From a5667b3999db9714fe12fd9d9d493d3b0809c6ec Mon Sep 17 00:00:00 2001 From: Zhiqi Date: Thu, 1 Jul 2021 17:24:29 +0800 Subject: [PATCH 0001/1892] init combo v2 --- combo/__init__.py | 0 combo/logical/__init__.py | 0 combo/physical/__init__.py | 0 combo/physical/operator/__init__.py | 0 combo/physical/operator/linear.py | 23 +++++++++ examples/linear.py | 80 +++++++++++++++++++++++++++++ scripts/env-setup.sh | 22 ++++++++ scripts/keep.py | 13 +++++ setup.py | 12 +++++ 9 files changed, 150 insertions(+) create mode 100644 combo/__init__.py create mode 100644 combo/logical/__init__.py create mode 100644 combo/physical/__init__.py create mode 100644 combo/physical/operator/__init__.py create mode 100644 combo/physical/operator/linear.py create mode 100644 examples/linear.py create mode 100644 scripts/env-setup.sh create mode 100644 scripts/keep.py create mode 100644 setup.py diff --git a/combo/__init__.py b/combo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/logical/__init__.py b/combo/logical/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/physical/__init__.py b/combo/physical/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/physical/operator/__init__.py b/combo/physical/operator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/physical/operator/linear.py b/combo/physical/operator/linear.py new file mode 100644 index 00000000..f9566d11 --- /dev/null +++ b/combo/physical/operator/linear.py @@ -0,0 +1,23 @@ +from typing import Optional +import torch +from torch import Tensor +from torch.overrides import has_torch_function_variadic, handle_torch_function + +def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: + r""" + Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. + + This operator supports :ref:`TensorFloat32`. + + Shape: + + - Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of + additional dimensions + - Weight: :math:`(out\_features, in\_features)` + - Bias: :math:`(out\_features)` + - Output: :math:`(N, *, out\_features)` + """ + if has_torch_function_variadic(input, weight): + print('go through here') + return handle_torch_function(linear, (input, weight), input, weight, bias=bias) + return torch._C._nn.linear(input, weight, bias) \ No newline at end of file diff --git a/examples/linear.py b/examples/linear.py new file mode 100644 index 00000000..d05d75e4 --- /dev/null +++ b/examples/linear.py @@ -0,0 +1,80 @@ +import torch +from torch import nn +from torch import Tensor +from torch.nn.parameter import Parameter +import torch.functional as F + +import combo.physical.operator as combo_op + +import math +import argparse + + +class Linear(nn.Module): + + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super(Linear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: Tensor) -> Tensor: + return combo_op.linear(input, self.weight, self.bias) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=4., classes=1000): + super().__init__() + self.net = nn.Sequential( + Linear(dim, dim * mult), + nn.GELU(), + nn.Dropout(dropout), + Linear(dim * mult, dim) + ) + + self.classifier = Linear(dim, classes) + + def forward(self, x, labels): + output = self.net(x) + output = self.classifier(output) + loss = F.cross_entory(output, labels) + return loss + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--dim', type=int, default=1024) + parser.add_argument('--bs', type=int, default=8) + parser.add_argument('--classes', type=int, default=10) + args = parser.parse_args() + + torch.cuda.set_device(0) + model = FeedForward(args.dim) + model = model.cuda() + + inputs = torch.rand((args.bs, args.dim)).cuda() + labels = torch.randint((args.bs, args.classes)).cuda() + for _ in range(100): + loss = model(inputs, labels) + loss.backward() + print('Done.') \ No newline at end of file diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh new file mode 100644 index 00000000..a7ed129c --- /dev/null +++ b/scripts/env-setup.sh @@ -0,0 +1,22 @@ + +echo using docker image pytorch-cuda11.3: nvcr.io/nvidia/pytorch:21.04-py3 + +git config --global core.editor "vim" +git config --global user.name "Zhiqi Lin" +git config --global user.email "v-zhiql@microsoft.com" + +git config --global core.editor "vim" +git config --global user.name "Zhiqi Lin" +git config --global user.email "v-zhiql@microsoft.com" + + +cd /root +wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.tmux.conf +wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.vimrc + +echo 'export PATH=/opt/conda/bin:$PATH' >> ~/.bashrc +echo 'export PATH=/usr/local/cuda/bin:$PATH' >> ~/.bashrc +echo 'export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH' >> ~/.bashrc +echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc + +python setup.py develop diff --git a/scripts/keep.py b/scripts/keep.py new file mode 100644 index 00000000..14e4d93f --- /dev/null +++ b/scripts/keep.py @@ -0,0 +1,13 @@ +import torch +import time + +interval = 10 + +a = torch.rand((8192, 8192)).cuda() +b = torch.rand((8192, 8192)).cuda() + + +while True: + for _ in range(1000): + c = a * b + time.sleep(interval) diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..c24c4bd9 --- /dev/null +++ b/setup.py @@ -0,0 +1,12 @@ +import setuptools + +setuptools.setup( + name= 'combo', + version= '0.1', + author= 'Zhiqi Lin', + author_email= 'v-zhiql@microsoft.com', + description= 'Combo', + long_description= 'Combo', + packages= ['combo'], + python_requires= '>=3.6', +) From 700994446905621e876eb49d6cab79c496acfbd5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 1 Jul 2021 10:38:27 +0000 Subject: [PATCH 0002/1892] setup maintain scripts --- scripts/env-setup.sh | 14 +++++++------- scripts/keep.py | 9 +++++++-- 2 files changed, 14 insertions(+), 9 deletions(-) mode change 100644 => 100755 scripts/env-setup.sh diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh old mode 100644 new mode 100755 index a7ed129c..080b6fa9 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -5,14 +5,14 @@ git config --global core.editor "vim" git config --global user.name "Zhiqi Lin" git config --global user.email "v-zhiql@microsoft.com" -git config --global core.editor "vim" -git config --global user.name "Zhiqi Lin" -git config --global user.email "v-zhiql@microsoft.com" - +sudo git config --global core.editor "vim" +sudo git config --global user.name "Zhiqi Lin" +sudo git config --global user.email "v-zhiql@microsoft.com" +sudo chmod -R a+w /opt/conda -cd /root -wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.tmux.conf -wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.vimrc +sudo apt-get install tmux -y +wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.tmux.conf -O ~/.tmux.conf +wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.vimrc -O ~/.vimrc echo 'export PATH=/opt/conda/bin:$PATH' >> ~/.bashrc echo 'export PATH=/usr/local/cuda/bin:$PATH' >> ~/.bashrc diff --git a/scripts/keep.py b/scripts/keep.py index 14e4d93f..16ee0c11 100644 --- a/scripts/keep.py +++ b/scripts/keep.py @@ -1,13 +1,18 @@ import torch import time -interval = 10 +interval = 2 +torch.cuda.set_device(3) a = torch.rand((8192, 8192)).cuda() b = torch.rand((8192, 8192)).cuda() while True: - for _ in range(1000): + tic = time.time() + for _ in range(5000): c = a * b + torch.cuda.synchronize() + toc = time.time() + print('time span: {}s'.format(toc - tic)) time.sleep(interval) From 15616df40f24a4dc3c225622c09e49c2420d06f1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 1 Jul 2021 10:58:17 +0000 Subject: [PATCH 0003/1892] linear init --- combo/physical/operator/__init__.py | 1 + combo/physical/operator/linear.py | 4 ++-- examples/linear.py | 13 +++++++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/combo/physical/operator/__init__.py b/combo/physical/operator/__init__.py index e69de29b..0151a3e1 100644 --- a/combo/physical/operator/__init__.py +++ b/combo/physical/operator/__init__.py @@ -0,0 +1 @@ +from combo.physical.operator.linear import linear_op \ No newline at end of file diff --git a/combo/physical/operator/linear.py b/combo/physical/operator/linear.py index f9566d11..73bb8ed4 100644 --- a/combo/physical/operator/linear.py +++ b/combo/physical/operator/linear.py @@ -3,7 +3,7 @@ from torch import Tensor from torch.overrides import has_torch_function_variadic, handle_torch_function -def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: +def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. @@ -18,6 +18,6 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens - Output: :math:`(N, *, out\_features)` """ if has_torch_function_variadic(input, weight): - print('go through here') + print('note: this branch should not pass') return handle_torch_function(linear, (input, weight), input, weight, bias=bias) return torch._C._nn.linear(input, weight, bias) \ No newline at end of file diff --git a/examples/linear.py b/examples/linear.py index d05d75e4..61010862 100644 --- a/examples/linear.py +++ b/examples/linear.py @@ -2,7 +2,7 @@ from torch import nn from torch import Tensor from torch.nn.parameter import Parameter -import torch.functional as F +import torch.nn.functional as F import combo.physical.operator as combo_op @@ -38,11 +38,11 @@ def reset_parameters(self) -> None: nn.init.uniform_(self.bias, -bound, bound) def forward(self, input: Tensor) -> Tensor: - return combo_op.linear(input, self.weight, self.bias) + return combo_op.linear_op(input, self.weight, self.bias) class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=4., classes=1000): + def __init__(self, dim, dropout=0., mult=16, classes=1000): super().__init__() self.net = nn.Sequential( Linear(dim, dim * mult), @@ -56,7 +56,7 @@ def __init__(self, dim, dropout=0., mult=4., classes=1000): def forward(self, x, labels): output = self.net(x) output = self.classifier(output) - loss = F.cross_entory(output, labels) + loss = F.cross_entropy(output, labels) return loss @@ -64,16 +64,17 @@ def forward(self, x, labels): parser = argparse.ArgumentParser() parser.add_argument('--dim', type=int, default=1024) + parser.add_argument('--heads', type=int, default=16) parser.add_argument('--bs', type=int, default=8) parser.add_argument('--classes', type=int, default=10) args = parser.parse_args() torch.cuda.set_device(0) - model = FeedForward(args.dim) + model = FeedForward(args.dim, mult=args.heads, classes=args.classes) model = model.cuda() inputs = torch.rand((args.bs, args.dim)).cuda() - labels = torch.randint((args.bs, args.classes)).cuda() + labels = torch.randint(0, 10, (args.bs, )).cuda() for _ in range(100): loss = model(inputs, labels) loss.backward() From 4d2b9b8baf60e9ce61c5c5d420bb6ad99ce9c133 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 1 Jul 2021 12:08:34 +0000 Subject: [PATCH 0004/1892] maintainable --- .gitignore | 2 ++ scripts/keep.py | 35 +++++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 12 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..c0f52faf --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +*.egg-info \ No newline at end of file diff --git a/scripts/keep.py b/scripts/keep.py index 16ee0c11..0934ed10 100644 --- a/scripts/keep.py +++ b/scripts/keep.py @@ -1,18 +1,29 @@ import torch import time +import argparse -interval = 2 -torch.cuda.set_device(3) -a = torch.rand((8192, 8192)).cuda() -b = torch.rand((8192, 8192)).cuda() +def keep(rank, args): + torch.cuda.set_device(rank) + a = torch.rand((8192, 8192)).cuda() + b = torch.rand((8192, 8192)).cuda() -while True: - tic = time.time() - for _ in range(5000): - c = a * b - torch.cuda.synchronize() - toc = time.time() - print('time span: {}s'.format(toc - tic)) - time.sleep(interval) + while True: + tic = time.time() + for _ in range(5000): + c = a * b + torch.cuda.synchronize() + toc = time.time() + if rank == 0: + print('time span: {}s'.format(toc - tic)) + time.sleep(args.interval) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--interval', type=int, default=2) + parser.add_argument('--gpus', type=int, default=1) + args = parser.parse_args() + + torch.multiprocessing.spawn(keep, args=(args,), nprocs=args.gpus, join=True) From ccb32649ebf77f43776d29ba9cd4f95c4c7a1c58 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 1 Jul 2021 12:37:41 +0000 Subject: [PATCH 0005/1892] auto detect gpu util --- scripts/keep.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/scripts/keep.py b/scripts/keep.py index 0934ed10..d2e87aa9 100644 --- a/scripts/keep.py +++ b/scripts/keep.py @@ -2,6 +2,32 @@ import time import argparse +import subprocess +import re + +def get_gpu_util(rank): + + cmds = [ + 'nvidia-smi', + '-i', + str(rank), + ] + proc = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + outputs = stdout.decode('utf-8').split('\n') + + util = 0 + for output in outputs[::-1]: + # switch to performance line + if 'Default' in output: + # match all the numbers and return the last one + util = re.findall(r'\d+', output)[-1] + util = int(util) + break + else: + print("rank {}: couldn't match any, check GPU status!".format(rank)) + return util + def keep(rank, args): @@ -16,8 +42,16 @@ def keep(rank, args): torch.cuda.synchronize() toc = time.time() if rank == 0: - print('time span: {}s'.format(toc - tic)) + print('benchmark 8K matmul: time span: {}ms'.format((toc - tic) * 1000 / 5000)) time.sleep(args.interval) + while True: + util = get_gpu_util(rank) + if util >= 0: + break + print('rank {}: find gpu busy, keep sleeping...'.format(rank)) + time.sleep(args.interval) + print('rank {} gets up'.format(rank)) + if __name__ == '__main__': From d28543699acb01d3f167ce2be491df8dcd43de6b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 1 Jul 2021 12:43:25 +0000 Subject: [PATCH 0006/1892] work in progress: enable linear tensor partition --- combo/physical/comm/__init__.py | 0 combo/physical/comm/helper.py | 177 ++++++++++++++++++++++++++++++ combo/physical/device/__init__.py | 0 combo/physical/device/group.py | 11 ++ combo/physical/operator/linear.py | 28 ++++- 5 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 combo/physical/comm/__init__.py create mode 100644 combo/physical/comm/helper.py create mode 100644 combo/physical/device/__init__.py create mode 100644 combo/physical/device/group.py diff --git a/combo/physical/comm/__init__.py b/combo/physical/comm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/physical/comm/helper.py b/combo/physical/comm/helper.py new file mode 100644 index 00000000..c33f18ff --- /dev/null +++ b/combo/physical/comm/helper.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank + + +def split_tensor_along_last_dim(tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def _reduce(input_): + """All-reduce the the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size()==1: + return input_ + + # All-reduce. + torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) + + return input_ + + +def _split(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size==1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_last_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_tensor_model_parallel_rank() + output = input_list[rank].contiguous() + + return output + + +def _gather(input_): + """Gather tensors and concatinate along the last dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size==1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_): + return _split(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output) + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_): + return _gather(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + + +# ----------------- +# Helper functions. +# ----------------- + +def copy_to_tensor_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + + +def scatter_to_tensor_model_parallel_region(input_): + return _ScatterToModelParallelRegion.apply(input_) + + +def gather_from_tensor_model_parallel_region(input_): + return _GatherFromModelParallelRegion.apply(input_) diff --git a/combo/physical/device/__init__.py b/combo/physical/device/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/physical/device/group.py b/combo/physical/device/group.py new file mode 100644 index 00000000..b3218727 --- /dev/null +++ b/combo/physical/device/group.py @@ -0,0 +1,11 @@ +""" +Communication group settings among devices +""" + +import torch + +def create_group(device_count): + """ + Create device group + """ + pass \ No newline at end of file diff --git a/combo/physical/operator/linear.py b/combo/physical/operator/linear.py index 73bb8ed4..9f31e13c 100644 --- a/combo/physical/operator/linear.py +++ b/combo/physical/operator/linear.py @@ -3,6 +3,7 @@ from torch import Tensor from torch.overrides import has_torch_function_variadic, handle_torch_function + def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. @@ -20,4 +21,29 @@ def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> T if has_torch_function_variadic(input, weight): print('note: this branch should not pass') return handle_torch_function(linear, (input, weight), input, weight, bias=bias) - return torch._C._nn.linear(input, weight, bias) \ No newline at end of file + + # single GPU version + if True: + output = torch._C._nn.linear(input, weight, bias) + + # multi-GPU version + # - Assume input is full + # - split cloumn: Y = XA + b where A = [A_1, ..., A_p] + elif True: + # forward: identity; backward: allreduce + input = copy_to_tensor_model_parallel_region(input) + output = torch._C._nn.linear(input, weight, bias) + # forward: allgather; backward: scatter + output = gather_from_tensor_model_parallel_region(output) + + # multi-GPU version + # - Assume input is full + # - split row: Y = XA + b where X = [X_1, ..., X_p], A = [A_1 || ... || A_p] + elif True: + # forward: scatter; backward: allgather + input = scatter_to_tensor_model_parallel_region(input) + output = torch._C._nn.linear(input, weight, bias) + # forward: reduce; backward: identity + output = reduce_from_tensor_model_parallel_region(output) + + return output \ No newline at end of file From 1e2e0f5181d16d061ca4e86242ca6579d7de6c23 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 2 Jul 2021 05:37:26 +0000 Subject: [PATCH 0007/1892] add comm groups --- combo/physical/device/group.py | 70 +++++++++++++++++++++++++++++++--- tests/test_group.py | 26 +++++++++++++ 2 files changed, 91 insertions(+), 5 deletions(-) create mode 100644 tests/test_group.py diff --git a/combo/physical/device/group.py b/combo/physical/device/group.py index b3218727..baee1d84 100644 --- a/combo/physical/device/group.py +++ b/combo/physical/device/group.py @@ -3,9 +3,69 @@ """ import torch +import os -def create_group(device_count): - """ - Create device group - """ - pass \ No newline at end of file + +class DeviceGroup: + + class __DeviceGroup: + + def __init__(self): + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + # world_size=device_num, + # init_method='tcp://' + '{master_ip}:{port}'.format(master_ip=master_ip, port=port) + ) + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + # assume each node has the same device number + self.local_rank = int(os.environ.get('LOCAL_RANK')) + self.node_id = self.rank // torch.cuda.device_count() + self.groups = dict() + torch.cuda.set_device(self.local_rank) + + instance = None + + def __init__(self): + if not DeviceGroup.instance: + DeviceGroup.instance = DeviceGroup.__DeviceGroup() + + def __getattr__(self, name): + return getattr(self.instance, name) + + # def __setattr__(self, name): + # return setattr(self.instance, name) + + def __len__(self, name): + return DeviceGroup.instance.world_size + + def get_group(self, ranks): + """ + Create and return rank groups on-demand + """ + rank_bits = DeviceGroup.bitmap(ranks) + if rank_bits not in self.instance.groups: + self.groups[rank_bits] = torch.distributed.new_group(list(ranks)) + return self.groups[rank_bits] + + @staticmethod + def bitmap(ranks): + """ + map the rank list to the bit map string + """ + bits = '0' * DeviceGroup.instance.world_size + for rank in ranks: + if rank >= len(bits): + raise ValueError("rank {} out of range ({})".format(rank, len(bits))) + bits = bits[0:rank] + '1' + bits[rank+1:] + return bits + + def __repr__(self): + msg = 'node id: [{}] rank: [{}] local rank: [{}]\n'.format(self.node_id, self.rank, self.local_rank) + msg += 'communication groups (ranks):\n' + for bitmap, group in self.groups.items(): + ranks = [rank for rank, bit in enumerate(bitmap) if bit == '1'] + if self.instance.rank in ranks: + msg += '\t group {}: my group rank: [{}]\n'.format(ranks, torch.distributed.get_rank(group)) + return msg diff --git a/tests/test_group.py b/tests/test_group.py new file mode 100644 index 00000000..f10b75b1 --- /dev/null +++ b/tests/test_group.py @@ -0,0 +1,26 @@ +""" +Test this with: + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=6000 \ + --use_env \ + tests/test_group.py +""" + +from combo.physical.device.group import DeviceGroup + + + +if __name__ == '__main__': + + # init distributed + group = DeviceGroup() + + sub_group_1 = group.get_group([0,2]) + sub_group_2 = group.get_group([1,3]) + + print(group) \ No newline at end of file From e3371a7583fccc0c0343ed823508bad057c5f2c4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 2 Jul 2021 07:24:22 +0000 Subject: [PATCH 0008/1892] tensor parallelism boundary --- combo/physical/comm/__init__.py | 0 combo/physical/comm/helper.py | 177 ----------------------- combo/physical/operator/comm/__init__.py | 1 + combo/physical/operator/comm/boundary.py | 132 +++++++++++++++++ 4 files changed, 133 insertions(+), 177 deletions(-) delete mode 100644 combo/physical/comm/__init__.py delete mode 100644 combo/physical/comm/helper.py create mode 100644 combo/physical/operator/comm/__init__.py create mode 100644 combo/physical/operator/comm/boundary.py diff --git a/combo/physical/comm/__init__.py b/combo/physical/comm/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/combo/physical/comm/helper.py b/combo/physical/comm/helper.py deleted file mode 100644 index c33f18ff..00000000 --- a/combo/physical/comm/helper.py +++ /dev/null @@ -1,177 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank - - -def split_tensor_along_last_dim(tensor, num_partitions, - contiguous_split_chunks=False): - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = divide(tensor.size()[last_dim], num_partitions) - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -def _reduce(input_): - """All-reduce the the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size()==1: - return input_ - - # All-reduce. - torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) - - return input_ - - -def _split(input_): - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size==1: - return input_ - - # Split along last dimension. - input_list = split_tensor_along_last_dim(input_, world_size) - - # Note: torch.split does not create contiguous tensors by default. - rank = get_tensor_model_parallel_rank() - output = input_list[rank].contiguous() - - return output - - -def _gather(input_): - """Gather tensors and concatinate along the last dimension.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size==1: - return input_ - - # Size and dimension. - last_dim = input_.dim() - 1 - rank = get_tensor_model_parallel_rank() - - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() - - return output - - -class _CopyToModelParallelRegion(torch.autograd.Function): - """Pass the input to the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_): - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output) - - -class _ReduceFromModelParallelRegion(torch.autograd.Function): - """All-reduce the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return _reduce(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce(input_) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class _ScatterToModelParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - @staticmethod - def symbolic(graph, input_): - return _split(input_) - - @staticmethod - def forward(ctx, input_): - return _split(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather(grad_output) - - -class _GatherFromModelParallelRegion(torch.autograd.Function): - """Gather the input from model parallel region and concatinate.""" - - @staticmethod - def symbolic(graph, input_): - return _gather(input_) - - @staticmethod - def forward(ctx, input_): - return _gather(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output) - - -# ----------------- -# Helper functions. -# ----------------- - -def copy_to_tensor_model_parallel_region(input_): - return _CopyToModelParallelRegion.apply(input_) - - -def reduce_from_tensor_model_parallel_region(input_): - return _ReduceFromModelParallelRegion.apply(input_) - - -def scatter_to_tensor_model_parallel_region(input_): - return _ScatterToModelParallelRegion.apply(input_) - - -def gather_from_tensor_model_parallel_region(input_): - return _GatherFromModelParallelRegion.apply(input_) diff --git a/combo/physical/operator/comm/__init__.py b/combo/physical/operator/comm/__init__.py new file mode 100644 index 00000000..7390d802 --- /dev/null +++ b/combo/physical/operator/comm/__init__.py @@ -0,0 +1 @@ +from combo.physical.operator.comm.boundary import * \ No newline at end of file diff --git a/combo/physical/operator/comm/boundary.py b/combo/physical/operator/comm/boundary.py new file mode 100644 index 00000000..fd8f0634 --- /dev/null +++ b/combo/physical/operator/comm/boundary.py @@ -0,0 +1,132 @@ + +import torch + +from combo.physical.device.group import DeviceGroup + + +__all__ = ['parallel_in', 'gather_out', 'scatter_in', 'reduce_out'] + + +def _reduce(input_, group): + """All-reduce the the input tensor across model parallel group.""" + + # allreduce + torch.distributed.all_reduce(input_, group=group) + return input_ + + +def _split(input_, dim, chunk_num, rank): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + # bypass the function if we are using only 1 GPU. + if chunk_num == 1: + return input_ + # split along specified dim + if input_.size()[dim] % chunk_num != 0: + raise RuntimeError("backward on Gather Out Error: un divideable") + dim_size = input_.size()[dim] // chunk_num + tensor_list = torch.split(input_, dim_size, dim=dim) + # note: torch.split does not create contiguous tensors by default. + output = tensor_list[rank].contiguous() + return output + + +def _gather(input_, dim, group): + """Gather tensors and concatinate along the last dimension.""" + + world_size = torch.distributed.get_world_size(group) + rank = torch.distributed.get_rank(group) + # bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + # note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + return output + + +class _ParallelIn(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def forward(ctx, input_, ranks): + # record group + group = DeviceGroup().get_group(ranks) + ctx.constants = group + # identitfy forward + return input_ + + @staticmethod + def backward(ctx, grad_output): + # allreduce + group = ctx.constants + return _reduce(grad_output, group) + + +class _GatherOut(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def forward(ctx, input_, dim, ranks): + # record group + group = DeviceGroup().get_group(ranks) + ctx.constants = (group, dim) + # allgather + return _gather(input_, dim, group) + + @staticmethod + def backward(ctx, grad_output): + group, dim = ctx.constants + world_size = torch.distributed.get_world_size(group) + rank = torch.distributed.get_rank(group) + return _split(grad_output, dim, world_size, rank) + + +class _ScatterIn(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def forward(ctx, input_, dim, ranks): + group = DeviceGroup().get_group(ranks) + world_size = torch.distributed.get_world_size(group) + rank = torch.distributed.get_rank(group) + ctx.constants = (group, dim) + return _split(input_, dim, world_size, rank) + + @staticmethod + def backward(ctx, grad_output): + group, dim = ctx.constants + return _gather(grad_output, dim, group) + + +class _ReduceOut(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def forward(ctx, input_, ranks): + group = DeviceGroup().get_group(ranks) + return _reduce(input_, group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +def parallel_in(input_, ranks): + return _ParallelIn.apply(input_, ranks) + + +def gather_out(input_, dim, ranks): + return _GatherOut.apply(input_, dim, ranks) + + +def scatter_in(input_, dim, ranks): + return _ScatterIn.apply(input_, dim, ranks) + + +def reduce_out(input_): + return _ReduceOut.apply(input_) From 553cf2e2c6753f0aede05f8ae30c6b67957c1992 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Jul 2021 14:00:13 +0000 Subject: [PATCH 0009/1892] none group if all the ranks are included --- combo/physical/device/group.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/combo/physical/device/group.py b/combo/physical/device/group.py index baee1d84..f6dfc470 100644 --- a/combo/physical/device/group.py +++ b/combo/physical/device/group.py @@ -43,7 +43,11 @@ def __len__(self, name): def get_group(self, ranks): """ Create and return rank groups on-demand + + None will be returned if length of ranks are equal to world size """ + if len(ranks) == self.instance.world_size: + return None rank_bits = DeviceGroup.bitmap(ranks) if rank_bits not in self.instance.groups: self.groups[rank_bits] = torch.distributed.new_group(list(ranks)) From a5709795a06daeb7af4f91aa6868b4bb910a0d6c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Jul 2021 14:01:02 +0000 Subject: [PATCH 0010/1892] fix boundary backward --- combo/physical/operator/comm/boundary.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/combo/physical/operator/comm/boundary.py b/combo/physical/operator/comm/boundary.py index fd8f0634..8af61b2a 100644 --- a/combo/physical/operator/comm/boundary.py +++ b/combo/physical/operator/comm/boundary.py @@ -1,3 +1,7 @@ +""" +Autograd backward needs to return the same number of gradients as input, +even if they are not tensors. +""" import torch @@ -64,7 +68,7 @@ def forward(ctx, input_, ranks): def backward(ctx, grad_output): # allreduce group = ctx.constants - return _reduce(grad_output, group) + return _reduce(grad_output, group), None class _GatherOut(torch.autograd.Function): @@ -83,7 +87,7 @@ def backward(ctx, grad_output): group, dim = ctx.constants world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) - return _split(grad_output, dim, world_size, rank) + return _split(grad_output, dim, world_size, rank), None, None class _ScatterIn(torch.autograd.Function): @@ -100,7 +104,7 @@ def forward(ctx, input_, dim, ranks): @staticmethod def backward(ctx, grad_output): group, dim = ctx.constants - return _gather(grad_output, dim, group) + return _gather(grad_output, dim, group), None, None class _ReduceOut(torch.autograd.Function): From 7193475dfc62863eb307879dbaf7050e603efc41 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Jul 2021 14:03:35 +0000 Subject: [PATCH 0011/1892] fix bugs on reduce out --- combo/physical/operator/comm/boundary.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/combo/physical/operator/comm/boundary.py b/combo/physical/operator/comm/boundary.py index 8af61b2a..be5f9465 100644 --- a/combo/physical/operator/comm/boundary.py +++ b/combo/physical/operator/comm/boundary.py @@ -117,7 +117,7 @@ def forward(ctx, input_, ranks): @staticmethod def backward(ctx, grad_output): - return grad_output + return grad_output, None def parallel_in(input_, ranks): @@ -132,5 +132,5 @@ def scatter_in(input_, dim, ranks): return _ScatterIn.apply(input_, dim, ranks) -def reduce_out(input_): - return _ReduceOut.apply(input_) +def reduce_out(input_, ranks): + return _ReduceOut.apply(input_, ranks) From 719d9c9ee7ddf7a29f793b0ca3a1b0a8649ec277 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Jul 2021 14:05:18 +0000 Subject: [PATCH 0012/1892] multi-gpu linear version --- combo/physical/operator/linear.py | 34 ++++++++++++++++++++++--------- examples/linear.py | 17 +++++++++++++++- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/combo/physical/operator/linear.py b/combo/physical/operator/linear.py index 9f31e13c..fa9311fe 100644 --- a/combo/physical/operator/linear.py +++ b/combo/physical/operator/linear.py @@ -3,6 +3,9 @@ from torch import Tensor from torch.overrides import has_torch_function_variadic, handle_torch_function +import combo.physical.operator.comm as comm +from combo.physical.device.group import DeviceGroup + def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" @@ -22,28 +25,39 @@ def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> T print('note: this branch should not pass') return handle_torch_function(linear, (input, weight), input, weight, bias=bias) + devices = [0, 1] + rank = torch.distributed.get_rank(DeviceGroup().get_group(devices)) + # single GPU version if True: output = torch._C._nn.linear(input, weight, bias) # multi-GPU version # - Assume input is full - # - split cloumn: Y = XA + b where A = [A_1, ..., A_p] - elif True: + # - split cloumn of W: Y = XW + b where W = [W_1, ..., W_p] + elif False: + # get weight chunk + weight = torch.chunk(weight, chunks=len(devices), dim=0)[rank].contiguous() + if bias is not None: + bias = torch.chunk(bias, chunks=len(devices), dim=0)[rank].contiguous() # forward: identity; backward: allreduce - input = copy_to_tensor_model_parallel_region(input) + input = comm.parallel_in(input, ranks=devices) output = torch._C._nn.linear(input, weight, bias) - # forward: allgather; backward: scatter - output = gather_from_tensor_model_parallel_region(output) + # forward: allgather; backward: split + output = comm.gather_out(output, dim=-1, ranks=devices) - # multi-GPU version + # multi-GPU version Y = XW + b # - Assume input is full - # - split row: Y = XA + b where X = [X_1, ..., X_p], A = [A_1 || ... || A_p] - elif True: + # - split row of W, column of X: + # - Y = X = [X_1, ..., X_p] + # - W = [W_1 // ... // W_p] + elif False: + # get weight chunk + weight = torch.chunk(weight, chunks=len(devices), dim=1)[rank] # forward: scatter; backward: allgather - input = scatter_to_tensor_model_parallel_region(input) + input = comm.scatter_in(input, dim=-1, ranks=devices) output = torch._C._nn.linear(input, weight, bias) # forward: reduce; backward: identity - output = reduce_from_tensor_model_parallel_region(output) + output = comm.reduce_out(output, ranks=devices) return output \ No newline at end of file diff --git a/examples/linear.py b/examples/linear.py index 61010862..96deec6b 100644 --- a/examples/linear.py +++ b/examples/linear.py @@ -1,9 +1,21 @@ +""" +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=6000 \ + --use_env \ + examples/linear.py +""" + import torch from torch import nn from torch import Tensor from torch.nn.parameter import Parameter import torch.nn.functional as F +import combo import combo.physical.operator as combo_op import math @@ -69,7 +81,10 @@ def forward(self, x, labels): parser.add_argument('--classes', type=int, default=10) args = parser.parse_args() - torch.cuda.set_device(0) + # init distributed env + group = combo.physical.device.group.DeviceGroup() + print(group) + model = FeedForward(args.dim, mult=args.heads, classes=args.classes) model = model.cuda() From 8fb17ba5357f7e7535d04edbf013695f9ba98707 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Jul 2021 14:07:52 +0000 Subject: [PATCH 0013/1892] nccl test --- tests/test_nccl.py | 89 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 tests/test_nccl.py diff --git a/tests/test_nccl.py b/tests/test_nccl.py new file mode 100644 index 00000000..348244c6 --- /dev/null +++ b/tests/test_nccl.py @@ -0,0 +1,89 @@ + +""" +Single node usage: +e.g., 8 GPUs +python -m torch.distributed.launch --nproc_per_node=4 test_nccl.py +Multi-node usage: +e.g., 2-node each with 8 GPUs +python -m torch.distributed.launch --nproc_per_node=8 --node_rank=0 --master_port=6000 --master_addr='master ip iddress' --nnodes=2 test_nccl.py +python -m torch.distributed.launch --nproc_per_node=8 --node_rank=1 --master_port=6000 --master_addr='master ip iddress' --nnodes=2 test_nccl.py +""" + +import torch +import time +import sys +import os +import argparse + + +def print_each_rank(msg, select=True, outfile=''): + myrank = torch.distributed.get_rank() + outfile = sys.stdout if outfile == '' else outfile + for rank in range(torch.distributed.get_world_size()): + if select: + if myrank == rank: + f = open(outfile, 'a') if outfile != sys.stdout else sys.stdout + f.write('rank [{}]: {}\n'.format(rank, msg)) + if outfile != sys.stdout: + f.close() + torch.distributed.barrier() + + +def test_nccl(size, local_rank): + msg = torch.ones((size,)).cuda() + # warm up + for _ in range(20): + out = torch.distributed.all_reduce(msg) + torch.cuda.synchronize() + # profile + tic = time.perf_counter() + for _ in range(100): + out = torch.distributed.all_reduce(msg) + torch.cuda.synchronize() + toc = time.perf_counter() + + span = (toc - tic) * 1000 / 100 # in ms + bandwidth = size / span / 1e6 # in GB/s + print_each_rank( + 'NCCL Allreduce | Msg Size: {:.0f} MB | Algo Bandwidth: {:.2f} GB/s'.format( + size / 1024 / 1024, bandwidth), + select=(local_rank==0), + ) + +def test_allgather(size, local_rank): + msg = torch.ones((size,)).cuda() + tensor_list = [torch.empty_like(msg) for _ in range(torch.distributed.get_world_size())] + + tic = time.perf_counter() + for _ in range(100): + out = torch.distributed.all_gather(tensor_list, msg) + torch.cuda.synchronize() + print_each_rank('Passed all-gather') + toc = time.perf_counter() + + +def benchmark(args): + size = args.begin + while size <= args.end: + # test_allgather(size * 1024 * 1024, args.local_rank) + test_nccl(size * 1024 * 1024, args.local_rank) # MB to B + size *= 2 + print_each_rank('test on nccl is done') + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--begin', type=int, default=4, + help='start message size in MB') + parser.add_argument('--end', type=int, default=64, + help='end message size in MB') + parser.add_argument('--local_rank', type=int, required=True, + help='specified by torch.distributed.launch') + args = parser.parse_args() + + torch.distributed.init_process_group(backend='nccl') + print_each_rank('local rank-{} launches'.format(args.local_rank)) + + torch.cuda.set_device(args.local_rank) + benchmark(args) \ No newline at end of file From 9d9e090b07d843903617bab0644f66e64230530d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Jul 2021 14:19:47 +0000 Subject: [PATCH 0014/1892] list information we need --- combo/physical/operator/linear.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/combo/physical/operator/linear.py b/combo/physical/operator/linear.py index fa9311fe..67568aba 100644 --- a/combo/physical/operator/linear.py +++ b/combo/physical/operator/linear.py @@ -25,6 +25,11 @@ def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> T print('note: this branch should not pass') return handle_torch_function(linear, (input, weight), input, weight, bias=bias) + # Information needed for enabling multiple GPUs: + # - involved devices -> (could be involved in tensor interfacee design) + # - which algorithm to take -> (semantic description / pattern match?) + # e.g., semantic description: allgather(split(weight, dim=0) * input + split(bias, dim=0)) + # - how we handle weight -> (everytime we need chunk weight / bias if we only focus on op) devices = [0, 1] rank = torch.distributed.get_rank(DeviceGroup().get_group(devices)) From a743c34040205532ddfda69b3511a71891ba16cf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 6 Jul 2021 08:38:43 +0000 Subject: [PATCH 0015/1892] add generic physical operator interface --- combo/physical/operator/generic.py | 49 ++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 combo/physical/operator/generic.py diff --git a/combo/physical/operator/generic.py b/combo/physical/operator/generic.py new file mode 100644 index 00000000..d4603475 --- /dev/null +++ b/combo/physical/operator/generic.py @@ -0,0 +1,49 @@ +""" +Physical Generic Operator definition. + +The output communication works in a lazy execution way. Communication will only +happen in the front of the next executed op in case the layout doesn't match. +""" + +class GenericOp: + + def __init__(self, func): + """ + func: Should be a logical operator handling holistic tensors. + """ + + # operator: take any inputs and generate output + self.F = func + + # function inputs requirement + self.input_layout = dict() + + # the expected function output holistic layout + self.output_layout = dict() + + def boundary_in(self, args, **kwargs): + """ + Transform tensors in args and kwargs to match the + input layout requirement + """ + pass + + def warp_to_holistic_tensor(self, outputs): + """ + Wrap local computed tensor into a holistic view + by using self.output_layout + """ + pass + + def execute(self, args, **kwargs): + + # data transformations to match input layout requirement + self.boundary_in(args, kwargs) + + # do execution + outputs = self.F(args, kwargs) + + # wrap in holistic tensor with output layout + outputs = self.warp_to_holistic_tensor(outputs) + + return outputs \ No newline at end of file From 26f8d43ec8b6d3a8d50f42e257aaa6aba73f52a5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Jul 2021 08:20:31 +0000 Subject: [PATCH 0016/1892] re-org operator layers --- combo/{logical => holist}/__init__.py | 0 .../{physical => holist/operator}/__init__.py | 0 combo/{physical/device => logic}/__init__.py | 0 combo/logic/generics.py | 54 +++++++++++++++++++ combo/physic/__init__.py | 0 combo/physic/device/__init__.py | 0 combo/{physical => physic}/device/group.py | 0 .../{physical => physic}/operator/__init__.py | 0 .../operator/comm/__init__.py | 0 .../operator/comm/boundary.py | 0 .../operator/generics.py} | 0 combo/{physical => physic}/operator/linear.py | 14 ++++- 12 files changed, 66 insertions(+), 2 deletions(-) rename combo/{logical => holist}/__init__.py (100%) rename combo/{physical => holist/operator}/__init__.py (100%) rename combo/{physical/device => logic}/__init__.py (100%) create mode 100644 combo/logic/generics.py create mode 100644 combo/physic/__init__.py create mode 100644 combo/physic/device/__init__.py rename combo/{physical => physic}/device/group.py (100%) rename combo/{physical => physic}/operator/__init__.py (100%) rename combo/{physical => physic}/operator/comm/__init__.py (100%) rename combo/{physical => physic}/operator/comm/boundary.py (100%) rename combo/{physical/operator/generic.py => physic/operator/generics.py} (100%) rename combo/{physical => physic}/operator/linear.py (82%) diff --git a/combo/logical/__init__.py b/combo/holist/__init__.py similarity index 100% rename from combo/logical/__init__.py rename to combo/holist/__init__.py diff --git a/combo/physical/__init__.py b/combo/holist/operator/__init__.py similarity index 100% rename from combo/physical/__init__.py rename to combo/holist/operator/__init__.py diff --git a/combo/physical/device/__init__.py b/combo/logic/__init__.py similarity index 100% rename from combo/physical/device/__init__.py rename to combo/logic/__init__.py diff --git a/combo/logic/generics.py b/combo/logic/generics.py new file mode 100644 index 00000000..219d5507 --- /dev/null +++ b/combo/logic/generics.py @@ -0,0 +1,54 @@ +""" + +A Logical Operator: + * Statusless + * Can be executed by only one kernel (atomic) on single device + +Logical Operator + |- Holistic Operator 1 + | |- Physical Operator(s) + |- Holistic Operator 2 + |- ... + +Holistic operators are allowed to nested in hybrid-distribution strategy + +""" + +class HolisticOpFactory: + + def __init__(self): + + self.holist_ops = list() + + def register(self, holistic_op): + """ + Register a holistic op as one of the anchors + """ + #TODO: type check + self.holist_ops.append(holist_ops) + + def composite_op(self, args, **kwargs): + """ + Given input tensor args, choose holistic operator(s) + for distributed execution plan + + Returns: + An hybrid-operator function which may composite by + nested holistic operators + """ + pass + + + +class GenericLogicalOp: + + def __init__(self): + + # candidate holistic operator + self.factory = HolisticOpFactory() + + def __call__(self, args, **kwargs): + """ + Policy here to determine which holistic operator(s) are called + """ + pass \ No newline at end of file diff --git a/combo/physic/__init__.py b/combo/physic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/physic/device/__init__.py b/combo/physic/device/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/physical/device/group.py b/combo/physic/device/group.py similarity index 100% rename from combo/physical/device/group.py rename to combo/physic/device/group.py diff --git a/combo/physical/operator/__init__.py b/combo/physic/operator/__init__.py similarity index 100% rename from combo/physical/operator/__init__.py rename to combo/physic/operator/__init__.py diff --git a/combo/physical/operator/comm/__init__.py b/combo/physic/operator/comm/__init__.py similarity index 100% rename from combo/physical/operator/comm/__init__.py rename to combo/physic/operator/comm/__init__.py diff --git a/combo/physical/operator/comm/boundary.py b/combo/physic/operator/comm/boundary.py similarity index 100% rename from combo/physical/operator/comm/boundary.py rename to combo/physic/operator/comm/boundary.py diff --git a/combo/physical/operator/generic.py b/combo/physic/operator/generics.py similarity index 100% rename from combo/physical/operator/generic.py rename to combo/physic/operator/generics.py diff --git a/combo/physical/operator/linear.py b/combo/physic/operator/linear.py similarity index 82% rename from combo/physical/operator/linear.py rename to combo/physic/operator/linear.py index 67568aba..35eda888 100644 --- a/combo/physical/operator/linear.py +++ b/combo/physic/operator/linear.py @@ -3,8 +3,8 @@ from torch import Tensor from torch.overrides import has_torch_function_variadic, handle_torch_function -import combo.physical.operator.comm as comm -from combo.physical.device.group import DeviceGroup +import combo.physic.operator.comm as comm +from combo.physic.device.group import DeviceGroup def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: @@ -65,4 +65,14 @@ def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> T # forward: reduce; backward: identity output = comm.reduce_out(output, ranks=devices) + # Pesudo-code + else: + # data parallelism + input=[(0, Split())], weight=[], bias=[], output=[(0, Split())] + # tensor parallelism, weight column split + input=[], weight=[(0, Split())], bias=[(0, Split())], output=[(-1, Split())] + # tensor parallelism: data column + weight row + input=[(1, Split())] weight=[(1, Split()], bias=[], output=[(ALL, Partial(Sum))] + + return output \ No newline at end of file From bd53a3fcc623642ec7e016fe9a5c451b0a9c7d9f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Jul 2021 08:31:46 +0000 Subject: [PATCH 0017/1892] update logic operator interface --- combo/logic/generics.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/combo/logic/generics.py b/combo/logic/generics.py index 219d5507..3bdc7005 100644 --- a/combo/logic/generics.py +++ b/combo/logic/generics.py @@ -27,7 +27,7 @@ def register(self, holistic_op): #TODO: type check self.holist_ops.append(holist_ops) - def composite_op(self, args, **kwargs): + def get_op(self, args, **kwargs): """ Given input tensor args, choose holistic operator(s) for distributed execution plan @@ -51,4 +51,7 @@ def __call__(self, args, **kwargs): """ Policy here to determine which holistic operator(s) are called """ - pass \ No newline at end of file + composite_op = self.factory.get_op(args, kwargs) + # run operator with the strategy plan + outputs = composite_op(args, kwargs) + return outputs \ No newline at end of file From 7c3d4a8a092338ab9ffe0ebeb1d3490b4367604a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Jul 2021 08:40:21 +0000 Subject: [PATCH 0018/1892] update holistic operator interface --- combo/holist/operator/generics.py | 32 +++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 combo/holist/operator/generics.py diff --git a/combo/holist/operator/generics.py b/combo/holist/operator/generics.py new file mode 100644 index 00000000..69409ac8 --- /dev/null +++ b/combo/holist/operator/generics.py @@ -0,0 +1,32 @@ + +""" +Holistic Operator Generics + +The holistic operator needed to be registered into logical op +""" + +class GenericHolisticOp: + + def __init__(self, input_layout, output_layout): + + # holistic layout of input to wark on + self.input_layout = input_layout + # expected holistic layout of output + self.output_layout = output_layout + + def input_transform(self, args, **kwargs): + """input transformation to the required layout""" + pass + + def forward(self, args, **kwargs): + """Expert code for doing operation""" + pass + + def __call__(self, args, **kwargs): + """Operator execution""" + + self.input_transform(args, kwargs) + + outputs = self.forward(args, kwargs) + + return outputs From 1d3f92b3d6073a35dc05e7e1f04c91a26905a30b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Jul 2021 12:50:22 +0000 Subject: [PATCH 0019/1892] holist operator generics --- combo/holist/operator/generics.py | 39 +++++++++++++++++++------ combo/physic/operator/generics.py | 48 ++----------------------------- 2 files changed, 32 insertions(+), 55 deletions(-) diff --git a/combo/holist/operator/generics.py b/combo/holist/operator/generics.py index 69409ac8..ada37f6a 100644 --- a/combo/holist/operator/generics.py +++ b/combo/holist/operator/generics.py @@ -3,19 +3,36 @@ Holistic Operator Generics The holistic operator needed to be registered into logical op + +The output communication works in a lazy execution way. Communication will only +happen in the front of the next executed op in case the layout doesn't match. """ class GenericHolisticOp: def __init__(self, input_layout, output_layout): - # holistic layout of input to wark on - self.input_layout = input_layout - # expected holistic layout of output - self.output_layout = output_layout + # operator: take any inputs and generate output + self.F = func + + # holistic layout of input to work on + self.input_layout = dict() - def input_transform(self, args, **kwargs): - """input transformation to the required layout""" + # holistic layout of output + self.output_layout = dict() + + def boundary_in(self, args, **kwargs): + """ + Transform tensors in args and kwargs to match the + input layout requirement + """ + pass + + def warp_to_holistic_tensor(self, outputs): + """ + Wrap local computed tensor into a holistic view + by using self.output_layout + """ pass def forward(self, args, **kwargs): @@ -23,10 +40,14 @@ def forward(self, args, **kwargs): pass def __call__(self, args, **kwargs): - """Operator execution""" - self.input_transform(args, kwargs) + # data transformations to match input layout requirement + self.boundary_in(args, kwargs) + + # do execution + outputs = self.F(args, kwargs) - outputs = self.forward(args, kwargs) + # wrap in holistic tensor with output layout + outputs = self.warp_to_holistic_tensor(outputs) return outputs diff --git a/combo/physic/operator/generics.py b/combo/physic/operator/generics.py index d4603475..68ec043f 100644 --- a/combo/physic/operator/generics.py +++ b/combo/physic/operator/generics.py @@ -1,49 +1,5 @@ """ -Physical Generic Operator definition. - -The output communication works in a lazy execution way. Communication will only -happen in the front of the next executed op in case the layout doesn't match. +This should be the interface with C level kernel launch """ -class GenericOp: - - def __init__(self, func): - """ - func: Should be a logical operator handling holistic tensors. - """ - - # operator: take any inputs and generate output - self.F = func - - # function inputs requirement - self.input_layout = dict() - - # the expected function output holistic layout - self.output_layout = dict() - - def boundary_in(self, args, **kwargs): - """ - Transform tensors in args and kwargs to match the - input layout requirement - """ - pass - - def warp_to_holistic_tensor(self, outputs): - """ - Wrap local computed tensor into a holistic view - by using self.output_layout - """ - pass - - def execute(self, args, **kwargs): - - # data transformations to match input layout requirement - self.boundary_in(args, kwargs) - - # do execution - outputs = self.F(args, kwargs) - - # wrap in holistic tensor with output layout - outputs = self.warp_to_holistic_tensor(outputs) - - return outputs \ No newline at end of file +import torch From 5f864d511341fb1098245cea45ddc597c08344c8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 8 Jul 2021 06:10:12 +0000 Subject: [PATCH 0020/1892] update generics abstraction --- combo/holist/operator/generics.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/combo/holist/operator/generics.py b/combo/holist/operator/generics.py index ada37f6a..ddab26e8 100644 --- a/combo/holist/operator/generics.py +++ b/combo/holist/operator/generics.py @@ -12,9 +12,6 @@ class GenericHolisticOp: def __init__(self, input_layout, output_layout): - # operator: take any inputs and generate output - self.F = func - # holistic layout of input to work on self.input_layout = dict() @@ -36,7 +33,8 @@ def warp_to_holistic_tensor(self, outputs): pass def forward(self, args, **kwargs): - """Expert code for doing operation""" + """Expert code for doing operation + Call to the physical operator for execution""" pass def __call__(self, args, **kwargs): @@ -45,7 +43,7 @@ def __call__(self, args, **kwargs): self.boundary_in(args, kwargs) # do execution - outputs = self.F(args, kwargs) + outputs = self.forward(args, kwargs) # wrap in holistic tensor with output layout outputs = self.warp_to_holistic_tensor(outputs) From 0541f2c477c0ba66d5cccbc145575e4ef3955b71 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 14 Jul 2021 03:08:35 +0000 Subject: [PATCH 0021/1892] re-organize levels --- combo/{holist => device}/__init__.py | 0 .../operator => device/physic}/__init__.py | 0 .../{physic/device => device/physic}/group.py | 0 combo/{logic => operator}/__init__.py | 0 combo/{physic => operator/holist}/__init__.py | 0 .../operator => operator/holist}/generics.py | 0 .../device => operator/logic}/__init__.py | 0 combo/{ => operator}/logic/generics.py | 0 .../operator => operator/physic}/__init__.py | 0 .../physic}/comm/__init__.py | 0 .../physic}/comm/boundary.py | 0 .../operator => operator/physic}/generics.py | 0 .../operator => operator/physic}/linear.py | 0 combo/tensor/__init__.py | 0 combo/tensor/logic/__init__.py | 0 combo/tensor/logic/community.py | 33 +++++++++++++++++++ combo/tensor/logic/tensor.py | 8 +++++ combo/tensor/physic/__init__.py | 0 combo/tensor/physic/tensor.py | 0 19 files changed, 41 insertions(+) rename combo/{holist => device}/__init__.py (100%) rename combo/{holist/operator => device/physic}/__init__.py (100%) rename combo/{physic/device => device/physic}/group.py (100%) rename combo/{logic => operator}/__init__.py (100%) rename combo/{physic => operator/holist}/__init__.py (100%) rename combo/{holist/operator => operator/holist}/generics.py (100%) rename combo/{physic/device => operator/logic}/__init__.py (100%) rename combo/{ => operator}/logic/generics.py (100%) rename combo/{physic/operator => operator/physic}/__init__.py (100%) rename combo/{physic/operator => operator/physic}/comm/__init__.py (100%) rename combo/{physic/operator => operator/physic}/comm/boundary.py (100%) rename combo/{physic/operator => operator/physic}/generics.py (100%) rename combo/{physic/operator => operator/physic}/linear.py (100%) create mode 100644 combo/tensor/__init__.py create mode 100644 combo/tensor/logic/__init__.py create mode 100644 combo/tensor/logic/community.py create mode 100644 combo/tensor/logic/tensor.py create mode 100644 combo/tensor/physic/__init__.py create mode 100644 combo/tensor/physic/tensor.py diff --git a/combo/holist/__init__.py b/combo/device/__init__.py similarity index 100% rename from combo/holist/__init__.py rename to combo/device/__init__.py diff --git a/combo/holist/operator/__init__.py b/combo/device/physic/__init__.py similarity index 100% rename from combo/holist/operator/__init__.py rename to combo/device/physic/__init__.py diff --git a/combo/physic/device/group.py b/combo/device/physic/group.py similarity index 100% rename from combo/physic/device/group.py rename to combo/device/physic/group.py diff --git a/combo/logic/__init__.py b/combo/operator/__init__.py similarity index 100% rename from combo/logic/__init__.py rename to combo/operator/__init__.py diff --git a/combo/physic/__init__.py b/combo/operator/holist/__init__.py similarity index 100% rename from combo/physic/__init__.py rename to combo/operator/holist/__init__.py diff --git a/combo/holist/operator/generics.py b/combo/operator/holist/generics.py similarity index 100% rename from combo/holist/operator/generics.py rename to combo/operator/holist/generics.py diff --git a/combo/physic/device/__init__.py b/combo/operator/logic/__init__.py similarity index 100% rename from combo/physic/device/__init__.py rename to combo/operator/logic/__init__.py diff --git a/combo/logic/generics.py b/combo/operator/logic/generics.py similarity index 100% rename from combo/logic/generics.py rename to combo/operator/logic/generics.py diff --git a/combo/physic/operator/__init__.py b/combo/operator/physic/__init__.py similarity index 100% rename from combo/physic/operator/__init__.py rename to combo/operator/physic/__init__.py diff --git a/combo/physic/operator/comm/__init__.py b/combo/operator/physic/comm/__init__.py similarity index 100% rename from combo/physic/operator/comm/__init__.py rename to combo/operator/physic/comm/__init__.py diff --git a/combo/physic/operator/comm/boundary.py b/combo/operator/physic/comm/boundary.py similarity index 100% rename from combo/physic/operator/comm/boundary.py rename to combo/operator/physic/comm/boundary.py diff --git a/combo/physic/operator/generics.py b/combo/operator/physic/generics.py similarity index 100% rename from combo/physic/operator/generics.py rename to combo/operator/physic/generics.py diff --git a/combo/physic/operator/linear.py b/combo/operator/physic/linear.py similarity index 100% rename from combo/physic/operator/linear.py rename to combo/operator/physic/linear.py diff --git a/combo/tensor/__init__.py b/combo/tensor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/tensor/logic/__init__.py b/combo/tensor/logic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/tensor/logic/community.py b/combo/tensor/logic/community.py new file mode 100644 index 00000000..4a4fad69 --- /dev/null +++ b/combo/tensor/logic/community.py @@ -0,0 +1,33 @@ + + +class Community: + + def __init__(self, logical_tensor, reduction=None): + """Create Community based on the logical tensor + + Attribute: + parent (LogicalTensor): + Logical Tensor the Community belongs to + reduction (Callable or None): + Reduction function for retrieve back physical tensors + + """ + self.parent = logical_tensor + self.reduction = reduction + + def spread(self, device_list): + """Create physical tensors and spread to devices + + Argument: + device_list (list[int]): device id list + + Return: + + """ + pass + + def fuse(self): + """Fuse the spread physical tensors into the one + Perform reduction function to get the results on each physical tensor + """ + pass diff --git a/combo/tensor/logic/tensor.py b/combo/tensor/logic/tensor.py new file mode 100644 index 00000000..dc4fb8b4 --- /dev/null +++ b/combo/tensor/logic/tensor.py @@ -0,0 +1,8 @@ + + +class LogicTensor: + + def __init__(self, ): + + self.communities = list() + diff --git a/combo/tensor/physic/__init__.py b/combo/tensor/physic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/combo/tensor/physic/tensor.py b/combo/tensor/physic/tensor.py new file mode 100644 index 00000000..e69de29b From 3589f1310b712468f62dcabc5bd7fe8fd52c5169 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 14 Jul 2021 07:10:11 +0000 Subject: [PATCH 0022/1892] update virtual community interface --- combo/tensor/logic/community.py | 54 ++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/combo/tensor/logic/community.py b/combo/tensor/logic/community.py index 4a4fad69..83bad95e 100644 --- a/combo/tensor/logic/community.py +++ b/combo/tensor/logic/community.py @@ -1,3 +1,34 @@ +import torch + + +__all__ = ['ReductionOpPool', 'Community'] + + +class _Reduction(type): + + Sum = torch.distributed.all_reduce + + # identity for replica + Replica = lambda physical_tensor, group : physical_tensor + + def register(name, udf): + """ + Reduction functions should be in function format: + + Arguments: + PhysicalTensor + Communication Group + + Return: + PhysicalTensor + """ + if hasattr(cls, name): + raise KeyError("{} is registered".format(name)) + setattr(cls, name, udf) + + +class ReductionOpPool(metaclass=_Reduction): + pass class Community: @@ -16,18 +47,31 @@ def __init__(self, logical_tensor, reduction=None): self.reduction = reduction def spread(self, device_list): - """Create physical tensors and spread to devices + """Spread physical tensors to devices + + Create physical tensors for this community and spread out + based on the given device list. + + This offers policy module an interface to decide which devices + to spread. Argument: device_list (list[int]): device id list Return: + PhysicalTensor(s) or None: + For SPMD programming model: + if current device is in the `device_list`, + than return the corresponding physical tensor, + else None + + For Global View programming model: + return list[PhysicalTensor] with the same + order of `device_list`. """ pass - def fuse(self): - """Fuse the spread physical tensors into the one - Perform reduction function to get the results on each physical tensor - """ + def sync(self): + """Synchrnoize the spread physical tensors by reduction operation""" pass From 0f741f61530c44f487b1f1a46382e37eb4e0f177 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 14 Jul 2021 08:17:36 +0000 Subject: [PATCH 0023/1892] update logical tensor interface --- combo/tensor/logic/segment.py | 57 +++++++++++++++++++++++++++++++++++ combo/tensor/logic/tensor.py | 21 ++++++++++++- 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 combo/tensor/logic/segment.py diff --git a/combo/tensor/logic/segment.py b/combo/tensor/logic/segment.py new file mode 100644 index 00000000..97e94049 --- /dev/null +++ b/combo/tensor/logic/segment.py @@ -0,0 +1,57 @@ +""" +This is the interface for describing which set of data is needed for +gathering a community. +""" + +## Basic interface to cover all the cases + + +class DataSegment: + """ + The basic primitive to gather data in the logical tensor. + + """ + + def __init__(self, indices_list=None): + """ + Args: + indices_list (list[ list[int] ]): + List of index + """ + + self.indices = indices_list + + def convert_to_indices(self): + """ + Convert to index list + """ + pass + + +## Higher level interface to cover the most cases ## + +class TileSegment(DataSegment): + """ + A tile is a contigonous block on the logical tensor shape, + which can be represented as the start position + offset (shape) + """ + + def __init__(self, anchor, offset): + """ + Args: + anchor (list[int]): start position of the tile + offset (list[int]): offset (shape) of the tile + """ + if len(anchor) != len(offset): + raise ValueError("Require anchor length to be equal with offset length") + super().__init__() + self.anchor = anchor + self.offset = offset + + def convert_to_indices(self): + """ + Convert anchor and offset to index list + """ + pass + + diff --git a/combo/tensor/logic/tensor.py b/combo/tensor/logic/tensor.py index dc4fb8b4..a76c4d5f 100644 --- a/combo/tensor/logic/tensor.py +++ b/combo/tensor/logic/tensor.py @@ -1,8 +1,27 @@ +from combo.tensor.community import Community, ReductionOpPool -class LogicTensor: +class LogicalTensor: def __init__(self, ): self.communities = list() + def create_community(self, segment): + """Create a community by given the segment""" + self.communities.append(segment) + + def fuse(self, communities=None, reduction=ReductionOpPool.Replica): + """Fuse multiple communities into one + + Synchronization will done for each community to retrieve the right + result. + + Args: + communities (list[Community]): + The particular comunities to merge. + If not specified (None), fuse all the communities. + reduction: + Reduction operator for the new fused community. + """" + pass From c2908e9e1d6e0881bc59dc4bb70d2a74e1fb7835 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 14 Jul 2021 08:19:05 +0000 Subject: [PATCH 0024/1892] move community outside as it is a connection to physical tensor and logical tensor --- combo/tensor/{logic => }/community.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) rename combo/tensor/{logic => }/community.py (92%) diff --git a/combo/tensor/logic/community.py b/combo/tensor/community.py similarity index 92% rename from combo/tensor/logic/community.py rename to combo/tensor/community.py index 83bad95e..c449c290 100644 --- a/combo/tensor/logic/community.py +++ b/combo/tensor/community.py @@ -33,12 +33,14 @@ class ReductionOpPool(metaclass=_Reduction): class Community: - def __init__(self, logical_tensor, reduction=None): + def __init__(self, logical_tensor, segment, reduction=None): """Create Community based on the logical tensor Attribute: parent (LogicalTensor): Logical Tensor the Community belongs to + segment (DataSegment): + indices of logical_tensor for this community reduction (Callable or None): Reduction function for retrieve back physical tensors From b15b6f2517a9e4d89f372bb233d654c1c296e775 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 14 Jul 2021 08:31:59 +0000 Subject: [PATCH 0025/1892] add communiy with connection of logical / physical tensor --- combo/tensor/community.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/combo/tensor/community.py b/combo/tensor/community.py index c449c290..dd1724f5 100644 --- a/combo/tensor/community.py +++ b/combo/tensor/community.py @@ -45,7 +45,13 @@ def __init__(self, logical_tensor, segment, reduction=None): Reduction function for retrieve back physical tensors """ + # connection to logical tensor self.parent = logical_tensor + self.segment = segment + + # connection to physical tensor + self.mapping = None + self.reduction = reduction def spread(self, device_list): From cd94413774a5cd1c2bd06854dcc2ae98b951f908 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 15 Jul 2021 11:46:08 +0000 Subject: [PATCH 0026/1892] community interface design --- combo/tensor/community.py | 25 ++++++++++++++++++++++--- combo/tensor/logic/segment.py | 16 +++++++++++++++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/combo/tensor/community.py b/combo/tensor/community.py index dd1724f5..9a1f504c 100644 --- a/combo/tensor/community.py +++ b/combo/tensor/community.py @@ -36,6 +36,11 @@ class Community: def __init__(self, logical_tensor, segment, reduction=None): """Create Community based on the logical tensor + Community manages one: + + 1). Logical Tensor data mapping to Physical Tensor data storage + 2). Materialized Physical Tensors + Attribute: parent (LogicalTensor): Logical Tensor the Community belongs to @@ -47,17 +52,20 @@ def __init__(self, logical_tensor, segment, reduction=None): """ # connection to logical tensor self.parent = logical_tensor + + # DataSegment to indicate both element set and data format mapping self.segment = segment - # connection to physical tensor - self.mapping = None + # connection to physical tensor (the PyTorch Tensor) + self.phsyical_tensor = None + self.materialized = False self.reduction = reduction def spread(self, device_list): """Spread physical tensors to devices - Create physical tensors for this community and spread out + Materialize physical tensors for this community and spread out based on the given device list. This offers policy module an interface to decide which devices @@ -83,3 +91,14 @@ def spread(self, device_list): def sync(self): """Synchrnoize the spread physical tensors by reduction operation""" pass + + def get_physical_tensor(self): + """Get physical tensor if materialized + + Returns: + PhysicalTensor (if materialized) + """" + if self.materialized: + return self.physical_tensor + else: + raise RuntimeError("The Community has not been materialized to physical tensors") diff --git a/combo/tensor/logic/segment.py b/combo/tensor/logic/segment.py index 97e94049..d54ebf34 100644 --- a/combo/tensor/logic/segment.py +++ b/combo/tensor/logic/segment.py @@ -10,12 +10,13 @@ class DataSegment: """ The basic primitive to gather data in the logical tensor. + The order of indices indicate the physical storage (1-D array) order """ def __init__(self, indices_list=None): """ Args: - indices_list (list[ list[int] ]): + indices_list (list[ list[int], ]): List of index """ @@ -27,6 +28,16 @@ def convert_to_indices(self): """ pass + def reorder(self, new_orders): + """ + Reorder the indices. + + Note this can be only called before materialize physical tensors, + or called from underlying operation that will change physical storage format + """ + #TODO: validation check + self.indices = new_orders + ## Higher level interface to cover the most cases ## @@ -53,5 +64,8 @@ def convert_to_indices(self): Convert anchor and offset to index list """ pass + + def reorder(self): + pass From e5790d0ab6e791a468c90fda3a8052d9c2b982e4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 15 Jul 2021 13:09:53 +0000 Subject: [PATCH 0027/1892] update holistic op interface --- combo/operator/holist/generics.py | 32 +++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/combo/operator/holist/generics.py b/combo/operator/holist/generics.py index ddab26e8..09d176f7 100644 --- a/combo/operator/holist/generics.py +++ b/combo/operator/holist/generics.py @@ -10,25 +10,41 @@ class GenericHolisticOp: - def __init__(self, input_layout, output_layout): + def __init__(self, + input_layout, output_layout, + input_format=None, output_format=None): + """ + Layout is the community distribution requirement for input and + output logical tensors. - # holistic layout of input to work on + Format is the dimension ordering based on the logical format, + `None` indicates the format is consistent with logical op, + otherwise should be a list of integers like torch.Tensor.permute() + on the logical required format. + """ + # holistic layout of input self.input_layout = dict() + self.input_format = input_format # holistic layout of output self.output_layout = dict() + self.output_format = output_format - def boundary_in(self, args, **kwargs): + def input_adapter(self, args, **kwargs): """ Transform tensors in args and kwargs to match the input layout requirement """ + # step 1: data reformat based on the input argument + + # step 2: physical tensor placement (policy) + + # step 3: community matching pass - def warp_to_holistic_tensor(self, outputs): + def output_adapter(self, outputs): """ - Wrap local computed tensor into a holistic view - by using self.output_layout + Data reformat to logical op format """ pass @@ -40,12 +56,12 @@ def forward(self, args, **kwargs): def __call__(self, args, **kwargs): # data transformations to match input layout requirement - self.boundary_in(args, kwargs) + self.input_adapter(args, kwargs) # do execution outputs = self.forward(args, kwargs) # wrap in holistic tensor with output layout - outputs = self.warp_to_holistic_tensor(outputs) + outputs = self.output_adapter(outputs) return outputs From f9f4419c5272bba402b7b63ccc88e05176bee40f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 15 Jul 2021 13:22:34 +0000 Subject: [PATCH 0028/1892] add example code for linear op: It's really hard for dynamic tensor shape --- combo/operator/holist/linear.py | 76 +++++++++++++++++++++++++++++++++ combo/operator/logic/linear.py | 16 +++++++ 2 files changed, 92 insertions(+) create mode 100644 combo/operator/holist/linear.py create mode 100644 combo/operator/logic/linear.py diff --git a/combo/operator/holist/linear.py b/combo/operator/holist/linear.py new file mode 100644 index 00000000..30123d35 --- /dev/null +++ b/combo/operator/holist/linear.py @@ -0,0 +1,76 @@ +from combo.operator.holist.generics import GenericHolisticOp + +import combo.operator.physic as physic_op + +from combo.tensor.logic.tensor import LogicalTensor +from combo.tensor.logic.segment import TileSegment +from combo.tensor.community import Community + +# expert space to declare all kinds of holistic operators + +__all__ = ['kHolistLinearSets'] + +class LinearColumnWeight(GenericHolisticOp): + """ + Perform Y = XW + b -> Y = X[W1,W2] + [b1,b2] + Split W and b on the last dimension + """ + + def __init__(self): + + # TODO + inputs_layout = None + # TODO + weight_layout = None + # TODO + bias_layout = None + # TODO + output_layout = None + + super().__init__( + input_layout=(inputs_layout, weight_layout), + output_layout=(output_layout,) + ) + + def forward(self, inputs, weight, bias): + outputs = list() + # TODO: handle bias is None + for pw, pb in zip(weight, bias): + output = physic_op.linear(inputs, weight, bias) + outputs.append(outputs) + return outputs + + +class LinearColumnInputRowWeight(GenericHolisticOp): + """ + Perform + Y = XW + b + -> Y = [X1,X2] * [W1//W2] + b] + -> Y = X1W1 + X2W2 + b + Split X (inputs) in column major (last dim), + Split W (weights) in row major (first dim) + """ + + def __init__(self): + + # TODO + inputs_layout = None + # TODO + weight_layout = None + # TODO + bias_layout = None + # TODO + output_layout = None + + super().__init__( + input_layout=(inputs_layout, weight_layout), + output_layout=(output_layout,) + ) + + def forward(self, inputs, weight, bias): + #TODO: semantic errors on bias + output = physic_op.linear(inputs, weight, bias) + return [output] + + +kHolistLinearSets = [LinearColumnWeight(), LinearColumnInputRowWeight()] \ No newline at end of file diff --git a/combo/operator/logic/linear.py b/combo/operator/logic/linear.py new file mode 100644 index 00000000..1ce25184 --- /dev/null +++ b/combo/operator/logic/linear.py @@ -0,0 +1,16 @@ +from combo.operator.logic.generics import generics +from combo.operator.holist.linear import kHolistLinearSets + +__all__ = ['linear'] + +def Linear(generics.GenericLogicalOp): + + def __init__(self): + super().__init__(self) + + # register holistic operators + for holist_op in kHolistLinearSets: + self.factory.register(holist_op) + +# initialize op +linear = Linear() \ No newline at end of file From 6fb702f8a08f1dd9723ceaa268a60b8bf19a3a46 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Jul 2021 03:17:15 +0000 Subject: [PATCH 0029/1892] rename to cube --- combo/operator/physic/__init__.py | 1 - combo/operator/physic/comm/__init__.py | 1 - {combo => cube}/__init__.py | 0 {combo => cube}/device/__init__.py | 0 {combo => cube}/device/physic/__init__.py | 0 {combo => cube}/device/physic/group.py | 0 {combo => cube}/operator/__init__.py | 0 {combo => cube}/operator/holist/__init__.py | 0 {combo => cube}/operator/holist/generics.py | 0 {combo => cube}/operator/holist/linear.py | 10 +++++----- {combo => cube}/operator/logic/__init__.py | 0 {combo => cube}/operator/logic/generics.py | 0 {combo => cube}/operator/logic/linear.py | 4 ++-- cube/operator/physic/__init__.py | 1 + cube/operator/physic/comm/__init__.py | 1 + {combo => cube}/operator/physic/comm/boundary.py | 2 +- {combo => cube}/operator/physic/generics.py | 0 {combo => cube}/operator/physic/linear.py | 4 ++-- {combo => cube}/tensor/__init__.py | 0 {combo => cube}/tensor/community.py | 0 {combo => cube}/tensor/logic/__init__.py | 0 {combo => cube}/tensor/logic/segment.py | 0 {combo => cube}/tensor/logic/tensor.py | 2 +- {combo => cube}/tensor/physic/__init__.py | 0 {combo => cube}/tensor/physic/tensor.py | 0 setup.py | 8 ++++---- 26 files changed, 17 insertions(+), 17 deletions(-) delete mode 100644 combo/operator/physic/__init__.py delete mode 100644 combo/operator/physic/comm/__init__.py rename {combo => cube}/__init__.py (100%) rename {combo => cube}/device/__init__.py (100%) rename {combo => cube}/device/physic/__init__.py (100%) rename {combo => cube}/device/physic/group.py (100%) rename {combo => cube}/operator/__init__.py (100%) rename {combo => cube}/operator/holist/__init__.py (100%) rename {combo => cube}/operator/holist/generics.py (100%) rename {combo => cube}/operator/holist/linear.py (87%) rename {combo => cube}/operator/logic/__init__.py (100%) rename {combo => cube}/operator/logic/generics.py (100%) rename {combo => cube}/operator/logic/linear.py (72%) create mode 100644 cube/operator/physic/__init__.py create mode 100644 cube/operator/physic/comm/__init__.py rename {combo => cube}/operator/physic/comm/boundary.py (98%) rename {combo => cube}/operator/physic/generics.py (100%) rename {combo => cube}/operator/physic/linear.py (94%) rename {combo => cube}/tensor/__init__.py (100%) rename {combo => cube}/tensor/community.py (100%) rename {combo => cube}/tensor/logic/__init__.py (100%) rename {combo => cube}/tensor/logic/segment.py (100%) rename {combo => cube}/tensor/logic/tensor.py (92%) rename {combo => cube}/tensor/physic/__init__.py (100%) rename {combo => cube}/tensor/physic/tensor.py (100%) diff --git a/combo/operator/physic/__init__.py b/combo/operator/physic/__init__.py deleted file mode 100644 index 0151a3e1..00000000 --- a/combo/operator/physic/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from combo.physical.operator.linear import linear_op \ No newline at end of file diff --git a/combo/operator/physic/comm/__init__.py b/combo/operator/physic/comm/__init__.py deleted file mode 100644 index 7390d802..00000000 --- a/combo/operator/physic/comm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from combo.physical.operator.comm.boundary import * \ No newline at end of file diff --git a/combo/__init__.py b/cube/__init__.py similarity index 100% rename from combo/__init__.py rename to cube/__init__.py diff --git a/combo/device/__init__.py b/cube/device/__init__.py similarity index 100% rename from combo/device/__init__.py rename to cube/device/__init__.py diff --git a/combo/device/physic/__init__.py b/cube/device/physic/__init__.py similarity index 100% rename from combo/device/physic/__init__.py rename to cube/device/physic/__init__.py diff --git a/combo/device/physic/group.py b/cube/device/physic/group.py similarity index 100% rename from combo/device/physic/group.py rename to cube/device/physic/group.py diff --git a/combo/operator/__init__.py b/cube/operator/__init__.py similarity index 100% rename from combo/operator/__init__.py rename to cube/operator/__init__.py diff --git a/combo/operator/holist/__init__.py b/cube/operator/holist/__init__.py similarity index 100% rename from combo/operator/holist/__init__.py rename to cube/operator/holist/__init__.py diff --git a/combo/operator/holist/generics.py b/cube/operator/holist/generics.py similarity index 100% rename from combo/operator/holist/generics.py rename to cube/operator/holist/generics.py diff --git a/combo/operator/holist/linear.py b/cube/operator/holist/linear.py similarity index 87% rename from combo/operator/holist/linear.py rename to cube/operator/holist/linear.py index 30123d35..8c58807a 100644 --- a/combo/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -1,10 +1,10 @@ -from combo.operator.holist.generics import GenericHolisticOp +from cube.operator.holist.generics import GenericHolisticOp -import combo.operator.physic as physic_op +import cube.operator.physic as physic_op -from combo.tensor.logic.tensor import LogicalTensor -from combo.tensor.logic.segment import TileSegment -from combo.tensor.community import Community +from cube.tensor.logic.tensor import LogicalTensor +from cube.tensor.logic.segment import TileSegment +from cube.tensor.community import Community # expert space to declare all kinds of holistic operators diff --git a/combo/operator/logic/__init__.py b/cube/operator/logic/__init__.py similarity index 100% rename from combo/operator/logic/__init__.py rename to cube/operator/logic/__init__.py diff --git a/combo/operator/logic/generics.py b/cube/operator/logic/generics.py similarity index 100% rename from combo/operator/logic/generics.py rename to cube/operator/logic/generics.py diff --git a/combo/operator/logic/linear.py b/cube/operator/logic/linear.py similarity index 72% rename from combo/operator/logic/linear.py rename to cube/operator/logic/linear.py index 1ce25184..c5f40bf3 100644 --- a/combo/operator/logic/linear.py +++ b/cube/operator/logic/linear.py @@ -1,5 +1,5 @@ -from combo.operator.logic.generics import generics -from combo.operator.holist.linear import kHolistLinearSets +from cube.operator.logic.generics import generics +from cube.operator.holist.linear import kHolistLinearSets __all__ = ['linear'] diff --git a/cube/operator/physic/__init__.py b/cube/operator/physic/__init__.py new file mode 100644 index 00000000..cb37c2a2 --- /dev/null +++ b/cube/operator/physic/__init__.py @@ -0,0 +1 @@ +from cube.physical.operator.linear import linear_op \ No newline at end of file diff --git a/cube/operator/physic/comm/__init__.py b/cube/operator/physic/comm/__init__.py new file mode 100644 index 00000000..962cde7b --- /dev/null +++ b/cube/operator/physic/comm/__init__.py @@ -0,0 +1 @@ +from cube.physical.operator.comm.boundary import * \ No newline at end of file diff --git a/combo/operator/physic/comm/boundary.py b/cube/operator/physic/comm/boundary.py similarity index 98% rename from combo/operator/physic/comm/boundary.py rename to cube/operator/physic/comm/boundary.py index be5f9465..54960a66 100644 --- a/combo/operator/physic/comm/boundary.py +++ b/cube/operator/physic/comm/boundary.py @@ -5,7 +5,7 @@ import torch -from combo.physical.device.group import DeviceGroup +from cube.physical.device.group import DeviceGroup __all__ = ['parallel_in', 'gather_out', 'scatter_in', 'reduce_out'] diff --git a/combo/operator/physic/generics.py b/cube/operator/physic/generics.py similarity index 100% rename from combo/operator/physic/generics.py rename to cube/operator/physic/generics.py diff --git a/combo/operator/physic/linear.py b/cube/operator/physic/linear.py similarity index 94% rename from combo/operator/physic/linear.py rename to cube/operator/physic/linear.py index 35eda888..be8d925e 100644 --- a/combo/operator/physic/linear.py +++ b/cube/operator/physic/linear.py @@ -3,8 +3,8 @@ from torch import Tensor from torch.overrides import has_torch_function_variadic, handle_torch_function -import combo.physic.operator.comm as comm -from combo.physic.device.group import DeviceGroup +import cube.physic.operator.comm as comm +from cube.physic.device.group import DeviceGroup def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: diff --git a/combo/tensor/__init__.py b/cube/tensor/__init__.py similarity index 100% rename from combo/tensor/__init__.py rename to cube/tensor/__init__.py diff --git a/combo/tensor/community.py b/cube/tensor/community.py similarity index 100% rename from combo/tensor/community.py rename to cube/tensor/community.py diff --git a/combo/tensor/logic/__init__.py b/cube/tensor/logic/__init__.py similarity index 100% rename from combo/tensor/logic/__init__.py rename to cube/tensor/logic/__init__.py diff --git a/combo/tensor/logic/segment.py b/cube/tensor/logic/segment.py similarity index 100% rename from combo/tensor/logic/segment.py rename to cube/tensor/logic/segment.py diff --git a/combo/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py similarity index 92% rename from combo/tensor/logic/tensor.py rename to cube/tensor/logic/tensor.py index a76c4d5f..3c185c0c 100644 --- a/combo/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -1,4 +1,4 @@ -from combo.tensor.community import Community, ReductionOpPool +from cube.tensor.community import Community, ReductionOpPool class LogicalTensor: diff --git a/combo/tensor/physic/__init__.py b/cube/tensor/physic/__init__.py similarity index 100% rename from combo/tensor/physic/__init__.py rename to cube/tensor/physic/__init__.py diff --git a/combo/tensor/physic/tensor.py b/cube/tensor/physic/tensor.py similarity index 100% rename from combo/tensor/physic/tensor.py rename to cube/tensor/physic/tensor.py diff --git a/setup.py b/setup.py index c24c4bd9..41c696a8 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,12 @@ import setuptools setuptools.setup( - name= 'combo', + name= 'cube', version= '0.1', author= 'Zhiqi Lin', author_email= 'v-zhiql@microsoft.com', - description= 'Combo', - long_description= 'Combo', - packages= ['combo'], + description= 'Magic Cube for configurable-DNN framework', + long_description= 'Magic Cube for configurable-DNN framework', + packages= ['cube'], python_requires= '>=3.6', ) From 96c854db9c1f02bfd9c8a4bdc8310f756f62c5bb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Jul 2021 03:18:41 +0000 Subject: [PATCH 0030/1892] add example primitives for splitting axis --- cube/tensor/logic/segment.py | 45 ++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/cube/tensor/logic/segment.py b/cube/tensor/logic/segment.py index d54ebf34..2081de15 100644 --- a/cube/tensor/logic/segment.py +++ b/cube/tensor/logic/segment.py @@ -3,9 +3,8 @@ gathering a community. """ -## Basic interface to cover all the cases - +## Basic interface to cover all the cases class DataSegment: """ The basic primitive to gather data in the logical tensor. @@ -40,7 +39,6 @@ def reorder(self, new_orders): ## Higher level interface to cover the most cases ## - class TileSegment(DataSegment): """ A tile is a contigonous block on the logical tensor shape, @@ -67,5 +65,44 @@ def convert_to_indices(self): def reorder(self): pass - + +# primitives to describe segmentation pattern + +class SplitAxis: + + def __init__(self, axis, chunk_num=None, chunk_size=None, overlap=0): + """ + Segmentation Pattern Requirement (parameters): + + axis (int): the axis to split + + chunk_num (None, int, tuple(int, int)): + valid chunk numbers to split. + If None, then any chunk number is valid; + If an integer, only the specified chunk number is valid; + If a tuple(min, max), the chunk number wihtin the scope [min,max] is valid + + chunk_size (None, int, tuple(int, int)): + valid chunk size. + If None, any size is valid; + If an integer, each chunk size is valid; + if a tuple(min, max), the chunk size wihtin the scope [min,max] is valid + + overlap (0, int, tuple(int, int)): + valid size for overlaping on the boundary of each splitted chunks. + If None, any overlapping is valid + If an integer, each overlap size is valid; + if a tuple(min, max), the overlap size wihtin the scope [min,max] is valid + + """ + self.axis = axis + self.chunk_num = chunk_num + self.chunk_size = chunk_size + self.overlap = overlap + + def __call__(self, shape): + """ + Runtime community generation given the logical tensor shape + """ + pass From 8b22938da2d5fbb4d0cb52219501e5cd4e581ed7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Jul 2021 04:49:11 +0000 Subject: [PATCH 0031/1892] primitive + runtime interface --- cube/tensor/logic/segment/__init__.py | 0 cube/tensor/logic/segment/primitive.py | 47 +++++++++++++++++++ .../logic/{segment.py => segment/runtime.py} | 41 ---------------- 3 files changed, 47 insertions(+), 41 deletions(-) create mode 100644 cube/tensor/logic/segment/__init__.py create mode 100644 cube/tensor/logic/segment/primitive.py rename cube/tensor/logic/{segment.py => segment/runtime.py} (55%) diff --git a/cube/tensor/logic/segment/__init__.py b/cube/tensor/logic/segment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/tensor/logic/segment/primitive.py b/cube/tensor/logic/segment/primitive.py new file mode 100644 index 00000000..ac38d3a3 --- /dev/null +++ b/cube/tensor/logic/segment/primitive.py @@ -0,0 +1,47 @@ +""" +This is the primitive to describe the layout requirement. + +The primitive translates the requirement before runtime to the +actual community groups during the runtime +""" + + +# primitives to describe segmentation pattern + +class SplitAxis: + + def __init__(self, axis, chunk_num=None, chunk_size=None, overlap=0): + """ + Segmentation Pattern Requirement (parameters): + + axis (int): the axis to split + + chunk_num (None, int, tuple(int, int)): + valid chunk numbers to split. + If None, then any chunk number is valid; + If an integer, only the specified chunk number is valid; + If a tuple(min, max), the chunk number wihtin the scope [min,max] is valid + + chunk_size (None, int, tuple(int, int)): + valid chunk size. + If None, any size is valid; + If an integer, each chunk size is valid; + if a tuple(min, max), the chunk size wihtin the scope [min,max] is valid + + overlap (0, int, tuple(int, int)): + valid size for overlaping on the boundary of each splitted chunks. + If None, any overlapping is valid + If an integer, each overlap size is valid; + if a tuple(min, max), the overlap size wihtin the scope [min,max] is valid + + """ + self.axis = axis + self.chunk_num = chunk_num + self.chunk_size = chunk_size + self.overlap = overlap + + def __call__(self, shape): + """ + Runtime community generation given the logical tensor shape + """ + pass diff --git a/cube/tensor/logic/segment.py b/cube/tensor/logic/segment/runtime.py similarity index 55% rename from cube/tensor/logic/segment.py rename to cube/tensor/logic/segment/runtime.py index 2081de15..ef85e867 100644 --- a/cube/tensor/logic/segment.py +++ b/cube/tensor/logic/segment/runtime.py @@ -65,44 +65,3 @@ def convert_to_indices(self): def reorder(self): pass - - -# primitives to describe segmentation pattern - -class SplitAxis: - - def __init__(self, axis, chunk_num=None, chunk_size=None, overlap=0): - """ - Segmentation Pattern Requirement (parameters): - - axis (int): the axis to split - - chunk_num (None, int, tuple(int, int)): - valid chunk numbers to split. - If None, then any chunk number is valid; - If an integer, only the specified chunk number is valid; - If a tuple(min, max), the chunk number wihtin the scope [min,max] is valid - - chunk_size (None, int, tuple(int, int)): - valid chunk size. - If None, any size is valid; - If an integer, each chunk size is valid; - if a tuple(min, max), the chunk size wihtin the scope [min,max] is valid - - overlap (0, int, tuple(int, int)): - valid size for overlaping on the boundary of each splitted chunks. - If None, any overlapping is valid - If an integer, each overlap size is valid; - if a tuple(min, max), the overlap size wihtin the scope [min,max] is valid - - """ - self.axis = axis - self.chunk_num = chunk_num - self.chunk_size = chunk_size - self.overlap = overlap - - def __call__(self, shape): - """ - Runtime community generation given the logical tensor shape - """ - pass From 552b1077bf759dea8ef080564042ef548bdc4e30 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Jul 2021 06:50:26 +0000 Subject: [PATCH 0032/1892] two stage for segmentation --- cube/tensor/logic/segment/outline.py | 52 +++++++++++++++++ cube/tensor/logic/segment/primitive.py | 81 ++++++++++++++++---------- cube/tensor/logic/segment/runtime.py | 67 --------------------- 3 files changed, 102 insertions(+), 98 deletions(-) create mode 100644 cube/tensor/logic/segment/outline.py delete mode 100644 cube/tensor/logic/segment/runtime.py diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py new file mode 100644 index 00000000..c1bdae56 --- /dev/null +++ b/cube/tensor/logic/segment/outline.py @@ -0,0 +1,52 @@ +""" +This is the description interface to describe the +segementation requirement (restrictions). + +The description includes two parts: + + 1). restriction description on tensor segementation + + 2). Translation procedure in runtime to translate such a restriction + to the real segmentation on given logical tensor shape. +""" + + +# primitives to describe segmentation pattern + +class SplitAxis: + + def __init__(self, axis, chunk_num=None, chunk_size=None, overlap=0): + """ + Segmentation Pattern Requirement (parameters): + + axis (int): the axis to split + + chunk_num (None, int, tuple(int, int)): + valid chunk numbers to split. + If None, then any chunk number is valid; + If an integer, only the specified chunk number is valid; + If a tuple(min, max), the chunk number wihtin the scope [min,max] is valid + + chunk_size (None, int, tuple(int, int)): + valid chunk size. + If None, any size is valid; + If an integer, each chunk size is valid; + if a tuple(min, max), the chunk size wihtin the scope [min,max] is valid + + overlap (0, int, tuple(int, int)): + valid size for overlaping on the boundary of each splitted chunks. + If None, any overlapping is valid + If an integer, each overlap size is valid; + if a tuple(min, max), the overlap size wihtin the scope [min,max] is valid + + """ + self.axis = axis + self.chunk_num = chunk_num + self.chunk_size = chunk_size + self.overlap = overlap + + def __call__(self, shape): + """ + Runtime community generation given the logical tensor shape + """ + pass diff --git a/cube/tensor/logic/segment/primitive.py b/cube/tensor/logic/segment/primitive.py index ac38d3a3..c17bb86a 100644 --- a/cube/tensor/logic/segment/primitive.py +++ b/cube/tensor/logic/segment/primitive.py @@ -1,47 +1,66 @@ """ -This is the primitive to describe the layout requirement. - -The primitive translates the requirement before runtime to the -actual community groups during the runtime +This is the runtime primitive sets to setup community for a logical tensor. """ -# primitives to describe segmentation pattern +## Basic interface to cover all the cases +class DataSegment: + """ + The basic primitive to gather data in the logical tensor. -class SplitAxis: + The order of indices indicate the physical storage (1-D array) order + """ - def __init__(self, axis, chunk_num=None, chunk_size=None, overlap=0): + def __init__(self, indices_list=None): + """ + Args: + indices_list (list[ list[int], ]): + List of index """ - Segmentation Pattern Requirement (parameters): - axis (int): the axis to split + self.indices = indices_list - chunk_num (None, int, tuple(int, int)): - valid chunk numbers to split. - If None, then any chunk number is valid; - If an integer, only the specified chunk number is valid; - If a tuple(min, max), the chunk number wihtin the scope [min,max] is valid + def convert_to_indices(self): + """ + Convert to index list + """ + pass - chunk_size (None, int, tuple(int, int)): - valid chunk size. - If None, any size is valid; - If an integer, each chunk size is valid; - if a tuple(min, max), the chunk size wihtin the scope [min,max] is valid - - overlap (0, int, tuple(int, int)): - valid size for overlaping on the boundary of each splitted chunks. - If None, any overlapping is valid - If an integer, each overlap size is valid; - if a tuple(min, max), the overlap size wihtin the scope [min,max] is valid + def reorder(self, new_orders): + """ + Reorder the indices. + Note this can be only called before materialize physical tensors, + or called from underlying operation that will change physical storage format """ - self.axis = axis - self.chunk_num = chunk_num - self.chunk_size = chunk_size - self.overlap = overlap + #TODO: validation check + self.indices = new_orders + + +## Higher level interface to cover the most cases ## +class TileSegment(DataSegment): + """ + A tile is a contigonous block on the logical tensor shape, + which can be represented as the start position + offset (shape) + """ + + def __init__(self, anchor, offset): + """ + Args: + anchor (list[int]): start position of the tile + offset (list[int]): offset (shape) of the tile + """ + if len(anchor) != len(offset): + raise ValueError("Require anchor length to be equal with offset length") + super().__init__() + self.anchor = anchor + self.offset = offset - def __call__(self, shape): + def convert_to_indices(self): """ - Runtime community generation given the logical tensor shape + Convert anchor and offset to index list """ pass + + def reorder(self): + pass diff --git a/cube/tensor/logic/segment/runtime.py b/cube/tensor/logic/segment/runtime.py deleted file mode 100644 index ef85e867..00000000 --- a/cube/tensor/logic/segment/runtime.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -This is the interface for describing which set of data is needed for -gathering a community. -""" - - -## Basic interface to cover all the cases -class DataSegment: - """ - The basic primitive to gather data in the logical tensor. - - The order of indices indicate the physical storage (1-D array) order - """ - - def __init__(self, indices_list=None): - """ - Args: - indices_list (list[ list[int], ]): - List of index - """ - - self.indices = indices_list - - def convert_to_indices(self): - """ - Convert to index list - """ - pass - - def reorder(self, new_orders): - """ - Reorder the indices. - - Note this can be only called before materialize physical tensors, - or called from underlying operation that will change physical storage format - """ - #TODO: validation check - self.indices = new_orders - - -## Higher level interface to cover the most cases ## -class TileSegment(DataSegment): - """ - A tile is a contigonous block on the logical tensor shape, - which can be represented as the start position + offset (shape) - """ - - def __init__(self, anchor, offset): - """ - Args: - anchor (list[int]): start position of the tile - offset (list[int]): offset (shape) of the tile - """ - if len(anchor) != len(offset): - raise ValueError("Require anchor length to be equal with offset length") - super().__init__() - self.anchor = anchor - self.offset = offset - - def convert_to_indices(self): - """ - Convert anchor and offset to index list - """ - pass - - def reorder(self): - pass From 42d5deb5e1606af7a5521d465f019ed214a86931 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Jul 2021 07:27:30 +0000 Subject: [PATCH 0033/1892] add basic primitive for runtime segment generation --- cube/tensor/logic/segment/outline.py | 2 +- cube/tensor/logic/segment/primitive.py | 22 ++++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py index c1bdae56..f7a5aa09 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/segment/outline.py @@ -11,7 +11,7 @@ """ -# primitives to describe segmentation pattern +# interface to setup restrictions on the segmentation class SplitAxis: diff --git a/cube/tensor/logic/segment/primitive.py b/cube/tensor/logic/segment/primitive.py index c17bb86a..dbbd1547 100644 --- a/cube/tensor/logic/segment/primitive.py +++ b/cube/tensor/logic/segment/primitive.py @@ -3,7 +3,7 @@ """ -## Basic interface to cover all the cases +## Basic structure for holding a segment -> cover all the cases ## class DataSegment: """ The basic primitive to gather data in the logical tensor. @@ -37,7 +37,7 @@ def reorder(self, new_orders): self.indices = new_orders -## Higher level interface to cover the most cases ## +## Higher structure to cover the most cases ## class TileSegment(DataSegment): """ A tile is a contigonous block on the logical tensor shape, @@ -64,3 +64,21 @@ def convert_to_indices(self): def reorder(self): pass + + +## Primitive sets for translation ## + +def create_from_indices(indices): + return DataSegment(indices) + + +def create_from_tiles(anchor, offset): + # segments = list() + # dims = len(offset) + # for dim_id in range(dims): + # indices = None # -> TODO: generate indices along the dim_id + # segment = create_from_indices(indices) + # segments.append(segment) + # segment = merge_segments(segments) + # return segment + return TileSegment(anchor, offset) \ No newline at end of file From 63cc45790d53e9abc7f029ca6b0b417a00d14df9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Jul 2021 13:34:42 +0000 Subject: [PATCH 0034/1892] case study for all kinds of optimizations --- examples/case_study/config_linear.py | 123 +++++++++++++++++++++++++++ examples/case_study/naive_linear.py | 32 +++++++ 2 files changed, 155 insertions(+) create mode 100644 examples/case_study/config_linear.py create mode 100644 examples/case_study/naive_linear.py diff --git a/examples/case_study/config_linear.py b/examples/case_study/config_linear.py new file mode 100644 index 00000000..2a583a3a --- /dev/null +++ b/examples/case_study/config_linear.py @@ -0,0 +1,123 @@ +"""Example Usage + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=6000 \ + --use_env \ + examples/case_study/config_linear.py +""" + +import torch +from torch.nn.parameter import Parameter +torch.manual_seed(121) + +# tensor parallel - split weight in column +def linear_tensor_parallel(input, weight, bias): + ### Policy need to know ### + devices = [0, 1] # how many device to perform? + + ### Necessary information to know ### + rank = torch.distributed.get_rank() # which role I participate? + + ### Additional ops need to use ### + class InputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return input_ + @staticmethod + def backward(ctx, grad_output): + return torch.distributed.all_reduce(grad_output) + + class OutputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_) + output = torch.cat(tensor_list, dim=-1) + return output + @staticmethod + def backward(ctx, grad_output): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + tensor_list = torch.split( + grad_output, grad_output.size()[-1]//world_size, dim=-1 + ) + return tensor_list[rank].contiguous() + + ### Input Adapter ### + weight = torch.chunk(weight, chunks=len(devices), dim=0)[rank].contiguous() + bias = torch.chunk(bias, chunks=len(devices), dim=0)[rank].contiguous() + input = InputAdapter.apply(input) + + ### Forward ### + output = torch._C._nn.linear(input, weight, bias) + + ### Ouput Adapter ### + # insert a forward + backward op at last (allgather - split) + output = OutputAdapter.apply(output) + return output + + +# data parallel +def linear_data_parallel(input, weight, bias): + ### Additional ops need to use ### + + ### Input Adapter ### + weight.register_hook(lambda grad: torch.distributed.allreduce(grad)) + bias.register_hook(lambda grad: torch.distributed.allreduce(grad)) + + ### Forward ### + output = torch._C._nn.linear(input, weight, bias) + + ### Output Adapter ### -> no need + return output + + + + +######### Utility ############# +def print_each_rank(msg, selected_rank=None): + myrank = torch.distributed.get_rank() + for rank in range(torch.distributed.get_world_size()): + if selected_rank is None or myrank in selected_rank: + if myrank == rank: + print('rank [{}]: {}\n'.format(rank, msg)) + torch.distributed.barrier() + + +if __name__ == '__main__': + + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + ) + torch.cuda.set_device(torch.distributed.get_rank()) + + # tensor definition + batch_size = 32 + out_features = 1024 + in_features = 1024 + weight = Parameter(torch.rand((out_features, in_features))).cuda() + # print_each_rank('weight: {}'.format(weight)) + bias = Parameter(torch.rand(out_features)).cuda() + # print_each_rank('bias: {}'.format(bias)) + input = torch.rand((batch_size, in_features)).cuda() + # print_each_rank('input: {}'.format(input)) + + # model parallel + print_each_rank('======== Model Parallel =========', [0]) + output = linear_tensor_parallel(input, weight, bias) + loss = torch.mean(output) + print_each_rank(loss) + loss.backward() + print_each_rank('======== Model Parallel =========', [0]) + + # data parallel + print_each_rank('======== Data Parallel =========', [0]) + print_each_rank('======== Data Parallel =========', [0]) diff --git a/examples/case_study/naive_linear.py b/examples/case_study/naive_linear.py new file mode 100644 index 00000000..5c7fda48 --- /dev/null +++ b/examples/case_study/naive_linear.py @@ -0,0 +1,32 @@ +import torch +from torch.nn.parameter import Parameter +torch.manual_seed(121) + + +def linear(input, weight, bias=None): + output = torch._C._nn.linear(input, weight, bias) + return output + + +if __name__ == '__main__': + + torch.cuda.set_device(0) + + # tensor definition + batch_size = 32 + out_features = 1024 + in_features = 1024 + weight = Parameter(torch.rand((out_features, in_features))).cuda() + # print('weight: ', weight) + bias = Parameter(torch.rand(out_features)).cuda() + # print('bias: ', bias) + input = torch.rand((batch_size, in_features)).cuda() + # print('input: ', input) + + # op compute + print('======== Naive Single Device =======') + output = linear(input, weight, bias) + loss = torch.mean(output) + print(loss) + loss.backward() + print('======== Naive Single Device =======') From aa390c500d03c69f483a12c556cf2f251b2463a0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Jul 2021 14:10:33 +0000 Subject: [PATCH 0035/1892] fix grad error --- examples/case_study/config_linear.py | 6 ++++-- examples/case_study/naive_linear.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/case_study/config_linear.py b/examples/case_study/config_linear.py index 2a583a3a..f8366d67 100644 --- a/examples/case_study/config_linear.py +++ b/examples/case_study/config_linear.py @@ -103,9 +103,9 @@ def print_each_rank(msg, selected_rank=None): batch_size = 32 out_features = 1024 in_features = 1024 - weight = Parameter(torch.rand((out_features, in_features))).cuda() + weight = torch.rand((out_features, in_features)).cuda().requires_grad_() # print_each_rank('weight: {}'.format(weight)) - bias = Parameter(torch.rand(out_features)).cuda() + bias = torch.rand(out_features).cuda().requires_grad_() # print_each_rank('bias: {}'.format(bias)) input = torch.rand((batch_size, in_features)).cuda() # print_each_rank('input: {}'.format(input)) @@ -116,6 +116,8 @@ def print_each_rank(msg, selected_rank=None): loss = torch.mean(output) print_each_rank(loss) loss.backward() + # note weight is created as transposed + print_each_rank('weight grad: {}'.format(weight.grad.t())) print_each_rank('======== Model Parallel =========', [0]) # data parallel diff --git a/examples/case_study/naive_linear.py b/examples/case_study/naive_linear.py index 5c7fda48..bbdead40 100644 --- a/examples/case_study/naive_linear.py +++ b/examples/case_study/naive_linear.py @@ -16,9 +16,9 @@ def linear(input, weight, bias=None): batch_size = 32 out_features = 1024 in_features = 1024 - weight = Parameter(torch.rand((out_features, in_features))).cuda() + weight = torch.rand((out_features, in_features)).cuda().requires_grad_() # print('weight: ', weight) - bias = Parameter(torch.rand(out_features)).cuda() + bias = torch.rand(out_features).cuda().requires_grad_() # print('bias: ', bias) input = torch.rand((batch_size, in_features)).cuda() # print('input: ', input) @@ -29,4 +29,5 @@ def linear(input, weight, bias=None): loss = torch.mean(output) print(loss) loss.backward() + print('weight grad: ', weight.grad.t()) print('======== Naive Single Device =======') From c730c8208ee847926b0ca89c7cf4f3ecd43a8bde Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Jul 2021 14:13:41 +0000 Subject: [PATCH 0036/1892] add data parallel example --- examples/case_study/config_linear.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/case_study/config_linear.py b/examples/case_study/config_linear.py index f8366d67..3959ecfd 100644 --- a/examples/case_study/config_linear.py +++ b/examples/case_study/config_linear.py @@ -69,8 +69,8 @@ def linear_data_parallel(input, weight, bias): ### Additional ops need to use ### ### Input Adapter ### - weight.register_hook(lambda grad: torch.distributed.allreduce(grad)) - bias.register_hook(lambda grad: torch.distributed.allreduce(grad)) + weight.register_hook(lambda grad: torch.distributed.all_reduce(grad)) + bias.register_hook(lambda grad: torch.distributed.all_reduce(grad)) ### Forward ### output = torch._C._nn.linear(input, weight, bias) @@ -121,5 +121,11 @@ def print_each_rank(msg, selected_rank=None): print_each_rank('======== Model Parallel =========', [0]) # data parallel + weight.grad = None + bias.grad = None print_each_rank('======== Data Parallel =========', [0]) + output = linear_data_parallel(input, weight, bias) + loss = torch.mean(output) + loss.backward() + print_each_rank('weight grad: {}'.format(weight.grad.t())) print_each_rank('======== Data Parallel =========', [0]) From b179c15bb9c4bb3eefc327544deb3f69b8237a2d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 17 Jul 2021 09:48:06 +0000 Subject: [PATCH 0037/1892] strange hangs on hybrid parallelisms --- examples/case_study/config_linear.py | 100 +++++++++++++++++++++++++-- 1 file changed, 96 insertions(+), 4 deletions(-) diff --git a/examples/case_study/config_linear.py b/examples/case_study/config_linear.py index 3959ecfd..6fe9bdb8 100644 --- a/examples/case_study/config_linear.py +++ b/examples/case_study/config_linear.py @@ -1,23 +1,24 @@ """Example Usage python -m torch.distributed.launch \ - --nproc_per_node=2 \ + --nproc_per_node=4 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ - --master_port=6000 \ + --master_port=62000 \ --use_env \ examples/case_study/config_linear.py """ import torch +import os from torch.nn.parameter import Parameter torch.manual_seed(121) # tensor parallel - split weight in column def linear_tensor_parallel(input, weight, bias): ### Policy need to know ### - devices = [0, 1] # how many device to perform? + devices = [0, 1, 2, 3] # how many device to perform? ### Necessary information to know ### rank = torch.distributed.get_rank() # which role I participate? @@ -79,6 +80,84 @@ def linear_data_parallel(input, weight, bias): return output +# tensor + data parallel +def linear_hybrid_tensor_data_parallel(input, weight, bias): + ### Policy need to know ### + devices = [0, 1, 2, 3] # how many device to perform? + + ### Necessary information to execute ### + rank = torch.distributed.get_rank() # which role I participate? + if rank in [0,2]: + dp_group = torch.distributed.new_group([0,2]) + else: + dp_group = torch.distributed.new_group([1,3]) + dp_rank = torch.distributed.get_rank(group=dp_group) + + if rank in [0,1]: + tp_group = torch.distributed.new_group([0,1]) + else: + tp_group = torch.distributed.new_group([2,3]) + tp_rank = torch.distributed.get_rank(group=tp_group) + tp_world_size = torch.distributed.get_world_size(group=tp_group) + # print_each_rank('tp world size: {} tp rank: {}'.format(tp_world_size, tp_rank)) + + + ### Additional Ops ### + class InputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group): + ctx.constants = group + return input_ + @staticmethod + def backward(ctx, grad_output): + group = ctx.constants + return torch.distributed.all_reduce(grad_output, group=group), None + + class OutputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group, dim=-1): + world_size = torch.distributed.get_world_size(group=group) + rank = torch.distributed.get_rank(group=group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + output = torch.cat(tensor_list, dim=dim) + ctx.constants = (group, dim) + return output + @staticmethod + def backward(ctx, grad_output): + group, dim = ctx.constants + world_size = torch.distributed.get_world_size(group=group) + rank = torch.distributed.get_rank(group=group) + tensor_list = torch.split( + grad_output, grad_output.size()[-1]//world_size, dim=dim + ) + return tensor_list[rank].contiguous(), None, None + + ### Input Adapter - Slice ### + weight = torch.chunk(weight, chunks=tp_world_size, dim=0)[tp_rank].contiguous() + bias = torch.chunk(bias, chunks=tp_world_size, dim=0)[tp_rank].contiguous() + # replicate is implicitly done due to SPMD + + ### Input Adapter - Data Parallel ### + weight.register_hook(lambda grad: torch.distributed.all_reduce(grad, group=dp_group)) + bias.register_hook(lambda grad: torch.distributed.all_reduce(grad, group=dp_group)) + + torch.distributed.barrier() + ### Input Adapter - Tensor Parallel ### + input = InputAdapter.apply(input, tp_group) + + ### Forward ### + output = torch._C._nn.linear(input, weight, bias) + + ### Output Adapter - Tensor Parallel ### + output = OutputAdapter.apply(output, tp_group, -1) + + ### Ouput Adapter - Data Parallel ### + ## No need + + return output + ######### Utility ############# @@ -93,11 +172,12 @@ def print_each_rank(msg, selected_rank=None): if __name__ == '__main__': + local_rank = int(os.environ.get('LOCAL_RANK')) + torch.cuda.set_device(local_rank) torch.distributed.init_process_group( backend='nccl', init_method='env://', ) - torch.cuda.set_device(torch.distributed.get_rank()) # tensor definition batch_size = 32 @@ -129,3 +209,15 @@ def print_each_rank(msg, selected_rank=None): loss.backward() print_each_rank('weight grad: {}'.format(weight.grad.t())) print_each_rank('======== Data Parallel =========', [0]) + #TODO: remove hook + + # hybrid tensor-data parallel + weight.grad = None + bias.grad = None + print_each_rank('======== Data + Tensor Parallel =========', [0]) + output = linear_hybrid_tensor_data_parallel(input, weight, bias) + loss = torch.mean(output) + print_each_rank(loss) + loss.backward() + print_each_rank('weight grad: {}'.format(weight.grad.t())) + print_each_rank('======== Data + Tensor Parallel =========', [0]) From c4d90d68fe8d772307af3d30681316cd2a408093 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 17 Jul 2021 11:39:02 +0000 Subject: [PATCH 0038/1892] correct impl for hybrid (tensor + data) parallelism: it turns out each torch.distributed.new_group requires called from all processes --- examples/case_study/config_linear.py | 55 +++++++++++++++++++--------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/examples/case_study/config_linear.py b/examples/case_study/config_linear.py index 6fe9bdb8..0c8d65d6 100644 --- a/examples/case_study/config_linear.py +++ b/examples/case_study/config_linear.py @@ -15,6 +15,8 @@ from torch.nn.parameter import Parameter torch.manual_seed(121) +hooks = list() + # tensor parallel - split weight in column def linear_tensor_parallel(input, weight, bias): ### Policy need to know ### @@ -70,8 +72,10 @@ def linear_data_parallel(input, weight, bias): ### Additional ops need to use ### ### Input Adapter ### - weight.register_hook(lambda grad: torch.distributed.all_reduce(grad)) - bias.register_hook(lambda grad: torch.distributed.all_reduce(grad)) + hw = weight.register_hook(lambda grad: torch.distributed.all_reduce(grad)) + hb = bias.register_hook(lambda grad: torch.distributed.all_reduce(grad)) + global hooks + hooks += [hw, hb] ### Forward ### output = torch._C._nn.linear(input, weight, bias) @@ -83,24 +87,40 @@ def linear_data_parallel(input, weight, bias): # tensor + data parallel def linear_hybrid_tensor_data_parallel(input, weight, bias): ### Policy need to know ### - devices = [0, 1, 2, 3] # how many device to perform? + tp_size = 2 # how many device to perform? + dp_size = 2 ### Necessary information to execute ### rank = torch.distributed.get_rank() # which role I participate? - if rank in [0,2]: - dp_group = torch.distributed.new_group([0,2]) - else: - dp_group = torch.distributed.new_group([1,3]) - dp_rank = torch.distributed.get_rank(group=dp_group) - - if rank in [0,1]: - tp_group = torch.distributed.new_group([0,1]) - else: - tp_group = torch.distributed.new_group([2,3]) + + # data parallel group + dp_group = None + group = torch.distributed.new_group([0,2]) + if rank in [0, 2]: + dp_group = group + group = torch.distributed.new_group([1,3]) + if rank in [1, 3]: + dp_group = group + + # tensor parallel group + tp_group = None + group = torch.distributed.new_group([0,1]) + if rank in [0, 1]: + tp_group = group + group = torch.distributed.new_group([2,3]) + if rank in [2, 3]: + tp_group = group tp_rank = torch.distributed.get_rank(group=tp_group) tp_world_size = torch.distributed.get_world_size(group=tp_group) - # print_each_rank('tp world size: {} tp rank: {}'.format(tp_world_size, tp_rank)) - + print_each_rank( + 'rank global:tp:dp=[{}:{}:{}] | size global:tp:dp=[{}:{}:{}]'.format( + torch.distributed.get_rank(), + torch.distributed.get_rank(tp_group), + torch.distributed.get_rank(dp_group), + torch.distributed.get_world_size(), + torch.distributed.get_world_size(tp_group), + torch.distributed.get_world_size(dp_group) + )) ### Additional Ops ### class InputAdapter(torch.autograd.Function): @@ -209,15 +229,16 @@ def print_each_rank(msg, selected_rank=None): loss.backward() print_each_rank('weight grad: {}'.format(weight.grad.t())) print_each_rank('======== Data Parallel =========', [0]) - #TODO: remove hook # hybrid tensor-data parallel weight.grad = None bias.grad = None + for hook in hooks: + hook.remove() print_each_rank('======== Data + Tensor Parallel =========', [0]) output = linear_hybrid_tensor_data_parallel(input, weight, bias) loss = torch.mean(output) - print_each_rank(loss) + # print_each_rank(loss) loss.backward() print_each_rank('weight grad: {}'.format(weight.grad.t())) print_each_rank('======== Data + Tensor Parallel =========', [0]) From 15b4971e623639e0dde604224fd92031d2f84171 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 17 Jul 2021 11:39:41 +0000 Subject: [PATCH 0039/1892] rename to parallel --- examples/case_study/{config_linear.py => parallel_linear.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename examples/case_study/{config_linear.py => parallel_linear.py} (99%) diff --git a/examples/case_study/config_linear.py b/examples/case_study/parallel_linear.py similarity index 99% rename from examples/case_study/config_linear.py rename to examples/case_study/parallel_linear.py index 0c8d65d6..9172a581 100644 --- a/examples/case_study/config_linear.py +++ b/examples/case_study/parallel_linear.py @@ -7,7 +7,7 @@ --master_addr=127.0.0.1 \ --master_port=62000 \ --use_env \ - examples/case_study/config_linear.py + examples/case_study/parallel_linear.py """ import torch From 6a0b31941e8d16341707b372a59b38fe8f5df7a1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 19 Jul 2021 01:20:47 +0000 Subject: [PATCH 0040/1892] update comment --- examples/case_study/parallel_linear.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/case_study/parallel_linear.py b/examples/case_study/parallel_linear.py index 9172a581..cbe38323 100644 --- a/examples/case_study/parallel_linear.py +++ b/examples/case_study/parallel_linear.py @@ -53,9 +53,11 @@ def backward(ctx, grad_output): ) return tensor_list[rank].contiguous() - ### Input Adapter ### + ### Input Slice ### weight = torch.chunk(weight, chunks=len(devices), dim=0)[rank].contiguous() bias = torch.chunk(bias, chunks=len(devices), dim=0)[rank].contiguous() + + ### Input Adapter ### input = InputAdapter.apply(input) ### Forward ### @@ -70,6 +72,7 @@ def backward(ctx, grad_output): # data parallel def linear_data_parallel(input, weight, bias): ### Additional ops need to use ### + # -> torch.distributed.all_reduce at backward ### Input Adapter ### hw = weight.register_hook(lambda grad: torch.distributed.all_reduce(grad)) @@ -87,7 +90,7 @@ def linear_data_parallel(input, weight, bias): # tensor + data parallel def linear_hybrid_tensor_data_parallel(input, weight, bias): ### Policy need to know ### - tp_size = 2 # how many device to perform? + tp_size = 2 # how many slices? which device? dp_size = 2 ### Necessary information to execute ### From 19d22282debdc5460ec89d88d794edd76b56ff00 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 19 Jul 2021 04:50:58 +0000 Subject: [PATCH 0041/1892] add comment to belongs --- examples/case_study/parallel_linear.py | 39 ++++++++++++++++---------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/examples/case_study/parallel_linear.py b/examples/case_study/parallel_linear.py index cbe38323..c9d35148 100644 --- a/examples/case_study/parallel_linear.py +++ b/examples/case_study/parallel_linear.py @@ -25,7 +25,7 @@ def linear_tensor_parallel(input, weight, bias): ### Necessary information to know ### rank = torch.distributed.get_rank() # which role I participate? - ### Additional ops need to use ### + ### Additional ops need to use ### -- TODO: System Generated class InputAdapter(torch.autograd.Function): @staticmethod def forward(ctx, input_): @@ -53,17 +53,17 @@ def backward(ctx, grad_output): ) return tensor_list[rank].contiguous() - ### Input Slice ### + ### Input Slice ### TODO: expert description on how to tile weight = torch.chunk(weight, chunks=len(devices), dim=0)[rank].contiguous() bias = torch.chunk(bias, chunks=len(devices), dim=0)[rank].contiguous() - ### Input Adapter ### + ### Input Adapter ### TODO: system generated according to segmentation input = InputAdapter.apply(input) - ### Forward ### + ### Forward ### TODO: expert description on how to compute output = torch._C._nn.linear(input, weight, bias) - ### Ouput Adapter ### + ### Ouput Adapter ### TODO: system generated according to segmentation # insert a forward + backward op at last (allgather - split) output = OutputAdapter.apply(output) return output @@ -71,19 +71,28 @@ def backward(ctx, grad_output): # data parallel def linear_data_parallel(input, weight, bias): + ### Policy need to know ### + devices = [0, 1, 2, 3] # how many device to perform? + + ### Necessary information to know ### + rank = torch.distributed.get_rank() # which role I participate? + ### Additional ops need to use ### # -> torch.distributed.all_reduce at backward - ### Input Adapter ### + ### Input Slice ### TODO: expert description on how to tile + input = torch.chunk(input, chunks=len(devices), dim=0)[rank].contiguous() + + ### Input Adapter ### TODO: system generated according to segmentation hw = weight.register_hook(lambda grad: torch.distributed.all_reduce(grad)) hb = bias.register_hook(lambda grad: torch.distributed.all_reduce(grad)) global hooks hooks += [hw, hb] - ### Forward ### + ### Forward ### TODO: expert description on how to compute output = torch._C._nn.linear(input, weight, bias) - ### Output Adapter ### -> no need + ### Output Adapter ### TODO: system generated according to segmentation return output @@ -104,6 +113,7 @@ def linear_hybrid_tensor_data_parallel(input, weight, bias): group = torch.distributed.new_group([1,3]) if rank in [1, 3]: dp_group = group + dp_rank = torch.distributed.get_rank(group=dp_group) # tensor parallel group tp_group = None @@ -157,23 +167,22 @@ def backward(ctx, grad_output): ) return tensor_list[rank].contiguous(), None, None - ### Input Adapter - Slice ### + ### Input Adapter - Slice ### TODO: expert description on how to tile + input = torch.chunk(input, chunks=dp_size, dim=0)[dp_rank].contiguous() weight = torch.chunk(weight, chunks=tp_world_size, dim=0)[tp_rank].contiguous() bias = torch.chunk(bias, chunks=tp_world_size, dim=0)[tp_rank].contiguous() - # replicate is implicitly done due to SPMD - ### Input Adapter - Data Parallel ### + ### Input Adapter - Data Parallel ### TODO: system generated according to segmentation weight.register_hook(lambda grad: torch.distributed.all_reduce(grad, group=dp_group)) bias.register_hook(lambda grad: torch.distributed.all_reduce(grad, group=dp_group)) - torch.distributed.barrier() - ### Input Adapter - Tensor Parallel ### + ### Input Adapter - Tensor Parallel ### TODO: system generated according to segmentation input = InputAdapter.apply(input, tp_group) - ### Forward ### + ### Forward ### TODO: expert description on how to compute output = torch._C._nn.linear(input, weight, bias) - ### Output Adapter - Tensor Parallel ### + ### Output Adapter - Tensor Parallel ### TODO: system generated according to segmentation output = OutputAdapter.apply(output, tp_group, -1) ### Ouput Adapter - Data Parallel ### From 2b438fae09fb914de4bba398de6f9d78c24e35ff Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 19 Jul 2021 05:13:05 +0000 Subject: [PATCH 0042/1892] add comment --- examples/case_study/parallel_linear.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/case_study/parallel_linear.py b/examples/case_study/parallel_linear.py index c9d35148..285d1367 100644 --- a/examples/case_study/parallel_linear.py +++ b/examples/case_study/parallel_linear.py @@ -84,8 +84,12 @@ def linear_data_parallel(input, weight, bias): input = torch.chunk(input, chunks=len(devices), dim=0)[rank].contiguous() ### Input Adapter ### TODO: system generated according to segmentation - hw = weight.register_hook(lambda grad: torch.distributed.all_reduce(grad)) - hb = bias.register_hook(lambda grad: torch.distributed.all_reduce(grad)) + def grad_hook(grad): + torch.distributed.all_reduce(grad) + grad /= len(devices) + return grad + hw = weight.register_hook(grad_hook) + hb = bias.register_hook(grad_hook) global hooks hooks += [hw, hb] @@ -173,8 +177,14 @@ def backward(ctx, grad_output): bias = torch.chunk(bias, chunks=tp_world_size, dim=0)[tp_rank].contiguous() ### Input Adapter - Data Parallel ### TODO: system generated according to segmentation - weight.register_hook(lambda grad: torch.distributed.all_reduce(grad, group=dp_group)) - bias.register_hook(lambda grad: torch.distributed.all_reduce(grad, group=dp_group)) + def grad_hook(grad): + torch.distributed.all_reduce(grad, group=dp_group) + grad /= dp_size + return grad + hw = weight.register_hook(grad_hook) + hb = bias.register_hook(grad_hook) + global hooks + hooks += [hw, hb] ### Input Adapter - Tensor Parallel ### TODO: system generated according to segmentation input = InputAdapter.apply(input, tp_group) From f601d4b17d28ba6166a5df421397d889195d6bb2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 04:31:41 +0000 Subject: [PATCH 0043/1892] not work yet: swap --- examples/case_study/memory_linear.py | 154 +++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 examples/case_study/memory_linear.py diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py new file mode 100644 index 00000000..d6600682 --- /dev/null +++ b/examples/case_study/memory_linear.py @@ -0,0 +1,154 @@ +import torch +import os + +torch.manual_seed(121) + +### Checkpoint PyTorch Implementation (Skip un-deterministic scenario) ### +# Note this implementation can only work with a module that consists +# multiple operators. This will won't work for one OP because the output +# for this module will be saved in next op +def checkpoint_module_linear(input, weight, bias): + + class Checkpoint(torch.autograd.Function): + """General class to wrapper op to enable checkpoint""" + @staticmethod + def forward(ctx, run_function, *args): + ctx.run_function = run_function + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + return outputs + @staticmethod + def backward(ctx, *args): + # retrieve what need to regenerate tensors + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + # re-generate + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + # detach inputs + detached_inputs = list() + for input in inputs: + if torch.is_tensor(input): + x = input.detach() + x.requires_grad = input.requires_grad + else: + x = input + detached_inputs.append(x) + detached_inputs = tuple(detached_inputs) + # generate output tensor + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + if torch.is_tensor(outputs): + outputs = (outputs,) + # run backward to tensors that require a grad + outputs_with_grad = list() + args_with_grad = list() + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs) + return (None, None) + grads + + output = Checkpoint.apply(torch._C._nn.linear, input, weight, bias) + return output + + +### Swap linear ### +def swap_linear(input, weight, bias): + ## Note pytorch tensor.to() will always return a copy + + ### Policy ### + op_device_id = 0 # where to perform the device + output_swap = True # whether output tensor needs swap + + ### Additional Swap operator ### + # Note autograd will not work in pytorch + # as pytorch will record each input, even you do the inplacement-update + # class SwapOutTensor(torch.autograd.Function): + # @staticmethod + # def forward(ctx, tensor): + # ctx.constants = tensor.get_device() + # cpu_tensor = tensor.cpu() + # tensor.data = cpu_tensor # inplace-update + # return tensor + # @staticmethod + # def backward(ctx, grad_output): + # device_id = ctx.constants + # grad = grad_output.cuda(device_id) + # grad_output.data = grad + # return grad_output + + ### Input swap-in (if needed) ### + input_swap = None + if input.get_device() != op_device_id: + input = input.cuda(op_device_id) + input_swap = -1 # CPU + weight_swap = None + if weight.get_device() != op_device_id: + weight_swap = -1 # CPU + weight = weight.cuda(op_device_id) + bias_swap = None + if bias.get_device() != op_device_id: + bias_swap = -1 # CPU + bias = bias.cuda(op_device_id) + + ### Compute ### + output = torch._C._nn.linear(input, weight, bias) + print(output) + # inplacement update + output.data = output.cpu() + print(output) + + # Here we need the backward to take back the intermediate tensor + + ### Swap out if needed ### TODO: swapout can be in any place + # if output_swap: + # output = SwapOutTensor.apply(output) + # if input_swap == -1: + # input_swap = SwapOutTensor.apply(input, ) + return output + + + +if __name__ == '__main__': + + torch.cuda.set_device(0) + + # tensor definition + batch_size = 32 + out_features = 1024 + in_features = 1024 + weight = torch.rand((out_features, in_features)).cuda().requires_grad_() + # print('weight: ', weight) + bias = torch.rand(out_features).cuda().requires_grad_() + # print('bias: ', bias) + input = torch.rand((batch_size, in_features)).cuda() + # print('input: ', input) + + # op compute + print('======== Checkpointing Single Device =======') + output = swap_linear(input, weight, bias) + print('output device: {}'.format(output.get_device())) + print(output) + output = output.cuda() + print('output device: {}'.format(output.get_device())) + print(output) + loss = torch.mean(output) + print(loss) + loss.backward() + print('weight grad: ', weight.grad.t()) + print('======== Checkpointing Single Device =======') From a490dcafa9902db52eb635c363e23dd74fa3e6ac Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 05:36:28 +0000 Subject: [PATCH 0044/1892] only enable (not sure) the weight / grad wapping --- examples/case_study/memory_linear.py | 78 +++++++++++----------------- 1 file changed, 31 insertions(+), 47 deletions(-) diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py index d6600682..ac3723a9 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/memory_linear.py @@ -68,58 +68,43 @@ def backward(ctx, *args): ### Swap linear ### -def swap_linear(input, weight, bias): +def swap_weight_grad_linear(input, weight, bias): ## Note pytorch tensor.to() will always return a copy ### Policy ### - op_device_id = 0 # where to perform the device - output_swap = True # whether output tensor needs swap - - ### Additional Swap operator ### - # Note autograd will not work in pytorch - # as pytorch will record each input, even you do the inplacement-update - # class SwapOutTensor(torch.autograd.Function): - # @staticmethod - # def forward(ctx, tensor): - # ctx.constants = tensor.get_device() - # cpu_tensor = tensor.cpu() - # tensor.data = cpu_tensor # inplace-update - # return tensor - # @staticmethod - # def backward(ctx, grad_output): - # device_id = ctx.constants - # grad = grad_output.cuda(device_id) - # grad_output.data = grad - # return grad_output + op_device_id = 0 # where to perform the device + # output_swap = False # whether output tensor needs swap + weight_swap = True + bias_swap = True + gradient_swap = True ### Input swap-in (if needed) ### - input_swap = None - if input.get_device() != op_device_id: - input = input.cuda(op_device_id) - input_swap = -1 # CPU - weight_swap = None - if weight.get_device() != op_device_id: - weight_swap = -1 # CPU - weight = weight.cuda(op_device_id) - bias_swap = None - if bias.get_device() != op_device_id: - bias_swap = -1 # CPU - bias = bias.cuda(op_device_id) + weight_locate = weight.get_device() + if weight_locate == -1: + weight.data = weight.cuda(op_device_id) + bias_locate = bias.get_device() + if bias_locate == -1: # current on CPU + bias.data = bias.cuda(op_device_id) + + ### Adatper to swap out gradient ### + def swap_out_grad(grad): + grad.data = grad.cpu() + return grad + if gradient_swap: + weight.register_hook(swap_out_grad) + bias.register_hook(swap_out_grad) ### Compute ### output = torch._C._nn.linear(input, weight, bias) - print(output) - # inplacement update - output.data = output.cpu() - print(output) - - # Here we need the backward to take back the intermediate tensor + # inplacement swap + # output.data = output.cpu() ### Swap out if needed ### TODO: swapout can be in any place - # if output_swap: - # output = SwapOutTensor.apply(output) - # if input_swap == -1: - # input_swap = SwapOutTensor.apply(input, ) + if weight_swap: + weight.data = weight.cpu() + if bias_swap: + bias.data = bias.cpu() + return output @@ -141,12 +126,11 @@ def swap_linear(input, weight, bias): # op compute print('======== Checkpointing Single Device =======') - output = swap_linear(input, weight, bias) - print('output device: {}'.format(output.get_device())) - print(output) - output = output.cuda() + # first locate on cpu + weight.data = weight.cpu() + bias.data = bias.cpu() + output = swap_weight_grad_linear(input, weight, bias) print('output device: {}'.format(output.get_device())) - print(output) loss = torch.mean(output) print(loss) loss.backward() From fdcf810b622ebc30f268a8d0775997efa204ec2c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 07:00:03 +0000 Subject: [PATCH 0045/1892] enable weight/bias swapping --- examples/case_study/memory_linear.py | 42 +++++++++++++++++++--------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py index ac3723a9..fb93a99d 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/memory_linear.py @@ -88,7 +88,7 @@ def swap_weight_grad_linear(input, weight, bias): ### Adatper to swap out gradient ### def swap_out_grad(grad): - grad.data = grad.cpu() + grad.data = grad.detach().cpu() return grad if gradient_swap: weight.register_hook(swap_out_grad) @@ -101,9 +101,11 @@ def swap_out_grad(grad): ### Swap out if needed ### TODO: swapout can be in any place if weight_swap: - weight.data = weight.cpu() + weight.data = weight.detach().cpu() if bias_swap: - bias.data = bias.cpu() + bias.data = bias.detach().cpu() + # print(weight) + # print(bias) return output @@ -112,27 +114,41 @@ def swap_out_grad(grad): if __name__ == '__main__': torch.cuda.set_device(0) + init_memory = torch.cuda.memory_allocated() # tensor definition batch_size = 32 - out_features = 1024 - in_features = 1024 + out_features = 10240 + in_features = 10240 ## 100 MB weight weight = torch.rand((out_features, in_features)).cuda().requires_grad_() # print('weight: ', weight) bias = torch.rand(out_features).cuda().requires_grad_() - # print('bias: ', bias) input = torch.rand((batch_size, in_features)).cuda() - # print('input: ', input) + + input_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 # op compute print('======== Checkpointing Single Device =======') - # first locate on cpu - weight.data = weight.cpu() - bias.data = bias.cpu() + + # swap out weight + weight.data = weight.detach().cpu() + bias.data = bias.detach().cpu() + + weight_swap_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 + output = swap_weight_grad_linear(input, weight, bias) - print('output device: {}'.format(output.get_device())) - loss = torch.mean(output) - print(loss) + loss = torch.mean(output) * 100 loss.backward() + + finish_op_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 + + # allocate tensor on gpu to see if swap workds + tmp = torch.rand((out_features, in_features)).cuda() + after_alloc_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 + + print('memory consumption (MB): input-require: {:.2f} | after swap weight: {:.2f} | after op run {:.2f} | after allocate {:.2f}'.format( + input_memory, weight_swap_memory, finish_op_memory, after_alloc_memory)) + + # correctness verify print('weight grad: ', weight.grad.t()) print('======== Checkpointing Single Device =======') From 5d275d6df92cbfbdfa30a659bd4e27e6b773f416 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 07:35:23 +0000 Subject: [PATCH 0046/1892] enlarge tensor size --- examples/case_study/naive_linear.py | 6 +++--- examples/case_study/parallel_linear.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/case_study/naive_linear.py b/examples/case_study/naive_linear.py index bbdead40..c370abcd 100644 --- a/examples/case_study/naive_linear.py +++ b/examples/case_study/naive_linear.py @@ -14,8 +14,8 @@ def linear(input, weight, bias=None): # tensor definition batch_size = 32 - out_features = 1024 - in_features = 1024 + out_features = 10240 + in_features = 10240 weight = torch.rand((out_features, in_features)).cuda().requires_grad_() # print('weight: ', weight) bias = torch.rand(out_features).cuda().requires_grad_() @@ -26,7 +26,7 @@ def linear(input, weight, bias=None): # op compute print('======== Naive Single Device =======') output = linear(input, weight, bias) - loss = torch.mean(output) + loss = torch.mean(output) * 100 print(loss) loss.backward() print('weight grad: ', weight.grad.t()) diff --git a/examples/case_study/parallel_linear.py b/examples/case_study/parallel_linear.py index 285d1367..2dd798a2 100644 --- a/examples/case_study/parallel_linear.py +++ b/examples/case_study/parallel_linear.py @@ -223,8 +223,8 @@ def print_each_rank(msg, selected_rank=None): # tensor definition batch_size = 32 - out_features = 1024 - in_features = 1024 + out_features = 10240 + in_features = 10240 weight = torch.rand((out_features, in_features)).cuda().requires_grad_() # print_each_rank('weight: {}'.format(weight)) bias = torch.rand(out_features).cuda().requires_grad_() @@ -235,7 +235,7 @@ def print_each_rank(msg, selected_rank=None): # model parallel print_each_rank('======== Model Parallel =========', [0]) output = linear_tensor_parallel(input, weight, bias) - loss = torch.mean(output) + loss = torch.mean(output) * 100 print_each_rank(loss) loss.backward() # note weight is created as transposed @@ -247,7 +247,7 @@ def print_each_rank(msg, selected_rank=None): bias.grad = None print_each_rank('======== Data Parallel =========', [0]) output = linear_data_parallel(input, weight, bias) - loss = torch.mean(output) + loss = torch.mean(output) * 100 loss.backward() print_each_rank('weight grad: {}'.format(weight.grad.t())) print_each_rank('======== Data Parallel =========', [0]) @@ -259,7 +259,7 @@ def print_each_rank(msg, selected_rank=None): hook.remove() print_each_rank('======== Data + Tensor Parallel =========', [0]) output = linear_hybrid_tensor_data_parallel(input, weight, bias) - loss = torch.mean(output) + loss = torch.mean(output) * 100 # print_each_rank(loss) loss.backward() print_each_rank('weight grad: {}'.format(weight.grad.t())) From 26646c016f55f091bdfad46b1bd6f117e39a4091 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 09:27:57 +0000 Subject: [PATCH 0047/1892] finally make it work by customize forward / backward --- examples/case_study/memory_linear.py | 129 ++++++++++++++++----------- 1 file changed, 78 insertions(+), 51 deletions(-) diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py index fb93a99d..36e0943d 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/memory_linear.py @@ -3,6 +3,8 @@ torch.manual_seed(121) +tensor_map = dict() + ### Checkpoint PyTorch Implementation (Skip un-deterministic scenario) ### # Note this implementation can only work with a module that consists # multiple operators. This will won't work for one OP because the output @@ -67,48 +69,73 @@ def backward(ctx, *args): return output -### Swap linear ### -def swap_weight_grad_linear(input, weight, bias): - ## Note pytorch tensor.to() will always return a copy - - ### Policy ### - op_device_id = 0 # where to perform the device - # output_swap = False # whether output tensor needs swap - weight_swap = True - bias_swap = True - gradient_swap = True - - ### Input swap-in (if needed) ### - weight_locate = weight.get_device() - if weight_locate == -1: - weight.data = weight.cuda(op_device_id) - bias_locate = bias.get_device() - if bias_locate == -1: # current on CPU - bias.data = bias.cuda(op_device_id) - - ### Adatper to swap out gradient ### - def swap_out_grad(grad): - grad.data = grad.detach().cpu() - return grad - if gradient_swap: - weight.register_hook(swap_out_grad) - bias.register_hook(swap_out_grad) - - ### Compute ### - output = torch._C._nn.linear(input, weight, bias) - # inplacement swap - # output.data = output.cpu() - - ### Swap out if needed ### TODO: swapout can be in any place - if weight_swap: - weight.data = weight.detach().cpu() - if bias_swap: - bias.data = bias.detach().cpu() - # print(weight) - # print(bias) +def swap_weight_grad_linear_v2(input, weight, bias): - return output + class SwapLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias, swap_weight=True, swap_bias=True): + + weight_id = id(weight) + bias_id = id(bias) + ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id)) + tensor_map[weight_id] = weight + tensor_map[bias_id] = bias_id + + ctx.constants = (swap_weight, swap_bias) + + # retrieve from cpu memory + if swap_weight: + weight.data = weight.detach().cuda() + if swap_bias: + bias.data = bias.detach().cuda() + # compute + output = torch._C._nn.linear(input, weight, bias) + + # offload to CPU + if swap_weight: + weight.data = weight.detach().cpu() + if swap_bias: + bias.data = bias.detach().cpu() + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight_id, bias_id = ctx.saved_tensors + weight = tensor_map[weight_id.item()] + bias = tensor_map[bias_id.item()] + swap_weight, swap_bias = ctx.constants + + grad_input = grad_weight = grad_bas = None + if ctx.needs_input_grad[0]: + print('computing grad of input...') + # retrieve weight + if swap_weight: + weight.data = weight.cuda() + grad_input = grad_output.matmul(weight) + if swap_weight: + weight.data = weight.detach().cpu() + if ctx.needs_input_grad[1]: + dim = grad_output.dim() + if dim > 2: + grad_weight = grad\ + .view(-1, grad_output.shape[-1])\ + .t()\ + .matmul(input.view(-1, input.shape[-1])) + else: + grad_weight = grad_output.t().matmul(input) + if swap_weight: + grad_weight.data = grad_weight.detach().cpu() + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + if swap_bias: + grad_bias.data = grad_bias.detach().cpu() + print('here') + return grad_input, grad_weight, grad_bias, None, None + + output = SwapLinear.apply(input, weight, bias, + True, True) + return output if __name__ == '__main__': @@ -120,23 +147,22 @@ def swap_out_grad(grad): batch_size = 32 out_features = 10240 in_features = 10240 ## 100 MB weight - weight = torch.rand((out_features, in_features)).cuda().requires_grad_() - # print('weight: ', weight) - bias = torch.rand(out_features).cuda().requires_grad_() + weight_1 = torch.rand((out_features, in_features)).requires_grad_() + bias_1 = torch.rand(out_features).requires_grad_() + weight_2 = torch.rand((out_features, in_features)).requires_grad_() + bias_2 = torch.rand(out_features).requires_grad_() input = torch.rand((batch_size, in_features)).cuda() input_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 # op compute print('======== Checkpointing Single Device =======') - - # swap out weight - weight.data = weight.detach().cpu() - bias.data = bias.detach().cpu() weight_swap_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - output = swap_weight_grad_linear(input, weight, bias) + output = swap_weight_grad_linear_v2(input, weight_1, bias_1) + print('output: {}'.format(output)) + output = swap_weight_grad_linear_v2(output, weight_2, bias_2) loss = torch.mean(output) * 100 loss.backward() @@ -146,9 +172,10 @@ def swap_out_grad(grad): tmp = torch.rand((out_features, in_features)).cuda() after_alloc_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - print('memory consumption (MB): input-require: {:.2f} | after swap weight: {:.2f} | after op run {:.2f} | after allocate {:.2f}'.format( - input_memory, weight_swap_memory, finish_op_memory, after_alloc_memory)) + max_allocated = (torch.cuda.max_memory_allocated() - init_memory) / 1024 / 1024 + print('memory consumption (MB): max allocated: {:.2f} | input-require: {:.2f} | after swap weight: {:.2f} | after op run {:.2f} | after allocate {:.2f}'.format( + max_allocated, input_memory, weight_swap_memory, finish_op_memory, after_alloc_memory)) # correctness verify - print('weight grad: ', weight.grad.t()) + print('weight grad: ', weight_2.grad.t()) print('======== Checkpointing Single Device =======') From 8a6f76d87b9d6d47d66abae8bc7a5689feaeafe7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 10:57:57 +0000 Subject: [PATCH 0048/1892] finally this is the right version with some abstractions --- examples/case_study/memory_linear.py | 68 +++++++++++++++++----------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py index 36e0943d..7482931a 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/memory_linear.py @@ -71,32 +71,46 @@ def backward(ctx, *args): def swap_weight_grad_linear_v2(input, weight, bias): + # op placement + op_device = torch.device('cuda:0') + + # tensor placement: this should be set at tensor creation stage + # note here if change this, we also need to change tensor init at main + weight.host_device = torch.device('cpu') + bias.host_device = torch.device('cpu') + + # grad placement: this can be set before running + grad_device = torch.device('cuda:0') + def grad_swap(grad): + grad.data = grad.detach().to(grad_device) + return grad + weight.register_hook(grad_swap) + bias.register_hook(grad_swap) + class SwapLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, bias, swap_weight=True, swap_bias=True): + def forward(ctx, input, weight, bias): weight_id = id(weight) bias_id = id(bias) ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id)) tensor_map[weight_id] = weight - tensor_map[bias_id] = bias_id - - ctx.constants = (swap_weight, swap_bias) + tensor_map[bias_id] = bias # retrieve from cpu memory - if swap_weight: - weight.data = weight.detach().cuda() - if swap_bias: - bias.data = bias.detach().cuda() + if weight.device != op_device: + weight.data = weight.detach().to(op_device) + if bias.get_device() != op_device: + bias.data = bias.detach().to(op_device) # compute output = torch._C._nn.linear(input, weight, bias) # offload to CPU - if swap_weight: - weight.data = weight.detach().cpu() - if swap_bias: - bias.data = bias.detach().cpu() + if weight.device != weight.host_device: + weight.data = weight.detach().to(weight.host_device) + if bias.device != bias.host_device: + bias.data = bias.detach().to(bias.host_device) return output @staticmethod @@ -104,17 +118,16 @@ def backward(ctx, grad_output): input, weight_id, bias_id = ctx.saved_tensors weight = tensor_map[weight_id.item()] bias = tensor_map[bias_id.item()] - swap_weight, swap_bias = ctx.constants grad_input = grad_weight = grad_bas = None if ctx.needs_input_grad[0]: print('computing grad of input...') # retrieve weight - if swap_weight: - weight.data = weight.cuda() + if weight.device != op_device: + weight.data = weight.detach().to(op_device) grad_input = grad_output.matmul(weight) - if swap_weight: - weight.data = weight.detach().cpu() + if weight.device != weight.host_device: + weight.data = weight.detach().to(weight.host_device) if ctx.needs_input_grad[1]: dim = grad_output.dim() if dim > 2: @@ -124,17 +137,20 @@ def backward(ctx, grad_output): .matmul(input.view(-1, input.shape[-1])) else: grad_weight = grad_output.t().matmul(input) - if swap_weight: - grad_weight.data = grad_weight.detach().cpu() if ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) - if swap_bias: - grad_bias.data = grad_bias.detach().cpu() - print('here') - return grad_input, grad_weight, grad_bias, None, None + + ### Move gradient to it's tensor host device ### + ### WARNING: there will be up to 2 redundant I/O if we require + ### gradient to place differently with its tensor + if grad_weight is not None: + grad_weight.data = grad_weight.detach().to(weight.host_device) + if grad_bias is not None: + grad_bias.data = grad_bias.detach().to(bias.host_device) + + return grad_input, grad_weight, grad_bias - output = SwapLinear.apply(input, weight, bias, - True, True) + output = SwapLinear.apply(input, weight, bias) return output @@ -161,7 +177,7 @@ def backward(ctx, grad_output): weight_swap_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 output = swap_weight_grad_linear_v2(input, weight_1, bias_1) - print('output: {}'.format(output)) + # print('output: {}'.format(output)) output = swap_weight_grad_linear_v2(output, weight_2, bias_2) loss = torch.mean(output) * 100 loss.backward() From 431a51ae83d74be38db9ac5951f5e7a1416b6982 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 11:08:20 +0000 Subject: [PATCH 0049/1892] clean memory offloading --- examples/case_study/memory_linear.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py index 7482931a..32ea41b8 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/memory_linear.py @@ -69,7 +69,7 @@ def backward(ctx, *args): return output -def swap_weight_grad_linear_v2(input, weight, bias): +def swap_weight_grad_linear(input, weight, bias): # op placement op_device = torch.device('cuda:0') @@ -80,7 +80,7 @@ def swap_weight_grad_linear_v2(input, weight, bias): bias.host_device = torch.device('cpu') # grad placement: this can be set before running - grad_device = torch.device('cuda:0') + grad_device = torch.device('cpu') def grad_swap(grad): grad.data = grad.detach().to(grad_device) return grad @@ -165,21 +165,20 @@ def backward(ctx, grad_output): in_features = 10240 ## 100 MB weight weight_1 = torch.rand((out_features, in_features)).requires_grad_() bias_1 = torch.rand(out_features).requires_grad_() + input = torch.rand((batch_size, in_features)).cuda() weight_2 = torch.rand((out_features, in_features)).requires_grad_() bias_2 = torch.rand(out_features).requires_grad_() - input = torch.rand((batch_size, in_features)).cuda() input_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 # op compute - print('======== Checkpointing Single Device =======') - + print('======== Offloading Single Device =======') weight_swap_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - output = swap_weight_grad_linear_v2(input, weight_1, bias_1) - # print('output: {}'.format(output)) - output = swap_weight_grad_linear_v2(output, weight_2, bias_2) + output = swap_weight_grad_linear(input, weight_1, bias_1) + output = swap_weight_grad_linear(output, weight_2, bias_2) loss = torch.mean(output) * 100 + print('loss: {}'.format(loss)) loss.backward() finish_op_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 @@ -193,5 +192,5 @@ def backward(ctx, grad_output): max_allocated, input_memory, weight_swap_memory, finish_op_memory, after_alloc_memory)) # correctness verify - print('weight grad: ', weight_2.grad.t()) - print('======== Checkpointing Single Device =======') + print('weight grad: ', weight_1.grad.t()) + print('======== Offloading Single Device =======') From 4ed1783e8ef5989eb2f88b8aac702bed6b3479b8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 11:28:55 +0000 Subject: [PATCH 0050/1892] add policy description --- examples/case_study/memory_linear.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py index 32ea41b8..9bb0d84d 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/memory_linear.py @@ -71,6 +71,8 @@ def backward(ctx, *args): def swap_weight_grad_linear(input, weight, bias): + ### Policy ### + # op placement op_device = torch.device('cuda:0') @@ -87,6 +89,11 @@ def grad_swap(grad): weight.register_hook(grad_swap) bias.register_hook(grad_swap) + ## Timing when a tensor swapped in/out + ## On-demand? Pre-fetch? All-consumed? + + ##### + class SwapLinear(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias): From 95eb9997e00c7c0fa17284c9b564a9cb736f7bd8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Jul 2021 12:03:36 +0000 Subject: [PATCH 0051/1892] move max allocated before allocation --- examples/case_study/memory_linear.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py index 9bb0d84d..62ee7890 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/memory_linear.py @@ -82,7 +82,7 @@ def swap_weight_grad_linear(input, weight, bias): bias.host_device = torch.device('cpu') # grad placement: this can be set before running - grad_device = torch.device('cpu') + grad_device = torch.device('cuda:0') def grad_swap(grad): grad.data = grad.detach().to(grad_device) return grad @@ -189,14 +189,14 @@ def backward(ctx, grad_output): loss.backward() finish_op_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 + max_allocated = (torch.cuda.max_memory_allocated() - init_memory) / 1024 / 1024 # allocate tensor on gpu to see if swap workds tmp = torch.rand((out_features, in_features)).cuda() after_alloc_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - max_allocated = (torch.cuda.max_memory_allocated() - init_memory) / 1024 / 1024 - print('memory consumption (MB): max allocated: {:.2f} | input-require: {:.2f} | after swap weight: {:.2f} | after op run {:.2f} | after allocate {:.2f}'.format( - max_allocated, input_memory, weight_swap_memory, finish_op_memory, after_alloc_memory)) + print('Memory Consumption (MB):\n\t input-require: {:.2f}\n\t after swap weight: {:.2f}\n\t after op run {:.2f}\n\t max allocated: {:.2f}\n\t after allocate {:.2f}'.format( + input_memory, weight_swap_memory, finish_op_memory, max_allocated, after_alloc_memory)) # correctness verify print('weight grad: ', weight_1.grad.t()) From fb78c40394b06e67a65192885df7d43b07dd7845 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 21 Jul 2021 02:29:19 +0000 Subject: [PATCH 0052/1892] enrich working flow with more details --- cube/operator/holist/generics.py | 33 +++++++++++++++++++++++++++----- cube/operator/holist/linear.py | 10 +++++----- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index 09d176f7..d6a8af17 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -23,11 +23,11 @@ def __init__(self, on the logical required format. """ # holistic layout of input - self.input_layout = dict() + self.input_layout = input_layout self.input_format = input_format # holistic layout of output - self.output_layout = dict() + self.output_layout = output_layout self.output_format = output_format def input_adapter(self, args, **kwargs): @@ -36,17 +36,40 @@ def input_adapter(self, args, **kwargs): input layout requirement """ # step 1: data reformat based on the input argument + #TODO: data dimension format transformation + tensor_inputs = list() + for arg in args: + #TODO: kwargs + if cube.is_tensor(arg): + tensor_inputs.append(arg) + tensor_segments = list() + for outliner, tensor in zip(self.input_layout, tensor_inputs): + segments = outliner(tensor.shape) + tensor_segments.append(segments) # step 2: physical tensor placement (policy) + #TODO: policy module + tensor_communities = policy_module(tensor_segments) - # step 3: community matching - pass + # step 3: community matching + for communities, tensor in zip(tensor_communities, tensor_inputs): + tensor.match(communities) def output_adapter(self, outputs): """ Data reformat to logical op format """ - pass + if not isinstance(outputs, tuple): + outputs = (outputs,) + output_tensors = list() + for output in outputs: + if cube.is_tensor(output): + if cube.is_tensor(output): + output_tensors.append(output) + for outliner, output in zip(self.output_layout, output_tensors): + segments = outliner(output.shape) + output.to_logic_tensor(segments) + def forward(self, args, **kwargs): """Expert code for doing operation diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index 8c58807a..86057076 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -3,7 +3,7 @@ import cube.operator.physic as physic_op from cube.tensor.logic.tensor import LogicalTensor -from cube.tensor.logic.segment import TileSegment +import cube.tensor.logic.segment.outline as outline from cube.tensor.community import Community # expert space to declare all kinds of holistic operators @@ -19,13 +19,13 @@ class LinearColumnWeight(GenericHolisticOp): def __init__(self): # TODO - inputs_layout = None + inputs_layout = Full # TODO - weight_layout = None + weight_layout = outline.SplitAxis(axis=0, chunk_num=None, overlap=0) # TODO - bias_layout = None + bias_layout = outline.SplitAxis(axis=0, chunk_num=None, overlap=0) # TODO - output_layout = None + output_layout = outline.Align(weight_layout) super().__init__( input_layout=(inputs_layout, weight_layout), From 19a1210a2a486aab44bf8ec57a1f07814d16b8a1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 21 Jul 2021 02:35:05 +0000 Subject: [PATCH 0053/1892] update outliner --- cube/tensor/logic/segment/outline.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py index f7a5aa09..f12f3df1 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/segment/outline.py @@ -13,9 +13,17 @@ # interface to setup restrictions on the segmentation +class Full: + + def __init__(self): + pass + + def __call__(self, shape): + pass + class SplitAxis: - def __init__(self, axis, chunk_num=None, chunk_size=None, overlap=0): + def __init__(self, axis, chunk_num=None, overlap=0): """ Segmentation Pattern Requirement (parameters): @@ -27,12 +35,7 @@ def __init__(self, axis, chunk_num=None, chunk_size=None, overlap=0): If an integer, only the specified chunk number is valid; If a tuple(min, max), the chunk number wihtin the scope [min,max] is valid - chunk_size (None, int, tuple(int, int)): - valid chunk size. - If None, any size is valid; - If an integer, each chunk size is valid; - if a tuple(min, max), the chunk size wihtin the scope [min,max] is valid - + overlap (0, int, tuple(int, int)): valid size for overlaping on the boundary of each splitted chunks. If None, any overlapping is valid @@ -42,11 +45,10 @@ def __init__(self, axis, chunk_num=None, chunk_size=None, overlap=0): """ self.axis = axis self.chunk_num = chunk_num - self.chunk_size = chunk_size self.overlap = overlap def __call__(self, shape): """ - Runtime community generation given the logical tensor shape + Runtime segment generation given the logical tensor shape """ pass From 9f52e0fa37947f7614577da086ad12807a9a9b81 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 21 Jul 2021 08:43:21 +0000 Subject: [PATCH 0054/1892] add physical tensor interface with checkpoint / swap --- cube/tensor/physic/tensor.py | 57 ++++++++++++++++++++++++++++++++++++ tests/test_physic_tensor.py | 25 ++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 tests/test_physic_tensor.py diff --git a/cube/tensor/physic/tensor.py b/cube/tensor/physic/tensor.py index e69de29b..480dc532 100644 --- a/cube/tensor/physic/tensor.py +++ b/cube/tensor/physic/tensor.py @@ -0,0 +1,57 @@ +import torch + + +class PhysicTensor(torch.Tensor): + """ + Additional attributes on top of PyTorch Tensor: + + data_host_device: + Tensor data placement. The device is responsible + for managing the tensor data. + + grad_host_device: + Gradient data placement. The device is responsible + for managing the gradient data. If no grad required + for this Tensor, this option won't have impact + """ + @property + def data_host_device(self): + if not hasattr(self, '_data_host_device'): + self._data_host_device = self.device + return self._data_host_device + + + @data_host_device.setter + def data_host_device(self, device): + if not isinstance(device, torch.device): + raise TypeError('Expected torch.device') + self._data_host_device = device + # inplacement move device to the host place + if self.device != self.data_host_device: + self.data = self.to(self.data_host_device) + + + @property + def grad_host_device(self): + if not hasattr(self, '_grad_host_device'): + self._grad_host_device = self.device + return self._grad_host_device + + + @grad_host_device.setter + def grad_host_device(self, device): + if not isinstance(device, torch.device): + raise TypeError('Expected torch.device') + self._grad_host_device = device + # inplacement move device to the host place + if self.device != self.grad_host_device: + self.data = self.to(self.grad_host_device) + + + def move_(self, device): + """ + inplacement device movement + """ + if not isinstance(device, torch.device): + raise TypeError('Expected torch.device') + self.data = self.to(device) diff --git a/tests/test_physic_tensor.py b/tests/test_physic_tensor.py new file mode 100644 index 00000000..e96f88ea --- /dev/null +++ b/tests/test_physic_tensor.py @@ -0,0 +1,25 @@ +import torch +from cube.tensor.physic.tensor import PhysicTensor + +def test_type(): + tensor1 = PhysicTensor([1,2,3,4]) + tensor2 = PhysicTensor([2,3,4,5]) + tensor_out = tensor1 + tensor2 + assert isinstance(tensor_out, PhysicTensor) + + +def test_data_host_device(): + tensor = PhysicTensor([1,2,3,4]) + assert tensor.data_host_device == torch.device('cpu') + tensor.data_host_device = torch.device('cuda:0') + assert tensor.device == torch.device('cuda:0') + tensor.move_(torch.device('cpu')) + assert tensor.device == torch.device('cpu') + + +if __name__ == '__main__': + + test_type() + test_data_host_device() + + print('test passed') \ No newline at end of file From 9e6479e0d16675de7b37002010465bb6ce6951d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 21 Jul 2021 09:02:56 +0000 Subject: [PATCH 0055/1892] grad hook for host device --- cube/tensor/physic/tensor.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/cube/tensor/physic/tensor.py b/cube/tensor/physic/tensor.py index 480dc532..4a04949e 100644 --- a/cube/tensor/physic/tensor.py +++ b/cube/tensor/physic/tensor.py @@ -20,21 +20,21 @@ def data_host_device(self): self._data_host_device = self.device return self._data_host_device - @data_host_device.setter def data_host_device(self, device): if not isinstance(device, torch.device): raise TypeError('Expected torch.device') self._data_host_device = device - # inplacement move device to the host place + # inplacement movement to host device if self.device != self.data_host_device: - self.data = self.to(self.data_host_device) - + self.move_(self.data_host_device) @property def grad_host_device(self): + if not hasattr(self, '_grad_host_hook'): + self._grad_host_hook = None if not hasattr(self, '_grad_host_device'): - self._grad_host_device = self.device + self._grad_host_device = self._data_host_device return self._grad_host_device @@ -43,9 +43,16 @@ def grad_host_device(self, device): if not isinstance(device, torch.device): raise TypeError('Expected torch.device') self._grad_host_device = device - # inplacement move device to the host place - if self.device != self.grad_host_device: - self.data = self.to(self.grad_host_device) + # inplacement movement to host device + if self.grad is not None: + self.grad.data = self.grad.detach().to(self.grad_host_device) + # modify hooks + if self._grad_host_hook is not None: + self._grad_host_hook.remove() + def move_grad(grad): + grad.data = grad.detach().to(self.grad_host_device) + return grad + self._grad_host_hook = self.register_hook(move_grad) def move_(self, device): @@ -54,4 +61,4 @@ def move_(self, device): """ if not isinstance(device, torch.device): raise TypeError('Expected torch.device') - self.data = self.to(device) + self.data = self.detach().to(device) From c9999b6053f18b1efd3bc9755206f4a68c5a8117 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 22 Jul 2021 06:48:40 +0000 Subject: [PATCH 0056/1892] add some comments --- examples/case_study/memory_linear.py | 11 ++++++++++- examples/case_study/parallel_linear.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/case_study/memory_linear.py b/examples/case_study/memory_linear.py index 62ee7890..f4d14989 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/memory_linear.py @@ -89,8 +89,17 @@ def grad_swap(grad): weight.register_hook(grad_swap) bias.register_hook(grad_swap) + ## Placement for a tensor swap in/out + ## where to swap in: op.device (op placement policy) + ## where to swap out: tensor.swap_to (policy) + ## Timing when a tensor swapped in/out - ## On-demand? Pre-fetch? All-consumed? + ## Basic Time block (each op is a slot?) + ## Event-driven (tesnor access? on-demand? | dynamic scenario?) + + # Policy description + # op.device = torch.device('cuda:0') + # ... ##### diff --git a/examples/case_study/parallel_linear.py b/examples/case_study/parallel_linear.py index 2dd798a2..ced2100e 100644 --- a/examples/case_study/parallel_linear.py +++ b/examples/case_study/parallel_linear.py @@ -25,7 +25,7 @@ def linear_tensor_parallel(input, weight, bias): ### Necessary information to know ### rank = torch.distributed.get_rank() # which role I participate? - ### Additional ops need to use ### -- TODO: System Generated + ### Additional ops need to use ### -- TODO: System provided class InputAdapter(torch.autograd.Function): @staticmethod def forward(ctx, input_): From 077c419489c43fb670ccf33f1edf6f2ce4aa93c5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 23 Jul 2021 02:03:26 +0000 Subject: [PATCH 0057/1892] holist linear update --- cube/operator/holist/linear.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index 86057076..0ff1a879 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -19,7 +19,7 @@ class LinearColumnWeight(GenericHolisticOp): def __init__(self): # TODO - inputs_layout = Full + inputs_layout = outline.Full # TODO weight_layout = outline.SplitAxis(axis=0, chunk_num=None, overlap=0) # TODO @@ -28,7 +28,7 @@ def __init__(self): output_layout = outline.Align(weight_layout) super().__init__( - input_layout=(inputs_layout, weight_layout), + input_layout=(inputs_layout, weight_layout, bias_layout), output_layout=(output_layout,) ) @@ -54,16 +54,17 @@ class LinearColumnInputRowWeight(GenericHolisticOp): def __init__(self): # TODO - inputs_layout = None + inputs_layout = outline.SplitAxis(axis=-1, chunk_num=None, overlap=0) # TODO - weight_layout = None + align = outline.Align(inputs_layout.chunk_num) + weight_layout = outline.SplitAxis(axis=1, chunk_num=align, overlap=0) # TODO - bias_layout = None + bias_layout = outline.Broadcast(reduce=ReductionOpPool.Sum) # TODO - output_layout = None + output_layout = outline.Broadcast(reduce=ReductionOpPool.Sum) super().__init__( - input_layout=(inputs_layout, weight_layout), + input_layout=(inputs_layout, weight_layout, bias_layout), output_layout=(output_layout,) ) From cc205a81de85f8f4c32c3694c10ae3bee7cb4276 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 23 Jul 2021 02:47:39 +0000 Subject: [PATCH 0058/1892] segment should have reduction ops --- cube/tensor/community.py | 31 +------------ cube/tensor/logic/segment/outline.py | 19 +++++++- .../segment/{primitive.py => segment.py} | 45 ++++++++++++++++--- 3 files changed, 58 insertions(+), 37 deletions(-) rename cube/tensor/logic/segment/{primitive.py => segment.py} (66%) diff --git a/cube/tensor/community.py b/cube/tensor/community.py index 9a1f504c..137f8f93 100644 --- a/cube/tensor/community.py +++ b/cube/tensor/community.py @@ -1,39 +1,13 @@ import torch -__all__ = ['ReductionOpPool', 'Community'] +__all__ = ['Community'] -class _Reduction(type): - - Sum = torch.distributed.all_reduce - - # identity for replica - Replica = lambda physical_tensor, group : physical_tensor - - def register(name, udf): - """ - Reduction functions should be in function format: - - Arguments: - PhysicalTensor - Communication Group - - Return: - PhysicalTensor - """ - if hasattr(cls, name): - raise KeyError("{} is registered".format(name)) - setattr(cls, name, udf) - - -class ReductionOpPool(metaclass=_Reduction): - pass - class Community: - def __init__(self, logical_tensor, segment, reduction=None): + def __init__(self, logical_tensor, segment): """Create Community based on the logical tensor Community manages one: @@ -60,7 +34,6 @@ def __init__(self, logical_tensor, segment, reduction=None): self.phsyical_tensor = None self.materialized = False - self.reduction = reduction def spread(self, device_list): """Spread physical tensors to devices diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py index f12f3df1..7e16b243 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/segment/outline.py @@ -49,6 +49,23 @@ def __init__(self, axis, chunk_num=None, overlap=0): def __call__(self, shape): """ - Runtime segment generation given the logical tensor shape + Runtime community generation given the logical tensor shape """ pass + + +class Broadcast: + + def __init__(self, reduce_op=None): + """ + Segmentation Pattern Requirement: + + The same shape with output but (may) need additional reduction + """ + self.reduce_op = reduce_op + + def __call__(self, shape): + """ + Runtime segment generation given the logical tensor shape + """ + return shape diff --git a/cube/tensor/logic/segment/primitive.py b/cube/tensor/logic/segment/segment.py similarity index 66% rename from cube/tensor/logic/segment/primitive.py rename to cube/tensor/logic/segment/segment.py index dbbd1547..51e1609f 100644 --- a/cube/tensor/logic/segment/primitive.py +++ b/cube/tensor/logic/segment/segment.py @@ -3,6 +3,36 @@ """ +__all__ = ['ReductionOp', 'DataSegment', 'TileSegment'] + + +class _Reduction(type): + + Sum = torch.distributed.all_reduce + + # identity for replica + Replica = lambda physical_tensor, group : physical_tensor + + def register(name, udf): + """ + Reduction functions should be in function format: + + Arguments: + PhysicalTensor + Communication Group + + Return: + PhysicalTensor + """ + if hasattr(cls, name): + raise KeyError("{} is registered".format(name)) + setattr(cls, name, udf) + + +class ReductionOp(metaclass=_Reduction): + pass + + ## Basic structure for holding a segment -> cover all the cases ## class DataSegment: """ @@ -11,7 +41,7 @@ class DataSegment: The order of indices indicate the physical storage (1-D array) order """ - def __init__(self, indices_list=None): + def __init__(self, indices_list=None, reduction=None): """ Args: indices_list (list[ list[int], ]): @@ -19,6 +49,7 @@ def __init__(self, indices_list=None): """ self.indices = indices_list + self.reduction = None def convert_to_indices(self): """ @@ -44,7 +75,7 @@ class TileSegment(DataSegment): which can be represented as the start position + offset (shape) """ - def __init__(self, anchor, offset): + def __init__(self, anchor, offset, reduction=None): """ Args: anchor (list[int]): start position of the tile @@ -52,7 +83,7 @@ def __init__(self, anchor, offset): """ if len(anchor) != len(offset): raise ValueError("Require anchor length to be equal with offset length") - super().__init__() + super().__init__(reduction=reduction) self.anchor = anchor self.offset = offset @@ -68,11 +99,11 @@ def reorder(self): ## Primitive sets for translation ## -def create_from_indices(indices): - return DataSegment(indices) +def create_from_indices(indices, reduction): + return DataSegment(indices, reduction) -def create_from_tiles(anchor, offset): +def create_from_tiles(anchor, offset, reduction): # segments = list() # dims = len(offset) # for dim_id in range(dims): @@ -81,4 +112,4 @@ def create_from_tiles(anchor, offset): # segments.append(segment) # segment = merge_segments(segments) # return segment - return TileSegment(anchor, offset) \ No newline at end of file + return TileSegment(anchor, offset, reduction) \ No newline at end of file From a581b94a7ba2c8faef13d62dbefbb698df90079f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 23 Jul 2021 02:48:17 +0000 Subject: [PATCH 0059/1892] killall cmd --- scripts/env-setup.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index 080b6fa9..09fa35da 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -11,6 +11,7 @@ sudo git config --global user.email "v-zhiql@microsoft.com" sudo chmod -R a+w /opt/conda sudo apt-get install tmux -y +sudo apt-get install psmisc -y wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.tmux.conf -O ~/.tmux.conf wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.vimrc -O ~/.vimrc From 1f06f23a1b438f388c571e76fc8175bcf5f13ccb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 23 Jul 2021 03:15:41 +0000 Subject: [PATCH 0060/1892] logic operator enable with simple policy --- cube/operator/logic/generics.py | 23 +++++++++++++++++++++-- cube/operator/logic/linear.py | 4 +++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/cube/operator/logic/generics.py b/cube/operator/logic/generics.py index 3bdc7005..a3114aa3 100644 --- a/cube/operator/logic/generics.py +++ b/cube/operator/logic/generics.py @@ -36,7 +36,8 @@ def get_op(self, args, **kwargs): An hybrid-operator function which may composite by nested holistic operators """ - pass + # TODO: hybrid parallelism generation + return self.holist_ops[0] @@ -46,12 +47,30 @@ def __init__(self): # candidate holistic operator self.factory = HolisticOpFactory() + self.policy_fn = None + + def register_policy(self, policy_fn): + """ + Register a policy function to customize how composite + holistic op generated during runtime. + + The `policy_fn` takes self.factory as input and returns a composite + holistic operator (callable) + """ + if not callable(policy_fn): + raise TypeError("Expected a callable function") + self.policy_fn = policy_fn def __call__(self, args, **kwargs): """ Policy here to determine which holistic operator(s) are called """ - composite_op = self.factory.get_op(args, kwargs) + # use default policy + if self.policy_fn is None: + composite_op = self.factory.get_op(args, kwargs) + # use user-customized policy + else: + composite_op = self.policy_fn(self.factory) # run operator with the strategy plan outputs = composite_op(args, kwargs) return outputs \ No newline at end of file diff --git a/cube/operator/logic/linear.py b/cube/operator/logic/linear.py index c5f40bf3..eac8e1b1 100644 --- a/cube/operator/logic/linear.py +++ b/cube/operator/logic/linear.py @@ -1,8 +1,10 @@ from cube.operator.logic.generics import generics from cube.operator.holist.linear import kHolistLinearSets + __all__ = ['linear'] + def Linear(generics.GenericLogicalOp): def __init__(self): @@ -13,4 +15,4 @@ def __init__(self): self.factory.register(holist_op) # initialize op -linear = Linear() \ No newline at end of file +linear = Linear() From 5137710c2fdb7d6445988d69ff6c89502ae020a8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 23 Jul 2021 07:53:40 +0000 Subject: [PATCH 0061/1892] logical tensor impl --- cube/tensor/logic/tensor.py | 97 ++++++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 18 deletions(-) diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 3c185c0c..bec77387 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -1,27 +1,88 @@ -from cube.tensor.community import Community, ReductionOpPool +from cube.tensor.community import Community +from cube.tensor.logic.segment.segment import DataSegment class LogicalTensor: + """ + The logical tensor + """ - def __init__(self, ): + def __init__(self, shape, init_data=True): - self.communities = list() + self.shape = shape + # segment -> community + self.communities = dict() + self.segments = list() + self.data = None + if init_data: + self.data = torch.randn(shape).detach() + + def get_physical_tensor(self, segment): + """ + Get physical tensor from the community. + + Args: + idx: index for community + + Returns: + torch.Tensor or None + """ + community = self.communities[idx] + return community.get_physical_tensor() + + def get_community(self, segment): + """ + Get Community based on the segment + """ + if not isinstance(segment, DataSegment): + raise ValueError("Expected (derived) DataSegment to chooese Community") + if segment not in self.communities: + raise KeyError("The segment doesn't found in current tensor") + return self.communities[segment] + + def __getitem__(self, key): + """ + key: + if key is DataSegment, return community + ##TODO: DOUBLE CHECK + if key is slice, return new logical tensor + """ + if isinstance(key, DataSegment): + return self.get_community(key) + else: + ## TODO: should return logical tensor / views + return self.data[key] def create_community(self, segment): - """Create a community by given the segment""" - self.communities.append(segment) + """ + Create a community by given the segment + """ + if segment in self.communities: + raise KeyError("The segment already exists") + self.communities[segment] = Community(segment) + self.segments.append(segment) + + def set_community(self, community): + """ + Set a community - def fuse(self, communities=None, reduction=ReductionOpPool.Replica): - """Fuse multiple communities into one + Warning: if there is a segment in this tensor that matches + with the given community's segment, the original community + will be overrided + """ + if not isinstance(community): + raise TypeError("Expected a community") + segment = community.segment + if segment not in self.communities: + self.segments.append(segment) + self.communities[segment] = community - Synchronization will done for each community to retrieve the right - result. - - Args: - communities (list[Community]): - The particular comunities to merge. - If not specified (None), fuse all the communities. - reduction: - Reduction operator for the new fused community. - """" - pass + def remove_community(self, segment): + """ + Remove a community by given the segment + """ + #TODO: check whether a sync-back is needed + if segment not in self.communities: + raise KeyError("The segment doesn't exist") + del self.communities[segment] + self.segments.remove(segment) From ae54c003dd0dbdc237561613992fb55e7d08c221 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 23 Jul 2021 08:47:10 +0000 Subject: [PATCH 0062/1892] setup interface for logical tensor --- cube/tensor/community.py | 59 +++++++++++++++------------- cube/tensor/logic/segment/segment.py | 29 ++++++++++---- cube/tensor/physic/tensor.py | 9 +++++ 3 files changed, 62 insertions(+), 35 deletions(-) diff --git a/cube/tensor/community.py b/cube/tensor/community.py index 137f8f93..7f402177 100644 --- a/cube/tensor/community.py +++ b/cube/tensor/community.py @@ -1,13 +1,12 @@ import torch - +from cube.device.physic.group import DeviceGroup __all__ = ['Community'] - class Community: - def __init__(self, logical_tensor, segment): + def __init__(self, segment): """Create Community based on the logical tensor Community manages one: @@ -16,27 +15,22 @@ def __init__(self, logical_tensor, segment): 2). Materialized Physical Tensors Attribute: - parent (LogicalTensor): - Logical Tensor the Community belongs to segment (DataSegment): indices of logical_tensor for this community - reduction (Callable or None): - Reduction function for retrieve back physical tensors + """ # connection to logical tensor - self.parent = logical_tensor - # DataSegment to indicate both element set and data format mapping self.segment = segment # connection to physical tensor (the PyTorch Tensor) self.phsyical_tensor = None + self.group = list() self.materialized = False - - def spread(self, device_list): - """Spread physical tensors to devices + def deploy(self, ranks, logic_tensor, value_map_fn=None): + """deploy (materialize) to physical tensors Materialize physical tensors for this community and spread out based on the given device list. @@ -45,25 +39,36 @@ def spread(self, device_list): to spread. Argument: - device_list (list[int]): device id list + ranks (list[int]): device id list + value_map_fn (callable): + takes the tensor, rank, world_size, + return a new tensor + """ - Return: - PhysicalTensor(s) or None: - - For SPMD programming model: - if current device is in the `device_list`, - than return the corresponding physical tensor, - else None + rank = DeviceGroup().rank + self.group = DeviceGroup().get_group(ranks) + if rank not in ranks: + self.physical_tensor = None + else: + if logic_tensor.data is None: + # TODO: check overlap + self.physical_tensor = torch.randn(self.segment.shape, device='cuda') + else: + # select from cpu view + self.physical_tensor = torch.empty(self.segment.shape, devic='cuda') + self.physical_tensor.copy_(logic_tensor[self.segment.get_indices()]) + if value_map_fn is not None: + self.physical_tensor.data = self.value_map_fn(physical_tensor) - For Global View programming model: - return list[PhysicalTensor] with the same - order of `device_list`. + def sync(self): """ - pass + Synchrnoize the spread physical tensors by reduction operation - def sync(self): - """Synchrnoize the spread physical tensors by reduction operation""" - pass + This should be a out-placement device for differentiable communication ops. + + Each device should call this, including no-physical-tensor devices + """ + self.physical_tensor = self.segment.reduction(self.physical_tensor, self.group) def get_physical_tensor(self): """Get physical tensor if materialized diff --git a/cube/tensor/logic/segment/segment.py b/cube/tensor/logic/segment/segment.py index 51e1609f..9f7156b0 100644 --- a/cube/tensor/logic/segment/segment.py +++ b/cube/tensor/logic/segment/segment.py @@ -41,21 +41,31 @@ class DataSegment: The order of indices indicate the physical storage (1-D array) order """ - def __init__(self, indices_list=None, reduction=None): + def __init__(self, indices_list=None, reduction=None, shape=None): """ Args: indices_list (list[ list[int], ]): List of index + reduction (ReductionOp): + How to reduction to the logical value + shape: + shape on the indices list """ self.indices = indices_list + if shape is None: + if indices_list is None: + raise RuntimeError("Provide shape if indices_list is empty") + self.shape = (len(indices_list[0]),) + else: + self.shape = shape self.reduction = None - def convert_to_indices(self): + def get_indices(self): """ Convert to index list """ - pass + return self.indices def reorder(self, new_orders): """ @@ -64,8 +74,8 @@ def reorder(self, new_orders): Note this can be only called before materialize physical tensors, or called from underlying operation that will change physical storage format """ - #TODO: validation check - self.indices = new_orders + for dim in range(len(self.indices)): + self.indices[dim] = self.indices[dim][new_orders] ## Higher structure to cover the most cases ## @@ -83,15 +93,18 @@ def __init__(self, anchor, offset, reduction=None): """ if len(anchor) != len(offset): raise ValueError("Require anchor length to be equal with offset length") - super().__init__(reduction=reduction) + super().__init__(reduction=reduction, shape=tuple(offset)) self.anchor = anchor self.offset = offset - def convert_to_indices(self): + def get_indices(self): """ Convert anchor and offset to index list """ - pass + indices = list() + for start, ofst in zip(self.anchor, self.offset): + indices.append(slice(start, start + ofst)) + return tuple(indices) def reorder(self): pass diff --git a/cube/tensor/physic/tensor.py b/cube/tensor/physic/tensor.py index 4a04949e..dad58a9e 100644 --- a/cube/tensor/physic/tensor.py +++ b/cube/tensor/physic/tensor.py @@ -62,3 +62,12 @@ def move_(self, device): if not isinstance(device, torch.device): raise TypeError('Expected torch.device') self.data = self.detach().to(device) + + def move_grad_(self, device): + """ + inplacement device move on tensor grad + """ + if not isinstance(device, torch.device): + raise TypeError('Expected torch.device') + if self.grad is not None: + self.grad.data = self.grad.detach().to(device) \ No newline at end of file From 38fb95ec13577cd2092d1b465a35aec5d89f30e9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 23 Jul 2021 09:26:36 +0000 Subject: [PATCH 0063/1892] add outliner --- cube/tensor/logic/segment/outline.py | 46 ++++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py index 7e16b243..8c7d2a50 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/segment/outline.py @@ -10,20 +10,25 @@ to the real segmentation on given logical tensor shape. """ +from cube.tensor.logic.segment.segment import TileSegment, ReductionOp + # interface to setup restrictions on the segmentation + class Full: - def __init__(self): - pass + def __init__(self, reduction=None): + self.reduction=None def __call__(self, shape): - pass + segment = TileSegment([0] * len(shape), shape, self.reduction) + return [segment] + class SplitAxis: - def __init__(self, axis, chunk_num=None, overlap=0): + def __init__(self, axis, chunk_num=None, overlap=0, reduction=None, uniform=True): """ Segmentation Pattern Requirement (parameters): @@ -45,27 +50,22 @@ def __init__(self, axis, chunk_num=None, overlap=0): """ self.axis = axis self.chunk_num = chunk_num + self.uniform = True self.overlap = overlap - - def __call__(self, shape): - """ - Runtime community generation given the logical tensor shape - """ - pass - - -class Broadcast: - - def __init__(self, reduce_op=None): - """ - Segmentation Pattern Requirement: - - The same shape with output but (may) need additional reduction - """ - self.reduce_op = reduce_op + self.reduction = reduction def __call__(self, shape): """ Runtime segment generation given the logical tensor shape - """ - return shape + + This is the policy that how to do the translation. + """ + segments = list() + shape[axis] = shape[axis] // self.chunk_num + anchor = [0] * self.chunk_num + for _ in range(self.chunk_num): + segment = TileSegment( + list(anchor), list(shape), reduction=ReductionOp.Replica) + anchor[axis] += shape[axis] + segments.append(segment) + return segments From aae82be8a6a5843e75e3699805bc04dfc781ca47 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 25 Jul 2021 13:51:13 +0000 Subject: [PATCH 0064/1892] fix tensor segment bugs --- cube/tensor/community.py | 5 +- cube/tensor/logic/segment/__init__.py | 5 ++ cube/tensor/logic/segment/segment.py | 54 ++++++++++----- cube/tensor/logic/tensor.py | 1 + tests/tensor/test_segment.py | 94 +++++++++++++++++++++++++++ 5 files changed, 140 insertions(+), 19 deletions(-) create mode 100644 tests/tensor/test_segment.py diff --git a/cube/tensor/community.py b/cube/tensor/community.py index 7f402177..9216a44c 100644 --- a/cube/tensor/community.py +++ b/cube/tensor/community.py @@ -1,6 +1,7 @@ import torch from cube.device.physic.group import DeviceGroup + __all__ = ['Community'] @@ -17,8 +18,6 @@ def __init__(self, segment): Attribute: segment (DataSegment): indices of logical_tensor for this community - - """ # connection to logical tensor # DataSegment to indicate both element set and data format mapping @@ -75,7 +74,7 @@ def get_physical_tensor(self): Returns: PhysicalTensor (if materialized) - """" + """ if self.materialized: return self.physical_tensor else: diff --git a/cube/tensor/logic/segment/__init__.py b/cube/tensor/logic/segment/__init__.py index e69de29b..ad6e085a 100644 --- a/cube/tensor/logic/segment/__init__.py +++ b/cube/tensor/logic/segment/__init__.py @@ -0,0 +1,5 @@ +from cube.tensor.logic.segment.segment import ReductionOp +from cube.tensor.logic.segment.segment import DataSegment, TileSegment +from cube.tensor.logic.segment.segment import create_from_indices, create_from_tiles + +from cube.tensor.logic.segment.outline import Full, SplitAxis \ No newline at end of file diff --git a/cube/tensor/logic/segment/segment.py b/cube/tensor/logic/segment/segment.py index 9f7156b0..4d758524 100644 --- a/cube/tensor/logic/segment/segment.py +++ b/cube/tensor/logic/segment/segment.py @@ -2,8 +2,7 @@ This is the runtime primitive sets to setup community for a logical tensor. """ - -__all__ = ['ReductionOp', 'DataSegment', 'TileSegment'] +import torch class _Reduction(type): @@ -13,7 +12,7 @@ class _Reduction(type): # identity for replica Replica = lambda physical_tensor, group : physical_tensor - def register(name, udf): + def register(cls, name, udf): """ Reduction functions should be in function format: @@ -41,7 +40,7 @@ class DataSegment: The order of indices indicate the physical storage (1-D array) order """ - def __init__(self, indices_list=None, reduction=None, shape=None): + def __init__(self, indices_list=None, shape=None, reduction=None): """ Args: indices_list (list[ list[int], ]): @@ -58,14 +57,15 @@ def __init__(self, indices_list=None, reduction=None, shape=None): raise RuntimeError("Provide shape if indices_list is empty") self.shape = (len(indices_list[0]),) else: + # TODO: check shape self.shape = shape - self.reduction = None + self.reduction = reduction def get_indices(self): """ Convert to index list """ - return self.indices + return tuple(self.indices) def reorder(self, new_orders): """ @@ -73,9 +73,13 @@ def reorder(self, new_orders): Note this can be only called before materialize physical tensors, or called from underlying operation that will change physical storage format + + Args: + new_orders (iteratable): order of each index """ + #TODO: check if materialized for dim in range(len(self.indices)): - self.indices[dim] = self.indices[dim][new_orders] + self.indices[dim] = [self.indices[dim][idx] for idx in new_orders] ## Higher structure to cover the most cases ## @@ -85,24 +89,23 @@ class TileSegment(DataSegment): which can be represented as the start position + offset (shape) """ - def __init__(self, anchor, offset, reduction=None): + def __init__(self, anchor, shape, reduction=None): """ Args: anchor (list[int]): start position of the tile offset (list[int]): offset (shape) of the tile """ - if len(anchor) != len(offset): + if len(anchor) != len(shape): raise ValueError("Require anchor length to be equal with offset length") - super().__init__(reduction=reduction, shape=tuple(offset)) + super().__init__(shape=shape, reduction=reduction) self.anchor = anchor - self.offset = offset def get_indices(self): """ Convert anchor and offset to index list """ indices = list() - for start, ofst in zip(self.anchor, self.offset): + for start, ofst in zip(self.anchor, self.shape): indices.append(slice(start, start + ofst)) return tuple(indices) @@ -112,11 +115,30 @@ def reorder(self): ## Primitive sets for translation ## -def create_from_indices(indices, reduction): - return DataSegment(indices, reduction) +def create_from_indices(indices, shape, reduction): + """ + Create a data segment from indices, and format in shape. + The indices list will determine how data will be organized in + storage. + + Args: + indices (list[list[int]]): + Represent indices from logical tensor shape + len(indices) is the dimension, + e.g., index [3,4,5] and [2,7,9] is represented as + [[3,2], [4,7], [5,9]] + shape (tuple or list): + the segment shape + reduction (ReductionOp): + How to generate correct logical results from reduction op. + + Returns: + DataSegment instance + """ + return DataSegment(indices, shape, reduction) -def create_from_tiles(anchor, offset, reduction): +def create_from_tiles(anchor, shape, reduction): # segments = list() # dims = len(offset) # for dim_id in range(dims): @@ -125,4 +147,4 @@ def create_from_tiles(anchor, offset, reduction): # segments.append(segment) # segment = merge_segments(segments) # return segment - return TileSegment(anchor, offset, reduction) \ No newline at end of file + return TileSegment(anchor, shape, reduction) \ No newline at end of file diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index bec77387..0738b2fc 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -15,6 +15,7 @@ def __init__(self, shape, init_data=True): self.segments = list() self.data = None if init_data: + import torch self.data = torch.randn(shape).detach() def get_physical_tensor(self, segment): diff --git a/tests/tensor/test_segment.py b/tests/tensor/test_segment.py new file mode 100644 index 00000000..16050cd8 --- /dev/null +++ b/tests/tensor/test_segment.py @@ -0,0 +1,94 @@ +import cube.tensor.logic.segment as segment +import torch + + +def test_reduction_op_register(): + + def reduce_fn(physical_tensor, group): + return physical_tensor + segment.ReductionOp.register("ReduceSum", reduce_fn) + + # segment.ReductionOp.register("Replica", reduce_fn) + + tensor = torch.randn((3,4)) + out = segment.ReductionOp.ReduceSum(tensor, None) + assert out is tensor + + +## TODO: test all the provided reduction op +def test_reduction_op_replica(): + #TODO: check correctness + assert callable(segment.ReductionOp.Replica) + + +def test_data_segment_init(): + + tensor = torch.randn((10,10,10)) + indices = [[5,3,2,4], + [1,2,7,4], + [3,4,5,4]] + seg = segment.DataSegment( + indices, shape=(4,1), reduction=segment.ReductionOp.Replica) + assert seg.indices == indices + assert seg.shape == (4,1) + assert seg.reduction == segment.ReductionOp.Replica + + +def test_data_segment_get_indices(): + + tensor = torch.randn((10,10,10)) + indices = [[5,3,2,4], + [1,2,7,4], + [3,4,5,4]] + seg = segment.DataSegment( + indices, shape=(4,1), reduction=segment.ReductionOp.Replica) + sub_tensor = tensor[seg.get_indices()] + assert sub_tensor.size() == torch.Size([4]) + + +def test_data_segment_reorder(): + + tensor = torch.randn((10,10,10)) + indices = [[5,3,2,4], + [1,2,7,4], + [3,4,5,4]] + seg = segment.DataSegment( + indices, shape=(4,1), reduction=segment.ReductionOp.Replica) + sub_tensor = tensor[seg.get_indices()] + + seg.reorder([2,3,1,0]) + ref_tensor = sub_tensor[([2,3,1,0])] + check_tensor = tensor[seg.get_indices()] + assert torch.all(torch.eq(ref_tensor, check_tensor)) + + +def test_tile_segment_init(): + + tensor = torch.randn((10,10,10)) + seg = segment.TileSegment( + anchor=(2,3,1), shape=(4,4,4), reduction=segment.ReductionOp.Replica) + assert seg.shape == (4,4,4) + assert seg.anchor == (2,3,1) + assert seg.reduction == segment.ReductionOp.Replica + + +def test_tile_segment_get_indices(): + + tensor = torch.randn((10,10,10)) + seg = segment.TileSegment( + anchor=(2,3,1), shape=(4,4,4), reduction=segment.ReductionOp.Replica) + ref_tensor = tensor[(slice(2,2+4), slice(3,3+4), slice(1,1+4))] + sub_tensor = tensor[seg.get_indices()] + assert sub_tensor.size() == torch.Size([4,4,4]) + assert torch.all(torch.eq(ref_tensor, sub_tensor)) + + +if __name__ == '__main__': + + test_reduction_op_register() + test_reduction_op_replica() + test_data_segment_init() + test_data_segment_get_indices() + test_data_segment_reorder() + test_tile_segment_init() + test_tile_segment_get_indices() \ No newline at end of file From 35be2bc7b6f056f803d5d407d689eccfc63cce2c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 25 Jul 2021 14:32:06 +0000 Subject: [PATCH 0065/1892] fix outline bugs --- cube/tensor/logic/segment/outline.py | 15 ++++--- tests/tensor/test_outline.py | 67 ++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 tests/tensor/test_outline.py diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py index 8c7d2a50..2f7e7390 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/segment/outline.py @@ -19,10 +19,10 @@ class Full: def __init__(self, reduction=None): - self.reduction=None + self.reduction = reduction def __call__(self, shape): - segment = TileSegment([0] * len(shape), shape, self.reduction) + segment = TileSegment([0] * len(shape), list(shape), self.reduction) return [segment] @@ -61,11 +61,12 @@ def __call__(self, shape): This is the policy that how to do the translation. """ segments = list() - shape[axis] = shape[axis] // self.chunk_num - anchor = [0] * self.chunk_num - for _ in range(self.chunk_num): + shape = list(shape) + shape[self.axis] = shape[self.axis] // self.chunk_num + anchor = [0] * len(shape) + for cid in range(self.chunk_num): segment = TileSegment( - list(anchor), list(shape), reduction=ReductionOp.Replica) - anchor[axis] += shape[axis] + list(anchor), list(shape), reduction=self.reduction[cid]) + anchor[self.axis] += shape[self.axis] segments.append(segment) return segments diff --git a/tests/tensor/test_outline.py b/tests/tensor/test_outline.py new file mode 100644 index 00000000..fde43df9 --- /dev/null +++ b/tests/tensor/test_outline.py @@ -0,0 +1,67 @@ +import cube.tensor.logic.segment.outline as outline +import cube.tensor.logic.segment as segment + +import torch + + +def test_full(): + + shape = (10,10,10) + tensor = torch.randn(shape) + full_dsp = outline.Full(reduction=segment.ReductionOp.Replica) + assert full_dsp.reduction == segment.ReductionOp.Replica + + segments = full_dsp(tensor.shape) + assert len(segments) == 1 + tile_seg = segments[0] + assert type(tile_seg) == segment.TileSegment + + sub_tensor = tensor[tile_seg.get_indices()] + assert torch.all(torch.eq(sub_tensor, tensor)) + + +def test_split_axis(): + + axis = 1 + num = 8 + + shape = (4,16,4) + tensor = torch.randn(shape) + split_dsp = outline.SplitAxis( + axis=axis, chunk_num=None, overlap=0, + reduction=segment.ReductionOp.Replica, uniform=True) + assert split_dsp.axis == 1 + assert split_dsp.chunk_num is None + assert split_dsp.uniform is True + assert split_dsp.overlap == 0 + assert split_dsp.reduction == segment.ReductionOp.Replica + + ## Policy here to decide how to split + if split_dsp.chunk_num is None: + split_dsp.chunk_num = num + split_dsp.reduction = [segment.ReductionOp.Replica] * num + ### + + + segs = split_dsp(tensor.shape) + assert len(segs) == num + assert torch.all(torch.Tensor([type(seg) == segment.TileSegment for seg in segs])) + + ofst = 0 + expected_shape = list(shape) + expected_shape[axis] = shape[axis] // num + for cid in range(num): + seg = segs[cid] + sub_tensor = tensor[seg.get_indices()] + ref_tensor = tensor[:,ofst:ofst+expected_shape[axis],:] + # print('sub tensor {}: {}'.format(sub_tensor.size(), sub_tensor)) + # print('ref tensor {}: {}'.format(ref_tensor.size(), ref_tensor)) + assert sub_tensor.size() == torch.Size(expected_shape) + assert torch.all(torch.eq(sub_tensor, ref_tensor)) + ofst += expected_shape[axis] + + +if __name__ == '__main__': + + test_full() + test_split_axis() \ No newline at end of file From 4de8c5a297a6e97a9717ebd1d4985ff2420f699a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 25 Jul 2021 15:09:15 +0000 Subject: [PATCH 0066/1892] test on device group --- tests/test_group.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/test_group.py b/tests/test_group.py index f10b75b1..10ac0526 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -11,8 +11,27 @@ tests/test_group.py """ -from combo.physical.device.group import DeviceGroup +from cube.device.physic.group import DeviceGroup +import torch + + +def test_sub_group(): + + group = DeviceGroup() + myrank = group.rank + sub_group_1 = group.get_group([0,2]) + if myrank in [0,2]: + assert torch.distributed.get_rank(sub_group_1) in [0,1] + else: + assert torch.distributed.get_rank(sub_group_1) == -1 + + sub_group_2 = group.get_group([1,3]) + if myrank in [1,3]: + assert torch.distributed.get_rank(sub_group_2) in [0,1] + else: + assert torch.distributed.get_rank(sub_group_2) == -1 + # print(group) if __name__ == '__main__': @@ -20,7 +39,4 @@ # init distributed group = DeviceGroup() - sub_group_1 = group.get_group([0,2]) - sub_group_2 = group.get_group([1,3]) - - print(group) \ No newline at end of file + test_sub_group() From 486b249ead4df3eb53f9519cfdecac88109d7d1b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 25 Jul 2021 16:29:47 +0000 Subject: [PATCH 0067/1892] waiting to fix: segment reduction op will be passed with self --- cube/tensor/community.py | 18 +++-- cube/tensor/logic/segment/segment.py | 3 +- tests/tensor/test_community.py | 99 ++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 7 deletions(-) create mode 100644 tests/tensor/test_community.py diff --git a/cube/tensor/community.py b/cube/tensor/community.py index 9216a44c..68b42d2b 100644 --- a/cube/tensor/community.py +++ b/cube/tensor/community.py @@ -24,7 +24,7 @@ def __init__(self, segment): self.segment = segment # connection to physical tensor (the PyTorch Tensor) - self.phsyical_tensor = None + self.physical_tensor = None self.group = list() self.materialized = False @@ -51,13 +51,14 @@ def deploy(self, ranks, logic_tensor, value_map_fn=None): else: if logic_tensor.data is None: # TODO: check overlap - self.physical_tensor = torch.randn(self.segment.shape, device='cuda') + self.physical_tensor = torch.randn(tuple(self.segment.shape), device='cuda') else: # select from cpu view - self.physical_tensor = torch.empty(self.segment.shape, devic='cuda') - self.physical_tensor.copy_(logic_tensor[self.segment.get_indices()]) + self.physical_tensor = torch.empty(tuple(self.segment.shape), device='cuda') + self.physical_tensor.copy_(logic_tensor.data[self.segment.get_indices()]) if value_map_fn is not None: - self.physical_tensor.data = self.value_map_fn(physical_tensor) + self.physical_tensor.data = value_map_fn(self.physical_tensor) + self.materialized = True def sync(self): """ @@ -67,7 +68,12 @@ def sync(self): Each device should call this, including no-physical-tensor devices """ - self.physical_tensor = self.segment.reduction(self.physical_tensor, self.group) + if self.materialized: + if self.physical_tensor is not None: + # self.segment.reduction.__func__(self.physical_tensor, group=self.group) + torch.distributed.all_reduce(self.physical_tensor, group=self.group) + else: + raise RuntimeError("The Community has not been materialized to physical tensors") def get_physical_tensor(self): """Get physical tensor if materialized diff --git a/cube/tensor/logic/segment/segment.py b/cube/tensor/logic/segment/segment.py index 4d758524..0c6c8195 100644 --- a/cube/tensor/logic/segment/segment.py +++ b/cube/tensor/logic/segment/segment.py @@ -5,6 +5,7 @@ import torch +# TODO: reduction op should be in torch autograd function class _Reduction(type): Sum = torch.distributed.all_reduce @@ -59,7 +60,7 @@ def __init__(self, indices_list=None, shape=None, reduction=None): else: # TODO: check shape self.shape = shape - self.reduction = reduction + self.reduction = staticmethod(reduction) def get_indices(self): """ diff --git a/tests/tensor/test_community.py b/tests/tensor/test_community.py new file mode 100644 index 00000000..1b4cc74b --- /dev/null +++ b/tests/tensor/test_community.py @@ -0,0 +1,99 @@ +""" +cmd for running the test + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + tests/tensor/test_community.py +""" + +from cube.tensor.community import Community +import cube.tensor.logic.segment as segment +from cube.device.physic.group import DeviceGroup + +import torch +import os +torch.manual_seed(121) + + +def test_community_init(): + + tensor = torch.randn((10,10,10)) + seg = segment.TileSegment( + anchor=(2,3,1), shape=(4,4,4), reduction=segment.ReductionOp.Replica) + community = Community(seg) + + assert community.segment == seg + assert community.physical_tensor is None + assert len(community.group) == 0 + assert community.materialized is False + + +def test_community_deploy(): + + tensor = torch.randn((10,10,10)) + seg = segment.TileSegment( + anchor=(2,3,1), shape=(4,4,4), + reduction=segment.ReductionOp.Replica) + community = Community(seg) + + # policy for scaling out + # using torch.Tensor to test + ranks = [0,2] + community.deploy(ranks, tensor, None) + + # check + myrank = DeviceGroup().rank + if myrank not in ranks: + assert community.physical_tensor is None + else: + sub_tensor = community.physical_tensor + assert torch.is_tensor(sub_tensor) + assert sub_tensor.size() == torch.Size([4,4,4]) + assert sub_tensor.device == torch.device('cuda:{}'.format(myrank)) + assert torch.all(torch.eq(sub_tensor.cpu(), tensor[seg.get_indices()])) + assert torch.distributed.get_world_size(community.group) == 2 + + +def test_community_sync(): + tensor = torch.randn((10,10,10)) + seg = segment.TileSegment( + anchor=(2,3,1), shape=(4,4,4), + reduction=segment.ReductionOp.Sum) + community = Community(seg) + + # deploy with value modification + ranks = [0,2] + community.deploy(ranks, tensor, + value_map_fn=lambda tensor: tensor / 2) + + # check + sub_tensor = community.get_physical_tensor() + ref_tensor = tensor[seg.get_indices()].cuda() + myrank = DeviceGroup().rank + if myrank in ranks: + assert torch.all(torch.eq(sub_tensor, ref_tensor / 2)) + + # sync to get logical value + community.sync() + sub_tensor = community.get_physical_tensor() + if myrank not in ranks: + assert sub_tensor is None + else: + # print('ref: {}'.format(ref_tensor)) + assert torch.allclose(sub_tensor, ref_tensor) is True + + + +if __name__ == '__main__': + + group = DeviceGroup() + torch.distributed.barrier() + + test_community_init() + test_community_deploy() + test_community_sync() \ No newline at end of file From 0a21b38f82d5ac88238801ca27f899e2318f2d78 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 01:27:44 +0000 Subject: [PATCH 0068/1892] workaround: use list to wrap func --- cube/operator/physic/__init__.py | 1 - cube/operator/physic/comm/__init__.py | 2 +- cube/operator/physic/comm/boundary.py | 35 ++++++++++++--------------- cube/tensor/community.py | 5 ++-- cube/tensor/logic/segment/segment.py | 12 +++++---- tests/tensor/test_segment.py | 4 +-- 6 files changed, 29 insertions(+), 30 deletions(-) diff --git a/cube/operator/physic/__init__.py b/cube/operator/physic/__init__.py index cb37c2a2..e69de29b 100644 --- a/cube/operator/physic/__init__.py +++ b/cube/operator/physic/__init__.py @@ -1 +0,0 @@ -from cube.physical.operator.linear import linear_op \ No newline at end of file diff --git a/cube/operator/physic/comm/__init__.py b/cube/operator/physic/comm/__init__.py index 962cde7b..3d140a8e 100644 --- a/cube/operator/physic/comm/__init__.py +++ b/cube/operator/physic/comm/__init__.py @@ -1 +1 @@ -from cube.physical.operator.comm.boundary import * \ No newline at end of file +from cube.operator.physic.comm.boundary import * \ No newline at end of file diff --git a/cube/operator/physic/comm/boundary.py b/cube/operator/physic/comm/boundary.py index 54960a66..ca27a516 100644 --- a/cube/operator/physic/comm/boundary.py +++ b/cube/operator/physic/comm/boundary.py @@ -5,10 +5,10 @@ import torch -from cube.physical.device.group import DeviceGroup +from cube.device.physic.group import DeviceGroup -__all__ = ['parallel_in', 'gather_out', 'scatter_in', 'reduce_out'] +__all__ = ['replicate', 'gather_out', 'scatter_in', 'reduce_sum'] def _reduce(input_, group): @@ -57,9 +57,8 @@ class _ParallelIn(torch.autograd.Function): """Pass the input to the model parallel region.""" @staticmethod - def forward(ctx, input_, ranks): + def forward(ctx, input_, group): # record group - group = DeviceGroup().get_group(ranks) ctx.constants = group # identitfy forward return input_ @@ -68,16 +67,15 @@ def forward(ctx, input_, ranks): def backward(ctx, grad_output): # allreduce group = ctx.constants - return _reduce(grad_output, group), None + return torch.distributed.all_reduce(grad_output, group=group), None class _GatherOut(torch.autograd.Function): """Split the input and keep only the corresponding chuck to the rank.""" @staticmethod - def forward(ctx, input_, dim, ranks): + def forward(ctx, input_, dim, group): # record group - group = DeviceGroup().get_group(ranks) ctx.constants = (group, dim) # allgather return _gather(input_, dim, group) @@ -94,8 +92,7 @@ class _ScatterIn(torch.autograd.Function): """Split the input and keep only the corresponding chuck to the rank.""" @staticmethod - def forward(ctx, input_, dim, ranks): - group = DeviceGroup().get_group(ranks) + def forward(ctx, input_, dim, group): world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) ctx.constants = (group, dim) @@ -111,8 +108,7 @@ class _ReduceOut(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @staticmethod - def forward(ctx, input_, ranks): - group = DeviceGroup().get_group(ranks) + def forward(ctx, input_, group): return _reduce(input_, group) @staticmethod @@ -120,17 +116,18 @@ def backward(ctx, grad_output): return grad_output, None -def parallel_in(input_, ranks): - return _ParallelIn.apply(input_, ranks) +def replicate(input_, group): + return _ParallelIn.apply(input_, group) -def gather_out(input_, dim, ranks): - return _GatherOut.apply(input_, dim, ranks) +def gather_out(input_, dim, group): + return _GatherOut.apply(input_, dim, group) -def scatter_in(input_, dim, ranks): - return _ScatterIn.apply(input_, dim, ranks) +def scatter_in(input_, dim, group): + return _ScatterIn.apply(input_, dim, group) -def reduce_out(input_, ranks): - return _ReduceOut.apply(input_, ranks) +def reduce_sum(input_, group): + return _ReduceOut.apply(input_, group) + diff --git a/cube/tensor/community.py b/cube/tensor/community.py index 68b42d2b..2f94d36c 100644 --- a/cube/tensor/community.py +++ b/cube/tensor/community.py @@ -22,6 +22,7 @@ def __init__(self, segment): # connection to logical tensor # DataSegment to indicate both element set and data format mapping self.segment = segment + self.reduction = segment.reduction # connection to physical tensor (the PyTorch Tensor) self.physical_tensor = None @@ -70,8 +71,8 @@ def sync(self): """ if self.materialized: if self.physical_tensor is not None: - # self.segment.reduction.__func__(self.physical_tensor, group=self.group) - torch.distributed.all_reduce(self.physical_tensor, group=self.group) + #TODO: elegant impl on calling reduction op + self.reduction[0](self.physical_tensor, group=self.group) else: raise RuntimeError("The Community has not been materialized to physical tensors") diff --git a/cube/tensor/logic/segment/segment.py b/cube/tensor/logic/segment/segment.py index 0c6c8195..7aceb6d7 100644 --- a/cube/tensor/logic/segment/segment.py +++ b/cube/tensor/logic/segment/segment.py @@ -2,16 +2,18 @@ This is the runtime primitive sets to setup community for a logical tensor. """ +from cube.operator.physic.comm import replicate, reduce_sum import torch # TODO: reduction op should be in torch autograd function class _Reduction(type): - Sum = torch.distributed.all_reduce + # forward: all_reduce, backward: identity + Sum = [reduce_sum] - # identity for replica - Replica = lambda physical_tensor, group : physical_tensor + # forward: identity, backward: all_reduce + Replica = [replicate] def register(cls, name, udf): """ @@ -26,7 +28,7 @@ def register(cls, name, udf): """ if hasattr(cls, name): raise KeyError("{} is registered".format(name)) - setattr(cls, name, udf) + setattr(cls, name, [udf]) class ReductionOp(metaclass=_Reduction): @@ -60,7 +62,7 @@ def __init__(self, indices_list=None, shape=None, reduction=None): else: # TODO: check shape self.shape = shape - self.reduction = staticmethod(reduction) + self.reduction = reduction def get_indices(self): """ diff --git a/tests/tensor/test_segment.py b/tests/tensor/test_segment.py index 16050cd8..a9a9e371 100644 --- a/tests/tensor/test_segment.py +++ b/tests/tensor/test_segment.py @@ -11,14 +11,14 @@ def reduce_fn(physical_tensor, group): # segment.ReductionOp.register("Replica", reduce_fn) tensor = torch.randn((3,4)) - out = segment.ReductionOp.ReduceSum(tensor, None) + out = segment.ReductionOp.ReduceSum[0](tensor, None) assert out is tensor ## TODO: test all the provided reduction op def test_reduction_op_replica(): #TODO: check correctness - assert callable(segment.ReductionOp.Replica) + assert callable(segment.ReductionOp.Replica[0]) def test_data_segment_init(): From 58a87dbb7963274ece4a98ebaee54ae6e8b3b18f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 05:48:00 +0000 Subject: [PATCH 0069/1892] add test on logic op --- cube/operator/holist/generics.py | 128 ++++++++++++++++++++--------- cube/operator/logic/generics.py | 50 +++++++---- tests/operator/test_holistic_op.py | 0 tests/operator/test_logical_op.py | 49 +++++++++++ 4 files changed, 174 insertions(+), 53 deletions(-) create mode 100644 tests/operator/test_holistic_op.py create mode 100644 tests/operator/test_logical_op.py diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index d6a8af17..c5f112af 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -8,11 +8,15 @@ happen in the front of the next executed op in case the layout doesn't match. """ +from cube.tensor.logic.tensor import LogicalTensor +from cube.tensor.community import Community + class GenericHolisticOp: def __init__(self, input_layout, output_layout, - input_format=None, output_format=None): + input_format=None, output_format=None, + dim_order=None): """ Layout is the community distribution requirement for input and output logical tensors. @@ -21,70 +25,120 @@ def __init__(self, `None` indicates the format is consistent with logical op, otherwise should be a list of integers like torch.Tensor.permute() on the logical required format. + + Args: + input_laytout (list[Outliner, None]): outliner for each input + The length of outliner should be equal to the number of input + input_format (list[list[int], None]): + input dim order compare with logical definition + output_laytout (list[Outlinter, None]): outliner for each output + The length of outliner should be equal to the number of output + output_format (list[list[int], None]): + output dim order compare with logical definition """ - # holistic layout of input + + # holistic layout (outliner) of input self.input_layout = input_layout self.input_format = input_format # holistic layout of output self.output_layout = output_layout self.output_format = output_format + + self.logical_op = None - def input_adapter(self, args, **kwargs): + def input_adapter(self, *args, **kwargs): """ Transform tensors in args and kwargs to match the - input layout requirement + input layout requirement, Currently kwargs is not allowed to + have tensors """ - # step 1: data reformat based on the input argument - #TODO: data dimension format transformation - tensor_inputs = list() - for arg in args: - #TODO: kwargs - if cube.is_tensor(arg): - tensor_inputs.append(arg) - tensor_segments = list() - for outliner, tensor in zip(self.input_layout, tensor_inputs): - segments = outliner(tensor.shape) - tensor_segments.append(segments) + #TODO: kwargs - # step 2: physical tensor placement (policy) - #TODO: policy module - tensor_communities = policy_module(tensor_segments) + input_num = len(args) + if len(self.input_layout) != input_num: + raise RuntimeError("Fail to adapt input: layout length not equal") + if len(self.input_format) != input_num: + raise RuntimeError("Fail to adapt input: format length not equal") + + # step 1: data reformat based on the input argument + for input, dim_order in zip(args, self.input_layout): + if dim_order is not None and isinstance(input, LogicalTensor): + input.permute(dim_order) - # step 3: community matching - for communities, tensor in zip(tensor_communities, tensor_inputs): - tensor.match(communities) + # step 2: get segments based on expert discription + tensor_segments = list() + for outliner, tensor in zip(self.input_layout, args): + if outliner is not None: + segments = outliner(tensor.shape) + tensor_segments.append(segments) + else: + tensor_segments.append(None) + + # step 3: physical tensor placement (policy) + if self.policy_module is not None: + tensor_communities, tensor_devices = self.policy_fn[0](args, tensor_segments) + else: + # init community without policy decision + tensor_communities = [[Community(seg) for seg in segments] for segments in tensor_segments] + tensor_devices = None + tensor_val_map_fns = None + + # step 4: community matching + for communities, devices, tensor in zip(tensor_communities, tensor_devices, tensor_inputs): + tensor.match(communities, tensor_devices, tensor_val_map_fns) def output_adapter(self, outputs): """ Data reformat to logical op format """ - if not isinstance(outputs, tuple): - outputs = (outputs,) - output_tensors = list() - for output in outputs: - if cube.is_tensor(output): - if cube.is_tensor(output): - output_tensors.append(output) - for outliner, output in zip(self.output_layout, output_tensors): - segments = outliner(output.shape) - output.to_logic_tensor(segments) + if not isinstance(outputs, list): + outputs = [outputs] + # step 1: construct to logical tensor + logical_outputs = list() + for output_segs, outliner, shape in zip(self.outputs, self.output_layout, self.logical_shapes): + if shape is not None: + communities = outliner(shape) + output = construct_from_community(shape, communities, output_segs) + logical_outputs.append(output) + else: + logical_outputs.append(output_segs) + # step 2: data reformat based on the output + for out_id in range(len(self.output_format)): + dim_order = self.output_format[out_id] + if dim_order is not None and isinstance(logical_outputs[out_id], LogicalTensor): + logical_ouputs[out_id] = logical_ouputs[out_id].permute(dim_order) + return logical_outputs + - def forward(self, args, **kwargs): + def forward(self, *args, **kwargs): """Expert code for doing operation Call to the physical operator for execution""" - pass + raise NotImplementedError("Error call to generics") + - def __call__(self, args, **kwargs): + def __call__(self, *args, **kwargs): # data transformations to match input layout requirement - self.input_adapter(args, kwargs) + self.input_adapter(*args, **kwargs) # do execution - outputs = self.forward(args, kwargs) + outputs = self.forward(*args, **kwargs) - # wrap in holistic tensor with output layout + # wrap to logical tensor + self.logical_shapes = self.logical_op.infer_shapes(*args, **kwargs) outputs = self.output_adapter(outputs) return outputs + + def register_policy(self, policy_fn): + """ + Register a policy to take inputs (logical tensors) and segments, + generate device placement for each community, and corresponding + message mapping + + Args: + plicy_fn (callable) + """ + self.policy_fn = [policy_fn] diff --git a/cube/operator/logic/generics.py b/cube/operator/logic/generics.py index a3114aa3..1fe75250 100644 --- a/cube/operator/logic/generics.py +++ b/cube/operator/logic/generics.py @@ -20,24 +20,26 @@ def __init__(self): self.holist_ops = list() + def __len__(self): + """ + Return the number of holistic op registered + """ + return len(self.holist_ops) + def register(self, holistic_op): """ Register a holistic op as one of the anchors """ - #TODO: type check - self.holist_ops.append(holist_ops) + self.holist_ops.append(holistic_op) - def get_op(self, args, **kwargs): + def get_op(self, idx): """ - Given input tensor args, choose holistic operator(s) - for distributed execution plan + Get holistic operator based on idx Returns: - An hybrid-operator function which may composite by - nested holistic operators + HolisticOp instance """ - # TODO: hybrid parallelism generation - return self.holist_ops[0] + return self.holist_ops[idx] @@ -59,18 +61,34 @@ def register_policy(self, policy_fn): """ if not callable(policy_fn): raise TypeError("Expected a callable function") - self.policy_fn = policy_fn - - def __call__(self, args, **kwargs): + self.policy_fn = [policy_fn] + + def shape_infer(self, *args, **kwargs): """ - Policy here to determine which holistic operator(s) are called + Output shape inference according to inputs + + Args: + Operator input + + Returns: + shapes tuple(list[int]): shape for each output tensor """ + raise NotImplementedError("Expected a shape infer engine") + + def get_op(self, *args, **kwargs): # use default policy if self.policy_fn is None: - composite_op = self.factory.get_op(args, kwargs) + composite_op = self.factory.get_op(0) # use user-customized policy else: - composite_op = self.policy_fn(self.factory) + composite_op = self.policy_fn[0](self.factory, *args, **kwargs) + return composite_op + + def __call__(self, *args, **kwargs): + """ + Policy here to determine which holistic operator(s) are called + """ + composite_op = self.get_op(*args, **kwargs) # run operator with the strategy plan - outputs = composite_op(args, kwargs) + outputs = composite_op(*args, **kwargs) return outputs \ No newline at end of file diff --git a/tests/operator/test_holistic_op.py b/tests/operator/test_holistic_op.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/operator/test_logical_op.py b/tests/operator/test_logical_op.py new file mode 100644 index 00000000..33cab4d9 --- /dev/null +++ b/tests/operator/test_logical_op.py @@ -0,0 +1,49 @@ +from cube.operator.logic.generics import HolisticOpFactory, GenericLogicalOp + + +def test_factory(): + + factory = HolisticOpFactory() + assert len(factory) == 0 + + class HolisticOp: pass + holistic_op = HolisticOp() + + factory.register(holistic_op) + assert len(factor) == 1 + + op = factory.get_op(0) + assert op is holistic_op + + +def test_generic_logical_op_init(): + + generic_op = GenericLogicalOp() + assert len(generic_op.factory) == 0 + assert generic_op.policy_fn is None + + +def test_generic_logical_op_register(): + + generic_op = GenericLogicalOp() + + class HolisticOp: pass + holistic_op = HolisticOp() + + generic_op.factory.register(holistic_op) + + def policy_fn(factory): + return factory.get_op(0) + + generic_op.register_policy(policy_fn) + assert generic_op.policy_fn is not None + + op = generic_op.get_op() + assert op is holistic_op + + +if __name__ == '__main__': + + test_factory + test_generic_logical_op_init() + test_generic_logical_op_register() \ No newline at end of file From ac75d88eaf28ad2d8adcf5e9c0cb932fd48b2d03 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 06:07:04 +0000 Subject: [PATCH 0070/1892] add generic physic op --- cube/operator/physic/generics.py | 30 ++++++++++++++++++++++++++++++ tests/operator/test_physic_op.py | 22 ++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 tests/operator/test_physic_op.py diff --git a/cube/operator/physic/generics.py b/cube/operator/physic/generics.py index 68ec043f..1e11ec83 100644 --- a/cube/operator/physic/generics.py +++ b/cube/operator/physic/generics.py @@ -3,3 +3,33 @@ """ import torch + + +class GenericPhysicOp: + + def __init__(self, func, placement=None): + + if not callable(func): + raise TypeError("Expect callable function") + self.func = [func] + self.placement = placement + + def set_placement(self, placement): + if not isinstance(placement, torch.device): + raise TypeError("Expected torch device") + self.placement = placement + + def __call__(self, *args, **kwargs): + + # tensor movement + for arg in args: + if torch.is_tensor(arg): + if arg.device != self.placement: + arg.data = arg.detach().to(self.placement) + for key in kwargs: + if torch.is_tensor(kwargs[key]): + if kwargs[key].device != self.placement: + kwargs[key].data = kwargs[key].detach().to(self.placement) + + outputs = self.func[0](*args, **kwargs) + return outputs diff --git a/tests/operator/test_physic_op.py b/tests/operator/test_physic_op.py new file mode 100644 index 00000000..3feb0c0c --- /dev/null +++ b/tests/operator/test_physic_op.py @@ -0,0 +1,22 @@ +from cube.operator.physic.generics import GenericPhysicOp +import torch + + +def test_physic_generic_op(): + + op = GenericPhysicOp(torch._C._nn.linear) + assert op.placement is None + + op.set_placement(torch.device('cuda:0')) + + matA = torch.randn((1024,1024)) + matB = torch.randn((1024,1024)) + matC = op(matA, matB, bias=None) + assert matC.device == torch.device('cuda:0') + + matC_ref = torch._C._nn.linear(matA, matB, bias=None) + assert torch.allclose(matC, matC_ref) is True + + +if __name__ == '__main__': + test_physic_generic_op() From f38d1a0892a4663fe2763f06dbfc037a4b28ebda Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 07:44:45 +0000 Subject: [PATCH 0071/1892] add placement to physcial op --- cube/operator/physic/generics.py | 71 +++++++++++++++++++++++++++----- tests/operator/test_physic_op.py | 36 +++++++++++++--- 2 files changed, 92 insertions(+), 15 deletions(-) diff --git a/cube/operator/physic/generics.py b/cube/operator/physic/generics.py index 1e11ec83..d52c7d3e 100644 --- a/cube/operator/physic/generics.py +++ b/cube/operator/physic/generics.py @@ -2,34 +2,85 @@ This should be the interface with C level kernel launch """ +from cube.device.physic.group import DeviceGroup import torch +class OpResult: + """ + The empty result is used for re-constructing community + """ + def __init__(self, result, ranks): + self.res = result + self.placement = ranks + + def get_result(self): + return self.res + class GenericPhysicOp: + """ + The generic physical op takes at least one physical tensor, + and generates at least one physical tensor. + + If there is no tensor as input, will return an empty result + which indicates which rank will generate the correct one. + """ def __init__(self, func, placement=None): if not callable(func): raise TypeError("Expect callable function") self.func = [func] - self.placement = placement + self._placement = placement + self.execute_flag = False + self.policy_fn = None - def set_placement(self, placement): - if not isinstance(placement, torch.device): - raise TypeError("Expected torch device") - self.placement = placement + @property + def placement(self): + """ + Ranks for the op to execute + """ + return self._placement + + @placement.setter + def placement(self, ranks): + if not isinstance(ranks, list): + raise TypeError("Expected list of int ranks") + self._placement = ranks + if DeviceGroup().rank not in self.placement: + self.execute_flag = False + else: + self.execute_flag = True + + def register_policy(self, policy_fn): + if not callable(policy_fn): + raise TypeError("Expected callable policy function") + self.policy_fn = [policy_fn] def __call__(self, *args, **kwargs): + #TODO: fix for model-partition with send/recv + if self.placement is None: + if self.policy_fn is None: + #TODO: fix: this will break between-device consistency view + self.placement = [torch.cuda.current_device()] + else: + self.placement = self.policy_fn(*args, **kwargs) + if not self.execute_flag: + return OpResult(None, self.placement) # tensor movement for arg in args: if torch.is_tensor(arg): - if arg.device != self.placement: - arg.data = arg.detach().to(self.placement) + #TODO: rank -> device mapping, send/recv + if arg.device.index not in self.placement: + #TODO: rank -> device mapping, send/recv + arg.data = arg.detach().cuda() for key in kwargs: if torch.is_tensor(kwargs[key]): - if kwargs[key].device != self.placement: - kwargs[key].data = kwargs[key].detach().to(self.placement) + #TODO: rank -> device mapping, send/recv + if kwargs[key].device.index not in self.placement: + # TODO: rank -> device mapping, send/recv + kwargs[key].data = kwargs[key].detach().cuda() outputs = self.func[0](*args, **kwargs) - return outputs + return OpResult(outputs, self.placement) diff --git a/tests/operator/test_physic_op.py b/tests/operator/test_physic_op.py index 3feb0c0c..a2472e50 100644 --- a/tests/operator/test_physic_op.py +++ b/tests/operator/test_physic_op.py @@ -1,21 +1,47 @@ -from cube.operator.physic.generics import GenericPhysicOp +""" +cmd for running the test + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + tests/operator/test_physic_op.py +""" + +from cube.device.physic.group import DeviceGroup +from cube.operator.physic.generics import GenericPhysicOp, OpResult import torch def test_physic_generic_op(): + myrank = DeviceGroup().rank + ranks = [0, 2] + op = GenericPhysicOp(torch._C._nn.linear) assert op.placement is None - - op.set_placement(torch.device('cuda:0')) + + op.placement = ranks + assert op.func[0] is torch._C._nn.linear + assert op.placement == [0, 2] + assert op.execute_flag == (myrank in ranks) matA = torch.randn((1024,1024)) matB = torch.randn((1024,1024)) matC = op(matA, matB, bias=None) - assert matC.device == torch.device('cuda:0') + + assert set(matC.placement) == set(ranks) + if myrank in ranks: + assert torch.is_tensor(matC.get_result()) + else: + assert matC.get_result() is None matC_ref = torch._C._nn.linear(matA, matB, bias=None) - assert torch.allclose(matC, matC_ref) is True + if myrank in ranks: + assert torch.allclose(matC.get_result(), matC_ref) is True if __name__ == '__main__': From 4a49409499d3efa35816e95738599e75a2b47c62 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 07:49:46 +0000 Subject: [PATCH 0072/1892] add linear op --- cube/operator/physic/linear.py | 80 +++------------------------- tests/operator/test_physic_linear.py | 48 +++++++++++++++++ 2 files changed, 55 insertions(+), 73 deletions(-) create mode 100644 tests/operator/test_physic_linear.py diff --git a/cube/operator/physic/linear.py b/cube/operator/physic/linear.py index be8d925e..8af300f9 100644 --- a/cube/operator/physic/linear.py +++ b/cube/operator/physic/linear.py @@ -1,78 +1,12 @@ -from typing import Optional -import torch -from torch import Tensor -from torch.overrides import has_torch_function_variadic, handle_torch_function - -import cube.physic.operator.comm as comm -from cube.physic.device.group import DeviceGroup - +from cube.operator.physic.generics import GenericPhysicOp -def linear_op(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: - r""" - Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. - - This operator supports :ref:`TensorFloat32`. +import torch - Shape: - - Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of - additional dimensions - - Weight: :math:`(out\_features, in\_features)` - - Bias: :math:`(out\_features)` - - Output: :math:`(N, *, out\_features)` +class Linear(GenericPhysicOp): + """ + Apply matmul: Out = input * weight^T + bias """ - if has_torch_function_variadic(input, weight): - print('note: this branch should not pass') - return handle_torch_function(linear, (input, weight), input, weight, bias=bias) - - # Information needed for enabling multiple GPUs: - # - involved devices -> (could be involved in tensor interfacee design) - # - which algorithm to take -> (semantic description / pattern match?) - # e.g., semantic description: allgather(split(weight, dim=0) * input + split(bias, dim=0)) - # - how we handle weight -> (everytime we need chunk weight / bias if we only focus on op) - devices = [0, 1] - rank = torch.distributed.get_rank(DeviceGroup().get_group(devices)) - - # single GPU version - if True: - output = torch._C._nn.linear(input, weight, bias) - - # multi-GPU version - # - Assume input is full - # - split cloumn of W: Y = XW + b where W = [W_1, ..., W_p] - elif False: - # get weight chunk - weight = torch.chunk(weight, chunks=len(devices), dim=0)[rank].contiguous() - if bias is not None: - bias = torch.chunk(bias, chunks=len(devices), dim=0)[rank].contiguous() - # forward: identity; backward: allreduce - input = comm.parallel_in(input, ranks=devices) - output = torch._C._nn.linear(input, weight, bias) - # forward: allgather; backward: split - output = comm.gather_out(output, dim=-1, ranks=devices) - - # multi-GPU version Y = XW + b - # - Assume input is full - # - split row of W, column of X: - # - Y = X = [X_1, ..., X_p] - # - W = [W_1 // ... // W_p] - elif False: - # get weight chunk - weight = torch.chunk(weight, chunks=len(devices), dim=1)[rank] - # forward: scatter; backward: allgather - input = comm.scatter_in(input, dim=-1, ranks=devices) - output = torch._C._nn.linear(input, weight, bias) - # forward: reduce; backward: identity - output = comm.reduce_out(output, ranks=devices) - - # Pesudo-code - else: - # data parallelism - input=[(0, Split())], weight=[], bias=[], output=[(0, Split())] - # tensor parallelism, weight column split - input=[], weight=[(0, Split())], bias=[(0, Split())], output=[(-1, Split())] - # tensor parallelism: data column + weight row - input=[(1, Split())] weight=[(1, Split()], bias=[], output=[(ALL, Partial(Sum))] - - return output \ No newline at end of file + def __init__(self, placement=None): + super().__init__(torch._C._nn.linear, placement) diff --git a/tests/operator/test_physic_linear.py b/tests/operator/test_physic_linear.py new file mode 100644 index 00000000..f51b51ed --- /dev/null +++ b/tests/operator/test_physic_linear.py @@ -0,0 +1,48 @@ +""" +cmd for running the test + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + tests/operator/test_physic_linear.py +""" + +from cube.device.physic.group import DeviceGroup +from cube.operator.physic.linear import Linear +import torch + + +def test_physic_generic_op(): + + myrank = DeviceGroup().rank + ranks = [0, 2] + + op = Linear() + assert op.placement is None + + op.placement = ranks + assert op.func[0] is torch._C._nn.linear + assert op.placement == [0, 2] + assert op.execute_flag == (myrank in ranks) + + matA = torch.randn((1024,1024)) + matB = torch.randn((1024,1024)) + matC = op(matA, matB, bias=None) + + assert set(matC.placement) == set(ranks) + if myrank in ranks: + assert torch.is_tensor(matC.get_result()) + else: + assert matC.get_result() is None + + matC_ref = torch._C._nn.linear(matA, matB, bias=None) + if myrank in ranks: + assert torch.allclose(matC.get_result(), matC_ref) is True + + +if __name__ == '__main__': + test_physic_generic_op() From 2f8855fc6333ee267779ce13a95630baa02a12b0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 08:42:04 +0000 Subject: [PATCH 0073/1892] update generics --- cube/operator/holist/generics.py | 35 ++++++++++++++++++++++++-------- cube/operator/holist/linear.py | 5 ++--- cube/operator/logic/linear.py | 13 +++++++++--- 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index c5f112af..fff700b0 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -11,6 +11,7 @@ from cube.tensor.logic.tensor import LogicalTensor from cube.tensor.community import Community + class GenericHolisticOp: def __init__(self, @@ -46,6 +47,14 @@ def __init__(self, self.output_format = output_format self.logical_op = None + self.policy_fn = None + + def set_logic_op(self, logic_op): + """ + Set logic op. This will be automatically called when the + holistic op registered in a logical op. + """ + self.logical_op = logic_op def input_adapter(self, *args, **kwargs): """ @@ -66,7 +75,7 @@ def input_adapter(self, *args, **kwargs): if dim_order is not None and isinstance(input, LogicalTensor): input.permute(dim_order) - # step 2: get segments based on expert discription + # step 2: get segments based on expert description tensor_segments = list() for outliner, tensor in zip(self.input_layout, args): if outliner is not None: @@ -88,6 +97,20 @@ def input_adapter(self, *args, **kwargs): for communities, devices, tensor in zip(tensor_communities, tensor_devices, tensor_inputs): tensor.match(communities, tensor_devices, tensor_val_map_fns) + def forward(self, *args, **kwargs): + """ + Expert code for doing operation + Call to the physical operator for execution + + Expert needs to gurantee the returned value is list[tuple(OpResult,),] + + Each item in list is the corresponding output to logical op output. + + Each item in the logical op output is a OpResult to the segment specified + by the expert. The order should be consistent with specified segment. + """ + raise NotImplementedError("Error call to generics") + def output_adapter(self, outputs): """ Data reformat to logical op format @@ -110,14 +133,6 @@ def output_adapter(self, outputs): logical_ouputs[out_id] = logical_ouputs[out_id].permute(dim_order) return logical_outputs - - - def forward(self, *args, **kwargs): - """Expert code for doing operation - Call to the physical operator for execution""" - raise NotImplementedError("Error call to generics") - - def __call__(self, *args, **kwargs): # data transformations to match input layout requirement @@ -127,6 +142,8 @@ def __call__(self, *args, **kwargs): outputs = self.forward(*args, **kwargs) # wrap to logical tensor + if self.logical_op is None: + raise RuntimeError("This holistic op doesn't have logical op") self.logical_shapes = self.logical_op.infer_shapes(*args, **kwargs) outputs = self.output_adapter(outputs) diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index 0ff1a879..2abe0834 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -59,9 +59,9 @@ def __init__(self): align = outline.Align(inputs_layout.chunk_num) weight_layout = outline.SplitAxis(axis=1, chunk_num=align, overlap=0) # TODO - bias_layout = outline.Broadcast(reduce=ReductionOpPool.Sum) + bias_layout = outline.Full(reduce=ReductionOpPool.Sum) # TODO - output_layout = outline.Broadcast(reduce=ReductionOpPool.Sum) + output_layout = outline.Full(reduce=ReductionOpPool.Sum) super().__init__( input_layout=(inputs_layout, weight_layout, bias_layout), @@ -69,7 +69,6 @@ def __init__(self): ) def forward(self, inputs, weight, bias): - #TODO: semantic errors on bias output = physic_op.linear(inputs, weight, bias) return [output] diff --git a/cube/operator/logic/linear.py b/cube/operator/logic/linear.py index eac8e1b1..d4ab70ee 100644 --- a/cube/operator/logic/linear.py +++ b/cube/operator/logic/linear.py @@ -2,7 +2,7 @@ from cube.operator.holist.linear import kHolistLinearSets -__all__ = ['linear'] +__all__ = ['Linear'] def Linear(generics.GenericLogicalOp): @@ -12,7 +12,14 @@ def __init__(self): # register holistic operators for holist_op in kHolistLinearSets: + holist_op.set_logic_op(self) self.factory.register(holist_op) -# initialize op -linear = Linear() + def shape_infer(self, input_shape, weight_shape, bias_shape=None) + """ + Return the outputs shape [list[int],] + """ + output_shape = list(input_shape) + output_shape[-1] = weight_shape[-1] + return [output_shape] + From c140e15c2b87e09e13a8daae37dd94285b16c830 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 12:24:14 +0000 Subject: [PATCH 0074/1892] outline with align support --- cube/tensor/logic/segment/outline.py | 71 ++++++++++++++++++++++++---- cube/tensor/logic/segment/segment.py | 6 +-- tests/tensor/test_outline.py | 51 ++++++++++++++++---- 3 files changed, 106 insertions(+), 22 deletions(-) diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py index 2f7e7390..a13d597f 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/segment/outline.py @@ -13,20 +13,72 @@ from cube.tensor.logic.segment.segment import TileSegment, ReductionOp +class MutableContainer: + + def __init__(self, scope): + self.__val = None + self.__scope = scope + + def get(self, scope=False): + if scope: + return self.__scope + else: + return self.__val + + def set(self, val): + if self.__scope is not None: + if val not in self.__scope: + raise ValueError("Fail to set container, out of range") + self.__val = val + + +class ConstantContainer: + + def __init__(self, val): + self.__val = val + + def get(self): + return self.__val + + def set(self, val): + raise RuntimeError("Cannot set a ConstantContainer") + + # interface to setup restrictions on the segmentation -class Full: +class BaseOutline: + """ + Basic class for declare outline + To setup an attribute (requirement), use `inst_baseoutline.attribute_name = val` + """ def __init__(self, reduction=None): self.reduction = reduction + def __setattr__(self, key, val): + if key in self.__dict__: + self.__dict__[key].set(val) + #TODO: Align semantics will not allow setting val on child, need a new class + elif isinstance(val, MutableContainer) or isinstance(val, ConstantContainer): + self.__dict__[key] = val + elif val is None or isinstance(val, range) or isinstance(val, set): + self.__dict__[key] = MutableContainer(val) + else: + self.__dict__[key] = ConstantContainer(val) + + +class Full(BaseOutline): + + def __init__(self, reduction=None): + super().__init__(reduction) + def __call__(self, shape): - segment = TileSegment([0] * len(shape), list(shape), self.reduction) + segment = TileSegment([0] * len(shape), list(shape), self.reduction.get()) return [segment] -class SplitAxis: +class SplitAxis(BaseOutline): def __init__(self, axis, chunk_num=None, overlap=0, reduction=None, uniform=True): """ @@ -48,11 +100,11 @@ def __init__(self, axis, chunk_num=None, overlap=0, reduction=None, uniform=True if a tuple(min, max), the overlap size wihtin the scope [min,max] is valid """ + super().__init__(reduction) self.axis = axis self.chunk_num = chunk_num - self.uniform = True + self.uniform = uniform self.overlap = overlap - self.reduction = reduction def __call__(self, shape): """ @@ -62,11 +114,12 @@ def __call__(self, shape): """ segments = list() shape = list(shape) - shape[self.axis] = shape[self.axis] // self.chunk_num + shape[self.axis.get()] = shape[self.axis.get()] // self.chunk_num.get() anchor = [0] * len(shape) - for cid in range(self.chunk_num): + #TODO: support list of reductions + for cid in range(self.chunk_num.get()): segment = TileSegment( - list(anchor), list(shape), reduction=self.reduction[cid]) - anchor[self.axis] += shape[self.axis] + list(anchor), list(shape), reduction=self.reduction) + anchor[self.axis.get()] += shape[self.axis.get()] segments.append(segment) return segments diff --git a/cube/tensor/logic/segment/segment.py b/cube/tensor/logic/segment/segment.py index 7aceb6d7..db0acd1b 100644 --- a/cube/tensor/logic/segment/segment.py +++ b/cube/tensor/logic/segment/segment.py @@ -10,10 +10,10 @@ class _Reduction(type): # forward: all_reduce, backward: identity - Sum = [reduce_sum] + Sum = (reduce_sum,) # forward: identity, backward: all_reduce - Replica = [replicate] + Replica = (replicate,) def register(cls, name, udf): """ @@ -28,7 +28,7 @@ def register(cls, name, udf): """ if hasattr(cls, name): raise KeyError("{} is registered".format(name)) - setattr(cls, name, [udf]) + setattr(cls, name, (udf,)) class ReductionOp(metaclass=_Reduction): diff --git a/tests/tensor/test_outline.py b/tests/tensor/test_outline.py index fde43df9..86bb825b 100644 --- a/tests/tensor/test_outline.py +++ b/tests/tensor/test_outline.py @@ -4,12 +4,25 @@ import torch +def test_base(): + + dsp1 = outline.BaseOutline(reduction=segment.ReductionOp.Sum) + assert isinstance(dsp1.reduction, outline.ConstantContainer) + assert dsp1.reduction.get() == segment.ReductionOp.Sum + + choice = {segment.ReductionOp.Sum, segment.ReductionOp.Replica} + dsp2 = outline.BaseOutline(reduction=choice) + assert isinstance(dsp2.reduction, outline.MutableContainer) + assert dsp2.reduction.get() is None + assert dsp2.reduction.get(scope=True) == choice + + def test_full(): shape = (10,10,10) tensor = torch.randn(shape) full_dsp = outline.Full(reduction=segment.ReductionOp.Replica) - assert full_dsp.reduction == segment.ReductionOp.Replica + assert full_dsp.reduction.get() == segment.ReductionOp.Replica segments = full_dsp(tensor.shape) assert len(segments) == 1 @@ -30,22 +43,23 @@ def test_split_axis(): split_dsp = outline.SplitAxis( axis=axis, chunk_num=None, overlap=0, reduction=segment.ReductionOp.Replica, uniform=True) - assert split_dsp.axis == 1 - assert split_dsp.chunk_num is None - assert split_dsp.uniform is True - assert split_dsp.overlap == 0 - assert split_dsp.reduction == segment.ReductionOp.Replica + assert split_dsp.axis.get() == 1 + assert split_dsp.chunk_num.get() is None + assert split_dsp.uniform.get() is True + assert split_dsp.overlap.get() == 0 + assert split_dsp.reduction.get() == segment.ReductionOp.Replica ## Policy here to decide how to split - if split_dsp.chunk_num is None: + if split_dsp.chunk_num.get() is None: split_dsp.chunk_num = num - split_dsp.reduction = [segment.ReductionOp.Replica] * num ### segs = split_dsp(tensor.shape) assert len(segs) == num - assert torch.all(torch.Tensor([type(seg) == segment.TileSegment for seg in segs])) + assert torch.all( + torch.Tensor( + [type(seg) == segment.TileSegment for seg in segs])).item() is True ofst = 0 expected_shape = list(shape) @@ -61,7 +75,24 @@ def test_split_axis(): ofst += expected_shape[axis] +def test_align(): + + dsp1 = outline.SplitAxis( + axis=1, chunk_num=None, overlap=0, + reduction=segment.ReductionOp.Replica, uniform=True) + + dsp2 = outline.SplitAxis( + axis=2, chunk_num=dsp1.chunk_num, overlap=0, + reduction=segment.ReductionOp.Replica, uniform=True) + + dsp1.chunk_num = 3 + assert dsp2.chunk_num.get() == 3 + assert dsp2.axis.get() == 2 + + if __name__ == '__main__': + test_base() test_full() - test_split_axis() \ No newline at end of file + test_split_axis() + test_align() \ No newline at end of file From 8f07bea82e45e34b2652c4c9514e44fb77ef8f32 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 12:35:53 +0000 Subject: [PATCH 0075/1892] linear distributed description with alignment --- cube/operator/holist/linear.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index 2abe0834..67a4e02a 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -8,8 +8,10 @@ # expert space to declare all kinds of holistic operators + __all__ = ['kHolistLinearSets'] + class LinearColumnWeight(GenericHolisticOp): """ Perform Y = XW + b -> Y = X[W1,W2] + [b1,b2] @@ -25,11 +27,11 @@ def __init__(self): # TODO bias_layout = outline.SplitAxis(axis=0, chunk_num=None, overlap=0) # TODO - output_layout = outline.Align(weight_layout) + output_layout = weight_layout super().__init__( - input_layout=(inputs_layout, weight_layout, bias_layout), - output_layout=(output_layout,) + input_layout=[inputs_layout, weight_layout, bias_layout], + output_layout=[output_layout,] ) def forward(self, inputs, weight, bias): @@ -64,8 +66,8 @@ def __init__(self): output_layout = outline.Full(reduce=ReductionOpPool.Sum) super().__init__( - input_layout=(inputs_layout, weight_layout, bias_layout), - output_layout=(output_layout,) + input_layout=[inputs_layout, weight_layout, bias_layout], + output_layout=[output_layout,] ) def forward(self, inputs, weight, bias): From 98ce7d5631437b51dfd34e1628efc56b00a0fab5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 26 Jul 2021 12:43:19 +0000 Subject: [PATCH 0076/1892] on-going: holistic op verification --- cube/operator/holist/generics.py | 13 ++++++- tests/operator/test_holistic_op.py | 59 ++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index fff700b0..b1026844 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -16,8 +16,8 @@ class GenericHolisticOp: def __init__(self, input_layout, output_layout, - input_format=None, output_format=None, - dim_order=None): + input_format=None, output_format=None + ): """ Layout is the community distribution requirement for input and output logical tensors. @@ -39,6 +39,15 @@ def __init__(self, """ # holistic layout (outliner) of input + if not isinstance(input_layout, list): + raise TypeError("Require input layout for HolistOp is a list") + if not isinstance(input_format, list) or input: + raise TypeError("Require input format for HolistOp is a list") + if not isinstance(output_layout, list): + raise TypeError("Require output layout for HolistOp is a list") + if not isinstance(output_format, list): + raise TypeError("Require output format for HolistOp is a list") + self.input_layout = input_layout self.input_format = input_format diff --git a/tests/operator/test_holistic_op.py b/tests/operator/test_holistic_op.py index e69de29b..ff4087e8 100644 --- a/tests/operator/test_holistic_op.py +++ b/tests/operator/test_holistic_op.py @@ -0,0 +1,59 @@ +import cube.tensor.logic.segment as sg +from cube.tensor.logic.tensor import LogicalTensor + +from cube.operator.holist.generics import GenericHolisticOp + + +def test_generic_holistic_op_init(): + + # description + input_layout = sg.SplitAxis( + axis=0, overlap=0, reduction=sg.ReductionOp.Replica + ) + weight_layout = sg.Full(reduction=sg.ReductionOp.Replica) + output_layout = sg.SplitAxis( + axis=0, overlap=0, chunk_num=input_layout.chunk_num, + reduction=sg.ReductionOp.Replica + ) + + op = GenericHolisticOp( + input_layout=[input_layout, weight_layout], + output_layout=[output_layout], + input_format=[None, None], + output_format=[None], + ) + + assert len(op.input_layout) == 2 + assert len(op.input_format) == 2 + assert len(op.output_layout) == 1 + assert len(op.output_format) == 1 + assert op.logical_op is None + assert op.policy_fn is None + + +def test_generic_holistic_op_input_adapter(): + + input_layout = sg.SplitAxis( + axis=0, overlap=0, reduction=sg.ReductionOp.Replica + ) + weight_layout = sg.Full(reduction=sg.ReductionOp.Replica) + output_layout = sg.SplitAxis( + axis=0, overlap=0, chunk_num=input_layout.chunk_num, + reduction=sg.ReductionOp.Replica + ) + + op = GenericHolisticOp( + input_layout=[input_layout, weight_layout], + output_layout=[output_layout], + input_format=[None, None], + output_format=[None], + ) + + input = LogicalTensor(shape=(1024, 1024)) + weight = LogicalTensor(shape=(1024, 1024)) + + + +if __name__ == '__main__': + + test_generic_holistic_op_init() \ No newline at end of file From 9a57dc4125da348f6199ee0fbb4a5404172a1edd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 03:21:11 +0000 Subject: [PATCH 0077/1892] add repr for segment --- cube/tensor/logic/segment/segment.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cube/tensor/logic/segment/segment.py b/cube/tensor/logic/segment/segment.py index db0acd1b..79711068 100644 --- a/cube/tensor/logic/segment/segment.py +++ b/cube/tensor/logic/segment/segment.py @@ -84,6 +84,11 @@ def reorder(self, new_orders): for dim in range(len(self.indices)): self.indices[dim] = [self.indices[dim][idx] for idx in new_orders] + def __repr__(self): + msg = 'DataSegment(indices_len={}, reduction={})'.format( + len(self.indices), self.reduction + ) + ## Higher structure to cover the most cases ## class TileSegment(DataSegment): @@ -115,6 +120,12 @@ def get_indices(self): def reorder(self): pass + def __repr__(self): + msg = 'TileSegment(anchor={}, shape={}, reduction={})'.format( + self.anchor, self.shape, self.reduction + ) + return msg + ## Primitive sets for translation ## From a7216cfc913d2934486c049ad71caef570519268 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 03:21:33 +0000 Subject: [PATCH 0078/1892] fix bug in reduction transformation of SplitAxis --- cube/tensor/logic/segment/outline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py index a13d597f..877fa3e0 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/segment/outline.py @@ -119,7 +119,7 @@ def __call__(self, shape): #TODO: support list of reductions for cid in range(self.chunk_num.get()): segment = TileSegment( - list(anchor), list(shape), reduction=self.reduction) + list(anchor), list(shape), reduction=self.reduction.get()) anchor[self.axis.get()] += shape[self.axis.get()] segments.append(segment) return segments From 616d62a785e0a700ce22b61b62ad164ba411bbb3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 03:22:21 +0000 Subject: [PATCH 0079/1892] init match functionality --- cube/tensor/logic/tensor.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 0738b2fc..4bbffbc2 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -18,6 +18,23 @@ def __init__(self, shape, init_data=True): import torch self.data = torch.randn(shape).detach() + def match(communities, ranks=None, val_map_fns=None): + """ + Match the LogicalTensor with community list. + """ + #TODO: community matching and transformation + if ranks is None: + ranks = [None] * len(communities) + if val_map_fn is None: + val_map_fn = [None] * len(communities) + if len(self.communities) == 0: + for cid in range(len(communities)): + self.set_community(community) + if not community.materialized: + rank_list = ranks[cid] + val_map_fn = ranks[cid] + community.deploy(ranks, self, val_map_fn) + def get_physical_tensor(self, segment): """ Get physical tensor from the community. From 63c14ac31d95a325b5c15bcda0b64032c58f85d2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 04:50:02 +0000 Subject: [PATCH 0080/1892] tensor match bug fix --- cube/tensor/logic/tensor.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 4bbffbc2..5911d8ba 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -18,22 +18,29 @@ def __init__(self, shape, init_data=True): import torch self.data = torch.randn(shape).detach() - def match(communities, ranks=None, val_map_fns=None): + def match(self, communities, ranks=None, val_map_fns=None): """ Match the LogicalTensor with community list. """ + # type check + ranks = [None] * len(communities) if ranks is None else ranks + val_map_fns = [None] * len(communities) if val_map_fns is None else val_map_fns + if not isinstance(ranks, list): + raise TypeError("Expected ranks to be a list or None") + if not isinstance(ranks, list): + raise TypeError("Expected ranks to be a list or None") + #TODO: community matching and transformation - if ranks is None: - ranks = [None] * len(communities) - if val_map_fn is None: - val_map_fn = [None] * len(communities) if len(self.communities) == 0: for cid in range(len(communities)): + community = communities[cid] self.set_community(community) if not community.materialized: rank_list = ranks[cid] - val_map_fn = ranks[cid] - community.deploy(ranks, self, val_map_fn) + val_map_fn = val_map_fns[cid] + community.deploy(rank_list, self, val_map_fn) + else: + raise NotImplementedError def get_physical_tensor(self, segment): """ @@ -45,7 +52,7 @@ def get_physical_tensor(self, segment): Returns: torch.Tensor or None """ - community = self.communities[idx] + community = self.communities[segment] return community.get_physical_tensor() def get_community(self, segment): @@ -88,7 +95,7 @@ def set_community(self, community): with the given community's segment, the original community will be overrided """ - if not isinstance(community): + if not isinstance(community, Community): raise TypeError("Expected a community") segment = community.segment if segment not in self.communities: From 4ef08d9c719bdd79f6198046de8c7cb093e44f92 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 04:50:45 +0000 Subject: [PATCH 0081/1892] test pass on input adapter --- cube/operator/holist/generics.py | 44 ++++++++++++++++------------ tests/operator/test_holistic_op.py | 47 ++++++++++++++++++++++++++---- 2 files changed, 68 insertions(+), 23 deletions(-) diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index b1026844..9a3e5445 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -41,7 +41,7 @@ def __init__(self, # holistic layout (outliner) of input if not isinstance(input_layout, list): raise TypeError("Require input layout for HolistOp is a list") - if not isinstance(input_format, list) or input: + if not isinstance(input_format, list): raise TypeError("Require input format for HolistOp is a list") if not isinstance(output_layout, list): raise TypeError("Require output layout for HolistOp is a list") @@ -80,31 +80,37 @@ def input_adapter(self, *args, **kwargs): raise RuntimeError("Fail to adapt input: format length not equal") # step 1: data reformat based on the input argument - for input, dim_order in zip(args, self.input_layout): - if dim_order is not None and isinstance(input, LogicalTensor): + for input, dim_order in zip(args, self.input_format): + if dim_order is not None: input.permute(dim_order) - # step 2: get segments based on expert description - tensor_segments = list() - for outliner, tensor in zip(self.input_layout, args): - if outliner is not None: + # step 2: get communities based on expert description + input_communities = list() + for tensor, outliner in zip(args, self.input_layout): + if outliner is not None and isinstance(tensor, LogicalTensor): segments = outliner(tensor.shape) - tensor_segments.append(segments) + communities = [Community(seg) for seg in segments] + input_communities.append(communities) else: - tensor_segments.append(None) + input_communities.append(None) # step 3: physical tensor placement (policy) - if self.policy_module is not None: - tensor_communities, tensor_devices = self.policy_fn[0](args, tensor_segments) + if self.policy_fn is not None: + input_ranks, input_val_map_fns = \ + self.policy_fn[0](input_communities, *args) else: - # init community without policy decision - tensor_communities = [[Community(seg) for seg in segments] for segments in tensor_segments] - tensor_devices = None - tensor_val_map_fns = None + # TODO: default policy + input_ranks = [None] * len(args) + input_val_map_fns = [None] * len(args) # step 4: community matching - for communities, devices, tensor in zip(tensor_communities, tensor_devices, tensor_inputs): - tensor.match(communities, tensor_devices, tensor_val_map_fns) + for tid in range(len(args)): + tensor = args[tid] + if isinstance(tensor, LogicalTensor): + communities = input_communities[tid] + ranks = input_ranks[tid] + val_map_fn = input_val_map_fns[tid] + tensor.match(communities, ranks, val_map_fn) def forward(self, *args, **kwargs): """ @@ -167,4 +173,6 @@ def register_policy(self, policy_fn): Args: plicy_fn (callable) """ - self.policy_fn = [policy_fn] + if not callable(policy_fn): + raise TypeError("Expected callable function") + self.policy_fn = (policy_fn,) diff --git a/tests/operator/test_holistic_op.py b/tests/operator/test_holistic_op.py index ff4087e8..98659a9c 100644 --- a/tests/operator/test_holistic_op.py +++ b/tests/operator/test_holistic_op.py @@ -1,8 +1,21 @@ +""" +cmd for running the test + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + tests/operator/test_holistic_op.py +""" + import cube.tensor.logic.segment as sg from cube.tensor.logic.tensor import LogicalTensor - from cube.operator.holist.generics import GenericHolisticOp - +from cube.device.physic.group import DeviceGroup +import torch def test_generic_holistic_op_init(): @@ -51,9 +64,33 @@ def test_generic_holistic_op_input_adapter(): input = LogicalTensor(shape=(1024, 1024)) weight = LogicalTensor(shape=(1024, 1024)) - + ## Policy Here + input_layout.chunk_num = 4 + assert output_layout.chunk_num.get() == 4 + def policy_fn(input_communities, input, weight): + input_ranks = [ + [[0],[1],[2],[3]], + [[0,1,2,3]] + ] + input_val_map_fns = list([None, None]) + return input_ranks, input_val_map_fns -if __name__ == '__main__': + op.register_policy(policy_fn) + op.input_adapter(input, weight) + + myrank = DeviceGroup().rank + assert len(input.communities) == 4 + assert len(weight.communities) == 1 + physical_tensor = input.get_physical_tensor(input.segments[myrank]) + piece = 1024 // 4 + start = int(myrank * piece) + assert torch.allclose(physical_tensor, input.data.cuda()[start:start+piece, :]) is True + physical_tensor = weight.get_physical_tensor(weight.segments[0]) + assert torch.allclose(physical_tensor, weight.data.cuda()) is True - test_generic_holistic_op_init() \ No newline at end of file + +if __name__ == '__main__': + group = DeviceGroup() + test_generic_holistic_op_init() + test_generic_holistic_op_input_adapter() From 130554fd15c47a338d01d3edba35d9f157de0803 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 05:51:33 +0000 Subject: [PATCH 0082/1892] community set physical tensor --- cube/tensor/community.py | 16 ++++++++++++++++ tests/tensor/test_community.py | 17 +++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/cube/tensor/community.py b/cube/tensor/community.py index 2f94d36c..375e2f3f 100644 --- a/cube/tensor/community.py +++ b/cube/tensor/community.py @@ -86,3 +86,19 @@ def get_physical_tensor(self): return self.physical_tensor else: raise RuntimeError("The Community has not been materialized to physical tensors") + + def set_physical_tensor(self, physical_tensor, ranks): + if self.materialized: + raise RuntimeError("Setting physical tensors to a materialized community") + if not isinstance(physical_tensor, torch.Tensor): + raise TypeError("physical_tensor: Expected a torch tensor") + if not isinstance(ranks, list): + raise TypeError("ranks: Expected a list[int]") + if physical_tensor.size() != torch.Size(self.segment.shape): + raise RuntimeError( + "Trying to set a community where physical tensor shape " + "doesn't match with segment shape") + #TODO: device check + self.physical_tensor = physical_tensor + self.group = DeviceGroup().get_group(ranks) + self.materialized = True diff --git a/tests/tensor/test_community.py b/tests/tensor/test_community.py index 1b4cc74b..ced08a34 100644 --- a/tests/tensor/test_community.py +++ b/tests/tensor/test_community.py @@ -86,7 +86,19 @@ def test_community_sync(): else: # print('ref: {}'.format(ref_tensor)) assert torch.allclose(sub_tensor, ref_tensor) is True - + + +def test_community_set_physical_tensor(): + seg = segment.TileSegment( + anchor=(2,3,1), shape=(4,4,4), + reduction=segment.ReductionOp.Sum) + community = Community(seg) + + tensor = torch.randn((4,4,4)) + community.set_physical_tensor(tensor, [0,1,2]) + assert community.materialized is True + assert community.group == DeviceGroup().get_group([0,1,2]) + assert community.physical_tensor is tensor if __name__ == '__main__': @@ -96,4 +108,5 @@ def test_community_sync(): test_community_init() test_community_deploy() - test_community_sync() \ No newline at end of file + test_community_sync() + test_community_set_physical_tensor() \ No newline at end of file From fced88def754a547418b483a462c6b0c9df85a5c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 07:43:29 +0000 Subject: [PATCH 0083/1892] test construct logical tensor --- cube/tensor/logic/tensor.py | 23 ++++++++++++++++++---- tests/tensor/test_logical_tensor.py | 30 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 tests/tensor/test_logical_tensor.py diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 5911d8ba..6f013fc6 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -42,6 +42,13 @@ def match(self, communities, ranks=None, val_map_fns=None): else: raise NotImplementedError + @staticmethod + def construct(shape, communities): + tensor = LogicalTensor(shape=shape, init_data=False) + for community in communities: + tensor.set_community(community) + return tensor + def get_physical_tensor(self, segment): """ Get physical tensor from the community. @@ -55,14 +62,22 @@ def get_physical_tensor(self, segment): community = self.communities[segment] return community.get_physical_tensor() - def get_community(self, segment): + def get_community(self, segment_or_index): """ Get Community based on the segment + + Args: + segment_or_index (DataSegment or int): + + Returns: + Community """ - if not isinstance(segment, DataSegment): + if isinstance(segment_or_index, DataSegment): + return self.communities[segment_or_index] + elif isinstance(segment_or_index, int): + return self.communities[self.segments[segment_or_index]] + else: raise ValueError("Expected (derived) DataSegment to chooese Community") - if segment not in self.communities: - raise KeyError("The segment doesn't found in current tensor") return self.communities[segment] def __getitem__(self, key): diff --git a/tests/tensor/test_logical_tensor.py b/tests/tensor/test_logical_tensor.py new file mode 100644 index 00000000..dbc871ee --- /dev/null +++ b/tests/tensor/test_logical_tensor.py @@ -0,0 +1,30 @@ +from cube.tensor.logic.tensor import LogicalTensor +from cube.tensor.community import Community +import cube.tensor.logic.segment as segment + + +def test_logical_tensor_init(): + + #TODO + pass + + +def test_logical_tensor_construct(): + + seg = segment.TileSegment( + anchor=(2,3,1), shape=(4,4,4), + reduction=segment.ReductionOp.Replica) + community = Community(seg) + + logical_tensor = LogicalTensor.construct((10,10,10), [community]) + + assert isinstance(logical_tensor, LogicalTensor) + assert len(logical_tensor.communities) == 1 + assert logical_tensor.get_community(0) is community + assert logical_tensor.shape == (10,10,10) + + +if __name__ == '__main__': + + test_logical_tensor_init() + test_logical_tensor_construct() \ No newline at end of file From 4c178614a1f533fa946cffd824fc640ecbbd42a7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 12:34:24 +0000 Subject: [PATCH 0084/1892] runnable on parallel linear --- cube/operator/holist/generics.py | 50 ++++++++++---- cube/operator/holist/linear.py | 83 ++++++++++++++++-------- cube/operator/logic/linear.py | 6 +- cube/operator/physic/generics.py | 11 +++- cube/tensor/community.py | 16 +++-- cube/tensor/logic/segment/outline.py | 17 ++++- cube/tensor/logic/tensor.py | 27 ++++++-- tests/operator/test_holistic_linear.py | 90 ++++++++++++++++++++++++++ tests/operator/test_logical_op.py | 2 +- 9 files changed, 242 insertions(+), 60 deletions(-) create mode 100644 tests/operator/test_holistic_linear.py diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index 9a3e5445..b2167d83 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -28,11 +28,11 @@ def __init__(self, on the logical required format. Args: - input_laytout (list[Outliner, None]): outliner for each input + input_layout (list[Outliner, None]): outliner for each input The length of outliner should be equal to the number of input input_format (list[list[int], None]): input dim order compare with logical definition - output_laytout (list[Outlinter, None]): outliner for each output + output_layout (list[Outlinter, None]): outliner for each output The length of outliner should be equal to the number of output output_format (list[list[int], None]): output dim order compare with logical definition @@ -129,24 +129,39 @@ def forward(self, *args, **kwargs): def output_adapter(self, outputs): """ Data reformat to logical op format + + Args: + outputs (tuple(list[physical_tensor],)) + each `list[physical_tensor]` represents a output of the op + with is communities + Returns: + logical outputs (tuple(LogicalTensor,)): + the logical tensor list """ - if not isinstance(outputs, list): - outputs = [outputs] + #TODO: fix: data re-format order. Should be ahead of logical tensor construction + if not isinstance(outputs, tuple): + outputs = (outputs,) # step 1: construct to logical tensor logical_outputs = list() - for output_segs, outliner, shape in zip(self.outputs, self.output_layout, self.logical_shapes): - if shape is not None: - communities = outliner(shape) - output = construct_from_community(shape, communities, output_segs) - logical_outputs.append(output) - else: - logical_outputs.append(output_segs) + for output, outliner, shape in zip(outputs, self.output_layout, self.logical_shapes): + segments = outliner(shape) + communities = [Community(segment) for segment in segments] + for community, op_res in zip(communities, output): + #if DeviceGroup().rank == 0: + # print(op_res.res.size(), community.segment.shape) + community.set_physical_tensor(op_res.res, op_res.placement) + output = LogicalTensor.construct(shape, communities) + logical_outputs.append(output) # step 2: data reformat based on the output for out_id in range(len(self.output_format)): dim_order = self.output_format[out_id] if dim_order is not None and isinstance(logical_outputs[out_id], LogicalTensor): logical_ouputs[out_id] = logical_ouputs[out_id].permute(dim_order) - return logical_outputs + + if len(logical_outputs) == 1: + return logical_outputs[0] + else: + return tuple(logical_outputs) def __call__(self, *args, **kwargs): @@ -159,12 +174,12 @@ def __call__(self, *args, **kwargs): # wrap to logical tensor if self.logical_op is None: raise RuntimeError("This holistic op doesn't have logical op") - self.logical_shapes = self.logical_op.infer_shapes(*args, **kwargs) + self.logical_shapes = self.logical_op.shape_infer(*args, **kwargs) outputs = self.output_adapter(outputs) return outputs - def register_policy(self, policy_fn): + def set_deploy_policy(self, policy_fn): """ Register a policy to take inputs (logical tensors) and segments, generate device placement for each community, and corresponding @@ -176,3 +191,10 @@ def register_policy(self, policy_fn): if not callable(policy_fn): raise TypeError("Expected callable function") self.policy_fn = (policy_fn,) + + def set_segmentation_policy(self, policy_fn): + for outliner in self.input_layout: + outliner.set_policy(policy_fn) + for outliner in self.output_layout: + outliner.set_policy(policy_fn) + diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index 67a4e02a..2a85ad23 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -1,11 +1,15 @@ from cube.operator.holist.generics import GenericHolisticOp -import cube.operator.physic as physic_op +import cube.operator.physic.linear as phy_linear from cube.tensor.logic.tensor import LogicalTensor -import cube.tensor.logic.segment.outline as outline +import cube.tensor.logic.segment as sg from cube.tensor.community import Community +# Debug +from cube.device.physic.group import DeviceGroup +import torch + # expert space to declare all kinds of holistic operators @@ -20,26 +24,49 @@ class LinearColumnWeight(GenericHolisticOp): def __init__(self): - # TODO - inputs_layout = outline.Full - # TODO - weight_layout = outline.SplitAxis(axis=0, chunk_num=None, overlap=0) - # TODO - bias_layout = outline.SplitAxis(axis=0, chunk_num=None, overlap=0) - # TODO - output_layout = weight_layout + inputs_layout = sg.outline.Full( + reduction=sg.ReductionOp.Replica + ) + + weight_layout = sg.outline.SplitAxis( + axis=0, chunk_num=None, overlap=0, uniform=False, + reduction=sg.ReductionOp.Replica + ) + + bias_layout = weight_layout + + output_layout = sg.outline.SplitAxis( + axis=1, chunk_num=weight_layout.chunk_num, overlap=0, uniform=False, + reduction=sg.ReductionOp.Replica + ) super().__init__( input_layout=[inputs_layout, weight_layout, bias_layout], - output_layout=[output_layout,] + output_layout=[output_layout,], + input_format=[None, None, None], + output_format=[None] ) - def forward(self, inputs, weight, bias): + def forward(self, input, weight, bias): outputs = list() # TODO: handle bias is None - for pw, pb in zip(weight, bias): - output = physic_op.linear(inputs, weight, bias) - outputs.append(outputs) + physical_input = input.get_physical_tensor(0) + for cid in range(len(weight)): + # output = physic_op.linear(inputs, weight, bias) + #TODO: TensorContainer to enable op placement + tensor movement + #TODO: ExecutionScheduler to handle re-compute / swap + #TODO: nested hybrid call to enable hybrid-parallelisms + #TODO: double-check necessety of stateful physical operator + physical_weight = weight.get_physical_tensor(cid) + # if DeviceGroup().rank == 0: + # print(physical_weight) + physical_bias = bias.get_physical_tensor(cid) + # TODO: this is the policy decision + phy_op = phy_linear.Linear(placement=weight.get_community(cid).placement) + output = phy_op(physical_input, physical_weight, physical_bias) + # if DeviceGroup().rank == 0: + # print(output) + outputs.append(output) return outputs @@ -55,24 +82,28 @@ class LinearColumnInputRowWeight(GenericHolisticOp): def __init__(self): - # TODO - inputs_layout = outline.SplitAxis(axis=-1, chunk_num=None, overlap=0) - # TODO - align = outline.Align(inputs_layout.chunk_num) - weight_layout = outline.SplitAxis(axis=1, chunk_num=align, overlap=0) - # TODO - bias_layout = outline.Full(reduce=ReductionOpPool.Sum) - # TODO - output_layout = outline.Full(reduce=ReductionOpPool.Sum) + inputs_layout = sg.outline.SplitAxis( + axis=-1, chunk_num=None, overlap=0, + reduction=sg.ReductionOp.Replica) + + weight_layout = sg.outline.SplitAxis( + axis=1, chunk_num=inputs_layout.chunk_num, overlap=0, + reduction=sg.ReductionOp.Replica) + + bias_layout = sg.outline.Full(reduction=sg.ReductionOp.Sum) + + output_layout = sg.outline.Full(reduction=sg.ReductionOp.Sum) super().__init__( input_layout=[inputs_layout, weight_layout, bias_layout], - output_layout=[output_layout,] + output_layout=[output_layout,], + input_format=[None, None, None], + output_format=[None, None, None] ) def forward(self, inputs, weight, bias): output = physic_op.linear(inputs, weight, bias) - return [output] + return [output,] kHolistLinearSets = [LinearColumnWeight(), LinearColumnInputRowWeight()] \ No newline at end of file diff --git a/cube/operator/logic/linear.py b/cube/operator/logic/linear.py index d4ab70ee..33c9b80a 100644 --- a/cube/operator/logic/linear.py +++ b/cube/operator/logic/linear.py @@ -15,11 +15,11 @@ def __init__(self): holist_op.set_logic_op(self) self.factory.register(holist_op) - def shape_infer(self, input_shape, weight_shape, bias_shape=None) + def shape_infer(self, input_shape, weight_shape, bias_shape=None): """ Return the outputs shape [list[int],] """ - output_shape = list(input_shape) - output_shape[-1] = weight_shape[-1] + output_shape = input_shape.shape + output_shape[-1] = weight_shape.shape[0] return [output_shape] diff --git a/cube/operator/physic/generics.py b/cube/operator/physic/generics.py index d52c7d3e..2ca3e3d0 100644 --- a/cube/operator/physic/generics.py +++ b/cube/operator/physic/generics.py @@ -16,6 +16,9 @@ def __init__(self, result, ranks): def get_result(self): return self.res + def __repr__(self): + return "OpResult(res={}, placement={})".format(self.res, self.placement) + class GenericPhysicOp: """ @@ -30,10 +33,14 @@ def __init__(self, func, placement=None): if not callable(func): raise TypeError("Expect callable function") - self.func = [func] - self._placement = placement + if not (isinstance(placement, list) or placement is None): + raise TypeError("Expected placement init with None or list[int]") + self.func = (func,) + self._placement = None self.execute_flag = False self.policy_fn = None + if isinstance(placement, list): + self.placement = placement @property def placement(self): diff --git a/cube/tensor/community.py b/cube/tensor/community.py index 375e2f3f..aaa8f58c 100644 --- a/cube/tensor/community.py +++ b/cube/tensor/community.py @@ -26,6 +26,7 @@ def __init__(self, segment): # connection to physical tensor (the PyTorch Tensor) self.physical_tensor = None + self.placement = list() self.group = list() self.materialized = False @@ -45,6 +46,9 @@ def deploy(self, ranks, logic_tensor, value_map_fn=None): return a new tensor """ + if not isinstance(ranks, list): + raise TypeError("Expected ranks in list[int]") + self.placement = ranks rank = DeviceGroup().rank self.group = DeviceGroup().get_group(ranks) if rank not in ranks: @@ -90,15 +94,13 @@ def get_physical_tensor(self): def set_physical_tensor(self, physical_tensor, ranks): if self.materialized: raise RuntimeError("Setting physical tensors to a materialized community") - if not isinstance(physical_tensor, torch.Tensor): - raise TypeError("physical_tensor: Expected a torch tensor") if not isinstance(ranks, list): raise TypeError("ranks: Expected a list[int]") - if physical_tensor.size() != torch.Size(self.segment.shape): - raise RuntimeError( - "Trying to set a community where physical tensor shape " - "doesn't match with segment shape") - #TODO: device check + if physical_tensor is not None: + if list(physical_tensor.size()) != list(self.segment.shape): + raise RuntimeError( + "Trying to set a community where physical tensor shape " + "doesn't match with segment shape") self.physical_tensor = physical_tensor self.group = DeviceGroup().get_group(ranks) self.materialized = True diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/segment/outline.py index 877fa3e0..568e8cce 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/segment/outline.py @@ -55,6 +55,13 @@ class BaseOutline: """ def __init__(self, reduction=None): self.reduction = reduction + # decide how to generate segmentation given the requirement + self.policy_fn = None + + def set_policy(self, policy_fn): + if not callable(policy_fn): + raise TypeError("Expected a function to take BaseOutline instance") + self.policy_fn = policy_fn def __setattr__(self, key, val): if key in self.__dict__: @@ -67,6 +74,10 @@ def __setattr__(self, key, val): else: self.__dict__[key] = ConstantContainer(val) + def __call__(self): + if self.policy_fn is not None: + self.policy_fn.get()(self) + class Full(BaseOutline): @@ -74,6 +85,8 @@ def __init__(self, reduction=None): super().__init__(reduction) def __call__(self, shape): + #TODO: super call seperate + super().__call__() segment = TileSegment([0] * len(shape), list(shape), self.reduction.get()) return [segment] @@ -111,7 +124,9 @@ def __call__(self, shape): Runtime segment generation given the logical tensor shape This is the policy that how to do the translation. - """ + """ + #TODO: super call seperate + super().__call__() segments = list() shape = list(shape) shape[self.axis.get()] = shape[self.axis.get()] // self.chunk_num.get() diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 6f013fc6..7e96383f 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -29,6 +29,16 @@ def match(self, communities, ranks=None, val_map_fns=None): raise TypeError("Expected ranks to be a list or None") if not isinstance(ranks, list): raise TypeError("Expected ranks to be a list or None") + if len(ranks) != len(communities): + raise RuntimeError( + "Un-matched length of communities ({}) : ranks ({})".format( + len(communities), len(ranks)) + ) + if len(val_map_fns) != len(communities): + raise RuntimeError( + "Un-matched length of communities ({}) : ranks ({})".format( + len(communities), len(ranks)) + ) #TODO: community matching and transformation if len(self.communities) == 0: @@ -49,7 +59,7 @@ def construct(shape, communities): tensor.set_community(community) return tensor - def get_physical_tensor(self, segment): + def get_physical_tensor(self, segment_or_index): """ Get physical tensor from the community. @@ -59,8 +69,13 @@ def get_physical_tensor(self, segment): Returns: torch.Tensor or None """ - community = self.communities[segment] - return community.get_physical_tensor() + return self.get_community(segment_or_index).get_physical_tensor() + + def __len__(self): + """ + Return community number + """ + return len(self.segments) def get_community(self, segment_or_index): """ @@ -72,10 +87,10 @@ def get_community(self, segment_or_index): Returns: Community """ - if isinstance(segment_or_index, DataSegment): - return self.communities[segment_or_index] - elif isinstance(segment_or_index, int): + if isinstance(segment_or_index, int): return self.communities[self.segments[segment_or_index]] + elif isinstance(segment_or_index, DataSegment): + return self.communities[segment_or_index] else: raise ValueError("Expected (derived) DataSegment to chooese Community") return self.communities[segment] diff --git a/tests/operator/test_holistic_linear.py b/tests/operator/test_holistic_linear.py new file mode 100644 index 00000000..b1539cba --- /dev/null +++ b/tests/operator/test_holistic_linear.py @@ -0,0 +1,90 @@ +""" +cmd for running the test + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + tests/operator/test_holistic_linear.py +""" + +from cube.tensor.logic.tensor import LogicalTensor +import cube.tensor.logic.segment as sg + +from cube.operator.holist.linear import LinearColumnWeight +from cube.operator.holist.linear import LinearColumnInputRowWeight + +from cube.device.physic.group import DeviceGroup + +import torch +torch.manual_seed(121) + + +class LogicalLinear: + + def __init__(self): pass + + def shape_infer(self, input_shape, weight_shape, bias_shape=None): + """ + Return the outputs shape [list[int],] + """ + #TODO: change all shape impl to list + output_shape = list(input_shape.shape) + output_shape[-1] = weight_shape.shape[0] + return [output_shape] + + +def test_holistic_linear_op_column_weight(): + + input = LogicalTensor(shape=(1024,1024)) + weight = LogicalTensor(shape=(1024,1024)) + bias = LogicalTensor(shape=(1024,)) + + holistic_op = LinearColumnWeight() + holistic_op.logical_op = LogicalLinear() + + # policy setup + def policy_for_how_many_tiles(outliner): + if isinstance(outliner, sg.outline.Full): + pass + elif isinstance(outliner, sg.outline.SplitAxis): + if outliner.chunk_num.get() is None: + outliner.chunk_num = 4 + else: + raise TypeError("Unhandled outliner type") + + def policy_for_each_tile_placement(community, input, weight, bias): + input_ranks = [ + [[0,1,2,3]], + [[0],[1],[2],[3]], + [[0],[1],[2],[3]] + ] + input_val_map_fns = list([None, None, None]) + return input_ranks, input_val_map_fns + + holistic_op.set_deploy_policy( + policy_for_each_tile_placement + ) + holistic_op.set_segmentation_policy( + policy_for_how_many_tiles + ) + + output = holistic_op(input, weight, bias) + + output_ref = torch._C._nn.linear(input.data.cuda(), weight.data.cuda(), bias.data.cuda()) + rank = DeviceGroup().rank + output_ref = torch.chunk(output_ref, chunks=4, dim=1)[rank].contiguous() + out = output.get_physical_tensor(rank) + if rank == 0: + print('ref: ', output_ref) + print('get: ', out) + print('sum: ', torch.sum(torch.abs(out - output_ref))) + print(torch.allclose(output.get_physical_tensor(rank), output_ref)) # is True + + +if __name__ == '__main__': + + test_holistic_linear_op_column_weight() \ No newline at end of file diff --git a/tests/operator/test_logical_op.py b/tests/operator/test_logical_op.py index 33cab4d9..fc8f60cb 100644 --- a/tests/operator/test_logical_op.py +++ b/tests/operator/test_logical_op.py @@ -44,6 +44,6 @@ def policy_fn(factory): if __name__ == '__main__': - test_factory + test_factory() test_generic_logical_op_init() test_generic_logical_op_register() \ No newline at end of file From 9bbf799df28db65dd7995a1f61b683facefeae82 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 13:15:32 +0000 Subject: [PATCH 0085/1892] POC to find hardware bias --- tests/operator/test_holistic_linear.py | 61 +++++++++++++++++++++----- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/tests/operator/test_holistic_linear.py b/tests/operator/test_holistic_linear.py index b1539cba..d4f3b83e 100644 --- a/tests/operator/test_holistic_linear.py +++ b/tests/operator/test_holistic_linear.py @@ -20,7 +20,7 @@ from cube.device.physic.group import DeviceGroup import torch -torch.manual_seed(121) +torch.manual_seed(100) class LogicalLinear: @@ -37,16 +37,44 @@ def shape_infer(self, input_shape, weight_shape, bias_shape=None): return [output_shape] -def test_holistic_linear_op_column_weight(): +def test_linear_POC(): + + N = 1024 + input = torch.randn((1024, 1024)).cuda() + weight = torch.randn((N, 1024)) + bias = torch.randn((N,)) + + rank = DeviceGroup().rank + # partial + partial_weight = torch.chunk(weight, 4, dim=0)[rank].cuda() + partial_bias = torch.chunk(bias, 4, dim=0)[rank].cuda() + partial_out = torch._C._nn.linear(input, partial_weight, partial_bias) + + # full + out_full = torch._C._nn.linear(input, weight.cuda(), bias.cuda()) + ref_out = torch.chunk(out_full, 4, dim=1)[rank].cuda() + + if rank == 0: + print('max bias: ', torch.max(torch.abs(partial_out - ref_out))) + print('sum bias: ', torch.sum(torch.abs(partial_out - ref_out))) + + +def test_holistic_linear_op_column_weight(): + """ + Note: Due to unknown reason in hardware, the output will have up to + 0.0001 bias. This is verified in `test_linear_POC()` The larger + K results larger bias. + """ + N = 1024 input = LogicalTensor(shape=(1024,1024)) - weight = LogicalTensor(shape=(1024,1024)) - bias = LogicalTensor(shape=(1024,)) + weight = LogicalTensor(shape=(N,1024)) + bias = LogicalTensor(shape=(N,)) holistic_op = LinearColumnWeight() holistic_op.logical_op = LogicalLinear() - # policy setup + # ================================ Policy =========================== def policy_for_how_many_tiles(outliner): if isinstance(outliner, sg.outline.Full): pass @@ -71,20 +99,29 @@ def policy_for_each_tile_placement(community, input, weight, bias): holistic_op.set_segmentation_policy( policy_for_how_many_tiles ) + # ================================ Policy =========================== output = holistic_op(input, weight, bias) - output_ref = torch._C._nn.linear(input.data.cuda(), weight.data.cuda(), bias.data.cuda()) + # =============================== Test ============================== + output_ref = torch._C._nn.linear( + input.data.cuda(), weight.data.cuda(), bias.data.cuda() + ) rank = DeviceGroup().rank output_ref = torch.chunk(output_ref, chunks=4, dim=1)[rank].contiguous() out = output.get_physical_tensor(rank) - if rank == 0: - print('ref: ', output_ref) - print('get: ', out) - print('sum: ', torch.sum(torch.abs(out - output_ref))) - print(torch.allclose(output.get_physical_tensor(rank), output_ref)) # is True + # if rank == 0: + # print('ref: ', output_ref) + # print('get: ', out) + # print('max bias: ', torch.max(torch.abs(out - output_ref))) + # print('sum bias: ', torch.sum(torch.abs(out - output_ref))) + error_max = torch.max(torch.abs(out - output_ref)) + assert error_max.item() < 2e-4 + # =============================== Test ============================== if __name__ == '__main__': - + group = DeviceGroup() + + # test_linear_POC() test_holistic_linear_op_column_weight() \ No newline at end of file From 6d98ef07055a11e60fbbf45ae4d10b0fdd209bb0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jul 2021 13:30:13 +0000 Subject: [PATCH 0086/1892] update env --- scripts/env-setup.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index 09fa35da..83dcf35e 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -1,5 +1,5 @@ -echo using docker image pytorch-cuda11.3: nvcr.io/nvidia/pytorch:21.04-py3 +echo using docker image pytorch-cuda11.3: nvcr.io/nvidia/pytorch:21.06-py3 git config --global core.editor "vim" git config --global user.name "Zhiqi Lin" @@ -20,4 +20,7 @@ echo 'export PATH=/usr/local/cuda/bin:$PATH' >> ~/.bashrc echo 'export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH' >> ~/.bashrc echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc +# cmd for count code lines +# find cube/ -name "*.py" -print0 | xargs -0 wc -l + python setup.py develop From ec4950969a856e70de54686dfb5e48f13987c93d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 28 Jul 2021 05:31:26 +0000 Subject: [PATCH 0087/1892] parallel primitives --- examples/case_study/parallel_primitive.py | 69 +++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 examples/case_study/parallel_primitive.py diff --git a/examples/case_study/parallel_primitive.py b/examples/case_study/parallel_primitive.py new file mode 100644 index 00000000..1c5bb13d --- /dev/null +++ b/examples/case_study/parallel_primitive.py @@ -0,0 +1,69 @@ +import torch +import os +from cube.device.physic.grou import DeviceGroup() + +torch.manual_seed(121) + + +def linear_tensor_parallel(inputs, weight, bias): + + rank = DeviceGroup().rank + + M = 1024 + K = 1024 + N = 1024 + + ### ============ Input Adapter ============ ### + + # select need to consider transformation from one segmentation to another + inputs_segment = select( + tensor = inputs, + indices = (slice(0, M), slice(0, K)), + shape = (M, K) + ) + + weight_segment = select( + tensor = weight, + indices = (slice(rank * (N // 4), (rank + 1) * (N // 4)), slice(0, K)), + shape = (N // 4, K) + ) + + bias_segment = select( + tensor = bias, + indices = (slice(rank * (N // 4), (rank + 1) * (N // 4)),), + shape = (N // 4,) + ) + + inputs = deploy( + segment = inputs_segment, + ranks = [0, 1, 2, 3], + val_map_op = IdentityForwardAllreduceBackward + ) + + weight = deploy( + segment = weight_segment, + ranks = [rank], + val_map_op = None + ) + + bias = deploy( + segment = bias_segment, + ranks = [rank], + val_map_op = None + ) + ### ============ Input Adapter ============ ### + + + ### ============ Compute ============ ### + output = torch._C._nn.linear(inputs, weight, bias) + ### ============ Compute ============ ### + + + ### ============ Output Adapter ============ ### + segment = recover( + tensor = output, + ranks = [0, 1, 2, 3], + reduction_op = AllGatherForwardSplitBackward + ) + # construct to logical tensor and return + ### ============ Output Adapter ============ ### From fd2ad711f3a4d46e748f6d87b5beed0bd23963d3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 28 Jul 2021 13:22:40 +0000 Subject: [PATCH 0088/1892] init with new primitives --- cube/tensor/indices.py | 72 ++++++++++++++++++++++++ cube/tensor/segment.py | 125 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 cube/tensor/indices.py create mode 100644 cube/tensor/segment.py diff --git a/cube/tensor/indices.py b/cube/tensor/indices.py new file mode 100644 index 00000000..45d19dbe --- /dev/null +++ b/cube/tensor/indices.py @@ -0,0 +1,72 @@ +""" +Basic structure for holding indices -> cover all the cases +""" + + +class BaseIndices: + """ + The basic primitive to gather data in the logical tensor. + + The order of indices indicate the physical storage (1-D array) order + """ + + def __init__(self, indices_list): + """ + Args: + indices_list (list[list[int],], tuple(slice(int, int),)): + indices list + """ + self.indices = tuple(indices_list) + + def get(self): + """ + Get indexable indices + """ + return tuple(self.indices) + + def reorder(self, new_orders): + """ + Reorder the indices. + + Note this can be only called before materialize physical tensors, + or called from underlying operation that will change physical storage format + + Args: + new_orders (iteratable): order of each index + """ + for dim in range(len(self.indices)): + self.indices[dim] = [self.indices[dim][idx] for idx in new_orders] + + def __repr__(self): + msg = 'BaseIndices(indices_len={})'.format( + len(self.indices), self.reduction + ) + + +class TileIndices(BaseIndices): + """ + A tile is a contigonous block on the logical tensor shape, + which can be represented as the start position + offset (shape) + """ + + def __init__(self, anchor, shape): + """ + Args: + anchor (list[int]): start position of the tile + offset (list[int]): offset (shape) of the tile + """ + indices = list() + for start, ofst in zip(self.anchor, self.shape): + indices.append(slice(start, start + ofst)) + super().__init__(tuple(indices)) + self.anchor = anchor + self.shape = shape + + def reorder(self): + raise NotImplementedError + + def __repr__(self): + msg = 'TileIndices(anchor={}, shape={})'.format( + self.anchor, self.shape + ) + return msg diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py new file mode 100644 index 00000000..a38ff7e5 --- /dev/null +++ b/cube/tensor/segment.py @@ -0,0 +1,125 @@ +import torch +from cube.device.physic.group import DeviceGroup + + +__all__ = ['Segment'] + + +class Segment: + + def __init__(self, logical_tensor, indices, shape): + """Create Segment based on the logical tensor + + Segment manages: + + 1). LogicalTensor indices mapping to Physical Tensor data storage + 2). Materialized Physical Tensor + + Attribute: + indices (tuple(slice,) or list[list[int]]): + indices of logical_tensor for this segment + deploy_op (None or callable): + deploy op to take logical value and map + merge_op (None or callable): + merge op to take physical tensor + """ + if not isinstance(logical_tensor, LogicalTensor): + raise TypeError("Expected logical_tensor to be LogicalTensor") + if not isinstance(indices, BaseIndices): + raise TypeError("Expected indices to be BaseIndices") + + # logical tensor + self.logical_tensor = logical_tensor + + # segment info + self.indices = indices + self.shape = shape + + # physical tensor (the PyTorch Tensor) + self.physical_tensor = None + + # deploy information + self.placement = list() + self.group = list() + self.deploy_op = None + self.materialized = False + + # recover op + self.merge_op = None + + def deploy(self, ranks, value_map_op=None): + """deploy (materialize) to physical tensors + + Materialize physical tensors for this community and spread out + based on the given device list. + + This offers policy module an interface to decide which devices + to spread. + + Argument: + ranks (list[int]): device id list + value_map_fn (callable): + takes the tensor, rank, world_size, + return a new tensor + """ + if not isinstance(ranks, list): + raise TypeError("Expected ranks in list[int]") + + rank = DeviceGroup().rank + self.placement = ranks + self.group = DeviceGroup().get_group(ranks) + if rank in ranks: + if self.logic_tensor.data is None: + # TODO: check overlap + self.physical_tensor = torch.randn(tuple(self.segment.shape), device='cuda') + else: + # select from logical data + self.physical_tensor = torch.empty(tuple(self.shape), device='cuda') + self.physical_tensor.copy_(self.logic_tensor.data[self.indices.get()]) + if value_map_op is not None: + self.physical_tensor.data = value_map_fn(self.physical_tensor) + self.materialized = True + + def recover(self, reduction_op): + """ + Recover the deployed physical tensors by reduction operation + + Each rank can call this even there is no physical tensor on it. + + Args: + reduction_op (callable): + inplacement update on physical tensor + + Returns: + None. The physical tensor will be updated to match logical data + """ + if self.materialized: + if self.physical_tensor is not None: + reduction_op(self.physical_tensor, group=self.group) + else: + raise RuntimeError("The Segment has not been materialized") + + def get_physical_tensor(self): + """Get physical tensor if materialized + + Returns: + PhysicalTensor (if materialized) + """ + if self.materialized: + return self.physical_tensor + else: + raise RuntimeError("The Segment has not been materialized") + + def set_physical_tensor(self, physical_tensor, ranks): + if self.materialized: + raise RuntimeError("Setting physical tensors to a materialized community") + if not isinstance(ranks, list): + raise TypeError("ranks: Expected a list[int]") + if physical_tensor is not None: + if list(physical_tensor.size()) != list(self.segment.shape): + raise RuntimeError( + "Trying to set a community where physical tensor shape " + "doesn't match with segment shape") + self.physical_tensor = physical_tensor + self.group = DeviceGroup().get_group(ranks) + self.materialized = True From 29aab653efe9e723e5fcc3ea1575a2fabe88f1e6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 28 Jul 2021 13:43:10 +0000 Subject: [PATCH 0089/1892] refine to use clean abstractions --- cube/operator/physic/comm/reduction.py | 31 ++++ cube/tensor/community.py | 106 ------------- cube/tensor/logic/{segment => }/outline.py | 35 +++-- cube/tensor/logic/segment/__init__.py | 5 - cube/tensor/logic/segment/segment.py | 164 --------------------- cube/tensor/logic/tensor.py | 1 + 6 files changed, 53 insertions(+), 289 deletions(-) create mode 100644 cube/operator/physic/comm/reduction.py delete mode 100644 cube/tensor/community.py rename cube/tensor/logic/{segment => }/outline.py (82%) delete mode 100644 cube/tensor/logic/segment/__init__.py delete mode 100644 cube/tensor/logic/segment/segment.py diff --git a/cube/operator/physic/comm/reduction.py b/cube/operator/physic/comm/reduction.py new file mode 100644 index 00000000..07fdbec8 --- /dev/null +++ b/cube/operator/physic/comm/reduction.py @@ -0,0 +1,31 @@ +from cube.operator.physic.comm import replicate, reduce_sum +import torch + + +# TODO: reduction op should be in torch autograd function +class _Reduction(type): + + # forward: all_reduce, backward: identity + Sum = (reduce_sum,) + + # forward: identity, backward: all_reduce + Replica = (replicate,) + + def register(cls, name, udf): + """ + Reduction functions should be in function format: + + Arguments: + PhysicalTensor + Communication Group + + Return: + PhysicalTensor + """ + if hasattr(cls, name): + raise KeyError("{} is registered".format(name)) + setattr(cls, name, (udf,)) + + +class ReductionOp(metaclass=_Reduction): + pass diff --git a/cube/tensor/community.py b/cube/tensor/community.py deleted file mode 100644 index aaa8f58c..00000000 --- a/cube/tensor/community.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -from cube.device.physic.group import DeviceGroup - - -__all__ = ['Community'] - - -class Community: - - def __init__(self, segment): - """Create Community based on the logical tensor - - Community manages one: - - 1). Logical Tensor data mapping to Physical Tensor data storage - 2). Materialized Physical Tensors - - Attribute: - segment (DataSegment): - indices of logical_tensor for this community - """ - # connection to logical tensor - # DataSegment to indicate both element set and data format mapping - self.segment = segment - self.reduction = segment.reduction - - # connection to physical tensor (the PyTorch Tensor) - self.physical_tensor = None - self.placement = list() - self.group = list() - self.materialized = False - - def deploy(self, ranks, logic_tensor, value_map_fn=None): - """deploy (materialize) to physical tensors - - Materialize physical tensors for this community and spread out - based on the given device list. - - This offers policy module an interface to decide which devices - to spread. - - Argument: - ranks (list[int]): device id list - value_map_fn (callable): - takes the tensor, rank, world_size, - return a new tensor - """ - - if not isinstance(ranks, list): - raise TypeError("Expected ranks in list[int]") - self.placement = ranks - rank = DeviceGroup().rank - self.group = DeviceGroup().get_group(ranks) - if rank not in ranks: - self.physical_tensor = None - else: - if logic_tensor.data is None: - # TODO: check overlap - self.physical_tensor = torch.randn(tuple(self.segment.shape), device='cuda') - else: - # select from cpu view - self.physical_tensor = torch.empty(tuple(self.segment.shape), device='cuda') - self.physical_tensor.copy_(logic_tensor.data[self.segment.get_indices()]) - if value_map_fn is not None: - self.physical_tensor.data = value_map_fn(self.physical_tensor) - self.materialized = True - - def sync(self): - """ - Synchrnoize the spread physical tensors by reduction operation - - This should be a out-placement device for differentiable communication ops. - - Each device should call this, including no-physical-tensor devices - """ - if self.materialized: - if self.physical_tensor is not None: - #TODO: elegant impl on calling reduction op - self.reduction[0](self.physical_tensor, group=self.group) - else: - raise RuntimeError("The Community has not been materialized to physical tensors") - - def get_physical_tensor(self): - """Get physical tensor if materialized - - Returns: - PhysicalTensor (if materialized) - """ - if self.materialized: - return self.physical_tensor - else: - raise RuntimeError("The Community has not been materialized to physical tensors") - - def set_physical_tensor(self, physical_tensor, ranks): - if self.materialized: - raise RuntimeError("Setting physical tensors to a materialized community") - if not isinstance(ranks, list): - raise TypeError("ranks: Expected a list[int]") - if physical_tensor is not None: - if list(physical_tensor.size()) != list(self.segment.shape): - raise RuntimeError( - "Trying to set a community where physical tensor shape " - "doesn't match with segment shape") - self.physical_tensor = physical_tensor - self.group = DeviceGroup().get_group(ranks) - self.materialized = True diff --git a/cube/tensor/logic/segment/outline.py b/cube/tensor/logic/outline.py similarity index 82% rename from cube/tensor/logic/segment/outline.py rename to cube/tensor/logic/outline.py index 568e8cce..e9f90c38 100644 --- a/cube/tensor/logic/segment/outline.py +++ b/cube/tensor/logic/outline.py @@ -10,7 +10,8 @@ to the real segmentation on given logical tensor shape. """ -from cube.tensor.logic.segment.segment import TileSegment, ReductionOp +from cube.tensor.segment import Segment +from cube.tensor.indices import TileIndices class MutableContainer: @@ -46,7 +47,6 @@ def set(self, val): # interface to setup restrictions on the segmentation - class BaseOutline: """ Basic class for declare outline @@ -74,20 +74,29 @@ def __setattr__(self, key, val): else: self.__dict__[key] = ConstantContainer(val) - def __call__(self): + def interpret(self, logical_tensor): + raise NotImplementedError + + def __call__(self, logical_tensor): + if not isinstance(logical_tensor, LogicalTensor): + raise TypeError("Expected logical_tensor is instance of LogicalTensor") + + #TODO: merge out to fuse in configurable space if self.policy_fn is not None: self.policy_fn.get()(self) + self.interpret(logical_tensor) + class Full(BaseOutline): def __init__(self, reduction=None): super().__init__(reduction) - def __call__(self, shape): - #TODO: super call seperate - super().__call__() - segment = TileSegment([0] * len(shape), list(shape), self.reduction.get()) + def interpret(self, logical_tensor): + shape = logical_tensor.shape + indices = TileIndices([0] * len(shape), shape) + segment = Segment(logical_tensor, indices, self.reduction.get()) return [segment] @@ -119,22 +128,20 @@ def __init__(self, axis, chunk_num=None, overlap=0, reduction=None, uniform=True self.uniform = uniform self.overlap = overlap - def __call__(self, shape): + def interpret(self, logical_tensor): """ Runtime segment generation given the logical tensor shape This is the policy that how to do the translation. """ - #TODO: super call seperate - super().__call__() segments = list() - shape = list(shape) + shape = list(logical_tensor.shape) shape[self.axis.get()] = shape[self.axis.get()] // self.chunk_num.get() anchor = [0] * len(shape) #TODO: support list of reductions for cid in range(self.chunk_num.get()): - segment = TileSegment( - list(anchor), list(shape), reduction=self.reduction.get()) - anchor[self.axis.get()] += shape[self.axis.get()] + indices = TileIndices(anchor, shape) + segment = Segment(logical_tensor, indices) segments.append(segment) + anchor[self.axis.get()] += shape[self.axis.get()] return segments diff --git a/cube/tensor/logic/segment/__init__.py b/cube/tensor/logic/segment/__init__.py deleted file mode 100644 index ad6e085a..00000000 --- a/cube/tensor/logic/segment/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from cube.tensor.logic.segment.segment import ReductionOp -from cube.tensor.logic.segment.segment import DataSegment, TileSegment -from cube.tensor.logic.segment.segment import create_from_indices, create_from_tiles - -from cube.tensor.logic.segment.outline import Full, SplitAxis \ No newline at end of file diff --git a/cube/tensor/logic/segment/segment.py b/cube/tensor/logic/segment/segment.py deleted file mode 100644 index 79711068..00000000 --- a/cube/tensor/logic/segment/segment.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -This is the runtime primitive sets to setup community for a logical tensor. -""" - -from cube.operator.physic.comm import replicate, reduce_sum -import torch - - -# TODO: reduction op should be in torch autograd function -class _Reduction(type): - - # forward: all_reduce, backward: identity - Sum = (reduce_sum,) - - # forward: identity, backward: all_reduce - Replica = (replicate,) - - def register(cls, name, udf): - """ - Reduction functions should be in function format: - - Arguments: - PhysicalTensor - Communication Group - - Return: - PhysicalTensor - """ - if hasattr(cls, name): - raise KeyError("{} is registered".format(name)) - setattr(cls, name, (udf,)) - - -class ReductionOp(metaclass=_Reduction): - pass - - -## Basic structure for holding a segment -> cover all the cases ## -class DataSegment: - """ - The basic primitive to gather data in the logical tensor. - - The order of indices indicate the physical storage (1-D array) order - """ - - def __init__(self, indices_list=None, shape=None, reduction=None): - """ - Args: - indices_list (list[ list[int], ]): - List of index - reduction (ReductionOp): - How to reduction to the logical value - shape: - shape on the indices list - """ - - self.indices = indices_list - if shape is None: - if indices_list is None: - raise RuntimeError("Provide shape if indices_list is empty") - self.shape = (len(indices_list[0]),) - else: - # TODO: check shape - self.shape = shape - self.reduction = reduction - - def get_indices(self): - """ - Convert to index list - """ - return tuple(self.indices) - - def reorder(self, new_orders): - """ - Reorder the indices. - - Note this can be only called before materialize physical tensors, - or called from underlying operation that will change physical storage format - - Args: - new_orders (iteratable): order of each index - """ - #TODO: check if materialized - for dim in range(len(self.indices)): - self.indices[dim] = [self.indices[dim][idx] for idx in new_orders] - - def __repr__(self): - msg = 'DataSegment(indices_len={}, reduction={})'.format( - len(self.indices), self.reduction - ) - - -## Higher structure to cover the most cases ## -class TileSegment(DataSegment): - """ - A tile is a contigonous block on the logical tensor shape, - which can be represented as the start position + offset (shape) - """ - - def __init__(self, anchor, shape, reduction=None): - """ - Args: - anchor (list[int]): start position of the tile - offset (list[int]): offset (shape) of the tile - """ - if len(anchor) != len(shape): - raise ValueError("Require anchor length to be equal with offset length") - super().__init__(shape=shape, reduction=reduction) - self.anchor = anchor - - def get_indices(self): - """ - Convert anchor and offset to index list - """ - indices = list() - for start, ofst in zip(self.anchor, self.shape): - indices.append(slice(start, start + ofst)) - return tuple(indices) - - def reorder(self): - pass - - def __repr__(self): - msg = 'TileSegment(anchor={}, shape={}, reduction={})'.format( - self.anchor, self.shape, self.reduction - ) - return msg - - -## Primitive sets for translation ## - -def create_from_indices(indices, shape, reduction): - """ - Create a data segment from indices, and format in shape. - The indices list will determine how data will be organized in - storage. - - Args: - indices (list[list[int]]): - Represent indices from logical tensor shape - len(indices) is the dimension, - e.g., index [3,4,5] and [2,7,9] is represented as - [[3,2], [4,7], [5,9]] - shape (tuple or list): - the segment shape - reduction (ReductionOp): - How to generate correct logical results from reduction op. - - Returns: - DataSegment instance - """ - return DataSegment(indices, shape, reduction) - - -def create_from_tiles(anchor, shape, reduction): - # segments = list() - # dims = len(offset) - # for dim_id in range(dims): - # indices = None # -> TODO: generate indices along the dim_id - # segment = create_from_indices(indices) - # segments.append(segment) - # segment = merge_segments(segments) - # return segment - return TileSegment(anchor, shape, reduction) \ No newline at end of file diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 7e96383f..964014c3 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -21,6 +21,7 @@ def __init__(self, shape, init_data=True): def match(self, communities, ranks=None, val_map_fns=None): """ Match the LogicalTensor with community list. + TODO: change name """ # type check ranks = [None] * len(communities) if ranks is None else ranks From 596f9af5410a87201bb215acbd4bd1800f679541 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Jul 2021 02:13:00 +0000 Subject: [PATCH 0090/1892] indices for base and tile --- cube/tensor/indices.py | 34 ++++++++++++++++++++++++++-- tests/tensor/test_indices.py | 43 ++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 tests/tensor/test_indices.py diff --git a/cube/tensor/indices.py b/cube/tensor/indices.py index 45d19dbe..88b1622d 100644 --- a/cube/tensor/indices.py +++ b/cube/tensor/indices.py @@ -18,6 +18,18 @@ def __init__(self, indices_list): """ self.indices = tuple(indices_list) + def ndim(self): + """ + Return dims of this indices + """ + return len(self.indices) + + def size(self): + """ + Return total number of index + """ + return len(self.indices[0]) + def get(self): """ Get indexable indices @@ -34,8 +46,11 @@ def reorder(self, new_orders): Args: new_orders (iteratable): order of each index """ - for dim in range(len(self.indices)): - self.indices[dim] = [self.indices[dim][idx] for idx in new_orders] + new_orders = list(new_orders) + indices = list(self.indices) + for dim in range(self.ndim()): + indices[dim] = [self.indices[dim][idx] for idx in new_orders] + self.indices = tuple(indices) def __repr__(self): msg = 'BaseIndices(indices_len={})'.format( @@ -56,11 +71,26 @@ def __init__(self, anchor, shape): offset (list[int]): offset (shape) of the tile """ indices = list() + size = 1 for start, ofst in zip(self.anchor, self.shape): indices.append(slice(start, start + ofst)) + size *= ofst super().__init__(tuple(indices)) self.anchor = anchor self.shape = shape + self.size = size + + def ndim(self): + """ + Return dims of this indices + """ + return len(self.indices) + + def size(self): + """ + Return total number of index + """ + return self.size def reorder(self): raise NotImplementedError diff --git a/tests/tensor/test_indices.py b/tests/tensor/test_indices.py new file mode 100644 index 00000000..42438837 --- /dev/null +++ b/tests/tensor/test_indices.py @@ -0,0 +1,43 @@ +from cube.tensor.indices import BaseIndices, TileIndices + +import torch + +def test_base_indices(): + + tensor = torch.randn((10, 10, 10)) + + # test init + sparse_indices = ( + [2,3,1,4], + [0,4,8,4], + [7,5,9,4] + ) + indices = BaseIndices(sparse_indices) + assert indices.indices == sparse_indices + + # test ndim + assert indices.ndim() == 3 + + # test size + assert indices.size() == 4 + + # test get + sub_tensor = tensor[indices.get()] + assert torch.allclose(sub_tensor, tensor[sparse_indices]) is True + + # test reorder + arg_order = [2, 1, 0, 3] + indices.reorder(arg_order) + sub_tensor = tensor[indices.get()] + + sparse_indices = ( + [1,3,2,4], + [8,4,0,4], + [9,5,7,4] + ) + ref_tensor = tensor[sparse_indices] + assert torch.allclose(sub_tensor, ref_tensor) is True + + +if __name__ == '__main__': + test_base_indices() From a1867afad3baf37636ac7ec2edb283ef5264c9b9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Jul 2021 02:24:39 +0000 Subject: [PATCH 0091/1892] tensor indices on tile --- cube/tensor/indices.py | 6 +++--- tests/tensor/test_indices.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/cube/tensor/indices.py b/cube/tensor/indices.py index 88b1622d..08862ecf 100644 --- a/cube/tensor/indices.py +++ b/cube/tensor/indices.py @@ -72,13 +72,13 @@ def __init__(self, anchor, shape): """ indices = list() size = 1 - for start, ofst in zip(self.anchor, self.shape): + for start, ofst in zip(anchor, shape): indices.append(slice(start, start + ofst)) size *= ofst super().__init__(tuple(indices)) self.anchor = anchor self.shape = shape - self.size = size + self.elenum = size def ndim(self): """ @@ -90,7 +90,7 @@ def size(self): """ Return total number of index """ - return self.size + return self.elenum def reorder(self): raise NotImplementedError diff --git a/tests/tensor/test_indices.py b/tests/tensor/test_indices.py index 42438837..37e6b419 100644 --- a/tests/tensor/test_indices.py +++ b/tests/tensor/test_indices.py @@ -39,5 +39,33 @@ def test_base_indices(): assert torch.allclose(sub_tensor, ref_tensor) is True +def test_tile_indices(): + + tensor = torch.randn((10, 10, 10)) + + anchor = [3,4,5] + ofst = [2,4,3] + indices = TileIndices(anchor, ofst) + assert indices.anchor == anchor + assert indices.shape == ofst + assert indices.elenum == 2 * 4 * 3 + + # test ndim + assert indices.ndim() == 3 + + # test size + assert indices.size() == 2 * 4 * 3 + + # test get + sub_tensor = tensor[indices.get()] + assert sub_tensor.size() == torch.Size(ofst) + ref_tensor = tensor[(slice(3,3+2), slice(4,4+4), slice(5,5+3))] + assert torch.allclose(sub_tensor, ref_tensor) is True + + # test reorder + ##TODO + + if __name__ == '__main__': test_base_indices() + test_tile_indices() \ No newline at end of file From 4bd465dfa7fd96391c8b97b524c9d24e5666314c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Jul 2021 03:18:32 +0000 Subject: [PATCH 0092/1892] segment test pass --- cube/tensor/logic/tensor.py | 28 +++---- cube/tensor/segment.py | 22 ++--- tests/tensor/test_segment.py | 157 ++++++++++++++++++++--------------- 3 files changed, 111 insertions(+), 96 deletions(-) diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 964014c3..8f6a8ffb 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -1,5 +1,5 @@ -from cube.tensor.community import Community -from cube.tensor.logic.segment.segment import DataSegment +from cube.tensor.segment import Segment +from cube.tensor.indices import BaseIndices class LogicalTensor: @@ -8,15 +8,19 @@ class LogicalTensor: """ def __init__(self, shape, init_data=True): - - self.shape = shape - # segment -> community - self.communities = dict() + self.shape = tuple(shape) self.segments = list() self.data = None if init_data: import torch self.data = torch.randn(shape).detach() + + def select(self, indices, shape): + """ + Create a Segment given the indices for this logical tensor, + and the Segment will use shape. + """ + self.segments.append(Segment(self, indices, shape)) def match(self, communities, ranks=None, val_map_fns=None): """ @@ -98,16 +102,10 @@ def get_community(self, segment_or_index): def __getitem__(self, key): """ - key: - if key is DataSegment, return community - ##TODO: DOUBLE CHECK - if key is slice, return new logical tensor + """ - if isinstance(key, DataSegment): - return self.get_community(key) - else: - ## TODO: should return logical tensor / views - return self.data[key] + # TODO: create new logical tensor / change layout + return self.data[key] def create_community(self, segment): """ diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py index a38ff7e5..1ee9bfad 100644 --- a/cube/tensor/segment.py +++ b/cube/tensor/segment.py @@ -1,8 +1,7 @@ -import torch from cube.device.physic.group import DeviceGroup +from cube.tensor.indices import BaseIndices - -__all__ = ['Segment'] +import torch class Segment: @@ -23,8 +22,6 @@ def __init__(self, logical_tensor, indices, shape): merge_op (None or callable): merge op to take physical tensor """ - if not isinstance(logical_tensor, LogicalTensor): - raise TypeError("Expected logical_tensor to be LogicalTensor") if not isinstance(indices, BaseIndices): raise TypeError("Expected indices to be BaseIndices") @@ -33,14 +30,14 @@ def __init__(self, logical_tensor, indices, shape): # segment info self.indices = indices - self.shape = shape + self.shape = tuple(shape) # physical tensor (the PyTorch Tensor) self.physical_tensor = None # deploy information self.placement = list() - self.group = list() + self.group = None self.deploy_op = None self.materialized = False @@ -58,9 +55,8 @@ def deploy(self, ranks, value_map_op=None): Argument: ranks (list[int]): device id list - value_map_fn (callable): - takes the tensor, rank, world_size, - return a new tensor + value_map_op (callable): + takes the tensor, return a new tensor """ if not isinstance(ranks, list): raise TypeError("Expected ranks in list[int]") @@ -69,15 +65,15 @@ def deploy(self, ranks, value_map_op=None): self.placement = ranks self.group = DeviceGroup().get_group(ranks) if rank in ranks: - if self.logic_tensor.data is None: + if self.logical_tensor.data is None: # TODO: check overlap self.physical_tensor = torch.randn(tuple(self.segment.shape), device='cuda') else: # select from logical data self.physical_tensor = torch.empty(tuple(self.shape), device='cuda') - self.physical_tensor.copy_(self.logic_tensor.data[self.indices.get()]) + self.physical_tensor.copy_(self.logical_tensor.data[self.indices.get()]) if value_map_op is not None: - self.physical_tensor.data = value_map_fn(self.physical_tensor) + self.physical_tensor.data = value_map_op(self.physical_tensor) self.materialized = True def recover(self, reduction_op): diff --git a/tests/tensor/test_segment.py b/tests/tensor/test_segment.py index a9a9e371..e16a1506 100644 --- a/tests/tensor/test_segment.py +++ b/tests/tensor/test_segment.py @@ -1,94 +1,115 @@ -import cube.tensor.logic.segment as segment -import torch - +""" +cmd for running the test + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + tests/tensor/test_segment.py +""" + +from cube.tensor.logic.tensor import LogicalTensor +from cube.tensor.segment import Segment +from cube.tensor.indices import BaseIndices, TileIndices +from cube.device.physic.group import DeviceGroup -def test_reduction_op_register(): +import torch +import os +torch.manual_seed(121) - def reduce_fn(physical_tensor, group): - return physical_tensor - segment.ReductionOp.register("ReduceSum", reduce_fn) - # segment.ReductionOp.register("Replica", reduce_fn) +def test_segment_init(): - tensor = torch.randn((3,4)) - out = segment.ReductionOp.ReduceSum[0](tensor, None) - assert out is tensor + tensor = LogicalTensor((10,10,10)) + anchor = [3,4,5] + ofst = [2,4,3] + indices = TileIndices(anchor, ofst) -## TODO: test all the provided reduction op -def test_reduction_op_replica(): - #TODO: check correctness - assert callable(segment.ReductionOp.Replica[0]) + segment = Segment(tensor, indices, ofst) + assert segment.logical_tensor is tensor + assert segment.shape == tuple(ofst) + assert segment.physical_tensor is None + assert len(segment.placement) == 0 + assert segment.group is None + assert segment.deploy_op is None + assert segment.materialized is False + assert segment.merge_op is None -def test_data_segment_init(): - tensor = torch.randn((10,10,10)) - indices = [[5,3,2,4], - [1,2,7,4], - [3,4,5,4]] - seg = segment.DataSegment( - indices, shape=(4,1), reduction=segment.ReductionOp.Replica) - assert seg.indices == indices - assert seg.shape == (4,1) - assert seg.reduction == segment.ReductionOp.Replica +def test_segment_deploy(): + myrank = DeviceGroup().rank + tensor = LogicalTensor((10,10,10)) -def test_data_segment_get_indices(): + anchor = [3,4,5] + ofst = [2,4,3] + indices = TileIndices(anchor, ofst) - tensor = torch.randn((10,10,10)) - indices = [[5,3,2,4], - [1,2,7,4], - [3,4,5,4]] - seg = segment.DataSegment( - indices, shape=(4,1), reduction=segment.ReductionOp.Replica) - sub_tensor = tensor[seg.get_indices()] - assert sub_tensor.size() == torch.Size([4]) + segment = Segment(tensor, indices, ofst) + ranks = [0,2] + segment.deploy(ranks, value_map_op=None) -def test_data_segment_reorder(): + physical_tensor = segment.get_physical_tensor() + tensor_ref = tensor.data[indices.get()].cuda() + if myrank in ranks: + assert physical_tensor.device == torch.device('cuda:{}'.format(myrank)) + assert torch.allclose(physical_tensor, tensor_ref) + else: + assert physical_tensor is None + assert segment.placement == ranks + assert segment.group == DeviceGroup().get_group(ranks) + assert segment.deploy_op is None + assert segment.materialized is True + assert segment.merge_op is None - tensor = torch.randn((10,10,10)) - indices = [[5,3,2,4], - [1,2,7,4], - [3,4,5,4]] - seg = segment.DataSegment( - indices, shape=(4,1), reduction=segment.ReductionOp.Replica) - sub_tensor = tensor[seg.get_indices()] - seg.reorder([2,3,1,0]) - ref_tensor = sub_tensor[([2,3,1,0])] - check_tensor = tensor[seg.get_indices()] - assert torch.all(torch.eq(ref_tensor, check_tensor)) +def test_segment_recover(): + myrank = DeviceGroup().rank + tensor = LogicalTensor((10,10,10)) -def test_tile_segment_init(): + anchor = [3,4,5] + ofst = [2,4,3] + indices = TileIndices(anchor, ofst) - tensor = torch.randn((10,10,10)) - seg = segment.TileSegment( - anchor=(2,3,1), shape=(4,4,4), reduction=segment.ReductionOp.Replica) - assert seg.shape == (4,4,4) - assert seg.anchor == (2,3,1) - assert seg.reduction == segment.ReductionOp.Replica + segment = Segment(tensor, indices, ofst) + ranks = [0,2] + segment.deploy(ranks, value_map_op=lambda tensor: tensor / 2) -def test_tile_segment_get_indices(): + # deploy check + physical_tensor = segment.get_physical_tensor() + tensor_ref = tensor.data[indices.get()].cuda() / 2 + if myrank in [0,2]: + assert physical_tensor.device == torch.device('cuda:{}'.format(myrank)) + assert torch.allclose(physical_tensor, tensor_ref) is True + else: + assert physical_tensor is None - tensor = torch.randn((10,10,10)) - seg = segment.TileSegment( - anchor=(2,3,1), shape=(4,4,4), reduction=segment.ReductionOp.Replica) - ref_tensor = tensor[(slice(2,2+4), slice(3,3+4), slice(1,1+4))] - sub_tensor = tensor[seg.get_indices()] - assert sub_tensor.size() == torch.Size([4,4,4]) - assert torch.all(torch.eq(ref_tensor, sub_tensor)) + # recover to get logical value + def reduction_op(tensor, group): + torch.distributed.all_reduce(tensor, group=group) + segment.recover(reduction_op=reduction_op) + physical_tensor = segment.get_physical_tensor() + + tensor_ref = tensor.data[indices.get()].cuda() + if myrank in [0,2]: + assert physical_tensor.device == torch.device('cuda:{}'.format(myrank)) + assert torch.allclose(physical_tensor, tensor_ref) is True + else: + assert physical_tensor is None if __name__ == '__main__': - test_reduction_op_register() - test_reduction_op_replica() - test_data_segment_init() - test_data_segment_get_indices() - test_data_segment_reorder() - test_tile_segment_init() - test_tile_segment_get_indices() \ No newline at end of file + group = DeviceGroup() + + test_segment_init() + test_segment_deploy() + test_segment_recover() From da07e965d07e705aecd2ea8ff203921b233daef5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Jul 2021 05:50:31 +0000 Subject: [PATCH 0093/1892] logical tensor test --- cube/tensor/logic/tensor.py | 170 ++++++++++++++-------------- cube/tensor/segment.py | 6 +- tests/tensor/test_community.py | 112 ------------------ tests/tensor/test_logical_tensor.py | 102 ++++++++++++++--- 4 files changed, 174 insertions(+), 216 deletions(-) delete mode 100644 tests/tensor/test_community.py diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 8f6a8ffb..702063ff 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -8,73 +8,80 @@ class LogicalTensor: """ def __init__(self, shape, init_data=True): + """ + Create an empty logical tensor with no segmentations + + Args: + shape (tuple[int] or list[int]): + shape of the tensor + init_data (Boolean): + if True, init a CPU data. Otherwise no data initialized. + """ self.shape = tuple(shape) self.segments = list() self.data = None if init_data: import torch self.data = torch.randn(shape).detach() + + def fill(self, physical_tensors, ranks): + """ + Construct the logical tensor with physical tensors. + + Args: + physical_tensors (list[PhysicalTensor, None]): + the list length should be equal to len(self.segments) + ranks (list[list[int],]): + each segment will pair with a list of ranks + """ + if self.data is not None: + raise RuntimeError("Only allowed fill physical tensors when data is not None") + for segment, physical_tensor, ranks in zip(self.segments, physical_tensors, ranks): + segment.set_physical_tensor(physical_tensor, ranks) def select(self, indices, shape): """ Create a Segment given the indices for this logical tensor, and the Segment will use shape. """ - self.segments.append(Segment(self, indices, shape)) - - def match(self, communities, ranks=None, val_map_fns=None): - """ - Match the LogicalTensor with community list. - TODO: change name - """ - # type check - ranks = [None] * len(communities) if ranks is None else ranks - val_map_fns = [None] * len(communities) if val_map_fns is None else val_map_fns - if not isinstance(ranks, list): - raise TypeError("Expected ranks to be a list or None") - if not isinstance(ranks, list): - raise TypeError("Expected ranks to be a list or None") - if len(ranks) != len(communities): - raise RuntimeError( - "Un-matched length of communities ({}) : ranks ({})".format( - len(communities), len(ranks)) - ) - if len(val_map_fns) != len(communities): - raise RuntimeError( - "Un-matched length of communities ({}) : ranks ({})".format( - len(communities), len(ranks)) - ) + segment = Segment(self, indices, shape) + return segment + + def transform(self, segments, ranks=None, val_map_ops=None): + """ + Transform the LogicalTensor with community list. + TODO: check if this should create a new logical tensor + """ + if not (isinstance(ranks, list) and len(ranks) == len(segments)): + raise ValueError("Expected ranks to be a list with equal length of segments") + if not (isinstance(ranks, list) and len(val_map_ops) == len(segments)): + raise ValueError("Expected ranks to be a list with equal length of segments") - #TODO: community matching and transformation - if len(self.communities) == 0: - for cid in range(len(communities)): - community = communities[cid] - self.set_community(community) - if not community.materialized: - rank_list = ranks[cid] - val_map_fn = val_map_fns[cid] - community.deploy(rank_list, self, val_map_fn) + if len(self.segments) == 0: + for sid in range(len(segments)): + segment = segments[sid] + self.add_segment(segment) + if not segment.materialized: + deploy_ranks = ranks[sid] + if not isinstance(deploy_ranks, list): + raise TypeError('Expected ranks to be list[list[int],]') + deploy_ops = val_map_ops[sid] + segment.deploy(deploy_ranks, deploy_ops) + #TODO: segment transformation on existing segments else: raise NotImplementedError - - @staticmethod - def construct(shape, communities): - tensor = LogicalTensor(shape=shape, init_data=False) - for community in communities: - tensor.set_community(community) - return tensor - - def get_physical_tensor(self, segment_or_index): + + def get_physical_tensor(self, index): """ - Get physical tensor from the community. + Get physical tensor from the segment. Args: - idx: index for community + idx: index for segment Returns: torch.Tensor or None """ - return self.get_community(segment_or_index).get_physical_tensor() + return self.get_segment(index).get_physical_tensor() def __len__(self): """ @@ -82,24 +89,6 @@ def __len__(self): """ return len(self.segments) - def get_community(self, segment_or_index): - """ - Get Community based on the segment - - Args: - segment_or_index (DataSegment or int): - - Returns: - Community - """ - if isinstance(segment_or_index, int): - return self.communities[self.segments[segment_or_index]] - elif isinstance(segment_or_index, DataSegment): - return self.communities[segment_or_index] - else: - raise ValueError("Expected (derived) DataSegment to chooese Community") - return self.communities[segment] - def __getitem__(self, key): """ @@ -107,36 +96,43 @@ def __getitem__(self, key): # TODO: create new logical tensor / change layout return self.data[key] - def create_community(self, segment): - """ - Create a community by given the segment + def get_segment(self, idx): """ - if segment in self.communities: - raise KeyError("The segment already exists") - self.communities[segment] = Community(segment) - self.segments.append(segment) - - def set_community(self, community): + Get a segment using index + + Args: + idx (int): index to segment list + + Returns: + Segment """ - Set a community + return self.segments[idx] - Warning: if there is a segment in this tensor that matches - with the given community's segment, the original community - will be overrided + def add_segment(self, segment): """ - if not isinstance(community, Community): - raise TypeError("Expected a community") - segment = community.segment - if segment not in self.communities: - self.segments.append(segment) - self.communities[segment] = community + Add a segment. - def remove_community(self, segment): + Note adding a segment will change the segment parent logical tensor + to this tensor + """ + if not isinstance(segment, Segment): + raise TypeError("Expected a segment") + segment.logical_tensor = self + if segment in self.segments: + raise RuntimeError("Segment is already added") + self.segments.append(segment) + + def remove_segment(self, segment_or_index): """ Remove a community by given the segment """ #TODO: check whether a sync-back is needed - if segment not in self.communities: - raise KeyError("The segment doesn't exist") - del self.communities[segment] - self.segments.remove(segment) + if isinstance(segment_or_index, Segment): + if segment not in self.segments: + raise KeyError("The segment doesn't exist") + self.segments.remove(segment) + elif isinstance(segment_or_index, int): + del self.segments[segment_or_index] + else: + raise ValueError("Expected Segment instance or index int") + diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py index 1ee9bfad..2f7f7504 100644 --- a/cube/tensor/segment.py +++ b/cube/tensor/segment.py @@ -71,7 +71,9 @@ def deploy(self, ranks, value_map_op=None): else: # select from logical data self.physical_tensor = torch.empty(tuple(self.shape), device='cuda') - self.physical_tensor.copy_(self.logical_tensor.data[self.indices.get()]) + self.physical_tensor.copy_( + self.logical_tensor.data[self.indices.get()].reshape(self.shape) + ) if value_map_op is not None: self.physical_tensor.data = value_map_op(self.physical_tensor) self.materialized = True @@ -112,7 +114,7 @@ def set_physical_tensor(self, physical_tensor, ranks): if not isinstance(ranks, list): raise TypeError("ranks: Expected a list[int]") if physical_tensor is not None: - if list(physical_tensor.size()) != list(self.segment.shape): + if list(physical_tensor.size()) != list(self.shape): raise RuntimeError( "Trying to set a community where physical tensor shape " "doesn't match with segment shape") diff --git a/tests/tensor/test_community.py b/tests/tensor/test_community.py deleted file mode 100644 index ced08a34..00000000 --- a/tests/tensor/test_community.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -cmd for running the test - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - tests/tensor/test_community.py -""" - -from cube.tensor.community import Community -import cube.tensor.logic.segment as segment -from cube.device.physic.group import DeviceGroup - -import torch -import os -torch.manual_seed(121) - - -def test_community_init(): - - tensor = torch.randn((10,10,10)) - seg = segment.TileSegment( - anchor=(2,3,1), shape=(4,4,4), reduction=segment.ReductionOp.Replica) - community = Community(seg) - - assert community.segment == seg - assert community.physical_tensor is None - assert len(community.group) == 0 - assert community.materialized is False - - -def test_community_deploy(): - - tensor = torch.randn((10,10,10)) - seg = segment.TileSegment( - anchor=(2,3,1), shape=(4,4,4), - reduction=segment.ReductionOp.Replica) - community = Community(seg) - - # policy for scaling out - # using torch.Tensor to test - ranks = [0,2] - community.deploy(ranks, tensor, None) - - # check - myrank = DeviceGroup().rank - if myrank not in ranks: - assert community.physical_tensor is None - else: - sub_tensor = community.physical_tensor - assert torch.is_tensor(sub_tensor) - assert sub_tensor.size() == torch.Size([4,4,4]) - assert sub_tensor.device == torch.device('cuda:{}'.format(myrank)) - assert torch.all(torch.eq(sub_tensor.cpu(), tensor[seg.get_indices()])) - assert torch.distributed.get_world_size(community.group) == 2 - - -def test_community_sync(): - tensor = torch.randn((10,10,10)) - seg = segment.TileSegment( - anchor=(2,3,1), shape=(4,4,4), - reduction=segment.ReductionOp.Sum) - community = Community(seg) - - # deploy with value modification - ranks = [0,2] - community.deploy(ranks, tensor, - value_map_fn=lambda tensor: tensor / 2) - - # check - sub_tensor = community.get_physical_tensor() - ref_tensor = tensor[seg.get_indices()].cuda() - myrank = DeviceGroup().rank - if myrank in ranks: - assert torch.all(torch.eq(sub_tensor, ref_tensor / 2)) - - # sync to get logical value - community.sync() - sub_tensor = community.get_physical_tensor() - if myrank not in ranks: - assert sub_tensor is None - else: - # print('ref: {}'.format(ref_tensor)) - assert torch.allclose(sub_tensor, ref_tensor) is True - - -def test_community_set_physical_tensor(): - seg = segment.TileSegment( - anchor=(2,3,1), shape=(4,4,4), - reduction=segment.ReductionOp.Sum) - community = Community(seg) - - tensor = torch.randn((4,4,4)) - community.set_physical_tensor(tensor, [0,1,2]) - assert community.materialized is True - assert community.group == DeviceGroup().get_group([0,1,2]) - assert community.physical_tensor is tensor - - -if __name__ == '__main__': - - group = DeviceGroup() - torch.distributed.barrier() - - test_community_init() - test_community_deploy() - test_community_sync() - test_community_set_physical_tensor() \ No newline at end of file diff --git a/tests/tensor/test_logical_tensor.py b/tests/tensor/test_logical_tensor.py index dbc871ee..d6483824 100644 --- a/tests/tensor/test_logical_tensor.py +++ b/tests/tensor/test_logical_tensor.py @@ -1,30 +1,102 @@ +""" +cmd for running the test + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + tests/tensor/test_logical_tensor.py +""" + +from cube.tensor.indices import BaseIndices from cube.tensor.logic.tensor import LogicalTensor -from cube.tensor.community import Community -import cube.tensor.logic.segment as segment +from cube.tensor.segment import Segment +from cube.device.physic.group import DeviceGroup + +import torch def test_logical_tensor_init(): - #TODO - pass + tensor = LogicalTensor(shape=(10,10,10)) + assert tensor.shape == (10, 10, 10) + assert len(tensor.segments) == 0 + assert tensor.data is not None + assert tensor.data.size() == torch.Size([10,10,10]) + + +def test_logical_tensor_select(): + tensor = LogicalTensor(shape=(10,10,10)) + sparse_indices = ( + [2,3,1,4], + [0,4,8,4], + [7,5,9,4] + ) + indices = BaseIndices(sparse_indices) + segment = tensor.select(indices, shape=(2,2)) + assert isinstance(segment, Segment) + assert segment.materialized is False + +def test_logical_tensor_fill(): -def test_logical_tensor_construct(): + myrank = DeviceGroup().rank - seg = segment.TileSegment( - anchor=(2,3,1), shape=(4,4,4), - reduction=segment.ReductionOp.Replica) - community = Community(seg) + tensor = LogicalTensor(shape=(10,10,10), init_data=False) + sparse_indices = ( + [2,3,1,4], + [0,4,8,4], + [7,5,9,4] + ) + indices = BaseIndices(sparse_indices) + segment = tensor.select(indices, shape=(2,2)) + tensor.add_segment(segment) - logical_tensor = LogicalTensor.construct((10,10,10), [community]) + assert segment.materialized is False + assert len(tensor.segments) == 1 - assert isinstance(logical_tensor, LogicalTensor) - assert len(logical_tensor.communities) == 1 - assert logical_tensor.get_community(0) is community - assert logical_tensor.shape == (10,10,10) + ranks = [1, 3] + if myrank in ranks: + phy_tensor = torch.randn((2,2)).cuda() + else: + phy_tensor = None + tensor.fill([phy_tensor], [ranks]) + assert segment.materialized is True + if myrank in ranks: + assert tensor.get_physical_tensor(0) is not None + else: + assert tensor.get_physical_tensor(0) is None + + +def test_logical_tensor_transform(): + + tensor = LogicalTensor(shape=(10,10,10)) + sparse_indices = ( + [2,3,1,4], + [0,4,8,4], + [7,5,9,4] + ) + indices = BaseIndices(sparse_indices) + segment = tensor.select(indices, shape=(2,2)) + + ranks = [0,1,3] + tensor.transform([segment], [ranks], [None]) + + myrank = DeviceGroup().rank + if myrank in ranks: + assert tensor.get_physical_tensor(0) is not None + else: + assert tensor.get_physical_tensor(0) is None if __name__ == '__main__': + group = DeviceGroup() + test_logical_tensor_init() - test_logical_tensor_construct() \ No newline at end of file + test_logical_tensor_select() + test_logical_tensor_fill() + test_logical_tensor_transform() \ No newline at end of file From 23fc822a61caa7e55b6fb066133752038ba9ec75 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Jul 2021 07:01:17 +0000 Subject: [PATCH 0094/1892] init condition container --- cube/config/__init__.py | 0 cube/config/container.py | 92 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 cube/config/__init__.py create mode 100644 cube/config/container.py diff --git a/cube/config/__init__.py b/cube/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/config/container.py b/cube/config/container.py new file mode 100644 index 00000000..5a1a3045 --- /dev/null +++ b/cube/config/container.py @@ -0,0 +1,92 @@ + +class ConditionContainer: + + def __init__(self, satisfy_fn): + if not callable(satisfy_fn): + raise TypeError("Expected function") + self._satisfy_fn = satisfy_fn + self._val = None + self._choices = None + self._lock = False + + def get(self): + """ + Get the current set value (default None if not set) + """ + return self._val + + def set(self, val): + """ + Set the value, will raise ValueError if not satisfy + """ + if self._lock: + raise RuntimeError("Try to set a locked config") + if self._choices is not None: + if not self.satisfy(val, self._choices): + raise ValueError("Fail to set config") + self._val = val + + def lock(self): + """ + Lock the value, will not allow change + """ + self._lock = True + + def satisfy(self, val): + """ + Check whether the value satisfy the choices + + Returns: + True if satisfy, False not + """ + return self._satisfy_fn(val, self._choices) + + def choices(self): + """ + Return choices. + + Use list(container.choices) to see all the choices + """ + return self._choices + + def reset(self, choices): + """ + Reset choices + """ + self._val = None + self._choices = choices + + +class ChoiceContainer(ConditionContainer): + + def __init__(self, choices): + """ + Create a choice container, the value can only be + the item in the choices. + + choices (iterable): + list or range + """ + def satisfy_fn(val, choices): + return val in choices + super().__init__(satisfy_fn) + self._choices = choices + + +class TypeContainer(ConditionContainer): + + def __init__(self, type_choices): + """ + Create a type container, the value can only be + the instance of the type in the choices. + + type_choices (iterable): + usually a list[type] + """ + def satisfy_fn(val, choices): + for t in choices: + if isinstance(val, t): + return True + return False + super().__init__(satisfy_fn) + self._choices = type_choices From 03b8577682392ce0b29206a60b8fad24b2274b57 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Jul 2021 14:16:29 +0000 Subject: [PATCH 0095/1892] select needs to select with val map op, and deploy now only allows replica --- cube/tensor/logic/outline.py | 55 +++++++++-------------------- cube/tensor/logic/tensor.py | 13 +++++-- cube/tensor/segment.py | 36 +++++++++++-------- tests/tensor/test_logical_tensor.py | 6 ++-- tests/tensor/test_segment.py | 37 ++++++++----------- 5 files changed, 66 insertions(+), 81 deletions(-) diff --git a/cube/tensor/logic/outline.py b/cube/tensor/logic/outline.py index e9f90c38..76ebf21c 100644 --- a/cube/tensor/logic/outline.py +++ b/cube/tensor/logic/outline.py @@ -1,6 +1,6 @@ """ This is the description interface to describe the -segementation requirement (restrictions). +segmentation requirement (restrictions). The description includes two parts: @@ -14,37 +14,6 @@ from cube.tensor.indices import TileIndices -class MutableContainer: - - def __init__(self, scope): - self.__val = None - self.__scope = scope - - def get(self, scope=False): - if scope: - return self.__scope - else: - return self.__val - - def set(self, val): - if self.__scope is not None: - if val not in self.__scope: - raise ValueError("Fail to set container, out of range") - self.__val = val - - -class ConstantContainer: - - def __init__(self, val): - self.__val = val - - def get(self): - return self.__val - - def set(self, val): - raise RuntimeError("Cannot set a ConstantContainer") - - # interface to setup restrictions on the segmentation class BaseOutline: @@ -53,7 +22,7 @@ class BaseOutline: To setup an attribute (requirement), use `inst_baseoutline.attribute_name = val` """ - def __init__(self, reduction=None): + def __init__(self): self.reduction = reduction # decide how to generate segmentation given the requirement self.policy_fn = None @@ -88,10 +57,10 @@ def __call__(self, logical_tensor): self.interpret(logical_tensor) -class Full(BaseOutline): +class Full(ConfigTemplate): - def __init__(self, reduction=None): - super().__init__(reduction) + def __init__(self): + pass def interpret(self, logical_tensor): shape = logical_tensor.shape @@ -100,9 +69,9 @@ def interpret(self, logical_tensor): return [segment] -class SplitAxis(BaseOutline): +class SplitAxis(ConfigTemplate): - def __init__(self, axis, chunk_num=None, overlap=0, reduction=None, uniform=True): + def __init__(self, axis, chunk_num=None, overlap=0, uniform=True): """ Segmentation Pattern Requirement (parameters): @@ -122,7 +91,7 @@ def __init__(self, axis, chunk_num=None, overlap=0, reduction=None, uniform=True if a tuple(min, max), the overlap size wihtin the scope [min,max] is valid """ - super().__init__(reduction) + super().__init__() self.axis = axis self.chunk_num = chunk_num self.uniform = uniform @@ -145,3 +114,11 @@ def interpret(self, logical_tensor): segments.append(segment) anchor[self.axis.get()] += shape[self.axis.get()] return segments + + +class SplitValue(ConfigTemplate): + + def __init__(self, chunk_num=None, val_map_op=None): + ##TODO + self.chunk_num = chunk_num + self.val_map_op = val_map_op diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 702063ff..fb132ceb 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -39,12 +39,12 @@ def fill(self, physical_tensors, ranks): for segment, physical_tensor, ranks in zip(self.segments, physical_tensors, ranks): segment.set_physical_tensor(physical_tensor, ranks) - def select(self, indices, shape): + def select(self, indices, val_map_op, shape): """ Create a Segment given the indices for this logical tensor, and the Segment will use shape. """ - segment = Segment(self, indices, shape) + segment = Segment(self, indices, val_map_op, shape) return segment def transform(self, segments, ranks=None, val_map_ops=None): @@ -66,7 +66,7 @@ def transform(self, segments, ranks=None, val_map_ops=None): if not isinstance(deploy_ranks, list): raise TypeError('Expected ranks to be list[list[int],]') deploy_ops = val_map_ops[sid] - segment.deploy(deploy_ranks, deploy_ops) + segment.deploy(deploy_ranks) #TODO: segment transformation on existing segments else: raise NotImplementedError @@ -136,3 +136,10 @@ def remove_segment(self, segment_or_index): else: raise ValueError("Expected Segment instance or index int") + def merge_segment(self, indices, reduction_op): + """ + Merge segments for the logical tensor + + The merged segments will be placed at the end of the list. + """ + raise NotImplementedError diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py index 2f7f7504..1360f5b5 100644 --- a/cube/tensor/segment.py +++ b/cube/tensor/segment.py @@ -6,7 +6,7 @@ class Segment: - def __init__(self, logical_tensor, indices, shape): + def __init__(self, logical_tensor, indices, val_map_op, shape): """Create Segment based on the logical tensor Segment manages: @@ -17,7 +17,7 @@ def __init__(self, logical_tensor, indices, shape): Attribute: indices (tuple(slice,) or list[list[int]]): indices of logical_tensor for this segment - deploy_op (None or callable): + val_map_op (None or callable): deploy op to take logical value and map merge_op (None or callable): merge op to take physical tensor @@ -30,6 +30,8 @@ def __init__(self, logical_tensor, indices, shape): # segment info self.indices = indices + self.val_map_ops = list() + self.add_val_map_op(val_map_op) self.shape = tuple(shape) # physical tensor (the PyTorch Tensor) @@ -38,13 +40,12 @@ def __init__(self, logical_tensor, indices, shape): # deploy information self.placement = list() self.group = None - self.deploy_op = None self.materialized = False # recover op self.merge_op = None - def deploy(self, ranks, value_map_op=None): + def deploy(self, ranks): """deploy (materialize) to physical tensors Materialize physical tensors for this community and spread out @@ -66,16 +67,14 @@ def deploy(self, ranks, value_map_op=None): self.group = DeviceGroup().get_group(ranks) if rank in ranks: if self.logical_tensor.data is None: - # TODO: check overlap - self.physical_tensor = torch.randn(tuple(self.segment.shape), device='cuda') - else: - # select from logical data - self.physical_tensor = torch.empty(tuple(self.shape), device='cuda') - self.physical_tensor.copy_( - self.logical_tensor.data[self.indices.get()].reshape(self.shape) - ) - if value_map_op is not None: - self.physical_tensor.data = value_map_op(self.physical_tensor) + raise RuntimeError("Try deploying a segment from a logical tensor without data") + # select from logical data + self.physical_tensor = torch.empty(tuple(self.shape), device='cuda') + self.physical_tensor.copy_( + self.logical_tensor.data[self.indices.get()].reshape(self.shape) + ) + for val_map_op in self.val_map_ops: + self.physical_tensor.data = val_map_op(self.physical_tensor) self.materialized = True def recover(self, reduction_op): @@ -97,6 +96,15 @@ def recover(self, reduction_op): else: raise RuntimeError("The Segment has not been materialized") + def add_val_map_op(self, val_map_op): + """ + Append val_map_op to the end + """ + if val_map_op is not None: + if not callable(val_map_op): + raise TypeError("Expected val_map_op to be callable or None") + self.val_map_ops.append(val_map_op) + def get_physical_tensor(self): """Get physical tensor if materialized diff --git a/tests/tensor/test_logical_tensor.py b/tests/tensor/test_logical_tensor.py index d6483824..e7b5e06e 100644 --- a/tests/tensor/test_logical_tensor.py +++ b/tests/tensor/test_logical_tensor.py @@ -36,7 +36,7 @@ def test_logical_tensor_select(): [7,5,9,4] ) indices = BaseIndices(sparse_indices) - segment = tensor.select(indices, shape=(2,2)) + segment = tensor.select(indices, None, shape=(2,2)) assert isinstance(segment, Segment) assert segment.materialized is False @@ -52,7 +52,7 @@ def test_logical_tensor_fill(): [7,5,9,4] ) indices = BaseIndices(sparse_indices) - segment = tensor.select(indices, shape=(2,2)) + segment = tensor.select(indices, None, shape=(2,2)) tensor.add_segment(segment) assert segment.materialized is False @@ -80,7 +80,7 @@ def test_logical_tensor_transform(): [7,5,9,4] ) indices = BaseIndices(sparse_indices) - segment = tensor.select(indices, shape=(2,2)) + segment = tensor.select(indices, None, shape=(2,2)) ranks = [0,1,3] tensor.transform([segment], [ranks], [None]) diff --git a/tests/tensor/test_segment.py b/tests/tensor/test_segment.py index e16a1506..fbfde213 100644 --- a/tests/tensor/test_segment.py +++ b/tests/tensor/test_segment.py @@ -29,14 +29,14 @@ def test_segment_init(): ofst = [2,4,3] indices = TileIndices(anchor, ofst) - segment = Segment(tensor, indices, ofst) + segment = Segment(tensor, indices, None, ofst) assert segment.logical_tensor is tensor assert segment.shape == tuple(ofst) assert segment.physical_tensor is None assert len(segment.placement) == 0 assert segment.group is None - assert segment.deploy_op is None + assert len(segment.val_map_ops) == 0 assert segment.materialized is False assert segment.merge_op is None @@ -50,10 +50,10 @@ def test_segment_deploy(): ofst = [2,4,3] indices = TileIndices(anchor, ofst) - segment = Segment(tensor, indices, ofst) + segment = Segment(tensor, indices, None, ofst) ranks = [0,2] - segment.deploy(ranks, value_map_op=None) + segment.deploy(ranks) physical_tensor = segment.get_physical_tensor() tensor_ref = tensor.data[indices.get()].cuda() @@ -64,12 +64,12 @@ def test_segment_deploy(): assert physical_tensor is None assert segment.placement == ranks assert segment.group == DeviceGroup().get_group(ranks) - assert segment.deploy_op is None + assert len(segment.val_map_ops) == 0 assert segment.materialized is True assert segment.merge_op is None -def test_segment_recover(): +def test_segment_deploy_with_val_map(): myrank = DeviceGroup().rank tensor = LogicalTensor((10,10,10)) @@ -78,10 +78,16 @@ def test_segment_recover(): ofst = [2,4,3] indices = TileIndices(anchor, ofst) - segment = Segment(tensor, indices, ofst) + segment = Segment( + logical_tensor = tensor, + indices = indices, + val_map_op = lambda tensor: tensor / 2, + shape = ofst + ) + assert len(segment.val_map_ops) == 1 ranks = [0,2] - segment.deploy(ranks, value_map_op=lambda tensor: tensor / 2) + segment.deploy(ranks) # deploy check physical_tensor = segment.get_physical_tensor() @@ -92,19 +98,6 @@ def test_segment_recover(): else: assert physical_tensor is None - # recover to get logical value - def reduction_op(tensor, group): - torch.distributed.all_reduce(tensor, group=group) - segment.recover(reduction_op=reduction_op) - physical_tensor = segment.get_physical_tensor() - - tensor_ref = tensor.data[indices.get()].cuda() - if myrank in [0,2]: - assert physical_tensor.device == torch.device('cuda:{}'.format(myrank)) - assert torch.allclose(physical_tensor, tensor_ref) is True - else: - assert physical_tensor is None - if __name__ == '__main__': @@ -112,4 +105,4 @@ def reduction_op(tensor, group): test_segment_init() test_segment_deploy() - test_segment_recover() + test_segment_deploy_with_val_map() From 9982808dba9fcda5b452a476de058106cf5d64df Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 1 Aug 2021 15:07:46 +0000 Subject: [PATCH 0096/1892] outline with z3 constraints: --- cube/tensor/logic/outline.py | 214 +++++++++++++++++++++++------------ requirements.txt | 1 + tests/tensor/test_outline.py | 164 ++++++++++++++++----------- 3 files changed, 242 insertions(+), 137 deletions(-) create mode 100644 requirements.txt diff --git a/cube/tensor/logic/outline.py b/cube/tensor/logic/outline.py index 76ebf21c..b7526f02 100644 --- a/cube/tensor/logic/outline.py +++ b/cube/tensor/logic/outline.py @@ -13,6 +13,8 @@ from cube.tensor.segment import Segment from cube.tensor.indices import TileIndices +import z3 + # interface to setup restrictions on the segmentation @@ -22,103 +24,171 @@ class BaseOutline: To setup an attribute (requirement), use `inst_baseoutline.attribute_name = val` """ - def __init__(self): - self.reduction = reduction - # decide how to generate segmentation given the requirement - self.policy_fn = None - - def set_policy(self, policy_fn): - if not callable(policy_fn): - raise TypeError("Expected a function to take BaseOutline instance") - self.policy_fn = policy_fn - - def __setattr__(self, key, val): - if key in self.__dict__: - self.__dict__[key].set(val) - #TODO: Align semantics will not allow setting val on child, need a new class - elif isinstance(val, MutableContainer) or isinstance(val, ConstantContainer): - self.__dict__[key] = val - elif val is None or isinstance(val, range) or isinstance(val, set): - self.__dict__[key] = MutableContainer(val) - else: - self.__dict__[key] = ConstantContainer(val) - - def interpret(self, logical_tensor): - raise NotImplementedError + def __init__(self, solver, shape): + super().__init__() + self.solver = solver + self.shape = shape + self.attributes = list() + + def get_attributes(self): + return self.attributes + + def add_field(self, **kwargs): + """ + Add a config field to current instance - def __call__(self, logical_tensor): - if not isinstance(logical_tensor, LogicalTensor): - raise TypeError("Expected logical_tensor is instance of LogicalTensor") + Usage: self.add_field(key=val): - #TODO: merge out to fuse in configurable space - if self.policy_fn is not None: - self.policy_fn.get()(self) + key is the name for the config attribute, val is the choices - self.interpret(logical_tensor) + val type: + list[int]: the key can only be the options from the val; + int: the key can only be the val; + range: the key can only be the val in the range; + None: the key can be any integers + z3.z3.ArithRef: the key is aligned with another attribute + """ + for key in kwargs: + if key in self.__dict__: + raise RuntimeError("{} already in config field".format(key)) + val = kwargs[key] + if isinstance(val, list): + if not all([isinstance(arg, int) for arg in val]): + raise TypeError("{} only supports list[int] choices".format(key)) + self.__dict__[key] = z3.Int(key) + self.attributes.append(self.__dict__[key]) + self.solver.add(z3.Or([self.__dict__[key] == val for val in val])) + elif isinstance(val, int): + self.__dict__[key] = z3.Int(str(id(self))+key) + self.attributes.append(self.__dict__[key]) + self.solver.add(self.__dict__[key] == val) + elif isinstance(val, range): + self.__dict__[key] = z3.Int(str(id(self))+key) + self.attributes.append(self.__dict__[key]) + self.solver.add(self.__dict__[key] >= val[0]) + raise NotImplementedError + elif val is None: + self.__dict__[key] = z3.Int(str(id(self))+key) + self.attributes.append(self.__dict__[key]) + elif isinstance(val, z3.z3.ArithRef): + self.__dict__[key] = val + else: + raise TypeError("{} can only be int, list[int], z3.Int()".format(key)) + + def remove_config(self, config): + if not isinstance(config, z3.z3.ModelRef): + raise TypeError("Expected config from z3 model()") + self.solver.add(z3.Or([z3.Not(attr == config[attr]) for attr in self.attributes])) + + def interpret(self, logical_tensor, config): + raise NotImplementedError -class Full(ConfigTemplate): +class Full(BaseOutline): - def __init__(self): - pass + def __init__(self, solver, shape): + super().__init__(solver, shape) - def interpret(self, logical_tensor): - shape = logical_tensor.shape - indices = TileIndices([0] * len(shape), shape) - segment = Segment(logical_tensor, indices, self.reduction.get()) + def interpret(self, logical_tensor, config): + if not isinstance(config, z3.z3.ModelRef): + raise TypeError("Expected config from z3 model()") + indices = TileIndices([0] * len(self.shape), self.shape) + segment = logical_tensor.select(indices, None, self.shape) return [segment] -class SplitAxis(ConfigTemplate): +class SplitAxis(BaseOutline): - def __init__(self, axis, chunk_num=None, overlap=0, uniform=True): + def __init__(self, solver, shape, axis, chunk_num, overlap): """ - Segmentation Pattern Requirement (parameters): - - axis (int): the axis to split + Split the logical tensor spatially in `axis` dimension + + TODO: support split axis with non-uniform chunk size + + shape: list / tuple int + shape of input logical tensor + axis: int + which axis to split + chunk_num: options (iterable int) / None / int: + how many segments to produce + uniform: Boolean + whether restrict to uniform split + overlap: options (iterable int) / int: + overlap size on the boundary + """ + if not isinstance(axis, int): + raise RuntimeError("Expected axis to be an integer") - chunk_num (None, int, tuple(int, int)): - valid chunk numbers to split. - If None, then any chunk number is valid; - If an integer, only the specified chunk number is valid; - If a tuple(min, max), the chunk number wihtin the scope [min,max] is valid + super().__init__(solver, shape) + self.axis = axis + + self.add_field(overlap=overlap) + self.solver.add(self.overlap >= 0) + self.add_field(chunk_num=chunk_num) + self.solver.add(self.chunk_num >= 0) - overlap (0, int, tuple(int, int)): - valid size for overlaping on the boundary of each splitted chunks. - If None, any overlapping is valid - If an integer, each overlap size is valid; - if a tuple(min, max), the overlap size wihtin the scope [min,max] is valid + # TODO: change to array to adapt with non-uniform cases + self.add_field(chunk_size=None) + + # setup constraints + total_size = self.shape[self.axis] + self.solver.add(self.chunk_num * self.chunk_size - self.overlap * (self.chunk_num - 1) == total_size) + def interpret(self, logical_tensor, config): """ - super().__init__() - self.axis = axis - self.chunk_num = chunk_num - self.uniform = uniform - self.overlap = overlap + Get segments from config - def interpret(self, logical_tensor): - """ - Runtime segment generation given the logical tensor shape + Args: + logical_tensor (LogicalTensor): + the logical tensor + config: + Config searched by model output - This is the policy that how to do the translation. """ - segments = list() - shape = list(logical_tensor.shape) - shape[self.axis.get()] = shape[self.axis.get()] // self.chunk_num.get() + if tuple(logical_tensor.shape) != tuple(self.shape): + raise RuntimeError("The logical tensor's shape doesn't match") + if not isinstance(config, z3.z3.ModelRef): + raise TypeError("Expected config from z3 model()") + chunk_num = config[self.chunk_num].as_long() + chunk_size = config[self.chunk_size].as_long() + shape = list(self.shape) + shape[self.axis] = chunk_size anchor = [0] * len(shape) - #TODO: support list of reductions - for cid in range(self.chunk_num.get()): + segments = list() + for cid in range(chunk_num): indices = TileIndices(anchor, shape) - segment = Segment(logical_tensor, indices) + segment = logical_tensor.select(indices, None, shape) segments.append(segment) - anchor[self.axis.get()] += shape[self.axis.get()] + anchor[self.axis] += shape[self.axis] return segments -class SplitValue(ConfigTemplate): +class SplitValue(BaseOutline): + + def __init__(self, solver, shape, chunk_num, val_map_op): + """ + Split the whole tensor in value dimension. + + Each segment shape will be same with logical tensor. - def __init__(self, chunk_num=None, val_map_op=None): - ##TODO - self.chunk_num = chunk_num + Each segment value will be modified by `val_map_op`. + """ + if not callable(val_map_op): + raise TypeError("Expected val_map_op a callable function") + super().__init__(solver, shape) + self.add_field(chunk_num=chunk_num) + self.solver.add(self.chunk_num >= 1) self.val_map_op = val_map_op + + def interpret(self, logical_tensor, config): + if tuple(logical_tensor.shape) != tuple(self.shape): + raise RuntimeError("The logical tensor's shape doesn't match") + chunk_num = config[self.chunk_num].as_long() + segments = list() + for cid in range(chunk_num): + # full tensor shape + indices = TileIndices([0] * len(self.shape), self.shape) + segment = logical_tensor.select(indices, self.val_map_op, self.shape) + segments.append(segment) + return segments diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..f2466477 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +z3-solver \ No newline at end of file diff --git a/tests/tensor/test_outline.py b/tests/tensor/test_outline.py index 86bb825b..40bea106 100644 --- a/tests/tensor/test_outline.py +++ b/tests/tensor/test_outline.py @@ -1,98 +1,132 @@ -import cube.tensor.logic.segment.outline as outline -import cube.tensor.logic.segment as segment +from cube.tensor.logic.tensor import LogicalTensor +import cube.tensor.logic.outline as outline +from cube.tensor.segment import Segment import torch +import z3 -def test_base(): - - dsp1 = outline.BaseOutline(reduction=segment.ReductionOp.Sum) - assert isinstance(dsp1.reduction, outline.ConstantContainer) - assert dsp1.reduction.get() == segment.ReductionOp.Sum - - choice = {segment.ReductionOp.Sum, segment.ReductionOp.Replica} - dsp2 = outline.BaseOutline(reduction=choice) - assert isinstance(dsp2.reduction, outline.MutableContainer) - assert dsp2.reduction.get() is None - assert dsp2.reduction.get(scope=True) == choice +def iter_each_config(solver, attrs): + if len(attrs) == 0: + solver.check() + yield solver.model() + else: + while solver.check() == z3.sat: + config = solver.model() + solver.add(z3.Or([z3.Not(attr == config[attr]) for attr in attrs])) + yield config def test_full(): - shape = (10,10,10) tensor = torch.randn(shape) - full_dsp = outline.Full(reduction=segment.ReductionOp.Replica) - assert full_dsp.reduction.get() == segment.ReductionOp.Replica + solver = z3.Solver() + + full_dsp = outline.Full(solver, shape) + assert len(full_dsp.get_attributes()) == 0 - segments = full_dsp(tensor.shape) - assert len(segments) == 1 - tile_seg = segments[0] - assert type(tile_seg) == segment.TileSegment + configs = list() + for config in iter_each_config(solver, full_dsp.get_attributes()): + configs.append(config) + + assert len(configs) == 1 + config = configs[0] - sub_tensor = tensor[tile_seg.get_indices()] - assert torch.all(torch.eq(sub_tensor, tensor)) + tensor = LogicalTensor(shape=shape) + segments = full_dsp.interpret(tensor, config) + assert len(segments) == 1 + assert tuple(segments[0].shape) == tuple(tensor.shape) + assert torch.allclose(tensor.data, tensor.data[segments[0].indices.get()]) is True def test_split_axis(): axis = 1 - num = 8 + shape = [1024, 16] + solver = z3.Solver() - shape = (4,16,4) tensor = torch.randn(shape) split_dsp = outline.SplitAxis( - axis=axis, chunk_num=None, overlap=0, - reduction=segment.ReductionOp.Replica, uniform=True) - assert split_dsp.axis.get() == 1 - assert split_dsp.chunk_num.get() is None - assert split_dsp.uniform.get() is True - assert split_dsp.overlap.get() == 0 - assert split_dsp.reduction.get() == segment.ReductionOp.Replica - - ## Policy here to decide how to split - if split_dsp.chunk_num.get() is None: - split_dsp.chunk_num = num - ### - - - segs = split_dsp(tensor.shape) - assert len(segs) == num - assert torch.all( - torch.Tensor( - [type(seg) == segment.TileSegment for seg in segs])).item() is True - - ofst = 0 - expected_shape = list(shape) - expected_shape[axis] = shape[axis] // num - for cid in range(num): - seg = segs[cid] - sub_tensor = tensor[seg.get_indices()] - ref_tensor = tensor[:,ofst:ofst+expected_shape[axis],:] - # print('sub tensor {}: {}'.format(sub_tensor.size(), sub_tensor)) - # print('ref tensor {}: {}'.format(ref_tensor.size(), ref_tensor)) - assert sub_tensor.size() == torch.Size(expected_shape) - assert torch.all(torch.eq(sub_tensor, ref_tensor)) - ofst += expected_shape[axis] + solver, shape, axis, chunk_num=None, overlap=0 + ) + + # test config space + configs = list() + for config in iter_each_config(solver, split_dsp.get_attributes()): + configs.append(config) + assert len(configs) == 5 + + # test segments + tensor = LogicalTensor(shape=shape) + segments = split_dsp.interpret(tensor, configs[0]) + shape_axis = [segment.shape[axis] for segment in segments] + assert sum(shape_axis) == shape[axis] + + +def test_split_axis_with_constraints(): + + axis = 1 + shape = [1024, 16] + solver = z3.Solver() + + split_dsp = outline.SplitAxis( + solver, shape, axis, chunk_num=None, overlap=0 + ) + + # this can be set due to device number constraints + split_dsp.solver.add(split_dsp.chunk_num <= 8) + + configs = list() + for config in iter_each_config(solver, split_dsp.get_attributes()): + configs.append(config) + # print(config) + assert len(configs) == 4 + + +def test_split_value(): + + shape = [1024, 32] + split_op = lambda tensor, rank, world_size : tensor / world_size + solver = z3.Solver() + + split_dsp = outline.SplitValue(solver, shape, None, split_op) + split_dsp.solver.add(split_dsp.chunk_num <= 4) + configs = list() + for config in iter_each_config(solver, split_dsp.get_attributes()): + configs.append(config) + assert len(configs) == 4 + + tensor = LogicalTensor(shape=shape) + segments = split_dsp.interpret(tensor, configs[0]) + for segment in segments: + assert torch.allclose(tensor.data, tensor.data[segment.indices.get()]) is True def test_align(): + shape = [1024, 16] + solver = z3.Solver() + dsp1 = outline.SplitAxis( - axis=1, chunk_num=None, overlap=0, - reduction=segment.ReductionOp.Replica, uniform=True) + solver, shape, axis=0, chunk_num=None, overlap=0, + ) dsp2 = outline.SplitAxis( - axis=2, chunk_num=dsp1.chunk_num, overlap=0, - reduction=segment.ReductionOp.Replica, uniform=True) - - dsp1.chunk_num = 3 - assert dsp2.chunk_num.get() == 3 - assert dsp2.axis.get() == 2 + solver, shape, axis=1, chunk_num=dsp1.chunk_num, overlap=0, + ) + + configs = list() + attrs = dsp1.get_attributes() + dsp2.get_attributes() + for config in iter_each_config(solver, attrs): + configs.append(config) + assert len(configs) == 5 if __name__ == '__main__': - test_base() + # test_base() test_full() test_split_axis() + test_split_axis_with_constraints() + test_split_value() test_align() \ No newline at end of file From bc65c9e78691641abc97609b2b80cf1e616eb6c7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 1 Aug 2021 15:08:48 +0000 Subject: [PATCH 0097/1892] config container (deprecated) --- cube/config/container.py | 43 +++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/cube/config/container.py b/cube/config/container.py index 5a1a3045..37639e62 100644 --- a/cube/config/container.py +++ b/cube/config/container.py @@ -4,7 +4,8 @@ class ConditionContainer: def __init__(self, satisfy_fn): if not callable(satisfy_fn): raise TypeError("Expected function") - self._satisfy_fn = satisfy_fn + self._satisfy_fn = (satisfy_fn,) + self._condition_fn = None self._val = None self._choices = None self._lock = False @@ -20,11 +21,17 @@ def set(self, val): Set the value, will raise ValueError if not satisfy """ if self._lock: - raise RuntimeError("Try to set a locked config") + raise False if self._choices is not None: if not self.satisfy(val, self._choices): - raise ValueError("Fail to set config") + return False + val_backup = self._val self._val = val + if self._condition_fn is not None: + if not self._condition_fn(): + self._val = val_backup + return False + return True def lock(self): """ @@ -34,12 +41,12 @@ def lock(self): def satisfy(self, val): """ - Check whether the value satisfy the choices + Check whether the value satisfy the choices and conditions Returns: True if satisfy, False not """ - return self._satisfy_fn(val, self._choices) + return self._satisfy_fn[0](val, self._choices) def choices(self): """ @@ -70,6 +77,8 @@ def __init__(self, choices): def satisfy_fn(val, choices): return val in choices super().__init__(satisfy_fn) + if not hasattr(choices, '__iter__'): + choices = [choices] self._choices = choices @@ -90,3 +99,27 @@ def satisfy_fn(val, choices): return False super().__init__(satisfy_fn) self._choices = type_choices + + +class UniformSumContainer(ConditionContainer): + + def __init__(self, summation): + """ + Create a summation restriction container + """ + def satisfy_fn(val, choices): + return len(set(val) == 1) and sum(val) == choices + super().__init__(satisfy_fn) + self._choices = summation + + +class SumContainer(ConditionContainer): + + def __init__(self, summation, slots=None): + """ + Create a summation restriction container + """ + def satisfy_fn(val, choices): + return sum(val) == choices + super().__init__(satisfy_fn) + self._choices = summation From 7e27ae9220936c087eeed4b2a829d21671d08421 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 1 Aug 2021 15:12:32 +0000 Subject: [PATCH 0098/1892] val map op takes 3 args: tensor, rank, world_size --- cube/tensor/segment.py | 5 +++-- tests/tensor/test_segment.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py index 1360f5b5..77dd6eb8 100644 --- a/cube/tensor/segment.py +++ b/cube/tensor/segment.py @@ -57,7 +57,8 @@ def deploy(self, ranks): Argument: ranks (list[int]): device id list value_map_op (callable): - takes the tensor, return a new tensor + takes the tensor, rank, world_size, + return a new tensor """ if not isinstance(ranks, list): raise TypeError("Expected ranks in list[int]") @@ -74,7 +75,7 @@ def deploy(self, ranks): self.logical_tensor.data[self.indices.get()].reshape(self.shape) ) for val_map_op in self.val_map_ops: - self.physical_tensor.data = val_map_op(self.physical_tensor) + self.physical_tensor.data = val_map_op(self.physical_tensor, rank, len(ranks)) self.materialized = True def recover(self, reduction_op): diff --git a/tests/tensor/test_segment.py b/tests/tensor/test_segment.py index fbfde213..a63ae028 100644 --- a/tests/tensor/test_segment.py +++ b/tests/tensor/test_segment.py @@ -81,7 +81,7 @@ def test_segment_deploy_with_val_map(): segment = Segment( logical_tensor = tensor, indices = indices, - val_map_op = lambda tensor: tensor / 2, + val_map_op = lambda tensor, rank, world_size: tensor / world_size, shape = ofst ) assert len(segment.val_map_ops) == 1 @@ -91,7 +91,7 @@ def test_segment_deploy_with_val_map(): # deploy check physical_tensor = segment.get_physical_tensor() - tensor_ref = tensor.data[indices.get()].cuda() / 2 + tensor_ref = tensor.data[indices.get()].cuda() / len(ranks) if myrank in [0,2]: assert physical_tensor.device == torch.device('cuda:{}'.format(myrank)) assert torch.allclose(physical_tensor, tensor_ref) is True From 1bb3bccb78d6eb02edf9d9176ebbd13c7a438561 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 1 Aug 2021 15:14:49 +0000 Subject: [PATCH 0099/1892] update envs --- scripts/env-setup.sh | 2 ++ scripts/keep.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index 83dcf35e..d28e204e 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -24,3 +24,5 @@ echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc # find cube/ -name "*.py" -print0 | xargs -0 wc -l python setup.py develop +pip install -r requirements.txt + diff --git a/scripts/keep.py b/scripts/keep.py index d2e87aa9..5fd58d03 100644 --- a/scripts/keep.py +++ b/scripts/keep.py @@ -46,7 +46,7 @@ def keep(rank, args): time.sleep(args.interval) while True: util = get_gpu_util(rank) - if util >= 0: + if util > 0: break print('rank {}: find gpu busy, keep sleeping...'.format(rank)) time.sleep(args.interval) From 4aa5b69ece3e63768f28e89ab6618c0936400701 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 2 Aug 2021 06:24:14 +0000 Subject: [PATCH 0100/1892] holistic op test --- cube/operator/holist/generics.py | 164 +++++++++++++------------ tests/operator/test_holistic_linear.py | 11 +- tests/operator/test_holistic_op.py | 129 +++++++++++-------- 3 files changed, 172 insertions(+), 132 deletions(-) diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index b2167d83..90ca71cc 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -9,15 +9,14 @@ """ from cube.tensor.logic.tensor import LogicalTensor -from cube.tensor.community import Community +from cube.tensor.logic.outline import BaseOutline + +import z3 class GenericHolisticOp: - def __init__(self, - input_layout, output_layout, - input_format=None, output_format=None - ): + def __init__(self, shapes): """ Layout is the community distribution requirement for input and output logical tensors. @@ -37,33 +36,58 @@ def __init__(self, output_format (list[list[int], None]): output dim order compare with logical definition """ + self.solver = z3.Solver() + self.shapes = shapes - # holistic layout (outliner) of input - if not isinstance(input_layout, list): - raise TypeError("Require input layout for HolistOp is a list") - if not isinstance(input_format, list): - raise TypeError("Require input format for HolistOp is a list") - if not isinstance(output_layout, list): - raise TypeError("Require output layout for HolistOp is a list") - if not isinstance(output_format, list): - raise TypeError("Require output format for HolistOp is a list") - - self.input_layout = input_layout - self.input_format = input_format - - # holistic layout of output - self.output_layout = output_layout - self.output_format = output_format + self.input_layouts = list() + self.output_layouts = list() self.logical_op = None + self.output_shapes = list() + + self.attributes = list() self.policy_fn = None + self.config = None + + def set_input_layouts(self, layouts): + """ + Set input layout + + Args: + layouts (list[BaseOutline]): layout list for input logical tensor + """ + for layout in layouts: + if not isinstance(layout, BaseOutline): + TypeError("Require input layout for HolistOp is a list[BaseOutline]") + self.attributes += layout.get_attributes() + self.input_layouts.append(layout) + + def set_output_layouts(self, layouts): + """ + Set output layout + + Args: + layouts (list[BaseOutline]): layout list for output logical tensor + """ + for layout in layouts: + if not isinstance(layout, BaseOutline): + TypeError("Require input layout for HolistOp is a list[BaseOutline]") + self.attributes += layout.get_attributes() + self.output_layouts.append(layout) def set_logic_op(self, logic_op): """ Set logic op. This will be automatically called when the holistic op registered in a logical op. """ + # if not isinstance(logic_op, GenericLogicalOp): + # raise TypeError("Require a logic op to register") self.logical_op = logic_op + + def set_config(self, config): + if not isinstance(config, z3.z3.ModelRef): + raise TypeError("Expected config from z3 solver.model()") + self.config = config def input_adapter(self, *args, **kwargs): """ @@ -74,43 +98,38 @@ def input_adapter(self, *args, **kwargs): #TODO: kwargs input_num = len(args) - if len(self.input_layout) != input_num: + if len(self.input_layouts) != input_num: raise RuntimeError("Fail to adapt input: layout length not equal") - if len(self.input_format) != input_num: - raise RuntimeError("Fail to adapt input: format length not equal") + # if len(self.input_format) != input_num: + # raise RuntimeError("Fail to adapt input: format length not equal") # step 1: data reformat based on the input argument - for input, dim_order in zip(args, self.input_format): - if dim_order is not None: - input.permute(dim_order) - - # step 2: get communities based on expert description - input_communities = list() - for tensor, outliner in zip(args, self.input_layout): + # for input, dim_order in zip(args, self.input_format): + # if dim_order is not None: + # input.permute(dim_order) + + # step 2: Policy: segmentation + deploy decision + if self.policy_fn is None: + raise RuntimeError("Expected a runtime configuration policy") + config, input_ranks = self.policy_fn[0](self) + self.set_config(config) + + # step 3: segmentation + input_segments = list() + for tensor, outliner in zip(args, self.input_layouts): if outliner is not None and isinstance(tensor, LogicalTensor): - segments = outliner(tensor.shape) - communities = [Community(seg) for seg in segments] - input_communities.append(communities) + segments = outliner.interpret(tensor, self.config) + input_segments.append(segments) else: - input_communities.append(None) + input_segments.append(None) - # step 3: physical tensor placement (policy) - if self.policy_fn is not None: - input_ranks, input_val_map_fns = \ - self.policy_fn[0](input_communities, *args) - else: - # TODO: default policy - input_ranks = [None] * len(args) - input_val_map_fns = [None] * len(args) - - # step 4: community matching + # step 4: deploy for tid in range(len(args)): tensor = args[tid] if isinstance(tensor, LogicalTensor): - communities = input_communities[tid] + segments = input_segments[tid] ranks = input_ranks[tid] - val_map_fn = input_val_map_fns[tid] - tensor.match(communities, ranks, val_map_fn) + tensor.transform(segments, ranks) def forward(self, *args, **kwargs): """ @@ -131,9 +150,9 @@ def output_adapter(self, outputs): Data reformat to logical op format Args: - outputs (tuple(list[physical_tensor],)) - each `list[physical_tensor]` represents a output of the op - with is communities + outputs (tuple(list[OpResult],)) + each `list[OpResult]` represents a output of the op + with its segments Returns: logical outputs (tuple(LogicalTensor,)): the logical tensor list @@ -141,22 +160,24 @@ def output_adapter(self, outputs): #TODO: fix: data re-format order. Should be ahead of logical tensor construction if not isinstance(outputs, tuple): outputs = (outputs,) + # step 1: construct to logical tensor - logical_outputs = list() - for output, outliner, shape in zip(outputs, self.output_layout, self.logical_shapes): - segments = outliner(shape) - communities = [Community(segment) for segment in segments] - for community, op_res in zip(communities, output): - #if DeviceGroup().rank == 0: - # print(op_res.res.size(), community.segment.shape) - community.set_physical_tensor(op_res.res, op_res.placement) - output = LogicalTensor.construct(shape, communities) - logical_outputs.append(output) + for output, outliner in zip(outputs, self.output_layouts): + logical_tensor = LogicalTensor(outliner.shape, init_data=False) + segments = outliner.interpret(shape, self.config) + for segment in segments: + logical_tensor.add_segment(segment) + logical_tensor.fill( + physical_tensors=[op_res.res for op_res in output], + ranks=[op_res.placement for op_res in output] + ) + logical_outputs.append(logical_tensor) + # step 2: data reformat based on the output - for out_id in range(len(self.output_format)): - dim_order = self.output_format[out_id] - if dim_order is not None and isinstance(logical_outputs[out_id], LogicalTensor): - logical_ouputs[out_id] = logical_ouputs[out_id].permute(dim_order) + # for out_id in range(len(self.output_format)): + # dim_order = self.output_format[out_id] + # if dim_order is not None and isinstance(logical_outputs[out_id], LogicalTensor): + # logical_ouputs[out_id] = logical_ouputs[out_id].permute(dim_order) if len(logical_outputs) == 1: return logical_outputs[0] @@ -172,16 +193,13 @@ def __call__(self, *args, **kwargs): outputs = self.forward(*args, **kwargs) # wrap to logical tensor - if self.logical_op is None: - raise RuntimeError("This holistic op doesn't have logical op") - self.logical_shapes = self.logical_op.shape_infer(*args, **kwargs) outputs = self.output_adapter(outputs) return outputs - def set_deploy_policy(self, policy_fn): + def set_policy(self, policy_fn): """ - Register a policy to take inputs (logical tensors) and segments, + Register a policy to take layouts and solver, generate device placement for each community, and corresponding message mapping @@ -191,10 +209,4 @@ def set_deploy_policy(self, policy_fn): if not callable(policy_fn): raise TypeError("Expected callable function") self.policy_fn = (policy_fn,) - - def set_segmentation_policy(self, policy_fn): - for outliner in self.input_layout: - outliner.set_policy(policy_fn) - for outliner in self.output_layout: - outliner.set_policy(policy_fn) diff --git a/tests/operator/test_holistic_linear.py b/tests/operator/test_holistic_linear.py index d4f3b83e..1f5e4751 100644 --- a/tests/operator/test_holistic_linear.py +++ b/tests/operator/test_holistic_linear.py @@ -71,10 +71,13 @@ def test_holistic_linear_op_column_weight(): weight = LogicalTensor(shape=(N,1024)) bias = LogicalTensor(shape=(N,)) + # output = LogicalLinear(input, weight, bias) + + # ================================ Policy =========================== + holistic_op = LinearColumnWeight() holistic_op.logical_op = LogicalLinear() - # ================================ Policy =========================== def policy_for_how_many_tiles(outliner): if isinstance(outliner, sg.outline.Full): pass @@ -83,15 +86,19 @@ def policy_for_how_many_tiles(outliner): outliner.chunk_num = 4 else: raise TypeError("Unhandled outliner type") + # -> together def policy_for_each_tile_placement(community, input, weight, bias): + # generate results (hard code) [helper function] input_ranks = [ - [[0,1,2,3]], + [[0,1,2,3]], [DeviceGroup().all_ranks()] [[0],[1],[2],[3]], [[0],[1],[2],[3]] ] input_val_map_fns = list([None, None, None]) return input_ranks, input_val_map_fns + + # Missing Policy: where physical op executed? holistic_op.set_deploy_policy( policy_for_each_tile_placement diff --git a/tests/operator/test_holistic_op.py b/tests/operator/test_holistic_op.py index 98659a9c..d5dd75a6 100644 --- a/tests/operator/test_holistic_op.py +++ b/tests/operator/test_holistic_op.py @@ -11,84 +11,105 @@ tests/operator/test_holistic_op.py """ -import cube.tensor.logic.segment as sg +import cube.tensor.logic.outline as outline from cube.tensor.logic.tensor import LogicalTensor from cube.operator.holist.generics import GenericHolisticOp from cube.device.physic.group import DeviceGroup import torch +import z3 def test_generic_holistic_op_init(): + shapes = [(32, 2048), (1024, 2048), (32, 1024)] + op = GenericHolisticOp(shapes) + # description - input_layout = sg.SplitAxis( - axis=0, overlap=0, reduction=sg.ReductionOp.Replica + input_layout = outline.Full( + op.solver, op.shapes[0], ) - weight_layout = sg.Full(reduction=sg.ReductionOp.Replica) - output_layout = sg.SplitAxis( - axis=0, overlap=0, chunk_num=input_layout.chunk_num, - reduction=sg.ReductionOp.Replica + weight_layout = outline.SplitAxis( + op.solver, op.shapes[1], + axis=0, chunk_num=None, overlap=0, ) - - op = GenericHolisticOp( - input_layout=[input_layout, weight_layout], - output_layout=[output_layout], - input_format=[None, None], - output_format=[None], + output_layout = outline.SplitAxis( + op.solver, op.shapes[2], + axis=0, chunk_num=weight_layout.chunk_num, overlap=0, ) - assert len(op.input_layout) == 2 - assert len(op.input_format) == 2 - assert len(op.output_layout) == 1 - assert len(op.output_format) == 1 + assert op.shapes == shapes + assert len(op.input_layouts) == 0 + assert len(op.output_layouts) == 0 assert op.logical_op is None assert op.policy_fn is None + op.set_input_layouts([input_layout, weight_layout]) + op.set_output_layouts([output_layout]) + + assert len(op.input_layouts) == 2 + assert len(op.output_layouts) == 1 + assert len(op.attributes) == 5 + def test_generic_holistic_op_input_adapter(): - input_layout = sg.SplitAxis( - axis=0, overlap=0, reduction=sg.ReductionOp.Replica + shapes = [(32, 512), (1024, 512), (32, 1024)] + input = LogicalTensor(shape=shapes[0]) + weight = LogicalTensor(shape=shapes[1]) + + op = GenericHolisticOp(shapes) + + # description + input_layout = outline.Full( + op.solver, op.shapes[0], ) - weight_layout = sg.Full(reduction=sg.ReductionOp.Replica) - output_layout = sg.SplitAxis( - axis=0, overlap=0, chunk_num=input_layout.chunk_num, - reduction=sg.ReductionOp.Replica + weight_layout = outline.SplitAxis( + op.solver, op.shapes[1], + axis=0, chunk_num=None, overlap=0, ) - - op = GenericHolisticOp( - input_layout=[input_layout, weight_layout], - output_layout=[output_layout], - input_format=[None, None], - output_format=[None], + output_layout = outline.SplitAxis( + op.solver, op.shapes[2], + axis=0, chunk_num=weight_layout.chunk_num, overlap=0, ) - input = LogicalTensor(shape=(1024, 1024)) - weight = LogicalTensor(shape=(1024, 1024)) - - ## Policy Here - input_layout.chunk_num = 4 - assert output_layout.chunk_num.get() == 4 - def policy_fn(input_communities, input, weight): - input_ranks = [ - [[0],[1],[2],[3]], - [[0,1,2,3]] - ] - input_val_map_fns = list([None, None]) - return input_ranks, input_val_map_fns - - op.register_policy(policy_fn) + op.set_input_layouts([input_layout, weight_layout]) + op.set_output_layouts([output_layout]) + + def policy(holist_op): + solver = holist_op.solver + attributes = holist_op.attributes + input_layout = holist_op.input_layouts[0] + weight_layout = holist_op.input_layouts[1] + output_layout = holist_op.output_layouts[0] + + # add restrictions based on device num + device_num = torch.cuda.device_count() + solver.add(weight_layout.chunk_num <= 4) + + # iterate all configs + configs = list() + while solver.check() == z3.sat: + config = solver.model() + configs.append(config) + solver.add( + z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) + ) + if len(attributes) == 0: + break + # choose one config -- suppose to the first + config = configs[0] + + # deploy decisions + chunk_num = config[weight_layout.chunk_num].as_long() + input_ranks = [list(range(0, chunk_num)),] + weight_ranks = list() + for rank in range(chunk_num): + weight_ranks.append([rank]) + + return config, [input_ranks, weight_ranks] + + op.set_policy(policy) op.input_adapter(input, weight) - myrank = DeviceGroup().rank - assert len(input.communities) == 4 - assert len(weight.communities) == 1 - physical_tensor = input.get_physical_tensor(input.segments[myrank]) - piece = 1024 // 4 - start = int(myrank * piece) - assert torch.allclose(physical_tensor, input.data.cuda()[start:start+piece, :]) is True - physical_tensor = weight.get_physical_tensor(weight.segments[0]) - assert torch.allclose(physical_tensor, weight.data.cuda()) is True - if __name__ == '__main__': group = DeviceGroup() From ee2c8b7605d04edcc5b9d793645ee1cf4e23eb9b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 2 Aug 2021 06:25:54 +0000 Subject: [PATCH 0101/1892] add helper funcs --- cube/tensor/logic/tensor.py | 9 +++------ cube/tensor/segment.py | 5 +++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index fb132ceb..40f4e765 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -47,16 +47,14 @@ def select(self, indices, val_map_op, shape): segment = Segment(self, indices, val_map_op, shape) return segment - def transform(self, segments, ranks=None, val_map_ops=None): + def transform(self, segments, ranks=None): """ - Transform the LogicalTensor with community list. + Transform the LogicalTensor with segment list. TODO: check if this should create a new logical tensor """ if not (isinstance(ranks, list) and len(ranks) == len(segments)): raise ValueError("Expected ranks to be a list with equal length of segments") - if not (isinstance(ranks, list) and len(val_map_ops) == len(segments)): - raise ValueError("Expected ranks to be a list with equal length of segments") - + if len(self.segments) == 0: for sid in range(len(segments)): segment = segments[sid] @@ -65,7 +63,6 @@ def transform(self, segments, ranks=None, val_map_ops=None): deploy_ranks = ranks[sid] if not isinstance(deploy_ranks, list): raise TypeError('Expected ranks to be list[list[int],]') - deploy_ops = val_map_ops[sid] segment.deploy(deploy_ranks) #TODO: segment transformation on existing segments else: diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py index 77dd6eb8..633f0c80 100644 --- a/cube/tensor/segment.py +++ b/cube/tensor/segment.py @@ -130,3 +130,8 @@ def set_physical_tensor(self, physical_tensor, ranks): self.physical_tensor = physical_tensor self.group = DeviceGroup().get_group(ranks) self.materialized = True + + def __repr__(self): + msg = 'Segment(Indices: {} | Materialized: {})'.format(self.indices, self.materialized) + return msg + \ No newline at end of file From bae10342bbf9786988d75d3ed0d87fe3b641fe83 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 2 Aug 2021 06:32:36 +0000 Subject: [PATCH 0102/1892] logical op test --- cube/operator/logic/generics.py | 30 ++++++++++++++++++++++-------- tests/operator/test_logical_op.py | 26 ++++++++++++-------------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/cube/operator/logic/generics.py b/cube/operator/logic/generics.py index 1fe75250..47c47f40 100644 --- a/cube/operator/logic/generics.py +++ b/cube/operator/logic/generics.py @@ -28,20 +28,21 @@ def __len__(self): def register(self, holistic_op): """ - Register a holistic op as one of the anchors + Register a holistic op (class) as one of the anchors """ self.holist_ops.append(holistic_op) - def get_op(self, idx): + def get_op(self, idx, shapes): """ Get holistic operator based on idx + The holistic operator will be initialized with shapes + Returns: HolisticOp instance """ - return self.holist_ops[idx] + return self.holist_ops[idx](shapes) - class GenericLogicalOp: @@ -51,7 +52,7 @@ def __init__(self): self.factory = HolisticOpFactory() self.policy_fn = None - def register_policy(self, policy_fn): + def set_policy(self, policy_fn): """ Register a policy function to customize how composite holistic op generated during runtime. @@ -75,13 +76,26 @@ def shape_infer(self, *args, **kwargs): """ raise NotImplementedError("Expected a shape infer engine") + def get_shapes(self, *args, **kwargs): + # get shapes of input and output + shapes = list() + for arg in args: + if isinstance(LogicalTensor, arg): + shapes.append(arg.shape) + else: + shapes.append(None) + shapes.append(self.shape_infer(*args, **kwargs)) + return shapes + def get_op(self, *args, **kwargs): + # get shapes of input and output + shapes = self.get_shapes() # use default policy if self.policy_fn is None: - composite_op = self.factory.get_op(0) + composite_op = self.factory.get_op(0, shapes) # use user-customized policy else: - composite_op = self.policy_fn[0](self.factory, *args, **kwargs) + composite_op = self.policy_fn[0](self.factory, shapes) return composite_op def __call__(self, *args, **kwargs): @@ -91,4 +105,4 @@ def __call__(self, *args, **kwargs): composite_op = self.get_op(*args, **kwargs) # run operator with the strategy plan outputs = composite_op(*args, **kwargs) - return outputs \ No newline at end of file + return outputs diff --git a/tests/operator/test_logical_op.py b/tests/operator/test_logical_op.py index fc8f60cb..931eb0d0 100644 --- a/tests/operator/test_logical_op.py +++ b/tests/operator/test_logical_op.py @@ -6,14 +6,14 @@ def test_factory(): factory = HolisticOpFactory() assert len(factory) == 0 - class HolisticOp: pass - holistic_op = HolisticOp() + class HolisticOp: + def __init__(self, shape): pass - factory.register(holistic_op) - assert len(factor) == 1 + factory.register(HolisticOp) + assert len(factory) == 1 - op = factory.get_op(0) - assert op is holistic_op + op = factory.get_op(0, [(1024, 1024)]) + assert isinstance(op, HolisticOp) def test_generic_logical_op_init(): @@ -27,19 +27,17 @@ def test_generic_logical_op_register(): generic_op = GenericLogicalOp() - class HolisticOp: pass - holistic_op = HolisticOp() + class HolisticOp: + def __init__(self, shape): pass - generic_op.factory.register(holistic_op) + generic_op.factory.register(HolisticOp) - def policy_fn(factory): - return factory.get_op(0) + def policy_fn(factory, shapes): + return factory.get_op(0, shapes) - generic_op.register_policy(policy_fn) + generic_op.set_policy(policy_fn) assert generic_op.policy_fn is not None - op = generic_op.get_op() - assert op is holistic_op if __name__ == '__main__': From bead2625a53baf35e9a66a66f95b83148738df74 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 2 Aug 2021 08:25:56 +0000 Subject: [PATCH 0103/1892] linear column weight split test --- cube/operator/holist/generics.py | 6 +- cube/operator/holist/linear.py | 110 +++++++++++++++---------- cube/tensor/logic/tensor.py | 22 +++++ tests/operator/test_holistic_linear.py | 95 +++++++++++---------- 4 files changed, 140 insertions(+), 93 deletions(-) diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index 90ca71cc..aa9aa89f 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -11,6 +11,8 @@ from cube.tensor.logic.tensor import LogicalTensor from cube.tensor.logic.outline import BaseOutline +from cube.device.physic.group import DeviceGroup + import z3 @@ -162,9 +164,10 @@ def output_adapter(self, outputs): outputs = (outputs,) # step 1: construct to logical tensor + logical_outputs = list() for output, outliner in zip(outputs, self.output_layouts): logical_tensor = LogicalTensor(outliner.shape, init_data=False) - segments = outliner.interpret(shape, self.config) + segments = outliner.interpret(logical_tensor, self.config) for segment in segments: logical_tensor.add_segment(segment) logical_tensor.fill( @@ -190,6 +193,7 @@ def __call__(self, *args, **kwargs): self.input_adapter(*args, **kwargs) # do execution + args, kwargs = LogicalTensor.to_segments(*args, **kwargs) outputs = self.forward(*args, **kwargs) # wrap to logical tensor diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index 2a85ad23..af29d73e 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -3,8 +3,7 @@ import cube.operator.physic.linear as phy_linear from cube.tensor.logic.tensor import LogicalTensor -import cube.tensor.logic.segment as sg -from cube.tensor.community import Community +import cube.tensor.logic.outline as outline # Debug from cube.device.physic.group import DeviceGroup @@ -22,47 +21,52 @@ class LinearColumnWeight(GenericHolisticOp): Split W and b on the last dimension """ - def __init__(self): + def __init__(self, shapes): - inputs_layout = sg.outline.Full( - reduction=sg.ReductionOp.Replica - ) + super().__init__(shapes) - weight_layout = sg.outline.SplitAxis( - axis=0, chunk_num=None, overlap=0, uniform=False, - reduction=sg.ReductionOp.Replica + # input layouts + input_layout = outline.Full( + self.solver, self.shapes[0] ) - - bias_layout = weight_layout - - output_layout = sg.outline.SplitAxis( - axis=1, chunk_num=weight_layout.chunk_num, overlap=0, uniform=False, - reduction=sg.ReductionOp.Replica + weight_layout = outline.SplitAxis( + self.solver, self.shapes[1], + axis=0, chunk_num=None, overlap=0 ) - - super().__init__( - input_layout=[inputs_layout, weight_layout, bias_layout], - output_layout=[output_layout,], - input_format=[None, None, None], - output_format=[None] + bias_layout = outline.SplitAxis( + self.solver, self.shapes[2], + axis=0, chunk_num=weight_layout.chunk_num, overlap=0 ) + # output layouts + output_layout = outline.SplitAxis( + self.solver, self.shapes[3], + axis=1, chunk_num=weight_layout.chunk_num, overlap=0 + ) + + self.set_input_layouts([input_layout, weight_layout, bias_layout]) + self.set_output_layouts([output_layout]) def forward(self, input, weight, bias): + """ + input: list[Segment] of input + weight: list[Segment] of weight + bias: list[Segment] of bias + """ outputs = list() # TODO: handle bias is None - physical_input = input.get_physical_tensor(0) - for cid in range(len(weight)): + physical_input = input[0].get_physical_tensor() + for weight_seg, bias_seg in zip(weight, bias): # output = physic_op.linear(inputs, weight, bias) #TODO: TensorContainer to enable op placement + tensor movement #TODO: ExecutionScheduler to handle re-compute / swap #TODO: nested hybrid call to enable hybrid-parallelisms #TODO: double-check necessety of stateful physical operator - physical_weight = weight.get_physical_tensor(cid) + physical_weight = weight_seg.get_physical_tensor() # if DeviceGroup().rank == 0: # print(physical_weight) - physical_bias = bias.get_physical_tensor(cid) + physical_bias = bias_seg.get_physical_tensor() # TODO: this is the policy decision - phy_op = phy_linear.Linear(placement=weight.get_community(cid).placement) + phy_op = phy_linear.Linear(placement=weight_seg.placement) output = phy_op(physical_input, physical_weight, physical_bias) # if DeviceGroup().rank == 0: # print(output) @@ -80,30 +84,48 @@ class LinearColumnInputRowWeight(GenericHolisticOp): Split W (weights) in row major (first dim) """ - def __init__(self): + def __init__(self, shapes): - inputs_layout = sg.outline.SplitAxis( - axis=-1, chunk_num=None, overlap=0, - reduction=sg.ReductionOp.Replica) + super().__init__(shapes) - weight_layout = sg.outline.SplitAxis( - axis=1, chunk_num=inputs_layout.chunk_num, overlap=0, - reduction=sg.ReductionOp.Replica) + input_layout = outline.SplitAxis( + self.solver, self.shapes, + axis=-1, chunk_num=None, overlap=0, + ) - bias_layout = sg.outline.Full(reduction=sg.ReductionOp.Sum) + weight_layout = outline.SplitAxis( + self.solver, self.shapes, + axis=1, chunk_num=input_layout.chunk_num, overlap=0, + ) - output_layout = sg.outline.Full(reduction=sg.ReductionOp.Sum) + bias_layout = outline.SplitValue( + self.solver, self.shapes, + chunk_num=input_layout.chunk_num, + val_map_op=lambda tensor, rank, world_size : tensor / world_size + ) - super().__init__( - input_layout=[inputs_layout, weight_layout, bias_layout], - output_layout=[output_layout,], - input_format=[None, None, None], - output_format=[None, None, None] + # output layout will only use reduce op + output_layout = outline.SplitValue( + self.solver, self.shapes, + chunk_num=input_layout.chunk_num, + val_map_op=lambda tensor, rank, world_size : tensor, + val_reduce_op=lambda tensor, rank, world_size : tensor ) + + self.set_input_layouts([input_layout, weight_layout, bias_layout]) + self.set_output_layouts([output_layout]) - def forward(self, inputs, weight, bias): - output = physic_op.linear(inputs, weight, bias) - return [output,] + def forward(self, input, weight, bias): + outputs = list() + for input_seg, weight_seg, bias_seg in zip(input, weight, bias): + phy_op = phy_linear.Linear(placement=weight_seg.placement) + output = phy_op( + input_seg.get_physical_tensor(), + weight.get_physical_tensor(), + bias.get_physical_tensor() + ) + outputs.append(output) + return outputs -kHolistLinearSets = [LinearColumnWeight(), LinearColumnInputRowWeight()] \ No newline at end of file +kHolistLinearSets = [LinearColumnWeight, LinearColumnInputRowWeight] \ No newline at end of file diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index 40f4e765..a1ccf465 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -1,6 +1,7 @@ from cube.tensor.segment import Segment from cube.tensor.indices import BaseIndices +from cube.device.physic.group import DeviceGroup class LogicalTensor: """ @@ -140,3 +141,24 @@ def merge_segment(self, indices, reduction_op): The merged segments will be placed at the end of the list. """ raise NotImplementedError + + def __repr__(self): + return 'LogicalTensor[{} with {} Segments]'.format( + tuple(self.shape), len(self.segments) + ) + + @staticmethod + def to_segments(*args, **kwargs): + args_segments = list() + for arg in args: + if isinstance(arg, LogicalTensor): + args_segments.append(arg.segments) + else: + args_segments.append(arg) + kwargs_segments = dict() + for key in kwargs: + if isinstance(kwargs[key], LogicalTensor): + kwargs_segments[key] = kwargs[key].segments + else: + kwargs_segments[key] = kwargs[key] + return tuple(args_segments), kwargs_segments diff --git a/tests/operator/test_holistic_linear.py b/tests/operator/test_holistic_linear.py index 1f5e4751..db64d221 100644 --- a/tests/operator/test_holistic_linear.py +++ b/tests/operator/test_holistic_linear.py @@ -12,7 +12,6 @@ """ from cube.tensor.logic.tensor import LogicalTensor -import cube.tensor.logic.segment as sg from cube.operator.holist.linear import LinearColumnWeight from cube.operator.holist.linear import LinearColumnInputRowWeight @@ -20,23 +19,10 @@ from cube.device.physic.group import DeviceGroup import torch +import z3 torch.manual_seed(100) -class LogicalLinear: - - def __init__(self): pass - - def shape_infer(self, input_shape, weight_shape, bias_shape=None): - """ - Return the outputs shape [list[int],] - """ - #TODO: change all shape impl to list - output_shape = list(input_shape.shape) - output_shape[-1] = weight_shape.shape[0] - return [output_shape] - - def test_linear_POC(): N = 1024 @@ -67,48 +53,61 @@ def test_holistic_linear_op_column_weight(): K results larger bias. """ N = 1024 - input = LogicalTensor(shape=(1024,1024)) - weight = LogicalTensor(shape=(N,1024)) - bias = LogicalTensor(shape=(N,)) - - # output = LogicalLinear(input, weight, bias) + shapes = [(1024, 1024), (N, 1024), (N,), (1024, N)] + input = LogicalTensor(shape=shapes[0]) + weight = LogicalTensor(shape=shapes[1]) + bias = LogicalTensor(shape=shapes[2]) # ================================ Policy =========================== - holistic_op = LinearColumnWeight() - holistic_op.logical_op = LogicalLinear() - - def policy_for_how_many_tiles(outliner): - if isinstance(outliner, sg.outline.Full): - pass - elif isinstance(outliner, sg.outline.SplitAxis): - if outliner.chunk_num.get() is None: - outliner.chunk_num = 4 - else: - raise TypeError("Unhandled outliner type") - # -> together - - def policy_for_each_tile_placement(community, input, weight, bias): - # generate results (hard code) [helper function] - input_ranks = [ - [[0,1,2,3]], [DeviceGroup().all_ranks()] - [[0],[1],[2],[3]], - [[0],[1],[2],[3]] - ] - input_val_map_fns = list([None, None, None]) - return input_ranks, input_val_map_fns + holistic_op = LinearColumnWeight(shapes) + + def policy(holist_op): + solver = holist_op.solver + attributes = holist_op.attributes + input_layout = holist_op.input_layouts[0] + weight_layout = holist_op.input_layouts[1] + bias_layout = holist_op.input_layouts[2] + output_layout = holist_op.output_layouts[0] + + # add restrictions based on device num + device_num = torch.cuda.device_count() + solver.add(weight_layout.chunk_num == 4) + + # iterate all configs + configs = list() + while solver.check() == z3.sat: + config = solver.model() + if DeviceGroup().rank == 0: + print('find config: {}'.format(config)) + configs.append(config) + solver.add( + z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) + ) + if len(attributes) == 0: + break + # choose one config -- suppose to the first + config = configs[0] + if DeviceGroup().rank == 0: + print('selected config: {}'.format(config)) + + # deploy decisions + chunk_num = config[weight_layout.chunk_num].as_long() + input_ranks = [list(range(0, chunk_num)),] + weight_ranks = list() + for rank in range(chunk_num): + weight_ranks.append([rank]) + bias_ranks = weight_ranks + + return config, [input_ranks, weight_ranks, bias_ranks] # Missing Policy: where physical op executed? - holistic_op.set_deploy_policy( - policy_for_each_tile_placement - ) - holistic_op.set_segmentation_policy( - policy_for_how_many_tiles - ) + holistic_op.set_policy(policy) # ================================ Policy =========================== output = holistic_op(input, weight, bias) + print('segments: {}'.format(len(output.segments))) # =============================== Test ============================== output_ref = torch._C._nn.linear( From 0261e9325d244ae0cd118398c6cb28962bbe6796 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 2 Aug 2021 11:50:57 +0000 Subject: [PATCH 0104/1892] select needs a value map-reduce op --- cube/operator/holist/linear.py | 6 ++--- cube/operator/physic/comm/mapreduce.py | 36 +++++++++++++++++++++++++ cube/operator/physic/comm/reduction.py | 31 --------------------- cube/tensor/logic/outline.py | 10 +++---- cube/tensor/logic/tensor.py | 4 +-- cube/tensor/segment.py | 37 +++++++++++++------------- tests/tensor/test_logical_tensor.py | 2 +- tests/tensor/test_outline.py | 4 +-- tests/tensor/test_segment.py | 11 ++++---- 9 files changed, 71 insertions(+), 70 deletions(-) create mode 100644 cube/operator/physic/comm/mapreduce.py delete mode 100644 cube/operator/physic/comm/reduction.py diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index af29d73e..18fc888c 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -1,6 +1,7 @@ from cube.operator.holist.generics import GenericHolisticOp import cube.operator.physic.linear as phy_linear +from cube.operator.physic.comm.mapreduce import PartialSum from cube.tensor.logic.tensor import LogicalTensor import cube.tensor.logic.outline as outline @@ -101,15 +102,14 @@ def __init__(self, shapes): bias_layout = outline.SplitValue( self.solver, self.shapes, chunk_num=input_layout.chunk_num, - val_map_op=lambda tensor, rank, world_size : tensor / world_size + val_op=PartialSum ) # output layout will only use reduce op output_layout = outline.SplitValue( self.solver, self.shapes, chunk_num=input_layout.chunk_num, - val_map_op=lambda tensor, rank, world_size : tensor, - val_reduce_op=lambda tensor, rank, world_size : tensor + val_op=PartialSum ) self.set_input_layouts([input_layout, weight_layout, bias_layout]) diff --git a/cube/operator/physic/comm/mapreduce.py b/cube/operator/physic/comm/mapreduce.py new file mode 100644 index 00000000..3b3beaf0 --- /dev/null +++ b/cube/operator/physic/comm/mapreduce.py @@ -0,0 +1,36 @@ +import torch + + +class ValueMapReduceOp: + + def __init__(self, val_map_op, val_reduce_op): + if not (callable(val_map_op) and callable(val_reduce_op)): + raise TypeError("Expected val_map_op and val_reduce_o callable") + self.val_map_op = (val_map_op,) + self.val_reduce_op = (val_reduce_op,) + + def map(self, tensor, group): + if not torch.is_tensor(tensor): + raise RuntimeError("Expected tensor to be torch.Tensor") + return self.val_map_op[0](tensor, group) + + def reduce(self, tensor, group): + if not torch.is_tensor(tensor): + raise RuntimeError("Expected `tensor` to be torch.Tensor") + return self.val_map_op[0](tensor, group) + + +def _val_split_map(tensor, group): + world_size = torch.distributed.get_world_size(group) + return tensor / world_size + + +def _val_sum_reduce(tensor, group): + torch.distributed.all_reduce(tensor, group=group) + return tensor + + +PartialSum = ValueMapReduceOp( + val_map_op = _val_split_map, + val_reduce_op = _val_sum_reduce +) diff --git a/cube/operator/physic/comm/reduction.py b/cube/operator/physic/comm/reduction.py deleted file mode 100644 index 07fdbec8..00000000 --- a/cube/operator/physic/comm/reduction.py +++ /dev/null @@ -1,31 +0,0 @@ -from cube.operator.physic.comm import replicate, reduce_sum -import torch - - -# TODO: reduction op should be in torch autograd function -class _Reduction(type): - - # forward: all_reduce, backward: identity - Sum = (reduce_sum,) - - # forward: identity, backward: all_reduce - Replica = (replicate,) - - def register(cls, name, udf): - """ - Reduction functions should be in function format: - - Arguments: - PhysicalTensor - Communication Group - - Return: - PhysicalTensor - """ - if hasattr(cls, name): - raise KeyError("{} is registered".format(name)) - setattr(cls, name, (udf,)) - - -class ReductionOp(metaclass=_Reduction): - pass diff --git a/cube/tensor/logic/outline.py b/cube/tensor/logic/outline.py index b7526f02..fbfc30e5 100644 --- a/cube/tensor/logic/outline.py +++ b/cube/tensor/logic/outline.py @@ -166,20 +166,18 @@ def interpret(self, logical_tensor, config): class SplitValue(BaseOutline): - def __init__(self, solver, shape, chunk_num, val_map_op): + def __init__(self, solver, shape, chunk_num, val_op): """ Split the whole tensor in value dimension. Each segment shape will be same with logical tensor. - Each segment value will be modified by `val_map_op`. + Each segment value will be modified by `val_op`. """ - if not callable(val_map_op): - raise TypeError("Expected val_map_op a callable function") super().__init__(solver, shape) self.add_field(chunk_num=chunk_num) self.solver.add(self.chunk_num >= 1) - self.val_map_op = val_map_op + self.val_op = val_op def interpret(self, logical_tensor, config): if tuple(logical_tensor.shape) != tuple(self.shape): @@ -189,6 +187,6 @@ def interpret(self, logical_tensor, config): for cid in range(chunk_num): # full tensor shape indices = TileIndices([0] * len(self.shape), self.shape) - segment = logical_tensor.select(indices, self.val_map_op, self.shape) + segment = logical_tensor.select(indices, self.val_op, self.shape) segments.append(segment) return segments diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index a1ccf465..d02ee882 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -40,12 +40,12 @@ def fill(self, physical_tensors, ranks): for segment, physical_tensor, ranks in zip(self.segments, physical_tensors, ranks): segment.set_physical_tensor(physical_tensor, ranks) - def select(self, indices, val_map_op, shape): + def select(self, indices, val_op, shape): """ Create a Segment given the indices for this logical tensor, and the Segment will use shape. """ - segment = Segment(self, indices, val_map_op, shape) + segment = Segment(self, indices, val_op, shape) return segment def transform(self, segments, ranks=None): diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py index 633f0c80..bcbea01d 100644 --- a/cube/tensor/segment.py +++ b/cube/tensor/segment.py @@ -6,7 +6,7 @@ class Segment: - def __init__(self, logical_tensor, indices, val_map_op, shape): + def __init__(self, logical_tensor, indices, val_op, shape): """Create Segment based on the logical tensor Segment manages: @@ -17,10 +17,9 @@ def __init__(self, logical_tensor, indices, val_map_op, shape): Attribute: indices (tuple(slice,) or list[list[int]]): indices of logical_tensor for this segment - val_map_op (None or callable): - deploy op to take logical value and map - merge_op (None or callable): - merge op to take physical tensor + val_op (ValueMapReduceOp): + deploy op to take logical value and group in for value mapping + merge op to take mapped value and group in for value reduction """ if not isinstance(indices, BaseIndices): raise TypeError("Expected indices to be BaseIndices") @@ -30,8 +29,6 @@ def __init__(self, logical_tensor, indices, val_map_op, shape): # segment info self.indices = indices - self.val_map_ops = list() - self.add_val_map_op(val_map_op) self.shape = tuple(shape) # physical tensor (the PyTorch Tensor) @@ -42,8 +39,9 @@ def __init__(self, logical_tensor, indices, val_map_op, shape): self.group = None self.materialized = False - # recover op - self.merge_op = None + # val ops + self.val_ops = list() + self.add_val_op(val_op) def deploy(self, ranks): """deploy (materialize) to physical tensors @@ -74,8 +72,8 @@ def deploy(self, ranks): self.physical_tensor.copy_( self.logical_tensor.data[self.indices.get()].reshape(self.shape) ) - for val_map_op in self.val_map_ops: - self.physical_tensor.data = val_map_op(self.physical_tensor, rank, len(ranks)) + for val_op in self.val_ops: + self.physical_tensor.data = val_op.map(self.physical_tensor, self.group) self.materialized = True def recover(self, reduction_op): @@ -97,14 +95,14 @@ def recover(self, reduction_op): else: raise RuntimeError("The Segment has not been materialized") - def add_val_map_op(self, val_map_op): + def add_val_op(self, val_op): """ - Append val_map_op to the end + Append val_op to the end """ - if val_map_op is not None: - if not callable(val_map_op): - raise TypeError("Expected val_map_op to be callable or None") - self.val_map_ops.append(val_map_op) + if val_op is not None: + if not (callable(val_op.map) and callable(val_op.reduce)): + raise TypeError("Expected val_op to be ValMapReudceOp") + self.val_ops.append(val_op) def get_physical_tensor(self): """Get physical tensor if materialized @@ -132,6 +130,7 @@ def set_physical_tensor(self, physical_tensor, ranks): self.materialized = True def __repr__(self): - msg = 'Segment(Indices: {} | Materialized: {})'.format(self.indices, self.materialized) - return msg + return 'Segment(Indices: {} | Materialized: {})'.format( + self.indices, self.materialized + ) \ No newline at end of file diff --git a/tests/tensor/test_logical_tensor.py b/tests/tensor/test_logical_tensor.py index e7b5e06e..6d9ad033 100644 --- a/tests/tensor/test_logical_tensor.py +++ b/tests/tensor/test_logical_tensor.py @@ -83,7 +83,7 @@ def test_logical_tensor_transform(): segment = tensor.select(indices, None, shape=(2,2)) ranks = [0,1,3] - tensor.transform([segment], [ranks], [None]) + tensor.transform([segment], [ranks]) myrank = DeviceGroup().rank if myrank in ranks: diff --git a/tests/tensor/test_outline.py b/tests/tensor/test_outline.py index 40bea106..edf91194 100644 --- a/tests/tensor/test_outline.py +++ b/tests/tensor/test_outline.py @@ -1,6 +1,7 @@ from cube.tensor.logic.tensor import LogicalTensor import cube.tensor.logic.outline as outline from cube.tensor.segment import Segment +from cube.operator.physic.comm.mapreduce import PartialSum import torch import z3 @@ -86,10 +87,9 @@ def test_split_axis_with_constraints(): def test_split_value(): shape = [1024, 32] - split_op = lambda tensor, rank, world_size : tensor / world_size solver = z3.Solver() - split_dsp = outline.SplitValue(solver, shape, None, split_op) + split_dsp = outline.SplitValue(solver, shape, None, PartialSum) split_dsp.solver.add(split_dsp.chunk_num <= 4) configs = list() for config in iter_each_config(solver, split_dsp.get_attributes()): diff --git a/tests/tensor/test_segment.py b/tests/tensor/test_segment.py index a63ae028..b7bf0947 100644 --- a/tests/tensor/test_segment.py +++ b/tests/tensor/test_segment.py @@ -15,6 +15,7 @@ from cube.tensor.segment import Segment from cube.tensor.indices import BaseIndices, TileIndices from cube.device.physic.group import DeviceGroup +from cube.operator.physic.comm.mapreduce import PartialSum import torch import os @@ -36,9 +37,8 @@ def test_segment_init(): assert segment.physical_tensor is None assert len(segment.placement) == 0 assert segment.group is None - assert len(segment.val_map_ops) == 0 + assert len(segment.val_ops) == 0 assert segment.materialized is False - assert segment.merge_op is None def test_segment_deploy(): @@ -64,9 +64,8 @@ def test_segment_deploy(): assert physical_tensor is None assert segment.placement == ranks assert segment.group == DeviceGroup().get_group(ranks) - assert len(segment.val_map_ops) == 0 + assert len(segment.val_ops) == 0 assert segment.materialized is True - assert segment.merge_op is None def test_segment_deploy_with_val_map(): @@ -81,10 +80,10 @@ def test_segment_deploy_with_val_map(): segment = Segment( logical_tensor = tensor, indices = indices, - val_map_op = lambda tensor, rank, world_size: tensor / world_size, + val_op = PartialSum, shape = ofst ) - assert len(segment.val_map_ops) == 1 + assert len(segment.val_ops) == 1 ranks = [0,2] segment.deploy(ranks) From b2944ef559811e1744d4bd08a0d1193d8b0a1f9b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 2 Aug 2021 14:07:56 +0000 Subject: [PATCH 0105/1892] log among-segments reduction info --- cube/operator/holist/linear.py | 12 ++-- cube/tensor/logic/outline.py | 2 + cube/tensor/logic/tensor.py | 7 ++- cube/tensor/segment.py | 44 +++++++++---- tests/operator/test_holistic_linear.py | 86 +++++++++++++++++++++++++- 5 files changed, 130 insertions(+), 21 deletions(-) diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index 18fc888c..560c6408 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -90,24 +90,24 @@ def __init__(self, shapes): super().__init__(shapes) input_layout = outline.SplitAxis( - self.solver, self.shapes, + self.solver, self.shapes[0], axis=-1, chunk_num=None, overlap=0, ) weight_layout = outline.SplitAxis( - self.solver, self.shapes, + self.solver, self.shapes[1], axis=1, chunk_num=input_layout.chunk_num, overlap=0, ) bias_layout = outline.SplitValue( - self.solver, self.shapes, + self.solver, self.shapes[2], chunk_num=input_layout.chunk_num, val_op=PartialSum ) # output layout will only use reduce op output_layout = outline.SplitValue( - self.solver, self.shapes, + self.solver, self.shapes[3], chunk_num=input_layout.chunk_num, val_op=PartialSum ) @@ -121,8 +121,8 @@ def forward(self, input, weight, bias): phy_op = phy_linear.Linear(placement=weight_seg.placement) output = phy_op( input_seg.get_physical_tensor(), - weight.get_physical_tensor(), - bias.get_physical_tensor() + weight_seg.get_physical_tensor(), + bias_seg.get_physical_tensor() ) outputs.append(output) return outputs diff --git a/cube/tensor/logic/outline.py b/cube/tensor/logic/outline.py index fbfc30e5..ccc19410 100644 --- a/cube/tensor/logic/outline.py +++ b/cube/tensor/logic/outline.py @@ -189,4 +189,6 @@ def interpret(self, logical_tensor, config): indices = TileIndices([0] * len(self.shape), self.shape) segment = logical_tensor.select(indices, self.val_op, self.shape) segments.append(segment) + for segment in segments: + segment.val_op_segs.append(segments) return segments diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py index d02ee882..79283912 100644 --- a/cube/tensor/logic/tensor.py +++ b/cube/tensor/logic/tensor.py @@ -57,6 +57,7 @@ def transform(self, segments, ranks=None): raise ValueError("Expected ranks to be a list with equal length of segments") if len(self.segments) == 0: + # setting up the placement for all segments for sid in range(len(segments)): segment = segments[sid] self.add_segment(segment) @@ -64,7 +65,11 @@ def transform(self, segments, ranks=None): deploy_ranks = ranks[sid] if not isinstance(deploy_ranks, list): raise TypeError('Expected ranks to be list[list[int],]') - segment.deploy(deploy_ranks) + segment.placement = deploy_ranks + # deploy with the placement + for segment in self.segments: + if not segment.materialized: + segment.deploy() #TODO: segment transformation on existing segments else: raise NotImplementedError diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py index bcbea01d..942d2f63 100644 --- a/cube/tensor/segment.py +++ b/cube/tensor/segment.py @@ -31,6 +31,11 @@ def __init__(self, logical_tensor, indices, val_op, shape): self.indices = indices self.shape = tuple(shape) + # val ops + self.val_ops = list() + self.val_op_segs = list() + self.add_val_op(val_op) + # physical tensor (the PyTorch Tensor) self.physical_tensor = None @@ -39,11 +44,7 @@ def __init__(self, logical_tensor, indices, val_op, shape): self.group = None self.materialized = False - # val ops - self.val_ops = list() - self.add_val_op(val_op) - - def deploy(self, ranks): + def deploy(self, ranks=None): """deploy (materialize) to physical tensors Materialize physical tensors for this community and spread out @@ -53,18 +54,27 @@ def deploy(self, ranks): to spread. Argument: - ranks (list[int]): device id list + ranks (list[int] or None): + if rank id list: deploy based on this list + if None: deploy based on setted self.placement value_map_op (callable): takes the tensor, rank, world_size, return a new tensor """ - if not isinstance(ranks, list): - raise TypeError("Expected ranks in list[int]") + if isinstance(ranks, list): + self.placement = ranks + elif ranks is None and self.placement is None: + raise TypeError("Expected self.placement when ranks is None") + + #TODO: remove this constraints + if len(self.val_ops) > 0 and len(self.placement) > 1: + raise RuntimeError("Currently segment with val_ops only allows to deploy on one rank") rank = DeviceGroup().rank - self.placement = ranks - self.group = DeviceGroup().get_group(ranks) - if rank in ranks: + self.group = DeviceGroup().get_group(self.placement) + + # set physical tensors + if rank in self.placement: if self.logical_tensor.data is None: raise RuntimeError("Try deploying a segment from a logical tensor without data") # select from logical data @@ -72,8 +82,16 @@ def deploy(self, ranks): self.physical_tensor.copy_( self.logical_tensor.data[self.indices.get()].reshape(self.shape) ) - for val_op in self.val_ops: - self.physical_tensor.data = val_op.map(self.physical_tensor, self.group) + + # go through val_op + for val_op, segs in zip(self.val_ops, self.val_op_segs): + if len(segs) == 0: + raise RuntimeError("Missing segments for val op") + op_ranks = [seg.placement[0] for seg in segs] + group = DeviceGroup().get_group(op_ranks) + if rank in self.placement: + self.physical_tensor.data = val_op.map(self.physical_tensor, group) + self.materialized = True def recover(self, reduction_op): diff --git a/tests/operator/test_holistic_linear.py b/tests/operator/test_holistic_linear.py index db64d221..885b9564 100644 --- a/tests/operator/test_holistic_linear.py +++ b/tests/operator/test_holistic_linear.py @@ -126,8 +126,92 @@ def policy(holist_op): # =============================== Test ============================== +def test_holistic_linear_op_column_input_row_weight(): + + N = 1024 + shapes = [(1024, 1024), (N, 1024), (N,), (1024, N)] + input = LogicalTensor(shape=shapes[0]) + weight = LogicalTensor(shape=shapes[1]) + bias = LogicalTensor(shape=shapes[2]) + + # ================================ Policy =========================== + + holistic_op = LinearColumnInputRowWeight(shapes) + + def policy(holist_op): + solver = holist_op.solver + attributes = holist_op.attributes + input_layout = holist_op.input_layouts[0] + weight_layout = holist_op.input_layouts[1] + bias_layout = holist_op.input_layouts[2] + output_layout = holist_op.output_layouts[0] + + # add restrictions based on device num + device_num = torch.cuda.device_count() + solver.add(weight_layout.chunk_num == 4) + + # iterate all configs + configs = list() + while solver.check() == z3.sat: + config = solver.model() + if DeviceGroup().rank == 0: + print('find config: {}'.format(config)) + configs.append(config) + solver.add( + z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) + ) + if len(attributes) == 0: + break + # choose one config -- suppose to the first + config = configs[0] + if DeviceGroup().rank == 0: + print('selected config: {}'.format(config)) + + # deploy decisions + chunk_num = config[weight_layout.chunk_num].as_long() + input_ranks = list() + for rank in range(chunk_num): + input_ranks.append([rank]) + weight_ranks = input_ranks + bias_ranks = weight_ranks + + return config, [input_ranks, weight_ranks, bias_ranks] + + # Missing Policy: where physical op executed? + + holistic_op.set_policy(policy) + # ================================ Policy =========================== + + output = holistic_op(input, weight, bias) + print('segments: {}'.format(len(output.segments))) + + # =============================== Test ============================== + rank = DeviceGroup().rank + input_ref = torch.chunk(input.data.cuda(), chunks=4, dim=-1)[rank] + weight_ref = torch.chunk(weight.data.cuda(), chunks=4, dim=1)[rank] + bias_ref = bias.data.cuda() / 4 + # if rank == 0: + # print('input ref: ', input_ref) + # print('weight ref: ', weight_ref) + # print('bias ref: ', bias_ref) + + output_ref = torch._C._nn.linear( + input_ref, weight_ref, bias_ref + ) + out = output.get_physical_tensor(rank) + # if rank == 0: + # print('ref: ', output_ref) + # print('get: ', out) + # print('max bias: ', torch.max(torch.abs(out - output_ref))) + # print('sum bias: ', torch.sum(torch.abs(out - output_ref))) + error_max = torch.max(torch.abs(out - output_ref)) + assert error_max.item() < 2e-4 + # =============================== Test ============================== + + if __name__ == '__main__': group = DeviceGroup() # test_linear_POC() - test_holistic_linear_op_column_weight() \ No newline at end of file + test_holistic_linear_op_column_weight() + test_holistic_linear_op_column_input_row_weight() From 0f4bd2f0b5ca575acddd17d99cfee76f1444c2e6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 3 Aug 2021 10:38:02 +0000 Subject: [PATCH 0106/1892] add constraints interface --- cube/tensor/logic/outline.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/cube/tensor/logic/outline.py b/cube/tensor/logic/outline.py index ccc19410..399fa322 100644 --- a/cube/tensor/logic/outline.py +++ b/cube/tensor/logic/outline.py @@ -57,7 +57,7 @@ def add_field(self, **kwargs): raise TypeError("{} only supports list[int] choices".format(key)) self.__dict__[key] = z3.Int(key) self.attributes.append(self.__dict__[key]) - self.solver.add(z3.Or([self.__dict__[key] == val for val in val])) + self.solver.add(z3.Or([self.__dict__[key] == v for v in val])) elif isinstance(val, int): self.__dict__[key] = z3.Int(str(id(self))+key) self.attributes.append(self.__dict__[key]) @@ -75,6 +75,14 @@ def add_field(self, **kwargs): else: raise TypeError("{} can only be int, list[int], z3.Int()".format(key)) + def add_constraint(self, constraint): + """ + Add a constraint + """ + if not isinstance(constraint, z3.z3.BoolRef): + raise TypeError("Expected z3.z3.BoolRef constraints") + self.solver.add(constraint) + def remove_config(self, config): if not isinstance(config, z3.z3.ModelRef): raise TypeError("Expected config from z3 model()") @@ -123,17 +131,19 @@ def __init__(self, solver, shape, axis, chunk_num, overlap): self.axis = axis self.add_field(overlap=overlap) - self.solver.add(self.overlap >= 0) + self.add_constraint(self.overlap >= 0) self.add_field(chunk_num=chunk_num) - self.solver.add(self.chunk_num >= 0) + self.add_constraint(self.chunk_num >= 0) # TODO: change to array to adapt with non-uniform cases self.add_field(chunk_size=None) # setup constraints total_size = self.shape[self.axis] - self.solver.add(self.chunk_num * self.chunk_size - self.overlap * (self.chunk_num - 1) == total_size) + self.add_constraint( + self.chunk_num * self.chunk_size - self.overlap * (self.chunk_num - 1) == total_size + ) def interpret(self, logical_tensor, config): """ @@ -176,7 +186,7 @@ def __init__(self, solver, shape, chunk_num, val_op): """ super().__init__(solver, shape) self.add_field(chunk_num=chunk_num) - self.solver.add(self.chunk_num >= 1) + self.add_constraint(self.chunk_num >= 1) self.val_op = val_op def interpret(self, logical_tensor, config): From 91065879ea71613639d80cf68ac7eb2d4c046be8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 3 Aug 2021 11:36:30 +0000 Subject: [PATCH 0107/1892] add linear example --- cube/__init__.py | 3 + cube/nn/__init__.py | 1 + cube/nn/linear.py | 30 +++++++ cube/operator/holist/generics.py | 38 ++++++--- cube/operator/holist/linear.py | 16 ++-- cube/operator/logic/__init__.py | 1 + cube/operator/logic/generics.py | 53 +++++++----- cube/operator/logic/linear.py | 18 ++--- examples/ffn.py | 93 +++++++++++++++++++++ examples/linear.py | 134 ++++++++++++++++--------------- 10 files changed, 276 insertions(+), 111 deletions(-) create mode 100644 cube/nn/__init__.py create mode 100644 cube/nn/linear.py create mode 100644 examples/ffn.py diff --git a/cube/__init__.py b/cube/__init__.py index e69de29b..39c29770 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -0,0 +1,3 @@ +from cube import operator +from cube import nn +from cube import device diff --git a/cube/nn/__init__.py b/cube/nn/__init__.py new file mode 100644 index 00000000..4394da26 --- /dev/null +++ b/cube/nn/__init__.py @@ -0,0 +1 @@ +from cube.nn.linear import Linear \ No newline at end of file diff --git a/cube/nn/linear.py b/cube/nn/linear.py new file mode 100644 index 00000000..8f43fbe2 --- /dev/null +++ b/cube/nn/linear.py @@ -0,0 +1,30 @@ +import torch +from cube.tensor.logic.tensor import LogicalTensor +import cube.operator.logic as logic_op + +import math + +from torch import nn + +class Linear(nn.Module): + + __constants__ = ['in_features', 'out_features'] + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super(Linear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = LogicalTensor((out_features, in_features)) + if bias: + self.bias = LogicalTensor((out_features,)) + self.reset_parameters() + # Actually here we can pass shapes + self.op = logic_op.Linear() + + def reset_parameters(self) -> None: + pass + + def forward(self, input: LogicalTensor) -> LogicalTensor: + return self.op(input, self.weight, self.bias) diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index aa9aa89f..b940e20e 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -18,6 +18,8 @@ class GenericHolisticOp: + _default_policy_fn = None + def __init__(self, shapes): """ Layout is the community distribution requirement for input and @@ -77,14 +79,14 @@ def set_output_layouts(self, layouts): self.attributes += layout.get_attributes() self.output_layouts.append(layout) - def set_logic_op(self, logic_op): + def add_constraint(self, constraint): """ - Set logic op. This will be automatically called when the - holistic op registered in a logical op. + Add cross-layout constraint to the solver """ - # if not isinstance(logic_op, GenericLogicalOp): - # raise TypeError("Require a logic op to register") - self.logical_op = logic_op + if not isinstance(constraint, z3.z3.BoolRef): + raise TypeError("Expected z3.z3.BoolRef constraints") + self.solver.add(constraint) + def set_config(self, config): if not isinstance(config, z3.z3.ModelRef): @@ -111,9 +113,10 @@ def input_adapter(self, *args, **kwargs): # input.permute(dim_order) # step 2: Policy: segmentation + deploy decision - if self.policy_fn is None: - raise RuntimeError("Expected a runtime configuration policy") - config, input_ranks = self.policy_fn[0](self) + policy_fn = self._default_policy_fn + if self.policy_fn is not None: + policy_fn = self.policy_fn + config, input_ranks = policy_fn[0](self) self.set_config(config) # step 3: segmentation @@ -203,9 +206,8 @@ def __call__(self, *args, **kwargs): def set_policy(self, policy_fn): """ - Register a policy to take layouts and solver, - generate device placement for each community, and corresponding - message mapping + Register a customized policy to take layouts and solver, + generate segmentation plan and deploy plan Args: plicy_fn (callable) @@ -214,3 +216,15 @@ def set_policy(self, policy_fn): raise TypeError("Expected callable function") self.policy_fn = (policy_fn,) + @classmethod + def set_default_policy(cls, policy_fn): + """ + Register a policy for all instances. Take layouts and solver, + generate segmentation plan and deploy plan + + Args: + plicy_fn (callable) + """ + if not callable(policy_fn): + raise TypeError("Expected callable function") + cls._default_policy_fn = (policy_fn,) diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index 560c6408..b40e6849 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -36,13 +36,16 @@ def __init__(self, shapes): ) bias_layout = outline.SplitAxis( self.solver, self.shapes[2], - axis=0, chunk_num=weight_layout.chunk_num, overlap=0 + axis=0, chunk_num=None, overlap=0 ) + self.add_constraint(bias_layout.chunk_num == weight_layout.chunk_num) + # output layouts output_layout = outline.SplitAxis( self.solver, self.shapes[3], - axis=1, chunk_num=weight_layout.chunk_num, overlap=0 + axis=1, chunk_num=None, overlap=0 ) + self.add_constraint(output_layout.chunk_num == weight_layout.chunk_num) self.set_input_layouts([input_layout, weight_layout, bias_layout]) self.set_output_layouts([output_layout]) @@ -96,21 +99,24 @@ def __init__(self, shapes): weight_layout = outline.SplitAxis( self.solver, self.shapes[1], - axis=1, chunk_num=input_layout.chunk_num, overlap=0, + axis=1, chunk_num=None, overlap=0, ) + self.add_constraint(weight_layout.chunk_num == input_layout.chunk_num) bias_layout = outline.SplitValue( self.solver, self.shapes[2], - chunk_num=input_layout.chunk_num, + chunk_num=None, val_op=PartialSum ) + self.add_constraint(bias_layout.chunk_num == input_layout.chunk_num) # output layout will only use reduce op output_layout = outline.SplitValue( self.solver, self.shapes[3], - chunk_num=input_layout.chunk_num, + chunk_num=None, val_op=PartialSum ) + self.add_constraint(output_layout.chunk_num == input_layout.chunk_num) self.set_input_layouts([input_layout, weight_layout, bias_layout]) self.set_output_layouts([output_layout]) diff --git a/cube/operator/logic/__init__.py b/cube/operator/logic/__init__.py index e69de29b..39163d3e 100644 --- a/cube/operator/logic/__init__.py +++ b/cube/operator/logic/__init__.py @@ -0,0 +1 @@ +from cube.operator.logic.linear import Linear \ No newline at end of file diff --git a/cube/operator/logic/generics.py b/cube/operator/logic/generics.py index 47c47f40..5cd2cd73 100644 --- a/cube/operator/logic/generics.py +++ b/cube/operator/logic/generics.py @@ -1,5 +1,4 @@ """ - A Logical Operator: * Statusless * Can be executed by only one kernel (atomic) on single device @@ -11,9 +10,11 @@ |- ... Holistic operators are allowed to nested in hybrid-distribution strategy - """ +from cube.tensor.logic.tensor import LogicalTensor + + class HolisticOpFactory: def __init__(self): @@ -46,23 +47,13 @@ def get_op(self, idx, shapes): class GenericLogicalOp: + _default_policy_fn = None + def __init__(self): # candidate holistic operator self.factory = HolisticOpFactory() self.policy_fn = None - - def set_policy(self, policy_fn): - """ - Register a policy function to customize how composite - holistic op generated during runtime. - - The `policy_fn` takes self.factory as input and returns a composite - holistic operator (callable) - """ - if not callable(policy_fn): - raise TypeError("Expected a callable function") - self.policy_fn = [policy_fn] def shape_infer(self, *args, **kwargs): """ @@ -80,19 +71,20 @@ def get_shapes(self, *args, **kwargs): # get shapes of input and output shapes = list() for arg in args: - if isinstance(LogicalTensor, arg): + if isinstance(arg, LogicalTensor): shapes.append(arg.shape) else: shapes.append(None) - shapes.append(self.shape_infer(*args, **kwargs)) + shapes += self.shape_infer(*args, **kwargs) return shapes def get_op(self, *args, **kwargs): # get shapes of input and output - shapes = self.get_shapes() + shapes = self.get_shapes(*args, **kwargs) + print(shapes) # use default policy if self.policy_fn is None: - composite_op = self.factory.get_op(0, shapes) + composite_op = self._default_policy_fn[0](self.factory, shapes) # use user-customized policy else: composite_op = self.policy_fn[0](self.factory, shapes) @@ -106,3 +98,28 @@ def __call__(self, *args, **kwargs): # run operator with the strategy plan outputs = composite_op(*args, **kwargs) return outputs + + def set_policy(self, policy_fn): + """ + Register a policy function to customize how composite + holistic op generated during runtime. + + The `policy_fn` takes self.factory as input and returns a composite + holistic operator (callable) + """ + if not callable(policy_fn): + raise TypeError("Expected a callable function") + self.policy_fn = (policy_fn,) + + @classmethod + def set_default_policy(self, policy_fn): + """ + Register a default policy function to all instances. + Customize how composite holistic op generated during runtime. + + The `policy_fn` takes self.factory and shapes as input, + and returns a composite holistic operator (callable) + """ + if not callable(policy_fn): + raise TypeError("Expected a callable function") + self._default_policy_fn = (policy_fn,) diff --git a/cube/operator/logic/linear.py b/cube/operator/logic/linear.py index 33c9b80a..c1a108a5 100644 --- a/cube/operator/logic/linear.py +++ b/cube/operator/logic/linear.py @@ -1,25 +1,21 @@ -from cube.operator.logic.generics import generics +from cube.operator.logic.generics import GenericLogicalOp from cube.operator.holist.linear import kHolistLinearSets -__all__ = ['Linear'] - - -def Linear(generics.GenericLogicalOp): +class Linear(GenericLogicalOp): def __init__(self): - super().__init__(self) + super().__init__() # register holistic operators for holist_op in kHolistLinearSets: - holist_op.set_logic_op(self) self.factory.register(holist_op) - def shape_infer(self, input_shape, weight_shape, bias_shape=None): + def shape_infer(self, input, weight, bias=None): """ Return the outputs shape [list[int],] """ - output_shape = input_shape.shape - output_shape[-1] = weight_shape.shape[0] - return [output_shape] + output_shape = list(input.shape) + output_shape[-1] = weight.shape[0] + return [output_shape,] diff --git a/examples/ffn.py b/examples/ffn.py new file mode 100644 index 00000000..b5e90099 --- /dev/null +++ b/examples/ffn.py @@ -0,0 +1,93 @@ +""" +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=6000 \ + --use_env \ + examples/linear.py +""" + +import torch +from torch import nn +from torch import Tensor +from torch.nn.parameter import Parameter +import torch.nn.functional as F + +import math +import argparse + + +class Linear(nn.Module): + + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super(Linear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: Tensor) -> Tensor: + return combo_op.linear_op(input, self.weight, self.bias) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=16, classes=1000): + super().__init__() + self.net = nn.Sequential( + Linear(dim, dim * mult), + nn.GELU(), + nn.Dropout(dropout), + Linear(dim * mult, dim) + ) + + self.classifier = Linear(dim, classes) + + def forward(self, x, labels): + output = self.net(x) + output = self.classifier(output) + loss = F.cross_entropy(output, labels) + return loss + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--dim', type=int, default=1024) + parser.add_argument('--heads', type=int, default=16) + parser.add_argument('--bs', type=int, default=8) + parser.add_argument('--classes', type=int, default=10) + args = parser.parse_args() + + # init distributed env + group = combo.physical.device.group.DeviceGroup() + print(group) + + model = FeedForward(args.dim, mult=args.heads, classes=args.classes) + model = model.cuda() + + inputs = torch.rand((args.bs, args.dim)).cuda() + labels = torch.randint(0, 10, (args.bs, )).cuda() + for _ in range(100): + loss = model(inputs, labels) + loss.backward() + print('Done.') \ No newline at end of file diff --git a/examples/linear.py b/examples/linear.py index 96deec6b..adb933cc 100644 --- a/examples/linear.py +++ b/examples/linear.py @@ -9,88 +9,92 @@ examples/linear.py """ +import cube +from cube.tensor.logic.tensor import LogicalTensor +from cube.device.physic.group import DeviceGroup + import torch from torch import nn -from torch import Tensor -from torch.nn.parameter import Parameter -import torch.nn.functional as F - -import combo -import combo.physical.operator as combo_op - -import math import argparse +import z3 + +torch.manual_seed(100) + + +# Expert Policy + +def select_policy(holistic_ops, shapes): + return holistic_ops.get_op(0, shapes) + + +def segment_policy(holist_op): + solver = holist_op.solver + attributes = holist_op.attributes + input_layout = holist_op.input_layouts[0] + weight_layout = holist_op.input_layouts[1] + bias_layout = holist_op.input_layouts[2] + output_layout = holist_op.output_layouts[0] + # add restrictions based on device num + device_num = torch.cuda.device_count() + solver.add(weight_layout.chunk_num == 4) + + # iterate all configs + configs = list() + while solver.check() == z3.sat: + config = solver.model() + if DeviceGroup().rank == 0: + print('find config: {}'.format(config)) + configs.append(config) + solver.add( + z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) + ) + if len(attributes) == 0: + break + # choose one config -- suppose to the first + config = configs[0] + if DeviceGroup().rank == 0: + print('selected config: {}'.format(config)) + # deploy decisions + chunk_num = config[weight_layout.chunk_num].as_long() + input_ranks = [list(range(0, chunk_num)),] + weight_ranks = list() + for rank in range(chunk_num): + weight_ranks.append([rank]) + bias_ranks = weight_ranks + return config, [input_ranks, weight_ranks, bias_ranks] -class Linear(nn.Module): - - __constants__ = ['in_features', 'out_features'] - in_features: int - out_features: int - weight: Tensor - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super(Linear, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs)) - if bias: - self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter('bias', None) - self.reset_parameters() +cube.operator.logic.linear.Linear.set_default_policy(select_policy) +cube.operator.holist.linear.LinearColumnWeight.set_default_policy(segment_policy) - def reset_parameters(self) -> None: - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(self.bias, -bound, bound) - def forward(self, input: Tensor) -> Tensor: - return combo_op.linear_op(input, self.weight, self.bias) +# User Network +class SingleLinear(nn.Module): -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): + def __init__(self, dim, mult): super().__init__() - self.net = nn.Sequential( - Linear(dim, dim * mult), - nn.GELU(), - nn.Dropout(dropout), - Linear(dim * mult, dim) - ) - - self.classifier = Linear(dim, classes) - - def forward(self, x, labels): + self.net = cube.nn.Linear(dim, dim * mult) + + def forward(self, x): output = self.net(x) - output = self.classifier(output) - loss = F.cross_entropy(output, labels) - return loss + return output if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--dim', type=int, default=1024) - parser.add_argument('--heads', type=int, default=16) - parser.add_argument('--bs', type=int, default=8) - parser.add_argument('--classes', type=int, default=10) + parser.add_argument('--bs', type=int, default=32) + parser.add_argument('--dim', type=int, default=128) + parser.add_argument('--mult', type=int, default=16) args = parser.parse_args() # init distributed env - group = combo.physical.device.group.DeviceGroup() - print(group) - - model = FeedForward(args.dim, mult=args.heads, classes=args.classes) - model = model.cuda() - - inputs = torch.rand((args.bs, args.dim)).cuda() - labels = torch.randint(0, 10, (args.bs, )).cuda() - for _ in range(100): - loss = model(inputs, labels) - loss.backward() - print('Done.') \ No newline at end of file + group = DeviceGroup() + + model = SingleLinear(args.dim, args.mult) + + inputs = LogicalTensor((args.bs, args.dim)) + output = model(inputs) + print('Done.') From e071c9607b339357997f7dc2934c8f6bd6b49b89 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 3 Aug 2021 11:39:36 +0000 Subject: [PATCH 0108/1892] include torch.nn.Module into cube --- cube/nn/__init__.py | 1 + examples/linear.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cube/nn/__init__.py b/cube/nn/__init__.py index 4394da26..7a8bac35 100644 --- a/cube/nn/__init__.py +++ b/cube/nn/__init__.py @@ -1 +1,2 @@ +from torch.nn import Module from cube.nn.linear import Linear \ No newline at end of file diff --git a/examples/linear.py b/examples/linear.py index adb933cc..99f6a36d 100644 --- a/examples/linear.py +++ b/examples/linear.py @@ -10,11 +10,11 @@ """ import cube +from cube import nn from cube.tensor.logic.tensor import LogicalTensor from cube.device.physic.group import DeviceGroup import torch -from torch import nn import argparse import z3 From 70cb8f9645f763c449580919bd8b1b71413549ce Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 3 Aug 2021 12:35:01 +0000 Subject: [PATCH 0109/1892] outline takes outputs ind input for init --- cube/operator/holist/generics.py | 44 +++++------------ cube/operator/holist/linear.py | 37 ++++++-------- cube/operator/logic/generics.py | 8 ++- cube/operator/logic/linear.py | 1 - cube/tensor/logic/outline.py | 68 ++++++++++++-------------- tests/operator/test_holistic_linear.py | 10 ++-- 6 files changed, 71 insertions(+), 97 deletions(-) diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py index b940e20e..3780b26e 100644 --- a/cube/operator/holist/generics.py +++ b/cube/operator/holist/generics.py @@ -20,28 +20,18 @@ class GenericHolisticOp: _default_policy_fn = None - def __init__(self, shapes): + def __init__(self, outputs, *args, **kwargs): """ Layout is the community distribution requirement for input and output logical tensors. - Format is the dimension ordering based on the logical format, - `None` indicates the format is consistent with logical op, - otherwise should be a list of integers like torch.Tensor.permute() - on the logical required format. - Args: - input_layout (list[Outliner, None]): outliner for each input - The length of outliner should be equal to the number of input - input_format (list[list[int], None]): - input dim order compare with logical definition - output_layout (list[Outlinter, None]): outliner for each output - The length of outliner should be equal to the number of output - output_format (list[list[int], None]): - output dim order compare with logical definition + outputs (list[LogicalTensor]): + output logical tensor (empty data) + *args, **kwargs: input arguments + """ self.solver = z3.Solver() - self.shapes = shapes self.input_layouts = list() self.output_layouts = list() @@ -86,7 +76,6 @@ def add_constraint(self, constraint): if not isinstance(constraint, z3.z3.BoolRef): raise TypeError("Expected z3.z3.BoolRef constraints") self.solver.add(constraint) - def set_config(self, config): if not isinstance(config, z3.z3.ModelRef): @@ -100,23 +89,16 @@ def input_adapter(self, *args, **kwargs): have tensors """ #TODO: kwargs - - input_num = len(args) - if len(self.input_layouts) != input_num: + if len(self.input_layouts) != len(args): raise RuntimeError("Fail to adapt input: layout length not equal") - # if len(self.input_format) != input_num: - # raise RuntimeError("Fail to adapt input: format length not equal") - - # step 1: data reformat based on the input argument - # for input, dim_order in zip(args, self.input_format): - # if dim_order is not None: - # input.permute(dim_order) + + # step1: TODO: format (dimension reorder support) # step 2: Policy: segmentation + deploy decision policy_fn = self._default_policy_fn if self.policy_fn is not None: policy_fn = self.policy_fn - config, input_ranks = policy_fn[0](self) + config, input_ranks = policy_fn[0](self, *args, **kwargs) self.set_config(config) # step 3: segmentation @@ -179,12 +161,8 @@ def output_adapter(self, outputs): ) logical_outputs.append(logical_tensor) - # step 2: data reformat based on the output - # for out_id in range(len(self.output_format)): - # dim_order = self.output_format[out_id] - # if dim_order is not None and isinstance(logical_outputs[out_id], LogicalTensor): - # logical_ouputs[out_id] = logical_ouputs[out_id].permute(dim_order) - + # step 2: TODO: data reformat based on the output + if len(logical_outputs) == 1: return logical_outputs[0] else: diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py index b40e6849..34f3c5ec 100644 --- a/cube/operator/holist/linear.py +++ b/cube/operator/holist/linear.py @@ -10,11 +10,6 @@ from cube.device.physic.group import DeviceGroup import torch -# expert space to declare all kinds of holistic operators - - -__all__ = ['kHolistLinearSets'] - class LinearColumnWeight(GenericHolisticOp): """ @@ -22,27 +17,26 @@ class LinearColumnWeight(GenericHolisticOp): Split W and b on the last dimension """ - def __init__(self, shapes): + def __init__(self, outputs, input, weight, bias): - super().__init__(shapes) + super().__init__(outputs, input, weight, bias) # input layouts - input_layout = outline.Full( - self.solver, self.shapes[0] - ) + input_layout = outline.Full(self.solver, input) + weight_layout = outline.SplitAxis( - self.solver, self.shapes[1], + self.solver, weight, axis=0, chunk_num=None, overlap=0 ) bias_layout = outline.SplitAxis( - self.solver, self.shapes[2], + self.solver, bias, axis=0, chunk_num=None, overlap=0 ) self.add_constraint(bias_layout.chunk_num == weight_layout.chunk_num) # output layouts output_layout = outline.SplitAxis( - self.solver, self.shapes[3], + self.solver, outputs[0], axis=1, chunk_num=None, overlap=0 ) self.add_constraint(output_layout.chunk_num == weight_layout.chunk_num) @@ -82,29 +76,30 @@ class LinearColumnInputRowWeight(GenericHolisticOp): """ Perform Y = XW + b - -> Y = [X1,X2] * [W1//W2] + b] - -> Y = X1W1 + X2W2 + b + -> Y = [X1,X2] * [W1//W2] + [b1 + b2]] + -> Y = X1W1 + X2W2 + b1 + b2 Split X (inputs) in column major (last dim), Split W (weights) in row major (first dim) + Split b (bias) in value major """ - def __init__(self, shapes): + def __init__(self, outputs, input, weight, bias): - super().__init__(shapes) + super().__init__(outputs, input, weight, bias) input_layout = outline.SplitAxis( - self.solver, self.shapes[0], + self.solver, input, axis=-1, chunk_num=None, overlap=0, ) weight_layout = outline.SplitAxis( - self.solver, self.shapes[1], + self.solver, weight, axis=1, chunk_num=None, overlap=0, ) self.add_constraint(weight_layout.chunk_num == input_layout.chunk_num) bias_layout = outline.SplitValue( - self.solver, self.shapes[2], + self.solver, bias, chunk_num=None, val_op=PartialSum ) @@ -112,7 +107,7 @@ def __init__(self, shapes): # output layout will only use reduce op output_layout = outline.SplitValue( - self.solver, self.shapes[3], + self.solver, outputs[0], chunk_num=None, val_op=PartialSum ) diff --git a/cube/operator/logic/generics.py b/cube/operator/logic/generics.py index 5cd2cd73..3bfb2773 100644 --- a/cube/operator/logic/generics.py +++ b/cube/operator/logic/generics.py @@ -33,16 +33,20 @@ def register(self, holistic_op): """ self.holist_ops.append(holistic_op) - def get_op(self, idx, shapes): + def get_op(self, idx, *args, **kwargs): """ Get holistic operator based on idx The holistic operator will be initialized with shapes + Args: + idx (int): index for the holist op factory + args, kwargs: (logical) tensor inputs + Returns: HolisticOp instance """ - return self.holist_ops[idx](shapes) + return self.holist_ops[idx](*args, **kwargs) class GenericLogicalOp: diff --git a/cube/operator/logic/linear.py b/cube/operator/logic/linear.py index c1a108a5..e8f32314 100644 --- a/cube/operator/logic/linear.py +++ b/cube/operator/logic/linear.py @@ -18,4 +18,3 @@ def shape_infer(self, input, weight, bias=None): output_shape = list(input.shape) output_shape[-1] = weight.shape[0] return [output_shape,] - diff --git a/cube/tensor/logic/outline.py b/cube/tensor/logic/outline.py index 399fa322..b94e3984 100644 --- a/cube/tensor/logic/outline.py +++ b/cube/tensor/logic/outline.py @@ -6,8 +6,8 @@ 1). restriction description on tensor segementation - 2). Translation procedure in runtime to translate such a restriction - to the real segmentation on given logical tensor shape. + 2). Translation procedure to translate such a restriction + to the real segmentation on given logical tensor. """ from cube.tensor.segment import Segment @@ -16,18 +16,17 @@ import z3 -# interface to setup restrictions on the segmentation - class BaseOutline: """ Basic class for declare outline To setup an attribute (requirement), use `inst_baseoutline.attribute_name = val` """ - def __init__(self, solver, shape): - super().__init__() + def __init__(self, solver, tensor): + if not isinstance(solver, z3.z3.Solver): + raise TypeError("Expected solver to be an z3.z3.Solver") self.solver = solver - self.shape = shape + self.shape = tensor.shape self.attributes = list() def get_attributes(self): @@ -88,28 +87,38 @@ def remove_config(self, config): raise TypeError("Expected config from z3 model()") self.solver.add(z3.Or([z3.Not(attr == config[attr]) for attr in self.attributes])) - def interpret(self, logical_tensor, config): + def interpret(self, tensor, config): + """ + Interpret to a list of segment based on the logical tensor and config + + Args: + tensor (LogicalTensor) + config (z3.z3.ModelRef) + + Returns: + list[Segment] + """ raise NotImplementedError class Full(BaseOutline): - def __init__(self, solver, shape): - super().__init__(solver, shape) + def __init__(self, solver, tensor): + super().__init__(solver, tensor) - def interpret(self, logical_tensor, config): + def interpret(self, tensor, config): if not isinstance(config, z3.z3.ModelRef): raise TypeError("Expected config from z3 model()") indices = TileIndices([0] * len(self.shape), self.shape) - segment = logical_tensor.select(indices, None, self.shape) + segment = tensor.select(indices, None, self.shape) return [segment] class SplitAxis(BaseOutline): - def __init__(self, solver, shape, axis, chunk_num, overlap): + def __init__(self, solver, tensor, axis, chunk_num, overlap): """ - Split the logical tensor spatially in `axis` dimension + Split the logical tensor uniformly in `axis` dimension TODO: support split axis with non-uniform chunk size @@ -119,15 +128,13 @@ def __init__(self, solver, shape, axis, chunk_num, overlap): which axis to split chunk_num: options (iterable int) / None / int: how many segments to produce - uniform: Boolean - whether restrict to uniform split overlap: options (iterable int) / int: overlap size on the boundary """ if not isinstance(axis, int): raise RuntimeError("Expected axis to be an integer") + super().__init__(solver, tensor) - super().__init__(solver, shape) self.axis = axis self.add_field(overlap=overlap) @@ -145,18 +152,8 @@ def __init__(self, solver, shape, axis, chunk_num, overlap): self.chunk_num * self.chunk_size - self.overlap * (self.chunk_num - 1) == total_size ) - def interpret(self, logical_tensor, config): - """ - Get segments from config - - Args: - logical_tensor (LogicalTensor): - the logical tensor - config: - Config searched by model output - - """ - if tuple(logical_tensor.shape) != tuple(self.shape): + def interpret(self, tensor, config): + if tuple(tensor.shape) != tuple(self.shape): raise RuntimeError("The logical tensor's shape doesn't match") if not isinstance(config, z3.z3.ModelRef): raise TypeError("Expected config from z3 model()") @@ -168,7 +165,7 @@ def interpret(self, logical_tensor, config): segments = list() for cid in range(chunk_num): indices = TileIndices(anchor, shape) - segment = logical_tensor.select(indices, None, shape) + segment = tensor.select(indices, None, shape) segments.append(segment) anchor[self.axis] += shape[self.axis] return segments @@ -176,7 +173,7 @@ def interpret(self, logical_tensor, config): class SplitValue(BaseOutline): - def __init__(self, solver, shape, chunk_num, val_op): + def __init__(self, solver, tensor, chunk_num, val_op): """ Split the whole tensor in value dimension. @@ -184,20 +181,19 @@ def __init__(self, solver, shape, chunk_num, val_op): Each segment value will be modified by `val_op`. """ - super().__init__(solver, shape) + super().__init__(solver, tensor) self.add_field(chunk_num=chunk_num) self.add_constraint(self.chunk_num >= 1) self.val_op = val_op - def interpret(self, logical_tensor, config): - if tuple(logical_tensor.shape) != tuple(self.shape): + def interpret(self, tensor, config): + if tuple(tensor.shape) != tuple(self.shape): raise RuntimeError("The logical tensor's shape doesn't match") chunk_num = config[self.chunk_num].as_long() segments = list() for cid in range(chunk_num): - # full tensor shape indices = TileIndices([0] * len(self.shape), self.shape) - segment = logical_tensor.select(indices, self.val_op, self.shape) + segment = tensor.select(indices, self.val_op, self.shape) segments.append(segment) for segment in segments: segment.val_op_segs.append(segments) diff --git a/tests/operator/test_holistic_linear.py b/tests/operator/test_holistic_linear.py index 885b9564..dfe0ae6d 100644 --- a/tests/operator/test_holistic_linear.py +++ b/tests/operator/test_holistic_linear.py @@ -57,12 +57,13 @@ def test_holistic_linear_op_column_weight(): input = LogicalTensor(shape=shapes[0]) weight = LogicalTensor(shape=shapes[1]) bias = LogicalTensor(shape=shapes[2]) + outputs = [LogicalTensor(shapes[3])] # ================================ Policy =========================== - holistic_op = LinearColumnWeight(shapes) + holistic_op = LinearColumnWeight(outputs, input, weight, bias) - def policy(holist_op): + def policy(holist_op, input, weight, bias): solver = holist_op.solver attributes = holist_op.attributes input_layout = holist_op.input_layouts[0] @@ -133,12 +134,13 @@ def test_holistic_linear_op_column_input_row_weight(): input = LogicalTensor(shape=shapes[0]) weight = LogicalTensor(shape=shapes[1]) bias = LogicalTensor(shape=shapes[2]) + outputs = [LogicalTensor(shapes[3])] # ================================ Policy =========================== - holistic_op = LinearColumnInputRowWeight(shapes) + holistic_op = LinearColumnInputRowWeight(outputs, input, weight, bias) - def policy(holist_op): + def policy(holist_op, input, weight, bias): solver = holist_op.solver attributes = holist_op.attributes input_layout = holist_op.input_layouts[0] From ab4647edacbb517c91c72bb8cdddf2f64a061793 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 3 Aug 2021 14:14:21 +0000 Subject: [PATCH 0110/1892] add linear example --- cube/operator/logic/generics.py | 23 ++++++----------------- examples/linear.py | 31 ++++++++++++++++++++++--------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/cube/operator/logic/generics.py b/cube/operator/logic/generics.py index 3bfb2773..81915ce3 100644 --- a/cube/operator/logic/generics.py +++ b/cube/operator/logic/generics.py @@ -33,7 +33,7 @@ def register(self, holistic_op): """ self.holist_ops.append(holistic_op) - def get_op(self, idx, *args, **kwargs): + def get_op(self, idx, outputs, *args, **kwargs): """ Get holistic operator based on idx @@ -46,7 +46,7 @@ def get_op(self, idx, *args, **kwargs): Returns: HolisticOp instance """ - return self.holist_ops[idx](*args, **kwargs) + return self.holist_ops[idx](outputs, *args, **kwargs) class GenericLogicalOp: @@ -71,27 +71,16 @@ def shape_infer(self, *args, **kwargs): """ raise NotImplementedError("Expected a shape infer engine") - def get_shapes(self, *args, **kwargs): - # get shapes of input and output - shapes = list() - for arg in args: - if isinstance(arg, LogicalTensor): - shapes.append(arg.shape) - else: - shapes.append(None) - shapes += self.shape_infer(*args, **kwargs) - return shapes - def get_op(self, *args, **kwargs): # get shapes of input and output - shapes = self.get_shapes(*args, **kwargs) - print(shapes) + shapes = self.shape_infer(*args, **kwargs) + outputs = [LogicalTensor(shape=shape, init_data=False) for shape in shapes] # use default policy if self.policy_fn is None: - composite_op = self._default_policy_fn[0](self.factory, shapes) + composite_op = self._default_policy_fn[0](self.factory, outputs, *args, **kwargs) # use user-customized policy else: - composite_op = self.policy_fn[0](self.factory, shapes) + composite_op = self.policy_fn[0](self.factory, outputs, *args, **kwargs) return composite_op def __call__(self, *args, **kwargs): diff --git a/examples/linear.py b/examples/linear.py index 99f6a36d..360dcfdc 100644 --- a/examples/linear.py +++ b/examples/linear.py @@ -24,20 +24,30 @@ # Expert Policy -def select_policy(holistic_ops, shapes): - return holistic_ops.get_op(0, shapes) - - -def segment_policy(holist_op): +def select_policy(holistic_ops, outputs, *args, **kwargs): + """ + Args: + Candidates: holistic_ops + *args, **kwargs: op input + """ + return holistic_ops.get_op(0, outputs, *args, **kwargs) + + +def segment_policy(holist_op, input, weight, bias): + """ + Args: + holistic_op (HolisticOp) + *args, **kwargs: op input + """ solver = holist_op.solver attributes = holist_op.attributes input_layout = holist_op.input_layouts[0] weight_layout = holist_op.input_layouts[1] bias_layout = holist_op.input_layouts[2] output_layout = holist_op.output_layouts[0] + # add restrictions based on device num - device_num = torch.cuda.device_count() - solver.add(weight_layout.chunk_num == 4) + holist_op.add_constraint(weight_layout.chunk_num == 4) # iterate all configs configs = list() @@ -75,7 +85,7 @@ class SingleLinear(nn.Module): def __init__(self, dim, mult): super().__init__() - self.net = cube.nn.Linear(dim, dim * mult) + self.net = nn.Linear(dim, dim * mult) def forward(self, x): output = self.net(x) @@ -91,10 +101,13 @@ def forward(self, x): args = parser.parse_args() # init distributed env - group = DeviceGroup() + rank = DeviceGroup().rank model = SingleLinear(args.dim, args.mult) inputs = LogicalTensor((args.bs, args.dim)) output = model(inputs) + + assert isinstance(output, LogicalTensor) + assert torch.is_tensor(output.get_physical_tensor(rank)) print('Done.') From 71d76c5c482cad81c3da9d014d557bb58957fd8e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 3 Aug 2021 14:37:13 +0000 Subject: [PATCH 0111/1892] wrap config choices --- cube/__init__.py | 1 + cube/config/__init__.py | 1 + cube/config/utils.py | 25 +++++++++++++++++++++++++ examples/linear.py | 17 +++++------------ 4 files changed, 32 insertions(+), 12 deletions(-) create mode 100644 cube/config/utils.py diff --git a/cube/__init__.py b/cube/__init__.py index 39c29770..996b1f04 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,3 +1,4 @@ from cube import operator from cube import nn from cube import device +from cube import config \ No newline at end of file diff --git a/cube/config/__init__.py b/cube/config/__init__.py index e69de29b..9bd37031 100644 --- a/cube/config/__init__.py +++ b/cube/config/__init__.py @@ -0,0 +1 @@ +from cube.config.utils import choices \ No newline at end of file diff --git a/cube/config/utils.py b/cube/config/utils.py new file mode 100644 index 00000000..610f6383 --- /dev/null +++ b/cube/config/utils.py @@ -0,0 +1,25 @@ +import copy +import z3 + +def choices(solver, attributes): + """ + Iterate each the config space + + Args: + solver (z3.z3.Solver) + attributes (list[z3.z3.xx]) + + Yield: + config (z3.z3.ModelRef) + """ + if not isinstance(solver, z3.z3.Solver): + raise TypeError("Expected solver to be an z3 solver") + solver = copy.deepcopy(solver) + while solver.check() == z3.sat: + config = solver.model() + solver.add( + z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) + ) + yield config + if len(attributes) == 0: + break diff --git a/examples/linear.py b/examples/linear.py index 360dcfdc..f8118cad 100644 --- a/examples/linear.py +++ b/examples/linear.py @@ -41,30 +41,23 @@ def segment_policy(holist_op, input, weight, bias): """ solver = holist_op.solver attributes = holist_op.attributes - input_layout = holist_op.input_layouts[0] weight_layout = holist_op.input_layouts[1] - bias_layout = holist_op.input_layouts[2] - output_layout = holist_op.output_layouts[0] # add restrictions based on device num holist_op.add_constraint(weight_layout.chunk_num == 4) # iterate all configs configs = list() - while solver.check() == z3.sat: - config = solver.model() + for config in cube.config.choices(solver, attributes): if DeviceGroup().rank == 0: - print('find config: {}'.format(config)) + print('find config: \n', config) configs.append(config) - solver.add( - z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) - ) - if len(attributes) == 0: - break - # choose one config -- suppose to the first + + # choose one config -- policy decision config = configs[0] if DeviceGroup().rank == 0: print('selected config: {}'.format(config)) + # deploy decisions chunk_num = config[weight_layout.chunk_num].as_long() input_ranks = [list(range(0, chunk_num)),] From 66879a4fc9b6447d8fd4db135661310445d2c069 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 3 Aug 2021 14:39:43 +0000 Subject: [PATCH 0112/1892] util detect --- scripts/keep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/keep.py b/scripts/keep.py index 5fd58d03..3054c7f2 100644 --- a/scripts/keep.py +++ b/scripts/keep.py @@ -46,7 +46,7 @@ def keep(rank, args): time.sleep(args.interval) while True: util = get_gpu_util(rank) - if util > 0: + if util <= 10: break print('rank {}: find gpu busy, keep sleeping...'.format(rank)) time.sleep(args.interval) From 8f1e7f371a6eb2885c2adad28507ca7ec2cc2ad7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 4 Aug 2021 10:20:43 +0000 Subject: [PATCH 0113/1892] init pipeline --- examples/case_study/pipeline_linear.py | 199 +++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 examples/case_study/pipeline_linear.py diff --git a/examples/case_study/pipeline_linear.py b/examples/case_study/pipeline_linear.py new file mode 100644 index 00000000..04c3991b --- /dev/null +++ b/examples/case_study/pipeline_linear.py @@ -0,0 +1,199 @@ +import torch +from torch import nn + + +class Linears(nn.Module): + """ + Note in model creation, it will only construct model chunks + that belong to this rank + """ + + def __init__(self, features, layers=4): + super().__init__() + self.ops = nn.ModuleList([]) + + myrank = torch.distributed.get_rank() + ngpus = torch.distributed.get_world_size() + op_per_rank = int(layers / ngpus) + + for _ in range(op_per_rank): + self.ops.append(nn.Linear(features, features)) + + def forward(self, x): + out = x + for op in self.ops: + out = op(out) + return out + + +def is_last_stage(): + return torch.distributed.get_rank() == 0 + + +def is_last_stage(): + return torch.distributed.get_rank() == torch.distributed.get_world_size() - 1 + + +#================= WhatToDO functions ==================# + +def forward_step(model, input_tensor): + output_tensor = model(input_tensor) + # last stage: calcuate loss + if is_last_stage(): + output_tensor = torch.sum(output_tensor) + return output_tensor + + +def backward_step(input_tensor, output_tensor, output_tensor_grad): + """ + Calculate input tensor gradient + """ + if input_tensor is not None: + input_tensor.retain_grad() + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + input_tensor_grad = None + if input_tensor is not None: + input_tensor_grad = input_tensor.grad + return input_tensor_grad + +#================= WhatToDO functions ==================# + +#================= Between Stage functions ==================# + +def send(tensor, to_rank): + """ + send tensor to the target rank + """ + if to_rank < 0 or to_rank >= torch.distributed.get_world_size(): + return None + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, to_rank + ) + reqs = torch.distributed.batch_isend_irecv([send_op]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + + +def recv(shape, from_rank, inputs_first_stage): + if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): + return None + tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device() + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, from_rank + ) + reqs = torch.distributed.batch_isend_irecv([recv_op]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + return tensor + +def send_and_recv(send_tensor, to_rank, recv_shape, from_rank, inputs_first_stage): + if to_rank > torch.distributed.get_world_size() or from_rank < 0: + return None + recv_tensor = torch.empty( + recv_shape, requires_grad=True, device=torch.cuda.current_device() + ) + send_op = torch.distributed.P2POp( + torch.distributed.isend, send_tensor, to_rank + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, from_rank + ) + reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + return recv_tensor + + + +#================= Between Stage functions ==================# + + + +def scheduling_1f1b(model, inputs, bs, feats, micro_bs): + myrank = torch.distributed.get_rank() + + num_microbatches = bs / micro_bs + num_warmup_microbatches = \ + (torch.distributed.get_world_size() - + torch.distributed.get_rank() - 1) + num_warmup_remaining = num_microbatches - num_warmup_microbatches + + input_tensors = list() + output_tensors = list() + + if inputs is not None: + inputs = torch.chunk(input_tensor, chunks=num_microbatches, dim=0) + + # warmup forward pass + for i in range(num_warmup_microbatches): + # recv forward + input_tensor = recv(torch.Size([bs, feats]), myrank-1, inputs) + # forward + output_tensor = forward_step(model, input_tensor) + # send forward + send(output_tensor, myrank+1) + + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + + # before running 1F1B, need to recieve first forward tensor + if num_warmup_remaining > 0: + # recv forward + input_tensor = recv(torch.Size([bs, feats]), myrank-1, inputs) + if input_tensor is None: + input_tensor = inputs[i+num_warmup_microbatches] + + # run 1F1B + for i in range(num_warmup_remaining): + # forward + output_tensor = forward_step(model, input_tensor) + # send forward + recv backward grads + output_tensor_grad = send_and_recv_backward( + output_tensor, myrank+1, torch.Size([bs, feats]), myrank+1) + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + # backward + input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) + if i != (num_warmup_remaining-1): + # send backward grads + recv forward results + input_tensor = send_and_recv( + input_tensor_grad, myrank-1, torch.Size([bs, feats]), myrank-1) + else: # last iteration - no more inputs + input_tensor = None + # send backward grads + send(input_tensor_grad, myrank - 1) + + # cooldown + for i in range(num_warmup_microbatches): + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + # recv backward gradients + output_tensor_grad = recv(torch.Size([bs, feats]), myrank+1) + # backward + input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) + # send backward gradients + send(input_tensor_grad, myrank-1) + + +if __name__ == '__main__': + + # initialize distributed env + local_rank = int(os.environ.get('LOCAL_RANK')) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + ) + + batch_size = 32 + features = 1024 + + torch.randn((batch_size, features)) + + From cd4f49776b8c183b00557147a3144f6363785ab1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 5 Aug 2021 02:39:34 +0000 Subject: [PATCH 0114/1892] 1f1b pipeline finish --- examples/case_study/pipeline_linear.py | 75 ++++++++++++++++++-------- 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/examples/case_study/pipeline_linear.py b/examples/case_study/pipeline_linear.py index 04c3991b..a16370c0 100644 --- a/examples/case_study/pipeline_linear.py +++ b/examples/case_study/pipeline_linear.py @@ -1,5 +1,18 @@ +"""Example Usage + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + examples/case_study/pipeline_linear.py +""" + import torch from torch import nn +import os class Linears(nn.Module): @@ -26,7 +39,7 @@ def forward(self, x): return out -def is_last_stage(): +def is_first_stage(): return torch.distributed.get_rank() == 0 @@ -48,11 +61,11 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad): """ Calculate input tensor gradient """ - if input_tensor is not None: + if input_tensor is not None and input_tensor.requires_grad: input_tensor.retain_grad() torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) input_tensor_grad = None - if input_tensor is not None: + if input_tensor is not None and input_tensor.requires_grad: input_tensor_grad = input_tensor.grad return input_tensor_grad @@ -75,9 +88,9 @@ def send(tensor, to_rank): torch.cuda.synchronize() -def recv(shape, from_rank, inputs_first_stage): +def recv(shape, from_rank, boundary_tensor): if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): - return None + return boundary_tensor tensor = torch.empty( shape, requires_grad=True, device=torch.cuda.current_device() ) @@ -90,17 +103,17 @@ def recv(shape, from_rank, inputs_first_stage): torch.cuda.synchronize() return tensor -def send_and_recv(send_tensor, to_rank, recv_shape, from_rank, inputs_first_stage): - if to_rank > torch.distributed.get_world_size() or from_rank < 0: - return None +def send_and_recv(send_tensor, recv_shape, rank, boundary_tensor): + if rank < 0 or rank >= torch.distributed.get_world_size(): + return boundary_tensor recv_tensor = torch.empty( recv_shape, requires_grad=True, device=torch.cuda.current_device() ) send_op = torch.distributed.P2POp( - torch.distributed.isend, send_tensor, to_rank + torch.distributed.isend, send_tensor, rank ) recv_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_tensor, from_rank + torch.distributed.irecv, recv_tensor, rank ) reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) for req in reqs: @@ -117,7 +130,7 @@ def send_and_recv(send_tensor, to_rank, recv_shape, from_rank, inputs_first_stag def scheduling_1f1b(model, inputs, bs, feats, micro_bs): myrank = torch.distributed.get_rank() - num_microbatches = bs / micro_bs + num_microbatches = int(bs / micro_bs) num_warmup_microbatches = \ (torch.distributed.get_world_size() - torch.distributed.get_rank() - 1) @@ -127,15 +140,19 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): output_tensors = list() if inputs is not None: - inputs = torch.chunk(input_tensor, chunks=num_microbatches, dim=0) + inputs = torch.chunk(inputs, chunks=num_microbatches, dim=0) + else: + inputs = [None] * num_microbatches # warmup forward pass for i in range(num_warmup_microbatches): # recv forward - input_tensor = recv(torch.Size([bs, feats]), myrank-1, inputs) + print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) + input_tensor = recv(torch.Size([micro_bs, feats]), myrank-1, inputs[i]) # forward output_tensor = forward_step(model, input_tensor) # send forward + print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) send(output_tensor, myrank+1) input_tensors.append(input_tensor) @@ -144,17 +161,19 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): # before running 1F1B, need to recieve first forward tensor if num_warmup_remaining > 0: # recv forward - input_tensor = recv(torch.Size([bs, feats]), myrank-1, inputs) - if input_tensor is None: - input_tensor = inputs[i+num_warmup_microbatches] + print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) + input_tensor = recv(torch.Size([micro_bs, feats]), myrank-1, inputs[num_warmup_microbatches]) # run 1F1B for i in range(num_warmup_remaining): # forward + if input_tensor is None: + print('[1f1b] rank {}: Unexpected None at step {}'.format(myrank, i)) output_tensor = forward_step(model, input_tensor) # send forward + recv backward grads - output_tensor_grad = send_and_recv_backward( - output_tensor, myrank+1, torch.Size([bs, feats]), myrank+1) + print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) + output_tensor_grad = send_and_recv( + output_tensor, torch.Size([micro_bs, feats]), myrank+1, None) input_tensors.append(input_tensor) output_tensors.append(output_tensor) # backward @@ -162,19 +181,21 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) if i != (num_warmup_remaining-1): # send backward grads + recv forward results + print('[1f1b] rank {}: step-{}: sending backward + recving forward...'.format(myrank, i)) input_tensor = send_and_recv( - input_tensor_grad, myrank-1, torch.Size([bs, feats]), myrank-1) + input_tensor_grad, torch.Size([micro_bs, feats]), myrank-1, inputs[num_warmup_microbatches+i+1]) else: # last iteration - no more inputs input_tensor = None # send backward grads - send(input_tensor_grad, myrank - 1) + print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) + send(input_tensor_grad, myrank-1) # cooldown for i in range(num_warmup_microbatches): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) # recv backward gradients - output_tensor_grad = recv(torch.Size([bs, feats]), myrank+1) + output_tensor_grad = recv(torch.Size([micro_bs, feats]), myrank+1, None) # backward input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) # send backward gradients @@ -190,10 +211,18 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): backend='nccl', init_method='env://', ) + myrank = torch.distributed.get_rank() - batch_size = 32 + bs = 32 + micro_bs = 1 features = 1024 - torch.randn((batch_size, features)) + model = Linears(features, layers=4).cuda() + if myrank == 0: + inputs = torch.randn((bs, features)).cuda() + else: + inputs = None + for _ in range(50): + scheduling_1f1b(model, inputs, bs, features, micro_bs) From b5fdf671f4d330fbc89451c267a43cd151e6f885 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 5 Aug 2021 02:43:45 +0000 Subject: [PATCH 0115/1892] add status info --- examples/case_study/pipeline_linear.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/case_study/pipeline_linear.py b/examples/case_study/pipeline_linear.py index a16370c0..8a0dceac 100644 --- a/examples/case_study/pipeline_linear.py +++ b/examples/case_study/pipeline_linear.py @@ -199,6 +199,7 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): # backward input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) # send backward gradients + print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) send(input_tensor_grad, myrank-1) From d76a5fdf7fb35659ffddef3d63d661fe918791ac Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 5 Aug 2021 02:54:53 +0000 Subject: [PATCH 0116/1892] separate offload and recompute --- .../{memory_linear.py => offload_linear.py} | 63 ----------------- examples/case_study/recompute_linear.py | 67 +++++++++++++++++++ 2 files changed, 67 insertions(+), 63 deletions(-) rename examples/case_study/{memory_linear.py => offload_linear.py} (67%) create mode 100644 examples/case_study/recompute_linear.py diff --git a/examples/case_study/memory_linear.py b/examples/case_study/offload_linear.py similarity index 67% rename from examples/case_study/memory_linear.py rename to examples/case_study/offload_linear.py index f4d14989..7df14b87 100644 --- a/examples/case_study/memory_linear.py +++ b/examples/case_study/offload_linear.py @@ -5,69 +5,6 @@ tensor_map = dict() -### Checkpoint PyTorch Implementation (Skip un-deterministic scenario) ### -# Note this implementation can only work with a module that consists -# multiple operators. This will won't work for one OP because the output -# for this module will be saved in next op -def checkpoint_module_linear(input, weight, bias): - - class Checkpoint(torch.autograd.Function): - """General class to wrapper op to enable checkpoint""" - @staticmethod - def forward(ctx, run_function, *args): - ctx.run_function = run_function - ctx.tensor_indices = [] - tensor_inputs = [] - for i, arg in enumerate(args): - if torch.is_tensor(arg): - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - ctx.save_for_backward(*tensor_inputs) - - with torch.no_grad(): - outputs = run_function(*args) - return outputs - @staticmethod - def backward(ctx, *args): - # retrieve what need to regenerate tensors - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors - # re-generate - for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - # detach inputs - detached_inputs = list() - for input in inputs: - if torch.is_tensor(input): - x = input.detach() - x.requires_grad = input.requires_grad - else: - x = input - detached_inputs.append(x) - detached_inputs = tuple(detached_inputs) - # generate output tensor - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - if torch.is_tensor(outputs): - outputs = (outputs,) - # run backward to tensors that require a grad - outputs_with_grad = list() - args_with_grad = list() - if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: - outputs_with_grad.append(outputs[i]) - args_with_grad.append(args[i]) - torch.autograd.backward(outputs_with_grad, args_with_grad) - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None - for inp in detached_inputs) - return (None, None) + grads - - output = Checkpoint.apply(torch._C._nn.linear, input, weight, bias) - return output - def swap_weight_grad_linear(input, weight, bias): diff --git a/examples/case_study/recompute_linear.py b/examples/case_study/recompute_linear.py new file mode 100644 index 00000000..f46fae58 --- /dev/null +++ b/examples/case_study/recompute_linear.py @@ -0,0 +1,67 @@ +import torch +import os + +torch.manual_seed(121) + +### Checkpoint PyTorch Implementation (Skip un-deterministic scenario) ### +# Note this implementation can only work with a module that consists +# multiple operators. This will won't work for one OP because the output +# for this module will be saved in next op +def checkpoint_module_linear(input, weight, bias): + + class Checkpoint(torch.autograd.Function): + """General class to wrapper op to enable checkpoint""" + @staticmethod + def forward(ctx, run_function, *args): + ctx.run_function = run_function + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + return outputs + @staticmethod + def backward(ctx, *args): + # retrieve what need to regenerate tensors + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + # re-generate + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + # detach inputs + detached_inputs = list() + for input in inputs: + if torch.is_tensor(input): + x = input.detach() + x.requires_grad = input.requires_grad + else: + x = input + detached_inputs.append(x) + detached_inputs = tuple(detached_inputs) + # generate output tensor + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + if torch.is_tensor(outputs): + outputs = (outputs,) + # run backward to tensors that require a grad + outputs_with_grad = list() + args_with_grad = list() + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs) + return (None, None) + grads + + output = Checkpoint.apply(torch._C._nn.linear, input, weight, bias) + return output \ No newline at end of file From e1837909d64e1328021fdb6bccbabe374c695c13 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 5 Aug 2021 03:10:01 +0000 Subject: [PATCH 0117/1892] gradient accumulation example --- .../case_study/grad_accumulation_linear.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 examples/case_study/grad_accumulation_linear.py diff --git a/examples/case_study/grad_accumulation_linear.py b/examples/case_study/grad_accumulation_linear.py new file mode 100644 index 00000000..b51b4104 --- /dev/null +++ b/examples/case_study/grad_accumulation_linear.py @@ -0,0 +1,33 @@ +import torch + + +if __name__ == '__main__': + + global_bs = 128 + bs = 32 + feats = 1024 + + inputs = [torch.randn((bs, feats)).cuda() for _ in range(16)] + weight = torch.randn((feats, feats)).cuda().requires_grad_() + bias = torch.randn((feats,)).cuda().requires_grad_() + + update_interval = int(global_bs / bs) + tic = 0 + for input_data in inputs: + tic += 1 + # forward + print('forward') + out = torch._C._nn.linear(input_data, weight, bias) + loss = torch.sum(out) + # backward - calculate grad: + # note pytorch in default accumulates gradients + print('backward') + loss.backward() + + # weight update + if tic % update_interval == 0: + print('weight update') + weight.data += weight.grad + weight.grad = None + bias.data += bias.grad + bias.grad = None From 538c9185ace76847d1b2782f141d80a3c9502e0b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 5 Aug 2021 03:12:07 +0000 Subject: [PATCH 0118/1892] add ultimate code --- examples/case_study/{ => ultimate}/grad_accumulation_linear.py | 0 examples/case_study/{ => ultimate}/offload_linear.py | 0 examples/case_study/{ => ultimate}/parallel_linear.py | 0 examples/case_study/{ => ultimate}/pipeline_linear.py | 0 examples/case_study/{ => ultimate}/recompute_linear.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename examples/case_study/{ => ultimate}/grad_accumulation_linear.py (100%) rename examples/case_study/{ => ultimate}/offload_linear.py (100%) rename examples/case_study/{ => ultimate}/parallel_linear.py (100%) rename examples/case_study/{ => ultimate}/pipeline_linear.py (100%) rename examples/case_study/{ => ultimate}/recompute_linear.py (100%) diff --git a/examples/case_study/grad_accumulation_linear.py b/examples/case_study/ultimate/grad_accumulation_linear.py similarity index 100% rename from examples/case_study/grad_accumulation_linear.py rename to examples/case_study/ultimate/grad_accumulation_linear.py diff --git a/examples/case_study/offload_linear.py b/examples/case_study/ultimate/offload_linear.py similarity index 100% rename from examples/case_study/offload_linear.py rename to examples/case_study/ultimate/offload_linear.py diff --git a/examples/case_study/parallel_linear.py b/examples/case_study/ultimate/parallel_linear.py similarity index 100% rename from examples/case_study/parallel_linear.py rename to examples/case_study/ultimate/parallel_linear.py diff --git a/examples/case_study/pipeline_linear.py b/examples/case_study/ultimate/pipeline_linear.py similarity index 100% rename from examples/case_study/pipeline_linear.py rename to examples/case_study/ultimate/pipeline_linear.py diff --git a/examples/case_study/recompute_linear.py b/examples/case_study/ultimate/recompute_linear.py similarity index 100% rename from examples/case_study/recompute_linear.py rename to examples/case_study/ultimate/recompute_linear.py From 88634ebb4359549b48c9615472cdc57f26c8df3e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 5 Aug 2021 03:24:18 +0000 Subject: [PATCH 0119/1892] logical code --- .../case_study/{ => logic}/naive_linear.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) rename examples/case_study/{ => logic}/naive_linear.py (54%) diff --git a/examples/case_study/naive_linear.py b/examples/case_study/logic/naive_linear.py similarity index 54% rename from examples/case_study/naive_linear.py rename to examples/case_study/logic/naive_linear.py index c370abcd..7304f477 100644 --- a/examples/case_study/naive_linear.py +++ b/examples/case_study/logic/naive_linear.py @@ -13,21 +13,27 @@ def linear(input, weight, bias=None): torch.cuda.set_device(0) # tensor definition - batch_size = 32 - out_features = 10240 - in_features = 10240 + batch_size = 128 + out_features = 1024 + in_features = 1024 weight = torch.rand((out_features, in_features)).cuda().requires_grad_() - # print('weight: ', weight) bias = torch.rand(out_features).cuda().requires_grad_() - # print('bias: ', bias) input = torch.rand((batch_size, in_features)).cuda() + # print('weight: ', weight) + # print('bias: ', bias) # print('input: ', input) - # op compute - print('======== Naive Single Device =======') - output = linear(input, weight, bias) - loss = torch.mean(output) * 100 - print(loss) - loss.backward() - print('weight grad: ', weight.grad.t()) - print('======== Naive Single Device =======') + # iterations + for _ in range(4): + # forward + output = linear(input, weight, bias) + loss = torch.mean(output) + print(loss) + # backward + loss.backward() + # print('weight grad: ', weight.grad.t()) + # weight update + weight.data += weight.grad + weight.grad = None + bias.data += bias.grad + bias.grad = None From d2618bb51f0e4d5f87733d8d90b0c16f6514ba73 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 5 Aug 2021 09:25:58 +0000 Subject: [PATCH 0120/1892] add zero-redundancy code of paramter / gradient partitioning --- examples/case_study/ultimate/zero_linear.py | 167 ++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 examples/case_study/ultimate/zero_linear.py diff --git a/examples/case_study/ultimate/zero_linear.py b/examples/case_study/ultimate/zero_linear.py new file mode 100644 index 00000000..4198071a --- /dev/null +++ b/examples/case_study/ultimate/zero_linear.py @@ -0,0 +1,167 @@ +""" +Zero Redundancy Implementation + +Partition Weights / Gradients / Optimizer States across GPUs + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + examples/case_study/ultimate/zero_linear.py +""" +import torch +import os +torch.manual_seed(121) + +tensor_map = dict() + +def linear_zero(input, weight, bias): + ### weight / bias is partitioned ### + class ZeroLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + + weight_id = id(weight) + bias_id = id(bias) + ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id)) + tensor_map[weight_id] = weight + tensor_map[bias_id] = bias + + # ======= all-gather parameters ========= # + device_num = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + # all-gather weight + weight_list = [torch.empty_like(weight) for _ in range(device_num)] + weight_list[rank] = weight + torch.distributed.all_gather(weight_list, weight) + weight_full = torch.cat(weight_list, dim=0).contiguous() + # all-gather bias + bias_list = [torch.empty_like(bias) for _ in range(device_num)] + bias_list[rank] = bias + torch.distributed.all_gather(bias_list, bias) + bias_full = torch.cat(bias_list, dim=0).contiguous() + # ======= all-gather parameters ========= # + + # compute: -> use full weight / bias + output = torch._C._nn.linear(input, weight_full, bias_full) + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight_id, bias_id = ctx.saved_tensors + weight = tensor_map[weight_id.item()] + bias = tensor_map[bias_id.item()] + + grad_input = grad_weight = grad_bas = None + if ctx.needs_input_grad[0]: + # ========== all-gather weight =========== # + weight_list = [torch.empty_like(weight) for _ in range(device_num)] + weight_list[rank] = weight + torch.distributed.all_gather(weight_list, weight) + weight_full = torch.cat(weight_list, dim=0).contiguous() + # ========== all-gather weight =========== # + + grad_input = grad_output.matmul(weight_full) + + if ctx.needs_input_grad[1]: + dim = grad_output.dim() + if dim > 2: + grad_weight_full = grad\ + .view(-1, grad_output.shape[-1])\ + .t()\ + .matmul(input.view(-1, input.shape[-1])) + else: + grad_weight_full = grad_output.t().matmul(input) + if ctx.needs_input_grad[2]: + grad_bias_full = grad_output.sum(0) + + ## ========== reduce-scatter for data parallelism ========= ## + device_num = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + grad_weight_list = list(torch.chunk(grad_weight_full, chunks=device_num, dim=0)) + grad_weight = torch.empty_like(grad_weight_list[rank]) + torch.distributed.reduce_scatter(grad_weight, grad_weight_list) + grad_bias_list = list(torch.chunk(grad_bias_full, chunks=device_num, dim=0)) + grad_bias = torch.empty_like(grad_bias_list[rank]) + torch.distributed.reduce_scatter(grad_bias, grad_bias_list) + ## ========== reduce-scatter for data parallelism ========= ## + + return grad_input, grad_weight, grad_bias + + output = ZeroLinear.apply(input, weight, bias) + return output + + +######### Utility ############# +def print_each_rank(msg, selected_rank=None): + myrank = torch.distributed.get_rank() + for rank in range(torch.distributed.get_world_size()): + if selected_rank is None or myrank in selected_rank: + if myrank == rank: + print('rank [{}]: {}\n'.format(rank, msg)) + torch.distributed.barrier() + + +if __name__ == '__main__': + + local_rank = int(os.environ.get('LOCAL_RANK')) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + ) + devices = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # tensor definition + batch_size = 32 + out_features = 10240 + in_features = 10240 ## 100 MB weight + + # weight + weight = torch.chunk( + torch.rand((out_features, in_features)), + chunks=devices, + dim=0 + )[rank].contiguous().cuda().requires_grad_() + + # bias + bias = torch.chunk( + torch.rand((out_features,)), + chunks=devices, + dim=0 + )[rank].contiguous().cuda().requires_grad_() + + # data + input = torch.rand((batch_size, in_features)).cuda() + + # op compute + print_each_rank('======== Zero-Redundancy =======', [0]) + + output = linear_zero(input, weight, bias) + loss = torch.mean(output) * 100 + print_each_rank('loss: {}'.format(loss)) + loss.backward() + + with torch.no_grad(): + weight.data += weight.grad + bias.data += bias.grad + + # finish_op_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 + # max_allocated = (torch.cuda.max_memory_allocated() - init_memory) / 1024 / 1024 + + # allocate tensor on gpu to see if swap workds + # after_alloc_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 + + # print('Memory Consumption (MB):\n\t input-require: {:.2f}\n\t after swap weight: {:.2f}\n\t after op run {:.2f}\n\t max allocated: {:.2f}\n\t after allocate {:.2f}'.format( + # input_memory, weight_swap_memory, finish_op_memory, max_allocated, after_alloc_memory)) + + # correctness verify + output = linear_zero(input, weight, bias) + loss = torch.mean(output) * 100 + print_each_rank('loss: {}'.format(loss)) + print_each_rank('======== Zero-Redundancy =======', [0]) From ad474a59665080c32aabfb2f4347c85a4479cc22 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 5 Aug 2021 10:41:09 +0000 Subject: [PATCH 0121/1892] clear comments --- examples/case_study/ultimate/grad_accumulation_linear.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/case_study/ultimate/grad_accumulation_linear.py b/examples/case_study/ultimate/grad_accumulation_linear.py index b51b4104..4f74474c 100644 --- a/examples/case_study/ultimate/grad_accumulation_linear.py +++ b/examples/case_study/ultimate/grad_accumulation_linear.py @@ -16,17 +16,16 @@ for input_data in inputs: tic += 1 # forward - print('forward') out = torch._C._nn.linear(input_data, weight, bias) loss = torch.sum(out) # backward - calculate grad: - # note pytorch in default accumulates gradients - print('backward') loss.backward() - + # Note: during backward, PyTorch will do tensor.grad += computed_grad + # if tensor had gradient, then do accumulation by default. + # weight update if tic % update_interval == 0: - print('weight update') + # weight update weight.data += weight.grad weight.grad = None bias.data += bias.grad From 9ae1f33795fbaa36ce9e67e73153cb52460f56c6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 Aug 2021 02:53:19 +0000 Subject: [PATCH 0122/1892] add adam optimizer --- examples/case_study/logic/naive_linear.py | 60 ++++++++++++++-- .../ultimate/grad_accumulation_linear.py | 71 ++++++++++++++++--- examples/case_study/ultimate/zero_linear.py | 52 +++++++++++++- 3 files changed, 167 insertions(+), 16 deletions(-) diff --git a/examples/case_study/logic/naive_linear.py b/examples/case_study/logic/naive_linear.py index 7304f477..ef50e8d2 100644 --- a/examples/case_study/logic/naive_linear.py +++ b/examples/case_study/logic/naive_linear.py @@ -1,5 +1,7 @@ import torch from torch.nn.parameter import Parameter +import math + torch.manual_seed(121) @@ -8,6 +10,24 @@ def linear(input, weight, bias=None): return output +def apply_adam(params, grads, exp_avgs, exp_avg_sqs, steps, beta1, beta2, lr): + for i, param in enumerate(params): + + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = steps[-1] + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8) + step_size = lr / bias_correction1 + param.addcdiv_(exp_avg, denom, value=-step_size) + + if __name__ == '__main__': torch.cuda.set_device(0) @@ -16,24 +36,54 @@ def linear(input, weight, bias=None): batch_size = 128 out_features = 1024 in_features = 1024 + weight = torch.rand((out_features, in_features)).cuda().requires_grad_() bias = torch.rand(out_features).cuda().requires_grad_() input = torch.rand((batch_size, in_features)).cuda() # print('weight: ', weight) # print('bias: ', bias) # print('input: ', input) + + ## Adam optimizer states -- 2x more weights volume + weight_exp_avg = torch.zeros_like( + weight, memory_format=torch.preserve_format + ) + weight_exp_avg_sq = torch.zeros_like( + weight, memory_format=torch.preserve_format + ) + bias_exp_avg = torch.zeros_like( + bias, memory_format=torch.preserve_format + ) + bias_exp_avg_sq = torch.zeros_like( + bias, memory_format=torch.preserve_format + ) + state_steps = list() + lr = 0.01 + beta1 = 0.5 + beta2 = 0.5 # iterations for _ in range(4): - # forward + # ======= step1: forward ======= # output = linear(input, weight, bias) loss = torch.mean(output) print(loss) - # backward + + # ======= step2: backward ======= # loss.backward() # print('weight grad: ', weight.grad.t()) - # weight update - weight.data += weight.grad + + # ======= step3: update ======= # + params = [weight, bias] + grads = [weight.grad, bias.grad] + exp_avgs = [weight_exp_avg, bias_exp_avg] + exp_avg_sqs = [weight_exp_avg_sq, bias_exp_avg_sq] + state_steps.append(len(state_steps)+1) + with torch.no_grad(): + apply_adam( + params, grads, exp_avgs, exp_avg_sqs, state_steps, + beta1, beta2, lr + ) + # zero out grad weight.grad = None - bias.data += bias.grad bias.grad = None diff --git a/examples/case_study/ultimate/grad_accumulation_linear.py b/examples/case_study/ultimate/grad_accumulation_linear.py index 4f74474c..659fcf7e 100644 --- a/examples/case_study/ultimate/grad_accumulation_linear.py +++ b/examples/case_study/ultimate/grad_accumulation_linear.py @@ -1,4 +1,26 @@ import torch +import math + +torch.manual_seed(121) + + +def apply_adam(params, grads, exp_avgs, exp_avg_sqs, steps, beta1, beta2, lr): + for i, param in enumerate(params): + + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = steps[-1] + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8) + step_size = lr / bias_correction1 + param.addcdiv_(exp_avg, denom, value=-step_size) + if __name__ == '__main__': @@ -7,26 +29,57 @@ bs = 32 feats = 1024 - inputs = [torch.randn((bs, feats)).cuda() for _ in range(16)] weight = torch.randn((feats, feats)).cuda().requires_grad_() bias = torch.randn((feats,)).cuda().requires_grad_() + ## Adam optimizer states -- 2x more weights volume + weight_exp_avg = torch.zeros_like( + weight, memory_format=torch.preserve_format + ) + weight_exp_avg_sq = torch.zeros_like( + weight, memory_format=torch.preserve_format + ) + bias_exp_avg = torch.zeros_like( + bias, memory_format=torch.preserve_format + ) + bias_exp_avg_sq = torch.zeros_like( + bias, memory_format=torch.preserve_format + ) + state_steps = list() + lr = 0.01 + beta1 = 0.5 + beta2 = 0.5 + + inputs = [torch.randn((bs, feats)).cuda() for _ in range(16)] + # inputs = [torch.randn((bs, feats)).cuda()] * 16 # for debug + update_interval = int(global_bs / bs) tic = 0 for input_data in inputs: tic += 1 - # forward + + # ======= step1: forward ======= # out = torch._C._nn.linear(input_data, weight, bias) - loss = torch.sum(out) - # backward - calculate grad: + loss = torch.mean(out) / update_interval ## loss also need scale + print('loss: {}'.format(loss)) + + # ======= step2: backward ======= # loss.backward() # Note: during backward, PyTorch will do tensor.grad += computed_grad - # if tensor had gradient, then do accumulation by default. + # if tensor had gradient, this will do accumulation by default. - # weight update + # ======= step3: update ======= # if tic % update_interval == 0: - # weight update - weight.data += weight.grad + params = [weight, bias] + grads = [weight.grad, bias.grad] + exp_avgs = [weight_exp_avg, bias_exp_avg] + exp_avg_sqs = [weight_exp_avg_sq, bias_exp_avg_sq] + state_steps.append(len(state_steps)+1) + with torch.no_grad(): + apply_adam( + params, grads, exp_avgs, exp_avg_sqs, state_steps, + beta1, beta2, lr + ) + # zero out grad weight.grad = None - bias.data += bias.grad bias.grad = None diff --git a/examples/case_study/ultimate/zero_linear.py b/examples/case_study/ultimate/zero_linear.py index 4198071a..1a73ee40 100644 --- a/examples/case_study/ultimate/zero_linear.py +++ b/examples/case_study/ultimate/zero_linear.py @@ -14,6 +14,7 @@ """ import torch import os +import math torch.manual_seed(121) tensor_map = dict() @@ -96,6 +97,24 @@ def backward(ctx, grad_output): return output +def apply_adam(params, grads, exp_avgs, exp_avg_sqs, steps, beta1, beta2, lr): + for i, param in enumerate(params): + + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = steps[-1] + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8) + step_size = lr / bias_correction1 + param.addcdiv_(exp_avg, denom, value=-step_size) + + ######### Utility ############# def print_each_rank(msg, selected_rank=None): myrank = torch.distributed.get_rank() @@ -136,6 +155,24 @@ def print_each_rank(msg, selected_rank=None): dim=0 )[rank].contiguous().cuda().requires_grad_() + ## Adam optimizer states -- Zero-DP: the states are partitioned + weight_exp_avg = torch.zeros_like( + weight, memory_format=torch.preserve_format + ) + weight_exp_avg_sq = torch.zeros_like( + weight, memory_format=torch.preserve_format + ) + bias_exp_avg = torch.zeros_like( + bias, memory_format=torch.preserve_format + ) + bias_exp_avg_sq = torch.zeros_like( + bias, memory_format=torch.preserve_format + ) + state_steps = list() + lr = 0.01 + beta1 = 0.5 + beta2 = 0.5 + # data input = torch.rand((batch_size, in_features)).cuda() @@ -147,9 +184,20 @@ def print_each_rank(msg, selected_rank=None): print_each_rank('loss: {}'.format(loss)) loss.backward() + # adam optimizer + params = [weight, bias] + grads = [weight.grad, bias.grad] + exp_avgs = [weight_exp_avg, bias_exp_avg] + exp_avg_sqs = [weight_exp_avg_sq, bias_exp_avg_sq] + state_steps.append(len(state_steps)+1) with torch.no_grad(): - weight.data += weight.grad - bias.data += bias.grad + apply_adam( + params, grads, exp_avgs, exp_avg_sqs, state_steps, + beta1, beta2, lr + ) + # zero out grad + weight.grad = None + bias.grad = None # finish_op_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 # max_allocated = (torch.cuda.max_memory_allocated() - init_memory) / 1024 / 1024 From 90ae7c8e85649d38779fe186ac97e0a763d4c5e6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 Aug 2021 03:08:44 +0000 Subject: [PATCH 0123/1892] pipeline 1f1b --- .../case_study/ultimate/pipeline_linear.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/case_study/ultimate/pipeline_linear.py b/examples/case_study/ultimate/pipeline_linear.py index 8a0dceac..8b100069 100644 --- a/examples/case_study/ultimate/pipeline_linear.py +++ b/examples/case_study/ultimate/pipeline_linear.py @@ -7,7 +7,7 @@ --master_addr=127.0.0.1 \ --master_port=62000 \ --use_env \ - examples/case_study/pipeline_linear.py + examples/case_study/ultimate/pipeline_linear.py """ import torch @@ -21,15 +21,15 @@ class Linears(nn.Module): that belong to this rank """ - def __init__(self, features, layers=4): + def __init__(self, features, op_num=4): super().__init__() self.ops = nn.ModuleList([]) myrank = torch.distributed.get_rank() ngpus = torch.distributed.get_world_size() - op_per_rank = int(layers / ngpus) + op_num_per_rank = int(op_num / ngpus) - for _ in range(op_per_rank): + for _ in range(op_num_per_rank): self.ops.append(nn.Linear(features, features)) def forward(self, x): @@ -103,6 +103,7 @@ def recv(shape, from_rank, boundary_tensor): torch.cuda.synchronize() return tensor + def send_and_recv(send_tensor, recv_shape, rank, boundary_tensor): if rank < 0 or rank >= torch.distributed.get_world_size(): return boundary_tensor @@ -121,11 +122,10 @@ def send_and_recv(send_tensor, recv_shape, rank, boundary_tensor): torch.cuda.synchronize() return recv_tensor - - #================= Between Stage functions ==================# +#================= Scheduling ==================# def scheduling_1f1b(model, inputs, bs, feats, micro_bs): myrank = torch.distributed.get_rank() @@ -167,8 +167,6 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): # run 1F1B for i in range(num_warmup_remaining): # forward - if input_tensor is None: - print('[1f1b] rank {}: Unexpected None at step {}'.format(myrank, i)) output_tensor = forward_step(model, input_tensor) # send forward + recv backward grads print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) @@ -190,7 +188,7 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) send(input_tensor_grad, myrank-1) - # cooldown + # cooldown gradient trans back for i in range(num_warmup_microbatches): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) @@ -202,6 +200,8 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) send(input_tensor_grad, myrank-1) +#================= Scheduling ==================# + if __name__ == '__main__': @@ -216,9 +216,9 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): bs = 32 micro_bs = 1 - features = 1024 + features = 10240 - model = Linears(features, layers=4).cuda() + model = Linears(features, op_num=4).cuda() if myrank == 0: inputs = torch.randn((bs, features)).cuda() From 4a91c64601b7b52cc6ce75fdf3fd35127aa3563b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 Aug 2021 03:24:20 +0000 Subject: [PATCH 0124/1892] model partition linear --- .../ultimate/model_partition_linear.py | 154 ++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 examples/case_study/ultimate/model_partition_linear.py diff --git a/examples/case_study/ultimate/model_partition_linear.py b/examples/case_study/ultimate/model_partition_linear.py new file mode 100644 index 00000000..75b262d3 --- /dev/null +++ b/examples/case_study/ultimate/model_partition_linear.py @@ -0,0 +1,154 @@ +"""Example Usage + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + examples/case_study/ultimate/model_partition_linear.py +""" + +import torch +from torch import nn +import os + +class Linears(nn.Module): + """ + Note in model creation, it will only construct model chunks + that belong to this rank + """ + + def __init__(self, features, op_num=4): + super().__init__() + self.ops = nn.ModuleList([]) + + myrank = torch.distributed.get_rank() + ngpus = torch.distributed.get_world_size() + op_num_per_rank = int(op_num / ngpus) + + for _ in range(op_num_per_rank): + self.ops.append(nn.Linear(features, features)) + + def forward(self, x): + out = x + for op in self.ops: + out = op(out) + return out + + +def is_last_stage(): + return torch.distributed.get_rank() == torch.distributed.get_world_size() - 1 + +#================= WhatToDO functions ==================# + +def forward_step(model, input_tensor): + output_tensor = model(input_tensor) + # last stage: calcuate loss + if is_last_stage(): + output_tensor = torch.sum(output_tensor) + print('loss: {}'.format(output_tensor)) + return output_tensor + + +def backward_step(input_tensor, output_tensor, output_tensor_grad): + """ + Calculate input tensor gradient + """ + if input_tensor is not None and input_tensor.requires_grad: + input_tensor.retain_grad() + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + input_tensor_grad = None + if input_tensor is not None and input_tensor.requires_grad: + input_tensor_grad = input_tensor.grad + return input_tensor_grad + +#================= WhatToDO functions ==================# + +#================= Between Stage functions ==================# + +def send(tensor, to_rank): + """ + send tensor to the target rank + """ + if to_rank < 0 or to_rank >= torch.distributed.get_world_size(): + return None + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, to_rank + ) + reqs = torch.distributed.batch_isend_irecv([send_op]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + + +def recv(shape, from_rank, boundary_tensor): + if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): + return boundary_tensor + tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device() + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, from_rank + ) + reqs = torch.distributed.batch_isend_irecv([recv_op]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + return tensor + +#================= Between Stage functions ==================# + + +#================= Scheduling ==================# + +def scheduling_naive(model, inputs, bs, feats): + + myrank = torch.distributed.get_rank() + + # ================ forward pass ================ # + # recv input data + input_tensor = recv(torch.Size([bs, feats]), myrank-1, inputs) + # forward + output_tensor = forward_step(model, input_tensor) + # send forward + send(output_tensor, myrank+1) + + # ================ backward pass ================ # + # recv backward + output_tensor_grad = recv(torch.Size([bs, feats]), myrank+1, None) + # backward + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad) + # send backward + send(input_tensor_grad, myrank-1) + + # ================ weight update ================ # + # xxx + +#================= Scheduling ==================# + + +if __name__ == '__main__': + + # initialize distributed env + local_rank = int(os.environ.get('LOCAL_RANK')) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group( + backend='nccl', + init_method='env://', + ) + myrank = torch.distributed.get_rank() + + bs = 32 + features = 10240 + + model = Linears(features, op_num=4).cuda() + + if myrank == 0: + inputs = torch.randn((bs, features)).cuda() + else: + inputs = None + + scheduling_naive(model, inputs, bs, features) From d47c6bdcab36cd478916f3f8581e0a6c84878635 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 Aug 2021 05:24:12 +0000 Subject: [PATCH 0125/1892] profiling code --- examples/case_study/ultimate/pipeline_linear.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/case_study/ultimate/pipeline_linear.py b/examples/case_study/ultimate/pipeline_linear.py index 8b100069..77bdfb25 100644 --- a/examples/case_study/ultimate/pipeline_linear.py +++ b/examples/case_study/ultimate/pipeline_linear.py @@ -13,6 +13,7 @@ import torch from torch import nn import os +import time class Linears(nn.Module): @@ -227,3 +228,5 @@ def scheduling_1f1b(model, inputs, bs, feats, micro_bs): for _ in range(50): scheduling_1f1b(model, inputs, bs, features, micro_bs) + # torch.distributed.barrier() # for profiling only + # time.sleep(1) From f36d855b40f6fb112d6791362a37518ae1fe9b17 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 Aug 2021 06:30:23 +0000 Subject: [PATCH 0126/1892] update comment --- examples/case_study/ultimate/parallel_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/case_study/ultimate/parallel_linear.py b/examples/case_study/ultimate/parallel_linear.py index ced2100e..1f077c27 100644 --- a/examples/case_study/ultimate/parallel_linear.py +++ b/examples/case_study/ultimate/parallel_linear.py @@ -232,7 +232,7 @@ def print_each_rank(msg, selected_rank=None): input = torch.rand((batch_size, in_features)).cuda() # print_each_rank('input: {}'.format(input)) - # model parallel + # tensor parallel print_each_rank('======== Model Parallel =========', [0]) output = linear_tensor_parallel(input, weight, bias) loss = torch.mean(output) * 100 From a2f328f95d11369ec3209a6a47203e585509ab50 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 Aug 2021 06:55:20 +0000 Subject: [PATCH 0127/1892] parallel primitive update --- examples/case_study/parallel_primitive.py | 64 +++++++++++------------ 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/examples/case_study/parallel_primitive.py b/examples/case_study/parallel_primitive.py index 1c5bb13d..6a0e79d9 100644 --- a/examples/case_study/parallel_primitive.py +++ b/examples/case_study/parallel_primitive.py @@ -13,57 +13,57 @@ def linear_tensor_parallel(inputs, weight, bias): K = 1024 N = 1024 - ### ============ Input Adapter ============ ### - - # select need to consider transformation from one segmentation to another - inputs_segment = select( + inputs = select( tensor = inputs, indices = (slice(0, M), slice(0, K)), + val_op = None, + ranks = [0, 1, 2, 3], shape = (M, K) ) - weight_segment = select( + weight = select( tensor = weight, indices = (slice(rank * (N // 4), (rank + 1) * (N // 4)), slice(0, K)), + val_op = None, + ranks = [0, 1, 2, 3], shape = (N // 4, K) ) - bias_segment = select( + bias = select( tensor = bias, indices = (slice(rank * (N // 4), (rank + 1) * (N // 4)),), - shape = (N // 4,) - ) - - inputs = deploy( - segment = inputs_segment, + val_op = None, ranks = [0, 1, 2, 3], - val_map_op = IdentityForwardAllreduceBackward - ) - - weight = deploy( - segment = weight_segment, - ranks = [rank], - val_map_op = None - ) - - bias = deploy( - segment = bias_segment, - ranks = [rank], - val_map_op = None + shape = (N // 4,) ) - ### ============ Input Adapter ============ ### - ### ============ Compute ============ ### + # each rank do this output = torch._C._nn.linear(inputs, weight, bias) ### ============ Compute ============ ### - - ### ============ Output Adapter ============ ### - segment = recover( + output = merge( tensor = output, ranks = [0, 1, 2, 3], - reduction_op = AllGatherForwardSplitBackward + merge_op = all_gather ) - # construct to logical tensor and return - ### ============ Output Adapter ============ ### + + + + # inputs = deploy( + # segment = inputs_segment, + # ranks = [0, 1, 2, 3], + # val_map_op = IdentityForwardAllreduceBackward + # ) + + # weight = deploy( + # segment = weight_segment, + # ranks = [rank], + # val_map_op = None + # ) + + # bias = deploy( + # segment = bias_segment, + # ranks = [rank], + # val_map_op = None + # ) \ No newline at end of file From 6d114eaebe85885a5d72ac4321403a8e279e6868 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 9 Aug 2021 01:42:10 +0000 Subject: [PATCH 0128/1892] naive torch code for ffn --- examples/ffn.py | 88 ++++++++++++++++--------------------------------- 1 file changed, 28 insertions(+), 60 deletions(-) diff --git a/examples/ffn.py b/examples/ffn.py index b5e90099..17a66a80 100644 --- a/examples/ffn.py +++ b/examples/ffn.py @@ -1,72 +1,33 @@ -""" -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=6000 \ - --use_env \ - examples/linear.py -""" - import torch from torch import nn -from torch import Tensor -from torch.nn.parameter import Parameter import torch.nn.functional as F -import math import argparse -class Linear(nn.Module): - - __constants__ = ['in_features', 'out_features'] - in_features: int - out_features: int - weight: Tensor - - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super(Linear, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs)) - if bias: - self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter('bias', None) - self.reset_parameters() - - def reset_parameters(self) -> None: - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(self.bias, -bound, bound) - - def forward(self, input: Tensor) -> Tensor: - return combo_op.linear_op(input, self.weight, self.bias) - - class FeedForward(nn.Module): def __init__(self, dim, dropout=0., mult=16, classes=1000): super().__init__() self.net = nn.Sequential( - Linear(dim, dim * mult), + nn.Linear(dim, dim * mult), nn.GELU(), nn.Dropout(dropout), - Linear(dim * mult, dim) + nn.Linear(dim * mult, dim) ) - self.classifier = Linear(dim, classes) + self.classifier = nn.Linear(dim, classes) - def forward(self, x, labels): + def forward(self, x): output = self.net(x) output = self.classifier(output) - loss = F.cross_entropy(output, labels) - return loss + return output + + +def data_iter(bs, dim, classes, length=64): + for _ in range(length): + data = torch.randn((bs, dim)) + label = torch.randint(0, classes, (bs,)) + yield data, label if __name__ == '__main__': @@ -78,16 +39,23 @@ def forward(self, x, labels): parser.add_argument('--classes', type=int, default=10) args = parser.parse_args() - # init distributed env - group = combo.physical.device.group.DeviceGroup() - print(group) - model = FeedForward(args.dim, mult=args.heads, classes=args.classes) model = model.cuda() - inputs = torch.rand((args.bs, args.dim)).cuda() - labels = torch.randint(0, 10, (args.bs, )).cuda() - for _ in range(100): - loss = model(inputs, labels) + optimizer = torch.optim.Adam( + model.parameters(), + lr=0.001, + betas=(0.9, 0.99), + weight_decay=0 + ) + + for (data, label) in data_iter(args.bs, args.dim, args.classes): + data, label = data.cuda(), label.cuda() + # forward + output = model(data) + loss = F.cross_entropy(output, label) + # backward loss.backward() - print('Done.') \ No newline at end of file + # weight update + optimizer.step() + optimizer.zero_grad() From 06e01a045fb1db607a1d06443755dba766fd26b8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 9 Aug 2021 02:35:29 +0000 Subject: [PATCH 0129/1892] example for policy --- examples/case_study/policy/logical_code.py | 69 ++++++++++++++++++++++ examples/case_study/policy/policy.py | 34 +++++++++++ 2 files changed, 103 insertions(+) create mode 100644 examples/case_study/policy/logical_code.py create mode 100644 examples/case_study/policy/policy.py diff --git a/examples/case_study/policy/logical_code.py b/examples/case_study/policy/logical_code.py new file mode 100644 index 00000000..f77589ce --- /dev/null +++ b/examples/case_study/policy/logical_code.py @@ -0,0 +1,69 @@ +import torch +from torch import nn +import torch.nn.functional as F + +import argparse + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=16, classes=1000): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim) + ) + + self.classifier = nn.Linear(dim, classes) + + def forward(self, x): + with annotate(data_parallel): + output = self.net(x) + output = self.classifier(output) + return output + + +def data_iter(bs, dim, classes, length=64): + for _ in range(length): + data = torch.randn((bs, dim)) + label = torch.randint(0, classes, (bs,)) + yield data, label + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--dim', type=int, default=1024) + parser.add_argument('--heads', type=int, default=16) + parser.add_argument('--bs', type=int, default=8) + parser.add_argument('--classes', type=int, default=10) + args = parser.parse_args() + + model = FeedForward(args.dim, mult=args.heads, classes=args.classes) + # model = model.cuda() + + ### ======= get DAG and modify by policy ======= ### + dag = get_dag(model, data) + new_dag = policy(dag, resources)[myrank] + model = new_dag + ### ======= get DAG and modify by policy ======= ### + + optimizer = torch.optim.Adam( + model.parameters(), + lr=0.001, + betas=(0.9, 0.99), + weight_decay=0 + ) + + + for (data, label) in data_iter(args.bs, args.dim, args.classes): + data, label = data.cuda(), label.cuda() + # forward + output = model(data) + loss = F.cross_entropy(output, label) + # backward + loss.backward() + # weight update + optimizer.step() + optimizer.zero_grad() diff --git a/examples/case_study/policy/policy.py b/examples/case_study/policy/policy.py new file mode 100644 index 00000000..0ef22b51 --- /dev/null +++ b/examples/case_study/policy/policy.py @@ -0,0 +1,34 @@ + + +def policy(DAG, resources): + """ + Args: + * DAG: semantic (logical) computation graph + * Resources: Environment inlcuding devices, network topology etc + Returns: + * DAGs (list[DAG]) execution (local & physical) DAG for each rank + """ + for inputs, op, outputs in DAG: + # tensor placement / lifecycle adapter + if is_annotated(inputs): + placement_lifecycle_adapter(DAG, inputs) + if is_annotated(op): + # distributed op adapter + dist_op = select(op, inputs, resources) + replace(DAG, op, dist_op) + # input tensor segmentation adapter + input_adapter(DAG, dist_op, inputs) + # output tensor segmentation adapter + output_adapter(DAG, dist_op, outputs) + # tensor move / destroy + if is_annotated(outputs): + placement_lifecycle_adapter(DAG, outputs) + DAGs = generate_for_each_rank(DAG, resources) + return DAGs + + +def select(op, inputs, resources): + op_candidates = get_distributed_ops(type(op)) + for candidate in op_candidates: + if candidate.same_segmentation(inputs): + return candidate From 480f7bc37797f99721aafeac8dc2fb0369972f7b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 9 Aug 2021 06:49:40 +0000 Subject: [PATCH 0130/1892] update code --- examples/case_study/policy/logical_code.py | 6 +++++ examples/case_study/policy/policy.py | 31 +++++++++++++++------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/examples/case_study/policy/logical_code.py b/examples/case_study/policy/logical_code.py index f77589ce..c997ff19 100644 --- a/examples/case_study/policy/logical_code.py +++ b/examples/case_study/policy/logical_code.py @@ -67,3 +67,9 @@ def data_iter(bs, dim, classes, length=64): # weight update optimizer.step() optimizer.zero_grad() + + +## dynamics? + +## weight update + forward concurrent + diff --git a/examples/case_study/policy/policy.py b/examples/case_study/policy/policy.py index 0ef22b51..6e4d1581 100644 --- a/examples/case_study/policy/policy.py +++ b/examples/case_study/policy/policy.py @@ -1,4 +1,14 @@ +""" +DAG interface: + add_op + + delete_op + + update_op + + find_op / iter_op +""" def policy(DAG, resources): """ @@ -8,26 +18,29 @@ def policy(DAG, resources): Returns: * DAGs (list[DAG]) execution (local & physical) DAG for each rank """ - for inputs, op, outputs in DAG: - # tensor placement / lifecycle adapter - if is_annotated(inputs): - placement_lifecycle_adapter(DAG, inputs) + for inputs, op, outputs in iter_op(DAG): if is_annotated(op): # distributed op adapter - dist_op = select(op, inputs, resources) - replace(DAG, op, dist_op) + dist_op = select_dist_op(op, inputs, resources) + replace_op(DAG, op, dist_op) # input tensor segmentation adapter - input_adapter(DAG, dist_op, inputs) + inputs = input_adapter(DAG, dist_op, inputs) # output tensor segmentation adapter - output_adapter(DAG, dist_op, outputs) + outputs = output_adapter(DAG, dist_op, outputs) + # tensor placement / lifecycle adapter + if is_annotated(inputs): + placement_lifecycle_adapter(DAG, inputs) # tensor move / destroy if is_annotated(outputs): placement_lifecycle_adapter(DAG, outputs) + # scheduling + # TODO: do we need to include scheduling in the DAG? + # materialize to physical op DAGs = generate_for_each_rank(DAG, resources) return DAGs -def select(op, inputs, resources): +def select_dist_op(op, inputs, resources): op_candidates = get_distributed_ops(type(op)) for candidate in op_candidates: if candidate.same_segmentation(inputs): From 1f46700675eb7c927ec211b41f44ef730d4a756a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 12 Aug 2021 06:50:35 +0000 Subject: [PATCH 0131/1892] megatron benchmark code --- benchmark/megatron_gpt_2.sh | 47 +++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100755 benchmark/megatron_gpt_2.sh diff --git a/benchmark/megatron_gpt_2.sh b/benchmark/megatron_gpt_2.sh new file mode 100755 index 00000000..f5f86900 --- /dev/null +++ b/benchmark/megatron_gpt_2.sh @@ -0,0 +1,47 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=62001 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=/data/webtext2/my-gpt2_text_document + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +## Optional Config ## +# --checkpoint-activations \ +# NCCL_P2P_DISABLE=1 + +NCCL_P2P_DISABLE=1 python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + /workspace/Megatron-LM/pretrain_gpt.py \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --micro-batch-size 4 \ + --global-batch-size 16 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --data-path $DATA_PATH \ + --vocab-file /data/gpt2-vocab.json \ + --merge-file /data/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --log-interval 100 \ + --fp16 From 3b9e2f09b6639fb8de3ca1ead56cd0db0dcba5c5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 12 Aug 2021 11:26:36 +0000 Subject: [PATCH 0132/1892] script update --- benchmark/megatron_gpt_2.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmark/megatron_gpt_2.sh b/benchmark/megatron_gpt_2.sh index f5f86900..19e2c294 100755 --- a/benchmark/megatron_gpt_2.sh +++ b/benchmark/megatron_gpt_2.sh @@ -10,7 +10,7 @@ NNODES=1 NODE_RANK=0 WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) -DATA_PATH=/data/webtext2/my-gpt2_text_document +DATA_PATH=/mydata/LargeModel/GPT-2/webtext2/my-gpt2_text_document DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" @@ -18,15 +18,15 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ # --checkpoint-activations \ # NCCL_P2P_DISABLE=1 -NCCL_P2P_DISABLE=1 python -m torch.distributed.launch $DISTRIBUTED_ARGS \ +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ /workspace/Megatron-LM/pretrain_gpt.py \ - --tensor-model-parallel-size 8 \ - --pipeline-model-parallel-size 1 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 8 \ --num-layers 24 \ --hidden-size 1024 \ --num-attention-heads 16 \ --micro-batch-size 4 \ - --global-batch-size 16 \ + --global-batch-size 64 \ --seq-length 1024 \ --max-position-embeddings 1024 \ --train-iters 500000 \ From 5c81dcd0a1adf774b40290ebb8db20f0bbcb96d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 13 Aug 2021 02:14:13 +0000 Subject: [PATCH 0133/1892] update env with azcopy --- scripts/env-setup.sh | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index d28e204e..80843eba 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -12,6 +12,22 @@ sudo chmod -R a+w /opt/conda sudo apt-get install tmux -y sudo apt-get install psmisc -y +sudo apt-get install lsof -y + +# install blob +# sudo apt-get install lsb-release -y +# wget https://packages.microsoft.com/config/ubuntu/20.04/packages-microsoft-prod.deb +# sudo dpkg -i packages-microsoft-prod.deb +# sudo apt-get update +# sudo apt-get install blobfuse -y +# sudo rm packages-microsoft-prod.deb + +# install azcopy +wget https://azcopyvnext.azureedge.net/release20210616/azcopy_linux_amd64_10.11.0.tar.gz -O azcopy.tar.gz +tar -zxvf azcopy.tar.gz +sudo mv azcopy_linux_amd64_10.11.0/azcopy /usr/bin/ +rm -rf azcopy_linux_amd64_10.11.0 azcopy.tar.gz + wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.tmux.conf -O ~/.tmux.conf wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.vimrc -O ~/.vimrc From b03f183e27ea381f42319853e204cd924a3f622b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 13 Aug 2021 03:01:34 +0000 Subject: [PATCH 0134/1892] to blob data --- benchmark/megatron_gpt_2.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/megatron_gpt_2.sh b/benchmark/megatron_gpt_2.sh index 19e2c294..d2f525d0 100755 --- a/benchmark/megatron_gpt_2.sh +++ b/benchmark/megatron_gpt_2.sh @@ -32,8 +32,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --train-iters 500000 \ --lr-decay-iters 320000 \ --data-path $DATA_PATH \ - --vocab-file /data/gpt2-vocab.json \ - --merge-file /data/gpt2-merges.txt \ + --vocab-file /mydata/LargeModel/GPT-2/gpt2-vocab.json \ + --merge-file /mydata/LargeModel/GPT-2/gpt2-merges.txt \ --data-impl mmap \ --split 949,50,1 \ --distributed-backend nccl \ From 477b4b61af80a1932234fe0f900aadc232619a3b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 18 Aug 2021 05:08:10 +0000 Subject: [PATCH 0135/1892] update benchmark tool --- benchmark/megatron_gpt_2.sh | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/benchmark/megatron_gpt_2.sh b/benchmark/megatron_gpt_2.sh index d2f525d0..08b52245 100755 --- a/benchmark/megatron_gpt_2.sh +++ b/benchmark/megatron_gpt_2.sh @@ -17,15 +17,19 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ ## Optional Config ## # --checkpoint-activations \ # NCCL_P2P_DISABLE=1 +# --fp16 + +rm -rf /workspace/Megatron-LM/megatron/fused_kernels/build python -m torch.distributed.launch $DISTRIBUTED_ARGS \ /workspace/Megatron-LM/pretrain_gpt.py \ + --checkpoint-activations \ --tensor-model-parallel-size 1 \ --pipeline-model-parallel-size 8 \ --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --micro-batch-size 4 \ + --hidden-size 2304 \ + --num-attention-heads 24 \ + --micro-batch-size 1 \ --global-batch-size 64 \ --seq-length 1024 \ --max-position-embeddings 1024 \ @@ -43,5 +47,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --weight-decay 1e-2 \ --clip-grad 1.0 \ --lr-warmup-fraction .01 \ - --log-interval 100 \ - --fp16 + --no-masked-softmax-fusion \ + --no-bias-dropout-fusion \ + --no-bias-gelu-fusion \ + --log-interval 10 From ba7b68a38d5b7791faeda8f06c412c61051abd97 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 20 Aug 2021 06:50:01 +0000 Subject: [PATCH 0136/1892] add schedule primitive --- examples/case_study/schedule_primitive.py | 102 ++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 examples/case_study/schedule_primitive.py diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py new file mode 100644 index 00000000..359a827b --- /dev/null +++ b/examples/case_study/schedule_primitive.py @@ -0,0 +1,102 @@ +import torch + +from functools import partial + +## Primitive ## + +def select(tensor, indices, val_map_op=None): + pass + +def execute(action, *args, **kwargs): + return action(*args, **kwargs) + +def add_flow(*actions): pass + +def run(schedule): + """ + Take a list of actions and execute in list order + """ + for action in schedule: + outs = execute(action) + return outs + +class Action: pass + + +# ===================== Basic steps ================== # +def general_action(flow_in, *args, **kwargs): + """ + flow_in: the output from previous actions + """ + pass + +def forward(flow_in, model, data): pass + +def backward(flow_in): pass + +def update(flow_in, optimizer): pass +# ===================== Basic steps ================== # + + +def naive_schedule(model, data, optimizer): + + f = Action(partial(forward, model=model, data=data)) + b = Action(partial(backward)) + u = Action(partial(update, optimizer=optimizer)) + + add_flow(f, b ,u) + + schedules = [f, b, u] + + return schedules + + +def pipeline_schedule(model, data, optimizer, num_microbatches=4): + + # forward, backward, update function + f = partial(forward, model=model) + b = partial(backward) + u = partial(update, optimizer=optimizer) + + # suppose we have 4 devices using 1f1b with num micro-batches=4 + chunk_size = data.size(0) / 4 + data = [ + select(data, slice(chunk_size * 0, chunk_size * 1)), + select(data, slice(chunk_size * 1, chunk_size * 2)), + select(data, slice(chunk_size * 2, chunk_size * 3)), + select(data, slice(chunk_size * 3, chunk_size * 4)) + ] + + f0 = Action(partial(f, data=data[0])) + f1 = Action(partial(f, data=data[1])) + f2 = Action(partial(f, data=data[2])) + f3 = Action(partial(f, data=data[3])) + + b0 = Action(b) + b1 = Action(b) + b2 = Action(b) + b3 = Action(b) + + u = Action(u) + + # add data flow f0 -> b0 -> u + add_flow(f0, b0, u) + add_flow(f1, b1, u) + add_flow(f2, b2, u) + add_flow(f3, b3, u) + + + global_schedule = [ + [f0, f1, f2, f3, b0, b1, b2, b3, u], # rank 0 + [f0, f1, f2, b0, f3, b1, b2, b3, u], # rank 1 + [f0, f1, b0, b2, f1, f3, b2, b3, u], # rank 2 + [f0, b0, f1, b1, f2, b2, f3, b3, u], # rank 3 + ] + + # schedules will be in dead lock + [ + [f0, b0, f1, b1], + [f0, f1, b0, b1], + ] + + return global_schedule[torch.distributed.get_rank()] \ No newline at end of file From 6c3e7da4454ee4654d2e5ae6b54b16ded7a1841d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 21 Aug 2021 07:52:31 +0000 Subject: [PATCH 0137/1892] add full training procedure --- examples/case_study/schedule_primitive.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index 359a827b..b17a6cfa 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -12,12 +12,12 @@ def execute(action, *args, **kwargs): def add_flow(*actions): pass -def run(schedule): +def run(schedule, *args, **kwargs): """ Take a list of actions and execute in list order """ for action in schedule: - outs = execute(action) + outs = execute(action, *args, **kwargs) return outs class Action: pass @@ -99,4 +99,21 @@ def pipeline_schedule(model, data, optimizer, num_microbatches=4): [f0, f1, b0, b1], ] - return global_schedule[torch.distributed.get_rank()] \ No newline at end of file + return global_schedule[torch.distributed.get_rank()] + + +if __name__ == '__main__': + + # define logical model / optimizer / data loader + class LogicalModel: pass + class Optimizer: pass + class DataLoader: pass + + + model = LogicalModel() + optimizer = Optimizer(model.parameters()) + dataloader = DataLoader() + + schedule = pipeline_schedule(model, optimizer, num_microbatches=4) + for data in dataloader: + run(schedule, data) From 295061e207be96f85f18b30bc4f0bb9617d613fa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 Aug 2021 02:01:16 +0000 Subject: [PATCH 0138/1892] schedule pritmive --- examples/case_study/schedule_primitive.py | 37 ++++++++++++----------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index b17a6cfa..a511641a 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -38,9 +38,9 @@ def update(flow_in, optimizer): pass # ===================== Basic steps ================== # -def naive_schedule(model, data, optimizer): +def naive_schedule(model, optimizer): - f = Action(partial(forward, model=model, data=data)) + f = Action(partial(forward, model=model)) b = Action(partial(backward)) u = Action(partial(update, optimizer=optimizer)) @@ -48,10 +48,10 @@ def naive_schedule(model, data, optimizer): schedules = [f, b, u] - return schedules + return partial(run, schedules) -def pipeline_schedule(model, data, optimizer, num_microbatches=4): +def pipeline_schedule(model, optimizer, num_microbatches=4): # forward, backward, update function f = partial(forward, model=model) @@ -59,13 +59,16 @@ def pipeline_schedule(model, data, optimizer, num_microbatches=4): u = partial(update, optimizer=optimizer) # suppose we have 4 devices using 1f1b with num micro-batches=4 - chunk_size = data.size(0) / 4 - data = [ - select(data, slice(chunk_size * 0, chunk_size * 1)), - select(data, slice(chunk_size * 1, chunk_size * 2)), - select(data, slice(chunk_size * 2, chunk_size * 3)), - select(data, slice(chunk_size * 3, chunk_size * 4)) - ] + + def slicer(data): + chunk_size = data.size(0) / 4 + data = [ + select(data, slice(chunk_size * 0, chunk_size * 1)), + select(data, slice(chunk_size * 1, chunk_size * 2)), + select(data, slice(chunk_size * 2, chunk_size * 3)), + select(data, slice(chunk_size * 3, chunk_size * 4)) + ] + d0 = Action(slicer) f0 = Action(partial(f, data=data[0])) f1 = Action(partial(f, data=data[1])) @@ -87,10 +90,10 @@ def pipeline_schedule(model, data, optimizer, num_microbatches=4): global_schedule = [ - [f0, f1, f2, f3, b0, b1, b2, b3, u], # rank 0 - [f0, f1, f2, b0, f3, b1, b2, b3, u], # rank 1 - [f0, f1, b0, b2, f1, f3, b2, b3, u], # rank 2 - [f0, b0, f1, b1, f2, b2, f3, b3, u], # rank 3 + [d0, f0, f1, f2, f3, b0, b1, b2, b3, u], # rank 0 + [f0, f1, f2, b0, f3, b1, b2, b3, u], # rank 1 + [f0, f1, b0, b2, f1, f3, b2, b3, u], # rank 2 + [f0, b0, f1, b1, f2, b2, f3, b3, u], # rank 3 ] # schedules will be in dead lock @@ -99,7 +102,7 @@ def pipeline_schedule(model, data, optimizer, num_microbatches=4): [f0, f1, b0, b1], ] - return global_schedule[torch.distributed.get_rank()] + return partial(run, global_schedule[torch.distributed.get_rank()]) if __name__ == '__main__': @@ -116,4 +119,4 @@ class DataLoader: pass schedule = pipeline_schedule(model, optimizer, num_microbatches=4) for data in dataloader: - run(schedule, data) + schedule(data=data) From b4e324e7d0b2d3a51b4ac52a67b7bda2a2d40a0f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 Aug 2021 02:40:04 +0000 Subject: [PATCH 0139/1892] schedule + parallel primitive --- examples/case_study/parallel_primitive.py | 113 ++++++++++++++-------- examples/case_study/schedule_primitive.py | 2 +- 2 files changed, 71 insertions(+), 44 deletions(-) diff --git a/examples/case_study/parallel_primitive.py b/examples/case_study/parallel_primitive.py index 6a0e79d9..36f7a3e7 100644 --- a/examples/case_study/parallel_primitive.py +++ b/examples/case_study/parallel_primitive.py @@ -1,69 +1,96 @@ import torch import os -from cube.device.physic.grou import DeviceGroup() + +from functools import partial torch.manual_seed(121) +# select from logical tensor with indices -> generate a logical tensor +def select(tensor, indices, val_op, shape): pass + +# deploy logical tensor to devices +def deploy(tensor, ranks): pass + +# merge logical tensors at `ranks` devices +def merge(tensor, ranks, merge_op): pass + + +class LogicalTensor: pass +class PhyiscalTensor: pass -def linear_tensor_parallel(inputs, weight, bias): - rank = DeviceGroup().rank +def linear_tensor_parallel(inputs, weight, bias, output): + """ + inputs: (M, K) + weight: (N, K) + bias: (N,) + output: (M, N) + + Perform: (M, K) * (\delta N, K) + (\delta N,) = (M, \delta N) + """ M = 1024 K = 1024 N = 1024 + # Tensor Split # -- System + policy generated inputs = select( tensor = inputs, indices = (slice(0, M), slice(0, K)), val_op = None, - ranks = [0, 1, 2, 3], shape = (M, K) ) - weight = select( - tensor = weight, - indices = (slice(rank * (N // 4), (rank + 1) * (N // 4)), slice(0, K)), - val_op = None, - ranks = [0, 1, 2, 3], - shape = (N // 4, K) - ) + weights, biases, outputs = list(), list(), list() + for cid in range(4): + weights.append(select( + tensor = weight, + indices = (slice(cid * (N // 4), (cid + 1) * (N // 4)), slice(0, K)), + val_op = None, + shape = (N // 4, K) + )) - bias = select( - tensor = bias, - indices = (slice(rank * (N // 4), (rank + 1) * (N // 4)),), - val_op = None, - ranks = [0, 1, 2, 3], - shape = (N // 4,) - ) + biases.append(select( + tensor = bias, + indices = (slice(cid * (N // 4), (cid + 1) * (N // 4)),), + val_op = None, + shape = (N // 4,) + )) - ### ============ Compute ============ ### - # each rank do this - output = torch._C._nn.linear(inputs, weight, bias) - ### ============ Compute ============ ### + outputs.append(select( + tensor = output, + indices = (slice(slice(0, M), cid * (N // 4), (cid + 1) * (N // 4)),), + val_op = None, + shape = (M, N // 4) + )) + # Tensor Split # - output = merge( - tensor = output, - ranks = [0, 1, 2, 3], - merge_op = all_gather + # Tensor Deployment # -- System + policy generated + inputs = deploy( + segment = inputs, + ranks = [0, 1, 2, 3] ) + for rank, (weight, bias) in enumerate(zip(weights, biases)): + weight = deploy( + segment = weight, + ranks = [rank], + ) + bias = deploy( + segment = bias, + ranks = [rank], + ) + # Tensor Deployment # + # Compute # -- Expert specified + for weight, bias, output in enumerate(zip(weights, biases, outputs)): + # physical tensor + chunk = torch._C._nn.linear(inputs, weight, bias) + output.fill(chunk) - # inputs = deploy( - # segment = inputs_segment, - # ranks = [0, 1, 2, 3], - # val_map_op = IdentityForwardAllreduceBackward - # ) - - # weight = deploy( - # segment = weight_segment, - # ranks = [rank], - # val_map_op = None - # ) - - # bias = deploy( - # segment = bias_segment, - # ranks = [rank], - # val_map_op = None - # ) \ No newline at end of file + # Generate logical tensor -- System generated + merge( + tensor = output, + ranks = [0, 1, 2, 3], + merge_op = partial(all_gather, dim=1) + ) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index a511641a..b126cab0 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -4,7 +4,7 @@ ## Primitive ## -def select(tensor, indices, val_map_op=None): +def select(tensor, indices, val_map_op=None, shape=None): pass def execute(action, *args, **kwargs): From 1fb41d72c05fd10bcae85cf94fd236ef381185d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 Aug 2021 07:27:34 +0000 Subject: [PATCH 0140/1892] using only forward + backward --- examples/case_study/schedule_primitive.py | 166 ++++++++++++++-------- 1 file changed, 108 insertions(+), 58 deletions(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index b126cab0..5d0770c3 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -7,17 +7,32 @@ def select(tensor, indices, val_map_op=None, shape=None): pass -def execute(action, *args, **kwargs): - return action(*args, **kwargs) +def execute(action, **kwargs): + # action instance will automatically take flow-in results + # and select the chunked kwargs + return action(**kwargs) -def add_flow(*actions): pass +def add_flow(*actions): + # this will set all input actions with same flow-id + pass -def run(schedule, *args, **kwargs): +def run(schedule, num_microbs, *args): """ Take a list of actions and execute in list order """ + chunked_args = list() + for arg in args: + if torch.is_tensor(arg): + chunk_size = data.size(0) / num_microbs + arg = [ + select(arg, slice(chunk_size * 0, chunk_size * 1)), + select(arg, slice(chunk_size * 1, chunk_size * 2)), + select(arg, slice(chunk_size * 2, chunk_size * 3)), + select(arg, slice(chunk_size * 3, chunk_size * 4)) + ] + chunked_args.append(arg) for action in schedule: - outs = execute(action, *args, **kwargs) + outs = execute(action, *tuple(args)) return outs class Action: pass @@ -30,79 +45,92 @@ def general_action(flow_in, *args, **kwargs): """ pass -def forward(flow_in, model, data): pass +def forward(flow_in, model, data): + loss = model(data) + return loss -def backward(flow_in): pass +def backward(flow_in): + flow_in.backwrd() + return flow_in -def update(flow_in, optimizer): pass # ===================== Basic steps ================== # -def naive_schedule(model, optimizer): - - f = Action(partial(forward, model=model)) - b = Action(partial(backward)) - u = Action(partial(update, optimizer=optimizer)) - - add_flow(f, b ,u) - - schedules = [f, b, u] - - return partial(run, schedules) +def naive_schedule(f, b): + f = Action(f) + b = Action(b) + add_flow(f, b) + schedules = [f, b] + return partial(run, schedules, num_microbs=1) -def pipeline_schedule(model, optimizer, num_microbatches=4): +def grad_accumulate(f, b, accum_times=4): + forwards = [Action(f, fid=fid) for fid in range(accum_times)] + backwards = [Action(b, fid=fid) for fid in range(accum_times)] + schedules = list() + for f, b in zip(forwards, backwards): + add_flow(f, b) + schedules += [f, b] + return partial(run, schedules, num_microbs=accum_times) - # forward, backward, update function - f = partial(forward, model=model) - b = partial(backward) - u = partial(update, optimizer=optimizer) +def pipeline_schedule(f, b, num_microbs=4): + """ + f: forward function + b: backward function + """ # suppose we have 4 devices using 1f1b with num micro-batches=4 - def slicer(data): - chunk_size = data.size(0) / 4 - data = [ - select(data, slice(chunk_size * 0, chunk_size * 1)), - select(data, slice(chunk_size * 1, chunk_size * 2)), - select(data, slice(chunk_size * 2, chunk_size * 3)), - select(data, slice(chunk_size * 3, chunk_size * 4)) - ] - d0 = Action(slicer) - - f0 = Action(partial(f, data=data[0])) - f1 = Action(partial(f, data=data[1])) - f2 = Action(partial(f, data=data[2])) - f3 = Action(partial(f, data=data[3])) + f0 = Action(partial(f), fid=0) + f1 = Action(partial(f), fid=1) + f2 = Action(partial(f), fid=2) + f3 = Action(partial(f), fid=3) b0 = Action(b) b1 = Action(b) b2 = Action(b) b3 = Action(b) - u = Action(u) - - # add data flow f0 -> b0 -> u - add_flow(f0, b0, u) - add_flow(f1, b1, u) - add_flow(f2, b2, u) - add_flow(f3, b3, u) + # add data flow f0 -> b0 + add_flow(f0, b0) + add_flow(f1, b1) + add_flow(f2, b2) + add_flow(f3, b3) global_schedule = [ - [d0, f0, f1, f2, f3, b0, b1, b2, b3, u], # rank 0 - [f0, f1, f2, b0, f3, b1, b2, b3, u], # rank 1 - [f0, f1, b0, b2, f1, f3, b2, b3, u], # rank 2 - [f0, b0, f1, b1, f2, b2, f3, b3, u], # rank 3 + [f0, f1, f2, f3, b0, b1, b2, b3], # rank 0 + [f0, f1, f2, b0, f3, b1, b2, b3], # rank 1 + [f0, f1, b0, b2, f1, f3, b2, b3], # rank 2 + [f0, b0, f1, b1, f2, b2, f3, b3], # rank 3 ] + myschedule = global_schedule[torch.distributed.get_rank()] + # schedules will be in dead lock - [ - [f0, b0, f1, b1], - [f0, f1, b0, b1], - ] + # [ + # [f0, b0, f1, b1], + # [f0, f1, b0, b1], + # ] + + return partial(run, myschedule, num_microbs=num_microbs) - return partial(run, global_schedule[torch.distributed.get_rank()]) + +def dist_policy(DAG, resources): + """ + Policy decided the parallelisms and op-placement + """ + return DAG + + +def schedule_policy(model, forward_fn, backward_fn, bs): + """ + forward_fn: forward function + backward_fn: backward_function + bs: global batch size + """ + num_microbs = 4 if bs >= 4 else bs + return pipeline_schedule(forward_fn, backward_fn, num_microbs) if __name__ == '__main__': @@ -115,8 +143,30 @@ class DataLoader: pass model = LogicalModel() optimizer = Optimizer(model.parameters()) - dataloader = DataLoader() + dataloader = DataLoader(bs=1024) - schedule = pipeline_schedule(model, optimizer, num_microbatches=4) - for data in dataloader: - schedule(data=data) + def forward_step(flow_in, data, label, **kwargs): + # this requires loss computation needs to be in the model + output = model(data, label) + return output + + def backward_step(output, **kwargs): + output.backward() + return output + + # policy for placement and parallelisms + model = dist_policy(get_dag(model, input_shapes), resources) + # data flow scheduling policy + schedule = schedule_policy(model, forward_step, backward_step, bs=1024) + + for epoch in range(100): + for step, (data, label) in enumerate(dataloader): + loss = schedule(data=data) + optimizer.step() + # lr_scheduler.step() + optimizer.zero_grad() + print(loss) + + if (epoch + 1) % 4 == 0: + model.eval() + # evaluation \ No newline at end of file From 721cb4a4b62b5f392b9a0fd16cd260597bd6e25b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 24 Aug 2021 02:45:09 +0000 Subject: [PATCH 0141/1892] pipeline poc to stop in middle --- examples/poc/pipeline.py | 71 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 examples/poc/pipeline.py diff --git a/examples/poc/pipeline.py b/examples/poc/pipeline.py new file mode 100644 index 00000000..0d6e6141 --- /dev/null +++ b/examples/poc/pipeline.py @@ -0,0 +1,71 @@ +""" +This is to check whether backward can be stopped in the middle + +Verified by using `detach()`, `requires_grad_()` and `retain_grad()` +""" + + +import torch +from torch import nn + +torch.manual_seed(100) + + +class LinearModel(nn.Module): + + def __init__(self, dim): + super().__init__() + self.linear1 = nn.Linear(dim, dim) + self.linear2 = nn.Linear(dim, dim) + self.linear3 = nn.Linear(dim, dim) + self.linear4 = nn.Linear(dim, dim) + + def forward(self, x): + x2_ = None + + x1 = self.linear1(x) + + x2 = self.linear2(x1) + + x2_ = x2.detach() + x2_.requires_grad_() + x2_.retain_grad() + x3 = self.linear3(x2_) + + x4 = self.linear4(x3) + + return x4, x2, x2_ + + +if __name__ == '__main__': + + bs = 32 + dim = 1024 + + model = LinearModel(dim) + model = model.cuda() + + inputs = torch.randn((bs, dim), device=torch.device('cuda:0')) + + output, x2, x2_ = model(inputs) + loss = torch.sum(output) + + # check before backward grads + # print('before linear1 weight grad:\n{}'.format(model.linear1.weight.grad)) + # print('before linear2 weight grad:\n{}'.format(model.linear3.weight.grad)) + # print('before x2 tensor:\n{}'.format(x2.grad)) + # print('===============================') + assert model.linear1.weight.grad is None + assert model.linear2.weight.grad is None + + loss.backward() + assert model.linear1.weight.grad is None + assert torch.is_tensor(model.linear3.weight.grad) is True + # print('after linear1 weight grad :\n{}'.format(model.linear1.weight.grad)) + # print('after linear2 weight grad :\n{}'.format(model.linear3.weight.grad)) + # print('after x2 tensor:\n{}'.format(x2.grad)) + + torch.autograd.backward(x2, grad_tensors=x2_.grad) + assert torch.is_tensor(model.linear1.weight.grad) is True + # print('===============================') + # print('after autograd linear1 weight grad :\n{}'.format(model.linear1.weight.grad)) From 5c5c6a536e77109523994aa3b8c06e1d2b676869 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 24 Aug 2021 06:38:24 +0000 Subject: [PATCH 0142/1892] add timer --- cube/utils.py | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 cube/utils.py diff --git a/cube/utils.py b/cube/utils.py new file mode 100644 index 00000000..5bd9d229 --- /dev/null +++ b/cube/utils.py @@ -0,0 +1,89 @@ +import sys +import torch +import time + + +def print_each_rank(msg, rank_only=None, outfile=''): + myrank = torch.distributed.get_rank() + outfile = sys.stdout if outfile == '' else outfile + for rank in range(torch.distributed.get_world_size()): + if rank_only is None: + if myrank == rank: + f = open(outfile, 'a') if outfile != sys.stdout else sys.stdout + f.write('rank [{}]: {}\n'.format(rank, msg)) + if outfile != sys.stdout: + f.close() + else: + if myrank == rank_only and rank_only == rank: + f = open(outfile, 'a') if outfile != sys.stdout else sys.stdout + f.write('rank [{}]: {}\n'.format(rank, msg)) + if outfile != sys.stdout: + f.close() + torch.distributed.barrier() + + +class CudaTimer: + + """ + Singleton Timer + """ + + class __CudaTimer: + + def __init__(self): + self.start_t = None + self.stop_t = None + self.field = dict() + self.field_data = dict() + + instance = None + + def __init__(self): + if not CudaTimer.instance: + CudaTimer.instance = CudaTimer.__CudaTimer() + + def start(self, field_name='default'): + torch.cuda.synchronize() + if field_name not in CudaTimer.instance.field: + CudaTimer.instance.field[field_name] = list() + CudaTimer.instance.field_data[field_name] = 0 + CudaTimer.instance.field[field_name].append(time.time()) + + def stop(self, field_name='default'): + """ + Return in ms + """ + if field_name not in CudaTimer.instance.field: + raise RuntimeError("Missing start on the field") + torch.cuda.synchronize() + stop_time = time.time() + start_time = CudaTimer.instance.field[field_name].pop(-1) + span = stop_time - start_time # in seconds + CudaTimer.instance.field_data[field_name] += span + return span + + def duration(self, times, field_name='default'): + if field_name not in CudaTimer.instance.field: + raise RuntimeError(f"Missing start on the field {field_name}") + if len(CudaTimer.instance.field[field_name]) != 0: + raise RuntimeError(f"timer for field {field_name} not stopped") + return CudaTimer.instance.field_data[field_name] / times * 1000 # in ms + + def __getattr__(self, name): + return getattr(self.instance, name) + + def clear(self): + CudaTimer.instance = CudaTimer.__CudaTimer() + + def print_all(self, times): + msg = list() + comm_span = 0 + for field_name in CudaTimer.instance.field_data: + span = self.duration(times, field_name) + if 'send' in field_name or 'recv' in field_name: + comm_span += span + msg.append('{} : {:.2f} ms'.format(field_name, span)) + msg.append('{} : {:.2f} ms'.format('communication', comm_span)) + msg = ' | '.join(msg) + + print_each_rank(msg) From eb639129643d9e3d45bdf1d379fb16d66d92b83c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 25 Aug 2021 01:49:11 +0000 Subject: [PATCH 0143/1892] update pritmitive --- examples/case_study/schedule_primitive.py | 60 ++++++++++++++++------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index 5d0770c3..f511e095 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -123,14 +123,15 @@ def dist_policy(DAG, resources): return DAG -def schedule_policy(model, forward_fn, backward_fn, bs): +def set_schedule_policy(model, specific_schedule, bs): """ forward_fn: forward function backward_fn: backward_function bs: global batch size """ num_microbs = 4 if bs >= 4 else bs - return pipeline_schedule(forward_fn, backward_fn, num_microbs) + schedule = pipeline_schedule(model.forward, backward, num_microbs) + model.set_schedule(schedule) if __name__ == '__main__': @@ -145,23 +146,33 @@ class DataLoader: pass optimizer = Optimizer(model.parameters()) dataloader = DataLoader(bs=1024) - def forward_step(flow_in, data, label, **kwargs): - # this requires loss computation needs to be in the model - output = model(data, label) - return output - - def backward_step(output, **kwargs): - output.backward() - return output - - # policy for placement and parallelisms - model = dist_policy(get_dag(model, input_shapes), resources) - # data flow scheduling policy - schedule = schedule_policy(model, forward_step, backward_step, bs=1024) + # def forward_step(flow_in, data, label, **kwargs): + # # this requires loss computation needs to be in the model + # # output = model(data, label) + # output = model(data, label) + # # function wrapper + # loss = compute_loss(output) + # return output + # + # def backward_step(output, **kwargs): + # output.backward() + # return output + + # policy for placement and parallelisms -- will be hidden + model = dist_policy(get_dag(model, loss_compute, input_shapes), resources) + # data flow scheduling policy -- will be hidden + set_schedule_policy(model, pipeline_schedule, bs=1024) for epoch in range(100): for step, (data, label) in enumerate(dataloader): - loss = schedule(data=data) + # enqueue forward specfied by schedule and execute the first one + output = model(data) + # accessing partial output data without generation will rase warning + # pop forward until to generate the backward tensor + loss = compute_loss(output, label) + loss.backward() + + # loss = schedule(data=data) optimizer.step() # lr_scheduler.step() optimizer.zero_grad() @@ -169,4 +180,19 @@ def backward_step(output, **kwargs): if (epoch + 1) % 4 == 0: model.eval() - # evaluation \ No newline at end of file + # evaluation + + +# class Model: +# +# def forward(self, data): +# # non-torch op wrapper +# +# if data[0] > 1: +# self.net1(data) +# else: +# self.net2(data) +# +# def forward_(self, data) +# +# self.q = [(self.forward, data[0]), (xxx)] \ No newline at end of file From d4e09ca20365dfbd7ba8166cd4a4a735a7f10d02 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 26 Aug 2021 07:47:16 +0000 Subject: [PATCH 0144/1892] revise to use consistency order --- examples/case_study/schedule_primitive.py | 163 ++++++++++++++++------ 1 file changed, 117 insertions(+), 46 deletions(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index f511e095..85b6d3ff 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -1,3 +1,4 @@ +from typing import Sequence import torch from functools import partial @@ -74,46 +75,56 @@ def grad_accumulate(f, b, accum_times=4): return partial(run, schedules, num_microbs=accum_times) -def pipeline_schedule(f, b, num_microbs=4): +def pipeline_1f1b_schedule(forward, backward, update, num_stages=2, num_microbs=4): """ f: forward function b: backward function - """ - # suppose we have 4 devices using 1f1b with num micro-batches=4 - - f0 = Action(partial(f), fid=0) - f1 = Action(partial(f), fid=1) - f2 = Action(partial(f), fid=2) - f3 = Action(partial(f), fid=3) - - b0 = Action(b) - b1 = Action(b) - b2 = Action(b) - b3 = Action(b) - # add data flow f0 -> b0 - add_flow(f0, b0) - add_flow(f1, b1) - add_flow(f2, b2) - add_flow(f3, b3) - - + Suppose model is partitioned to `num_stages` with input `num_microbs` micro-batches + """ + # suppose we have 2 stages using 1f1b with num micro-batches=4 + + # f[stage_id, data_id] + partial_sequences = [] + for data_id in range(num_microbs): + one_mbatch = PartialSequence() + for stage_id in range(num_stages): + one_mbatch.append(Action(forward)) + for stage_id in range(num_stages): + one_mbatch.append(Action(backward)) + if data_id == num_microbs - 1: + one_mbatch.append(Action(update)) + partial_sequences.append(one_mbatch) + for S in range(num_stages): + seq = PartialSequence([partial_sequences[-1-S][-num_stages-1]], Action(update)) + partial_sequences.append(seq) + + + # Action f[stage, micro-batch] + # f[S, D]: forward on stage S for micro-batch id D + f = [partial_sequences[i][:num_stages] for i in range(num_microbs)] + + # Action b[stage, micro-batch] + # b[S, D]: backward on stage S for micro-batch id D + b = [partial_sequences[i][num_stages:] for i in range(num_microbs)] + + # Action u[stage, micro-batch] + # u[S, D]: update weight on stage S + u = [partial_sequences[i+num_microbs][1] for i in range(num_stages)] + + + # ========================= + # !@#$#%$&^$# -- policy generated a legal global execution order global_schedule = [ - [f0, f1, f2, f3, b0, b1, b2, b3], # rank 0 - [f0, f1, f2, b0, f3, b1, b2, b3], # rank 1 - [f0, f1, b0, b2, f1, f3, b2, b3], # rank 2 - [f0, b0, f1, b1, f2, b2, f3, b3], # rank 3 + f[0,0], f[1,0], b[1,0], + f[0,1], b[0,0], f[1,1], b[1,1], + f[0,2], b[0,1], f[1,2], b[0,2], + f[0,3], b[0,2], f[1,3], b[1,3], u[1], + u[0] ] + # ========================= - myschedule = global_schedule[torch.distributed.get_rank()] - - # schedules will be in dead lock - # [ - # [f0, b0, f1, b1], - # [f0, f1, b0, b1], - # ] - - return partial(run, myschedule, num_microbs=num_microbs) + return global_schedule def dist_policy(DAG, resources): @@ -183,16 +194,76 @@ class DataLoader: pass # evaluation -# class Model: -# -# def forward(self, data): -# # non-torch op wrapper -# -# if data[0] > 1: -# self.net1(data) -# else: -# self.net2(data) -# -# def forward_(self, data) -# -# self.q = [(self.forward, data[0]), (xxx)] \ No newline at end of file +def train_iter_grad_accumulate(model, datas, stage=2, micro_bs=4): + + out_s0_d0 = forward(model[0], datas[0]) + out_s1_d0 = forward(model[1], out_s0_d0) + grad_s1_d0 = backward(out_s1_d0) + grad_s0_d0 = backward(out_s0_d0, grad=grad_s1_d0) + + out_s0_d1 = forward(model[0], datas[1]) + out_s1_d1 = forward(model[1], out_s0_d1) + grad_s1_d1 = backward(out_s1_d1) + grad_s0_d1 = backward(out_s0_d0, grad=grad_s1_d1) + + out_s0_d2 = forward(model[0], datas[2]) + out_s1_d2 = forward(model[1], out_s0_d2) + grad_s1_d2 = backward(out_s1_d2) + grad_s0_d2 = backward(out_s0_d0, grad=grad_s1_d2) + + out_s0_d3 = forward(model[0], datas[3]) + out_s1_d3 = forward(model[1], out_s0_d3) + grad_s1_d3 = backward(out_s1_d3) + grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) + + update_gradient(model[0]) + update_gradient(model[1]) + + +def train_iter_1f1b(model, datas, stage=2, micro_bs=4): + + out_s0_d0 = forward(model[0], datas[0]) + out_s1_d0 = forward(model[1], out_s0_d0) + grad_s1_d0 = backward(out_s1_d0) + + out_s0_d1 = forward(model[0], datas[1]) + grad_s0_d0 = backward(out_s0_d0, grads=grad_s1_d0) + out_s1_d1 = forward(model[1], out_s0_d1) + grad_s1_d1 = backward(out_s1_d1) + + out_s0_d2 = forward(model[0], datas[2]) + grad_s0_d1 = backward(out_s0_d0, grad=grad_s1_d1) + out_s1_d2 = forward(model[1], out_s0_d2) + grad_s1_d2 = backward(out_s1_d2) + + out_s0_d3 = forward(model[0], datas[3]) + grad_s0_d2 = backward(out_s0_d0, grad=grad_s1_d2) + out_s1_d3 = forward(model[1], out_s0_d3) + grad_s1_d3 = backward(out_s1_d3) + update_gradient(model[1]) + + grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) + update_gradient(model[0]) + + +def train_iter_gpipe(model, datas, stage=2, micro_bs=4): + + out_s0_d0 = forward(model[0], datas[0]) + out_s1_d0 = forward(model[1], out_s0_d0) + out_s0_d1 = forward(model[0], datas[1]) + out_s1_d1 = forward(model[1], out_s0_d1) + out_s0_d2 = forward(model[0], datas[2]) + out_s1_d2 = forward(model[1], out_s0_d2) + out_s0_d3 = forward(model[0], datas[3]) + out_s1_d3 = forward(model[1], out_s0_d3) + + grad_s1_d0 = backward(out_s1_d0) + grad_s0_d0 = backward(out_s0_d0, grad=grad_s1_d0) + grad_s1_d1 = backward(out_s1_d1) + grad_s0_d1 = backward(out_s0_d0, grad=grad_s1_d1) + grad_s1_d2 = backward(out_s1_d2) + grad_s0_d2 = backward(out_s0_d0, grad=grad_s1_d2) + grad_s1_d3 = backward(out_s1_d3) + update_gradient(model[1]) + grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) + update_gradient(model[0]) From 5ad29922c6f73875ffb03628a690820e52bb52b4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 27 Aug 2021 02:00:26 +0000 Subject: [PATCH 0145/1892] gradient update with model weight grad --- examples/case_study/schedule_primitive.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index 85b6d3ff..bc34a987 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -216,8 +216,8 @@ def train_iter_grad_accumulate(model, datas, stage=2, micro_bs=4): grad_s1_d3 = backward(out_s1_d3) grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) - update_gradient(model[0]) - update_gradient(model[1]) + update_gradient(model[0], model[0].weights.grad) + update_gradient(model[1], model[1].weights.grad) def train_iter_1f1b(model, datas, stage=2, micro_bs=4): @@ -240,10 +240,10 @@ def train_iter_1f1b(model, datas, stage=2, micro_bs=4): grad_s0_d2 = backward(out_s0_d0, grad=grad_s1_d2) out_s1_d3 = forward(model[1], out_s0_d3) grad_s1_d3 = backward(out_s1_d3) - update_gradient(model[1]) + update_gradient(model[1], model[1].weights.grad) grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) - update_gradient(model[0]) + update_gradient(model[0], model[0].weights.grad) def train_iter_gpipe(model, datas, stage=2, micro_bs=4): @@ -264,6 +264,6 @@ def train_iter_gpipe(model, datas, stage=2, micro_bs=4): grad_s1_d2 = backward(out_s1_d2) grad_s0_d2 = backward(out_s0_d0, grad=grad_s1_d2) grad_s1_d3 = backward(out_s1_d3) - update_gradient(model[1]) + update_gradient(model[1], model[1].weights.grad) grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) - update_gradient(model[0]) + update_gradient(model[0], model[0].weights.grad) From 5beadc79036fc82c74c024ad95fefa09dfe0833f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 Aug 2021 02:52:35 +0000 Subject: [PATCH 0146/1892] update temporal execution logic --- examples/case_study/schedule_primitive.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index bc34a987..81527200 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -21,6 +21,7 @@ def run(schedule, num_microbs, *args): """ Take a list of actions and execute in list order """ + myrank = torch.distributed.get_rank() chunked_args = list() for arg in args: if torch.is_tensor(arg): @@ -33,7 +34,11 @@ def run(schedule, num_microbs, *args): ] chunked_args.append(arg) for action in schedule: - outs = execute(action, *tuple(args)) + if action.device == myrank: + # wait for cross-device dependency (if have) + action.wait() + # execute + outs = execute(action, *tuple(args)) return outs class Action: pass From 5ebdd966049625858c7897111cdab27677bee34b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Sep 2021 07:26:18 +0000 Subject: [PATCH 0147/1892] add 1f1b description --- examples/case_study/schedule_primitive.py | 83 ++++++++++------------- 1 file changed, 36 insertions(+), 47 deletions(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index 81527200..17ea6ea8 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -43,6 +43,8 @@ def run(schedule, num_microbs, *args): class Action: pass +def check_consistency(sequence, actions, relations): pass + # ===================== Basic steps ================== # def general_action(flow_in, *args, **kwargs): @@ -80,56 +82,43 @@ def grad_accumulate(f, b, accum_times=4): return partial(run, schedules, num_microbs=accum_times) -def pipeline_1f1b_schedule(forward, backward, update, num_stages=2, num_microbs=4): +def pipeline_1f1b_schedules(actions, relations): """ - f: forward function - b: backward function + Pipeline 1f1b policy description + + Actions: a list of actions - Suppose model is partitioned to `num_stages` with input `num_microbs` micro-batches + relations: list[(Action1, Action2)]: a list of tuples indicate partial order """ - # suppose we have 2 stages using 1f1b with num micro-batches=4 - - # f[stage_id, data_id] - partial_sequences = [] - for data_id in range(num_microbs): - one_mbatch = PartialSequence() - for stage_id in range(num_stages): - one_mbatch.append(Action(forward)) - for stage_id in range(num_stages): - one_mbatch.append(Action(backward)) - if data_id == num_microbs - 1: - one_mbatch.append(Action(update)) - partial_sequences.append(one_mbatch) - for S in range(num_stages): - seq = PartialSequence([partial_sequences[-1-S][-num_stages-1]], Action(update)) - partial_sequences.append(seq) - - - # Action f[stage, micro-batch] - # f[S, D]: forward on stage S for micro-batch id D - f = [partial_sequences[i][:num_stages] for i in range(num_microbs)] - - # Action b[stage, micro-batch] - # b[S, D]: backward on stage S for micro-batch id D - b = [partial_sequences[i][num_stages:] for i in range(num_microbs)] - - # Action u[stage, micro-batch] - # u[S, D]: update weight on stage S - u = [partial_sequences[i+num_microbs][1] for i in range(num_stages)] - - - # ========================= - # !@#$#%$&^$# -- policy generated a legal global execution order - global_schedule = [ - f[0,0], f[1,0], b[1,0], - f[0,1], b[0,0], f[1,1], b[1,1], - f[0,2], b[0,1], f[1,2], b[0,2], - f[0,3], b[0,2], f[1,3], b[1,3], u[1], - u[0] - ] - # ========================= - - return global_schedule + + # suppose input actions are forward and backward of grad accumulation + # suppose in forward -> ... -> forward -> backward -> ... -> backward + num_stage = torch.distributed.get_world_size() + num_micro_batch = len(actions) / 2 / num_stage + + f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] + b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + stage] + + sequence = list() + + # warmup: + for stage in range(num_stage): + for mid in range(stage): + sequence.append(f(stage, mid)) + + # steady + cooldown: + for mid in range(num_micro_batch): + # enqueue backward + for stage in range(num_stage-1, -1, -1): + sequence.append(b(stage, mid)) + # enqueue forward + for stage in range(num_stage): + f_mid = mid + 1 + num_stage - stage + if f_mid >= num_micro_batch: + continue + sequence.append(f(stage, f_mid)) + assert check_consistency(sequence, actions, relations) + return sequence def dist_policy(DAG, resources): From 0207d60cde833d1a81f3eb11b0a8773ea727c5d8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Sep 2021 08:10:14 +0000 Subject: [PATCH 0148/1892] add 1f1b description --- examples/case_study/schedule_primitive.py | 66 +++++++++++++++++++++-- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/schedule_primitive.py index 17ea6ea8..79da99ee 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/schedule_primitive.py @@ -84,7 +84,7 @@ def grad_accumulate(f, b, accum_times=4): def pipeline_1f1b_schedules(actions, relations): """ - Pipeline 1f1b policy description + Pipeline 1f1b policy description -- generate a sequence Actions: a list of actions @@ -94,11 +94,17 @@ def pipeline_1f1b_schedules(actions, relations): # suppose input actions are forward and backward of grad accumulation # suppose in forward -> ... -> forward -> backward -> ... -> backward num_stage = torch.distributed.get_world_size() - num_micro_batch = len(actions) / 2 / num_stage + num_microbatch = len(actions) / 2 / num_stage f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + stage] + # action placement + for stage in range(num_stage): + for mid in range(num_microbatch): + f(stage, mid).device = torch.device.cuda(stage) + b(stage, mid).device = torch.device.cuda(stage) + sequence = list() # warmup: @@ -107,20 +113,72 @@ def pipeline_1f1b_schedules(actions, relations): sequence.append(f(stage, mid)) # steady + cooldown: - for mid in range(num_micro_batch): + for mid in range(num_microbatch): # enqueue backward for stage in range(num_stage-1, -1, -1): sequence.append(b(stage, mid)) # enqueue forward for stage in range(num_stage): f_mid = mid + 1 + num_stage - stage - if f_mid >= num_micro_batch: + if f_mid >= num_microbatch: continue sequence.append(f(stage, f_mid)) assert check_consistency(sequence, actions, relations) return sequence +def pipeline_1f1b_schedule(actions, relations): + """ + Pipeline 1f1b policy description -- each device order + + Actions: a list of actions + + relations: list[(Action1, Action2)]: a list of tuples indicate partial order + """ + num_stage = torch.distributed.get_world_size() + num_microbatch = len(actions) / 2 / num_stage + + f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] + b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + stage] + + # action placement + for stage in range(num_stage): + for mid in range(num_microbatch): + f(stage, mid).device = torch.device.cuda(stage) + b(stage, mid).device = torch.device.cuda(stage) + + # action in-device order + stage_order = list() + + for stage in range(num_stage): + order = list() + num_warmup_microbatch = num_stage - stage - 1 + num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) + num_microbatch_remain = num_microbatch - num_warmup_microbatch + + # warmup + for mid in range(num_warmup_microbatch): + order.append(f(stage, mid)) + + # steady + for i in range(num_microbatch_remain): + f_mid = num_warmup_microbatch + i + b_mid = i + order.append(f(stage, f_mid)) + order.append(b(stage, b_mid)) + + # cooldown + for i in range(num_warmup_microbatch): + b_mid = num_microbatch_remain + i + order.append(b(stage, b_mid)) + + stage_order.append(order) + + assert check_consistency(stage_order, actions, relations) + return stage_order + + + def dist_policy(DAG, resources): """ Policy decided the parallelisms and op-placement From e7fd6d5feba3258c8f95453c4210d7833bd87081 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Sep 2021 11:46:15 +0000 Subject: [PATCH 0149/1892] start search plans --- cube/schedule/__init__.py | 0 cube/schedule/action.py | 65 ++++++++++++++++++++++++++++++++++ cube/schedule/checker.py | 32 +++++++++++++++++ cube/schedule/iterator.py | 23 ++++++++++++ examples/poc/pipeline_space.py | 51 ++++++++++++++++++++++++++ 5 files changed, 171 insertions(+) create mode 100644 cube/schedule/__init__.py create mode 100644 cube/schedule/action.py create mode 100644 cube/schedule/checker.py create mode 100644 cube/schedule/iterator.py create mode 100644 examples/poc/pipeline_space.py diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/schedule/action.py b/cube/schedule/action.py new file mode 100644 index 00000000..5c10c387 --- /dev/null +++ b/cube/schedule/action.py @@ -0,0 +1,65 @@ + + +class Action: + + def __init__(self, fn): + """ + fn: a function call to perform a set of operators + """ + self._fn = [fn,] + self.pre_actions = list() + self.outputs = None + self.name = 'NotSet' + + def __call__(self, *args, **kwargs): + """ + Execute the action + """ + outputs = self.get_input() + outputs = self._fn[0](outputs, *args, **kwargs) + self.outputs = outputs + + def get_input(self): + """ + Get input for the flow-ins from pre_actions + """ + raise NotImplementedError + + def add_pre_action(self, action): + self.pre_actions.append(action) + + def depends_on(self, action): + """ + check if the self -> action + + Note: this may return false negative as it will only check + 1-hop dependency + """ + if not isinstance(action, Action): + raise TypeError("Expected action to be an Action") + return action in self.pre_actions + + def tag(self, name): + """ + tag a string to indicate this action (as name) + """ + self.name = name + + def __repr__(self): + return self.name + + +def add_flow(action1, action2): + """ + Add happened before dependency action1 -> action2 + + Args: + action1 (Action) + action2 (Action) + """ + if not isinstance(action1, Action): + raise TypeError("Expected action1 to be an Action") + if not isinstance(action2, Action): + raise TypeError("Expected action2 to be an Anction") + if not action1.depends_on(action2): + action1.add_pre_action(action2) diff --git a/cube/schedule/checker.py b/cube/schedule/checker.py new file mode 100644 index 00000000..51b95ff5 --- /dev/null +++ b/cube/schedule/checker.py @@ -0,0 +1,32 @@ +from cube.schedule.action import Action + + +def correct_check(sequence, actions, relations): + """ + Check if sequence satisfies the sequential consistency model + Args: + sequence (list[Actions]): action sequence + actions (list[Action]): action lists + relations (list(tuple(Action, Action))): + contains happened before tuple list + Returns: + Boolean: whether satisfies the partial order specified in relations + """ + if not all([isinstance(action, Action) for action in sequence]): + raise TypeError("Expected the sequence to be list[Action]") + if not all([isinstance(action, Action) for action in actions]): + raise TypeError("Expected the actions to be list[Action]") + + # check if all Actions in `actions` are used by sequence + if set(sequence) != set(actions): + return False + + # check partial order + for (action1, action2) in relations: + act1_idx = sequence.index(action1) + act2_idx = sequence.index(action2) + if act1_idx >= act2_idx: + return False + + # check passed + return True diff --git a/cube/schedule/iterator.py b/cube/schedule/iterator.py new file mode 100644 index 00000000..12b7cf03 --- /dev/null +++ b/cube/schedule/iterator.py @@ -0,0 +1,23 @@ +from cube.schedule.action import Action +from cube.schedule.checker import correct_check + +import itertools + + +def legal_sequence(actions, relations): + """ + Yield all possible legal sequence given the list of actions + + Args: + actions (list[Actions]) + + Yield: + sequence (list[Actions]) + """ + if not all([isinstance(action, Action) for action in actions]): + raise TypeError("Expected the sequence to be list[Action]") + + for seq in itertools.permutations(actions): + seq = list(seq) + if correct_check(seq, actions, relations): + yield seq diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py new file mode 100644 index 00000000..a60ebab1 --- /dev/null +++ b/examples/poc/pipeline_space.py @@ -0,0 +1,51 @@ +from cube.schedule.action import Action +from cube.schedule.iterator import legal_sequence + +import argparse + +from examples.case_study.schedule_primitive import grad_accumulate + + +def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): + actions = list() + relations = list() + for mid in range(num_microbatch): + # forward + for stage in range(num_stage): + action = Action(forward_fn) + action.tag('f(S{},D{})'.format(stage, mid)) + if stage != 0: + relation = (actions[-1], action) + relations.append(relation) + actions.append(action) + # backward + for stage in range(num_stage): + action = Action(backward_fn) + action.tag('b(S{},D{})'.format(num_stage - 1 - stage, mid)) + relation = (actions[-1], action) + relations.append(relation) + actions.append(action) + return actions, relations + + +def print_all_legal_sequence(actions, relations): + for cnt, seq in enumerate(legal_sequence(actions, relations)): + print(seq) + print('total found {} legal sequences'.format(cnt + 1)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--nstage', type=int, default=4, + help='number of stages') + parser.add_argument('--nmb', type=int, default=4, + help='number of micro-batch') + args = parser.parse_args() + + forward = lambda data: data + backward = lambda grad: grad + + actions, relations = get_semantic(forward, backward, args.nstage, args.nmb) + + print_all_legal_sequence(actions, relations) From 5f35ea5e7547ab4b41929a757153937c935bcce0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Sep 2021 13:01:40 +0000 Subject: [PATCH 0150/1892] work in progress --- cube/schedule/action.py | 3 +- examples/poc/pipeline_space.py | 87 +++++++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/cube/schedule/action.py b/cube/schedule/action.py index 5c10c387..74be96af 100644 --- a/cube/schedule/action.py +++ b/cube/schedule/action.py @@ -10,6 +10,7 @@ def __init__(self, fn): self.pre_actions = list() self.outputs = None self.name = 'NotSet' + self.device = -1 def __call__(self, *args, **kwargs): """ @@ -46,7 +47,7 @@ def tag(self, name): self.name = name def __repr__(self): - return self.name + return self.name+'@{}'.format(self.device) def add_flow(action1, action2): diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index a60ebab1..00f10edc 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -1,9 +1,12 @@ +from typing import Sequence from cube.schedule.action import Action from cube.schedule.iterator import legal_sequence import argparse +import re -from examples.case_study.schedule_primitive import grad_accumulate +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): @@ -16,24 +19,102 @@ def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): action.tag('f(S{},D{})'.format(stage, mid)) if stage != 0: relation = (actions[-1], action) + action.depends_on(actions[-1]) relations.append(relation) actions.append(action) # backward for stage in range(num_stage): action = Action(backward_fn) action.tag('b(S{},D{})'.format(num_stage - 1 - stage, mid)) + # relation relation = (actions[-1], action) + action.depends_on(actions[-1]) + # append to relation sets relations.append(relation) actions.append(action) return actions, relations +def get_stage_and_mid(action): + ids = re.findall(r"S(\d+),D(\d+)", action.name) + stage, mid = int(ids[0][0]), int(ids[0][1]) + return stage, mid + + +def device_search(sequence): + pass + + +def draw_execution_plan(seq, forward_fn, backward_fn, ndevice): + forward_time = 1 + backward_time = 2 + # record each action end time + current_time = [[1] for _ in range(ndevice)] + device_actions = [list() for _ in range(ndevice)] + + recs = dict() + + for action in seq: + if action.device == -1 or action.device >= ndevice: + raise RuntimeError("action {} device not assigned or out of boundary".format(action)) + start_time = current_time[action.device][-1] + for dev_id, (end_times, dev_actions) in enumerate(zip(current_time, device_actions)): + if dev_id == action.device: + continue + # go through to check if the action has dependencies + for end_time, dev_action in zip(end_times, dev_actions): + print(dev_action) + if action.depends_on(dev_action): + print('find dependency {} -> {}, end time: {}'.format(action, dev_action, end_time)) + start_time = max(start_time, end_time) + elif dev_action.depends_on(action): + raise RuntimeError("Action happened before") + # draw regtangular + if action._fn[0] == forward_fn: + span_time = forward_time + color = 'blue' + elif action._fn[0] == backward_fn: + span_time = backward_time + color = 'orange' + # stage, mid = get_stage_and_mid(action) + recs[action.name] = Rectangle((start_time, action.device), span_time, 1, + color=color, ec='black', lw=1.5) + # update timeline + current_time[action.device].append(start_time + span_time) + device_actions[action.device].append(action) + + fig, ax = plt.subplots() + for r in recs: + ax.add_artist(recs[r]) + rx, ry = recs[r].get_xy() + cx = rx + recs[r].get_width() / 2.0 + cy = ry + recs[r].get_height() / 2.0 + ax.annotate(r, (cx, cy), color='w', weight='bold', + fontsize=8, ha='center', va='center') + + ax.set_xlim((1, len(seq) + 1)) + ax.set_ylim((0, ndevice)) + ax.set_aspect('equal') + plt.savefig('./tmp.png') + + def print_all_legal_sequence(actions, relations): for cnt, seq in enumerate(legal_sequence(actions, relations)): print(seq) print('total found {} legal sequences'.format(cnt + 1)) +def fixed_placement_sequence(actions, relations, ndevices, forward, backward): + for cnt, seq in enumerate(legal_sequence(actions, relations)): + # assign device + for action in seq: + stage, mid = get_stage_and_mid(action) + action.device = stage % ndevices + draw_execution_plan(seq, forward, backward, ndevices) + break + print('total found {} legal sequences'.format(cnt + 1)) + + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -41,6 +122,8 @@ def print_all_legal_sequence(actions, relations): help='number of stages') parser.add_argument('--nmb', type=int, default=4, help='number of micro-batch') + parser.add_argument('--ndev', type=int, default=4, + help='number of devices') args = parser.parse_args() forward = lambda data: data @@ -48,4 +131,4 @@ def print_all_legal_sequence(actions, relations): actions, relations = get_semantic(forward, backward, args.nstage, args.nmb) - print_all_legal_sequence(actions, relations) + fixed_placement_sequence(actions, relations, args.ndev, forward, backward) From 90ee42a89665d928f5ad8f0e3bc84775e19811b0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 02:08:32 +0000 Subject: [PATCH 0151/1892] draw the pipeline execution plan --- cube/schedule/action.py | 4 ++-- examples/poc/pipeline_space.py | 30 +++++++++++++++--------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/cube/schedule/action.py b/cube/schedule/action.py index 74be96af..744ca6bf 100644 --- a/cube/schedule/action.py +++ b/cube/schedule/action.py @@ -62,5 +62,5 @@ def add_flow(action1, action2): raise TypeError("Expected action1 to be an Action") if not isinstance(action2, Action): raise TypeError("Expected action2 to be an Anction") - if not action1.depends_on(action2): - action1.add_pre_action(action2) + if not action2.depends_on(action1): + action2.add_pre_action(action1) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index 00f10edc..216bc2dd 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -1,5 +1,5 @@ from typing import Sequence -from cube.schedule.action import Action +from cube.schedule.action import Action, add_flow from cube.schedule.iterator import legal_sequence import argparse @@ -19,7 +19,7 @@ def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): action.tag('f(S{},D{})'.format(stage, mid)) if stage != 0: relation = (actions[-1], action) - action.depends_on(actions[-1]) + add_flow(actions[-1], action) relations.append(relation) actions.append(action) # backward @@ -28,7 +28,7 @@ def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): action.tag('b(S{},D{})'.format(num_stage - 1 - stage, mid)) # relation relation = (actions[-1], action) - action.depends_on(actions[-1]) + add_flow(actions[-1], action) # append to relation sets relations.append(relation) actions.append(action) @@ -62,8 +62,7 @@ def draw_execution_plan(seq, forward_fn, backward_fn, ndevice): if dev_id == action.device: continue # go through to check if the action has dependencies - for end_time, dev_action in zip(end_times, dev_actions): - print(dev_action) + for end_time, dev_action in zip(end_times[1:], dev_actions): if action.depends_on(dev_action): print('find dependency {} -> {}, end time: {}'.format(action, dev_action, end_time)) start_time = max(start_time, end_time) @@ -76,23 +75,23 @@ def draw_execution_plan(seq, forward_fn, backward_fn, ndevice): elif action._fn[0] == backward_fn: span_time = backward_time color = 'orange' - # stage, mid = get_stage_and_mid(action) - recs[action.name] = Rectangle((start_time, action.device), span_time, 1, + recs[action] = Rectangle((start_time, action.device), span_time, 1, color=color, ec='black', lw=1.5) # update timeline current_time[action.device].append(start_time + span_time) device_actions[action.device].append(action) fig, ax = plt.subplots() - for r in recs: - ax.add_artist(recs[r]) - rx, ry = recs[r].get_xy() - cx = rx + recs[r].get_width() / 2.0 - cy = ry + recs[r].get_height() / 2.0 - ax.annotate(r, (cx, cy), color='w', weight='bold', + for action in recs: + stage, mid = get_stage_and_mid(action) + ax.add_artist(recs[action]) + rx, ry = recs[action].get_xy() + cx = rx + recs[action].get_width() / 2.0 + cy = ry + recs[action].get_height() / 2.0 + ax.annotate(f'S{stage}D{mid}', (cx, cy), color='w', weight='bold', fontsize=8, ha='center', va='center') - ax.set_xlim((1, len(seq) + 1)) + ax.set_xlim((1, int(len(seq) * 1.5) + 1)) ax.set_ylim((0, ndevice)) ax.set_aspect('equal') plt.savefig('./tmp.png') @@ -110,8 +109,9 @@ def fixed_placement_sequence(actions, relations, ndevices, forward, backward): for action in seq: stage, mid = get_stage_and_mid(action) action.device = stage % ndevices + print(f'found seq > {seq}') draw_execution_plan(seq, forward, backward, ndevices) - break + input('>>> ') print('total found {} legal sequences'.format(cnt + 1)) From c6606460ee25714cb4fb71ce217d1633cbe75bee Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 05:35:05 +0000 Subject: [PATCH 0152/1892] draw plans and do search --- cube/schedule/action.py | 5 +- cube/schedule/iterator.py | 42 +++++++++++++ cube/schedule/plan.py | 112 +++++++++++++++++++++++++++++++++ examples/poc/pipeline_space.py | 86 +++++++------------------ 4 files changed, 182 insertions(+), 63 deletions(-) create mode 100644 cube/schedule/plan.py diff --git a/cube/schedule/action.py b/cube/schedule/action.py index 744ca6bf..e983253f 100644 --- a/cube/schedule/action.py +++ b/cube/schedule/action.py @@ -9,8 +9,10 @@ def __init__(self, fn): self._fn = [fn,] self.pre_actions = list() self.outputs = None - self.name = 'NotSet' + self.name = 'None' + self.fid = None # flow id self.device = -1 + self.est_latency = 1 def __call__(self, *args, **kwargs): """ @@ -28,6 +30,7 @@ def get_input(self): def add_pre_action(self, action): self.pre_actions.append(action) + self.fid = action.fid def depends_on(self, action): """ diff --git a/cube/schedule/iterator.py b/cube/schedule/iterator.py index 12b7cf03..1b874550 100644 --- a/cube/schedule/iterator.py +++ b/cube/schedule/iterator.py @@ -21,3 +21,45 @@ def legal_sequence(actions, relations): seq = list(seq) if correct_check(seq, actions, relations): yield seq + + +def ready_action_set(actions, relations, flip=False): + """ + Return a list of actions can be executed now + """ + flip = -1 if flip else 1 + ready_actions = list() + for action in actions[::flip]: + satisfy = True + for (_, succ) in relations: + if succ == action: + satisfy = False + break + if satisfy: + ready_actions.append(action) + return ready_actions + + +def remove_dependency(action, relations): + new_relations = list() + for (pre, succ) in relations: + # remove dependency + if pre == action: + continue + new_relations.append((pre, succ)) + return new_relations + + +def sequence_space(actions, relations, seq=list()): + if len(actions) == 0: + yield seq + # inital entry + entry_actions = ready_action_set(actions, relations, flip=len(actions) % 2 == 0) + for action in entry_actions: + seq = seq + [action] + action_idx = actions.index(action) + sub_actions = actions[:action_idx] + actions[action_idx+1:] + sub_relations = remove_dependency(action, relations) + for res in sequence_space(sub_actions, sub_relations, seq): + yield res + seq = seq[:-1] diff --git a/cube/schedule/plan.py b/cube/schedule/plan.py new file mode 100644 index 00000000..f48a7ff1 --- /dev/null +++ b/cube/schedule/plan.py @@ -0,0 +1,112 @@ + +class ExecutionPlan: + + def __init__(self, seq, ndevice): + """ + Seq: action sequence + ndevice: device number + """ + self.seq = seq + self.ndevice = ndevice + self.device_timeline = None + self.device_actions = None + + def gen(self): + """ + Generate execution plan + """ + # timeline: [(start_time, end_time)] + self.device_timeline = [list() for _ in range(self.ndevice)] + self.device_actions = [list() for _ in range(self.ndevice)] + + for action in self.seq: + if action.device == -1 or action.device >= self.ndevice: + raise RuntimeError("action {} device not assigned or out of boundary".format(action)) + if len(self.device_timeline[action.device]) == 0: + start_time = 1 + else: + start_time = self.device_timeline[action.device][-1][1] + for dev_id, (timeline, dev_actions) in enumerate(zip(self.device_timeline, self.device_actions)): + if dev_id == action.device: + continue + # go through to check if the action has dependencies + for (_, end_time), dev_action in zip(timeline[::-1], dev_actions[::-1]): + if action.depends_on(dev_action): + # print('find dependency {} -> {}, end time: {}'.format(action, dev_action, end_time)) + start_time = max(start_time, end_time) + break + elif dev_action.depends_on(action): + raise RuntimeError("Action happened before") + # update timeline + self.device_timeline[action.device].append((start_time, start_time + action.est_latency)) + self.device_actions[action.device].append(action) + + def actions(self, device_id): + """ + Get action sequence for the specific device id + """ + if device_id >= self.ndevice: + raise ValueError(f"device id out of boundary ({device_id} >= {self.ndeivce})") + if self.device_actions is None: + self.gen() + return self.device_actions[device_id] + + def timeline(self, device_id): + """ + Get action timeline for the specific device id + """ + if device_id >= self.ndevice: + raise ValueError(f"device id out of boundary ({device_id} >= {self.ndeivce})") + if self.device_timeline is None: + self.gen() + return self.device_timeline[device_id] + + def get_time(self): + if self.device_timeline is None: + self.gen() + return max( + [timeline[-1][1] for timeline in self.device_timeline] + ) + + def draw(self, outfile='./execplan.png'): + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + plt.rcParams['figure.figsize'] = (12.0, 4.0) + + if self.device_actions is None: + self.gen() + + fig, ax = plt.subplots() + plan_time = self.get_time() + + # xaxis + ax.set_xlim((1, plan_time)) + plt.xticks(list(range(1, plan_time+1, 1))) + ax.xaxis.grid(True, linestyle='--') + plt.xlabel('time') + + # yaxis + ax.set_ylim((0.5, self.ndevice+0.5)) + plt.yticks(list(range(1, self.ndevice+1, 1))) + ax.invert_yaxis() + plt.ylabel('device id') + + ax.set_aspect('equal') + + for devid in range(len(self.device_actions)): + timeline = self.device_timeline[devid] + actions = self.device_actions[devid] + for action, (start, end) in zip(actions, timeline): + # draw + color = 'blue' if (end - start) == 1 else 'orange' + rec = Rectangle((start, devid + 0.5), end-start, 1, + color=color, ec='black', lw=1.5) + ax.add_artist(rec) + rx, ry = rec.get_xy() + cx = rx + rec.get_width() / 2.0 + cy = ry + rec.get_height() / 2.0 + anno = action.name if action.fid is None else action.fid + ax.annotate(anno, (cx, cy), color='w', weight='bold', + fontsize=10, ha='center', va='center') + # plt.grid() + plt.savefig(outfile) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index 216bc2dd..4b3323e6 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -1,6 +1,7 @@ from typing import Sequence from cube.schedule.action import Action, add_flow -from cube.schedule.iterator import legal_sequence +from cube.schedule.iterator import sequence_space +from cube.schedule.plan import ExecutionPlan import argparse import re @@ -10,22 +11,29 @@ def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): + forward_time = 1 + backward_time = 2 + actions = list() relations = list() for mid in range(num_microbatch): # forward for stage in range(num_stage): action = Action(forward_fn) - action.tag('f(S{},D{})'.format(stage, mid)) + action.est_latency = forward_time + action.tag('fS{}D{}'.format(stage, mid)) if stage != 0: relation = (actions[-1], action) add_flow(actions[-1], action) relations.append(relation) + else: + action.fid = mid actions.append(action) # backward for stage in range(num_stage): action = Action(backward_fn) - action.tag('b(S{},D{})'.format(num_stage - 1 - stage, mid)) + action.est_latency = backward_time + action.tag('bS{}D{}'.format(num_stage - 1 - stage, mid)) # relation relation = (actions[-1], action) add_flow(actions[-1], action) @@ -36,7 +44,7 @@ def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): def get_stage_and_mid(action): - ids = re.findall(r"S(\d+),D(\d+)", action.name) + ids = re.findall(r"S(\d+)D(\d+)", action.name) stage, mid = int(ids[0][0]), int(ids[0][1]) return stage, mid @@ -45,72 +53,25 @@ def device_search(sequence): pass -def draw_execution_plan(seq, forward_fn, backward_fn, ndevice): - forward_time = 1 - backward_time = 2 - # record each action end time - current_time = [[1] for _ in range(ndevice)] - device_actions = [list() for _ in range(ndevice)] - - recs = dict() - - for action in seq: - if action.device == -1 or action.device >= ndevice: - raise RuntimeError("action {} device not assigned or out of boundary".format(action)) - start_time = current_time[action.device][-1] - for dev_id, (end_times, dev_actions) in enumerate(zip(current_time, device_actions)): - if dev_id == action.device: - continue - # go through to check if the action has dependencies - for end_time, dev_action in zip(end_times[1:], dev_actions): - if action.depends_on(dev_action): - print('find dependency {} -> {}, end time: {}'.format(action, dev_action, end_time)) - start_time = max(start_time, end_time) - elif dev_action.depends_on(action): - raise RuntimeError("Action happened before") - # draw regtangular - if action._fn[0] == forward_fn: - span_time = forward_time - color = 'blue' - elif action._fn[0] == backward_fn: - span_time = backward_time - color = 'orange' - recs[action] = Rectangle((start_time, action.device), span_time, 1, - color=color, ec='black', lw=1.5) - # update timeline - current_time[action.device].append(start_time + span_time) - device_actions[action.device].append(action) - - fig, ax = plt.subplots() - for action in recs: - stage, mid = get_stage_and_mid(action) - ax.add_artist(recs[action]) - rx, ry = recs[action].get_xy() - cx = rx + recs[action].get_width() / 2.0 - cy = ry + recs[action].get_height() / 2.0 - ax.annotate(f'S{stage}D{mid}', (cx, cy), color='w', weight='bold', - fontsize=8, ha='center', va='center') - - ax.set_xlim((1, int(len(seq) * 1.5) + 1)) - ax.set_ylim((0, ndevice)) - ax.set_aspect('equal') - plt.savefig('./tmp.png') - - def print_all_legal_sequence(actions, relations): - for cnt, seq in enumerate(legal_sequence(actions, relations)): + for cnt, seq in enumerate(sequence_space(actions, relations)): print(seq) print('total found {} legal sequences'.format(cnt + 1)) -def fixed_placement_sequence(actions, relations, ndevices, forward, backward): - for cnt, seq in enumerate(legal_sequence(actions, relations)): +def fixed_placement_sequence(actions, relations, ndevice, forward, backward): + for cnt, seq in enumerate(sequence_space(actions, relations)): # assign device for action in seq: stage, mid = get_stage_and_mid(action) - action.device = stage % ndevices - print(f'found seq > {seq}') - draw_execution_plan(seq, forward, backward, ndevices) + action.device = stage % ndevice + execplan = ExecutionPlan(seq, ndevice) + execplan.gen() + iter_time = execplan.get_time() + print(f'found seq > {seq} \t time {iter_time}') + # if iter_time > 28: + # continue + execplan.draw(outfile='tmp.png') input('>>> ') print('total found {} legal sequences'.format(cnt + 1)) @@ -131,4 +92,5 @@ def fixed_placement_sequence(actions, relations, ndevices, forward, backward): actions, relations = get_semantic(forward, backward, args.nstage, args.nmb) + #print_all_legal_sequence(actions, relations) fixed_placement_sequence(actions, relations, args.ndev, forward, backward) From d11cc3ee515a06cb7003b7cf502e8ac1747e6537 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 06:06:20 +0000 Subject: [PATCH 0153/1892] add gpipe and 1f1b example --- examples/poc/pipeline_space.py | 87 ++++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 5 deletions(-) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index 4b3323e6..2fa0868b 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -2,6 +2,7 @@ from cube.schedule.action import Action, add_flow from cube.schedule.iterator import sequence_space from cube.schedule.plan import ExecutionPlan +from cube.schedule.checker import correct_check import argparse import re @@ -59,7 +60,7 @@ def print_all_legal_sequence(actions, relations): print('total found {} legal sequences'.format(cnt + 1)) -def fixed_placement_sequence(actions, relations, ndevice, forward, backward): +def fixed_placement_sequence(actions, relations, ndevice, max_time): for cnt, seq in enumerate(sequence_space(actions, relations)): # assign device for action in seq: @@ -69,13 +70,85 @@ def fixed_placement_sequence(actions, relations, ndevice, forward, backward): execplan.gen() iter_time = execplan.get_time() print(f'found seq > {seq} \t time {iter_time}') - # if iter_time > 28: - # continue + if iter_time > max_time: + continue execplan.draw(outfile='tmp.png') input('>>> ') print('total found {} legal sequences'.format(cnt + 1)) +def pipe_1f1b(actions, relations, nstage, ndevice, nmb): + num_stage = nstage + num_microbatch = nmb + + f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] + b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] + + # action placement + for stage in range(num_stage): + for mid in range(num_microbatch): + f(stage, mid).device = stage % ndevice + print(f(stage, mid), f'stage={stage}, mid={mid}, device={stage % ndevice}') + b(stage, mid).device = stage % ndevice + print(b(stage, mid), f'stage={stage}, mid={mid}') + + sequence = list() + + # warmup: + for stage in range(num_stage): + for mid in range(num_stage-stage): + sequence.append(f(stage, mid)) + + # steady + cooldown: + for mid in range(num_microbatch): + # enqueue backward + for stage in range(num_stage-1, -1, -1): + sequence.append(b(stage, mid)) + # enqueue forward + for stage in range(num_stage): + f_mid = mid + num_stage - stage + if f_mid >= num_microbatch: + continue + sequence.append(f(stage, f_mid)) + print(sequence) + assert correct_check(sequence, actions, relations) + execplan = ExecutionPlan(sequence, ndevice) + execplan.draw(outfile='./pipeline-1f1b.png') + + +def gpipe(actions, relations, nstage, ndevice, nmb): + num_stage = nstage + num_microbatch = nmb + + f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] + b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] + + # action placement + for stage in range(num_stage): + for mid in range(num_microbatch): + f(stage, mid).device = stage % ndevice + print(f(stage, mid), f'stage={stage}, mid={mid}, device={stage % ndevice}') + b(stage, mid).device = stage % ndevice + print(b(stage, mid), f'stage={stage}, mid={mid}') + + sequence = list() + + # warmup: + for stage in range(num_stage): + for mid in range(num_microbatch): + sequence.append(f(stage, mid)) + + # backward + for stage in range(num_stage): + for mid in range(num_microbatch): + sequence.append(b(num_stage - 1 - stage, mid)) + + print(sequence) + # assert correct_check(sequence, actions, relations) + execplan = ExecutionPlan(sequence, ndevice) + execplan.draw(outfile='./gpipe.png') + + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -85,6 +158,8 @@ def fixed_placement_sequence(actions, relations, ndevice, forward, backward): help='number of micro-batch') parser.add_argument('--ndev', type=int, default=4, help='number of devices') + parser.add_argument('--max-time', type=int, default=100, + help='maximal time. Will filter out plans that have larger time than this') args = parser.parse_args() forward = lambda data: data @@ -92,5 +167,7 @@ def fixed_placement_sequence(actions, relations, ndevice, forward, backward): actions, relations = get_semantic(forward, backward, args.nstage, args.nmb) - #print_all_legal_sequence(actions, relations) - fixed_placement_sequence(actions, relations, args.ndev, forward, backward) + pipe_1f1b(actions, relations, args.nstage, args.ndev, args.nmb) + gpipe(actions, relations, args.nstage, args.ndev, args.nmb) + + fixed_placement_sequence(actions, relations, args.ndev, args.max_time) From d35e0c9b376387850848cc19530700c7587292da Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 07:34:08 +0000 Subject: [PATCH 0154/1892] enable full search space --- cube/schedule/action.py | 1 + cube/schedule/iterator.py | 41 +++++++++++++++++++---- cube/schedule/plan.py | 21 +++++++++++- examples/poc/pipeline_space.py | 59 +++++++++++++++++++++++----------- 4 files changed, 96 insertions(+), 26 deletions(-) diff --git a/cube/schedule/action.py b/cube/schedule/action.py index e983253f..1efed1fc 100644 --- a/cube/schedule/action.py +++ b/cube/schedule/action.py @@ -13,6 +13,7 @@ def __init__(self, fn): self.fid = None # flow id self.device = -1 self.est_latency = 1 + self.est_memory = 1 def __call__(self, *args, **kwargs): """ diff --git a/cube/schedule/iterator.py b/cube/schedule/iterator.py index 1b874550..c5327c51 100644 --- a/cube/schedule/iterator.py +++ b/cube/schedule/iterator.py @@ -2,6 +2,7 @@ from cube.schedule.checker import correct_check import itertools +import numpy as np def legal_sequence(actions, relations): @@ -23,13 +24,12 @@ def legal_sequence(actions, relations): yield seq -def ready_action_set(actions, relations, flip=False): +def ready_action_set(actions, relations): """ Return a list of actions can be executed now """ - flip = -1 if flip else 1 ready_actions = list() - for action in actions[::flip]: + for action in actions: satisfy = True for (_, succ) in relations: if succ == action: @@ -50,16 +50,43 @@ def remove_dependency(action, relations): return new_relations -def sequence_space(actions, relations, seq=list()): +def sequence_space(actions, relations, path_shuffle=True, seq=list()): if len(actions) == 0: yield seq # inital entry - entry_actions = ready_action_set(actions, relations, flip=len(actions) % 2 == 0) - for action in entry_actions: + entry_actions = ready_action_set(actions, relations) + entry_actions = np.array(entry_actions) + if path_shuffle: + np.random.shuffle(entry_actions) + for aid, action in enumerate(entry_actions): + if len(seq) == 0: + print(f'> search progress: [{aid}/{len(entry_actions)}]...') seq = seq + [action] action_idx = actions.index(action) sub_actions = actions[:action_idx] + actions[action_idx+1:] sub_relations = remove_dependency(action, relations) - for res in sequence_space(sub_actions, sub_relations, seq): + for res in sequence_space(sub_actions, sub_relations, path_shuffle, seq): yield res seq = seq[:-1] + + +def placement_space(actions, ndevice, fb_same=True, path_shuffle=True, assigned=0): + if assigned == len(actions): + yield actions + return + + action = actions[assigned] + device_choice = np.array(list(range(ndevice)), dtype=np.int) + if path_shuffle: + np.random.shuffle(device_choice) + + if fb_same: + for assigned_action in actions[:assigned]: + # assume action name likes 'fS0D1' + if action.name[1:] == assigned_action.name[1:]: + device_choice = [assigned_action.device] + break + for device in device_choice: + action.device = device + for res in placement_space(actions, ndevice, fb_same, path_shuffle, assigned+1): + yield res diff --git a/cube/schedule/plan.py b/cube/schedule/plan.py index f48a7ff1..9190aeba 100644 --- a/cube/schedule/plan.py +++ b/cube/schedule/plan.py @@ -65,7 +65,23 @@ def get_time(self): if self.device_timeline is None: self.gen() return max( - [timeline[-1][1] for timeline in self.device_timeline] + [timeline[-1][1] for timeline in self.device_timeline if len(timeline) != 0] + ) + + def get_memory(self): + if self.device_timeline is None: + self.gen() + + def device_memory(actions): + max_mem = 0 + cur_mem = 0 + for action in actions: + cur_mem += action.est_memory + max_mem = max(cur_mem, max_mem) + return max_mem + + return max( + [device_memory(actions) for actions in self.device_actions] ) def draw(self, outfile='./execplan.png'): @@ -110,3 +126,6 @@ def draw(self, outfile='./execplan.png'): fontsize=10, ha='center', va='center') # plt.grid() plt.savefig(outfile) + + def to_json(self): + return [repr(action) for action in self.seq] diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index 2fa0868b..c16c81fb 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -1,14 +1,11 @@ -from typing import Sequence from cube.schedule.action import Action, add_flow -from cube.schedule.iterator import sequence_space +from cube.schedule.iterator import sequence_space, placement_space from cube.schedule.plan import ExecutionPlan from cube.schedule.checker import correct_check import argparse import re - -import matplotlib.pyplot as plt -from matplotlib.patches import Rectangle +import json def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): @@ -22,6 +19,7 @@ def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): for stage in range(num_stage): action = Action(forward_fn) action.est_latency = forward_time + action.est_memory = 1 action.tag('fS{}D{}'.format(stage, mid)) if stage != 0: relation = (actions[-1], action) @@ -34,6 +32,7 @@ def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): for stage in range(num_stage): action = Action(backward_fn) action.est_latency = backward_time + action.est_memory = -1 action.tag('bS{}D{}'.format(num_stage - 1 - stage, mid)) # relation relation = (actions[-1], action) @@ -50,17 +49,42 @@ def get_stage_and_mid(action): return stage, mid -def device_search(sequence): - pass +def full_grid_search(actions, relations, ndevice, nmb): + """ + Search minimal time plan under the memory constraints + """ + memory_buckets = dict() + for activation_num in range(1, nmb+1): + memory_buckets[activation_num] = None -def print_all_legal_sequence(actions, relations): for cnt, seq in enumerate(sequence_space(actions, relations)): - print(seq) - print('total found {} legal sequences'.format(cnt + 1)) - - -def fixed_placement_sequence(actions, relations, ndevice, max_time): + for dev_num, dev_seq in enumerate(placement_space(seq, ndevice, fb_same=True)): + # print(f'on sequence > {dev_seq}') + execplan = ExecutionPlan(dev_seq, ndevice) + execplan.gen() + span = execplan.get_time() + memory = execplan.get_memory() + # update plan + for upper_mem in memory_buckets: + if memory <= upper_mem: + if memory_buckets[upper_mem] is None: + memory_buckets[upper_mem] = execplan + execplan.draw(outfile=f'./figs/plan.mem{memory}.png') + if span < memory_buckets[upper_mem].get_time(): + memory_buckets[upper_mem] = execplan + execplan.draw(outfile=f'./figs/plan.mem{memory}.png') + print(f'> found a better seq {seq} time {span} mem {memory}') + # input(f'>>> done on {dev_num+1} device placement ') + # dump to json + print(f'> totally done search on {cnt+1} sequences') + for key in memory_buckets: + memory_buckets[key] = memory_buckets[key].to_json() + with open('./figs/results.json', 'w') as outfile: + json.dump(memory_buckets, outfile) + + +def fixed_placement_search(actions, relations, ndevice, max_time): for cnt, seq in enumerate(sequence_space(actions, relations)): # assign device for action in seq: @@ -158,8 +182,6 @@ def gpipe(actions, relations, nstage, ndevice, nmb): help='number of micro-batch') parser.add_argument('--ndev', type=int, default=4, help='number of devices') - parser.add_argument('--max-time', type=int, default=100, - help='maximal time. Will filter out plans that have larger time than this') args = parser.parse_args() forward = lambda data: data @@ -167,7 +189,8 @@ def gpipe(actions, relations, nstage, ndevice, nmb): actions, relations = get_semantic(forward, backward, args.nstage, args.nmb) - pipe_1f1b(actions, relations, args.nstage, args.ndev, args.nmb) - gpipe(actions, relations, args.nstage, args.ndev, args.nmb) + # pipe_1f1b(actions, relations, args.nstage, args.ndev, args.nmb) + # gpipe(actions, relations, args.nstage, args.ndev, args.nmb) - fixed_placement_sequence(actions, relations, args.ndev, args.max_time) + # fixed_placement_search(actions, relations, args.ndev, max_time=100) + full_grid_search(actions, relations, args.ndev, args.nmb) \ No newline at end of file From 06080c9662fdedcd73e0f170b876095d43e68c1f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 07:55:06 +0000 Subject: [PATCH 0155/1892] search add throughput --- examples/poc/pipeline_space.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index c16c81fb..c846ee0f 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -6,6 +6,7 @@ import argparse import re import json +import time def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): @@ -58,6 +59,7 @@ def full_grid_search(actions, relations, ndevice, nmb): for activation_num in range(1, nmb+1): memory_buckets[activation_num] = None + tic = time.time() for cnt, seq in enumerate(sequence_space(actions, relations)): for dev_num, dev_seq in enumerate(placement_space(seq, ndevice, fb_same=True)): # print(f'on sequence > {dev_seq}') @@ -76,6 +78,10 @@ def full_grid_search(actions, relations, ndevice, nmb): execplan.draw(outfile=f'./figs/plan.mem{memory}.png') print(f'> found a better seq {seq} time {span} mem {memory}') # input(f'>>> done on {dev_num+1} device placement ') + if (cnt+1) % 1000 == 0: + throughput = 1000 * (nmb ** ndevice) / (time.time() - tic) + tic = time.time() + print('> search [{}-{}] throughput {:.2f} spatial sequences / sec'.format(cnt+1-1000, cnt+1, throughput)) # dump to json print(f'> totally done search on {cnt+1} sequences') for key in memory_buckets: From cc328617a3fa941327a3c51f8b7d67a9e22db7c3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 08:37:21 +0000 Subject: [PATCH 0156/1892] add outpath --- examples/poc/pipeline_space.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index c846ee0f..f2affead 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -7,6 +7,7 @@ import re import json import time +import os def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): @@ -50,7 +51,7 @@ def get_stage_and_mid(action): return stage, mid -def full_grid_search(actions, relations, ndevice, nmb): +def full_grid_search(actions, relations, ndevice, nmb, outpath='./figs'): """ Search minimal time plan under the memory constraints """ @@ -72,10 +73,10 @@ def full_grid_search(actions, relations, ndevice, nmb): if memory <= upper_mem: if memory_buckets[upper_mem] is None: memory_buckets[upper_mem] = execplan - execplan.draw(outfile=f'./figs/plan.mem{memory}.png') + execplan.draw(outfile=os.path.join(outpath, f'{ndevice}nmb{nmb}dev.mem{memory}.png')) if span < memory_buckets[upper_mem].get_time(): memory_buckets[upper_mem] = execplan - execplan.draw(outfile=f'./figs/plan.mem{memory}.png') + execplan.draw(outfile=os.path.join(outpath, f'{ndevice}nmb{nmb}dev.mem{memory}.png')) print(f'> found a better seq {seq} time {span} mem {memory}') # input(f'>>> done on {dev_num+1} device placement ') if (cnt+1) % 1000 == 0: @@ -86,7 +87,7 @@ def full_grid_search(actions, relations, ndevice, nmb): print(f'> totally done search on {cnt+1} sequences') for key in memory_buckets: memory_buckets[key] = memory_buckets[key].to_json() - with open('./figs/results.json', 'w') as outfile: + with open(os.path.join(outpath, 'results.json'), 'w') as outfile: json.dump(memory_buckets, outfile) @@ -188,6 +189,7 @@ def gpipe(actions, relations, nstage, ndevice, nmb): help='number of micro-batch') parser.add_argument('--ndev', type=int, default=4, help='number of devices') + parser.add_argument('--outpath', type=str, default='/mydata/MagicCube/search/pipeline/') args = parser.parse_args() forward = lambda data: data @@ -199,4 +201,4 @@ def gpipe(actions, relations, nstage, ndevice, nmb): # gpipe(actions, relations, args.nstage, args.ndev, args.nmb) # fixed_placement_search(actions, relations, args.ndev, max_time=100) - full_grid_search(actions, relations, args.ndev, args.nmb) \ No newline at end of file + full_grid_search(actions, relations, args.ndev, args.nmb, args.outpath) \ No newline at end of file From cc4ba94f0ddfd3573eed2de1343f773187bdc690 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 11:33:31 +0000 Subject: [PATCH 0157/1892] multiprocess serach --- examples/poc/pipeline_space.py | 85 ++++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 5 deletions(-) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index f2affead..f7567df1 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -1,5 +1,5 @@ from cube.schedule.action import Action, add_flow -from cube.schedule.iterator import sequence_space, placement_space +from cube.schedule.iterator import sequence_space, sequence_space_batched, placement_space from cube.schedule.plan import ExecutionPlan from cube.schedule.checker import correct_check @@ -8,6 +8,8 @@ import json import time import os +import multiprocessing as mp +from functools import partial def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): @@ -91,6 +93,76 @@ def full_grid_search(actions, relations, ndevice, nmb, outpath='./figs'): json.dump(memory_buckets, outfile) +def worker_search(seqs, nmb, ndevice): + sub_memory_buckets = dict() + for activation_num in range(1, nmb+1): + sub_memory_buckets[activation_num] = None + for seq in seqs: + for dev_seq in placement_space(seq, ndevice, fb_same=True): + execplan = ExecutionPlan(dev_seq, ndevice) + execplan.gen() + span = execplan.get_time() + memory = execplan.get_memory() + # update plan + for upper_mem in sub_memory_buckets: + if memory <= upper_mem: + if sub_memory_buckets[upper_mem] is None: + sub_memory_buckets[upper_mem] = execplan + if span < sub_memory_buckets[upper_mem].get_time(): + sub_memory_buckets[upper_mem] = execplan + return sub_memory_buckets + + +def full_grid_search_mp(actions, relations, ndevice, nmb, outpath='./figs', nworker=40): + """ + Search minimal time plan under the memory constraints + """ + pool = mp.Pool(processes=nworker) + + memory_buckets = dict() + for activation_num in range(1, nmb+1): + memory_buckets[activation_num] = None + + def merge(sub_memory_buckets): + for upper_mem in sub_memory_buckets: + if sub_memory_buckets[upper_mem] is None: + continue + execplan = sub_memory_buckets[upper_mem] + span = execplan.get_time() + memory = execplan.get_memory() + if memory_buckets[upper_mem] is None: + memory_buckets[upper_mem] = execplan + execplan.draw(outfile=os.path.join(outpath, f'{ndevice}nmb{nmb}dev.mem{memory}.png')) + print(f'> found a better seq {execplan.seq} time {span} mem {memory}') + if span < memory_buckets[upper_mem].get_time(): + memory_buckets[upper_mem] = execplan + execplan.draw(outfile=os.path.join(outpath, f'{ndevice}nmb{nmb}dev.mem{memory}.png')) + print(f'> found a better seq {execplan.seq} time {span} mem {memory}') + + bs = (nworker, 20) + nseqs = 0 + for seqs in sequence_space_batched(actions, relations, bs=bs): + results = list() + for wid in range(nworker): + res = pool.apply_async(worker_search, args=(seqs[wid], nmb, ndevice)) + results.append(res) + nseqs += sum([len(worker_seqs) for worker_seqs in seqs]) + print(f'assigned {nseqs} sequences') + for res in results: + sub_buckets = res.get() + merge(sub_buckets) + + pool.close() + pool.join() + + # dump to json + print(f'> totally done search on {nseqs} sequences') + for key in memory_buckets: + memory_buckets[key] = memory_buckets[key].to_json() + with open(os.path.join(outpath, 'results.json'), 'w') as outfile: + json.dump(memory_buckets, outfile) + + def fixed_placement_search(actions, relations, ndevice, max_time): for cnt, seq in enumerate(sequence_space(actions, relations)): # assign device @@ -180,6 +252,12 @@ def gpipe(actions, relations, nstage, ndevice, nmb): execplan.draw(outfile='./gpipe.png') +def forward(data): + pass + +def backward(grad): + pass + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -191,9 +269,6 @@ def gpipe(actions, relations, nstage, ndevice, nmb): help='number of devices') parser.add_argument('--outpath', type=str, default='/mydata/MagicCube/search/pipeline/') args = parser.parse_args() - - forward = lambda data: data - backward = lambda grad: grad actions, relations = get_semantic(forward, backward, args.nstage, args.nmb) @@ -201,4 +276,4 @@ def gpipe(actions, relations, nstage, ndevice, nmb): # gpipe(actions, relations, args.nstage, args.ndev, args.nmb) # fixed_placement_search(actions, relations, args.ndev, max_time=100) - full_grid_search(actions, relations, args.ndev, args.nmb, args.outpath) \ No newline at end of file + full_grid_search_mp(actions, relations, args.ndev, args.nmb, args.outpath) \ No newline at end of file From d81660569c640ad3c9e9c27789221b9c9c0f4d96 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 11:39:50 +0000 Subject: [PATCH 0158/1892] iterator --- cube/schedule/iterator.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/cube/schedule/iterator.py b/cube/schedule/iterator.py index c5327c51..d074b026 100644 --- a/cube/schedule/iterator.py +++ b/cube/schedule/iterator.py @@ -70,6 +70,23 @@ def sequence_space(actions, relations, path_shuffle=True, seq=list()): seq = seq[:-1] +def sequence_space_batched(actions, relations, bs): + """ + bs: tuple (num_workers, seq_per_worker) + """ + seqs = list() + for seq in sequence_space(actions, relations): + seqs.append(seq) + if len(seqs) % (bs[0] * bs[1]) == 0: + seqs = [seqs[wid*bs[1]:(wid+1)*bs[1]] for wid in range(bs[0])] + yield seqs + seqs = list() + # tail + if len(seqs) != 0: + seqs = [seqs[wid*bs[1]:(wid+1)*bs[1]] for wid in range(bs[0])] + yield seqs + + def placement_space(actions, ndevice, fb_same=True, path_shuffle=True, assigned=0): if assigned == len(actions): yield actions From 799ff0a0e4a5036e01df5d08b77a4dbeee44fb2d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 20:22:43 +0800 Subject: [PATCH 0159/1892] multiprocess and fix bugs --- examples/poc/pipeline_space.py | 62 ++++++++++++++-------------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index f7567df1..5db49066 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -53,6 +53,13 @@ def get_stage_and_mid(action): return stage, mid +def fixed_placement(actions, ndevice, **kwargs): + for action in actions: + stage, _ = get_stage_and_mid(action) + action.device = stage % ndevice + yield actions + + def full_grid_search(actions, relations, ndevice, nmb, outpath='./figs'): """ Search minimal time plan under the memory constraints @@ -93,12 +100,12 @@ def full_grid_search(actions, relations, ndevice, nmb, outpath='./figs'): json.dump(memory_buckets, outfile) -def worker_search(seqs, nmb, ndevice): +def worker_search(seqs, nstage, ndevice, space_iter=placement_space): sub_memory_buckets = dict() - for activation_num in range(1, nmb+1): + for activation_num in range(1, nstage+1): sub_memory_buckets[activation_num] = None for seq in seqs: - for dev_seq in placement_space(seq, ndevice, fb_same=True): + for dev_seq in space_iter(seq, ndevice, fb_same=True): execplan = ExecutionPlan(dev_seq, ndevice) execplan.gen() span = execplan.get_time() @@ -113,14 +120,14 @@ def worker_search(seqs, nmb, ndevice): return sub_memory_buckets -def full_grid_search_mp(actions, relations, ndevice, nmb, outpath='./figs', nworker=40): +def space_search_mp(actions, relations, nstage, nmb, ndevice, outpath, space_iter=placement_space, nworker=40): """ Search minimal time plan under the memory constraints """ pool = mp.Pool(processes=nworker) memory_buckets = dict() - for activation_num in range(1, nmb+1): + for activation_num in range(1, nstage+1): memory_buckets[activation_num] = None def merge(sub_memory_buckets): @@ -132,24 +139,24 @@ def merge(sub_memory_buckets): memory = execplan.get_memory() if memory_buckets[upper_mem] is None: memory_buckets[upper_mem] = execplan - execplan.draw(outfile=os.path.join(outpath, f'{ndevice}nmb{nmb}dev.mem{memory}.png')) + execplan.draw(outfile=os.path.join(outpath, f'{nstage}stage.{nmb}nmb.{ndevice}dev.mem{memory}.png')) print(f'> found a better seq {execplan.seq} time {span} mem {memory}') if span < memory_buckets[upper_mem].get_time(): memory_buckets[upper_mem] = execplan - execplan.draw(outfile=os.path.join(outpath, f'{ndevice}nmb{nmb}dev.mem{memory}.png')) + execplan.draw(outfile=os.path.join(outpath, f'{nstage}stage.{nmb}nmb.{ndevice}dev.mem{memory}.png')) print(f'> found a better seq {execplan.seq} time {span} mem {memory}') - bs = (nworker, 20) + bs = (nworker, 100) nseqs = 0 for seqs in sequence_space_batched(actions, relations, bs=bs): - results = list() + handles = list() for wid in range(nworker): - res = pool.apply_async(worker_search, args=(seqs[wid], nmb, ndevice)) - results.append(res) + handle = pool.apply_async(worker_search, args=(seqs[wid], nstage, ndevice, space_iter)) + handles.append(handle) nseqs += sum([len(worker_seqs) for worker_seqs in seqs]) print(f'assigned {nseqs} sequences') - for res in results: - sub_buckets = res.get() + for handle in handles: + sub_buckets = handle.get() merge(sub_buckets) pool.close() @@ -163,24 +170,7 @@ def merge(sub_memory_buckets): json.dump(memory_buckets, outfile) -def fixed_placement_search(actions, relations, ndevice, max_time): - for cnt, seq in enumerate(sequence_space(actions, relations)): - # assign device - for action in seq: - stage, mid = get_stage_and_mid(action) - action.device = stage % ndevice - execplan = ExecutionPlan(seq, ndevice) - execplan.gen() - iter_time = execplan.get_time() - print(f'found seq > {seq} \t time {iter_time}') - if iter_time > max_time: - continue - execplan.draw(outfile='tmp.png') - input('>>> ') - print('total found {} legal sequences'.format(cnt + 1)) - - -def pipe_1f1b(actions, relations, nstage, ndevice, nmb): +def pipe_1f1b(actions, relations, nstage, nmb, ndevice): num_stage = nstage num_microbatch = nmb @@ -219,7 +209,7 @@ def pipe_1f1b(actions, relations, nstage, ndevice, nmb): execplan.draw(outfile='./pipeline-1f1b.png') -def gpipe(actions, relations, nstage, ndevice, nmb): +def gpipe(actions, relations, nstage, nmb, ndevice): num_stage = nstage num_microbatch = nmb @@ -272,8 +262,6 @@ def backward(grad): actions, relations = get_semantic(forward, backward, args.nstage, args.nmb) - # pipe_1f1b(actions, relations, args.nstage, args.ndev, args.nmb) - # gpipe(actions, relations, args.nstage, args.ndev, args.nmb) - - # fixed_placement_search(actions, relations, args.ndev, max_time=100) - full_grid_search_mp(actions, relations, args.ndev, args.nmb, args.outpath) \ No newline at end of file + # pipe_1f1b(actions, relations, args.nstage, args.nmb, args.ndev) + # gpipe(actions, relations, args.nstage, args.nmb, args.ndev) + space_search_mp(actions, relations, args.nstage, args.nmb, args.ndev, args.outpath, space_iter=fixed_placement) From d0a09512481966287782d40e747788fb05ae28fe Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Sep 2021 20:43:02 +0800 Subject: [PATCH 0160/1892] dump json bug fix --- examples/poc/pipeline_space.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index 5db49066..620ccb40 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -165,7 +165,8 @@ def merge(sub_memory_buckets): # dump to json print(f'> totally done search on {nseqs} sequences') for key in memory_buckets: - memory_buckets[key] = memory_buckets[key].to_json() + if memory_buckets[key] is not None: + memory_buckets[key] = memory_buckets[key].to_json() with open(os.path.join(outpath, 'results.json'), 'w') as outfile: json.dump(memory_buckets, outfile) From 6629227e78cf7304c13c54753a9177eb787b2baf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 3 Sep 2021 16:21:02 +0800 Subject: [PATCH 0161/1892] update primitive examples --- examples/case_study/parallel_primitive.py | 96 ---------- examples/case_study/spatial_primitive.py | 165 ++++++++++++++++++ ...ule_primitive.py => temporal_primitive.py} | 109 ++++-------- 3 files changed, 200 insertions(+), 170 deletions(-) delete mode 100644 examples/case_study/parallel_primitive.py create mode 100644 examples/case_study/spatial_primitive.py rename examples/case_study/{schedule_primitive.py => temporal_primitive.py} (79%) diff --git a/examples/case_study/parallel_primitive.py b/examples/case_study/parallel_primitive.py deleted file mode 100644 index 36f7a3e7..00000000 --- a/examples/case_study/parallel_primitive.py +++ /dev/null @@ -1,96 +0,0 @@ -import torch -import os - -from functools import partial - -torch.manual_seed(121) - -# select from logical tensor with indices -> generate a logical tensor -def select(tensor, indices, val_op, shape): pass - -# deploy logical tensor to devices -def deploy(tensor, ranks): pass - -# merge logical tensors at `ranks` devices -def merge(tensor, ranks, merge_op): pass - - -class LogicalTensor: pass -class PhyiscalTensor: pass - - -def linear_tensor_parallel(inputs, weight, bias, output): - """ - inputs: (M, K) - weight: (N, K) - bias: (N,) - output: (M, N) - - Perform: (M, K) * (\delta N, K) + (\delta N,) = (M, \delta N) - """ - - M = 1024 - K = 1024 - N = 1024 - - # Tensor Split # -- System + policy generated - inputs = select( - tensor = inputs, - indices = (slice(0, M), slice(0, K)), - val_op = None, - shape = (M, K) - ) - - weights, biases, outputs = list(), list(), list() - for cid in range(4): - weights.append(select( - tensor = weight, - indices = (slice(cid * (N // 4), (cid + 1) * (N // 4)), slice(0, K)), - val_op = None, - shape = (N // 4, K) - )) - - biases.append(select( - tensor = bias, - indices = (slice(cid * (N // 4), (cid + 1) * (N // 4)),), - val_op = None, - shape = (N // 4,) - )) - - outputs.append(select( - tensor = output, - indices = (slice(slice(0, M), cid * (N // 4), (cid + 1) * (N // 4)),), - val_op = None, - shape = (M, N // 4) - )) - # Tensor Split # - - # Tensor Deployment # -- System + policy generated - inputs = deploy( - segment = inputs, - ranks = [0, 1, 2, 3] - ) - - for rank, (weight, bias) in enumerate(zip(weights, biases)): - weight = deploy( - segment = weight, - ranks = [rank], - ) - bias = deploy( - segment = bias, - ranks = [rank], - ) - # Tensor Deployment # - - # Compute # -- Expert specified - for weight, bias, output in enumerate(zip(weights, biases, outputs)): - # physical tensor - chunk = torch._C._nn.linear(inputs, weight, bias) - output.fill(chunk) - - # Generate logical tensor -- System generated - merge( - tensor = output, - ranks = [0, 1, 2, 3], - merge_op = partial(all_gather, dim=1) - ) diff --git a/examples/case_study/spatial_primitive.py b/examples/case_study/spatial_primitive.py new file mode 100644 index 00000000..7439e77c --- /dev/null +++ b/examples/case_study/spatial_primitive.py @@ -0,0 +1,165 @@ +import torch +import os + +from functools import partial + +torch.manual_seed(121) + +class LogicalOp: pass +class PhyiscalOp: pass + +class LogicalTensor: pass +class PhyiscalTensor: pass + +# select from logical tensor with indices -> generate a logical tensor +def select(tensor: LogicalTensor, indices, val_map_op, shape) -> LogicalTensor: pass + +# deploy logical tensor to devices +def deploy(tensor: LogicalTensor, ranks) -> list(PhyiscalTensor): pass + +# merge logical tensors at `ranks` devices +def merge(tensor: LogicalTensor, ranks, val_reduce_op): pass + +# tensor movement: move physical tensor to rank +def move(tensor: PhyiscalTensor, rank): pass + +# tensor release: release the data in physical tensor inside tensor +def release(tensor: PhyiscalTensor): pass + +# tensor re-genrewate: bring back the data for the physical tensor +def generate(tensor: PhyiscalTensor, rank): pass + + + +## =============== tensor parallelism on matmul ============== ## + +def all_gather(tensors, dim): pass + + +def linear_tensor_parallel(inputs, weight, bias, output): + """ + inputs: (M, K) + weight: (N, K) + bias: (N,) + output: (M, N) + + Perform: (M, K) * (\delta N, K) + (\delta N,) = (M, \delta N) + """ + + M = 1024 + K = 1024 + N = 1024 + + # Tensor split -- system + policy generated + inputs = select( + tensor = inputs, + indices = (slice(0, M), slice(0, K)), + val_map_op = None, + shape = (M, K) + ) + + weights, biases, outputs = list(), list(), list() + for cid in range(4): + weights.append(select( + tensor = weight, + indices = (slice(cid * (N // 4), (cid + 1) * (N // 4)), slice(0, K)), + val_map_op = None, + shape = (N // 4, K) + )) + + biases.append(select( + tensor = bias, + indices = (slice(cid * (N // 4), (cid + 1) * (N // 4)),), + val_map_op = None, + shape = (N // 4,) + )) + + outputs.append(select( + tensor = output, + indices = (slice(slice(0, M), cid * (N // 4), (cid + 1) * (N // 4)),), + val_map_op = None, + shape = (M, N // 4) + )) + + # Algorithm -- Expert specified + for weight, bias, output in enumerate(zip(weights, biases, outputs)): + # physical tensor + chunk = torch._C._nn.linear(inputs, weight, bias) + # physical tensor fill in to logical tensor + output.fill(chunk) + + # Tensor deployment -- system + policy generated + inputs = deploy( + segment = inputs, + ranks = [0, 1, 2, 3] + ) + + for rank, (weight, bias) in enumerate(zip(weights, biases)): + weight = deploy( + segment = weight, + ranks = [rank], + ) + bias = deploy( + segment = bias, + ranks = [rank], + ) + + # Logical tensor merge -- system + policy generated + merge( + tensor = outputs, + ranks = [0, 1, 2, 3], + merge_op = partial(all_gather, dim=1) + ) + + +## =============== tensor movement / re-generation ============== ## + +def custom_op(forward_fn, backward_fn): pass + +def offload(inputs: PhyiscalTensor, weights: list(PhyiscalTensor), ops: list(PhyiscalOp)): + """ + offload a feature_map after forward the 3rd op + retrieve (prefetch) the feature_map after backward the 5th op + """ + feature_maps = [inputs] + offload_step = 2 + retrieve_step = 4 + for step, (weight, op) in enumerate(zip(weights, ops)): + tensor = feature_maps[-1] + # retrieve + if step == retrieve_step: + feature_maps[-1] = custom_op( + forward_fn=partial((lambda input: input), input=feature_maps[-1]), + backward_fn=partial(move, feature_maps[offload_step + 1], rank=0) + ) + # op calculation + out = op(tensor, weight) + # offload + if step == offload_step: + move(tensor, rank=-1) + feature_maps.append(out) + + +def checkpoint(inputs: PhyiscalTensor, weights: list(PhyiscalTensor), ops: list(PhyiscalOp)): + """ + checkpoint a feature_map after forward the 3rd op + re-generate (possible for packing with other operator) after backward the 5th op + """ + feature_maps = [inputs] + release_step = 2 + recompute_step = 4 + released_tensor = None + for step, (weight, op) in enumerate(zip(weights, ops)): + tensor = feature_maps[-1] + # retrieve + if step == recompute_step: + feature_maps[-1] = custom_op( + forward_fn=partial((lambda input: input), input=feature_maps[-1]), + backward_fn=partial(generate, feature_maps[release_step + 1], rank=0) + ) + # op calculation + out = op(tensor, weight) + # offload + if step == release_step: + release(tensor) + feature_maps.append(out) diff --git a/examples/case_study/schedule_primitive.py b/examples/case_study/temporal_primitive.py similarity index 79% rename from examples/case_study/schedule_primitive.py rename to examples/case_study/temporal_primitive.py index 79da99ee..1245c890 100644 --- a/examples/case_study/schedule_primitive.py +++ b/examples/case_study/temporal_primitive.py @@ -1,13 +1,15 @@ -from typing import Sequence import torch from functools import partial -## Primitive ## def select(tensor, indices, val_map_op=None, shape=None): pass +## Abstractions and Primitivse ## + +class Action: pass + def execute(action, **kwargs): # action instance will automatically take flow-in results # and select the chunked kwargs @@ -17,6 +19,9 @@ def add_flow(*actions): # this will set all input actions with same flow-id pass + +## System Runtime units ## + def run(schedule, num_microbs, *args): """ Take a list of actions and execute in list order @@ -41,48 +46,29 @@ def run(schedule, num_microbs, *args): outs = execute(action, *tuple(args)) return outs -class Action: pass - def check_consistency(sequence, actions, relations): pass -# ===================== Basic steps ================== # -def general_action(flow_in, *args, **kwargs): - """ - flow_in: the output from previous actions - """ - pass - -def forward(flow_in, model, data): - loss = model(data) - return loss - -def backward(flow_in): - flow_in.backwrd() - return flow_in - -# ===================== Basic steps ================== # - - -def naive_schedule(f, b): - f = Action(f) - b = Action(b) - add_flow(f, b) - schedules = [f, b] - return partial(run, schedules, num_microbs=1) +# Schedule example +def naive_schedule(actions: list(Action), relations: set((Action, Action))) -> list(Action): + """ + Args: + actions: order specified by AI scientist (the reference semantic) + relations: set of action dependencies (action1, action2): action1 -> action2 -def grad_accumulate(f, b, accum_times=4): - forwards = [Action(f, fid=fid) for fid in range(accum_times)] - backwards = [Action(b, fid=fid) for fid in range(accum_times)] - schedules = list() - for f, b in zip(forwards, backwards): - add_flow(f, b) - schedules += [f, b] - return partial(run, schedules, num_microbs=accum_times) + Returns: + a execution sequence following the abstraction + """ + # placement + for action in actions: + action.device = 0 + # execution sequence + sequence = actions + return sequence -def pipeline_1f1b_schedules(actions, relations): +def pipeline_1f1b_schedule(actions, relations): """ Pipeline 1f1b policy description -- generate a sequence @@ -97,7 +83,7 @@ def pipeline_1f1b_schedules(actions, relations): num_microbatch = len(actions) / 2 / num_stage f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] - b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + stage] + b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] # action placement for stage in range(num_stage): @@ -139,7 +125,7 @@ def pipeline_1f1b_schedule(actions, relations): num_microbatch = len(actions) / 2 / num_stage f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] - b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + stage] + b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] # action placement for stage in range(num_stage): @@ -179,53 +165,19 @@ def pipeline_1f1b_schedule(actions, relations): -def dist_policy(DAG, resources): - """ - Policy decided the parallelisms and op-placement - """ - return DAG - - -def set_schedule_policy(model, specific_schedule, bs): - """ - forward_fn: forward function - backward_fn: backward_function - bs: global batch size - """ - num_microbs = 4 if bs >= 4 else bs - schedule = pipeline_schedule(model.forward, backward, num_microbs) - model.set_schedule(schedule) - - if __name__ == '__main__': # define logical model / optimizer / data loader class LogicalModel: pass class Optimizer: pass class DataLoader: pass + compute_loss = lambda output, label : output model = LogicalModel() optimizer = Optimizer(model.parameters()) dataloader = DataLoader(bs=1024) - # def forward_step(flow_in, data, label, **kwargs): - # # this requires loss computation needs to be in the model - # # output = model(data, label) - # output = model(data, label) - # # function wrapper - # loss = compute_loss(output) - # return output - # - # def backward_step(output, **kwargs): - # output.backward() - # return output - - # policy for placement and parallelisms -- will be hidden - model = dist_policy(get_dag(model, loss_compute, input_shapes), resources) - # data flow scheduling policy -- will be hidden - set_schedule_policy(model, pipeline_schedule, bs=1024) - for epoch in range(100): for step, (data, label) in enumerate(dataloader): # enqueue forward specfied by schedule and execute the first one @@ -246,6 +198,15 @@ class DataLoader: pass # evaluation + + +# ======== example sequences for all kinds of configuration ============= + +forward = lambda model, data: model(data) +backward = lambda grad, output: output.backward(grad) +update_gradient = lambda model, grad: model.update(grad) + + def train_iter_grad_accumulate(model, datas, stage=2, micro_bs=4): out_s0_d0 = forward(model[0], datas[0]) From baa3d9a0ed953555c49c29924230c74c45e5f42b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 3 Sep 2021 16:54:01 +0800 Subject: [PATCH 0162/1892] a new iterator --- cube/schedule/iterator.py | 15 +++++++++++++-- examples/poc/pipeline_space.py | 6 +++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cube/schedule/iterator.py b/cube/schedule/iterator.py index d074b026..317ed3d2 100644 --- a/cube/schedule/iterator.py +++ b/cube/schedule/iterator.py @@ -70,12 +70,23 @@ def sequence_space(actions, relations, path_shuffle=True, seq=list()): seq = seq[:-1] -def sequence_space_batched(actions, relations, bs): +def sequence_space_bfs(actions, relations, path_shuffle=True): + # reverse relation + reverse_relation = list() + for relation in relations: + reverse_relation.append((relation[1], relation[0])) + # reverse seq + for seq in sequence_space(actions, reverse_relation, path_shuffle): + yield seq[::-1] + + +def sequence_space_batched(actions, relations, bs, bfs=False): """ bs: tuple (num_workers, seq_per_worker) """ seqs = list() - for seq in sequence_space(actions, relations): + space_iter = sequence_space_bfs if bfs else sequence_space + for seq in space_iter(actions, relations): seqs.append(seq) if len(seqs) % (bs[0] * bs[1]) == 0: seqs = [seqs[wid*bs[1]:(wid+1)*bs[1]] for wid in range(bs[0])] diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index 620ccb40..f52794dd 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -102,7 +102,7 @@ def full_grid_search(actions, relations, ndevice, nmb, outpath='./figs'): def worker_search(seqs, nstage, ndevice, space_iter=placement_space): sub_memory_buckets = dict() - for activation_num in range(1, nstage+1): + for activation_num in range(1, 2*nstage+1): sub_memory_buckets[activation_num] = None for seq in seqs: for dev_seq in space_iter(seq, ndevice, fb_same=True): @@ -127,7 +127,7 @@ def space_search_mp(actions, relations, nstage, nmb, ndevice, outpath, space_ite pool = mp.Pool(processes=nworker) memory_buckets = dict() - for activation_num in range(1, nstage+1): + for activation_num in range(1, 2*nstage+1): memory_buckets[activation_num] = None def merge(sub_memory_buckets): @@ -146,7 +146,7 @@ def merge(sub_memory_buckets): execplan.draw(outfile=os.path.join(outpath, f'{nstage}stage.{nmb}nmb.{ndevice}dev.mem{memory}.png')) print(f'> found a better seq {execplan.seq} time {span} mem {memory}') - bs = (nworker, 100) + bs = (nworker, 256) nseqs = 0 for seqs in sequence_space_batched(actions, relations, bs=bs): handles = list() From 418ddf617ddbe47f3a7ba7b07b852c241e5b790d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 3 Sep 2021 19:41:51 +0800 Subject: [PATCH 0163/1892] space size calculation --- cube/schedule/iterator.py | 29 +++++++++++++++++ examples/poc/pipeline_space.py | 5 ++- examples/poc/space_size.py | 57 ++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 examples/poc/space_size.py diff --git a/cube/schedule/iterator.py b/cube/schedule/iterator.py index 317ed3d2..a05a77a0 100644 --- a/cube/schedule/iterator.py +++ b/cube/schedule/iterator.py @@ -5,6 +5,35 @@ import numpy as np +def _comb(n, m): + """ + Calcualte combination C(n,m): select n from m (n < m) + """ + res = 1 + for j in range(0, min(n, m)): + res *= (m-j) / (min(n, m) - j) + return int(res) + + +def get_pipeline_seq_space_size(nstage, nmb): + """ + Calculate legal sequence number given num stage and num microbatch + + \prod \limits_{i=1}^{nmb} C(nstage, i*nstage) + + Args: + nstage: number of stages + nmb: number of micro batch + + Return: + total legal line + """ + res = 1 + for i in range(1, nmb+1): + res *= _comb(nstage*2, i*nstage*2) + return res + + def legal_sequence(actions, relations): """ Yield all possible legal sequence given the list of actions diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py index f52794dd..02f7c00a 100644 --- a/examples/poc/pipeline_space.py +++ b/examples/poc/pipeline_space.py @@ -258,6 +258,8 @@ def backward(grad): help='number of micro-batch') parser.add_argument('--ndev', type=int, default=4, help='number of devices') + parser.add_argument('--full-placement', action='store_true', + help='device assignment for each action will be fully explored') parser.add_argument('--outpath', type=str, default='/mydata/MagicCube/search/pipeline/') args = parser.parse_args() @@ -265,4 +267,5 @@ def backward(grad): # pipe_1f1b(actions, relations, args.nstage, args.nmb, args.ndev) # gpipe(actions, relations, args.nstage, args.nmb, args.ndev) - space_search_mp(actions, relations, args.nstage, args.nmb, args.ndev, args.outpath, space_iter=fixed_placement) + space_iter = placement_space if args.full_placement else fixed_placement + space_search_mp(actions, relations, args.nstage, args.nmb, args.ndev, args.outpath, space_iter=space_iter) diff --git a/examples/poc/space_size.py b/examples/poc/space_size.py new file mode 100644 index 00000000..fb1b4995 --- /dev/null +++ b/examples/poc/space_size.py @@ -0,0 +1,57 @@ +from cube.schedule.iterator import get_pipeline_seq_space_size + +import argparse + + +def get_seq_space_size(nstage, nmb): + """ + Calculate legal sequence number given num stage and num microbatch + + \prod \limits_{i=1}^{nmb} C(nstage, i*nstage) + + Args: + nstage: number of stages + nmb: number of micro batch + + Return: + total legal line + """ + return get_pipeline_seq_space_size(nstage, nmb) + + +def get_device_space_size(nstage, nmb, ndevice): + """ + Calculate legal spatial sequence number given num stage and num microbatch + + \prod \limits_{i=1}^{nmb} C(nstage, i*nstage) + + Args: + nstage: number of stages + nmb: number of micro batch + ndevice: number of device + + Return: + total legal line + """ + num_actions = nmb * nstage * 2 + device_space_size = ndevice ** num_actions + return device_space_size + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--nstage', type=int, default=4, + help='number of stages') + parser.add_argument('--nmb', type=int, default=4, + help='number of micro-batch') + parser.add_argument('--ndev', type=int, default=4, + help='number of devices') + args = parser.parse_args() + + seq_space = get_seq_space_size(args.nstage, args.nmb) + print('legal sequence space: {}'.format(seq_space)) + dev_space = get_device_space_size(args.nstage, args.nmb, args.ndev) + print('spatial space for one sequence: {}'.format(dev_space)) + total_space = seq_space * dev_space + print('total space: {}'.format(total_space)) From d2db893169f99b07954c2df89db010ddefcc0c9e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 10 Sep 2021 13:14:25 +0800 Subject: [PATCH 0164/1892] update policy with spatial scheduling --- examples/case_study/policy/logical_code.py | 68 +++++++++++---------- examples/case_study/policy/policy.py | 70 +++++++++++----------- 2 files changed, 71 insertions(+), 67 deletions(-) diff --git a/examples/case_study/policy/logical_code.py b/examples/case_study/policy/logical_code.py index c997ff19..fb03a747 100644 --- a/examples/case_study/policy/logical_code.py +++ b/examples/case_study/policy/logical_code.py @@ -4,6 +4,10 @@ import argparse +def sschedule(partial_dag, resources): pass +def tschedule(train_fn): pass +resources = None # available hardware resources + class FeedForward(nn.Module): def __init__(self, dim, dropout=0., mult=16, classes=1000): @@ -17,18 +21,25 @@ def __init__(self, dim, dropout=0., mult=16, classes=1000): self.classifier = nn.Linear(dim, classes) - def forward(self, x): - with annotate(data_parallel): - output = self.net(x) - output = self.classifier(output) - return output + def forward(self, data, label): + output = self.net(data) + output = self.classifier(output) + loss = F.cross_entropy(output, label) + return loss -def data_iter(bs, dim, classes, length=64): +def data_iter(gbs, dim, classes, length=1024, mbs=None): + mbs = mbs if mbs is not None else gbs + num_mb = gbs // mbs for _ in range(length): - data = torch.randn((bs, dim)) - label = torch.randint(0, classes, (bs,)) - yield data, label + gbs_data = list() + gbs_label = list() + for _ in range(num_mb): + mbs_data = torch.randn((mbs, dim)) + mbs_label = torch.randint(0, classes, (mbs,)) + gbs_data.append(mbs_data) + gbs_label.append(mbs_label) + yield gbs_data, gbs_label if __name__ == '__main__': @@ -36,18 +47,26 @@ def data_iter(bs, dim, classes, length=64): parser = argparse.ArgumentParser() parser.add_argument('--dim', type=int, default=1024) parser.add_argument('--heads', type=int, default=16) - parser.add_argument('--bs', type=int, default=8) + parser.add_argument('--gbs', type=int, default=64) + parser.add_argument('--mbs', type=int, default=4) parser.add_argument('--classes', type=int, default=10) args = parser.parse_args() model = FeedForward(args.dim, mult=args.heads, classes=args.classes) # model = model.cuda() - ### ======= get DAG and modify by policy ======= ### - dag = get_dag(model, data) - new_dag = policy(dag, resources)[myrank] - model = new_dag - ### ======= get DAG and modify by policy ======= ### + # spatial schedule + model = sschedule(model, resources) + # temporal schedule + @tschedule + def train_iter(data, label): + # forward + loss = model(data, label) + # backward + loss.backward() + # update + optimizer.step() + optimizer.zero_grad() optimizer = torch.optim.Adam( model.parameters(), @@ -56,20 +75,5 @@ def data_iter(bs, dim, classes, length=64): weight_decay=0 ) - - for (data, label) in data_iter(args.bs, args.dim, args.classes): - data, label = data.cuda(), label.cuda() - # forward - output = model(data) - loss = F.cross_entropy(output, label) - # backward - loss.backward() - # weight update - optimizer.step() - optimizer.zero_grad() - - -## dynamics? - -## weight update + forward concurrent - + for (data, label) in data_iter(args.gbs, args.dim, args.classes, mbs=args.mbs): + train_iter(data, label) diff --git a/examples/case_study/policy/policy.py b/examples/case_study/policy/policy.py index 6e4d1581..249b71cc 100644 --- a/examples/case_study/policy/policy.py +++ b/examples/case_study/policy/policy.py @@ -1,47 +1,47 @@ -""" -DAG interface: - add_op +def select(tensor, indices, val_map_op=None, shape=None): pass - delete_op +def input_adapter(inputs, target): pass - update_op +def iter_op(DAG): pass +def generate_for_each_rank(pDAG): pass - find_op / iter_op -""" -def policy(DAG, resources): +def sschedule_dp(pDAG, resources, input_tensors): """ + Data Parallel Description + Args: - * DAG: semantic (logical) computation graph + * pDAG: partial semantic (logical) computation graph * Resources: Environment inlcuding devices, network topology etc Returns: - * DAGs (list[DAG]) execution (local & physical) DAG for each rank + * pDAGs (list[DAG]) execution (local & physical) DAG for each rank """ - for inputs, op, outputs in iter_op(DAG): - if is_annotated(op): - # distributed op adapter - dist_op = select_dist_op(op, inputs, resources) - replace_op(DAG, op, dist_op) - # input tensor segmentation adapter - inputs = input_adapter(DAG, dist_op, inputs) - # output tensor segmentation adapter - outputs = output_adapter(DAG, dist_op, outputs) - # tensor placement / lifecycle adapter - if is_annotated(inputs): - placement_lifecycle_adapter(DAG, inputs) - # tensor move / destroy - if is_annotated(outputs): - placement_lifecycle_adapter(DAG, outputs) - # scheduling - # TODO: do we need to include scheduling in the DAG? + ndevs = resources.ndevs + for data in input_tensors: + shape = data.shape + for sid in range(ndevs): + chunk_shape = () + for dim, size in enumerate(shape): + if dim == 0: + chunk_size = shape[0] // ndevs + chunk_shape.append(slice(sid * chunk_size, (sid+1) * chunk_size)) + else: + chunk_shape.append(slice(0, size)) + data.add_segment(select(data, chunk_shape, None)) + pDAG.op[0].set_partition(input_tensors) + for op in iter_op(pDAG): + # inputs: op input tensors + # outputs: op output tensors + for inputs, dist_op, outputs in op.dist_candidates(): + # find the data parallelism + if dist_op.satisfy(inputs): + # set placement + dist_op.op_placement = list(range(ndevs)) + # replace logical op to data parallelism + input_adapter(dist_op.inputs, target=inputs) + # output will be in data parallel format + pDAG.replace(op, dist_op) # materialize to physical op - DAGs = generate_for_each_rank(DAG, resources) + DAGs = generate_for_each_rank(pDAG, resources) return DAGs - - -def select_dist_op(op, inputs, resources): - op_candidates = get_distributed_ops(type(op)) - for candidate in op_candidates: - if candidate.same_segmentation(inputs): - return candidate From 61458d03c1eb00b60418117768949665f0029ca5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 10 Sep 2021 13:17:40 +0800 Subject: [PATCH 0165/1892] add temporal schedule --- examples/case_study/policy/policy.py | 64 +++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/examples/case_study/policy/policy.py b/examples/case_study/policy/policy.py index 249b71cc..4d602c16 100644 --- a/examples/case_study/policy/policy.py +++ b/examples/case_study/policy/policy.py @@ -1,4 +1,5 @@ - +import torch + def select(tensor, indices, val_map_op=None, shape=None): pass def input_adapter(inputs, target): pass @@ -20,22 +21,23 @@ def sschedule_dp(pDAG, resources, input_tensors): ndevs = resources.ndevs for data in input_tensors: shape = data.shape - for sid in range(ndevs): + # set num micro-batch to 4 + for sid in range(ndevs * 4): chunk_shape = () for dim, size in enumerate(shape): if dim == 0: - chunk_size = shape[0] // ndevs + chunk_size = shape[0] // ndevs // 4 chunk_shape.append(slice(sid * chunk_size, (sid+1) * chunk_size)) else: chunk_shape.append(slice(0, size)) data.add_segment(select(data, chunk_shape, None)) pDAG.op[0].set_partition(input_tensors) - for op in iter_op(pDAG): + for inputs, op, outputs in iter_op(pDAG): # inputs: op input tensors # outputs: op output tensors - for inputs, dist_op, outputs in op.dist_candidates(): + for dist_op in op.dist_candidates(): # find the data parallelism - if dist_op.satisfy(inputs): + if dist_op.satisfy_and_set(inputs): # set placement dist_op.op_placement = list(range(ndevs)) # replace logical op to data parallelism @@ -45,3 +47,53 @@ def sschedule_dp(pDAG, resources, input_tensors): # materialize to physical op DAGs = generate_for_each_rank(pDAG, resources) return DAGs + + +def tschedule_1f1b(actions, relations, resources): + """ + Pipeline 1f1b policy description -- each device order + + Actions: a list of actions + + relations: list[(Action1, Action2)]: a list of tuples indicate partial order + """ + num_stage = resources.n_gpus + num_microbatch = len(actions) / 2 / num_stage + + f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] + b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] + + # action placement + for stage in range(num_stage): + for mid in range(num_microbatch): + f(stage, mid).device = torch.device.cuda(stage) + b(stage, mid).device = torch.device.cuda(stage) + + # action in-device order + stage_order = list() + + for stage in range(num_stage): + order = list() + num_warmup_microbatch = num_stage - stage - 1 + num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) + num_microbatch_remain = num_microbatch - num_warmup_microbatch + + # warmup + for mid in range(num_warmup_microbatch): + order.append(f(stage, mid)) + + # steady + for i in range(num_microbatch_remain): + f_mid = num_warmup_microbatch + i + b_mid = i + order.append(f(stage, f_mid)) + order.append(b(stage, b_mid)) + + # cooldown + for i in range(num_warmup_microbatch): + b_mid = num_microbatch_remain + i + order.append(b(stage, b_mid)) + + stage_order.append(order) + + return stage_order From b36070c5ae06dd71d0d6ce807cfd122d50a7e74b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 13 Sep 2021 11:12:22 +0800 Subject: [PATCH 0166/1892] add constraints description --- examples/case_study/spatial_primitive.py | 43 ++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/examples/case_study/spatial_primitive.py b/examples/case_study/spatial_primitive.py index 7439e77c..e4829469 100644 --- a/examples/case_study/spatial_primitive.py +++ b/examples/case_study/spatial_primitive.py @@ -36,6 +36,10 @@ def generate(tensor: PhyiscalTensor, rank): pass def all_gather(tensors, dim): pass +# the logical op linear: +def linear(inputs, weight, bias) -> LogicalTensor: pass + + def linear_tensor_parallel(inputs, weight, bias, output): """ inputs: (M, K) @@ -112,6 +116,45 @@ def linear_tensor_parallel(inputs, weight, bias, output): ) +def linear_tensor_parallel_space(inputs, weight, bias, output): + """ + inputs: (M, K) + weight: (N, K) + bias: (N,) + output: (M, N) + + Perform: (M, K) * (\delta N, K) + (\delta N,) = (M, \delta N) + """ + + # no split + def Full(): pass + # split at axis + def SplitAxis(axis, chunk_num, overlap): pass + + # add constraints for inter-tensors + def add_constraint(condition): pass + + # ========= segmentation constraints ===========# + inputs.segment = Full() + weight.segment = SplitAxis( + axis=0, chunk_num=None, overlap=0 + ) + bias.segment = SplitAxis( + axis=0, chunk_num=None, overlap=0 + ) + add_constraint(bias.segment.chunk_num == weight.segment.chunk_num) + + output.segment = SplitAxis( + axis=1, chunk_num=None, overlap=0 + ) + add_constraint(output.segment.chunk_num == weight.layout.chunk_num) + + # ========= distributed algorithms ============# + for pweight, pbias, pout in zip(weight, bias, output): + pout.fill(linear(inputs, pweight, pbias)) + return output + + ## =============== tensor movement / re-generation ============== ## def custom_op(forward_fn, backward_fn): pass From 518df96b0e2efc4a25f313c37051178784fae470 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 13 Sep 2021 14:08:20 +0800 Subject: [PATCH 0167/1892] update dp policy --- examples/case_study/policy/policy.py | 57 ++++++++++++---------------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/examples/case_study/policy/policy.py b/examples/case_study/policy/policy.py index 4d602c16..bb52ab6a 100644 --- a/examples/case_study/policy/policy.py +++ b/examples/case_study/policy/policy.py @@ -13,40 +13,37 @@ def sschedule_dp(pDAG, resources, input_tensors): Data Parallel Description Args: - * pDAG: partial semantic (logical) computation graph + * pDAG: (partial) logical computation graph * Resources: Environment inlcuding devices, network topology etc Returns: * pDAGs (list[DAG]) execution (local & physical) DAG for each rank """ + # rank [0,1,..., pp_size-1], [pp_size, ..., 2*pp_size - 1], ... ndevs = resources.ndevs - for data in input_tensors: - shape = data.shape - # set num micro-batch to 4 - for sid in range(ndevs * 4): - chunk_shape = () - for dim, size in enumerate(shape): - if dim == 0: - chunk_size = shape[0] // ndevs // 4 - chunk_shape.append(slice(sid * chunk_size, (sid+1) * chunk_size)) - else: - chunk_shape.append(slice(0, size)) - data.add_segment(select(data, chunk_shape, None)) - pDAG.op[0].set_partition(input_tensors) - for inputs, op, outputs in iter_op(pDAG): - # inputs: op input tensors - # outputs: op output tensors - for dist_op in op.dist_candidates(): + # suppose 8 devices, 4 for pipeline, 2 for data parallel + dp_size = 2 + pp_size = 4 + for op in iter_op(pDAG): + for op_id, dist_op in enumerate(op.dist_candidates()): # find the data parallelism - if dist_op.satisfy_and_set(inputs): - # set placement - dist_op.op_placement = list(range(ndevs)) - # replace logical op to data parallelism - input_adapter(dist_op.inputs, target=inputs) - # output will be in data parallel format + if is_data_parallelism(dist_op): + for tensor in dist_op.inputs + dist_op.outputs: + if isinstance(tensor.segment, SplitAxis): + # pipeline micro-batch = 4 + tensor.segment.chunk_num = dp_size * 4 + # translate to logical tensor segments + tensor.segment.translate() + dist_op.generate_ops() + # setup placement + stage = op_id // (len(pDAG) // pp_size) + for dp_id, sub_op in enumerate(dist_op.ops): + sub_op.device = (dp_id % dp_size) * pp_size + stage + # materialize -- call to the deploy + dist_op.materialize() + # generate input adapter pDAG.replace(op, dist_op) - # materialize to physical op - DAGs = generate_for_each_rank(pDAG, resources) - return DAGs + break + return pDAG def tschedule_1f1b(actions, relations, resources): @@ -63,12 +60,6 @@ def tschedule_1f1b(actions, relations, resources): f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] - # action placement - for stage in range(num_stage): - for mid in range(num_microbatch): - f(stage, mid).device = torch.device.cuda(stage) - b(stage, mid).device = torch.device.cuda(stage) - # action in-device order stage_order = list() From 49c5c0f85ca67389d6d0c988d160f7c3bd5358ca Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 13 Sep 2021 14:09:12 +0800 Subject: [PATCH 0168/1892] torch.fx poc --- examples/poc/torchfx.py | 167 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 examples/poc/torchfx.py diff --git a/examples/poc/torchfx.py b/examples/poc/torchfx.py new file mode 100644 index 00000000..b53d232f --- /dev/null +++ b/examples/poc/torchfx.py @@ -0,0 +1,167 @@ +""" +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=62000 \ + --use_env \ + examples/poc/torchfx.py +""" + +import torch +from torch import nn +import torch.nn.functional as F +from torch.fx import symbolic_trace + +import os + + +local_rank = int(os.environ.get('LOCAL_RANK')) +torch.cuda.set_device(local_rank) +torch.distributed.init_process_group( + backend='nccl', + init_method='env://', +) + +# ====================== Check for normal module ========================== +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=16, classes=1000): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim * mult, dim) + self.classifier = nn.Linear(dim, classes) + + def forward(self, x): + output = self.linear1(x) + output = self.gelu(output) + output = self.dropout(output) + output = self.linear2(output) + output = self.classifier(output) + return output + +model = FeedForward(dim=1024).cuda() +graph_module = symbolic_trace(model) +if local_rank == 0: + print(graph_module) + print(graph_module.code) + print(graph_module.graph) + + +# ====================== Check for autograd function ========================== +class CustomOp(torch.autograd.Function): + @staticmethod + def symbolic(graph, input, weight): + return torch.matmul(input, weight) + @staticmethod + def forward(ctx, input, weight): + ctx.save_for_backward(input, weight) + return torch.matmul(input, weight) + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + return input+weight, input+weight + +class CustomModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, weight): + out = CustomOp.apply(input, weight) + return out + +custom_op = CustomModule().cuda() + +input = torch.ones((1024, 1024)).cuda().requires_grad_() +weight = torch.ones((1024, 1024)).cuda().requires_grad_() + +if local_rank == 0: + custom_op_trace = symbolic_trace(custom_op) + print(custom_op_trace) + print(custom_op_trace.code) + print(custom_op_trace.graph) + # traced graph call + out = custom_op_trace(input, weight) + torch.sum(out).backward() + print(out) + print('weight grad: ', weight.grad) + # original graph call + + out = custom_op(input, weight) + input.grad = None + weight.grad = None + torch.sum(out).backward() + print('weight grad expected: ', weight.grad) + print(out) + +torch.distributed.barrier() + + +# ====================== Check for function with communications ========================== +class InputAdapter(torch.autograd.Function): + @staticmethod + def symbolic(graph, input_): + return input_ + @staticmethod + def forward(ctx, input_): + return input_ + @staticmethod + def backward(ctx, grad_output): + return torch.distributed.all_reduce(grad_output) + + +class OutputAdapter(torch.autograd.Function): + @staticmethod + def symbolic(graph, input_): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_) + output = torch.cat(tensor_list, dim=-1) + return output + @staticmethod + def forward(ctx, input_): + # world_size = torch.distributed.get_world_size() + # rank = torch.distributed.get_rank() + # tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + # tensor_list[rank] = input_ + # torch.distributed.all_gather(tensor_list, input_) + # output = torch.cat(tensor_list, dim=-1) + output = input_ + torch.distributed.all_reduce(output) + return output + @staticmethod + def backward(ctx, grad_output): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + tensor_list = torch.split( + grad_output, grad_output.size()[-1]//world_size, dim=-1 + ) + return tensor_list[rank].contiguous() + + +class LinearComm(nn.Module): + def __init__(self, input_feats, output_feats): + super().__init__() + self.linear = nn.Linear(input_feats, output_feats) + def forward(self, x): + x = InputAdapter.apply(x) + x = self.linear(x) + x = OutputAdapter.apply(x) + return x + +comm_linear = LinearComm(1024, 1024).cuda() +graph_comm = symbolic_trace(comm_linear) +if local_rank == 0: + print(graph_comm.graph) + print(graph_comm.code) + +input = torch.ones((1024, 1024)).cuda().requires_grad_() +out = graph_comm(input) +out_ref = comm_linear(input) +if local_rank == 0: + print('out: ', out) + print('out expected: ', out_ref) From 159366aac2d5ab586f2e241a810472e74f97064e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Sep 2021 17:25:26 +0800 Subject: [PATCH 0169/1892] feedforward network example --- examples/ffn.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/examples/ffn.py b/examples/ffn.py index 17a66a80..960adb28 100644 --- a/examples/ffn.py +++ b/examples/ffn.py @@ -6,21 +6,19 @@ class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): + def __init__(self, dim, dropout=0., mult=4): super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, dim * mult), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(dim * mult, dim) - ) - - self.classifier = nn.Linear(dim, classes) + self.linear1 = nn.Linear(dim, dim * mult) + self.gelu = nn.GELU() + self.linear2 = nn.Linear(dim * mult, dim) + self.dropout = nn.Dropout(dropout) def forward(self, x): - output = self.net(x) - output = self.classifier(output) - return output + x = self.linear1(x) + x = self.gelu(x) + x = self.linear2(x) + x = self.dropout(x) + return x def data_iter(bs, dim, classes, length=64): @@ -34,13 +32,12 @@ def data_iter(bs, dim, classes, length=64): parser = argparse.ArgumentParser() parser.add_argument('--dim', type=int, default=1024) - parser.add_argument('--heads', type=int, default=16) parser.add_argument('--bs', type=int, default=8) parser.add_argument('--classes', type=int, default=10) args = parser.parse_args() - model = FeedForward(args.dim, mult=args.heads, classes=args.classes) - model = model.cuda() + model = torch.jit.script(FeedForward(args.dim).cuda()) + print(model.code) optimizer = torch.optim.Adam( model.parameters(), From 0133216532daf7cad6f092b37417e40b1900e6b0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Sep 2021 20:28:18 +0800 Subject: [PATCH 0170/1892] init code translation --- cube/codegen/__init__.py | 0 cube/codegen/codegen.py | 11 +++++++ cube/codegen/syntax/__init__.py | 0 cube/codegen/syntax/blocks.py | 52 +++++++++++++++++++++++++++++++ cube/codegen/syntax/symtable.py | 50 ++++++++++++++++++++++++++++++ cube/graph/__init__.py | 0 cube/graph/graph.py | 38 +++++++++++++++++++++++ cube/graph/parser.py | 55 +++++++++++++++++++++++++++++++++ 8 files changed, 206 insertions(+) create mode 100644 cube/codegen/__init__.py create mode 100644 cube/codegen/codegen.py create mode 100644 cube/codegen/syntax/__init__.py create mode 100644 cube/codegen/syntax/blocks.py create mode 100644 cube/codegen/syntax/symtable.py create mode 100644 cube/graph/__init__.py create mode 100644 cube/graph/graph.py create mode 100644 cube/graph/parser.py diff --git a/cube/codegen/__init__.py b/cube/codegen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py new file mode 100644 index 00000000..723bcd29 --- /dev/null +++ b/cube/codegen/codegen.py @@ -0,0 +1,11 @@ +""" +Generate Pytorch code given the model DAG and the transformation config +""" + + +class ModelCodeGen: + + def __init__(self, action_irs): + pass + + diff --git a/cube/codegen/syntax/__init__.py b/cube/codegen/syntax/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/codegen/syntax/blocks.py b/cube/codegen/syntax/blocks.py new file mode 100644 index 00000000..33e44485 --- /dev/null +++ b/cube/codegen/syntax/blocks.py @@ -0,0 +1,52 @@ + + +class Block: + + def __init__(self, title): + if not isinstance(title, str): + raise TypeError(f"Expected string, but got {type(title)}") + self.code = [title] + + def __enter__(self): + return self + + def insert_body(self, code): + if isinstance(code, list): + self.code += code + elif type(code) == str: + self.code.append(code) + else: + raise TypeError + + def __exit__(self, exc_type, exc_value, exc_tb): + # add indent for function block + for idx in range(1, len(self.code)): + self.code[idx] = '\t' + self.code[idx] + if not exc_tb is None: + print('Error detected in function block') + + +class FunctionBlock(Block): + + def __init__(self, func_name, args): + self.func_name = func_name + self.param_name = args + args = ', '.join(args) + title = f'def {self.func_name}({args}):' + super().__init__(title) + + +class ClassBlock(Block): + + def __init__(self, class_name, derived=None): + if not isinstance(class_name, str): + raise TypeError("Expected class_name to be str") + if not isinstance(derived, list) and derived is not None: + raise TypeError("Expcted derived to be None or list[str]") + self.class_name = class_name + if derived: + derived = ', '.join(derived) + derived = f'({derived})' + title = f'class {self.class_name}{derived}' + super().__init__(self, title) + diff --git a/cube/codegen/syntax/symtable.py b/cube/codegen/syntax/symtable.py new file mode 100644 index 00000000..abf8235f --- /dev/null +++ b/cube/codegen/syntax/symtable.py @@ -0,0 +1,50 @@ + + +class SymbolTable: + """ + Symbolic table for saving declared variables. + + Assume the program will first declare all possible used + variables before entering any of its sub (children) scopes. + + Attributes: + name (str): name of this scope + _varlist (dict{str: DType}): declared variable dict + var_name -> type_of_var + """ + + def __init__(self): + self._varlist = list() + + def create(self, var_name): + """ + Create a variable with a type. + + If var_name is already declared: + if the declared type matches with var_name, return False. + else raise Error + + If var name is not declared, decalre the var and return True. + + Args: + var_name (str): variable name + var_type (Dtype): variable type + + Returns: + True if declared, False if the var already exists. + """ + assert isinstance(var_name, str) + if var_name in self._varlist: + return False + else: + self._varlist.append(var_name) + + def exist(self, var_name): + """ + Check whether a variable exists + """ + assert isinstance(var_name, str) + if var_name in self._varlist: + return True + else: + return False diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/graph/graph.py b/cube/graph/graph.py new file mode 100644 index 00000000..5be57e55 --- /dev/null +++ b/cube/graph/graph.py @@ -0,0 +1,38 @@ +""" +Convert PyTorch nn.Module to our IRGraph +""" +import torch + +class Node: + + def __init__(self, name, type_name): + """ + Create a node with name (variable name) and module type (module_name) + + Args: + name (str): the var name of the module + type_name: the type name of the module + + Example: + init code: + self.linear1 = torch.nn.Linear(input_feats, output_feats) + forward code: + output = self.linear1(input) + => + name = linear1; type_name = torch.nn.Linear + """ + pass + + +class IRGraph: + + def __init__(self, module, example_inputs=None): + + self.module = module + self.script_module = torch.jit.script(module) + # model info + self.module_name = None + + + def _convert(self): + self.module_name = self.script_module.original_name diff --git a/cube/graph/parser.py b/cube/graph/parser.py new file mode 100644 index 00000000..83af4e77 --- /dev/null +++ b/cube/graph/parser.py @@ -0,0 +1,55 @@ +import torch +import enum + +class ScriptNodeKind(enum.Enum): + PrimGetAttr = 1 + PrimCallMethod = 2 + + +class ScriptModuleParser: + + @staticmethod + def get_node_type(node: torch._C.Node): + if node.kind() == 'prim::GetAttr': + return ScriptNodeKind.PrimGetAttr + if node.kind() == 'prim::CallMethod': + return ScriptNodeKind.PrimCallMethod + + @staticmethod + def parse_node(node: torch._C.Node, module: torch.jit.RecursiveScriptModule): + """ + Parse the node and get the inputs, module type, outputs + + Returns: + Inputs: list[int] + Modulelin + """ + ntype = ScriptNodeKind.get_node_type(node) + if ntype == ScriptNodeKind.PrimGetAttr: + return ScriptModuleParser.parse_prim_attr_node(node, module) + if ntype == ScriptNodeKind.PrimCallMethod: + return ScriptModuleParser.parse_prim_call_node(node, module) + + @staticmethod + def parse_prim_attr_node(node, module): + """ + Parse script module node like: + %2 :__torch__.torch.nn.modules.linear.___torch_mangle_0.Linear = prim::GetAttr[name="linear1"](%self) + """ + name = node.s('name') + module_mangle = node.outputsAt(0).type().str() + module_name_demangle = list() + for mname in module_mangle.split('.'): + if mname == '__torch__' or '_mangle' in mname: + continue + module_name_demangle.append(mname) + module_name_demangle = '.'.joint(module_name_demangle) + # TODO + + + def parse_prim_call_node(node, module): + """ + Parse script module node like: + %output.1 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # /tmp/ipykernel_27188/97711738.py:11:17 + """ + #TODO \ No newline at end of file From 3020a2239bbf73ffeedcbb295d5cdb36b5ede47d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 15 Sep 2021 11:26:15 +0800 Subject: [PATCH 0171/1892] init graph --- cube/graph/graph.py | 181 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 167 insertions(+), 14 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5be57e55..77b633c4 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -1,17 +1,28 @@ """ Convert PyTorch nn.Module to our IRGraph """ -import torch +from typing import List, Optional -class Node: - def __init__(self, name, type_name): +__all__ = ['IROperation', 'IRTensor', 'IRGraph'] + + +class IROperation: + """ + IROperation serves as IRGraph node + """ + + def __init__(self, node_id: int, + op_name: str, + input_length: int, output_length: int, + label: str,): """ Create a node with name (variable name) and module type (module_name) Args: - name (str): the var name of the module - type_name: the type name of the module + node_id (int): the int + label (str): the var name of the module + type: the type name of the module Example: init code: @@ -19,20 +30,162 @@ def __init__(self, name, type_name): forward code: output = self.linear1(input) => - name = linear1; type_name = torch.nn.Linear + label = linear1; + op = torch._C._nn.linear + """ + # node info + self._id: int = node_id + self.label: str = label + + # node type + self.op: str = op_name + + # edge (dataflow info) + self._inputs: List[IRTensor] = [None] * input_length + self._predecessors: List[IROperation] = [None] * input_length + # todo for outputs + self._outputs: List[IRTensor] = [None] * output_length + self._successors: List[IROperation] = [None] * output_length + + def inputs(self, index: Optional(None, int) = None): + """ + Get input tensor at input index + + Args: + index (int or None): + index of the inputs, None will return the nodes + for all the inputs """ - pass + if isinstance(index, int): + if index >= len(self._inputs): + raise RuntimeError( + f"Get the input out of range ({index} >= {len(self._inputs)}" + ) + return self._inputs[index] + elif index is None: + return self._inputs + else: + raise TypeError("Expected index to be None or int") + + def predecessors(self, index: Optional(None, int) = None): + """ + Get input operator at input index + """ + if isinstance(index, int): + if index >= len(self._inputs): + raise RuntimeError( + f"Get the input out of range ({index} >= {len(self._inputs)}" + ) + return self._predecessors[index] + elif index is None: + return self._predecessors + else: + raise TypeError("Expected index to be None or int") + + def outputs(self, index: Optional(None, int) = None): + """ + Get output tensor at output index + + Args: + index (int or None): + index of the outputs, None will return the nodes + for all the outputs + """ + if isinstance(index, int): + if index >= len(self._outputs): + raise RuntimeError( + f"Get the output out of range ({index} >= {len(self._outputs)}" + ) + return self._outputs[index] + elif index is None: + return self._outputs + else: + raise TypeError("Expected index to be None or int") + + def successors(self, index: Optional(None, int) = None): + """ + Get output operator at output index + + Args: + index (int or None): + index of the outputs, None will return the nodes + for all the outputs + """ + if isinstance(index, int): + if index >= len(self._outputs): + raise RuntimeError( + f"Get the output out of range ({index} >= {len(self._outputs)}" + ) + return self._successors[index] + elif index is None: + return self._successors + else: + raise TypeError("Expected index to be None or int") + + def set_predecessor(self, input_index: int, node: IROperation, out_index: int): + """ + Set self node the input node. self.input[input_index] = node.output[out_index] + """ + if not isinstance(node, IROperation): + raise TypeError("Expected node to be IROperation") + if input_index >= len(self.inputs()): + raise RuntimeError( + f"Set the input out of range ({input_index} >= {len(self._inputs)})" + ) + self._inputs[input_index] = node.outputs(out_index) + self._predecessors[input_index] = node + node.set_successor(out_index, self) + + def set_successor(self, out_index: int, node: IROperation): + """ + Set self node the output index node. + `node` will take the self.outputs(index) as the input + """ + if out_index >= len(self._outputs): + raise RuntimeError( + f"Set output index out of range ({out_index} >= {len(self._outputs)}" + ) + self._successors[out_index] = node + + +class IRTensor: + """ + IRTensor serves as IRGraph edge + """ + def __init__(self, edge_id: int, shape: List[int], label: str): + self._id = edge_id + self.shape = shape + self.label = label class IRGraph: + """ + PyTorch IR Graph + + The IR Graph only contains forward graph + """ - def __init__(self, module, example_inputs=None): + def __init__(self, module_name: str): + self.module_name = module_name + self._nodes: List[IROperation] = list() - self.module = module - self.script_module = torch.jit.script(module) - # model info - self.module_name = None + def add_node(self, node: IROperation): + if not isinstance(node, IROperation): + raise TypeError("Expected node to be IROperation") + self._nodes.append(node) + def nodes(self, index: Optional(None, int)): + """ + Get node at position index + """ + if index >= len(self._nodes): + raise RuntimeError( + f"Get node out of range ({index} >= {len(self._nodes)})" + ) + return self._nodes[index] - def _convert(self): - self.module_name = self.script_module.original_name + def replace(self, target: IROperation, nodes: List[IROperation]): + """ + Replace the node with new nodes (IRGraph) + """ + raise NotImplementedError From e77656182fc8f10a5ac2ae2863e3c87f2501bd93 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 15 Sep 2021 21:00:49 +0800 Subject: [PATCH 0172/1892] working in progress --- cube/graph/graph.py | 36 +++++----- cube/graph/parser.py | 159 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 168 insertions(+), 27 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 77b633c4..023f4b2a 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -12,17 +12,17 @@ class IROperation: IROperation serves as IRGraph node """ - def __init__(self, node_id: int, - op_name: str, - input_length: int, output_length: int, - label: str,): + def __init__(self, + name: str, + signature: str, + input_length: int, + output_length: int): """ Create a node with name (variable name) and module type (module_name) Args: - node_id (int): the int - label (str): the var name of the module - type: the type name of the module + name (str): the op semantic name + signature (str): the op signature, e.g., torch.functional.nn.linear Example: init code: @@ -34,11 +34,11 @@ def __init__(self, node_id: int, op = torch._C._nn.linear """ # node info - self._id: int = node_id - self.label: str = label + self._id: int = NotImplementedError + self.name: str = name - # node type - self.op: str = op_name + # op signature + self.signature: str = signature # edge (dataflow info) self._inputs: List[IRTensor] = [None] * input_length @@ -47,7 +47,7 @@ def __init__(self, node_id: int, self._outputs: List[IRTensor] = [None] * output_length self._successors: List[IROperation] = [None] * output_length - def inputs(self, index: Optional(None, int) = None): + def inputs(self, index: Optional[int] = None): """ Get input tensor at input index @@ -67,7 +67,7 @@ def inputs(self, index: Optional(None, int) = None): else: raise TypeError("Expected index to be None or int") - def predecessors(self, index: Optional(None, int) = None): + def predecessors(self, index: Optional[int] = None): """ Get input operator at input index """ @@ -82,7 +82,7 @@ def predecessors(self, index: Optional(None, int) = None): else: raise TypeError("Expected index to be None or int") - def outputs(self, index: Optional(None, int) = None): + def outputs(self, index: Optional[int] = None): """ Get output tensor at output index @@ -102,7 +102,7 @@ def outputs(self, index: Optional(None, int) = None): else: raise TypeError("Expected index to be None or int") - def successors(self, index: Optional(None, int) = None): + def successors(self, index: Optional[int] = None): """ Get output operator at output index @@ -122,7 +122,7 @@ def successors(self, index: Optional(None, int) = None): else: raise TypeError("Expected index to be None or int") - def set_predecessor(self, input_index: int, node: IROperation, out_index: int): + def set_predecessor(self, input_index: int, node, out_index: int): """ Set self node the input node. self.input[input_index] = node.output[out_index] """ @@ -136,7 +136,7 @@ def set_predecessor(self, input_index: int, node: IROperation, out_index: int): self._predecessors[input_index] = node node.set_successor(out_index, self) - def set_successor(self, out_index: int, node: IROperation): + def set_successor(self, out_index: int, node): """ Set self node the output index node. `node` will take the self.outputs(index) as the input @@ -174,7 +174,7 @@ def add_node(self, node: IROperation): raise TypeError("Expected node to be IROperation") self._nodes.append(node) - def nodes(self, index: Optional(None, int)): + def nodes(self, index: Optional[int]): """ Get node at position index """ diff --git a/cube/graph/parser.py b/cube/graph/parser.py index 83af4e77..0edd6ffa 100644 --- a/cube/graph/parser.py +++ b/cube/graph/parser.py @@ -1,19 +1,87 @@ import torch import enum +from cube.graph.graph import IROperation + +from typing import Dict, List class ScriptNodeKind(enum.Enum): PrimGetAttr = 1 PrimCallMethod = 2 + PrimCallFunction = 3 # -> the parser may end here + PrimConstant = 4 + AtenOp = 5 # -> the parser may end here + PrimIf = 6 # dynamic class ScriptModuleParser: @staticmethod - def get_node_type(node: torch._C.Node): + def parse_module(module: torch.jit.RecursiveModule) -> List[IROperation]: + """ + The overall entry to parse a torchscript graph module + """ + all_ir_nodes: List[IROperation] = list() + for node in module.graph.nodes(): + ir_nodes = None + node_type = ScriptModuleParser.ntype(node) + if node_type == ScriptNodeKind.PrimCallFunction: + ir_nodes = ScriptModuleParser.parse_prim_function_node(node, module) + if node_type == ScriptNodeKind.AtenOp: + ir_nodes = ScriptModuleParser.parse_aten_node(node, module) + if node_type == ScriptNodeKind.PrimCallMethod: + ir_nodes = ScriptModuleParser.parse_prim_method_node(node, module) + if ir_nodes is not None: + all_ir_nodes.append(ir_nodes) + return all_ir_nodes + + @staticmethod + def ntype(node: torch._C.Node): if node.kind() == 'prim::GetAttr': return ScriptNodeKind.PrimGetAttr if node.kind() == 'prim::CallMethod': return ScriptNodeKind.PrimCallMethod + if node.kind() == 'prim::CallFunction': # the op call + return ScriptNodeKind.PrimCallFunction + if node.kind() == 'prim::Constant': + return ScriptNodeKind.PrimConstant + if node.kind().startswith('aten::'): + return ScriptNodeKind.AtenOp + if node.kind() == 'prim::If': + return ScriptNodeKind.PrimIf + raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") + + @staticmethod + def parse_prim_function_node(node, module) -> IROperation: + """ + parse node like: + Tensor = prim::CallFunction(%5, %input.1, %3, %4) + %5 : Function = prim::Constant[name="linear"]() + """ + fnode = node.inputsAt(0).node() # function node + if not ScriptModuleParser.ntype(fnode) == ScriptNodeKind.PrimConstant: + raise RuntimeError(f"Found unexpected function call node: {fnode}") + var_name, fsig, _ = ScriptModuleParser.parse_prim_constant_node(fnode, module) + input_length = len(list(fnode.inputs())) - 1 # -1 for the first signature + output_length = len(list(fnode.outputs())) + ir_node = IROperation( + signature = fsig, + name = fnode.s('name'), + input_length=input_length, + output_length=output_length, + ) + return ir_node + + @staticmethod + def parse_aten_node(node, module) -> IROperation: + """ + Parse script module node like: + %13 : Tensor = aten::gt(%output1.1, %output2.1) + """ + input_nodes = list() + for input in node.inputs(): + input_node = input.node() + input_nodes.append(input_node) + @staticmethod def parse_node(node: torch._C.Node, module: torch.jit.RecursiveScriptModule): @@ -28,28 +96,101 @@ def parse_node(node: torch._C.Node, module: torch.jit.RecursiveScriptModule): if ntype == ScriptNodeKind.PrimGetAttr: return ScriptModuleParser.parse_prim_attr_node(node, module) if ntype == ScriptNodeKind.PrimCallMethod: - return ScriptModuleParser.parse_prim_call_node(node, module) + return ScriptModuleParser.parse_prim_method_node(node, module) + if ntype == ScriptNodeKind.PrimCallFunction: + return ScriptModuleParser.parse_prim_function_node(node, module) + if ntype == ScriptNodeKind.AtenOp: + return ScriptModuleParser.parse_aten_node(node, module) + if ntype == ScriptNodeKind.PrimIf: + return ScriptModuleParser.parse_prim_if_node(node, module) + raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") + + @staticmethod + def parse_prim_method_node(node, module) -> List[IROperation]: + """ + Parse script module node like: + %output.1 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) + + Find the module and + """ + # forward + label = node.s('name') + if label != 'forward': + raise RuntimeError(f"{node} is calling function {label} that is not `forward`") + call_module = getattr(module, label) + ir_nodes = ScriptModuleParser.parse_module(call_module) + # TODO: rename nodes + return ir_nodes @staticmethod - def parse_prim_attr_node(node, module): + def parse_prim_attr_node(node, module) -> Dict: """ Parse script module node like: %2 :__torch__.torch.nn.modules.linear.___torch_mangle_0.Linear = prim::GetAttr[name="linear1"](%self) + %3 : Tensor = prim::GetAttr[name="weight"](%self) + The __torch__.torch.nn.modules.* will be ignored """ - name = node.s('name') + if node.inputsAt(0).debugeName() != 'self': + raise RuntimeError(f"Fail to parse {node} due to missing %self") + # user defined var name: linear1 + label = node.s('name') + + # output names: [2] + output_vars : List[str] = list() + for output in node.outputs(): + output_vars.append(output.debugName()) + + # module type name module_mangle = node.outputsAt(0).type().str() + if 'torch.nn.modules' in module_mangle: + return None module_name_demangle = list() for mname in module_mangle.split('.'): if mname == '__torch__' or '_mangle' in mname: continue module_name_demangle.append(mname) module_name_demangle = '.'.joint(module_name_demangle) - # TODO + return dict( + input_ids=list(), output_names=output_names, + module=module_name_demangle, + label=label + ) + + @staticmethod + def parse_prim_constant_node(node): + output = node.outputsAt(0) + if output.type().str() == 'Function': + func_name = repr(output.type()) + + @staticmethod + def parse_prim_if_node(node, module): + """ + Parse script module node like + %output2 : Tensor = prim::If(%15) # /tmp/ipykernel_27188/2459450745.py:13:8 + block0(): + -> (%output1.1) + block1(): + -> (%output2.1) + """ + raise NotImplementedError("Dynamic Graph is not supported yet") - def parse_prim_call_node(node, module): + @staticmethod + def flatten(smodule, depth=0): """ - Parse script module node like: - %output.1 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # /tmp/ipykernel_27188/97711738.py:11:17 + Flatten the recursive script module to function and aten primitives """ - #TODO \ No newline at end of file + # stashed_module = list() + if len(list(smodule.children())) == 0: + for node in smodule.graph.nodes(): + print(' '*depth, node) + else: + for node in smodule.graph.nodes(): + ntype = ScriptModuleParser.get_node_type(node) + print(' '*depth, node) + if ntype == ScriptNodeKind.PrimCallMethod: + label = node.inputsAt(0).node().s('name') + submodule = getattr(smodule, label) + ScriptModuleParser.flatten(submodule, depth+1) + + From f078573f48710a254d44e6d08bdcfd63bf2a7e97 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Sep 2021 11:09:27 +0800 Subject: [PATCH 0173/1892] graph strcuture --- cube/graph/graph.py | 82 ++++++++++++++++++++++++++++++++++++-------- cube/graph/unique.py | 34 ++++++++++++++++++ 2 files changed, 102 insertions(+), 14 deletions(-) create mode 100644 cube/graph/unique.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 023f4b2a..e340ddec 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -1,7 +1,8 @@ """ Convert PyTorch nn.Module to our IRGraph """ -from typing import List, Optional +from cube.graph.unique import IDGenerator +from typing import List, Optional, Any __all__ = ['IROperation', 'IRTensor', 'IRGraph'] @@ -34,7 +35,7 @@ def __init__(self, op = torch._C._nn.linear """ # node info - self._id: int = NotImplementedError + self._id: int = IDGenerator().gen_op_id() self.name: str = name # op signature @@ -44,8 +45,8 @@ def __init__(self, self._inputs: List[IRTensor] = [None] * input_length self._predecessors: List[IROperation] = [None] * input_length # todo for outputs - self._outputs: List[IRTensor] = [None] * output_length - self._successors: List[IROperation] = [None] * output_length + self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] + self._successors: List[List(IROperation)] = [list() for _ in range(output_length)] def inputs(self, index: Optional[int] = None): """ @@ -122,6 +123,24 @@ def successors(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") + def set_input(self, input_index: int, val: Any): + """ + Set the node inputs[input_index] with the tensor + + val: IRTensor or any deterministic value (int, bool, str, etc) + """ + if input_index >= len(self.inputs()): + raise RuntimeError( + f"Set the input out of range ({input_index} >= {len(self._inputs)})" + ) + # set tensor + self._inputs[input_index] = val + if isinstance(val, IRTensor): + # set predecessor + self._predecessors[input_index] = val.src() + # set the source node successor + val.src()._add_successor(val, self) + def set_predecessor(self, input_index: int, node, out_index: int): """ Set self node the input node. self.input[input_index] = node.output[out_index] @@ -136,26 +155,61 @@ def set_predecessor(self, input_index: int, node, out_index: int): self._predecessors[input_index] = node node.set_successor(out_index, self) - def set_successor(self, out_index: int, node): + def _add_successor(self, tensor, node): """ Set self node the output index node. `node` will take the self.outputs(index) as the input """ - if out_index >= len(self._outputs): - raise RuntimeError( - f"Set output index out of range ({out_index} >= {len(self._outputs)}" - ) - self._successors[out_index] = node + out_index = self._outputs.index(tensor) + if out_index < 0: + raise RuntimeError("Fail to find output tensor") + self._successors[out_index].append(node) class IRTensor: """ IRTensor serves as IRGraph edge """ - def __init__(self, edge_id: int, shape: List[int], label: str): - self._id = edge_id - self.shape = shape - self.label = label + def __init__(self, shape=None, name=None): + + self._id: int = IDGenerator().gen_tensor_id() + self._shape: Optional(List[int]) = shape + self.name = name + self.device = -1 + + # connected to IROperation + self._src_nodes: IROperation = None # -> output of the node + self._dst_nodes: List[IROperation] = list() # -> input of the nodes + + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, val): + if self._shape is not None: + raise RuntimeError("Try to change shape") + if not all([isinstance(size, int) for size in val]): + raise RuntimeError("Expected shape to be list[int]") + self._shape = val + + def src(self) -> Optional[IROperation]: + return self._src_nodes + + def dst(self, index: Optional[int] = None): + if index >= len(self._dst_nodes): + raise RuntimeError("get tensor dst out of range") + return self._dst_nodes[index] + + def set_src_nodes(self, node: IROperation): + if not isinstance(node, IROperation): + raise TypeError("IRTensor source node should be IROperation") + self._src_nodes = node + + def add_dst_nodes(self, node: IROperation): + if not isinstance(node, IROperation): + raise TypeError("IRTensor destination node should be IROperation") + self._dst_nodes.append(IROperation) class IRGraph: diff --git a/cube/graph/unique.py b/cube/graph/unique.py new file mode 100644 index 00000000..ee34a95a --- /dev/null +++ b/cube/graph/unique.py @@ -0,0 +1,34 @@ + +class IDGenerator: + """ + Tensor / Operator manager. To guarantee that each IRTensor / IROperator id + is unique and progressively increases. + + This class is designed in singleton pattern. + """ + class __IDGenerator: + def __init__(self): + + self._tensor_id = 0 + self._op_id = 0 + + instance = None + + def __init__(self): + if not IDGenerator.instance: + IDGenerator.instance = IDGenerator.__IDGenerator() + + def __getattr__(self, name): + return getattr(self.instance, name) + + def gen_tensor_id(self): + self.instance._tensor_id += 1 + return self.instance._tensor_id + + def gen_op_id(self): + self.instance._op_id += 1 + return self.instance._op_id + + def clear(self): + self.instance._tensor_id = 0 + self.instance._op_id = 0 From ceaaaaf18d38cb24ba83981b0aab9a5ee2b40d29 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Sep 2021 15:47:13 +0800 Subject: [PATCH 0174/1892] runnable --- cube/graph/graph.py | 11 +- cube/graph/parser.py | 373 +++++++++++++++++++++++++++++++++---------- 2 files changed, 301 insertions(+), 83 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index e340ddec..e10bfe0a 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -139,7 +139,8 @@ def set_input(self, input_index: int, val: Any): # set predecessor self._predecessors[input_index] = val.src() # set the source node successor - val.src()._add_successor(val, self) + if isinstance(val.src(), IROperation): + val.src()._add_successor(val, self) def set_predecessor(self, input_index: int, node, out_index: int): """ @@ -165,6 +166,10 @@ def _add_successor(self, tensor, node): raise RuntimeError("Fail to find output tensor") self._successors[out_index].append(node) + def __repr__(self): + dscp = f'Op(id={self._id}, signature={self.signature}, inputs={self._inputs}, outputs={self._outputs})' + return dscp + class IRTensor: """ @@ -211,6 +216,10 @@ def add_dst_nodes(self, node: IROperation): raise TypeError("IRTensor destination node should be IROperation") self._dst_nodes.append(IROperation) + def __repr__(self): + dscp = f'Tensor(id={self._id}, name={self.name}, shape={self.shape})' + return dscp + class IRGraph: """ diff --git a/cube/graph/parser.py b/cube/graph/parser.py index 0edd6ffa..ed517df2 100644 --- a/cube/graph/parser.py +++ b/cube/graph/parser.py @@ -1,8 +1,114 @@ import torch import enum -from cube.graph.graph import IROperation +import re +from collections import OrderedDict +from typing import List, Any, Tuple + +from cube.graph.graph import IROperation, IRTensor + + +class _Frame: + """ + Frame to save call stack and variable + """ + def __init__(self): + + # var name -> value (IRTesnor, deterministic) + self._vars: List[dict[str, Any]] = list() + self._var_stack: List[str] = list() + + def push(self): + """ + This should only be called when step in a module + """ + self._vars.append(OrderedDict()) + + def pop(self): + """ + This should only be called step out a module + """ + if len(self._vars) == 0: + raise RuntimeError("Try to pop stack with 0 depth") + self._vars.pop() + + def add_var(self, var_name: str, val: Any, graph_arg: int = 0): + """ + Add variable to the current frame + + Args: + var_name (str): variable name (unique) + val: variable content + graph_arg (int): + indicate whether it is an argument of the graph. + If is 0, is not a graph arg. + If > 0, is a graph arg, will try to find + val from previous frame + """ + if not isinstance(var_name, str): + raise RuntimeError("Expected var_name is str") + if var_name in self._vars[-1]: + raise KeyError("Try to insert an already existed variable") + if graph_arg == 0: + self._vars[-1][var_name] = val + elif graph_arg > 0: + # root graph entry + if self.depth() == 1: + self._vars[-1][var_name] = val + # fucnton call + else: + prev_frame = self._vars[-2] + param_name = self._var_stack[0-graph_arg] + val = prev_frame[param_name] + self._vars[-1][var_name] = val + else: + raise ValueError("graph_arg (int) must be >= 0") + + def get_var(self, var_name: str) -> Any: + """ + Get variable value according to var_name + + Special mapping between frames (function calls): + + input.x will be mapped to output.k at the about 1-hop frame + + Returns: + val (Any) + """ + # first check whether we have variable in this frame + if var_name in self._vars[-1]: + return self._vars[-1][var_name] + raise KeyError(f"Cannot find var name {var_name}") + + def push_param(self, var_name): + """ + push var name to the method stack + + Args: + var_name (str): variable name + """ + if var_name not in self._vars[-1]: + raise KeyError(f"push {var_name} not declared") + self._var_stack.append(var_name) + + def pop_param(self, times=1): + """ + pop var name from the method stack + """ + for _ in range(times): + self._var_stack.pop() + + def depth(self): + return len(self._vars) + + def __repr__(self): + dscp = f'frame: depth: {self.depth()}\n var table:' + for var_name in self._vars[-1].keys(): + dscp += f'\n {var_name} : {self._vars[-1][var_name]}' + dscp += f'\n var stack:' + for var_name in self._var_stack: + dscp += f'\n {var_name}' + return dscp -from typing import Dict, List class ScriptNodeKind(enum.Enum): PrimGetAttr = 1 @@ -16,23 +122,36 @@ class ScriptNodeKind(enum.Enum): class ScriptModuleParser: @staticmethod - def parse_module(module: torch.jit.RecursiveModule) -> List[IROperation]: + def parse_module(module, + frame: _Frame = _Frame()) -> Tuple[List[IROperation], List[IRTensor]]: """ The overall entry to parse a torchscript graph module """ + frame.push() + + # handle graph input -- Assuming all the inputs are tensors + input_var_name = [input.debugName() for input in module.graph.inputs()] + # [1:] is to omit self + for var_name in input_var_name[1:][::-1]: + frame.add_var(var_name, IRTensor(name=var_name), graph_arg=True) + all_ir_nodes: List[IROperation] = list() for node in module.graph.nodes(): - ir_nodes = None - node_type = ScriptModuleParser.ntype(node) - if node_type == ScriptNodeKind.PrimCallFunction: - ir_nodes = ScriptModuleParser.parse_prim_function_node(node, module) - if node_type == ScriptNodeKind.AtenOp: - ir_nodes = ScriptModuleParser.parse_aten_node(node, module) - if node_type == ScriptNodeKind.PrimCallMethod: - ir_nodes = ScriptModuleParser.parse_prim_method_node(node, module) - if ir_nodes is not None: + # debug info + print(f'on parsing:\n\t{node}') + ir_nodes = ScriptModuleParser.parse_node(node, module, frame) + print(f'> {frame}') + print(f'> {ir_nodes}') + _ = input('>>>') + if len(ir_nodes) != 0: all_ir_nodes.append(ir_nodes) - return all_ir_nodes + + # handle graph output -- Assuming all the output are tensors + output_var_name = [output.debugName() for output in module.graph.outputs()] + output_val = [frame.get_var(var_name) for var_name in output_var_name] + + frame.pop() + return all_ir_nodes, output_val @staticmethod def ntype(node: torch._C.Node): @@ -51,120 +170,208 @@ def ntype(node: torch._C.Node): raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @staticmethod - def parse_prim_function_node(node, module) -> IROperation: + def parse_node(node: torch._C.Node, module, frame: _Frame) -> List[IROperation]: + """ + Parse the node and return the IROperation nodes + """ + node_type = ScriptModuleParser.ntype(node) + if node_type == ScriptNodeKind.PrimCallFunction: + return ScriptModuleParser.parse_prim_function_node(node, module, frame) + if node_type == ScriptNodeKind.AtenOp: + return ScriptModuleParser.parse_aten_node(node, module, frame) + if node_type == ScriptNodeKind.PrimCallMethod: + return ScriptModuleParser.parse_prim_method_node(node, module, frame) + if node_type == ScriptNodeKind.PrimGetAttr: + return ScriptModuleParser.parse_prim_attr_node(node, module, frame) + if node_type == ScriptNodeKind.PrimConstant: + return ScriptModuleParser.parse_prim_constant_node(node, module, frame) + + @staticmethod + def parse_prim_function_node(node, module, frame: _Frame) -> List[IROperation]: """ parse node like: Tensor = prim::CallFunction(%5, %input.1, %3, %4) %5 : Function = prim::Constant[name="linear"]() """ - fnode = node.inputsAt(0).node() # function node + inputs = [input for input in node.inputs()] + outputs = [output for output in node.outputs()] + + # handle function node + fnode = node.inputsAt(0).node() if not ScriptModuleParser.ntype(fnode) == ScriptNodeKind.PrimConstant: raise RuntimeError(f"Found unexpected function call node: {fnode}") - var_name, fsig, _ = ScriptModuleParser.parse_prim_constant_node(fnode, module) - input_length = len(list(fnode.inputs())) - 1 # -1 for the first signature - output_length = len(list(fnode.outputs())) + fsig = frame.get_var(inputs[0].debugName()) + + # create IR node ir_node = IROperation( signature = fsig, name = fnode.s('name'), - input_length=input_length, - output_length=output_length, + input_length=len(inputs) - 1, + output_length=len(outputs), ) - return ir_node + + # handle inputs -- in stack with reverse order + for index, input in enumerate(inputs[1:]): + var_name = input.debugName() + val = frame.get_var(var_name) + ir_node.set_input(index, val) + + # handle outputs + for index, output in enumerate(outputs): + frame.add_var(output.debugName(), ir_node.outputs(index)) + + return [ir_node] @staticmethod - def parse_aten_node(node, module) -> IROperation: + def parse_aten_node(node, module, frame: _Frame) -> List[IROperation]: """ Parse script module node like: %13 : Tensor = aten::gt(%output1.1, %output2.1) """ - input_nodes = list() - for input in node.inputs(): - input_node = input.node() - input_nodes.append(input_node) + fsig = node.kind() + fsig = re.sub('aten::', 'torch.', fsig) + inputs = [input for input in node.inputs()] + outputs = [output for output in node.outputs()] + # create IR node + ir_node = IROperation( + signature = fsig, + name = fsig, + input_length = len(inputs), + output_length = len(outputs) + ) - @staticmethod - def parse_node(node: torch._C.Node, module: torch.jit.RecursiveScriptModule): - """ - Parse the node and get the inputs, module type, outputs + # handle inputs + inputs = [input for input in node.inputs()] + # in stack with reverse order + for index, input in enumerate(inputs): + var_name = input.debugName() + val = frame.get_var(var_name) + ir_node.set_input(index, val) + + # handle outputs + outputs = [output for output in node.outputs()] + for index, output in enumerate(outputs): + frame.add_var(output.debugName(), ir_node.outputs(index)) + + return [ir_node] - Returns: - Inputs: list[int] - Modulelin - """ - ntype = ScriptNodeKind.get_node_type(node) - if ntype == ScriptNodeKind.PrimGetAttr: - return ScriptModuleParser.parse_prim_attr_node(node, module) - if ntype == ScriptNodeKind.PrimCallMethod: - return ScriptModuleParser.parse_prim_method_node(node, module) - if ntype == ScriptNodeKind.PrimCallFunction: - return ScriptModuleParser.parse_prim_function_node(node, module) - if ntype == ScriptNodeKind.AtenOp: - return ScriptModuleParser.parse_aten_node(node, module) - if ntype == ScriptNodeKind.PrimIf: - return ScriptModuleParser.parse_prim_if_node(node, module) - raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") - @staticmethod - def parse_prim_method_node(node, module) -> List[IROperation]: + def parse_prim_method_node(node, module, frame: _Frame) -> List[IROperation]: """ Parse script module node like: %output.1 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) - Find the module and + prim::CallMethod has a underlying submodule """ + inputs = [input for input in node.inputs()] + outputs = [output for output in node.outputs()] + # forward label = node.s('name') if label != 'forward': raise RuntimeError(f"{node} is calling function {label} that is not `forward`") - call_module = getattr(module, label) - ir_nodes = ScriptModuleParser.parse_module(call_module) - # TODO: rename nodes + + # cell node that will not appear in final graph + signature = frame.get_var(node.inputsAt(0).debugName()) + ir_node = IROperation( + name = signature + '.' + label, + signature = signature, + input_length = len(inputs) - 1, + output_length = len(outputs) + ) + + # handle inputs -- in stack with reverse order + for index, input in enumerate(inputs[1:][::-1]): + var_name = input.debugName() + val = frame.get_var(var_name) + ir_node.set_input(-1-index, val) + frame.push_param(var_name) + + print(f'> {frame}') + + # recursively parse the module + module_label = node.inputsAt(0).node().s('name') + call_module = getattr(module, module_label) + ir_nodes, outputs_val = ScriptModuleParser.parse_module(call_module, frame) + + # pop out the frame + frame.pop_param(times=len(inputs)-1) + + # handle outputs + outputs = [output for output in node.outputs()] + for index, (output, val) in enumerate(zip(outputs, outputs_val)): + frame.add_var(output.debugName(), val) + return ir_nodes @staticmethod - def parse_prim_attr_node(node, module) -> Dict: + def parse_prim_attr_node(node, module, frame) -> List[None]: """ Parse script module node like: %2 :__torch__.torch.nn.modules.linear.___torch_mangle_0.Linear = prim::GetAttr[name="linear1"](%self) %3 : Tensor = prim::GetAttr[name="weight"](%self) The __torch__.torch.nn.modules.* will be ignored + + This will add frame with the variable name and it's value + + The value can be: + 1). (IRTensor) the tensor edge in graph + 2). (str code) symbolic value based on runtime info (e.g., self.training) + 3). (str) Function or torch.nn.moudles + + Returns: + Empty list """ - if node.inputsAt(0).debugeName() != 'self': + if node.inputsAt(0).debugName() != 'self': raise RuntimeError(f"Fail to parse {node} due to missing %self") - # user defined var name: linear1 + label = node.s('name') + var_name = node.outputsAt(0).debugName() + dtype = node.outputsAt(0).type().str() - # output names: [2] - output_vars : List[str] = list() - for output in node.outputs(): - output_vars.append(output.debugName()) - - # module type name - module_mangle = node.outputsAt(0).type().str() - if 'torch.nn.modules' in module_mangle: - return None - module_name_demangle = list() - for mname in module_mangle.split('.'): - if mname == '__torch__' or '_mangle' in mname: - continue - module_name_demangle.append(mname) - module_name_demangle = '.'.joint(module_name_demangle) - - return dict( - input_ids=list(), output_names=output_names, - module=module_name_demangle, - label=label - ) + # this usually means weight (nn.Parameter in torch) + if dtype == 'Tensor': + ir_tensor = IRTensor(name=label) + frame.add_var(var_name, ir_tensor) + # symbolic attributes + elif dtype in ['bool', 'int', 'float']: + frame.add_var(var_name, 'self.' + label) + # NoneType + elif dtype == 'NoneType': + frame.add_var(var_name, None) + # module name or other things cannot handle + else: + frame.add_var(var_name, label) + return list() @staticmethod - def parse_prim_constant_node(node): - output = node.outputsAt(0) - if output.type().str() == 'Function': - func_name = repr(output.type()) + def parse_prim_constant_node(node, module, frame) -> List[None]: + """ + Parse script module node like: + %6 : Function = prim::Constant[name="dropout"]() + %5 : bool = prim::Constant[value=0]() + + This will add frame with the variable name and it's value + + Returns: + Empty list + """ + if len(list(node.inputs())) != 0: + raise RuntimeError(f"prim::Constant node: {node} has inputs") + var_name = node.outputsAt(0).debugName() + dtype = node.outputsAt(0).type().str() + + if dtype == 'Function': + signature = repr(node.outputsAt(0).type()) + frame.add_var(var_name, signature) + else: + val = node.outputsAt(0).toIValue() + frame.add_var(var_name, val) + return list() @staticmethod - def parse_prim_if_node(node, module): + def parse_prim_if_node(node, module, frame: _Frame) -> List[IROperation]: """ Parse script module node like %output2 : Tensor = prim::If(%15) # /tmp/ipykernel_27188/2459450745.py:13:8 @@ -181,12 +388,14 @@ def flatten(smodule, depth=0): Flatten the recursive script module to function and aten primitives """ # stashed_module = list() + inputs = [input for input in smodule.graph.inputs()] + print(' '*depth, f'graph inputs: {inputs}') if len(list(smodule.children())) == 0: for node in smodule.graph.nodes(): print(' '*depth, node) else: for node in smodule.graph.nodes(): - ntype = ScriptModuleParser.get_node_type(node) + ntype = ScriptModuleParser.ntype(node) print(' '*depth, node) if ntype == ScriptNodeKind.PrimCallMethod: label = node.inputsAt(0).node().s('name') From 471d762e2faa58ac17a3eb64685d77f74347a44b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Sep 2021 17:30:35 +0800 Subject: [PATCH 0175/1892] fix dependency bug --- cube/graph/__init__.py | 3 + cube/graph/converter.py | 17 +++++ cube/graph/frame.py | 107 +++++++++++++++++++++++++++ cube/graph/graph.py | 42 +++++++---- cube/graph/parser.py | 158 ++++++---------------------------------- 5 files changed, 180 insertions(+), 147 deletions(-) create mode 100644 cube/graph/converter.py create mode 100644 cube/graph/frame.py diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index e69de29b..5fdefdc7 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -0,0 +1,3 @@ +from cube.graph.graph import IRGraph, IRTensor, IROperation +from cube.graph.converter import convert + diff --git a/cube/graph/converter.py b/cube/graph/converter.py new file mode 100644 index 00000000..21b84c3b --- /dev/null +++ b/cube/graph/converter.py @@ -0,0 +1,17 @@ +from cube.graph.parser import ScriptModuleParser +from cube.graph.graph import IRGraph + +import torch + +def convert(model: torch.nn.Module) -> IRGraph: + """ + Convert toch.nn.Module based model into IRGraph + """ + try: + smodule = torch.jit.script(model) + except Exception: + raise RuntimeError("Cannot convert module into torchscript moudle.") + module_name = smodule.original_name + inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule) + graph = IRGraph(nodes, inputs, outputs, module_name) + return graph diff --git a/cube/graph/frame.py b/cube/graph/frame.py new file mode 100644 index 00000000..72feba1d --- /dev/null +++ b/cube/graph/frame.py @@ -0,0 +1,107 @@ +from collections import OrderedDict +from typing import List, Any + + +class Frame: + """ + Frame to save call stack and variable + """ + def __init__(self): + + # var name -> value (IRTesnor, deterministic) + self._vars: List[dict[str, Any]] = list() + self._var_stack: List[str] = list() + + def push(self): + """ + This should only be called when step in a module + """ + self._vars.append(OrderedDict()) + + def pop(self): + """ + This should only be called step out a module + """ + if len(self._vars) == 0: + raise RuntimeError("Try to pop stack with 0 depth") + self._vars.pop() + + def add_var(self, var_name: str, val: Any, graph_arg: int = -1): + """ + Add variable to the current frame + + Args: + var_name (str): variable name (unique) + val: variable content + graph_arg (int): + indicate whether it is an argument of the graph. + If is 0, is not a graph arg. + If > 0, is a graph arg, will try to find + val from previous frame + """ + if not isinstance(var_name, str): + raise RuntimeError("Expected var_name is str") + if var_name in self._vars[-1]: + raise KeyError("Try to insert an already existed variable") + # not a function parameter, no need for mapping + if graph_arg == -1: + self._vars[-1][var_name] = val + # a function parameter, may need for mapping + elif graph_arg >= 0: + # root graph entry + if self.depth() == 1: + self._vars[-1][var_name] = val + # fucnton call + else: + prev_frame = self._vars[-2] + param_name = self._var_stack[-1-graph_arg] + val = prev_frame[param_name] + self._vars[-1][var_name] = val + else: + raise ValueError("graph_arg (int) must be >= 0") + + def get_var(self, var_name: str) -> Any: + """ + Get variable value according to var_name + + Special mapping between frames (function calls): + + input.x will be mapped to output.k at the about 1-hop frame + + Returns: + val (Any) + """ + # first check whether we have variable in this frame + if var_name in self._vars[-1]: + return self._vars[-1][var_name] + raise KeyError(f"Cannot find var name {var_name}") + + def push_param(self, var_name): + """ + push var name to the method stack + + Args: + var_name (str): variable name + """ + if var_name not in self._vars[-1]: + raise KeyError(f"push {var_name} not declared") + self._var_stack.append(var_name) + + def pop_param(self, times=1): + """ + pop var name from the method stack + """ + for _ in range(times): + self._var_stack.pop() + + def depth(self): + return len(self._vars) + + def __repr__(self): + dscp = f'frame: depth: {self.depth()}\n var table:' + for var_name in self._vars[-1].keys(): + dscp += f'\n {var_name} : {self._vars[-1][var_name]}' + dscp += f'\n var stack:' + for var_name in self._var_stack: + dscp += f'\n {var_name}' + return dscp diff --git a/cube/graph/graph.py b/cube/graph/graph.py index e10bfe0a..1116e458 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -24,15 +24,8 @@ def __init__(self, Args: name (str): the op semantic name signature (str): the op signature, e.g., torch.functional.nn.linear - - Example: - init code: - self.linear1 = torch.nn.Linear(input_feats, output_feats) - forward code: - output = self.linear1(input) - => - label = linear1; - op = torch._C._nn.linear + input_length (int): the number of inputs for the op + output_length (int): the number of outputs for the op """ # node info self._id: int = IDGenerator().gen_op_id() @@ -46,6 +39,8 @@ def __init__(self, self._predecessors: List[IROperation] = [None] * input_length # todo for outputs self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] + for tensor in self._outputs: + tensor.set_src_node(self) self._successors: List[List(IROperation)] = [list() for _ in range(output_length)] def inputs(self, index: Optional[int] = None): @@ -206,7 +201,7 @@ def dst(self, index: Optional[int] = None): raise RuntimeError("get tensor dst out of range") return self._dst_nodes[index] - def set_src_nodes(self, node: IROperation): + def set_src_node(self, node: IROperation): if not isinstance(node, IROperation): raise TypeError("IRTensor source node should be IROperation") self._src_nodes = node @@ -217,7 +212,7 @@ def add_dst_nodes(self, node: IROperation): self._dst_nodes.append(IROperation) def __repr__(self): - dscp = f'Tensor(id={self._id}, name={self.name}, shape={self.shape})' + dscp = f'Tensor(id={self._id}, shape={self.shape})' return dscp @@ -228,9 +223,15 @@ class IRGraph: The IR Graph only contains forward graph """ - def __init__(self, module_name: str): + def __init__(self, + nodes: List[IROperation], + input_tensors: List[IRTensor], + output_tensors: List[IRTensor], + module_name: str): self.module_name = module_name - self._nodes: List[IROperation] = list() + self._nodes: List[IROperation] = nodes + self._input_tensors = input_tensors + self._output_tensors = output_tensors def add_node(self, node: IROperation): if not isinstance(node, IROperation): @@ -252,3 +253,18 @@ def replace(self, target: IROperation, nodes: List[IROperation]): Replace the node with new nodes (IRGraph) """ raise NotImplementedError + + def __repr__(self): + dscp = '' + # inputs + dscp += f'Inputs: {self._input_tensors}\n' + # nodes + for node in self._nodes: + succ_node_ids = [None] * len(node.outputs()) + for out_idx, node_list in enumerate(node.successors()): + node_list = [snode._id for snode in node_list] + succ_node_ids[out_idx] = node_list + dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" + # outputs + dscp += f'\nOutputs: {self._output_tensors}' + return dscp diff --git a/cube/graph/parser.py b/cube/graph/parser.py index ed517df2..fbd7d061 100644 --- a/cube/graph/parser.py +++ b/cube/graph/parser.py @@ -1,114 +1,10 @@ import torch import enum import re -from collections import OrderedDict -from typing import List, Any, Tuple +from typing import List, Tuple from cube.graph.graph import IROperation, IRTensor - - -class _Frame: - """ - Frame to save call stack and variable - """ - def __init__(self): - - # var name -> value (IRTesnor, deterministic) - self._vars: List[dict[str, Any]] = list() - self._var_stack: List[str] = list() - - def push(self): - """ - This should only be called when step in a module - """ - self._vars.append(OrderedDict()) - - def pop(self): - """ - This should only be called step out a module - """ - if len(self._vars) == 0: - raise RuntimeError("Try to pop stack with 0 depth") - self._vars.pop() - - def add_var(self, var_name: str, val: Any, graph_arg: int = 0): - """ - Add variable to the current frame - - Args: - var_name (str): variable name (unique) - val: variable content - graph_arg (int): - indicate whether it is an argument of the graph. - If is 0, is not a graph arg. - If > 0, is a graph arg, will try to find - val from previous frame - """ - if not isinstance(var_name, str): - raise RuntimeError("Expected var_name is str") - if var_name in self._vars[-1]: - raise KeyError("Try to insert an already existed variable") - if graph_arg == 0: - self._vars[-1][var_name] = val - elif graph_arg > 0: - # root graph entry - if self.depth() == 1: - self._vars[-1][var_name] = val - # fucnton call - else: - prev_frame = self._vars[-2] - param_name = self._var_stack[0-graph_arg] - val = prev_frame[param_name] - self._vars[-1][var_name] = val - else: - raise ValueError("graph_arg (int) must be >= 0") - - def get_var(self, var_name: str) -> Any: - """ - Get variable value according to var_name - - Special mapping between frames (function calls): - - input.x will be mapped to output.k at the about 1-hop frame - - Returns: - val (Any) - """ - # first check whether we have variable in this frame - if var_name in self._vars[-1]: - return self._vars[-1][var_name] - raise KeyError(f"Cannot find var name {var_name}") - - def push_param(self, var_name): - """ - push var name to the method stack - - Args: - var_name (str): variable name - """ - if var_name not in self._vars[-1]: - raise KeyError(f"push {var_name} not declared") - self._var_stack.append(var_name) - - def pop_param(self, times=1): - """ - pop var name from the method stack - """ - for _ in range(times): - self._var_stack.pop() - - def depth(self): - return len(self._vars) - - def __repr__(self): - dscp = f'frame: depth: {self.depth()}\n var table:' - for var_name in self._vars[-1].keys(): - dscp += f'\n {var_name} : {self._vars[-1][var_name]}' - dscp += f'\n var stack:' - for var_name in self._var_stack: - dscp += f'\n {var_name}' - return dscp - +from cube.graph.frame import Frame class ScriptNodeKind(enum.Enum): PrimGetAttr = 1 @@ -123,7 +19,8 @@ class ScriptModuleParser: @staticmethod def parse_module(module, - frame: _Frame = _Frame()) -> Tuple[List[IROperation], List[IRTensor]]: + frame: Frame = Frame()) \ + -> Tuple[List[IRTensor], List[IROperation], List[IRTensor]]: """ The overall entry to parse a torchscript graph module """ @@ -132,26 +29,27 @@ def parse_module(module, # handle graph input -- Assuming all the inputs are tensors input_var_name = [input.debugName() for input in module.graph.inputs()] # [1:] is to omit self - for var_name in input_var_name[1:][::-1]: - frame.add_var(var_name, IRTensor(name=var_name), graph_arg=True) + for index, var_name in enumerate(input_var_name[1:]): + frame.add_var(var_name, IRTensor(name=var_name), graph_arg=index) + input_val = [frame.get_var(var_name) for var_name in input_var_name[1:]] all_ir_nodes: List[IROperation] = list() for node in module.graph.nodes(): # debug info - print(f'on parsing:\n\t{node}') + # print(f'on parsing:\n\t{node}') ir_nodes = ScriptModuleParser.parse_node(node, module, frame) - print(f'> {frame}') - print(f'> {ir_nodes}') - _ = input('>>>') + # print(f'> {frame}') + # print(f'> {ir_nodes}') + # _ = input('>>>') if len(ir_nodes) != 0: - all_ir_nodes.append(ir_nodes) + all_ir_nodes += ir_nodes # handle graph output -- Assuming all the output are tensors output_var_name = [output.debugName() for output in module.graph.outputs()] output_val = [frame.get_var(var_name) for var_name in output_var_name] frame.pop() - return all_ir_nodes, output_val + return input_val, all_ir_nodes, output_val @staticmethod def ntype(node: torch._C.Node): @@ -170,7 +68,7 @@ def ntype(node: torch._C.Node): raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @staticmethod - def parse_node(node: torch._C.Node, module, frame: _Frame) -> List[IROperation]: + def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IROperation]: """ Parse the node and return the IROperation nodes """ @@ -187,7 +85,7 @@ def parse_node(node: torch._C.Node, module, frame: _Frame) -> List[IROperation]: return ScriptModuleParser.parse_prim_constant_node(node, module, frame) @staticmethod - def parse_prim_function_node(node, module, frame: _Frame) -> List[IROperation]: + def parse_prim_function_node(node, module, frame: Frame) -> List[IROperation]: """ parse node like: Tensor = prim::CallFunction(%5, %input.1, %3, %4) @@ -223,7 +121,7 @@ def parse_prim_function_node(node, module, frame: _Frame) -> List[IROperation]: return [ir_node] @staticmethod - def parse_aten_node(node, module, frame: _Frame) -> List[IROperation]: + def parse_aten_node(node, module, frame: Frame) -> List[IROperation]: """ Parse script module node like: %13 : Tensor = aten::gt(%output1.1, %output2.1) @@ -257,7 +155,7 @@ def parse_aten_node(node, module, frame: _Frame) -> List[IROperation]: return [ir_node] @staticmethod - def parse_prim_method_node(node, module, frame: _Frame) -> List[IROperation]: + def parse_prim_method_node(node, module, frame: Frame) -> List[IROperation]: """ Parse script module node like: %output.1 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) @@ -272,35 +170,25 @@ def parse_prim_method_node(node, module, frame: _Frame) -> List[IROperation]: if label != 'forward': raise RuntimeError(f"{node} is calling function {label} that is not `forward`") - # cell node that will not appear in final graph - signature = frame.get_var(node.inputsAt(0).debugName()) - ir_node = IROperation( - name = signature + '.' + label, - signature = signature, - input_length = len(inputs) - 1, - output_length = len(outputs) - ) - # handle inputs -- in stack with reverse order - for index, input in enumerate(inputs[1:][::-1]): + for input in inputs[1:][::-1]: var_name = input.debugName() val = frame.get_var(var_name) - ir_node.set_input(-1-index, val) frame.push_param(var_name) - print(f'> {frame}') + # print(f'> {frame}') # recursively parse the module module_label = node.inputsAt(0).node().s('name') call_module = getattr(module, module_label) - ir_nodes, outputs_val = ScriptModuleParser.parse_module(call_module, frame) + _, ir_nodes, outputs_val = ScriptModuleParser.parse_module(call_module, frame) # pop out the frame frame.pop_param(times=len(inputs)-1) # handle outputs outputs = [output for output in node.outputs()] - for index, (output, val) in enumerate(zip(outputs, outputs_val)): + for output, val in zip(outputs, outputs_val): frame.add_var(output.debugName(), val) return ir_nodes @@ -364,6 +252,8 @@ def parse_prim_constant_node(node, module, frame) -> List[None]: if dtype == 'Function': signature = repr(node.outputsAt(0).type()) + if '__torch__.' in signature: + signature = re.sub('__torch__.', '', signature) frame.add_var(var_name, signature) else: val = node.outputsAt(0).toIValue() @@ -371,7 +261,7 @@ def parse_prim_constant_node(node, module, frame) -> List[None]: return list() @staticmethod - def parse_prim_if_node(node, module, frame: _Frame) -> List[IROperation]: + def parse_prim_if_node(node, module, frame: Frame) -> List[IROperation]: """ Parse script module node like %output2 : Tensor = prim::If(%15) # /tmp/ipykernel_27188/2459450745.py:13:8 From 71298fcb4dd2b2279c30a459de786725bfff5995 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Sep 2021 20:30:35 +0800 Subject: [PATCH 0176/1892] add codegen for spatial --- cube/codegen/codegen.py | 100 +++++++++++++++++++++++++++++++- cube/codegen/syntax/blocks.py | 13 ++++- cube/codegen/syntax/symtable.py | 15 ++--- cube/graph/graph.py | 40 +++++++++++-- 4 files changed, 147 insertions(+), 21 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 723bcd29..de0cfbae 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -2,10 +2,104 @@ Generate Pytorch code given the model DAG and the transformation config """ +from typing import List, Any -class ModelCodeGen: +from cube.graph import IRGraph, IRTensor, IROperation +from cube.codegen.syntax.symtable import SymbolTable +from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock - def __init__(self, action_irs): - pass +class SScheduleCodeGen: + """ + Generate spatial code for the model + """ + def __init__(self, graph: IRGraph): + if not isinstance(graph, IRGraph): + raise TypeError("graph should be IRGraph") + self.graph = graph + # model full code + self.code: List[str] = list() + # module init code + self.declare_region: List[str] = list() + # module forward code + self.forward_region: List[str] = list() + # module member name + self.symbols = SymbolTable() + + def gen(self, outfile=None) -> List[str]: + """ + Generate model implementation code based on the given graph. + """ + # register forward input + fargs = [self.naming(input) for input in self.graph.inputs()] + for name in fargs: + self.symbols.create(name) + + # parse graph body + for node in self.graph.nodes(): + self.emit_op_call(node) + # emit input declaration + for arg in node.inputs(): + self.emit_var_declare(arg) + # record output tensor name + for out in node.outputs(): + if isinstance(out, IRTensor) or isinstance(out, str): + self.symbols.create(self.naming(out)) + + # generate full code + with ClassBlock(class_name='GenModel', derived='torch.nn.Module') as cb: + with FunctionBlock(func_name='__init__', args=['self']) as ib: + ib.insert_body(self.declare_region) + cb.insert_body(ib.code) + with FunctionBlock(func_name='forward', args=['self']+fargs) as fb: + fb.insert_body(self.emit_op_call) + cb.insert_body(fb.code) + self.code = cb.code + + # write to file + if outfile: + with open(outfile, 'w'): + for line in self.code: + outfile.write(line) + + return self.code + + def emit_var_declare(self, var: Any): + """ + Emit tensor declaration code + """ + if isinstance(var, IRTensor): + name = self.naming(var) + # indicate this is a leaf tensor, should be parameter + if self.symbols.create(name): + code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(IRTensor.shape)}))' + self.declare_region.append(code) + elif isinstance(var, str): + name = self.naming(var) + if self.symbols.create(name): + #TODO: add type info + code = f'self.{name} = None' + self.declare_region.append(code) + return + + def emit_op_call(self, node: IROperation): + """ + Emit op forward code + """ + op_code = node.signature + out_region = ', '.join([self.naming(out) for out in node.outputs()]) + arg_region = '(' + ', '.join([self.naming(arg) for arg in node.inputs()]) + ')' + code = f'{out_region} = {op_code}{arg_region}' + self.forward_region.append(code) + + def naming(self, tensor: Any) -> str: + """ + Return the var name (unique for different variable) + """ + if isinstance(tensor, IRTensor): + tensor_name = 'tensor' if tensor.name is None else tensor.name + name = '_'.join([tensor_name, tensor._id]) + else: + name = str(tensor) + return name diff --git a/cube/codegen/syntax/blocks.py b/cube/codegen/syntax/blocks.py index 33e44485..b16bbb8e 100644 --- a/cube/codegen/syntax/blocks.py +++ b/cube/codegen/syntax/blocks.py @@ -1,4 +1,4 @@ - +from typing import List class Block: @@ -27,8 +27,15 @@ def __exit__(self, exc_type, exc_value, exc_tb): class FunctionBlock(Block): - - def __init__(self, func_name, args): + """ + Create a function block with function definition + """ + + def __init__(self, func_name: str, args: List[str]): + if not isinstance(func_name, str): + raise TypeError("Expected func_name to be str") + if not isinstance(args, list): + raise TypeError("Expcted args to be list[str]") self.func_name = func_name self.param_name = args args = ', '.join(args) diff --git a/cube/codegen/syntax/symtable.py b/cube/codegen/syntax/symtable.py index abf8235f..3712453e 100644 --- a/cube/codegen/syntax/symtable.py +++ b/cube/codegen/syntax/symtable.py @@ -16,20 +16,13 @@ class SymbolTable: def __init__(self): self._varlist = list() - def create(self, var_name): + def create(self, var_name: str): """ - Create a variable with a type. - - If var_name is already declared: - if the declared type matches with var_name, return False. - else raise Error - - If var name is not declared, decalre the var and return True. + Create a variable. Args: var_name (str): variable name - var_type (Dtype): variable type - + Returns: True if declared, False if the var already exists. """ @@ -39,7 +32,7 @@ def create(self, var_name): else: self._varlist.append(var_name) - def exist(self, var_name): + def exist(self, var_name: str): """ Check whether a variable exists """ diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 1116e458..52a71a97 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -230,8 +230,8 @@ def __init__(self, module_name: str): self.module_name = module_name self._nodes: List[IROperation] = nodes - self._input_tensors = input_tensors - self._output_tensors = output_tensors + self._inputs = input_tensors + self._outputs = output_tensors def add_node(self, node: IROperation): if not isinstance(node, IROperation): @@ -248,6 +248,38 @@ def nodes(self, index: Optional[int]): ) return self._nodes[index] + def inputs(self, index: Optional[int] = None): + if isinstance(index, int): + if index >= len(self._inputs): + raise RuntimeError( + f"Get the input out of range ({index} >= {len(self._inputs)}" + ) + return self._inputs[index] + elif index is None: + return self._inputs + else: + raise TypeError("Expected index to be None or int") + + def outputs(self, index: Optional[int] = None): + """ + Get output tensor at output index + + Args: + index (int or None): + index of the outputs, None will return the nodes + for all the outputs + """ + if isinstance(index, int): + if index >= len(self._outputs): + raise RuntimeError( + f"Get the output out of range ({index} >= {len(self._outputs)}" + ) + return self._outputs[index] + elif index is None: + return self._outputs + else: + raise TypeError("Expected index to be None or int") + def replace(self, target: IROperation, nodes: List[IROperation]): """ Replace the node with new nodes (IRGraph) @@ -257,7 +289,7 @@ def replace(self, target: IROperation, nodes: List[IROperation]): def __repr__(self): dscp = '' # inputs - dscp += f'Inputs: {self._input_tensors}\n' + dscp += f'Inputs: {self._inputs}\n' # nodes for node in self._nodes: succ_node_ids = [None] * len(node.outputs()) @@ -266,5 +298,5 @@ def __repr__(self): succ_node_ids[out_idx] = node_list dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" # outputs - dscp += f'\nOutputs: {self._output_tensors}' + dscp += f'\nOutputs: {self._outputs}' return dscp From 00b80ba15bd9cdcd32b0e5a68e8faa26459cb084 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 09:54:15 +0800 Subject: [PATCH 0177/1892] add weight shape --- cube/graph/parser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cube/graph/parser.py b/cube/graph/parser.py index fbd7d061..fbe5e8a9 100644 --- a/cube/graph/parser.py +++ b/cube/graph/parser.py @@ -220,7 +220,8 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: # this usually means weight (nn.Parameter in torch) if dtype == 'Tensor': - ir_tensor = IRTensor(name=label) + shape = list(getattr(module, label).shape) + ir_tensor = IRTensor(name=label, shape=shape) frame.add_var(var_name, ir_tensor) # symbolic attributes elif dtype in ['bool', 'int', 'float']: From 7b9de431b5adfc59dd4481fab1d4c3a436460adc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 12:39:19 +0800 Subject: [PATCH 0178/1892] infer shape --- cube/__init__.py | 4 - cube/config/__init__.py | 1 - cube/config/container.py | 125 -------------- cube/config/utils.py | 25 --- cube/graph/converter.py | 7 +- cube/graph/graph.py | 36 +++- cube/graph/mapping.py | 36 ++++ cube/graph/parser.py | 22 ++- cube/nn/__init__.py | 2 - cube/nn/linear.py | 30 ---- cube/operator/__init__.py | 2 + cube/operator/holist/__init__.py | 0 cube/operator/holist/generics.py | 208 ----------------------- cube/operator/holist/linear.py | 132 --------------- cube/operator/logic/__init__.py | 2 +- cube/operator/logic/function.py | 46 ++++++ cube/operator/logic/generics.py | 106 +++++------- cube/operator/logic/linear.py | 20 --- cube/operator/physic/__init__.py | 0 cube/operator/physic/comm/__init__.py | 1 - cube/operator/physic/comm/boundary.py | 133 --------------- cube/operator/physic/comm/mapreduce.py | 36 ---- cube/operator/physic/generics.py | 93 ----------- cube/operator/physic/linear.py | 12 -- cube/tschedule/__init__.py | 1 + cube/tschedule/action.py | 90 ++++++++++ tests/graph/test_parser.py | 42 +++++ tests/operator/test_holistic_linear.py | 219 ------------------------- tests/operator/test_holistic_op.py | 117 ------------- tests/operator/test_logical_op.py | 47 ------ tests/operator/test_physic_linear.py | 48 ------ tests/operator/test_physic_op.py | 48 ------ 32 files changed, 311 insertions(+), 1380 deletions(-) delete mode 100644 cube/config/__init__.py delete mode 100644 cube/config/container.py delete mode 100644 cube/config/utils.py create mode 100644 cube/graph/mapping.py delete mode 100644 cube/nn/__init__.py delete mode 100644 cube/nn/linear.py delete mode 100644 cube/operator/holist/__init__.py delete mode 100644 cube/operator/holist/generics.py delete mode 100644 cube/operator/holist/linear.py create mode 100644 cube/operator/logic/function.py delete mode 100644 cube/operator/logic/linear.py delete mode 100644 cube/operator/physic/__init__.py delete mode 100644 cube/operator/physic/comm/__init__.py delete mode 100644 cube/operator/physic/comm/boundary.py delete mode 100644 cube/operator/physic/comm/mapreduce.py delete mode 100644 cube/operator/physic/generics.py delete mode 100644 cube/operator/physic/linear.py create mode 100644 cube/tschedule/__init__.py create mode 100644 cube/tschedule/action.py create mode 100644 tests/graph/test_parser.py delete mode 100644 tests/operator/test_holistic_linear.py delete mode 100644 tests/operator/test_holistic_op.py delete mode 100644 tests/operator/test_logical_op.py delete mode 100644 tests/operator/test_physic_linear.py delete mode 100644 tests/operator/test_physic_op.py diff --git a/cube/__init__.py b/cube/__init__.py index 996b1f04..e69de29b 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,4 +0,0 @@ -from cube import operator -from cube import nn -from cube import device -from cube import config \ No newline at end of file diff --git a/cube/config/__init__.py b/cube/config/__init__.py deleted file mode 100644 index 9bd37031..00000000 --- a/cube/config/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cube.config.utils import choices \ No newline at end of file diff --git a/cube/config/container.py b/cube/config/container.py deleted file mode 100644 index 37639e62..00000000 --- a/cube/config/container.py +++ /dev/null @@ -1,125 +0,0 @@ - -class ConditionContainer: - - def __init__(self, satisfy_fn): - if not callable(satisfy_fn): - raise TypeError("Expected function") - self._satisfy_fn = (satisfy_fn,) - self._condition_fn = None - self._val = None - self._choices = None - self._lock = False - - def get(self): - """ - Get the current set value (default None if not set) - """ - return self._val - - def set(self, val): - """ - Set the value, will raise ValueError if not satisfy - """ - if self._lock: - raise False - if self._choices is not None: - if not self.satisfy(val, self._choices): - return False - val_backup = self._val - self._val = val - if self._condition_fn is not None: - if not self._condition_fn(): - self._val = val_backup - return False - return True - - def lock(self): - """ - Lock the value, will not allow change - """ - self._lock = True - - def satisfy(self, val): - """ - Check whether the value satisfy the choices and conditions - - Returns: - True if satisfy, False not - """ - return self._satisfy_fn[0](val, self._choices) - - def choices(self): - """ - Return choices. - - Use list(container.choices) to see all the choices - """ - return self._choices - - def reset(self, choices): - """ - Reset choices - """ - self._val = None - self._choices = choices - - -class ChoiceContainer(ConditionContainer): - - def __init__(self, choices): - """ - Create a choice container, the value can only be - the item in the choices. - - choices (iterable): - list or range - """ - def satisfy_fn(val, choices): - return val in choices - super().__init__(satisfy_fn) - if not hasattr(choices, '__iter__'): - choices = [choices] - self._choices = choices - - -class TypeContainer(ConditionContainer): - - def __init__(self, type_choices): - """ - Create a type container, the value can only be - the instance of the type in the choices. - - type_choices (iterable): - usually a list[type] - """ - def satisfy_fn(val, choices): - for t in choices: - if isinstance(val, t): - return True - return False - super().__init__(satisfy_fn) - self._choices = type_choices - - -class UniformSumContainer(ConditionContainer): - - def __init__(self, summation): - """ - Create a summation restriction container - """ - def satisfy_fn(val, choices): - return len(set(val) == 1) and sum(val) == choices - super().__init__(satisfy_fn) - self._choices = summation - - -class SumContainer(ConditionContainer): - - def __init__(self, summation, slots=None): - """ - Create a summation restriction container - """ - def satisfy_fn(val, choices): - return sum(val) == choices - super().__init__(satisfy_fn) - self._choices = summation diff --git a/cube/config/utils.py b/cube/config/utils.py deleted file mode 100644 index 610f6383..00000000 --- a/cube/config/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import copy -import z3 - -def choices(solver, attributes): - """ - Iterate each the config space - - Args: - solver (z3.z3.Solver) - attributes (list[z3.z3.xx]) - - Yield: - config (z3.z3.ModelRef) - """ - if not isinstance(solver, z3.z3.Solver): - raise TypeError("Expected solver to be an z3 solver") - solver = copy.deepcopy(solver) - while solver.check() == z3.sat: - config = solver.model() - solver.add( - z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) - ) - yield config - if len(attributes) == 0: - break diff --git a/cube/graph/converter.py b/cube/graph/converter.py index 21b84c3b..fa5d212c 100644 --- a/cube/graph/converter.py +++ b/cube/graph/converter.py @@ -1,9 +1,12 @@ +from typing import Optional, List + from cube.graph.parser import ScriptModuleParser from cube.graph.graph import IRGraph import torch -def convert(model: torch.nn.Module) -> IRGraph: +def convert(model: torch.nn.Module, + input_shapes: Optional[ List[List[int],] ] = None) -> IRGraph: """ Convert toch.nn.Module based model into IRGraph """ @@ -12,6 +15,6 @@ def convert(model: torch.nn.Module) -> IRGraph: except Exception: raise RuntimeError("Cannot convert module into torchscript moudle.") module_name = smodule.original_name - inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule) + inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) graph = IRGraph(nodes, inputs, outputs, module_name) return graph diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 52a71a97..e35e9db1 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -1,7 +1,6 @@ -""" -Convert PyTorch nn.Module to our IRGraph -""" from cube.graph.unique import IDGenerator +from cube.graph.mapping import IR2LogicOp + from typing import List, Optional, Any @@ -31,8 +30,9 @@ def __init__(self, self._id: int = IDGenerator().gen_op_id() self.name: str = name - # op signature + # op signature and op class self.signature: str = signature + self.op = IR2LogicOp.map(self.signature) # edge (dataflow info) self._inputs: List[IRTensor] = [None] * input_length @@ -161,6 +161,29 @@ def _add_successor(self, tensor, node): raise RuntimeError("Fail to find output tensor") self._successors[out_index].append(node) + def infer_shape(self): + """ + Infer output value shape + """ + shapes = list() + for input in self.inputs(): + if isinstance(input, IRTensor): + if input.shape is None: + return False + shapes.append(input.shape) + else: + shapes.append([1,]) + shapes = tuple(shapes) + out_shapes = self.op.shape_infer(*shapes) + if len(out_shapes) != len(self._outputs): + raise RuntimeError( + "The logical op semantic doesn't match with parsed op" + ) + for shape, val in zip(out_shapes, self._outputs): + if isinstance(val, IRTensor): + val.shape = shape + return True + def __repr__(self): dscp = f'Op(id={self._id}, signature={self.signature}, inputs={self._inputs}, outputs={self._outputs})' return dscp @@ -187,9 +210,10 @@ def shape(self): @shape.setter def shape(self, val): - if self._shape is not None: + if self._shape is not None and self._shape != val: raise RuntimeError("Try to change shape") - if not all([isinstance(size, int) for size in val]): + if not isinstance(val, list) or \ + not all([isinstance(size, int) for size in val]): raise RuntimeError("Expected shape to be list[int]") self._shape = val diff --git a/cube/graph/mapping.py b/cube/graph/mapping.py new file mode 100644 index 00000000..24c0499e --- /dev/null +++ b/cube/graph/mapping.py @@ -0,0 +1,36 @@ +""" +Mapping of + IROperation -> cube.operator.logic.generics.GenericLogicalOp +""" + +import cube.operator.logic as logic + +class IR2LogicOp: + + @staticmethod + def map(signature: str) -> logic.GenericLogicalOp : + """ + Map the signature to GenericLogicalOp + """ + if signature in IR2LogicOp.kOpMap: + return IR2LogicOp.kOpMap[signature] + raise KeyError(f"{signature} is not supported yet") + + # functional templates + __ftemplate = lambda name: f'torch.nn.functional.{name}' + + # tensor template + __ttemplate = lambda name: f'torch.{name}' + + kOpMap = { + + __ftemplate('linear') : logic.Linear, + + __ftemplate('dropout') : logic.Dropout, + + __ftemplate('gelu') : logic.GeLU, + + __ttemplate('add') : logic.TensorAdd, + + } + diff --git a/cube/graph/parser.py b/cube/graph/parser.py index fbe5e8a9..c11b70d5 100644 --- a/cube/graph/parser.py +++ b/cube/graph/parser.py @@ -1,9 +1,9 @@ import torch import enum import re -from typing import List, Tuple +from typing import List, Tuple, Optional -from cube.graph.graph import IROperation, IRTensor +from cube.graph import IROperation, IRTensor from cube.graph.frame import Frame class ScriptNodeKind(enum.Enum): @@ -19,6 +19,7 @@ class ScriptModuleParser: @staticmethod def parse_module(module, + input_shapes: Optional[ Tuple[List[int],] ] = None, frame: Frame = Frame()) \ -> Tuple[List[IRTensor], List[IROperation], List[IRTensor]]: """ @@ -28,11 +29,20 @@ def parse_module(module, # handle graph input -- Assuming all the inputs are tensors input_var_name = [input.debugName() for input in module.graph.inputs()] - # [1:] is to omit self - for index, var_name in enumerate(input_var_name[1:]): + for index, var_name in enumerate(input_var_name[1:]): # omit self frame.add_var(var_name, IRTensor(name=var_name), graph_arg=index) input_val = [frame.get_var(var_name) for var_name in input_var_name[1:]] + # handle input shape + if input_shapes: + if len(input_val) != len(input_shapes): + raise RuntimeError( + f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(input_val)})" + ) + for shape, val in zip(input_shapes, input_val): + if isinstance(val, IRTensor): + val.shape = shape + all_ir_nodes: List[IROperation] = list() for node in module.graph.nodes(): # debug info @@ -42,6 +52,8 @@ def parse_module(module, # print(f'> {ir_nodes}') # _ = input('>>>') if len(ir_nodes) != 0: + for ir_node in ir_nodes: + ir_node.infer_shape() all_ir_nodes += ir_nodes # handle graph output -- Assuming all the output are tensors @@ -181,7 +193,7 @@ def parse_prim_method_node(node, module, frame: Frame) -> List[IROperation]: # recursively parse the module module_label = node.inputsAt(0).node().s('name') call_module = getattr(module, module_label) - _, ir_nodes, outputs_val = ScriptModuleParser.parse_module(call_module, frame) + _, ir_nodes, outputs_val = ScriptModuleParser.parse_module(call_module, frame=frame) # pop out the frame frame.pop_param(times=len(inputs)-1) diff --git a/cube/nn/__init__.py b/cube/nn/__init__.py deleted file mode 100644 index 7a8bac35..00000000 --- a/cube/nn/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from torch.nn import Module -from cube.nn.linear import Linear \ No newline at end of file diff --git a/cube/nn/linear.py b/cube/nn/linear.py deleted file mode 100644 index 8f43fbe2..00000000 --- a/cube/nn/linear.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from cube.tensor.logic.tensor import LogicalTensor -import cube.operator.logic as logic_op - -import math - -from torch import nn - -class Linear(nn.Module): - - __constants__ = ['in_features', 'out_features'] - - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super(Linear, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = LogicalTensor((out_features, in_features)) - if bias: - self.bias = LogicalTensor((out_features,)) - self.reset_parameters() - # Actually here we can pass shapes - self.op = logic_op.Linear() - - def reset_parameters(self) -> None: - pass - - def forward(self, input: LogicalTensor) -> LogicalTensor: - return self.op(input, self.weight, self.bias) diff --git a/cube/operator/__init__.py b/cube/operator/__init__.py index e69de29b..e8250dad 100644 --- a/cube/operator/__init__.py +++ b/cube/operator/__init__.py @@ -0,0 +1,2 @@ +import cube.operator.logic as logic + diff --git a/cube/operator/holist/__init__.py b/cube/operator/holist/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/operator/holist/generics.py b/cube/operator/holist/generics.py deleted file mode 100644 index 3780b26e..00000000 --- a/cube/operator/holist/generics.py +++ /dev/null @@ -1,208 +0,0 @@ - -""" -Holistic Operator Generics - -The holistic operator needed to be registered into logical op - -The output communication works in a lazy execution way. Communication will only -happen in the front of the next executed op in case the layout doesn't match. -""" - -from cube.tensor.logic.tensor import LogicalTensor -from cube.tensor.logic.outline import BaseOutline - -from cube.device.physic.group import DeviceGroup - -import z3 - - -class GenericHolisticOp: - - _default_policy_fn = None - - def __init__(self, outputs, *args, **kwargs): - """ - Layout is the community distribution requirement for input and - output logical tensors. - - Args: - outputs (list[LogicalTensor]): - output logical tensor (empty data) - *args, **kwargs: input arguments - - """ - self.solver = z3.Solver() - - self.input_layouts = list() - self.output_layouts = list() - - self.logical_op = None - self.output_shapes = list() - - self.attributes = list() - self.policy_fn = None - self.config = None - - def set_input_layouts(self, layouts): - """ - Set input layout - - Args: - layouts (list[BaseOutline]): layout list for input logical tensor - """ - for layout in layouts: - if not isinstance(layout, BaseOutline): - TypeError("Require input layout for HolistOp is a list[BaseOutline]") - self.attributes += layout.get_attributes() - self.input_layouts.append(layout) - - def set_output_layouts(self, layouts): - """ - Set output layout - - Args: - layouts (list[BaseOutline]): layout list for output logical tensor - """ - for layout in layouts: - if not isinstance(layout, BaseOutline): - TypeError("Require input layout for HolistOp is a list[BaseOutline]") - self.attributes += layout.get_attributes() - self.output_layouts.append(layout) - - def add_constraint(self, constraint): - """ - Add cross-layout constraint to the solver - """ - if not isinstance(constraint, z3.z3.BoolRef): - raise TypeError("Expected z3.z3.BoolRef constraints") - self.solver.add(constraint) - - def set_config(self, config): - if not isinstance(config, z3.z3.ModelRef): - raise TypeError("Expected config from z3 solver.model()") - self.config = config - - def input_adapter(self, *args, **kwargs): - """ - Transform tensors in args and kwargs to match the - input layout requirement, Currently kwargs is not allowed to - have tensors - """ - #TODO: kwargs - if len(self.input_layouts) != len(args): - raise RuntimeError("Fail to adapt input: layout length not equal") - - # step1: TODO: format (dimension reorder support) - - # step 2: Policy: segmentation + deploy decision - policy_fn = self._default_policy_fn - if self.policy_fn is not None: - policy_fn = self.policy_fn - config, input_ranks = policy_fn[0](self, *args, **kwargs) - self.set_config(config) - - # step 3: segmentation - input_segments = list() - for tensor, outliner in zip(args, self.input_layouts): - if outliner is not None and isinstance(tensor, LogicalTensor): - segments = outliner.interpret(tensor, self.config) - input_segments.append(segments) - else: - input_segments.append(None) - - # step 4: deploy - for tid in range(len(args)): - tensor = args[tid] - if isinstance(tensor, LogicalTensor): - segments = input_segments[tid] - ranks = input_ranks[tid] - tensor.transform(segments, ranks) - - def forward(self, *args, **kwargs): - """ - Expert code for doing operation - Call to the physical operator for execution - - Expert needs to gurantee the returned value is list[tuple(OpResult,),] - - Each item in list is the corresponding output to logical op output. - - Each item in the logical op output is a OpResult to the segment specified - by the expert. The order should be consistent with specified segment. - """ - raise NotImplementedError("Error call to generics") - - def output_adapter(self, outputs): - """ - Data reformat to logical op format - - Args: - outputs (tuple(list[OpResult],)) - each `list[OpResult]` represents a output of the op - with its segments - Returns: - logical outputs (tuple(LogicalTensor,)): - the logical tensor list - """ - #TODO: fix: data re-format order. Should be ahead of logical tensor construction - if not isinstance(outputs, tuple): - outputs = (outputs,) - - # step 1: construct to logical tensor - logical_outputs = list() - for output, outliner in zip(outputs, self.output_layouts): - logical_tensor = LogicalTensor(outliner.shape, init_data=False) - segments = outliner.interpret(logical_tensor, self.config) - for segment in segments: - logical_tensor.add_segment(segment) - logical_tensor.fill( - physical_tensors=[op_res.res for op_res in output], - ranks=[op_res.placement for op_res in output] - ) - logical_outputs.append(logical_tensor) - - # step 2: TODO: data reformat based on the output - - if len(logical_outputs) == 1: - return logical_outputs[0] - else: - return tuple(logical_outputs) - - def __call__(self, *args, **kwargs): - - # data transformations to match input layout requirement - self.input_adapter(*args, **kwargs) - - # do execution - args, kwargs = LogicalTensor.to_segments(*args, **kwargs) - outputs = self.forward(*args, **kwargs) - - # wrap to logical tensor - outputs = self.output_adapter(outputs) - - return outputs - - def set_policy(self, policy_fn): - """ - Register a customized policy to take layouts and solver, - generate segmentation plan and deploy plan - - Args: - plicy_fn (callable) - """ - if not callable(policy_fn): - raise TypeError("Expected callable function") - self.policy_fn = (policy_fn,) - - @classmethod - def set_default_policy(cls, policy_fn): - """ - Register a policy for all instances. Take layouts and solver, - generate segmentation plan and deploy plan - - Args: - plicy_fn (callable) - """ - if not callable(policy_fn): - raise TypeError("Expected callable function") - cls._default_policy_fn = (policy_fn,) diff --git a/cube/operator/holist/linear.py b/cube/operator/holist/linear.py deleted file mode 100644 index 34f3c5ec..00000000 --- a/cube/operator/holist/linear.py +++ /dev/null @@ -1,132 +0,0 @@ -from cube.operator.holist.generics import GenericHolisticOp - -import cube.operator.physic.linear as phy_linear -from cube.operator.physic.comm.mapreduce import PartialSum - -from cube.tensor.logic.tensor import LogicalTensor -import cube.tensor.logic.outline as outline - -# Debug -from cube.device.physic.group import DeviceGroup -import torch - - -class LinearColumnWeight(GenericHolisticOp): - """ - Perform Y = XW + b -> Y = X[W1,W2] + [b1,b2] - Split W and b on the last dimension - """ - - def __init__(self, outputs, input, weight, bias): - - super().__init__(outputs, input, weight, bias) - - # input layouts - input_layout = outline.Full(self.solver, input) - - weight_layout = outline.SplitAxis( - self.solver, weight, - axis=0, chunk_num=None, overlap=0 - ) - bias_layout = outline.SplitAxis( - self.solver, bias, - axis=0, chunk_num=None, overlap=0 - ) - self.add_constraint(bias_layout.chunk_num == weight_layout.chunk_num) - - # output layouts - output_layout = outline.SplitAxis( - self.solver, outputs[0], - axis=1, chunk_num=None, overlap=0 - ) - self.add_constraint(output_layout.chunk_num == weight_layout.chunk_num) - - self.set_input_layouts([input_layout, weight_layout, bias_layout]) - self.set_output_layouts([output_layout]) - - def forward(self, input, weight, bias): - """ - input: list[Segment] of input - weight: list[Segment] of weight - bias: list[Segment] of bias - """ - outputs = list() - # TODO: handle bias is None - physical_input = input[0].get_physical_tensor() - for weight_seg, bias_seg in zip(weight, bias): - # output = physic_op.linear(inputs, weight, bias) - #TODO: TensorContainer to enable op placement + tensor movement - #TODO: ExecutionScheduler to handle re-compute / swap - #TODO: nested hybrid call to enable hybrid-parallelisms - #TODO: double-check necessety of stateful physical operator - physical_weight = weight_seg.get_physical_tensor() - # if DeviceGroup().rank == 0: - # print(physical_weight) - physical_bias = bias_seg.get_physical_tensor() - # TODO: this is the policy decision - phy_op = phy_linear.Linear(placement=weight_seg.placement) - output = phy_op(physical_input, physical_weight, physical_bias) - # if DeviceGroup().rank == 0: - # print(output) - outputs.append(output) - return outputs - - -class LinearColumnInputRowWeight(GenericHolisticOp): - """ - Perform - Y = XW + b - -> Y = [X1,X2] * [W1//W2] + [b1 + b2]] - -> Y = X1W1 + X2W2 + b1 + b2 - Split X (inputs) in column major (last dim), - Split W (weights) in row major (first dim) - Split b (bias) in value major - """ - - def __init__(self, outputs, input, weight, bias): - - super().__init__(outputs, input, weight, bias) - - input_layout = outline.SplitAxis( - self.solver, input, - axis=-1, chunk_num=None, overlap=0, - ) - - weight_layout = outline.SplitAxis( - self.solver, weight, - axis=1, chunk_num=None, overlap=0, - ) - self.add_constraint(weight_layout.chunk_num == input_layout.chunk_num) - - bias_layout = outline.SplitValue( - self.solver, bias, - chunk_num=None, - val_op=PartialSum - ) - self.add_constraint(bias_layout.chunk_num == input_layout.chunk_num) - - # output layout will only use reduce op - output_layout = outline.SplitValue( - self.solver, outputs[0], - chunk_num=None, - val_op=PartialSum - ) - self.add_constraint(output_layout.chunk_num == input_layout.chunk_num) - - self.set_input_layouts([input_layout, weight_layout, bias_layout]) - self.set_output_layouts([output_layout]) - - def forward(self, input, weight, bias): - outputs = list() - for input_seg, weight_seg, bias_seg in zip(input, weight, bias): - phy_op = phy_linear.Linear(placement=weight_seg.placement) - output = phy_op( - input_seg.get_physical_tensor(), - weight_seg.get_physical_tensor(), - bias_seg.get_physical_tensor() - ) - outputs.append(output) - return outputs - - -kHolistLinearSets = [LinearColumnWeight, LinearColumnInputRowWeight] \ No newline at end of file diff --git a/cube/operator/logic/__init__.py b/cube/operator/logic/__init__.py index 39163d3e..46a60fcb 100644 --- a/cube/operator/logic/__init__.py +++ b/cube/operator/logic/__init__.py @@ -1 +1 @@ -from cube.operator.logic.linear import Linear \ No newline at end of file +from cube.operator.logic.function import * \ No newline at end of file diff --git a/cube/operator/logic/function.py b/cube/operator/logic/function.py new file mode 100644 index 00000000..52a1d290 --- /dev/null +++ b/cube/operator/logic/function.py @@ -0,0 +1,46 @@ +from typing import List, Optional + +from cube.operator.logic.generics import GenericLogicalOp +from cube.operator.logic.generics import ElementSameInputOp + + +class Linear(GenericLogicalOp): + + @staticmethod + def candidates(): + raise NotImplementedError + + @staticmethod + def shape_infer(input: List[int], + weight: List[int], + bias: Optional[List[int]] = None): + """ + input: [(D), M, K] + weight: [N, K] + bias: [N,] + """ + out_shape = list(input) + out_shape[-1] = weight[0] + return [out_shape] + + def translate(self, config): + raise NotImplementedError + + +class GeLU(ElementSameInputOp): + + def __init__(self, signature: str): + super().__init__(signature) + +class Dropout(ElementSameInputOp): + + def __init__(self, signature: str): + super().__init__(signature) + + +# ================== aten tensor op ======================== + +class TensorAdd(ElementSameInputOp): + + def __init__(self, signature: str): + super().__init__(signature) diff --git a/cube/operator/logic/generics.py b/cube/operator/logic/generics.py index 81915ce3..d0a2b33a 100644 --- a/cube/operator/logic/generics.py +++ b/cube/operator/logic/generics.py @@ -1,37 +1,23 @@ -""" -A Logical Operator: - * Statusless - * Can be executed by only one kernel (atomic) on single device +from typing import List -Logical Operator - |- Holistic Operator 1 - | |- Physical Operator(s) - |- Holistic Operator 2 - |- ... -Holistic operators are allowed to nested in hybrid-distribution strategy -""" - -from cube.tensor.logic.tensor import LogicalTensor - - -class HolisticOpFactory: +class DistAlgorithmFactory: def __init__(self): - self.holist_ops = list() + self.algorithms = list() def __len__(self): """ Return the number of holistic op registered """ - return len(self.holist_ops) + return len(self.algorithms) - def register(self, holistic_op): + def register(self, algorithm): """ Register a holistic op (class) as one of the anchors """ - self.holist_ops.append(holistic_op) + self.algorithms.append(algorithm) def get_op(self, idx, outputs, *args, **kwargs): """ @@ -46,20 +32,28 @@ def get_op(self, idx, outputs, *args, **kwargs): Returns: HolisticOp instance """ - return self.holist_ops[idx](outputs, *args, **kwargs) + return self.algorithms[idx](outputs, *args, **kwargs) class GenericLogicalOp: - _default_policy_fn = None - - def __init__(self): + def __init__(self, signature: str): + """ + Generic logical operator - # candidate holistic operator - self.factory = HolisticOpFactory() - self.policy_fn = None + signature (str): + Framework implementation signature, + e.g., 'torch.nn.functional.linear' + """ + if not isinstance(signature, str): + raise TypeError("Expect signature to be a string") + # factory + self.factory = DistAlgorithmFactory() + # torch impl signature + self.signature = signature - def shape_infer(self, *args, **kwargs): + @staticmethod + def shape_infer(*args, **kwargs): """ Output shape inference according to inputs @@ -70,49 +64,31 @@ def shape_infer(self, *args, **kwargs): shapes tuple(list[int]): shape for each output tensor """ raise NotImplementedError("Expected a shape infer engine") - - def get_op(self, *args, **kwargs): - # get shapes of input and output - shapes = self.shape_infer(*args, **kwargs) - outputs = [LogicalTensor(shape=shape, init_data=False) for shape in shapes] - # use default policy - if self.policy_fn is None: - composite_op = self._default_policy_fn[0](self.factory, outputs, *args, **kwargs) - # use user-customized policy - else: - composite_op = self.policy_fn[0](self.factory, outputs, *args, **kwargs) - return composite_op - - def __call__(self, *args, **kwargs): + + def register_algorithm(self, algorithm): """ - Policy here to determine which holistic operator(s) are called + Register a distributed algoritm description """ - composite_op = self.get_op(*args, **kwargs) - # run operator with the strategy plan - outputs = composite_op(*args, **kwargs) - return outputs + self.factory.register(algorithm) - def set_policy(self, policy_fn): + def translate(self, config): """ - Register a policy function to customize how composite - holistic op generated during runtime. + Translate the algorithm to implementation + """ + raise NotImplementedError("Expected a tranlation for operator") + - The `policy_fn` takes self.factory as input and returns a composite - holistic operator (callable) +class ElementSameInputOp(GenericLogicalOp): + + def __init__(self): """ - if not callable(policy_fn): - raise TypeError("Expected a callable function") - self.policy_fn = (policy_fn,) - - @classmethod - def set_default_policy(self, policy_fn): + Elementwise Operator """ - Register a default policy function to all instances. - Customize how composite holistic op generated during runtime. + super().__init__() - The `policy_fn` takes self.factory and shapes as input, - and returns a composite holistic operator (callable) + @staticmethod + def shape_infer(input: List[int], *args, **kwargs): + """ + Element-wise single input op """ - if not callable(policy_fn): - raise TypeError("Expected a callable function") - self._default_policy_fn = (policy_fn,) + return [input] diff --git a/cube/operator/logic/linear.py b/cube/operator/logic/linear.py deleted file mode 100644 index e8f32314..00000000 --- a/cube/operator/logic/linear.py +++ /dev/null @@ -1,20 +0,0 @@ -from cube.operator.logic.generics import GenericLogicalOp -from cube.operator.holist.linear import kHolistLinearSets - - -class Linear(GenericLogicalOp): - - def __init__(self): - super().__init__() - - # register holistic operators - for holist_op in kHolistLinearSets: - self.factory.register(holist_op) - - def shape_infer(self, input, weight, bias=None): - """ - Return the outputs shape [list[int],] - """ - output_shape = list(input.shape) - output_shape[-1] = weight.shape[0] - return [output_shape,] diff --git a/cube/operator/physic/__init__.py b/cube/operator/physic/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/operator/physic/comm/__init__.py b/cube/operator/physic/comm/__init__.py deleted file mode 100644 index 3d140a8e..00000000 --- a/cube/operator/physic/comm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cube.operator.physic.comm.boundary import * \ No newline at end of file diff --git a/cube/operator/physic/comm/boundary.py b/cube/operator/physic/comm/boundary.py deleted file mode 100644 index ca27a516..00000000 --- a/cube/operator/physic/comm/boundary.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Autograd backward needs to return the same number of gradients as input, -even if they are not tensors. -""" - -import torch - -from cube.device.physic.group import DeviceGroup - - -__all__ = ['replicate', 'gather_out', 'scatter_in', 'reduce_sum'] - - -def _reduce(input_, group): - """All-reduce the the input tensor across model parallel group.""" - - # allreduce - torch.distributed.all_reduce(input_, group=group) - return input_ - - -def _split(input_, dim, chunk_num, rank): - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - # bypass the function if we are using only 1 GPU. - if chunk_num == 1: - return input_ - # split along specified dim - if input_.size()[dim] % chunk_num != 0: - raise RuntimeError("backward on Gather Out Error: un divideable") - dim_size = input_.size()[dim] // chunk_num - tensor_list = torch.split(input_, dim_size, dim=dim) - # note: torch.split does not create contiguous tensors by default. - output = tensor_list[rank].contiguous() - return output - - -def _gather(input_, dim, group): - """Gather tensors and concatinate along the last dimension.""" - - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) - # bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=group) - # note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=dim).contiguous() - return output - - -class _ParallelIn(torch.autograd.Function): - """Pass the input to the model parallel region.""" - - @staticmethod - def forward(ctx, input_, group): - # record group - ctx.constants = group - # identitfy forward - return input_ - - @staticmethod - def backward(ctx, grad_output): - # allreduce - group = ctx.constants - return torch.distributed.all_reduce(grad_output, group=group), None - - -class _GatherOut(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - @staticmethod - def forward(ctx, input_, dim, group): - # record group - ctx.constants = (group, dim) - # allgather - return _gather(input_, dim, group) - - @staticmethod - def backward(ctx, grad_output): - group, dim = ctx.constants - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) - return _split(grad_output, dim, world_size, rank), None, None - - -class _ScatterIn(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - @staticmethod - def forward(ctx, input_, dim, group): - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) - ctx.constants = (group, dim) - return _split(input_, dim, world_size, rank) - - @staticmethod - def backward(ctx, grad_output): - group, dim = ctx.constants - return _gather(grad_output, dim, group), None, None - - -class _ReduceOut(torch.autograd.Function): - """All-reduce the input from the model parallel region.""" - - @staticmethod - def forward(ctx, input_, group): - return _reduce(input_, group) - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -def replicate(input_, group): - return _ParallelIn.apply(input_, group) - - -def gather_out(input_, dim, group): - return _GatherOut.apply(input_, dim, group) - - -def scatter_in(input_, dim, group): - return _ScatterIn.apply(input_, dim, group) - - -def reduce_sum(input_, group): - return _ReduceOut.apply(input_, group) - diff --git a/cube/operator/physic/comm/mapreduce.py b/cube/operator/physic/comm/mapreduce.py deleted file mode 100644 index 3b3beaf0..00000000 --- a/cube/operator/physic/comm/mapreduce.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch - - -class ValueMapReduceOp: - - def __init__(self, val_map_op, val_reduce_op): - if not (callable(val_map_op) and callable(val_reduce_op)): - raise TypeError("Expected val_map_op and val_reduce_o callable") - self.val_map_op = (val_map_op,) - self.val_reduce_op = (val_reduce_op,) - - def map(self, tensor, group): - if not torch.is_tensor(tensor): - raise RuntimeError("Expected tensor to be torch.Tensor") - return self.val_map_op[0](tensor, group) - - def reduce(self, tensor, group): - if not torch.is_tensor(tensor): - raise RuntimeError("Expected `tensor` to be torch.Tensor") - return self.val_map_op[0](tensor, group) - - -def _val_split_map(tensor, group): - world_size = torch.distributed.get_world_size(group) - return tensor / world_size - - -def _val_sum_reduce(tensor, group): - torch.distributed.all_reduce(tensor, group=group) - return tensor - - -PartialSum = ValueMapReduceOp( - val_map_op = _val_split_map, - val_reduce_op = _val_sum_reduce -) diff --git a/cube/operator/physic/generics.py b/cube/operator/physic/generics.py deleted file mode 100644 index 2ca3e3d0..00000000 --- a/cube/operator/physic/generics.py +++ /dev/null @@ -1,93 +0,0 @@ -""" -This should be the interface with C level kernel launch -""" - -from cube.device.physic.group import DeviceGroup -import torch - -class OpResult: - """ - The empty result is used for re-constructing community - """ - def __init__(self, result, ranks): - self.res = result - self.placement = ranks - - def get_result(self): - return self.res - - def __repr__(self): - return "OpResult(res={}, placement={})".format(self.res, self.placement) - - -class GenericPhysicOp: - """ - The generic physical op takes at least one physical tensor, - and generates at least one physical tensor. - - If there is no tensor as input, will return an empty result - which indicates which rank will generate the correct one. - """ - - def __init__(self, func, placement=None): - - if not callable(func): - raise TypeError("Expect callable function") - if not (isinstance(placement, list) or placement is None): - raise TypeError("Expected placement init with None or list[int]") - self.func = (func,) - self._placement = None - self.execute_flag = False - self.policy_fn = None - if isinstance(placement, list): - self.placement = placement - - @property - def placement(self): - """ - Ranks for the op to execute - """ - return self._placement - - @placement.setter - def placement(self, ranks): - if not isinstance(ranks, list): - raise TypeError("Expected list of int ranks") - self._placement = ranks - if DeviceGroup().rank not in self.placement: - self.execute_flag = False - else: - self.execute_flag = True - - def register_policy(self, policy_fn): - if not callable(policy_fn): - raise TypeError("Expected callable policy function") - self.policy_fn = [policy_fn] - - def __call__(self, *args, **kwargs): - #TODO: fix for model-partition with send/recv - if self.placement is None: - if self.policy_fn is None: - #TODO: fix: this will break between-device consistency view - self.placement = [torch.cuda.current_device()] - else: - self.placement = self.policy_fn(*args, **kwargs) - if not self.execute_flag: - return OpResult(None, self.placement) - - # tensor movement - for arg in args: - if torch.is_tensor(arg): - #TODO: rank -> device mapping, send/recv - if arg.device.index not in self.placement: - #TODO: rank -> device mapping, send/recv - arg.data = arg.detach().cuda() - for key in kwargs: - if torch.is_tensor(kwargs[key]): - #TODO: rank -> device mapping, send/recv - if kwargs[key].device.index not in self.placement: - # TODO: rank -> device mapping, send/recv - kwargs[key].data = kwargs[key].detach().cuda() - - outputs = self.func[0](*args, **kwargs) - return OpResult(outputs, self.placement) diff --git a/cube/operator/physic/linear.py b/cube/operator/physic/linear.py deleted file mode 100644 index 8af300f9..00000000 --- a/cube/operator/physic/linear.py +++ /dev/null @@ -1,12 +0,0 @@ -from cube.operator.physic.generics import GenericPhysicOp - -import torch - - -class Linear(GenericPhysicOp): - """ - Apply matmul: Out = input * weight^T + bias - """ - - def __init__(self, placement=None): - super().__init__(torch._C._nn.linear, placement) diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py new file mode 100644 index 00000000..23966a1c --- /dev/null +++ b/cube/tschedule/__init__.py @@ -0,0 +1 @@ +from cube.tschedule.action import Action \ No newline at end of file diff --git a/cube/tschedule/action.py b/cube/tschedule/action.py new file mode 100644 index 00000000..57f80cad --- /dev/null +++ b/cube/tschedule/action.py @@ -0,0 +1,90 @@ +from typing import List + +from cube.graph import IRGraph + + +class Action: + """ + Action represents a (sub-)graph which contains operators on the + same device + """ + def __init__(self, graph: IRGraph, device: int): + + if not isinstance(graph, IRGraph): + raise TypeError("Require graph to be IRGraph") + if not isinstance(device, int): + raise TypeError("Require device to be int") + # set up attributes + self.graph: IRGraph = graph + self.device: int = device + self.name: str = None + # dependencies + self._pre_actions: List[Action] = list() + self._post_actions: List[Action] = list() + + @property + def device(self): + return self._device + + @device.setter + def device(self, device): + for op in self.graph.nodes(): + op.deivce = device + self._device = device + + def tag(self, name: str): + """ + Tag a string to indicate this action (as name) + """ + self.name = name + + def happen_before(self, action): + """ + Check if the self -> (happened before) action + """ + if not isinstance(action, Action): + raise TypeError("Expected action to be an Action") + return action in self._post_actions + + def post_actions(self): + """ + Get post action list + """ + return self._post_actions + + def happen_after(self, action): + """ + Check if the action -> (happened before) self + + Note: this may return false negative as it will only check + 1-hop dependency + """ + if not isinstance(action, Action): + raise TypeError("Expected action to be an Action") + return action in self._pre_actions + + def pre_actions(self): + """ + Get pre action list + + Note: this may return false negative as it will only check + 1-hop dependency + """ + return self._pre_actions + + def add_flow(self, action): + """ + Make this action (self) -> (happened before) action + """ + if not isinstance(action, Action): + raise TypeError("Expected action to be Action") + self._post_actions.append(action) + action._add_pre_action(self) + + def _add_pre_action(self, action): + """ + Add successor that requries this action happened first + """ + if not isinstance(action, Action): + raise TypeError("Expected action to be Action") + self._successors.append(action) diff --git a/tests/graph/test_parser.py b/tests/graph/test_parser.py new file mode 100644 index 00000000..b67817a2 --- /dev/null +++ b/tests/graph/test_parser.py @@ -0,0 +1,42 @@ +from cube.graph.parser import ScriptModuleParser +import torch +from torch import nn + +import cube.graph as cgraph +from cube.graph.parser import ScriptModuleParser + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=16, classes=1000): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult, bias=False) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim * mult, dim) + self.classifier = nn.Linear(dim, classes) + + def forward(self, data, x: int = 4): + output = self.linear1(data) + output = self.gelu(output) + output = self.dropout(output) + output = output + data + output = self.linear2(output) + output = self.classifier(output) + return output + +model = FeedForward(dim=1024) +smodule = torch.jit.script(model) + + +def test_flatten(smodule): + ScriptModuleParser.flatten(smodule) + +def test_parse_module(model): + return cgraph.convert(model, input_shapes=([1024,1024],[1,])) + + +if __name__ == '__main__': + + + # test_flatten(smodule) + graph = test_parse_module(model) + print(graph) \ No newline at end of file diff --git a/tests/operator/test_holistic_linear.py b/tests/operator/test_holistic_linear.py deleted file mode 100644 index dfe0ae6d..00000000 --- a/tests/operator/test_holistic_linear.py +++ /dev/null @@ -1,219 +0,0 @@ -""" -cmd for running the test - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - tests/operator/test_holistic_linear.py -""" - -from cube.tensor.logic.tensor import LogicalTensor - -from cube.operator.holist.linear import LinearColumnWeight -from cube.operator.holist.linear import LinearColumnInputRowWeight - -from cube.device.physic.group import DeviceGroup - -import torch -import z3 -torch.manual_seed(100) - - -def test_linear_POC(): - - N = 1024 - input = torch.randn((1024, 1024)).cuda() - weight = torch.randn((N, 1024)) - bias = torch.randn((N,)) - - rank = DeviceGroup().rank - - # partial - partial_weight = torch.chunk(weight, 4, dim=0)[rank].cuda() - partial_bias = torch.chunk(bias, 4, dim=0)[rank].cuda() - partial_out = torch._C._nn.linear(input, partial_weight, partial_bias) - - # full - out_full = torch._C._nn.linear(input, weight.cuda(), bias.cuda()) - ref_out = torch.chunk(out_full, 4, dim=1)[rank].cuda() - - if rank == 0: - print('max bias: ', torch.max(torch.abs(partial_out - ref_out))) - print('sum bias: ', torch.sum(torch.abs(partial_out - ref_out))) - - -def test_holistic_linear_op_column_weight(): - """ - Note: Due to unknown reason in hardware, the output will have up to - 0.0001 bias. This is verified in `test_linear_POC()` The larger - K results larger bias. - """ - N = 1024 - shapes = [(1024, 1024), (N, 1024), (N,), (1024, N)] - input = LogicalTensor(shape=shapes[0]) - weight = LogicalTensor(shape=shapes[1]) - bias = LogicalTensor(shape=shapes[2]) - outputs = [LogicalTensor(shapes[3])] - - # ================================ Policy =========================== - - holistic_op = LinearColumnWeight(outputs, input, weight, bias) - - def policy(holist_op, input, weight, bias): - solver = holist_op.solver - attributes = holist_op.attributes - input_layout = holist_op.input_layouts[0] - weight_layout = holist_op.input_layouts[1] - bias_layout = holist_op.input_layouts[2] - output_layout = holist_op.output_layouts[0] - - # add restrictions based on device num - device_num = torch.cuda.device_count() - solver.add(weight_layout.chunk_num == 4) - - # iterate all configs - configs = list() - while solver.check() == z3.sat: - config = solver.model() - if DeviceGroup().rank == 0: - print('find config: {}'.format(config)) - configs.append(config) - solver.add( - z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) - ) - if len(attributes) == 0: - break - # choose one config -- suppose to the first - config = configs[0] - if DeviceGroup().rank == 0: - print('selected config: {}'.format(config)) - - # deploy decisions - chunk_num = config[weight_layout.chunk_num].as_long() - input_ranks = [list(range(0, chunk_num)),] - weight_ranks = list() - for rank in range(chunk_num): - weight_ranks.append([rank]) - bias_ranks = weight_ranks - - return config, [input_ranks, weight_ranks, bias_ranks] - - # Missing Policy: where physical op executed? - - holistic_op.set_policy(policy) - # ================================ Policy =========================== - - output = holistic_op(input, weight, bias) - print('segments: {}'.format(len(output.segments))) - - # =============================== Test ============================== - output_ref = torch._C._nn.linear( - input.data.cuda(), weight.data.cuda(), bias.data.cuda() - ) - rank = DeviceGroup().rank - output_ref = torch.chunk(output_ref, chunks=4, dim=1)[rank].contiguous() - out = output.get_physical_tensor(rank) - # if rank == 0: - # print('ref: ', output_ref) - # print('get: ', out) - # print('max bias: ', torch.max(torch.abs(out - output_ref))) - # print('sum bias: ', torch.sum(torch.abs(out - output_ref))) - error_max = torch.max(torch.abs(out - output_ref)) - assert error_max.item() < 2e-4 - # =============================== Test ============================== - - -def test_holistic_linear_op_column_input_row_weight(): - - N = 1024 - shapes = [(1024, 1024), (N, 1024), (N,), (1024, N)] - input = LogicalTensor(shape=shapes[0]) - weight = LogicalTensor(shape=shapes[1]) - bias = LogicalTensor(shape=shapes[2]) - outputs = [LogicalTensor(shapes[3])] - - # ================================ Policy =========================== - - holistic_op = LinearColumnInputRowWeight(outputs, input, weight, bias) - - def policy(holist_op, input, weight, bias): - solver = holist_op.solver - attributes = holist_op.attributes - input_layout = holist_op.input_layouts[0] - weight_layout = holist_op.input_layouts[1] - bias_layout = holist_op.input_layouts[2] - output_layout = holist_op.output_layouts[0] - - # add restrictions based on device num - device_num = torch.cuda.device_count() - solver.add(weight_layout.chunk_num == 4) - - # iterate all configs - configs = list() - while solver.check() == z3.sat: - config = solver.model() - if DeviceGroup().rank == 0: - print('find config: {}'.format(config)) - configs.append(config) - solver.add( - z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) - ) - if len(attributes) == 0: - break - # choose one config -- suppose to the first - config = configs[0] - if DeviceGroup().rank == 0: - print('selected config: {}'.format(config)) - - # deploy decisions - chunk_num = config[weight_layout.chunk_num].as_long() - input_ranks = list() - for rank in range(chunk_num): - input_ranks.append([rank]) - weight_ranks = input_ranks - bias_ranks = weight_ranks - - return config, [input_ranks, weight_ranks, bias_ranks] - - # Missing Policy: where physical op executed? - - holistic_op.set_policy(policy) - # ================================ Policy =========================== - - output = holistic_op(input, weight, bias) - print('segments: {}'.format(len(output.segments))) - - # =============================== Test ============================== - rank = DeviceGroup().rank - input_ref = torch.chunk(input.data.cuda(), chunks=4, dim=-1)[rank] - weight_ref = torch.chunk(weight.data.cuda(), chunks=4, dim=1)[rank] - bias_ref = bias.data.cuda() / 4 - # if rank == 0: - # print('input ref: ', input_ref) - # print('weight ref: ', weight_ref) - # print('bias ref: ', bias_ref) - - output_ref = torch._C._nn.linear( - input_ref, weight_ref, bias_ref - ) - out = output.get_physical_tensor(rank) - # if rank == 0: - # print('ref: ', output_ref) - # print('get: ', out) - # print('max bias: ', torch.max(torch.abs(out - output_ref))) - # print('sum bias: ', torch.sum(torch.abs(out - output_ref))) - error_max = torch.max(torch.abs(out - output_ref)) - assert error_max.item() < 2e-4 - # =============================== Test ============================== - - -if __name__ == '__main__': - group = DeviceGroup() - - # test_linear_POC() - test_holistic_linear_op_column_weight() - test_holistic_linear_op_column_input_row_weight() diff --git a/tests/operator/test_holistic_op.py b/tests/operator/test_holistic_op.py deleted file mode 100644 index d5dd75a6..00000000 --- a/tests/operator/test_holistic_op.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -cmd for running the test - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - tests/operator/test_holistic_op.py -""" - -import cube.tensor.logic.outline as outline -from cube.tensor.logic.tensor import LogicalTensor -from cube.operator.holist.generics import GenericHolisticOp -from cube.device.physic.group import DeviceGroup -import torch -import z3 - -def test_generic_holistic_op_init(): - - shapes = [(32, 2048), (1024, 2048), (32, 1024)] - op = GenericHolisticOp(shapes) - - # description - input_layout = outline.Full( - op.solver, op.shapes[0], - ) - weight_layout = outline.SplitAxis( - op.solver, op.shapes[1], - axis=0, chunk_num=None, overlap=0, - ) - output_layout = outline.SplitAxis( - op.solver, op.shapes[2], - axis=0, chunk_num=weight_layout.chunk_num, overlap=0, - ) - - assert op.shapes == shapes - assert len(op.input_layouts) == 0 - assert len(op.output_layouts) == 0 - assert op.logical_op is None - assert op.policy_fn is None - - op.set_input_layouts([input_layout, weight_layout]) - op.set_output_layouts([output_layout]) - - assert len(op.input_layouts) == 2 - assert len(op.output_layouts) == 1 - assert len(op.attributes) == 5 - - -def test_generic_holistic_op_input_adapter(): - - shapes = [(32, 512), (1024, 512), (32, 1024)] - input = LogicalTensor(shape=shapes[0]) - weight = LogicalTensor(shape=shapes[1]) - - op = GenericHolisticOp(shapes) - - # description - input_layout = outline.Full( - op.solver, op.shapes[0], - ) - weight_layout = outline.SplitAxis( - op.solver, op.shapes[1], - axis=0, chunk_num=None, overlap=0, - ) - output_layout = outline.SplitAxis( - op.solver, op.shapes[2], - axis=0, chunk_num=weight_layout.chunk_num, overlap=0, - ) - - op.set_input_layouts([input_layout, weight_layout]) - op.set_output_layouts([output_layout]) - - def policy(holist_op): - solver = holist_op.solver - attributes = holist_op.attributes - input_layout = holist_op.input_layouts[0] - weight_layout = holist_op.input_layouts[1] - output_layout = holist_op.output_layouts[0] - - # add restrictions based on device num - device_num = torch.cuda.device_count() - solver.add(weight_layout.chunk_num <= 4) - - # iterate all configs - configs = list() - while solver.check() == z3.sat: - config = solver.model() - configs.append(config) - solver.add( - z3.Or([z3.Not(attr == config[attr]) for attr in attributes]) - ) - if len(attributes) == 0: - break - # choose one config -- suppose to the first - config = configs[0] - - # deploy decisions - chunk_num = config[weight_layout.chunk_num].as_long() - input_ranks = [list(range(0, chunk_num)),] - weight_ranks = list() - for rank in range(chunk_num): - weight_ranks.append([rank]) - - return config, [input_ranks, weight_ranks] - - op.set_policy(policy) - op.input_adapter(input, weight) - - -if __name__ == '__main__': - group = DeviceGroup() - test_generic_holistic_op_init() - test_generic_holistic_op_input_adapter() diff --git a/tests/operator/test_logical_op.py b/tests/operator/test_logical_op.py deleted file mode 100644 index 931eb0d0..00000000 --- a/tests/operator/test_logical_op.py +++ /dev/null @@ -1,47 +0,0 @@ -from cube.operator.logic.generics import HolisticOpFactory, GenericLogicalOp - - -def test_factory(): - - factory = HolisticOpFactory() - assert len(factory) == 0 - - class HolisticOp: - def __init__(self, shape): pass - - factory.register(HolisticOp) - assert len(factory) == 1 - - op = factory.get_op(0, [(1024, 1024)]) - assert isinstance(op, HolisticOp) - - -def test_generic_logical_op_init(): - - generic_op = GenericLogicalOp() - assert len(generic_op.factory) == 0 - assert generic_op.policy_fn is None - - -def test_generic_logical_op_register(): - - generic_op = GenericLogicalOp() - - class HolisticOp: - def __init__(self, shape): pass - - generic_op.factory.register(HolisticOp) - - def policy_fn(factory, shapes): - return factory.get_op(0, shapes) - - generic_op.set_policy(policy_fn) - assert generic_op.policy_fn is not None - - - -if __name__ == '__main__': - - test_factory() - test_generic_logical_op_init() - test_generic_logical_op_register() \ No newline at end of file diff --git a/tests/operator/test_physic_linear.py b/tests/operator/test_physic_linear.py deleted file mode 100644 index f51b51ed..00000000 --- a/tests/operator/test_physic_linear.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -cmd for running the test - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - tests/operator/test_physic_linear.py -""" - -from cube.device.physic.group import DeviceGroup -from cube.operator.physic.linear import Linear -import torch - - -def test_physic_generic_op(): - - myrank = DeviceGroup().rank - ranks = [0, 2] - - op = Linear() - assert op.placement is None - - op.placement = ranks - assert op.func[0] is torch._C._nn.linear - assert op.placement == [0, 2] - assert op.execute_flag == (myrank in ranks) - - matA = torch.randn((1024,1024)) - matB = torch.randn((1024,1024)) - matC = op(matA, matB, bias=None) - - assert set(matC.placement) == set(ranks) - if myrank in ranks: - assert torch.is_tensor(matC.get_result()) - else: - assert matC.get_result() is None - - matC_ref = torch._C._nn.linear(matA, matB, bias=None) - if myrank in ranks: - assert torch.allclose(matC.get_result(), matC_ref) is True - - -if __name__ == '__main__': - test_physic_generic_op() diff --git a/tests/operator/test_physic_op.py b/tests/operator/test_physic_op.py deleted file mode 100644 index a2472e50..00000000 --- a/tests/operator/test_physic_op.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -cmd for running the test - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - tests/operator/test_physic_op.py -""" - -from cube.device.physic.group import DeviceGroup -from cube.operator.physic.generics import GenericPhysicOp, OpResult -import torch - - -def test_physic_generic_op(): - - myrank = DeviceGroup().rank - ranks = [0, 2] - - op = GenericPhysicOp(torch._C._nn.linear) - assert op.placement is None - - op.placement = ranks - assert op.func[0] is torch._C._nn.linear - assert op.placement == [0, 2] - assert op.execute_flag == (myrank in ranks) - - matA = torch.randn((1024,1024)) - matB = torch.randn((1024,1024)) - matC = op(matA, matB, bias=None) - - assert set(matC.placement) == set(ranks) - if myrank in ranks: - assert torch.is_tensor(matC.get_result()) - else: - assert matC.get_result() is None - - matC_ref = torch._C._nn.linear(matA, matB, bias=None) - if myrank in ranks: - assert torch.allclose(matC.get_result(), matC_ref) is True - - -if __name__ == '__main__': - test_physic_generic_op() From 3d1620050137053e5644259d01538b49fb580049 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 13:55:42 +0800 Subject: [PATCH 0179/1892] code generation --- cube/codegen/codegen.py | 64 ++++++++++++++++++++++++--------- cube/codegen/syntax/blocks.py | 11 +++--- cube/codegen/syntax/symtable.py | 3 +- cube/graph/graph.py | 23 ++++++++---- tests/codegen/test_codegen.py | 39 ++++++++++++++++++++ tests/graph/test_parser.py | 1 - 6 files changed, 111 insertions(+), 30 deletions(-) create mode 100644 tests/codegen/test_codegen.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index de0cfbae..b1ef33da 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -2,6 +2,7 @@ Generate Pytorch code given the model DAG and the transformation config """ +from inspect import Arguments from typing import List, Any from cube.graph import IRGraph, IRTensor, IROperation @@ -19,7 +20,7 @@ def __init__(self, graph: IRGraph): raise TypeError("graph should be IRGraph") self.graph = graph # model full code - self.code: List[str] = list() + self.code: List[str] = ['import torch', '', ''] # module init code self.declare_region: List[str] = list() # module forward code @@ -48,20 +49,27 @@ def gen(self, outfile=None) -> List[str]: self.symbols.create(self.naming(out)) # generate full code - with ClassBlock(class_name='GenModel', derived='torch.nn.Module') as cb: + with ClassBlock(class_name='GenModel', derived=['torch.nn.Module']) as cb: with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.declare_region) + cb.insert_body('') cb.insert_body(ib.code) with FunctionBlock(func_name='forward', args=['self']+fargs) as fb: - fb.insert_body(self.emit_op_call) + fb.insert_body(self.forward_region) + # generate output + out_names = self._forward_region_arg_names(self.graph.outputs()) + return_code = f"return {', '.join(out_names)}" + fb.insert_body(return_code) + cb.insert_body('') cb.insert_body(fb.code) - self.code = cb.code + self.code += cb.code + self.code += [''] # write to file if outfile: - with open(outfile, 'w'): - for line in self.code: - outfile.write(line) + with open(outfile, 'w') as f: + code = '\n'.join(self.code) + f.write(code) return self.code @@ -73,14 +81,16 @@ def emit_var_declare(self, var: Any): name = self.naming(var) # indicate this is a leaf tensor, should be parameter if self.symbols.create(name): - code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(IRTensor.shape)}))' + code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(var.shape)}))' self.declare_region.append(code) elif isinstance(var, str): - name = self.naming(var) - if self.symbols.create(name): - #TODO: add type info - code = f'self.{name} = None' - self.declare_region.append(code) + # TODO: handle var that is not default in nn.Module + pass + # name = self.naming(var) + # if self.symbols.create(name): + # #TODO: add type info + # code = f'{name} = None' + # self.declare_region.append(code) return def emit_op_call(self, node: IROperation): @@ -88,18 +98,38 @@ def emit_op_call(self, node: IROperation): Emit op forward code """ op_code = node.signature - out_region = ', '.join([self.naming(out) for out in node.outputs()]) - arg_region = '(' + ', '.join([self.naming(arg) for arg in node.inputs()]) + ')' - code = f'{out_region} = {op_code}{arg_region}' + out_names = self._forward_region_arg_names(node.outputs()) + out_names = ', '.join(out_names) + arg_names = self._forward_region_arg_names(node.inputs()) + arg_region = '(' + ', '.join(arg_names) + ')' + code = f'{out_names} = {op_code}{arg_region}' self.forward_region.append(code) + def _forward_region_arg_names(self, args: List[Any]): + """ + Generate arg name list for forward region. + + Will add prefix 'self.' for var defined in declare region + """ + named_args : List[str] = list() + for arg in args: + if isinstance(arg, IRTensor) and arg.is_leaf(): + named_args.append('self.' + self.naming(arg)) + else: + named_args.append(self.naming(arg)) + return named_args + def naming(self, tensor: Any) -> str: """ Return the var name (unique for different variable) + + If the var is a leaf tensor, will add prefix `self.` to its name """ if isinstance(tensor, IRTensor): tensor_name = 'tensor' if tensor.name is None else tensor.name - name = '_'.join([tensor_name, tensor._id]) + if '.' in tensor_name: + tensor_name = tensor_name.split('.')[0] + name = '_'.join([tensor_name, str(tensor._id)]) else: name = str(tensor) return name diff --git a/cube/codegen/syntax/blocks.py b/cube/codegen/syntax/blocks.py index b16bbb8e..51488a4e 100644 --- a/cube/codegen/syntax/blocks.py +++ b/cube/codegen/syntax/blocks.py @@ -13,10 +13,12 @@ def __enter__(self): def insert_body(self, code): if isinstance(code, list): self.code += code - elif type(code) == str: + elif isinstance(code, str): self.code.append(code) else: - raise TypeError + raise TypeError( + f"Get type {type(code)} but expected list[str] or list" + ) def __exit__(self, exc_type, exc_value, exc_tb): # add indent for function block @@ -54,6 +56,5 @@ def __init__(self, class_name, derived=None): if derived: derived = ', '.join(derived) derived = f'({derived})' - title = f'class {self.class_name}{derived}' - super().__init__(self, title) - + title = f'class {self.class_name}{derived}:' + super().__init__(title) diff --git a/cube/codegen/syntax/symtable.py b/cube/codegen/syntax/symtable.py index 3712453e..8732112f 100644 --- a/cube/codegen/syntax/symtable.py +++ b/cube/codegen/syntax/symtable.py @@ -31,7 +31,8 @@ def create(self, var_name: str): return False else: self._varlist.append(var_name) - + return True + def exist(self, var_name: str): """ Check whether a variable exists diff --git a/cube/graph/graph.py b/cube/graph/graph.py index e35e9db1..51295011 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -235,6 +235,12 @@ def add_dst_nodes(self, node: IROperation): raise TypeError("IRTensor destination node should be IROperation") self._dst_nodes.append(IROperation) + def is_leaf(self): + """ + Check if it is a leaf tensor (parameter) + """ + return self.src() is None + def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape})' return dscp @@ -262,15 +268,20 @@ def add_node(self, node: IROperation): raise TypeError("Expected node to be IROperation") self._nodes.append(node) - def nodes(self, index: Optional[int]): + def nodes(self, index: Optional[int] = None): """ Get node at position index """ - if index >= len(self._nodes): - raise RuntimeError( - f"Get node out of range ({index} >= {len(self._nodes)})" - ) - return self._nodes[index] + if isinstance(index, int): + if index >= len(self._nodes): + raise RuntimeError( + f"Get node out of range ({index} >= {len(self._nodes)})" + ) + return self._nodes[index] + elif index is None: + return self._nodes + else: + raise TypeError("Expected index to be None or int") def inputs(self, index: Optional[int] = None): if isinstance(index, int): diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py new file mode 100644 index 00000000..265b5640 --- /dev/null +++ b/tests/codegen/test_codegen.py @@ -0,0 +1,39 @@ +import cube.graph as cgraph +from cube.codegen.codegen import SScheduleCodeGen + +import torch +from torch import nn + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=16, classes=1000): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult, bias=False) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim * mult, dim) + self.classifier = nn.Linear(dim, classes) + + def forward(self, data, x: int = 4): + output = self.linear1(data) + output = self.gelu(output) + output = self.dropout(output) + output = output + data + output = self.linear2(output) + output = self.classifier(output) + return output + + +model = FeedForward(dim=1024) + + +def test_codegen(model): + graph = cgraph.convert(model, + input_shapes=([1024,1024],[1,])) + gener = SScheduleCodeGen(graph) + gener.gen(outfile='code.py') + + +if __name__ == '__main__': + + test_codegen(model) \ No newline at end of file diff --git a/tests/graph/test_parser.py b/tests/graph/test_parser.py index b67817a2..50329005 100644 --- a/tests/graph/test_parser.py +++ b/tests/graph/test_parser.py @@ -24,7 +24,6 @@ def forward(self, data, x: int = 4): return output model = FeedForward(dim=1024) -smodule = torch.jit.script(model) def test_flatten(smodule): From c4894b1517e18a9ec0173ecd0f168f2059d27dbc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 14:13:43 +0800 Subject: [PATCH 0180/1892] fix self. var naming --- cube/codegen/codegen.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index b1ef33da..114db6cc 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,14 +1,14 @@ """ Generate Pytorch code given the model DAG and the transformation config """ - -from inspect import Arguments from typing import List, Any from cube.graph import IRGraph, IRTensor, IROperation from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock +import torch + class SScheduleCodeGen: """ @@ -27,6 +27,8 @@ def __init__(self, graph: IRGraph): self.forward_region: List[str] = list() # module member name self.symbols = SymbolTable() + # ref module to check shared variables + self._ref_module = torch.nn.Module() def gen(self, outfile=None) -> List[str]: """ @@ -84,13 +86,13 @@ def emit_var_declare(self, var: Any): code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(var.shape)}))' self.declare_region.append(code) elif isinstance(var, str): - # TODO: handle var that is not default in nn.Module - pass - # name = self.naming(var) - # if self.symbols.create(name): - # #TODO: add type info - # code = f'{name} = None' - # self.declare_region.append(code) + name = self.naming(var) + if name.startswith('self.'): + if not hasattr(self._ref_module, var): + if self.symbols.create(name): + #TODO: add default value + code = f'{name} = None' + self.declare_region.append(code) return def emit_op_call(self, node: IROperation): @@ -112,8 +114,11 @@ def _forward_region_arg_names(self, args: List[Any]): Will add prefix 'self.' for var defined in declare region """ named_args : List[str] = list() + input_name = [self.naming(input) for input in self.graph.inputs()] for arg in args: - if isinstance(arg, IRTensor) and arg.is_leaf(): + name = self.naming(arg) + if isinstance(arg, IRTensor) and \ + arg.is_leaf() and (name not in input_name): named_args.append('self.' + self.naming(arg)) else: named_args.append(self.naming(arg)) From 61e37df88217563827c1c281ff3a1cdbc47dfa4b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 14:26:28 +0800 Subject: [PATCH 0181/1892] fix shared var --- cube/codegen/codegen.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 114db6cc..be25ad02 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -67,13 +67,12 @@ def gen(self, outfile=None) -> List[str]: self.code += cb.code self.code += [''] + code = '\n'.join(self.code) # write to file if outfile: with open(outfile, 'w') as f: - code = '\n'.join(self.code) f.write(code) - - return self.code + return code def emit_var_declare(self, var: Any): """ @@ -88,7 +87,7 @@ def emit_var_declare(self, var: Any): elif isinstance(var, str): name = self.naming(var) if name.startswith('self.'): - if not hasattr(self._ref_module, var): + if not hasattr(self._ref_module, var[5:]): if self.symbols.create(name): #TODO: add default value code = f'{name} = None' From d3b41df34e74ba7bd5cc95eba7e1ca4118cd13ed Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 14:44:21 +0800 Subject: [PATCH 0182/1892] fix super() init, now gen code is runnable --- cube/codegen/syntax/blocks.py | 15 ++++++++++++++- tests/codegen/test_codegen.py | 17 ++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/cube/codegen/syntax/blocks.py b/cube/codegen/syntax/blocks.py index 51488a4e..12459f90 100644 --- a/cube/codegen/syntax/blocks.py +++ b/cube/codegen/syntax/blocks.py @@ -31,9 +31,12 @@ def __exit__(self, exc_type, exc_value, exc_tb): class FunctionBlock(Block): """ Create a function block with function definition + + If class has derived class, then require the derived classes + has no argument for __init__ """ - def __init__(self, func_name: str, args: List[str]): + def __init__(self, func_name: str, args: List[str], derived=True): if not isinstance(func_name, str): raise TypeError("Expected func_name to be str") if not isinstance(args, list): @@ -43,9 +46,19 @@ def __init__(self, func_name: str, args: List[str]): args = ', '.join(args) title = f'def {self.func_name}({args}):' super().__init__(title) + self.derived = derived + + def __enter__(self): + # assume no argument for initialize super class + if self.derived and self.func_name == '__init__': + self.insert_body('super().__init__()') + return self class ClassBlock(Block): + """ + Class definition. + """ def __init__(self, class_name, derived=None): if not isinstance(class_name, str): diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index 265b5640..e02df3f3 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -27,11 +27,26 @@ def forward(self, data, x: int = 4): model = FeedForward(dim=1024) +def import_from_file(filename): + print(f'> loading GenModel from {filename} ...') + import importlib.util + spec = importlib.util.spec_from_file_location("GenModel", filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.GenModel + + def test_codegen(model): graph = cgraph.convert(model, input_shapes=([1024,1024],[1,])) gener = SScheduleCodeGen(graph) - gener.gen(outfile='code.py') + code = gener.gen(outfile='code.py') + + # execute + print(code) + GenModel = import_from_file('code.py') + model = GenModel() + print(model) if __name__ == '__main__': From c119784704a9a60ad8b6c095d64c0ac1fc5e758c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 15:41:55 +0800 Subject: [PATCH 0183/1892] trainable example --- tests/codegen/test_codegen.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index e02df3f3..8133ecc2 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -36,6 +36,12 @@ def import_from_file(filename): return module.GenModel +def init_weight(parameters): + for param in parameters: + with torch.no_grad(): + torch.nn.init.uniform_(param) + + def test_codegen(model): graph = cgraph.convert(model, input_shapes=([1024,1024],[1,])) @@ -43,10 +49,26 @@ def test_codegen(model): code = gener.gen(outfile='code.py') # execute + print("> ===== Generated Code =====") print(code) + print("< ===== Generated Code =====") + GenModel = import_from_file('code.py') - model = GenModel() - print(model) + model = GenModel().cuda() + + init_weight(model.parameters()) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + print("> training 10 iterations...") + + for _ in range(10): + data = torch.randn([64,1024], device=torch.device('cuda:0')) + out = model(data, 0) + loss = torch.mean(out) / 1000 + print(f'> loss: {loss.item()}') + loss.backward() + optimizer.step() + optimizer.zero_grad() if __name__ == '__main__': From 6b5c0a7a6798e7950298c4c69211d0e37a3eea81 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 16:22:29 +0800 Subject: [PATCH 0184/1892] fix aten node kwargs (omit them) --- cube/graph/parser.py | 30 +++++++++++++++++++++--------- tests/codegen/test_codegen.py | 2 +- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/cube/graph/parser.py b/cube/graph/parser.py index c11b70d5..983c04eb 100644 --- a/cube/graph/parser.py +++ b/cube/graph/parser.py @@ -143,24 +143,36 @@ def parse_aten_node(node, module, frame: Frame) -> List[IROperation]: inputs = [input for input in node.inputs()] outputs = [output for output in node.outputs()] + # handle inputs: + # TODO: fix omitted kwargs + # We will omit arg index >= 2 as we assume the + # tensor op at most gets 2 tensor, others are kwargs + input_val = list() + maybe_kwarg = len(inputs) > 2 + for reverse_index, input in enumerate(inputs[::-1]): + var_name = input.debugName() + val = frame.get_var(var_name) + index = len(inputs) - 1 - reverse_index + if maybe_kwarg and (not isinstance(val, IRTensor)) and index > 1: + continue + else: + input_val.append(val) + maybe_kwarg = False + input_val = input_val[::-1] + if len(input_val) < len(inputs): + print(f"Warning: some non-tensor arguments are ommited in {fsig}") + # create IR node ir_node = IROperation( signature = fsig, name = fsig, - input_length = len(inputs), + input_length = len(input_val), output_length = len(outputs) ) - - # handle inputs - inputs = [input for input in node.inputs()] - # in stack with reverse order - for index, input in enumerate(inputs): - var_name = input.debugName() - val = frame.get_var(var_name) + for index, val in enumerate(input_val): ir_node.set_input(index, val) # handle outputs - outputs = [output for output in node.outputs()] for index, output in enumerate(outputs): frame.add_var(output.debugName(), ir_node.outputs(index)) diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index 8133ecc2..c096281e 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -18,8 +18,8 @@ def forward(self, data, x: int = 4): output = self.linear1(data) output = self.gelu(output) output = self.dropout(output) - output = output + data output = self.linear2(output) + output = output + data output = self.classifier(output) return output From d9c37cbb23315319fe3ce450163794b058f70b8f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Sep 2021 19:39:59 +0800 Subject: [PATCH 0185/1892] restructure code --- cube/graph/__init__.py | 2 +- cube/graph/graph.py | 6 ++++++ cube/graph/parser/__init__.py | 2 ++ cube/graph/{ => parser}/converter.py | 0 cube/graph/{ => parser}/frame.py | 0 cube/graph/{ => parser}/parser.py | 2 +- tests/codegen/test_codegen.py | 2 +- tests/graph/test_parser.py | 6 +++--- 8 files changed, 14 insertions(+), 6 deletions(-) create mode 100644 cube/graph/parser/__init__.py rename cube/graph/{ => parser}/converter.py (100%) rename cube/graph/{ => parser}/frame.py (100%) rename cube/graph/{ => parser}/parser.py (99%) diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index 5fdefdc7..8704f201 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -1,3 +1,3 @@ from cube.graph.graph import IRGraph, IRTensor, IROperation -from cube.graph.converter import convert +from cube.graph import parser diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 51295011..fbf6a6f4 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -321,6 +321,12 @@ def replace(self, target: IROperation, nodes: List[IROperation]): """ raise NotImplementedError + def forward(self, *args, **kwargs): + raise NotImplementedError + + def __call__(self, *args, **kwargs): + raise NotImplementedError + def __repr__(self): dscp = '' # inputs diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py new file mode 100644 index 00000000..9d11b160 --- /dev/null +++ b/cube/graph/parser/__init__.py @@ -0,0 +1,2 @@ +from cube.graph.parser.parser import ScriptModuleParser +from cube.graph.parser.converter import convert \ No newline at end of file diff --git a/cube/graph/converter.py b/cube/graph/parser/converter.py similarity index 100% rename from cube/graph/converter.py rename to cube/graph/parser/converter.py diff --git a/cube/graph/frame.py b/cube/graph/parser/frame.py similarity index 100% rename from cube/graph/frame.py rename to cube/graph/parser/frame.py diff --git a/cube/graph/parser.py b/cube/graph/parser/parser.py similarity index 99% rename from cube/graph/parser.py rename to cube/graph/parser/parser.py index 983c04eb..1fa0013c 100644 --- a/cube/graph/parser.py +++ b/cube/graph/parser/parser.py @@ -4,7 +4,7 @@ from typing import List, Tuple, Optional from cube.graph import IROperation, IRTensor -from cube.graph.frame import Frame +from cube.graph.parser.frame import Frame class ScriptNodeKind(enum.Enum): PrimGetAttr = 1 diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index c096281e..16133af0 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -43,7 +43,7 @@ def init_weight(parameters): def test_codegen(model): - graph = cgraph.convert(model, + graph = cgraph.parser.convert(model, input_shapes=([1024,1024],[1,])) gener = SScheduleCodeGen(graph) code = gener.gen(outfile='code.py') diff --git a/tests/graph/test_parser.py b/tests/graph/test_parser.py index 50329005..774135e7 100644 --- a/tests/graph/test_parser.py +++ b/tests/graph/test_parser.py @@ -1,5 +1,5 @@ -from cube.graph.parser import ScriptModuleParser -import torch +from cube.graph import parser +import cube.graph.parser as parser from torch import nn import cube.graph as cgraph @@ -30,7 +30,7 @@ def test_flatten(smodule): ScriptModuleParser.flatten(smodule) def test_parse_module(model): - return cgraph.convert(model, input_shapes=([1024,1024],[1,])) + return parser.convert(model, input_shapes=([1024,1024],[1,])) if __name__ == '__main__': From 4b8472f494407bcf9609917523c1ce224f98fb29 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 18 Sep 2021 10:43:13 +0800 Subject: [PATCH 0186/1892] add sequence model --- cube/tschedule/sequence.py | 154 +++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 cube/tschedule/sequence.py diff --git a/cube/tschedule/sequence.py b/cube/tschedule/sequence.py new file mode 100644 index 00000000..cb680a77 --- /dev/null +++ b/cube/tschedule/sequence.py @@ -0,0 +1,154 @@ +from typing import List, Tuple, NewType +import numpy as np + +from cube.tschedule.action import Action + + +class ASequence: + + def __init__(self, actions: List[Action]): + + if not all([isinstance(action, Action) for action in actions]): + raise TypeError("Expected a list of Actions") + + self.sequence = actions + + def __iter__(self): + """ + Iterate on the actions + """ + return self.sequence + + def __len__(self) -> int: + """ + Get number of action in the sequence + """ + return len(self.sequence) + + def append(self, action: Action): + if not isinstance(action, Action): + raise TypeError("Expected an action") + self.sequence.append(action) + + def pop(self) -> Action: + """ + Pop the last action and return + """ + if len(self.sequence) == 0: + return None + return self.sequence.pop() + + def is_correct(self): + """ + Check whether sequence + satisfies the sequential consistency model + """ + for index, action in enumerate(self.sequence): + for pre_action in action.pre_actions(): + # find the pre-action not appear in sequence + if not pre_action in self.sequence: + return False + pre_idx = self.sequence.index(pre_action) + # violate happened before + if pre_idx >= index: + return False + return True + + +# ======= Blow should be moved from this module ======== # + +Relation = NewType('Relation', List[Tuple[Action, Action]]) + + +class ScheduleSpace: + + @staticmethod + def tspace(remain_actions: List[Action], + path_shuffle=True, + relations=None, + seq: ASequence = ASequence(list())): + """ + Iterate on the legal sequence space + """ + if len(remain_actions) == 0: + yield seq + # inital entry + if relations is None: + relations = ScheduleSpace._get_relations(remain_actions) + entry_actions = ScheduleSpace._ready_actions(remain_actions, relations) + entry_actions = np.array(entry_actions) + + # recursive search + if path_shuffle: + np.random.shuffle(entry_actions) + for aid, action in enumerate(entry_actions): + if len(seq) == 0: + print(f'> search progress: [{aid+1}/{len(entry_actions)}]...') + seq.append(action) + action_idx = remain_actions.index(action) + sub_actions = remain_actions[:action_idx] + remain_actions[action_idx+1:] + sub_relations = ScheduleSpace._remove_action(action, relations) + for res in ScheduleSpace.space(sub_actions, path_shuffle, sub_relations, seq): + yield res + seq.pop() + + + @staticmethod + def sspace(actions: List[Action], ndevice: int, path_shuffle=True, depth=0): + """ + Iterate on the possible action space + """ + if depth == len(actions): + yield actions + return + action = actions[depth] + device_choice = np.array(list(range(ndevice)), dtype=np.int) + if path_shuffle: + np.random.shuffle(device_choice) + for device in device_choice: + action.device = device + for res in ScheduleSpace.sspace(actions, ndevice, path_shuffle, depth+1): + yield res + + + @staticmethod + def _ready_actions(actions: List[Action], sub_relations: Relation) -> List[Action]: + """ + Get ready to emit actions based on sub_relations + """ + ready_actions = list() + for action in actions: + satisfy = True + for (_, succ) in sub_relations: + if succ == action: + satisfy = False + break + if satisfy: + ready_actions.append(action) + return ready_actions + + + @staticmethod + def _get_relations(actions: List[Action]) -> Relation: + """ + Get relation tuples (Action1 -> Action2) + """ + relations = list() + for action in actions: + relation = [(pre_action, action) for pre_action in action.pre_actions()] + if len(relation) != 0: + relations += relation + return relations + + + @staticmethod + def _remove_action(target: Action, relations: Relation) -> Relation: + """ + Remove the target action from relation set + """ + sub_relations = list() + for (pre, succ) in relations: + if pre == target or succ == target: + continue + sub_relations.append((pre, succ)) + return sub_relations From b2dfe9e03848b2e9d7a901e5acba171924275939 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 18 Sep 2021 15:00:59 +0800 Subject: [PATCH 0187/1892] graph forward to actions --- cube/graph/graph.py | 76 ++++++++++++++++++++++++++++--- cube/tschedule/__init__.py | 2 +- cube/tschedule/action.py | 13 +++--- cube/tschedule/pool.py | 53 +++++++++++++++++++++ tests/tschedule/test_tschedule.py | 52 +++++++++++++++++++++ 5 files changed, 183 insertions(+), 13 deletions(-) create mode 100644 cube/tschedule/pool.py create mode 100644 tests/tschedule/test_tschedule.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index fbf6a6f4..5101a5c8 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -1,7 +1,11 @@ from cube.graph.unique import IDGenerator from cube.graph.mapping import IR2LogicOp -from typing import List, Optional, Any +from cube.tschedule.action import Action +from cube.tschedule.pool import TSchedulePool + +from typing import Union, Tuple, List, Optional, Any +import copy __all__ = ['IROperation', 'IRTensor', 'IRGraph'] @@ -32,7 +36,8 @@ def __init__(self, # op signature and op class self.signature: str = signature - self.op = IR2LogicOp.map(self.signature) + self.semantic = IR2LogicOp.map(self.signature) + self.device = None # edge (dataflow info) self._inputs: List[IRTensor] = [None] * input_length @@ -174,7 +179,7 @@ def infer_shape(self): else: shapes.append([1,]) shapes = tuple(shapes) - out_shapes = self.op.shape_infer(*shapes) + out_shapes = self.semantic.shape_infer(*shapes) if len(out_shapes) != len(self._outputs): raise RuntimeError( "The logical op semantic doesn't match with parsed op" @@ -204,6 +209,18 @@ def __init__(self, shape=None, name=None): self._src_nodes: IROperation = None # -> output of the node self._dst_nodes: List[IROperation] = list() # -> input of the nodes + def __copy__(self): + """ + Copy the tensor that will be same except a new id + """ + tensor = IRTensor(self._shape, self.name) + tensor.device = self.device + tensor._id = IDGenerator().gen_tensor_id() + return tensor + + def __deepcopy__(self): + raise RuntimeError("DeepCopy is not allowed.") + @property def shape(self): return self._shape @@ -241,6 +258,9 @@ def is_leaf(self): """ return self.src() is None + def backward(self, tensors = None): + raise NotImplementedError + def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape})' return dscp @@ -321,11 +341,55 @@ def replace(self, target: IROperation, nodes: List[IROperation]): """ raise NotImplementedError - def forward(self, *args, **kwargs): - raise NotImplementedError + def forward(self, *args, **kwargs) -> Union[IRTensor, Tuple[IRTensor]]: + """ + forward will divide the graph into Actions according to + node device assignment + + Currently each forward call will result in a new flow + even if the input is same + + Returns: + List[Action] + """ + if len(self._outputs) == 1: + return copy.copy(self._outputs[0]) + else: + return tuple([copy.copy(output) for output in self._outputs]) def __call__(self, *args, **kwargs): - raise NotImplementedError + """ + Register forward action + """ + curr_nodes: List[IROperation] = list() + curr_device = None + + def _wrap_to_action(): + sub_graph = IRGraph( + curr_nodes, self._inputs, self._outputs, + module_name=self.module_name + ) + action = Action(sub_graph, device=curr_device) + action.tag('forward') + return action + + for node in self.nodes(): + device = node.device + if device is None: + raise RuntimeError("All the node should be assigned to devices") + if device != curr_device and curr_device is not None: + # note we still use same input and output to make consistency + action = _wrap_to_action() + # register to schedule space + TSchedulePool().add_action(action) + curr_nodes = list() + curr_device = device + curr_nodes.append(node) + if curr_device is not None: + action = _wrap_to_action() + TSchedulePool().add_action(action) + + return self.forward(*args, **kwargs) def __repr__(self): dscp = '' diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py index 23966a1c..210ed4b8 100644 --- a/cube/tschedule/__init__.py +++ b/cube/tschedule/__init__.py @@ -1 +1 @@ -from cube.tschedule.action import Action \ No newline at end of file +from cube.tschedule.action import Action diff --git a/cube/tschedule/action.py b/cube/tschedule/action.py index 57f80cad..a1d71631 100644 --- a/cube/tschedule/action.py +++ b/cube/tschedule/action.py @@ -1,21 +1,17 @@ from typing import List -from cube.graph import IRGraph - class Action: """ Action represents a (sub-)graph which contains operators on the same device """ - def __init__(self, graph: IRGraph, device: int): + def __init__(self, ir_graph, device: int): - if not isinstance(graph, IRGraph): - raise TypeError("Require graph to be IRGraph") if not isinstance(device, int): raise TypeError("Require device to be int") # set up attributes - self.graph: IRGraph = graph + self.graph = ir_graph self.device: int = device self.name: str = None # dependencies @@ -88,3 +84,8 @@ def _add_pre_action(self, action): if not isinstance(action, Action): raise TypeError("Expected action to be Action") self._successors.append(action) + + def __repr__(self): + dscp = f'Action({self.name}):\n\t{self.graph.outputs()} <- {self.graph.inputs()}' + return dscp + \ No newline at end of file diff --git a/cube/tschedule/pool.py b/cube/tschedule/pool.py new file mode 100644 index 00000000..8b0ba80b --- /dev/null +++ b/cube/tschedule/pool.py @@ -0,0 +1,53 @@ +from typing import List, Callable + +from cube.tschedule.action import Action + + +class TSchedulePool: + + class __TSchedulePool: + + def __init__(self): + + self._actions: List[Action] = list() + + instance = None + + def __init__(self): + if not TSchedulePool.instance: + TSchedulePool.instance = TSchedulePool.__TSchedulePool() + + def __getattr__(self, name): + return getattr(self.instance, name) + + def add_action(self, action: Action): + self.instance._actions.append(action) + + def actions(self): + return self.instance._actions + + def clear(self): + self.instance._actions = list() + + def __repr__(self): + dscp = '\n'.join([repr(action) for action in self._actions]) + return dscp + + +def schedule(fn: Callable, policy=None, *args, **kwargs): + """ + AI Scientist calls like: + + @cube.tschedule.schedule + def train_step(model, optimizer, datas, labels): + for (data, label) in datas: + loss = model(data, label) + loss.backward() + optimizer.step() + optimizer.zero_grad() + ... + for datas, labels in dataloader(): + train_step(model, optimizer, datas, labels) + ... + """ + raise NotImplementedError \ No newline at end of file diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py new file mode 100644 index 00000000..55a3e344 --- /dev/null +++ b/tests/tschedule/test_tschedule.py @@ -0,0 +1,52 @@ +from torch import nn + +import cube.graph.parser as parser +import cube.tschedule as tschedule + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=16, classes=1000): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult, bias=False) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim * mult, dim) + self.classifier = nn.Linear(dim, classes) + + def forward(self, data, label: int = 4): + output = self.linear1(data) + output = self.gelu(output) + output = self.dropout(output) + output = self.linear2(output) + output = output + data + output = self.classifier(output) + return output + + +model = FeedForward(dim=1024) +ir_graph = parser.convert(model, input_shapes=([64,1024],[64,])) + +# device assignment +for node in ir_graph.nodes(): + node.device = 0 + + +def test_graph_forward(ir_graph): + + tensor1 = ir_graph() + print(tensor1) + print(tschedule.pool.TSchedulePool()) + tensor2 = ir_graph() + print(tensor2) + print(tschedule.pool.TSchedulePool()) + + +def test_graph_backward(ir_graph): + + tensor = ir_graph() + tensor.backward() + + +if __name__ == '__main__': + + test_graph_forward(ir_graph) From 549d4d25dbbd8eba7a13f97da2857a4f72643de6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 18 Sep 2021 15:19:56 +0800 Subject: [PATCH 0188/1892] re-structure code --- cube/graph/__init__.py | 2 +- cube/graph/ir_graph.py | 150 +++++++++++++++++++++++++++ cube/graph/{graph.py => ir_opten.py} | 150 +-------------------------- cube/graph/parser/converter.py | 2 +- tests/graph/test_parser.py | 13 ++- 5 files changed, 161 insertions(+), 156 deletions(-) create mode 100644 cube/graph/ir_graph.py rename cube/graph/{graph.py => ir_opten.py} (63%) diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index 8704f201..86899d86 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -1,3 +1,3 @@ -from cube.graph.graph import IRGraph, IRTensor, IROperation +from cube.graph.ir_graph import IRGraph, IRTensor, IROperation from cube.graph import parser diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py new file mode 100644 index 00000000..cc62a5ee --- /dev/null +++ b/cube/graph/ir_graph.py @@ -0,0 +1,150 @@ +from cube.graph.ir_opten import IROperation, IRTensor +from cube.tschedule.action import Action +from cube.tschedule.pool import TSchedulePool + +from typing import Union, Tuple, List, Optional +import copy + + +__all__ = ['IRGraph'] + + +class IRGraph: + """ + PyTorch IR Graph + + The IR Graph only contains forward graph + """ + + def __init__(self, + nodes: List[IROperation], + input_tensors: List[IRTensor], + output_tensors: List[IRTensor], + module_name: str): + self.module_name = module_name + self._nodes: List[IROperation] = nodes + self._inputs = input_tensors + self._outputs = output_tensors + + def add_node(self, node: IROperation): + if not isinstance(node, IROperation): + raise TypeError("Expected node to be IROperation") + self._nodes.append(node) + + def nodes(self, index: Optional[int] = None): + """ + Get node at position index + """ + if isinstance(index, int): + if index >= len(self._nodes): + raise RuntimeError( + f"Get node out of range ({index} >= {len(self._nodes)})" + ) + return self._nodes[index] + elif index is None: + return self._nodes + else: + raise TypeError("Expected index to be None or int") + + def inputs(self, index: Optional[int] = None): + if isinstance(index, int): + if index >= len(self._inputs): + raise RuntimeError( + f"Get the input out of range ({index} >= {len(self._inputs)}" + ) + return self._inputs[index] + elif index is None: + return self._inputs + else: + raise TypeError("Expected index to be None or int") + + def outputs(self, index: Optional[int] = None): + """ + Get output tensor at output index + + Args: + index (int or None): + index of the outputs, None will return the nodes + for all the outputs + """ + if isinstance(index, int): + if index >= len(self._outputs): + raise RuntimeError( + f"Get the output out of range ({index} >= {len(self._outputs)}" + ) + return self._outputs[index] + elif index is None: + return self._outputs + else: + raise TypeError("Expected index to be None or int") + + def replace(self, target: IROperation, nodes: List[IROperation]): + """ + Replace the node with new nodes (IRGraph) + """ + raise NotImplementedError + + def forward(self, *args, **kwargs) -> Union[IRTensor, Tuple[IRTensor]]: + """ + forward will divide the graph into Actions according to + node device assignment + + Currently each forward call will result in a new flow + even if the input is same + + Returns: + List[Action] + """ + if len(self._outputs) == 1: + return copy.copy(self._outputs[0]) + else: + return tuple([copy.copy(output) for output in self._outputs]) + + def __call__(self, *args, **kwargs): + """ + Register forward action + """ + curr_nodes: List[IROperation] = list() + curr_device = None + + def _wrap_to_action(): + sub_graph = IRGraph( + curr_nodes, self._inputs, self._outputs, + module_name=self.module_name + ) + action = Action(sub_graph, device=curr_device) + action.tag('forward') + return action + + for node in self.nodes(): + device = node.device + if device is None: + raise RuntimeError("All the node should be assigned to devices") + if device != curr_device and curr_device is not None: + # note we still use same input and output to make consistency + action = _wrap_to_action() + # register to schedule space + TSchedulePool().add_action(action) + curr_nodes = list() + curr_device = device + curr_nodes.append(node) + if curr_device is not None: + action = _wrap_to_action() + TSchedulePool().add_action(action) + + return self.forward(*args, **kwargs) + + def __repr__(self): + dscp = '' + # inputs + dscp += f'Inputs: {self._inputs}\n' + # nodes + for node in self._nodes: + succ_node_ids = [None] * len(node.outputs()) + for out_idx, node_list in enumerate(node.successors()): + node_list = [snode._id for snode in node_list] + succ_node_ids[out_idx] = node_list + dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" + # outputs + dscp += f'\nOutputs: {self._outputs}' + return dscp diff --git a/cube/graph/graph.py b/cube/graph/ir_opten.py similarity index 63% rename from cube/graph/graph.py rename to cube/graph/ir_opten.py index 5101a5c8..5741cb1d 100644 --- a/cube/graph/graph.py +++ b/cube/graph/ir_opten.py @@ -1,14 +1,10 @@ from cube.graph.unique import IDGenerator from cube.graph.mapping import IR2LogicOp -from cube.tschedule.action import Action -from cube.tschedule.pool import TSchedulePool +from typing import List, Optional, Any -from typing import Union, Tuple, List, Optional, Any -import copy - -__all__ = ['IROperation', 'IRTensor', 'IRGraph'] +__all__ = ['IROperation', 'IRTensor'] class IROperation: @@ -250,7 +246,7 @@ def set_src_node(self, node: IROperation): def add_dst_nodes(self, node: IROperation): if not isinstance(node, IROperation): raise TypeError("IRTensor destination node should be IROperation") - self._dst_nodes.append(IROperation) + self._dst_nodes.append(node) def is_leaf(self): """ @@ -265,143 +261,3 @@ def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape})' return dscp - -class IRGraph: - """ - PyTorch IR Graph - - The IR Graph only contains forward graph - """ - - def __init__(self, - nodes: List[IROperation], - input_tensors: List[IRTensor], - output_tensors: List[IRTensor], - module_name: str): - self.module_name = module_name - self._nodes: List[IROperation] = nodes - self._inputs = input_tensors - self._outputs = output_tensors - - def add_node(self, node: IROperation): - if not isinstance(node, IROperation): - raise TypeError("Expected node to be IROperation") - self._nodes.append(node) - - def nodes(self, index: Optional[int] = None): - """ - Get node at position index - """ - if isinstance(index, int): - if index >= len(self._nodes): - raise RuntimeError( - f"Get node out of range ({index} >= {len(self._nodes)})" - ) - return self._nodes[index] - elif index is None: - return self._nodes - else: - raise TypeError("Expected index to be None or int") - - def inputs(self, index: Optional[int] = None): - if isinstance(index, int): - if index >= len(self._inputs): - raise RuntimeError( - f"Get the input out of range ({index} >= {len(self._inputs)}" - ) - return self._inputs[index] - elif index is None: - return self._inputs - else: - raise TypeError("Expected index to be None or int") - - def outputs(self, index: Optional[int] = None): - """ - Get output tensor at output index - - Args: - index (int or None): - index of the outputs, None will return the nodes - for all the outputs - """ - if isinstance(index, int): - if index >= len(self._outputs): - raise RuntimeError( - f"Get the output out of range ({index} >= {len(self._outputs)}" - ) - return self._outputs[index] - elif index is None: - return self._outputs - else: - raise TypeError("Expected index to be None or int") - - def replace(self, target: IROperation, nodes: List[IROperation]): - """ - Replace the node with new nodes (IRGraph) - """ - raise NotImplementedError - - def forward(self, *args, **kwargs) -> Union[IRTensor, Tuple[IRTensor]]: - """ - forward will divide the graph into Actions according to - node device assignment - - Currently each forward call will result in a new flow - even if the input is same - - Returns: - List[Action] - """ - if len(self._outputs) == 1: - return copy.copy(self._outputs[0]) - else: - return tuple([copy.copy(output) for output in self._outputs]) - - def __call__(self, *args, **kwargs): - """ - Register forward action - """ - curr_nodes: List[IROperation] = list() - curr_device = None - - def _wrap_to_action(): - sub_graph = IRGraph( - curr_nodes, self._inputs, self._outputs, - module_name=self.module_name - ) - action = Action(sub_graph, device=curr_device) - action.tag('forward') - return action - - for node in self.nodes(): - device = node.device - if device is None: - raise RuntimeError("All the node should be assigned to devices") - if device != curr_device and curr_device is not None: - # note we still use same input and output to make consistency - action = _wrap_to_action() - # register to schedule space - TSchedulePool().add_action(action) - curr_nodes = list() - curr_device = device - curr_nodes.append(node) - if curr_device is not None: - action = _wrap_to_action() - TSchedulePool().add_action(action) - - return self.forward(*args, **kwargs) - - def __repr__(self): - dscp = '' - # inputs - dscp += f'Inputs: {self._inputs}\n' - # nodes - for node in self._nodes: - succ_node_ids = [None] * len(node.outputs()) - for out_idx, node_list in enumerate(node.successors()): - node_list = [snode._id for snode in node_list] - succ_node_ids[out_idx] = node_list - dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" - # outputs - dscp += f'\nOutputs: {self._outputs}' - return dscp diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index fa5d212c..31c557b7 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,7 +1,7 @@ from typing import Optional, List from cube.graph.parser import ScriptModuleParser -from cube.graph.graph import IRGraph +from cube.graph import IRGraph import torch diff --git a/tests/graph/test_parser.py b/tests/graph/test_parser.py index 774135e7..9ee2ef93 100644 --- a/tests/graph/test_parser.py +++ b/tests/graph/test_parser.py @@ -1,10 +1,9 @@ -from cube.graph import parser import cube.graph.parser as parser -from torch import nn - -import cube.graph as cgraph from cube.graph.parser import ScriptModuleParser +import torch +from torch import nn + class FeedForward(nn.Module): def __init__(self, dim, dropout=0., mult=16, classes=1000): super().__init__() @@ -26,7 +25,8 @@ def forward(self, data, x: int = 4): model = FeedForward(dim=1024) -def test_flatten(smodule): +def test_flatten(model): + smodule = torch.jit.script(model) ScriptModuleParser.flatten(smodule) def test_parse_module(model): @@ -34,8 +34,7 @@ def test_parse_module(model): if __name__ == '__main__': - - # test_flatten(smodule) + # test_flatten(model) graph = test_parse_module(model) print(graph) \ No newline at end of file From 14fde4eb5c6dce95f410df051e09d4ff4fee0034 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 21 Sep 2021 10:43:59 +0800 Subject: [PATCH 0189/1892] add local graph --- cube/codegen/codegen.py | 6 +-- cube/graph/__init__.py | 3 +- cube/graph/ir_graph.py | 49 +++++++++++++++++- cube/graph/ir_opten.py | 98 ++++++++++++++++++++++++++++++++--- tests/codegen/test_codegen.py | 10 ++-- 5 files changed, 150 insertions(+), 16 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index be25ad02..8110b54f 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -3,7 +3,7 @@ """ from typing import List, Any -from cube.graph import IRGraph, IRTensor, IROperation +from cube.graph import IRGraph, IRLocalGraph, IRTensor, IROperation from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -15,10 +15,10 @@ class SScheduleCodeGen: Generate spatial code for the model """ - def __init__(self, graph: IRGraph): + def __init__(self, graph: IRGraph, device: int): if not isinstance(graph, IRGraph): raise TypeError("graph should be IRGraph") - self.graph = graph + self.graph = IRLocalGraph(graph, device=device) # model full code self.code: List[str] = ['import torch', '', ''] # module init code diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index 86899d86..924c2f74 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -1,3 +1,4 @@ -from cube.graph.ir_graph import IRGraph, IRTensor, IROperation +from cube.graph.ir_graph import IRGraph, IRLocalGraph +from cube.graph.ir_opten import IRTensor, IROperation, OperationType from cube.graph import parser diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index cc62a5ee..78c7e06c 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -6,7 +6,7 @@ import copy -__all__ = ['IRGraph'] +__all__ = ['IRGraph', 'IRLocalGraph'] class IRGraph: @@ -148,3 +148,50 @@ def __repr__(self): # outputs dscp += f'\nOutputs: {self._outputs}' return dscp + + +class IRLocalGraph(IRGraph): + + def __init__(self, graph: IRGraph, device: int): + + if not isinstance(graph, IRGraph): + raise TypeError(f"Expected graph: IRGraph but go {type(graph)}") + if not isinstance(device, int): + raise TypeError(f"Expected device: int but not {type(device)}") + + self.global_graph = graph + self.device = device + self.send_tensors = list() + self.recv_tensors = list() + # get nodes belong to this graph + nodes = list() + all_tensors = set() + for node in self.global_graph.nodes(): + # collect on device node, inputs and outputs + if node.on_device(self.device): + nodes.append(node) + # collect send tensors and recv tensors + if node.semantic == 'move': + if device in node.inputs(0).device: + self.send_tensors.append(node.inputs(0)) + if device in node.outputs(0).device: + self.recv_tensors.append(node.outputs(0)) + all_tensors.update(node.inputs()) + all_tensors.update(node.outputs()) + + # model inputs and outputs + model_inputs = list() + model_outputs = list() + for input in self.global_graph.inputs(): + if input in all_tensors: + model_inputs.append(input) + for output in self.global_graph.outputs(): + if output in all_tensors: + model_outputs.append(output) + + super().__init__( + nodes, + model_inputs + self.recv_tensors, # input tensors + model_outputs + self.send_tensors, # output tensors + self.global_graph.module_name + f'Rank{self.device}' + ) diff --git a/cube/graph/ir_opten.py b/cube/graph/ir_opten.py index 5741cb1d..e7bc366e 100644 --- a/cube/graph/ir_opten.py +++ b/cube/graph/ir_opten.py @@ -1,10 +1,17 @@ from cube.graph.unique import IDGenerator from cube.graph.mapping import IR2LogicOp -from typing import List, Optional, Any +from enum import Enum +from typing import List, Optional, Any, Union -__all__ = ['IROperation', 'IRTensor'] +__all__ = ['OperationType', 'IROperation', 'IRTensor'] + + +class OperationType(Enum): + + Comp = 1 # computation + Comm = 2 # communication class IROperation: @@ -16,7 +23,8 @@ def __init__(self, name: str, signature: str, input_length: int, - output_length: int): + output_length: int, + type=OperationType.Comp): """ Create a node with name (variable name) and module type (module_name) @@ -33,8 +41,9 @@ def __init__(self, # op signature and op class self.signature: str = signature self.semantic = IR2LogicOp.map(self.signature) - self.device = None - + self._type = type + self._device = None + # edge (dataflow info) self._inputs: List[IRTensor] = [None] * input_length self._predecessors: List[IROperation] = [None] * input_length @@ -44,6 +53,60 @@ def __init__(self, tensor.set_src_node(self) self._successors: List[List(IROperation)] = [list() for _ in range(output_length)] + @property + def type(self) -> OperationType: + return self._type + + @type.setter + def type(self, _): + raise RuntimeError("Not allowed to set type except initialization") + + @property + def device(self): + return self._device + + @device.setter + def device(self, device_id: Union[int, List[int]]): + """ + Set the operation device. + + For computation operators, they are only allowed + to happen on one device (int) + + For communication operators (e.g., move, all-reduce), + they are allowed to happend on multiple devices + """ + if self.type == OperationType.Comp: + if not isinstance(device_id, int): + raise KeyError("Require computation op device: int") + elif self.type == OperationType.Comm: + if not isinstance(device_id, list): + raise TypeError("Require communication op device: List[int]") + else: + raise TypeError("Unknown operator type") + self._device = device_id + for output in self._outputs: + if isinstance(output, IRTensor): + output.device = device_id + + def on_device(self, device_id: int): + """ + Check whether the operation is on device_id + + Returns: + Boolean + """ + if not isinstance(device_id, int): + raise TypeError("Expected device id to be int") + if self._device is None: + return False + if self.type == OperationType.Comm: + return device_id in self.device + elif self.type == OperationType.Comp: + return device_id == self.device + else: + raise RuntimeError("Unkown Operation type") + def inputs(self, index: Optional[int] = None): """ Get input tensor at input index @@ -199,7 +262,7 @@ def __init__(self, shape=None, name=None): self._id: int = IDGenerator().gen_tensor_id() self._shape: Optional(List[int]) = shape self.name = name - self.device = -1 + self._device = list() # connected to IROperation self._src_nodes: IROperation = None # -> output of the node @@ -230,6 +293,24 @@ def shape(self, val): raise RuntimeError("Expected shape to be list[int]") self._shape = val + @property + def device(self) -> List[int]: + return self._device + + @device.setter + def device(self, device_id: Union[int, List[int]]): + """ + Set placement of the tensor + + A tensor can be placed on multiple devices as input + for multiple operations on different devices + """ + if not (isinstance(device_id, int) or isinstance(device_id, list)): + raise TypeError(f"Expected device id to be int or List[int]") + if isinstance(device_id, int): + device_id = [device_id] + self._device = device_id + def src(self) -> Optional[IROperation]: return self._src_nodes @@ -254,7 +335,10 @@ def is_leaf(self): """ return self.src() is None - def backward(self, tensors = None): + def backward(self): + """ + Backward will generate a backward action scheduling pool + """ raise NotImplementedError def __repr__(self): diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index 16133af0..4a360d00 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -14,7 +14,7 @@ def __init__(self, dim, dropout=0., mult=16, classes=1000): self.linear2 = nn.Linear(dim * mult, dim) self.classifier = nn.Linear(dim, classes) - def forward(self, data, x: int = 4): + def forward(self, data): output = self.linear1(data) output = self.gelu(output) output = self.dropout(output) @@ -44,8 +44,10 @@ def init_weight(parameters): def test_codegen(model): graph = cgraph.parser.convert(model, - input_shapes=([1024,1024],[1,])) - gener = SScheduleCodeGen(graph) + input_shapes=([1024,1024],)) + for node in graph.nodes(): + node.device = 0 + gener = SScheduleCodeGen(graph, device=0) code = gener.gen(outfile='code.py') # execute @@ -63,7 +65,7 @@ def test_codegen(model): for _ in range(10): data = torch.randn([64,1024], device=torch.device('cuda:0')) - out = model(data, 0) + out = model(data) loss = torch.mean(out) / 1000 print(f'> loss: {loss.item()}') loss.backward() From 758e6ccb4328dca055e0d1117bb37312d84dd876 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 22 Sep 2021 12:36:38 +0800 Subject: [PATCH 0190/1892] generate forward / backward graph actions --- cube/codegen/codegen.py | 8 +- cube/graph/ir_graph.py | 162 ++++++++++++++++++++----- cube/graph/ir_opten.py | 192 +++++++++++++++++++++++------- cube/graph/mapping.py | 3 + cube/graph/parser/converter.py | 6 +- tests/codegen/test_codegen.py | 9 +- tests/tschedule/test_tschedule.py | 37 +++++- 7 files changed, 331 insertions(+), 86 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 8110b54f..91ad8e80 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -3,7 +3,7 @@ """ from typing import List, Any -from cube.graph import IRGraph, IRLocalGraph, IRTensor, IROperation +from cube.graph import IRLocalGraph, IRTensor, IROperation from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -15,10 +15,10 @@ class SScheduleCodeGen: Generate spatial code for the model """ - def __init__(self, graph: IRGraph, device: int): - if not isinstance(graph, IRGraph): + def __init__(self, graph: IRLocalGraph): + if not isinstance(graph, IRLocalGraph): raise TypeError("graph should be IRGraph") - self.graph = IRLocalGraph(graph, device=device) + self.graph = graph # model full code self.code: List[str] = ['import torch', '', ''] # module init code diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 78c7e06c..437926b5 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -25,6 +25,8 @@ def __init__(self, self._nodes: List[IROperation] = nodes self._inputs = input_tensors self._outputs = output_tensors + # default is forward graph + self.tag = 'forward' def add_node(self, node: IROperation): if not isinstance(node, IROperation): @@ -96,9 +98,15 @@ def forward(self, *args, **kwargs) -> Union[IRTensor, Tuple[IRTensor]]: List[Action] """ if len(self._outputs) == 1: - return copy.copy(self._outputs[0]) + tensor = copy.copy(self._outputs[0]) + tensor.set_forward_graph(self) + return tensor else: - return tuple([copy.copy(output) for output in self._outputs]) + tensors = tuple([copy.copy(output) for output in self._outputs]) + for tensor in tensors: + if isinstance(tensor, IRTensor): + tensor.set_forward_graph(self) + return tensors def __call__(self, *args, **kwargs): """ @@ -108,17 +116,17 @@ def __call__(self, *args, **kwargs): curr_device = None def _wrap_to_action(): - sub_graph = IRGraph( - curr_nodes, self._inputs, self._outputs, - module_name=self.module_name + sub_graph = IRLocalGraph( + curr_nodes, self, device=curr_device[0] #FIXME ) - action = Action(sub_graph, device=curr_device) - action.tag('forward') + action = Action(sub_graph, device=curr_device[0]) #FIXME + action.tag(self.tag) return action for node in self.nodes(): + #FIXME: will fail in multi-branch placement (backward) device = node.device - if device is None: + if len(node.device) == 0: raise RuntimeError("All the node should be assigned to devices") if device != curr_device and curr_device is not None: # note we still use same input and output to make consistency @@ -134,6 +142,78 @@ def _wrap_to_action(): return self.forward(*args, **kwargs) + def backward(self): + """ + Backward will generate a backward action scheduling pool + + Construct a reverse graph of forward and seperate to actions + """ + # travel graph in reverse order + all_tensors = dict() + + def get_tensor_grad(tensor): + if tensor._id not in all_tensors: + new_tensor = copy.deepcopy(tensor) + if tensor.name is None: + new_tensor.name = 'grad' + else: + new_tensor.name = tensor.name + '_grad' + new_tensor._src_nodes = list() + new_tensor._dst_nodes = list() + # reverse op + devices = set() + for node in tensor.dst(): + devices.update(node.device) + new_tensor.device = list(devices) + all_tensors[tensor._id] = new_tensor + return new_tensor + else: + return all_tensors[tensor._id] + + backward_nodes = list() + for fnode in self._nodes[::-1]: + inputs = list() + for input in fnode.outputs(): + if isinstance(input, IRTensor) and input.requires_grad: + tensor = get_tensor_grad(input) + inputs.append(tensor) + else: + inputs.append(None) + outputs = list() + for output in fnode.inputs(): + if isinstance(output, IRTensor) and output.requires_grad: + tensor = get_tensor_grad(output) + outputs.append(tensor) + else: + outputs.append(None) + bp_node = IROperation( + name = fnode.name + '_backward', + signature = fnode.signature, + input_length = len(inputs), + output_length = len(outputs), + type=fnode.type + ) + bp_node.device = fnode.device + print(bp_node) + for idx, arg in enumerate(inputs): + bp_node.set_input(idx, arg) + for idx, arg in enumerate(outputs): + bp_node.set_output(idx, arg) + backward_nodes.append(bp_node) + # none inputs for loss + inputs = list() + # none outputs for loss + outputs = list() + graph = IRGraph( + backward_nodes, + inputs, outputs, + self.module_name + 'Backward' + ) + print(graph) + graph.tag = 'backward' + graph() + + def __repr__(self): dscp = '' # inputs @@ -152,45 +232,67 @@ def __repr__(self): class IRLocalGraph(IRGraph): - def __init__(self, graph: IRGraph, device: int): + def __init__(self, + sub_nodes: List[IROperation], + global_graph: IRGraph, + device: int + ): - if not isinstance(graph, IRGraph): - raise TypeError(f"Expected graph: IRGraph but go {type(graph)}") + if not isinstance(global_graph, IRGraph): + raise TypeError(f"Expected graph: IRGraph but go {type(global_graph)}") if not isinstance(device, int): raise TypeError(f"Expected device: int but not {type(device)}") - - self.global_graph = graph + for node in sub_nodes: + if not node.on_device(device): + raise RuntimeError(f"Local Graph requires all nodes on device {device}") + self.global_graph = global_graph self.device = device self.send_tensors = list() + self.send_devices = list() self.recv_tensors = list() + self.recv_devices = list() # get nodes belong to this graph - nodes = list() - all_tensors = set() - for node in self.global_graph.nodes(): - # collect on device node, inputs and outputs - if node.on_device(self.device): - nodes.append(node) - # collect send tensors and recv tensors - if node.semantic == 'move': - if device in node.inputs(0).device: - self.send_tensors.append(node.inputs(0)) - if device in node.outputs(0).device: - self.recv_tensors.append(node.outputs(0)) - all_tensors.update(node.inputs()) - all_tensors.update(node.outputs()) + all_tensors = list() + for node in sub_nodes: + # collect recv tensors + for input in node.inputs(): + if isinstance(input, IRTensor): + if self.device not in input.device: + if input not in self.recv_tensors: + self.recv_tensors.append(input) + self.recv_devices.append(self.device) + # collect send tensors + for output in node.outputs(): + if isinstance(output, IRTensor): + succ_nodes = output.dst() + for succ_node in succ_nodes: + if not succ_node.on_device(self.device): + if output not in self.send_tensors: + self.send_tensors.append(output) + self.send_devices.append(succ_node.device) + # move semantic + # if node.semantic == 'move': + # if device in node.inputs(0).device: + # self.send_tensors.append(node.inputs(0)) + # self.send_devices.append(node.outputs(0).device) + # if device in node.outputs(0).device: + # self.recv_tensors.append(node.outputs(0)) + # self.recv_devices.append(node.inputs(0).device) + all_tensors += node.inputs() + all_tensors += node.outputs() # model inputs and outputs model_inputs = list() model_outputs = list() for input in self.global_graph.inputs(): - if input in all_tensors: + if input in all_tensors and input not in self.recv_tensors: model_inputs.append(input) for output in self.global_graph.outputs(): - if output in all_tensors: + if output in all_tensors and output not in self.send_tensors: model_outputs.append(output) super().__init__( - nodes, + sub_nodes, model_inputs + self.recv_tensors, # input tensors model_outputs + self.send_tensors, # output tensors self.global_graph.module_name + f'Rank{self.device}' diff --git a/cube/graph/ir_opten.py b/cube/graph/ir_opten.py index e7bc366e..8ac6a03d 100644 --- a/cube/graph/ir_opten.py +++ b/cube/graph/ir_opten.py @@ -1,8 +1,45 @@ +""" +IROperation: + + Semantic operation representation (node) in IRGraph. + An operation is of Computation (Comp) or Communication (Comm) type. + + A Comp type operation can be assigned to multiple devices for redundant computation. + A Comm type operation can be assigned to multiple devices (List[int]). + + Each IROperation can have (multiple) input args and (multiple) output args. + +IRTensor: + + Semantic tensor representation (edge) in IRGraph. + + IRTensor can be assigned (deploy) to multiple devices (List[int]) + + The IRTensor is a logical tensor that + + 1). can be generated from multipe operations (i.e., different operators + can generate different part of the IRTensor). + + => multiple source IROperation. + + 2). can be used as input for multiple operations. + + => multiple destination IROperation + + +IROperation can accept tensors that are placed on the different devices. + +Set the operation device will in default change the output tensor placement +and input leaf tensor placement to match with the operation. +""" + + from cube.graph.unique import IDGenerator from cube.graph.mapping import IR2LogicOp from enum import Enum from typing import List, Optional, Any, Union +import copy __all__ = ['OperationType', 'IROperation', 'IRTensor'] @@ -42,16 +79,17 @@ def __init__(self, self.signature: str = signature self.semantic = IR2LogicOp.map(self.signature) self._type = type - self._device = None + self._device = list() - # edge (dataflow info) + # source operations self._inputs: List[IRTensor] = [None] * input_length - self._predecessors: List[IROperation] = [None] * input_length - # todo for outputs + self._predecessors: List[List[IROperation]] = [list() for _ in range(input_length)] + + # destination operations self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] for tensor in self._outputs: - tensor.set_src_node(self) - self._successors: List[List(IROperation)] = [list() for _ in range(output_length)] + tensor.add_src_node(self) + self._successors: List[List[IROperation]] = [list() for _ in range(output_length)] @property def type(self) -> OperationType: @@ -76,15 +114,19 @@ def device(self, device_id: Union[int, List[int]]): For communication operators (e.g., move, all-reduce), they are allowed to happend on multiple devices """ - if self.type == OperationType.Comp: - if not isinstance(device_id, int): - raise KeyError("Require computation op device: int") - elif self.type == OperationType.Comm: - if not isinstance(device_id, list): - raise TypeError("Require communication op device: List[int]") - else: - raise TypeError("Unknown operator type") + if isinstance(device_id, int): + device_id = [device_id] + if not all([isinstance(devid, int) for devid in device_id]): + raise KeyError("Require device Union[int, List[int]]") self._device = device_id + for input in self._inputs: + # in default, parameters will be placed on all devices + # that needs it + if isinstance(input, IRTensor) and input.is_leaf(): + devices = set() + for node in input.dst(): + devices.update(node.device) + input.device = list(devices) for output in self._outputs: if isinstance(output, IRTensor): output.device = device_id @@ -98,14 +140,7 @@ def on_device(self, device_id: int): """ if not isinstance(device_id, int): raise TypeError("Expected device id to be int") - if self._device is None: - return False - if self.type == OperationType.Comm: - return device_id in self.device - elif self.type == OperationType.Comp: - return device_id == self.device - else: - raise RuntimeError("Unkown Operation type") + return device_id in self.device def inputs(self, index: Optional[int] = None): """ @@ -195,11 +230,35 @@ def set_input(self, input_index: int, val: Any): # set tensor self._inputs[input_index] = val if isinstance(val, IRTensor): + # set tensor dst + val.add_dst_node(self) # set predecessor self._predecessors[input_index] = val.src() # set the source node successor - if isinstance(val.src(), IROperation): - val.src()._add_successor(val, self) + for node in val.src(): + if isinstance(node, IROperation): + node._add_successor(val, self) + + def set_output(self, output_index: int, val: Any): + """ + Set the node inputs[output_index] with the tensor + + val: IRTensor or any deterministic value (int, bool, str, etc) + """ + if output_index >= len(self.outputs()): + raise RuntimeError( + f"Set the input out of range ({output_index} >= {len(self._inputs)})" + ) + # set tensor + self._outputs[output_index] = val + if isinstance(val, IRTensor): + # set predecessor + for node in val.src(): + if isinstance(node, IROperation): + self._successors[output_index].append(node) + # set the source node + if self not in val.src(): + val.add_src_node(self) def set_predecessor(self, input_index: int, node, out_index: int): """ @@ -249,7 +308,21 @@ def infer_shape(self): return True def __repr__(self): - dscp = f'Op(id={self._id}, signature={self.signature}, inputs={self._inputs}, outputs={self._outputs})' + inputs = list() + for tensor in self.inputs(): + if isinstance(tensor, IRTensor): + inputs.append(f't{tensor._id}') + else: + inputs.append(tensor) + + outputs = list() + for tensor in self.outputs(): + if isinstance(tensor, IRTensor): + outputs.append(f't{tensor._id}') + else: + outputs.append(tensor) + + dscp = f'Op(id={self._id}, signature={self.signature}, device={self.device}, inputs={inputs}, outputs={outputs})' return dscp @@ -265,20 +338,50 @@ def __init__(self, shape=None, name=None): self._device = list() # connected to IROperation - self._src_nodes: IROperation = None # -> output of the node + self._src_nodes: List[IROperation] = list() # -> output of the node self._dst_nodes: List[IROperation] = list() # -> input of the nodes + # forward graph + self.requires_grad = True + self.forward_graph = None + + def set_forward_graph(self, graph): + """ + Set forward graph (IRGraph) + """ + self.forward_graph = graph + def __copy__(self): """ Copy the tensor that will be same except a new id """ tensor = IRTensor(self._shape, self.name) - tensor.device = self.device - tensor._id = IDGenerator().gen_tensor_id() + new_id = tensor._id + for key in self.__dict__: + setattr(tensor, key, getattr(self, key)) + tensor._id = new_id + return tensor + + def __deepcopy__(self, memo): + """ + Deep Copy will copy the exactly same tensor with same tensor id + """ + tensor = IRTensor(self._shape, self.name) + for key in self.__dict__: + val = getattr(self, key) + if isinstance(val, IRTensor): + pass + if isinstance(val, list) and all([isinstance(v, IRTensor) for v in val]): + pass + else: + val = copy.copy(val) + setattr(tensor, key, val) return tensor - def __deepcopy__(self): - raise RuntimeError("DeepCopy is not allowed.") + def __eq__(self, tensor): + if not isinstance(tensor, IRTensor): + return False + return self._id == tensor._id @property def shape(self): @@ -298,33 +401,35 @@ def device(self) -> List[int]: return self._device @device.setter - def device(self, device_id: Union[int, List[int]]): + def device(self, device_id: List[int]): """ Set placement of the tensor A tensor can be placed on multiple devices as input for multiple operations on different devices """ - if not (isinstance(device_id, int) or isinstance(device_id, list)): - raise TypeError(f"Expected device id to be int or List[int]") if isinstance(device_id, int): device_id = [device_id] + if not all([isinstance(devid, int) for devid in device_id]) : + raise TypeError(f"Expected device id to be int or List[int]") self._device = device_id - def src(self) -> Optional[IROperation]: + def src(self) -> List[IROperation]: return self._src_nodes def dst(self, index: Optional[int] = None): - if index >= len(self._dst_nodes): + if index is None: + return self._dst_nodes + elif index >= len(self._dst_nodes): raise RuntimeError("get tensor dst out of range") return self._dst_nodes[index] - def set_src_node(self, node: IROperation): + def add_src_node(self, node: IROperation): if not isinstance(node, IROperation): raise TypeError("IRTensor source node should be IROperation") - self._src_nodes = node + self._src_nodes.append(node) - def add_dst_nodes(self, node: IROperation): + def add_dst_node(self, node: IROperation): if not isinstance(node, IROperation): raise TypeError("IRTensor destination node should be IROperation") self._dst_nodes.append(node) @@ -333,15 +438,20 @@ def is_leaf(self): """ Check if it is a leaf tensor (parameter) """ - return self.src() is None + return len(self.src()) == 0 def backward(self): """ Backward will generate a backward action scheduling pool + + Construct a reverse graph of forward and seperate to actions """ - raise NotImplementedError + if self.forward_graph is None: + raise RuntimeError("Backward on a tensor without forward graph") + self.forward_graph.backward() + def __repr__(self): - dscp = f'Tensor(id={self._id}, shape={self.shape})' + dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device})' return dscp diff --git a/cube/graph/mapping.py b/cube/graph/mapping.py index 24c0499e..7156e18b 100644 --- a/cube/graph/mapping.py +++ b/cube/graph/mapping.py @@ -32,5 +32,8 @@ def map(signature: str) -> logic.GenericLogicalOp : __ttemplate('add') : logic.TensorAdd, + # runtime collectives + 'cube.runtime.spatial.move': 'move', + } diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 31c557b7..d5bcfa68 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,7 +1,8 @@ +from cube.graph.ir_opten import IRTensor from typing import Optional, List from cube.graph.parser import ScriptModuleParser -from cube.graph import IRGraph +from cube.graph import IRGraph, IRTensor import torch @@ -16,5 +17,8 @@ def convert(model: torch.nn.Module, raise RuntimeError("Cannot convert module into torchscript moudle.") module_name = smodule.original_name inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) + for input in inputs: + if isinstance(input, IRTensor): + input.requires_grad = False graph = IRGraph(nodes, inputs, outputs, module_name) return graph diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index 4a360d00..b8e37411 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -1,4 +1,4 @@ -import cube.graph as cgraph +from cube.graph import parser, IRLocalGraph from cube.codegen.codegen import SScheduleCodeGen import torch @@ -43,13 +43,14 @@ def init_weight(parameters): def test_codegen(model): - graph = cgraph.parser.convert(model, + graph = parser.convert(model, input_shapes=([1024,1024],)) for node in graph.nodes(): node.device = 0 - gener = SScheduleCodeGen(graph, device=0) + local_graph = IRLocalGraph(graph.nodes(), graph, device=0) + gener = SScheduleCodeGen(local_graph) code = gener.gen(outfile='code.py') - + # execute print("> ===== Generated Code =====") print(code) diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 55a3e344..227b9ecf 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -1,7 +1,9 @@ +from cube.tschedule.pool import TSchedulePool +from cube.graph.ir_opten import IRTensor from torch import nn import cube.graph.parser as parser -import cube.tschedule as tschedule +from cube.codegen.codegen import SScheduleCodeGen class FeedForward(nn.Module): @@ -25,28 +27,51 @@ def forward(self, data, label: int = 4): model = FeedForward(dim=1024) ir_graph = parser.convert(model, input_shapes=([64,1024],[64,])) +print(" > Forward IRGraph ========") +print(ir_graph) +print(" < ==============\n") # device assignment -for node in ir_graph.nodes(): - node.device = 0 +for input in ir_graph.inputs(): + if isinstance(input, IRTensor): + input.device = [0,1] +for nid, node in enumerate(ir_graph.nodes()): + if nid <= 2: + node.device = 0 + else: + node.device = 1 def test_graph_forward(ir_graph): + TSchedulePool().clear() tensor1 = ir_graph() print(tensor1) - print(tschedule.pool.TSchedulePool()) + # print(tschedule.pool.TSchedulePool()) tensor2 = ir_graph() print(tensor2) - print(tschedule.pool.TSchedulePool()) + for action in TSchedulePool().actions(): + print('\n', action) + gener = SScheduleCodeGen(action.graph) + code = gener.gen() + print("> ===== Generated Code =====") + print(code) + print("< ===== Generated Code =====") + print(TSchedulePool()) def test_graph_backward(ir_graph): + TSchedulePool().clear() tensor = ir_graph() tensor.backward() + #tensor = ir_graph() + #tensor.backward() + print('====== Backward Test =======') + print(TSchedulePool()) if __name__ == '__main__': - test_graph_forward(ir_graph) + #test_graph_forward(ir_graph) + test_graph_backward(ir_graph) \ No newline at end of file From fa2460c4a5d90c9480f10d9ac31bef75cbb5c7f1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 22 Sep 2021 22:01:36 +0800 Subject: [PATCH 0191/1892] init cell --- cube/codegen/codegen.py | 177 +++++++++++++++- cube/graph/__init__.py | 5 +- cube/graph/ir_action.py | 126 +++++++++++ cube/graph/{ir_opten.py => ir_cten.py} | 198 +++++------------- cube/graph/ir_graph.py | 275 ++++++++++++++----------- cube/graph/ir_op.py | 100 +++++++++ cube/graph/parser/converter.py | 2 +- cube/graph/unique.py | 8 +- cube/tschedule/__init__.py | 1 - cube/tschedule/action.py | 91 -------- cube/tschedule/pool.py | 17 +- cube/tschedule/sequence.py | 27 +-- tests/codegen/test_codegen.py | 6 +- tests/tschedule/test_tschedule.py | 19 +- 14 files changed, 654 insertions(+), 398 deletions(-) create mode 100644 cube/graph/ir_action.py rename cube/graph/{ir_opten.py => ir_cten.py} (65%) create mode 100644 cube/graph/ir_op.py delete mode 100644 cube/tschedule/action.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 91ad8e80..4f6cb987 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,9 +1,10 @@ """ Generate Pytorch code given the model DAG and the transformation config """ -from typing import List, Any +from cube.tschedule.sequence import ASequence +from typing import List, Any, Dict -from cube.graph import IRLocalGraph, IRTensor, IROperation +from cube.graph import IRAction, IRTensor, IROperation from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -15,10 +16,10 @@ class SScheduleCodeGen: Generate spatial code for the model """ - def __init__(self, graph: IRLocalGraph): - if not isinstance(graph, IRLocalGraph): + def __init__(self, action: IRAction): + if not isinstance(action, IRAction): raise TypeError("graph should be IRGraph") - self.graph = graph + self.graph = action.graph # model full code self.code: List[str] = ['import torch', '', ''] # module init code @@ -30,7 +31,7 @@ def __init__(self, graph: IRLocalGraph): # ref module to check shared variables self._ref_module = torch.nn.Module() - def gen(self, outfile=None) -> List[str]: + def gen(self, device: int, outfile=None) -> str: """ Generate model implementation code based on the given graph. """ @@ -137,3 +138,167 @@ def naming(self, tensor: Any) -> str: else: name = str(tensor) return name + + +class TScheduleCodeGen: + + def __init__(self, seq: ASequence): + if not isinstance(seq, ASequence): + raise TypeError("seq should be ASequence") + self.seq = seq + # model full code + self.code: List[str] = ['from typing import Tuple', 'import torch', '', ''] + # module member name + self.symbols = SymbolTable() + + def gen(self, device: int, outfile=None) -> str: + """ + Generate scheduling code based on the given actions + """ + actions = list() + for action in self.seq.actions(): + if device in action.device: + actions.append(action) + + # {send: xxx, recv: xxx} action1 {send:xxx, recv:xxx} action2 .... + action_with_comms = list(dict()) + for action in actions: + send_tensors = [self.naming(tensor) for tensor in action.graph.send_tensors] + send_devices = action.graph.send_devices + send_shapes = tuple([tensor.shape for tensor in send_tensors]) + recv_tensors = [self.naming(tensor) for tensor in action.graph.recv_tensors] + recv_devices = action.graph.recv_devices + recv_shapes = tuple([tensor.shape for tensor in recv_tensors]) + + comm = action_with_comms[-1] + # recv before the action + if len(recv_tensors) != 0: + comm.update({ + 'recv_tensors' : recv_tensors, + 'recv_devices' : recv_devices, + 'recv_shapes' : recv_shapes + }) + # action + action_with_comms.append(action) + # send after the action + comm = dict() + if len(send_tensors) != 0: + comm.update({ + 'send_tensors' : send_tensors, + 'send_devices' : send_devices, + 'send_shapes' : send_shapes + }) + action_with_comms.append(comm) + + # generate code + with FunctionBlock(func_name='_train_step', + args=['model', 'inputs: Tuple[Tuple[Tensor]]']) as fb: + for action_or_comm in action_with_comms: + if isinstance(action_or_comm, dict): + code = self.emit_comm(action_or_comm) + fb.insert_body(code) + else: + code = self.emit_action(action_or_comm) + fb.insert_body(code) + self.code += fb.code + self.code += [''] + + code = '\n'.join(self.code) + # write to file + if outfile: + with open(outfile, 'w') as f: + f.write(code) + return code + + def emit_comm(self, comm: Dict) -> List[str]: + """ + Emit send / recv code + """ + ssign = 'cube.runtime.spatial.send({send_tensors}, {shapes}, {to_devices})' + rsign = 'cube.runtime.spatial.recv({shapes}, {from_devices})' + srsign = 'cube.runtime.spatial.send_and_recv({send_tensors}, {send_shapes}, {to_devices}, {recv_shapes}, {from_devices})' + + # generate for send + if ('send_tensors') in comm and ('recv_tensors' not in comm): + code = ssign.format( + send_tensors = comm['send_tensors'], + shapes = comm['send_shapes'], + to_devices = comm['send_devices'] + ) + return code + # generate for recv + elif ('send_tensors' not in comm) and ('recv_tensors' in comm): + body = rsign.format( + shapes = comm['recv_shapes'], + from_devices = comm['recv_devices'] + ) + return_val = repr(tuple(comm['recv_tensors'])) + code = f'{return_val} = {body}' + return code + # generate for send + recv + else: + body = srsign.format( + send_tensors = comm['send_tensors'], + shapes = comm['send_shapes'], + to_devices = comm['send_devices'], + recv_shapes = comm['recv_shapes'], + from_devices = comm['recv_devices'] + ) + return_val = repr(tuple(comm['recv_tensors'])) + code = f'{return_val} = {body}' + return code + + def emit_action(self, action) -> List[str]: + """ + Emit action code + """ + fsign = 'cube.runtime.temporal.forward({model}, *inputs[{fid}]})' + bsign = 'cube.runtime.temporal.backward({input_tensors}, {output_tensors}, {output_grads})' + + if action.tag == 'forward': + body = fsign.format( + model = 'model', + fid = 0 + ) + outputs = [self.naming(output) for output in action.outputs()] + return_val = repr(tuple(outputs)) + code = f'{return_val} = {body}' + return code + elif action.tag == 'backward': + # 1). input_tensors are forward inputs (happened before action inputs) + # 2). output_tensors are forward results (action.inputs()) + # 3). output_grads are recved tesnors of this graph (graph.recv_tensors) + output_tensors = [self.naming(input) for input in action.inputs()] + output_grads = None + if len(action.graph.recv_tensors) != 0: + output_grads = [self.naming(tensor) for tensor in action.graph.recv_tensors] + output_grads = repr(tuple(output_grads)) + + body = bsign.format( + input_tensors = None, + output_tensors = output_tensors, + output_grads = output_grads + ) + + # returned value are graph.outputs + return_val = [self.naming(tensor) for tensor in action.graph.outputs()] + return_val = repr(tuple(return_val)) + code = f'{return_val} = {body}' + return code + else: + raise RuntimeError(f"Unsupported action tag: {action.tag}") + + def naming(self, tensor: Any) -> str: + """ + Return the var name (unique for different variable) + + If the var is a leaf tensor, will add prefix `self.` to its name + """ + if isinstance(tensor, IRTensor): + tensor_name = 'tensor' if tensor.name is None else tensor.name + if '.' in tensor_name: + tensor_name = tensor_name.split('.')[0] + name = '_'.join([tensor_name, str(tensor._id)]) + else: + name = str(tensor) + return name diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index 924c2f74..07158ba6 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -1,4 +1,5 @@ -from cube.graph.ir_graph import IRGraph, IRLocalGraph -from cube.graph.ir_opten import IRTensor, IROperation, OperationType +from cube.graph.ir_graph import IRGraph, IRAction +from cube.graph.ir_cten import IRTensor, IRCell +from cube.graph.ir_op import IROperation from cube.graph import parser diff --git a/cube/graph/ir_action.py b/cube/graph/ir_action.py new file mode 100644 index 00000000..8af864be --- /dev/null +++ b/cube/graph/ir_action.py @@ -0,0 +1,126 @@ +from typing import List, Any, Union + +from cube.graph.ir_cten import IRCell, IRTensor +from cube.graph.ir_graph import IRGraph + + +__all__ = ['IRAction'] + +# outputs = cube.runtime.temporal.forward(model, *args) +__forward_signature = 'cube.runtime.temporal.forward' +# grads = cube.runtime.temporal.backward(input_tensors, output_tensors, output_grads) +__backward_signature = 'cube.runtime.temporal.backward' + + +class IRAction(IRCell): + + def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): + + if isinstance(devices, int): + devices = [devices] + + if not isinstance(global_graph, IRGraph): + raise TypeError(f"Expected graph: IRGraph but go {type(global_graph)}") + + if global_graph.tag == 'forward': + signature = __forward_signature + elif global_graph.tag == 'backward': + signature = __backward_signature + else: + raise RuntimeError(f"Unsupported graph tag: {self.global_graph.tag}") + + # send tensors + self.send_tensors = list() + self.send_devices = list() + + # recv tensors + self.recv_tensors = list() + self.recv_devices = list() + + # get nodes belong to this graph + all_tensors = list() + for node in sub_nodes: + # collect recv tensors + for input in node.inputs(): + if isinstance(input, IRTensor): + recv_devices = list(set(devices) - set(input.device)) + if len(recv_devices) != 0: + if input not in self.recv_tensors: + self.recv_tensors.append(input) + self.recv_devices += recv_devices + # collect send tensors + for output in node.outputs(): + if isinstance(output, IRTensor): + succ_nodes = output.dst() + for succ_node in succ_nodes: + send_devices = list(set(devices) - set(succ_node.device)) + if len(send_devices) != 0: + if output not in self.send_tensors: + self.send_tensors.append(output) + self.send_devices.append(send_devices) + all_tensors += node.inputs() + all_tensors += node.outputs() + + # action graph inputs and outputs + inputs = list() + outputs = list() + for input in global_graph.inputs(): + if input in all_tensors and input not in self.recv_tensors: + inputs.append(input) + for output in global_graph.outputs(): + if output in all_tensors and output not in self.send_tensors: + outputs.append(output) + + self.graph = IRGraph( + nodes = sub_nodes, + input_tensors = inputs + self.recv_tensors, + output_tensors = outputs + self.send_tensors, + module_name = global_graph.name + ) + + action_inputs = [self.graph] + [None] * len(self.graph.inputs()) + super().__init__( + name = self.global_graph.tag, + signature = signature, + input_length = len(action_inputs), + output_length = len(self.graph.outputs()) + ) + self.device = devices + self._inputs = action_inputs + + def map_output(self, graph_output_tensor: Any) -> Any: + if graph_output_tensor not in self.graph.outputs(): + return None + index = self.graph.outputs().index(graph_output_tensor) + return self.outputs(index) + + def happen_before(self, action): + """ + Check if the self -> (happened before) action + """ + if not isinstance(action, IRAction): + raise TypeError("Expected action to be an Action") + for pre_actions in self.successors(): + if action in pre_actions: + return True + return False + + def happen_after(self, action): + """ + Check if the action -> (happened before) self + + Note: this may return false negative as it will only check + 1-hop dependency + """ + if not isinstance(action, IRAction): + raise TypeError("Expected action to be an Action") + for pre_actions in self.predecessors(): + if action in pre_actions: + return True + return False + + def add_flow(self, action): + """ + self -> (happened before) action + """ + raise NotImplementedError diff --git a/cube/graph/ir_opten.py b/cube/graph/ir_cten.py similarity index 65% rename from cube/graph/ir_opten.py rename to cube/graph/ir_cten.py index 8ac6a03d..ab8a75a2 100644 --- a/cube/graph/ir_opten.py +++ b/cube/graph/ir_cten.py @@ -1,103 +1,51 @@ -""" -IROperation: - - Semantic operation representation (node) in IRGraph. - An operation is of Computation (Comp) or Communication (Comm) type. - - A Comp type operation can be assigned to multiple devices for redundant computation. - A Comm type operation can be assigned to multiple devices (List[int]). - - Each IROperation can have (multiple) input args and (multiple) output args. - -IRTensor: - - Semantic tensor representation (edge) in IRGraph. - - IRTensor can be assigned (deploy) to multiple devices (List[int]) - - The IRTensor is a logical tensor that - - 1). can be generated from multipe operations (i.e., different operators - can generate different part of the IRTensor). - - => multiple source IROperation. - - 2). can be used as input for multiple operations. - - => multiple destination IROperation - - -IROperation can accept tensors that are placed on the different devices. - -Set the operation device will in default change the output tensor placement -and input leaf tensor placement to match with the operation. -""" - - -from cube.graph.unique import IDGenerator -from cube.graph.mapping import IR2LogicOp - -from enum import Enum -from typing import List, Optional, Any, Union +from typing import List, Union, Optional, Any import copy +from cube.graph.unique import IDGenerator -__all__ = ['OperationType', 'IROperation', 'IRTensor'] - - -class OperationType(Enum): - Comp = 1 # computation - Comm = 2 # communication +__all__ = ['IRCell', 'IRTensor'] -class IROperation: +class IRCell: """ - IROperation serves as IRGraph node + IRCell serves as a general node for different purpose """ def __init__(self, - name: str, + name: str, signature: str, input_length: int, - output_length: int, - type=OperationType.Comp): + output_length: int): """ Create a node with name (variable name) and module type (module_name) Args: - name (str): the op semantic name - signature (str): the op signature, e.g., torch.functional.nn.linear + name (str): the cell name + signature (str): the cell function signature, + e.g., torch.functional.nn.linear input_length (int): the number of inputs for the op output_length (int): the number of outputs for the op """ # node info - self._id: int = IDGenerator().gen_op_id() + self._id: int = IDGenerator().gen_cell_id() self.name: str = name + self.signature = signature - # op signature and op class - self.signature: str = signature - self.semantic = IR2LogicOp.map(self.signature) - self._type = type + # device self._device = list() - # source operations - self._inputs: List[IRTensor] = [None] * input_length - self._predecessors: List[List[IROperation]] = [list() for _ in range(input_length)] + # source tensors + self._inputs: List[Any] = [None] * input_length + # source cells + self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length)] - # destination operations + # destination tensors self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] - for tensor in self._outputs: - tensor.add_src_node(self) - self._successors: List[List[IROperation]] = [list() for _ in range(output_length)] - - @property - def type(self) -> OperationType: - return self._type - - @type.setter - def type(self, _): - raise RuntimeError("Not allowed to set type except initialization") + for output in self._outputs: + output.add_src_node(self) + # destination cells + self._successors: List[List[IRCell]] = [list() for _ in range(output_length)] @property def device(self): @@ -107,29 +55,12 @@ def device(self): def device(self, device_id: Union[int, List[int]]): """ Set the operation device. - - For computation operators, they are only allowed - to happen on one device (int) - - For communication operators (e.g., move, all-reduce), - they are allowed to happend on multiple devices """ if isinstance(device_id, int): device_id = [device_id] if not all([isinstance(devid, int) for devid in device_id]): raise KeyError("Require device Union[int, List[int]]") self._device = device_id - for input in self._inputs: - # in default, parameters will be placed on all devices - # that needs it - if isinstance(input, IRTensor) and input.is_leaf(): - devices = set() - for node in input.dst(): - devices.update(node.device) - input.device = list(devices) - for output in self._outputs: - if isinstance(output, IRTensor): - output.device = device_id def on_device(self, device_id: int): """ @@ -139,7 +70,7 @@ def on_device(self, device_id: int): Boolean """ if not isinstance(device_id, int): - raise TypeError("Expected device id to be int") + raise TypeError(f"Expected device id to be int but got {type(device_id)}") return device_id in self.device def inputs(self, index: Optional[int] = None): @@ -236,8 +167,8 @@ def set_input(self, input_index: int, val: Any): self._predecessors[input_index] = val.src() # set the source node successor for node in val.src(): - if isinstance(node, IROperation): - node._add_successor(val, self) + if isinstance(node, IRCell): + node.add_successor(val, self) def set_output(self, output_index: int, val: Any): """ @@ -254,76 +185,58 @@ def set_output(self, output_index: int, val: Any): if isinstance(val, IRTensor): # set predecessor for node in val.src(): - if isinstance(node, IROperation): + if isinstance(node, IRCell): self._successors[output_index].append(node) # set the source node if self not in val.src(): val.add_src_node(self) - def set_predecessor(self, input_index: int, node, out_index: int): + def add_predecessor(self, input_index: int, node, out_index: int): """ Set self node the input node. self.input[input_index] = node.output[out_index] """ - if not isinstance(node, IROperation): - raise TypeError("Expected node to be IROperation") + if not isinstance(node, IRCell): + raise TypeError("Expected node to be IRCell") if input_index >= len(self.inputs()): raise RuntimeError( f"Set the input out of range ({input_index} >= {len(self._inputs)})" ) self._inputs[input_index] = node.outputs(out_index) - self._predecessors[input_index] = node - node.set_successor(out_index, self) + self._predecessors[input_index].append(node) + node.add_successor(out_index, self) - def _add_successor(self, tensor, node): + def add_successor(self, tensor, node): """ Set self node the output index node. `node` will take the self.outputs(index) as the input """ + if not isinstance(node, IRCell): + raise TypeError("Expected node to be IRCell") out_index = self._outputs.index(tensor) if out_index < 0: raise RuntimeError("Fail to find output tensor") self._successors[out_index].append(node) - def infer_shape(self): + def __repr__(self): """ - Infer output value shape + Cell string presentation """ - shapes = list() - for input in self.inputs(): - if isinstance(input, IRTensor): - if input.shape is None: - return False - shapes.append(input.shape) - else: - shapes.append([1,]) - shapes = tuple(shapes) - out_shapes = self.semantic.shape_infer(*shapes) - if len(out_shapes) != len(self._outputs): - raise RuntimeError( - "The logical op semantic doesn't match with parsed op" - ) - for shape, val in zip(out_shapes, self._outputs): - if isinstance(val, IRTensor): - val.shape = shape - return True - - def __repr__(self): inputs = list() for tensor in self.inputs(): if isinstance(tensor, IRTensor): inputs.append(f't{tensor._id}') else: inputs.append(tensor) - + outputs = list() for tensor in self.outputs(): if isinstance(tensor, IRTensor): outputs.append(f't{tensor._id}') else: outputs.append(tensor) - - dscp = f'Op(id={self._id}, signature={self.signature}, device={self.device}, inputs={inputs}, outputs={outputs})' - return dscp + dcsp = f'Cell-{self._id}({self.signature}, device={self.device})'\ + f'({inputs}) -> {outputs}' + return dcsp class IRTensor: @@ -337,19 +250,19 @@ def __init__(self, shape=None, name=None): self.name = name self._device = list() - # connected to IROperation - self._src_nodes: List[IROperation] = list() # -> output of the node - self._dst_nodes: List[IROperation] = list() # -> input of the nodes + # connected to IRCell + self._src_nodes: List[IRCell] = list() # -> output of the node + self._dst_nodes: List[IRCell] = list() # -> input of the nodes # forward graph self.requires_grad = True - self.forward_graph = None + self.gen_graph = None - def set_forward_graph(self, graph): + def set_gen_graph(self, graph): """ Set forward graph (IRGraph) """ - self.forward_graph = graph + self.gen_graph = graph def __copy__(self): """ @@ -414,7 +327,7 @@ def device(self, device_id: List[int]): raise TypeError(f"Expected device id to be int or List[int]") self._device = device_id - def src(self) -> List[IROperation]: + def src(self) -> List[IRCell]: return self._src_nodes def dst(self, index: Optional[int] = None): @@ -424,14 +337,14 @@ def dst(self, index: Optional[int] = None): raise RuntimeError("get tensor dst out of range") return self._dst_nodes[index] - def add_src_node(self, node: IROperation): - if not isinstance(node, IROperation): - raise TypeError("IRTensor source node should be IROperation") + def add_src_node(self, node: IRCell): + if not isinstance(node, IRCell): + raise TypeError("IRTensor source node should be IRCell") self._src_nodes.append(node) - def add_dst_node(self, node: IROperation): - if not isinstance(node, IROperation): - raise TypeError("IRTensor destination node should be IROperation") + def add_dst_node(self, node: IRCell): + if not isinstance(node, IRCell): + raise TypeError("IRTensor destination node should be IRCell") self._dst_nodes.append(node) def is_leaf(self): @@ -446,12 +359,11 @@ def backward(self): Construct a reverse graph of forward and seperate to actions """ - if self.forward_graph is None: + if self.gen_graph is None: raise RuntimeError("Backward on a tensor without forward graph") - self.forward_graph.backward() + self.gen_graph.backward(self) def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device})' return dscp - diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 437926b5..8f6a0d87 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -1,15 +1,16 @@ -from cube.graph.ir_opten import IROperation, IRTensor -from cube.tschedule.action import Action +from cube.graph.ir_cten import IRTensor, IRCell +from cube.graph.ir_op import IROperation from cube.tschedule.pool import TSchedulePool -from typing import Union, Tuple, List, Optional +from typing import Union, Tuple, List, Optional, Any import copy -__all__ = ['IRGraph', 'IRLocalGraph'] +__all__ = ['IRGraph', 'IRAction'] -class IRGraph: + +class IRGraph(IRCell): """ PyTorch IR Graph @@ -21,15 +22,20 @@ def __init__(self, input_tensors: List[IRTensor], output_tensors: List[IRTensor], module_name: str): - self.module_name = module_name + self._nodes: List[IROperation] = nodes + super().__init__( + name=module_name, + signature=module_name, + input_length=len(input_tensors), + output_length=len(output_tensors) + ) self._inputs = input_tensors self._outputs = output_tensors - # default is forward graph self.tag = 'forward' - def add_node(self, node: IROperation): - if not isinstance(node, IROperation): + def add_node(self, node: IRCell): + if not isinstance(node, IRCell): raise TypeError("Expected node to be IROperation") self._nodes.append(node) @@ -48,45 +54,13 @@ def nodes(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") - def inputs(self, index: Optional[int] = None): - if isinstance(index, int): - if index >= len(self._inputs): - raise RuntimeError( - f"Get the input out of range ({index} >= {len(self._inputs)}" - ) - return self._inputs[index] - elif index is None: - return self._inputs - else: - raise TypeError("Expected index to be None or int") - - def outputs(self, index: Optional[int] = None): - """ - Get output tensor at output index - - Args: - index (int or None): - index of the outputs, None will return the nodes - for all the outputs - """ - if isinstance(index, int): - if index >= len(self._outputs): - raise RuntimeError( - f"Get the output out of range ({index} >= {len(self._outputs)}" - ) - return self._outputs[index] - elif index is None: - return self._outputs - else: - raise TypeError("Expected index to be None or int") - def replace(self, target: IROperation, nodes: List[IROperation]): """ Replace the node with new nodes (IRGraph) """ raise NotImplementedError - def forward(self, *args, **kwargs) -> Union[IRTensor, Tuple[IRTensor]]: + def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: """ forward will divide the graph into Actions according to node device assignment @@ -97,52 +71,71 @@ def forward(self, *args, **kwargs) -> Union[IRTensor, Tuple[IRTensor]]: Returns: List[Action] """ - if len(self._outputs) == 1: - tensor = copy.copy(self._outputs[0]) - tensor.set_forward_graph(self) - return tensor - else: - tensors = tuple([copy.copy(output) for output in self._outputs]) - for tensor in tensors: - if isinstance(tensor, IRTensor): - tensor.set_forward_graph(self) - return tensors - - def __call__(self, *args, **kwargs): - """ - Register forward action - """ - curr_nodes: List[IROperation] = list() - curr_device = None - - def _wrap_to_action(): - sub_graph = IRLocalGraph( - curr_nodes, self, device=curr_device[0] #FIXME + # check input num + if len(args) != len(self.inputs()): + raise RuntimeError( + f"Expected {len(self.inputs())} input args but got {len(args)}" ) - action = Action(sub_graph, device=curr_device[0]) #FIXME - action.tag(self.tag) - return action + # check input type + if not all([type(arg) is type(input) for arg, input in zip(args, self.inputs())]): + raise RuntimeError(f"Expected input type the same") + + curr_nodes: List[IROperation] = list() + curr_device = list() + total_actions = list() for node in self.nodes(): - #FIXME: will fail in multi-branch placement (backward) device = node.device if len(node.device) == 0: raise RuntimeError("All the node should be assigned to devices") - if device != curr_device and curr_device is not None: - # note we still use same input and output to make consistency - action = _wrap_to_action() + if set(device) != set(curr_device) and len(curr_device) != 0: + # create action + action = IRAction(curr_nodes, self, devices=curr_device) + total_actions.append(action) # register to schedule space TSchedulePool().add_action(action) curr_nodes = list() curr_device = device curr_nodes.append(node) if curr_device is not None: - action = _wrap_to_action() + action = IRAction(curr_nodes, self, devices=curr_device) + total_actions.append(action) TSchedulePool().add_action(action) - return self.forward(*args, **kwargs) + # setup action inputs + head = total_actions[0] + for idx, arg in enumerate(args): + head.set_input(idx + 1, arg) # 0 is for graph itself + outputs_tensors = [*head.graph.outputs()] + outputs_actions = [head] * len(head.graph.outputs()) + for action in total_actions[1:]: + for idx, input in enumerate(action.graph.inputs()): + if input not in outputs_tensors: + raise RuntimeError(f"Cannot find {input} tensors") + pre_action = outputs_actions[outputs_tensors.index(input)] + val = pre_action.map_output(input) + action.set_input(idx + 1, val) + outputs_tensors += action.graph.outputs() + outputs_actions += [action] * len(action.graph.outputs()) + + # return tensors + outputs = tuple(total_actions[-1].outputs()) + for output in outputs: + output.set_gen_graph(self) + if len(outputs) == 1: + return outputs[0] + elif len(outputs) == 0: + return None + else: + return outputs + + def __call__(self, *args): + """ + Register forward action + """ + return self.forward(*args) - def backward(self): + def backward(self, loss: IRTensor): """ Backward will generate a backward action scheduling pool @@ -190,11 +183,9 @@ def get_tensor_grad(tensor): name = fnode.name + '_backward', signature = fnode.signature, input_length = len(inputs), - output_length = len(outputs), - type=fnode.type + output_length = len(outputs) ) bp_node.device = fnode.device - print(bp_node) for idx, arg in enumerate(inputs): bp_node.set_input(idx, arg) for idx, arg in enumerate(outputs): @@ -207,12 +198,11 @@ def get_tensor_grad(tensor): graph = IRGraph( backward_nodes, inputs, outputs, - self.module_name + 'Backward' + self.name + 'Backward' ) print(graph) graph.tag = 'backward' - graph() - + graph(loss) def __repr__(self): dscp = '' @@ -230,70 +220,125 @@ def __repr__(self): return dscp -class IRLocalGraph(IRGraph): +# outputs = cube.runtime.temporal.forward(model, *args) +_forward_signature = 'cube.runtime.temporal.forward' +# grads = cube.runtime.temporal.backward(input_tensors, output_tensors, output_grads) +_backward_signature = 'cube.runtime.temporal.backward' - def __init__(self, - sub_nodes: List[IROperation], - global_graph: IRGraph, - device: int - ): + +class IRAction(IRCell): + + def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): + + if isinstance(devices, int): + devices = [devices] if not isinstance(global_graph, IRGraph): raise TypeError(f"Expected graph: IRGraph but go {type(global_graph)}") - if not isinstance(device, int): - raise TypeError(f"Expected device: int but not {type(device)}") - for node in sub_nodes: - if not node.on_device(device): - raise RuntimeError(f"Local Graph requires all nodes on device {device}") - self.global_graph = global_graph - self.device = device + + if global_graph.tag == 'forward': + signature = _forward_signature + elif global_graph.tag == 'backward': + signature = _backward_signature + else: + raise RuntimeError(f"Unsupported graph tag: {self.global_graph.tag}") + + # send tensors self.send_tensors = list() self.send_devices = list() + + # recv tensors self.recv_tensors = list() self.recv_devices = list() + # get nodes belong to this graph all_tensors = list() for node in sub_nodes: # collect recv tensors for input in node.inputs(): if isinstance(input, IRTensor): - if self.device not in input.device: + recv_devices = list(set(devices) - set(input.device)) + if len(recv_devices) != 0: if input not in self.recv_tensors: self.recv_tensors.append(input) - self.recv_devices.append(self.device) + self.recv_devices.append(recv_devices) # collect send tensors for output in node.outputs(): if isinstance(output, IRTensor): succ_nodes = output.dst() for succ_node in succ_nodes: - if not succ_node.on_device(self.device): + send_devices = list(set(devices) - set(succ_node.device)) + if len(send_devices) != 0: if output not in self.send_tensors: self.send_tensors.append(output) - self.send_devices.append(succ_node.device) - # move semantic - # if node.semantic == 'move': - # if device in node.inputs(0).device: - # self.send_tensors.append(node.inputs(0)) - # self.send_devices.append(node.outputs(0).device) - # if device in node.outputs(0).device: - # self.recv_tensors.append(node.outputs(0)) - # self.recv_devices.append(node.inputs(0).device) + self.send_devices.append(send_devices) all_tensors += node.inputs() all_tensors += node.outputs() - # model inputs and outputs - model_inputs = list() - model_outputs = list() - for input in self.global_graph.inputs(): + # action graph inputs and outputs + inputs = list() + outputs = list() + for input in global_graph.inputs(): if input in all_tensors and input not in self.recv_tensors: - model_inputs.append(input) - for output in self.global_graph.outputs(): + inputs.append(input) + for output in global_graph.outputs(): if output in all_tensors and output not in self.send_tensors: - model_outputs.append(output) + outputs.append(output) + self.graph = IRGraph( + nodes = sub_nodes, + input_tensors = inputs + self.recv_tensors, + output_tensors = outputs + self.send_tensors, + module_name = global_graph.name + ) + + action_inputs = [self.graph] + [None] * len(self.graph.inputs()) super().__init__( - sub_nodes, - model_inputs + self.recv_tensors, # input tensors - model_outputs + self.send_tensors, # output tensors - self.global_graph.module_name + f'Rank{self.device}' + name = global_graph.tag, + signature = signature, + input_length = len(action_inputs), + output_length = len(self.graph.outputs()) ) + self.device = devices + self._inputs = action_inputs + + def map_output(self, graph_output_tensor: Any) -> Any: + if graph_output_tensor not in self.graph.outputs(): + return None + index = self.graph.outputs().index(graph_output_tensor) + return self.outputs(index) + + def happen_before(self, action): + """ + Check if the self -> (happened before) action + """ + if not isinstance(action, IRAction): + raise TypeError("Expected action to be an Action") + for pre_actions in self.successors(): + if action in pre_actions: + return True + return False + + def happen_after(self, action): + """ + Check if the action -> (happened before) self + + Note: this may return false negative as it will only check + 1-hop dependency + """ + if not isinstance(action, IRAction): + raise TypeError("Expected action to be an Action") + for pre_actions in self.predecessors(): + if action in pre_actions: + return True + return False + + def add_flow(self, action): + """ + self -> (happened before) action + """ + raise NotImplementedError + + def __repr__(self): + dscp = f'Action({self.name}):\n\t{self.graph.inputs()} -> {self.graph.outputs()}' + return dscp diff --git a/cube/graph/ir_op.py b/cube/graph/ir_op.py new file mode 100644 index 00000000..01071d93 --- /dev/null +++ b/cube/graph/ir_op.py @@ -0,0 +1,100 @@ +from typing import List, Union + +from cube.graph.ir_cten import IRTensor, IRCell +from cube.graph.mapping import IR2LogicOp + + +__call__ = ['IROperation'] + + +class IROperation(IRCell): + + def __init__(self, + name: str, + signature: str, + input_length: int, + output_length: int): + """ + Create a node with name (variable name) and module type (module_name) + + Args: + name (str): the op semantic name + signature (str): the op signature, e.g., torch.functional.nn.linear + input_length (int): the number of inputs for the op + output_length (int): the number of outputs for the op + """ + super().__init__(name, signature, input_length, output_length) + self.semantic = IR2LogicOp.map(self.signature) + + @property + def device(self): + return self._device + + @device.setter + def device(self, device_id: Union[int, List[int]]): + """ + Set the operation device. + + For computation operators, they are only allowed + to happen on one device (int) + + For communication operators (e.g., move, all-reduce), + they are allowed to happend on multiple devices + """ + if isinstance(device_id, int): + device_id = [device_id] + if not all([isinstance(devid, int) for devid in device_id]): + raise ValueError("Require device Union[int, List[int]]") + self._device = device_id + for input in self._inputs: + # in default, parameters will be placed on + # all devices that needs it + if isinstance(input, IRTensor) and input.is_leaf(): + devices = set() + for node in input.dst(): + devices.update(node.device) + input.device = list(devices) + for output in self._outputs: + if isinstance(output, IRTensor): + output.device = device_id + + def infer_shape(self): + """ + Infer output value shape + """ + shapes = list() + for input in self.inputs(): + if isinstance(input, IRTensor): + if input.shape is None: + return False + shapes.append(input.shape) + else: + shapes.append([1,]) + shapes = tuple(shapes) + out_shapes = self.semantic.shape_infer(*shapes) + if len(out_shapes) != len(self._outputs): + raise RuntimeError( + "The logical op semantic doesn't match with parsed op" + ) + for shape, val in zip(out_shapes, self._outputs): + if isinstance(val, IRTensor): + val.shape = shape + return True + + def __repr__(self): + inputs = list() + for tensor in self.inputs(): + if isinstance(tensor, IRTensor): + inputs.append(f't{tensor._id}') + else: + inputs.append(tensor) + + outputs = list() + for tensor in self.outputs(): + if isinstance(tensor, IRTensor): + outputs.append(f't{tensor._id}') + else: + outputs.append(tensor) + + dscp = f'Op(id={self._id}, signature={self.signature}, device={self.device}, inputs={inputs}, outputs={outputs})' + return dscp diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index d5bcfa68..79e80257 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,4 +1,4 @@ -from cube.graph.ir_opten import IRTensor +from cube.graph.ir_cten import IRTensor from typing import Optional, List from cube.graph.parser import ScriptModuleParser diff --git a/cube/graph/unique.py b/cube/graph/unique.py index ee34a95a..ac45bd27 100644 --- a/cube/graph/unique.py +++ b/cube/graph/unique.py @@ -10,7 +10,7 @@ class __IDGenerator: def __init__(self): self._tensor_id = 0 - self._op_id = 0 + self._cell_id = 0 instance = None @@ -25,9 +25,9 @@ def gen_tensor_id(self): self.instance._tensor_id += 1 return self.instance._tensor_id - def gen_op_id(self): - self.instance._op_id += 1 - return self.instance._op_id + def gen_cell_id(self): + self.instance._cell_id += 1 + return self.instance._cell_id def clear(self): self.instance._tensor_id = 0 diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py index 210ed4b8..e69de29b 100644 --- a/cube/tschedule/__init__.py +++ b/cube/tschedule/__init__.py @@ -1 +0,0 @@ -from cube.tschedule.action import Action diff --git a/cube/tschedule/action.py b/cube/tschedule/action.py deleted file mode 100644 index a1d71631..00000000 --- a/cube/tschedule/action.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import List - - -class Action: - """ - Action represents a (sub-)graph which contains operators on the - same device - """ - def __init__(self, ir_graph, device: int): - - if not isinstance(device, int): - raise TypeError("Require device to be int") - # set up attributes - self.graph = ir_graph - self.device: int = device - self.name: str = None - # dependencies - self._pre_actions: List[Action] = list() - self._post_actions: List[Action] = list() - - @property - def device(self): - return self._device - - @device.setter - def device(self, device): - for op in self.graph.nodes(): - op.deivce = device - self._device = device - - def tag(self, name: str): - """ - Tag a string to indicate this action (as name) - """ - self.name = name - - def happen_before(self, action): - """ - Check if the self -> (happened before) action - """ - if not isinstance(action, Action): - raise TypeError("Expected action to be an Action") - return action in self._post_actions - - def post_actions(self): - """ - Get post action list - """ - return self._post_actions - - def happen_after(self, action): - """ - Check if the action -> (happened before) self - - Note: this may return false negative as it will only check - 1-hop dependency - """ - if not isinstance(action, Action): - raise TypeError("Expected action to be an Action") - return action in self._pre_actions - - def pre_actions(self): - """ - Get pre action list - - Note: this may return false negative as it will only check - 1-hop dependency - """ - return self._pre_actions - - def add_flow(self, action): - """ - Make this action (self) -> (happened before) action - """ - if not isinstance(action, Action): - raise TypeError("Expected action to be Action") - self._post_actions.append(action) - action._add_pre_action(self) - - def _add_pre_action(self, action): - """ - Add successor that requries this action happened first - """ - if not isinstance(action, Action): - raise TypeError("Expected action to be Action") - self._successors.append(action) - - def __repr__(self): - dscp = f'Action({self.name}):\n\t{self.graph.outputs()} <- {self.graph.inputs()}' - return dscp - \ No newline at end of file diff --git a/cube/tschedule/pool.py b/cube/tschedule/pool.py index 8b0ba80b..21c5c5b4 100644 --- a/cube/tschedule/pool.py +++ b/cube/tschedule/pool.py @@ -1,6 +1,4 @@ -from typing import List, Callable - -from cube.tschedule.action import Action +from typing import Callable class TSchedulePool: @@ -9,7 +7,8 @@ class __TSchedulePool: def __init__(self): - self._actions: List[Action] = list() + self._actions = list() + self._flow_id = -1 instance = None @@ -20,7 +19,7 @@ def __init__(self): def __getattr__(self, name): return getattr(self.instance, name) - def add_action(self, action: Action): + def add_action(self, action): self.instance._actions.append(action) def actions(self): @@ -28,6 +27,14 @@ def actions(self): def clear(self): self.instance._actions = list() + self.instance._flow_id = -1 + + def gen_id(self) -> int: + """ + Generate an unique action id + """ + self.instance._flow_id += 1 + return self.instance._flow_id def __repr__(self): dscp = '\n'.join([repr(action) for action in self._actions]) diff --git a/cube/tschedule/sequence.py b/cube/tschedule/sequence.py index cb680a77..2673752d 100644 --- a/cube/tschedule/sequence.py +++ b/cube/tschedule/sequence.py @@ -1,15 +1,10 @@ -from typing import List, Tuple, NewType +from typing import List, Tuple, NewType, Any import numpy as np -from cube.tschedule.action import Action - class ASequence: - def __init__(self, actions: List[Action]): - - if not all([isinstance(action, Action) for action in actions]): - raise TypeError("Expected a list of Actions") + def __init__(self, actions): self.sequence = actions @@ -25,12 +20,10 @@ def __len__(self) -> int: """ return len(self.sequence) - def append(self, action: Action): - if not isinstance(action, Action): - raise TypeError("Expected an action") + def append(self, action): self.sequence.append(action) - def pop(self) -> Action: + def pop(self): """ Pop the last action and return """ @@ -57,13 +50,13 @@ def is_correct(self): # ======= Blow should be moved from this module ======== # -Relation = NewType('Relation', List[Tuple[Action, Action]]) +Relation = NewType('Relation', List[Tuple[Any, Any]]) class ScheduleSpace: @staticmethod - def tspace(remain_actions: List[Action], + def tspace(remain_actions, path_shuffle=True, relations=None, seq: ASequence = ASequence(list())): @@ -94,7 +87,7 @@ def tspace(remain_actions: List[Action], @staticmethod - def sspace(actions: List[Action], ndevice: int, path_shuffle=True, depth=0): + def sspace(actions, ndevice: int, path_shuffle=True, depth=0): """ Iterate on the possible action space """ @@ -112,7 +105,7 @@ def sspace(actions: List[Action], ndevice: int, path_shuffle=True, depth=0): @staticmethod - def _ready_actions(actions: List[Action], sub_relations: Relation) -> List[Action]: + def _ready_actions(actions, sub_relations: Relation): """ Get ready to emit actions based on sub_relations """ @@ -129,7 +122,7 @@ def _ready_actions(actions: List[Action], sub_relations: Relation) -> List[Actio @staticmethod - def _get_relations(actions: List[Action]) -> Relation: + def _get_relations(actions) -> Relation: """ Get relation tuples (Action1 -> Action2) """ @@ -142,7 +135,7 @@ def _get_relations(actions: List[Action]) -> Relation: @staticmethod - def _remove_action(target: Action, relations: Relation) -> Relation: + def _remove_action(target, relations: Relation) -> Relation: """ Remove the target action from relation set """ diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index b8e37411..ad71f042 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -1,4 +1,4 @@ -from cube.graph import parser, IRLocalGraph +from cube.graph import parser, IRAction, IRTensor from cube.codegen.codegen import SScheduleCodeGen import torch @@ -47,9 +47,9 @@ def test_codegen(model): input_shapes=([1024,1024],)) for node in graph.nodes(): node.device = 0 - local_graph = IRLocalGraph(graph.nodes(), graph, device=0) + local_graph = IRAction(graph.nodes(), graph, devices=[0]) gener = SScheduleCodeGen(local_graph) - code = gener.gen(outfile='code.py') + code = gener.gen(device=0, outfile='code.py') # execute print("> ===== Generated Code =====") diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 227b9ecf..278313df 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -1,5 +1,5 @@ from cube.tschedule.pool import TSchedulePool -from cube.graph.ir_opten import IRTensor +from cube.graph.ir_cten import IRTensor from torch import nn import cube.graph.parser as parser @@ -15,18 +15,17 @@ def __init__(self, dim, dropout=0., mult=16, classes=1000): self.linear2 = nn.Linear(dim * mult, dim) self.classifier = nn.Linear(dim, classes) - def forward(self, data, label: int = 4): + def forward(self, data): output = self.linear1(data) output = self.gelu(output) output = self.dropout(output) output = self.linear2(output) - output = output + data output = self.classifier(output) return output model = FeedForward(dim=1024) -ir_graph = parser.convert(model, input_shapes=([64,1024],[64,])) +ir_graph = parser.convert(model, input_shapes=([64,1024],)) print(" > Forward IRGraph ========") print(ir_graph) print(" < ==============\n") @@ -34,7 +33,7 @@ def forward(self, data, label: int = 4): # device assignment for input in ir_graph.inputs(): if isinstance(input, IRTensor): - input.device = [0,1] + input.device = [0] for nid, node in enumerate(ir_graph.nodes()): if nid <= 2: node.device = 0 @@ -45,15 +44,15 @@ def forward(self, data, label: int = 4): def test_graph_forward(ir_graph): TSchedulePool().clear() - tensor1 = ir_graph() + tensor1 = ir_graph(IRTensor(shape=[64,1024])) print(tensor1) # print(tschedule.pool.TSchedulePool()) - tensor2 = ir_graph() - print(tensor2) + # tensor2 = ir_graph() + # print(tensor2) for action in TSchedulePool().actions(): print('\n', action) gener = SScheduleCodeGen(action.graph) - code = gener.gen() + code = gener.gen(device=action.device[0]) print("> ===== Generated Code =====") print(code) print("< ===== Generated Code =====") @@ -63,7 +62,7 @@ def test_graph_forward(ir_graph): def test_graph_backward(ir_graph): TSchedulePool().clear() - tensor = ir_graph() + tensor = ir_graph(IRTensor(shape=[64,1024])) tensor.backward() #tensor = ir_graph() #tensor.backward() From e2e766f6f731014f34ee2546e4de29cc49d98a17 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Sep 2021 10:25:20 +0800 Subject: [PATCH 0192/1892] backward graph with action dependency --- cube/graph/ir_action.py | 126 ------------------------------ cube/graph/ir_graph.py | 40 +++++----- cube/graph/mapping.py | 2 + cube/operator/logic/function.py | 11 +++ tests/tschedule/test_tschedule.py | 18 ++++- 5 files changed, 51 insertions(+), 146 deletions(-) delete mode 100644 cube/graph/ir_action.py diff --git a/cube/graph/ir_action.py b/cube/graph/ir_action.py deleted file mode 100644 index 8af864be..00000000 --- a/cube/graph/ir_action.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import List, Any, Union - -from cube.graph.ir_cten import IRCell, IRTensor -from cube.graph.ir_graph import IRGraph - - -__all__ = ['IRAction'] - -# outputs = cube.runtime.temporal.forward(model, *args) -__forward_signature = 'cube.runtime.temporal.forward' -# grads = cube.runtime.temporal.backward(input_tensors, output_tensors, output_grads) -__backward_signature = 'cube.runtime.temporal.backward' - - -class IRAction(IRCell): - - def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): - - if isinstance(devices, int): - devices = [devices] - - if not isinstance(global_graph, IRGraph): - raise TypeError(f"Expected graph: IRGraph but go {type(global_graph)}") - - if global_graph.tag == 'forward': - signature = __forward_signature - elif global_graph.tag == 'backward': - signature = __backward_signature - else: - raise RuntimeError(f"Unsupported graph tag: {self.global_graph.tag}") - - # send tensors - self.send_tensors = list() - self.send_devices = list() - - # recv tensors - self.recv_tensors = list() - self.recv_devices = list() - - # get nodes belong to this graph - all_tensors = list() - for node in sub_nodes: - # collect recv tensors - for input in node.inputs(): - if isinstance(input, IRTensor): - recv_devices = list(set(devices) - set(input.device)) - if len(recv_devices) != 0: - if input not in self.recv_tensors: - self.recv_tensors.append(input) - self.recv_devices += recv_devices - # collect send tensors - for output in node.outputs(): - if isinstance(output, IRTensor): - succ_nodes = output.dst() - for succ_node in succ_nodes: - send_devices = list(set(devices) - set(succ_node.device)) - if len(send_devices) != 0: - if output not in self.send_tensors: - self.send_tensors.append(output) - self.send_devices.append(send_devices) - all_tensors += node.inputs() - all_tensors += node.outputs() - - # action graph inputs and outputs - inputs = list() - outputs = list() - for input in global_graph.inputs(): - if input in all_tensors and input not in self.recv_tensors: - inputs.append(input) - for output in global_graph.outputs(): - if output in all_tensors and output not in self.send_tensors: - outputs.append(output) - - self.graph = IRGraph( - nodes = sub_nodes, - input_tensors = inputs + self.recv_tensors, - output_tensors = outputs + self.send_tensors, - module_name = global_graph.name - ) - - action_inputs = [self.graph] + [None] * len(self.graph.inputs()) - super().__init__( - name = self.global_graph.tag, - signature = signature, - input_length = len(action_inputs), - output_length = len(self.graph.outputs()) - ) - self.device = devices - self._inputs = action_inputs - - def map_output(self, graph_output_tensor: Any) -> Any: - if graph_output_tensor not in self.graph.outputs(): - return None - index = self.graph.outputs().index(graph_output_tensor) - return self.outputs(index) - - def happen_before(self, action): - """ - Check if the self -> (happened before) action - """ - if not isinstance(action, IRAction): - raise TypeError("Expected action to be an Action") - for pre_actions in self.successors(): - if action in pre_actions: - return True - return False - - def happen_after(self, action): - """ - Check if the action -> (happened before) self - - Note: this may return false negative as it will only check - 1-hop dependency - """ - if not isinstance(action, IRAction): - raise TypeError("Expected action to be an Action") - for pre_actions in self.predecessors(): - if action in pre_actions: - return True - return False - - def add_flow(self, action): - """ - self -> (happened before) action - """ - raise NotImplementedError diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 8f6a0d87..bdce3239 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -3,13 +3,11 @@ from cube.tschedule.pool import TSchedulePool from typing import Union, Tuple, List, Optional, Any -import copy __all__ = ['IRGraph', 'IRAction'] - class IRGraph(IRCell): """ PyTorch IR Graph @@ -105,7 +103,7 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: # setup action inputs head = total_actions[0] for idx, arg in enumerate(args): - head.set_input(idx + 1, arg) # 0 is for graph itself + head.set_input(idx, arg) outputs_tensors = [*head.graph.outputs()] outputs_actions = [head] * len(head.graph.outputs()) for action in total_actions[1:]: @@ -114,7 +112,7 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: raise RuntimeError(f"Cannot find {input} tensors") pre_action = outputs_actions[outputs_tensors.index(input)] val = pre_action.map_output(input) - action.set_input(idx + 1, val) + action.set_input(idx, val) outputs_tensors += action.graph.outputs() outputs_actions += [action] * len(action.graph.outputs()) @@ -146,13 +144,10 @@ def backward(self, loss: IRTensor): def get_tensor_grad(tensor): if tensor._id not in all_tensors: - new_tensor = copy.deepcopy(tensor) - if tensor.name is None: - new_tensor.name = 'grad' - else: - new_tensor.name = tensor.name + '_grad' - new_tensor._src_nodes = list() - new_tensor._dst_nodes = list() + name = 'grad' if tensor.name is None else tensor.name + '_grad' + new_tensor = IRTensor( + shape=tensor.shape, name=name + ) # reverse op devices = set() for node in tensor.dst(): @@ -163,12 +158,21 @@ def get_tensor_grad(tensor): else: return all_tensors[tensor._id] + # backward graph inputs + graph_inputs = list() + # none outputs for loss + graph_outputs = list() + # nodes backward_nodes = list() + all_bp_tensors = list() for fnode in self._nodes[::-1]: inputs = list() for input in fnode.outputs(): if isinstance(input, IRTensor) and input.requires_grad: tensor = get_tensor_grad(input) + if tensor not in all_bp_tensors: + graph_inputs.append(tensor) + all_bp_tensors.append(tensor) inputs.append(tensor) else: inputs.append(None) @@ -176,6 +180,7 @@ def get_tensor_grad(tensor): for output in fnode.inputs(): if isinstance(output, IRTensor) and output.requires_grad: tensor = get_tensor_grad(output) + all_bp_tensors.append(tensor) outputs.append(tensor) else: outputs.append(None) @@ -191,13 +196,9 @@ def get_tensor_grad(tensor): for idx, arg in enumerate(outputs): bp_node.set_output(idx, arg) backward_nodes.append(bp_node) - # none inputs for loss - inputs = list() - # none outputs for loss - outputs = list() graph = IRGraph( backward_nodes, - inputs, outputs, + graph_inputs, graph_outputs, self.name + 'Backward' ) print(graph) @@ -292,7 +293,7 @@ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): module_name = global_graph.name ) - action_inputs = [self.graph] + [None] * len(self.graph.inputs()) + action_inputs = [None] * len(self.graph.inputs()) super().__init__( name = global_graph.tag, signature = signature, @@ -301,6 +302,7 @@ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): ) self.device = devices self._inputs = action_inputs + print(self.graph) def map_output(self, graph_output_tensor: Any) -> Any: if graph_output_tensor not in self.graph.outputs(): @@ -340,5 +342,7 @@ def add_flow(self, action): raise NotImplementedError def __repr__(self): - dscp = f'Action({self.name}):\n\t{self.graph.inputs()} -> {self.graph.outputs()}' + action_inputs = [f't{tensor._id}' for tensor in self.inputs()] + action_outputs = [f't{tensor._id}' for tensor in self.outputs()] + dscp = f'Action({self.name}):\n\t{self.graph.inputs()} ({action_inputs}) -> {self.graph.outputs()} ({action_outputs})' return dscp diff --git a/cube/graph/mapping.py b/cube/graph/mapping.py index 7156e18b..a94902bc 100644 --- a/cube/graph/mapping.py +++ b/cube/graph/mapping.py @@ -32,6 +32,8 @@ def map(signature: str) -> logic.GenericLogicalOp : __ttemplate('add') : logic.TensorAdd, + __ttemplate('sum') : logic.TensorSum, + # runtime collectives 'cube.runtime.spatial.move': 'move', diff --git a/cube/operator/logic/function.py b/cube/operator/logic/function.py index 52a1d290..2090ae7b 100644 --- a/cube/operator/logic/function.py +++ b/cube/operator/logic/function.py @@ -44,3 +44,14 @@ class TensorAdd(ElementSameInputOp): def __init__(self, signature: str): super().__init__(signature) + + +class TensorSum(GenericLogicalOp): + + @staticmethod + def candidates(): + raise NotImplementedError + + @staticmethod + def shape_infer(*args): + return [[1,],] diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 278313df..24a61270 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -1,6 +1,7 @@ from cube.tschedule.pool import TSchedulePool from cube.graph.ir_cten import IRTensor from torch import nn +import torch import cube.graph.parser as parser from cube.codegen.codegen import SScheduleCodeGen @@ -21,7 +22,8 @@ def forward(self, data): output = self.dropout(output) output = self.linear2(output) output = self.classifier(output) - return output + loss = torch.sum(output) + return loss model = FeedForward(dim=1024) @@ -73,4 +75,16 @@ def test_graph_backward(ir_graph): if __name__ == '__main__': #test_graph_forward(ir_graph) - test_graph_backward(ir_graph) \ No newline at end of file + test_graph_backward(ir_graph) + + + +""" +loss = cube.runtime.temporal.forward(model, input1, input2, xxx) +grad1, grad2, ... = cube.runtime.temporal.backward(loss, None) +""" + +""" +out1, out2 = cube.runtime.temporal.forward(model, input1) +cube.runtime.temporal.backward(out1, out2, out1_grad, out2_grad) +""" \ No newline at end of file From c11428f9a425219853b3d20c5914146cac8c75a5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Sep 2021 11:02:31 +0800 Subject: [PATCH 0193/1892] successor and predecessor flatten --- cube/graph/ir_cten.py | 14 ++++++++++---- cube/graph/ir_graph.py | 9 +++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/cube/graph/ir_cten.py b/cube/graph/ir_cten.py index ab8a75a2..07ef0130 100644 --- a/cube/graph/ir_cten.py +++ b/cube/graph/ir_cten.py @@ -93,7 +93,7 @@ def inputs(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") - def predecessors(self, index: Optional[int] = None): + def predecessors(self, index: Optional[int] = None) -> List: """ Get input operator at input index """ @@ -104,7 +104,10 @@ def predecessors(self, index: Optional[int] = None): ) return self._predecessors[index] elif index is None: - return self._predecessors + predecessors = list() + for pre_cells in self._predecessors: + predecessors += pre_cells + return predecessors else: raise TypeError("Expected index to be None or int") @@ -128,7 +131,7 @@ def outputs(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") - def successors(self, index: Optional[int] = None): + def successors(self, index: Optional[int] = None) -> List: """ Get output operator at output index @@ -144,7 +147,10 @@ def successors(self, index: Optional[int] = None): ) return self._successors[index] elif index is None: - return self._successors + successors = list() + for post_cells in self._successors: + successors += post_cells + return post_cells else: raise TypeError("Expected index to be None or int") diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index bdce3239..0d3fbdb9 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -144,10 +144,11 @@ def backward(self, loss: IRTensor): def get_tensor_grad(tensor): if tensor._id not in all_tensors: - name = 'grad' if tensor.name is None else tensor.name + '_grad' + #name = 'grad' if tensor.name is None else tensor.name + '_grad' new_tensor = IRTensor( - shape=tensor.shape, name=name + shape=tensor.shape, name=tensor.name ) + new_tensor._id = tensor._id # -> keep same tensor # reverse op devices = set() for node in tensor.dst(): @@ -212,8 +213,8 @@ def __repr__(self): # nodes for node in self._nodes: succ_node_ids = [None] * len(node.outputs()) - for out_idx, node_list in enumerate(node.successors()): - node_list = [snode._id for snode in node_list] + for out_idx in range(len(node.outputs())): + node_list = [snode._id for snode in node.successors(out_idx)] succ_node_ids[out_idx] = node_list dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" # outputs From 7e05bfbdeb0483892756f077aa0449d0b8162e37 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Sep 2021 14:09:33 +0800 Subject: [PATCH 0194/1892] about to finish code gen --- cube/codegen/codegen.py | 112 ++++++++++++++++++++---------- cube/graph/ir_cten.py | 8 +-- cube/graph/ir_graph.py | 40 ++++++++--- cube/graph/ir_seq.py | 102 +++++++++++++++++++++++++++ cube/graph/mapping.py | 1 + tests/tschedule/test_tschedule.py | 12 +++- 6 files changed, 220 insertions(+), 55 deletions(-) create mode 100644 cube/graph/ir_seq.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 4f6cb987..7e696e7d 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,10 +1,10 @@ """ Generate Pytorch code given the model DAG and the transformation config """ -from cube.tschedule.sequence import ASequence from typing import List, Any, Dict from cube.graph import IRAction, IRTensor, IROperation +from cube.graph.ir_seq import IRSequence from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -21,7 +21,9 @@ def __init__(self, action: IRAction): raise TypeError("graph should be IRGraph") self.graph = action.graph # model full code - self.code: List[str] = ['import torch', '', ''] + self.code: List[str] = [ + '########## Generated Code ###########', + 'import torch', '', ''] # module init code self.declare_region: List[str] = list() # module forward code @@ -142,12 +144,14 @@ def naming(self, tensor: Any) -> str: class TScheduleCodeGen: - def __init__(self, seq: ASequence): - if not isinstance(seq, ASequence): - raise TypeError("seq should be ASequence") + def __init__(self, seq: IRSequence): + if not isinstance(seq, IRSequence): + raise TypeError("seq should be IRSequence") self.seq = seq # model full code - self.code: List[str] = ['from typing import Tuple', 'import torch', '', ''] + self.code: List[str] = [ + '########## Generated Code ###########', + 'from typing import Tuple', 'import torch', '', ''] # module member name self.symbols = SymbolTable() @@ -156,19 +160,30 @@ def gen(self, device: int, outfile=None) -> str: Generate scheduling code based on the given actions """ actions = list() - for action in self.seq.actions(): + for action in self.seq: if device in action.device: actions.append(action) # {send: xxx, recv: xxx} action1 {send:xxx, recv:xxx} action2 .... - action_with_comms = list(dict()) + action_with_comms = [dict()] for action in actions: - send_tensors = [self.naming(tensor) for tensor in action.graph.send_tensors] - send_devices = action.graph.send_devices - send_shapes = tuple([tensor.shape for tensor in send_tensors]) - recv_tensors = [self.naming(tensor) for tensor in action.graph.recv_tensors] - recv_devices = action.graph.recv_devices - recv_shapes = tuple([tensor.shape for tensor in recv_tensors]) + num_send_tensors = len(action.send_tensors) + if num_send_tensors == 0: + send_tensors = list() + else: + send_tensors = action.outputs()[-num_send_tensors:] + send_tensors = [self.naming(tensor) for tensor in send_tensors] + send_devices = action.send_devices + send_shapes = tuple([tensor.shape for tensor in action.send_tensors]) + + num_recv_tensors = len(action.recv_tensors) + if num_recv_tensors == 0: + recv_tensors = list() + else: + recv_tensors = action.inputs()[-num_recv_tensors:] + recv_tensors = [self.naming(tensor) for tensor in recv_tensors] + recv_devices = action.recv_devices + recv_shapes = tuple([tensor.shape for tensor in action.recv_tensors]) comm = action_with_comms[-1] # recv before the action @@ -220,8 +235,9 @@ def emit_comm(self, comm: Dict) -> List[str]: # generate for send if ('send_tensors') in comm and ('recv_tensors' not in comm): + send_tensors = ', '.join(comm['send_tensors']) code = ssign.format( - send_tensors = comm['send_tensors'], + send_tensors = send_tensors, shapes = comm['send_shapes'], to_devices = comm['send_devices'] ) @@ -232,58 +248,78 @@ def emit_comm(self, comm: Dict) -> List[str]: shapes = comm['recv_shapes'], from_devices = comm['recv_devices'] ) - return_val = repr(tuple(comm['recv_tensors'])) + return_val = ','.join(comm['recv_tensors']) code = f'{return_val} = {body}' return code # generate for send + recv - else: + elif ('send_tensors' in comm) and ('recv_tensors' in comm): + send_tensors = ', '.join(comm['send_tensors']) body = srsign.format( - send_tensors = comm['send_tensors'], - shapes = comm['send_shapes'], + send_tensors = send_tensors, + send_shapes = comm['send_shapes'], to_devices = comm['send_devices'], recv_shapes = comm['recv_shapes'], from_devices = comm['recv_devices'] ) - return_val = repr(tuple(comm['recv_tensors'])) + return_val = ','.join(comm['recv_tensors']) code = f'{return_val} = {body}' return code + else: + return [] - def emit_action(self, action) -> List[str]: + def emit_action(self, action: IRAction) -> List[str]: """ Emit action code """ - fsign = 'cube.runtime.temporal.forward({model}, *inputs[{fid}]})' + fsign = 'cube.runtime.temporal.forward({model}, *{inputs})' bsign = 'cube.runtime.temporal.backward({input_tensors}, {output_tensors}, {output_grads})' - if action.tag == 'forward': + if action.name == 'forward': + inputs = [self.naming(tensor) for tensor in action.inputs()] + inputs = '(' + ', '.join(inputs) + ',)' body = fsign.format( model = 'model', - fid = 0 + inputs = inputs ) outputs = [self.naming(output) for output in action.outputs()] - return_val = repr(tuple(outputs)) + return_val = ','.join(outputs) code = f'{return_val} = {body}' return code - elif action.tag == 'backward': + + elif action.name == 'backward': # 1). input_tensors are forward inputs (happened before action inputs) - # 2). output_tensors are forward results (action.inputs()) + # => backward graph output tensor (share tensor in forward / backward graph) + # 2). output_tensors are forward outputs (action.inputs()) + # => backward graph input tensor (share tensor in forward / backward) # 3). output_grads are recved tesnors of this graph (graph.recv_tensors) - output_tensors = [self.naming(input) for input in action.inputs()] - output_grads = None - if len(action.graph.recv_tensors) != 0: - output_grads = [self.naming(tensor) for tensor in action.graph.recv_tensors] - output_grads = repr(tuple(output_grads)) + # => backward graph input tensor (graph.recv_tensors) + forward_inputs = self.seq.get_forward_inputs(action) + forward_inputs = [self.naming(tensor) for tensor in forward_inputs] + forward_inputs = '(' + ', '.join(forward_inputs) + ',)' + forward_outputs = self.seq.get_forward_outputs(action) + forward_outputs = [self.naming(tensor) for tensor in forward_outputs] + forward_outputs = '(' + ', '.join(forward_outputs) + ',)' + num_recv_tensors = len(action.recv_tensors) + if num_recv_tensors == 0: + recv_grads = list() + else: + recv_grads = action.inputs()[-num_recv_tensors:] + recv_grads = [self.naming(tensor) for tensor in recv_grads] + recv_grads = '(' + ','.join(recv_grads) + ',)' body = bsign.format( - input_tensors = None, - output_tensors = output_tensors, - output_grads = output_grads + input_tensors = forward_inputs, + output_tensors = forward_outputs, + output_grads = recv_grads ) # returned value are graph.outputs - return_val = [self.naming(tensor) for tensor in action.graph.outputs()] - return_val = repr(tuple(return_val)) - code = f'{return_val} = {body}' + return_val = [self.naming(tensor) for tensor in action.outputs()] + if len(return_val) > 0: + return_code = ', '.join(return_val) + ' = ' + else: + return_code = '' + code = f'{return_code}{body}' return code else: raise RuntimeError(f"Unsupported action tag: {action.tag}") diff --git a/cube/graph/ir_cten.py b/cube/graph/ir_cten.py index 07ef0130..0927cb24 100644 --- a/cube/graph/ir_cten.py +++ b/cube/graph/ir_cten.py @@ -89,7 +89,7 @@ def inputs(self, index: Optional[int] = None): ) return self._inputs[index] elif index is None: - return self._inputs + return copy.copy(self._inputs) else: raise TypeError("Expected index to be None or int") @@ -102,7 +102,7 @@ def predecessors(self, index: Optional[int] = None) -> List: raise RuntimeError( f"Get the input out of range ({index} >= {len(self._inputs)}" ) - return self._predecessors[index] + return copy.copy(self._predecessors[index]) elif index is None: predecessors = list() for pre_cells in self._predecessors: @@ -127,7 +127,7 @@ def outputs(self, index: Optional[int] = None): ) return self._outputs[index] elif index is None: - return self._outputs + return copy.copy(self._outputs) else: raise TypeError("Expected index to be None or int") @@ -145,7 +145,7 @@ def successors(self, index: Optional[int] = None) -> List: raise RuntimeError( f"Get the output out of range ({index} >= {len(self._outputs)}" ) - return self._successors[index] + return copy.copy(self._successors[index]) elif index is None: successors = list() for post_cells in self._successors: diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 0d3fbdb9..adee1cb4 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -1,3 +1,4 @@ +from torch._C import device from cube.graph.ir_cten import IRTensor, IRCell from cube.graph.ir_op import IROperation from cube.tschedule.pool import TSchedulePool @@ -153,7 +154,10 @@ def get_tensor_grad(tensor): devices = set() for node in tensor.dst(): devices.update(node.device) - new_tensor.device = list(devices) + devices = list(devices) + if len(devices) == 0: + devices = tensor.device + new_tensor.device = devices all_tensors[tensor._id] = new_tensor return new_tensor else: @@ -166,6 +170,15 @@ def get_tensor_grad(tensor): # nodes backward_nodes = list() all_bp_tensors = list() + # the first node: loss to none + # none_node = IROperation( + # name = 'tonone', + # signature = 'cube.temporal.to_none', + # input_length=1, + # output_length=1 + # ) + # none_node.set_input(0, loss) + for fnode in self._nodes[::-1]: inputs = list() for input in fnode.outputs(): @@ -207,7 +220,7 @@ def get_tensor_grad(tensor): graph(loss) def __repr__(self): - dscp = '' + dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs dscp += f'Inputs: {self._inputs}\n' # nodes @@ -218,7 +231,7 @@ def __repr__(self): succ_node_ids[out_idx] = node_list dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" # outputs - dscp += f'\nOutputs: {self._outputs}' + dscp += f"\nOutputs: {self._outputs}\n{'=' * len(self.name)}\n" return dscp @@ -229,6 +242,11 @@ def __repr__(self): class IRAction(IRCell): + """ + Action recv tensors must be inside of Action inputs, + and can be mapped to Action.graph.inputs + + """ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): @@ -302,6 +320,9 @@ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): output_length = len(self.graph.outputs()) ) self.device = devices + for output in self.outputs(): + if isinstance(output, IRTensor): + output.device = devices self._inputs = action_inputs print(self.graph) @@ -314,13 +335,13 @@ def map_output(self, graph_output_tensor: Any) -> Any: def happen_before(self, action): """ Check if the self -> (happened before) action + + Note: this may return false negative as it will only check + 1-hop dependency """ if not isinstance(action, IRAction): raise TypeError("Expected action to be an Action") - for pre_actions in self.successors(): - if action in pre_actions: - return True - return False + return self in action.predecessors() def happen_after(self, action): """ @@ -331,10 +352,7 @@ def happen_after(self, action): """ if not isinstance(action, IRAction): raise TypeError("Expected action to be an Action") - for pre_actions in self.predecessors(): - if action in pre_actions: - return True - return False + return self in action.successors() def add_flow(self, action): """ diff --git a/cube/graph/ir_seq.py b/cube/graph/ir_seq.py new file mode 100644 index 00000000..4a43db51 --- /dev/null +++ b/cube/graph/ir_seq.py @@ -0,0 +1,102 @@ +from typing import List, Tuple, NewType, Any, Optional +import numpy as np + +from cube.graph.ir_cten import IRCell, IRTensor +from cube.graph.ir_graph import IRAction + + +class IRSequence(IRCell): + + def __init__(self, actions: List[IRAction]): + + if not all([isinstance(action, IRAction) for action in actions]): + raise TypeError("Expected a list of IRActions") + + super().__init__( + name = 'action', + signature = 'None', + input_length = 0, + output_length = 0 + ) + self.sequence = actions + + def __iter__(self): + return iter(self.sequence) + + def __len__(self): + return len(self.sequence) + + def append(self, action: IRAction): + self.sequence.append(action) + + def get_forward_inputs(self, action: IRAction) -> List[Any]: + """ + Get corresponding forward action inputs + + The backward graph output tensor shuould be forward graph input tensor + """ + if action.name == 'forward': + return action.inputs() + if action.name == 'backward': + bp_graph_outputs = action.graph.outputs() + fw_action_inputs = [None] * len(bp_graph_outputs) + pre_actions = action.predecessors() + while len(pre_actions) != 0: + pre = list() + for pre_action in pre_actions: + if pre_action.name == 'forward': + for bidx, output in enumerate(bp_graph_outputs): + for fidx, input in enumerate(pre_action.graph.inputs()): + if input == output: + fw_action_inputs[bidx] = pre_action.inputs(fidx) + pre += pre_action.predecessors() + pre_actions = pre + if None in fw_action_inputs: + raise RuntimeError("Couldn't found forward inputs") + return fw_action_inputs + raise RuntimeError(f"Unsupported action name: {action.name}") + + def get_forward_outputs(self, action: IRAction) -> List[Any]: + """ + Get corresponding forward action outputs + + The backward graph input tensor should be forward graph output tensor + """ + if action.name == 'forward': + return action.inputs() + if action.name == 'backward': + bp_graph_inputs = action.graph.inputs() + fw_action_outputs = [None] * len(bp_graph_inputs) + pre_actions = action.predecessors() + while len(pre_actions) != 0: + pre = list() + for pre_action in pre_actions: + if pre_action.name == 'forward': + for bidx, output in enumerate(bp_graph_inputs): + for fidx, input in enumerate(pre_action.graph.outputs()): + if input == output: + fw_action_outputs[bidx] = pre_action.outputs(fidx) + pre += pre_action.predecessors() + pre_actions = pre + if None in fw_action_outputs: + raise RuntimeError("Couldn't found forward inputs") + return fw_action_outputs + raise RuntimeError(f"Unsupported action name: {action.name}") + + + def is_correct(self): + """ + Check whether sequence + satisfies the sequential consistency model + """ + + for index, action in enumerate(self.sequence): + for pre_action in action.predecessors(): + # find the pre-action not appear in sequence + if not pre_action in self.sequence: + return False + pre_idx = self.sequence.index(pre_action) + # violate sequential consistency model + if pre_idx >= index: + return False + return True \ No newline at end of file diff --git a/cube/graph/mapping.py b/cube/graph/mapping.py index a94902bc..17a58614 100644 --- a/cube/graph/mapping.py +++ b/cube/graph/mapping.py @@ -14,6 +14,7 @@ def map(signature: str) -> logic.GenericLogicalOp : """ if signature in IR2LogicOp.kOpMap: return IR2LogicOp.kOpMap[signature] + # return None raise KeyError(f"{signature} is not supported yet") # functional templates diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 24a61270..768e751c 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -1,5 +1,6 @@ from cube.tschedule.pool import TSchedulePool from cube.graph.ir_cten import IRTensor +from cube.graph.ir_seq import IRSequence from torch import nn import torch @@ -66,11 +67,18 @@ def test_graph_backward(ir_graph): TSchedulePool().clear() tensor = ir_graph(IRTensor(shape=[64,1024])) tensor.backward() - #tensor = ir_graph() - #tensor.backward() + tensor = ir_graph(IRTensor(shape=[64,1024])) + tensor.backward() print('====== Backward Test =======') print(TSchedulePool()) + sequence = IRSequence(TSchedulePool().actions()) + from cube.codegen.codegen import TScheduleCodeGen + gener = TScheduleCodeGen(sequence) + code = gener.gen(device=0) + code = gener.gen(device=1) + print(code) + if __name__ == '__main__': From 3b81868a676347eb3cd87a785c55f22ce4563772 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Sep 2021 18:00:15 +0800 Subject: [PATCH 0195/1892] action re-structure --- cube/codegen/codegen.py | 36 ++++---- cube/graph/ir_cten.py | 42 ++++++++- cube/graph/ir_graph.py | 64 ++++++------- cube/tschedule/sequence.py | 147 ------------------------------ tests/tschedule/test_tschedule.py | 26 +++++- 5 files changed, 113 insertions(+), 202 deletions(-) delete mode 100644 cube/tschedule/sequence.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 7e696e7d..d608aafc 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -167,25 +167,18 @@ def gen(self, device: int, outfile=None) -> str: # {send: xxx, recv: xxx} action1 {send:xxx, recv:xxx} action2 .... action_with_comms = [dict()] for action in actions: - num_send_tensors = len(action.send_tensors) - if num_send_tensors == 0: - send_tensors = list() - else: - send_tensors = action.outputs()[-num_send_tensors:] + # send info + send_tensors, send_devices = action.get_send_tensors() + send_shapes = tuple([tensor.shape for tensor in send_tensors]) send_tensors = [self.naming(tensor) for tensor in send_tensors] - send_devices = action.send_devices - send_shapes = tuple([tensor.shape for tensor in action.send_tensors]) - num_recv_tensors = len(action.recv_tensors) - if num_recv_tensors == 0: - recv_tensors = list() - else: - recv_tensors = action.inputs()[-num_recv_tensors:] + # recv info + recv_tensors, recv_devices = action.get_recv_tensors() + recv_shapes = tuple([tensor.shape for tensor in recv_tensors]) recv_tensors = [self.naming(tensor) for tensor in recv_tensors] - recv_devices = action.recv_devices - recv_shapes = tuple([tensor.shape for tensor in action.recv_tensors]) comm = action_with_comms[-1] + # recv before the action if len(recv_tensors) != 0: comm.update({ @@ -193,8 +186,10 @@ def gen(self, device: int, outfile=None) -> str: 'recv_devices' : recv_devices, 'recv_shapes' : recv_shapes }) + # action action_with_comms.append(action) + # send after the action comm = dict() if len(send_tensors) != 0: @@ -207,7 +202,7 @@ def gen(self, device: int, outfile=None) -> str: # generate code with FunctionBlock(func_name='_train_step', - args=['model', 'inputs: Tuple[Tuple[Tensor]]']) as fb: + args=['model', 'dataloader']) as fb: for action_or_comm in action_with_comms: if isinstance(action_or_comm, dict): code = self.emit_comm(action_or_comm) @@ -331,10 +326,13 @@ def naming(self, tensor: Any) -> str: If the var is a leaf tensor, will add prefix `self.` to its name """ if isinstance(tensor, IRTensor): - tensor_name = 'tensor' if tensor.name is None else tensor.name - if '.' in tensor_name: - tensor_name = tensor_name.split('.')[0] - name = '_'.join([tensor_name, str(tensor._id)]) + if len(tensor.src()) == 0: + name = '*next(dataloader)' + else: + tensor_name = 'tensor' if tensor.name is None else tensor.name + if '.' in tensor_name: + tensor_name = tensor_name.split('.')[0] + name = '_'.join([tensor_name, str(tensor._id)]) else: name = str(tensor) return name diff --git a/cube/graph/ir_cten.py b/cube/graph/ir_cten.py index 0927cb24..ea5aa3e2 100644 --- a/cube/graph/ir_cten.py +++ b/cube/graph/ir_cten.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional, Any +from typing import List, Union, Optional, Any, Tuple import copy from cube.graph.unique import IDGenerator @@ -223,6 +223,46 @@ def add_successor(self, tensor, node): raise RuntimeError("Fail to find output tensor") self._successors[out_index].append(node) + def get_send_tensors(self): + """ + Collect send tensors at cell level. + This will not care what happened inside this cell + + Returns: + send_tensors: list of IRTensor + send_devices: list of list[int] devices for each tensor + """ + send_tensors = list() + send_devices = list() + for idx, output in enumerate(self.outputs()): + if isinstance(output, IRTensor): + succ_cells = self.successors(idx) + for cell in succ_cells: + devices = set(cell.device) - set(output.device) + if len(devices) != 0: + send_tensors.append(output) + send_devices.append(list(devices)) + return send_tensors, send_devices + + def get_recv_tensors(self): + """ + Collect recv tensors at cell level. + This will not care what happened inside this cell + + Returns: + recv_tensors: list of IRTensor + recv_devices: list of list[int] devices for each tensor + """ + recv_tensors = list() + recv_devices = list() + for input in self.inputs(): + if isinstance(input, IRTensor): + devices = set(self.device) - set(input.device) + if len(devices) != 0: + recv_tensors.append(input) + recv_devices.append(list(devices)) + return recv_tensors, recv_devices + def __repr__(self): """ Cell string presentation diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index adee1cb4..2ce5f638 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -170,14 +170,6 @@ def get_tensor_grad(tensor): # nodes backward_nodes = list() all_bp_tensors = list() - # the first node: loss to none - # none_node = IROperation( - # name = 'tonone', - # signature = 'cube.temporal.to_none', - # input_length=1, - # output_length=1 - # ) - # none_node.set_input(0, loss) for fnode in self._nodes[::-1]: inputs = list() @@ -215,7 +207,7 @@ def get_tensor_grad(tensor): graph_inputs, graph_outputs, self.name + 'Backward' ) - print(graph) + # print(graph) graph.tag = 'backward' graph(loss) @@ -264,13 +256,11 @@ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): raise RuntimeError(f"Unsupported graph tag: {self.global_graph.tag}") # send tensors - self.send_tensors = list() - self.send_devices = list() - + send_tensors = list() + send_devices = list() # recv tensors - self.recv_tensors = list() - self.recv_devices = list() - + recv_tensors = list() + recv_devices = list() # get nodes belong to this graph all_tensors = list() for node in sub_nodes: @@ -279,9 +269,9 @@ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): if isinstance(input, IRTensor): recv_devices = list(set(devices) - set(input.device)) if len(recv_devices) != 0: - if input not in self.recv_tensors: - self.recv_tensors.append(input) - self.recv_devices.append(recv_devices) + if input not in recv_tensors: + recv_tensors.append(input) + recv_devices.append(recv_devices) # collect send tensors for output in node.outputs(): if isinstance(output, IRTensor): @@ -289,9 +279,9 @@ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): for succ_node in succ_nodes: send_devices = list(set(devices) - set(succ_node.device)) if len(send_devices) != 0: - if output not in self.send_tensors: - self.send_tensors.append(output) - self.send_devices.append(send_devices) + if output not in send_tensors: + send_tensors.append(output) + send_devices.append(send_devices) all_tensors += node.inputs() all_tensors += node.outputs() @@ -299,32 +289,42 @@ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): inputs = list() outputs = list() for input in global_graph.inputs(): - if input in all_tensors and input not in self.recv_tensors: + if input in all_tensors and input not in recv_tensors: inputs.append(input) for output in global_graph.outputs(): - if output in all_tensors and output not in self.send_tensors: + if output in all_tensors and output not in send_tensors: outputs.append(output) + self._send_ofst = len(outputs) + self._recv_ofst = len(inputs) + self.graph = IRGraph( nodes = sub_nodes, - input_tensors = inputs + self.recv_tensors, - output_tensors = outputs + self.send_tensors, + input_tensors = inputs + recv_tensors, + output_tensors = outputs + send_tensors, module_name = global_graph.name ) - action_inputs = [None] * len(self.graph.inputs()) super().__init__( name = global_graph.tag, signature = signature, - input_length = len(action_inputs), + input_length = len(self.graph.inputs()), output_length = len(self.graph.outputs()) ) + # set action device self.device = devices - for output in self.outputs(): - if isinstance(output, IRTensor): - output.device = devices - self._inputs = action_inputs - print(self.graph) + # set output shape + for output, g_out in zip(self.outputs(), self.graph.outputs()): + output.device = devices + output.shape = g_out.shape + + @property + def send_tensors(self): + return self._outputs[self._send_ofst:] + + @property + def recv_tensors(self): + return self._inputs[self._recv_ofst:] def map_output(self, graph_output_tensor: Any) -> Any: if graph_output_tensor not in self.graph.outputs(): diff --git a/cube/tschedule/sequence.py b/cube/tschedule/sequence.py deleted file mode 100644 index 2673752d..00000000 --- a/cube/tschedule/sequence.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import List, Tuple, NewType, Any -import numpy as np - - -class ASequence: - - def __init__(self, actions): - - self.sequence = actions - - def __iter__(self): - """ - Iterate on the actions - """ - return self.sequence - - def __len__(self) -> int: - """ - Get number of action in the sequence - """ - return len(self.sequence) - - def append(self, action): - self.sequence.append(action) - - def pop(self): - """ - Pop the last action and return - """ - if len(self.sequence) == 0: - return None - return self.sequence.pop() - - def is_correct(self): - """ - Check whether sequence - satisfies the sequential consistency model - """ - for index, action in enumerate(self.sequence): - for pre_action in action.pre_actions(): - # find the pre-action not appear in sequence - if not pre_action in self.sequence: - return False - pre_idx = self.sequence.index(pre_action) - # violate happened before - if pre_idx >= index: - return False - return True - - -# ======= Blow should be moved from this module ======== # - -Relation = NewType('Relation', List[Tuple[Any, Any]]) - - -class ScheduleSpace: - - @staticmethod - def tspace(remain_actions, - path_shuffle=True, - relations=None, - seq: ASequence = ASequence(list())): - """ - Iterate on the legal sequence space - """ - if len(remain_actions) == 0: - yield seq - # inital entry - if relations is None: - relations = ScheduleSpace._get_relations(remain_actions) - entry_actions = ScheduleSpace._ready_actions(remain_actions, relations) - entry_actions = np.array(entry_actions) - - # recursive search - if path_shuffle: - np.random.shuffle(entry_actions) - for aid, action in enumerate(entry_actions): - if len(seq) == 0: - print(f'> search progress: [{aid+1}/{len(entry_actions)}]...') - seq.append(action) - action_idx = remain_actions.index(action) - sub_actions = remain_actions[:action_idx] + remain_actions[action_idx+1:] - sub_relations = ScheduleSpace._remove_action(action, relations) - for res in ScheduleSpace.space(sub_actions, path_shuffle, sub_relations, seq): - yield res - seq.pop() - - - @staticmethod - def sspace(actions, ndevice: int, path_shuffle=True, depth=0): - """ - Iterate on the possible action space - """ - if depth == len(actions): - yield actions - return - action = actions[depth] - device_choice = np.array(list(range(ndevice)), dtype=np.int) - if path_shuffle: - np.random.shuffle(device_choice) - for device in device_choice: - action.device = device - for res in ScheduleSpace.sspace(actions, ndevice, path_shuffle, depth+1): - yield res - - - @staticmethod - def _ready_actions(actions, sub_relations: Relation): - """ - Get ready to emit actions based on sub_relations - """ - ready_actions = list() - for action in actions: - satisfy = True - for (_, succ) in sub_relations: - if succ == action: - satisfy = False - break - if satisfy: - ready_actions.append(action) - return ready_actions - - - @staticmethod - def _get_relations(actions) -> Relation: - """ - Get relation tuples (Action1 -> Action2) - """ - relations = list() - for action in actions: - relation = [(pre_action, action) for pre_action in action.pre_actions()] - if len(relation) != 0: - relations += relation - return relations - - - @staticmethod - def _remove_action(target, relations: Relation) -> Relation: - """ - Remove the target action from relation set - """ - sub_relations = list() - for (pre, succ) in relations: - if pre == target or succ == target: - continue - sub_relations.append((pre, succ)) - return sub_relations diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 768e751c..8405f62f 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -1,3 +1,4 @@ +from cube.graph.ir_graph import IRGraph from cube.tschedule.pool import TSchedulePool from cube.graph.ir_cten import IRTensor from cube.graph.ir_seq import IRSequence @@ -5,7 +6,7 @@ import torch import cube.graph.parser as parser -from cube.codegen.codegen import SScheduleCodeGen +from cube.codegen.codegen import SScheduleCodeGen, TScheduleCodeGen class FeedForward(nn.Module): @@ -65,9 +66,13 @@ def test_graph_forward(ir_graph): def test_graph_backward(ir_graph): TSchedulePool().clear() - tensor = ir_graph(IRTensor(shape=[64,1024])) + input = IRTensor(shape=[64,1024]) + input.device = [0] + tensor = ir_graph(input) tensor.backward() - tensor = ir_graph(IRTensor(shape=[64,1024])) + input = IRTensor(shape=[64,1024]) + input.device = [0] + tensor = ir_graph(input) tensor.backward() print('====== Backward Test =======') print(TSchedulePool()) @@ -79,6 +84,21 @@ def test_graph_backward(ir_graph): code = gener.gen(device=1) print(code) +def test_graph(ir_graph): + + datas = None + model: IRGraph = None + + @tschedule(model=ir_graph) + def train_step(model, datas): + for data in datas: + loss = model(data) + loss.backward() + + for epoch in range(10): + for datas in dataloader(bs=64, mbs=4): + train_step(model, datas) + if __name__ == '__main__': From e49bc5a36996bff0ec198820b02d702b67e74de6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Sep 2021 18:29:56 +0800 Subject: [PATCH 0196/1892] move action to subgraph creation --- cube/graph/ir_graph.py | 113 +++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 50 deletions(-) diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 2ce5f638..9438f38c 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -1,4 +1,3 @@ -from torch._C import device from cube.graph.ir_cten import IRTensor, IRCell from cube.graph.ir_op import IROperation from cube.tschedule.pool import TSchedulePool @@ -211,6 +210,66 @@ def get_tensor_grad(tensor): graph.tag = 'backward' graph(loss) + def subgraph(self, sub_nodes: List[IRCell]): + """ + Create a subgraph with sub nodes. + + The remote tensor will be set as graph input (recv tensors) + and graph output (send tensors) + + Return: + IRGraph, + recv tensor starting offset (int) in input, + send tensor starting offset (int) in output + """ + def _update(x_tensors, x_devices, tensor, devices): + if tensor not in x_tensors: + x_tensors.append(tensor) + x_devices.append(set(devices)) + else: + idx = x_tensors.index(tensor) + x_devices[idx].update(set(devices)) + + # recv tensors + recv_tensors = list() + recv_devices = list() + # send tensors + send_tensors = list() + send_devices = list() + # get nodes belong to this graph + all_tensors = list() + for node in sub_nodes: + # collect recv tensors + tensors_and_devices = node.get_recv_tensors() + for r_tensor, r_devices in zip(*tensors_and_devices): + _update(recv_tensors, recv_devices, r_tensor, r_devices) + # collect send tensors + tensors_and_devices = node.get_send_tensors() + for s_tensor, s_devices in zip(*tensors_and_devices): + _update(send_tensors, send_devices, s_tensor, s_devices) + all_tensors += node.inputs() + all_tensors += node.outputs() + + # set extra graph inputs and outputs + inputs = list() + outputs = list() + for input in self.inputs(): + if input in all_tensors and input not in recv_tensors: + inputs.append(input) + for output in self.outputs(): + if output in all_tensors and output not in send_tensors: + outputs.append(output) + + graph = IRGraph( + nodes = sub_nodes, + input_tensors = inputs + recv_tensors, + output_tensors = outputs + send_tensors, + module_name = self.name + ) + + return graph, len(inputs), len(outputs) + + def __repr__(self): dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs @@ -255,55 +314,9 @@ def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): else: raise RuntimeError(f"Unsupported graph tag: {self.global_graph.tag}") - # send tensors - send_tensors = list() - send_devices = list() - # recv tensors - recv_tensors = list() - recv_devices = list() - # get nodes belong to this graph - all_tensors = list() - for node in sub_nodes: - # collect recv tensors - for input in node.inputs(): - if isinstance(input, IRTensor): - recv_devices = list(set(devices) - set(input.device)) - if len(recv_devices) != 0: - if input not in recv_tensors: - recv_tensors.append(input) - recv_devices.append(recv_devices) - # collect send tensors - for output in node.outputs(): - if isinstance(output, IRTensor): - succ_nodes = output.dst() - for succ_node in succ_nodes: - send_devices = list(set(devices) - set(succ_node.device)) - if len(send_devices) != 0: - if output not in send_tensors: - send_tensors.append(output) - send_devices.append(send_devices) - all_tensors += node.inputs() - all_tensors += node.outputs() - - # action graph inputs and outputs - inputs = list() - outputs = list() - for input in global_graph.inputs(): - if input in all_tensors and input not in recv_tensors: - inputs.append(input) - for output in global_graph.outputs(): - if output in all_tensors and output not in send_tensors: - outputs.append(output) - - self._send_ofst = len(outputs) - self._recv_ofst = len(inputs) - - self.graph = IRGraph( - nodes = sub_nodes, - input_tensors = inputs + recv_tensors, - output_tensors = outputs + send_tensors, - module_name = global_graph.name - ) + self.graph, recv_ofst, send_ofst = global_graph.subgraph(sub_nodes) + self._recv_ofst = recv_ofst + self._send_ofst = send_ofst super().__init__( name = global_graph.tag, From ba2d8a0b4b03f680ebf24829cd3ebb2b29bf0a10 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Sep 2021 19:04:46 +0800 Subject: [PATCH 0197/1892] clean up forward procedure --- cube/graph/ir_graph.py | 46 ++++++++++++++++-------------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 9438f38c..5e03cf96 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -78,20 +78,21 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: if not all([type(arg) is type(input) for arg, input in zip(args, self.inputs())]): raise RuntimeError(f"Expected input type the same") - curr_nodes: List[IROperation] = list() - curr_device = list() + curr_nodes: List[IRCell] = list() + curr_device = self.nodes(0).device total_actions = list() for node in self.nodes(): device = node.device if len(node.device) == 0: raise RuntimeError("All the node should be assigned to devices") - if set(device) != set(curr_device) and len(curr_device) != 0: + if set(device) != set(curr_device): # create action action = IRAction(curr_nodes, self, devices=curr_device) total_actions.append(action) # register to schedule space TSchedulePool().add_action(action) + # clear curr_nodes = list() curr_device = device curr_nodes.append(node) @@ -101,31 +102,25 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: TSchedulePool().add_action(action) # setup action inputs - head = total_actions[0] - for idx, arg in enumerate(args): - head.set_input(idx, arg) - outputs_tensors = [*head.graph.outputs()] - outputs_actions = [head] * len(head.graph.outputs()) - for action in total_actions[1:]: + output_map = { + gten._id : aten for gten, aten in zip(self.inputs(), args) + } + for action in total_actions: for idx, input in enumerate(action.graph.inputs()): - if input not in outputs_tensors: - raise RuntimeError(f"Cannot find {input} tensors") - pre_action = outputs_actions[outputs_tensors.index(input)] - val = pre_action.map_output(input) - action.set_input(idx, val) - outputs_tensors += action.graph.outputs() - outputs_actions += [action] * len(action.graph.outputs()) + if isinstance(input, IRTensor): + input = output_map[input._id] + action.set_input(idx, input) + for action_out, graph_out in zip(action.outputs(), action.graph.outputs()): + output_map[graph_out._id] = action_out # return tensors outputs = tuple(total_actions[-1].outputs()) for output in outputs: output.set_gen_graph(self) - if len(outputs) == 1: - return outputs[0] - elif len(outputs) == 0: - return None - else: - return outputs + + if len(outputs) == 1: return outputs[0] + elif len(outputs) == 0: return None + else: return outputs def __call__(self, *args): """ @@ -206,7 +201,6 @@ def get_tensor_grad(tensor): graph_inputs, graph_outputs, self.name + 'Backward' ) - # print(graph) graph.tag = 'backward' graph(loss) @@ -339,12 +333,6 @@ def send_tensors(self): def recv_tensors(self): return self._inputs[self._recv_ofst:] - def map_output(self, graph_output_tensor: Any) -> Any: - if graph_output_tensor not in self.graph.outputs(): - return None - index = self.graph.outputs().index(graph_output_tensor) - return self.outputs(index) - def happen_before(self, action): """ Check if the self -> (happened before) action From 94762d5e48d2dda52dfe9cd440e6b31953909e1d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 24 Sep 2021 13:49:19 +0800 Subject: [PATCH 0198/1892] fix bugs on recv from devices --- cube/graph/ir_cten.py | 2 +- cube/graph/ir_graph.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/graph/ir_cten.py b/cube/graph/ir_cten.py index ea5aa3e2..abd8a19f 100644 --- a/cube/graph/ir_cten.py +++ b/cube/graph/ir_cten.py @@ -257,7 +257,7 @@ def get_recv_tensors(self): recv_devices = list() for input in self.inputs(): if isinstance(input, IRTensor): - devices = set(self.device) - set(input.device) + devices = set(input.device) - set(self.device) if len(devices) != 0: recv_tensors.append(input) recv_devices.append(list(devices)) diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 5e03cf96..75382964 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -1,8 +1,7 @@ from cube.graph.ir_cten import IRTensor, IRCell from cube.graph.ir_op import IROperation -from cube.tschedule.pool import TSchedulePool -from typing import Union, Tuple, List, Optional, Any +from typing import Union, Tuple, List, Optional __all__ = ['IRGraph', 'IRAction'] @@ -69,6 +68,7 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: Returns: List[Action] """ + from cube.tschedule.pool import TSchedulePool # check input num if len(args) != len(self.inputs()): raise RuntimeError( From 751f01b0963ac9c2accdd1ff4808c665df868790 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 24 Sep 2021 14:24:34 +0800 Subject: [PATCH 0199/1892] generate code for spatial and temporal --- cube/__init__.py | 4 ++ cube/codegen/codegen.py | 61 +++++++++++---------- cube/graph/ir_seq.py | 2 +- cube/graph/parser/parser.py | 3 + cube/runtime/__init__.py | 1 + cube/runtime/collectives.py | 90 ++++++++++++++++++++++++++++++ cube/runtime/temporal.py | 38 +++++++++++++ cube/sschedule/__init__.py | 55 +++++++++++++++++++ cube/tschedule/__init__.py | 97 +++++++++++++++++++++++++++++++++ cube/tschedule/pool.py | 20 ------- examples/e2e.py | 102 ++++++++++++++++++++++++++++++++++ examples/ffn.py | 58 -------------------- examples/linear.py | 106 ------------------------------------ gencode0.py | 36 ++++++++++++ gencode1.py | 40 ++++++++++++++ 15 files changed, 498 insertions(+), 215 deletions(-) create mode 100644 cube/runtime/__init__.py create mode 100644 cube/runtime/collectives.py create mode 100644 cube/runtime/temporal.py create mode 100644 cube/sschedule/__init__.py create mode 100644 examples/e2e.py delete mode 100644 examples/ffn.py delete mode 100644 examples/linear.py create mode 100644 gencode0.py create mode 100644 gencode1.py diff --git a/cube/__init__.py b/cube/__init__.py index e69de29b..e8b31d7d 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -0,0 +1,4 @@ +from cube.device.physic.group import DeviceGroup +from cube import sschedule +from cube import tschedule +from cube import runtime \ No newline at end of file diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index d608aafc..7b99d212 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -9,6 +9,7 @@ from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock import torch +import copy class SScheduleCodeGen: @@ -21,8 +22,8 @@ def __init__(self, action: IRAction): raise TypeError("graph should be IRGraph") self.graph = action.graph # model full code - self.code: List[str] = [ - '########## Generated Code ###########', + self.init_code: List[str] = [ + '\n\n########## Generated Code ###########', 'import torch', '', ''] # module init code self.declare_region: List[str] = list() @@ -33,10 +34,11 @@ def __init__(self, action: IRAction): # ref module to check shared variables self._ref_module = torch.nn.Module() - def gen(self, device: int, outfile=None) -> str: + def gen(self, device: int, outfile=None, attach=False) -> str: """ Generate model implementation code based on the given graph. """ + gencode = copy.copy(self.init_code) # register forward input fargs = [self.naming(input) for input in self.graph.inputs()] for name in fargs: @@ -67,13 +69,13 @@ def gen(self, device: int, outfile=None) -> str: fb.insert_body(return_code) cb.insert_body('') cb.insert_body(fb.code) - self.code += cb.code - self.code += [''] + gencode += cb.code + gencode += [''] - code = '\n'.join(self.code) + code = '\n'.join(gencode) # write to file if outfile: - with open(outfile, 'w') as f: + with open(outfile, 'a' if attach else 'w') as f: f.write(code) return code @@ -149,16 +151,17 @@ def __init__(self, seq: IRSequence): raise TypeError("seq should be IRSequence") self.seq = seq # model full code - self.code: List[str] = [ - '########## Generated Code ###########', - 'from typing import Tuple', 'import torch', '', ''] + self.init_code: List[str] = [ + '\n\n########## Generated Code ###########', + 'import torch', 'import cube', ''] # module member name self.symbols = SymbolTable() - def gen(self, device: int, outfile=None) -> str: + def gen(self, device: int, outfile=None, attach=False) -> str: """ Generate scheduling code based on the given actions """ + gencode = copy.copy(self.init_code) actions = list() for action in self.seq: if device in action.device: @@ -210,13 +213,13 @@ def gen(self, device: int, outfile=None) -> str: else: code = self.emit_action(action_or_comm) fb.insert_body(code) - self.code += fb.code - self.code += [''] + gencode += fb.code + gencode += [''] - code = '\n'.join(self.code) + code = '\n'.join(gencode) # write to file if outfile: - with open(outfile, 'w') as f: + with open(outfile, 'a' if attach else 'w') as f: f.write(code) return code @@ -224,40 +227,38 @@ def emit_comm(self, comm: Dict) -> List[str]: """ Emit send / recv code """ - ssign = 'cube.runtime.spatial.send({send_tensors}, {shapes}, {to_devices})' - rsign = 'cube.runtime.spatial.recv({shapes}, {from_devices})' - srsign = 'cube.runtime.spatial.send_and_recv({send_tensors}, {send_shapes}, {to_devices}, {recv_shapes}, {from_devices})' + ssign = 'cube.runtime.collectives.send({send_tensors}, {to_devices})' + rsign = 'cube.runtime.collectives.recv({shapes}, {from_devices})' + srsign = 'cube.runtime.collectives.send_and_recv({send_tensors}, {to_devices}, {recv_shapes}, {from_devices})' # generate for send if ('send_tensors') in comm and ('recv_tensors' not in comm): - send_tensors = ', '.join(comm['send_tensors']) + send_tensors = '(' + ', '.join(comm['send_tensors'] + ['']) + ')' code = ssign.format( send_tensors = send_tensors, - shapes = comm['send_shapes'], to_devices = comm['send_devices'] ) - return code + return code + f" # send: {comm['send_shapes']}" # generate for recv elif ('send_tensors' not in comm) and ('recv_tensors' in comm): body = rsign.format( shapes = comm['recv_shapes'], from_devices = comm['recv_devices'] ) - return_val = ','.join(comm['recv_tensors']) + return_val = ', '.join(comm['recv_tensors']) code = f'{return_val} = {body}' return code # generate for send + recv elif ('send_tensors' in comm) and ('recv_tensors' in comm): - send_tensors = ', '.join(comm['send_tensors']) + send_tensors = '(' + ', '.join(comm['send_tensors'] + ['']) + ')' body = srsign.format( send_tensors = send_tensors, - send_shapes = comm['send_shapes'], to_devices = comm['send_devices'], recv_shapes = comm['recv_shapes'], from_devices = comm['recv_devices'] ) - return_val = ','.join(comm['recv_tensors']) - code = f'{return_val} = {body}' + return_val = ', '.join(comm['recv_tensors']) + code = f"{return_val} = {body} # send: {comm['send_shapes']}" return code else: return [] @@ -271,7 +272,7 @@ def emit_action(self, action: IRAction) -> List[str]: if action.name == 'forward': inputs = [self.naming(tensor) for tensor in action.inputs()] - inputs = '(' + ', '.join(inputs) + ',)' + inputs = '(' + ', '.join(inputs + ['']) + ')' body = fsign.format( model = 'model', inputs = inputs @@ -290,17 +291,17 @@ def emit_action(self, action: IRAction) -> List[str]: # => backward graph input tensor (graph.recv_tensors) forward_inputs = self.seq.get_forward_inputs(action) forward_inputs = [self.naming(tensor) for tensor in forward_inputs] - forward_inputs = '(' + ', '.join(forward_inputs) + ',)' + forward_inputs = '(' + ', '.join(forward_inputs + ['']) + ')' forward_outputs = self.seq.get_forward_outputs(action) forward_outputs = [self.naming(tensor) for tensor in forward_outputs] - forward_outputs = '(' + ', '.join(forward_outputs) + ',)' + forward_outputs = '(' + ', '.join(forward_outputs + ['']) + ')' num_recv_tensors = len(action.recv_tensors) if num_recv_tensors == 0: recv_grads = list() else: recv_grads = action.inputs()[-num_recv_tensors:] recv_grads = [self.naming(tensor) for tensor in recv_grads] - recv_grads = '(' + ','.join(recv_grads) + ',)' + recv_grads = '(' + ', '.join(recv_grads + ['']) + ')' body = bsign.format( input_tensors = forward_inputs, diff --git a/cube/graph/ir_seq.py b/cube/graph/ir_seq.py index 4a43db51..be754862 100644 --- a/cube/graph/ir_seq.py +++ b/cube/graph/ir_seq.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, NewType, Any, Optional +from typing import List, Any import numpy as np from cube.graph.ir_cten import IRCell, IRTensor diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 1fa0013c..4429bcba 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -159,6 +159,9 @@ def parse_aten_node(node, module, frame: Frame) -> List[IROperation]: input_val.append(val) maybe_kwarg = False input_val = input_val[::-1] + # handle single operand e.g., torch.sum + if input_val[1] is None: + input_val = input_val[:1] + input_val[2:] if len(input_val) < len(inputs): print(f"Warning: some non-tensor arguments are ommited in {fsig}") diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py new file mode 100644 index 00000000..7fe29b15 --- /dev/null +++ b/cube/runtime/__init__.py @@ -0,0 +1 @@ +from cube.runtime import collectives, temporal diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py new file mode 100644 index 00000000..04dc9fd2 --- /dev/null +++ b/cube/runtime/collectives.py @@ -0,0 +1,90 @@ +from typing import List + +import torch + + +def send(tensors, to_ranks: List[List[int]]): + """ + send tensor to the remote devices. Each tensor can be + sent to multiple devices + + Args: + tensors (List[torch.Tensor]): list of tensor to send + tensor_devices (List[List[int]]): tensor sent devices + """ + print('sending...') + send_ops = list() + for tensor, ranks in zip(tensors, to_ranks): + for rank in ranks: + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, rank + ) + send_ops.append(send_op) + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + + +def recv(shapes: List[List[int]], from_ranks: List[List[int]]): + print('recving...') + recv_ops = list() + recv_tensors = list() + for shape, ranks in zip(shapes, from_ranks): + if len(ranks) != 1: + raise RuntimeError( + "Not supported for recving same tensor from multiple devices" + ) + rank = ranks[0] + tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device() + ) + recv_tensors.append(tensor) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, rank + ) + recv_ops.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(recv_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + + if len(recv_tensors) == 0: return None + elif len(recv_tensors) == 1: return recv_tensors[0] + else: return tuple(recv_tensors) + + +def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): + print('sending and recving...') + ops = list() + recv_tensors = list() + for tensor, ranks in zip(send_tensors, to_ranks): + if not torch.is_tensor(tensor): + raise RuntimeError(f"Expected {tensor} to be tensor") + for rank in ranks: + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, rank + ) + ops.append(send_op) + for shape, ranks in zip(recv_shapes, from_ranks): + if len(ranks) != 1: + raise RuntimeError( + "Not supported for recving same tensor from multiple devices" + ) + rank = ranks[0] + tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device() + ) + recv_tensors.append(tensor) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, rank + ) + ops.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + + if len(recv_tensors) == 0: return None + elif len(recv_tensors) == 1: return recv_tensors[0] + else: return tuple(recv_tensors) diff --git a/cube/runtime/temporal.py b/cube/runtime/temporal.py new file mode 100644 index 00000000..5e8703ba --- /dev/null +++ b/cube/runtime/temporal.py @@ -0,0 +1,38 @@ +from typing import Tuple, Any +import torch + + +def forward(model, *input_tensors: Tuple[Any]): + """ + forward the model + """ + outputs = model(*input_tensors) + print('forwarding... ') + return outputs + + +def backward(input_tensors, output_tensors, output_tensor_grads): + """ + Backward on the tensors + """ + for tensor in input_tensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + tensor.retain_grad() + + # TODO: gen code should contain None in output_tensor_grads + if len(output_tensor_grads) != len(output_tensors): + output_tensor_grads = [None] * len(output_tensors) + + for tensor, grads in zip(output_tensors, output_tensor_grads): + print('backwarding... ') + torch.autograd.backward(tensor, grad_tensors=grads) + grads = list() + for tensor in input_tensors: + # print('backward input tensor: {}'.format(tensor)) + if torch.is_tensor(tensor) and tensor.requires_grad: + grads.append(tensor.grad) + else: + grads.append(None) + if len(grads) == 0: return None + elif len(grads) == 1: return grads[0] + else: return tuple(grads) diff --git a/cube/sschedule/__init__.py b/cube/sschedule/__init__.py new file mode 100644 index 00000000..622ac5c9 --- /dev/null +++ b/cube/sschedule/__init__.py @@ -0,0 +1,55 @@ +from cube.graph import parser +from cube.graph.ir_graph import IRGraph, IRAction +from cube.codegen.codegen import SScheduleCodeGen + + +class SpatialModule: + + def __init__(self, ir_graph): + # the full semantic graph + self._ir_graph = ir_graph + # the spatial pytorch module for specific rank + self._loaded_module = None + + def get_graph(self): + return self._ir_graph + + def gen_module(self, rank, outfile, attach=False) -> str: + """ + Set the module + """ + # TODO: support multiple graph segments + subnodes = [node for node in self._ir_graph.nodes() if node.on_device(rank)] + # subgraph = self._ir_graph.subgraph(subnodes) + action = IRAction(subnodes, self._ir_graph, devices=[rank]) + gener = SScheduleCodeGen(action) + code = gener.gen(device=rank, outfile=outfile, attach=attach) + return code + + def load_module(self, filename: str): + print(f'> loading generated spatial moduel from {filename}') + import importlib.util + spec = importlib.util.spec_from_file_location("GenModel", filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._loaded_module = module.GenModel().cuda() + + def get_gen_module(self): + return self._loaded_module + + def clear_module(self): + self._loaded_module = None + + +def schedule(module, input_shapes, policy_fn=None): + """ + Spatial schedule + + Returns: + IRGraph + """ + ir_graph = parser.convert(module, input_shapes=input_shapes) + module = SpatialModule(ir_graph) + if policy_fn: + module._ir_graph = policy_fn(module.get_graph()) + return module diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py index e69de29b..07ba260a 100644 --- a/cube/tschedule/__init__.py +++ b/cube/tschedule/__init__.py @@ -0,0 +1,97 @@ +from typing import Callable, Optional +import torch + +from cube.tschedule.pool import TSchedulePool +from cube.graph.ir_cten import IRTensor +from cube.graph.ir_seq import IRSequence +from cube.codegen.codegen import TScheduleCodeGen + + +class IRTesnorDataLoader: + + def __init__(self, dataloader): + self.dataloader = dataloader + + def __iter__(self): + return self + + def __next__(self): + datas = next(self.dataloader) + ir_datas = list() + for data in datas: + if torch.is_tensor(data): + tensor = IRTensor(shape=list(data.size()), name='input') + tensor.device = [0] + else: + tensor = data + ir_datas.append(tensor) + return tuple(ir_datas) + + +def schedule(model, dataloader, policy_fn: Optional[Callable] = None): + """ + AI Scientist calls like: + + @cube.tschedule.schedule + def train_step(model, dataloader): + # do a 4-time gradient accumulation + for acc_step, (data, label) in enumerate(dataloader): + if acc_step < 4: + loss = model(data, label) + loss.backward() + else: + break + ... + + for epoch in range(100): + train_step(model, data_loader) + optimizer.step() + optimizer.zero_grad() + + ... + """ + ir_graph = model.get_graph() + ir_dataloader = IRTesnorDataLoader(dataloader) + myrank = torch.distributed.get_rank() + + def _load_tschedule_fn(filename) -> Callable: + print(f'> [{myrank}] loading generated schedule from {filename} ...') + import importlib.util + spec = importlib.util.spec_from_file_location( + "_train_step", filename + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module._train_step + + def decorator(fn: Callable) -> Callable: + filename = 'gencode{}.py' + if myrank == 0: + TSchedulePool().clear() + # collect trace + fn(ir_graph, ir_dataloader) + actions = TSchedulePool().actions() + seq = IRSequence(actions) + # policy + if policy_fn: + seq = policy_fn(seq) + + world_size = torch.distributed.get_world_size() + tgener = TScheduleCodeGen(seq) + for rank in range(world_size): + fname = filename.format(rank) + # generate spatial module code + model.gen_module(rank, fname, attach=False) + # generate temporal schedule code + tgener.gen( + device = rank, + outfile = fname, + attach=True + ) + torch.distributed.barrier() + # load module + model.load_module(filename.format(myrank)) + # load temporal + return _load_tschedule_fn(filename.format(myrank)) + + return decorator diff --git a/cube/tschedule/pool.py b/cube/tschedule/pool.py index 21c5c5b4..545d204a 100644 --- a/cube/tschedule/pool.py +++ b/cube/tschedule/pool.py @@ -1,4 +1,3 @@ -from typing import Callable class TSchedulePool: @@ -39,22 +38,3 @@ def gen_id(self) -> int: def __repr__(self): dscp = '\n'.join([repr(action) for action in self._actions]) return dscp - - -def schedule(fn: Callable, policy=None, *args, **kwargs): - """ - AI Scientist calls like: - - @cube.tschedule.schedule - def train_step(model, optimizer, datas, labels): - for (data, label) in datas: - loss = model(data, label) - loss.backward() - optimizer.step() - optimizer.zero_grad() - ... - for datas, labels in dataloader(): - train_step(model, optimizer, datas, labels) - ... - """ - raise NotImplementedError \ No newline at end of file diff --git a/examples/e2e.py b/examples/e2e.py new file mode 100644 index 00000000..66819c5c --- /dev/null +++ b/examples/e2e.py @@ -0,0 +1,102 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/e2e.py +""" + +import torch +from torch import nn + +import cube +from cube.graph.ir_cten import IRTensor + +def spolicy(ir_graph): + + for input in ir_graph.inputs(): + if isinstance(input, IRTensor): + input.device = [0] + for nid, node in enumerate(ir_graph.nodes()): + if nid <= 2: + node.device = 0 + else: + node.device = 1 + return ir_graph + + +class FakeDataLoader: + def __init__(self, batch_size, num=32): + self.batch_size = batch_size + self.length = num + self.pos = 0 + def __iter__(self): + self.pos = 0 + return self + def __next__(self): + self.pos += 1 + if self.pos == self.length: + raise StopIteration + return (torch.randn((self.batch_size, 1024)).cuda(),) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=16, classes=1000): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult, bias=False) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim * mult, dim) + self.classifier = nn.Linear(dim, classes) + + def forward(self, data): + output = self.linear1(data) + output = self.gelu(output) + output = self.dropout(output) + output = self.linear2(output) + output = self.classifier(output) + loss = torch.sum(output) + return loss + +def init_weight(parameters): + for param in parameters: + with torch.no_grad(): + torch.nn.init.uniform_(param) + + +def train(): + model = FeedForward(dim=1024) + model = cube.sschedule.schedule( + model, input_shapes=([64,1024],), + policy_fn=spolicy + ) + + dataloader = FakeDataLoader(64) + + @cube.tschedule.schedule(model, dataloader) + def train_iter(model, dataloader): + for _ in range(4): + (data,) = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + init_weight(model.parameters()) + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + for epoch in range(100): + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + +if __name__ == '__main__': + + cube.DeviceGroup() + train() diff --git a/examples/ffn.py b/examples/ffn.py deleted file mode 100644 index 960adb28..00000000 --- a/examples/ffn.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F - -import argparse - - -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=4): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult) - self.gelu = nn.GELU() - self.linear2 = nn.Linear(dim * mult, dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.linear1(x) - x = self.gelu(x) - x = self.linear2(x) - x = self.dropout(x) - return x - - -def data_iter(bs, dim, classes, length=64): - for _ in range(length): - data = torch.randn((bs, dim)) - label = torch.randint(0, classes, (bs,)) - yield data, label - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--dim', type=int, default=1024) - parser.add_argument('--bs', type=int, default=8) - parser.add_argument('--classes', type=int, default=10) - args = parser.parse_args() - - model = torch.jit.script(FeedForward(args.dim).cuda()) - print(model.code) - - optimizer = torch.optim.Adam( - model.parameters(), - lr=0.001, - betas=(0.9, 0.99), - weight_decay=0 - ) - - for (data, label) in data_iter(args.bs, args.dim, args.classes): - data, label = data.cuda(), label.cuda() - # forward - output = model(data) - loss = F.cross_entropy(output, label) - # backward - loss.backward() - # weight update - optimizer.step() - optimizer.zero_grad() diff --git a/examples/linear.py b/examples/linear.py deleted file mode 100644 index f8118cad..00000000 --- a/examples/linear.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=6000 \ - --use_env \ - examples/linear.py -""" - -import cube -from cube import nn -from cube.tensor.logic.tensor import LogicalTensor -from cube.device.physic.group import DeviceGroup - -import torch -import argparse - -import z3 - -torch.manual_seed(100) - - -# Expert Policy - -def select_policy(holistic_ops, outputs, *args, **kwargs): - """ - Args: - Candidates: holistic_ops - *args, **kwargs: op input - """ - return holistic_ops.get_op(0, outputs, *args, **kwargs) - - -def segment_policy(holist_op, input, weight, bias): - """ - Args: - holistic_op (HolisticOp) - *args, **kwargs: op input - """ - solver = holist_op.solver - attributes = holist_op.attributes - weight_layout = holist_op.input_layouts[1] - - # add restrictions based on device num - holist_op.add_constraint(weight_layout.chunk_num == 4) - - # iterate all configs - configs = list() - for config in cube.config.choices(solver, attributes): - if DeviceGroup().rank == 0: - print('find config: \n', config) - configs.append(config) - - # choose one config -- policy decision - config = configs[0] - if DeviceGroup().rank == 0: - print('selected config: {}'.format(config)) - - # deploy decisions - chunk_num = config[weight_layout.chunk_num].as_long() - input_ranks = [list(range(0, chunk_num)),] - weight_ranks = list() - for rank in range(chunk_num): - weight_ranks.append([rank]) - bias_ranks = weight_ranks - return config, [input_ranks, weight_ranks, bias_ranks] - - -cube.operator.logic.linear.Linear.set_default_policy(select_policy) -cube.operator.holist.linear.LinearColumnWeight.set_default_policy(segment_policy) - - - -# User Network -class SingleLinear(nn.Module): - - def __init__(self, dim, mult): - super().__init__() - self.net = nn.Linear(dim, dim * mult) - - def forward(self, x): - output = self.net(x) - return output - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--bs', type=int, default=32) - parser.add_argument('--dim', type=int, default=128) - parser.add_argument('--mult', type=int, default=16) - args = parser.parse_args() - - # init distributed env - rank = DeviceGroup().rank - - model = SingleLinear(args.dim, args.mult) - - inputs = LogicalTensor((args.bs, args.dim)) - output = model(inputs) - - assert isinstance(output, LogicalTensor) - assert torch.is_tensor(output.get_physical_tensor(rank)) - print('Done.') diff --git a/gencode0.py b/gencode0.py new file mode 100644 index 00000000..6f22f362 --- /dev/null +++ b/gencode0.py @@ -0,0 +1,36 @@ + + +########## Generated Code ########### +import torch + + +class GenModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.weight_3 = torch.nn.Parameter(torch.empty((16384, 1024))) + + def forward(self, data_1): + tensor_4 = torch.nn.functional.linear(data_1, self.weight_3, None) + tensor_6 = torch.nn.functional.gelu(tensor_4) + tensor_8 = torch.nn.functional.dropout(tensor_6, 0.0, self.training, False) + return tensor_8 + + +########## Generated Code ########### +import torch +import cube + +def _train_step(model, dataloader): + tensor_21 = cube.runtime.temporal.forward(model, *(*next(dataloader), )) + tensor_51 = cube.runtime.collectives.send_and_recv((tensor_21, ), [[1]], ([64, 16384],), [[1]]) # send: ([64, 16384],) + cube.runtime.temporal.backward((), (tensor_21, ), (tensor_51, )) + tensor_54 = cube.runtime.temporal.forward(model, *(*next(dataloader), )) + tensor_84 = cube.runtime.collectives.send_and_recv((tensor_54, ), [[1]], ([64, 16384],), [[1]]) # send: ([64, 16384],) + cube.runtime.temporal.backward((), (tensor_54, ), (tensor_84, )) + tensor_87 = cube.runtime.temporal.forward(model, *(*next(dataloader), )) + tensor_117 = cube.runtime.collectives.send_and_recv((tensor_87, ), [[1]], ([64, 16384],), [[1]]) # send: ([64, 16384],) + cube.runtime.temporal.backward((), (tensor_87, ), (tensor_117, )) + tensor_120 = cube.runtime.temporal.forward(model, *(*next(dataloader), )) + tensor_150 = cube.runtime.collectives.send_and_recv((tensor_120, ), [[1]], ([64, 16384],), [[1]]) # send: ([64, 16384],) + cube.runtime.temporal.backward((), (tensor_120, ), (tensor_150, )) diff --git a/gencode1.py b/gencode1.py new file mode 100644 index 00000000..a476a4c2 --- /dev/null +++ b/gencode1.py @@ -0,0 +1,40 @@ + + +########## Generated Code ########### +import torch + + +class GenModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.weight_10 = torch.nn.Parameter(torch.empty((1024, 16384))) + self.bias_11 = torch.nn.Parameter(torch.empty((1024,))) + self.weight_14 = torch.nn.Parameter(torch.empty((1000, 1024))) + self.bias_15 = torch.nn.Parameter(torch.empty((1000,))) + + def forward(self, tensor_8): + tensor_12 = torch.nn.functional.linear(tensor_8, self.weight_10, self.bias_11) + tensor_16 = torch.nn.functional.linear(tensor_12, self.weight_14, self.bias_15) + tensor_17 = torch.sum(tensor_16) + return tensor_17 + + +########## Generated Code ########### +import torch +import cube + +def _train_step(model, dataloader): + tensor_21 = cube.runtime.collectives.recv(([64, 16384],), [[0]]) + tensor_23 = cube.runtime.temporal.forward(model, *(tensor_21, )) + tensor_51 = cube.runtime.temporal.backward((tensor_21, ), (tensor_23, ), ()) + tensor_54 = cube.runtime.collectives.send_and_recv((tensor_51, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) + tensor_56 = cube.runtime.temporal.forward(model, *(tensor_54, )) + tensor_84 = cube.runtime.temporal.backward((tensor_54, ), (tensor_56, ), ()) + tensor_87 = cube.runtime.collectives.send_and_recv((tensor_84, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) + tensor_89 = cube.runtime.temporal.forward(model, *(tensor_87, )) + tensor_117 = cube.runtime.temporal.backward((tensor_87, ), (tensor_89, ), ()) + tensor_120 = cube.runtime.collectives.send_and_recv((tensor_117, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) + tensor_122 = cube.runtime.temporal.forward(model, *(tensor_120, )) + tensor_150 = cube.runtime.temporal.backward((tensor_120, ), (tensor_122, ), ()) + cube.runtime.collectives.send((tensor_150, ), [[0]]) # send: ([64, 16384],) From 5cd34db0cc30f074e51848d859cdce5875d23b24 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 24 Sep 2021 15:22:19 +0800 Subject: [PATCH 0200/1892] backward grads will be None if not have --- cube/codegen/codegen.py | 5 ++-- cube/runtime/temporal.py | 5 ++-- gencode1.py | 8 +++--- tests/tschedule/test_tschedule.py | 46 ++++++------------------------- 4 files changed, 19 insertions(+), 45 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 7b99d212..5d8cf8f4 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -297,9 +297,10 @@ def emit_action(self, action: IRAction) -> List[str]: forward_outputs = '(' + ', '.join(forward_outputs + ['']) + ')' num_recv_tensors = len(action.recv_tensors) if num_recv_tensors == 0: - recv_grads = list() + recv_grads = [None] else: - recv_grads = action.inputs()[-num_recv_tensors:] + # recv_grads = action.inputs()[-num_recv_tensors:] + recv_grads = action.recv_tensors recv_grads = [self.naming(tensor) for tensor in recv_grads] recv_grads = '(' + ', '.join(recv_grads + ['']) + ')' diff --git a/cube/runtime/temporal.py b/cube/runtime/temporal.py index 5e8703ba..bca6125d 100644 --- a/cube/runtime/temporal.py +++ b/cube/runtime/temporal.py @@ -19,9 +19,10 @@ def backward(input_tensors, output_tensors, output_tensor_grads): if torch.is_tensor(tensor) and tensor.requires_grad: tensor.retain_grad() - # TODO: gen code should contain None in output_tensor_grads if len(output_tensor_grads) != len(output_tensors): - output_tensor_grads = [None] * len(output_tensors) + raise RuntimeError( + "Expected same length of out tensors and grads" + ) for tensor, grads in zip(output_tensors, output_tensor_grads): print('backwarding... ') diff --git a/gencode1.py b/gencode1.py index a476a4c2..0d5cde6e 100644 --- a/gencode1.py +++ b/gencode1.py @@ -27,14 +27,14 @@ def forward(self, tensor_8): def _train_step(model, dataloader): tensor_21 = cube.runtime.collectives.recv(([64, 16384],), [[0]]) tensor_23 = cube.runtime.temporal.forward(model, *(tensor_21, )) - tensor_51 = cube.runtime.temporal.backward((tensor_21, ), (tensor_23, ), ()) + tensor_51 = cube.runtime.temporal.backward((tensor_21, ), (tensor_23, ), (None, )) tensor_54 = cube.runtime.collectives.send_and_recv((tensor_51, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) tensor_56 = cube.runtime.temporal.forward(model, *(tensor_54, )) - tensor_84 = cube.runtime.temporal.backward((tensor_54, ), (tensor_56, ), ()) + tensor_84 = cube.runtime.temporal.backward((tensor_54, ), (tensor_56, ), (None, )) tensor_87 = cube.runtime.collectives.send_and_recv((tensor_84, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) tensor_89 = cube.runtime.temporal.forward(model, *(tensor_87, )) - tensor_117 = cube.runtime.temporal.backward((tensor_87, ), (tensor_89, ), ()) + tensor_117 = cube.runtime.temporal.backward((tensor_87, ), (tensor_89, ), (None, )) tensor_120 = cube.runtime.collectives.send_and_recv((tensor_117, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) tensor_122 = cube.runtime.temporal.forward(model, *(tensor_120, )) - tensor_150 = cube.runtime.temporal.backward((tensor_120, ), (tensor_122, ), ()) + tensor_150 = cube.runtime.temporal.backward((tensor_120, ), (tensor_122, ), (None, )) cube.runtime.collectives.send((tensor_150, ), [[0]]) # send: ([64, 16384],) diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 8405f62f..272dcb26 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -30,9 +30,9 @@ def forward(self, data): model = FeedForward(dim=1024) ir_graph = parser.convert(model, input_shapes=([64,1024],)) -print(" > Forward IRGraph ========") +print('====== Forward Graph =======\n') print(ir_graph) -print(" < ==============\n") +print('====== Forward Graph =======\n') # device assignment for input in ir_graph.inputs(): @@ -50,17 +50,13 @@ def test_graph_forward(ir_graph): TSchedulePool().clear() tensor1 = ir_graph(IRTensor(shape=[64,1024])) print(tensor1) - # print(tschedule.pool.TSchedulePool()) - # tensor2 = ir_graph() - # print(tensor2) + print('====== Forward Test =======') for action in TSchedulePool().actions(): - print('\n', action) - gener = SScheduleCodeGen(action.graph) + gener = SScheduleCodeGen(action) code = gener.gen(device=action.device[0]) - print("> ===== Generated Code =====") print(code) - print("< ===== Generated Code =====") print(TSchedulePool()) + print('\n====== Forward Test =======\n') def test_graph_backward(ir_graph): @@ -74,45 +70,21 @@ def test_graph_backward(ir_graph): input.device = [0] tensor = ir_graph(input) tensor.backward() - print('====== Backward Test =======') + print('====== Backward Test =======\n') print(TSchedulePool()) sequence = IRSequence(TSchedulePool().actions()) from cube.codegen.codegen import TScheduleCodeGen gener = TScheduleCodeGen(sequence) code = gener.gen(device=0) + print(code) code = gener.gen(device=1) print(code) -def test_graph(ir_graph): - - datas = None - model: IRGraph = None - - @tschedule(model=ir_graph) - def train_step(model, datas): - for data in datas: - loss = model(data) - loss.backward() - - for epoch in range(10): - for datas in dataloader(bs=64, mbs=4): - train_step(model, datas) + print('\n====== Backward Test =======\n') if __name__ == '__main__': - #test_graph_forward(ir_graph) + test_graph_forward(ir_graph) test_graph_backward(ir_graph) - - - -""" -loss = cube.runtime.temporal.forward(model, input1, input2, xxx) -grad1, grad2, ... = cube.runtime.temporal.backward(loss, None) -""" - -""" -out1, out2 = cube.runtime.temporal.forward(model, input1) -cube.runtime.temporal.backward(out1, out2, out1_grad, out2_grad) -""" \ No newline at end of file From 3efcf623fee222317c58c045e2b3e80778a5c2bd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 28 Sep 2021 20:04:52 +0800 Subject: [PATCH 0201/1892] policy for megatron: --- examples/case_study/megatron_policy.py | 157 +++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 examples/case_study/megatron_policy.py diff --git a/examples/case_study/megatron_policy.py b/examples/case_study/megatron_policy.py new file mode 100644 index 00000000..3d3a9b86 --- /dev/null +++ b/examples/case_study/megatron_policy.py @@ -0,0 +1,157 @@ +from typing import List + +# spatial +def select(tensor, indices, val_op, shape): pass +def assign(tensor, ranks: List): pass + +# temporal +def merge(su1, su2): pass + + +def spolicy(model, runtime_info, tp_size, dp_size, pp_size): + + n_devices = runtime_info.ndevs + + # each op is divided in (mp_dsize, dp_size) + # and put in (pp_size) stage + # TODO + devices = device_rank_group(n_devices, tp_size, dp_size, pp_size) + + # pipeline stage + total_nodes = len(model.nodes()) + num_op_per_stage = total_nodes // pp_size + for idx, op in enumerate(model.nodes()): + stage_id = idx // num_op_per_stage + assign(op, devices[stage_id]) + + # data parallel + for op in model.nodes(): + # data parallel algorithm (suppose at index 0) + dp_algo = op.logical_op.dist_algo(0) + dp_devices = op.device + sub_graph = select( + op = op, + algorithm = dp_algo, + config = dict(chunk_num=dp_size, uniform=True) + ) + for dp_op, tp_devices in zip(sub_graph, dp_devices): + assign(dp_op, tp_devices) + model.replace(op, sub_graph) + + # tensor parallel + # a transformer attention layer: + # [attention: col_split(mm + mm + mm) + row_split(mm)] + # a transformer feedforward layer: + # [feedforwrd: col_split(mm) + row_split(mm)] + for idx in range(total_nodes): + for dp_rank in range(dp_size): + op = model.nodes(dp_size * idx + dp_rank) + devices = op.devices + sub_graph = None + # Attention block + # [1st linear -> 2nd linear) + if first_to_2nd_linear(op): + # split column + tp_col_algo = op.logical_op.dist_algo(1) + sub_graph = select( + op = op, + algorithm = tp_col_algo, + config = dict(chunk_num=tp_size, uniform=True) + ) + # 2nd linear + elif is_2nd_linear(op): + # split row + tp_row_algo = op.logical_op.dist_algo(2) + sub_graph = select( + op = op, + algorithm = tp_row_algo, + config = dict(chunk_num=tp_size, uniform=True) + ) + # MLP block + # [3rd linear -> 4th linear] + elif thrid_to_4th_linear(op): + # split column + tp_col_algo = op.logical_op.dist_algo(1) + sub_graph = select( + op = op, + algorithm = tp_col_algo, + config = dict(chunk_num=tp_size, uniform=True) + ) + elif is_4th_linear(op): + # split row + tp_row_algo = op.logical_op.dist_algo(2) + sub_graph = select( + op = op, + algorithm = tp_row_algo, + config = dict(chunk_num=tp_size, uniform=True) + ) + # else: no change, do redundant computation + if sub_graph: + # assign device + for op, device in zip(sub_graph, devices): + assign(op, [device]) + model.replace(op, sub_graph) + return model + + +def tpolicy(sus, relations, tp_size, pp_size, num_microbatch): + """ + Pipeline 1f1b policy description -- generate a sequence + + Actions: a list of actions + + relations: list[(Action1, Action2)]: a list of tuples indicate partial order + """ + + # put sus to forward-backward sequences: List[List[SU(op)]] + fb_op_seqs = list() + for su in sus: + for fb_seq in fb_op_seqs: + if fb_seq[-1].happen_before(su): + fb_seq.append(su) + break + else: + fb_op_seqs.append([su]) + + # merge to stages: List[List[SU(stage of ops)]] + fb_stage_seqs = list() + for fb_seq in fb_op_seqs: + merged_su = fb_seq[0] + merged_tag = fb_seq[0].tag + for su in fb_seq[1]: + if su.device == merged_su and su.tag == merged_tag: + merged_su = merge(merged_su, su) + else: + fb_stage_seqs.append(merged_su) + merged_su = su + merged_tag = su.tag + merged_su = merge(merged_su, su) + + # pp_size forward + pp_size backward + assert (pp_size * 2 == len(fb_stage_seqs[0])) + + num_stage = pp_size + + f = lambda stage, micro_batch_id: fb_stage_seqs[micro_batch_id][stage] + b = lambda stage, micro_batch_id: fb_stage_seqs[micro_batch_id][num_stage + stage] + + sequence = list() + + # warmup: + for stage in range(num_stage): + for mid in range(stage): + sequence.append(f(stage, mid)) + + # steady + cooldown: + for mid in range(num_microbatch): + # enqueue backward + for stage in range(num_stage-1, -1, -1): + sequence.append(b(stage, mid)) + # enqueue forward + for stage in range(num_stage): + f_mid = mid + 1 + num_stage - stage + if f_mid >= num_microbatch: + continue + sequence.append(f(stage, f_mid)) + assert check_consistency(sequence, sus, relations) + return sequence From eb8beed9198fb724fca72eed43ab6043b3409cca Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 29 Sep 2021 16:37:49 +0800 Subject: [PATCH 0202/1892] group to device map --- examples/case_study/megatron_policy.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/case_study/megatron_policy.py b/examples/case_study/megatron_policy.py index 3d3a9b86..a27860c5 100644 --- a/examples/case_study/megatron_policy.py +++ b/examples/case_study/megatron_policy.py @@ -14,28 +14,27 @@ def spolicy(model, runtime_info, tp_size, dp_size, pp_size): # each op is divided in (mp_dsize, dp_size) # and put in (pp_size) stage - # TODO - devices = device_rank_group(n_devices, tp_size, dp_size, pp_size) + # TODO groups[stage][dp_group][tp_group] = devices (List[int]) + groups = parallel_group(n_devices, tp_size, dp_size, pp_size) # pipeline stage total_nodes = len(model.nodes()) num_op_per_stage = total_nodes // pp_size for idx, op in enumerate(model.nodes()): - stage_id = idx // num_op_per_stage - assign(op, devices[stage_id]) + pp_stage = idx // num_op_per_stage + op.group = [pp_stage] # data parallel for op in model.nodes(): # data parallel algorithm (suppose at index 0) dp_algo = op.logical_op.dist_algo(0) - dp_devices = op.device sub_graph = select( op = op, algorithm = dp_algo, config = dict(chunk_num=dp_size, uniform=True) ) - for dp_op, tp_devices in zip(sub_graph, dp_devices): - assign(dp_op, tp_devices) + for dp_stage, dp_op in sub_graph.nodes(): + dp_op.group.append(dp_stage) model.replace(op, sub_graph) # tensor parallel @@ -87,10 +86,14 @@ def spolicy(model, runtime_info, tp_size, dp_size, pp_size): ) # else: no change, do redundant computation if sub_graph: - # assign device - for op, device in zip(sub_graph, devices): - assign(op, [device]) + for tp_stage, op in enumerate(sub_graph): + op.group.append(tp_stage) model.replace(op, sub_graph) + # device assignment + for op in model.nodes(): + pp_stage, dp_stage, tp_stage = op.group + device = groups[pp_stage][dp_stage][tp_stage] + assign(op, device) return model From 70a21eb8a43e712323e8e15c93324dfc47ea1153 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 10 Oct 2021 16:09:54 +0800 Subject: [PATCH 0203/1892] update to align with system design --- cube/__init__.py | 3 +- cube/codegen/codegen.py | 330 +++++++++++---------- cube/graph/__init__.py | 2 +- cube/graph/ir_comm.py | 65 ++++ cube/graph/ir_cten.py | 266 +++++++++-------- cube/graph/ir_graph.py | 477 ++++++++++++++---------------- cube/graph/ir_op.py | 36 +-- cube/graph/ir_seq.py | 102 ------- cube/graph/unique.py | 2 +- cube/sschedule/__init__.py | 10 +- cube/sschedule/adapter.py | 44 +++ cube/sschedule/prim.py | 55 ++++ cube/tschedule/__init__.py | 6 +- cube/tschedule/pool.py | 14 +- cube/tschedule/su.py | 189 ++++++++++++ cube/tschedule/suseq.py | 200 +++++++++++++ tests/graph/test_graph.py | 63 ++++ tests/graph/test_parser.py | 16 +- tests/tschedule/test_tschedule.py | 111 +++++-- 19 files changed, 1262 insertions(+), 729 deletions(-) create mode 100644 cube/graph/ir_comm.py delete mode 100644 cube/graph/ir_seq.py create mode 100644 cube/sschedule/adapter.py create mode 100644 cube/sschedule/prim.py create mode 100644 cube/tschedule/su.py create mode 100644 cube/tschedule/suseq.py create mode 100644 tests/graph/test_graph.py diff --git a/cube/__init__.py b/cube/__init__.py index e8b31d7d..13952670 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,4 +1,5 @@ from cube.device.physic.group import DeviceGroup from cube import sschedule from cube import tschedule -from cube import runtime \ No newline at end of file +from cube import runtime + diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 5d8cf8f4..6e012ba8 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,10 +1,12 @@ """ Generate Pytorch code given the model DAG and the transformation config """ -from typing import List, Any, Dict +from typing import List, Any +from cube.graph.ir_comm import IRCommType, IRCommunication -from cube.graph import IRAction, IRTensor, IROperation -from cube.graph.ir_seq import IRSequence +from cube.graph.ir_cten import IRTensor +from cube.tschedule.suseq import SUSequence +from cube.tschedule.su import ScheduleUnit from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -17,10 +19,10 @@ class SScheduleCodeGen: Generate spatial code for the model """ - def __init__(self, action: IRAction): - if not isinstance(action, IRAction): - raise TypeError("graph should be IRGraph") - self.graph = action.graph + def __init__(self, seq: SUSequence): + if not isinstance(seq, SUSequence): + raise TypeError("seq should be SUSequence") + self.sus = seq.sus() # model full code self.init_code: List[str] = [ '\n\n########## Generated Code ###########', @@ -28,6 +30,7 @@ def __init__(self, action: IRAction): # module init code self.declare_region: List[str] = list() # module forward code + self.all_su_forward_region: List[List[str]] = list() self.forward_region: List[str] = list() # module member name self.symbols = SymbolTable() @@ -38,22 +41,39 @@ def gen(self, device: int, outfile=None, attach=False) -> str: """ Generate model implementation code based on the given graph. """ + device_sus = [su for su in self.sus if (device in su.device) and (su.tag == 'forward')] + gencode = copy.copy(self.init_code) + # register forward input - fargs = [self.naming(input) for input in self.graph.inputs()] - for name in fargs: - self.symbols.create(name) + su_args: List[List[str]] = list() + for su in device_sus: + fargs = [self.naming(input) for input in su.inputs()] + for name in fargs: + self.symbols.create(name) + su_args.append(fargs) # parse graph body - for node in self.graph.nodes(): - self.emit_op_call(node) - # emit input declaration - for arg in node.inputs(): - self.emit_var_declare(arg) - # record output tensor name - for out in node.outputs(): - if isinstance(out, IRTensor) or isinstance(out, str): - self.symbols.create(self.naming(out)) + print(f'device: {device}: {device_sus}') + for su in device_sus: + print('====', su) + for node in su.nodes(): + print(node) + for node in su.nodes(): + if isinstance(node, IRCommunication): + self.emit_comm_call(node, su) + else: + self.emit_op_call(node, su) + # emit input declaration + for arg in node.inputs(): + self.emit_var_declare(arg) + # record output tensor name + for out in node.outputs(): + if isinstance(out, IRTensor) or isinstance(out, str): + self.symbols.create(self.naming(out)) + print(self.forward_region) + self.all_su_forward_region.append(self.forward_region) + self.forward_region = list() # generate full code with ClassBlock(class_name='GenModel', derived=['torch.nn.Module']) as cb: @@ -61,14 +81,18 @@ def gen(self, device: int, outfile=None, attach=False) -> str: ib.insert_body(self.declare_region) cb.insert_body('') cb.insert_body(ib.code) - with FunctionBlock(func_name='forward', args=['self']+fargs) as fb: - fb.insert_body(self.forward_region) - # generate output - out_names = self._forward_region_arg_names(self.graph.outputs()) - return_code = f"return {', '.join(out_names)}" - fb.insert_body(return_code) - cb.insert_body('') - cb.insert_body(fb.code) + for idx, su in enumerate(device_sus): + name = f'su{self.sus.index(su)}' + input_args = ['self'] + su_args[idx] + forward_code = self.all_su_forward_region[idx] + with FunctionBlock(func_name=name, args=input_args) as fb: + fb.insert_body(forward_code) + # generate output + out_names = self._forward_region_arg_names(su.outputs(), su) + return_code = f"return {', '.join(out_names)}" + fb.insert_body(return_code) + cb.insert_body('') + cb.insert_body(fb.code) gencode += cb.code gencode += [''] @@ -77,6 +101,9 @@ def gen(self, device: int, outfile=None, attach=False) -> str: if outfile: with open(outfile, 'a' if attach else 'w') as f: f.write(code) + + # clear used buffer + self.clear() return code def emit_var_declare(self, var: Any): @@ -99,33 +126,59 @@ def emit_var_declare(self, var: Any): self.declare_region.append(code) return - def emit_op_call(self, node: IROperation): + def emit_op_call(self, node, su: ScheduleUnit): """ Emit op forward code """ op_code = node.signature - out_names = self._forward_region_arg_names(node.outputs()) - out_names = ', '.join(out_names) - arg_names = self._forward_region_arg_names(node.inputs()) + arg_names = self._forward_region_arg_names(node.inputs(), su) arg_region = '(' + ', '.join(arg_names) + ')' - code = f'{out_names} = {op_code}{arg_region}' + if len(node.outputs()) == 0: + code = f'{op_code}{arg_region}' + else: + out_names = self._forward_region_arg_names(node.outputs(), su) + out_names = ', '.join(out_names) + code = f'{out_names} = {op_code}{arg_region}' + self.forward_region.append(code) + + def emit_comm_call(self, node, su: ScheduleUnit): + """ + Emit communication code + """ + comm_code = node.signature + send_tensors = self._forward_region_arg_names(node.inputs(), su) + send_ranks = node.send_ranks + recv_tensors = self._forward_region_arg_names(node.outputs(), su) + recv_shapes = [tensor.shape for tensor in node.outputs()] + recv_ranks = node.recv_ranks + if node.comm_type == IRCommType.Send: + send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' + code = f'{comm_code}({send_tensors}, {send_ranks})' + elif node.comm_type == IRCommType.Recv: + recv_tensors = '(' + ', '.join(recv_tensors + ['']) + ')' + code = f'{recv_tensors} = {comm_code}({recv_shapes}, {recv_ranks})' + elif node.comm_type == IRCommType.SendRecv: + send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' + recv_tensors = '(' + ', '.join(recv_tensors + ['']) + ')' + code = f'{recv_tensors} = {comm_code}({send_tensors}, {send_ranks}, {recv_shapes}, {recv_ranks})' + else: + raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") self.forward_region.append(code) - def _forward_region_arg_names(self, args: List[Any]): + def _forward_region_arg_names(self, tensors: List[Any], su: ScheduleUnit): """ Generate arg name list for forward region. Will add prefix 'self.' for var defined in declare region """ named_args : List[str] = list() - input_name = [self.naming(input) for input in self.graph.inputs()] - for arg in args: - name = self.naming(arg) - if isinstance(arg, IRTensor) and \ - arg.is_leaf() and (name not in input_name): - named_args.append('self.' + self.naming(arg)) + for tensor in tensors: + name = self.naming(tensor) + if isinstance(tensor, IRTensor) and \ + tensor.is_leaf(su.nodes()) and (tensor not in su.inputs()): + named_args.append('self.' + name) else: - named_args.append(self.naming(arg)) + named_args.append(self.naming(name)) return named_args def naming(self, tensor: Any) -> str: @@ -143,12 +196,25 @@ def naming(self, tensor: Any) -> str: name = str(tensor) return name + def clear(self): + """ + Clear buffer that used for generating code + """ + # module init code + self.declare_region: List[str] = list() + # module forward code + self.all_su_forward_region: List[List[str]] = list() + self.forward_region: List[str] = list() + # module member name + self.symbols = SymbolTable() + + class TScheduleCodeGen: - def __init__(self, seq: IRSequence): - if not isinstance(seq, IRSequence): - raise TypeError("seq should be IRSequence") + def __init__(self, seq: SUSequence): + if not isinstance(seq, SUSequence): + raise TypeError("seq should be SUSequence") self.seq = seq # model full code self.init_code: List[str] = [ @@ -159,60 +225,20 @@ def __init__(self, seq: IRSequence): def gen(self, device: int, outfile=None, attach=False) -> str: """ - Generate scheduling code based on the given actions + Generate scheduling code based on the given sus """ gencode = copy.copy(self.init_code) - actions = list() - for action in self.seq: - if device in action.device: - actions.append(action) - - # {send: xxx, recv: xxx} action1 {send:xxx, recv:xxx} action2 .... - action_with_comms = [dict()] - for action in actions: - # send info - send_tensors, send_devices = action.get_send_tensors() - send_shapes = tuple([tensor.shape for tensor in send_tensors]) - send_tensors = [self.naming(tensor) for tensor in send_tensors] - - # recv info - recv_tensors, recv_devices = action.get_recv_tensors() - recv_shapes = tuple([tensor.shape for tensor in recv_tensors]) - recv_tensors = [self.naming(tensor) for tensor in recv_tensors] - - comm = action_with_comms[-1] - - # recv before the action - if len(recv_tensors) != 0: - comm.update({ - 'recv_tensors' : recv_tensors, - 'recv_devices' : recv_devices, - 'recv_shapes' : recv_shapes - }) - - # action - action_with_comms.append(action) - - # send after the action - comm = dict() - if len(send_tensors) != 0: - comm.update({ - 'send_tensors' : send_tensors, - 'send_devices' : send_devices, - 'send_shapes' : send_shapes - }) - action_with_comms.append(comm) + device_sus = [su for su in self.seq.sus() if device in su.device] # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: - for action_or_comm in action_with_comms: - if isinstance(action_or_comm, dict): - code = self.emit_comm(action_or_comm) - fb.insert_body(code) - else: - code = self.emit_action(action_or_comm) - fb.insert_body(code) + data_code = self.emit_data(device_sus) + fb.insert_body(data_code) + for su in device_sus: + name = f'su{self.seq.sus().index(su)}' + code = self.emit_su(su, name=name) + fb.insert_body(code) gencode += fb.code gencode += [''] @@ -223,95 +249,76 @@ def gen(self, device: int, outfile=None, attach=False) -> str: f.write(code) return code - def emit_comm(self, comm: Dict) -> List[str]: + def emit_data(self, device_sus) -> List[str]: """ - Emit send / recv code + Emit dataloader iter code """ - ssign = 'cube.runtime.collectives.send({send_tensors}, {to_devices})' - rsign = 'cube.runtime.collectives.recv({shapes}, {from_devices})' - srsign = 'cube.runtime.collectives.send_and_recv({send_tensors}, {to_devices}, {recv_shapes}, {from_devices})' - - # generate for send - if ('send_tensors') in comm and ('recv_tensors' not in comm): - send_tensors = '(' + ', '.join(comm['send_tensors'] + ['']) + ')' - code = ssign.format( - send_tensors = send_tensors, - to_devices = comm['send_devices'] - ) - return code + f" # send: {comm['send_shapes']}" - # generate for recv - elif ('send_tensors' not in comm) and ('recv_tensors' in comm): - body = rsign.format( - shapes = comm['recv_shapes'], - from_devices = comm['recv_devices'] - ) - return_val = ', '.join(comm['recv_tensors']) - code = f'{return_val} = {body}' - return code - # generate for send + recv - elif ('send_tensors' in comm) and ('recv_tensors' in comm): - send_tensors = '(' + ', '.join(comm['send_tensors'] + ['']) + ')' - body = srsign.format( - send_tensors = send_tensors, - to_devices = comm['send_devices'], - recv_shapes = comm['recv_shapes'], - from_devices = comm['recv_devices'] - ) - return_val = ', '.join(comm['recv_tensors']) - code = f"{return_val} = {body} # send: {comm['send_shapes']}" - return code - else: - return [] + # TODO: dataloader to op node + inputs = list() + for su in device_sus: + su_inputs = [ + self.naming(input, su) for input in su.inputs() \ + if input.is_leaf(device_sus) + ] + inputs += su_inputs + data_code = list() + if len(inputs) != 0: + inputs = '(' + ', '.join(inputs + ['']) + ')' + data_code.append(inputs + ' = next(dataloader)') + return data_code - def emit_action(self, action: IRAction) -> List[str]: + def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: """ - Emit action code + Emit su code """ fsign = 'cube.runtime.temporal.forward({model}, *{inputs})' bsign = 'cube.runtime.temporal.backward({input_tensors}, {output_tensors}, {output_grads})' - if action.name == 'forward': - inputs = [self.naming(tensor) for tensor in action.inputs()] + if su.tag == 'forward': + inputs = [self.naming(tensor, su) for tensor in su.inputs()] inputs = '(' + ', '.join(inputs + ['']) + ')' body = fsign.format( - model = 'model', + model = f'model.{name}', inputs = inputs ) - outputs = [self.naming(output) for output in action.outputs()] + outputs = [self.naming(output, su) for output in su.outputs()] return_val = ','.join(outputs) - code = f'{return_val} = {body}' + if len(su.outputs()) == 0: + code = body + else: + code = f'{return_val} = {body}' return code - elif action.name == 'backward': - # 1). input_tensors are forward inputs (happened before action inputs) + elif su.tag == 'backward': + # 1). input_tensors are forward inputs (happened before su inputs) # => backward graph output tensor (share tensor in forward / backward graph) - # 2). output_tensors are forward outputs (action.inputs()) + # 2). output_tensors are forward outputs (su.inputs()) # => backward graph input tensor (share tensor in forward / backward) # 3). output_grads are recved tesnors of this graph (graph.recv_tensors) # => backward graph input tensor (graph.recv_tensors) - forward_inputs = self.seq.get_forward_inputs(action) - forward_inputs = [self.naming(tensor) for tensor in forward_inputs] + fsu = su.mirror + forward_inputs = [self.naming(tensor, fsu) for tensor in fsu.inputs()] forward_inputs = '(' + ', '.join(forward_inputs + ['']) + ')' - forward_outputs = self.seq.get_forward_outputs(action) - forward_outputs = [self.naming(tensor) for tensor in forward_outputs] + forward_outputs = [self.naming(tensor, fsu) for tensor in fsu.outputs()] forward_outputs = '(' + ', '.join(forward_outputs + ['']) + ')' - num_recv_tensors = len(action.recv_tensors) - if num_recv_tensors == 0: - recv_grads = [None] - else: - # recv_grads = action.inputs()[-num_recv_tensors:] - recv_grads = action.recv_tensors - recv_grads = [self.naming(tensor) for tensor in recv_grads] - recv_grads = '(' + ', '.join(recv_grads + ['']) + ')' + + grads = list() + for tensor in su.inputs(): + # the thensor is loss, no grad needs + if tensor in fsu.outputs(): + grads.append('None') + else: + grads.append(self.naming(tensor, su)) + grads = '(' + ', '.join(grads + ['']) + ')' body = bsign.format( input_tensors = forward_inputs, output_tensors = forward_outputs, - output_grads = recv_grads + output_grads = grads ) # returned value are graph.outputs - return_val = [self.naming(tensor) for tensor in action.outputs()] + return_val = [self.naming(tensor, su) for tensor in su.outputs()] if len(return_val) > 0: return_code = ', '.join(return_val) + ' = ' else: @@ -319,22 +326,23 @@ def emit_action(self, action: IRAction) -> List[str]: code = f'{return_code}{body}' return code else: - raise RuntimeError(f"Unsupported action tag: {action.tag}") + raise RuntimeError(f"Unsupported su tag: {su.tag}") - def naming(self, tensor: Any) -> str: + def naming(self, tensor: Any, su) -> str: """ Return the var name (unique for different variable) If the var is a leaf tensor, will add prefix `self.` to its name """ if isinstance(tensor, IRTensor): - if len(tensor.src()) == 0: - name = '*next(dataloader)' - else: - tensor_name = 'tensor' if tensor.name is None else tensor.name - if '.' in tensor_name: - tensor_name = tensor_name.split('.')[0] - name = '_'.join([tensor_name, str(tensor._id)]) + # note in su there is no parameters + # if len(tensor.src(su.nodes())) == 0: + # name = '*next(dataloader)' + # else: + tensor_name = 'tensor' if tensor.name is None else tensor.name + if '.' in tensor_name: + tensor_name = tensor_name.split('.')[0] + name = '_'.join([tensor_name, str(tensor._id)]) else: name = str(tensor) return name diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index 07158ba6..1cfb314f 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -1,4 +1,4 @@ -from cube.graph.ir_graph import IRGraph, IRAction +from cube.graph.ir_graph import IRGraph from cube.graph.ir_cten import IRTensor, IRCell from cube.graph.ir_op import IROperation from cube.graph import parser diff --git a/cube/graph/ir_comm.py b/cube/graph/ir_comm.py new file mode 100644 index 00000000..7d5642d5 --- /dev/null +++ b/cube/graph/ir_comm.py @@ -0,0 +1,65 @@ +from typing import List +from enum import Enum + +from cube.graph.ir_cten import IRCell + + +class IRCommType(Enum): + + Send = 'send' + Recv = 'recv' + SendRecv = 'sendrecv' + + +class IRCommunication(IRCell): + """ + Communication cell for IRCell + """ + + def __init__(self, + send_tensors=list(), send_ranks: List[List[int]] = list(), + recv_tensors=list(), recv_ranks: List[List[int]] =list()): + """ + Create a basic send, recv or sendrecv communication node + """ + if len(send_tensors) != 0 and len(recv_tensors) != 0: + comm_type = IRCommType.SendRecv + signature = 'cube.runtime.collectives.sendrecv' + elif len(send_tensors) != 0 and len(recv_tensors) == 0: + comm_type = IRCommType.Send + signature = 'cube.runtime.collectives.send' + elif len(recv_tensors) != 0 and len(send_tensors) == 0: + comm_type = IRCommType.Recv + signature = 'cube.runtime.collectives.recv' + else: + raise ValueError( + "Expected at least one of send_tensors and recv_tensors" + ) + + self.comm_type = comm_type + self.send_tensors = list() + self.send_ranks = list() + self.recv_tensors = list() + self.recv_ranks = list() + + super().__init__( + name = comm_type.value, + signature = signature, + input_length = len(send_tensors), + output_length = len(recv_tensors) + ) + + for idx, (tensor, to_device) in enumerate(zip(send_tensors, send_ranks)): + self.set_input(idx, tensor) + self.send_tensors.append(self.inputs(idx)) + self.send_ranks.append(to_device) + + for idx, (tensor, from_device) in enumerate(zip(recv_tensors, recv_ranks)): + self.set_output(idx, tensor) + self.recv_tensors.append(self.outputs(idx)) + self.recv_ranks.append(from_device) + + def merge(self, other): + if not isinstance(other, IRCommunication): + raise RuntimeError("Expected IRCommunication to merge") + raise NotImplementedError diff --git a/cube/graph/ir_cten.py b/cube/graph/ir_cten.py index abd8a19f..d8224671 100644 --- a/cube/graph/ir_cten.py +++ b/cube/graph/ir_cten.py @@ -1,4 +1,20 @@ -from typing import List, Union, Optional, Any, Tuple +""" +IRCell: + a graph node component serving for different purpose, + e.g., operator, device graph, graph + +IRTensor: + Tensor representation serving for edges to connect IRCells + +The input of IRCell are IRTensors or any deterministic values (e.g., int). +If an IRTensor is the input of Cell, then Cell.device \in IRTensor.deivce + +The output of IRCell are IRTensors or any deterministic values (e.g., int) +If an IRTensor is the output of Cell, then Cell.device == IRTensor.device +""" + + +from typing import List, Union, Optional, Any import copy from cube.graph.unique import IDGenerator @@ -32,20 +48,22 @@ def __init__(self, self.name: str = name self.signature = signature - # device self._device = list() # source tensors self._inputs: List[Any] = [None] * input_length - # source cells - self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length)] # destination tensors self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] - for output in self._outputs: - output.add_src_node(self) + for tensor in self._outputs: + tensor.attach_cell(self) + # destination cells + # -- will only be set when initializing to a graph self._successors: List[List[IRCell]] = [list() for _ in range(output_length)] + # source cells: note a tensor can be generated by many cells + # -- will only be set when initializing to a graph + self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length)] @property def device(self): @@ -73,7 +91,7 @@ def on_device(self, device_id: int): raise TypeError(f"Expected device id to be int but got {type(device_id)}") return device_id in self.device - def inputs(self, index: Optional[int] = None): + def inputs(self, index: Optional[int] = None) -> Union[List[Any], Any]: """ Get input tensor at input index @@ -81,6 +99,9 @@ def inputs(self, index: Optional[int] = None): index (int or None): index of the inputs, None will return the nodes for all the inputs + + Returns: + values: Union[List[Any], Any] """ if isinstance(index, int): if index >= len(self._inputs): @@ -96,6 +117,9 @@ def inputs(self, index: Optional[int] = None): def predecessors(self, index: Optional[int] = None) -> List: """ Get input operator at input index + + Returns: + cell(s): Union[List[IRCell], IRCell] """ if isinstance(index, int): if index >= len(self._inputs): @@ -111,7 +135,7 @@ def predecessors(self, index: Optional[int] = None) -> List: else: raise TypeError("Expected index to be None or int") - def outputs(self, index: Optional[int] = None): + def outputs(self, index: Optional[int] = None) -> Union[List[Any], Any]: """ Get output tensor at output index @@ -119,6 +143,9 @@ def outputs(self, index: Optional[int] = None): index (int or None): index of the outputs, None will return the nodes for all the outputs + + Returns: + values: Union[List[Any], Any] """ if isinstance(index, int): if index >= len(self._outputs): @@ -150,7 +177,7 @@ def successors(self, index: Optional[int] = None) -> List: successors = list() for post_cells in self._successors: successors += post_cells - return post_cells + return successors else: raise TypeError("Expected index to be None or int") @@ -158,48 +185,41 @@ def set_input(self, input_index: int, val: Any): """ Set the node inputs[input_index] with the tensor - val: IRTensor or any deterministic value (int, bool, str, etc) + Args: + val: Union[IRTensor, Any] """ if input_index >= len(self.inputs()): raise RuntimeError( f"Set the input out of range ({input_index} >= {len(self._inputs)})" ) - # set tensor - self._inputs[input_index] = val if isinstance(val, IRTensor): + # copy the val + val = copy.copy(val) # set tensor dst - val.add_dst_node(self) - # set predecessor - self._predecessors[input_index] = val.src() - # set the source node successor - for node in val.src(): - if isinstance(node, IRCell): - node.add_successor(val, self) + val.attach_cell(self) + self._inputs[input_index] = val def set_output(self, output_index: int, val: Any): """ Set the node inputs[output_index] with the tensor - val: IRTensor or any deterministic value (int, bool, str, etc) + Args: + val: Union[IRTensor, Any] + IRTensor or any deterministic value (int, bool, str, etc) """ if output_index >= len(self.outputs()): raise RuntimeError( f"Set the input out of range ({output_index} >= {len(self._inputs)})" ) - # set tensor - self._outputs[output_index] = val if isinstance(val, IRTensor): - # set predecessor - for node in val.src(): - if isinstance(node, IRCell): - self._successors[output_index].append(node) - # set the source node - if self not in val.src(): - val.add_src_node(self) + val = copy.copy(val) + val.attach_cell(self) + self._outputs[output_index] = val - def add_predecessor(self, input_index: int, node, out_index: int): + def add_predecessor(self, input_index: int, node): """ - Set self node the input node. self.input[input_index] = node.output[out_index] + Add a predecessor cell in the input_index slot. + self.input[input_index] = node.output[out_index] """ if not isinstance(node, IRCell): raise TypeError("Expected node to be IRCell") @@ -207,61 +227,18 @@ def add_predecessor(self, input_index: int, node, out_index: int): raise RuntimeError( f"Set the input out of range ({input_index} >= {len(self._inputs)})" ) - self._inputs[input_index] = node.outputs(out_index) - self._predecessors[input_index].append(node) - node.add_successor(out_index, self) + if node not in self._predecessors[input_index]: + self._predecessors[input_index].append(node) - def add_successor(self, tensor, node): + def add_successor(self, output_index: int, node): """ Set self node the output index node. `node` will take the self.outputs(index) as the input """ if not isinstance(node, IRCell): raise TypeError("Expected node to be IRCell") - out_index = self._outputs.index(tensor) - if out_index < 0: - raise RuntimeError("Fail to find output tensor") - self._successors[out_index].append(node) - - def get_send_tensors(self): - """ - Collect send tensors at cell level. - This will not care what happened inside this cell - - Returns: - send_tensors: list of IRTensor - send_devices: list of list[int] devices for each tensor - """ - send_tensors = list() - send_devices = list() - for idx, output in enumerate(self.outputs()): - if isinstance(output, IRTensor): - succ_cells = self.successors(idx) - for cell in succ_cells: - devices = set(cell.device) - set(output.device) - if len(devices) != 0: - send_tensors.append(output) - send_devices.append(list(devices)) - return send_tensors, send_devices - - def get_recv_tensors(self): - """ - Collect recv tensors at cell level. - This will not care what happened inside this cell - - Returns: - recv_tensors: list of IRTensor - recv_devices: list of list[int] devices for each tensor - """ - recv_tensors = list() - recv_devices = list() - for input in self.inputs(): - if isinstance(input, IRTensor): - devices = set(input.device) - set(self.device) - if len(devices) != 0: - recv_tensors.append(input) - recv_devices.append(list(devices)) - return recv_tensors, recv_devices + if node not in self._successors[output_index]: + self._successors[output_index].append(node) def __repr__(self): """ @@ -270,14 +247,14 @@ def __repr__(self): inputs = list() for tensor in self.inputs(): if isinstance(tensor, IRTensor): - inputs.append(f't{tensor._id}') + inputs.append(f't{tensor._id}-dev{tensor.device}') else: inputs.append(tensor) outputs = list() for tensor in self.outputs(): if isinstance(tensor, IRTensor): - outputs.append(f't{tensor._id}') + outputs.append(f't{tensor._id}-dev{tensor.device}') else: outputs.append(tensor) dcsp = f'Cell-{self._id}({self.signature}, device={self.device})'\ @@ -294,33 +271,66 @@ def __init__(self, shape=None, name=None): self._id: int = IDGenerator().gen_tensor_id() self._shape: Optional(List[int]) = shape self.name = name - self._device = list() - # connected to IRCell - self._src_nodes: List[IRCell] = list() # -> output of the node - self._dst_nodes: List[IRCell] = list() # -> input of the nodes + # device + self._cell: List[IRCell] = list() # forward graph self.requires_grad = True - self.gen_graph = None + self.trace = None + + def attach_cell(self, cell: IRCell): + if not isinstance(cell, IRCell): + raise TypeError("Expected an IRCell") + if cell not in self._cell: + self._cell.append(cell) - def set_gen_graph(self, graph): + def detach_cell(self, cell: IRCell): + if not isinstance(cell, IRCell): + raise TypeError("Expected an IRCell") + if cell not in self._cell: + raise RuntimeError("the target cell not in the attached list") + self._cell.remove(cell) + + def set_trace(self, sus: List): """ - Set forward graph (IRGraph) + Set tensor generation trace """ - self.gen_graph = graph + if not isinstance(sus, list): + raise TypeError("Expected List[ScheduleUnit]") + self.trace = sus - def __copy__(self): + def renew(self): """ - Copy the tensor that will be same except a new id + Renew a new tensor with same name and shape + + Returns: + tensor """ tensor = IRTensor(self._shape, self.name) new_id = tensor._id for key in self.__dict__: setattr(tensor, key, getattr(self, key)) + # clear attached cells + tensor._cell = list() tensor._id = new_id return tensor + def __copy__(self): + """ + Copy the tensor that will have the exactly same id + except the empty attached cell + + Returns: + tensor + """ + tensor = IRTensor(self._shape, self.name) + for key in self.__dict__: + setattr(tensor, key, getattr(self, key)) + # clear attached cells + tensor._cell = list() + return tensor + def __deepcopy__(self, memo): """ Deep Copy will copy the exactly same tensor with same tensor id @@ -357,47 +367,46 @@ def shape(self, val): @property def device(self) -> List[int]: - return self._device + device = set() + for cell in self._cell: + device.update(set(cell.device)) + return list(device) @device.setter - def device(self, device_id: List[int]): - """ - Set placement of the tensor + def device(self, device_id: Union[int, List[int]]): + raise RuntimeError( + "tensor placement is not allowed to set manually" + ) - A tensor can be placed on multiple devices as input - for multiple operations on different devices + def src(self, cells: List[IRCell]) -> List[IRCell]: """ - if isinstance(device_id, int): - device_id = [device_id] - if not all([isinstance(devid, int) for devid in device_id]) : - raise TypeError(f"Expected device id to be int or List[int]") - self._device = device_id - - def src(self) -> List[IRCell]: - return self._src_nodes - - def dst(self, index: Optional[int] = None): - if index is None: - return self._dst_nodes - elif index >= len(self._dst_nodes): - raise RuntimeError("get tensor dst out of range") - return self._dst_nodes[index] - - def add_src_node(self, node: IRCell): - if not isinstance(node, IRCell): - raise TypeError("IRTensor source node should be IRCell") - self._src_nodes.append(node) + Return all the cells that will generate this tensor + """ + src_cells = list() + for cell in cells: + if not isinstance(cell, IRCell): + raise TypeError("Expected cells to be List[IRCell]") + if self in cell.outputs(): + src_cells.append(cell) + return src_cells - def add_dst_node(self, node: IRCell): - if not isinstance(node, IRCell): - raise TypeError("IRTensor destination node should be IRCell") - self._dst_nodes.append(node) + def dst(self, cells: List[IRCell]) -> List[IRCell]: + """ + Return all the cells that will generate this tensor + """ + dst_cells = list() + for cell in cells: + if not isinstance(cell, IRCell): + raise TypeError("Expected cells to be List[IRCell]") + if self in cell.inputs(): + dst_cells.append(cell) + return dst_cells - def is_leaf(self): + def is_leaf(self, cells: List[IRCell]): """ - Check if it is a leaf tensor (parameter) + Check if it is a leaf tensor (parameter or input data) """ - return len(self.src()) == 0 + return len(self.src(cells)) == 0 def backward(self): """ @@ -405,10 +414,11 @@ def backward(self): Construct a reverse graph of forward and seperate to actions """ - if self.gen_graph is None: - raise RuntimeError("Backward on a tensor without forward graph") - self.gen_graph.backward(self) - + if self.trace is None: + return + from cube.tschedule.pool import TSchedulePool + for fsu in self.trace[::-1]: + TSchedulePool().add_su(fsu.mirror) def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device})' diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 75382964..fc57195e 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -1,10 +1,23 @@ +""" +IRGraph: + a graph that is composed by node (IROperation) and edge (IRTensor). + + Note the device of graph.inputs() can be different of the same input + tensor of operation node in the graph. In this case, a move operation + will be inserted at scheduling time. +""" + +from typing import Union, Tuple, List, Optional, Any + from cube.graph.ir_cten import IRTensor, IRCell from cube.graph.ir_op import IROperation +from cube.graph.ir_comm import IRCommunication +from cube.tschedule.su import ScheduleUnit, forward_convert -from typing import Union, Tuple, List, Optional +import copy -__all__ = ['IRGraph', 'IRAction'] +__all__ = ['IRGraph'] class IRGraph(IRCell): @@ -21,16 +34,119 @@ def __init__(self, module_name: str): self._nodes: List[IROperation] = nodes + self.reset_dependency() + super().__init__( name=module_name, signature=module_name, input_length=len(input_tensors), output_length=len(output_tensors) ) - self._inputs = input_tensors - self._outputs = output_tensors + + for idx, tensor in enumerate(input_tensors): + self.set_input(idx, tensor) + for idx, tensor in enumerate(output_tensors): + self.set_output(idx, tensor) + self.tag = 'forward' + def reset_dependency(self): + """ + Reset the node dataflow dependency + """ + # set node predecessors and successors + for src_idx in range(len(self._nodes)): + src_cell = self._nodes[src_idx] + src_cell._successors = [ + list() for _ in range(len(src_cell.outputs())) + ] + for dst_idx in range(src_idx + 1, len(self._nodes)): + dst_cell = self._nodes[dst_idx] + dst_cell._predecessors = [ + list() for _ in range(len(dst_cell.inputs())) + ] + for tensor in src_cell.outputs(): + if isinstance(tensor, IRTensor): + if tensor in dst_cell.inputs(): + src_output_idx = src_cell.outputs().index(tensor) + src_cell.add_successor(src_output_idx, dst_cell) + dst_input_idx = dst_cell.inputs().index(tensor) + dst_cell.add_predecessor(dst_input_idx, src_cell) + + def copy(self, reverse=False): + """ + Copy the graph but re-new the intermediate tensor + """ + new_tensors = dict() # old graph tensor._id -> new tensor + + def _renew(val: Any): + if not isinstance(val, IRTensor): + return val + # parameters + if val.is_leaf(self.nodes()) and val not in self.inputs(): + return val + # intermediate data + if val._id not in new_tensors: + tensor = val.renew() + new_tensors[val._id] = tensor + return new_tensors[val._id] + + nodes = list() + for node in self.nodes(): + + if isinstance(node, IRCommunication): + send_tensors = [_renew(tensor) for tensor in node.inputs()] + send_ranks = node.send_ranks + recv_tensors = [_renew(tensor) for tensor in node.outputs()] + recv_ranks = node.recv_ranks + if reverse: + send_tensors, recv_tensors = recv_tensors, send_tensors + send_ranks, recv_ranks = recv_ranks, send_ranks + + new_node = IRCommunication( + send_tensors = send_tensors, + send_ranks = send_ranks, + recv_tensors = recv_tensors, + recv_ranks = recv_ranks + ) + + elif isinstance(node, IROperation): + inputs = node.inputs() + outputs = node.outputs() + if reverse: + inputs, outputs = outputs, inputs + + new_node = IROperation( + node.name, node.signature, + len(inputs), len(outputs) + ) + # set inputs + for idx, val in enumerate(inputs): + new_node.set_input(idx, _renew(val)) + # set outputs + for idx, val in enumerate(outputs): + new_node.set_output(idx, _renew(val)) + else: + raise TypeError("Found node with unsupported copy") + new_node.device = node.device + nodes.append(new_node) + + inputs = [_renew(input) for input in self.inputs()] + outputs = [_renew(output) for output in self.outputs()] + + if reverse: + inputs, outputs = outputs, inputs + nodes = nodes[::-1] + + copied_graph = IRGraph( + nodes = nodes, + input_tensors = inputs, + output_tensors = outputs, + module_name = self.name + ) + copied_graph.tag = self.tag + return copied_graph + def add_node(self, node: IRCell): if not isinstance(node, IRCell): raise TypeError("Expected node to be IROperation") @@ -47,15 +163,67 @@ def nodes(self, index: Optional[int] = None): ) return self._nodes[index] elif index is None: - return self._nodes + return copy.copy(self._nodes) else: raise TypeError("Expected index to be None or int") - def replace(self, target: IROperation, nodes: List[IROperation]): + def insert(self, node, src_node=None, dst_node=None, replaced_tensor=None): """ - Replace the node with new nodes (IRGraph) + Insert a node between src_node and dst_node. In default, + if dst_node is not None, the node will be inserted right before + dst_node. If the replaced_tensor is provided, the replaced_tensor + in dst_node's inputs will be removed, and the output of node will be + set as input for dst_node. """ - raise NotImplementedError + if not isinstance(node, IRCell): + raise TypeError("Expected IRCell to insert") + if dst_node is not None: + if dst_node not in self._nodes: + raise KeyError("dst_node not found") + if replaced_tensor is not None: + if replaced_tensor not in dst_node.inputs(): + raise RuntimeError(f"Expected dst_node input has {replaced_tensor}") + # remove dst_node input + input_index = dst_node.inputs().index(replaced_tensor) + if len(node.outputs()) != 1: + raise RuntimeError("replaced node requires output length to be 1") + dst_node.set_input(input_index, node.outputs(0)) + # insert node + index = self._nodes.index(dst_node) + self._nodes.insert(index, node) + elif src_node is not None: + if src_node not in self._nodes: + raise KeyError("src_node not found") + index = self._nodes.index(src_node) + self._nodes = self._nodes[:index+1] + [node] + self._nodes[index+1:] + else: + raise TypeError("Expected at least one of [src_node, dst_node]") + #TODO: optimize this + self.reset_dependency() + + def replace_tensor(self, old_tensor: IRTensor, new_tensor: IRTensor): + """ + Replace tensor from old_tensor to new_tensor for all the graph. + """ + def _replace_inputs(cell, old_tensor, new_tensor): + index = cell.inputs().index(old_tensor) + cell.set_input(index, new_tensor) + + def _replace_outputs(cell, old_tensor, new_tensor): + index = cell.outputs().index(old_tensor) + cell.set_output(index, new_tensor) + + if old_tensor in self.inputs(): + _replace_inputs(self, old_tensor, new_tensor) + + for node in self.nodes(): + if old_tensor in node.inputs(): + _replace_inputs(node, old_tensor, new_tensor) + if old_tensor in node.outputs(): + _replace_outputs(node, old_tensor, new_tensor) + + if old_tensor in self.outputs(): + _replace_outputs(self, old_tensor, new_tensor) def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: """ @@ -74,49 +242,38 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: raise RuntimeError( f"Expected {len(self.inputs())} input args but got {len(args)}" ) - # check input type - if not all([type(arg) is type(input) for arg, input in zip(args, self.inputs())]): - raise RuntimeError(f"Expected input type the same") - - curr_nodes: List[IRCell] = list() - curr_device = self.nodes(0).device - total_actions = list() - for node in self.nodes(): - device = node.device - if len(node.device) == 0: - raise RuntimeError("All the node should be assigned to devices") - if set(device) != set(curr_device): - # create action - action = IRAction(curr_nodes, self, devices=curr_device) - total_actions.append(action) - # register to schedule space - TSchedulePool().add_action(action) - # clear - curr_nodes = list() - curr_device = device - curr_nodes.append(node) - if curr_device is not None: - action = IRAction(curr_nodes, self, devices=curr_device) - total_actions.append(action) - TSchedulePool().add_action(action) - - # setup action inputs - output_map = { - gten._id : aten for gten, aten in zip(self.inputs(), args) - } - for action in total_actions: - for idx, input in enumerate(action.graph.inputs()): - if isinstance(input, IRTensor): - input = output_map[input._id] - action.set_input(idx, input) - for action_out, graph_out in zip(action.outputs(), action.graph.outputs()): - output_map[graph_out._id] = action_out + forward_graph = self.copy() + backward_graph = self.copy(reverse=True) + backward_graph.tag = 'backward' + + # set input + for input, arg in zip(self.inputs(), args): + if type(arg) != type(input): + raise RuntimeError(f"Expected input type the same") + forward_graph.replace_tensor(input, arg) + + fsus = forward_convert(forward_graph) # return tensors - outputs = tuple(total_actions[-1].outputs()) + outputs = forward_graph.outputs() for output in outputs: - output.set_gen_graph(self) + output.set_trace(fsus) + + # set backward graph input + for input, output in zip(backward_graph.inputs(), outputs): + backward_graph.replace_tensor(input, output) + + bsus = forward_convert(backward_graph) + + for fsu, bsu in zip(fsus, bsus[::-1]): + fsu.set_mirror(bsu) + bsu.set_mirror(fsu) + print(f'pair: {fsu} <-> {bsu}') + + # add forward schedule to pool + for su in fsus: + TSchedulePool().add_su(su) if len(outputs) == 1: return outputs[0] elif len(outputs) == 0: return None @@ -128,82 +285,6 @@ def __call__(self, *args): """ return self.forward(*args) - def backward(self, loss: IRTensor): - """ - Backward will generate a backward action scheduling pool - - Construct a reverse graph of forward and seperate to actions - """ - # travel graph in reverse order - all_tensors = dict() - - def get_tensor_grad(tensor): - if tensor._id not in all_tensors: - #name = 'grad' if tensor.name is None else tensor.name + '_grad' - new_tensor = IRTensor( - shape=tensor.shape, name=tensor.name - ) - new_tensor._id = tensor._id # -> keep same tensor - # reverse op - devices = set() - for node in tensor.dst(): - devices.update(node.device) - devices = list(devices) - if len(devices) == 0: - devices = tensor.device - new_tensor.device = devices - all_tensors[tensor._id] = new_tensor - return new_tensor - else: - return all_tensors[tensor._id] - - # backward graph inputs - graph_inputs = list() - # none outputs for loss - graph_outputs = list() - # nodes - backward_nodes = list() - all_bp_tensors = list() - - for fnode in self._nodes[::-1]: - inputs = list() - for input in fnode.outputs(): - if isinstance(input, IRTensor) and input.requires_grad: - tensor = get_tensor_grad(input) - if tensor not in all_bp_tensors: - graph_inputs.append(tensor) - all_bp_tensors.append(tensor) - inputs.append(tensor) - else: - inputs.append(None) - outputs = list() - for output in fnode.inputs(): - if isinstance(output, IRTensor) and output.requires_grad: - tensor = get_tensor_grad(output) - all_bp_tensors.append(tensor) - outputs.append(tensor) - else: - outputs.append(None) - bp_node = IROperation( - name = fnode.name + '_backward', - signature = fnode.signature, - input_length = len(inputs), - output_length = len(outputs) - ) - bp_node.device = fnode.device - for idx, arg in enumerate(inputs): - bp_node.set_input(idx, arg) - for idx, arg in enumerate(outputs): - bp_node.set_output(idx, arg) - backward_nodes.append(bp_node) - graph = IRGraph( - backward_nodes, - graph_inputs, graph_outputs, - self.name + 'Backward' - ) - graph.tag = 'backward' - graph(loss) - def subgraph(self, sub_nodes: List[IRCell]): """ Create a subgraph with sub nodes. @@ -212,56 +293,36 @@ def subgraph(self, sub_nodes: List[IRCell]): and graph output (send tensors) Return: - IRGraph, - recv tensor starting offset (int) in input, - send tensor starting offset (int) in output + IRGraph """ - def _update(x_tensors, x_devices, tensor, devices): - if tensor not in x_tensors: - x_tensors.append(tensor) - x_devices.append(set(devices)) - else: - idx = x_tensors.index(tensor) - x_devices[idx].update(set(devices)) - - # recv tensors - recv_tensors = list() - recv_devices = list() - # send tensors - send_tensors = list() - send_devices = list() - # get nodes belong to this graph - all_tensors = list() - for node in sub_nodes: - # collect recv tensors - tensors_and_devices = node.get_recv_tensors() - for r_tensor, r_devices in zip(*tensors_and_devices): - _update(recv_tensors, recv_devices, r_tensor, r_devices) - # collect send tensors - tensors_and_devices = node.get_send_tensors() - for s_tensor, s_devices in zip(*tensors_and_devices): - _update(send_tensors, send_devices, s_tensor, s_devices) - all_tensors += node.inputs() - all_tensors += node.outputs() - - # set extra graph inputs and outputs + # find input inputs = list() outputs = list() - for input in self.inputs(): - if input in all_tensors and input not in recv_tensors: - inputs.append(input) - for output in self.outputs(): - if output in all_tensors and output not in send_tensors: - outputs.append(output) + for node in sub_nodes: + outer_cells = list(set(self.nodes()) - set(sub_nodes)) + for tensor in node.inputs(): + if isinstance(tensor, IRTensor) and tensor not in inputs: + # if a tensor is generated by other nodes out of sub_nodes, + # then this tensor should be the input + src_nodes = tensor.src(outer_cells) + if len(src_nodes) != 0 or tensor in self.inputs(): + inputs.append(tensor) + for tensor in node.outputs(): + if isinstance(tensor, IRTensor) and tensor not in outputs: + # if a tensor is used by other nodes out of sub_nodes, + # then this tensor should be output + dst_nodes = tensor.dst(outer_cells) + if len(dst_nodes) != 0 or tensor in self.outputs(): + outputs.append(tensor) graph = IRGraph( nodes = sub_nodes, - input_tensors = inputs + recv_tensors, - output_tensors = outputs + send_tensors, + input_tensors = inputs, + output_tensors = outputs, module_name = self.name ) - return graph, len(inputs), len(outputs) + return graph def __repr__(self): @@ -278,91 +339,3 @@ def __repr__(self): # outputs dscp += f"\nOutputs: {self._outputs}\n{'=' * len(self.name)}\n" return dscp - - -# outputs = cube.runtime.temporal.forward(model, *args) -_forward_signature = 'cube.runtime.temporal.forward' -# grads = cube.runtime.temporal.backward(input_tensors, output_tensors, output_grads) -_backward_signature = 'cube.runtime.temporal.backward' - - -class IRAction(IRCell): - """ - Action recv tensors must be inside of Action inputs, - and can be mapped to Action.graph.inputs - - """ - - def __init__(self, sub_nodes, global_graph, devices: Union[List[int], int]): - - if isinstance(devices, int): - devices = [devices] - - if not isinstance(global_graph, IRGraph): - raise TypeError(f"Expected graph: IRGraph but go {type(global_graph)}") - - if global_graph.tag == 'forward': - signature = _forward_signature - elif global_graph.tag == 'backward': - signature = _backward_signature - else: - raise RuntimeError(f"Unsupported graph tag: {self.global_graph.tag}") - - self.graph, recv_ofst, send_ofst = global_graph.subgraph(sub_nodes) - self._recv_ofst = recv_ofst - self._send_ofst = send_ofst - - super().__init__( - name = global_graph.tag, - signature = signature, - input_length = len(self.graph.inputs()), - output_length = len(self.graph.outputs()) - ) - # set action device - self.device = devices - # set output shape - for output, g_out in zip(self.outputs(), self.graph.outputs()): - output.device = devices - output.shape = g_out.shape - - @property - def send_tensors(self): - return self._outputs[self._send_ofst:] - - @property - def recv_tensors(self): - return self._inputs[self._recv_ofst:] - - def happen_before(self, action): - """ - Check if the self -> (happened before) action - - Note: this may return false negative as it will only check - 1-hop dependency - """ - if not isinstance(action, IRAction): - raise TypeError("Expected action to be an Action") - return self in action.predecessors() - - def happen_after(self, action): - """ - Check if the action -> (happened before) self - - Note: this may return false negative as it will only check - 1-hop dependency - """ - if not isinstance(action, IRAction): - raise TypeError("Expected action to be an Action") - return self in action.successors() - - def add_flow(self, action): - """ - self -> (happened before) action - """ - raise NotImplementedError - - def __repr__(self): - action_inputs = [f't{tensor._id}' for tensor in self.inputs()] - action_outputs = [f't{tensor._id}' for tensor in self.outputs()] - dscp = f'Action({self.name}):\n\t{self.graph.inputs()} ({action_inputs}) -> {self.graph.outputs()} ({action_outputs})' - return dscp diff --git a/cube/graph/ir_op.py b/cube/graph/ir_op.py index 01071d93..7e6f7abf 100644 --- a/cube/graph/ir_op.py +++ b/cube/graph/ir_op.py @@ -26,38 +26,6 @@ def __init__(self, super().__init__(name, signature, input_length, output_length) self.semantic = IR2LogicOp.map(self.signature) - @property - def device(self): - return self._device - - @device.setter - def device(self, device_id: Union[int, List[int]]): - """ - Set the operation device. - - For computation operators, they are only allowed - to happen on one device (int) - - For communication operators (e.g., move, all-reduce), - they are allowed to happend on multiple devices - """ - if isinstance(device_id, int): - device_id = [device_id] - if not all([isinstance(devid, int) for devid in device_id]): - raise ValueError("Require device Union[int, List[int]]") - self._device = device_id - for input in self._inputs: - # in default, parameters will be placed on - # all devices that needs it - if isinstance(input, IRTensor) and input.is_leaf(): - devices = set() - for node in input.dst(): - devices.update(node.device) - input.device = list(devices) - for output in self._outputs: - if isinstance(output, IRTensor): - output.device = device_id - def infer_shape(self): """ Infer output value shape @@ -85,14 +53,14 @@ def __repr__(self): inputs = list() for tensor in self.inputs(): if isinstance(tensor, IRTensor): - inputs.append(f't{tensor._id}') + inputs.append(f't{tensor._id}-dev{tensor.device}') else: inputs.append(tensor) outputs = list() for tensor in self.outputs(): if isinstance(tensor, IRTensor): - outputs.append(f't{tensor._id}') + outputs.append(f't{tensor._id}-dev{tensor.device}') else: outputs.append(tensor) diff --git a/cube/graph/ir_seq.py b/cube/graph/ir_seq.py deleted file mode 100644 index be754862..00000000 --- a/cube/graph/ir_seq.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import List, Any -import numpy as np - -from cube.graph.ir_cten import IRCell, IRTensor -from cube.graph.ir_graph import IRAction - - -class IRSequence(IRCell): - - def __init__(self, actions: List[IRAction]): - - if not all([isinstance(action, IRAction) for action in actions]): - raise TypeError("Expected a list of IRActions") - - super().__init__( - name = 'action', - signature = 'None', - input_length = 0, - output_length = 0 - ) - self.sequence = actions - - def __iter__(self): - return iter(self.sequence) - - def __len__(self): - return len(self.sequence) - - def append(self, action: IRAction): - self.sequence.append(action) - - def get_forward_inputs(self, action: IRAction) -> List[Any]: - """ - Get corresponding forward action inputs - - The backward graph output tensor shuould be forward graph input tensor - """ - if action.name == 'forward': - return action.inputs() - if action.name == 'backward': - bp_graph_outputs = action.graph.outputs() - fw_action_inputs = [None] * len(bp_graph_outputs) - pre_actions = action.predecessors() - while len(pre_actions) != 0: - pre = list() - for pre_action in pre_actions: - if pre_action.name == 'forward': - for bidx, output in enumerate(bp_graph_outputs): - for fidx, input in enumerate(pre_action.graph.inputs()): - if input == output: - fw_action_inputs[bidx] = pre_action.inputs(fidx) - pre += pre_action.predecessors() - pre_actions = pre - if None in fw_action_inputs: - raise RuntimeError("Couldn't found forward inputs") - return fw_action_inputs - raise RuntimeError(f"Unsupported action name: {action.name}") - - def get_forward_outputs(self, action: IRAction) -> List[Any]: - """ - Get corresponding forward action outputs - - The backward graph input tensor should be forward graph output tensor - """ - if action.name == 'forward': - return action.inputs() - if action.name == 'backward': - bp_graph_inputs = action.graph.inputs() - fw_action_outputs = [None] * len(bp_graph_inputs) - pre_actions = action.predecessors() - while len(pre_actions) != 0: - pre = list() - for pre_action in pre_actions: - if pre_action.name == 'forward': - for bidx, output in enumerate(bp_graph_inputs): - for fidx, input in enumerate(pre_action.graph.outputs()): - if input == output: - fw_action_outputs[bidx] = pre_action.outputs(fidx) - pre += pre_action.predecessors() - pre_actions = pre - if None in fw_action_outputs: - raise RuntimeError("Couldn't found forward inputs") - return fw_action_outputs - raise RuntimeError(f"Unsupported action name: {action.name}") - - - def is_correct(self): - """ - Check whether sequence - satisfies the sequential consistency model - """ - - for index, action in enumerate(self.sequence): - for pre_action in action.predecessors(): - # find the pre-action not appear in sequence - if not pre_action in self.sequence: - return False - pre_idx = self.sequence.index(pre_action) - # violate sequential consistency model - if pre_idx >= index: - return False - return True \ No newline at end of file diff --git a/cube/graph/unique.py b/cube/graph/unique.py index ac45bd27..635a456c 100644 --- a/cube/graph/unique.py +++ b/cube/graph/unique.py @@ -31,4 +31,4 @@ def gen_cell_id(self): def clear(self): self.instance._tensor_id = 0 - self.instance._op_id = 0 + self.instance._cell_id = 0 diff --git a/cube/sschedule/__init__.py b/cube/sschedule/__init__.py index 622ac5c9..5fa7b2bf 100644 --- a/cube/sschedule/__init__.py +++ b/cube/sschedule/__init__.py @@ -1,6 +1,6 @@ from cube.graph import parser -from cube.graph.ir_graph import IRGraph, IRAction from cube.codegen.codegen import SScheduleCodeGen +from cube.tschedule.su import ScheduleUnit class SpatialModule: @@ -14,15 +14,11 @@ def __init__(self, ir_graph): def get_graph(self): return self._ir_graph - def gen_module(self, rank, outfile, attach=False) -> str: + def gen_module(self, seq, rank, outfile, attach=False) -> str: """ Set the module """ - # TODO: support multiple graph segments - subnodes = [node for node in self._ir_graph.nodes() if node.on_device(rank)] - # subgraph = self._ir_graph.subgraph(subnodes) - action = IRAction(subnodes, self._ir_graph, devices=[rank]) - gener = SScheduleCodeGen(action) + gener = SScheduleCodeGen(seq) code = gener.gen(device=rank, outfile=outfile, attach=attach) return code diff --git a/cube/sschedule/adapter.py b/cube/sschedule/adapter.py new file mode 100644 index 00000000..22c6c79b --- /dev/null +++ b/cube/sschedule/adapter.py @@ -0,0 +1,44 @@ +from typing import Tuple + +from cube.graph.ir_comm import IRCommunication +from cube.graph.ir_graph import IRGraph + + +class Adapter: + + @staticmethod + def adapt(graph: IRGraph) -> IRGraph: + for src_node in graph.nodes(): + for out_idx, tensor in enumerate(src_node.outputs()): + for dst_node in src_node.successors(out_idx): + if set(src_node.device) != set(dst_node.device): + from_rank = src_node.device + to_rank = dst_node.device + from_rank, to_rank = from_rank, to_rank + #TODO check if it is a tensor + send_node, recv_node = Adapter.create_tensor_move( + tensor = tensor, + from_rank = from_rank, + to_rank = to_rank + ) + graph.insert(send_node, src_node=src_node) + graph.insert(recv_node, dst_node=dst_node, + replaced_tensor=tensor) + return graph + + @staticmethod + def create_tensor_move(tensor, from_rank: int, to_rank: int) -> Tuple[IRCommunication, IRCommunication]: + # send node + ir_send_node = IRCommunication( + send_tensors = [tensor], + send_ranks = [to_rank] + ) + ir_send_node.device = from_rank + # recv node + ir_recv_node = IRCommunication( + recv_tensors = [tensor], + recv_ranks = [from_rank] + ) + ir_recv_node.device = to_rank + return ir_send_node, ir_recv_node + diff --git a/cube/sschedule/prim.py b/cube/sschedule/prim.py new file mode 100644 index 00000000..3371c421 --- /dev/null +++ b/cube/sschedule/prim.py @@ -0,0 +1,55 @@ +""" +Spatial primitives for policy +""" +from cube.graph.ir_cten import IRCell, IRTensor +from cube.graph.ir_graph import IRGraph + +from typing import List, Union + + +def assign(inst: Union[IRTensor, IRCell], ranks: List[int], graph: IRGraph) -> None: + """ + Assign a IRTensor / IRCell with spatial rank placement + + For IRCell: + the device attribute will be set to ranks, + the inputs and outputs of this IRCell will also be changed + to ranks. + + For IRTensor: + A move operation will be changed and inserted in order: + output_node -> move -> input_node + """ + if not all([isinstance(rank, int) for rank in ranks]): + raise TypeError("Expected ranks to be List[int]") + if isinstance(inst, IRCell): + inst.device = ranks + elif isinstance(inst, IRTensor): + if set(inst.device) == set(ranks): + return + # find nodes that generated this tensor from the graph + src_node = list() + dst_node = list() + for node in graph.nodes(): + if inst in node.outputs(): + src_node.append(node) + if inst in node.inputs(): + dst_node.append(node) + if len(src_node) == 0: # a leaf tensor + raise NotImplementedError( + "Prim [assign]: moving parameter is not supported" + ) + if len(dst_node) == 0: # a loss tensor + raise RuntimeError( + "Prim [assign]: moving a tensor that is never used in graph" + ) + raise NotImplementedError( + "Prim [assign]: moving tensor is not supported yet" + ) + else: + raise TypeError("Expected inst to ba Union[IRTensor, IRCell]") + + +def select(tensor: IRTensor, indices, val_op, shape) -> IRTensor: + raise NotImplementedError("Prim [select]: selecting sub IRTensor is not supported") + diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py index 07ba260a..7f4f64df 100644 --- a/cube/tschedule/__init__.py +++ b/cube/tschedule/__init__.py @@ -3,7 +3,7 @@ from cube.tschedule.pool import TSchedulePool from cube.graph.ir_cten import IRTensor -from cube.graph.ir_seq import IRSequence +from cube.tschedule.suseq import SUSequence from cube.codegen.codegen import TScheduleCodeGen @@ -70,8 +70,8 @@ def decorator(fn: Callable) -> Callable: TSchedulePool().clear() # collect trace fn(ir_graph, ir_dataloader) - actions = TSchedulePool().actions() - seq = IRSequence(actions) + sus = TSchedulePool().sus() + seq = SUSequence(sus) # policy if policy_fn: seq = policy_fn(seq) diff --git a/cube/tschedule/pool.py b/cube/tschedule/pool.py index 545d204a..8843faf3 100644 --- a/cube/tschedule/pool.py +++ b/cube/tschedule/pool.py @@ -6,7 +6,7 @@ class __TSchedulePool: def __init__(self): - self._actions = list() + self._sus = list() self._flow_id = -1 instance = None @@ -18,14 +18,14 @@ def __init__(self): def __getattr__(self, name): return getattr(self.instance, name) - def add_action(self, action): - self.instance._actions.append(action) + def add_su(self, su): + self.instance._sus.append(su) - def actions(self): - return self.instance._actions + def sus(self): + return self.instance._sus def clear(self): - self.instance._actions = list() + self.instance._sus = list() self.instance._flow_id = -1 def gen_id(self) -> int: @@ -36,5 +36,5 @@ def gen_id(self) -> int: return self.instance._flow_id def __repr__(self): - dscp = '\n'.join([repr(action) for action in self._actions]) + dscp = '\n'.join([repr(su) for su in self._sus]) return dscp diff --git a/cube/tschedule/su.py b/cube/tschedule/su.py new file mode 100644 index 00000000..96e201ad --- /dev/null +++ b/cube/tschedule/su.py @@ -0,0 +1,189 @@ +from typing import Union, List, Optional +import copy +from enum import Enum + +from cube.graph.ir_comm import IRCommunication + +from cube.graph.ir_cten import IRCell + + +class SUType(Enum): + + Forward = 'forward' + Backward = 'backward' + Adapter = 'adapter' + + +class ScheduleUnit(IRCell): + """ + Action recv tensors must be inside of Action inputs, + and can be mapped to Action.graph.inputs + + """ + + # outputs = cube.runtime.temporal.forward(model, *args) + _forward_signature = 'cube.runtime.temporal.forward' + # grads = cube.runtime.temporal.backward( + # input_tensors, output_tensors, output_grads + # ) + _backward_signature = 'cube.runtime.temporal.backward' + + def __init__(self, sub_nodes, graph, devices: Union[List[int], int]): + + if all([isinstance(node, IRCommunication) for node in sub_nodes]): + self.tag = 'forward' + else: + self.tag = graph.tag + + self.global_graph = graph + + if self.tag == 'forward': + signature = ScheduleUnit._forward_signature + elif self.tag == 'backward': + signature = ScheduleUnit._backward_signature + else: + raise RuntimeError(f"Unsupported graph tag: {self.tag}") + + subgraph = graph.subgraph(sub_nodes) + self._nodes = sub_nodes + + super().__init__( + name = self.tag, + signature = signature, + input_length = len(subgraph.inputs()), + output_length = len(subgraph.outputs()) + ) + + for idx, input in enumerate(subgraph.inputs()): + self.set_input(idx, input) + for idx, output in enumerate(subgraph.outputs()): + self.set_output(idx, output) + + # set su device + self.device = devices + + # additional control dependency for add_flow + self._ctrl_predecessors = list() + self._ctrl_successors = list() + + self.mirror = None + + def set_mirror(self, su): + """ + Create a mirrored ScheduleUnit: the + inputs and outputs are reversed + """ + if not isinstance(su, ScheduleUnit): + raise TypeError("Expected mirror to be ScheduleUnit") + self.mirror = su + + def nodes(self, index: Optional[int] = None): + """ + Get node at position index + """ + if isinstance(index, int): + if index >= len(self._nodes): + raise RuntimeError( + f"Get node out of range ({index} >= {len(self._nodes)})" + ) + return self._nodes[index] + elif index is None: + return copy.copy(self._nodes) + else: + raise TypeError("Expected index to be None or int") + + def add_predecessor(self, input_index: int, su): + """ + Add a predecessor cell in the input_index slot. + self.input[input_index] = node.output[out_index] + """ + if input_index == -1: + self._predecessors.append(su) + else: + super().add_predecessor(input_index, su) + + def predecessors(self, index: Optional[int] = None) -> List: + """ + Get 1-hop predecessor cells including control predecessors + + Args: + index (Optional[int]): + -1: return control predecessors + None: return all predecessors including index + >0 : return input SUs at input index + + Returns: + cell(s): List[IRCell] + """ + if isinstance(index, int): + if index == -1: + return copy.copy(self._ctrl_predecessors) + if index >= len(self._inputs): + raise RuntimeError( + f"Get the input out of range ({index} >= {len(self._inputs)}" + ) + return copy.copy(self._predecessors[index]) + elif index is None: + predecessors = list() + for pre_cells in self._predecessors: + predecessors += pre_cells + predecessors += self._ctrl_predecessors + return predecessors + else: + raise TypeError("Expected index to be None or int") + + def add_successor(self, output_index: int, su): + """ + Set self node the output index node. + `node` will take the self.outputs(index) as the input + """ + if output_index == -1: + self._successors.append(su) + else: + super().add_successor(output_index, su) + + def successors(self, index: Optional[int] = None) -> List: + """ + Get 1-hop successor cells including control successors + + Args: + index (Optional[int]): + -1: return control successors + None: return all successors including index + >0 : return output SUs at output index + + Returns: + cells: List[ScheduleUnit] + """ + if isinstance(index, int): + if index == -1: + return copy.copy*self._ctrl_successors + if index >= len(self._outputs): + raise RuntimeError( + f"Get the output out of range ({index} >= {len(self._outputs)}" + ) + return copy.copy(self._successors[index]) + elif index is None: + successors = list() + for post_cells in self._successors: + successors += post_cells + successors += self._ctrl_successors + return successors + else: + raise TypeError("Expected index to be None or int") + + def __repr__(self): + su_inputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.inputs()] + su_outputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.outputs()] + dscp = f'SU({self.name}, nodes={len(self.nodes())})-dev{self.device}: {su_inputs} -> {su_outputs}' + return dscp + + +def forward_convert(graph) -> List[ScheduleUnit]: + sus = list() + for node in graph.nodes(): + devices = node.device + for device in devices: + su = ScheduleUnit([node], graph, device) + sus.append(su) + return sus diff --git a/cube/tschedule/suseq.py b/cube/tschedule/suseq.py new file mode 100644 index 00000000..26910170 --- /dev/null +++ b/cube/tschedule/suseq.py @@ -0,0 +1,200 @@ +from typing import List, Any, Optional +import copy + +from cube.graph.ir_cten import IRCell, IRTensor +from cube.tschedule.su import ScheduleUnit + + +class SUSequence(IRCell): + + def __init__(self, sus: List[ScheduleUnit]): + + if not all([isinstance(su, ScheduleUnit) for su in sus]): + raise TypeError("Expected a list of ScheduleUnits") + + super().__init__( + name = 'SU', + signature = 'None', + input_length = 0, + output_length = 0 + ) + self.sequence = sus + self.reset_dependency() + + def reset_dependency(self): + """ + Reset the node dataflow dependency + """ + # set node predecessors and successors + for src_idx in range(len(self.sequence)): + src_cell = self.sequence[src_idx] + src_cell._successors = [ + list() for _ in range(len(src_cell.outputs())) + ] + for dst_idx in range(src_idx + 1, len(self.sequence)): + dst_su = self.sequence[dst_idx] + dst_su._predecessors = [ + list() for _ in range(len(dst_su.inputs())) + ] + for tensor in src_cell.outputs(): + if isinstance(tensor, IRTensor): + if tensor in dst_su.inputs(): + src_output_idx = src_cell.outputs().index(tensor) + src_cell.add_successor(src_output_idx, dst_su) + dst_input_idx = dst_su.inputs().index(tensor) + dst_su.add_predecessor(dst_input_idx, src_cell) + + def __len__(self): + return len(self.sequence) + + def sus(self, index: Optional[int] = None): + """ + Return ScheduleUnit + + Args: + + """ + if isinstance(index, int): + if index >= len(self.sequence): + raise RuntimeError( + f"Get node out of range ({index} >= {len(self.sequence)})" + ) + return self.sequence[index] + elif index is None: + return copy.copy(self.sequence) + else: + raise TypeError("Expected index to be None or int") + + def happen_before(self, su1, su2): + """ + Check if the su1 -> (happened before) su2 + + Returns: + Boolean + """ + if not isinstance(su1, ScheduleUnit) or \ + not isinstance(su2, ScheduleUnit): + raise TypeError("Expected su to be an ScheduleUnit") + if su2 in su1.successors(): + return True + else: + for succ_su in su1.successors(): + if self.happen_before(succ_su, su2): + return True + return False + + def happen_after(self, su1, su2): + """ + Check if the su2 -> (happened before) su1 + + Returns: + Boolean + """ + return self.happen_before(su2, su1) + + def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: + """ + Merge two ScheduleUnit. This requires + + 1). all the nodes in one SU happens before / after + all the nodes in another SU. (Guaranteed by default + as all the operations on sequence are semantic-correct) + + 2). all the nodes in both SU are on the same device, + have same tags and they are not equal. + + 3). Deadlock-free merge. Suppose + SU1 (dev0) -> SU2 (dev1) -> SU3 (dev0) + Then merge SU1 and SU3 to SU4 will cause + deadlock on SU4 -> <- SU2 + + Note due to PyTorch limitation, + merging two forward ScheduleUnits will also cause + the merge of corresponding two backward ScheduleUnits. + + Returns: + if succeed: A merged ScheduleUnit. + if fail: None + """ + + if not isinstance(su1, ScheduleUnit) or \ + not isinstance(su2, ScheduleUnit): + raise TypeError("Expected SU1 and SU2 are ScheduleUnit") + if su1 not in self.sequence or su2 not in self.sequence: + raise ValueError("Expected both su1 and su2 are in sequence") + + # 2) all the nodes in both SU are on the same device + if su1 == su2 or su1.tag != su2.tag: + return None + if set(su1.device) != set(su2.device): + return None + + # 3) deadlock-free merge + index_su1 = self.sequence.index(su1) + index_su2 = self.sequence.index(su2) + # make su1 happen before su2 + su1, su2 = (su1, su2) if index_su1 < index_su2 else (su2, su1) + inter_sus = self.sequence[index_su1+1:index_su2] + for su in inter_sus: + if su.happen_after(su1) and su.happen_before(su2): + return None + + # merge forward su + sub_nodes = su1.nodes() + su2.nodes() + merged_su = ScheduleUnit(sub_nodes, su1.global_graph, su1.device) + + # merge mirrored su + # mirror_su2 -> mirror_su1 + mirror_su1, mirror_su2 = su1.mirror, su2.mirror + sub_nodes = mirror_su2.nodes() + mirror_su1.nodes() + merged_mirror_su = ScheduleUnit( + sub_nodes, mirror_su1.global_graph, mirror_su1.device + ) + + # set mirror + merged_su.set_mirror(merged_mirror_su) + merged_mirror_su.set_mirror(merged_su) + + # replace + self.sequence[index_su1] = merged_su + self.sequence.remove(su2) + if mirror_su1 in self.sequence and mirror_su2 in self.sequence: + index_mirror_su2 = self.sequence.index(mirror_su2) + self.sequence[index_mirror_su2] = merged_mirror_su + self.sequence.remove(mirror_su1) + + # TODO: optimize: reset dependency + self.reset_dependency() + return merged_su + + def add_flow(self, su1, su2): + """ + Add control flow dependency su1 -> su2 + """ + if not isinstance(su1, ScheduleUnit) or not isinstance(su2, ScheduleUnit): + raise TypeError("Expected both SU1 and SU2 are ScheduleUnit") + su1.add_successors(-1, su2) + su2.add_predecessors(-1, su1) + + def is_correct(self): + """ + Check whether sequence + satisfies the sequential consistency model + """ + + for index, su in enumerate(self.sequence): + for pre_su in su.predecessors(): + # find the pre-su not appear in sequence + if not pre_su in self.sequence: + return False + pre_idx = self.sequence.index(pre_su) + # violate sequential consistency model + if pre_idx >= index: + return False + return True + + def __repr__(self): + dscp = f'ScheduleSeq (len={len(self)}):\n' + for su in self.sequence: + dscp += f'\t{su}\n' + return dscp diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py new file mode 100644 index 00000000..7e09e57b --- /dev/null +++ b/tests/graph/test_graph.py @@ -0,0 +1,63 @@ + +import cube.graph.parser as parser +from cube.sschedule.adapter import Adapter +from cube.graph.unique import IDGenerator +import copy + +import torch +from torch import nn + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0., mult=16, classes=1000): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult, bias=False) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim * mult, dim) + self.classifier = nn.Linear(dim, classes) + + def forward(self, data, x: int = 4): + output = self.linear1(data) + output = self.gelu(output) + output = self.dropout(output) + output = output + data + output = self.linear2(output) + output = self.classifier(output) + return output + +model = FeedForward(dim=1024) +graph = parser.convert(model, input_shapes=([1024,1024],[1,])) + + +def test_sendrecv_adapter(graph): + for nid, node in enumerate(graph.nodes()): + if nid < 3: + node.device = 0 + else: + node.device = 1 + print('==== graph (not adapted) ====') + print(graph) + graph = Adapter.adapt(graph) + print('==== graph (after adapter) ====') + print(graph) + + +def test_graph_copy(graph): + graph = graph.copy() + print('====== Copied Graph =====') + print(graph) + print('====== Copied Graph =====') + + +def test_graph_reverse(graph): + graph = graph.copy(reverse=True) + print('====== Reversed Graph =====') + print(graph) + print('====== Reversed Graph =====') + + +if __name__ == '__main__': + + test_sendrecv_adapter(graph) + test_graph_copy(graph) + test_graph_reverse(graph) diff --git a/tests/graph/test_parser.py b/tests/graph/test_parser.py index 9ee2ef93..79ce25c7 100644 --- a/tests/graph/test_parser.py +++ b/tests/graph/test_parser.py @@ -1,5 +1,6 @@ import cube.graph.parser as parser from cube.graph.parser import ScriptModuleParser +from cube.graph.unique import IDGenerator import torch from torch import nn @@ -30,11 +31,20 @@ def test_flatten(model): ScriptModuleParser.flatten(smodule) def test_parse_module(model): - return parser.convert(model, input_shapes=([1024,1024],[1,])) + graph = parser.convert(model, input_shapes=([1024,1024],[1,])) + print(graph) + +def test_device_set(model): + IDGenerator().clear() + graph = parser.convert(model, input_shapes=([1024,1024],[1,])) + for node in graph.nodes(): + node.device = 0 + print('==== graph (with device) ====') + print(graph) if __name__ == '__main__': # test_flatten(model) - graph = test_parse_module(model) - print(graph) \ No newline at end of file + test_parse_module(model) + test_device_set(model) \ No newline at end of file diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 272dcb26..4558805c 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -1,7 +1,8 @@ from cube.graph.ir_graph import IRGraph from cube.tschedule.pool import TSchedulePool from cube.graph.ir_cten import IRTensor -from cube.graph.ir_seq import IRSequence +from cube.tschedule.suseq import SUSequence +from cube.sschedule.adapter import Adapter from torch import nn import torch @@ -30,56 +31,107 @@ def forward(self, data): model = FeedForward(dim=1024) ir_graph = parser.convert(model, input_shapes=([64,1024],)) -print('====== Forward Graph =======\n') -print(ir_graph) -print('====== Forward Graph =======\n') # device assignment -for input in ir_graph.inputs(): - if isinstance(input, IRTensor): - input.device = [0] for nid, node in enumerate(ir_graph.nodes()): - if nid <= 2: + if nid < 3: node.device = 0 else: node.device = 1 +print('====== Forward Graph =======\n') +print(ir_graph) +ir_graph = Adapter.adapt(ir_graph) +print('====== Forward Graph =======\n') + def test_graph_forward(ir_graph): TSchedulePool().clear() tensor1 = ir_graph(IRTensor(shape=[64,1024])) - print(tensor1) + tensor2 = ir_graph(IRTensor(shape=[64,1024])) + assert tensor1 != tensor2 print('====== Forward Test =======') - for action in TSchedulePool().actions(): - gener = SScheduleCodeGen(action) - code = gener.gen(device=action.device[0]) - print(code) - print(TSchedulePool()) + seq = SUSequence(TSchedulePool().sus()) + + for su in seq.sus(): + print(su) + print('\n====== Forward Test =======\n') +def test_su_merge(ir_graph): + + TSchedulePool().clear() + loss = ir_graph(IRTensor(shape=[64,1024])) + seq = SUSequence(TSchedulePool().sus()) + + first_stage = seq.sus()[0:3] + second_stage = seq.sus()[5:8] + + su1 = seq.sus(0) + for su in first_stage[1:]: + su1 = seq.merge(su1, su) + assert su1 is not None + + su2 = second_stage[0] + for su in second_stage[1:]: + su2 = seq.merge(su2, su) + + for su in seq.sus(): + print(su) + + # spatial code + sgener = SScheduleCodeGen(seq) + scode = sgener.gen(device=0) + print(scode) + + # temporal code + tgener = TScheduleCodeGen(seq) + tcode = tgener.gen(device=0) + print(tcode) + + def test_graph_backward(ir_graph): TSchedulePool().clear() - input = IRTensor(shape=[64,1024]) - input.device = [0] - tensor = ir_graph(input) - tensor.backward() - input = IRTensor(shape=[64,1024]) - input.device = [0] - tensor = ir_graph(input) - tensor.backward() + micro_bs = 1 + for _ in range(micro_bs): + loss = ir_graph(IRTensor(shape=[64,1024])) + loss.backward() print('====== Backward Test =======\n') print(TSchedulePool()) - sequence = IRSequence(TSchedulePool().actions()) - from cube.codegen.codegen import TScheduleCodeGen - gener = TScheduleCodeGen(sequence) - code = gener.gen(device=0) - print(code) - code = gener.gen(device=1) - print(code) + seq = SUSequence(TSchedulePool().sus()) + first_stage = seq.sus()[0:3] + second_stage = seq.sus()[5:8] + + su1 = seq.sus(0) + for su in first_stage[1:]: + su1 = seq.merge(su1, su) + assert su1 is not None + + su2 = second_stage[0] + for su in second_stage[1:]: + su2 = seq.merge(su2, su) + + print('===== seq before gen ====') + print(seq) + for su in seq.sus(): + print(f'pair: {su} <-> {su.mirror}') + + sgener = SScheduleCodeGen(seq) + scode = sgener.gen(device=0) + print(scode) + scode = sgener.gen(device=1) + print(scode) + + # temporal code + tgener = TScheduleCodeGen(seq) + tcode = tgener.gen(device=0) + print(tcode) + tcode = tgener.gen(device=1) + print(tcode) print('\n====== Backward Test =======\n') @@ -87,4 +139,5 @@ def test_graph_backward(ir_graph): if __name__ == '__main__': test_graph_forward(ir_graph) + test_su_merge(ir_graph) test_graph_backward(ir_graph) From 3422c76eb6d3f544cdf215ef007a5fc8ae85e557 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 10 Oct 2021 19:45:06 +0800 Subject: [PATCH 0204/1892] enable e2e: TODO: multi mircobatch --- cube/codegen/codegen.py | 19 +++++------- cube/graph/ir_comm.py | 32 +++++++++++++++++++- cube/graph/ir_graph.py | 1 - cube/sschedule/__init__.py | 3 +- cube/sschedule/adapter.py | 1 + cube/tschedule/__init__.py | 3 +- cube/tschedule/su.py | 8 ++++- cube/tschedule/suseq.py | 17 +++++++---- examples/e2e.py | 49 ++++++++++++++++++++++++++----- tests/tschedule/test_tschedule.py | 19 ++++++++++++ 10 files changed, 122 insertions(+), 30 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 6e012ba8..4a48ae38 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -22,11 +22,11 @@ class SScheduleCodeGen: def __init__(self, seq: SUSequence): if not isinstance(seq, SUSequence): raise TypeError("seq should be SUSequence") - self.sus = seq.sus() + self.seq = seq # model full code self.init_code: List[str] = [ '\n\n########## Generated Code ###########', - 'import torch', '', ''] + 'import torch', 'import cube', '', ''] # module init code self.declare_region: List[str] = list() # module forward code @@ -41,7 +41,8 @@ def gen(self, device: int, outfile=None, attach=False) -> str: """ Generate model implementation code based on the given graph. """ - device_sus = [su for su in self.sus if (device in su.device) and (su.tag == 'forward')] + device_sus = [su for su in self.seq.sus() \ + if device in su.device and su.tag != 'backward'] gencode = copy.copy(self.init_code) @@ -54,11 +55,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: su_args.append(fargs) # parse graph body - print(f'device: {device}: {device_sus}') for su in device_sus: - print('====', su) - for node in su.nodes(): - print(node) for node in su.nodes(): if isinstance(node, IRCommunication): self.emit_comm_call(node, su) @@ -82,7 +79,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: cb.insert_body('') cb.insert_body(ib.code) for idx, su in enumerate(device_sus): - name = f'su{self.sus.index(su)}' + name = f'su{self.seq.sus().index(su)}' input_args = ['self'] + su_args[idx] forward_code = self.all_su_forward_region[idx] with FunctionBlock(func_name=name, args=input_args) as fb: @@ -155,11 +152,11 @@ def emit_comm_call(self, node, su: ScheduleUnit): send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' code = f'{comm_code}({send_tensors}, {send_ranks})' elif node.comm_type == IRCommType.Recv: - recv_tensors = '(' + ', '.join(recv_tensors + ['']) + ')' + recv_tensors = ', '.join(recv_tensors) code = f'{recv_tensors} = {comm_code}({recv_shapes}, {recv_ranks})' elif node.comm_type == IRCommType.SendRecv: send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' - recv_tensors = '(' + ', '.join(recv_tensors + ['']) + ')' + recv_tensors = ', '.join(recv_tensors) code = f'{recv_tensors} = {comm_code}({send_tensors}, {send_ranks}, {recv_shapes}, {recv_ranks})' else: raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") @@ -274,7 +271,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: fsign = 'cube.runtime.temporal.forward({model}, *{inputs})' bsign = 'cube.runtime.temporal.backward({input_tensors}, {output_tensors}, {output_grads})' - if su.tag == 'forward': + if su.tag == 'forward' or su.tag == 'adapter': inputs = [self.naming(tensor, su) for tensor in su.inputs()] inputs = '(' + ', '.join(inputs + ['']) + ')' body = fsign.format( diff --git a/cube/graph/ir_comm.py b/cube/graph/ir_comm.py index 7d5642d5..02dfbb3b 100644 --- a/cube/graph/ir_comm.py +++ b/cube/graph/ir_comm.py @@ -1,7 +1,7 @@ from typing import List from enum import Enum -from cube.graph.ir_cten import IRCell +from cube.graph.ir_cten import IRCell, IRTensor class IRCommType(Enum): @@ -59,7 +59,37 @@ def __init__(self, self.recv_tensors.append(self.outputs(idx)) self.recv_ranks.append(from_device) + self.msg_id = self._id + + def pair(self, other): + """ + Pair two comm node to have same message id. + + The `other` message id is set same with caller + """ + if not isinstance(other, IRCommunication): + raise RuntimeError("Expected IRCommunication to pair") + other.msg_id = self.msg_id + def merge(self, other): if not isinstance(other, IRCommunication): raise RuntimeError("Expected IRCommunication to merge") raise NotImplementedError + + def __repr__(self): + inputs = list() + for tensor in self.inputs(): + if isinstance(tensor, IRTensor): + inputs.append(f't{tensor._id}-dev{tensor.device}') + else: + inputs.append(tensor) + + outputs = list() + for tensor in self.outputs(): + if isinstance(tensor, IRTensor): + outputs.append(f't{tensor._id}-dev{tensor.device}') + else: + outputs.append(tensor) + + dscp = f'SendRecv(msg_id={self.msg_id}, device={self.device}, send={inputs}, recv={outputs})' + return dscp \ No newline at end of file diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index fc57195e..6c1b9759 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -269,7 +269,6 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: for fsu, bsu in zip(fsus, bsus[::-1]): fsu.set_mirror(bsu) bsu.set_mirror(fsu) - print(f'pair: {fsu} <-> {bsu}') # add forward schedule to pool for su in fsus: diff --git a/cube/sschedule/__init__.py b/cube/sschedule/__init__.py index 5fa7b2bf..1c9c5673 100644 --- a/cube/sschedule/__init__.py +++ b/cube/sschedule/__init__.py @@ -1,6 +1,6 @@ from cube.graph import parser from cube.codegen.codegen import SScheduleCodeGen -from cube.tschedule.su import ScheduleUnit +from cube.sschedule.adapter import Adapter class SpatialModule: @@ -48,4 +48,5 @@ def schedule(module, input_shapes, policy_fn=None): module = SpatialModule(ir_graph) if policy_fn: module._ir_graph = policy_fn(module.get_graph()) + module._ir_graph = Adapter.adapt(module._ir_graph) return module diff --git a/cube/sschedule/adapter.py b/cube/sschedule/adapter.py index 22c6c79b..21d09bb0 100644 --- a/cube/sschedule/adapter.py +++ b/cube/sschedule/adapter.py @@ -40,5 +40,6 @@ def create_tensor_move(tensor, from_rank: int, to_rank: int) -> Tuple[IRCommunic recv_ranks = [from_rank] ) ir_recv_node.device = to_rank + ir_send_node.pair(ir_recv_node) return ir_send_node, ir_recv_node diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py index 7f4f64df..2815aecc 100644 --- a/cube/tschedule/__init__.py +++ b/cube/tschedule/__init__.py @@ -21,7 +21,6 @@ def __next__(self): for data in datas: if torch.is_tensor(data): tensor = IRTensor(shape=list(data.size()), name='input') - tensor.device = [0] else: tensor = data ir_datas.append(tensor) @@ -81,7 +80,7 @@ def decorator(fn: Callable) -> Callable: for rank in range(world_size): fname = filename.format(rank) # generate spatial module code - model.gen_module(rank, fname, attach=False) + model.gen_module(seq, rank, fname, attach=False) # generate temporal schedule code tgener.gen( device = rank, diff --git a/cube/tschedule/su.py b/cube/tschedule/su.py index 96e201ad..8b8a4422 100644 --- a/cube/tschedule/su.py +++ b/cube/tschedule/su.py @@ -27,11 +27,15 @@ class ScheduleUnit(IRCell): # input_tensors, output_tensors, output_grads # ) _backward_signature = 'cube.runtime.temporal.backward' + # cube.runtime.collectives.sendrecv(send_tensors, send_ranks, + # recv_shapes, from_ranks + # ) + _adapter_signature = 'cube.runtime.collectives.sendrecv' def __init__(self, sub_nodes, graph, devices: Union[List[int], int]): if all([isinstance(node, IRCommunication) for node in sub_nodes]): - self.tag = 'forward' + self.tag = 'adapter' else: self.tag = graph.tag @@ -41,6 +45,8 @@ def __init__(self, sub_nodes, graph, devices: Union[List[int], int]): signature = ScheduleUnit._forward_signature elif self.tag == 'backward': signature = ScheduleUnit._backward_signature + elif self.tag == 'adapter': + signature = ScheduleUnit._adapter_signature else: raise RuntimeError(f"Unsupported graph tag: {self.tag}") diff --git a/cube/tschedule/suseq.py b/cube/tschedule/suseq.py index 26910170..7ae2a5d6 100644 --- a/cube/tschedule/suseq.py +++ b/cube/tschedule/suseq.py @@ -120,8 +120,10 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: if not isinstance(su1, ScheduleUnit) or \ not isinstance(su2, ScheduleUnit): raise TypeError("Expected SU1 and SU2 are ScheduleUnit") - if su1 not in self.sequence or su2 not in self.sequence: - raise ValueError("Expected both su1 and su2 are in sequence") + if su1 not in self.sequence: + raise ValueError(f"su1: {su1} not in sequence") + if su2 not in self.sequence: + raise ValueError(f"su2: {su2} not in sequence") # 2) all the nodes in both SU are on the same device if su1 == su2 or su1.tag != su2.tag: @@ -134,9 +136,10 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: index_su2 = self.sequence.index(su2) # make su1 happen before su2 su1, su2 = (su1, su2) if index_su1 < index_su2 else (su2, su1) + index_su1, index_su2 = min(index_su1, index_su2), max(index_su1, index_su2) inter_sus = self.sequence[index_su1+1:index_su2] for su in inter_sus: - if su.happen_after(su1) and su.happen_before(su2): + if su1.happen_after(su) and su.happen_before(su2): return None # merge forward su @@ -195,6 +198,10 @@ def is_correct(self): def __repr__(self): dscp = f'ScheduleSeq (len={len(self)}):\n' - for su in self.sequence: - dscp += f'\t{su}\n' + for node in self.sequence: + succ_node_ids = [None] * len(node.outputs()) + for out_idx in range(len(node.outputs())): + node_list = [snode._id for snode in node.successors(out_idx)] + succ_node_ids[out_idx] = node_list + dscp += f"\n{node._id}: {node} -> su id {succ_node_ids}\n" return dscp diff --git a/examples/e2e.py b/examples/e2e.py index 66819c5c..eafb803c 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -15,20 +15,53 @@ from torch import nn import cube -from cube.graph.ir_cten import IRTensor -def spolicy(ir_graph): - for input in ir_graph.inputs(): - if isinstance(input, IRTensor): - input.device = [0] +def spolicy(ir_graph): for nid, node in enumerate(ir_graph.nodes()): - if nid <= 2: + if nid < 3: node.device = 0 else: node.device = 1 return ir_graph +def tpolicy(seq): + # put to micro-batch forward-backward sequence + fb_op_seqs = list() + for su in seq.sus(): + for fb_seq in fb_op_seqs: + for ksu in fb_seq[::-1]: + if seq.happen_before(ksu, su): + fb_seq.append(su) + break + else: + continue + break + else: + fb_op_seqs.append([su]) + + # merge to stages + fb_stage_seqs = list() + for fb_seq in fb_op_seqs: + merged_su = fb_seq[0] + for su in fb_seq[1:]: + if su.tag == 'backward': + break + out_su = seq.merge(merged_su, su) + if out_su is not None: + merged_su = out_su + else: + print('=====', merged_su) + fb_stage_seqs.append(merged_su) + merged_su = su + fb_stage_seqs.append(merged_su) + + for mbs_seq in fb_stage_seqs: + print('mirobatch seq:', mbs_seq) + print(seq) + return seq + + class FakeDataLoader: def __init__(self, batch_size, num=32): @@ -78,9 +111,9 @@ def train(): dataloader = FakeDataLoader(64) - @cube.tschedule.schedule(model, dataloader) + @cube.tschedule.schedule(model, dataloader, policy_fn=tpolicy) def train_iter(model, dataloader): - for _ in range(4): + for _ in range(1): (data,) = next(dataloader) loss = model(data) loss.backward() diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 4558805c..4449455f 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -136,8 +136,27 @@ def test_graph_backward(ir_graph): print('\n====== Backward Test =======\n') +def test_su_merge(ir_graph): + TSchedulePool().clear() + loss = ir_graph(IRTensor(shape=[64,1024])) + seq = SUSequence(TSchedulePool().sus()) + sus = seq.sus()[0:4] + + su1 = seq.sus(0) + for su in sus[1:]: + su1 = seq.merge(su1, su) + assert su1 is not None + print(su1) + for node in su1.nodes(): + print(node) + print('====') + for node in ir_graph.nodes(): + print(node) + + if __name__ == '__main__': test_graph_forward(ir_graph) test_su_merge(ir_graph) test_graph_backward(ir_graph) + test_su_merge(ir_graph) From 655b38be134193ac92dc563464b2e1674489bd18 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 11 Oct 2021 19:39:50 +0800 Subject: [PATCH 0205/1892] dataloader to node --- cube/codegen/codegen.py | 18 +++-- cube/graph/ir_graph.py | 107 +++++++++++++++++++++++++----- cube/tschedule/__init__.py | 24 ++++--- cube/tschedule/su.py | 76 ++++++++++----------- cube/tschedule/suseq.py | 12 ++-- examples/e2e.py | 33 ++++----- gencode0.py | 93 +++++++++++++++++++++----- gencode1.py | 100 +++++++++++++++++++++------- tests/tschedule/test_tschedule.py | 33 ++------- 9 files changed, 337 insertions(+), 159 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 4a48ae38..832c691f 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -6,7 +6,7 @@ from cube.graph.ir_cten import IRTensor from cube.tschedule.suseq import SUSequence -from cube.tschedule.su import ScheduleUnit +from cube.tschedule.su import ScheduleUnit, SUType from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -42,7 +42,9 @@ def gen(self, device: int, outfile=None, attach=False) -> str: Generate model implementation code based on the given graph. """ device_sus = [su for su in self.seq.sus() \ - if device in su.device and su.tag != 'backward'] + if device in su.device \ + and su.stype != SUType.Backward \ + and su.stype != SUType.Dataloader] gencode = copy.copy(self.init_code) @@ -271,7 +273,15 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: fsign = 'cube.runtime.temporal.forward({model}, *{inputs})' bsign = 'cube.runtime.temporal.backward({input_tensors}, {output_tensors}, {output_grads})' - if su.tag == 'forward' or su.tag == 'adapter': + if su.stype == SUType.Dataloader: + if len(su.inputs()) != 0: + raise RuntimeError("Dataloader su has no inputs") + outputs = [self.naming(output, su) for output in su.outputs()] + return_val = ','.join(outputs) + code = f'{return_val} = {su.signature}' + return code + + elif su.stype == SUType.Forward or su.stype == SUType.Adapter: inputs = [self.naming(tensor, su) for tensor in su.inputs()] inputs = '(' + ', '.join(inputs + ['']) + ')' body = fsign.format( @@ -286,7 +296,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: code = f'{return_val} = {body}' return code - elif su.tag == 'backward': + elif su.stype == SUType.Backward: # 1). input_tensors are forward inputs (happened before su inputs) # => backward graph output tensor (share tensor in forward / backward graph) # 2). output_tensors are forward outputs (su.inputs()) diff --git a/cube/graph/ir_graph.py b/cube/graph/ir_graph.py index 6c1b9759..08f65bb5 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/ir_graph.py @@ -9,10 +9,13 @@ from typing import Union, Tuple, List, Optional, Any +from torch._C import device + from cube.graph.ir_cten import IRTensor, IRCell from cube.graph.ir_op import IROperation from cube.graph.ir_comm import IRCommunication -from cube.tschedule.su import ScheduleUnit, forward_convert +from cube.runtime.temporal import forward +from cube.tschedule.su import SUType, logic_translator import copy @@ -29,13 +32,18 @@ class IRGraph(IRCell): def __init__(self, nodes: List[IROperation], - input_tensors: List[IRTensor], - output_tensors: List[IRTensor], + input_tensors: Optional[List[IRTensor]], + output_tensors: Optional[List[IRTensor]], module_name: str): self._nodes: List[IROperation] = nodes self.reset_dependency() + if input_tensors is None: + input_tensors = IRGraph.get_inputs(nodes) + if output_tensors is None: + output_tensors = IRGraph.get_outputs(nodes) + super().__init__( name=module_name, signature=module_name, @@ -234,7 +242,7 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: even if the input is same Returns: - List[Action] + IRTensors """ from cube.tschedule.pool import TSchedulePool # check input num @@ -243,28 +251,58 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: f"Expected {len(self.inputs())} input args but got {len(args)}" ) - forward_graph = self.copy() - backward_graph = self.copy(reverse=True) - backward_graph.tag = 'backward' + fgraph = self.copy() + bgraph = self.copy(reverse=True) + bgraph.tag = 'backward' # set input - for input, arg in zip(self.inputs(), args): + for input, arg in zip(fgraph.inputs(), args): if type(arg) != type(input): raise RuntimeError(f"Expected input type the same") - forward_graph.replace_tensor(input, arg) + fgraph.replace_tensor(input, arg) + + # dataloader su + cell = IRCell( + name = 'dataloader', + signature = 'dataloder.__next__', + input_length = 0, + output_length = len(args) + ) + for idx, arg in enumerate(args): + cell.set_output(idx, arg) + + devices = set() + for idx, arg in enumerate(args): + cell.set_output(idx, arg) + if isinstance(arg, IRTensor): + for node in arg.dst(fgraph.nodes()): + devices.update(set(node.device)) + cell.device = list(devices) + + dataloader = IRGraph([cell], None, None, 'dataloader') + data_sus = logic_translator( + dataloader, + su_type=SUType.Dataloader + ) - fsus = forward_convert(forward_graph) + for su in data_sus: + TSchedulePool().add_su(su) + + # forward su + fsus = logic_translator(fgraph, su_type=SUType.Forward) # return tensors - outputs = forward_graph.outputs() + outputs = fgraph.outputs() for output in outputs: - output.set_trace(fsus) + if isinstance(output, IRTensor): + output.set_trace(fsus) # set backward graph input - for input, output in zip(backward_graph.inputs(), outputs): - backward_graph.replace_tensor(input, output) + for input, output in zip(bgraph.inputs(), outputs): + bgraph.replace_tensor(input, output) - bsus = forward_convert(backward_graph) + # backward su + bsus = logic_translator(bgraph, su_type=SUType.Backward) for fsu, bsu in zip(fsus, bsus[::-1]): fsu.set_mirror(bsu) @@ -324,6 +362,45 @@ def subgraph(self, sub_nodes: List[IRCell]): return graph + @staticmethod + def get_inputs(nodes: List[IRCell]) -> List[IRTensor]: + """ + Get all the input tensors the is not generated by nodes + + Returns: + List[IRTensor] + """ + all_outputs = list() + for node in nodes: + all_outputs += node.outputs() + inputs = list() + for node in nodes: + for input in node.inputs(): + if isinstance(input, IRTensor): + if input not in all_outputs: + inputs.append(input) + return inputs + + @staticmethod + def get_outputs(nodes: List[IRCell]) -> List[IRTensor]: + """ + Get all the input tensors the is not generated by nodes + + Returns: + List[IRTensor] + """ + all_inputs = list() + for node in nodes: + all_inputs += node.inputs() + outputs = list() + for node in nodes: + for output in node.outputs(): + if isinstance(output, IRTensor): + if output not in all_inputs: + outputs.append(output) + return outputs + + def __repr__(self): dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py index 2815aecc..d96fd82a 100644 --- a/cube/tschedule/__init__.py +++ b/cube/tschedule/__init__.py @@ -1,9 +1,9 @@ from typing import Callable, Optional import torch - from cube.tschedule.pool import TSchedulePool -from cube.graph.ir_cten import IRTensor +from cube.graph.ir_cten import IRTensor, IRCell from cube.tschedule.suseq import SUSequence +from cube.tschedule.su import ScheduleUnit from cube.codegen.codegen import TScheduleCodeGen @@ -16,15 +16,19 @@ def __iter__(self): return self def __next__(self): + # generate a schedule node datas = next(self.dataloader) - ir_datas = list() - for data in datas: - if torch.is_tensor(data): - tensor = IRTensor(shape=list(data.size()), name='input') - else: - tensor = data - ir_datas.append(tensor) - return tuple(ir_datas) + if not isinstance(datas, tuple): + datas = (datas,) + + outputs = [ + IRTensor(shape=list(data.shape), name='data') for data in datas + ] + for output in outputs: + output.requires_grad = False + + #TODO: check data type consistency + return tuple(outputs) def schedule(model, dataloader, policy_fn: Optional[Callable] = None): diff --git a/cube/tschedule/su.py b/cube/tschedule/su.py index 8b8a4422..a8421c65 100644 --- a/cube/tschedule/su.py +++ b/cube/tschedule/su.py @@ -1,7 +1,6 @@ from typing import Union, List, Optional import copy from enum import Enum - from cube.graph.ir_comm import IRCommunication from cube.graph.ir_cten import IRCell @@ -9,60 +8,52 @@ class SUType(Enum): - Forward = 'forward' - Backward = 'backward' - Adapter = 'adapter' - - -class ScheduleUnit(IRCell): - """ - Action recv tensors must be inside of Action inputs, - and can be mapped to Action.graph.inputs - - """ - # outputs = cube.runtime.temporal.forward(model, *args) - _forward_signature = 'cube.runtime.temporal.forward' + Forward = 'cube.runtime.temporal.forward' + # grads = cube.runtime.temporal.backward( # input_tensors, output_tensors, output_grads # ) - _backward_signature = 'cube.runtime.temporal.backward' + Backward = 'cube.runtime.temporal.backward' + # cube.runtime.collectives.sendrecv(send_tensors, send_ranks, # recv_shapes, from_ranks # ) - _adapter_signature = 'cube.runtime.collectives.sendrecv' + Adapter = 'cube.runtime.collectives.sendrecv' - def __init__(self, sub_nodes, graph, devices: Union[List[int], int]): + Dataloader = 'next(dataloader)' - if all([isinstance(node, IRCommunication) for node in sub_nodes]): - self.tag = 'adapter' - else: - self.tag = graph.tag - self.global_graph = graph +class ScheduleUnit(IRCell): + """ + Action recv tensors must be inside of Action inputs, + and can be mapped to Action.graph.inputs - if self.tag == 'forward': - signature = ScheduleUnit._forward_signature - elif self.tag == 'backward': - signature = ScheduleUnit._backward_signature - elif self.tag == 'adapter': - signature = ScheduleUnit._adapter_signature - else: - raise RuntimeError(f"Unsupported graph tag: {self.tag}") + """ + + def __init__(self, sub_nodes, graph, devices: Union[List[int], int], stype: SUType): + + if not isinstance(stype, SUType): + raise TypeError("Expected stype be SUType") + + self.stype = stype + self.global_graph = graph subgraph = graph.subgraph(sub_nodes) - self._nodes = sub_nodes + inputs = subgraph.inputs() + outputs = subgraph.outputs() super().__init__( - name = self.tag, - signature = signature, - input_length = len(subgraph.inputs()), - output_length = len(subgraph.outputs()) + name = graph.name, + signature = stype.value, + input_length = len(inputs), + output_length = len(outputs) ) - for idx, input in enumerate(subgraph.inputs()): + self._nodes = sub_nodes + for idx, input in enumerate(inputs): self.set_input(idx, input) - for idx, output in enumerate(subgraph.outputs()): + for idx, output in enumerate(outputs): self.set_output(idx, output) # set su device @@ -181,15 +172,20 @@ def successors(self, index: Optional[int] = None) -> List: def __repr__(self): su_inputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.inputs()] su_outputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.outputs()] - dscp = f'SU({self.name}, nodes={len(self.nodes())})-dev{self.device}: {su_inputs} -> {su_outputs}' + dscp = f'SU({self.stype}, nodes={len(self.nodes())})-dev{self.device}: {su_inputs} -> {su_outputs}' return dscp -def forward_convert(graph) -> List[ScheduleUnit]: +def logic_translator(graph, su_type: SUType) -> List[ScheduleUnit]: + if not isinstance(su_type, SUType): + raise TypeError("Expected SU Type") sus = list() for node in graph.nodes(): + stype = su_type + if isinstance(node, IRCommunication): + stype = SUType.Adapter devices = node.device for device in devices: - su = ScheduleUnit([node], graph, device) + su = ScheduleUnit([node], graph, device, stype) sus.append(su) return sus diff --git a/cube/tschedule/suseq.py b/cube/tschedule/suseq.py index 7ae2a5d6..379825c7 100644 --- a/cube/tschedule/suseq.py +++ b/cube/tschedule/suseq.py @@ -10,7 +10,9 @@ class SUSequence(IRCell): def __init__(self, sus: List[ScheduleUnit]): if not all([isinstance(su, ScheduleUnit) for su in sus]): - raise TypeError("Expected a list of ScheduleUnits") + raise TypeError( + f"Expected a list of ScheduleUnits, but got {type(sus)}" + ) super().__init__( name = 'SU', @@ -126,7 +128,7 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: raise ValueError(f"su2: {su2} not in sequence") # 2) all the nodes in both SU are on the same device - if su1 == su2 or su1.tag != su2.tag: + if su1 == su2 or su1.stype != su2.stype: return None if set(su1.device) != set(su2.device): return None @@ -144,14 +146,16 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: # merge forward su sub_nodes = su1.nodes() + su2.nodes() - merged_su = ScheduleUnit(sub_nodes, su1.global_graph, su1.device) + merged_su = ScheduleUnit( + sub_nodes, su1.global_graph, su1.device, su1.stype + ) # merge mirrored su # mirror_su2 -> mirror_su1 mirror_su1, mirror_su2 = su1.mirror, su2.mirror sub_nodes = mirror_su2.nodes() + mirror_su1.nodes() merged_mirror_su = ScheduleUnit( - sub_nodes, mirror_su1.global_graph, mirror_su1.device + sub_nodes, mirror_su1.global_graph, mirror_su1.device, mirror_su1.stype ) # set mirror diff --git a/examples/e2e.py b/examples/e2e.py index eafb803c..42045bd4 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -15,6 +15,7 @@ from torch import nn import cube +from cube.tschedule.su import SUType def spolicy(ir_graph): @@ -41,30 +42,20 @@ def tpolicy(seq): fb_op_seqs.append([su]) # merge to stages - fb_stage_seqs = list() for fb_seq in fb_op_seqs: merged_su = fb_seq[0] for su in fb_seq[1:]: - if su.tag == 'backward': - break - out_su = seq.merge(merged_su, su) - if out_su is not None: - merged_su = out_su - else: - print('=====', merged_su) - fb_stage_seqs.append(merged_su) - merged_su = su - fb_stage_seqs.append(merged_su) - - for mbs_seq in fb_stage_seqs: - print('mirobatch seq:', mbs_seq) + if su.stype == SUType.Backward: + continue + msu = seq.merge(merged_su, su) + merged_su = su if msu is None else msu print(seq) return seq class FakeDataLoader: - def __init__(self, batch_size, num=32): + def __init__(self, batch_size, num=640): self.batch_size = batch_size self.length = num self.pos = 0 @@ -75,7 +66,7 @@ def __next__(self): self.pos += 1 if self.pos == self.length: raise StopIteration - return (torch.randn((self.batch_size, 1024)).cuda(),) + return torch.randn((self.batch_size, 1024)).cuda() class FeedForward(nn.Module): @@ -103,17 +94,19 @@ def init_weight(parameters): def train(): + batch_size = 64 + model = FeedForward(dim=1024) model = cube.sschedule.schedule( - model, input_shapes=([64,1024],), + model, input_shapes=([batch_size,1024],), policy_fn=spolicy ) - dataloader = FakeDataLoader(64) + dataloader = FakeDataLoader(batch_size) @cube.tschedule.schedule(model, dataloader, policy_fn=tpolicy) def train_iter(model, dataloader): - for _ in range(1): + for _ in range(4): (data,) = next(dataloader) loss = model(data) loss.backward() @@ -123,7 +116,7 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - for epoch in range(100): + for epoch in range(10): train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() diff --git a/gencode0.py b/gencode0.py index 6f22f362..b9ecca0c 100644 --- a/gencode0.py +++ b/gencode0.py @@ -2,6 +2,7 @@ ########## Generated Code ########### import torch +import cube class GenModel(torch.nn.Module): @@ -10,11 +11,61 @@ def __init__(self): super().__init__() self.weight_3 = torch.nn.Parameter(torch.empty((16384, 1024))) - def forward(self, data_1): - tensor_4 = torch.nn.functional.linear(data_1, self.weight_3, None) - tensor_6 = torch.nn.functional.gelu(tensor_4) - tensor_8 = torch.nn.functional.dropout(tensor_6, 0.0, self.training, False) - return tensor_8 + def su1(self, data_36): + tensor_41 = torch.nn.functional.linear(data_36, self.weight_3, None) + tensor_45 = torch.nn.functional.gelu(tensor_41) + tensor_49 = torch.nn.functional.dropout(tensor_45, 0.0, self.training, False) + return tensor_49 + + def su2(self, tensor_49): + cube.runtime.collectives.send((tensor_49, ), [[1]]) + return + + def su7(self): + tensor_89 = cube.runtime.collectives.recv([[64, 16384]], [[1]]) + return tensor_89 + + def su10(self, data_215): + tensor_220 = torch.nn.functional.linear(data_215, self.weight_3, None) + tensor_224 = torch.nn.functional.gelu(tensor_220) + tensor_228 = torch.nn.functional.dropout(tensor_224, 0.0, self.training, False) + return tensor_228 + + def su11(self, tensor_228): + cube.runtime.collectives.send((tensor_228, ), [[1]]) + return + + def su16(self): + tensor_268 = cube.runtime.collectives.recv([[64, 16384]], [[1]]) + return tensor_268 + + def su19(self, data_394): + tensor_399 = torch.nn.functional.linear(data_394, self.weight_3, None) + tensor_403 = torch.nn.functional.gelu(tensor_399) + tensor_407 = torch.nn.functional.dropout(tensor_403, 0.0, self.training, False) + return tensor_407 + + def su20(self, tensor_407): + cube.runtime.collectives.send((tensor_407, ), [[1]]) + return + + def su25(self): + tensor_447 = cube.runtime.collectives.recv([[64, 16384]], [[1]]) + return tensor_447 + + def su28(self, data_573): + tensor_578 = torch.nn.functional.linear(data_573, self.weight_3, None) + tensor_582 = torch.nn.functional.gelu(tensor_578) + tensor_586 = torch.nn.functional.dropout(tensor_582, 0.0, self.training, False) + return tensor_586 + + def su29(self, tensor_586): + cube.runtime.collectives.send((tensor_586, ), [[1]]) + return + + def su34(self): + tensor_626 = cube.runtime.collectives.recv([[64, 16384]], [[1]]) + return tensor_626 ########## Generated Code ########### @@ -22,15 +73,23 @@ def forward(self, data_1): import cube def _train_step(model, dataloader): - tensor_21 = cube.runtime.temporal.forward(model, *(*next(dataloader), )) - tensor_51 = cube.runtime.collectives.send_and_recv((tensor_21, ), [[1]], ([64, 16384],), [[1]]) # send: ([64, 16384],) - cube.runtime.temporal.backward((), (tensor_21, ), (tensor_51, )) - tensor_54 = cube.runtime.temporal.forward(model, *(*next(dataloader), )) - tensor_84 = cube.runtime.collectives.send_and_recv((tensor_54, ), [[1]], ([64, 16384],), [[1]]) # send: ([64, 16384],) - cube.runtime.temporal.backward((), (tensor_54, ), (tensor_84, )) - tensor_87 = cube.runtime.temporal.forward(model, *(*next(dataloader), )) - tensor_117 = cube.runtime.collectives.send_and_recv((tensor_87, ), [[1]], ([64, 16384],), [[1]]) # send: ([64, 16384],) - cube.runtime.temporal.backward((), (tensor_87, ), (tensor_117, )) - tensor_120 = cube.runtime.temporal.forward(model, *(*next(dataloader), )) - tensor_150 = cube.runtime.collectives.send_and_recv((tensor_120, ), [[1]], ([64, 16384],), [[1]]) # send: ([64, 16384],) - cube.runtime.temporal.backward((), (tensor_120, ), (tensor_150, )) + data_36 = next(dataloader) + tensor_49 = cube.runtime.temporal.forward(model.su1, *(data_36, )) + cube.runtime.temporal.forward(model.su2, *(tensor_49, )) + tensor_89 = cube.runtime.temporal.forward(model.su7, *()) + data_78 = cube.runtime.temporal.backward((data_36, ), (tensor_49, ), (tensor_89, )) + data_215 = next(dataloader) + tensor_228 = cube.runtime.temporal.forward(model.su10, *(data_215, )) + cube.runtime.temporal.forward(model.su11, *(tensor_228, )) + tensor_268 = cube.runtime.temporal.forward(model.su16, *()) + data_257 = cube.runtime.temporal.backward((data_215, ), (tensor_228, ), (tensor_268, )) + data_394 = next(dataloader) + tensor_407 = cube.runtime.temporal.forward(model.su19, *(data_394, )) + cube.runtime.temporal.forward(model.su20, *(tensor_407, )) + tensor_447 = cube.runtime.temporal.forward(model.su25, *()) + data_436 = cube.runtime.temporal.backward((data_394, ), (tensor_407, ), (tensor_447, )) + data_573 = next(dataloader) + tensor_586 = cube.runtime.temporal.forward(model.su28, *(data_573, )) + cube.runtime.temporal.forward(model.su29, *(tensor_586, )) + tensor_626 = cube.runtime.temporal.forward(model.su34, *()) + data_615 = cube.runtime.temporal.backward((data_573, ), (tensor_586, ), (tensor_626, )) diff --git a/gencode1.py b/gencode1.py index 0d5cde6e..b356b8a7 100644 --- a/gencode1.py +++ b/gencode1.py @@ -2,22 +2,73 @@ ########## Generated Code ########### import torch +import cube class GenModel(torch.nn.Module): def __init__(self): super().__init__() - self.weight_10 = torch.nn.Parameter(torch.empty((1024, 16384))) - self.bias_11 = torch.nn.Parameter(torch.empty((1024,))) - self.weight_14 = torch.nn.Parameter(torch.empty((1000, 1024))) - self.bias_15 = torch.nn.Parameter(torch.empty((1000,))) - - def forward(self, tensor_8): - tensor_12 = torch.nn.functional.linear(tensor_8, self.weight_10, self.bias_11) - tensor_16 = torch.nn.functional.linear(tensor_12, self.weight_14, self.bias_15) - tensor_17 = torch.sum(tensor_16) - return tensor_17 + self.weight_14 = torch.nn.Parameter(torch.empty((1024, 16384))) + self.bias_15 = torch.nn.Parameter(torch.empty((1024,))) + self.weight_21 = torch.nn.Parameter(torch.empty((1000, 1024))) + self.bias_22 = torch.nn.Parameter(torch.empty((1000,))) + + def su3(self): + tensor_49 = cube.runtime.collectives.recv([[64, 16384]], [[0]]) + return tensor_49 + + def su4(self, tensor_49): + tensor_58 = torch.nn.functional.linear(tensor_49, self.weight_14, self.bias_15) + tensor_64 = torch.nn.functional.linear(tensor_58, self.weight_21, self.bias_22) + tensor_68 = torch.sum(tensor_64) + return tensor_68 + + def su6(self, tensor_89): + cube.runtime.collectives.send((tensor_89, ), [[0]]) + return + + def su12(self): + tensor_228 = cube.runtime.collectives.recv([[64, 16384]], [[0]]) + return tensor_228 + + def su13(self, tensor_228): + tensor_237 = torch.nn.functional.linear(tensor_228, self.weight_14, self.bias_15) + tensor_243 = torch.nn.functional.linear(tensor_237, self.weight_21, self.bias_22) + tensor_247 = torch.sum(tensor_243) + return tensor_247 + + def su15(self, tensor_268): + cube.runtime.collectives.send((tensor_268, ), [[0]]) + return + + def su21(self): + tensor_407 = cube.runtime.collectives.recv([[64, 16384]], [[0]]) + return tensor_407 + + def su22(self, tensor_407): + tensor_416 = torch.nn.functional.linear(tensor_407, self.weight_14, self.bias_15) + tensor_422 = torch.nn.functional.linear(tensor_416, self.weight_21, self.bias_22) + tensor_426 = torch.sum(tensor_422) + return tensor_426 + + def su24(self, tensor_447): + cube.runtime.collectives.send((tensor_447, ), [[0]]) + return + + def su30(self): + tensor_586 = cube.runtime.collectives.recv([[64, 16384]], [[0]]) + return tensor_586 + + def su31(self, tensor_586): + tensor_595 = torch.nn.functional.linear(tensor_586, self.weight_14, self.bias_15) + tensor_601 = torch.nn.functional.linear(tensor_595, self.weight_21, self.bias_22) + tensor_605 = torch.sum(tensor_601) + return tensor_605 + + def su33(self, tensor_626): + cube.runtime.collectives.send((tensor_626, ), [[0]]) + return ########## Generated Code ########### @@ -25,16 +76,19 @@ def forward(self, tensor_8): import cube def _train_step(model, dataloader): - tensor_21 = cube.runtime.collectives.recv(([64, 16384],), [[0]]) - tensor_23 = cube.runtime.temporal.forward(model, *(tensor_21, )) - tensor_51 = cube.runtime.temporal.backward((tensor_21, ), (tensor_23, ), (None, )) - tensor_54 = cube.runtime.collectives.send_and_recv((tensor_51, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) - tensor_56 = cube.runtime.temporal.forward(model, *(tensor_54, )) - tensor_84 = cube.runtime.temporal.backward((tensor_54, ), (tensor_56, ), (None, )) - tensor_87 = cube.runtime.collectives.send_and_recv((tensor_84, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) - tensor_89 = cube.runtime.temporal.forward(model, *(tensor_87, )) - tensor_117 = cube.runtime.temporal.backward((tensor_87, ), (tensor_89, ), (None, )) - tensor_120 = cube.runtime.collectives.send_and_recv((tensor_117, ), [[0]], ([64, 16384],), [[0]]) # send: ([64, 16384],) - tensor_122 = cube.runtime.temporal.forward(model, *(tensor_120, )) - tensor_150 = cube.runtime.temporal.backward((tensor_120, ), (tensor_122, ), (None, )) - cube.runtime.collectives.send((tensor_150, ), [[0]]) # send: ([64, 16384],) + tensor_49 = cube.runtime.temporal.forward(model.su3, *()) + tensor_68 = cube.runtime.temporal.forward(model.su4, *(tensor_49, )) + tensor_89 = cube.runtime.temporal.backward((tensor_49, ), (tensor_68, ), (None, )) + cube.runtime.temporal.forward(model.su6, *(tensor_89, )) + tensor_228 = cube.runtime.temporal.forward(model.su12, *()) + tensor_247 = cube.runtime.temporal.forward(model.su13, *(tensor_228, )) + tensor_268 = cube.runtime.temporal.backward((tensor_228, ), (tensor_247, ), (None, )) + cube.runtime.temporal.forward(model.su15, *(tensor_268, )) + tensor_407 = cube.runtime.temporal.forward(model.su21, *()) + tensor_426 = cube.runtime.temporal.forward(model.su22, *(tensor_407, )) + tensor_447 = cube.runtime.temporal.backward((tensor_407, ), (tensor_426, ), (None, )) + cube.runtime.temporal.forward(model.su24, *(tensor_447, )) + tensor_586 = cube.runtime.temporal.forward(model.su30, *()) + tensor_605 = cube.runtime.temporal.forward(model.su31, *(tensor_586, )) + tensor_626 = cube.runtime.temporal.backward((tensor_586, ), (tensor_605, ), (None, )) + cube.runtime.temporal.forward(model.su33, *(tensor_626, )) diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py index 4449455f..b7835b01 100644 --- a/tests/tschedule/test_tschedule.py +++ b/tests/tschedule/test_tschedule.py @@ -66,10 +66,10 @@ def test_su_merge(ir_graph): loss = ir_graph(IRTensor(shape=[64,1024])) seq = SUSequence(TSchedulePool().sus()) - first_stage = seq.sus()[0:3] - second_stage = seq.sus()[5:8] + first_stage = seq.sus()[1:4] + second_stage = seq.sus()[6:9] - su1 = seq.sus(0) + su1 = seq.sus(1) for su in first_stage[1:]: su1 = seq.merge(su1, su) assert su1 is not None @@ -95,7 +95,7 @@ def test_su_merge(ir_graph): def test_graph_backward(ir_graph): TSchedulePool().clear() - micro_bs = 1 + micro_bs = 2 for _ in range(micro_bs): loss = ir_graph(IRTensor(shape=[64,1024])) loss.backward() @@ -103,10 +103,10 @@ def test_graph_backward(ir_graph): print(TSchedulePool()) seq = SUSequence(TSchedulePool().sus()) - first_stage = seq.sus()[0:3] - second_stage = seq.sus()[5:8] + first_stage = seq.sus()[1:4] + second_stage = seq.sus()[6:9] - su1 = seq.sus(0) + su1 = seq.sus(1) for su in first_stage[1:]: su1 = seq.merge(su1, su) assert su1 is not None @@ -136,27 +136,8 @@ def test_graph_backward(ir_graph): print('\n====== Backward Test =======\n') -def test_su_merge(ir_graph): - TSchedulePool().clear() - loss = ir_graph(IRTensor(shape=[64,1024])) - seq = SUSequence(TSchedulePool().sus()) - sus = seq.sus()[0:4] - - su1 = seq.sus(0) - for su in sus[1:]: - su1 = seq.merge(su1, su) - assert su1 is not None - print(su1) - for node in su1.nodes(): - print(node) - print('====') - for node in ir_graph.nodes(): - print(node) - - if __name__ == '__main__': test_graph_forward(ir_graph) test_su_merge(ir_graph) test_graph_backward(ir_graph) - test_su_merge(ir_graph) From 0617e05233f90b8cfc2941b4bec60bd4215d1e00 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 13 Oct 2021 15:21:11 +0800 Subject: [PATCH 0206/1892] dataloader data type check --- cube/tschedule/__init__.py | 14 +++++++++----- examples/e2e.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py index d96fd82a..45ea3843 100644 --- a/cube/tschedule/__init__.py +++ b/cube/tschedule/__init__.py @@ -1,7 +1,7 @@ from typing import Callable, Optional import torch from cube.tschedule.pool import TSchedulePool -from cube.graph.ir_cten import IRTensor, IRCell +from cube.graph.ir_cten import IRTensor from cube.tschedule.suseq import SUSequence from cube.tschedule.su import ScheduleUnit from cube.codegen.codegen import TScheduleCodeGen @@ -24,11 +24,15 @@ def __next__(self): outputs = [ IRTensor(shape=list(data.shape), name='data') for data in datas ] - for output in outputs: - output.requires_grad = False + for idx, (output, data) in enumerate(zip(outputs, datas)): + if not torch.is_tensor(data): + outputs[idx] = data + else: + output.requires_grad = False - #TODO: check data type consistency - return tuple(outputs) + if len(outputs) == 0: return + elif len(outputs) == 1: return outputs[0] + else: return tuple(outputs) def schedule(model, dataloader, policy_fn: Optional[Callable] = None): diff --git a/examples/e2e.py b/examples/e2e.py index 42045bd4..a207c152 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -107,7 +107,7 @@ def train(): @cube.tschedule.schedule(model, dataloader, policy_fn=tpolicy) def train_iter(model, dataloader): for _ in range(4): - (data,) = next(dataloader) + data = next(dataloader) loss = model(data) loss.backward() model = model.get_gen_module() From a6a984c44b89d1aacd3d1463ef28f1a5be54ff04 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 13 Oct 2021 15:22:41 +0800 Subject: [PATCH 0207/1892] dataloader typo fix --- cube/tschedule/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py index 45ea3843..df159fc2 100644 --- a/cube/tschedule/__init__.py +++ b/cube/tschedule/__init__.py @@ -7,7 +7,7 @@ from cube.codegen.codegen import TScheduleCodeGen -class IRTesnorDataLoader: +class IRTensorDataLoader: def __init__(self, dataloader): self.dataloader = dataloader @@ -58,7 +58,7 @@ def train_step(model, dataloader): ... """ ir_graph = model.get_graph() - ir_dataloader = IRTesnorDataLoader(dataloader) + ir_dataloader = IRTensorDataLoader(dataloader) myrank = torch.distributed.get_rank() def _load_tschedule_fn(filename) -> Callable: From fe06a95cd7a3ff429130f4fb5bbbf998152b31d4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 14 Oct 2021 13:14:01 +0800 Subject: [PATCH 0208/1892] linear example --- examples/linears.py | 140 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 examples/linears.py diff --git a/examples/linears.py b/examples/linears.py new file mode 100644 index 00000000..3a9cd3d0 --- /dev/null +++ b/examples/linears.py @@ -0,0 +1,140 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/e2e.py +""" +from typing import List + +import torch +from torch import nn + +import cube +from cube.tschedule.su import ScheduleUnit +from cube.tschedule.suseq import SUSequence + + +def trans_policy(graph, resource): + """ + The transformation policy transposes linear using data parallel + """ + ndevice = resource.ngpus + for node in graph.nodes(): + algorithm = node.algorithms('data_parallel') + graph.select(node, algorithm, config=dict(chunk_size=ndevice)) + return graph + + +def schedule_policy(seq: SUSequence, resource): + """ + The schedule policy uses 1F1B (interleaved) pipeline + """ + ndevice = resource.ngpus + + # batch_seqs[idx]: the idx-th forward-backward 4 linear forward + backward + batch_seqs: List[List[ScheduleUnit]] = group_by_batches(seq.sus()) + num_fsus = len(seq.sus()) // len(batch_seqs) // 2 + + # assign devices -- intra device order + for batch_seq in batch_seqs: + for idx, su in enumerate(batch_seq): + stage = idx // (num_fsus // ndevice) + if idx < num_fsus: + seq.assign(su, stage) + else: + seq.assign(su, ndevice - stage % ndevice) + + + # assign devices -- inter device order + f = lambda stage, micro_batch_id: batch_seqs[micro_batch_id][stage] + b = lambda stage, micro_batch_id: batch_seqs[micro_batch_id][-stage] + + reorder = list() + # warmup + for stage in range(ndevice): + for micro_batch_id in range(stage): + reorder = reorder.append(f(stage, micro_batch_id)) + # steady + cooldown + for stage in range(ndevice): + # backward + for micro_batch_id in range(len(batch_seqs)): + reorder.append(b(stage, micro_batch_id)) + # forward + for stage in range(ndevice): + f_mirco_batch_id = micro_batch_id + 1 + ndevice - stage + if f_mirco_batch_id >= len(batch_seqs): + continue + reorder.append(f(stage, f_mirco_batch_id)) + + for idx, su in enumerate(reorder): + seq.move(su, idx) + + + + + +class FakeDataLoader: + def __init__(self, shape, num=640): + self.shape = shape + self.length = num + self.pos = 0 + def __iter__(self): + self.pos = 0 + return self + def __next__(self): + self.pos += 1 + if self.pos == self.length: + raise StopIteration + return torch.randn(self.shape).cuda() + + +class MLP(nn.Module): + def __init__(self, dim, mult=16): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult, bias=False) + self.linear2 = nn.Linear(dim * mult, dim) + self.linear3 = nn.Linear(dim, dim * mult, bias=False) + self.linear4 = nn.Linear(dim * mult, dim) + + def forward(self, data): + output = self.linear1(data) + output = self.linear2(output) + output = self.linear3(output) + output = self.linear4(output) + return output + + +def train(): + batch_size = 64 + dim = 1024 + + model = MLP(dim=dim) + model = model.cuda() + + dataloader = FakeDataLoader((batch_size, dim)) + + def train_iter(model, dataloader): + for _ in range(4): + data = next(dataloader) + output = model(data) + loss = torch.sum(output) / 1000 + print(f'loss={loss.item()}') + loss.backward() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + for epoch in range(10): + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + +if __name__ == '__main__': + + train() \ No newline at end of file From 4cb7873b1068b4433cf706555b15844f8a562db4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Oct 2021 00:26:15 +0800 Subject: [PATCH 0209/1892] re-strcuture and add test --- cube/__init__.py | 2 - cube/graph/__init__.py | 6 +- cube/graph/{ir_comm.py => comm.py} | 2 +- cube/graph/{ir_graph.py => graph.py} | 20 +-- cube/graph/{ir_op.py => operator.py} | 2 +- cube/graph/parser/converter.py | 4 +- cube/graph/parser/parser.py | 3 +- cube/graph/tensor.py | 130 +++++++++++++++ cube/ir/__init__.py | 1 + cube/{graph/ir_cten.py => ir/cten.py} | 76 +++++++-- cube/{graph => ir}/unique.py | 0 tests/graph/test_graph.py | 164 ++++++++++++------- tests/graph/test_parser.py | 52 +++--- tests/ir/test_cell.py | 219 ++++++++++++++++++++++++++ tests/ir/test_tensor.py | 192 ++++++++++++++++++++++ 15 files changed, 762 insertions(+), 111 deletions(-) rename cube/graph/{ir_comm.py => comm.py} (98%) rename cube/graph/{ir_graph.py => graph.py} (96%) rename cube/graph/{ir_op.py => operator.py} (97%) create mode 100644 cube/graph/tensor.py create mode 100644 cube/ir/__init__.py rename cube/{graph/ir_cten.py => ir/cten.py} (86%) rename cube/{graph => ir}/unique.py (100%) create mode 100644 tests/ir/test_cell.py create mode 100644 tests/ir/test_tensor.py diff --git a/cube/__init__.py b/cube/__init__.py index 13952670..491513f6 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,5 +1,3 @@ from cube.device.physic.group import DeviceGroup -from cube import sschedule -from cube import tschedule from cube import runtime diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index 1cfb314f..30225faa 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -1,5 +1,5 @@ -from cube.graph.ir_graph import IRGraph -from cube.graph.ir_cten import IRTensor, IRCell -from cube.graph.ir_op import IROperation +from cube.graph.graph import IRGraph +from cube.graph.operator import IROperation +from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.graph import parser diff --git a/cube/graph/ir_comm.py b/cube/graph/comm.py similarity index 98% rename from cube/graph/ir_comm.py rename to cube/graph/comm.py index 02dfbb3b..c9a4948d 100644 --- a/cube/graph/ir_comm.py +++ b/cube/graph/comm.py @@ -1,7 +1,7 @@ from typing import List from enum import Enum -from cube.graph.ir_cten import IRCell, IRTensor +from cube.ir.cten import IRCell, IRTensor class IRCommType(Enum): diff --git a/cube/graph/ir_graph.py b/cube/graph/graph.py similarity index 96% rename from cube/graph/ir_graph.py rename to cube/graph/graph.py index 08f65bb5..fa5428c4 100644 --- a/cube/graph/ir_graph.py +++ b/cube/graph/graph.py @@ -9,13 +9,10 @@ from typing import Union, Tuple, List, Optional, Any -from torch._C import device - -from cube.graph.ir_cten import IRTensor, IRCell -from cube.graph.ir_op import IROperation -from cube.graph.ir_comm import IRCommunication -from cube.runtime.temporal import forward -from cube.tschedule.su import SUType, logic_translator +from cube.ir.cten import IRTensor, IRCell +from cube.graph.operator import IROperation +from cube.graph.comm import IRCommunication +# from cube.tschedule.su import SUType, logic_translator import copy @@ -155,11 +152,6 @@ def _renew(val: Any): copied_graph.tag = self.tag return copied_graph - def add_node(self, node: IRCell): - if not isinstance(node, IRCell): - raise TypeError("Expected node to be IROperation") - self._nodes.append(node) - def nodes(self, index: Optional[int] = None): """ Get node at position index @@ -209,7 +201,7 @@ def insert(self, node, src_node=None, dst_node=None, replaced_tensor=None): #TODO: optimize this self.reset_dependency() - def replace_tensor(self, old_tensor: IRTensor, new_tensor: IRTensor): + def _replace_tensor(self, old_tensor: IRTensor, new_tensor: IRTensor): """ Replace tensor from old_tensor to new_tensor for all the graph. """ @@ -259,7 +251,7 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: for input, arg in zip(fgraph.inputs(), args): if type(arg) != type(input): raise RuntimeError(f"Expected input type the same") - fgraph.replace_tensor(input, arg) + fgraph._replace_tensor(input, arg) # dataloader su cell = IRCell( diff --git a/cube/graph/ir_op.py b/cube/graph/operator.py similarity index 97% rename from cube/graph/ir_op.py rename to cube/graph/operator.py index 7e6f7abf..941ff83c 100644 --- a/cube/graph/ir_op.py +++ b/cube/graph/operator.py @@ -1,6 +1,6 @@ from typing import List, Union -from cube.graph.ir_cten import IRTensor, IRCell +from cube.ir.cten import IRTensor, IRCell from cube.graph.mapping import IR2LogicOp diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 79e80257..42e22015 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,8 +1,8 @@ -from cube.graph.ir_cten import IRTensor from typing import Optional, List +from cube.ir.cten import IRTensor from cube.graph.parser import ScriptModuleParser -from cube.graph import IRGraph, IRTensor +from cube.graph import IRGraph import torch diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 4429bcba..263ac027 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -3,7 +3,8 @@ import re from typing import List, Tuple, Optional -from cube.graph import IROperation, IRTensor +from cube.graph import IROperation +from cube.ir.cten import IRTensor from cube.graph.parser.frame import Frame class ScriptNodeKind(enum.Enum): diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py new file mode 100644 index 00000000..825aa953 --- /dev/null +++ b/cube/graph/tensor.py @@ -0,0 +1,130 @@ +from typing import List, Optional, Callable +import copy + +from cube.ir.cten import IRTensor + + +__all__ = ['IRFullTensor', 'IRSubTensor'] + + +class IRFullTensor(IRTensor): + + def __init__(self, shape=None, name=None): + + super().__init__(shape, name) + + self._segments = list() + # indices: List[IndexMap] for each segment + self._indices: List = list() + # value op + self._val_ops: List = list() + + def segments(self, index: Optional[int] = None): + """ + Get the SubTensors at index position + """ + if index is None: + return copy.copy(self._segments) + else: + return self._segments[index] + + def indices(self, index: Optional[int] = None): + """ + Get the SubTensors mapping indices + """ + if index is None: + return copy.copy(self._indices) + else: + return self._indices[index] + + def val_ops(self, index: Optional[int] = None): + """ + Get the SubTensors val_op + """ + if index is None: + return copy.copy(self._val_ops) + else: + return self._val_ops[index] + + def select(self, indices, val_op: Optional[Callable], shape: List[int]): + """ + Select a SubTensor from FullTensor. + + Note due to implementation issue, one value in the full tensor + cannot be splitted by different val_op + + Args: + indices: the index of this tensor's index + + val_op: how the tensor is merged with the other + sub_tensor at same location + + shape: the sub_tensor shape. + + Returns: + IRSubTensor + """ + sub_tensor = IRSubTensor(self, indices, val_op, shape) + self._segments.append(sub_tensor) + self._indices.append(indices) + self._val_ops.append(val_op) + return sub_tensor + + +class IRSubTensor: + + def __init__(self, full_tensor: IRTensor, indices, val_op=None, shape=None): + """ + Create an IRSubTensor. + + Args: + full_tensor: the full tensor + indices: index list + val_op: the value operation to merge SubTensors into one + """ + super.__init__(shape=shape, name=full_tensor.name) + + # the full tensor + self._full_tensor = full_tensor + + # the index from full_tensor + self._index_map = indices + + # val merge op + self.val_merge_op = val_op + + @property + def parent(self) -> IRFullTensor: + """ + Return the full tensor of this sub tensor + """ + return self._full_tensor + + def index_map(self): + """ + Return indices list mapped to the full tensor + """ + return copy.copy(self._index_map) + + @property + def val_op(self): + return self.val_merge_op + + def select(self, indices, val_op, shape=None): + """ + Select an IRSubTensor + + Args: + indices: the index of this tensor's index + + val_op: the value operation to merge + co-located indices of SubTensors into one + + shape: the sub_tensor shape + + Returns: + IRSubTensor + """ + index_map = self.index_map[indices] + sub_tensor = self.full_tensor.select(index_map, val_op, shape) + return sub_tensor diff --git a/cube/ir/__init__.py b/cube/ir/__init__.py new file mode 100644 index 00000000..96030a6e --- /dev/null +++ b/cube/ir/__init__.py @@ -0,0 +1 @@ +from cube.ir.cten import IRTensor, IRCell \ No newline at end of file diff --git a/cube/graph/ir_cten.py b/cube/ir/cten.py similarity index 86% rename from cube/graph/ir_cten.py rename to cube/ir/cten.py index d8224671..e20d13a2 100644 --- a/cube/graph/ir_cten.py +++ b/cube/ir/cten.py @@ -17,7 +17,7 @@ from typing import List, Union, Optional, Any import copy -from cube.graph.unique import IDGenerator +from cube.ir.unique import IDGenerator __all__ = ['IRCell', 'IRTensor'] @@ -187,6 +187,9 @@ def set_input(self, input_index: int, val: Any): Args: val: Union[IRTensor, Any] + + Return: + the set tensor """ if input_index >= len(self.inputs()): raise RuntimeError( @@ -198,6 +201,7 @@ def set_input(self, input_index: int, val: Any): # set tensor dst val.attach_cell(self) self._inputs[input_index] = val + return val def set_output(self, output_index: int, val: Any): """ @@ -215,30 +219,72 @@ def set_output(self, output_index: int, val: Any): val = copy.copy(val) val.attach_cell(self) self._outputs[output_index] = val + return val - def add_predecessor(self, input_index: int, node): + def add_predecessor(self, input_index: int, cell): """ Add a predecessor cell in the input_index slot. - self.input[input_index] = node.output[out_index] + + Note this won't add successor if caller cell to the node """ - if not isinstance(node, IRCell): + if not isinstance(cell, IRCell): raise TypeError("Expected node to be IRCell") if input_index >= len(self.inputs()): raise RuntimeError( f"Set the input out of range ({input_index} >= {len(self._inputs)})" ) - if node not in self._predecessors[input_index]: - self._predecessors[input_index].append(node) + if cell not in self._predecessors[input_index]: + self._predecessors[input_index].append(cell) - def add_successor(self, output_index: int, node): + def add_successor(self, output_index: int, cell): """ Set self node the output index node. `node` will take the self.outputs(index) as the input """ - if not isinstance(node, IRCell): + if not isinstance(cell, IRCell): raise TypeError("Expected node to be IRCell") - if node not in self._successors[output_index]: - self._successors[output_index].append(node) + if cell not in self._successors[output_index]: + self._successors[output_index].append(cell) + + @staticmethod + def get_inputs(cells): + """ + Get all the input tensors the is not generated by nodes + + Inputs + + Returns: + List[IRTensor] + """ + all_outputs = list() + for cell in cells: + all_outputs += cell.outputs() + inputs = list() + for cell in cells: + for input in cell.inputs(): + if isinstance(input, IRTensor): + if input not in all_outputs: + inputs.append(input) + return inputs + + @staticmethod + def get_outputs(cells): + """ + Get all the input tensors the is not generated by nodes + + Returns: + List[IRTensor] + """ + all_inputs = list() + for node in cells: + all_inputs += node.inputs() + outputs = list() + for node in cells: + for output in node.outputs(): + if isinstance(output, IRTensor): + if output not in all_inputs: + outputs.append(output) + return outputs def __repr__(self): """ @@ -280,12 +326,19 @@ def __init__(self, shape=None, name=None): self.trace = None def attach_cell(self, cell: IRCell): + """ + Attach to a cell, to be with input or output + """ if not isinstance(cell, IRCell): raise TypeError("Expected an IRCell") if cell not in self._cell: self._cell.append(cell) def detach_cell(self, cell: IRCell): + """ + Detach from a cell, when removing from cell's input + and output + """ if not isinstance(cell, IRCell): raise TypeError("Expected an IRCell") if cell not in self._cell: @@ -302,7 +355,8 @@ def set_trace(self, sus: List): def renew(self): """ - Renew a new tensor with same name and shape + Renew a new tensor with same name and shape, + but with a different new id Returns: tensor diff --git a/cube/graph/unique.py b/cube/ir/unique.py similarity index 100% rename from cube/graph/unique.py rename to cube/ir/unique.py diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 7e09e57b..cc86676d 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -1,63 +1,119 @@ +from cube.graph.graph import IRGraph +from cube.graph.tensor import IRFullTensor +from cube.graph.operator import IROperation -import cube.graph.parser as parser -from cube.sschedule.adapter import Adapter -from cube.graph.unique import IDGenerator -import copy - -import torch -from torch import nn - -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult, bias=False) - self.gelu = nn.GELU() - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim * mult, dim) - self.classifier = nn.Linear(dim, classes) - - def forward(self, data, x: int = 4): - output = self.linear1(data) - output = self.gelu(output) - output = self.dropout(output) - output = output + data - output = self.linear2(output) - output = self.classifier(output) - return output - -model = FeedForward(dim=1024) -graph = parser.convert(model, input_shapes=([1024,1024],[1,])) - - -def test_sendrecv_adapter(graph): - for nid, node in enumerate(graph.nodes()): - if nid < 3: - node.device = 0 - else: - node.device = 1 - print('==== graph (not adapted) ====') - print(graph) - graph = Adapter.adapt(graph) - print('==== graph (after adapter) ====') - print(graph) +def construct_model(): + + input = IRFullTensor(shape=[64,1024], name='data') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = IROperation( + name='linear1', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear1.set_input(0, input) + linear1.set_input(1, weight1) + linear1.set_input(2, bias1) + + # linear2 + linear2 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear2.set_input(0, linear1.outputs(0)) + linear2.set_input(1, weight2) + + # linear3 + linear3 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear3.set_input(0, linear2.outputs(0)) + linear3.set_input(1, weight3) + linear3.set_input(2, bias3) + + # return [input], [ops], [output] + return [input], [linear1, linear2, linear3], [linear3.outputs(0)] -def test_graph_copy(graph): - graph = graph.copy() - print('====== Copied Graph =====') - print(graph) - print('====== Copied Graph =====') +def test_graph_init(): -def test_graph_reverse(graph): - graph = graph.copy(reverse=True) - print('====== Reversed Graph =====') + inputs, ops, outputs = construct_model() + graph = IRGraph(ops, inputs, outputs, 'MLP') print(graph) - print('====== Reversed Graph =====') + assert len(graph.inputs()) == 1 + assert len(graph.outputs()) == 1 + assert graph.tag == 'forward' + assert graph.name == 'MLP' + + all_inputs = list() + all_outputs = list() + for node in graph.nodes(): + all_inputs += node.inputs() + all_outputs += node.outputs() + + # check inputs + for input in inputs: + assert input in graph.inputs() + assert input in all_inputs + for output in outputs: + assert output in graph.outputs() + assert output in all_outputs + + # check dependency + node1, node2, node3 = graph.nodes() + assert node2 in node1.successors() + assert node3 in node2.successors() + assert node1 in node2.predecessors() + assert node2 in node3.predecessors() + # one-hop test + assert node1 not in node3.predecessors() + assert node3 not in node1.successors() + # false test + assert node1 not in node2.successors() + assert node3 not in node2.predecessors() + + +def test_graph_nodes(): + inputs, ops, outputs = construct_model() + graph = IRGraph(ops, inputs, outputs, 'MLP') + assert id(graph.nodes()) != id(graph.nodes()) + assert graph.nodes(1) == ops[1] + + +def test_graph_copy(): + inputs, ops, outputs = construct_model() + graph = IRGraph(ops, inputs, outputs, 'MLP') -if __name__ == '__main__': + cgraph = graph.copy(reverse=False) + print(cgraph) + for gnode, cnode in zip(graph.nodes(), cgraph.nodes()): + assert gnode.name == cnode.name + assert gnode.signature == cnode.signature + assert len(gnode.inputs()) == len(cnode.inputs()) + assert len(gnode.outputs()) == len(cnode.outputs()) + assert len(gnode.predecessors()) == len(cnode.predecessors()) + assert len(gnode.successors()) == len(cnode.successors()) - test_sendrecv_adapter(graph) - test_graph_copy(graph) - test_graph_reverse(graph) + rgraph = graph.copy(reverse=True) + print(rgraph) + for gnode, cnode in zip(graph.nodes(), rgraph.nodes()[::-1]): + assert gnode.name == cnode.name + assert gnode.signature == cnode.signature + assert len(gnode.outputs()) == len(cnode.inputs()) + assert len(gnode.inputs()) == len(cnode.outputs()) + assert len(gnode.predecessors()) == len(cnode.successors()) + assert len(gnode.successors()) == len(cnode.predecessors()) diff --git a/tests/graph/test_parser.py b/tests/graph/test_parser.py index 79ce25c7..4bcce9be 100644 --- a/tests/graph/test_parser.py +++ b/tests/graph/test_parser.py @@ -1,9 +1,8 @@ +from torch import nn + import cube.graph.parser as parser -from cube.graph.parser import ScriptModuleParser -from cube.graph.unique import IDGenerator +from cube.ir.cten import IRTensor -import torch -from torch import nn class FeedForward(nn.Module): def __init__(self, dim, dropout=0., mult=16, classes=1000): @@ -23,28 +22,37 @@ def forward(self, data, x: int = 4): output = self.classifier(output) return output + model = FeedForward(dim=1024) -def test_flatten(model): - smodule = torch.jit.script(model) - ScriptModuleParser.flatten(smodule) +def test_parse_module(): -def test_parse_module(model): graph = parser.convert(model, input_shapes=([1024,1024],[1,])) print(graph) - -def test_device_set(model): - IDGenerator().clear() - graph = parser.convert(model, input_shapes=([1024,1024],[1,])) - for node in graph.nodes(): - node.device = 0 - print('==== graph (with device) ====') - print(graph) - - -if __name__ == '__main__': + assert len(graph.nodes()) == 6 + assert len(graph.inputs()) == 2 + assert len(graph.outputs()) == 1 - # test_flatten(model) - test_parse_module(model) - test_device_set(model) \ No newline at end of file + node1, node2, node3, node4, node5, node6 = graph.nodes() + assert node1.signature == 'torch.nn.functional.linear' + assert node2.signature == 'torch.nn.functional.gelu' + assert node3.signature == 'torch.nn.functional.dropout' + assert node4.signature == 'torch.add' + assert node5.signature == 'torch.nn.functional.linear' + assert node6.signature == 'torch.nn.functional.linear' + + assert node1.inputs(2) is None + assert isinstance(node5.inputs(2), IRTensor) + + # dependency + assert node2.predecessors() == [node1] + assert node3.predecessors() == [node2] + assert node4.predecessors() == [node3] + assert node5.predecessors() == [node4] + assert node6.predecessors() == [node5] + assert node1.successors() == [node2] + assert node2.successors() == [node3] + assert node3.successors() == [node4] + assert node4.successors() == [node5] + assert node5.successors() == [node6] diff --git a/tests/ir/test_cell.py b/tests/ir/test_cell.py new file mode 100644 index 00000000..30fb2857 --- /dev/null +++ b/tests/ir/test_cell.py @@ -0,0 +1,219 @@ +from cube.ir.cten import IRCell, IRTensor + + +def test_cell_init(): + + cell = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + cell2 = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + assert cell2._id != cell._id + + assert len(cell.device) == 0 + assert cell.name == 'cell_test' + assert cell.signature == 'torch.nn.functional.linear' + assert len(cell.inputs()) == 3 + assert len(cell.outputs()) == 1 + assert len(cell.device) == 0 + + +def test_cell_device(): + + cell = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + assert len(cell.device) == 0 + cell.device = 2 + assert len(cell.device) == 1 + assert cell.device[0] == 2 + assert cell.on_device(2) + assert not cell.on_device(3) + + cell.device = [2,3] + assert len(cell.device) == 2 + assert set(cell.device) == set([2, 3]) + assert cell.on_device(2) + assert cell.on_device(3) + assert not cell.on_device(4) + + +def test_cell_inputs(): + + cell = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + assert len(cell.inputs()) == 3 + for input in cell.inputs(): + assert input is None + + # the copy behavior + inputs = cell.inputs() + inputs[2] = 0 + assert cell.inputs(2) is None + + for idx in range(len(cell.inputs())): + assert cell.inputs(idx) is None + tensor = IRTensor(shape=[1024,], name='input') + cell.set_input(idx, tensor) + assert cell.inputs(idx) == tensor + + +def test_cell_outputs(): + + cell = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + assert len(cell.outputs()) == 1 + for output in cell.outputs(): + assert isinstance(output, IRTensor) + + # the copy behavior + outputs = cell.outputs() + outputs[0] = 4 + assert cell.outputs(0) != 4 + + for idx in range(len(cell.outputs())): + output = cell.outputs(idx) + tensor = IRTensor(shape=[1024,], name='output') + cell.set_output(0, tensor) + assert cell.outputs(0) == tensor + assert cell.outputs(0) != output + + +def test_cell_predecessor(): + + cell_prev = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + cell_post = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + assert len(cell_post.predecessors()) == 0 + assert len(cell_prev.predecessors()) == 0 + + cell_post.add_predecessor(1, cell_prev) + assert cell_prev in cell_post.predecessors() + assert len(cell_post.predecessors()) == 1 + assert cell_prev in cell_post.predecessors(1) + + assert len(cell_post.successors()) == 0 + + +def test_cell_successor(): + + cell_prev = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + cell_post = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + assert len(cell_prev.successors()) == 0 + assert len(cell_post.successors()) == 0 + + cell_prev.add_successor(0, cell_post) + assert cell_post in cell_prev.successors() + assert len(cell_prev.successors()) == 1 + assert cell_post in cell_prev.successors() + + assert len(cell_post.predecessors()) == 0 + + +def test_cell_get_inputs_and_outputs(): + + cell1 = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + input1 = IRTensor(shape=[1024, 1024]) + weight1 = IRTensor(shape=[1024, 1024]) + bias1 = IRTensor(shape=[1024,]) + + cell1.set_input(0, input1) + cell1.set_input(1, weight1) + cell1.set_input(2, bias1) + + + cell2 = IRCell( + name='cell_test', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + + input2 = IRTensor(shape=[1024, 1024]) + weight2 = IRTensor(shape=[1024, 1024]) + bias2 = IRTensor(shape=[1024,]) + + cell2.set_input(0, input2) + cell2.set_input(1, weight2) + cell2.set_input(2, bias2) + + inputs = IRCell.get_inputs([cell1, cell2]) + assert len(inputs) == 6 + assert input1 in inputs + assert weight1 in inputs + assert bias1 in inputs + assert input2 in inputs + assert weight2 in inputs + assert bias2 in inputs + + outputs = IRCell.get_outputs([cell1, cell2]) + assert len(outputs) == 2 + for output in cell1.outputs() + cell2.outputs(): + assert output in outputs + + # overlapped + cell2.set_input(1, weight1) + cell2.set_input(0, cell1.outputs(0)) + + inputs = IRCell.get_inputs([cell1, cell2]) + assert len(inputs) == 5 + assert input1 in inputs + assert weight1 in inputs + assert bias1 in inputs + assert bias2 in inputs + + outputs = IRCell.get_outputs([cell1, cell2]) + assert len(outputs) == 1 + assert cell2.outputs(0) in outputs + assert cell1.outputs(0) not in outputs diff --git a/tests/ir/test_tensor.py b/tests/ir/test_tensor.py new file mode 100644 index 00000000..6a172907 --- /dev/null +++ b/tests/ir/test_tensor.py @@ -0,0 +1,192 @@ +import copy + +from cube.ir.cten import IRTensor, IRCell + + +def test_tensor_init(): + + tensor1 = IRTensor() + tensor2 = IRTensor(shape=[1,2,3]) + tensor3 = IRTensor(shape=[1024], name='tensor') + + assert tensor1._id != tensor2._id + assert tensor2._id != tensor3._id + + assert tensor1.shape is None + assert tensor2.shape == [1,2,3] + assert tensor3.shape == [1024,] + + assert tensor1.name is None + assert tensor2.name is None + assert tensor3.name == 'tensor' + + assert len(tensor1.device) == 0 + assert len(tensor2.device) == 0 + + assert tensor1.requires_grad + assert tensor2.requires_grad + assert tensor3.requires_grad + + +def test_tensor_attach(): + + tensor1 = IRTensor() + tensor2 = IRTensor() + cell = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + + tensor1.attach_cell(cell) + assert cell in tensor1._cell + assert len(tensor1._cell) == 1 + assert len(tensor2._cell) == 0 + + tensor1.detach_cell(cell) + assert cell not in tensor1._cell + assert len(tensor1._cell) == 0 + + cell.set_input(0, tensor1) + cell.set_output(0, tensor1) + assert len(tensor1._cell) == 0 + assert len(cell.inputs(0)._cell) == 1 + + +def test_tensor_renew(): + + tensor1 = IRTensor(shape=[1024], name='renew_tensor') + cell = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + cell.set_input(0, tensor1) + tensor1 = cell.inputs(0) + + tensor2 = tensor1.renew() + assert tensor2.shape == tensor1.shape + assert tensor2.name == tensor1.name + assert tensor2 not in cell.inputs() + assert len(tensor2._cell) == 0 + assert tensor2.requires_grad == tensor1.requires_grad + + +def test_tensor_copy(): + + tensor1 = IRTensor(shape=[1024], name='renew_tensor') + cell = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + tensor1 = cell.set_input(0, tensor1) + + tensor2 = copy.copy(tensor1) + assert tensor2 == tensor1 + assert len(tensor2._cell) == 0 + + +def test_tensor_device(): + + tensor1 = IRTensor(shape=[1024], name='renew_tensor') + cell1 = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + cell2 = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + tensor1 = cell1.set_input(0, tensor1) + tensor2 = cell2.set_input(0, tensor1) + + assert tensor1 == tensor2 + + assert len(tensor1.device) == 0 + assert len(tensor2.device) == 0 + + cell1.device = 2 + assert tensor1.device == [2] + assert len(tensor2.device) == 0 + + cell2.device = 3 + assert tensor1.device == [2] + assert tensor2.device == [3] + + +def test_tensor_dst(): + tensor1 = IRTensor(shape=[1024], name='renew_tensor') + cell1 = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + cell2 = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + + cell1.set_input(0, tensor1) + cells = tensor1.dst([cell1, cell2]) + assert set(cells) == set([cell1]) + + cell2.set_input(0, tensor1) + cells = tensor1.dst([cell1, cell2]) + assert set(cells) == set([cell1, cell2]) + + +def test_tensor_src(): + tensor1 = IRTensor(shape=[1024], name='renew_tensor') + cell1 = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + cell2 = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + + cell1.set_output(0, tensor1) + cells = tensor1.src([cell1, cell2]) + assert set(cells) == set([cell1]) + + cell2.set_output(0, tensor1) + cells = tensor1.src([cell1, cell2]) + assert set(cells) == set([cell1, cell2]) + + +def test_tensor_is_leaf(): + tensor1 = IRTensor(shape=[1024], name='renew_tensor') + cell1 = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + cell2 = IRCell( + name='cell', + signature='any', + input_length=3, + output_length=1 + ) + cell1.set_input(0, tensor1) + assert tensor1.is_leaf([cell1]) + + cell2.set_input(0, cell1.outputs(0)) + assert cell2.outputs(0).is_leaf([cell1]) + assert not cell2.outputs(0).is_leaf([cell1, cell2]) From e13d091a1fb319c4513012199ba441d63f9a3dad Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Oct 2021 13:37:12 +0800 Subject: [PATCH 0210/1892] add mapping --- cube/graph/graph.py | 118 +-------------------------- cube/graph/tensor.py | 160 ++++++++++++++++++++++++++++++++++--- tests/graph/test_tensor.py | 122 ++++++++++++++++++++++++++++ 3 files changed, 277 insertions(+), 123 deletions(-) create mode 100644 tests/graph/test_tensor.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index fa5428c4..1864b80f 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -12,7 +12,6 @@ from cube.ir.cten import IRTensor, IRCell from cube.graph.operator import IROperation from cube.graph.comm import IRCommunication -# from cube.tschedule.su import SUType, logic_translator import copy @@ -37,9 +36,9 @@ def __init__(self, self.reset_dependency() if input_tensors is None: - input_tensors = IRGraph.get_inputs(nodes) + input_tensors = IRCell.get_inputs(nodes) if output_tensors is None: - output_tensors = IRGraph.get_outputs(nodes) + output_tensors = IRCell.get_outputs(nodes) super().__init__( name=module_name, @@ -236,77 +235,8 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: Returns: IRTensors """ - from cube.tschedule.pool import TSchedulePool - # check input num - if len(args) != len(self.inputs()): - raise RuntimeError( - f"Expected {len(self.inputs())} input args but got {len(args)}" - ) - - fgraph = self.copy() - bgraph = self.copy(reverse=True) - bgraph.tag = 'backward' - - # set input - for input, arg in zip(fgraph.inputs(), args): - if type(arg) != type(input): - raise RuntimeError(f"Expected input type the same") - fgraph._replace_tensor(input, arg) - - # dataloader su - cell = IRCell( - name = 'dataloader', - signature = 'dataloder.__next__', - input_length = 0, - output_length = len(args) - ) - for idx, arg in enumerate(args): - cell.set_output(idx, arg) - - devices = set() - for idx, arg in enumerate(args): - cell.set_output(idx, arg) - if isinstance(arg, IRTensor): - for node in arg.dst(fgraph.nodes()): - devices.update(set(node.device)) - cell.device = list(devices) - - dataloader = IRGraph([cell], None, None, 'dataloader') - data_sus = logic_translator( - dataloader, - su_type=SUType.Dataloader - ) - - for su in data_sus: - TSchedulePool().add_su(su) - - # forward su - fsus = logic_translator(fgraph, su_type=SUType.Forward) - - # return tensors - outputs = fgraph.outputs() - for output in outputs: - if isinstance(output, IRTensor): - output.set_trace(fsus) - - # set backward graph input - for input, output in zip(bgraph.inputs(), outputs): - bgraph.replace_tensor(input, output) - - # backward su - bsus = logic_translator(bgraph, su_type=SUType.Backward) - - for fsu, bsu in zip(fsus, bsus[::-1]): - fsu.set_mirror(bsu) - bsu.set_mirror(fsu) - - # add forward schedule to pool - for su in fsus: - TSchedulePool().add_su(su) - - if len(outputs) == 1: return outputs[0] - elif len(outputs) == 0: return None - else: return outputs + from cube.schedule.translator import LogicTranslator + return LogicTranslator.forward(self, *args) def __call__(self, *args): """ @@ -353,46 +283,6 @@ def subgraph(self, sub_nodes: List[IRCell]): return graph - - @staticmethod - def get_inputs(nodes: List[IRCell]) -> List[IRTensor]: - """ - Get all the input tensors the is not generated by nodes - - Returns: - List[IRTensor] - """ - all_outputs = list() - for node in nodes: - all_outputs += node.outputs() - inputs = list() - for node in nodes: - for input in node.inputs(): - if isinstance(input, IRTensor): - if input not in all_outputs: - inputs.append(input) - return inputs - - @staticmethod - def get_outputs(nodes: List[IRCell]) -> List[IRTensor]: - """ - Get all the input tensors the is not generated by nodes - - Returns: - List[IRTensor] - """ - all_inputs = list() - for node in nodes: - all_inputs += node.inputs() - outputs = list() - for node in nodes: - for output in node.outputs(): - if isinstance(output, IRTensor): - if output not in all_inputs: - outputs.append(output) - return outputs - - def __repr__(self): dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 825aa953..86b8493d 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -1,10 +1,116 @@ from typing import List, Optional, Callable import copy +import math from cube.ir.cten import IRTensor -__all__ = ['IRFullTensor', 'IRSubTensor'] +__all__ = ['IndexMap', 'IRFullTensor', 'IRSubTensor'] + + +class IndexMap: + + def __init__(self, indices): + + if not isinstance(indices, tuple): + raise TypeError("Expected indices to be a tuple") + + if not all([isinstance(s, slice) for s in indices]): + raise NotImplementedError( + "Only support for sliced index mapping" + ) + self._indices = indices + + def get(self): + """ + Get indices + """ + return self._indices + + @property + def ndims(self) -> int: + """ + Number of dimensions of the index map + """ + return len(self._indices) + + @property + def neles(self) -> int: + """ + Number of elements of the index map + """ + nelements = 1 + for slicer in self._indices: + count = slicer.stop - slicer.start + if slicer.step: + count = int(count // slicer.step) + nelements *= count + return nelements + + def map(self, submap): + """ + Map from the current indices by sub_indices. + + Args: + sub_indices: IndexMap + + Returns: + sub_indices: IndexMap + + """ + if not isinstance(submap, IndexMap): + raise TypeError("Expected IndexMap") + if self.ndims != submap.ndims: + raise ValueError("Expected same length of sub_indices") + + # e.g., (slice(0, M), slice(0, int(K // 2)) + sub = list() + for dim_indices, dim_sub_indices in zip(self.get(), submap.get()): + start, stop = dim_indices.start, dim_indices.stop + step = dim_indices.step if dim_indices.step else 1 + + sub_start, sub_stop = dim_sub_indices.start, dim_sub_indices.stop + sub_step = dim_sub_indices.step if dim_sub_indices.step else 1 + + new_start = start + sub_start + new_stop = new_start + sub_stop - sub_start + new_step = step * sub_step + if new_stop > stop: + raise ValueError("Trying to map a index out of range") + sub.append(slice(new_start, new_stop, new_step)) + return IndexMap(tuple(sub)) + + def overlap(self, other): + """ + Check if this indices overlapped with the other + + Args: + other: IndexMap + + Returns: + Boolean: True has overlap, otherwise False + """ + if not isinstance(other, IndexMap): + raise TypeError("Expected IndexMap") + + if other.ndims != self.ndims: + raise TypeError("Expected same dimension") + + for slicer1, slicer2 in zip(self.get(), other.get()): + start1, stop1 = slicer1.start, slicer1.stop + step1 = slicer1.step if slicer1.step else 1 + + start2, stop2 = slicer2.start, slicer2.stop + step2 = slicer2.step if slicer2.step else 1 + + if step1 == step2: + if min(stop1, stop2) <= max(start1, start2): + return False + elif start1 % step1 != start2 % step2: + return False + else: + raise NotImplementedError(f"not supported for differnt steps") + return True class IRFullTensor(IRTensor): @@ -28,7 +134,7 @@ def segments(self, index: Optional[int] = None): else: return self._segments[index] - def indices(self, index: Optional[int] = None): + def indices(self, index: Optional[int] = None) -> IndexMap: """ Get the SubTensors mapping indices """ @@ -66,12 +172,29 @@ def select(self, indices, val_op: Optional[Callable], shape: List[int]): """ sub_tensor = IRSubTensor(self, indices, val_op, shape) self._segments.append(sub_tensor) - self._indices.append(indices) + self._indices.append(IndexMap(indices)) self._val_ops.append(val_op) return sub_tensor + def overlap(self, other): + """ + Check if the two tensor is overlapped. -class IRSubTensor: + Returns: + True if they are sharing co-located position in + the full tensor, otherwise False + """ + if not isinstance(other, IRTensor): + raise TypeError("Expected Tensor") + if isinstance(other, IRFullTensor): + return self == other + elif isinstance(other, IRSubTensor): + return other.parent == self + else: + raise TypeError("Customized IRTensor not support") + + +class IRSubTensor(IRTensor): def __init__(self, full_tensor: IRTensor, indices, val_op=None, shape=None): """ @@ -82,13 +205,13 @@ def __init__(self, full_tensor: IRTensor, indices, val_op=None, shape=None): indices: index list val_op: the value operation to merge SubTensors into one """ - super.__init__(shape=shape, name=full_tensor.name) + super().__init__(shape=shape, name=full_tensor.name) # the full tensor self._full_tensor = full_tensor # the index from full_tensor - self._index_map = indices + self._index_map = IndexMap(indices) # val merge op self.val_merge_op = val_op @@ -100,7 +223,8 @@ def parent(self) -> IRFullTensor: """ return self._full_tensor - def index_map(self): + @property + def indices(self) -> IndexMap: """ Return indices list mapped to the full tensor """ @@ -125,6 +249,24 @@ def select(self, indices, val_op, shape=None): Returns: IRSubTensor """ - index_map = self.index_map[indices] - sub_tensor = self.full_tensor.select(index_map, val_op, shape) + sub_map = IndexMap(indices) + index_map = self.indices.map(sub_map) + sub_tensor = self.parent.select(index_map.get(), val_op, shape) return sub_tensor + + def overlap(self, other): + """ + Check if the two tensor is overlapped. + + Returns: + True if they are sharing co-located position in + the full tensor, otherwise False + """ + if not isinstance(other, IRTensor): + raise TypeError("Expected Tensor") + if isinstance(other, IRFullTensor): + return self.parent == other + elif isinstance(other, IRSubTensor): + return self.indices.overlap(other.indices) + else: + raise TypeError("Customized IRTensor not support") diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py new file mode 100644 index 00000000..5f371b33 --- /dev/null +++ b/tests/graph/test_tensor.py @@ -0,0 +1,122 @@ +from cube.graph.tensor import IRFullTensor, IRSubTensor + + +def test_full_tensor_init(): + + tensor = IRFullTensor(shape=[1024,1024], name='full_tensor') + assert tensor.shape == [1024, 1024] + assert tensor.name == 'full_tensor' + + +def test_full_tensor_select(): + + tensor = IRFullTensor(shape=[1024,1024], name='tensor') + assert len(tensor.segments()) == 0 + assert len(tensor.indices()) == 0 + assert len(tensor.val_ops()) == 0 + + sub_tensor1 = tensor.select( + indices = (slice(0, 1024), slice(0, 512)), + val_op = None, + shape = (1024, 512) + ) + + sub_tensor2 = tensor.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_op = None, + shape = (1024, 512) + ) + + assert sub_tensor1.shape == (1024, 512) + assert sub_tensor1.name == 'tensor' + + assert sub_tensor2.shape == (1024, 512) + assert sub_tensor2.name == 'tensor' + + assert len(tensor.segments()) == 2 + assert len(tensor.indices()) == 2 + assert len(tensor.val_ops()) == 2 + + +def test_full_tensor_overlap(): + + tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') + sub_tensor1 = tensor1.select( + indices = (slice(0, 1024), slice(256, 1024)), + val_op = None, + shape = (1024, 768) + ) + + sub_tensor2 = tensor1.select( + indices = (slice(0, 1024, 2), slice(512, 1024)), + val_op = None, + shape = (1024, 512) + ) + sub_tensor3 = tensor1.select( + indices = (slice(1, 1024, 2), slice(512, 1024)), + val_op = None, + shape = (1024, 512) + ) + + tensor2 = IRFullTensor(shape=[1024,1024], name='tensor') + + assert tensor1.overlap(sub_tensor1) + assert tensor1.overlap(tensor1) + assert not tensor1.overlap(tensor2) + assert not tensor2.overlap(sub_tensor1) + + assert not sub_tensor2.overlap(sub_tensor3) + + +def test_sub_tensor_select(): + + tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') + sub_tensor1 = tensor1.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_op = None, + shape = (1024, 512) + ) + sub_tensor2 = sub_tensor1.select( + indices = (slice(512, 1024), slice(0, 256)), + val_op = None, + shape = (512, 256) + ) + sub_tensor3 = sub_tensor1.select( + indices = (slice(512, 1024), slice(256, 512)), + val_op = None, + shape = (512, 256) + ) + + indices = sub_tensor2.indices.get() + assert indices == (slice(512, 1024, 1), slice(512, 768, 1)) + indices = sub_tensor3.indices.get() + assert indices == (slice(512, 1024, 1), slice(768, 1024, 1)) + + assert len(tensor1.segments()) == 3 + assert sub_tensor1 in tensor1.segments() + assert sub_tensor2 in tensor1.segments() + assert sub_tensor3 in tensor1.segments() + + +def test_sub_tensor_overlap(): + + tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') + sub_tensor1 = tensor1.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_op = None, + shape = (1024, 512) + ) + sub_tensor2 = sub_tensor1.select( + indices = (slice(512, 1024), slice(0, 256)), + val_op = None, + shape = (512, 256) + ) + sub_tensor3 = sub_tensor1.select( + indices = (slice(512, 1024), slice(256, 512)), + val_op = None, + shape = (512, 256) + ) + + assert sub_tensor1.overlap(sub_tensor2) + assert sub_tensor1.overlap(sub_tensor3) + assert not sub_tensor2.overlap(sub_tensor3) From 280641eb80c3cf46fdd557dca19a24f9a4be951f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Oct 2021 14:33:15 +0800 Subject: [PATCH 0211/1892] add common set --- cube/graph/tensor.py | 81 +++++++++++++++++++++++++++++++++++++- tests/graph/test_tensor.py | 35 ++++++++++++++++ 2 files changed, 115 insertions(+), 1 deletion(-) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 86b8493d..cb1a25eb 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -1,6 +1,5 @@ from typing import List, Optional, Callable import copy -import math from cube.ir.cten import IRTensor @@ -47,6 +46,19 @@ def neles(self) -> int: nelements *= count return nelements + @property + def shape(self) -> List[int]: + """ + Get the shape of the slice + """ + shape = list() + for slicer in self._indices: + count = slicer.stop - slicer.start + if slicer.step: + count = int(count // slicer.step) + shape.append(count) + return shape + def map(self, submap): """ Map from the current indices by sub_indices. @@ -112,6 +124,34 @@ def overlap(self, other): raise NotImplementedError(f"not supported for differnt steps") return True + def __and__(self, other): + """ + Get the common part + + Args: + other: IndexMap + + Returns: + IndexMap for the common part + """ + if not self.overlap(other): + return None + slices = list() + for slicer1, slicer2 in zip(self.get(), other.get()): + start1, stop1 = slicer1.start, slicer1.stop + step1 = slicer1.step if slicer1.step else 1 + + start2, stop2 = slicer2.start, slicer2.stop + step2 = slicer2.step if slicer2.step else 1 + + if step1 == step2: + start = max(start1, start2) + stop = min(stop1, stop2) + slices.append(slice(start, stop, step1)) + else: + raise NotImplementedError(f"not supported for differnt steps") + return IndexMap(tuple(slices)) + class IRFullTensor(IRTensor): @@ -193,6 +233,19 @@ def overlap(self, other): else: raise TypeError("Customized IRTensor not support") + def common(self, other) -> Optional[IRTensor]: + """ + Get the common sub-tensor + + Args: + IRTensor + + Returns: + None for not overlap, + else IRSubTensor or IRFullTensor + """ + return other if self.overlap(other) else None + class IRSubTensor(IRTensor): @@ -270,3 +323,29 @@ def overlap(self, other): return self.indices.overlap(other.indices) else: raise TypeError("Customized IRTensor not support") + + def common(self, other): + """ + Get the common sub-tensor + + Args: + IRTensor + + Returns: + None for not overlap, + else IRSubTensor or IRFullTensor + """ + if self.overlap(other): + if isinstance(other, IRFullTensor): + return self + elif isinstance(other, IRSubTensor): + indices = self.indices & other.indices + sub_tensor = self.parent.select( + indices = indices.get(), + val_op = self.val_op, + shape = indices.shape + ) + return sub_tensor + else: + raise NotImplementedError("Customized IRTensor not support") + return None diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py index 5f371b33..b7bdf5c6 100644 --- a/tests/graph/test_tensor.py +++ b/tests/graph/test_tensor.py @@ -120,3 +120,38 @@ def test_sub_tensor_overlap(): assert sub_tensor1.overlap(sub_tensor2) assert sub_tensor1.overlap(sub_tensor3) assert not sub_tensor2.overlap(sub_tensor3) + + +def test_sub_tensor_common(): + + tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') + sub_tensor_col1 = tensor1.select( + indices = (slice(0, 1024), slice(0, 512)), + val_op = None, + shape = (1024, 512) + ) + sub_tensor_col2 = tensor1.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_op = None, + shape = (1024, 512) + ) + sub_tensor_row1 = tensor1.select( + indices = (slice(0, 512), slice(0, 1024)), + val_op = None, + shape = (512, 1024) + ) + sub_tensor_row2 = tensor1.select( + indices = (slice(512, 1024), slice(0, 1024)), + val_op = None, + shape = (512, 1024) + ) + + lt = sub_tensor_col1.common(sub_tensor_row1) + rt = sub_tensor_col2.common(sub_tensor_row1) + lb = sub_tensor_row2.common(sub_tensor_col1) + rb = sub_tensor_row2.common(sub_tensor_col2) + + assert lt.indices.get() == (slice(0, 512, 1), slice(0, 512, 1)) + assert rt.indices.get() == (slice(0, 512, 1), slice(512, 1024, 1)) + assert lb.indices.get() == (slice(512, 1024, 1), slice(0, 512, 1)) + assert rb.indices.get() == (slice(512, 1024, 1), slice(512, 1024, 1)) From edebcb751ba20cb0bca4cc18f7a0d15fd3587585 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Oct 2021 19:45:03 +0800 Subject: [PATCH 0212/1892] add graph parameters --- cube/graph/graph.py | 17 ++++++++++++++++- cube/graph/parser/parser.py | 14 +++++++------- cube/graph/tensor.py | 2 +- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 1864b80f..9b751b0a 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -52,6 +52,15 @@ def __init__(self, for idx, tensor in enumerate(output_tensors): self.set_output(idx, tensor) + # set parameter + self._parameters = list() + for node in self._nodes: + for input in node.inputs(): + if isinstance(input, IRTensor): + if input not in input_tensors and \ + input.is_leaf(self._nodes): + input.as_param() + self._parameters.append(input) self.tag = 'forward' def reset_dependency(self): @@ -77,6 +86,12 @@ def reset_dependency(self): dst_input_idx = dst_cell.inputs().index(tensor) dst_cell.add_predecessor(dst_input_idx, src_cell) + def parameters(self): + """ + Return parameter list + """ + return copy.copy(self._parameters) + def copy(self, reverse=False): """ Copy the graph but re-new the intermediate tensor @@ -87,7 +102,7 @@ def _renew(val: Any): if not isinstance(val, IRTensor): return val # parameters - if val.is_leaf(self.nodes()) and val not in self.inputs(): + if val.is_param(): return val # intermediate data if val._id not in new_tensors: diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 263ac027..b6aab7b2 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -4,7 +4,7 @@ from typing import List, Tuple, Optional from cube.graph import IROperation -from cube.ir.cten import IRTensor +from cube.graph.tensor import IRFullTensor from cube.graph.parser.frame import Frame class ScriptNodeKind(enum.Enum): @@ -22,7 +22,7 @@ class ScriptModuleParser: def parse_module(module, input_shapes: Optional[ Tuple[List[int],] ] = None, frame: Frame = Frame()) \ - -> Tuple[List[IRTensor], List[IROperation], List[IRTensor]]: + -> Tuple[List[IRFullTensor], List[IROperation], List[IRFullTensor]]: """ The overall entry to parse a torchscript graph module """ @@ -31,7 +31,7 @@ def parse_module(module, # handle graph input -- Assuming all the inputs are tensors input_var_name = [input.debugName() for input in module.graph.inputs()] for index, var_name in enumerate(input_var_name[1:]): # omit self - frame.add_var(var_name, IRTensor(name=var_name), graph_arg=index) + frame.add_var(var_name, IRFullTensor(name=var_name), graph_arg=index) input_val = [frame.get_var(var_name) for var_name in input_var_name[1:]] # handle input shape @@ -41,7 +41,7 @@ def parse_module(module, f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(input_val)})" ) for shape, val in zip(input_shapes, input_val): - if isinstance(val, IRTensor): + if isinstance(val, IRFullTensor): val.shape = shape all_ir_nodes: List[IROperation] = list() @@ -154,7 +154,7 @@ def parse_aten_node(node, module, frame: Frame) -> List[IROperation]: var_name = input.debugName() val = frame.get_var(var_name) index = len(inputs) - 1 - reverse_index - if maybe_kwarg and (not isinstance(val, IRTensor)) and index > 1: + if maybe_kwarg and (not isinstance(val, IRFullTensor)) and index > 1: continue else: input_val.append(val) @@ -232,7 +232,7 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: This will add frame with the variable name and it's value The value can be: - 1). (IRTensor) the tensor edge in graph + 1). (IRFullTensor) the tensor edge in graph 2). (str code) symbolic value based on runtime info (e.g., self.training) 3). (str) Function or torch.nn.moudles @@ -249,7 +249,7 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: # this usually means weight (nn.Parameter in torch) if dtype == 'Tensor': shape = list(getattr(module, label).shape) - ir_tensor = IRTensor(name=label, shape=shape) + ir_tensor = IRFullTensor(name=label, shape=shape) frame.add_var(var_name, ir_tensor) # symbolic attributes elif dtype in ['bool', 'int', 'float']: diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index cb1a25eb..69527c7c 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -316,7 +316,7 @@ def overlap(self, other): the full tensor, otherwise False """ if not isinstance(other, IRTensor): - raise TypeError("Expected Tensor") + return False if isinstance(other, IRFullTensor): return self.parent == other elif isinstance(other, IRSubTensor): From 80602d8d5c8e0eab140af9133edbddea521b0cbd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Oct 2021 19:46:26 +0800 Subject: [PATCH 0213/1892] test parameter --- tests/graph/test_graph.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index cc86676d..3e1def76 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -86,6 +86,10 @@ def test_graph_init(): assert node1 not in node2.successors() assert node3 not in node2.predecessors() + # weight test + params = graph.parameters() + assert len(params) == 5 + def test_graph_nodes(): inputs, ops, outputs = construct_model() @@ -100,6 +104,11 @@ def test_graph_copy(): cgraph = graph.copy(reverse=False) print(cgraph) + + cparam_id = [param._id for param in cgraph.parameters()] + param_id = [param._id for param in graph.parameters()] + assert set(cparam_id) == set(param_id) + for gnode, cnode in zip(graph.nodes(), cgraph.nodes()): assert gnode.name == cnode.name assert gnode.signature == cnode.signature From 0a41cad1151cd86a3fe70e8663e198863b26a378 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Oct 2021 22:57:00 +0800 Subject: [PATCH 0214/1892] fix copy for fulltensor and subtensor --- cube/graph/operator.py | 4 +++ cube/graph/tensor.py | 64 ++++++++++++++++++++++++++++++++++++++ tests/graph/test_graph.py | 8 +++++ tests/graph/test_tensor.py | 7 +++++ 4 files changed, 83 insertions(+) diff --git a/cube/graph/operator.py b/cube/graph/operator.py index 941ff83c..e323e095 100644 --- a/cube/graph/operator.py +++ b/cube/graph/operator.py @@ -1,6 +1,7 @@ from typing import List, Union from cube.ir.cten import IRTensor, IRCell +from cube.graph.tensor import IRFullTensor from cube.graph.mapping import IR2LogicOp @@ -24,6 +25,9 @@ def __init__(self, output_length (int): the number of outputs for the op """ super().__init__(name, signature, input_length, output_length) + outputs = [IRFullTensor() for _ in range(output_length)] + for idx, output in enumerate(outputs): + self.set_output(idx, output) self.semantic = IR2LogicOp.map(self.signature) def infer_shape(self): diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 69527c7c..05bd6a21 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -165,6 +165,38 @@ def __init__(self, shape=None, name=None): # value op self._val_ops: List = list() + def __copy__(self): + """ + Copy the tensor that will have the exactly same id + except the empty attached cell + + Returns: + tensor + """ + tensor = IRFullTensor(self._shape, self.name) + for key in self.__dict__: + setattr(tensor, key, getattr(self, key)) + # clear attached cells + tensor._cell = list() + return tensor + + def renew(self): + """ + Renew a new tensor with same name and shape, + but with a different new id + + Returns: + tensor + """ + tensor = IRFullTensor(self._shape, self.name) + new_id = tensor._id + for key in self.__dict__: + setattr(tensor, key, getattr(self, key)) + # clear attached cells + tensor._cell = list() + tensor._id = new_id + return tensor + def segments(self, index: Optional[int] = None): """ Get the SubTensors at index position @@ -287,6 +319,38 @@ def indices(self) -> IndexMap: def val_op(self): return self.val_merge_op + def __copy__(self): + """ + Copy the tensor that will have the exactly same id + except the empty attached cell + + Returns: + tensor + """ + tensor = IRSubTensor(self._shape, self.name) + for key in self.__dict__: + setattr(tensor, key, getattr(self, key)) + # clear attached cells + tensor._cell = list() + return tensor + + def renew(self): + """ + Renew a new tensor with same name and shape, + but with a different new id + + Returns: + tensor + """ + tensor = IRSubTensor(self._shape, self.name) + new_id = tensor._id + for key in self.__dict__: + setattr(tensor, key, getattr(self, key)) + # clear attached cells + tensor._cell = list() + tensor._id = new_id + return tensor + def select(self, indices, val_op, shape=None): """ Select an IRSubTensor diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 3e1def76..7c97542c 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -1,6 +1,7 @@ from cube.graph.graph import IRGraph from cube.graph.tensor import IRFullTensor from cube.graph.operator import IROperation +from cube.ir.cten import IRTensor def construct_model(): @@ -65,6 +66,13 @@ def test_graph_init(): all_inputs += node.inputs() all_outputs += node.outputs() + for input in all_inputs: + if isinstance(input, IRTensor): + assert isinstance(input, IRFullTensor) + for output in all_outputs: + if isinstance(output, IRTensor): + assert isinstance(output, IRFullTensor) + # check inputs for input in inputs: assert input in graph.inputs() diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py index b7bdf5c6..c8ec1921 100644 --- a/tests/graph/test_tensor.py +++ b/tests/graph/test_tensor.py @@ -1,3 +1,5 @@ +import copy + from cube.graph.tensor import IRFullTensor, IRSubTensor @@ -7,6 +9,11 @@ def test_full_tensor_init(): assert tensor.shape == [1024, 1024] assert tensor.name == 'full_tensor' +def test_full_tensor_constrcut(): + + tensor = IRFullTensor(shape=[1024,1024], name='full_tensor') + ctensor = copy.copy(tensor) + assert isinstance(ctensor, IRFullTensor) def test_full_tensor_select(): From 1b5bc82cc4fe05c498e3e8c587e0536b7f08c172 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Oct 2021 23:21:12 +0800 Subject: [PATCH 0215/1892] add su and sugraph --- cube/schedule/__init__.py | 123 +++++++++++++ cube/schedule/pool.py | 33 ++++ cube/schedule/su.py | 295 +++++++++++++++++++++++++++++++ cube/schedule/sugraph.py | 314 +++++++++++++++++++++++++++++++++ tests/schedule/test_pool.py | 24 +++ tests/schedule/test_su.py | 112 ++++++++++++ tests/schedule/test_sugraph.py | 118 +++++++++++++ 7 files changed, 1019 insertions(+) create mode 100644 cube/schedule/pool.py create mode 100644 cube/schedule/su.py create mode 100644 cube/schedule/sugraph.py create mode 100644 tests/schedule/test_pool.py create mode 100644 tests/schedule/test_su.py create mode 100644 tests/schedule/test_sugraph.py diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index e69de29b..2f73e064 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -0,0 +1,123 @@ +from typing import Callable, Optional +import torch + +from cube.schedule.pool import SchedulePool +from cube.schedule.translator import IRDataLoader, LogicTranslator +from cube.schedule.sugraph import SUGraph +from cube.codegen.codegen import TScheduleCodeGen + + +class SemanticModel: + + def __init__(self, model: torch.nn.Module, input_shapes): + """ + Create semantic model based on AI Scientist description. + """ + from cube.graph import parser + self.ir_graph = parser.convert( + model, input_shapes=input_shapes + ) + self._loaded_module = None + + def get_graph(self): + return self.ir_graph + + def load_module(self, filename: str): + import importlib.util + print(f'> loading generated spatial moduel from {filename}') + spec = importlib.util.spec_from_file_location("GenModel", filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._loaded_module = module.GenModel().cuda() + + def get_gen_module(self): + return self._loaded_module + + def clear_module(self): + self._loaded_module = None + + def __call__(self, *args): + if self._loaded_module: + return self._loaded_module(*args) + else: + return self.ir_graph(*args) + + +def schedule(model: SemanticModel, dataloader, policy_fn: Optional[Callable] = None): + """ + AI Scientist calls like: + + @cube.tschedule.schedule(model, dataloader, policy_fn=policy) + def train_step(model, dataloader): + # do a 4-time gradient accumulation + for acc_step, (data, label) in enumerate(dataloader): + if acc_step < 4: + loss = model(data, label) + loss.backward() + else: + break + ... + + for epoch in range(100): + train_step(model, data_loader) + optimizer.step() + optimizer.zero_grad() + + ... + """ + if not isinstance(model, SemanticModel): + raise TypeError("Expect Semantic Model") + + ir_graph = model.get_graph() + ir_dataloader = IRDataLoader(dataloader) + myrank = torch.distributed.get_rank() + + + def _load_tschedule_fn(filename) -> Callable: + import importlib.util + print(f'> [{myrank}] loading generated schedule from {filename} ...') + spec = importlib.util.spec_from_file_location( + "_train_step", filename + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module._train_step + + def decorator(fn: Callable) -> Callable: + filename = 'gencode{}.py' + + if myrank == 0: + SchedulePool().clear() + + # logic translator + fn(ir_graph, ir_dataloader) + sus = SchedulePool().sus() + + # adapter + sus_with_adapter = LogicTranslator.gen_adapter(sus) + + # policy + su_graph = SUGraph(sus_with_adapter) + if policy_fn: + seq = policy_fn(su_graph) + + # code generation + world_size = torch.distributed.get_world_size() + tgener = TScheduleCodeGen(seq) + for rank in range(world_size): + fname = filename.format(rank) + # generate spatial module code + model.gen_module(seq, rank, fname, attach=False) + # generate temporal schedule code + tgener.gen( + device = rank, + outfile = fname, + attach=True + ) + torch.distributed.barrier() + # load module + model.load_module(filename.format(myrank)) + # load temporal + return _load_tschedule_fn(filename.format(myrank)) + + return decorator diff --git a/cube/schedule/pool.py b/cube/schedule/pool.py new file mode 100644 index 00000000..11de56e9 --- /dev/null +++ b/cube/schedule/pool.py @@ -0,0 +1,33 @@ +from typing import List +import copy + + +class SchedulePool: + + class __SchedulePool: + + def __init__(self): + + self._sus = list() + + instance = None + + def __init__(self): + if not SchedulePool.instance: + SchedulePool.instance = SchedulePool.__SchedulePool() + + def __getattr__(self, name): + return getattr(self.instance, name) + + def add_su(self, su): + self.instance._sus.append(su) + + def sus(self) -> List: + return copy.copy(self.instance._sus) + + def clear(self): + self.instance._sus = list() + + def __repr__(self): + dscp = '\n'.join([repr(su) for su in self._sus]) + return dscp diff --git a/cube/schedule/su.py b/cube/schedule/su.py new file mode 100644 index 00000000..1fa51b9c --- /dev/null +++ b/cube/schedule/su.py @@ -0,0 +1,295 @@ +from typing import List, Optional, Tuple +import copy +from enum import Enum + +from cube.ir.cten import IRCell + + +class SUType(Enum): + + # outputs = cube.runtime.temporal.forward(model, *args) + Forward = 'cube.runtime.temporal.forward' + + # grads = cube.runtime.temporal.backward( + # input_tensors, output_tensors, output_grads + # ) + Backward = 'cube.runtime.temporal.backward' + + # cube.runtime.collectives.sendrecv(send_tensors, send_ranks, + # recv_shapes, from_ranks + # ) + Adapter = 'cube.runtime.collectives.sendrecv' + + Dataloader = 'next(dataloader)' + + +class ScheduleUnit(IRCell): + """ + Action recv tensors must be inside of Action inputs, + and can be mapped to Action.graph.inputs + + """ + + def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): + + if not all([isinstance(node, IRCell) for node in nodes]): + raise ValueError("Expected each nodes to be List[IRCell]") + if not isinstance(stype, SUType): + raise TypeError("Expected stype be SUType") + + # get inputs and outputs + inputs = IRCell.get_inputs(nodes) + inputs = [input for input in inputs if not input.is_param()] + outputs = IRCell.get_outputs(nodes) + super().__init__( + name = name, + signature = stype.value, + input_length = len(inputs), + output_length = len(outputs) + ) + + self.stype = stype + + self._nodes = nodes + for idx, input in enumerate(inputs): + self.set_input(idx, input) + for idx, output in enumerate(outputs): + self.set_output(idx, output) + + # each input is associated with + # send adapters and recv adapters (send + recv) + self._send_in_adapters: List[List[ScheduleUnit]] = [ + list() for _ in range(len(inputs)) + ] + self._recv_in_adapters: List[List[ScheduleUnit]] = [ + list() for _ in range(len(inputs)) + ] + + # each output is associated with + # send adapters and recv adapters (send + recv) + self._send_out_adapters: List[List[ScheduleUnit]] = [ + list() for _ in range(len(outputs)) + ] + self._recv_out_adapters: List[List[ScheduleUnit]] = [ + list() for _ in range(len(outputs)) + ] + + # additional control dependency for add_flow + self._ctrl_predecessors = list() + self._ctrl_successors = list() + + self.mirror = None + + def __copy__(self): + """ + Copy the SU. Note the mirror su is also copied + """ + su = ScheduleUnit(self._nodes, self.stype, self.name) + if self.mirror is not None: + mirror_su = self.mirror + mirror_su = ScheduleUnit( + mirror_su._nodes, mirror_su.stype, mirror_su.name + ) + su.set_mirror(mirror_su) + mirror_su.set_mirror(su) + return su + + def set_mirror(self, su): + """ + Create a mirrored ScheduleUnit: the + inputs and outputs are reversed + """ + if not isinstance(su, ScheduleUnit): + raise TypeError("Expected mirror to be ScheduleUnit") + self.mirror = su + + def in_adapters(self, index: Optional[int] = None) -> List: + """ + Get adapter for the input tensor at index + + Returns: + Tuple[List[ScheduleUnit], List[ScheduleUnit]]: + the send_adapters and recv_adapters + """ + if isinstance(index, int): + if index >= len(self._inputs): + raise RuntimeError( + f"Get index out of range ({index} >= {len(self._inputs)})" + ) + send_adapters = copy.copy(self._send_in_adapters[index]) + recv_adapters = copy.copy(self._recv_in_adapters[index]) + return send_adapters, recv_adapters + elif index is None: + all_send_adapters = list() + all_recv_adapters = list() + for adapters in self._send_in_adapters: + all_send_adapters += adapters + for adapters in self._recv_in_adapters: + all_recv_adapters += adapters + return all_send_adapters, all_recv_adapters + else: + raise TypeError("Expected index to be None or int") + + def out_adapters(self, index: Optional[int] = None) -> Tuple[List, List]: + """ + Get adapter for the output tensor at index + + Returns: + Tuple[List[ScheduleUnit], List[ScheduleUnit]]: + the send_adapters and recv_adapters + """ + if isinstance(index, int): + if index >= len(self._outputs): + raise RuntimeError( + f"Get index out of range ({index} >= {len(self._outputs)})" + ) + send_adapters = copy.copy(self._send_out_adapters[index]) + recv_adapters = copy.copy(self._recv_out_adapters[index]) + return send_adapters, recv_adapters + elif index is None: + all_send_adapters = list() + all_recv_adapters = list() + for adapters in self._send_out_adapters: + all_send_adapters += adapters + for adapters in self._recv_out_adapters: + all_recv_adapters += adapters + return all_send_adapters, all_recv_adapters + else: + raise TypeError("Expected index to be None or int") + + def _add_in_adapter(self, index: int, send_adapter, recv_adapter): + """ + Add adapters to the input tensor of this SU + + Args: + index (int): the input index + send_adapter (ScheduleUnit) + recv_adapter (ScheduleUnit) + """ + if index >= len(self._inputs): + raise ValueError(f"index {index} out of range {len(self._inputs)}") + if not isinstance(send_adapter, ScheduleUnit): + raise TypeError("Expected send adapter to be ScheduleUnit") + if not isinstance(recv_adapter, ScheduleUnit): + raise TypeError("Expected recv adapter to be ScheduleUnit") + self._send_in_adapters[index].append(send_adapter) + self._recv_in_adapters[index].append(recv_adapter) + + def _add_out_adapter(self, index: int, send_adapter, recv_adapter): + """ + Add adapters to the output tensor of this SU + + Args: + index (int): the output index + send_adapter (ScheduleUnit) + recv_adapter (ScheduleUnit) + """ + if index >= len(self._outputs): + raise ValueError(f"index {index} out of range {len(self._outputs)}") + if not isinstance(send_adapter, ScheduleUnit): + raise TypeError("Expected send adapter to be ScheduleUnit") + if not isinstance(recv_adapter, ScheduleUnit): + raise TypeError("Expected recv adapter to be ScheduleUnit") + self._send_out_adapters[index].append(send_adapter) + self._recv_out_adapters[index].append(recv_adapter) + + def nodes(self, index: Optional[int] = None): + """ + Get node at position index + """ + if isinstance(index, int): + if index >= len(self._nodes): + raise RuntimeError( + f"Get node out of range ({index} >= {len(self._nodes)})" + ) + return self._nodes[index] + elif index is None: + return copy.copy(self._nodes) + else: + raise TypeError("Expected index to be None or int") + + def add_predecessor(self, input_index: int, su): + """ + Add a predecessor cell in the input_index slot. + self.input[input_index] = node.output[out_index] + """ + if input_index == -1: + self._ctrl_predecessors.append(su) + else: + super().add_predecessor(input_index, su) + + def predecessors(self, index: Optional[int] = None) -> List: + """ + Get 1-hop predecessor cells including control predecessors + + Args: + index (Optional[int]): + -1: return control predecessors + None: return all predecessors including index + >0 : return input SUs at input index + + Returns: + cell(s): List[IRCell] + """ + if isinstance(index, int): + if index == -1: + return copy.copy(self._ctrl_predecessors) + if index >= len(self._inputs): + raise RuntimeError( + f"Get the input out of range ({index} >= {len(self._inputs)}" + ) + return copy.copy(self._predecessors[index]) + elif index is None: + predecessors = list() + for pre_cells in self._predecessors: + predecessors += pre_cells + predecessors += self._ctrl_predecessors + return predecessors + else: + raise TypeError("Expected index to be None or int") + + def add_successor(self, output_index: int, su): + """ + Set self node the output index node. + `node` will take the self.outputs(index) as the input + """ + if output_index == -1: + self._ctrl_successors.append(su) + else: + super().add_successor(output_index, su) + + def successors(self, index: Optional[int] = None) -> List: + """ + Get 1-hop successor cells including control successors + + Args: + index (Optional[int]): + -1: return control successors + None: return all successors including index + >0 : return output SUs at output index + + Returns: + cells: List[ScheduleUnit] + """ + if isinstance(index, int): + if index == -1: + return copy.copy*self._ctrl_successors + if index >= len(self._outputs): + raise RuntimeError( + f"Get the output out of range ({index} >= {len(self._outputs)}" + ) + return copy.copy(self._successors[index]) + elif index is None: + successors = list() + for post_cells in self._successors: + successors += post_cells + successors += self._ctrl_successors + return successors + else: + raise TypeError("Expected index to be None or int") + + def __repr__(self): + su_inputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.inputs()] + su_outputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.outputs()] + dscp = f'SU({self.stype}, nodes={len(self.nodes())})-dev{self.device}: {su_inputs} -> {su_outputs}' + return dscp diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py new file mode 100644 index 00000000..d9491f23 --- /dev/null +++ b/cube/schedule/sugraph.py @@ -0,0 +1,314 @@ +from typing import List, Optional, Union +import copy + +from cube.ir.cten import IRCell +from cube.schedule.su import SUType, ScheduleUnit + + +class SUGraph(IRCell): + + def __init__(self, sus: List[ScheduleUnit]): + + if not all([isinstance(su, ScheduleUnit) for su in sus]): + raise TypeError( + f"Expected a list of ScheduleUnits, but got {type(sus)}" + ) + + inputs = IRCell.get_inputs(sus) + outputs = IRCell.get_outputs(sus) + super().__init__( + name = 'SU', + signature = 'None', + input_length = len(inputs), + output_length = len(outputs) + ) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + for idx, output in enumerate(outputs): + self.set_output(idx, output) + + self.sequence = sus + self.reset_dependency() + + @property + def nnodes(self) -> int: + """ + Get number of nodes (int) + """ + return len(self.sequence) + + def reset_dependency(self): + """ + Reset the node dataflow dependency + """ + # set node predecessors and successors + for src_idx in range(self.nnodes): + src_cell = self.sequence[src_idx] + src_cell._successors = [ + list() for _ in range(len(src_cell.outputs())) + ] + for dst_su in self.sequence[src_idx+1:]: + dst_su._predecessors = [ + list() for _ in range(len(dst_su.inputs())) + ] + for out_idx, out_tensor in enumerate(src_cell.outputs()): + for in_idx, in_tensor in enumerate(dst_su.inputs()): + if out_tensor.overlap(in_tensor): + src_cell.add_successor(out_idx, dst_su) + dst_su.add_predecessor(in_idx, src_cell) + + def __len__(self): + return len(self.sequence) + + def sus(self, index: Optional[int] = None): + """ + Return ScheduleUnit + + Args: + + """ + if isinstance(index, int): + if index >= len(self.sequence): + raise RuntimeError( + f"Get node out of range ({index} >= {len(self.sequence)})" + ) + return self.sequence[index] + elif index is None: + return copy.copy(self.sequence) + else: + raise TypeError("Expected index to be None or int") + + def happen_before(self, su1, su2): + """ + Check if the su1 -> (happened before) su2 + + Returns: + Boolean + """ + if not isinstance(su1, ScheduleUnit) or \ + not isinstance(su2, ScheduleUnit): + raise TypeError("Expected su to be an ScheduleUnit") + if su2 in su1.successors(): + return True + else: + for succ_su in su1.successors(): + if self.happen_before(succ_su, su2): + return True + return False + + def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: + """ + Merge two ScheduleUnit. This requires + + 1). all the nodes in one SU happens before / after + all the nodes in another SU. (Guaranteed by default + as all the operations on sequence are semantic-correct) + + 2). all the nodes in both SU are on the same device, + have same tags and they are not equal. + + 3). Deadlock-free merge. Suppose + SU1 (dev0) -> SU2 (dev1) -> SU3 (dev0) + Then merge SU1 and SU3 to SU4 will cause + deadlock on SU4 -> <- SU2 + + Note due to PyTorch limitation, + merging two forward ScheduleUnits will also cause + the merge of corresponding two backward ScheduleUnits. + + Returns: + if succeed: A merged ScheduleUnit. + if fail: None + """ + + if not isinstance(su1, ScheduleUnit) or \ + not isinstance(su2, ScheduleUnit): + raise TypeError("Expected SU1 and SU2 are ScheduleUnit") + if su1 not in self.sequence: + raise ValueError(f"su1: {su1} not in sequence") + if su2 not in self.sequence: + raise ValueError(f"su2: {su2} not in sequence") + + # 2) all the nodes in both SU are on the same device + if su1 == su2 or su1.stype != su2.stype: + return None + if set(su1.device) != set(su2.device): + return None + + index_su1 = self.sequence.index(su1) + index_su2 = self.sequence.index(su2) + su1, su2 = (su1, su2) if index_su1 < index_su2 else (su2, su1) + # 3) deadlock-free merge + index_su1, index_su2 = min(index_su1, index_su2), max(index_su1, index_su2) + inter_sus = self.sequence[index_su1+1:index_su2] + for su in inter_sus: + if self.happen_before(su1, su) and self.happen_before(su, su2): + return None + + # merge forward su + sub_nodes = su1.nodes() + su2.nodes() + merged_su = ScheduleUnit(sub_nodes, su1.stype) + + # merge mirrored su + # mirror_su2 -> mirror_su1 + mirror_su1, mirror_su2 = su1.mirror, su2.mirror + if mirror_su1 is not None and mirror_su2 is not None: + sub_nodes = mirror_su2.nodes() + mirror_su1.nodes() + merged_mirror_su = ScheduleUnit(sub_nodes, mirror_su1.stype) + # set mirror + merged_su.set_mirror(merged_mirror_su) + merged_mirror_su.set_mirror(merged_su) + elif mirror_su1 is None and mirror_su2 is None: + merged_mirror_su = None + else: + raise RuntimeError( + "The merged su should be both have mirror or both not have." + ) + + # replace + self.sequence[index_su1] = merged_su + self.sequence.remove(su2) + if mirror_su1 in self.sequence and mirror_su2 in self.sequence: + index_mirror_su2 = self.sequence.index(mirror_su2) + self.sequence[index_mirror_su2] = merged_mirror_su + self.sequence.remove(mirror_su1) + + # TODO: optimize: reset dependency + self.reset_dependency() + return merged_su + + def add_flow(self, su1, su2): + """ + Add control flow dependency su1 -> su2 + """ + if not isinstance(su1, ScheduleUnit) or not isinstance(su2, ScheduleUnit): + raise TypeError("Expected both SU1 and SU2 are ScheduleUnit") + su1.add_successors(-1, su2) + su2.add_predecessors(-1, su1) + + def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): + """ + Assign SU to devices. + + The assignment will automatically trigger the generation of + Adapter SU. + + 1) if ranks has multiple int, then the su is copied as the same + SU will be happened redundantly on multiple devices. + + 2) if the input tensor this su is decided to be generated on + other devices, then Adapter SUs (send SU and recv SU) will + be generated and inserted right before this SU. + """ + if su not in self.sequence: + raise ValueError(f"SU {su} is not in the SUGraph") + if isinstance(ranks, int): + ranks = [ranks] + elif not all([isinstance(int, rank) for rank in ranks]): + raise TypeError("Expected type ranks to be Union[int, List[int]]") + + if set(su.device) == set(ranks): + return + + if len(ranks) != 1: + # copy su + sus = [copy.copy(su) for _ in range(len(ranks)-1)] + sus = [self] + sus + for su in ranks: + index = self.sus().index(su) + self.sequence.insert(su, index) + self.reset_dependency() + for su, rank in zip(sus, ranks): + self.assign(su, rank) + + # set device + su.device = ranks + + # set adapter device for the input + for idx in range(len(su.inputs())): + send_adapters, recv_adapters = su.in_adapters(idx) + for send_adapter in send_adapters: + send_adapter.nodes(0).send_ranks = [ranks[0],] + for recv_adapter in recv_adapters: + recv_adapter.device = ranks + + # set adapter device for the output + for idx in range(len(su.outputs())): + send_adapters, recv_adapters = su.out_adapters(idx) + for send_adapter in send_adapters: + send_adapter.device = ranks + for recv_adapter in recv_adapters: + recv_adapter.nodes(0).recv_ranks = [ranks[0],] + return True + + def set_order(self, seq: List[ScheduleUnit]): + if not all([isinstance(su, ScheduleUnit) for su in seq]): + raise ValueError("Expected a list of SUs") + if len(seq) != len(self.sequence): + raise ValueError("Expected seq length equal with Graph sus") + for su in seq: + if su not in self.sequence: + raise ValueError(f"Found SU {su} in seq but not in graph") + # correctness check + if not SUGraph.is_topo_order(seq, integrity_check=True): + raise ValueError("Cannot satisfy topological order") + self.sequence = seq + return True + + + @staticmethod + def is_topo_order(seq: List[ScheduleUnit], integrity_check=False): + """ + Check whether seq satisfies topological order. + + Args: + seq: List of ScheduleUnit + integrity_check: + If true, performs additional integrity check that requires + all the SUs in predecessor and successor of a SU should + appear in the sequence. + + Returns: + Boolean: True for satisfying topo order, otherwise False. + """ + + for index, su in enumerate(seq): + for pre_su in su.predecessors(): + # find the pre-su not appear in sequence + if integrity_check and not pre_su in seq: + return False + pre_idx = seq.index(pre_su) + # violate topological order + if pre_idx >= index: + return False + return True + + def __repr__(self): + dscp = f'ScheduleSeq (len={len(self)}):\n' + for node in self.sequence: + succ_node_ids = [None] * len(node.outputs()) + for out_idx in range(len(node.outputs())): + node_list = [snode._id for snode in node.successors(out_idx)] + succ_node_ids[out_idx] = node_list + dscp += f"\n{node._id}: {node} -> su id {succ_node_ids}\n" + return dscp + + +class SeqSpace: + + @staticmethod + def space_size(seq, device_num=1): + """ + Calculate legal + """ + + def _comb(n, m): + """ + Calcualte combination C(n,m): select n from m (n < m) + """ + res = 1 + for j in range(0, min(n, m)): + res *= (m-j) / (min(n, m) - j) + return int(res) + + raise NotImplementedError diff --git a/tests/schedule/test_pool.py b/tests/schedule/test_pool.py new file mode 100644 index 00000000..16bb4fb6 --- /dev/null +++ b/tests/schedule/test_pool.py @@ -0,0 +1,24 @@ +from cube.schedule.pool import SchedulePool +from cube.schedule.su import SUType, ScheduleUnit + +from cube.ir.cten import IRCell, IRTensor + + +def test_schedule_pool(): + + SchedulePool().clear() + assert len(SchedulePool()._sus) == 0 + assert len(SchedulePool().sus()) == 0 + + cell = IRCell( + name='test', signature='test', input_length=4, output_length=2 + ) + su = ScheduleUnit([cell], SUType.Forward, name='su') + SchedulePool().add_su(su) + + assert len(SchedulePool()._sus) == 1 + assert len(SchedulePool().sus()) == 1 + + for record_su in SchedulePool().sus(): + assert record_su == su + diff --git a/tests/schedule/test_su.py b/tests/schedule/test_su.py new file mode 100644 index 00000000..a7d498c2 --- /dev/null +++ b/tests/schedule/test_su.py @@ -0,0 +1,112 @@ +import copy + +from cube.graph.tensor import IRFullTensor +from cube.graph.operator import IROperation +from cube.graph.graph import IRGraph + +from cube.schedule.su import SUType, ScheduleUnit + + +def construct_model(): + + input = IRFullTensor(shape=[64,1024], name='data') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = IROperation( + name='linear1', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear1.set_input(0, input) + linear1.set_input(1, weight1) + linear1.set_input(2, bias1) + + # linear2 + linear2 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear2.set_input(0, linear1.outputs(0)) + linear2.set_input(1, weight2) + + # linear3 + linear3 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear3.set_input(0, linear2.outputs(0)) + linear3.set_input(1, weight3) + linear3.set_input(2, bias3) + + # return [input], [ops], [output] + return [input], [linear1, linear2, linear3], [linear3.outputs(0)] + + +def test_su_init(): + + inputs, nodes, outputs = construct_model() + graph = IRGraph(nodes, inputs, outputs, 'Test') + linear1, linear2, linear3 = nodes + + su1 = ScheduleUnit([linear1], stype=SUType.Forward) + assert len(su1.inputs()) == 1 + assert len(su1.outputs()) == 1 + assert su1.signature == SUType.Forward.value + + assert su1.mirror is None + assert su1.stype == SUType.Forward + assert su1._nodes == [linear1] + assert len(su1._send_in_adapters) == 1 + assert len(su1._recv_in_adapters) == 1 + assert len(su1._send_out_adapters) == 1 + assert len(su1._recv_out_adapters) == 1 + assert len(su1._ctrl_predecessors) == 0 + assert len(su1._ctrl_successors) == 0 + + su2 = ScheduleUnit([linear1, linear2], stype=SUType.Forward) + assert len(su2.inputs()) == 1 + assert len(su2.outputs()) == 1 + assert su2.signature == SUType.Forward.value + + su3 = ScheduleUnit([linear1, linear2, linear3], stype=SUType.Forward) + assert len(su3.inputs()) == 1 + assert len(su3.outputs()) == 1 + assert su3.signature == SUType.Forward.value + + +def test_su_copy(): + + inputs, nodes, outputs = construct_model() + graph = IRGraph(nodes, inputs, outputs, 'Test') + linear1, linear2, linear3 = nodes + + su1 = ScheduleUnit([linear1, linear2], stype=SUType.Forward) + assert len(su1.inputs()) == 1 + assert len(su1.outputs()) == 1 + assert su1.signature == SUType.Forward.value + + su2 = ScheduleUnit([linear1, linear2, linear3], stype=SUType.Forward) + assert len(su2.inputs()) == 1 + assert len(su2.outputs()) == 1 + assert su2.signature == SUType.Forward.value + + su1.set_mirror(su2) + + csu = copy.copy(su1) + assert csu.inputs() == su1.inputs() + assert csu.outputs() == su1.outputs() + + assert csu.mirror is not None + mirror = csu.mirror + assert mirror.inputs() == su2.inputs() + assert mirror.outputs() == su2.outputs() diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py new file mode 100644 index 00000000..f7fb1d07 --- /dev/null +++ b/tests/schedule/test_sugraph.py @@ -0,0 +1,118 @@ +from cube.graph.tensor import IRFullTensor +from cube.graph.operator import IROperation +from cube.graph.graph import IRGraph + +from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.sugraph import SUGraph + + + +def construct_graph(): + + input = IRFullTensor(shape=[64,1024], name='data') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = IROperation( + name='linear1', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear1.set_input(0, input) + linear1.set_input(1, weight1) + linear1.set_input(2, bias1) + + # linear2 + linear2 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear2.set_input(0, linear1.outputs(0)) + linear2.set_input(1, weight2) + + # linear3 + linear3 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear3.set_input(0, linear2.outputs(0)) + linear3.set_input(1, weight3) + linear3.set_input(2, bias3) + + graph = IRGraph( + nodes=[linear1, linear2, linear3], + input_tensors=[input], + output_tensors=linear3.outputs(), + module_name="Test" + ) + return graph + + +def test_graph_init(): + + graph = construct_graph() + sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] + + sugraph = SUGraph(sus) + assert len(sugraph.inputs()) == 1 + assert len(sugraph.outputs()) == 1 + assert graph.inputs() == sugraph.inputs() + assert graph.outputs() == sugraph.outputs() + + assert sugraph.sequence == sus + + # test dependency + su1, su2, su3 = sus + assert su2 in su1.successors() + assert su3 in su2.successors() + assert su3 not in su1.successors() + assert su1 in su2.predecessors() + assert su1 in su2.predecessors(0) + assert su2 in su3.predecessors() + assert su1 not in su3.predecessors() + + +def test_sugraph_happen_before(): + + graph = construct_graph() + sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] + + sugraph = SUGraph(sus) + su1, su2, su3 = sugraph.sus() + + assert sugraph.happen_before(su1, su2) + assert not sugraph.happen_before(su2, su1) + assert sugraph.happen_before(su1, su3) + assert not sugraph.happen_before(su3, su1) + assert sugraph.happen_before(su2, su3) + assert not sugraph.happen_before(su3, su2) + + +def test_sugraph_merge(): + + graph = construct_graph() + sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] + + sugraph = SUGraph(sus) + su1, su2, su3 = sugraph.sus() + + assert sugraph.merge(su1, su3) is None + + su12 = sugraph.merge(su1, su2) + assert sugraph.nnodes == 2 + assert len(su12.inputs()) == 1 + assert len(su12.outputs()) == 1 + assert len(su12.nodes()) == 2 + assert su12 in sugraph.sus() + assert su1 not in sugraph.sus() + assert su2 not in sugraph.sus() + assert sugraph.happen_before(su12, su3) From e17fbe4d4ada338d494d87522361ddf5646772c7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Oct 2021 23:32:24 +0800 Subject: [PATCH 0216/1892] add test for add_flow --- cube/ir/cten.py | 41 +++++++++----------- cube/schedule/action.py | 70 ---------------------------------- cube/schedule/checker.py | 32 ---------------- cube/schedule/sugraph.py | 13 +++++-- tests/schedule/test_sugraph.py | 18 +++++++++ 5 files changed, 45 insertions(+), 129 deletions(-) delete mode 100644 cube/schedule/action.py delete mode 100644 cube/schedule/checker.py diff --git a/cube/ir/cten.py b/cube/ir/cten.py index e20d13a2..60a97454 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -322,6 +322,7 @@ def __init__(self, shape=None, name=None): self._cell: List[IRCell] = list() # forward graph + self._is_param = False self.requires_grad = True self.trace = None @@ -353,6 +354,19 @@ def set_trace(self, sus: List): raise TypeError("Expected List[ScheduleUnit]") self.trace = sus + def as_param(self): + """ + Set the tensor as trainable parameter + """ + self.requires_grad = True + self._is_param = True + + def is_param(self): + """ + Check if the tensor is parameter + """ + return self._is_param + def renew(self): """ Renew a new tensor with same name and shape, @@ -385,22 +399,6 @@ def __copy__(self): tensor._cell = list() return tensor - def __deepcopy__(self, memo): - """ - Deep Copy will copy the exactly same tensor with same tensor id - """ - tensor = IRTensor(self._shape, self.name) - for key in self.__dict__: - val = getattr(self, key) - if isinstance(val, IRTensor): - pass - if isinstance(val, list) and all([isinstance(v, IRTensor) for v in val]): - pass - else: - val = copy.copy(val) - setattr(tensor, key, val) - return tensor - def __eq__(self, tensor): if not isinstance(tensor, IRTensor): return False @@ -464,15 +462,10 @@ def is_leaf(self, cells: List[IRCell]): def backward(self): """ - Backward will generate a backward action scheduling pool - - Construct a reverse graph of forward and seperate to actions + Autograd backward on the tensor """ - if self.trace is None: - return - from cube.tschedule.pool import TSchedulePool - for fsu in self.trace[::-1]: - TSchedulePool().add_su(fsu.mirror) + from cube.schedule.translator import LogicTranslator + return LogicTranslator.backward(self) def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device})' diff --git a/cube/schedule/action.py b/cube/schedule/action.py deleted file mode 100644 index 1efed1fc..00000000 --- a/cube/schedule/action.py +++ /dev/null @@ -1,70 +0,0 @@ - - -class Action: - - def __init__(self, fn): - """ - fn: a function call to perform a set of operators - """ - self._fn = [fn,] - self.pre_actions = list() - self.outputs = None - self.name = 'None' - self.fid = None # flow id - self.device = -1 - self.est_latency = 1 - self.est_memory = 1 - - def __call__(self, *args, **kwargs): - """ - Execute the action - """ - outputs = self.get_input() - outputs = self._fn[0](outputs, *args, **kwargs) - self.outputs = outputs - - def get_input(self): - """ - Get input for the flow-ins from pre_actions - """ - raise NotImplementedError - - def add_pre_action(self, action): - self.pre_actions.append(action) - self.fid = action.fid - - def depends_on(self, action): - """ - check if the self -> action - - Note: this may return false negative as it will only check - 1-hop dependency - """ - if not isinstance(action, Action): - raise TypeError("Expected action to be an Action") - return action in self.pre_actions - - def tag(self, name): - """ - tag a string to indicate this action (as name) - """ - self.name = name - - def __repr__(self): - return self.name+'@{}'.format(self.device) - - -def add_flow(action1, action2): - """ - Add happened before dependency action1 -> action2 - - Args: - action1 (Action) - action2 (Action) - """ - if not isinstance(action1, Action): - raise TypeError("Expected action1 to be an Action") - if not isinstance(action2, Action): - raise TypeError("Expected action2 to be an Anction") - if not action2.depends_on(action1): - action2.add_pre_action(action1) diff --git a/cube/schedule/checker.py b/cube/schedule/checker.py deleted file mode 100644 index 51b95ff5..00000000 --- a/cube/schedule/checker.py +++ /dev/null @@ -1,32 +0,0 @@ -from cube.schedule.action import Action - - -def correct_check(sequence, actions, relations): - """ - Check if sequence satisfies the sequential consistency model - Args: - sequence (list[Actions]): action sequence - actions (list[Action]): action lists - relations (list(tuple(Action, Action))): - contains happened before tuple list - Returns: - Boolean: whether satisfies the partial order specified in relations - """ - if not all([isinstance(action, Action) for action in sequence]): - raise TypeError("Expected the sequence to be list[Action]") - if not all([isinstance(action, Action) for action in actions]): - raise TypeError("Expected the actions to be list[Action]") - - # check if all Actions in `actions` are used by sequence - if set(sequence) != set(actions): - return False - - # check partial order - for (action1, action2) in relations: - act1_idx = sequence.index(action1) - act2_idx = sequence.index(action2) - if act1_idx >= act2_idx: - return False - - # check passed - return True diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index d9491f23..44c6a9ef 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -177,14 +177,21 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: self.reset_dependency() return merged_su - def add_flow(self, su1, su2): + def add_flow(self, su1: ScheduleUnit, su2: ScheduleUnit): """ Add control flow dependency su1 -> su2 """ if not isinstance(su1, ScheduleUnit) or not isinstance(su2, ScheduleUnit): raise TypeError("Expected both SU1 and SU2 are ScheduleUnit") - su1.add_successors(-1, su2) - su2.add_predecessors(-1, su1) + if su1 not in self.sequence: + raise ValueError(f"su1 {su1} not in SUGraph") + if su2 not in self.sequence: + raise ValueError(f"su1 {su2} not in SUGraph") + if self.happen_before(su2, su1): + return False + su1.add_successor(-1, su2) + su2.add_predecessor(-1, su1) + return True def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): """ diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index f7fb1d07..6c09bec9 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -116,3 +116,21 @@ def test_sugraph_merge(): assert su1 not in sugraph.sus() assert su2 not in sugraph.sus() assert sugraph.happen_before(su12, su3) + + +def test_sugraph_add_flow(): + + graph = construct_graph() + sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] + + sugraph = SUGraph(sus) + su1, su2, su3 = sugraph.sus() + + assert su1 not in su3.predecessors() + assert su3 not in su1.successors() + + assert not sugraph.add_flow(su3, su1) + + assert sugraph.add_flow(su1, su3) + assert su1 in su3.predecessors() + assert su3 in su1.successors() From b13510548988ed12b6be2ad332dad447f6ad9398 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 19 Oct 2021 09:38:01 +0800 Subject: [PATCH 0217/1892] add test for assign and set_order --- cube/schedule/su.py | 1 + cube/schedule/sugraph.py | 20 ++++-- tests/schedule/test_sugraph.py | 116 +++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 4 deletions(-) diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 1fa51b9c..b00b42f6 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -85,6 +85,7 @@ def __copy__(self): Copy the SU. Note the mirror su is also copied """ su = ScheduleUnit(self._nodes, self.stype, self.name) + #TODO: adapter copy if self.mirror is not None: mirror_su = self.mirror mirror_su = ScheduleUnit( diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 44c6a9ef..7b275faa 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -214,11 +214,16 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): elif not all([isinstance(int, rank) for rank in ranks]): raise TypeError("Expected type ranks to be Union[int, List[int]]") + if su.stype == SUType.Adapter: + return False + if set(su.device) == set(ranks): - return + return True if len(ranks) != 1: # copy su + # TODO: adatper copy + print('warning: Missing adapter copy!!') sus = [copy.copy(su) for _ in range(len(ranks)-1)] sus = [self] + sus for su in ranks: @@ -249,16 +254,23 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): return True def set_order(self, seq: List[ScheduleUnit]): + """ + set a topological order for SUGraph, which requires seq: + + 1). The set of SUs in seq must be equal to set of SUGraph + 2). Staisfies topological order + + """ if not all([isinstance(su, ScheduleUnit) for su in seq]): raise ValueError("Expected a list of SUs") if len(seq) != len(self.sequence): - raise ValueError("Expected seq length equal with Graph sus") + return False for su in seq: if su not in self.sequence: - raise ValueError(f"Found SU {su} in seq but not in graph") + return False # correctness check if not SUGraph.is_topo_order(seq, integrity_check=True): - raise ValueError("Cannot satisfy topological order") + return False self.sequence = seq return True diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index 6c09bec9..f498de90 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -1,4 +1,5 @@ from cube.graph.tensor import IRFullTensor +from cube.graph.comm import IRCommunication from cube.graph.operator import IROperation from cube.graph.graph import IRGraph @@ -134,3 +135,118 @@ def test_sugraph_add_flow(): assert sugraph.add_flow(su1, su3) assert su1 in su3.predecessors() assert su3 in su1.successors() + + +def test_sugraph_assign(): + + graph = construct_graph() + sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] + + su1, su2, su3 = sus + + # adapter between su1-su2 + send_op = IRCommunication( + send_tensors=[su1.outputs(0)], + send_ranks = [-1] + ) + recv_op = IRCommunication( + recv_tensors=[su1.outputs(0)], + recv_ranks = [-1] + ) + send_op.pair(recv_op) + send_su12 = ScheduleUnit([send_op], SUType.Adapter, name='send') + recv_su12 = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + su1._add_out_adapter(0, send_su12, recv_su12) + su2._add_in_adapter(0, send_su12, recv_su12) + + # adapter between su2-su3 + send_op = IRCommunication( + send_tensors=[su1.outputs(0)], + send_ranks = [-1] + ) + recv_op = IRCommunication( + recv_tensors=[su1.outputs(0)], + recv_ranks = [-1] + ) + send_op.pair(recv_op) + send_su23 = ScheduleUnit([send_op], SUType.Adapter, name='send') + recv_su23 = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + su2._add_out_adapter(0, send_su23, recv_su23) + su3._add_in_adapter(0, send_su23, recv_su23) + + sugraph = SUGraph( + [su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3] + ) + + assert sugraph.assign(su1, 0) + assert su1.device == [0] + assert send_su12.device == [0] + assert send_su12.nodes(0).send_ranks == [-1] + assert recv_su12.device == [] + assert recv_su12.nodes(0).recv_ranks == [0] + + assert sugraph.assign(su2, 1) + assert su1.device == [0] + assert send_su12.device == [0] + assert send_su12.nodes(0).send_ranks == [1] + assert recv_su12.device == [1] + assert recv_su12.nodes(0).recv_ranks == [0] + + assert sugraph.assign(su3, 1) + assert su3.device == [1] + assert send_su23.device == [1] + assert send_su23.nodes(0).send_ranks == [1] + assert recv_su23.device == [1] + assert recv_su23.nodes(0).recv_ranks == [1] + + assert not sugraph.assign(send_su12, 3) + + +def test_sugraph_assign(): + + graph = construct_graph() + sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] + + su1, su2, su3 = sus + + # adapter between su1-su2 + send_op = IRCommunication( + send_tensors=[su1.outputs(0)], + send_ranks = [-1] + ) + recv_op = IRCommunication( + recv_tensors=[su1.outputs(0)], + recv_ranks = [-1] + ) + send_op.pair(recv_op) + send_su12 = ScheduleUnit([send_op], SUType.Adapter, name='send') + recv_su12 = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + su1._add_out_adapter(0, send_su12, recv_su12) + su2._add_in_adapter(0, send_su12, recv_su12) + + # adapter between su2-su3 + send_op = IRCommunication( + send_tensors=[su1.outputs(0)], + send_ranks = [-1] + ) + recv_op = IRCommunication( + recv_tensors=[su1.outputs(0)], + recv_ranks = [-1] + ) + send_op.pair(recv_op) + send_su23 = ScheduleUnit([send_op], SUType.Adapter, name='send') + recv_su23 = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + su2._add_out_adapter(0, send_su23, recv_su23) + su3._add_in_adapter(0, send_su23, recv_su23) + + sugraph = SUGraph( + [su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3] + ) + + assert not sugraph.set_order( + [su2, send_su12, recv_su12, su1, send_su23, recv_su23, su3] + ) + + assert sugraph.set_order( + [su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3] + ) From 8e0c9e6d610f204c205c94a60f65c5f20e56993f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 19 Oct 2021 10:48:56 +0800 Subject: [PATCH 0218/1892] logic translator --- cube/schedule/su.py | 17 ++++ cube/schedule/translator.py | 156 +++++++++++++++++++++++++++++++++ tests/graph/test_parser.py | 2 + tests/schedule/test_sugraph.py | 1 - 4 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 cube/schedule/translator.py diff --git a/cube/schedule/su.py b/cube/schedule/su.py index b00b42f6..25f7313d 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -158,6 +158,23 @@ def out_adapters(self, index: Optional[int] = None) -> Tuple[List, List]: else: raise TypeError("Expected index to be None or int") + def _clear_adapters(self): + """ + Clear all adapters for this SU + """ + self._send_in_adapters: List[List[ScheduleUnit]] = [ + list() for _ in range(len(self.inputs())) + ] + self._recv_in_adapters: List[List[ScheduleUnit]] = [ + list() for _ in range(len(self.inputs())) + ] + self._send_out_adapters: List[List[ScheduleUnit]] = [ + list() for _ in range(len(self.outputs())) + ] + self._recv_out_adapters: List[List[ScheduleUnit]] = [ + list() for _ in range(len(self.outputs())) + ] + def _add_in_adapter(self, index: int, send_adapter, recv_adapter): """ Add adapters to the input tensor of this SU diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py new file mode 100644 index 00000000..aa6a981c --- /dev/null +++ b/cube/schedule/translator.py @@ -0,0 +1,156 @@ +""" +Traning Logic Translator + +The traning logic first translate the training logic into +Schedule Units, and then add Adapter ScheduleUnit +""" +from typing import List +import torch + +from cube.ir.cten import IRCell, IRTensor +from cube.graph.tensor import IRFullTensor +from cube.graph.comm import IRCommunication +from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.pool import SchedulePool +from cube.schedule.sugraph import SUGraph + + +class IRDataLoader: + + def __init__(self, dataloader): + self.dataloader = iter(dataloader) + + def __iter__(self): + return self + + def __next__(self): + return LogicTranslator.load_data(self) + + +class LogicTranslator: + + @staticmethod + def load_data(dataloader: IRDataLoader): + """ + Translator Action: Load data from data loaderw + """ + datas = next(dataloader.dataloader) + if not isinstance(datas, tuple): + datas = (datas,) + + # data IRTensor + outputs = list() + for data in datas: + if torch.is_tensor(data): + data = IRFullTensor(shape=list(data.shape), name='data') + data.requires_grad = False + outputs.append(data) + + cell = IRCell( + name='dataloader', + signature='dataloader.__next__', + input_length=0, + output_length=len(datas) + ) + for idx, output in enumerate(outputs): + cell.set_output(idx, output) + + su = ScheduleUnit([cell], stype=SUType.Dataloader, name='DataLoader') + SchedulePool().add_su(su) + + if len(outputs) == 0: return + elif len(outputs) == 1: return outputs[0] + else: return tuple(outputs) + + @staticmethod + def forward(graph, *args): + """ + Translator Action: forward an IRGraph + """ + + def _forward(graph, stype, *args): + # set input + for input, arg in zip(graph.inputs(), args): + graph._replace_tensor(input, arg) + # translate to SUs + sus = list() + for node in graph.nodes(): + su = ScheduleUnit([node], stype, name=str(stype)) + sus.append(su) + return sus + + # forward graph + fgraph = graph.copy(reverse=False) + # backward graph + bgraph = graph.copy(reverse=True) + bgraph.tag = 'backward' + + # translate forward graph + fsus = _forward(fgraph, SUType.Forward, *args) + bsus = _forward(bgraph, SUType.Backward, *(fgraph.outputs())) + for fsu, bsu in zip(fsus, bsus[::-1]): + fsu.set_mirror(bsu) + bsu.set_mirror(fsu) + SchedulePool().add_su(fsu) + + for output in fgraph.outputs(): + output.set_trace(fsus) + + outputs = fgraph.outputs() + if len(outputs) == 1: return outputs[0] + elif len(outputs) == 0: return None + else: return outputs + + @staticmethod + def backward(tensor: IRTensor): + """ + Translator Action: backward a tensor + """ + if tensor.trace is None: + return + for fsu in tensor.trace[::-1]: + SchedulePool().add_su(fsu.mirror) + + @staticmethod + def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: + """ + Each computation SU has adapters for its inputs + """ + sugraph = SUGraph(sus) + + # clear adapters + for su in sugraph.sus(): + su._clear_adapters() + + for su in sugraph.sus(): + for in_idx, input in enumerate(su.inputs()): + if not isinstance(input, IRTensor): + continue + pre_sus = su.predecessors(in_idx) + for pre_su in pre_sus: + for out_idx, output in enumerate(pre_su.outputs()): + if output.overlap(input): + sub_tensor = output.common(input) + send_op = IRCommunication( + send_tensors=[sub_tensor], + send_ranks = [-1] + ) + recv_op = IRCommunication( + recv_tensors=[sub_tensor], + recv_ranks = [-1] + ) + send_op.pair(recv_op) + send_su = ScheduleUnit([send_op], SUType.Adapter, name='send') + recv_su = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + su._add_in_adapter(in_idx, send_su, recv_su) + pre_su._add_out_adapter(out_idx, send_su, recv_su) + + sus_with_adapter = list() + for su in sus: + for idx in range(len(su.inputs())): + send_adapters, recv_adapters = su.in_adapters(idx) + for send_su, recv_su in zip(send_adapters, recv_adapters): + sus_with_adapter.append(send_su) + sus_with_adapter.append(recv_su) + sus_with_adapter.append(su) + return sus_with_adapter diff --git a/tests/graph/test_parser.py b/tests/graph/test_parser.py index 4bcce9be..c36bee0e 100644 --- a/tests/graph/test_parser.py +++ b/tests/graph/test_parser.py @@ -56,3 +56,5 @@ def test_parse_module(): assert node3.successors() == [node4] assert node4.successors() == [node5] assert node5.successors() == [node6] + + assert graph.outputs(0).shape == [1024, 1000] diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index f498de90..5bb66fca 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -7,7 +7,6 @@ from cube.schedule.sugraph import SUGraph - def construct_graph(): input = IRFullTensor(shape=[64,1024], name='data') From 05bec95d03b7925633afd6f8fcd6ee12ae04d46b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 19 Oct 2021 12:36:40 +0800 Subject: [PATCH 0219/1892] fix workflow bugs --- cube/schedule/__init__.py | 33 ++++-- tests/schedule/test_translator.py | 168 ++++++++++++++++++++++++++++++ tests/schedule/test_worflow.py | 103 ++++++++++++++++++ 3 files changed, 297 insertions(+), 7 deletions(-) create mode 100644 tests/schedule/test_translator.py create mode 100644 tests/schedule/test_worflow.py diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index 2f73e064..d240cf51 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -4,7 +4,7 @@ from cube.schedule.pool import SchedulePool from cube.schedule.translator import IRDataLoader, LogicTranslator from cube.schedule.sugraph import SUGraph -from cube.codegen.codegen import TScheduleCodeGen +from cube.codegen.codegen import SScheduleCodeGen, TScheduleCodeGen class SemanticModel: @@ -70,8 +70,13 @@ def train_step(model, dataloader): ir_graph = model.get_graph() ir_dataloader = IRDataLoader(dataloader) - myrank = torch.distributed.get_rank() + if torch.distributed.is_initialized(): + # multiple device + myrank = torch.distributed.get_rank() + else: + # single device + myrank = 0 def _load_tschedule_fn(filename) -> Callable: import importlib.util @@ -99,22 +104,36 @@ def decorator(fn: Callable) -> Callable: # policy su_graph = SUGraph(sus_with_adapter) if policy_fn: - seq = policy_fn(su_graph) + # TODO: add resource + su_graph = policy_fn(su_graph, None) + + # check assignment and order + for su in su_graph.sus(): + if len(su.device) == 0: + raise RuntimeError(f"SU {su} device is not set") + if not SUGraph.is_topo_order(su_graph.sus()): + raise RuntimeError(f"SUGraph order is not topological order") + + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 # code generation - world_size = torch.distributed.get_world_size() - tgener = TScheduleCodeGen(seq) + tgener = TScheduleCodeGen(su_graph) + sgener = SScheduleCodeGen(su_graph) for rank in range(world_size): fname = filename.format(rank) # generate spatial module code - model.gen_module(seq, rank, fname, attach=False) + sgener.gen(rank, outfile=fname, attach=True) # generate temporal schedule code tgener.gen( device = rank, outfile = fname, attach=True ) - torch.distributed.barrier() + if torch.distributed.is_initialized(): + torch.distributed.barrier() # load module model.load_module(filename.format(myrank)) # load temporal diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py new file mode 100644 index 00000000..ec529653 --- /dev/null +++ b/tests/schedule/test_translator.py @@ -0,0 +1,168 @@ +import torch + +from cube.schedule.translator import LogicTranslator +from cube.schedule.translator import IRDataLoader +from cube.schedule.su import SUType +from cube.schedule.pool import SchedulePool + +from cube.graph.tensor import IRFullTensor +from cube.graph.operator import IROperation +from cube.graph.graph import IRGraph + + +class FakeDataLoader: + def __init__(self, batch_size, num=640): + self.batch_size = batch_size + self.length = num + self.pos = 0 + def __iter__(self): + self.pos = 0 + return self + def __next__(self): + self.pos += 1 + if self.pos == self.length: + raise StopIteration + return torch.randn((self.batch_size, 1024)) + + +def construct_graph(): + + input = IRFullTensor(shape=[64,1024], name='data') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = IROperation( + name='linear1', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear1.set_input(0, input) + linear1.set_input(1, weight1) + linear1.set_input(2, bias1) + linear1.infer_shape() + + # linear2 + linear2 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear2.set_input(0, linear1.outputs(0)) + linear2.set_input(1, weight2) + linear2.infer_shape() + + # linear3 + linear3 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear3.set_input(0, linear2.outputs(0)) + linear3.set_input(1, weight3) + linear3.set_input(2, bias3) + linear3.infer_shape() + + graph = IRGraph( + nodes=[linear1, linear2, linear3], + input_tensors=[input], + output_tensors=linear3.outputs(), + module_name="Test" + ) + return graph + + +def test_load_dataloader(): + + SchedulePool().clear() + dataloader = IRDataLoader(FakeDataLoader(batch_size=64)) + + data1 = next(dataloader) + assert isinstance(data1, IRFullTensor) + assert data1.shape == [64, 1024] + + data2 = next(dataloader) + assert len(SchedulePool().sus()) == 2 + assert all([su.stype == SUType.Dataloader for su in SchedulePool().sus()]) + + data3 = LogicTranslator.load_data(dataloader) + assert isinstance(data1, IRFullTensor) + assert data1.shape == [64, 1024] + assert len(SchedulePool().sus()) == 3 + assert all([su.stype == SUType.Dataloader for su in SchedulePool().sus()]) + + +def test_translator_forward(): + SchedulePool().clear() + + graph = construct_graph() + print(graph) + data = IRFullTensor(shape=[64,1024], name='data') + output = graph(data) + + assert isinstance(output, IRFullTensor) + assert output.shape == [64, 1024] + assert output.trace is not None + + sus = SchedulePool().sus() + assert len(sus) == 3 + assert output.trace == sus + for su in sus: + assert su.stype == SUType.Forward + assert su.mirror is not None + + +def test_translator_backward(): + SchedulePool().clear() + + graph = construct_graph() + data = IRFullTensor(shape=[64,1024], name='data') + output = graph(data) + + output.backward() + + sus = SchedulePool().sus() + assert len(sus) == 6 + fsus = sus[0:3] + bsus = sus[3:] + for fsu, bsu in zip(fsus, bsus[::-1]): + assert fsu.mirror == bsu + assert bsu.mirror == fsu + assert bsu.stype == SUType.Backward + + +def test_translatro_gen_adapter(): + SchedulePool().clear() + + graph = construct_graph() + data = IRFullTensor(shape=[64,1024], name='data') + output = graph(data) + + # forward adatpers + sus = SchedulePool().sus() + sus = LogicTranslator.gen_adapter(sus) + assert len(sus) == 7 + su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3 = sus + assert su1.stype == SUType.Forward + assert su2.stype == SUType.Forward + assert su3.stype == SUType.Forward + assert send_su12.stype == SUType.Adapter + assert recv_su12.stype == SUType.Adapter + assert send_su23.stype == SUType.Adapter + assert recv_su23.stype == SUType.Adapter + + # backward adapters + output.backward() + sus = SchedulePool().sus() + sus = LogicTranslator.gen_adapter(sus) + for su in sus: + print(su) + # note loss will be the input to autograd, therefore + # have additional adapters + assert len(sus) == 16 diff --git a/tests/schedule/test_worflow.py b/tests/schedule/test_worflow.py new file mode 100644 index 00000000..b6e6e804 --- /dev/null +++ b/tests/schedule/test_worflow.py @@ -0,0 +1,103 @@ +import torch +from torch import nn + +import cube +from cube.graph.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.pool import SchedulePool +from cube.schedule.sugraph import SUGraph +from cube.schedule.translator import LogicTranslator, IRDataLoader + + +class MLP(nn.Module): + def __init__(self, dim, mult=16): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult, bias=False) + self.linear2 = nn.Linear(dim * mult, dim) + self.linear3 = nn.Linear(dim, dim * mult, bias=False) + self.linear4 = nn.Linear(dim * mult, dim) + + def forward(self, data): + output = self.linear1(data) + output = self.linear2(output) + output = self.linear3(output) + output = self.linear4(output) + return output + + +class FakeDataLoader: + def __init__(self, shape, num=640): + self.shape = shape + self.length = num + self.pos = 0 + def __iter__(self): + self.pos = 0 + return self + def __next__(self): + self.pos += 1 + if self.pos == self.length: + raise StopIteration + return torch.randn(self.shape).cuda() + + +def test_semantic_model(): + dim = 1024 + model = MLP(dim=dim) + model = cube.schedule.SemanticModel( + model, + input_shapes=([64, dim],) + ) + assert isinstance(model.ir_graph, IRGraph) + assert model._loaded_module is None + + +def test_schedule(): + + SchedulePool().clear() + + dim = 1024 + batch_size = 64 + + model = MLP(dim=dim) + model = cube.schedule.SemanticModel( + model, + input_shapes=([batch_size, dim],) + ) + + dataloader = FakeDataLoader((batch_size, dim)) + dataloader = IRDataLoader(dataloader) + + def policy(sugraph, resources): + # dataloader + sugraph.assign(sugraph.sus(0), 0) + + fsus = [su for su in sugraph.sus() if su.stype == SUType.Forward] + for idx, fsu in enumerate(fsus): + bsu = fsu.mirror + if idx < 2: + sugraph.assign(fsu, 0) + sugraph.assign(bsu, 0) + else: + sugraph.assign(fsu, 1) + sugraph.assign(bsu, 1) + return sugraph + + def train_iter(model, dataloader): + num_micro_batch = 1 + for _ in range(num_micro_batch): + data = next(dataloader) + output = model(data) + output.backward() + + train_iter(model, dataloader) + + sus = SchedulePool().sus() + sus_with_adapter = LogicTranslator.gen_adapter(sus) + sugraph = SUGraph(sus_with_adapter) + + sugraph = policy(sugraph, None) + + for su in sugraph.sus(): + print(su) + + assert len(sugraph.sus()) == 1 + 2 * (4 * 3) From 82cfe74ebe003a38cdd1c25c09c3b535731deca2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 19 Oct 2021 20:03:00 +0800 Subject: [PATCH 0220/1892] graph pass for remove and merge SUs --- cube/__init__.py | 1 + cube/schedule/__init__.py | 6 ++ cube/schedule/graphpass.py | 50 +++++++++++++ cube/schedule/sugraph.py | 52 ++++++++------ tests/schedule/test_graphpass.py | 117 +++++++++++++++++++++++++++++++ 5 files changed, 203 insertions(+), 23 deletions(-) create mode 100644 cube/schedule/graphpass.py create mode 100644 tests/schedule/test_graphpass.py diff --git a/cube/__init__.py b/cube/__init__.py index 491513f6..3b647e70 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,3 +1,4 @@ from cube.device.physic.group import DeviceGroup +from cube import schedule from cube import runtime diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index d240cf51..ad4286f0 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -4,6 +4,8 @@ from cube.schedule.pool import SchedulePool from cube.schedule.translator import IRDataLoader, LogicTranslator from cube.schedule.sugraph import SUGraph +from cube.schedule.graphpass import SUGraphPass + from cube.codegen.codegen import SScheduleCodeGen, TScheduleCodeGen @@ -114,6 +116,10 @@ def decorator(fn: Callable) -> Callable: if not SUGraph.is_topo_order(su_graph.sus()): raise RuntimeError(f"SUGraph order is not topological order") + # graph pass to remove redundant sus + su_graph = SUGraphPass.remove_redundant_adapters(su_graph) + su_graph = SUGraphPass.merge_small_sus(su_graph) + if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() else: diff --git a/cube/schedule/graphpass.py b/cube/schedule/graphpass.py new file mode 100644 index 00000000..cb9f7d06 --- /dev/null +++ b/cube/schedule/graphpass.py @@ -0,0 +1,50 @@ +from cube.schedule.sugraph import SUGraph +from cube.schedule.su import SUType, ScheduleUnit + + +class SUGraphPass: + + @staticmethod + def remove_redundant_adapters(sugraph: SUGraph) -> SUGraph: + """ + Remove redundant adapters + + A redundant adapter is sending and recving + on the same device + """ + redundant_adapters = list() + for su in sugraph.sus(): + if su.stype != SUType.Adapter: + for idx in range(len(su.outputs())): + send_adapters, recv_adapters = su.out_adapters(idx) + for sadapter, radapter in zip(send_adapters, recv_adapters): + # indicate a tensor selection in-device + if sadapter.device == radapter.device: + if len(sadapter.inputs()) != 1: + raise NotImplementedError + # indicate identity op: + if sadapter.inputs(0).shape == su.outputs(idx).shape: + redundant_adapters.append(sadapter) + redundant_adapters.append(radapter) + + all_sus = sugraph.sus() + for adapter in redundant_adapters: + if adapter in all_sus: + all_sus.remove(adapter) + + sugraph = SUGraph(all_sus) + return sugraph + + @staticmethod + def merge_small_sus(sugraph: SUGraph) -> SUGraph: + """ + Merge SU to a larger one if possible + """ + merged_su = None + for su in sugraph.sus(): + if su.stype == SUType.Forward: + if not isinstance(merged_su, ScheduleUnit): + merged_su = su + continue + merged_su = sugraph.merge(merged_su, su) + return sugraph diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 7b275faa..0abacb95 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -43,19 +43,19 @@ def reset_dependency(self): """ # set node predecessors and successors for src_idx in range(self.nnodes): - src_cell = self.sequence[src_idx] - src_cell._successors = [ - list() for _ in range(len(src_cell.outputs())) + src_su = self.sequence[src_idx] + src_su._successors = [ + list() for _ in range(len(src_su.outputs())) ] for dst_su in self.sequence[src_idx+1:]: dst_su._predecessors = [ list() for _ in range(len(dst_su.inputs())) ] - for out_idx, out_tensor in enumerate(src_cell.outputs()): + for out_idx, out_tensor in enumerate(src_su.outputs()): for in_idx, in_tensor in enumerate(dst_su.inputs()): if out_tensor.overlap(in_tensor): - src_cell.add_successor(out_idx, dst_su) - dst_su.add_predecessor(in_idx, src_cell) + src_su.add_successor(out_idx, dst_su) + dst_su.add_predecessor(in_idx, src_su) def __len__(self): return len(self.sequence) @@ -132,8 +132,12 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: # 2) all the nodes in both SU are on the same device if su1 == su2 or su1.stype != su2.stype: return None - if set(su1.device) != set(su2.device): + if su1.device != su2.device: return None + + #TODO: GraphPass on remove redundant adapter also need TODO + if su1.stype == SUType.Adapter: + raise NotImplementedError("Not supported for merging Adapter") index_su1 = self.sequence.index(su1) index_su2 = self.sequence.index(su2) @@ -148,19 +152,21 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: # merge forward su sub_nodes = su1.nodes() + su2.nodes() merged_su = ScheduleUnit(sub_nodes, su1.stype) + merged_su.device = su1.device # merge mirrored su # mirror_su2 -> mirror_su1 mirror_su1, mirror_su2 = su1.mirror, su2.mirror - if mirror_su1 is not None and mirror_su2 is not None: - sub_nodes = mirror_su2.nodes() + mirror_su1.nodes() - merged_mirror_su = ScheduleUnit(sub_nodes, mirror_su1.stype) - # set mirror - merged_su.set_mirror(merged_mirror_su) - merged_mirror_su.set_mirror(merged_su) - elif mirror_su1 is None and mirror_su2 is None: - merged_mirror_su = None - else: + merged_mirror_su = None + if mirror_su1 and mirror_su2: + if mirror_su1.device == mirror_su2.device: + sub_nodes = mirror_su2.nodes() + mirror_su1.nodes() + merged_mirror_su = ScheduleUnit(sub_nodes, mirror_su1.stype) + merged_mirror_su.device = mirror_su1.device + # set mirror + merged_su.set_mirror(merged_mirror_su) + merged_mirror_su.set_mirror(merged_su) + elif mirror_su1 or mirror_su2: raise RuntimeError( "The merged su should be both have mirror or both not have." ) @@ -168,10 +174,11 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: # replace self.sequence[index_su1] = merged_su self.sequence.remove(su2) - if mirror_su1 in self.sequence and mirror_su2 in self.sequence: - index_mirror_su2 = self.sequence.index(mirror_su2) - self.sequence[index_mirror_su2] = merged_mirror_su - self.sequence.remove(mirror_su1) + if merged_mirror_su: + if mirror_su1 in self.sequence and mirror_su2 in self.sequence: + index_mirror_su2 = self.sequence.index(mirror_su2) + self.sequence[index_mirror_su2] = merged_mirror_su + self.sequence.remove(mirror_su1) # TODO: optimize: reset dependency self.reset_dependency() @@ -225,10 +232,9 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): # TODO: adatper copy print('warning: Missing adapter copy!!') sus = [copy.copy(su) for _ in range(len(ranks)-1)] - sus = [self] + sus - for su in ranks: + for su in sus: index = self.sus().index(su) - self.sequence.insert(su, index) + self.sequence.insert(index, su) self.reset_dependency() for su, rank in zip(sus, ranks): self.assign(su, rank) diff --git a/tests/schedule/test_graphpass.py b/tests/schedule/test_graphpass.py new file mode 100644 index 00000000..f27ce97a --- /dev/null +++ b/tests/schedule/test_graphpass.py @@ -0,0 +1,117 @@ +from cube.graph.tensor import IRFullTensor +from cube.graph.operator import IROperation +from cube.graph.graph import IRGraph + + +from cube.schedule.graphpass import SUGraphPass +from cube.schedule.pool import SchedulePool +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.schedule.translator import LogicTranslator + + +def construct_graph(): + + input = IRFullTensor(shape=[64,1024], name='data') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = IROperation( + name='linear1', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear1.set_input(0, input) + linear1.set_input(1, weight1) + linear1.set_input(2, bias1) + + # linear2 + linear2 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear2.set_input(0, linear1.outputs(0)) + linear2.set_input(1, weight2) + + # linear3 + linear3 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear3.set_input(0, linear2.outputs(0)) + linear3.set_input(1, weight3) + linear3.set_input(2, bias3) + + graph = IRGraph( + nodes=[linear1, linear2, linear3], + input_tensors=[input], + output_tensors=linear3.outputs(), + module_name="Test" + ) + return graph + + +def test_remove_adapter(): + + SchedulePool().clear() + + graph = construct_graph() + data = IRFullTensor(shape=[64,1024], name='data') + output = graph(data) + output.backward() + + # forward adatpers + sus = SchedulePool().sus() + sus = LogicTranslator.gen_adapter(sus) + + sugraph = SUGraph(sus) + for su in sugraph.sus(): + sugraph.assign(su, 0) + sugraph = SUGraphPass.remove_redundant_adapters(sugraph) + for su in sugraph.sus(): + print(su) + for su in sugraph.sus(): + assert su.stype != SUType.Adapter + assert len(sugraph.sus()) == 6 + + +def test_merge_small_sus(): + + SchedulePool().clear() + + graph = construct_graph() + data = IRFullTensor(shape=[64,1024], name='data') + output = graph(data) + output.backward() + + # forward adatpers + sus = SchedulePool().sus() + + sugraph = SUGraph(sus) + + for su in sugraph.sus(): + sugraph.assign(su, 0) + + print('orignal:') + for su in sugraph.sus(): + print(su) + + sugraph = SUGraphPass.merge_small_sus(sugraph) + + print('changed:') + for su in sugraph.sus(): + print(su) + + assert len(sugraph.sus()) == 2 + assert sugraph.sus(0).stype == SUType.Forward + assert sugraph.sus(1).stype == SUType.Backward + assert sugraph.sus(0).mirror == sugraph.sus(1) From 009cc3175e3fcf3ea8364407580f5b6aa8d7cf84 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 21 Oct 2021 10:05:17 +0800 Subject: [PATCH 0221/1892] move comm to schedule adapter --- cube/graph/graph.py | 19 +-- cube/schedule/adapter/__init__.py | 0 cube/{graph => schedule/adapter}/comm.py | 0 cube/schedule/adapter/select.py | 98 ++++++++++++++++ cube/schedule/execplan.py | 112 ++++++++++++++++++ cube/schedule/translator.py | 2 +- tests/codegen/test_codegen.py | 140 ++++++++++------------ tests/schedule/test_sugraph.py | 2 +- tests/tschedule/test_tschedule.py | 143 ----------------------- 9 files changed, 274 insertions(+), 242 deletions(-) create mode 100644 cube/schedule/adapter/__init__.py rename cube/{graph => schedule/adapter}/comm.py (100%) create mode 100644 cube/schedule/adapter/select.py create mode 100644 cube/schedule/execplan.py delete mode 100644 tests/tschedule/test_tschedule.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 9b751b0a..49a69df4 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -11,7 +11,6 @@ from cube.ir.cten import IRTensor, IRCell from cube.graph.operator import IROperation -from cube.graph.comm import IRCommunication import copy @@ -113,23 +112,7 @@ def _renew(val: Any): nodes = list() for node in self.nodes(): - if isinstance(node, IRCommunication): - send_tensors = [_renew(tensor) for tensor in node.inputs()] - send_ranks = node.send_ranks - recv_tensors = [_renew(tensor) for tensor in node.outputs()] - recv_ranks = node.recv_ranks - if reverse: - send_tensors, recv_tensors = recv_tensors, send_tensors - send_ranks, recv_ranks = recv_ranks, send_ranks - - new_node = IRCommunication( - send_tensors = send_tensors, - send_ranks = send_ranks, - recv_tensors = recv_tensors, - recv_ranks = recv_ranks - ) - - elif isinstance(node, IROperation): + if isinstance(node, IROperation): inputs = node.inputs() outputs = node.outputs() if reverse: diff --git a/cube/schedule/adapter/__init__.py b/cube/schedule/adapter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/graph/comm.py b/cube/schedule/adapter/comm.py similarity index 100% rename from cube/graph/comm.py rename to cube/schedule/adapter/comm.py diff --git a/cube/schedule/adapter/select.py b/cube/schedule/adapter/select.py new file mode 100644 index 00000000..433978ba --- /dev/null +++ b/cube/schedule/adapter/select.py @@ -0,0 +1,98 @@ +from typing import List +from enum import Enum +import numpy as np + +from cube.ir.cten import IRCell, IRTensor +from cube.graph.tensor import IRSubTensor, IRFullTensor + + +class IRReshapeType(Enum): + + Select = 'cube.runtime.adapter.select' + Merge = 'cube.runtime.adapter.merge' + + +class IRTensorReshape(IRCell): + """ + Tensor transformation by convert source tensors + to destination tensors + + Select: + src_tensors is only one tensor, dst_tensors has (multiple) tensors. + This will select the sub_tensor and generate what it need + + Merge: + src_tensors has (multiple) tensors, dst_tensors is only one tensor. + This will merge the sub_tensor and generate what it need + """ + def __init__(self, src_tensors: List[IRTensor], dst_tensors: List[IRTensor]): + + if len(src_tensors) != 1 and len(dst_tensors) != 1: + raise ValueError("Expected at least one of tensors has length 1") + self._src_tensors = src_tensors + self._dst_tensors = dst_tensors + + self.ttype = None + + self.select_indices = list() + self.merge_axis = None + + if len(src_tensors) == 1: + self.ttype = IRReshapeType.Select + src_tensor = src_tensors[0] + # select + for tensor in dst_tensors: + indices = tensor.common(src_tensor) + self.select_indices.append(indices) + + elif len(dst_tensors) == 1: + self.ttype = IRReshapeType.Merge + dst_tensor = dst_tensors[0] + # find dims to concat + ndims = len(dst_tensor.shape) + indices = [set() for _ in range(ndims)] + for src_tensor in src_tensors: + if isinstance(src_tensor, IRSubTensor): + for ndim, slicer in enumerate(src_tensor.indices.get()): + indices[ndim].add(slicer) + elif isinstance(dst_tensor, IRFullTensor): + for ndim, dim_len in enumerate(src_tensor.shape): + slicer = slice(0, dim_len, 1) + indices[ndim].add(slicer) + # check if only one dim set has multiple slicer + for dim, dim_indices in enumerate(indices): + if len(dim_indices) != 1: + if self.merge_axis is not None: + raise NotImplementedError("Only support merge on one axis") + self.merge_axis = dim + dim_indices = indices[self.merge_axis] + # check if they are overlapped + starts = np.array([slicer.start for slicer in dim_indices]) + stops = np.array([slicer.stop for slicer in dim_indices]) + steps = np.array([slicer.step for slicer in dim_indices]) + sorted_idx = np.argsort(starts) + sorted_starts = starts[sorted_idx] + sorted_stops = stops[sorted_idx] + sorted_steps = steps[sorted_idx] + for last_stop, begin_start in zip(sorted_stops[:-1], sorted_starts[1:]): + if last_stop != begin_start: + raise NotImplementedError(f"Concatenation fails due to axis {last_stop} != {begin_start}") + for step in sorted_steps: + if step != 1: + raise NotImplementedError(f"Found a SubTensor step {step} != 1") + # re-order + dst_tensors = dst_tensors[sorted_idx] + + else: + raise RuntimeError("Internal Error") + + super().__init__( + name = 'transformation', + signature = self.ttype.value, + input_length = len(src_tensors), + output_length = len(dst_tensors) + ) + for idx, input in enumerate(src_tensors): + self.set_input(idx, input) + for idx, output in enumerate(dst_tensors): + self.set_output(idx, output) diff --git a/cube/schedule/execplan.py b/cube/schedule/execplan.py new file mode 100644 index 00000000..ef8e920b --- /dev/null +++ b/cube/schedule/execplan.py @@ -0,0 +1,112 @@ +from typing import List, Optional + +from cube.schedule.sugraph import SUGraph +from cube.schedule.su import SUType, ScheduleUnit + + +class ExectuionPlan: + + def __init__(self, seq: SUGraph): + + self.seq = seq + self.device_seq = dict() + for su in seq.sequence: + device = su.device[0] + if device not in self.device_seq: + self.device_seq[device] = [su] + else: + self.device_seq[device].append(su) + + def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): + """ + Draw the execution timeline. + + Args: + span (List[int]): + length equal to schedule unit num. + Each element stands for the time span for corresponding SU + + outfile: + the output file name + """ + ndevice = max(self.device_seq.keys()) + 1 + # timeline [ [ (start_time, end_time), ... ], ... ] + device_timeline = [list() for _ in range(ndevice)] + device_sus = [list() for _ in range(ndevice)] + + if spans is None: + spans = list() + for su in self.seq.sus(): + span = 0 + if su.stype == SUType.Forward: + span = 1 + elif su.stype == SUType.Backward: + span = 2 + elif su.stype == SUType.Adapter: + span = 0.1 + spans.append(span) + + for su, span_time in zip(self.seq.sequence, spans): + device = su.device[0] + + # tight execution if no dependency + if len(device_timeline[device]) == 0: + start_time = 1 + else: + start_time = device_timeline[device][-1][1] + + # check dependency + for devid, (timeline, dev_sus) in enumerate(zip(device_timeline, device_sus)): + if devid == device: + continue + for suid, (_, end_time) in enumerate(timeline[::-1]): + other_su = dev_sus[::-1][suid] + if other_su.happen_before(su): + start_time = max(start_time, end_time) + break + + device_timeline[device].append((start_time, start_time + span_time)) + device_sus[device].append(su) + + # draw the timeline + if outfile is not None: + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + plt.rcParams['figure.figsize'] = (12.0, 4.0) + + max_time = max( + [tline[-1][1] for tline in device_timeline if len(tline) != 0] + ) + + fig, ax = plt.subplots() + ax.set_xlim((1, max_time)) + plt.xticks(list(range(1, max_time+1, 1))) + ax.xaxis.grid(True, linestyle='--') + plt.xlabel('time') + + # yaxis + ax.set_ylim((0.5, self.ndevice+0.5)) + plt.yticks(list(range(1, self.ndevice+1, 1))) + ax.invert_yaxis() + plt.ylabel('device id') + + ax.set_aspect('equal') + + for devid in range(ndevice): + timeline = device_timeline[devid] + sus = device_sus[devid] + for su, (start, end) in zip(sus, timeline): + # draw + color = 'blue' if (end - start) == 1 else 'orange' + rec = Rectangle((start, devid + 0.5), end-start, 1, + color=color, ec='black', lw=1.5) + ax.add_artist(rec) + rx, ry = rec.get_xy() + cx = rx + rec.get_width() / 2.0 + cy = ry + rec.get_height() / 2.0 + anno = str(su.stype) + # anno = su.name if action.fid is None else action.fid + ax.annotate(anno, (cx, cy), color='w', weight='bold', + fontsize=10, ha='center', va='center') + # plt.grid() + plt.savefig(outfile) \ No newline at end of file diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index aa6a981c..679b61ba 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -9,7 +9,7 @@ from cube.ir.cten import IRCell, IRTensor from cube.graph.tensor import IRFullTensor -from cube.graph.comm import IRCommunication +from cube.schedule.adapter.comm import IRCommunication from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.pool import SchedulePool from cube.schedule.sugraph import SUGraph diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index ad71f042..d6cf6884 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -1,79 +1,61 @@ -from cube.graph import parser, IRAction, IRTensor -from cube.codegen.codegen import SScheduleCodeGen - -import torch -from torch import nn - - -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult, bias=False) - self.gelu = nn.GELU() - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim * mult, dim) - self.classifier = nn.Linear(dim, classes) - - def forward(self, data): - output = self.linear1(data) - output = self.gelu(output) - output = self.dropout(output) - output = self.linear2(output) - output = output + data - output = self.classifier(output) - return output - - -model = FeedForward(dim=1024) - - -def import_from_file(filename): - print(f'> loading GenModel from {filename} ...') - import importlib.util - spec = importlib.util.spec_from_file_location("GenModel", filename) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.GenModel - - -def init_weight(parameters): - for param in parameters: - with torch.no_grad(): - torch.nn.init.uniform_(param) - - -def test_codegen(model): - graph = parser.convert(model, - input_shapes=([1024,1024],)) - for node in graph.nodes(): - node.device = 0 - local_graph = IRAction(graph.nodes(), graph, devices=[0]) - gener = SScheduleCodeGen(local_graph) - code = gener.gen(device=0, outfile='code.py') - - # execute - print("> ===== Generated Code =====") - print(code) - print("< ===== Generated Code =====") - - GenModel = import_from_file('code.py') - model = GenModel().cuda() - - init_weight(model.parameters()) - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - print("> training 10 iterations...") - - for _ in range(10): - data = torch.randn([64,1024], device=torch.device('cuda:0')) - out = model(data) - loss = torch.mean(out) / 1000 - print(f'> loss: {loss.item()}') - loss.backward() - optimizer.step() - optimizer.zero_grad() - - -if __name__ == '__main__': - - test_codegen(model) \ No newline at end of file +from cube.graph.tensor import IRFullTensor +from cube.graph.comm import IRCommunication +from cube.graph.operator import IROperation +from cube.graph.graph import IRGraph + +from cube.codegen.codegen import SScheduleCodeGen, TScheduleCodeGen + + +def construct_graph(): + + input = IRFullTensor(shape=[64,1024], name='data') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = IROperation( + name='linear1', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear1.set_input(0, input) + linear1.set_input(1, weight1) + linear1.set_input(2, bias1) + + # linear2 + linear2 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear2.set_input(0, linear1.outputs(0)) + linear2.set_input(1, weight2) + + # linear3 + linear3 = IROperation( + name='linear2', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear3.set_input(0, linear2.outputs(0)) + linear3.set_input(1, weight3) + linear3.set_input(2, bias3) + + graph = IRGraph( + nodes=[linear1, linear2, linear3], + input_tensors=[input], + output_tensors=linear3.outputs(), + module_name="Test" + ) + return graph + + +def test_model_gen(): + + \ No newline at end of file diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index 5bb66fca..64b42bd1 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -1,10 +1,10 @@ from cube.graph.tensor import IRFullTensor -from cube.graph.comm import IRCommunication from cube.graph.operator import IROperation from cube.graph.graph import IRGraph from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.sugraph import SUGraph +from cube.schedule.adapter.comm import IRCommunication def construct_graph(): diff --git a/tests/tschedule/test_tschedule.py b/tests/tschedule/test_tschedule.py deleted file mode 100644 index b7835b01..00000000 --- a/tests/tschedule/test_tschedule.py +++ /dev/null @@ -1,143 +0,0 @@ -from cube.graph.ir_graph import IRGraph -from cube.tschedule.pool import TSchedulePool -from cube.graph.ir_cten import IRTensor -from cube.tschedule.suseq import SUSequence -from cube.sschedule.adapter import Adapter -from torch import nn -import torch - -import cube.graph.parser as parser -from cube.codegen.codegen import SScheduleCodeGen, TScheduleCodeGen - - -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult, bias=False) - self.gelu = nn.GELU() - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim * mult, dim) - self.classifier = nn.Linear(dim, classes) - - def forward(self, data): - output = self.linear1(data) - output = self.gelu(output) - output = self.dropout(output) - output = self.linear2(output) - output = self.classifier(output) - loss = torch.sum(output) - return loss - - -model = FeedForward(dim=1024) -ir_graph = parser.convert(model, input_shapes=([64,1024],)) - -# device assignment -for nid, node in enumerate(ir_graph.nodes()): - if nid < 3: - node.device = 0 - else: - node.device = 1 - -print('====== Forward Graph =======\n') -print(ir_graph) -ir_graph = Adapter.adapt(ir_graph) -print('====== Forward Graph =======\n') - - -def test_graph_forward(ir_graph): - - TSchedulePool().clear() - tensor1 = ir_graph(IRTensor(shape=[64,1024])) - tensor2 = ir_graph(IRTensor(shape=[64,1024])) - assert tensor1 != tensor2 - print('====== Forward Test =======') - seq = SUSequence(TSchedulePool().sus()) - - for su in seq.sus(): - print(su) - - print('\n====== Forward Test =======\n') - - -def test_su_merge(ir_graph): - - TSchedulePool().clear() - loss = ir_graph(IRTensor(shape=[64,1024])) - seq = SUSequence(TSchedulePool().sus()) - - first_stage = seq.sus()[1:4] - second_stage = seq.sus()[6:9] - - su1 = seq.sus(1) - for su in first_stage[1:]: - su1 = seq.merge(su1, su) - assert su1 is not None - - su2 = second_stage[0] - for su in second_stage[1:]: - su2 = seq.merge(su2, su) - - for su in seq.sus(): - print(su) - - # spatial code - sgener = SScheduleCodeGen(seq) - scode = sgener.gen(device=0) - print(scode) - - # temporal code - tgener = TScheduleCodeGen(seq) - tcode = tgener.gen(device=0) - print(tcode) - - -def test_graph_backward(ir_graph): - - TSchedulePool().clear() - micro_bs = 2 - for _ in range(micro_bs): - loss = ir_graph(IRTensor(shape=[64,1024])) - loss.backward() - print('====== Backward Test =======\n') - print(TSchedulePool()) - - seq = SUSequence(TSchedulePool().sus()) - first_stage = seq.sus()[1:4] - second_stage = seq.sus()[6:9] - - su1 = seq.sus(1) - for su in first_stage[1:]: - su1 = seq.merge(su1, su) - assert su1 is not None - - su2 = second_stage[0] - for su in second_stage[1:]: - su2 = seq.merge(su2, su) - - print('===== seq before gen ====') - print(seq) - for su in seq.sus(): - print(f'pair: {su} <-> {su.mirror}') - - sgener = SScheduleCodeGen(seq) - scode = sgener.gen(device=0) - print(scode) - scode = sgener.gen(device=1) - print(scode) - - # temporal code - tgener = TScheduleCodeGen(seq) - tcode = tgener.gen(device=0) - print(tcode) - tcode = tgener.gen(device=1) - print(tcode) - - print('\n====== Backward Test =======\n') - - -if __name__ == '__main__': - - test_graph_forward(ir_graph) - test_su_merge(ir_graph) - test_graph_backward(ir_graph) From 9e602a4186a05d5381495db4f2341d54f4eab96f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 21 Oct 2021 10:10:24 +0800 Subject: [PATCH 0222/1892] align primitives to linear --- examples/linears.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/examples/linears.py b/examples/linears.py index 3a9cd3d0..e10fdb27 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -16,8 +16,8 @@ from torch import nn import cube -from cube.tschedule.su import ScheduleUnit -from cube.tschedule.suseq import SUSequence +from cube.schedule.su import ScheduleUnit +from cube.schedule.sugraph import SUGraph def trans_policy(graph, resource): @@ -25,13 +25,13 @@ def trans_policy(graph, resource): The transformation policy transposes linear using data parallel """ ndevice = resource.ngpus - for node in graph.nodes(): - algorithm = node.algorithms('data_parallel') - graph.select(node, algorithm, config=dict(chunk_size=ndevice)) + for op in graph.nodes(): + algorithm = op.algorithms('data_parallel') + graph.partition(op, algorithm, config=dict(chunk_size=ndevice)) return graph -def schedule_policy(seq: SUSequence, resource): +def schedule_policy(seq: SUGraph, resource): """ The schedule policy uses 1F1B (interleaved) pipeline """ @@ -71,10 +71,7 @@ def schedule_policy(seq: SUSequence, resource): if f_mirco_batch_id >= len(batch_seqs): continue reorder.append(f(stage, f_mirco_batch_id)) - - for idx, su in enumerate(reorder): - seq.move(su, idx) - + SUGraph.set_order(reorder) From 1320a891d996dc309909c63abda6f48401f4b7c4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 21 Oct 2021 12:44:18 +0800 Subject: [PATCH 0223/1892] graph to subtensor --- cube/graph/graph.py | 60 +++++++++++++++------- cube/graph/tensor.py | 82 ++++++++++++++++++++++++++----- cube/schedule/translator.py | 2 +- tests/schedule/test_graphpass.py | 3 ++ tests/schedule/test_su.py | 3 ++ tests/schedule/test_sugraph.py | 3 ++ tests/schedule/test_translator.py | 14 +++--- 7 files changed, 129 insertions(+), 38 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 49a69df4..5ad3dca4 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -8,11 +8,11 @@ """ from typing import Union, Tuple, List, Optional, Any +import copy from cube.ir.cten import IRTensor, IRCell from cube.graph.operator import IROperation - -import copy +from cube.graph.tensor import IRFullTensor __all__ = ['IRGraph'] @@ -32,6 +32,7 @@ def __init__(self, module_name: str): self._nodes: List[IROperation] = nodes + self._parameters = list() self.reset_dependency() if input_tensors is None: @@ -46,13 +47,33 @@ def __init__(self, output_length=len(output_tensors) ) - for idx, tensor in enumerate(input_tensors): + # convert to SubTensor + inputs = list() + for tensor in input_tensors: + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + inputs.append(tensor) + outputs = list() + for tensor in output_tensors: + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + outputs.append(tensor) + for node in self.nodes(): + for idx, tensor in enumerate(node.inputs()): + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + node.set_input(idx, tensor) + for idx, tensor in enumerate(node.outputs()): + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + node.set_output(idx, tensor) + + for idx, tensor in enumerate(inputs): self.set_input(idx, tensor) - for idx, tensor in enumerate(output_tensors): + for idx, tensor in enumerate(outputs): self.set_output(idx, tensor) # set parameter - self._parameters = list() for node in self._nodes: for input in node.inputs(): if isinstance(input, IRTensor): @@ -68,22 +89,23 @@ def reset_dependency(self): """ # set node predecessors and successors for src_idx in range(len(self._nodes)): - src_cell = self._nodes[src_idx] - src_cell._successors = [ - list() for _ in range(len(src_cell.outputs())) + src_node = self._nodes[src_idx] + src_node._successors = [ + list() for _ in range(len(src_node.outputs())) ] - for dst_idx in range(src_idx + 1, len(self._nodes)): - dst_cell = self._nodes[dst_idx] - dst_cell._predecessors = [ - list() for _ in range(len(dst_cell.inputs())) + for dst_node in self._nodes[src_idx+1:]: + dst_node._predecessors = [ + list() for _ in range(len(dst_node.inputs())) ] - for tensor in src_cell.outputs(): - if isinstance(tensor, IRTensor): - if tensor in dst_cell.inputs(): - src_output_idx = src_cell.outputs().index(tensor) - src_cell.add_successor(src_output_idx, dst_cell) - dst_input_idx = dst_cell.inputs().index(tensor) - dst_cell.add_predecessor(dst_input_idx, src_cell) + for out_idx, out_tensor in enumerate(src_node.outputs()): + if not isinstance(out_tensor, IRTensor): + continue + for in_idx, in_tensor in enumerate(dst_node.inputs()): + if not isinstance(in_tensor, IRTensor): + continue + if out_tensor.overlap(in_tensor): + src_node.add_successor(out_idx, dst_node) + dst_node.add_predecessor(in_idx, src_node) def parameters(self): """ diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 05bd6a21..2b7d5c05 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Union, Tuple import copy from cube.ir.cten import IRTensor @@ -20,6 +20,16 @@ def __init__(self, indices): ) self._indices = indices + def __eq__(self, other): + if isinstance(other, IndexMap): + if self.ndims != self.ndims: + return False + for myslicer, oslicer in zip(self.get(), other.get()): + if myslicer != oslicer: + return False + return True + return False + def get(self): """ Get indices @@ -224,7 +234,7 @@ def val_ops(self, index: Optional[int] = None): else: return self._val_ops[index] - def select(self, indices, val_op: Optional[Callable], shape: List[int]): + def select(self, indices: Union[Tuple, IndexMap], val_op: Optional[Callable], shape: List[int]): """ Select a SubTensor from FullTensor. @@ -242,11 +252,19 @@ def select(self, indices, val_op: Optional[Callable], shape: List[int]): Returns: IRSubTensor """ - sub_tensor = IRSubTensor(self, indices, val_op, shape) - self._segments.append(sub_tensor) - self._indices.append(IndexMap(indices)) - self._val_ops.append(val_op) - return sub_tensor + if not isinstance(indices, IndexMap): + indices = IndexMap(indices) + if indices in self._indices: + index = self._indices.index(indices) + sub_tensor = self._segments[index] + if sub_tensor.val_op == val_op: + return sub_tensor + else: + sub_tensor = IRSubTensor(self, indices, val_op, shape) + self._segments.append(sub_tensor) + self._indices.append(indices) + self._val_ops.append(val_op) + return sub_tensor def overlap(self, other): """ @@ -278,6 +296,26 @@ def common(self, other) -> Optional[IRTensor]: """ return other if self.overlap(other) else None + def tosub(self): + """ + Convert to SubTensor by selecting all indices + """ + if self.shape is None: + raise RuntimeError("Expected know shape") + slicers = list() + for dim_len in self.shape: + slicers.append(slice(0, dim_len, 1)) + sub_tensor = self.select( + indices=tuple(slicers), + val_op=None, + shape=self.shape + ) + return sub_tensor + + def __repr__(self): + dscp = f'FullTensor(id={self._id}, shape={self.shape}, device={self.device})' + return dscp + class IRSubTensor(IRTensor): @@ -290,17 +328,33 @@ def __init__(self, full_tensor: IRTensor, indices, val_op=None, shape=None): indices: index list val_op: the value operation to merge SubTensors into one """ + if not isinstance(full_tensor, IRFullTensor): + raise TypeError(f"Expected IRFullTensor but got {full_tensor}") super().__init__(shape=shape, name=full_tensor.name) # the full tensor self._full_tensor = full_tensor # the index from full_tensor - self._index_map = IndexMap(indices) + if not isinstance(indices, IndexMap): + indices = IndexMap(indices) + self._index_map = indices # val merge op self.val_merge_op = val_op + def __eq__(self, other): + + if isinstance(other, IRFullTensor): + return self.parent == other and self.shape == other.shape + if isinstance(other, IRSubTensor): + if self.parent != other.parent: + return False + if other.indices == self.indices and self.shape == other.shape: + return True + return False + return False + @property def parent(self) -> IRFullTensor: """ @@ -327,7 +381,7 @@ def __copy__(self): Returns: tensor """ - tensor = IRSubTensor(self._shape, self.name) + tensor = IRSubTensor(self.parent, self.indices, self.val_op, self._shape) for key in self.__dict__: setattr(tensor, key, getattr(self, key)) # clear attached cells @@ -342,7 +396,7 @@ def renew(self): Returns: tensor """ - tensor = IRSubTensor(self._shape, self.name) + tensor = IRSubTensor(self.parent, self.indices, self.val_op, self._shape) new_id = tensor._id for key in self.__dict__: setattr(tensor, key, getattr(self, key)) @@ -351,7 +405,7 @@ def renew(self): tensor._id = new_id return tensor - def select(self, indices, val_op, shape=None): + def select(self, indices: Union[Tuple, IndexMap], val_op, shape=None): """ Select an IRSubTensor @@ -384,6 +438,8 @@ def overlap(self, other): if isinstance(other, IRFullTensor): return self.parent == other elif isinstance(other, IRSubTensor): + if self.parent != other.parent: + return False return self.indices.overlap(other.indices) else: raise TypeError("Customized IRTensor not support") @@ -413,3 +469,7 @@ def common(self, other): else: raise NotImplementedError("Customized IRTensor not support") return None + + def __repr__(self): + dscp = f'SubTensor(id={self._id}, shape={self.shape}, device={self.device})' + return dscp \ No newline at end of file diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index 679b61ba..faa125d6 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -42,7 +42,7 @@ def load_data(dataloader: IRDataLoader): outputs = list() for data in datas: if torch.is_tensor(data): - data = IRFullTensor(shape=list(data.shape), name='data') + data = IRFullTensor(shape=list(data.shape), name='data').tosub() data.requires_grad = False outputs.append(data) diff --git a/tests/schedule/test_graphpass.py b/tests/schedule/test_graphpass.py index f27ce97a..3955dfcf 100644 --- a/tests/schedule/test_graphpass.py +++ b/tests/schedule/test_graphpass.py @@ -29,6 +29,7 @@ def construct_graph(): linear1.set_input(0, input) linear1.set_input(1, weight1) linear1.set_input(2, bias1) + linear1.infer_shape() # linear2 linear2 = IROperation( @@ -39,6 +40,7 @@ def construct_graph(): ) linear2.set_input(0, linear1.outputs(0)) linear2.set_input(1, weight2) + linear2.infer_shape() # linear3 linear3 = IROperation( @@ -50,6 +52,7 @@ def construct_graph(): linear3.set_input(0, linear2.outputs(0)) linear3.set_input(1, weight3) linear3.set_input(2, bias3) + linear3.infer_shape() graph = IRGraph( nodes=[linear1, linear2, linear3], diff --git a/tests/schedule/test_su.py b/tests/schedule/test_su.py index a7d498c2..9f3c4f23 100644 --- a/tests/schedule/test_su.py +++ b/tests/schedule/test_su.py @@ -26,6 +26,7 @@ def construct_model(): linear1.set_input(0, input) linear1.set_input(1, weight1) linear1.set_input(2, bias1) + linear1.infer_shape() # linear2 linear2 = IROperation( @@ -36,6 +37,7 @@ def construct_model(): ) linear2.set_input(0, linear1.outputs(0)) linear2.set_input(1, weight2) + linear2.infer_shape() # linear3 linear3 = IROperation( @@ -47,6 +49,7 @@ def construct_model(): linear3.set_input(0, linear2.outputs(0)) linear3.set_input(1, weight3) linear3.set_input(2, bias3) + linear3.infer_shape() # return [input], [ops], [output] return [input], [linear1, linear2, linear3], [linear3.outputs(0)] diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index 64b42bd1..b2a53c04 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -26,6 +26,7 @@ def construct_graph(): linear1.set_input(0, input) linear1.set_input(1, weight1) linear1.set_input(2, bias1) + linear1.infer_shape() # linear2 linear2 = IROperation( @@ -36,6 +37,7 @@ def construct_graph(): ) linear2.set_input(0, linear1.outputs(0)) linear2.set_input(1, weight2) + linear2.infer_shape() # linear3 linear3 = IROperation( @@ -47,6 +49,7 @@ def construct_graph(): linear3.set_input(0, linear2.outputs(0)) linear3.set_input(1, weight3) linear3.set_input(2, bias3) + linear3.infer_shape() graph = IRGraph( nodes=[linear1, linear2, linear3], diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py index ec529653..aff64ad7 100644 --- a/tests/schedule/test_translator.py +++ b/tests/schedule/test_translator.py @@ -5,7 +5,7 @@ from cube.schedule.su import SUType from cube.schedule.pool import SchedulePool -from cube.graph.tensor import IRFullTensor +from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.graph.operator import IROperation from cube.graph.graph import IRGraph @@ -84,7 +84,7 @@ def test_load_dataloader(): dataloader = IRDataLoader(FakeDataLoader(batch_size=64)) data1 = next(dataloader) - assert isinstance(data1, IRFullTensor) + assert isinstance(data1, IRSubTensor) assert data1.shape == [64, 1024] data2 = next(dataloader) @@ -92,7 +92,7 @@ def test_load_dataloader(): assert all([su.stype == SUType.Dataloader for su in SchedulePool().sus()]) data3 = LogicTranslator.load_data(dataloader) - assert isinstance(data1, IRFullTensor) + assert isinstance(data1, IRSubTensor) assert data1.shape == [64, 1024] assert len(SchedulePool().sus()) == 3 assert all([su.stype == SUType.Dataloader for su in SchedulePool().sus()]) @@ -103,10 +103,10 @@ def test_translator_forward(): graph = construct_graph() print(graph) - data = IRFullTensor(shape=[64,1024], name='data') + data = IRFullTensor(shape=[64,1024], name='data').tosub() output = graph(data) - assert isinstance(output, IRFullTensor) + assert isinstance(output, IRSubTensor) assert output.shape == [64, 1024] assert output.trace is not None @@ -122,7 +122,7 @@ def test_translator_backward(): SchedulePool().clear() graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data') + data = IRFullTensor(shape=[64,1024], name='data').tosub() output = graph(data) output.backward() @@ -141,7 +141,7 @@ def test_translatro_gen_adapter(): SchedulePool().clear() graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data') + data = IRFullTensor(shape=[64,1024], name='data').tosub() output = graph(data) # forward adatpers From ac122104a6c51ff00e4eaf5a1dff414c0114f628 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 21 Oct 2021 13:55:10 +0800 Subject: [PATCH 0224/1892] add select for adapter --- cube/schedule/adapter/select.py | 88 ++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 34 deletions(-) diff --git a/cube/schedule/adapter/select.py b/cube/schedule/adapter/select.py index 433978ba..4d6dcc7a 100644 --- a/cube/schedule/adapter/select.py +++ b/cube/schedule/adapter/select.py @@ -1,9 +1,9 @@ -from typing import List +from typing import List, Optional from enum import Enum import numpy as np -from cube.ir.cten import IRCell, IRTensor -from cube.graph.tensor import IRSubTensor, IRFullTensor +from cube.ir.cten import IRCell +from cube.graph.tensor import IRSubTensor, IndexMap class IRReshapeType(Enum): @@ -25,25 +25,26 @@ class IRTensorReshape(IRCell): src_tensors has (multiple) tensors, dst_tensors is only one tensor. This will merge the sub_tensor and generate what it need """ - def __init__(self, src_tensors: List[IRTensor], dst_tensors: List[IRTensor]): + def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor]): if len(src_tensors) != 1 and len(dst_tensors) != 1: raise ValueError("Expected at least one of tensors has length 1") - self._src_tensors = src_tensors - self._dst_tensors = dst_tensors self.ttype = None - self.select_indices = list() - self.merge_axis = None + self._select_indices: List[IndexMap] = list() + self._merge_axis = None if len(src_tensors) == 1: self.ttype = IRReshapeType.Select src_tensor = src_tensors[0] + if not isinstance(src_tensor, IRSubTensor): + raise TypeError(f"Expected IRSubTensor but got {type(src_tensor)}") # select for tensor in dst_tensors: - indices = tensor.common(src_tensor) - self.select_indices.append(indices) + indices = tensor.indices & src_tensor.indices + print(indices.get()) + self._select_indices.append(indices) elif len(dst_tensors) == 1: self.ttype = IRReshapeType.Merge @@ -54,34 +55,36 @@ def __init__(self, src_tensors: List[IRTensor], dst_tensors: List[IRTensor]): for src_tensor in src_tensors: if isinstance(src_tensor, IRSubTensor): for ndim, slicer in enumerate(src_tensor.indices.get()): - indices[ndim].add(slicer) - elif isinstance(dst_tensor, IRFullTensor): - for ndim, dim_len in enumerate(src_tensor.shape): - slicer = slice(0, dim_len, 1) - indices[ndim].add(slicer) + indices[ndim].add((slicer.start, slicer.stop, slicer.step)) + else: + raise RuntimeError( + f"Expected SubTensor but got {type(src_tensor)}" + ) # check if only one dim set has multiple slicer for dim, dim_indices in enumerate(indices): if len(dim_indices) != 1: - if self.merge_axis is not None: + if self._merge_axis is not None: raise NotImplementedError("Only support merge on one axis") - self.merge_axis = dim - dim_indices = indices[self.merge_axis] - # check if they are overlapped - starts = np.array([slicer.start for slicer in dim_indices]) - stops = np.array([slicer.stop for slicer in dim_indices]) - steps = np.array([slicer.step for slicer in dim_indices]) - sorted_idx = np.argsort(starts) - sorted_starts = starts[sorted_idx] - sorted_stops = stops[sorted_idx] - sorted_steps = steps[sorted_idx] - for last_stop, begin_start in zip(sorted_stops[:-1], sorted_starts[1:]): - if last_stop != begin_start: - raise NotImplementedError(f"Concatenation fails due to axis {last_stop} != {begin_start}") - for step in sorted_steps: - if step != 1: - raise NotImplementedError(f"Found a SubTensor step {step} != 1") - # re-order - dst_tensors = dst_tensors[sorted_idx] + self._merge_axis = dim + # get merge axis + if self._merge_axis is not None: + dim_indices = indices[self._merge_axis] + # check if they are overlapped + starts = np.array([slicer[0] for slicer in dim_indices]) + stops = np.array([slicer[1] for slicer in dim_indices]) + steps = np.array([slicer[2] for slicer in dim_indices]) + sorted_idx = np.argsort(starts) + sorted_starts = starts[sorted_idx] + sorted_stops = stops[sorted_idx] + sorted_steps = steps[sorted_idx] + for last_stop, begin_start in zip(sorted_stops[:-1], sorted_starts[1:]): + if last_stop != begin_start: + raise NotImplementedError(f"Concatenation fails due to axis {last_stop} != {begin_start}") + for step in sorted_steps: + if step != 1: + raise NotImplementedError(f"Found a SubTensor step {step} != 1") + # re-order + src_tensors = np.array(src_tensors)[sorted_idx] else: raise RuntimeError("Internal Error") @@ -96,3 +99,20 @@ def __init__(self, src_tensors: List[IRTensor], dst_tensors: List[IRTensor]): self.set_input(idx, input) for idx, output in enumerate(dst_tensors): self.set_output(idx, output) + + @property + def select_indices(self) -> List[IndexMap]: + return self._select_indices + + @property + def merge_axis(self) -> Optional[int]: + return self._merge_axis + + def is_identity(self): + """ + Check if this transformation is a non-op + """ + if len(self.inputs()) == 1 and len(self.outputs()) == 1: + if self.inputs(0) == self.outputs(0): + return True + return False From da4d10a3bcbb13a93e1aade0a8065befcb25e072 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 21 Oct 2021 14:02:34 +0800 Subject: [PATCH 0225/1892] test adapter --- cube/schedule/adapter/select.py | 11 +++- tests/schedule/test_adapter_select.py | 83 +++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 tests/schedule/test_adapter_select.py diff --git a/cube/schedule/adapter/select.py b/cube/schedule/adapter/select.py index 4d6dcc7a..1ac85ce7 100644 --- a/cube/schedule/adapter/select.py +++ b/cube/schedule/adapter/select.py @@ -112,7 +112,14 @@ def is_identity(self): """ Check if this transformation is a non-op """ - if len(self.inputs()) == 1 and len(self.outputs()) == 1: - if self.inputs(0) == self.outputs(0): + if self.ttype == IRReshapeType.Select: + src_tensor = self.inputs(0) + for dst_tensor in self.outputs(): + if dst_tensor != src_tensor: + return False + return True + if self.ttype == IRReshapeType.Merge: + if self.merge_axis is None: return True + return False return False diff --git a/tests/schedule/test_adapter_select.py b/tests/schedule/test_adapter_select.py new file mode 100644 index 00000000..8739d62f --- /dev/null +++ b/tests/schedule/test_adapter_select.py @@ -0,0 +1,83 @@ + +from cube.schedule.adapter.select import IRReshapeType +from cube.schedule.adapter.select import IRTensorReshape + +from cube.graph.tensor import IRFullTensor, IndexMap + + +def test_tensor_reshape_init(): + + tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() + + tensor2 = tensor1.select( + indices = (slice(0, 512), slice(0, 1024)), + val_op = None, + shape = [512, 1024] + ) + + tensor3 = tensor1.select( + indices = (slice(512, 1024), slice(0, 1024)), + val_op = None, + shape = [512, 1024] + ) + + reshape = IRTensorReshape( + src_tensors=[tensor1], + dst_tensors=[tensor2, tensor3] + ) + + assert len(reshape.inputs()) == 1 + assert len(reshape.outputs()) == 2 + assert reshape.ttype == IRReshapeType.Select + assert reshape.select_indices == [ + IndexMap((slice(0, 512, 1), slice(0, 1024, 1))), + IndexMap((slice(512, 1024, 1), slice(0, 1024, 1))), + ] + assert reshape.merge_axis is None + + reshape = IRTensorReshape( + dst_tensors=[tensor1], + src_tensors=[tensor2, tensor3] + ) + + assert len(reshape.inputs()) == 2 + assert len(reshape.outputs()) == 1 + assert reshape.ttype == IRReshapeType.Merge + assert reshape.merge_axis == 0 + assert len(reshape.select_indices) == 0 + + +def test_adapter_select_is_identity(): + + tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() + + tensor2 = tensor1.select( + indices = (slice(512, 1024), slice(0, 1024)), + val_op = None, + shape = [512, 1024] + ) + + tensor3 = tensor2.select( + indices = (slice(0, 256), slice(0, 1024)), + val_op = None, + shape = [256, 1024] + ) + + tensor4 = tensor1.select( + indices = (slice(512, 768), slice(0, 1024)), + val_op = None, + shape = [256, 1024] + ) + + tensor5 = tensor1.select( + indices = (slice(512, 768), slice(0, 1024)), + val_op = None, + shape = [256, 1024] + ) + + reshape = IRTensorReshape( + dst_tensors=[tensor2], + src_tensors=[tensor4, tensor5] + ) + + assert reshape.is_identity() From 0dd1a92c3348a4f705d8ebb2213baa9d4f201b5d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 21 Oct 2021 14:26:41 +0800 Subject: [PATCH 0226/1892] select merge bug fix --- cube/schedule/adapter/select.py | 5 ++++- tests/schedule/test_adapter_select.py | 9 +++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/cube/schedule/adapter/select.py b/cube/schedule/adapter/select.py index 1ac85ce7..5a19d855 100644 --- a/cube/schedule/adapter/select.py +++ b/cube/schedule/adapter/select.py @@ -43,7 +43,6 @@ def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor # select for tensor in dst_tensors: indices = tensor.indices & src_tensor.indices - print(indices.get()) self._select_indices.append(indices) elif len(dst_tensors) == 1: @@ -66,6 +65,10 @@ def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor if self._merge_axis is not None: raise NotImplementedError("Only support merge on one axis") self._merge_axis = dim + if self._merge_axis is None: + # check the coverage + if src_tensors[0].indices != dst_tensor.indices: + raise RuntimeError("Not cover all the indices to merge.") # get merge axis if self._merge_axis is not None: dim_indices = indices[self._merge_axis] diff --git a/tests/schedule/test_adapter_select.py b/tests/schedule/test_adapter_select.py index 8739d62f..d0281ee7 100644 --- a/tests/schedule/test_adapter_select.py +++ b/tests/schedule/test_adapter_select.py @@ -76,8 +76,13 @@ def test_adapter_select_is_identity(): ) reshape = IRTensorReshape( - dst_tensors=[tensor2], - src_tensors=[tensor4, tensor5] + src_tensors=[tensor2], + dst_tensors=[tensor4, tensor5] ) + assert not reshape.is_identity() + reshape = IRTensorReshape( + src_tensors=[tensor3], + dst_tensors=[tensor4, tensor5] + ) assert reshape.is_identity() From 4612fb55200dbb37f9d884473f0dd361752a8afc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Oct 2021 12:23:56 +0800 Subject: [PATCH 0227/1892] change examples to align to current design --- examples/case_study/megatron_policy.py | 228 ++++++++++++------------- examples/e2e.py | 2 +- examples/linears.py | 12 +- 3 files changed, 117 insertions(+), 125 deletions(-) diff --git a/examples/case_study/megatron_policy.py b/examples/case_study/megatron_policy.py index a27860c5..773dfeb6 100644 --- a/examples/case_study/megatron_policy.py +++ b/examples/case_study/megatron_policy.py @@ -1,142 +1,128 @@ from typing import List -# spatial -def select(tensor, indices, val_op, shape): pass -def assign(tensor, ranks: List): pass +from cube.schedule.su import SUType -# temporal -def merge(su1, su2): pass +def transform_policy(graph, resource): -def spolicy(model, runtime_info, tp_size, dp_size, pp_size): - - n_devices = runtime_info.ndevs + # suppose this is the policy config that both + # transformation and schedule policy know + tp_size = 8, + pp_size = 4, + dp_size = resource.ndev // (tp_size * pp_size) + num_micro_batch = 16 # each op is divided in (mp_dsize, dp_size) # and put in (pp_size) stage # TODO groups[stage][dp_group][tp_group] = devices (List[int]) - groups = parallel_group(n_devices, tp_size, dp_size, pp_size) - # pipeline stage - total_nodes = len(model.nodes()) - num_op_per_stage = total_nodes // pp_size - for idx, op in enumerate(model.nodes()): - pp_stage = idx // num_op_per_stage - op.group = [pp_stage] + # data + pipeline parallelism: first transform graph + for idx, op in enumerate(graph.nodes()): + algorithm = op.algorithm('data_parallel') + graph.partition( + op, algorithm, config=dict(chunk_size=num_micro_batch * dp_size) + ) + pp_stage = idx // (len(graph.nodes()) // pp_size) + op.tag('pp_stage', pp_stage) # data parallel - for op in model.nodes(): - # data parallel algorithm (suppose at index 0) - dp_algo = op.logical_op.dist_algo(0) - sub_graph = select( - op = op, - algorithm = dp_algo, - config = dict(chunk_num=dp_size, uniform=True) - ) - for dp_stage, dp_op in sub_graph.nodes(): - dp_op.group.append(dp_stage) - model.replace(op, sub_graph) + for op in graph.nodes(): + algorithm = op.algorithm('data_parallel') + graph.partition(op) # tensor parallel # a transformer attention layer: # [attention: col_split(mm + mm + mm) + row_split(mm)] # a transformer feedforward layer: # [feedforwrd: col_split(mm) + row_split(mm)] - for idx in range(total_nodes): - for dp_rank in range(dp_size): - op = model.nodes(dp_size * idx + dp_rank) - devices = op.devices - sub_graph = None - # Attention block - # [1st linear -> 2nd linear) - if first_to_2nd_linear(op): - # split column - tp_col_algo = op.logical_op.dist_algo(1) - sub_graph = select( - op = op, - algorithm = tp_col_algo, - config = dict(chunk_num=tp_size, uniform=True) - ) - # 2nd linear - elif is_2nd_linear(op): - # split row - tp_row_algo = op.logical_op.dist_algo(2) - sub_graph = select( - op = op, - algorithm = tp_row_algo, - config = dict(chunk_num=tp_size, uniform=True) - ) - # MLP block - # [3rd linear -> 4th linear] - elif thrid_to_4th_linear(op): - # split column - tp_col_algo = op.logical_op.dist_algo(1) - sub_graph = select( - op = op, - algorithm = tp_col_algo, - config = dict(chunk_num=tp_size, uniform=True) - ) - elif is_4th_linear(op): - # split row - tp_row_algo = op.logical_op.dist_algo(2) - sub_graph = select( - op = op, - algorithm = tp_row_algo, - config = dict(chunk_num=tp_size, uniform=True) - ) - # else: no change, do redundant computation - if sub_graph: - for tp_stage, op in enumerate(sub_graph): - op.group.append(tp_stage) - model.replace(op, sub_graph) - # device assignment - for op in model.nodes(): - pp_stage, dp_stage, tp_stage = op.group - device = groups[pp_stage][dp_stage][tp_stage] - assign(op, device) - return model - - -def tpolicy(sus, relations, tp_size, pp_size, num_microbatch): - """ - Pipeline 1f1b policy description -- generate a sequence - - Actions: a list of actions - - relations: list[(Action1, Action2)]: a list of tuples indicate partial order - """ + for idx, op in enumerate(graph.nodes()): + # Attention block + # [1st linear -> 2nd linear) + if op_from_1st_to_2nd_linear(op): + # split column + tp_col_algo = op.logical_op.dist_algo(1) + graph.partition( + op = op, + algorithm = tp_col_algo, + config = dict(chunk_num=tp_size, uniform=True) + ) + # 2nd linear + elif op_is_2nd_linear(op): + # split row + tp_row_algo = op.logical_op.dist_algo(2) + graph.partition( + op = op, + algorithm = tp_row_algo, + config = dict(chunk_num=tp_size, uniform=True) + ) + # MLP block + # [3rd linear -> 4th linear] + elif op_from_3rd_to_4th_linear(op): + # split column + tp_col_algo = op.logical_op.dist_algo(1) + graph.partition( + op = op, + algorithm = tp_col_algo, + config = dict(chunk_num=tp_size, uniform=True) + ) + elif op_is_4th_linear(op): + # split row + tp_row_algo = op.logical_op.dist_algo(2) + graph.partition( + op = op, + algorithm = tp_row_algo, + config = dict(chunk_num=tp_size, uniform=True) + ) + return graph + + +def schedule_policy(su_graph, resource): + + # suppose this is the policy config that both + # transformation and schedule policy know + tp_size = 8, + pp_size = 4, + dp_size = resource.ndev // (tp_size * pp_size) + num_micro_batch = 16 + + # given tp, pp, dp, num mirco batch, set the device id + # for hierachical: [pipeline][data][tensor] = device (int) + dev_groups = set_device_id(tp_size, dp_size, pp_size, num_micro_batch) # put sus to forward-backward sequences: List[List[SU(op)]] - fb_op_seqs = list() - for su in sus: - for fb_seq in fb_op_seqs: - if fb_seq[-1].happen_before(su): - fb_seq.append(su) - break - else: - fb_op_seqs.append([su]) - - # merge to stages: List[List[SU(stage of ops)]] - fb_stage_seqs = list() - for fb_seq in fb_op_seqs: - merged_su = fb_seq[0] - merged_tag = fb_seq[0].tag - for su in fb_seq[1]: - if su.device == merged_su and su.tag == merged_tag: - merged_su = merge(merged_su, su) + fb_op_sus = list() + for su in su_graph.sus(): + if su.stype == SUType.Forward or su.stype == SUType.Backward: + for fb_seq in fb_op_sus: + if fb_seq[-1].happen_before(su): + fb_seq.append(su) + break else: - fb_stage_seqs.append(merged_su) - merged_su = su - merged_tag = su.tag - merged_su = merge(merged_su, su) - - # pp_size forward + pp_size backward - assert (pp_size * 2 == len(fb_stage_seqs[0])) + fb_op_sus.append([su]) + + # merge to stages: List[List[SU(stage sequential of ops)]] + fb_stage_sus = list() + assert len(fb_op_sus) == tp_size * dp_size * num_micro_batch + for dp in range(dp_size): + for tp in range(tp_size): + fb_stage_sus.append([]) + fb_sus = fb_op_sus[dp * dp_size + tp] + for idx, su in enumerate(fb_sus): + pp = idx // ( len(fb_sus) // pp_size) + device = dev_groups[pp][dp][tp] + su_graph.assign(su, device) + merged_su = None + for su in fb_sus: + if merged_su is None: + merged_su = su + fb_stage_sus[-1].append([su]) + else: + # same device op can be merged + merged_su = su_graph.merge(merged_su, su) num_stage = pp_size - - f = lambda stage, micro_batch_id: fb_stage_seqs[micro_batch_id][stage] - b = lambda stage, micro_batch_id: fb_stage_seqs[micro_batch_id][num_stage + stage] + f = lambda stage, micro_batch_id: fb_stage_sus[micro_batch_id][stage] + b = lambda stage, micro_batch_id: fb_stage_sus[micro_batch_id][num_stage + stage] sequence = list() @@ -146,15 +132,17 @@ def tpolicy(sus, relations, tp_size, pp_size, num_microbatch): sequence.append(f(stage, mid)) # steady + cooldown: - for mid in range(num_microbatch): + for mid in range(num_micro_batch): # enqueue backward for stage in range(num_stage-1, -1, -1): sequence.append(b(stage, mid)) # enqueue forward for stage in range(num_stage): f_mid = mid + 1 + num_stage - stage - if f_mid >= num_microbatch: + if f_mid >= num_micro_batch: continue sequence.append(f(stage, f_mid)) - assert check_consistency(sequence, sus, relations) - return sequence + + # infor system the control dependency by topological assignment + su_graph.set_order(sequence) + return su_graph diff --git a/examples/e2e.py b/examples/e2e.py index a207c152..5b452b6f 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -15,7 +15,7 @@ from torch import nn import cube -from cube.tschedule.su import SUType +from cube.schedule.su import SUType def spolicy(ir_graph): diff --git a/examples/linears.py b/examples/linears.py index e10fdb27..6d349361 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -20,12 +20,13 @@ from cube.schedule.sugraph import SUGraph -def trans_policy(graph, resource): +def transform_policy(graph, resource): """ The transformation policy transposes linear using data parallel """ ndevice = resource.ngpus for op in graph.nodes(): + # TODO: which dimension is batch algorithm = op.algorithms('data_parallel') graph.partition(op, algorithm, config=dict(chunk_size=ndevice)) return graph @@ -41,7 +42,7 @@ def schedule_policy(seq: SUGraph, resource): batch_seqs: List[List[ScheduleUnit]] = group_by_batches(seq.sus()) num_fsus = len(seq.sus()) // len(batch_seqs) // 2 - # assign devices -- intra device order + # device placement -- inter-device order for batch_seq in batch_seqs: for idx, su in enumerate(batch_seq): stage = idx // (num_fsus // ndevice) @@ -50,8 +51,8 @@ def schedule_policy(seq: SUGraph, resource): else: seq.assign(su, ndevice - stage % ndevice) - - # assign devices -- inter device order + + # decide topo order -- intra-device order f = lambda stage, micro_batch_id: batch_seqs[micro_batch_id][stage] b = lambda stage, micro_batch_id: batch_seqs[micro_batch_id][-stage] @@ -71,6 +72,7 @@ def schedule_policy(seq: SUGraph, resource): if f_mirco_batch_id >= len(batch_seqs): continue reorder.append(f(stage, f_mirco_batch_id)) + # inform system the topological order that could do pipeline parallelism SUGraph.set_order(reorder) @@ -112,10 +114,12 @@ def train(): dim = 1024 model = MLP(dim=dim) + model = cube.schedule.transform(model, policy_fn=transform_policy) model = model.cuda() dataloader = FakeDataLoader((batch_size, dim)) + @cube.schedule.schedule(model, dataloader, policy_fn=schedule_policy) def train_iter(model, dataloader): for _ in range(4): data = next(dataloader) From 606caa4622ee1490514872df57fb18bb249f72d9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Oct 2021 12:25:04 +0800 Subject: [PATCH 0228/1892] add attention example --- examples/attention.py | 156 ++++++++++++++++++++++++++++++++++++++++++ gencode0.py | 95 ------------------------- gencode1.py | 94 ------------------------- 3 files changed, 156 insertions(+), 189 deletions(-) create mode 100644 examples/attention.py delete mode 100644 gencode0.py delete mode 100644 gencode1.py diff --git a/examples/attention.py b/examples/attention.py new file mode 100644 index 00000000..9eb8e57e --- /dev/null +++ b/examples/attention.py @@ -0,0 +1,156 @@ +from typing import Optional +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +class MultiHeadSelfAttention(nn.Module): + + def __init__(self, seq_len, embed_dim, heads, dropout): + super().__init__() + + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_heads = heads + self.dim_head = embed_dim // heads + self.scale = self.dim_head ** -0.5 + + self.weight_qkv = torch.nn.Parameter(torch.empty( + 3 * embed_dim, embed_dim + )) + self.weight_out = torch.nn.Parameter(torch.empty( + embed_dim, embed_dim + )) + self.dropout = nn.Dropout(dropout) + + self._reset_parameters() + + def _reset_parameters(self): + torch.nn.init.xavier_uniform_(self.weight_qkv) + torch.nn.init.xavier_uniform_(self.weight_out) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + x: [L, N, E]: seq_len, batch_size, embedding dimension + output: [L, N, E] + """ + bs = x.shape[1] + + qkv = F.linear(x, self.weight_qkv, None).chunk(3, dim=-1) + q, k, v = qkv + q = q.contiguous().view(self.seq_len, (bs * self.num_heads), self.dim_head) + q = q.transpose(0, 1) + # => q: (batch size, seq_len, embed_dim) + k = k.contiguous().view(self.seq_len, (bs * self.num_heads), self.dim_head) + k = k.transpose(0, 1) + v = v.contiguous().view(self.seq_len, (bs * self.num_heads), self.dim_head) + v = v.transpose(0, 1) + + q = q * self.scale + attn = torch.bmm(q, k.transpose(-2, -1)) + attn = F.softmax(attn, dim=-1) + attn = self.dropout(attn) + output = torch.bmm(attn, v) + output = output.transpose(0, 1).contiguous() + output = output.view(self.seq_len, bs, self.embed_dim) + output = F.linear(output, self.weight_out) + return output + + def _ref_forward(self, x, mask: Optional[torch.Tensor] = None): + """ + X: [L, N, E]: seq_len, batch_size, embedding dimension + """ + output, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=self.embed_dim, + num_heads=self.num_heads, + in_proj_weight=self.weight_qkv, + in_proj_bias=None, + bias_k = None, + bias_v = None, + add_zero_attn=False, + dropout_p=self.dropout.p, + out_proj_weight=self.weight_out, + out_proj_bias=None, + training=self.training, + need_weights=False + ) + return output + + +class Attention(nn.Module): + def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.seq_len = seq_len + self.scale = dim_head ** -0.5 + + self.stable = stable + self.causal = causal + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None): + b, n, _, h, device = *x.shape, self.heads, x.device + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q = rearrange(qkv[0], 'b n (h d) -> b h n d', h = h) + k = rearrange(qkv[0], 'b n (h d) -> b h n d', h = h) + v = rearrange(qkv[0], 'b n (h d) -> b h n d', h = h) + + + q = q * self.scale + + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) + mask_value = max_neg_value(dots) + + if mask: + mask = rearrange(mask, 'b j -> b () () j') + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + dots.masked_fill_(mask, mask_value) + + attn = torch.softmax(dots, dim=-1) + + out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + + +if __name__ == '__main__': + + L = 64 + N = 16 + E = 128 + n_heads = 8 + + model = MultiHeadSelfAttention(L, E, n_heads, dropout=0.0) + + x = torch.rand((L, N, E)) + + out_ref = model._ref_forward(x) + out = model(x) + # print(out) + # print(out_ref) + assert torch.allclose(out, out_ref) is True + print('Test passed') + module = torch.jit.script(model) + print(module.graph) + print(module.code) \ No newline at end of file diff --git a/gencode0.py b/gencode0.py deleted file mode 100644 index b9ecca0c..00000000 --- a/gencode0.py +++ /dev/null @@ -1,95 +0,0 @@ - - -########## Generated Code ########### -import torch -import cube - - -class GenModel(torch.nn.Module): - - def __init__(self): - super().__init__() - self.weight_3 = torch.nn.Parameter(torch.empty((16384, 1024))) - - def su1(self, data_36): - tensor_41 = torch.nn.functional.linear(data_36, self.weight_3, None) - tensor_45 = torch.nn.functional.gelu(tensor_41) - tensor_49 = torch.nn.functional.dropout(tensor_45, 0.0, self.training, False) - return tensor_49 - - def su2(self, tensor_49): - cube.runtime.collectives.send((tensor_49, ), [[1]]) - return - - def su7(self): - tensor_89 = cube.runtime.collectives.recv([[64, 16384]], [[1]]) - return tensor_89 - - def su10(self, data_215): - tensor_220 = torch.nn.functional.linear(data_215, self.weight_3, None) - tensor_224 = torch.nn.functional.gelu(tensor_220) - tensor_228 = torch.nn.functional.dropout(tensor_224, 0.0, self.training, False) - return tensor_228 - - def su11(self, tensor_228): - cube.runtime.collectives.send((tensor_228, ), [[1]]) - return - - def su16(self): - tensor_268 = cube.runtime.collectives.recv([[64, 16384]], [[1]]) - return tensor_268 - - def su19(self, data_394): - tensor_399 = torch.nn.functional.linear(data_394, self.weight_3, None) - tensor_403 = torch.nn.functional.gelu(tensor_399) - tensor_407 = torch.nn.functional.dropout(tensor_403, 0.0, self.training, False) - return tensor_407 - - def su20(self, tensor_407): - cube.runtime.collectives.send((tensor_407, ), [[1]]) - return - - def su25(self): - tensor_447 = cube.runtime.collectives.recv([[64, 16384]], [[1]]) - return tensor_447 - - def su28(self, data_573): - tensor_578 = torch.nn.functional.linear(data_573, self.weight_3, None) - tensor_582 = torch.nn.functional.gelu(tensor_578) - tensor_586 = torch.nn.functional.dropout(tensor_582, 0.0, self.training, False) - return tensor_586 - - def su29(self, tensor_586): - cube.runtime.collectives.send((tensor_586, ), [[1]]) - return - - def su34(self): - tensor_626 = cube.runtime.collectives.recv([[64, 16384]], [[1]]) - return tensor_626 - - -########## Generated Code ########### -import torch -import cube - -def _train_step(model, dataloader): - data_36 = next(dataloader) - tensor_49 = cube.runtime.temporal.forward(model.su1, *(data_36, )) - cube.runtime.temporal.forward(model.su2, *(tensor_49, )) - tensor_89 = cube.runtime.temporal.forward(model.su7, *()) - data_78 = cube.runtime.temporal.backward((data_36, ), (tensor_49, ), (tensor_89, )) - data_215 = next(dataloader) - tensor_228 = cube.runtime.temporal.forward(model.su10, *(data_215, )) - cube.runtime.temporal.forward(model.su11, *(tensor_228, )) - tensor_268 = cube.runtime.temporal.forward(model.su16, *()) - data_257 = cube.runtime.temporal.backward((data_215, ), (tensor_228, ), (tensor_268, )) - data_394 = next(dataloader) - tensor_407 = cube.runtime.temporal.forward(model.su19, *(data_394, )) - cube.runtime.temporal.forward(model.su20, *(tensor_407, )) - tensor_447 = cube.runtime.temporal.forward(model.su25, *()) - data_436 = cube.runtime.temporal.backward((data_394, ), (tensor_407, ), (tensor_447, )) - data_573 = next(dataloader) - tensor_586 = cube.runtime.temporal.forward(model.su28, *(data_573, )) - cube.runtime.temporal.forward(model.su29, *(tensor_586, )) - tensor_626 = cube.runtime.temporal.forward(model.su34, *()) - data_615 = cube.runtime.temporal.backward((data_573, ), (tensor_586, ), (tensor_626, )) diff --git a/gencode1.py b/gencode1.py deleted file mode 100644 index b356b8a7..00000000 --- a/gencode1.py +++ /dev/null @@ -1,94 +0,0 @@ - - -########## Generated Code ########### -import torch -import cube - - -class GenModel(torch.nn.Module): - - def __init__(self): - super().__init__() - self.weight_14 = torch.nn.Parameter(torch.empty((1024, 16384))) - self.bias_15 = torch.nn.Parameter(torch.empty((1024,))) - self.weight_21 = torch.nn.Parameter(torch.empty((1000, 1024))) - self.bias_22 = torch.nn.Parameter(torch.empty((1000,))) - - def su3(self): - tensor_49 = cube.runtime.collectives.recv([[64, 16384]], [[0]]) - return tensor_49 - - def su4(self, tensor_49): - tensor_58 = torch.nn.functional.linear(tensor_49, self.weight_14, self.bias_15) - tensor_64 = torch.nn.functional.linear(tensor_58, self.weight_21, self.bias_22) - tensor_68 = torch.sum(tensor_64) - return tensor_68 - - def su6(self, tensor_89): - cube.runtime.collectives.send((tensor_89, ), [[0]]) - return - - def su12(self): - tensor_228 = cube.runtime.collectives.recv([[64, 16384]], [[0]]) - return tensor_228 - - def su13(self, tensor_228): - tensor_237 = torch.nn.functional.linear(tensor_228, self.weight_14, self.bias_15) - tensor_243 = torch.nn.functional.linear(tensor_237, self.weight_21, self.bias_22) - tensor_247 = torch.sum(tensor_243) - return tensor_247 - - def su15(self, tensor_268): - cube.runtime.collectives.send((tensor_268, ), [[0]]) - return - - def su21(self): - tensor_407 = cube.runtime.collectives.recv([[64, 16384]], [[0]]) - return tensor_407 - - def su22(self, tensor_407): - tensor_416 = torch.nn.functional.linear(tensor_407, self.weight_14, self.bias_15) - tensor_422 = torch.nn.functional.linear(tensor_416, self.weight_21, self.bias_22) - tensor_426 = torch.sum(tensor_422) - return tensor_426 - - def su24(self, tensor_447): - cube.runtime.collectives.send((tensor_447, ), [[0]]) - return - - def su30(self): - tensor_586 = cube.runtime.collectives.recv([[64, 16384]], [[0]]) - return tensor_586 - - def su31(self, tensor_586): - tensor_595 = torch.nn.functional.linear(tensor_586, self.weight_14, self.bias_15) - tensor_601 = torch.nn.functional.linear(tensor_595, self.weight_21, self.bias_22) - tensor_605 = torch.sum(tensor_601) - return tensor_605 - - def su33(self, tensor_626): - cube.runtime.collectives.send((tensor_626, ), [[0]]) - return - - -########## Generated Code ########### -import torch -import cube - -def _train_step(model, dataloader): - tensor_49 = cube.runtime.temporal.forward(model.su3, *()) - tensor_68 = cube.runtime.temporal.forward(model.su4, *(tensor_49, )) - tensor_89 = cube.runtime.temporal.backward((tensor_49, ), (tensor_68, ), (None, )) - cube.runtime.temporal.forward(model.su6, *(tensor_89, )) - tensor_228 = cube.runtime.temporal.forward(model.su12, *()) - tensor_247 = cube.runtime.temporal.forward(model.su13, *(tensor_228, )) - tensor_268 = cube.runtime.temporal.backward((tensor_228, ), (tensor_247, ), (None, )) - cube.runtime.temporal.forward(model.su15, *(tensor_268, )) - tensor_407 = cube.runtime.temporal.forward(model.su21, *()) - tensor_426 = cube.runtime.temporal.forward(model.su22, *(tensor_407, )) - tensor_447 = cube.runtime.temporal.backward((tensor_407, ), (tensor_426, ), (None, )) - cube.runtime.temporal.forward(model.su24, *(tensor_447, )) - tensor_586 = cube.runtime.temporal.forward(model.su30, *()) - tensor_605 = cube.runtime.temporal.forward(model.su31, *(tensor_586, )) - tensor_626 = cube.runtime.temporal.backward((tensor_586, ), (tensor_605, ), (None, )) - cube.runtime.temporal.forward(model.su33, *(tensor_626, )) From cab24b550ce4668b02bc4ae7ae33a9a49b8a260a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Oct 2021 12:25:58 +0800 Subject: [PATCH 0229/1892] test for graph --- tests/graph/test_graph.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 7c97542c..102e3524 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -1,5 +1,5 @@ from cube.graph.graph import IRGraph -from cube.graph.tensor import IRFullTensor +from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.graph.operator import IROperation from cube.ir.cten import IRTensor @@ -23,6 +23,7 @@ def construct_model(): linear1.set_input(0, input) linear1.set_input(1, weight1) linear1.set_input(2, bias1) + linear1.infer_shape() # linear2 linear2 = IROperation( @@ -33,6 +34,7 @@ def construct_model(): ) linear2.set_input(0, linear1.outputs(0)) linear2.set_input(1, weight2) + linear2.infer_shape() # linear3 linear3 = IROperation( @@ -44,6 +46,7 @@ def construct_model(): linear3.set_input(0, linear2.outputs(0)) linear3.set_input(1, weight3) linear3.set_input(2, bias3) + linear3.infer_shape() # return [input], [ops], [output] return [input], [linear1, linear2, linear3], [linear3.outputs(0)] @@ -68,18 +71,20 @@ def test_graph_init(): for input in all_inputs: if isinstance(input, IRTensor): - assert isinstance(input, IRFullTensor) + assert isinstance(input, IRSubTensor) for output in all_outputs: if isinstance(output, IRTensor): - assert isinstance(output, IRFullTensor) + assert isinstance(output, IRSubTensor) # check inputs - for input in inputs: - assert input in graph.inputs() - assert input in all_inputs - for output in outputs: - assert output in graph.outputs() - assert output in all_outputs + for full_input, sub_input in zip(inputs, graph.inputs()): + assert full_input.overlap(sub_input) + assert full_input.shape == sub_input.shape + assert sub_input in all_inputs + for full_output, sub_output in zip(outputs, graph.outputs()): + assert full_output.overlap(sub_output) + assert full_output.shape == sub_output.shape + assert sub_output in all_outputs # check dependency node1, node2, node3 = graph.nodes() From 4eb19cace6d2949f520f9995299a87c1511e5367 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Oct 2021 12:26:12 +0800 Subject: [PATCH 0230/1892] tensor repr --- cube/graph/tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 2b7d5c05..eb2cc6bf 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -162,6 +162,10 @@ def __and__(self, other): raise NotImplementedError(f"not supported for differnt steps") return IndexMap(tuple(slices)) + def __repr__(self): + dscp = repr(self._indices) + return dscp + class IRFullTensor(IRTensor): From fbd489387f2a72e5bb9f3123fc530a44f6031dd8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Oct 2021 12:27:11 +0800 Subject: [PATCH 0231/1892] on-going: translator for adapter --- cube/schedule/translator.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index faa125d6..07d597d5 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -5,11 +5,13 @@ Schedule Units, and then add Adapter ScheduleUnit """ from typing import List +from numpy.lib.arraysetops import isin import torch from cube.ir.cten import IRCell, IRTensor from cube.graph.tensor import IRFullTensor from cube.schedule.adapter.comm import IRCommunication +from cube.schedule.adapter.select import IRTensorReshape from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.pool import SchedulePool from cube.schedule.sugraph import SUGraph @@ -127,10 +129,13 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: if not isinstance(input, IRTensor): continue pre_sus = su.predecessors(in_idx) + select_tensors = list() for pre_su in pre_sus: for out_idx, output in enumerate(pre_su.outputs()): if output.overlap(input): - sub_tensor = output.common(input) + sub_tensor = input.common(output) + if sub_tensor != input: + select_tensors.append(sub_tensor) send_op = IRCommunication( send_tensors=[sub_tensor], send_ranks = [-1] @@ -144,7 +149,28 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: recv_su = ScheduleUnit([recv_op], SUType.Adapter, name='recv') su._add_in_adapter(in_idx, send_su, recv_su) pre_su._add_out_adapter(out_idx, send_su, recv_su) - + # TODO: add adapter for select + if len(select_tensors) != 0: + select_op = IRTensorReshape( + src_tensors=[input], dst_tensors=select_tensors + ) + select_su = ScheduleUnit([select_op], SUType.Adapter, name='select') + + # TODO: add adapter for merge + for out_idx, output in enumerate(su.outputs()): + if not isinstance(output, IRTensor): + continue + merge_tensors = list() + for send_adapters, recv_adapters in su.out_adapters(out_idx): + for recv_adapter in recv_adapters: + for tensor in recv_adapter.nodes(0).recv_tensors: + if tensor != output: + merge_tensors.append(tensor) + merge_op = IRTensorReshape( + src_tensors=merge_tensors, dst_tensors=output + ) + merge_su = ScheduleUnit([merge_op], SUType.Adapter, name='merge') + sus_with_adapter = list() for su in sus: for idx in range(len(su.inputs())): From b027606b7f5dc7e6bdcb9ca7edd7f4a9bdf49d03 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Oct 2021 12:27:35 +0800 Subject: [PATCH 0232/1892] on-going: codegen for adapter --- cube/codegen/codegen.py | 57 +++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 832c691f..b79ee107 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -2,11 +2,11 @@ Generate Pytorch code given the model DAG and the transformation config """ from typing import List, Any -from cube.graph.ir_comm import IRCommType, IRCommunication -from cube.graph.ir_cten import IRTensor -from cube.tschedule.suseq import SUSequence -from cube.tschedule.su import ScheduleUnit, SUType +from cube.ir.cten import IRTensor +from cube.schedule.sugraph import SUGraph +from cube.schedule.su import ScheduleUnit, SUType +from cube.schedule.adapter.comm import IRCommType, IRCommunication from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -19,9 +19,9 @@ class SScheduleCodeGen: Generate spatial code for the model """ - def __init__(self, seq: SUSequence): - if not isinstance(seq, SUSequence): - raise TypeError("seq should be SUSequence") + def __init__(self, seq: SUGraph): + if not isinstance(seq, SUGraph): + raise TypeError("seq should be SUGraph") self.seq = seq # model full code self.init_code: List[str] = [ @@ -70,7 +70,6 @@ def gen(self, device: int, outfile=None, attach=False) -> str: for out in node.outputs(): if isinstance(out, IRTensor) or isinstance(out, str): self.symbols.create(self.naming(out)) - print(self.forward_region) self.all_su_forward_region.append(self.forward_region) self.forward_region = list() @@ -112,7 +111,8 @@ def emit_var_declare(self, var: Any): if isinstance(var, IRTensor): name = self.naming(var) # indicate this is a leaf tensor, should be parameter - if self.symbols.create(name): + if var.is_param(): + self.symbols.create(name) code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(var.shape)}))' self.declare_region.append(code) elif isinstance(var, str): @@ -164,6 +164,34 @@ def emit_comm_call(self, node, su: ScheduleUnit): raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") self.forward_region.append(code) + def emit_adapter_call(self, node, su: ScheduleUnit): + """ + Emit in-device tensor transformation call. + + Note the in-device transformation only happens on send + """ + send_tensors, send_ranks = list(), list() + recv_tensors, recv_ranks = list(), list() + for node in su.nodes(): + trans_tensors, trans_ranks = node.send_tensors, node.send_ranks + send_tensors = list() + send_ranks = list() + for trans_tensor, trans_rank in zip(trans_tensors, trans_ranks): + #TODO: tensor transformation + # cross-devie send + if su.device[0] != trans_rank: + send_tensors.append(trans_tensor) + send_ranks.append(trans_rank) + trans_tensors, trans_ranks = node.recv_tensors, node.recv_ranks + for trans_tensor, trans_rank in zip(trans_tensors, trans_ranks): + # cross-devie send + if su.device[0] != trans_rank: + recv_tensors.append(trans_tensor) + recv_ranks.append(trans_rank) + + + + def _forward_region_arg_names(self, tensors: List[Any], su: ScheduleUnit): """ Generate arg name list for forward region. @@ -173,8 +201,7 @@ def _forward_region_arg_names(self, tensors: List[Any], su: ScheduleUnit): named_args : List[str] = list() for tensor in tensors: name = self.naming(tensor) - if isinstance(tensor, IRTensor) and \ - tensor.is_leaf(su.nodes()) and (tensor not in su.inputs()): + if isinstance(tensor, IRTensor) and tensor.is_param(): named_args.append('self.' + name) else: named_args.append(self.naming(name)) @@ -183,8 +210,6 @@ def _forward_region_arg_names(self, tensors: List[Any], su: ScheduleUnit): def naming(self, tensor: Any) -> str: """ Return the var name (unique for different variable) - - If the var is a leaf tensor, will add prefix `self.` to its name """ if isinstance(tensor, IRTensor): tensor_name = 'tensor' if tensor.name is None else tensor.name @@ -211,9 +236,9 @@ def clear(self): class TScheduleCodeGen: - def __init__(self, seq: SUSequence): - if not isinstance(seq, SUSequence): - raise TypeError("seq should be SUSequence") + def __init__(self, seq: SUGraph): + if not isinstance(seq, SUGraph): + raise TypeError("seq should be SUGraph") self.seq = seq # model full code self.init_code: List[str] = [ From 27e69e139bd38aa31e4ed168e1ed75b517092f22 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Oct 2021 12:29:06 +0800 Subject: [PATCH 0233/1892] remove useless code --- cube/sschedule/__init__.py | 52 --------- cube/sschedule/adapter.py | 45 -------- cube/sschedule/prim.py | 55 ---------- cube/tschedule/__init__.py | 104 ------------------ cube/tschedule/pool.py | 40 ------- cube/tschedule/su.py | 191 --------------------------------- cube/tschedule/suseq.py | 211 ------------------------------------- 7 files changed, 698 deletions(-) delete mode 100644 cube/sschedule/__init__.py delete mode 100644 cube/sschedule/adapter.py delete mode 100644 cube/sschedule/prim.py delete mode 100644 cube/tschedule/__init__.py delete mode 100644 cube/tschedule/pool.py delete mode 100644 cube/tschedule/su.py delete mode 100644 cube/tschedule/suseq.py diff --git a/cube/sschedule/__init__.py b/cube/sschedule/__init__.py deleted file mode 100644 index 1c9c5673..00000000 --- a/cube/sschedule/__init__.py +++ /dev/null @@ -1,52 +0,0 @@ -from cube.graph import parser -from cube.codegen.codegen import SScheduleCodeGen -from cube.sschedule.adapter import Adapter - - -class SpatialModule: - - def __init__(self, ir_graph): - # the full semantic graph - self._ir_graph = ir_graph - # the spatial pytorch module for specific rank - self._loaded_module = None - - def get_graph(self): - return self._ir_graph - - def gen_module(self, seq, rank, outfile, attach=False) -> str: - """ - Set the module - """ - gener = SScheduleCodeGen(seq) - code = gener.gen(device=rank, outfile=outfile, attach=attach) - return code - - def load_module(self, filename: str): - print(f'> loading generated spatial moduel from {filename}') - import importlib.util - spec = importlib.util.spec_from_file_location("GenModel", filename) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self._loaded_module = module.GenModel().cuda() - - def get_gen_module(self): - return self._loaded_module - - def clear_module(self): - self._loaded_module = None - - -def schedule(module, input_shapes, policy_fn=None): - """ - Spatial schedule - - Returns: - IRGraph - """ - ir_graph = parser.convert(module, input_shapes=input_shapes) - module = SpatialModule(ir_graph) - if policy_fn: - module._ir_graph = policy_fn(module.get_graph()) - module._ir_graph = Adapter.adapt(module._ir_graph) - return module diff --git a/cube/sschedule/adapter.py b/cube/sschedule/adapter.py deleted file mode 100644 index 21d09bb0..00000000 --- a/cube/sschedule/adapter.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Tuple - -from cube.graph.ir_comm import IRCommunication -from cube.graph.ir_graph import IRGraph - - -class Adapter: - - @staticmethod - def adapt(graph: IRGraph) -> IRGraph: - for src_node in graph.nodes(): - for out_idx, tensor in enumerate(src_node.outputs()): - for dst_node in src_node.successors(out_idx): - if set(src_node.device) != set(dst_node.device): - from_rank = src_node.device - to_rank = dst_node.device - from_rank, to_rank = from_rank, to_rank - #TODO check if it is a tensor - send_node, recv_node = Adapter.create_tensor_move( - tensor = tensor, - from_rank = from_rank, - to_rank = to_rank - ) - graph.insert(send_node, src_node=src_node) - graph.insert(recv_node, dst_node=dst_node, - replaced_tensor=tensor) - return graph - - @staticmethod - def create_tensor_move(tensor, from_rank: int, to_rank: int) -> Tuple[IRCommunication, IRCommunication]: - # send node - ir_send_node = IRCommunication( - send_tensors = [tensor], - send_ranks = [to_rank] - ) - ir_send_node.device = from_rank - # recv node - ir_recv_node = IRCommunication( - recv_tensors = [tensor], - recv_ranks = [from_rank] - ) - ir_recv_node.device = to_rank - ir_send_node.pair(ir_recv_node) - return ir_send_node, ir_recv_node - diff --git a/cube/sschedule/prim.py b/cube/sschedule/prim.py deleted file mode 100644 index 3371c421..00000000 --- a/cube/sschedule/prim.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -Spatial primitives for policy -""" -from cube.graph.ir_cten import IRCell, IRTensor -from cube.graph.ir_graph import IRGraph - -from typing import List, Union - - -def assign(inst: Union[IRTensor, IRCell], ranks: List[int], graph: IRGraph) -> None: - """ - Assign a IRTensor / IRCell with spatial rank placement - - For IRCell: - the device attribute will be set to ranks, - the inputs and outputs of this IRCell will also be changed - to ranks. - - For IRTensor: - A move operation will be changed and inserted in order: - output_node -> move -> input_node - """ - if not all([isinstance(rank, int) for rank in ranks]): - raise TypeError("Expected ranks to be List[int]") - if isinstance(inst, IRCell): - inst.device = ranks - elif isinstance(inst, IRTensor): - if set(inst.device) == set(ranks): - return - # find nodes that generated this tensor from the graph - src_node = list() - dst_node = list() - for node in graph.nodes(): - if inst in node.outputs(): - src_node.append(node) - if inst in node.inputs(): - dst_node.append(node) - if len(src_node) == 0: # a leaf tensor - raise NotImplementedError( - "Prim [assign]: moving parameter is not supported" - ) - if len(dst_node) == 0: # a loss tensor - raise RuntimeError( - "Prim [assign]: moving a tensor that is never used in graph" - ) - raise NotImplementedError( - "Prim [assign]: moving tensor is not supported yet" - ) - else: - raise TypeError("Expected inst to ba Union[IRTensor, IRCell]") - - -def select(tensor: IRTensor, indices, val_op, shape) -> IRTensor: - raise NotImplementedError("Prim [select]: selecting sub IRTensor is not supported") - diff --git a/cube/tschedule/__init__.py b/cube/tschedule/__init__.py deleted file mode 100644 index df159fc2..00000000 --- a/cube/tschedule/__init__.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Callable, Optional -import torch -from cube.tschedule.pool import TSchedulePool -from cube.graph.ir_cten import IRTensor -from cube.tschedule.suseq import SUSequence -from cube.tschedule.su import ScheduleUnit -from cube.codegen.codegen import TScheduleCodeGen - - -class IRTensorDataLoader: - - def __init__(self, dataloader): - self.dataloader = dataloader - - def __iter__(self): - return self - - def __next__(self): - # generate a schedule node - datas = next(self.dataloader) - if not isinstance(datas, tuple): - datas = (datas,) - - outputs = [ - IRTensor(shape=list(data.shape), name='data') for data in datas - ] - for idx, (output, data) in enumerate(zip(outputs, datas)): - if not torch.is_tensor(data): - outputs[idx] = data - else: - output.requires_grad = False - - if len(outputs) == 0: return - elif len(outputs) == 1: return outputs[0] - else: return tuple(outputs) - - -def schedule(model, dataloader, policy_fn: Optional[Callable] = None): - """ - AI Scientist calls like: - - @cube.tschedule.schedule - def train_step(model, dataloader): - # do a 4-time gradient accumulation - for acc_step, (data, label) in enumerate(dataloader): - if acc_step < 4: - loss = model(data, label) - loss.backward() - else: - break - ... - - for epoch in range(100): - train_step(model, data_loader) - optimizer.step() - optimizer.zero_grad() - - ... - """ - ir_graph = model.get_graph() - ir_dataloader = IRTensorDataLoader(dataloader) - myrank = torch.distributed.get_rank() - - def _load_tschedule_fn(filename) -> Callable: - print(f'> [{myrank}] loading generated schedule from {filename} ...') - import importlib.util - spec = importlib.util.spec_from_file_location( - "_train_step", filename - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module._train_step - - def decorator(fn: Callable) -> Callable: - filename = 'gencode{}.py' - if myrank == 0: - TSchedulePool().clear() - # collect trace - fn(ir_graph, ir_dataloader) - sus = TSchedulePool().sus() - seq = SUSequence(sus) - # policy - if policy_fn: - seq = policy_fn(seq) - - world_size = torch.distributed.get_world_size() - tgener = TScheduleCodeGen(seq) - for rank in range(world_size): - fname = filename.format(rank) - # generate spatial module code - model.gen_module(seq, rank, fname, attach=False) - # generate temporal schedule code - tgener.gen( - device = rank, - outfile = fname, - attach=True - ) - torch.distributed.barrier() - # load module - model.load_module(filename.format(myrank)) - # load temporal - return _load_tschedule_fn(filename.format(myrank)) - - return decorator diff --git a/cube/tschedule/pool.py b/cube/tschedule/pool.py deleted file mode 100644 index 8843faf3..00000000 --- a/cube/tschedule/pool.py +++ /dev/null @@ -1,40 +0,0 @@ - - -class TSchedulePool: - - class __TSchedulePool: - - def __init__(self): - - self._sus = list() - self._flow_id = -1 - - instance = None - - def __init__(self): - if not TSchedulePool.instance: - TSchedulePool.instance = TSchedulePool.__TSchedulePool() - - def __getattr__(self, name): - return getattr(self.instance, name) - - def add_su(self, su): - self.instance._sus.append(su) - - def sus(self): - return self.instance._sus - - def clear(self): - self.instance._sus = list() - self.instance._flow_id = -1 - - def gen_id(self) -> int: - """ - Generate an unique action id - """ - self.instance._flow_id += 1 - return self.instance._flow_id - - def __repr__(self): - dscp = '\n'.join([repr(su) for su in self._sus]) - return dscp diff --git a/cube/tschedule/su.py b/cube/tschedule/su.py deleted file mode 100644 index a8421c65..00000000 --- a/cube/tschedule/su.py +++ /dev/null @@ -1,191 +0,0 @@ -from typing import Union, List, Optional -import copy -from enum import Enum -from cube.graph.ir_comm import IRCommunication - -from cube.graph.ir_cten import IRCell - - -class SUType(Enum): - - # outputs = cube.runtime.temporal.forward(model, *args) - Forward = 'cube.runtime.temporal.forward' - - # grads = cube.runtime.temporal.backward( - # input_tensors, output_tensors, output_grads - # ) - Backward = 'cube.runtime.temporal.backward' - - # cube.runtime.collectives.sendrecv(send_tensors, send_ranks, - # recv_shapes, from_ranks - # ) - Adapter = 'cube.runtime.collectives.sendrecv' - - Dataloader = 'next(dataloader)' - - -class ScheduleUnit(IRCell): - """ - Action recv tensors must be inside of Action inputs, - and can be mapped to Action.graph.inputs - - """ - - def __init__(self, sub_nodes, graph, devices: Union[List[int], int], stype: SUType): - - if not isinstance(stype, SUType): - raise TypeError("Expected stype be SUType") - - self.stype = stype - self.global_graph = graph - - subgraph = graph.subgraph(sub_nodes) - inputs = subgraph.inputs() - outputs = subgraph.outputs() - - super().__init__( - name = graph.name, - signature = stype.value, - input_length = len(inputs), - output_length = len(outputs) - ) - - self._nodes = sub_nodes - for idx, input in enumerate(inputs): - self.set_input(idx, input) - for idx, output in enumerate(outputs): - self.set_output(idx, output) - - # set su device - self.device = devices - - # additional control dependency for add_flow - self._ctrl_predecessors = list() - self._ctrl_successors = list() - - self.mirror = None - - def set_mirror(self, su): - """ - Create a mirrored ScheduleUnit: the - inputs and outputs are reversed - """ - if not isinstance(su, ScheduleUnit): - raise TypeError("Expected mirror to be ScheduleUnit") - self.mirror = su - - def nodes(self, index: Optional[int] = None): - """ - Get node at position index - """ - if isinstance(index, int): - if index >= len(self._nodes): - raise RuntimeError( - f"Get node out of range ({index} >= {len(self._nodes)})" - ) - return self._nodes[index] - elif index is None: - return copy.copy(self._nodes) - else: - raise TypeError("Expected index to be None or int") - - def add_predecessor(self, input_index: int, su): - """ - Add a predecessor cell in the input_index slot. - self.input[input_index] = node.output[out_index] - """ - if input_index == -1: - self._predecessors.append(su) - else: - super().add_predecessor(input_index, su) - - def predecessors(self, index: Optional[int] = None) -> List: - """ - Get 1-hop predecessor cells including control predecessors - - Args: - index (Optional[int]): - -1: return control predecessors - None: return all predecessors including index - >0 : return input SUs at input index - - Returns: - cell(s): List[IRCell] - """ - if isinstance(index, int): - if index == -1: - return copy.copy(self._ctrl_predecessors) - if index >= len(self._inputs): - raise RuntimeError( - f"Get the input out of range ({index} >= {len(self._inputs)}" - ) - return copy.copy(self._predecessors[index]) - elif index is None: - predecessors = list() - for pre_cells in self._predecessors: - predecessors += pre_cells - predecessors += self._ctrl_predecessors - return predecessors - else: - raise TypeError("Expected index to be None or int") - - def add_successor(self, output_index: int, su): - """ - Set self node the output index node. - `node` will take the self.outputs(index) as the input - """ - if output_index == -1: - self._successors.append(su) - else: - super().add_successor(output_index, su) - - def successors(self, index: Optional[int] = None) -> List: - """ - Get 1-hop successor cells including control successors - - Args: - index (Optional[int]): - -1: return control successors - None: return all successors including index - >0 : return output SUs at output index - - Returns: - cells: List[ScheduleUnit] - """ - if isinstance(index, int): - if index == -1: - return copy.copy*self._ctrl_successors - if index >= len(self._outputs): - raise RuntimeError( - f"Get the output out of range ({index} >= {len(self._outputs)}" - ) - return copy.copy(self._successors[index]) - elif index is None: - successors = list() - for post_cells in self._successors: - successors += post_cells - successors += self._ctrl_successors - return successors - else: - raise TypeError("Expected index to be None or int") - - def __repr__(self): - su_inputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.inputs()] - su_outputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.outputs()] - dscp = f'SU({self.stype}, nodes={len(self.nodes())})-dev{self.device}: {su_inputs} -> {su_outputs}' - return dscp - - -def logic_translator(graph, su_type: SUType) -> List[ScheduleUnit]: - if not isinstance(su_type, SUType): - raise TypeError("Expected SU Type") - sus = list() - for node in graph.nodes(): - stype = su_type - if isinstance(node, IRCommunication): - stype = SUType.Adapter - devices = node.device - for device in devices: - su = ScheduleUnit([node], graph, device, stype) - sus.append(su) - return sus diff --git a/cube/tschedule/suseq.py b/cube/tschedule/suseq.py deleted file mode 100644 index 379825c7..00000000 --- a/cube/tschedule/suseq.py +++ /dev/null @@ -1,211 +0,0 @@ -from typing import List, Any, Optional -import copy - -from cube.graph.ir_cten import IRCell, IRTensor -from cube.tschedule.su import ScheduleUnit - - -class SUSequence(IRCell): - - def __init__(self, sus: List[ScheduleUnit]): - - if not all([isinstance(su, ScheduleUnit) for su in sus]): - raise TypeError( - f"Expected a list of ScheduleUnits, but got {type(sus)}" - ) - - super().__init__( - name = 'SU', - signature = 'None', - input_length = 0, - output_length = 0 - ) - self.sequence = sus - self.reset_dependency() - - def reset_dependency(self): - """ - Reset the node dataflow dependency - """ - # set node predecessors and successors - for src_idx in range(len(self.sequence)): - src_cell = self.sequence[src_idx] - src_cell._successors = [ - list() for _ in range(len(src_cell.outputs())) - ] - for dst_idx in range(src_idx + 1, len(self.sequence)): - dst_su = self.sequence[dst_idx] - dst_su._predecessors = [ - list() for _ in range(len(dst_su.inputs())) - ] - for tensor in src_cell.outputs(): - if isinstance(tensor, IRTensor): - if tensor in dst_su.inputs(): - src_output_idx = src_cell.outputs().index(tensor) - src_cell.add_successor(src_output_idx, dst_su) - dst_input_idx = dst_su.inputs().index(tensor) - dst_su.add_predecessor(dst_input_idx, src_cell) - - def __len__(self): - return len(self.sequence) - - def sus(self, index: Optional[int] = None): - """ - Return ScheduleUnit - - Args: - - """ - if isinstance(index, int): - if index >= len(self.sequence): - raise RuntimeError( - f"Get node out of range ({index} >= {len(self.sequence)})" - ) - return self.sequence[index] - elif index is None: - return copy.copy(self.sequence) - else: - raise TypeError("Expected index to be None or int") - - def happen_before(self, su1, su2): - """ - Check if the su1 -> (happened before) su2 - - Returns: - Boolean - """ - if not isinstance(su1, ScheduleUnit) or \ - not isinstance(su2, ScheduleUnit): - raise TypeError("Expected su to be an ScheduleUnit") - if su2 in su1.successors(): - return True - else: - for succ_su in su1.successors(): - if self.happen_before(succ_su, su2): - return True - return False - - def happen_after(self, su1, su2): - """ - Check if the su2 -> (happened before) su1 - - Returns: - Boolean - """ - return self.happen_before(su2, su1) - - def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: - """ - Merge two ScheduleUnit. This requires - - 1). all the nodes in one SU happens before / after - all the nodes in another SU. (Guaranteed by default - as all the operations on sequence are semantic-correct) - - 2). all the nodes in both SU are on the same device, - have same tags and they are not equal. - - 3). Deadlock-free merge. Suppose - SU1 (dev0) -> SU2 (dev1) -> SU3 (dev0) - Then merge SU1 and SU3 to SU4 will cause - deadlock on SU4 -> <- SU2 - - Note due to PyTorch limitation, - merging two forward ScheduleUnits will also cause - the merge of corresponding two backward ScheduleUnits. - - Returns: - if succeed: A merged ScheduleUnit. - if fail: None - """ - - if not isinstance(su1, ScheduleUnit) or \ - not isinstance(su2, ScheduleUnit): - raise TypeError("Expected SU1 and SU2 are ScheduleUnit") - if su1 not in self.sequence: - raise ValueError(f"su1: {su1} not in sequence") - if su2 not in self.sequence: - raise ValueError(f"su2: {su2} not in sequence") - - # 2) all the nodes in both SU are on the same device - if su1 == su2 or su1.stype != su2.stype: - return None - if set(su1.device) != set(su2.device): - return None - - # 3) deadlock-free merge - index_su1 = self.sequence.index(su1) - index_su2 = self.sequence.index(su2) - # make su1 happen before su2 - su1, su2 = (su1, su2) if index_su1 < index_su2 else (su2, su1) - index_su1, index_su2 = min(index_su1, index_su2), max(index_su1, index_su2) - inter_sus = self.sequence[index_su1+1:index_su2] - for su in inter_sus: - if su1.happen_after(su) and su.happen_before(su2): - return None - - # merge forward su - sub_nodes = su1.nodes() + su2.nodes() - merged_su = ScheduleUnit( - sub_nodes, su1.global_graph, su1.device, su1.stype - ) - - # merge mirrored su - # mirror_su2 -> mirror_su1 - mirror_su1, mirror_su2 = su1.mirror, su2.mirror - sub_nodes = mirror_su2.nodes() + mirror_su1.nodes() - merged_mirror_su = ScheduleUnit( - sub_nodes, mirror_su1.global_graph, mirror_su1.device, mirror_su1.stype - ) - - # set mirror - merged_su.set_mirror(merged_mirror_su) - merged_mirror_su.set_mirror(merged_su) - - # replace - self.sequence[index_su1] = merged_su - self.sequence.remove(su2) - if mirror_su1 in self.sequence and mirror_su2 in self.sequence: - index_mirror_su2 = self.sequence.index(mirror_su2) - self.sequence[index_mirror_su2] = merged_mirror_su - self.sequence.remove(mirror_su1) - - # TODO: optimize: reset dependency - self.reset_dependency() - return merged_su - - def add_flow(self, su1, su2): - """ - Add control flow dependency su1 -> su2 - """ - if not isinstance(su1, ScheduleUnit) or not isinstance(su2, ScheduleUnit): - raise TypeError("Expected both SU1 and SU2 are ScheduleUnit") - su1.add_successors(-1, su2) - su2.add_predecessors(-1, su1) - - def is_correct(self): - """ - Check whether sequence - satisfies the sequential consistency model - """ - - for index, su in enumerate(self.sequence): - for pre_su in su.predecessors(): - # find the pre-su not appear in sequence - if not pre_su in self.sequence: - return False - pre_idx = self.sequence.index(pre_su) - # violate sequential consistency model - if pre_idx >= index: - return False - return True - - def __repr__(self): - dscp = f'ScheduleSeq (len={len(self)}):\n' - for node in self.sequence: - succ_node_ids = [None] * len(node.outputs()) - for out_idx in range(len(node.outputs())): - node_list = [snode._id for snode in node.successors(out_idx)] - succ_node_ids[out_idx] = node_list - dscp += f"\n{node._id}: {node} -> su id {succ_node_ids}\n" - return dscp From 5d12615526a6a6e5710d9205d439a2eefbb01806 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Oct 2021 14:26:34 +0800 Subject: [PATCH 0234/1892] select / merge su --- cube/schedule/adapter/select.py | 2 + cube/schedule/su.py | 76 +++++++++++++++++++++++++++++++-- cube/schedule/translator.py | 54 ++++++++++++++--------- 3 files changed, 108 insertions(+), 24 deletions(-) diff --git a/cube/schedule/adapter/select.py b/cube/schedule/adapter/select.py index 5a19d855..120f87e8 100644 --- a/cube/schedule/adapter/select.py +++ b/cube/schedule/adapter/select.py @@ -63,6 +63,8 @@ def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor for dim, dim_indices in enumerate(indices): if len(dim_indices) != 1: if self._merge_axis is not None: + print("src: ", src_tensors) + print("dst: ", dst_tensors) raise NotImplementedError("Only support merge on one axis") self._merge_axis = dim if self._merge_axis is None: diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 25f7313d..4134cc0d 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -56,8 +56,9 @@ def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): for idx, output in enumerate(outputs): self.set_output(idx, output) - # each input is associated with - # send adapters and recv adapters (send + recv) + # each input is associated with a reshape (merge) adatpers and + # a couple of send adapters and recv adapters (send + recv) + self._merge_adapters: List[ScheduleUnit] = [None] * len(inputs) self._send_in_adapters: List[List[ScheduleUnit]] = [ list() for _ in range(len(inputs)) ] @@ -65,8 +66,9 @@ def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): list() for _ in range(len(inputs)) ] - # each output is associated with - # send adapters and recv adapters (send + recv) + # each input is associated with a reshape (select) adatpers and + # a couple of send adapters and recv adapters (send + recv) + self._select_adapters: List[ScheduleUnit] = [None] * len(outputs) self._send_out_adapters: List[List[ScheduleUnit]] = [ list() for _ in range(len(outputs)) ] @@ -131,6 +133,25 @@ def in_adapters(self, index: Optional[int] = None) -> List: else: raise TypeError("Expected index to be None or int") + def merge_adapters(self, index: Optional[int] = None) -> List: + """ + Get select adapter for the input tensor at index + + Returns: + Union[ScheduleUnit, List[ScheduleUnit]] + """ + if isinstance(index, int): + if index >= len(self._inputs): + raise RuntimeError( + f"Get index out of range ({index} >= {len(self._inputs)})" + ) + select_adapter = self._merge_adapters[index] + return select_adapter + elif index is None: + return copy.copy(self._merge_adapters) + else: + raise TypeError("Expected index to be None or int") + def out_adapters(self, index: Optional[int] = None) -> Tuple[List, List]: """ Get adapter for the output tensor at index @@ -158,6 +179,25 @@ def out_adapters(self, index: Optional[int] = None) -> Tuple[List, List]: else: raise TypeError("Expected index to be None or int") + def select_adapters(self, index: Optional[int] = None) -> List: + """ + Get select adapter for the input tensor at index + + Returns: + Union[ScheduleUnit, List[ScheduleUnit]] + """ + if isinstance(index, int): + if index >= len(self._outputs): + raise RuntimeError( + f"Get index out of range ({index} >= {len(self._outputs)})" + ) + select_adapter = self._select_adapters[index] + return select_adapter + elif index is None: + return copy.copy(self._select_adapters) + else: + raise TypeError("Expected index to be None or int") + def _clear_adapters(self): """ Clear all adapters for this SU @@ -193,6 +233,20 @@ def _add_in_adapter(self, index: int, send_adapter, recv_adapter): self._send_in_adapters[index].append(send_adapter) self._recv_in_adapters[index].append(recv_adapter) + def _set_merge_adapter(self, index: int, merge_adapter): + """ + Set adapters to the input tensor of this SU + + Args: + index (int): the input index + merge_adapter (ScheduleUnit) + """ + if index >= len(self._inputs): + raise ValueError(f"index {index} out of range {len(self._inputs)}") + if not isinstance(merge_adapter, ScheduleUnit): + raise TypeError("Expected merge adapter to be ScheduleUnit") + self._merge_adapters[index] = merge_adapter + def _add_out_adapter(self, index: int, send_adapter, recv_adapter): """ Add adapters to the output tensor of this SU @@ -211,6 +265,20 @@ def _add_out_adapter(self, index: int, send_adapter, recv_adapter): self._send_out_adapters[index].append(send_adapter) self._recv_out_adapters[index].append(recv_adapter) + def _set_select_adapter(self, index: int, select_adapter): + """ + Set adapters to the output tensor of this SU + + Args: + index (int): the output index + select_adapter (ScheduleUnit) + """ + if index >= len(self._outputs): + raise ValueError(f"index {index} out of range {len(self._inputs)}") + if not isinstance(select_adapter, ScheduleUnit): + raise TypeError("Expected merge adapter to be ScheduleUnit") + self._select_adapters[index] = select_adapter + def nodes(self, index: Optional[int] = None): """ Get node at position index diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index 07d597d5..113f59d2 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -5,7 +5,6 @@ Schedule Units, and then add Adapter ScheduleUnit """ from typing import List -from numpy.lib.arraysetops import isin import torch from cube.ir.cten import IRCell, IRTensor @@ -129,13 +128,13 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: if not isinstance(input, IRTensor): continue pre_sus = su.predecessors(in_idx) - select_tensors = list() + tensor_segments = list() for pre_su in pre_sus: for out_idx, output in enumerate(pre_su.outputs()): if output.overlap(input): sub_tensor = input.common(output) if sub_tensor != input: - select_tensors.append(sub_tensor) + tensor_segments.append(sub_tensor) send_op = IRCommunication( send_tensors=[sub_tensor], send_ranks = [-1] @@ -149,34 +148,49 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: recv_su = ScheduleUnit([recv_op], SUType.Adapter, name='recv') su._add_in_adapter(in_idx, send_su, recv_su) pre_su._add_out_adapter(out_idx, send_su, recv_su) - # TODO: add adapter for select - if len(select_tensors) != 0: - select_op = IRTensorReshape( - src_tensors=[input], dst_tensors=select_tensors + # add adapter for merge + if len(tensor_segments) != 0: + merge_op = IRTensorReshape( + src_tensors=tensor_segments, dst_tensors=[input] ) - select_su = ScheduleUnit([select_op], SUType.Adapter, name='select') - - # TODO: add adapter for merge + merge_su = ScheduleUnit([merge_op], SUType.Adapter, name='merge') + su._set_merge_adapter(in_idx, merge_su) + + # add adapter for select for out_idx, output in enumerate(su.outputs()): if not isinstance(output, IRTensor): continue - merge_tensors = list() - for send_adapters, recv_adapters in su.out_adapters(out_idx): - for recv_adapter in recv_adapters: - for tensor in recv_adapter.nodes(0).recv_tensors: - if tensor != output: - merge_tensors.append(tensor) - merge_op = IRTensorReshape( - src_tensors=merge_tensors, dst_tensors=output - ) - merge_su = ScheduleUnit([merge_op], SUType.Adapter, name='merge') + select_tensors = list() + send_adapters, recv_adapters = su.out_adapters(out_idx) + for send_adapter in send_adapters: + for tensor in send_adapter.nodes(0).send_tensors: + if tensor != output: + select_tensors.append(tensor) + if len(select_tensors) != 0: + select_op = IRTensorReshape( + src_tensors=[output], dst_tensors=select_tensors + ) + select_su = ScheduleUnit( + [select_op], SUType.Adapter, name='select' + ) + su._set_select_adapter(out_idx, select_su) sus_with_adapter = list() for su in sus: + # send + recv + merge for idx in range(len(su.inputs())): + merge_su = su.merge_adapters(idx) + if merge_su: + sus_with_adapter.append(merge_su) send_adapters, recv_adapters = su.in_adapters(idx) for send_su, recv_su in zip(send_adapters, recv_adapters): sus_with_adapter.append(send_su) sus_with_adapter.append(recv_su) + # excute sus_with_adapter.append(su) + # select + for idx in range(len(su.outputs())): + select_su = su.select_adapters(idx) + if select_su: + sus_with_adapter.append(select_su) return sus_with_adapter From 58303abbc8dcea9358348433e4d23e46a424727a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 28 Oct 2021 16:36:31 +0800 Subject: [PATCH 0235/1892] fix graph copy on break id --- cube/graph/graph.py | 25 +++++++++++++++++++------ cube/graph/tensor.py | 35 +++++++++++------------------------ 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5ad3dca4..26856b8b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -117,19 +117,32 @@ def copy(self, reverse=False): """ Copy the graph but re-new the intermediate tensor """ - new_tensors = dict() # old graph tensor._id -> new tensor + # old graph tensor.parent._id -> new full tensor + new_full_tensors = dict() def _renew(val: Any): if not isinstance(val, IRTensor): return val + elif isinstance(val, IRFullTensor): + raise RuntimeError("Found Full Tensor") # parameters - if val.is_param(): + if not reverse and val.is_param(): return val # intermediate data - if val._id not in new_tensors: - tensor = val.renew() - new_tensors[val._id] = tensor - return new_tensors[val._id] + if val.parent._id not in new_full_tensors: + full_tensor = val.parent.renew() + new_full_tensors[val.parent._id] = full_tensor + else: + full_tensor = new_full_tensors[val.parent._id] + new_val = full_tensor.select( + indices=val.indices, + val_op=val.val_op, + shape=val.shape + ) + if reverse and val.is_param(): + #TODO: something strange here: id not change + new_val.name = 'grad_' + new_val.name + return new_val nodes = list() for node in self.nodes(): diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index eb2cc6bf..f3fb617d 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -25,7 +25,11 @@ def __eq__(self, other): if self.ndims != self.ndims: return False for myslicer, oslicer in zip(self.get(), other.get()): - if myslicer != oslicer: + mstart, mstop = myslicer.start, myslicer.stop + mstep = myslicer.step if myslicer.stop is not None else 1 + ostart, ostop = oslicer.start, oslicer.stop + ostep = oslicer.step if oslicer.step is not None else 1 + if mstart != ostart or mstop != ostop or mstep != ostep: return False return True return False @@ -262,13 +266,13 @@ def select(self, indices: Union[Tuple, IndexMap], val_op: Optional[Callable], sh index = self._indices.index(indices) sub_tensor = self._segments[index] if sub_tensor.val_op == val_op: + print('here') return sub_tensor - else: - sub_tensor = IRSubTensor(self, indices, val_op, shape) - self._segments.append(sub_tensor) - self._indices.append(indices) - self._val_ops.append(val_op) - return sub_tensor + sub_tensor = IRSubTensor(self, indices, val_op, shape) + self._segments.append(sub_tensor) + self._indices.append(indices) + self._val_ops.append(val_op) + return sub_tensor def overlap(self, other): """ @@ -392,23 +396,6 @@ def __copy__(self): tensor._cell = list() return tensor - def renew(self): - """ - Renew a new tensor with same name and shape, - but with a different new id - - Returns: - tensor - """ - tensor = IRSubTensor(self.parent, self.indices, self.val_op, self._shape) - new_id = tensor._id - for key in self.__dict__: - setattr(tensor, key, getattr(self, key)) - # clear attached cells - tensor._cell = list() - tensor._id = new_id - return tensor - def select(self, indices: Union[Tuple, IndexMap], val_op, shape=None): """ Select an IRSubTensor From 08693eb3baef242ed24dc156773e770c26e64727 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 28 Oct 2021 19:16:06 +0800 Subject: [PATCH 0236/1892] fix param annotation bugs --- cube/graph/graph.py | 14 ++++++-------- cube/graph/tensor.py | 35 ++++++++++++++++++++++++++--------- cube/ir/cten.py | 3 +++ 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 26856b8b..363d4bc1 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -125,23 +125,21 @@ def _renew(val: Any): return val elif isinstance(val, IRFullTensor): raise RuntimeError("Found Full Tensor") - # parameters - if not reverse and val.is_param(): + # parameters in forward + if (not reverse) and val.is_param(): return val - # intermediate data + # intermediate / gradient data if val.parent._id not in new_full_tensors: - full_tensor = val.parent.renew() - new_full_tensors[val.parent._id] = full_tensor - else: - full_tensor = new_full_tensors[val.parent._id] + new_full_tensors[val.parent._id] = val.parent.like() + full_tensor = new_full_tensors[val.parent._id] new_val = full_tensor.select( indices=val.indices, val_op=val.val_op, shape=val.shape ) if reverse and val.is_param(): - #TODO: something strange here: id not change new_val.name = 'grad_' + new_val.name + assert new_val.is_param() return new_val nodes = list() diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index f3fb617d..89b3291d 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -198,21 +198,26 @@ def __copy__(self): tensor._cell = list() return tensor - def renew(self): + def as_param(self): """ - Renew a new tensor with same name and shape, + Set the tensor as trainable parameter + """ + self.requires_grad = True + self._is_param = True + for sub_tensor in self._segments: + sub_tensor.as_param() + + def like(self): + """ + Create a new tensor with same name and shape, but with a different new id Returns: tensor """ tensor = IRFullTensor(self._shape, self.name) - new_id = tensor._id - for key in self.__dict__: - setattr(tensor, key, getattr(self, key)) - # clear attached cells - tensor._cell = list() - tensor._id = new_id + for attr in IRFullTensor._attr: + setattr(tensor, attr, getattr(self, attr)) return tensor def segments(self, index: Optional[int] = None): @@ -266,9 +271,12 @@ def select(self, indices: Union[Tuple, IndexMap], val_op: Optional[Callable], sh index = self._indices.index(indices) sub_tensor = self._segments[index] if sub_tensor.val_op == val_op: - print('here') return sub_tensor + sub_tensor = IRSubTensor(self, indices, val_op, shape) + for attr in IRFullTensor._attr: + setattr(sub_tensor, attr, getattr(self, attr)) + self._segments.append(sub_tensor) self._indices.append(indices) self._val_ops.append(val_op) @@ -396,6 +404,15 @@ def __copy__(self): tensor._cell = list() return tensor + def as_param(self): + """ + Set the tensor as trainable parameter + """ + if not self.parent.is_param(): + self.parent.as_param() + self.requires_grad = True + self._is_param = True + def select(self, indices: Union[Tuple, IndexMap], val_op, shape=None): """ Select an IRSubTensor diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 60a97454..d710efd4 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -312,6 +312,9 @@ class IRTensor: """ IRTensor serves as IRGraph edge """ + + _attr = ['name', '_is_param', 'requires_grad'] + def __init__(self, shape=None, name=None): self._id: int = IDGenerator().gen_tensor_id() From 13dd790e2253faaf890f32023aad4b9790d182df Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 29 Oct 2021 10:32:24 +0800 Subject: [PATCH 0237/1892] fix merge bug and run pipeline --- cube/codegen/codegen.py | 103 +++++++++++++++--------------- cube/runtime/collectives.py | 22 +++---- cube/schedule/__init__.py | 15 +++-- cube/schedule/graphpass.py | 20 ++++-- cube/schedule/su.py | 1 + cube/schedule/sugraph.py | 81 +++++++++++++++++++++-- examples/e2e.py | 43 ++++++------- tests/codegen/test_codegen.py | 96 ++++++++++++++++++++++++++-- tests/schedule/test_graphpass.py | 18 +++--- tests/schedule/test_translator.py | 2 +- 10 files changed, 278 insertions(+), 123 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index b79ee107..41b81adb 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -2,30 +2,33 @@ Generate Pytorch code given the model DAG and the transformation config """ from typing import List, Any +import torch +import copy from cube.ir.cten import IRTensor from cube.schedule.sugraph import SUGraph from cube.schedule.su import ScheduleUnit, SUType from cube.schedule.adapter.comm import IRCommType, IRCommunication +from cube.schedule.adapter.select import IRTensorReshape, IRReshapeType from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock -import torch -import copy - -class SScheduleCodeGen: +class ModelCodeGen: """ Generate spatial code for the model """ - def __init__(self, seq: SUGraph): - if not isinstance(seq, SUGraph): - raise TypeError("seq should be SUGraph") - self.seq = seq + def __init__(self, sugraph: SUGraph): + if not isinstance(sugraph, SUGraph): + raise TypeError("sugraph should be SUGraph") + for su in sugraph.sus(): + if len(su.device) == 0: + raise RuntimeError(f"SU: {su} is not assigned to device") + self.seq = sugraph # model full code self.init_code: List[str] = [ - '\n\n########## Generated Code ###########', + '\n\n########## Generated Model Code ###########', 'import torch', 'import cube', '', ''] # module init code self.declare_region: List[str] = list() @@ -59,10 +62,12 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # parse graph body for su in device_sus: for node in su.nodes(): + if isinstance(node, IRTensorReshape): + self.emit_reshape_call(node) if isinstance(node, IRCommunication): - self.emit_comm_call(node, su) + self.emit_comm_call(node) else: - self.emit_op_call(node, su) + self.emit_op_call(node) # emit input declaration for arg in node.inputs(): self.emit_var_declare(arg) @@ -86,7 +91,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: with FunctionBlock(func_name=name, args=input_args) as fb: fb.insert_body(forward_code) # generate output - out_names = self._forward_region_arg_names(su.outputs(), su) + out_names = self._forward_region_arg_names(su.outputs()) return_code = f"return {', '.join(out_names)}" fb.insert_body(return_code) cb.insert_body('') @@ -111,8 +116,7 @@ def emit_var_declare(self, var: Any): if isinstance(var, IRTensor): name = self.naming(var) # indicate this is a leaf tensor, should be parameter - if var.is_param(): - self.symbols.create(name) + if var.is_param()and self.symbols.create(name): code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(var.shape)}))' self.declare_region.append(code) elif isinstance(var, str): @@ -125,74 +129,68 @@ def emit_var_declare(self, var: Any): self.declare_region.append(code) return - def emit_op_call(self, node, su: ScheduleUnit): + def emit_op_call(self, node): """ Emit op forward code """ op_code = node.signature - arg_names = self._forward_region_arg_names(node.inputs(), su) + arg_names = self._forward_region_arg_names(node.inputs()) arg_region = '(' + ', '.join(arg_names) + ')' if len(node.outputs()) == 0: code = f'{op_code}{arg_region}' else: - out_names = self._forward_region_arg_names(node.outputs(), su) + out_names = self._forward_region_arg_names(node.outputs()) out_names = ', '.join(out_names) code = f'{out_names} = {op_code}{arg_region}' self.forward_region.append(code) - def emit_comm_call(self, node, su: ScheduleUnit): + def emit_comm_call(self, node): """ Emit communication code """ comm_code = node.signature - send_tensors = self._forward_region_arg_names(node.inputs(), su) + send_tensors = self._forward_region_arg_names(node.inputs()) + send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' send_ranks = node.send_ranks - recv_tensors = self._forward_region_arg_names(node.outputs(), su) + recv_tensors = self._forward_region_arg_names(node.outputs()) + recv_tensors = ', '.join(recv_tensors) recv_shapes = [tensor.shape for tensor in node.outputs()] recv_ranks = node.recv_ranks if node.comm_type == IRCommType.Send: - send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' code = f'{comm_code}({send_tensors}, {send_ranks})' elif node.comm_type == IRCommType.Recv: - recv_tensors = ', '.join(recv_tensors) code = f'{recv_tensors} = {comm_code}({recv_shapes}, {recv_ranks})' elif node.comm_type == IRCommType.SendRecv: - send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' - recv_tensors = ', '.join(recv_tensors) code = f'{recv_tensors} = {comm_code}({send_tensors}, {send_ranks}, {recv_shapes}, {recv_ranks})' else: raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") self.forward_region.append(code) - def emit_adapter_call(self, node, su: ScheduleUnit): + def emit_reshape_call(self, node): """ - Emit in-device tensor transformation call. - - Note the in-device transformation only happens on send + Emit in-device tensor select / merge call. """ - send_tensors, send_ranks = list(), list() - recv_tensors, recv_ranks = list(), list() - for node in su.nodes(): - trans_tensors, trans_ranks = node.send_tensors, node.send_ranks - send_tensors = list() - send_ranks = list() - for trans_tensor, trans_rank in zip(trans_tensors, trans_ranks): - #TODO: tensor transformation - # cross-devie send - if su.device[0] != trans_rank: - send_tensors.append(trans_tensor) - send_ranks.append(trans_rank) - trans_tensors, trans_ranks = node.recv_tensors, node.recv_ranks - for trans_tensor, trans_rank in zip(trans_tensors, trans_ranks): - # cross-devie send - if su.device[0] != trans_rank: - recv_tensors.append(trans_tensor) - recv_ranks.append(trans_rank) - - - + src_tensors = self._forward_region_arg_names(node.inputs()) + dst_tensors = self._forward_region_arg_names(node.outputs()) + # emit select + if node.ttype == IRReshapeType.Select: + src_tensor = src_tensors[0] + #TODO: relative indices + indices = node.select_indices + indices = [slicer.get() for slicer in indices] + dst_tensors = ', '.join(dst_tensors) + code = f'{dst_tensors} = {node.signature}({src_tensor}, {indices})' + self.forward_region.append(code) + elif node.ttype == IRReshapeType.Merge: + axis = node.merge_axis + src_tensor = '(' + ', '.join(src_tensors + ['']) + ')' + dst_tensor = dst_tensors[0] + code = f'{dst_tensor} = {node.signature}({src_tensor}, {axis})' + self.forward_region.append(code) + else: + raise TypeError(f"Unknown Reshape Type: {node.ttype}") - def _forward_region_arg_names(self, tensors: List[Any], su: ScheduleUnit): + def _forward_region_arg_names(self, tensors: List[Any]): """ Generate arg name list for forward region. @@ -233,8 +231,7 @@ def clear(self): self.symbols = SymbolTable() - -class TScheduleCodeGen: +class ScheduleCodeGen: def __init__(self, seq: SUGraph): if not isinstance(seq, SUGraph): @@ -242,7 +239,7 @@ def __init__(self, seq: SUGraph): self.seq = seq # model full code self.init_code: List[str] = [ - '\n\n########## Generated Code ###########', + '\n\n########## Generated Schedule Code ###########', 'import torch', 'import cube', ''] # module member name self.symbols = SymbolTable() diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 04dc9fd2..35a41f3f 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -3,7 +3,7 @@ import torch -def send(tensors, to_ranks: List[List[int]]): +def send(tensors, to_ranks: List[int]): """ send tensor to the remote devices. Each tensor can be sent to multiple devices @@ -14,28 +14,22 @@ def send(tensors, to_ranks: List[List[int]]): """ print('sending...') send_ops = list() - for tensor, ranks in zip(tensors, to_ranks): - for rank in ranks: - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, rank - ) - send_ops.append(send_op) + for tensor, rank in zip(tensors, to_ranks): + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, rank + ) + send_ops.append(send_op) reqs = torch.distributed.batch_isend_irecv(send_ops) for req in reqs: req.wait() torch.cuda.synchronize() -def recv(shapes: List[List[int]], from_ranks: List[List[int]]): +def recv(shapes: List[List[int]], from_ranks: List[int]): print('recving...') recv_ops = list() recv_tensors = list() - for shape, ranks in zip(shapes, from_ranks): - if len(ranks) != 1: - raise RuntimeError( - "Not supported for recving same tensor from multiple devices" - ) - rank = ranks[0] + for shape, rank in zip(shapes, from_ranks): tensor = torch.empty( shape, requires_grad=True, device=torch.cuda.current_device() ) diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index ad4286f0..b6690f2c 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -6,12 +6,12 @@ from cube.schedule.sugraph import SUGraph from cube.schedule.graphpass import SUGraphPass -from cube.codegen.codegen import SScheduleCodeGen, TScheduleCodeGen +from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen class SemanticModel: - def __init__(self, model: torch.nn.Module, input_shapes): + def __init__(self, model: torch.nn.Module, input_shapes, policy_fn=None): """ Create semantic model based on AI Scientist description. """ @@ -19,6 +19,8 @@ def __init__(self, model: torch.nn.Module, input_shapes): self.ir_graph = parser.convert( model, input_shapes=input_shapes ) + if policy_fn: + self.ir_graph = policy_fn(self.ir_graph, None) self._loaded_module = None def get_graph(self): @@ -119,6 +121,7 @@ def decorator(fn: Callable) -> Callable: # graph pass to remove redundant sus su_graph = SUGraphPass.remove_redundant_adapters(su_graph) su_graph = SUGraphPass.merge_small_sus(su_graph) + print(su_graph) if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() @@ -126,14 +129,14 @@ def decorator(fn: Callable) -> Callable: world_size = 1 # code generation - tgener = TScheduleCodeGen(su_graph) - sgener = SScheduleCodeGen(su_graph) + mgener = ModelCodeGen(su_graph) + sgener = ScheduleCodeGen(su_graph) for rank in range(world_size): fname = filename.format(rank) # generate spatial module code - sgener.gen(rank, outfile=fname, attach=True) + mgener.gen(rank, outfile=fname, attach=False) # generate temporal schedule code - tgener.gen( + sgener.gen( device = rank, outfile = fname, attach=True diff --git a/cube/schedule/graphpass.py b/cube/schedule/graphpass.py index cb9f7d06..90691048 100644 --- a/cube/schedule/graphpass.py +++ b/cube/schedule/graphpass.py @@ -40,11 +40,19 @@ def merge_small_sus(sugraph: SUGraph) -> SUGraph: """ Merge SU to a larger one if possible """ - merged_su = None + devices = set() for su in sugraph.sus(): - if su.stype == SUType.Forward: - if not isinstance(merged_su, ScheduleUnit): - merged_su = su - continue - merged_su = sugraph.merge(merged_su, su) + devices.update(set(su.device)) + for device in devices: + dev_sus = [su for su in sugraph.sus() if device in su.device] + merged_su = None + for su in dev_sus: + if su.stype == SUType.Forward: + if not isinstance(merged_su, ScheduleUnit): + merged_su = su + continue + merged_su = sugraph.merge(merged_su, su) + if not isinstance(merged_su, ScheduleUnit): + merged_su = su + return sugraph diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 4134cc0d..aeca5d11 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -41,6 +41,7 @@ def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): inputs = IRCell.get_inputs(nodes) inputs = [input for input in inputs if not input.is_param()] outputs = IRCell.get_outputs(nodes) + outputs = [output for output in outputs if not output.is_param()] super().__init__( name = name, signature = stype.value, diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 0abacb95..0f36f204 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -1,5 +1,7 @@ +import enum from typing import List, Optional, Union import copy +from cube import schedule from cube.ir.cten import IRCell from cube.schedule.su import SUType, ScheduleUnit @@ -98,7 +100,7 @@ def happen_before(self, su1, su2): def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: """ - Merge two ScheduleUnit. This requires + Merge two ScheduleUnit as well as their adapters. This requires 1). all the nodes in one SU happens before / after all the nodes in another SU. (Guaranteed by default @@ -121,6 +123,62 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: if fail: None """ + def _adapter_merge(first_su: ScheduleUnit, second_su: ScheduleUnit, merged_su: ScheduleUnit): + # move from first_su adapter + # print(f' 1st SU: {first_su} \n 2nd SU: {second_su} \n merged SU: {merged_su}') + for idx, input in enumerate(first_su.inputs()): + send_adapters, recv_adapters = first_su.in_adapters(idx) + merge_adapter = first_su.merge_adapters(idx) + merge_idx = merged_su.inputs().index(input) + for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): + merged_su._add_in_adapter(merge_idx, send_adapter, recv_adapter) + if merge_adapter in self.sequence: + merged_su._set_merge_adapter(merge_idx, merge_adapter) + for idx, output in enumerate(first_su.outputs()): + send_adapters, recv_adapters = first_su.out_adapters(idx) + select_adapter = first_su.select_adapters(idx) + if output in merged_su.outputs() and output not in second_su.outputs(): + merge_idx = merged_su.outputs().index(output) + for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): + merged_su._add_out_adapter(merge_idx, send_adapter, recv_adapter) + if select_adapter: + merged_su._set_select_adapter(merge_idx, select_adapter) + else: + if merge_adapter in self.sequence: + self.sequence.remove(merge_adapter) + # move from su2 adapter + for idx, input in enumerate(second_su.inputs()): + send_adapters, recv_adapters = second_su.in_adapters(idx) + merge_adapter = second_su.merge_adapters(idx) + if input in merged_su.inputs() and input not in first_su.inputs(): + merge_idx = merged_su.inputs().index(input) + for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): + merged_su._add_in_adapter(merge_idx, send_adapter, recv_adapter) + if merge_adapter: + merged_su._set_merge_adapter(merge_idx, merge_adapter) + else: + for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): + # print(f'removing: {send_adapter}') + # print(f'removing: {recv_adapter}') + if send_adapter in self.sequence: + self.sequence.remove(send_adapter) + if recv_adapter in self.sequence: + self.sequence.remove(recv_adapter) + if merge_adapter in self.sequence: + self.sequence.remove(merge_adapter) + for idx, output in enumerate(second_su.outputs()): + send_adapters, recv_adapters = second_su.out_adapters(idx) + select_adapter = second_su.select_adapters(idx) + if output in merged_su.outputs(): + merge_idx = merged_su.outputs().index(output) + for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): + merged_su._add_out_adapter(merge_idx, send_adapter, recv_adapter) + if select_adapter: + merged_su._set_select_adapter(merge_idx, select_adapter) + else: + if select_adapter: + self.sequence.remove(select_adapter) + if not isinstance(su1, ScheduleUnit) or \ not isinstance(su2, ScheduleUnit): raise TypeError("Expected SU1 and SU2 are ScheduleUnit") @@ -134,8 +192,7 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: return None if su1.device != su2.device: return None - - #TODO: GraphPass on remove redundant adapter also need TODO + if su1.stype == SUType.Adapter: raise NotImplementedError("Not supported for merging Adapter") @@ -146,13 +203,19 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: index_su1, index_su2 = min(index_su1, index_su2), max(index_su1, index_su2) inter_sus = self.sequence[index_su1+1:index_su2] for su in inter_sus: - if self.happen_before(su1, su) and self.happen_before(su, su2): + # in theory the below condition satisfies merge, but it may + # break the topo order + # e.g., su1 -> adapter1 ,....., adapter2 -> su2 + # if self.happen_before(su1, su) and self.happen_before(su, su2): + # to keep topo order: + if self.happen_before(su, su2): return None # merge forward su sub_nodes = su1.nodes() + su2.nodes() merged_su = ScheduleUnit(sub_nodes, su1.stype) merged_su.device = su1.device + _adapter_merge(su1, su2, merged_su) # merge mirrored su # mirror_su2 -> mirror_su1 @@ -163,6 +226,7 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: sub_nodes = mirror_su2.nodes() + mirror_su1.nodes() merged_mirror_su = ScheduleUnit(sub_nodes, mirror_su1.stype) merged_mirror_su.device = mirror_su1.device + _adapter_merge(mirror_su2, mirror_su1, merged_mirror_su) # set mirror merged_su.set_mirror(merged_mirror_su) merged_mirror_su.set_mirror(merged_su) @@ -204,8 +268,7 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): """ Assign SU to devices. - The assignment will automatically trigger the generation of - Adapter SU. + The assignment will automatically set device of its Adapter SU. 1) if ranks has multiple int, then the su is copied as the same SU will be happened redundantly on multiple devices. @@ -245,18 +308,24 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): # set adapter device for the input for idx in range(len(su.inputs())): send_adapters, recv_adapters = su.in_adapters(idx) + merge_adapter = su.merge_adapters(idx) for send_adapter in send_adapters: send_adapter.nodes(0).send_ranks = [ranks[0],] for recv_adapter in recv_adapters: recv_adapter.device = ranks + if merge_adapter is not None: + merge_adapter.device = ranks # set adapter device for the output for idx in range(len(su.outputs())): send_adapters, recv_adapters = su.out_adapters(idx) + select_adapter = su.select_adapters(idx) for send_adapter in send_adapters: send_adapter.device = ranks for recv_adapter in recv_adapters: recv_adapter.nodes(0).recv_ranks = [ranks[0],] + if select_adapter is not None: + select_adapter.device = ranks return True def set_order(self, seq: List[ScheduleUnit]): diff --git a/examples/e2e.py b/examples/e2e.py index 5b452b6f..53013bbd 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -18,21 +18,16 @@ from cube.schedule.su import SUType -def spolicy(ir_graph): - for nid, node in enumerate(ir_graph.nodes()): - if nid < 3: - node.device = 0 - else: - node.device = 1 +def trans_policy(ir_graph, resource): return ir_graph -def tpolicy(seq): +def schedule_policy(sugraph, resource): # put to micro-batch forward-backward sequence fb_op_seqs = list() - for su in seq.sus(): + for su in sugraph.sus(): for fb_seq in fb_op_seqs: for ksu in fb_seq[::-1]: - if seq.happen_before(ksu, su): + if sugraph.happen_before(ksu, su): fb_seq.append(su) break else: @@ -41,17 +36,19 @@ def tpolicy(seq): else: fb_op_seqs.append([su]) - # merge to stages - for fb_seq in fb_op_seqs: - merged_su = fb_seq[0] - for su in fb_seq[1:]: - if su.stype == SUType.Backward: - continue - msu = seq.merge(merged_su, su) - merged_su = su if msu is None else msu - print(seq) - return seq - + for fb_sus in fb_op_seqs: + sugraph.assign(fb_sus[0], 0) + idx = 0 + for su in fb_sus[1:]: + if su.stype == SUType.Forward: + if idx < 3: + sugraph.assign(su, 0) + sugraph.assign(su.mirror, 0) + else: + sugraph.assign(su, 1) + sugraph.assign(su.mirror, 1) + idx += 1 + return sugraph class FakeDataLoader: @@ -97,14 +94,14 @@ def train(): batch_size = 64 model = FeedForward(dim=1024) - model = cube.sschedule.schedule( + model = cube.schedule.SemanticModel( model, input_shapes=([batch_size,1024],), - policy_fn=spolicy + policy_fn=trans_policy ) dataloader = FakeDataLoader(batch_size) - @cube.tschedule.schedule(model, dataloader, policy_fn=tpolicy) + @cube.schedule.schedule(model, dataloader, policy_fn=schedule_policy) def train_iter(model, dataloader): for _ in range(4): data = next(dataloader) diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index d6cf6884..672646b8 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -1,9 +1,30 @@ from cube.graph.tensor import IRFullTensor -from cube.graph.comm import IRCommunication from cube.graph.operator import IROperation from cube.graph.graph import IRGraph +from cube.schedule.pool import SchedulePool +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.schedule.translator import IRDataLoader +from cube.schedule.translator import LogicTranslator +from cube.schedule.graphpass import SUGraphPass +import torch -from cube.codegen.codegen import SScheduleCodeGen, TScheduleCodeGen +from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen + + +class FakeDataLoader: + def __init__(self, batch_size, num=640): + self.batch_size = batch_size + self.length = num + self.pos = 0 + def __iter__(self): + self.pos = 0 + return self + def __next__(self): + self.pos += 1 + if self.pos == self.length: + raise StopIteration + return torch.randn((self.batch_size, 1024)) def construct_graph(): @@ -14,6 +35,8 @@ def construct_graph(): weight2 = IRFullTensor(shape=[1024, 1024], name='weight') weight3 = IRFullTensor(shape=[1024, 1024], name='weight') bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + weight4 = IRFullTensor(shape=[1024, 1024], name='weight') + bias4 = IRFullTensor(shape=[1024, 1024], name='bias') # linear1 linear1 = IROperation( @@ -25,6 +48,7 @@ def construct_graph(): linear1.set_input(0, input) linear1.set_input(1, weight1) linear1.set_input(2, bias1) + linear1.infer_shape() # linear2 linear2 = IROperation( @@ -35,10 +59,11 @@ def construct_graph(): ) linear2.set_input(0, linear1.outputs(0)) linear2.set_input(1, weight2) + linear2.infer_shape() # linear3 linear3 = IROperation( - name='linear2', + name='linear3', signature='torch.nn.functional.linear', input_length=3, output_length=1 @@ -46,9 +71,22 @@ def construct_graph(): linear3.set_input(0, linear2.outputs(0)) linear3.set_input(1, weight3) linear3.set_input(2, bias3) + linear3.infer_shape() + + # linear4 + linear4 = IROperation( + name='linear4', + signature='torch.nn.functional.linear', + input_length=3, + output_length=1 + ) + linear4.set_input(0, linear3.outputs(0)) + linear4.set_input(1, weight4) + linear4.set_input(2, bias4) + linear4.infer_shape() graph = IRGraph( - nodes=[linear1, linear2, linear3], + nodes=[linear1, linear2, linear3, linear4], input_tensors=[input], output_tensors=linear3.outputs(), module_name="Test" @@ -58,4 +96,52 @@ def construct_graph(): def test_model_gen(): - \ No newline at end of file + SchedulePool().clear() + + graph = construct_graph() + dataloader = IRDataLoader(FakeDataLoader(batch_size=64)) + + data = next(dataloader) + output = graph(data) + output.backward() + + sus = SchedulePool().sus() + sus = LogicTranslator.gen_adapter(sus) + + sugraph = SUGraph(sus) + fsus = [su for su in sugraph.sus() if su.stype == SUType.Forward] + dsus = [su for su in sugraph.sus() if su.stype == SUType.Dataloader] + for dsu in dsus: + sugraph.assign(dsu, 0) + for idx, su in enumerate(fsus): + if su.stype == SUType.Forward: + if idx < 2: + sugraph.assign(su, 0) + sugraph.assign(su.mirror, 0) + else: + sugraph.assign(su, 1) + sugraph.assign(su.mirror, 1) + + sugraph = SUGraphPass.remove_redundant_adapters(sugraph) + sugraph = SUGraphPass.merge_small_sus(sugraph) + + print(sugraph) + + mgener = ModelCodeGen(sugraph) + tgener = ScheduleCodeGen(sugraph) + + mcode0 = mgener.gen(device = 0) + tcode0 = tgener.gen(device = 0) + print('model code on device 0: ') + print(mcode0) + print('schedule code on device 0: ') + print(tcode0) + + mcode1 = mgener.gen(device = 1) + tcode1 = tgener.gen(device = 1) + print('model code on device 1: ') + print(mcode1) + print('schedule code on device 1: ') + print(tcode1) + + assert False diff --git a/tests/schedule/test_graphpass.py b/tests/schedule/test_graphpass.py index 3955dfcf..68040497 100644 --- a/tests/schedule/test_graphpass.py +++ b/tests/schedule/test_graphpass.py @@ -98,23 +98,23 @@ def test_merge_small_sus(): # forward adatpers sus = SchedulePool().sus() + sus = LogicTranslator.gen_adapter(sus) sugraph = SUGraph(sus) for su in sugraph.sus(): - sugraph.assign(su, 0) + if su.stype != SUType.Adapter: + sugraph.assign(su, 0) print('orignal:') - for su in sugraph.sus(): - print(su) + print(sugraph) sugraph = SUGraphPass.merge_small_sus(sugraph) - print('changed:') - for su in sugraph.sus(): - print(su) + print('merged:') + print(sugraph) - assert len(sugraph.sus()) == 2 + assert len(sugraph.sus()) == 4 assert sugraph.sus(0).stype == SUType.Forward - assert sugraph.sus(1).stype == SUType.Backward - assert sugraph.sus(0).mirror == sugraph.sus(1) + assert sugraph.sus(3).stype == SUType.Backward + assert sugraph.sus(0).mirror == sugraph.sus(3) diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py index aff64ad7..734086ca 100644 --- a/tests/schedule/test_translator.py +++ b/tests/schedule/test_translator.py @@ -137,7 +137,7 @@ def test_translator_backward(): assert bsu.stype == SUType.Backward -def test_translatro_gen_adapter(): +def test_translator_gen_adapter(): SchedulePool().clear() graph = construct_graph() From f3c3dadbb02f8ed04a41fe7110e0ae27d3ea9750 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 1 Nov 2021 16:04:31 +0800 Subject: [PATCH 0238/1892] add policy for recompute / zero --- .../{ => policy}/megatron_policy.py | 0 examples/case_study/policy/recompute.py | 39 +++++++++++++++++++ examples/case_study/policy/zero.py | 26 +++++++++++++ 3 files changed, 65 insertions(+) rename examples/case_study/{ => policy}/megatron_policy.py (100%) create mode 100644 examples/case_study/policy/recompute.py create mode 100644 examples/case_study/policy/zero.py diff --git a/examples/case_study/megatron_policy.py b/examples/case_study/policy/megatron_policy.py similarity index 100% rename from examples/case_study/megatron_policy.py rename to examples/case_study/policy/megatron_policy.py diff --git a/examples/case_study/policy/recompute.py b/examples/case_study/policy/recompute.py new file mode 100644 index 00000000..d12f03cf --- /dev/null +++ b/examples/case_study/policy/recompute.py @@ -0,0 +1,39 @@ +from cube.schedule.su import SUType + + +def transformation_policy(graph, resource): + + def _recompute_op(graph, ops): + """ + PyTorch Checkpointing + """ + for op in ops[1:-1]: + for idx, output in enumerate(op.outputs()): + succ_ops = graph.successors(op, idx) + succ_ops = [ + op for op in succ_ops if op.type == SUType.Backward + ] + # remove output tensor connection between op -> [succ_ops], + # duplicate op with to connect with succ_ops + graph.incarnation(output, op, succ_ops) + + # checkpointing tensor + chunk_num = 4 + # forward ops + fops = [node for node in graph.nodes() if node.type == SUType.Forward] + chunk_size = int(len(fops) // chunk_num) + for cid in range(chunk_num): + chunk_fops = fops[chunk_size * cid, chunk_size * (cid + 1)] + _recompute_op(graph, chunk_fops) + + +def schedule_policy(sugraph, resource): + + for su in sugraph.sus(): + sugraph.assign(su, 0) + if su.is_incarnation(): + succ_sus = sugraph.successors(su) + for succ_su in succ_sus: + if sugraph.merge(su, succ_su): + break + sugraph.set_order(sugraph.random_topo_order()) diff --git a/examples/case_study/policy/zero.py b/examples/case_study/policy/zero.py new file mode 100644 index 00000000..960d024f --- /dev/null +++ b/examples/case_study/policy/zero.py @@ -0,0 +1,26 @@ +from cube.schedule.su import SUType + +def transformation_policy(graph, resource): + + for op in graph.nodes(): + if op.type == SUType.Forward: + algorithm = op.algorithms('data_parallelism') + sub_graph = graph.partition(op, algorithm, config=dict(chunk_size=resource.ngpus)) + if op.type == SUType.Optimizer: + algorithm = op.algorithms('split_axis_0') + sub_graph = graph.partition(op, algorithm, config=dict(chunk_size=resource.ngpus)) + + return graph + + +def schedule_policy(sugraph, resource): + + semantic_ops = dict() + for su in sugraph.sus(): + if su.nodes(0).semantic_ops not in semantic_ops: + semantic_ops[su.nodes(0).semantic_ops] = list() + semantic_ops[su.nodes(0).semantic_ops].append(su) + for semantic_op in semantic_ops: + for idx, su in enumerate(semantic_ops[semantic_op]): + gpu_id = idx % resource.ngpus + sugraph.assign(su, gpu_id) From 4e0cc96a3c1f1ac3c7a79616cdc0e147fc156fb5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 1 Nov 2021 19:23:55 +0800 Subject: [PATCH 0239/1892] primitives for incarnation --- examples/case_study/policy/recompute.py | 34 ++++++++++++++++++++----- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/examples/case_study/policy/recompute.py b/examples/case_study/policy/recompute.py index d12f03cf..75c71c27 100644 --- a/examples/case_study/policy/recompute.py +++ b/examples/case_study/policy/recompute.py @@ -1,21 +1,43 @@ from cube.schedule.su import SUType +def choose_input(op, input_incarnation): pass +def choose_output(op, output_incarnation): pass +def create_incar(graph, tensor_or_op): pass + def transformation_policy(graph, resource): - def _recompute_op(graph, ops): + def _recompute_ops(graph, ops): """ PyTorch Checkpointing """ - for op in ops[1:-1]: + tensors_incar = list() + ops_incar = list() + + for op in ops[:-1]: + op_incar = graph.create_incar(op) + ops_incar.append(op_incar) + for output in op.outputs(): + tensor_incar = graph.create_incar(output) + tensors_incar.append(tensor_incar) + ops_incar.choose_output(tensor_incar) + for op in ops_incar[1:]: + for input in op.outputs(): + for input_incar in input.get_incar(): + if input_incar in tensors_incar: + graph.choose_input(op, input_incar) + # else keep in memory + for op in ops[1:]: for idx, output in enumerate(op.outputs()): succ_ops = graph.successors(op, idx) succ_ops = [ op for op in succ_ops if op.type == SUType.Backward ] - # remove output tensor connection between op -> [succ_ops], - # duplicate op with to connect with succ_ops - graph.incarnation(output, op, succ_ops) + for succ_op in succ_ops: + for input in succ_op.inputs(): + for input_incar in input.get_incar(): + if input_incar in tensors_incar: + graph.choose_input(succ_op, input_incar) # checkpointing tensor chunk_num = 4 @@ -24,7 +46,7 @@ def _recompute_op(graph, ops): chunk_size = int(len(fops) // chunk_num) for cid in range(chunk_num): chunk_fops = fops[chunk_size * cid, chunk_size * (cid + 1)] - _recompute_op(graph, chunk_fops) + _recompute_ops(graph, chunk_fops) def schedule_policy(sugraph, resource): From f9a1af9ff547e4bb83f369ff2b5cb9d5c0ee5ed7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Nov 2021 09:43:53 +0800 Subject: [PATCH 0240/1892] fix typo --- examples/case_study/policy/recompute.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/case_study/policy/recompute.py b/examples/case_study/policy/recompute.py index 75c71c27..24ff55b4 100644 --- a/examples/case_study/policy/recompute.py +++ b/examples/case_study/policy/recompute.py @@ -2,7 +2,7 @@ def choose_input(op, input_incarnation): pass def choose_output(op, output_incarnation): pass -def create_incar(graph, tensor_or_op): pass +def create_incar(tensor_or_op): pass def transformation_policy(graph, resource): @@ -20,9 +20,9 @@ def _recompute_ops(graph, ops): for output in op.outputs(): tensor_incar = graph.create_incar(output) tensors_incar.append(tensor_incar) - ops_incar.choose_output(tensor_incar) + graph.choose_output(ops_incar, tensor_incar) for op in ops_incar[1:]: - for input in op.outputs(): + for input in op.inputs(): for input_incar in input.get_incar(): if input_incar in tensors_incar: graph.choose_input(op, input_incar) From 5d6c046118d52ea2e9ec1c657113a61d20e4540c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Nov 2021 18:02:26 +0800 Subject: [PATCH 0241/1892] operator function --- cube/graph/graph.py | 11 +++ cube/graph/mapping.py | 42 --------- cube/graph/operator/__init__.py | 1 + cube/graph/operator/function.py | 117 ++++++++++++++++++++++++++ cube/graph/{ => operator}/operator.py | 27 +----- cube/graph/parser/mapping.py | 44 ++++++++++ cube/graph/parser/parser.py | 34 ++++---- 7 files changed, 192 insertions(+), 84 deletions(-) delete mode 100644 cube/graph/mapping.py create mode 100644 cube/graph/operator/__init__.py create mode 100644 cube/graph/operator/function.py rename cube/graph/{ => operator}/operator.py (61%) create mode 100644 cube/graph/parser/mapping.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 363d4bc1..178a245f 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -314,6 +314,17 @@ def subgraph(self, sub_nodes: List[IRCell]): return graph + ## Primitives for policy expression ## + + def partition(self, op, op_partition_algorithm, config): + raise NotImplementedError + + def merge(self, sub_graph, target_op, op_partition_algorithm): + raise NotImplementedError + + def identity(self, input_tensor, dst_op): + raise NotImplementedError + def __repr__(self): dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs diff --git a/cube/graph/mapping.py b/cube/graph/mapping.py deleted file mode 100644 index 17a58614..00000000 --- a/cube/graph/mapping.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -Mapping of - IROperation -> cube.operator.logic.generics.GenericLogicalOp -""" - -import cube.operator.logic as logic - -class IR2LogicOp: - - @staticmethod - def map(signature: str) -> logic.GenericLogicalOp : - """ - Map the signature to GenericLogicalOp - """ - if signature in IR2LogicOp.kOpMap: - return IR2LogicOp.kOpMap[signature] - # return None - raise KeyError(f"{signature} is not supported yet") - - # functional templates - __ftemplate = lambda name: f'torch.nn.functional.{name}' - - # tensor template - __ttemplate = lambda name: f'torch.{name}' - - kOpMap = { - - __ftemplate('linear') : logic.Linear, - - __ftemplate('dropout') : logic.Dropout, - - __ftemplate('gelu') : logic.GeLU, - - __ttemplate('add') : logic.TensorAdd, - - __ttemplate('sum') : logic.TensorSum, - - # runtime collectives - 'cube.runtime.spatial.move': 'move', - - } - diff --git a/cube/graph/operator/__init__.py b/cube/graph/operator/__init__.py new file mode 100644 index 00000000..77f5f416 --- /dev/null +++ b/cube/graph/operator/__init__.py @@ -0,0 +1 @@ +from cube.graph.operator.operator import IROperation \ No newline at end of file diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py new file mode 100644 index 00000000..a86918b7 --- /dev/null +++ b/cube/graph/operator/function.py @@ -0,0 +1,117 @@ +import copy + +from cube.graph.operator import IROperation +from cube.ir.cten import IRTensor + + +class Linear(IROperation): + + def __init__(self, signature, inputs, name='linear', **kwargs): + + input, weight, bias = inputs + super().__init__( + name, signature, + input_length=3, + output_length=1 + ) + self.set_input(0, input) + self.set_input(1, weight) + self.set_input(2, bias) + + def infer_shape(self): + """ + input: [(D), M, K] + weight: [N, K] + bias: [N,] + """ + if len(self.inputs(0).shape) != 0 and len(self.inputs(1).shape) != 0: + shape = self.inputs(0).shape[:-1] + self.inputs(1).shape[:1] + self._outputs[0].shape = shape + return True + return False + + +class ElementWise(IROperation): + """ + Functions like torch.add (tensor1 + tensor2 / scaler) + """ + + def __init__(self, signature, inputs, name='elementwise', **kwargs): + + super().__init__( + name, signature, + input_length=len(inputs), + output_length=1 + ) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + + def infer_shape(self): + for input in self.inputs(): + if isinstance(input, IRTensor): + if len(input.shape) != 0: + self._outputs[0].shape = copy.copy(input.shape) + return True + return False + return False + + +class ElementWiseActivation(IROperation): + """ + functions like GELU, RELU, Dropout. + + Exclude softmax + """ + + def __init__(self, signature, inputs, name='elementwise_activation', **kwargs): + + super().__init__( + name, signature, + input_length=len(inputs), + output_length=1 + ) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + + def infer_shape(self): + for input in self.inputs(): + if isinstance(input, IRTensor): + if len(input.shape) != 0: + self._outputs[0].shape = copy.copy(input.shape) + return True + return False + return False + + +class Reduce(IROperation): + """ + functions like sum, mean, cross_entropy + """ + def __init__(self, signature, inputs, name='reduce', **kwargs): + super().__init__( + name, signature, + input_length=len(inputs), + output_length=1 + ) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + + def infer_shape(self): + self._outputs[0].shape = [1] + return True + + +class UnkownOperator(IROperation): + + def __init__(self, signature, inputs, name='unknown_op', n_output=None): + + super().__init__( + name, signature=signature, + input_length=len(inputs), + output_length=n_output, + ) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + + def infer_shape(self): + return False diff --git a/cube/graph/operator.py b/cube/graph/operator/operator.py similarity index 61% rename from cube/graph/operator.py rename to cube/graph/operator/operator.py index e323e095..c0fb0bb2 100644 --- a/cube/graph/operator.py +++ b/cube/graph/operator/operator.py @@ -1,8 +1,5 @@ -from typing import List, Union - from cube.ir.cten import IRTensor, IRCell from cube.graph.tensor import IRFullTensor -from cube.graph.mapping import IR2LogicOp __call__ = ['IROperation'] @@ -28,43 +25,25 @@ def __init__(self, outputs = [IRFullTensor() for _ in range(output_length)] for idx, output in enumerate(outputs): self.set_output(idx, output) - self.semantic = IR2LogicOp.map(self.signature) def infer_shape(self): """ Infer output value shape """ - shapes = list() - for input in self.inputs(): - if isinstance(input, IRTensor): - if input.shape is None: - return False - shapes.append(input.shape) - else: - shapes.append([1,]) - shapes = tuple(shapes) - out_shapes = self.semantic.shape_infer(*shapes) - if len(out_shapes) != len(self._outputs): - raise RuntimeError( - "The logical op semantic doesn't match with parsed op" - ) - for shape, val in zip(out_shapes, self._outputs): - if isinstance(val, IRTensor): - val.shape = shape - return True + raise NotImplementedError def __repr__(self): inputs = list() for tensor in self.inputs(): if isinstance(tensor, IRTensor): - inputs.append(f't{tensor._id}-dev{tensor.device}') + inputs.append(f't{tensor._id}') else: inputs.append(tensor) outputs = list() for tensor in self.outputs(): if isinstance(tensor, IRTensor): - outputs.append(f't{tensor._id}-dev{tensor.device}') + outputs.append(f't{tensor._id}') else: outputs.append(tensor) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py new file mode 100644 index 00000000..15f628e6 --- /dev/null +++ b/cube/graph/parser/mapping.py @@ -0,0 +1,44 @@ +""" +Mapping of + Signature -> IROperator +""" +from functools import partial + +import cube.graph.operator.function as function +from cube.graph.operator.operator import IROperation + + +class Sign2Op: + + @staticmethod + def map(signature: str) -> IROperation: + """ + Map the signature to GenericLogicalOp + """ + if signature in Sign2Op.kOpMap: + return partial(Sign2Op.kOpMap[signature], signature=signature) + else: + raise KeyError(f"{signature} is not supported yet") + # print(f'warning: {signature} is not recognized') + # return partial(function.UnkownOperator, signature=signature) + + # functional templates + __ftemplate = lambda name: f'torch.nn.functional.{name}' + + # tensor template + __ttemplate = lambda name: f'torch.{name}' + + kOpMap = { + + __ftemplate('linear') : function.Linear, + + __ftemplate('dropout') : partial(function.ElementWiseActivation, name='dropout'), + + __ftemplate('gelu') : partial(function.ElementWiseActivation, name='gelu'), + + __ttemplate('add') : partial(function.ElementWise, name='add'), + + __ttemplate('sum') : partial(function.Reduce, name='sum'), + + } + diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index b6aab7b2..fd61477b 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -6,6 +6,8 @@ from cube.graph import IROperation from cube.graph.tensor import IRFullTensor from cube.graph.parser.frame import Frame +from cube.graph.parser.mapping import Sign2Op + class ScriptNodeKind(enum.Enum): PrimGetAttr = 1 @@ -113,20 +115,19 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IROperation]: raise RuntimeError(f"Found unexpected function call node: {fnode}") fsig = frame.get_var(inputs[0].debugName()) - # create IR node - ir_node = IROperation( - signature = fsig, - name = fnode.s('name'), - input_length=len(inputs) - 1, - output_length=len(outputs), - ) - # handle inputs -- in stack with reverse order + input_vals = list() for index, input in enumerate(inputs[1:]): var_name = input.debugName() val = frame.get_var(var_name) - ir_node.set_input(index, val) - + input_vals.append(val) + + ir_node = Sign2Op.map(fsig)(inputs=input_vals, n_outputs=len(outputs)) + if len(ir_node.outputs()) != len(outputs): + raise RuntimeError( + f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" + ) + # handle outputs for index, output in enumerate(outputs): frame.add_var(output.debugName(), ir_node.outputs(index)) @@ -167,14 +168,11 @@ def parse_aten_node(node, module, frame: Frame) -> List[IROperation]: print(f"Warning: some non-tensor arguments are ommited in {fsig}") # create IR node - ir_node = IROperation( - signature = fsig, - name = fsig, - input_length = len(input_val), - output_length = len(outputs) - ) - for index, val in enumerate(input_val): - ir_node.set_input(index, val) + ir_node = Sign2Op.map(fsig)(inputs=input_val, n_output=len(outputs)) + if len(ir_node.outputs()) != len(outputs): + raise RuntimeError( + f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" + ) # handle outputs for index, output in enumerate(outputs): From cb41906eaaa1fc400061d742f9b49cd6503dbec8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Nov 2021 18:32:02 +0800 Subject: [PATCH 0242/1892] fix merge bugs --- cube/schedule/sugraph.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 0f36f204..700995e0 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -1,7 +1,5 @@ -import enum from typing import List, Optional, Union import copy -from cube import schedule from cube.ir.cten import IRCell from cube.schedule.su import SUType, ScheduleUnit @@ -208,7 +206,7 @@ def _adapter_merge(first_su: ScheduleUnit, second_su: ScheduleUnit, merged_su: S # e.g., su1 -> adapter1 ,....., adapter2 -> su2 # if self.happen_before(su1, su) and self.happen_before(su, su2): # to keep topo order: - if self.happen_before(su, su2): + if su.stype != SUType.Adapter and self.happen_before(su, su2): return None # merge forward su From b4113b60581d9ab7593084f89ce0c82fea0ff779 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Nov 2021 18:32:46 +0800 Subject: [PATCH 0243/1892] fix tests --- tests/graph/test_graph.py | 27 ++++++++------------------- tests/schedule/test_graphpass.py | 27 ++++++++------------------- tests/schedule/test_pool.py | 2 +- tests/schedule/test_su.py | 27 ++++++++------------------- tests/schedule/test_sugraph.py | 27 ++++++++------------------- tests/schedule/test_translator.py | 27 ++++++++------------------- 6 files changed, 41 insertions(+), 96 deletions(-) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 102e3524..13cee965 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -1,6 +1,6 @@ from cube.graph.graph import IRGraph from cube.graph.tensor import IRFullTensor, IRSubTensor -from cube.graph.operator import IROperation +from cube.graph.operator.function import Linear from cube.ir.cten import IRTensor @@ -14,38 +14,27 @@ def construct_model(): bias3 = IRFullTensor(shape=[1024, 1024], name='bias') # linear1 - linear1 = IROperation( + linear1 = Linear( name='linear1', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [input, weight1, bias1], ) - linear1.set_input(0, input) - linear1.set_input(1, weight1) - linear1.set_input(2, bias1) linear1.infer_shape() # linear2 - linear2 = IROperation( + linear2 = Linear( name='linear2', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear1.outputs(0), weight2, None], ) - linear2.set_input(0, linear1.outputs(0)) - linear2.set_input(1, weight2) linear2.infer_shape() # linear3 - linear3 = IROperation( - name='linear2', + linear3 = Linear( + name='linear3', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear2.outputs(0), weight3, bias3], ) - linear3.set_input(0, linear2.outputs(0)) - linear3.set_input(1, weight3) - linear3.set_input(2, bias3) linear3.infer_shape() # return [input], [ops], [output] diff --git a/tests/schedule/test_graphpass.py b/tests/schedule/test_graphpass.py index 68040497..a47ad736 100644 --- a/tests/schedule/test_graphpass.py +++ b/tests/schedule/test_graphpass.py @@ -1,5 +1,5 @@ from cube.graph.tensor import IRFullTensor -from cube.graph.operator import IROperation +from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph @@ -20,38 +20,27 @@ def construct_graph(): bias3 = IRFullTensor(shape=[1024, 1024], name='bias') # linear1 - linear1 = IROperation( + linear1 = Linear( name='linear1', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [input, weight1, bias1], ) - linear1.set_input(0, input) - linear1.set_input(1, weight1) - linear1.set_input(2, bias1) linear1.infer_shape() # linear2 - linear2 = IROperation( + linear2 = Linear( name='linear2', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear1.outputs(0), weight2, None], ) - linear2.set_input(0, linear1.outputs(0)) - linear2.set_input(1, weight2) linear2.infer_shape() # linear3 - linear3 = IROperation( - name='linear2', + linear3 = Linear( + name='linear3', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear2.outputs(0), weight3, bias3], ) - linear3.set_input(0, linear2.outputs(0)) - linear3.set_input(1, weight3) - linear3.set_input(2, bias3) linear3.infer_shape() graph = IRGraph( diff --git a/tests/schedule/test_pool.py b/tests/schedule/test_pool.py index 16bb4fb6..d375f209 100644 --- a/tests/schedule/test_pool.py +++ b/tests/schedule/test_pool.py @@ -1,7 +1,7 @@ from cube.schedule.pool import SchedulePool from cube.schedule.su import SUType, ScheduleUnit -from cube.ir.cten import IRCell, IRTensor +from cube.ir.cten import IRCell def test_schedule_pool(): diff --git a/tests/schedule/test_su.py b/tests/schedule/test_su.py index 9f3c4f23..8964e0f5 100644 --- a/tests/schedule/test_su.py +++ b/tests/schedule/test_su.py @@ -1,7 +1,7 @@ import copy from cube.graph.tensor import IRFullTensor -from cube.graph.operator import IROperation +from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph from cube.schedule.su import SUType, ScheduleUnit @@ -17,38 +17,27 @@ def construct_model(): bias3 = IRFullTensor(shape=[1024, 1024], name='bias') # linear1 - linear1 = IROperation( + linear1 = Linear( name='linear1', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [input, weight1, bias1], ) - linear1.set_input(0, input) - linear1.set_input(1, weight1) - linear1.set_input(2, bias1) linear1.infer_shape() # linear2 - linear2 = IROperation( + linear2 = Linear( name='linear2', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear1.outputs(0), weight2, None], ) - linear2.set_input(0, linear1.outputs(0)) - linear2.set_input(1, weight2) linear2.infer_shape() # linear3 - linear3 = IROperation( - name='linear2', + linear3 = Linear( + name='linear3', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear2.outputs(0), weight3, bias3], ) - linear3.set_input(0, linear2.outputs(0)) - linear3.set_input(1, weight3) - linear3.set_input(2, bias3) linear3.infer_shape() # return [input], [ops], [output] diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index b2a53c04..1d3602be 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -1,5 +1,5 @@ from cube.graph.tensor import IRFullTensor -from cube.graph.operator import IROperation +from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph from cube.schedule.su import SUType, ScheduleUnit @@ -17,38 +17,27 @@ def construct_graph(): bias3 = IRFullTensor(shape=[1024, 1024], name='bias') # linear1 - linear1 = IROperation( + linear1 = Linear( name='linear1', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [input, weight1, bias1], ) - linear1.set_input(0, input) - linear1.set_input(1, weight1) - linear1.set_input(2, bias1) linear1.infer_shape() # linear2 - linear2 = IROperation( + linear2 = Linear( name='linear2', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear1.outputs(0), weight2, None], ) - linear2.set_input(0, linear1.outputs(0)) - linear2.set_input(1, weight2) linear2.infer_shape() # linear3 - linear3 = IROperation( - name='linear2', + linear3 = Linear( + name='linear3', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear2.outputs(0), weight3, bias3], ) - linear3.set_input(0, linear2.outputs(0)) - linear3.set_input(1, weight3) - linear3.set_input(2, bias3) linear3.infer_shape() graph = IRGraph( diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py index 734086ca..a0ccb3d8 100644 --- a/tests/schedule/test_translator.py +++ b/tests/schedule/test_translator.py @@ -6,7 +6,7 @@ from cube.schedule.pool import SchedulePool from cube.graph.tensor import IRFullTensor, IRSubTensor -from cube.graph.operator import IROperation +from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph @@ -35,38 +35,27 @@ def construct_graph(): bias3 = IRFullTensor(shape=[1024, 1024], name='bias') # linear1 - linear1 = IROperation( + linear1 = Linear( name='linear1', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [input, weight1, bias1], ) - linear1.set_input(0, input) - linear1.set_input(1, weight1) - linear1.set_input(2, bias1) linear1.infer_shape() # linear2 - linear2 = IROperation( + linear2 = Linear( name='linear2', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear1.outputs(0), weight2, None], ) - linear2.set_input(0, linear1.outputs(0)) - linear2.set_input(1, weight2) linear2.infer_shape() # linear3 - linear3 = IROperation( - name='linear2', + linear3 = Linear( + name='linear3', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear2.outputs(0), weight3, bias3], ) - linear3.set_input(0, linear2.outputs(0)) - linear3.set_input(1, weight3) - linear3.set_input(2, bias3) linear3.infer_shape() graph = IRGraph( From 0770b2f582bcd10b0ac936558b3c6384191b4d5f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Nov 2021 19:42:30 +0800 Subject: [PATCH 0244/1892] init algorithm --- cube/algorithm/__init__.py | 0 cube/algorithm/factory.py | 43 +++++++++++++ cube/algorithm/generics.py | 54 ++++++++++++++++ cube/algorithm/linear.py | 129 +++++++++++++++++++++++++++++++++++++ cube/algorithm/utils.py | 34 ++++++++++ 5 files changed, 260 insertions(+) create mode 100644 cube/algorithm/__init__.py create mode 100644 cube/algorithm/factory.py create mode 100644 cube/algorithm/generics.py create mode 100644 cube/algorithm/linear.py create mode 100644 cube/algorithm/utils.py diff --git a/cube/algorithm/__init__.py b/cube/algorithm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py new file mode 100644 index 00000000..062b272f --- /dev/null +++ b/cube/algorithm/factory.py @@ -0,0 +1,43 @@ +from typing import Dict, Any + + +class DistAlgorithmFactory: + + class __DistAlgorithmFactory: + + def __init__(self): + # [LogicOp][tag] = algorithm + self._algos: Dict[Any, Dict[str, Any]] = dict() + + instance = None + + def __init__(self): + if not DistAlgorithmFactory.instance: + DistAlgorithmFactory.instance = DistAlgorithmFactory.__DistAlgorithmFactory() + + def __getattr__(self, name): + return getattr(self.instance, name) + + def register(self, op, algorithm, tag: str): + """ + Register a holistic op (class) as one of the anchors + """ + if op not in self.instance._algos: + self.instance._algos[op] = dict() + self._algos[op][tag] = algorithm + + def algorithms(self, op, tag = None): + """ + Get op tranformed algorithms + + Args: + op (IROperation): index for the holist op factory + args, kwargs: (logical) tensor inputs + + Returns: + algorithm class + """ + if tag: + return self.instance._algos[op][tag] + else: + return self.instance._algos[op].values() diff --git a/cube/algorithm/generics.py b/cube/algorithm/generics.py new file mode 100644 index 00000000..bc1a7686 --- /dev/null +++ b/cube/algorithm/generics.py @@ -0,0 +1,54 @@ +from typing import List, Dict, Optional + + +class GenericDistAlgo: + + def __init__(self, + input_shapes: List[Optional[List[int]]], + output_shapes: List[List[int]]): + """ + Layout is the community distribution requirement for input and + output logical tensors. + + Format is the dimension ordering based on the logical format, + `None` indicates the format is consistent with logical op, + otherwise should be a list of integers like torch.Tensor.permute() + on the logical required format. + + Args: + input_layout (list[Outliner, None]): outliner for each input. + The length of outliner should be equal to the number of input + output_layout (list[Outlinter, None]): outliner for each output + The length of outliner should be equal to the number of output + # TODO: + input_format (list[list[int], None]): + input dim order compare with logical definition + output_format (list[list[int], None]): + output dim order compare with logical definition + """ + + self.input_shapes = input_shapes + self.output_shapes = output_shapes + + self.logical_op = None + + def set_logic_op(self, logic_op): + """ + Set logic op. This will be automatically called when the + holistic op registered in a logical op. + """ + # if not isinstance(logic_op, GenericLogicalOp): + # raise TypeError("Require a logic op to register") + self.logical_op = logic_op + + def satisfy(self, config: Dict): + """ + Check if the config satisfies instantiation conditions + """ + raise NotImplementedError + + def instantiate(self, config: Dict): + """ + Instantiate the algorithm given the config + """ + raise NotImplementedError \ No newline at end of file diff --git a/cube/algorithm/linear.py b/cube/algorithm/linear.py new file mode 100644 index 00000000..8b15fe2d --- /dev/null +++ b/cube/algorithm/linear.py @@ -0,0 +1,129 @@ +from typing import List, Optional, Dict + +from cube.algorithm.utils import split_axis, split_value +from cube.algorithm.generics import GenericDistAlgo + +from cube.operator.logic.function import Linear + + +_kWaitDecision = None + + +class LinearDataParallel(GenericDistAlgo): + + def __init__(self, input_shapes: List[Optional[List[int]]], output_shapes: List[int]): + + super().__init__(input_shapes, output_shapes) + + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + input_shape = self.input_shapes[0] + if input_shape[0] % chunk_num != 0: + return False + return True + + def instantiate(self, inputs, outputs, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + input, weight, bias = inputs + output = outputs[0] + + ins = split_axis(input, 0, self.chunk_num) + outs = split_axis(output, 0, self.chunk_num) + + nodes = list() + for input_chunk, output_chunk in zip(ins, outs): + node = Linear( + signature='torch.nn.functional.linear', + inputs=[input_chunk, weight, bias], + name='linear' + ) + node.set_output(0, output_chunk) + nodes.append(node) + return nodes + + +class LinearColumnWeight(GenericDistAlgo): + + def __init__(self, input_shapes: List[Optional[List[int]]], output_shapes: List[int]): + + super().__init__(input_shapes, output_shapes) + + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + weight_shape = self.input_shapes[1] + if weight_shape[0] % chunk_num != 0: + return False + return True + + def instantiate(self, inputs, outputs, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + input, weight, bias = inputs + output = outputs[0] + + ws = split_axis(weight, 0, self.chunk_num) + if bias is not None: + bs = split_axis(bias, 0, self.chunk_num) + else: + bs = [None] * self.chunk_num + os = split_axis(output, 1, self.chunk_num) + + nodes = list() + for w, b, o in zip(ws, bs, os): + node = Linear( + signature='torch.nn.functional.linear', + inputs=[input, w, b], + name='linear' + ) + node.set_output(0, o) + nodes.append(node) + return nodes + + +class LinearRowWeight(GenericDistAlgo): + + def __init__(self, input_shapes: List[Optional[List[int]]], output_shapes: List[int]): + + super().__init__(input_shapes, output_shapes) + + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + weight_shape = self.input_shapes[1] + if weight_shape[1] % chunk_num != 0: + return False + return True + + def instantiate(self, inputs, outputs, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + input, weight, bias = inputs + output = outputs[0] + + ins = split_axis(input, 1, self.chunk_num) + ws = split_axis(weight, 1, self.chunk_num) + if bias: + bs = split_value(bias, self.chunk_num) + else: + bs = [None] * self.chunk_num + os = split_value(output, self.chunk_num) + + nodes = list() + for x, w, b, o in zip(ins, ws, bs, os): + node = Linear( + signature='torch.nn.functional.linear', + inputs=[x, w, b], + name='linear' + ) + node.set_output(0, o) + nodes.append(node) + return nodes diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py new file mode 100644 index 00000000..2211027d --- /dev/null +++ b/cube/algorithm/utils.py @@ -0,0 +1,34 @@ + +from cube.ir.cten import IRTensor + + +def split_axis(tensor: IRTensor, axis: int, chunk_num: int): + + if axis >= len(tensor.shape): + raise RuntimeError(f"Axis should within dims ({axis} >= {len(tensor.shape)})") + + chunk_size = int(tensor.shape[axis] // chunk_num) + + shape_slicer = list() + chunk_shape = list() + for dim, nele in enumerate(tensor.shape): + if dim != axis: + shape_slicer.append(slice(0, nele, 1)) + chunk_shape.append(nele) + else: + shape_slicer.append(None) + chunk_shape.append(chunk_size) + + sub_tensors = list() + for cid in range(chunk_size): + shape_slicer[axis] = slice(chunk_size * cid, chunk_size * (cid + 1)) + sub_tensors.append(tensor.select( + indices = tuple(shape_slicer), + val_op = None, + shape = chunk_shape + )) + return sub_tensors + + +def split_value(tensor: IRTensor, chunk_num: int): + raise NotImplementedError \ No newline at end of file From 86c45296e2cbc919861e81ad049c7dd3fce948e5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Nov 2021 19:56:31 +0800 Subject: [PATCH 0245/1892] remove original operator --- cube/algorithm/linear.py | 8 ++- cube/operator/__init__.py | 2 - cube/operator/logic/__init__.py | 1 - cube/operator/logic/function.py | 57 -------------------- cube/operator/logic/generics.py | 94 --------------------------------- 5 files changed, 7 insertions(+), 155 deletions(-) delete mode 100644 cube/operator/__init__.py delete mode 100644 cube/operator/logic/__init__.py delete mode 100644 cube/operator/logic/function.py delete mode 100644 cube/operator/logic/generics.py diff --git a/cube/algorithm/linear.py b/cube/algorithm/linear.py index 8b15fe2d..771dd118 100644 --- a/cube/algorithm/linear.py +++ b/cube/algorithm/linear.py @@ -2,8 +2,9 @@ from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo +from cube.algorithm.factory import DistAlgorithmFactory -from cube.operator.logic.function import Linear +from cube.graph.operator.function import Linear _kWaitDecision = None @@ -127,3 +128,8 @@ def instantiate(self, inputs, outputs, config: Dict): node.set_output(0, o) nodes.append(node) return nodes + + +DistAlgorithmFactory().register(Linear, LinearDataParallel, tag='data') +DistAlgorithmFactory().register(Linear, LinearColumnWeight, tag='column') +DistAlgorithmFactory().register(Linear, LinearRowWeight, tag='row') diff --git a/cube/operator/__init__.py b/cube/operator/__init__.py deleted file mode 100644 index e8250dad..00000000 --- a/cube/operator/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -import cube.operator.logic as logic - diff --git a/cube/operator/logic/__init__.py b/cube/operator/logic/__init__.py deleted file mode 100644 index 46a60fcb..00000000 --- a/cube/operator/logic/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cube.operator.logic.function import * \ No newline at end of file diff --git a/cube/operator/logic/function.py b/cube/operator/logic/function.py deleted file mode 100644 index 2090ae7b..00000000 --- a/cube/operator/logic/function.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import List, Optional - -from cube.operator.logic.generics import GenericLogicalOp -from cube.operator.logic.generics import ElementSameInputOp - - -class Linear(GenericLogicalOp): - - @staticmethod - def candidates(): - raise NotImplementedError - - @staticmethod - def shape_infer(input: List[int], - weight: List[int], - bias: Optional[List[int]] = None): - """ - input: [(D), M, K] - weight: [N, K] - bias: [N,] - """ - out_shape = list(input) - out_shape[-1] = weight[0] - return [out_shape] - - def translate(self, config): - raise NotImplementedError - - -class GeLU(ElementSameInputOp): - - def __init__(self, signature: str): - super().__init__(signature) - -class Dropout(ElementSameInputOp): - - def __init__(self, signature: str): - super().__init__(signature) - - -# ================== aten tensor op ======================== - -class TensorAdd(ElementSameInputOp): - - def __init__(self, signature: str): - super().__init__(signature) - - -class TensorSum(GenericLogicalOp): - - @staticmethod - def candidates(): - raise NotImplementedError - - @staticmethod - def shape_infer(*args): - return [[1,],] diff --git a/cube/operator/logic/generics.py b/cube/operator/logic/generics.py deleted file mode 100644 index d0a2b33a..00000000 --- a/cube/operator/logic/generics.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import List - - -class DistAlgorithmFactory: - - def __init__(self): - - self.algorithms = list() - - def __len__(self): - """ - Return the number of holistic op registered - """ - return len(self.algorithms) - - def register(self, algorithm): - """ - Register a holistic op (class) as one of the anchors - """ - self.algorithms.append(algorithm) - - def get_op(self, idx, outputs, *args, **kwargs): - """ - Get holistic operator based on idx - - The holistic operator will be initialized with shapes - - Args: - idx (int): index for the holist op factory - args, kwargs: (logical) tensor inputs - - Returns: - HolisticOp instance - """ - return self.algorithms[idx](outputs, *args, **kwargs) - - -class GenericLogicalOp: - - def __init__(self, signature: str): - """ - Generic logical operator - - signature (str): - Framework implementation signature, - e.g., 'torch.nn.functional.linear' - """ - if not isinstance(signature, str): - raise TypeError("Expect signature to be a string") - # factory - self.factory = DistAlgorithmFactory() - # torch impl signature - self.signature = signature - - @staticmethod - def shape_infer(*args, **kwargs): - """ - Output shape inference according to inputs - - Args: - Operator input - - Returns: - shapes tuple(list[int]): shape for each output tensor - """ - raise NotImplementedError("Expected a shape infer engine") - - def register_algorithm(self, algorithm): - """ - Register a distributed algoritm description - """ - self.factory.register(algorithm) - - def translate(self, config): - """ - Translate the algorithm to implementation - """ - raise NotImplementedError("Expected a tranlation for operator") - - -class ElementSameInputOp(GenericLogicalOp): - - def __init__(self): - """ - Elementwise Operator - """ - super().__init__() - - @staticmethod - def shape_infer(input: List[int], *args, **kwargs): - """ - Element-wise single input op - """ - return [input] From 117f97c8eb41f6da1933c6fc71551d207da9179c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 3 Nov 2021 15:52:48 +0800 Subject: [PATCH 0246/1892] value map for tensor abstraction --- cube/algorithm/utils.py | 2 +- cube/graph/graph.py | 2 +- cube/graph/tensor.py | 162 +++++++++++++++++++------- tests/graph/test_tensor.py | 34 +++--- tests/schedule/test_adapter_select.py | 12 +- 5 files changed, 144 insertions(+), 68 deletions(-) diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py index 2211027d..8e743500 100644 --- a/cube/algorithm/utils.py +++ b/cube/algorithm/utils.py @@ -24,7 +24,7 @@ def split_axis(tensor: IRTensor, axis: int, chunk_num: int): shape_slicer[axis] = slice(chunk_size * cid, chunk_size * (cid + 1)) sub_tensors.append(tensor.select( indices = tuple(shape_slicer), - val_op = None, + val_map = None, shape = chunk_shape )) return sub_tensors diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 178a245f..ae4d82f3 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -134,7 +134,7 @@ def _renew(val: Any): full_tensor = new_full_tensors[val.parent._id] new_val = full_tensor.select( indices=val.indices, - val_op=val.val_op, + val_map=val.val_map, shape=val.shape ) if reverse and val.is_param(): diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 89b3291d..279804b8 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -1,10 +1,10 @@ -from typing import List, Optional, Callable, Union, Tuple +from typing import List, Optional, Union, Tuple import copy from cube.ir.cten import IRTensor -__all__ = ['IndexMap', 'IRFullTensor', 'IRSubTensor'] +__all__ = ['IndexMap', 'ValueMap', 'IRFullTensor', 'IRSubTensor'] class IndexMap: @@ -171,6 +171,75 @@ def __repr__(self): return dscp +class ValueMap: + r""" + Represent the value split. + + Value is represented as a summation of several variables + + value = \sigma_{i=1}^{chunk_num} a_i + + two tensors consider as same value mapping: + they have same chunk num and share the same a_i (idx) + + Note we regard these mapping as same: + 1.0 = 0.9 (a1) + 0.1 (a2) + 1.0 = 0.4 (a1) + 0.6 (a2) + + The mapping doesn't consider what a1 really contains, but only + consider the variable (a) itself and number of variable. + """ + + def __init__(self, idx: int, chunk_num: int): + if idx >= chunk_num or idx < 0: + raise ValueError("Expected idx in [0, chunk_num)") + self._idx = idx + self._chunk_num = chunk_num + + @property + def idx(self): + return self._idx + + @property + def chunk_num(self): + return self._chunk_num + + def map(self, sub_map): + if not isinstance(sub_map, ValueMap): + raise TypeError("Expected sub_map to be ValueMap") + idx = self.idx + sub_map.idx + chunk_num = self.chunk_num - 1 + sub_map.chunk_num + return ValueMap(idx, chunk_num) + + def __eq__(self, other): + if isinstance(other, ValueMap): + if other.idx == self.idx and other.chunk_num == self.chunk_num: + return True + return False + + +def _to_index_map(indices: Union[Tuple, IndexMap]): + if not isinstance(indices, tuple) and not isinstance(indices, IndexMap): + raise TypeError("Expected indices to be tuple or IndexMap") + if isinstance(indices, tuple): + indices = IndexMap(indices) + return indices + + +def _to_value_map(val_map: Union[Tuple, ValueMap, None]): + if not isinstance(val_map, tuple) and \ + not isinstance(val_map, ValueMap) and \ + not val_map is None: + raise TypeError("Expected val_map to be tuple, IndexMap or None") + if val_map is None: + val_map = ValueMap(0, 1) + elif isinstance(val_map, tuple): + if len(val_map) != 2: + raise ValueError("Expected tuple to be (idx, chunk_num)") + val_map = ValueMap(*val_map) + return val_map + + class IRFullTensor(IRTensor): def __init__(self, shape=None, name=None): @@ -181,7 +250,7 @@ def __init__(self, shape=None, name=None): # indices: List[IndexMap] for each segment self._indices: List = list() # value op - self._val_ops: List = list() + self._val_maps: List = list() def __copy__(self): """ @@ -238,48 +307,49 @@ def indices(self, index: Optional[int] = None) -> IndexMap: else: return self._indices[index] - def val_ops(self, index: Optional[int] = None): + def val_maps(self, index: Optional[int] = None): """ - Get the SubTensors val_op + Get the SubTensors val_map """ if index is None: - return copy.copy(self._val_ops) + return copy.copy(self._val_maps) else: - return self._val_ops[index] + return self._val_maps[index] - def select(self, indices: Union[Tuple, IndexMap], val_op: Optional[Callable], shape: List[int]): + def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap, None], shape: List[int]): """ Select a SubTensor from FullTensor. Note due to implementation issue, one value in the full tensor - cannot be splitted by different val_op + cannot be splitted by different val_map Args: indices: the index of this tensor's index - val_op: how the tensor is merged with the other - sub_tensor at same location + val_map: how the tensor mapped from original value shape: the sub_tensor shape. Returns: IRSubTensor """ - if not isinstance(indices, IndexMap): - indices = IndexMap(indices) - if indices in self._indices: - index = self._indices.index(indices) - sub_tensor = self._segments[index] - if sub_tensor.val_op == val_op: + indices = _to_index_map(indices) + val_map = _to_value_map(val_map) + + for idx in range(len(self._segments)): + indmap = self._indices[idx] + valmap = self._val_maps[idx] + sub_tensor = self._segments[idx] + if indmap == indices and valmap == val_map: return sub_tensor - sub_tensor = IRSubTensor(self, indices, val_op, shape) + sub_tensor = IRSubTensor(self, indices, val_map, shape) for attr in IRFullTensor._attr: setattr(sub_tensor, attr, getattr(self, attr)) self._segments.append(sub_tensor) self._indices.append(indices) - self._val_ops.append(val_op) + self._val_maps.append(val_map) return sub_tensor def overlap(self, other): @@ -323,7 +393,7 @@ def tosub(self): slicers.append(slice(0, dim_len, 1)) sub_tensor = self.select( indices=tuple(slicers), - val_op=None, + val_map=None, shape=self.shape ) return sub_tensor @@ -335,14 +405,14 @@ def __repr__(self): class IRSubTensor(IRTensor): - def __init__(self, full_tensor: IRTensor, indices, val_op=None, shape=None): + def __init__(self, full_tensor: IRTensor, indices, val_map: Optional[ValueMap] =None, shape=None): """ Create an IRSubTensor. Args: full_tensor: the full tensor indices: index list - val_op: the value operation to merge SubTensors into one + val_map: the value operation to merge SubTensors into one """ if not isinstance(full_tensor, IRFullTensor): raise TypeError(f"Expected IRFullTensor but got {full_tensor}") @@ -352,23 +422,22 @@ def __init__(self, full_tensor: IRTensor, indices, val_op=None, shape=None): self._full_tensor = full_tensor # the index from full_tensor - if not isinstance(indices, IndexMap): - indices = IndexMap(indices) - self._index_map = indices + self._index_map = _to_index_map(indices) - # val merge op - self.val_merge_op = val_op + # val map + self._val_map = _to_value_map(val_map) def __eq__(self, other): if isinstance(other, IRFullTensor): - return self.parent == other and self.shape == other.shape + return self.parent == other and \ + self.shape == other.shape and \ + self.val_map == ValueMap(0, 1) if isinstance(other, IRSubTensor): - if self.parent != other.parent: - return False - if other.indices == self.indices and self.shape == other.shape: - return True - return False + return self.parent == other.parent and \ + self.indices == other.indices and \ + self.val_map == other.val_map and \ + self.shape == other.shape return False @property @@ -386,8 +455,8 @@ def indices(self) -> IndexMap: return copy.copy(self._index_map) @property - def val_op(self): - return self.val_merge_op + def val_map(self): + return copy.copy(self._val_map) def __copy__(self): """ @@ -397,7 +466,7 @@ def __copy__(self): Returns: tensor """ - tensor = IRSubTensor(self.parent, self.indices, self.val_op, self._shape) + tensor = IRSubTensor(self.parent, self.indices, self.val_map, self._shape) for key in self.__dict__: setattr(tensor, key, getattr(self, key)) # clear attached cells @@ -413,14 +482,14 @@ def as_param(self): self.requires_grad = True self._is_param = True - def select(self, indices: Union[Tuple, IndexMap], val_op, shape=None): + def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap, None], shape=None): """ Select an IRSubTensor Args: indices: the index of this tensor's index - val_op: the value operation to merge + val_map: the value operation to merge co-located indices of SubTensors into one shape: the sub_tensor shape @@ -428,9 +497,15 @@ def select(self, indices: Union[Tuple, IndexMap], val_op, shape=None): Returns: IRSubTensor """ - sub_map = IndexMap(indices) - index_map = self.indices.map(sub_map) - sub_tensor = self.parent.select(index_map.get(), val_op, shape) + sub_ind_map = _to_index_map(indices) + sub_val_map = _to_value_map(val_map) + + # index mapping + index_map = self.indices.map(sub_ind_map) + # value mapping + val_map = self.val_map.map(sub_val_map) + + sub_tensor = self.parent.select(index_map, val_map, shape) return sub_tensor def overlap(self, other): @@ -448,7 +523,8 @@ def overlap(self, other): elif isinstance(other, IRSubTensor): if self.parent != other.parent: return False - return self.indices.overlap(other.indices) + return self.indices.overlap(other.indices) and \ + self.val_map == other.val_map else: raise TypeError("Customized IRTensor not support") @@ -470,7 +546,7 @@ def common(self, other): indices = self.indices & other.indices sub_tensor = self.parent.select( indices = indices.get(), - val_op = self.val_op, + val_map = self.val_map, shape = indices.shape ) return sub_tensor diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py index c8ec1921..5553ac32 100644 --- a/tests/graph/test_tensor.py +++ b/tests/graph/test_tensor.py @@ -20,17 +20,17 @@ def test_full_tensor_select(): tensor = IRFullTensor(shape=[1024,1024], name='tensor') assert len(tensor.segments()) == 0 assert len(tensor.indices()) == 0 - assert len(tensor.val_ops()) == 0 + assert len(tensor.val_maps()) == 0 sub_tensor1 = tensor.select( indices = (slice(0, 1024), slice(0, 512)), - val_op = None, + val_map = None, shape = (1024, 512) ) sub_tensor2 = tensor.select( indices = (slice(0, 1024), slice(512, 1024)), - val_op = None, + val_map = None, shape = (1024, 512) ) @@ -42,7 +42,7 @@ def test_full_tensor_select(): assert len(tensor.segments()) == 2 assert len(tensor.indices()) == 2 - assert len(tensor.val_ops()) == 2 + assert len(tensor.val_maps()) == 2 def test_full_tensor_overlap(): @@ -50,18 +50,18 @@ def test_full_tensor_overlap(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( indices = (slice(0, 1024), slice(256, 1024)), - val_op = None, + val_map = None, shape = (1024, 768) ) sub_tensor2 = tensor1.select( indices = (slice(0, 1024, 2), slice(512, 1024)), - val_op = None, + val_map = None, shape = (1024, 512) ) sub_tensor3 = tensor1.select( indices = (slice(1, 1024, 2), slice(512, 1024)), - val_op = None, + val_map = None, shape = (1024, 512) ) @@ -80,17 +80,17 @@ def test_sub_tensor_select(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( indices = (slice(0, 1024), slice(512, 1024)), - val_op = None, + val_map = None, shape = (1024, 512) ) sub_tensor2 = sub_tensor1.select( indices = (slice(512, 1024), slice(0, 256)), - val_op = None, + val_map = None, shape = (512, 256) ) sub_tensor3 = sub_tensor1.select( indices = (slice(512, 1024), slice(256, 512)), - val_op = None, + val_map = None, shape = (512, 256) ) @@ -110,17 +110,17 @@ def test_sub_tensor_overlap(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( indices = (slice(0, 1024), slice(512, 1024)), - val_op = None, + val_map = None, shape = (1024, 512) ) sub_tensor2 = sub_tensor1.select( indices = (slice(512, 1024), slice(0, 256)), - val_op = None, + val_map = None, shape = (512, 256) ) sub_tensor3 = sub_tensor1.select( indices = (slice(512, 1024), slice(256, 512)), - val_op = None, + val_map = None, shape = (512, 256) ) @@ -134,22 +134,22 @@ def test_sub_tensor_common(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor_col1 = tensor1.select( indices = (slice(0, 1024), slice(0, 512)), - val_op = None, + val_map = None, shape = (1024, 512) ) sub_tensor_col2 = tensor1.select( indices = (slice(0, 1024), slice(512, 1024)), - val_op = None, + val_map = None, shape = (1024, 512) ) sub_tensor_row1 = tensor1.select( indices = (slice(0, 512), slice(0, 1024)), - val_op = None, + val_map = None, shape = (512, 1024) ) sub_tensor_row2 = tensor1.select( indices = (slice(512, 1024), slice(0, 1024)), - val_op = None, + val_map = None, shape = (512, 1024) ) diff --git a/tests/schedule/test_adapter_select.py b/tests/schedule/test_adapter_select.py index d0281ee7..0f89f0f1 100644 --- a/tests/schedule/test_adapter_select.py +++ b/tests/schedule/test_adapter_select.py @@ -11,13 +11,13 @@ def test_tensor_reshape_init(): tensor2 = tensor1.select( indices = (slice(0, 512), slice(0, 1024)), - val_op = None, + val_map = None, shape = [512, 1024] ) tensor3 = tensor1.select( indices = (slice(512, 1024), slice(0, 1024)), - val_op = None, + val_map = None, shape = [512, 1024] ) @@ -53,25 +53,25 @@ def test_adapter_select_is_identity(): tensor2 = tensor1.select( indices = (slice(512, 1024), slice(0, 1024)), - val_op = None, + val_map = None, shape = [512, 1024] ) tensor3 = tensor2.select( indices = (slice(0, 256), slice(0, 1024)), - val_op = None, + val_map = None, shape = [256, 1024] ) tensor4 = tensor1.select( indices = (slice(512, 768), slice(0, 1024)), - val_op = None, + val_map = None, shape = [256, 1024] ) tensor5 = tensor1.select( indices = (slice(512, 768), slice(0, 1024)), - val_op = None, + val_map = None, shape = [256, 1024] ) From ea8ab4c264dc15e2e0c4ebbf399a664d8ffb2b33 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 3 Nov 2021 18:36:14 +0800 Subject: [PATCH 0247/1892] test for linear data parallel --- cube/algorithm/factory.py | 12 +++++- cube/algorithm/generics.py | 2 +- cube/algorithm/linear.py | 27 +++++--------- cube/algorithm/utils.py | 19 +++++++++- tests/algorithm/test_factory.py | 15 ++++++++ tests/algorithm/test_generics.py | 21 +++++++++++ tests/algorithm/test_linear_algo.py | 58 +++++++++++++++++++++++++++++ 7 files changed, 133 insertions(+), 21 deletions(-) create mode 100644 tests/algorithm/test_factory.py create mode 100644 tests/algorithm/test_generics.py create mode 100644 tests/algorithm/test_linear_algo.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 062b272f..ba3384df 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -14,6 +14,7 @@ def __init__(self): def __init__(self): if not DistAlgorithmFactory.instance: DistAlgorithmFactory.instance = DistAlgorithmFactory.__DistAlgorithmFactory() + self._load_predefined_algos() def __getattr__(self, name): return getattr(self.instance, name) @@ -24,7 +25,7 @@ def register(self, op, algorithm, tag: str): """ if op not in self.instance._algos: self.instance._algos[op] = dict() - self._algos[op][tag] = algorithm + self.instance._algos[op][tag] = algorithm def algorithms(self, op, tag = None): """ @@ -37,7 +38,16 @@ def algorithms(self, op, tag = None): Returns: algorithm class """ + if op not in self.instance._algos: + raise KeyError("Op {op} is not registered in factory") if tag: return self.instance._algos[op][tag] else: return self.instance._algos[op].values() + + def _load_predefined_algos(self): + + import cube.algorithm.linear as linear + self.register(linear.Linear, linear.LinearDataParallel, tag='data') + self.register(linear.Linear, linear.LinearColumnWeight, tag='column') + self.register(linear.Linear, linear.LinearRowWeight, tag='row') diff --git a/cube/algorithm/generics.py b/cube/algorithm/generics.py index bc1a7686..534852c4 100644 --- a/cube/algorithm/generics.py +++ b/cube/algorithm/generics.py @@ -47,7 +47,7 @@ def satisfy(self, config: Dict): """ raise NotImplementedError - def instantiate(self, config: Dict): + def instantiate(self, node, config: Dict): """ Instantiate the algorithm given the config """ diff --git a/cube/algorithm/linear.py b/cube/algorithm/linear.py index 771dd118..97fb966c 100644 --- a/cube/algorithm/linear.py +++ b/cube/algorithm/linear.py @@ -2,8 +2,6 @@ from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo -from cube.algorithm.factory import DistAlgorithmFactory - from cube.graph.operator.function import Linear @@ -21,16 +19,16 @@ def __init__(self, input_shapes: List[Optional[List[int]]], output_shapes: List[ def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) input_shape = self.input_shapes[0] - if input_shape[0] % chunk_num != 0: + if chunk_num > 0 and input_shape[0] % chunk_num != 0: return False return True - def instantiate(self, inputs, outputs, config: Dict): + def instantiate(self, node, config: Dict): if not self.satisfy(config): raise RuntimeError("Instantiate failed. Condition not satisfied.") self.chunk_num = int(config['chunk_num']) - input, weight, bias = inputs - output = outputs[0] + input, weight, bias = node.inputs() + output = node.outputs(0) ins = split_axis(input, 0, self.chunk_num) outs = split_axis(output, 0, self.chunk_num) @@ -62,12 +60,12 @@ def satisfy(self, config: Dict): return False return True - def instantiate(self, inputs, outputs, config: Dict): + def instantiate(self, node, config: Dict): if not self.satisfy(config): raise RuntimeError("Instantiate failed. Condition not satisfied.") self.chunk_num = int(config['chunk_num']) - input, weight, bias = inputs - output = outputs[0] + input, weight, bias = node.inputs() + output = node.outputs(0) ws = split_axis(weight, 0, self.chunk_num) if bias is not None: @@ -103,12 +101,12 @@ def satisfy(self, config: Dict): return False return True - def instantiate(self, inputs, outputs, config: Dict): + def instantiate(self, node, config: Dict): if not self.satisfy(config): raise RuntimeError("Instantiate failed. Condition not satisfied.") self.chunk_num = int(config['chunk_num']) - input, weight, bias = inputs - output = outputs[0] + input, weight, bias = node.inputs() + output = node.outputs(0) ins = split_axis(input, 1, self.chunk_num) ws = split_axis(weight, 1, self.chunk_num) @@ -128,8 +126,3 @@ def instantiate(self, inputs, outputs, config: Dict): node.set_output(0, o) nodes.append(node) return nodes - - -DistAlgorithmFactory().register(Linear, LinearDataParallel, tag='data') -DistAlgorithmFactory().register(Linear, LinearColumnWeight, tag='column') -DistAlgorithmFactory().register(Linear, LinearRowWeight, tag='row') diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py index 8e743500..f21338a0 100644 --- a/cube/algorithm/utils.py +++ b/cube/algorithm/utils.py @@ -20,7 +20,7 @@ def split_axis(tensor: IRTensor, axis: int, chunk_num: int): chunk_shape.append(chunk_size) sub_tensors = list() - for cid in range(chunk_size): + for cid in range(chunk_num): shape_slicer[axis] = slice(chunk_size * cid, chunk_size * (cid + 1)) sub_tensors.append(tensor.select( indices = tuple(shape_slicer), @@ -31,4 +31,19 @@ def split_axis(tensor: IRTensor, axis: int, chunk_num: int): def split_value(tensor: IRTensor, chunk_num: int): - raise NotImplementedError \ No newline at end of file + + # full shape + shape_slicer = list() + for nele in tensor.shape: + shape_slicer.append(slice(0, nele, 1)) + + sub_tensors = list() + for idx in range(chunk_num): + sub_tensor = tensor.select( + indices = tuple(shape_slicer), + val_map = (idx, chunk_num), + shape = tensor.shape + ) + sub_tensors.append(sub_tensor) + + return sub_tensors diff --git a/tests/algorithm/test_factory.py b/tests/algorithm/test_factory.py new file mode 100644 index 00000000..b7453274 --- /dev/null +++ b/tests/algorithm/test_factory.py @@ -0,0 +1,15 @@ +from cube.algorithm.factory import DistAlgorithmFactory +from cube.graph.operator.function import Linear +from cube.algorithm.generics import GenericDistAlgo + + +def test_factory_init(): + factory = DistAlgorithmFactory() + assert len(factory.algorithms(Linear)) == 3 + + +def test_factory_tag(): + + factory = DistAlgorithmFactory() + dp = factory.algorithms(Linear, tag='data') + assert issubclass(dp, GenericDistAlgo) diff --git a/tests/algorithm/test_generics.py b/tests/algorithm/test_generics.py new file mode 100644 index 00000000..5fbe97e0 --- /dev/null +++ b/tests/algorithm/test_generics.py @@ -0,0 +1,21 @@ +from cube.algorithm.generics import GenericDistAlgo +from cube.graph.operator.function import Linear + + +def test_generic_algo_init(): + + algo = GenericDistAlgo( + input_shapes=[[1024,1024], [1024, 1024], None], + output_shapes=[[1024, 1024]] + ) + assert algo.logical_op is None + + +def test_generic_set_logic_op(): + + algo = GenericDistAlgo( + input_shapes=[[1024,1024], [1024, 1024], None], + output_shapes=[[1024, 1024]] + ) + algo.set_logic_op(Linear) + assert algo.logical_op == Linear diff --git a/tests/algorithm/test_linear_algo.py b/tests/algorithm/test_linear_algo.py new file mode 100644 index 00000000..b94eb66b --- /dev/null +++ b/tests/algorithm/test_linear_algo.py @@ -0,0 +1,58 @@ +from cube.graph.operator.function import Linear +from cube.algorithm.linear import LinearDataParallel, LinearColumnWeight, LinearRowWeight +from cube.graph.tensor import IRFullTensor + + +def test_linear_data_parallel(): + + input = IRFullTensor(shape=[1024, 1024], name='input').tosub() + weight = IRFullTensor(shape=[1000, 1024], name='weight').tosub() + bias = IRFullTensor(shape=[1000,], name='bias').tosub() + + semantic_op = Linear( + signature='torch.nn.functional.linear', + inputs = [input, weight, bias], + ) + semantic_op.infer_shape() + + input_shapes = list() + for input in semantic_op.inputs(): + input_shapes.append(input.shape) + + output_shapes = list() + for output in semantic_op.outputs(): + output_shapes.append(output.shape) + + linear_dp = LinearDataParallel( + input_shapes=input_shapes, + output_shapes=output_shapes, + ) + + assert linear_dp.chunk_num is None + + # test satisfy + assert linear_dp.satisfy(dict(chunk_num=4)) + assert not linear_dp.satisfy(dict(chunk_num=10)) + + nodes = linear_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, Linear) + + inputs = [node.inputs(0) for node in nodes] + weights = [node.inputs(1) for node in nodes] + biass = [node.inputs(2) for node in nodes] + + for x in inputs: + assert x.shape == [256, 1024] + assert not inputs[0].overlap(inputs[1]) + assert not inputs[0].overlap(inputs[2]) + assert not inputs[0].overlap(inputs[3]) + + for w in weights: + assert w.shape == [1000, 1024] + assert w == weight + + for b in biass: + assert b.shape == [1000] + assert b == bias From fdb30c4fbe1919e4827b549168b6588d54fcd570 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Nov 2021 11:19:01 +0800 Subject: [PATCH 0248/1892] test for linear parallelisms --- cube/algorithm/utils.py | 2 +- cube/graph/tensor.py | 5 +- cube/tensor/__init__.py | 0 cube/tensor/indices.py | 102 -------------- cube/tensor/logic/__init__.py | 0 cube/tensor/logic/outline.py | 200 ---------------------------- cube/tensor/logic/tensor.py | 169 ----------------------- cube/tensor/physic/__init__.py | 0 cube/tensor/physic/tensor.py | 73 ---------- cube/tensor/segment.py | 154 --------------------- tests/algorithm/test_linear_algo.py | 158 +++++++++++++++++++++- tests/tensor/test_indices.py | 71 ---------- tests/tensor/test_logical_tensor.py | 102 -------------- tests/tensor/test_outline.py | 132 ------------------ tests/tensor/test_segment.py | 107 --------------- tests/test_physic_tensor.py | 25 ---- 16 files changed, 160 insertions(+), 1140 deletions(-) delete mode 100644 cube/tensor/__init__.py delete mode 100644 cube/tensor/indices.py delete mode 100644 cube/tensor/logic/__init__.py delete mode 100644 cube/tensor/logic/outline.py delete mode 100644 cube/tensor/logic/tensor.py delete mode 100644 cube/tensor/physic/__init__.py delete mode 100644 cube/tensor/physic/tensor.py delete mode 100644 cube/tensor/segment.py delete mode 100644 tests/tensor/test_indices.py delete mode 100644 tests/tensor/test_logical_tensor.py delete mode 100644 tests/tensor/test_outline.py delete mode 100644 tests/tensor/test_segment.py delete mode 100644 tests/test_physic_tensor.py diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py index f21338a0..a552092c 100644 --- a/cube/algorithm/utils.py +++ b/cube/algorithm/utils.py @@ -21,7 +21,7 @@ def split_axis(tensor: IRTensor, axis: int, chunk_num: int): sub_tensors = list() for cid in range(chunk_num): - shape_slicer[axis] = slice(chunk_size * cid, chunk_size * (cid + 1)) + shape_slicer[axis] = slice(chunk_size * cid, chunk_size * (cid + 1), 1) sub_tensors.append(tensor.select( indices = tuple(shape_slicer), val_map = None, diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 279804b8..69fc7d07 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -217,6 +217,9 @@ def __eq__(self, other): return True return False + def __repr__(self): + return f'({self.idx}/{self.chunk_num})' + def _to_index_map(indices: Union[Tuple, IndexMap]): if not isinstance(indices, tuple) and not isinstance(indices, IndexMap): @@ -555,5 +558,5 @@ def common(self, other): return None def __repr__(self): - dscp = f'SubTensor(id={self._id}, shape={self.shape}, device={self.device})' + dscp = f'SubTensor(id={self._id}, shape={self.shape}, device={self.device}, ind={self.indices}, val={self.val_map})' return dscp \ No newline at end of file diff --git a/cube/tensor/__init__.py b/cube/tensor/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/tensor/indices.py b/cube/tensor/indices.py deleted file mode 100644 index 08862ecf..00000000 --- a/cube/tensor/indices.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Basic structure for holding indices -> cover all the cases -""" - - -class BaseIndices: - """ - The basic primitive to gather data in the logical tensor. - - The order of indices indicate the physical storage (1-D array) order - """ - - def __init__(self, indices_list): - """ - Args: - indices_list (list[list[int],], tuple(slice(int, int),)): - indices list - """ - self.indices = tuple(indices_list) - - def ndim(self): - """ - Return dims of this indices - """ - return len(self.indices) - - def size(self): - """ - Return total number of index - """ - return len(self.indices[0]) - - def get(self): - """ - Get indexable indices - """ - return tuple(self.indices) - - def reorder(self, new_orders): - """ - Reorder the indices. - - Note this can be only called before materialize physical tensors, - or called from underlying operation that will change physical storage format - - Args: - new_orders (iteratable): order of each index - """ - new_orders = list(new_orders) - indices = list(self.indices) - for dim in range(self.ndim()): - indices[dim] = [self.indices[dim][idx] for idx in new_orders] - self.indices = tuple(indices) - - def __repr__(self): - msg = 'BaseIndices(indices_len={})'.format( - len(self.indices), self.reduction - ) - - -class TileIndices(BaseIndices): - """ - A tile is a contigonous block on the logical tensor shape, - which can be represented as the start position + offset (shape) - """ - - def __init__(self, anchor, shape): - """ - Args: - anchor (list[int]): start position of the tile - offset (list[int]): offset (shape) of the tile - """ - indices = list() - size = 1 - for start, ofst in zip(anchor, shape): - indices.append(slice(start, start + ofst)) - size *= ofst - super().__init__(tuple(indices)) - self.anchor = anchor - self.shape = shape - self.elenum = size - - def ndim(self): - """ - Return dims of this indices - """ - return len(self.indices) - - def size(self): - """ - Return total number of index - """ - return self.elenum - - def reorder(self): - raise NotImplementedError - - def __repr__(self): - msg = 'TileIndices(anchor={}, shape={})'.format( - self.anchor, self.shape - ) - return msg diff --git a/cube/tensor/logic/__init__.py b/cube/tensor/logic/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/tensor/logic/outline.py b/cube/tensor/logic/outline.py deleted file mode 100644 index b94e3984..00000000 --- a/cube/tensor/logic/outline.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -This is the description interface to describe the -segmentation requirement (restrictions). - -The description includes two parts: - - 1). restriction description on tensor segementation - - 2). Translation procedure to translate such a restriction - to the real segmentation on given logical tensor. -""" - -from cube.tensor.segment import Segment -from cube.tensor.indices import TileIndices - -import z3 - - -class BaseOutline: - """ - Basic class for declare outline - - To setup an attribute (requirement), use `inst_baseoutline.attribute_name = val` - """ - def __init__(self, solver, tensor): - if not isinstance(solver, z3.z3.Solver): - raise TypeError("Expected solver to be an z3.z3.Solver") - self.solver = solver - self.shape = tensor.shape - self.attributes = list() - - def get_attributes(self): - return self.attributes - - def add_field(self, **kwargs): - """ - Add a config field to current instance - - Usage: self.add_field(key=val): - - key is the name for the config attribute, val is the choices - - val type: - list[int]: the key can only be the options from the val; - int: the key can only be the val; - range: the key can only be the val in the range; - None: the key can be any integers - z3.z3.ArithRef: the key is aligned with another attribute - """ - for key in kwargs: - if key in self.__dict__: - raise RuntimeError("{} already in config field".format(key)) - val = kwargs[key] - if isinstance(val, list): - if not all([isinstance(arg, int) for arg in val]): - raise TypeError("{} only supports list[int] choices".format(key)) - self.__dict__[key] = z3.Int(key) - self.attributes.append(self.__dict__[key]) - self.solver.add(z3.Or([self.__dict__[key] == v for v in val])) - elif isinstance(val, int): - self.__dict__[key] = z3.Int(str(id(self))+key) - self.attributes.append(self.__dict__[key]) - self.solver.add(self.__dict__[key] == val) - elif isinstance(val, range): - self.__dict__[key] = z3.Int(str(id(self))+key) - self.attributes.append(self.__dict__[key]) - self.solver.add(self.__dict__[key] >= val[0]) - raise NotImplementedError - elif val is None: - self.__dict__[key] = z3.Int(str(id(self))+key) - self.attributes.append(self.__dict__[key]) - elif isinstance(val, z3.z3.ArithRef): - self.__dict__[key] = val - else: - raise TypeError("{} can only be int, list[int], z3.Int()".format(key)) - - def add_constraint(self, constraint): - """ - Add a constraint - """ - if not isinstance(constraint, z3.z3.BoolRef): - raise TypeError("Expected z3.z3.BoolRef constraints") - self.solver.add(constraint) - - def remove_config(self, config): - if not isinstance(config, z3.z3.ModelRef): - raise TypeError("Expected config from z3 model()") - self.solver.add(z3.Or([z3.Not(attr == config[attr]) for attr in self.attributes])) - - def interpret(self, tensor, config): - """ - Interpret to a list of segment based on the logical tensor and config - - Args: - tensor (LogicalTensor) - config (z3.z3.ModelRef) - - Returns: - list[Segment] - """ - raise NotImplementedError - - -class Full(BaseOutline): - - def __init__(self, solver, tensor): - super().__init__(solver, tensor) - - def interpret(self, tensor, config): - if not isinstance(config, z3.z3.ModelRef): - raise TypeError("Expected config from z3 model()") - indices = TileIndices([0] * len(self.shape), self.shape) - segment = tensor.select(indices, None, self.shape) - return [segment] - - -class SplitAxis(BaseOutline): - - def __init__(self, solver, tensor, axis, chunk_num, overlap): - """ - Split the logical tensor uniformly in `axis` dimension - - TODO: support split axis with non-uniform chunk size - - shape: list / tuple int - shape of input logical tensor - axis: int - which axis to split - chunk_num: options (iterable int) / None / int: - how many segments to produce - overlap: options (iterable int) / int: - overlap size on the boundary - """ - if not isinstance(axis, int): - raise RuntimeError("Expected axis to be an integer") - super().__init__(solver, tensor) - - self.axis = axis - - self.add_field(overlap=overlap) - self.add_constraint(self.overlap >= 0) - - self.add_field(chunk_num=chunk_num) - self.add_constraint(self.chunk_num >= 0) - - # TODO: change to array to adapt with non-uniform cases - self.add_field(chunk_size=None) - - # setup constraints - total_size = self.shape[self.axis] - self.add_constraint( - self.chunk_num * self.chunk_size - self.overlap * (self.chunk_num - 1) == total_size - ) - - def interpret(self, tensor, config): - if tuple(tensor.shape) != tuple(self.shape): - raise RuntimeError("The logical tensor's shape doesn't match") - if not isinstance(config, z3.z3.ModelRef): - raise TypeError("Expected config from z3 model()") - chunk_num = config[self.chunk_num].as_long() - chunk_size = config[self.chunk_size].as_long() - shape = list(self.shape) - shape[self.axis] = chunk_size - anchor = [0] * len(shape) - segments = list() - for cid in range(chunk_num): - indices = TileIndices(anchor, shape) - segment = tensor.select(indices, None, shape) - segments.append(segment) - anchor[self.axis] += shape[self.axis] - return segments - - -class SplitValue(BaseOutline): - - def __init__(self, solver, tensor, chunk_num, val_op): - """ - Split the whole tensor in value dimension. - - Each segment shape will be same with logical tensor. - - Each segment value will be modified by `val_op`. - """ - super().__init__(solver, tensor) - self.add_field(chunk_num=chunk_num) - self.add_constraint(self.chunk_num >= 1) - self.val_op = val_op - - def interpret(self, tensor, config): - if tuple(tensor.shape) != tuple(self.shape): - raise RuntimeError("The logical tensor's shape doesn't match") - chunk_num = config[self.chunk_num].as_long() - segments = list() - for cid in range(chunk_num): - indices = TileIndices([0] * len(self.shape), self.shape) - segment = tensor.select(indices, self.val_op, self.shape) - segments.append(segment) - for segment in segments: - segment.val_op_segs.append(segments) - return segments diff --git a/cube/tensor/logic/tensor.py b/cube/tensor/logic/tensor.py deleted file mode 100644 index 79283912..00000000 --- a/cube/tensor/logic/tensor.py +++ /dev/null @@ -1,169 +0,0 @@ -from cube.tensor.segment import Segment -from cube.tensor.indices import BaseIndices - -from cube.device.physic.group import DeviceGroup - -class LogicalTensor: - """ - The logical tensor - """ - - def __init__(self, shape, init_data=True): - """ - Create an empty logical tensor with no segmentations - - Args: - shape (tuple[int] or list[int]): - shape of the tensor - init_data (Boolean): - if True, init a CPU data. Otherwise no data initialized. - """ - self.shape = tuple(shape) - self.segments = list() - self.data = None - if init_data: - import torch - self.data = torch.randn(shape).detach() - - def fill(self, physical_tensors, ranks): - """ - Construct the logical tensor with physical tensors. - - Args: - physical_tensors (list[PhysicalTensor, None]): - the list length should be equal to len(self.segments) - ranks (list[list[int],]): - each segment will pair with a list of ranks - """ - if self.data is not None: - raise RuntimeError("Only allowed fill physical tensors when data is not None") - for segment, physical_tensor, ranks in zip(self.segments, physical_tensors, ranks): - segment.set_physical_tensor(physical_tensor, ranks) - - def select(self, indices, val_op, shape): - """ - Create a Segment given the indices for this logical tensor, - and the Segment will use shape. - """ - segment = Segment(self, indices, val_op, shape) - return segment - - def transform(self, segments, ranks=None): - """ - Transform the LogicalTensor with segment list. - TODO: check if this should create a new logical tensor - """ - if not (isinstance(ranks, list) and len(ranks) == len(segments)): - raise ValueError("Expected ranks to be a list with equal length of segments") - - if len(self.segments) == 0: - # setting up the placement for all segments - for sid in range(len(segments)): - segment = segments[sid] - self.add_segment(segment) - if not segment.materialized: - deploy_ranks = ranks[sid] - if not isinstance(deploy_ranks, list): - raise TypeError('Expected ranks to be list[list[int],]') - segment.placement = deploy_ranks - # deploy with the placement - for segment in self.segments: - if not segment.materialized: - segment.deploy() - #TODO: segment transformation on existing segments - else: - raise NotImplementedError - - def get_physical_tensor(self, index): - """ - Get physical tensor from the segment. - - Args: - idx: index for segment - - Returns: - torch.Tensor or None - """ - return self.get_segment(index).get_physical_tensor() - - def __len__(self): - """ - Return community number - """ - return len(self.segments) - - def __getitem__(self, key): - """ - - """ - # TODO: create new logical tensor / change layout - return self.data[key] - - def get_segment(self, idx): - """ - Get a segment using index - - Args: - idx (int): index to segment list - - Returns: - Segment - """ - return self.segments[idx] - - def add_segment(self, segment): - """ - Add a segment. - - Note adding a segment will change the segment parent logical tensor - to this tensor - """ - if not isinstance(segment, Segment): - raise TypeError("Expected a segment") - segment.logical_tensor = self - if segment in self.segments: - raise RuntimeError("Segment is already added") - self.segments.append(segment) - - def remove_segment(self, segment_or_index): - """ - Remove a community by given the segment - """ - #TODO: check whether a sync-back is needed - if isinstance(segment_or_index, Segment): - if segment not in self.segments: - raise KeyError("The segment doesn't exist") - self.segments.remove(segment) - elif isinstance(segment_or_index, int): - del self.segments[segment_or_index] - else: - raise ValueError("Expected Segment instance or index int") - - def merge_segment(self, indices, reduction_op): - """ - Merge segments for the logical tensor - - The merged segments will be placed at the end of the list. - """ - raise NotImplementedError - - def __repr__(self): - return 'LogicalTensor[{} with {} Segments]'.format( - tuple(self.shape), len(self.segments) - ) - - @staticmethod - def to_segments(*args, **kwargs): - args_segments = list() - for arg in args: - if isinstance(arg, LogicalTensor): - args_segments.append(arg.segments) - else: - args_segments.append(arg) - kwargs_segments = dict() - for key in kwargs: - if isinstance(kwargs[key], LogicalTensor): - kwargs_segments[key] = kwargs[key].segments - else: - kwargs_segments[key] = kwargs[key] - return tuple(args_segments), kwargs_segments diff --git a/cube/tensor/physic/__init__.py b/cube/tensor/physic/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/tensor/physic/tensor.py b/cube/tensor/physic/tensor.py deleted file mode 100644 index dad58a9e..00000000 --- a/cube/tensor/physic/tensor.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch - - -class PhysicTensor(torch.Tensor): - """ - Additional attributes on top of PyTorch Tensor: - - data_host_device: - Tensor data placement. The device is responsible - for managing the tensor data. - - grad_host_device: - Gradient data placement. The device is responsible - for managing the gradient data. If no grad required - for this Tensor, this option won't have impact - """ - @property - def data_host_device(self): - if not hasattr(self, '_data_host_device'): - self._data_host_device = self.device - return self._data_host_device - - @data_host_device.setter - def data_host_device(self, device): - if not isinstance(device, torch.device): - raise TypeError('Expected torch.device') - self._data_host_device = device - # inplacement movement to host device - if self.device != self.data_host_device: - self.move_(self.data_host_device) - - @property - def grad_host_device(self): - if not hasattr(self, '_grad_host_hook'): - self._grad_host_hook = None - if not hasattr(self, '_grad_host_device'): - self._grad_host_device = self._data_host_device - return self._grad_host_device - - - @grad_host_device.setter - def grad_host_device(self, device): - if not isinstance(device, torch.device): - raise TypeError('Expected torch.device') - self._grad_host_device = device - # inplacement movement to host device - if self.grad is not None: - self.grad.data = self.grad.detach().to(self.grad_host_device) - # modify hooks - if self._grad_host_hook is not None: - self._grad_host_hook.remove() - def move_grad(grad): - grad.data = grad.detach().to(self.grad_host_device) - return grad - self._grad_host_hook = self.register_hook(move_grad) - - - def move_(self, device): - """ - inplacement device movement - """ - if not isinstance(device, torch.device): - raise TypeError('Expected torch.device') - self.data = self.detach().to(device) - - def move_grad_(self, device): - """ - inplacement device move on tensor grad - """ - if not isinstance(device, torch.device): - raise TypeError('Expected torch.device') - if self.grad is not None: - self.grad.data = self.grad.detach().to(device) \ No newline at end of file diff --git a/cube/tensor/segment.py b/cube/tensor/segment.py deleted file mode 100644 index 942d2f63..00000000 --- a/cube/tensor/segment.py +++ /dev/null @@ -1,154 +0,0 @@ -from cube.device.physic.group import DeviceGroup -from cube.tensor.indices import BaseIndices - -import torch - - -class Segment: - - def __init__(self, logical_tensor, indices, val_op, shape): - """Create Segment based on the logical tensor - - Segment manages: - - 1). LogicalTensor indices mapping to Physical Tensor data storage - 2). Materialized Physical Tensor - - Attribute: - indices (tuple(slice,) or list[list[int]]): - indices of logical_tensor for this segment - val_op (ValueMapReduceOp): - deploy op to take logical value and group in for value mapping - merge op to take mapped value and group in for value reduction - """ - if not isinstance(indices, BaseIndices): - raise TypeError("Expected indices to be BaseIndices") - - # logical tensor - self.logical_tensor = logical_tensor - - # segment info - self.indices = indices - self.shape = tuple(shape) - - # val ops - self.val_ops = list() - self.val_op_segs = list() - self.add_val_op(val_op) - - # physical tensor (the PyTorch Tensor) - self.physical_tensor = None - - # deploy information - self.placement = list() - self.group = None - self.materialized = False - - def deploy(self, ranks=None): - """deploy (materialize) to physical tensors - - Materialize physical tensors for this community and spread out - based on the given device list. - - This offers policy module an interface to decide which devices - to spread. - - Argument: - ranks (list[int] or None): - if rank id list: deploy based on this list - if None: deploy based on setted self.placement - value_map_op (callable): - takes the tensor, rank, world_size, - return a new tensor - """ - if isinstance(ranks, list): - self.placement = ranks - elif ranks is None and self.placement is None: - raise TypeError("Expected self.placement when ranks is None") - - #TODO: remove this constraints - if len(self.val_ops) > 0 and len(self.placement) > 1: - raise RuntimeError("Currently segment with val_ops only allows to deploy on one rank") - - rank = DeviceGroup().rank - self.group = DeviceGroup().get_group(self.placement) - - # set physical tensors - if rank in self.placement: - if self.logical_tensor.data is None: - raise RuntimeError("Try deploying a segment from a logical tensor without data") - # select from logical data - self.physical_tensor = torch.empty(tuple(self.shape), device='cuda') - self.physical_tensor.copy_( - self.logical_tensor.data[self.indices.get()].reshape(self.shape) - ) - - # go through val_op - for val_op, segs in zip(self.val_ops, self.val_op_segs): - if len(segs) == 0: - raise RuntimeError("Missing segments for val op") - op_ranks = [seg.placement[0] for seg in segs] - group = DeviceGroup().get_group(op_ranks) - if rank in self.placement: - self.physical_tensor.data = val_op.map(self.physical_tensor, group) - - self.materialized = True - - def recover(self, reduction_op): - """ - Recover the deployed physical tensors by reduction operation - - Each rank can call this even there is no physical tensor on it. - - Args: - reduction_op (callable): - inplacement update on physical tensor - - Returns: - None. The physical tensor will be updated to match logical data - """ - if self.materialized: - if self.physical_tensor is not None: - reduction_op(self.physical_tensor, group=self.group) - else: - raise RuntimeError("The Segment has not been materialized") - - def add_val_op(self, val_op): - """ - Append val_op to the end - """ - if val_op is not None: - if not (callable(val_op.map) and callable(val_op.reduce)): - raise TypeError("Expected val_op to be ValMapReudceOp") - self.val_ops.append(val_op) - - def get_physical_tensor(self): - """Get physical tensor if materialized - - Returns: - PhysicalTensor (if materialized) - """ - if self.materialized: - return self.physical_tensor - else: - raise RuntimeError("The Segment has not been materialized") - - def set_physical_tensor(self, physical_tensor, ranks): - if self.materialized: - raise RuntimeError("Setting physical tensors to a materialized community") - if not isinstance(ranks, list): - raise TypeError("ranks: Expected a list[int]") - if physical_tensor is not None: - if list(physical_tensor.size()) != list(self.shape): - raise RuntimeError( - "Trying to set a community where physical tensor shape " - "doesn't match with segment shape") - self.physical_tensor = physical_tensor - self.group = DeviceGroup().get_group(ranks) - self.materialized = True - - def __repr__(self): - return 'Segment(Indices: {} | Materialized: {})'.format( - self.indices, self.materialized - ) - \ No newline at end of file diff --git a/tests/algorithm/test_linear_algo.py b/tests/algorithm/test_linear_algo.py index b94eb66b..861445c0 100644 --- a/tests/algorithm/test_linear_algo.py +++ b/tests/algorithm/test_linear_algo.py @@ -1,6 +1,8 @@ from cube.graph.operator.function import Linear -from cube.algorithm.linear import LinearDataParallel, LinearColumnWeight, LinearRowWeight -from cube.graph.tensor import IRFullTensor +from cube.algorithm.linear import LinearDataParallel +from cube.algorithm.linear import LinearColumnWeight +from cube.algorithm.linear import LinearRowWeight +from cube.graph.tensor import IRFullTensor, ValueMap def test_linear_data_parallel(): @@ -40,11 +42,21 @@ def test_linear_data_parallel(): assert isinstance(node, Linear) inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) weights = [node.inputs(1) for node in nodes] + print('weights:') + for weight in weights: + print(weight) biass = [node.inputs(2) for node in nodes] + print('bias:') + for bias in biass: + print(bias) - for x in inputs: + for idx, x in enumerate(inputs): assert x.shape == [256, 1024] + assert x.indices.get()[0] == slice(256 * idx, 256 * (idx + 1), 1) assert not inputs[0].overlap(inputs[1]) assert not inputs[0].overlap(inputs[2]) assert not inputs[0].overlap(inputs[3]) @@ -56,3 +68,143 @@ def test_linear_data_parallel(): for b in biass: assert b.shape == [1000] assert b == bias + + +def test_linear_column_weight(): + input = IRFullTensor(shape=[1024, 1024], name='input').tosub() + weight = IRFullTensor(shape=[1000, 1024], name='weight').tosub() + bias = IRFullTensor(shape=[1000,], name='bias').tosub() + + semantic_op = Linear( + signature='torch.nn.functional.linear', + inputs = [input, weight, bias], + ) + semantic_op.infer_shape() + + input_shapes = list() + for input in semantic_op.inputs(): + input_shapes.append(input.shape) + + output_shapes = list() + for output in semantic_op.outputs(): + output_shapes.append(output.shape) + + linear_col_weight = LinearColumnWeight( + input_shapes=input_shapes, + output_shapes=output_shapes, + ) + + # test satisfy + assert linear_col_weight.satisfy(dict(chunk_num=4)) + assert linear_col_weight.satisfy(dict(chunk_num=10)) + assert not linear_col_weight.satisfy(dict(chunk_num=12)) + + nodes = linear_col_weight.instantiate(semantic_op, config=dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, Linear) + + inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) + weights = [node.inputs(1) for node in nodes] + print('weights:') + for weight in weights: + print(weight) + biass = [node.inputs(2) for node in nodes] + print('bias:') + for bias in biass: + print(bias) + outputs = [node.outputs(0) for node in nodes] + print('output:') + for output in outputs: + print(output) + + for x in inputs: + assert x == input + + for idx, w in enumerate(weights): + assert w.shape == [250, 1024] + assert w.indices.get()[0] == slice(250 * idx, 250 * (idx + 1), 1) + + for idx, b in enumerate(biass): + assert b.shape == [250] + assert b.indices.get() == (slice(250 * idx, 250 * (idx + 1), 1),) + + for idx, output in enumerate(outputs): + assert output.shape == [1024, 250] + assert output.indices.get()[0] == slice(0, 1024, 1) + assert output.indices.get()[1] == slice(250 * idx, 250 * (idx + 1), 1) + + +def test_linear_row(): + input = IRFullTensor(shape=[1024, 1024], name='input').tosub() + weight = IRFullTensor(shape=[1000, 1024], name='weight').tosub() + bias = IRFullTensor(shape=[1000,], name='bias').tosub() + + semantic_op = Linear( + signature='torch.nn.functional.linear', + inputs = [input, weight, bias], + ) + semantic_op.infer_shape() + + input_shapes = list() + for input in semantic_op.inputs(): + input_shapes.append(input.shape) + + output_shapes = list() + for output in semantic_op.outputs(): + output_shapes.append(output.shape) + + linear_row_weight = LinearRowWeight( + input_shapes=input_shapes, + output_shapes=output_shapes, + ) + + # test satisfy + assert linear_row_weight.satisfy(dict(chunk_num=4)) + assert not linear_row_weight.satisfy(dict(chunk_num=10)) + assert not linear_row_weight.satisfy(dict(chunk_num=12)) + + nodes = linear_row_weight.instantiate(semantic_op, config=dict(chunk_num=4)) + + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, Linear) + + inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) + weights = [node.inputs(1) for node in nodes] + print('weights:') + for weight in weights: + print(weight) + biass = [node.inputs(2) for node in nodes] + print('bias:') + for bias in biass: + print(bias) + outputs = [node.outputs(0) for node in nodes] + print('output:') + for output in outputs: + print(output) + + for idx, x in enumerate(inputs): + assert x.shape == [1024, 256] + assert x.indices.get()[1] == slice(256 * idx, 256 * (idx + 1), 1) + assert x.val_map == ValueMap(0, 1) + + for idx, w in enumerate(weights): + assert w.shape == [1000, 256] + assert w.indices.get()[1] == slice(256 * idx, 256 * (idx + 1), 1) + assert w.val_map == ValueMap(0, 1) + + for idx, b in enumerate(biass): + assert b.shape == [1000,] + assert b.indices.get()[0] == slice(0, 1000, 1) + assert b.val_map == ValueMap(idx, 4) + + for idx, output in enumerate(outputs): + assert output.shape == [1024, 1000] + assert output.val_map == ValueMap(idx, 4) diff --git a/tests/tensor/test_indices.py b/tests/tensor/test_indices.py deleted file mode 100644 index 37e6b419..00000000 --- a/tests/tensor/test_indices.py +++ /dev/null @@ -1,71 +0,0 @@ -from cube.tensor.indices import BaseIndices, TileIndices - -import torch - -def test_base_indices(): - - tensor = torch.randn((10, 10, 10)) - - # test init - sparse_indices = ( - [2,3,1,4], - [0,4,8,4], - [7,5,9,4] - ) - indices = BaseIndices(sparse_indices) - assert indices.indices == sparse_indices - - # test ndim - assert indices.ndim() == 3 - - # test size - assert indices.size() == 4 - - # test get - sub_tensor = tensor[indices.get()] - assert torch.allclose(sub_tensor, tensor[sparse_indices]) is True - - # test reorder - arg_order = [2, 1, 0, 3] - indices.reorder(arg_order) - sub_tensor = tensor[indices.get()] - - sparse_indices = ( - [1,3,2,4], - [8,4,0,4], - [9,5,7,4] - ) - ref_tensor = tensor[sparse_indices] - assert torch.allclose(sub_tensor, ref_tensor) is True - - -def test_tile_indices(): - - tensor = torch.randn((10, 10, 10)) - - anchor = [3,4,5] - ofst = [2,4,3] - indices = TileIndices(anchor, ofst) - assert indices.anchor == anchor - assert indices.shape == ofst - assert indices.elenum == 2 * 4 * 3 - - # test ndim - assert indices.ndim() == 3 - - # test size - assert indices.size() == 2 * 4 * 3 - - # test get - sub_tensor = tensor[indices.get()] - assert sub_tensor.size() == torch.Size(ofst) - ref_tensor = tensor[(slice(3,3+2), slice(4,4+4), slice(5,5+3))] - assert torch.allclose(sub_tensor, ref_tensor) is True - - # test reorder - ##TODO - - -if __name__ == '__main__': - test_base_indices() - test_tile_indices() \ No newline at end of file diff --git a/tests/tensor/test_logical_tensor.py b/tests/tensor/test_logical_tensor.py deleted file mode 100644 index 6d9ad033..00000000 --- a/tests/tensor/test_logical_tensor.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -cmd for running the test - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - tests/tensor/test_logical_tensor.py -""" - -from cube.tensor.indices import BaseIndices -from cube.tensor.logic.tensor import LogicalTensor -from cube.tensor.segment import Segment -from cube.device.physic.group import DeviceGroup - -import torch - - -def test_logical_tensor_init(): - - tensor = LogicalTensor(shape=(10,10,10)) - assert tensor.shape == (10, 10, 10) - assert len(tensor.segments) == 0 - assert tensor.data is not None - assert tensor.data.size() == torch.Size([10,10,10]) - - -def test_logical_tensor_select(): - tensor = LogicalTensor(shape=(10,10,10)) - sparse_indices = ( - [2,3,1,4], - [0,4,8,4], - [7,5,9,4] - ) - indices = BaseIndices(sparse_indices) - segment = tensor.select(indices, None, shape=(2,2)) - assert isinstance(segment, Segment) - assert segment.materialized is False - - -def test_logical_tensor_fill(): - - myrank = DeviceGroup().rank - - tensor = LogicalTensor(shape=(10,10,10), init_data=False) - sparse_indices = ( - [2,3,1,4], - [0,4,8,4], - [7,5,9,4] - ) - indices = BaseIndices(sparse_indices) - segment = tensor.select(indices, None, shape=(2,2)) - tensor.add_segment(segment) - - assert segment.materialized is False - assert len(tensor.segments) == 1 - - ranks = [1, 3] - if myrank in ranks: - phy_tensor = torch.randn((2,2)).cuda() - else: - phy_tensor = None - tensor.fill([phy_tensor], [ranks]) - assert segment.materialized is True - if myrank in ranks: - assert tensor.get_physical_tensor(0) is not None - else: - assert tensor.get_physical_tensor(0) is None - - -def test_logical_tensor_transform(): - - tensor = LogicalTensor(shape=(10,10,10)) - sparse_indices = ( - [2,3,1,4], - [0,4,8,4], - [7,5,9,4] - ) - indices = BaseIndices(sparse_indices) - segment = tensor.select(indices, None, shape=(2,2)) - - ranks = [0,1,3] - tensor.transform([segment], [ranks]) - - myrank = DeviceGroup().rank - if myrank in ranks: - assert tensor.get_physical_tensor(0) is not None - else: - assert tensor.get_physical_tensor(0) is None - - -if __name__ == '__main__': - - group = DeviceGroup() - - test_logical_tensor_init() - test_logical_tensor_select() - test_logical_tensor_fill() - test_logical_tensor_transform() \ No newline at end of file diff --git a/tests/tensor/test_outline.py b/tests/tensor/test_outline.py deleted file mode 100644 index edf91194..00000000 --- a/tests/tensor/test_outline.py +++ /dev/null @@ -1,132 +0,0 @@ -from cube.tensor.logic.tensor import LogicalTensor -import cube.tensor.logic.outline as outline -from cube.tensor.segment import Segment -from cube.operator.physic.comm.mapreduce import PartialSum - -import torch -import z3 - - -def iter_each_config(solver, attrs): - if len(attrs) == 0: - solver.check() - yield solver.model() - else: - while solver.check() == z3.sat: - config = solver.model() - solver.add(z3.Or([z3.Not(attr == config[attr]) for attr in attrs])) - yield config - - -def test_full(): - shape = (10,10,10) - tensor = torch.randn(shape) - solver = z3.Solver() - - full_dsp = outline.Full(solver, shape) - assert len(full_dsp.get_attributes()) == 0 - - configs = list() - for config in iter_each_config(solver, full_dsp.get_attributes()): - configs.append(config) - - assert len(configs) == 1 - config = configs[0] - - tensor = LogicalTensor(shape=shape) - segments = full_dsp.interpret(tensor, config) - assert len(segments) == 1 - assert tuple(segments[0].shape) == tuple(tensor.shape) - assert torch.allclose(tensor.data, tensor.data[segments[0].indices.get()]) is True - - -def test_split_axis(): - - axis = 1 - shape = [1024, 16] - solver = z3.Solver() - - tensor = torch.randn(shape) - split_dsp = outline.SplitAxis( - solver, shape, axis, chunk_num=None, overlap=0 - ) - - # test config space - configs = list() - for config in iter_each_config(solver, split_dsp.get_attributes()): - configs.append(config) - assert len(configs) == 5 - - # test segments - tensor = LogicalTensor(shape=shape) - segments = split_dsp.interpret(tensor, configs[0]) - shape_axis = [segment.shape[axis] for segment in segments] - assert sum(shape_axis) == shape[axis] - - -def test_split_axis_with_constraints(): - - axis = 1 - shape = [1024, 16] - solver = z3.Solver() - - split_dsp = outline.SplitAxis( - solver, shape, axis, chunk_num=None, overlap=0 - ) - - # this can be set due to device number constraints - split_dsp.solver.add(split_dsp.chunk_num <= 8) - - configs = list() - for config in iter_each_config(solver, split_dsp.get_attributes()): - configs.append(config) - # print(config) - assert len(configs) == 4 - - -def test_split_value(): - - shape = [1024, 32] - solver = z3.Solver() - - split_dsp = outline.SplitValue(solver, shape, None, PartialSum) - split_dsp.solver.add(split_dsp.chunk_num <= 4) - configs = list() - for config in iter_each_config(solver, split_dsp.get_attributes()): - configs.append(config) - assert len(configs) == 4 - - tensor = LogicalTensor(shape=shape) - segments = split_dsp.interpret(tensor, configs[0]) - for segment in segments: - assert torch.allclose(tensor.data, tensor.data[segment.indices.get()]) is True - - -def test_align(): - - shape = [1024, 16] - solver = z3.Solver() - - dsp1 = outline.SplitAxis( - solver, shape, axis=0, chunk_num=None, overlap=0, - ) - - dsp2 = outline.SplitAxis( - solver, shape, axis=1, chunk_num=dsp1.chunk_num, overlap=0, - ) - - configs = list() - attrs = dsp1.get_attributes() + dsp2.get_attributes() - for config in iter_each_config(solver, attrs): - configs.append(config) - assert len(configs) == 5 - - -if __name__ == '__main__': - - # test_base() - test_full() - test_split_axis() - test_split_axis_with_constraints() - test_split_value() - test_align() \ No newline at end of file diff --git a/tests/tensor/test_segment.py b/tests/tensor/test_segment.py deleted file mode 100644 index b7bf0947..00000000 --- a/tests/tensor/test_segment.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -cmd for running the test - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - tests/tensor/test_segment.py -""" - -from cube.tensor.logic.tensor import LogicalTensor -from cube.tensor.segment import Segment -from cube.tensor.indices import BaseIndices, TileIndices -from cube.device.physic.group import DeviceGroup -from cube.operator.physic.comm.mapreduce import PartialSum - -import torch -import os -torch.manual_seed(121) - - -def test_segment_init(): - - tensor = LogicalTensor((10,10,10)) - - anchor = [3,4,5] - ofst = [2,4,3] - indices = TileIndices(anchor, ofst) - - segment = Segment(tensor, indices, None, ofst) - - assert segment.logical_tensor is tensor - assert segment.shape == tuple(ofst) - assert segment.physical_tensor is None - assert len(segment.placement) == 0 - assert segment.group is None - assert len(segment.val_ops) == 0 - assert segment.materialized is False - - -def test_segment_deploy(): - - myrank = DeviceGroup().rank - tensor = LogicalTensor((10,10,10)) - - anchor = [3,4,5] - ofst = [2,4,3] - indices = TileIndices(anchor, ofst) - - segment = Segment(tensor, indices, None, ofst) - - ranks = [0,2] - segment.deploy(ranks) - - physical_tensor = segment.get_physical_tensor() - tensor_ref = tensor.data[indices.get()].cuda() - if myrank in ranks: - assert physical_tensor.device == torch.device('cuda:{}'.format(myrank)) - assert torch.allclose(physical_tensor, tensor_ref) - else: - assert physical_tensor is None - assert segment.placement == ranks - assert segment.group == DeviceGroup().get_group(ranks) - assert len(segment.val_ops) == 0 - assert segment.materialized is True - - -def test_segment_deploy_with_val_map(): - - myrank = DeviceGroup().rank - tensor = LogicalTensor((10,10,10)) - - anchor = [3,4,5] - ofst = [2,4,3] - indices = TileIndices(anchor, ofst) - - segment = Segment( - logical_tensor = tensor, - indices = indices, - val_op = PartialSum, - shape = ofst - ) - assert len(segment.val_ops) == 1 - - ranks = [0,2] - segment.deploy(ranks) - - # deploy check - physical_tensor = segment.get_physical_tensor() - tensor_ref = tensor.data[indices.get()].cuda() / len(ranks) - if myrank in [0,2]: - assert physical_tensor.device == torch.device('cuda:{}'.format(myrank)) - assert torch.allclose(physical_tensor, tensor_ref) is True - else: - assert physical_tensor is None - - -if __name__ == '__main__': - - group = DeviceGroup() - - test_segment_init() - test_segment_deploy() - test_segment_deploy_with_val_map() diff --git a/tests/test_physic_tensor.py b/tests/test_physic_tensor.py deleted file mode 100644 index e96f88ea..00000000 --- a/tests/test_physic_tensor.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from cube.tensor.physic.tensor import PhysicTensor - -def test_type(): - tensor1 = PhysicTensor([1,2,3,4]) - tensor2 = PhysicTensor([2,3,4,5]) - tensor_out = tensor1 + tensor2 - assert isinstance(tensor_out, PhysicTensor) - - -def test_data_host_device(): - tensor = PhysicTensor([1,2,3,4]) - assert tensor.data_host_device == torch.device('cpu') - tensor.data_host_device = torch.device('cuda:0') - assert tensor.device == torch.device('cuda:0') - tensor.move_(torch.device('cpu')) - assert tensor.device == torch.device('cpu') - - -if __name__ == '__main__': - - test_type() - test_data_host_device() - - print('test passed') \ No newline at end of file From ce707a9c917c27a33b7682e57c01c5f2cf1e5bd3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Nov 2021 12:56:45 +0800 Subject: [PATCH 0249/1892] add dependency for overlapping value map --- cube/graph/tensor.py | 13 ++++++++++++- tests/graph/test_tensor.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 69fc7d07..0f9e3740 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -211,6 +211,17 @@ def map(self, sub_map): chunk_num = self.chunk_num - 1 + sub_map.chunk_num return ValueMap(idx, chunk_num) + def overlap(self, other): + if not isinstance(other, ValueMap): + raise TypeError("Expected ValueMap") + if self.chunk_num == other.chunk_num: + return self.idx == other.idx + else: + if self.chunk_num == 1 or other.chunk_num == 1: + return True + else: + raise NotImplementedError("Not Implemented") + def __eq__(self, other): if isinstance(other, ValueMap): if other.idx == self.idx and other.chunk_num == self.chunk_num: @@ -527,7 +538,7 @@ def overlap(self, other): if self.parent != other.parent: return False return self.indices.overlap(other.indices) and \ - self.val_map == other.val_map + self.val_map.overlap(other.val_map) else: raise TypeError("Customized IRTensor not support") diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py index 5553ac32..a3ac6d58 100644 --- a/tests/graph/test_tensor.py +++ b/tests/graph/test_tensor.py @@ -1,6 +1,6 @@ import copy -from cube.graph.tensor import IRFullTensor, IRSubTensor +from cube.graph.tensor import IRFullTensor, IRSubTensor, ValueMap def test_full_tensor_init(): @@ -105,7 +105,7 @@ def test_sub_tensor_select(): assert sub_tensor3 in tensor1.segments() -def test_sub_tensor_overlap(): +def test_sub_tensor_ind_overlap(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( @@ -129,6 +129,36 @@ def test_sub_tensor_overlap(): assert not sub_tensor2.overlap(sub_tensor3) +def test_sub_tensor_val_overlap(): + tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') + sub_tensor1 = tensor1.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_map = None, + shape = (1024, 512) + ) + sub_tensor2 = tensor1.select( + indices = (slice(0, 1024), slice(0, 512)), + val_map = (0, 4), + shape = (1024, 512) + ) + sub_tensor3 = tensor1.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_map = (0, 4), + shape = (1024, 512) + ) + sub_tensor4 = tensor1.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_map = (1, 4), + shape = (1024, 512) + ) + + assert not sub_tensor1.overlap(sub_tensor2) + assert not sub_tensor2.overlap(sub_tensor3) + assert sub_tensor1.overlap(sub_tensor3) + assert sub_tensor1.overlap(sub_tensor4) + assert sub_tensor4.overlap(sub_tensor1) + assert not sub_tensor3.overlap(sub_tensor4) + def test_sub_tensor_common(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') From 6dcb3261d6448f891537c1229805e051a288867b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Nov 2021 13:38:12 +0800 Subject: [PATCH 0250/1892] change interface --- cube/algorithm/factory.py | 9 ++++++++ cube/algorithm/generics.py | 36 ++++++++++++++++++----------- cube/algorithm/linear.py | 18 ++++++++++----- tests/algorithm/test_generics.py | 33 +++++++++++++------------- tests/algorithm/test_linear_algo.py | 31 +++---------------------- 5 files changed, 64 insertions(+), 63 deletions(-) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index ba3384df..7a7bfbea 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -19,6 +19,15 @@ def __init__(self): def __getattr__(self, name): return getattr(self.instance, name) + def exist(self, op): + """ + Check if the factory has op's algorithm recorded + + Returns: + True if have, False if not + """ + return op in self.instance._algos + def register(self, op, algorithm, tag: str): """ Register a holistic op (class) as one of the anchors diff --git a/cube/algorithm/generics.py b/cube/algorithm/generics.py index 534852c4..80c35b95 100644 --- a/cube/algorithm/generics.py +++ b/cube/algorithm/generics.py @@ -1,11 +1,11 @@ -from typing import List, Dict, Optional +from typing import Dict + +from cube.ir.cten import IRCell, IRTensor class GenericDistAlgo: - def __init__(self, - input_shapes: List[Optional[List[int]]], - output_shapes: List[List[int]]): + def __init__(self, node: IRCell): """ Layout is the community distribution requirement for input and output logical tensors. @@ -26,20 +26,30 @@ def __init__(self, output_format (list[list[int], None]): output dim order compare with logical definition """ + if not isinstance(node, IRCell): + raise TypeError("Expected node to be IRCell") + + input_shapes = list() + for input in node.inputs(): + if isinstance(input, IRTensor): + input_shapes.append(input.shape) + else: + input_shapes.append(None) + output_shapes = list() + for output in node.outputs(): + if isinstance(output, IRTensor): + output_shapes.append(output.shape) + else: + output_shapes.append(None) self.input_shapes = input_shapes self.output_shapes = output_shapes - self.logical_op = None + self._logical_op = node - def set_logic_op(self, logic_op): - """ - Set logic op. This will be automatically called when the - holistic op registered in a logical op. - """ - # if not isinstance(logic_op, GenericLogicalOp): - # raise TypeError("Require a logic op to register") - self.logical_op = logic_op + @property + def logic_op(self): + return self._logical_op def satisfy(self, config: Dict): """ diff --git a/cube/algorithm/linear.py b/cube/algorithm/linear.py index 97fb966c..bdd36313 100644 --- a/cube/algorithm/linear.py +++ b/cube/algorithm/linear.py @@ -10,9 +10,11 @@ class LinearDataParallel(GenericDistAlgo): - def __init__(self, input_shapes: List[Optional[List[int]]], output_shapes: List[int]): + def __init__(self, node: Linear): - super().__init__(input_shapes, output_shapes) + if not isinstance(node, Linear): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) self.chunk_num = _kWaitDecision @@ -47,9 +49,11 @@ def instantiate(self, node, config: Dict): class LinearColumnWeight(GenericDistAlgo): - def __init__(self, input_shapes: List[Optional[List[int]]], output_shapes: List[int]): + def __init__(self, node: Linear): - super().__init__(input_shapes, output_shapes) + if not isinstance(node, Linear): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) self.chunk_num = _kWaitDecision @@ -88,9 +92,11 @@ def instantiate(self, node, config: Dict): class LinearRowWeight(GenericDistAlgo): - def __init__(self, input_shapes: List[Optional[List[int]]], output_shapes: List[int]): + def __init__(self, node: Linear): - super().__init__(input_shapes, output_shapes) + if not isinstance(node, Linear): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) self.chunk_num = _kWaitDecision diff --git a/tests/algorithm/test_generics.py b/tests/algorithm/test_generics.py index 5fbe97e0..303b5eb6 100644 --- a/tests/algorithm/test_generics.py +++ b/tests/algorithm/test_generics.py @@ -1,21 +1,22 @@ from cube.algorithm.generics import GenericDistAlgo -from cube.graph.operator.function import Linear +from cube.ir.cten import IRCell, IRTensor def test_generic_algo_init(): + input1 = IRTensor(shape=[1024, 1024]) + input2 = IRTensor(shape=[1024, 1000]) + bias = None + cell = IRCell(name='test', signature='test', input_length=3, output_length=1) + cell.set_input(0, input1) + cell.set_input(1, input2) + cell.set_input(2, bias) + cell.outputs(0).shape = [1024, 1000] - algo = GenericDistAlgo( - input_shapes=[[1024,1024], [1024, 1024], None], - output_shapes=[[1024, 1024]] - ) - assert algo.logical_op is None - - -def test_generic_set_logic_op(): - - algo = GenericDistAlgo( - input_shapes=[[1024,1024], [1024, 1024], None], - output_shapes=[[1024, 1024]] - ) - algo.set_logic_op(Linear) - assert algo.logical_op == Linear + algo = GenericDistAlgo(cell) + assert algo.logic_op is cell + assert len(algo.input_shapes) == 3 + assert algo.input_shapes[0] == [1024, 1024] + assert algo.input_shapes[1] == [1024, 1000] + assert algo.input_shapes[2] is None + assert len(algo.output_shapes) == 1 + assert algo.output_shapes[0] == [1024, 1000] diff --git a/tests/algorithm/test_linear_algo.py b/tests/algorithm/test_linear_algo.py index 861445c0..67e14fb2 100644 --- a/tests/algorithm/test_linear_algo.py +++ b/tests/algorithm/test_linear_algo.py @@ -17,18 +17,7 @@ def test_linear_data_parallel(): ) semantic_op.infer_shape() - input_shapes = list() - for input in semantic_op.inputs(): - input_shapes.append(input.shape) - - output_shapes = list() - for output in semantic_op.outputs(): - output_shapes.append(output.shape) - - linear_dp = LinearDataParallel( - input_shapes=input_shapes, - output_shapes=output_shapes, - ) + linear_dp = LinearDataParallel(semantic_op) assert linear_dp.chunk_num is None @@ -81,18 +70,7 @@ def test_linear_column_weight(): ) semantic_op.infer_shape() - input_shapes = list() - for input in semantic_op.inputs(): - input_shapes.append(input.shape) - - output_shapes = list() - for output in semantic_op.outputs(): - output_shapes.append(output.shape) - - linear_col_weight = LinearColumnWeight( - input_shapes=input_shapes, - output_shapes=output_shapes, - ) + linear_col_weight = LinearColumnWeight(semantic_op) # test satisfy assert linear_col_weight.satisfy(dict(chunk_num=4)) @@ -157,10 +135,7 @@ def test_linear_row(): for output in semantic_op.outputs(): output_shapes.append(output.shape) - linear_row_weight = LinearRowWeight( - input_shapes=input_shapes, - output_shapes=output_shapes, - ) + linear_row_weight = LinearRowWeight(semantic_op) # test satisfy assert linear_row_weight.satisfy(dict(chunk_num=4)) From aceb023b7a4c6e9115e0e8c2a05e46b70c2f0c17 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Nov 2021 13:51:28 +0800 Subject: [PATCH 0251/1892] add op algorithm interface --- cube/algorithm/factory.py | 7 +++++-- cube/graph/operator/operator.py | 25 +++++++++++++++++++++++++ tests/graph/test_function.py | 19 +++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 tests/graph/test_function.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 7a7bfbea..0edef270 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -19,14 +19,17 @@ def __init__(self): def __getattr__(self, name): return getattr(self.instance, name) - def exist(self, op): + def exist(self, op, tag=None): """ Check if the factory has op's algorithm recorded Returns: True if have, False if not """ - return op in self.instance._algos + if tag is None: + return op in self.instance._algos + else: + return op in self.instance._algos and tag in self.instance._algos[op] def register(self, op, algorithm, tag: str): """ diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index c0fb0bb2..156c9ea8 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,5 +1,8 @@ +from typing import Optional + from cube.ir.cten import IRTensor, IRCell from cube.graph.tensor import IRFullTensor +from cube.algorithm.factory import DistAlgorithmFactory __call__ = ['IROperation'] @@ -32,6 +35,28 @@ def infer_shape(self): """ raise NotImplementedError + def algorithms(self, tag: Optional[str] = None): + """ + get algorithm from algorithm factory + + Args: + tag: str or None. If None, return all + """ + factory = DistAlgorithmFactory() + if tag is None: + templates = list() + if factory.exist(type(self)): + templates = factory.algorithms(type(self)) + algos = list() + for template in templates: + algos.append(template(self)) + return algos + else: + if not factory.exist(type(self)): + return None + template = factory.algorithms(type(self), tag) + return template(self) + def __repr__(self): inputs = list() for tensor in self.inputs(): diff --git a/tests/graph/test_function.py b/tests/graph/test_function.py new file mode 100644 index 00000000..927be6ae --- /dev/null +++ b/tests/graph/test_function.py @@ -0,0 +1,19 @@ +from cube.graph.operator.function import Linear +from cube.graph.tensor import IRFullTensor +from cube.algorithm.linear import LinearDataParallel + + +def test_linear_algo(): + + input = IRFullTensor(shape=[1024, 1024], name='input').tosub() + weight = IRFullTensor(shape=[1000, 1024], name='weight').tosub() + bias = IRFullTensor(shape=[1000,], name='bias').tosub() + + semantic_op = Linear( + signature='torch.nn.functional.linear', + inputs = [input, weight, bias], + ) + semantic_op.infer_shape() + + assert len(semantic_op.algorithms()) == 3 + assert isinstance(semantic_op.algorithms('data'), LinearDataParallel) From 6fe09eee4085eabb681f6f9df2c65a12003468d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Nov 2021 15:50:05 +0800 Subject: [PATCH 0252/1892] fix bugs on dependency track --- cube/algorithm/generics.py | 2 +- cube/graph/graph.py | 55 +++++++++++++++++++++++++++------ cube/graph/operator/operator.py | 2 +- cube/ir/cten.py | 18 ++++++++++- tests/graph/test_graph.py | 31 +++++++++++++++++++ 5 files changed, 95 insertions(+), 13 deletions(-) diff --git a/cube/algorithm/generics.py b/cube/algorithm/generics.py index 80c35b95..73b05a73 100644 --- a/cube/algorithm/generics.py +++ b/cube/algorithm/generics.py @@ -45,7 +45,7 @@ def __init__(self, node: IRCell): self.input_shapes = input_shapes self.output_shapes = output_shapes - self._logical_op = node + self._logical_op = type(node) @property def logic_op(self): diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ae4d82f3..71b5e393 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,13 +7,15 @@ will be inserted at scheduling time. """ -from typing import Union, Tuple, List, Optional, Any +from typing import Union, Tuple, List, Optional, Any, Dict import copy from cube.ir.cten import IRTensor, IRCell from cube.graph.operator import IROperation from cube.graph.tensor import IRFullTensor +from cube.algorithm.generics import GenericDistAlgo + __all__ = ['IRGraph'] @@ -33,7 +35,6 @@ def __init__(self, self._nodes: List[IROperation] = nodes self._parameters = list() - self.reset_dependency() if input_tensors is None: input_tensors = IRCell.get_inputs(nodes) @@ -82,21 +83,19 @@ def __init__(self, input.as_param() self._parameters.append(input) self.tag = 'forward' + self.reset_dependency() def reset_dependency(self): """ Reset the node dataflow dependency """ + for node in self._nodes: + node.clear_predecessor() + node.clear_successor() # set node predecessors and successors for src_idx in range(len(self._nodes)): src_node = self._nodes[src_idx] - src_node._successors = [ - list() for _ in range(len(src_node.outputs())) - ] for dst_node in self._nodes[src_idx+1:]: - dst_node._predecessors = [ - list() for _ in range(len(dst_node.inputs())) - ] for out_idx, out_tensor in enumerate(src_node.outputs()): if not isinstance(out_tensor, IRTensor): continue @@ -314,10 +313,46 @@ def subgraph(self, sub_nodes: List[IRCell]): return graph + def _remove(self, node: IRCell): + """ + Remove a node from graph + """ + if node in self.nodes(): + self._nodes.remove(node) + #TODO: remove parameters + self.reset_dependency() + ## Primitives for policy expression ## - def partition(self, op, op_partition_algorithm, config): - raise NotImplementedError + def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: + """ + Policy primitive. Partition an operator by using + op_partition_algorithm and its configuration + + Args: + op: cell to be partitioned + algo: generic distributed algorithm related to the op + config: dict + + Returns: + nodes: List[IRCell] if partitioned successfully. + None if failed + """ + if not isinstance(op, IRCell): + raise TypeError("Expected op to be IRCell (IROperation)") + if not isinstance(algo, GenericDistAlgo): + raise TypeError("Expected algo to be GenericDistAlgo") + + if algo.logic_op != type(op): + return None + if not algo.satisfy(config): + return None + nodes = algo.instantiate(op, config) + idx = self._nodes.index(op) + self._nodes = self._nodes[:idx] + nodes + self._nodes[idx+1:] + self.reset_dependency() + return copy.copy(nodes) + def merge(self, sub_graph, target_op, op_partition_algorithm): raise NotImplementedError diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 156c9ea8..c8da2ded 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -52,7 +52,7 @@ def algorithms(self, tag: Optional[str] = None): algos.append(template(self)) return algos else: - if not factory.exist(type(self)): + if not factory.exist(type(self), tag): return None template = factory.algorithms(type(self), tag) return template(self) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index d710efd4..11fc1e63 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -1,4 +1,4 @@ -""" +r""" IRCell: a graph node component serving for different purpose, e.g., operator, device graph, graph @@ -236,6 +236,14 @@ def add_predecessor(self, input_index: int, cell): if cell not in self._predecessors[input_index]: self._predecessors[input_index].append(cell) + def clear_predecessor(self): + """ + Clear all predecessors + """ + self._predecessors = [ + list() for _ in range(len(self.inputs())) + ] + def add_successor(self, output_index: int, cell): """ Set self node the output index node. @@ -246,6 +254,14 @@ def add_successor(self, output_index: int, cell): if cell not in self._successors[output_index]: self._successors[output_index].append(cell) + def clear_successor(self): + """ + Clear all successors + """ + self._successors = [ + list() for _ in range(len(self.outputs())) + ] + @staticmethod def get_inputs(cells): """ diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 13cee965..7ca9307e 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -128,3 +128,34 @@ def test_graph_copy(): assert len(gnode.inputs()) == len(cnode.outputs()) assert len(gnode.predecessors()) == len(cnode.successors()) assert len(gnode.successors()) == len(cnode.predecessors()) + + +def test_graph_partition(): + + inputs, ops, outputs = construct_model() + graph = IRGraph(ops, inputs, outputs, 'MLP') + + node1, node2, node3 = graph.nodes() + + algo = node2.algorithms('data') + sub_nodes = graph.partition(node2, algo, config=dict(chunk_num=4)) + assert sub_nodes is not None + assert len(graph.nodes()) == 6 + dnode1, dnode2, dnode3, dnode4 = sub_nodes + assert dnode2 not in dnode1.successors() + assert dnode3 not in dnode1.successors() + assert dnode4 not in dnode1.successors() + + algo = node3.algorithms('column') + sub_nodes = graph.partition(node3, algo, config=dict(chunk_num=4)) + print(graph) + + cnode1, cnode2, cnode3, cnode4 = sub_nodes + for cnode in sub_nodes: + print(cnode, cnode.successors()) + print(cnode.predecessors(0)) + assert dnode1 in cnode.predecessors() + assert dnode2 in cnode.predecessors() + assert dnode3 in cnode.predecessors() + assert dnode4 in cnode.predecessors() + assert len(graph.nodes()) == 9 From fc7c5afae1a898fe3594802be7c9e4bd14c50b32 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Nov 2021 16:39:26 +0800 Subject: [PATCH 0253/1892] merge device in runtime --- cube/__init__.py | 2 +- cube/codegen/codegen.py | 4 ++-- cube/device/__init__.py | 0 cube/device/physic/__init__.py | 0 cube/runtime/__init__.py | 2 +- cube/{device/physic/group.py => runtime/device.py} | 0 cube/runtime/{temporal.py => executor.py} | 14 +++++++++----- tests/graph/test_graph.py | 2 -- 8 files changed, 13 insertions(+), 11 deletions(-) delete mode 100644 cube/device/__init__.py delete mode 100644 cube/device/physic/__init__.py rename cube/{device/physic/group.py => runtime/device.py} (100%) rename cube/runtime/{temporal.py => executor.py} (82%) diff --git a/cube/__init__.py b/cube/__init__.py index 3b647e70..a91d0ffa 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,4 +1,4 @@ -from cube.device.physic.group import DeviceGroup +from cube.runtime.device import DeviceGroup from cube import schedule from cube import runtime diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 41b81adb..4f553ae0 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -292,8 +292,8 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: """ Emit su code """ - fsign = 'cube.runtime.temporal.forward({model}, *{inputs})' - bsign = 'cube.runtime.temporal.backward({input_tensors}, {output_tensors}, {output_grads})' + fsign = 'cube.runtime.executor.fexecute({model}, *{inputs})' + bsign = 'cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' if su.stype == SUType.Dataloader: if len(su.inputs()) != 0: diff --git a/cube/device/__init__.py b/cube/device/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/device/physic/__init__.py b/cube/device/physic/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index 7fe29b15..90200379 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -1 +1 @@ -from cube.runtime import collectives, temporal +from cube.runtime import collectives, executor, device diff --git a/cube/device/physic/group.py b/cube/runtime/device.py similarity index 100% rename from cube/device/physic/group.py rename to cube/runtime/device.py diff --git a/cube/runtime/temporal.py b/cube/runtime/executor.py similarity index 82% rename from cube/runtime/temporal.py rename to cube/runtime/executor.py index bca6125d..a5d049c5 100644 --- a/cube/runtime/temporal.py +++ b/cube/runtime/executor.py @@ -1,19 +1,23 @@ -from typing import Tuple, Any +r""" +SU Executor for runtime +""" + +from typing import Tuple, Any, Callable import torch -def forward(model, *input_tensors: Tuple[Any]): +def fexecute(su: Callable, *input_tensors: Tuple[Any]): """ - forward the model + forward the SUs """ - outputs = model(*input_tensors) + outputs = su(*input_tensors) print('forwarding... ') return outputs def backward(input_tensors, output_tensors, output_tensor_grads): """ - Backward on the tensors + Backward the SUs """ for tensor in input_tensors: if torch.is_tensor(tensor) and tensor.requires_grad: diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 7ca9307e..e3f9c77c 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -152,8 +152,6 @@ def test_graph_partition(): cnode1, cnode2, cnode3, cnode4 = sub_nodes for cnode in sub_nodes: - print(cnode, cnode.successors()) - print(cnode.predecessors(0)) assert dnode1 in cnode.predecessors() assert dnode2 in cnode.predecessors() assert dnode3 in cnode.predecessors() From a2ca94fd9c2fea37514e3da9358a2375b3462cfd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Nov 2021 16:46:30 +0800 Subject: [PATCH 0254/1892] move test files --- tests/{ => runtime}/test_group.py | 2 +- tests/{ => runtime}/test_nccl.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/{ => runtime}/test_group.py (94%) rename tests/{ => runtime}/test_nccl.py (100%) diff --git a/tests/test_group.py b/tests/runtime/test_group.py similarity index 94% rename from tests/test_group.py rename to tests/runtime/test_group.py index 10ac0526..21bd37d4 100644 --- a/tests/test_group.py +++ b/tests/runtime/test_group.py @@ -11,7 +11,7 @@ tests/test_group.py """ -from cube.device.physic.group import DeviceGroup +from cube.runtime.device import DeviceGroup import torch diff --git a/tests/test_nccl.py b/tests/runtime/test_nccl.py similarity index 100% rename from tests/test_nccl.py rename to tests/runtime/test_nccl.py From 37ddaadd2dac3f68855c7d3105cb5a63086ee335 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 6 Nov 2021 13:16:30 +0800 Subject: [PATCH 0255/1892] add operator types --- cube/algorithm/factory.py | 2 +- cube/graph/__init__.py | 2 +- cube/graph/gpass.py | 111 ++++++++++++++++ cube/graph/graph.py | 89 +------------ cube/graph/operator/__init__.py | 4 +- cube/graph/operator/adapter.py | 221 ++++++++++++++++++++++++++++++++ cube/graph/operator/function.py | 12 +- cube/graph/operator/operator.py | 127 +++++++++++++++++- cube/graph/parser/mapping.py | 4 +- cube/graph/parser/parser.py | 18 +-- cube/graph/tensor.py | 19 +++ 11 files changed, 499 insertions(+), 110 deletions(-) create mode 100644 cube/graph/gpass.py create mode 100644 cube/graph/operator/adapter.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 0edef270..d88e7785 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -44,7 +44,7 @@ def algorithms(self, op, tag = None): Get op tranformed algorithms Args: - op (IROperation): index for the holist op factory + op (IRFwOperation): index for the holist op factory args, kwargs: (logical) tensor inputs Returns: diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index 30225faa..527a5d19 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -1,5 +1,5 @@ from cube.graph.graph import IRGraph -from cube.graph.operator import IROperation +from cube.graph.operator import IRFwOperation from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.graph import parser diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py new file mode 100644 index 00000000..88bc6850 --- /dev/null +++ b/cube/graph/gpass.py @@ -0,0 +1,111 @@ +from typing import Any +import copy + +from cube.graph.graph import IRGraph +from cube.graph.tensor import IRSubTensor +from cube.graph.operator import IRBpOperation + +from cube.ir.cten import IRTensor + + +__all__ = ['forward', 'backward'] + + +class _TensorGener: + + def __init__(self): + self.symbol = dict() + + def renew(self, val: Any, keep_param=True): + self._check_is_sub_tensor(val) + if not isinstance(val, IRTensor): + return val + if keep_param and val.is_param(): + return val + if val.parent._id not in self.symbol: + self.symbol[val.parent._id] = val.parent.like() + new_val = self.symbol[val.parent._id].select( + indices=val.indices, + val_map=val.val_map, + shape=val.shape + ) + return new_val + + def set_map(self, origin: Any, new: Any): + self._check_is_sub_tensor(origin) + self._check_is_sub_tensor(new) + if isinstance(origin, IRSubTensor): + tid = origin.parent._id + if isinstance(new, IRSubTensor): + self.symbol[tid] = new.parent + return + self.symbol[tid] = new + + def _check_is_sub_tensor(self, tensor): + if isinstance(tensor, IRTensor): + if not isinstance(tensor, IRSubTensor): + raise TypeError("Tensor only allows to be SubTensor") + + +def forward(graph, *args) -> IRGraph: + """ + Forward the IRGraph, replacing all the intermediate tensors + """ + if not isinstance(graph, IRGraph): + raise TypeError("Forwarding requires IRGraph") + + gener = _TensorGener() + + for input, arg in zip(graph.inputs(), args): + gener.set_map(input, arg) + + fnodes = list() + bnodes = list() + for node in graph.nodes(): + inputs = node.inputs() + outputs = node.outputs() + + # forwrd node + fnode = copy.copy(node) + # set forward inputs + for idx, val in enumerate(inputs): + fnode.set_input(idx, gener.renew(val)) + # set forward outputs + for idx, val in enumerate(outputs): + fnode.set_output(idx, gener.renew(val)) + + # backward node + bnode = IRBpOperation(data_num=len(inputs), grad_num=len(outputs)) + # set backward grad + for idx, val in enumerate(fnode.inputs()): + # set input + bnode.set_data(idx, val) + val = val if isinstance(val, IRTensor) else None + val = gener.renew(val, keep_param=False) + val = val.as_grad() if isinstance(val, IRTensor) else val + # set gradient output + bnode.set_output(idx, val) + for idx, val in enumerate(fnode.outputs()): + # set gradient input + val = gener.renew(val, keep_param=False) + val = val.as_grad() if isinstance(val, IRTensor) else val + bnode.set_grad(idx, val) + + fnode.device = node.device + bnode.device = node.device + + # mirror node for forward / backward + fnode.mirror = bnode + bnode.mirror = fnode + + fnodes.append(fnode) + bnodes.append(bnode) + + inputs = [gener.renew(input) for input in graph.inputs()] + outputs = [gener.renew(output) for output in graph.outputs()] + + fgraph = IRGraph(fnodes, inputs, outputs, graph.name) + for output in fgraph.outputs(): + output.set_trace(fgraph.nodes()) + return fgraph + diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 71b5e393..c2d2e8c9 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -1,6 +1,6 @@ """ IRGraph: - a graph that is composed by node (IROperation) and edge (IRTensor). + a graph that is composed by node (IRFwOperation) and edge (IRTensor). Note the device of graph.inputs() can be different of the same input tensor of operation node in the graph. In this case, a move operation @@ -11,7 +11,6 @@ import copy from cube.ir.cten import IRTensor, IRCell -from cube.graph.operator import IROperation from cube.graph.tensor import IRFullTensor from cube.algorithm.generics import GenericDistAlgo @@ -28,12 +27,12 @@ class IRGraph(IRCell): """ def __init__(self, - nodes: List[IROperation], + nodes: List[IRCell], input_tensors: Optional[List[IRTensor]], output_tensors: Optional[List[IRTensor]], module_name: str): - self._nodes: List[IROperation] = nodes + self._nodes: List[IRCell] = nodes self._parameters = list() if input_tensors is None: @@ -82,7 +81,6 @@ def __init__(self, input.is_leaf(self._nodes): input.as_param() self._parameters.append(input) - self.tag = 'forward' self.reset_dependency() def reset_dependency(self): @@ -112,75 +110,6 @@ def parameters(self): """ return copy.copy(self._parameters) - def copy(self, reverse=False): - """ - Copy the graph but re-new the intermediate tensor - """ - # old graph tensor.parent._id -> new full tensor - new_full_tensors = dict() - - def _renew(val: Any): - if not isinstance(val, IRTensor): - return val - elif isinstance(val, IRFullTensor): - raise RuntimeError("Found Full Tensor") - # parameters in forward - if (not reverse) and val.is_param(): - return val - # intermediate / gradient data - if val.parent._id not in new_full_tensors: - new_full_tensors[val.parent._id] = val.parent.like() - full_tensor = new_full_tensors[val.parent._id] - new_val = full_tensor.select( - indices=val.indices, - val_map=val.val_map, - shape=val.shape - ) - if reverse and val.is_param(): - new_val.name = 'grad_' + new_val.name - assert new_val.is_param() - return new_val - - nodes = list() - for node in self.nodes(): - - if isinstance(node, IROperation): - inputs = node.inputs() - outputs = node.outputs() - if reverse: - inputs, outputs = outputs, inputs - - new_node = IROperation( - node.name, node.signature, - len(inputs), len(outputs) - ) - # set inputs - for idx, val in enumerate(inputs): - new_node.set_input(idx, _renew(val)) - # set outputs - for idx, val in enumerate(outputs): - new_node.set_output(idx, _renew(val)) - else: - raise TypeError("Found node with unsupported copy") - new_node.device = node.device - nodes.append(new_node) - - inputs = [_renew(input) for input in self.inputs()] - outputs = [_renew(output) for output in self.outputs()] - - if reverse: - inputs, outputs = outputs, inputs - nodes = nodes[::-1] - - copied_graph = IRGraph( - nodes = nodes, - input_tensors = inputs, - output_tensors = outputs, - module_name = self.name - ) - copied_graph.tag = self.tag - return copied_graph - def nodes(self, index: Optional[int] = None): """ Get node at position index @@ -310,18 +239,8 @@ def subgraph(self, sub_nodes: List[IRCell]): output_tensors = outputs, module_name = self.name ) - return graph - def _remove(self, node: IRCell): - """ - Remove a node from graph - """ - if node in self.nodes(): - self._nodes.remove(node) - #TODO: remove parameters - self.reset_dependency() - ## Primitives for policy expression ## def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: @@ -339,7 +258,7 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional None if failed """ if not isinstance(op, IRCell): - raise TypeError("Expected op to be IRCell (IROperation)") + raise TypeError("Expected op to be IRCell (IRFwOperation)") if not isinstance(algo, GenericDistAlgo): raise TypeError("Expected algo to be GenericDistAlgo") diff --git a/cube/graph/operator/__init__.py b/cube/graph/operator/__init__.py index 77f5f416..80f9ae2e 100644 --- a/cube/graph/operator/__init__.py +++ b/cube/graph/operator/__init__.py @@ -1 +1,3 @@ -from cube.graph.operator.operator import IROperation \ No newline at end of file +from cube.graph.operator.operator import IRFwOperation +from cube.graph.operator.operator import IRBpOperation +from cube.graph.operator.operator import IRDataOperation \ No newline at end of file diff --git a/cube/graph/operator/adapter.py b/cube/graph/operator/adapter.py new file mode 100644 index 00000000..ed5f8ca5 --- /dev/null +++ b/cube/graph/operator/adapter.py @@ -0,0 +1,221 @@ +from typing import List, Optional +from enum import Enum +import numpy as np + +from cube.ir.cten import IRCell, IRTensor +from cube.graph.tensor import IRSubTensor, IndexMap + + +class IRReshapeType(Enum): + + Select = 'cube.runtime.adapter.select' + Merge = 'cube.runtime.adapter.merge' + + +class IRShapeAdapter(IRCell): + """ + Tensor transformation by convert source tensors + to destination tensors + + Select: + src_tensors is only one tensor, dst_tensors has (multiple) tensors. + This will select the sub_tensor and generate what it need + + Merge: + src_tensors has (multiple) tensors, dst_tensors is only one tensor. + This will merge the sub_tensor and generate what it need + """ + def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor]): + + if len(src_tensors) != 1 and len(dst_tensors) != 1: + raise ValueError("Expected at least one of tensors has length 1") + + self.ttype = None + + self._select_indices: List[IndexMap] = list() + self._merge_axis = None + + if len(src_tensors) == 1: + self.ttype = IRReshapeType.Select + src_tensor = src_tensors[0] + if not isinstance(src_tensor, IRSubTensor): + raise TypeError(f"Expected IRSubTensor but got {type(src_tensor)}") + # select + for tensor in dst_tensors: + indices = tensor.indices & src_tensor.indices + self._select_indices.append(indices) + + elif len(dst_tensors) == 1: + self.ttype = IRReshapeType.Merge + dst_tensor = dst_tensors[0] + # find dims to concat + ndims = len(dst_tensor.shape) + indices = [set() for _ in range(ndims)] + for src_tensor in src_tensors: + if isinstance(src_tensor, IRSubTensor): + for ndim, slicer in enumerate(src_tensor.indices.get()): + indices[ndim].add((slicer.start, slicer.stop, slicer.step)) + else: + raise RuntimeError( + f"Expected SubTensor but got {type(src_tensor)}" + ) + # check if only one dim set has multiple slicer + for dim, dim_indices in enumerate(indices): + if len(dim_indices) != 1: + if self._merge_axis is not None: + print("src: ", src_tensors) + print("dst: ", dst_tensors) + raise NotImplementedError("Only support merge on one axis") + self._merge_axis = dim + if self._merge_axis is None: + # check the coverage + if src_tensors[0].indices != dst_tensor.indices: + raise RuntimeError("Not cover all the indices to merge.") + # get merge axis + if self._merge_axis is not None: + dim_indices = indices[self._merge_axis] + # check if they are overlapped + starts = np.array([slicer[0] for slicer in dim_indices]) + stops = np.array([slicer[1] for slicer in dim_indices]) + steps = np.array([slicer[2] for slicer in dim_indices]) + sorted_idx = np.argsort(starts) + sorted_starts = starts[sorted_idx] + sorted_stops = stops[sorted_idx] + sorted_steps = steps[sorted_idx] + for last_stop, begin_start in zip(sorted_stops[:-1], sorted_starts[1:]): + if last_stop != begin_start: + raise NotImplementedError(f"Concatenation fails due to axis {last_stop} != {begin_start}") + for step in sorted_steps: + if step != 1: + raise NotImplementedError(f"Found a SubTensor step {step} != 1") + # re-order + src_tensors = np.array(src_tensors)[sorted_idx] + + else: + raise RuntimeError("Internal Error") + + super().__init__( + name = 'transformation', + signature = self.ttype.value, + input_length = len(src_tensors), + output_length = len(dst_tensors) + ) + for idx, input in enumerate(src_tensors): + self.set_input(idx, input) + for idx, output in enumerate(dst_tensors): + self.set_output(idx, output) + + @property + def select_indices(self) -> List[IndexMap]: + return self._select_indices + + @property + def merge_axis(self) -> Optional[int]: + return self._merge_axis + + def is_identity(self): + """ + Check if this transformation is a non-op + """ + if self.ttype == IRReshapeType.Select: + src_tensor = self.inputs(0) + for dst_tensor in self.outputs(): + if dst_tensor != src_tensor: + return False + return True + if self.ttype == IRReshapeType.Merge: + if self.merge_axis is None: + return True + return False + return False + + +class IRCommType(Enum): + + Send = 'send' + Recv = 'recv' + SendRecv = 'sendrecv' + + +class IRCommAdapter(IRCell): + """ + Communication cell for IRCell + """ + + def __init__(self, + send_tensors=list(), send_ranks: List[List[int]] = list(), + recv_tensors=list(), recv_ranks: List[List[int]] =list()): + """ + Create a basic send, recv or sendrecv communication node + """ + if len(send_tensors) != 0 and len(recv_tensors) != 0: + comm_type = IRCommType.SendRecv + signature = 'cube.runtime.collectives.sendrecv' + elif len(send_tensors) != 0 and len(recv_tensors) == 0: + comm_type = IRCommType.Send + signature = 'cube.runtime.collectives.send' + elif len(recv_tensors) != 0 and len(send_tensors) == 0: + comm_type = IRCommType.Recv + signature = 'cube.runtime.collectives.recv' + else: + raise ValueError( + "Expected at least one of send_tensors and recv_tensors" + ) + + self.comm_type = comm_type + self.send_tensors = list() + self.send_ranks = list() + self.recv_tensors = list() + self.recv_ranks = list() + + super().__init__( + name = comm_type.value, + signature = signature, + input_length = len(send_tensors), + output_length = len(recv_tensors) + ) + + for idx, (tensor, to_device) in enumerate(zip(send_tensors, send_ranks)): + self.set_input(idx, tensor) + self.send_tensors.append(self.inputs(idx)) + self.send_ranks.append(to_device) + + for idx, (tensor, from_device) in enumerate(zip(recv_tensors, recv_ranks)): + self.set_output(idx, tensor) + self.recv_tensors.append(self.outputs(idx)) + self.recv_ranks.append(from_device) + + self.msg_id = self._id + + def pair(self, other): + """ + Pair two comm node to have same message id. + + The `other` message id is set same with caller + """ + if not isinstance(other, IRCommAdapter): + raise RuntimeError("Expected IRCommAdapter to pair") + other.msg_id = self.msg_id + + def merge(self, other): + if not isinstance(other, IRCommAdapter): + raise RuntimeError("Expected IRCommAdapter to merge") + raise NotImplementedError + + def __repr__(self): + inputs = list() + for tensor in self.inputs(): + if isinstance(tensor, IRTensor): + inputs.append(f't{tensor._id}-dev{tensor.device}') + else: + inputs.append(tensor) + + outputs = list() + for tensor in self.outputs(): + if isinstance(tensor, IRTensor): + outputs.append(f't{tensor._id}-dev{tensor.device}') + else: + outputs.append(tensor) + + dscp = f'SendRecv(msg_id={self.msg_id}, device={self.device}, send={inputs}, recv={outputs})' + return dscp diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index a86918b7..2fc25e68 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -1,10 +1,10 @@ import copy -from cube.graph.operator import IROperation +from cube.graph.operator import IRFwOperation from cube.ir.cten import IRTensor -class Linear(IROperation): +class Linear(IRFwOperation): def __init__(self, signature, inputs, name='linear', **kwargs): @@ -31,7 +31,7 @@ def infer_shape(self): return False -class ElementWise(IROperation): +class ElementWise(IRFwOperation): """ Functions like torch.add (tensor1 + tensor2 / scaler) """ @@ -56,7 +56,7 @@ def infer_shape(self): return False -class ElementWiseActivation(IROperation): +class ElementWiseActivation(IRFwOperation): """ functions like GELU, RELU, Dropout. @@ -83,7 +83,7 @@ def infer_shape(self): return False -class Reduce(IROperation): +class Reduce(IRFwOperation): """ functions like sum, mean, cross_entropy """ @@ -101,7 +101,7 @@ def infer_shape(self): return True -class UnkownOperator(IROperation): +class UnkownOperator(IRFwOperation): def __init__(self, signature, inputs, name='unknown_op', n_output=None): diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index c8da2ded..70b0b8e1 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,14 +1,14 @@ -from typing import Optional +from typing import Any, Optional, Union, List from cube.ir.cten import IRTensor, IRCell from cube.graph.tensor import IRFullTensor from cube.algorithm.factory import DistAlgorithmFactory -__call__ = ['IROperation'] +__call__ = ['IRFwOperation', 'IRBpOperation'] -class IROperation(IRCell): +class IRFwOperation(IRCell): def __init__(self, name: str, @@ -61,16 +61,133 @@ def __repr__(self): inputs = list() for tensor in self.inputs(): if isinstance(tensor, IRTensor): - inputs.append(f't{tensor._id}') + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + inputs.append(f'{anno}{tensor._id}') else: inputs.append(tensor) outputs = list() for tensor in self.outputs(): if isinstance(tensor, IRTensor): - outputs.append(f't{tensor._id}') + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + outputs.append(f'{anno}{tensor._id}') else: outputs.append(tensor) dscp = f'Op(id={self._id}, signature={self.signature}, device={self.device}, inputs={inputs}, outputs={outputs})' return dscp + + +class IRBpOperation(IRCell): + + def __init__(self, data_num, grad_num, name='backward'): + signature = 'torch.autograd.backward' + self.data_num = data_num + self.grad_num = grad_num + super().__init__( + name, signature, + input_length=data_num + grad_num, + output_length=data_num + ) + + def datas(self, index: Optional[int] = None) -> Union[List[Any], Any]: + if index is None: + return self.inputs()[:self.data_num] + if index >= self.data_num: + raise RuntimeError( + f"Set the input out of range ({index} >= {self.data_num})" + ) + return self.inputs(index) + + def grads(self, index: Optional[int] = None) -> Union[List[Any], Any]: + if index is None: + return self.inputs()[self.data_num:] + elif index >= self.grad_num: + raise RuntimeError( + f"Set the input out of range ({index} >= {self.grad_num})" + ) + return self.inputs(index + self.data_num) + + def set_data(self, input_index: int, val: Any): + """ + Set the node inputs[input_index] with the tensor + + Args: + val: Union[IRTensor, Any] + + Return: + the set tensor + """ + if input_index >= self.data_num: + raise RuntimeError( + f"Set the input out of range ({input_index} >= {self.data_num})" + ) + return self.set_input(input_index, val) + + def set_grad(self, input_index: int, val: Any): + """ + Set the node gradient at input index + + Args: + input_idx: input index + val: Union[IRTensor, Any] + + Return: + The set val + """ + if input_index >= self.grad_num: + raise RuntimeError( + f"Set the grad out of range ({input_index} >= {self.grad_num})" + ) + return self.set_input(input_index + self.data_num, val) + + def __repr__(self): + datas = list() + for tensor in self.datas(): + if isinstance(tensor, IRTensor): + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + datas.append(f'{anno}{tensor._id}') + else: + datas.append(tensor) + + grads = list() + for tensor in self.grads(): + if isinstance(tensor, IRTensor): + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + grads.append(f'{anno}{tensor._id}') + else: + grads.append(tensor) + + outputs = list() + for tensor in self.outputs(): + if isinstance(tensor, IRTensor): + outputs.append(f't{tensor._id}') + else: + outputs.append(tensor) + + dscp = f'bOp(id={self._id}, signature={self.signature}, device={self.device}, grads={grads}, datas={datas}, outputs={outputs})' + return dscp + + +class IRDataOperation(IRCell): + + def __init__(self, data_num: int, name='dataloader'): + + signature = 'dataloader.__next__' + super().__init__(name, signature, 0, data_num) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 15f628e6..8a8f3290 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -5,13 +5,13 @@ from functools import partial import cube.graph.operator.function as function -from cube.graph.operator.operator import IROperation +from cube.graph.operator.operator import IRFwOperation class Sign2Op: @staticmethod - def map(signature: str) -> IROperation: + def map(signature: str) -> IRFwOperation: """ Map the signature to GenericLogicalOp """ diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index fd61477b..b9370ecd 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -3,7 +3,7 @@ import re from typing import List, Tuple, Optional -from cube.graph import IROperation +from cube.graph import IRFwOperation from cube.graph.tensor import IRFullTensor from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import Sign2Op @@ -24,7 +24,7 @@ class ScriptModuleParser: def parse_module(module, input_shapes: Optional[ Tuple[List[int],] ] = None, frame: Frame = Frame()) \ - -> Tuple[List[IRFullTensor], List[IROperation], List[IRFullTensor]]: + -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: """ The overall entry to parse a torchscript graph module """ @@ -46,7 +46,7 @@ def parse_module(module, if isinstance(val, IRFullTensor): val.shape = shape - all_ir_nodes: List[IROperation] = list() + all_ir_nodes: List[IRFwOperation] = list() for node in module.graph.nodes(): # debug info # print(f'on parsing:\n\t{node}') @@ -83,9 +83,9 @@ def ntype(node: torch._C.Node): raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @staticmethod - def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IROperation]: + def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation]: """ - Parse the node and return the IROperation nodes + Parse the node and return the IRFwOperation nodes """ node_type = ScriptModuleParser.ntype(node) if node_type == ScriptNodeKind.PrimCallFunction: @@ -100,7 +100,7 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IROperation]: return ScriptModuleParser.parse_prim_constant_node(node, module, frame) @staticmethod - def parse_prim_function_node(node, module, frame: Frame) -> List[IROperation]: + def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: """ parse node like: Tensor = prim::CallFunction(%5, %input.1, %3, %4) @@ -135,7 +135,7 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IROperation]: return [ir_node] @staticmethod - def parse_aten_node(node, module, frame: Frame) -> List[IROperation]: + def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: """ Parse script module node like: %13 : Tensor = aten::gt(%output1.1, %output2.1) @@ -181,7 +181,7 @@ def parse_aten_node(node, module, frame: Frame) -> List[IROperation]: return [ir_node] @staticmethod - def parse_prim_method_node(node, module, frame: Frame) -> List[IROperation]: + def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: """ Parse script module node like: %output.1 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) @@ -288,7 +288,7 @@ def parse_prim_constant_node(node, module, frame) -> List[None]: return list() @staticmethod - def parse_prim_if_node(node, module, frame: Frame) -> List[IROperation]: + def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: """ Parse script module node like %output2 : Tensor = prim::If(%15) # /tmp/ipykernel_27188/2459450745.py:13:8 diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 0f9e3740..7f4ad4a3 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -287,9 +287,18 @@ def as_param(self): """ self.requires_grad = True self._is_param = True + self._is_grad = False for sub_tensor in self._segments: sub_tensor.as_param() + def as_grad(self): + self.requires_grad = False + self._is_param = False + self._is_grad = True + for sub_tensor in self._segments: + sub_tensor.as_grad() + return self + def like(self): """ Create a new tensor with same name and shape, @@ -495,6 +504,16 @@ def as_param(self): self.parent.as_param() self.requires_grad = True self._is_param = True + self._is_grad = False + return self + + def as_grad(self): + if not self.parent.is_grad(): + self.parent.as_grad() + self.requires_grad = False + self._is_grad = True + self._is_param = False + return self def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap, None], shape=None): """ From b1d5dcc7b275479b1038e40a31ba0775e3ed874c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 6 Nov 2021 13:16:56 +0800 Subject: [PATCH 0256/1892] add grads --- cube/ir/cten.py | 66 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 11fc1e63..65a33af6 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -24,7 +24,7 @@ class IRCell: - """ + r""" IRCell serves as a general node for different purpose """ @@ -65,9 +65,11 @@ def __init__(self, # -- will only be set when initializing to a graph self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length)] + self._mirror = None + @property def device(self): - return self._device + return list(self._device) @device.setter def device(self, device_id: Union[int, List[int]]): @@ -80,6 +82,19 @@ def device(self, device_id: Union[int, List[int]]): raise KeyError("Require device Union[int, List[int]]") self._device = device_id + @property + def mirror(self): + """ + The mirror cell. E.g., forward op / backward op. + """ + return self._mirror + + @mirror.setter + def mirror(self, other): + if not isinstance(other, IRCell): + raise TypeError("Expected mirror to be IRCell") + self._mirror = other + def on_device(self, device_id: int): """ Check whether the operation is on device_id @@ -329,7 +344,7 @@ class IRTensor: IRTensor serves as IRGraph edge """ - _attr = ['name', '_is_param', 'requires_grad'] + _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad'] def __init__(self, shape=None, name=None): @@ -340,9 +355,12 @@ def __init__(self, shape=None, name=None): # device self._cell: List[IRCell] = list() - # forward graph self._is_param = False - self.requires_grad = True + self._is_grad = False + self._requires_grad = True + + self._grad = None + self.trace = None def attach_cell(self, cell: IRCell): @@ -373,12 +391,26 @@ def set_trace(self, sus: List): raise TypeError("Expected List[ScheduleUnit]") self.trace = sus + @property + def requires_grad(self): + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, requires: bool): + if not isinstance(requires, bool): + raise TypeError("Expected bool") + self._requires_grad = requires + if not requires: + self.grad = None + def as_param(self): """ Set the tensor as trainable parameter """ self.requires_grad = True + self._is_grad = False self._is_param = True + return self def is_param(self): """ @@ -386,6 +418,30 @@ def is_param(self): """ return self._is_param + @property + def grad(self): + return self._grad + + @grad.setter + def grad(self, grad): + if grad is not None and not isinstance(grad, IRTensor): + raise TypeError("grad can only be None or Tensor") + if self.is_grad() and grad is not None: + raise RuntimeError("Cannot assign grad to a gradient") + if not self.requires_grad and grad is not None: + raise RuntimeError("Cannot assign grad to a frozen tensor") + self._grad = grad + self.requires_grad = True + + def as_grad(self): + self.requires_grad = False + self._is_param = False + self._is_grad = True + return self + + def is_grad(self): + return self._is_grad + def renew(self): """ Renew a new tensor with same name and shape, From eeb13e19cf5fdec43c1a8428e84b3dfd4a6a07de Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Nov 2021 12:34:50 +0800 Subject: [PATCH 0257/1892] fix bug on index map equal; add grad --- cube/graph/gpass.py | 10 ++++--- cube/graph/tensor.py | 4 +-- cube/ir/cten.py | 5 ---- tests/graph/test_graph.py | 55 +++++++++++++++++++++----------------- tests/graph/test_tensor.py | 37 +++++++++++++++++++++++++ 5 files changed, 75 insertions(+), 36 deletions(-) diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index 88bc6850..84b79a9e 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -80,11 +80,13 @@ def forward(graph, *args) -> IRGraph: for idx, val in enumerate(fnode.inputs()): # set input bnode.set_data(idx, val) - val = val if isinstance(val, IRTensor) else None - val = gener.renew(val, keep_param=False) - val = val.as_grad() if isinstance(val, IRTensor) else val # set gradient output - bnode.set_output(idx, val) + val = val if isinstance(val, IRTensor) else None + grad = gener.renew(val, keep_param=False) + grad = grad.as_grad() if isinstance(grad, IRTensor) else grad + if isinstance(val, IRTensor) and val.requires_grad: + val.grad = grad + bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): # set gradient input val = gener.renew(val, keep_param=False) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 7f4ad4a3..bde4fb05 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -26,7 +26,7 @@ def __eq__(self, other): return False for myslicer, oslicer in zip(self.get(), other.get()): mstart, mstop = myslicer.start, myslicer.stop - mstep = myslicer.step if myslicer.stop is not None else 1 + mstep = myslicer.step if myslicer.step is not None else 1 ostart, ostop = oslicer.start, oslicer.stop ostep = oslicer.step if oslicer.step is not None else 1 if mstart != ostart or mstop != ostop or mstep != ostep: @@ -292,7 +292,6 @@ def as_param(self): sub_tensor.as_param() def as_grad(self): - self.requires_grad = False self._is_param = False self._is_grad = True for sub_tensor in self._segments: @@ -510,7 +509,6 @@ def as_param(self): def as_grad(self): if not self.parent.is_grad(): self.parent.as_grad() - self.requires_grad = False self._is_grad = True self._is_param = False return self diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 65a33af6..9a9b5662 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -426,15 +426,10 @@ def grad(self): def grad(self, grad): if grad is not None and not isinstance(grad, IRTensor): raise TypeError("grad can only be None or Tensor") - if self.is_grad() and grad is not None: - raise RuntimeError("Cannot assign grad to a gradient") - if not self.requires_grad and grad is not None: - raise RuntimeError("Cannot assign grad to a frozen tensor") self._grad = grad self.requires_grad = True def as_grad(self): - self.requires_grad = False self._is_param = False self._is_grad = True return self diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index e3f9c77c..671a8038 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -1,6 +1,8 @@ from cube.graph.graph import IRGraph +from cube.graph.operator.operator import IRBpOperation from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.graph.operator.function import Linear +import cube.graph.gpass as gpass from cube.ir.cten import IRTensor @@ -49,7 +51,6 @@ def test_graph_init(): assert len(graph.inputs()) == 1 assert len(graph.outputs()) == 1 - assert graph.tag == 'forward' assert graph.name == 'MLP' all_inputs = list() @@ -100,34 +101,40 @@ def test_graph_nodes(): assert graph.nodes(1) == ops[1] -def test_graph_copy(): +def test_graph_forward(): inputs, ops, outputs = construct_model() graph = IRGraph(ops, inputs, outputs, 'MLP') - cgraph = graph.copy(reverse=False) - print(cgraph) + fgraph = gpass.forward(graph, *graph.inputs()) + print(fgraph) - cparam_id = [param._id for param in cgraph.parameters()] + fparam_id = [param._id for param in fgraph.parameters()] param_id = [param._id for param in graph.parameters()] - assert set(cparam_id) == set(param_id) - - for gnode, cnode in zip(graph.nodes(), cgraph.nodes()): - assert gnode.name == cnode.name - assert gnode.signature == cnode.signature - assert len(gnode.inputs()) == len(cnode.inputs()) - assert len(gnode.outputs()) == len(cnode.outputs()) - assert len(gnode.predecessors()) == len(cnode.predecessors()) - assert len(gnode.successors()) == len(cnode.successors()) - - rgraph = graph.copy(reverse=True) - print(rgraph) - for gnode, cnode in zip(graph.nodes(), rgraph.nodes()[::-1]): - assert gnode.name == cnode.name - assert gnode.signature == cnode.signature - assert len(gnode.outputs()) == len(cnode.inputs()) - assert len(gnode.inputs()) == len(cnode.outputs()) - assert len(gnode.predecessors()) == len(cnode.successors()) - assert len(gnode.successors()) == len(cnode.predecessors()) + assert set(fparam_id) == set(param_id) + + for gnode, fnode in zip(graph.nodes(), fgraph.nodes()): + assert gnode.name == fnode.name + assert gnode.signature == fnode.signature + assert len(gnode.inputs()) == len(fnode.inputs()) + assert len(gnode.outputs()) == len(fnode.outputs()) + assert len(gnode.predecessors()) == len(fnode.predecessors()) + assert len(gnode.successors()) == len(fnode.successors()) + + # test backward + bnodes = [node.mirror for node in fgraph.nodes()][::-1] + bgraph = IRGraph(bnodes, None, None, module_name='backwards') + print(bgraph) + bnode1, bnode2, bnode3 = bnodes + for bnode in bnodes: + assert isinstance(bnode, IRBpOperation) + assert len(bnode.inputs()) == 4 + assert len(bnode.outputs()) == 3 + assert bnode2 in bnode1.successors() + assert bnode3 in bnode2.successors() + assert not bnode3 in bnode1.successors() + assert bnode1 in bnode2.predecessors() + assert bnode2 in bnode3.predecessors() + assert not bnode1 in bnode3.predecessors() def test_graph_partition(): diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py index a3ac6d58..a0de3bd8 100644 --- a/tests/graph/test_tensor.py +++ b/tests/graph/test_tensor.py @@ -192,3 +192,40 @@ def test_sub_tensor_common(): assert rt.indices.get() == (slice(0, 512, 1), slice(512, 1024, 1)) assert lb.indices.get() == (slice(512, 1024, 1), slice(0, 512, 1)) assert rb.indices.get() == (slice(512, 1024, 1), slice(512, 1024, 1)) + + +def test_sub_tensor_as_grad(): + tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') + sub_tensor1 = tensor1.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_map = None, + shape = (1024, 512) + ) + + sub_tensor1.as_grad() + assert sub_tensor1.is_grad() + + sub_tensor2 = tensor1.select( + indices = (slice(0, 1024), slice(0, 512)), + val_map = (0, 4), + shape = (1024, 512) + ) + assert sub_tensor2.is_grad() + + +def test_sub_tensor_copy(): + tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') + sub_tensor1 = tensor1.select( + indices = (slice(0, 1024), slice(512, 1024)), + val_map = None, + shape = (1024, 512) + ) + sub_tensor2 = tensor1.select( + indices = (slice(0, 1024), slice(0, 512)), + val_map = (0, 4), + shape = (1024, 512) + ) + sub_tensor1.grad = sub_tensor2 + cpy_tensor = copy.copy(sub_tensor1) + assert cpy_tensor.grad == sub_tensor2 + From d3d9a05784d65c9ab51d9ec0f2c927d3cc1c04a1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Nov 2021 13:00:55 +0800 Subject: [PATCH 0258/1892] switch to full graph transformation --- cube/codegen/codegen.py | 82 +++--- cube/graph/gpass.py | 8 +- cube/schedule/__init__.py | 37 +-- cube/schedule/execplan.py | 2 +- cube/schedule/graphpass.py | 2 +- cube/schedule/pool.py | 33 ++- cube/schedule/su.py | 153 +++++++---- cube/schedule/sugraph.py | 407 ++++++++++++++++++++---------- cube/schedule/translator.py | 157 ++---------- examples/e2e.py | 5 +- tests/codegen/test_codegen.py | 54 ++-- tests/schedule/test_graphpass.py | 30 +-- tests/schedule/test_pool.py | 16 +- tests/schedule/test_su.py | 22 +- tests/schedule/test_sugraph.py | 33 ++- tests/schedule/test_translator.py | 61 +++-- tests/schedule/test_worflow.py | 20 +- 17 files changed, 611 insertions(+), 511 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 4f553ae0..da27b77d 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -54,7 +54,11 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # register forward input su_args: List[List[str]] = list() for su in device_sus: - fargs = [self.naming(input) for input in su.inputs()] + fargs = list() + for input in su.inputs(): + if isinstance(input, IRTensor) and input.is_param(): + continue + fargs.append(self.naming(input)) for name in fargs: self.symbols.create(name) su_args.append(fargs) @@ -115,8 +119,9 @@ def emit_var_declare(self, var: Any): """ if isinstance(var, IRTensor): name = self.naming(var) - # indicate this is a leaf tensor, should be parameter - if var.is_param()and self.symbols.create(name): + # emit parameter code + if var.is_param() and not self.symbols.exist(name): + self.symbols.create(name) code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(var.shape)}))' self.declare_region.append(code) elif isinstance(var, str): @@ -254,8 +259,6 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: - data_code = self.emit_data(device_sus) - fb.insert_body(data_code) for su in device_sus: name = f'su{self.seq.sus().index(su)}' code = self.emit_su(su, name=name) @@ -270,24 +273,6 @@ def gen(self, device: int, outfile=None, attach=False) -> str: f.write(code) return code - def emit_data(self, device_sus) -> List[str]: - """ - Emit dataloader iter code - """ - # TODO: dataloader to op node - inputs = list() - for su in device_sus: - su_inputs = [ - self.naming(input, su) for input in su.inputs() \ - if input.is_leaf(device_sus) - ] - inputs += su_inputs - data_code = list() - if len(inputs) != 0: - inputs = '(' + ', '.join(inputs + ['']) + ')' - data_code.append(inputs + ' = next(dataloader)') - return data_code - def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: """ Emit su code @@ -300,11 +285,16 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: raise RuntimeError("Dataloader su has no inputs") outputs = [self.naming(output, su) for output in su.outputs()] return_val = ','.join(outputs) - code = f'{return_val} = {su.signature}' + code = f'{return_val} = next(dataloader)' return code - elif su.stype == SUType.Forward or su.stype == SUType.Adapter: - inputs = [self.naming(tensor, su) for tensor in su.inputs()] + elif su.stype == SUType.Forward or su.stype == SUType.Comm: + inputs = list() + for tensor in su.inputs(): + if isinstance(tensor, IRTensor): + if tensor.is_param(): + continue + inputs.append(self.naming(tensor, su)) inputs = '(' + ', '.join(inputs + ['']) + ')' body = fsign.format( model = f'model.{name}', @@ -326,28 +316,38 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: # 3). output_grads are recved tesnors of this graph (graph.recv_tensors) # => backward graph input tensor (graph.recv_tensors) fsu = su.mirror - forward_inputs = [self.naming(tensor, fsu) for tensor in fsu.inputs()] - forward_inputs = '(' + ', '.join(forward_inputs + ['']) + ')' - forward_outputs = [self.naming(tensor, fsu) for tensor in fsu.outputs()] - forward_outputs = '(' + ', '.join(forward_outputs + ['']) + ')' - - grads = list() - for tensor in su.inputs(): - # the thensor is loss, no grad needs - if tensor in fsu.outputs(): - grads.append('None') + finputs = list() + for tensor in fsu.inputs(): + if isinstance(tensor, IRTensor): + if tensor.is_param(): + continue + finputs.append(self.naming(tensor, fsu)) + fargs = '(' + ', '.join(finputs + ['']) + ')' + + foutputs = list() + for tensor in fsu.outputs(): + foutputs.append(self.naming(tensor, fsu)) + foutputs = '(' + ', '.join(foutputs + ['']) + ')' + + in_grads = list() + for tensor in fsu.outputs(): + grad = tensor.grad + if grad in fsu.outputs(): + in_grads.append('None') else: - grads.append(self.naming(tensor, su)) - grads = '(' + ', '.join(grads + ['']) + ')' + in_grads.append(self.naming(grad, su)) + in_grads = '(' + ', '.join(in_grads + ['']) + ')' body = bsign.format( - input_tensors = forward_inputs, - output_tensors = forward_outputs, - output_grads = grads + input_tensors = fargs, + output_tensors = foutputs, + output_grads = in_grads ) # returned value are graph.outputs return_val = [self.naming(tensor, su) for tensor in su.outputs()] + # TODO: fix this by using grad attributed + return_val = return_val[:len(finputs)] if len(return_val) > 0: return_code = ', '.join(return_val) + ' = ' else: diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index 84b79a9e..a8e7edb9 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -89,9 +89,11 @@ def forward(graph, *args) -> IRGraph: bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): # set gradient input - val = gener.renew(val, keep_param=False) - val = val.as_grad() if isinstance(val, IRTensor) else val - bnode.set_grad(idx, val) + grad = gener.renew(val, keep_param=False) + grad = grad.as_grad() if isinstance(grad, IRTensor) else grad + if isinstance(val, IRTensor) and val.requires_grad: + val.grad = grad + bnode.set_grad(idx, grad) fnode.device = node.device bnode.device = node.device diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index b6690f2c..0baa0b0f 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -1,9 +1,10 @@ from typing import Callable, Optional import torch +from cube.graph.graph import IRGraph from cube.schedule.pool import SchedulePool -from cube.schedule.translator import IRDataLoader, LogicTranslator -from cube.schedule.sugraph import SUGraph +from cube.schedule.translator import IRDataLoader +from cube.schedule.sugraph import SUGraph, SUGraphGener from cube.schedule.graphpass import SUGraphPass from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -11,7 +12,7 @@ class SemanticModel: - def __init__(self, model: torch.nn.Module, input_shapes, policy_fn=None): + def __init__(self, model: torch.nn.Module, input_shapes): """ Create semantic model based on AI Scientist description. """ @@ -19,8 +20,6 @@ def __init__(self, model: torch.nn.Module, input_shapes, policy_fn=None): self.ir_graph = parser.convert( model, input_shapes=input_shapes ) - if policy_fn: - self.ir_graph = policy_fn(self.ir_graph, None) self._loaded_module = None def get_graph(self): @@ -47,7 +46,9 @@ def __call__(self, *args): return self.ir_graph(*args) -def schedule(model: SemanticModel, dataloader, policy_fn: Optional[Callable] = None): +def schedule(model: SemanticModel, dataloader, + transform_policy: Optional[Callable] = None, + schedule_policy: Optional[Callable] = None): """ AI Scientist calls like: @@ -100,26 +101,30 @@ def decorator(fn: Callable) -> Callable: # logic translator fn(ir_graph, ir_dataloader) - sus = SchedulePool().sus() - # adapter - sus_with_adapter = LogicTranslator.gen_adapter(sus) + nodes = SchedulePool().nodes() - # policy - su_graph = SUGraph(sus_with_adapter) - if policy_fn: + # graph transformation + graph = IRGraph(nodes, None, None, ir_graph.name) + if transform_policy: + graph = transform_policy(graph, None) + + # sugraph + sugraph = SUGraphGener.gen_sugraph(graph.nodes()) + if schedule_policy: # TODO: add resource - su_graph = policy_fn(su_graph, None) + sugraph = schedule_policy(sugraph, None) # check assignment and order - for su in su_graph.sus(): + print(sugraph) + for su in sugraph.sus(): if len(su.device) == 0: raise RuntimeError(f"SU {su} device is not set") - if not SUGraph.is_topo_order(su_graph.sus()): + if not SUGraph.is_topo_order(sugraph.sus()): raise RuntimeError(f"SUGraph order is not topological order") # graph pass to remove redundant sus - su_graph = SUGraphPass.remove_redundant_adapters(su_graph) + su_graph = SUGraphPass.remove_redundant_adapters(sugraph) su_graph = SUGraphPass.merge_small_sus(su_graph) print(su_graph) diff --git a/cube/schedule/execplan.py b/cube/schedule/execplan.py index ef8e920b..34094600 100644 --- a/cube/schedule/execplan.py +++ b/cube/schedule/execplan.py @@ -42,7 +42,7 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): span = 1 elif su.stype == SUType.Backward: span = 2 - elif su.stype == SUType.Adapter: + elif su.stype == SUType.Comm: span = 0.1 spans.append(span) diff --git a/cube/schedule/graphpass.py b/cube/schedule/graphpass.py index 90691048..23e4b89e 100644 --- a/cube/schedule/graphpass.py +++ b/cube/schedule/graphpass.py @@ -14,7 +14,7 @@ def remove_redundant_adapters(sugraph: SUGraph) -> SUGraph: """ redundant_adapters = list() for su in sugraph.sus(): - if su.stype != SUType.Adapter: + if su.stype != SUType.Comm: for idx in range(len(su.outputs())): send_adapters, recv_adapters = su.out_adapters(idx) for sadapter, radapter in zip(send_adapters, recv_adapters): diff --git a/cube/schedule/pool.py b/cube/schedule/pool.py index 11de56e9..fd9f2045 100644 --- a/cube/schedule/pool.py +++ b/cube/schedule/pool.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any import copy @@ -8,7 +8,8 @@ class __SchedulePool: def __init__(self): - self._sus = list() + self._nodes = list() + self._tapes = dict() instance = None @@ -19,15 +20,31 @@ def __init__(self): def __getattr__(self, name): return getattr(self.instance, name) - def add_su(self, su): - self.instance._sus.append(su) + def add_node(self, node): + self.instance._nodes.append(node) - def sus(self) -> List: - return copy.copy(self.instance._sus) + def nodes(self) -> List: + return copy.copy(self.instance._nodes) + + def tape(self, tensor, trace: Any): + """ + Record the trace generated to this tensor + """ + self.instance._tapes[tensor._id] = trace + + def get_tape(self, tensor): + """ + Get the trace given the tensor + """ + if tensor._id not in self.instance._tapes: + return None + else: + return self.instance._tapes[tensor._id] def clear(self): - self.instance._sus = list() + self.instance._nodes = list() + self.instance._tapes = dict() def __repr__(self): - dscp = '\n'.join([repr(su) for su in self._sus]) + dscp = '\n'.join([repr(node) for node in self._nodes]) return dscp diff --git a/cube/schedule/su.py b/cube/schedule/su.py index aeca5d11..5eda73b5 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -2,35 +2,44 @@ import copy from enum import Enum -from cube.ir.cten import IRCell +from cube.ir.cten import IRCell, IRTensor +from cube.graph.operator import IRBpOperation class SUType(Enum): + Dataloader = 'next(dataloader)' + # outputs = cube.runtime.temporal.forward(model, *args) - Forward = 'cube.runtime.temporal.forward' + Forward = 'cube.runtime.executor.fexecute' # grads = cube.runtime.temporal.backward( # input_tensors, output_tensors, output_grads # ) - Backward = 'cube.runtime.temporal.backward' + Backward = 'cube.runtime.executor.backward' + + Transform = 'cube.runtime.adapter.transform' # cube.runtime.collectives.sendrecv(send_tensors, send_ranks, # recv_shapes, from_ranks # ) - Adapter = 'cube.runtime.collectives.sendrecv' + Comm = 'cube.runtime.adapter.sendrecv' - Dataloader = 'next(dataloader)' + Empty = 'None' class ScheduleUnit(IRCell): - """ - Action recv tensors must be inside of Action inputs, - and can be mapped to Action.graph.inputs - + r""" + ScheduleUnit for policy scheduling. """ def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): + """ + Create a ScheduleUnit. + + Args: + nodes (List[IRCell]): A list of nodes in IRGraph + """ if not all([isinstance(node, IRCell) for node in nodes]): raise ValueError("Expected each nodes to be List[IRCell]") @@ -38,10 +47,11 @@ def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): raise TypeError("Expected stype be SUType") # get inputs and outputs + # TODO: fix bug on multi-branch inputs = IRCell.get_inputs(nodes) - inputs = [input for input in inputs if not input.is_param()] + # inputs = [input for input in inputs if not input.is_param()] outputs = IRCell.get_outputs(nodes) - outputs = [output for output in outputs if not output.is_param()] + # outputs = [output for output in outputs if not output.is_param()] super().__init__( name = name, signature = stype.value, @@ -81,8 +91,6 @@ def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): self._ctrl_predecessors = list() self._ctrl_successors = list() - self.mirror = None - def __copy__(self): """ Copy the SU. Note the mirror su is also copied @@ -94,19 +102,10 @@ def __copy__(self): mirror_su = ScheduleUnit( mirror_su._nodes, mirror_su.stype, mirror_su.name ) - su.set_mirror(mirror_su) - mirror_su.set_mirror(su) + su.mirror = mirror_su + mirror_su.mirror = su return su - def set_mirror(self, su): - """ - Create a mirrored ScheduleUnit: the - inputs and outputs are reversed - """ - if not isinstance(su, ScheduleUnit): - raise TypeError("Expected mirror to be ScheduleUnit") - self.mirror = su - def in_adapters(self, index: Optional[int] = None) -> List: """ Get adapter for the input tensor at index @@ -209,6 +208,8 @@ def _clear_adapters(self): self._recv_in_adapters: List[List[ScheduleUnit]] = [ list() for _ in range(len(self.inputs())) ] + self._merge_adapters: List[ScheduleUnit] = [None] * len(self._inputs) + self._select_adapters: List[ScheduleUnit] = [None] * len(self._outputs) self._send_out_adapters: List[List[ScheduleUnit]] = [ list() for _ in range(len(self.outputs())) ] @@ -216,7 +217,7 @@ def _clear_adapters(self): list() for _ in range(len(self.outputs())) ] - def _add_in_adapter(self, index: int, send_adapter, recv_adapter): + def _add_in_adapter(self, index: int, send_adapters, recv_adapters): """ Add adapters to the input tensor of this SU @@ -227,12 +228,19 @@ def _add_in_adapter(self, index: int, send_adapter, recv_adapter): """ if index >= len(self._inputs): raise ValueError(f"index {index} out of range {len(self._inputs)}") - if not isinstance(send_adapter, ScheduleUnit): - raise TypeError("Expected send adapter to be ScheduleUnit") - if not isinstance(recv_adapter, ScheduleUnit): - raise TypeError("Expected recv adapter to be ScheduleUnit") - self._send_in_adapters[index].append(send_adapter) - self._recv_in_adapters[index].append(recv_adapter) + if isinstance(send_adapters, ScheduleUnit): + send_adapters = [send_adapters] + if not all(isinstance(adapter, ScheduleUnit) for adapter in send_adapters): + raise TypeError("Expected send adapter to be (list of) ScheduleUnit") + if isinstance(recv_adapters, ScheduleUnit): + recv_adapters = [recv_adapters] + if not all(isinstance(adapter, ScheduleUnit) for adapter in send_adapters): + raise TypeError("Expected recv adapters to be (list of) ScheduleUnit") + if len(send_adapters) != len(recv_adapters): + raise ValueError("Expected same number of send / recv adapters") + for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): + self._send_in_adapters[index].append(send_adapter) + self._recv_in_adapters[index].append(recv_adapter) def _set_merge_adapter(self, index: int, merge_adapter): """ @@ -244,11 +252,11 @@ def _set_merge_adapter(self, index: int, merge_adapter): """ if index >= len(self._inputs): raise ValueError(f"index {index} out of range {len(self._inputs)}") - if not isinstance(merge_adapter, ScheduleUnit): - raise TypeError("Expected merge adapter to be ScheduleUnit") + if merge_adapter is not None and not isinstance(merge_adapter, ScheduleUnit): + raise TypeError("Expected merge adapter to be None or ScheduleUnit") self._merge_adapters[index] = merge_adapter - def _add_out_adapter(self, index: int, send_adapter, recv_adapter): + def _add_out_adapter(self, index: int, send_adapters, recv_adapters): """ Add adapters to the output tensor of this SU @@ -259,12 +267,19 @@ def _add_out_adapter(self, index: int, send_adapter, recv_adapter): """ if index >= len(self._outputs): raise ValueError(f"index {index} out of range {len(self._outputs)}") - if not isinstance(send_adapter, ScheduleUnit): - raise TypeError("Expected send adapter to be ScheduleUnit") - if not isinstance(recv_adapter, ScheduleUnit): - raise TypeError("Expected recv adapter to be ScheduleUnit") - self._send_out_adapters[index].append(send_adapter) - self._recv_out_adapters[index].append(recv_adapter) + if isinstance(send_adapters, ScheduleUnit): + send_adapters = [send_adapters] + if not all(isinstance(adapter, ScheduleUnit) for adapter in send_adapters): + raise TypeError("Expected send adapter to be (list of) ScheduleUnit") + if isinstance(recv_adapters, ScheduleUnit): + recv_adapters = [recv_adapters] + if not all(isinstance(adapter, ScheduleUnit) for adapter in send_adapters): + raise TypeError("Expected recv adapters to be (list of) ScheduleUnit") + if len(send_adapters) != len(recv_adapters): + raise ValueError("Expected same number of send / recv adapters") + for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): + self._send_out_adapters[index].append(send_adapter) + self._recv_out_adapters[index].append(recv_adapter) def _set_select_adapter(self, index: int, select_adapter): """ @@ -276,10 +291,40 @@ def _set_select_adapter(self, index: int, select_adapter): """ if index >= len(self._outputs): raise ValueError(f"index {index} out of range {len(self._inputs)}") - if not isinstance(select_adapter, ScheduleUnit): - raise TypeError("Expected merge adapter to be ScheduleUnit") + if select_adapter is not None and not isinstance(select_adapter, ScheduleUnit): + raise TypeError("Expected merge adapter to be Optional[ScheduleUnit]") self._select_adapters[index] = select_adapter + def _remove_adapter(self, adapter): + """ + Remove the adapter + """ + for send_adapters in self._send_in_adapters: + if adapter in send_adapters: + send_adapters.remove(adapter) + return True + for recv_adapters in self._recv_in_adapters: + if adapter in recv_adapters: + recv_adapters.remove(adapter) + return True + if adapter in self._merge_adapters: + idx = self._merge_adapters.index(adapter) + self._merge_adapters[idx] = None + return True + if adapter in self._select_adapters: + idx = self._select_adapters.index(adapter) + self._select_adapters[idx] = None + return True + for send_adapters in self._send_out_adapters: + if adapter in send_adapters: + send_adapters.remove(adapter) + return True + for recv_adapters in self._recv_out_adapters: + if adapter in recv_adapters: + recv_adapters.remove(adapter) + return True + return False + def nodes(self, index: Optional[int] = None): """ Get node at position index @@ -376,7 +421,27 @@ def successors(self, index: Optional[int] = None) -> List: raise TypeError("Expected index to be None or int") def __repr__(self): - su_inputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.inputs()] - su_outputs = [f't{tensor._id}-dev{tensor.device}' for tensor in self.outputs()] + su_inputs = list() + for tensor in self.inputs(): + if isinstance(tensor, IRTensor): + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + su_inputs.append(f'{anno}{tensor._id}') + else: + su_inputs.append(tensor) + su_outputs = list() + for tensor in self.outputs(): + if isinstance(tensor, IRTensor): + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + su_outputs.append(f'{anno}{tensor._id}') + else: + su_outputs.append(tensor) dscp = f'SU({self.stype}, nodes={len(self.nodes())})-dev{self.device}: {su_inputs} -> {su_outputs}' return dscp diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 700995e0..21d6868c 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -1,8 +1,15 @@ from typing import List, Optional, Union import copy -from cube.ir.cten import IRCell + +from cube.ir.cten import IRCell, IRTensor +from cube.graph.operator import IRBpOperation +from cube.graph.operator import IRDataOperation +from cube.graph.operator import IRFwOperation +from cube.graph.graph import IRGraph from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.adapter.comm import IRCommunication +from cube.schedule.adapter.select import IRTensorReshape class SUGraph(IRCell): @@ -15,7 +22,9 @@ def __init__(self, sus: List[ScheduleUnit]): ) inputs = IRCell.get_inputs(sus) + inputs = [input for input in inputs if not input.is_param()] outputs = IRCell.get_outputs(sus) + outputs = [output for output in outputs if not output.is_param()] super().__init__( name = 'SU', signature = 'None', @@ -28,7 +37,7 @@ def __init__(self, sus: List[ScheduleUnit]): self.set_output(idx, output) self.sequence = sus - self.reset_dependency() + SUGraph.reset_dependency(self.sequence) @property def nnodes(self) -> int: @@ -37,25 +46,38 @@ def nnodes(self) -> int: """ return len(self.sequence) - def reset_dependency(self): + @staticmethod + def reset_dependency(sus: List[ScheduleUnit]): """ Reset the node dataflow dependency """ - # set node predecessors and successors - for src_idx in range(self.nnodes): - src_su = self.sequence[src_idx] - src_su._successors = [ - list() for _ in range(len(src_su.outputs())) - ] - for dst_su in self.sequence[src_idx+1:]: - dst_su._predecessors = [ - list() for _ in range(len(dst_su.inputs())) - ] - for out_idx, out_tensor in enumerate(src_su.outputs()): - for in_idx, in_tensor in enumerate(dst_su.inputs()): + if not all([isinstance(su, ScheduleUnit) for su in sus]): + raise TypeError("Expected list of schedule unit") + for su in sus: + su.clear_predecessor() + su.clear_successor() + for src_idx in range(len(sus)): + src = sus[src_idx] + for dst in sus[src_idx+1:]: + for out_idx, out_tensor in enumerate(src.outputs()): + for in_idx, in_tensor in enumerate(dst.inputs()): if out_tensor.overlap(in_tensor): - src_su.add_successor(out_idx, dst_su) - dst_su.add_predecessor(in_idx, src_su) + src.add_successor(out_idx, dst) + dst.add_predecessor(in_idx, src) + + @staticmethod + def gen_comm_adapter(sus: List[ScheduleUnit]): + """ + Generate communication adapter for each SU + """ + pass + + @staticmethod + def gen_trans_adapter(sus: List[ScheduleUnit]): + """ + Generate transformation adapter for each SU + """ + pass def __len__(self): return len(self.sequence) @@ -78,6 +100,21 @@ def sus(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") + def fsus(self) -> List[ScheduleUnit]: + """ + Get forward ScheduleUnits sequence. + """ + return [su for su in self.sequence if su.stype == SUType.Forward] + + def get_graph(self, sus: List[ScheduleUnit], name: str) -> IRGraph: + """ + Generate IRGraph + """ + nodes = list() + for su in sus: + nodes += su.nodes() + return IRGraph(nodes, None, None, name) + def happen_before(self, su1, su2): """ Check if the su1 -> (happened before) su2 @@ -121,130 +158,111 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: if fail: None """ - def _adapter_merge(first_su: ScheduleUnit, second_su: ScheduleUnit, merged_su: ScheduleUnit): - # move from first_su adapter - # print(f' 1st SU: {first_su} \n 2nd SU: {second_su} \n merged SU: {merged_su}') - for idx, input in enumerate(first_su.inputs()): - send_adapters, recv_adapters = first_su.in_adapters(idx) - merge_adapter = first_su.merge_adapters(idx) - merge_idx = merged_su.inputs().index(input) - for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): - merged_su._add_in_adapter(merge_idx, send_adapter, recv_adapter) - if merge_adapter in self.sequence: - merged_su._set_merge_adapter(merge_idx, merge_adapter) - for idx, output in enumerate(first_su.outputs()): - send_adapters, recv_adapters = first_su.out_adapters(idx) - select_adapter = first_su.select_adapters(idx) - if output in merged_su.outputs() and output not in second_su.outputs(): - merge_idx = merged_su.outputs().index(output) - for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): - merged_su._add_out_adapter(merge_idx, send_adapter, recv_adapter) - if select_adapter: - merged_su._set_select_adapter(merge_idx, select_adapter) - else: - if merge_adapter in self.sequence: - self.sequence.remove(merge_adapter) - # move from su2 adapter - for idx, input in enumerate(second_su.inputs()): - send_adapters, recv_adapters = second_su.in_adapters(idx) - merge_adapter = second_su.merge_adapters(idx) - if input in merged_su.inputs() and input not in first_su.inputs(): - merge_idx = merged_su.inputs().index(input) - for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): - merged_su._add_in_adapter(merge_idx, send_adapter, recv_adapter) - if merge_adapter: - merged_su._set_merge_adapter(merge_idx, merge_adapter) - else: - for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): - # print(f'removing: {send_adapter}') - # print(f'removing: {recv_adapter}') - if send_adapter in self.sequence: - self.sequence.remove(send_adapter) - if recv_adapter in self.sequence: - self.sequence.remove(recv_adapter) - if merge_adapter in self.sequence: - self.sequence.remove(merge_adapter) - for idx, output in enumerate(second_su.outputs()): - send_adapters, recv_adapters = second_su.out_adapters(idx) - select_adapter = second_su.select_adapters(idx) - if output in merged_su.outputs(): - merge_idx = merged_su.outputs().index(output) - for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): - merged_su._add_out_adapter(merge_idx, send_adapter, recv_adapter) - if select_adapter: - merged_su._set_select_adapter(merge_idx, select_adapter) - else: - if select_adapter: - self.sequence.remove(select_adapter) + fsus = self.fsus() + if su1 not in fsus: + raise RuntimeError(f"SU1: {su1} not in forward SUs") + if su2 not in fsus: + raise RuntimeError(f"SU2: {su2} not in forward SUs") - if not isinstance(su1, ScheduleUnit) or \ - not isinstance(su2, ScheduleUnit): - raise TypeError("Expected SU1 and SU2 are ScheduleUnit") - if su1 not in self.sequence: - raise ValueError(f"su1: {su1} not in sequence") - if su2 not in self.sequence: - raise ValueError(f"su2: {su2} not in sequence") - - # 2) all the nodes in both SU are on the same device - if su1 == su2 or su1.stype != su2.stype: - return None + idx1, idx2 = fsus.index(su1), fsus.index(su2) + su1, su2 = (su1, su2) if idx1 < idx2 else (su2, su1) + + # condition 1): same device if su1.device != su2.device: return None - if su1.stype == SUType.Adapter: - raise NotImplementedError("Not supported for merging Adapter") - - index_su1 = self.sequence.index(su1) - index_su2 = self.sequence.index(su2) - su1, su2 = (su1, su2) if index_su1 < index_su2 else (su2, su1) - # 3) deadlock-free merge - index_su1, index_su2 = min(index_su1, index_su2), max(index_su1, index_su2) - inter_sus = self.sequence[index_su1+1:index_su2] + # condition 2): su2 input cannot be got from both su1 and other su + start, stop = min(idx1, idx2), max(idx1, idx2) + inter_sus = fsus[start+1:stop] for su in inter_sus: - # in theory the below condition satisfies merge, but it may - # break the topo order - # e.g., su1 -> adapter1 ,....., adapter2 -> su2 - # if self.happen_before(su1, su) and self.happen_before(su, su2): - # to keep topo order: - if su.stype != SUType.Adapter and self.happen_before(su, su2): + if self.happen_before(su, su2): + return None + for idx in range(len(su2.inputs())): + prev_sus = su2.predecessors(idx) + prev_sus = [su for su in prev_sus if su.stype != SUType.Comm] + if su2 in prev_sus and len(prev_sus) > 1: return None - # merge forward su - sub_nodes = su1.nodes() + su2.nodes() - merged_su = ScheduleUnit(sub_nodes, su1.stype) - merged_su.device = su1.device - _adapter_merge(su1, su2, merged_su) - - # merge mirrored su - # mirror_su2 -> mirror_su1 - mirror_su1, mirror_su2 = su1.mirror, su2.mirror - merged_mirror_su = None - if mirror_su1 and mirror_su2: - if mirror_su1.device == mirror_su2.device: - sub_nodes = mirror_su2.nodes() + mirror_su1.nodes() - merged_mirror_su = ScheduleUnit(sub_nodes, mirror_su1.stype) - merged_mirror_su.device = mirror_su1.device - _adapter_merge(mirror_su2, mirror_su1, merged_mirror_su) - # set mirror - merged_su.set_mirror(merged_mirror_su) - merged_mirror_su.set_mirror(merged_su) - elif mirror_su1 or mirror_su2: - raise RuntimeError( - "The merged su should be both have mirror or both not have." + # start merging + fnodes = su1.nodes() + su2.nodes() + # TODO: fix multi-branch + fsu = ScheduleUnit(fnodes, SUType.Forward, name='fsu') + fsu.device = su1.device + + bnodes = [node.mirror for node in fnodes][::-1] + skip_bp = all([bnode is None for bnode in bnodes]) + if not skip_bp: + bnode = IRBpOperation( + data_num=len(fsu.inputs()), + grad_num=len(fsu.outputs()) ) - - # replace - self.sequence[index_su1] = merged_su + for idx, input in enumerate(fsu.inputs()): + bnode.set_data(idx, input) + fout_grads = [out.grad for out in fsu.outputs()] + for idx, fout_grad in enumerate(fout_grads): + bnode.set_grad(idx, fout_grad) + for idx, fin in enumerate(fsu.inputs()): + if isinstance(fin, IRTensor): + bnode.set_output(idx, fin.grad) + else: + bnode.set_output(idx, None) + for output in fsu.outputs(): + print(output.grad) + bsu = ScheduleUnit([bnode], stype=SUType.Backward, name='bsu') + bsu.device = su2.mirror.device + fsu.mirror = bsu + bsu.mirror = fsu + + def _set_adapters(su1, su2, msu): + # set adapter + for idx, input in enumerate(msu.inputs()): + if input in su1.inputs(): + su1_idx = su1.inputs().index(input) + adapters = su1.in_adapters(su1_idx) + merge_adapter = su1.merge_adapters(su1_idx) + elif input in su2.inputs(): + su2_idx = su2.inputs().index(input) + adapters = su2.in_adapters(su2_idx) + merge_adapter = su2.merge_adapters(su2_idx) + else: + raise RuntimeError("Internal Error: not found input SU") + msu._add_in_adapter(idx, *adapters) + msu._set_merge_adapter(idx, merge_adapter) + for idx, output in enumerate(msu.outputs()): + if output in su1.outputs(): + su1_idx = su1.outputs().index(output) + adapters = su1.out_adapters(su1_idx) + select_adapter = su1.select_adapters(su1_idx) + elif output in su2.outputs(): + su2_idx = su2.outputs().index(output) + adapters = su2.out_adapters(su2_idx) + select_adapter = su2.select_adapters(su2_idx) + else: + raise RuntimeError("Internal Error: not found output SU") + msu._add_out_adapter(idx, *adapters) + msu._set_merge_adapter(idx, select_adapter) + # remove adapters + for idx, input in enumerate(su2.inputs()): + if input not in msu.inputs(): + sadapters, radapters = su2.in_adapters(idx) + for adapter in [sadapters + radapters]: + if adapter in self.sequence: + self.sequence.remove(adapter) + + _set_adapters(su1, su2, fsu) + if not skip_bp: + _set_adapters(su2.mirror, su1.mirror, bsu) + + # replace + self.sequence[self.sequence.index(su1)] = fsu self.sequence.remove(su2) - if merged_mirror_su: - if mirror_su1 in self.sequence and mirror_su2 in self.sequence: - index_mirror_su2 = self.sequence.index(mirror_su2) - self.sequence[index_mirror_su2] = merged_mirror_su - self.sequence.remove(mirror_su1) + if not skip_bp: + self.sequence[self.sequence.index(su2.mirror)] = bsu + self.sequence.remove(su1.mirror) - # TODO: optimize: reset dependency - self.reset_dependency() - return merged_su + # re-gen adapter + SUGraph.reset_dependency(self.sequence) + return fsu def add_flow(self, su1: ScheduleUnit, su2: ScheduleUnit): """ @@ -282,7 +300,7 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): elif not all([isinstance(int, rank) for rank in ranks]): raise TypeError("Expected type ranks to be Union[int, List[int]]") - if su.stype == SUType.Adapter: + if su.stype == SUType.Comm: return False if set(su.device) == set(ranks): @@ -296,7 +314,7 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): for su in sus: index = self.sus().index(su) self.sequence.insert(index, su) - self.reset_dependency() + SUGraph.reset_dependency(self.sequence) for su, rank in zip(sus, ranks): self.assign(su, rank) @@ -347,6 +365,96 @@ def set_order(self, seq: List[ScheduleUnit]): self.sequence = seq return True + @staticmethod + def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: + """ + Each computation SU has adapters for its inputs. + """ + sugraph = SUGraph(sus) + + # clear adapters + for su in sugraph.sus(): + su._clear_adapters() + + for su in sugraph.sus(): + for in_idx, input in enumerate(su.inputs()): + if not isinstance(input, IRTensor): + continue + pre_sus = su.predecessors(in_idx) + tensor_segments = list() + for pre_su in pre_sus: + for out_idx, output in enumerate(pre_su.outputs()): + if output.overlap(input): + sub_tensor = input.common(output) + if sub_tensor != input: + tensor_segments.append(sub_tensor) + send_op = IRCommunication( + send_tensors=[sub_tensor], + send_ranks = [-1] + ) + recv_op = IRCommunication( + recv_tensors=[sub_tensor], + recv_ranks = [-1] + ) + send_op.pair(recv_op) + send_su = ScheduleUnit([send_op], SUType.Comm, name='send') + recv_su = ScheduleUnit([recv_op], SUType.Comm, name='recv') + su._add_in_adapter(in_idx, send_su, recv_su) + send_su.device = su.device + pre_su._add_out_adapter(out_idx, send_su, recv_su) + recv_su.device = su.device + # add adapter for merge + if len(tensor_segments) != 0: + merge_op = IRTensorReshape( + src_tensors=tensor_segments, dst_tensors=[input] + ) + merge_su = ScheduleUnit([merge_op], SUType.Comm, name='merge') + su._set_merge_adapter(in_idx, merge_su) + merge_su.device = su.device + + # add adapter for select + for out_idx, output in enumerate(su.outputs()): + if not isinstance(output, IRTensor): + continue + select_tensors = list() + send_adapters, recv_adapters = su.out_adapters(out_idx) + for send_adapter in send_adapters: + for tensor in send_adapter.nodes(0).send_tensors: + if tensor != output: + select_tensors.append(tensor) + if len(select_tensors) != 0: + select_op = IRTensorReshape( + src_tensors=[output], dst_tensors=select_tensors + ) + select_su = ScheduleUnit( + [select_op], SUType.Comm, name='select' + ) + su._set_select_adapter(out_idx, select_su) + select_su.device = su.device + + sus_with_adapter = list() + for su in sus: + # send + recv + merge + for idx in range(len(su.inputs())): + merge_su = su.merge_adapters(idx) + send_adapters, recv_adapters = su.in_adapters(idx) + # PyTorch implementation issue: forward + backward happened on same device + if su.stype == SUType.Backward and not su.inputs(idx).is_grad(): + continue + for send_su, recv_su in zip(send_adapters, recv_adapters): + sus_with_adapter.append(send_su) + sus_with_adapter.append(recv_su) + if merge_su: + sus_with_adapter.append(merge_su) + # excute + sus_with_adapter.append(su) + # select + for idx in range(len(su.outputs())): + select_su = su.select_adapters(idx) + if select_su: + sus_with_adapter.append(select_su) + return sus_with_adapter + @staticmethod def is_topo_order(seq: List[ScheduleUnit], integrity_check=False): @@ -386,6 +494,39 @@ def __repr__(self): return dscp +class SUGraphGener: + + @staticmethod + def gen_sugraph(nodes) -> SUGraph: + """ + Generate SUGraph from SchedulePool + """ + sus = list() + fnodes = list() + fsus: List[ScheduleUnit] = list() + for node in nodes: + su = ScheduleUnit([node], stype=SUType.Empty, name='su') + if isinstance(node, IRDataOperation): + stype = SUType.Dataloader + elif isinstance(node, IRFwOperation): + stype = SUType.Forward + fnodes.append(node) + fsus.append(su) + elif isinstance(node, IRBpOperation): + stype = SUType.Backward + index = fnodes.index(node.mirror) + fsu = fsus[index] + su.mirror = fsu + fsu.mirror = su + else: + raise NotImplementedError("Not implemented node type") + su.stype = stype + sus.append(su) + sus_with_adapter = SUGraph.gen_adapter(sus) + sugraph = SUGraph(sus_with_adapter) + return sugraph + + class SeqSpace: @staticmethod diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index 113f59d2..b7e4be3a 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -4,16 +4,13 @@ The traning logic first translate the training logic into Schedule Units, and then add Adapter ScheduleUnit """ -from typing import List import torch -from cube.ir.cten import IRCell, IRTensor +from cube.ir.cten import IRTensor from cube.graph.tensor import IRFullTensor -from cube.schedule.adapter.comm import IRCommunication -from cube.schedule.adapter.select import IRTensorReshape -from cube.schedule.su import SUType, ScheduleUnit +from cube.graph.operator import IRDataOperation +import cube.graph.gpass as gpass from cube.schedule.pool import SchedulePool -from cube.schedule.sugraph import SUGraph class IRDataLoader: @@ -47,18 +44,13 @@ def load_data(dataloader: IRDataLoader): data.requires_grad = False outputs.append(data) - cell = IRCell( - name='dataloader', - signature='dataloader.__next__', - input_length=0, - output_length=len(datas) + data_op = IRDataOperation( + data_num=len(datas) ) for idx, output in enumerate(outputs): - cell.set_output(idx, output) - - su = ScheduleUnit([cell], stype=SUType.Dataloader, name='DataLoader') - SchedulePool().add_su(su) + data_op.set_output(idx, output) + SchedulePool().add_node(data_op) if len(outputs) == 0: return elif len(outputs) == 1: return outputs[0] else: return tuple(outputs) @@ -68,129 +60,32 @@ def forward(graph, *args): """ Translator Action: forward an IRGraph """ - - def _forward(graph, stype, *args): - # set input - for input, arg in zip(graph.inputs(), args): - graph._replace_tensor(input, arg) - # translate to SUs - sus = list() - for node in graph.nodes(): - su = ScheduleUnit([node], stype, name=str(stype)) - sus.append(su) - return sus - - # forward graph - fgraph = graph.copy(reverse=False) - # backward graph - bgraph = graph.copy(reverse=True) - bgraph.tag = 'backward' - - # translate forward graph - fsus = _forward(fgraph, SUType.Forward, *args) - bsus = _forward(bgraph, SUType.Backward, *(fgraph.outputs())) - for fsu, bsu in zip(fsus, bsus[::-1]): - fsu.set_mirror(bsu) - bsu.set_mirror(fsu) - SchedulePool().add_su(fsu) - + fgraph = gpass.forward(graph, *args) + for node in fgraph.nodes(): + SchedulePool().add_node(node) for output in fgraph.outputs(): - output.set_trace(fsus) - + SchedulePool().tape(output, fgraph.nodes()) outputs = fgraph.outputs() if len(outputs) == 1: return outputs[0] elif len(outputs) == 0: return None else: return outputs @staticmethod - def backward(tensor: IRTensor): + def backward(loss: IRTensor): """ Translator Action: backward a tensor """ - if tensor.trace is None: - return - for fsu in tensor.trace[::-1]: - SchedulePool().add_su(fsu.mirror) - - @staticmethod - def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: - """ - Each computation SU has adapters for its inputs - """ - sugraph = SUGraph(sus) - - # clear adapters - for su in sugraph.sus(): - su._clear_adapters() - - for su in sugraph.sus(): - for in_idx, input in enumerate(su.inputs()): - if not isinstance(input, IRTensor): - continue - pre_sus = su.predecessors(in_idx) - tensor_segments = list() - for pre_su in pre_sus: - for out_idx, output in enumerate(pre_su.outputs()): - if output.overlap(input): - sub_tensor = input.common(output) - if sub_tensor != input: - tensor_segments.append(sub_tensor) - send_op = IRCommunication( - send_tensors=[sub_tensor], - send_ranks = [-1] - ) - recv_op = IRCommunication( - recv_tensors=[sub_tensor], - recv_ranks = [-1] - ) - send_op.pair(recv_op) - send_su = ScheduleUnit([send_op], SUType.Adapter, name='send') - recv_su = ScheduleUnit([recv_op], SUType.Adapter, name='recv') - su._add_in_adapter(in_idx, send_su, recv_su) - pre_su._add_out_adapter(out_idx, send_su, recv_su) - # add adapter for merge - if len(tensor_segments) != 0: - merge_op = IRTensorReshape( - src_tensors=tensor_segments, dst_tensors=[input] - ) - merge_su = ScheduleUnit([merge_op], SUType.Adapter, name='merge') - su._set_merge_adapter(in_idx, merge_su) - - # add adapter for select - for out_idx, output in enumerate(su.outputs()): - if not isinstance(output, IRTensor): - continue - select_tensors = list() - send_adapters, recv_adapters = su.out_adapters(out_idx) - for send_adapter in send_adapters: - for tensor in send_adapter.nodes(0).send_tensors: - if tensor != output: - select_tensors.append(tensor) - if len(select_tensors) != 0: - select_op = IRTensorReshape( - src_tensors=[output], dst_tensors=select_tensors - ) - select_su = ScheduleUnit( - [select_op], SUType.Adapter, name='select' - ) - su._set_select_adapter(out_idx, select_su) - - sus_with_adapter = list() - for su in sus: - # send + recv + merge - for idx in range(len(su.inputs())): - merge_su = su.merge_adapters(idx) - if merge_su: - sus_with_adapter.append(merge_su) - send_adapters, recv_adapters = su.in_adapters(idx) - for send_su, recv_su in zip(send_adapters, recv_adapters): - sus_with_adapter.append(send_su) - sus_with_adapter.append(recv_su) - # excute - sus_with_adapter.append(su) - # select - for idx in range(len(su.outputs())): - select_su = su.select_adapters(idx) - if select_su: - sus_with_adapter.append(select_su) - return sus_with_adapter + trace = SchedulePool().get_tape(loss) + if trace is None: + raise RuntimeError("No forward detected") + bnode = None + loss_idx = None + for node in trace[::-1]: + if loss in node.outputs(): + bnode = node.mirror + loss_idx = node.outputs().index(loss) + node.outputs(loss_idx).grad = loss + bnode.set_grad(loss_idx, loss) + break + for node in trace[::-1]: + SchedulePool().add_node(node.mirror) diff --git a/examples/e2e.py b/examples/e2e.py index 53013bbd..d073a195 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -96,14 +96,13 @@ def train(): model = FeedForward(dim=1024) model = cube.schedule.SemanticModel( model, input_shapes=([batch_size,1024],), - policy_fn=trans_policy ) dataloader = FakeDataLoader(batch_size) - @cube.schedule.schedule(model, dataloader, policy_fn=schedule_policy) + @cube.schedule.schedule(model, dataloader, transform_policy=trans_policy, schedule_policy=schedule_policy) def train_iter(model, dataloader): - for _ in range(4): + for _ in range(1): data = next(dataloader) loss = model(data) loss.backward() diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index 672646b8..fc9d5ab7 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -1,11 +1,10 @@ from cube.graph.tensor import IRFullTensor -from cube.graph.operator import IROperation +from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph +from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.sugraph import SUGraphGener from cube.schedule.translator import IRDataLoader -from cube.schedule.translator import LogicTranslator from cube.schedule.graphpass import SUGraphPass import torch @@ -39,56 +38,41 @@ def construct_graph(): bias4 = IRFullTensor(shape=[1024, 1024], name='bias') # linear1 - linear1 = IROperation( + linear1 = Linear( name='linear1', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [input, weight1, bias1], ) - linear1.set_input(0, input) - linear1.set_input(1, weight1) - linear1.set_input(2, bias1) linear1.infer_shape() # linear2 - linear2 = IROperation( + linear2 = Linear( name='linear2', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear1.outputs(0), weight2, None], ) - linear2.set_input(0, linear1.outputs(0)) - linear2.set_input(1, weight2) linear2.infer_shape() # linear3 - linear3 = IROperation( + linear3 = Linear( name='linear3', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear2.outputs(0), weight3, bias3], ) - linear3.set_input(0, linear2.outputs(0)) - linear3.set_input(1, weight3) - linear3.set_input(2, bias3) linear3.infer_shape() # linear4 - linear4 = IROperation( + linear4 = Linear( name='linear4', signature='torch.nn.functional.linear', - input_length=3, - output_length=1 + inputs= [linear3.outputs(0), weight4, bias4], ) - linear4.set_input(0, linear3.outputs(0)) - linear4.set_input(1, weight4) - linear4.set_input(2, bias4) linear4.infer_shape() graph = IRGraph( nodes=[linear1, linear2, linear3, linear4], input_tensors=[input], - output_tensors=linear3.outputs(), + output_tensors=linear4.outputs(), module_name="Test" ) return graph @@ -105,10 +89,11 @@ def test_model_gen(): output = graph(data) output.backward() - sus = SchedulePool().sus() - sus = LogicTranslator.gen_adapter(sus) + nodes = SchedulePool().nodes() + graph = IRGraph(nodes, None, None, module_name='Test') + + sugraph = SUGraphGener.gen_sugraph(nodes) - sugraph = SUGraph(sus) fsus = [su for su in sugraph.sus() if su.stype == SUType.Forward] dsus = [su for su in sugraph.sus() if su.stype == SUType.Dataloader] for dsu in dsus: @@ -122,10 +107,13 @@ def test_model_gen(): sugraph.assign(su, 1) sugraph.assign(su.mirror, 1) + print('after asignment:\n', sugraph) + sugraph = SUGraphPass.remove_redundant_adapters(sugraph) - sugraph = SUGraphPass.merge_small_sus(sugraph) + print('after remove adapter:\n', sugraph) - print(sugraph) + sugraph = SUGraphPass.merge_small_sus(sugraph) + print('after merge samll SU:\n', sugraph) mgener = ModelCodeGen(sugraph) tgener = ScheduleCodeGen(sugraph) diff --git a/tests/schedule/test_graphpass.py b/tests/schedule/test_graphpass.py index a47ad736..654a281b 100644 --- a/tests/schedule/test_graphpass.py +++ b/tests/schedule/test_graphpass.py @@ -6,8 +6,7 @@ from cube.schedule.graphpass import SUGraphPass from cube.schedule.pool import SchedulePool from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.schedule.translator import LogicTranslator +from cube.schedule.sugraph import SUGraphGener def construct_graph(): @@ -57,22 +56,20 @@ def test_remove_adapter(): SchedulePool().clear() graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data') + data = IRFullTensor(shape=[64,1024], name='data').tosub() output = graph(data) output.backward() - # forward adatpers - sus = SchedulePool().sus() - sus = LogicTranslator.gen_adapter(sus) + nodes = SchedulePool().nodes() + sugraph = SUGraphGener.gen_sugraph(nodes) - sugraph = SUGraph(sus) for su in sugraph.sus(): sugraph.assign(su, 0) sugraph = SUGraphPass.remove_redundant_adapters(sugraph) for su in sugraph.sus(): print(su) for su in sugraph.sus(): - assert su.stype != SUType.Adapter + assert su.stype != SUType.Comm assert len(sugraph.sus()) == 6 @@ -81,18 +78,15 @@ def test_merge_small_sus(): SchedulePool().clear() graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data') + data = IRFullTensor(shape=[64,1024], name='data').tosub() output = graph(data) output.backward() - # forward adatpers - sus = SchedulePool().sus() - sus = LogicTranslator.gen_adapter(sus) - - sugraph = SUGraph(sus) + nodes = SchedulePool().nodes() + sugraph = SUGraphGener.gen_sugraph(nodes) for su in sugraph.sus(): - if su.stype != SUType.Adapter: + if su.stype != SUType.Comm: sugraph.assign(su, 0) print('orignal:') @@ -103,7 +97,5 @@ def test_merge_small_sus(): print('merged:') print(sugraph) - assert len(sugraph.sus()) == 4 - assert sugraph.sus(0).stype == SUType.Forward - assert sugraph.sus(3).stype == SUType.Backward - assert sugraph.sus(0).mirror == sugraph.sus(3) + assert len(sugraph.sus()) == 12 + assert all([su.stype == SUType.Forward for su in sugraph.fsus()]) diff --git a/tests/schedule/test_pool.py b/tests/schedule/test_pool.py index d375f209..d98813bc 100644 --- a/tests/schedule/test_pool.py +++ b/tests/schedule/test_pool.py @@ -7,18 +7,16 @@ def test_schedule_pool(): SchedulePool().clear() - assert len(SchedulePool()._sus) == 0 - assert len(SchedulePool().sus()) == 0 + assert len(SchedulePool()._nodes) == 0 + assert len(SchedulePool().nodes()) == 0 cell = IRCell( name='test', signature='test', input_length=4, output_length=2 ) - su = ScheduleUnit([cell], SUType.Forward, name='su') - SchedulePool().add_su(su) + SchedulePool().add_node(cell) - assert len(SchedulePool()._sus) == 1 - assert len(SchedulePool().sus()) == 1 - - for record_su in SchedulePool().sus(): - assert record_su == su + assert len(SchedulePool()._nodes) == 1 + assert len(SchedulePool().nodes()) == 1 + for record_node in SchedulePool().nodes(): + assert record_node == cell diff --git a/tests/schedule/test_su.py b/tests/schedule/test_su.py index 8964e0f5..50794915 100644 --- a/tests/schedule/test_su.py +++ b/tests/schedule/test_su.py @@ -51,27 +51,29 @@ def test_su_init(): linear1, linear2, linear3 = nodes su1 = ScheduleUnit([linear1], stype=SUType.Forward) - assert len(su1.inputs()) == 1 + assert len(su1.inputs()) == 3 assert len(su1.outputs()) == 1 assert su1.signature == SUType.Forward.value assert su1.mirror is None assert su1.stype == SUType.Forward assert su1._nodes == [linear1] - assert len(su1._send_in_adapters) == 1 - assert len(su1._recv_in_adapters) == 1 + assert len(su1._send_in_adapters) == 3 + assert len(su1._recv_in_adapters) == 3 assert len(su1._send_out_adapters) == 1 assert len(su1._recv_out_adapters) == 1 assert len(su1._ctrl_predecessors) == 0 assert len(su1._ctrl_successors) == 0 su2 = ScheduleUnit([linear1, linear2], stype=SUType.Forward) - assert len(su2.inputs()) == 1 + print('su2:', su2) + assert len(su2.inputs()) == 4 assert len(su2.outputs()) == 1 assert su2.signature == SUType.Forward.value su3 = ScheduleUnit([linear1, linear2, linear3], stype=SUType.Forward) - assert len(su3.inputs()) == 1 + print('su3:', su3) + assert len(su3.inputs()) == 6 assert len(su3.outputs()) == 1 assert su3.signature == SUType.Forward.value @@ -83,16 +85,8 @@ def test_su_copy(): linear1, linear2, linear3 = nodes su1 = ScheduleUnit([linear1, linear2], stype=SUType.Forward) - assert len(su1.inputs()) == 1 - assert len(su1.outputs()) == 1 - assert su1.signature == SUType.Forward.value - su2 = ScheduleUnit([linear1, linear2, linear3], stype=SUType.Forward) - assert len(su2.inputs()) == 1 - assert len(su2.outputs()) == 1 - assert su2.signature == SUType.Forward.value - - su1.set_mirror(su2) + su1.mirror = su2 csu = copy.copy(su1) assert csu.inputs() == su1.inputs() diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index 1d3602be..803183e8 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -55,6 +55,7 @@ def test_graph_init(): sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] sugraph = SUGraph(sus) + print(sugraph) assert len(sugraph.inputs()) == 1 assert len(sugraph.outputs()) == 1 assert graph.inputs() == sugraph.inputs() @@ -95,13 +96,17 @@ def test_sugraph_merge(): sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] sugraph = SUGraph(sus) - su1, su2, su3 = sugraph.sus() + su1, su2, su3 = sugraph.fsus() assert sugraph.merge(su1, su3) is None - + + print('origin: ') + print(sugraph) su12 = sugraph.merge(su1, su2) - assert sugraph.nnodes == 2 - assert len(su12.inputs()) == 1 + print('merged: ') + print(sugraph) + assert sugraph.nnodes == 4 + assert len(su12.inputs()) == 4 assert len(su12.outputs()) == 1 assert len(su12.nodes()) == 2 assert su12 in sugraph.sus() @@ -128,7 +133,7 @@ def test_sugraph_add_flow(): assert su3 in su1.successors() -def test_sugraph_assign(): +def test_sugraph_assign1(): graph = construct_graph() sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] @@ -145,8 +150,8 @@ def test_sugraph_assign(): recv_ranks = [-1] ) send_op.pair(recv_op) - send_su12 = ScheduleUnit([send_op], SUType.Adapter, name='send') - recv_su12 = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + send_su12 = ScheduleUnit([send_op], SUType.Comm, name='send') + recv_su12 = ScheduleUnit([recv_op], SUType.Comm, name='recv') su1._add_out_adapter(0, send_su12, recv_su12) su2._add_in_adapter(0, send_su12, recv_su12) @@ -160,8 +165,8 @@ def test_sugraph_assign(): recv_ranks = [-1] ) send_op.pair(recv_op) - send_su23 = ScheduleUnit([send_op], SUType.Adapter, name='send') - recv_su23 = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + send_su23 = ScheduleUnit([send_op], SUType.Comm, name='send') + recv_su23 = ScheduleUnit([recv_op], SUType.Comm, name='recv') su2._add_out_adapter(0, send_su23, recv_su23) su3._add_in_adapter(0, send_su23, recv_su23) @@ -193,7 +198,7 @@ def test_sugraph_assign(): assert not sugraph.assign(send_su12, 3) -def test_sugraph_assign(): +def test_sugraph_assign2(): graph = construct_graph() sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] @@ -210,8 +215,8 @@ def test_sugraph_assign(): recv_ranks = [-1] ) send_op.pair(recv_op) - send_su12 = ScheduleUnit([send_op], SUType.Adapter, name='send') - recv_su12 = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + send_su12 = ScheduleUnit([send_op], SUType.Comm, name='send') + recv_su12 = ScheduleUnit([recv_op], SUType.Comm, name='recv') su1._add_out_adapter(0, send_su12, recv_su12) su2._add_in_adapter(0, send_su12, recv_su12) @@ -225,8 +230,8 @@ def test_sugraph_assign(): recv_ranks = [-1] ) send_op.pair(recv_op) - send_su23 = ScheduleUnit([send_op], SUType.Adapter, name='send') - recv_su23 = ScheduleUnit([recv_op], SUType.Adapter, name='recv') + send_su23 = ScheduleUnit([send_op], SUType.Comm, name='send') + recv_su23 = ScheduleUnit([recv_op], SUType.Comm, name='recv') su2._add_out_adapter(0, send_su23, recv_su23) su3._add_in_adapter(0, send_su23, recv_su23) diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py index a0ccb3d8..a2e53d18 100644 --- a/tests/schedule/test_translator.py +++ b/tests/schedule/test_translator.py @@ -1,6 +1,7 @@ import torch +from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.schedule.translator import LogicTranslator +from cube.schedule.translator import LogicTranslator, SUGraphGener from cube.schedule.translator import IRDataLoader from cube.schedule.su import SUType from cube.schedule.pool import SchedulePool @@ -77,14 +78,14 @@ def test_load_dataloader(): assert data1.shape == [64, 1024] data2 = next(dataloader) - assert len(SchedulePool().sus()) == 2 - assert all([su.stype == SUType.Dataloader for su in SchedulePool().sus()]) + assert len(SchedulePool().nodes()) == 2 + assert all([isinstance(node, IRDataOperation) for node in SchedulePool().nodes()]) data3 = LogicTranslator.load_data(dataloader) assert isinstance(data1, IRSubTensor) assert data1.shape == [64, 1024] - assert len(SchedulePool().sus()) == 3 - assert all([su.stype == SUType.Dataloader for su in SchedulePool().sus()]) + assert len(SchedulePool().nodes()) == 3 + assert all([isinstance(node, IRDataOperation) for node in SchedulePool().nodes()]) def test_translator_forward(): @@ -99,12 +100,12 @@ def test_translator_forward(): assert output.shape == [64, 1024] assert output.trace is not None - sus = SchedulePool().sus() - assert len(sus) == 3 - assert output.trace == sus - for su in sus: - assert su.stype == SUType.Forward - assert su.mirror is not None + nodes = SchedulePool().nodes() + assert len(nodes) == 3 + assert isinstance(SchedulePool().get_tape(output), list) + for node in nodes: + assert isinstance(node, IRFwOperation) + assert isinstance(node.mirror, IRBpOperation) def test_translator_backward(): @@ -113,20 +114,18 @@ def test_translator_backward(): graph = construct_graph() data = IRFullTensor(shape=[64,1024], name='data').tosub() output = graph(data) - output.backward() - sus = SchedulePool().sus() - assert len(sus) == 6 - fsus = sus[0:3] - bsus = sus[3:] - for fsu, bsu in zip(fsus, bsus[::-1]): + nodes = SchedulePool().nodes() + assert len(nodes) == 6 + fnodes = nodes[0:3] + bnodes = nodes[3:] + for fsu, bsu in zip(fnodes, bnodes[::-1]): assert fsu.mirror == bsu assert bsu.mirror == fsu - assert bsu.stype == SUType.Backward -def test_translator_gen_adapter(): +def test_sugraph_gener_gen(): SchedulePool().clear() graph = construct_graph() @@ -134,24 +133,24 @@ def test_translator_gen_adapter(): output = graph(data) # forward adatpers - sus = SchedulePool().sus() - sus = LogicTranslator.gen_adapter(sus) - assert len(sus) == 7 - su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3 = sus + nodes = SchedulePool().nodes() + sugraph = SUGraphGener.gen_sugraph(nodes) + assert len(sugraph.sus()) == 7 + su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3 = sugraph.sus() assert su1.stype == SUType.Forward assert su2.stype == SUType.Forward assert su3.stype == SUType.Forward - assert send_su12.stype == SUType.Adapter - assert recv_su12.stype == SUType.Adapter - assert send_su23.stype == SUType.Adapter - assert recv_su23.stype == SUType.Adapter + assert send_su12.stype == SUType.Comm + assert recv_su12.stype == SUType.Comm + assert send_su23.stype == SUType.Comm + assert recv_su23.stype == SUType.Comm # backward adapters output.backward() - sus = SchedulePool().sus() - sus = LogicTranslator.gen_adapter(sus) - for su in sus: + nodes = SchedulePool().nodes() + sugraph = SUGraphGener.gen_sugraph(nodes) + for su in sugraph.sus(): print(su) # note loss will be the input to autograd, therefore # have additional adapters - assert len(sus) == 16 + assert len(sugraph.sus()) == 18 diff --git a/tests/schedule/test_worflow.py b/tests/schedule/test_worflow.py index b6e6e804..ee5fa866 100644 --- a/tests/schedule/test_worflow.py +++ b/tests/schedule/test_worflow.py @@ -3,10 +3,10 @@ import cube from cube.graph.graph import IRGraph -from cube.schedule.su import SUType +from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.pool import SchedulePool from cube.schedule.sugraph import SUGraph -from cube.schedule.translator import LogicTranslator, IRDataLoader +from cube.schedule.translator import LogicTranslator, IRDataLoader, SUGraphGener class MLP(nn.Module): @@ -37,7 +37,7 @@ def __next__(self): self.pos += 1 if self.pos == self.length: raise StopIteration - return torch.randn(self.shape).cuda() + return torch.randn(self.shape) def test_semantic_model(): @@ -91,13 +91,13 @@ def train_iter(model, dataloader): train_iter(model, dataloader) - sus = SchedulePool().sus() - sus_with_adapter = LogicTranslator.gen_adapter(sus) - sugraph = SUGraph(sus_with_adapter) + nodes = SchedulePool().nodes() + graph = IRGraph(nodes, None, None, 'testmodel') + print(graph) - sugraph = policy(sugraph, None) + sugraph = SUGraphGener.gen_sugraph(nodes) - for su in sugraph.sus(): - print(su) + sugraph = policy(sugraph, None) + print(sugraph) - assert len(sugraph.sus()) == 1 + 2 * (4 * 3) + assert len(sugraph.sus()) == 33 From 12f4e6306c772a872e9a75bb32f0fb2a1226b47c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Nov 2021 16:18:40 +0800 Subject: [PATCH 0259/1892] fix multi forward bug --- cube/graph/gpass.py | 26 ++++++++++++++------------ tests/graph/test_graph.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index a8e7edb9..ec6115af 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -8,7 +8,7 @@ from cube.ir.cten import IRTensor -__all__ = ['forward', 'backward'] +__all__ = ['forward'] class _TensorGener: @@ -67,6 +67,8 @@ def forward(graph, *args) -> IRGraph: # forwrd node fnode = copy.copy(node) + fnode._inputs = inputs + fnode._outputs = outputs # set forward inputs for idx, val in enumerate(inputs): fnode.set_input(idx, gener.renew(val)) @@ -81,18 +83,20 @@ def forward(graph, *args) -> IRGraph: # set input bnode.set_data(idx, val) # set gradient output - val = val if isinstance(val, IRTensor) else None - grad = gener.renew(val, keep_param=False) - grad = grad.as_grad() if isinstance(grad, IRTensor) else grad - if isinstance(val, IRTensor) and val.requires_grad: - val.grad = grad + grad = None + if isinstance(val, IRTensor): + # TODO: requires_grad = False should be set to None + grad = gener.renew(val, keep_param=False).as_grad() + val.add_grad(grad) bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): # set gradient input - grad = gener.renew(val, keep_param=False) - grad = grad.as_grad() if isinstance(grad, IRTensor) else grad - if isinstance(val, IRTensor) and val.requires_grad: - val.grad = grad + grad = None + if isinstance(val, IRTensor): + # TODO: requires_grad = False should be set to None + grad = gener.renew(val, keep_param=False).as_grad() + # TODO: this grad should be partitioned in value dimension + val.add_grad(grad) bnode.set_grad(idx, grad) fnode.device = node.device @@ -109,7 +113,5 @@ def forward(graph, *args) -> IRGraph: outputs = [gener.renew(output) for output in graph.outputs()] fgraph = IRGraph(fnodes, inputs, outputs, graph.name) - for output in fgraph.outputs(): - output.set_trace(fgraph.nodes()) return fgraph diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 671a8038..66f7c6b8 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -137,6 +137,24 @@ def test_graph_forward(): assert not bnode1 in bnode3.predecessors() +def test_graph_multi_forward(): + inputs, ops, outputs = construct_model() + graph = IRGraph(ops, inputs, outputs, 'MLP') + + def _gen_data(graph): + data = list() + for input in graph.inputs(): + data.append(input.parent.like().tosub()) + return data + fgraph1 = gpass.forward(graph, *_gen_data(graph)) + fgraph2 = gpass.forward(graph, *_gen_data(graph)) + print(fgraph1) + print(fgraph2) + assert fgraph1.inputs != fgraph2.inputs() + for node1, node2 in zip(fgraph1.nodes(), fgraph2.nodes()): + assert node1.inputs() != node2.inputs() + + def test_graph_partition(): inputs, ops, outputs = construct_model() From 3e1f711717bfeda62a1190bd7a4a4a6ca075b849 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Nov 2021 16:36:40 +0800 Subject: [PATCH 0260/1892] add gradient multi version --- cube/codegen/codegen.py | 28 +++++++++++++----------- cube/ir/cten.py | 27 +++++++++++++---------- cube/schedule/sugraph.py | 41 +++++++++++++++++++++++++++++------ cube/schedule/translator.py | 3 ++- examples/e2e.py | 36 +++++++++++++++--------------- tests/codegen/test_codegen.py | 41 +++++++++++++++++++++++++---------- tests/graph/test_tensor.py | 4 ++-- 7 files changed, 117 insertions(+), 63 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index da27b77d..11d9c2b8 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -324,24 +324,26 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: finputs.append(self.naming(tensor, fsu)) fargs = '(' + ', '.join(finputs + ['']) + ')' - foutputs = list() + fouts = list() for tensor in fsu.outputs(): - foutputs.append(self.naming(tensor, fsu)) - foutputs = '(' + ', '.join(foutputs + ['']) + ')' - - in_grads = list() - for tensor in fsu.outputs(): - grad = tensor.grad + fouts.append(self.naming(tensor, fsu)) + fouts = '(' + ', '.join(fouts + ['']) + ')' + + fout_grads = list() + for fout in fsu.outputs(): + grad = None + for fout_grad in fout.grads: + if fout_grad in su.inputs(): + grad = fout_grad if grad in fsu.outputs(): - in_grads.append('None') - else: - in_grads.append(self.naming(grad, su)) - in_grads = '(' + ', '.join(in_grads + ['']) + ')' + grad = None + fout_grads.append(self.naming(grad, su)) + fout_grads = '(' + ', '.join(fout_grads + ['']) + ')' body = bsign.format( input_tensors = fargs, - output_tensors = foutputs, - output_grads = in_grads + output_tensors = fouts, + output_grads = fout_grads ) # returned value are graph.outputs diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 9a9b5662..7b235d64 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -344,7 +344,7 @@ class IRTensor: IRTensor serves as IRGraph edge """ - _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad'] + _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grads'] def __init__(self, shape=None, name=None): @@ -359,9 +359,7 @@ def __init__(self, shape=None, name=None): self._is_grad = False self._requires_grad = True - self._grad = None - - self.trace = None + self._grads = list() def attach_cell(self, cell: IRCell): """ @@ -401,7 +399,7 @@ def requires_grad(self, requires: bool): raise TypeError("Expected bool") self._requires_grad = requires if not requires: - self.grad = None + self._grads = list() def as_param(self): """ @@ -419,15 +417,22 @@ def is_param(self): return self._is_param @property - def grad(self): - return self._grad + def grads(self) -> List: + return self._grads + + @grads.setter + def grads(self, grads: List): + if grads is None: + grads = list() + if not all([isinstance(grad, IRTensor) for grad in grads]): + raise TypeError("grad can only be None or List[Tensor]") + self._grads = grads + self.requires_grad = True - @grad.setter - def grad(self, grad): + def add_grad(self, grad): if grad is not None and not isinstance(grad, IRTensor): raise TypeError("grad can only be None or Tensor") - self._grad = grad - self.requires_grad = True + self._grads += [grad] def as_grad(self): self._is_param = False diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 21d6868c..7a9bbb11 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -175,7 +175,8 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: start, stop = min(idx1, idx2), max(idx1, idx2) inter_sus = fsus[start+1:stop] for su in inter_sus: - if self.happen_before(su, su2): + # FIXME: currently only allow other device su exists + if self.happen_before(su1, su) or self.happen_before(su, su2): return None for idx in range(len(su2.inputs())): prev_sus = su2.predecessors(idx) @@ -198,16 +199,39 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: ) for idx, input in enumerate(fsu.inputs()): bnode.set_data(idx, input) - fout_grads = [out.grad for out in fsu.outputs()] + + # FIXME: fail case: forward -> forward -> backward -> backward + fout_grads = list() + for fout in fsu.outputs(): + for grad in fout.grads: + if grad in su1.mirror.inputs() + su2.mirror.inputs(): + fout_grads.append(grad) + break + else: + raise RuntimeError("Cannot fout find gradient") for idx, fout_grad in enumerate(fout_grads): bnode.set_grad(idx, fout_grad) - for idx, fin in enumerate(fsu.inputs()): + + fin_grads = list() + for fin in fsu.inputs(): if isinstance(fin, IRTensor): - bnode.set_output(idx, fin.grad) + for grad in fin.grads: + if grad in su1.mirror.outputs() + su2.mirror.outputs(): + fin_grads.append(grad) + break + else: + print(f'msu = {fsu}') + print(f'fin = {fin}') + print(f'fin grads = {fin.grads}') + print(f'fsu1 = {su1}') + print(f'fsu2 = {su2}') + print(f'bsu1 = {su1.mirror}') + print(f'bsu2 = {su2.mirror}') + raise RuntimeError("Cannot find fin gradient") else: - bnode.set_output(idx, None) - for output in fsu.outputs(): - print(output.grad) + fin_grads.append(None) + for idx, fin_grad in enumerate(fin_grads): + bnode.set_output(idx, fin_grad) bsu = ScheduleUnit([bnode], stype=SUType.Backward, name='bsu') bsu.device = su2.mirror.device fsu.mirror = bsu @@ -225,6 +249,9 @@ def _set_adapters(su1, su2, msu): adapters = su2.in_adapters(su2_idx) merge_adapter = su2.merge_adapters(su2_idx) else: + print(f'> Error: msu: {msu}') + print(f'> Error: su1: {su1}') + print(f'> Error: su2: {su2}') raise RuntimeError("Internal Error: not found input SU") msu._add_in_adapter(idx, *adapters) msu._set_merge_adapter(idx, merge_adapter) diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index b7e4be3a..4430ffde 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -84,7 +84,8 @@ def backward(loss: IRTensor): if loss in node.outputs(): bnode = node.mirror loss_idx = node.outputs().index(loss) - node.outputs(loss_idx).grad = loss + # TODO: fix why cannot use loss.add_grad + node.outputs(loss_idx).add_grad(loss) bnode.set_grad(loss_idx, loss) break for node in trace[::-1]: diff --git a/examples/e2e.py b/examples/e2e.py index d073a195..84e982fa 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -16,38 +16,40 @@ import cube from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph def trans_policy(ir_graph, resource): return ir_graph -def schedule_policy(sugraph, resource): +def schedule_policy(sugraph: SUGraph, resource): # put to micro-batch forward-backward sequence fb_op_seqs = list() - for su in sugraph.sus(): + for fsu in sugraph.fsus(): for fb_seq in fb_op_seqs: for ksu in fb_seq[::-1]: - if sugraph.happen_before(ksu, su): - fb_seq.append(su) + if sugraph.happen_before(ksu, fsu): + fb_seq.append(fsu) break else: continue break else: - fb_op_seqs.append([su]) + fb_op_seqs.append([fsu]) + + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + print(f'> collect {len(fb_op_seqs)} forward-backward sequence') for fb_sus in fb_op_seqs: - sugraph.assign(fb_sus[0], 0) - idx = 0 - for su in fb_sus[1:]: - if su.stype == SUType.Forward: - if idx < 3: - sugraph.assign(su, 0) - sugraph.assign(su.mirror, 0) - else: - sugraph.assign(su, 1) - sugraph.assign(su.mirror, 1) - idx += 1 + for idx, su in enumerate(fb_sus): + if idx < 3: + sugraph.assign(su, 0) + sugraph.assign(su.mirror, 0) + else: + sugraph.assign(su, 1) + sugraph.assign(su.mirror, 1) return sugraph @@ -102,7 +104,7 @@ def train(): @cube.schedule.schedule(model, dataloader, transform_policy=trans_policy, schedule_policy=schedule_policy) def train_iter(model, dataloader): - for _ in range(1): + for _ in range(4): data = next(dataloader) loss = model(data) loss.backward() diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index fc9d5ab7..00fa8947 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -82,24 +82,41 @@ def test_model_gen(): SchedulePool().clear() + grad_accum = 2 + graph = construct_graph() dataloader = IRDataLoader(FakeDataLoader(batch_size=64)) - data = next(dataloader) - output = graph(data) - output.backward() + for _ in range(grad_accum): + data = next(dataloader) + output = graph(data) + output.backward() nodes = SchedulePool().nodes() graph = IRGraph(nodes, None, None, module_name='Test') sugraph = SUGraphGener.gen_sugraph(nodes) - fsus = [su for su in sugraph.sus() if su.stype == SUType.Forward] - dsus = [su for su in sugraph.sus() if su.stype == SUType.Dataloader] - for dsu in dsus: - sugraph.assign(dsu, 0) - for idx, su in enumerate(fsus): - if su.stype == SUType.Forward: + fb_seqs = list() + for fsu in sugraph.fsus(): + for fb_seq in fb_seqs: + for ksu in fb_seq[::-1]: + if sugraph.happen_before(ksu, fsu): + fb_seq.append(fsu) + break + else: + continue + break + else: + fb_seqs.append([fsu]) + assert len(fb_seqs) == grad_accum + + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + + for fb_seq in fb_seqs: + for idx, su in enumerate(fb_seq): if idx < 2: sugraph.assign(su, 0) sugraph.assign(su.mirror, 0) @@ -107,13 +124,13 @@ def test_model_gen(): sugraph.assign(su, 1) sugraph.assign(su.mirror, 1) - print('after asignment:\n', sugraph) + print('========= after asignment: ==========\n', sugraph) sugraph = SUGraphPass.remove_redundant_adapters(sugraph) - print('after remove adapter:\n', sugraph) + print('========= after remove adapter: ==========\n', sugraph) sugraph = SUGraphPass.merge_small_sus(sugraph) - print('after merge samll SU:\n', sugraph) + print('========= after merge small SU: ==========\n', sugraph) mgener = ModelCodeGen(sugraph) tgener = ScheduleCodeGen(sugraph) diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py index a0de3bd8..13f212b6 100644 --- a/tests/graph/test_tensor.py +++ b/tests/graph/test_tensor.py @@ -225,7 +225,7 @@ def test_sub_tensor_copy(): val_map = (0, 4), shape = (1024, 512) ) - sub_tensor1.grad = sub_tensor2 + sub_tensor1.grads = [sub_tensor2] cpy_tensor = copy.copy(sub_tensor1) - assert cpy_tensor.grad == sub_tensor2 + assert cpy_tensor.grads[0] == sub_tensor2 From 7b4c4fd2fca9d4107f4ac8a790824cafdd2591b7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Nov 2021 17:31:04 +0800 Subject: [PATCH 0261/1892] fix su merge bugs --- cube/schedule/graphpass.py | 17 ++++++----------- cube/schedule/sugraph.py | 2 +- tests/algorithm/test_generics.py | 2 +- tests/schedule/test_graphpass.py | 3 +-- tests/schedule/test_sugraph.py | 2 +- tests/schedule/test_translator.py | 6 +++--- .../{test_worflow.py => test_workflow.py} | 6 +++--- 7 files changed, 16 insertions(+), 22 deletions(-) rename tests/schedule/{test_worflow.py => test_workflow.py} (93%) diff --git a/cube/schedule/graphpass.py b/cube/schedule/graphpass.py index 23e4b89e..1149e248 100644 --- a/cube/schedule/graphpass.py +++ b/cube/schedule/graphpass.py @@ -44,15 +44,10 @@ def merge_small_sus(sugraph: SUGraph) -> SUGraph: for su in sugraph.sus(): devices.update(set(su.device)) for device in devices: - dev_sus = [su for su in sugraph.sus() if device in su.device] - merged_su = None - for su in dev_sus: - if su.stype == SUType.Forward: - if not isinstance(merged_su, ScheduleUnit): - merged_su = su - continue - merged_su = sugraph.merge(merged_su, su) - if not isinstance(merged_su, ScheduleUnit): - merged_su = su - + dev_sus = [su for su in sugraph.fsus() if device in su.device] + merged_su = dev_sus[0] + for su in dev_sus[1:]: + merged_su = sugraph.merge(merged_su, su) + if merged_su is None: + merged_su = su return sugraph diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 7a9bbb11..656926a5 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -272,7 +272,7 @@ def _set_adapters(su1, su2, msu): for idx, input in enumerate(su2.inputs()): if input not in msu.inputs(): sadapters, radapters = su2.in_adapters(idx) - for adapter in [sadapters + radapters]: + for adapter in sadapters + radapters: if adapter in self.sequence: self.sequence.remove(adapter) diff --git a/tests/algorithm/test_generics.py b/tests/algorithm/test_generics.py index 303b5eb6..088729d2 100644 --- a/tests/algorithm/test_generics.py +++ b/tests/algorithm/test_generics.py @@ -13,7 +13,7 @@ def test_generic_algo_init(): cell.outputs(0).shape = [1024, 1000] algo = GenericDistAlgo(cell) - assert algo.logic_op is cell + assert algo.logic_op is IRCell assert len(algo.input_shapes) == 3 assert algo.input_shapes[0] == [1024, 1024] assert algo.input_shapes[1] == [1024, 1000] diff --git a/tests/schedule/test_graphpass.py b/tests/schedule/test_graphpass.py index 654a281b..4c1798bc 100644 --- a/tests/schedule/test_graphpass.py +++ b/tests/schedule/test_graphpass.py @@ -97,5 +97,4 @@ def test_merge_small_sus(): print('merged:') print(sugraph) - assert len(sugraph.sus()) == 12 - assert all([su.stype == SUType.Forward for su in sugraph.fsus()]) + assert len(sugraph.sus()) == 2 diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index 803183e8..012efceb 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -105,7 +105,7 @@ def test_sugraph_merge(): su12 = sugraph.merge(su1, su2) print('merged: ') print(sugraph) - assert sugraph.nnodes == 4 + assert sugraph.nnodes == 2 assert len(su12.inputs()) == 4 assert len(su12.outputs()) == 1 assert len(su12.nodes()) == 2 diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py index a2e53d18..0ec8df8e 100644 --- a/tests/schedule/test_translator.py +++ b/tests/schedule/test_translator.py @@ -1,9 +1,10 @@ import torch from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.schedule.translator import LogicTranslator, SUGraphGener +from cube.schedule.translator import LogicTranslator from cube.schedule.translator import IRDataLoader from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph, SUGraphGener from cube.schedule.pool import SchedulePool from cube.graph.tensor import IRFullTensor, IRSubTensor @@ -98,7 +99,6 @@ def test_translator_forward(): assert isinstance(output, IRSubTensor) assert output.shape == [64, 1024] - assert output.trace is not None nodes = SchedulePool().nodes() assert len(nodes) == 3 @@ -153,4 +153,4 @@ def test_sugraph_gener_gen(): print(su) # note loss will be the input to autograd, therefore # have additional adapters - assert len(sugraph.sus()) == 18 + assert len(sugraph.sus()) == 14 diff --git a/tests/schedule/test_worflow.py b/tests/schedule/test_workflow.py similarity index 93% rename from tests/schedule/test_worflow.py rename to tests/schedule/test_workflow.py index ee5fa866..59614f5d 100644 --- a/tests/schedule/test_worflow.py +++ b/tests/schedule/test_workflow.py @@ -3,10 +3,10 @@ import cube from cube.graph.graph import IRGraph -from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.su import SUType from cube.schedule.pool import SchedulePool -from cube.schedule.sugraph import SUGraph -from cube.schedule.translator import LogicTranslator, IRDataLoader, SUGraphGener +from cube.schedule.sugraph import SUGraphGener +from cube.schedule.translator import IRDataLoader class MLP(nn.Module): From cfb3c57081a3bd2f7653de7fe1568c2ca92f5b32 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 Nov 2021 17:24:37 +0800 Subject: [PATCH 0262/1892] gradient infer and split --- cube/codegen/codegen.py | 8 +-- cube/graph/gpass.py | 26 ++++++--- cube/graph/graph.py | 55 +++++++++++++++++- cube/graph/operator/operator.py | 18 +++++- cube/graph/tensor.py | 95 +++++++++++++++++++++++++++---- cube/ir/cten.py | 37 ++++-------- cube/schedule/sugraph.py | 45 +++------------ cube/schedule/translator.py | 3 +- examples/e2e.py | 11 ++-- tests/codegen/test_codegen.py | 2 +- tests/schedule/test_translator.py | 13 +++-- 11 files changed, 209 insertions(+), 104 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 11d9c2b8..3ad0c34e 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -331,13 +331,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: fout_grads = list() for fout in fsu.outputs(): - grad = None - for fout_grad in fout.grads: - if fout_grad in su.inputs(): - grad = fout_grad - if grad in fsu.outputs(): - grad = None - fout_grads.append(self.naming(grad, su)) + fout_grads.append(self.naming(fout.grad, fsu)) fout_grads = '(' + ', '.join(fout_grads + ['']) + ')' body = bsign.format( diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index ec6115af..155948ec 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -66,7 +66,8 @@ def forward(graph, *args) -> IRGraph: outputs = node.outputs() # forwrd node - fnode = copy.copy(node) + # fnode = copy.copy(node) + fnode = node fnode._inputs = inputs fnode._outputs = outputs # set forward inputs @@ -84,19 +85,22 @@ def forward(graph, *args) -> IRGraph: bnode.set_data(idx, val) # set gradient output grad = None - if isinstance(val, IRTensor): + if isinstance(val, IRSubTensor): # TODO: requires_grad = False should be set to None - grad = gener.renew(val, keep_param=False).as_grad() - val.add_grad(grad) + # grad = gener.renew(val, keep_param=False).as_grad() + grad = val.get_grad(fnode) + val.grad = grad bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): # set gradient input grad = None - if isinstance(val, IRTensor): + if isinstance(val, IRSubTensor): # TODO: requires_grad = False should be set to None - grad = gener.renew(val, keep_param=False).as_grad() + grad = val.get_grad(fnode) + val.grad = grad + # grad = gener.renew(val, keep_param=False).as_grad() # TODO: this grad should be partitioned in value dimension - val.add_grad(grad) + # val.add_grad(grad) bnode.set_grad(idx, grad) fnode.device = node.device @@ -112,6 +116,10 @@ def forward(graph, *args) -> IRGraph: inputs = [gener.renew(input) for input in graph.inputs()] outputs = [gener.renew(output) for output in graph.outputs()] - fgraph = IRGraph(fnodes, inputs, outputs, graph.name) - return fgraph + for idx, input in enumerate(inputs): + graph.set_input(idx, input) + for idx, output in enumerate(outputs): + graph.set_output(idx, output) + # fgraph = IRGraph(fnodes, inputs, outputs, graph.name) + return graph diff --git a/cube/graph/graph.py b/cube/graph/graph.py index c2d2e8c9..1b4cde5e 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -9,6 +9,7 @@ from typing import Union, Tuple, List, Optional, Any, Dict import copy +from cube.graph.operator.operator import IRBpOperation from cube.ir.cten import IRTensor, IRCell from cube.graph.tensor import IRFullTensor @@ -246,7 +247,8 @@ def subgraph(self, sub_nodes: List[IRCell]): def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: """ Policy primitive. Partition an operator by using - op_partition_algorithm and its configuration + op_partition_algorithm and its configuration. Note the + backward op-partition will be automatically done. Args: op: cell to be partitioned @@ -267,6 +269,31 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional if not algo.satisfy(config): return None nodes = algo.instantiate(op, config) + # set backward mirror node + if op.mirror is not None: + # generate mirror node + for fnode in nodes: + bnode = IRBpOperation( + data_num=len(fnode.inputs()), + grad_num=len(fnode.outputs()) + ) + for idx, val in enumerate(fnode.inputs()): + bnode.set_data(idx, val) + grad = None + if isinstance(val, IRTensor): + # this is wrong + grad = val.grads[-1] + bnode.set_output(idx, grad) + for idx, val in enumerate(fnode.outputs()): + grad = None + if isinstance(val, IRTensor): + # this is wrong + grad = val.grads[-1] + bnode.set_grad(idx, grad) + fnode.mirror = bnode + fnode.device = op.device + bnode.mirror = fnode + bnode.device = op.mirror.device idx = self._nodes.index(op) self._nodes = self._nodes[:idx] + nodes + self._nodes[idx+1:] self.reset_dependency() @@ -282,7 +309,29 @@ def identity(self, input_tensor, dst_op): def __repr__(self): dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs - dscp += f'Inputs: {self._inputs}\n' + inputs = list() + for tensor in self.inputs(): + if isinstance(tensor, IRTensor): + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + inputs.append(f'{anno}{tensor._id}') + else: + inputs.append(tensor) + outputs = list() + for tensor in self.outputs(): + if isinstance(tensor, IRTensor): + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + outputs.append(f'{anno}{tensor._id}') + else: + outputs.append(tensor) + dscp += f"Inputs: {inputs}\n" # nodes for node in self._nodes: succ_node_ids = [None] * len(node.outputs()) @@ -291,5 +340,5 @@ def __repr__(self): succ_node_ids[out_idx] = node_list dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" # outputs - dscp += f"\nOutputs: {self._outputs}\n{'=' * len(self.name)}\n" + dscp += f"\nOutputs: {outputs}\n{'=' * len(self.name)}\n" return dscp diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 70b0b8e1..e47030de 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,7 +1,7 @@ from typing import Any, Optional, Union, List from cube.ir.cten import IRTensor, IRCell -from cube.graph.tensor import IRFullTensor +from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory @@ -57,6 +57,15 @@ def algorithms(self, tag: Optional[str] = None): template = factory.algorithms(type(self), tag) return template(self) + def set_input(self, input_index: int, val: Any): + old_val = self.inputs(input_index) + # remove the old one + if isinstance(old_val, IRSubTensor): + old_val.parent._rm_fdst_cell(self) + if isinstance(val, IRSubTensor): + val.parent._add_fdst_cell(self) + return super().set_input(input_index, val) + def __repr__(self): inputs = list() for tensor in self.inputs(): @@ -177,7 +186,12 @@ def __repr__(self): outputs = list() for tensor in self.outputs(): if isinstance(tensor, IRTensor): - outputs.append(f't{tensor._id}') + anno = 't' + if tensor.is_param(): + anno = 'w' + if tensor.is_grad(): + anno = 'g' + outputs.append(f'{anno}{tensor._id}') else: outputs.append(tensor) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index bde4fb05..9cbed74a 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -1,7 +1,20 @@ +""" +SubTensor Gradient rule: + +1). for input tensors, gradient SubTensor is: + indices = input.indices; + val is splitted by referencing times on the indices + +2). for output tensors, gradient SubTensor is: + indices = output.indices; + val follows same value splitting rules with output +""" + + from typing import List, Optional, Union, Tuple import copy -from cube.ir.cten import IRTensor +from cube.ir.cten import IRCell, IRTensor __all__ = ['IndexMap', 'ValueMap', 'IRFullTensor', 'IRSubTensor'] @@ -256,7 +269,7 @@ def _to_value_map(val_map: Union[Tuple, ValueMap, None]): class IRFullTensor(IRTensor): - def __init__(self, shape=None, name=None): + def __init__(self, shape=None, name=None, requires_grad=True): super().__init__(shape, name) @@ -266,20 +279,43 @@ def __init__(self, shape=None, name=None): # value op self._val_maps: List = list() + # track gradient + self._forward_dst_cells = list() + + self.requires_grad = requires_grad + if requires_grad: + grad = IRFullTensor(shape, 'g' + self.name, False).as_grad() + self.grad = grad + def __copy__(self): """ - Copy the tensor that will have the exactly same id - except the empty attached cell + Full tensor should only exist one instance per id Returns: tensor """ - tensor = IRFullTensor(self._shape, self.name) - for key in self.__dict__: - setattr(tensor, key, getattr(self, key)) - # clear attached cells - tensor._cell = list() - return tensor + return self + + def _add_fdst_cell(self, cell: IRCell): + if not isinstance(cell, IRCell): + raise TypeError("Expect an IRCell") + if cell not in self._forward_dst_cells: + if None in self._forward_dst_cells: + idx = self._forward_dst_cells.index(None) + self._forward_dst_cells[idx] = cell + else: + self._forward_dst_cells.append(cell) + + def _rm_fdst_cell(self, cell: IRCell): + if not isinstance(cell, IRCell): + raise TypeError("Expect an IRCell") + if cell in self._forward_dst_cells: + # setting to None to keep value map order + idx = self._forward_dst_cells.index(cell) + self._forward_dst_cells[idx] = None + + def forward_dst_cells(self): + return copy.copy(self._forward_dst_cells) def as_param(self): """ @@ -368,6 +404,7 @@ def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap sub_tensor = IRSubTensor(self, indices, val_map, shape) for attr in IRFullTensor._attr: setattr(sub_tensor, attr, getattr(self, attr)) + sub_tensor.grad = None self._segments.append(sub_tensor) self._indices.append(indices) @@ -513,6 +550,44 @@ def as_grad(self): self._is_param = False return self + def get_grad(self, fcell: IRCell): + """ + Get gradient of this tensor which is associated by a + forward cell + """ + if not self.requires_grad: + raise RuntimeError("require a gradient for a non-grad tensor") + full_grad = self.parent.grad + if full_grad is None: + return None + if self in fcell.inputs(): + fdst_cells = self.parent.forward_dst_cells() + ref_cells = list() + for dst_cell in fdst_cells: + for input in dst_cell.inputs(): + if self.overlap(input): + ref_cells.append(dst_cell) + break + ref_times = len(ref_cells) + if ref_times == 0: + raise RuntimeError("Internal Error: ref time is 0") + idx = ref_cells.index(fcell) + grad = full_grad.select( + indices = self.indices, + val_map = (idx, ref_times), + shape = self.shape + ) + return grad.as_grad() + elif self in fcell.outputs(): + grad = full_grad.select( + indices = self.indices, + val_map = self.val_map, + shape = self.shape + ) + return grad.as_grad() + else: + raise RuntimeError(f"{self} not found in cell {fcell}") + def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap, None], shape=None): """ Select an IRSubTensor diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 7b235d64..754f5da7 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -344,13 +344,13 @@ class IRTensor: IRTensor serves as IRGraph edge """ - _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grads'] + _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad'] def __init__(self, shape=None, name=None): self._id: int = IDGenerator().gen_tensor_id() self._shape: Optional(List[int]) = shape - self.name = name + self.name = name if name else 'tensor' # device self._cell: List[IRCell] = list() @@ -359,7 +359,7 @@ def __init__(self, shape=None, name=None): self._is_grad = False self._requires_grad = True - self._grads = list() + self._grad = None def attach_cell(self, cell: IRCell): """ @@ -381,14 +381,6 @@ def detach_cell(self, cell: IRCell): raise RuntimeError("the target cell not in the attached list") self._cell.remove(cell) - def set_trace(self, sus: List): - """ - Set tensor generation trace - """ - if not isinstance(sus, list): - raise TypeError("Expected List[ScheduleUnit]") - self.trace = sus - @property def requires_grad(self): return self._requires_grad @@ -399,7 +391,7 @@ def requires_grad(self, requires: bool): raise TypeError("Expected bool") self._requires_grad = requires if not requires: - self._grads = list() + self.grad = None def as_param(self): """ @@ -417,22 +409,15 @@ def is_param(self): return self._is_param @property - def grads(self) -> List: - return self._grads - - @grads.setter - def grads(self, grads: List): - if grads is None: - grads = list() - if not all([isinstance(grad, IRTensor) for grad in grads]): - raise TypeError("grad can only be None or List[Tensor]") - self._grads = grads - self.requires_grad = True + def grad(self): + return self._grad - def add_grad(self, grad): - if grad is not None and not isinstance(grad, IRTensor): + @grad.setter + def grad(self, grad): + if grad and not isinstance(grad, IRTensor): raise TypeError("grad can only be None or Tensor") - self._grads += [grad] + self._grad = grad + self.requires_grad = True def as_grad(self): self._is_param = False diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 656926a5..5d18bb60 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -197,47 +197,20 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: data_num=len(fsu.inputs()), grad_num=len(fsu.outputs()) ) - for idx, input in enumerate(fsu.inputs()): - bnode.set_data(idx, input) - - # FIXME: fail case: forward -> forward -> backward -> backward - fout_grads = list() - for fout in fsu.outputs(): - for grad in fout.grads: - if grad in su1.mirror.inputs() + su2.mirror.inputs(): - fout_grads.append(grad) - break - else: - raise RuntimeError("Cannot fout find gradient") - for idx, fout_grad in enumerate(fout_grads): - bnode.set_grad(idx, fout_grad) - - fin_grads = list() - for fin in fsu.inputs(): - if isinstance(fin, IRTensor): - for grad in fin.grads: - if grad in su1.mirror.outputs() + su2.mirror.outputs(): - fin_grads.append(grad) - break - else: - print(f'msu = {fsu}') - print(f'fin = {fin}') - print(f'fin grads = {fin.grads}') - print(f'fsu1 = {su1}') - print(f'fsu2 = {su2}') - print(f'bsu1 = {su1.mirror}') - print(f'bsu2 = {su2.mirror}') - raise RuntimeError("Cannot find fin gradient") - else: - fin_grads.append(None) - for idx, fin_grad in enumerate(fin_grads): - bnode.set_output(idx, fin_grad) + for idx, fin in enumerate(fsu.inputs()): + bnode.set_data(idx, fin) + + for idx, fout in enumerate(fsu.outputs()): + bnode.set_grad(idx, fout.grad) + + for idx, fin in enumerate(fsu.inputs()): + bnode.set_output(idx, fin.grad) bsu = ScheduleUnit([bnode], stype=SUType.Backward, name='bsu') bsu.device = su2.mirror.device fsu.mirror = bsu bsu.mirror = fsu - def _set_adapters(su1, su2, msu): + def _set_adapters(su1: ScheduleUnit, su2: ScheduleUnit, msu: ScheduleUnit): # set adapter for idx, input in enumerate(msu.inputs()): if input in su1.inputs(): diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index 4430ffde..b901e286 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -85,7 +85,8 @@ def backward(loss: IRTensor): bnode = node.mirror loss_idx = node.outputs().index(loss) # TODO: fix why cannot use loss.add_grad - node.outputs(loss_idx).add_grad(loss) + # assert False + node.outputs(loss_idx).grad = loss bnode.set_grad(loss_idx, loss) break for node in trace[::-1]: diff --git a/examples/e2e.py b/examples/e2e.py index 84e982fa..9d7599dd 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -104,10 +104,13 @@ def train(): @cube.schedule.schedule(model, dataloader, transform_policy=trans_policy, schedule_policy=schedule_policy) def train_iter(model, dataloader): - for _ in range(4): - data = next(dataloader) - loss = model(data) - loss.backward() + # for _ in range(1): + # data = next(dataloader) + # loss = model(data) + # loss.backward() + data = next(dataloader) + loss = model(data) + loss.backward() model = model.get_gen_module() init_weight(model.parameters()) diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index 00fa8947..ada62f3e 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -82,7 +82,7 @@ def test_model_gen(): SchedulePool().clear() - grad_accum = 2 + grad_accum = 1 graph = construct_graph() dataloader = IRDataLoader(FakeDataLoader(batch_size=64)) diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py index 0ec8df8e..cdcca670 100644 --- a/tests/schedule/test_translator.py +++ b/tests/schedule/test_translator.py @@ -113,16 +113,19 @@ def test_translator_backward(): graph = construct_graph() data = IRFullTensor(shape=[64,1024], name='data').tosub() - output = graph(data) - output.backward() + loss = graph(data) + loss.backward() nodes = SchedulePool().nodes() + for node in nodes: + print(node) assert len(nodes) == 6 fnodes = nodes[0:3] bnodes = nodes[3:] - for fsu, bsu in zip(fnodes, bnodes[::-1]): - assert fsu.mirror == bsu - assert bsu.mirror == fsu + assert loss in bnodes[0].inputs() + for fnode, bnode in zip(fnodes, bnodes[::-1]): + assert fnode.mirror == bnode + assert bnode.mirror == fnode def test_sugraph_gener_gen(): From ccb2537bc98c5508dcf1bb80bfc1c3814a1a792d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 Nov 2021 19:47:51 +0800 Subject: [PATCH 0263/1892] test tensor grad --- tests/graph/test_tensor_grad.py | 153 ++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 tests/graph/test_tensor_grad.py diff --git a/tests/graph/test_tensor_grad.py b/tests/graph/test_tensor_grad.py new file mode 100644 index 00000000..227899c5 --- /dev/null +++ b/tests/graph/test_tensor_grad.py @@ -0,0 +1,153 @@ +from cube.graph.graph import IRGraph +from cube.graph.tensor import IRFullTensor, ValueMap +from cube.graph.operator.function import Linear, ElementWise +import cube.graph.gpass as gpass +from cube.ir.cten import IRTensor + + +def construct_model(): + + input1 = IRFullTensor(shape=[64,1024], name='data1') + input2 = IRFullTensor(shape=[64,1024], name='data2') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = Linear( + name='linear1', + signature='torch.nn.functional.linear', + inputs= [input1, weight1, bias1], + ) + linear1.infer_shape() + + # linear2 + linear2 = Linear( + name='linear2', + signature='torch.nn.functional.linear', + inputs= [linear1.outputs(0), weight2, None], + ) + linear2.infer_shape() + + # linear3 + linear3 = Linear( + name='linear3', + signature='torch.nn.functional.linear', + inputs= [linear2.outputs(0), weight3, bias3], + ) + linear3.infer_shape() + + # linear4 + linear4 = Linear( + name='linear4', + signature='torch.nn.functional.linear', + inputs= [input2, weight1, bias1], + ) + linear4.infer_shape() + + # element-wise + add5 = ElementWise( + name='add', + signature='torch.add', + inputs=[linear2.outputs(0), linear3.outputs(0)] + ) + add5.infer_shape() + + # element-wise + add6 = ElementWise( + name='add', + signature='torch.add', + inputs=[add5.outputs(0), linear4.outputs(0)] + ) + add6.infer_shape() + + # return [input], [ops], [output] + return [input1, input2], [linear1, linear2, linear3, linear4, add5, add6], [add6.outputs(0)] + + +def test_tensor_grad(): + + inputs, ops, outputs = construct_model() + linear1, linear2, linear3, linear4, add5, add6 = ops + graph = IRGraph(ops, inputs, outputs, 'MLP') + print(graph) + + all_parent_tids = list() + all_parent_tensors = list() + for op in ops: + for input in op.inputs(): + if isinstance(input, IRTensor): + if input.parent._id not in all_parent_tids: + all_parent_tensors.append(input.parent) + + for pten in all_parent_tensors: + assert pten.grad is None + print(pten.name, pten) + cell_ids = [cell._id for cell in pten.forward_dst_cells()] + print('forward_dst_cells id:', cell_ids) + print('') + + print('test grad:') + + input = linear1.inputs(0) + assert input.grad is None + gin = input.get_grad(linear1) + assert gin.val_map == ValueMap(0, 1) + print(gin.name, gin) + + weight = linear1.inputs(1) + gw = weight.get_grad(linear1) + assert gw.val_map == ValueMap(0, 2) + print(gw.name, gw) + + weight = linear4.inputs(1) + gw = weight.get_grad(linear4) + assert gw.val_map == ValueMap(1, 2) + print(gw.name, gw) + + out2 = linear2.outputs(0) + gout2 = out2.get_grad(linear2) + print(gout2.name, gout2) + assert gout2.val_map == ValueMap(0, 1) + gout2 = out2.get_grad(linear3) + print(gout2.name, gout2) + assert gout2.val_map == ValueMap(0, 2) + gout2 = out2.get_grad(add5) + print(gout2.name, gout2) + assert gout2.val_map == ValueMap(1, 2) + + out3 = linear3.outputs(0) + gout3 = out3.get_grad(linear3) + print(gout3.name, gout3) + assert gout3.val_map == ValueMap(0, 1) + gout3 = out3.get_grad(add5) + print(gout3.name, gout3) + assert gout3.val_map == ValueMap(0, 1) + + for node in graph.nodes(): + assert node.mirror is None + + print('test forward graph:') + inputs = [inputs[0].tosub(), inputs[1].tosub()] + graph = gpass.forward(graph, *inputs) + print(graph) + for node in graph.nodes()[::-1]: + print(node.mirror) + + gw1 = linear1.mirror.outputs(1) + assert gw1.is_grad() + print(gw1) + gw4 = linear4.mirror.outputs(1) + assert gw4.is_grad() + print(gw4) + + assert gw1.parent == gw4.parent + assert gw1.shape == gw4.shape + assert gw1.indices == gw4.indices + assert gw1.val_map != gw4.val_map + + # assert False From 08b3031860ccb0aad903573cf0166e53521f9394 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 01:47:23 +0800 Subject: [PATCH 0264/1892] graph partition enabled --- cube/algorithm/dataloader.py | 44 +++++++++++++++++++++++ cube/algorithm/factory.py | 3 ++ cube/graph/graph.py | 50 ++++++++++++++++++-------- cube/graph/operator/operator.py | 22 ++++++++++++ cube/graph/tensor.py | 2 +- cube/ir/cten.py | 6 ++-- cube/schedule/sugraph.py | 62 ++++++++++++++++++++++++++++++--- cube/schedule/translator.py | 23 ++++++------ tests/codegen/test_codegen.py | 39 +++++++++++++++++---- 9 files changed, 210 insertions(+), 41 deletions(-) create mode 100644 cube/algorithm/dataloader.py diff --git a/cube/algorithm/dataloader.py b/cube/algorithm/dataloader.py new file mode 100644 index 00000000..f863e592 --- /dev/null +++ b/cube/algorithm/dataloader.py @@ -0,0 +1,44 @@ +from typing import List, Dict, Type + +from cube.algorithm.utils import split_axis +from cube.algorithm.generics import GenericDistAlgo +from cube.graph.operator.operator import IRDataOperation + + +_kWaitDecision = None + + +class DPDataLoader(GenericDistAlgo): + + def __init__(self, node: IRDataOperation): + + if not isinstance(node, IRDataOperation): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + for shape in self.output_shapes: + if chunk_num > 0 and shape[0] % chunk_num != 0: + return False + return True + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + sub_outputs = list() + for output in node.outputs(): + sub_output = split_axis(output, 0, self.chunk_num) + sub_outputs.append(sub_output) + + nodes = list() + for sub_outs in zip(*sub_outputs): + node = IRDataOperation(data_num = len(sub_outs)) + for idx, out in enumerate(sub_outs): + node.set_output(idx, out) + nodes.append(node) + return nodes diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index d88e7785..69d6eb4e 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -59,6 +59,9 @@ def algorithms(self, op, tag = None): def _load_predefined_algos(self): + import cube.algorithm.dataloader as dataloader + self.register(dataloader.IRDataOperation, dataloader.DPDataLoader, tag='data') + import cube.algorithm.linear as linear self.register(linear.Linear, linear.LinearDataParallel, tag='data') self.register(linear.Linear, linear.LinearColumnWeight, tag='column') diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 1b4cde5e..063077fe 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -12,7 +12,7 @@ from cube.graph.operator.operator import IRBpOperation from cube.ir.cten import IRTensor, IRCell -from cube.graph.tensor import IRFullTensor +from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.generics import GenericDistAlgo @@ -37,9 +37,17 @@ def __init__(self, self._parameters = list() if input_tensors is None: - input_tensors = IRCell.get_inputs(nodes) + input_tensors = list() + inputs = IRCell.get_inputs(nodes) + for input in inputs: + if not input.is_param(): + input_tensors.append(input) if output_tensors is None: - output_tensors = IRCell.get_outputs(nodes) + output_tensors = list() + outputs = IRCell.get_outputs(nodes) + for output in outputs: + if not output.is_param(): + output_tensors.append(output) super().__init__( name=module_name, @@ -78,6 +86,10 @@ def __init__(self, for node in self._nodes: for input in node.inputs(): if isinstance(input, IRTensor): + if input.is_param(): + # parameters already set + self._parameters.append(input) + continue if input not in input_tensors and \ input.is_leaf(self._nodes): input.as_param() @@ -268,11 +280,14 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional return None if not algo.satisfy(config): return None - nodes = algo.instantiate(op, config) + fnodes = algo.instantiate(op, config) + # remove reference + for idx in range(len(op.inputs())): + op.set_input(idx, None) # set backward mirror node if op.mirror is not None: # generate mirror node - for fnode in nodes: + for fnode in fnodes: bnode = IRBpOperation( data_num=len(fnode.inputs()), grad_num=len(fnode.outputs()) @@ -280,25 +295,32 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional for idx, val in enumerate(fnode.inputs()): bnode.set_data(idx, val) grad = None - if isinstance(val, IRTensor): - # this is wrong - grad = val.grads[-1] + if isinstance(val, IRSubTensor): + if val.requires_grad and val.grad is None: + grad = val.get_grad(fnode) + val.grad = grad + grad = val.grad bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): grad = None - if isinstance(val, IRTensor): - # this is wrong - grad = val.grads[-1] + if isinstance(val, IRSubTensor): + if val.requires_grad and val.grad is None: + grad = val.get_grad(fnode) + val.grad = grad + grad = val.grad bnode.set_grad(idx, grad) fnode.mirror = bnode fnode.device = op.device bnode.mirror = fnode bnode.device = op.mirror.device idx = self._nodes.index(op) - self._nodes = self._nodes[:idx] + nodes + self._nodes[idx+1:] + self._nodes = self._nodes[:idx] + fnodes + self._nodes[idx+1:] + if op.mirror is not None: + idx = self._nodes.index(op.mirror) + bnodes = [node.mirror for node in fnodes][::-1] + self._nodes = self._nodes[:idx] + bnodes + self._nodes[idx+1:] self.reset_dependency() - return copy.copy(nodes) - + return copy.copy(fnodes) def merge(self, sub_graph, target_op, op_partition_algorithm): raise NotImplementedError diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index e47030de..0be30a5e 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -205,3 +205,25 @@ def __init__(self, data_num: int, name='dataloader'): signature = 'dataloader.__next__' super().__init__(name, signature, 0, data_num) + + def algorithms(self, tag: Optional[str] = None): + """ + get algorithm from algorithm factory + + Args: + tag: str or None. If None, return all + """ + factory = DistAlgorithmFactory() + if tag is None: + templates = list() + if factory.exist(type(self)): + templates = factory.algorithms(type(self)) + algos = list() + for template in templates: + algos.append(template(self)) + return algos + else: + if not factory.exist(type(self), tag): + return None + template = factory.algorithms(type(self), tag) + return template(self) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 9cbed74a..168b8cf1 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -315,7 +315,7 @@ def _rm_fdst_cell(self, cell: IRCell): self._forward_dst_cells[idx] = None def forward_dst_cells(self): - return copy.copy(self._forward_dst_cells) + return [cell for cell in self._forward_dst_cells if cell is not None] def as_param(self): """ diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 754f5da7..13e6ea92 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -295,7 +295,8 @@ def get_inputs(cells): for input in cell.inputs(): if isinstance(input, IRTensor): if input not in all_outputs: - inputs.append(input) + if input not in inputs: + inputs.append(input) return inputs @staticmethod @@ -314,7 +315,8 @@ def get_outputs(cells): for output in node.outputs(): if isinstance(output, IRTensor): if output not in all_inputs: - outputs.append(output) + if output not in outputs: + outputs.append(output) return outputs def __repr__(self): diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 5d18bb60..e293be03 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -365,6 +365,56 @@ def set_order(self, seq: List[ScheduleUnit]): self.sequence = seq return True + def partial_set_order(self, seq: List[ScheduleUnit]): + """ + Set a order of the sequence using part of SUs. + + A random topological order will be set under + the constraints of given `seq` order + """ + seq = copy.copy(seq) + for su in seq: + if su not in self.sequence: + raise RuntimeError(f"SU {su} is not in SUGraph") + if not SUGraph.is_topo_order(seq, integrity_check=False): + return False + remain_sus : ScheduleUnit = list() + for su in self.sequence: + if su not in seq: + remain_sus.append(su) + for rsu in remain_sus: + if len(rsu.inputs()) > 0: + happen_before_sus = rsu.predecessors() + idx = 0 + while len(happen_before_sus) > 0: + if idx == len(seq): + raise RuntimeError( + f"Internal Error: SU {rsu} cannot be inserted" + ) + su = seq[idx] + if su in happen_before_sus: + happen_before_sus.remove(su) + idx += 1 + seq.insert(idx, rsu) + else: + succ_sus = rsu.successors() + idx = len(seq) + while len(succ_sus) > 0: + idx -= 1 + if idx < 0: + idx = 0 + break + su = seq[idx] + if su in succ_sus: + succ_sus.remove(su) + seq.insert(idx, rsu) + + if not SUGraph.is_topo_order(seq, integrity_check=True): + raise RuntimeError("Internal Error: topo is not guaranteed.") + self.sequence = seq + return True + + @staticmethod def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: """ @@ -475,12 +525,14 @@ def is_topo_order(seq: List[ScheduleUnit], integrity_check=False): for index, su in enumerate(seq): for pre_su in su.predecessors(): # find the pre-su not appear in sequence - if integrity_check and not pre_su in seq: + if integrity_check: + if pre_su not in seq: + return False + if pre_su in seq: + pre_idx = seq.index(pre_su) + # violate topological order + if pre_idx >= index: return False - pre_idx = seq.index(pre_su) - # violate topological order - if pre_idx >= index: - return False return True def __repr__(self): diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index b901e286..15c0723b 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -6,8 +6,7 @@ """ import torch -from cube.ir.cten import IRTensor -from cube.graph.tensor import IRFullTensor +from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.graph.operator import IRDataOperation import cube.graph.gpass as gpass from cube.schedule.pool import SchedulePool @@ -71,23 +70,21 @@ def forward(graph, *args): else: return outputs @staticmethod - def backward(loss: IRTensor): + def backward(loss: IRSubTensor): """ Translator Action: backward a tensor """ trace = SchedulePool().get_tape(loss) if trace is None: raise RuntimeError("No forward detected") + # make gradient point to it self + loss.parent.grad = loss.parent bnode = None - loss_idx = None - for node in trace[::-1]: - if loss in node.outputs(): - bnode = node.mirror - loss_idx = node.outputs().index(loss) - # TODO: fix why cannot use loss.add_grad - # assert False - node.outputs(loss_idx).grad = loss - bnode.set_grad(loss_idx, loss) - break + for node in trace: + for idx, output in enumerate(node.outputs()): + if loss.overlap(output): + bnode = node.mirror + output.grad = output + bnode.set_grad(idx, output) for node in trace[::-1]: SchedulePool().add_node(node.mirror) diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index ada62f3e..d46bb9eb 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -1,3 +1,4 @@ +from cube.graph.operator.operator import IRDataOperation, IRFwOperation from cube.graph.tensor import IRFullTensor from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph @@ -82,20 +83,28 @@ def test_model_gen(): SchedulePool().clear() - grad_accum = 1 + grad_accum = 2 graph = construct_graph() dataloader = IRDataLoader(FakeDataLoader(batch_size=64)) - for _ in range(grad_accum): - data = next(dataloader) - output = graph(data) - output.backward() + data = next(dataloader) + output = graph(data) + output.backward() nodes = SchedulePool().nodes() graph = IRGraph(nodes, None, None, module_name='Test') - sugraph = SUGraphGener.gen_sugraph(nodes) + for node in graph.nodes(): + if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): + algo = node.algorithms('data') + graph.partition(node, algo, config=dict(chunk_num=grad_accum)) + + # print(graph) + + sugraph = SUGraphGener.gen_sugraph(graph.nodes()) + + # print(sugraph) fb_seqs = list() for fsu in sugraph.fsus(): @@ -109,6 +118,12 @@ def test_model_gen(): break else: fb_seqs.append([fsu]) + + if len(fb_seqs) != grad_accum: + for idx, fb_seq in enumerate(fb_seqs): + print(f'> sequence {idx}:') + for su in fb_seq: + print(su) assert len(fb_seqs) == grad_accum for su in sugraph.sus(): @@ -123,9 +138,21 @@ def test_model_gen(): else: sugraph.assign(su, 1) sugraph.assign(su.mirror, 1) + + for fb_seq in fb_seqs: + fb_seq += [fsu.mirror for fsu in fb_seq][::-1] print('========= after asignment: ==========\n', sugraph) + seqs = list() + for fb_seq in fb_seqs: + seqs += fb_seq + print('> seqs:') + for su in seqs: + print(su) + sugraph.partial_set_order(seqs) + print('========= after reorder: ==========\n', sugraph) + sugraph = SUGraphPass.remove_redundant_adapters(sugraph) print('========= after remove adapter: ==========\n', sugraph) From 15041f5f8d4fedef35a958cc99d4904389595ad0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 11:03:46 +0800 Subject: [PATCH 0265/1892] fix select adapter gen bug --- cube/schedule/sugraph.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index e293be03..62ecde8d 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -458,11 +458,12 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: merge_op = IRTensorReshape( src_tensors=tensor_segments, dst_tensors=[input] ) - merge_su = ScheduleUnit([merge_op], SUType.Comm, name='merge') + merge_su = ScheduleUnit([merge_op], SUType.Transform, name='merge') su._set_merge_adapter(in_idx, merge_su) merge_su.device = su.device - # add adapter for select + # add adapter for select + for su in sugraph.sus(): for out_idx, output in enumerate(su.outputs()): if not isinstance(output, IRTensor): continue @@ -477,7 +478,7 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: src_tensors=[output], dst_tensors=select_tensors ) select_su = ScheduleUnit( - [select_op], SUType.Comm, name='select' + [select_op], SUType.Transform, name='select' ) su._set_select_adapter(out_idx, select_su) select_su.device = su.device From 4918af0ddc90eb58e73e5ca0e413a809c8e99c08 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 13:24:46 +0800 Subject: [PATCH 0266/1892] add recv su predecessors --- cube/graph/tensor.py | 14 ++++++++++--- cube/schedule/sugraph.py | 43 +++++++++++++++++----------------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 168b8cf1..a8b8a565 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -1,11 +1,19 @@ -""" +r""" SubTensor Gradient rule: -1). for input tensors, gradient SubTensor is: +SubTensor's logical grad = SubTensor.parent.grad.select( + indices = SubTensor.indices, + val_map = SubTensor.val_map, + shape = SubTensor.shape +) + +FwOperation -> BpOperation rule: + +1). for (FwOp) input tensors, gradient SubTensor is: indices = input.indices; val is splitted by referencing times on the indices -2). for output tensors, gradient SubTensor is: +2). for (FwOp) output tensors, gradient SubTensor is: indices = output.indices; val follows same value splitting rules with output """ diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 62ecde8d..9dc17d36 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -60,6 +60,12 @@ def reset_dependency(sus: List[ScheduleUnit]): src = sus[src_idx] for dst in sus[src_idx+1:]: for out_idx, out_tensor in enumerate(src.outputs()): + # special dependency for communication adapter + if dst.stype == SUType.Comm: + for recv_tensor in dst.outputs(): + if out_tensor.overlap(recv_tensor): + src.add_successor(out_idx, dst) + dst.add_predecessor(-1, src) for in_idx, in_tensor in enumerate(dst.inputs()): if out_tensor.overlap(in_tensor): src.add_successor(out_idx, dst) @@ -383,31 +389,18 @@ def partial_set_order(self, seq: List[ScheduleUnit]): if su not in seq: remain_sus.append(su) for rsu in remain_sus: - if len(rsu.inputs()) > 0: - happen_before_sus = rsu.predecessors() - idx = 0 - while len(happen_before_sus) > 0: - if idx == len(seq): - raise RuntimeError( - f"Internal Error: SU {rsu} cannot be inserted" - ) - su = seq[idx] - if su in happen_before_sus: - happen_before_sus.remove(su) - idx += 1 - seq.insert(idx, rsu) - else: - succ_sus = rsu.successors() - idx = len(seq) - while len(succ_sus) > 0: - idx -= 1 - if idx < 0: - idx = 0 - break - su = seq[idx] - if su in succ_sus: - succ_sus.remove(su) - seq.insert(idx, rsu) + happen_before_sus = rsu.predecessors() + idx = 0 + while len(happen_before_sus) > 0: + if idx == len(seq): + raise RuntimeError( + f"Internal Error: SU {rsu} cannot be inserted" + ) + su = seq[idx] + if su in happen_before_sus: + happen_before_sus.remove(su) + idx += 1 + seq.insert(idx, rsu) if not SUGraph.is_topo_order(seq, integrity_check=True): raise RuntimeError("Internal Error: topo is not guaranteed.") From 2e99171032767c4841b33b14a2ec08800191c12b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 16:04:58 +0800 Subject: [PATCH 0267/1892] enable elementwise split on data dimension --- cube/algorithm/elementwise.py | 56 ++++++++++++++++++++++++ tests/algorithm/test_elementwise_algo.py | 42 ++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 cube/algorithm/elementwise.py create mode 100644 tests/algorithm/test_elementwise_algo.py diff --git a/cube/algorithm/elementwise.py b/cube/algorithm/elementwise.py new file mode 100644 index 00000000..66bd61b9 --- /dev/null +++ b/cube/algorithm/elementwise.py @@ -0,0 +1,56 @@ +from typing import Dict + +from cube.algorithm.utils import split_axis +from cube.algorithm.generics import GenericDistAlgo +from cube.graph.operator.function import ElementWise +from cube.ir.cten import IRTensor + + +_kWaitDecision = None + + +class ElementWiseDataParallel(GenericDistAlgo): + + def __init__(self, node: ElementWise): + if not isinstance(node, ElementWise): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + input_shape = self.input_shapes[0] + if chunk_num > 0 and input_shape[0] % chunk_num != 0: + return False + return True + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + sub_inputs = list() + for input in node.inputs(): + if isinstance(input, IRTensor): + sub_input = split_axis(input, 0, self.chunk_num) + else: + sub_input = [input] * self.chunk_num + sub_inputs.append(sub_input) + + sub_outputs = list() + for output in node.outputs(): + if isinstance(output, IRTensor): + sub_output = split_axis(output, 0, self.chunk_num) + else: + sub_output = [output] * self.chunk_num + sub_outputs.append(sub_output) + + nodes = list() + for idx, sub_input in enumerate(zip(*sub_inputs)): + node = ElementWise(node.signature, inputs=sub_input, name=node.name) + nodes.append(node) + for idx, sub_output in enumerate(zip(*sub_outputs)): + node = nodes[idx] + for idx, output in enumerate(sub_output): + node.set_output(idx, output) + return nodes diff --git a/tests/algorithm/test_elementwise_algo.py b/tests/algorithm/test_elementwise_algo.py new file mode 100644 index 00000000..74e61fe3 --- /dev/null +++ b/tests/algorithm/test_elementwise_algo.py @@ -0,0 +1,42 @@ +from cube.graph.operator.function import ElementWise +from cube.algorithm.elementwise import ElementWiseDataParallel +from cube.graph.tensor import IRFullTensor + + +def test_elementwise_data_parallel(): + + input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() + input2 = IRFullTensor(shape=[1024, 1024], name='input2').tosub() + + semantic_op = ElementWise( + signature='torch.add', inputs=[input1, input2], name='add' + ) + semantic_op.infer_shape() + print('semantic op:') + print(semantic_op) + + op_dp = ElementWiseDataParallel(semantic_op) + + assert op_dp.chunk_num is None + + # test satisfy + assert op_dp.satisfy(dict(chunk_num = 4)) + assert not op_dp.satisfy(dict(chunk_num = 10)) + + nodes = op_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, ElementWise) + + for node in nodes: + print('=======') + print(node) + print('inputs:') + for input in node.inputs(): + print(input) + assert input.shape == [256, 1024] + print('outputs:') + for output in node.outputs(): + print(output) + assert output.shape == [256, 1024] + From 5c84772dae819a0458c1e7c69d4ad47006a27ffe Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 16:18:48 +0800 Subject: [PATCH 0268/1892] add reduce split on data dimension --- cube/algorithm/factory.py | 6 ++++ cube/algorithm/reduce.py | 55 +++++++++++++++++++++++++++++ tests/algorithm/test_reduce_algo.py | 41 +++++++++++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 cube/algorithm/reduce.py create mode 100644 tests/algorithm/test_reduce_algo.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 69d6eb4e..95549ca9 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -66,3 +66,9 @@ def _load_predefined_algos(self): self.register(linear.Linear, linear.LinearDataParallel, tag='data') self.register(linear.Linear, linear.LinearColumnWeight, tag='column') self.register(linear.Linear, linear.LinearRowWeight, tag='row') + + import cube.algorithm.elementwise as elew + self.register(elew.ElementWise, elew.ElementWiseDataParallel, tag='data') + + import cube.algorithm.reduce as reduce + self.register(reduce.Reduce, reduce.ReduceDataParallel, tag='data') diff --git a/cube/algorithm/reduce.py b/cube/algorithm/reduce.py new file mode 100644 index 00000000..98caa652 --- /dev/null +++ b/cube/algorithm/reduce.py @@ -0,0 +1,55 @@ +from typing import Dict + +from cube.algorithm.utils import split_axis, split_value +from cube.algorithm.generics import GenericDistAlgo +from cube.graph.operator.function import Reduce +from cube.ir.cten import IRTensor + +_kWaitDecision = None + + +class ReduceDataParallel(GenericDistAlgo): + + def __init__(self, node: Reduce): + if not isinstance(node, Reduce): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + input_shape = self.input_shapes[0] + if chunk_num > 0 and input_shape[0] % chunk_num != 0: + return False + return True + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + sub_inputs = list() + for input in node.inputs(): + if isinstance(input, IRTensor): + sub_input = split_axis(input, 0, self.chunk_num) + else: + sub_input = [input] * self.chunk_num + sub_inputs.append(sub_input) + + sub_outputs = list() + for output in node.outputs(): + if isinstance(output, IRTensor): + sub_output = split_value(output, self.chunk_num) + else: + sub_output = [output] * self.chunk_num + sub_outputs.append(sub_output) + + nodes = list() + for idx, sub_input in enumerate(zip(*sub_inputs)): + node = Reduce(node.signature, inputs=sub_input, name=node.name) + nodes.append(node) + for idx, sub_output in enumerate(zip(*sub_outputs)): + node = nodes[idx] + for oidx, output in enumerate(sub_output): + node.set_output(oidx, output) + return nodes diff --git a/tests/algorithm/test_reduce_algo.py b/tests/algorithm/test_reduce_algo.py new file mode 100644 index 00000000..5f009150 --- /dev/null +++ b/tests/algorithm/test_reduce_algo.py @@ -0,0 +1,41 @@ +from cube.graph.operator.function import Reduce +from cube.algorithm.reduce import ReduceDataParallel +from cube.graph.tensor import IRFullTensor, ValueMap + + +def test_elementwise_data_parallel(): + + input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() + + semantic_op = Reduce( + signature='torch.sum', inputs=[input1], name='add' + ) + semantic_op.infer_shape() + print('semantic op:') + print(semantic_op) + + op_dp = ReduceDataParallel(semantic_op) + + assert op_dp.chunk_num is None + + # test satisfy + assert op_dp.satisfy(dict(chunk_num = 4)) + assert not op_dp.satisfy(dict(chunk_num = 10)) + + nodes = op_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, Reduce) + + for idx, node in enumerate(nodes): + print('=======') + print(node) + print('inputs:') + for input in node.inputs(): + print(input) + assert input.shape == [256, 1024] + print('outputs:') + for output in node.outputs(): + print(output) + assert output.shape == [1] + assert output.val_map == ValueMap(idx, 4) From b7b6336da06ef0a73fee51ac86fe7a7f1b8fd979 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 16:43:30 +0800 Subject: [PATCH 0269/1892] linear example fix --- cube/schedule/__init__.py | 31 +++++++++-- examples/e2e.py | 3 + examples/linears.py | 114 +++++++++++++++++++------------------- 3 files changed, 85 insertions(+), 63 deletions(-) diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index 0baa0b0f..b80437dd 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -3,6 +3,7 @@ from cube.graph.graph import IRGraph from cube.schedule.pool import SchedulePool +from cube.schedule.su import SUType from cube.schedule.translator import IRDataLoader from cube.schedule.sugraph import SUGraph, SUGraphGener from cube.schedule.graphpass import SUGraphPass @@ -95,7 +96,7 @@ def _load_tschedule_fn(filename) -> Callable: def decorator(fn: Callable) -> Callable: filename = 'gencode{}.py' - + batch_size = torch.tensor([-1], dtype=torch.int).cuda() if myrank == 0: SchedulePool().clear() @@ -124,9 +125,9 @@ def decorator(fn: Callable) -> Callable: raise RuntimeError(f"SUGraph order is not topological order") # graph pass to remove redundant sus - su_graph = SUGraphPass.remove_redundant_adapters(sugraph) - su_graph = SUGraphPass.merge_small_sus(su_graph) - print(su_graph) + sugraph = SUGraphPass.remove_redundant_adapters(sugraph) + sugraph = SUGraphPass.merge_small_sus(sugraph) + print(sugraph) if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() @@ -134,8 +135,8 @@ def decorator(fn: Callable) -> Callable: world_size = 1 # code generation - mgener = ModelCodeGen(su_graph) - sgener = ScheduleCodeGen(su_graph) + mgener = ModelCodeGen(sugraph) + sgener = ScheduleCodeGen(sugraph) for rank in range(world_size): fname = filename.format(rank) # generate spatial module code @@ -146,8 +147,26 @@ def decorator(fn: Callable) -> Callable: outfile = fname, attach=True ) + # get dataloader batch size + data = None + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + data = su.outputs(0) + break + if data is None: + raise RuntimeError("dataloader not found in SUGraph") + # assume batch_size is always first dimension + batch_size = torch.tensor([data.shape[0]], dtype=torch.int).cuda() + if torch.distributed.is_initialized(): torch.distributed.barrier() + + # reset dataloader + torch.distributed.broadcast(batch_size, src=0) + batch_size = batch_size.item() + print(f'> reseting dataloader batch size to {batch_size}') + dataloader.reset(batch_size=batch_size) + # load module model.load_module(filename.format(myrank)) # load temporal diff --git a/examples/e2e.py b/examples/e2e.py index 9d7599dd..b5f977a7 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -61,6 +61,9 @@ def __init__(self, batch_size, num=640): def __iter__(self): self.pos = 0 return self + def reset(self, batch_size): + self.batch_size = batch_size + self.pos = 0 def __next__(self): self.pos += 1 if self.pos == self.length: diff --git a/examples/linears.py b/examples/linears.py index 6d349361..bcc9b85a 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -8,15 +8,15 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/e2e.py + examples/linears.py """ -from typing import List import torch from torch import nn import cube -from cube.schedule.su import ScheduleUnit +from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraph @@ -24,68 +24,66 @@ def transform_policy(graph, resource): """ The transformation policy transposes linear using data parallel """ - ndevice = resource.ngpus - for op in graph.nodes(): - # TODO: which dimension is batch - algorithm = op.algorithms('data_parallel') - graph.partition(op, algorithm, config=dict(chunk_size=ndevice)) + for node in graph.nodes(): + if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): + algo = node.algorithms('data') + graph.partition(node, algo, config=dict(chunk_num=2)) return graph -def schedule_policy(seq: SUGraph, resource): +def schedule_policy(sugraph: SUGraph, resource): """ The schedule policy uses 1F1B (interleaved) pipeline """ - ndevice = resource.ngpus - - # batch_seqs[idx]: the idx-th forward-backward 4 linear forward + backward - batch_seqs: List[List[ScheduleUnit]] = group_by_batches(seq.sus()) - num_fsus = len(seq.sus()) // len(batch_seqs) // 2 - - # device placement -- inter-device order - for batch_seq in batch_seqs: - for idx, su in enumerate(batch_seq): - stage = idx // (num_fsus // ndevice) - if idx < num_fsus: - seq.assign(su, stage) + fb_seqs = list() + for fsu in sugraph.fsus(): + for fb_seq in fb_seqs: + for ksu in fb_seq[::-1]: + if sugraph.happen_before(ksu, fsu): + fb_seq.append(fsu) + break else: - seq.assign(su, ndevice - stage % ndevice) - - - # decide topo order -- intra-device order - f = lambda stage, micro_batch_id: batch_seqs[micro_batch_id][stage] - b = lambda stage, micro_batch_id: batch_seqs[micro_batch_id][-stage] - - reorder = list() - # warmup - for stage in range(ndevice): - for micro_batch_id in range(stage): - reorder = reorder.append(f(stage, micro_batch_id)) - # steady + cooldown - for stage in range(ndevice): - # backward - for micro_batch_id in range(len(batch_seqs)): - reorder.append(b(stage, micro_batch_id)) - # forward - for stage in range(ndevice): - f_mirco_batch_id = micro_batch_id + 1 + ndevice - stage - if f_mirco_batch_id >= len(batch_seqs): continue - reorder.append(f(stage, f_mirco_batch_id)) - # inform system the topological order that could do pipeline parallelism - SUGraph.set_order(reorder) + break + else: + fb_seqs.append([fsu]) + + # device assignment + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + + for fb_seq in fb_seqs: + for idx, su in enumerate(fb_seq): + if idx < 2: + sugraph.assign(su, 0) + sugraph.assign(su.mirror, 0) + else: + sugraph.assign(su, 1) + sugraph.assign(su.mirror, 1) + # set partial order + for fb_seq in fb_seqs: + fb_seq += [fsu.mirror for fsu in fb_seq][::-1] + seqs = list() + for fb_seq in fb_seqs: + seqs += fb_seq + sugraph.partial_set_order(seqs) + return sugraph class FakeDataLoader: def __init__(self, shape, num=640): - self.shape = shape + self.shape = list(shape) self.length = num self.pos = 0 def __iter__(self): self.pos = 0 return self + def reset(self, batch_size): + self.shape[0] = batch_size + self.pos = 0 def __next__(self): self.pos += 1 if self.pos == self.length: @@ -106,7 +104,8 @@ def forward(self, data): output = self.linear2(output) output = self.linear3(output) output = self.linear4(output) - return output + loss = torch.sum(output) + return loss def train(): @@ -114,19 +113,19 @@ def train(): dim = 1024 model = MLP(dim=dim) - model = cube.schedule.transform(model, policy_fn=transform_policy) - model = model.cuda() + model = cube.schedule.SemanticModel( + model, input_shapes=([batch_size, dim],), + ) - dataloader = FakeDataLoader((batch_size, dim)) + dataloader = FakeDataLoader([batch_size, dim]) - @cube.schedule.schedule(model, dataloader, policy_fn=schedule_policy) + @cube.schedule.schedule(model, dataloader, transform_policy=transform_policy, schedule_policy=schedule_policy) def train_iter(model, dataloader): - for _ in range(4): - data = next(dataloader) - output = model(data) - loss = torch.sum(output) / 1000 - print(f'loss={loss.item()}') - loss.backward() + data = next(dataloader) + loss = model(data) + # print(f'loss={loss.item()}') + loss.backward() + model = model.get_gen_module() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) @@ -138,4 +137,5 @@ def train_iter(model, dataloader): if __name__ == '__main__': + cube.DeviceGroup() train() \ No newline at end of file From 55c532d9c9aa9e38736ee97bf0f529f4a3101a42 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 16:56:28 +0800 Subject: [PATCH 0270/1892] synthetic data loader --- cube/runtime/__init__.py | 1 + cube/runtime/syndata.py | 42 +++++++++++++++++++++++++++++++++++++++ cube/schedule/__init__.py | 4 ++-- examples/linears.py | 20 ++----------------- 4 files changed, 47 insertions(+), 20 deletions(-) create mode 100644 cube/runtime/syndata.py diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index 90200379..126aa0d2 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -1 +1,2 @@ from cube.runtime import collectives, executor, device +from cube.runtime import syndata \ No newline at end of file diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py new file mode 100644 index 00000000..74cb8392 --- /dev/null +++ b/cube/runtime/syndata.py @@ -0,0 +1,42 @@ +r""" +Synthetic Data Loader +""" + +from typing import List +import torch + + +__all__ = ['SynDataLoader'] + + +class SynDataLoader: + r""" + Synthetic dataloader to produce tensors + for given shape. + """ + def __init__(self, num: int, *shapes: List[List[int]]): + self.shapes = list(shapes) + self.length = num + self.pos = 0 + + def __iter__(self): + self.pos = 0 + return self + + def reset(self, batch_size: int): + """ + Reset batch size + """ + for shape in self.shapes: + shape[0] = batch_size + + def __next__(self): + self.pos += 1 + if self.pos == self.length: + raise StopIteration + datas = list() + for shape in self.shapes: + data = torch.randn(shape).cuda() + datas.append(data) + if len(datas) == 1: return datas[0] + else: return tuple(datas) diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index b80437dd..c882ff22 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -117,7 +117,7 @@ def decorator(fn: Callable) -> Callable: sugraph = schedule_policy(sugraph, None) # check assignment and order - print(sugraph) + # print(sugraph) for su in sugraph.sus(): if len(su.device) == 0: raise RuntimeError(f"SU {su} device is not set") @@ -127,7 +127,7 @@ def decorator(fn: Callable) -> Callable: # graph pass to remove redundant sus sugraph = SUGraphPass.remove_redundant_adapters(sugraph) sugraph = SUGraphPass.merge_small_sus(sugraph) - print(sugraph) + print(f'> after merge small sus:\n {sugraph}') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() diff --git a/examples/linears.py b/examples/linears.py index bcc9b85a..f3e0aef3 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -73,23 +73,7 @@ def schedule_policy(sugraph: SUGraph, resource): return sugraph -class FakeDataLoader: - def __init__(self, shape, num=640): - self.shape = list(shape) - self.length = num - self.pos = 0 - def __iter__(self): - self.pos = 0 - return self - def reset(self, batch_size): - self.shape[0] = batch_size - self.pos = 0 - def __next__(self): - self.pos += 1 - if self.pos == self.length: - raise StopIteration - return torch.randn(self.shape).cuda() - +# =================== Semantic Model Description ==================== class MLP(nn.Module): def __init__(self, dim, mult=16): @@ -117,7 +101,7 @@ def train(): model, input_shapes=([batch_size, dim],), ) - dataloader = FakeDataLoader([batch_size, dim]) + dataloader = cube.runtime.syndata.SynDataLoader(640, [batch_size, dim]) @cube.schedule.schedule(model, dataloader, transform_policy=transform_policy, schedule_policy=schedule_policy) def train_iter(model, dataloader): From 614797ea023104021e66db8aea776f22f3454d9d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 18:39:53 +0800 Subject: [PATCH 0271/1892] fix su merge bug: not consider backward --- cube/__init__.py | 4 ++++ cube/runtime/__init__.py | 3 ++- cube/runtime/resource.py | 25 +++++++++++++++++++++++++ cube/schedule/__init__.py | 12 ++++++++---- cube/schedule/sugraph.py | 5 +++-- examples/linears.py | 19 +++++++++---------- 6 files changed, 51 insertions(+), 17 deletions(-) create mode 100644 cube/runtime/resource.py diff --git a/cube/__init__.py b/cube/__init__.py index a91d0ffa..4d9d894e 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -2,3 +2,7 @@ from cube import schedule from cube import runtime + +def init(): + _ = DeviceGroup() + _ = runtime.resource.EnvResource() diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index 126aa0d2..c3f5515b 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -1,2 +1,3 @@ from cube.runtime import collectives, executor, device -from cube.runtime import syndata \ No newline at end of file +from cube.runtime import syndata +from cube.runtime import resource \ No newline at end of file diff --git a/cube/runtime/resource.py b/cube/runtime/resource.py new file mode 100644 index 00000000..d67f5742 --- /dev/null +++ b/cube/runtime/resource.py @@ -0,0 +1,25 @@ +r""" +Runtime information +""" + +import torch + + +class EnvResource: + + class __EnvResource: + + def __init__(self): + # number of gpus + self.ngpus = torch.distributed.get_world_size() + # device topology + self.topo = None + + instance = None + + def __init__(self): + if not EnvResource.instance: + EnvResource.instance = EnvResource.__EnvResource() + + def __getattr__(self, name): + return getattr(self.instance, name) diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index c882ff22..229110ec 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -1,7 +1,8 @@ from typing import Callable, Optional import torch -from cube.graph.graph import IRGraph +import cube +from cube.graph.graph import IRGraph from cube.schedule.pool import SchedulePool from cube.schedule.su import SUType from cube.schedule.translator import IRDataLoader @@ -99,8 +100,10 @@ def decorator(fn: Callable) -> Callable: batch_size = torch.tensor([-1], dtype=torch.int).cuda() if myrank == 0: SchedulePool().clear() + resource = cube.runtime.resource.EnvResource() # logic translator + # print(f'> ir_graph:\n{ir_graph}') fn(ir_graph, ir_dataloader) nodes = SchedulePool().nodes() @@ -108,13 +111,13 @@ def decorator(fn: Callable) -> Callable: # graph transformation graph = IRGraph(nodes, None, None, ir_graph.name) if transform_policy: - graph = transform_policy(graph, None) + graph = transform_policy(graph, resource) # sugraph sugraph = SUGraphGener.gen_sugraph(graph.nodes()) if schedule_policy: # TODO: add resource - sugraph = schedule_policy(sugraph, None) + sugraph = schedule_policy(sugraph, resource) # check assignment and order # print(sugraph) @@ -126,8 +129,9 @@ def decorator(fn: Callable) -> Callable: # graph pass to remove redundant sus sugraph = SUGraphPass.remove_redundant_adapters(sugraph) + # print(f'> after remove redundant adapters:\n {sugraph}') sugraph = SUGraphPass.merge_small_sus(sugraph) - print(f'> after merge small sus:\n {sugraph}') + # print(f'> after merge small sus:\n {sugraph}') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 9dc17d36..e1023b65 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -170,7 +170,7 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: if su2 not in fsus: raise RuntimeError(f"SU2: {su2} not in forward SUs") - idx1, idx2 = fsus.index(su1), fsus.index(su2) + idx1, idx2 = self.sequence.index(su1), self.sequence.index(su2) su1, su2 = (su1, su2) if idx1 < idx2 else (su2, su1) # condition 1): same device @@ -179,7 +179,8 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: # condition 2): su2 input cannot be got from both su1 and other su start, stop = min(idx1, idx2), max(idx1, idx2) - inter_sus = fsus[start+1:stop] + inter_sus = self.sequence[start+1:stop] + inter_sus = [su for su in inter_sus if su.stype != SUType.Comm] for su in inter_sus: # FIXME: currently only allow other device su exists if self.happen_before(su1, su) or self.happen_before(su, su2): diff --git a/examples/linears.py b/examples/linears.py index f3e0aef3..ef6f4095 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -15,7 +15,6 @@ from torch import nn import cube -from cube.graph.operator.operator import IRDataOperation, IRFwOperation from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraph @@ -24,16 +23,18 @@ def transform_policy(graph, resource): """ The transformation policy transposes linear using data parallel """ + from cube.graph.operator.operator import IRDataOperation, IRFwOperation for node in graph.nodes(): if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): algo = node.algorithms('data') - graph.partition(node, algo, config=dict(chunk_num=2)) + assert algo is not None + graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) return graph def schedule_policy(sugraph: SUGraph, resource): """ - The schedule policy uses 1F1B (interleaved) pipeline + The schedule policy """ fb_seqs = list() for fsu in sugraph.fsus(): @@ -53,14 +54,12 @@ def schedule_policy(sugraph: SUGraph, resource): if su.stype == SUType.Dataloader: sugraph.assign(su, 0) + print(f'> collect {len(fb_seqs)} forward-backward sequence') for fb_seq in fb_seqs: for idx, su in enumerate(fb_seq): - if idx < 2: - sugraph.assign(su, 0) - sugraph.assign(su.mirror, 0) - else: - sugraph.assign(su, 1) - sugraph.assign(su.mirror, 1) + devid = idx % resource.ngpus + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) # set partial order for fb_seq in fb_seqs: @@ -121,5 +120,5 @@ def train_iter(model, dataloader): if __name__ == '__main__': - cube.DeviceGroup() + cube.init() train() \ No newline at end of file From 8ba463eeb72783c3a96e9a0ad9c18fdad41d5de0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 20:47:42 +0800 Subject: [PATCH 0272/1892] add execution plan --- cube/execplan/__init__.py | 1 + cube/{schedule => execplan}/execplan.py | 44 ++++++++++-- cube/execplan/planpass/__init__.py | 0 cube/execplan/planpass/merge.py | 85 +++++++++++++++++++++++ cube/execplan/planpass/planpass.py | 8 +++ cube/execplan/planpass/redundant.py | 30 ++++++++ tests/execplan/test_planpass_merge.py | 84 ++++++++++++++++++++++ tests/execplan/test_planpass_redundant.py | 78 +++++++++++++++++++++ 8 files changed, 326 insertions(+), 4 deletions(-) create mode 100644 cube/execplan/__init__.py rename cube/{schedule => execplan}/execplan.py (75%) create mode 100644 cube/execplan/planpass/__init__.py create mode 100644 cube/execplan/planpass/merge.py create mode 100644 cube/execplan/planpass/planpass.py create mode 100644 cube/execplan/planpass/redundant.py create mode 100644 tests/execplan/test_planpass_merge.py create mode 100644 tests/execplan/test_planpass_redundant.py diff --git a/cube/execplan/__init__.py b/cube/execplan/__init__.py new file mode 100644 index 00000000..a1160701 --- /dev/null +++ b/cube/execplan/__init__.py @@ -0,0 +1 @@ +from cube.execplan.execplan import ExectuionPlan \ No newline at end of file diff --git a/cube/schedule/execplan.py b/cube/execplan/execplan.py similarity index 75% rename from cube/schedule/execplan.py rename to cube/execplan/execplan.py index 34094600..fe4d4a9a 100644 --- a/cube/schedule/execplan.py +++ b/cube/execplan/execplan.py @@ -1,4 +1,5 @@ from typing import List, Optional +import copy from cube.schedule.sugraph import SUGraph from cube.schedule.su import SUType, ScheduleUnit @@ -6,17 +7,52 @@ class ExectuionPlan: - def __init__(self, seq: SUGraph): - - self.seq = seq + def __init__(self, sugraph: SUGraph): + if not isinstance(sugraph, SUGraph): + raise TypeError("Expected a list of ScheduleUnit") + self.sugraph = sugraph self.device_seq = dict() - for su in seq.sequence: + for su in sugraph.sus(): device = su.device[0] if device not in self.device_seq: self.device_seq[device] = [su] else: self.device_seq[device].append(su) + def devices(self) -> List[int]: + """ + Get device set + """ + return self.device_seq.keys() + + def sequence(self, device_id: int) -> List[ScheduleUnit]: + """ + Get a copy of execution sequence for device id + + Note changing the list content will not change the execution plan. + """ + if device_id not in self.device_seq: + return list() + return copy.copy(self.device_seq[device_id]) + + def at(self, device_id: int) -> List[ScheduleUnit]: + """ + Access the sequence for device id + + Note changing the list content will change the execution plan. + """ + if device_id not in self.device_seq: + return list() + return self.device_seq[device_id] + + def set(self, device_id: int, seq: List[ScheduleUnit]): + """ + Set device sequence + """ + if not all([isinstance(su, ScheduleUnit) for su in seq]): + raise TypeError("Expected a list of ScheduleUnit") + self.device_seq[device_id] = seq + def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): """ Draw the execution timeline. diff --git a/cube/execplan/planpass/__init__.py b/cube/execplan/planpass/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py new file mode 100644 index 00000000..df3eff69 --- /dev/null +++ b/cube/execplan/planpass/merge.py @@ -0,0 +1,85 @@ +from typing import List + +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.planpass import PlanPass +from cube.graph.operator.operator import IRBpOperation +from cube.schedule.su import SUType, ScheduleUnit + + +class MergeComputeAdapters(PlanPass): + + @staticmethod + def apply(execplan: ExectuionPlan) -> ExectuionPlan: + """ + Merge consecutive forward SUs + """ + for devid in execplan.devices(): + dev_seq = execplan.sequence(devid) + pieces: List[ScheduleUnit] = list() + for seqidx, su in enumerate(execplan.sequence(devid)): + if su.stype in [SUType.Forward]: + allow_merge = len(pieces) == 0 + for psu in pieces[::-1]: + if execplan.sugraph.happen_before(psu, su): + allow_merge = True + break + if allow_merge: + dev_seq[seqidx] = None + if su.mirror is not None: + if su.mirror not in dev_seq: + raise RuntimeError( + "Expected backward and forward on same device") + idx = dev_seq.index(su.mirror) + dev_seq[idx] = None + pieces.append(su) + continue + # merge pieces + if len(pieces) > 1: + # merged forward su + mfsu = MergeComputeAdapters._merge(pieces, devid) + mbsu = mfsu.mirror + # insert merged forward su + dev_seq[seqidx-1] = mfsu + # insert merged backward su + bidx = len(dev_seq) + for fsu in pieces: + bsu = fsu.mirror + if bsu is not None: + idx = execplan.sequence(devid).index(bsu) + dev_seq[idx] = None + bidx = min(bidx, idx) + if bidx != len(dev_seq): + dev_seq[bidx] = mbsu + pieces = list() + dev_seq = [su for su in dev_seq if su is not None] + execplan.set(devid, dev_seq) + return execplan + + @staticmethod + def _merge(pieces: List[ScheduleUnit], devid: int) -> ScheduleUnit: + """ + Merge a list of SU into one. + """ + fnodes = list() + for fsu in pieces: + fnodes += fsu.nodes() + # TODO: fix multi-branch + mfsu = ScheduleUnit(fnodes, SUType.Forward, name='fsu') + mfsu.device = devid + + # merged backward su + mbnode = IRBpOperation( + data_num=len(mfsu.inputs()), + grad_num=len(mfsu.outputs()) + ) + for idx, fin in enumerate(mfsu.inputs()): + mbnode.set_data(idx, fin) + mbnode.set_output(idx, fin.grad) + for idx, fout in enumerate(mfsu.outputs()): + mbnode.set_grad(idx, fout.grad) + mbsu = ScheduleUnit([mbnode], SUType.Backward, name='bsu') + mbsu.device = devid + + mfsu.mirror = mbsu + mbsu.mirror = mfsu + return mfsu diff --git a/cube/execplan/planpass/planpass.py b/cube/execplan/planpass/planpass.py new file mode 100644 index 00000000..558de959 --- /dev/null +++ b/cube/execplan/planpass/planpass.py @@ -0,0 +1,8 @@ +from cube.execplan import ExectuionPlan + + +class PlanPass: + + @staticmethod + def apply(execplan: ExectuionPlan) -> ExectuionPlan: + raise NotImplementedError diff --git a/cube/execplan/planpass/redundant.py b/cube/execplan/planpass/redundant.py new file mode 100644 index 00000000..aef9c8ed --- /dev/null +++ b/cube/execplan/planpass/redundant.py @@ -0,0 +1,30 @@ +from cube.execplan import ExectuionPlan +from cube.schedule.su import SUType +from cube.execplan.planpass.planpass import PlanPass + + +class RemoveRedundantAdapters(PlanPass): + + @staticmethod + def apply(execplan: ExectuionPlan) -> ExectuionPlan: + """ + Remove redundant adapters + + A redundant adapter is sending / recving tensors on the same deivce + """ + for devid in execplan.devices(): + seq = execplan.sequence(devid) + comms = [su for su in seq if su.stype == SUType.Comm] + for comm in comms: + send_ranks = set([devid]) + recv_ranks = set([devid]) + for node in comm.nodes(): + send_ranks.update(node.send_ranks) + recv_ranks.update(node.recv_ranks) + if list(send_ranks) != [devid]: + continue + if list(recv_ranks) != [devid]: + continue + # remove + execplan.at(devid).remove(comm) + return execplan diff --git a/tests/execplan/test_planpass_merge.py b/tests/execplan/test_planpass_merge.py new file mode 100644 index 00000000..4d240563 --- /dev/null +++ b/tests/execplan/test_planpass_merge.py @@ -0,0 +1,84 @@ +from cube.graph.tensor import IRFullTensor +from cube.graph.operator.function import Linear +from cube.graph.graph import IRGraph + +from cube.schedule.pool import SchedulePool +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraphGener + +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.redundant import RemoveRedundantAdapters +from cube.execplan.planpass.merge import MergeComputeAdapters + + +def construct_graph(): + + input = IRFullTensor(shape=[64,1024], name='data') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = Linear( + name='linear1', + signature='torch.nn.functional.linear', + inputs= [input, weight1, bias1], + ) + linear1.infer_shape() + + # linear2 + linear2 = Linear( + name='linear2', + signature='torch.nn.functional.linear', + inputs= [linear1.outputs(0), weight2, None], + ) + linear2.infer_shape() + + # linear3 + linear3 = Linear( + name='linear3', + signature='torch.nn.functional.linear', + inputs= [linear2.outputs(0), weight3, bias3], + ) + linear3.infer_shape() + + graph = IRGraph( + nodes=[linear1, linear2, linear3], + input_tensors=[input], + output_tensors=linear3.outputs(), + module_name="Test" + ) + return graph + + +def test_planpass_merge(): + SchedulePool().clear() + + graph = construct_graph() + data = IRFullTensor(shape=[64,1024], name='data').tosub() + output = graph(data) + output.backward() + + nodes = SchedulePool().nodes() + sugraph = SUGraphGener.gen_sugraph(nodes) + + for su in sugraph.sus(): + if su.stype != SUType.Comm: + sugraph.assign(su, 0) + + print('orignal:') + print(sugraph) + + execplan = ExectuionPlan(sugraph) + execplan = RemoveRedundantAdapters.apply(execplan) + execplan = MergeComputeAdapters.apply(execplan) + + print('merged:') + for devid in execplan.devices(): + print(f'> device {devid}') + for su in execplan.sequence(devid): + print(su) + assert su.stype != SUType.Comm + assert len(execplan.sequence(0)) == 2 \ No newline at end of file diff --git a/tests/execplan/test_planpass_redundant.py b/tests/execplan/test_planpass_redundant.py new file mode 100644 index 00000000..d63c99fb --- /dev/null +++ b/tests/execplan/test_planpass_redundant.py @@ -0,0 +1,78 @@ +from cube.graph.tensor import IRFullTensor +from cube.graph.operator.function import Linear +from cube.graph.graph import IRGraph + +from cube.schedule.pool import SchedulePool +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraphGener + +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.redundant import RemoveRedundantAdapters + + +def construct_graph(): + + input = IRFullTensor(shape=[64,1024], name='data') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = Linear( + name='linear1', + signature='torch.nn.functional.linear', + inputs= [input, weight1, bias1], + ) + linear1.infer_shape() + + # linear2 + linear2 = Linear( + name='linear2', + signature='torch.nn.functional.linear', + inputs= [linear1.outputs(0), weight2, None], + ) + linear2.infer_shape() + + # linear3 + linear3 = Linear( + name='linear3', + signature='torch.nn.functional.linear', + inputs= [linear2.outputs(0), weight3, bias3], + ) + linear3.infer_shape() + + graph = IRGraph( + nodes=[linear1, linear2, linear3], + input_tensors=[input], + output_tensors=linear3.outputs(), + module_name="Test" + ) + return graph + + +def test_remove_adapter(): + + SchedulePool().clear() + + graph = construct_graph() + data = IRFullTensor(shape=[64,1024], name='data').tosub() + output = graph(data) + output.backward() + + nodes = SchedulePool().nodes() + sugraph = SUGraphGener.gen_sugraph(nodes) + + for su in sugraph.sus(): + sugraph.assign(su, 0) + + execplan = ExectuionPlan(sugraph) + execplan = RemoveRedundantAdapters.apply(execplan) + + for devid in execplan.devices(): + print(f'> device {devid}') + for su in execplan.sequence(devid): + print(su) + assert su.stype != SUType.Comm + assert len(execplan.sequence(0)) == 6 \ No newline at end of file From a9d25f765929894c1934f7c6539461ec4f106723 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 20:55:42 +0800 Subject: [PATCH 0273/1892] fix naming typo --- cube/execplan/execplan.py | 2 ++ cube/execplan/planpass/merge.py | 4 ++-- tests/execplan/test_planpass_merge.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index fe4d4a9a..f8e941f6 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -13,6 +13,8 @@ def __init__(self, sugraph: SUGraph): self.sugraph = sugraph self.device_seq = dict() for su in sugraph.sus(): + if len(su.device) == 0: + raise RuntimeError(f"device not set: SU {su}") device = su.device[0] if device not in self.device_seq: self.device_seq[device] = [su] diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index df3eff69..b75e6863 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -6,7 +6,7 @@ from cube.schedule.su import SUType, ScheduleUnit -class MergeComputeAdapters(PlanPass): +class MergeComputeSU(PlanPass): @staticmethod def apply(execplan: ExectuionPlan) -> ExectuionPlan: @@ -36,7 +36,7 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: # merge pieces if len(pieces) > 1: # merged forward su - mfsu = MergeComputeAdapters._merge(pieces, devid) + mfsu = MergeComputeSU._merge(pieces, devid) mbsu = mfsu.mirror # insert merged forward su dev_seq[seqidx-1] = mfsu diff --git a/tests/execplan/test_planpass_merge.py b/tests/execplan/test_planpass_merge.py index 4d240563..936e912a 100644 --- a/tests/execplan/test_planpass_merge.py +++ b/tests/execplan/test_planpass_merge.py @@ -8,7 +8,7 @@ from cube.execplan import ExectuionPlan from cube.execplan.planpass.redundant import RemoveRedundantAdapters -from cube.execplan.planpass.merge import MergeComputeAdapters +from cube.execplan.planpass.merge import MergeComputeSU def construct_graph(): @@ -73,7 +73,7 @@ def test_planpass_merge(): execplan = ExectuionPlan(sugraph) execplan = RemoveRedundantAdapters.apply(execplan) - execplan = MergeComputeAdapters.apply(execplan) + execplan = MergeComputeSU.apply(execplan) print('merged:') for devid in execplan.devices(): From 195aa549a39b939818fd11edf37a142d2c04cb0a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Nov 2021 23:06:35 +0800 Subject: [PATCH 0274/1892] switch to execution plan --- cube/codegen/codegen.py | 34 ++++++++++----------- cube/execplan/execplan.py | 11 ++++++- cube/execplan/planpass/merge.py | 4 ++- cube/schedule/__init__.py | 21 +++++++------ cube/schedule/graphpass.py | 53 --------------------------------- tests/codegen/test_codegen.py | 25 ++++++++++------ 6 files changed, 57 insertions(+), 91 deletions(-) delete mode 100644 cube/schedule/graphpass.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 3ad0c34e..8aa7616b 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -6,7 +6,8 @@ import copy from cube.ir.cten import IRTensor -from cube.schedule.sugraph import SUGraph +from cube.execplan import ExectuionPlan + from cube.schedule.su import ScheduleUnit, SUType from cube.schedule.adapter.comm import IRCommType, IRCommunication from cube.schedule.adapter.select import IRTensorReshape, IRReshapeType @@ -19,13 +20,10 @@ class ModelCodeGen: Generate spatial code for the model """ - def __init__(self, sugraph: SUGraph): - if not isinstance(sugraph, SUGraph): - raise TypeError("sugraph should be SUGraph") - for su in sugraph.sus(): - if len(su.device) == 0: - raise RuntimeError(f"SU: {su} is not assigned to device") - self.seq = sugraph + def __init__(self, execplan: ExectuionPlan): + if not isinstance(execplan, ExectuionPlan): + raise TypeError("execplan should be ExecutionPlan") + self.execplan = execplan # model full code self.init_code: List[str] = [ '\n\n########## Generated Model Code ###########', @@ -44,9 +42,9 @@ def gen(self, device: int, outfile=None, attach=False) -> str: """ Generate model implementation code based on the given graph. """ - device_sus = [su for su in self.seq.sus() \ - if device in su.device \ - and su.stype != SUType.Backward \ + device_sus = self.execplan.sequence(device) + device_sus = [su for su in device_sus \ + if su.stype != SUType.Backward \ and su.stype != SUType.Dataloader] gencode = copy.copy(self.init_code) @@ -89,7 +87,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: cb.insert_body('') cb.insert_body(ib.code) for idx, su in enumerate(device_sus): - name = f'su{self.seq.sus().index(su)}' + name = f'su{su._id}' input_args = ['self'] + su_args[idx] forward_code = self.all_su_forward_region[idx] with FunctionBlock(func_name=name, args=input_args) as fb: @@ -238,10 +236,10 @@ def clear(self): class ScheduleCodeGen: - def __init__(self, seq: SUGraph): - if not isinstance(seq, SUGraph): - raise TypeError("seq should be SUGraph") - self.seq = seq + def __init__(self, execplan: ExectuionPlan): + if not isinstance(execplan, ExectuionPlan): + raise TypeError("execplan should be ExecutionPlan") + self.execplan = execplan # model full code self.init_code: List[str] = [ '\n\n########## Generated Schedule Code ###########', @@ -254,13 +252,13 @@ def gen(self, device: int, outfile=None, attach=False) -> str: Generate scheduling code based on the given sus """ gencode = copy.copy(self.init_code) - device_sus = [su for su in self.seq.sus() if device in su.device] + device_sus = self.execplan.sequence(device) # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: for su in device_sus: - name = f'su{self.seq.sus().index(su)}' + name = f'su{su._id}' code = self.emit_su(su, name=name) fb.insert_body(code) gencode += fb.code diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index f8e941f6..f0dee2ce 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -147,4 +147,13 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): ax.annotate(anno, (cx, cy), color='w', weight='bold', fontsize=10, ha='center', va='center') # plt.grid() - plt.savefig(outfile) \ No newline at end of file + plt.savefig(outfile) + + + def __repr__(self): + dscp = f'Execution Plan ({self.sugraph.name}):\n' + for devid in self.devices(): + dscp += f'====> Device {devid}:\n' + for su in self.sequence(devid): + dscp += f'{su}\n' + return dscp diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index b75e6863..fe54f073 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -34,7 +34,7 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: pieces.append(su) continue # merge pieces - if len(pieces) > 1: + if len(pieces) > 0: # merged forward su mfsu = MergeComputeSU._merge(pieces, devid) mbsu = mfsu.mirror @@ -60,6 +60,8 @@ def _merge(pieces: List[ScheduleUnit], devid: int) -> ScheduleUnit: """ Merge a list of SU into one. """ + if len(pieces) == 1: + return pieces[0] fnodes = list() for fsu in pieces: fnodes += fsu.nodes() diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index 229110ec..1a9ee3c1 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -7,7 +7,10 @@ from cube.schedule.su import SUType from cube.schedule.translator import IRDataLoader from cube.schedule.sugraph import SUGraph, SUGraphGener -from cube.schedule.graphpass import SUGraphPass + +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.redundant import RemoveRedundantAdapters +from cube.execplan.planpass.merge import MergeComputeSU from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -116,7 +119,6 @@ def decorator(fn: Callable) -> Callable: # sugraph sugraph = SUGraphGener.gen_sugraph(graph.nodes()) if schedule_policy: - # TODO: add resource sugraph = schedule_policy(sugraph, resource) # check assignment and order @@ -127,11 +129,12 @@ def decorator(fn: Callable) -> Callable: if not SUGraph.is_topo_order(sugraph.sus()): raise RuntimeError(f"SUGraph order is not topological order") - # graph pass to remove redundant sus - sugraph = SUGraphPass.remove_redundant_adapters(sugraph) - # print(f'> after remove redundant adapters:\n {sugraph}') - sugraph = SUGraphPass.merge_small_sus(sugraph) - # print(f'> after merge small sus:\n {sugraph}') + execplan = ExectuionPlan(sugraph) + # plan pass to remove redundant sus + execplan = RemoveRedundantAdapters.apply(execplan) + # print(f'> after remove redundant adapters:\n {execplan}') + execplan = MergeComputeSU.apply(execplan) + # print(f'> after merge compute SU:\n{execplan}') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() @@ -139,8 +142,8 @@ def decorator(fn: Callable) -> Callable: world_size = 1 # code generation - mgener = ModelCodeGen(sugraph) - sgener = ScheduleCodeGen(sugraph) + mgener = ModelCodeGen(execplan) + sgener = ScheduleCodeGen(execplan) for rank in range(world_size): fname = filename.format(rank) # generate spatial module code diff --git a/cube/schedule/graphpass.py b/cube/schedule/graphpass.py deleted file mode 100644 index 1149e248..00000000 --- a/cube/schedule/graphpass.py +++ /dev/null @@ -1,53 +0,0 @@ -from cube.schedule.sugraph import SUGraph -from cube.schedule.su import SUType, ScheduleUnit - - -class SUGraphPass: - - @staticmethod - def remove_redundant_adapters(sugraph: SUGraph) -> SUGraph: - """ - Remove redundant adapters - - A redundant adapter is sending and recving - on the same device - """ - redundant_adapters = list() - for su in sugraph.sus(): - if su.stype != SUType.Comm: - for idx in range(len(su.outputs())): - send_adapters, recv_adapters = su.out_adapters(idx) - for sadapter, radapter in zip(send_adapters, recv_adapters): - # indicate a tensor selection in-device - if sadapter.device == radapter.device: - if len(sadapter.inputs()) != 1: - raise NotImplementedError - # indicate identity op: - if sadapter.inputs(0).shape == su.outputs(idx).shape: - redundant_adapters.append(sadapter) - redundant_adapters.append(radapter) - - all_sus = sugraph.sus() - for adapter in redundant_adapters: - if adapter in all_sus: - all_sus.remove(adapter) - - sugraph = SUGraph(all_sus) - return sugraph - - @staticmethod - def merge_small_sus(sugraph: SUGraph) -> SUGraph: - """ - Merge SU to a larger one if possible - """ - devices = set() - for su in sugraph.sus(): - devices.update(set(su.device)) - for device in devices: - dev_sus = [su for su in sugraph.fsus() if device in su.device] - merged_su = dev_sus[0] - for su in dev_sus[1:]: - merged_su = sugraph.merge(merged_su, su) - if merged_su is None: - merged_su = su - return sugraph diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py index d46bb9eb..fc05862f 100644 --- a/tests/codegen/test_codegen.py +++ b/tests/codegen/test_codegen.py @@ -3,10 +3,14 @@ from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraphGener from cube.schedule.translator import IRDataLoader -from cube.schedule.graphpass import SUGraphPass + +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.redundant import RemoveRedundantAdapters +from cube.execplan.planpass.merge import MergeComputeSU + import torch from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -151,16 +155,19 @@ def test_model_gen(): for su in seqs: print(su) sugraph.partial_set_order(seqs) - print('========= after reorder: ==========\n', sugraph) - sugraph = SUGraphPass.remove_redundant_adapters(sugraph) - print('========= after remove adapter: ==========\n', sugraph) + # print('========= after reorder: ==========\n', sugraph) + + execplan = ExectuionPlan(sugraph) + execplan = RemoveRedundantAdapters.apply(execplan) + + # print('========= after remove adapter: ==========\n', execplan) - sugraph = SUGraphPass.merge_small_sus(sugraph) - print('========= after merge small SU: ==========\n', sugraph) + execplan = MergeComputeSU.apply(execplan) + # print('========= after merge small SU: ==========\n', execplan) - mgener = ModelCodeGen(sugraph) - tgener = ScheduleCodeGen(sugraph) + mgener = ModelCodeGen(execplan) + tgener = ScheduleCodeGen(execplan) mcode0 = mgener.gen(device = 0) tcode0 = tgener.gen(device = 0) From c08734009ef515afb137f4ed199a35d2c58788b6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 11 Nov 2021 10:55:37 +0800 Subject: [PATCH 0275/1892] clean up code --- cube/__init__.py | 5 +- cube/compiler.py | 191 ++++++++++++++++++++++++++++++++++++++ cube/schedule/__init__.py | 180 +---------------------------------- cube/schedule/plan.py | 131 -------------------------- examples/linears.py | 4 +- examples/poc/torchfx.py | 167 --------------------------------- 6 files changed, 197 insertions(+), 481 deletions(-) create mode 100644 cube/compiler.py delete mode 100644 cube/schedule/plan.py delete mode 100644 examples/poc/torchfx.py diff --git a/cube/__init__.py b/cube/__init__.py index 4d9d894e..1f0f0c33 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,8 +1,9 @@ -from cube.runtime.device import DeviceGroup from cube import schedule from cube import runtime +from cube.compiler import SemanticModel, compile + def init(): - _ = DeviceGroup() + _ = runtime.device.DeviceGroup() _ = runtime.resource.EnvResource() diff --git a/cube/compiler.py b/cube/compiler.py new file mode 100644 index 00000000..854aec6a --- /dev/null +++ b/cube/compiler.py @@ -0,0 +1,191 @@ +from typing import Callable, Optional, Tuple +import torch + +import cube +from cube.graph.graph import IRGraph +from cube.schedule.pool import SchedulePool +from cube.schedule.su import SUType +from cube.schedule.translator import IRDataLoader +from cube.schedule.sugraph import SUGraph, SUGraphGener + +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.redundant import RemoveRedundantAdapters +from cube.execplan.planpass.merge import MergeComputeSU + +from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen + + +class SemanticModel: + + def __init__(self, model: torch.nn.Module, input_shapes): + """ + Create semantic model based on AI Scientist description. + """ + from cube.graph import parser + self.ir_graph = parser.convert( + model, input_shapes=input_shapes + ) + self._loaded_module = None + + def get_graph(self): + return self.ir_graph + + def load_module(self, filename: str): + import importlib.util + print(f'> loading generated spatial moduel from {filename}') + spec = importlib.util.spec_from_file_location("GenModel", filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._loaded_module = module.GenModel().cuda() + + def get_gen_module(self): + return self._loaded_module + + def clear_module(self): + self._loaded_module = None + + def __call__(self, *args): + if self._loaded_module: + return self._loaded_module(*args) + else: + return self.ir_graph(*args) + + +def compile(model: SemanticModel, dataloader, + policy: Tuple[Optional[Callable], Optional[Callable]] = (None, None)): + """ + AI Scientist calls like: + + @cube.compile(model, dataloader, policy=(trans_policy, schedule_policy)) + def train_step(model, dataloader): + # do a 4-time gradient accumulation + for acc_step, (data, label) in enumerate(dataloader): + if acc_step < 4: + loss = model(data, label) + loss.backward() + else: + break + ... + + for epoch in range(100): + train_step(model, data_loader) + optimizer.step() + optimizer.zero_grad() + + ... + + Args: + model: AI Scientist specified SemanticModel + dataloader: dataloader used for training + policy: tuple of transformation policy and scheduling policy + """ + if not isinstance(model, SemanticModel): + raise TypeError("Expect Semantic Model") + if len(policy) != 2: + raise TypeError( + "Expected policy to be tuple of transformation + scheduling policy" + ) + transform_policy, schedule_policy = policy + + ir_graph = model.get_graph() + ir_dataloader = IRDataLoader(dataloader) + + if torch.distributed.is_initialized(): + # multiple device + myrank = torch.distributed.get_rank() + else: + # single device + myrank = 0 + + def _load_tschedule_fn(filename) -> Callable: + import importlib.util + print(f'> [{myrank}] loading generated schedule from {filename} ...') + spec = importlib.util.spec_from_file_location( + "_train_step", filename + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module._train_step + + def decorator(fn: Callable) -> Callable: + filename = 'gencode{}.py' + batch_size = torch.tensor([-1], dtype=torch.int).cuda() + if myrank == 0: + SchedulePool().clear() + resource = cube.runtime.resource.EnvResource() + + # logic translator + # print(f'> ir_graph:\n{ir_graph}') + fn(ir_graph, ir_dataloader) + + nodes = SchedulePool().nodes() + + # graph transformation + graph = IRGraph(nodes, None, None, ir_graph.name) + if transform_policy: + graph = transform_policy(graph, resource) + + # sugraph + sugraph = SUGraphGener.gen_sugraph(graph.nodes()) + if schedule_policy: + sugraph = schedule_policy(sugraph, resource) + + # check assignment and order + # print(sugraph) + for su in sugraph.sus(): + if len(su.device) == 0: + raise RuntimeError(f"SU {su} device is not set") + if not SUGraph.is_topo_order(sugraph.sus()): + raise RuntimeError(f"SUGraph order is not topological order") + + execplan = ExectuionPlan(sugraph) + # plan pass to remove redundant sus + execplan = RemoveRedundantAdapters.apply(execplan) + # print(f'> after remove redundant adapters:\n {execplan}') + execplan = MergeComputeSU.apply(execplan) + # print(f'> after merge compute SU:\n{execplan}') + + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + + # code generation + mgener = ModelCodeGen(execplan) + sgener = ScheduleCodeGen(execplan) + for rank in range(world_size): + fname = filename.format(rank) + # generate spatial module code + mgener.gen(rank, outfile=fname, attach=False) + # generate temporal schedule code + sgener.gen( + device = rank, + outfile = fname, + attach=True + ) + # get dataloader batch size + data = None + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + data = su.outputs(0) + break + if data is None: + raise RuntimeError("dataloader not found in SUGraph") + # assume batch_size is always first dimension + batch_size = torch.tensor([data.shape[0]], dtype=torch.int).cuda() + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # reset dataloader + torch.distributed.broadcast(batch_size, src=0) + batch_size = batch_size.item() + print(f'> reseting dataloader batch size to {batch_size}') + dataloader.reset(batch_size=batch_size) + + # load module + model.load_module(filename.format(myrank)) + # load temporal + return _load_tschedule_fn(filename.format(myrank)) + + return decorator diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py index 1a9ee3c1..23058c3a 100644 --- a/cube/schedule/__init__.py +++ b/cube/schedule/__init__.py @@ -1,182 +1,4 @@ -from typing import Callable, Optional -import torch - -import cube -from cube.graph.graph import IRGraph from cube.schedule.pool import SchedulePool from cube.schedule.su import SUType from cube.schedule.translator import IRDataLoader -from cube.schedule.sugraph import SUGraph, SUGraphGener - -from cube.execplan import ExectuionPlan -from cube.execplan.planpass.redundant import RemoveRedundantAdapters -from cube.execplan.planpass.merge import MergeComputeSU - -from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen - - -class SemanticModel: - - def __init__(self, model: torch.nn.Module, input_shapes): - """ - Create semantic model based on AI Scientist description. - """ - from cube.graph import parser - self.ir_graph = parser.convert( - model, input_shapes=input_shapes - ) - self._loaded_module = None - - def get_graph(self): - return self.ir_graph - - def load_module(self, filename: str): - import importlib.util - print(f'> loading generated spatial moduel from {filename}') - spec = importlib.util.spec_from_file_location("GenModel", filename) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self._loaded_module = module.GenModel().cuda() - - def get_gen_module(self): - return self._loaded_module - - def clear_module(self): - self._loaded_module = None - - def __call__(self, *args): - if self._loaded_module: - return self._loaded_module(*args) - else: - return self.ir_graph(*args) - - -def schedule(model: SemanticModel, dataloader, - transform_policy: Optional[Callable] = None, - schedule_policy: Optional[Callable] = None): - """ - AI Scientist calls like: - - @cube.tschedule.schedule(model, dataloader, policy_fn=policy) - def train_step(model, dataloader): - # do a 4-time gradient accumulation - for acc_step, (data, label) in enumerate(dataloader): - if acc_step < 4: - loss = model(data, label) - loss.backward() - else: - break - ... - - for epoch in range(100): - train_step(model, data_loader) - optimizer.step() - optimizer.zero_grad() - - ... - """ - if not isinstance(model, SemanticModel): - raise TypeError("Expect Semantic Model") - - ir_graph = model.get_graph() - ir_dataloader = IRDataLoader(dataloader) - - if torch.distributed.is_initialized(): - # multiple device - myrank = torch.distributed.get_rank() - else: - # single device - myrank = 0 - - def _load_tschedule_fn(filename) -> Callable: - import importlib.util - print(f'> [{myrank}] loading generated schedule from {filename} ...') - spec = importlib.util.spec_from_file_location( - "_train_step", filename - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module._train_step - - def decorator(fn: Callable) -> Callable: - filename = 'gencode{}.py' - batch_size = torch.tensor([-1], dtype=torch.int).cuda() - if myrank == 0: - SchedulePool().clear() - resource = cube.runtime.resource.EnvResource() - - # logic translator - # print(f'> ir_graph:\n{ir_graph}') - fn(ir_graph, ir_dataloader) - - nodes = SchedulePool().nodes() - - # graph transformation - graph = IRGraph(nodes, None, None, ir_graph.name) - if transform_policy: - graph = transform_policy(graph, resource) - - # sugraph - sugraph = SUGraphGener.gen_sugraph(graph.nodes()) - if schedule_policy: - sugraph = schedule_policy(sugraph, resource) - - # check assignment and order - # print(sugraph) - for su in sugraph.sus(): - if len(su.device) == 0: - raise RuntimeError(f"SU {su} device is not set") - if not SUGraph.is_topo_order(sugraph.sus()): - raise RuntimeError(f"SUGraph order is not topological order") - - execplan = ExectuionPlan(sugraph) - # plan pass to remove redundant sus - execplan = RemoveRedundantAdapters.apply(execplan) - # print(f'> after remove redundant adapters:\n {execplan}') - execplan = MergeComputeSU.apply(execplan) - # print(f'> after merge compute SU:\n{execplan}') - - if torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size() - else: - world_size = 1 - - # code generation - mgener = ModelCodeGen(execplan) - sgener = ScheduleCodeGen(execplan) - for rank in range(world_size): - fname = filename.format(rank) - # generate spatial module code - mgener.gen(rank, outfile=fname, attach=False) - # generate temporal schedule code - sgener.gen( - device = rank, - outfile = fname, - attach=True - ) - # get dataloader batch size - data = None - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - data = su.outputs(0) - break - if data is None: - raise RuntimeError("dataloader not found in SUGraph") - # assume batch_size is always first dimension - batch_size = torch.tensor([data.shape[0]], dtype=torch.int).cuda() - - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - # reset dataloader - torch.distributed.broadcast(batch_size, src=0) - batch_size = batch_size.item() - print(f'> reseting dataloader batch size to {batch_size}') - dataloader.reset(batch_size=batch_size) - - # load module - model.load_module(filename.format(myrank)) - # load temporal - return _load_tschedule_fn(filename.format(myrank)) - - return decorator +from cube.schedule.sugraph import SUGraph diff --git a/cube/schedule/plan.py b/cube/schedule/plan.py deleted file mode 100644 index 9190aeba..00000000 --- a/cube/schedule/plan.py +++ /dev/null @@ -1,131 +0,0 @@ - -class ExecutionPlan: - - def __init__(self, seq, ndevice): - """ - Seq: action sequence - ndevice: device number - """ - self.seq = seq - self.ndevice = ndevice - self.device_timeline = None - self.device_actions = None - - def gen(self): - """ - Generate execution plan - """ - # timeline: [(start_time, end_time)] - self.device_timeline = [list() for _ in range(self.ndevice)] - self.device_actions = [list() for _ in range(self.ndevice)] - - for action in self.seq: - if action.device == -1 or action.device >= self.ndevice: - raise RuntimeError("action {} device not assigned or out of boundary".format(action)) - if len(self.device_timeline[action.device]) == 0: - start_time = 1 - else: - start_time = self.device_timeline[action.device][-1][1] - for dev_id, (timeline, dev_actions) in enumerate(zip(self.device_timeline, self.device_actions)): - if dev_id == action.device: - continue - # go through to check if the action has dependencies - for (_, end_time), dev_action in zip(timeline[::-1], dev_actions[::-1]): - if action.depends_on(dev_action): - # print('find dependency {} -> {}, end time: {}'.format(action, dev_action, end_time)) - start_time = max(start_time, end_time) - break - elif dev_action.depends_on(action): - raise RuntimeError("Action happened before") - # update timeline - self.device_timeline[action.device].append((start_time, start_time + action.est_latency)) - self.device_actions[action.device].append(action) - - def actions(self, device_id): - """ - Get action sequence for the specific device id - """ - if device_id >= self.ndevice: - raise ValueError(f"device id out of boundary ({device_id} >= {self.ndeivce})") - if self.device_actions is None: - self.gen() - return self.device_actions[device_id] - - def timeline(self, device_id): - """ - Get action timeline for the specific device id - """ - if device_id >= self.ndevice: - raise ValueError(f"device id out of boundary ({device_id} >= {self.ndeivce})") - if self.device_timeline is None: - self.gen() - return self.device_timeline[device_id] - - def get_time(self): - if self.device_timeline is None: - self.gen() - return max( - [timeline[-1][1] for timeline in self.device_timeline if len(timeline) != 0] - ) - - def get_memory(self): - if self.device_timeline is None: - self.gen() - - def device_memory(actions): - max_mem = 0 - cur_mem = 0 - for action in actions: - cur_mem += action.est_memory - max_mem = max(cur_mem, max_mem) - return max_mem - - return max( - [device_memory(actions) for actions in self.device_actions] - ) - - def draw(self, outfile='./execplan.png'): - import matplotlib.pyplot as plt - from matplotlib.patches import Rectangle - plt.rcParams['figure.figsize'] = (12.0, 4.0) - - if self.device_actions is None: - self.gen() - - fig, ax = plt.subplots() - plan_time = self.get_time() - - # xaxis - ax.set_xlim((1, plan_time)) - plt.xticks(list(range(1, plan_time+1, 1))) - ax.xaxis.grid(True, linestyle='--') - plt.xlabel('time') - - # yaxis - ax.set_ylim((0.5, self.ndevice+0.5)) - plt.yticks(list(range(1, self.ndevice+1, 1))) - ax.invert_yaxis() - plt.ylabel('device id') - - ax.set_aspect('equal') - - for devid in range(len(self.device_actions)): - timeline = self.device_timeline[devid] - actions = self.device_actions[devid] - for action, (start, end) in zip(actions, timeline): - # draw - color = 'blue' if (end - start) == 1 else 'orange' - rec = Rectangle((start, devid + 0.5), end-start, 1, - color=color, ec='black', lw=1.5) - ax.add_artist(rec) - rx, ry = rec.get_xy() - cx = rx + rec.get_width() / 2.0 - cy = ry + rec.get_height() / 2.0 - anno = action.name if action.fid is None else action.fid - ax.annotate(anno, (cx, cy), color='w', weight='bold', - fontsize=10, ha='center', va='center') - # plt.grid() - plt.savefig(outfile) - - def to_json(self): - return [repr(action) for action in self.seq] diff --git a/examples/linears.py b/examples/linears.py index ef6f4095..8b17b645 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -96,13 +96,13 @@ def train(): dim = 1024 model = MLP(dim=dim) - model = cube.schedule.SemanticModel( + model = cube.SemanticModel( model, input_shapes=([batch_size, dim],), ) dataloader = cube.runtime.syndata.SynDataLoader(640, [batch_size, dim]) - @cube.schedule.schedule(model, dataloader, transform_policy=transform_policy, schedule_policy=schedule_policy) + @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) diff --git a/examples/poc/torchfx.py b/examples/poc/torchfx.py deleted file mode 100644 index b53d232f..00000000 --- a/examples/poc/torchfx.py +++ /dev/null @@ -1,167 +0,0 @@ -""" -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - examples/poc/torchfx.py -""" - -import torch -from torch import nn -import torch.nn.functional as F -from torch.fx import symbolic_trace - -import os - - -local_rank = int(os.environ.get('LOCAL_RANK')) -torch.cuda.set_device(local_rank) -torch.distributed.init_process_group( - backend='nccl', - init_method='env://', -) - -# ====================== Check for normal module ========================== -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult) - self.gelu = nn.GELU() - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim * mult, dim) - self.classifier = nn.Linear(dim, classes) - - def forward(self, x): - output = self.linear1(x) - output = self.gelu(output) - output = self.dropout(output) - output = self.linear2(output) - output = self.classifier(output) - return output - -model = FeedForward(dim=1024).cuda() -graph_module = symbolic_trace(model) -if local_rank == 0: - print(graph_module) - print(graph_module.code) - print(graph_module.graph) - - -# ====================== Check for autograd function ========================== -class CustomOp(torch.autograd.Function): - @staticmethod - def symbolic(graph, input, weight): - return torch.matmul(input, weight) - @staticmethod - def forward(ctx, input, weight): - ctx.save_for_backward(input, weight) - return torch.matmul(input, weight) - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - return input+weight, input+weight - -class CustomModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input, weight): - out = CustomOp.apply(input, weight) - return out - -custom_op = CustomModule().cuda() - -input = torch.ones((1024, 1024)).cuda().requires_grad_() -weight = torch.ones((1024, 1024)).cuda().requires_grad_() - -if local_rank == 0: - custom_op_trace = symbolic_trace(custom_op) - print(custom_op_trace) - print(custom_op_trace.code) - print(custom_op_trace.graph) - # traced graph call - out = custom_op_trace(input, weight) - torch.sum(out).backward() - print(out) - print('weight grad: ', weight.grad) - # original graph call - - out = custom_op(input, weight) - input.grad = None - weight.grad = None - torch.sum(out).backward() - print('weight grad expected: ', weight.grad) - print(out) - -torch.distributed.barrier() - - -# ====================== Check for function with communications ========================== -class InputAdapter(torch.autograd.Function): - @staticmethod - def symbolic(graph, input_): - return input_ - @staticmethod - def forward(ctx, input_): - return input_ - @staticmethod - def backward(ctx, grad_output): - return torch.distributed.all_reduce(grad_output) - - -class OutputAdapter(torch.autograd.Function): - @staticmethod - def symbolic(graph, input_): - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_) - output = torch.cat(tensor_list, dim=-1) - return output - @staticmethod - def forward(ctx, input_): - # world_size = torch.distributed.get_world_size() - # rank = torch.distributed.get_rank() - # tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - # tensor_list[rank] = input_ - # torch.distributed.all_gather(tensor_list, input_) - # output = torch.cat(tensor_list, dim=-1) - output = input_ - torch.distributed.all_reduce(output) - return output - @staticmethod - def backward(ctx, grad_output): - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - tensor_list = torch.split( - grad_output, grad_output.size()[-1]//world_size, dim=-1 - ) - return tensor_list[rank].contiguous() - - -class LinearComm(nn.Module): - def __init__(self, input_feats, output_feats): - super().__init__() - self.linear = nn.Linear(input_feats, output_feats) - def forward(self, x): - x = InputAdapter.apply(x) - x = self.linear(x) - x = OutputAdapter.apply(x) - return x - -comm_linear = LinearComm(1024, 1024).cuda() -graph_comm = symbolic_trace(comm_linear) -if local_rank == 0: - print(graph_comm.graph) - print(graph_comm.code) - -input = torch.ones((1024, 1024)).cuda().requires_grad_() -out = graph_comm(input) -out_ref = comm_linear(input) -if local_rank == 0: - print('out: ', out) - print('out expected: ', out_ref) From cb4f462ac65902ba34667d1d8af111ff7cacd870 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 11 Nov 2021 13:34:25 +0800 Subject: [PATCH 0276/1892] test random device assignment --- cube/compiler.py | 2 +- examples/linears.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 854aec6a..802ef4dd 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -143,7 +143,7 @@ def decorator(fn: Callable) -> Callable: execplan = RemoveRedundantAdapters.apply(execplan) # print(f'> after remove redundant adapters:\n {execplan}') execplan = MergeComputeSU.apply(execplan) - # print(f'> after merge compute SU:\n{execplan}') + print(f'> after merge compute SU:\n{execplan}') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() diff --git a/examples/linears.py b/examples/linears.py index 8b17b645..d45dbdac 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -13,6 +13,8 @@ import torch from torch import nn +import math +import random import cube from cube.schedule.su import SUType @@ -56,8 +58,11 @@ def schedule_policy(sugraph: SUGraph, resource): print(f'> collect {len(fb_seqs)} forward-backward sequence') for fb_seq in fb_seqs: + chunk_num = int(math.ceil(len(fb_seq) / resource.ngpus)) for idx, su in enumerate(fb_seq): - devid = idx % resource.ngpus + # devid = int(idx // chunk_num) + # devid = idx % resource.ngpus + devid = random.randint(0, resource.ngpus - 1) sugraph.assign(su, devid) sugraph.assign(su.mirror, devid) @@ -92,7 +97,7 @@ def forward(self, data): def train(): - batch_size = 64 + batch_size = 128 dim = 1024 model = MLP(dim=dim) @@ -100,7 +105,7 @@ def train(): model, input_shapes=([batch_size, dim],), ) - dataloader = cube.runtime.syndata.SynDataLoader(640, [batch_size, dim]) + dataloader = cube.runtime.syndata.SynDataLoader(1280, [batch_size, dim]) @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) def train_iter(model, dataloader): From a9410745bd5de0949cdcf3a8986b903f4d54889e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 11 Nov 2021 14:16:03 +0800 Subject: [PATCH 0277/1892] change name --- cube/codegen/codegen.py | 8 +- cube/graph/operator/adapter.py | 221 ------------------ .../adapter/{select.py => transform.py} | 12 +- cube/schedule/sugraph.py | 6 +- tests/schedule/test_adapter_select.py | 16 +- tests/schedule/test_graphpass.py | 100 -------- tests/schedule/test_workflow.py | 6 +- 7 files changed, 24 insertions(+), 345 deletions(-) delete mode 100644 cube/graph/operator/adapter.py rename cube/schedule/adapter/{select.py => transform.py} (94%) delete mode 100644 tests/schedule/test_graphpass.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 8aa7616b..31f955aa 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -10,7 +10,7 @@ from cube.schedule.su import ScheduleUnit, SUType from cube.schedule.adapter.comm import IRCommType, IRCommunication -from cube.schedule.adapter.select import IRTensorReshape, IRReshapeType +from cube.schedule.adapter.transform import IRTensorTransform, IRTransformType from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -64,7 +64,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # parse graph body for su in device_sus: for node in su.nodes(): - if isinstance(node, IRTensorReshape): + if isinstance(node, IRTensorTransform): self.emit_reshape_call(node) if isinstance(node, IRCommunication): self.emit_comm_call(node) @@ -176,7 +176,7 @@ def emit_reshape_call(self, node): src_tensors = self._forward_region_arg_names(node.inputs()) dst_tensors = self._forward_region_arg_names(node.outputs()) # emit select - if node.ttype == IRReshapeType.Select: + if node.ttype == IRTransformType.Select: src_tensor = src_tensors[0] #TODO: relative indices indices = node.select_indices @@ -184,7 +184,7 @@ def emit_reshape_call(self, node): dst_tensors = ', '.join(dst_tensors) code = f'{dst_tensors} = {node.signature}({src_tensor}, {indices})' self.forward_region.append(code) - elif node.ttype == IRReshapeType.Merge: + elif node.ttype == IRTransformType.Merge: axis = node.merge_axis src_tensor = '(' + ', '.join(src_tensors + ['']) + ')' dst_tensor = dst_tensors[0] diff --git a/cube/graph/operator/adapter.py b/cube/graph/operator/adapter.py deleted file mode 100644 index ed5f8ca5..00000000 --- a/cube/graph/operator/adapter.py +++ /dev/null @@ -1,221 +0,0 @@ -from typing import List, Optional -from enum import Enum -import numpy as np - -from cube.ir.cten import IRCell, IRTensor -from cube.graph.tensor import IRSubTensor, IndexMap - - -class IRReshapeType(Enum): - - Select = 'cube.runtime.adapter.select' - Merge = 'cube.runtime.adapter.merge' - - -class IRShapeAdapter(IRCell): - """ - Tensor transformation by convert source tensors - to destination tensors - - Select: - src_tensors is only one tensor, dst_tensors has (multiple) tensors. - This will select the sub_tensor and generate what it need - - Merge: - src_tensors has (multiple) tensors, dst_tensors is only one tensor. - This will merge the sub_tensor and generate what it need - """ - def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor]): - - if len(src_tensors) != 1 and len(dst_tensors) != 1: - raise ValueError("Expected at least one of tensors has length 1") - - self.ttype = None - - self._select_indices: List[IndexMap] = list() - self._merge_axis = None - - if len(src_tensors) == 1: - self.ttype = IRReshapeType.Select - src_tensor = src_tensors[0] - if not isinstance(src_tensor, IRSubTensor): - raise TypeError(f"Expected IRSubTensor but got {type(src_tensor)}") - # select - for tensor in dst_tensors: - indices = tensor.indices & src_tensor.indices - self._select_indices.append(indices) - - elif len(dst_tensors) == 1: - self.ttype = IRReshapeType.Merge - dst_tensor = dst_tensors[0] - # find dims to concat - ndims = len(dst_tensor.shape) - indices = [set() for _ in range(ndims)] - for src_tensor in src_tensors: - if isinstance(src_tensor, IRSubTensor): - for ndim, slicer in enumerate(src_tensor.indices.get()): - indices[ndim].add((slicer.start, slicer.stop, slicer.step)) - else: - raise RuntimeError( - f"Expected SubTensor but got {type(src_tensor)}" - ) - # check if only one dim set has multiple slicer - for dim, dim_indices in enumerate(indices): - if len(dim_indices) != 1: - if self._merge_axis is not None: - print("src: ", src_tensors) - print("dst: ", dst_tensors) - raise NotImplementedError("Only support merge on one axis") - self._merge_axis = dim - if self._merge_axis is None: - # check the coverage - if src_tensors[0].indices != dst_tensor.indices: - raise RuntimeError("Not cover all the indices to merge.") - # get merge axis - if self._merge_axis is not None: - dim_indices = indices[self._merge_axis] - # check if they are overlapped - starts = np.array([slicer[0] for slicer in dim_indices]) - stops = np.array([slicer[1] for slicer in dim_indices]) - steps = np.array([slicer[2] for slicer in dim_indices]) - sorted_idx = np.argsort(starts) - sorted_starts = starts[sorted_idx] - sorted_stops = stops[sorted_idx] - sorted_steps = steps[sorted_idx] - for last_stop, begin_start in zip(sorted_stops[:-1], sorted_starts[1:]): - if last_stop != begin_start: - raise NotImplementedError(f"Concatenation fails due to axis {last_stop} != {begin_start}") - for step in sorted_steps: - if step != 1: - raise NotImplementedError(f"Found a SubTensor step {step} != 1") - # re-order - src_tensors = np.array(src_tensors)[sorted_idx] - - else: - raise RuntimeError("Internal Error") - - super().__init__( - name = 'transformation', - signature = self.ttype.value, - input_length = len(src_tensors), - output_length = len(dst_tensors) - ) - for idx, input in enumerate(src_tensors): - self.set_input(idx, input) - for idx, output in enumerate(dst_tensors): - self.set_output(idx, output) - - @property - def select_indices(self) -> List[IndexMap]: - return self._select_indices - - @property - def merge_axis(self) -> Optional[int]: - return self._merge_axis - - def is_identity(self): - """ - Check if this transformation is a non-op - """ - if self.ttype == IRReshapeType.Select: - src_tensor = self.inputs(0) - for dst_tensor in self.outputs(): - if dst_tensor != src_tensor: - return False - return True - if self.ttype == IRReshapeType.Merge: - if self.merge_axis is None: - return True - return False - return False - - -class IRCommType(Enum): - - Send = 'send' - Recv = 'recv' - SendRecv = 'sendrecv' - - -class IRCommAdapter(IRCell): - """ - Communication cell for IRCell - """ - - def __init__(self, - send_tensors=list(), send_ranks: List[List[int]] = list(), - recv_tensors=list(), recv_ranks: List[List[int]] =list()): - """ - Create a basic send, recv or sendrecv communication node - """ - if len(send_tensors) != 0 and len(recv_tensors) != 0: - comm_type = IRCommType.SendRecv - signature = 'cube.runtime.collectives.sendrecv' - elif len(send_tensors) != 0 and len(recv_tensors) == 0: - comm_type = IRCommType.Send - signature = 'cube.runtime.collectives.send' - elif len(recv_tensors) != 0 and len(send_tensors) == 0: - comm_type = IRCommType.Recv - signature = 'cube.runtime.collectives.recv' - else: - raise ValueError( - "Expected at least one of send_tensors and recv_tensors" - ) - - self.comm_type = comm_type - self.send_tensors = list() - self.send_ranks = list() - self.recv_tensors = list() - self.recv_ranks = list() - - super().__init__( - name = comm_type.value, - signature = signature, - input_length = len(send_tensors), - output_length = len(recv_tensors) - ) - - for idx, (tensor, to_device) in enumerate(zip(send_tensors, send_ranks)): - self.set_input(idx, tensor) - self.send_tensors.append(self.inputs(idx)) - self.send_ranks.append(to_device) - - for idx, (tensor, from_device) in enumerate(zip(recv_tensors, recv_ranks)): - self.set_output(idx, tensor) - self.recv_tensors.append(self.outputs(idx)) - self.recv_ranks.append(from_device) - - self.msg_id = self._id - - def pair(self, other): - """ - Pair two comm node to have same message id. - - The `other` message id is set same with caller - """ - if not isinstance(other, IRCommAdapter): - raise RuntimeError("Expected IRCommAdapter to pair") - other.msg_id = self.msg_id - - def merge(self, other): - if not isinstance(other, IRCommAdapter): - raise RuntimeError("Expected IRCommAdapter to merge") - raise NotImplementedError - - def __repr__(self): - inputs = list() - for tensor in self.inputs(): - if isinstance(tensor, IRTensor): - inputs.append(f't{tensor._id}-dev{tensor.device}') - else: - inputs.append(tensor) - - outputs = list() - for tensor in self.outputs(): - if isinstance(tensor, IRTensor): - outputs.append(f't{tensor._id}-dev{tensor.device}') - else: - outputs.append(tensor) - - dscp = f'SendRecv(msg_id={self.msg_id}, device={self.device}, send={inputs}, recv={outputs})' - return dscp diff --git a/cube/schedule/adapter/select.py b/cube/schedule/adapter/transform.py similarity index 94% rename from cube/schedule/adapter/select.py rename to cube/schedule/adapter/transform.py index 120f87e8..9c23098a 100644 --- a/cube/schedule/adapter/select.py +++ b/cube/schedule/adapter/transform.py @@ -6,13 +6,13 @@ from cube.graph.tensor import IRSubTensor, IndexMap -class IRReshapeType(Enum): +class IRTransformType(Enum): Select = 'cube.runtime.adapter.select' Merge = 'cube.runtime.adapter.merge' -class IRTensorReshape(IRCell): +class IRTensorTransform(IRCell): """ Tensor transformation by convert source tensors to destination tensors @@ -36,7 +36,7 @@ def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor self._merge_axis = None if len(src_tensors) == 1: - self.ttype = IRReshapeType.Select + self.ttype = IRTransformType.Select src_tensor = src_tensors[0] if not isinstance(src_tensor, IRSubTensor): raise TypeError(f"Expected IRSubTensor but got {type(src_tensor)}") @@ -46,7 +46,7 @@ def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor self._select_indices.append(indices) elif len(dst_tensors) == 1: - self.ttype = IRReshapeType.Merge + self.ttype = IRTransformType.Merge dst_tensor = dst_tensors[0] # find dims to concat ndims = len(dst_tensor.shape) @@ -117,13 +117,13 @@ def is_identity(self): """ Check if this transformation is a non-op """ - if self.ttype == IRReshapeType.Select: + if self.ttype == IRTransformType.Select: src_tensor = self.inputs(0) for dst_tensor in self.outputs(): if dst_tensor != src_tensor: return False return True - if self.ttype == IRReshapeType.Merge: + if self.ttype == IRTransformType.Merge: if self.merge_axis is None: return True return False diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index e1023b65..5b93bb06 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -9,7 +9,7 @@ from cube.graph.graph import IRGraph from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.adapter.comm import IRCommunication -from cube.schedule.adapter.select import IRTensorReshape +from cube.schedule.adapter.transform import IRTensorTransform class SUGraph(IRCell): @@ -449,7 +449,7 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: recv_su.device = su.device # add adapter for merge if len(tensor_segments) != 0: - merge_op = IRTensorReshape( + merge_op = IRTensorTransform( src_tensors=tensor_segments, dst_tensors=[input] ) merge_su = ScheduleUnit([merge_op], SUType.Transform, name='merge') @@ -468,7 +468,7 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: if tensor != output: select_tensors.append(tensor) if len(select_tensors) != 0: - select_op = IRTensorReshape( + select_op = IRTensorTransform( src_tensors=[output], dst_tensors=select_tensors ) select_su = ScheduleUnit( diff --git a/tests/schedule/test_adapter_select.py b/tests/schedule/test_adapter_select.py index 0f89f0f1..38542f88 100644 --- a/tests/schedule/test_adapter_select.py +++ b/tests/schedule/test_adapter_select.py @@ -1,6 +1,6 @@ -from cube.schedule.adapter.select import IRReshapeType -from cube.schedule.adapter.select import IRTensorReshape +from cube.schedule.adapter.transform import IRTransformType +from cube.schedule.adapter.transform import IRTensorTransform from cube.graph.tensor import IRFullTensor, IndexMap @@ -21,28 +21,28 @@ def test_tensor_reshape_init(): shape = [512, 1024] ) - reshape = IRTensorReshape( + reshape = IRTensorTransform( src_tensors=[tensor1], dst_tensors=[tensor2, tensor3] ) assert len(reshape.inputs()) == 1 assert len(reshape.outputs()) == 2 - assert reshape.ttype == IRReshapeType.Select + assert reshape.ttype == IRTransformType.Select assert reshape.select_indices == [ IndexMap((slice(0, 512, 1), slice(0, 1024, 1))), IndexMap((slice(512, 1024, 1), slice(0, 1024, 1))), ] assert reshape.merge_axis is None - reshape = IRTensorReshape( + reshape = IRTensorTransform( dst_tensors=[tensor1], src_tensors=[tensor2, tensor3] ) assert len(reshape.inputs()) == 2 assert len(reshape.outputs()) == 1 - assert reshape.ttype == IRReshapeType.Merge + assert reshape.ttype == IRTransformType.Merge assert reshape.merge_axis == 0 assert len(reshape.select_indices) == 0 @@ -75,13 +75,13 @@ def test_adapter_select_is_identity(): shape = [256, 1024] ) - reshape = IRTensorReshape( + reshape = IRTensorTransform( src_tensors=[tensor2], dst_tensors=[tensor4, tensor5] ) assert not reshape.is_identity() - reshape = IRTensorReshape( + reshape = IRTensorTransform( src_tensors=[tensor3], dst_tensors=[tensor4, tensor5] ) diff --git a/tests/schedule/test_graphpass.py b/tests/schedule/test_graphpass.py deleted file mode 100644 index 4c1798bc..00000000 --- a/tests/schedule/test_graphpass.py +++ /dev/null @@ -1,100 +0,0 @@ -from cube.graph.tensor import IRFullTensor -from cube.graph.operator.function import Linear -from cube.graph.graph import IRGraph - - -from cube.schedule.graphpass import SUGraphPass -from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraphGener - - -def construct_graph(): - - input = IRFullTensor(shape=[64,1024], name='data') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - graph = IRGraph( - nodes=[linear1, linear2, linear3], - input_tensors=[input], - output_tensors=linear3.outputs(), - module_name="Test" - ) - return graph - - -def test_remove_adapter(): - - SchedulePool().clear() - - graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data').tosub() - output = graph(data) - output.backward() - - nodes = SchedulePool().nodes() - sugraph = SUGraphGener.gen_sugraph(nodes) - - for su in sugraph.sus(): - sugraph.assign(su, 0) - sugraph = SUGraphPass.remove_redundant_adapters(sugraph) - for su in sugraph.sus(): - print(su) - for su in sugraph.sus(): - assert su.stype != SUType.Comm - assert len(sugraph.sus()) == 6 - - -def test_merge_small_sus(): - - SchedulePool().clear() - - graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data').tosub() - output = graph(data) - output.backward() - - nodes = SchedulePool().nodes() - sugraph = SUGraphGener.gen_sugraph(nodes) - - for su in sugraph.sus(): - if su.stype != SUType.Comm: - sugraph.assign(su, 0) - - print('orignal:') - print(sugraph) - - sugraph = SUGraphPass.merge_small_sus(sugraph) - - print('merged:') - print(sugraph) - - assert len(sugraph.sus()) == 2 diff --git a/tests/schedule/test_workflow.py b/tests/schedule/test_workflow.py index 59614f5d..71aa3313 100644 --- a/tests/schedule/test_workflow.py +++ b/tests/schedule/test_workflow.py @@ -43,7 +43,7 @@ def __next__(self): def test_semantic_model(): dim = 1024 model = MLP(dim=dim) - model = cube.schedule.SemanticModel( + model = cube.SemanticModel( model, input_shapes=([64, dim],) ) @@ -59,7 +59,7 @@ def test_schedule(): batch_size = 64 model = MLP(dim=dim) - model = cube.schedule.SemanticModel( + model = cube.SemanticModel( model, input_shapes=([batch_size, dim],) ) @@ -100,4 +100,4 @@ def train_iter(model, dataloader): sugraph = policy(sugraph, None) print(sugraph) - assert len(sugraph.sus()) == 33 + assert len(sugraph.sus()) == 23 From 1f54cddd01989e4cdc66bff47503a3285f494077 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 11 Nov 2021 20:25:19 +0800 Subject: [PATCH 0278/1892] transform plan gen; tensor value map bug fix --- cube/graph/tensor.py | 4 +- cube/runtime/transform.py | 44 ++++ cube/schedule/adapter/transform.py | 314 +++++++++++++++++------ tests/schedule/test_adapter_select.py | 88 ------- tests/schedule/test_adapter_transform.py | 196 ++++++++++++++ 5 files changed, 481 insertions(+), 165 deletions(-) create mode 100644 cube/runtime/transform.py delete mode 100644 tests/schedule/test_adapter_select.py create mode 100644 tests/schedule/test_adapter_transform.py diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index a8b8a565..3be5d2c0 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -228,8 +228,8 @@ def chunk_num(self): def map(self, sub_map): if not isinstance(sub_map, ValueMap): raise TypeError("Expected sub_map to be ValueMap") - idx = self.idx + sub_map.idx - chunk_num = self.chunk_num - 1 + sub_map.chunk_num + idx = self.chunk_num * self.idx + sub_map.idx + chunk_num = self.chunk_num * sub_map.chunk_num return ValueMap(idx, chunk_num) def overlap(self, other): diff --git a/cube/runtime/transform.py b/cube/runtime/transform.py new file mode 100644 index 00000000..436f7a34 --- /dev/null +++ b/cube/runtime/transform.py @@ -0,0 +1,44 @@ +""" +Adapter: Tensor Transformation +""" + +from typing import List, Tuple, Optional +import torch + + +def select(tensor: torch.Tensor, + indices: Tuple[slice], val_map: Tuple[int, int]) -> torch.Tensor: + + with torch.no_grad(): + sub_tensor = tensor[indices] + if val_map != (0, 1): + sub_tensor = sub_tensor / val_map[1] + sub_tensor = sub_tensor.contiguous() + return sub_tensor + +def merge(tensors: List[torch.Tensor], + concat: Optional[int] = None, + add: bool = False): + """ + Runtime primitive to finish tensor transformation. + + Warning: No contiguous is called!!! need to explicitly called + before communication + + Args: + tensors: a list of torch tensor + concat: Optional[int]: the dimension to merge + add: bool: whether to perform value merge + """ + if (concat is not None) ^ (add is True): # xor condition + raise RuntimeError("Expected concat or add") + if concat is not None: + with torch.no_grad(): + out = torch.cat(tensors, concat) + return out + if add is not None: + with torch.no_grad(): + out = tensors[0] + for tensor in tensors[1:]: + out = out + tensor + return out diff --git a/cube/schedule/adapter/transform.py b/cube/schedule/adapter/transform.py index 9c23098a..a61bdecb 100644 --- a/cube/schedule/adapter/transform.py +++ b/cube/schedule/adapter/transform.py @@ -1,9 +1,10 @@ +import copy from typing import List, Optional from enum import Enum import numpy as np from cube.ir.cten import IRCell -from cube.graph.tensor import IRSubTensor, IndexMap +from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap class IRTransformType(Enum): @@ -27,72 +28,28 @@ class IRTensorTransform(IRCell): """ def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor]): - if len(src_tensors) != 1 and len(dst_tensors) != 1: + if not all([isinstance(t, IRSubTensor) for t in src_tensors]): + raise TypeError("Expected src tensors to be IRSubTensor") + if not all([isinstance(t, IRSubTensor) for t in dst_tensors]): + raise TypeError("Expected dst tensors to be IRSubTensor") + if not ((len(src_tensors) == 1) or (len(dst_tensors) == 1)): raise ValueError("Expected at least one of tensors has length 1") self.ttype = None + self._trace = list() - self._select_indices: List[IndexMap] = list() - self._merge_axis = None - + # select if len(src_tensors) == 1: self.ttype = IRTransformType.Select - src_tensor = src_tensors[0] - if not isinstance(src_tensor, IRSubTensor): - raise TypeError(f"Expected IRSubTensor but got {type(src_tensor)}") - # select - for tensor in dst_tensors: - indices = tensor.indices & src_tensor.indices - self._select_indices.append(indices) - + self._trace = SelectPlan.gen(src_tensors[0], dst_tensors) + + # merge elif len(dst_tensors) == 1: self.ttype = IRTransformType.Merge - dst_tensor = dst_tensors[0] - # find dims to concat - ndims = len(dst_tensor.shape) - indices = [set() for _ in range(ndims)] - for src_tensor in src_tensors: - if isinstance(src_tensor, IRSubTensor): - for ndim, slicer in enumerate(src_tensor.indices.get()): - indices[ndim].add((slicer.start, slicer.stop, slicer.step)) - else: - raise RuntimeError( - f"Expected SubTensor but got {type(src_tensor)}" - ) - # check if only one dim set has multiple slicer - for dim, dim_indices in enumerate(indices): - if len(dim_indices) != 1: - if self._merge_axis is not None: - print("src: ", src_tensors) - print("dst: ", dst_tensors) - raise NotImplementedError("Only support merge on one axis") - self._merge_axis = dim - if self._merge_axis is None: - # check the coverage - if src_tensors[0].indices != dst_tensor.indices: - raise RuntimeError("Not cover all the indices to merge.") - # get merge axis - if self._merge_axis is not None: - dim_indices = indices[self._merge_axis] - # check if they are overlapped - starts = np.array([slicer[0] for slicer in dim_indices]) - stops = np.array([slicer[1] for slicer in dim_indices]) - steps = np.array([slicer[2] for slicer in dim_indices]) - sorted_idx = np.argsort(starts) - sorted_starts = starts[sorted_idx] - sorted_stops = stops[sorted_idx] - sorted_steps = steps[sorted_idx] - for last_stop, begin_start in zip(sorted_stops[:-1], sorted_starts[1:]): - if last_stop != begin_start: - raise NotImplementedError(f"Concatenation fails due to axis {last_stop} != {begin_start}") - for step in sorted_steps: - if step != 1: - raise NotImplementedError(f"Found a SubTensor step {step} != 1") - # re-order - src_tensors = np.array(src_tensors)[sorted_idx] + self._trace = MergePlan.gen(src_tensors, dst_tensors[0]) else: - raise RuntimeError("Internal Error") + raise NotImplementedError super().__init__( name = 'transformation', @@ -105,26 +62,233 @@ def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor for idx, output in enumerate(dst_tensors): self.set_output(idx, output) - @property - def select_indices(self) -> List[IndexMap]: - return self._select_indices - - @property - def merge_axis(self) -> Optional[int]: - return self._merge_axis + def trace(self): + """ + Get trace of transformation + """ + return copy.copy(self._trace) def is_identity(self): """ Check if this transformation is a non-op """ - if self.ttype == IRTransformType.Select: - src_tensor = self.inputs(0) - for dst_tensor in self.outputs(): - if dst_tensor != src_tensor: - return False - return True - if self.ttype == IRTransformType.Merge: - if self.merge_axis is None: - return True - return False - return False + return len(self._trace) == 0 + + +class SelectPrim: + + def __init__(self, tensor: IRSubTensor, indices: IndexMap, val_map: ValueMap, shape: List[int]): + self.tensor = tensor + self.indices = indices + self.val_map = val_map + self.shape = shape + self.output = None + + def set_output(self, output: IRSubTensor): + self.output = output + + def __repr__(self): + dscp = f't{self.output._id} = select(t{self.tensor._id}, {self.indices}, {self.val_map}, {self.shape})' + return dscp + + +class SelectPlan: + + @staticmethod + def gen(input: IRSubTensor, outputs: List[IRSubTensor]) -> List[SelectPrim]: + trace: List[SelectPrim] = list() + islicers: List[slice] = input.indices.get() + for output in outputs: + if output == input: + continue + oslicers: List[slice] = output.indices.get() + # indices + indices = list() + for islicer, oslicer in zip(islicers, oslicers): + istart, istop, istep = islicer.start, islicer.stop, islicer.step + ostart, ostop, ostep = oslicer.start, oslicer.stop, oslicer.step + if ostep % istep != 0: + raise RuntimeError("Step condition fails") + # relative offset + start = ostart - istart + stop = start + ostop - ostart + slicer = slice(start, stop, ostep) + indices.append(slicer) + indices = IndexMap(tuple(indices)) + # value map + if output.val_map == input.val_map: + val_map = ValueMap(0, 1) + elif input.val_map == ValueMap(0, 1): + val_map = output.val_map + else: + print(output) + raise NotImplementedError( + f"Not supported value trans: {input.val_map} -> {output.val_map}" + ) + prim = SelectPrim(input, indices, val_map, output.shape) + prim.set_output(output) + trace.append(prim) + return trace + + +class MergePrim: + def __init__(self, + tensors: List[IRSubTensor], + concat: Optional[int] = None, + add: bool = False): + if not ((concat is not None) ^ (add is True)): # xor condition + raise RuntimeError("Expected concat or add") + self.tensors = tensors + self.concat = concat + self.add = add + self.output = None + # re-order tensor + if isinstance(concat, int): + slicers = [tensor.indices.get()[concat] for tensor in tensors] + starts = np.array([slicer.start for slicer in slicers], dtype=int) + sorted_idx = np.argsort(starts) + tensors = np.array(tensors)[sorted_idx] + self.tensors = tensors.tolist() + + def set_output(self, output: IRSubTensor): + self.output = output + + + def __repr__(self): + tensors = [f't{t._id}' for t in self.tensors] + tensors = '[' + ', '.join(tensors) + ']' + dscp = f't{self.output._id} = merge({tensors}, axis={self.concat}, add={self.add})' + return dscp + + +class MergePlan: + + @staticmethod + def gen(inputs: List[IRSubTensor], output: IRSubTensor) -> List[MergePrim]: + """ + Generate merge plan from input tensors to the output. + """ + if not all([isinstance(t, IRSubTensor) for t in inputs]): + raise TypeError("Expected inputs: List[IRSubTensor]") + if not isinstance(output, IRSubTensor): + raise TypeError("Expected inputs: List[IRSubTensor]") + + trace : List[MergePrim] = list() + remain_tensors = copy.copy(inputs) + dst_tensor = output + if dst_tensor in remain_tensors: + return trace + out = None + while out != dst_tensor: + # concat or merge + out = None + merge = False + for idx1 in range(len(remain_tensors) - 1): + for idx2 in range(idx1 + 1, len(remain_tensors)): + tensor1 = remain_tensors[idx1] + tensor2 = remain_tensors[idx2] + out = MergePlan.concat(tensor1, tensor2) + if out is not None: + out_tensor, concat_dim = out + out = out_tensor + prim = MergePrim([tensor1, tensor2], concat_dim, False) + prim.set_output(out_tensor) + trace.append(prim) + merge = True + break + out = MergePlan.add(tensor1, tensor2) + if out is not None: + prim = MergePrim([tensor1, tensor2], None, True) + prim.set_output(out) + trace.append(prim) + merge = True + break + if merge: + remain_tensors.remove(tensor1) + remain_tensors.remove(tensor2) + remain_tensors.append(out) + break + # cannot merge or add + if out is None: + raise RuntimeError("Merge Plan not found") + return trace + + + @staticmethod + def concat(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: + """ + Check if two tensor can be merged. + If they can be merged, return the merge index + """ + if not isinstance(tensor1, IRSubTensor) or not isinstance(tensor2, IRSubTensor): + raise TypeError("Expected two tensors") + if tensor1.overlap(tensor2): + return None + if tensor1.parent != tensor2.parent: + return None + if tensor1.val_map != tensor2.val_map: + return None + indices1 = tensor1.indices.get() + indices2 = tensor2.indices.get() + indices = list() + if len(indices1) != len(indices2): + return None + axis = None + for dim, (slicer1, slicer2) in enumerate(zip(indices1, indices2)): + if slicer1 != slicer2: + start1, stop1, step1 = slicer1.start, slicer1.stop, slicer1.step + start2, stop2, step2 = slicer2.start, slicer2.stop, slicer2.step + if step1 != step2: + return None + if axis is not None: + return None + if start1 < start2 and stop1 == start2: + axis = dim + indices.append(slice(start1, stop2, step1)) + elif start1 > start2 and start1 == stop2: + axis = dim + indices.append(slice(start2, stop1, step1)) + else: + return None + else: + indices.append(slicer1) + shapes = list() + for idx, (nele1, nele2) in enumerate(zip(tensor1.shape, tensor2.shape)): + nele = nele1 if idx != axis else nele1 + nele2 + shapes.append(nele) + mtensor = tensor1.parent.select( + indices = tuple(indices), + val_map = tensor1.val_map, + shape = shapes + ) + return mtensor, axis + + @staticmethod + def add(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: + if not isinstance(tensor1, IRSubTensor) or not isinstance(tensor2, IRSubTensor): + raise TypeError("Expected two tensors") + if tensor1.overlap(tensor2): + return None + if tensor1.parent != tensor2.parent: + return None + if tensor1.indices != tensor2.indices: + return None + if tensor1.val_map.chunk_num != tensor2.val_map.chunk_num: + return None + chunk_num = tensor1.val_map.chunk_num + idx1, idx2 = tensor1.val_map.idx, tensor2.val_map.idx + if chunk_num % 2 != 0: + return None + chunk_num = int(chunk_num // 2) + if chunk_num == 1: + idx = 0 + else: + if int(idx1 // chunk_num) != int(idx2 // chunk_num): + return None + idx = int(idx1 // chunk_num) + mtensor = tensor1.parent.select( + indices = tensor1.indices, + val_map = (idx, chunk_num), + shape = tensor1.shape + ) + return mtensor diff --git a/tests/schedule/test_adapter_select.py b/tests/schedule/test_adapter_select.py deleted file mode 100644 index 38542f88..00000000 --- a/tests/schedule/test_adapter_select.py +++ /dev/null @@ -1,88 +0,0 @@ - -from cube.schedule.adapter.transform import IRTransformType -from cube.schedule.adapter.transform import IRTensorTransform - -from cube.graph.tensor import IRFullTensor, IndexMap - - -def test_tensor_reshape_init(): - - tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() - - tensor2 = tensor1.select( - indices = (slice(0, 512), slice(0, 1024)), - val_map = None, - shape = [512, 1024] - ) - - tensor3 = tensor1.select( - indices = (slice(512, 1024), slice(0, 1024)), - val_map = None, - shape = [512, 1024] - ) - - reshape = IRTensorTransform( - src_tensors=[tensor1], - dst_tensors=[tensor2, tensor3] - ) - - assert len(reshape.inputs()) == 1 - assert len(reshape.outputs()) == 2 - assert reshape.ttype == IRTransformType.Select - assert reshape.select_indices == [ - IndexMap((slice(0, 512, 1), slice(0, 1024, 1))), - IndexMap((slice(512, 1024, 1), slice(0, 1024, 1))), - ] - assert reshape.merge_axis is None - - reshape = IRTensorTransform( - dst_tensors=[tensor1], - src_tensors=[tensor2, tensor3] - ) - - assert len(reshape.inputs()) == 2 - assert len(reshape.outputs()) == 1 - assert reshape.ttype == IRTransformType.Merge - assert reshape.merge_axis == 0 - assert len(reshape.select_indices) == 0 - - -def test_adapter_select_is_identity(): - - tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() - - tensor2 = tensor1.select( - indices = (slice(512, 1024), slice(0, 1024)), - val_map = None, - shape = [512, 1024] - ) - - tensor3 = tensor2.select( - indices = (slice(0, 256), slice(0, 1024)), - val_map = None, - shape = [256, 1024] - ) - - tensor4 = tensor1.select( - indices = (slice(512, 768), slice(0, 1024)), - val_map = None, - shape = [256, 1024] - ) - - tensor5 = tensor1.select( - indices = (slice(512, 768), slice(0, 1024)), - val_map = None, - shape = [256, 1024] - ) - - reshape = IRTensorTransform( - src_tensors=[tensor2], - dst_tensors=[tensor4, tensor5] - ) - assert not reshape.is_identity() - - reshape = IRTensorTransform( - src_tensors=[tensor3], - dst_tensors=[tensor4, tensor5] - ) - assert reshape.is_identity() diff --git a/tests/schedule/test_adapter_transform.py b/tests/schedule/test_adapter_transform.py new file mode 100644 index 00000000..9b5c831a --- /dev/null +++ b/tests/schedule/test_adapter_transform.py @@ -0,0 +1,196 @@ +from cube.schedule.adapter.transform import IRTransformType +from cube.schedule.adapter.transform import IRTensorTransform + +from cube.graph.tensor import IRFullTensor + + +def test_tensor_transform_select(): + + tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() + + tensor2 = tensor1.select( + indices = (slice(0, 512), slice(0, 1024)), + val_map = (0, 1), + shape = [512, 1024] + ) + + tensor3 = tensor1.select( + indices = (slice(512, 1024), slice(0, 1024)), + val_map = (0, 2), + shape = [512, 1024] + ) + + tensor4 = tensor3.select( + indices = (slice(0, 256), slice(0, 512)), + val_map = (0, 1), + shape = [256, 512] + ) + + tensor5 = tensor3.select( + indices = (slice(256, 512), slice(0, 512)), + val_map = (0, 1), + shape = [256, 512] + ) + + select1 = IRTensorTransform( + src_tensors=[tensor1], + dst_tensors=[tensor2, tensor3] + ) + assert len(select1.inputs()) == 1 + assert len(select1.outputs()) == 2 + assert select1.ttype == IRTransformType.Select + + print('> select1:', select1) + for prim in select1.trace(): + print(prim) + + select2 = IRTensorTransform( + src_tensors=[tensor3], + dst_tensors=[tensor4, tensor5] + ) + print('> select2:', select2) + for prim in select2.trace(): + print(prim) + assert False + + +def test_tensor_transform_merge(): + tensor0 = IRFullTensor(shape=[1024,1024], name='test1').tosub() + + tensor1 = tensor0.select( + indices = (slice(0, 512), slice(0, 512)), + val_map = None, + shape = [256, 1024] + ) + + tensor2 = tensor0.select( + indices = (slice(0, 512), slice(512, 1024)), + val_map = None, + shape = [256, 1024] + ) + + tensor3 = tensor0.select( + indices = (slice(512, 1024), slice(0, 512)), + val_map = None, + shape = [256, 512] + ) + + tensor4 = tensor0.select( + indices = (slice(512, 1024), slice(512, 1024)), + val_map = None, + shape = [256, 512] + ) + + tensor5 = tensor0.select( + indices = (slice(512, 1024), slice(0, 1024)), + val_map = None, + shape = [256, 512] + ) + + merge1 = IRTensorTransform( + src_tensors=[tensor1, tensor2, tensor3, tensor4], + dst_tensors=[tensor0] + ) + assert len(merge1.inputs()) == 4 + assert len(merge1.outputs()) == 1 + assert merge1.ttype == IRTransformType.Merge + + print('> merge1:') + for prim in merge1.trace(): + print(prim) + assert merge1.trace()[-1].output == tensor0 + assert merge1.trace()[-1].output._id == tensor0._id + + merge2 = IRTensorTransform( + src_tensors=[tensor3, tensor4], + dst_tensors=[tensor5] + ) + print('> merge2:') + for prim in merge2.trace(): + print(prim) + assert merge2.trace()[-1].output == tensor5 + assert merge2.trace()[-1].output._id == tensor5._id + # assert False + + tensor6 = tensor0.select( + indices = (slice(0, 256), slice(0, 1024)), + val_map = (0, 4), + shape = [256, 1024] + ) + tensor7 = tensor0.select( + indices = (slice(0, 256), slice(0, 1024)), + val_map = (1, 4), + shape = [256, 1024] + ) + tensor8 = tensor0.select( + indices = (slice(0, 256), slice(0, 1024)), + val_map = (2, 4), + shape = [256, 1024] + ) + tensor9 = tensor0.select( + indices = (slice(0, 256), slice(0, 1024)), + val_map = (3, 4), + shape = [256, 1024] + ) + + tensor10 = tensor0.select( + indices = (slice(0, 256), slice(0, 1024)), + val_map = (0, 1) + ) + + merge3 = IRTensorTransform( + src_tensors=[tensor6, tensor7, tensor8, tensor9], + dst_tensors=[tensor10] + ) + print('> merge3:') + for prim in merge3.trace(): + print(prim) + assert merge3.trace()[-1].output._id == tensor10._id + # assert False + + +def test_transform_identity(): + + tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() + + tensor2 = tensor1.select( + indices = (slice(512, 1024), slice(0, 1024)), + val_map = None, + shape = [512, 1024] + ) + + tensor3 = tensor2.select( + indices = (slice(0, 256), slice(0, 1024)), + val_map = None, + shape = [256, 1024] + ) + + tensor4 = tensor1.select( + indices = (slice(512, 768), slice(0, 1024)), + val_map = None, + shape = [256, 1024] + ) + + tensor5 = tensor1.select( + indices = (slice(512, 768), slice(0, 1024)), + val_map = None, + shape = [256, 1024] + ) + + select1 = IRTensorTransform( + src_tensors=[tensor2], + dst_tensors=[tensor4, tensor5] + ) + assert not select1.is_identity() + + select2 = IRTensorTransform( + src_tensors=[tensor3], + dst_tensors=[tensor4, tensor5] + ) + assert select2.is_identity() + + merge1 = IRTensorTransform( + src_tensors=[tensor4], + dst_tensors=[tensor5] + ) + assert merge1.is_identity() From 3db9d27b5b2d4f4e9ab2b9d20d5785134678cad2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 11 Nov 2021 23:31:08 +0800 Subject: [PATCH 0279/1892] fix tensor common bugs; fix gpass bugs --- cube/graph/gpass.py | 4 ---- cube/graph/graph.py | 14 ++++++-------- cube/graph/tensor.py | 22 ++++++++++++++++++++-- cube/schedule/adapter/transform.py | 2 +- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index 155948ec..b5fc7a84 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -87,7 +87,6 @@ def forward(graph, *args) -> IRGraph: grad = None if isinstance(val, IRSubTensor): # TODO: requires_grad = False should be set to None - # grad = gener.renew(val, keep_param=False).as_grad() grad = val.get_grad(fnode) val.grad = grad bnode.set_output(idx, grad) @@ -98,9 +97,6 @@ def forward(graph, *args) -> IRGraph: # TODO: requires_grad = False should be set to None grad = val.get_grad(fnode) val.grad = grad - # grad = gener.renew(val, keep_param=False).as_grad() - # TODO: this grad should be partitioned in value dimension - # val.add_grad(grad) bnode.set_grad(idx, grad) fnode.device = node.device diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 063077fe..36857665 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -296,18 +296,16 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional bnode.set_data(idx, val) grad = None if isinstance(val, IRSubTensor): - if val.requires_grad and val.grad is None: - grad = val.get_grad(fnode) - val.grad = grad - grad = val.grad + # TODO: remove grad is grad doesn't require it + grad = val.get_grad(fnode) + val.grad = grad bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): grad = None if isinstance(val, IRSubTensor): - if val.requires_grad and val.grad is None: - grad = val.get_grad(fnode) - val.grad = grad - grad = val.grad + # TODO: remove grad is grad doesn't require it + grad = val.get_grad(fnode) + val.grad = grad bnode.set_grad(idx, grad) fnode.mirror = bnode fnode.device = op.device diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 3be5d2c0..7d6da53b 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -22,6 +22,8 @@ from typing import List, Optional, Union, Tuple import copy +from numpy.lib.arraysetops import isin + from cube.ir.cten import IRCell, IRTensor @@ -249,6 +251,21 @@ def __eq__(self, other): return True return False + def __and__(self, other): + """ + Find the common part + """ + if not isinstance(other, ValueMap): + raise TypeError("Expected ValueMap for & operator") + if not self.overlap(other): + return None + if self.chunk_num == other.chunk_num: + return ValueMap(self.idx, self.chunk_num) + if self.chunk_num == 1: + return ValueMap(other.idx, other.chunk_num) + else: + return ValueMap(self.idx, self.chunk_num) + def __repr__(self): return f'({self.idx}/{self.chunk_num})' @@ -658,9 +675,10 @@ def common(self, other): return self elif isinstance(other, IRSubTensor): indices = self.indices & other.indices + val_map = self.val_map & other.val_map sub_tensor = self.parent.select( - indices = indices.get(), - val_map = self.val_map, + indices = indices, + val_map = val_map, shape = indices.shape ) return sub_tensor diff --git a/cube/schedule/adapter/transform.py b/cube/schedule/adapter/transform.py index a61bdecb..bc767620 100644 --- a/cube/schedule/adapter/transform.py +++ b/cube/schedule/adapter/transform.py @@ -123,7 +123,7 @@ def gen(input: IRSubTensor, outputs: List[IRSubTensor]) -> List[SelectPrim]: else: print(output) raise NotImplementedError( - f"Not supported value trans: {input.val_map} -> {output.val_map}" + f"Not supported value select: {input.val_map} -> {output.val_map}" ) prim = SelectPrim(input, indices, val_map, output.shape) prim.set_output(output) From 141613f8de467438ba6de6d2ed71124e8faf83d0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 11 Nov 2021 23:31:48 +0800 Subject: [PATCH 0280/1892] init transform codegen --- cube/codegen/codegen.py | 42 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 31f955aa..3c756c40 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -2,6 +2,7 @@ Generate Pytorch code given the model DAG and the transformation config """ from typing import List, Any +from numpy import isin import torch import copy @@ -10,7 +11,8 @@ from cube.schedule.su import ScheduleUnit, SUType from cube.schedule.adapter.comm import IRCommType, IRCommunication -from cube.schedule.adapter.transform import IRTensorTransform, IRTransformType +from cube.schedule.adapter.transform import IRTensorTransform +from cube.schedule.adapter.transform import SelectPrim, MergePrim from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -65,7 +67,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: for su in device_sus: for node in su.nodes(): if isinstance(node, IRTensorTransform): - self.emit_reshape_call(node) + self.emit_transform_call(node) if isinstance(node, IRCommunication): self.emit_comm_call(node) else: @@ -169,29 +171,25 @@ def emit_comm_call(self, node): raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") self.forward_region.append(code) - def emit_reshape_call(self, node): + def emit_transform_call(self, node: IRTensorTransform): """ Emit in-device tensor select / merge call. """ - src_tensors = self._forward_region_arg_names(node.inputs()) - dst_tensors = self._forward_region_arg_names(node.outputs()) - # emit select - if node.ttype == IRTransformType.Select: - src_tensor = src_tensors[0] - #TODO: relative indices - indices = node.select_indices - indices = [slicer.get() for slicer in indices] - dst_tensors = ', '.join(dst_tensors) - code = f'{dst_tensors} = {node.signature}({src_tensor}, {indices})' - self.forward_region.append(code) - elif node.ttype == IRTransformType.Merge: - axis = node.merge_axis - src_tensor = '(' + ', '.join(src_tensors + ['']) + ')' - dst_tensor = dst_tensors[0] - code = f'{dst_tensor} = {node.signature}({src_tensor}, {axis})' - self.forward_region.append(code) - else: - raise TypeError(f"Unknown Reshape Type: {node.ttype}") + for prim in node.trace(): + if isinstance(prim, SelectPrim): + signature = 'cube.runtime.transform.select({tensor}, {indices}, {val_map})' + input = self.naming(prim.tensor) + indices = repr(prim.indices) + val_map = repr(tuple([prim.val_map.idx, prim.val_map.chunk_num])) + output = self.naming(prim.output) + code = f'{output} = {signature.format(tensor=input, indices=indices, val_map=val_map)}' + self.forward_region.append(code) + elif isinstance(prim, MergePrim): + signature = 'cube.runtime.transform.merge({tensors}, {concat}, {add})' + inputs = self._forward_region_arg_names(prim.tensors) + output = self.naming(prim.output) + code = f'{output} = {signature.format(tensors=inputs, concat=prim.concat, add=prim.add)}' + self.forward_region.append(code) def _forward_region_arg_names(self, tensors: List[Any]): """ From e9eb01a78958b7b90f3efbde644e79fae57abb52 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Nov 2021 10:45:00 +0800 Subject: [PATCH 0281/1892] add tag interface --- cube/ir/cten.py | 12 ++++++++++++ cube/schedule/su.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 13e6ea92..065df596 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -66,6 +66,7 @@ def __init__(self, self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length)] self._mirror = None + self._tag = None @property def device(self): @@ -319,6 +320,17 @@ def get_outputs(cells): outputs.append(output) return outputs + @property + def tag(self) -> Any: + return self._tag + + @tag.setter + def tag(self, info: Any): + """ + Tag an info to the cell + """ + self._tag = info + def __repr__(self): """ Cell string presentation diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 5eda73b5..6cb02e04 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -91,6 +91,8 @@ def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): self._ctrl_predecessors = list() self._ctrl_successors = list() + self._tag = [node.tag for node in nodes] + def __copy__(self): """ Copy the SU. Note the mirror su is also copied From b5904c48a872990ae577e7622a251725886c6ca3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Nov 2021 12:55:55 +0800 Subject: [PATCH 0282/1892] enable tensor parallelism: fix bugs on adapter gen --- cube/codegen/codegen.py | 128 ++++++++++--------- cube/execplan/execplan.py | 10 +- cube/runtime/__init__.py | 3 +- cube/runtime/transform.py | 2 +- cube/schedule/su.py | 2 +- cube/schedule/sugraph.py | 31 ++--- examples/linears.py | 64 +--------- examples/policy/col_parallel.py | 38 ++++++ examples/policy/pipe_parallel.py | 61 +++++++++ tests/codegen/test_partition_codegen.py | 112 +++++++++++++++++ tests/graph/test_graph_partition.py | 157 ++++++++++++++++++++++++ 11 files changed, 465 insertions(+), 143 deletions(-) create mode 100644 examples/policy/col_parallel.py create mode 100644 examples/policy/pipe_parallel.py create mode 100644 tests/codegen/test_partition_codegen.py create mode 100644 tests/graph/test_graph_partition.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 3c756c40..1d89596a 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -17,15 +17,46 @@ from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock -class ModelCodeGen: +class CodeGen: """ - Generate spatial code for the model + Generate code for the model """ - def __init__(self, execplan: ExectuionPlan): if not isinstance(execplan, ExectuionPlan): raise TypeError("execplan should be ExecutionPlan") self.execplan = execplan + + def su_naming(self, su: ScheduleUnit) -> str: + if su.stype == SUType.Forward: + return f"fwcp{su._id}" + if su.stype == SUType.Backward: + return f"bwcp{su._id}" + if su.stype == SUType.Comm: + return f"comm{su._id}" + if su.stype == SUType.Transform: + return f"trans{su._id}" + + def tensor_naming(self, tensor: Any) -> str: + """ + Return the var name (unique for different variable) + """ + if isinstance(tensor, IRTensor): + tensor_name = 'tensor' if tensor.name is None else tensor.name + if '.' in tensor_name: + tensor_name = tensor_name.split('.')[0] + name = '_'.join([tensor_name, str(tensor._id)]) + else: + name = str(tensor) + return name + + +class ModelCodeGen(CodeGen): + """ + Generate spatial code for the model + """ + + def __init__(self, execplan: ExectuionPlan): + super().__init__(execplan) # model full code self.init_code: List[str] = [ '\n\n########## Generated Model Code ###########', @@ -58,7 +89,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: for input in su.inputs(): if isinstance(input, IRTensor) and input.is_param(): continue - fargs.append(self.naming(input)) + fargs.append(self.tensor_naming(input)) for name in fargs: self.symbols.create(name) su_args.append(fargs) @@ -68,7 +99,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: for node in su.nodes(): if isinstance(node, IRTensorTransform): self.emit_transform_call(node) - if isinstance(node, IRCommunication): + elif isinstance(node, IRCommunication): self.emit_comm_call(node) else: self.emit_op_call(node) @@ -78,7 +109,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # record output tensor name for out in node.outputs(): if isinstance(out, IRTensor) or isinstance(out, str): - self.symbols.create(self.naming(out)) + self.symbols.create(self.tensor_naming(out)) self.all_su_forward_region.append(self.forward_region) self.forward_region = list() @@ -89,7 +120,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: cb.insert_body('') cb.insert_body(ib.code) for idx, su in enumerate(device_sus): - name = f'su{su._id}' + name = self.su_naming(su) input_args = ['self'] + su_args[idx] forward_code = self.all_su_forward_region[idx] with FunctionBlock(func_name=name, args=input_args) as fb: @@ -118,14 +149,14 @@ def emit_var_declare(self, var: Any): Emit tensor declaration code """ if isinstance(var, IRTensor): - name = self.naming(var) + name = self.tensor_naming(var) # emit parameter code if var.is_param() and not self.symbols.exist(name): self.symbols.create(name) code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(var.shape)}))' self.declare_region.append(code) elif isinstance(var, str): - name = self.naming(var) + name = self.tensor_naming(var) if name.startswith('self.'): if not hasattr(self._ref_module, var[5:]): if self.symbols.create(name): @@ -178,18 +209,28 @@ def emit_transform_call(self, node: IRTensorTransform): for prim in node.trace(): if isinstance(prim, SelectPrim): signature = 'cube.runtime.transform.select({tensor}, {indices}, {val_map})' - input = self.naming(prim.tensor) + input = self.tensor_naming(prim.tensor) indices = repr(prim.indices) val_map = repr(tuple([prim.val_map.idx, prim.val_map.chunk_num])) - output = self.naming(prim.output) + output = self.tensor_naming(prim.output) code = f'{output} = {signature.format(tensor=input, indices=indices, val_map=val_map)}' self.forward_region.append(code) elif isinstance(prim, MergePrim): signature = 'cube.runtime.transform.merge({tensors}, {concat}, {add})' inputs = self._forward_region_arg_names(prim.tensors) - output = self.naming(prim.output) + inputs = '(' + ', '.join(inputs) + ')' + output = self.tensor_naming(prim.output) code = f'{output} = {signature.format(tensors=inputs, concat=prim.concat, add=prim.add)}' self.forward_region.append(code) + else: + raise RuntimeError(f"Not supported prim: {type(prim)}") + for output in node.outputs(): + # contiguous and requires grad + output = self.tensor_naming(output) + code = f'{output} = {output}.contiguous()' + self.forward_region.append(code) + code = f'{output} = {output}.requires_grad_()' + self.forward_region.append(code) def _forward_region_arg_names(self, tensors: List[Any]): """ @@ -199,26 +240,13 @@ def _forward_region_arg_names(self, tensors: List[Any]): """ named_args : List[str] = list() for tensor in tensors: - name = self.naming(tensor) + name = self.tensor_naming(tensor) if isinstance(tensor, IRTensor) and tensor.is_param(): named_args.append('self.' + name) else: - named_args.append(self.naming(name)) + named_args.append(self.tensor_naming(name)) return named_args - def naming(self, tensor: Any) -> str: - """ - Return the var name (unique for different variable) - """ - if isinstance(tensor, IRTensor): - tensor_name = 'tensor' if tensor.name is None else tensor.name - if '.' in tensor_name: - tensor_name = tensor_name.split('.')[0] - name = '_'.join([tensor_name, str(tensor._id)]) - else: - name = str(tensor) - return name - def clear(self): """ Clear buffer that used for generating code @@ -232,12 +260,10 @@ def clear(self): self.symbols = SymbolTable() -class ScheduleCodeGen: +class ScheduleCodeGen(CodeGen): def __init__(self, execplan: ExectuionPlan): - if not isinstance(execplan, ExectuionPlan): - raise TypeError("execplan should be ExecutionPlan") - self.execplan = execplan + super().__init__(execplan) # model full code self.init_code: List[str] = [ '\n\n########## Generated Schedule Code ###########', @@ -256,7 +282,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: for su in device_sus: - name = f'su{su._id}' + name = self.su_naming(su) code = self.emit_su(su, name=name) fb.insert_body(code) gencode += fb.code @@ -273,30 +299,31 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: """ Emit su code """ + fsu_types = [SUType.Forward, SUType.Comm, SUType.Transform] fsign = 'cube.runtime.executor.fexecute({model}, *{inputs})' bsign = 'cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' if su.stype == SUType.Dataloader: if len(su.inputs()) != 0: raise RuntimeError("Dataloader su has no inputs") - outputs = [self.naming(output, su) for output in su.outputs()] + outputs = [self.tensor_naming(output) for output in su.outputs()] return_val = ','.join(outputs) code = f'{return_val} = next(dataloader)' return code - elif su.stype == SUType.Forward or su.stype == SUType.Comm: + elif su.stype in fsu_types: inputs = list() for tensor in su.inputs(): if isinstance(tensor, IRTensor): if tensor.is_param(): continue - inputs.append(self.naming(tensor, su)) + inputs.append(self.tensor_naming(tensor)) inputs = '(' + ', '.join(inputs + ['']) + ')' body = fsign.format( model = f'model.{name}', inputs = inputs ) - outputs = [self.naming(output, su) for output in su.outputs()] + outputs = [self.tensor_naming(output) for output in su.outputs()] return_val = ','.join(outputs) if len(su.outputs()) == 0: code = body @@ -317,17 +344,17 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: if isinstance(tensor, IRTensor): if tensor.is_param(): continue - finputs.append(self.naming(tensor, fsu)) + finputs.append(self.tensor_naming(tensor)) fargs = '(' + ', '.join(finputs + ['']) + ')' fouts = list() for tensor in fsu.outputs(): - fouts.append(self.naming(tensor, fsu)) + fouts.append(self.tensor_naming(tensor)) fouts = '(' + ', '.join(fouts + ['']) + ')' fout_grads = list() for fout in fsu.outputs(): - fout_grads.append(self.naming(fout.grad, fsu)) + fout_grads.append(self.tensor_naming(fout.grad)) fout_grads = '(' + ', '.join(fout_grads + ['']) + ')' body = bsign.format( @@ -337,7 +364,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: ) # returned value are graph.outputs - return_val = [self.naming(tensor, su) for tensor in su.outputs()] + return_val = [self.tensor_naming(tensor) for tensor in su.outputs()] # TODO: fix this by using grad attributed return_val = return_val[:len(finputs)] if len(return_val) > 0: @@ -347,23 +374,4 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: code = f'{return_code}{body}' return code else: - raise RuntimeError(f"Unsupported su tag: {su.tag}") - - def naming(self, tensor: Any, su) -> str: - """ - Return the var name (unique for different variable) - - If the var is a leaf tensor, will add prefix `self.` to its name - """ - if isinstance(tensor, IRTensor): - # note in su there is no parameters - # if len(tensor.src(su.nodes())) == 0: - # name = '*next(dataloader)' - # else: - tensor_name = 'tensor' if tensor.name is None else tensor.name - if '.' in tensor_name: - tensor_name = tensor_name.split('.')[0] - name = '_'.join([tensor_name, str(tensor._id)]) - else: - name = str(tensor) - return name + raise RuntimeError(f"Unsupported SUType: {su.stype}") diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index f0dee2ce..2fa2ded4 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -15,11 +15,11 @@ def __init__(self, sugraph: SUGraph): for su in sugraph.sus(): if len(su.device) == 0: raise RuntimeError(f"device not set: SU {su}") - device = su.device[0] - if device not in self.device_seq: - self.device_seq[device] = [su] - else: - self.device_seq[device].append(su) + for device in su.device: + if device not in self.device_seq: + self.device_seq[device] = [su] + else: + self.device_seq[device].append(su) def devices(self) -> List[int]: """ diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index c3f5515b..a2112ae0 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -1,3 +1,4 @@ -from cube.runtime import collectives, executor, device +from cube.runtime import collectives, executor, transform +from cube.runtime import device from cube.runtime import syndata from cube.runtime import resource \ No newline at end of file diff --git a/cube/runtime/transform.py b/cube/runtime/transform.py index 436f7a34..f3976501 100644 --- a/cube/runtime/transform.py +++ b/cube/runtime/transform.py @@ -30,7 +30,7 @@ def merge(tensors: List[torch.Tensor], concat: Optional[int]: the dimension to merge add: bool: whether to perform value merge """ - if (concat is not None) ^ (add is True): # xor condition + if not ((concat is not None) ^ (add is True)): # xor condition raise RuntimeError("Expected concat or add") if concat is not None: with torch.no_grad(): diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 6cb02e04..b28d796f 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -97,8 +97,8 @@ def __copy__(self): """ Copy the SU. Note the mirror su is also copied """ + raise NotImplementedError("Copy SU is not supported yet") su = ScheduleUnit(self._nodes, self.stype, self.name) - #TODO: adapter copy if self.mirror is not None: mirror_su = self.mirror mirror_su = ScheduleUnit( diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 5b93bb06..eb490a14 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -304,7 +304,7 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): raise ValueError(f"SU {su} is not in the SUGraph") if isinstance(ranks, int): ranks = [ranks] - elif not all([isinstance(int, rank) for rank in ranks]): + elif not all([isinstance(rank, int) for rank in ranks]): raise TypeError("Expected type ranks to be Union[int, List[int]]") if su.stype == SUType.Comm: @@ -314,16 +314,18 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): return True if len(ranks) != 1: - # copy su - # TODO: adatper copy - print('warning: Missing adapter copy!!') - sus = [copy.copy(su) for _ in range(len(ranks)-1)] - for su in sus: - index = self.sus().index(su) - self.sequence.insert(index, su) - SUGraph.reset_dependency(self.sequence) - for su, rank in zip(sus, ranks): - self.assign(su, rank) + if su.stype == SUType.Dataloader: + su.device = ranks + else: + raise NotImplementedError("Assign multiple ranks to one SU is not supported") + # print('warning: Missing adapter copy!!') + # sus = [copy.copy(su) for _ in range(len(ranks)-1)] + # for su in sus: + # index = self.sus().index(su) + # self.sequence.insert(index, su) + # SUGraph.reset_dependency(self.sequence) + # for su, rank in zip(sus, ranks): + # self.assign(su, rank) # set device su.device = ranks @@ -430,7 +432,7 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: for out_idx, output in enumerate(pre_su.outputs()): if output.overlap(input): sub_tensor = input.common(output) - if sub_tensor != input: + if sub_tensor != input and sub_tensor not in tensor_segments: tensor_segments.append(sub_tensor) send_op = IRCommunication( send_tensors=[sub_tensor], @@ -465,7 +467,7 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: send_adapters, recv_adapters = su.out_adapters(out_idx) for send_adapter in send_adapters: for tensor in send_adapter.nodes(0).send_tensors: - if tensor != output: + if tensor != output and tensor not in select_tensors: select_tensors.append(tensor) if len(select_tensors) != 0: select_op = IRTensorTransform( @@ -537,7 +539,8 @@ def __repr__(self): for out_idx in range(len(node.outputs())): node_list = [snode._id for snode in node.successors(out_idx)] succ_node_ids[out_idx] = node_list - dscp += f"\n{node._id}: {node} -> su id {succ_node_ids}\n" + dscp += f"{node._id}: {node}\n" + # dscp += f"\n{node._id}: {node} -> su id {succ_node_ids}\n" return dscp diff --git a/examples/linears.py b/examples/linears.py index d45dbdac..88d2f38a 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -13,69 +13,10 @@ import torch from torch import nn -import math -import random import cube -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph - - -def transform_policy(graph, resource): - """ - The transformation policy transposes linear using data parallel - """ - from cube.graph.operator.operator import IRDataOperation, IRFwOperation - for node in graph.nodes(): - if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): - algo = node.algorithms('data') - assert algo is not None - graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy - """ - fb_seqs = list() - for fsu in sugraph.fsus(): - for fb_seq in fb_seqs: - for ksu in fb_seq[::-1]: - if sugraph.happen_before(ksu, fsu): - fb_seq.append(fsu) - break - else: - continue - break - else: - fb_seqs.append([fsu]) - - # device assignment - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - - print(f'> collect {len(fb_seqs)} forward-backward sequence') - for fb_seq in fb_seqs: - chunk_num = int(math.ceil(len(fb_seq) / resource.ngpus)) - for idx, su in enumerate(fb_seq): - # devid = int(idx // chunk_num) - # devid = idx % resource.ngpus - devid = random.randint(0, resource.ngpus - 1) - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - - # set partial order - for fb_seq in fb_seqs: - fb_seq += [fsu.mirror for fsu in fb_seq][::-1] - - seqs = list() - for fb_seq in fb_seqs: - seqs += fb_seq - sugraph.partial_set_order(seqs) - return sugraph - +from examples.policy.col_parallel import transform_policy +from examples.policy.col_parallel import schedule_policy # =================== Semantic Model Description ==================== @@ -121,6 +62,7 @@ def train_iter(model, dataloader): train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() + print('====> end iteration') if __name__ == '__main__': diff --git a/examples/policy/col_parallel.py b/examples/policy/col_parallel.py new file mode 100644 index 00000000..4fbb2937 --- /dev/null +++ b/examples/policy/col_parallel.py @@ -0,0 +1,38 @@ + + + +import cube +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRDataOperation, IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using column parallel + """ + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + algo = node.algorithms('column') + if algo is None: + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + # sugraph.assign(su, list(range(resource.ngpus))) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + return sugraph diff --git a/examples/policy/pipe_parallel.py b/examples/policy/pipe_parallel.py new file mode 100644 index 00000000..d8535c2e --- /dev/null +++ b/examples/policy/pipe_parallel.py @@ -0,0 +1,61 @@ +import math +import random + +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph + + +def transform_policy(graph, resource): + """ + The transformation policy transposes linear using data parallel + """ + from cube.graph.operator.operator import IRDataOperation, IRFwOperation + for node in graph.nodes(): + if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): + algo = node.algorithms('data') + assert algo is not None + graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy + """ + fb_seqs = list() + for fsu in sugraph.fsus(): + for fb_seq in fb_seqs: + for ksu in fb_seq[::-1]: + if sugraph.happen_before(ksu, fsu): + fb_seq.append(fsu) + break + else: + continue + break + else: + fb_seqs.append([fsu]) + + # device assignment + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + + print(f'> collect {len(fb_seqs)} forward-backward sequence') + for fb_seq in fb_seqs: + chunk_num = int(math.ceil(len(fb_seq) / resource.ngpus)) + for idx, su in enumerate(fb_seq): + # devid = int(idx // chunk_num) + # devid = idx % resource.ngpus + devid = random.randint(0, resource.ngpus - 1) + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + + # set partial order + for fb_seq in fb_seqs: + fb_seq += [fsu.mirror for fsu in fb_seq][::-1] + + seqs = list() + for fb_seq in fb_seqs: + seqs += fb_seq + sugraph.partial_set_order(seqs) + return sugraph diff --git a/tests/codegen/test_partition_codegen.py b/tests/codegen/test_partition_codegen.py new file mode 100644 index 00000000..8b55805c --- /dev/null +++ b/tests/codegen/test_partition_codegen.py @@ -0,0 +1,112 @@ +from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.graph.tensor import IRFullTensor +from cube.graph.operator.function import Linear +from cube.graph.graph import IRGraph +from cube.schedule.pool import SchedulePool +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraphGener +from cube.schedule.translator import IRDataLoader + +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.redundant import RemoveRedundantAdapters +from cube.execplan.planpass.merge import MergeComputeSU + +from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen + + +def simple_linear(): + input1 = IRFullTensor(shape=[64,1024], name='data1') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = Linear( + name='linear1', + signature='torch.nn.functional.linear', + inputs= [input1, weight1, bias1], + ) + linear1.infer_shape() + + # linear2 + linear2 = Linear( + name='linear2', + signature='torch.nn.functional.linear', + inputs= [linear1.outputs(0), weight2, None], + ) + linear2.infer_shape() + + # linear3 + linear3 = Linear( + name='linear3', + signature='torch.nn.functional.linear', + inputs= [linear2.outputs(0), weight3, bias3], + ) + linear3.infer_shape() + return [input1], [linear1, linear2, linear3], [linear3.outputs(0)] + + +def test_linear_col_codegen(): + + SchedulePool().clear() + ngpus = 2 + + inputs, ops, outputs = simple_linear() + linear1, linear2, linear3, = ops + graph = IRGraph(ops, inputs, outputs, 'MLP') + print(graph) + + inputs = [inputs[0].tosub()] + loss = graph(*inputs) + loss.backward() + + nodes = SchedulePool().nodes() + fbgraph = IRGraph(nodes, None, None, 'MLPFull') + print(fbgraph) + + # replace first linear by data parallel + algo = linear1.algorithms('column') + subnodes1 = fbgraph.partition(linear1, algo, config=dict(chunk_num=ngpus)) + + algo = linear2.algorithms('column') + subnodes2 = fbgraph.partition(linear2, algo, config=dict(chunk_num=ngpus)) + + algo = linear3.algorithms('column') + subnodes3 = fbgraph.partition(linear3, algo, config=dict(chunk_num=ngpus)) + + print(fbgraph) + + sugraph = SUGraphGener.gen_sugraph(fbgraph.nodes()) + algosu1 = sugraph.fsus()[:ngpus] + for idx, su in enumerate(algosu1): + sugraph.assign(su, idx) + sugraph.assign(su.mirror, idx) + algosu2 = sugraph.fsus()[ngpus: ngpus * 2] + for idx, su in enumerate(algosu2): + sugraph.assign(su, idx) + sugraph.assign(su.mirror, idx) + algosu3 = sugraph.fsus()[ngpus * 2: ngpus * 3] + for idx, su in enumerate(algosu3): + sugraph.assign(su, idx) + sugraph.assign(su.mirror, idx) + print(sugraph) + + execplan = ExectuionPlan(sugraph) + execplan = RemoveRedundantAdapters.apply(execplan) + + execplan = MergeComputeSU.apply(execplan) + + mgener = ModelCodeGen(execplan) + tgener = ScheduleCodeGen(execplan) + + for devid in range(ngpus): + mcode0 = mgener.gen(device=devid, outfile=f'test{devid}.py') + tcode0 = tgener.gen(device=devid, outfile=f'test{devid}.py', attach=True) + print(f'===> model code on device {devid}: ') + print(mcode0) + print(f'===> schedule code on device {devid}: ') + print(tcode0) + + assert False \ No newline at end of file diff --git a/tests/graph/test_graph_partition.py b/tests/graph/test_graph_partition.py new file mode 100644 index 00000000..f3a348f5 --- /dev/null +++ b/tests/graph/test_graph_partition.py @@ -0,0 +1,157 @@ +import enum +from cube.graph.graph import IRGraph +from cube.graph.tensor import IRFullTensor, ValueMap +from cube.graph.operator.function import Linear, ElementWise +from cube.schedule.pool import SchedulePool +from cube.schedule.sugraph import SUGraphGener + + +def simple_linear(): + input1 = IRFullTensor(shape=[64,1024], name='data1') + weight1 = IRFullTensor(shape=[1024, 1024], name='weight') + bias1 = IRFullTensor(shape=[1024, 1024], name='bias') + weight2 = IRFullTensor(shape=[1024, 1024], name='weight') + weight3 = IRFullTensor(shape=[1024, 1024], name='weight') + bias3 = IRFullTensor(shape=[1024, 1024], name='bias') + + # linear1 + linear1 = Linear( + name='linear1', + signature='torch.nn.functional.linear', + inputs= [input1, weight1, bias1], + ) + linear1.infer_shape() + + # linear2 + linear2 = Linear( + name='linear2', + signature='torch.nn.functional.linear', + inputs= [linear1.outputs(0), weight2, None], + ) + linear2.infer_shape() + + # linear3 + linear3 = Linear( + name='linear3', + signature='torch.nn.functional.linear', + inputs= [linear2.outputs(0), weight3, bias3], + ) + linear3.infer_shape() + return [input1], [linear1, linear2, linear3], [linear3.outputs(0)] + + +def test_linear_dp_partition(): + + SchedulePool().clear() + + inputs, ops, outputs = simple_linear() + linear1, linear2, linear3, = ops + graph = IRGraph(ops, inputs, outputs, 'MLP') + print(graph) + + inputs = [inputs[0].tosub()] + loss = graph(*inputs) + loss.backward() + + nodes = SchedulePool().nodes() + fbgraph = IRGraph(nodes, None, None, 'MLPFull') + print(fbgraph) + + # replace first linear by data parallel + algo = linear1.algorithms('data') + subnodes = fbgraph.partition(linear1, algo, config=dict(chunk_num=4)) + + algo = linear2.algorithms('data') + subnodes = fbgraph.partition(linear2, algo, config=dict(chunk_num=4)) + + algo = linear3.algorithms('data') + subnodes = fbgraph.partition(linear3, algo, config=dict(chunk_num=4)) + + print(fbgraph) + for node in subnodes: + print(node) + print(node.mirror) + # assert False + +def test_linear_hybrid_partition(): + + SchedulePool().clear() + ngpus = 2 + + inputs, ops, outputs = simple_linear() + linear1, linear2, linear3, = ops + graph = IRGraph(ops, inputs, outputs, 'MLP') + print(graph) + + inputs = [inputs[0].tosub()] + loss = graph(*inputs) + loss.backward() + + nodes = SchedulePool().nodes() + fbgraph = IRGraph(nodes, None, None, 'MLPFull') + print(fbgraph) + + # replace first linear by data parallel + algo = linear1.algorithms('column') + subnodes1 = fbgraph.partition(linear1, algo, config=dict(chunk_num=ngpus)) + + algo = linear2.algorithms('column') + subnodes2 = fbgraph.partition(linear2, algo, config=dict(chunk_num=ngpus)) + + algo = linear3.algorithms('column') + subnodes3 = fbgraph.partition(linear3, algo, config=dict(chunk_num=ngpus)) + + print(fbgraph) + # for node in subnodes: + # print(node) + # print(node.mirror) + + sugraph = SUGraphGener.gen_sugraph(fbgraph.nodes()) + algosu1 = sugraph.fsus()[:ngpus] + for idx, su in enumerate(algosu1): + sugraph.assign(su, idx) + sugraph.assign(su.mirror, idx) + algosu2 = sugraph.fsus()[ngpus: ngpus * 2] + for idx, su in enumerate(algosu2): + sugraph.assign(su, idx) + sugraph.assign(su.mirror, idx) + algosu3 = sugraph.fsus()[ngpus * 2: ngpus * 3] + for idx, su in enumerate(algosu3): + sugraph.assign(su, idx) + sugraph.assign(su.mirror, idx) + print(sugraph) + + print('===== algo 1 =====') + for idx, su in enumerate(algosu1): + print('F:', su) + print('B:', su.mirror) + data_grad = su.mirror.outputs(0) + data_grad_ref = su.inputs(0).get_grad(su.nodes(0)) + print('grad :', data_grad) + print('grad ref:', data_grad_ref) + assert data_grad == data_grad_ref + assert data_grad.val_map == ValueMap(idx, ngpus) + + print('===== algo 2 =====') + for idx, su in enumerate(algosu2): + print('F:', su) + print('B:', su.mirror) + data_grad = su.mirror.outputs(0) + data_grad_ref = su.inputs(0).get_grad(su.nodes(0)) + print('grad :', data_grad) + print('grad ref:', data_grad_ref) + assert data_grad == data_grad_ref + assert data_grad.val_map == ValueMap(idx, ngpus) + + print('===== algo 3 =====') + for idx, su in enumerate(algosu3): + print('F:', su) + print('B:', su.mirror) + data_grad = su.mirror.outputs(0) + data_grad_ref = su.inputs(0).get_grad(su.nodes(0)) + print('grad :', data_grad) + print('grad ref:', data_grad_ref) + assert data_grad == data_grad_ref + assert data_grad.val_map == ValueMap(idx, ngpus) + + assert False From 5ce2157df88b5486fe018b8a42b6a541157ad9f0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Nov 2021 13:41:49 +0800 Subject: [PATCH 0283/1892] enable all kinds of parallelism except data parallel --- cube/codegen/codegen.py | 2 ++ examples/linears.py | 4 +-- examples/policy/hybrid_parallel.py | 39 ++++++++++++++++++++++++++++++ examples/policy/no_parallel.py | 23 ++++++++++++++++++ examples/policy/row_parallel.py | 34 ++++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 examples/policy/hybrid_parallel.py create mode 100644 examples/policy/no_parallel.py create mode 100644 examples/policy/row_parallel.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 1d89596a..a9c6de78 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -281,6 +281,8 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: + if len(device_sus) == 0: + fb.insert_body('pass') for su in device_sus: name = self.su_naming(su) code = self.emit_su(su, name=name) diff --git a/examples/linears.py b/examples/linears.py index 88d2f38a..4a6b4ba3 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -15,8 +15,8 @@ from torch import nn import cube -from examples.policy.col_parallel import transform_policy -from examples.policy.col_parallel import schedule_policy +from examples.policy.hybrid_parallel import transform_policy +from examples.policy.hybrid_parallel import schedule_policy # =================== Semantic Model Description ==================== diff --git a/examples/policy/hybrid_parallel.py b/examples/policy/hybrid_parallel.py new file mode 100644 index 00000000..dd673e4f --- /dev/null +++ b/examples/policy/hybrid_parallel.py @@ -0,0 +1,39 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using column parallel + """ + for node in graph.nodes(): + idx = 0 + if isinstance(node, IRFwOperation): + algo = None + if idx % 2 == 0: + algo = node.algorithms('column') + else: + algo = node.algorithms('row') + if algo is None: + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + # sugraph.assign(su, list(range(resource.ngpus))) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + return sugraph diff --git a/examples/policy/no_parallel.py b/examples/policy/no_parallel.py new file mode 100644 index 00000000..702c02b7 --- /dev/null +++ b/examples/policy/no_parallel.py @@ -0,0 +1,23 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using column parallel + """ + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + for su in sugraph.fsus(): + sugraph.assign(su, 0) + sugraph.assign(su.mirror, 0) + return sugraph diff --git a/examples/policy/row_parallel.py b/examples/policy/row_parallel.py new file mode 100644 index 00000000..0c1198fa --- /dev/null +++ b/examples/policy/row_parallel.py @@ -0,0 +1,34 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using column parallel + """ + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + algo = node.algorithms('row') + if algo is None: + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + # sugraph.assign(su, list(range(resource.ngpus))) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + return sugraph From 7af594b67bb532dffbb79ed284118a64b1907df5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Nov 2021 15:09:57 +0800 Subject: [PATCH 0284/1892] fix policy error --- examples/policy/hybrid_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/policy/hybrid_parallel.py b/examples/policy/hybrid_parallel.py index dd673e4f..5de96ce1 100644 --- a/examples/policy/hybrid_parallel.py +++ b/examples/policy/hybrid_parallel.py @@ -18,6 +18,7 @@ def transform_policy(graph: IRGraph, resource): algo = node.algorithms('row') if algo is None: algo = node.algorithms('data') + idx += 1 sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx From 384dcdf861a93ed62b33a2ff0e04b96ccc44d13d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Nov 2021 15:40:39 +0800 Subject: [PATCH 0285/1892] add reducer for weight gradient sync --- cube/runtime/reducer.py | 77 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 cube/runtime/reducer.py diff --git a/cube/runtime/reducer.py b/cube/runtime/reducer.py new file mode 100644 index 00000000..a946ffbe --- /dev/null +++ b/cube/runtime/reducer.py @@ -0,0 +1,77 @@ +""" +Borrowed from Megatron Implementation +""" + +from typing import List +import torch + +from cube.runtime.device import DeviceGroup + + +class Reducer: + + def __init__(self, ranks: List[int]): + + self._params: List[torch.nn.Parameter] = list() + # note this need to be called for every device + self.ranks = ranks + self._group = DeviceGroup().get_group(ranks) + + def add_param(self, param: torch.nn.Parameter): + self._params.append(param) + + def allreduce(self): + """ + Reduce gradients across given group + """ + buckets = {} + for param in self._params: + if param.requires_grad and param.grad is not None: + tp = param.data.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + # TODO: figure out why Megatron needs this? + # param.main_grad = param.grad + # for each bucket, do all-reduce + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = self._flatten_dense_tensors(grads) + coalesced /= len(self.ranks) + torch.distributed.all_reduce(coalesced, group=self._group) + all_synced = self._unflatten_dense_tensors(coalesced, grads) + for grad, synced in zip(grads, all_synced): + grad.copy_(synced) + + def _flatten_dense_tensors(self, tensors): + """ + Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of + same dense type. + + Since inputs are dense, the resulting tensor will be a concatenated 1D + buffer. Element-wise operation on this buffer will be equivalent to + operating individually. + + Args: + tensors (Iterable[Tensor]): dense tensors to flatten. + Returns: + A contiguous 1D buffer containing input tensors. + """ + return torch._C._nn.flatten_dense_tensors(tensors) + + def _unflatten_dense_tensors(self, flat, tensors): + """ + View a flat buffer using the sizes of tensors. Assume that tensors are of + same dense type, and that flat is given by _flatten_dense_tensors. + + Args: + flat (Tensor): flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened dense tensors with sizes same as tensors and values from + flat. + """ + return torch._C._nn.unflatten_dense_tensors(flat, tensors) From 917bf5e00d5c09e6d1d28d48db41690dcbb589b5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Nov 2021 17:33:55 +0800 Subject: [PATCH 0286/1892] init data parallel --- cube/codegen/codegen.py | 24 +++++++++++- cube/execplan/planpass/gfuse.py | 69 +++++++++++++++++++++++++++++++++ cube/graph/operator/operator.py | 19 ++++++++- cube/schedule/su.py | 4 +- cube/schedule/sugraph.py | 6 +++ 5 files changed, 118 insertions(+), 4 deletions(-) create mode 100644 cube/execplan/planpass/gfuse.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index a9c6de78..cea8946c 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -5,6 +5,7 @@ from numpy import isin import torch import copy +from cube.graph.operator.operator import IRFwOperation, IROptimOperation from cube.ir.cten import IRTensor from cube.execplan import ExectuionPlan @@ -35,6 +36,8 @@ def su_naming(self, su: ScheduleUnit) -> str: return f"comm{su._id}" if su.stype == SUType.Transform: return f"trans{su._id}" + if su.stype == SUType.Optimizer: + return f"optim{su._id}" def tensor_naming(self, tensor: Any) -> str: """ @@ -97,12 +100,14 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # parse graph body for su in device_sus: for node in su.nodes(): + if isinstance(node, IRFwOperation): + self.emit_op_call(node) if isinstance(node, IRTensorTransform): self.emit_transform_call(node) elif isinstance(node, IRCommunication): self.emit_comm_call(node) - else: - self.emit_op_call(node) + elif isinstance(node, IROptimOperation): + self.emit_optim_call(node) # emit input declaration for arg in node.inputs(): self.emit_var_declare(arg) @@ -232,6 +237,21 @@ def emit_transform_call(self, node: IRTensorTransform): code = f'{output} = {output}.requires_grad_()' self.forward_region.append(code) + def emit_optim_call(self, node: IROptimOperation): + ranks = node.ranks + grads = node.inputs() + reducer_name = f'self.reducer{node._id}' + # create reducer in declare region + init_code = f'{reducer_name} = cube.runtime.reducer.Reducer({ranks})' + self.declare_region.append(init_code) + grads = self._forward_region_arg_names(grads) + for grad in grads: + add_param_code = f'{reducer_name}.add_param({grad})' + self.declare_region.append(add_param_code) + # create call in forward region + call_code = f'{reducer_name}.allreduce()' + self.forward_region.append(call_code) + def _forward_region_arg_names(self, tensors: List[Any]): """ Generate arg name list for forward region. diff --git a/cube/execplan/planpass/gfuse.py b/cube/execplan/planpass/gfuse.py new file mode 100644 index 00000000..3fe0e97e --- /dev/null +++ b/cube/execplan/planpass/gfuse.py @@ -0,0 +1,69 @@ +""" +Gradient Allreduce Fusion +""" +from typing import Dict, Tuple, List +from cube.graph.operator.operator import IROptimOperation + +from cube.graph.tensor import IRSubTensor + +from cube.execplan import ExectuionPlan +from cube.schedule.su import SUType, ScheduleUnit +from cube.execplan.planpass.planpass import PlanPass + + +class WeightGradAllreduceFusion(PlanPass): + + @staticmethod + def apply(execplan: ExectuionPlan) -> ExectuionPlan: + """ + Apply weight gradient allreduce fusion + """ + reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() + params = WeightGradAllreduceFusion._get_weight_grads(execplan) + for param in params: + grads = params[param] + ranks = tuple(grads.keys()) # ranks are used for group + grads = [grads[devid][-1] for devid in grads] + if len(ranks) == 1: + continue + if ranks not in reducers: + reducers[ranks] = list() + for grad in grads: + reducers[ranks].append(grad) + # generate reducer for each rank + for ranks in reducers: + grads = reducers[ranks] + # even though some ranks don't need allreduce, + # pytorch still requires each rank simutaneously call the + # communication group initialization + for devid in execplan.devices(): + opt_op = IROptimOperation(grads, ranks) + reduce_su = ScheduleUnit([opt_op], SUType.Optimizer) + reduce_su.device = devid + execplan.at(devid).append(reduce_su) + return execplan + + @staticmethod + def _get_weight_grads(execplan: ExectuionPlan) -> Dict: + """ + Get weight gradient + + Return Dict[IRSubTensor, Dict[int, List[IRSubTensor]]] + (grads = params[param][device]) + """ + # grad = params[param][device] + params = dict() + for devid in execplan.devices(): + bsus = [su for su in execplan.sequence(devid) if su.stype == SUType.Backward] + for bsu in bsus: + # bsu has only one node + for input in bsu.inputs(): + if isinstance(input, IRSubTensor) and input.is_param(): + if input not in params: + params[input] = {devid : list()} + grad = input.grad + assert grad is not None + if grad in params[input][devid]: + raise RuntimeError("Already logged grad?") + params[input][devid].append(grad) + return params diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 0be30a5e..d180e79c 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,11 +1,12 @@ from typing import Any, Optional, Union, List +import copy from cube.ir.cten import IRTensor, IRCell from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory -__call__ = ['IRFwOperation', 'IRBpOperation'] +__all__ = ['IRFwOperation', 'IRBpOperation', 'IRDataOperation', 'IROptimOperation'] class IRFwOperation(IRCell): @@ -227,3 +228,19 @@ def algorithms(self, tag: Optional[str] = None): return None template = factory.algorithms(type(self), tag) return template(self) + + +class IROptimOperation(IRCell): + + def __init__(self, grads: List[IRSubTensor], ranks: List[int], name='optimizer'): + if not all([isinstance(grad, IRSubTensor) and grad.is_grad() for grad in grads]): + raise RuntimeError("Expected a list of gradient IRSubTensor") + if not all([isinstance(rank, int) for rank in ranks]): + raise RuntimeError("Expected a list of int") + signature = None + self._ranks = ranks + super.__init__(name, signature, len(grads), 0) + + @property + def ranks(self): + return copy.copy(self._ranks) diff --git a/cube/schedule/su.py b/cube/schedule/su.py index b28d796f..2c87fe30 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -18,13 +18,15 @@ class SUType(Enum): # ) Backward = 'cube.runtime.executor.backward' - Transform = 'cube.runtime.adapter.transform' + Transform = 'cube.runtime.transform' # cube.runtime.collectives.sendrecv(send_tensors, send_ranks, # recv_shapes, from_ranks # ) Comm = 'cube.runtime.adapter.sendrecv' + Optimizer = 'cube.runtime.reducer.Reduce' + Empty = 'None' diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index eb490a14..c54c5d0f 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -106,6 +106,12 @@ def sus(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") + def get_sus(self, stype: SUType) -> List[ScheduleUnit]: + """ + Get SUs that are of stype + """ + return [su for su in self.sequence if su.stype == stype] + def fsus(self) -> List[ScheduleUnit]: """ Get forward ScheduleUnits sequence. From 2d66b3b085753f2d9e5d670370d1c056f9f30483 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Nov 2021 22:56:16 +0800 Subject: [PATCH 0287/1892] support data parallel by adding reducer --- cube/codegen/codegen.py | 8 +++--- cube/compiler.py | 4 ++- cube/execplan/planpass/gfuse.py | 48 ++++++++++++++++++++++----------- cube/execplan/planpass/merge.py | 41 ++++++++++++---------------- cube/graph/gpass.py | 6 ++--- cube/graph/graph.py | 6 ++--- cube/graph/operator/operator.py | 9 ++++--- cube/runtime/__init__.py | 1 + examples/policy/col_parallel.py | 6 +---- examples/policy/dp_parallel.py | 33 +++++++++++++++++++++++ 10 files changed, 103 insertions(+), 59 deletions(-) create mode 100644 examples/policy/dp_parallel.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index cea8946c..8640cabd 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -238,11 +238,11 @@ def emit_transform_call(self, node: IRTensorTransform): self.forward_region.append(code) def emit_optim_call(self, node: IROptimOperation): - ranks = node.ranks + ranks = list(node.ranks) grads = node.inputs() reducer_name = f'self.reducer{node._id}' # create reducer in declare region - init_code = f'{reducer_name} = cube.runtime.reducer.Reducer({ranks})' + init_code = f'{reducer_name} = cube.runtime.reducer.Reducer(ranks={ranks})' self.declare_region.append(init_code) grads = self._forward_region_arg_names(grads) for grad in grads: @@ -264,7 +264,7 @@ def _forward_region_arg_names(self, tensors: List[Any]): if isinstance(tensor, IRTensor) and tensor.is_param(): named_args.append('self.' + name) else: - named_args.append(self.tensor_naming(name)) + named_args.append(name) return named_args def clear(self): @@ -321,7 +321,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: """ Emit su code """ - fsu_types = [SUType.Forward, SUType.Comm, SUType.Transform] + fsu_types = [SUType.Forward, SUType.Comm, SUType.Transform, SUType.Optimizer] fsign = 'cube.runtime.executor.fexecute({model}, *{inputs})' bsign = 'cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' diff --git a/cube/compiler.py b/cube/compiler.py index 802ef4dd..6a895583 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -11,6 +11,7 @@ from cube.execplan import ExectuionPlan from cube.execplan.planpass.redundant import RemoveRedundantAdapters from cube.execplan.planpass.merge import MergeComputeSU +from cube.execplan.planpass.gfuse import WeightGradAllreduceFusion from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -141,8 +142,9 @@ def decorator(fn: Callable) -> Callable: execplan = ExectuionPlan(sugraph) # plan pass to remove redundant sus execplan = RemoveRedundantAdapters.apply(execplan) - # print(f'> after remove redundant adapters:\n {execplan}') + print(f'> after remove redundant adapters:\n {execplan}') execplan = MergeComputeSU.apply(execplan) + execplan = WeightGradAllreduceFusion.apply(execplan) print(f'> after merge compute SU:\n{execplan}') if torch.distributed.is_initialized(): diff --git a/cube/execplan/planpass/gfuse.py b/cube/execplan/planpass/gfuse.py index 3fe0e97e..f17695c8 100644 --- a/cube/execplan/planpass/gfuse.py +++ b/cube/execplan/planpass/gfuse.py @@ -2,8 +2,11 @@ Gradient Allreduce Fusion """ from typing import Dict, Tuple, List -from cube.graph.operator.operator import IROptimOperation +import sys +import copy + +from cube.graph.operator.operator import IROptimOperation from cube.graph.tensor import IRSubTensor from cube.execplan import ExectuionPlan @@ -19,25 +22,31 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: Apply weight gradient allreduce fusion """ reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() - params = WeightGradAllreduceFusion._get_weight_grads(execplan) - for param in params: - grads = params[param] + weights, params = WeightGradAllreduceFusion._get_weight_grads(execplan) + for param_id in params: + grads = params[param_id] ranks = tuple(grads.keys()) # ranks are used for group - grads = [grads[devid][-1] for devid in grads] if len(ranks) == 1: continue + grads_num = [len(grads[devid]) for devid in grads] + if len(set(grads_num)) > 1: + sys.stderr.write("May require weighted allreduce!\n") if ranks not in reducers: reducers[ranks] = list() - for grad in grads: - reducers[ranks].append(grad) + reducers[ranks].append(weights[param_id]) # generate reducer for each rank for ranks in reducers: - grads = reducers[ranks] + weights = reducers[ranks] # even though some ranks don't need allreduce, # pytorch still requires each rank simutaneously call the # communication group initialization for devid in execplan.devices(): - opt_op = IROptimOperation(grads, ranks) + dev_weights = copy.copy(weights) + for idx, weight in enumerate(dev_weights): + if devid not in params[weight._id]: + dev_weights[idx] = None + dev_weights = [w for w in dev_weights if w is not None] + opt_op = IROptimOperation(dev_weights, ranks) reduce_su = ScheduleUnit([opt_op], SUType.Optimizer) reduce_su.device = devid execplan.at(devid).append(reduce_su) @@ -52,18 +61,25 @@ def _get_weight_grads(execplan: ExectuionPlan) -> Dict: (grads = params[param][device]) """ # grad = params[param][device] - params = dict() + grads = dict() + weights = dict() for devid in execplan.devices(): bsus = [su for su in execplan.sequence(devid) if su.stype == SUType.Backward] for bsu in bsus: # bsu has only one node for input in bsu.inputs(): if isinstance(input, IRSubTensor) and input.is_param(): - if input not in params: - params[input] = {devid : list()} + if input._id not in grads: + grads[input._id] = dict() + weights[input._id] = input + if devid not in grads[input._id]: + grads[input._id][devid] = list() grad = input.grad - assert grad is not None - if grad in params[input][devid]: + if grad is None: + print(input.name, input) + print(grad) + assert grad is not None + if grad in grads[input._id][devid]: raise RuntimeError("Already logged grad?") - params[input][devid].append(grad) - return params + grads[input._id][devid].append(grad) + return weights, grads diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index fe54f073..ebf58d17 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -11,13 +11,14 @@ class MergeComputeSU(PlanPass): @staticmethod def apply(execplan: ExectuionPlan) -> ExectuionPlan: """ - Merge consecutive forward SUs + Merge consecutive backward SUs. The forward SUs will + also be merged if possible """ for devid in execplan.devices(): - dev_seq = execplan.sequence(devid) + dev_seq = execplan.sequence(devid) + [None] pieces: List[ScheduleUnit] = list() - for seqidx, su in enumerate(execplan.sequence(devid)): - if su.stype in [SUType.Forward]: + for seqidx, su in enumerate(dev_seq): + if su and su.stype in [SUType.Backward]: allow_merge = len(pieces) == 0 for psu in pieces[::-1]: if execplan.sugraph.happen_before(psu, su): @@ -25,31 +26,23 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: break if allow_merge: dev_seq[seqidx] = None - if su.mirror is not None: - if su.mirror not in dev_seq: - raise RuntimeError( - "Expected backward and forward on same device") - idx = dev_seq.index(su.mirror) - dev_seq[idx] = None pieces.append(su) continue - # merge pieces + # merged forward su if len(pieces) > 0: - # merged forward su - mfsu = MergeComputeSU._merge(pieces, devid) + fsus = [bsu.mirror for bsu in pieces][::-1] + if not all([fsu and (fsu in dev_seq) for fsu in fsus]): + raise RuntimeError("Expected same device fw-bw") + mfsu = MergeComputeSU._merge(fsus, devid) mbsu = mfsu.mirror - # insert merged forward su - dev_seq[seqidx-1] = mfsu # insert merged backward su - bidx = len(dev_seq) - for fsu in pieces: - bsu = fsu.mirror - if bsu is not None: - idx = execplan.sequence(devid).index(bsu) - dev_seq[idx] = None - bidx = min(bidx, idx) - if bidx != len(dev_seq): - dev_seq[bidx] = mbsu + dev_seq[seqidx-1] = mbsu + fsus_idx = [dev_seq.index(fsu) for fsu in fsus] + # insert merged forward su + if max(fsus_idx) - min(fsus_idx) == len(fsus) - 1: + for fidx in fsus_idx: + dev_seq[fidx] = None + dev_seq[min(fsus_idx)] = mfsu pieces = list() dev_seq = [su for su in dev_seq if su is not None] execplan.set(devid, dev_seq) diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index b5fc7a84..e264a368 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -81,14 +81,14 @@ def forward(graph, *args) -> IRGraph: bnode = IRBpOperation(data_num=len(inputs), grad_num=len(outputs)) # set backward grad for idx, val in enumerate(fnode.inputs()): - # set input - bnode.set_data(idx, val) - # set gradient output grad = None if isinstance(val, IRSubTensor): # TODO: requires_grad = False should be set to None grad = val.get_grad(fnode) val.grad = grad + # set input + bnode.set_data(idx, val) + # set gradient output bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): # set gradient input diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 36857665..c2a9af4f 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -293,17 +293,17 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional grad_num=len(fnode.outputs()) ) for idx, val in enumerate(fnode.inputs()): - bnode.set_data(idx, val) grad = None if isinstance(val, IRSubTensor): - # TODO: remove grad is grad doesn't require it + # TODO: requires_grad = False should be set to None grad = val.get_grad(fnode) val.grad = grad + bnode.set_data(idx, val) bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): grad = None if isinstance(val, IRSubTensor): - # TODO: remove grad is grad doesn't require it + # TODO: requires_grad = False should be set to None grad = val.get_grad(fnode) val.grad = grad bnode.set_grad(idx, grad) diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index d180e79c..a4d871b5 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -232,14 +232,17 @@ def algorithms(self, tag: Optional[str] = None): class IROptimOperation(IRCell): - def __init__(self, grads: List[IRSubTensor], ranks: List[int], name='optimizer'): - if not all([isinstance(grad, IRSubTensor) and grad.is_grad() for grad in grads]): + def __init__(self, weights: List[IRSubTensor], ranks: List[int], name='optimizer'): + if not all([isinstance(w, IRSubTensor) and w.is_param() for w in weights]): raise RuntimeError("Expected a list of gradient IRSubTensor") if not all([isinstance(rank, int) for rank in ranks]): raise RuntimeError("Expected a list of int") signature = None self._ranks = ranks - super.__init__(name, signature, len(grads), 0) + + super().__init__(name, signature, len(weights), 0) + for idx, weight in enumerate(weights): + self.set_input(idx, weight) @property def ranks(self): diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index a2112ae0..c66a7d0c 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -1,4 +1,5 @@ from cube.runtime import collectives, executor, transform from cube.runtime import device +from cube.runtime import reducer from cube.runtime import syndata from cube.runtime import resource \ No newline at end of file diff --git a/examples/policy/col_parallel.py b/examples/policy/col_parallel.py index 4fbb2937..98be31f7 100644 --- a/examples/policy/col_parallel.py +++ b/examples/policy/col_parallel.py @@ -1,11 +1,7 @@ - - - -import cube from cube.graph import IRGraph from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.graph.operator.operator import IRFwOperation def transform_policy(graph: IRGraph, resource): diff --git a/examples/policy/dp_parallel.py b/examples/policy/dp_parallel.py new file mode 100644 index 00000000..36f52d3d --- /dev/null +++ b/examples/policy/dp_parallel.py @@ -0,0 +1,33 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using column parallel + """ + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + algo = node.algorithms('data') + assert algo + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + # sugraph.assign(su, list(range(resource.ngpus))) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + return sugraph From 6370e80a4dac0b23a5bad6dcb03b4a28da0318d6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Nov 2021 23:03:42 +0800 Subject: [PATCH 0288/1892] fix value map bug --- cube/schedule/adapter/transform.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cube/schedule/adapter/transform.py b/cube/schedule/adapter/transform.py index bc767620..968e4698 100644 --- a/cube/schedule/adapter/transform.py +++ b/cube/schedule/adapter/transform.py @@ -280,12 +280,9 @@ def add(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: if chunk_num % 2 != 0: return None chunk_num = int(chunk_num // 2) - if chunk_num == 1: - idx = 0 - else: - if int(idx1 // chunk_num) != int(idx2 // chunk_num): - return None - idx = int(idx1 // chunk_num) + if int(idx1 // 2) != int(idx2 // 2): + return None + idx = int(idx1 // 2) mtensor = tensor1.parent.select( indices = tensor1.indices, val_map = (idx, chunk_num), From c3e56993d396a5753abdc5c7a0fafc3d59c7bc5c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 13 Nov 2021 00:10:07 +0800 Subject: [PATCH 0289/1892] remove requires_grad for gradient --- cube/codegen/codegen.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 8640cabd..78fb248d 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -2,7 +2,6 @@ Generate Pytorch code given the model DAG and the transformation config """ from typing import List, Any -from numpy import isin import torch import copy from cube.graph.operator.operator import IRFwOperation, IROptimOperation @@ -231,11 +230,12 @@ def emit_transform_call(self, node: IRTensorTransform): raise RuntimeError(f"Not supported prim: {type(prim)}") for output in node.outputs(): # contiguous and requires grad - output = self.tensor_naming(output) - code = f'{output} = {output}.contiguous()' - self.forward_region.append(code) - code = f'{output} = {output}.requires_grad_()' + output_name = self.tensor_naming(output) + code = f'{output_name} = {output_name}.contiguous()' self.forward_region.append(code) + if not output.is_grad(): + code = f'{output_name} = {output_name}.requires_grad_()' + self.forward_region.append(code) def emit_optim_call(self, node: IROptimOperation): ranks = list(node.ranks) From f3f7b3a026648d26de8c0f3df3afbe86e11d658c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 13 Nov 2021 15:16:48 +0800 Subject: [PATCH 0290/1892] commit set up profiler --- cube/profiler/__init__.py | 1 + cube/{utils.py => profiler/timer.py} | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 cube/profiler/__init__.py rename cube/{utils.py => profiler/timer.py} (91%) diff --git a/cube/profiler/__init__.py b/cube/profiler/__init__.py new file mode 100644 index 00000000..5649c2ee --- /dev/null +++ b/cube/profiler/__init__.py @@ -0,0 +1 @@ +from cube.profiler.timer import CudaTimer \ No newline at end of file diff --git a/cube/utils.py b/cube/profiler/timer.py similarity index 91% rename from cube/utils.py rename to cube/profiler/timer.py index 5bd9d229..0cbb31b7 100644 --- a/cube/utils.py +++ b/cube/profiler/timer.py @@ -1,6 +1,7 @@ +import time import sys + import torch -import time def print_each_rank(msg, rank_only=None, outfile=''): @@ -23,11 +24,9 @@ def print_each_rank(msg, rank_only=None, outfile=''): class CudaTimer: - - """ + r""" Singleton Timer """ - class __CudaTimer: def __init__(self): @@ -43,6 +42,11 @@ def __init__(self): CudaTimer.instance = CudaTimer.__CudaTimer() def start(self, field_name='default'): + """ + Start recording time on the the field + + Note `start` and `stop` on the same field can be called nestly + """ torch.cuda.synchronize() if field_name not in CudaTimer.instance.field: CudaTimer.instance.field[field_name] = list() @@ -51,7 +55,10 @@ def start(self, field_name='default'): def stop(self, field_name='default'): """ - Return in ms + Return the time span from last `start` on the smae field name to now + + Returns: + float (ms) """ if field_name not in CudaTimer.instance.field: raise RuntimeError("Missing start on the field") From 4529e1849ba8771840feb9c0aacd58ffdf1827f4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 13 Nov 2021 15:22:39 +0800 Subject: [PATCH 0291/1892] add readme --- README.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 00000000..54faa364 --- /dev/null +++ b/README.md @@ -0,0 +1,25 @@ +# MagicCube + +AI System Compiler to compile a semantic (single-device) model to distributed model using policies specified by System Expert. + +## Install + +```python +pip install -r requirements.txt +python setup.py develop +``` + +## Run Examples + +* [Micro Benchmark] Run a mutiple MLP Model + +```sh +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/linears.py +``` From db0d71020a51e879804425c140aef781739701a3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 13 Nov 2021 15:24:39 +0800 Subject: [PATCH 0292/1892] fix bug: unneccessary grad scale --- cube/codegen/codegen.py | 4 ++++ cube/runtime/reducer.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 78fb248d..4c5691e6 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -376,6 +376,10 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: fout_grads = list() for fout in fsu.outputs(): + # the loss computed starting point + if fout == fout.grad: + #TODO: mean<0, N> needs to divide by N times + pass fout_grads.append(self.tensor_naming(fout.grad)) fout_grads = '(' + ', '.join(fout_grads + ['']) + ')' diff --git a/cube/runtime/reducer.py b/cube/runtime/reducer.py index a946ffbe..37125233 100644 --- a/cube/runtime/reducer.py +++ b/cube/runtime/reducer.py @@ -38,7 +38,7 @@ def allreduce(self): bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = self._flatten_dense_tensors(grads) - coalesced /= len(self.ranks) + # coalesced /= len(self.ranks) torch.distributed.all_reduce(coalesced, group=self._group) all_synced = self._unflatten_dense_tensors(coalesced, grads) for grad, synced in zip(grads, all_synced): From 88a084c3f858e40c31d3a0696fff195a70339c18 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 13 Nov 2021 16:35:43 +0800 Subject: [PATCH 0293/1892] operator repr info update --- cube/graph/graph.py | 4 ++-- cube/graph/operator/operator.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index c2a9af4f..bb670c79 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -271,10 +271,10 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional nodes: List[IRCell] if partitioned successfully. None if failed """ - if not isinstance(op, IRCell): - raise TypeError("Expected op to be IRCell (IRFwOperation)") if not isinstance(algo, GenericDistAlgo): raise TypeError("Expected algo to be GenericDistAlgo") + if op not in self.nodes(): + raise RuntimeError("Not Exist: {op}") if algo.logic_op != type(op): return None diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index a4d871b5..136c5128 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -92,7 +92,8 @@ def __repr__(self): else: outputs.append(tensor) - dscp = f'Op(id={self._id}, signature={self.signature}, device={self.device}, inputs={inputs}, outputs={outputs})' + sign = self.signature.split('.')[-1] + dscp = f'Op{self._id}(sign={sign}, inputs={inputs}, outputs={outputs})' return dscp @@ -196,7 +197,8 @@ def __repr__(self): else: outputs.append(tensor) - dscp = f'bOp(id={self._id}, signature={self.signature}, device={self.device}, grads={grads}, datas={datas}, outputs={outputs})' + sign = self.signature.split('.')[-1] + dscp = f'bOp{self._id}(sign={sign}, grads={grads}, datas={datas}, outputs={outputs})' return dscp @@ -207,6 +209,12 @@ def __init__(self, data_num: int, name='dataloader'): signature = 'dataloader.__next__' super().__init__(name, signature, 0, data_num) + def infer_shape(self): + """ + Infer output value shape + """ + return True + def algorithms(self, tag: Optional[str] = None): """ get algorithm from algorithm factory From 6bae7d79a23b50185a4f6ff158b1907fba8fdb26 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 14 Nov 2021 15:58:19 +0800 Subject: [PATCH 0294/1892] sorted ranks for redundant --- cube/execplan/planpass/gfuse.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/execplan/planpass/gfuse.py b/cube/execplan/planpass/gfuse.py index f17695c8..e2710d71 100644 --- a/cube/execplan/planpass/gfuse.py +++ b/cube/execplan/planpass/gfuse.py @@ -25,7 +25,9 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: weights, params = WeightGradAllreduceFusion._get_weight_grads(execplan) for param_id in params: grads = params[param_id] - ranks = tuple(grads.keys()) # ranks are used for group + ranks = list(grads.keys()) + ranks.sort() + ranks = tuple(ranks) # ranks are used for group if len(ranks) == 1: continue grads_num = [len(grads[devid]) for devid in grads] From 08a074c43a7d2415c442b19835f57d66bf5e5f01 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 14 Nov 2021 15:59:54 +0800 Subject: [PATCH 0295/1892] repr for better debugging --- cube/graph/operator/operator.py | 15 ++++++++++----- cube/graph/tensor.py | 2 -- cube/schedule/adapter/transform.py | 3 ++- cube/schedule/sugraph.py | 18 ++++++++++++------ 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 136c5128..5b63647d 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -76,7 +76,8 @@ def __repr__(self): anno = 'w' if tensor.is_grad(): anno = 'g' - inputs.append(f'{anno}{tensor._id}') + # inputs.append(f'{anno}{tensor._id}') + inputs.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') else: inputs.append(tensor) @@ -88,7 +89,8 @@ def __repr__(self): anno = 'w' if tensor.is_grad(): anno = 'g' - outputs.append(f'{anno}{tensor._id}') + # outputs.append(f'{anno}{tensor._id}') + outputs.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') else: outputs.append(tensor) @@ -169,7 +171,8 @@ def __repr__(self): anno = 'w' if tensor.is_grad(): anno = 'g' - datas.append(f'{anno}{tensor._id}') + # datas.append(f'{anno}{tensor._id}') + datas.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') else: datas.append(tensor) @@ -181,7 +184,8 @@ def __repr__(self): anno = 'w' if tensor.is_grad(): anno = 'g' - grads.append(f'{anno}{tensor._id}') + # grads.append(f'{anno}{tensor._id}') + grads.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') else: grads.append(tensor) @@ -193,7 +197,8 @@ def __repr__(self): anno = 'w' if tensor.is_grad(): anno = 'g' - outputs.append(f'{anno}{tensor._id}') + # outputs.append(f'{anno}{tensor._id}') + outputs.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') else: outputs.append(tensor) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 7d6da53b..e3fe7839 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -22,8 +22,6 @@ from typing import List, Optional, Union, Tuple import copy -from numpy.lib.arraysetops import isin - from cube.ir.cten import IRCell, IRTensor diff --git a/cube/schedule/adapter/transform.py b/cube/schedule/adapter/transform.py index 968e4698..fee89af9 100644 --- a/cube/schedule/adapter/transform.py +++ b/cube/schedule/adapter/transform.py @@ -121,7 +121,8 @@ def gen(input: IRSubTensor, outputs: List[IRSubTensor]) -> List[SelectPrim]: elif input.val_map == ValueMap(0, 1): val_map = output.val_map else: - print(output) + print('from: ', input) + print('to : ', output) raise NotImplementedError( f"Not supported value select: {input.val_map} -> {output.val_map}" ) diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index c54c5d0f..ef564c9b 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -457,9 +457,12 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: recv_su.device = su.device # add adapter for merge if len(tensor_segments) != 0: - merge_op = IRTensorTransform( - src_tensors=tensor_segments, dst_tensors=[input] - ) + try: + merge_op = IRTensorTransform( + src_tensors=tensor_segments, dst_tensors=[input] + ) + except Exception: + raise RuntimeError(f"Merge Generation Error: {su}") merge_su = ScheduleUnit([merge_op], SUType.Transform, name='merge') su._set_merge_adapter(in_idx, merge_su) merge_su.device = su.device @@ -476,9 +479,12 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: if tensor != output and tensor not in select_tensors: select_tensors.append(tensor) if len(select_tensors) != 0: - select_op = IRTensorTransform( - src_tensors=[output], dst_tensors=select_tensors - ) + try: + select_op = IRTensorTransform( + src_tensors=[output], dst_tensors=select_tensors + ) + except Exception: + raise RuntimeError(f"Select Generation Error: {su}") select_su = ScheduleUnit( [select_op], SUType.Transform, name='select' ) From ad5ca331b73039af9ae7402ed9ee0c92d6443bbc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 14 Nov 2021 16:08:23 +0800 Subject: [PATCH 0296/1892] fix a bug in gradient value map generation: nested transformation will cause gradient's value map drift --- cube/compiler.py | 2 +- cube/graph/graph.py | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 6a895583..244d053e 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -142,7 +142,7 @@ def decorator(fn: Callable) -> Callable: execplan = ExectuionPlan(sugraph) # plan pass to remove redundant sus execplan = RemoveRedundantAdapters.apply(execplan) - print(f'> after remove redundant adapters:\n {execplan}') + # print(f'> after remove redundant adapters:\n {execplan}') execplan = MergeComputeSU.apply(execplan) execplan = WeightGradAllreduceFusion.apply(execplan) print(f'> after merge compute SU:\n{execplan}') diff --git a/cube/graph/graph.py b/cube/graph/graph.py index bb670c79..5d824308 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,7 +7,7 @@ will be inserted at scheduling time. """ -from typing import Union, Tuple, List, Optional, Any, Dict +from typing import Union, Tuple, List, Optional, Dict import copy from cube.graph.operator.operator import IRBpOperation @@ -274,7 +274,7 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional if not isinstance(algo, GenericDistAlgo): raise TypeError("Expected algo to be GenericDistAlgo") if op not in self.nodes(): - raise RuntimeError("Not Exist: {op}") + raise RuntimeError(f"Not Exist: {op}") if algo.logic_op != type(op): return None @@ -286,6 +286,19 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional op.set_input(idx, None) # set backward mirror node if op.mirror is not None: + # go through related op to reset the related gradient + for fnode in fnodes: + for val in fnode.inputs(): + if not isinstance(val, IRSubTensor): + continue + # TODO: requires_grad = False should be set to None + val.grad = val.get_grad(fnode) + for related_op in val.parent.forward_dst_cells(): + for idx, rval in enumerate(related_op.inputs()): + if val.overlap(rval): + rval.grad = rval.get_grad(related_op) + if related_op.mirror is not None: + related_op.mirror.set_output(idx, rval.grad) # generate mirror node for fnode in fnodes: bnode = IRBpOperation( @@ -295,9 +308,7 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional for idx, val in enumerate(fnode.inputs()): grad = None if isinstance(val, IRSubTensor): - # TODO: requires_grad = False should be set to None - grad = val.get_grad(fnode) - val.grad = grad + grad = val.grad bnode.set_data(idx, val) bnode.set_output(idx, grad) for idx, val in enumerate(fnode.outputs()): From d57104f31e0f9b6fbbb3a01e7f213f349b208ddf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 14 Nov 2021 16:08:49 +0800 Subject: [PATCH 0297/1892] add nested parallelism example --- examples/linears.py | 12 ++++--- examples/policy/dp_parallel.py | 12 +++---- examples/policy/nested_parallel.py | 53 ++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 11 deletions(-) create mode 100644 examples/policy/nested_parallel.py diff --git a/examples/linears.py b/examples/linears.py index 4a6b4ba3..256f1f7c 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -2,7 +2,7 @@ example: python -m torch.distributed.launch \ - --nproc_per_node=2 \ + --nproc_per_node=4 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ @@ -15,8 +15,8 @@ from torch import nn import cube -from examples.policy.hybrid_parallel import transform_policy -from examples.policy.hybrid_parallel import schedule_policy +from examples.policy.nested_parallel import transform_policy +from examples.policy.nested_parallel import schedule_policy # =================== Semantic Model Description ==================== @@ -58,11 +58,13 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - for epoch in range(10): + iter_num = 128 + for step in range(iter_num): train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() - print('====> end iteration') + if (step + 1) % 20 == 0: + print(f'iter [{step + 1}/{iter_num}]') if __name__ == '__main__': diff --git a/examples/policy/dp_parallel.py b/examples/policy/dp_parallel.py index 36f52d3d..6d1bb5c4 100644 --- a/examples/policy/dp_parallel.py +++ b/examples/policy/dp_parallel.py @@ -1,15 +1,15 @@ from cube.graph import IRGraph from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation +from cube.graph.operator.operator import IRDataOperation, IRFwOperation def transform_policy(graph: IRGraph, resource): """ - The transformation policy transposes linear using column parallel + The transformation policy transposes linear using data parallel """ - for node in graph.nodes(): - if isinstance(node, IRFwOperation): + for node in graph.nodes(): + if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): algo = node.algorithms('data') assert algo sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) @@ -24,8 +24,8 @@ def schedule_policy(sugraph: SUGraph, resource): """ for su in sugraph.sus(): if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - # sugraph.assign(su, list(range(resource.ngpus))) + devid = su.tag[0] + sugraph.assign(su, devid) for su in sugraph.fsus(): devid = su.tag[0] sugraph.assign(su, devid) diff --git a/examples/policy/nested_parallel.py b/examples/policy/nested_parallel.py new file mode 100644 index 00000000..56c369b9 --- /dev/null +++ b/examples/policy/nested_parallel.py @@ -0,0 +1,53 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRDataOperation, IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using data parallel + """ + tp = 2 + dp = int(resource.ngpus // tp) + for node in graph.nodes(): + # partition data loader at data dimension + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=dp)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx * tp + # partition operators first in column and then in data + if isinstance(node, IRFwOperation): + all_sub_nodes = list() + if node.algorithms('column') is not None: + algo = node.algorithms('column') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=tp)) + for sub_node in sub_nodes: + algo = sub_node.algorithms('data') + ssub_nodes = graph.partition(sub_node, algo, config=dict(chunk_num=dp)) + all_sub_nodes += ssub_nodes + else: + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + all_sub_nodes += sub_nodes + # add tags (vdev) for node + for idx, ssub_node in enumerate(all_sub_nodes): + ssub_node.tag = idx + # print(graph) + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + devid = su.tag[0] + sugraph.assign(su, devid) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + return sugraph From b7bdf73e93eebe59a7d1532977878a420110f4f0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 14 Nov 2021 17:22:47 +0800 Subject: [PATCH 0298/1892] parameter sync for correctness --- cube/codegen/codegen.py | 22 ++++++++++++++++------ cube/compiler.py | 2 ++ cube/runtime/__init__.py | 3 ++- cube/runtime/module.py | 22 ++++++++++++++++++++++ cube/runtime/reducer.py | 8 ++++++++ 5 files changed, 50 insertions(+), 7 deletions(-) create mode 100644 cube/runtime/module.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 4c5691e6..4b8da4ad 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -106,6 +106,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: elif isinstance(node, IRCommunication): self.emit_comm_call(node) elif isinstance(node, IROptimOperation): + self.emit_optim_init(node) self.emit_optim_call(node) # emit input declaration for arg in node.inputs(): @@ -118,7 +119,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: self.forward_region = list() # generate full code - with ClassBlock(class_name='GenModel', derived=['torch.nn.Module']) as cb: + with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.declare_region) cb.insert_body('') @@ -237,18 +238,27 @@ def emit_transform_call(self, node: IRTensorTransform): code = f'{output_name} = {output_name}.requires_grad_()' self.forward_region.append(code) - def emit_optim_call(self, node: IROptimOperation): + def emit_optim_init(self, node: IROptimOperation): + # reducer init interface + reducer_init = '{reducer} = cube.runtime.reducer.Reducer(ranks={ranks})' + reducer_add = 'self.add_reducer({reducer})' + add_param = '{reducer}.add_param({grad})' + # create reducer in declare region ranks = list(node.ranks) grads = node.inputs() reducer_name = f'self.reducer{node._id}' - # create reducer in declare region - init_code = f'{reducer_name} = cube.runtime.reducer.Reducer(ranks={ranks})' + self.declare_region.append('') + init_code = reducer_init.format(reducer=reducer_name, ranks=ranks) self.declare_region.append(init_code) grads = self._forward_region_arg_names(grads) for grad in grads: - add_param_code = f'{reducer_name}.add_param({grad})' + add_param_code = add_param.format(reducer=reducer_name, grad=grad) self.declare_region.append(add_param_code) - # create call in forward region + add_code = reducer_add.format(reducer=reducer_name) + self.declare_region.append(add_code) + + def emit_optim_call(self, node: IROptimOperation): + reducer_name = f'self.reducer{node._id}' call_code = f'{reducer_name}.allreduce()' self.forward_region.append(call_code) diff --git a/cube/compiler.py b/cube/compiler.py index 244d053e..9431c857 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -38,6 +38,8 @@ def load_module(self, filename: str): module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self._loaded_module = module.GenModel().cuda() + # sync parameters before start training + self._loaded_module.sync_params() def get_gen_module(self): return self._loaded_module diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index c66a7d0c..993a8c63 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -2,4 +2,5 @@ from cube.runtime import device from cube.runtime import reducer from cube.runtime import syndata -from cube.runtime import resource \ No newline at end of file +from cube.runtime import resource +from cube.runtime import module \ No newline at end of file diff --git a/cube/runtime/module.py b/cube/runtime/module.py new file mode 100644 index 00000000..a6c5a688 --- /dev/null +++ b/cube/runtime/module.py @@ -0,0 +1,22 @@ +import torch +from cube.runtime.reducer import Reducer + + +class CubeModule(torch.nn.Module): + """ + The module is responsible for parameter synchronization + before training + """ + + def __init__(self): + super().__init__() + self._reducers = list() + + def add_reducer(self, reducer: Reducer): + if not isinstance(reducer, Reducer): + raise RuntimeError(f"Expected a Reducer but got {type(reducer)}") + self._reducers.append(reducer) + + def sync_params(self): + for reducer in self._reducers: + reducer.sync() diff --git a/cube/runtime/reducer.py b/cube/runtime/reducer.py index 37125233..012c9375 100644 --- a/cube/runtime/reducer.py +++ b/cube/runtime/reducer.py @@ -44,6 +44,14 @@ def allreduce(self): for grad, synced in zip(grads, all_synced): grad.copy_(synced) + def sync(self): + """ + Sync parameters before training + """ + for param in self._params: + torch.distributed.broadcast(param, self.ranks[0], group=self._group) + torch.cuda.synchronize() + def _flatten_dense_tensors(self, tensors): """ Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of From 8f3b0387a2cd42cc02d35ab611f17c18c8b94448 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 14 Nov 2021 18:29:18 +0800 Subject: [PATCH 0299/1892] fix bugs in value map submap --- cube/graph/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index e3fe7839..213219d3 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -213,7 +213,7 @@ class ValueMap: def __init__(self, idx: int, chunk_num: int): if idx >= chunk_num or idx < 0: - raise ValueError("Expected idx in [0, chunk_num)") + raise ValueError(f"Expected idx {idx} in [0, {chunk_num})") self._idx = idx self._chunk_num = chunk_num @@ -228,7 +228,7 @@ def chunk_num(self): def map(self, sub_map): if not isinstance(sub_map, ValueMap): raise TypeError("Expected sub_map to be ValueMap") - idx = self.chunk_num * self.idx + sub_map.idx + idx = self.idx * sub_map.chunk_num + sub_map.idx chunk_num = self.chunk_num * sub_map.chunk_num return ValueMap(idx, chunk_num) From d8ad9d588c2885b42d4df931737d2038c8f3fc81 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 00:39:51 +0800 Subject: [PATCH 0300/1892] add constraints to partition algo --- cube/graph/graph.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5d824308..50e57ed5 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -281,6 +281,15 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional if not algo.satisfy(config): return None fnodes = algo.instantiate(op, config) + + #FIXME: we don't allow non-weight input to be splitted in value + for fnode in fnodes: + for input in fnode.inputs(): + if isinstance(input, IRSubTensor): + if input.val_map.chunk_num != 1 and not input.is_param(): + raise NotImplementedError( + f"Not support feature-map {input} to be splitted in value as input" + ) # remove reference for idx in range(len(op.inputs())): op.set_input(idx, None) From 258dc223c8c1185ebe8907f1d91e1c33f43e16de Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 00:41:21 +0800 Subject: [PATCH 0301/1892] fix correctness bugs: grad of output should have no val map --- cube/graph/tensor.py | 15 ++++++++++++--- cube/schedule/translator.py | 10 ++++++---- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 213219d3..927d2ad3 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -15,12 +15,13 @@ 2). for (FwOp) output tensors, gradient SubTensor is: indices = output.indices; - val follows same value splitting rules with output + val is always (0/1) """ from typing import List, Optional, Union, Tuple import copy +import math from cube.ir.cten import IRCell, IRTensor @@ -241,7 +242,15 @@ def overlap(self, other): if self.chunk_num == 1 or other.chunk_num == 1: return True else: - raise NotImplementedError("Not Implemented") + chk1, chk2 = self.chunk_num, other.chunk_num + time1 = int(chk2 / math.gcd(chk1, chk2)) + time2 = int(chk1 / math.gcd(chk1, chk2)) + span1 = (self.idx * time1, self.idx * time1 + time1) + span2 = (other.idx * time2, other.idx * time2 + time2) + if max(span1[0], span2[0]) < min(span1[1], span2[1]): + return True + else: + return False def __eq__(self, other): if isinstance(other, ValueMap): @@ -604,7 +613,7 @@ def get_grad(self, fcell: IRCell): elif self in fcell.outputs(): grad = full_grad.select( indices = self.indices, - val_map = self.val_map, + val_map = (0, 1), shape = self.shape ) return grad.as_grad() diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index 15c0723b..09525f9d 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -77,14 +77,16 @@ def backward(loss: IRSubTensor): trace = SchedulePool().get_tape(loss) if trace is None: raise RuntimeError("No forward detected") - # make gradient point to it self - loss.parent.grad = loss.parent + # make grad to 1.0 + if not loss.shape == [1]: + raise RuntimeError("backward can only perform on the scaler tensor") + loss.parent.grad = None bnode = None for node in trace: for idx, output in enumerate(node.outputs()): if loss.overlap(output): bnode = node.mirror - output.grad = output - bnode.set_grad(idx, output) + output.grad = None + bnode.set_grad(idx, None) for node in trace[::-1]: SchedulePool().add_node(node.mirror) From a046d5c80a6d00fec4149c9b2fae20f04ec039a5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 01:44:41 +0800 Subject: [PATCH 0302/1892] fix bugs on merge backward SUs --- cube/execplan/planpass/gfuse.py | 3 --- cube/execplan/planpass/merge.py | 20 +++++++++++++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/cube/execplan/planpass/gfuse.py b/cube/execplan/planpass/gfuse.py index e2710d71..d324ec24 100644 --- a/cube/execplan/planpass/gfuse.py +++ b/cube/execplan/planpass/gfuse.py @@ -30,9 +30,6 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: ranks = tuple(ranks) # ranks are used for group if len(ranks) == 1: continue - grads_num = [len(grads[devid]) for devid in grads] - if len(set(grads_num)) > 1: - sys.stderr.write("May require weighted allreduce!\n") if ranks not in reducers: reducers[ranks] = list() reducers[ranks].append(weights[param_id]) diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index ebf58d17..6ce3a8f3 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -17,15 +17,23 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: for devid in execplan.devices(): dev_seq = execplan.sequence(devid) + [None] pieces: List[ScheduleUnit] = list() + adapters: List[ScheduleUnit] = list() for seqidx, su in enumerate(dev_seq): + if su and su.stype in [SUType.Comm, SUType.Transform]: + if len(pieces) > 0: + adapters.append(su) + continue if su and su.stype in [SUType.Backward]: allow_merge = len(pieces) == 0 for psu in pieces[::-1]: if execplan.sugraph.happen_before(psu, su): allow_merge = True break + for adapter in adapters: + if execplan.sugraph.happen_before(adapter, su): + allow_merge = False + break if allow_merge: - dev_seq[seqidx] = None pieces.append(su) continue # merged forward su @@ -36,14 +44,20 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: mfsu = MergeComputeSU._merge(fsus, devid) mbsu = mfsu.mirror # insert merged backward su - dev_seq[seqidx-1] = mbsu - fsus_idx = [dev_seq.index(fsu) for fsu in fsus] + mbsu_idx = min([dev_seq.index(bsu) for bsu in pieces]) + for bsu in pieces: + dev_seq[dev_seq.index(bsu)] = None + dev_seq[mbsu_idx] = mbsu # insert merged forward su + fsus_idx = [dev_seq.index(fsu) for fsu in fsus] if max(fsus_idx) - min(fsus_idx) == len(fsus) - 1: for fidx in fsus_idx: dev_seq[fidx] = None dev_seq[min(fsus_idx)] = mfsu pieces = list() + if su and su.stype in [SUType.Backward]: + pieces = [su] + adapters = list() dev_seq = [su for su in dev_seq if su is not None] execplan.set(devid, dev_seq) return execplan From da38b34ae3f55821b312485274c65b1f81b7ca64 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 01:45:09 +0800 Subject: [PATCH 0303/1892] sorted device list --- cube/execplan/execplan.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 2fa2ded4..d1f461e4 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -25,7 +25,9 @@ def devices(self) -> List[int]: """ Get device set """ - return self.device_seq.keys() + devices = list(self.device_seq.keys()) + devices.sort() + return devices def sequence(self, device_id: int) -> List[ScheduleUnit]: """ From 70b56df3a5c38ab672f5bf007286f1fe02031a64 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 01:45:57 +0800 Subject: [PATCH 0304/1892] fix policy bugs and print info --- examples/policy/col_parallel.py | 1 + examples/policy/{dp_parallel.py => data_parallel.py} | 1 + examples/policy/hybrid_parallel.py | 10 +++++++--- 3 files changed, 9 insertions(+), 3 deletions(-) rename examples/policy/{dp_parallel.py => data_parallel.py} (98%) diff --git a/examples/policy/col_parallel.py b/examples/policy/col_parallel.py index 98be31f7..76e2e194 100644 --- a/examples/policy/col_parallel.py +++ b/examples/policy/col_parallel.py @@ -16,6 +16,7 @@ def transform_policy(graph: IRGraph, resource): sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx + print(graph) return graph diff --git a/examples/policy/dp_parallel.py b/examples/policy/data_parallel.py similarity index 98% rename from examples/policy/dp_parallel.py rename to examples/policy/data_parallel.py index 6d1bb5c4..960eeb4b 100644 --- a/examples/policy/dp_parallel.py +++ b/examples/policy/data_parallel.py @@ -15,6 +15,7 @@ def transform_policy(graph: IRGraph, resource): sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx + print(graph) return graph diff --git a/examples/policy/hybrid_parallel.py b/examples/policy/hybrid_parallel.py index 5de96ce1..65b8ed76 100644 --- a/examples/policy/hybrid_parallel.py +++ b/examples/policy/hybrid_parallel.py @@ -8,20 +8,24 @@ def transform_policy(graph: IRGraph, resource): """ The transformation policy transposes linear using column parallel """ + linear_idx = 0 for node in graph.nodes(): - idx = 0 if isinstance(node, IRFwOperation): algo = None - if idx % 2 == 0: + if linear_idx % 2 == 0: + print(f'> column partition: {node}') algo = node.algorithms('column') else: + print(f'> row partition: {node}') algo = node.algorithms('row') if algo is None: + print(f'> data partition: {node}') algo = node.algorithms('data') - idx += 1 + linear_idx += 1 sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx + print(graph) return graph From b535f420505dca438594f1e29dd11b85306bbbdb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 01:47:12 +0800 Subject: [PATCH 0305/1892] add megatron partition policy --- examples/policy/megatron_parallel.py | 61 ++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 examples/policy/megatron_parallel.py diff --git a/examples/policy/megatron_parallel.py b/examples/policy/megatron_parallel.py new file mode 100644 index 00000000..54c2c121 --- /dev/null +++ b/examples/policy/megatron_parallel.py @@ -0,0 +1,61 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation, IRDataOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using data parallel + """ + tp = 2 + dp = int(resource.ngpus // tp) + linear_idx = 0 + for node in graph.nodes(): + # partition data loader at data dimension + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=dp)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx * tp + # partition operators first in column and then in data + if isinstance(node, IRFwOperation): + all_sub_nodes = list() + if node.algorithms('column') is not None: + if linear_idx % 2 == 0: + print(' ==> column partition') + algo = node.algorithms('column') + else: + print(' ==> row partition') + algo = node.algorithms('row') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=tp)) + for sub_node in sub_nodes: + print(' ==> data partition') + algo = sub_node.algorithms('data') + ssub_nodes = graph.partition(sub_node, algo, config=dict(chunk_num=dp)) + all_sub_nodes += ssub_nodes + linear_idx += 1 + else: + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + all_sub_nodes += sub_nodes + # add tags (vdev) for node + for idx, ssub_node in enumerate(all_sub_nodes): + ssub_node.tag = idx + print(graph) + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + devid = su.tag[0] + sugraph.assign(su, devid) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + return sugraph \ No newline at end of file From 7836c0a7a9cf42ffa7849a0f3956d3766b332308 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 01:48:13 +0800 Subject: [PATCH 0306/1892] compiler debug output --- cube/compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 9431c857..12206876 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -144,10 +144,11 @@ def decorator(fn: Callable) -> Callable: execplan = ExectuionPlan(sugraph) # plan pass to remove redundant sus execplan = RemoveRedundantAdapters.apply(execplan) - # print(f'> after remove redundant adapters:\n {execplan}') + print(f'> after remove redundant adapters:\n {execplan}') execplan = MergeComputeSU.apply(execplan) + print(f'> after merge backward SU:\n {execplan}') execplan = WeightGradAllreduceFusion.apply(execplan) - print(f'> after merge compute SU:\n{execplan}') + print(f'> after add allreduce:\n{execplan}') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() From d236f37491f5c529f6c25a8e1d6423369caa3ec9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 01:48:48 +0800 Subject: [PATCH 0307/1892] switch to megatron partition policy --- examples/linears.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/linears.py b/examples/linears.py index 256f1f7c..9293d4d7 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -15,8 +15,8 @@ from torch import nn import cube -from examples.policy.nested_parallel import transform_policy -from examples.policy.nested_parallel import schedule_policy +from examples.policy.hybrid_parallel import transform_policy +from examples.policy.hybrid_parallel import schedule_policy # =================== Semantic Model Description ==================== From e0ebe291ab300b74df4fcbadac52ecf9f629787c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 10:42:52 +0800 Subject: [PATCH 0308/1892] init weight --- cube/compiler.py | 1 + cube/runtime/module.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/cube/compiler.py b/cube/compiler.py index 12206876..5f247522 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -38,6 +38,7 @@ def load_module(self, filename: str): module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self._loaded_module = module.GenModel().cuda() + self._loaded_module.init_param() # sync parameters before start training self._loaded_module.sync_params() diff --git a/cube/runtime/module.py b/cube/runtime/module.py index a6c5a688..e13b5b21 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -20,3 +20,7 @@ def add_reducer(self, reducer: Reducer): def sync_params(self): for reducer in self._reducers: reducer.sync() + + def init_param(self): + for param in self.parameters(): + torch.nn.init.uniform_(param) From 826cb3e0a08c511d2e690c0d1c2596a939dcf6c5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 10:43:26 +0800 Subject: [PATCH 0309/1892] clean debug output --- cube/runtime/collectives.py | 6 +++--- cube/runtime/executor.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 35a41f3f..5d6ad856 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -12,7 +12,7 @@ def send(tensors, to_ranks: List[int]): tensors (List[torch.Tensor]): list of tensor to send tensor_devices (List[List[int]]): tensor sent devices """ - print('sending...') + # print('sending...') send_ops = list() for tensor, rank in zip(tensors, to_ranks): send_op = torch.distributed.P2POp( @@ -26,7 +26,7 @@ def send(tensors, to_ranks: List[int]): def recv(shapes: List[List[int]], from_ranks: List[int]): - print('recving...') + # print('recving...') recv_ops = list() recv_tensors = list() for shape, rank in zip(shapes, from_ranks): @@ -49,7 +49,7 @@ def recv(shapes: List[List[int]], from_ranks: List[int]): def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): - print('sending and recving...') + # print('sending and recving...') ops = list() recv_tensors = list() for tensor, ranks in zip(send_tensors, to_ranks): diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index a5d049c5..44d00e85 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -11,7 +11,7 @@ def fexecute(su: Callable, *input_tensors: Tuple[Any]): forward the SUs """ outputs = su(*input_tensors) - print('forwarding... ') + # print('forwarding... ') return outputs @@ -29,7 +29,7 @@ def backward(input_tensors, output_tensors, output_tensor_grads): ) for tensor, grads in zip(output_tensors, output_tensor_grads): - print('backwarding... ') + # print('backwarding... ') torch.autograd.backward(tensor, grad_tensors=grads) grads = list() for tensor in input_tensors: From 63045a03177ca173d6ffb0c061c6581e2cd5d807 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 15 Nov 2021 15:15:44 +0800 Subject: [PATCH 0310/1892] add benchmark and profilor test --- benchmark/megatron/layers.py | 158 ++++++++++++++++++++++++++++++++++ benchmark/megatron/linears.py | 119 +++++++++++++++++++++++++ examples/linears.py | 13 ++- 3 files changed, 287 insertions(+), 3 deletions(-) create mode 100644 benchmark/megatron/layers.py create mode 100644 benchmark/megatron/linears.py diff --git a/benchmark/megatron/layers.py b/benchmark/megatron/layers.py new file mode 100644 index 00000000..47ee055e --- /dev/null +++ b/benchmark/megatron/layers.py @@ -0,0 +1,158 @@ +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + + +def _reduce(input_): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size() + if world_size == 1: + return input_ + torch.distributed.all_reduce(input_, group=None) + return input_ + + +def _split(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + # Bypass the function if we are using only 1 GPU. + if world_size==1: + return input_ + last_dim = input_.dim() - 1 + last_dim_size = input_.size()[last_dim] // world_size + tensor_list = torch.split(input_, last_dim_size, dim=last_dim) + output = tensor_list[rank].contiguous() + return output + + +def _gather(input_): + """Gather tensors and concatinate along the last dimension.""" + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + # Bypass the function if we are using only 1 GPU. + if world_size==1: + return input_ + # Size and dimension. + last_dim = input_.dim() - 1 + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=None) + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + return output + + +class ColumnInputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return input_ + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class ColumnOutputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _gather(input_) + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + + +class RowInputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _split(input_) + + @staticmethod + def backward(ctx, grad_outputs): + return _gather(grad_outputs) + + +class RowOutputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class ColumnParallelLinear(torch.nn.Module): + + def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.full_input = full_input + self.full_output = full_output + + world_size = torch.distributed.get_world_size() + self.weight = Parameter(torch.empty( + int(self.output_size // world_size), + self.input_size, + )) + if bias: + self.bias = Parameter(torch.empty( + int(self.output_size // world_size), + )) + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + bias = self.bias + if not self.full_input: + raise RuntimeError("Expected full tensor input") + input_parallel = ColumnInputAdapter.apply(input_) + output_parallel = F.linear(input_parallel, self.weight, bias) + if self.full_output: + output = ColumnOutputAdapter.apply(output_parallel) + else: + output = output_parallel + return output + + +class RowParallelLinear(torch.nn.Module): + + def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.full_input = full_input + self.full_output = full_output + + world_size = torch.distributed.get_world_size() + self.weight = Parameter(torch.empty( + self.output_size, + int(self.input_size // world_size), + )) + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + bias = self.bias + if self.full_input: + input_parallel = RowInputAdapter.apply(input_) + else: + input_parallel = input_ + output_parallel = F.linear(input_parallel, self.weight, bias) + if self.full_output: + output = RowOutputAdapter.apply(output_parallel) + else: + output = output_parallel + return output + diff --git a/benchmark/megatron/linears.py b/benchmark/megatron/linears.py new file mode 100644 index 00000000..b05cbb15 --- /dev/null +++ b/benchmark/megatron/linears.py @@ -0,0 +1,119 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + benchmark/megatron/linears.py +""" + +import argparse + +import torch +from torch import nn +from benchmark.megatron.layers import ColumnParallelLinear, RowParallelLinear + +import cube +from cube.profiler import CudaTimer + + +class ColumnMLP(nn.Module): + def __init__(self, dim, mult=16): + super().__init__() + self.linear1 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=True) + self.linear2 = ColumnParallelLinear(dim * mult, dim, full_input=True, full_output=True) + self.linear3 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=True) + self.linear4 = ColumnParallelLinear(dim * mult, dim, full_input=True, full_output=True) + + def forward(self, data): + output = self.linear1(data) + output = self.linear2(output) + output = self.linear3(output) + output = self.linear4(output) + loss = torch.sum(output) + return loss + + +class RowMLP(nn.Module): + def __init__(self, dim, mult=16): + super().__init__() + self.linear1 = RowParallelLinear(dim, dim * mult, full_input=True, full_output=True) + self.linear2 = RowParallelLinear(dim * mult, dim, full_input=True, full_output=True) + self.linear3 = RowParallelLinear(dim, dim * mult, full_input=True, full_output=True) + self.linear4 = RowParallelLinear(dim * mult, dim, full_input=True, full_output=True) + + def forward(self, data): + output = self.linear1(data) + output = self.linear2(output) + output = self.linear3(output) + output = self.linear4(output) + loss = torch.sum(output) + return loss + + +class HybridMLP(nn.Module): + def __init__(self, dim, mult=16): + super().__init__() + self.linear1 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=False) + self.linear2 = RowParallelLinear(dim * mult, dim, full_input=False, full_output=True) + self.linear3 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=False) + self.linear4 = RowParallelLinear(dim * mult, dim, full_input=False, full_output=True) + + def forward(self, data): + output = self.linear1(data) + output = self.linear2(output) + output = self.linear3(output) + output = self.linear4(output) + loss = torch.sum(output) + return loss + + +def train(args): + + batch_size = 128 + dim = 1024 + + # model = ColumnMLP(dim=dim).cuda() + # model = RowMLP(dim=dim).cuda() + model = HybridMLP(dim=dim).cuda() + + for param in model.parameters(): + torch.nn.init.uniform_(param) + + dataloader = cube.runtime.syndata.SynDataLoader(1280, [batch_size, dim]) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 10: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 10: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print(f'iter [{step + 1}/{iter_num}]') + + print('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-10, field_name='e2e'))) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='inspect') + parser.add_argument('--bs', type=int, default=128) + args = parser.parse_args() + + cube.init() + train(args) diff --git a/examples/linears.py b/examples/linears.py index 9293d4d7..b34afd4b 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -15,6 +15,7 @@ from torch import nn import cube +from cube.profiler import CudaTimer from examples.policy.hybrid_parallel import transform_policy from examples.policy.hybrid_parallel import schedule_policy @@ -23,9 +24,9 @@ class MLP(nn.Module): def __init__(self, dim, mult=16): super().__init__() - self.linear1 = nn.Linear(dim, dim * mult, bias=False) + self.linear1 = nn.Linear(dim, dim * mult) self.linear2 = nn.Linear(dim * mult, dim) - self.linear3 = nn.Linear(dim, dim * mult, bias=False) + self.linear3 = nn.Linear(dim, dim * mult) self.linear4 = nn.Linear(dim * mult, dim) def forward(self, data): @@ -52,7 +53,6 @@ def train(): def train_iter(model, dataloader): data = next(dataloader) loss = model(data) - # print(f'loss={loss.item()}') loss.backward() model = model.get_gen_module() @@ -60,12 +60,19 @@ def train_iter(model, dataloader): iter_num = 128 for step in range(iter_num): + if step >= 10: + CudaTimer().start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() + if step >= 10: + CudaTimer().stop('e2e') if (step + 1) % 20 == 0: print(f'iter [{step + 1}/{iter_num}]') + print('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-10, field_name='e2e'))) + if __name__ == '__main__': From dcdcbbb8c378633a52b055e0cc92868014ccef62 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 16 Nov 2021 13:40:10 +0800 Subject: [PATCH 0311/1892] p2p fusion init --- cube/execplan/planpass/p2pfusion.py | 217 +++++++++++++++++++++++++++ cube/runtime/collectives.py | 64 +++++++- cube/schedule/adapter/collectives.py | 46 ++++++ 3 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 cube/execplan/planpass/p2pfusion.py create mode 100644 cube/schedule/adapter/collectives.py diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py new file mode 100644 index 00000000..a7ee18ff --- /dev/null +++ b/cube/execplan/planpass/p2pfusion.py @@ -0,0 +1,217 @@ +from typing import List, Dict +from cube.execplan import ExectuionPlan +from cube.ir.cten import IRTensor +from cube.schedule.su import SUType, ScheduleUnit +from cube.execplan.planpass.planpass import PlanPass + +from cube.schedule.adapter.collectives import IRCollType, IRCollectives + + +class P2PFusion(PlanPass): + + @staticmethod + def apply(execplan: ExectuionPlan) -> ExectuionPlan: + # dict[pid][devid] = list of sub_tensors + fous, fins = P2PFusion.collect_tensors(execplan, SUType.Forward) + bous, bins = P2PFusion.collect_tensors(execplan, SUType.Backward) + # debug + print('=====> forward') + for pid in fins: + if pid not in fous: + continue + if P2PFusion.have_comm(fous[pid], fins[pid]): + print(f'=> parent tensor id: {pid}') + for devid in fous[pid]: + print(f' ==> device: {devid}') + for val in fous[pid][devid]: + print(f' o:', val) + for devid in fins[pid]: + print(f' ==> device: {devid}') + for val in fins[pid][devid]: + print(f' i:', val) + matchers = [ + P2PFusion.match_allreduce, P2PFusion.match_allgather, + P2PFusion.match_reducescatter, P2PFusion.match_broadcast, + ] + for pid in fins: + if pid not in fous: + continue + tous, tins = fous[pid], fins[pid] + if P2PFusion.have_comm(tous, tins): + colls = None + for matcher in matchers: + colls = matcher(tous, tins) + if colls: + break + return execplan + + @staticmethod + def collect_tensors(execplan: ExectuionPlan, stype: SUType): + # dict[pid][devid] = list of sub_tensors + ous = dict() + ins = dict() + for devid in execplan.devices(): + dev_seq = execplan.sequence(devid) + for su in dev_seq: + if su.stype == stype: + for val in su.inputs(): + if isinstance(val, IRTensor): + pid = val.parent._id + if pid not in ins: + ins[pid] = dict() + if devid not in ins[pid]: + ins[pid][devid] = list() + # TODO: may have redundancy + ins[pid][devid].append(val) + for idx, val in enumerate(su.outputs()): + if isinstance(val, IRTensor): + pid = val.parent._id + if pid not in ous: + ous[pid] = dict() + if devid not in ous[pid]: + ous[pid][devid] = list() + select_su = su.select_adapters(idx) + if select_su: + for out in select_su.outputs(): + # TODO: may have redundancy + ous[pid][devid].append(out) + else: + # TODO: may have redundancy + ous[pid][devid].append(val) + return ous, ins + + @staticmethod + def have_comm(tensor_ous, tensor_ins): + """ + Check if they don't have communications + """ + for devid in tensor_ins: + if devid not in tensor_ous: + return True + # no transmission + if input in tensor_ous[devid]: + continue + # have transmission + else: + return True + return False + + @staticmethod + def transmission(tensor_ous, in_tensor) -> Dict[int, List[IRTensor]]: + trans_tensors = dict() + for devid in tensor_ous: + for out in tensor_ous[devid]: + if devid not in trans_tensors: + trans_tensors[devid] = list() + if in_tensor.overlap(out): + trans_tensors[devid].append(out) + return trans_tensors + + @staticmethod + def match_allreduce(tous, tins): + """ + Allreduce semantic: + + Each device holds a recvs same spatial tensor from all device and + sends to all device. + The recved tensors are summed into one + """ + return None + + @staticmethod + def match_allgather(tous, tins): + """ + Allgather semantic: + + Each device performs same transformation merge. + + !!Note: Each input in merge su can be paired with a pair, find + them and remove!! Fuse merge, send, recv into one merge!! + """ + allgather_sus = list() + # {tensor_id: [device_id]} + in_devices: Dict[int, List[int]] = dict() + # {tensor_id: [tensors] + in_tensors: Dict[int, List[IRTensor]] = dict() + for devid in tins: + for in_tensor in tins[devid]: + tid = in_tensor._id + if tid not in in_devices: + in_devices[tid] = list() + in_tensors[tid] = list() + in_devices[tid].append(devid) + in_tensors[tid].append(in_tensor) + for tid in in_devices: + share_tensors = in_tensors[tid] + # P2P transmission + if len(in_devices[tid]) <= 1: + continue + in_tensor = in_tensors[tid][0] + # {rank: [IRTensor]}} + out_tensors = P2PFusion.transmission(tous, in_tensor) + out_devices = set(out_tensors.keys()) + if out_devices == set(in_devices[tid]): + # multiple transmission FIXME: remove redundancy + if not all([len(out_tensors[odev]) == 1 for odev in out_devices]): + continue + # check same value map and no overlap indices + unique_valmaps = list() + for odev in out_tensors: + valmap = out_tensors[odev][0].val_map + if valmap not in unique_valmaps: + unique_valmaps.append(valmap) + if len(unique_valmaps) != 1: + continue + # check no overlap indices + all_indices = list() + overlap = False + for odev in out_tensors: + indices = out_tensors[odev][0].indices + for pre_indices in all_indices: + overlap = pre_indices.overlap(indices) + all_indices.append(indices) + if overlap: + continue + + ranks = list(out_tensors.keys()) + inputs = [[out_tensors[rank][0]] for rank in ranks] + outputs = list() + for rank in ranks: + sh_idx = in_devices[tid].index(rank) + outputs.append([share_tensors[sh_idx]]) + + for input, output, rank in zip(inputs, outputs, ranks): + op = IRCollectives(input, output, ranks, IRCollType.AllGather) + su = ScheduleUnit([op], SUType.Comm, name='allgather') + su.device = rank + allgather_sus.append(su) + + print('>> find allgather pattern:') + print(f'device group: {ranks}') + for input in inputs: + print(f'src: {input}') + for output in outputs: + print(f'dst: {output}') + + if len(allgather_sus) == 0: + return None + else: + return allgather_sus + + @staticmethod + def match_reducescatter(tous, tins): + """ + ReduceScatter semantic: + + Each device performs same + """ + return None + + @staticmethod + def match_broadcast(tous, tins): + """ + Broadcast semantic: + + The root device send the its tensor to all the devices + """ + return None \ No newline at end of file diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 5d6ad856..00de5fe6 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -1,7 +1,8 @@ from typing import List - import torch +from cube.runtime.device import DeviceGroup + def send(tensors, to_ranks: List[int]): """ @@ -14,6 +15,10 @@ def send(tensors, to_ranks: List[int]): """ # print('sending...') send_ops = list() + + ## synthetic ## + # return + for tensor, rank in zip(tensors, to_ranks): send_op = torch.distributed.P2POp( torch.distributed.isend, tensor, rank @@ -29,6 +34,14 @@ def recv(shapes: List[List[int]], from_ranks: List[int]): # print('recving...') recv_ops = list() recv_tensors = list() + + ## synthetic ## + # for shape in shapes: + # recv_tensors.append( + # torch.ones(tuple(shape), + # device=torch.cuda.current_device() + # )) + for shape, rank in zip(shapes, from_ranks): tensor = torch.empty( shape, requires_grad=True, device=torch.cuda.current_device() @@ -82,3 +95,52 @@ def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): if len(recv_tensors) == 0: return None elif len(recv_tensors) == 1: return recv_tensors[0] else: return tuple(recv_tensors) + + +def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): + """ + Allreduce + """ + assert len(tensors) == 1 + tensor = tensors[0] + group = DeviceGroup().get_group(ranks) + torch.distributed.all_reduce(tensor, group=group) + return tensor + + +def all_gather(tensors: List[torch.Tensor], ranks: List[int]): + """ + Allgather + """ + assert len(tensors) == 1 + tensor = tensors[0] + group = DeviceGroup().get_group(ranks) + tensor_list = [torch.empty_like(tensor) for _ in ranks] + idx = ranks.index(DeviceGroup().rank) + tensor_list[idx] = tensor + torch.distributed.all_gather(tensor_list, group=group) + return tensor_list + + +def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): + """ + ReduceScatter + """ + group = DeviceGroup().get_group(ranks) + idx = ranks.index(DeviceGroup().rank) + output = tensors[idx] + torch.distributed.reduce_scatter( + output, tensors, group=group + ) + return output + + +def broadcast(tensors: List[torch.Tensor], ranks: List[int]): + """ + Broadcast. ranks[0] is the root + """ + assert len(tensors) == 1 + tensor = tensors[0] + group = DeviceGroup().get_group(ranks) + torch.distributed.broadcast(tensor, ranks[0], group=group) + return tensor diff --git a/cube/schedule/adapter/collectives.py b/cube/schedule/adapter/collectives.py new file mode 100644 index 00000000..a361d1b3 --- /dev/null +++ b/cube/schedule/adapter/collectives.py @@ -0,0 +1,46 @@ +from typing import List +from enum import Enum + +from cube.ir.cten import IRCell, IRTensor + + +class IRCollType(Enum): + + AllReduce = 'all_reduce' + AllGather = 'all_gather' + ReduceScatter = 'reduce_scatter' + + +class IRCollectives(IRCell): + """ + Collective cell for IRCell + """ + + def __init__(self, inputs: List[IRTensor], outputs: List[IRTensor], + ranks: List[int], colltype: IRCollType): + + if not isinstance(colltype, IRCollType): + raise TypeError("colltype Expected IRCollType") + if not all([isinstance(rank, int) for rank in ranks]): + raise TypeError("ranks should be List[int]") + + self.comm_type = colltype + if colltype == IRCollType.AllReduce: + signature = 'cube.runtime.collectives.all_reduce' + if colltype == IRCollType.AllGather: + signature = 'cube.runtime.collectives.all_gather' + if colltype == IRCollType.ReduceScatter: + signature = 'cube.runtime.collectives.reduce_scatter' + + self.ranks = ranks + + super().__init__( + name = colltype.value, + signature = signature, + input_length = len(inputs), + output_length = len(outputs) + ) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + for idx, output in enumerate(outputs): + self.set_output(idx, output) From f9270a9ddc5faaac06308564d4241ff82bfee33c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 16 Nov 2021 13:57:29 +0800 Subject: [PATCH 0312/1892] change sutype comm to p2p --- cube/codegen/codegen.py | 4 ++-- cube/execplan/execplan.py | 2 +- cube/execplan/planpass/merge.py | 2 +- cube/execplan/planpass/redundant.py | 2 +- cube/schedule/su.py | 3 ++- cube/schedule/sugraph.py | 12 ++++++------ tests/execplan/test_planpass_merge.py | 4 ++-- tests/execplan/test_planpass_redundant.py | 2 +- tests/schedule/test_sugraph.py | 16 ++++++++-------- tests/schedule/test_translator.py | 8 ++++---- 10 files changed, 28 insertions(+), 27 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 4b8da4ad..960aaaf9 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -31,7 +31,7 @@ def su_naming(self, su: ScheduleUnit) -> str: return f"fwcp{su._id}" if su.stype == SUType.Backward: return f"bwcp{su._id}" - if su.stype == SUType.Comm: + if su.stype == SUType.P2P: return f"comm{su._id}" if su.stype == SUType.Transform: return f"trans{su._id}" @@ -331,7 +331,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: """ Emit su code """ - fsu_types = [SUType.Forward, SUType.Comm, SUType.Transform, SUType.Optimizer] + fsu_types = [SUType.Forward, SUType.P2P, SUType.Transform, SUType.Optimizer] fsign = 'cube.runtime.executor.fexecute({model}, *{inputs})' bsign = 'cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index d1f461e4..455a30ab 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -82,7 +82,7 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): span = 1 elif su.stype == SUType.Backward: span = 2 - elif su.stype == SUType.Comm: + elif su.stype == SUType.P2P: span = 0.1 spans.append(span) diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index 6ce3a8f3..ab535aa3 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -19,7 +19,7 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: pieces: List[ScheduleUnit] = list() adapters: List[ScheduleUnit] = list() for seqidx, su in enumerate(dev_seq): - if su and su.stype in [SUType.Comm, SUType.Transform]: + if su and su.stype in [SUType.P2P, SUType.Transform]: if len(pieces) > 0: adapters.append(su) continue diff --git a/cube/execplan/planpass/redundant.py b/cube/execplan/planpass/redundant.py index aef9c8ed..7dc2731f 100644 --- a/cube/execplan/planpass/redundant.py +++ b/cube/execplan/planpass/redundant.py @@ -14,7 +14,7 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: """ for devid in execplan.devices(): seq = execplan.sequence(devid) - comms = [su for su in seq if su.stype == SUType.Comm] + comms = [su for su in seq if su.stype == SUType.P2P] for comm in comms: send_ranks = set([devid]) recv_ranks = set([devid]) diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 2c87fe30..26acdcbd 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -23,7 +23,8 @@ class SUType(Enum): # cube.runtime.collectives.sendrecv(send_tensors, send_ranks, # recv_shapes, from_ranks # ) - Comm = 'cube.runtime.adapter.sendrecv' + P2P = 'cube.runtime.adapter.sendrecv' + Coll = 'cube.runtime.adapter.coll' Optimizer = 'cube.runtime.reducer.Reduce' diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index ef564c9b..1d5ec63f 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -61,7 +61,7 @@ def reset_dependency(sus: List[ScheduleUnit]): for dst in sus[src_idx+1:]: for out_idx, out_tensor in enumerate(src.outputs()): # special dependency for communication adapter - if dst.stype == SUType.Comm: + if dst.stype == SUType.P2P: for recv_tensor in dst.outputs(): if out_tensor.overlap(recv_tensor): src.add_successor(out_idx, dst) @@ -186,14 +186,14 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: # condition 2): su2 input cannot be got from both su1 and other su start, stop = min(idx1, idx2), max(idx1, idx2) inter_sus = self.sequence[start+1:stop] - inter_sus = [su for su in inter_sus if su.stype != SUType.Comm] + inter_sus = [su for su in inter_sus if su.stype != SUType.P2P] for su in inter_sus: # FIXME: currently only allow other device su exists if self.happen_before(su1, su) or self.happen_before(su, su2): return None for idx in range(len(su2.inputs())): prev_sus = su2.predecessors(idx) - prev_sus = [su for su in prev_sus if su.stype != SUType.Comm] + prev_sus = [su for su in prev_sus if su.stype != SUType.P2P] if su2 in prev_sus and len(prev_sus) > 1: return None @@ -313,7 +313,7 @@ def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): elif not all([isinstance(rank, int) for rank in ranks]): raise TypeError("Expected type ranks to be Union[int, List[int]]") - if su.stype == SUType.Comm: + if su.stype == SUType.P2P: return False if set(su.device) == set(ranks): @@ -449,8 +449,8 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: recv_ranks = [-1] ) send_op.pair(recv_op) - send_su = ScheduleUnit([send_op], SUType.Comm, name='send') - recv_su = ScheduleUnit([recv_op], SUType.Comm, name='recv') + send_su = ScheduleUnit([send_op], SUType.P2P, name='send') + recv_su = ScheduleUnit([recv_op], SUType.P2P, name='recv') su._add_in_adapter(in_idx, send_su, recv_su) send_su.device = su.device pre_su._add_out_adapter(out_idx, send_su, recv_su) diff --git a/tests/execplan/test_planpass_merge.py b/tests/execplan/test_planpass_merge.py index 936e912a..215c9711 100644 --- a/tests/execplan/test_planpass_merge.py +++ b/tests/execplan/test_planpass_merge.py @@ -65,7 +65,7 @@ def test_planpass_merge(): sugraph = SUGraphGener.gen_sugraph(nodes) for su in sugraph.sus(): - if su.stype != SUType.Comm: + if su.stype != SUType.P2P: sugraph.assign(su, 0) print('orignal:') @@ -80,5 +80,5 @@ def test_planpass_merge(): print(f'> device {devid}') for su in execplan.sequence(devid): print(su) - assert su.stype != SUType.Comm + assert su.stype != SUType.P2P assert len(execplan.sequence(0)) == 2 \ No newline at end of file diff --git a/tests/execplan/test_planpass_redundant.py b/tests/execplan/test_planpass_redundant.py index d63c99fb..30cafb11 100644 --- a/tests/execplan/test_planpass_redundant.py +++ b/tests/execplan/test_planpass_redundant.py @@ -74,5 +74,5 @@ def test_remove_adapter(): print(f'> device {devid}') for su in execplan.sequence(devid): print(su) - assert su.stype != SUType.Comm + assert su.stype != SUType.P2P assert len(execplan.sequence(0)) == 6 \ No newline at end of file diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index 012efceb..0ddfd868 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -150,8 +150,8 @@ def test_sugraph_assign1(): recv_ranks = [-1] ) send_op.pair(recv_op) - send_su12 = ScheduleUnit([send_op], SUType.Comm, name='send') - recv_su12 = ScheduleUnit([recv_op], SUType.Comm, name='recv') + send_su12 = ScheduleUnit([send_op], SUType.P2P, name='send') + recv_su12 = ScheduleUnit([recv_op], SUType.P2P, name='recv') su1._add_out_adapter(0, send_su12, recv_su12) su2._add_in_adapter(0, send_su12, recv_su12) @@ -165,8 +165,8 @@ def test_sugraph_assign1(): recv_ranks = [-1] ) send_op.pair(recv_op) - send_su23 = ScheduleUnit([send_op], SUType.Comm, name='send') - recv_su23 = ScheduleUnit([recv_op], SUType.Comm, name='recv') + send_su23 = ScheduleUnit([send_op], SUType.P2P, name='send') + recv_su23 = ScheduleUnit([recv_op], SUType.P2P, name='recv') su2._add_out_adapter(0, send_su23, recv_su23) su3._add_in_adapter(0, send_su23, recv_su23) @@ -215,8 +215,8 @@ def test_sugraph_assign2(): recv_ranks = [-1] ) send_op.pair(recv_op) - send_su12 = ScheduleUnit([send_op], SUType.Comm, name='send') - recv_su12 = ScheduleUnit([recv_op], SUType.Comm, name='recv') + send_su12 = ScheduleUnit([send_op], SUType.P2P, name='send') + recv_su12 = ScheduleUnit([recv_op], SUType.P2P, name='recv') su1._add_out_adapter(0, send_su12, recv_su12) su2._add_in_adapter(0, send_su12, recv_su12) @@ -230,8 +230,8 @@ def test_sugraph_assign2(): recv_ranks = [-1] ) send_op.pair(recv_op) - send_su23 = ScheduleUnit([send_op], SUType.Comm, name='send') - recv_su23 = ScheduleUnit([recv_op], SUType.Comm, name='recv') + send_su23 = ScheduleUnit([send_op], SUType.P2P, name='send') + recv_su23 = ScheduleUnit([recv_op], SUType.P2P, name='recv') su2._add_out_adapter(0, send_su23, recv_su23) su3._add_in_adapter(0, send_su23, recv_su23) diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py index cdcca670..abf8f84c 100644 --- a/tests/schedule/test_translator.py +++ b/tests/schedule/test_translator.py @@ -143,10 +143,10 @@ def test_sugraph_gener_gen(): assert su1.stype == SUType.Forward assert su2.stype == SUType.Forward assert su3.stype == SUType.Forward - assert send_su12.stype == SUType.Comm - assert recv_su12.stype == SUType.Comm - assert send_su23.stype == SUType.Comm - assert recv_su23.stype == SUType.Comm + assert send_su12.stype == SUType.P2P + assert recv_su12.stype == SUType.P2P + assert send_su23.stype == SUType.P2P + assert recv_su23.stype == SUType.P2P # backward adapters output.backward() From 499305815ae94f53cfffa986bc35c15029901d45 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 16 Nov 2021 14:51:26 +0800 Subject: [PATCH 0313/1892] p2p fusion on allgather --- cube/execplan/planpass/p2pfusion.py | 43 ++++++++++++++++++++++------- cube/runtime/collectives.py | 6 +++- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index a7ee18ff..c81dfbe8 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -38,11 +38,14 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: continue tous, tins = fous[pid], fins[pid] if P2PFusion.have_comm(tous, tins): - colls = None + colls : List[ScheduleUnit] = None for matcher in matchers: colls = matcher(tous, tins) if colls: break + if colls is not None: + for coll_su in colls: + P2PFusion.add_collective(execplan, coll_su) return execplan @staticmethod @@ -96,6 +99,28 @@ def have_comm(tensor_ous, tensor_ins): return True return False + @staticmethod + def add_collective(execplan: ExectuionPlan, coll_su: ScheduleUnit): + print(f'inserting Collective SU: {coll_su}') + # find insert place: the first send + devid = coll_su.device[0] + ranks = coll_su.nodes(0).ranks + for idx, su in enumerate(execplan.sequence(devid)): + # send + if su.stype == SUType.P2P and len(su.inputs()) == 1: + if su.inputs(0) in coll_su.inputs(): + execplan.at(devid)[idx] = coll_su + break + else: + raise RuntimeError("Cannot find a send P2P") + # all the send, recv of the inputs will be removed in ranks + for input in coll_su.inputs(): + for rank in ranks: + for su in execplan.sequence(rank): + # remove send / recv + if su.stype == SUType.P2P and input in (su.inputs() + su.outputs()): + execplan.at(rank).remove(su) + @staticmethod def transmission(tensor_ous, in_tensor) -> Dict[int, List[IRTensor]]: trans_tensors = dict() @@ -142,7 +167,6 @@ def match_allgather(tous, tins): in_devices[tid].append(devid) in_tensors[tid].append(in_tensor) for tid in in_devices: - share_tensors = in_tensors[tid] # P2P transmission if len(in_devices[tid]) <= 1: continue @@ -175,14 +199,13 @@ def match_allgather(tous, tins): ranks = list(out_tensors.keys()) inputs = [[out_tensors[rank][0]] for rank in ranks] - outputs = list() - for rank in ranks: - sh_idx = in_devices[tid].index(rank) - outputs.append([share_tensors[sh_idx]]) - - for input, output, rank in zip(inputs, outputs, ranks): - op = IRCollectives(input, output, ranks, IRCollType.AllGather) - su = ScheduleUnit([op], SUType.Comm, name='allgather') + + for input, rank in zip(inputs, ranks): + outputs = [ + input[0] for idx, input in enumerate(inputs) if idx != rank + ] + op = IRCollectives(input, outputs, ranks, IRCollType.AllGather) + su = ScheduleUnit([op], SUType.Coll, name='allgather') su.device = rank allgather_sus.append(su) diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 00de5fe6..e2c104d1 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -119,7 +119,11 @@ def all_gather(tensors: List[torch.Tensor], ranks: List[int]): idx = ranks.index(DeviceGroup().rank) tensor_list[idx] = tensor torch.distributed.all_gather(tensor_list, group=group) - return tensor_list + tensor_list = [t for oidx, t in enumerate(tensor_list) if oidx != idx] + if len(tensor_list) == 1: + return tensor_list[0] + else: + return tensor_list def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): From 1d1c20db5b4ed697dcc516bc4c6afaa0f5f91cfa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 16 Nov 2021 15:15:59 +0800 Subject: [PATCH 0314/1892] enable allgather codegen --- cube/codegen/codegen.py | 24 +++++++++++++++++++++--- cube/execplan/planpass/p2pfusion.py | 22 ++++++++++++++++++---- cube/runtime/collectives.py | 2 +- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 960aaaf9..00ef38ad 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -2,12 +2,14 @@ Generate Pytorch code given the model DAG and the transformation config """ from typing import List, Any +from numpy import isin import torch import copy from cube.graph.operator.operator import IRFwOperation, IROptimOperation from cube.ir.cten import IRTensor from cube.execplan import ExectuionPlan +from cube.schedule.adapter.collectives import IRCollectives from cube.schedule.su import ScheduleUnit, SUType from cube.schedule.adapter.comm import IRCommType, IRCommunication @@ -32,7 +34,9 @@ def su_naming(self, su: ScheduleUnit) -> str: if su.stype == SUType.Backward: return f"bwcp{su._id}" if su.stype == SUType.P2P: - return f"comm{su._id}" + return f"p2p{su._id}" + if su.stype == SUType.Coll: + return f"coll{su._id}" if su.stype == SUType.Transform: return f"trans{su._id}" if su.stype == SUType.Optimizer: @@ -101,13 +105,17 @@ def gen(self, device: int, outfile=None, attach=False) -> str: for node in su.nodes(): if isinstance(node, IRFwOperation): self.emit_op_call(node) - if isinstance(node, IRTensorTransform): + elif isinstance(node, IRTensorTransform): self.emit_transform_call(node) elif isinstance(node, IRCommunication): self.emit_comm_call(node) + elif isinstance(node, IRCollectives): + self.emit_collective_call(node) elif isinstance(node, IROptimOperation): self.emit_optim_init(node) self.emit_optim_call(node) + else: + raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") # emit input declaration for arg in node.inputs(): self.emit_var_declare(arg) @@ -207,6 +215,16 @@ def emit_comm_call(self, node): raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") self.forward_region.append(code) + + def emit_collective_call(self, node): + ranks = node.ranks + inputs = self._forward_region_arg_names(node.inputs()) + inputs = '(' + ', '.join(inputs + ['']) + ')' + outputs = self._forward_region_arg_names(node.outputs()) + outputs = ', '.join(outputs) + code = f'{outputs} = {node.signature}({inputs}, {ranks})' + self.forward_region.append(code) + def emit_transform_call(self, node: IRTensorTransform): """ Emit in-device tensor select / merge call. @@ -331,7 +349,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: """ Emit su code """ - fsu_types = [SUType.Forward, SUType.P2P, SUType.Transform, SUType.Optimizer] + fsu_types = [SUType.Forward, SUType.P2P, SUType.Coll, SUType.Transform, SUType.Optimizer] fsign = 'cube.runtime.executor.fexecute({model}, *{inputs})' bsign = 'cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index c81dfbe8..af11b6f5 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -106,11 +106,13 @@ def add_collective(execplan: ExectuionPlan, coll_su: ScheduleUnit): devid = coll_su.device[0] ranks = coll_su.nodes(0).ranks for idx, su in enumerate(execplan.sequence(devid)): - # send - if su.stype == SUType.P2P and len(su.inputs()) == 1: - if su.inputs(0) in coll_su.inputs(): + # send or recv + if su.stype == SUType.P2P: + sr_tensor = (su.inputs() + su.outputs())[0] + if sr_tensor in coll_su.inputs() + coll_su.outputs(): execplan.at(devid)[idx] = coll_su break + else: raise RuntimeError("Cannot find a send P2P") # all the send, recv of the inputs will be removed in ranks @@ -228,7 +230,19 @@ def match_reducescatter(tous, tins): Each device performs same """ - return None + rs_sus = list() + # {tensor_id: [device_id]} + in_devices: Dict[int, List[int]] = dict() + # {tensor_id: [tensors] + in_tensors: Dict[int, List[IRTensor]] = dict() + for devid in tins: + for in_tensor in tins[devid]: + tid = in_tensor._id + if tid not in in_devices: + in_devices[tid] = list() + in_tensors[tid] = list() + in_devices[tid].append(devid) + in_tensors[tid].append(in_tensor) @staticmethod def match_broadcast(tous, tins): diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index e2c104d1..2546b933 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -118,7 +118,7 @@ def all_gather(tensors: List[torch.Tensor], ranks: List[int]): tensor_list = [torch.empty_like(tensor) for _ in ranks] idx = ranks.index(DeviceGroup().rank) tensor_list[idx] = tensor - torch.distributed.all_gather(tensor_list, group=group) + torch.distributed.all_gather(tensor_list, tensor, group=group) tensor_list = [t for oidx, t in enumerate(tensor_list) if oidx != idx] if len(tensor_list) == 1: return tensor_list[0] From 434e8700c3ccecd117aad493ed8736e1d88cad7e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 16 Nov 2021 17:11:01 +0800 Subject: [PATCH 0315/1892] enable p2p fusion to reduce-scatter --- cube/execplan/planpass/p2pfusion.py | 110 ++++++++++++++++++++++++++++ cube/runtime/collectives.py | 3 +- 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index af11b6f5..e96ab541 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -1,5 +1,6 @@ from typing import List, Dict from cube.execplan import ExectuionPlan +from cube.graph.tensor import ValueMap from cube.ir.cten import IRTensor from cube.schedule.su import SUType, ScheduleUnit from cube.execplan.planpass.planpass import PlanPass @@ -46,6 +47,19 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: if colls is not None: for coll_su in colls: P2PFusion.add_collective(execplan, coll_su) + for pid in bins: + if pid not in bous: + continue + tous, tins = bous[pid], bins[pid] + if P2PFusion.have_comm(tous, tins): + colls : List[ScheduleUnit] = None + for matcher in matchers: + colls = matcher(tous, tins) + if colls: + break + if colls is not None: + for coll_su in colls: + P2PFusion.add_collective(execplan, coll_su) return execplan @staticmethod @@ -112,6 +126,13 @@ def add_collective(execplan: ExectuionPlan, coll_su: ScheduleUnit): if sr_tensor in coll_su.inputs() + coll_su.outputs(): execplan.at(devid)[idx] = coll_su break + # merge + if su.stype == SUType.Transform and len(su.inputs()) > 1: + merge_out = su.outputs(0) + if merge_out in coll_su.outputs(): + assert len(coll_su.outputs()) == 1 + execplan.at(devid)[idx] = coll_su + break else: raise RuntimeError("Cannot find a send P2P") @@ -122,6 +143,12 @@ def add_collective(execplan: ExectuionPlan, coll_su: ScheduleUnit): # remove send / recv if su.stype == SUType.P2P and input in (su.inputs() + su.outputs()): execplan.at(rank).remove(su) + # remove merge if coll generate merge results + if su.stype == SUType.Transform and len(su.inputs()) > 1: + merge_out = su.outputs(0) + if merge_out in coll_su.outputs(): + assert len(coll_su.outputs()) == 1 + execplan.at(rank).remove(su) @staticmethod def transmission(tensor_ous, in_tensor) -> Dict[int, List[IRTensor]]: @@ -238,11 +265,94 @@ def match_reducescatter(tous, tins): for devid in tins: for in_tensor in tins[devid]: tid = in_tensor._id + if in_tensor.val_map != ValueMap(0, 1): + continue if tid not in in_devices: in_devices[tid] = list() in_tensors[tid] = list() in_devices[tid].append(devid) in_tensors[tid].append(in_tensor) + # {in_tensor_id: [reduce_tensor device]} + reduce_out_devices = dict() + # {in_tensor_id: [reduce out tensors]} + reduce_out_tensors = dict() + for tid in in_devices: + # P2P transmission + if len(in_devices[tid]) != 1: + continue + in_tensor = in_tensors[tid][0] + out_tensors = P2PFusion.transmission(tous, in_tensor) + + is_reduce = True + for devid in out_tensors: + # multiple transmission FIXME: remove redundancy + if not all([len(out_tensors[odev]) == 1 for odev in out_tensors]): + continue + if out_tensors[devid][0].val_map == ValueMap(0, 1): + is_reduce = False + break + if out_tensors[devid][0].indices != in_tensor.indices: + is_reduce = False + break + if is_reduce: + reduce_out_devices[tid] = list() + reduce_out_tensors[tid] = list() + for devid in out_tensors: + reduce_out_devices[tid].append(devid) + reduce_out_tensors[tid].append(out_tensors[devid][0]) + # reverse reduce_devices {tuple(devices): [in_tensors]} + reduce_tensors = dict() + for tid in reduce_out_devices: + devices = tuple(set(reduce_out_devices[tid])) + if devices not in reduce_tensors: + reduce_tensors[devices] = list() + reduce_tensors[devices].append(in_tensors[tid][0]) + # check conditions + for ranks in reduce_tensors: + reduce_in_tensors = reduce_tensors[ranks] + # reduce-scatter requires tensor num to be equal of num devs + if len(reduce_in_tensors) != len(ranks): + continue + # reduce in tensors should place on different devices + devices = [t.device[0] for t in reduce_in_tensors] + if set(devices) != set(ranks): + continue + + # satisfied! set up inputs, outputs and ranks + ranks = list(ranks) + ranks.sort() + + device_inputs = [None] * len(ranks) + for in_tensor in reduce_in_tensors: + out_tensors = reduce_out_tensors[in_tensor._id] + out_devs = [t.device[0] for t in out_tensors] + inputs = [ + out_tensors[out_devs.index(odev)] for odev in ranks + ] + ridx = ranks.index(in_tensor.device[0]) + device_inputs[ridx] = inputs + for in_tensor in reduce_in_tensors: + rank = in_tensor.device[0] + outputs = [in_tensor] + inputs = [inputs[rank] for inputs in device_inputs] + op = IRCollectives(inputs, outputs, ranks, IRCollType.ReduceScatter) + su = ScheduleUnit([op], SUType.Coll, name='reducescatter') + su.device = rank + rs_sus.append(su) + + print('>> find reduce-scatter pattern:') + print(f'device group: {ranks}') + for output in reduce_in_tensors: + tid = output._id + for input in reduce_out_tensors[tid]: + print(f'src: {input}') + print(f'dst: {output}') + + if len(rs_sus) == 0: + return None + else: + return rs_sus + @staticmethod def match_broadcast(tous, tins): diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 2546b933..6b067ce2 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -130,9 +130,10 @@ def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): """ ReduceScatter """ + tensors = list(tensors) group = DeviceGroup().get_group(ranks) idx = ranks.index(DeviceGroup().rank) - output = tensors[idx] + output = torch.empty_like(tensors[idx]) torch.distributed.reduce_scatter( output, tensors, group=group ) From c768d721a55a95ef929433c4eb7cd1e248048e78 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 16 Nov 2021 19:04:16 +0800 Subject: [PATCH 0316/1892] enable allreduce p2p fusion --- cube/compiler.py | 8 +- cube/execplan/planpass/p2pfusion.py | 190 ++++++++++++++++++---------- cube/runtime/collectives.py | 4 +- 3 files changed, 135 insertions(+), 67 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 5f247522..6359445b 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -12,6 +12,7 @@ from cube.execplan.planpass.redundant import RemoveRedundantAdapters from cube.execplan.planpass.merge import MergeComputeSU from cube.execplan.planpass.gfuse import WeightGradAllreduceFusion +from cube.execplan.planpass.p2pfusion import P2PFusion from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -145,11 +146,14 @@ def decorator(fn: Callable) -> Callable: execplan = ExectuionPlan(sugraph) # plan pass to remove redundant sus execplan = RemoveRedundantAdapters.apply(execplan) - print(f'> after remove redundant adapters:\n {execplan}') + # print(f'> after remove redundant adapters:\n {execplan}') execplan = MergeComputeSU.apply(execplan) print(f'> after merge backward SU:\n {execplan}') execplan = WeightGradAllreduceFusion.apply(execplan) - print(f'> after add allreduce:\n{execplan}') + # print(f'> after add allreduce:\n{execplan}') + + execplan = P2PFusion.apply(execplan) + print(f'> after fuse P2P SU:\n {execplan}') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index e96ab541..dae5701b 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -31,35 +31,24 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: for val in fins[pid][devid]: print(f' i:', val) matchers = [ - P2PFusion.match_allreduce, P2PFusion.match_allgather, - P2PFusion.match_reducescatter, P2PFusion.match_broadcast, + P2PFusion.match_allreduce, + P2PFusion.match_allgather, + P2PFusion.match_reducescatter, + P2PFusion.match_broadcast, ] - for pid in fins: - if pid not in fous: - continue - tous, tins = fous[pid], fins[pid] - if P2PFusion.have_comm(tous, tins): - colls : List[ScheduleUnit] = None - for matcher in matchers: - colls = matcher(tous, tins) - if colls: - break - if colls is not None: - for coll_su in colls: - P2PFusion.add_collective(execplan, coll_su) - for pid in bins: - if pid not in bous: - continue - tous, tins = bous[pid], bins[pid] - if P2PFusion.have_comm(tous, tins): - colls : List[ScheduleUnit] = None - for matcher in matchers: - colls = matcher(tous, tins) - if colls: - break - if colls is not None: - for coll_su in colls: - P2PFusion.add_collective(execplan, coll_su) + for ous, ins in zip([fous, bous], [fins, bins]): + for pid in ins: + if pid not in ous: + continue + tous, tins = ous[pid], ins[pid] + if P2PFusion.have_comm(tous, tins): + colls : List[ScheduleUnit] = None + for matcher in matchers: + colls = matcher(tous, tins) + if colls: + break + if colls is not None: + P2PFusion.add_collectives(execplan, colls) return execplan @staticmethod @@ -114,50 +103,51 @@ def have_comm(tensor_ous, tensor_ins): return False @staticmethod - def add_collective(execplan: ExectuionPlan, coll_su: ScheduleUnit): - print(f'inserting Collective SU: {coll_su}') - # find insert place: the first send - devid = coll_su.device[0] - ranks = coll_su.nodes(0).ranks - for idx, su in enumerate(execplan.sequence(devid)): - # send or recv - if su.stype == SUType.P2P: - sr_tensor = (su.inputs() + su.outputs())[0] - if sr_tensor in coll_su.inputs() + coll_su.outputs(): - execplan.at(devid)[idx] = coll_su - break - # merge - if su.stype == SUType.Transform and len(su.inputs()) > 1: - merge_out = su.outputs(0) - if merge_out in coll_su.outputs(): - assert len(coll_su.outputs()) == 1 - execplan.at(devid)[idx] = coll_su - break - - else: - raise RuntimeError("Cannot find a send P2P") + def add_collectives(execplan: ExectuionPlan, coll_sus: List[ScheduleUnit]): + for coll_su in coll_sus: + print(f'inserting Collective SU: {coll_su}') + # find insert place: the first send + devid = coll_su.device[0] + ranks = coll_su.nodes(0).ranks + for idx, su in enumerate(execplan.sequence(devid)): + # send or recv + if su.stype == SUType.P2P: + sr_tensor = (su.inputs() + su.outputs())[0] + if sr_tensor in coll_su.inputs() + coll_su.outputs(): + execplan.at(devid)[idx] = coll_su + break + # merge + if su.stype == SUType.Transform and len(su.inputs()) > 1: + merge_out = su.outputs(0) + if merge_out in coll_su.outputs(): + assert len(coll_su.outputs()) == 1 + execplan.at(devid)[idx] = coll_su + break + else: + raise RuntimeError("Cannot find a send P2P") # all the send, recv of the inputs will be removed in ranks - for input in coll_su.inputs(): - for rank in ranks: - for su in execplan.sequence(rank): - # remove send / recv - if su.stype == SUType.P2P and input in (su.inputs() + su.outputs()): - execplan.at(rank).remove(su) - # remove merge if coll generate merge results - if su.stype == SUType.Transform and len(su.inputs()) > 1: - merge_out = su.outputs(0) - if merge_out in coll_su.outputs(): - assert len(coll_su.outputs()) == 1 + for coll_su in coll_sus: + for input in coll_su.inputs(): + for rank in ranks: + for su in execplan.sequence(rank): + # remove send / recv + if su.stype == SUType.P2P and input in (su.inputs() + su.outputs()): execplan.at(rank).remove(su) + # remove merge if coll generate merge results + if su.stype == SUType.Transform and len(su.inputs()) > 1: + merge_out = su.outputs(0) + if merge_out in coll_su.outputs(): + assert len(coll_su.outputs()) == 1 + execplan.at(rank).remove(su) @staticmethod def transmission(tensor_ous, in_tensor) -> Dict[int, List[IRTensor]]: trans_tensors = dict() for devid in tensor_ous: for out in tensor_ous[devid]: - if devid not in trans_tensors: - trans_tensors[devid] = list() if in_tensor.overlap(out): + if devid not in trans_tensors: + trans_tensors[devid] = list() trans_tensors[devid].append(out) return trans_tensors @@ -170,7 +160,79 @@ def match_allreduce(tous, tins): sends to all device. The recved tensors are summed into one """ - return None + allreduce_sus = list() + # {tensor_id: [device_id]} + in_devices: Dict[int, List[int]] = dict() + # {tensor_id: [tensors] + in_tensors: Dict[int, List[IRTensor]] = dict() + for devid in tins: + for in_tensor in tins[devid]: + if in_tensor.val_map != ValueMap(0, 1): + continue + tid = in_tensor._id + if tid not in in_devices: + in_devices[tid] = list() + in_tensors[tid] = list() + in_devices[tid].append(devid) + in_tensors[tid].append(in_tensor) + for tid in in_devices: + # P2P transmission + if len(in_devices[tid]) <= 1: + continue + in_tensor = in_tensors[tid][0] + # {rank: [IRTensor]}} + out_tensors = P2PFusion.transmission(tous, in_tensor) + out_devices = set(out_tensors.keys()) + # check out tensor and reduce in tensor devices are the same set + if out_devices == set(in_devices[tid]): + # multiple transmission FIXME: remove redundancy + if not all([len(out_tensors[odev]) == 1 for odev in out_devices]): + continue + # check same indice map and no overlap value map + unique_indices = list() + for odev in out_tensors: + indices = out_tensors[odev][0].indices + if indices not in unique_indices: + unique_indices.append(indices) + if len(unique_indices) != 1: + continue + # check no overlap valmaps + all_valmaps = list() + overlap = False + for odev in out_tensors: + valmap = out_tensors[odev][0].val_map + for pre_valmp in all_valmaps: + overlap = pre_valmp.overlap(valmap) + all_valmaps.append(valmap) + if overlap: + continue + + ranks = list(out_tensors.keys()) + inputs = [[out_tensors[rank][0]] for rank in ranks] + + for input, rank in zip(inputs, ranks): + for in_tensor in in_tensors[tid]: + if in_tensor.device[0] == rank: + outputs = [in_tensor] + break + else: + raise RuntimeError("Internal Error") + op = IRCollectives(input, outputs, ranks, IRCollType.AllReduce) + su = ScheduleUnit([op], SUType.Coll, name='allgather') + su.device = rank + allreduce_sus.append(su) + + print('>> find allreduce pattern:') + print(f'device group: {ranks}') + for input in inputs: + print(f'src: {input}') + for output in outputs: + print(f'dst: {output}') + + if len(allreduce_sus) == 0: + return None + else: + return allreduce_sus @staticmethod def match_allgather(tous, tins): diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 6b067ce2..23e75e6a 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -103,6 +103,8 @@ def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): """ assert len(tensors) == 1 tensor = tensors[0] + tensor = tensor.detach() + tensor = tensor.requires_grad_() group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(tensor, group=group) return tensor @@ -133,7 +135,7 @@ def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): tensors = list(tensors) group = DeviceGroup().get_group(ranks) idx = ranks.index(DeviceGroup().rank) - output = torch.empty_like(tensors[idx]) + output = torch.empty_like(tensors[idx], requires_grad=True) torch.distributed.reduce_scatter( output, tensors, group=group ) From 695e4705de7cdfd659c6a5c93bcff7e1d992a890 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 16 Nov 2021 20:21:32 +0800 Subject: [PATCH 0317/1892] enable p2p fusion to broadcast --- cube/codegen/codegen.py | 12 ++- cube/execplan/planpass/p2pfusion.py | 145 ++++++++++++++++++++------- cube/runtime/collectives.py | 9 +- cube/schedule/adapter/collectives.py | 3 + 4 files changed, 125 insertions(+), 44 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 00ef38ad..a6e8fdbd 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -215,14 +215,22 @@ def emit_comm_call(self, node): raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") self.forward_region.append(code) - def emit_collective_call(self, node): ranks = node.ranks inputs = self._forward_region_arg_names(node.inputs()) + shape = None + if len(inputs) == 0: + assert len(node.outputs()) == 1 + shape = node.outputs(0).shape inputs = '(' + ', '.join(inputs + ['']) + ')' outputs = self._forward_region_arg_names(node.outputs()) outputs = ', '.join(outputs) - code = f'{outputs} = {node.signature}({inputs}, {ranks})' + if shape: + code = f'{node.signature}({inputs}, {ranks}, {shape})' + else: + code = f'{node.signature}({inputs}, {ranks})' + if outputs: + code = f'{outputs} = {code}' self.forward_region.append(code) def emit_transform_call(self, node: IRTensorTransform): diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index dae5701b..3bccfc62 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -13,23 +13,28 @@ class P2PFusion(PlanPass): @staticmethod def apply(execplan: ExectuionPlan) -> ExectuionPlan: # dict[pid][devid] = list of sub_tensors - fous, fins = P2PFusion.collect_tensors(execplan, SUType.Forward) - bous, bins = P2PFusion.collect_tensors(execplan, SUType.Backward) + fous, fins = P2PFusion.collect_tensors( + execplan, [SUType.Dataloader, SUType.Forward] + ) + bous, bins = P2PFusion.collect_tensors( + execplan, [SUType.Backward] + ) # debug - print('=====> forward') - for pid in fins: - if pid not in fous: - continue - if P2PFusion.have_comm(fous[pid], fins[pid]): - print(f'=> parent tensor id: {pid}') - for devid in fous[pid]: - print(f' ==> device: {devid}') - for val in fous[pid][devid]: - print(f' o:', val) - for devid in fins[pid]: - print(f' ==> device: {devid}') - for val in fins[pid][devid]: - print(f' i:', val) + # print('=====> forward') + # for pid in fins: + # if pid not in fous: + # continue + # if P2PFusion.have_comm(fous[pid], fins[pid]): + # print(f'=> parent tensor id: {pid}') + # for devid in fous[pid]: + # print(f' ==> device: {devid}') + # for val in fous[pid][devid]: + # print(f' o:', val) + # for devid in fins[pid]: + # print(f' ==> device: {devid}') + # for val in fins[pid][devid]: + # print(f' i:', val) + matchers = [ P2PFusion.match_allreduce, P2PFusion.match_allgather, @@ -52,16 +57,17 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: return execplan @staticmethod - def collect_tensors(execplan: ExectuionPlan, stype: SUType): + def collect_tensors(execplan: ExectuionPlan, stypes: List[SUType]): # dict[pid][devid] = list of sub_tensors ous = dict() ins = dict() for devid in execplan.devices(): dev_seq = execplan.sequence(devid) for su in dev_seq: - if su.stype == stype: + if su.stype in stypes: for val in su.inputs(): - if isinstance(val, IRTensor): + # FIXME: remove parameter constraints + if isinstance(val, IRTensor) and not val.is_param(): pid = val.parent._id if pid not in ins: ins[pid] = dict() @@ -222,12 +228,12 @@ def match_allreduce(tous, tins): su.device = rank allreduce_sus.append(su) - print('>> find allreduce pattern:') - print(f'device group: {ranks}') - for input in inputs: - print(f'src: {input}') - for output in outputs: - print(f'dst: {output}') + # print('>> find allreduce pattern:') + # print(f'device group: {ranks}') + # for input in inputs: + # print(f'src: {input}') + # for output in outputs: + # print(f'dst: {output}') if len(allreduce_sus) == 0: return None @@ -300,12 +306,12 @@ def match_allgather(tous, tins): su.device = rank allgather_sus.append(su) - print('>> find allgather pattern:') - print(f'device group: {ranks}') - for input in inputs: - print(f'src: {input}') - for output in outputs: - print(f'dst: {output}') + # print('>> find allgather pattern:') + # print(f'device group: {ranks}') + # for input in inputs: + # print(f'src: {input}') + # for output in outputs: + # print(f'dst: {output}') if len(allgather_sus) == 0: return None @@ -402,13 +408,13 @@ def match_reducescatter(tous, tins): su.device = rank rs_sus.append(su) - print('>> find reduce-scatter pattern:') - print(f'device group: {ranks}') - for output in reduce_in_tensors: - tid = output._id - for input in reduce_out_tensors[tid]: - print(f'src: {input}') - print(f'dst: {output}') + # print('>> find reduce-scatter pattern:') + # print(f'device group: {ranks}') + # for output in reduce_in_tensors: + # tid = output._id + # for input in reduce_out_tensors[tid]: + # print(f'src: {input}') + # print(f'dst: {output}') if len(rs_sus) == 0: return None @@ -423,4 +429,65 @@ def match_broadcast(tous, tins): The root device send the its tensor to all the devices """ - return None \ No newline at end of file + broadcast_sus = list() + # {tensor_id: [device_id]} + in_devices: Dict[int, List[int]] = dict() + # {tensor_id: [tensors] + in_tensors: Dict[int, List[IRTensor]] = dict() + for devid in tins: + for in_tensor in tins[devid]: + tid = in_tensor._id + if in_tensor.val_map != ValueMap(0, 1): + continue + if tid not in in_devices: + in_devices[tid] = list() + in_tensors[tid] = list() + in_devices[tid].append(devid) + in_tensors[tid].append(in_tensor) + + for tid in in_devices: + # P2P transmission + if len(in_devices[tid]) <= 2: + continue + in_tensor = in_tensors[tid][0] + out_tensors = P2PFusion.transmission(tous, in_tensor) + # multiple transmission FIXME: remove redundancy + if len(out_tensors.keys()) != 1: + continue + # multiple transmission FIXME: remove redundancy + if len(out_tensors[list(out_tensors.keys())[0]]) != 1: + continue + root_tensor = out_tensors[list(out_tensors.keys())[0]][0] + is_equal = True + for in_tensor in in_tensors[tid]: + if in_tensor != root_tensor: + is_equal = False + break + if not is_equal: + continue + ranks = [root_tensor.device[0]] + inputs = [[root_tensor],] + outputs = [[],] + for output in in_tensors[tid]: + devid = output.device[0] + if devid in ranks: + continue + ranks.append(devid) + outputs.append([output]) + inputs.append([]) + for input, output, rank in zip(inputs, outputs, ranks): + op = IRCollectives(input, output, ranks, IRCollType.Broadcast) + su = ScheduleUnit([op], SUType.Coll, name='broadcast') + su.device = rank + broadcast_sus.append(su) + + print('>> find broadcast pattern:') + print(f'device group: {ranks}') + print(su) + + if len(broadcast_sus) == 0: + return None + + + else: + return broadcast_sus diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 23e75e6a..c5605136 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -142,12 +142,15 @@ def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): return output -def broadcast(tensors: List[torch.Tensor], ranks: List[int]): +def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None): """ Broadcast. ranks[0] is the root """ - assert len(tensors) == 1 - tensor = tensors[0] + if len(tensors) == 1: + tensor = tensors[0] + else: + tensor = torch.empty(shape, device=torch.cuda.current_device()) + # tensor.requires_grad_() group = DeviceGroup().get_group(ranks) torch.distributed.broadcast(tensor, ranks[0], group=group) return tensor diff --git a/cube/schedule/adapter/collectives.py b/cube/schedule/adapter/collectives.py index a361d1b3..a004361c 100644 --- a/cube/schedule/adapter/collectives.py +++ b/cube/schedule/adapter/collectives.py @@ -9,6 +9,7 @@ class IRCollType(Enum): AllReduce = 'all_reduce' AllGather = 'all_gather' ReduceScatter = 'reduce_scatter' + Broadcast = 'broadcast' class IRCollectives(IRCell): @@ -31,6 +32,8 @@ def __init__(self, inputs: List[IRTensor], outputs: List[IRTensor], signature = 'cube.runtime.collectives.all_gather' if colltype == IRCollType.ReduceScatter: signature = 'cube.runtime.collectives.reduce_scatter' + if colltype == IRCollType.Broadcast: + signature = 'cube.runtime.collectives.broadcast' self.ranks = ranks From 6b548d0ed059f3298b6a5a580bf7591d56f6e71c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 16 Nov 2021 23:57:10 +0800 Subject: [PATCH 0318/1892] fix allgather and reducescatter bugs --- cube/execplan/planpass/p2pfusion.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index 3bccfc62..f7889061 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -111,7 +111,7 @@ def have_comm(tensor_ous, tensor_ins): @staticmethod def add_collectives(execplan: ExectuionPlan, coll_sus: List[ScheduleUnit]): for coll_su in coll_sus: - print(f'inserting Collective SU: {coll_su}') + print(f'inserting Collective SU: {coll_su.name}: {coll_su}') # find insert place: the first send devid = coll_su.device[0] ranks = coll_su.nodes(0).ranks @@ -295,13 +295,11 @@ def match_allgather(tous, tins): continue ranks = list(out_tensors.keys()) - inputs = [[out_tensors[rank][0]] for rank in ranks] + inputs = [out_tensors[rank][0] for rank in ranks] for input, rank in zip(inputs, ranks): - outputs = [ - input[0] for idx, input in enumerate(inputs) if idx != rank - ] - op = IRCollectives(input, outputs, ranks, IRCollType.AllGather) + outputs = [t for t in inputs if t != input] + op = IRCollectives([input], outputs, ranks, IRCollType.AllGather) su = ScheduleUnit([op], SUType.Coll, name='allgather') su.device = rank allgather_sus.append(su) @@ -402,7 +400,7 @@ def match_reducescatter(tous, tins): for in_tensor in reduce_in_tensors: rank = in_tensor.device[0] outputs = [in_tensor] - inputs = [inputs[rank] for inputs in device_inputs] + inputs = [inputs[ranks.index(rank)] for inputs in device_inputs] op = IRCollectives(inputs, outputs, ranks, IRCollType.ReduceScatter) su = ScheduleUnit([op], SUType.Coll, name='reducescatter') su.device = rank From a3e21eb6ba0f99188dcf25fdb48308a2234b8900 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 17 Nov 2021 00:21:23 +0800 Subject: [PATCH 0319/1892] fix removing bug --- cube/execplan/planpass/p2pfusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index f7889061..c6924169 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -133,6 +133,7 @@ def add_collectives(execplan: ExectuionPlan, coll_sus: List[ScheduleUnit]): raise RuntimeError("Cannot find a send P2P") # all the send, recv of the inputs will be removed in ranks for coll_su in coll_sus: + ranks = coll_su.nodes(0).ranks for input in coll_su.inputs(): for rank in ranks: for su in execplan.sequence(rank): From c614fde8ab14ffce9bd86404c4c7e15ffb2c2463 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 17 Nov 2021 00:41:26 +0800 Subject: [PATCH 0320/1892] fix bugs by create comm groups --- cube/codegen/codegen.py | 33 +++++++++++++++++++++++++++++- cube/runtime/collectives.py | 8 ++++++-- cube/runtime/module.py | 7 +++++++ examples/policy/nested_parallel.py | 2 +- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index a6e8fdbd..5c1c3ea6 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -2,7 +2,6 @@ Generate Pytorch code given the model DAG and the transformation config """ from typing import List, Any -from numpy import isin import torch import copy from cube.graph.operator.operator import IRFwOperation, IROptimOperation @@ -76,6 +75,25 @@ def __init__(self, execplan: ExectuionPlan): self.symbols = SymbolTable() # ref module to check shared variables self._ref_module = torch.nn.Module() + # groups + self._all_comm_groups = list() + self.get_all_groups() + + def get_all_groups(self): + """ + Get all communication groups. + + Creating communication group requires all the devices + enter the same call. + """ + for devid in self.execplan.devices(): + for su in self.execplan.sequence(devid): + if su.stype == SUType.Coll: + ranks = list(su.nodes(0).ranks) + ranks.sort() + ranks = tuple(ranks) + if ranks not in self._all_comm_groups: + self._all_comm_groups.append(ranks) def gen(self, device: int, outfile=None, attach=False) -> str: """ @@ -100,6 +118,9 @@ def gen(self, device: int, outfile=None, attach=False) -> str: self.symbols.create(name) su_args.append(fargs) + # init group + self.emit_comm_group_creation() + # parse graph body for su in device_sus: for node in su.nodes(): @@ -178,6 +199,16 @@ def emit_var_declare(self, var: Any): self.declare_region.append(code) return + def emit_comm_group_creation(self): + """ + Emit communication group creation code + """ + sign = 'self.init_group(ranks={ranks})' + for ranks in self._all_comm_groups: + ranks = list(ranks) + code = sign.format(ranks=ranks) + self.declare_region.append(code) + def emit_op_call(self, node): """ Emit op forward code diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index c5605136..b4f60a2d 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -13,7 +13,7 @@ def send(tensors, to_ranks: List[int]): tensors (List[torch.Tensor]): list of tensor to send tensor_devices (List[List[int]]): tensor sent devices """ - # print('sending...') + # print(f'{torch.distributed.get_rank()}: sending...') send_ops = list() ## synthetic ## @@ -31,7 +31,7 @@ def send(tensors, to_ranks: List[int]): def recv(shapes: List[List[int]], from_ranks: List[int]): - # print('recving...') + # print(f'{torch.distributed.get_rank()}: recving...') recv_ops = list() recv_tensors = list() @@ -101,6 +101,7 @@ def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): """ Allreduce """ + # print(f'{torch.distributed.get_rank()}: all_reduce...') assert len(tensors) == 1 tensor = tensors[0] tensor = tensor.detach() @@ -114,6 +115,7 @@ def all_gather(tensors: List[torch.Tensor], ranks: List[int]): """ Allgather """ + # print(f'{torch.distributed.get_rank()}: all_gather...') assert len(tensors) == 1 tensor = tensors[0] group = DeviceGroup().get_group(ranks) @@ -132,6 +134,7 @@ def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): """ ReduceScatter """ + # print(f'{torch.distributed.get_rank()}: reduce-scatter...') tensors = list(tensors) group = DeviceGroup().get_group(ranks) idx = ranks.index(DeviceGroup().rank) @@ -146,6 +149,7 @@ def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None): """ Broadcast. ranks[0] is the root """ + # print(f'{torch.distributed.get_rank()}: broadcast...') if len(tensors) == 1: tensor = tensors[0] else: diff --git a/cube/runtime/module.py b/cube/runtime/module.py index e13b5b21..d87dbaad 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,4 +1,6 @@ +from typing import List import torch +from cube.runtime.device import DeviceGroup from cube.runtime.reducer import Reducer @@ -24,3 +26,8 @@ def sync_params(self): def init_param(self): for param in self.parameters(): torch.nn.init.uniform_(param) + + def init_group(self, ranks: List[int]): + if not all([isinstance(rank, int) for rank in ranks]): + raise TypeError("Expected ranks to be List[int]") + DeviceGroup().get_group(ranks) diff --git a/examples/policy/nested_parallel.py b/examples/policy/nested_parallel.py index 56c369b9..9f158d26 100644 --- a/examples/policy/nested_parallel.py +++ b/examples/policy/nested_parallel.py @@ -16,7 +16,7 @@ def transform_policy(graph: IRGraph, resource): algo = node.algorithms('data') sub_nodes = graph.partition(node, algo, config=dict(chunk_num=dp)) for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx * tp + sub_node.tag = idx # partition operators first in column and then in data if isinstance(node, IRFwOperation): all_sub_nodes = list() From d8b1e4a95aeb79494b2f3c95246dbb6056a2ff77 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 17 Nov 2021 10:29:38 +0800 Subject: [PATCH 0321/1892] add make pair for comm --- cube/execplan/planpass/merge.py | 4 +-- cube/execplan/planpass/redundant.py | 24 ++++++++++++++++++ cube/graph/gpass.py | 5 ++-- cube/graph/graph.py | 37 ++++++++++++++++++++++++++-- cube/ir/cten.py | 13 +++++++--- cube/schedule/adapter/collectives.py | 3 ++- cube/schedule/adapter/comm.py | 21 ++-------------- cube/schedule/su.py | 9 ------- cube/schedule/sugraph.py | 20 +++------------ tests/schedule/test_su.py | 3 ++- tests/schedule/test_sugraph.py | 9 ++++--- 11 files changed, 88 insertions(+), 60 deletions(-) diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index ab535aa3..8a55d2fd 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -3,6 +3,7 @@ from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass from cube.graph.operator.operator import IRBpOperation +from cube.ir.cten import IRCell from cube.schedule.su import SUType, ScheduleUnit @@ -89,6 +90,5 @@ def _merge(pieces: List[ScheduleUnit], devid: int) -> ScheduleUnit: mbsu = ScheduleUnit([mbnode], SUType.Backward, name='bsu') mbsu.device = devid - mfsu.mirror = mbsu - mbsu.mirror = mfsu + IRCell.make_pair(mfsu, mbsu) return mfsu diff --git a/cube/execplan/planpass/redundant.py b/cube/execplan/planpass/redundant.py index 7dc2731f..8ec69ee5 100644 --- a/cube/execplan/planpass/redundant.py +++ b/cube/execplan/planpass/redundant.py @@ -12,6 +12,7 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: A redundant adapter is sending / recving tensors on the same deivce """ + # remove identity comm for devid in execplan.devices(): seq = execplan.sequence(devid) comms = [su for su in seq if su.stype == SUType.P2P] @@ -27,4 +28,27 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: continue # remove execplan.at(devid).remove(comm) + # remove redundant comm e.g., recving same tensor from other ranks + for devid in execplan.devices(): + all_outs = list() + seq = execplan.sequence(devid) + for su in seq: + # zero-output SU will not be removed + removable = len(su.outputs()) != 0 + for output in su.outputs(): + if output not in all_outs: + removable = False + all_outs.append(output) + if removable: + # only recv has output + execplan.at(devid).remove(su) + if su.stype == SUType.P2P: + # remove all the paired send + ranks = su.nodes(0).recv_ranks + if len(ranks) > 1: + raise NotImplementedError + rank = ranks[0] + if su.mirror not in execplan.at(rank): + raise RuntimeError("Recv Op not found!") + execplan.at(rank).remove(su.mirror) return execplan diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index e264a368..ae4bcf14 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -5,7 +5,7 @@ from cube.graph.tensor import IRSubTensor from cube.graph.operator import IRBpOperation -from cube.ir.cten import IRTensor +from cube.ir.cten import IRCell, IRTensor __all__ = ['forward'] @@ -103,8 +103,7 @@ def forward(graph, *args) -> IRGraph: bnode.device = node.device # mirror node for forward / backward - fnode.mirror = bnode - bnode.mirror = fnode + IRCell.make_pair(fnode, bnode) fnodes.append(fnode) bnodes.append(bnode) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 50e57ed5..102b31a6 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -254,6 +254,40 @@ def subgraph(self, sub_nodes: List[IRCell]): ) return graph + def replicate(self, op: IRCell, times=1): + """ + Replicate an operation with multiple times. + + This is temporary use to enable assign with multiple devices + """ + if not isinstance(op, IRCell): + raise TypeError("Expected an IRCell") + if not isinstance(times, int) or times < 1: + raise TypeError("Expected times to be int and >= 1") + + if op not in self.nodes(): + raise RuntimeError(f"Op {op} not exsits") + + ops = [op] + for _ in range(times - 1): + cpy_op = copy.copy(op) + for idx, input in enumerate(op.inputs()): + cpy_op.set_input(idx, input) + for idx, output in enumerate(op.outputs()): + cpy_op.set_output(idx, output) + if op.mirror is not None: + cpy_mirror_op = copy.copy(op.mirror) + for idx, input in enumerate(op.mirror.inputs()): + cpy_mirror_op.set_input(idx, input) + for idx, output in enumerate(op.mirror.outputs()): + cpy_mirror_op.set_output(idx, output) + IRCell.make_pair(cpy_op, cpy_mirror_op) + ops.append(cpy_op) + idx = self.nodes().index(op) + self._nodes = self._nodes[:idx] + ops + self._nodes[idx+1:] + self.reset_dependency() + return ops + ## Primitives for policy expression ## def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: @@ -327,9 +361,8 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional grad = val.get_grad(fnode) val.grad = grad bnode.set_grad(idx, grad) - fnode.mirror = bnode + IRCell.make_pair(fnode, bnode) fnode.device = op.device - bnode.mirror = fnode bnode.device = op.mirror.device idx = self._nodes.index(op) self._nodes = self._nodes[:idx] + fnodes + self._nodes[idx+1:] diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 065df596..787cfb15 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -92,9 +92,16 @@ def mirror(self): @mirror.setter def mirror(self, other): - if not isinstance(other, IRCell): - raise TypeError("Expected mirror to be IRCell") - self._mirror = other + raise RuntimeError("Use IRCell.make_pair instead") + + @staticmethod + def make_pair(cell1, cell2): + if not isinstance(cell1, IRCell): + raise TypeError("Expected cell1 to be IRCell") + if not isinstance(cell2, IRCell): + raise TypeError("Expected cell2 to be IRCell") + cell1._mirror = cell2 + cell2._mirror = cell1 def on_device(self, device_id: int): """ diff --git a/cube/schedule/adapter/collectives.py b/cube/schedule/adapter/collectives.py index a004361c..026772b3 100644 --- a/cube/schedule/adapter/collectives.py +++ b/cube/schedule/adapter/collectives.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Type from enum import Enum from cube.ir.cten import IRCell, IRTensor @@ -47,3 +47,4 @@ def __init__(self, inputs: List[IRTensor], outputs: List[IRTensor], self.set_input(idx, input) for idx, output in enumerate(outputs): self.set_output(idx, output) + diff --git a/cube/schedule/adapter/comm.py b/cube/schedule/adapter/comm.py index c9a4948d..671b65fc 100644 --- a/cube/schedule/adapter/comm.py +++ b/cube/schedule/adapter/comm.py @@ -59,23 +59,6 @@ def __init__(self, self.recv_tensors.append(self.outputs(idx)) self.recv_ranks.append(from_device) - self.msg_id = self._id - - def pair(self, other): - """ - Pair two comm node to have same message id. - - The `other` message id is set same with caller - """ - if not isinstance(other, IRCommunication): - raise RuntimeError("Expected IRCommunication to pair") - other.msg_id = self.msg_id - - def merge(self, other): - if not isinstance(other, IRCommunication): - raise RuntimeError("Expected IRCommunication to merge") - raise NotImplementedError - def __repr__(self): inputs = list() for tensor in self.inputs(): @@ -91,5 +74,5 @@ def __repr__(self): else: outputs.append(tensor) - dscp = f'SendRecv(msg_id={self.msg_id}, device={self.device}, send={inputs}, recv={outputs})' - return dscp \ No newline at end of file + dscp = f'SendRecv(id={self._id}, send={inputs}, recv={outputs})' + return dscp diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 26acdcbd..888512c8 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -101,15 +101,6 @@ def __copy__(self): Copy the SU. Note the mirror su is also copied """ raise NotImplementedError("Copy SU is not supported yet") - su = ScheduleUnit(self._nodes, self.stype, self.name) - if self.mirror is not None: - mirror_su = self.mirror - mirror_su = ScheduleUnit( - mirror_su._nodes, mirror_su.stype, mirror_su.name - ) - su.mirror = mirror_su - mirror_su.mirror = su - return su def in_adapters(self, index: Optional[int] = None) -> List: """ diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 1d5ec63f..4bb0b983 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -1,12 +1,10 @@ from typing import List, Optional, Union import copy - from cube.ir.cten import IRCell, IRTensor from cube.graph.operator import IRBpOperation from cube.graph.operator import IRDataOperation from cube.graph.operator import IRFwOperation -from cube.graph.graph import IRGraph from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.adapter.comm import IRCommunication from cube.schedule.adapter.transform import IRTensorTransform @@ -118,15 +116,6 @@ def fsus(self) -> List[ScheduleUnit]: """ return [su for su in self.sequence if su.stype == SUType.Forward] - def get_graph(self, sus: List[ScheduleUnit], name: str) -> IRGraph: - """ - Generate IRGraph - """ - nodes = list() - for su in sus: - nodes += su.nodes() - return IRGraph(nodes, None, None, name) - def happen_before(self, su1, su2): """ Check if the su1 -> (happened before) su2 @@ -220,8 +209,7 @@ def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: bnode.set_output(idx, fin.grad) bsu = ScheduleUnit([bnode], stype=SUType.Backward, name='bsu') bsu.device = su2.mirror.device - fsu.mirror = bsu - bsu.mirror = fsu + IRCell.make_pair(fsu, bsu) def _set_adapters(su1: ScheduleUnit, su2: ScheduleUnit, msu: ScheduleUnit): # set adapter @@ -448,13 +436,14 @@ def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: recv_tensors=[sub_tensor], recv_ranks = [-1] ) - send_op.pair(recv_op) + IRCell.make_pair(send_op, recv_op) send_su = ScheduleUnit([send_op], SUType.P2P, name='send') recv_su = ScheduleUnit([recv_op], SUType.P2P, name='recv') su._add_in_adapter(in_idx, send_su, recv_su) send_su.device = su.device pre_su._add_out_adapter(out_idx, send_su, recv_su) recv_su.device = su.device + IRCell.make_pair(send_su, recv_su) # add adapter for merge if len(tensor_segments) != 0: try: @@ -578,8 +567,7 @@ def gen_sugraph(nodes) -> SUGraph: stype = SUType.Backward index = fnodes.index(node.mirror) fsu = fsus[index] - su.mirror = fsu - fsu.mirror = su + IRCell.make_pair(su, fsu) else: raise NotImplementedError("Not implemented node type") su.stype = stype diff --git a/tests/schedule/test_su.py b/tests/schedule/test_su.py index 50794915..d255c139 100644 --- a/tests/schedule/test_su.py +++ b/tests/schedule/test_su.py @@ -3,6 +3,7 @@ from cube.graph.tensor import IRFullTensor from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph +from cube.ir.cten import IRCell from cube.schedule.su import SUType, ScheduleUnit @@ -86,7 +87,7 @@ def test_su_copy(): su1 = ScheduleUnit([linear1, linear2], stype=SUType.Forward) su2 = ScheduleUnit([linear1, linear2, linear3], stype=SUType.Forward) - su1.mirror = su2 + IRCell.make_pair(su1, su2) csu = copy.copy(su1) assert csu.inputs() == su1.inputs() diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py index 0ddfd868..7b72a3e2 100644 --- a/tests/schedule/test_sugraph.py +++ b/tests/schedule/test_sugraph.py @@ -1,6 +1,7 @@ from cube.graph.tensor import IRFullTensor from cube.graph.operator.function import Linear from cube.graph.graph import IRGraph +from cube.ir.cten import IRCell from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.sugraph import SUGraph @@ -149,7 +150,7 @@ def test_sugraph_assign1(): recv_tensors=[su1.outputs(0)], recv_ranks = [-1] ) - send_op.pair(recv_op) + IRCell.make_pair(send_op, recv_op) send_su12 = ScheduleUnit([send_op], SUType.P2P, name='send') recv_su12 = ScheduleUnit([recv_op], SUType.P2P, name='recv') su1._add_out_adapter(0, send_su12, recv_su12) @@ -164,7 +165,7 @@ def test_sugraph_assign1(): recv_tensors=[su1.outputs(0)], recv_ranks = [-1] ) - send_op.pair(recv_op) + IRCell.make_pair(send_op, recv_op) send_su23 = ScheduleUnit([send_op], SUType.P2P, name='send') recv_su23 = ScheduleUnit([recv_op], SUType.P2P, name='recv') su2._add_out_adapter(0, send_su23, recv_su23) @@ -214,7 +215,7 @@ def test_sugraph_assign2(): recv_tensors=[su1.outputs(0)], recv_ranks = [-1] ) - send_op.pair(recv_op) + IRCell.make_pair(send_op, recv_op) send_su12 = ScheduleUnit([send_op], SUType.P2P, name='send') recv_su12 = ScheduleUnit([recv_op], SUType.P2P, name='recv') su1._add_out_adapter(0, send_su12, recv_su12) @@ -229,7 +230,7 @@ def test_sugraph_assign2(): recv_tensors=[su1.outputs(0)], recv_ranks = [-1] ) - send_op.pair(recv_op) + IRCell.make_pair(send_op, recv_op) send_su23 = ScheduleUnit([send_op], SUType.P2P, name='send') recv_su23 = ScheduleUnit([recv_op], SUType.P2P, name='recv') su2._add_out_adapter(0, send_su23, recv_su23) From e30402e09d340ada63dd92ffba7693839cbc600c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 17 Nov 2021 17:20:49 +0800 Subject: [PATCH 0322/1892] replicate op call --- cube/execplan/planpass/merge.py | 43 +++++++++++++++++++++++++++--- cube/graph/graph.py | 8 ++++++ cube/ir/cten.py | 5 ++++ cube/schedule/sugraph.py | 24 +++++++++++++---- examples/policy/col_parallel.py | 18 ++++++++----- examples/policy/hybrid_parallel.py | 20 +++++++------- examples/policy/row_parallel.py | 15 ++++++----- 7 files changed, 104 insertions(+), 29 deletions(-) diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index 8a55d2fd..1ffe5147 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -5,6 +5,7 @@ from cube.graph.operator.operator import IRBpOperation from cube.ir.cten import IRCell from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.sugraph import SUGraph class MergeComputeSU(PlanPass): @@ -15,11 +16,12 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: Merge consecutive backward SUs. The forward SUs will also be merged if possible """ + sugraph = execplan.sugraph for devid in execplan.devices(): dev_seq = execplan.sequence(devid) + [None] pieces: List[ScheduleUnit] = list() adapters: List[ScheduleUnit] = list() - for seqidx, su in enumerate(dev_seq): + for su in dev_seq: if su and su.stype in [SUType.P2P, SUType.Transform]: if len(pieces) > 0: adapters.append(su) @@ -27,13 +29,18 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: if su and su.stype in [SUType.Backward]: allow_merge = len(pieces) == 0 for psu in pieces[::-1]: - if execplan.sugraph.happen_before(psu, su): + if sugraph.happen_before(psu, su): allow_merge = True break for adapter in adapters: - if execplan.sugraph.happen_before(adapter, su): + if sugraph.happen_before(adapter, su): allow_merge = False break + # no merge adapters connected between forward SUs + if allow_merge: + fsus = [su.mirror] + [bsu.mirror for bsu in pieces] #[::-1] + if MergeComputeSU._connected_by_adapter(execplan, fsus, devid): + allow_merge = False if allow_merge: pieces.append(su) continue @@ -92,3 +99,33 @@ def _merge(pieces: List[ScheduleUnit], devid: int) -> ScheduleUnit: IRCell.make_pair(mfsu, mbsu) return mfsu + + @staticmethod + def _connected_by_adapter(execplan: ExectuionPlan, fpieces, devid: int): + """ + Check if there is an adapter connecting forward SUs + """ + sugraph = execplan.sugraph + indices = [execplan.sequence(devid).index(fsu) for fsu in fpieces] + start = min(indices) + end = max(indices) + # check fsu1 -> asu -> fsu2 + for asu in execplan.sequence(devid)[start:end]: + if asu.stype in [SUType.P2P, SUType.Transform, SUType.Coll]: + happen_before = False + happen_after = False + # fsu1 -> asu + for fsu1 in fpieces: + if sugraph.happen_before(fsu1, asu): + happen_after = True + break + if not happen_after: + continue + # asu -> fsu2 + for fsu2 in fpieces: + if sugraph.happen_before(asu, fsu2): + happen_before = True + break + if happen_before and happen_after: + return True + return False diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 102b31a6..dab05702 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -269,6 +269,7 @@ def replicate(self, op: IRCell, times=1): raise RuntimeError(f"Op {op} not exsits") ops = [op] + mirror_ops = [op.mirror] for _ in range(times - 1): cpy_op = copy.copy(op) for idx, input in enumerate(op.inputs()): @@ -281,10 +282,17 @@ def replicate(self, op: IRCell, times=1): cpy_mirror_op.set_input(idx, input) for idx, output in enumerate(op.mirror.outputs()): cpy_mirror_op.set_output(idx, output) + mirror_ops.append(cpy_mirror_op) IRCell.make_pair(cpy_op, cpy_mirror_op) ops.append(cpy_op) idx = self.nodes().index(op) + # forward self._nodes = self._nodes[:idx] + ops + self._nodes[idx+1:] + # backward + if op.mirror: + mirror_ops = mirror_ops[::-1] + midx = self.nodes().index(op.mirror) + self._nodes = self._nodes[:midx] + mirror_ops + self._nodes[midx+1:] self.reset_dependency() return ops diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 787cfb15..df798df6 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -68,6 +68,11 @@ def __init__(self, self._mirror = None self._tag = None + def __eq__(self, other): + if isinstance(other, IRCell): + return self._id == other._id + return False + @property def device(self): return list(self._device) diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 4bb0b983..695d0450 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -116,13 +116,20 @@ def fsus(self) -> List[ScheduleUnit]: """ return [su for su in self.sequence if su.stype == SUType.Forward] - def happen_before(self, su1, su2): + def happen_before(self, su1, su2, visited=None): """ Check if the su1 -> (happened before) su2 Returns: Boolean """ + # FIXME: there is still a strange bug may cause infinite loop + if visited is None: + visited = list() + if su1 in visited: + return False + visited.append(su1) + if not isinstance(su1, ScheduleUnit) or \ not isinstance(su2, ScheduleUnit): raise TypeError("Expected su to be an ScheduleUnit") @@ -130,7 +137,10 @@ def happen_before(self, su1, su2): return True else: for succ_su in su1.successors(): - if self.happen_before(succ_su, su2): + # don't need to consider P2P comm dependency + if succ_su.stype == SUType.P2P: + continue + if self.happen_before(succ_su, su2, visited): return True return False @@ -540,8 +550,8 @@ def __repr__(self): for out_idx in range(len(node.outputs())): node_list = [snode._id for snode in node.successors(out_idx)] succ_node_ids[out_idx] = node_list - dscp += f"{node._id}: {node}\n" - # dscp += f"\n{node._id}: {node} -> su id {succ_node_ids}\n" + # dscp += f"{node._id}: {node}\n" + dscp += f"\n{node._id}: {node} -> su id {succ_node_ids}\n" return dscp @@ -565,9 +575,13 @@ def gen_sugraph(nodes) -> SUGraph: fsus.append(su) elif isinstance(node, IRBpOperation): stype = SUType.Backward - index = fnodes.index(node.mirror) + # get the last one same node + index = len(fnodes) - fnodes[::-1].index(node.mirror) - 1 fsu = fsus[index] IRCell.make_pair(su, fsu) + # remove fsu + fnodes.pop(index) + fsus.remove(fsu) else: raise NotImplementedError("Not implemented node type") su.stype = stype diff --git a/examples/policy/col_parallel.py b/examples/policy/col_parallel.py index 76e2e194..ca118e4b 100644 --- a/examples/policy/col_parallel.py +++ b/examples/policy/col_parallel.py @@ -1,7 +1,7 @@ from cube.graph import IRGraph from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation +from cube.graph.operator.operator import IRDataOperation, IRFwOperation def transform_policy(graph: IRGraph, resource): @@ -9,11 +9,12 @@ def transform_policy(graph: IRGraph, resource): The transformation policy transposes linear using column parallel """ for node in graph.nodes(): - if isinstance(node, IRFwOperation): + if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): algo = node.algorithms('column') - if algo is None: - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + if algo: + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx print(graph) @@ -24,12 +25,17 @@ def schedule_policy(sugraph: SUGraph, resource): """ The schedule policy assign devices """ + print(sugraph) for su in sugraph.sus(): if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) + devid = su.tag[0] + sugraph.assign(su, devid) # sugraph.assign(su, list(range(resource.ngpus))) for su in sugraph.fsus(): devid = su.tag[0] sugraph.assign(su, devid) + if su.mirror is None: + print(f'error su: {su}') + assert False sugraph.assign(su.mirror, devid) return sugraph diff --git a/examples/policy/hybrid_parallel.py b/examples/policy/hybrid_parallel.py index 65b8ed76..dd11fbc6 100644 --- a/examples/policy/hybrid_parallel.py +++ b/examples/policy/hybrid_parallel.py @@ -1,7 +1,7 @@ from cube.graph import IRGraph from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation +from cube.graph.operator.operator import IRDataOperation, IRFwOperation def transform_policy(graph: IRGraph, resource): @@ -10,19 +10,20 @@ def transform_policy(graph: IRGraph, resource): """ linear_idx = 0 for node in graph.nodes(): - if isinstance(node, IRFwOperation): - algo = None + if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): + algo = algo = None if linear_idx % 2 == 0: print(f'> column partition: {node}') algo = node.algorithms('column') else: print(f'> row partition: {node}') algo = node.algorithms('row') - if algo is None: - print(f'> data partition: {node}') - algo = node.algorithms('data') - linear_idx += 1 - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + if algo: + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + linear_idx += 1 + else: + print(f'> replicate: {node}') + sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx print(graph) @@ -35,7 +36,8 @@ def schedule_policy(sugraph: SUGraph, resource): """ for su in sugraph.sus(): if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) + devid = su.tag[0] + sugraph.assign(su, devid) # sugraph.assign(su, list(range(resource.ngpus))) for su in sugraph.fsus(): devid = su.tag[0] diff --git a/examples/policy/row_parallel.py b/examples/policy/row_parallel.py index 0c1198fa..0a07a473 100644 --- a/examples/policy/row_parallel.py +++ b/examples/policy/row_parallel.py @@ -1,7 +1,7 @@ from cube.graph import IRGraph from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation +from cube.graph.operator.operator import IRFwOperation, IRDataOperation def transform_policy(graph: IRGraph, resource): @@ -9,11 +9,12 @@ def transform_policy(graph: IRGraph, resource): The transformation policy transposes linear using column parallel """ for node in graph.nodes(): - if isinstance(node, IRFwOperation): + if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): algo = node.algorithms('row') - if algo is None: - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + if algo: + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx return graph @@ -23,9 +24,11 @@ def schedule_policy(sugraph: SUGraph, resource): """ The schedule policy assign devices """ + # print(sugraph) for su in sugraph.sus(): if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) + devid = su.tag[0] + sugraph.assign(su, devid) # sugraph.assign(su, list(range(resource.ngpus))) for su in sugraph.fsus(): devid = su.tag[0] From fe6512aa8ab5b5833a0eeb5c96631f14c215c8ce Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 17 Nov 2021 18:33:59 +0800 Subject: [PATCH 0323/1892] example for mlp microbench --- benchmark/megatron/linears.py | 2 ++ cube/execplan/planpass/merge.py | 1 - examples/linears.py | 10 ++++++---- examples/policy/hybrid_parallel.py | 2 +- examples/policy/megatron_parallel.py | 5 ++--- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/benchmark/megatron/linears.py b/benchmark/megatron/linears.py index b05cbb15..4113009e 100644 --- a/benchmark/megatron/linears.py +++ b/benchmark/megatron/linears.py @@ -89,6 +89,8 @@ def train(args): def train_iter(model, dataloader): data = next(dataloader) + # torch.distributed.broadcast(data, 0) + # torch.cuda.synchronize() loss = model(data) loss.backward() diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index 1ffe5147..8805e651 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -5,7 +5,6 @@ from cube.graph.operator.operator import IRBpOperation from cube.ir.cten import IRCell from cube.schedule.su import SUType, ScheduleUnit -from cube.schedule.sugraph import SUGraph class MergeComputeSU(PlanPass): diff --git a/examples/linears.py b/examples/linears.py index b34afd4b..054fe52e 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -16,8 +16,9 @@ import cube from cube.profiler import CudaTimer -from examples.policy.hybrid_parallel import transform_policy -from examples.policy.hybrid_parallel import schedule_policy +from cube.profiler.timer import print_each_rank +from examples.policy.megatron_parallel import transform_policy +from examples.policy.megatron_parallel import schedule_policy # =================== Semantic Model Description ==================== @@ -58,6 +59,7 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + torch.distributed.barrier() iter_num = 128 for step in range(iter_num): if step >= 10: @@ -68,9 +70,9 @@ def train_iter(model, dataloader): if step >= 10: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: - print(f'iter [{step + 1}/{iter_num}]') + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - print('e2e time (ms) per iteration: {} ms'.format( + print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-10, field_name='e2e'))) diff --git a/examples/policy/hybrid_parallel.py b/examples/policy/hybrid_parallel.py index dd11fbc6..49decb71 100644 --- a/examples/policy/hybrid_parallel.py +++ b/examples/policy/hybrid_parallel.py @@ -26,7 +26,7 @@ def transform_policy(graph: IRGraph, resource): sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx - print(graph) + # print(graph) return graph diff --git a/examples/policy/megatron_parallel.py b/examples/policy/megatron_parallel.py index 54c2c121..e5e40df1 100644 --- a/examples/policy/megatron_parallel.py +++ b/examples/policy/megatron_parallel.py @@ -17,7 +17,7 @@ def transform_policy(graph: IRGraph, resource): algo = node.algorithms('data') sub_nodes = graph.partition(node, algo, config=dict(chunk_num=dp)) for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx * tp + sub_node.tag = idx # partition operators first in column and then in data if isinstance(node, IRFwOperation): all_sub_nodes = list() @@ -36,8 +36,7 @@ def transform_policy(graph: IRGraph, resource): all_sub_nodes += ssub_nodes linear_idx += 1 else: - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + sub_nodes = graph.replicate(node, times=resource.ngpus) all_sub_nodes += sub_nodes # add tags (vdev) for node for idx, ssub_node in enumerate(all_sub_nodes): From 6fc9a65b3835adb04573f97f2ce89b1d162982bd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Nov 2021 10:44:28 +0800 Subject: [PATCH 0324/1892] add warm up --- benchmark/megatron/linears.py | 12 +++++++----- cube/profiler/timer.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/benchmark/megatron/linears.py b/benchmark/megatron/linears.py index 4113009e..e6070c71 100644 --- a/benchmark/megatron/linears.py +++ b/benchmark/megatron/linears.py @@ -19,6 +19,7 @@ import cube from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank class ColumnMLP(nn.Module): @@ -94,21 +95,22 @@ def train_iter(model, dataloader): loss = model(data) loss.backward() + CudaTimer().warmup(seconds=1.0) torch.distributed.barrier() iter_num = 128 for step in range(iter_num): - if step >= 10: + if step >= 40: CudaTimer().start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() - if step >= 10: + if step >= 40: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: - print(f'iter [{step + 1}/{iter_num}]') + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - print('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-10, field_name='e2e'))) + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) if __name__ == '__main__': diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index 0cbb31b7..bc1481f5 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -94,3 +94,20 @@ def print_all(self, times): msg = ' | '.join(msg) print_each_rank(msg) + + def warmup(self, seconds=1.0): + """ + Warm up GPU for `span` seconds. + """ + print('> warming up for 1 second') + data1 = torch.randn((4096, 4096), device=torch.cuda.current_device()) + data2 = torch.randn((4096, 4096), device=torch.cuda.current_device()) + # warm up 1s + if torch.distributed.is_initialized(): + torch.distributed.barrier() + start = time.time() + while time.time() - start < seconds: + out = torch.matmul(data1, data2) + # if torch.distributed.is_initialized(): + # torch.distributed.all_reduce(out) + torch.cuda.synchronize() From 88e51500a2907072a6ade0765c69c5f9815b6232 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Nov 2021 10:46:21 +0800 Subject: [PATCH 0325/1892] add warm up --- examples/linears.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/linears.py b/examples/linears.py index 054fe52e..16d8fb83 100644 --- a/examples/linears.py +++ b/examples/linears.py @@ -17,8 +17,8 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.policy.megatron_parallel import transform_policy -from examples.policy.megatron_parallel import schedule_policy +from examples.policy.hybrid_parallel import transform_policy +from examples.policy.hybrid_parallel import schedule_policy # =================== Semantic Model Description ==================== @@ -59,21 +59,22 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + CudaTimer().warmup() torch.distributed.barrier() iter_num = 128 for step in range(iter_num): - if step >= 10: + if step >= 40: CudaTimer().start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() - if step >= 10: + if step >= 40: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-10, field_name='e2e'))) + CudaTimer().duration(iter_num-40, field_name='e2e'))) if __name__ == '__main__': From b03b85f5ef54280444c08d5549cb5688e6189b28 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Nov 2021 10:50:02 +0800 Subject: [PATCH 0326/1892] add trace inspector --- examples/inspector.py | 93 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 examples/inspector.py diff --git a/examples/inspector.py b/examples/inspector.py new file mode 100644 index 00000000..99ce7438 --- /dev/null +++ b/examples/inspector.py @@ -0,0 +1,93 @@ +""" +Directly loading generated file for training + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/inspector.py +""" +import torch +import argparse +import time + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +kDataShapes = ([128, 1024],) + + +def load_module(filename: str): + import importlib.util + rank = torch.distributed.get_rank() + print(f'> [{rank}] loading generated spatial moduel from {filename}') + spec = importlib.util.spec_from_file_location("GenModel", filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + loaded_module = module.GenModel().cuda() + # sync parameters before start training + loaded_module.sync_params() + return loaded_module + + +def load_train_fn(filename: str): + import importlib.util + rank = torch.distributed.get_rank() + print(f'> [{rank}] loading generated schedule from {filename} ...') + spec = importlib.util.spec_from_file_location( + "_train_step", filename + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module._train_step + + +def train(args): + global kDataShapes + + dataloader = cube.runtime.syndata.SynDataLoader(1280, *kDataShapes) + + genfile = args.genfile.format(rank=torch.distributed.get_rank()) + model = load_module(genfile) + train_fn = load_train_fn(genfile) + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer().warmup() + torch.distributed.barrier() + with torch.profiler.profile() as prof: + iter_num = args.iter_num + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_fn(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + time.sleep(0.05) + + prof.export_chrome_trace(f"trace{torch.distributed.get_rank()}.json") + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='inspect') + parser.add_argument('--genfile', type=str, + default='gencode{rank}.py') + parser.add_argument('--iter-num', type=int, + default=128) + args = parser.parse_args() + + cube.init() + train(args) From 9cd764a3564c6951ace3652022a54cf397ac5d5b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Nov 2021 14:17:30 +0800 Subject: [PATCH 0327/1892] add torch scriptable transformer --- examples/attention.py | 156 ------------------------------- examples/transformer.py | 202 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 202 insertions(+), 156 deletions(-) delete mode 100644 examples/attention.py create mode 100644 examples/transformer.py diff --git a/examples/attention.py b/examples/attention.py deleted file mode 100644 index 9eb8e57e..00000000 --- a/examples/attention.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Optional -import torch -from torch import nn -import torch.nn.functional as F -from einops import rearrange - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -class MultiHeadSelfAttention(nn.Module): - - def __init__(self, seq_len, embed_dim, heads, dropout): - super().__init__() - - self.seq_len = seq_len - self.embed_dim = embed_dim - self.num_heads = heads - self.dim_head = embed_dim // heads - self.scale = self.dim_head ** -0.5 - - self.weight_qkv = torch.nn.Parameter(torch.empty( - 3 * embed_dim, embed_dim - )) - self.weight_out = torch.nn.Parameter(torch.empty( - embed_dim, embed_dim - )) - self.dropout = nn.Dropout(dropout) - - self._reset_parameters() - - def _reset_parameters(self): - torch.nn.init.xavier_uniform_(self.weight_qkv) - torch.nn.init.xavier_uniform_(self.weight_out) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - x: [L, N, E]: seq_len, batch_size, embedding dimension - output: [L, N, E] - """ - bs = x.shape[1] - - qkv = F.linear(x, self.weight_qkv, None).chunk(3, dim=-1) - q, k, v = qkv - q = q.contiguous().view(self.seq_len, (bs * self.num_heads), self.dim_head) - q = q.transpose(0, 1) - # => q: (batch size, seq_len, embed_dim) - k = k.contiguous().view(self.seq_len, (bs * self.num_heads), self.dim_head) - k = k.transpose(0, 1) - v = v.contiguous().view(self.seq_len, (bs * self.num_heads), self.dim_head) - v = v.transpose(0, 1) - - q = q * self.scale - attn = torch.bmm(q, k.transpose(-2, -1)) - attn = F.softmax(attn, dim=-1) - attn = self.dropout(attn) - output = torch.bmm(attn, v) - output = output.transpose(0, 1).contiguous() - output = output.view(self.seq_len, bs, self.embed_dim) - output = F.linear(output, self.weight_out) - return output - - def _ref_forward(self, x, mask: Optional[torch.Tensor] = None): - """ - X: [L, N, E]: seq_len, batch_size, embedding dimension - """ - output, _ = F.multi_head_attention_forward( - query=x, - key=x, - value=x, - embed_dim_to_check=self.embed_dim, - num_heads=self.num_heads, - in_proj_weight=self.weight_qkv, - in_proj_bias=None, - bias_k = None, - bias_v = None, - add_zero_attn=False, - dropout_p=self.dropout.p, - out_proj_weight=self.weight_out, - out_proj_bias=None, - training=self.training, - need_weights=False - ) - return output - - -class Attention(nn.Module): - def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): - super().__init__() - inner_dim = dim_head * heads - self.heads = heads - self.seq_len = seq_len - self.scale = dim_head ** -0.5 - - self.stable = stable - self.causal = causal - - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x, mask = None): - b, n, _, h, device = *x.shape, self.heads, x.device - - qkv = self.to_qkv(x).chunk(3, dim = -1) - q = rearrange(qkv[0], 'b n (h d) -> b h n d', h = h) - k = rearrange(qkv[0], 'b n (h d) -> b h n d', h = h) - v = rearrange(qkv[0], 'b n (h d) -> b h n d', h = h) - - - q = q * self.scale - - dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) - mask_value = max_neg_value(dots) - - if mask: - mask = rearrange(mask, 'b j -> b () () j') - dots.masked_fill_(~mask, mask_value) - del mask - - if self.causal: - i, j = dots.shape[-2:] - mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() - dots.masked_fill_(mask, mask_value) - - attn = torch.softmax(dots, dim=-1) - - out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) - return out - - -if __name__ == '__main__': - - L = 64 - N = 16 - E = 128 - n_heads = 8 - - model = MultiHeadSelfAttention(L, E, n_heads, dropout=0.0) - - x = torch.rand((L, N, E)) - - out_ref = model._ref_forward(x) - out = model(x) - # print(out) - # print(out_ref) - assert torch.allclose(out, out_ref) is True - print('Test passed') - module = torch.jit.script(model) - print(module.graph) - print(module.code) \ No newline at end of file diff --git a/examples/transformer.py b/examples/transformer.py new file mode 100644 index 00000000..b5b39359 --- /dev/null +++ b/examples/transformer.py @@ -0,0 +1,202 @@ +from typing import Optional +import torch +from torch import nn +import torch.nn.functional as F + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +class MultiHeadSelfAttention(nn.Module): + + def __init__(self, seq_len, embed_dim, heads, dropout): + super().__init__() + + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_heads = heads + self.dim_head = embed_dim // heads + self.scale = self.dim_head ** -0.5 + + self.weight_qkv = torch.nn.Parameter(torch.empty( + 3 * embed_dim, embed_dim + )) + self.weight_out = torch.nn.Parameter(torch.empty( + embed_dim, embed_dim + )) + self.dropout = nn.Dropout(dropout) + + self._reset_parameters() + + def _reset_parameters(self): + torch.nn.init.xavier_uniform_(self.weight_qkv) + torch.nn.init.xavier_uniform_(self.weight_out) + + def forward(self, x, mask): + """ + x: [L, N, E]: seq_len, batch_size, embedding dimension + output: [L, N, E] + """ + bs = x.shape[1] + + # [L, N, E] -> [L, N, (num_heads * dim_head * 3)] + qkv = F.linear(x, self.weight_qkv, None).chunk(3, dim=-1) + # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 + q, k, v = qkv + + # [L, N, (num_heads * dim_head)] -> [L, (N * num_heads), dim_head] + q = q.contiguous() + q = q.view(self.seq_len, (bs * self.num_heads), self.dim_head) + # [L, N, (num_heads * dim_head)] -> [(N * num_heads), L, dim_head] + q = q.transpose(0, 1) + + k = k.contiguous() + k = k.view(self.seq_len, (bs * self.num_heads), self.dim_head) + k = k.transpose(0, 1) + + v = v.contiguous() + v = v.view(self.seq_len, (bs * self.num_heads), self.dim_head) + v = v.transpose(0, 1) + + # [(N * num_heads), L, dim_head] -> [(N * num_heads), L, dim_head] + q = q * self.scale + # [(N * num_heads), L, dim_head] * [(N * num_heads), dim_head, L] + # -> [(N * num_heads), L, L] + attn = torch.bmm(q, k.transpose(-2, -1)) + + # [(N * num_heads), L, L] -> [N, num_heads, L, L] + attn = attn.view(bs, self.num_heads, self.seq_len, self.seq_len) + # [N, num_heads, L, L] -> [N, num_heads, L, L] + attn = attn.masked_fill_(mask, -10000.0) + # [N, num_heads, L, L] -> [(N * num_heads), L, L] + attn = attn.view((bs * self.num_heads), self.seq_len, self.seq_len) + + # [(N * num_heads), L, L] -> [(N * num_heads), L, L] + attn = F.softmax(attn, dim=-1) + + # [(N * num_heads), L, L] -> [(N * num_heads), L, L] + attn = self.dropout(attn) + # [(N * num_heads), L, L] * [(N * num_heads), L, dim_head] + # -> [(N * num_heads), L, dim_head] + output = torch.bmm(attn, v) + # [(N * num_heads), L, dim_head] -> [L, (N * num_heads), dim_head] + output = output.transpose(0, 1).contiguous() + # [L, (N * num_heads), dim_head] -> [L, N, (num_heads * dim_head)] + output = output.view(self.seq_len, bs, self.embed_dim) + # [L, N, (num_heads * dim_head)] * [(num_heads * dim_head), (num_heads * dim_head)] + # => [L, N, (num_heads * dim_head)] + output = F.linear(output, self.weight_out) + return output + + def _ref_forward(self, x, mask: Optional[torch.Tensor] = None): + """ + X: [L, N, E]: seq_len, batch_size, embedding dimension + """ + output, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=self.embed_dim, + num_heads=self.num_heads, + in_proj_weight=self.weight_qkv, + in_proj_bias=None, + bias_k = None, + bias_v = None, + add_zero_attn=False, + dropout_p=self.dropout.p, + out_proj_weight=self.weight_out, + out_proj_bias=None, + training=self.training, + need_weights=False + ) + return output + + +class MLP(torch.nn.Module): + + def __init__(self, hidden_size: int): + super().__init__() + self.dense_h_to_4h = torch.nn.Linear( + hidden_size, 4 * hidden_size + ) + self.dense_4h_to_h = torch.nn.Linear( + 4 * hidden_size, hidden_size + ) + + def forward(self, hidden_states): + # [L, N, E] * [E, 4E] -> [L, N, 4E] + out = self.dense_h_to_4h(hidden_states) + # [L, N, 4E] -> [L, N, 4E] + out = F.gelu(out) + # [L, N, 4E] * [4E, E] -> [L, N, E] + out = self.dense_4h_to_h(out) + return out + + +class TransformerLayer(torch.nn.Module): + + def __init__(self, seq_len, hidden_size, head_num, dropout): + super().__init__() + # layer norm + self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + + self.attention = MultiHeadSelfAttention(seq_len, hidden_size, head_num, dropout) + self.attn_dropout = torch.nn.Dropout(dropout) + + self.mlp_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + self.mlp = MLP(hidden_size) + self.mlp_dropout = torch.nn.Dropout(dropout) + + def forward(self, hidden_states, attention_mask): + # Attention + in_attn_norm = self.input_layernorm(hidden_states) + attn_out = self.attention(in_attn_norm, attention_mask) + # residual + attn_out = self.attn_dropout(attn_out) + residual = attn_out + hidden_states + # MLP + in_mlp_norm = self.mlp_layernorm(residual) + mlp_out = self.mlp(in_mlp_norm) + # residual + mlp_out = self.mlp_dropout(mlp_out) + mlp_out = mlp_out + residual + return mlp_out + + +def get_attn_mask(batch_size: int, seq_len: int): + ones = torch.ones( + (batch_size, seq_len, seq_len), + device=torch.cuda.current_device() + ) + mask = torch.tril(ones) + mask = mask.view(batch_size, 1, seq_len, seq_len) + mask = (mask < 0.5) + return mask + + +def reset_parameter(model): + for param in model.parameters(): + torch.nn.init.uniform_(param) + +if __name__ == '__main__': + + L = 64 + N = 16 + E = 1024 + n_heads = 8 + + model = TransformerLayer(L, E, n_heads, 0.5).cuda() + reset_parameter(model) + + x = torch.rand((L, N, E)).cuda() + mask = get_attn_mask(N, L).cuda() + + out = model(x, mask) + print(out) + # print(out_ref) + # assert torch.allclose(out, out_ref) is True + print('Test passed') + module = torch.jit.script(model) + print(module.graph) + print(module.code) \ No newline at end of file From f517ab2d35ceb574f95750903248d8bc1353946b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Nov 2021 14:21:39 +0800 Subject: [PATCH 0328/1892] make code clean --- examples/transformer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/transformer.py b/examples/transformer.py index b5b39359..adacd78c 100644 --- a/examples/transformer.py +++ b/examples/transformer.py @@ -27,12 +27,6 @@ def __init__(self, seq_len, embed_dim, heads, dropout): )) self.dropout = nn.Dropout(dropout) - self._reset_parameters() - - def _reset_parameters(self): - torch.nn.init.xavier_uniform_(self.weight_qkv) - torch.nn.init.xavier_uniform_(self.weight_out) - def forward(self, x, mask): """ x: [L, N, E]: seq_len, batch_size, embedding dimension @@ -41,8 +35,9 @@ def forward(self, x, mask): bs = x.shape[1] # [L, N, E] -> [L, N, (num_heads * dim_head * 3)] - qkv = F.linear(x, self.weight_qkv, None).chunk(3, dim=-1) + qkv = F.linear(x, self.weight_qkv, None) # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 + qkv = qkv.chunk(3, dim=-1) q, k, v = qkv # [L, N, (num_heads * dim_head)] -> [L, (N * num_heads), dim_head] From 8b2d648af31a0d31d9d82a8501d31fca25becb04 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Nov 2021 14:52:51 +0800 Subject: [PATCH 0329/1892] add attention test --- examples/transformer.py | 47 +++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/examples/transformer.py b/examples/transformer.py index adacd78c..979dee4c 100644 --- a/examples/transformer.py +++ b/examples/transformer.py @@ -1,13 +1,8 @@ -from typing import Optional import torch from torch import nn import torch.nn.functional as F -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - class MultiHeadSelfAttention(nn.Module): def __init__(self, seq_len, embed_dim, heads, dropout): @@ -63,7 +58,8 @@ def forward(self, x, mask): # [(N * num_heads), L, L] -> [N, num_heads, L, L] attn = attn.view(bs, self.num_heads, self.seq_len, self.seq_len) # [N, num_heads, L, L] -> [N, num_heads, L, L] - attn = attn.masked_fill_(mask, -10000.0) + # attn += mask # pytorch official implementation + attn = attn.masked_fill_(mask, -100000.0) # [N, num_heads, L, L] -> [(N * num_heads), L, L] attn = attn.view((bs * self.num_heads), self.seq_len, self.seq_len) @@ -84,10 +80,20 @@ def forward(self, x, mask): output = F.linear(output, self.weight_out) return output - def _ref_forward(self, x, mask: Optional[torch.Tensor] = None): + def _ref_forward(self, x, mask=True): """ X: [L, N, E]: seq_len, batch_size, embedding dimension + mask: whether to use mask """ + if mask is not None: + ones = torch.ones( + (self.seq_len, self.seq_len), + device=torch.cuda.current_device() + ) + mask = torch.tril(ones) + mask = (mask < 0.5) + else: + mask = None output, _ = F.multi_head_attention_forward( query=x, key=x, @@ -102,8 +108,9 @@ def _ref_forward(self, x, mask: Optional[torch.Tensor] = None): dropout_p=self.dropout.p, out_proj_weight=self.weight_out, out_proj_bias=None, + attn_mask=mask, training=self.training, - need_weights=False + need_weights=False, ) return output @@ -174,6 +181,26 @@ def reset_parameter(model): for param in model.parameters(): torch.nn.init.uniform_(param) + +def test_attention(): + L = 64 + N = 16 + E = 128 + n_heads = 8 + + model = MultiHeadSelfAttention(L, E, n_heads, dropout=0.0).cuda() + reset_parameter(model) + + x = torch.rand((L, N, E)).cuda() + mask = get_attn_mask(N, L).cuda() + + out = model(x, mask) + out_ref = model._ref_forward(x, mask) + + assert torch.allclose(out, out_ref) is True + print('test passed') + + if __name__ == '__main__': L = 64 @@ -181,6 +208,8 @@ def reset_parameter(model): E = 1024 n_heads = 8 + # test_attention() + model = TransformerLayer(L, E, n_heads, 0.5).cuda() reset_parameter(model) @@ -194,4 +223,4 @@ def reset_parameter(model): print('Test passed') module = torch.jit.script(model) print(module.graph) - print(module.code) \ No newline at end of file + print(module.code) From a8479372a745b6578959aeb2f7c1965dcf4134ce Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Nov 2021 16:32:58 +0800 Subject: [PATCH 0330/1892] parser init prim tuple unpack --- cube/graph/parser/parser.py | 21 ++++++++++++++++++++- examples/transformer.py | 13 ++++++------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index b9370ecd..827bc146 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -16,6 +16,8 @@ class ScriptNodeKind(enum.Enum): PrimConstant = 4 AtenOp = 5 # -> the parser may end here PrimIf = 6 # dynamic + PrimListUnpack = 7 + PrimTupleUnpack = 8 class ScriptModuleParser: @@ -80,6 +82,10 @@ def ntype(node: torch._C.Node): return ScriptNodeKind.AtenOp if node.kind() == 'prim::If': return ScriptNodeKind.PrimIf + if node.kind() == 'prim::ListUnpack': + return ScriptNodeKind.PrimListUnpack + if node.kind() == 'prim::TupleUnpack': + return ScriptNodeKind.PrimTupleUnpack raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @staticmethod @@ -98,6 +104,11 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] return ScriptModuleParser.parse_prim_attr_node(node, module, frame) if node_type == ScriptNodeKind.PrimConstant: return ScriptModuleParser.parse_prim_constant_node(node, module, frame) + if node_type == ScriptNodeKind.PrimListUnpack: + return ScriptModuleParser.parse_prim_listunpack_node(node, module, frame) + if node_type == ScriptNodeKind.PrimTupleUnpack: + return ScriptModuleParser.parse_prim_tupleunpack_node(node, module, frame) + raise NotImplementedError(f"Un-supported node type {node_type}") @staticmethod def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: @@ -168,7 +179,7 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: print(f"Warning: some non-tensor arguments are ommited in {fsig}") # create IR node - ir_node = Sign2Op.map(fsig)(inputs=input_val, n_output=len(outputs)) + ir_node = Sign2Op.map(fsig)(inputs=input_val, n_outputs=len(outputs)) if len(ir_node.outputs()) != len(outputs): raise RuntimeError( f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" @@ -299,6 +310,14 @@ def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: """ raise NotImplementedError("Dynamic Graph is not supported yet") + @staticmethod + def parse_prim_listunpack_node(node, module, frame: Frame) -> List[None]: + raise NotImplementedError + + @staticmethod + def parse_prim_tupleunpack_node(node, module, frame) -> List[None]: + raise NotImplementedError + @staticmethod def flatten(smodule, depth=0): """ diff --git a/examples/transformer.py b/examples/transformer.py index 979dee4c..c317bfe5 100644 --- a/examples/transformer.py +++ b/examples/transformer.py @@ -38,15 +38,14 @@ def forward(self, x, mask): # [L, N, (num_heads * dim_head)] -> [L, (N * num_heads), dim_head] q = q.contiguous() q = q.view(self.seq_len, (bs * self.num_heads), self.dim_head) - # [L, N, (num_heads * dim_head)] -> [(N * num_heads), L, dim_head] - q = q.transpose(0, 1) - k = k.contiguous() k = k.view(self.seq_len, (bs * self.num_heads), self.dim_head) - k = k.transpose(0, 1) - v = v.contiguous() v = v.view(self.seq_len, (bs * self.num_heads), self.dim_head) + + # [L, N, (num_heads * dim_head)] -> [(N * num_heads), L, dim_head] + q = q.transpose(0, 1) + k = k.transpose(0, 1) v = v.transpose(0, 1) # [(N * num_heads), L, dim_head] -> [(N * num_heads), L, dim_head] @@ -208,7 +207,7 @@ def test_attention(): E = 1024 n_heads = 8 - # test_attention() + test_attention() model = TransformerLayer(L, E, n_heads, 0.5).cuda() reset_parameter(model) @@ -217,7 +216,7 @@ def test_attention(): mask = get_attn_mask(N, L).cuda() out = model(x, mask) - print(out) + # print(out) # print(out_ref) # assert torch.allclose(out, out_ref) is True print('Test passed') From ad37825d67903b116f97339668b223392a6d4359 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Nov 2021 19:33:27 +0800 Subject: [PATCH 0331/1892] handle tuple unpack prim --- cube/graph/parser/parser.py | 46 +++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 827bc146..c8dc6f0f 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -1,7 +1,7 @@ import torch import enum import re -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union from cube.graph import IRFwOperation from cube.graph.tensor import IRFullTensor @@ -107,7 +107,7 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == ScriptNodeKind.PrimListUnpack: return ScriptModuleParser.parse_prim_listunpack_node(node, module, frame) if node_type == ScriptNodeKind.PrimTupleUnpack: - return ScriptModuleParser.parse_prim_tupleunpack_node(node, module, frame) + return list() # tuple unpack should only be used in prim function node raise NotImplementedError(f"Un-supported node type {node_type}") @staticmethod @@ -120,6 +120,20 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: inputs = [input for input in node.inputs()] outputs = [output for output in node.outputs()] + outputs: List[Union[torch._C.Value, IRFullTensor]] = list() + for output in node.outputs(): + # unpack the output type + if isinstance(output.type(), torch._C.TupleType): + for unpack_node in module.graph.nodes(): + if ScriptModuleParser.ntype(unpack_node) == ScriptNodeKind.PrimTupleUnpack: + if output in unpack_node.inputs(): + ScriptModuleParser.parse_prim_tupleunpack_node(unpack_node, module, frame) + break + tuple_outputs = frame.get_var(output.debugName()) + outputs += tuple_outputs + else: + outputs.append(output) + # handle function node fnode = node.inputsAt(0).node() if not ScriptModuleParser.ntype(fnode) == ScriptNodeKind.PrimConstant: @@ -141,7 +155,10 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: # handle outputs for index, output in enumerate(outputs): - frame.add_var(output.debugName(), ir_node.outputs(index)) + if isinstance(output, IRFullTensor): + ir_node.set_output(index, output) + else: + frame.add_var(output.debugName(), ir_node.outputs(index)) return [ir_node] @@ -316,7 +333,28 @@ def parse_prim_listunpack_node(node, module, frame: Frame) -> List[None]: @staticmethod def parse_prim_tupleunpack_node(node, module, frame) -> List[None]: - raise NotImplementedError + """ + Parse script module node like: + %q.1 : Tensor, %k.1 : Tensor, %v.1 : Tensor = prim::TupleUnpack(%11) + """ + inputs = [input for input in node.inputs()] + outputs = [output for output in node.outputs()] + if len(inputs) != 1: + raise RuntimeError("Find UnpackTuple has more than one input") + if len(outputs) == 1: + raise RuntimeError("Find UnpackTuple has only one output") + tuple_outs = list() + for output in outputs: + dtype = output.type().str() + var_name = output.debugName() + if dtype == 'Tensor': + ir_tensor = IRFullTensor(name=var_name) + tuple_outs.append(ir_tensor) + frame.add_var(var_name, ir_tensor) + else: + raise NotImplementedError + frame.add_var(inputs[0].debugName(), tuple_outs) + return list() @staticmethod def flatten(smodule, depth=0): From e1ac26a916de27e76904efec3033d8f960bacdee Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 19 Nov 2021 13:49:51 +0800 Subject: [PATCH 0332/1892] add operators for attention --- cube/graph/operator/function.py | 246 ++++++++++++++++++++++++++++-- cube/graph/operator/operator.py | 21 ++- cube/runtime/__init__.py | 3 +- cube/runtime/function/__init__.py | 1 + cube/runtime/function/complex.py | 51 +++++++ 5 files changed, 302 insertions(+), 20 deletions(-) create mode 100644 cube/runtime/function/__init__.py create mode 100644 cube/runtime/function/complex.py diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 2fc25e68..cd7d1803 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -24,20 +24,66 @@ def infer_shape(self): weight: [N, K] bias: [N,] """ - if len(self.inputs(0).shape) != 0 and len(self.inputs(1).shape) != 0: - shape = self.inputs(0).shape[:-1] + self.inputs(1).shape[:1] - self._outputs[0].shape = shape - return True - return False + if self.inputs(0).shape is None or self.inputs(1).shape is None: + return False + shape = self.inputs(0).shape[:-1] + self.inputs(1).shape[:1] + self._outputs[0].shape = shape + return True + + +class BatchLinear(IRFwOperation): + """ + Inputs: + + input1: [B, N, M] + input2: [B, M, P] + + Outputs: + + output: [B, N, P] + """ + + def __init__(self, signature, inputs, name='bmm', **kwargs): + + if len(inputs) != 2: + raise TypeError(f"Requires 2 inputs. But got {inputs}") + input1, input2 = inputs + super().__init__( + name, signature, + input_length=2, + output_length=1 + ) + self.set_input(0, input1) + self.set_input(1, input2) + + def infer_shape(self): + if self.inputs(0).shape is None or self.inputs(1).shape is None: + return False + b1, n1, m1 = self.inputs(0).shape + b2, m2, p2 = self.inputs(1).shape + if m1 != m2 or b1 != b2: + raise RuntimeError("Unmatch {b1} != {b2} or {m1} != {m2}") + shape = [b1, n1, p2] + self._outputs[0].shape = shape + return True class ElementWise(IRFwOperation): """ - Functions like torch.add (tensor1 + tensor2 / scaler) + Functions like torch.add, torch.mul, torch.sub, etc. """ def __init__(self, signature, inputs, name='elementwise', **kwargs): + """ + Inputs: + inputs[0]: IRTensor + inputs[1]: other (IRTensor or Number) + Outputs: + same shape as inputs[0] + """ + if len(inputs) != 2: + raise TypeError(f"Expected 2 inputs but got {inputs}") super().__init__( name, signature, input_length=len(inputs), @@ -47,13 +93,34 @@ def __init__(self, signature, inputs, name='elementwise', **kwargs): self.set_input(idx, input) def infer_shape(self): - for input in self.inputs(): - if isinstance(input, IRTensor): - if len(input.shape) != 0: - self._outputs[0].shape = copy.copy(input.shape) - return True - return False - return False + if self.inputs(0).shape is None: + return False + shape = copy.copy(self.inputs(0).shape) + self._outputs[0].shape = shape + return True + + +class Add(ElementWise): + """ + torch.add + """ + def __init__(self, signature, inputs, name='add', **kwargs): + """ + Inputs: + inputs[0]: IRTensor + inputs[1]: other (IRTensor or Number) + inputs[2]: alpha (Number) + Outputs: + same shape as inputs[0] + """ + if len(inputs) != 3: + raise TypeError( + f"Add expected 3 inputs: [tensor, other, alpha], but got {inputs}" + ) + super().__init__(signature, inputs[:2], name=name) + alpha = inputs[2] + if alpha != 1: + self.kwargs['alpha'] = alpha class ElementWiseActivation(IRFwOperation): @@ -101,14 +168,163 @@ def infer_shape(self): return True +class Sum(IRFwOperation): + """ + torch.sum + """ + def __init__(self, signature, inputs, name='sum', **kwargs): + + if len(inputs) <= 1: + raise TypeError(f"Expected at least 2 inputs, but got {inputs}") + if inputs[1] is not None and not isinstance(inputs[1], int): + raise TypeError(f"Expected inputs[1] to be None or int, but got {type(inputs[1])}") + + super().__init__( + name, signature, + input_length=1, + output_length=1 + ) + self.set_input(0, inputs[0]) + if inputs[1] is not None: + self.kwargs['dim'] = inputs[1] + if len(inputs) > 2: + self.kwargs['keepdim'] = inputs[2] + else: + self.kwargs['keepdim'] = False + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + shape = list() + if 'dim' in self.kwargs: + dim = [self.kwargs['dim']] + keepdim = self.kwargs['keepdim'] + for idx, nele in enumerate(self.inputs(0).shape): + if idx in dim: + if not keepdim: + continue + nele = 1 + shape.append(nele) + else: + shape = [1] + self._outputs[0].shape = shape + return True + + +class Softmax(IRFwOperation): + + def __init__(self, signature, inputs, name='softmax', **kwargs): + + if len(inputs) != 4: + raise TypeError(f"Expected 4 inputs, but got: {inputs}") + + tensor, dim, stacklevel, dtype = inputs[0], inputs[1], inputs[2], inputs[3] + super().__init__( + name, signature, input_length=1, output_length=1 + ) + self.set_input(0, tensor) + self.kwargs['dim'] = dim + self.kwargs['_stacklevel'] = stacklevel + self.kwargs['dtype'] = dtype + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + dim = self.kwargs['dim'] + shape = [ + nele for idx, nele in enumerate(self.inputs(0).shape) if idx != dim + ] + self._outputs[0].shape = shape + + +class Transpose(IRFwOperation): + """ + torch.transpose + """ + def __init__(self, signature, inputs, name='transpose', **kwargs): + + if len(inputs) != 3: + raise RuntimeError("expected 3 inputs ") + + if not isinstance(inputs[1], int): + raise TypeError(f"Expected 1st input: int, but got {type(inputs[1])}") + if not isinstance(inputs[2], int): + raise TypeError(f"Expected 1st input: int, but got {type(inputs[2])}") + + super().__init__( + name, signature, + input_length=1, + output_length=1 + ) + self.set_input(0, inputs[0]) + self.kwargs['dim1'] = inputs[1] + self.kwargs['dim2'] = inputs[2] + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + dim1 = self.kwargs['dim1'] + dim2 = self.kwargs['dim2'] + shape = copy.copy(list(self.inputs(0).shape)) + shape[dim1], shape[dim2] = shape[dim2], shape[dim1] + self._outputs[0].shape = shape + return True + + +class CubeComplexToQKV(IRFwOperation): + """ + function to QKV + """ + def __init__(self, signature, inputs, name='toqkv', **kwargs): + super().__init__( + name, signature, + input_length=len(inputs), + output_length=3 + ) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + shape = self.inputs(0).shape + for output in self.outputs(): + output.shape = shape + return True + + +class CubeComplexTrilMask(IRFwOperation): + """ + Function to tril_mask + """ + def __init__(self, signature, inputs, name='trilmask', **kwargs): + if len(inputs) != 2: + raise TypeError("Expected 2 input") + tensor, num_head = inputs[0], inputs[1] + super().__init__( + name, signature, + input_length=2, + output_length=1 + ) + self.set_input(0, tensor) + self.set_input(1, num_head) + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + shape = copy.copy(self.inputs(0).shape) + self._outputs[0].shape = shape + return True + + class UnkownOperator(IRFwOperation): - def __init__(self, signature, inputs, name='unknown_op', n_output=None): + def __init__(self, signature, inputs, name='unknown_op', n_outputs=None): super().__init__( name, signature=signature, input_length=len(inputs), - output_length=n_output, + output_length=n_outputs, ) for idx, input in enumerate(inputs): self.set_input(idx, input) diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 5b63647d..9ee2f71f 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -25,6 +25,8 @@ def __init__(self, input_length (int): the number of inputs for the op output_length (int): the number of outputs for the op """ + # additional argument + self.kwargs = dict() super().__init__(name, signature, input_length, output_length) outputs = [IRFullTensor() for _ in range(output_length)] for idx, output in enumerate(outputs): @@ -76,8 +78,13 @@ def __repr__(self): anno = 'w' if tensor.is_grad(): anno = 'g' - # inputs.append(f'{anno}{tensor._id}') - inputs.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') + if isinstance(tensor, IRFullTensor): + pid = tensor._id + valmap = (0,1) + else: + pid = tensor.parent._id + valmap = tensor.val_map + inputs.append(f'{anno}{tensor._id}(p{pid},{tensor.shape},{valmap})') else: inputs.append(tensor) @@ -89,8 +96,14 @@ def __repr__(self): anno = 'w' if tensor.is_grad(): anno = 'g' - # outputs.append(f'{anno}{tensor._id}') - outputs.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') + if isinstance(tensor, IRFullTensor): + pid = tensor._id + valmap = (0,1) + else: + pid = tensor.parent._id + valmap = tensor.val_map + pid = tensor.parent._id if hasattr(tensor, 'parent') else tensor._id + outputs.append(f'{anno}{tensor._id}(p{pid},{tensor.shape},{valmap})') else: outputs.append(tensor) diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index 993a8c63..47dd5028 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -3,4 +3,5 @@ from cube.runtime import reducer from cube.runtime import syndata from cube.runtime import resource -from cube.runtime import module \ No newline at end of file +from cube.runtime import module +from cube.runtime import function \ No newline at end of file diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py new file mode 100644 index 00000000..9dd5982d --- /dev/null +++ b/cube/runtime/function/__init__.py @@ -0,0 +1 @@ +from cube.runtime.function.complex import * \ No newline at end of file diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py new file mode 100644 index 00000000..90e876e0 --- /dev/null +++ b/cube/runtime/function/complex.py @@ -0,0 +1,51 @@ +import torch +import torch.nn.functional as F + + +def toqkv(input: torch.Tensor, weight: torch.nn.Parameter, + bs: int, seqlen: int, num_heads: int, dim_head: int): + """ + input: [L, N, E] (seqlen, batch size, embed dim (hidden size)) + weight: [E, E * 3] + + Returns: + Q: [L, N, E] + K: [L, N, E] + V: [L, N, E] + """ + qkv = F.linear(input, weight, None) + qkv = qkv.chunk(3, dim=-1) + q, k, v = qkv + q = q.contiguous() + q = q.view(seqlen, (bs * num_heads), dim_head) + k = k.contiguous() + k = k.view(seqlen, (bs * num_heads), dim_head) + v = v.contiguous() + v = v.view(seqlen, (bs * num_heads), dim_head) + return q, k, v + + +def tril_mask(input: torch.Tensor, num_heads: int): + """ + Inputs: + input: [N * num_heads, L, L] + num_head: int + + Returns: + output: [N * num_heads, L, L] + """ + bs: int = input.shape[0] // num_heads + seqlen: int = input.shape[2] + input = input.view(bs, num_heads, seqlen, seqlen) + # set up mask + ones = torch.ones( + (bs, seqlen, seqlen), + device=input.device, + ) + mask = torch.tril(ones) + mask = mask.view(bs, 1, seqlen, seqlen) + mask = (mask < 0.5) + # mask + masked_input = input.masked_fill_(mask, -100000.0) + masked_input = masked_input.view((bs * num_heads), seqlen, seqlen) + return masked_input From 24afd4a6a3fec5a6df8a634eadb0d1fcb0852db8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 19 Nov 2021 14:10:48 +0800 Subject: [PATCH 0333/1892] parse attention --- cube/graph/operator/function.py | 29 ++++++++++++++++++++ cube/graph/parser/mapping.py | 27 +++++++++++++++++-- cube/graph/parser/parser.py | 46 ++++++++++++++++++-------------- cube/runtime/function/complex.py | 18 +++++++++++++ 4 files changed, 98 insertions(+), 22 deletions(-) diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index cd7d1803..046ba90f 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -270,6 +270,7 @@ def infer_shape(self): self._outputs[0].shape = shape return True +# ===================== Cube Complex Operation ======================= class CubeComplexToQKV(IRFwOperation): """ @@ -317,6 +318,34 @@ def infer_shape(self): return True +class CubeComplexAttnView(IRFwOperation): + """ + Funtion to attention view + """ + def __init__(self, signature, inputs, name='attn_view', **kwargs): + if len(inputs) != 2: + raise TypeError("Expected 2 input") + tensor, num_head = inputs[0], inputs[1] + super().__init__( + name, signature, + input_length=2, + output_length=1 + ) + self.set_input(0, tensor) + self.set_input(1, num_head) + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + num_heads = self.inputs(1) + bs = self.inputs(0).shape[0] // num_heads + seqlen = self.inputs(0).shape[1] + dim_head = self.inputs(0).shape[2] + shape = [seqlen, bs, num_heads * dim_head] + self._outputs[0].shape = shape + return True + + class UnkownOperator(IRFwOperation): def __init__(self, signature, inputs, name='unknown_op', n_outputs=None): diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 8a8f3290..ceb49795 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -28,17 +28,40 @@ def map(signature: str) -> IRFwOperation: # tensor template __ttemplate = lambda name: f'torch.{name}' + # customized + __customize = lambda name: f'cube.runtime.function.complex.{name}' + kOpMap = { + # torch nn functional + __ftemplate('linear') : function.Linear, + __ftemplate('softmax') : function.Softmax, + __ftemplate('dropout') : partial(function.ElementWiseActivation, name='dropout'), __ftemplate('gelu') : partial(function.ElementWiseActivation, name='gelu'), - __ttemplate('add') : partial(function.ElementWise, name='add'), + # torch aten + + __ttemplate('add') : partial(function.Add, name='add'), + + __ttemplate('mul') : partial(function.ElementWise, name='mul'), + + __ttemplate('bmm') : function.BatchLinear, + + __ttemplate('sum') : partial(function.Sum, name='sum'), + + __ttemplate('transpose') : function.Transpose, + + # complex + + __customize('toqkv'): partial(function.CubeComplexToQKV, name='toqkv'), + + __customize('tril_mask'): function.CubeComplexTrilMask, - __ttemplate('sum') : partial(function.Reduce, name='sum'), + __customize('attn_view'): function.CubeComplexAttnView, } diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index c8dc6f0f..e9cc7e48 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -9,6 +9,9 @@ from cube.graph.parser.mapping import Sign2Op +_refmodule = torch.nn.Module() + + class ScriptNodeKind(enum.Enum): PrimGetAttr = 1 PrimCallMethod = 2 @@ -18,6 +21,7 @@ class ScriptNodeKind(enum.Enum): PrimIf = 6 # dynamic PrimListUnpack = 7 PrimTupleUnpack = 8 + PrimPythonOp = 9 class ScriptModuleParser: @@ -58,7 +62,9 @@ def parse_module(module, # _ = input('>>>') if len(ir_nodes) != 0: for ir_node in ir_nodes: - ir_node.infer_shape() + ret = ir_node.infer_shape() + if not ret: + print(f'warning: {ir_node} cannot infer shape') all_ir_nodes += ir_nodes # handle graph output -- Assuming all the output are tensors @@ -86,6 +92,8 @@ def ntype(node: torch._C.Node): return ScriptNodeKind.PrimListUnpack if node.kind() == 'prim::TupleUnpack': return ScriptNodeKind.PrimTupleUnpack + if node.kind() == 'prim::PythonOp': + return ScriptNodeKind.PrimPythonOp raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @staticmethod @@ -108,6 +116,8 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] return ScriptModuleParser.parse_prim_listunpack_node(node, module, frame) if node_type == ScriptNodeKind.PrimTupleUnpack: return list() # tuple unpack should only be used in prim function node + if node_type == ScriptNodeKind.PrimPythonOp: + return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) raise NotImplementedError(f"Un-supported node type {node_type}") @staticmethod @@ -140,7 +150,7 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: raise RuntimeError(f"Found unexpected function call node: {fnode}") fsig = frame.get_var(inputs[0].debugName()) - # handle inputs -- in stack with reverse order + # handle inputs input_vals = list() for index, input in enumerate(inputs[1:]): var_name = input.debugName() @@ -174,26 +184,11 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: outputs = [output for output in node.outputs()] # handle inputs: - # TODO: fix omitted kwargs - # We will omit arg index >= 2 as we assume the - # tensor op at most gets 2 tensor, others are kwargs input_val = list() - maybe_kwarg = len(inputs) > 2 - for reverse_index, input in enumerate(inputs[::-1]): + for input in inputs: var_name = input.debugName() val = frame.get_var(var_name) - index = len(inputs) - 1 - reverse_index - if maybe_kwarg and (not isinstance(val, IRFullTensor)) and index > 1: - continue - else: - input_val.append(val) - maybe_kwarg = False - input_val = input_val[::-1] - # handle single operand e.g., torch.sum - if input_val[1] is None: - input_val = input_val[:1] + input_val[2:] - if len(input_val) < len(inputs): - print(f"Warning: some non-tensor arguments are ommited in {fsig}") + input_val.append(val) # create IR node ir_node = Sign2Op.map(fsig)(inputs=input_val, n_outputs=len(outputs)) @@ -265,6 +260,7 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: Returns: Empty list """ + global _refmodule if node.inputsAt(0).debugName() != 'self': raise RuntimeError(f"Fail to parse {node} due to missing %self") @@ -279,7 +275,12 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: frame.add_var(var_name, ir_tensor) # symbolic attributes elif dtype in ['bool', 'int', 'float']: - frame.add_var(var_name, 'self.' + label) + if hasattr(_refmodule, label): + val = 'self.' + label + else: + val = getattr(module, label) + # print(f'get: var_name {var_name}: {val}') + frame.add_var(var_name, val) # NoneType elif dtype == 'NoneType': frame.add_var(var_name, None) @@ -356,6 +357,11 @@ def parse_prim_tupleunpack_node(node, module, frame) -> List[None]: frame.add_var(inputs[0].debugName(), tuple_outs) return list() + @staticmethod + def parse_prim_python_op_node(node, module, frame): + raise NotImplementedError("Cannot support torch.jit.ignore") + print(dir(node)) + @staticmethod def flatten(smodule, depth=0): """ diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index 90e876e0..39018071 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -49,3 +49,21 @@ def tril_mask(input: torch.Tensor, num_heads: int): masked_input = input.masked_fill_(mask, -100000.0) masked_input = masked_input.view((bs * num_heads), seqlen, seqlen) return masked_input + + +def attn_view(input: torch.Tensor, num_heads: int): + """ + Inputs: + [N * num_heads, L, dim_head] + + Outputs: + [L, N, num_heads * dim_head] + """ + bs: int = input.shape[0] // num_heads + seqlen: int = input.shape[1] + dim_head = input.shape[2] + # [(N * num_heads), L, dim_head] -> [L, (N * num_heads), dim_head] + input = input.transpose(0, 1).contiguous() + # [L, (N * num_heads), dim_head] -> [L, N, (num_heads * dim_head)] + input = input.view(seqlen, bs, num_heads * dim_head) + return input From 04cf8ae1595b2f2c5c046cc1fc7d6892eff2a9f3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 19 Nov 2021 14:54:06 +0800 Subject: [PATCH 0334/1892] parser for attention --- cube/graph/operator/function.py | 21 +++- cube/runtime/function/complex.py | 13 +- tests/graph/parser/test_parse_attention.py | 113 ++++++++++++++++++ .../test_parse_mlp.py} | 2 + 4 files changed, 136 insertions(+), 13 deletions(-) create mode 100644 tests/graph/parser/test_parse_attention.py rename tests/graph/{test_parser.py => parser/test_parse_mlp.py} (98%) diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 046ba90f..82720f80 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -277,20 +277,31 @@ class CubeComplexToQKV(IRFwOperation): function to QKV """ def __init__(self, signature, inputs, name='toqkv', **kwargs): + if len(inputs) != 5: + raise TypeError(f"Expected 5 arguments but goit {inputs}") + qkv = inputs[0] super().__init__( name, signature, - input_length=len(inputs), + input_length=1, output_length=3 ) - for idx, input in enumerate(inputs): - self.set_input(idx, input) + self.set_input(0, qkv) + self.kwargs['bs'] = inputs[1] + self.kwargs['seqlen'] = inputs[2] + self.kwargs['num_heads'] = inputs[3] + self.kwargs['dim_head'] = inputs[4] def infer_shape(self): if self.inputs(0).shape is None: return False - shape = self.inputs(0).shape + bs = self.kwargs['bs'] + seqlen = self.kwargs['seqlen'] + num_heads = self.kwargs['num_heads'] + dim_head = self.kwargs['dim_head'] + + shape = [seqlen, bs * num_heads, dim_head] for output in self.outputs(): - output.shape = shape + output.shape = copy.copy(shape) return True diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index 39018071..12241286 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -1,19 +1,16 @@ import torch -import torch.nn.functional as F -def toqkv(input: torch.Tensor, weight: torch.nn.Parameter, +def toqkv(qkv: torch.Tensor, bs: int, seqlen: int, num_heads: int, dim_head: int): """ - input: [L, N, E] (seqlen, batch size, embed dim (hidden size)) - weight: [E, E * 3] + input: [L, N, E * 3] (seqlen, batch size, num_heads * dim_head * 3)) Returns: - Q: [L, N, E] - K: [L, N, E] - V: [L, N, E] + Q: [L, N * num_heads, dim_head] + K: [L, N * num_heads, dim_head] + V: [L, N * num_heads, dim_head] """ - qkv = F.linear(input, weight, None) qkv = qkv.chunk(3, dim=-1) q, k, v = qkv q = q.contiguous() diff --git a/tests/graph/parser/test_parse_attention.py b/tests/graph/parser/test_parse_attention.py new file mode 100644 index 00000000..be159214 --- /dev/null +++ b/tests/graph/parser/test_parse_attention.py @@ -0,0 +1,113 @@ +from torch import nn +import torch +import torch.nn.functional as F + +from cube.graph import parser +import cube + + +class MultiHeadSelfAttention(nn.Module): + + def __init__(self, bs, seq_len, embed_dim, heads, dropout): + super().__init__() + + self.batch_size = bs + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_heads = heads + self.dim_head = embed_dim // heads + self.scale = self.dim_head ** -0.5 + + self.weight_qkv = torch.nn.Parameter(torch.empty( + 3 * embed_dim, embed_dim + )) + self.weight_out = torch.nn.Parameter(torch.empty( + embed_dim, embed_dim + )) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): #, mask): + """ + x: [L, N, E]: seq_len, batch_size, embedding dimension + output: [L, N, E] + """ + bs = self.batch_size + + # # [L, N, E] -> [L, N, (num_heads * dim_head * 3)] + qkv = F.linear(x, self.weight_qkv, None) + + # # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 + # qkv = qkv.chunk(3, dim=-1) + # q, k, v = qkv + # + # # [L, N, (num_heads * dim_head)] -> [L, (N * num_heads), dim_head] + # q = q.contiguous() + # q = q.view(self.seq_len, (bs * self.num_heads), self.dim_head) + # k = k.contiguous() + # k = k.view(self.seq_len, (bs * self.num_heads), self.dim_head) + # v = v.contiguous() + # v = v.view(self.seq_len, (bs * self.num_heads), self.dim_head) + + # [L, N, E] -> 3 x [L, (N * num_heads), dim_head] + q, k, v = cube.runtime.function.toqkv( + qkv, bs, self.seq_len, self.num_heads, self.dim_head + ) + + # [L, (N * num_heads), dim_head] -> [(N * num_heads), L, dim_head] + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + # [(N * num_heads), L, dim_head] -> [(N * num_heads), L, dim_head] + q = q * self.scale + # [(N * num_heads), L, dim_head] * [(N * num_heads), dim_head, L] + # -> [(N * num_heads), L, L] + k = k.transpose(-2, -1) + attn = torch.bmm(q, k) + + # # [(N * num_heads), L, L] -> [N, num_heads, L, L] + # attn = attn.view(bs, self.num_heads, self.seq_len, self.seq_len) + # # [N, num_heads, L, L] -> [N, num_heads, L, L] + # attn = attn.masked_fill_(mask, -100000.0) + # # [N, num_heads, L, L] -> [(N * num_heads), L, L] + # attn = attn.view((bs * self.num_heads), self.seq_len, self.seq_len) + attn = cube.runtime.function.tril_mask(attn, bs) + + # [(N * num_heads), L, L] -> [(N * num_heads), L, L] + attn = F.softmax(attn, dim=-1) + + # [(N * num_heads), L, L] -> [(N * num_heads), L, L] + attn = self.dropout(attn) + # [(N * num_heads), L, L] * [(N * num_heads), L, dim_head] + # -> [(N * num_heads), L, dim_head] + output = torch.bmm(attn, v) + + # # [(N * num_heads), L, dim_head] -> [L, (N * num_heads), dim_head] + # output = output.transpose(0, 1) + # output = output.contiguous() + # # [L, (N * num_heads), dim_head] -> [L, N, (num_heads * dim_head)] + # output = output.view(self.seq_len, bs, self.embed_dim) + output = cube.runtime.function.attn_view(output, self.num_heads) + + # [L, N, (num_heads * dim_head)] * [(num_heads * dim_head), (num_heads * dim_head)] + # => [L, N, (num_heads * dim_head)] + output = F.linear(output, self.weight_out) + return output + + +def test_parse_attention(): + + L = 64 # seq len + N = 16 # batch + E = 1024 # hiddend size = dim_head * num_head + n_heads = 8 + + model = MultiHeadSelfAttention(N, L, E, n_heads, dropout=0.5) + module = torch.jit.script(model) + print(module.graph) + # print(module.code) + + graph = parser.convert(model, input_shapes=([L, N, E],)) + print(graph) + + assert False diff --git a/tests/graph/test_parser.py b/tests/graph/parser/test_parse_mlp.py similarity index 98% rename from tests/graph/test_parser.py rename to tests/graph/parser/test_parse_mlp.py index c36bee0e..c902deac 100644 --- a/tests/graph/test_parser.py +++ b/tests/graph/parser/test_parse_mlp.py @@ -1,3 +1,4 @@ +import torch from torch import nn import cube.graph.parser as parser @@ -58,3 +59,4 @@ def test_parse_module(): assert node5.successors() == [node6] assert graph.outputs(0).shape == [1024, 1000] + assert False \ No newline at end of file From b12d3b75a79dd527cf8a5e4d9fdc2153db33bef7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 19 Nov 2021 16:53:58 +0800 Subject: [PATCH 0335/1892] fix bug in softmax shape infer --- cube/graph/operator/function.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 82720f80..6ef655b6 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -235,6 +235,7 @@ def infer_shape(self): nele for idx, nele in enumerate(self.inputs(0).shape) if idx != dim ] self._outputs[0].shape = shape + return True class Transpose(IRFwOperation): From d54bbd854df2772145fe8fe8fd57faf7f9493aca Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 19 Nov 2021 16:59:49 +0800 Subject: [PATCH 0336/1892] add kwargs for forward op --- cube/codegen/codegen.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 5c1c3ea6..2216e2c5 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -215,6 +215,11 @@ def emit_op_call(self, node): """ op_code = node.signature arg_names = self._forward_region_arg_names(node.inputs()) + kwargs = list() + for key in node.kwargs: + code = f'{key}={node.kwargs[key]}' + kwargs.append(code) + arg_names += kwargs arg_region = '(' + ', '.join(arg_names) + ')' if len(node.outputs()) == 0: code = f'{op_code}{arg_region}' From 002db8aebcdff04961d5ead71ff9013d140219da Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 19 Nov 2021 17:00:15 +0800 Subject: [PATCH 0337/1892] fix transpose bug --- cube/graph/operator/function.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 6ef655b6..a634a893 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -258,14 +258,14 @@ def __init__(self, signature, inputs, name='transpose', **kwargs): output_length=1 ) self.set_input(0, inputs[0]) - self.kwargs['dim1'] = inputs[1] - self.kwargs['dim2'] = inputs[2] + self.kwargs['dim0'] = inputs[1] + self.kwargs['dim1'] = inputs[2] def infer_shape(self): if self.inputs(0).shape is None: return False - dim1 = self.kwargs['dim1'] - dim2 = self.kwargs['dim2'] + dim1 = self.kwargs['dim0'] + dim2 = self.kwargs['dim1'] shape = copy.copy(list(self.inputs(0).shape)) shape[dim1], shape[dim2] = shape[dim2], shape[dim1] self._outputs[0].shape = shape From ae795e256fe02c30bd88988c1216a0c3da3e3573 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 20 Nov 2021 12:14:48 +0800 Subject: [PATCH 0338/1892] qkv linear --- cube/graph/operator/function.py | 22 ++++++++++------------ cube/runtime/function/complex.py | 13 ++++++++++--- tests/graph/parser/test_parse_attention.py | 8 +++++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index a634a893..91af8a18 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -278,31 +278,29 @@ class CubeComplexToQKV(IRFwOperation): function to QKV """ def __init__(self, signature, inputs, name='toqkv', **kwargs): - if len(inputs) != 5: - raise TypeError(f"Expected 5 arguments but goit {inputs}") - qkv = inputs[0] + if len(inputs) != 3: + raise TypeError(f"Expected 3 arguments but goit {inputs}") + qkv, weight = inputs[0], inputs[1] super().__init__( name, signature, - input_length=1, + input_length=2, output_length=3 ) self.set_input(0, qkv) - self.kwargs['bs'] = inputs[1] - self.kwargs['seqlen'] = inputs[2] - self.kwargs['num_heads'] = inputs[3] - self.kwargs['dim_head'] = inputs[4] + self.set_input(1, weight) + self.kwargs['num_heads'] = inputs[2] def infer_shape(self): if self.inputs(0).shape is None: return False - bs = self.kwargs['bs'] - seqlen = self.kwargs['seqlen'] + seqlen = self.inputs(0).shape[0] + bs = self.inputs(0).shape[1] num_heads = self.kwargs['num_heads'] - dim_head = self.kwargs['dim_head'] + dim_head = self.inputs(0).shape[2] // num_heads shape = [seqlen, bs * num_heads, dim_head] for output in self.outputs(): - output.shape = copy.copy(shape) + output.shape = shape return True diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index 12241286..e9237589 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -1,16 +1,23 @@ import torch +import torch.nn.functional as F -def toqkv(qkv: torch.Tensor, - bs: int, seqlen: int, num_heads: int, dim_head: int): +def toqkv(hidden_state: torch.Tensor, weight: torch.Tensor, + num_heads: int): """ - input: [L, N, E * 3] (seqlen, batch size, num_heads * dim_head * 3)) + Inputs: + hidden_state: [L, N, E] (seqlen, batch size, num_heads * dim_head) + weight: [E * 3, E] Returns: Q: [L, N * num_heads, dim_head] K: [L, N * num_heads, dim_head] V: [L, N * num_heads, dim_head] """ + seqlen = hidden_state.shape[0] + bs = hidden_state.shape[1] + dim_head = hidden_state.shape[2] // num_heads + qkv = F.linear(hidden_state, weight, None) qkv = qkv.chunk(3, dim=-1) q, k, v = qkv q = q.contiguous() diff --git a/tests/graph/parser/test_parse_attention.py b/tests/graph/parser/test_parse_attention.py index be159214..b90137bd 100644 --- a/tests/graph/parser/test_parse_attention.py +++ b/tests/graph/parser/test_parse_attention.py @@ -33,8 +33,10 @@ def forward(self, x): #, mask): """ bs = self.batch_size - # # [L, N, E] -> [L, N, (num_heads * dim_head * 3)] - qkv = F.linear(x, self.weight_qkv, None) + # [L, N, (num_heads * dim_head)], + # [(num_heads * dim_head), 3 * (num_heads * dim_head)] + # -> [L, N, (num_heads * dim_head * 3)] + # qkv = F.linear(x, self.weight_qkv, None) # # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 # qkv = qkv.chunk(3, dim=-1) @@ -50,7 +52,7 @@ def forward(self, x): #, mask): # [L, N, E] -> 3 x [L, (N * num_heads), dim_head] q, k, v = cube.runtime.function.toqkv( - qkv, bs, self.seq_len, self.num_heads, self.dim_head + x, self.weight_qkv, self.num_heads ) # [L, (N * num_heads), dim_head] -> [(N * num_heads), L, dim_head] From edc577768df5d3e9ec315cc04249d86350ec40ce Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 20 Nov 2021 14:36:28 +0800 Subject: [PATCH 0339/1892] add complex partition methods --- cube/algorithm/factory.py | 11 ++ cube/algorithm/ops/__init__.py | 0 cube/algorithm/ops/complex.py | 311 +++++++++++++++++++++++++++++++ cube/graph/operator/function.py | 46 +++-- cube/ir/cten.py | 4 +- cube/runtime/function/complex.py | 7 +- 6 files changed, 361 insertions(+), 18 deletions(-) create mode 100644 cube/algorithm/ops/__init__.py create mode 100644 cube/algorithm/ops/complex.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 95549ca9..e941787e 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -72,3 +72,14 @@ def _load_predefined_algos(self): import cube.algorithm.reduce as reduce self.register(reduce.Reduce, reduce.ReduceDataParallel, tag='data') + + import cube.algorithm.ops.complex as complex + self.register(complex.CubeComplexToQKV, complex.CubeToQKVDataParallel, tag='data') + self.register(complex.CubeComplexToQKV, complex.CubeToQKVHeadParallel, tag='head') + + self.register(complex.CubeComplexTrilMask, complex.CubeTrilMaskDataParallel, tag='data') + self.register(complex.CubeComplexTrilMask, complex.CubeTrilMaskHeadParallel, tag='head') + + self.register(complex.CubeComplexAttnView, complex.CubeAttnViewDataParallel, tag='data') + self.register(complex.CubeComplexAttnView, complex.CubeAttnViewHeadParallel, tag='head') + \ No newline at end of file diff --git a/cube/algorithm/ops/__init__.py b/cube/algorithm/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/algorithm/ops/complex.py b/cube/algorithm/ops/complex.py new file mode 100644 index 00000000..b4ac3500 --- /dev/null +++ b/cube/algorithm/ops/complex.py @@ -0,0 +1,311 @@ +from typing import Dict + +from cube.algorithm.utils import split_axis +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.operator.function import CubeComplexToQKV +from cube.graph.operator.function import CubeComplexTrilMask +from cube.graph.operator.function import CubeComplexAttnView + + +_kWaitDecision = None + + +class CubeToQKVDataParallel(GenericDistAlgo): + """ + Inputs: + hidden_state: [L, N, E] + weight: [3 * (num_heads * dim_head), E] + num_heads: int + + where L = sequence length, N = batch size, E = num_heads * dim_head + + Returns: + Q: [L, N * num_heads, dim_head] + K: [L, N * num_heads, dim_head] + V: [L, N * num_heads, dim_head] + """ + def __init__(self, node: CubeComplexToQKV): + + if not isinstance(node, CubeComplexToQKV): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.num_heads = node.kwargs['num_heads'] + self.bs = node.inputs(0).shape[1] + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.bs % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + hidden_size, weight = node.inputs() + q, k, v = node.outputs() + + ins = split_axis(hidden_size, 1, self.chunk_num) + qs = split_axis(q, 1, self.chunk_num) + ks = split_axis(k, 1, self.chunk_num) + vs = split_axis(v, 1, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ins[idx], weight, self.num_heads] + node = CubeComplexToQKV( + signature = 'cube.runtime.function.complex.toqkv', + inputs = inputs, + name = 'toqkv' + ) + node.set_output(0, qs[idx]) + node.set_output(1, ks[idx]) + node.set_output(2, vs[idx]) + nodes.append(node) + return nodes + + +class CubeToQKVHeadParallel(GenericDistAlgo): + """ + Inputs: + hidden_state: [L, N, E] (seqlen, batch size, num_heads * dim_head) + weight: [E * 3, E] + num_heads: int + + Returns: + Q: [L, N * num_heads, dim_head] + K: [L, N * num_heads, dim_head] + V: [L, N * num_heads, dim_head] + """ + def __init__(self, node: CubeComplexToQKV): + + if not isinstance(node, CubeComplexToQKV): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.num_heads = node.kwargs['num_heads'] + self.bs = node.inputs(0).shape[1] + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.num_heads % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + hidden_state, weight = node.inputs() + q, k, v = node.outputs() + + ws = split_axis(weight, 0, self.chunk_num) + qs = split_axis(q, 1, self.chunk_num) + ks = split_axis(k, 1, self.chunk_num) + vs = split_axis(v, 1, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [hidden_state, ws[idx], self.num_heads // self.chunk_num] + node = CubeComplexToQKV( + signature = 'cube.runtime.function.complex.toqkv', + inputs = inputs, + name = 'toqkv' + ) + node.set_output(0, qs[idx]) + node.set_output(1, ks[idx]) + node.set_output(2, vs[idx]) + nodes.append(node) + return nodes + + +class CubeTrilMaskDataParallel(GenericDistAlgo): + """ + Inputs: + input: [N * num_heads, L, L] + num_head: int + + Returns: + output: [N * num_heads, L, L] + """ + def __init__(self, node: CubeComplexTrilMask): + + if not isinstance(node, CubeComplexTrilMask): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.num_heads = node.kwargs['num_heads'] + self.bs = node.inputs(0).shape[0] // self.num_heads + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.bs % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + hidden_size = node.inputs(0) + masked_out = node.outputs(0) + + ins = split_axis(hidden_size, 0, self.chunk_num) + ous = split_axis(masked_out, 0, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ins[idx], self.num_heads] + node = CubeComplexTrilMask( + signature = 'cube.runtime.function.complex.tril_mask', + inputs = inputs, + name = 'tril_mask' + ) + node.set_output(0, ous[idx]) + nodes.append(node) + return nodes + + +class CubeTrilMaskHeadParallel(GenericDistAlgo): + """ + Inputs: + input: [N * num_heads, L, L] + num_head: int + + Returns: + output: [N * num_heads, L, L] + """ + def __init__(self, node: CubeComplexTrilMask): + + if not isinstance(node, CubeComplexTrilMask): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.num_heads = node.kwargs['num_heads'] + self.bs = node.inputs(0).shape[0] // self.num_heads + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.num_heads % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + hidden_size = node.inputs(0) + masked_out = node.outputs(0) + + ins = split_axis(hidden_size, 0, self.chunk_num) + ous = split_axis(masked_out, 0, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ins[idx], self.num_heads // self.chunk_num] + node = CubeComplexTrilMask( + signature = 'cube.runtime.function.complex.tril_mask', + inputs = inputs, + name = 'tril_mask' + ) + node.set_output(0, ous[idx]) + nodes.append(node) + return nodes + + +class CubeAttnViewDataParallel(GenericDistAlgo): + """ + Inputs: + [N * num_heads, L, dim_head] + + Outputs: + [L, N, num_heads * dim_head] + """ + def __init__(self, node: CubeComplexAttnView): + if not isinstance(node, CubeComplexAttnView): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.num_heads = node.kwargs['num_heads'] + self.bs = node.inputs(0).shape[0] // self.num_heads + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.bs % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + attn = node.inputs(0) + out = node.outputs(0) + + ins = split_axis(attn, 0, self.chunk_num) + ous = split_axis(out, 1, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ins[idx], self.num_heads] + node = CubeComplexAttnView( + signature = 'cube.runtime.function.complex.attn_view', + inputs = inputs, + name = 'attn_view' + ) + node.set_output(0, ous[idx]) + nodes.append(node) + return nodes + + +class CubeAttnViewHeadParallel(GenericDistAlgo): + """ + Inputs: + [N * num_heads, L, dim_head] + + Outputs: + [L, N, num_heads * dim_head] + """ + def __init__(self, node: CubeComplexAttnView): + if not isinstance(node, CubeComplexAttnView): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.num_heads = node.kwargs['num_heads'] + self.bs = node.inputs(0).shape[0] // self.num_heads + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.num_heads % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + attn = node.inputs(0) + out = node.outputs(0) + + ins = split_axis(attn, 0, self.chunk_num) + ous = split_axis(out, 2, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ins[idx], self.num_heads // self.chunk_num] + node = CubeComplexAttnView( + signature = 'cube.runtime.function.complex.attn_view', + inputs = inputs, + name = 'attn_view' + ) + node.set_output(0, ous[idx]) + nodes.append(node) + return nodes diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 91af8a18..340c2593 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -275,7 +275,17 @@ def infer_shape(self): class CubeComplexToQKV(IRFwOperation): """ - function to QKV + Inputs: + hidden_state: [L, N, E] + weight: [3 * (num_heads * dim_head), E] + num_heads: int + + where L = sequence length, N = batch size, E = num_heads * dim_head + + Returns: + Q: [L, N * num_heads, dim_head] + K: [L, N * num_heads, dim_head] + V: [L, N * num_heads, dim_head] """ def __init__(self, signature, inputs, name='toqkv', **kwargs): if len(inputs) != 3: @@ -291,12 +301,12 @@ def __init__(self, signature, inputs, name='toqkv', **kwargs): self.kwargs['num_heads'] = inputs[2] def infer_shape(self): - if self.inputs(0).shape is None: + if self.inputs(0).shape is None or self.inputs(1) is None: return False seqlen = self.inputs(0).shape[0] bs = self.inputs(0).shape[1] num_heads = self.kwargs['num_heads'] - dim_head = self.inputs(0).shape[2] // num_heads + dim_head = self.inputs(1).shape[0] // 3 // num_heads shape = [seqlen, bs * num_heads, dim_head] for output in self.outputs(): @@ -306,48 +316,56 @@ def infer_shape(self): class CubeComplexTrilMask(IRFwOperation): """ - Function to tril_mask + Inputs: + input: [N * num_heads, L, L] + num_head: int + + Returns: + output: [N * num_heads, L, L] """ def __init__(self, signature, inputs, name='trilmask', **kwargs): if len(inputs) != 2: raise TypeError("Expected 2 input") - tensor, num_head = inputs[0], inputs[1] + tensor, num_heads = inputs[0], inputs[1] super().__init__( name, signature, - input_length=2, + input_length=1, output_length=1 ) self.set_input(0, tensor) - self.set_input(1, num_head) + self.kwargs['num_heads'] = num_heads def infer_shape(self): if self.inputs(0).shape is None: return False - shape = copy.copy(self.inputs(0).shape) - self._outputs[0].shape = shape + self._outputs[0].shape = self.inputs(0).shape return True class CubeComplexAttnView(IRFwOperation): """ - Funtion to attention view + Inputs: + [N * num_heads, L, dim_head] + + Outputs: + [L, N, num_heads * dim_head] """ def __init__(self, signature, inputs, name='attn_view', **kwargs): if len(inputs) != 2: raise TypeError("Expected 2 input") - tensor, num_head = inputs[0], inputs[1] + tensor, num_heads = inputs[0], inputs[1] super().__init__( name, signature, - input_length=2, + input_length=1, output_length=1 ) self.set_input(0, tensor) - self.set_input(1, num_head) + self.kwargs['num_heads'] = num_heads def infer_shape(self): if self.inputs(0).shape is None: return False - num_heads = self.inputs(1) + num_heads = self.kwargs['num_heads'] bs = self.inputs(0).shape[0] // num_heads seqlen = self.inputs(0).shape[1] dim_head = self.inputs(0).shape[2] diff --git a/cube/ir/cten.py b/cube/ir/cten.py index df798df6..4e63318f 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -492,7 +492,7 @@ def __eq__(self, tensor): @property def shape(self): - return self._shape + return copy.copy(self._shape) @shape.setter def shape(self, val): @@ -501,7 +501,7 @@ def shape(self, val): if not isinstance(val, list) or \ not all([isinstance(size, int) for size in val]): raise RuntimeError("Expected shape to be list[int]") - self._shape = val + self._shape = copy.copy(list(val)) @property def device(self) -> List[int]: diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index e9237589..29707ae8 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -6,8 +6,11 @@ def toqkv(hidden_state: torch.Tensor, weight: torch.Tensor, num_heads: int): """ Inputs: - hidden_state: [L, N, E] (seqlen, batch size, num_heads * dim_head) - weight: [E * 3, E] + hidden_state: [L, N, E] + weight: [3 * (num_heads * dim_head), E] + num_heads: int + + where L = sequence length, N = batch size, E = num_heads * dim_head Returns: Q: [L, N * num_heads, dim_head] From 54e18793506e30deb38630ce49db2a515544ddbc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 20 Nov 2021 15:13:04 +0800 Subject: [PATCH 0340/1892] add bmm split --- cube/algorithm/factory.py | 6 + cube/algorithm/linear.py | 6 +- cube/algorithm/ops/bmm.py | 195 +++++++++++++++++++++ cube/graph/operator/function.py | 2 - tests/algorithm/test_bmm.py | 206 ++++++++++++++++++++++ tests/algorithm/test_complex.py | 298 ++++++++++++++++++++++++++++++++ 6 files changed, 708 insertions(+), 5 deletions(-) create mode 100644 cube/algorithm/ops/bmm.py create mode 100644 tests/algorithm/test_bmm.py create mode 100644 tests/algorithm/test_complex.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index e941787e..5fcc882a 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -82,4 +82,10 @@ def _load_predefined_algos(self): self.register(complex.CubeComplexAttnView, complex.CubeAttnViewDataParallel, tag='data') self.register(complex.CubeComplexAttnView, complex.CubeAttnViewHeadParallel, tag='head') + + import cube.algorithm.ops.bmm as bmm + self.register(bmm.BatchLinear, bmm.BatchLinearDataParallel, tag='data') + self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='n') + self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='m') + self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='p') \ No newline at end of file diff --git a/cube/algorithm/linear.py b/cube/algorithm/linear.py index bdd36313..b1fb7c22 100644 --- a/cube/algorithm/linear.py +++ b/cube/algorithm/linear.py @@ -21,9 +21,9 @@ def __init__(self, node: Linear): def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) input_shape = self.input_shapes[0] - if chunk_num > 0 and input_shape[0] % chunk_num != 0: - return False - return True + if chunk_num > 0 and input_shape[0] % chunk_num == 0: + return True + return False def instantiate(self, node, config: Dict): if not self.satisfy(config): diff --git a/cube/algorithm/ops/bmm.py b/cube/algorithm/ops/bmm.py new file mode 100644 index 00000000..dd25bd2c --- /dev/null +++ b/cube/algorithm/ops/bmm.py @@ -0,0 +1,195 @@ +from typing import Dict + +from cube.algorithm.utils import split_axis, split_value +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.operator.function import BatchLinear + + +_kWaitDecision = None + + +class BatchLinearDataParallel(GenericDistAlgo): + """ + Inputs: + input1: [B, N, M] + input2: [B, M, P] + + Outputs: + output: [B, N, P] + """ + + def __init__(self, node: BatchLinear): + + if not isinstance(node, BatchLinear): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + input_shape = self.input_shapes[0] + if chunk_num > 0 and input_shape[0] % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + input1, input2 = node.inputs() + output = node.outputs(0) + + in1s = split_axis(input1, 0, self.chunk_num) + in2s = split_axis(input2, 0, self.chunk_num) + outs = split_axis(output, 0, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + node = BatchLinear( + signature='torch.bmm', + inputs=[in1s[idx], in2s[idx]], + name='bmm' + ) + node.set_output(0, outs[idx]) + nodes.append(node) + return nodes + + +class BatchLinearNParallel(GenericDistAlgo): + """ + Inputs: + input1: [B, N, M] + input2: [B, M, P] + + Outputs: + output: [B, N, P] + """ + + def __init__(self, node: BatchLinear): + + if not isinstance(node, BatchLinear): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + input_shape = self.input_shapes[0] + if chunk_num > 0 and input_shape[1] % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + input1, input2 = node.inputs() + output = node.outputs(0) + + in1s = split_axis(input1, 1, self.chunk_num) + outs = split_axis(output, 1, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + node = BatchLinear( + signature='torch.bmm', + inputs=[in1s[idx], input2], + name='bmm' + ) + node.set_output(0, outs[idx]) + nodes.append(node) + return nodes + + +class BatchLinearMParallel(GenericDistAlgo): + """ + Inputs: + input1: [B, N, M] + input2: [B, M, P] + + Outputs: + output: [B, N, P] + """ + + def __init__(self, node: BatchLinear): + + if not isinstance(node, BatchLinear): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + input_shape = self.input_shapes[0] + if chunk_num > 0 and input_shape[2] % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + input1, input2 = node.inputs() + output = node.outputs(0) + + in1s = split_axis(input1, 2, self.chunk_num) + in2s = split_axis(input2, 1, self.chunk_num) + outs = split_value(output, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + node = BatchLinear( + signature='torch.bmm', + inputs=[in1s[idx], in2s[idx]], + name='bmm' + ) + node.set_output(0, outs[idx]) + nodes.append(node) + return nodes + + +class BatchLinearPParallel(GenericDistAlgo): + """ + Inputs: + input1: [B, N, M] + input2: [B, M, P] + + Outputs: + output: [B, N, P] + """ + + def __init__(self, node: BatchLinear): + + if not isinstance(node, BatchLinear): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + input_shape = self.input_shapes[1] + if chunk_num > 0 and input_shape[2] % chunk_num == 0: + return True + return False + + def instantiate(self, node, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + input1, input2 = node.inputs() + output = node.outputs(0) + + in2s = split_axis(input2, 2, self.chunk_num) + outs = split_axis(output, 2, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + node = BatchLinear( + signature='torch.bmm', + inputs=[input1, in2s[idx]], + name='bmm' + ) + node.set_output(0, outs[idx]) + nodes.append(node) + return nodes diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 340c2593..377e5e47 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -34,12 +34,10 @@ def infer_shape(self): class BatchLinear(IRFwOperation): """ Inputs: - input1: [B, N, M] input2: [B, M, P] Outputs: - output: [B, N, P] """ diff --git a/tests/algorithm/test_bmm.py b/tests/algorithm/test_bmm.py new file mode 100644 index 00000000..a8393294 --- /dev/null +++ b/tests/algorithm/test_bmm.py @@ -0,0 +1,206 @@ +import cube.algorithm.ops.bmm as bmm +from cube.graph.tensor import IRFullTensor, ValueMap + + +def test_bmm_data_parallel(): + + B = 64 # seq len + N = 256 # batch + M = 1024 # hiddend size = dim_head * num_head + P = 512 + input1 = IRFullTensor(shape=[B, N, M], name='hidden').tosub() + input2 = IRFullTensor(shape=[B, M, P], name='input2').tosub() + + semantic_op = bmm.BatchLinear( + signature='torch.bmm', + inputs = [input1, input2] + ) + semantic_op.infer_shape() + + bmm_dp = bmm.BatchLinearDataParallel(semantic_op) + + assert bmm_dp.chunk_num is None + + assert bmm_dp.satisfy(dict(chunk_num=8)) + assert not bmm_dp.satisfy(dict(chunk_num=9)) + + nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, bmm.BatchLinear) + + nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, bmm.BatchLinear) + + input1s = [node.inputs(0) for node in nodes] + print('inputs:') + for input in input1s: + print(input) + assert input.shape == [B // 4, N, M] + + input2s = [node.inputs(1) for node in nodes] + print('input2s:') + for input2 in input2s: + print(input2) + assert input2.shape == [B // 4, M, P] + + outputs = [node.outputs(0) for node in nodes] + for output in outputs: + print(output) + assert output.shape == [B // 4, N, P] + assert output.val_map == ValueMap(0, 1) + + +def test_bmm_n_parallel(): + + B = 64 # seq len + N = 256 # batch + M = 1024 # hiddend size = dim_head * num_head + P = 512 + input1 = IRFullTensor(shape=[B, N, M], name='hidden').tosub() + input2 = IRFullTensor(shape=[B, M, P], name='input2').tosub() + + semantic_op = bmm.BatchLinear( + signature='torch.bmm', + inputs = [input1, input2] + ) + semantic_op.infer_shape() + + bmm_dp = bmm.BatchLinearNParallel(semantic_op) + + assert bmm_dp.chunk_num is None + + assert bmm_dp.satisfy(dict(chunk_num=8)) + assert not bmm_dp.satisfy(dict(chunk_num=9)) + + nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, bmm.BatchLinear) + + nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, bmm.BatchLinear) + + input1s = [node.inputs(0) for node in nodes] + print('inputs:') + for input in input1s: + print(input) + assert input.shape == [B, N // 4, M] + + input2s = [node.inputs(1) for node in nodes] + print('input2s:') + for input2 in input2s: + print(input2) + assert input2.shape == [B, M, P] + + outputs = [node.outputs(0) for node in nodes] + for output in outputs: + print(output) + assert output.shape == [B, N // 4, P] + assert output.val_map == ValueMap(0, 1) + + +def test_bmm_m_parallel(): + + B = 64 + N = 256 + M = 1024 + P = 512 + input1 = IRFullTensor(shape=[B, N, M], name='input1').tosub() + input2 = IRFullTensor(shape=[B, M, P], name='input2').tosub() + + semantic_op = bmm.BatchLinear( + signature='torch.bmm', + inputs = [input1, input2] + ) + semantic_op.infer_shape() + + bmm_dp = bmm.BatchLinearMParallel(semantic_op) + + assert bmm_dp.chunk_num is None + + assert bmm_dp.satisfy(dict(chunk_num=8)) + assert not bmm_dp.satisfy(dict(chunk_num=9)) + + nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, bmm.BatchLinear) + + nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, bmm.BatchLinear) + + input1s = [node.inputs(0) for node in nodes] + print('inputs:') + for input in input1s: + print(input) + assert input.shape == [B, N, M // 4] + + input2s = [node.inputs(1) for node in nodes] + print('input2s:') + for input2 in input2s: + print(input2) + assert input2.shape == [B, M // 4, P] + + outputs = [node.outputs(0) for node in nodes] + for idx, output in enumerate(outputs): + print(output) + assert output.shape == [B, N, P] + assert output.val_map == ValueMap(idx, 4) + + +def test_bmm_p_parallel(): + + B = 64 # seq len + N = 256 # batch + M = 1024 # hiddend size = dim_head * num_head + P = 512 + input1 = IRFullTensor(shape=[B, N, M], name='hidden').tosub() + input2 = IRFullTensor(shape=[B, M, P], name='input2').tosub() + + semantic_op = bmm.BatchLinear( + signature='torch.bmm', + inputs = [input1, input2] + ) + semantic_op.infer_shape() + + bmm_dp = bmm.BatchLinearPParallel(semantic_op) + + assert bmm_dp.chunk_num is None + + assert bmm_dp.satisfy(dict(chunk_num=8)) + assert not bmm_dp.satisfy(dict(chunk_num=9)) + + nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, bmm.BatchLinear) + + nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, bmm.BatchLinear) + + input1s = [node.inputs(0) for node in nodes] + print('inputs:') + for input in input1s: + print(input) + assert input.shape == [B, N, M] + + input2s = [node.inputs(1) for node in nodes] + print('input2s:') + for input2 in input2s: + print(input2) + assert input2.shape == [B, M, P // 4] + + outputs = [node.outputs(0) for node in nodes] + for output in outputs: + print(output) + assert output.shape == [B, N, P // 4] + assert output.val_map == ValueMap(0, 1) \ No newline at end of file diff --git a/tests/algorithm/test_complex.py b/tests/algorithm/test_complex.py new file mode 100644 index 00000000..00041bd8 --- /dev/null +++ b/tests/algorithm/test_complex.py @@ -0,0 +1,298 @@ +import cube.algorithm.ops.complex as complex +from cube.graph.tensor import IRFullTensor, ValueMap + + +def test_complex_toqkv_data_parallel(): + + L = 64 # seq len + N = 16 # batch + E = 1024 # hiddend size = dim_head * num_head + num_heads = 8 + dim_head = E // num_heads + input = IRFullTensor(shape=[L, N, E], name='hidden').tosub() + weight = IRFullTensor(shape=[3 * E, E], name='weight').tosub() + + semantic_op = complex.CubeComplexToQKV( + signature='cube.runtime.function.complex.toqkv', + inputs = [input, weight, num_heads] + ) + semantic_op.infer_shape() + + qkv_dp = complex.CubeToQKVDataParallel(semantic_op) + + assert qkv_dp.chunk_num is None + + assert qkv_dp.satisfy(dict(chunk_num=8)) + assert not qkv_dp.satisfy(dict(chunk_num=32)) + + nodes = qkv_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexToQKV) + + inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) + assert input.shape == [L, N // 4, E] + weights = [node.inputs(1) for node in nodes] + + print('weights:') + for weight in weights: + print(weight) + assert weight.shape == [3 * E, E] + + sub_heads = [node.kwargs['num_heads'] for node in nodes] + print('num_heads:') + for nhead in sub_heads: + assert nhead == 8 + print(nhead) + + outputs = [node.outputs() for node in nodes] + print('outputs:') + for output in outputs: + q, k, v = output + print('q:', q) + print('k:', k) + print('v:', v) + assert q.shape == [L, N * num_heads // 4, dim_head] + assert k.shape == [L, N * num_heads // 4, dim_head] + assert v.shape == [L, N * num_heads // 4, dim_head] + + +def test_complex_toqkv_head_parallel(): + + L = 64 # seq len + N = 16 # batch + E = 1024 # hiddend size = dim_head * num_head + num_heads = 8 + dim_head = E // num_heads + input = IRFullTensor(shape=[L, N, E], name='hidden').tosub() + weight = IRFullTensor(shape=[3 * E, E], name='weight').tosub() + + semantic_op = complex.CubeComplexToQKV( + signature='cube.runtime.function.complex.toqkv', + inputs = [input, weight, num_heads] + ) + semantic_op.infer_shape() + + qkv_hp = complex.CubeToQKVHeadParallel(semantic_op) + + assert qkv_hp.chunk_num is None + + assert qkv_hp.satisfy(dict(chunk_num=8)) + assert not qkv_hp.satisfy(dict(chunk_num=32)) + + nodes = qkv_hp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexToQKV) + + inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) + assert input.shape == [L, N, E] + + weights = [node.inputs(1) for node in nodes] + print('weights:') + for weight in weights: + assert weight.shape == [3 * E // 4, E] + print(weight) + + sub_heads = [node.kwargs['num_heads'] for node in nodes] + print('sub_heads:') + for nhead in sub_heads: + assert nhead == num_heads // 4 + print(nhead) + + outputs = [node.outputs() for node in nodes] + print('outputs:') + for output in outputs: + q, k, v = output + print('q:', q) + print('k:', k) + print('v:', v) + assert q.shape == [L, N * num_heads // 4, dim_head] + assert k.shape == [L, N * num_heads // 4, dim_head] + assert v.shape == [L, N * num_heads // 4, dim_head] + + +def test_complex_tril_mask_data_parallel(): + + L = 64 # seq len + N = 16 # batch + num_heads = 8 + input = IRFullTensor(shape=[N * num_heads, L, L], name='hidden').tosub() + + semantic_op = complex.CubeComplexTrilMask( + signature = 'cube.runtime.function.complex.trill_mask', + inputs = [input, num_heads], + ) + semantic_op.infer_shape() + + mask_dp = complex.CubeTrilMaskDataParallel(semantic_op) + + assert mask_dp.chunk_num is None + + assert mask_dp.satisfy(dict(chunk_num=8)) + assert not mask_dp.satisfy(dict(chunk_num=32)) + + nodes = mask_dp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexTrilMask) + + inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) + assert input.shape == [N * num_heads // 4, L, L] + + sub_heads = [node.kwargs['num_heads'] for node in nodes] + print('num_heads:') + for nhead in sub_heads: + assert nhead == 8 + print(nhead) + + outputs = [node.outputs(0) for node in nodes] + print('outputs:') + for output in outputs: + print(output) + assert output.shape == [N * num_heads // 4, L, L] + + +def test_complex_tril_mask_head_parallel(): + + L = 64 # seq len + N = 16 # batch + num_heads = 8 + input = IRFullTensor(shape=[N * num_heads, L, L], name='hidden').tosub() + + semantic_op = complex.CubeComplexTrilMask( + signature = 'cube.runtime.function.complex.trill_mask', + inputs = [input, num_heads], + ) + semantic_op.infer_shape() + + mask_hp = complex.CubeTrilMaskHeadParallel(semantic_op) + + assert mask_hp.chunk_num is None + + assert mask_hp.satisfy(dict(chunk_num=8)) + assert not mask_hp.satisfy(dict(chunk_num=32)) + + nodes = mask_hp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexTrilMask) + + inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) + assert input.shape == [N * num_heads // 4, L, L] + + sub_heads = [node.kwargs['num_heads'] for node in nodes] + print('num_heads:') + for nhead in sub_heads: + assert nhead == num_heads // 4 + print(nhead) + + outputs = [node.outputs(0) for node in nodes] + print('outputs:') + for output in outputs: + print(output) + assert output.shape == [N * num_heads // 4, L, L] + + +def test_complex_attn_view_data_parallel(): + + L = 64 # seq len + N = 16 # batch + num_heads = 8 + dim_head = 128 + input = IRFullTensor( + shape=[N * num_heads, L, dim_head], name='hidden').tosub() + + semantic_op = complex.CubeComplexAttnView( + signature = 'cube.runtime.function.complex.trill_mask', + inputs = [input, num_heads], + ) + semantic_op.infer_shape() + + mask_hp = complex.CubeAttnViewDataParallel(semantic_op) + + assert mask_hp.chunk_num is None + + assert mask_hp.satisfy(dict(chunk_num=8)) + assert not mask_hp.satisfy(dict(chunk_num=32)) + + nodes = mask_hp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexAttnView) + + inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) + assert input.shape == [N * num_heads // 4, L, dim_head] + + sub_heads = [node.kwargs['num_heads'] for node in nodes] + print('num_heads:') + for nhead in sub_heads: + assert nhead == num_heads + print(nhead) + + outputs = [node.outputs(0) for node in nodes] + print('outputs:') + for output in outputs: + print(output) + assert output.shape == [L, N // 4, num_heads * dim_head] + + +def test_complex_attn_view_head_parallel(): + + L = 64 # seq len + N = 16 # batch + num_heads = 8 + dim_head = 128 + input = IRFullTensor( + shape=[N * num_heads, L, dim_head], name='hidden').tosub() + + semantic_op = complex.CubeComplexAttnView( + signature = 'cube.runtime.function.complex.trill_mask', + inputs = [input, num_heads], + ) + semantic_op.infer_shape() + + mask_hp = complex.CubeAttnViewHeadParallel(semantic_op) + + assert mask_hp.chunk_num is None + + assert mask_hp.satisfy(dict(chunk_num=8)) + assert not mask_hp.satisfy(dict(chunk_num=32)) + + nodes = mask_hp.instantiate(semantic_op, dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexAttnView) + + inputs = [node.inputs(0) for node in nodes] + print('inputs:') + for input in inputs: + print(input) + assert input.shape == [N * num_heads // 4, L, dim_head] + + sub_heads = [node.kwargs['num_heads'] for node in nodes] + print('num_heads:') + for nhead in sub_heads: + assert nhead == num_heads // 4 + print(nhead) + + outputs = [node.outputs(0) for node in nodes] + print('outputs:') + for output in outputs: + print(output) + assert output.shape == [L, N, num_heads * dim_head // 4] From 049c77358d2e303e3cadbd9de6e188b48641f8d2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 20 Nov 2021 15:15:52 +0800 Subject: [PATCH 0341/1892] restructure --- cube/algorithm/factory.py | 21 ++++++++++----------- cube/algorithm/{ => ops}/dataloader.py | 0 cube/algorithm/{ => ops}/elementwise.py | 0 cube/algorithm/{ => ops}/linear.py | 0 cube/algorithm/{ => ops}/reduce.py | 0 5 files changed, 10 insertions(+), 11 deletions(-) rename cube/algorithm/{ => ops}/dataloader.py (100%) rename cube/algorithm/{ => ops}/elementwise.py (100%) rename cube/algorithm/{ => ops}/linear.py (100%) rename cube/algorithm/{ => ops}/reduce.py (100%) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 5fcc882a..9f83c900 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -59,18 +59,24 @@ def algorithms(self, op, tag = None): def _load_predefined_algos(self): - import cube.algorithm.dataloader as dataloader + import cube.algorithm.ops.dataloader as dataloader self.register(dataloader.IRDataOperation, dataloader.DPDataLoader, tag='data') - import cube.algorithm.linear as linear + import cube.algorithm.ops.linear as linear self.register(linear.Linear, linear.LinearDataParallel, tag='data') self.register(linear.Linear, linear.LinearColumnWeight, tag='column') self.register(linear.Linear, linear.LinearRowWeight, tag='row') - import cube.algorithm.elementwise as elew + import cube.algorithm.ops.bmm as bmm + self.register(bmm.BatchLinear, bmm.BatchLinearDataParallel, tag='data') + self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='n') + self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='m') + self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='p') + + import cube.algorithm.ops.elementwise as elew self.register(elew.ElementWise, elew.ElementWiseDataParallel, tag='data') - import cube.algorithm.reduce as reduce + import cube.algorithm.ops.reduce as reduce self.register(reduce.Reduce, reduce.ReduceDataParallel, tag='data') import cube.algorithm.ops.complex as complex @@ -82,10 +88,3 @@ def _load_predefined_algos(self): self.register(complex.CubeComplexAttnView, complex.CubeAttnViewDataParallel, tag='data') self.register(complex.CubeComplexAttnView, complex.CubeAttnViewHeadParallel, tag='head') - - import cube.algorithm.ops.bmm as bmm - self.register(bmm.BatchLinear, bmm.BatchLinearDataParallel, tag='data') - self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='n') - self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='m') - self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='p') - \ No newline at end of file diff --git a/cube/algorithm/dataloader.py b/cube/algorithm/ops/dataloader.py similarity index 100% rename from cube/algorithm/dataloader.py rename to cube/algorithm/ops/dataloader.py diff --git a/cube/algorithm/elementwise.py b/cube/algorithm/ops/elementwise.py similarity index 100% rename from cube/algorithm/elementwise.py rename to cube/algorithm/ops/elementwise.py diff --git a/cube/algorithm/linear.py b/cube/algorithm/ops/linear.py similarity index 100% rename from cube/algorithm/linear.py rename to cube/algorithm/ops/linear.py diff --git a/cube/algorithm/reduce.py b/cube/algorithm/ops/reduce.py similarity index 100% rename from cube/algorithm/reduce.py rename to cube/algorithm/ops/reduce.py From bd40adf46901c3597b7c53fe47b59e707f688d3c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 20 Nov 2021 15:55:58 +0800 Subject: [PATCH 0342/1892] dim partition on elementwise op --- cube/algorithm/ops/elementwise.py | 37 ++++++++++++------- cube/graph/operator/function.py | 2 +- ...lementwise_algo.py => test_elementwise.py} | 20 ++++++++-- 3 files changed, 42 insertions(+), 17 deletions(-) rename tests/algorithm/{test_elementwise_algo.py => test_elementwise.py} (63%) diff --git a/cube/algorithm/ops/elementwise.py b/cube/algorithm/ops/elementwise.py index 66bd61b9..62dd21bc 100644 --- a/cube/algorithm/ops/elementwise.py +++ b/cube/algorithm/ops/elementwise.py @@ -2,37 +2,48 @@ from cube.algorithm.utils import split_axis from cube.algorithm.generics import GenericDistAlgo -from cube.graph.operator.function import ElementWise from cube.ir.cten import IRTensor +from cube.graph.operator.function import ElementWise + _kWaitDecision = None -class ElementWiseDataParallel(GenericDistAlgo): +class ElementWiseDimParallel(GenericDistAlgo): - def __init__(self, node: ElementWise): + def __init__(self, node: ElementWise, dim=None): if not isinstance(node, ElementWise): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) + self.ndim = len(node.inputs(0).shape) self.chunk_num = _kWaitDecision + self.dim = dim def satisfy(self, config: Dict): + if 'dim' in config: + dim = config['dim'] + else: + if self.dim is None: + raise RuntimeError("Expected dim in config") + dim = self.dim chunk_num = int(config['chunk_num']) - input_shape = self.input_shapes[0] - if chunk_num > 0 and input_shape[0] % chunk_num != 0: - return False - return True + shape = self.input_shapes[0] + if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: + return True + return False def instantiate(self, node, config: Dict): if not self.satisfy(config): raise RuntimeError("Instantiate failed. Condition not satisfied.") self.chunk_num = int(config['chunk_num']) + if 'dim' in config: + self.dim = config['dim'] sub_inputs = list() for input in node.inputs(): if isinstance(input, IRTensor): - sub_input = split_axis(input, 0, self.chunk_num) + sub_input = split_axis(input, self.dim, self.chunk_num) else: sub_input = [input] * self.chunk_num sub_inputs.append(sub_input) @@ -40,17 +51,17 @@ def instantiate(self, node, config: Dict): sub_outputs = list() for output in node.outputs(): if isinstance(output, IRTensor): - sub_output = split_axis(output, 0, self.chunk_num) + sub_output = split_axis(output, self.dim, self.chunk_num) else: sub_output = [output] * self.chunk_num sub_outputs.append(sub_output) nodes = list() for idx, sub_input in enumerate(zip(*sub_inputs)): - node = ElementWise(node.signature, inputs=sub_input, name=node.name) - nodes.append(node) + sub_node = ElementWise(node.signature, inputs=sub_input, name=node.name) + nodes.append(sub_node) for idx, sub_output in enumerate(zip(*sub_outputs)): - node = nodes[idx] + sub_node = nodes[idx] for idx, output in enumerate(sub_output): - node.set_output(idx, output) + sub_node.set_output(idx, output) return nodes diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 377e5e47..d8239af3 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -84,7 +84,7 @@ def __init__(self, signature, inputs, name='elementwise', **kwargs): raise TypeError(f"Expected 2 inputs but got {inputs}") super().__init__( name, signature, - input_length=len(inputs), + input_length=2, output_length=1 ) for idx, input in enumerate(inputs): diff --git a/tests/algorithm/test_elementwise_algo.py b/tests/algorithm/test_elementwise.py similarity index 63% rename from tests/algorithm/test_elementwise_algo.py rename to tests/algorithm/test_elementwise.py index 74e61fe3..3f40e174 100644 --- a/tests/algorithm/test_elementwise_algo.py +++ b/tests/algorithm/test_elementwise.py @@ -1,9 +1,9 @@ from cube.graph.operator.function import ElementWise -from cube.algorithm.elementwise import ElementWiseDataParallel +from cube.algorithm.ops.elementwise import ElementWiseDimParallel from cube.graph.tensor import IRFullTensor -def test_elementwise_data_parallel(): +def test_elementwise_dim_parallel(): input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() input2 = IRFullTensor(shape=[1024, 1024], name='input2').tosub() @@ -15,7 +15,7 @@ def test_elementwise_data_parallel(): print('semantic op:') print(semantic_op) - op_dp = ElementWiseDataParallel(semantic_op) + op_dp = ElementWiseDimParallel(semantic_op, dim=0) assert op_dp.chunk_num is None @@ -40,3 +40,17 @@ def test_elementwise_data_parallel(): print(output) assert output.shape == [256, 1024] + op_dp = ElementWiseDimParallel(semantic_op, dim=1) + nodes = op_dp.instantiate(semantic_op, dict(chunk_num=4)) + + for node in nodes: + print('=======') + print(node) + print('inputs:') + for input in node.inputs(): + print(input) + assert input.shape == [1024, 256] + print('outputs:') + for output in node.outputs(): + print(output) + assert output.shape == [1024, 256] From f2a8e7ecea57274acaa06f371767932e1b34b71d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 20 Nov 2021 23:40:15 +0800 Subject: [PATCH 0343/1892] update elementwise ops --- cube/algorithm/factory.py | 8 +- cube/algorithm/ops/activation.py | 115 ++++++++++++++++++++++++++++ cube/algorithm/ops/elementwise.py | 31 +++++++- cube/graph/operator/function.py | 90 ++++++++++++---------- tests/algorithm/test_elementwise.py | 40 +++++++++- 5 files changed, 237 insertions(+), 47 deletions(-) create mode 100644 cube/algorithm/ops/activation.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 9f83c900..b42860df 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -74,7 +74,13 @@ def _load_predefined_algos(self): self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='p') import cube.algorithm.ops.elementwise as elew - self.register(elew.ElementWise, elew.ElementWiseDataParallel, tag='data') + self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') + self.register(elew.Add, elew.AddDimParallel, tag='dim') + + import cube.algorithm.ops.activation as activation + self.register(activation.Activation, activation.ActivationDimParallel, tag='dim') + self.register(activation.Dropout, activation.DropoutDimParallel, tag='dim') + self.register(activation.Softmax, activation.SoftmaxDimParallel, tag ='dim') import cube.algorithm.ops.reduce as reduce self.register(reduce.Reduce, reduce.ReduceDataParallel, tag='data') diff --git a/cube/algorithm/ops/activation.py b/cube/algorithm/ops/activation.py new file mode 100644 index 00000000..1468648f --- /dev/null +++ b/cube/algorithm/ops/activation.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, List +import copy + +from cube.algorithm.utils import split_axis +from cube.algorithm.generics import GenericDistAlgo +from cube.ir.cten import IRTensor + +from cube.graph.operator.function import Activation +from cube.graph.operator.function import Dropout +from cube.graph.operator.function import Softmax + + +_kWaitDecision = None + + +class ActivationDimParallel(GenericDistAlgo): + + def __init__(self, node: Activation, dim=None, execlude_dims=None): + if not isinstance(node, Activation): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.ndim = len(node.inputs(0).shape) + self.chunk_num = _kWaitDecision + self.dim = dim + # stay dim convert to positive dim + self.stay_dims = list() + for sdim in node.stay_dims: + sdim = sdim if sdim >= 0 else self.ndim + sdim + self.stay_dims.append(sdim) + + def satisfy(self, config: Dict): + if 'dim' in config: + dim = config['dim'] + else: + if self.dim is None: + raise RuntimeError("Expected dim in config") + dim = self.dim + if dim < 0: + dim = self.ndim + dim + chunk_num = int(config['chunk_num']) + if dim in self.stay_dims: + return False + shape = self.input_shapes[0] + if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: + return True + return False + + def get_extra_kwargs(self, node) -> List[Any]: + """ + Get extra kwarg inputs for the activation + + Returns: + value in List + """ + return [] + + def instantiate(self, node: Activation, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + if 'dim' in config: + self.dim = config['dim'] + + sub_inputs = list() + for input in node.inputs(): + if isinstance(input, IRTensor): + sub_input = split_axis(input, self.dim, self.chunk_num) + else: + sub_input = [input] * self.chunk_num + sub_inputs.append(sub_input) + + sub_outputs = list() + for output in node.outputs(): + if isinstance(output, IRTensor): + sub_output = split_axis(output, self.dim, self.chunk_num) + else: + sub_output = [output] * self.chunk_num + sub_outputs.append(sub_output) + + nodes = list() + for idx, sub_input in enumerate(zip(*sub_inputs)): + extra_input = self.get_extra_kwargs(node) + sub_input += extra_input + sub_node = type(node)(node.signature, inputs=sub_input, name=node.name) + sub_node.stay_dims = node.stay_dims + nodes.append(sub_node) + for idx, sub_output in enumerate(zip(*sub_outputs)): + sub_node = nodes[idx] + for idx, output in enumerate(sub_output): + sub_node.set_output(idx, output) + return nodes + + +class DropoutDimParallel(ActivationDimParallel): + + def __init__(self, node: Activation, dim=None, execlude_dims=None): + super().__init__(node, dim=dim, execlude_dims=execlude_dims) + + def get_extra_kwargs(self, node: Dropout) -> List[Any]: + if not isinstance(node, Dropout): + raise TypeError("Expected Dropout for DropoutDimParallel") + kwargs = [node.kwargs['p'], node.kwargs['training'], node.kwargs['inplace']] + return kwargs + + +class SoftmaxDimParallel(ActivationDimParallel): + + def __init__(self, node: Activation, dim=None, execlude_dims=None): + super().__init__(node, dim=dim, execlude_dims=execlude_dims) + + def get_extra_kwargs(self, node) -> List[Any]: + if not isinstance(node, Softmax): + raise TypeError("Expected Softmax for SoftmaxDimParallel") + kwargs = [node.kwargs['dim'], node.kwargs['_stacklevel'], node.kwargs['dtype']] + return kwargs diff --git a/cube/algorithm/ops/elementwise.py b/cube/algorithm/ops/elementwise.py index 62dd21bc..2c83faab 100644 --- a/cube/algorithm/ops/elementwise.py +++ b/cube/algorithm/ops/elementwise.py @@ -1,10 +1,12 @@ -from typing import Dict +from typing import Any, List, Dict +import copy from cube.algorithm.utils import split_axis from cube.algorithm.generics import GenericDistAlgo from cube.ir.cten import IRTensor from cube.graph.operator.function import ElementWise +from cube.graph.operator.function import Add _kWaitDecision = None @@ -27,12 +29,23 @@ def satisfy(self, config: Dict): if self.dim is None: raise RuntimeError("Expected dim in config") dim = self.dim + if dim < 0: + dim = self.ndim + dim chunk_num = int(config['chunk_num']) shape = self.input_shapes[0] if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: return True return False + def get_extra_kwargs(self, node: ElementWise) -> List[Any]: + """ + Get extra kwarg inputs for the activation + + Returns: + value in List + """ + return [] + def instantiate(self, node, config: Dict): if not self.satisfy(config): raise RuntimeError("Instantiate failed. Condition not satisfied.") @@ -58,10 +71,24 @@ def instantiate(self, node, config: Dict): nodes = list() for idx, sub_input in enumerate(zip(*sub_inputs)): - sub_node = ElementWise(node.signature, inputs=sub_input, name=node.name) + print(sub_input) + sub_input = list(sub_input) + self.get_extra_kwargs(node) + sub_node = type(node)(node.signature, inputs=sub_input, name=node.name) + sub_node.kwargs = copy.copy(node.kwargs) nodes.append(sub_node) for idx, sub_output in enumerate(zip(*sub_outputs)): sub_node = nodes[idx] for idx, output in enumerate(sub_output): sub_node.set_output(idx, output) return nodes + + +class AddDimParallel(ElementWiseDimParallel): + + def __init__(self, node: ElementWise, dim=None): + super().__init__(node, dim=dim) + + def get_extra_kwargs(self, node: Add) -> List[Any]: + if not isinstance(node, Add): + raise TypeError("Expected Add for AddDimParallel") + return [node.kwargs['alpha']] diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index d8239af3..1f8b1db4 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -1,4 +1,5 @@ import copy +from typing import Type from cube.graph.operator import IRFwOperation from cube.ir.cten import IRTensor @@ -117,35 +118,69 @@ def __init__(self, signature, inputs, name='add', **kwargs): ) super().__init__(signature, inputs[:2], name=name) alpha = inputs[2] - if alpha != 1: - self.kwargs['alpha'] = alpha + self.kwargs['alpha'] = alpha -class ElementWiseActivation(IRFwOperation): +class Activation(IRFwOperation): """ functions like GELU, RELU, Dropout. Exclude softmax """ - def __init__(self, signature, inputs, name='elementwise_activation', **kwargs): + def __init__(self, signature, inputs, name='activation', **kwargs): + + if len(inputs) != 1: + raise TypeError("Expected single tensor input") super().__init__( name, signature, - input_length=len(inputs), + input_length=1, output_length=1 ) - for idx, input in enumerate(inputs): - self.set_input(idx, input) + self.set_input(0, inputs[0]) + # this is for partitioning indicator + self.stay_dims = list() def infer_shape(self): - for input in self.inputs(): - if isinstance(input, IRTensor): - if len(input.shape) != 0: - self._outputs[0].shape = copy.copy(input.shape) - return True - return False - return False + input = self.inputs(0) + if input.shape is None: + return False + self._outputs[0].shape = input.shape + return True + + +class Dropout(Activation): + """ + torch.nn.functional.dropout + """ + def __init__(self, signature, inputs, name='dropout', **kwargs): + + if len(inputs) != 4: + raise TypeError(f"Expected 4 inputs but got {inputs}") + super().__init__(signature, [inputs[0]], name) + self.set_input(0, inputs[0]) + self.kwargs['p'] = inputs[1] + self.kwargs['training'] = inputs[2] + self.kwargs['inplace'] = inputs[3] + + +class Softmax(Activation): + + def __init__(self, signature, inputs, name='softmax', **kwargs): + + if len(inputs) != 4: + raise TypeError(f"Expected 4 inputs, but got: {inputs}") + + tensor, dim, stacklevel, dtype = inputs[0], inputs[1], inputs[2], inputs[3] + super().__init__( + name, signature, input_length=1, output_length=1 + ) + self.set_input(0, tensor) + self.kwargs['dim'] = dim + self.kwargs['_stacklevel'] = stacklevel + self.kwargs['dtype'] = dtype + self.stay_dims.append(dim) class Reduce(IRFwOperation): @@ -209,33 +244,6 @@ def infer_shape(self): return True -class Softmax(IRFwOperation): - - def __init__(self, signature, inputs, name='softmax', **kwargs): - - if len(inputs) != 4: - raise TypeError(f"Expected 4 inputs, but got: {inputs}") - - tensor, dim, stacklevel, dtype = inputs[0], inputs[1], inputs[2], inputs[3] - super().__init__( - name, signature, input_length=1, output_length=1 - ) - self.set_input(0, tensor) - self.kwargs['dim'] = dim - self.kwargs['_stacklevel'] = stacklevel - self.kwargs['dtype'] = dtype - - def infer_shape(self): - if self.inputs(0).shape is None: - return False - dim = self.kwargs['dim'] - shape = [ - nele for idx, nele in enumerate(self.inputs(0).shape) if idx != dim - ] - self._outputs[0].shape = shape - return True - - class Transpose(IRFwOperation): """ torch.transpose diff --git a/tests/algorithm/test_elementwise.py b/tests/algorithm/test_elementwise.py index 3f40e174..f41d1b67 100644 --- a/tests/algorithm/test_elementwise.py +++ b/tests/algorithm/test_elementwise.py @@ -1,5 +1,5 @@ from cube.graph.operator.function import ElementWise -from cube.algorithm.ops.elementwise import ElementWiseDimParallel +import cube.algorithm.ops.elementwise as elew from cube.graph.tensor import IRFullTensor @@ -15,7 +15,7 @@ def test_elementwise_dim_parallel(): print('semantic op:') print(semantic_op) - op_dp = ElementWiseDimParallel(semantic_op, dim=0) + op_dp = elew.ElementWiseDimParallel(semantic_op, dim=0) assert op_dp.chunk_num is None @@ -40,7 +40,7 @@ def test_elementwise_dim_parallel(): print(output) assert output.shape == [256, 1024] - op_dp = ElementWiseDimParallel(semantic_op, dim=1) + op_dp = elew.ElementWiseDimParallel(semantic_op, dim=1) nodes = op_dp.instantiate(semantic_op, dict(chunk_num=4)) for node in nodes: @@ -54,3 +54,37 @@ def test_elementwise_dim_parallel(): for output in node.outputs(): print(output) assert output.shape == [1024, 256] + + +def test_add_dim_parallel(): + + input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() + input2 = IRFullTensor(shape=[1024, 1024], name='input2').tosub() + alpha = 1.0 + + semantic_op = elew.Add( + signature='torch.add', inputs=[input1, input2, alpha], name='add' + ) + semantic_op.infer_shape() + + dim_op = elew.AddDimParallel(semantic_op) + + assert dim_op.dim is None + assert dim_op.chunk_num is None + + assert dim_op.satisfy(config=dict(dim=1, chunk_num=4)) + assert dim_op.satisfy(config=dict(dim=-1, chunk_num=4)) + assert dim_op.satisfy(config=dict(dim=0, chunk_num=4)) + assert not dim_op.satisfy(config=dict(dim=2, chunk_num=4)) + + nodes = dim_op.instantiate(semantic_op, dict(dim=0, chunk_num=4)) + for node in nodes: + print(node) + assert isinstance(node, elew.Add) + for input in node.inputs(): + print(input) + assert input.shape == [1024 // 4, 1024] + for output in node.outputs(): + print(output) + assert output.shape == [1024 // 4, 1024] + assert node.kwargs == semantic_op.kwargs From 2b22ece0849959a3f45f1597243f9bfb66280af5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 20 Nov 2021 23:52:09 +0800 Subject: [PATCH 0344/1892] add dropout, softmax activations --- cube/algorithm/ops/activation.py | 2 +- cube/graph/operator/function.py | 4 +- tests/algorithm/test_activation.py | 75 ++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 tests/algorithm/test_activation.py diff --git a/cube/algorithm/ops/activation.py b/cube/algorithm/ops/activation.py index 1468648f..160ff9f2 100644 --- a/cube/algorithm/ops/activation.py +++ b/cube/algorithm/ops/activation.py @@ -80,7 +80,7 @@ def instantiate(self, node: Activation, config: Dict): nodes = list() for idx, sub_input in enumerate(zip(*sub_inputs)): extra_input = self.get_extra_kwargs(node) - sub_input += extra_input + sub_input = list(sub_input) + extra_input sub_node = type(node)(node.signature, inputs=sub_input, name=node.name) sub_node.stay_dims = node.stay_dims nodes.append(sub_node) diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 1f8b1db4..b0edfd4a 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -173,9 +173,7 @@ def __init__(self, signature, inputs, name='softmax', **kwargs): raise TypeError(f"Expected 4 inputs, but got: {inputs}") tensor, dim, stacklevel, dtype = inputs[0], inputs[1], inputs[2], inputs[3] - super().__init__( - name, signature, input_length=1, output_length=1 - ) + super().__init__(signature, inputs=[inputs[0]], name=name) self.set_input(0, tensor) self.kwargs['dim'] = dim self.kwargs['_stacklevel'] = stacklevel diff --git a/tests/algorithm/test_activation.py b/tests/algorithm/test_activation.py new file mode 100644 index 00000000..4a5e9dfb --- /dev/null +++ b/tests/algorithm/test_activation.py @@ -0,0 +1,75 @@ +import cube.algorithm.ops.activation as activation +from cube.graph.operator.function import Dropout +from cube.graph.tensor import IRFullTensor + + +def test_softmax_dim_parallel(): + + input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() + dim = -1 + stacklevel = 3 + dtype = None + + semantic_op = activation.Softmax( + signature = 'torch.nn.functional.softmax', + inputs = [input1, dim, stacklevel, dtype], + ) + semantic_op.infer_shape() + + op_dim = activation.SoftmaxDimParallel(semantic_op) + assert op_dim.dim is None + assert op_dim.chunk_num is None + + assert op_dim.satisfy(dict(dim=0, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) + assert not op_dim.satisfy(dict(dim=1, chunk_num=4)) + assert not op_dim.satisfy(dict(dim=-1, chunk_num=4)) + + nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) + for node in nodes: + print(node) + assert isinstance(node, activation.Softmax) + for input in node.inputs(): + print(input) + assert input.shape == [1024 // 4, 1024] + for output in node.outputs(): + print(output) + assert output.shape == [1024 // 4, 1024] + assert node.kwargs == semantic_op.kwargs + assert node.stay_dims == semantic_op.stay_dims + + +def test_dropout_dim_parallel(): + + input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() + p = 0.5 + training = True + inplace = False + + semantic_op = activation.Dropout( + signature = 'torch.nn.functional.softmax', + inputs = [input1, p, training, inplace], + ) + semantic_op.infer_shape() + + op_dim = activation.DropoutDimParallel(semantic_op) + assert op_dim.dim is None + assert op_dim.chunk_num is None + + assert op_dim.satisfy(dict(dim=0, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) + assert op_dim.satisfy(dict(dim=1, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-1, chunk_num=4)) + + nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) + for node in nodes: + print(node) + assert isinstance(node, activation.Dropout) + for input in node.inputs(): + print(input) + assert input.shape == [1024 // 4, 1024] + for output in node.outputs(): + print(output) + assert output.shape == [1024 // 4, 1024] + assert node.kwargs == semantic_op.kwargs + assert node.stay_dims == semantic_op.stay_dims From 6f8416e83bae24568297db14040b0d92ce198cf6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 00:53:37 +0800 Subject: [PATCH 0345/1892] add reduce sum partition --- cube/algorithm/factory.py | 2 +- cube/algorithm/ops/activation.py | 4 +- cube/algorithm/ops/reduce.py | 86 ++++++++++++++++++----------- cube/graph/operator/function.py | 51 +++++++++-------- tests/algorithm/test_reduce.py | 68 +++++++++++++++++++++++ tests/algorithm/test_reduce_algo.py | 41 -------------- 6 files changed, 150 insertions(+), 102 deletions(-) create mode 100644 tests/algorithm/test_reduce.py delete mode 100644 tests/algorithm/test_reduce_algo.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index b42860df..a8864298 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -83,7 +83,7 @@ def _load_predefined_algos(self): self.register(activation.Softmax, activation.SoftmaxDimParallel, tag ='dim') import cube.algorithm.ops.reduce as reduce - self.register(reduce.Reduce, reduce.ReduceDataParallel, tag='data') + self.register(reduce.Sum, reduce.SumDimParallel, tag='dim') import cube.algorithm.ops.complex as complex self.register(complex.CubeComplexToQKV, complex.CubeToQKVDataParallel, tag='data') diff --git a/cube/algorithm/ops/activation.py b/cube/algorithm/ops/activation.py index 160ff9f2..7ed73103 100644 --- a/cube/algorithm/ops/activation.py +++ b/cube/algorithm/ops/activation.py @@ -15,7 +15,7 @@ class ActivationDimParallel(GenericDistAlgo): - def __init__(self, node: Activation, dim=None, execlude_dims=None): + def __init__(self, node: Activation, dim=None): if not isinstance(node, Activation): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) @@ -82,7 +82,7 @@ def instantiate(self, node: Activation, config: Dict): extra_input = self.get_extra_kwargs(node) sub_input = list(sub_input) + extra_input sub_node = type(node)(node.signature, inputs=sub_input, name=node.name) - sub_node.stay_dims = node.stay_dims + sub_node.stay_dims = copy.copy(node.stay_dims) nodes.append(sub_node) for idx, sub_output in enumerate(zip(*sub_outputs)): sub_node = nodes[idx] diff --git a/cube/algorithm/ops/reduce.py b/cube/algorithm/ops/reduce.py index 98caa652..c9ebe54a 100644 --- a/cube/algorithm/ops/reduce.py +++ b/cube/algorithm/ops/reduce.py @@ -1,55 +1,77 @@ from typing import Dict +import copy from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo -from cube.graph.operator.function import Reduce -from cube.ir.cten import IRTensor + +from cube.graph.operator.function import Sum + _kWaitDecision = None -class ReduceDataParallel(GenericDistAlgo): +class SumDimParallel(GenericDistAlgo): - def __init__(self, node: Reduce): - if not isinstance(node, Reduce): + def __init__(self, node: Sum, dim=None): + if not isinstance(node, Sum): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) + self.ndim = len(node.inputs(0).shape) + self.reduce_dims = list(range(self.ndim)) + self.keepdim = [False] * self.ndim + if 'dim' in node.kwargs: + self.reduce_dims = [node.kwargs['dim']] + if 'keepdim' in node.kwargs: + self.keepdim = [node.kwargs['keepdim']] * self.ndim + self.chunk_num = _kWaitDecision + if dim is not None: + dim = self.ndim + dim if dim < 0 else dim + self.dim = dim def satisfy(self, config: Dict): + if 'dim' in config: + dim = config['dim'] + else: + if self.dim is None: + raise RuntimeError("Expected dim in config") + dim = self.dim + if dim < 0: + dim = self.ndim + dim chunk_num = int(config['chunk_num']) - input_shape = self.input_shapes[0] - if chunk_num > 0 and input_shape[0] % chunk_num != 0: - return False - return True + shape = self.input_shapes[0] + if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: + return True + return False - def instantiate(self, node, config: Dict): + def instantiate(self, node: Sum, config: Dict): if not self.satisfy(config): raise RuntimeError("Instantiate failed. Condition not satisfied.") self.chunk_num = int(config['chunk_num']) + if 'dim' in config: + self.dim = config['dim'] + self.dim = self.ndim + self.dim if self.dim < 0 else self.dim + + assert len(node.inputs()) == 1 + input = node.inputs(0) + sub_inputs = split_axis(input, self.dim, self.chunk_num) - sub_inputs = list() - for input in node.inputs(): - if isinstance(input, IRTensor): - sub_input = split_axis(input, 0, self.chunk_num) - else: - sub_input = [input] * self.chunk_num - sub_inputs.append(sub_input) - - sub_outputs = list() - for output in node.outputs(): - if isinstance(output, IRTensor): - sub_output = split_value(output, self.chunk_num) - else: - sub_output = [output] * self.chunk_num - sub_outputs.append(sub_output) + assert len(node.outputs()) == 1 + output = node.outputs(0) + print(self.reduce_dims) + if self.dim not in self.reduce_dims: + sub_outputs = split_axis(output, self.dim, self.chunk_num) + else: + sub_outputs = split_value(output, self.chunk_num) nodes = list() - for idx, sub_input in enumerate(zip(*sub_inputs)): - node = Reduce(node.signature, inputs=sub_input, name=node.name) - nodes.append(node) - for idx, sub_output in enumerate(zip(*sub_outputs)): - node = nodes[idx] - for oidx, output in enumerate(sub_output): - node.set_output(oidx, output) + if 'dim' in node.kwargs: + dim = node.kwargs['dim'] + else: + dim = None + for input, output in zip(sub_inputs, sub_outputs): + sub_node = type(node)(node.signature, inputs=[input, dim], name=node.name) + sub_node.kwargs = copy.copy(node.kwargs) + sub_node.set_output(0, output) + nodes.append(sub_node) return nodes diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index b0edfd4a..df223826 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -66,6 +66,7 @@ def infer_shape(self): self._outputs[0].shape = shape return True +# ============================= Elementwise ============================ class ElementWise(IRFwOperation): """ @@ -120,6 +121,7 @@ def __init__(self, signature, inputs, name='add', **kwargs): alpha = inputs[2] self.kwargs['alpha'] = alpha +# ============================= Activation ============================ class Activation(IRFwOperation): """ @@ -180,24 +182,7 @@ def __init__(self, signature, inputs, name='softmax', **kwargs): self.kwargs['dtype'] = dtype self.stay_dims.append(dim) - -class Reduce(IRFwOperation): - """ - functions like sum, mean, cross_entropy - """ - def __init__(self, signature, inputs, name='reduce', **kwargs): - super().__init__( - name, signature, - input_length=len(inputs), - output_length=1 - ) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - - def infer_shape(self): - self._outputs[0].shape = [1] - return True - +# ===================== Loss Computation (Reduce) ========================= class Sum(IRFwOperation): """ @@ -226,21 +211,35 @@ def __init__(self, signature, inputs, name='sum', **kwargs): def infer_shape(self): if self.inputs(0).shape is None: return False - shape = list() + + # change dim to positive value + ndim = len(self.inputs(0).shape) if 'dim' in self.kwargs: - dim = [self.kwargs['dim']] + dim = self.kwargs['dim'] + dim = ndim + dim if dim < 0 else dim + self.kwargs['dim'] = dim + reduce_dims = [dim] + else: + reduce_dims = list(range(ndim)) + + if 'keepdim' in self.kwargs: keepdim = self.kwargs['keepdim'] - for idx, nele in enumerate(self.inputs(0).shape): - if idx in dim: - if not keepdim: - continue - nele = 1 - shape.append(nele) else: + keepdim = False + + shape = list() + for dim, nele in enumerate(self.inputs(0).shape): + if dim in reduce_dims: + if keepdim: + shape.append(1) + else: + shape.append(nele) + if len(shape) == 0: shape = [1] self._outputs[0].shape = shape return True +# ========================= Memory Operation ========================== class Transpose(IRFwOperation): """ diff --git a/tests/algorithm/test_reduce.py b/tests/algorithm/test_reduce.py new file mode 100644 index 00000000..a4df3d2e --- /dev/null +++ b/tests/algorithm/test_reduce.py @@ -0,0 +1,68 @@ +import cube.algorithm.ops.reduce as reduce +from cube.graph.tensor import IRFullTensor, ValueMap + + +def test_reduce_dim_parallel(): + + input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() + dim = None + + semantic_op = reduce.Sum( + signature='torch.sum', inputs=[input1, dim], name='add' + ) + semantic_op.infer_shape() + print('semantic op:') + print(semantic_op) + + op_dim = reduce.SumDimParallel(semantic_op) + assert op_dim.dim is None + assert op_dim.chunk_num is None + + # test satisfy + assert op_dim.satisfy(dict(dim=0, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) + assert op_dim.satisfy(dict(dim=1, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-1, chunk_num=4)) + + nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, reduce.Sum) + + for idx, node in enumerate(nodes): + print('=======') + print(node) + print('inputs:') + for input in node.inputs(): + print(input) + assert input.shape == [256, 1024] + print('outputs:') + for output in node.outputs(): + print(output) + assert output.shape == [1] + assert output.val_map == ValueMap(idx, 4) + + + dim = 1 + semantic_op = reduce.Sum( + signature='torch.sum', inputs=[input1, dim], name='add' + ) + semantic_op.infer_shape() + assert op_dim.satisfy(dict(dim=0, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) + assert op_dim.satisfy(dict(dim=1, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-1, chunk_num=4)) + + op_dim = reduce.SumDimParallel(semantic_op) + nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) + for idx, node in enumerate(nodes): + print(node) + print('inputs:') + for input in node.inputs(): + print(input) + assert input.shape == [256, 1024] + print('outputs:') + for output in node.outputs(): + print(output) + assert output.shape == [256] + assert output.val_map == ValueMap(0, 1) diff --git a/tests/algorithm/test_reduce_algo.py b/tests/algorithm/test_reduce_algo.py deleted file mode 100644 index 5f009150..00000000 --- a/tests/algorithm/test_reduce_algo.py +++ /dev/null @@ -1,41 +0,0 @@ -from cube.graph.operator.function import Reduce -from cube.algorithm.reduce import ReduceDataParallel -from cube.graph.tensor import IRFullTensor, ValueMap - - -def test_elementwise_data_parallel(): - - input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() - - semantic_op = Reduce( - signature='torch.sum', inputs=[input1], name='add' - ) - semantic_op.infer_shape() - print('semantic op:') - print(semantic_op) - - op_dp = ReduceDataParallel(semantic_op) - - assert op_dp.chunk_num is None - - # test satisfy - assert op_dp.satisfy(dict(chunk_num = 4)) - assert not op_dp.satisfy(dict(chunk_num = 10)) - - nodes = op_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, Reduce) - - for idx, node in enumerate(nodes): - print('=======') - print(node) - print('inputs:') - for input in node.inputs(): - print(input) - assert input.shape == [256, 1024] - print('outputs:') - for output in node.outputs(): - print(output) - assert output.shape == [1] - assert output.val_map == ValueMap(idx, 4) From ce48e3a4f68e72b6016efcaaa0fe63bf6475e34b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 00:55:02 +0800 Subject: [PATCH 0346/1892] update mapping --- cube/graph/parser/mapping.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index ceb49795..09c09eac 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -39,19 +39,19 @@ def map(signature: str) -> IRFwOperation: __ftemplate('softmax') : function.Softmax, - __ftemplate('dropout') : partial(function.ElementWiseActivation, name='dropout'), + __ftemplate('dropout') : function.Dropout, - __ftemplate('gelu') : partial(function.ElementWiseActivation, name='gelu'), + __ftemplate('gelu') : partial(function.Activation, name='gelu'), # torch aten - __ttemplate('add') : partial(function.Add, name='add'), + __ttemplate('add') : function.Add, __ttemplate('mul') : partial(function.ElementWise, name='mul'), __ttemplate('bmm') : function.BatchLinear, - __ttemplate('sum') : partial(function.Sum, name='sum'), + __ttemplate('sum') : function.Sum, __ttemplate('transpose') : function.Transpose, From 632517fc929a4cdff1fbdd65028956c5a4e94351 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 01:17:01 +0800 Subject: [PATCH 0347/1892] add transpose partition --- cube/algorithm/factory.py | 3 ++ cube/algorithm/ops/memory.py | 68 +++++++++++++++++++++++++++++++++ cube/algorithm/ops/reduce.py | 1 - cube/graph/operator/function.py | 15 ++++++-- tests/algorithm/test_memory.py | 46 ++++++++++++++++++++++ 5 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 cube/algorithm/ops/memory.py create mode 100644 tests/algorithm/test_memory.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index a8864298..6c23d07a 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -94,3 +94,6 @@ def _load_predefined_algos(self): self.register(complex.CubeComplexAttnView, complex.CubeAttnViewDataParallel, tag='data') self.register(complex.CubeComplexAttnView, complex.CubeAttnViewHeadParallel, tag='head') + + import cube.algorithm.ops.memory as mem + self.register(mem.Transpose, mem.TransposeDimParallel, tag='dim') diff --git a/cube/algorithm/ops/memory.py b/cube/algorithm/ops/memory.py new file mode 100644 index 00000000..4a5493be --- /dev/null +++ b/cube/algorithm/ops/memory.py @@ -0,0 +1,68 @@ +from typing import Dict +import copy + +from cube.algorithm.utils import split_axis +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.operator.function import Transpose + + +_kWaitDecision = None + + +class TransposeDimParallel(GenericDistAlgo): + + def __init__(self, node: Transpose, dim=None): + if not isinstance(node, Transpose): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + + self.dim0 = node.kwargs['dim0'] + self.dim1 = node.kwargs['dim1'] + self.ndim = len(node.inputs(0).shape) + + # config + self.chunk_num = _kWaitDecision + self.dim = dim + + def satisfy(self, config: Dict): + if 'dim' in config: + dim = config['dim'] + dim = self.ndim + dim if dim < 0 else dim + else: + if self.dim is None: + raise RuntimeError("Expected dim in config") + dim = self.dim + chunk_num = int(config['chunk_num']) + shape = self.input_shapes[0] + if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: + return True + return False + + def instantiate(self, node: Transpose, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + if 'dim' in config: + self.dim = config['dim'] + + input = node.inputs(0) + sub_inputs = split_axis(input, self.dim, self.chunk_num) + + output = node.outputs(0) + target_dim = self.dim + if self.dim == self.dim0: + target_dim = self.dim1 + if self.dim == self.dim1: + target_dim = self.dim0 + sub_outputs = split_axis(output, target_dim, self.chunk_num) + + nodes = list() + for input, output in zip(sub_inputs, sub_outputs): + sub_node = type(node)( + node.signature, inputs=[input, self.dim0, self.dim1], name=node.name + ) + sub_node.kwargs = copy.copy(node.kwargs) + sub_node.set_output(0, output) + nodes.append(sub_node) + return nodes diff --git a/cube/algorithm/ops/reduce.py b/cube/algorithm/ops/reduce.py index c9ebe54a..a16e65f4 100644 --- a/cube/algorithm/ops/reduce.py +++ b/cube/algorithm/ops/reduce.py @@ -58,7 +58,6 @@ def instantiate(self, node: Sum, config: Dict): assert len(node.outputs()) == 1 output = node.outputs(0) - print(self.reduce_dims) if self.dim not in self.reduce_dims: sub_outputs = split_axis(output, self.dim, self.chunk_num) else: diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index df223826..5c68c98a 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -267,10 +267,17 @@ def __init__(self, signature, inputs, name='transpose', **kwargs): def infer_shape(self): if self.inputs(0).shape is None: return False - dim1 = self.kwargs['dim0'] - dim2 = self.kwargs['dim1'] - shape = copy.copy(list(self.inputs(0).shape)) - shape[dim1], shape[dim2] = shape[dim2], shape[dim1] + ndim = len(self.inputs(0).shape) + dim0 = self.kwargs['dim0'] + if dim0 < 0: + dim0 = ndim + dim0 + self.kwargs['dim0'] = dim0 + dim1 = self.kwargs['dim1'] + if dim1 < 0: + dim1 = ndim + dim1 + self.kwargs['dim1'] = dim1 + shape = list(self.inputs(0).shape) + shape[dim0], shape[dim1] = shape[dim1], shape[dim0] self._outputs[0].shape = shape return True diff --git a/tests/algorithm/test_memory.py b/tests/algorithm/test_memory.py new file mode 100644 index 00000000..efab4015 --- /dev/null +++ b/tests/algorithm/test_memory.py @@ -0,0 +1,46 @@ +import cube.algorithm.ops.memory as mem +from cube.graph.tensor import IRFullTensor, ValueMap + + +def test_transpose_dim_parallel(): + + M = 512 + N = 1024 + input1 = IRFullTensor(shape=[M, N], name='input1').tosub() + dim0 = 0 + dim1 = 1 + + semantic_op = mem.Transpose( + signature='torch.transpose', inputs=[input1, dim0, dim1], name='transpose' + ) + semantic_op.infer_shape() + print('semantic op:') + print(semantic_op) + + op_dim = mem.TransposeDimParallel(semantic_op) + assert op_dim.dim is None + assert op_dim.chunk_num is None + + # test satisfy + assert op_dim.satisfy(dict(dim=0, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) + assert op_dim.satisfy(dict(dim=1, chunk_num=4)) + assert op_dim.satisfy(dict(dim=-1, chunk_num=4)) + + nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, mem.Transpose) + + for idx, node in enumerate(nodes): + print('=======') + print(node) + print('inputs:') + for input in node.inputs(): + print(input) + assert input.shape == [M // 4, N] + print('outputs:') + for output in node.outputs(): + print(output) + assert output.shape == [N, M // 4] + assert output.val_map == ValueMap(0, 1) From db2fe590588d7c82136143705bbf4c19235eeb0c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 12:28:45 +0800 Subject: [PATCH 0348/1892] fix activation bug --- cube/algorithm/ops/activation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cube/algorithm/ops/activation.py b/cube/algorithm/ops/activation.py index 7ed73103..05aa4ec0 100644 --- a/cube/algorithm/ops/activation.py +++ b/cube/algorithm/ops/activation.py @@ -93,8 +93,8 @@ def instantiate(self, node: Activation, config: Dict): class DropoutDimParallel(ActivationDimParallel): - def __init__(self, node: Activation, dim=None, execlude_dims=None): - super().__init__(node, dim=dim, execlude_dims=execlude_dims) + def __init__(self, node: Activation, dim=None): + super().__init__(node, dim=dim) def get_extra_kwargs(self, node: Dropout) -> List[Any]: if not isinstance(node, Dropout): @@ -105,8 +105,8 @@ def get_extra_kwargs(self, node: Dropout) -> List[Any]: class SoftmaxDimParallel(ActivationDimParallel): - def __init__(self, node: Activation, dim=None, execlude_dims=None): - super().__init__(node, dim=dim, execlude_dims=execlude_dims) + def __init__(self, node: Activation, dim=None): + super().__init__(node, dim=dim) def get_extra_kwargs(self, node) -> List[Any]: if not isinstance(node, Softmax): From 0b93a92205ff87146e679d98c5b0e8d56c1e2e60 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 12:30:22 +0800 Subject: [PATCH 0349/1892] fix linear test --- tests/algorithm/test_linear_algo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/algorithm/test_linear_algo.py b/tests/algorithm/test_linear_algo.py index 67e14fb2..345d4292 100644 --- a/tests/algorithm/test_linear_algo.py +++ b/tests/algorithm/test_linear_algo.py @@ -1,7 +1,7 @@ from cube.graph.operator.function import Linear -from cube.algorithm.linear import LinearDataParallel -from cube.algorithm.linear import LinearColumnWeight -from cube.algorithm.linear import LinearRowWeight +from cube.algorithm.ops.linear import LinearDataParallel +from cube.algorithm.ops.linear import LinearColumnWeight +from cube.algorithm.ops.linear import LinearRowWeight from cube.graph.tensor import IRFullTensor, ValueMap From fbc98fdac0d0ef7bab0aa4629e7113d3fc5536f5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 13:53:06 +0800 Subject: [PATCH 0350/1892] linear support multi dimensional partition --- cube/algorithm/ops/linear.py | 52 ++++++++++++++++++++++++++------- cube/algorithm/utils.py | 6 +++- cube/graph/operator/function.py | 12 ++++++-- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/cube/algorithm/ops/linear.py b/cube/algorithm/ops/linear.py index b1fb7c22..54132e0f 100644 --- a/cube/algorithm/ops/linear.py +++ b/cube/algorithm/ops/linear.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict +from typing import Dict from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo @@ -9,6 +9,15 @@ class LinearDataParallel(GenericDistAlgo): + """ + Input: + input: [N, *, in_features] + weight: [out_features, in_features] + bias: [out_features,] + + Output: + [N, *, in_features] + """ def __init__(self, node: Linear): @@ -16,12 +25,31 @@ def __init__(self, node: Linear): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) + # input dimension + self.ndim = len(node.inputs(0).shape) + self.dim_choice = list(range(self.ndim - 1)) + self.chunk_num = _kWaitDecision + if len(self.dim_choice) == 1: + self.dim = 0 + else: + self.dim = _kWaitDecision def satisfy(self, config: Dict): + input_shape = self.input_shapes[0] + if input_shape is None: + return False chunk_num = int(config['chunk_num']) + if 'dim' in config: + dim = config['dim'] + else: + if self.dim is None: + raise RuntimeError("Expected dim in config") + dim = self.dim + if dim < 0: + dim = self.ndim + dim input_shape = self.input_shapes[0] - if chunk_num > 0 and input_shape[0] % chunk_num == 0: + if chunk_num > 0 and input_shape[dim] % chunk_num == 0: return True return False @@ -29,11 +57,13 @@ def instantiate(self, node, config: Dict): if not self.satisfy(config): raise RuntimeError("Instantiate failed. Condition not satisfied.") self.chunk_num = int(config['chunk_num']) + if 'dim' in config: + self.dim = config['dim'] input, weight, bias = node.inputs() output = node.outputs(0) - ins = split_axis(input, 0, self.chunk_num) - outs = split_axis(output, 0, self.chunk_num) + ins = split_axis(input, self.dim, self.chunk_num) + outs = split_axis(output, self.dim, self.chunk_num) nodes = list() for input_chunk, output_chunk in zip(ins, outs): @@ -60,9 +90,9 @@ def __init__(self, node: Linear): def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) weight_shape = self.input_shapes[1] - if weight_shape[0] % chunk_num != 0: - return False - return True + if chunk_num > 0 and weight_shape[0] % chunk_num == 0: + return True + return False def instantiate(self, node, config: Dict): if not self.satisfy(config): @@ -103,9 +133,9 @@ def __init__(self, node: Linear): def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) weight_shape = self.input_shapes[1] - if weight_shape[1] % chunk_num != 0: - return False - return True + if chunk_num > 0 and weight_shape[1] % chunk_num == 0: + return True + return False def instantiate(self, node, config: Dict): if not self.satisfy(config): @@ -114,7 +144,7 @@ def instantiate(self, node, config: Dict): input, weight, bias = node.inputs() output = node.outputs(0) - ins = split_axis(input, 1, self.chunk_num) + ins = split_axis(input, -1, self.chunk_num) ws = split_axis(weight, 1, self.chunk_num) if bias: bs = split_value(bias, self.chunk_num) diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py index a552092c..221280df 100644 --- a/cube/algorithm/utils.py +++ b/cube/algorithm/utils.py @@ -3,7 +3,11 @@ def split_axis(tensor: IRTensor, axis: int, chunk_num: int): - + """ + Split tensor along an axis. The axis can be positive or negative. + """ + if axis < 0: + axis = len(tensor.shape) + axis if axis >= len(tensor.shape): raise RuntimeError(f"Axis should within dims ({axis} >= {len(tensor.shape)})") diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 5c68c98a..e62bf40d 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -1,12 +1,18 @@ import copy -from typing import Type from cube.graph.operator import IRFwOperation -from cube.ir.cten import IRTensor class Linear(IRFwOperation): - + """ + Input: + input: [N, *, in_features] + weight: [out_features, in_features] + bias: [out_features,] + + Output: + [N, *, in_features] + """ def __init__(self, signature, inputs, name='linear', **kwargs): input, weight, bias = inputs From 6914847123603fc5b5a0ddf024efb54e5893da04 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 15:43:37 +0800 Subject: [PATCH 0351/1892] dataloader awares of batch size; enable attention tp --- cube/algorithm/ops/complex.py | 76 +++++----- cube/algorithm/ops/dataloader.py | 11 +- cube/algorithm/ops/elementwise.py | 1 - cube/compiler.py | 17 ++- cube/graph/operator/function.py | 43 +++--- cube/graph/operator/operator.py | 13 +- cube/runtime/collectives.py | 4 + cube/runtime/function/__init__.py | 1 + cube/runtime/function/complex.py | 48 +++---- cube/runtime/syndata.py | 54 +++++-- cube/schedule/translator.py | 15 +- examples/attention/attention.py | 142 +++++++++++++++++++ examples/attention/policy/no_parallel.py | 23 +++ examples/attention/policy/tensor_parallel.py | 109 ++++++++++++++ 14 files changed, 446 insertions(+), 111 deletions(-) create mode 100644 examples/attention/attention.py create mode 100644 examples/attention/policy/no_parallel.py create mode 100644 examples/attention/policy/tensor_parallel.py diff --git a/cube/algorithm/ops/complex.py b/cube/algorithm/ops/complex.py index b4ac3500..f64a2acf 100644 --- a/cube/algorithm/ops/complex.py +++ b/cube/algorithm/ops/complex.py @@ -15,15 +15,15 @@ class CubeToQKVDataParallel(GenericDistAlgo): """ Inputs: hidden_state: [L, N, E] - weight: [3 * (num_heads * dim_head), E] - num_heads: int + weight: [3 * (num_head * dim_head), E] + num_head: int - where L = sequence length, N = batch size, E = num_heads * dim_head + where L = sequence length, N = batch size, E = num_head * dim_head Returns: - Q: [L, N * num_heads, dim_head] - K: [L, N * num_heads, dim_head] - V: [L, N * num_heads, dim_head] + Q: [L, N * num_head, dim_head] + K: [L, N * num_head, dim_head] + V: [L, N * num_head, dim_head] """ def __init__(self, node: CubeComplexToQKV): @@ -31,7 +31,7 @@ def __init__(self, node: CubeComplexToQKV): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) self.chunk_num = _kWaitDecision - self.num_heads = node.kwargs['num_heads'] + self.num_head = node.kwargs['num_head'] self.bs = node.inputs(0).shape[1] def satisfy(self, config: Dict): @@ -55,7 +55,7 @@ def instantiate(self, node, config: Dict): nodes = list() for idx in range(self.chunk_num): - inputs = [ins[idx], weight, self.num_heads] + inputs = [ins[idx], weight, self.num_head] node = CubeComplexToQKV( signature = 'cube.runtime.function.complex.toqkv', inputs = inputs, @@ -71,14 +71,14 @@ def instantiate(self, node, config: Dict): class CubeToQKVHeadParallel(GenericDistAlgo): """ Inputs: - hidden_state: [L, N, E] (seqlen, batch size, num_heads * dim_head) + hidden_state: [L, N, E] (seqlen, batch size, num_head * dim_head) weight: [E * 3, E] - num_heads: int + num_head: int Returns: - Q: [L, N * num_heads, dim_head] - K: [L, N * num_heads, dim_head] - V: [L, N * num_heads, dim_head] + Q: [L, N * num_head, dim_head] + K: [L, N * num_head, dim_head] + V: [L, N * num_head, dim_head] """ def __init__(self, node: CubeComplexToQKV): @@ -86,12 +86,12 @@ def __init__(self, node: CubeComplexToQKV): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) self.chunk_num = _kWaitDecision - self.num_heads = node.kwargs['num_heads'] + self.num_head = node.kwargs['num_head'] self.bs = node.inputs(0).shape[1] def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.num_heads % chunk_num == 0: + if chunk_num > 0 and self.num_head % chunk_num == 0: return True return False @@ -110,7 +110,7 @@ def instantiate(self, node, config: Dict): nodes = list() for idx in range(self.chunk_num): - inputs = [hidden_state, ws[idx], self.num_heads // self.chunk_num] + inputs = [hidden_state, ws[idx], self.num_head // self.chunk_num] node = CubeComplexToQKV( signature = 'cube.runtime.function.complex.toqkv', inputs = inputs, @@ -126,11 +126,11 @@ def instantiate(self, node, config: Dict): class CubeTrilMaskDataParallel(GenericDistAlgo): """ Inputs: - input: [N * num_heads, L, L] + input: [N * num_head, L, L] num_head: int Returns: - output: [N * num_heads, L, L] + output: [N * num_head, L, L] """ def __init__(self, node: CubeComplexTrilMask): @@ -138,8 +138,8 @@ def __init__(self, node: CubeComplexTrilMask): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) self.chunk_num = _kWaitDecision - self.num_heads = node.kwargs['num_heads'] - self.bs = node.inputs(0).shape[0] // self.num_heads + self.num_head = node.kwargs['num_head'] + self.bs = node.inputs(0).shape[0] // self.num_head def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) @@ -160,7 +160,7 @@ def instantiate(self, node, config: Dict): nodes = list() for idx in range(self.chunk_num): - inputs = [ins[idx], self.num_heads] + inputs = [ins[idx], self.num_head] node = CubeComplexTrilMask( signature = 'cube.runtime.function.complex.tril_mask', inputs = inputs, @@ -174,11 +174,11 @@ def instantiate(self, node, config: Dict): class CubeTrilMaskHeadParallel(GenericDistAlgo): """ Inputs: - input: [N * num_heads, L, L] + input: [N * num_head, L, L] num_head: int Returns: - output: [N * num_heads, L, L] + output: [N * num_head, L, L] """ def __init__(self, node: CubeComplexTrilMask): @@ -186,12 +186,12 @@ def __init__(self, node: CubeComplexTrilMask): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) self.chunk_num = _kWaitDecision - self.num_heads = node.kwargs['num_heads'] - self.bs = node.inputs(0).shape[0] // self.num_heads + self.num_head = node.kwargs['num_head'] + self.bs = node.inputs(0).shape[0] // self.num_head def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.num_heads % chunk_num == 0: + if chunk_num > 0 and self.num_head % chunk_num == 0: return True return False @@ -208,7 +208,7 @@ def instantiate(self, node, config: Dict): nodes = list() for idx in range(self.chunk_num): - inputs = [ins[idx], self.num_heads // self.chunk_num] + inputs = [ins[idx], self.num_head // self.chunk_num] node = CubeComplexTrilMask( signature = 'cube.runtime.function.complex.tril_mask', inputs = inputs, @@ -222,18 +222,18 @@ def instantiate(self, node, config: Dict): class CubeAttnViewDataParallel(GenericDistAlgo): """ Inputs: - [N * num_heads, L, dim_head] + [N * num_head, L, dim_head] Outputs: - [L, N, num_heads * dim_head] + [L, N, num_head * dim_head] """ def __init__(self, node: CubeComplexAttnView): if not isinstance(node, CubeComplexAttnView): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) self.chunk_num = _kWaitDecision - self.num_heads = node.kwargs['num_heads'] - self.bs = node.inputs(0).shape[0] // self.num_heads + self.num_head = node.kwargs['num_head'] + self.bs = node.inputs(0).shape[0] // self.num_head def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) @@ -254,7 +254,7 @@ def instantiate(self, node, config: Dict): nodes = list() for idx in range(self.chunk_num): - inputs = [ins[idx], self.num_heads] + inputs = [ins[idx], self.num_head] node = CubeComplexAttnView( signature = 'cube.runtime.function.complex.attn_view', inputs = inputs, @@ -268,22 +268,22 @@ def instantiate(self, node, config: Dict): class CubeAttnViewHeadParallel(GenericDistAlgo): """ Inputs: - [N * num_heads, L, dim_head] + [N * num_head, L, dim_head] Outputs: - [L, N, num_heads * dim_head] + [L, N, num_head * dim_head] """ def __init__(self, node: CubeComplexAttnView): if not isinstance(node, CubeComplexAttnView): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) self.chunk_num = _kWaitDecision - self.num_heads = node.kwargs['num_heads'] - self.bs = node.inputs(0).shape[0] // self.num_heads + self.num_head = node.kwargs['num_head'] + self.bs = node.inputs(0).shape[0] // self.num_head def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.num_heads % chunk_num == 0: + if chunk_num > 0 and self.num_head % chunk_num == 0: return True return False @@ -300,7 +300,7 @@ def instantiate(self, node, config: Dict): nodes = list() for idx in range(self.chunk_num): - inputs = [ins[idx], self.num_heads // self.chunk_num] + inputs = [ins[idx], self.num_head // self.chunk_num] node = CubeComplexAttnView( signature = 'cube.runtime.function.complex.attn_view', inputs = inputs, diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py index f863e592..808fa44b 100644 --- a/cube/algorithm/ops/dataloader.py +++ b/cube/algorithm/ops/dataloader.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Type +from typing import Dict from cube.algorithm.utils import split_axis from cube.algorithm.generics import GenericDistAlgo @@ -15,13 +15,14 @@ def __init__(self, node: IRDataOperation): if not isinstance(node, IRDataOperation): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) + self.batch_dims = node.get_batch_dims() self.chunk_num = _kWaitDecision def satisfy(self, config: Dict): chunk_num = int(config['chunk_num']) - for shape in self.output_shapes: - if chunk_num > 0 and shape[0] % chunk_num != 0: + for bdim, shape in zip(self.batch_dims, self.output_shapes): + if chunk_num > 0 and shape[bdim] % chunk_num != 0: return False return True @@ -31,8 +32,8 @@ def instantiate(self, node, config: Dict): self.chunk_num = int(config['chunk_num']) sub_outputs = list() - for output in node.outputs(): - sub_output = split_axis(output, 0, self.chunk_num) + for bdim, output in zip(self.batch_dims, node.outputs()): + sub_output = split_axis(output, bdim, self.chunk_num) sub_outputs.append(sub_output) nodes = list() diff --git a/cube/algorithm/ops/elementwise.py b/cube/algorithm/ops/elementwise.py index 2c83faab..b0639bed 100644 --- a/cube/algorithm/ops/elementwise.py +++ b/cube/algorithm/ops/elementwise.py @@ -71,7 +71,6 @@ def instantiate(self, node, config: Dict): nodes = list() for idx, sub_input in enumerate(zip(*sub_inputs)): - print(sub_input) sub_input = list(sub_input) + self.get_extra_kwargs(node) sub_node = type(node)(node.signature, inputs=sub_input, name=node.name) sub_node.kwargs = copy.copy(node.kwargs) diff --git a/cube/compiler.py b/cube/compiler.py index 6359445b..be90b3cb 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -3,6 +3,7 @@ import cube from cube.graph.graph import IRGraph +from cube.graph.operator.operator import IRDataOperation from cube.schedule.pool import SchedulePool from cube.schedule.su import SUType from cube.schedule.translator import IRDataLoader @@ -174,15 +175,19 @@ def decorator(fn: Callable) -> Callable: attach=True ) # get dataloader batch size - data = None + batch_size = dict() # {devid: batch size} for su in sugraph.sus(): if su.stype == SUType.Dataloader: - data = su.outputs(0) - break - if data is None: - raise RuntimeError("dataloader not found in SUGraph") + data_op: IRDataOperation = su.nodes(0) + batch_dim = data_op.get_batch_dims()[0] + dev_batch_size = data_op.outputs(0).shape[batch_dim] + batch_size[su.device[0]] = dev_batch_size + all_batch_size = set([batch_size[dev] for dev in batch_size]) + if len(all_batch_size) != 1: + raise NotImplementedError("Heterogenous batch size it not supported") + batch_size = list(all_batch_size)[0] # assume batch_size is always first dimension - batch_size = torch.tensor([data.shape[0]], dtype=torch.int).cuda() + batch_size = torch.tensor([batch_size], dtype=torch.int).cuda() if torch.distributed.is_initialized(): torch.distributed.barrier() diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index e62bf40d..3d3ddce3 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -293,15 +293,16 @@ class CubeComplexToQKV(IRFwOperation): """ Inputs: hidden_state: [L, N, E] - weight: [3 * (num_heads * dim_head), E] - num_heads: int + weight: [3 * (num_head * dim_head), E] + num_head: int + dim_head: int - where L = sequence length, N = batch size, E = num_heads * dim_head + where L = sequence length, N = batch size, E = num_head * dim_head Returns: - Q: [L, N * num_heads, dim_head] - K: [L, N * num_heads, dim_head] - V: [L, N * num_heads, dim_head] + Q: [L, N * num_head, dim_head] + K: [L, N * num_head, dim_head] + V: [L, N * num_head, dim_head] """ def __init__(self, signature, inputs, name='toqkv', **kwargs): if len(inputs) != 3: @@ -314,17 +315,17 @@ def __init__(self, signature, inputs, name='toqkv', **kwargs): ) self.set_input(0, qkv) self.set_input(1, weight) - self.kwargs['num_heads'] = inputs[2] + self.kwargs['num_head'] = inputs[2] def infer_shape(self): if self.inputs(0).shape is None or self.inputs(1) is None: return False seqlen = self.inputs(0).shape[0] bs = self.inputs(0).shape[1] - num_heads = self.kwargs['num_heads'] - dim_head = self.inputs(1).shape[0] // 3 // num_heads + num_head = self.kwargs['num_head'] + dim_head = self.inputs(1).shape[0] // 3 // num_head - shape = [seqlen, bs * num_heads, dim_head] + shape = [seqlen, bs * num_head, dim_head] for output in self.outputs(): output.shape = shape return True @@ -333,23 +334,23 @@ def infer_shape(self): class CubeComplexTrilMask(IRFwOperation): """ Inputs: - input: [N * num_heads, L, L] + input: [N * num_head, L, L] num_head: int Returns: - output: [N * num_heads, L, L] + output: [N * num_head, L, L] """ def __init__(self, signature, inputs, name='trilmask', **kwargs): if len(inputs) != 2: raise TypeError("Expected 2 input") - tensor, num_heads = inputs[0], inputs[1] + tensor, num_head = inputs[0], inputs[1] super().__init__( name, signature, input_length=1, output_length=1 ) self.set_input(0, tensor) - self.kwargs['num_heads'] = num_heads + self.kwargs['num_head'] = num_head def infer_shape(self): if self.inputs(0).shape is None: @@ -361,31 +362,31 @@ def infer_shape(self): class CubeComplexAttnView(IRFwOperation): """ Inputs: - [N * num_heads, L, dim_head] + [N * num_head, L, dim_head] Outputs: - [L, N, num_heads * dim_head] + [L, N, num_head * dim_head] """ def __init__(self, signature, inputs, name='attn_view', **kwargs): if len(inputs) != 2: raise TypeError("Expected 2 input") - tensor, num_heads = inputs[0], inputs[1] + tensor, num_head = inputs[0], inputs[1] super().__init__( name, signature, input_length=1, output_length=1 ) self.set_input(0, tensor) - self.kwargs['num_heads'] = num_heads + self.kwargs['num_head'] = num_head def infer_shape(self): if self.inputs(0).shape is None: return False - num_heads = self.kwargs['num_heads'] - bs = self.inputs(0).shape[0] // num_heads + num_head = self.kwargs['num_head'] + bs = self.inputs(0).shape[0] // num_head seqlen = self.inputs(0).shape[1] dim_head = self.inputs(0).shape[2] - shape = [seqlen, bs, num_heads * dim_head] + shape = [seqlen, bs, num_head * dim_head] self._outputs[0].shape = shape return True diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 9ee2f71f..48f94838 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,6 +1,8 @@ from typing import Any, Optional, Union, List import copy +from torch._C import is_anomaly_enabled + from cube.ir.cten import IRTensor, IRCell from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory @@ -222,10 +224,17 @@ def __repr__(self): class IRDataOperation(IRCell): - def __init__(self, data_num: int, name='dataloader'): - + def __init__(self, data_num: int, batch_dims: List[int], name='dataloader'): + if not isinstance(batch_dims, list): + raise RuntimeError("Expected batch dims to be a list") + if len(batch_dims) != data_num: + raise RuntimeError("Expected each output data has a specified batch dim") signature = 'dataloader.__next__' super().__init__(name, signature, 0, data_num) + self.batch_dims = batch_dims + + def get_batch_dims(self): + return copy.copy(self.batch_dims) def infer_shape(self): """ diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index b4f60a2d..4ebae7c9 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -106,6 +106,10 @@ def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): tensor = tensors[0] tensor = tensor.detach() tensor = tensor.requires_grad_() + + ### Bypass ### + # return tensor + group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(tensor, group=group) return tensor diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py index 9dd5982d..a8c9ff96 100644 --- a/cube/runtime/function/__init__.py +++ b/cube/runtime/function/__init__.py @@ -1 +1,2 @@ +import cube.runtime.function.complex as complex from cube.runtime.function.complex import * \ No newline at end of file diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index 29707ae8..1544d218 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -3,47 +3,47 @@ def toqkv(hidden_state: torch.Tensor, weight: torch.Tensor, - num_heads: int): + num_head: int): """ Inputs: hidden_state: [L, N, E] - weight: [3 * (num_heads * dim_head), E] - num_heads: int + weight: [3 * (num_head * dim_head), E] + num_head: int - where L = sequence length, N = batch size, E = num_heads * dim_head + where L = sequence length, N = batch size, E = num_head * dim_head Returns: - Q: [L, N * num_heads, dim_head] - K: [L, N * num_heads, dim_head] - V: [L, N * num_heads, dim_head] + Q: [L, N * num_head, dim_head] + K: [L, N * num_head, dim_head] + V: [L, N * num_head, dim_head] """ seqlen = hidden_state.shape[0] bs = hidden_state.shape[1] - dim_head = hidden_state.shape[2] // num_heads + dim_head = weight.shape[0] // 3 // num_head qkv = F.linear(hidden_state, weight, None) qkv = qkv.chunk(3, dim=-1) q, k, v = qkv q = q.contiguous() - q = q.view(seqlen, (bs * num_heads), dim_head) + q = q.view(seqlen, (bs * num_head), dim_head) k = k.contiguous() - k = k.view(seqlen, (bs * num_heads), dim_head) + k = k.view(seqlen, (bs * num_head), dim_head) v = v.contiguous() - v = v.view(seqlen, (bs * num_heads), dim_head) + v = v.view(seqlen, (bs * num_head), dim_head) return q, k, v -def tril_mask(input: torch.Tensor, num_heads: int): +def tril_mask(input: torch.Tensor, num_head: int): """ Inputs: - input: [N * num_heads, L, L] + input: [N * num_head, L, L] num_head: int Returns: - output: [N * num_heads, L, L] + output: [N * num_head, L, L] """ - bs: int = input.shape[0] // num_heads + bs: int = input.shape[0] // num_head seqlen: int = input.shape[2] - input = input.view(bs, num_heads, seqlen, seqlen) + input = input.view(bs, num_head, seqlen, seqlen) # set up mask ones = torch.ones( (bs, seqlen, seqlen), @@ -54,23 +54,23 @@ def tril_mask(input: torch.Tensor, num_heads: int): mask = (mask < 0.5) # mask masked_input = input.masked_fill_(mask, -100000.0) - masked_input = masked_input.view((bs * num_heads), seqlen, seqlen) + masked_input = masked_input.view((bs * num_head), seqlen, seqlen) return masked_input -def attn_view(input: torch.Tensor, num_heads: int): +def attn_view(input: torch.Tensor, num_head: int): """ Inputs: - [N * num_heads, L, dim_head] + [N * num_head, L, dim_head] Outputs: - [L, N, num_heads * dim_head] + [L, N, num_head * dim_head] """ - bs: int = input.shape[0] // num_heads + bs: int = input.shape[0] // num_head seqlen: int = input.shape[1] dim_head = input.shape[2] - # [(N * num_heads), L, dim_head] -> [L, (N * num_heads), dim_head] + # [(N * num_head), L, dim_head] -> [L, (N * num_head), dim_head] input = input.transpose(0, 1).contiguous() - # [L, (N * num_heads), dim_head] -> [L, N, (num_heads * dim_head)] - input = input.view(seqlen, bs, num_heads * dim_head) + # [L, (N * num_head), dim_head] -> [L, N, (num_head * dim_head)] + input = input.view(seqlen, bs, num_head * dim_head) return input diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 74cb8392..8451b1fe 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -2,20 +2,57 @@ Synthetic Data Loader """ -from typing import List +from typing import List, Optional +import copy import torch -__all__ = ['SynDataLoader'] +__all__ = ['CubeDataLoader', 'SynDataLoader'] -class SynDataLoader: +class CubeDataLoader: + r""" + Cube Dataloader + """ + def __init__(self, batch_dims: List[int], *shapes: List[List[int]]): + """ + batch_dim: + The batch dimension for each input shapes + *shapes: + The shape for each data + """ + if not isinstance(batch_dims, list): + raise RuntimeError("Expected a List[int] for batch dims") + self.shapes = list(shapes) + self.batch_dims = batch_dims + + def get_batch_dims(self, idx: Optional[int] = None) -> int: + """ + Get batch dimension for idx-th data + """ + if idx is not None: + return self.batch_dims[idx] + else: + return copy.copy(self.batch_dims) + + def reset(self, batch_size: int): + """ + Reset batch size + """ + for bdim, shape in zip(self.batch_dims, self.shapes): + shape[bdim] = batch_size + print(f'> data loader output shape change to: {self.shapes}') + + +class SynDataLoader(CubeDataLoader): r""" Synthetic dataloader to produce tensors for given shape. """ - def __init__(self, num: int, *shapes: List[List[int]]): - self.shapes = list(shapes) + def __init__(self, num: int, batch_dim: List[int], *shapes: List[List[int]]): + if len(shapes) != len(batch_dim): + raise TypeError("Expected length of batch dim is same to shapes") + super().__init__(batch_dim, *shapes) self.length = num self.pos = 0 @@ -23,13 +60,6 @@ def __iter__(self): self.pos = 0 return self - def reset(self, batch_size: int): - """ - Reset batch size - """ - for shape in self.shapes: - shape[0] = batch_size - def __next__(self): self.pos += 1 if self.pos == self.length: diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py index 09525f9d..dae6ade8 100644 --- a/cube/schedule/translator.py +++ b/cube/schedule/translator.py @@ -4,6 +4,8 @@ The traning logic first translate the training logic into Schedule Units, and then add Adapter ScheduleUnit """ +import copy +from typing import Optional import torch from cube.graph.tensor import IRFullTensor, IRSubTensor @@ -11,11 +13,20 @@ import cube.graph.gpass as gpass from cube.schedule.pool import SchedulePool +from cube.runtime.syndata import CubeDataLoader + class IRDataLoader: - def __init__(self, dataloader): + def __init__(self, dataloader: CubeDataLoader): self.dataloader = iter(dataloader) + self.batch_dims = dataloader.get_batch_dims() + + def get_batch_dims(self, idx: Optional[int] = None) -> int: + if idx is None: + return copy.copy(self.batch_dims) + else: + return self.batch_dims[idx] def __iter__(self): return self @@ -44,7 +55,7 @@ def load_data(dataloader: IRDataLoader): outputs.append(data) data_op = IRDataOperation( - data_num=len(datas) + data_num=len(datas), batch_dims=dataloader.get_batch_dims(), ) for idx, output in enumerate(outputs): data_op.set_output(idx, output) diff --git a/examples/attention/attention.py b/examples/attention/attention.py new file mode 100644 index 00000000..328f9d00 --- /dev/null +++ b/examples/attention/attention.py @@ -0,0 +1,142 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/attention.py +""" + +import torch +from torch import nn +import torch.nn.functional as F +import cube + + +from examples.attention.policy.tensor_parallel import transform_policy +from examples.attention.policy.tensor_parallel import schedule_policy + +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +class MultiHeadSelfAttention(nn.Module): + + def __init__(self, seq_len, embed_dim, heads, dropout): + super().__init__() + + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_head = heads + self.dim_head = embed_dim // heads + self.scale = self.dim_head ** -0.5 + + self.weight_qkv = torch.nn.Parameter(torch.empty( + 3 * embed_dim, embed_dim + )) + self.weight_out = torch.nn.Parameter(torch.empty( + embed_dim, embed_dim + )) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + x: [L, N, E]: seq_len, batch_size, embedding dimension + output: [L, N, E] + """ + # [L, N, E] -> 3 x [L, (N * num_head), dim_head] + q, k, v = cube.runtime.function.toqkv( + x, self.weight_qkv, self.num_head + ) + + # [L, (N * num_head), dim_head] -> [(N * num_head), L, dim_head] + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + # [(N * num_head), L, dim_head] -> [(N * num_head), L, dim_head] + q = q * self.scale + # [(N * num_head), L, dim_head] -> [(N * num_head), dim_head, L] + k = k.transpose(-2, -1) + # [(N * num_head), L, dim_head] * [(N * num_head), dim_head, L] + # -> [(N * num_head), L, L] + attn = torch.bmm(q, k) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = cube.runtime.function.tril_mask(attn, self.num_head) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = F.softmax(attn, dim=-1) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = self.dropout(attn) + # [(N * num_head), L, L] * [(N * num_head), L, dim_head] + # -> [(N * num_head), L, dim_head] + output = torch.bmm(attn, v) + + # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] + output = cube.runtime.function.attn_view(output, self.num_head) + + # [L, (N * num_head), dim_head] * [] + output = F.linear(output, self.weight_out) + + loss = torch.sum(output) + return loss + + +def train(): + L = 128 # seq len + N = 32 # batch size + # configs: [hidden size, num_head] + # E, num_head = [1536, 16] # 1.2B model + # E, num_head = [1920, 20] # 2.5B model + # E, num_head = [2304, 24] # 4.2B model + E, num_head = [3072, 32] # 8.7B model + + + model = MultiHeadSelfAttention( + seq_len=L, embed_dim=E, heads=num_head, dropout=0.5 + ) + model = cube.SemanticModel( + # TODO: data parallel batch dim + model, input_shapes=([L, N, E],), + ) + + # TODO: data parallel batch dim + dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) + + @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + + +if __name__ == '__main__': + + cube.init() + train() diff --git a/examples/attention/policy/no_parallel.py b/examples/attention/policy/no_parallel.py new file mode 100644 index 00000000..702c02b7 --- /dev/null +++ b/examples/attention/policy/no_parallel.py @@ -0,0 +1,23 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using column parallel + """ + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + for su in sugraph.fsus(): + sugraph.assign(su, 0) + sugraph.assign(su.mirror, 0) + return sugraph diff --git a/examples/attention/policy/tensor_parallel.py b/examples/attention/policy/tensor_parallel.py new file mode 100644 index 00000000..1d3621d1 --- /dev/null +++ b/examples/attention/policy/tensor_parallel.py @@ -0,0 +1,109 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using data parallel + """ + ndevs = resource.ngpus + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + assert len(fnodes) == 14 + + toqkv = fnodes[0] + q_t = fnodes[1] + k_t = fnodes[2] + v_t = fnodes[3] + q_scale = fnodes[4] + k_t2 = fnodes[5] + qk_bmm = fnodes[6] + mask = fnodes[7] + softmax = fnodes[8] + dropout = fnodes[9] + attnv_bmm = fnodes[10] + attnview = fnodes[11] + linear = fnodes[12] + loss = fnodes[13] + + all_sub_nodes = list() + + algo = toqkv.algorithms('head') + sub_nodes = graph.partition(toqkv, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = q_t.algorithms('dim') + sub_nodes = graph.partition(q_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = k_t.algorithms('dim') + sub_nodes = graph.partition(k_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = v_t.algorithms('dim') + sub_nodes = graph.partition(v_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = q_scale.algorithms('dim') + sub_nodes = graph.partition(q_scale, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = k_t2.algorithms('dim') + sub_nodes = graph.partition(k_t2, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = qk_bmm.algorithms('data') + sub_nodes = graph.partition(qk_bmm, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = mask.algorithms('head') + sub_nodes = graph.partition(mask, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = softmax.algorithms('dim') + sub_nodes = graph.partition(softmax, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = dropout.algorithms('dim') + sub_nodes = graph.partition(dropout, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = attnv_bmm.algorithms('data') + sub_nodes = graph.partition(attnv_bmm, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = attnview.algorithms('head') + sub_nodes = graph.partition(attnview, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = linear.algorithms('row') + sub_nodes = graph.partition(linear, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + sub_nodes = graph.replicate(loss, times=ndevs) + all_sub_nodes.append(sub_nodes) + + for sub_nodes in all_sub_nodes: + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + print(graph) + # assert False + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + for su in sugraph.fsus(): + devid = su.tag[0] + print(f'assinging {su.nodes(0)}') + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + return sugraph From 48a1aaeedf9eacad86c8cbcc3658b7e1638f7b66 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 18:02:26 +0800 Subject: [PATCH 0352/1892] remove data loading time --- cube/runtime/syndata.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 8451b1fe..e8ea908d 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -56,17 +56,32 @@ def __init__(self, num: int, batch_dim: List[int], *shapes: List[List[int]]): self.length = num self.pos = 0 + self._buffer_num = None + self.datas: torch.Tensor = list() + self.set_data_buffer() + def __iter__(self): self.pos = 0 return self + def set_data_buffer(self, buffer_num = 4): + self.datas = list() + self._buffer_num = buffer_num + for _ in range(self._buffer_num): + datas = list() + for shape in self.shapes: + data = torch.randn(shape).cuda() + datas.append(data) + self.datas.append(datas) + + def reset(self, batch_size: int): + super().reset(batch_size) + self.set_data_buffer() + def __next__(self): self.pos += 1 if self.pos == self.length: raise StopIteration - datas = list() - for shape in self.shapes: - data = torch.randn(shape).cuda() - datas.append(data) + datas = self.datas[self.pos % self._buffer_num] if len(datas) == 1: return datas[0] else: return tuple(datas) From 0a123d1a5bf24e3e50f960a5a28d16ec7a019b24 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 18:09:05 +0800 Subject: [PATCH 0353/1892] fix bugs in dependency: remove inter-adapter dependency --- cube/schedule/su.py | 12 +++++++++--- cube/schedule/sugraph.py | 20 +++++++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 888512c8..19b2a644 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -3,7 +3,6 @@ from enum import Enum from cube.ir.cten import IRCell, IRTensor -from cube.graph.operator import IRBpOperation class SUType(Enum): @@ -194,9 +193,13 @@ def select_adapters(self, index: Optional[int] = None) -> List: else: raise TypeError("Expected index to be None or int") - def _clear_adapters(self): + def _clear_adapters(self, ctrl=False): """ - Clear all adapters for this SU + Clear all adapters for this SU. By default control dependency is keeped + + Args: + ctrl (boolean): if true, additional control dependency is removed. + if false, additional control dependency is keeped. """ self._send_in_adapters: List[List[ScheduleUnit]] = [ list() for _ in range(len(self.inputs())) @@ -212,6 +215,9 @@ def _clear_adapters(self): self._recv_out_adapters: List[List[ScheduleUnit]] = [ list() for _ in range(len(self.outputs())) ] + if ctrl: + self._ctrl_predecessors = list() + self._ctrl_successors = list() def _add_in_adapter(self, index: int, send_adapters, recv_adapters): """ diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 695d0450..a88d4a73 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -51,12 +51,16 @@ def reset_dependency(sus: List[ScheduleUnit]): """ if not all([isinstance(su, ScheduleUnit) for su in sus]): raise TypeError("Expected list of schedule unit") + adapters = [SUType.P2P, SUType.Coll, SUType.Transform] for su in sus: su.clear_predecessor() su.clear_successor() for src_idx in range(len(sus)): src = sus[src_idx] for dst in sus[src_idx+1:]: + # inter-adapter has no dependency + if src.stype in adapters and dst.stype in adapters: + continue for out_idx, out_tensor in enumerate(src.outputs()): # special dependency for communication adapter if dst.stype == SUType.P2P: @@ -378,13 +382,22 @@ def set_order(self, seq: List[ScheduleUnit]): self.sequence = seq return True - def partial_set_order(self, seq: List[ScheduleUnit]): + def partial_set_order(self, seq: List[ScheduleUnit], lazy=False): """ Set a order of the sequence using part of SUs. A random topological order will be set under the constraints of given `seq` order + + Args: + seq: partial scheduling sequence + lazy: + if True, the remaining SU is inserted only when it is needed. + if False, the remaining SU is inserted once it is ready. + """ + if lazy: + raise NotImplementedError("Not supported for Lazy") seq = copy.copy(seq) for su in seq: if su not in self.sequence: @@ -397,6 +410,11 @@ def partial_set_order(self, seq: List[ScheduleUnit]): remain_sus.append(su) for rsu in remain_sus: happen_before_sus = rsu.predecessors() + # A temporal fix for loss computation and backward + # -- as they have no dependency in theory + if rsu.stype == SUType.Backward: + if rsu.mirror not in happen_before_sus: + happen_before_sus.append(rsu.mirror) idx = 0 while len(happen_before_sus) > 0: if idx == len(seq): From ef938a76c3ada485f690ccc4b75be0f81141f1f9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 18:10:01 +0800 Subject: [PATCH 0354/1892] tensor parallelism --- examples/attention/attention.py | 2 +- examples/attention/policy/tensor_parallel.py | 5 +++-- examples/inspector.py | 7 +++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/attention/attention.py b/examples/attention/attention.py index 328f9d00..747250eb 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -89,7 +89,7 @@ def forward(self, x): def train(): - L = 128 # seq len + L = 512 # seq len N = 32 # batch size # configs: [hidden size, num_head] # E, num_head = [1536, 16] # 1.2B model diff --git a/examples/attention/policy/tensor_parallel.py b/examples/attention/policy/tensor_parallel.py index 1d3621d1..df132025 100644 --- a/examples/attention/policy/tensor_parallel.py +++ b/examples/attention/policy/tensor_parallel.py @@ -6,7 +6,7 @@ def transform_policy(graph: IRGraph, resource): """ - The transformation policy transposes linear using data parallel + The transformation policy transposes linear using tensor parallel """ ndevs = resource.ngpus @@ -103,7 +103,8 @@ def schedule_policy(sugraph: SUGraph, resource): sugraph.assign(su, 0) for su in sugraph.fsus(): devid = su.tag[0] - print(f'assinging {su.nodes(0)}') sugraph.assign(su, devid) sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + sugraph.partial_set_order(fsus, lazy=False) return sugraph diff --git a/examples/inspector.py b/examples/inspector.py index 99ce7438..a3bc7700 100644 --- a/examples/inspector.py +++ b/examples/inspector.py @@ -19,7 +19,8 @@ from cube.profiler.timer import print_each_rank -kDataShapes = ([128, 1024],) +kDataShapes = ([512, 32, 3072],) +kBatchDims = [1] def load_module(filename: str): @@ -50,7 +51,9 @@ def load_train_fn(filename: str): def train(args): global kDataShapes - dataloader = cube.runtime.syndata.SynDataLoader(1280, *kDataShapes) + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, kBatchDims, *kDataShapes + ) genfile = args.genfile.format(rank=torch.distributed.get_rank()) model = load_module(genfile) From 664d6e031e1b40d442d51153a732aef5d91f39bc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 18:33:31 +0800 Subject: [PATCH 0355/1892] fix bugs in splitting dataloader --- cube/algorithm/ops/dataloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py index 808fa44b..7fd4f803 100644 --- a/cube/algorithm/ops/dataloader.py +++ b/cube/algorithm/ops/dataloader.py @@ -1,4 +1,5 @@ from typing import Dict +import copy from cube.algorithm.utils import split_axis from cube.algorithm.generics import GenericDistAlgo @@ -38,7 +39,8 @@ def instantiate(self, node, config: Dict): nodes = list() for sub_outs in zip(*sub_outputs): - node = IRDataOperation(data_num = len(sub_outs)) + node = IRDataOperation( + data_num = len(sub_outs), batch_dims = copy.copy(self.batch_dims)) for idx, out in enumerate(sub_outs): node.set_output(idx, out) nodes.append(node) From a375913a15c3986bd2bd87bf8bd94933db9f89df Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 18:33:48 +0800 Subject: [PATCH 0356/1892] fix bugs dependency --- cube/schedule/sugraph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index a88d4a73..fb77c58c 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -59,7 +59,9 @@ def reset_dependency(sus: List[ScheduleUnit]): src = sus[src_idx] for dst in sus[src_idx+1:]: # inter-adapter has no dependency - if src.stype in adapters and dst.stype in adapters: + if src.stype in adapters and \ + dst.stype in adapters and \ + src.stype == dst.stype: continue for out_idx, out_tensor in enumerate(src.outputs()): # special dependency for communication adapter From b4a3e79b635d4629dfb5cf9d25a13f2a25bdc6df Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 18:34:23 +0800 Subject: [PATCH 0357/1892] add data parallel policy for attention --- examples/attention/attention.py | 3 +- examples/attention/policy/data_parallel.py | 118 +++++++++++++++++++++ 2 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 examples/attention/policy/data_parallel.py diff --git a/examples/attention/attention.py b/examples/attention/attention.py index 747250eb..fa8b7c96 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -81,7 +81,8 @@ def forward(self, x): # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] output = cube.runtime.function.attn_view(output, self.num_head) - # [L, (N * num_head), dim_head] * [] + # [L, N, num_head * dim_head] * [E, embed_head * dim_head] + # -> [L, N, E] output = F.linear(output, self.weight_out) loss = torch.sum(output) diff --git a/examples/attention/policy/data_parallel.py b/examples/attention/policy/data_parallel.py new file mode 100644 index 00000000..d7819a94 --- /dev/null +++ b/examples/attention/policy/data_parallel.py @@ -0,0 +1,118 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation, IRDataOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using tensor parallel + """ + ndevs = resource.ngpus + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + assert len(fnodes) == 14 + + toqkv = fnodes[0] + q_t = fnodes[1] + k_t = fnodes[2] + v_t = fnodes[3] + q_scale = fnodes[4] + k_t2 = fnodes[5] + qk_bmm = fnodes[6] + mask = fnodes[7] + softmax = fnodes[8] + dropout = fnodes[9] + attnv_bmm = fnodes[10] + attnview = fnodes[11] + linear = fnodes[12] + loss = fnodes[13] + + all_sub_nodes = list() + + algo = toqkv.algorithms('data') + sub_nodes = graph.partition(toqkv, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = q_t.algorithms('dim') + sub_nodes = graph.partition(q_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = k_t.algorithms('dim') + sub_nodes = graph.partition(k_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = v_t.algorithms('dim') + sub_nodes = graph.partition(v_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = q_scale.algorithms('dim') + sub_nodes = graph.partition(q_scale, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = k_t2.algorithms('dim') + sub_nodes = graph.partition(k_t2, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = qk_bmm.algorithms('data') + sub_nodes = graph.partition(qk_bmm, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = mask.algorithms('head') + sub_nodes = graph.partition(mask, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = softmax.algorithms('dim') + sub_nodes = graph.partition(softmax, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = dropout.algorithms('dim') + sub_nodes = graph.partition(dropout, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = attnv_bmm.algorithms('data') + sub_nodes = graph.partition(attnv_bmm, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = attnview.algorithms('data') + sub_nodes = graph.partition(attnview, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = linear.algorithms('data') + sub_nodes = graph.partition(linear, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = loss.algorithms('dim') + sub_nodes = graph.partition(loss, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + # data loader + dataloaders = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] + for data_op in dataloaders: + algo = data_op.algorithms('data') + sub_nodes = graph.partition(data_op, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + for sub_nodes in all_sub_nodes: + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + print(graph) + # assert False + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + sugraph.partial_set_order(fsus, lazy=False) + return sugraph From 4e6130db0cfd94ca058c83585564d8cf9634a444 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 18:50:59 +0800 Subject: [PATCH 0358/1892] data parallel policy: dataloader fix --- examples/attention/attention.py | 8 +++----- examples/attention/policy/data_parallel.py | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/attention/attention.py b/examples/attention/attention.py index fa8b7c96..86c1cbb9 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -17,8 +17,8 @@ import cube -from examples.attention.policy.tensor_parallel import transform_policy -from examples.attention.policy.tensor_parallel import schedule_policy +from examples.attention.policy.data_parallel import transform_policy +from examples.attention.policy.data_parallel import schedule_policy from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank @@ -103,11 +103,9 @@ def train(): seq_len=L, embed_dim=E, heads=num_head, dropout=0.5 ) model = cube.SemanticModel( - # TODO: data parallel batch dim model, input_shapes=([L, N, E],), ) - - # TODO: data parallel batch dim + dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) diff --git a/examples/attention/policy/data_parallel.py b/examples/attention/policy/data_parallel.py index d7819a94..ab8aaf26 100644 --- a/examples/attention/policy/data_parallel.py +++ b/examples/attention/policy/data_parallel.py @@ -108,7 +108,8 @@ def schedule_policy(sugraph: SUGraph, resource): """ for su in sugraph.sus(): if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) + devid = su.tag[0] + sugraph.assign(su, devid) for su in sugraph.fsus(): devid = su.tag[0] sugraph.assign(su, devid) From cfd52aa37c0e2e6ab53953aa6dd60a64a8a1e507 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 19:04:22 +0800 Subject: [PATCH 0359/1892] fix column partition bugs --- cube/algorithm/ops/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/algorithm/ops/linear.py b/cube/algorithm/ops/linear.py index 54132e0f..af9cc23a 100644 --- a/cube/algorithm/ops/linear.py +++ b/cube/algorithm/ops/linear.py @@ -106,7 +106,7 @@ def instantiate(self, node, config: Dict): bs = split_axis(bias, 0, self.chunk_num) else: bs = [None] * self.chunk_num - os = split_axis(output, 1, self.chunk_num) + os = split_axis(output, -1, self.chunk_num) nodes = list() for w, b, o in zip(ws, bs, os): From 6c5e84f17b898af7444ba6218f257ce119e02491 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 19:08:02 +0800 Subject: [PATCH 0360/1892] ffn data & tensor parallel --- examples/ffn/ffn.py | 95 ++++++++++++++++++++++++++ examples/ffn/policy/data_parallel.py | 66 ++++++++++++++++++ examples/ffn/policy/tensor_parallel.py | 57 ++++++++++++++++ 3 files changed, 218 insertions(+) create mode 100644 examples/ffn/ffn.py create mode 100644 examples/ffn/policy/data_parallel.py create mode 100644 examples/ffn/policy/tensor_parallel.py diff --git a/examples/ffn/ffn.py b/examples/ffn/ffn.py new file mode 100644 index 00000000..f432db2b --- /dev/null +++ b/examples/ffn/ffn.py @@ -0,0 +1,95 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/ffn/ffn.py +""" + +import torch +import torch.nn.functional as F +import cube + +from examples.ffn.policy.data_parallel import transform_policy +from examples.ffn.policy.data_parallel import schedule_policy + +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +class FFN(torch.nn.Module): + + def __init__(self, hidden_size: int): + super().__init__() + self.dense_h_to_4h = torch.nn.Linear( + hidden_size, 4 * hidden_size + ) + self.dense_4h_to_h = torch.nn.Linear( + 4 * hidden_size, hidden_size + ) + + def forward(self, hidden_states): + # [L, N, E] * [E, 4E] -> [L, N, 4E] + out = self.dense_h_to_4h(hidden_states) + # [L, N, 4E] -> [L, N, 4E] + out = F.gelu(out) + # [L, N, 4E] * [4E, E] -> [L, N, E] + out = self.dense_4h_to_h(out) + + loss = torch.sum(out) + return loss + + +def train(): + L = 512 # seq len + N = 32 # batch size + # configs: [hidden size, num_head] + # E, num_head = [1536, 16] # 1.2B model + # E, num_head = [1920, 20] # 2.5B model + # E, num_head = [2304, 24] # 4.2B model + E, num_head = [3072, 32] # 8.7B model + + + model = FFN(hidden_size=E) + model = cube.SemanticModel( + model, input_shapes=([L, N, E],), + ) + + dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) + + @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + + +if __name__ == '__main__': + + cube.init() + train() diff --git a/examples/ffn/policy/data_parallel.py b/examples/ffn/policy/data_parallel.py new file mode 100644 index 00000000..6201b4f1 --- /dev/null +++ b/examples/ffn/policy/data_parallel.py @@ -0,0 +1,66 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation, IRDataOperation + + +def transform_policy(graph: IRGraph, resource): + + ndevs = resource.ngpus + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + assert len(fnodes) == 4 + + linear1 = fnodes[0] + gelu = fnodes[1] + linear2 = fnodes[2] + loss = fnodes[3] + + all_sub_nodes = list() + + algo = linear1.algorithms('data') + sub_nodes = graph.partition(linear1, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = gelu.algorithms('dim') + sub_nodes = graph.partition(gelu, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = linear2.algorithms('data') + sub_nodes = graph.partition(linear2, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = loss.algorithms('dim') + sub_nodes = graph.partition(loss, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + # data loader + dataloaders = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] + for data_op in dataloaders: + algo = data_op.algorithms('data') + sub_nodes = graph.partition(data_op, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + for sub_nodes in all_sub_nodes: + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + print(graph) + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + devid = su.tag[0] + sugraph.assign(su, devid) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + sugraph.partial_set_order(fsus, lazy=False) + return sugraph diff --git a/examples/ffn/policy/tensor_parallel.py b/examples/ffn/policy/tensor_parallel.py new file mode 100644 index 00000000..1d93c5b5 --- /dev/null +++ b/examples/ffn/policy/tensor_parallel.py @@ -0,0 +1,57 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + + ndevs = resource.ngpus + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + assert len(fnodes) == 4 + + linear1 = fnodes[0] + gelu = fnodes[1] + linear2 = fnodes[2] + loss = fnodes[3] + + all_sub_nodes = list() + + algo = linear1.algorithms('column') + sub_nodes = graph.partition(linear1, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = gelu.algorithms('dim') + sub_nodes = graph.partition(gelu, algo, config=dict(dim=2, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = linear2.algorithms('row') + sub_nodes = graph.partition(linear2, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + sub_nodes = graph.replicate(loss, times=ndevs) + all_sub_nodes.append(sub_nodes) + + for sub_nodes in all_sub_nodes: + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + print(graph) + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + sugraph.partial_set_order(fsus, lazy=False) + return sugraph From dbe9e6ad2e5db53a4ecf59140ade938e0aa8af5a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 21:46:37 +0800 Subject: [PATCH 0361/1892] add parsing on contruct list --- cube/graph/parser/parser.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index e9cc7e48..facc0d96 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -19,9 +19,10 @@ class ScriptNodeKind(enum.Enum): PrimConstant = 4 AtenOp = 5 # -> the parser may end here PrimIf = 6 # dynamic - PrimListUnpack = 7 - PrimTupleUnpack = 8 - PrimPythonOp = 9 + PrimListConstruct = 7 + PrimListUnpack = 8 + PrimTupleUnpack = 9 + PrimPythonOp = 10 class ScriptModuleParser: @@ -90,6 +91,8 @@ def ntype(node: torch._C.Node): return ScriptNodeKind.PrimIf if node.kind() == 'prim::ListUnpack': return ScriptNodeKind.PrimListUnpack + if node.kind() == 'prim::ListConstruct': + return ScriptNodeKind.PrimListConstruct if node.kind() == 'prim::TupleUnpack': return ScriptNodeKind.PrimTupleUnpack if node.kind() == 'prim::PythonOp': @@ -101,7 +104,11 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] """ Parse the node and return the IRFwOperation nodes """ - node_type = ScriptModuleParser.ntype(node) + try: + node_type = ScriptModuleParser.ntype(node) + except RuntimeError: + print(module.graph) + raise RuntimeError("Unsupported node kind {node.kind()} found in parsing. See above graph.") if node_type == ScriptNodeKind.PrimCallFunction: return ScriptModuleParser.parse_prim_function_node(node, module, frame) if node_type == ScriptNodeKind.AtenOp: @@ -112,6 +119,8 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] return ScriptModuleParser.parse_prim_attr_node(node, module, frame) if node_type == ScriptNodeKind.PrimConstant: return ScriptModuleParser.parse_prim_constant_node(node, module, frame) + if node_type == ScriptNodeKind.PrimListConstruct: + return ScriptModuleParser.parse_prim_list_construct_node(node, module, frame) if node_type == ScriptNodeKind.PrimListUnpack: return ScriptModuleParser.parse_prim_listunpack_node(node, module, frame) if node_type == ScriptNodeKind.PrimTupleUnpack: @@ -328,6 +337,22 @@ def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: """ raise NotImplementedError("Dynamic Graph is not supported yet") + @staticmethod + def parse_prim_list_construct_node(node, module, frame: Frame) -> List[None]: + """ + Parse script module node like + %8 : int[] = prim::ListConstruct(%3) + """ + inputs = [input for input in node.inputs()] + outputs = [output for output in node.outputs()] + assert len(outputs) == 1 + output = outputs[0] + out_val = list() + for input in inputs: + out_val.append(frame.get_var(input.debugName())) + frame.add_var(output.debugName(), out_val) + return list() + @staticmethod def parse_prim_listunpack_node(node, module, frame: Frame) -> List[None]: raise NotImplementedError From 56cd0a61cb4a0b80fdd0ca323f047e8b2f7f533e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 23:44:42 +0800 Subject: [PATCH 0362/1892] add layernorm --- cube/algorithm/factory.py | 3 ++ cube/algorithm/ops/layernorm.py | 63 +++++++++++++++++++++++++++++++++ cube/graph/operator/function.py | 28 +++++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 cube/algorithm/ops/layernorm.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 6c23d07a..6d5bb33e 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -77,6 +77,9 @@ def _load_predefined_algos(self): self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') self.register(elew.Add, elew.AddDimParallel, tag='dim') + import cube.algorithm.ops.layernorm as ln + self.register(ln.LayerNorm, ln.LayerNormDimParallel, tag='dim') + import cube.algorithm.ops.activation as activation self.register(activation.Activation, activation.ActivationDimParallel, tag='dim') self.register(activation.Dropout, activation.DropoutDimParallel, tag='dim') diff --git a/cube/algorithm/ops/layernorm.py b/cube/algorithm/ops/layernorm.py new file mode 100644 index 00000000..51a1f467 --- /dev/null +++ b/cube/algorithm/ops/layernorm.py @@ -0,0 +1,63 @@ +from typing import Dict +import copy + +from cube.algorithm.utils import split_axis +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.operator.function import LayerNorm + + +_kWaitDecision = None + +class LayerNormDimParallel(GenericDistAlgo): + + def __init__(self, node: LayerNorm, dim=None): + if not isinstance(node, LayerNorm): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.ndim = len(node.inputs(0).shape) + last_ndims = len(node.inputs(1)) + self.stay_dims = list() + for dim in range(last_ndims): + self.stay_dims.append(self.ndim - dim - 1) + + self.chunk_num = _kWaitDecision + self.dim = dim + + def satisfy(self, config: Dict): + if 'dim' in config: + dim = config['dim'] + else: + if self.dim is None: + raise RuntimeError("Expected dim in config") + dim = self.dim + if dim < 0: + dim = self.ndim + dim + chunk_num = int(config['chunk_num']) + if dim in self.stay_dims: + return False + shape = self.input_shapes[0] + if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: + return True + return False + + def instantiate(self, node: LayerNorm, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + if 'dim' in config: + self.dim = config['dim'] + + input = node.inputs(0) + sub_inputs = split_axis(input, self.dim, self.chunk_num) + + output = node.outputs(0) + sub_outputs = split_axis(output, self.dim, self.chunk_num) + + nodes = list() + for sub_input, sub_output in zip(sub_inputs, sub_outputs): + inputs = [sub_input] + node.inputs()[1:] + [node.kwargs['eps']] + sub_node = LayerNorm(node.signature, inputs, node.name) + sub_node.set_output(0, sub_output) + nodes.append(sub_node) + return nodes diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 3d3ddce3..8cca15f6 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -127,6 +127,33 @@ def __init__(self, signature, inputs, name='add', **kwargs): alpha = inputs[2] self.kwargs['alpha'] = alpha + +class LayerNorm(IRFwOperation): + + def __init__(self, signature, inputs, name='layernorm', **kwargs): + + if len(inputs) != 5: + raise TypeError(f"Expected 5 inputs, but got: {inputs}") + input = inputs[0] + normalized_shape = inputs[1] + if not isinstance(normalized_shape, list): + raise TypeError(f"Expected list of int, but got: {type(normalized_shape)}") + weight = inputs[2] + bias = inputs[3] + eps = inputs[4] + + inputs = [input, normalized_shape, weight, bias] + super().__init__(name, signature, input_length=4, output_length=1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs['eps'] = eps + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + self.outputs(0).shape = self.inputs(0).shape + + # ============================= Activation ============================ class Activation(IRFwOperation): @@ -188,6 +215,7 @@ def __init__(self, signature, inputs, name='softmax', **kwargs): self.kwargs['dtype'] = dtype self.stay_dims.append(dim) + # ===================== Loss Computation (Reduce) ========================= class Sum(IRFwOperation): From 027dbeabbaf23e6d42491b072763099ef2201c88 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 23:45:25 +0800 Subject: [PATCH 0363/1892] add layernorm test --- tests/algorithm/test_layernorm.py | 80 +++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/algorithm/test_layernorm.py diff --git a/tests/algorithm/test_layernorm.py b/tests/algorithm/test_layernorm.py new file mode 100644 index 00000000..73174808 --- /dev/null +++ b/tests/algorithm/test_layernorm.py @@ -0,0 +1,80 @@ +from cube.graph.operator.function import ElementWise +import cube.algorithm.ops.layernorm as ln +from cube.graph.tensor import IRFullTensor + + +def test_elementwise_dim_parallel(): + + input1 = IRFullTensor(shape=[1024, 512, 256], name='input1').tosub() + normalized_shape = [256,] + weight = IRFullTensor(shape=[256], name='weight').tosub() + bias = IRFullTensor(shape=[256], name='bias').tosub() + eps = 1e-5 + + semantic_op = ln.LayerNorm( + signature='torch.nn.functional.layernorm', + inputs=[input1, normalized_shape, weight, bias, eps], + name='layernorm' + ) + semantic_op.infer_shape() + print('semantic op:') + print(semantic_op) + + op_dim = ln.LayerNormDimParallel(semantic_op) + + assert op_dim.chunk_num is None + + # test satisfy + assert op_dim.satisfy(dict(dim=0, chunk_num = 4)) + assert op_dim.satisfy(dict(dim=1, chunk_num = 8)) + assert not op_dim.satisfy(dict(dim=2, chunk_num = 8)) + + nodes = op_dim.instantiate(semantic_op, dict(dim=1, chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, ln.LayerNorm) + + for node in nodes: + print(node) + print('inputs:') + + input = node.inputs(0) + print(input) + assert input.shape == [1024, 512 // 4, 256] + + weight = node.inputs(2) + print(weight) + assert weight.shape == [256,] + + bias = node.inputs(3) + print(bias) + assert bias.shape == [256,] + + print('outputs:') + for output in node.outputs(): + print(output) + assert output.shape == [1024, 512 // 4, 256] + + op_dim = ln.LayerNormDimParallel(semantic_op, dim=0) + nodes = op_dim.instantiate(semantic_op, dict(chunk_num=4)) + + for node in nodes: + print(node) + print('inputs:') + + input = node.inputs(0) + print(input) + assert input.shape == [1024 // 4, 512, 256] + + weight = node.inputs(2) + print(weight) + assert weight.shape == [256,] + + bias = node.inputs(3) + print(bias) + assert bias.shape == [256,] + + print('outputs:') + for output in node.outputs(): + print(output) + assert input.shape == [1024 // 4, 512, 256] From a4eb8bf72c6f1e67f1a1f8d61637666c9296b938 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 23:47:57 +0800 Subject: [PATCH 0364/1892] parse layernorm --- cube/graph/operator/function.py | 1 + cube/graph/parser/mapping.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 8cca15f6..52bf043b 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -152,6 +152,7 @@ def infer_shape(self): if self.inputs(0).shape is None: return False self.outputs(0).shape = self.inputs(0).shape + return True # ============================= Activation ============================ diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 09c09eac..af0c9bbc 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -43,6 +43,8 @@ def map(signature: str) -> IRFwOperation: __ftemplate('gelu') : partial(function.Activation, name='gelu'), + __ftemplate('layer_norm'): function.LayerNorm, + # torch aten __ttemplate('add') : function.Add, From 14e6bcf713f5d1c9c83ab718f63c1d8dae45e5b0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 21 Nov 2021 23:51:56 +0800 Subject: [PATCH 0365/1892] parse transformer --- examples/transformer/transformer.py | 192 ++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 examples/transformer/transformer.py diff --git a/examples/transformer/transformer.py b/examples/transformer/transformer.py new file mode 100644 index 00000000..b42dfb05 --- /dev/null +++ b/examples/transformer/transformer.py @@ -0,0 +1,192 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/transformer/transformer.py +""" + +import torch +from torch import nn +import torch.nn.functional as F +import cube + + +from examples.transformer.policy.tensor_parallel import transform_policy +from examples.transformer.policy.tensor_parallel import schedule_policy + +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +class MultiHeadSelfAttention(nn.Module): + + def __init__(self, seq_len, embed_dim, heads, dropout): + super().__init__() + + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_head = heads + self.dim_head = embed_dim // heads + self.scale = self.dim_head ** -0.5 + + self.weight_qkv = torch.nn.Parameter(torch.empty( + 3 * embed_dim, embed_dim + )) + self.weight_out = torch.nn.Parameter(torch.empty( + embed_dim, embed_dim + )) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + x: [L, N, E]: seq_len, batch_size, embedding dimension + output: [L, N, E] + """ + # [L, N, E] -> 3 x [L, (N * num_head), dim_head] + q, k, v = cube.runtime.function.toqkv( + x, self.weight_qkv, self.num_head + ) + + # [L, (N * num_head), dim_head] -> [(N * num_head), L, dim_head] + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + # [(N * num_head), L, dim_head] -> [(N * num_head), L, dim_head] + q = q * self.scale + # [(N * num_head), L, dim_head] -> [(N * num_head), dim_head, L] + k = k.transpose(-2, -1) + # [(N * num_head), L, dim_head] * [(N * num_head), dim_head, L] + # -> [(N * num_head), L, L] + attn = torch.bmm(q, k) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = cube.runtime.function.tril_mask(attn, self.num_head) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = F.softmax(attn, dim=-1) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = self.dropout(attn) + # [(N * num_head), L, L] * [(N * num_head), L, dim_head] + # -> [(N * num_head), L, dim_head] + output = torch.bmm(attn, v) + + # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] + output = cube.runtime.function.attn_view(output, self.num_head) + + # [L, N, num_head * dim_head] * [E, embed_head * dim_head] + # -> [L, N, E] + output = F.linear(output, self.weight_out) + return output + + +class FFN(torch.nn.Module): + + def __init__(self, hidden_size: int): + super().__init__() + self.dense_h_to_4h = torch.nn.Linear( + hidden_size, 4 * hidden_size + ) + self.dense_4h_to_h = torch.nn.Linear( + 4 * hidden_size, hidden_size + ) + + def forward(self, hidden_states): + # [L, N, E] * [E, 4E] -> [L, N, 4E] + out = self.dense_h_to_4h(hidden_states) + # [L, N, 4E] -> [L, N, 4E] + out = F.gelu(out) + # [L, N, 4E] * [4E, E] -> [L, N, E] + out = self.dense_4h_to_h(out) + return out + + +class TransformerLayer(torch.nn.Module): + + def __init__(self, seq_len, hidden_size, head_num, dropout): + super().__init__() + # layer norm + self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + + self.attention = MultiHeadSelfAttention(seq_len, hidden_size, head_num, dropout) + self.attn_dropout = torch.nn.Dropout(dropout) + + self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + self.ffn = FFN(hidden_size) + self.ffn_dropout = torch.nn.Dropout(dropout) + + def forward(self, hidden_states): + # Attention + in_attn_norm = self.input_layernorm(hidden_states) + attn_out = self.attention(in_attn_norm) + # residual + attn_out = self.attn_dropout(attn_out) + residual = attn_out + hidden_states + # ffn + in_ffn_norm = self.ffn_layernorm(residual) + ffn_out = self.ffn(in_ffn_norm) + # residual + ffn_out = self.ffn_dropout(ffn_out) + ffn_out = ffn_out + residual + + loss = torch.sum(ffn_out) + return loss + + +def train(): + L = 512 # seq len + N = 32 # batch size + # configs: [hidden size, num_head] + # E, num_head = [1536, 16] # 1.2B model + # E, num_head = [1920, 20] # 2.5B model + # E, num_head = [2304, 24] # 4.2B model + E, num_head = [3072, 32] # 8.7B model + + + model = TransformerLayer( + seq_len=L, hidden_size=E, head_num=num_head, dropout=0.5 + ) + model = cube.SemanticModel( + model, input_shapes=([L, N, E],), + ) + + dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) + + @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + + +if __name__ == '__main__': + + cube.init() + train() From d8629c66ad927de01a776f8a9b7633f45dbfa762 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Nov 2021 16:04:39 +0800 Subject: [PATCH 0366/1892] add torch adapt for multi branch --- cube/execplan/planpass/torchadapt.py | 155 +++++++++++++++++++++++++++ cube/graph/gpass.py | 21 ++-- 2 files changed, 166 insertions(+), 10 deletions(-) create mode 100644 cube/execplan/planpass/torchadapt.py diff --git a/cube/execplan/planpass/torchadapt.py b/cube/execplan/planpass/torchadapt.py new file mode 100644 index 00000000..bbc9833d --- /dev/null +++ b/cube/execplan/planpass/torchadapt.py @@ -0,0 +1,155 @@ +""" +PyTorch Adapter for multi-branch reference + +If a tensor is the input for multiple operators: + + the gradient of this tensor will be value splitted for each op-backward. + +However, in pytorch, the gradient is accumulated by default, this +will cause inconsistent behaviour for transoform SU when the referred +operators are on the same device or not. + +For the situation when the referred operators are on different devices: + Nothing happens + +For the situation when the referred operators are on same device: + The gradient will change to match `auto accumulation` semantics. + For first referred op: grad will be set to ValueMap(idx, num_referred_devices) + For other referred op: grad is set to None +""" + +from typing import Dict, List + +from cube.execplan import ExectuionPlan +from cube.graph.tensor import IRSubTensor, ValueMap +from cube.schedule.adapter.transform import IRTensorTransform +from cube.schedule.su import SUType, ScheduleUnit +from cube.execplan.planpass.planpass import PlanPass + + +class TorchRefAdapter(PlanPass): + + @staticmethod + def apply(execplan: ExectuionPlan): + # same device multiple reference + multiref = TorchRefAdapter.gather_tensor(execplan) + for tid in multiref: + print(f'tensor id: {tid}') + for devid in multiref[tid]: + for fsu in multiref[tid][devid]: + print(f'dev {devid}: {fsu}') + + for tid in multiref: + grad_num = len(multiref[tid]) + for idx, devid in enumerate(multiref[tid]): + # the first forward, the last backward + fsu = multiref[tid][devid][0] + ftensor = None + for input in fsu.inputs(): + if isinstance(input, IRSubTensor): + if input._id == tid: + ftensor = input + break + if ftensor is None: + raise RuntimeError("Internal Error: fsu not found input tensor") + grad = ftensor.parent.grad.select( + indices = ftensor.indices, + val_map = ValueMap(idx, grad_num), + shape = ftensor.shape + ) + rm_grad = TorchRefAdapter.set_grad(fsu, ftensor, grad) + TorchRefAdapter.replace_all(execplan, rm_grad, grad) + + # all the other reference place: set grad to none + for fsu in multiref[tid][devid][1:]: + rm_grad = TorchRefAdapter.set_grad(fsu, ftensor, grad=None) + TorchRefAdapter.replace_all(execplan, rm_grad, None) + + # reset select and merge adapters + for devid in execplan.devices(): + for idx, su in enumerate(execplan.sequence(devid)): + if su.stype == SUType.Transform: + ins = [input for input in su.inputs() if input is not None] + ous = [ou for ou in su.outputs() if ou is not None] + if len(ins) < len(su.inputs()) or len(ous) < len(su.outputs()): + for ou in ous: + if ou in ins: + break + trans = IRTensorTransform( + src_tensors=ins, dst_tensors=ous + ) + trans_su = ScheduleUnit([trans], SUType.Transform, name='trans') + trans_su.device = devid + if len(trans_su.outputs()) == 0: + # meaning outputs in inputs + execplan.at(devid).remove(su) + else: + execplan.at(devid)[idx] = trans_su + return execplan + + @staticmethod + def gather_tensor(execplan: ExectuionPlan) -> Dict: + """ + Return: + { + sub_tensor id: + device id: + [forward su] + } + """ + fwsus = dict() + for devid in execplan.devices(): + for fsu in execplan.sequence(devid): + if fsu.stype == SUType.Forward: + for input in fsu.inputs(): + if isinstance(input, IRSubTensor): + tid = input._id + if tid not in fwsus: + fwsus[tid] = dict() + if devid not in fwsus[tid]: + fwsus[tid][devid] = list() + fwsus[tid][devid].append(fsu) + multiref = dict() + for tid in fwsus: + for devid in fwsus[tid]: + if len(fwsus[tid][devid]) != 1: + multiref[tid] = fwsus[tid] + break + return multiref + + @staticmethod + def set_grad(fsu: ScheduleUnit, input: IRSubTensor, grad): + """ + Return removed grad + """ + if not isinstance(fsu, ScheduleUnit) or fsu.stype != SUType.Forward: + raise TypeError("Require SU to be forward SU") + # forward SU + findex = fsu.inputs().index(input) + fsu.inputs(findex).grad = grad + # backward SU + bsu = fsu.mirror + bindex = bsu.inputs().index(input) + bin = bsu.inputs(bindex) + gindex = bsu.outputs().index(bin.grad) + removed_grad = bin.grad + bin.grad = grad + bsu.set_output(gindex, grad) + return removed_grad + + @staticmethod + def replace_all(execplan: ExectuionPlan, src: IRSubTensor, dst): + for devid in execplan.devices(): + for su in execplan.sequence(devid): + if src in su.inputs(): + if len(su.inputs()) == 1: + execplan.at(devid).remove(su) + else: + index = su.inputs().index(src) + su.set_input(index, dst) + if src in su.outputs(): + if len(su.outputs()) == 1: + execplan.at(devid).remove(su) + else: + index = su.outputs().index(src) + su.set_output(index, dst) diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index ae4bcf14..f727a318 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -2,7 +2,7 @@ import copy from cube.graph.graph import IRGraph -from cube.graph.tensor import IRSubTensor +from cube.graph.tensor import IRSubTensor, ValueMap from cube.graph.operator import IRBpOperation from cube.ir.cten import IRCell, IRTensor @@ -61,11 +61,11 @@ def forward(graph, *args) -> IRGraph: fnodes = list() bnodes = list() + + # generate forward nodes for node in graph.nodes(): inputs = node.inputs() outputs = node.outputs() - - # forwrd node # fnode = copy.copy(node) fnode = node fnode._inputs = inputs @@ -76,8 +76,13 @@ def forward(graph, *args) -> IRGraph: # set forward outputs for idx, val in enumerate(outputs): fnode.set_output(idx, gener.renew(val)) - - # backward node + fnodes.append(fnode) + fnode.device = node.device + + # generate backward nodes + for fnode in fnodes: + inputs = fnode.inputs() + outputs = fnode.outputs() bnode = IRBpOperation(data_num=len(inputs), grad_num=len(outputs)) # set backward grad for idx, val in enumerate(fnode.inputs()): @@ -98,14 +103,10 @@ def forward(graph, *args) -> IRGraph: grad = val.get_grad(fnode) val.grad = grad bnode.set_grad(idx, grad) - - fnode.device = node.device bnode.device = node.device - + # mirror node for forward / backward IRCell.make_pair(fnode, bnode) - - fnodes.append(fnode) bnodes.append(bnode) inputs = [gener.renew(input) for input in graph.inputs()] From d373b71871056a33957f0b3c75972fe393824939 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Nov 2021 16:14:04 +0800 Subject: [PATCH 0367/1892] sugraph dependency reset for torch adapter --- cube/execplan/planpass/torchadapt.py | 6 ++++++ cube/schedule/sugraph.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/cube/execplan/planpass/torchadapt.py b/cube/execplan/planpass/torchadapt.py index bbc9833d..edf8cf02 100644 --- a/cube/execplan/planpass/torchadapt.py +++ b/cube/execplan/planpass/torchadapt.py @@ -83,8 +83,12 @@ def apply(execplan: ExectuionPlan): if len(trans_su.outputs()) == 0: # meaning outputs in inputs execplan.at(devid).remove(su) + execplan.sugraph.sequence.remove(su) else: execplan.at(devid)[idx] = trans_su + suidx = execplan.sugraph.sequence.index(su) + execplan.sugraph.sequence[suidx] = trans_su + execplan.sugraph.reset_dependency(execplan.sugraph.sus()) return execplan @staticmethod @@ -144,12 +148,14 @@ def replace_all(execplan: ExectuionPlan, src: IRSubTensor, dst): if src in su.inputs(): if len(su.inputs()) == 1: execplan.at(devid).remove(su) + execplan.sugraph.sequence.remove(su) else: index = su.inputs().index(src) su.set_input(index, dst) if src in su.outputs(): if len(su.outputs()) == 1: execplan.at(devid).remove(su) + execplan.sugraph.sequence.remove(su) else: index = su.outputs().index(src) su.set_output(index, dst) diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index fb77c58c..5a61a5c6 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union import copy +from cube.graph.tensor import IRSubTensor from cube.ir.cten import IRCell, IRTensor from cube.graph.operator import IRBpOperation @@ -64,6 +65,8 @@ def reset_dependency(sus: List[ScheduleUnit]): src.stype == dst.stype: continue for out_idx, out_tensor in enumerate(src.outputs()): + if not isinstance(out_tensor, IRTensor): + continue # special dependency for communication adapter if dst.stype == SUType.P2P: for recv_tensor in dst.outputs(): From b0e8a5570c3495a8b77394e0d3e3e076b18263d0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Nov 2021 20:43:58 +0800 Subject: [PATCH 0368/1892] fix command --- examples/attention/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/attention/attention.py b/examples/attention/attention.py index 86c1cbb9..c68409b3 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -2,13 +2,13 @@ example: python -m torch.distributed.launch \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/attention.py + examples/attention/attention.py """ import torch From 8b4faf570524ec82d4ba155c75b9f72bb1b09309 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Nov 2021 20:44:42 +0800 Subject: [PATCH 0369/1892] fix residual --- cube/schedule/su.py | 62 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/cube/schedule/su.py b/cube/schedule/su.py index 19b2a644..ab80f6e9 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -3,6 +3,7 @@ from enum import Enum from cube.ir.cten import IRCell, IRTensor +from cube.graph.operator import IRFwOperation class SUType(Enum): @@ -49,10 +50,9 @@ def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): raise TypeError("Expected stype be SUType") # get inputs and outputs - # TODO: fix bug on multi-branch - inputs = IRCell.get_inputs(nodes) + inputs = ScheduleUnit.get_inputs(nodes) # inputs = [input for input in inputs if not input.is_param()] - outputs = IRCell.get_outputs(nodes) + outputs = ScheduleUnit.get_outputs(nodes) # outputs = [output for output in outputs if not output.is_param()] super().__init__( name = name, @@ -422,6 +422,62 @@ def successors(self, index: Optional[int] = None) -> List: else: raise TypeError("Expected index to be None or int") + @staticmethod + def get_inputs(nodes): + """ + Get all the input tensors the is not generated by nodes + + Inputs + + Returns: + List[IRTensor] + """ + all_outputs = list() + for cell in nodes: + all_outputs += cell.outputs() + inputs = list() + for cell in nodes: + for input in cell.inputs(): + if isinstance(input, IRTensor): + if input not in all_outputs: + if input not in inputs: + inputs.append(input) + return inputs + + @staticmethod + def get_outputs(nodes: List[IRCell]): + """ + Get all the input tensors the is not generated by nodes + + Args: + This will also consider the successor forward nodes. + If it is required by other outside forward nodes, + put in the outputs list + + Returns: + List[IRTensor] + """ + all_inputs = list() + for node in nodes: + all_inputs += node.inputs() + outputs = list() + for node in nodes: + for idx, output in enumerate(node.outputs()): + if isinstance(output, IRTensor): + if output not in all_inputs: + if output not in outputs: + outputs.append(output) + continue + succs = node.successors(idx) + fsuccs = [ + fnode for fnode in succs if isinstance(fnode, IRFwOperation) + ] + for fsucc in fsuccs: + if fsucc not in nodes: + if output not in outputs: + outputs.append(output) + return outputs + def __repr__(self): su_inputs = list() for tensor in self.inputs(): From 0a29b7ae680c058219e20f090a3ca62fd4bb1ecd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Nov 2021 20:45:17 +0800 Subject: [PATCH 0370/1892] fix torchadapt --- cube/execplan/planpass/torchadapt.py | 151 +++++++++++++++++++-------- 1 file changed, 105 insertions(+), 46 deletions(-) diff --git a/cube/execplan/planpass/torchadapt.py b/cube/execplan/planpass/torchadapt.py index edf8cf02..8c467d5d 100644 --- a/cube/execplan/planpass/torchadapt.py +++ b/cube/execplan/planpass/torchadapt.py @@ -18,7 +18,7 @@ For other referred op: grad is set to None """ -from typing import Dict, List +from typing import Dict from cube.execplan import ExectuionPlan from cube.graph.tensor import IRSubTensor, ValueMap @@ -32,18 +32,43 @@ class TorchRefAdapter(PlanPass): @staticmethod def apply(execplan: ExectuionPlan): # same device multiple reference - multiref = TorchRefAdapter.gather_tensor(execplan) - for tid in multiref: - print(f'tensor id: {tid}') - for devid in multiref[tid]: - for fsu in multiref[tid][devid]: + multiref_fsus, multiref_fnodes = TorchRefAdapter.multi_ref_cells(execplan) + for tid in multiref_fsus: + print(f'multi-referred tensor id: {tid}') + for devid in multiref_fsus[tid]: + for fsu in multiref_fsus[tid][devid]: print(f'dev {devid}: {fsu}') - for tid in multiref: - grad_num = len(multiref[tid]) - for idx, devid in enumerate(multiref[tid]): + + for tid in multiref_fsus: + # check chunk num for each device + total_ops = set() + for devid in multiref_fnodes[tid]: + for op in multiref_fnodes[tid][devid]: + total_ops.add(op._id) + total_ops = list(total_ops) + num_ops = len(total_ops) + # how many ops are computed for each device + dev_ops = dict() + for devid in multiref_fnodes[tid]: + op_index = list() + for op in multiref_fnodes[tid][devid]: + op_index.append(total_ops.index(op._id)) + cnt = len(op_index) + if cnt != 1 and cnt != num_ops: + raise NotImplementedError("Only support even chunk for multi-ref") + dev_ops[devid] = op_index + + for idx, devid in enumerate(multiref_fsus[tid]): + # the value map should be op_num / total_ops + op_index = dev_ops[devid] + if len(op_index) == num_ops: + grad_idx, grad_num = 0, 1 + elif len(op_index) == 1: + grad_idx, grad_num = op_index[0], num_ops + # the first forward, the last backward - fsu = multiref[tid][devid][0] + fsu = multiref_fsus[tid][devid][0] ftensor = None for input in fsu.inputs(): if isinstance(input, IRSubTensor): @@ -54,16 +79,18 @@ def apply(execplan: ExectuionPlan): raise RuntimeError("Internal Error: fsu not found input tensor") grad = ftensor.parent.grad.select( indices = ftensor.indices, - val_map = ValueMap(idx, grad_num), + val_map = ValueMap(grad_idx, grad_num), shape = ftensor.shape ) rm_grad = TorchRefAdapter.set_grad(fsu, ftensor, grad) - TorchRefAdapter.replace_all(execplan, rm_grad, grad) + TorchRefAdapter.replace_all(execplan, rm_grad, grad, devid) # all the other reference place: set grad to none - for fsu in multiref[tid][devid][1:]: + for fsu in multiref_fsus[tid][devid][1:]: rm_grad = TorchRefAdapter.set_grad(fsu, ftensor, grad=None) - TorchRefAdapter.replace_all(execplan, rm_grad, None) + TorchRefAdapter.replace_all(execplan, rm_grad, None, devid) + + print(execplan) # reset select and merge adapters for devid in execplan.devices(): @@ -92,34 +119,43 @@ def apply(execplan: ExectuionPlan): return execplan @staticmethod - def gather_tensor(execplan: ExectuionPlan) -> Dict: + def multi_ref_cells(execplan: ExectuionPlan) -> Dict: """ Return: { sub_tensor id: device id: - [forward su] + [forward su or forward node] } """ - fwsus = dict() + fnodes = dict() + fsus = dict() for devid in execplan.devices(): for fsu in execplan.sequence(devid): if fsu.stype == SUType.Forward: for input in fsu.inputs(): if isinstance(input, IRSubTensor): tid = input._id - if tid not in fwsus: - fwsus[tid] = dict() - if devid not in fwsus[tid]: - fwsus[tid][devid] = list() - fwsus[tid][devid].append(fsu) - multiref = dict() - for tid in fwsus: - for devid in fwsus[tid]: - if len(fwsus[tid][devid]) != 1: - multiref[tid] = fwsus[tid] + if tid not in fnodes: + fnodes[tid] = dict() + fsus[tid] = dict() + if devid not in fnodes[tid]: + fnodes[tid][devid] = list() + fsus[tid][devid] = list() + fsus[tid][devid].append(fsu) + for node in fsu.nodes(): + if input in node.inputs(): + fnodes[tid][devid].append(node) + multiref_fnodes = dict() + multiref_sus = dict() + for tid in fnodes: + for devid in fnodes[tid]: + if len(fnodes[tid][devid]) != 1: + multiref_sus[tid] = fnodes[tid] + multiref_fnodes[tid] = fsus[tid] break - return multiref + return multiref_fnodes, multiref_sus + @staticmethod def set_grad(fsu: ScheduleUnit, input: IRSubTensor, grad): @@ -131,31 +167,54 @@ def set_grad(fsu: ScheduleUnit, input: IRSubTensor, grad): # forward SU findex = fsu.inputs().index(input) fsu.inputs(findex).grad = grad + if not len(fsu.nodes()) == 1: + raise RuntimeError("TorchAdapt should call before merge") + fnode = fsu.nodes(0) + findex = fnode.inputs().index(input) + fnode.inputs(findex).grad = grad # backward SU bsu = fsu.mirror bindex = bsu.inputs().index(input) bin = bsu.inputs(bindex) - gindex = bsu.outputs().index(bin.grad) + try: + gindex = bsu.outputs().index(bin.grad) + except ValueError: + raise RuntimeError( + (f"Internal Error: cannot find given grad in bsu: {bsu}:\n" + f"gradient given tensor: {bin}, grad: {bin.grad}") + ) removed_grad = bin.grad bin.grad = grad bsu.set_output(gindex, grad) return removed_grad @staticmethod - def replace_all(execplan: ExectuionPlan, src: IRSubTensor, dst): - for devid in execplan.devices(): - for su in execplan.sequence(devid): - if src in su.inputs(): - if len(su.inputs()) == 1: - execplan.at(devid).remove(su) - execplan.sugraph.sequence.remove(su) - else: - index = su.inputs().index(src) - su.set_input(index, dst) - if src in su.outputs(): - if len(su.outputs()) == 1: - execplan.at(devid).remove(su) - execplan.sugraph.sequence.remove(su) - else: - index = su.outputs().index(src) - su.set_output(index, dst) + def replace_all(execplan: ExectuionPlan, src: IRSubTensor, dst, devid: int): + for su in execplan.sequence(devid): + # pair removement for p2p will already remove su + if su not in execplan.at(devid): + continue + rm_su = None + if src in su.inputs(): + if len(su.inputs()) == 1 and dst is None: + execplan.at(devid).remove(su) + execplan.sugraph.sequence.remove(su) + rm_su = su + else: + index = su.inputs().index(src) + su.set_input(index, dst) + if src in su.outputs(): + if len(su.outputs()) == 1 and dst is None: + execplan.at(devid).remove(su) + execplan.sugraph.sequence.remove(su) + rm_su = su + else: + index = su.outputs().index(src) + su.set_output(index, dst) + # pair removement + if rm_su is not None and rm_su.stype == SUType.P2P: + mirror = rm_su.mirror + dev = mirror.device[0] + if mirror in execplan.at(dev): + execplan.at(dev).remove(mirror) + execplan.sugraph.sequence.remove(mirror) From b5983e9304914d0e310bef10b8cd4c856ee60bca Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Nov 2021 20:46:48 +0800 Subject: [PATCH 0371/1892] fix bugs on single device fusion --- cube/execplan/planpass/p2pfusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index c6924169..f3c0a6bb 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -46,6 +46,9 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: if pid not in ous: continue tous, tins = ous[pid], ins[pid] + # if they are on the single device, matching is skipped + if len(tous) == 1 and set(tous.keys()) == set(tins.keys()): + continue if P2PFusion.have_comm(tous, tins): colls : List[ScheduleUnit] = None for matcher in matchers: From d5276730a66b208621c1013652c1aa2d249091bb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Nov 2021 20:47:33 +0800 Subject: [PATCH 0372/1892] fix bugs on multiple tensor backward --- cube/runtime/executor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 44d00e85..b712ac2f 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -28,9 +28,8 @@ def backward(input_tensors, output_tensors, output_tensor_grads): "Expected same length of out tensors and grads" ) - for tensor, grads in zip(output_tensors, output_tensor_grads): - # print('backwarding... ') - torch.autograd.backward(tensor, grad_tensors=grads) + # print('backwarding... ') + torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) grads = list() for tensor in input_tensors: # print('backward input tensor: {}'.format(tensor)) From 6ae4d81052aec77e31250667d94fd92107e62d3c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 09:12:56 +0800 Subject: [PATCH 0373/1892] name typo --- cube/compiler.py | 4 ++++ cube/execplan/planpass/p2pfusion.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cube/compiler.py b/cube/compiler.py index be90b3cb..e575f925 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -10,6 +10,7 @@ from cube.schedule.sugraph import SUGraph, SUGraphGener from cube.execplan import ExectuionPlan +from cube.execplan.planpass.torchadapt import TorchRefAdapter from cube.execplan.planpass.redundant import RemoveRedundantAdapters from cube.execplan.planpass.merge import MergeComputeSU from cube.execplan.planpass.gfuse import WeightGradAllreduceFusion @@ -145,6 +146,9 @@ def decorator(fn: Callable) -> Callable: raise RuntimeError(f"SUGraph order is not topological order") execplan = ExectuionPlan(sugraph) + # plan pass to adapt to pytorch semantic: multi branch gradient + # TODO: residual support + # execplan = TorchRefAdapter.apply(execplan) # plan pass to remove redundant sus execplan = RemoveRedundantAdapters.apply(execplan) # print(f'> after remove redundant adapters:\n {execplan}') diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index f3c0a6bb..e88070d1 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -228,7 +228,7 @@ def match_allreduce(tous, tins): else: raise RuntimeError("Internal Error") op = IRCollectives(input, outputs, ranks, IRCollType.AllReduce) - su = ScheduleUnit([op], SUType.Coll, name='allgather') + su = ScheduleUnit([op], SUType.Coll, name='allreduce') su.device = rank allreduce_sus.append(su) From 0ab62b32912897bb8aa40cde700fd51ef371fad1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 10:44:00 +0800 Subject: [PATCH 0374/1892] change backward to use autograd.grad --- cube/codegen/codegen.py | 16 +++++++++----- cube/runtime/executor.py | 46 ++++++++++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 2216e2c5..924c9d98 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -7,6 +7,7 @@ from cube.graph.operator.operator import IRFwOperation, IROptimOperation from cube.ir.cten import IRTensor +from cube.graph.tensor import ValueMap from cube.execplan import ExectuionPlan from cube.schedule.adapter.collectives import IRCollectives @@ -437,6 +438,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: for tensor in fsu.inputs(): if isinstance(tensor, IRTensor): if tensor.is_param(): + finputs.append('model.' + self.tensor_naming(tensor)) continue finputs.append(self.tensor_naming(tensor)) fargs = '(' + ', '.join(finputs + ['']) + ')' @@ -449,9 +451,7 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: fout_grads = list() for fout in fsu.outputs(): # the loss computed starting point - if fout == fout.grad: - #TODO: mean<0, N> needs to divide by N times - pass + # if fout == fout.grad: fout_grads.append(self.tensor_naming(fout.grad)) fout_grads = '(' + ', '.join(fout_grads + ['']) + ')' @@ -462,9 +462,15 @@ def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: ) # returned value are graph.outputs - return_val = [self.tensor_naming(tensor) for tensor in su.outputs()] + return_val = list() + for input in fsu.inputs(): + if isinstance(input, IRTensor): + return_val.append(self.tensor_naming(input.grad)) + else: + return_val.append(None) + # return_val = [self.tensor_naming(tensor.grad) for tensor in finputs] # TODO: fix this by using grad attributed - return_val = return_val[:len(finputs)] + # return_val = return_val[:len(finputs)] if len(return_val) > 0: return_code = ', '.join(return_val) + ' = ' else: diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index b712ac2f..c613ba93 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -2,7 +2,7 @@ SU Executor for runtime """ -from typing import Tuple, Any, Callable +from typing import Tuple, Any, Callable, List import torch @@ -15,7 +15,7 @@ def fexecute(su: Callable, *input_tensors: Tuple[Any]): return outputs -def backward(input_tensors, output_tensors, output_tensor_grads): +def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_grads): """ Backward the SUs """ @@ -28,15 +28,39 @@ def backward(input_tensors, output_tensors, output_tensor_grads): "Expected same length of out tensors and grads" ) - # print('backwarding... ') - torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) - grads = list() - for tensor in input_tensors: - # print('backward input tensor: {}'.format(tensor)) - if torch.is_tensor(tensor) and tensor.requires_grad: - grads.append(tensor.grad) - else: - grads.append(None) + inputs = list() + indices = list() + for idx, input in enumerate(input_tensors): + if torch.is_tensor(input) and input.requires_grad: + inputs.append(input) + indices.append(idx) + + grads = [None] * len(input_tensors) + if len(inputs) != 0: + # print('backwarding... ') + in_grads = torch.autograd.grad(output_tensors, inputs, output_tensor_grads) + for idx, grad in zip(indices, in_grads): + if input_tensors[idx].is_leaf: + input_tensors[idx].grad = grad + grads[idx] = grad + + # if len(inputs) != 0: + # torch.autograd.backward( + # output_tensors, + # grad_tensors=output_tensor_grads, + # inputs=inputs + # ) + # for idx, tensor in zip(indices, inputs): + # grads[idx] = tensor.grad + + # torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) + # grads = list() + # for tensor in input_tensors: + # # print('backward input tensor: {}'.format(tensor)) + # if torch.is_tensor(tensor) and tensor.requires_grad: + # grads.append(tensor.grad) + # else: + # grads.append(None) if len(grads) == 0: return None elif len(grads) == 1: return grads[0] else: return tuple(grads) From fed03201f259419bf3568f2ab16db48541fd17d3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 10:44:31 +0800 Subject: [PATCH 0375/1892] fix fuse bug for replicating --- cube/execplan/planpass/gfuse.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/cube/execplan/planpass/gfuse.py b/cube/execplan/planpass/gfuse.py index d324ec24..3adced5f 100644 --- a/cube/execplan/planpass/gfuse.py +++ b/cube/execplan/planpass/gfuse.py @@ -7,7 +7,7 @@ from cube.graph.operator.operator import IROptimOperation -from cube.graph.tensor import IRSubTensor +from cube.graph.tensor import IRSubTensor, ValueMap from cube.execplan import ExectuionPlan from cube.schedule.su import SUType, ScheduleUnit @@ -54,12 +54,12 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: @staticmethod def _get_weight_grads(execplan: ExectuionPlan) -> Dict: """ - Get weight gradient + Get weight and gradient - Return Dict[IRSubTensor, Dict[int, List[IRSubTensor]]] - (grads = params[param][device]) + weights: Dict[param_id: int, IRSubTensor] + grads : Dict[param_id: int, Dict[device: int, List[grad: IRSubTensor]]] + """ - # grad = params[param][device] grads = dict() weights = dict() for devid in execplan.devices(): @@ -68,16 +68,19 @@ def _get_weight_grads(execplan: ExectuionPlan) -> Dict: # bsu has only one node for input in bsu.inputs(): if isinstance(input, IRSubTensor) and input.is_param(): - if input._id not in grads: - grads[input._id] = dict() - weights[input._id] = input - if devid not in grads[input._id]: - grads[input._id][devid] = list() grad = input.grad if grad is None: print(input.name, input) print(grad) assert grad is not None + # nothing to sync + if grad.val_map == ValueMap(0, 1): + continue + if input._id not in grads: + grads[input._id] = dict() + weights[input._id] = input + if devid not in grads[input._id]: + grads[input._id][devid] = list() if grad in grads[input._id][devid]: raise RuntimeError("Already logged grad?") grads[input._id][devid].append(grad) From ff99a18130230b3b3828015b1b57ed081d172ee9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 10:44:56 +0800 Subject: [PATCH 0376/1892] transformer test --- examples/transformer/policy/no_parallel.py | 23 +++ .../transformer/policy/tensor_parallel.py | 161 ++++++++++++++++++ examples/transformer/transformer.py | 6 +- 3 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 examples/transformer/policy/no_parallel.py create mode 100644 examples/transformer/policy/tensor_parallel.py diff --git a/examples/transformer/policy/no_parallel.py b/examples/transformer/policy/no_parallel.py new file mode 100644 index 00000000..702c02b7 --- /dev/null +++ b/examples/transformer/policy/no_parallel.py @@ -0,0 +1,23 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using column parallel + """ + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + for su in sugraph.fsus(): + sugraph.assign(su, 0) + sugraph.assign(su.mirror, 0) + return sugraph diff --git a/examples/transformer/policy/tensor_parallel.py b/examples/transformer/policy/tensor_parallel.py new file mode 100644 index 00000000..0cedcb4c --- /dev/null +++ b/examples/transformer/policy/tensor_parallel.py @@ -0,0 +1,161 @@ +from cube.graph import IRGraph +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using tensor parallel + """ + print('> transforming graph...') + ndevs = resource.ngpus + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + assert len(fnodes) == 23 + + attn_ln = fnodes[0] + + toqkv = fnodes[1] + q_t = fnodes[2] + k_t = fnodes[3] + v_t = fnodes[4] + q_scale = fnodes[5] + k_t2 = fnodes[6] + qk_bmm = fnodes[7] + mask = fnodes[8] + softmax = fnodes[9] + attn_dropout = fnodes[10] + attnv_bmm = fnodes[11] + attnview = fnodes[12] + linear = fnodes[13] + + attn_post_dropout = fnodes[14] + attn_residual = fnodes[15] + + ffn_ln = fnodes[16] + ffn_linear1 = fnodes[17] + ffn_gelu = fnodes[18] + ffn_linear2 = fnodes[19] + + ffn_post_dropout = fnodes[20] + ffn_post_residual = fnodes[21] + + loss = fnodes[22] + + + all_sub_nodes = list() + + # ============== attention ============ + sub_nodes = graph.replicate(attn_ln, times=resource.ngpus) + all_sub_nodes.append(sub_nodes) + + algo = toqkv.algorithms('head') + sub_nodes = graph.partition(toqkv, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = q_t.algorithms('dim') + sub_nodes = graph.partition(q_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = k_t.algorithms('dim') + sub_nodes = graph.partition(k_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = v_t.algorithms('dim') + sub_nodes = graph.partition(v_t, algo, config=dict(dim=1, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = q_scale.algorithms('dim') + sub_nodes = graph.partition(q_scale, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = k_t2.algorithms('dim') + sub_nodes = graph.partition(k_t2, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = qk_bmm.algorithms('data') + sub_nodes = graph.partition(qk_bmm, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = mask.algorithms('head') + sub_nodes = graph.partition(mask, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = softmax.algorithms('dim') + sub_nodes = graph.partition(softmax, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = attn_dropout.algorithms('dim') + sub_nodes = graph.partition(attn_dropout, algo, config=dict(dim=0, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = attnv_bmm.algorithms('data') + sub_nodes = graph.partition(attnv_bmm, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = attnview.algorithms('head') + sub_nodes = graph.partition(attnview, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = linear.algorithms('row') + sub_nodes = graph.partition(linear, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + # ========== between attention and mlp =============== + sub_nodes = graph.replicate(attn_post_dropout, times=resource.ngpus) + all_sub_nodes.append(sub_nodes) + + sub_nodes = graph.replicate(attn_residual, times=resource.ngpus) + all_sub_nodes.append(sub_nodes) + + sub_nodes = graph.replicate(ffn_ln, times=resource.ngpus) + all_sub_nodes.append(sub_nodes) + + # =========== mlp =========== + algo = ffn_linear1.algorithms('column') + sub_nodes = graph.partition(ffn_linear1, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = ffn_gelu.algorithms('dim') + sub_nodes = graph.partition(ffn_gelu, algo, config=dict(dim=2, chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + algo = ffn_linear2.algorithms('row') + sub_nodes = graph.partition(ffn_linear2, algo, config=dict(chunk_num=ndevs)) + all_sub_nodes.append(sub_nodes) + + # ========== post mlp ======== + sub_nodes = graph.replicate(ffn_post_dropout, times=resource.ngpus) + all_sub_nodes.append(sub_nodes) + + sub_nodes = graph.replicate(ffn_post_residual, times=resource.ngpus) + all_sub_nodes.append(sub_nodes) + + # =========== loss =========== + sub_nodes = graph.replicate(loss, times=ndevs) + all_sub_nodes.append(sub_nodes) + + for sub_nodes in all_sub_nodes: + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + print(graph) + # assert False + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + sugraph.partial_set_order(fsus, lazy=False) + return sugraph diff --git a/examples/transformer/transformer.py b/examples/transformer/transformer.py index b42dfb05..b856da03 100644 --- a/examples/transformer/transformer.py +++ b/examples/transformer/transformer.py @@ -128,13 +128,15 @@ def forward(self, hidden_states): attn_out = self.attention(in_attn_norm) # residual attn_out = self.attn_dropout(attn_out) - residual = attn_out + hidden_states + # residual = attn_out + hidden_states + residual = attn_out * 2 # ffn in_ffn_norm = self.ffn_layernorm(residual) ffn_out = self.ffn(in_ffn_norm) # residual ffn_out = self.ffn_dropout(ffn_out) - ffn_out = ffn_out + residual + # ffn_out = ffn_out + residual + ffn_out = ffn_out * 2 loss = torch.sum(ffn_out) return loss From b5da6a770bca9e8c82d76ab5176c755635d96c81 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 11:28:36 +0800 Subject: [PATCH 0377/1892] add benchmark transformer --- benchmark/megatron/transformer.py | 207 ++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 benchmark/megatron/transformer.py diff --git a/benchmark/megatron/transformer.py b/benchmark/megatron/transformer.py new file mode 100644 index 00000000..89d391b2 --- /dev/null +++ b/benchmark/megatron/transformer.py @@ -0,0 +1,207 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + benchmark/megatron/transformer.py +""" + +import torch +from torch import nn +import torch.nn.functional as F +import cube +from benchmark.megatron.layers import ColumnParallelLinear, RowParallelLinear + + +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +class MultiHeadSelfAttention(nn.Module): + + def __init__(self, seq_len, embed_dim, heads, dropout): + super().__init__() + + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_head = heads + self.dim_head = embed_dim // heads + self.scale = self.dim_head ** -0.5 + + self.world_size = torch.distributed.get_world_size() + + self.toqkv = ColumnParallelLinear( + embed_dim, 3 * embed_dim, bias=False, + full_input=True, full_output=False + ) + self.out = RowParallelLinear( + embed_dim, embed_dim, bias=False, + full_input=False, full_output=True + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + x: [L, N, E]: seq_len, batch_size, embedding dimension + output: [L, N, E] + """ + bs = x.shape[1] + # [L, N, E] -> [L, N, (3 * num_heads * dim_head)] + qkv = self.toqkv(x) + # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 + qkv = qkv.chunk(3, dim=-1) + q, k, v = qkv + q = q.contiguous() + q = q.view(self.seq_len, (bs * self.num_head // self.world_size), self.dim_head) + k = k.contiguous() + k = k.view(self.seq_len, (bs * self.num_head // self.world_size), self.dim_head) + v = v.contiguous() + v = v.view(self.seq_len, (bs * self.num_head // self.world_size), self.dim_head) + + # [L, (N * num_head), dim_head] -> [(N * num_head), L, dim_head] + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + # [(N * num_head), L, dim_head] -> [(N * num_head), L, dim_head] + q = q * self.scale + # [(N * num_head), L, dim_head] -> [(N * num_head), dim_head, L] + k = k.transpose(-2, -1) + # [(N * num_head), L, dim_head] * [(N * num_head), dim_head, L] + # -> [(N * num_head), L, L] + attn = torch.bmm(q, k) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = cube.runtime.function.tril_mask( + attn, self.num_head // self.world_size + ) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = F.softmax(attn, dim=-1) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = self.dropout(attn) + # [(N * num_head), L, L] * [(N * num_head), L, dim_head] + # -> [(N * num_head), L, dim_head] + output = torch.bmm(attn, v) + + # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] + output = cube.runtime.function.attn_view( + output, self.num_head // self.world_size + ) + + # [L, N, num_head * dim_head] * [E, embed_head * dim_head] + # -> [L, N, E] + output = self.out(output) + return output + + +class FFN(torch.nn.Module): + + def __init__(self, hidden_size: int): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, 4 * hidden_size, + full_input=True, full_output=False + ) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, hidden_size, + full_input=False, full_output=True + ) + + def forward(self, hidden_states): + # [L, N, E] * [E, 4E] -> [L, N, 4E] + out = self.dense_h_to_4h(hidden_states) + # [L, N, 4E] -> [L, N, 4E] + out = F.gelu(out) + # [L, N, 4E] * [4E, E] -> [L, N, E] + out = self.dense_4h_to_h(out) + return out + + +class TransformerLayer(torch.nn.Module): + + def __init__(self, seq_len, hidden_size, head_num, dropout): + super().__init__() + # layer norm + self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + + self.attention = MultiHeadSelfAttention(seq_len, hidden_size, head_num, dropout) + self.attn_dropout = torch.nn.Dropout(dropout) + + self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + self.ffn = FFN(hidden_size) + self.ffn_dropout = torch.nn.Dropout(dropout) + + def forward(self, hidden_states): + # Attention + in_attn_norm = self.input_layernorm(hidden_states) + attn_out = self.attention(in_attn_norm) + # residual + attn_out = self.attn_dropout(attn_out) + # residual = attn_out + hidden_states + residual = attn_out * 2 + # ffn + in_ffn_norm = self.ffn_layernorm(residual) + ffn_out = self.ffn(in_ffn_norm) + # residual + ffn_out = self.ffn_dropout(ffn_out) + # ffn_out = ffn_out + residual + ffn_out = ffn_out * 2 + + loss = torch.sum(ffn_out) + return loss + + +def train(): + L = 512 # seq len + N = 32 # batch size + # configs: [hidden size, num_head] + # E, num_head = [1536, 16] # 1.2B model + # E, num_head = [1920, 20] # 2.5B model + # E, num_head = [2304, 24] # 4.2B model + E, num_head = [3072, 32] # 8.7B model + + + model = TransformerLayer( + seq_len=L, hidden_size=E, head_num=num_head, dropout=0.5 + ).cuda() + + dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) + + def train_iter(model, dataloader): + data = next(dataloader) + torch.distributed.broadcast(data, 0) + torch.cuda.synchronize() + loss = model(data) + loss.backward() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + + +if __name__ == '__main__': + + cube.init() + train() From fad8dc0efed72aeddf37f5d6a00bb071e34a64ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 13:50:35 +0800 Subject: [PATCH 0378/1892] backward executor --- cube/runtime/executor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index c613ba93..8b6ef2fb 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -19,9 +19,9 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr """ Backward the SUs """ - for tensor in input_tensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - tensor.retain_grad() + # for tensor in input_tensors: + # if torch.is_tensor(tensor) and tensor.requires_grad: + # tensor.retain_grad() if len(output_tensor_grads) != len(output_tensors): raise RuntimeError( @@ -40,7 +40,7 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr # print('backwarding... ') in_grads = torch.autograd.grad(output_tensors, inputs, output_tensor_grads) for idx, grad in zip(indices, in_grads): - if input_tensors[idx].is_leaf: + if isinstance(input_tensors[idx], torch.nn.Parameter): input_tensors[idx].grad = grad grads[idx] = grad @@ -61,6 +61,7 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr # grads.append(tensor.grad) # else: # grads.append(None) + if len(grads) == 0: return None elif len(grads) == 1: return grads[0] else: return tuple(grads) From 51078cda1ac6738ad7b7a0d082cb533e3c0944d1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 16:29:14 +0800 Subject: [PATCH 0379/1892] init gpt model --- cube/runtime/function/__init__.py | 3 +- cube/runtime/function/function.py | 18 ++ examples/gpt/gpt.py | 289 ++++++++++++++++++++++++++++++ 3 files changed, 309 insertions(+), 1 deletion(-) create mode 100644 cube/runtime/function/function.py create mode 100644 examples/gpt/gpt.py diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py index a8c9ff96..e46ddb50 100644 --- a/cube/runtime/function/__init__.py +++ b/cube/runtime/function/__init__.py @@ -1,2 +1,3 @@ import cube.runtime.function.complex as complex -from cube.runtime.function.complex import * \ No newline at end of file +from cube.runtime.function.complex import * +from cube.runtime.function.function import embedding \ No newline at end of file diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py new file mode 100644 index 00000000..9704f3a1 --- /dev/null +++ b/cube/runtime/function/function.py @@ -0,0 +1,18 @@ +import torch +import torch.nn.functional as F + + +def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): + """ + Embedding + """ + input_mask = (input < start) | (input >= stop) + masked_input = input.clone() - start + masked_input[input_mask] = 0 + output = F.embedding( + masked_input, weight, + None, None, 2.0, False, False + ) + output[input_mask, :] = 0.0 + return output + diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py new file mode 100644 index 00000000..cbd63c3d --- /dev/null +++ b/examples/gpt/gpt.py @@ -0,0 +1,289 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/transformer/transformer.py +""" + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.modules.normalization import LayerNorm +import cube +from cube.runtime.function.function import embedding + + +from examples.transformer.policy.tensor_parallel import transform_policy +from examples.transformer.policy.tensor_parallel import schedule_policy + +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +class MultiHeadSelfAttention(nn.Module): + + def __init__(self, seq_len, embed_dim, heads, dropout): + super().__init__() + + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_head = heads + self.dim_head = embed_dim // heads + self.scale = self.dim_head ** -0.5 + + self.weight_qkv = torch.nn.Parameter(torch.empty( + 3 * embed_dim, embed_dim + )) + self.weight_out = torch.nn.Parameter(torch.empty( + embed_dim, embed_dim + )) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + x: [L, N, E]: seq_len, batch_size, embedding dimension + output: [L, N, E] + """ + # [L, N, E] -> 3 x [L, (N * num_head), dim_head] + q, k, v = cube.runtime.function.toqkv( + x, self.weight_qkv, self.num_head + ) + + # [L, (N * num_head), dim_head] -> [(N * num_head), L, dim_head] + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + # [(N * num_head), L, dim_head] -> [(N * num_head), L, dim_head] + q = q * self.scale + # [(N * num_head), L, dim_head] -> [(N * num_head), dim_head, L] + k = k.transpose(-2, -1) + # [(N * num_head), L, dim_head] * [(N * num_head), dim_head, L] + # -> [(N * num_head), L, L] + attn = torch.bmm(q, k) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = cube.runtime.function.tril_mask(attn, self.num_head) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = F.softmax(attn, dim=-1) + + # [(N * num_head), L, L] -> [(N * num_head), L, L] + attn = self.dropout(attn) + # [(N * num_head), L, L] * [(N * num_head), L, dim_head] + # -> [(N * num_head), L, dim_head] + output = torch.bmm(attn, v) + + # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] + output = cube.runtime.function.attn_view(output, self.num_head) + + # [L, N, num_head * dim_head] * [E, embed_head * dim_head] + # -> [L, N, E] + output = F.linear(output, self.weight_out) + return output + + +class FFN(torch.nn.Module): + + def __init__(self, hidden_size: int): + super().__init__() + self.dense_h_to_4h = torch.nn.Linear( + hidden_size, 4 * hidden_size + ) + self.dense_4h_to_h = torch.nn.Linear( + 4 * hidden_size, hidden_size + ) + + def forward(self, hidden_states): + # [L, N, E] * [E, 4E] -> [L, N, 4E] + out = self.dense_h_to_4h(hidden_states) + # [L, N, 4E] -> [L, N, 4E] + out = F.gelu(out) + # [L, N, 4E] * [4E, E] -> [L, N, E] + out = self.dense_4h_to_h(out) + return out + + +class TransformerLayer(torch.nn.Module): + + def __init__(self, seq_len, hidden_size, head_num, dropout): + super().__init__() + # layer norm + self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + + self.attention = MultiHeadSelfAttention(seq_len, hidden_size, head_num, dropout) + self.attn_dropout = torch.nn.Dropout(dropout) + + self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + self.ffn = FFN(hidden_size) + self.ffn_dropout = torch.nn.Dropout(dropout) + + def forward(self, hidden_states): + # Attention + in_attn_norm = self.input_layernorm(hidden_states) + attn_out = self.attention(in_attn_norm) + # residual + attn_out = self.attn_dropout(attn_out) + # residual = attn_out + hidden_states + residual = attn_out * 2 + # ffn + in_ffn_norm = self.ffn_layernorm(residual) + ffn_out = self.ffn(in_ffn_norm) + # residual + ffn_out = self.ffn_dropout(ffn_out) + # ffn_out = ffn_out + residual + ffn_out = ffn_out * 2 + + loss = torch.sum(ffn_out) + return loss + + +class Embedding(torch.nn.Module): + + def __init__(self, num_embed, dim_embed, dropout): + super().__init__() + self.num_embed = num_embed + self.weight = torch.nn.Parameter( + torch.empty(self.num_embed, dim_embed) + ) + + def forward(self, input): + embeddings = cube.runtime.function.embedding( + 0, self.num_embed, input, self.weight + ) + return embeddings + + +class GPT(torch.nn.Module): + + def __init__(self, hidden_size, vocab_size, seqlen_size, + bs, seqlen, num_head, num_layers: int): + super().__init__() + + self.num_layers = num_layers + self.bs = bs + self.seqlen = seqlen + + # embeddings + self.vocab_size = vocab_size + self.vocab_embed_weight = torch.nn.Parameter( + torch.empty(vocab_size, hidden_size) + ) + self.seqlen_size = seqlen_size + self.pos_embed_weight = torch.nn.Parameter( + torch.empty(seqlen_size, hidden_size) + ) + + self.embed_dropout = torch.nn.Dropout(0.5) + + # transformer layers + self.layers = torch.nn.ModuleList( + [TransformerLayer(seqlen, hidden_size, num_head, 0.5) for _ in range(num_layers)] + ) + + # final linear + self.final_layernorm = LayerNorm( + hidden_size, 1e-5 + ) + + def forward(self, input_ids, position_ids): + """ + input_ids: + [bs, seqlen] + position_ids: + [bs, seqlen] + """ + + # preprocess: embedding + # [bs, seqlen] -> [bs, seqlen, hidden size] + words_embeddings = cube.runtime.function.embedding( + input_ids, self.vocab_embed_weight, 0, self.vocab_size + ) + # [bs, seqlen] -> [bs, seqlen, hidden size] + position_embeddings = cube.runtime.function.embedding( + position_ids, self.pos_embed_weight, 0, self.seqlen_size + ) + embeddings = words_embeddings + position_embeddings + encoder_input = self.embed_dropout(embeddings) + + # [bs, seqlen, hidden size] -> [seqlen, bs, hidden size] + hidden_states = encoder_input.transpose(0, 1).contiguous() + + # transformer + # [seqlen, bs, hidden size] -> [seqlen, bs, hidden size] + for layer in self.layers: + hidden_states = layer(hidden_states) + + hidden_states = self.final_layernorm(hidden_states) + + # post process + # [seqlen, bs, hidden size] -> [bs, seqlen, hidden size] + hidden_states = hidden_states.transpose(0, 1).contiguous() + # [bs, seqlen, hidden size] * [self.vocab_size, hidden size] + # => [bs, seqlen, self.vocab_size] + logits = F.linear(hidden_states, self.vocab_embed_weight) + + # loss # for verification, the mask is ommitted + loss = torch.sum(logits) / (self.seqlen * self.bs) + + return loss + + +def train(): + L = 512 # seq len + N = 32 # batch size + # configs: [hidden size, num_head] + # E, num_head = [1536, 16] # 1.2B model + # E, num_head = [1920, 20] # 2.5B model + # E, num_head = [2304, 24] # 4.2B model + E, num_head = [3072, 32] # 8.7B model + layers = 4 + + + model = GPT( + hidden_size=E, vocab_size=50304, seqlen_size=L, + bs=N, seqlen=L, num_head=num_head, num_layers=layers + ) + model = cube.SemanticModel( + model, input_shapes=([N, L], [N, L],), + ) + + dataloader = cube.runtime.syndata.SynDataLoader(1280, [0, 0], [N, L], [N, L]) + + @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + + +if __name__ == '__main__': + + cube.init() + train() From 63e70a3086181e52c2b4c3c55724ccd763ccf3a1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 16:46:37 +0800 Subject: [PATCH 0380/1892] fix half overflow --- cube/runtime/function/complex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index 1544d218..ff5819c5 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -53,7 +53,7 @@ def tril_mask(input: torch.Tensor, num_head: int): mask = mask.view(bs, 1, seqlen, seqlen) mask = (mask < 0.5) # mask - masked_input = input.masked_fill_(mask, -100000.0) + masked_input = input.masked_fill_(mask, -10000.0) masked_input = masked_input.view((bs * num_head), seqlen, seqlen) return masked_input From a81e99fdbc51bd88e84d67589a6530908cf91374 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 16:47:04 +0800 Subject: [PATCH 0381/1892] text data loader for embedding --- cube/runtime/syndata.py | 13 +++++++++++++ examples/gpt/gpt.py | 16 ++++++---------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index e8ea908d..ed7e672b 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -85,3 +85,16 @@ def __next__(self): datas = self.datas[self.pos % self._buffer_num] if len(datas) == 1: return datas[0] else: return tuple(datas) + + +class SynTextDataLoader(SynDataLoader): + + def set_data_buffer(self, buffer_num=4, text_num=50257): + self.datas = list() + self._buffer_num = buffer_num + for _ in range(self._buffer_num): + datas = list() + for shape in self.shapes: + data = torch.randint(0, text_num, shape, dtype=torch.long).cuda() + datas.append(data) + self.datas.append(datas) diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py index cbd63c3d..2b8a52d4 100644 --- a/examples/gpt/gpt.py +++ b/examples/gpt/gpt.py @@ -14,9 +14,7 @@ import torch from torch import nn import torch.nn.functional as F -from torch.nn.modules.normalization import LayerNorm import cube -from cube.runtime.function.function import embedding from examples.transformer.policy.tensor_parallel import transform_policy @@ -139,9 +137,7 @@ def forward(self, hidden_states): ffn_out = self.ffn_dropout(ffn_out) # ffn_out = ffn_out + residual ffn_out = ffn_out * 2 - - loss = torch.sum(ffn_out) - return loss + return ffn_out class Embedding(torch.nn.Module): @@ -188,7 +184,7 @@ def __init__(self, hidden_size, vocab_size, seqlen_size, ) # final linear - self.final_layernorm = LayerNorm( + self.final_layernorm = torch.nn.LayerNorm( hidden_size, 1e-5 ) @@ -237,7 +233,7 @@ def forward(self, input_ids, position_ids): def train(): L = 512 # seq len - N = 32 # batch size + N = 1 # batch size # configs: [hidden size, num_head] # E, num_head = [1536, 16] # 1.2B model # E, num_head = [1920, 20] # 2.5B model @@ -254,12 +250,12 @@ def train(): model, input_shapes=([N, L], [N, L],), ) - dataloader = cube.runtime.syndata.SynDataLoader(1280, [0, 0], [N, L], [N, L]) + dataloader = cube.runtime.syndata.SynTextDataLoader(1280, [0, 0], [N, L], [N, L]) @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) + input_ids, position_ids = next(dataloader) + loss = model(input_ids, position_ids) loss.backward() model = model.get_gen_module() From d85be6e87fbaca66b8e685f88f982a7b28a63fcc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 17:27:06 +0800 Subject: [PATCH 0382/1892] move linear --- .../{ => case_study/models}/transformer.py | 6 +- examples/e2e.py | 132 ------------------ examples/{ => mlp}/linears.py | 18 ++- examples/{ => mlp}/policy/col_parallel.py | 0 examples/{ => mlp}/policy/data_parallel.py | 10 +- examples/{ => mlp}/policy/hybrid_parallel.py | 0 .../{ => mlp}/policy/megatron_parallel.py | 0 examples/{ => mlp}/policy/no_parallel.py | 0 examples/{ => mlp}/policy/pipe_parallel.py | 48 +++---- examples/{ => mlp}/policy/row_parallel.py | 0 examples/policy/nested_parallel.py | 53 ------- 11 files changed, 48 insertions(+), 219 deletions(-) rename examples/{ => case_study/models}/transformer.py (98%) delete mode 100644 examples/e2e.py rename examples/{ => mlp}/linears.py (75%) rename examples/{ => mlp}/policy/col_parallel.py (100%) rename examples/{ => mlp}/policy/data_parallel.py (71%) rename examples/{ => mlp}/policy/hybrid_parallel.py (100%) rename examples/{ => mlp}/policy/megatron_parallel.py (100%) rename examples/{ => mlp}/policy/no_parallel.py (100%) rename examples/{ => mlp}/policy/pipe_parallel.py (50%) rename examples/{ => mlp}/policy/row_parallel.py (100%) delete mode 100644 examples/policy/nested_parallel.py diff --git a/examples/transformer.py b/examples/case_study/models/transformer.py similarity index 98% rename from examples/transformer.py rename to examples/case_study/models/transformer.py index c317bfe5..098a91a9 100644 --- a/examples/transformer.py +++ b/examples/case_study/models/transformer.py @@ -29,7 +29,7 @@ def forward(self, x, mask): """ bs = x.shape[1] - # [L, N, E] -> [L, N, (num_heads * dim_head * 3)] + # [L, N, E] -> [L, N, (3 * num_heads * dim_head)] qkv = F.linear(x, self.weight_qkv, None) # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 qkv = qkv.chunk(3, dim=-1) @@ -114,7 +114,7 @@ def _ref_forward(self, x, mask=True): return output -class MLP(torch.nn.Module): +class FFN(torch.nn.Module): def __init__(self, hidden_size: int): super().__init__() @@ -146,7 +146,7 @@ def __init__(self, seq_len, hidden_size, head_num, dropout): self.attn_dropout = torch.nn.Dropout(dropout) self.mlp_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - self.mlp = MLP(hidden_size) + self.mlp = FFN(hidden_size) self.mlp_dropout = torch.nn.Dropout(dropout) def forward(self, hidden_states, attention_mask): diff --git a/examples/e2e.py b/examples/e2e.py deleted file mode 100644 index b5f977a7..00000000 --- a/examples/e2e.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/e2e.py -""" - -import torch -from torch import nn - -import cube -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph - - -def trans_policy(ir_graph, resource): - return ir_graph - -def schedule_policy(sugraph: SUGraph, resource): - # put to micro-batch forward-backward sequence - fb_op_seqs = list() - for fsu in sugraph.fsus(): - for fb_seq in fb_op_seqs: - for ksu in fb_seq[::-1]: - if sugraph.happen_before(ksu, fsu): - fb_seq.append(fsu) - break - else: - continue - break - else: - fb_op_seqs.append([fsu]) - - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - - print(f'> collect {len(fb_op_seqs)} forward-backward sequence') - for fb_sus in fb_op_seqs: - for idx, su in enumerate(fb_sus): - if idx < 3: - sugraph.assign(su, 0) - sugraph.assign(su.mirror, 0) - else: - sugraph.assign(su, 1) - sugraph.assign(su.mirror, 1) - return sugraph - - -class FakeDataLoader: - def __init__(self, batch_size, num=640): - self.batch_size = batch_size - self.length = num - self.pos = 0 - def __iter__(self): - self.pos = 0 - return self - def reset(self, batch_size): - self.batch_size = batch_size - self.pos = 0 - def __next__(self): - self.pos += 1 - if self.pos == self.length: - raise StopIteration - return torch.randn((self.batch_size, 1024)).cuda() - - -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult, bias=False) - self.gelu = nn.GELU() - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim * mult, dim) - self.classifier = nn.Linear(dim, classes) - - def forward(self, data): - output = self.linear1(data) - output = self.gelu(output) - output = self.dropout(output) - output = self.linear2(output) - output = self.classifier(output) - loss = torch.sum(output) - return loss - -def init_weight(parameters): - for param in parameters: - with torch.no_grad(): - torch.nn.init.uniform_(param) - - -def train(): - batch_size = 64 - - model = FeedForward(dim=1024) - model = cube.schedule.SemanticModel( - model, input_shapes=([batch_size,1024],), - ) - - dataloader = FakeDataLoader(batch_size) - - @cube.schedule.schedule(model, dataloader, transform_policy=trans_policy, schedule_policy=schedule_policy) - def train_iter(model, dataloader): - # for _ in range(1): - # data = next(dataloader) - # loss = model(data) - # loss.backward() - data = next(dataloader) - loss = model(data) - loss.backward() - model = model.get_gen_module() - - init_weight(model.parameters()) - - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - for epoch in range(10): - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - - -if __name__ == '__main__': - - cube.DeviceGroup() - train() diff --git a/examples/linears.py b/examples/mlp/linears.py similarity index 75% rename from examples/linears.py rename to examples/mlp/linears.py index 16d8fb83..c2cd601c 100644 --- a/examples/linears.py +++ b/examples/mlp/linears.py @@ -8,7 +8,7 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/linears.py + examples/mlp/linears.py """ import torch @@ -17,24 +17,32 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.policy.hybrid_parallel import transform_policy -from examples.policy.hybrid_parallel import schedule_policy +from examples.mlp.policy.pipe_parallel import transform_policy +from examples.mlp.policy.pipe_parallel import schedule_policy # =================== Semantic Model Description ==================== class MLP(nn.Module): - def __init__(self, dim, mult=16): + def __init__(self, dim, mult=4): super().__init__() self.linear1 = nn.Linear(dim, dim * mult) self.linear2 = nn.Linear(dim * mult, dim) self.linear3 = nn.Linear(dim, dim * mult) self.linear4 = nn.Linear(dim * mult, dim) + self.linear5 = nn.Linear(dim, dim * mult) + self.linear6 = nn.Linear(dim * mult, dim) + self.linear7 = nn.Linear(dim, dim * mult) + self.linear8 = nn.Linear(dim * mult, dim) def forward(self, data): output = self.linear1(data) output = self.linear2(output) output = self.linear3(output) output = self.linear4(output) + output = self.linear5(output) + output = self.linear6(output) + output = self.linear7(output) + output = self.linear8(output) loss = torch.sum(output) return loss @@ -48,7 +56,7 @@ def train(): model, input_shapes=([batch_size, dim],), ) - dataloader = cube.runtime.syndata.SynDataLoader(1280, [batch_size, dim]) + dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [batch_size, dim]) @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) def train_iter(model, dataloader): diff --git a/examples/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py similarity index 100% rename from examples/policy/col_parallel.py rename to examples/mlp/policy/col_parallel.py diff --git a/examples/policy/data_parallel.py b/examples/mlp/policy/data_parallel.py similarity index 71% rename from examples/policy/data_parallel.py rename to examples/mlp/policy/data_parallel.py index 960eeb4b..ff74e6bf 100644 --- a/examples/policy/data_parallel.py +++ b/examples/mlp/policy/data_parallel.py @@ -11,8 +11,12 @@ def transform_policy(graph: IRGraph, resource): for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): algo = node.algorithms('data') - assert algo - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + if algo is None: + algo = node.algorithms('dim') + assert algo + sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=resource.ngpus)) + else: + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx print(graph) @@ -31,4 +35,6 @@ def schedule_policy(sugraph: SUGraph, resource): devid = su.tag[0] sugraph.assign(su, devid) sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + sugraph.partial_set_order(fsus, lazy=False) return sugraph diff --git a/examples/policy/hybrid_parallel.py b/examples/mlp/policy/hybrid_parallel.py similarity index 100% rename from examples/policy/hybrid_parallel.py rename to examples/mlp/policy/hybrid_parallel.py diff --git a/examples/policy/megatron_parallel.py b/examples/mlp/policy/megatron_parallel.py similarity index 100% rename from examples/policy/megatron_parallel.py rename to examples/mlp/policy/megatron_parallel.py diff --git a/examples/policy/no_parallel.py b/examples/mlp/policy/no_parallel.py similarity index 100% rename from examples/policy/no_parallel.py rename to examples/mlp/policy/no_parallel.py diff --git a/examples/policy/pipe_parallel.py b/examples/mlp/policy/pipe_parallel.py similarity index 50% rename from examples/policy/pipe_parallel.py rename to examples/mlp/policy/pipe_parallel.py index d8535c2e..cfc4da6c 100644 --- a/examples/policy/pipe_parallel.py +++ b/examples/mlp/policy/pipe_parallel.py @@ -1,7 +1,8 @@ +from typing import List import math import random -from cube.schedule.su import SUType +from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.sugraph import SUGraph @@ -13,8 +14,13 @@ def transform_policy(graph, resource): for node in graph.nodes(): if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): algo = node.algorithms('data') - assert algo is not None - graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + if algo is not None: + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + else: + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=resource.ngpus)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx return graph @@ -22,40 +28,34 @@ def schedule_policy(sugraph: SUGraph, resource): """ The schedule policy """ - fb_seqs = list() + fseqs: List[List[ScheduleUnit]] = [list() for _ in range(resource.ngpus)] + fbseqs: List[List[ScheduleUnit]] = [list() for _ in range(resource.ngpus)] + for fsu in sugraph.fsus(): - for fb_seq in fb_seqs: - for ksu in fb_seq[::-1]: - if sugraph.happen_before(ksu, fsu): - fb_seq.append(fsu) - break - else: - continue - break - else: - fb_seqs.append([fsu]) + micro_bs_id = fsu.tag[0] + fseqs[micro_bs_id].append(fsu) + + for micro_bs_id, fseq in enumerate(fbseqs): + bseq = [fsu.mirror for fsu in fseq][::-1] + fbseqs[micro_bs_id] = fseq + bseq # device assignment for su in sugraph.sus(): if su.stype == SUType.Dataloader: sugraph.assign(su, 0) - - print(f'> collect {len(fb_seqs)} forward-backward sequence') - for fb_seq in fb_seqs: - chunk_num = int(math.ceil(len(fb_seq) / resource.ngpus)) - for idx, su in enumerate(fb_seq): + + print(f'> collect {len(fseqs)} forward-backward sequence') + for fseq in fseqs: + chunk_num = int(math.ceil(len(fseq) / resource.ngpus)) + for idx, su in enumerate(fseq): # devid = int(idx // chunk_num) # devid = idx % resource.ngpus devid = random.randint(0, resource.ngpus - 1) sugraph.assign(su, devid) sugraph.assign(su.mirror, devid) - # set partial order - for fb_seq in fb_seqs: - fb_seq += [fsu.mirror for fsu in fb_seq][::-1] - seqs = list() - for fb_seq in fb_seqs: + for fb_seq in fbseqs: seqs += fb_seq sugraph.partial_set_order(seqs) return sugraph diff --git a/examples/policy/row_parallel.py b/examples/mlp/policy/row_parallel.py similarity index 100% rename from examples/policy/row_parallel.py rename to examples/mlp/policy/row_parallel.py diff --git a/examples/policy/nested_parallel.py b/examples/policy/nested_parallel.py deleted file mode 100644 index 9f158d26..00000000 --- a/examples/policy/nested_parallel.py +++ /dev/null @@ -1,53 +0,0 @@ -from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation - - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using data parallel - """ - tp = 2 - dp = int(resource.ngpus // tp) - for node in graph.nodes(): - # partition data loader at data dimension - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=dp)) - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - # partition operators first in column and then in data - if isinstance(node, IRFwOperation): - all_sub_nodes = list() - if node.algorithms('column') is not None: - algo = node.algorithms('column') - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=tp)) - for sub_node in sub_nodes: - algo = sub_node.algorithms('data') - ssub_nodes = graph.partition(sub_node, algo, config=dict(chunk_num=dp)) - all_sub_nodes += ssub_nodes - else: - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) - all_sub_nodes += sub_nodes - # add tags (vdev) for node - for idx, ssub_node in enumerate(all_sub_nodes): - ssub_node.tag = idx - # print(graph) - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - return sugraph From 82ff6e62f6e51d46676e185254df293f2f9b1c83 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 18:26:54 +0800 Subject: [PATCH 0383/1892] fix bugs on contiguous send --- cube/runtime/collectives.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 4ebae7c9..b3963d6c 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -4,7 +4,7 @@ from cube.runtime.device import DeviceGroup -def send(tensors, to_ranks: List[int]): +def send(tensors: List[torch.Tensor], to_ranks: List[int]): """ send tensor to the remote devices. Each tensor can be sent to multiple devices @@ -20,6 +20,8 @@ def send(tensors, to_ranks: List[int]): # return for tensor, rank in zip(tensors, to_ranks): + if not tensor.is_contiguous(): + tensor = tensor.contiguous() send_op = torch.distributed.P2POp( torch.distributed.isend, tensor, rank ) From 758f1e9b3666fc48f72c48ca7b04ec6a8b0c7a49 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 18:28:48 +0800 Subject: [PATCH 0384/1892] fix bugs on gradient accum --- cube/runtime/executor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 8b6ef2fb..8ff3c26f 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -40,8 +40,12 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr # print('backwarding... ') in_grads = torch.autograd.grad(output_tensors, inputs, output_tensor_grads) for idx, grad in zip(indices, in_grads): - if isinstance(input_tensors[idx], torch.nn.Parameter): - input_tensors[idx].grad = grad + tensor = input_tensors[idx] + if isinstance(tensor, torch.nn.Parameter): + if tensor.grad is not None: + tensor.grad += grad + else: + tensor.grad = grad grads[idx] = grad # if len(inputs) != 0: From 51967773ea6e8780b6ba640b89fbfb0ccd02a030 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 18:29:12 +0800 Subject: [PATCH 0385/1892] fix reorder bugs on send recv --- cube/schedule/sugraph.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index 5a61a5c6..e7d23a71 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -420,6 +420,16 @@ def partial_set_order(self, seq: List[ScheduleUnit], lazy=False): if rsu.stype == SUType.Backward: if rsu.mirror not in happen_before_sus: happen_before_sus.append(rsu.mirror) + # send / recv su pair should be colocated + if rsu.stype == SUType.P2P: + if rsu in seq: + continue + if rsu.mirror in seq: + index = seq.index(rsu.mirror) + seq.insert(idx+1, rsu) + continue + if rsu in seq: + raise RuntimeError(f"Internal Error: should not appear SU: {rsu}") idx = 0 while len(happen_before_sus) > 0: if idx == len(seq): From 74f585fd9f7d4dd33efe308044c50c3ca7542d0d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 18:29:38 +0800 Subject: [PATCH 0386/1892] random pipeline example --- examples/mlp/policy/pipe_parallel.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/mlp/policy/pipe_parallel.py b/examples/mlp/policy/pipe_parallel.py index cfc4da6c..912e6974 100644 --- a/examples/mlp/policy/pipe_parallel.py +++ b/examples/mlp/policy/pipe_parallel.py @@ -4,21 +4,22 @@ from cube.schedule.su import SUType, ScheduleUnit from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRDataOperation, IRFwOperation def transform_policy(graph, resource): """ The transformation policy transposes linear using data parallel """ - from cube.graph.operator.operator import IRDataOperation, IRFwOperation + micro_batch_num = resource.ngpus for node in graph.nodes(): if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): algo = node.algorithms('data') if algo is not None: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=micro_batch_num)) else: algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=resource.ngpus)) + sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=micro_batch_num)) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx return graph @@ -28,8 +29,10 @@ def schedule_policy(sugraph: SUGraph, resource): """ The schedule policy """ - fseqs: List[List[ScheduleUnit]] = [list() for _ in range(resource.ngpus)] - fbseqs: List[List[ScheduleUnit]] = [list() for _ in range(resource.ngpus)] + micro_batch_num = resource.ngpus + + fseqs: List[List[ScheduleUnit]] = [list() for _ in range(micro_batch_num)] + fbseqs: List[List[ScheduleUnit]] = [list() for _ in range(micro_batch_num)] for fsu in sugraph.fsus(): micro_bs_id = fsu.tag[0] @@ -58,4 +61,5 @@ def schedule_policy(sugraph: SUGraph, resource): for fb_seq in fbseqs: seqs += fb_seq sugraph.partial_set_order(seqs) + print(sugraph) return sugraph From 33e6d18f2d67c6fdb29fd34b425f89aa50efbb09 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 19:58:56 +0800 Subject: [PATCH 0387/1892] pipeline 1f1b example --- examples/mlp/linears.py | 4 +- examples/mlp/policy/pipe1f1b_parallel.py | 102 +++++++++++++++++++++++ 2 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 examples/mlp/policy/pipe1f1b_parallel.py diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index c2cd601c..7c187e91 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,8 +17,8 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.pipe_parallel import transform_policy -from examples.mlp.policy.pipe_parallel import schedule_policy +from examples.mlp.policy.pipe1f1b_parallel import transform_policy +from examples.mlp.policy.pipe1f1b_parallel import schedule_policy # =================== Semantic Model Description ==================== diff --git a/examples/mlp/policy/pipe1f1b_parallel.py b/examples/mlp/policy/pipe1f1b_parallel.py new file mode 100644 index 00000000..b4ff458b --- /dev/null +++ b/examples/mlp/policy/pipe1f1b_parallel.py @@ -0,0 +1,102 @@ +from typing import List +import math + +from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRDataOperation, IRFwOperation + + +def transform_policy(graph, resource): + """ + The transformation policy transposes linear using data parallel + """ + micro_batch_num = resource.ngpus + for node in graph.nodes(): + if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): + algo = node.algorithms('data') + if algo is not None: + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=micro_batch_num)) + else: + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=micro_batch_num)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy + """ + # each device is a stage + num_micro_batch = resource.ngpus + num_stage = resource.ngpus + + fseqs: List[List[ScheduleUnit]] = [list() for _ in range(num_micro_batch)] + fbseqs: List[List[ScheduleUnit]] = [list() for _ in range(num_micro_batch)] + + for fsu in sugraph.fsus(): + micro_bs_id = fsu.tag[0] + fseqs[micro_bs_id].append(fsu) + + for micro_bs_id, fseq in enumerate(fbseqs): + bseq = [fsu.mirror for fsu in fseq][::-1] + fbseqs[micro_bs_id] = fseq + bseq + + print(f'> collect {len(fseqs)} forward-backward sequence') + + # fstages[micro_batch_id][stage] = fstages[micro_batch_id * num_stage + stage] + fstages: List[List[ScheduleUnit]] = [ + list() for _ in range(num_micro_batch * num_stage) + ] + + def f(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: + return fstages[micro_batch_id * num_stage + stage_id] + + def b(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: + fstage = f(micro_batch_id, stage_id) + bstage = [fsu.mirror for fsu in fstage][::-1] + return bstage + + # assign su to stages + for micro_bid, fseq in enumerate(fseqs): + chunk_num = int(len(fseq) // resource.ngpus) + for idx, fsu in enumerate(fseq): + stage = min(int(idx // chunk_num), num_stage - 1) + fstages[micro_bid * num_stage + stage].append(fsu) + + # stage device assignment + for micro_bid in range(num_micro_batch): + for stage in range(num_stage): + for su in f(micro_bid, stage): + sugraph.assign(su, stage) + sugraph.assign(su.mirror, stage) + + # device assignment + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + + # 1f1b scheduling + seqs = list() + + # warmup + for stage in range(num_stage): + for mid in range(stage): + seqs += f(mid, stage) + + # steady + cooldown: + for mid in range(num_micro_batch): + # enqueue backward + for stage in range(num_stage-1, -1, -1): + seqs += b(mid, stage) + # enqueue forward + for stage in range(num_stage): + f_mid = mid + 1 + num_stage - stage + if f_mid >= num_micro_batch: + continue + seqs += f(f_mid, stage) + + sugraph.partial_set_order(seqs) + # print(sugraph) + return sugraph From ecbd565898d07ca5e25633209913512c43840662 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Nov 2021 23:55:29 +0800 Subject: [PATCH 0388/1892] add coalescing module self attention and feedforward --- cube/algorithm/ops/complex.py | 253 +++++++++++++++++++++++++++++- cube/graph/operator/function.py | 75 +++++++++ cube/runtime/function/complex.py | 67 ++++++++ tests/algorithm/test_complex.py | 257 +++++++++++++++++++++++++------ 4 files changed, 604 insertions(+), 48 deletions(-) diff --git a/cube/algorithm/ops/complex.py b/cube/algorithm/ops/complex.py index f64a2acf..06f0a60a 100644 --- a/cube/algorithm/ops/complex.py +++ b/cube/algorithm/ops/complex.py @@ -1,11 +1,13 @@ from typing import Dict -from cube.algorithm.utils import split_axis +from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo from cube.graph.operator.function import CubeComplexToQKV from cube.graph.operator.function import CubeComplexTrilMask from cube.graph.operator.function import CubeComplexAttnView +from cube.graph.operator.function import CubeComplexSelfAttention +from cube.graph.operator.function import CubeComplexFeedForward _kWaitDecision = None @@ -309,3 +311,252 @@ def instantiate(self, node, config: Dict): node.set_output(0, ous[idx]) nodes.append(node) return nodes + + +class CubeSelfAttentionHeadParallel(GenericDistAlgo): + """ + Multi-Head Self-Attention. + + L: sequence length + N: batch size + E: embedding size + + Inputs: + hidden_state: [L, N, E] + w_qkv : [3 * num_head * dim_head, E] + w_out : [E, E] + num_head: int + dim_head: int + dropout_p: float + + Outputs: + hidden_state: [L, N, E] + """ + def __init__(self, node: CubeComplexSelfAttention): + if not isinstance(node, CubeComplexSelfAttention): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.num_head = node.kwargs['num_head'] + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.num_head % chunk_num == 0: + return True + return False + + def instantiate(self, node: CubeComplexSelfAttention, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + hidden_state = node.inputs(0) + w_qkv = node.inputs(1) + w_out = node.inputs(2) + num_head = node.kwargs['num_head'] + dim_head = node.kwargs['dim_head'] + dropout_p = node.kwargs['dropout_p'] + out = node.outputs(0) + + + w_qkvs = split_axis(w_qkv, 0, self.chunk_num) + w_outs = split_axis(w_out, 1, self.chunk_num) + ous = split_value(out, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ + hidden_state, w_qkvs[idx], w_outs[idx], + num_head // self.chunk_num, dim_head, dropout_p + ] + node = CubeComplexSelfAttention( + signature = 'cube.runtime.function.complex.self_attn', + inputs = inputs, + ) + node.set_output(0, ous[idx]) + nodes.append(node) + return nodes + + +class CubeSelfAttentionDataParallel(GenericDistAlgo): + """ + Multi-Head Self-Attention. + + L: sequence length + N: batch size + E: embedding size + + Inputs: + hidden_state: [L, N, E] + w_qkv : [3 * num_head * dim_head, E] + w_out : [E, E] + num_head: int + dim_head: int + dropout_p: float + + Outputs: + hidden_state: [L, N, E] + """ + def __init__(self, node: CubeComplexSelfAttention): + if not isinstance(node, CubeComplexSelfAttention): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.bs = node.inputs(0).shape[1] + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.bs % chunk_num == 0: + return True + return False + + def instantiate(self, node: CubeComplexSelfAttention, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + hidden_state = node.inputs(0) + w_qkv = node.inputs(1) + w_out = node.inputs(2) + num_head = node.kwargs['num_head'] + dim_head = node.kwargs['dim_head'] + dropout_p = node.kwargs['dropout_p'] + out = node.outputs(0) + + ins = split_axis(hidden_state, 1, self.chunk_num) + ous = split_axis(out, 1, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ + ins[idx], w_qkv, w_out, + num_head, dim_head, dropout_p + ] + node = CubeComplexSelfAttention( + signature = 'cube.runtime.function.complex.self_attn', + inputs = inputs, + ) + node.set_output(0, ous[idx]) + nodes.append(node) + return nodes + + +class CubeFeedForwardTensorParallel(GenericDistAlgo): + """ + FeedForward + + Inputs: + hidden_state: [L, N, E] + w_proj1: [4 * E, E] + w_bias1: [4 * E,] + w_porj2: [E, 4 * E] + w_bias2: [E,] + + Outputs: + hidden_state: [L, N, E] + """ + def __init__(self, node: CubeComplexFeedForward): + if not isinstance(node, CubeComplexFeedForward): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.embed_size = node.inputs(1).shape[0] + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.embed_size % chunk_num == 0: + return True + return False + + def instantiate(self, node: CubeComplexFeedForward, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + hidden_state = node.inputs(0) + w_proj1 = node.inputs(1) + w_bias1 = node.inputs(2) + w_proj2 = node.inputs(3) + w_bias2 = node.inputs(4) + + out = node.outputs(0) + + w_proj1s = split_axis(w_proj1, 0, self.chunk_num) + w_bias1s = split_axis(w_bias1, 0, self.chunk_num) + w_proj2s = split_axis(w_proj2, 1, self.chunk_num) + w_bias2s = split_value(w_bias2, self.chunk_num) + + outs = split_value(out, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ + hidden_state, + w_proj1s[idx], w_bias1s[idx], + w_proj2s[idx], w_bias2s[idx] + ] + node = CubeComplexFeedForward( + signature = 'cube.runtime.function.complex.feedforward', + inputs = inputs, + ) + node.set_output(0, outs[idx]) + nodes.append(node) + return nodes + + +class CubeFeedForwardDataParallel(GenericDistAlgo): + """ + FeedForward + + Inputs: + hidden_state: [L, N, E] + w_proj1: [4 * E, E] + w_bias1: [4 * E,] + w_porj2: [E, 4 * E] + w_bias2: [E,] + + Outputs: + hidden_state: [L, N, E] + """ + def __init__(self, node: CubeComplexFeedForward): + if not isinstance(node, CubeComplexFeedForward): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.bs = node.inputs(0).shape[1] + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.bs % chunk_num == 0: + return True + return False + + def instantiate(self, node: CubeComplexFeedForward, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + hidden_state = node.inputs(0) + w_proj1 = node.inputs(1) + w_bias1 = node.inputs(2) + w_proj2 = node.inputs(3) + w_bias2 = node.inputs(4) + out = node.outputs(0) + + ins = split_axis(hidden_state, 1, self.chunk_num) + outs = split_axis(out, 1, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ + ins[idx], + w_proj1, w_bias1, + w_proj2, w_bias2, + ] + node = CubeComplexFeedForward( + signature = 'cube.runtime.function.complex.feedforward', + inputs = inputs, + ) + node.set_output(0, outs[idx]) + nodes.append(node) + return nodes diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 52bf043b..2d553e66 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -420,6 +420,81 @@ def infer_shape(self): return True +class CubeComplexSelfAttention(IRFwOperation): + """ + Multi-Head Self-Attention. + + L: sequence length + N: batch size + E: embedding size + + Inputs: + hidden_state: [L, N, E] + w_qkv : [3 * num_head * dim_head, E] + w_out : [E, E] + num_head: int + dim_head: int + dropout_p: float + + Outputs: + hidden_state: [L, N, E] + """ + def __init__(self, signature, inputs, name='selfattn', **kwargs): + if len(inputs) != 6: + raise RuntimeError(f"Expected 6 inputs but got {input}") + num_head: int = inputs[3] + dim_head: int = inputs[4] + dropout_p: float = inputs[5] + super().__init__( + name, signature, + input_length = 3, + output_length = 1 + ) + for idx, tensor in enumerate(inputs[:3]): + self.set_input(idx, tensor) + self.kwargs['num_head'] = num_head + self.kwargs['dim_head'] = dim_head + self.kwargs['dropout_p'] = dropout_p + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + self.outputs(0).shape = self.inputs(0).shape + return True + + +class CubeComplexFeedForward(IRFwOperation): + """ + FeedForward + + Inputs: + hidden_state: [L, N, E] + w_proj1: [4 * E, E] + w_bias1: [4 * E,] + w_porj2: [E, 4 * E] + w_bias2: [E,] + + Outputs: + hidden_state: [L, N, E] + """ + def __init__(self, signature, inputs, name='selfattn', **kwargs): + if len(inputs) != 5: + raise RuntimeError(f"Expected 6 inputs but got {input}") + super().__init__( + name, signature, + input_length = 5, + output_length = 1 + ) + for idx, tensor in enumerate(inputs): + self.set_input(idx, tensor) + + def infer_shape(self): + if self.inputs(0).shape is None: + return False + self.outputs(0).shape = self.inputs(0).shape + return True + + class UnkownOperator(IRFwOperation): def __init__(self, signature, inputs, name='unknown_op', n_outputs=None): diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index ff5819c5..0f4124c6 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -74,3 +74,70 @@ def attn_view(input: torch.Tensor, num_head: int): # [L, (N * num_head), dim_head] -> [L, N, (num_head * dim_head)] input = input.view(seqlen, bs, num_head * dim_head) return input + + +def self_attn(hidden_state, w_qkv, w_out, + num_head: int, dim_head: int, + dropout_p: float): + """ + Multi-Head Self-Attention. + + L: sequence length + N: batch size + E: embedding size + + Inputs: + hidden_state: [L, N, E] + w_qkv : [3 * num_head * dim_head, E] + w_out : [E, E] + + Outputs: + hidden_state: [L, N, E] + """ + scale = dim_head ** -0.5 + seqlen = hidden_state.shape[0] + bs = hidden_state.shape[1] + + qkv = F.linear(hidden_state, w_qkv, None) + qkv = qkv.chunk(3, dim=-1) + q, k, v = qkv + q = q.contiguous() + q = q.view(seqlen, (bs * num_head), dim_head) + k = k.contiguous() + k = k.view(seqlen, (bs * num_head), dim_head) + v = v.contiguous() + v = v.view(seqlen, (bs * num_head), dim_head) + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + q = q * scale + k = k.transpose(-2, -1) + attn = torch.bmm(q, k) + + attn = tril_mask(attn, num_head) + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, dropout_p, True, False) + output = torch.bmm(attn, v) + output = attn_view(output, num_head) + + output = F.linear(output, w_out, None) + return output + + +def feedforward(hidden_state, w_proj1, w_bias1, w_proj2, w_bias2): + """ + FeedForward + + Inputs: + hidden_state: [L, N, E] + w_proj1: [4 * E, E] + w_bias1: [4 * E,] + w_porj2: [E, 4 * E] + w_bias2: [E,] + """ + out = F.linear(hidden_state, w_proj1, w_bias1) + out = F.gelu(out) + out = F.linear(out, w_proj2, w_bias2) + return out diff --git a/tests/algorithm/test_complex.py b/tests/algorithm/test_complex.py index 00041bd8..958fd6de 100644 --- a/tests/algorithm/test_complex.py +++ b/tests/algorithm/test_complex.py @@ -7,14 +7,14 @@ def test_complex_toqkv_data_parallel(): L = 64 # seq len N = 16 # batch E = 1024 # hiddend size = dim_head * num_head - num_heads = 8 - dim_head = E // num_heads + num_head = 8 + dim_head = E // num_head input = IRFullTensor(shape=[L, N, E], name='hidden').tosub() weight = IRFullTensor(shape=[3 * E, E], name='weight').tosub() semantic_op = complex.CubeComplexToQKV( signature='cube.runtime.function.complex.toqkv', - inputs = [input, weight, num_heads] + inputs = [input, weight, num_head] ) semantic_op.infer_shape() @@ -42,8 +42,8 @@ def test_complex_toqkv_data_parallel(): print(weight) assert weight.shape == [3 * E, E] - sub_heads = [node.kwargs['num_heads'] for node in nodes] - print('num_heads:') + sub_heads = [node.kwargs['num_head'] for node in nodes] + print('num_head:') for nhead in sub_heads: assert nhead == 8 print(nhead) @@ -55,9 +55,9 @@ def test_complex_toqkv_data_parallel(): print('q:', q) print('k:', k) print('v:', v) - assert q.shape == [L, N * num_heads // 4, dim_head] - assert k.shape == [L, N * num_heads // 4, dim_head] - assert v.shape == [L, N * num_heads // 4, dim_head] + assert q.shape == [L, N * num_head // 4, dim_head] + assert k.shape == [L, N * num_head // 4, dim_head] + assert v.shape == [L, N * num_head // 4, dim_head] def test_complex_toqkv_head_parallel(): @@ -65,14 +65,14 @@ def test_complex_toqkv_head_parallel(): L = 64 # seq len N = 16 # batch E = 1024 # hiddend size = dim_head * num_head - num_heads = 8 - dim_head = E // num_heads + num_head = 8 + dim_head = E // num_head input = IRFullTensor(shape=[L, N, E], name='hidden').tosub() weight = IRFullTensor(shape=[3 * E, E], name='weight').tosub() semantic_op = complex.CubeComplexToQKV( signature='cube.runtime.function.complex.toqkv', - inputs = [input, weight, num_heads] + inputs = [input, weight, num_head] ) semantic_op.infer_shape() @@ -100,10 +100,10 @@ def test_complex_toqkv_head_parallel(): assert weight.shape == [3 * E // 4, E] print(weight) - sub_heads = [node.kwargs['num_heads'] for node in nodes] + sub_heads = [node.kwargs['num_head'] for node in nodes] print('sub_heads:') for nhead in sub_heads: - assert nhead == num_heads // 4 + assert nhead == num_head // 4 print(nhead) outputs = [node.outputs() for node in nodes] @@ -113,21 +113,21 @@ def test_complex_toqkv_head_parallel(): print('q:', q) print('k:', k) print('v:', v) - assert q.shape == [L, N * num_heads // 4, dim_head] - assert k.shape == [L, N * num_heads // 4, dim_head] - assert v.shape == [L, N * num_heads // 4, dim_head] + assert q.shape == [L, N * num_head // 4, dim_head] + assert k.shape == [L, N * num_head // 4, dim_head] + assert v.shape == [L, N * num_head // 4, dim_head] def test_complex_tril_mask_data_parallel(): L = 64 # seq len N = 16 # batch - num_heads = 8 - input = IRFullTensor(shape=[N * num_heads, L, L], name='hidden').tosub() + num_head = 8 + input = IRFullTensor(shape=[N * num_head, L, L], name='hidden').tosub() semantic_op = complex.CubeComplexTrilMask( signature = 'cube.runtime.function.complex.trill_mask', - inputs = [input, num_heads], + inputs = [input, num_head], ) semantic_op.infer_shape() @@ -147,10 +147,10 @@ def test_complex_tril_mask_data_parallel(): print('inputs:') for input in inputs: print(input) - assert input.shape == [N * num_heads // 4, L, L] + assert input.shape == [N * num_head // 4, L, L] - sub_heads = [node.kwargs['num_heads'] for node in nodes] - print('num_heads:') + sub_heads = [node.kwargs['num_head'] for node in nodes] + print('num_head:') for nhead in sub_heads: assert nhead == 8 print(nhead) @@ -159,19 +159,19 @@ def test_complex_tril_mask_data_parallel(): print('outputs:') for output in outputs: print(output) - assert output.shape == [N * num_heads // 4, L, L] + assert output.shape == [N * num_head // 4, L, L] def test_complex_tril_mask_head_parallel(): L = 64 # seq len N = 16 # batch - num_heads = 8 - input = IRFullTensor(shape=[N * num_heads, L, L], name='hidden').tosub() + num_head = 8 + input = IRFullTensor(shape=[N * num_head, L, L], name='hidden').tosub() semantic_op = complex.CubeComplexTrilMask( signature = 'cube.runtime.function.complex.trill_mask', - inputs = [input, num_heads], + inputs = [input, num_head], ) semantic_op.infer_shape() @@ -191,33 +191,33 @@ def test_complex_tril_mask_head_parallel(): print('inputs:') for input in inputs: print(input) - assert input.shape == [N * num_heads // 4, L, L] + assert input.shape == [N * num_head // 4, L, L] - sub_heads = [node.kwargs['num_heads'] for node in nodes] - print('num_heads:') + sub_heads = [node.kwargs['num_head'] for node in nodes] + print('num_head:') for nhead in sub_heads: - assert nhead == num_heads // 4 + assert nhead == num_head // 4 print(nhead) outputs = [node.outputs(0) for node in nodes] print('outputs:') for output in outputs: print(output) - assert output.shape == [N * num_heads // 4, L, L] + assert output.shape == [N * num_head // 4, L, L] def test_complex_attn_view_data_parallel(): L = 64 # seq len N = 16 # batch - num_heads = 8 + num_head = 8 dim_head = 128 input = IRFullTensor( - shape=[N * num_heads, L, dim_head], name='hidden').tosub() + shape=[N * num_head, L, dim_head], name='hidden').tosub() semantic_op = complex.CubeComplexAttnView( signature = 'cube.runtime.function.complex.trill_mask', - inputs = [input, num_heads], + inputs = [input, num_head], ) semantic_op.infer_shape() @@ -237,33 +237,33 @@ def test_complex_attn_view_data_parallel(): print('inputs:') for input in inputs: print(input) - assert input.shape == [N * num_heads // 4, L, dim_head] + assert input.shape == [N * num_head // 4, L, dim_head] - sub_heads = [node.kwargs['num_heads'] for node in nodes] - print('num_heads:') + sub_heads = [node.kwargs['num_head'] for node in nodes] + print('num_head:') for nhead in sub_heads: - assert nhead == num_heads + assert nhead == num_head print(nhead) outputs = [node.outputs(0) for node in nodes] print('outputs:') for output in outputs: print(output) - assert output.shape == [L, N // 4, num_heads * dim_head] + assert output.shape == [L, N // 4, num_head * dim_head] def test_complex_attn_view_head_parallel(): L = 64 # seq len N = 16 # batch - num_heads = 8 + num_head = 8 dim_head = 128 input = IRFullTensor( - shape=[N * num_heads, L, dim_head], name='hidden').tosub() + shape=[N * num_head, L, dim_head], name='hidden').tosub() semantic_op = complex.CubeComplexAttnView( signature = 'cube.runtime.function.complex.trill_mask', - inputs = [input, num_heads], + inputs = [input, num_head], ) semantic_op.infer_shape() @@ -283,16 +283,179 @@ def test_complex_attn_view_head_parallel(): print('inputs:') for input in inputs: print(input) - assert input.shape == [N * num_heads // 4, L, dim_head] + assert input.shape == [N * num_head // 4, L, dim_head] - sub_heads = [node.kwargs['num_heads'] for node in nodes] - print('num_heads:') + sub_heads = [node.kwargs['num_head'] for node in nodes] + print('num_head:') for nhead in sub_heads: - assert nhead == num_heads // 4 + assert nhead == num_head // 4 print(nhead) outputs = [node.outputs(0) for node in nodes] print('outputs:') for output in outputs: print(output) - assert output.shape == [L, N, num_heads * dim_head // 4] + assert output.shape == [L, N, num_head * dim_head // 4] + + +def test_complex_self_attention_head_parallel(): + L = 64 # seq len + N = 16 # batch + num_head = 8 + dim_head = 128 + E = num_head * dim_head + + input = IRFullTensor( + shape=[L, N, E], name='hidden').tosub() + w_qkv = IRFullTensor( + shape=[3 * num_head * dim_head, num_head * dim_head], name='wqkv').tosub() + w_out = IRFullTensor( + shape=[num_head * dim_head, num_head * dim_head], name='wout').tosub() + + semantic_op = complex.CubeComplexSelfAttention( + signature = 'cube.runtime.function.complex.self_attn', + inputs = [input, w_qkv, w_out, num_head, dim_head, 0.5], + ) + semantic_op.infer_shape() + + op_head = complex.CubeSelfAttentionHeadParallel(semantic_op) + + assert op_head.satisfy(config=dict(chunk_num=8)) + assert not op_head.satisfy(config=dict(chunk_num=16)) + + nodes = op_head.instantiate(semantic_op, config=dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexSelfAttention) + + for idx, node in enumerate(nodes): + assert node.outputs(0).shape == [L, N, E] + assert node.outputs(0).val_map == ValueMap(idx, 4) + assert node.kwargs['num_head'] == num_head // 4 + assert node.inputs(0).shape == [L, N, E] + assert node.inputs(1).shape == [3 * E // 4, E] + assert node.inputs(2).shape == [E, E // 4] + + +def test_complex_self_attention_data_parallel(): + L = 64 # seq len + N = 16 # batch + num_head = 8 + dim_head = 128 + E = num_head * dim_head + + input = IRFullTensor( + shape=[L, N, E], name='hidden').tosub() + w_qkv = IRFullTensor( + shape=[3 * num_head * dim_head, num_head * dim_head], name='wqkv').tosub() + w_out = IRFullTensor( + shape=[num_head * dim_head, num_head * dim_head], name='wout').tosub() + + semantic_op = complex.CubeComplexSelfAttention( + signature = 'cube.runtime.function.complex.self_attn', + inputs = [input, w_qkv, w_out, num_head, dim_head, 0.5], + ) + semantic_op.infer_shape() + + op_head = complex.CubeSelfAttentionDataParallel(semantic_op) + + assert op_head.satisfy(config=dict(chunk_num=8)) + assert not op_head.satisfy(config=dict(chunk_num=32)) + + nodes = op_head.instantiate(semantic_op, config=dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexSelfAttention) + + for idx, node in enumerate(nodes): + assert node.outputs(0).shape == [L, N // 4, E] + assert node.outputs(0).val_map == ValueMap(0, 1) + assert node.kwargs['num_head'] == num_head + assert node.inputs(0).shape == [L, N // 4, E] + assert node.inputs(1).shape == [3 * E, E] + assert node.inputs(2).shape == [E, E] + + +def test_complex_feedforward_tensor_parallel(): + L = 64 # seq len + N = 16 # batch + E = 1024 + + input = IRFullTensor( + shape=[L, N, E], name='hidden').tosub() + w_proj1 = IRFullTensor( + shape=[4 * E, E], name='proj1').tosub() + w_bias1 = IRFullTensor( + shape=[4 * E,], name='bias1').tosub() + w_proj2 = IRFullTensor( + shape=[E, 4 * E], name='proj2').tosub() + w_bias2 = IRFullTensor( + shape=[E,], name='bias2').tosub() + + semantic_op = complex.CubeComplexFeedForward( + signature = 'cube.runtime.function.complex.feedforward', + inputs = [input, w_proj1, w_bias1, w_proj2, w_bias2], + ) + semantic_op.infer_shape() + + op_head = complex.CubeFeedForwardTensorParallel(semantic_op) + + assert op_head.satisfy(config=dict(chunk_num=8)) + assert op_head.satisfy(config=dict(chunk_num=32)) + + nodes = op_head.instantiate(semantic_op, config=dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexFeedForward) + + for idx, node in enumerate(nodes): + assert node.outputs(0).shape == [L, N, E] + assert node.outputs(0).val_map == ValueMap(idx, 4) + assert node.inputs(0).shape == [L, N, E] + assert node.inputs(1).shape == [4 * E // 4, E] + assert node.inputs(2).shape == [4 * E // 4,] + assert node.inputs(3).shape == [E, 4 * E // 4] + assert node.inputs(4).shape == [E,] + assert node.inputs(4).val_map == ValueMap(idx, 4) + + +def test_complex_feedforward_data_parallel(): + L = 64 # seq len + N = 16 # batch + E = 1024 + + input = IRFullTensor( + shape=[L, N, E], name='hidden').tosub() + w_proj1 = IRFullTensor( + shape=[4 * E, E], name='proj1').tosub() + w_bias1 = IRFullTensor( + shape=[4 * E,], name='bias1').tosub() + w_proj2 = IRFullTensor( + shape=[E, 4 * E], name='proj2').tosub() + w_bias2 = IRFullTensor( + shape=[E,], name='bias2').tosub() + + semantic_op = complex.CubeComplexFeedForward( + signature = 'cube.runtime.function.complex.feedforward', + inputs = [input, w_proj1, w_bias1, w_proj2, w_bias2], + ) + semantic_op.infer_shape() + + op_head = complex.CubeFeedForwardDataParallel(semantic_op) + + assert op_head.satisfy(config=dict(chunk_num=8)) + assert not op_head.satisfy(config=dict(chunk_num=32)) + + nodes = op_head.instantiate(semantic_op, config=dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexFeedForward) + + for idx, node in enumerate(nodes): + assert node.outputs(0).shape == [L, N // 4, E] + assert node.outputs(0).val_map == ValueMap(0, 1) + assert node.inputs(0).shape == [L, N // 4, E] + assert node.inputs(1).shape == [4 * E, E] + assert node.inputs(2).shape == [4 * E,] + assert node.inputs(3).shape == [E, 4 * E] + assert node.inputs(4).shape == [E,] From af666c73fce31412239f99fa97f594e56da294ce Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 00:23:42 +0800 Subject: [PATCH 0389/1892] enbale multiple transformer layers --- cube/algorithm/factory.py | 6 + cube/graph/parser/mapping.py | 4 + .../transformer/policy/megatron_parallel.py | 48 +++++ examples/transformer/transformers.py | 193 ++++++++++++++++++ 4 files changed, 251 insertions(+) create mode 100644 examples/transformer/policy/megatron_parallel.py create mode 100644 examples/transformer/transformers.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 6d5bb33e..7a850034 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -98,5 +98,11 @@ def _load_predefined_algos(self): self.register(complex.CubeComplexAttnView, complex.CubeAttnViewDataParallel, tag='data') self.register(complex.CubeComplexAttnView, complex.CubeAttnViewHeadParallel, tag='head') + self.register(complex.CubeComplexSelfAttention, complex.CubeSelfAttentionDataParallel, tag='data') + self.register(complex.CubeComplexSelfAttention, complex.CubeSelfAttentionHeadParallel, tag='head') + + self.register(complex.CubeComplexFeedForward, complex.CubeFeedForwardDataParallel, tag='data') + self.register(complex.CubeComplexFeedForward, complex.CubeFeedForwardTensorParallel, tag='tensor') + import cube.algorithm.ops.memory as mem self.register(mem.Transpose, mem.TransposeDimParallel, tag='dim') diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index af0c9bbc..b9fbf36f 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -65,5 +65,9 @@ def map(signature: str) -> IRFwOperation: __customize('attn_view'): function.CubeComplexAttnView, + __customize('self_attn'): function.CubeComplexSelfAttention, + + __customize('feedforward'): function.CubeComplexFeedForward, + } diff --git a/examples/transformer/policy/megatron_parallel.py b/examples/transformer/policy/megatron_parallel.py new file mode 100644 index 00000000..84aac0d5 --- /dev/null +++ b/examples/transformer/policy/megatron_parallel.py @@ -0,0 +1,48 @@ +from torch.nn.modules import dropout +from cube.graph import IRGraph +from cube.graph.operator.function import CubeComplexFeedForward, CubeComplexSelfAttention +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using tensor parallel + """ + print('> transforming graph...') + ndevs = resource.ngpus + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + for fnode in fnodes: + if isinstance(fnode, CubeComplexSelfAttention): + algo = fnode.algorithms('head') + sub_nodes = graph.partition(fnode, algo, config=dict(chunk_num=ndevs)) + elif isinstance(fnode, CubeComplexFeedForward): + algo = fnode.algorithms('tensor') + sub_nodes = graph.partition(fnode, algo, config=dict(chunk_num=ndevs)) + else: + sub_nodes = graph.replicate(fnode, ndevs) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + # print(graph) + # assert False + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + sugraph.partial_set_order(fsus, lazy=False) + return sugraph diff --git a/examples/transformer/transformers.py b/examples/transformer/transformers.py new file mode 100644 index 00000000..ef4dd749 --- /dev/null +++ b/examples/transformer/transformers.py @@ -0,0 +1,193 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/transformer/transformers.py +""" + +import torch +from torch import nn +import cube + +from examples.transformer.policy.megatron_parallel import transform_policy +from examples.transformer.policy.megatron_parallel import schedule_policy + +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +class MultiHeadSelfAttention(nn.Module): + + def __init__(self, embed_dim, heads, dropout): + super().__init__() + + self.num_head = heads + self.dim_head = embed_dim // heads + self.dropout = dropout + + self.weight_qkv = torch.nn.Parameter(torch.empty( + 3 * embed_dim, embed_dim + )) + self.weight_out = torch.nn.Parameter(torch.empty( + embed_dim, embed_dim + )) + + def forward(self, x): + """ + Multi-Head Self-Attention. + + L: sequence length + N: batch size + E: embedding size + + Inputs: + hidden_state: [L, N, E] + w_qkv : [3 * num_head * dim_head, E] + w_out : [E, E] + + Outputs: + hidden_state: [L, N, E] + """ + + hidden_state = cube.runtime.function.complex.self_attn( + x, self.weight_qkv, self.weight_out, + self.num_head, self.dim_head, self.dropout + ) + return hidden_state + + +class FFN(torch.nn.Module): + + def __init__(self, hidden_size: int): + super().__init__() + self.proj1_weight = torch.nn.Parameter( + torch.empty(4 * hidden_size, hidden_size) + ) + self.proj1_bias = torch.nn.Parameter( + torch.empty(4 * hidden_size) + ) + self.proj2_weight = torch.nn.Parameter( + torch.empty(hidden_size, 4 * hidden_size) + ) + self.proj2_bias = torch.nn.Parameter( + torch.empty(hidden_size) + ) + + def forward(self, hidden_states): + hidden_states = cube.runtime.function.complex.feedforward( + hidden_states, + self.proj1_weight, self.proj1_bias, + self.proj2_weight, self.proj2_bias + ) + return hidden_states + + +class TransformerLayer(torch.nn.Module): + + def __init__(self, hidden_size, head_num, dropout): + super().__init__() + # layer norm + self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + + self.attention = MultiHeadSelfAttention(hidden_size, head_num, dropout) + self.attn_dropout = torch.nn.Dropout(dropout) + + self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) + self.ffn = FFN(hidden_size) + self.ffn_dropout = torch.nn.Dropout(dropout) + + def forward(self, hidden_states): + # Attention + in_attn_norm = self.input_layernorm(hidden_states) + attn_out = self.attention(in_attn_norm) + # residual + attn_out = self.attn_dropout(attn_out) + # residual = attn_out + hidden_states + residual = attn_out * 2 + # ffn + in_ffn_norm = self.ffn_layernorm(residual) + ffn_out = self.ffn(in_ffn_norm) + # residual + ffn_out = self.ffn_dropout(ffn_out) + # ffn_out = ffn_out + residual + ffn_out = ffn_out * 2 + + loss = torch.sum(ffn_out) + return loss + + +class Transformers(torch.nn.Module): + + def __init__(self, hidden_size, head_num, layer_num): + super().__init__() + + self.transformer1 = TransformerLayer(hidden_size, head_num, 0.5) + self.transformer2 = TransformerLayer(hidden_size, head_num, 0.5) + self.transformer3 = TransformerLayer(hidden_size, head_num, 0.5) + self.transformer4 = TransformerLayer(hidden_size, head_num, 0.5) + + def forward(self, hidden_states): + + hidden_states = self.transformer1(hidden_states) + hidden_states = self.transformer2(hidden_states) + hidden_states = self.transformer3(hidden_states) + hidden_states = self.transformer4(hidden_states) + loss = torch.sum(hidden_states) + return loss + + +def train(): + L = 512 # seq len + N = 16 # batch size + # configs: [hidden size, num_head] + # E, num_head = [1536, 16] # 1.2B model + # E, num_head = [1920, 20] # 2.5B model + # E, num_head = [2304, 24] # 4.2B model + E, num_head = [3072, 32] # 8.7B model + + + model = TransformerLayer( + hidden_size=E, head_num=num_head, dropout=0.5 + ) + model = cube.SemanticModel( + model, input_shapes=([L, N, E],), + ) + + dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) + + @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + +if __name__ == '__main__': + + cube.init() + train() \ No newline at end of file From 52744a550ebd964ac6743b854d1197e0cea5501d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 01:08:06 +0800 Subject: [PATCH 0390/1892] fix bugs on op replica --- cube/graph/graph.py | 15 ++++------- cube/graph/operator/operator.py | 44 +++++++++++++++++++++++++++++++-- cube/ir/cten.py | 4 +-- 3 files changed, 49 insertions(+), 14 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index dab05702..5bd15c66 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -267,21 +267,16 @@ def replicate(self, op: IRCell, times=1): if op not in self.nodes(): raise RuntimeError(f"Op {op} not exsits") + cpy_op = op.replicate() + if op.mirror is not None: + cpy_mirror_op = op.mirror.replicate() ops = [op] mirror_ops = [op.mirror] for _ in range(times - 1): - cpy_op = copy.copy(op) - for idx, input in enumerate(op.inputs()): - cpy_op.set_input(idx, input) - for idx, output in enumerate(op.outputs()): - cpy_op.set_output(idx, output) + cpy_op = op.replicate() if op.mirror is not None: - cpy_mirror_op = copy.copy(op.mirror) - for idx, input in enumerate(op.mirror.inputs()): - cpy_mirror_op.set_input(idx, input) - for idx, output in enumerate(op.mirror.outputs()): - cpy_mirror_op.set_output(idx, output) + cpy_mirror_op = op.mirror.replicate() mirror_ops.append(cpy_mirror_op) IRCell.make_pair(cpy_op, cpy_mirror_op) ops.append(cpy_op) diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 48f94838..9dcd28f5 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,8 +1,6 @@ from typing import Any, Optional, Union, List import copy -from torch._C import is_anomaly_enabled - from cube.ir.cten import IRTensor, IRCell from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory @@ -71,6 +69,20 @@ def set_input(self, input_index: int, val: Any): val.parent._add_fdst_cell(self) return super().set_input(input_index, val) + def replicate(self): + """ + Replicate the Operation + """ + cpy = copy.copy(self) + cpy._device = list() + cpy._inputs = copy.copy(self._inputs) + cpy._outputs = copy.copy(self._outputs) + cpy._mirror = None + cpy._tag = None + cpy.clear_predecessor() + cpy.clear_successor() + return cpy + def __repr__(self): inputs = list() for tensor in self.inputs(): @@ -126,6 +138,20 @@ def __init__(self, data_num, grad_num, name='backward'): output_length=data_num ) + def replicate(self): + """ + Replicate the backward op + """ + cpy = copy.copy(self) + cpy._device = list() + cpy._inputs = copy.copy(self._inputs) + cpy._outputs = copy.copy(self._outputs) + cpy._mirror = None + cpy._tag = None + cpy.clear_predecessor() + cpy.clear_successor() + return cpy + def datas(self, index: Optional[int] = None) -> Union[List[Any], Any]: if index is None: return self.inputs()[:self.data_num] @@ -233,6 +259,20 @@ def __init__(self, data_num: int, batch_dims: List[int], name='dataloader'): super().__init__(name, signature, 0, data_num) self.batch_dims = batch_dims + def replicate(self): + """ + Replicate the Operation + """ + cpy = copy.copy(self) + cpy._device = list() + cpy._inputs = copy.copy(self._inputs) + cpy._outputs = copy.copy(self._outputs) + cpy._mirror = None + cpy._tag = None + cpy.clear_predecessor() + cpy.clear_successor() + return cpy + def get_batch_dims(self): return copy.copy(self.batch_dims) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 4e63318f..cec17531 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -75,7 +75,7 @@ def __eq__(self, other): @property def device(self): - return list(self._device) + return copy.copy(self._device) @device.setter def device(self, device_id: Union[int, List[int]]): @@ -86,7 +86,7 @@ def device(self, device_id: Union[int, List[int]]): device_id = [device_id] if not all([isinstance(devid, int) for devid in device_id]): raise KeyError("Require device Union[int, List[int]]") - self._device = device_id + self._device = copy.copy(list(device_id)) @property def mirror(self): From acf07eae031d9f02b55019b8b7fd13ef4014ee16 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 01:15:35 +0800 Subject: [PATCH 0391/1892] megatron partition policy for transformers --- .../transformer/policy/megatron_parallel.py | 21 ++++++++++++++++--- examples/transformer/transformer.py | 2 +- examples/transformer/transformers.py | 2 +- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/examples/transformer/policy/megatron_parallel.py b/examples/transformer/policy/megatron_parallel.py index 84aac0d5..2ba4ec59 100644 --- a/examples/transformer/policy/megatron_parallel.py +++ b/examples/transformer/policy/megatron_parallel.py @@ -12,18 +12,33 @@ def transform_policy(graph: IRGraph, resource): """ print('> transforming graph...') ndevs = resource.ngpus + dp = 2 + tp = ndevs // dp fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] for fnode in fnodes: + sub_nodes = list() if isinstance(fnode, CubeComplexSelfAttention): algo = fnode.algorithms('head') - sub_nodes = graph.partition(fnode, algo, config=dict(chunk_num=ndevs)) + tp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=tp)) + for tp_node in tp_nodes: + algo = tp_node.algorithms('data') + dp_nodes = graph.partition(tp_node, algo, config=dict(chunk_num=dp)) + sub_nodes += dp_nodes elif isinstance(fnode, CubeComplexFeedForward): algo = fnode.algorithms('tensor') - sub_nodes = graph.partition(fnode, algo, config=dict(chunk_num=ndevs)) + tp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=tp)) + for tp_node in tp_nodes: + algo = tp_node.algorithms('data') + dp_nodes = graph.partition(tp_node, algo, config=dict(chunk_num=dp)) + sub_nodes += dp_nodes else: - sub_nodes = graph.replicate(fnode, ndevs) + rep_nodes = graph.replicate(fnode, times=tp) + for rep_node in rep_nodes: + algo = rep_node.algorithms('dim') + dp_nodes = graph.partition(rep_node, algo, config=dict(dim=1, chunk_num=dp)) + sub_nodes += dp_nodes for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx diff --git a/examples/transformer/transformer.py b/examples/transformer/transformer.py index b856da03..c28b6407 100644 --- a/examples/transformer/transformer.py +++ b/examples/transformer/transformer.py @@ -2,7 +2,7 @@ example: python -m torch.distributed.launch \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ diff --git a/examples/transformer/transformers.py b/examples/transformer/transformers.py index ef4dd749..c1331459 100644 --- a/examples/transformer/transformers.py +++ b/examples/transformer/transformers.py @@ -2,7 +2,7 @@ example: python -m torch.distributed.launch \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ From 0d3435491835f28f99bfec05513a9e6ef2f05da3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 10:17:09 +0800 Subject: [PATCH 0392/1892] allreduce contiguous --- cube/runtime/collectives.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index b3963d6c..4ab66dfd 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -106,6 +106,8 @@ def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): # print(f'{torch.distributed.get_rank()}: all_reduce...') assert len(tensors) == 1 tensor = tensors[0] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() tensor = tensor.detach() tensor = tensor.requires_grad_() From d9ac28a2cbbb42dbd1b83df10b5ba0b3b9586bcc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 11:16:16 +0800 Subject: [PATCH 0393/1892] transformers policy for megatron --- .../transformer/policy/megatron_parallel.py | 52 +++++++++++-------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/examples/transformer/policy/megatron_parallel.py b/examples/transformer/policy/megatron_parallel.py index 2ba4ec59..bab64c54 100644 --- a/examples/transformer/policy/megatron_parallel.py +++ b/examples/transformer/policy/megatron_parallel.py @@ -1,9 +1,8 @@ -from torch.nn.modules import dropout from cube.graph import IRGraph from cube.graph.operator.function import CubeComplexFeedForward, CubeComplexSelfAttention from cube.schedule.su import SUType from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation +from cube.graph.operator.operator import IRDataOperation, IRFwOperation def transform_policy(graph: IRGraph, resource): @@ -15,34 +14,44 @@ def transform_policy(graph: IRGraph, resource): dp = 2 tp = ndevs // dp + # dataloader + + dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] + for dnode in dnodes: + algo = dnode.algorithms('data') + dp_nodes = graph.partition(dnode, algo, config=dict(chunk_num=dp)) + for idx, dp_node in enumerate(dp_nodes): + dp_node.tag = idx * tp + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] for fnode in fnodes: sub_nodes = list() if isinstance(fnode, CubeComplexSelfAttention): - algo = fnode.algorithms('head') - tp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=tp)) - for tp_node in tp_nodes: - algo = tp_node.algorithms('data') - dp_nodes = graph.partition(tp_node, algo, config=dict(chunk_num=dp)) - sub_nodes += dp_nodes + algo = fnode.algorithms('data') + dp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=dp)) + for dp_node in dp_nodes: + algo = dp_node.algorithms('head') + tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) + sub_nodes += tp_nodes elif isinstance(fnode, CubeComplexFeedForward): - algo = fnode.algorithms('tensor') - tp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=tp)) - for tp_node in tp_nodes: - algo = tp_node.algorithms('data') - dp_nodes = graph.partition(tp_node, algo, config=dict(chunk_num=dp)) - sub_nodes += dp_nodes + algo = fnode.algorithms('data') + dp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=dp)) + for dp_node in dp_nodes: + algo = dp_node.algorithms('tensor') + tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) + sub_nodes += tp_nodes else: - rep_nodes = graph.replicate(fnode, times=tp) - for rep_node in rep_nodes: - algo = rep_node.algorithms('dim') - dp_nodes = graph.partition(rep_node, algo, config=dict(dim=1, chunk_num=dp)) - sub_nodes += dp_nodes + # note replicate should put in the last due to bugs: + algo = fnode.algorithms('dim') + dp_nodes = graph.partition(fnode, algo, config=dict(dim=1, chunk_num=dp)) + for dp_node in dp_nodes: + rep_nodes = graph.replicate(dp_node, times=tp) + sub_nodes += rep_nodes for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx - # print(graph) + print(graph) # assert False return graph @@ -53,7 +62,8 @@ def schedule_policy(sugraph: SUGraph, resource): """ for su in sugraph.sus(): if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) + devid = su.tag[0] + sugraph.assign(su, devid) for su in sugraph.fsus(): devid = su.tag[0] sugraph.assign(su, devid) From 7cca7da1c39c606f3229f499770f43dccc6497ce Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 13:45:49 +0800 Subject: [PATCH 0394/1892] embedding --- cube/algorithm/ops/complex.py | 96 +++++++++++++++++++++++++++++++ cube/graph/operator/function.py | 28 ++++++++- cube/runtime/function/__init__.py | 3 +- cube/runtime/function/complex.py | 24 ++++++++ cube/runtime/function/function.py | 18 ------ tests/algorithm/test_complex.py | 82 ++++++++++++++++++++++++++ 6 files changed, 230 insertions(+), 21 deletions(-) delete mode 100644 cube/runtime/function/function.py diff --git a/cube/algorithm/ops/complex.py b/cube/algorithm/ops/complex.py index 06f0a60a..9e985d65 100644 --- a/cube/algorithm/ops/complex.py +++ b/cube/algorithm/ops/complex.py @@ -8,6 +8,7 @@ from cube.graph.operator.function import CubeComplexAttnView from cube.graph.operator.function import CubeComplexSelfAttention from cube.graph.operator.function import CubeComplexFeedForward +from cube.graph.operator.function import CubeComplexEmbedding _kWaitDecision = None @@ -560,3 +561,98 @@ def instantiate(self, node: CubeComplexFeedForward, config: Dict): node.set_output(0, outs[idx]) nodes.append(node) return nodes + + +class CubeEmbedDataParallel(GenericDistAlgo): + + def __init__(self, node: CubeComplexEmbedding): + if not isinstance(node, CubeComplexEmbedding): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.dims = node.inputs(0).shape + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + dim = int(config['dim']) + if dim >= len(self.dims): + return False + if chunk_num > 0 and self.dims[dim] % chunk_num == 0: + return True + return False + + def instantiate(self, node: CubeComplexEmbedding, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + dim = int(config['dim']) + + input = node.inputs(0) + weight = node.inputs(1) + start = node.kwargs['start'] + stop = node.kwargs['stop'] + + out = node.outputs(0) + + ins = split_axis(input, dim, self.chunk_num) + outs = split_axis(out, dim, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + inputs = [ + ins[idx], weight, start, stop + ] + node = CubeComplexEmbedding( + signature = 'cube.runtime.function.complex.embedding', + inputs = inputs, + ) + node.set_output(0, outs[idx]) + nodes.append(node) + return nodes + + +class CubeEmbedShardingParallel(GenericDistAlgo): + + def __init__(self, node: CubeComplexEmbedding): + if not isinstance(node, CubeComplexEmbedding): + raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") + super().__init__(node) + self.chunk_num = _kWaitDecision + self.vocabs = node.inputs(1).shape[0] + + def satisfy(self, config: Dict): + chunk_num = int(config['chunk_num']) + if chunk_num > 0 and self.vocabs % chunk_num == 0: + return True + return False + + def instantiate(self, node: CubeComplexEmbedding, config: Dict): + if not self.satisfy(config): + raise RuntimeError("Instantiate failed. Condition not satisfied.") + self.chunk_num = int(config['chunk_num']) + + input = node.inputs(0) + weight = node.inputs(1) + start = node.kwargs['start'] + stop = node.kwargs['stop'] + shard = (stop - start) // self.chunk_num + + out = node.outputs(0) + + ws = split_axis(weight, 0, self.chunk_num) + outs = split_value(out, self.chunk_num) + + nodes = list() + for idx in range(self.chunk_num): + shard_start = start + shard * idx + shard_stop = shard_start + shard + inputs = [ + input, ws[idx], shard_start, shard_stop + ] + node = CubeComplexEmbedding( + signature = 'cube.runtime.function.complex.embedding', + inputs = inputs, + ) + node.set_output(0, outs[idx]) + nodes.append(node) + return nodes diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function.py index 2d553e66..a589c4ad 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function.py @@ -479,7 +479,7 @@ class CubeComplexFeedForward(IRFwOperation): """ def __init__(self, signature, inputs, name='selfattn', **kwargs): if len(inputs) != 5: - raise RuntimeError(f"Expected 6 inputs but got {input}") + raise RuntimeError(f"Expected 6 inputs but got {inputs}") super().__init__( name, signature, input_length = 5, @@ -495,6 +495,32 @@ def infer_shape(self): return True +class CubeComplexEmbedding(IRFwOperation): + """ + Embedding + """ + def __init__(self, signature, inputs, name='embedding', **kwargs): + if len(inputs) != 4: + raise RuntimeError(f"Expected 4 inputs but got {inputs}") + input, weight = inputs[0], inputs[1] + start, stop = inputs[2], inputs[3] + super().__init__( + name, signature, + input_length = 2, + output_length = 1 + ) + self.set_input(0, input) + self.set_input(1, weight) + self.kwargs['start'] = start + self.kwargs['stop'] = stop + + def infer_shape(self): + if self.inputs(0).shape is None or self.inputs(1).shape is None: + return False + self.outputs(0).shape = self.inputs(0).shape + [self.inputs(1).shape[1]] + return True + + class UnkownOperator(IRFwOperation): def __init__(self, signature, inputs, name='unknown_op', n_outputs=None): diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py index e46ddb50..a8c9ff96 100644 --- a/cube/runtime/function/__init__.py +++ b/cube/runtime/function/__init__.py @@ -1,3 +1,2 @@ import cube.runtime.function.complex as complex -from cube.runtime.function.complex import * -from cube.runtime.function.function import embedding \ No newline at end of file +from cube.runtime.function.complex import * \ No newline at end of file diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index 0f4124c6..b3be2b1b 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -141,3 +141,27 @@ def feedforward(hidden_state, w_proj1, w_bias1, w_proj2, w_bias2): out = F.gelu(out) out = F.linear(out, w_proj2, w_bias2) return out + + +def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): + """ + Embedding + + Inputs: + input: torch.Tensor [*] + weight: [vocab size, embed size] + start: int + stop: int + + Outputs: + output: [*, embed_size] + """ + input_mask = (input < start) | (input >= stop) + masked_input = input.clone() - start + masked_input[input_mask] = 0 + output = F.embedding( + masked_input, weight, + None, None, 2.0, False, False + ) + output[input_mask, :] = 0.0 + return output diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py deleted file mode 100644 index 9704f3a1..00000000 --- a/cube/runtime/function/function.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import torch.nn.functional as F - - -def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): - """ - Embedding - """ - input_mask = (input < start) | (input >= stop) - masked_input = input.clone() - start - masked_input[input_mask] = 0 - output = F.embedding( - masked_input, weight, - None, None, 2.0, False, False - ) - output[input_mask, :] = 0.0 - return output - diff --git a/tests/algorithm/test_complex.py b/tests/algorithm/test_complex.py index 958fd6de..3372f24a 100644 --- a/tests/algorithm/test_complex.py +++ b/tests/algorithm/test_complex.py @@ -459,3 +459,85 @@ def test_complex_feedforward_data_parallel(): assert node.inputs(2).shape == [4 * E,] assert node.inputs(3).shape == [E, 4 * E] assert node.inputs(4).shape == [E,] + + +def test_embed_shard_parallel(): + L = 64 # seq len + N = 16 # batch + vocab = 50304 + E = 1024 + + ids = IRFullTensor(shape=[L, N], name='hidden').tosub() + weight = IRFullTensor(shape=[vocab, E], name='hidden').tosub() + start = 0 + stop = vocab + + semantic_op = complex.CubeComplexEmbedding( + signature = 'cube.runtime.function.complex.embedding', + inputs = [ids, weight, start, stop] + ) + semantic_op.infer_shape() + + assert semantic_op.outputs(0).shape == [L, N, E] + + op_shard = complex.CubeEmbedShardingParallel(semantic_op) + + assert op_shard.satisfy(config=dict(chunk_num=8)) + assert op_shard.satisfy(config=dict(chunk_num=32)) + assert not op_shard.satisfy(config=dict(chunk_num=256)) + + nodes = op_shard.instantiate(semantic_op, config=dict(chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexEmbedding) + + start = semantic_op.kwargs['start'] + stop = semantic_op.kwargs['stop'] + shard = (stop - start) // 4 + for idx, node in enumerate(nodes): + assert node.outputs(0).shape == [L, N, E] + assert node.outputs(0).val_map == ValueMap(idx, 4) + assert node.inputs(0).shape == [L, N] + assert node.inputs(1).shape == [vocab // 4, E] + assert node.kwargs['start'] == start + idx * shard + assert node.kwargs['stop'] == start + (idx + 1) * shard + + +def test_embed_shard_parallel(): + L = 64 # seq len + N = 16 # batch + vocab = 50304 + E = 1024 + + ids = IRFullTensor(shape=[L, N], name='hidden').tosub() + weight = IRFullTensor(shape=[vocab, E], name='hidden').tosub() + start = 0 + stop = vocab + + semantic_op = complex.CubeComplexEmbedding( + signature = 'cube.runtime.function.complex.embedding', + inputs = [ids, weight, start, stop] + ) + semantic_op.infer_shape() + + assert semantic_op.outputs(0).shape == [L, N, E] + + op_shard = complex.CubeEmbedDataParallel(semantic_op) + + assert op_shard.satisfy(config=dict(dim=1, chunk_num=8)) + assert not op_shard.satisfy(config=dict(dim=1, chunk_num=32)) + + nodes = op_shard.instantiate(semantic_op, config=dict(dim=1, chunk_num=4)) + assert len(nodes) == 4 + for node in nodes: + assert isinstance(node, complex.CubeComplexEmbedding) + + start = semantic_op.kwargs['start'] + stop = semantic_op.kwargs['stop'] + for idx, node in enumerate(nodes): + assert node.outputs(0).shape == [L, N // 4, E] + assert node.outputs(0).val_map == ValueMap(0, 1) + assert node.inputs(0).shape == [L, N // 4] + assert node.inputs(1).shape == [vocab, E] + assert node.kwargs['start'] == start + assert node.kwargs['stop'] == stop From fc7a4132bba9a925826687ad118a3ec8899efa29 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 15:50:30 +0800 Subject: [PATCH 0395/1892] gpt example --- cube/algorithm/factory.py | 3 + cube/graph/parser/mapping.py | 2 + cube/runtime/executor.py | 3 +- cube/runtime/function/complex.py | 1 + examples/gpt/gpt.py | 144 ++++++++++------------- examples/gpt/policy/megatron_parallel.py | 130 ++++++++++++++++++++ examples/transformer/transformers.py | 10 +- 7 files changed, 204 insertions(+), 89 deletions(-) create mode 100644 examples/gpt/policy/megatron_parallel.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 7a850034..a907047e 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -104,5 +104,8 @@ def _load_predefined_algos(self): self.register(complex.CubeComplexFeedForward, complex.CubeFeedForwardDataParallel, tag='data') self.register(complex.CubeComplexFeedForward, complex.CubeFeedForwardTensorParallel, tag='tensor') + self.register(complex.CubeComplexEmbedding, complex.CubeEmbedDataParallel, tag='data') + self.register(complex.CubeComplexEmbedding, complex.CubeEmbedShardingParallel, tag='shard') + import cube.algorithm.ops.memory as mem self.register(mem.Transpose, mem.TransposeDimParallel, tag='dim') diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index b9fbf36f..fb4999f5 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -69,5 +69,7 @@ def map(signature: str) -> IRFwOperation: __customize('feedforward'): function.CubeComplexFeedForward, + __customize('embedding'): function.CubeComplexEmbedding, + } diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 8ff3c26f..abdecdc7 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -38,7 +38,8 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr grads = [None] * len(input_tensors) if len(inputs) != 0: # print('backwarding... ') - in_grads = torch.autograd.grad(output_tensors, inputs, output_tensor_grads) + in_grads = torch.autograd.grad( + output_tensors, inputs, output_tensor_grads, allow_unused=True) for idx, grad in zip(indices, in_grads): tensor = input_tensors[idx] if isinstance(tensor, torch.nn.Parameter): diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py index b3be2b1b..84111c69 100644 --- a/cube/runtime/function/complex.py +++ b/cube/runtime/function/complex.py @@ -156,6 +156,7 @@ def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): Outputs: output: [*, embed_size] """ + input = input.long() input_mask = (input < start) | (input >= stop) masked_input = input.clone() - start masked_input[input_mask] = 0 diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py index 2b8a52d4..a3e2783a 100644 --- a/examples/gpt/gpt.py +++ b/examples/gpt/gpt.py @@ -17,23 +17,22 @@ import cube -from examples.transformer.policy.tensor_parallel import transform_policy -from examples.transformer.policy.tensor_parallel import schedule_policy +from examples.gpt.policy.megatron_parallel import transform_policy +from examples.gpt.policy.megatron_parallel import schedule_policy from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary class MultiHeadSelfAttention(nn.Module): - def __init__(self, seq_len, embed_dim, heads, dropout): + def __init__(self, embed_dim, heads, dropout): super().__init__() - self.seq_len = seq_len - self.embed_dim = embed_dim self.num_head = heads self.dim_head = embed_dim // heads - self.scale = self.dim_head ** -0.5 + self.dropout = dropout self.weight_qkv = torch.nn.Parameter(torch.empty( 3 * embed_dim, embed_dim @@ -41,81 +40,65 @@ def __init__(self, seq_len, embed_dim, heads, dropout): self.weight_out = torch.nn.Parameter(torch.empty( embed_dim, embed_dim )) - self.dropout = nn.Dropout(dropout) def forward(self, x): """ - x: [L, N, E]: seq_len, batch_size, embedding dimension - output: [L, N, E] - """ - # [L, N, E] -> 3 x [L, (N * num_head), dim_head] - q, k, v = cube.runtime.function.toqkv( - x, self.weight_qkv, self.num_head - ) - - # [L, (N * num_head), dim_head] -> [(N * num_head), L, dim_head] - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - # [(N * num_head), L, dim_head] -> [(N * num_head), L, dim_head] - q = q * self.scale - # [(N * num_head), L, dim_head] -> [(N * num_head), dim_head, L] - k = k.transpose(-2, -1) - # [(N * num_head), L, dim_head] * [(N * num_head), dim_head, L] - # -> [(N * num_head), L, L] - attn = torch.bmm(q, k) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = cube.runtime.function.tril_mask(attn, self.num_head) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = F.softmax(attn, dim=-1) + Multi-Head Self-Attention. - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = self.dropout(attn) - # [(N * num_head), L, L] * [(N * num_head), L, dim_head] - # -> [(N * num_head), L, dim_head] - output = torch.bmm(attn, v) + L: sequence length + N: batch size + E: embedding size - # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] - output = cube.runtime.function.attn_view(output, self.num_head) + Inputs: + hidden_state: [L, N, E] + w_qkv : [3 * num_head * dim_head, E] + w_out : [E, E] - # [L, N, num_head * dim_head] * [E, embed_head * dim_head] - # -> [L, N, E] - output = F.linear(output, self.weight_out) - return output + Outputs: + hidden_state: [L, N, E] + """ + + hidden_state = cube.runtime.function.complex.self_attn( + x, self.weight_qkv, self.weight_out, + self.num_head, self.dim_head, self.dropout + ) + return hidden_state class FFN(torch.nn.Module): def __init__(self, hidden_size: int): super().__init__() - self.dense_h_to_4h = torch.nn.Linear( - hidden_size, 4 * hidden_size + self.proj1_weight = torch.nn.Parameter( + torch.empty(4 * hidden_size, hidden_size) + ) + self.proj1_bias = torch.nn.Parameter( + torch.empty(4 * hidden_size) + ) + self.proj2_weight = torch.nn.Parameter( + torch.empty(hidden_size, 4 * hidden_size) ) - self.dense_4h_to_h = torch.nn.Linear( - 4 * hidden_size, hidden_size + self.proj2_bias = torch.nn.Parameter( + torch.empty(hidden_size) ) def forward(self, hidden_states): - # [L, N, E] * [E, 4E] -> [L, N, 4E] - out = self.dense_h_to_4h(hidden_states) - # [L, N, 4E] -> [L, N, 4E] - out = F.gelu(out) - # [L, N, 4E] * [4E, E] -> [L, N, E] - out = self.dense_4h_to_h(out) - return out + hidden_states = cube.runtime.function.complex.feedforward( + hidden_states, + self.proj1_weight, self.proj1_bias, + self.proj2_weight, self.proj2_bias + ) + return hidden_states class TransformerLayer(torch.nn.Module): - def __init__(self, seq_len, hidden_size, head_num, dropout): + def __init__(self, hidden_size, num_head, dropout): super().__init__() # layer norm self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - self.attention = MultiHeadSelfAttention(seq_len, hidden_size, head_num, dropout) + self.attention = MultiHeadSelfAttention(hidden_size, num_head, dropout) self.attn_dropout = torch.nn.Dropout(dropout) self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) @@ -138,22 +121,6 @@ def forward(self, hidden_states): # ffn_out = ffn_out + residual ffn_out = ffn_out * 2 return ffn_out - - -class Embedding(torch.nn.Module): - - def __init__(self, num_embed, dim_embed, dropout): - super().__init__() - self.num_embed = num_embed - self.weight = torch.nn.Parameter( - torch.empty(self.num_embed, dim_embed) - ) - - def forward(self, input): - embeddings = cube.runtime.function.embedding( - 0, self.num_embed, input, self.weight - ) - return embeddings class GPT(torch.nn.Module): @@ -165,6 +132,7 @@ def __init__(self, hidden_size, vocab_size, seqlen_size, self.num_layers = num_layers self.bs = bs self.seqlen = seqlen + self.ntoken = 1.0 / self.bs * self.seqlen # embeddings self.vocab_size = vocab_size @@ -179,9 +147,13 @@ def __init__(self, hidden_size, vocab_size, seqlen_size, self.embed_dropout = torch.nn.Dropout(0.5) # transformer layers - self.layers = torch.nn.ModuleList( - [TransformerLayer(seqlen, hidden_size, num_head, 0.5) for _ in range(num_layers)] - ) + # self.layers = torch.nn.ModuleList( + # [TransformerLayer(seqlen, hidden_size, num_head, 0.5) for _ in range(num_layers)] + # ) + self.transform1 = TransformerLayer(hidden_size, num_head, 0.5) + self.transform2 = TransformerLayer(hidden_size, num_head, 0.5) + self.transform3 = TransformerLayer(hidden_size, num_head, 0.5) + self.transform4 = TransformerLayer(hidden_size, num_head, 0.5) # final linear self.final_layernorm = torch.nn.LayerNorm( @@ -209,31 +181,37 @@ def forward(self, input_ids, position_ids): encoder_input = self.embed_dropout(embeddings) # [bs, seqlen, hidden size] -> [seqlen, bs, hidden size] - hidden_states = encoder_input.transpose(0, 1).contiguous() + hidden_states = encoder_input.transpose(0, 1) #.contiguous() # transformer # [seqlen, bs, hidden size] -> [seqlen, bs, hidden size] - for layer in self.layers: - hidden_states = layer(hidden_states) + # for layer in self.layers: + # hidden_states = layer(hidden_states) + hidden_states = self.transform1(hidden_states) + hidden_states = self.transform2(hidden_states) + hidden_states = self.transform3(hidden_states) + hidden_states = self.transform4(hidden_states) hidden_states = self.final_layernorm(hidden_states) # post process # [seqlen, bs, hidden size] -> [bs, seqlen, hidden size] - hidden_states = hidden_states.transpose(0, 1).contiguous() + hidden_states = hidden_states.transpose(0, 1) # .contiguous() # [bs, seqlen, hidden size] * [self.vocab_size, hidden size] # => [bs, seqlen, self.vocab_size] logits = F.linear(hidden_states, self.vocab_embed_weight) # loss # for verification, the mask is ommitted - loss = torch.sum(logits) / (self.seqlen * self.bs) + # [bs, seqlen, self.vocab_size] -> [1] + loss = torch.sum(logits) + # loss = loss * self.ntoken return loss def train(): L = 512 # seq len - N = 1 # batch size + N = 4 # batch size # configs: [hidden size, num_head] # E, num_head = [1536, 16] # 1.2B model # E, num_head = [1920, 20] # 2.5B model @@ -275,6 +253,8 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + memory_summary() + print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-40, field_name='e2e'))) diff --git a/examples/gpt/policy/megatron_parallel.py b/examples/gpt/policy/megatron_parallel.py new file mode 100644 index 00000000..e5229a52 --- /dev/null +++ b/examples/gpt/policy/megatron_parallel.py @@ -0,0 +1,130 @@ +from cube.graph import IRGraph +from cube.graph.operator.function import CubeComplexEmbedding, Linear, Sum +from cube.graph.operator.function import CubeComplexFeedForward +from cube.graph.operator.function import CubeComplexSelfAttention +from cube.graph.operator.function import Transpose +from cube.schedule.su import SUType +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRDataOperation, IRFwOperation + + +def transform_policy(graph: IRGraph, resource): + """ + The transformation policy transposes linear using tensor parallel + """ + print('> transforming graph...') + ndevs = resource.ngpus + dp = 2 + tp = ndevs // dp + + # dataloader + + dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] + for dnode in dnodes: + algo = dnode.algorithms('data') + dp_nodes = graph.partition(dnode, algo, config=dict(chunk_num=dp)) + for idx, dp_node in enumerate(dp_nodes): + dp_node.tag = idx * tp + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # preprocess before transformer + for fnode in fnodes[:5]: + sub_nodes = list() + if isinstance(fnode, CubeComplexEmbedding): + algo = fnode.algorithms('data') + dp_nodes = graph.partition(fnode, algo, config=dict(dim=0, chunk_num=dp)) + if dp_nodes[0].inputs(1).shape[0] >= 50000: + for dp_node in dp_nodes: + algo = dp_node.algorithms('shard') + tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) + sub_nodes += tp_nodes + else: + for dp_node in dp_nodes: + tp_nodes = graph.replicate(dp_node, times=tp) + sub_nodes += tp_nodes + else: + algo = fnode.algorithms('dim') + assert algo + dp_nodes = graph.partition(fnode, algo, config=dict(dim=0, chunk_num=dp)) + for dp_node in dp_nodes: + tp_nodes = graph.replicate(dp_node, times=tp) + sub_nodes += tp_nodes + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + # transformers + for fnode in fnodes[5:-3]: + sub_nodes = list() + if isinstance(fnode, CubeComplexSelfAttention): + algo = fnode.algorithms('data') + dp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=dp)) + for dp_node in dp_nodes: + algo = dp_node.algorithms('head') + tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) + sub_nodes += tp_nodes + elif isinstance(fnode, CubeComplexFeedForward): + algo = fnode.algorithms('data') + dp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=dp)) + for dp_node in dp_nodes: + algo = dp_node.algorithms('tensor') + tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) + sub_nodes += tp_nodes + else: + # note replicate should put in the last due to bugs: + algo = fnode.algorithms('dim') + dp_nodes = graph.partition(fnode, algo, config=dict(dim=1, chunk_num=dp)) + for dp_node in dp_nodes: + rep_nodes = graph.replicate(dp_node, times=tp) + sub_nodes += rep_nodes + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + # post-process + for fnode in fnodes[-3:]: + sub_nodes = list() + if isinstance(fnode, Transpose): + algo = fnode.algorithms('dim') + dp_nodes = graph.partition(fnode, algo, config=dict(dim=1, chunk_num=dp)) + for dp_node in dp_nodes: + rep_nodes = graph.replicate(dp_node, times=tp) + sub_nodes += rep_nodes + elif isinstance(fnode, Linear): + algo = fnode.algorithms('data') + dp_nodes = graph.partition(fnode, algo, config=dict(dim=0, chunk_num=dp)) + for dp_node in dp_nodes: + algo = dp_node.algorithms('column') + tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) + sub_nodes += tp_nodes + elif isinstance(fnode, Sum): + algo = fnode.algorithms('dim') + dp_nodes = graph.partition(fnode, algo, config=dict(dim=0, chunk_num=dp)) + for dp_node in dp_nodes: + rep_nodes = graph.replicate(dp_node, times=tp) + sub_nodes += rep_nodes + else: + rep_nodes = graph.replicate(fnode, times=ndevs) + sub_nodes += rep_nodes + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + + print(graph) + # assert False + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy assign devices + """ + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + devid = su.tag[0] + sugraph.assign(su, devid) + for su in sugraph.fsus(): + devid = su.tag[0] + sugraph.assign(su, devid) + sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + sugraph.partial_set_order(fsus, lazy=False) + return sugraph diff --git a/examples/transformer/transformers.py b/examples/transformer/transformers.py index c1331459..305243d7 100644 --- a/examples/transformer/transformers.py +++ b/examples/transformer/transformers.py @@ -117,14 +117,12 @@ def forward(self, hidden_states): ffn_out = self.ffn_dropout(ffn_out) # ffn_out = ffn_out + residual ffn_out = ffn_out * 2 - - loss = torch.sum(ffn_out) - return loss + return ffn_out class Transformers(torch.nn.Module): - def __init__(self, hidden_size, head_num, layer_num): + def __init__(self, hidden_size, head_num): super().__init__() self.transformer1 = TransformerLayer(hidden_size, head_num, 0.5) @@ -152,8 +150,8 @@ def train(): E, num_head = [3072, 32] # 8.7B model - model = TransformerLayer( - hidden_size=E, head_num=num_head, dropout=0.5 + model = Transformers( + hidden_size=E, head_num=num_head ) model = cube.SemanticModel( model, input_shapes=([L, N, E],), From 0a0eac03ccc6266077c8e0af81f2594d973aea5f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 16:24:13 +0800 Subject: [PATCH 0396/1892] pipeline 1f1b for transformers --- cube/compiler.py | 4 +- .../transformer/policy/pipeline_parallel.py | 109 ++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 examples/transformer/policy/pipeline_parallel.py diff --git a/cube/compiler.py b/cube/compiler.py index e575f925..5d621953 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -153,12 +153,12 @@ def decorator(fn: Callable) -> Callable: execplan = RemoveRedundantAdapters.apply(execplan) # print(f'> after remove redundant adapters:\n {execplan}') execplan = MergeComputeSU.apply(execplan) - print(f'> after merge backward SU:\n {execplan}') + # print(f'> after merge backward SU:\n {execplan}') execplan = WeightGradAllreduceFusion.apply(execplan) # print(f'> after add allreduce:\n{execplan}') execplan = P2PFusion.apply(execplan) - print(f'> after fuse P2P SU:\n {execplan}') + # print(f'> after fuse P2P SU:\n {execplan}') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() diff --git a/examples/transformer/policy/pipeline_parallel.py b/examples/transformer/policy/pipeline_parallel.py new file mode 100644 index 00000000..366fb23b --- /dev/null +++ b/examples/transformer/policy/pipeline_parallel.py @@ -0,0 +1,109 @@ +from typing import List + +from cube.schedule.su import SUType, ScheduleUnit +from cube.schedule.sugraph import SUGraph +from cube.graph.operator.operator import IRDataOperation, IRFwOperation + +_batch_size = 8 +_micro_batch_size = 1 + +def transform_policy(graph, resource): + """ + The transformation policy transposes linear using data parallel + """ + print('> transforming graph...') + micro_batch_num = _batch_size // _micro_batch_size + + for node in graph.nodes(): + if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): + algo = node.algorithms('data') + if algo is not None: + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=micro_batch_num)) + else: + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, config=dict(dim=1, chunk_num=micro_batch_num)) + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx + print('> [Done] transforming graph...') + return graph + + +def schedule_policy(sugraph: SUGraph, resource): + """ + The schedule policy + """ + print('> scheduling su graph...') + num_micro_batch = _batch_size // _micro_batch_size + # each device is a stage + num_stage = resource.ngpus + + fseqs: List[List[ScheduleUnit]] = [list() for _ in range(num_micro_batch)] + fbseqs: List[List[ScheduleUnit]] = [list() for _ in range(num_micro_batch)] + + for fsu in sugraph.fsus(): + micro_bs_id = fsu.tag[0] + fseqs[micro_bs_id].append(fsu) + + for micro_bs_id, fseq in enumerate(fbseqs): + bseq = [fsu.mirror for fsu in fseq][::-1] + fbseqs[micro_bs_id] = fseq + bseq + + print(f'> collect {len(fseqs)} forward-backward sequence') + + # fstages[micro_batch_id][stage] = fstages[micro_batch_id * num_stage + stage] + fstages: List[List[ScheduleUnit]] = [ + list() for _ in range(num_micro_batch * num_stage) + ] + + def f(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: + return fstages[micro_batch_id * num_stage + stage_id] + + def b(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: + fstage = f(micro_batch_id, stage_id) + bstage = [fsu.mirror for fsu in fstage][::-1] + return bstage + + # assign su to stages + for micro_bid, fseq in enumerate(fseqs): + chunk_num = int(len(fseq) // resource.ngpus) + for idx, fsu in enumerate(fseq): + stage = min(int(idx // chunk_num), num_stage - 1) + fstages[micro_bid * num_stage + stage].append(fsu) + + # stage device assignment + for micro_bid in range(num_micro_batch): + for stage in range(num_stage): + for su in f(micro_bid, stage): + sugraph.assign(su, stage) + sugraph.assign(su.mirror, stage) + + # device assignment + for su in sugraph.sus(): + if su.stype == SUType.Dataloader: + sugraph.assign(su, 0) + + # 1f1b scheduling + seqs = list() + + # warmup + for stage in range(num_stage): + for mid in range(stage): + seqs += f(mid, stage) + + # steady + cooldown: + for mid in range(num_micro_batch): + # enqueue backward + for stage in range(num_stage-1, -1, -1): + seqs += b(mid, stage) + # enqueue forward + for stage in range(num_stage): + f_mid = mid + 1 + num_stage - stage + if f_mid >= num_micro_batch: + continue + seqs += f(f_mid, stage) + + sugraph.partial_set_order(seqs) + + print('> [Done] scheduling su graph') + # print(sugraph) + return sugraph From 914cbd887a39de59379f8cc8c2e84db64aa2369c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 18:38:00 +0800 Subject: [PATCH 0397/1892] benchmark on megatron --- benchmark/megatron/gpt.py | 154 ++++++++++++++++++++++++++++++ benchmark/megatron/layers.py | 32 +++++++ benchmark/megatron/transformer.py | 14 +-- 3 files changed, 194 insertions(+), 6 deletions(-) create mode 100644 benchmark/megatron/gpt.py diff --git a/benchmark/megatron/gpt.py b/benchmark/megatron/gpt.py new file mode 100644 index 00000000..861e074a --- /dev/null +++ b/benchmark/megatron/gpt.py @@ -0,0 +1,154 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + benchmark/megatron/gpt.py +""" + +import torch +import torch.nn.functional as F +import cube +from benchmark.megatron.layers import ColumnOutputAdapter, ShardEmbedding +from benchmark.megatron.transformer import TransformerLayer + + +from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary +from cube.profiler.timer import print_each_rank + + +class GPT(torch.nn.Module): + + def __init__(self, hidden_size, vocab_size, seqlen_size, + bs, seqlen, num_head, num_layers: int): + super().__init__() + + self.num_layers = num_layers + self.bs = bs + self.seqlen = seqlen + self.ntoken = 1.0 / self.bs * self.seqlen + + # embeddings + + self.vocab_size = vocab_size + self.vocab_embedding = ShardEmbedding(self.vocab_size, hidden_size) + self.seqlen_size = seqlen_size + self.pos_embed_weight = torch.nn.Parameter( + torch.empty(seqlen_size, hidden_size) + ) + + self.embed_dropout = torch.nn.Dropout(0.5) + + self.transform1 = TransformerLayer(seqlen, hidden_size, num_head, 0.5) + self.transform2 = TransformerLayer(seqlen, hidden_size, num_head, 0.5) + self.transform3 = TransformerLayer(seqlen, hidden_size, num_head, 0.5) + self.transform4 = TransformerLayer(seqlen, hidden_size, num_head, 0.5) + + # final linear + self.final_layernorm = torch.nn.LayerNorm( + hidden_size, 1e-5 + ) + + def forward(self, input_ids, position_ids): + """ + input_ids: + [bs, seqlen] + position_ids: + [bs, seqlen] + """ + + # preprocess: embedding + # [bs, seqlen] -> [bs, seqlen, hidden size] + words_embeddings = self.vocab_embedding(input_ids) + + # [bs, seqlen] -> [bs, seqlen, hidden size] + position_embeddings = cube.runtime.function.embedding( + position_ids, self.pos_embed_weight, 0, self.seqlen_size + ) + embeddings = words_embeddings + position_embeddings + encoder_input = self.embed_dropout(embeddings) + + # [bs, seqlen, hidden size] -> [seqlen, bs, hidden size] + hidden_states = encoder_input.transpose(0, 1) + + hidden_states = self.transform1(hidden_states) + hidden_states = self.transform2(hidden_states) + hidden_states = self.transform3(hidden_states) + hidden_states = self.transform4(hidden_states) + + hidden_states = self.final_layernorm(hidden_states) + + # post process + hidden_states = hidden_states.transpose(0, 1) # .contiguous() + logits = F.linear(hidden_states, self.vocab_embedding.weight) + # all gather + logits = ColumnOutputAdapter.apply(logits) + + # loss # for verification, the mask is ommitted + # [bs, seqlen, self.vocab_size] -> [1] + loss = torch.sum(logits) + # loss = loss * self.ntoken + return loss + +def train(): + L = 512 # seq len + N = 8 # batch size + # configs: [hidden size, num_head] + # E, num_head = [2304, 24, 24] # 1.7B model + E, num_head, layers = [3072, 32, 30] # 3.6B model + # E, num_head, layers = [4096, 32, 36] # 7.5B model + + print_each_rank('config: L={}, N={}, E={}, num-head={}'.format( + L, N, E, num_head + )) + + + model = GPT( + hidden_size=E, vocab_size=50304, seqlen_size=L, + bs=N, seqlen=L, num_head=num_head, num_layers=layers + ).cuda() + + dataloader = cube.runtime.syndata.SynTextDataLoader(1280, [0, 0], [N, L], [N, L]) + + def train_iter(model, dataloader): + input_ids, position_ids = next(dataloader) + torch.distributed.broadcast(input_ids, 0) + torch.distributed.broadcast(position_ids, 0) + torch.cuda.synchronize() + loss = model(input_ids, position_ids) + loss.backward() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + + +if __name__ == '__main__': + + cube.init() + train() \ No newline at end of file diff --git a/benchmark/megatron/layers.py b/benchmark/megatron/layers.py index 47ee055e..3ec1efd2 100644 --- a/benchmark/megatron/layers.py +++ b/benchmark/megatron/layers.py @@ -156,3 +156,35 @@ def forward(self, input_): output = output_parallel return output + +class ShardEmbedding(torch.nn.Module): + + def __init__(self, num_embeddings, embedding_dim): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + self.shard_num = torch.distributed.get_world_size() + self.myshard = torch.distributed.get_rank() + + shard_num_embeddings = self.num_embeddings // self.shard_num + self.vocab_start_index = shard_num_embeddings * self.myshard + self.vocab_end_index = self.vocab_start_index + shard_num_embeddings + + self.weight = torch.nn.Parameter( + torch.empty(shard_num_embeddings, self.embedding_dim) + ) + + def forward(self, input_): + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | \ + (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + output_parallel = F.embedding( + masked_input, self.weight, + None, None, 2., False, False + ) + output = RowOutputAdapter.apply(output_parallel) + return output diff --git a/benchmark/megatron/transformer.py b/benchmark/megatron/transformer.py index 89d391b2..e0dcbdc6 100644 --- a/benchmark/megatron/transformer.py +++ b/benchmark/megatron/transformer.py @@ -153,9 +153,7 @@ def forward(self, hidden_states): ffn_out = self.ffn_dropout(ffn_out) # ffn_out = ffn_out + residual ffn_out = ffn_out * 2 - - loss = torch.sum(ffn_out) - return loss + return ffn_out def train(): @@ -178,7 +176,8 @@ def train_iter(model, dataloader): data = next(dataloader) torch.distributed.broadcast(data, 0) torch.cuda.synchronize() - loss = model(data) + out = model(data) + loss = torch.sum(out) loss.backward() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) @@ -197,8 +196,11 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) if __name__ == '__main__': From 4b8b214cb2dc8396c222c4dbbcece9d86abdaa6d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 18:39:30 +0800 Subject: [PATCH 0398/1892] clean debug info --- cube/compiler.py | 8 ++++++++ cube/execplan/execplan.py | 6 ++++-- cube/execplan/planpass/p2pfusion.py | 8 ++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 5d621953..b97ffb8d 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,5 +1,6 @@ from typing import Callable, Optional, Tuple import torch +import time import cube from cube.graph.graph import IRGraph @@ -118,6 +119,9 @@ def decorator(fn: Callable) -> Callable: filename = 'gencode{}.py' batch_size = torch.tensor([-1], dtype=torch.int).cuda() if myrank == 0: + + compile_start = time.time() + SchedulePool().clear() resource = cube.runtime.resource.EnvResource() @@ -193,6 +197,10 @@ def decorator(fn: Callable) -> Callable: # assume batch_size is always first dimension batch_size = torch.tensor([batch_size], dtype=torch.int).cuda() + compile_end = time.time() + compile_time = compile_end - compile_start + print(f'> compile time: {compile_time} seconds') + if torch.distributed.is_initialized(): torch.distributed.barrier() diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 455a30ab..14fc010b 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -69,7 +69,7 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): outfile: the output file name """ - ndevice = max(self.device_seq.keys()) + 1 + ndevice = len(self.devices()) # timeline [ [ (start_time, end_time), ... ], ... ] device_timeline = [list() for _ in range(ndevice)] device_sus = [list() for _ in range(ndevice)] @@ -82,8 +82,10 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): span = 1 elif su.stype == SUType.Backward: span = 2 - elif su.stype == SUType.P2P: + elif su.stype in [SUType.P2P, SUType.Transform]: span = 0.1 + else: + span = 0 spans.append(span) for su, span_time in zip(self.seq.sequence, spans): diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index e88070d1..c20d07b5 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -114,7 +114,7 @@ def have_comm(tensor_ous, tensor_ins): @staticmethod def add_collectives(execplan: ExectuionPlan, coll_sus: List[ScheduleUnit]): for coll_su in coll_sus: - print(f'inserting Collective SU: {coll_su.name}: {coll_su}') + # print(f'inserting Collective SU: {coll_su.name}: {coll_su}') # find insert place: the first send devid = coll_su.device[0] ranks = coll_su.nodes(0).ranks @@ -483,9 +483,9 @@ def match_broadcast(tous, tins): su.device = rank broadcast_sus.append(su) - print('>> find broadcast pattern:') - print(f'device group: {ranks}') - print(su) + # print('>> find broadcast pattern:') + # print(f'device group: {ranks}') + # print(su) if len(broadcast_sus) == 0: return None From 612f59eb143066aa1c77fc02b81361e02a148ecb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 18:40:21 +0800 Subject: [PATCH 0399/1892] gpt benchmark --- examples/gpt/gpt.py | 18 +++++++++--------- examples/gpt/policy/megatron_parallel.py | 5 +++-- examples/inspector.py | 18 ++++++++++++++---- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py index a3e2783a..e9952796 100644 --- a/examples/gpt/gpt.py +++ b/examples/gpt/gpt.py @@ -211,13 +211,11 @@ def forward(self, input_ids, position_ids): def train(): L = 512 # seq len - N = 4 # batch size + N = 8 # batch size # configs: [hidden size, num_head] - # E, num_head = [1536, 16] # 1.2B model - # E, num_head = [1920, 20] # 2.5B model - # E, num_head = [2304, 24] # 4.2B model - E, num_head = [3072, 32] # 8.7B model - layers = 4 + # E, num_head = [2304, 24, 24] # 1.7B model + E, num_head, layers = [3072, 32, 30] # 3.6B model + # E, num_head, layers = [4096, 32, 36] # 7.5B model model = GPT( @@ -253,10 +251,12 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) memory_summary() - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) if __name__ == '__main__': diff --git a/examples/gpt/policy/megatron_parallel.py b/examples/gpt/policy/megatron_parallel.py index e5229a52..78c8534e 100644 --- a/examples/gpt/policy/megatron_parallel.py +++ b/examples/gpt/policy/megatron_parallel.py @@ -14,7 +14,7 @@ def transform_policy(graph: IRGraph, resource): """ print('> transforming graph...') ndevs = resource.ngpus - dp = 2 + dp = 1 tp = ndevs // dp # dataloader @@ -108,7 +108,7 @@ def transform_policy(graph: IRGraph, resource): for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx - print(graph) + # print(graph) # assert False return graph @@ -117,6 +117,7 @@ def schedule_policy(sugraph: SUGraph, resource): """ The schedule policy assign devices """ + print('> scheduling SU...') for su in sugraph.sus(): if su.stype == SUType.Dataloader: devid = su.tag[0] diff --git a/examples/inspector.py b/examples/inspector.py index a3bc7700..55a979de 100644 --- a/examples/inspector.py +++ b/examples/inspector.py @@ -16,11 +16,16 @@ import cube from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary from cube.profiler.timer import print_each_rank - -kDataShapes = ([512, 32, 3072],) -kBatchDims = [1] +L, N, E = (512, 4, 3072) +# gpt +kBatchDims = [0, 0] +kDataShapes = ([N // 2, L], [N // 2, L]) +# transformer +# kBatchDims = [1] +# kDataShapes = ([512, 4, 3072],) def load_module(filename: str): @@ -51,7 +56,10 @@ def load_train_fn(filename: str): def train(args): global kDataShapes - dataloader = cube.runtime.syndata.SynDataLoader( + # dataloader = cube.runtime.syndata.SynDataLoader( + # 1280, kBatchDims, *kDataShapes + # ) + dataloader = cube.runtime.syndata.SynTextDataLoader( 1280, kBatchDims, *kDataShapes ) @@ -82,6 +90,8 @@ def train(args): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-40, field_name='e2e'))) + memory_summary() + if __name__ == '__main__': From 2ad04c472a8eebc117e180d719d659240acf4db6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 18:46:29 +0800 Subject: [PATCH 0400/1892] benchmark tools --- eval/benchmark_gpt.sh | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 eval/benchmark_gpt.sh diff --git a/eval/benchmark_gpt.sh b/eval/benchmark_gpt.sh new file mode 100644 index 00000000..5a2ef2e8 --- /dev/null +++ b/eval/benchmark_gpt.sh @@ -0,0 +1,29 @@ + +echo benchmarking gpt megatron hybrid parallelism... + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + benchmark/megatron/gpt.py > mydata/MagicCube/expdata/8B.2V100.Megatron.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + benchmark/megatron/gpt.py > mydata/MagicCube/expdata/8B.4V100.Megatron.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + benchmark/megatron/gpt.py > mydata/MagicCube/expdata/8B.8V100.Megatron.txt From c456572704cc65be62199016b3d41d11ec0415a2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 18:48:31 +0800 Subject: [PATCH 0401/1892] make runnable --- eval/benchmark_gpt.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 eval/benchmark_gpt.sh diff --git a/eval/benchmark_gpt.sh b/eval/benchmark_gpt.sh old mode 100644 new mode 100755 From bd2dd065e7ef38388d0438053eb1349e295fc15d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 18:51:02 +0800 Subject: [PATCH 0402/1892] benchmark memory tool --- cube/profiler/memory.py | 11 +++++++++++ eval/benchmark_gpt.sh | 6 +++--- 2 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 cube/profiler/memory.py diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py new file mode 100644 index 00000000..919b8fe3 --- /dev/null +++ b/cube/profiler/memory.py @@ -0,0 +1,11 @@ +import torch +from cube.profiler.timer import print_each_rank + +def memory_summary(): + rank = torch.distributed.get_rank() + # memory measurement + mem = torch.cuda.max_memory_allocated() + # mem = torch.cuda.max_memory_reserved() + print( + '{:.2f}GB memory consumption'.format(mem / 1024 / 1024 / 1024), + ) diff --git a/eval/benchmark_gpt.sh b/eval/benchmark_gpt.sh index 5a2ef2e8..aca5e7e3 100755 --- a/eval/benchmark_gpt.sh +++ b/eval/benchmark_gpt.sh @@ -8,7 +8,7 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - benchmark/megatron/gpt.py > mydata/MagicCube/expdata/8B.2V100.Megatron.txt + benchmark/megatron/gpt.py > /mydata/MagicCube/expdata/8B.2V100.Megatron.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -17,7 +17,7 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - benchmark/megatron/gpt.py > mydata/MagicCube/expdata/8B.4V100.Megatron.txt + benchmark/megatron/gpt.py > /mydata/MagicCube/expdata/8B.4V100.Megatron.txt python -m torch.distributed.launch \ --nproc_per_node=8 \ @@ -26,4 +26,4 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - benchmark/megatron/gpt.py > mydata/MagicCube/expdata/8B.8V100.Megatron.txt + benchmark/megatron/gpt.py > /mydata/MagicCube/expdata/8B.8V100.Megatron.txt From 63b61fa9f3d6b962736340de9f3c333ebfc15cfb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 20:09:43 +0800 Subject: [PATCH 0403/1892] residual is correct --- benchmark/megatron/transformer.py | 6 ++--- cube/runtime/collectives.py | 1 + cube/runtime/executor.py | 2 +- examples/gpt/gpt.py | 12 ++++----- examples/gpt/policy/megatron_parallel.py | 14 +++++----- .../transformer/policy/megatron_parallel.py | 4 +-- examples/transformer/transformer.py | 6 ++--- examples/transformer/transformers.py | 26 ++++++++++--------- 8 files changed, 34 insertions(+), 37 deletions(-) diff --git a/benchmark/megatron/transformer.py b/benchmark/megatron/transformer.py index e0dcbdc6..9c6f5ac9 100644 --- a/benchmark/megatron/transformer.py +++ b/benchmark/megatron/transformer.py @@ -144,15 +144,13 @@ def forward(self, hidden_states): attn_out = self.attention(in_attn_norm) # residual attn_out = self.attn_dropout(attn_out) - # residual = attn_out + hidden_states - residual = attn_out * 2 + residual = attn_out + hidden_states # ffn in_ffn_norm = self.ffn_layernorm(residual) ffn_out = self.ffn(in_ffn_norm) # residual ffn_out = self.ffn_dropout(ffn_out) - # ffn_out = ffn_out + residual - ffn_out = ffn_out * 2 + ffn_out = ffn_out + residual return ffn_out diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 4ab66dfd..7cabe3c8 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -158,6 +158,7 @@ def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None): Broadcast. ranks[0] is the root """ # print(f'{torch.distributed.get_rank()}: broadcast...') + # FIXME: data type if len(tensors) == 1: tensor = tensors[0] else: diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index abdecdc7..3f1ad947 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -37,7 +37,7 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr grads = [None] * len(input_tensors) if len(inputs) != 0: - # print('backwarding... ') + # print(f'{torch.distributed.get_rank()}: backwarding... ') in_grads = torch.autograd.grad( output_tensors, inputs, output_tensor_grads, allow_unused=True) for idx, grad in zip(indices, in_grads): diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py index e9952796..e35d06d9 100644 --- a/examples/gpt/gpt.py +++ b/examples/gpt/gpt.py @@ -2,13 +2,13 @@ example: python -m torch.distributed.launch \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/transformer/transformer.py + examples/gpt/gpt.py """ import torch @@ -111,15 +111,13 @@ def forward(self, hidden_states): attn_out = self.attention(in_attn_norm) # residual attn_out = self.attn_dropout(attn_out) - # residual = attn_out + hidden_states - residual = attn_out * 2 + residual = attn_out + hidden_states # ffn in_ffn_norm = self.ffn_layernorm(residual) ffn_out = self.ffn(in_ffn_norm) # residual ffn_out = self.ffn_dropout(ffn_out) - # ffn_out = ffn_out + residual - ffn_out = ffn_out * 2 + ffn_out = ffn_out + residual return ffn_out @@ -246,6 +244,8 @@ def train_iter(model, dataloader): train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() + if step == 1: + print('> passed on iteration') if step >= 40: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: diff --git a/examples/gpt/policy/megatron_parallel.py b/examples/gpt/policy/megatron_parallel.py index 78c8534e..c41d97fd 100644 --- a/examples/gpt/policy/megatron_parallel.py +++ b/examples/gpt/policy/megatron_parallel.py @@ -21,10 +21,14 @@ def transform_policy(graph: IRGraph, resource): dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] for dnode in dnodes: + sub_nodes = list() algo = dnode.algorithms('data') dp_nodes = graph.partition(dnode, algo, config=dict(chunk_num=dp)) - for idx, dp_node in enumerate(dp_nodes): - dp_node.tag = idx * tp + for dp_node in dp_nodes: + tp_nodes = graph.replicate(dp_node, times=tp) + sub_nodes += tp_nodes + for idx, sub_node in enumerate(sub_nodes): + sub_node.tag = idx fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] @@ -96,12 +100,6 @@ def transform_policy(graph: IRGraph, resource): algo = dp_node.algorithms('column') tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) sub_nodes += tp_nodes - elif isinstance(fnode, Sum): - algo = fnode.algorithms('dim') - dp_nodes = graph.partition(fnode, algo, config=dict(dim=0, chunk_num=dp)) - for dp_node in dp_nodes: - rep_nodes = graph.replicate(dp_node, times=tp) - sub_nodes += rep_nodes else: rep_nodes = graph.replicate(fnode, times=ndevs) sub_nodes += rep_nodes diff --git a/examples/transformer/policy/megatron_parallel.py b/examples/transformer/policy/megatron_parallel.py index bab64c54..15a00cb6 100644 --- a/examples/transformer/policy/megatron_parallel.py +++ b/examples/transformer/policy/megatron_parallel.py @@ -11,7 +11,7 @@ def transform_policy(graph: IRGraph, resource): """ print('> transforming graph...') ndevs = resource.ngpus - dp = 2 + dp = 1 tp = ndevs // dp # dataloader @@ -51,7 +51,7 @@ def transform_policy(graph: IRGraph, resource): for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx - print(graph) + # print(graph) # assert False return graph diff --git a/examples/transformer/transformer.py b/examples/transformer/transformer.py index c28b6407..f54c6d6e 100644 --- a/examples/transformer/transformer.py +++ b/examples/transformer/transformer.py @@ -128,15 +128,13 @@ def forward(self, hidden_states): attn_out = self.attention(in_attn_norm) # residual attn_out = self.attn_dropout(attn_out) - # residual = attn_out + hidden_states - residual = attn_out * 2 + residual = attn_out + hidden_states # ffn in_ffn_norm = self.ffn_layernorm(residual) ffn_out = self.ffn(in_ffn_norm) # residual ffn_out = self.ffn_dropout(ffn_out) - # ffn_out = ffn_out + residual - ffn_out = ffn_out * 2 + ffn_out = ffn_out + residual loss = torch.sum(ffn_out) return loss diff --git a/examples/transformer/transformers.py b/examples/transformer/transformers.py index 305243d7..7eddf913 100644 --- a/examples/transformer/transformers.py +++ b/examples/transformer/transformers.py @@ -20,6 +20,7 @@ from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary class MultiHeadSelfAttention(nn.Module): @@ -108,15 +109,13 @@ def forward(self, hidden_states): attn_out = self.attention(in_attn_norm) # residual attn_out = self.attn_dropout(attn_out) - # residual = attn_out + hidden_states - residual = attn_out * 2 + residual = attn_out + hidden_states # ffn in_ffn_norm = self.ffn_layernorm(residual) ffn_out = self.ffn(in_ffn_norm) # residual ffn_out = self.ffn_dropout(ffn_out) - # ffn_out = ffn_out + residual - ffn_out = ffn_out * 2 + ffn_out = ffn_out + residual return ffn_out @@ -142,12 +141,11 @@ def forward(self, hidden_states): def train(): L = 512 # seq len - N = 16 # batch size + N = 8 # batch size # configs: [hidden size, num_head] - # E, num_head = [1536, 16] # 1.2B model - # E, num_head = [1920, 20] # 2.5B model - # E, num_head = [2304, 24] # 4.2B model - E, num_head = [3072, 32] # 8.7B model + # E, num_head = [2304, 24, 24] # 1.7B model + E, num_head, layers = [3072, 32, 30] # 3.6B model + # E, num_head, layers = [4096, 32, 36] # 7.5B model model = Transformers( @@ -181,9 +179,13 @@ def train_iter(model, dataloader): CudaTimer().stop('e2e') if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) + + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() if __name__ == '__main__': From df329699f1cfb48f28be9dce5412a4f6fe26d324 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 20:42:51 +0800 Subject: [PATCH 0404/1892] benchmark tool --- eval/benchmark_gpt.sh | 27 +++++++++++++++++++++++++++ scripts/env-setup.sh | 4 ++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/eval/benchmark_gpt.sh b/eval/benchmark_gpt.sh index aca5e7e3..8d528214 100755 --- a/eval/benchmark_gpt.sh +++ b/eval/benchmark_gpt.sh @@ -27,3 +27,30 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ benchmark/megatron/gpt.py > /mydata/MagicCube/expdata/8B.8V100.Megatron.txt + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/gpt/gpt.py > /mydata/MagicCube/expdata/8B.2V100.Cube.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/gpt/gpt.py > /mydata/MagicCube/expdata/8B.4V100.Cube.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/gpt/gpt.py > /mydata/MagicCube/expdata/8B.8V100.Cube.txt diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index 80843eba..44ce1bdb 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -1,5 +1,5 @@ -echo using docker image pytorch-cuda11.3: nvcr.io/nvidia/pytorch:21.06-py3 +echo using docker image pytorch-cuda11.3: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime git config --global core.editor "vim" git config --global user.name "Zhiqi Lin" @@ -38,7 +38,7 @@ echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc # cmd for count code lines # find cube/ -name "*.py" -print0 | xargs -0 wc -l - +pip uninstall training_daemon python setup.py develop pip install -r requirements.txt From d96fb759797d88a10d1efebca41eb4ecfe1fc1ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 24 Nov 2021 23:07:51 +0800 Subject: [PATCH 0405/1892] disable redundant check --- cube/schedule/sugraph.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py index e7d23a71..9843e03a 100644 --- a/cube/schedule/sugraph.py +++ b/cube/schedule/sugraph.py @@ -78,20 +78,6 @@ def reset_dependency(sus: List[ScheduleUnit]): src.add_successor(out_idx, dst) dst.add_predecessor(in_idx, src) - @staticmethod - def gen_comm_adapter(sus: List[ScheduleUnit]): - """ - Generate communication adapter for each SU - """ - pass - - @staticmethod - def gen_trans_adapter(sus: List[ScheduleUnit]): - """ - Generate transformation adapter for each SU - """ - pass - def __len__(self): return len(self.sequence) @@ -426,7 +412,7 @@ def partial_set_order(self, seq: List[ScheduleUnit], lazy=False): continue if rsu.mirror in seq: index = seq.index(rsu.mirror) - seq.insert(idx+1, rsu) + seq.insert(index+1, rsu) continue if rsu in seq: raise RuntimeError(f"Internal Error: should not appear SU: {rsu}") @@ -442,8 +428,8 @@ def partial_set_order(self, seq: List[ScheduleUnit], lazy=False): idx += 1 seq.insert(idx, rsu) - if not SUGraph.is_topo_order(seq, integrity_check=True): - raise RuntimeError("Internal Error: topo is not guaranteed.") + # if not SUGraph.is_topo_order(seq, integrity_check=True): + # raise RuntimeError("Internal Error: topo is not guaranteed.") self.sequence = seq return True From 48c6acfdd69e52ac8355adacdfe1cf1de2cb1aeb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 10:24:03 +0800 Subject: [PATCH 0406/1892] change to adam optimizer --- cube/schedule/su.py | 20 +++++++++++++++ examples/gpt/gpt.py | 2 +- examples/gpt/policy/megatron_parallel.py | 9 +++++++ examples/inspector.py | 31 ++++++++++++++++-------- examples/mlp/policy/pipe1f1b_parallel.py | 3 ++- examples/transformer/transformer.py | 2 +- examples/transformer/transformers.py | 2 +- 7 files changed, 55 insertions(+), 14 deletions(-) diff --git a/cube/schedule/su.py b/cube/schedule/su.py index ab80f6e9..2aba8c18 100644 --- a/cube/schedule/su.py +++ b/cube/schedule/su.py @@ -422,6 +422,26 @@ def successors(self, index: Optional[int] = None) -> List: else: raise TypeError("Expected index to be None or int") + def is_identity(self): + """ + Check if the SU is identity + """ + # not assigned + if len(self.device) == 0: + return False + if self.stype == SUType.P2P: + send_ranks = set(self.device) + recv_ranks = set(self.device) + for node in self.nodes(): + send_ranks.update(node.send_ranks) + recv_ranks.update(node.recv_ranks) + if list(send_ranks) != self.device: + return False + if list(recv_ranks) != self.device: + return False + return True + return False + @staticmethod def get_inputs(nodes): """ diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py index e35d06d9..72ae4d1c 100644 --- a/examples/gpt/gpt.py +++ b/examples/gpt/gpt.py @@ -233,7 +233,7 @@ def train_iter(model, dataloader): loss.backward() model = model.get_gen_module() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) CudaTimer().warmup() torch.distributed.barrier() diff --git a/examples/gpt/policy/megatron_parallel.py b/examples/gpt/policy/megatron_parallel.py index c41d97fd..f43158a4 100644 --- a/examples/gpt/policy/megatron_parallel.py +++ b/examples/gpt/policy/megatron_parallel.py @@ -1,3 +1,5 @@ +import time + from cube.graph import IRGraph from cube.graph.operator.function import CubeComplexEmbedding, Linear, Sum from cube.graph.operator.function import CubeComplexFeedForward @@ -116,14 +118,21 @@ def schedule_policy(sugraph: SUGraph, resource): The schedule policy assign devices """ print('> scheduling SU...') + start_time = time.time() + for su in sugraph.sus(): if su.stype == SUType.Dataloader: devid = su.tag[0] sugraph.assign(su, devid) + print('> [scheduling] assign device...') for su in sugraph.fsus(): devid = su.tag[0] sugraph.assign(su, devid) sugraph.assign(su.mirror, devid) fsus = sugraph.fsus() + print('> [scheduling] setting schedule order...') sugraph.partial_set_order(fsus, lazy=False) + + span = time.time() - start_time + print('> Done scheduling: {:.2f} seconds'.format(span)) return sugraph diff --git a/examples/inspector.py b/examples/inspector.py index 55a979de..048487d8 100644 --- a/examples/inspector.py +++ b/examples/inspector.py @@ -19,10 +19,10 @@ from cube.profiler.memory import memory_summary from cube.profiler.timer import print_each_rank -L, N, E = (512, 4, 3072) +L, N, E = (512, 8, 3072) # gpt kBatchDims = [0, 0] -kDataShapes = ([N // 2, L], [N // 2, L]) +kDataShapes = ([N, L], [N, L]) # transformer # kBatchDims = [1] # kDataShapes = ([512, 4, 3072],) @@ -67,29 +67,39 @@ def train(args): model = load_module(genfile) train_fn = load_train_fn(genfile) - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) CudaTimer().warmup() torch.distributed.barrier() - with torch.profiler.profile() as prof: - iter_num = args.iter_num + iter_num = args.iter_num + + def train_iters(): for step in range(iter_num): if step >= 40: CudaTimer().start('e2e') train_fn(model, dataloader) optimizer.step() optimizer.zero_grad() + if step == 1: + print('test passed') if step >= 40: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) time.sleep(0.05) - prof.export_chrome_trace(f"trace{torch.distributed.get_rank()}.json") - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - + if args.profile: + with torch.profiler.profile() as prof: + train_iters() + prof.export_chrome_trace(f"trace{torch.distributed.get_rank()}.json") + else: + train_iters() + + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) memory_summary() @@ -100,6 +110,7 @@ def train(args): default='gencode{rank}.py') parser.add_argument('--iter-num', type=int, default=128) + parser.add_argument('--profile', dest='profile', action='store_true') args = parser.parse_args() cube.init() diff --git a/examples/mlp/policy/pipe1f1b_parallel.py b/examples/mlp/policy/pipe1f1b_parallel.py index b4ff458b..96b8d42a 100644 --- a/examples/mlp/policy/pipe1f1b_parallel.py +++ b/examples/mlp/policy/pipe1f1b_parallel.py @@ -18,6 +18,7 @@ def transform_policy(graph, resource): sub_nodes = graph.partition(node, algo, config=dict(chunk_num=micro_batch_num)) else: algo = node.algorithms('dim') + # dim trace sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=micro_batch_num)) for idx, sub_node in enumerate(sub_nodes): sub_node.tag = idx @@ -58,7 +59,7 @@ def b(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: bstage = [fsu.mirror for fsu in fstage][::-1] return bstage - # assign su to stages + # assign su to SU Group for micro_bid, fseq in enumerate(fseqs): chunk_num = int(len(fseq) // resource.ngpus) for idx, fsu in enumerate(fseq): diff --git a/examples/transformer/transformer.py b/examples/transformer/transformer.py index f54c6d6e..78759d9a 100644 --- a/examples/transformer/transformer.py +++ b/examples/transformer/transformer.py @@ -166,7 +166,7 @@ def train_iter(model, dataloader): loss.backward() model = model.get_gen_module() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) CudaTimer().warmup() torch.distributed.barrier() diff --git a/examples/transformer/transformers.py b/examples/transformer/transformers.py index 7eddf913..65f04721 100644 --- a/examples/transformer/transformers.py +++ b/examples/transformer/transformers.py @@ -164,7 +164,7 @@ def train_iter(model, dataloader): loss.backward() model = model.get_gen_module() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) CudaTimer().warmup() torch.distributed.barrier() From 46b0eb1af360014a78dd548c2538815e7e5bfade Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 13:56:30 +0800 Subject: [PATCH 0407/1892] init swin transformer example --- examples/swin/swin_transformer.py | 673 ++++++++++++++++++++++++++++++ 1 file changed, 673 insertions(+) create mode 100644 examples/swin/swin_transformer.py diff --git a/examples/swin/swin_transformer.py b/examples/swin/swin_transformer.py new file mode 100644 index 00000000..0289f6d2 --- /dev/null +++ b/examples/swin/swin_transformer.py @@ -0,0 +1,673 @@ + +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu + +# Copied and modified from +# -------------------------------------------------------- + +from typing import Optional +import torch +import torch.nn as nn + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + drop_path(x, self.drop_path_p) + ffn = self.norm2(x) + ffn = self.mlp(ffn) + x = x + drop_path(ffn, self.drop_path_p) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + # if self.ape: + # x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x): + # forward features + # x = self.forward_features(x) + x = self.patch_embed(x) + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def train(): + + # image batch input + N, C, H, W = [1, 3, 224, 224] + + embed_dim, depths, num_heads, window_size = [ + 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 + ] + + # embed_dim, depths, num_heads, window_size = [ + # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + # ] + + # 1.02B Model + # embed_dim, depths, num_heads, window_size = [ + # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # ] + + + model = SwinTransformer(embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size) + + module = torch.jit.script(model) + print(module.graph) + + model = model.cuda() + + dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [N, C, H, W]) + + def train_iter(model, dataloader): + img = next(dataloader) + loss = model(img) + loss = torch.sum(loss) + loss.backward() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on iteration') + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + + +if __name__ == '__main__': + + cube.init() + train() From 181df458d74449e5a20c015d5e79462059bc3751 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 14:33:38 +0800 Subject: [PATCH 0408/1892] remove useless branch --- examples/swin/swin_transformer.py | 38 +++++++++++++++++++------------ 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/examples/swin/swin_transformer.py b/examples/swin/swin_transformer.py index 0289f6d2..b9e39787 100644 --- a/examples/swin/swin_transformer.py +++ b/examples/swin/swin_transformer.py @@ -82,6 +82,20 @@ def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): return x +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. @@ -108,19 +122,6 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) @@ -142,8 +143,13 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): q = q * self.scale attn = (q @ k.transpose(-2, -1)) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + relative_position_bias = self.relative_position_bias_table[relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) @@ -628,8 +634,10 @@ def train(): num_heads = num_heads, window_size = window_size) + module = torch.jit.script(model) print(module.graph) + # print(parser.ScriptModuleParser.flatten(module, depth=2)) model = model.cuda() From b98f2ce797c80001bc5255f4cd8db4406a0e08b8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 15:12:25 +0800 Subject: [PATCH 0409/1892] add shape annotation --- examples/swin/swin_transformer.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/examples/swin/swin_transformer.py b/examples/swin/swin_transformer.py index b9e39787..56cc68c1 100644 --- a/examples/swin/swin_transformer.py +++ b/examples/swin/swin_transformer.py @@ -270,27 +270,36 @@ def forward(self, x): shifted_x = x # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x + # [B, H, W, C] -> [B, H * W, C] x = x.view(B, H * W, C) - - # FFN + # [B, H * W, C] -> [B, H * W, C] x = shortcut + drop_path(x, self.drop_path_p) + # FFN + # [B, H * W, C] -> [B, H * W, C] ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] x = x + drop_path(ffn, self.drop_path_p) return x From 1df84fe69064d75302b04862048d3ab686ffb4a5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 15:18:07 +0800 Subject: [PATCH 0410/1892] init swin --- benchmark/swin/swin_megatron.py | 690 ++++++++++++++++++++++++++++++++ 1 file changed, 690 insertions(+) create mode 100644 benchmark/swin/swin_megatron.py diff --git a/benchmark/swin/swin_megatron.py b/benchmark/swin/swin_megatron.py new file mode 100644 index 00000000..56cc68c1 --- /dev/null +++ b/benchmark/swin/swin_megatron.py @@ -0,0 +1,690 @@ + +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu + +# Copied and modified from +# -------------------------------------------------------- + +from typing import Optional +import torch +import torch.nn as nn + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + relative_position_bias = self.relative_position_bias_table[relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + # [B, H, W, C] -> [B, H * W, C] + x = x.view(B, H * W, C) + # [B, H * W, C] -> [B, H * W, C] + x = shortcut + drop_path(x, self.drop_path_p) + # FFN + # [B, H * W, C] -> [B, H * W, C] + ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] + ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] + x = x + drop_path(ffn, self.drop_path_p) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + # if self.ape: + # x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x): + # forward features + # x = self.forward_features(x) + x = self.patch_embed(x) + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def train(): + + # image batch input + N, C, H, W = [1, 3, 224, 224] + + embed_dim, depths, num_heads, window_size = [ + 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 + ] + + # embed_dim, depths, num_heads, window_size = [ + # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + # ] + + # 1.02B Model + # embed_dim, depths, num_heads, window_size = [ + # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # ] + + + model = SwinTransformer(embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size) + + + module = torch.jit.script(model) + print(module.graph) + # print(parser.ScriptModuleParser.flatten(module, depth=2)) + + model = model.cuda() + + dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [N, C, H, W]) + + def train_iter(model, dataloader): + img = next(dataloader) + loss = model(img) + loss = torch.sum(loss) + loss.backward() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on iteration') + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + + +if __name__ == '__main__': + + cube.init() + train() From 826d952b73088f2326fd0940469ac46ec2b2eb02 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 18:03:23 +0800 Subject: [PATCH 0411/1892] megatron parallel support --- benchmark/swin/layers.py | 213 ++++++++++++++++++++++++++++++++ benchmark/swin/swin_megatron.py | 134 +++++++++++++++----- cube/runtime/resource.py | 4 + 3 files changed, 322 insertions(+), 29 deletions(-) create mode 100644 benchmark/swin/layers.py diff --git a/benchmark/swin/layers.py b/benchmark/swin/layers.py new file mode 100644 index 00000000..81c3af06 --- /dev/null +++ b/benchmark/swin/layers.py @@ -0,0 +1,213 @@ +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from cube.profiler.timer import print_each_rank +from cube.runtime.resource import EnvResource + + +def _reduce(input_): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size() + if world_size == 1: + return input_ + group = EnvResource().tp_group + torch.distributed.all_reduce(input_, group=group) + return input_ + + +def _split(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + # Bypass the function if we are using only 1 GPU. + if world_size==1: + return input_ + last_dim = input_.dim() - 1 + last_dim_size = input_.size()[last_dim] // world_size + tensor_list = torch.split(input_, last_dim_size, dim=last_dim) + output = tensor_list[rank].contiguous() + return output + + +def _gather(input_): + """Gather tensors and concatinate along the last dimension.""" + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + # Bypass the function if we are using only 1 GPU. + if world_size==1: + return input_ + # Size and dimension. + last_dim = input_.dim() - 1 + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + group = EnvResource().tp_group + torch.distributed.all_gather(tensor_list, input_, group=group) + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + return output + + +class ColumnInputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return input_ + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class ColumnOutputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _gather(input_) + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + + +class RowInputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _split(input_) + + @staticmethod + def backward(ctx, grad_outputs): + return _gather(grad_outputs) + + +class RowOutputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class ColumnParallelLinear(torch.nn.Module): + + def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.full_input = full_input + self.full_output = full_output + + world_size = torch.distributed.get_world_size() + + # print_each_rank(f'> parallizing linear using column partition: ' + # f'{output_size} partitioned by {world_size} devices') + + # not if output size is smaller than world size, + # no parallel enbaled. Each device compute the same + if world_size > output_size: + world_size = 1 + + self.weight = Parameter(torch.empty( + int(self.output_size // world_size), + self.input_size, + )) + if bias: + self.bias = Parameter(torch.empty( + int(self.output_size // world_size), + )) + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + bias = self.bias + if not self.full_input: + raise RuntimeError("Expected full tensor input") + input_parallel = ColumnInputAdapter.apply(input_) + output_parallel = F.linear(input_parallel, self.weight, bias) + if self.full_output: + output = ColumnOutputAdapter.apply(output_parallel) + else: + output = output_parallel + return output + + +class RowParallelLinear(torch.nn.Module): + + def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.full_input = full_input + self.full_output = full_output + + world_size = torch.distributed.get_world_size() + + # print_each_rank(f'> parallizing linear using row partition: ' + # f'{output_size} partitioned by {world_size} devices') + + # not if output size is smaller than world size, + # no parallel enbaled. Each device compute the same + if world_size > output_size: + world_size = 1 + + self.weight = Parameter(torch.empty( + self.output_size, + int(self.input_size // world_size), + )) + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + bias = self.bias + if self.full_input: + input_parallel = RowInputAdapter.apply(input_) + else: + input_parallel = input_ + output_parallel = F.linear(input_parallel, self.weight, bias) + if self.full_output: + output = RowOutputAdapter.apply(output_parallel) + else: + output = output_parallel + return output + + +class ShardEmbedding(torch.nn.Module): + + def __init__(self, num_embeddings, embedding_dim): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + self.shard_num = torch.distributed.get_world_size() + self.myshard = torch.distributed.get_rank() + + shard_num_embeddings = self.num_embeddings // self.shard_num + self.vocab_start_index = shard_num_embeddings * self.myshard + self.vocab_end_index = self.vocab_start_index + shard_num_embeddings + + self.weight = torch.nn.Parameter( + torch.empty(shard_num_embeddings, self.embedding_dim) + ) + + def forward(self, input_): + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | \ + (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + output_parallel = F.embedding( + masked_input, self.weight, + None, None, 2., False, False + ) + output = RowOutputAdapter.apply(output_parallel) + return output diff --git a/benchmark/swin/swin_megatron.py b/benchmark/swin/swin_megatron.py index 56cc68c1..a8c0c301 100644 --- a/benchmark/swin/swin_megatron.py +++ b/benchmark/swin/swin_megatron.py @@ -17,6 +17,11 @@ from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary +import argparse + + +from benchmark.swin.layers import ColumnParallelLinear, RowParallelLinear + def drop_path(x, drop_prob: float = 0.): if drop_prob == 0.: @@ -29,14 +34,16 @@ def drop_path(x, drop_prob: float = 0.): return output -class Mlp(nn.Module): +class MegatronMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, full_input=True, full_output=False) self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) + # self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = RowParallelLinear(hidden_features, out_features, full_input=False, full_output=True) self.drop = nn.Dropout(drop) def forward(self, x): @@ -96,7 +103,7 @@ def window_position_index(window_size_h: int, window_size_w: int): return relative_position_index -class WindowAttention(nn.Module): +class MegatronWindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: @@ -114,17 +121,20 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.global_num_heads = num_heads + self.num_heads = num_heads // torch.distributed.get_world_size() + self.dim_heads = dim // self.global_num_heads + self.scale = qk_scale or self.dim_heads ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + # self.proj = nn.Linear(dim, dim) + self.proj = RowParallelLinear(dim, dim, full_input=False, full_output=True) self.proj_drop = nn.Dropout(proj_drop) # trunc_normal_(self.relative_position_bias_table, std=.02) @@ -137,7 +147,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale @@ -163,7 +173,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) x = self.proj(x) x = self.proj_drop(x) return x @@ -220,14 +230,14 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) - self.attn = WindowAttention( + self.attn = MegatronWindowAttention( dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path_p = drop_path self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = MegatronMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA @@ -619,19 +629,21 @@ def flops(self): return flops -def train(): +def train(args): + resource = cube.runtime.resource.EnvResource() # image batch input - N, C, H, W = [1, 3, 224, 224] - - embed_dim, depths, num_heads, window_size = [ - 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 - ] + N, C, H, W = [32, 3, 224, 224] # embed_dim, depths, num_heads, window_size = [ - # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 # ] + # 348.55 M + embed_dim, depths, num_heads, window_size = [ + 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + ] + # 1.02B Model # embed_dim, depths, num_heads, window_size = [ # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 @@ -642,14 +654,16 @@ def train(): depths = depths, num_heads = num_heads, window_size = window_size) - - - module = torch.jit.script(model) - print(module.graph) - # print(parser.ScriptModuleParser.flatten(module, depth=2)) - model = model.cuda() + # setup data parallel reducer + reducer = None + if torch.distributed.get_world_size(group=resource.dp_group) > 1: + print('> initialize weight reducer') + reducer = resource.reducer + for param in model.parameters(): + reducer.add_param(param) + dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [N, C, H, W]) def train_iter(model, dataloader): @@ -657,9 +671,15 @@ def train_iter(model, dataloader): loss = model(img) loss = torch.sum(loss) loss.backward() + if reducer is not None: + reducer.allreduce() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + # start training + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + CudaTimer().warmup() torch.distributed.barrier() iter_num = 128 @@ -670,7 +690,7 @@ def train_iter(model, dataloader): optimizer.step() optimizer.zero_grad() if step == 1: - print('> passed on iteration') + print('> passed on 1st iteration') if step >= 40: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: @@ -685,6 +705,62 @@ def train_iter(model, dataloader): if __name__ == '__main__': + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--tp', type=int, default=1, + help='tensor parallel size') + parser.add_argument('--dp', type=int, default=1, + help='data parallel size') + parser.add_argument('--pp', type=int, default=1, + help='pipeline parallel size') + parser.add_argument('--micro-bs', type=int, default=-1) + args = parser.parse_args() cube.init() - train() + + # allocate resource + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + tp_size, tp_group_nums = args.tp, ndevs // args.tp + dp_size, dp_group_nums = args.dp, ndevs // args.dp + pp_size, pp_group_nums = args.pp, ndevs // args.pp + + if not pp_size * dp_size * tp_size == ndevs: + raise RuntimeError("Expected all devices are used") + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize data parallel group + all_data_parallel_group_ranks = list() + for i in range(pp_size): + start_rank = i * pp_group_nums + end_rank = (i + 1) * pp_group_nums + for j in range(tp_size): + ranks = list(range(start_rank + j, end_rank, tp_size)) + all_data_parallel_group_ranks.append(ranks) + # initialize groups + group = devs.get_group(ranks) + if myrank in ranks: + resource.dp_group = group + resource.reducer = cube.runtime.reducer.Reducer(ranks) + + # initialize pipelne parallel groups + for i in range(dp_size): + ranks = [data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks] + group = devs.get_group(ranks) + if myrank in ranks: + resource.pp_group = group + + # initialize tensor parallel groups + for i in range(tp_group_nums): + ranks = list(range(i * tp_size, (i + 1) * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + resource.tp_group = group + + train(args) diff --git a/cube/runtime/resource.py b/cube/runtime/resource.py index d67f5742..63a5275c 100644 --- a/cube/runtime/resource.py +++ b/cube/runtime/resource.py @@ -23,3 +23,7 @@ def __init__(self): def __getattr__(self, name): return getattr(self.instance, name) + + + def __setattr__(self, name, val) -> None: + setattr(EnvResource.instance, name, val) From ea373438e24fe12e40ae981733ab57e4eb56fa40 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 18:41:38 +0800 Subject: [PATCH 0412/1892] fix bugs on tensor parallel --- benchmark/swin/layers.py | 18 +++++++++--------- benchmark/swin/swin_megatron.py | 14 +++++++++++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/benchmark/swin/layers.py b/benchmark/swin/layers.py index 81c3af06..dda71ad9 100644 --- a/benchmark/swin/layers.py +++ b/benchmark/swin/layers.py @@ -10,7 +10,7 @@ def _reduce(input_): """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size() + world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) if world_size == 1: return input_ group = EnvResource().tp_group @@ -22,8 +22,8 @@ def _split(input_): """Split the tensor along its last dimension and keep the corresponding slice.""" - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) + rank = torch.distributed.get_rank(group=EnvResource().tp_group) # Bypass the function if we are using only 1 GPU. if world_size==1: return input_ @@ -37,8 +37,8 @@ def _split(input_): def _gather(input_): """Gather tensors and concatinate along the last dimension.""" - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) + rank = torch.distributed.get_rank(group=EnvResource().tp_group) # Bypass the function if we are using only 1 GPU. if world_size==1: return input_ @@ -100,7 +100,7 @@ def __init__(self, input_size, output_size, bias=True, full_input=True, full_out self.full_input = full_input self.full_output = full_output - world_size = torch.distributed.get_world_size() + world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) # print_each_rank(f'> parallizing linear using column partition: ' # f'{output_size} partitioned by {world_size} devices') @@ -145,7 +145,7 @@ def __init__(self, input_size, output_size, bias=True, full_input=True, full_out self.full_input = full_input self.full_output = full_output - world_size = torch.distributed.get_world_size() + world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) # print_each_rank(f'> parallizing linear using row partition: ' # f'{output_size} partitioned by {world_size} devices') @@ -187,8 +187,8 @@ def __init__(self, num_embeddings, embedding_dim): self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim - self.shard_num = torch.distributed.get_world_size() - self.myshard = torch.distributed.get_rank() + self.shard_num = torch.distributed.get_world_size(group=EnvResource().tp_group) + self.myshard = torch.distributed.get_rank(group=EnvResource().tp_group) shard_num_embeddings = self.num_embeddings // self.shard_num self.vocab_start_index = shard_num_embeddings * self.myshard diff --git a/benchmark/swin/swin_megatron.py b/benchmark/swin/swin_megatron.py index a8c0c301..8f3cfe10 100644 --- a/benchmark/swin/swin_megatron.py +++ b/benchmark/swin/swin_megatron.py @@ -122,7 +122,8 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.dim = dim self.window_size = window_size # Wh, Ww self.global_num_heads = num_heads - self.num_heads = num_heads // torch.distributed.get_world_size() + group = cube.runtime.resource.EnvResource().tp_group + self.num_heads = num_heads // torch.distributed.get_world_size(group=group) self.dim_heads = dim // self.global_num_heads self.scale = qk_scale or self.dim_heads ** -0.5 @@ -658,13 +659,14 @@ def train(args): # setup data parallel reducer reducer = None - if torch.distributed.get_world_size(group=resource.dp_group) > 1: + if args.dp > 1: print('> initialize weight reducer') reducer = resource.reducer for param in model.parameters(): reducer.add_param(param) - dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [N, C, H, W]) + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [N // args.dp, C, H, W]) def train_iter(model, dataloader): img = next(dataloader) @@ -745,8 +747,10 @@ def train_iter(model, dataloader): # initialize groups group = devs.get_group(ranks) if myrank in ranks: + dp_ranks = ranks resource.dp_group = group resource.reducer = cube.runtime.reducer.Reducer(ranks) + print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) # initialize pipelne parallel groups for i in range(dp_size): @@ -754,13 +758,17 @@ def train_iter(model, dataloader): for data_parallel_group_ranks in all_data_parallel_group_ranks] group = devs.get_group(ranks) if myrank in ranks: + pp_ranks = ranks resource.pp_group = group + print_each_rank(f'initialzed pipeline parallel group: {pp_ranks}', rank_only=myrank) # initialize tensor parallel groups for i in range(tp_group_nums): ranks = list(range(i * tp_size, (i + 1) * tp_size)) group = devs.get_group(ranks) if myrank in ranks: + tp_ranks = ranks resource.tp_group = group + print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) train(args) From e2565b2cdbc2a8418e255e24669a35ca1ab587d3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 20:38:16 +0800 Subject: [PATCH 0413/1892] add 2.01B model setting --- benchmark/swin/swin_megatron.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/benchmark/swin/swin_megatron.py b/benchmark/swin/swin_megatron.py index 8f3cfe10..3b3bbb44 100644 --- a/benchmark/swin/swin_megatron.py +++ b/benchmark/swin/swin_megatron.py @@ -123,6 +123,9 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.window_size = window_size # Wh, Ww self.global_num_heads = num_heads group = cube.runtime.resource.EnvResource().tp_group + tp_world_size = torch.distributed.get_world_size(group=group) + if num_heads % tp_world_size != 0: + print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') self.num_heads = num_heads // torch.distributed.get_world_size(group=group) self.dim_heads = dim // self.global_num_heads self.scale = qk_scale or self.dim_heads ** -0.5 @@ -132,6 +135,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # print(f'qkv embed dim: {dim}') self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False) self.attn_drop = nn.Dropout(attn_drop) # self.proj = nn.Linear(dim, dim) @@ -634,28 +638,34 @@ def train(args): resource = cube.runtime.resource.EnvResource() # image batch input - N, C, H, W = [32, 3, 224, 224] + N, C, H, W = [1, 3, 224, 224] # embed_dim, depths, num_heads, window_size = [ # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 # ] # 348.55 M - embed_dim, depths, num_heads, window_size = [ - 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 - ] + # embed_dim, depths, num_heads, window_size = [ + # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + # ] - # 1.02B Model + # 895.7 M Model # embed_dim, depths, num_heads, window_size = [ # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 # ] + # 2.01B model + embed_dim, depths, num_heads, window_size = [ + 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 + ] + model = SwinTransformer(embed_dim = embed_dim, depths = depths, num_heads = num_heads, window_size = window_size) model = model.cuda() + memory_summary() # setup data parallel reducer reducer = None @@ -693,6 +703,7 @@ def train_iter(model, dataloader): optimizer.zero_grad() if step == 1: print('> passed on 1st iteration') + memory_summary() if step >= 40: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: From 55f5ead93a41515aca4c454bc2b632ae7743ad86 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 25 Nov 2021 20:46:37 +0800 Subject: [PATCH 0414/1892] swin running example --- benchmark/swin/swin_megatron.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/benchmark/swin/swin_megatron.py b/benchmark/swin/swin_megatron.py index 3b3bbb44..9f3ab9e8 100644 --- a/benchmark/swin/swin_megatron.py +++ b/benchmark/swin/swin_megatron.py @@ -1,11 +1,16 @@ # -------------------------------------------------------- -# Swin Transformer -# Copyright (c) 2021 Microsoft -# Licensed under The MIT License [see LICENSE for details] -# Written by Ze Liu - -# Copied and modified from +# Modified from Swin-Transformer Repo +""" +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + benchmark/swin/swin_megatron.py +""" # -------------------------------------------------------- from typing import Optional From 401e6313a72a7d92cd4fa96860092c3266fc675c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 26 Nov 2021 13:36:02 +0800 Subject: [PATCH 0415/1892] halo exchange --- cube/runtime/function/__init__.py | 3 +- cube/runtime/function/dist.py | 75 +++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 cube/runtime/function/dist.py diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py index a8c9ff96..294c00ce 100644 --- a/cube/runtime/function/__init__.py +++ b/cube/runtime/function/__init__.py @@ -1,2 +1,3 @@ import cube.runtime.function.complex as complex -from cube.runtime.function.complex import * \ No newline at end of file +from cube.runtime.function.complex import * +from cube.runtime.function.dist import * diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py new file mode 100644 index 00000000..e8952824 --- /dev/null +++ b/cube/runtime/function/dist.py @@ -0,0 +1,75 @@ +import torch + + +class RollDimParallel(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, shifts, dims, group, full_input=False, full_output=False): + pass + + @staticmethod + def backward(ctx): + pass + +def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, + group, full_input=False, full_output=False): + """ + partition torch.roll at shifted dimension + + Inputs: + input: [B, H, W, C] + shift: int + dim: int + """ + world_size = torch.distributed.get_world_size(group) + rank = torch.distributed.get_rank(group) + # halo exchange at H dimension + if dim == 1: + assert shift < 0 + shift = 0 - shift + local = input[:, shift:, :, :] + remote = input[:, slice(0, shift), :, :].contiguous() + recv_tensor = torch.empty_like(remote, requires_grad=True) + # send to next rank and recv from prevous rank + send_op = torch.distributed.P2POp( + torch.distributed.isend, remote, + (rank - 1 + world_size) % world_size, group=group + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, + (rank + 1) % world_size, group=group + ) + ops = [send_op, recv_op] if rank % 2 == 0 else [recv_op, send_op] + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + tensor = torch.cat((local, recv_tensor), dim=dim).contiguous() + return tensor + else: + raise NotImplementedError + + +def roll_dim_allgather(input: torch.Tensor, shift: int, dim: int, group, + full_input=False, full_output=False): + """ + partition torch.roll at shifted dimension + + Inputs: + input: [B, H, W, C] + shift: int + dim: int + """ + world_size = torch.distributed.get_world_size(group) + rank = torch.distributed.get_rank(group) + # allgather to have all and select what each rank needed + if dim == 1: + tensor_list = [torch.empty_like(input) for _ in range(world_size)] + tensor_list[rank] = input + torch.distributed.all_gather(tensor_list, input, group=group) + full_tensor = torch.cat(tuple(tensor_list), dim=dim).contiguous() + full_tensor = torch.roll(full_tensor, shifts=(shift,), dims=(dim,)) + chunk_len = input.shape[dim] + mytensor = full_tensor[:, rank * chunk_len : (rank + 1) * chunk_len, :, :] + mytensor = mytensor.contiguous() + return mytensor + else: + raise NotImplementedError From a5673f46e420e6391c4ae053c904892b68303ef9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 26 Nov 2021 13:40:37 +0800 Subject: [PATCH 0416/1892] add test for roll partition --- tests/runtime/rollsplit.py | 68 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 tests/runtime/rollsplit.py diff --git a/tests/runtime/rollsplit.py b/tests/runtime/rollsplit.py new file mode 100644 index 00000000..013462bf --- /dev/null +++ b/tests/runtime/rollsplit.py @@ -0,0 +1,68 @@ +""" +CUDA_VISIBLE_DEVICES=4,5,6,7 +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + tests/runtime/rollsplit.py + +""" + + +import torch + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +def test_roll_parallel(): + + # input size + group = None + world_size = torch.distributed.get_world_size(group=group) + input_size = [1, 224 // world_size, 224, 256] + # input_size = torch.arange(0, ) + input = torch.randn(input_size).cuda() * 10 + + CudaTimer().warmup(seconds=2) + + torch.distributed.barrier() + CudaTimer().start(field_name='roll_halo') + for _ in range(1000): + roll_out = cube.runtime.function.roll_dim_parallel( + input, -(9 // 2), 1, group + ) + CudaTimer().stop(field_name='roll_halo') + ref1 = roll_out + # print_each_rank(ref1, rank_only=0) + assert roll_out.shape == input.shape + span = CudaTimer().duration(times=1000, field_name='roll_halo') + print_each_rank('span on halo exchange: {:.2f} ms'.format(span)) + + + torch.distributed.barrier() + CudaTimer().start(field_name='roll_allgather') + for _ in range(1000): + roll_out = cube.runtime.function.roll_dim_allgather( + input, -(9 // 2), 1, group + ) + CudaTimer().stop(field_name='roll_allgather') + ref2 = roll_out + # print_each_rank(ref2, rank_only=0) + span = CudaTimer().duration(times=1000, field_name='roll_allgather') + print_each_rank('span on allgather exchange: {:.2f} ms'.format(span)) + + if not torch.allclose(ref1, ref2, atol=1e-3, rtol=1e-3): + print('correctness test failed') + else: + print('correctness test passed') + + +if __name__ == '__main__': + + cube.init() + test_roll_parallel() \ No newline at end of file From 4c3a3707f97f9f37b3c56b48f7d9224614116335 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 26 Nov 2021 14:39:44 +0800 Subject: [PATCH 0417/1892] add autograd for roll --- cube/runtime/function/dist.py | 109 ++++++++++++++++++++++++++-------- tests/runtime/rollsplit.py | 25 +++++++- 2 files changed, 105 insertions(+), 29 deletions(-) diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py index e8952824..de99c98e 100644 --- a/cube/runtime/function/dist.py +++ b/cube/runtime/function/dist.py @@ -1,17 +1,7 @@ import torch -class RollDimParallel(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, shifts, dims, group, full_input=False, full_output=False): - pass - - @staticmethod - def backward(ctx): - pass - -def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, - group, full_input=False, full_output=False): +def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): """ partition torch.roll at shifted dimension @@ -23,12 +13,17 @@ def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) # halo exchange at H dimension - if dim == 1: - assert shift < 0 + if shift < 0: shift = 0 - shift - local = input[:, shift:, :, :] - remote = input[:, slice(0, shift), :, :].contiguous() - recv_tensor = torch.empty_like(remote, requires_grad=True) + if dim == 1: + local = input[:, shift:, :, :] + remote = input[:, slice(0, shift), :, :].contiguous() + elif dim == 2: + local = input[:, :, shift:, :] + remote = input[:, :, slice(0, shift), :].contiguous() + else: + raise NotImplementedError("Only support on dim 1 and dim 2") + recv_tensor = torch.empty_like(remote) # send to next rank and recv from prevous rank send_op = torch.distributed.P2POp( torch.distributed.isend, remote, @@ -44,8 +39,33 @@ def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, req.wait() tensor = torch.cat((local, recv_tensor), dim=dim).contiguous() return tensor + elif shift > 0: + boundary = input.shape[dim] - shift + if dim == 1: + local = input[:, slice(0, boundary), :, :] + remote = input[:, slice(boundary, input.shape[dim]), :, :].contiguous() + elif dim == 2: + local = input[:, :, slice(0, boundary), :] + remote = input[:, :, slice(boundary, input.shape[dim]), :].contiguous() + else: + raise NotImplementedError("Only support on dim 1 and dim 2") + recv_tensor = torch.empty_like(remote) + send_op = torch.distributed.P2POp( + torch.distributed.isend, remote, + (rank + 1) % world_size, group=group + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, + (rank - 1 + world_size) % world_size, group=group + ) + ops = [send_op, recv_op] if rank % 2 == 0 else [recv_op, send_op] + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + tensor = torch.cat((recv_tensor, local), dim=dim).contiguous() + return tensor else: - raise NotImplementedError + return input def roll_dim_allgather(input: torch.Tensor, shift: int, dim: int, group, @@ -61,15 +81,52 @@ def roll_dim_allgather(input: torch.Tensor, shift: int, dim: int, group, world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) # allgather to have all and select what each rank needed + tensor_list = [torch.empty_like(input) for _ in range(world_size)] + tensor_list[rank] = input + torch.distributed.all_gather(tensor_list, input, group=group) + full_tensor = torch.cat(tuple(tensor_list), dim=dim).contiguous() + full_tensor = torch.roll(full_tensor, shifts=(shift,), dims=(dim,)) + chunk_len = input.shape[dim] if dim == 1: - tensor_list = [torch.empty_like(input) for _ in range(world_size)] - tensor_list[rank] = input - torch.distributed.all_gather(tensor_list, input, group=group) - full_tensor = torch.cat(tuple(tensor_list), dim=dim).contiguous() - full_tensor = torch.roll(full_tensor, shifts=(shift,), dims=(dim,)) - chunk_len = input.shape[dim] mytensor = full_tensor[:, rank * chunk_len : (rank + 1) * chunk_len, :, :] - mytensor = mytensor.contiguous() - return mytensor + elif dim == 2: + mytensor = full_tensor[:, :, rank * chunk_len : (rank + 1) * chunk_len, :] else: - raise NotImplementedError + raise NotImplementedError("Only supported on dim 1 and dim 2") + mytensor = mytensor.contiguous() + return mytensor + + +class RollDimParallel(torch.autograd.Function): + """ + Halo exchange implementation on partitioning torch.roll + at shift dimension + + """ + @staticmethod + def forward(ctx, input_, shift: int, dim: int, group=None): + ctx.shift = shift + ctx.dim = dim + ctx.group = group + output = _roll_dim_parallel(input_, shift, dim, group) + return output + + @staticmethod + def backward(ctx, grad_output): + shift = ctx.shift + dim = ctx.dim + group = ctx.group + grad = _roll_dim_parallel(grad_output, 0-shift, dim, group) + return grad, None, None, None + + +def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): + """ + partition torch.roll at shifted dimension + + Inputs: + input: [B, H, W, C] + shift: int + dim: int + """ + return RollDimParallel.apply(input, shift, dim, group) diff --git a/tests/runtime/rollsplit.py b/tests/runtime/rollsplit.py index 013462bf..43f87fa9 100644 --- a/tests/runtime/rollsplit.py +++ b/tests/runtime/rollsplit.py @@ -34,7 +34,7 @@ def test_roll_parallel(): CudaTimer().start(field_name='roll_halo') for _ in range(1000): roll_out = cube.runtime.function.roll_dim_parallel( - input, -(9 // 2), 1, group + input, (9 // 2), 1, group ) CudaTimer().stop(field_name='roll_halo') ref1 = roll_out @@ -48,7 +48,7 @@ def test_roll_parallel(): CudaTimer().start(field_name='roll_allgather') for _ in range(1000): roll_out = cube.runtime.function.roll_dim_allgather( - input, -(9 // 2), 1, group + input, (9 // 2), 1, group ) CudaTimer().stop(field_name='roll_allgather') ref2 = roll_out @@ -62,7 +62,26 @@ def test_roll_parallel(): print('correctness test passed') +def test_roll_parallel_autograd(): + + group = None + world_size = torch.distributed.get_world_size(group=group) + input_size = [1, 224 // world_size, 224, 256] + # input_size = torch.arange(0, ) + input = torch.randn(input_size).cuda() * 10 + input = input.requires_grad_() + + out = cube.runtime.function.roll_dim_parallel( + input, (9 // 2), 1, group + ) + loss = torch.sum(out) + loss.backward() + print(loss) + print(input.grad) + + if __name__ == '__main__': cube.init() - test_roll_parallel() \ No newline at end of file + # test_roll_parallel() + test_roll_parallel_autograd() \ No newline at end of file From f50a3cc1d3736f7a384fb88d3049fa37958bbf29 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 26 Nov 2021 15:34:09 +0800 Subject: [PATCH 0418/1892] swin 1 for 1 batch size strong scaling --- examples/swin/swin_1.py | 794 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 794 insertions(+) create mode 100644 examples/swin/swin_1.py diff --git a/examples/swin/swin_1.py b/examples/swin/swin_1.py new file mode 100644 index 00000000..8663258d --- /dev/null +++ b/examples/swin/swin_1.py @@ -0,0 +1,794 @@ + +# -------------------------------------------------------- +# Modified from Swin-Transformer Repo +""" +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_1.py +""" +# -------------------------------------------------------- + +from typing import Optional +import torch +import torch.nn as nn + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary + +import argparse + + +from examples.swin.layers import ColumnParallelLinear, RowParallelLinear + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class MegatronMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, full_input=True, full_output=False) + self.act = act_layer() + # self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = RowParallelLinear(hidden_features, out_features, full_input=False, full_output=True) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +class MegatronWindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.global_num_heads = num_heads + group = cube.runtime.resource.EnvResource().tp_group + tp_world_size = torch.distributed.get_world_size(group=group) + if num_heads % tp_world_size != 0: + print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') + self.num_heads = num_heads // torch.distributed.get_world_size(group=group) + self.dim_heads = dim // self.global_num_heads + self.scale = qk_scale or self.dim_heads ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # print(f'qkv embed dim: {dim}') + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False) + self.attn_drop = nn.Dropout(attn_drop) + # self.proj = nn.Linear(dim, dim) + self.proj = RowParallelLinear(dim, dim, full_input=False, full_output=True) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + relative_position_bias = self.relative_position_bias_table[relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = MegatronWindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MegatronMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + # [B, H, W, C] -> [B, H * W, C] + x = x.view(B, H * W, C) + # [B, H * W, C] -> [B, H * W, C] + x = shortcut + drop_path(x, self.drop_path_p) + # FFN + # [B, H * W, C] -> [B, H * W, C] + ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] + ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] + x = x + drop_path(ffn, self.drop_path_p) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + print(f'input resolution: {input_resolution}') + print(f'window num: {input_resolution[0] * input_resolution[1] / window_size / window_size}') + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + # if self.ape: + # x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x): + # forward features + # x = self.forward_features(x) + x = self.patch_embed(x) + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def train(args): + resource = cube.runtime.resource.EnvResource() + + # image batch input + N, C, H, W = [1, 3, 224, 224] + + # embed_dim, depths, num_heads, window_size = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 + # ] + + # 348.55 M + # embed_dim, depths, num_heads, window_size = [ + # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + # ] + + # 895.7 M Model -- 224x224 + embed_dim, depths, num_heads, window_size = [ + 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 + ] + + # 2.01B model + # embed_dim, depths, num_heads, window_size = [ + # 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # ] + + + model = SwinTransformer(img_size = H, + embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size) + model = model.cuda() + memory_summary() + + # setup data parallel reducer + reducer = None + if args.dp > 1: + print('> initialize weight reducer') + reducer = resource.reducer + for param in model.parameters(): + reducer.add_param(param) + + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [N // args.dp, C, H, W]) + + def train_iter(model, dataloader): + img = next(dataloader) + loss = model(img) + loss = torch.sum(loss) + loss.backward() + if reducer is not None: + reducer.allreduce() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # start training + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on 1st iteration') + memory_summary() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + + +if __name__ == '__main__': + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--tp', type=int, default=1, + help='tensor parallel size') + parser.add_argument('--dp', type=int, default=1, + help='data parallel size') + parser.add_argument('--pp', type=int, default=1, + help='pipeline parallel size') + parser.add_argument('--wp', type=int, default=1) + parser.add_argument('--micro-bs', type=int, default=-1) + args = parser.parse_args() + + cube.init() + + # allocate resource + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + tp_size, tp_group_nums = args.tp, ndevs // args.tp + dp_size, dp_group_nums = args.dp, ndevs // args.dp + pp_size, pp_group_nums = args.pp, ndevs // args.pp + + if not pp_size * dp_size * tp_size == ndevs: + raise RuntimeError("Expected all devices are used") + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize data parallel group + all_data_parallel_group_ranks = list() + for i in range(pp_size): + start_rank = i * pp_group_nums + end_rank = (i + 1) * pp_group_nums + for j in range(tp_size): + ranks = list(range(start_rank + j, end_rank, tp_size)) + all_data_parallel_group_ranks.append(ranks) + # initialize groups + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + resource.dp_group = group + resource.reducer = cube.runtime.reducer.Reducer(ranks) + print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) + + # initialize pipelne parallel groups + for i in range(dp_size): + ranks = [data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks] + group = devs.get_group(ranks) + if myrank in ranks: + pp_ranks = ranks + resource.pp_group = group + print_each_rank(f'initialzed pipeline parallel group: {pp_ranks}', rank_only=myrank) + + # initialize tensor parallel groups + for i in range(tp_group_nums): + ranks = list(range(i * tp_size, (i + 1) * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + resource.tp_group = group + print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + + train(args) From 1159990284bb6f846b8b6306cfd2af505a36407a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 26 Nov 2021 19:09:23 +0800 Subject: [PATCH 0419/1892] add grid autograd --- cube/runtime/function/dist.py | 95 +++++++++++++++++++++++++++++++++++ tests/runtime/rollsplit.py | 35 ++++++++++++- 2 files changed, 129 insertions(+), 1 deletion(-) diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py index de99c98e..4e44399d 100644 --- a/cube/runtime/function/dist.py +++ b/cube/runtime/function/dist.py @@ -130,3 +130,98 @@ def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): dim: int """ return RollDimParallel.apply(input, shift, dim, group) + + +class GridPartition(torch.autograd.Function): + """ + Full input + """ + @staticmethod + def forward(ctx, input_, nrow: int, ncol: int, group=None): + """ + input: [B, H, W, C] + """ + ctx.group = group + world_size = torch.distributed.get_world_size(group) + ctx.nrow = nrow + ctx.ncol = ncol + assert nrow * ncol == world_size + rank = torch.distributed.get_rank(group) + myrow = rank // ncol + mycol = rank % ncol + + chunk = torch.chunk(input_, nrow, dim=1)[myrow] + chunk = torch.chunk(chunk, ncol, dim=2)[mycol].contiguous() + return chunk + + @staticmethod + def backward(ctx, grad_output): + group = ctx.group + nrow = ctx.nrow + ncol = ctx.ncol + + world_size = torch.distributed.get_world_size(group) + rank = torch.distributed.get_rank(group) + grad_output = grad_output.contiguous() + tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] + tensor_list[rank] = grad_output + torch.distributed.all_gather(tensor_list, grad_output, group=group) + + rows = list() + for row in range(nrow): + row_slice = torch.cat(tuple(tensor_list[row*ncol:(row+1)*ncol]), dim=2) + rows.append(row_slice) + grad_output = torch.cat(tuple(rows), dim=1).contiguous() + return grad_output, None, None, None + + +class GridCollection(torch.autograd.Function): + """ + Full input + """ + @staticmethod + def forward(ctx, input_, nrow: int, ncol: int, group=None): + """ + input: [B, H, W, C] + output: [B, nrow * H, ncol * W, C] + """ + ctx.group = group + world_size = torch.distributed.get_world_size(group) + ctx.nrow = nrow + ctx.ncol = ncol + assert nrow * ncol == world_size + + world_size = torch.distributed.get_world_size(group) + rank = torch.distributed.get_rank(group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + rows = list() + for row in range(nrow): + row_slice = torch.cat(tuple(tensor_list[row*ncol:(row+1)*ncol]), dim=2) + rows.append(row_slice) + output = torch.cat(tuple(rows), dim=1).contiguous() + return output + + @staticmethod + def backward(ctx, grad_output): + group = ctx.group + nrow = ctx.nrow + ncol = ctx.ncol + + rank = torch.distributed.get_rank(group) + myrow = rank // ncol + mycol = rank % ncol + + chunk = torch.chunk(grad_output, nrow, dim=1)[myrow] + chunk = torch.chunk(chunk, ncol, dim=2)[mycol].contiguous() + return chunk, None, None, None + + +def grid_partition(input_, nrow, ncol, group=None): + return GridPartition.apply(input_, nrow, ncol, group) + + +def grid_collection(input_, nrow, ncol, group=None): + return GridCollection.apply(input_, nrow, ncol, group) diff --git a/tests/runtime/rollsplit.py b/tests/runtime/rollsplit.py index 43f87fa9..20eb29e2 100644 --- a/tests/runtime/rollsplit.py +++ b/tests/runtime/rollsplit.py @@ -80,8 +80,41 @@ def test_roll_parallel_autograd(): print(input.grad) +def test_grid_partition(): + + group = None + world_size = torch.distributed.get_world_size(group=group) + assert world_size == 4 + input_size = [1, 56, 56, 256] + input = torch.randn(input_size).cuda() * 10 + input = input.requires_grad_() + out = cube.runtime.function.grid_partition(input, 2, 2, group = None) + print(out.shape) + assert out.shape == torch.Size([1, 56 // 2, 56 // 2, 256]) + loss = torch.sum(out) + loss.backward() + # print(input.grad) + + +def test_grid_collection(): + + group = None + world_size = torch.distributed.get_world_size(group=group) + assert world_size == 4 + input_size = [1, 56 // 2, 56 // 2, 256] + input = torch.randn(input_size).cuda() * 10 + input = input.requires_grad_() + out = cube.runtime.function.grid_collection(input, 2, 2, group = None) + assert out.shape == torch.Size([1, 56, 56, 256]) + loss = torch.sum(out) + loss.backward() + # print(input.grad) + + if __name__ == '__main__': cube.init() # test_roll_parallel() - test_roll_parallel_autograd() \ No newline at end of file + # test_roll_parallel_autograd() + # test_grid_partition() + test_grid_collection() \ No newline at end of file From 560481283e974fe88daeb7dbdfeea11fa4e23ab7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 26 Nov 2021 20:34:03 +0800 Subject: [PATCH 0420/1892] fix p2p send recv to have global rank --- cube/runtime/function/dist.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py index 4e44399d..aa07b2d6 100644 --- a/cube/runtime/function/dist.py +++ b/cube/runtime/function/dist.py @@ -1,4 +1,14 @@ import torch +from torch.distributed.distributed_c10d import _get_global_rank + +from cube.profiler.timer import print_each_rank + + +def get_global_rank(group, group_rank): + if group is None: + return group_rank + else: + return _get_global_rank(group, group_rank) def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): @@ -11,6 +21,8 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): dim: int """ world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return torch.roll(input, (shift), (dim,)) rank = torch.distributed.get_rank(group) # halo exchange at H dimension if shift < 0: @@ -24,14 +36,21 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): else: raise NotImplementedError("Only support on dim 1 and dim 2") recv_tensor = torch.empty_like(remote) + # send to next rank and recv from prevous rank + # print_each_rank(f'send to {(rank - 1 + world_size) % world_size}, recv from {(rank + 1) % world_size}') + send_local_rank = (rank - 1 + world_size) % world_size + send_global_rank = get_global_rank(group, send_local_rank) + recv_local_rank = (rank + 1) % world_size + recv_global_rank = get_global_rank(group, recv_local_rank) + send_op = torch.distributed.P2POp( torch.distributed.isend, remote, - (rank - 1 + world_size) % world_size, group=group + send_global_rank, group=group ) recv_op = torch.distributed.P2POp( torch.distributed.irecv, recv_tensor, - (rank + 1) % world_size, group=group + recv_global_rank, group=group ) ops = [send_op, recv_op] if rank % 2 == 0 else [recv_op, send_op] reqs = torch.distributed.batch_isend_irecv(ops) @@ -50,13 +69,20 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): else: raise NotImplementedError("Only support on dim 1 and dim 2") recv_tensor = torch.empty_like(remote) + + # to global rank + send_local_rank = (rank + 1) % world_size + send_global_rank = get_global_rank(group, send_local_rank) + recv_local_rank = (rank - 1 + world_size) % world_size + recv_global_rank = get_global_rank(group, recv_local_rank) + send_op = torch.distributed.P2POp( torch.distributed.isend, remote, - (rank + 1) % world_size, group=group + send_global_rank, group=group ) recv_op = torch.distributed.P2POp( torch.distributed.irecv, recv_tensor, - (rank - 1 + world_size) % world_size, group=group + recv_global_rank, group=group ) ops = [send_op, recv_op] if rank % 2 == 0 else [recv_op, send_op] reqs = torch.distributed.batch_isend_irecv(ops) From 2998868defe08cbf969b8e513e9385d637ebaab2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 27 Nov 2021 15:03:08 +0800 Subject: [PATCH 0421/1892] unfold layers --- examples/swin/layers.py | 221 ++++++++++ examples/swin/swin_348M.py | 858 +++++++++++++++++++++++++++++++++++++ 2 files changed, 1079 insertions(+) create mode 100644 examples/swin/layers.py create mode 100644 examples/swin/swin_348M.py diff --git a/examples/swin/layers.py b/examples/swin/layers.py new file mode 100644 index 00000000..88f11a51 --- /dev/null +++ b/examples/swin/layers.py @@ -0,0 +1,221 @@ +import torch +from torch import autograd +import torch.nn.functional as F +from torch.nn.parameter import Parameter + + +def _reduce(input_, group): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + torch.distributed.all_reduce(input_, group=group) + return input_ + + +def _split(input_, group): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = torch.distributed.get_world_size(group=group) + rank = torch.distributed.get_rank(group=group) + # Bypass the function if we are using only 1 GPU. + if world_size==1: + return input_ + last_dim = input_.dim() - 1 + last_dim_size = input_.size()[last_dim] // world_size + tensor_list = torch.split(input_, last_dim_size, dim=last_dim) + output = tensor_list[rank].contiguous() + return output + + +def _gather(input_, group): + """Gather tensors and concatinate along the last dimension.""" + + world_size = torch.distributed.get_world_size(group=group) + rank = torch.distributed.get_rank(group=group) + # Bypass the function if we are using only 1 GPU. + if world_size==1: + return input_ + # Size and dimension. + last_dim = input_.dim() - 1 + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + return output + + +class ColumnInputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + return input_ + @staticmethod + def backward(ctx, grad_output): + group = ctx.group + return _reduce(grad_output, group), None + + +class ColumnOutputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + return _gather(input_, group) + @staticmethod + def backward(ctx, grad_output): + group = ctx.group + return _split(grad_output, group), None + + +class RowInputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + return _split(input_, group) + + @staticmethod + def backward(ctx, grad_outputs): + group = ctx.group + return _gather(grad_outputs, group), None + + +class RowOutputAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + return _reduce(input_, group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class ColumnParallelLinear(torch.nn.Module): + + def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False, tp_group=-1): + super().__init__() + assert tp_group != -1 + self.input_size = input_size + self.output_size = output_size + self.full_input = full_input + self.full_output = full_output + + self.group = tp_group + world_size = torch.distributed.get_world_size(group=self.group) + + # print_each_rank(f'> parallizing linear using column partition: ' + # f'{output_size} partitioned by {world_size} devices') + + # not if output size is smaller than world size, + # no parallel enbaled. Each device compute the same + if world_size > output_size: + world_size = 1 + + self.weight = Parameter(torch.empty( + int(self.output_size // world_size), + self.input_size, + )) + if bias: + self.bias = Parameter(torch.empty( + int(self.output_size // world_size), + )) + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + bias = self.bias + if not self.full_input: + raise RuntimeError("Expected full tensor input") + input_parallel = ColumnInputAdapter.apply(input_, self.group) + output_parallel = F.linear(input_parallel, self.weight, bias) + if self.full_output: + output = ColumnOutputAdapter.apply(output_parallel, self.group) + else: + output = output_parallel + return output + + +class RowParallelLinear(torch.nn.Module): + + def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False, tp_group=-1): + super().__init__() + assert tp_group != -1 + self.input_size = input_size + self.output_size = output_size + self.full_input = full_input + self.full_output = full_output + + self.group = tp_group + world_size = torch.distributed.get_world_size(group=self.group) + + # print_each_rank(f'> parallizing linear using row partition: ' + # f'{output_size} partitioned by {world_size} devices') + + # not if output size is smaller than world size, + # no parallel enbaled. Each device compute the same + if world_size > output_size: + world_size = 1 + + self.weight = Parameter(torch.empty( + self.output_size, + int(self.input_size // world_size), + )) + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + bias = self.bias + if self.full_input: + input_parallel = RowInputAdapter.apply(input_, self.group) + else: + input_parallel = input_ + output_parallel = F.linear(input_parallel, self.weight, bias) + if self.full_output: + output = RowOutputAdapter.apply(output_parallel, self.group) + else: + output = output_parallel + return output + + +class ShardEmbedding(torch.nn.Module): + + def __init__(self, num_embeddings, embedding_dim, tp_group): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + self.group = tp_group + self.shard_num = torch.distributed.get_world_size(group=self.group) + self.myshard = torch.distributed.get_rank(group=self.group) + + shard_num_embeddings = self.num_embeddings // self.shard_num + self.vocab_start_index = shard_num_embeddings * self.myshard + self.vocab_end_index = self.vocab_start_index + shard_num_embeddings + + self.weight = torch.nn.Parameter( + torch.empty(shard_num_embeddings, self.embedding_dim) + ) + + def forward(self, input_): + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | \ + (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + output_parallel = F.embedding( + masked_input, self.weight, + None, None, 2., False, False + ) + output = RowOutputAdapter.apply(output_parallel, self.group) + return output diff --git a/examples/swin/swin_348M.py b/examples/swin/swin_348M.py new file mode 100644 index 00000000..5c2e1a9f --- /dev/null +++ b/examples/swin/swin_348M.py @@ -0,0 +1,858 @@ + +# -------------------------------------------------------- +# Modified from Swin-Transformer Repo +""" +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_348M.py +""" +# -------------------------------------------------------- + +from typing import Dict, Optional, Tuple +import torch +import torch.nn as nn +import argparse + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary +from cube.runtime import reducer +from cube.runtime.device import DeviceGroup +from cube.runtime.reducer import Reducer + +from examples.swin.layers import ColumnParallelLinear, RowParallelLinear + + +_reducer_groups: Dict[Tuple[int], Reducer] = dict() + + +def setup_device_group(tp: int, wp: int, dp: int, layer_id: int): + """ + Layer wise device group initialize + + Returns: + + """ + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + tp_size, tp_group_nums = tp, ndevs // tp + wp_size, wp_group_nums = wp, ndevs // wp + dp_size, dp_group_nums = dp, ndevs // dp + + if not tp_size * wp_size * dp_size == ndevs: + raise RuntimeError("Expected all devices are used") + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize tensor parallel groups + for i in range(tp_group_nums): + ranks = list(range(i * tp_size, (i + 1) * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + resource.tp_group = group + print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + + # initialize wp parallel group + all_wp_parallel_group_ranks = list() + for i in range(dp_size): + start_rank = i * dp_group_nums + end_rank = (i + 1) * dp_group_nums + for j in range(tp_size): + ranks = list(range(start_rank + j, end_rank, tp_size)) + all_wp_parallel_group_ranks.append(ranks) + # initialize groups + group = devs.get_group(ranks) + if myrank in ranks: + wp_ranks = ranks + _reducer_groups[tuple(ranks)] = Reducer(ranks) + print_each_rank(f'layer {layer_id}: initialzed window parallel group: {wp_ranks}', rank_only=myrank) + + # initialize data parallel groups + start_rank = 0 + end_rank = ndevs + for i in range(wp_size * tp_size): + ranks = list(range(i, ndevs, wp_size * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + _reducer_groups[tuple(ranks)] = Reducer(ranks) + print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) + return tp_ranks, wp_ranks, dp_ranks + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class MegatronMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tp_group=-1): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, full_input=True, full_output=False, tp_group=tp_group) + self.act = act_layer() + # self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = RowParallelLinear(hidden_features, out_features, full_input=False, full_output=True, tp_group=tp_group) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +class MegatronWindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., tp_group=-1): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.global_num_heads = num_heads + + tp_world_size = torch.distributed.get_world_size(group=tp_group) + if num_heads % tp_world_size != 0: + print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') + self.num_heads = num_heads // tp_world_size + + self.dim_heads = dim // self.global_num_heads + self.scale = qk_scale or self.dim_heads ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # print(f'qkv embed dim: {dim}') + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False, tp_group=tp_group) + self.attn_drop = nn.Dropout(attn_drop) + # self.proj = nn.Linear(dim, dim) + self.proj = RowParallelLinear(dim, dim, full_input=False, full_output=True, tp_group=tp_group) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + relative_position_bias = self.relative_position_bias_table[relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + tp_group=-1, wp_group=-1): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = MegatronWindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + tp_group=tp_group) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MegatronMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop, + tp_group=tp_group + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + # [B, H, W, C] -> [B, H * W, C] + x = x.view(B, H * W, C) + # [B, H * W, C] -> [B, H * W, C] + x = shortcut + drop_path(x, self.drop_path_p) + # FFN + # [B, H * W, C] -> [B, H * W, C] + ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] + ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] + x = x + drop_path(ffn, self.drop_path_p) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, + tp=1, wp=1, dp=1, layer_id=-1): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + self.resource = cube.runtime.resource.EnvResource() + tp_ranks, wp_ranks, dp_ranks = setup_device_group(tp, wp, dp, layer_id) + tp_group = DeviceGroup().get_group(tp_ranks) + wp_group = DeviceGroup().get_group(wp_ranks) + + # build blocks + self.blocks = nn.ModuleList() + for i in range(depth): + block = SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + tp_group=tp_group, wp_group=wp_group + ) + self.blocks.append(block) + + if wp > 1: + for param in self.blocks.parameters(): + _reducer_groups[tuple(wp_ranks)].add_param() + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + return x + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, dp=1, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + + # ====================== depth 0 =========================== + pconfig = dict(layer_id=0, tp=4, wp=1, dp=dp) + input_resolution = ( + patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) + ) + self.basic_layer0 = BasicLayer( + dim=int(embed_dim * 2 ** 0), + input_resolution=input_resolution, + depth=depths[0], + num_heads=num_heads[0], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:0]):sum(depths[:0 + 1])], + norm_layer=norm_layer, + **pconfig, + ) + + self.merging0 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 0), norm_layer=norm_layer + ) + + # ====================== depth 1 =========================== + pconfig = dict(layer_id=1, tp=4, wp=1, dp=dp) + input_resolution = ( + patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) + ) + self.basic_layer1 = BasicLayer( + dim=int(embed_dim * 2 ** 1), + input_resolution=input_resolution, + depth=depths[1], + num_heads=num_heads[1], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:1]):sum(depths[:1 + 1])], + norm_layer=norm_layer, + **pconfig, + ) + + self.merging1 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 1), norm_layer=norm_layer + ) + + + # ====================== depth 2 =========================== + pconfig = dict(layer_id=2, tp=4, wp=1, dp=dp) + input_resolution = ( + patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) + ) + self.basic_layer2 = BasicLayer( + dim=int(embed_dim * 2 ** 2), + input_resolution=input_resolution, + depth=depths[2], + num_heads=num_heads[2], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:2]):sum(depths[:2 + 1])], + norm_layer=norm_layer, + **pconfig + ) + + self.merging2 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 2), norm_layer=norm_layer + ) + + # ====================== depth 3 =========================== + pconfig = dict(layer_id=3, tp=4, wp=1, dp=dp) + self.basic_layer3 = BasicLayer( + dim=int(embed_dim * 2 ** 3), + input_resolution=(patches_resolution[0] // (2 ** 3), + patches_resolution[1] // (2 ** 3)), + depth=depths[3], + num_heads=num_heads[3], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:3]):sum(depths[:3 + 1])], + norm_layer=norm_layer, + **pconfig + ) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x) + + x = self.basic_layer0(x) + x = self.merging0(x) + x = self.basic_layer1(x) + x = self.merging1(x) + x = self.basic_layer2(x) + x = self.merging2(x) + x = self.basic_layer3(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C L + x = torch.flatten(x, 1) + + x = self.head(x) + return x + + +def train(args): + resource = cube.runtime.resource.EnvResource() + + # image batch input + N, C, H, W = [1, 3, 224, 224] + + # embed_dim, depths, num_heads, window_size = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 + # ] + + # 348.55 M + embed_dim, depths, num_heads, window_size = [ + 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + ] + + # 895.7 M Model + # embed_dim, depths, num_heads, window_size = [ + # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # ] + + # 2.01B model + # embed_dim, depths, num_heads, window_size = [ + # 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # ] + + + model = SwinTransformer(embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size) + model = model.cuda() + memory_summary() + + # setup data parallel reducer + reducer = None + if args.dp > 1: + print('> initialize weight reducer') + reducer = resource.reducer + for param in model.parameters(): + reducer.add_param(param) + + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [N // args.dp, C, H, W]) + + def train_iter(model, dataloader): + img = next(dataloader) + loss = model(img) + loss = torch.sum(loss) + loss.backward() + if reducer is not None: + reducer.allreduce() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # start training + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on 1st iteration') + memory_summary() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + + +if __name__ == '__main__': + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--tp', type=int, default=1, + help='tensor parallel size') + parser.add_argument('--wp', type=int, default=1, + help='data parallel size') + parser.add_argument('--dp', type=int, default=1, + help='pipeline parallel size') + parser.add_argument('--micro-bs', type=int, default=-1) + args = parser.parse_args() + + cube.init() + + """ + # allocate resource + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + tp_size, tp_group_nums = args.tp, ndevs // args.tp + wp_size, wp_group_nums = args.wp, ndevs // args.wp + dp_size, dp_group_nums = args.dp, ndevs // args.dp + + if not tp_size * wp_size * dp_size == ndevs: + raise RuntimeError("Expected all devices are used") + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize tensor parallel groups + for i in range(tp_group_nums): + ranks = list(range(i * tp_size, (i + 1) * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + resource.tp_group = group + print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + + # initialize wp parallel group + all_wp_parallel_group_ranks = list() + for i in range(dp_size): + start_rank = i * dp_group_nums + end_rank = (i + 1) * dp_group_nums + for j in range(tp_size): + ranks = list(range(start_rank + j, end_rank, tp_size)) + all_wp_parallel_group_ranks.append(ranks) + # initialize groups + group = devs.get_group(ranks) + if myrank in ranks: + wp_ranks = ranks + resource.wp_group = group + resource.wp_reducer = cube.runtime.reducer.Reducer(ranks) + print_each_rank(f'initialzed window parallel group: {wp_ranks}', rank_only=myrank) + + # initialize data parallel groups + start_rank = 0 + end_rank = ndevs + for i in range(wp_size * tp_size): + ranks = list(range(i, ndevs, wp_size * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + resource.dp_group = group + resource.dp_reducer = cube.runtime.reducer.Reducer(ranks) + print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) + """ + train(args) From 9441a21d9122f3fa92b31056cd15f9762734cd56 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 27 Nov 2021 20:53:45 +0800 Subject: [PATCH 0422/1892] swin for window parallel --- cube/profiler/timer.py | 21 +- cube/runtime/function/dist.py | 68 ++- examples/swin/layers.py | 8 + examples/swin/swin_1.py | 794 ---------------------------------- examples/swin/swin_348M.py | 104 ++++- 5 files changed, 164 insertions(+), 831 deletions(-) delete mode 100644 examples/swin/swin_1.py diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index bc1481f5..f5e12e13 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -29,17 +29,26 @@ class CudaTimer: """ class __CudaTimer: - def __init__(self): + def __init__(self, **kwargs): self.start_t = None self.stop_t = None self.field = dict() self.field_data = dict() + self.enabled = True + if 'enable' in kwargs: + self.enabled = kwargs['enable'] instance = None - def __init__(self): + def __init__(self, enable = None): if not CudaTimer.instance: - CudaTimer.instance = CudaTimer.__CudaTimer() + kwargs = dict() + if enable is not None: + kwargs = dict(enable=enable) + CudaTimer.instance = CudaTimer.__CudaTimer(**kwargs) + elif enable is not None: + CudaTimer.instance.enabled = enable + def start(self, field_name='default'): """ @@ -47,6 +56,8 @@ def start(self, field_name='default'): Note `start` and `stop` on the same field can be called nestly """ + if not CudaTimer.instance.enabled: + return torch.cuda.synchronize() if field_name not in CudaTimer.instance.field: CudaTimer.instance.field[field_name] = list() @@ -60,6 +71,8 @@ def stop(self, field_name='default'): Returns: float (ms) """ + if not CudaTimer.instance.enabled: + return if field_name not in CudaTimer.instance.field: raise RuntimeError("Missing start on the field") torch.cuda.synchronize() @@ -90,7 +103,7 @@ def print_all(self, times): if 'send' in field_name or 'recv' in field_name: comm_span += span msg.append('{} : {:.2f} ms'.format(field_name, span)) - msg.append('{} : {:.2f} ms'.format('communication', comm_span)) + # msg.append('{} : {:.2f} ms'.format('communication', comm_span)) msg = ' | '.join(msg) print_each_rank(msg) diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py index aa07b2d6..c99ae874 100644 --- a/cube/runtime/function/dist.py +++ b/cube/runtime/function/dist.py @@ -1,8 +1,11 @@ +from typing import Tuple, List import torch from torch.distributed.distributed_c10d import _get_global_rank from cube.profiler.timer import print_each_rank +from cube.profiler.timer import CudaTimer + def get_global_rank(group, group_rank): if group is None: @@ -11,7 +14,7 @@ def get_global_rank(group, group_rank): return _get_global_rank(group, group_rank) -def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): +def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, group): """ partition torch.roll at shifted dimension @@ -20,10 +23,11 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): shift: int dim: int """ - world_size = torch.distributed.get_world_size(group) + world_size = len(dim_ranks) if world_size == 1: return torch.roll(input, (shift), (dim,)) - rank = torch.distributed.get_rank(group) + global_rank = torch.distributed.get_rank() + dim_rank = dim_ranks.index(torch.distributed.get_rank(group)) # halo exchange at H dimension if shift < 0: shift = 0 - shift @@ -38,26 +42,27 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): recv_tensor = torch.empty_like(remote) # send to next rank and recv from prevous rank - # print_each_rank(f'send to {(rank - 1 + world_size) % world_size}, recv from {(rank + 1) % world_size}') - send_local_rank = (rank - 1 + world_size) % world_size + send_local_rank = dim_ranks[(dim_rank - 1 + world_size) % world_size] send_global_rank = get_global_rank(group, send_local_rank) - recv_local_rank = (rank + 1) % world_size + recv_local_rank = dim_ranks[(dim_rank + 1) % world_size] recv_global_rank = get_global_rank(group, recv_local_rank) + # print_each_rank(f'send to {send_global_rank}, recv from {recv_global_rank}') send_op = torch.distributed.P2POp( torch.distributed.isend, remote, - send_global_rank, group=group + send_global_rank, group=group, tag=global_rank ) recv_op = torch.distributed.P2POp( torch.distributed.irecv, recv_tensor, - recv_global_rank, group=group + recv_global_rank, group=group, tag=recv_global_rank ) - ops = [send_op, recv_op] if rank % 2 == 0 else [recv_op, send_op] + ops = [send_op, recv_op] if dim_rank % 2 == 0 else [recv_op, send_op] reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() tensor = torch.cat((local, recv_tensor), dim=dim).contiguous() return tensor + elif shift > 0: boundary = input.shape[dim] - shift if dim == 1: @@ -71,20 +76,21 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): recv_tensor = torch.empty_like(remote) # to global rank - send_local_rank = (rank + 1) % world_size + send_local_rank = dim_ranks[(dim_rank + 1) % world_size] send_global_rank = get_global_rank(group, send_local_rank) - recv_local_rank = (rank - 1 + world_size) % world_size + recv_local_rank = dim_ranks[(dim_rank - 1 + world_size) % world_size] recv_global_rank = get_global_rank(group, recv_local_rank) + # print_each_rank(f'send to {send_global_rank}, recv from {recv_global_rank}') send_op = torch.distributed.P2POp( torch.distributed.isend, remote, - send_global_rank, group=group + send_global_rank, group=group, tag=global_rank ) recv_op = torch.distributed.P2POp( torch.distributed.irecv, recv_tensor, - recv_global_rank, group=group + recv_global_rank, group=group, tag=recv_global_rank ) - ops = [send_op, recv_op] if rank % 2 == 0 else [recv_op, send_op] + ops = [send_op, recv_op] if dim_rank % 2 == 0 else [recv_op, send_op] reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() @@ -130,23 +136,29 @@ class RollDimParallel(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, shift: int, dim: int, group=None): + def forward(ctx, input_, shift: int, dim: int, dim_ranks: List[int], group=None): + CudaTimer().start(field_name='roll parallel_fw') ctx.shift = shift ctx.dim = dim ctx.group = group - output = _roll_dim_parallel(input_, shift, dim, group) + ctx.dim_ranks = dim_ranks + output = _roll_dim_parallel(input_, shift, dim, dim_ranks, group) + CudaTimer().stop(field_name='roll parallel_fw') return output @staticmethod def backward(ctx, grad_output): + CudaTimer().start(field_name='roll parallel_bw') shift = ctx.shift dim = ctx.dim group = ctx.group - grad = _roll_dim_parallel(grad_output, 0-shift, dim, group) - return grad, None, None, None + dim_ranks = ctx.dim_ranks + grad = _roll_dim_parallel(grad_output, 0-shift, dim, dim_ranks, group) + CudaTimer().stop(field_name='roll parallel_bw') + return grad, None, None, None, None -def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): +def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, group): """ partition torch.roll at shifted dimension @@ -155,7 +167,15 @@ def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, group): shift: int dim: int """ - return RollDimParallel.apply(input, shift, dim, group) + return RollDimParallel.apply(input, shift, dim, dim_ranks, group) + + +def roll_grid_parallel(input: torch.Tensor, + shifts: Tuple[int, int], dims: Tuple[int, int], + nh_group_ranks: List[int], nw_group_ranks: List[int], group): + input = roll_dim_parallel(input, shifts[0], 1, nh_group_ranks, group) + input = roll_dim_parallel(input, shifts[1], 2, nw_group_ranks, group) + return input class GridPartition(torch.autograd.Function): @@ -167,6 +187,7 @@ def forward(ctx, input_, nrow: int, ncol: int, group=None): """ input: [B, H, W, C] """ + CudaTimer().start(field_name='grid_partition_forward') ctx.group = group world_size = torch.distributed.get_world_size(group) ctx.nrow = nrow @@ -178,10 +199,12 @@ def forward(ctx, input_, nrow: int, ncol: int, group=None): chunk = torch.chunk(input_, nrow, dim=1)[myrow] chunk = torch.chunk(chunk, ncol, dim=2)[mycol].contiguous() + CudaTimer().stop(field_name='grid_partition_forward') return chunk @staticmethod def backward(ctx, grad_output): + CudaTimer().start(field_name='grid_partition_backward') group = ctx.group nrow = ctx.nrow ncol = ctx.ncol @@ -198,6 +221,7 @@ def backward(ctx, grad_output): row_slice = torch.cat(tuple(tensor_list[row*ncol:(row+1)*ncol]), dim=2) rows.append(row_slice) grad_output = torch.cat(tuple(rows), dim=1).contiguous() + CudaTimer().stop(field_name='grid_partition_backward') return grad_output, None, None, None @@ -211,6 +235,7 @@ def forward(ctx, input_, nrow: int, ncol: int, group=None): input: [B, H, W, C] output: [B, nrow * H, ncol * W, C] """ + CudaTimer().start(field_name='grid_collection_forward') ctx.group = group world_size = torch.distributed.get_world_size(group) ctx.nrow = nrow @@ -228,10 +253,12 @@ def forward(ctx, input_, nrow: int, ncol: int, group=None): row_slice = torch.cat(tuple(tensor_list[row*ncol:(row+1)*ncol]), dim=2) rows.append(row_slice) output = torch.cat(tuple(rows), dim=1).contiguous() + CudaTimer().stop(field_name='grid_collection_forward') return output @staticmethod def backward(ctx, grad_output): + CudaTimer().start(field_name='grid_collection_backward') group = ctx.group nrow = ctx.nrow ncol = ctx.ncol @@ -242,6 +269,7 @@ def backward(ctx, grad_output): chunk = torch.chunk(grad_output, nrow, dim=1)[myrow] chunk = torch.chunk(chunk, ncol, dim=2)[mycol].contiguous() + CudaTimer().stop(field_name='grid_collection_backward') return chunk, None, None, None diff --git a/examples/swin/layers.py b/examples/swin/layers.py index 88f11a51..9fcd3cd9 100644 --- a/examples/swin/layers.py +++ b/examples/swin/layers.py @@ -2,16 +2,20 @@ from torch import autograd import torch.nn.functional as F from torch.nn.parameter import Parameter +from cube.profiler.timer import CudaTimer def _reduce(input_, group): """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. + CudaTimer().start(field_name='tp_allreduce') world_size = torch.distributed.get_world_size(group) if world_size == 1: + CudaTimer().stop(field_name='tp_allreduce') return input_ torch.distributed.all_reduce(input_, group=group) + CudaTimer().stop(field_name='tp_allreduce') return input_ @@ -33,11 +37,13 @@ def _split(input_, group): def _gather(input_, group): """Gather tensors and concatinate along the last dimension.""" + CudaTimer().start(field_name='tp_allgather') world_size = torch.distributed.get_world_size(group=group) rank = torch.distributed.get_rank(group=group) # Bypass the function if we are using only 1 GPU. if world_size==1: + CudaTimer().stop(field_name='tp_allgather') return input_ # Size and dimension. last_dim = input_.dim() - 1 @@ -46,6 +52,8 @@ def _gather(input_, group): torch.distributed.all_gather(tensor_list, input_, group=group) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() + + CudaTimer().stop(field_name='tp_allgather') return output diff --git a/examples/swin/swin_1.py b/examples/swin/swin_1.py deleted file mode 100644 index 8663258d..00000000 --- a/examples/swin/swin_1.py +++ /dev/null @@ -1,794 +0,0 @@ - -# -------------------------------------------------------- -# Modified from Swin-Transformer Repo -""" -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_1.py -""" -# -------------------------------------------------------- - -from typing import Optional -import torch -import torch.nn as nn - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary - -import argparse - - -from examples.swin.layers import ColumnParallelLinear, RowParallelLinear - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class MegatronMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, full_input=True, full_output=False) - self.act = act_layer() - # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, full_input=False, full_output=True) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class MegatronWindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.global_num_heads = num_heads - group = cube.runtime.resource.EnvResource().tp_group - tp_world_size = torch.distributed.get_world_size(group=group) - if num_heads % tp_world_size != 0: - print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') - self.num_heads = num_heads // torch.distributed.get_world_size(group=group) - self.dim_heads = dim // self.global_num_heads - self.scale = qk_scale or self.dim_heads ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - # print(f'qkv embed dim: {dim}') - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False) - self.attn_drop = nn.Dropout(attn_drop) - # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, full_input=False, full_output=True) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - relative_position_bias = self.relative_position_bias_table[relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = MegatronWindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MegatronMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - x = x + drop_path(ffn, self.drop_path_p) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - print(f'input resolution: {input_resolution}') - print(f'window num: {input_resolution[0] * input_resolution[1] / window_size / window_size}') - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x): - for blk in self.blocks: - x = blk(x) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) - self.layers.append(layer) - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def forward_features(self, x): - x = self.patch_embed(x) - # if self.ape: - # x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - return x - - def forward(self, x): - # forward features - # x = self.forward_features(x) - x = self.patch_embed(x) - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - - x = self.head(x) - return x - - def flops(self): - flops = 0 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes - return flops - - -def train(args): - resource = cube.runtime.resource.EnvResource() - - # image batch input - N, C, H, W = [1, 3, 224, 224] - - # embed_dim, depths, num_heads, window_size = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 - # ] - - # 348.55 M - # embed_dim, depths, num_heads, window_size = [ - # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 - # ] - - # 895.7 M Model -- 224x224 - embed_dim, depths, num_heads, window_size = [ - 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 - ] - - # 2.01B model - # embed_dim, depths, num_heads, window_size = [ - # 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 - # ] - - - model = SwinTransformer(img_size = H, - embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size) - model = model.cuda() - memory_summary() - - # setup data parallel reducer - reducer = None - if args.dp > 1: - print('> initialize weight reducer') - reducer = resource.reducer - for param in model.parameters(): - reducer.add_param(param) - - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [N // args.dp, C, H, W]) - - def train_iter(model, dataloader): - img = next(dataloader) - loss = model(img) - loss = torch.sum(loss) - loss.backward() - if reducer is not None: - reducer.allreduce() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - # start training - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - - -if __name__ == '__main__': - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--tp', type=int, default=1, - help='tensor parallel size') - parser.add_argument('--dp', type=int, default=1, - help='data parallel size') - parser.add_argument('--pp', type=int, default=1, - help='pipeline parallel size') - parser.add_argument('--wp', type=int, default=1) - parser.add_argument('--micro-bs', type=int, default=-1) - args = parser.parse_args() - - cube.init() - - # allocate resource - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - tp_size, tp_group_nums = args.tp, ndevs // args.tp - dp_size, dp_group_nums = args.dp, ndevs // args.dp - pp_size, pp_group_nums = args.pp, ndevs // args.pp - - if not pp_size * dp_size * tp_size == ndevs: - raise RuntimeError("Expected all devices are used") - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize data parallel group - all_data_parallel_group_ranks = list() - for i in range(pp_size): - start_rank = i * pp_group_nums - end_rank = (i + 1) * pp_group_nums - for j in range(tp_size): - ranks = list(range(start_rank + j, end_rank, tp_size)) - all_data_parallel_group_ranks.append(ranks) - # initialize groups - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - resource.dp_group = group - resource.reducer = cube.runtime.reducer.Reducer(ranks) - print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) - - # initialize pipelne parallel groups - for i in range(dp_size): - ranks = [data_parallel_group_ranks[i] - for data_parallel_group_ranks in all_data_parallel_group_ranks] - group = devs.get_group(ranks) - if myrank in ranks: - pp_ranks = ranks - resource.pp_group = group - print_each_rank(f'initialzed pipeline parallel group: {pp_ranks}', rank_only=myrank) - - # initialize tensor parallel groups - for i in range(tp_group_nums): - ranks = list(range(i * tp_size, (i + 1) * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - resource.tp_group = group - print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - train(args) diff --git a/examples/swin/swin_348M.py b/examples/swin/swin_348M.py index 5c2e1a9f..8e2a9059 100644 --- a/examples/swin/swin_348M.py +++ b/examples/swin/swin_348M.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn import argparse +import time import cube from cube.profiler import CudaTimer @@ -273,7 +274,7 @@ class SwinTransformerBlock(nn.Module): def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - tp_group=-1, wp_group=-1): + tp_group=-1, wp_plans=-1): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -304,6 +305,7 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 ) if self.shift_size > 0: + self.wp_group, self.wp_nH_ranks, self.wp_nW_ranks = wp_plans # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 @@ -339,7 +341,11 @@ def forward(self, x): # cyclic shift if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_x = cube.runtime.function.roll_grid_parallel( + x, (-self.shift_size, -self.shift_size), (1,2), + self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group + ) + # shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x @@ -363,7 +369,11 @@ def forward(self, x): # [B, H', W', C] -> [B, H, W, C] x = shifted_x if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = cube.runtime.function.roll_grid_parallel( + shifted_x, (self.shift_size, self.shift_size), (1,2), + self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group + ) + # x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) # [B, H, W, C] -> [B, H * W, C] x = x.view(B, H * W, C) # [B, H * W, C] -> [B, H * W, C] @@ -450,12 +460,47 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, tp_ranks, wp_ranks, dp_ranks = setup_device_group(tp, wp, dp, layer_id) tp_group = DeviceGroup().get_group(tp_ranks) wp_group = DeviceGroup().get_group(wp_ranks) + wp_nH_ranks = [-1] + wp_nW_ranks = [-1] + + # window parallel + self.wp_resolution = input_resolution + if wp > 1: + H, W = self.input_resolution + nH = 1 + nW = wp // nH + while nH <= nW: + if H % nH != 0 or W % nW != 0: + nW = nW // 2 + nH = int(nH * 2) + else: + break + if nH > nW: + raise RuntimeError(f"layer {layer_id}: Cannot window partition plan") + print_each_rank(f"layer {layer_id}: Find partition plan: Width // {nW}, Height // {nH}") + self.wp_resolution = (H // nH, W // nW) + self.wp_group = wp_group + # wp_group multi dim shift ranks + for i in range(nW): + ranks = list(range(i * nH, (i + 1) * nH)) + if torch.distributed.get_rank(wp_group) in ranks: + wp_nW_ranks = ranks + break + for i in range(nH): + ranks = list(range(i, wp, nH)) + if torch.distributed.get_rank(wp_group) in ranks: + wp_nH_ranks = ranks + break + assert wp_nH_ranks != [-1] + assert wp_nW_ranks != [-1] + print_each_rank(f'window parallel nH local ranks: {wp_nH_ranks}') + print_each_rank(f'window parallel nW local ranks: {wp_nW_ranks}') # build blocks self.blocks = nn.ModuleList() for i in range(depth): block = SwinTransformerBlock( - dim=dim, input_resolution=input_resolution, + dim=dim, input_resolution=self.wp_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, @@ -463,17 +508,35 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, - tp_group=tp_group, wp_group=wp_group + tp_group=tp_group, wp_plans=(wp_group, wp_nH_ranks, wp_nW_ranks) ) self.blocks.append(block) - + + self.wp_preprocess = False + self.wp_postprocess = False if wp > 1: for param in self.blocks.parameters(): - _reducer_groups[tuple(wp_ranks)].add_param() + _reducer_groups[tuple(wp_ranks)].add_param(param) + self.wp_preprocess = True + self.wp_postprocess = True def forward(self, x): + if self.wp_preprocess: + oH, oW = self.input_resolution + pH, pW = self.wp_resolution + x = x.view(-1, oH, oW, self.dim) + x = cube.runtime.function.grid_partition(x, oH // pH, oW // pW, group=self.wp_group) + x = x.view(-1, pH * pW, self.dim).contiguous() + for blk in self.blocks: x = blk(x) + + if self.wp_postprocess: + oH, oW = self.input_resolution + pH, pW = self.wp_resolution + x = x.view(-1, pH, pW, self.dim) + x = cube.runtime.function.grid_collection(x, oH // pH, oW // pW, group=self.wp_group) + x = x.view(-1, oH * oW, self.dim) return x @@ -583,7 +646,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # ====================== depth 0 =========================== - pconfig = dict(layer_id=0, tp=4, wp=1, dp=dp) + pconfig = dict(layer_id=0, tp=1, wp=4, dp=dp) input_resolution = ( patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) ) @@ -606,7 +669,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, ) # ====================== depth 1 =========================== - pconfig = dict(layer_id=1, tp=4, wp=1, dp=dp) + pconfig = dict(layer_id=1, tp=1, wp=4, dp=dp) input_resolution = ( patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) ) @@ -630,7 +693,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # ====================== depth 2 =========================== - pconfig = dict(layer_id=2, tp=4, wp=1, dp=dp) + pconfig = dict(layer_id=2, tp=1, wp=4, dp=dp) input_resolution = ( patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) ) @@ -688,13 +751,21 @@ def forward(self, x): x = self.patch_embed(x) x = self.pos_drop(x) + CudaTimer().start('basic_layer0') x = self.basic_layer0(x) + CudaTimer().stop('basic_layer0') x = self.merging0(x) + CudaTimer().start('basic_layer1') x = self.basic_layer1(x) + CudaTimer().stop('basic_layer1') x = self.merging1(x) + CudaTimer().start('basic_layer2') x = self.basic_layer2(x) + CudaTimer().stop('basic_layer2') x = self.merging2(x) + CudaTimer().start('basic_layer3') x = self.basic_layer3(x) + CudaTimer().stop('basic_layer3') x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C L @@ -762,12 +833,15 @@ def train_iter(model, dataloader): nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - CudaTimer().warmup() + CudaTimer(enable=False).warmup() torch.distributed.barrier() + span = 0 iter_num = 128 for step in range(iter_num): if step >= 40: - CudaTimer().start('e2e') + torch.cuda.synchronize() + start = time.time() + CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() @@ -775,16 +849,20 @@ def train_iter(model, dataloader): print('> passed on 1st iteration') memory_summary() if step >= 40: + torch.cuda.synchronize() + stop = time.time() + span += (stop - start) * 1000 CudaTimer().stop('e2e') if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + iter_time = span / (iter_num-40) throughput = N / iter_time * 1000 print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) memory_summary() + CudaTimer().print_all(times=iter_num-40) if __name__ == '__main__': From c279548d3ab54551ac317d39a10aa4fedc5c6851 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 27 Nov 2021 22:41:41 +0800 Subject: [PATCH 0423/1892] allreduce for wp --- examples/swin/swin_348M.py | 116 ++++++++++--------------------------- 1 file changed, 29 insertions(+), 87 deletions(-) diff --git a/examples/swin/swin_348M.py b/examples/swin/swin_348M.py index 8e2a9059..ae9d0102 100644 --- a/examples/swin/swin_348M.py +++ b/examples/swin/swin_348M.py @@ -60,7 +60,6 @@ def setup_device_group(tp: int, wp: int, dp: int, layer_id: int): group = devs.get_group(ranks) if myrank in ranks: tp_ranks = ranks - resource.tp_group = group print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) # initialize wp parallel group @@ -193,7 +192,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at tp_world_size = torch.distributed.get_world_size(group=tp_group) if num_heads % tp_world_size != 0: - print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') + raise RuntimeError(f'detecting un-even num head {num_heads} partition to {tp_world_size}') self.num_heads = num_heads // tp_world_size self.dim_heads = dim // self.global_num_heads @@ -493,8 +492,8 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, break assert wp_nH_ranks != [-1] assert wp_nW_ranks != [-1] - print_each_rank(f'window parallel nH local ranks: {wp_nH_ranks}') - print_each_rank(f'window parallel nW local ranks: {wp_nW_ranks}') + # print_each_rank(f'window parallel nH local ranks: {wp_nH_ranks}') + # print_each_rank(f'window parallel nW local ranks: {wp_nW_ranks}') # build blocks self.blocks = nn.ModuleList() @@ -615,7 +614,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, dp=1, **kwargs): + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, **kwargs): super().__init__() self.num_classes = num_classes @@ -646,7 +645,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # ====================== depth 0 =========================== - pconfig = dict(layer_id=0, tp=1, wp=4, dp=dp) + pconfig = pconfigs[0] input_resolution = ( patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) ) @@ -669,7 +668,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, ) # ====================== depth 1 =========================== - pconfig = dict(layer_id=1, tp=1, wp=4, dp=dp) + pconfig = pconfigs[1] input_resolution = ( patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) ) @@ -693,7 +692,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # ====================== depth 2 =========================== - pconfig = dict(layer_id=2, tp=1, wp=4, dp=dp) + pconfig = pconfigs[2] input_resolution = ( patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) ) @@ -716,7 +715,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, ) # ====================== depth 3 =========================== - pconfig = dict(layer_id=3, tp=4, wp=1, dp=dp) + pconfig = pconfigs[3] self.basic_layer3 = BasicLayer( dim=int(embed_dim * 2 ** 3), input_resolution=(patches_resolution[0] // (2 ** 3), @@ -775,9 +774,7 @@ def forward(self, x): return x -def train(args): - resource = cube.runtime.resource.EnvResource() - +def train(args, pconfigs): # image batch input N, C, H, W = [1, 3, 224, 224] @@ -786,15 +783,15 @@ def train(args): # ] # 348.55 M - embed_dim, depths, num_heads, window_size = [ - 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 - ] - - # 895.7 M Model # embed_dim, depths, num_heads, window_size = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 # ] + # 895.7 M Model + embed_dim, depths, num_heads, window_size = [ + 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 + ] + # 2.01B model # embed_dim, depths, num_heads, window_size = [ # 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 @@ -804,18 +801,11 @@ def train(args): model = SwinTransformer(embed_dim = embed_dim, depths = depths, num_heads = num_heads, - window_size = window_size) + window_size = window_size, + pconfigs = pconfigs) model = model.cuda() memory_summary() - # setup data parallel reducer - reducer = None - if args.dp > 1: - print('> initialize weight reducer') - reducer = resource.reducer - for param in model.parameters(): - reducer.add_param(param) - dataloader = cube.runtime.syndata.SynDataLoader( 1280, [0], [N // args.dp, C, H, W]) @@ -824,8 +814,11 @@ def train_iter(model, dataloader): loss = model(img) loss = torch.sum(loss) loss.backward() - if reducer is not None: + CudaTimer().start('wp_allreduce') + for ranks in _reducer_groups: + reducer = _reducer_groups[ranks] reducer.allreduce() + CudaTimer().stop('wp_allreduce') optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -869,68 +862,17 @@ def train_iter(model, dataloader): # resource allocation parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--tp', type=int, default=1, - help='tensor parallel size') - parser.add_argument('--wp', type=int, default=1, - help='data parallel size') parser.add_argument('--dp', type=int, default=1, help='pipeline parallel size') parser.add_argument('--micro-bs', type=int, default=-1) args = parser.parse_args() - cube.init() - - """ - # allocate resource - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - tp_size, tp_group_nums = args.tp, ndevs // args.tp - wp_size, wp_group_nums = args.wp, ndevs // args.wp - dp_size, dp_group_nums = args.dp, ndevs // args.dp - - if not tp_size * wp_size * dp_size == ndevs: - raise RuntimeError("Expected all devices are used") - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize tensor parallel groups - for i in range(tp_group_nums): - ranks = list(range(i * tp_size, (i + 1) * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - resource.tp_group = group - print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - # initialize wp parallel group - all_wp_parallel_group_ranks = list() - for i in range(dp_size): - start_rank = i * dp_group_nums - end_rank = (i + 1) * dp_group_nums - for j in range(tp_size): - ranks = list(range(start_rank + j, end_rank, tp_size)) - all_wp_parallel_group_ranks.append(ranks) - # initialize groups - group = devs.get_group(ranks) - if myrank in ranks: - wp_ranks = ranks - resource.wp_group = group - resource.wp_reducer = cube.runtime.reducer.Reducer(ranks) - print_each_rank(f'initialzed window parallel group: {wp_ranks}', rank_only=myrank) + pconfigs = [ + dict(layer_id=1, tp=4, wp=2, dp=args.dp), # basic layer 0 + dict(layer_id=2, tp=4, wp=2, dp=args.dp), # basic layer 1 + dict(layer_id=3, tp=4, wp=2, dp=args.dp), # basic layer 2 + dict(layer_id=4, tp=8, wp=1, dp=args.dp), # basic layer 3 + ] - # initialize data parallel groups - start_rank = 0 - end_rank = ndevs - for i in range(wp_size * tp_size): - ranks = list(range(i, ndevs, wp_size * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - resource.dp_group = group - resource.dp_reducer = cube.runtime.reducer.Reducer(ranks) - print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) - """ - train(args) + cube.init() + train(args, pconfigs) From 42dc60b67794669828cf9b9aadc4639b56be93db Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 28 Nov 2021 00:21:59 +0800 Subject: [PATCH 0424/1892] fix bugs on roll partition --- cube/runtime/function/dist.py | 25 +++++++++++++------------ examples/swin/swin_348M.py | 19 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py index c99ae874..67ff4aa5 100644 --- a/cube/runtime/function/dist.py +++ b/cube/runtime/function/dist.py @@ -23,6 +23,7 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, gro shift: int dim: int """ + return input world_size = len(dim_ranks) if world_size == 1: return torch.roll(input, (shift), (dim,)) @@ -137,24 +138,24 @@ class RollDimParallel(torch.autograd.Function): """ @staticmethod def forward(ctx, input_, shift: int, dim: int, dim_ranks: List[int], group=None): - CudaTimer().start(field_name='roll parallel_fw') + CudaTimer().start(field_name='roll parallel') ctx.shift = shift ctx.dim = dim ctx.group = group ctx.dim_ranks = dim_ranks output = _roll_dim_parallel(input_, shift, dim, dim_ranks, group) - CudaTimer().stop(field_name='roll parallel_fw') + CudaTimer().stop(field_name='roll parallel') return output @staticmethod def backward(ctx, grad_output): - CudaTimer().start(field_name='roll parallel_bw') + CudaTimer().start(field_name='roll parallel') shift = ctx.shift dim = ctx.dim group = ctx.group dim_ranks = ctx.dim_ranks grad = _roll_dim_parallel(grad_output, 0-shift, dim, dim_ranks, group) - CudaTimer().stop(field_name='roll parallel_bw') + CudaTimer().stop(field_name='roll parallel') return grad, None, None, None, None @@ -187,7 +188,7 @@ def forward(ctx, input_, nrow: int, ncol: int, group=None): """ input: [B, H, W, C] """ - CudaTimer().start(field_name='grid_partition_forward') + CudaTimer().start(field_name='grid_partition') ctx.group = group world_size = torch.distributed.get_world_size(group) ctx.nrow = nrow @@ -199,12 +200,12 @@ def forward(ctx, input_, nrow: int, ncol: int, group=None): chunk = torch.chunk(input_, nrow, dim=1)[myrow] chunk = torch.chunk(chunk, ncol, dim=2)[mycol].contiguous() - CudaTimer().stop(field_name='grid_partition_forward') + CudaTimer().stop(field_name='grid_partition') return chunk @staticmethod def backward(ctx, grad_output): - CudaTimer().start(field_name='grid_partition_backward') + CudaTimer().start(field_name='grid_partition') group = ctx.group nrow = ctx.nrow ncol = ctx.ncol @@ -221,7 +222,7 @@ def backward(ctx, grad_output): row_slice = torch.cat(tuple(tensor_list[row*ncol:(row+1)*ncol]), dim=2) rows.append(row_slice) grad_output = torch.cat(tuple(rows), dim=1).contiguous() - CudaTimer().stop(field_name='grid_partition_backward') + CudaTimer().stop(field_name='grid_partition') return grad_output, None, None, None @@ -235,7 +236,7 @@ def forward(ctx, input_, nrow: int, ncol: int, group=None): input: [B, H, W, C] output: [B, nrow * H, ncol * W, C] """ - CudaTimer().start(field_name='grid_collection_forward') + CudaTimer().start(field_name='grid_collection') ctx.group = group world_size = torch.distributed.get_world_size(group) ctx.nrow = nrow @@ -253,12 +254,12 @@ def forward(ctx, input_, nrow: int, ncol: int, group=None): row_slice = torch.cat(tuple(tensor_list[row*ncol:(row+1)*ncol]), dim=2) rows.append(row_slice) output = torch.cat(tuple(rows), dim=1).contiguous() - CudaTimer().stop(field_name='grid_collection_forward') + CudaTimer().stop(field_name='grid_collection') return output @staticmethod def backward(ctx, grad_output): - CudaTimer().start(field_name='grid_collection_backward') + CudaTimer().start(field_name='grid_collection') group = ctx.group nrow = ctx.nrow ncol = ctx.ncol @@ -269,7 +270,7 @@ def backward(ctx, grad_output): chunk = torch.chunk(grad_output, nrow, dim=1)[myrow] chunk = torch.chunk(chunk, ncol, dim=2)[mycol].contiguous() - CudaTimer().stop(field_name='grid_collection_backward') + CudaTimer().stop(field_name='grid_collection') return chunk, None, None, None diff --git a/examples/swin/swin_348M.py b/examples/swin/swin_348M.py index ae9d0102..27c2f323 100644 --- a/examples/swin/swin_348M.py +++ b/examples/swin/swin_348M.py @@ -23,7 +23,6 @@ from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary -from cube.runtime import reducer from cube.runtime.device import DeviceGroup from cube.runtime.reducer import Reducer @@ -476,24 +475,24 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, break if nH > nW: raise RuntimeError(f"layer {layer_id}: Cannot window partition plan") - print_each_rank(f"layer {layer_id}: Find partition plan: Width // {nW}, Height // {nH}") + print_each_rank(f"layer {layer_id}: Find partition plan: H{H} // {nH}, W{W} // {nW}") self.wp_resolution = (H // nH, W // nW) self.wp_group = wp_group # wp_group multi dim shift ranks - for i in range(nW): - ranks = list(range(i * nH, (i + 1) * nH)) + for i in range(nH): + ranks = list(range(i * nW, (i + 1) * nW)) if torch.distributed.get_rank(wp_group) in ranks: wp_nW_ranks = ranks break - for i in range(nH): - ranks = list(range(i, wp, nH)) + for i in range(nW): + ranks = list(range(i, wp, nW)) if torch.distributed.get_rank(wp_group) in ranks: wp_nH_ranks = ranks break assert wp_nH_ranks != [-1] assert wp_nW_ranks != [-1] - # print_each_rank(f'window parallel nH local ranks: {wp_nH_ranks}') - # print_each_rank(f'window parallel nW local ranks: {wp_nW_ranks}') + print_each_rank(f'window parallel nH group ranks: {wp_nH_ranks}') + print_each_rank(f'window parallel nW group ranks: {wp_nW_ranks}') # build blocks self.blocks = nn.ModuleList() @@ -868,9 +867,9 @@ def train_iter(model, dataloader): args = parser.parse_args() pconfigs = [ - dict(layer_id=1, tp=4, wp=2, dp=args.dp), # basic layer 0 + dict(layer_id=1, tp=2, wp=4, dp=args.dp), # basic layer 0 dict(layer_id=2, tp=4, wp=2, dp=args.dp), # basic layer 1 - dict(layer_id=3, tp=4, wp=2, dp=args.dp), # basic layer 2 + dict(layer_id=3, tp=2, wp=4, dp=args.dp), # basic layer 2 # prob at 8:1? dict(layer_id=4, tp=8, wp=1, dp=args.dp), # basic layer 3 ] From 9f6dd2e5f54ee7f529c314296266da6f5a179016 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 28 Nov 2021 16:40:55 +0800 Subject: [PATCH 0425/1892] add support on data, window, tensorr parallel --- examples/swin/{swin_348M.py => swin_dwt.py} | 127 +++++++++++++++----- examples/swin/swin_transformer.py | 14 ++- 2 files changed, 104 insertions(+), 37 deletions(-) rename examples/swin/{swin_348M.py => swin_dwt.py} (90%) diff --git a/examples/swin/swin_348M.py b/examples/swin/swin_dwt.py similarity index 90% rename from examples/swin/swin_348M.py rename to examples/swin/swin_dwt.py index 27c2f323..581da512 100644 --- a/examples/swin/swin_348M.py +++ b/examples/swin/swin_dwt.py @@ -13,6 +13,7 @@ """ # -------------------------------------------------------- +from functools import reduce from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -29,7 +30,8 @@ from examples.swin.layers import ColumnParallelLinear, RowParallelLinear -_reducer_groups: Dict[Tuple[int], Reducer] = dict() +_wp_reducer: Dict[Tuple[int], Reducer] = dict() +_dp_reducer: Dict[Tuple[int], Reducer] = dict() def setup_device_group(tp: int, wp: int, dp: int, layer_id: int): @@ -73,7 +75,7 @@ def setup_device_group(tp: int, wp: int, dp: int, layer_id: int): group = devs.get_group(ranks) if myrank in ranks: wp_ranks = ranks - _reducer_groups[tuple(ranks)] = Reducer(ranks) + _wp_reducer[tuple(ranks)] = Reducer(ranks) print_each_rank(f'layer {layer_id}: initialzed window parallel group: {wp_ranks}', rank_only=myrank) # initialize data parallel groups @@ -84,7 +86,7 @@ def setup_device_group(tp: int, wp: int, dp: int, layer_id: int): group = devs.get_group(ranks) if myrank in ranks: dp_ranks = ranks - _reducer_groups[tuple(ranks)] = Reducer(ranks) + _dp_reducer[tuple(ranks)] = Reducer(ranks) print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) return tp_ranks, wp_ranks, dp_ranks @@ -302,8 +304,10 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 tp_group=tp_group ) + self.wp_group, self.wp_nH_ranks, self.wp_nW_ranks = wp_plans + self.use_wp = torch.distributed.get_world_size(self.wp_group) != 1 + if self.shift_size > 0: - self.wp_group, self.wp_nH_ranks, self.wp_nW_ranks = wp_plans # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 @@ -339,11 +343,13 @@ def forward(self, x): # cyclic shift if self.shift_size > 0: - shifted_x = cube.runtime.function.roll_grid_parallel( - x, (-self.shift_size, -self.shift_size), (1,2), - self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group - ) - # shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + if self.use_wp: + shifted_x = cube.runtime.function.roll_grid_parallel( + x, (-self.shift_size, -self.shift_size), (1,2), + self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group + ) + else: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x @@ -367,11 +373,13 @@ def forward(self, x): # [B, H', W', C] -> [B, H, W, C] x = shifted_x if self.shift_size > 0: - x = cube.runtime.function.roll_grid_parallel( - shifted_x, (self.shift_size, self.shift_size), (1,2), - self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group - ) - # x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + if self.use_wp: + x = cube.runtime.function.roll_grid_parallel( + shifted_x, (self.shift_size, self.shift_size), (1,2), + self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group + ) + else: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) # [B, H, W, C] -> [B, H * W, C] x = x.view(B, H * W, C) # [B, H * W, C] -> [B, H * W, C] @@ -514,7 +522,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, self.wp_postprocess = False if wp > 1: for param in self.blocks.parameters(): - _reducer_groups[tuple(wp_ranks)].add_param(param) + _wp_reducer[tuple(wp_ranks)].add_param(param) self.wp_preprocess = True self.wp_postprocess = True @@ -774,26 +782,57 @@ def forward(self, x): def train(args, pconfigs): - # image batch input - N, C, H, W = [1, 3, 224, 224] - # embed_dim, depths, num_heads, window_size = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 + # dim_head is always 32 + + # img resolution, windows size: 224, 384, 518, 640 + C, H, W, window_size = [3, 224, 224, 7] + # C, H, W, window_size = [3, 384, 384, 12] + # C, H, W, window_size = [3, 518, 518, ?] + # C, H, W, window_size = [4, 640, 640, 20] + + # image batch size + N = 8 + + # Swin-Tiny + # embed_dim, depths, num_heads = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24] # ] - # 348.55 M - # embed_dim, depths, num_heads, window_size = [ - # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + # SwinV2-B: 87 M + # embed_dim, depths, num_heads = [ + # 128, [2, 2, 18, 2], [4, 8, 12, 24] # ] - # 895.7 M Model - embed_dim, depths, num_heads, window_size = [ - 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # SwinV2-L: 196 M + # embed_dim, depths, num_heads = [ + # 192, [2, 2, 18, 2], [6, 12, 24, 48] + # ] + + # SwinV2-H: 657 M + # embed_dim, depths, num_heads = [ + # 352, [2, 2, 18, 2], [11, 22, 44, 88] + # ] + + # SwinV2-H modified: 782 M + embed_dim, depths, num_heads = [ + 384, [2, 2, 18, 2], [12, 24, 48, 96] ] + # SwinV2-G: 2.5B Model + # embed_dim, depths, num_heads = [ + # 512, [2, 2, 42, 2], [16, 32, 64, 128] + # ] + + # 895.7 M Model + # embed_dim, depths, num_heads = [ + # 384, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + # 2.01B model - # embed_dim, depths, num_heads, window_size = [ - # 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # embed_dim, depths, num_heads = [ + # 576, [2, 2, 22, 2], [12, 24, 48, 96] # ] @@ -802,22 +841,44 @@ def train(args, pconfigs): num_heads = num_heads, window_size = window_size, pconfigs = pconfigs) + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + model = model.cuda() memory_summary() dataloader = cube.runtime.syndata.SynDataLoader( 1280, [0], [N // args.dp, C, H, W]) + if args.dp > 1: + assert len(_dp_reducer) == 1 + reducer = None + for ranks in _dp_reducer: + reducer = _dp_reducer[ranks] + for param in model.parameters(): + reduced = False + for wp_ranks in _wp_reducer: + if param in _wp_reducer[wp_ranks]._params: + reduced = True + break + if not reduced: + reducer.add_param(param) + def train_iter(model, dataloader): img = next(dataloader) loss = model(img) loss = torch.sum(loss) loss.backward() CudaTimer().start('wp_allreduce') - for ranks in _reducer_groups: - reducer = _reducer_groups[ranks] + for ranks in _wp_reducer: + reducer = _wp_reducer[ranks] reducer.allreduce() CudaTimer().stop('wp_allreduce') + CudaTimer().start('dp_allreduce') + for ranks in _dp_reducer: + reducer = _dp_reducer[ranks] + reducer.allreduce() + CudaTimer().stop('dp_allreduce') optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -867,10 +928,10 @@ def train_iter(model, dataloader): args = parser.parse_args() pconfigs = [ - dict(layer_id=1, tp=2, wp=4, dp=args.dp), # basic layer 0 - dict(layer_id=2, tp=4, wp=2, dp=args.dp), # basic layer 1 - dict(layer_id=3, tp=2, wp=4, dp=args.dp), # basic layer 2 # prob at 8:1? - dict(layer_id=4, tp=8, wp=1, dp=args.dp), # basic layer 3 + dict(layer_id=0, tp=4, wp=1, dp=args.dp), # basic layer 0 + dict(layer_id=1, tp=4, wp=1, dp=args.dp), # basic layer 1 + dict(layer_id=2, tp=4, wp=1, dp=args.dp), # basic layer 2 # prob at 8:1? + dict(layer_id=3, tp=4, wp=1, dp=args.dp), # basic layer 3 ] cube.init() diff --git a/examples/swin/swin_transformer.py b/examples/swin/swin_transformer.py index 56cc68c1..4d4f0763 100644 --- a/examples/swin/swin_transformer.py +++ b/examples/swin/swin_transformer.py @@ -624,17 +624,23 @@ def train(): # image batch input N, C, H, W = [1, 3, 224, 224] + # embed_dim, depths, num_heads, window_size = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 + # ] + + # 348.55 M embed_dim, depths, num_heads, window_size = [ - 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 + 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 ] + # 895.7 M Model -- 224x224 # embed_dim, depths, num_heads, window_size = [ - # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 + # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 # ] - # 1.02B Model + # 2.01B model # embed_dim, depths, num_heads, window_size = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 + # 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 # ] From 02c51072349bcaef2360620c7a8e946ea521a244 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 28 Nov 2021 17:22:55 +0800 Subject: [PATCH 0426/1892] cmd support for experiments --- examples/swin/swin_dwt.py | 50 +++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/examples/swin/swin_dwt.py b/examples/swin/swin_dwt.py index 581da512..ed0f6ad0 100644 --- a/examples/swin/swin_dwt.py +++ b/examples/swin/swin_dwt.py @@ -9,11 +9,15 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_348M.py + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 4 1 \ + --layer1 1 4 1 \ + --layer2 1 1 4 \ + --layer3 1 1 4 + """ # -------------------------------------------------------- -from functools import reduce from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -792,7 +796,7 @@ def train(args, pconfigs): # C, H, W, window_size = [4, 640, 640, 20] # image batch size - N = 8 + N = args.bs # Swin-Tiny # embed_dim, depths, num_heads = [ @@ -919,20 +923,46 @@ def train_iter(model, dataloader): if __name__ == '__main__': + + cube.init() # resource allocation parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--dp', type=int, default=1, - help='pipeline parallel size') + parser.add_argument('--layer0', type=int, nargs='+', + help='data, window tensor parallel config') + parser.add_argument('--layer1', type=int, nargs='+', + help='data, window tensor parallel config') + parser.add_argument('--layer2', type=int, nargs='+', + help='data, window tensor parallel config') + parser.add_argument('--layer3', type=int, nargs='+', + help='data, window tensor parallel config') + parser.add_argument('--bs', type=int, default=1, + help='bs') parser.add_argument('--micro-bs', type=int, default=-1) args = parser.parse_args() + assert len(args.layer0) == 3 + assert len(args.layer1) == 3 + assert len(args.layer2) == 3 + assert len(args.layer3) == 3 + + # data parallel should be same + assert args.layer0[0] == args.layer1[0] and args.layer1[0] == args.layer2[0] and args.layer2[0] == args.layer3[0] + args.dp = args.layer0[0] + pconfigs = [ - dict(layer_id=0, tp=4, wp=1, dp=args.dp), # basic layer 0 - dict(layer_id=1, tp=4, wp=1, dp=args.dp), # basic layer 1 - dict(layer_id=2, tp=4, wp=1, dp=args.dp), # basic layer 2 # prob at 8:1? - dict(layer_id=3, tp=4, wp=1, dp=args.dp), # basic layer 3 + dict(layer_id=0, dp=args.layer0[0], wp=args.layer0[1], tp=args.layer0[2]), # basic layer 0 + dict(layer_id=1, dp=args.layer1[0], wp=args.layer1[1], tp=args.layer1[2]), # basic layer 1 + dict(layer_id=2, dp=args.layer2[0], wp=args.layer2[1], tp=args.layer2[2]), # basic layer 2 # prob at 8:1? + dict(layer_id=3, dp=args.layer3[0], wp=args.layer3[1], tp=args.layer3[2]), # basic layer 3 ] - cube.init() + # pconfigs = [ + # dict(layer_id=0, tp=4, wp=1, dp=args.dp), # basic layer 0 + # dict(layer_id=1, tp=4, wp=1, dp=args.dp), # basic layer 1 + # dict(layer_id=2, tp=4, wp=1, dp=args.dp), # basic layer 2 # prob at 8:1? + # dict(layer_id=3, tp=4, wp=1, dp=args.dp), # basic layer 3 + # ] + + print_each_rank(pconfigs, rank_only=0) train(args, pconfigs) From aa38a03b02dc6ff1a196c22eef4daf067a9e93b7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 28 Nov 2021 18:47:12 +0800 Subject: [PATCH 0427/1892] exp --- eval/swin.sh | 135 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100755 eval/swin.sh diff --git a/eval/swin.sh b/eval/swin.sh new file mode 100755 index 00000000..39ed337c --- /dev/null +++ b/eval/swin.sh @@ -0,0 +1,135 @@ +#!/usr/sh + +# ================== Megatron Policy Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 2 1 1 \ + --layer1 2 1 1 \ + --layer2 2 1 1 \ + --layer3 2 1 1 \ + > 2gpu_dp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 4 1 1 \ + --layer1 4 1 1 \ + --layer2 4 1 1 \ + --layer3 4 1 1 \ + > 4gpu_dp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 8 1 1 \ + --layer1 8 1 1 \ + --layer2 8 1 1 \ + --layer3 8 1 1 \ + > 8gpu_dp.txt + +# ================== Maximal Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 2 1 \ + --layer1 1 2 1 \ + --layer2 1 2 1 \ + --layer3 1 2 1 \ + > 2gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 2 1 \ + --layer1 1 2 1 \ + --layer2 1 2 1 \ + --layer3 1 2 1 \ + > 4gpu_tp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 2 1 4 \ + --layer1 2 1 4 \ + --layer2 2 1 4 \ + --layer3 2 1 4 \ + > 8gpu_2dp4tp.txt + +# ================== Window + Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 2 1 \ + --layer1 1 2 1 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + > 2gpu_2wp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 4 1 \ + --layer1 1 4 1 \ + --layer2 1 1 4 \ + --layer3 1 1 4 \ + > 2gpu_4wp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 8 1 \ + --layer1 1 1 8 \ + --layer2 1 1 4 \ + --layer3 1 1 4 \ + > 2gpu_8wp8tp.txt From f026f5688cc8dfc9aed62080427944b49aa9718c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 28 Nov 2021 23:15:33 +0800 Subject: [PATCH 0428/1892] dp tp hybrid --- examples/swin/layers.py | 111 ++++- examples/swin/swin_dt.py | 938 ++++++++++++++++++++++++++++++++++++++ examples/swin/swin_dwt.py | 8 +- 3 files changed, 1034 insertions(+), 23 deletions(-) create mode 100644 examples/swin/swin_dt.py diff --git a/examples/swin/layers.py b/examples/swin/layers.py index 9fcd3cd9..65790dd4 100644 --- a/examples/swin/layers.py +++ b/examples/swin/layers.py @@ -19,7 +19,7 @@ def _reduce(input_, group): return input_ -def _split(input_, group): +def _split(input_, group, dim=-1): """Split the tensor along its last dimension and keep the corresponding slice.""" @@ -28,14 +28,13 @@ def _split(input_, group): # Bypass the function if we are using only 1 GPU. if world_size==1: return input_ - last_dim = input_.dim() - 1 - last_dim_size = input_.size()[last_dim] // world_size - tensor_list = torch.split(input_, last_dim_size, dim=last_dim) + dim_size = input_.size()[dim] // world_size + tensor_list = torch.split(input_, dim_size, dim=dim) output = tensor_list[rank].contiguous() return output -def _gather(input_, group): +def _gather(input_, group, dim=-1): """Gather tensors and concatinate along the last dimension.""" CudaTimer().start(field_name='tp_allgather') @@ -46,16 +45,30 @@ def _gather(input_, group): CudaTimer().stop(field_name='tp_allgather') return input_ # Size and dimension. - last_dim = input_.dim() - 1 tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=group) # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() + output = torch.cat(tensor_list, dim=dim).contiguous() CudaTimer().stop(field_name='tp_allgather') return output +def _scatter(input_, group, dim=0): + """Reduce-Scatter tensor""" + CudaTimer().start(field_name='tp_reduce_scatter') + world_size = torch.distributed.get_world_size(group=group) + if world_size == 1: + CudaTimer().stop(field_name='tp_reduce_scatter') + return input_ + rank = torch.distributed.get_rank(group=group) + tensor_list = list(torch.chunk(input_, world_size, dim)) + # for idx, tensor in enumerate(tensor_list): + # tensor_list[idx] = tensor.contiguous() + torch.distributed.reduce_scatter(tensor_list[rank], tensor_list, group=group) + CudaTimer().stop(field_name='tp_reduce_scatter') + return tensor_list[rank] + class ColumnInputAdapter(torch.autograd.Function): @staticmethod @@ -102,15 +115,53 @@ def backward(ctx, grad_output): return grad_output, None +class DPtoTPAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group): + """ + split + """ + ctx.group = group + return _gather(input_, group, dim=0) + + @staticmethod + def backward(ctx, grad_output): + """ + reduce-scatter + """ + group = ctx.group + return _split(grad_output, group, dim=0), None + + +class TPtoDPAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group): + """ + Reduce-scatter + """ + ctx.group = group + return _split(input_, group, dim=0) + + @staticmethod + def backward(ctx, grad_output): + """ + all-gather + """ + group = ctx.group + return _gather(grad_output, group, dim=0), None + + + + class ColumnParallelLinear(torch.nn.Module): - def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False, tp_group=-1): + def __init__(self, input_size, output_size, bias=True, in_adapter=True, out_adapter=True, tp_group=-1): super().__init__() assert tp_group != -1 self.input_size = input_size self.output_size = output_size - self.full_input = full_input - self.full_output = full_output + self.in_adapter = in_adapter + self.out_adapter = out_adapter self.group = tp_group world_size = torch.distributed.get_world_size(group=self.group) @@ -138,11 +189,12 @@ def __init__(self, input_size, output_size, bias=True, full_input=True, full_out def forward(self, input_): bias = self.bias - if not self.full_input: - raise RuntimeError("Expected full tensor input") - input_parallel = ColumnInputAdapter.apply(input_, self.group) + if self.in_adapter: + input_parallel = ColumnInputAdapter.apply(input_, self.group) + else: + input_parallel = input_ output_parallel = F.linear(input_parallel, self.weight, bias) - if self.full_output: + if self.out_adapter: output = ColumnOutputAdapter.apply(output_parallel, self.group) else: output = output_parallel @@ -151,13 +203,13 @@ def forward(self, input_): class RowParallelLinear(torch.nn.Module): - def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False, tp_group=-1): + def __init__(self, input_size, output_size, bias=True, in_adapter=True, out_adapter=True, tp_group=-1): super().__init__() assert tp_group != -1 self.input_size = input_size self.output_size = output_size - self.full_input = full_input - self.full_output = full_output + self.in_adapter = in_adapter + self.out_adapter = out_adapter self.group = tp_group world_size = torch.distributed.get_world_size(group=self.group) @@ -183,12 +235,12 @@ def __init__(self, input_size, output_size, bias=True, full_input=True, full_out def forward(self, input_): bias = self.bias - if self.full_input: + if self.in_adapter: input_parallel = RowInputAdapter.apply(input_, self.group) else: input_parallel = input_ output_parallel = F.linear(input_parallel, self.weight, bias) - if self.full_output: + if self.out_adapter: output = RowOutputAdapter.apply(output_parallel, self.group) else: output = output_parallel @@ -227,3 +279,24 @@ def forward(self, input_): ) output = RowOutputAdapter.apply(output_parallel, self.group) return output + + +class DPtoTP(torch.nn.Module): + + def __init__(self, dp_group): + super().__init__() + self.group = dp_group + + def forward(self, input_): + return DPtoTPAdapter.apply(input_, self.group) + + +class TPtoDP(torch.nn.Module): + + def __init__(self, tp_group): + super().__init__() + self.group = tp_group + + def forward(self, input_): + return TPtoDPAdapter.apply(input_, self.group) + diff --git a/examples/swin/swin_dt.py b/examples/swin/swin_dt.py new file mode 100644 index 00000000..e32e4324 --- /dev/null +++ b/examples/swin/swin_dt.py @@ -0,0 +1,938 @@ + +# -------------------------------------------------------- +# Modified from Swin-Transformer Repo +""" +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dt.py --bs 8 \ + --layer0 4 1 \ + --layer1 4 1 \ + --layer2 1 4 \ + --layer3 1 4 + +""" +# -------------------------------------------------------- + +from typing import Dict, Optional, Tuple +import torch +import torch.nn as nn +import argparse +import time + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary +from cube.runtime.device import DeviceGroup +from cube.runtime.reducer import Reducer + +from examples.swin.layers import ColumnParallelLinear, DPtoTP, RowParallelLinear, TPtoDP + +_dp_reducer: Dict[Tuple[int], Reducer] = dict() + + +def setup_device_group(tp: int, dp: int, layer_id: int): + """ + Layer wise device group initialize + + Returns: + + """ + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + if not tp * dp == ndevs: + raise RuntimeError("Expected same device number") + + assert tp == 1 or dp == 1, "Currently hybrid not supported" + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize tensor parallel groups + for i in range(dp): + ranks = list(range(i * tp, (i + 1) * tp)) + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + + # initialize data parallel groups + for i in range(tp): + ranks = list(range(i, ndevs, tp)) + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) + return tp_ranks, dp_ranks + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class MegatronMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tp_group=-1): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=tp_group) + self.act = act_layer() + # self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=True, tp_group=tp_group) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +class MegatronWindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., tp_group=-1): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.global_num_heads = num_heads + + tp_world_size = torch.distributed.get_world_size(group=tp_group) + if num_heads % tp_world_size != 0: + raise RuntimeError(f'detecting un-even num head {num_heads} partition to {tp_world_size}') + self.num_heads = num_heads // tp_world_size + + self.dim_heads = dim // self.global_num_heads + self.scale = qk_scale or self.dim_heads ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) + self.attn_drop = nn.Dropout(attn_drop) + # self.proj = nn.Linear(dim, dim) + self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=True, tp_group=tp_group) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + relative_position_bias = self.relative_position_bias_table[relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + tp_group=-1): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = MegatronWindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + tp_group=tp_group) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MegatronMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop, + tp_group=tp_group + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + # [B, H, W, C] -> [B, H * W, C] + x = x.view(B, H * W, C) + # [B, H * W, C] -> [B, H * W, C] + x = shortcut + drop_path(x, self.drop_path_p) + # FFN + # [B, H * W, C] -> [B, H * W, C] + ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] + ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] + x = x + drop_path(ffn, self.drop_path_p) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, + tp_group=-1, layer_id=-1): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList() + for i in range(depth): + block = SwinTransformerBlock( + dim=dim, input_resolution=self.input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + tp_group=tp_group, + ) + self.blocks.append(block) + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + return x + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + + # ====================== depth 0 =========================== + pconfig = pconfigs[0] + l0_tp_ranks, l0_dp_ranks = setup_device_group(**pconfig) + tp_group = DeviceGroup().get_group(l0_tp_ranks) + + input_resolution = ( + patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) + ) + self.basic_layer0 = BasicLayer( + dim=int(embed_dim * 2 ** 0), + input_resolution=input_resolution, + depth=depths[0], + num_heads=num_heads[0], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:0]):sum(depths[:0 + 1])], + norm_layer=norm_layer, + tp_group=tp_group, + ) + + if len(l0_dp_ranks) > 1: + dp_ranks = tuple(l0_dp_ranks) + if dp_ranks not in _dp_reducer: + _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.basic_layer0.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + + # ====================== depth 1 =========================== + pconfig = pconfigs[1] + l1_tp_ranks, l1_dp_ranks = setup_device_group(**pconfig) + tp_group = DeviceGroup().get_group(l1_tp_ranks) + + # adapter + if len(l0_dp_ranks) > 1 and len(l1_tp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter01 = DPtoTP(DeviceGroup().get_group(l0_dp_ranks)) + elif len(l0_tp_ranks) > 1 and len(l1_dp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter01 = TPtoDP(DeviceGroup().get_group(l0_tp_ranks)) + else: + self.adapter01 = torch.nn.Identity() + + self.merging0 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 0), norm_layer=norm_layer + ) + + input_resolution = ( + patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) + ) + self.basic_layer1 = BasicLayer( + dim=int(embed_dim * 2 ** 1), + input_resolution=input_resolution, + depth=depths[1], + num_heads=num_heads[1], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:1]):sum(depths[:1 + 1])], + norm_layer=norm_layer, + tp_group=tp_group, + ) + + if len(l1_dp_ranks) > 1: + dp_ranks = tuple(l1_dp_ranks) + if dp_ranks not in _dp_reducer: + _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.basic_layer1.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.merging0.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + + # ====================== depth 2 =========================== + pconfig = pconfigs[2] + l2_tp_ranks, l2_dp_ranks = setup_device_group(**pconfig) + tp_group = DeviceGroup().get_group(l2_tp_ranks) + + # adapter + if len(l1_dp_ranks) > 1 and len(l2_tp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter12 = DPtoTP(DeviceGroup().get_group(l1_dp_ranks)) + elif len(l1_tp_ranks) > 1 and len(l2_dp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter12 = TPtoDP(DeviceGroup().get_group(l1_tp_ranks)) + else: + self.adapter12 = torch.nn.Identity() + + + self.merging1 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 1), norm_layer=norm_layer + ) + + input_resolution = ( + patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) + ) + self.basic_layer2 = BasicLayer( + dim=int(embed_dim * 2 ** 2), + input_resolution=input_resolution, + depth=depths[2], + num_heads=num_heads[2], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:2]):sum(depths[:2 + 1])], + norm_layer=norm_layer, + tp_group=tp_group + ) + + if len(l2_dp_ranks) > 1: + dp_ranks = tuple(l2_dp_ranks) + if dp_ranks not in _dp_reducer: + _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.basic_layer2.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.merging1.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + # ====================== depth 3 =========================== + pconfig = pconfigs[3] + l3_tp_ranks, l3_dp_ranks = setup_device_group(**pconfig) + tp_group = DeviceGroup().get_group(l3_tp_ranks) + + # adapter + if len(l2_dp_ranks) > 1 and len(l3_tp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter23 = DPtoTP(DeviceGroup().get_group(l2_dp_ranks)) + elif len(l2_tp_ranks) > 1 and len(l3_dp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter23 = TPtoDP(DeviceGroup().get_group(l2_tp_ranks)) + else: + self.adapter23 = torch.nn.Identity() + + self.merging2 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 2), norm_layer=norm_layer + ) + + self.basic_layer3 = BasicLayer( + dim=int(embed_dim * 2 ** 3), + input_resolution=(patches_resolution[0] // (2 ** 3), + patches_resolution[1] // (2 ** 3)), + depth=depths[3], + num_heads=num_heads[3], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:3]):sum(depths[:3 + 1])], + norm_layer=norm_layer, + tp_group=tp_group + ) + + if len(l3_dp_ranks) > 1: + dp_ranks = tuple(l3_dp_ranks) + if dp_ranks not in _dp_reducer: + _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.basic_layer3.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.merging2.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + if len(l3_dp_ranks) > 1: + dp_ranks = tuple(l3_dp_ranks) + for param in self.norm.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.head.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x) + + CudaTimer().start('basic_layer0') + x = self.basic_layer0(x) + CudaTimer().start('adapter') + x = self.adapter01(x) + CudaTimer().stop('adapter') + x = self.merging0(x) + CudaTimer().stop('basic_layer0') + + CudaTimer().start('basic_layer1') + x = self.basic_layer1(x) + CudaTimer().start('adapter') + x = self.adapter12(x) + CudaTimer().stop('adapter') + x = self.merging1(x) + CudaTimer().stop('basic_layer1') + + CudaTimer().start('basic_layer2') + x = self.basic_layer2(x) + CudaTimer().start('adapter') + x = self.adapter23(x) + CudaTimer().stop('adapter') + x = self.merging2(x) + CudaTimer().stop('basic_layer2') + + CudaTimer().start('basic_layer3') + x = self.basic_layer3(x) + CudaTimer().stop('basic_layer3') + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C L + x = torch.flatten(x, 1) + + x = self.head(x) + return x + + +def train(args, pconfigs): + + # dim_head is always 32 + + # img resolution, windows size: 224, 384, 518, 640 + C, H, W, window_size = [3, 224, 224, 7] + # C, H, W, window_size = [3, 384, 384, 12] + # C, H, W, window_size = [3, 518, 518, ?] + # C, H, W, window_size = [4, 640, 640, 20] + + # image batch size + N = args.bs + + # Swin-Tiny + # embed_dim, depths, num_heads = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24] + # ] + + # SwinV2-B: 87 M + # embed_dim, depths, num_heads = [ + # 128, [2, 2, 18, 2], [4, 8, 12, 24] + # ] + + # SwinV2-L: 196 M + # embed_dim, depths, num_heads = [ + # 192, [2, 2, 18, 2], [6, 12, 24, 48] + # ] + + # SwinV2-H: 657 M + # embed_dim, depths, num_heads = [ + # 352, [2, 2, 18, 2], [11, 22, 44, 88] + # ] + + # SwinV2-H modified: 782 M + embed_dim, depths, num_heads = [ + 384, [2, 2, 18, 2], [12, 24, 48, 96] + ] + + # SwinV2-G: 2.5B Model + # embed_dim, depths, num_heads = [ + # 512, [2, 2, 42, 2], [16, 32, 64, 128] + # ] + + # 895.7 M Model + # embed_dim, depths, num_heads = [ + # 384, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + + # 2.01B model + # embed_dim, depths, num_heads = [ + # 576, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + + model = SwinTransformer(embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size, + pconfigs = pconfigs) + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + model = model.cuda() + memory_summary() + + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [N // args.dp, C, H, W]) + + def train_iter(model, dataloader): + img = next(dataloader) + loss = model(img) + loss = torch.sum(loss) + loss.backward() + CudaTimer().start('dp_allreduce') + for ranks in _dp_reducer: + reducer = _dp_reducer[ranks] + reducer.allreduce() + CudaTimer().stop('dp_allreduce') + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # start training + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + CudaTimer(enable=False).warmup() + torch.distributed.barrier() + span = 0 + iter_num = 128 + for step in range(iter_num): + if step >= 40: + torch.cuda.synchronize() + start = time.time() + CudaTimer(enable=True).start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on 1st iteration') + memory_summary() + if step >= 40: + torch.cuda.synchronize() + stop = time.time() + span += (stop - start) * 1000 + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = span / (iter_num-40) + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + CudaTimer().print_all(times=iter_num-40) + + +if __name__ == '__main__': + + cube.init() + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--layer0', type=int, nargs='+', + help='data, tensor parallel config') + parser.add_argument('--layer1', type=int, nargs='+', + help='data, tensor parallel config') + parser.add_argument('--layer2', type=int, nargs='+', + help='data, tensor parallel config') + parser.add_argument('--layer3', type=int, nargs='+', + help='data, tensor parallel config') + parser.add_argument('--bs', type=int, default=1, + help='bs') + parser.add_argument('--micro-bs', type=int, default=-1) + args = parser.parse_args() + + assert len(args.layer0) == 2 + assert len(args.layer1) == 2 + assert len(args.layer2) == 2 + assert len(args.layer3) == 2 + + # data parallel should be same + args.dp = args.layer0[0] + + pconfigs = [ + dict(layer_id=0, dp=args.layer0[0], tp=args.layer0[1]), # basic layer 0 + dict(layer_id=1, dp=args.layer1[0], tp=args.layer1[1]), # basic layer 1 + dict(layer_id=2, dp=args.layer2[0], tp=args.layer2[1]), # basic layer 2 + dict(layer_id=3, dp=args.layer3[0], tp=args.layer3[1]), # basic layer 3 + ] + + print_each_rank(pconfigs, rank_only=0) + train(args, pconfigs) diff --git a/examples/swin/swin_dwt.py b/examples/swin/swin_dwt.py index ed0f6ad0..9ff46c2d 100644 --- a/examples/swin/swin_dwt.py +++ b/examples/swin/swin_dwt.py @@ -112,10 +112,10 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay out_features = out_features or in_features hidden_features = hidden_features or in_features # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, full_input=True, full_output=False, tp_group=tp_group) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=tp_group) self.act = act_layer() # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, full_input=False, full_output=True, tp_group=tp_group) + self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=True, tp_group=tp_group) self.drop = nn.Dropout(drop) def forward(self, x): @@ -209,10 +209,10 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # print(f'qkv embed dim: {dim}') - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False, tp_group=tp_group) + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) self.attn_drop = nn.Dropout(attn_drop) # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, full_input=False, full_output=True, tp_group=tp_group) + self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=True, tp_group=tp_group) self.proj_drop = nn.Dropout(proj_drop) # trunc_normal_(self.relative_position_bias_table, std=.02) From b5dac6d64365ae91792b7bfc27605a1cb47e7400 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 28 Nov 2021 23:21:49 +0800 Subject: [PATCH 0429/1892] swin exp script --- eval/swin.sh | 104 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 74 insertions(+), 30 deletions(-) diff --git a/eval/swin.sh b/eval/swin.sh index 39ed337c..6c1246f5 100755 --- a/eval/swin.sh +++ b/eval/swin.sh @@ -1,4 +1,4 @@ -#!/usr/sh +mkdir -p expdata # ================== Megatron Policy Parallel =============== @@ -10,11 +10,11 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ examples/swin/swin_dwt.py --bs 8 \ - --layer0 2 1 1 \ - --layer1 2 1 1 \ - --layer2 2 1 1 \ - --layer3 2 1 1 \ - > 2gpu_dp.txt + --layer0 1 1 2 \ + --layer1 1 1 2 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + > expdata/2gpu_dp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -24,11 +24,11 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ examples/swin/swin_dwt.py --bs 8 \ - --layer0 4 1 1 \ - --layer1 4 1 1 \ - --layer2 4 1 1 \ - --layer3 4 1 1 \ - > 4gpu_dp.txt + --layer0 2 1 2 \ + --layer1 2 1 2 \ + --layer2 2 1 2 \ + --layer3 2 1 2 \ + > expdata/4gpu_dp.txt python -m torch.distributed.launch \ @@ -39,11 +39,11 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ examples/swin/swin_dwt.py --bs 8 \ - --layer0 8 1 1 \ - --layer1 8 1 1 \ - --layer2 8 1 1 \ - --layer3 8 1 1 \ - > 8gpu_dp.txt + --layer0 4 1 2 \ + --layer1 4 1 2 \ + --layer2 4 1 2 \ + --layer3 4 1 2 \ + > expdata/8gpu_dp.txt # ================== Maximal Tensor Parallel =============== @@ -55,11 +55,11 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ examples/swin/swin_dwt.py --bs 8 \ - --layer0 1 2 1 \ - --layer1 1 2 1 \ - --layer2 1 2 1 \ - --layer3 1 2 1 \ - > 2gpu_tp.txt + --layer0 1 1 2 \ + --layer1 1 1 2 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + > expdata/2gpu_tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -69,11 +69,11 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ examples/swin/swin_dwt.py --bs 8 \ - --layer0 1 2 1 \ - --layer1 1 2 1 \ - --layer2 1 2 1 \ - --layer3 1 2 1 \ - > 4gpu_tp.txt + --layer0 1 1 4 \ + --layer1 1 1 4 \ + --layer2 1 1 4 \ + --layer3 1 1 4 \ + > expdata/4gpu_tp.txt python -m torch.distributed.launch \ @@ -88,7 +88,7 @@ python -m torch.distributed.launch \ --layer1 2 1 4 \ --layer2 2 1 4 \ --layer3 2 1 4 \ - > 8gpu_2dp4tp.txt + > expdata/8gpu_2dp4tp.txt # ================== Window + Tensor Parallel =============== @@ -104,7 +104,7 @@ python -m torch.distributed.launch \ --layer1 1 2 1 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > 2gpu_2wp2tp.txt + > expdata/2gpu_2wp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -118,7 +118,7 @@ python -m torch.distributed.launch \ --layer1 1 4 1 \ --layer2 1 1 4 \ --layer3 1 1 4 \ - > 2gpu_4wp4tp.txt + > expdata/2gpu_4wp4tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -132,4 +132,48 @@ python -m torch.distributed.launch \ --layer1 1 1 8 \ --layer2 1 1 4 \ --layer3 1 1 4 \ - > 2gpu_8wp8tp.txt + > expdata/2gpu_8wp8tp.txt + + +# ================== Data + Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dt.py --bs 8 \ + --layer0 2 1 \ + --layer1 2 1 \ + --layer2 1 2 \ + --layer3 1 2 \ + > expdata/2gpu_dt_2dp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dt.py --bs 8 \ + --layer0 4 1 \ + --layer1 4 1 \ + --layer2 1 4 \ + --layer3 1 4 \ + > expdata/4gpu_dt_4dp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dt.py --bs 8 \ + --layer0 8 1 \ + --layer1 1 1 \ + --layer2 1 8 \ + --layer3 1 8 \ + > expdata/8gpu_dt_8dp8tp.txt From c180860026706e63d47e2b09ecae6ced72a95fc4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 00:24:30 +0800 Subject: [PATCH 0430/1892] add fp16 mode --- examples/swin/swin_dt.py | 13 +++++++++++-- examples/swin/swin_dwt.py | 9 ++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/examples/swin/swin_dt.py b/examples/swin/swin_dt.py index e32e4324..b6735836 100644 --- a/examples/swin/swin_dt.py +++ b/examples/swin/swin_dt.py @@ -525,7 +525,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, **kwargs): + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, fp16=False, **kwargs): super().__init__() self.num_classes = num_classes @@ -581,6 +581,8 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, dp_ranks = tuple(l0_dp_ranks) if dp_ranks not in _dp_reducer: _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.patch_embed.parameters(): + _dp_reducer[dp_ranks].add_param(param) for param in self.basic_layer0.parameters(): _dp_reducer[dp_ranks].add_param(param) @@ -842,6 +844,9 @@ def train(args, pconfigs): num_heads = num_heads, window_size = window_size, pconfigs = pconfigs) + if args.fp16: + print_each_rank('use half precision') + model = model.half() nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) @@ -850,6 +855,10 @@ def train(args, pconfigs): dataloader = cube.runtime.syndata.SynDataLoader( 1280, [0], [N // args.dp, C, H, W]) + + if args.fp16: + data_buff = [[e.half() for e in data] for data in dataloader.datas] + dataloader.datas = data_buff def train_iter(model, dataloader): img = next(dataloader) @@ -916,7 +925,7 @@ def train_iter(model, dataloader): help='data, tensor parallel config') parser.add_argument('--bs', type=int, default=1, help='bs') - parser.add_argument('--micro-bs', type=int, default=-1) + parser.add_argument('--fp16', action='store_true', dest='fp16') args = parser.parse_args() assert len(args.layer0) == 2 diff --git a/examples/swin/swin_dwt.py b/examples/swin/swin_dwt.py index 9ff46c2d..1e6eb7e6 100644 --- a/examples/swin/swin_dwt.py +++ b/examples/swin/swin_dwt.py @@ -848,12 +848,19 @@ def train(args, pconfigs): nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + if args.fp16: + print_each_rank('use half model') + model = model.half() model = model.cuda() memory_summary() dataloader = cube.runtime.syndata.SynDataLoader( 1280, [0], [N // args.dp, C, H, W]) + if args.fp16: + data_buff = [[e.half() for e in data] for data in dataloader.datas] + dataloader.datas = data_buff + if args.dp > 1: assert len(_dp_reducer) == 1 reducer = None @@ -938,7 +945,7 @@ def train_iter(model, dataloader): help='data, window tensor parallel config') parser.add_argument('--bs', type=int, default=1, help='bs') - parser.add_argument('--micro-bs', type=int, default=-1) + parser.add_argument('--fp16', action='store_true', dest='fp16') args = parser.parse_args() assert len(args.layer0) == 3 From 780b5d273f7e4ad3291aef0509fee16e8c0e15be Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 00:26:45 +0800 Subject: [PATCH 0431/1892] swin test for fp16 --- eval/swin.sh | 12 ++-- eval/swin_fp16.sh | 179 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 eval/swin_fp16.sh diff --git a/eval/swin.sh b/eval/swin.sh index 6c1246f5..fb2aab71 100755 --- a/eval/swin.sh +++ b/eval/swin.sh @@ -118,10 +118,10 @@ python -m torch.distributed.launch \ --layer1 1 4 1 \ --layer2 1 1 4 \ --layer3 1 1 4 \ - > expdata/2gpu_4wp4tp.txt + > expdata/4gpu_4wp4tp.txt python -m torch.distributed.launch \ - --nproc_per_node=4 \ + --nproc_per_node=8 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ @@ -130,9 +130,9 @@ python -m torch.distributed.launch \ examples/swin/swin_dwt.py --bs 8 \ --layer0 1 8 1 \ --layer1 1 1 8 \ - --layer2 1 1 4 \ - --layer3 1 1 4 \ - > expdata/2gpu_8wp8tp.txt + --layer2 1 1 8 \ + --layer3 1 1 8 \ + > expdata/8gpu_8wp8tp.txt # ================== Data + Tensor Parallel =============== @@ -173,7 +173,7 @@ python -m torch.distributed.launch \ --use_env \ examples/swin/swin_dt.py --bs 8 \ --layer0 8 1 \ - --layer1 1 1 \ + --layer1 8 1 \ --layer2 1 8 \ --layer3 1 8 \ > expdata/8gpu_dt_8dp8tp.txt diff --git a/eval/swin_fp16.sh b/eval/swin_fp16.sh new file mode 100644 index 00000000..c8bd2e08 --- /dev/null +++ b/eval/swin_fp16.sh @@ -0,0 +1,179 @@ +mkdir -p expdata_fp16 + +# ================== Megatron Policy Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 1 2 \ + --layer1 1 1 2 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + > expdata_fp16/2gpu_dp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 2 1 2 \ + --layer1 2 1 2 \ + --layer2 2 1 2 \ + --layer3 2 1 2 \ + > expdata_fp16/4gpu_dp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 4 1 2 \ + --layer1 4 1 2 \ + --layer2 4 1 2 \ + --layer3 4 1 2 \ + > expdata_fp16/8gpu_dp.txt + +# ================== Maximal Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 1 2 \ + --layer1 1 1 2 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + > expdata_fp16/2gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 1 4 \ + --layer1 1 1 4 \ + --layer2 1 1 4 \ + --layer3 1 1 4 \ + > expdata_fp16/4gpu_tp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 2 1 4 \ + --layer1 2 1 4 \ + --layer2 2 1 4 \ + --layer3 2 1 4 \ + > expdata_fp16/8gpu_2dp4tp.txt + +# ================== Window + Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 2 1 \ + --layer1 1 2 1 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + > expdata_fp16/2gpu_2wp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 4 1 \ + --layer1 1 4 1 \ + --layer2 1 1 4 \ + --layer3 1 1 4 \ + > expdata_fp16/4gpu_4wp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 8 1 \ + --layer1 1 1 8 \ + --layer2 1 1 8 \ + --layer3 1 1 8 \ + > expdata_fp16/8gpu_8wp8tp.txt + + +# ================== Data + Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dt.py --bs 8 \ + --layer0 2 1 \ + --layer1 2 1 \ + --layer2 1 2 \ + --layer3 1 2 \ + > expdata_fp16/2gpu_dt_2dp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dt.py --bs 8 \ + --layer0 4 1 \ + --layer1 4 1 \ + --layer2 1 4 \ + --layer3 1 4 \ + > expdata_fp16/4gpu_dt_4dp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dt.py --bs 8 \ + --layer0 8 1 \ + --layer1 8 1 \ + --layer2 1 8 \ + --layer3 1 8 \ + > expdata_fp16/8gpu_dt_8dp8tp.txt From bdbca6a32ddfce2070d798b5aa512acde64ee1ae Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 00:27:42 +0800 Subject: [PATCH 0432/1892] swin fp16 --- eval/swin_fp16.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 eval/swin_fp16.sh diff --git a/eval/swin_fp16.sh b/eval/swin_fp16.sh old mode 100644 new mode 100755 From 0ef4f95b85f05ccf8ff34accf9071552c27787fc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 00:30:07 +0800 Subject: [PATCH 0433/1892] swin for fp16 --- eval/swin_fp16.sh | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/eval/swin_fp16.sh b/eval/swin_fp16.sh index c8bd2e08..75808c81 100755 --- a/eval/swin_fp16.sh +++ b/eval/swin_fp16.sh @@ -14,6 +14,7 @@ python -m torch.distributed.launch \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ + --fp16 \ > expdata_fp16/2gpu_dp.txt python -m torch.distributed.launch \ @@ -28,6 +29,7 @@ python -m torch.distributed.launch \ --layer1 2 1 2 \ --layer2 2 1 2 \ --layer3 2 1 2 \ + --fp16 \ > expdata_fp16/4gpu_dp.txt @@ -43,6 +45,7 @@ python -m torch.distributed.launch \ --layer1 4 1 2 \ --layer2 4 1 2 \ --layer3 4 1 2 \ + --fp16 \ > expdata_fp16/8gpu_dp.txt # ================== Maximal Tensor Parallel =============== @@ -59,6 +62,7 @@ python -m torch.distributed.launch \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ + --fp16 \ > expdata_fp16/2gpu_tp.txt python -m torch.distributed.launch \ @@ -73,6 +77,7 @@ python -m torch.distributed.launch \ --layer1 1 1 4 \ --layer2 1 1 4 \ --layer3 1 1 4 \ + --fp16 \ > expdata_fp16/4gpu_tp.txt @@ -88,6 +93,7 @@ python -m torch.distributed.launch \ --layer1 2 1 4 \ --layer2 2 1 4 \ --layer3 2 1 4 \ + --fp16 \ > expdata_fp16/8gpu_2dp4tp.txt # ================== Window + Tensor Parallel =============== @@ -104,6 +110,7 @@ python -m torch.distributed.launch \ --layer1 1 2 1 \ --layer2 1 1 2 \ --layer3 1 1 2 \ + --fp16 \ > expdata_fp16/2gpu_2wp2tp.txt python -m torch.distributed.launch \ @@ -118,6 +125,7 @@ python -m torch.distributed.launch \ --layer1 1 4 1 \ --layer2 1 1 4 \ --layer3 1 1 4 \ + --fp16 \ > expdata_fp16/4gpu_4wp4tp.txt python -m torch.distributed.launch \ @@ -132,6 +140,7 @@ python -m torch.distributed.launch \ --layer1 1 1 8 \ --layer2 1 1 8 \ --layer3 1 1 8 \ + --fp16 \ > expdata_fp16/8gpu_8wp8tp.txt @@ -148,6 +157,7 @@ python -m torch.distributed.launch \ --layer1 2 1 \ --layer2 1 2 \ --layer3 1 2 \ + --fp16 \ > expdata_fp16/2gpu_dt_2dp2tp.txt python -m torch.distributed.launch \ @@ -162,6 +172,7 @@ python -m torch.distributed.launch \ --layer1 4 1 \ --layer2 1 4 \ --layer3 1 4 \ + --fp16 \ > expdata_fp16/4gpu_dt_4dp4tp.txt python -m torch.distributed.launch \ @@ -176,4 +187,5 @@ python -m torch.distributed.launch \ --layer1 8 1 \ --layer2 1 8 \ --layer3 1 8 \ + --fp16 \ > expdata_fp16/8gpu_dt_8dp8tp.txt From 8cc667ab7b5d2dfeaf252798138cf846dc8fc3dc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 10:00:35 +0800 Subject: [PATCH 0434/1892] switch to 384 resolution --- examples/swin/swin_dt.py | 15 ++++++++------- examples/swin/swin_dwt.py | 5 +++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/swin/swin_dt.py b/examples/swin/swin_dt.py index b6735836..931cb45e 100644 --- a/examples/swin/swin_dt.py +++ b/examples/swin/swin_dt.py @@ -10,8 +10,8 @@ --master_port=8004 \ --use_env \ examples/swin/swin_dt.py --bs 8 \ - --layer0 4 1 \ - --layer1 4 1 \ + --layer0 1 4 \ + --layer1 1 4 \ --layer2 1 4 \ --layer3 1 4 @@ -789,10 +789,10 @@ def train(args, pconfigs): # dim_head is always 32 # img resolution, windows size: 224, 384, 518, 640 - C, H, W, window_size = [3, 224, 224, 7] - # C, H, W, window_size = [3, 384, 384, 12] + # C, H, W, window_size = [3, 224, 224, 7] + C, H, W, window_size = [3, 384, 384, 12] # C, H, W, window_size = [3, 518, 518, ?] - # C, H, W, window_size = [4, 640, 640, 20] + # C, H, W, window_size = [3, 640, 640, 20] # image batch size N = args.bs @@ -804,7 +804,7 @@ def train(args, pconfigs): # SwinV2-B: 87 M # embed_dim, depths, num_heads = [ - # 128, [2, 2, 18, 2], [4, 8, 12, 24] + # 128, [2, 2, 18, 2], [4, 8, 16, 32] # ] # SwinV2-L: 196 M @@ -839,7 +839,8 @@ def train(args, pconfigs): # ] - model = SwinTransformer(embed_dim = embed_dim, + model = SwinTransformer(img_size = H, + embed_dim = embed_dim, depths = depths, num_heads = num_heads, window_size = window_size, diff --git a/examples/swin/swin_dwt.py b/examples/swin/swin_dwt.py index 1e6eb7e6..11d0ffcb 100644 --- a/examples/swin/swin_dwt.py +++ b/examples/swin/swin_dwt.py @@ -805,7 +805,7 @@ def train(args, pconfigs): # SwinV2-B: 87 M # embed_dim, depths, num_heads = [ - # 128, [2, 2, 18, 2], [4, 8, 12, 24] + # 128, [2, 2, 18, 2], [4, 8, 16, 32] # ] # SwinV2-L: 196 M @@ -840,7 +840,8 @@ def train(args, pconfigs): # ] - model = SwinTransformer(embed_dim = embed_dim, + model = SwinTransformer(img_size = H, + embed_dim = embed_dim, depths = depths, num_heads = num_heads, window_size = window_size, From 5bf2c74db9f95dd4856eb56e396f18992e12565b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 12:35:18 +0800 Subject: [PATCH 0435/1892] swin inference test --- eval/swin_infer.sh | 77 +++ examples/swin/swin_dwt_infer.py | 963 ++++++++++++++++++++++++++++++++ 2 files changed, 1040 insertions(+) create mode 100644 eval/swin_infer.sh create mode 100644 examples/swin/swin_dwt_infer.py diff --git a/eval/swin_infer.sh b/eval/swin_infer.sh new file mode 100644 index 00000000..65d28098 --- /dev/null +++ b/eval/swin_infer.sh @@ -0,0 +1,77 @@ +mkdir -p expinfer32 + +# ================== Maximal Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 2 \ + --layer1 1 1 2 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + > expinfer32/2gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 4 \ + --layer1 1 1 4 \ + --layer2 1 1 4 \ + --layer3 1 1 4 \ + > expinfer32/4gpu_tp.txt + + +# ================== Window + Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 2 1 \ + --layer1 1 2 1 \ + --layer2 1 2 1 \ + --layer3 1 1 2 \ + > expinfer32/2gpu_2wp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 4 1 \ + --layer1 1 4 1 \ + --layer2 1 4 1 \ + --layer3 1 1 4 \ + > expinfer32/4gpu_4wp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 8 1 \ + --layer1 1 1 8 \ + --layer2 1 1 8 \ + --layer3 1 1 8 \ + > expinfer32/8gpu_8wp8tp.txt + diff --git a/examples/swin/swin_dwt_infer.py b/examples/swin/swin_dwt_infer.py new file mode 100644 index 00000000..5f2918b2 --- /dev/null +++ b/examples/swin/swin_dwt_infer.py @@ -0,0 +1,963 @@ + +# -------------------------------------------------------- +# Modified from Swin-Transformer Repo +""" +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs 8 \ + --layer0 1 4 1 \ + --layer1 1 4 1 \ + --layer2 1 1 4 \ + --layer3 1 1 4 + +""" +# -------------------------------------------------------- + +from typing import Dict, Optional, Tuple +import torch +import torch.nn as nn +import argparse +import time + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary +from cube.runtime.device import DeviceGroup +from cube.runtime.reducer import Reducer + +from examples.swin.layers import ColumnParallelLinear, RowParallelLinear + + +_wp_reducer: Dict[Tuple[int], Reducer] = dict() +_dp_reducer: Dict[Tuple[int], Reducer] = dict() + + +def setup_device_group(tp: int, wp: int, dp: int, layer_id: int): + """ + Layer wise device group initialize + + Returns: + + """ + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + tp_size, tp_group_nums = tp, ndevs // tp + wp_size, wp_group_nums = wp, ndevs // wp + dp_size, dp_group_nums = dp, ndevs // dp + + if not tp_size * wp_size * dp_size == ndevs: + raise RuntimeError("Expected all devices are used") + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize tensor parallel groups + for i in range(tp_group_nums): + ranks = list(range(i * tp_size, (i + 1) * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + + # initialize wp parallel group + all_wp_parallel_group_ranks = list() + for i in range(dp_size): + start_rank = i * dp_group_nums + end_rank = (i + 1) * dp_group_nums + for j in range(tp_size): + ranks = list(range(start_rank + j, end_rank, tp_size)) + all_wp_parallel_group_ranks.append(ranks) + # initialize groups + group = devs.get_group(ranks) + if myrank in ranks: + wp_ranks = ranks + _wp_reducer[tuple(ranks)] = Reducer(ranks) + print_each_rank(f'layer {layer_id}: initialzed window parallel group: {wp_ranks}', rank_only=myrank) + + # initialize data parallel groups + start_rank = 0 + end_rank = ndevs + for i in range(wp_size * tp_size): + ranks = list(range(i, ndevs, wp_size * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + _dp_reducer[tuple(ranks)] = Reducer(ranks) + print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) + return tp_ranks, wp_ranks, dp_ranks + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class MegatronMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tp_group=-1): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=tp_group) + self.act = act_layer() + # self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=True, tp_group=tp_group) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +class MegatronWindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., tp_group=-1): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.global_num_heads = num_heads + + tp_world_size = torch.distributed.get_world_size(group=tp_group) + if num_heads % tp_world_size != 0: + raise RuntimeError(f'detecting un-even num head {num_heads} partition to {tp_world_size}') + self.num_heads = num_heads // tp_world_size + + self.dim_heads = dim // self.global_num_heads + self.scale = qk_scale or self.dim_heads ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # print(f'qkv embed dim: {dim}') + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) + self.attn_drop = nn.Dropout(attn_drop) + # self.proj = nn.Linear(dim, dim) + self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=True, tp_group=tp_group) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + relative_position_bias = self.relative_position_bias_table[relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + tp_group=-1, wp_plans=-1): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = MegatronWindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + tp_group=tp_group) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MegatronMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop, + tp_group=tp_group + ) + + self.wp_group, self.wp_nH_ranks, self.wp_nW_ranks = wp_plans + self.use_wp = torch.distributed.get_world_size(self.wp_group) != 1 + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + if self.use_wp: + shifted_x = cube.runtime.function.roll_grid_parallel( + x, (-self.shift_size, -self.shift_size), (1,2), + self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group + ) + else: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x + if self.shift_size > 0: + if self.use_wp: + x = cube.runtime.function.roll_grid_parallel( + shifted_x, (self.shift_size, self.shift_size), (1,2), + self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group + ) + else: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + # [B, H, W, C] -> [B, H * W, C] + x = x.view(B, H * W, C) + # [B, H * W, C] -> [B, H * W, C] + x = shortcut + drop_path(x, self.drop_path_p) + # FFN + # [B, H * W, C] -> [B, H * W, C] + ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] + ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] + x = x + drop_path(ffn, self.drop_path_p) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, + tp=1, wp=1, dp=1, layer_id=-1): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + self.resource = cube.runtime.resource.EnvResource() + tp_ranks, wp_ranks, dp_ranks = setup_device_group(tp, wp, dp, layer_id) + tp_group = DeviceGroup().get_group(tp_ranks) + wp_group = DeviceGroup().get_group(wp_ranks) + wp_nH_ranks = [-1] + wp_nW_ranks = [-1] + + # window parallel + self.wp_resolution = input_resolution + if wp > 1: + H, W = self.input_resolution + nH = 1 + nW = wp // nH + while nH <= nW: + if H % nH != 0 or W % nW != 0: + nW = nW // 2 + nH = int(nH * 2) + else: + break + if nH > nW: + raise RuntimeError(f"layer {layer_id}: Cannot window partition plan") + print_each_rank(f"layer {layer_id}: Find partition plan: H{H} // {nH}, W{W} // {nW}") + self.wp_resolution = (H // nH, W // nW) + self.wp_group = wp_group + # wp_group multi dim shift ranks + for i in range(nH): + ranks = list(range(i * nW, (i + 1) * nW)) + if torch.distributed.get_rank(wp_group) in ranks: + wp_nW_ranks = ranks + break + for i in range(nW): + ranks = list(range(i, wp, nW)) + if torch.distributed.get_rank(wp_group) in ranks: + wp_nH_ranks = ranks + break + assert wp_nH_ranks != [-1] + assert wp_nW_ranks != [-1] + print_each_rank(f'window parallel nH group ranks: {wp_nH_ranks}') + print_each_rank(f'window parallel nW group ranks: {wp_nW_ranks}') + + # build blocks + self.blocks = nn.ModuleList() + for i in range(depth): + block = SwinTransformerBlock( + dim=dim, input_resolution=self.wp_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + tp_group=tp_group, wp_plans=(wp_group, wp_nH_ranks, wp_nW_ranks) + ) + self.blocks.append(block) + + self.wp_preprocess = False + self.wp_postprocess = False + if wp > 1: + for param in self.blocks.parameters(): + _wp_reducer[tuple(wp_ranks)].add_param(param) + self.wp_preprocess = True + self.wp_postprocess = True + + def forward(self, x): + if self.wp_preprocess: + oH, oW = self.input_resolution + pH, pW = self.wp_resolution + x = x.view(-1, oH, oW, self.dim) + x = cube.runtime.function.grid_partition(x, oH // pH, oW // pW, group=self.wp_group) + x = x.view(-1, pH * pW, self.dim).contiguous() + + for blk in self.blocks: + x = blk(x) + + if self.wp_postprocess: + oH, oW = self.input_resolution + pH, pW = self.wp_resolution + x = x.view(-1, pH, pW, self.dim) + x = cube.runtime.function.grid_collection(x, oH // pH, oW // pW, group=self.wp_group) + x = x.view(-1, oH * oW, self.dim) + return x + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + + # ====================== depth 0 =========================== + pconfig = pconfigs[0] + input_resolution = ( + patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) + ) + self.basic_layer0 = BasicLayer( + dim=int(embed_dim * 2 ** 0), + input_resolution=input_resolution, + depth=depths[0], + num_heads=num_heads[0], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:0]):sum(depths[:0 + 1])], + norm_layer=norm_layer, + **pconfig, + ) + + self.merging0 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 0), norm_layer=norm_layer + ) + + # ====================== depth 1 =========================== + pconfig = pconfigs[1] + input_resolution = ( + patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) + ) + self.basic_layer1 = BasicLayer( + dim=int(embed_dim * 2 ** 1), + input_resolution=input_resolution, + depth=depths[1], + num_heads=num_heads[1], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:1]):sum(depths[:1 + 1])], + norm_layer=norm_layer, + **pconfig, + ) + + self.merging1 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 1), norm_layer=norm_layer + ) + + + # ====================== depth 2 =========================== + pconfig = pconfigs[2] + input_resolution = ( + patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) + ) + self.basic_layer2 = BasicLayer( + dim=int(embed_dim * 2 ** 2), + input_resolution=input_resolution, + depth=depths[2], + num_heads=num_heads[2], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:2]):sum(depths[:2 + 1])], + norm_layer=norm_layer, + **pconfig + ) + + self.merging2 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 2), norm_layer=norm_layer + ) + + # ====================== depth 3 =========================== + pconfig = pconfigs[3] + self.basic_layer3 = BasicLayer( + dim=int(embed_dim * 2 ** 3), + input_resolution=(patches_resolution[0] // (2 ** 3), + patches_resolution[1] // (2 ** 3)), + depth=depths[3], + num_heads=num_heads[3], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:3]):sum(depths[:3 + 1])], + norm_layer=norm_layer, + **pconfig + ) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x) + + CudaTimer().start('basic_layer0') + x = self.basic_layer0(x) + CudaTimer().stop('basic_layer0') + x = self.merging0(x) + CudaTimer().start('basic_layer1') + x = self.basic_layer1(x) + CudaTimer().stop('basic_layer1') + x = self.merging1(x) + CudaTimer().start('basic_layer2') + x = self.basic_layer2(x) + CudaTimer().stop('basic_layer2') + x = self.merging2(x) + CudaTimer().start('basic_layer3') + x = self.basic_layer3(x) + CudaTimer().stop('basic_layer3') + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C L + x = torch.flatten(x, 1) + + x = self.head(x) + return x + + +def train(args, pconfigs): + + # dim_head is always 32 + + # img resolution, windows size: 224, 384, 518, 640 + # C, H, W, window_size = [3, 224, 224, 7] + # C, H, W, window_size = [3, 384, 384, 12] + # C, H, W, window_size = [3, 518, 518, ?] + C, H, W, window_size = [3, 640, 640, 20] + + # image batch size + N = args.bs + + # Swin-Tiny + # embed_dim, depths, num_heads = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24] + # ] + + # SwinV2-B: 87 M + # embed_dim, depths, num_heads = [ + # 128, [2, 2, 18, 2], [4, 8, 16, 32] + # ] + + # SwinV2-L: 196 M + # embed_dim, depths, num_heads = [ + # 192, [2, 2, 18, 2], [6, 12, 24, 48] + # ] + + # SwinV2-H: 657 M + # embed_dim, depths, num_heads = [ + # 352, [2, 2, 18, 2], [11, 22, 44, 88] + # ] + + # SwinV2-H modified: 782 M + embed_dim, depths, num_heads = [ + 384, [2, 2, 18, 2], [12, 24, 48, 96] + ] + + # SwinV2-G: 2.5B Model + # embed_dim, depths, num_heads = [ + # 512, [2, 2, 42, 2], [16, 32, 64, 128] + # ] + + # 895.7 M Model + # embed_dim, depths, num_heads = [ + # 384, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + + # 2.01B model + # embed_dim, depths, num_heads = [ + # 576, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + + model = SwinTransformer(img_size = H, + embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size, + pconfigs = pconfigs) + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + if args.fp16: + print_each_rank('use half model') + model = model.half() + model = model.cuda() + memory_summary() + + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [N // args.dp, C, H, W]) + + if args.fp16: + data_buff = [[e.half() for e in data] for data in dataloader.datas] + dataloader.datas = data_buff + + if args.dp > 1: + assert len(_dp_reducer) == 1 + reducer = None + for ranks in _dp_reducer: + reducer = _dp_reducer[ranks] + for param in model.parameters(): + reduced = False + for wp_ranks in _wp_reducer: + if param in _wp_reducer[wp_ranks]._params: + reduced = True + break + if not reduced: + reducer.add_param(param) + + def infer_iter(model, dataloader): + with torch.no_grad(): + img = next(dataloader) + loss = model(img) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # start training + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + CudaTimer(enable=False).warmup() + torch.distributed.barrier() + span = 0 + iter_num = 128 + for step in range(iter_num): + if step >= 40: + torch.cuda.synchronize() + start = time.time() + CudaTimer(enable=True).start('e2e') + infer_iter(model, dataloader) + if step == 1: + print('> passed on 1st iteration') + memory_summary() + if step >= 40: + torch.cuda.synchronize() + stop = time.time() + span += (stop - start) * 1000 + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = span / (iter_num-40) + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + CudaTimer().print_all(times=iter_num-40) + + +if __name__ == '__main__': + + cube.init() + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--layer0', type=int, nargs='+', + help='data, window tensor parallel config') + parser.add_argument('--layer1', type=int, nargs='+', + help='data, window tensor parallel config') + parser.add_argument('--layer2', type=int, nargs='+', + help='data, window tensor parallel config') + parser.add_argument('--layer3', type=int, nargs='+', + help='data, window tensor parallel config') + parser.add_argument('--bs', type=int, default=1, + help='bs') + parser.add_argument('--fp16', action='store_true', dest='fp16') + args = parser.parse_args() + + assert len(args.layer0) == 3 + assert len(args.layer1) == 3 + assert len(args.layer2) == 3 + assert len(args.layer3) == 3 + + # data parallel should be same + assert args.layer0[0] == args.layer1[0] and args.layer1[0] == args.layer2[0] and args.layer2[0] == args.layer3[0] + args.dp = args.layer0[0] + + pconfigs = [ + dict(layer_id=0, dp=args.layer0[0], wp=args.layer0[1], tp=args.layer0[2]), # basic layer 0 + dict(layer_id=1, dp=args.layer1[0], wp=args.layer1[1], tp=args.layer1[2]), # basic layer 1 + dict(layer_id=2, dp=args.layer2[0], wp=args.layer2[1], tp=args.layer2[2]), # basic layer 2 # prob at 8:1? + dict(layer_id=3, dp=args.layer3[0], wp=args.layer3[1], tp=args.layer3[2]), # basic layer 3 + ] + + # pconfigs = [ + # dict(layer_id=0, tp=4, wp=1, dp=args.dp), # basic layer 0 + # dict(layer_id=1, tp=4, wp=1, dp=args.dp), # basic layer 1 + # dict(layer_id=2, tp=4, wp=1, dp=args.dp), # basic layer 2 # prob at 8:1? + # dict(layer_id=3, tp=4, wp=1, dp=args.dp), # basic layer 3 + # ] + + print_each_rank(pconfigs, rank_only=0) + train(args, pconfigs) From 2854169337c77f8725b78cf5ffaf9b8b0e6038f1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 12:39:36 +0800 Subject: [PATCH 0436/1892] swin infer script --- eval/swin_infer.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) mode change 100644 => 100755 eval/swin_infer.sh diff --git a/eval/swin_infer.sh b/eval/swin_infer.sh old mode 100644 new mode 100755 index 65d28098..1efa659d --- a/eval/swin_infer.sh +++ b/eval/swin_infer.sh @@ -70,8 +70,8 @@ python -m torch.distributed.launch \ --use_env \ examples/swin/swin_dwt_infer.py --bs 1 \ --layer0 1 8 1 \ - --layer1 1 1 8 \ - --layer2 1 1 8 \ + --layer1 1 8 1 \ + --layer2 1 4 2 \ --layer3 1 1 8 \ > expinfer32/8gpu_8wp8tp.txt From 0401c539244b290e8143d9387b6b6fb0adaecf73 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 09:01:52 +0000 Subject: [PATCH 0437/1892] swin infer --- eval/swin_infer.sh | 71 ++++++++++++++++------- eval/swin_infer_bs1_640_Gfp16.sh | 98 ++++++++++++++++++++++++++++++++ examples/swin/swin_dt.py | 16 +++--- examples/swin/swin_dwt_infer.py | 49 ++++++---------- 4 files changed, 174 insertions(+), 60 deletions(-) create mode 100755 eval/swin_infer_bs1_640_Gfp16.sh diff --git a/eval/swin_infer.sh b/eval/swin_infer.sh index 1efa659d..13f468df 100755 --- a/eval/swin_infer.sh +++ b/eval/swin_infer.sh @@ -1,7 +1,21 @@ -mkdir -p expinfer32 +mkdir -p expinfer32_2.6B_fp16_bs4 # ================== Maximal Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 1 \ + --layer1 1 1 1 \ + --layer2 1 1 1 \ + --layer3 1 1 1 \ + > expinfer32_2.6B_fp16_bs4/1gpu_tp.txt + python -m torch.distributed.launch \ --nproc_per_node=2 \ --nnodes=1 \ @@ -14,7 +28,7 @@ python -m torch.distributed.launch \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > expinfer32/2gpu_tp.txt + > expinfer32_2.6B_fp16_bs4/2gpu_tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -28,7 +42,22 @@ python -m torch.distributed.launch \ --layer1 1 1 4 \ --layer2 1 1 4 \ --layer3 1 1 4 \ - > expinfer32/4gpu_tp.txt + > expinfer32_2.6B_fp16_bs4/4gpu_tp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 4 1 2 \ + --layer1 4 1 2 \ + --layer2 4 1 2 \ + --layer3 4 1 2 \ + > expinfer32_2.6B_fp16_bs4/8gpu_tp.txt # ================== Window + Tensor Parallel =============== @@ -40,12 +69,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 2 1 \ - --layer1 1 2 1 \ - --layer2 1 2 1 \ - --layer3 1 1 2 \ - > expinfer32/2gpu_2wp2tp.txt + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 2 1 1 \ + --layer1 2 1 1 \ + --layer2 2 1 1 \ + --layer3 2 1 1 \ + > expinfer32_2.6B_fp16_bs4/2gpu_2wp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -54,12 +83,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 4 1 \ - --layer1 1 4 1 \ - --layer2 1 4 1 \ - --layer3 1 1 4 \ - > expinfer32/4gpu_4wp4tp.txt + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 4 1 1 \ + --layer1 4 1 1 \ + --layer2 4 1 1 \ + --layer3 4 1 1 \ + > expinfer32_2.6B_fp16_bs4/4gpu_4wp4tp.txt python -m torch.distributed.launch \ --nproc_per_node=8 \ @@ -68,10 +97,10 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 8 1 \ - --layer1 1 8 1 \ - --layer2 1 4 2 \ - --layer3 1 1 8 \ - > expinfer32/8gpu_8wp8tp.txt + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 2 4 1 \ + --layer1 2 4 1 \ + --layer2 2 4 1 \ + --layer3 2 1 4 \ + > expinfer32_2.6B_fp16_bs4/8gpu_8wp8tp.txt diff --git a/eval/swin_infer_bs1_640_Gfp16.sh b/eval/swin_infer_bs1_640_Gfp16.sh new file mode 100755 index 00000000..1a160cf6 --- /dev/null +++ b/eval/swin_infer_bs1_640_Gfp16.sh @@ -0,0 +1,98 @@ +mkdir -p expinfer_Gfp16_bs1 + +# ================== Maximal Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 2 \ + --layer1 1 1 2 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + --fp16 \ + > expinfer_Gfp16_bs1/2gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 4 \ + --layer1 1 1 4 \ + --layer2 1 1 4 \ + --layer3 1 1 4 \ + --fp16 \ + > expinfer_Gfp16_bs1/4gpu_tp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 8 \ + --layer1 1 1 8 \ + --layer2 1 1 8 \ + --layer3 1 1 8 \ + --fp16 \ + > expinfer_Gfp16_bs1/8gpu_tp.txt + + +# ================== Window + Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 2 \ + --layer1 1 1 2 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + --fp16 \ + > expinfer_Gfp16_bs1/2gpu_2wp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 4 1 \ + --layer1 1 4 1 \ + --layer2 1 4 1 \ + --layer3 1 1 4 \ + --fp16 \ + > expinfer_Gfp16_bs1/4gpu_4wp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 8 1 \ + --layer1 1 8 1 \ + --layer2 1 4 2 \ + --layer3 2 1 4 \ + --fp16 \ + > expinfer_Gfp16_bs1/8gpu_8wp8tp.txt + diff --git a/examples/swin/swin_dt.py b/examples/swin/swin_dt.py index 931cb45e..03347f65 100644 --- a/examples/swin/swin_dt.py +++ b/examples/swin/swin_dt.py @@ -3,17 +3,17 @@ # Modified from Swin-Transformer Repo """ python -m torch.distributed.launch \ - --nproc_per_node=4 \ + --nproc_per_node=2 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dt.py --bs 8 \ - --layer0 1 4 \ - --layer1 1 4 \ - --layer2 1 4 \ - --layer3 1 4 + examples/swin/swin_dt.py --bs 2 \ + --layer0 1 2 \ + --layer1 1 2 \ + --layer2 1 2 \ + --layer3 1 2 """ # -------------------------------------------------------- @@ -789,8 +789,8 @@ def train(args, pconfigs): # dim_head is always 32 # img resolution, windows size: 224, 384, 518, 640 - # C, H, W, window_size = [3, 224, 224, 7] - C, H, W, window_size = [3, 384, 384, 12] + C, H, W, window_size = [3, 224, 224, 7] + # C, H, W, window_size = [3, 384, 384, 12] # C, H, W, window_size = [3, 518, 518, ?] # C, H, W, window_size = [3, 640, 640, 20] diff --git a/examples/swin/swin_dwt_infer.py b/examples/swin/swin_dwt_infer.py index 5f2918b2..b5328c55 100644 --- a/examples/swin/swin_dwt_infer.py +++ b/examples/swin/swin_dwt_infer.py @@ -794,6 +794,7 @@ def train(args, pconfigs): # C, H, W, window_size = [3, 384, 384, 12] # C, H, W, window_size = [3, 518, 518, ?] C, H, W, window_size = [3, 640, 640, 20] + # C, H, W, window_size = [3, 1536, 1536, 48] # image batch size N = args.bs @@ -819,15 +820,15 @@ def train(args, pconfigs): # ] # SwinV2-H modified: 782 M - embed_dim, depths, num_heads = [ - 384, [2, 2, 18, 2], [12, 24, 48, 96] - ] - - # SwinV2-G: 2.5B Model # embed_dim, depths, num_heads = [ - # 512, [2, 2, 42, 2], [16, 32, 64, 128] + # 384, [2, 2, 18, 2], [12, 24, 48, 96] # ] + # SwinV2-G: 2.5B Model + embed_dim, depths, num_heads = [ + 512, [2, 2, 42, 2], [16, 32, 64, 128] + ] + # 895.7 M Model # embed_dim, depths, num_heads = [ # 384, [2, 2, 22, 2], [12, 24, 48, 96] @@ -839,6 +840,11 @@ def train(args, pconfigs): # 576, [2, 2, 22, 2], [12, 24, 48, 96] # ] + print_each_rank( + f'Test setting: Resolution {H}, Embed {embed_dim}, depths: {depths}, heads: {num_heads}' + rank_only=0 + ) + model = SwinTransformer(img_size = H, embed_dim = embed_dim, @@ -862,20 +868,6 @@ def train(args, pconfigs): data_buff = [[e.half() for e in data] for data in dataloader.datas] dataloader.datas = data_buff - if args.dp > 1: - assert len(_dp_reducer) == 1 - reducer = None - for ranks in _dp_reducer: - reducer = _dp_reducer[ranks] - for param in model.parameters(): - reduced = False - for wp_ranks in _wp_reducer: - if param in _wp_reducer[wp_ranks]._params: - reduced = True - break - if not reduced: - reducer.add_param(param) - def infer_iter(model, dataloader): with torch.no_grad(): img = next(dataloader) @@ -890,9 +882,9 @@ def infer_iter(model, dataloader): CudaTimer(enable=False).warmup() torch.distributed.barrier() span = 0 - iter_num = 128 + iter_num = 60 for step in range(iter_num): - if step >= 40: + if step >= 20: torch.cuda.synchronize() start = time.time() CudaTimer(enable=True).start('e2e') @@ -900,7 +892,7 @@ def infer_iter(model, dataloader): if step == 1: print('> passed on 1st iteration') memory_summary() - if step >= 40: + if step >= 20: torch.cuda.synchronize() stop = time.time() span += (stop - start) * 1000 @@ -908,13 +900,13 @@ def infer_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = span / (iter_num-40) + iter_time = span / (iter_num-20) throughput = N / iter_time * 1000 print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) memory_summary() - CudaTimer().print_all(times=iter_num-40) + CudaTimer().print_all(times=iter_num-20) if __name__ == '__main__': @@ -952,12 +944,7 @@ def infer_iter(model, dataloader): dict(layer_id=3, dp=args.layer3[0], wp=args.layer3[1], tp=args.layer3[2]), # basic layer 3 ] - # pconfigs = [ - # dict(layer_id=0, tp=4, wp=1, dp=args.dp), # basic layer 0 - # dict(layer_id=1, tp=4, wp=1, dp=args.dp), # basic layer 1 - # dict(layer_id=2, tp=4, wp=1, dp=args.dp), # basic layer 2 # prob at 8:1? - # dict(layer_id=3, tp=4, wp=1, dp=args.dp), # basic layer 3 - # ] + args.fp16 = True print_each_rank(pconfigs, rank_only=0) train(args, pconfigs) From 4f3f9e2109be01d4f8d108b88b02d454b4395d93 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 09:04:58 +0000 Subject: [PATCH 0438/1892] infer script --- eval/swin_infer_bs1_640_Gfp16.sh | 20 +++++++++++++++++--- examples/swin/swin_dwt_infer.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/eval/swin_infer_bs1_640_Gfp16.sh b/eval/swin_infer_bs1_640_Gfp16.sh index 1a160cf6..504a8363 100755 --- a/eval/swin_infer_bs1_640_Gfp16.sh +++ b/eval/swin_infer_bs1_640_Gfp16.sh @@ -1,6 +1,20 @@ mkdir -p expinfer_Gfp16_bs1 # ================== Maximal Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 1 \ + --layer1 1 1 1 \ + --layer2 1 1 1 \ + --layer3 1 1 1 \ + --fp16 \ + > expinfer_Gfp16_bs1/1gpu_tp.txt python -m torch.distributed.launch \ --nproc_per_node=2 \ @@ -59,9 +73,9 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 2 \ - --layer1 1 1 2 \ - --layer2 1 1 2 \ + --layer0 1 2 1 \ + --layer1 1 2 1 \ + --layer2 1 2 1 \ --layer3 1 1 2 \ --fp16 \ > expinfer_Gfp16_bs1/2gpu_2wp2tp.txt diff --git a/examples/swin/swin_dwt_infer.py b/examples/swin/swin_dwt_infer.py index b5328c55..e3f0e6c6 100644 --- a/examples/swin/swin_dwt_infer.py +++ b/examples/swin/swin_dwt_infer.py @@ -841,7 +841,7 @@ def train(args, pconfigs): # ] print_each_rank( - f'Test setting: Resolution {H}, Embed {embed_dim}, depths: {depths}, heads: {num_heads}' + f'Test setting: Resolution {H}, Embed {embed_dim}, depths: {depths}, heads: {num_heads}', rank_only=0 ) From 28d08d4087d0df0b6ae09793cf6ee341cc36c914 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 20:04:54 +0800 Subject: [PATCH 0439/1892] avoid get world size --- examples/swin/layers.py | 58 +++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/examples/swin/layers.py b/examples/swin/layers.py index 65790dd4..61b70932 100644 --- a/examples/swin/layers.py +++ b/examples/swin/layers.py @@ -164,40 +164,37 @@ def __init__(self, input_size, output_size, bias=True, in_adapter=True, out_adap self.out_adapter = out_adapter self.group = tp_group - world_size = torch.distributed.get_world_size(group=self.group) + self.world_size = torch.distributed.get_world_size(group=self.group) # print_each_rank(f'> parallizing linear using column partition: ' # f'{output_size} partitioned by {world_size} devices') # not if output size is smaller than world size, # no parallel enbaled. Each device compute the same - if world_size > output_size: - world_size = 1 + if self.world_size > output_size: + raise RuntimeError self.weight = Parameter(torch.empty( - int(self.output_size // world_size), + int(self.output_size // self.world_size), self.input_size, )) if bias: self.bias = Parameter(torch.empty( - int(self.output_size // world_size), + int(self.output_size // self.world_size), )) - with torch.no_grad(): - self.bias.zero_() else: - self.register_parameter('bias', None) + self.bias = None def forward(self, input_): - bias = self.bias - if self.in_adapter: - input_parallel = ColumnInputAdapter.apply(input_, self.group) - else: - input_parallel = input_ - output_parallel = F.linear(input_parallel, self.weight, bias) - if self.out_adapter: - output = ColumnOutputAdapter.apply(output_parallel, self.group) - else: - output = output_parallel + + if self.in_adapter and self.world_size > 1: + input_ = ColumnInputAdapter.apply(input_, self.group) + + output = F.linear(input_, self.weight, self.bias) + + if self.out_adapter and self.world_size > 1: + output = ColumnOutputAdapter.apply(output, self.group) + return output @@ -212,19 +209,19 @@ def __init__(self, input_size, output_size, bias=True, in_adapter=True, out_adap self.out_adapter = out_adapter self.group = tp_group - world_size = torch.distributed.get_world_size(group=self.group) + self.world_size = torch.distributed.get_world_size(group=self.group) # print_each_rank(f'> parallizing linear using row partition: ' # f'{output_size} partitioned by {world_size} devices') # not if output size is smaller than world size, # no parallel enbaled. Each device compute the same - if world_size > output_size: - world_size = 1 + if self.world_size > input_size: + raise RuntimeError self.weight = Parameter(torch.empty( self.output_size, - int(self.input_size // world_size), + int(self.input_size // self.world_size), )) if bias: self.bias = Parameter(torch.empty(self.output_size)) @@ -235,15 +232,14 @@ def __init__(self, input_size, output_size, bias=True, in_adapter=True, out_adap def forward(self, input_): bias = self.bias - if self.in_adapter: - input_parallel = RowInputAdapter.apply(input_, self.group) - else: - input_parallel = input_ - output_parallel = F.linear(input_parallel, self.weight, bias) - if self.out_adapter: - output = RowOutputAdapter.apply(output_parallel, self.group) - else: - output = output_parallel + if self.in_adapter and self.world_size > 1: + input_ = RowInputAdapter.apply(input_, self.group) + + output = F.linear(input_, self.weight, bias) + + if self.out_adapter and self.world_size > 1: + output = RowOutputAdapter.apply(output, self.group) + return output From bd4e1eb5c020494948f27aab31d2b836389cb962 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 29 Nov 2021 12:49:42 +0000 Subject: [PATCH 0440/1892] add infer script --- eval/swin_infer_bs1_640_Gfp16.sh | 2 +- eval/swin_infer_bs2_640_Gfp16.sh | 112 +++++++++++++++++++++++++++++++ eval/swin_infer_bs4_640_Gfp16.sh | 112 +++++++++++++++++++++++++++++++ 3 files changed, 225 insertions(+), 1 deletion(-) create mode 100755 eval/swin_infer_bs2_640_Gfp16.sh create mode 100755 eval/swin_infer_bs4_640_Gfp16.sh diff --git a/eval/swin_infer_bs1_640_Gfp16.sh b/eval/swin_infer_bs1_640_Gfp16.sh index 504a8363..7fefa488 100755 --- a/eval/swin_infer_bs1_640_Gfp16.sh +++ b/eval/swin_infer_bs1_640_Gfp16.sh @@ -106,7 +106,7 @@ python -m torch.distributed.launch \ --layer0 1 8 1 \ --layer1 1 8 1 \ --layer2 1 4 2 \ - --layer3 2 1 4 \ + --layer3 1 1 8 \ --fp16 \ > expinfer_Gfp16_bs1/8gpu_8wp8tp.txt diff --git a/eval/swin_infer_bs2_640_Gfp16.sh b/eval/swin_infer_bs2_640_Gfp16.sh new file mode 100755 index 00000000..cb1540aa --- /dev/null +++ b/eval/swin_infer_bs2_640_Gfp16.sh @@ -0,0 +1,112 @@ +mkdir -p expinfer_Gfp16_bs2 + +# ================== Maximal Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 1 1 1 \ + --layer1 1 1 1 \ + --layer2 1 1 1 \ + --layer3 1 1 1 \ + --fp16 \ + > expinfer_Gfp16_bs2/1gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 1 1 \ + --layer1 2 1 1 \ + --layer2 2 1 1 \ + --layer3 2 1 1 \ + --fp16 \ + > expinfer_Gfp16_bs2/2gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 1 2 \ + --layer1 2 1 2 \ + --layer2 2 1 2 \ + --layer3 2 1 2 \ + --fp16 \ + > expinfer_Gfp16_bs2/4gpu_tp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 1 4 \ + --layer1 2 1 4 \ + --layer2 2 1 4 \ + --layer3 2 1 4 \ + --fp16 \ + > expinfer_Gfp16_bs2/8gpu_tp.txt + + +# ================== Window + Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 1 2 1 \ + --layer1 1 2 1 \ + --layer2 1 2 1 \ + --layer3 1 1 2 \ + --fp16 \ + > expinfer_Gfp16_bs2/2gpu_2wp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 1 4 1 \ + --layer1 1 4 1 \ + --layer2 1 4 1 \ + --layer3 1 1 4 \ + --fp16 \ + > expinfer_Gfp16_bs2/4gpu_4wp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 1 8 1 \ + --layer1 1 8 1 \ + --layer2 1 4 2 \ + --layer3 1 1 8 \ + --fp16 \ + > expinfer_Gfp16_bs2/8gpu_8wp8tp.txt + diff --git a/eval/swin_infer_bs4_640_Gfp16.sh b/eval/swin_infer_bs4_640_Gfp16.sh new file mode 100755 index 00000000..74eb8195 --- /dev/null +++ b/eval/swin_infer_bs4_640_Gfp16.sh @@ -0,0 +1,112 @@ +mkdir -p expinfer_Gfp16_bs4 + +# ================== Maximal Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 1 1 1 \ + --layer1 1 1 1 \ + --layer2 1 1 1 \ + --layer3 1 1 1 \ + --fp16 \ + > expinfer_Gfp16_bs4/1gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 2 1 1 \ + --layer1 2 1 1 \ + --layer2 2 1 1 \ + --layer3 2 1 1 \ + --fp16 \ + > expinfer_Gfp16_bs4/2gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 4 1 1 \ + --layer1 4 1 1 \ + --layer2 4 1 1 \ + --layer3 4 1 1 \ + --fp16 \ + > expinfer_Gfp16_bs4/4gpu_tp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 4 1 2 \ + --layer1 4 1 2 \ + --layer2 4 1 2 \ + --layer3 4 1 2 \ + --fp16 \ + > expinfer_Gfp16_bs4/8gpu_tp.txt + + +# ================== Window + Tensor Parallel =============== + +# python -m torch.distributed.launch \ +# --nproc_per_node=2 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt_infer.py --bs 4 \ +# --layer0 4 1 1 \ +# --layer1 4 1 1 \ +# --layer2 4 1 1 \ +# --layer3 4 1 1 \ +# --fp16 \ +# > expinfer_Gfp16_bs4/2gpu_2wp2tp.txt +# +# python -m torch.distributed.launch \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt_infer.py --bs 4 \ +# --layer0 4 1 1 \ +# --layer1 4 1 1 \ +# --layer2 4 1 1 \ +# --layer3 4 1 1 \ +# --fp16 \ +# > expinfer_Gfp16_bs4/4gpu_4wp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 4 \ + --layer0 4 2 1 \ + --layer1 4 2 1 \ + --layer2 4 2 1 \ + --layer3 4 1 2 \ + --fp16 \ + > expinfer_Gfp16_bs4/8gpu_8wp8tp.txt + From f0ea99ed43c3f581d62502a99865c2bed7fd76e3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 01:55:40 +0000 Subject: [PATCH 0441/1892] fix cpu code bugs --- examples/swin/swin_dt.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/swin/swin_dt.py b/examples/swin/swin_dt.py index 03347f65..7da2d339 100644 --- a/examples/swin/swin_dt.py +++ b/examples/swin/swin_dt.py @@ -10,11 +10,10 @@ --master_port=8004 \ --use_env \ examples/swin/swin_dt.py --bs 2 \ - --layer0 1 2 \ - --layer1 1 2 \ - --layer2 1 2 \ - --layer3 1 2 - + --layer0 2 1 \ + --layer1 2 1 \ + --layer2 2 1 \ + --layer3 2 1 """ # -------------------------------------------------------- @@ -185,6 +184,10 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # relative position index + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + self.register_buffer('relative_position_index', relative_position_index) + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) self.attn_drop = nn.Dropout(attn_drop) @@ -208,8 +211,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): q = q * self.scale attn = (q @ k.transpose(-2, -1)) - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - relative_position_bias = self.relative_position_bias_table[relative_position_index] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # [Wh * Ww, Wh * Ww, nH] relative_position_bias = relative_position_bias.view( self.window_size[0] * self.window_size[1], From 1f85d4c3095a1657a8c53e14e81208f627065aa2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 03:24:35 +0000 Subject: [PATCH 0442/1892] swin dwt bug fix --- examples/swin/swin_dwt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/swin/swin_dwt.py b/examples/swin/swin_dwt.py index 11d0ffcb..3ae2e785 100644 --- a/examples/swin/swin_dwt.py +++ b/examples/swin/swin_dwt.py @@ -207,6 +207,10 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # relative position index + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + self.register_buffer('relative_position_index', relative_position_index) + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # print(f'qkv embed dim: {dim}') self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) @@ -231,8 +235,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): q = q * self.scale attn = (q @ k.transpose(-2, -1)) - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - relative_position_bias = self.relative_position_bias_table[relative_position_index] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # [Wh * Ww, Wh * Ww, nH] relative_position_bias = relative_position_bias.view( self.window_size[0] * self.window_size[1], From 9e3f1006a689b5d9d57ca53aadad56ae48ff98c1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 03:29:31 +0000 Subject: [PATCH 0443/1892] fp32 test --- eval/{swin.sh => swin_train_fp32.sh} | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) rename eval/{swin.sh => swin_train_fp32.sh} (87%) diff --git a/eval/swin.sh b/eval/swin_train_fp32.sh similarity index 87% rename from eval/swin.sh rename to eval/swin_train_fp32.sh index fb2aab71..2bf31b74 100755 --- a/eval/swin.sh +++ b/eval/swin_train_fp32.sh @@ -1,4 +1,4 @@ -mkdir -p expdata +mkdir -p exptrain_782M_bs8_fp32 # ================== Megatron Policy Parallel =============== @@ -14,7 +14,7 @@ python -m torch.distributed.launch \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > expdata/2gpu_dp.txt + > exptrain_782M_bs8_fp32/2gpu_maxdp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -28,7 +28,7 @@ python -m torch.distributed.launch \ --layer1 2 1 2 \ --layer2 2 1 2 \ --layer3 2 1 2 \ - > expdata/4gpu_dp.txt + > exptrain_782M_bs8_fp32/4gpu_maxdp.txt python -m torch.distributed.launch \ @@ -43,7 +43,7 @@ python -m torch.distributed.launch \ --layer1 4 1 2 \ --layer2 4 1 2 \ --layer3 4 1 2 \ - > expdata/8gpu_dp.txt + > exptrain_782M_bs8_fp32/8gpu_maxdp.txt # ================== Maximal Tensor Parallel =============== @@ -59,7 +59,7 @@ python -m torch.distributed.launch \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > expdata/2gpu_tp.txt + > exptrain_782M_bs8_fp32/2gpu_maxtp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -73,7 +73,7 @@ python -m torch.distributed.launch \ --layer1 1 1 4 \ --layer2 1 1 4 \ --layer3 1 1 4 \ - > expdata/4gpu_tp.txt + > exptrain_782M_bs8_fp32/4gpu_maxtp.txt python -m torch.distributed.launch \ @@ -88,7 +88,7 @@ python -m torch.distributed.launch \ --layer1 2 1 4 \ --layer2 2 1 4 \ --layer3 2 1 4 \ - > expdata/8gpu_2dp4tp.txt + > exptrain_782M_bs8_fp32/8gpu_maxtp.txt # ================== Window + Tensor Parallel =============== @@ -104,7 +104,7 @@ python -m torch.distributed.launch \ --layer1 1 2 1 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > expdata/2gpu_2wp2tp.txt + > exptrain_782M_bs8_fp32/2gpu_2wp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -118,7 +118,7 @@ python -m torch.distributed.launch \ --layer1 1 4 1 \ --layer2 1 1 4 \ --layer3 1 1 4 \ - > expdata/4gpu_4wp4tp.txt + > exptrain_782M_bs8_fp32/4gpu_4wp4tp.txt python -m torch.distributed.launch \ --nproc_per_node=8 \ @@ -132,7 +132,7 @@ python -m torch.distributed.launch \ --layer1 1 1 8 \ --layer2 1 1 8 \ --layer3 1 1 8 \ - > expdata/8gpu_8wp8tp.txt + > exptrain_782M_bs8_fp32/8gpu_8wp8tp.txt # ================== Data + Tensor Parallel =============== @@ -148,7 +148,7 @@ python -m torch.distributed.launch \ --layer1 2 1 \ --layer2 1 2 \ --layer3 1 2 \ - > expdata/2gpu_dt_2dp2tp.txt + > exptrain_782M_bs8_fp32/2gpu_dt_2dp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -162,7 +162,7 @@ python -m torch.distributed.launch \ --layer1 4 1 \ --layer2 1 4 \ --layer3 1 4 \ - > expdata/4gpu_dt_4dp4tp.txt + > exptrain_782M_bs8_fp32/4gpu_dt_4dp4tp.txt python -m torch.distributed.launch \ --nproc_per_node=8 \ @@ -176,4 +176,4 @@ python -m torch.distributed.launch \ --layer1 8 1 \ --layer2 1 8 \ --layer3 1 8 \ - > expdata/8gpu_dt_8dp8tp.txt + > exptrain_782M_bs8_fp32/8gpu_dt_8dp8tp.txt From 216c0a46391f46406ee75916245cf1092e973c08 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 04:38:23 +0000 Subject: [PATCH 0444/1892] update train test script --- eval/{swin_fp16.sh => swin_train_fp16.sh} | 116 +++++++++++++++++----- eval/swin_train_fp32.sh | 52 +++++----- 2 files changed, 119 insertions(+), 49 deletions(-) rename eval/{swin_fp16.sh => swin_train_fp16.sh} (59%) diff --git a/eval/swin_fp16.sh b/eval/swin_train_fp16.sh similarity index 59% rename from eval/swin_fp16.sh rename to eval/swin_train_fp16.sh index 75808c81..b6ef102a 100755 --- a/eval/swin_fp16.sh +++ b/eval/swin_train_fp16.sh @@ -1,4 +1,8 @@ -mkdir -p expdata_fp16 +bs=$1 + +logfile=exptrain_782M_bs${bs}_fp32 + +mkdir -p ${logfile} # ================== Megatron Policy Parallel =============== @@ -9,13 +13,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 1 2 \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ --fp16 \ - > expdata_fp16/2gpu_dp.txt + > ${logfile}/2gpu_maxdp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -24,13 +28,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 2 1 2 \ --layer1 2 1 2 \ --layer2 2 1 2 \ --layer3 2 1 2 \ --fp16 \ - > expdata_fp16/4gpu_dp.txt + > ${logfile}/4gpu_maxdp2tp.txt python -m torch.distributed.launch \ @@ -40,13 +44,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 4 1 2 \ --layer1 4 1 2 \ --layer2 4 1 2 \ --layer3 4 1 2 \ --fp16 \ - > expdata_fp16/8gpu_dp.txt + > ${logfile}/8gpu_maxdp2tp.txt # ================== Maximal Tensor Parallel =============== @@ -57,13 +61,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 1 2 \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ --fp16 \ - > expdata_fp16/2gpu_tp.txt + > ${logfile}/2gpu_maxtp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -72,13 +76,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 1 4 \ --layer1 1 1 4 \ --layer2 1 1 4 \ --layer3 1 1 4 \ --fp16 \ - > expdata_fp16/4gpu_tp.txt + > ${logfile}/4gpu_maxtp.txt python -m torch.distributed.launch \ @@ -88,15 +92,29 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 2 1 4 \ --layer1 2 1 4 \ --layer2 2 1 4 \ --layer3 2 1 4 \ --fp16 \ - > expdata_fp16/8gpu_2dp4tp.txt + > ${logfile}/8gpu_maxtp.txt # ================== Window + Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs ${bs} \ + --layer0 1 1 1 \ + --layer1 1 1 1 \ + --layer2 1 1 1 \ + --layer3 1 1 1 \ + --fp16 \ + > ${logfile}/single.txt python -m torch.distributed.launch \ --nproc_per_node=2 \ @@ -105,13 +123,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 2 1 \ --layer1 1 2 1 \ --layer2 1 1 2 \ --layer3 1 1 2 \ --fp16 \ - > expdata_fp16/2gpu_2wp2tp.txt + > ${logfile}/2gpu_2wp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -120,13 +138,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 4 1 \ --layer1 1 4 1 \ --layer2 1 1 4 \ --layer3 1 1 4 \ --fp16 \ - > expdata_fp16/4gpu_4wp4tp.txt + > ${logfile}/4gpu_4wp4tp.txt python -m torch.distributed.launch \ --nproc_per_node=8 \ @@ -135,13 +153,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 8 1 \ --layer1 1 1 8 \ --layer2 1 1 8 \ --layer3 1 1 8 \ --fp16 \ - > expdata_fp16/8gpu_8wp8tp.txt + > ${logfile}/8gpu_8wp8tp.txt # ================== Data + Tensor Parallel =============== @@ -152,13 +170,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dt.py --bs 8 \ + examples/swin/swin_dt.py --bs ${bs} \ --layer0 2 1 \ --layer1 2 1 \ --layer2 1 2 \ --layer3 1 2 \ --fp16 \ - > expdata_fp16/2gpu_dt_2dp2tp.txt + > ${logfile}/2gpu_dt_2dp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -167,13 +185,13 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dt.py --bs 8 \ + examples/swin/swin_dt.py --bs ${bs} \ --layer0 4 1 \ --layer1 4 1 \ --layer2 1 4 \ --layer3 1 4 \ --fp16 \ - > expdata_fp16/4gpu_dt_4dp4tp.txt + > ${logfile}/4gpu_dt_4dp4tp.txt python -m torch.distributed.launch \ --nproc_per_node=8 \ @@ -182,10 +200,58 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dt.py --bs 8 \ + examples/swin/swin_dt.py --bs ${bs} \ --layer0 8 1 \ --layer1 8 1 \ --layer2 1 8 \ --layer3 1 8 \ --fp16 \ - > expdata_fp16/8gpu_dt_8dp8tp.txt + > ${logfile}/8gpu_dt_8dp8tp.txt + + +# ========================== Data Parallel ====================== # + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs ${bs} \ + --layer0 2 1 1 \ + --layer1 2 1 1 \ + --layer2 2 1 1 \ + --layer3 2 1 1 \ + --fp16 \ + > ${logfile}/2gpu_maxdp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs ${bs} \ + --layer0 4 1 1 \ + --layer1 4 1 1 \ + --layer2 4 1 1 \ + --layer3 4 1 1 \ + --fp16 \ + > ${logfile}/4gpu_maxdp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs ${bs} \ + --layer0 8 1 1 \ + --layer1 8 1 1 \ + --layer2 8 1 1 \ + --layer3 8 1 1 \ + --fp16 \ + > ${logfile}/8gpu_maxdp.txt \ No newline at end of file diff --git a/eval/swin_train_fp32.sh b/eval/swin_train_fp32.sh index 2bf31b74..8a15b796 100755 --- a/eval/swin_train_fp32.sh +++ b/eval/swin_train_fp32.sh @@ -1,4 +1,8 @@ -mkdir -p exptrain_782M_bs8_fp32 +bs=$1 + +logfile=exptrain_782M_bs${bs}_fp32 + +mkdir -p ${logfile} # ================== Megatron Policy Parallel =============== @@ -9,12 +13,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 1 2 \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > exptrain_782M_bs8_fp32/2gpu_maxdp.txt + > ${logfile}/2gpu_maxdp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -23,12 +27,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 2 1 2 \ --layer1 2 1 2 \ --layer2 2 1 2 \ --layer3 2 1 2 \ - > exptrain_782M_bs8_fp32/4gpu_maxdp.txt + > ${logfile}/4gpu_maxdp.txt python -m torch.distributed.launch \ @@ -38,12 +42,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 4 1 2 \ --layer1 4 1 2 \ --layer2 4 1 2 \ --layer3 4 1 2 \ - > exptrain_782M_bs8_fp32/8gpu_maxdp.txt + > ${logfile}/8gpu_maxdp.txt # ================== Maximal Tensor Parallel =============== @@ -54,12 +58,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 1 2 \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > exptrain_782M_bs8_fp32/2gpu_maxtp.txt + > ${logfile}/2gpu_maxtp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -68,12 +72,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 1 4 \ --layer1 1 1 4 \ --layer2 1 1 4 \ --layer3 1 1 4 \ - > exptrain_782M_bs8_fp32/4gpu_maxtp.txt + > ${logfile}/4gpu_maxtp.txt python -m torch.distributed.launch \ @@ -83,12 +87,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 2 1 4 \ --layer1 2 1 4 \ --layer2 2 1 4 \ --layer3 2 1 4 \ - > exptrain_782M_bs8_fp32/8gpu_maxtp.txt + > ${logfile}/8gpu_maxtp.txt # ================== Window + Tensor Parallel =============== @@ -99,12 +103,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 2 1 \ --layer1 1 2 1 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > exptrain_782M_bs8_fp32/2gpu_2wp2tp.txt + > ${logfile}/2gpu_2wp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -113,7 +117,7 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 4 1 \ --layer1 1 4 1 \ --layer2 1 1 4 \ @@ -127,12 +131,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + examples/swin/swin_dwt.py --bs ${bs} \ --layer0 1 8 1 \ --layer1 1 1 8 \ --layer2 1 1 8 \ --layer3 1 1 8 \ - > exptrain_782M_bs8_fp32/8gpu_8wp8tp.txt + > ${logfile}/8gpu_8wp8tp.txt # ================== Data + Tensor Parallel =============== @@ -143,12 +147,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dt.py --bs 8 \ + examples/swin/swin_dt.py --bs ${bs} \ --layer0 2 1 \ --layer1 2 1 \ --layer2 1 2 \ --layer3 1 2 \ - > exptrain_782M_bs8_fp32/2gpu_dt_2dp2tp.txt + > ${logfile}/2gpu_dt_2dp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -157,12 +161,12 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dt.py --bs 8 \ + examples/swin/swin_dt.py --bs ${bs} \ --layer0 4 1 \ --layer1 4 1 \ --layer2 1 4 \ --layer3 1 4 \ - > exptrain_782M_bs8_fp32/4gpu_dt_4dp4tp.txt + > ${logfile}/4gpu_dt_4dp4tp.txt python -m torch.distributed.launch \ --nproc_per_node=8 \ @@ -171,9 +175,9 @@ python -m torch.distributed.launch \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dt.py --bs 8 \ + examples/swin/swin_dt.py --bs ${bs} \ --layer0 8 1 \ --layer1 8 1 \ --layer2 1 8 \ --layer3 1 8 \ - > exptrain_782M_bs8_fp32/8gpu_dt_8dp8tp.txt + > ${logfile}/8gpu_dt_8dp8tp.txt From 90517cae19fe2ab92353c9966de19049abd9c4cf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 07:10:49 +0000 Subject: [PATCH 0445/1892] fix cpu code bug --- examples/swin/swin_dwt_infer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/swin/swin_dwt_infer.py b/examples/swin/swin_dwt_infer.py index e3f0e6c6..9c5600a7 100644 --- a/examples/swin/swin_dwt_infer.py +++ b/examples/swin/swin_dwt_infer.py @@ -207,6 +207,10 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # relative position index + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + self.register_buffer('relative_position_index', relative_position_index) + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # print(f'qkv embed dim: {dim}') self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) @@ -231,8 +235,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): q = q * self.scale attn = (q @ k.transpose(-2, -1)) - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - relative_position_bias = self.relative_position_bias_table[relative_position_index] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # [Wh * Ww, Wh * Ww, nH] relative_position_bias = relative_position_bias.view( self.window_size[0] * self.window_size[1], @@ -868,13 +871,12 @@ def train(args, pconfigs): data_buff = [[e.half() for e in data] for data in dataloader.datas] dataloader.datas = data_buff + model.eval() def infer_iter(model, dataloader): with torch.no_grad(): img = next(dataloader) loss = model(img) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - # start training nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) @@ -944,7 +946,5 @@ def infer_iter(model, dataloader): dict(layer_id=3, dp=args.layer3[0], wp=args.layer3[1], tp=args.layer3[2]), # basic layer 3 ] - args.fp16 = True - print_each_rank(pconfigs, rank_only=0) train(args, pconfigs) From ace6fd62148e91bb6fd3d623f4cb31e72b78d48a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 07:25:34 +0000 Subject: [PATCH 0446/1892] dwt infer bug fix --- examples/swin/swin_dwt_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/swin/swin_dwt_infer.py b/examples/swin/swin_dwt_infer.py index 9c5600a7..652eb074 100644 --- a/examples/swin/swin_dwt_infer.py +++ b/examples/swin/swin_dwt_infer.py @@ -483,7 +483,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, nH = 1 nW = wp // nH while nH <= nW: - if H % nH != 0 or W % nW != 0: + if H % nH != 0 or W % nW != 0 or (H // nH) % window_size != 0 or (W // nW) % window_size != 0: nW = nW // 2 nH = int(nH * 2) else: From 2aa3bf5b0855684478581879b35b556767399de5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 07:31:34 +0000 Subject: [PATCH 0447/1892] infer bs2 --- eval/swin_infer_bs2_640_Gfp16.sh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/eval/swin_infer_bs2_640_Gfp16.sh b/eval/swin_infer_bs2_640_Gfp16.sh index cb1540aa..91c5e7ae 100755 --- a/eval/swin_infer_bs2_640_Gfp16.sh +++ b/eval/swin_infer_bs2_640_Gfp16.sh @@ -88,10 +88,10 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 1 4 1 \ - --layer1 1 4 1 \ - --layer2 1 4 1 \ - --layer3 1 1 4 \ + --layer0 2 2 1 \ + --layer1 2 2 1 \ + --layer2 2 2 1 \ + --layer3 2 1 2 \ --fp16 \ > expinfer_Gfp16_bs2/4gpu_4wp4tp.txt @@ -103,10 +103,10 @@ python -m torch.distributed.launch \ --master_port=8004 \ --use_env \ examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 1 8 1 \ - --layer1 1 8 1 \ - --layer2 1 4 2 \ - --layer3 1 1 8 \ + --layer0 2 4 1 \ + --layer1 2 4 1 \ + --layer2 2 4 1 \ + --layer3 2 1 4 \ --fp16 \ > expinfer_Gfp16_bs2/8gpu_8wp8tp.txt From 4455986a8d3be9fad9305711de861a4a0ce733fa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 11:58:09 +0000 Subject: [PATCH 0448/1892] 1f1b pipeline --- examples/swin/schedule.py | 234 +++++++++++ examples/swin/swin_pipe.py | 842 +++++++++++++++++++++++++++++++++++++ 2 files changed, 1076 insertions(+) create mode 100644 examples/swin/schedule.py create mode 100644 examples/swin/swin_pipe.py diff --git a/examples/swin/schedule.py b/examples/swin/schedule.py new file mode 100644 index 00000000..5b271bbc --- /dev/null +++ b/examples/swin/schedule.py @@ -0,0 +1,234 @@ +import torch + +from cube.profiler.timer import CudaTimer + + +def is_last_stage(): + return torch.distributed.get_rank() == torch.distributed.get_world_size() - 1 + + +#================= WhatToDO functions ==================# + +def forward_step(model, image, trans_input=None): + CudaTimer().start("forward") + output = model(image, trans_input) + CudaTimer().stop("forward") + return output + + +def backward_step(feature_map, output_tensor, output_tensor_grad): + """ + Calculate input tensor gradient + """ + if feature_map is not None and feature_map.requires_grad: + feature_map.retain_grad() + CudaTimer().start("backward") + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + CudaTimer().stop("backward") + input_tensor_grad = None + if feature_map is not None and feature_map.requires_grad: + input_tensor_grad = feature_map.grad + return input_tensor_grad + +#================= WhatToDO functions ==================# + +#================= Between Stage functions ==================# + +def send(tensors, to_rank): + """ + send tensor to the target rank + """ + if to_rank < 0 or to_rank >= torch.distributed.get_world_size(): + return None + assert isinstance(tensors, list) or isinstance(tensors, tuple) + CudaTimer().start("send") + reqs = list() + for tensor in tensors: + if tensor is None: + continue + elif torch.is_tensor(tensor): + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, to_rank + ) + reqs.append(send_op) + else: + raise RuntimeError("Expected tensor or None") + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("send") + + +def recv(shapes, from_rank, dtype=torch.float): + if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): + return [None] * len(shapes) + assert isinstance(shapes, list) or isinstance(shapes, tuple) + CudaTimer().start("recv") + reqs = list() + recved_tensors = list() + for shape in shapes: + if shape is None: + recved_tensors.append(None) + continue + tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device(), + dtype=dtype + ) + recved_tensors.append(tensor) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, from_rank + ) + reqs.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("recv") + return recved_tensors + + +def send_and_recv(send_tensors, recv_shapes, rank, dtype=torch.float): + if rank < 0 or rank >= torch.distributed.get_world_size(): + return [None] * len(recv_shapes) + assert isinstance(send_tensors, list) or isinstance(send_tensors, tuple) + assert isinstance(recv_shapes, list) or isinstance(recv_shapes, tuple) + CudaTimer().start("send_recv") + reqs = list() + recved_tensors = list() + for tensor in send_tensors: + if tensor is None: + continue + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, rank + ) + reqs.append(send_op) + for shape in recv_shapes: + if shape is None: + recved_tensors.append(None) + continue + recv_tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device(), + dtype=dtype + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, rank + ) + recved_tensors.append(recv_tensor) + reqs.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("send_recv") + return recved_tensors + +#================= Between Stage functions ==================# + +def split_batch(inputs, num_microbatches): + """ + Split a mini-batch to micro-batches + """ + assert isinstance(inputs, list) or isinstance(inputs, tuple) + input_chunks = list() + for feature_map in inputs: + if torch.is_tensor(feature_map): + feature_map = torch.chunk(feature_map, chunks=num_microbatches, dim=0) + else: + feature_map = [feature_map] * num_microbatches + input_chunks.append(feature_map) + micro_batches = list() + for micro_data in zip(*tuple(input_chunks)): + micro_batches.append(micro_data) + return micro_batches + + +#================= Scheduling ==================# + +def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): + myrank = torch.distributed.get_rank() + + num_microbatches = int(bs / micro_bs) + num_warmup_microbatches = \ + (torch.distributed.get_world_size() - + torch.distributed.get_rank() - 1) + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_warmup_remaining = num_microbatches - num_warmup_microbatches + + input_tensors = list() + output_tensors = list() + + inputs = split_batch(inputs, num_microbatches) + + # warmup forward pass + for i in range(num_warmup_microbatches): + # recv forward + # print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) + feature_map = recv( + (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype + )[0] + image = inputs[i][0] + # forward + output_tensor = forward_step(model, image, feature_map) + # send forward + # print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) + send((output_tensor,), myrank+1) + + input_tensors.append(feature_map) + output_tensors.append(output_tensor) + + # before running 1F1B, need to recieve first forward tensor + if num_warmup_remaining > 0: + # recv forward + # print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) + feature_map = recv( + (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype + )[0] + image = inputs[num_warmup_microbatches][0] + + # run 1F1B + for i in range(num_warmup_remaining): + # forward + output_tensor = forward_step(model, image, feature_map) + # send forward + recv backward grads + # print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) + output_tensor_grad = send_and_recv( + (output_tensor,), + (torch.Size([micro_bs] + model.out_size),), + myrank+1, dtype + )[0] + input_tensors.append(feature_map) + output_tensors.append(output_tensor) + # backward + feature_map, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) + if i != (num_warmup_remaining-1): + # send backward grads + recv forward results + # print('[1f1b] rank {}: step-{}: sending backward + recving forward...'.format(myrank, i)) + feature_map = send_and_recv( + (input_tensor_grad,), + (torch.Size([micro_bs] + model.in_size),), + myrank-1, dtype + )[0] + image = inputs[num_warmup_microbatches+i+1][0] + else: # last iteration - no more inputs + feature_map = None + # send backward grads + # print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) + send((input_tensor_grad,), myrank-1) + + # cooldown gradient trans back + for i in range(num_warmup_microbatches): + feature_map = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + # recv backward gradients + output_tensor_grad = recv( + (torch.Size([micro_bs] + model.out_size),), myrank+1, dtype + )[0] + # backward + input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) + # send backward gradients + # print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) + send((input_tensor_grad,), myrank-1) + +#================= Scheduling ==================# \ No newline at end of file diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py new file mode 100644 index 00000000..94c3650f --- /dev/null +++ b/examples/swin/swin_pipe.py @@ -0,0 +1,842 @@ + +# -------------------------------------------------------- +# Modified from Swin-Transformer Repo +""" +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_pipe.py --pp 4 +""" +# -------------------------------------------------------- + +from typing import Optional +import torch +import torch.nn as nn + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary + +import argparse + +from examples.swin.schedule import scheduling_1f1b, is_last_stage +from benchmark.swin.layers import ColumnParallelLinear, RowParallelLinear + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class MegatronMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, full_input=True, full_output=False) + self.act = act_layer() + # self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = RowParallelLinear(hidden_features, out_features, full_input=False, full_output=True) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + assert B == 1 + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + assert B == 1 + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +class MegatronWindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.global_num_heads = num_heads + group = cube.runtime.resource.EnvResource().tp_group + tp_world_size = torch.distributed.get_world_size(group=group) + if num_heads % tp_world_size != 0: + print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') + self.num_heads = num_heads // torch.distributed.get_world_size(group=group) + self.dim_heads = dim // self.global_num_heads + self.scale = qk_scale or self.dim_heads ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # relative position index + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + self.register_buffer('relative_position_index', relative_position_index) + + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # print(f'qkv embed dim: {dim}') + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False) + self.attn_drop = nn.Dropout(attn_drop) + # self.proj = nn.Linear(dim, dim) + self.proj = RowParallelLinear(dim, dim, full_input=False, full_output=True) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = MegatronWindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MegatronMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + H, W = self.input_resolution + self.in_size = [H * W, self.dim] + self.out_size = [H * W, self.dim] + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert B == 1 + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + # [B, H, W, C] -> [B, H * W, C] + x = x.view(B, H * W, C) + # [B, H * W, C] -> [B, H * W, C] + x = shortcut + drop_path(x, self.drop_path_p) + # FFN + # [B, H * W, C] -> [B, H * W, C] + ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] + ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] + x = x + drop_path(ffn, self.drop_path_p) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + H, W = self.input_resolution + self.in_size = [H * W, self.dim] + self.out_size = [H // 2 * W // 2, self.dim * 2] + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert B == 1 + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, module_lists=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + for i in range(depth): + block = SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + module_lists.append(block) + + # patch merging layer + if downsample is not None: + merging = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + module_lists.append(merging) + + def forward(self, x): + raise RuntimeError("Error call here") + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + layers = nn.ModuleList() + for i_layer in range(self.num_layers): + _ = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + module_lists=layers) + + # pipeline stage + resource =cube.runtime.resource.EnvResource() + pp_rank = torch.distributed.get_rank() + pp_size = torch.distributed.get_world_size() + chunk = len(layers) // pp_size + start = pp_rank * chunk + stop = (pp_rank + 1) * chunk + if is_last_stage(): + stop = len(layers) + self.layers = layers[start:stop] + print_each_rank([str(type(layer)) + '\n' for layer in self.layers]) + + self.in_size = self.layers[0].in_size + assert isinstance(self.in_size, list) + self.out_size = self.layers[-1].out_size + assert isinstance(self.out_size, list) + + + self.preprocess = False + if pp_rank == 0: + self.preprocess = True + self.in_size = [in_chans, img_size, img_size] + self.postprocess = False + if is_last_stage(): + self.postprocess = True + self.out_size = [1,] + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, image, feature_map=None): + + if self.preprocess: + x = self.patch_embed(image) + x = self.pos_drop(x) + feature_map = x + + for layer in self.layers: + feature_map = layer(feature_map) + x = feature_map + + if self.postprocess: + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C L + x = torch.flatten(x, 1) + x = self.head(x) + # simulate for simplicity + x = torch.sum(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def train(args): + resource = cube.runtime.resource.EnvResource() + + # dim_head is always 32 + + # img resolution, windows size: 224, 384, 518, 640 + # C, H, W, window_size = [3, 224, 224, 7] + C, H, W, window_size = [3, 384, 384, 12] + # C, H, W, window_size = [3, 518, 518, ?] + # C, H, W, window_size = [3, 640, 640, 20] + # C, H, W, window_size = [3, 1536, 1536, 48] + + # image batch size + N = args.gbs + + # Swin-Tiny + # embed_dim, depths, num_heads = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24] + # ] + + # SwinV2-B: 87 M + # embed_dim, depths, num_heads = [ + # 128, [2, 2, 18, 2], [4, 8, 16, 32] + # ] + + # SwinV2-L: 196 M + # embed_dim, depths, num_heads = [ + # 192, [2, 2, 18, 2], [6, 12, 24, 48] + # ] + + # SwinV2-H: 657 M + embed_dim, depths, num_heads = [ + 352, [2, 2, 18, 2], [11, 22, 44, 88] + ] + + # SwinV2-H modified: 782 M + # embed_dim, depths, num_heads = [ + # 384, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + + # SwinV2-G: 2.5B Model + # embed_dim, depths, num_heads = [ + # 512, [2, 2, 42, 2], [16, 32, 64, 128] + # ] + + # 895.7 M Model + # embed_dim, depths, num_heads = [ + # 384, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + + # 2.01B model + # embed_dim, depths, num_heads = [ + # 576, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + + model = SwinTransformer(img_size = H, + embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size) + model = model.cuda() + memory_summary() + + # setup data parallel reducer + # reducer = None + if args.dp > 1: + print('> initialize weight reducer') + reducer = resource.reducer + for param in model.parameters(): + reducer.add_param(param) + + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [N // args.dp, C, H, W]) + + def train_iter(model, dataloader): + # with torch.no_grad(): + img = next(dataloader) + scheduling_1f1b(model, [img], args.gbs, args.mbs, dtype=torch.float) + # if reducer is not None: + # reducer.allreduce() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # start training + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on 1st iteration') + memory_summary() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + + +if __name__ == '__main__': + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--tp', type=int, default=1, + help='tensor parallel size') + parser.add_argument('--dp', type=int, default=1, + help='data parallel size') + parser.add_argument('--pp', type=int, default=1, + help='pipeline parallel size') + parser.add_argument('--gbs', type=int, default=-1) + parser.add_argument('--mbs', type=int, default=-1) + args = parser.parse_args() + + cube.init() + + # allocate resource + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + tp_size, tp_group_nums = args.tp, ndevs // args.tp + dp_size, dp_group_nums = args.dp, ndevs // args.dp + pp_size, pp_group_nums = args.pp, ndevs // args.pp + + if not pp_size * dp_size * tp_size == ndevs: + raise RuntimeError("Expected all devices are used") + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize data parallel group + all_data_parallel_group_ranks = list() + for i in range(pp_size): + start_rank = i * pp_group_nums + end_rank = (i + 1) * pp_group_nums + for j in range(tp_size): + ranks = list(range(start_rank + j, end_rank, tp_size)) + all_data_parallel_group_ranks.append(ranks) + # initialize groups + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + resource.dp_group = group + resource.reducer = cube.runtime.reducer.Reducer(ranks) + print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) + + # initialize pipelne parallel groups + for i in range(dp_size): + ranks = [data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks] + group = devs.get_group(ranks) + if myrank in ranks: + pp_ranks = ranks + resource.pp_group = group + print_each_rank(f'initialzed pipeline parallel group: {pp_ranks}', rank_only=myrank) + + # initialize tensor parallel groups + for i in range(tp_group_nums): + ranks = list(range(i * tp_size, (i + 1) * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + resource.tp_group = group + print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + + train(args) From 982c1bbc234bbc949477c6640032cb2fc14ffb9c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 12:01:40 +0000 Subject: [PATCH 0449/1892] swin pipe bug fix --- examples/swin/swin_pipe.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index 94c3650f..027f3bef 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -9,7 +9,7 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_pipe.py --pp 4 + examples/swin/swin_pipe.py --pp 8 --gbs 1 --mbs 1 """ # -------------------------------------------------------- @@ -69,7 +69,6 @@ def window_partition(x: torch.Tensor, window_size: int): windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape - assert B == 1 # [B, H_window_num, window_size, W_window_num, window_size, C] x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) # [B, H_window_num, W_window_num, window_size, window_size, C] @@ -90,7 +89,6 @@ def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) - assert B == 1 x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -287,7 +285,6 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 def forward(self, x): H, W = self.input_resolution B, L, C = x.shape - assert B == 1 assert L == H * W, "input feature has wrong size" shortcut = x @@ -379,7 +376,6 @@ def forward(self, x): """ H, W = self.input_resolution B, L, C = x.shape - assert B == 1 assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." From 4ad639189e09f94c3363eda11db6434776414643 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 12:43:50 +0000 Subject: [PATCH 0450/1892] swin pipe line --- examples/swin/swin_pipe.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index 027f3bef..3dc546f9 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -10,12 +10,15 @@ --master_port=8004 \ --use_env \ examples/swin/swin_pipe.py --pp 8 --gbs 1 --mbs 1 + +# V100-16GB: 8GPU: need checkpoint: 16 micro bs """ # -------------------------------------------------------- from typing import Optional import torch import torch.nn as nn +import torch.utils.checkpoint as checkpoint import cube from cube.profiler import CudaTimer @@ -661,8 +664,8 @@ def train(args): # dim_head is always 32 # img resolution, windows size: 224, 384, 518, 640 - # C, H, W, window_size = [3, 224, 224, 7] - C, H, W, window_size = [3, 384, 384, 12] + C, H, W, window_size = [3, 224, 224, 7] + # C, H, W, window_size = [3, 384, 384, 12] # C, H, W, window_size = [3, 518, 518, ?] # C, H, W, window_size = [3, 640, 640, 20] # C, H, W, window_size = [3, 1536, 1536, 48] @@ -686,15 +689,15 @@ def train(args): # ] # SwinV2-H: 657 M - embed_dim, depths, num_heads = [ - 352, [2, 2, 18, 2], [11, 22, 44, 88] - ] - - # SwinV2-H modified: 782 M # embed_dim, depths, num_heads = [ - # 384, [2, 2, 18, 2], [12, 24, 48, 96] + # 352, [2, 2, 18, 2], [11, 22, 44, 88] # ] + # SwinV2-H modified: 782 M + embed_dim, depths, num_heads = [ + 384, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # SwinV2-G: 2.5B Model # embed_dim, depths, num_heads = [ # 512, [2, 2, 42, 2], [16, 32, 64, 128] @@ -722,14 +725,11 @@ def train(args): # setup data parallel reducer # reducer = None - if args.dp > 1: - print('> initialize weight reducer') - reducer = resource.reducer - for param in model.parameters(): - reducer.add_param(param) + assert args.dp == 1 dataloader = cube.runtime.syndata.SynDataLoader( 1280, [0], [N // args.dp, C, H, W]) + dataloader.set_data_buffer(buffer_num=16) def train_iter(model, dataloader): # with torch.no_grad(): @@ -766,6 +766,8 @@ def train_iter(model, dataloader): print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) + + CudaTimer().print_all(times=iter_num-40) memory_summary() From ce6504242b0ee2e6ed84159e6113820d75e00c48 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 13:34:57 +0000 Subject: [PATCH 0451/1892] better model partition --- examples/swin/swin_pipe.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index 3dc546f9..61ec734c 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -11,7 +11,7 @@ --use_env \ examples/swin/swin_pipe.py --pp 8 --gbs 1 --mbs 1 -# V100-16GB: 8GPU: need checkpoint: 16 micro bs +# V100-16GB: 8GPU: need checkpoint: 8 micro bs """ # -------------------------------------------------------- @@ -20,6 +20,7 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint + import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank @@ -587,14 +588,20 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, module_lists=layers) # pipeline stage - resource =cube.runtime.resource.EnvResource() pp_rank = torch.distributed.get_rank() pp_size = torch.distributed.get_world_size() chunk = len(layers) // pp_size - start = pp_rank * chunk - stop = (pp_rank + 1) * chunk - if is_last_stage(): - stop = len(layers) + if len(layers) % pp_size != 0: + remain = len(layers) % pp_size + if pp_rank < remain: + start = pp_rank * (chunk+1) + chunk = chunk + 1 + else: + start = remain * (chunk + 1) + (pp_rank - remain) * chunk + else: + start = pp_rank * chunk + stop = start + chunk + print_each_rank(f'layer start -> end: {start} -> {stop}') self.layers = layers[start:stop] print_each_rank([str(type(layer)) + '\n' for layer in self.layers]) From 70707d645baa25e748cdbf6466886a5a174538d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 13:42:17 +0000 Subject: [PATCH 0452/1892] checkpoint --- examples/swin/swin_pipe.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index 61ec734c..a6f3f055 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -643,7 +643,8 @@ def forward(self, image, feature_map=None): feature_map = x for layer in self.layers: - feature_map = layer(feature_map) + feature_map = checkpoint.checkpoint(layer, feature_map) + # feature_map = layer(feature_map) x = feature_map if self.postprocess: @@ -753,9 +754,9 @@ def train_iter(model, dataloader): CudaTimer().warmup() torch.distributed.barrier() - iter_num = 128 + iter_num = 40 for step in range(iter_num): - if step >= 40: + if step >= 20: CudaTimer().start('e2e') train_iter(model, dataloader) optimizer.step() @@ -763,18 +764,18 @@ def train_iter(model, dataloader): if step == 1: print('> passed on 1st iteration') memory_summary() - if step >= 40: + if step >= 20: CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: + if (step + 1) % 10 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') + iter_time = CudaTimer().duration(iter_num-20, field_name='e2e') throughput = N / iter_time * 1000 print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) - CudaTimer().print_all(times=iter_num-40) + CudaTimer().print_all(times=iter_num-20) memory_summary() From 8da89a6a64140d1fcfeac54cfcd37b478f84cb69 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 15:29:44 +0000 Subject: [PATCH 0453/1892] swin pipeline cutomize checkpointing --- examples/swin/swin_pipe.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index a6f3f055..a8c09faf 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -590,6 +590,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # pipeline stage pp_rank = torch.distributed.get_rank() pp_size = torch.distributed.get_world_size() + chunk = len(layers) // pp_size if len(layers) % pp_size != 0: remain = len(layers) % pp_size @@ -601,6 +602,19 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, else: start = pp_rank * chunk stop = start + chunk + + # 8gpu layer assign + # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] + # assert sum(layer_split) == 27 + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) + + self.use_checkpoint = [False] * (stop - start) + for idx in range(stop - start): + if pp_rank == 0: + if idx in [0, 1]: + self.use_checkpoint[idx] = True + print_each_rank(f'layer start -> end: {start} -> {stop}') self.layers = layers[start:stop] print_each_rank([str(type(layer)) + '\n' for layer in self.layers]) @@ -642,9 +656,11 @@ def forward(self, image, feature_map=None): x = self.pos_drop(x) feature_map = x - for layer in self.layers: - feature_map = checkpoint.checkpoint(layer, feature_map) - # feature_map = layer(feature_map) + for layer, use_checkpoint in zip(self.layers, self.use_checkpoint): + if use_checkpoint: + feature_map = checkpoint.checkpoint(layer, feature_map) + else: + feature_map = layer(feature_map) x = feature_map if self.postprocess: From ff2cf859d578ba3e4c3ce396a09a84891f2daf9a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 30 Nov 2021 16:53:36 +0000 Subject: [PATCH 0454/1892] pipeline test code --- examples/swin/swin_pipe.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index a8c09faf..5aadec2f 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -603,19 +603,28 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, start = pp_rank * chunk stop = start + chunk + self.use_checkpoint = [True] * (stop - start) + # 8gpu layer assign - # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] + # layer_split = [3, 4, 3, 3, 3, 4, 4, 3] + # assert sum(layer_split) == 27 + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) + + # 4Ggpu layer assign + # layer_split = [6, 7, 7, 7] # assert sum(layer_split) == 27 # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) - self.use_checkpoint = [False] * (stop - start) - for idx in range(stop - start): - if pp_rank == 0: - if idx in [0, 1]: - self.use_checkpoint[idx] = True + # self.use_checkpoint = [False] * (stop - start) + # for idx in range(stop - start): + # if pp_rank == 0: + # if idx in [0,]: + # self.use_checkpoint[idx] = True print_each_rank(f'layer start -> end: {start} -> {stop}') + print_each_rank(self.use_checkpoint) self.layers = layers[start:stop] print_each_rank([str(type(layer)) + '\n' for layer in self.layers]) @@ -756,7 +765,6 @@ def train(args): dataloader.set_data_buffer(buffer_num=16) def train_iter(model, dataloader): - # with torch.no_grad(): img = next(dataloader) scheduling_1f1b(model, [img], args.gbs, args.mbs, dtype=torch.float) # if reducer is not None: @@ -770,9 +778,9 @@ def train_iter(model, dataloader): CudaTimer().warmup() torch.distributed.barrier() - iter_num = 40 + iter_num = 20 for step in range(iter_num): - if step >= 20: + if step >= 10: CudaTimer().start('e2e') train_iter(model, dataloader) optimizer.step() @@ -780,18 +788,18 @@ def train_iter(model, dataloader): if step == 1: print('> passed on 1st iteration') memory_summary() - if step >= 20: + if step >= 10: CudaTimer().stop('e2e') if (step + 1) % 10 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = CudaTimer().duration(iter_num-20, field_name='e2e') + iter_time = CudaTimer().duration(iter_num-10, field_name='e2e') throughput = N / iter_time * 1000 print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) - CudaTimer().print_all(times=iter_num-20) + CudaTimer().print_all(times=iter_num-10) memory_summary() From 907d9cdb3fec754a6e2ed170129d9dfcaf332030 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 05:57:46 +0000 Subject: [PATCH 0455/1892] init efficientnet --- examples/efficientnet/efficientnet.py | 397 +++++++++++++++++ examples/efficientnet/schedule.py | 234 ++++++++++ examples/efficientnet/train.py | 92 ++++ examples/efficientnet/utils.py | 586 ++++++++++++++++++++++++++ 4 files changed, 1309 insertions(+) create mode 100644 examples/efficientnet/efficientnet.py create mode 100644 examples/efficientnet/schedule.py create mode 100644 examples/efficientnet/train.py create mode 100644 examples/efficientnet/utils.py diff --git a/examples/efficientnet/efficientnet.py b/examples/efficientnet/efficientnet.py new file mode 100644 index 00000000..4f5de160 --- /dev/null +++ b/examples/efficientnet/efficientnet.py @@ -0,0 +1,397 @@ +"""model.py - Model and module class for EfficientNet. + They are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +# https://arxiv.org/pdf/1911.04252.pdf + +import torch +from torch import nn +from torch.nn import functional as F +from .utils import ( + round_filters, + round_repeats, + drop_connect, + get_same_padding_conv2d, + get_model_params, + efficientnet_params, + load_pretrained_weights, + Swish, + MemoryEfficientSwish, + calculate_output_image_size +) + + +VALID_MODELS = ( + 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', + 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', + 'efficientnet-b8', + + # Support the construction of 'efficientnet-l2' without pretrained weights + 'efficientnet-l2' +) + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block. + Args: + block_args (namedtuple): BlockArgs, defined in utils.py. + global_params (namedtuple): GlobalParam, defined in utils.py. + image_size (tuple or list): [image_height, image_width]. + References: + [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) + [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) + [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) + """ + + def __init__(self, block_args, global_params, image_size=None): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # whether to use skip connection and drop connect + + # Expansion phase (Inverted Bottleneck) + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise + kernel_size=k, stride=s, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + image_size = calculate_output_image_size(image_size, s) + + # Squeeze and Excitation layer, if desired + if self.has_se: + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) + self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) + self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) + + # Pointwise convolution phase + final_oup = self._block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """MBConvBlock's forward function. + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + Returns: + Output of this block after processing. + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + # Pointwise Convolution + x = self._project_conv(x) + x = self._bn2(x) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + # The combination of skip connection and drop connect brings about stochastic depth. + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """EfficientNet model. + Most easily loaded with the .from_name or .from_pretrained methods. + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> outputs = model(inputs) + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Get stem static or dynamic convolution depending on image size + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + # Stem + in_channels = 3 # rgb + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + image_size = calculate_output_image_size(image_size, 2) + + # Build blocks + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) + image_size = calculate_output_image_size(image_size, block_args.stride) + if block_args.num_repeat > 1: # modify block_args to keep same output size + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) + # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + if self._global_params.include_top: + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + + # set activation to memory efficient swish by default + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_endpoints(self, inputs): + """Use convolution layer to extract features + from reduction levels i in [1, 2, 3, 4, 5]. + Args: + inputs (tensor): Input tensor. + Returns: + Dictionary of last intermediate features + with reduction levels i in [1, 2, 3, 4, 5]. + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> endpoints = model.extract_endpoints(inputs) + >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) + >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) + >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) + >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) + >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) + >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) + """ + endpoints = dict() + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + prev_x = x + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + if prev_x.size(2) > x.size(2): + endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x + elif idx == len(self._blocks) - 1: + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + prev_x = x + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + + return endpoints + + def extract_features(self, inputs): + """use convolution layer to extract feature . + Args: + inputs (tensor): Input tensor. + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """EfficientNet's forward function. + Calls extract_features to extract features, applies final linear layer, and returns logits. + Args: + inputs (tensor): Input tensor. + Returns: + Output of this model after processing. + """ + # Convolution layers + x = self.extract_features(inputs) + # Pooling and final linear layer + x = self._avg_pooling(x) + if self._global_params.include_top: + x = x.flatten(start_dim=1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, in_channels=3, **override_params): + """Create an efficientnet model according to name. + Args: + model_name (str): Name for efficientnet. + in_channels (int): Input data's channel number. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + Returns: + An efficientnet model. + """ + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, override_params) + model = cls(blocks_args, global_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def from_pretrained(cls, model_name, weights_path=None, advprop=False, + in_channels=3, num_classes=1000, **override_params): + """Create an efficientnet model according to name. + Args: + model_name (str): Name for efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + advprop (bool): + Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + in_channels (int): Input data's channel number. + num_classes (int): + Number of categories for classification. + It controls the output size for final linear layer. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + Returns: + A pretrained efficientnet model. + """ + model = cls.from_name(model_name, num_classes=num_classes, **override_params) + load_pretrained_weights(model, model_name, weights_path=weights_path, + load_fc=(num_classes == 1000), advprop=advprop) + model._change_in_channels(in_channels) + return model + + @classmethod + def get_image_size(cls, model_name): + """Get the input image size for a given efficientnet model. + Args: + model_name (str): Name for efficientnet. + Returns: + Input image size (resolution). + """ + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """Validates model name. + Args: + model_name (str): Name for efficientnet. + Returns: + bool: Is a valid name or not. + """ + if model_name not in VALID_MODELS: + raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS)) + + def _change_in_channels(self, in_channels): + """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. + Args: + in_channels (int): Input data's channel number. + """ + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) \ No newline at end of file diff --git a/examples/efficientnet/schedule.py b/examples/efficientnet/schedule.py new file mode 100644 index 00000000..5b271bbc --- /dev/null +++ b/examples/efficientnet/schedule.py @@ -0,0 +1,234 @@ +import torch + +from cube.profiler.timer import CudaTimer + + +def is_last_stage(): + return torch.distributed.get_rank() == torch.distributed.get_world_size() - 1 + + +#================= WhatToDO functions ==================# + +def forward_step(model, image, trans_input=None): + CudaTimer().start("forward") + output = model(image, trans_input) + CudaTimer().stop("forward") + return output + + +def backward_step(feature_map, output_tensor, output_tensor_grad): + """ + Calculate input tensor gradient + """ + if feature_map is not None and feature_map.requires_grad: + feature_map.retain_grad() + CudaTimer().start("backward") + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + CudaTimer().stop("backward") + input_tensor_grad = None + if feature_map is not None and feature_map.requires_grad: + input_tensor_grad = feature_map.grad + return input_tensor_grad + +#================= WhatToDO functions ==================# + +#================= Between Stage functions ==================# + +def send(tensors, to_rank): + """ + send tensor to the target rank + """ + if to_rank < 0 or to_rank >= torch.distributed.get_world_size(): + return None + assert isinstance(tensors, list) or isinstance(tensors, tuple) + CudaTimer().start("send") + reqs = list() + for tensor in tensors: + if tensor is None: + continue + elif torch.is_tensor(tensor): + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, to_rank + ) + reqs.append(send_op) + else: + raise RuntimeError("Expected tensor or None") + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("send") + + +def recv(shapes, from_rank, dtype=torch.float): + if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): + return [None] * len(shapes) + assert isinstance(shapes, list) or isinstance(shapes, tuple) + CudaTimer().start("recv") + reqs = list() + recved_tensors = list() + for shape in shapes: + if shape is None: + recved_tensors.append(None) + continue + tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device(), + dtype=dtype + ) + recved_tensors.append(tensor) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, from_rank + ) + reqs.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("recv") + return recved_tensors + + +def send_and_recv(send_tensors, recv_shapes, rank, dtype=torch.float): + if rank < 0 or rank >= torch.distributed.get_world_size(): + return [None] * len(recv_shapes) + assert isinstance(send_tensors, list) or isinstance(send_tensors, tuple) + assert isinstance(recv_shapes, list) or isinstance(recv_shapes, tuple) + CudaTimer().start("send_recv") + reqs = list() + recved_tensors = list() + for tensor in send_tensors: + if tensor is None: + continue + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, rank + ) + reqs.append(send_op) + for shape in recv_shapes: + if shape is None: + recved_tensors.append(None) + continue + recv_tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device(), + dtype=dtype + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, rank + ) + recved_tensors.append(recv_tensor) + reqs.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("send_recv") + return recved_tensors + +#================= Between Stage functions ==================# + +def split_batch(inputs, num_microbatches): + """ + Split a mini-batch to micro-batches + """ + assert isinstance(inputs, list) or isinstance(inputs, tuple) + input_chunks = list() + for feature_map in inputs: + if torch.is_tensor(feature_map): + feature_map = torch.chunk(feature_map, chunks=num_microbatches, dim=0) + else: + feature_map = [feature_map] * num_microbatches + input_chunks.append(feature_map) + micro_batches = list() + for micro_data in zip(*tuple(input_chunks)): + micro_batches.append(micro_data) + return micro_batches + + +#================= Scheduling ==================# + +def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): + myrank = torch.distributed.get_rank() + + num_microbatches = int(bs / micro_bs) + num_warmup_microbatches = \ + (torch.distributed.get_world_size() - + torch.distributed.get_rank() - 1) + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_warmup_remaining = num_microbatches - num_warmup_microbatches + + input_tensors = list() + output_tensors = list() + + inputs = split_batch(inputs, num_microbatches) + + # warmup forward pass + for i in range(num_warmup_microbatches): + # recv forward + # print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) + feature_map = recv( + (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype + )[0] + image = inputs[i][0] + # forward + output_tensor = forward_step(model, image, feature_map) + # send forward + # print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) + send((output_tensor,), myrank+1) + + input_tensors.append(feature_map) + output_tensors.append(output_tensor) + + # before running 1F1B, need to recieve first forward tensor + if num_warmup_remaining > 0: + # recv forward + # print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) + feature_map = recv( + (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype + )[0] + image = inputs[num_warmup_microbatches][0] + + # run 1F1B + for i in range(num_warmup_remaining): + # forward + output_tensor = forward_step(model, image, feature_map) + # send forward + recv backward grads + # print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) + output_tensor_grad = send_and_recv( + (output_tensor,), + (torch.Size([micro_bs] + model.out_size),), + myrank+1, dtype + )[0] + input_tensors.append(feature_map) + output_tensors.append(output_tensor) + # backward + feature_map, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) + if i != (num_warmup_remaining-1): + # send backward grads + recv forward results + # print('[1f1b] rank {}: step-{}: sending backward + recving forward...'.format(myrank, i)) + feature_map = send_and_recv( + (input_tensor_grad,), + (torch.Size([micro_bs] + model.in_size),), + myrank-1, dtype + )[0] + image = inputs[num_warmup_microbatches+i+1][0] + else: # last iteration - no more inputs + feature_map = None + # send backward grads + # print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) + send((input_tensor_grad,), myrank-1) + + # cooldown gradient trans back + for i in range(num_warmup_microbatches): + feature_map = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + # recv backward gradients + output_tensor_grad = recv( + (torch.Size([micro_bs] + model.out_size),), myrank+1, dtype + )[0] + # backward + input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) + # send backward gradients + # print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) + send((input_tensor_grad,), myrank-1) + +#================= Scheduling ==================# \ No newline at end of file diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py new file mode 100644 index 00000000..8004f0ef --- /dev/null +++ b/examples/efficientnet/train.py @@ -0,0 +1,92 @@ +""" +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/efficientnet/train.py --bs 1 +""" + +import torch +from torch import nn +from examples.efficientnet.efficientnet import EfficientNet +import time +import argparse + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary +from cube.runtime.device import DeviceGroup +from cube.runtime.reducer import Reducer + + + +def train(args): + + N = args.bs + + # L2 config + # C, H, W = [3, 800, 800] + # model = EfficientNet.from_name('efficientnet-l2') + + # B8 config + C, H, W = [3, 672, 672] + model = EfficientNet.from_name('efficientnet-b8') + + model = model.cuda() + + if N % args.bs != 0: + raise RuntimeError("global bs is not divisible by DP") + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [N // args.dp, C, H, W]) + + + def train_iter(model, dataloader): + img = next(dataloader) + loss = model(img) + loss = torch.sum(loss) + loss.backward() + + optimizer = torch.optim.RMSprop(model.parameters()) + + CudaTimer(enable=False).warmup() + torch.distributed.barrier() + span = 0 + iter_num = 128 + for step in range(iter_num): + if step >= 40: + torch.cuda.synchronize() + start = time.time() + CudaTimer(enable=True).start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on 1st iteration') + memory_summary() + if step >= 40: + torch.cuda.synchronize() + stop = time.time() + span += (stop - start) * 1000 + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + +if __name__ == '__main__': + + cube.init() + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--bs', type=int, default=1, + help='bs') + parser.add_argument('--dp', type=int, default=1, + help='data parallel') + parser.add_argument('--fp16', action='store_true', dest='fp16') + args = parser.parse_args() + + train(args) diff --git a/examples/efficientnet/utils.py b/examples/efficientnet/utils.py new file mode 100644 index 00000000..4850a9e9 --- /dev/null +++ b/examples/efficientnet/utils.py @@ -0,0 +1,586 @@ +"""utils.py - Helper functions for building the model and for loading model parameters. + These helper functions are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +import re +import math +import collections +from functools import partial +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo + + +################################################################################ +# Help functions for model architecture +################################################################################ + +# GlobalParams and BlockArgs: Two namedtuples +# Swish and MemoryEfficientSwish: Two implementations of the method +# round_filters and round_repeats: +# Functions to calculate params for scaling model width and depth ! ! ! +# get_width_and_height_from_size and calculate_output_image_size +# drop_connect: A structural design +# get_same_padding_conv2d: +# Conv2dDynamicSamePadding +# Conv2dStaticSamePadding +# get_same_padding_maxPool2d: +# MaxPool2dDynamicSamePadding +# MaxPool2dStaticSamePadding +# It's an additional function, not used in EfficientNet, +# but can be used in other model (such as EfficientDet). + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple('GlobalParams', [ + 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', + 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top']) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple('BlockArgs', [ + 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', + 'input_filters', 'output_filters', 'se_ratio', 'id_skip']) + +# Set GlobalParams and BlockArgs's defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + +# Swish activation function +if hasattr(nn, 'SiLU'): + Swish = nn.SiLU +else: + # For compatibility with old PyTorch versions + class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +# A memory-efficient implementation of Swish function +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + # TODO: modify the params names. + # maybe the names (width_divisor,min_width) + # are more suitable than (depth_divisor,min_depth). + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor # pay attention to this line when using min_depth + # follow the formula transferred from official TensorFlow implementation + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + # follow the formula transferred from official TensorFlow implementation + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """Drop connect. + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, 'p must be in range of [0,1]' + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + # generate binary_tensor mask according to probability (p for 0, 1-p for 1) + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + Args: + x (int, tuple or list): Data size. + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size(input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +# Note: +# The following 'SamePadding' functions make output size equal ceil(input size/stride). +# Only when stride equals 1, can the output size be the same as input size. +# Don't be confused by their function names ! ! ! + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: + image_size (int or tuple): Size of the image. + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + # Tips for 'SAME' mode padding. + # Given the following: + # i: width or height + # s: stride + # k: kernel size + # d: dilation + # p: padding + # Output after Conv2d: + # o = floor((i+p-((k-1)*d+1))/s+1) + # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), + # => p = (i-1)*s+((k-1)*d+1)-i + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + # With the same calculation as Conv2dDynamicSamePadding + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, + pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: + image_size (int or tuple): Size of the image. + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + return x + + +################################################################################ +# Helper functions for loading model params +################################################################################ + +# BlockDecoder: A Class for encoding and decoding BlockArgs +# efficientnet_params: A function to query compound coefficient +# get_model_params and efficientnet: +# Functions to get BlockArgs and GlobalParams for efficientnet +# url_map and url_map_advprop: Dicts of url_map for pretrained weights +# load_pretrained_weights: A function to load pretrained weights + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + num_repeat=int(options['r']), + kernel_size=int(options['k']), + stride=[int(options['s'][0])], + expand_ratio=int(options['e']), + input_filters=int(options['i']), + output_filters=int(options['o']), + se_ratio=float(options['se']) if 'se' in options else None, + id_skip=('noskip' not in block_string)) + + @staticmethod + def _encode_block_string(block): + """Encode a block to a string. + Args: + block (namedtuple): A BlockArgs type argument. + Returns: + block_string: A String form of BlockArgs. + """ + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """Encode a list of BlockArgs to a list of strings. + Args: + blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + Returns: + block_strings: A list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet_params(model_name): + """Map EfficientNet model name to parameter coefficients. + Args: + model_name (str): Model name to be queried. + Returns: + params_dict[model_name]: A (width,depth,res,dropout) tuple. + """ + params_dict = { + # Coefficients: width,depth,res,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, + dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): + """Create BlockArgs and GlobalParams for efficientnet model. + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + Meaning as the name suggests. + Returns: + blocks_args, global_params. + """ + + # Blocks args for the whole model(efficientnet-b0 by default) + # It will be modified in the construction of EfficientNet Class according to model + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', + 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', + 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', + 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """Get the block args and global params for a given model name. + Args: + model_name (str): Model's name. + override_params (dict): A dict to modify global_params. + Returns: + blocks_args, global_params + """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + # note: all models have drop connect rate = 0.2 + blocks_args, global_params = efficientnet( + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) + else: + raise NotImplementedError('model name is not pre-defined: {}'.format(model_name)) + if override_params: + # ValueError will be raised here if override_params has fields not included in global_params. + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +# train with Standard methods +# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) +url_map = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', +} + +# train with Adversarial Examples(AdvProp) +# check more details in paper(Adversarial Examples Improve Image Recognition) +url_map_advprop = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', + 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', +} + +# TODO: add the petrained weights url map of 'efficientnet-l2' + + +def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True): + """Loads pretrained weights from weights path or download using url. + Args: + model (Module): The whole model of efficientnet. + model_name (str): Model name of efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. + advprop (bool): Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + """ + if isinstance(weights_path, str): + state_dict = torch.load(weights_path) + else: + # AutoAugment or Advprop (different preprocessing) + url_map_ = url_map_advprop if advprop else url_map + state_dict = model_zoo.load_url(url_map_[model_name]) + + if load_fc: + ret = model.load_state_dict(state_dict, strict=False) + assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + ret = model.load_state_dict(state_dict, strict=False) + assert set(ret.missing_keys) == set( + ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) + assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) + + if verbose: + print('Loaded pretrained weights for {}'.format(model_name)) \ No newline at end of file From ffe51fc27eca64faa60ff4a73dcec44fe6b89864 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 06:29:17 +0000 Subject: [PATCH 0456/1892] enable b8 pipeline --- examples/efficientnet/efficientnet.py | 43 +++++++-- examples/efficientnet/train.py | 134 +++++++++++++++++++++++--- 2 files changed, 157 insertions(+), 20 deletions(-) diff --git a/examples/efficientnet/efficientnet.py b/examples/efficientnet/efficientnet.py index 4f5de160..0de902be 100644 --- a/examples/efficientnet/efficientnet.py +++ b/examples/efficientnet/efficientnet.py @@ -64,6 +64,7 @@ def __init__(self, block_args, global_params, image_size=None): self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size + in_image_size = image_size # Depthwise convolution phase k = self._block_args.kernel_size s = self._block_args.stride @@ -88,6 +89,9 @@ def __init__(self, block_args, global_params, image_size=None): self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) self._swish = MemoryEfficientSwish() + self.in_size = [inp, *in_image_size] + self.out_size = [final_oup, *image_size] + def forward(self, inputs, drop_connect_rate=None): """MBConvBlock's forward function. Args: @@ -99,6 +103,7 @@ def forward(self, inputs, drop_connect_rate=None): # Expansion and Depthwise Convolution x = inputs + assert list(x.shape)[1:] == self.in_size if self._block_args.expand_ratio != 1: x = self._expand_conv(inputs) x = self._bn0(x) @@ -127,6 +132,8 @@ def forward(self, inputs, drop_connect_rate=None): if drop_connect_rate: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection + + assert list(x.shape)[1:] == self.out_size return x def set_swish(self, memory_efficient=True): @@ -212,6 +219,9 @@ def __init__(self, blocks_args=None, global_params=None): # set activation to memory efficient swish by default self._swish = MemoryEfficientSwish() + self.preprocess = True + self.postprocess = True + def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export). Args: @@ -289,7 +299,7 @@ def extract_features(self, inputs): return x - def forward(self, inputs): + def forward(self, x, feature_map=None): """EfficientNet's forward function. Calls extract_features to extract features, applies final linear layer, and returns logits. Args: @@ -297,14 +307,29 @@ def forward(self, inputs): Returns: Output of this model after processing. """ - # Convolution layers - x = self.extract_features(inputs) - # Pooling and final linear layer - x = self._avg_pooling(x) - if self._global_params.include_top: - x = x.flatten(start_dim=1) - x = self._dropout(x) - x = self._fc(x) + if self.preprocess: + # Stem + x = self._swish(self._bn0(self._conv_stem(x))) + feature_map = x + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + feature_map = block(feature_map, drop_connect_rate=drop_connect_rate) + x = feature_map + + if self.postprocess: + # Head + x = self._swish(self._bn1(self._conv_head(x))) + # Pooling and final linear layer + x = self._avg_pooling(x) + if self._global_params.include_top: + x = x.flatten(start_dim=1) + x = self._dropout(x) + x = self._fc(x) + x = torch.sum(x) return x @classmethod diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index 8004f0ef..888bf1f0 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -1,12 +1,13 @@ """ python -m torch.distributed.launch \ - --nproc_per_node=2 \ + --nproc_per_node=8 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/efficientnet/train.py --bs 1 + examples/efficientnet/train.py \ + --pp 8 --gbs 8 --mbs 1 """ import torch @@ -21,12 +22,52 @@ from cube.profiler.memory import memory_summary from cube.runtime.device import DeviceGroup from cube.runtime.reducer import Reducer +from examples.efficientnet.schedule import is_last_stage, scheduling_1f1b +def model_partition(model, in_size): + # pipeline stage + pp_rank = torch.distributed.get_rank() + pp_size = torch.distributed.get_world_size() + + layers = model._blocks + + chunk = len(layers) // pp_size + if len(layers) % pp_size != 0: + remain = len(layers) % pp_size + if pp_rank < remain: + start = pp_rank * (chunk+1) + chunk = chunk + 1 + else: + start = remain * (chunk + 1) + (pp_rank - remain) * chunk + else: + start = pp_rank * chunk + stop = start + chunk + + print_each_rank(f'layer start -> end: {start} -> {stop}') + layers = layers[start:stop] + model._blocks = layers + + if pp_rank == 0: + model.preprocess = True + model.in_size = in_size + else: + model.preprocess = False + model.in_size = layers[0].in_size + + if is_last_stage(): + model.postprocess = True + model.out_size = [1,] + else: + model.postprocess = False + model.out_size = layers[-1].out_size + + return model + def train(args): - N = args.bs + N = args.gbs # L2 config # C, H, W = [3, 800, 800] @@ -36,19 +77,27 @@ def train(args): C, H, W = [3, 672, 672] model = EfficientNet.from_name('efficientnet-b8') + model = model_partition(model, [C, H, W]) + if args.fp16: + model == model.half() model = model.cuda() - if N % args.bs != 0: + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + memory_summary() + + if N % args.gbs != 0: raise RuntimeError("global bs is not divisible by DP") dataloader = cube.runtime.syndata.SynDataLoader( 1280, [0], [N // args.dp, C, H, W]) - + + if args.fp16: + data_buff = [[e.half() for e in data] for data in dataloader.datas] + dataloader.datas = data_buff def train_iter(model, dataloader): img = next(dataloader) - loss = model(img) - loss = torch.sum(loss) - loss.backward() + scheduling_1f1b(model, [img], args.gbs, args.mbs, dtype=torch.float) optimizer = torch.optim.RMSprop(model.parameters()) @@ -75,6 +124,15 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + iter_time = CudaTimer().duration(iter_num-10, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + + CudaTimer().print_all(times=iter_num-10) + memory_summary() + if __name__ == '__main__': @@ -82,11 +140,65 @@ def train_iter(model, dataloader): # resource allocation parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--bs', type=int, default=1, - help='bs') + parser.add_argument('--tp', type=int, default=1, + help='tensor parallel size') parser.add_argument('--dp', type=int, default=1, - help='data parallel') + help='data parallel size') + parser.add_argument('--pp', type=int, default=1, + help='pipeline parallel size') + parser.add_argument('--gbs', type=int, default=-1) + parser.add_argument('--mbs', type=int, default=-1) parser.add_argument('--fp16', action='store_true', dest='fp16') args = parser.parse_args() + + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + tp_size, tp_group_nums = args.tp, ndevs // args.tp + dp_size, dp_group_nums = args.dp, ndevs // args.dp + pp_size, pp_group_nums = args.pp, ndevs // args.pp + + if not pp_size * dp_size * tp_size == ndevs: + raise RuntimeError("Expected all devices are used") + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize data parallel group + all_data_parallel_group_ranks = list() + for i in range(pp_size): + start_rank = i * pp_group_nums + end_rank = (i + 1) * pp_group_nums + for j in range(tp_size): + ranks = list(range(start_rank + j, end_rank, tp_size)) + all_data_parallel_group_ranks.append(ranks) + # initialize groups + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + resource.dp_group = group + resource.reducer = cube.runtime.reducer.Reducer(ranks) + print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) + + # initialize pipelne parallel groups + for i in range(dp_size): + ranks = [data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks] + group = devs.get_group(ranks) + if myrank in ranks: + pp_ranks = ranks + resource.pp_group = group + print_each_rank(f'initialzed pipeline parallel group: {pp_ranks}', rank_only=myrank) + + # initialize tensor parallel groups + for i in range(tp_group_nums): + ranks = list(range(i * tp_size, (i + 1) * tp_size)) + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + resource.tp_group = group + print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + train(args) From c1e66f17832feaf7fa7fdc4046ace097b0027c0a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 06:34:08 +0000 Subject: [PATCH 0457/1892] enable 1f1b pipe on efficient-net l2 --- examples/efficientnet/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index 888bf1f0..5e6ef0ef 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -70,12 +70,12 @@ def train(args): N = args.gbs # L2 config - # C, H, W = [3, 800, 800] - # model = EfficientNet.from_name('efficientnet-l2') + C, H, W = [3, 800, 800] + model = EfficientNet.from_name('efficientnet-l2') # B8 config - C, H, W = [3, 672, 672] - model = EfficientNet.from_name('efficientnet-b8') + # C, H, W = [3, 672, 672] + # model = EfficientNet.from_name('efficientnet-b8') model = model_partition(model, [C, H, W]) if args.fp16: From 6c7d849591217ea7ff3115f1c358ffbd73f46142 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 06:43:51 +0000 Subject: [PATCH 0458/1892] efficientnet train --- examples/efficientnet/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index 5e6ef0ef..e95a2985 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -104,9 +104,9 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() torch.distributed.barrier() span = 0 - iter_num = 128 + iter_num = 40 for step in range(iter_num): - if step >= 40: + if step >= 20: torch.cuda.synchronize() start = time.time() CudaTimer(enable=True).start('e2e') @@ -116,7 +116,7 @@ def train_iter(model, dataloader): if step == 1: print('> passed on 1st iteration') memory_summary() - if step >= 40: + if step >= 20: torch.cuda.synchronize() stop = time.time() span += (stop - start) * 1000 @@ -124,13 +124,13 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = CudaTimer().duration(iter_num-10, field_name='e2e') + iter_time = CudaTimer().duration(iter_num-20, field_name='e2e') throughput = N / iter_time * 1000 print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) - CudaTimer().print_all(times=iter_num-10) + CudaTimer().print_all(times=iter_num-20) memory_summary() From 7108f59095f3f2cb1a2c521342e92d844b0df236 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 07:19:48 +0000 Subject: [PATCH 0459/1892] checkpoint --- examples/efficientnet/efficientnet.py | 34 ++++++++------------------- examples/efficientnet/train.py | 4 ++++ 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/examples/efficientnet/efficientnet.py b/examples/efficientnet/efficientnet.py index 0de902be..cf867633 100644 --- a/examples/efficientnet/efficientnet.py +++ b/examples/efficientnet/efficientnet.py @@ -24,6 +24,8 @@ calculate_output_image_size ) +import torch.utils.checkpoint as checkpoint + VALID_MODELS = ( 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', @@ -219,6 +221,7 @@ def __init__(self, blocks_args=None, global_params=None): # set activation to memory efficient swish by default self._swish = MemoryEfficientSwish() + self.use_checkpoint = [False] * len(self._blocks) self.preprocess = True self.postprocess = True @@ -276,29 +279,6 @@ def extract_endpoints(self, inputs): return endpoints - def extract_features(self, inputs): - """use convolution layer to extract feature . - Args: - inputs (tensor): Input tensor. - Returns: - Output of the final convolution - layer in the efficientnet model. - """ - # Stem - x = self._swish(self._bn0(self._conv_stem(inputs))) - - # Blocks - for idx, block in enumerate(self._blocks): - drop_connect_rate = self._global_params.drop_connect_rate - if drop_connect_rate: - drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate - x = block(x, drop_connect_rate=drop_connect_rate) - - # Head - x = self._swish(self._bn1(self._conv_head(x))) - - return x - def forward(self, x, feature_map=None): """EfficientNet's forward function. Calls extract_features to extract features, applies final linear layer, and returns logits. @@ -317,7 +297,13 @@ def forward(self, x, feature_map=None): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate - feature_map = block(feature_map, drop_connect_rate=drop_connect_rate) + else: + drop_connect_rate = None + if self.use_checkpoint[idx]: + feature_map = checkpoint.checkpoint(block, feature_map, drop_connect_rate) + else: + feature_map = block(feature_map, drop_connect_rate=drop_connect_rate) + # feature_map = block(feature_map, drop_connect_rate=drop_connect_rate) x = feature_map if self.postprocess: diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index e95a2985..4c43d440 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -44,9 +44,13 @@ def model_partition(model, in_size): start = pp_rank * chunk stop = start + chunk + # use_checkpoint = [False] * (stop - start) + use_checkpoint = [True] * (stop - start) + print_each_rank(f'layer start -> end: {start} -> {stop}') layers = layers[start:stop] model._blocks = layers + model.use_checkpoint = use_checkpoint if pp_rank == 0: model.preprocess = True From b6ae862c987e91ccc9b2459753d4a91438e0babf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 13:34:22 +0000 Subject: [PATCH 0460/1892] memory profiler --- cube/profiler/memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index 919b8fe3..ad2f9aff 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -6,6 +6,6 @@ def memory_summary(): # memory measurement mem = torch.cuda.max_memory_allocated() # mem = torch.cuda.max_memory_reserved() - print( + print_each_rank( '{:.2f}GB memory consumption'.format(mem / 1024 / 1024 / 1024), ) From 3bc17b1fbd29af731581b7361c18e4e8eac4e408 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 13:37:52 +0000 Subject: [PATCH 0461/1892] test script --- examples/efficientnet/train.py | 34 +++++++++++++++++++++++++++++++++- examples/swin/swin_pipe.py | 6 +++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index 4c43d440..6f365acf 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -47,6 +47,31 @@ def model_partition(model, in_size): # use_checkpoint = [False] * (stop - start) use_checkpoint = [True] * (stop - start) + # layer_split = [6, 6, 8, 16, 14, 13, 14, 11] + # assert sum(layer_split) == 88 + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) + # + # # use_checkpoint = [True] * (stop - start) + # use_checkpoint = [False] * (stop - start) + # # use_checkpoint = [True] * (stop - start) + # if pp_rank == 0: + # for idx in range(stop - start): + # if idx < 5: + # use_checkpoint[idx] = True + # if pp_rank == 1: + # for idx in range(stop - start): + # if idx < 5: + # use_checkpoint[idx] = True + # if pp_rank == 2: + # for idx in range(stop - start): + # if idx < 5: + # use_checkpoint[idx] = True + # if pp_rank == 3: + # for idx in range(stop - start): + # if idx < 4: + # use_checkpoint[idx] = True + print_each_rank(f'layer start -> end: {start} -> {stop}') layers = layers[start:stop] model._blocks = layers @@ -108,7 +133,7 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() torch.distributed.barrier() span = 0 - iter_num = 40 + iter_num = 60 for step in range(iter_num): if step >= 20: torch.cuda.synchronize() @@ -153,6 +178,8 @@ def train_iter(model, dataloader): parser.add_argument('--gbs', type=int, default=-1) parser.add_argument('--mbs', type=int, default=-1) parser.add_argument('--fp16', action='store_true', dest='fp16') + parser.add_argument('--memory-limit', type=float, default=None, + help='memory fraction limit') args = parser.parse_args() @@ -205,4 +232,9 @@ def train_iter(model, dataloader): resource.tp_group = group print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + if args.memory_limit is not None: + assert isinstance(args.memory_limit, float) + print_each_rank(f'set memory constraints on {args.memory_limit} fraction.') + torch.cuda.set_per_process_memory_fraction(args.memory_limit) + train(args) diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index 5aadec2f..b1d12a36 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -603,10 +603,10 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, start = pp_rank * chunk stop = start + chunk - self.use_checkpoint = [True] * (stop - start) + # self.use_checkpoint = [True] * (stop - start) # 8gpu layer assign - # layer_split = [3, 4, 3, 3, 3, 4, 4, 3] + # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] # assert sum(layer_split) == 27 # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) @@ -617,7 +617,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) - # self.use_checkpoint = [False] * (stop - start) + self.use_checkpoint = [False] * (stop - start) # for idx in range(stop - start): # if pp_rank == 0: # if idx in [0,]: From 9aa8cc82fe10a54bc11bae20742df6b182c74ab9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 13:40:33 +0000 Subject: [PATCH 0462/1892] eval scripts --- eval/swin_infer_224_782M_fp16.sh | 97 ++++++++++++++++++++++++++++++++ eval/swin_train_fp32.sh | 68 ++++++++++++++++++++-- 2 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 eval/swin_infer_224_782M_fp16.sh diff --git a/eval/swin_infer_224_782M_fp16.sh b/eval/swin_infer_224_782M_fp16.sh new file mode 100644 index 00000000..6ad1ceb8 --- /dev/null +++ b/eval/swin_infer_224_782M_fp16.sh @@ -0,0 +1,97 @@ + +mkdir -p expinfer_224_782M_fp16 + +# ================== Maximal Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 1 \ + --layer1 1 1 1 \ + --layer2 1 1 1 \ + --layer3 1 1 1 \ + --fp16 \ + > expinfer_Gfp16_bs1/1gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 2 \ + --layer1 1 1 2 \ + --layer2 1 1 2 \ + --layer3 1 1 2 \ + --fp16 \ + > expinfer_Gfp16_bs1/2gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 1 4 \ + --layer1 1 1 4 \ + --layer2 1 1 4 \ + --layer3 1 1 4 \ + --fp16 \ + > expinfer_Gfp16_bs1/4gpu_tp.txt + + +# ================== Window + Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 2 1 \ + --layer1 1 2 1 \ + --layer2 1 2 1 \ + --layer3 1 1 2 \ + --fp16 \ + > expinfer_Gfp16_bs1/2gpu_2wp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 4 1 \ + --layer1 1 4 1 \ + --layer2 1 4 1 \ + --layer3 1 1 4 \ + --fp16 \ + > expinfer_Gfp16_bs1/4gpu_4wp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 1 \ + --layer0 1 8 1 \ + --layer1 1 8 1 \ + --layer2 1 4 2 \ + --layer3 1 1 8 \ + --fp16 \ + > expinfer_Gfp16_bs1/8gpu_8wp8tp.txt + diff --git a/eval/swin_train_fp32.sh b/eval/swin_train_fp32.sh index 8a15b796..ad441b70 100755 --- a/eval/swin_train_fp32.sh +++ b/eval/swin_train_fp32.sh @@ -1,9 +1,23 @@ bs=$1 -logfile=exptrain_782M_bs${bs}_fp32 +logfile=exptrain_782M_bs${bs}_fp32_384 mkdir -p ${logfile} +# python -m torch.distributed.launch \ +# --nproc_per_node=1 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt.py --bs ${bs} \ +# --layer0 1 1 1 \ +# --layer1 1 1 1 \ +# --layer2 1 1 1 \ +# --layer3 1 1 1 \ +# > ${logfile}/single.txt + # ================== Megatron Policy Parallel =============== python -m torch.distributed.launch \ @@ -18,7 +32,7 @@ python -m torch.distributed.launch \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - > ${logfile}/2gpu_maxdp.txt + > ${logfile}/2gpu_maxdp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -32,7 +46,7 @@ python -m torch.distributed.launch \ --layer1 2 1 2 \ --layer2 2 1 2 \ --layer3 2 1 2 \ - > ${logfile}/4gpu_maxdp.txt + > ${logfile}/4gpu_maxdp2tp.txt python -m torch.distributed.launch \ @@ -47,7 +61,7 @@ python -m torch.distributed.launch \ --layer1 4 1 2 \ --layer2 4 1 2 \ --layer3 4 1 2 \ - > ${logfile}/8gpu_maxdp.txt + > ${logfile}/8gpu_maxdp2tp.txt # ================== Maximal Tensor Parallel =============== @@ -181,3 +195,49 @@ python -m torch.distributed.launch \ --layer2 1 8 \ --layer3 1 8 \ > ${logfile}/8gpu_dt_8dp8tp.txt + + +# ================== Pure Data Parallel ============= + +# python -m torch.distributed.launch \ +# --nproc_per_node=2 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt.py --bs ${bs} \ +# --layer0 2 1 1 \ +# --layer1 2 1 1 \ +# --layer2 2 1 1 \ +# --layer3 2 1 1 \ +# > ${logfile}/2gpu_maxdp.txt +# +# +# python -m torch.distributed.launch \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt.py --bs ${bs} \ +# --layer0 4 1 1 \ +# --layer1 4 1 1 \ +# --layer2 4 1 1 \ +# --layer3 4 1 1 \ +# > ${logfile}/4gpu_maxdp.txt +# +# python -m torch.distributed.launch \ +# --nproc_per_node=8 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt.py --bs ${bs} \ +# --layer0 8 1 1 \ +# --layer1 8 1 1 \ +# --layer2 8 1 1 \ +# --layer3 8 1 1 \ +# > ${logfile}/8gpu_maxdp.txt From 635529f055df11b3cd0eb2d0d75adebf660616e7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Dec 2021 15:43:01 +0000 Subject: [PATCH 0463/1892] update script --- ...fp16.sh => swin_infer_bs1_224_782Mfp32.sh} | 22 ++-- eval/swin_infer_bs2_224_782Mfp32.sh | 107 ++++++++++++++++++ 2 files changed, 116 insertions(+), 13 deletions(-) rename eval/{swin_infer_224_782M_fp16.sh => swin_infer_bs1_224_782Mfp32.sh} (84%) mode change 100644 => 100755 create mode 100755 eval/swin_infer_bs2_224_782Mfp32.sh diff --git a/eval/swin_infer_224_782M_fp16.sh b/eval/swin_infer_bs1_224_782Mfp32.sh old mode 100644 new mode 100755 similarity index 84% rename from eval/swin_infer_224_782M_fp16.sh rename to eval/swin_infer_bs1_224_782Mfp32.sh index 6ad1ceb8..fe73f429 --- a/eval/swin_infer_224_782M_fp16.sh +++ b/eval/swin_infer_bs1_224_782Mfp32.sh @@ -1,5 +1,7 @@ -mkdir -p expinfer_224_782M_fp16 +logfile=expinfer_224_782M_fp32_bs1 + +mkdir -p ${logfile} # ================== Maximal Tensor Parallel =============== python -m torch.distributed.launch \ @@ -14,8 +16,7 @@ python -m torch.distributed.launch \ --layer1 1 1 1 \ --layer2 1 1 1 \ --layer3 1 1 1 \ - --fp16 \ - > expinfer_Gfp16_bs1/1gpu_tp.txt + > ${logfile}/1gpu_tp.txt python -m torch.distributed.launch \ --nproc_per_node=2 \ @@ -29,8 +30,7 @@ python -m torch.distributed.launch \ --layer1 1 1 2 \ --layer2 1 1 2 \ --layer3 1 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs1/2gpu_tp.txt + > ${logfile}/2gpu_tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -44,8 +44,7 @@ python -m torch.distributed.launch \ --layer1 1 1 4 \ --layer2 1 1 4 \ --layer3 1 1 4 \ - --fp16 \ - > expinfer_Gfp16_bs1/4gpu_tp.txt + > ${logfile}/4gpu_tp.txt # ================== Window + Tensor Parallel =============== @@ -62,8 +61,7 @@ python -m torch.distributed.launch \ --layer1 1 2 1 \ --layer2 1 2 1 \ --layer3 1 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs1/2gpu_2wp2tp.txt + > ${logfile}/2gpu_2wp2tp.txt python -m torch.distributed.launch \ --nproc_per_node=4 \ @@ -77,8 +75,7 @@ python -m torch.distributed.launch \ --layer1 1 4 1 \ --layer2 1 4 1 \ --layer3 1 1 4 \ - --fp16 \ - > expinfer_Gfp16_bs1/4gpu_4wp4tp.txt + > ${logfile}/4gpu_4wp4tp.txt python -m torch.distributed.launch \ --nproc_per_node=8 \ @@ -92,6 +89,5 @@ python -m torch.distributed.launch \ --layer1 1 8 1 \ --layer2 1 4 2 \ --layer3 1 1 8 \ - --fp16 \ - > expinfer_Gfp16_bs1/8gpu_8wp8tp.txt + > ${logfile}/8gpu_8wp8tp.txt diff --git a/eval/swin_infer_bs2_224_782Mfp32.sh b/eval/swin_infer_bs2_224_782Mfp32.sh new file mode 100755 index 00000000..46b32221 --- /dev/null +++ b/eval/swin_infer_bs2_224_782Mfp32.sh @@ -0,0 +1,107 @@ + +logfile=expinfer_224_782M_fp32_bs2 + +mkdir -p ${logfile} + +# ================== Maximal Tensor Parallel =============== +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 1 1 1 \ + --layer1 1 1 1 \ + --layer2 1 1 1 \ + --layer3 1 1 1 \ + > ${logfile}/1gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 1 1 \ + --layer1 2 1 1 \ + --layer2 2 1 1 \ + --layer3 2 1 1 \ + > ${logfile}/2gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 1 2 \ + --layer1 2 1 2 \ + --layer2 2 1 2 \ + --layer3 2 1 2 \ + > ${logfile}/4gpu_tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 1 4 \ + --layer1 2 1 4 \ + --layer2 2 1 4 \ + --layer3 2 1 4 \ + > ${logfile}/8gpu_tp.txt + + +# ================== Window + Tensor Parallel =============== + +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 1 1 \ + --layer1 2 1 1 \ + --layer2 2 1 1 \ + --layer3 2 1 1 \ + > ${logfile}/2gpu_2wp2tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 2 1 \ + --layer1 2 2 1 \ + --layer2 2 2 1 \ + --layer3 2 1 2 \ + > ${logfile}/4gpu_4wp4tp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt_infer.py --bs 2 \ + --layer0 2 4 1 \ + --layer1 2 4 1 \ + --layer2 2 4 1 \ + --layer3 2 1 4 \ + > ${logfile}/8gpu_8wp8tp.txt + From b632198d09afc78781d3155dfdec32b2e96b302b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Dec 2021 05:29:24 +0000 Subject: [PATCH 0464/1892] efficientdet 8GB mem constraints --- examples/efficientnet/train.py | 58 ++++++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index 6f365acf..73a2f7af 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -47,21 +47,57 @@ def model_partition(model, in_size): # use_checkpoint = [False] * (stop - start) use_checkpoint = [True] * (stop - start) - # layer_split = [6, 6, 8, 16, 14, 13, 14, 11] - # assert sum(layer_split) == 88 + # layer_split = [8, 5, 7, 14, 15, 13, 15, 11] + # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) # - # # use_checkpoint = [True] * (stop - start) # use_checkpoint = [False] * (stop - start) - # # use_checkpoint = [True] * (stop - start) # if pp_rank == 0: # for idx in range(stop - start): - # if idx < 5: + # if idx < 7: # use_checkpoint[idx] = True # if pp_rank == 1: # for idx in range(stop - start): - # if idx < 5: + # if idx < 4: + # use_checkpoint[idx] = True + # if pp_rank == 2: + # for idx in range(stop - start): + # if idx < 4: + # use_checkpoint[idx] = True + # if pp_rank == 3: + # for idx in range(stop - start): + # if idx < 3: + # use_checkpoint[idx] = True + + # 8 gpu naive partition plan + # if pp_rank == 0: + # for idx in range(stop - start): + # if idx < 10: + # use_checkpoint[idx] = True + # if pp_rank == 1: + # for idx in range(stop - start): + # if idx < 8: + # use_checkpoint[idx] = True + # if pp_rank == 2: + # for idx in range(stop - start): + # if idx < 2: + # use_checkpoint[idx] = True + + # 8GB memory experiments + # layer_split = [8, 5, 7, 14, 14, 13, 16, 11] + # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) + # + # use_checkpoint = [False] * (stop - start) + # if pp_rank == 0: + # for idx in range(stop - start): + # if idx < 8: + # use_checkpoint[idx] = True + # if pp_rank == 1: + # for idx in range(stop - start): + # if idx < 4: # use_checkpoint[idx] = True # if pp_rank == 2: # for idx in range(stop - start): @@ -69,6 +105,14 @@ def model_partition(model, in_size): # use_checkpoint[idx] = True # if pp_rank == 3: # for idx in range(stop - start): + # if idx < 8: + # use_checkpoint[idx] = True + # if pp_rank == 4: + # for idx in range(stop - start): + # if idx < 5: + # use_checkpoint[idx] = True + # if pp_rank == 5: + # for idx in range(stop - start): # if idx < 4: # use_checkpoint[idx] = True @@ -133,7 +177,7 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() torch.distributed.barrier() span = 0 - iter_num = 60 + iter_num = 40 for step in range(iter_num): if step >= 20: torch.cuda.synchronize() From 38b59a4ee95ce452ab9b304171095fe48ca58db8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Dec 2021 06:58:51 +0000 Subject: [PATCH 0465/1892] update docker info --- scripts/env-setup.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index 44ce1bdb..a4898aff 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -1,5 +1,5 @@ -echo using docker image pytorch-cuda11.3: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime +echo using docker image nvcr.io/pytorch:pytorch-21.06-py3 git config --global core.editor "vim" git config --global user.name "Zhiqi Lin" @@ -38,7 +38,7 @@ echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc # cmd for count code lines # find cube/ -name "*.py" -print0 | xargs -0 wc -l -pip uninstall training_daemon +pip uninstall training_daemon -y python setup.py develop pip install -r requirements.txt From dfaddfc23b8f82b82d27575b8143faf0f6de53d0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Dec 2021 06:59:13 +0000 Subject: [PATCH 0466/1892] swin example --- examples/swin/swin_dwt_infer.py | 16 ++++++++-------- examples/swin/swin_pipe.py | 18 +++++++++--------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/swin/swin_dwt_infer.py b/examples/swin/swin_dwt_infer.py index 652eb074..c7f8ccda 100644 --- a/examples/swin/swin_dwt_infer.py +++ b/examples/swin/swin_dwt_infer.py @@ -793,10 +793,10 @@ def train(args, pconfigs): # dim_head is always 32 # img resolution, windows size: 224, 384, 518, 640 - # C, H, W, window_size = [3, 224, 224, 7] + C, H, W, window_size = [3, 224, 224, 7] # C, H, W, window_size = [3, 384, 384, 12] # C, H, W, window_size = [3, 518, 518, ?] - C, H, W, window_size = [3, 640, 640, 20] + # C, H, W, window_size = [3, 640, 640, 20] # C, H, W, window_size = [3, 1536, 1536, 48] # image batch size @@ -823,15 +823,15 @@ def train(args, pconfigs): # ] # SwinV2-H modified: 782 M - # embed_dim, depths, num_heads = [ - # 384, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - - # SwinV2-G: 2.5B Model embed_dim, depths, num_heads = [ - 512, [2, 2, 42, 2], [16, 32, 64, 128] + 384, [2, 2, 18, 2], [12, 24, 48, 96] ] + # SwinV2-G: 2.5B Model + # embed_dim, depths, num_heads = [ + # 512, [2, 2, 42, 2], [16, 32, 64, 128] + # ] + # 895.7 M Model # embed_dim, depths, num_heads = [ # 384, [2, 2, 22, 2], [12, 24, 48, 96] diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index b1d12a36..59a103a0 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -3,13 +3,13 @@ # Modified from Swin-Transformer Repo """ python -m torch.distributed.launch \ - --nproc_per_node=4 \ + --nproc_per_node=8 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_pipe.py --pp 8 --gbs 1 --mbs 1 + examples/swin/swin_pipe.py --pp 8 --gbs 32 --mbs 4 # V100-16GB: 8GPU: need checkpoint: 8 micro bs """ @@ -603,13 +603,19 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, start = pp_rank * chunk stop = start + chunk - # self.use_checkpoint = [True] * (stop - start) + self.use_checkpoint = [True] * (stop - start) + # self.use_checkpoint = [False] * (stop - start) # 8gpu layer assign # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] # assert sum(layer_split) == 27 # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) + # self.use_checkpoint = [False] * (stop - start) + # for idx in range(stop - start): + # if pp_rank == 0: + # if idx < 1: + # self.use_checkpoint[idx] = True # 4Ggpu layer assign # layer_split = [6, 7, 7, 7] @@ -617,12 +623,6 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) - self.use_checkpoint = [False] * (stop - start) - # for idx in range(stop - start): - # if pp_rank == 0: - # if idx in [0,]: - # self.use_checkpoint[idx] = True - print_each_rank(f'layer start -> end: {start} -> {stop}') print_each_rank(self.use_checkpoint) self.layers = layers[start:stop] From 7d56e30af749062494cac494a7f1cc3052be98c1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 3 Dec 2021 09:16:54 +0000 Subject: [PATCH 0467/1892] unflatten and flatten --- cube/runtime/reducer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/runtime/reducer.py b/cube/runtime/reducer.py index 012c9375..3968acc1 100644 --- a/cube/runtime/reducer.py +++ b/cube/runtime/reducer.py @@ -66,7 +66,7 @@ def _flatten_dense_tensors(self, tensors): Returns: A contiguous 1D buffer containing input tensors. """ - return torch._C._nn.flatten_dense_tensors(tensors) + return torch._utils._flatten_dense_tensors(tensors) def _unflatten_dense_tensors(self, flat, tensors): """ @@ -82,4 +82,4 @@ def _unflatten_dense_tensors(self, flat, tensors): Unflattened dense tensors with sizes same as tensors and values from flat. """ - return torch._C._nn.unflatten_dense_tensors(flat, tensors) + return torch._utils._unflatten_dense_tensors(flat, tensors) From e2d921b9d2567e267bb31902e4c5903b24e55267 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 3 Dec 2021 12:00:36 +0000 Subject: [PATCH 0468/1892] enable hybrid training --- cube/runtime/collectives.py | 24 +- cube/runtime/function/dist.py | 1 - examples/swin/hybrid_schedule.py | 251 ++++++++ examples/swin/pmodule.py | 65 ++ examples/swin/swin_hybrid.py | 1022 ++++++++++++++++++++++++++++++ 5 files changed, 1361 insertions(+), 2 deletions(-) create mode 100644 examples/swin/hybrid_schedule.py create mode 100644 examples/swin/pmodule.py create mode 100644 examples/swin/swin_hybrid.py diff --git a/cube/runtime/collectives.py b/cube/runtime/collectives.py index 7cabe3c8..245df862 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/collectives.py @@ -2,6 +2,7 @@ import torch from cube.runtime.device import DeviceGroup +from cube.profiler.timer import CudaTimer def send(tensors: List[torch.Tensor], to_ranks: List[int]): @@ -14,6 +15,7 @@ def send(tensors: List[torch.Tensor], to_ranks: List[int]): tensor_devices (List[List[int]]): tensor sent devices """ # print(f'{torch.distributed.get_rank()}: sending...') + CudaTimer().start(field_name='comm') send_ops = list() ## synthetic ## @@ -30,9 +32,11 @@ def send(tensors: List[torch.Tensor], to_ranks: List[int]): for req in reqs: req.wait() torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') def recv(shapes: List[List[int]], from_ranks: List[int]): + CudaTimer().start(field_name='comm') # print(f'{torch.distributed.get_rank()}: recving...') recv_ops = list() recv_tensors = list() @@ -43,7 +47,6 @@ def recv(shapes: List[List[int]], from_ranks: List[int]): # torch.ones(tuple(shape), # device=torch.cuda.current_device() # )) - for shape, rank in zip(shapes, from_ranks): tensor = torch.empty( shape, requires_grad=True, device=torch.cuda.current_device() @@ -58,12 +61,15 @@ def recv(shapes: List[List[int]], from_ranks: List[int]): req.wait() torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + if len(recv_tensors) == 0: return None elif len(recv_tensors) == 1: return recv_tensors[0] else: return tuple(recv_tensors) def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): + CudaTimer().start(field_name='comm') # print('sending and recving...') ops = list() recv_tensors = list() @@ -94,6 +100,8 @@ def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): req.wait() torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + if len(recv_tensors) == 0: return None elif len(recv_tensors) == 1: return recv_tensors[0] else: return tuple(recv_tensors) @@ -103,6 +111,7 @@ def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): """ Allreduce """ + CudaTimer().start(field_name='comm') # print(f'{torch.distributed.get_rank()}: all_reduce...') assert len(tensors) == 1 tensor = tensors[0] @@ -116,6 +125,8 @@ def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(tensor, group=group) + + CudaTimer().stop(field_name='comm') return tensor @@ -124,6 +135,8 @@ def all_gather(tensors: List[torch.Tensor], ranks: List[int]): Allgather """ # print(f'{torch.distributed.get_rank()}: all_gather...') + CudaTimer().start(field_name='comm') + assert len(tensors) == 1 tensor = tensors[0] group = DeviceGroup().get_group(ranks) @@ -132,6 +145,8 @@ def all_gather(tensors: List[torch.Tensor], ranks: List[int]): tensor_list[idx] = tensor torch.distributed.all_gather(tensor_list, tensor, group=group) tensor_list = [t for oidx, t in enumerate(tensor_list) if oidx != idx] + + CudaTimer().stop(field_name='comm') if len(tensor_list) == 1: return tensor_list[0] else: @@ -143,6 +158,8 @@ def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): ReduceScatter """ # print(f'{torch.distributed.get_rank()}: reduce-scatter...') + CudaTimer().start(field_name='comm') + tensors = list(tensors) group = DeviceGroup().get_group(ranks) idx = ranks.index(DeviceGroup().rank) @@ -150,6 +167,8 @@ def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): torch.distributed.reduce_scatter( output, tensors, group=group ) + + CudaTimer().stop(field_name='comm') return output @@ -157,6 +176,7 @@ def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None): """ Broadcast. ranks[0] is the root """ + CudaTimer().start(field_name='comm') # print(f'{torch.distributed.get_rank()}: broadcast...') # FIXME: data type if len(tensors) == 1: @@ -166,4 +186,6 @@ def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None): # tensor.requires_grad_() group = DeviceGroup().get_group(ranks) torch.distributed.broadcast(tensor, ranks[0], group=group) + + CudaTimer().stop(field_name='comm') return tensor diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py index 67ff4aa5..f367daf4 100644 --- a/cube/runtime/function/dist.py +++ b/cube/runtime/function/dist.py @@ -23,7 +23,6 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, gro shift: int dim: int """ - return input world_size = len(dim_ranks) if world_size == 1: return torch.roll(input, (shift), (dim,)) diff --git a/examples/swin/hybrid_schedule.py b/examples/swin/hybrid_schedule.py new file mode 100644 index 00000000..4b474fca --- /dev/null +++ b/examples/swin/hybrid_schedule.py @@ -0,0 +1,251 @@ +import torch + +from torch.distributed.distributed_c10d import _get_global_rank +from cube.profiler.timer import CudaTimer + + +def get_global_rank(group, group_rank): + if group is None: + return group_rank + else: + return _get_global_rank(group, group_rank) + + +def is_last_stage(group): + return torch.distributed.get_rank(group=group) == torch.distributed.get_world_size(group=group) - 1 + + +#================= WhatToDO functions ==================# + +def forward_step(model, image, trans_input=None): + CudaTimer().start("forward") + output = model(image, trans_input) + CudaTimer().stop("forward") + return output + + +def backward_step(feature_map, output_tensor, output_tensor_grad): + """ + Calculate input tensor gradient + """ + if feature_map is not None and feature_map.requires_grad: + feature_map.retain_grad() + CudaTimer().start("backward") + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + CudaTimer().stop("backward") + input_tensor_grad = None + if feature_map is not None and feature_map.requires_grad: + input_tensor_grad = feature_map.grad + return input_tensor_grad + +#================= WhatToDO functions ==================# + +#================= Between Stage functions ==================# + +def send(tensors, to_rank, group): + """ + send tensor to the target rank + """ + if to_rank < 0 or to_rank >= torch.distributed.get_world_size(group): + return None + if group is not None: + to_rank = get_global_rank(group, to_rank) + # print(f'send: {torch.distributed.get_rank()} -> {to_rank}: {tensors[0].shape}') + assert isinstance(tensors, list) or isinstance(tensors, tuple) + CudaTimer().start("send") + reqs = list() + for tensor in tensors: + if tensor is None: + continue + elif torch.is_tensor(tensor): + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, to_rank + ) + reqs.append(send_op) + else: + raise RuntimeError("Expected tensor or None") + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("send") + + +def recv(shapes, from_rank, dtype, group): + if from_rank < 0 or from_rank >= torch.distributed.get_world_size(group): + return [None] * len(shapes) + assert isinstance(shapes, list) or isinstance(shapes, tuple) + if group is not None: + from_rank = get_global_rank(group, from_rank) + # print(f'recv: {torch.distributed.get_rank()} <- {from_rank}: {shapes}') + CudaTimer().start("recv") + reqs = list() + recved_tensors = list() + for shape in shapes: + if shape is None: + recved_tensors.append(None) + continue + tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device(), + dtype=dtype + ) + recved_tensors.append(tensor) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, from_rank + ) + reqs.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("recv") + return recved_tensors + + +def send_and_recv(send_tensors, recv_shapes, rank, dtype, group): + if rank < 0 or rank >= torch.distributed.get_world_size(group): + return [None] * len(recv_shapes) + if group is not None: + rank = get_global_rank(group, rank) + # print(f'exchange: {torch.distributed.get_rank()} <-> {rank}: {recv_shapes}') + assert isinstance(send_tensors, list) or isinstance(send_tensors, tuple) + assert isinstance(recv_shapes, list) or isinstance(recv_shapes, tuple) + CudaTimer().start("send_recv") + reqs = list() + recved_tensors = list() + for tensor in send_tensors: + if tensor is None: + continue + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, rank + ) + reqs.append(send_op) + for shape in recv_shapes: + if shape is None: + recved_tensors.append(None) + continue + recv_tensor = torch.empty( + shape, requires_grad=True, device=torch.cuda.current_device(), + dtype=dtype + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, rank + ) + recved_tensors.append(recv_tensor) + reqs.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(reqs) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop("send_recv") + return recved_tensors + +#================= Between Stage functions ==================# + +def split_batch(inputs, num_microbatches): + """ + Split a mini-batch to micro-batches + """ + assert isinstance(inputs, list) or isinstance(inputs, tuple) + input_chunks = list() + for feature_map in inputs: + if torch.is_tensor(feature_map): + feature_map = torch.chunk(feature_map, chunks=num_microbatches, dim=0) + else: + feature_map = [feature_map] * num_microbatches + input_chunks.append(feature_map) + micro_batches = list() + for micro_data in zip(*tuple(input_chunks)): + micro_batches.append(micro_data) + return micro_batches + + +#================= Scheduling ==================# + +def scheduling_1f1b(model, inputs, bs, micro_bs, dtype, group): + myrank = torch.distributed.get_rank(group) + + num_microbatches = int(bs / micro_bs) + num_warmup_microbatches = \ + (torch.distributed.get_world_size(group) - + torch.distributed.get_rank(group) - 1) + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_warmup_remaining = num_microbatches - num_warmup_microbatches + + input_tensors = list() + output_tensors = list() + + inputs = split_batch(inputs, num_microbatches) + + # warmup forward pass + for i in range(num_warmup_microbatches): + # recv forward + # print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) + feature_map = recv( + (torch.Size(model.in_size),), myrank-1, dtype, group + )[0] + image = inputs[i][0] + # forward + output_tensor = forward_step(model, image, feature_map) + # send forward + # print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) + send((output_tensor,), myrank+1, group) + + input_tensors.append(feature_map) + output_tensors.append(output_tensor) + + # before running 1F1B, need to recieve first forward tensor + if num_warmup_remaining > 0: + # recv forward + # print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) + feature_map = recv( + (torch.Size(model.in_size),), myrank-1, dtype, group + )[0] + image = inputs[num_warmup_microbatches][0] + + # run 1F1B + for i in range(num_warmup_remaining): + # forward + output_tensor = forward_step(model, image, feature_map) + # send forward + recv backward grads + # print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) + output_tensor_grad = send_and_recv( + (output_tensor,), + (torch.Size(model.out_size),), + myrank+1, dtype, group + )[0] + input_tensors.append(feature_map) + output_tensors.append(output_tensor) + # backward + feature_map, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) + if i != (num_warmup_remaining-1): + # send backward grads + recv forward results + # print('[1f1b] rank {}: step-{}: sending backward + recving forward...'.format(myrank, i)) + feature_map = send_and_recv( + (input_tensor_grad,), + (torch.Size(model.in_size),), + myrank-1, dtype, group + )[0] + image = inputs[num_warmup_microbatches+i+1][0] + else: # last iteration - no more inputs + feature_map = None + # send backward grads + # print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) + send((input_tensor_grad,), myrank-1, group) + + # cooldown gradient trans back + for i in range(num_warmup_microbatches): + feature_map = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + # recv backward gradients + output_tensor_grad = recv( + (torch.Size(model.out_size),), myrank+1, dtype, group + )[0] + # backward + input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) + # send backward gradients + # print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) + send((input_tensor_grad,), myrank-1, group) + +#================= Scheduling ==================# \ No newline at end of file diff --git a/examples/swin/pmodule.py b/examples/swin/pmodule.py new file mode 100644 index 00000000..1ab0387e --- /dev/null +++ b/examples/swin/pmodule.py @@ -0,0 +1,65 @@ +from typing import List + +import torch + +from cube.runtime.device import DeviceGroup + + +class ParallelModule(torch.nn.Module): + + def __init__(self, pp_ranks: List[int] = list(), + dp_ranks: List[int] = list(), + tp_ranks: List[int] = list()): + + super().__init__() + self._pp_ranks = tuple(pp_ranks) + self._pp_group = DeviceGroup().get_group(pp_ranks) + + self._dp_ranks = tuple(dp_ranks) + self._dp_group = DeviceGroup().get_group(dp_ranks) + + self._tp_ranks = tuple(tp_ranks) + self._tp_group = DeviceGroup().get_group(tp_ranks) + + self.in_size = None + self.out_size = None + + @property + def pp_ranks(self): + return self._pp_ranks + + @property + def pp_group(self): + return self._pp_group + + def use_pp(self): + return len(self._pp_ranks) > 1 + + @property + def dp_ranks(self): + return self._dp_ranks + + @property + def dp_group(self): + return self._dp_group + + def use_dp(self): + return len(self._dp_ranks) > 1 + + @property + def tp_ranks(self): + return self._tp_ranks + + @property + def tp_group(self): + return self._tp_group + + @property + def use_tp(self): + return len(self._tp_ranks) > 1 + + def set_in_size(self, size: List[int]): + self.in_size = size + + def set_out_size(self, size: List[int]): + self.out_size = size \ No newline at end of file diff --git a/examples/swin/swin_hybrid.py b/examples/swin/swin_hybrid.py new file mode 100644 index 00000000..8aea699e --- /dev/null +++ b/examples/swin/swin_hybrid.py @@ -0,0 +1,1022 @@ + +# -------------------------------------------------------- +# Modified from Swin-Transformer Repo +""" +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_hybrid.py \ + --layer0 4 2 1 \ + --layer1 4 2 1 \ + --layer2 4 1 2 \ + --layer3 4 1 2 \ + --gbs 32 --mbs 2 + +# V100-16GB: 8GPU: need checkpoint: 8 micro bs +""" +# -------------------------------------------------------- + +from typing import Optional, Dict, Tuple +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + + +import cube +from cube.profiler import CudaTimer +from cube.runtime.device import DeviceGroup +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary +from cube.runtime.reducer import Reducer + +import argparse + +from examples.swin.hybrid_schedule import scheduling_1f1b, is_last_stage +from examples.swin.layers import ColumnParallelLinear, RowParallelLinear, DPtoTP, TPtoDP + +from examples.swin.pmodule import ParallelModule + + +_dp_reducer: Dict[Tuple[int], Reducer] = dict() + + +def setup_device_group(pp: int, dp: int, tp: int, layer_id: int): + """ + Layer wise device group initialize + + Returns: + + """ + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + if not pp * tp * dp == ndevs: + raise RuntimeError("Expected same device number") + + assert tp == 1 or dp == 1, "Currently hybrid not supported" + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize tensor parallel groups + for i in range(ndevs // tp): + ranks = list(range(i * tp, (i + 1) * tp)) + if len(ranks) > 1: + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + + # initialize data parallel groups + for i in range(pp): + start_rank = i * ndevs // pp + end_rank = (i+1) * ndevs // pp + for j in range(tp): + ranks = list(range(start_rank + j, end_rank, tp)) + if len(ranks) > 1: + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + _dp_reducer[tuple(dp_ranks)] = Reducer(dp_ranks) + print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) + + # initialize pipeline parallel groups + for i in range(dp * tp): + ranks = list(range(i, ndevs, tp * dp)) + if len(ranks) > 1: + group = devs.get_group(ranks) + if myrank in ranks: + pp_ranks = ranks + print_each_rank(f'layer {layer_id}: initialized pipeline parallel group: {pp_ranks}') + + return pp_ranks, dp_ranks, tp_ranks + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class MegatronMlp(ParallelModule): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., + pp_ranks=-1, tp_ranks=-1, dp_ranks=-1): + super().__init__( + pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks + ) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=self.tp_group) + self.act = act_layer() + # self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=True, tp_group=self.tp_group) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +class MegatronWindowAttention(ParallelModule): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0, + pp_ranks=-1, tp_ranks=-1, dp_ranks=-1): + + super().__init__( + pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks + ) + self.dim = dim + self.window_size = window_size # Wh, Ww + self.global_num_heads = num_heads + + tp_world_size = torch.distributed.get_world_size(group=self.tp_group) + if num_heads % tp_world_size != 0: + print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') + self.num_heads = num_heads // tp_world_size + self.dim_heads = dim // self.global_num_heads + self.scale = qk_scale or self.dim_heads ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # relative position index + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + self.register_buffer('relative_position_index', relative_position_index) + + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # print(f'qkv embed dim: {dim}') + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=self.tp_group) + self.attn_drop = nn.Dropout(attn_drop) + # self.proj = nn.Linear(dim, dim) + self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=True, tp_group=self.tp_group) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(ParallelModule): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + pp_ranks=-1, tp_ranks=-1, dp_ranks=-1, fw_bs=-1): + super().__init__( + pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks + ) + + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = MegatronWindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MegatronMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + H, W = self.input_resolution + + self.set_in_size([fw_bs // len(dp_ranks), H * W, self.dim]) + self.set_out_size([fw_bs // len(dp_ranks), H * W, self.dim]) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + # [B, H, W, C] -> [B, H * W, C] + x = x.view(B, H * W, C) + # [B, H * W, C] -> [B, H * W, C] + x = shortcut + drop_path(x, self.drop_path_p) + # FFN + # [B, H * W, C] -> [B, H * W, C] + ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] + ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] + x = x + drop_path(ffn, self.drop_path_p) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(ParallelModule): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, + pp_ranks=-1, tp_ranks=-1, dp_ranks=-1, fw_bs=-1): + super().__init__( + pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks + ) + + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + H, W = self.input_resolution + + self.set_in_size([fw_bs // len(dp_ranks), H * W, self.dim]) + self.set_out_size([fw_bs // len(dp_ranks), H // 2 * W // 2, self.dim * 2]) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(ParallelModule): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, + pp_ranks=-1, tp_ranks=-1, dp_ranks=-1, layer_id=-1, fw_bs=-1): + + super().__init__( + pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks + ) + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList([]) + for i in range(depth): + block = SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks, fw_bs=fw_bs + ) + self.blocks.append(block) + + def forward(self, x): + raise RuntimeError("Error call here") + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + pconfigs=None, fw_bs=-1, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + + tp_ranks, dp_ranks, pp_ranks = list(), list(), list() + for i in range(4): + pconfig = pconfigs[i] + layer_pp_ranks, layer_dp_ranks, layer_tp_ranks = setup_device_group(**pconfig) + tp_ranks.append(layer_tp_ranks) + dp_ranks.append(layer_dp_ranks) + pp_ranks.append(layer_pp_ranks) + + # build network layers + layers = nn.ModuleList() + for i_layer in range(self.num_layers): + pconfig = pconfigs[i_layer] + layer_tp_ranks, layer_dp_ranks = tp_ranks[i_layer], dp_ranks[i_layer] + + if i_layer != self.num_layers - 1: + next_layer_tp_ranks = tp_ranks[i_layer + 1] + next_layer_dp_ranks = dp_ranks[i_layer + 1] + else: + next_layer_dp_ranks = list() + next_layer_tp_ranks = list() + + input_resolution = ( + patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer) + ) + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=input_resolution, + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + pp_ranks=pp_ranks[i_layer], dp_ranks=dp_ranks[i_layer], tp_ranks=tp_ranks[i_layer], + fw_bs=fw_bs + ) + + for block in layer.blocks: + layers.append(block) + + if i_layer < self.num_layers - 1: + merging = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** i_layer), + norm_layer=norm_layer, + pp_ranks=pp_ranks[i_layer], dp_ranks=dp_ranks[i_layer], tp_ranks=tp_ranks[i_layer], + fw_bs = fw_bs, + ) + layers.append(merging) + else: + merging = None + + # adapter + if len(layer_tp_ranks) > 1 and len(next_layer_dp_ranks) > 1: + print_each_rank('add tp to dp adapters') + adapter = TPtoDP(DeviceGroup().get_group(next_layer_dp_ranks)) + adapter.in_size = layers[-1].out_size + out_size = [size for size in layers[-1].out_size] + out_size[0] = out_size[0] // len(next_layer_dp_ranks) + adapter.out_size = out_size + elif len(layer_dp_ranks) > 1 and len(next_layer_tp_ranks) > 1: + print_each_rank('add dp to tp adapters') + adapter = DPtoTP(DeviceGroup().get_group(next_layer_tp_ranks)) + adapter.in_size = layers[-1].out_size + out_size = [size for size in layers[-1].out_size] + out_size[0] = out_size[0] * len(layer_dp_ranks) + adapter.out_size = out_size + else: + adapter = torch.nn.Identity() + adapter.in_size = layers[-1].out_size + adapter.out_size = layers[-1].out_size + layers.append(adapter) + + + # ================ Pipeline Parallel Region ====================== + self.pp_group = DeviceGroup().get_group(pp_ranks[0]) + pp_rank = torch.distributed.get_rank(self.pp_group) + pp_size = torch.distributed.get_world_size(self.pp_group) + + assert len(layers) == 31 + + for block in layers: + print_each_rank(f'> block: {type(block).__name__}: in {block.in_size}, out: {block.out_size}', rank_only=0) + + chunk = len(layers) // pp_size + if len(layers) % pp_size != 0: + remain = len(layers) % pp_size + if pp_rank < remain: + start = pp_rank * (chunk+1) + chunk = chunk + 1 + else: + start = remain * (chunk + 1) + (pp_rank - remain) * chunk + else: + start = pp_rank * chunk + stop = start + chunk + + # self.use_checkpoint = [True] * (stop - start) + self.use_checkpoint = [False] * (stop - start) + + # 8gpu layer assign + # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] + # assert sum(layer_split) == 27 + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) + # self.use_checkpoint = [False] * (stop - start) + # for idx in range(stop - start): + # if pp_rank == 0: + # if idx < 1: + # self.use_checkpoint[idx] = True + + # 4Ggpu layer assign + # layer_split = [6, 7, 7, 7] + # assert sum(layer_split) == 27 + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) + + print_each_rank(f'layer start -> end: {start} -> {stop}') + print_each_rank(self.use_checkpoint) + self.layers = layers[start:stop] + + local_chunk = list() + for block in self.layers: + local_chunk.append(f'{type(block).__name__}: in: {block.in_size}; out: {block.out_size}') + local_chunk = '\n'.join(local_chunk) + print_each_rank('local chunk:\n' + local_chunk) + + self.in_size = self.layers[0].in_size + assert isinstance(self.in_size, list) + self.out_size = self.layers[-1].out_size + assert isinstance(self.out_size, list) + + self.preprocess = False + if pp_rank == 0: + self.preprocess = True + self.in_size = [in_chans, img_size, img_size] + self.postprocess = False + if is_last_stage(self.pp_group): + self.postprocess = True + self.out_size = [1,] + + if self.postprocess: + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + + # =================== Data Parallel ======================== + + self.split_data = len(dp_ranks[0]) + + # preprocess data parallel region + if self.preprocess and len(dp_ranks[0]) > 1: + for param in self.patch_embed.parameters(): + _dp_reducer[tuple(dp_ranks[0])].add_param(param) + + # block data parallel region + for block in self.layers: + if isinstance(block, ParallelModule): + if block.use_dp: + for param in block.parameters(): + _dp_reducer[block.dp_ranks].add_param(param) + + # postprocess data parallel region + if self.postprocess and len(dp_ranks[-1]) > 1: + for param in self.norm.parameters(): + _dp_reducer[tuple(dp_ranks[-1])].add_param(param) + for param in self.head.parameters(): + _dp_reducer[tuple(dp_ranks[-1])].add_param(param) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, image: torch.Tensor, feature_map=None): + + if self.preprocess: + with torch.no_grad(): + # FIXME: should select corresponding chunk + image = image.chunk(self.split_data, 0)[0] + x = self.patch_embed(image) + x = self.pos_drop(x) + feature_map = x + + for layer, use_checkpoint in zip(self.layers, self.use_checkpoint): + if use_checkpoint: + feature_map = checkpoint.checkpoint(layer, feature_map) + else: + feature_map = layer(feature_map) + x = feature_map + + if self.postprocess: + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C L + x = torch.flatten(x, 1) + x = self.head(x) + # simulate for simplicity + x = torch.sum(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def train(args, pconfigs): + + # dim_head is always 32 + + # img resolution, windows size: 224, 384, 518, 640 + C, H, W, window_size = [3, 224, 224, 7] + # C, H, W, window_size = [3, 384, 384, 12] + # C, H, W, window_size = [3, 518, 518, ?] + # C, H, W, window_size = [3, 640, 640, 20] + # C, H, W, window_size = [3, 1536, 1536, 48] + + # image batch size + N = args.gbs + + # Swin-Tiny + # embed_dim, depths, num_heads = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24] + # ] + + # SwinV2-B: 87 M + # embed_dim, depths, num_heads = [ + # 128, [2, 2, 18, 2], [4, 8, 16, 32] + # ] + + # SwinV2-L: 196 M + # embed_dim, depths, num_heads = [ + # 192, [2, 2, 18, 2], [6, 12, 24, 48] + # ] + + # SwinV2-H: 657 M + # embed_dim, depths, num_heads = [ + # 352, [2, 2, 18, 2], [11, 22, 44, 88] + # ] + + # SwinV2-H modified: 782 M + embed_dim, depths, num_heads = [ + 384, [2, 2, 18, 2], [12, 24, 48, 96] + ] + + # SwinV2-G: 2.5B Model + # embed_dim, depths, num_heads = [ + # 512, [2, 2, 42, 2], [16, 32, 64, 128] + # ] + + # 895.7 M Model + # embed_dim, depths, num_heads = [ + # 384, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + + # 2.01B model + # embed_dim, depths, num_heads = [ + # 576, [2, 2, 22, 2], [12, 24, 48, 96] + # ] + + + model = SwinTransformer(img_size = H, + embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size, + pconfigs = pconfigs, + fw_bs = args.mbs) + model = model.cuda() + memory_summary() + + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [args.gbs, C, H, W]) + dataloader.set_data_buffer(buffer_num=2) + + def train_iter(model, dataloader): + img = next(dataloader) + scheduling_1f1b(model, [img], args.gbs, args.mbs, torch.float, model.pp_group) + CudaTimer().start('dp_allreduce') + for ranks in _dp_reducer: + reducer = _dp_reducer[ranks] + reducer.allreduce() + CudaTimer().stop('dp_allreduce') + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # start training + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + CudaTimer().warmup() + torch.distributed.barrier() + iter_num = 20 + for step in range(iter_num): + if step >= 10: + CudaTimer().start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on 1st iteration') + memory_summary() + if step >= 10: + CudaTimer().stop('e2e') + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = CudaTimer().duration(iter_num-10, field_name='e2e') + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + + CudaTimer().print_all(times=iter_num-10) + memory_summary() + + +if __name__ == '__main__': + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--layer0', type=int, nargs='+', + help='pipeline, data, tensor parallel config') + parser.add_argument('--layer1', type=int, nargs='+', + help='pipeline, data, tensor parallel config') + parser.add_argument('--layer2', type=int, nargs='+', + help='pipeline, data, tensor parallel config') + parser.add_argument('--layer3', type=int, nargs='+', + help='pipeline, data, tensor parallel config') + parser.add_argument('--gbs', type=int, default=-1) + parser.add_argument('--mbs', type=int, default=-1) + args = parser.parse_args() + + cube.init() + + # allocate resource + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + args.pp = args.layer0[0] + + pconfigs = [ + dict(layer_id=0, pp=args.layer0[0], dp=args.layer0[1], tp=args.layer0[2]), # basic layer 0 + dict(layer_id=1, pp=args.layer0[0], dp=args.layer1[1], tp=args.layer1[2]), # basic layer 1 + dict(layer_id=2, pp=args.layer0[0], dp=args.layer2[1], tp=args.layer2[2]), # basic layer 2 + dict(layer_id=3, pp=args.layer0[0], dp=args.layer3[1], tp=args.layer3[2]), # basic layer 3 + ] + + train(args, pconfigs) From 78f6ecfd8d665d05f117377afce4740dd887d23a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 3 Dec 2021 12:32:42 +0000 Subject: [PATCH 0469/1892] fix swin hybrid allreduce bug --- examples/swin/swin_hybrid.py | 11 ++++++----- examples/swin/swin_pipe.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/swin/swin_hybrid.py b/examples/swin/swin_hybrid.py index 8aea699e..ac354469 100644 --- a/examples/swin/swin_hybrid.py +++ b/examples/swin/swin_hybrid.py @@ -760,11 +760,12 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, self.use_checkpoint = [False] * (stop - start) # 8gpu layer assign + layer_split = [5, 5, 4, 3, 3, 3, 3, 5] # original # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] - # assert sum(layer_split) == 27 - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - # self.use_checkpoint = [False] * (stop - start) + assert sum(layer_split) == 31 + start = sum(layer_split[0:pp_rank]) + stop = sum(layer_split[0:pp_rank+1]) + self.use_checkpoint = [False] * (stop - start) # for idx in range(stop - start): # if pp_rank == 0: # if idx < 1: @@ -818,7 +819,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # block data parallel region for block in self.layers: if isinstance(block, ParallelModule): - if block.use_dp: + if block.use_dp(): for param in block.parameters(): _dp_reducer[block.dp_ranks].add_param(param) diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index 59a103a0..f74ea8e6 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -603,8 +603,8 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, start = pp_rank * chunk stop = start + chunk - self.use_checkpoint = [True] * (stop - start) - # self.use_checkpoint = [False] * (stop - start) + # self.use_checkpoint = [True] * (stop - start) + self.use_checkpoint = [False] * (stop - start) # 8gpu layer assign # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] From 762a44785b3ca94c6f77ad3c0b715d0e7a9025e7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 4 Dec 2021 06:28:49 +0000 Subject: [PATCH 0470/1892] swin for hybrid --- examples/mlp/linears.py | 31 ++++++++++---------- examples/mlp/policy/col_parallel.py | 3 ++ examples/swin/layers.py | 2 ++ examples/swin/swin_hybrid.py | 45 +++++++++++++++++++---------- 4 files changed, 50 insertions(+), 31 deletions(-) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 7c187e91..6a1f6290 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,39 +17,39 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.pipe1f1b_parallel import transform_policy -from examples.mlp.policy.pipe1f1b_parallel import schedule_policy +from examples.mlp.policy.col_parallel import transform_policy +from examples.mlp.policy.col_parallel import schedule_policy # =================== Semantic Model Description ==================== class MLP(nn.Module): - def __init__(self, dim, mult=4): + def __init__(self, dim, mult=1): super().__init__() self.linear1 = nn.Linear(dim, dim * mult) self.linear2 = nn.Linear(dim * mult, dim) self.linear3 = nn.Linear(dim, dim * mult) self.linear4 = nn.Linear(dim * mult, dim) - self.linear5 = nn.Linear(dim, dim * mult) - self.linear6 = nn.Linear(dim * mult, dim) - self.linear7 = nn.Linear(dim, dim * mult) - self.linear8 = nn.Linear(dim * mult, dim) + # self.linear5 = nn.Linear(dim, dim * mult) + # self.linear6 = nn.Linear(dim * mult, dim) + # self.linear7 = nn.Linear(dim, dim * mult) + # self.linear8 = nn.Linear(dim * mult, dim) def forward(self, data): output = self.linear1(data) output = self.linear2(output) output = self.linear3(output) output = self.linear4(output) - output = self.linear5(output) - output = self.linear6(output) - output = self.linear7(output) - output = self.linear8(output) + # output = self.linear5(output) + # output = self.linear6(output) + # output = self.linear7(output) + # output = self.linear8(output) loss = torch.sum(output) return loss def train(): - batch_size = 128 - dim = 1024 + batch_size = 8192 + dim = 8192 model = MLP(dim=dim) model = cube.SemanticModel( @@ -67,12 +67,12 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - CudaTimer().warmup() + CudaTimer(enable=False).warmup() torch.distributed.barrier() iter_num = 128 for step in range(iter_num): if step >= 40: - CudaTimer().start('e2e') + CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() @@ -83,6 +83,7 @@ def train_iter(model, dataloader): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-40, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-40) if __name__ == '__main__': diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index ca118e4b..96e4225b 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -38,4 +38,7 @@ def schedule_policy(sugraph: SUGraph, resource): print(f'error su: {su}') assert False sugraph.assign(su.mirror, devid) + fsus = sugraph.fsus() + print('> [scheduling] setting schedule order...') + sugraph.partial_set_order(fsus, lazy=False) return sugraph diff --git a/examples/swin/layers.py b/examples/swin/layers.py index 61b70932..27929d06 100644 --- a/examples/swin/layers.py +++ b/examples/swin/layers.py @@ -15,6 +15,7 @@ def _reduce(input_, group): CudaTimer().stop(field_name='tp_allreduce') return input_ torch.distributed.all_reduce(input_, group=group) + torch.cuda.synchronize() CudaTimer().stop(field_name='tp_allreduce') return input_ @@ -48,6 +49,7 @@ def _gather(input_, group, dim=-1): tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=group) + torch.cuda.synchronize() # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=dim).contiguous() diff --git a/examples/swin/swin_hybrid.py b/examples/swin/swin_hybrid.py index ac354469..01f6b4d8 100644 --- a/examples/swin/swin_hybrid.py +++ b/examples/swin/swin_hybrid.py @@ -10,11 +10,11 @@ --master_port=8004 \ --use_env \ examples/swin/swin_hybrid.py \ - --layer0 4 2 1 \ - --layer1 4 2 1 \ - --layer2 4 1 2 \ - --layer3 4 1 2 \ - --gbs 32 --mbs 2 + --layer0 8 1 1 \ + --layer1 8 1 1 \ + --layer2 8 1 1 \ + --layer3 8 1 1 \ + --gbs 1 --mbs 1 # V100-16GB: 8GPU: need checkpoint: 8 micro bs """ @@ -57,7 +57,7 @@ def setup_device_group(pp: int, dp: int, tp: int, layer_id: int): if not pp * tp * dp == ndevs: raise RuntimeError("Expected same device number") - assert tp == 1 or dp == 1, "Currently hybrid not supported" + # assert tp == 1 or dp == 1, "Currently hybrid not supported" devs = cube.runtime.device.DeviceGroup() @@ -362,6 +362,7 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 H, W = self.input_resolution + assert fw_bs // len(dp_ranks) != 0 self.set_in_size([fw_bs // len(dp_ranks), H * W, self.dim]) self.set_out_size([fw_bs // len(dp_ranks), H * W, self.dim]) @@ -455,6 +456,7 @@ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, H, W = self.input_resolution + assert fw_bs // len(dp_ranks) != 0 self.set_in_size([fw_bs // len(dp_ranks), H * W, self.dim]) self.set_out_size([fw_bs // len(dp_ranks), H // 2 * W // 2, self.dim * 2]) @@ -713,21 +715,24 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, merging = None # adapter - if len(layer_tp_ranks) > 1 and len(next_layer_dp_ranks) > 1: + if len(layer_dp_ranks) == 1 and len(layer_tp_ranks) > 1 \ + and len(next_layer_dp_ranks) > 1 and len(next_layer_tp_ranks) == 1: print_each_rank('add tp to dp adapters') adapter = TPtoDP(DeviceGroup().get_group(next_layer_dp_ranks)) adapter.in_size = layers[-1].out_size out_size = [size for size in layers[-1].out_size] out_size[0] = out_size[0] // len(next_layer_dp_ranks) adapter.out_size = out_size - elif len(layer_dp_ranks) > 1 and len(next_layer_tp_ranks) > 1: + elif len(layer_tp_ranks) == 1 and len(layer_dp_ranks) > 1 \ + and len(next_layer_tp_ranks) > 1 and len(next_layer_dp_ranks) == 1: print_each_rank('add dp to tp adapters') adapter = DPtoTP(DeviceGroup().get_group(next_layer_tp_ranks)) adapter.in_size = layers[-1].out_size out_size = [size for size in layers[-1].out_size] out_size[0] = out_size[0] * len(layer_dp_ranks) adapter.out_size = out_size - else: + elif len(layer_tp_ranks) == len(next_layer_tp_ranks) and \ + len(layer_dp_ranks) == len(next_layer_dp_ranks): adapter = torch.nn.Identity() adapter.in_size = layers[-1].out_size adapter.out_size = layers[-1].out_size @@ -756,16 +761,16 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, start = pp_rank * chunk stop = start + chunk - # self.use_checkpoint = [True] * (stop - start) - self.use_checkpoint = [False] * (stop - start) + # self.use_checkpoint = [False] * (stop - start) + self.use_checkpoint = [True] * (stop - start) # 8gpu layer assign - layer_split = [5, 5, 4, 3, 3, 3, 3, 5] # original + # layer_split = [5, 5, 4, 3, 3, 3, 3, 5] # original # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] - assert sum(layer_split) == 31 - start = sum(layer_split[0:pp_rank]) - stop = sum(layer_split[0:pp_rank+1]) - self.use_checkpoint = [False] * (stop - start) + # assert sum(layer_split) == 31 + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) + # self.use_checkpoint = [False] * (stop - start) # for idx in range(stop - start): # if pp_rank == 0: # if idx < 1: @@ -915,6 +920,14 @@ def train(args, pconfigs): embed_dim, depths, num_heads = [ 384, [2, 2, 18, 2], [12, 24, 48, 96] ] + # head dim 32 -> 48 + embed_dim, depths, num_heads = [ + 576, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # head dim 32 -> 64 + # embed_dim, depths, num_heads = [ + # 768, [2, 2, 18, 2], [12, 24, 48, 96] + # ] # SwinV2-G: 2.5B Model # embed_dim, depths, num_heads = [ From 57b3c4b3f1b11f26e1182dc8e11b585fb47528ec Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 4 Dec 2021 07:53:39 +0000 Subject: [PATCH 0471/1892] swin hybrid --- examples/swin/swin_dt.py | 4 ++++ examples/swin/swin_hybrid.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/swin/swin_dt.py b/examples/swin/swin_dt.py index 7da2d339..3f10db43 100644 --- a/examples/swin/swin_dt.py +++ b/examples/swin/swin_dt.py @@ -823,6 +823,10 @@ def train(args, pconfigs): embed_dim, depths, num_heads = [ 384, [2, 2, 18, 2], [12, 24, 48, 96] ] + # head dim 32 -> 48 + embed_dim, depths, num_heads = [ + 576, [2, 2, 18, 2], [12, 24, 48, 96] + ] # SwinV2-G: 2.5B Model # embed_dim, depths, num_heads = [ diff --git a/examples/swin/swin_hybrid.py b/examples/swin/swin_hybrid.py index 01f6b4d8..ccb3b00c 100644 --- a/examples/swin/swin_hybrid.py +++ b/examples/swin/swin_hybrid.py @@ -761,8 +761,8 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, start = pp_rank * chunk stop = start + chunk - # self.use_checkpoint = [False] * (stop - start) - self.use_checkpoint = [True] * (stop - start) + self.use_checkpoint = [False] * (stop - start) + # self.use_checkpoint = [True] * (stop - start) # 8gpu layer assign # layer_split = [5, 5, 4, 3, 3, 3, 3, 5] # original From afc09b28fcc8e133610c7e5405153594383124e0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 4 Dec 2021 09:48:06 +0000 Subject: [PATCH 0472/1892] infer --- eval/swin_infer.sh | 106 -------------- eval/swin_infer_bs2_224_782Mfp32.sh | 2 +- eval/swin_train_fp32.sh | 218 ++++++++++++++-------------- 3 files changed, 110 insertions(+), 216 deletions(-) delete mode 100755 eval/swin_infer.sh diff --git a/eval/swin_infer.sh b/eval/swin_infer.sh deleted file mode 100755 index 13f468df..00000000 --- a/eval/swin_infer.sh +++ /dev/null @@ -1,106 +0,0 @@ -mkdir -p expinfer32_2.6B_fp16_bs4 - -# ================== Maximal Tensor Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 1 \ - --layer1 1 1 1 \ - --layer2 1 1 1 \ - --layer3 1 1 1 \ - > expinfer32_2.6B_fp16_bs4/1gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 2 \ - --layer1 1 1 2 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - > expinfer32_2.6B_fp16_bs4/2gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 4 \ - --layer1 1 1 4 \ - --layer2 1 1 4 \ - --layer3 1 1 4 \ - > expinfer32_2.6B_fp16_bs4/4gpu_tp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 4 1 2 \ - --layer1 4 1 2 \ - --layer2 4 1 2 \ - --layer3 4 1 2 \ - > expinfer32_2.6B_fp16_bs4/8gpu_tp.txt - - -# ================== Window + Tensor Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 2 1 1 \ - --layer1 2 1 1 \ - --layer2 2 1 1 \ - --layer3 2 1 1 \ - > expinfer32_2.6B_fp16_bs4/2gpu_2wp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 4 1 1 \ - --layer1 4 1 1 \ - --layer2 4 1 1 \ - --layer3 4 1 1 \ - > expinfer32_2.6B_fp16_bs4/4gpu_4wp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 2 4 1 \ - --layer1 2 4 1 \ - --layer2 2 4 1 \ - --layer3 2 1 4 \ - > expinfer32_2.6B_fp16_bs4/8gpu_8wp8tp.txt - diff --git a/eval/swin_infer_bs2_224_782Mfp32.sh b/eval/swin_infer_bs2_224_782Mfp32.sh index 46b32221..628947b8 100755 --- a/eval/swin_infer_bs2_224_782Mfp32.sh +++ b/eval/swin_infer_bs2_224_782Mfp32.sh @@ -47,7 +47,7 @@ python -m torch.distributed.launch \ > ${logfile}/4gpu_tp.txt python -m torch.distributed.launch \ - --nproc_per_node=4 \ + --nproc_per_node=8 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ diff --git a/eval/swin_train_fp32.sh b/eval/swin_train_fp32.sh index ad441b70..97bc677e 100755 --- a/eval/swin_train_fp32.sh +++ b/eval/swin_train_fp32.sh @@ -4,19 +4,19 @@ logfile=exptrain_782M_bs${bs}_fp32_384 mkdir -p ${logfile} -# python -m torch.distributed.launch \ -# --nproc_per_node=1 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt.py --bs ${bs} \ -# --layer0 1 1 1 \ -# --layer1 1 1 1 \ -# --layer2 1 1 1 \ -# --layer3 1 1 1 \ -# > ${logfile}/single.txt +python -m torch.distributed.launch \ + --nproc_per_node=1 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs ${bs} \ + --layer0 1 1 1 \ + --layer1 1 1 1 \ + --layer2 1 1 1 \ + --layer3 1 1 1 \ + > ${logfile}/single.txt # ================== Megatron Policy Parallel =============== @@ -34,19 +34,19 @@ python -m torch.distributed.launch \ --layer3 1 1 2 \ > ${logfile}/2gpu_maxdp2tp.txt -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 2 1 2 \ - --layer1 2 1 2 \ - --layer2 2 1 2 \ - --layer3 2 1 2 \ - > ${logfile}/4gpu_maxdp2tp.txt +# python -m torch.distributed.launch \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt.py --bs ${bs} \ +# --layer0 2 1 2 \ +# --layer1 2 1 2 \ +# --layer2 2 1 2 \ +# --layer3 2 1 2 \ +# > ${logfile}/4gpu_maxdp2tp.txt python -m torch.distributed.launch \ @@ -110,47 +110,47 @@ python -m torch.distributed.launch \ # ================== Window + Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 2 1 \ - --layer1 1 2 1 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - > ${logfile}/2gpu_2wp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 4 1 \ - --layer1 1 4 1 \ - --layer2 1 1 4 \ - --layer3 1 1 4 \ - > exptrain_782M_bs8_fp32/4gpu_4wp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 8 1 \ - --layer1 1 1 8 \ - --layer2 1 1 8 \ - --layer3 1 1 8 \ - > ${logfile}/8gpu_8wp8tp.txt +# python -m torch.distributed.launch \ +# --nproc_per_node=2 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt.py --bs ${bs} \ +# --layer0 1 2 1 \ +# --layer1 1 2 1 \ +# --layer2 1 1 2 \ +# --layer3 1 1 2 \ +# > ${logfile}/2gpu_2wp2tp.txt +# +# python -m torch.distributed.launch \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt.py --bs ${bs} \ +# --layer0 1 4 1 \ +# --layer1 1 4 1 \ +# --layer2 1 1 4 \ +# --layer3 1 1 4 \ +# > exptrain_782M_bs8_fp32/4gpu_4wp4tp.txt +# +# python -m torch.distributed.launch \ +# --nproc_per_node=8 \ +# --nnodes=1 \ +# --node_rank=0 \ +# --master_addr=127.0.0.1 \ +# --master_port=8004 \ +# --use_env \ +# examples/swin/swin_dwt.py --bs ${bs} \ +# --layer0 1 8 1 \ +# --layer1 1 1 8 \ +# --layer2 1 1 8 \ +# --layer3 1 1 8 \ +# > ${logfile}/8gpu_8wp8tp.txt # ================== Data + Tensor Parallel =============== @@ -199,45 +199,45 @@ python -m torch.distributed.launch \ # ================== Pure Data Parallel ============= -# python -m torch.distributed.launch \ -# --nproc_per_node=2 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt.py --bs ${bs} \ -# --layer0 2 1 1 \ -# --layer1 2 1 1 \ -# --layer2 2 1 1 \ -# --layer3 2 1 1 \ -# > ${logfile}/2gpu_maxdp.txt -# -# -# python -m torch.distributed.launch \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt.py --bs ${bs} \ -# --layer0 4 1 1 \ -# --layer1 4 1 1 \ -# --layer2 4 1 1 \ -# --layer3 4 1 1 \ -# > ${logfile}/4gpu_maxdp.txt -# -# python -m torch.distributed.launch \ -# --nproc_per_node=8 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt.py --bs ${bs} \ -# --layer0 8 1 1 \ -# --layer1 8 1 1 \ -# --layer2 8 1 1 \ -# --layer3 8 1 1 \ -# > ${logfile}/8gpu_maxdp.txt +python -m torch.distributed.launch \ + --nproc_per_node=2 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs ${bs} \ + --layer0 2 1 1 \ + --layer1 2 1 1 \ + --layer2 2 1 1 \ + --layer3 2 1 1 \ + > ${logfile}/2gpu_maxdp.txt + + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs ${bs} \ + --layer0 4 1 1 \ + --layer1 4 1 1 \ + --layer2 4 1 1 \ + --layer3 4 1 1 \ + > ${logfile}/4gpu_maxdp.txt + +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_dwt.py --bs ${bs} \ + --layer0 8 1 1 \ + --layer1 8 1 1 \ + --layer2 8 1 1 \ + --layer3 8 1 1 \ + > ${logfile}/8gpu_maxdp.txt From cbd1976a3cb715f7e534ea9a55b31f03857ccf04 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 5 Dec 2021 10:05:57 +0000 Subject: [PATCH 0473/1892] add model configs --- examples/swin/swin_hybrid.py | 72 ++++++++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 11 deletions(-) diff --git a/examples/swin/swin_hybrid.py b/examples/swin/swin_hybrid.py index ccb3b00c..7c02af94 100644 --- a/examples/swin/swin_hybrid.py +++ b/examples/swin/swin_hybrid.py @@ -16,6 +16,20 @@ --layer3 8 1 1 \ --gbs 1 --mbs 1 +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=$NID \ + --master_addr=worker-0 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_hybrid.py \ + --layer0 2 8 1 \ + --layer1 2 8 1 \ + --layer2 2 8 1 \ + --layer3 2 8 1 \ + --gbs 8 --mbs 8 + # V100-16GB: 8GPU: need checkpoint: 8 micro bs """ # -------------------------------------------------------- @@ -464,6 +478,8 @@ def forward(self, x): """ x: B, H*W, C """ + assert list(x.shape) == self.in_size + H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" @@ -481,6 +497,7 @@ def forward(self, x): x = self.norm(x) x = self.reduction(x) + assert list(x.shape) == self.out_size return x def extra_repr(self) -> str: @@ -761,8 +778,8 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, start = pp_rank * chunk stop = start + chunk - self.use_checkpoint = [False] * (stop - start) - # self.use_checkpoint = [True] * (stop - start) + # self.use_checkpoint = [False] * (stop - start) + self.use_checkpoint = [True] * (stop - start) # 8gpu layer assign # layer_split = [5, 5, 4, 3, 3, 3, 3, 5] # original @@ -776,9 +793,11 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # if idx < 1: # self.use_checkpoint[idx] = True - # 4Ggpu layer assign + # 4 stage layer assign + # layer_split = [8, 8, 7, 8] # original # layer_split = [6, 7, 7, 7] - # assert sum(layer_split) == 27 + + # assert sum(layer_split) == 31 # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) @@ -920,14 +939,38 @@ def train(args, pconfigs): embed_dim, depths, num_heads = [ 384, [2, 2, 18, 2], [12, 24, 48, 96] ] - # head dim 32 -> 48 + # # head dim 32 -> 48 embed_dim, depths, num_heads = [ 576, [2, 2, 18, 2], [12, 24, 48, 96] ] - # head dim 32 -> 64 - # embed_dim, depths, num_heads = [ - # 768, [2, 2, 18, 2], [12, 24, 48, 96] - # ] + # # head dim 32 -> 64 -- too much + embed_dim, depths, num_heads = [ + 768, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # # head dim 32 -> 80 + embed_dim, depths, num_heads = [ + 960, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # # head dim 32 -> 96 + embed_dim, depths, num_heads = [ + 1152, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # # head dim 32 -> 112 + embed_dim, depths, num_heads = [ + 1344, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # head dim 32 -> 128 + embed_dim, depths, num_heads = [ + 1536, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # head dim 32 -> 144 + embed_dim, depths, num_heads = [ + 1728, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # head dim 32 -> 160 + embed_dim, depths, num_heads = [ + 1920, [2, 2, 18, 2], [12, 24, 48, 96] + ] # SwinV2-G: 2.5B Model # embed_dim, depths, num_heads = [ @@ -945,6 +988,10 @@ def train(args, pconfigs): # 576, [2, 2, 22, 2], [12, 24, 48, 96] # ] + print_each_rank( + f'config: embed_dim: {embed_dim}, depths: {depths}, num_heads: {num_heads}' + ) + model = SwinTransformer(img_size = H, embed_dim = embed_dim, @@ -963,6 +1010,7 @@ def train(args, pconfigs): def train_iter(model, dataloader): img = next(dataloader) scheduling_1f1b(model, [img], args.gbs, args.mbs, torch.float, model.pp_group) + torch.distributed.barrier() CudaTimer().start('dp_allreduce') for ranks in _dp_reducer: reducer = _dp_reducer[ranks] @@ -975,15 +1023,17 @@ def train_iter(model, dataloader): nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - CudaTimer().warmup() + CudaTimer(enable=False).warmup() torch.distributed.barrier() iter_num = 20 for step in range(iter_num): if step >= 10: - CudaTimer().start('e2e') + CudaTimer(enable=True).start('e2e') + torch.distributed.barrier() train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() + torch.distributed.barrier() if step == 1: print('> passed on 1st iteration') memory_summary() From e33092ec629b1859ad52e799362e84aeee3d2d60 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 5 Dec 2021 10:34:02 +0000 Subject: [PATCH 0474/1892] scaleup test --- eval/swin_scaleup.sh | 30 ++++++++++++++++ examples/swin/swin_dt.py | 72 +++++++++++++++++++++++---------------- examples/swin/swin_dwt.py | 10 +++--- 3 files changed, 77 insertions(+), 35 deletions(-) create mode 100644 eval/swin_scaleup.sh diff --git a/eval/swin_scaleup.sh b/eval/swin_scaleup.sh new file mode 100644 index 00000000..6c4de61d --- /dev/null +++ b/eval/swin_scaleup.sh @@ -0,0 +1,30 @@ + +# Swin cube maximal scaling +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=$NID \ + --master_addr=worker-0 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_hybrid.py \ + --layer0 2 8 1 \ + --layer1 2 1 8 \ + --layer2 2 1 8 \ + --layer3 2 1 8 \ + --gbs 8 --mbs 8 + +# Swin Megatron maximal scaling +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=$NID \ + --master_addr=worker-0 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_hybrid.py \ + --layer0 2 8 1 \ + --layer1 2 1 8 \ + --layer2 2 1 8 \ + --layer3 2 1 8 \ + --gbs 8 --mbs 8 diff --git a/examples/swin/swin_dt.py b/examples/swin/swin_dt.py index 3f10db43..8a0c3311 100644 --- a/examples/swin/swin_dt.py +++ b/examples/swin/swin_dt.py @@ -3,17 +3,17 @@ # Modified from Swin-Transformer Repo """ python -m torch.distributed.launch \ - --nproc_per_node=2 \ + --nproc_per_node=1 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dt.py --bs 2 \ - --layer0 2 1 \ - --layer1 2 1 \ - --layer2 2 1 \ - --layer3 2 1 + examples/swin/swin_dt.py --bs 16 \ + --layer0 1 1 \ + --layer1 1 1 \ + --layer2 1 1 \ + --layer3 1 1 """ # -------------------------------------------------------- @@ -815,33 +815,45 @@ def train(args, pconfigs): # ] # SwinV2-H: 657 M - # embed_dim, depths, num_heads = [ - # 352, [2, 2, 18, 2], [11, 22, 44, 88] - # ] - - # SwinV2-H modified: 782 M - embed_dim, depths, num_heads = [ - 384, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # head dim 32 -> 48 embed_dim, depths, num_heads = [ - 576, [2, 2, 18, 2], [12, 24, 48, 96] + 352, [2, 2, 18, 2], [11, 22, 44, 88] ] - # SwinV2-G: 2.5B Model + # # SwinV2-H modified: 782 M # embed_dim, depths, num_heads = [ - # 512, [2, 2, 42, 2], [16, 32, 64, 128] + # 384, [2, 2, 18, 2], [12, 24, 48, 96] # ] - - # 895.7 M Model + # # head dim 32 -> 48 # embed_dim, depths, num_heads = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96] + # 576, [2, 2, 18, 2], [12, 24, 48, 96] # ] - - - # 2.01B model + # # head dim 32 -> 64 -- too much + # embed_dim, depths, num_heads = [ + # 768, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 80 + # embed_dim, depths, num_heads = [ + # 960, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 96 + # embed_dim, depths, num_heads = [ + # 1152, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # # head dim 32 -> 112 + # embed_dim, depths, num_heads = [ + # 1344, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 128 + # embed_dim, depths, num_heads = [ + # 1536, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 144 + # embed_dim, depths, num_heads = [ + # 1728, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 160 # embed_dim, depths, num_heads = [ - # 576, [2, 2, 22, 2], [12, 24, 48, 96] + # 1920, [2, 2, 18, 2], [12, 24, 48, 96] # ] @@ -887,9 +899,9 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() torch.distributed.barrier() span = 0 - iter_num = 128 + iter_num = 60 for step in range(iter_num): - if step >= 40: + if step >= 20: torch.cuda.synchronize() start = time.time() CudaTimer(enable=True).start('e2e') @@ -899,7 +911,7 @@ def train_iter(model, dataloader): if step == 1: print('> passed on 1st iteration') memory_summary() - if step >= 40: + if step >= 20: torch.cuda.synchronize() stop = time.time() span += (stop - start) * 1000 @@ -907,13 +919,13 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = span / (iter_num-40) + iter_time = span / (iter_num-20) throughput = N / iter_time * 1000 print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) memory_summary() - CudaTimer().print_all(times=iter_num-40) + CudaTimer().print_all(times=iter_num-20) if __name__ == '__main__': diff --git a/examples/swin/swin_dwt.py b/examples/swin/swin_dwt.py index 3ae2e785..10f9c769 100644 --- a/examples/swin/swin_dwt.py +++ b/examples/swin/swin_dwt.py @@ -904,9 +904,9 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() torch.distributed.barrier() span = 0 - iter_num = 128 + iter_num = 60 for step in range(iter_num): - if step >= 40: + if step >= 20: torch.cuda.synchronize() start = time.time() CudaTimer(enable=True).start('e2e') @@ -916,7 +916,7 @@ def train_iter(model, dataloader): if step == 1: print('> passed on 1st iteration') memory_summary() - if step >= 40: + if step >= 20: torch.cuda.synchronize() stop = time.time() span += (stop - start) * 1000 @@ -924,13 +924,13 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = span / (iter_num-40) + iter_time = span / (iter_num-20) throughput = N / iter_time * 1000 print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) memory_summary() - CudaTimer().print_all(times=iter_num-40) + CudaTimer().print_all(times=iter_num-20) if __name__ == '__main__': From 748fe2794b8f1a691009570a745f112e34ec499d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 5 Dec 2021 12:00:03 +0000 Subject: [PATCH 0475/1892] dp + pp --- examples/efficientnet/schedule.py | 56 ++++++++++++++++++++----------- examples/efficientnet/train.py | 43 +++++++++++++----------- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/examples/efficientnet/schedule.py b/examples/efficientnet/schedule.py index 5b271bbc..281d2814 100644 --- a/examples/efficientnet/schedule.py +++ b/examples/efficientnet/schedule.py @@ -1,10 +1,18 @@ import torch +from torch.distributed.distributed_c10d import _get_global_rank from cube.profiler.timer import CudaTimer -def is_last_stage(): - return torch.distributed.get_rank() == torch.distributed.get_world_size() - 1 +def get_global_rank(group, group_rank): + if group is None: + return group_rank + else: + return _get_global_rank(group, group_rank) + + +def is_last_stage(group): + return torch.distributed.get_rank(group=group) == torch.distributed.get_world_size(group=group) - 1 #================= WhatToDO functions ==================# @@ -34,12 +42,14 @@ def backward_step(feature_map, output_tensor, output_tensor_grad): #================= Between Stage functions ==================# -def send(tensors, to_rank): +def send(tensors, to_rank, group): """ send tensor to the target rank """ - if to_rank < 0 or to_rank >= torch.distributed.get_world_size(): + if to_rank < 0 or to_rank >= torch.distributed.get_world_size(group): return None + if group is not None: + to_rank = get_global_rank(group, to_rank) assert isinstance(tensors, list) or isinstance(tensors, tuple) CudaTimer().start("send") reqs = list() @@ -60,10 +70,13 @@ def send(tensors, to_rank): CudaTimer().stop("send") -def recv(shapes, from_rank, dtype=torch.float): - if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): +def recv(shapes, from_rank, dtype, group): + if from_rank < 0 or from_rank >= torch.distributed.get_world_size(group): return [None] * len(shapes) assert isinstance(shapes, list) or isinstance(shapes, tuple) + if group is not None: + from_rank = get_global_rank(group, from_rank) + # print(f'recv: {torch.distributed.get_rank()} <- {from_rank}: {shapes}') CudaTimer().start("recv") reqs = list() recved_tensors = list() @@ -88,9 +101,12 @@ def recv(shapes, from_rank, dtype=torch.float): return recved_tensors -def send_and_recv(send_tensors, recv_shapes, rank, dtype=torch.float): - if rank < 0 or rank >= torch.distributed.get_world_size(): +def send_and_recv(send_tensors, recv_shapes, rank, dtype, group): + if rank < 0 or rank >= torch.distributed.get_world_size(group): return [None] * len(recv_shapes) + if group is not None: + rank = get_global_rank(group, rank) + # print(f'exchange: {torch.distributed.get_rank()} <-> {rank}: {recv_shapes}') assert isinstance(send_tensors, list) or isinstance(send_tensors, tuple) assert isinstance(recv_shapes, list) or isinstance(recv_shapes, tuple) CudaTimer().start("send_recv") @@ -145,13 +161,13 @@ def split_batch(inputs, num_microbatches): #================= Scheduling ==================# -def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): - myrank = torch.distributed.get_rank() +def scheduling_1f1b(model, inputs, bs, micro_bs, dtype, group): + myrank = torch.distributed.get_rank(group) num_microbatches = int(bs / micro_bs) num_warmup_microbatches = \ - (torch.distributed.get_world_size() - - torch.distributed.get_rank() - 1) + (torch.distributed.get_world_size(group) - + torch.distributed.get_rank(group) - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_warmup_remaining = num_microbatches - num_warmup_microbatches @@ -165,14 +181,14 @@ def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): # recv forward # print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) feature_map = recv( - (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype + (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype, group )[0] image = inputs[i][0] # forward output_tensor = forward_step(model, image, feature_map) # send forward # print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) - send((output_tensor,), myrank+1) + send((output_tensor,), myrank+1, group) input_tensors.append(feature_map) output_tensors.append(output_tensor) @@ -182,7 +198,7 @@ def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): # recv forward # print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) feature_map = recv( - (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype + (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype, group )[0] image = inputs[num_warmup_microbatches][0] @@ -195,7 +211,7 @@ def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): output_tensor_grad = send_and_recv( (output_tensor,), (torch.Size([micro_bs] + model.out_size),), - myrank+1, dtype + myrank+1, dtype, group )[0] input_tensors.append(feature_map) output_tensors.append(output_tensor) @@ -208,14 +224,14 @@ def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): feature_map = send_and_recv( (input_tensor_grad,), (torch.Size([micro_bs] + model.in_size),), - myrank-1, dtype + myrank-1, dtype, group )[0] image = inputs[num_warmup_microbatches+i+1][0] else: # last iteration - no more inputs feature_map = None # send backward grads # print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) - send((input_tensor_grad,), myrank-1) + send((input_tensor_grad,), myrank-1, group) # cooldown gradient trans back for i in range(num_warmup_microbatches): @@ -223,12 +239,12 @@ def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): output_tensor = output_tensors.pop(0) # recv backward gradients output_tensor_grad = recv( - (torch.Size([micro_bs] + model.out_size),), myrank+1, dtype + (torch.Size([micro_bs] + model.out_size),), myrank+1, dtype, group )[0] # backward input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) # send backward gradients # print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) - send((input_tensor_grad,), myrank-1) + send((input_tensor_grad,), myrank-1, group) #================= Scheduling ==================# \ No newline at end of file diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index 73a2f7af..f7894fb6 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -9,9 +9,7 @@ examples/efficientnet/train.py \ --pp 8 --gbs 8 --mbs 1 """ - import torch -from torch import nn from examples.efficientnet.efficientnet import EfficientNet import time import argparse @@ -20,15 +18,14 @@ from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup -from cube.runtime.reducer import Reducer from examples.efficientnet.schedule import is_last_stage, scheduling_1f1b def model_partition(model, in_size): + resource = cube.runtime.resource.EnvResource() # pipeline stage - pp_rank = torch.distributed.get_rank() - pp_size = torch.distributed.get_world_size() + pp_rank = torch.distributed.get_rank(resource.pp_group) + pp_size = torch.distributed.get_world_size(resource.pp_group) layers = model._blocks @@ -128,7 +125,7 @@ def model_partition(model, in_size): model.preprocess = False model.in_size = layers[0].in_size - if is_last_stage(): + if is_last_stage(resource.pp_group): model.postprocess = True model.out_size = [1,] else: @@ -139,8 +136,7 @@ def model_partition(model, in_size): def train(args): - - N = args.gbs + resource = cube.runtime.resource.EnvResource() # L2 config C, H, W = [3, 800, 800] @@ -159,10 +155,10 @@ def train(args): print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) memory_summary() - if N % args.gbs != 0: + if args.gbs % args.dp != 0: raise RuntimeError("global bs is not divisible by DP") dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [N // args.dp, C, H, W]) + 1280, [0], [args.gbs // args.dp, C, H, W]) if args.fp16: data_buff = [[e.half() for e in data] for data in dataloader.datas] @@ -170,16 +166,24 @@ def train(args): def train_iter(model, dataloader): img = next(dataloader) - scheduling_1f1b(model, [img], args.gbs, args.mbs, dtype=torch.float) + scheduling_1f1b(model, [img], args.gbs // args.dp, args.mbs, torch.float, resource.pp_group) + CudaTimer().start('dp_allreduce') + resource.reducer.allreduce() + CudaTimer().stop('dp_allreduce') optimizer = torch.optim.RMSprop(model.parameters()) + if args.dp > 1: + print_each_rank('adding param for allreduce sync') + for param in model.parameters(): + resource.reducer.add_param(param) + CudaTimer(enable=False).warmup() torch.distributed.barrier() span = 0 - iter_num = 40 + iter_num = 10 for step in range(iter_num): - if step >= 20: + if step >= 10: torch.cuda.synchronize() start = time.time() CudaTimer(enable=True).start('e2e') @@ -189,21 +193,21 @@ def train_iter(model, dataloader): if step == 1: print('> passed on 1st iteration') memory_summary() - if step >= 20: + if step >= 10: torch.cuda.synchronize() stop = time.time() span += (stop - start) * 1000 CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: + if (step + 1) % 10 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = CudaTimer().duration(iter_num-20, field_name='e2e') - throughput = N / iter_time * 1000 + iter_time = CudaTimer().duration(iter_num-10, field_name='e2e') + throughput = args.gbs / iter_time * 1000 print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) - CudaTimer().print_all(times=iter_num-20) + CudaTimer().print_all(times=iter_num-10) memory_summary() @@ -258,6 +262,7 @@ def train_iter(model, dataloader): print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) # initialize pipelne parallel groups + resource.pp_group = -1 for i in range(dp_size): ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] From a1c9726ef081767128946e0a0440d68ada606fc7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 5 Dec 2021 12:21:29 +0000 Subject: [PATCH 0476/1892] fix train bug --- examples/efficientnet/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index f7894fb6..a46ae634 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -181,7 +181,7 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() torch.distributed.barrier() span = 0 - iter_num = 10 + iter_num = 20 for step in range(iter_num): if step >= 10: torch.cuda.synchronize() From c61a8ecab1144f18b031ba10805bffac73fe3bca Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 8 Dec 2021 12:20:59 +0000 Subject: [PATCH 0477/1892] sync workers --- scripts/env-setup.sh | 1 + scripts/sync.sh | 4 ++++ scripts/sync4.sh | 5 +++++ 3 files changed, 10 insertions(+) create mode 100755 scripts/sync.sh create mode 100755 scripts/sync4.sh diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index a4898aff..d8fd916f 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -13,6 +13,7 @@ sudo chmod -R a+w /opt/conda sudo apt-get install tmux -y sudo apt-get install psmisc -y sudo apt-get install lsof -y +sudo apt-get install infiniband-diags -y # install blob # sudo apt-get install lsb-release -y diff --git a/scripts/sync.sh b/scripts/sync.sh new file mode 100755 index 00000000..16737ab6 --- /dev/null +++ b/scripts/sync.sh @@ -0,0 +1,4 @@ + +# usually worker-1 +worker_name=$1 +scp -r /workspace/MagicCube/examples ${worker_name}:/workspace/MagicCube/ \ No newline at end of file diff --git a/scripts/sync4.sh b/scripts/sync4.sh new file mode 100755 index 00000000..1edf2a53 --- /dev/null +++ b/scripts/sync4.sh @@ -0,0 +1,5 @@ + +scp -r /workspace/MagicCube/examples worker-1:/workspace/MagicCube/ +scp -r /workspace/MagicCube/examples worker-2:/workspace/MagicCube/ +scp -r /workspace/MagicCube/examples worker-3:/workspace/MagicCube/ +scp -r /workspace/MagicCube/examples worker-4:/workspace/MagicCube/ \ No newline at end of file From 6d990550d5a84cb0a66c5d63e8f0960cd241e688 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 8 Dec 2021 12:22:02 +0000 Subject: [PATCH 0478/1892] update swin --- examples/swin/layers.py | 26 + examples/swin/swin_dt.py | 6 +- examples/swin/swin_flexflow.py | 993 +++++++++++++++++++++++++++++++++ examples/swin/swin_hybrid.py | 6 +- examples/swin/swin_pipe.py | 4 +- 5 files changed, 1027 insertions(+), 8 deletions(-) create mode 100644 examples/swin/swin_flexflow.py diff --git a/examples/swin/layers.py b/examples/swin/layers.py index 27929d06..d9b80e2e 100644 --- a/examples/swin/layers.py +++ b/examples/swin/layers.py @@ -134,6 +134,23 @@ def backward(ctx, grad_output): group = ctx.group return _split(grad_output, group, dim=0), None +class ValueTPtoEleDPAdapter(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, group): + """ + Reduce Scatter + """ + ctx.group = group + return _scatter(input_, group, dim=0) + + @staticmethod + def backward(ctx, grad_output): + """ + Allgather + """ + group = ctx.group + return _gather(grad_output, group, dim=0), None + class TPtoDPAdapter(torch.autograd.Function): @staticmethod @@ -298,3 +315,12 @@ def __init__(self, tp_group): def forward(self, input_): return TPtoDPAdapter.apply(input_, self.group) + +class ValueTPtoEleDP(torch.nn.Module): + + def __init__(self, tp_group): + super().__init__() + self.group = tp_group + + def forward(self, input_): + return ValueTPtoEleDPAdapter.apply(input_, self.group) diff --git a/examples/swin/swin_dt.py b/examples/swin/swin_dt.py index 8a0c3311..27107b49 100644 --- a/examples/swin/swin_dt.py +++ b/examples/swin/swin_dt.py @@ -820,9 +820,9 @@ def train(args, pconfigs): ] # # SwinV2-H modified: 782 M - # embed_dim, depths, num_heads = [ - # 384, [2, 2, 18, 2], [12, 24, 48, 96] - # ] + embed_dim, depths, num_heads = [ + 384, [2, 2, 18, 2], [12, 24, 48, 96] + ] # # head dim 32 -> 48 # embed_dim, depths, num_heads = [ # 576, [2, 2, 18, 2], [12, 24, 48, 96] diff --git a/examples/swin/swin_flexflow.py b/examples/swin/swin_flexflow.py new file mode 100644 index 00000000..cb1f5fdf --- /dev/null +++ b/examples/swin/swin_flexflow.py @@ -0,0 +1,993 @@ + +# -------------------------------------------------------- +# Modified from Swin-Transformer Repo +""" +python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/swin/swin_flexflow.py --bs 8 \ + --layer0 8 1 \ + --layer1 8 1 \ + --layer2 1 8 \ + --layer3 1 8 +""" +# -------------------------------------------------------- + +from typing import Dict, Optional, Tuple +import torch +import torch.nn as nn +import argparse +import time + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary +from cube.runtime.device import DeviceGroup +from cube.runtime.reducer import Reducer + +from examples.swin.layers import ColumnParallelLinear, DPtoTP, ValueTPtoEleDP, RowParallelLinear, TPtoDP + +_dp_reducer: Dict[Tuple[int], Reducer] = dict() + + +def setup_device_group(tp: int, dp: int, layer_id: int): + """ + Layer wise device group initialize + + Returns: + + """ + resource = cube.runtime.resource.EnvResource() + ndevs = resource.ngpus + + if not tp * dp == ndevs: + raise RuntimeError("Expected same device number") + + assert tp == 1 or dp == 1, "Currently hybrid not supported" + + devs = cube.runtime.device.DeviceGroup() + + myrank = torch.distributed.get_rank() + + # initialize tensor parallel groups + for i in range(dp): + ranks = list(range(i * tp, (i + 1) * tp)) + group = devs.get_group(ranks) + if myrank in ranks: + tp_ranks = ranks + print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) + + # initialize data parallel groups + for i in range(tp): + ranks = list(range(i, ndevs, tp)) + group = devs.get_group(ranks) + if myrank in ranks: + dp_ranks = ranks + print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) + return tp_ranks, dp_ranks + + +def drop_path(x, drop_prob: float = 0.): + if drop_prob == 0.: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class MegatronMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tp_group=-1): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=tp_group) + self.act = act_layer() + # self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=False, tp_group=tp_group) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x: torch.Tensor, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # [B, H_window_num, window_size, W_window_num, window_size, C] + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + # [B, H_window_num, W_window_num, window_size, window_size, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + # [B * H_windows_num * W_window_size, window_size, window_size, C] + windows = windows.view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def window_position_index(window_size_h: int, window_size_w: int): + coords_h = torch.arange(window_size_h) + coords_w = torch.arange(window_size_w) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_w - 1 + relative_coords[:, :, 0] *= 2 * window_size_w - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +class MegatronWindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., tp_group=-1): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.global_num_heads = num_heads + + tp_world_size = torch.distributed.get_world_size(group=tp_group) + if num_heads % tp_world_size != 0: + raise RuntimeError(f'detecting un-even num head {num_heads} partition to {tp_world_size}') + self.num_heads = num_heads // tp_world_size + + self.dim_heads = dim // self.global_num_heads + self.scale = qk_scale or self.dim_heads ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # relative position index + relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) + self.register_buffer('relative_position_index', relative_position_index) + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) + self.attn_drop = nn.Dropout(attn_drop) + # self.proj = nn.Linear(dim, dim) + self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=False, tp_group=tp_group) + self.proj_drop = nn.Dropout(proj_drop) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] + # [Wh * Ww, Wh * Ww, nH] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + tp_group=-1): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = MegatronWindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + tp_group=tp_group) + + self.drop_path_p = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MegatronMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop, + tp_group=tp_group + ) + + self.partition_all_op = True + if self.partition_all_op and torch.distributed.get_world_size(tp_group) > 1: + print('> enabled all-op partitioning...') + self.val_tp_to_dp = ValueTPtoEleDP(tp_group) + self.tp_to_dp = TPtoDP(tp_group) + self.dp_to_tp = DPtoTP(tp_group) + else: + self.tp_to_dp = None + self.dp_to_tp = None + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] + x_windows = window_partition(shifted_x, self.window_size) + # -> [B * num_windows, window_size_h * windows_size_w, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # same in/out: [B * num_windows, window_size_h * windows_size_w, C] + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + # reverse cyclic shift + # [B, H', W', C] -> [B, H, W, C] + x = shifted_x + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + # [B, H, W, C] -> [B, H * W, C] + x = x.view(B, H * W, C) + + if self.partition_all_op and self.tp_to_dp is not None: + x = self.val_tp_to_dp(x) + shortcut = self.tp_to_dp(shortcut) + + # [B, H * W, C] -> [B, H * W, C] + x = shortcut + drop_path(x, self.drop_path_p) + + if self.partition_all_op and self.dp_to_tp is not None: + x = self.dp_to_tp(x) + + # FFN + # [B, H * W, C] -> [B, H * W, C] + ffn = self.norm2(x) + # [B, H * W, C] -> [B, H * W, C] + ffn = self.mlp(ffn) + # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] + + if self.partition_all_op and self.tp_to_dp is not None: + x = self.val_tp_to_dp(x) + ffn = self.tp_to_dp(ffn) + + x = x + drop_path(ffn, self.drop_path_p) + + if self.partition_all_op and self.dp_to_tp is not None: + x = self.dp_to_tp(x) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, + tp_group=-1, layer_id=-1): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList() + for i in range(depth): + block = SwinTransformerBlock( + dim=dim, input_resolution=self.input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + tp_group=tp_group, + ) + self.blocks.append(block) + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + return x + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, fp16=False, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + # self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + # if self.ape: + # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + # trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + + # ====================== depth 0 =========================== + pconfig = pconfigs[0] + l0_tp_ranks, l0_dp_ranks = setup_device_group(**pconfig) + tp_group = DeviceGroup().get_group(l0_tp_ranks) + + input_resolution = ( + patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) + ) + self.basic_layer0 = BasicLayer( + dim=int(embed_dim * 2 ** 0), + input_resolution=input_resolution, + depth=depths[0], + num_heads=num_heads[0], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:0]):sum(depths[:0 + 1])], + norm_layer=norm_layer, + tp_group=tp_group, + ) + + if len(l0_dp_ranks) > 1: + dp_ranks = tuple(l0_dp_ranks) + if dp_ranks not in _dp_reducer: + _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.patch_embed.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.basic_layer0.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + + # ====================== depth 1 =========================== + pconfig = pconfigs[1] + l1_tp_ranks, l1_dp_ranks = setup_device_group(**pconfig) + tp_group = DeviceGroup().get_group(l1_tp_ranks) + + # adapter + if len(l0_dp_ranks) > 1 and len(l1_tp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter01 = DPtoTP(DeviceGroup().get_group(l0_dp_ranks)) + elif len(l0_tp_ranks) > 1 and len(l1_dp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter01 = TPtoDP(DeviceGroup().get_group(l0_tp_ranks)) + else: + self.adapter01 = torch.nn.Identity() + + self.merging0 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 0), norm_layer=norm_layer + ) + + input_resolution = ( + patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) + ) + self.basic_layer1 = BasicLayer( + dim=int(embed_dim * 2 ** 1), + input_resolution=input_resolution, + depth=depths[1], + num_heads=num_heads[1], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:1]):sum(depths[:1 + 1])], + norm_layer=norm_layer, + tp_group=tp_group, + ) + + if len(l1_dp_ranks) > 1: + dp_ranks = tuple(l1_dp_ranks) + if dp_ranks not in _dp_reducer: + _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.basic_layer1.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.merging0.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + + # ====================== depth 2 =========================== + pconfig = pconfigs[2] + l2_tp_ranks, l2_dp_ranks = setup_device_group(**pconfig) + tp_group = DeviceGroup().get_group(l2_tp_ranks) + + # adapter + if len(l1_dp_ranks) > 1 and len(l2_tp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter12 = DPtoTP(DeviceGroup().get_group(l1_dp_ranks)) + elif len(l1_tp_ranks) > 1 and len(l2_dp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter12 = TPtoDP(DeviceGroup().get_group(l1_tp_ranks)) + else: + self.adapter12 = torch.nn.Identity() + + + self.merging1 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 1), norm_layer=norm_layer + ) + + input_resolution = ( + patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) + ) + self.basic_layer2 = BasicLayer( + dim=int(embed_dim * 2 ** 2), + input_resolution=input_resolution, + depth=depths[2], + num_heads=num_heads[2], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:2]):sum(depths[:2 + 1])], + norm_layer=norm_layer, + tp_group=tp_group + ) + + if len(l2_dp_ranks) > 1: + dp_ranks = tuple(l2_dp_ranks) + if dp_ranks not in _dp_reducer: + _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.basic_layer2.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.merging1.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + # ====================== depth 3 =========================== + pconfig = pconfigs[3] + l3_tp_ranks, l3_dp_ranks = setup_device_group(**pconfig) + tp_group = DeviceGroup().get_group(l3_tp_ranks) + + # adapter + if len(l2_dp_ranks) > 1 and len(l3_tp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter23 = DPtoTP(DeviceGroup().get_group(l2_dp_ranks)) + elif len(l2_tp_ranks) > 1 and len(l3_dp_ranks) > 1: + print_each_rank('add dp to tp adapters') + self.adapter23 = TPtoDP(DeviceGroup().get_group(l2_tp_ranks)) + else: + self.adapter23 = torch.nn.Identity() + + self.merging2 = PatchMerging( + input_resolution, dim=int(embed_dim * 2 ** 2), norm_layer=norm_layer + ) + + self.basic_layer3 = BasicLayer( + dim=int(embed_dim * 2 ** 3), + input_resolution=(patches_resolution[0] // (2 ** 3), + patches_resolution[1] // (2 ** 3)), + depth=depths[3], + num_heads=num_heads[3], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:3]):sum(depths[:3 + 1])], + norm_layer=norm_layer, + tp_group=tp_group + ) + + if len(l3_dp_ranks) > 1: + dp_ranks = tuple(l3_dp_ranks) + if dp_ranks not in _dp_reducer: + _dp_reducer[dp_ranks] = Reducer(dp_ranks) + for param in self.basic_layer3.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.merging2.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) + + if len(l3_dp_ranks) > 1: + dp_ranks = tuple(l3_dp_ranks) + for param in self.norm.parameters(): + _dp_reducer[dp_ranks].add_param(param) + for param in self.head.parameters(): + _dp_reducer[dp_ranks].add_param(param) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x) + + CudaTimer().start('basic_layer0') + x = self.basic_layer0(x) + CudaTimer().start('adapter') + x = self.adapter01(x) + CudaTimer().stop('adapter') + x = self.merging0(x) + CudaTimer().stop('basic_layer0') + + CudaTimer().start('basic_layer1') + x = self.basic_layer1(x) + CudaTimer().start('adapter') + x = self.adapter12(x) + CudaTimer().stop('adapter') + x = self.merging1(x) + CudaTimer().stop('basic_layer1') + + CudaTimer().start('basic_layer2') + x = self.basic_layer2(x) + CudaTimer().start('adapter') + x = self.adapter23(x) + CudaTimer().stop('adapter') + x = self.merging2(x) + CudaTimer().stop('basic_layer2') + + CudaTimer().start('basic_layer3') + x = self.basic_layer3(x) + CudaTimer().stop('basic_layer3') + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C L + x = torch.flatten(x, 1) + + x = self.head(x) + return x + + +def train(args, pconfigs): + + # dim_head is always 32 + + # img resolution, windows size: 224, 384, 518, 640 + C, H, W, window_size = [3, 224, 224, 7] + # C, H, W, window_size = [3, 384, 384, 12] + # C, H, W, window_size = [3, 518, 518, ?] + # C, H, W, window_size = [3, 640, 640, 20] + + # image batch size + N = args.bs + + # Swin-Tiny + # embed_dim, depths, num_heads = [ + # 96, [2, 2, 6, 2], [3, 6, 12, 24] + # ] + + # SwinV2-B: 87 M + # embed_dim, depths, num_heads = [ + # 128, [2, 2, 18, 2], [4, 8, 16, 32] + # ] + + # SwinV2-L: 196 M + # embed_dim, depths, num_heads = [ + # 192, [2, 2, 18, 2], [6, 12, 24, 48] + # ] + + # SwinV2-H: 657 M + embed_dim, depths, num_heads = [ + 352, [2, 2, 18, 2], [11, 22, 44, 88] + ] + + # # SwinV2-H modified: 782 M + embed_dim, depths, num_heads = [ + 384, [2, 2, 18, 2], [12, 24, 48, 96] + ] + # # head dim 32 -> 48 + # embed_dim, depths, num_heads = [ + # 576, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 64 -- too much + # embed_dim, depths, num_heads = [ + # 768, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 80 + # embed_dim, depths, num_heads = [ + # 960, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 96 + # embed_dim, depths, num_heads = [ + # 1152, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # # head dim 32 -> 112 + # embed_dim, depths, num_heads = [ + # 1344, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 128 + # embed_dim, depths, num_heads = [ + # 1536, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 144 + # embed_dim, depths, num_heads = [ + # 1728, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + # # head dim 32 -> 160 + # embed_dim, depths, num_heads = [ + # 1920, [2, 2, 18, 2], [12, 24, 48, 96] + # ] + + + model = SwinTransformer(img_size = H, + embed_dim = embed_dim, + depths = depths, + num_heads = num_heads, + window_size = window_size, + pconfigs = pconfigs) + if args.fp16: + print_each_rank('use half precision') + model = model.half() + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + model = model.cuda() + memory_summary() + + dataloader = cube.runtime.syndata.SynDataLoader( + 1280, [0], [N // args.dp, C, H, W]) + + if args.fp16: + data_buff = [[e.half() for e in data] for data in dataloader.datas] + dataloader.datas = data_buff + + def train_iter(model, dataloader): + img = next(dataloader) + loss = model(img) + loss = torch.sum(loss) + loss.backward() + CudaTimer().start('dp_allreduce') + for ranks in _dp_reducer: + reducer = _dp_reducer[ranks] + reducer.allreduce() + CudaTimer().stop('dp_allreduce') + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # start training + nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 + print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) + + CudaTimer(enable=False).warmup() + torch.distributed.barrier() + span = 0 + iter_num = 60 + for step in range(iter_num): + if step >= 20: + torch.cuda.synchronize() + start = time.time() + CudaTimer(enable=True).start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step == 1: + print('> passed on 1st iteration') + memory_summary() + if step >= 20: + torch.cuda.synchronize() + stop = time.time() + span += (stop - start) * 1000 + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + iter_time = span / (iter_num-20) + throughput = N / iter_time * 1000 + print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( + iter_time, throughput) + ) + memory_summary() + CudaTimer().print_all(times=iter_num-20) + + +if __name__ == '__main__': + + cube.init() + + # resource allocation + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--layer0', type=int, nargs='+', + help='data, tensor parallel config') + parser.add_argument('--layer1', type=int, nargs='+', + help='data, tensor parallel config') + parser.add_argument('--layer2', type=int, nargs='+', + help='data, tensor parallel config') + parser.add_argument('--layer3', type=int, nargs='+', + help='data, tensor parallel config') + parser.add_argument('--bs', type=int, default=1, + help='bs') + parser.add_argument('--fp16', action='store_true', dest='fp16') + args = parser.parse_args() + + assert len(args.layer0) == 2 + assert len(args.layer1) == 2 + assert len(args.layer2) == 2 + assert len(args.layer3) == 2 + + # data parallel should be same + args.dp = args.layer0[0] + + pconfigs = [ + dict(layer_id=0, dp=args.layer0[0], tp=args.layer0[1]), # basic layer 0 + dict(layer_id=1, dp=args.layer1[0], tp=args.layer1[1]), # basic layer 1 + dict(layer_id=2, dp=args.layer2[0], tp=args.layer2[1]), # basic layer 2 + dict(layer_id=3, dp=args.layer3[0], tp=args.layer3[1]), # basic layer 3 + ] + + print_each_rank(pconfigs, rank_only=0) + train(args, pconfigs) diff --git a/examples/swin/swin_hybrid.py b/examples/swin/swin_hybrid.py index 7c02af94..6fada771 100644 --- a/examples/swin/swin_hybrid.py +++ b/examples/swin/swin_hybrid.py @@ -968,9 +968,9 @@ def train(args, pconfigs): 1728, [2, 2, 18, 2], [12, 24, 48, 96] ] # head dim 32 -> 160 - embed_dim, depths, num_heads = [ - 1920, [2, 2, 18, 2], [12, 24, 48, 96] - ] + # embed_dim, depths, num_heads = [ + # 1920, [2, 2, 18, 2], [12, 24, 48, 96] + # ] # SwinV2-G: 2.5B Model # embed_dim, depths, num_heads = [ diff --git a/examples/swin/swin_pipe.py b/examples/swin/swin_pipe.py index f74ea8e6..3bd429b9 100644 --- a/examples/swin/swin_pipe.py +++ b/examples/swin/swin_pipe.py @@ -604,10 +604,10 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, stop = start + chunk # self.use_checkpoint = [True] * (stop - start) - self.use_checkpoint = [False] * (stop - start) + # self.use_checkpoint = [False] * (stop - start) # 8gpu layer assign - # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] + # layer_split = [3, 4, 3, 3, 3, 4, 3, 4] # assert sum(layer_split) == 27 # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) From 46f1e1f4b22f7301fb7d6bf9cbc9e13c143734df Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 8 Dec 2021 12:22:26 +0000 Subject: [PATCH 0479/1892] linear --- examples/mlp/linears.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 6a1f6290..8603d52a 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -29,27 +29,27 @@ def __init__(self, dim, mult=1): self.linear2 = nn.Linear(dim * mult, dim) self.linear3 = nn.Linear(dim, dim * mult) self.linear4 = nn.Linear(dim * mult, dim) - # self.linear5 = nn.Linear(dim, dim * mult) - # self.linear6 = nn.Linear(dim * mult, dim) - # self.linear7 = nn.Linear(dim, dim * mult) - # self.linear8 = nn.Linear(dim * mult, dim) + self.linear5 = nn.Linear(dim, dim * mult) + self.linear6 = nn.Linear(dim * mult, dim) + self.linear7 = nn.Linear(dim, dim * mult) + self.linear8 = nn.Linear(dim * mult, dim) def forward(self, data): output = self.linear1(data) output = self.linear2(output) output = self.linear3(output) output = self.linear4(output) - # output = self.linear5(output) - # output = self.linear6(output) - # output = self.linear7(output) - # output = self.linear8(output) + output = self.linear5(output) + output = self.linear6(output) + output = self.linear7(output) + output = self.linear8(output) loss = torch.sum(output) return loss def train(): - batch_size = 8192 - dim = 8192 + batch_size = 4096 + dim = 4096 model = MLP(dim=dim) model = cube.SemanticModel( From 70f74d3f9fb0665d639360c889e9bfc17049a51f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 8 Dec 2021 12:22:50 +0000 Subject: [PATCH 0480/1892] update --- examples/efficientnet/efficientnet.py | 5 ++++- examples/efficientnet/train.py | 26 ++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/examples/efficientnet/efficientnet.py b/examples/efficientnet/efficientnet.py index cf867633..4c8c1973 100644 --- a/examples/efficientnet/efficientnet.py +++ b/examples/efficientnet/efficientnet.py @@ -102,7 +102,7 @@ def forward(self, inputs, drop_connect_rate=None): Returns: Output of this block after processing. """ - + # before_allocated = torch.cuda.max_memory_allocated() # Expansion and Depthwise Convolution x = inputs assert list(x.shape)[1:] == self.in_size @@ -136,6 +136,9 @@ def forward(self, inputs, drop_connect_rate=None): x = x + inputs # skip connection assert list(x.shape)[1:] == self.out_size + # after_allocated = torch.cuda.max_memory_allocated() + # consumption = (after_allocated - before_allocated) / 1024 / 1024 + # print('{} {}'.format(self.layer_id, consumption)) return x def set_swish(self, memory_efficient=True): diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index a46ae634..b99d1667 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -7,7 +7,7 @@ --master_port=8004 \ --use_env \ examples/efficientnet/train.py \ - --pp 8 --gbs 8 --mbs 1 + --pp 8 --gbs 32 --mbs 1 """ import torch from examples.efficientnet.efficientnet import EfficientNet @@ -28,6 +28,8 @@ def model_partition(model, in_size): pp_size = torch.distributed.get_world_size(resource.pp_group) layers = model._blocks + # for lid, layer in enumerate(layers): + # layer.layer_id = lid chunk = len(layers) // pp_size if len(layers) % pp_size != 0: @@ -41,10 +43,25 @@ def model_partition(model, in_size): start = pp_rank * chunk stop = start + chunk + use_checkpoint = [False] * (stop - start) + # use_checkpoint = [True] * (stop - start) + + # 2 stage + # layer_split = [30, 58] + # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) # use_checkpoint = [False] * (stop - start) - use_checkpoint = [True] * (stop - start) + # if pp_rank == 0: + # for idx in range(stop - start): + # if idx < 23: + # use_checkpoint[idx] = True + # if pp_rank == 1: + # for idx in range(stop - start): + # if idx < 20: + # use_checkpoint[idx] = True - # layer_split = [8, 5, 7, 14, 15, 13, 15, 11] + # layer_split = [8, 5, 8, 13, 16, 12, 16, 10] # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" # start = sum(layer_split[0:pp_rank]) # stop = sum(layer_split[0:pp_rank+1]) @@ -52,12 +69,13 @@ def model_partition(model, in_size): # use_checkpoint = [False] * (stop - start) # if pp_rank == 0: # for idx in range(stop - start): - # if idx < 7: + # if idx < 3: # use_checkpoint[idx] = True # if pp_rank == 1: # for idx in range(stop - start): # if idx < 4: # use_checkpoint[idx] = True + # if pp_rank == 2: # for idx in range(stop - start): # if idx < 4: From 07d469e2b2f9b1906e8e64fa17815773d058389e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 10 Dec 2021 08:26:05 +0000 Subject: [PATCH 0481/1892] efficientnet training config for 16 gpus --- examples/efficientnet/train.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/examples/efficientnet/train.py b/examples/efficientnet/train.py index b99d1667..a49b4ea7 100644 --- a/examples/efficientnet/train.py +++ b/examples/efficientnet/train.py @@ -44,7 +44,7 @@ def model_partition(model, in_size): stop = start + chunk use_checkpoint = [False] * (stop - start) - # use_checkpoint = [True] * (stop - start) + use_checkpoint = [True] * (stop - start) # 2 stage # layer_split = [30, 58] @@ -99,6 +99,34 @@ def model_partition(model, in_size): # if idx < 2: # use_checkpoint[idx] = True + # 16GPU + # layer_split = [4, 3, 3, 3, 3, 5, 6, 7, 8, 8, 6, 6, 8, 8, 6, 4] + # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" + # start = sum(layer_split[0:pp_rank]) + # stop = sum(layer_split[0:pp_rank+1]) + # use_checkpoint = [False] * (stop - start) + # if pp_rank == 1: + # for idx in range(stop - start): + # if idx < 2: + # use_checkpoint[idx] = True + # if pp_rank == 2: + # for idx in range(stop - start): + # if idx < 1: + # use_checkpoint[idx] = True + use_checkpoint = [False] * (stop - start) + if pp_rank == 0: + for idx in range(stop - start): + if idx < 2: + use_checkpoint[idx] = True + if pp_rank == 1: + for idx in range(stop - start): + if idx < 4: + use_checkpoint[idx] = True + if pp_rank == 2: + for idx in range(stop - start): + if idx < 3: + use_checkpoint[idx] = True + # 8GB memory experiments # layer_split = [8, 5, 7, 14, 14, 13, 16, 11] # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" @@ -224,6 +252,9 @@ def train_iter(model, dataloader): print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( iter_time, throughput) ) + compute_time = CudaTimer().duration(iter_num-10, field_name='forward') + \ + CudaTimer().duration(iter_num-10, field_name='backward') + print_each_rank(f'compute time: {compute_time} ms') CudaTimer().print_all(times=iter_num-10) memory_summary() From c878e1283f9af5fa5a13ac757a30f92d3bf484c6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 12 Dec 2021 08:21:30 +0000 Subject: [PATCH 0482/1892] fix swin infer bug --- cube/runtime/function/dist.py | 9 ++++----- examples/swin/swin_dwt_infer.py | 16 ++++++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py index f367daf4..400bc3ad 100644 --- a/cube/runtime/function/dist.py +++ b/cube/runtime/function/dist.py @@ -26,7 +26,6 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, gro world_size = len(dim_ranks) if world_size == 1: return torch.roll(input, (shift), (dim,)) - global_rank = torch.distributed.get_rank() dim_rank = dim_ranks.index(torch.distributed.get_rank(group)) # halo exchange at H dimension if shift < 0: @@ -50,11 +49,11 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, gro send_op = torch.distributed.P2POp( torch.distributed.isend, remote, - send_global_rank, group=group, tag=global_rank + send_global_rank ) recv_op = torch.distributed.P2POp( torch.distributed.irecv, recv_tensor, - recv_global_rank, group=group, tag=recv_global_rank + recv_global_rank ) ops = [send_op, recv_op] if dim_rank % 2 == 0 else [recv_op, send_op] reqs = torch.distributed.batch_isend_irecv(ops) @@ -84,11 +83,11 @@ def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, gro send_op = torch.distributed.P2POp( torch.distributed.isend, remote, - send_global_rank, group=group, tag=global_rank + send_global_rank ) recv_op = torch.distributed.P2POp( torch.distributed.irecv, recv_tensor, - recv_global_rank, group=group, tag=recv_global_rank + recv_global_rank ) ops = [send_op, recv_op] if dim_rank % 2 == 0 else [recv_op, send_op] reqs = torch.distributed.batch_isend_irecv(ops) diff --git a/examples/swin/swin_dwt_infer.py b/examples/swin/swin_dwt_infer.py index c7f8ccda..a4c96c3d 100644 --- a/examples/swin/swin_dwt_infer.py +++ b/examples/swin/swin_dwt_infer.py @@ -281,7 +281,7 @@ class SwinTransformerBlock(nn.Module): def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - tp_group=-1, wp_plans=-1): + tp_group=-1, wp_plans=-1, layer_id=-1): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -289,8 +289,13 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows + self.wp_group, self.wp_nH_ranks, self.wp_nW_ranks = wp_plans + # if min(self.input_resolution) <= self.window_size: + # # if window size is larger than input resolution, we don't partition windows + # self.shift_size = 0 + # self.window_size = min(self.input_resolution) + if layer_id == 3: + print('set shift size to 0') self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" @@ -311,7 +316,6 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 tp_group=tp_group ) - self.wp_group, self.wp_nH_ranks, self.wp_nW_ranks = wp_plans self.use_wp = torch.distributed.get_world_size(self.wp_group) != 1 if self.shift_size > 0: @@ -329,7 +333,6 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) @@ -521,7 +524,8 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, - tp_group=tp_group, wp_plans=(wp_group, wp_nH_ranks, wp_nW_ranks) + tp_group=tp_group, wp_plans=(wp_group, wp_nH_ranks, wp_nW_ranks), + layer_id = layer_id ) self.blocks.append(block) From 1aa4d9caf811082dbfa68b8d50b5ed700ac50eea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 21 Dec 2021 17:16:28 +0800 Subject: [PATCH 0483/1892] add consumer & producer for ir tensor --- cube/compiler.py | 18 +++- cube/graph/graph.py | 2 +- cube/graph/operator/operator.py | 21 +++- cube/graph/parser/parser.py | 3 +- cube/graph/tensor.py | 144 +++++++++++++++------------- cube/ir/cten.py | 29 +++++- examples/mlp/policy/col_parallel.py | 2 +- 7 files changed, 141 insertions(+), 78 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index b97ffb8d..ec46dd48 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -146,22 +146,34 @@ def decorator(fn: Callable) -> Callable: for su in sugraph.sus(): if len(su.device) == 0: raise RuntimeError(f"SU {su} device is not set") - if not SUGraph.is_topo_order(sugraph.sus()): - raise RuntimeError(f"SUGraph order is not topological order") + # if not SUGraph.is_topo_order(sugraph.sus()): + # raise RuntimeError(f"SUGraph order is not topological order") execplan = ExectuionPlan(sugraph) # plan pass to adapt to pytorch semantic: multi branch gradient # TODO: residual support # execplan = TorchRefAdapter.apply(execplan) - # plan pass to remove redundant sus + # plan pass to remove redundant sus + start = time.time() execplan = RemoveRedundantAdapters.apply(execplan) + span = time.time() - start + print('> planpass on remove redundant adapter: {:.2f} s'.format(span)) # print(f'> after remove redundant adapters:\n {execplan}') + start = time.time() execplan = MergeComputeSU.apply(execplan) + span = time.time() - start + print('> planpass on merge compute: {:.2f} s'.format(span)) # print(f'> after merge backward SU:\n {execplan}') + start = time.time() execplan = WeightGradAllreduceFusion.apply(execplan) + span = time.time() - start + print('> planpass on grad allreduce: {:.2f} s'.format(span)) # print(f'> after add allreduce:\n{execplan}') + start = time.time() execplan = P2PFusion.apply(execplan) + span = time.time() - start + print('> planpass on p2p fusion: {:.2f} s'.format(span)) # print(f'> after fuse P2P SU:\n {execplan}') if torch.distributed.is_initialized(): diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5bd15c66..6a5bd2f1 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -339,7 +339,7 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional continue # TODO: requires_grad = False should be set to None val.grad = val.get_grad(fnode) - for related_op in val.parent.forward_dst_cells(): + for related_op in val.parent.consumers: for idx, rval in enumerate(related_op.inputs()): if val.overlap(rval): rval.grad = rval.get_grad(related_op) diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 9dcd28f5..e04f77cf 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -64,11 +64,20 @@ def set_input(self, input_index: int, val: Any): old_val = self.inputs(input_index) # remove the old one if isinstance(old_val, IRSubTensor): - old_val.parent._rm_fdst_cell(self) + old_val.parent.rm_consumer(self) if isinstance(val, IRSubTensor): - val.parent._add_fdst_cell(self) + val.parent.add_consumer(self, val) return super().set_input(input_index, val) + def set_output(self, output_index: int, val: Any): + old_val = self.outputs(output_index) + # remove the old one + if isinstance(old_val, IRSubTensor): + old_val.parent.rm_producer(self) + if isinstance(val, IRSubTensor): + val.parent.add_producer(self, val) + return super().set_output(output_index, val) + def replicate(self): """ Replicate the Operation @@ -303,6 +312,14 @@ def algorithms(self, tag: Optional[str] = None): return None template = factory.algorithms(type(self), tag) return template(self) + + def __repr__(self): + outputs = list() + for t in self.outputs(): + name = f't{t._id}(p{t.parent._id},{t.shape},{t.val_map})' + outputs.append(name) + dscp = f'DataLoader-{self._id}(outputs={outputs})' + return dscp class IROptimOperation(IRCell): diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index facc0d96..eae3c0e4 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -400,9 +400,8 @@ def flatten(smodule, depth=0): print(' '*depth, node) else: for node in smodule.graph.nodes(): - ntype = ScriptModuleParser.ntype(node) print(' '*depth, node) - if ntype == ScriptNodeKind.PrimCallMethod: + if node.kind() == 'prim::CallMethod': label = node.inputsAt(0).node().s('name') submodule = getattr(smodule, label) ScriptModuleParser.flatten(submodule, depth+1) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 927d2ad3..3a27a18e 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -301,18 +301,20 @@ def _to_value_map(val_map: Union[Tuple, ValueMap, None]): class IRFullTensor(IRTensor): - def __init__(self, shape=None, name=None, requires_grad=True): + def __init__(self, shape=None, name=None, requires_grad=True, dtype=float): - super().__init__(shape, name) + super().__init__(shape, name, dtype) - self._segments = list() - # indices: List[IndexMap] for each segment - self._indices: List = list() - # value op - self._val_maps: List = list() + # producer cell and produced sub tensor + self._producers: List[IRCell] = list() + self._ptensors : List[IRSubTensor] = list() - # track gradient - self._forward_dst_cells = list() + # consumer cell and consumed sub tensor + self._consumers: List[IRCell] = list() + self._ctensors : List[IRSubTensor] = list() + + # record all created sub_tensors + self._segments : List[IRSubTensor] = list() self.requires_grad = requires_grad if requires_grad: @@ -328,26 +330,67 @@ def __copy__(self): """ return self - def _add_fdst_cell(self, cell: IRCell): - if not isinstance(cell, IRCell): - raise TypeError("Expect an IRCell") - if cell not in self._forward_dst_cells: - if None in self._forward_dst_cells: - idx = self._forward_dst_cells.index(None) - self._forward_dst_cells[idx] = cell - else: - self._forward_dst_cells.append(cell) + @property + def producers(self) -> List[IRCell]: + """ + Producer IRCell list + """ + return self._producers + + @property + def ptensors(self): + """ + Produced IRSubTensor list correspongding to producer IRCell + """ + return self._ptensors + + @property + def consumers(self) -> List[IRCell]: + """ + Consumer IRCell list + """ + return self._consumers + + @property + def ctensors(self): + """ + Consumed IRSubTensor list correspongding to consumer IRCell + """ + return self._ctensors + + def add_producer(self, cell: IRCell, tensor: IRTensor): + if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): + raise TypeError("Expect an IRCell and an IRTensor") + if cell not in self.consumers: + self.producers.append(cell) + self.ptensors.append(tensor) - def _rm_fdst_cell(self, cell: IRCell): - if not isinstance(cell, IRCell): - raise TypeError("Expect an IRCell") - if cell in self._forward_dst_cells: - # setting to None to keep value map order - idx = self._forward_dst_cells.index(cell) - self._forward_dst_cells[idx] = None + def add_consumer(self, cell: IRCell, tensor: IRTensor): + if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): + raise TypeError("Expect an IRCell and an IRTensor") + if cell not in self.consumers: + self.consumers.append(cell) + self.ctensors.append(tensor) - def forward_dst_cells(self): - return [cell for cell in self._forward_dst_cells if cell is not None] + def rm_producer(self, cell: IRCell): + if cell not in self.producers: + raise KeyError(f"Cell {cell} not found in producer") + idx = self.producers.index(cell) + self.producers.pop(idx) + self.ptensors.pop(idx) + + def rm_consumer(self, cell: IRCell): + if cell not in self.consumers: + raise KeyError(f"Cell {cell} not found in producer") + idx = self.consumers.index(cell) + self.consumers.pop(idx) + self.ctensors.pop(idx) + + def subtensors(self): + """ + Get created sub-tensors of this tensor. + """ + return copy.copy(self._segments) def as_param(self): """ @@ -356,13 +399,13 @@ def as_param(self): self.requires_grad = True self._is_param = True self._is_grad = False - for sub_tensor in self._segments: + for sub_tensor in self.ptensors + self.ctensors: sub_tensor.as_param() def as_grad(self): self._is_param = False self._is_grad = True - for sub_tensor in self._segments: + for sub_tensor in self.ptensors + self.ctensors: sub_tensor.as_grad() return self @@ -379,33 +422,6 @@ def like(self): setattr(tensor, attr, getattr(self, attr)) return tensor - def segments(self, index: Optional[int] = None): - """ - Get the SubTensors at index position - """ - if index is None: - return copy.copy(self._segments) - else: - return self._segments[index] - - def indices(self, index: Optional[int] = None) -> IndexMap: - """ - Get the SubTensors mapping indices - """ - if index is None: - return copy.copy(self._indices) - else: - return self._indices[index] - - def val_maps(self, index: Optional[int] = None): - """ - Get the SubTensors val_map - """ - if index is None: - return copy.copy(self._val_maps) - else: - return self._val_maps[index] - def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap, None], shape: List[int]): """ Select a SubTensor from FullTensor. @@ -426,21 +442,16 @@ def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap indices = _to_index_map(indices) val_map = _to_value_map(val_map) - for idx in range(len(self._segments)): - indmap = self._indices[idx] - valmap = self._val_maps[idx] - sub_tensor = self._segments[idx] - if indmap == indices and valmap == val_map: + # return tensor to keep id same for same sub tensor + for sub_tensor in self.subtensors(): + if sub_tensor.indices == indices and sub_tensor.val_map == val_map: return sub_tensor sub_tensor = IRSubTensor(self, indices, val_map, shape) for attr in IRFullTensor._attr: setattr(sub_tensor, attr, getattr(self, attr)) sub_tensor.grad = None - self._segments.append(sub_tensor) - self._indices.append(indices) - self._val_maps.append(val_map) return sub_tensor def overlap(self, other): @@ -475,7 +486,7 @@ def common(self, other) -> Optional[IRTensor]: def tosub(self): """ - Convert to SubTensor by selecting all indices + Convert to SubTensor by selecting all indices and full value """ if self.shape is None: raise RuntimeError("Expected know shape") @@ -593,9 +604,8 @@ def get_grad(self, fcell: IRCell): if full_grad is None: return None if self in fcell.inputs(): - fdst_cells = self.parent.forward_dst_cells() ref_cells = list() - for dst_cell in fdst_cells: + for dst_cell in self.parent.consumers: for input in dst_cell.inputs(): if self.overlap(input): ref_cells.append(dst_cell) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index cec17531..a6a40bdf 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -370,9 +370,9 @@ class IRTensor: IRTensor serves as IRGraph edge """ - _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad'] + _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad', '_dtype'] - def __init__(self, shape=None, name=None): + def __init__(self, shape=None, name=None, dtype=float): self._id: int = IDGenerator().gen_tensor_id() self._shape: Optional(List[int]) = shape @@ -381,12 +381,37 @@ def __init__(self, shape=None, name=None): # device self._cell: List[IRCell] = list() + # TODO: support float16 + self._dtype = dtype + self._is_param = False self._is_grad = False self._requires_grad = True self._grad = None + @property + def requires_grad(self): + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, val: bool): + self._requires_grad = val + + @property + def dtype(self): + """ + Data type + """ + return self._dtype + + @dtype.setter + def dtype(self, val): + """ + Set data type + """ + self._dtype = val + def attach_cell(self, cell: IRCell): """ Attach to a cell, to be with input or output diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index 96e4225b..7fafdce4 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -25,7 +25,7 @@ def schedule_policy(sugraph: SUGraph, resource): """ The schedule policy assign devices """ - print(sugraph) + # print(sugraph) for su in sugraph.sus(): if su.stype == SUType.Dataloader: devid = su.tag[0] From 622abc1da9bbb979968fc22e759637a7ec979e47 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 21 Dec 2021 17:26:28 +0800 Subject: [PATCH 0484/1892] rename to indmap and valmap --- cube/algorithm/utils.py | 8 +- cube/codegen/codegen.py | 8 +- cube/execplan/planpass/gfuse.py | 2 +- cube/execplan/planpass/merge.py | 6 +- cube/execplan/planpass/p2pfusion.py | 30 ++-- cube/execplan/planpass/torchadapt.py | 4 +- cube/graph/gpass.py | 4 +- cube/graph/graph.py | 2 +- cube/graph/operator/operator.py | 12 +- cube/graph/tensor.py | 152 ++++++++++----------- cube/runtime/executor.py | 8 +- cube/runtime/transform.py | 8 +- cube/schedule/adapter/transform.py | 64 ++++----- tests/algorithm/test_bmm.py | 8 +- tests/algorithm/test_complex.py | 14 +- tests/algorithm/test_linear_algo.py | 24 ++-- tests/algorithm/test_memory.py | 2 +- tests/algorithm/test_reduce.py | 4 +- tests/graph/parser/test_parse_attention.py | 2 +- tests/graph/test_graph_partition.py | 6 +- tests/graph/test_tensor.py | 112 +++++++-------- tests/graph/test_tensor_grad.py | 24 ++-- tests/runtime/rollsplit.py | 50 ++++++- tests/schedule/test_adapter_transform.py | 72 +++++----- 24 files changed, 335 insertions(+), 291 deletions(-) diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py index 221280df..67fae0b0 100644 --- a/cube/algorithm/utils.py +++ b/cube/algorithm/utils.py @@ -27,8 +27,8 @@ def split_axis(tensor: IRTensor, axis: int, chunk_num: int): for cid in range(chunk_num): shape_slicer[axis] = slice(chunk_size * cid, chunk_size * (cid + 1), 1) sub_tensors.append(tensor.select( - indices = tuple(shape_slicer), - val_map = None, + indmap = tuple(shape_slicer), + valmap = None, shape = chunk_shape )) return sub_tensors @@ -44,8 +44,8 @@ def split_value(tensor: IRTensor, chunk_num: int): sub_tensors = list() for idx in range(chunk_num): sub_tensor = tensor.select( - indices = tuple(shape_slicer), - val_map = (idx, chunk_num), + indmap = tuple(shape_slicer), + valmap = (idx, chunk_num), shape = tensor.shape ) sub_tensors.append(sub_tensor) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 924c9d98..0b3f9515 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -276,12 +276,12 @@ def emit_transform_call(self, node: IRTensorTransform): """ for prim in node.trace(): if isinstance(prim, SelectPrim): - signature = 'cube.runtime.transform.select({tensor}, {indices}, {val_map})' + signature = 'cube.runtime.transform.select({tensor}, {indmap}, {valmap})' input = self.tensor_naming(prim.tensor) - indices = repr(prim.indices) - val_map = repr(tuple([prim.val_map.idx, prim.val_map.chunk_num])) + indmap = repr(prim.indmap) + valmap = repr(tuple([prim.valmap.idx, prim.valmap.chunk_num])) output = self.tensor_naming(prim.output) - code = f'{output} = {signature.format(tensor=input, indices=indices, val_map=val_map)}' + code = f'{output} = {signature.format(tensor=input, indmap=indmap, valmap=valmap)}' self.forward_region.append(code) elif isinstance(prim, MergePrim): signature = 'cube.runtime.transform.merge({tensors}, {concat}, {add})' diff --git a/cube/execplan/planpass/gfuse.py b/cube/execplan/planpass/gfuse.py index 3adced5f..28faef39 100644 --- a/cube/execplan/planpass/gfuse.py +++ b/cube/execplan/planpass/gfuse.py @@ -74,7 +74,7 @@ def _get_weight_grads(execplan: ExectuionPlan) -> Dict: print(grad) assert grad is not None # nothing to sync - if grad.val_map == ValueMap(0, 1): + if grad.valmap == ValueMap(0, 1): continue if input._id not in grads: grads[input._id] = dict() diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py index 8805e651..23def674 100644 --- a/cube/execplan/planpass/merge.py +++ b/cube/execplan/planpass/merge.py @@ -105,9 +105,9 @@ def _connected_by_adapter(execplan: ExectuionPlan, fpieces, devid: int): Check if there is an adapter connecting forward SUs """ sugraph = execplan.sugraph - indices = [execplan.sequence(devid).index(fsu) for fsu in fpieces] - start = min(indices) - end = max(indices) + indmap = [execplan.sequence(devid).index(fsu) for fsu in fpieces] + start = min(indmap) + end = max(indmap) # check fsu1 -> asu -> fsu2 for asu in execplan.sequence(devid)[start:end]: if asu.stype in [SUType.P2P, SUType.Transform, SUType.Coll]: diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py index c20d07b5..a72ccff9 100644 --- a/cube/execplan/planpass/p2pfusion.py +++ b/cube/execplan/planpass/p2pfusion.py @@ -177,7 +177,7 @@ def match_allreduce(tous, tins): in_tensors: Dict[int, List[IRTensor]] = dict() for devid in tins: for in_tensor in tins[devid]: - if in_tensor.val_map != ValueMap(0, 1): + if in_tensor.valmap != ValueMap(0, 1): continue tid = in_tensor._id if tid not in in_devices: @@ -201,16 +201,16 @@ def match_allreduce(tous, tins): # check same indice map and no overlap value map unique_indices = list() for odev in out_tensors: - indices = out_tensors[odev][0].indices - if indices not in unique_indices: - unique_indices.append(indices) + indmap = out_tensors[odev][0].indmap + if indmap not in unique_indices: + unique_indices.append(indmap) if len(unique_indices) != 1: continue # check no overlap valmaps all_valmaps = list() overlap = False for odev in out_tensors: - valmap = out_tensors[odev][0].val_map + valmap = out_tensors[odev][0].valmap for pre_valmp in all_valmaps: overlap = pre_valmp.overlap(valmap) all_valmaps.append(valmap) @@ -279,22 +279,22 @@ def match_allgather(tous, tins): # multiple transmission FIXME: remove redundancy if not all([len(out_tensors[odev]) == 1 for odev in out_devices]): continue - # check same value map and no overlap indices + # check same value map and no overlap indmap unique_valmaps = list() for odev in out_tensors: - valmap = out_tensors[odev][0].val_map + valmap = out_tensors[odev][0].valmap if valmap not in unique_valmaps: unique_valmaps.append(valmap) if len(unique_valmaps) != 1: continue - # check no overlap indices + # check no overlap indmap all_indices = list() overlap = False for odev in out_tensors: - indices = out_tensors[odev][0].indices + indmap = out_tensors[odev][0].indmap for pre_indices in all_indices: - overlap = pre_indices.overlap(indices) - all_indices.append(indices) + overlap = pre_indices.overlap(indmap) + all_indices.append(indmap) if overlap: continue @@ -335,7 +335,7 @@ def match_reducescatter(tous, tins): for devid in tins: for in_tensor in tins[devid]: tid = in_tensor._id - if in_tensor.val_map != ValueMap(0, 1): + if in_tensor.valmap != ValueMap(0, 1): continue if tid not in in_devices: in_devices[tid] = list() @@ -358,10 +358,10 @@ def match_reducescatter(tous, tins): # multiple transmission FIXME: remove redundancy if not all([len(out_tensors[odev]) == 1 for odev in out_tensors]): continue - if out_tensors[devid][0].val_map == ValueMap(0, 1): + if out_tensors[devid][0].valmap == ValueMap(0, 1): is_reduce = False break - if out_tensors[devid][0].indices != in_tensor.indices: + if out_tensors[devid][0].indmap != in_tensor.indmap: is_reduce = False break if is_reduce: @@ -439,7 +439,7 @@ def match_broadcast(tous, tins): for devid in tins: for in_tensor in tins[devid]: tid = in_tensor._id - if in_tensor.val_map != ValueMap(0, 1): + if in_tensor.valmap != ValueMap(0, 1): continue if tid not in in_devices: in_devices[tid] = list() diff --git a/cube/execplan/planpass/torchadapt.py b/cube/execplan/planpass/torchadapt.py index 8c467d5d..b1cf460e 100644 --- a/cube/execplan/planpass/torchadapt.py +++ b/cube/execplan/planpass/torchadapt.py @@ -78,8 +78,8 @@ def apply(execplan: ExectuionPlan): if ftensor is None: raise RuntimeError("Internal Error: fsu not found input tensor") grad = ftensor.parent.grad.select( - indices = ftensor.indices, - val_map = ValueMap(grad_idx, grad_num), + indmap = ftensor.indmap, + valmap = ValueMap(grad_idx, grad_num), shape = ftensor.shape ) rm_grad = TorchRefAdapter.set_grad(fsu, ftensor, grad) diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index f727a318..614a1ba6 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -25,8 +25,8 @@ def renew(self, val: Any, keep_param=True): if val.parent._id not in self.symbol: self.symbol[val.parent._id] = val.parent.like() new_val = self.symbol[val.parent._id].select( - indices=val.indices, - val_map=val.val_map, + indmap=val.indmap, + valmap=val.valmap, shape=val.shape ) return new_val diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 6a5bd2f1..c94ab0b6 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -323,7 +323,7 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional for fnode in fnodes: for input in fnode.inputs(): if isinstance(input, IRSubTensor): - if input.val_map.chunk_num != 1 and not input.is_param(): + if input.valmap.chunk_num != 1 and not input.is_param(): raise NotImplementedError( f"Not support feature-map {input} to be splitted in value as input" ) diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index e04f77cf..92a3b52a 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -106,7 +106,7 @@ def __repr__(self): valmap = (0,1) else: pid = tensor.parent._id - valmap = tensor.val_map + valmap = tensor.valmap inputs.append(f'{anno}{tensor._id}(p{pid},{tensor.shape},{valmap})') else: inputs.append(tensor) @@ -124,7 +124,7 @@ def __repr__(self): valmap = (0,1) else: pid = tensor.parent._id - valmap = tensor.val_map + valmap = tensor.valmap pid = tensor.parent._id if hasattr(tensor, 'parent') else tensor._id outputs.append(f'{anno}{tensor._id}(p{pid},{tensor.shape},{valmap})') else: @@ -222,7 +222,7 @@ def __repr__(self): if tensor.is_grad(): anno = 'g' # datas.append(f'{anno}{tensor._id}') - datas.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') + datas.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.valmap})') else: datas.append(tensor) @@ -235,7 +235,7 @@ def __repr__(self): if tensor.is_grad(): anno = 'g' # grads.append(f'{anno}{tensor._id}') - grads.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') + grads.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.valmap})') else: grads.append(tensor) @@ -248,7 +248,7 @@ def __repr__(self): if tensor.is_grad(): anno = 'g' # outputs.append(f'{anno}{tensor._id}') - outputs.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.val_map})') + outputs.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.valmap})') else: outputs.append(tensor) @@ -316,7 +316,7 @@ def algorithms(self, tag: Optional[str] = None): def __repr__(self): outputs = list() for t in self.outputs(): - name = f't{t._id}(p{t.parent._id},{t.shape},{t.val_map})' + name = f't{t._id}(p{t.parent._id},{t.shape},{t.valmap})' outputs.append(name) dscp = f'DataLoader-{self._id}(outputs={outputs})' return dscp diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 3a27a18e..8e8fa087 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -2,19 +2,19 @@ SubTensor Gradient rule: SubTensor's logical grad = SubTensor.parent.grad.select( - indices = SubTensor.indices, - val_map = SubTensor.val_map, + indmap = SubTensor.indmap, + valmap = SubTensor.valmap, shape = SubTensor.shape ) FwOperation -> BpOperation rule: 1). for (FwOp) input tensors, gradient SubTensor is: - indices = input.indices; - val is splitted by referencing times on the indices + indmap = input.indmap; + val is splitted by referencing times on the indmap 2). for (FwOp) output tensors, gradient SubTensor is: - indices = output.indices; + indmap = output.indmap; val is always (0/1) """ @@ -31,16 +31,16 @@ class IndexMap: - def __init__(self, indices): + def __init__(self, indmap): - if not isinstance(indices, tuple): - raise TypeError("Expected indices to be a tuple") + if not isinstance(indmap, tuple): + raise TypeError("Expected indmap to be a tuple") - if not all([isinstance(s, slice) for s in indices]): + if not all([isinstance(s, slice) for s in indmap]): raise NotImplementedError( "Only support for sliced index mapping" ) - self._indices = indices + self._indices = indmap def __eq__(self, other): if isinstance(other, IndexMap): @@ -58,7 +58,7 @@ def __eq__(self, other): def get(self): """ - Get indices + Get indmap """ return self._indices @@ -97,7 +97,7 @@ def shape(self) -> List[int]: def map(self, submap): """ - Map from the current indices by sub_indices. + Map from the current indmap by sub_indices. Args: sub_indices: IndexMap @@ -130,7 +130,7 @@ def map(self, submap): def overlap(self, other): """ - Check if this indices overlapped with the other + Check if this indmap overlapped with the other Args: other: IndexMap @@ -277,26 +277,26 @@ def __repr__(self): return f'({self.idx}/{self.chunk_num})' -def _to_index_map(indices: Union[Tuple, IndexMap]): - if not isinstance(indices, tuple) and not isinstance(indices, IndexMap): - raise TypeError("Expected indices to be tuple or IndexMap") - if isinstance(indices, tuple): - indices = IndexMap(indices) - return indices +def _to_indmap(indmap: Union[Tuple, IndexMap]): + if not isinstance(indmap, tuple) and not isinstance(indmap, IndexMap): + raise TypeError("Expected indmap to be tuple or IndexMap") + if isinstance(indmap, tuple): + indmap = IndexMap(indmap) + return indmap -def _to_value_map(val_map: Union[Tuple, ValueMap, None]): - if not isinstance(val_map, tuple) and \ - not isinstance(val_map, ValueMap) and \ - not val_map is None: - raise TypeError("Expected val_map to be tuple, IndexMap or None") - if val_map is None: - val_map = ValueMap(0, 1) - elif isinstance(val_map, tuple): - if len(val_map) != 2: +def _to_value_map(valmap: Union[Tuple, ValueMap, None]): + if not isinstance(valmap, tuple) and \ + not isinstance(valmap, ValueMap) and \ + not valmap is None: + raise TypeError("Expected valmap to be tuple, IndexMap or None") + if valmap is None: + valmap = ValueMap(0, 1) + elif isinstance(valmap, tuple): + if len(valmap) != 2: raise ValueError("Expected tuple to be (idx, chunk_num)") - val_map = ValueMap(*val_map) - return val_map + valmap = ValueMap(*valmap) + return valmap class IRFullTensor(IRTensor): @@ -422,32 +422,32 @@ def like(self): setattr(tensor, attr, getattr(self, attr)) return tensor - def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap, None], shape: List[int]): + def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, None], shape: List[int]): """ Select a SubTensor from FullTensor. Note due to implementation issue, one value in the full tensor - cannot be splitted by different val_map + cannot be splitted by different valmap Args: - indices: the index of this tensor's index + indmap: the index of this tensor's index - val_map: how the tensor mapped from original value + valmap: how the tensor mapped from original value shape: the sub_tensor shape. Returns: IRSubTensor """ - indices = _to_index_map(indices) - val_map = _to_value_map(val_map) + indmap = _to_indmap(indmap) + valmap = _to_value_map(valmap) # return tensor to keep id same for same sub tensor for sub_tensor in self.subtensors(): - if sub_tensor.indices == indices and sub_tensor.val_map == val_map: + if sub_tensor.indmap == indmap and sub_tensor.valmap == valmap: return sub_tensor - sub_tensor = IRSubTensor(self, indices, val_map, shape) + sub_tensor = IRSubTensor(self, indmap, valmap, shape) for attr in IRFullTensor._attr: setattr(sub_tensor, attr, getattr(self, attr)) sub_tensor.grad = None @@ -486,7 +486,7 @@ def common(self, other) -> Optional[IRTensor]: def tosub(self): """ - Convert to SubTensor by selecting all indices and full value + Convert to SubTensor by selecting all indmap and full value """ if self.shape is None: raise RuntimeError("Expected know shape") @@ -494,8 +494,8 @@ def tosub(self): for dim_len in self.shape: slicers.append(slice(0, dim_len, 1)) sub_tensor = self.select( - indices=tuple(slicers), - val_map=None, + indmap=tuple(slicers), + valmap=None, shape=self.shape ) return sub_tensor @@ -507,14 +507,14 @@ def __repr__(self): class IRSubTensor(IRTensor): - def __init__(self, full_tensor: IRTensor, indices, val_map: Optional[ValueMap] =None, shape=None): + def __init__(self, full_tensor: IRTensor, indmap, valmap: Optional[ValueMap] =None, shape=None): """ Create an IRSubTensor. Args: full_tensor: the full tensor - indices: index list - val_map: the value operation to merge SubTensors into one + indmap: index list + valmap: the value operation to merge SubTensors into one """ if not isinstance(full_tensor, IRFullTensor): raise TypeError(f"Expected IRFullTensor but got {full_tensor}") @@ -524,21 +524,21 @@ def __init__(self, full_tensor: IRTensor, indices, val_map: Optional[ValueMap] = self._full_tensor = full_tensor # the index from full_tensor - self._index_map = _to_index_map(indices) + self._indmap = _to_indmap(indmap) # val map - self._val_map = _to_value_map(val_map) + self._valmap = _to_value_map(valmap) def __eq__(self, other): if isinstance(other, IRFullTensor): return self.parent == other and \ self.shape == other.shape and \ - self.val_map == ValueMap(0, 1) + self.valmap == ValueMap(0, 1) if isinstance(other, IRSubTensor): return self.parent == other.parent and \ - self.indices == other.indices and \ - self.val_map == other.val_map and \ + self.indmap == other.indmap and \ + self.valmap == other.valmap and \ self.shape == other.shape return False @@ -550,15 +550,15 @@ def parent(self) -> IRFullTensor: return self._full_tensor @property - def indices(self) -> IndexMap: + def indmap(self) -> IndexMap: """ - Return indices list mapped to the full tensor + Return indmap list mapped to the full tensor """ - return copy.copy(self._index_map) + return copy.copy(self._indmap) @property - def val_map(self): - return copy.copy(self._val_map) + def valmap(self): + return copy.copy(self._valmap) def __copy__(self): """ @@ -568,7 +568,7 @@ def __copy__(self): Returns: tensor """ - tensor = IRSubTensor(self.parent, self.indices, self.val_map, self._shape) + tensor = IRSubTensor(self.parent, self.indmap, self.valmap, self._shape) for key in self.__dict__: setattr(tensor, key, getattr(self, key)) # clear attached cells @@ -615,45 +615,45 @@ def get_grad(self, fcell: IRCell): raise RuntimeError("Internal Error: ref time is 0") idx = ref_cells.index(fcell) grad = full_grad.select( - indices = self.indices, - val_map = (idx, ref_times), + indmap = self.indmap, + valmap = (idx, ref_times), shape = self.shape ) return grad.as_grad() elif self in fcell.outputs(): grad = full_grad.select( - indices = self.indices, - val_map = (0, 1), + indmap = self.indmap, + valmap = (0, 1), shape = self.shape ) return grad.as_grad() else: raise RuntimeError(f"{self} not found in cell {fcell}") - def select(self, indices: Union[Tuple, IndexMap], val_map: Union[Tuple, ValueMap, None], shape=None): + def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, None], shape=None): """ Select an IRSubTensor Args: - indices: the index of this tensor's index + indmap: the index of this tensor's index - val_map: the value operation to merge - co-located indices of SubTensors into one + valmap: the value operation to merge + co-located indmap of SubTensors into one shape: the sub_tensor shape Returns: IRSubTensor """ - sub_ind_map = _to_index_map(indices) - sub_val_map = _to_value_map(val_map) + sub_ind_map = _to_indmap(indmap) + sub_valmap = _to_value_map(valmap) # index mapping - index_map = self.indices.map(sub_ind_map) + index_map = self.indmap.map(sub_ind_map) # value mapping - val_map = self.val_map.map(sub_val_map) + valmap = self.valmap.map(sub_valmap) - sub_tensor = self.parent.select(index_map, val_map, shape) + sub_tensor = self.parent.select(index_map, valmap, shape) return sub_tensor def overlap(self, other): @@ -671,8 +671,8 @@ def overlap(self, other): elif isinstance(other, IRSubTensor): if self.parent != other.parent: return False - return self.indices.overlap(other.indices) and \ - self.val_map.overlap(other.val_map) + return self.indmap.overlap(other.indmap) and \ + self.valmap.overlap(other.valmap) else: raise TypeError("Customized IRTensor not support") @@ -691,12 +691,12 @@ def common(self, other): if isinstance(other, IRFullTensor): return self elif isinstance(other, IRSubTensor): - indices = self.indices & other.indices - val_map = self.val_map & other.val_map + indmap = self.indmap & other.indmap + valmap = self.valmap & other.valmap sub_tensor = self.parent.select( - indices = indices, - val_map = val_map, - shape = indices.shape + indmap = indmap, + valmap = valmap, + shape = indmap.shape ) return sub_tensor else: @@ -704,5 +704,5 @@ def common(self, other): return None def __repr__(self): - dscp = f'SubTensor(id={self._id}, shape={self.shape}, device={self.device}, ind={self.indices}, val={self.val_map})' + dscp = f'SubTensor(id={self._id}, shape={self.shape}, device={self.device}, ind={self.indmap}, val={self.valmap})' return dscp \ No newline at end of file diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 3f1ad947..aa61c1f4 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -29,18 +29,18 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr ) inputs = list() - indices = list() + indmap = list() for idx, input in enumerate(input_tensors): if torch.is_tensor(input) and input.requires_grad: inputs.append(input) - indices.append(idx) + indmap.append(idx) grads = [None] * len(input_tensors) if len(inputs) != 0: # print(f'{torch.distributed.get_rank()}: backwarding... ') in_grads = torch.autograd.grad( output_tensors, inputs, output_tensor_grads, allow_unused=True) - for idx, grad in zip(indices, in_grads): + for idx, grad in zip(indmap, in_grads): tensor = input_tensors[idx] if isinstance(tensor, torch.nn.Parameter): if tensor.grad is not None: @@ -55,7 +55,7 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr # grad_tensors=output_tensor_grads, # inputs=inputs # ) - # for idx, tensor in zip(indices, inputs): + # for idx, tensor in zip(indmap, inputs): # grads[idx] = tensor.grad # torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) diff --git a/cube/runtime/transform.py b/cube/runtime/transform.py index f3976501..63fe24b4 100644 --- a/cube/runtime/transform.py +++ b/cube/runtime/transform.py @@ -7,12 +7,12 @@ def select(tensor: torch.Tensor, - indices: Tuple[slice], val_map: Tuple[int, int]) -> torch.Tensor: + indmap: Tuple[slice], valmap: Tuple[int, int]) -> torch.Tensor: with torch.no_grad(): - sub_tensor = tensor[indices] - if val_map != (0, 1): - sub_tensor = sub_tensor / val_map[1] + sub_tensor = tensor[indmap] + if valmap != (0, 1): + sub_tensor = sub_tensor / valmap[1] sub_tensor = sub_tensor.contiguous() return sub_tensor diff --git a/cube/schedule/adapter/transform.py b/cube/schedule/adapter/transform.py index fee89af9..358eb9c6 100644 --- a/cube/schedule/adapter/transform.py +++ b/cube/schedule/adapter/transform.py @@ -77,10 +77,10 @@ def is_identity(self): class SelectPrim: - def __init__(self, tensor: IRSubTensor, indices: IndexMap, val_map: ValueMap, shape: List[int]): + def __init__(self, tensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, shape: List[int]): self.tensor = tensor - self.indices = indices - self.val_map = val_map + self.indmap = indmap + self.valmap = valmap self.shape = shape self.output = None @@ -88,7 +88,7 @@ def set_output(self, output: IRSubTensor): self.output = output def __repr__(self): - dscp = f't{self.output._id} = select(t{self.tensor._id}, {self.indices}, {self.val_map}, {self.shape})' + dscp = f't{self.output._id} = select(t{self.tensor._id}, {self.indmap}, {self.valmap}, {self.shape})' return dscp @@ -97,13 +97,13 @@ class SelectPlan: @staticmethod def gen(input: IRSubTensor, outputs: List[IRSubTensor]) -> List[SelectPrim]: trace: List[SelectPrim] = list() - islicers: List[slice] = input.indices.get() + islicers: List[slice] = input.indmap.get() for output in outputs: if output == input: continue - oslicers: List[slice] = output.indices.get() - # indices - indices = list() + oslicers: List[slice] = output.indmap.get() + # indmap + indmap = list() for islicer, oslicer in zip(islicers, oslicers): istart, istop, istep = islicer.start, islicer.stop, islicer.step ostart, ostop, ostep = oslicer.start, oslicer.stop, oslicer.step @@ -113,20 +113,20 @@ def gen(input: IRSubTensor, outputs: List[IRSubTensor]) -> List[SelectPrim]: start = ostart - istart stop = start + ostop - ostart slicer = slice(start, stop, ostep) - indices.append(slicer) - indices = IndexMap(tuple(indices)) + indmap.append(slicer) + indmap = IndexMap(tuple(indmap)) # value map - if output.val_map == input.val_map: - val_map = ValueMap(0, 1) - elif input.val_map == ValueMap(0, 1): - val_map = output.val_map + if output.valmap == input.valmap: + valmap = ValueMap(0, 1) + elif input.valmap == ValueMap(0, 1): + valmap = output.valmap else: print('from: ', input) print('to : ', output) raise NotImplementedError( - f"Not supported value select: {input.val_map} -> {output.val_map}" + f"Not supported value select: {input.valmap} -> {output.valmap}" ) - prim = SelectPrim(input, indices, val_map, output.shape) + prim = SelectPrim(input, indmap, valmap, output.shape) prim.set_output(output) trace.append(prim) return trace @@ -145,7 +145,7 @@ def __init__(self, self.output = None # re-order tensor if isinstance(concat, int): - slicers = [tensor.indices.get()[concat] for tensor in tensors] + slicers = [tensor.indmap.get()[concat] for tensor in tensors] starts = np.array([slicer.start for slicer in slicers], dtype=int) sorted_idx = np.argsort(starts) tensors = np.array(tensors)[sorted_idx] @@ -227,11 +227,11 @@ def concat(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: return None if tensor1.parent != tensor2.parent: return None - if tensor1.val_map != tensor2.val_map: + if tensor1.valmap != tensor2.valmap: return None - indices1 = tensor1.indices.get() - indices2 = tensor2.indices.get() - indices = list() + indices1 = tensor1.indmap.get() + indices2 = tensor2.indmap.get() + indmap = list() if len(indices1) != len(indices2): return None axis = None @@ -245,21 +245,21 @@ def concat(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: return None if start1 < start2 and stop1 == start2: axis = dim - indices.append(slice(start1, stop2, step1)) + indmap.append(slice(start1, stop2, step1)) elif start1 > start2 and start1 == stop2: axis = dim - indices.append(slice(start2, stop1, step1)) + indmap.append(slice(start2, stop1, step1)) else: return None else: - indices.append(slicer1) + indmap.append(slicer1) shapes = list() for idx, (nele1, nele2) in enumerate(zip(tensor1.shape, tensor2.shape)): nele = nele1 if idx != axis else nele1 + nele2 shapes.append(nele) mtensor = tensor1.parent.select( - indices = tuple(indices), - val_map = tensor1.val_map, + indmap = tuple(indmap), + valmap = tensor1.valmap, shape = shapes ) return mtensor, axis @@ -272,12 +272,12 @@ def add(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: return None if tensor1.parent != tensor2.parent: return None - if tensor1.indices != tensor2.indices: + if tensor1.indmap != tensor2.indmap: return None - if tensor1.val_map.chunk_num != tensor2.val_map.chunk_num: + if tensor1.valmap.chunk_num != tensor2.valmap.chunk_num: return None - chunk_num = tensor1.val_map.chunk_num - idx1, idx2 = tensor1.val_map.idx, tensor2.val_map.idx + chunk_num = tensor1.valmap.chunk_num + idx1, idx2 = tensor1.valmap.idx, tensor2.valmap.idx if chunk_num % 2 != 0: return None chunk_num = int(chunk_num // 2) @@ -285,8 +285,8 @@ def add(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: return None idx = int(idx1 // 2) mtensor = tensor1.parent.select( - indices = tensor1.indices, - val_map = (idx, chunk_num), + indmap = tensor1.indmap, + valmap = (idx, chunk_num), shape = tensor1.shape ) return mtensor diff --git a/tests/algorithm/test_bmm.py b/tests/algorithm/test_bmm.py index a8393294..a161c0f2 100644 --- a/tests/algorithm/test_bmm.py +++ b/tests/algorithm/test_bmm.py @@ -50,7 +50,7 @@ def test_bmm_data_parallel(): for output in outputs: print(output) assert output.shape == [B // 4, N, P] - assert output.val_map == ValueMap(0, 1) + assert output.valmap == ValueMap(0, 1) def test_bmm_n_parallel(): @@ -101,7 +101,7 @@ def test_bmm_n_parallel(): for output in outputs: print(output) assert output.shape == [B, N // 4, P] - assert output.val_map == ValueMap(0, 1) + assert output.valmap == ValueMap(0, 1) def test_bmm_m_parallel(): @@ -152,7 +152,7 @@ def test_bmm_m_parallel(): for idx, output in enumerate(outputs): print(output) assert output.shape == [B, N, P] - assert output.val_map == ValueMap(idx, 4) + assert output.valmap == ValueMap(idx, 4) def test_bmm_p_parallel(): @@ -203,4 +203,4 @@ def test_bmm_p_parallel(): for output in outputs: print(output) assert output.shape == [B, N, P // 4] - assert output.val_map == ValueMap(0, 1) \ No newline at end of file + assert output.valmap == ValueMap(0, 1) \ No newline at end of file diff --git a/tests/algorithm/test_complex.py b/tests/algorithm/test_complex.py index 3372f24a..03d86e37 100644 --- a/tests/algorithm/test_complex.py +++ b/tests/algorithm/test_complex.py @@ -330,7 +330,7 @@ def test_complex_self_attention_head_parallel(): for idx, node in enumerate(nodes): assert node.outputs(0).shape == [L, N, E] - assert node.outputs(0).val_map == ValueMap(idx, 4) + assert node.outputs(0).valmap == ValueMap(idx, 4) assert node.kwargs['num_head'] == num_head // 4 assert node.inputs(0).shape == [L, N, E] assert node.inputs(1).shape == [3 * E // 4, E] @@ -369,7 +369,7 @@ def test_complex_self_attention_data_parallel(): for idx, node in enumerate(nodes): assert node.outputs(0).shape == [L, N // 4, E] - assert node.outputs(0).val_map == ValueMap(0, 1) + assert node.outputs(0).valmap == ValueMap(0, 1) assert node.kwargs['num_head'] == num_head assert node.inputs(0).shape == [L, N // 4, E] assert node.inputs(1).shape == [3 * E, E] @@ -410,13 +410,13 @@ def test_complex_feedforward_tensor_parallel(): for idx, node in enumerate(nodes): assert node.outputs(0).shape == [L, N, E] - assert node.outputs(0).val_map == ValueMap(idx, 4) + assert node.outputs(0).valmap == ValueMap(idx, 4) assert node.inputs(0).shape == [L, N, E] assert node.inputs(1).shape == [4 * E // 4, E] assert node.inputs(2).shape == [4 * E // 4,] assert node.inputs(3).shape == [E, 4 * E // 4] assert node.inputs(4).shape == [E,] - assert node.inputs(4).val_map == ValueMap(idx, 4) + assert node.inputs(4).valmap == ValueMap(idx, 4) def test_complex_feedforward_data_parallel(): @@ -453,7 +453,7 @@ def test_complex_feedforward_data_parallel(): for idx, node in enumerate(nodes): assert node.outputs(0).shape == [L, N // 4, E] - assert node.outputs(0).val_map == ValueMap(0, 1) + assert node.outputs(0).valmap == ValueMap(0, 1) assert node.inputs(0).shape == [L, N // 4, E] assert node.inputs(1).shape == [4 * E, E] assert node.inputs(2).shape == [4 * E,] @@ -496,7 +496,7 @@ def test_embed_shard_parallel(): shard = (stop - start) // 4 for idx, node in enumerate(nodes): assert node.outputs(0).shape == [L, N, E] - assert node.outputs(0).val_map == ValueMap(idx, 4) + assert node.outputs(0).valmap == ValueMap(idx, 4) assert node.inputs(0).shape == [L, N] assert node.inputs(1).shape == [vocab // 4, E] assert node.kwargs['start'] == start + idx * shard @@ -536,7 +536,7 @@ def test_embed_shard_parallel(): stop = semantic_op.kwargs['stop'] for idx, node in enumerate(nodes): assert node.outputs(0).shape == [L, N // 4, E] - assert node.outputs(0).val_map == ValueMap(0, 1) + assert node.outputs(0).valmap == ValueMap(0, 1) assert node.inputs(0).shape == [L, N // 4] assert node.inputs(1).shape == [vocab, E] assert node.kwargs['start'] == start diff --git a/tests/algorithm/test_linear_algo.py b/tests/algorithm/test_linear_algo.py index 345d4292..7fd58008 100644 --- a/tests/algorithm/test_linear_algo.py +++ b/tests/algorithm/test_linear_algo.py @@ -45,7 +45,7 @@ def test_linear_data_parallel(): for idx, x in enumerate(inputs): assert x.shape == [256, 1024] - assert x.indices.get()[0] == slice(256 * idx, 256 * (idx + 1), 1) + assert x.indmap.get()[0] == slice(256 * idx, 256 * (idx + 1), 1) assert not inputs[0].overlap(inputs[1]) assert not inputs[0].overlap(inputs[2]) assert not inputs[0].overlap(inputs[3]) @@ -104,16 +104,16 @@ def test_linear_column_weight(): for idx, w in enumerate(weights): assert w.shape == [250, 1024] - assert w.indices.get()[0] == slice(250 * idx, 250 * (idx + 1), 1) + assert w.indmap.get()[0] == slice(250 * idx, 250 * (idx + 1), 1) for idx, b in enumerate(biass): assert b.shape == [250] - assert b.indices.get() == (slice(250 * idx, 250 * (idx + 1), 1),) + assert b.indmap.get() == (slice(250 * idx, 250 * (idx + 1), 1),) for idx, output in enumerate(outputs): assert output.shape == [1024, 250] - assert output.indices.get()[0] == slice(0, 1024, 1) - assert output.indices.get()[1] == slice(250 * idx, 250 * (idx + 1), 1) + assert output.indmap.get()[0] == slice(0, 1024, 1) + assert output.indmap.get()[1] == slice(250 * idx, 250 * (idx + 1), 1) def test_linear_row(): @@ -167,19 +167,19 @@ def test_linear_row(): for idx, x in enumerate(inputs): assert x.shape == [1024, 256] - assert x.indices.get()[1] == slice(256 * idx, 256 * (idx + 1), 1) - assert x.val_map == ValueMap(0, 1) + assert x.indmap.get()[1] == slice(256 * idx, 256 * (idx + 1), 1) + assert x.valmap == ValueMap(0, 1) for idx, w in enumerate(weights): assert w.shape == [1000, 256] - assert w.indices.get()[1] == slice(256 * idx, 256 * (idx + 1), 1) - assert w.val_map == ValueMap(0, 1) + assert w.indmap.get()[1] == slice(256 * idx, 256 * (idx + 1), 1) + assert w.valmap == ValueMap(0, 1) for idx, b in enumerate(biass): assert b.shape == [1000,] - assert b.indices.get()[0] == slice(0, 1000, 1) - assert b.val_map == ValueMap(idx, 4) + assert b.indmap.get()[0] == slice(0, 1000, 1) + assert b.valmap == ValueMap(idx, 4) for idx, output in enumerate(outputs): assert output.shape == [1024, 1000] - assert output.val_map == ValueMap(idx, 4) + assert output.valmap == ValueMap(idx, 4) diff --git a/tests/algorithm/test_memory.py b/tests/algorithm/test_memory.py index efab4015..ed1b5176 100644 --- a/tests/algorithm/test_memory.py +++ b/tests/algorithm/test_memory.py @@ -43,4 +43,4 @@ def test_transpose_dim_parallel(): for output in node.outputs(): print(output) assert output.shape == [N, M // 4] - assert output.val_map == ValueMap(0, 1) + assert output.valmap == ValueMap(0, 1) diff --git a/tests/algorithm/test_reduce.py b/tests/algorithm/test_reduce.py index a4df3d2e..d2304616 100644 --- a/tests/algorithm/test_reduce.py +++ b/tests/algorithm/test_reduce.py @@ -40,7 +40,7 @@ def test_reduce_dim_parallel(): for output in node.outputs(): print(output) assert output.shape == [1] - assert output.val_map == ValueMap(idx, 4) + assert output.valmap == ValueMap(idx, 4) dim = 1 @@ -65,4 +65,4 @@ def test_reduce_dim_parallel(): for output in node.outputs(): print(output) assert output.shape == [256] - assert output.val_map == ValueMap(0, 1) + assert output.valmap == ValueMap(0, 1) diff --git a/tests/graph/parser/test_parse_attention.py b/tests/graph/parser/test_parse_attention.py index b90137bd..c8cdf25c 100644 --- a/tests/graph/parser/test_parse_attention.py +++ b/tests/graph/parser/test_parse_attention.py @@ -73,7 +73,7 @@ def forward(self, x): #, mask): # attn = attn.masked_fill_(mask, -100000.0) # # [N, num_heads, L, L] -> [(N * num_heads), L, L] # attn = attn.view((bs * self.num_heads), self.seq_len, self.seq_len) - attn = cube.runtime.function.tril_mask(attn, bs) + attn = cube.runtime.function.tril_mask(attn, self.num_heads) # [(N * num_heads), L, L] -> [(N * num_heads), L, L] attn = F.softmax(attn, dim=-1) diff --git a/tests/graph/test_graph_partition.py b/tests/graph/test_graph_partition.py index f3a348f5..ce9d4f10 100644 --- a/tests/graph/test_graph_partition.py +++ b/tests/graph/test_graph_partition.py @@ -130,7 +130,7 @@ def test_linear_hybrid_partition(): print('grad :', data_grad) print('grad ref:', data_grad_ref) assert data_grad == data_grad_ref - assert data_grad.val_map == ValueMap(idx, ngpus) + assert data_grad.valmap == ValueMap(idx, ngpus) print('===== algo 2 =====') for idx, su in enumerate(algosu2): @@ -141,7 +141,7 @@ def test_linear_hybrid_partition(): print('grad :', data_grad) print('grad ref:', data_grad_ref) assert data_grad == data_grad_ref - assert data_grad.val_map == ValueMap(idx, ngpus) + assert data_grad.valmap == ValueMap(idx, ngpus) print('===== algo 3 =====') for idx, su in enumerate(algosu3): @@ -152,6 +152,6 @@ def test_linear_hybrid_partition(): print('grad :', data_grad) print('grad ref:', data_grad_ref) assert data_grad == data_grad_ref - assert data_grad.val_map == ValueMap(idx, ngpus) + assert data_grad.valmap == ValueMap(idx, ngpus) assert False diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py index 13f212b6..66052ca9 100644 --- a/tests/graph/test_tensor.py +++ b/tests/graph/test_tensor.py @@ -19,18 +19,18 @@ def test_full_tensor_select(): tensor = IRFullTensor(shape=[1024,1024], name='tensor') assert len(tensor.segments()) == 0 - assert len(tensor.indices()) == 0 + assert len(tensor.indmap()) == 0 assert len(tensor.val_maps()) == 0 sub_tensor1 = tensor.select( - indices = (slice(0, 1024), slice(0, 512)), - val_map = None, + indmap = (slice(0, 1024), slice(0, 512)), + valmap = None, shape = (1024, 512) ) sub_tensor2 = tensor.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) @@ -41,7 +41,7 @@ def test_full_tensor_select(): assert sub_tensor2.name == 'tensor' assert len(tensor.segments()) == 2 - assert len(tensor.indices()) == 2 + assert len(tensor.indmap()) == 2 assert len(tensor.val_maps()) == 2 @@ -49,19 +49,19 @@ def test_full_tensor_overlap(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( - indices = (slice(0, 1024), slice(256, 1024)), - val_map = None, + indmap = (slice(0, 1024), slice(256, 1024)), + valmap = None, shape = (1024, 768) ) sub_tensor2 = tensor1.select( - indices = (slice(0, 1024, 2), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 1024, 2), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) sub_tensor3 = tensor1.select( - indices = (slice(1, 1024, 2), slice(512, 1024)), - val_map = None, + indmap = (slice(1, 1024, 2), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) @@ -79,25 +79,25 @@ def test_sub_tensor_select(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) sub_tensor2 = sub_tensor1.select( - indices = (slice(512, 1024), slice(0, 256)), - val_map = None, + indmap = (slice(512, 1024), slice(0, 256)), + valmap = None, shape = (512, 256) ) sub_tensor3 = sub_tensor1.select( - indices = (slice(512, 1024), slice(256, 512)), - val_map = None, + indmap = (slice(512, 1024), slice(256, 512)), + valmap = None, shape = (512, 256) ) - indices = sub_tensor2.indices.get() - assert indices == (slice(512, 1024, 1), slice(512, 768, 1)) - indices = sub_tensor3.indices.get() - assert indices == (slice(512, 1024, 1), slice(768, 1024, 1)) + indmap = sub_tensor2.indmap.get() + assert indmap == (slice(512, 1024, 1), slice(512, 768, 1)) + indmap = sub_tensor3.indmap.get() + assert indmap == (slice(512, 1024, 1), slice(768, 1024, 1)) assert len(tensor1.segments()) == 3 assert sub_tensor1 in tensor1.segments() @@ -109,18 +109,18 @@ def test_sub_tensor_ind_overlap(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) sub_tensor2 = sub_tensor1.select( - indices = (slice(512, 1024), slice(0, 256)), - val_map = None, + indmap = (slice(512, 1024), slice(0, 256)), + valmap = None, shape = (512, 256) ) sub_tensor3 = sub_tensor1.select( - indices = (slice(512, 1024), slice(256, 512)), - val_map = None, + indmap = (slice(512, 1024), slice(256, 512)), + valmap = None, shape = (512, 256) ) @@ -132,23 +132,23 @@ def test_sub_tensor_ind_overlap(): def test_sub_tensor_val_overlap(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) sub_tensor2 = tensor1.select( - indices = (slice(0, 1024), slice(0, 512)), - val_map = (0, 4), + indmap = (slice(0, 1024), slice(0, 512)), + valmap = (0, 4), shape = (1024, 512) ) sub_tensor3 = tensor1.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = (0, 4), + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = (0, 4), shape = (1024, 512) ) sub_tensor4 = tensor1.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = (1, 4), + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = (1, 4), shape = (1024, 512) ) @@ -163,23 +163,23 @@ def test_sub_tensor_common(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor_col1 = tensor1.select( - indices = (slice(0, 1024), slice(0, 512)), - val_map = None, + indmap = (slice(0, 1024), slice(0, 512)), + valmap = None, shape = (1024, 512) ) sub_tensor_col2 = tensor1.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) sub_tensor_row1 = tensor1.select( - indices = (slice(0, 512), slice(0, 1024)), - val_map = None, + indmap = (slice(0, 512), slice(0, 1024)), + valmap = None, shape = (512, 1024) ) sub_tensor_row2 = tensor1.select( - indices = (slice(512, 1024), slice(0, 1024)), - val_map = None, + indmap = (slice(512, 1024), slice(0, 1024)), + valmap = None, shape = (512, 1024) ) @@ -188,17 +188,17 @@ def test_sub_tensor_common(): lb = sub_tensor_row2.common(sub_tensor_col1) rb = sub_tensor_row2.common(sub_tensor_col2) - assert lt.indices.get() == (slice(0, 512, 1), slice(0, 512, 1)) - assert rt.indices.get() == (slice(0, 512, 1), slice(512, 1024, 1)) - assert lb.indices.get() == (slice(512, 1024, 1), slice(0, 512, 1)) - assert rb.indices.get() == (slice(512, 1024, 1), slice(512, 1024, 1)) + assert lt.indmap.get() == (slice(0, 512, 1), slice(0, 512, 1)) + assert rt.indmap.get() == (slice(0, 512, 1), slice(512, 1024, 1)) + assert lb.indmap.get() == (slice(512, 1024, 1), slice(0, 512, 1)) + assert rb.indmap.get() == (slice(512, 1024, 1), slice(512, 1024, 1)) def test_sub_tensor_as_grad(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) @@ -206,8 +206,8 @@ def test_sub_tensor_as_grad(): assert sub_tensor1.is_grad() sub_tensor2 = tensor1.select( - indices = (slice(0, 1024), slice(0, 512)), - val_map = (0, 4), + indmap = (slice(0, 1024), slice(0, 512)), + valmap = (0, 4), shape = (1024, 512) ) assert sub_tensor2.is_grad() @@ -216,13 +216,13 @@ def test_sub_tensor_as_grad(): def test_sub_tensor_copy(): tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') sub_tensor1 = tensor1.select( - indices = (slice(0, 1024), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 1024), slice(512, 1024)), + valmap = None, shape = (1024, 512) ) sub_tensor2 = tensor1.select( - indices = (slice(0, 1024), slice(0, 512)), - val_map = (0, 4), + indmap = (slice(0, 1024), slice(0, 512)), + valmap = (0, 4), shape = (1024, 512) ) sub_tensor1.grads = [sub_tensor2] diff --git a/tests/graph/test_tensor_grad.py b/tests/graph/test_tensor_grad.py index 227899c5..e1948814 100644 --- a/tests/graph/test_tensor_grad.py +++ b/tests/graph/test_tensor_grad.py @@ -87,8 +87,8 @@ def test_tensor_grad(): for pten in all_parent_tensors: assert pten.grad is None print(pten.name, pten) - cell_ids = [cell._id for cell in pten.forward_dst_cells()] - print('forward_dst_cells id:', cell_ids) + cell_ids = [cell._id for cell in pten.consumers] + print('consumers id:', cell_ids) print('') print('test grad:') @@ -96,37 +96,37 @@ def test_tensor_grad(): input = linear1.inputs(0) assert input.grad is None gin = input.get_grad(linear1) - assert gin.val_map == ValueMap(0, 1) + assert gin.valmap == ValueMap(0, 1) print(gin.name, gin) weight = linear1.inputs(1) gw = weight.get_grad(linear1) - assert gw.val_map == ValueMap(0, 2) + assert gw.valmap == ValueMap(0, 2) print(gw.name, gw) weight = linear4.inputs(1) gw = weight.get_grad(linear4) - assert gw.val_map == ValueMap(1, 2) + assert gw.valmap == ValueMap(1, 2) print(gw.name, gw) out2 = linear2.outputs(0) gout2 = out2.get_grad(linear2) print(gout2.name, gout2) - assert gout2.val_map == ValueMap(0, 1) + assert gout2.valmap == ValueMap(0, 1) gout2 = out2.get_grad(linear3) print(gout2.name, gout2) - assert gout2.val_map == ValueMap(0, 2) + assert gout2.valmap == ValueMap(0, 2) gout2 = out2.get_grad(add5) print(gout2.name, gout2) - assert gout2.val_map == ValueMap(1, 2) + assert gout2.valmap == ValueMap(1, 2) out3 = linear3.outputs(0) gout3 = out3.get_grad(linear3) print(gout3.name, gout3) - assert gout3.val_map == ValueMap(0, 1) + assert gout3.valmap == ValueMap(0, 1) gout3 = out3.get_grad(add5) print(gout3.name, gout3) - assert gout3.val_map == ValueMap(0, 1) + assert gout3.valmap == ValueMap(0, 1) for node in graph.nodes(): assert node.mirror is None @@ -147,7 +147,7 @@ def test_tensor_grad(): assert gw1.parent == gw4.parent assert gw1.shape == gw4.shape - assert gw1.indices == gw4.indices - assert gw1.val_map != gw4.val_map + assert gw1.indmap == gw4.indmap + assert gw1.valmap != gw4.valmap # assert False diff --git a/tests/runtime/rollsplit.py b/tests/runtime/rollsplit.py index 20eb29e2..21dc065a 100644 --- a/tests/runtime/rollsplit.py +++ b/tests/runtime/rollsplit.py @@ -34,7 +34,7 @@ def test_roll_parallel(): CudaTimer().start(field_name='roll_halo') for _ in range(1000): roll_out = cube.runtime.function.roll_dim_parallel( - input, (9 // 2), 1, group + input, (9 // 2), 1, list(range(world_size)), group ) CudaTimer().stop(field_name='roll_halo') ref1 = roll_out @@ -62,6 +62,49 @@ def test_roll_parallel(): print('correctness test passed') +def test_roll_grid_parallel(): + # input size + group = None + world_size = torch.distributed.get_world_size(group=group) + myrank = torch.distributed.get_rank() + assert world_size == 4 + # input_size = [1, 224 // 2, 224 // 2, 256] + # input = torch.randn(input_size).cuda() * 10 + input = torch.arange(myrank * 4, (myrank + 1) * 4).view(1, 2, 2, 1).float() + input = input.cuda() + print_each_rank(f'input: {input.view(-1)}') + + CudaTimer().warmup(seconds=2) + + torch.distributed.barrier() + CudaTimer().start(field_name='roll_halo') + for _ in range(1): + roll_out = cube.runtime.function.roll_grid_parallel( + input, (9 // 2, 9 // 2), (1, 2), 2, 2, group + ) + CudaTimer().stop(field_name='roll_halo') + ref1 = roll_out + # print_each_rank(ref1, rank_only=0) + assert roll_out.shape == input.shape + span = CudaTimer().duration(times=1000, field_name='roll_halo') + print_each_rank('span on halo exchange: {:.2f} ms'.format(span)) + + torch.distributed.barrier() + CudaTimer().start(field_name='roll_allgather') + for _ in range(1): + roll_out = cube.runtime.function.roll_dim_allgather( + input, (9 // 2), 1, group + ) + roll_out = cube.runtime.function.roll_dim_allgather( + roll_out, (9 // 2), 2, group + ) + CudaTimer().stop(field_name='roll_allgather') + ref2 = roll_out + # print_each_rank(ref2, rank_only=0) + span = CudaTimer().duration(times=1000, field_name='roll_allgather') + print_each_rank('span on allgather exchange: {:.2f} ms'.format(span)) + + def test_roll_parallel_autograd(): group = None @@ -114,7 +157,8 @@ def test_grid_collection(): if __name__ == '__main__': cube.init() - # test_roll_parallel() + test_roll_parallel() + # test_roll_grid_parallel() # test_roll_parallel_autograd() # test_grid_partition() - test_grid_collection() \ No newline at end of file + # test_grid_collection() \ No newline at end of file diff --git a/tests/schedule/test_adapter_transform.py b/tests/schedule/test_adapter_transform.py index 9b5c831a..d0f97b44 100644 --- a/tests/schedule/test_adapter_transform.py +++ b/tests/schedule/test_adapter_transform.py @@ -9,26 +9,26 @@ def test_tensor_transform_select(): tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() tensor2 = tensor1.select( - indices = (slice(0, 512), slice(0, 1024)), - val_map = (0, 1), + indmap = (slice(0, 512), slice(0, 1024)), + valmap = (0, 1), shape = [512, 1024] ) tensor3 = tensor1.select( - indices = (slice(512, 1024), slice(0, 1024)), - val_map = (0, 2), + indmap = (slice(512, 1024), slice(0, 1024)), + valmap = (0, 2), shape = [512, 1024] ) tensor4 = tensor3.select( - indices = (slice(0, 256), slice(0, 512)), - val_map = (0, 1), + indmap = (slice(0, 256), slice(0, 512)), + valmap = (0, 1), shape = [256, 512] ) tensor5 = tensor3.select( - indices = (slice(256, 512), slice(0, 512)), - val_map = (0, 1), + indmap = (slice(256, 512), slice(0, 512)), + valmap = (0, 1), shape = [256, 512] ) @@ -58,32 +58,32 @@ def test_tensor_transform_merge(): tensor0 = IRFullTensor(shape=[1024,1024], name='test1').tosub() tensor1 = tensor0.select( - indices = (slice(0, 512), slice(0, 512)), - val_map = None, + indmap = (slice(0, 512), slice(0, 512)), + valmap = None, shape = [256, 1024] ) tensor2 = tensor0.select( - indices = (slice(0, 512), slice(512, 1024)), - val_map = None, + indmap = (slice(0, 512), slice(512, 1024)), + valmap = None, shape = [256, 1024] ) tensor3 = tensor0.select( - indices = (slice(512, 1024), slice(0, 512)), - val_map = None, + indmap = (slice(512, 1024), slice(0, 512)), + valmap = None, shape = [256, 512] ) tensor4 = tensor0.select( - indices = (slice(512, 1024), slice(512, 1024)), - val_map = None, + indmap = (slice(512, 1024), slice(512, 1024)), + valmap = None, shape = [256, 512] ) tensor5 = tensor0.select( - indices = (slice(512, 1024), slice(0, 1024)), - val_map = None, + indmap = (slice(512, 1024), slice(0, 1024)), + valmap = None, shape = [256, 512] ) @@ -113,29 +113,29 @@ def test_tensor_transform_merge(): # assert False tensor6 = tensor0.select( - indices = (slice(0, 256), slice(0, 1024)), - val_map = (0, 4), + indmap = (slice(0, 256), slice(0, 1024)), + valmap = (0, 4), shape = [256, 1024] ) tensor7 = tensor0.select( - indices = (slice(0, 256), slice(0, 1024)), - val_map = (1, 4), + indmap = (slice(0, 256), slice(0, 1024)), + valmap = (1, 4), shape = [256, 1024] ) tensor8 = tensor0.select( - indices = (slice(0, 256), slice(0, 1024)), - val_map = (2, 4), + indmap = (slice(0, 256), slice(0, 1024)), + valmap = (2, 4), shape = [256, 1024] ) tensor9 = tensor0.select( - indices = (slice(0, 256), slice(0, 1024)), - val_map = (3, 4), + indmap = (slice(0, 256), slice(0, 1024)), + valmap = (3, 4), shape = [256, 1024] ) tensor10 = tensor0.select( - indices = (slice(0, 256), slice(0, 1024)), - val_map = (0, 1) + indmap = (slice(0, 256), slice(0, 1024)), + valmap = (0, 1) ) merge3 = IRTensorTransform( @@ -154,26 +154,26 @@ def test_transform_identity(): tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() tensor2 = tensor1.select( - indices = (slice(512, 1024), slice(0, 1024)), - val_map = None, + indmap = (slice(512, 1024), slice(0, 1024)), + valmap = None, shape = [512, 1024] ) tensor3 = tensor2.select( - indices = (slice(0, 256), slice(0, 1024)), - val_map = None, + indmap = (slice(0, 256), slice(0, 1024)), + valmap = None, shape = [256, 1024] ) tensor4 = tensor1.select( - indices = (slice(512, 768), slice(0, 1024)), - val_map = None, + indmap = (slice(512, 768), slice(0, 1024)), + valmap = None, shape = [256, 1024] ) tensor5 = tensor1.select( - indices = (slice(512, 768), slice(0, 1024)), - val_map = None, + indmap = (slice(512, 768), slice(0, 1024)), + valmap = None, shape = [256, 1024] ) From f88d669948f4105b5ff36ba9d6dc4f98ae70b322 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 22 Dec 2021 09:57:04 +0800 Subject: [PATCH 0485/1892] add dtype for tensor --- cube/graph/graph.py | 11 ++-------- cube/graph/parser/mapping.py | 28 +++++++++++++++++++++++ cube/graph/parser/parser.py | 15 +++++++++---- cube/graph/tensor.py | 33 +++++++--------------------- cube/ir/__init__.py | 3 ++- cube/ir/cten.py | 25 ++++++++++++++++----- cube/ir/dtype.py | 22 +++++++++++++++++++ tests/graph/parser/test_parse_mlp.py | 8 +++++++ 8 files changed, 101 insertions(+), 44 deletions(-) create mode 100644 cube/ir/dtype.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index c94ab0b6..a90ce40d 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -85,15 +85,8 @@ def __init__(self, # set parameter for node in self._nodes: for input in node.inputs(): - if isinstance(input, IRTensor): - if input.is_param(): - # parameters already set - self._parameters.append(input) - continue - if input not in input_tensors and \ - input.is_leaf(self._nodes): - input.as_param() - self._parameters.append(input) + if isinstance(input, IRTensor) and input.is_param(): + self._parameters.append(input) self.reset_dependency() def reset_dependency(self): diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index fb4999f5..d1feb6be 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -2,10 +2,13 @@ Mapping of Signature -> IROperator """ +import torch + from functools import partial import cube.graph.operator.function as function from cube.graph.operator.operator import IRFwOperation +import cube.ir as ir class Sign2Op: @@ -73,3 +76,28 @@ def map(signature: str) -> IRFwOperation: } + +class DType2IRDType: + + @staticmethod + def map(dtype: torch.dtype): + """ + Map the torch dtype to IRDType + """ + return DType2IRDType.kDtypeMap[dtype] + + kDtypeMap = { + torch.float32: ir.float32, + torch.float : ir.float32, + torch.float16: ir.float16, + torch.half : ir.float16, + torch.uint8 : ir.uint8, + torch.int8 : ir.int8, + torch.int16 : ir.int16, + torch.short : ir.int16, + torch.int32 : ir.int32, + torch.int : ir.int32, + torch.int64 : ir.int64, + torch.long : ir.int64, + torch.bool : ir.boolean + } diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index eae3c0e4..9bfe65b6 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -6,7 +6,7 @@ from cube.graph import IRFwOperation from cube.graph.tensor import IRFullTensor from cube.graph.parser.frame import Frame -from cube.graph.parser.mapping import Sign2Op +from cube.graph.parser.mapping import Sign2Op, DType2IRDType _refmodule = torch.nn.Module() @@ -40,7 +40,7 @@ def parse_module(module, # handle graph input -- Assuming all the inputs are tensors input_var_name = [input.debugName() for input in module.graph.inputs()] for index, var_name in enumerate(input_var_name[1:]): # omit self - frame.add_var(var_name, IRFullTensor(name=var_name), graph_arg=index) + frame.add_var(var_name, IRFullTensor(name=var_name, requires_grad=False), graph_arg=index) input_val = [frame.get_var(var_name) for var_name in input_var_name[1:]] # handle input shape @@ -279,8 +279,15 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: # this usually means weight (nn.Parameter in torch) if dtype == 'Tensor': - shape = list(getattr(module, label).shape) - ir_tensor = IRFullTensor(name=label, shape=shape) + tensor = getattr(module, label) + shape = list(tensor.shape) + ir_tensor = IRFullTensor( + name=label, shape=shape, + requires_grad=tensor.requires_grad, + dtype=DType2IRDType.map(tensor.dtype) + ) + if isinstance(tensor, torch.nn.Parameter): + ir_tensor.as_param() frame.add_var(var_name, ir_tensor) # symbolic attributes elif dtype in ['bool', 'int', 'float']: diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 8e8fa087..7c2c45d1 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -24,6 +24,7 @@ import math from cube.ir.cten import IRCell, IRTensor +import cube.ir as ir __all__ = ['IndexMap', 'ValueMap', 'IRFullTensor', 'IRSubTensor'] @@ -301,7 +302,7 @@ def _to_value_map(valmap: Union[Tuple, ValueMap, None]): class IRFullTensor(IRTensor): - def __init__(self, shape=None, name=None, requires_grad=True, dtype=float): + def __init__(self, shape=None, name=None, requires_grad=True, dtype=ir.float32): super().__init__(shape, name, dtype) @@ -399,14 +400,14 @@ def as_param(self): self.requires_grad = True self._is_param = True self._is_grad = False - for sub_tensor in self.ptensors + self.ctensors: - sub_tensor.as_param() + # for sub_tensor in self.ptensors + self.ctensors: + # sub_tensor.as_param() def as_grad(self): self._is_param = False self._is_grad = True - for sub_tensor in self.ptensors + self.ctensors: - sub_tensor.as_grad() + # for sub_tensor in self.ptensors + self.ctensors: + # sub_tensor.as_grad() return self def like(self): @@ -575,24 +576,6 @@ def __copy__(self): tensor._cell = list() return tensor - def as_param(self): - """ - Set the tensor as trainable parameter - """ - if not self.parent.is_param(): - self.parent.as_param() - self.requires_grad = True - self._is_param = True - self._is_grad = False - return self - - def as_grad(self): - if not self.parent.is_grad(): - self.parent.as_grad() - self._is_grad = True - self._is_param = False - return self - def get_grad(self, fcell: IRCell): """ Get gradient of this tensor which is associated by a @@ -619,14 +602,14 @@ def get_grad(self, fcell: IRCell): valmap = (idx, ref_times), shape = self.shape ) - return grad.as_grad() + return grad elif self in fcell.outputs(): grad = full_grad.select( indmap = self.indmap, valmap = (0, 1), shape = self.shape ) - return grad.as_grad() + return grad else: raise RuntimeError(f"{self} not found in cell {fcell}") diff --git a/cube/ir/__init__.py b/cube/ir/__init__.py index 96030a6e..10053459 100644 --- a/cube/ir/__init__.py +++ b/cube/ir/__init__.py @@ -1 +1,2 @@ -from cube.ir.cten import IRTensor, IRCell \ No newline at end of file +from cube.ir.cten import IRTensor, IRCell +from cube.ir.dtype import * \ No newline at end of file diff --git a/cube/ir/cten.py b/cube/ir/cten.py index a6a40bdf..e4873c6d 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -18,9 +18,10 @@ import copy from cube.ir.unique import IDGenerator +from cube.ir.dtype import IRDType -__all__ = ['IRCell', 'IRTensor'] +__all__ = ['IRCell', 'IRDType', 'IRTensor'] class IRCell: @@ -48,6 +49,7 @@ def __init__(self, self.name: str = name self.signature = signature + self._dtype = IRDType.unknown self._device = list() # source tensors @@ -228,6 +230,12 @@ def set_input(self, input_index: int, val: Any): val = copy.copy(val) # set tensor dst val.attach_cell(self) + # set input value dtype + if self._dtype != IRDType.unknown: + val.dtype = self._dtype + # set cell dtype + elif val.dtype != IRDType.unknown: + self._dtype = val.dtype self._inputs[input_index] = val return val @@ -246,6 +254,12 @@ def set_output(self, output_index: int, val: Any): if isinstance(val, IRTensor): val = copy.copy(val) val.attach_cell(self) + # set output value dtype + if self._dtype != IRDType.unknown: + val.dtype = self._dtype + # set cell dtype + elif val.dtype != IRDType.unknown: + self._dtype = val.dtype self._outputs[output_index] = val return val @@ -372,7 +386,7 @@ class IRTensor: _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad', '_dtype'] - def __init__(self, shape=None, name=None, dtype=float): + def __init__(self, shape=None, name=None, dtype=IRDType.unknown): self._id: int = IDGenerator().gen_tensor_id() self._shape: Optional(List[int]) = shape @@ -381,8 +395,7 @@ def __init__(self, shape=None, name=None, dtype=float): # device self._cell: List[IRCell] = list() - # TODO: support float16 - self._dtype = dtype + self._dtype: IRDType = dtype self._is_param = False self._is_grad = False @@ -406,10 +419,12 @@ def dtype(self): return self._dtype @dtype.setter - def dtype(self, val): + def dtype(self, val: IRDType): """ Set data type """ + if not isinstance(val, IRDType): + raise TypeError("Expected IRDType") self._dtype = val def attach_cell(self, cell: IRCell): diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py new file mode 100644 index 00000000..b70c14f1 --- /dev/null +++ b/cube/ir/dtype.py @@ -0,0 +1,22 @@ +from enum import Enum + +class IRDType(Enum): + float16 = 'float16' + float32 = 'float32' + int64 = 'int64' + int32 = 'int32' + int16 = 'int16' + int8 = 'int8' + uint8 = 'uint8' + boolean = 'bool' + unknown = 'unknown' + + +float16 = IRDType.float16 +float32 = IRDType.float32 +int64 = IRDType.int64 +int32 = IRDType.int32 +int16 = IRDType.int16 +int8 = IRDType.int8 +uint8 = IRDType.uint8 +boolean = IRDType.boolean diff --git a/tests/graph/parser/test_parse_mlp.py b/tests/graph/parser/test_parse_mlp.py index c902deac..e40f6aa4 100644 --- a/tests/graph/parser/test_parse_mlp.py +++ b/tests/graph/parser/test_parse_mlp.py @@ -3,6 +3,7 @@ import cube.graph.parser as parser from cube.ir.cten import IRTensor +import cube.ir as ir class FeedForward(nn.Module): @@ -58,5 +59,12 @@ def test_parse_module(): assert node4.successors() == [node5] assert node5.successors() == [node6] + # dtype + for node in graph.nodes(): + assert node._dtype == ir.float32 + for val in node.inputs() + node.outputs(): + if isinstance(val, IRTensor): + val.dtype == ir.float32 + assert graph.outputs(0).shape == [1024, 1000] assert False \ No newline at end of file From 2d5a867aa58a6c76b47f432d17cd27b61a66695e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 22 Dec 2021 16:10:30 +0800 Subject: [PATCH 0486/1892] add adapter for graph --- cube/graph/operator/adapter.py | 283 +++++++++++++++++++++++++++++++++ 1 file changed, 283 insertions(+) create mode 100644 cube/graph/operator/adapter.py diff --git a/cube/graph/operator/adapter.py b/cube/graph/operator/adapter.py new file mode 100644 index 00000000..2dda2a16 --- /dev/null +++ b/cube/graph/operator/adapter.py @@ -0,0 +1,283 @@ + +from typing import List, Optional, Tuple +import copy +import numpy as np + +from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap +from cube.ir.cten import IRCell + + +class SelectPrim: + + def __init__(self, tensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, shape: List[int]): + self.tensor = tensor + self.indmap = indmap + self.valmap = valmap + self.shape = shape + self.output = None + self.device = tensor.device + + def set_output(self, output: IRSubTensor): + self.output = output + + def __repr__(self): + dscp = f't{self.output._id} = select(t{self.tensor._id}, {self.indmap}, {self.valmap}, {self.shape})' + return dscp + + +class MovePrim: + + def __init__(self, tensor: IRSubTensor, from_rank: int, to_rank: int): + self.tensor = tensor + self.from_rank = from_rank + self.to_rank = to_rank + self.shape = tensor.shape + self.dtype = tensor.dtype + self.device = tensor.device + + +class MergePrim: + def __init__(self, + tensors: List[IRSubTensor], + concat: Optional[int] = None, + add: bool = False): + if not ((concat is not None) ^ (add is True)): # xor condition + raise RuntimeError("Expected concat or add") + self.tensors = tensors + self.concat = concat + self.add = add + self.output = None + # re-order tensor + if isinstance(concat, int): + slicers = [tensor.indmap.get()[concat] for tensor in tensors] + starts = np.array([slicer.start for slicer in slicers], dtype=int) + sorted_idx = np.argsort(starts) + tensors = np.array(tensors)[sorted_idx] + self.tensors = tensors.tolist() + self.device = None + + def set_output(self, output: IRSubTensor): + self.device = output.device + self.output = output + + @staticmethod + def concat(tensor1: IRSubTensor, tensor2: IRSubTensor) -> Optional[Tuple[IRSubTensor, int]]: + """ + Check if two tensor can be merged. + If they can be merged, return the merge index + """ + if not isinstance(tensor1, IRSubTensor) or not isinstance(tensor2, IRSubTensor): + raise TypeError("Expected two tensors") + if tensor1.overlap(tensor2): + return None + if tensor1.parent != tensor2.parent: + return None + if tensor1.valmap != tensor2.valmap: + return None + indices1 = tensor1.indmap.get() + indices2 = tensor2.indmap.get() + indmap = list() + if len(indices1) != len(indices2): + return None + axis = None + for dim, (slicer1, slicer2) in enumerate(zip(indices1, indices2)): + if slicer1 != slicer2: + start1, stop1, step1 = slicer1.start, slicer1.stop, slicer1.step + start2, stop2, step2 = slicer2.start, slicer2.stop, slicer2.step + if step1 != step2: + return None + if axis is not None: + return None + if start1 < start2 and stop1 == start2: + axis = dim + indmap.append(slice(start1, stop2, step1)) + elif start1 > start2 and start1 == stop2: + axis = dim + indmap.append(slice(start2, stop1, step1)) + else: + return None + else: + indmap.append(slicer1) + shapes = list() + for idx, (nele1, nele2) in enumerate(zip(tensor1.shape, tensor2.shape)): + nele = nele1 if idx != axis else nele1 + nele2 + shapes.append(nele) + mtensor = tensor1.parent.select( + indmap = tuple(indmap), + valmap = tensor1.valmap, + shape = shapes + ) + return mtensor, axis + + @staticmethod + def add(tensor1: IRSubTensor, tensor2: IRSubTensor) -> Optional[IRSubTensor]: + if not isinstance(tensor1, IRSubTensor) or not isinstance(tensor2, IRSubTensor): + raise TypeError("Expected two tensors") + if tensor1.overlap(tensor2): + return None + if tensor1.parent != tensor2.parent: + return None + if tensor1.indmap != tensor2.indmap: + return None + if tensor1.valmap.chunk_num != tensor2.valmap.chunk_num: + return None + chunk_num = tensor1.valmap.chunk_num + idx1, idx2 = tensor1.valmap.idx, tensor2.valmap.idx + if chunk_num % 2 != 0: + return None + chunk_num = int(chunk_num // 2) + if int(idx1 // 2) != int(idx2 // 2): + return None + idx = int(idx1 // 2) + mtensor = tensor1.parent.select( + indmap = tensor1.indmap, + valmap = (idx, chunk_num), + shape = tensor1.shape + ) + return mtensor + + def __repr__(self): + tensors = [f't{t._id}' for t in self.tensors] + tensors = '[' + ', '.join(tensors) + ']' + dscp = f't{self.output._id} = merge({tensors}, axis={self.concat}, add={self.add})' + return dscp + + +class IRAdapter(IRCell): + """ + Tensor Adapter for each operator. + + A Tensor Adapter has three stages: + * Select: select produced tensors + * Move: transfer the produced tensors + * Merge: merge the produced tensors + """ + def __init__(self, dst_tensor: IRSubTensor): + if not isinstance(dst_tensor, IRSubTensor): + raise RuntimeError("Expected IRSubTensor") + self.dst_tensor = dst_tensor + self._intersections = list() + + # ====== select ====== + self._select_trace = list() + + # ====== move ======= + self._move_trace = list() + + # ====== merge ======= + self._merge_trace = list() + + self._gen_select() + self._gen_move() + self._gen_merge() + + def _gen_select(self): + otensor = self.dst_tensor + odevice = otensor.device + + local, remote = list(), list() + for ptensor in otensor.parent.ptensors: + if ptensor.device == odevice: + local.append(ptensor) + else: + remote.append(ptensor) + # check local tensor + if otensor in local: + self._intersections.append(otensor) + return + # FIXME: multi producer may result overlapped region + for itensor in otensor.parent.ptensors: + if not itensor.overlap(otensor): + continue + common = otensor.common(itensor) + common.attach_cell(itensor._cell) + islicers = itensor.indmap.get() + oslicers = common.indmap.get() + # index map + indmap = list() + for islicer, oslicer in zip(islicers, oslicers): + istart, istop, istep = islicer.start, islicer.stop, islicer.step + ostart, ostop, ostep = oslicer.start, oslicer.stop, oslicer.step + if ostep % istep != 0: + raise RuntimeError("Step condition fails") + # relative offset + start = ostart - istart + stop = start + ostop - ostart + slicer = slice(start, stop, ostep) + indmap.append(slicer) + # value map + if itensor.valmap == common.valmap: + valmap = ValueMap(0, 1) + elif itensor.valmap == ValueMap(0, 1): + valmap = common.valmap + else: + print('from: ', itensor) + print('to : ', common) + raise NotImplementedError( + f"Not supported value select: {input.valmap} -> {common.valmap}" + ) + prim = SelectPrim(itensor, indmap, valmap, common.shape) + prim.set_output(common) + self._select_trace.append(prim) + self._intersections.append(common) + + def _gen_move(self): + odevice = self.dst_tensor.device + for tensor in self._intersections: + if tensor.device != odevice: + prim = MovePrim(prim, from_rank=tensor.device, to_rank=odevice) + self._move_trace.append(prim) + + def _gen_merge(self): + output = self.dst_tensor + remain_tensors = copy.copy(self._intersections) + if output in remain_tensors: + return + out = None + while out != output: + out = None + merged = False + for idx1 in range(len(remain_tensors) - 1): + for idx2 in range(idx1 + 1, len(remain_tensors)): + tensor1 = remain_tensors[idx1] + tensor2 = remain_tensors[idx2] + # try concat + out = MergePrim.concat(tensor1, tensor2) + if out is not None: + out_tensor, concat_dim = out + out = out_tensor + prim = MergePrim([tensor1, tensor2], concat_dim, False) + prim.set_output(out_tensor) + self._merge_trace.append(prim) + merged = True + break + # try add + out = MergePrim.add(tensor1, tensor2) + if out is not None: + prim = MergePrim([tensor1, tensor2], None, True) + prim.set_output(out) + self._merge_trace.append(prim) + merged = True + break + if merged: + remain_tensors.remove(tensor1) + remain_tensors.remove(tensor2) + remain_tensors.append(out) + break + # cannot merge or add + if out is None: + raise RuntimeError("Merge Plan not found") + + def is_identity(self): + """ + Check if the adapter does nothing + + Returns: + Boolean + """ + if len(self._select_trace) == 0 and \ + len(self._move_trace) == 0 and \ + len(self._merge_trace) == 0: + return True + else: + return False From 732b58730d2dffc92920381e493a302f6b65c0dc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 22 Dec 2021 20:02:48 +0800 Subject: [PATCH 0487/1892] backward gen update --- cube/graph/gpass.py | 40 ++-------- cube/graph/graph.py | 97 +++++++++-------------- cube/graph/operator/operator.py | 133 +++++++++++--------------------- cube/graph/parser/converter.py | 19 +++++ cube/graph/tensor.py | 17 +++- cube/ir/cten.py | 50 ++++++------ 6 files changed, 141 insertions(+), 215 deletions(-) diff --git a/cube/graph/gpass.py b/cube/graph/gpass.py index 614a1ba6..9af2a4bd 100644 --- a/cube/graph/gpass.py +++ b/cube/graph/gpass.py @@ -3,7 +3,7 @@ from cube.graph.graph import IRGraph from cube.graph.tensor import IRSubTensor, ValueMap -from cube.graph.operator import IRBpOperation +from cube.graph.operator import IRFwOperation, IRBpOperation from cube.ir.cten import IRCell, IRTensor @@ -60,14 +60,13 @@ def forward(graph, *args) -> IRGraph: gener.set_map(input, arg) fnodes = list() - bnodes = list() # generate forward nodes for node in graph.nodes(): inputs = node.inputs() outputs = node.outputs() # fnode = copy.copy(node) - fnode = node + fnode : IRFwOperation = node fnode._inputs = inputs fnode._outputs = outputs # set forward inputs @@ -77,37 +76,10 @@ def forward(graph, *args) -> IRGraph: for idx, val in enumerate(outputs): fnode.set_output(idx, gener.renew(val)) fnodes.append(fnode) - fnode.device = node.device - - # generate backward nodes - for fnode in fnodes: - inputs = fnode.inputs() - outputs = fnode.outputs() - bnode = IRBpOperation(data_num=len(inputs), grad_num=len(outputs)) - # set backward grad - for idx, val in enumerate(fnode.inputs()): - grad = None - if isinstance(val, IRSubTensor): - # TODO: requires_grad = False should be set to None - grad = val.get_grad(fnode) - val.grad = grad - # set input - bnode.set_data(idx, val) - # set gradient output - bnode.set_output(idx, grad) - for idx, val in enumerate(fnode.outputs()): - # set gradient input - grad = None - if isinstance(val, IRSubTensor): - # TODO: requires_grad = False should be set to None - grad = val.get_grad(fnode) - val.grad = grad - bnode.set_grad(idx, grad) - bnode.device = node.device - - # mirror node for forward / backward - IRCell.make_pair(fnode, bnode) - bnodes.append(bnode) + + # reverse is only to make op id looks consecutive + for fnode in graph.nodes()[::-1]: + fnode.gen_backward() inputs = [gener.renew(input) for input in graph.inputs()] outputs = [gener.renew(output) for output in graph.outputs()] diff --git a/cube/graph/graph.py b/cube/graph/graph.py index a90ce40d..b2b8d686 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -9,10 +9,11 @@ from typing import Union, Tuple, List, Optional, Dict import copy + from cube.graph.operator.operator import IRBpOperation from cube.ir.cten import IRTensor, IRCell -from cube.graph.tensor import IRFullTensor, IRSubTensor +from cube.graph.tensor import IRSubTensor from cube.algorithm.generics import GenericDistAlgo @@ -29,54 +30,27 @@ class IRGraph(IRCell): def __init__(self, nodes: List[IRCell], - input_tensors: Optional[List[IRTensor]], - output_tensors: Optional[List[IRTensor]], + inputs: Optional[List[IRTensor]], + outputs: Optional[List[IRTensor]], module_name: str): self._nodes: List[IRCell] = nodes self._parameters = list() - if input_tensors is None: - input_tensors = list() + if inputs is None: inputs = IRCell.get_inputs(nodes) - for input in inputs: - if not input.is_param(): - input_tensors.append(input) - if output_tensors is None: - output_tensors = list() + inputs = [t for t in inputs if not t.is_param()] + if outputs is None: outputs = IRCell.get_outputs(nodes) - for output in outputs: - if not output.is_param(): - output_tensors.append(output) + outputs = [t for t in outputs if not t.is_param()] super().__init__( name=module_name, signature=module_name, - input_length=len(input_tensors), - output_length=len(output_tensors) + input_length=len(inputs), + output_length=len(outputs) ) - # convert to SubTensor - inputs = list() - for tensor in input_tensors: - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - inputs.append(tensor) - outputs = list() - for tensor in output_tensors: - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - outputs.append(tensor) - for node in self.nodes(): - for idx, tensor in enumerate(node.inputs()): - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - node.set_input(idx, tensor) - for idx, tensor in enumerate(node.outputs()): - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - node.set_output(idx, tensor) - for idx, tensor in enumerate(inputs): self.set_input(idx, tensor) for idx, tensor in enumerate(outputs): @@ -247,6 +221,8 @@ def subgraph(self, sub_nodes: List[IRCell]): ) return graph + ## Parallel Policy Primitives ## + def replicate(self, op: IRCell, times=1): """ Replicate an operation with multiple times. @@ -284,8 +260,6 @@ def replicate(self, op: IRCell, times=1): self.reset_dependency() return ops - ## Primitives for policy expression ## - def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: """ Policy primitive. Partition an operator by using @@ -375,32 +349,31 @@ def merge(self, sub_graph, target_op, op_partition_algorithm): def identity(self, input_tensor, dst_op): raise NotImplementedError + ## Assign Policy Primitives ## + + def assign(self, op: IRCell, rank: int): + if op not in self._nodes: + raise KeyError(f"{op} is not in the graph") + if not isinstance(rank, int): + raise TypeError("Expected rank to be int") + op.device = rank + if op.mirror is not None: + op.mirror.device = rank + return True + + ## Schedule Policy Primitives ## + + def set_order(self, seq: List[IRCell]): + raise NotImplementedError + + def partial_set_order(self, seq: List[IRCell], lazy=False): + raise NotImplementedError + + def __repr__(self): dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs - inputs = list() - for tensor in self.inputs(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - inputs.append(f'{anno}{tensor._id}') - else: - inputs.append(tensor) - outputs = list() - for tensor in self.outputs(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - outputs.append(f'{anno}{tensor._id}') - else: - outputs.append(tensor) - dscp += f"Inputs: {inputs}\n" + dscp += f"Inputs: {self.inputs()}\n" # nodes for node in self._nodes: succ_node_ids = [None] * len(node.outputs()) @@ -409,5 +382,5 @@ def __repr__(self): succ_node_ids[out_idx] = node_list dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" # outputs - dscp += f"\nOutputs: {outputs}\n{'=' * len(self.name)}\n" + dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" return dscp diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 92a3b52a..94091b03 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,7 +1,7 @@ from typing import Any, Optional, Union, List import copy -from cube.ir.cten import IRTensor, IRCell +from cube.ir.cten import IRCell from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory @@ -27,7 +27,7 @@ def __init__(self, """ # additional argument self.kwargs = dict() - super().__init__(name, signature, input_length, output_length) + super().__init__(name, signature, input_length, output_length, init_outputs=False) outputs = [IRFullTensor() for _ in range(output_length)] for idx, output in enumerate(outputs): self.set_output(idx, output) @@ -92,46 +92,29 @@ def replicate(self): cpy.clear_successor() return cpy - def __repr__(self): - inputs = list() - for tensor in self.inputs(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - if isinstance(tensor, IRFullTensor): - pid = tensor._id - valmap = (0,1) - else: - pid = tensor.parent._id - valmap = tensor.valmap - inputs.append(f'{anno}{tensor._id}(p{pid},{tensor.shape},{valmap})') - else: - inputs.append(tensor) - - outputs = list() - for tensor in self.outputs(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - if isinstance(tensor, IRFullTensor): - pid = tensor._id - valmap = (0,1) - else: - pid = tensor.parent._id - valmap = tensor.valmap - pid = tensor.parent._id if hasattr(tensor, 'parent') else tensor._id - outputs.append(f'{anno}{tensor._id}(p{pid},{tensor.shape},{valmap})') - else: - outputs.append(tensor) + def gen_backward(self): + if self.mirror is not None: + raise RuntimeError( + "Backward Op already generated. Use self.mirror.update() instead.") + bnode = IRBpOperation( + data_num=len(self.inputs()), + grad_num=len(self.outputs()) + ) + for idx, input in enumerate(self.inputs()): + grad = None + if isinstance(input, IRSubTensor): + grad = input.get_grad(self) + bnode.set_data(idx, input) + bnode.set_output(idx, grad) + for idx, output in enumerate(self.outputs()): + grad = output.get_grad(self) + bnode.set_grad(idx, grad) + IRCell.make_pair(self, bnode) + return bnode + def __repr__(self): sign = self.signature.split('.')[-1] - dscp = f'Op{self._id}(sign={sign}, inputs={inputs}, outputs={outputs})' + dscp = f'FwOp{self._id}-{self.device}(sign={sign}, inputs={self.inputs()}, outputs={self.outputs()})' return dscp @@ -144,7 +127,8 @@ def __init__(self, data_num, grad_num, name='backward'): super().__init__( name, signature, input_length=data_num + grad_num, - output_length=data_num + output_length=data_num, + init_outputs=False ) def replicate(self): @@ -197,7 +181,9 @@ def set_data(self, input_index: int, val: Any): def set_grad(self, input_index: int, val: Any): """ - Set the node gradient at input index + Set the node gradient at input index. + The grad is same order with corresponding output tensor + of it's forward tensor Args: input_idx: input index @@ -212,48 +198,25 @@ def set_grad(self, input_index: int, val: Any): ) return self.set_input(input_index + self.data_num, val) - def __repr__(self): - datas = list() - for tensor in self.datas(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - # datas.append(f'{anno}{tensor._id}') - datas.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.valmap})') - else: - datas.append(tensor) - - grads = list() - for tensor in self.grads(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - # grads.append(f'{anno}{tensor._id}') - grads.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.valmap})') - else: - grads.append(tensor) - - outputs = list() - for tensor in self.outputs(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - # outputs.append(f'{anno}{tensor._id}') - outputs.append(f'{anno}{tensor._id}(p{tensor.parent._id},{tensor.shape},{tensor.valmap})') - else: - outputs.append(tensor) + def update(self): + """ + Update this backward operator. + This neccessary when op is partitioned and reference count is changed. + """ + fnode = self.mirror + for idx, input in enumerate(fnode.inputs()): + grad = None + if isinstance(input, IRSubTensor): + grad = input.get_grad(self) + self.set_data(idx, input) + self.set_output(idx, grad) + for idx, output in enumerate(self.outputs()): + grad = output.get_grad(self) + self.set_grad(idx, grad) + def __repr__(self): sign = self.signature.split('.')[-1] - dscp = f'bOp{self._id}(sign={sign}, grads={grads}, datas={datas}, outputs={outputs})' + dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, grads={self.grads()}, datas={self.datas()}, outputs={self.outputs()})' return dscp @@ -314,11 +277,7 @@ def algorithms(self, tag: Optional[str] = None): return template(self) def __repr__(self): - outputs = list() - for t in self.outputs(): - name = f't{t._id}(p{t.parent._id},{t.shape},{t.valmap})' - outputs.append(name) - dscp = f'DataLoader-{self._id}(outputs={outputs})' + dscp = f'DataLoader-{self._id}(outputs={self.outputs()})' return dscp diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 42e22015..f2ae946d 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,6 +1,7 @@ from typing import Optional, List from cube.ir.cten import IRTensor +from cube.graph.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser from cube.graph import IRGraph @@ -20,5 +21,23 @@ def convert(model: torch.nn.Module, for input in inputs: if isinstance(input, IRTensor): input.requires_grad = False + # convert to SubTensor + for idx, tensor in enumerate(inputs): + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + inputs[idx] = tensor + for idx, tensor in enumerate(outputs): + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + outputs[idx] = tensor + for node in nodes: + for idx, tensor in enumerate(node.inputs()): + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + node.set_input(idx, tensor) + for idx, tensor in enumerate(node.outputs()): + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + node.set_output(idx, tensor) graph = IRGraph(nodes, inputs, outputs, module_name) return graph diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 7c2c45d1..d043c4e3 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -573,7 +573,7 @@ def __copy__(self): for key in self.__dict__: setattr(tensor, key, getattr(self, key)) # clear attached cells - tensor._cell = list() + tensor._cell = None return tensor def get_grad(self, fcell: IRCell): @@ -582,7 +582,7 @@ def get_grad(self, fcell: IRCell): forward cell """ if not self.requires_grad: - raise RuntimeError("require a gradient for a non-grad tensor") + return None full_grad = self.parent.grad if full_grad is None: return None @@ -687,5 +687,14 @@ def common(self, other): return None def __repr__(self): - dscp = f'SubTensor(id={self._id}, shape={self.shape}, device={self.device}, ind={self.indmap}, val={self.valmap})' - return dscp \ No newline at end of file + anno = 't' + if self.is_param(): + anno = 'w' + if self.is_grad(): + anno = 'g' + dscp = f'{anno}{self._id}(p{self.parent._id},{self.shape},{self.valmap})' + return dscp + + def extra_repr(self): + dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device}, ind={self.indmap}, val={self.valmap})' + return dscp diff --git a/cube/ir/cten.py b/cube/ir/cten.py index e4873c6d..c50df996 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -33,7 +33,8 @@ def __init__(self, name: str, signature: str, input_length: int, - output_length: int): + output_length: int, + init_outputs = True): """ Create a node with name (variable name) and module type (module_name) @@ -56,9 +57,11 @@ def __init__(self, self._inputs: List[Any] = [None] * input_length # destination tensors - self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] - for tensor in self._outputs: - tensor.attach_cell(self) + self._outputs: List[IRTensor] = [None] * output_length + if init_outputs: + self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] + for tensor in self._outputs: + tensor.attach_cell(self) # destination cells # -- will only be set when initializing to a graph @@ -393,7 +396,7 @@ def __init__(self, shape=None, name=None, dtype=IRDType.unknown): self.name = name if name else 'tensor' # device - self._cell: List[IRCell] = list() + self._cell: Optional[IRCell] = None self._dtype: IRDType = dtype @@ -433,19 +436,23 @@ def attach_cell(self, cell: IRCell): """ if not isinstance(cell, IRCell): raise TypeError("Expected an IRCell") - if cell not in self._cell: - self._cell.append(cell) + self._cell = cell - def detach_cell(self, cell: IRCell): + def detach_cell(self): """ - Detach from a cell, when removing from cell's input - and output + Detach from a cell """ - if not isinstance(cell, IRCell): - raise TypeError("Expected an IRCell") - if cell not in self._cell: - raise RuntimeError("the target cell not in the attached list") - self._cell.remove(cell) + self._cell = None + + @property + def device(self) -> List[int]: + return self._cell.device + + @device.setter + def device(self, val: Union[int, List[int]]): + raise RuntimeError( + "tensor placement is not allowed to set manually" + ) @property def requires_grad(self): @@ -543,19 +550,6 @@ def shape(self, val): raise RuntimeError("Expected shape to be list[int]") self._shape = copy.copy(list(val)) - @property - def device(self) -> List[int]: - device = set() - for cell in self._cell: - device.update(set(cell.device)) - return list(device) - - @device.setter - def device(self, device_id: Union[int, List[int]]): - raise RuntimeError( - "tensor placement is not allowed to set manually" - ) - def src(self, cells: List[IRCell]) -> List[IRCell]: """ Return all the cells that will generate this tensor From 5b0849c13408e2e93d75796d093a24cfc886064f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Dec 2021 10:27:31 +0800 Subject: [PATCH 0488/1892] makeing P and A in same graph --- cube/__init__.py | 4 +- cube/graph/{gpass.py => _gpass.py} | 0 cube/graph/graph.py | 3 +- cube/graph/operator/operator.py | 2 +- cube/graph/parser/__init__.py | 2 +- cube/graph/parser/converter.py | 12 +++- cube/graph/tensor.py | 2 + cube/ir/cten.py | 4 +- cube/logics/__init__.py | 0 cube/logics/dataloader.py | 40 +++++++++++++ cube/logics/model.py | 93 ++++++++++++++++++++++++++++++ cube/logics/pool.py | 50 ++++++++++++++++ cube/logics/translator.py | 82 ++++++++++++++++++++++++++ 13 files changed, 286 insertions(+), 8 deletions(-) rename cube/graph/{gpass.py => _gpass.py} (100%) create mode 100644 cube/logics/__init__.py create mode 100644 cube/logics/dataloader.py create mode 100644 cube/logics/model.py create mode 100644 cube/logics/pool.py create mode 100644 cube/logics/translator.py diff --git a/cube/__init__.py b/cube/__init__.py index 1f0f0c33..fdb7c42a 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,7 +1,7 @@ -from cube import schedule +from cube import logics from cube import runtime -from cube.compiler import SemanticModel, compile +# from cube.compiler import SemanticModel, compile def init(): diff --git a/cube/graph/gpass.py b/cube/graph/_gpass.py similarity index 100% rename from cube/graph/gpass.py rename to cube/graph/_gpass.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index b2b8d686..b0181ad3 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -174,7 +174,7 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: Returns: IRTensors """ - from cube.schedule.translator import LogicTranslator + from cube.logics.translator import LogicTranslator return LogicTranslator.forward(self, *args) def __call__(self, *args): @@ -357,6 +357,7 @@ def assign(self, op: IRCell, rank: int): if not isinstance(rank, int): raise TypeError("Expected rank to be int") op.device = rank + # pytorch requirement if op.mirror is not None: op.mirror.device = rank return True diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 94091b03..a6f91268 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -207,7 +207,7 @@ def update(self): for idx, input in enumerate(fnode.inputs()): grad = None if isinstance(input, IRSubTensor): - grad = input.get_grad(self) + grad = input.get_grad(fnode) self.set_data(idx, input) self.set_output(idx, grad) for idx, output in enumerate(self.outputs()): diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index 9d11b160..ded01f96 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,2 +1,2 @@ from cube.graph.parser.parser import ScriptModuleParser -from cube.graph.parser.converter import convert \ No newline at end of file +from cube.graph.parser.converter import convert_model, convert_dataloader \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index f2ae946d..dc7377d1 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -4,10 +4,11 @@ from cube.graph.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser from cube.graph import IRGraph +from cube.logics.dataloader import IRDataLoader import torch -def convert(model: torch.nn.Module, +def convert_model(model: torch.nn.Module, input_shapes: Optional[ List[List[int],] ] = None) -> IRGraph: """ Convert toch.nn.Module based model into IRGraph @@ -41,3 +42,12 @@ def convert(model: torch.nn.Module, node.set_output(idx, tensor) graph = IRGraph(nodes, inputs, outputs, module_name) return graph + + +def convert_dataloader(dataloader) -> IRDataLoader: + """ + convert pytorch dataloader into IRDataLoader + """ + from cube.graph.parser.mapping import DType2IRDType + dataloader = IRDataLoader(dataloader, dtype_map=DType2IRDType) + return dataloader diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index d043c4e3..8099a12b 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -602,6 +602,7 @@ def get_grad(self, fcell: IRCell): valmap = (idx, ref_times), shape = self.shape ) + self.grad = grad return grad elif self in fcell.outputs(): grad = full_grad.select( @@ -609,6 +610,7 @@ def get_grad(self, fcell: IRCell): valmap = (0, 1), shape = self.shape ) + self.grad = grad return grad else: raise RuntimeError(f"{self} not found in cell {fcell}") diff --git a/cube/ir/cten.py b/cube/ir/cten.py index c50df996..320b0d6d 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -427,7 +427,7 @@ def dtype(self, val: IRDType): Set data type """ if not isinstance(val, IRDType): - raise TypeError("Expected IRDType") + raise TypeError(f"Expected IRDType but got {val}") self._dtype = val def attach_cell(self, cell: IRCell): @@ -584,7 +584,7 @@ def backward(self): """ Autograd backward on the tensor """ - from cube.schedule.translator import LogicTranslator + from cube.logics.translator import LogicTranslator return LogicTranslator.backward(self) def __repr__(self): diff --git a/cube/logics/__init__.py b/cube/logics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/logics/dataloader.py b/cube/logics/dataloader.py new file mode 100644 index 00000000..cc8871b1 --- /dev/null +++ b/cube/logics/dataloader.py @@ -0,0 +1,40 @@ +import copy +from typing import Optional + +import torch + +from cube.runtime.syndata import CubeDataLoader + + +class IRDataLoader: + + def __init__(self, dataloader: CubeDataLoader, dtype_map): + self.dataloader = iter(dataloader) + self.batch_dims = dataloader.get_batch_dims() + self.dtypes = list() + self.shapes = list() + + datas = next(dataloader) + if not isinstance(datas, tuple): + datas = (datas,) + + for data in datas: + if torch.is_tensor(data): + self.dtypes.append(dtype_map.map(data.dtype)) + self.shapes.append(list(data.shape)) + else: + raise NotImplementedError("Data should be torch.Tensor") + + def get_batch_dims(self, idx: Optional[int] = None) -> int: + if idx is None: + return copy.copy(self.batch_dims) + else: + return self.batch_dims[idx] + + def __iter__(self): + return self + + def __next__(self): + from cube.logics.translator import LogicTranslator + datas = LogicTranslator.load_data(self) + return datas diff --git a/cube/logics/model.py b/cube/logics/model.py new file mode 100644 index 00000000..f6073f73 --- /dev/null +++ b/cube/logics/model.py @@ -0,0 +1,93 @@ +from typing import Any +import copy + +from cube.graph.graph import IRGraph +from cube.graph.tensor import IRSubTensor +from cube.graph.operator import IRFwOperation + +from cube.ir.cten import IRTensor + + +__all__ = ['forward'] + + +class _TensorGener: + + def __init__(self): + self.symbol = dict() + + def renew(self, val: Any, keep_param=True): + self._check_is_sub_tensor(val) + if not isinstance(val, IRTensor): + return val + if keep_param and val.is_param(): + return val + if val.parent._id not in self.symbol: + self.symbol[val.parent._id] = val.parent.like() + new_val = self.symbol[val.parent._id].select( + indmap=val.indmap, + valmap=val.valmap, + shape=val.shape + ) + return new_val + + def set_map(self, origin: Any, new: Any): + self._check_is_sub_tensor(origin) + self._check_is_sub_tensor(new) + if isinstance(origin, IRSubTensor): + tid = origin.parent._id + if isinstance(new, IRSubTensor): + self.symbol[tid] = new.parent + return + self.symbol[tid] = new + + def _check_is_sub_tensor(self, tensor): + if isinstance(tensor, IRTensor): + if not isinstance(tensor, IRSubTensor): + raise TypeError("Tensor only allows to be SubTensor") + + +def forward(graph, *args) -> IRGraph: + """ + Forward the IRGraph, replacing all the intermediate tensors + """ + if not isinstance(graph, IRGraph): + raise TypeError("Forwarding requires IRGraph") + + gener = _TensorGener() + + for input, arg in zip(graph.inputs(), args): + gener.set_map(input, arg) + + fnodes = list() + + # generate forward nodes + for node in graph.nodes(): + inputs = node.inputs() + outputs = node.outputs() + # fnode = copy.copy(node) + fnode : IRFwOperation = node + fnode._inputs = inputs + fnode._outputs = outputs + # set forward inputs + for idx, val in enumerate(inputs): + fnode.set_input(idx, gener.renew(val)) + # set forward outputs + for idx, val in enumerate(outputs): + fnode.set_output(idx, gener.renew(val)) + fnodes.append(fnode) + + # reverse is only to make op id looks consecutive + for fnode in graph.nodes()[::-1]: + fnode.gen_backward() + + inputs = [gener.renew(input) for input in graph.inputs()] + outputs = [gener.renew(output) for output in graph.outputs()] + + for idx, input in enumerate(inputs): + graph.set_input(idx, input) + for idx, output in enumerate(outputs): + graph.set_output(idx, output) + + # fgraph = IRGraph(fnodes, inputs, outputs, graph.name) + return graph diff --git a/cube/logics/pool.py b/cube/logics/pool.py new file mode 100644 index 00000000..fd9f2045 --- /dev/null +++ b/cube/logics/pool.py @@ -0,0 +1,50 @@ +from typing import List, Any +import copy + + +class SchedulePool: + + class __SchedulePool: + + def __init__(self): + + self._nodes = list() + self._tapes = dict() + + instance = None + + def __init__(self): + if not SchedulePool.instance: + SchedulePool.instance = SchedulePool.__SchedulePool() + + def __getattr__(self, name): + return getattr(self.instance, name) + + def add_node(self, node): + self.instance._nodes.append(node) + + def nodes(self) -> List: + return copy.copy(self.instance._nodes) + + def tape(self, tensor, trace: Any): + """ + Record the trace generated to this tensor + """ + self.instance._tapes[tensor._id] = trace + + def get_tape(self, tensor): + """ + Get the trace given the tensor + """ + if tensor._id not in self.instance._tapes: + return None + else: + return self.instance._tapes[tensor._id] + + def clear(self): + self.instance._nodes = list() + self.instance._tapes = dict() + + def __repr__(self): + dscp = '\n'.join([repr(node) for node in self._nodes]) + return dscp diff --git a/cube/logics/translator.py b/cube/logics/translator.py new file mode 100644 index 00000000..220d20de --- /dev/null +++ b/cube/logics/translator.py @@ -0,0 +1,82 @@ +from cube.logics.dataloader import IRDataLoader +from cube.logics import model +from cube.logics.pool import SchedulePool + +from cube.graph.graph import IRGraph +from cube.graph.tensor import IRFullTensor, IRSubTensor +from cube.graph.operator import IRDataOperation + + +class LogicTranslator: + + @staticmethod + def gen_logic_graph(): + """ + Generate Training Logic Graph + """ + nodes = SchedulePool().nodes() + graph = IRGraph(nodes, inputs=[], outputs=None, module_name='LogicGraph') + return graph + + @staticmethod + def load_data(dataloader: IRDataLoader): + """ + Translator Action: Load data from data loaderw + """ + if not isinstance(dataloader, IRDataLoader): + raise TypeError("Expected IRDataLoader") + outputs = list() + for dtype, shape in zip(dataloader.dtypes, dataloader.shapes): + data = IRFullTensor( + shape, 'data', requires_grad=False, dtype=dtype + ).tosub() + outputs.append(data) + + data_op = IRDataOperation( + data_num=len(outputs), batch_dims=dataloader.get_batch_dims(), + ) + for idx, output in enumerate(outputs): + data_op.set_output(idx, output) + + SchedulePool().add_node(data_op) + if len(outputs) == 0: return + elif len(outputs) == 1: return outputs[0] + else: return tuple(outputs) + + @staticmethod + def forward(graph, *args): + """ + Translator Action: forward an IRGraph + """ + fgraph = model.forward(graph, *args) + for node in fgraph.nodes(): + SchedulePool().add_node(node) + for output in fgraph.outputs(): + SchedulePool().tape(output, fgraph.nodes()) + outputs = fgraph.outputs() + if len(outputs) == 1: return outputs[0] + elif len(outputs) == 0: return None + else: return outputs + + @staticmethod + def backward(loss: IRSubTensor): + """ + Translator Action: backward a tensor + """ + trace = SchedulePool().get_tape(loss) + if trace is None: + raise RuntimeError("No forward detected") + # make grad to 1.0 + if not loss.shape == [1]: + raise RuntimeError("backward can only perform on the scaler tensor") + loss.parent.requires_grad = False + for node in trace: + for output in node.outputs(): + if loss.overlap(output): + node.mirror.update() + for node in trace[::-1]: + SchedulePool().add_node(node.mirror) + + @staticmethod + def update(optimizer): + raise NotImplementedError From 7396f3799215b65f2604ef482c13c6881ba8e1d3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Dec 2021 13:54:58 +0800 Subject: [PATCH 0489/1892] fix graph bugs; enable new adapter --- cube/graph/adapter/__init__.py | 1 + cube/graph/{operator => adapter}/adapter.py | 54 +++++++++++------- cube/graph/adapter/gen.py | 26 +++++++++ cube/graph/graph.py | 63 ++++++++------------- cube/graph/operator/operator.py | 30 +++++++--- 5 files changed, 105 insertions(+), 69 deletions(-) create mode 100644 cube/graph/adapter/__init__.py rename cube/graph/{operator => adapter}/adapter.py (87%) create mode 100644 cube/graph/adapter/gen.py diff --git a/cube/graph/adapter/__init__.py b/cube/graph/adapter/__init__.py new file mode 100644 index 00000000..2f1a83e2 --- /dev/null +++ b/cube/graph/adapter/__init__.py @@ -0,0 +1 @@ +from cube.graph.adapter.gen import AdapterGener diff --git a/cube/graph/operator/adapter.py b/cube/graph/adapter/adapter.py similarity index 87% rename from cube/graph/operator/adapter.py rename to cube/graph/adapter/adapter.py index 2dda2a16..d3d8cbc2 100644 --- a/cube/graph/operator/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -9,16 +9,14 @@ class SelectPrim: - def __init__(self, tensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, shape: List[int]): + def __init__(self, tensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, + shape: List[int], output: IRSubTensor): self.tensor = tensor self.indmap = indmap self.valmap = valmap self.shape = shape - self.output = None - self.device = tensor.device - - def set_output(self, output: IRSubTensor): self.output = output + self.device = tensor.device def __repr__(self): dscp = f't{self.output._id} = select(t{self.tensor._id}, {self.indmap}, {self.valmap}, {self.shape})' @@ -37,16 +35,15 @@ def __init__(self, tensor: IRSubTensor, from_rank: int, to_rank: int): class MergePrim: - def __init__(self, - tensors: List[IRSubTensor], - concat: Optional[int] = None, - add: bool = False): + def __init__(self, tensors: List[IRSubTensor], + output: IRSubTensor, device: List[int], + concat: Optional[int] = None, add: bool = False): if not ((concat is not None) ^ (add is True)): # xor condition raise RuntimeError("Expected concat or add") self.tensors = tensors self.concat = concat self.add = add - self.output = None + self.output = output # re-order tensor if isinstance(concat, int): slicers = [tensor.indmap.get()[concat] for tensor in tensors] @@ -54,10 +51,9 @@ def __init__(self, sorted_idx = np.argsort(starts) tensors = np.array(tensors)[sorted_idx] self.tensors = tensors.tolist() - self.device = None + self.device = device def set_output(self, output: IRSubTensor): - self.device = output.device self.output = output @staticmethod @@ -153,6 +149,7 @@ class IRAdapter(IRCell): * Merge: merge the produced tensors """ def __init__(self, dst_tensor: IRSubTensor): + print(f'generating adapter for: {dst_tensor}') if not isinstance(dst_tensor, IRSubTensor): raise RuntimeError("Expected IRSubTensor") self.dst_tensor = dst_tensor @@ -160,6 +157,7 @@ def __init__(self, dst_tensor: IRSubTensor): # ====== select ====== self._select_trace = list() + self._select_ptensors = list() # ====== move ======= self._move_trace = list() @@ -171,10 +169,21 @@ def __init__(self, dst_tensor: IRSubTensor): self._gen_move() self._gen_merge() + super().__init__( + name='adapter', signature='adapter', + input_length=len(self._intersections), output_length=1, + init_outputs=False + ) + for idx, ptensor in enumerate(self._select_ptensors): + self.set_input(idx, ptensor) + self.set_output(0, dst_tensor) + def _gen_select(self): otensor = self.dst_tensor odevice = otensor.device + print(f'select: produced tensors: {otensor.parent.ptensors}') + local, remote = list(), list() for ptensor in otensor.parent.ptensors: if ptensor.device == odevice: @@ -184,6 +193,7 @@ def _gen_select(self): # check local tensor if otensor in local: self._intersections.append(otensor) + self._select_ptensors.append(ptensor) return # FIXME: multi producer may result overlapped region for itensor in otensor.parent.ptensors: @@ -216,16 +226,16 @@ def _gen_select(self): raise NotImplementedError( f"Not supported value select: {input.valmap} -> {common.valmap}" ) - prim = SelectPrim(itensor, indmap, valmap, common.shape) - prim.set_output(common) + prim = SelectPrim(itensor, indmap, valmap, common.shape, common) self._select_trace.append(prim) self._intersections.append(common) + self._select_ptensors.append(itensor) def _gen_move(self): odevice = self.dst_tensor.device for tensor in self._intersections: if tensor.device != odevice: - prim = MovePrim(prim, from_rank=tensor.device, to_rank=odevice) + prim = MovePrim(tensor, from_rank=tensor.device, to_rank=odevice) self._move_trace.append(prim) def _gen_merge(self): @@ -244,18 +254,15 @@ def _gen_merge(self): # try concat out = MergePrim.concat(tensor1, tensor2) if out is not None: - out_tensor, concat_dim = out - out = out_tensor - prim = MergePrim([tensor1, tensor2], concat_dim, False) - prim.set_output(out_tensor) + out, concat_dim = out + prim = MergePrim([tensor1, tensor2], out, output.device, concat_dim, False) self._merge_trace.append(prim) merged = True break # try add out = MergePrim.add(tensor1, tensor2) if out is not None: - prim = MergePrim([tensor1, tensor2], None, True) - prim.set_output(out) + prim = MergePrim([tensor1, tensor2], out, output.device, None, True) self._merge_trace.append(prim) merged = True break @@ -281,3 +288,8 @@ def is_identity(self): return True else: return False + + def __repr__(self): + dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' + return dscp + \ No newline at end of file diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py new file mode 100644 index 00000000..527da1c6 --- /dev/null +++ b/cube/graph/adapter/gen.py @@ -0,0 +1,26 @@ + +from cube.graph.graph import IRGraph +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + + +class AdapterGener: + + @staticmethod + def gen(graph: IRGraph) -> IRGraph: + for node in graph.nodes(): + if not isinstance(node, IRFwOperation): + continue + # adapter for input + for input in node.inputs(): + if not isinstance(input, IRTensor): + continue + # skip parameter + if input.is_param(): + continue + adapter = IRAdapter(input) + if not adapter.is_identity(): + idx = graph.nodes().index(node) + graph._nodes.insert(idx, adapter) + return graph diff --git a/cube/graph/graph.py b/cube/graph/graph.py index b0181ad3..7fee670a 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -10,8 +10,9 @@ from typing import Union, Tuple, List, Optional, Dict import copy -from cube.graph.operator.operator import IRBpOperation +from numpy import isin +from cube.graph.operator.operator import IRBpOperation, IRFwOperation from cube.ir.cten import IRTensor, IRCell from cube.graph.tensor import IRSubTensor @@ -294,46 +295,26 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional raise NotImplementedError( f"Not support feature-map {input} to be splitted in value as input" ) + # remove reference - for idx in range(len(op.inputs())): - op.set_input(idx, None) - # set backward mirror node - if op.mirror is not None: - # go through related op to reset the related gradient - for fnode in fnodes: - for val in fnode.inputs(): - if not isinstance(val, IRSubTensor): - continue - # TODO: requires_grad = False should be set to None - val.grad = val.get_grad(fnode) - for related_op in val.parent.consumers: - for idx, rval in enumerate(related_op.inputs()): - if val.overlap(rval): - rval.grad = rval.get_grad(related_op) - if related_op.mirror is not None: - related_op.mirror.set_output(idx, rval.grad) - # generate mirror node - for fnode in fnodes: - bnode = IRBpOperation( - data_num=len(fnode.inputs()), - grad_num=len(fnode.outputs()) - ) - for idx, val in enumerate(fnode.inputs()): - grad = None - if isinstance(val, IRSubTensor): - grad = val.grad - bnode.set_data(idx, val) - bnode.set_output(idx, grad) - for idx, val in enumerate(fnode.outputs()): - grad = None - if isinstance(val, IRSubTensor): - # TODO: requires_grad = False should be set to None - grad = val.get_grad(fnode) - val.grad = grad - bnode.set_grad(idx, grad) - IRCell.make_pair(fnode, bnode) - fnode.device = op.device - bnode.device = op.mirror.device + finputs = op.inputs() + op.make_empty() + + # generate backward + updated = set() + for input in finputs: + if not isinstance(input, IRSubTensor): + continue + # go through related consumers and update backward op + for fnode in input.parent.consumers: + if isinstance(fnode, IRFwOperation) and fnode._id not in updated: + if fnode.mirror is not None: + fnode.mirror.update() + else: + fnode.gen_backward() + updated.add(fnode._id) + + # insert nodes idx = self._nodes.index(op) self._nodes = self._nodes[:idx] + fnodes + self._nodes[idx+1:] if op.mirror is not None: @@ -377,6 +358,8 @@ def __repr__(self): dscp += f"Inputs: {self.inputs()}\n" # nodes for node in self._nodes: + # if isinstance(node, IRBpOperation): + # continue succ_node_ids = [None] * len(node.outputs()) for out_idx in range(len(node.outputs())): node_list = [snode._id for snode in node.successors(out_idx)] diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index a6f91268..17e41dd6 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -61,22 +61,26 @@ def algorithms(self, tag: Optional[str] = None): return template(self) def set_input(self, input_index: int, val: Any): + # remove the consumer old_val = self.inputs(input_index) - # remove the old one if isinstance(old_val, IRSubTensor): old_val.parent.rm_consumer(self) + # add the consumer + val = super().set_input(input_index, val) if isinstance(val, IRSubTensor): val.parent.add_consumer(self, val) - return super().set_input(input_index, val) + return val def set_output(self, output_index: int, val: Any): + # remove the producer old_val = self.outputs(output_index) - # remove the old one if isinstance(old_val, IRSubTensor): old_val.parent.rm_producer(self) + # add the producer + val = super().set_output(output_index, val) if isinstance(val, IRSubTensor): val.parent.add_producer(self, val) - return super().set_output(output_index, val) + return val def replicate(self): """ @@ -210,12 +214,11 @@ def update(self): grad = input.get_grad(fnode) self.set_data(idx, input) self.set_output(idx, grad) - for idx, output in enumerate(self.outputs()): - grad = output.get_grad(self) + for idx, output in enumerate(fnode.outputs()): + grad = output.get_grad(fnode) self.set_grad(idx, grad) def __repr__(self): - sign = self.signature.split('.')[-1] dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, grads={self.grads()}, datas={self.datas()}, outputs={self.outputs()})' return dscp @@ -254,6 +257,17 @@ def infer_shape(self): """ return True + def set_output(self, output_index: int, val: Any): + # remove the producer + old_val = self.outputs(output_index) + if isinstance(old_val, IRSubTensor): + old_val.parent.rm_producer(self) + # add the producer + val = super().set_output(output_index, val) + if isinstance(val, IRSubTensor): + val.parent.add_producer(self, val) + return val + def algorithms(self, tag: Optional[str] = None): """ get algorithm from algorithm factory @@ -277,7 +291,7 @@ def algorithms(self, tag: Optional[str] = None): return template(self) def __repr__(self): - dscp = f'DataLoader-{self._id}(outputs={self.outputs()})' + dscp = f'DataLoader{self._id}-{self.device}(outputs={self.outputs()})' return dscp From 0533853950c2eb59a807dcae8adddd5d8db7d468 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Dec 2021 13:55:38 +0800 Subject: [PATCH 0490/1892] empty operator --- cube/ir/cten.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 320b0d6d..d6baf937 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -307,6 +307,15 @@ def clear_successor(self): list() for _ in range(len(self.outputs())) ] + def make_empty(self): + """ + Clear all inputs, outputs of this Cell + """ + for idx in range(len(self.inputs())): + self.set_input(idx, None) + for idx in range(len(self.outputs())): + self.set_output(idx, None) + @staticmethod def get_inputs(cells): """ From 2a2be3f71a868c8258a1ece1c47f4b663a0e2082 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Dec 2021 15:03:34 +0800 Subject: [PATCH 0491/1892] add adapter --- cube/graph/adapter/adapter.py | 30 +++++++++++++++++++++++------- cube/logics/translator.py | 2 +- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index d3d8cbc2..c7cf837d 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -19,7 +19,7 @@ def __init__(self, tensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, self.device = tensor.device def __repr__(self): - dscp = f't{self.output._id} = select(t{self.tensor._id}, {self.indmap}, {self.valmap}, {self.shape})' + dscp = f'{self.output} = select({self.tensor})' return dscp @@ -33,6 +33,10 @@ def __init__(self, tensor: IRSubTensor, from_rank: int, to_rank: int): self.dtype = tensor.dtype self.device = tensor.device + def __repr__(self): + dscp = f'move({self.tensor}, from={self.from_rank}, to={self.to_rank})' + return dscp + class MergePrim: def __init__(self, tensors: List[IRSubTensor], @@ -133,9 +137,7 @@ def add(tensor1: IRSubTensor, tensor2: IRSubTensor) -> Optional[IRSubTensor]: return mtensor def __repr__(self): - tensors = [f't{t._id}' for t in self.tensors] - tensors = '[' + ', '.join(tensors) + ']' - dscp = f't{self.output._id} = merge({tensors}, axis={self.concat}, add={self.add})' + dscp = f'{self.output} = merge({self.tensors}, axis={self.concat}, add={self.add})' return dscp @@ -199,8 +201,15 @@ def _gen_select(self): for itensor in otensor.parent.ptensors: if not itensor.overlap(otensor): continue + + # intersection common = otensor.common(itensor) common.attach_cell(itensor._cell) + self._intersections.append(common) + self._select_ptensors.append(itensor) + if common == itensor: + continue + islicers = itensor.indmap.get() oslicers = common.indmap.get() # index map @@ -228,8 +237,6 @@ def _gen_select(self): ) prim = SelectPrim(itensor, indmap, valmap, common.shape, common) self._select_trace.append(prim) - self._intersections.append(common) - self._select_ptensors.append(itensor) def _gen_move(self): odevice = self.dst_tensor.device @@ -292,4 +299,13 @@ def is_identity(self): def __repr__(self): dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' return dscp - \ No newline at end of file + + def extra_repr(self): + """ + Detailed information + """ + dscp = repr(self) + ':\n' + # select + for prim in self._select_trace + self._move_trace + self._merge_trace: + dscp += '\t' + repr(prim) + '\n' + return dscp diff --git a/cube/logics/translator.py b/cube/logics/translator.py index 220d20de..6cf1442d 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -66,9 +66,9 @@ def backward(loss: IRSubTensor): trace = SchedulePool().get_tape(loss) if trace is None: raise RuntimeError("No forward detected") - # make grad to 1.0 if not loss.shape == [1]: raise RuntimeError("backward can only perform on the scaler tensor") + # grad should be None or 1.0 loss.parent.requires_grad = False for node in trace: for output in node.outputs(): From 696fd7eba5125cf7c29cc79821f6d82974f1f2fc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 24 Dec 2021 11:29:51 +0800 Subject: [PATCH 0492/1892] add backward adapter --- cube/graph/adapter/adapter.py | 53 ++++++++++++++++++++++++++------- cube/graph/adapter/gen.py | 37 +++++++++++++---------- cube/graph/graph.py | 3 +- cube/graph/operator/operator.py | 34 +++++++++++++++++++-- 4 files changed, 99 insertions(+), 28 deletions(-) diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index c7cf837d..f78630e4 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -16,7 +16,7 @@ def __init__(self, tensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, self.valmap = valmap self.shape = shape self.output = output - self.device = tensor.device + self.device: List[int] = tensor.device def __repr__(self): dscp = f'{self.output} = select({self.tensor})' @@ -31,7 +31,7 @@ def __init__(self, tensor: IRSubTensor, from_rank: int, to_rank: int): self.to_rank = to_rank self.shape = tensor.shape self.dtype = tensor.dtype - self.device = tensor.device + self.device: List[int] = [from_rank, to_rank] def __repr__(self): dscp = f'move({self.tensor}, from={self.from_rank}, to={self.to_rank})' @@ -55,7 +55,7 @@ def __init__(self, tensors: List[IRSubTensor], sorted_idx = np.argsort(starts) tensors = np.array(tensors)[sorted_idx] self.tensors = tensors.tolist() - self.device = device + self.device: List[int] = device def set_output(self, output: IRSubTensor): self.output = output @@ -179,6 +179,12 @@ def __init__(self, dst_tensor: IRSubTensor): for idx, ptensor in enumerate(self._select_ptensors): self.set_input(idx, ptensor) self.set_output(0, dst_tensor) + + # set up device + device = set() + for prim in self._select_trace + self._move_trace + self._merge_trace: + device.update(prim.device) + self.device = list(device) def _gen_select(self): otensor = self.dst_tensor @@ -242,7 +248,11 @@ def _gen_move(self): odevice = self.dst_tensor.device for tensor in self._intersections: if tensor.device != odevice: - prim = MovePrim(tensor, from_rank=tensor.device, to_rank=odevice) + if len(tensor.device) != 1 or len(odevice) != 1: + raise RuntimeError( + f"Expected tensor on a single device but got {tensor.device} and {odevice}" + ) + prim = MovePrim(tensor, from_rank=tensor.device[0], to_rank=odevice[0]) self._move_trace.append(prim) def _gen_merge(self): @@ -282,6 +292,34 @@ def _gen_merge(self): if out is None: raise RuntimeError("Merge Plan not found") + def prims(self, select=True, move=True, merge=True): + """ + Return prim list + """ + prims = list() + if select: + prims += self._select_trace + if move: + prims += self._move_trace + if merge: + prims += self._merge_trace + return prims + + def dispatch(self, rank: int) -> List: + """ + Get executed prim for a specific rank + + Returns: + List[Prims] + """ + if not isinstance(rank, int): + raise TypeError(f"Expected rank to be int but got {rank}") + prims = list() + for prim in self.prims(): + if rank in prims.device: + prims.append(prim) + return prims + def is_identity(self): """ Check if the adapter does nothing @@ -289,12 +327,7 @@ def is_identity(self): Returns: Boolean """ - if len(self._select_trace) == 0 and \ - len(self._move_trace) == 0 and \ - len(self._merge_trace) == 0: - return True - else: - return False + return len(self.prims()) == 0 def __repr__(self): dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index 527da1c6..b217dba2 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -1,8 +1,8 @@ from cube.graph.graph import IRGraph from cube.graph.adapter.adapter import IRAdapter -from cube.graph.operator.operator import IRFwOperation -from cube.ir.cten import IRTensor +from cube.graph.operator.operator import IRBpOperation, IRFwOperation +from cube.graph.tensor import IRSubTensor class AdapterGener: @@ -10,17 +10,24 @@ class AdapterGener: @staticmethod def gen(graph: IRGraph) -> IRGraph: for node in graph.nodes(): - if not isinstance(node, IRFwOperation): - continue - # adapter for input - for input in node.inputs(): - if not isinstance(input, IRTensor): - continue - # skip parameter - if input.is_param(): - continue - adapter = IRAdapter(input) - if not adapter.is_identity(): - idx = graph.nodes().index(node) - graph._nodes.insert(idx, adapter) + if isinstance(node, IRFwOperation): + for input in node.inputs(): + if not isinstance(input, IRSubTensor): + continue + # skip parameter + if input.is_param(): + continue + adapter = IRAdapter(input) + if not adapter.is_identity(): + idx = graph.nodes().index(node) + graph._nodes.insert(idx, adapter) + if isinstance(node, IRBpOperation): + for grad in node.grads(): + if not isinstance(grad, IRSubTensor): + continue + # skip parameter + adapter = IRAdapter(grad) + if not adapter.is_identity(): + idx = graph.nodes().index(node) + graph._nodes.insert(idx, adapter) return graph diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 7fee670a..e63b4fdf 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -299,6 +299,8 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional # remove reference finputs = op.inputs() op.make_empty() + if op.mirror is not None: + op.mirror.make_empty() # generate backward updated = set() @@ -350,7 +352,6 @@ def set_order(self, seq: List[IRCell]): def partial_set_order(self, seq: List[IRCell], lazy=False): raise NotImplementedError - def __repr__(self): dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 17e41dd6..a5f6670e 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -150,6 +150,9 @@ def replicate(self): return cpy def datas(self, index: Optional[int] = None) -> Union[List[Any], Any]: + """ + Forward inputs + """ if index is None: return self.inputs()[:self.data_num] if index >= self.data_num: @@ -159,6 +162,9 @@ def datas(self, index: Optional[int] = None) -> Union[List[Any], Any]: return self.inputs(index) def grads(self, index: Optional[int] = None) -> Union[List[Any], Any]: + """ + backward op input gradient (a.k.a. output gradient in forward) + """ if index is None: return self.inputs()[self.data_num:] elif index >= self.grad_num: @@ -185,7 +191,8 @@ def set_data(self, input_index: int, val: Any): def set_grad(self, input_index: int, val: Any): """ - Set the node gradient at input index. + Set the node input gradient + (i.e., output gradient in forward) at input index. The grad is same order with corresponding output tensor of it's forward tensor @@ -200,7 +207,30 @@ def set_grad(self, input_index: int, val: Any): raise RuntimeError( f"Set the grad out of range ({input_index} >= {self.grad_num})" ) - return self.set_input(input_index + self.data_num, val) + input_index += self.data_num + # remove the consumer + old_val = self.inputs(input_index) + if isinstance(old_val, IRSubTensor): + old_val.parent.rm_consumer(self) + # add the consumer + val = super().set_input(input_index, val) + if isinstance(val, IRSubTensor): + val.parent.add_consumer(self, val) + return val + + def set_output(self, output_index: int, val: Any): + """ + Set op output grad (Forward input gradient) + """ + # remove the producer + old_val = self.outputs(output_index) + if isinstance(old_val, IRSubTensor): + old_val.parent.rm_producer(self) + # add the producer + val = super().set_output(output_index, val) + if isinstance(val, IRSubTensor): + val.parent.add_producer(self, val) + return val def update(self): """ From c628f9c7cbb8991c7b0927f024ab96f9036e0e19 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 24 Dec 2021 15:53:20 +0800 Subject: [PATCH 0493/1892] use subgraph for schedule unit --- cube/execplan/execplan.py | 50 ++++---- cube/graph/adapter/adapter.py | 7 +- cube/graph/graph.py | 198 ++++++++++++++++++++------------ cube/graph/operator/operator.py | 22 ++++ 4 files changed, 179 insertions(+), 98 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 14fc010b..491b32c2 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -1,25 +1,27 @@ from typing import List, Optional import copy +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.operator.operator import IRBpOperation, IRFwOperation -from cube.schedule.sugraph import SUGraph -from cube.schedule.su import SUType, ScheduleUnit +from cube.ir.cten import IRCell +from cube.graph.graph import IRGraph class ExectuionPlan: - def __init__(self, sugraph: SUGraph): - if not isinstance(sugraph, SUGraph): + def __init__(self, graph: IRGraph): + if not isinstance(graph, IRGraph): raise TypeError("Expected a list of ScheduleUnit") - self.sugraph = sugraph + self.graph = graph self.device_seq = dict() - for su in sugraph.sus(): - if len(su.device) == 0: - raise RuntimeError(f"device not set: SU {su}") - for device in su.device: + for node in graph.nodes(): + if len(node.device) == 0: + raise RuntimeError(f"Node device not set: {node}") + for device in node.device: if device not in self.device_seq: - self.device_seq[device] = [su] + self.device_seq[device] = [node] else: - self.device_seq[device].append(su) + self.device_seq[device].append(node) def devices(self) -> List[int]: """ @@ -29,7 +31,7 @@ def devices(self) -> List[int]: devices.sort() return devices - def sequence(self, device_id: int) -> List[ScheduleUnit]: + def sequence(self, device_id: int) -> List[IRCell]: """ Get a copy of execution sequence for device id @@ -39,7 +41,7 @@ def sequence(self, device_id: int) -> List[ScheduleUnit]: return list() return copy.copy(self.device_seq[device_id]) - def at(self, device_id: int) -> List[ScheduleUnit]: + def at(self, device_id: int) -> List[IRCell]: """ Access the sequence for device id @@ -49,12 +51,12 @@ def at(self, device_id: int) -> List[ScheduleUnit]: return list() return self.device_seq[device_id] - def set(self, device_id: int, seq: List[ScheduleUnit]): + def set(self, device_id: int, seq: List[IRCell]): """ Set device sequence """ - if not all([isinstance(su, ScheduleUnit) for su in seq]): - raise TypeError("Expected a list of ScheduleUnit") + if not all([isinstance(su, IRCell) for su in seq]): + raise TypeError("Expected a list of Cell") self.device_seq[device_id] = seq def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): @@ -64,7 +66,7 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): Args: span (List[int]): length equal to schedule unit num. - Each element stands for the time span for corresponding SU + Each element stands for the time span for corresponding Cell outfile: the output file name @@ -76,13 +78,13 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): if spans is None: spans = list() - for su in self.seq.sus(): + for node in self.graph.nodes(): span = 0 - if su.stype == SUType.Forward: + if isinstance(node, IRFwOperation): span = 1 - elif su.stype == SUType.Backward: + elif isinstance(node, IRBpOperation): span = 2 - elif su.stype in [SUType.P2P, SUType.Transform]: + elif isinstance(node, IRAdapter): span = 0.1 else: span = 0 @@ -155,9 +157,9 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): def __repr__(self): - dscp = f'Execution Plan ({self.sugraph.name}):\n' + dscp = f'Execution Plan ({self.graph.name}):\n' for devid in self.devices(): dscp += f'====> Device {devid}:\n' - for su in self.sequence(devid): - dscp += f'{su}\n' + for node in self.sequence(devid): + dscp += f'{node.module_repr()}\n' return dscp diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index f78630e4..fc65dfd6 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -151,7 +151,7 @@ class IRAdapter(IRCell): * Merge: merge the produced tensors """ def __init__(self, dst_tensor: IRSubTensor): - print(f'generating adapter for: {dst_tensor}') + # print(f'generating adapter for: {dst_tensor}') if not isinstance(dst_tensor, IRSubTensor): raise RuntimeError("Expected IRSubTensor") self.dst_tensor = dst_tensor @@ -190,7 +190,7 @@ def _gen_select(self): otensor = self.dst_tensor odevice = otensor.device - print(f'select: produced tensors: {otensor.parent.ptensors}') + # print(f'select: produced tensors: {otensor.parent.ptensors}') local, remote = list(), list() for ptensor in otensor.parent.ptensors: @@ -332,6 +332,9 @@ def is_identity(self): def __repr__(self): dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' return dscp + + def module_repr(self) -> str: + return repr(self) def extra_repr(self): """ diff --git a/cube/graph/graph.py b/cube/graph/graph.py index e63b4fdf..6b5bd6d2 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -10,10 +10,8 @@ from typing import Union, Tuple, List, Optional, Dict import copy -from numpy import isin - -from cube.graph.operator.operator import IRBpOperation, IRFwOperation from cube.ir.cten import IRTensor, IRCell +from cube.graph.operator.operator import IRBpOperation, IRFwOperation from cube.graph.tensor import IRSubTensor from cube.algorithm.generics import GenericDistAlgo @@ -39,10 +37,10 @@ def __init__(self, self._parameters = list() if inputs is None: - inputs = IRCell.get_inputs(nodes) + inputs = IRGraph.get_inputs(nodes) inputs = [t for t in inputs if not t.is_param()] if outputs is None: - outputs = IRCell.get_outputs(nodes) + outputs = IRGraph.get_outputs(nodes) outputs = [t for t in outputs if not t.is_param()] super().__init__( @@ -106,39 +104,39 @@ def nodes(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") - def insert(self, node, src_node=None, dst_node=None, replaced_tensor=None): - """ - Insert a node between src_node and dst_node. In default, - if dst_node is not None, the node will be inserted right before - dst_node. If the replaced_tensor is provided, the replaced_tensor - in dst_node's inputs will be removed, and the output of node will be - set as input for dst_node. - """ - if not isinstance(node, IRCell): - raise TypeError("Expected IRCell to insert") - if dst_node is not None: - if dst_node not in self._nodes: - raise KeyError("dst_node not found") - if replaced_tensor is not None: - if replaced_tensor not in dst_node.inputs(): - raise RuntimeError(f"Expected dst_node input has {replaced_tensor}") - # remove dst_node input - input_index = dst_node.inputs().index(replaced_tensor) - if len(node.outputs()) != 1: - raise RuntimeError("replaced node requires output length to be 1") - dst_node.set_input(input_index, node.outputs(0)) - # insert node - index = self._nodes.index(dst_node) - self._nodes.insert(index, node) - elif src_node is not None: - if src_node not in self._nodes: - raise KeyError("src_node not found") - index = self._nodes.index(src_node) - self._nodes = self._nodes[:index+1] + [node] + self._nodes[index+1:] - else: - raise TypeError("Expected at least one of [src_node, dst_node]") - #TODO: optimize this - self.reset_dependency() + # def insert(self, node, src_node=None, dst_node=None, replaced_tensor=None): + # """ + # Insert a node between src_node and dst_node. In default, + # if dst_node is not None, the node will be inserted right before + # dst_node. If the replaced_tensor is provided, the replaced_tensor + # in dst_node's inputs will be removed, and the output of node will be + # set as input for dst_node. + # """ + # if not isinstance(node, IRCell): + # raise TypeError("Expected IRCell to insert") + # if dst_node is not None: + # if dst_node not in self._nodes: + # raise KeyError("dst_node not found") + # if replaced_tensor is not None: + # if replaced_tensor not in dst_node.inputs(): + # raise RuntimeError(f"Expected dst_node input has {replaced_tensor}") + # # remove dst_node input + # input_index = dst_node.inputs().index(replaced_tensor) + # if len(node.outputs()) != 1: + # raise RuntimeError("replaced node requires output length to be 1") + # dst_node.set_input(input_index, node.outputs(0)) + # # insert node + # index = self._nodes.index(dst_node) + # self._nodes.insert(index, node) + # elif src_node is not None: + # if src_node not in self._nodes: + # raise KeyError("src_node not found") + # index = self._nodes.index(src_node) + # self._nodes = self._nodes[:index+1] + [node] + self._nodes[index+1:] + # else: + # raise TypeError("Expected at least one of [src_node, dst_node]") + # #TODO: optimize this + # self.reset_dependency() def _replace_tensor(self, old_tensor: IRTensor, new_tensor: IRTensor): """ @@ -188,39 +186,74 @@ def subgraph(self, sub_nodes: List[IRCell]): """ Create a subgraph with sub nodes. - The remote tensor will be set as graph input (recv tensors) - and graph output (send tensors) - Return: IRGraph """ - # find input - inputs = list() - outputs = list() - for node in sub_nodes: - outer_cells = list(set(self.nodes()) - set(sub_nodes)) - for tensor in node.inputs(): - if isinstance(tensor, IRTensor) and tensor not in inputs: - # if a tensor is generated by other nodes out of sub_nodes, - # then this tensor should be the input - src_nodes = tensor.src(outer_cells) - if len(src_nodes) != 0 or tensor in self.inputs(): - inputs.append(tensor) - for tensor in node.outputs(): - if isinstance(tensor, IRTensor) and tensor not in outputs: - # if a tensor is used by other nodes out of sub_nodes, - # then this tensor should be output - dst_nodes = tensor.dst(outer_cells) - if len(dst_nodes) != 0 or tensor in self.outputs(): - outputs.append(tensor) - - graph = IRGraph( + subgraph = IRGraph( nodes = sub_nodes, - input_tensors = inputs, - output_tensors = outputs, - module_name = self.name + input_tensors = None, + output_tensors = None, + module_name = 'subgraph' ) - return graph + return subgraph + + @staticmethod + def get_inputs(nodes: List[IRCell]): + """ + Get all the input tensors the is not generated by nodes + + Inputs + + Returns: + List[IRTensor] + """ + all_outputs = list() + for node in nodes: + all_outputs += node.outputs() + inputs = list() + for cell in nodes: + for input in cell.inputs(): + if isinstance(input, IRTensor): + if input not in all_outputs: + if input not in inputs: + inputs.append(input) + return inputs + + @staticmethod + def get_outputs(nodes: List[IRCell]): + """ + Get all the output tensors the is not used by nodes + + Args: + This will also consider the successor forward nodes. + If it is required by other outside forward nodes, + put in the outputs list + + Returns: + List[IRTensor] + """ + all_inputs = list() + for node in nodes: + all_inputs += node.inputs() + outputs = list() + for node in nodes: + for idx, output in enumerate(node.outputs()): + # not consumed tensor + if isinstance(output, IRSubTensor): + if output not in all_inputs: + if output not in outputs: + outputs.append(output) + continue + # consumed by other nodes + succs = node.successors(idx) + fsuccs = [ + fnode for fnode in succs if isinstance(fnode, IRFwOperation) + ] + for fsucc in fsuccs: + if fsucc not in nodes: + if output not in outputs: + outputs.append(output) + return outputs ## Parallel Policy Primitives ## @@ -237,9 +270,6 @@ def replicate(self, op: IRCell, times=1): if op not in self.nodes(): raise RuntimeError(f"Op {op} not exsits") - cpy_op = op.replicate() - if op.mirror is not None: - cpy_mirror_op = op.mirror.replicate() ops = [op] mirror_ops = [op.mirror] @@ -263,9 +293,9 @@ def replicate(self, op: IRCell, times=1): def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: """ - Policy primitive. Partition an operator by using - op_partition_algorithm and its configuration. Note the - backward op-partition will be automatically done. + Partition an operator (op) by using + op partition algorithm (algo) and its configuration (config). + Note the backward op-partition will be automatically done. Args: op: cell to be partitioned @@ -347,6 +377,26 @@ def assign(self, op: IRCell, rank: int): ## Schedule Policy Primitives ## + def happen_before(self, node1: IRCell, node2: IRCell, skip=None) -> bool: + """ + Check node1 -> (happen before) node2 + + Returns: + Boolean + """ + skip = list() if skip is None else skip + if node1 in skip: + return False + if not isinstance(node1, IRCell) or not isinstance(node2, IRCell): + raise TypeError("Expected node to be IRCell") + if node2 in node1.successors(): + return True + else: + for succ_node in node1.successors(): + if self.happen_before(succ_node, node2, skip): + return True + return False + def set_order(self, seq: List[IRCell]): raise NotImplementedError @@ -354,6 +404,10 @@ def partial_set_order(self, seq: List[IRCell], lazy=False): raise NotImplementedError def __repr__(self): + dscp = f"Graph{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" + return dscp + + def extra_repr(self): dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs dscp += f"Inputs: {self.inputs()}\n" diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index a5f6670e..f54e7fb2 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -121,6 +121,15 @@ def __repr__(self): dscp = f'FwOp{self._id}-{self.device}(sign={sign}, inputs={self.inputs()}, outputs={self.outputs()})' return dscp + def module_repr(self) -> str: + """ + Weight-hidden string representation + """ + sign = self.signature.split('.')[-1] + ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_param()] + dscp = f'FwOp{self._id}-{self.device}(sign={sign}, inputs={ins}, outputs={self.outputs()})' + return dscp + class IRBpOperation(IRCell): @@ -252,6 +261,16 @@ def __repr__(self): dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, grads={self.grads()}, datas={self.datas()}, outputs={self.outputs()})' return dscp + def module_repr(self) -> str: + """ + Weight-hidden string representation + """ + ins = [t for t in self.datas() if isinstance(t, IRSubTensor) and not t.is_param()] + outs = [t.grad for t in ins] + assert all([out in self.outputs() for out in outs]) + dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, grads={self.grads()}, outputs={outs})' + return dscp + class IRDataOperation(IRCell): @@ -324,6 +343,9 @@ def __repr__(self): dscp = f'DataLoader{self._id}-{self.device}(outputs={self.outputs()})' return dscp + def module_repr(self) -> str: + return repr(self) + class IROptimOperation(IRCell): From 0ab51ab97933204b0cfaeaa44698a3b0e5c676ab Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 28 Dec 2021 17:29:28 +0800 Subject: [PATCH 0494/1892] model code gen --- cube/codegen/codegen.py | 523 ++++++++++++++++---------------- cube/graph/adapter/adapter.py | 210 +++++++------ cube/graph/adapter/gen.py | 6 +- cube/graph/graph.py | 33 +- cube/graph/operator/operator.py | 54 ++-- cube/ir/cten.py | 19 +- 6 files changed, 451 insertions(+), 394 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 0b3f9515..d58c23a0 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -4,17 +4,16 @@ from typing import List, Any import torch import copy -from cube.graph.operator.operator import IRFwOperation, IROptimOperation -from cube.ir.cten import IRTensor -from cube.graph.tensor import ValueMap +from cube.ir.cten import IRCell, IRTensor +from cube.graph.tensor import IRSubTensor +from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.graph.operator.operator import IROptimOperation +from cube.graph.adapter.adapter import IRAdapter, SelectPrim, MovePrim, MergePrim from cube.execplan import ExectuionPlan -from cube.schedule.adapter.collectives import IRCollectives +# from cube.schedule.adapter.collectives import IRCollectives -from cube.schedule.su import ScheduleUnit, SUType -from cube.schedule.adapter.comm import IRCommType, IRCommunication -from cube.schedule.adapter.transform import IRTensorTransform -from cube.schedule.adapter.transform import SelectPrim, MergePrim +from cube.graph.graph import IRGraph from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -28,19 +27,8 @@ def __init__(self, execplan: ExectuionPlan): raise TypeError("execplan should be ExecutionPlan") self.execplan = execplan - def su_naming(self, su: ScheduleUnit) -> str: - if su.stype == SUType.Forward: - return f"fwcp{su._id}" - if su.stype == SUType.Backward: - return f"bwcp{su._id}" - if su.stype == SUType.P2P: - return f"p2p{su._id}" - if su.stype == SUType.Coll: - return f"coll{su._id}" - if su.stype == SUType.Transform: - return f"trans{su._id}" - if su.stype == SUType.Optimizer: - return f"optim{su._id}" + def node_naming(self, node: IRCell) -> str: + return f"{node.name}{node._id}" def tensor_naming(self, tensor: Any) -> str: """ @@ -58,7 +46,7 @@ def tensor_naming(self, tensor: Any) -> str: class ModelCodeGen(CodeGen): """ - Generate spatial code for the model + Generate model code """ def __init__(self, execplan: ExectuionPlan): @@ -70,7 +58,7 @@ def __init__(self, execplan: ExectuionPlan): # module init code self.declare_region: List[str] = list() # module forward code - self.all_su_forward_region: List[List[str]] = list() + self.forward_region_units: List[List[str]] = list() self.forward_region: List[str] = list() # module member name self.symbols = SymbolTable() @@ -78,7 +66,7 @@ def __init__(self, execplan: ExectuionPlan): self._ref_module = torch.nn.Module() # groups self._all_comm_groups = list() - self.get_all_groups() + # self.get_all_groups() def get_all_groups(self): """ @@ -87,6 +75,7 @@ def get_all_groups(self): Creating communication group requires all the devices enter the same call. """ + raise NotImplementedError for devid in self.execplan.devices(): for su in self.execplan.sequence(devid): if su.stype == SUType.Coll: @@ -100,53 +89,50 @@ def gen(self, device: int, outfile=None, attach=False) -> str: """ Generate model implementation code based on the given graph. """ - device_sus = self.execplan.sequence(device) - device_sus = [su for su in device_sus \ - if su.stype != SUType.Backward \ - and su.stype != SUType.Dataloader] - gencode = copy.copy(self.init_code) + node_args: List[List[str]] = list() + gen_nodes: List[IRCell] = list() - # register forward input - su_args: List[List[str]] = list() - for su in device_sus: - fargs = list() - for input in su.inputs(): - if isinstance(input, IRTensor) and input.is_param(): - continue - fargs.append(self.tensor_naming(input)) - for name in fargs: - self.symbols.create(name) - su_args.append(fargs) - - # init group - self.emit_comm_group_creation() + # TODO init group + # self.emit_comm_group_creation() # parse graph body - for su in device_sus: - for node in su.nodes(): - if isinstance(node, IRFwOperation): - self.emit_op_call(node) - elif isinstance(node, IRTensorTransform): - self.emit_transform_call(node) - elif isinstance(node, IRCommunication): - self.emit_comm_call(node) - elif isinstance(node, IRCollectives): - self.emit_collective_call(node) - elif isinstance(node, IROptimOperation): - self.emit_optim_init(node) - self.emit_optim_call(node) - else: - raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") - # emit input declaration - for arg in node.inputs(): - self.emit_var_declare(arg) - # record output tensor name - for out in node.outputs(): - if isinstance(out, IRTensor) or isinstance(out, str): - self.symbols.create(self.tensor_naming(out)) - self.all_su_forward_region.append(self.forward_region) + for node in self.execplan.sequence(device): + if isinstance(node, IRGraph): + # skip backward ir graph + if all([isinstance(n, IRBpOperation) for n in node.nodes()]): + continue + self.emit_graph_call(node) + elif isinstance(node, IRFwOperation): + self.emit_op_call(node) + elif isinstance(node, IRAdapter): + node = node.dispatch(rank=device) + self.emit_adapter_call(node) + # elif isinstance(node, IRCollectives): + # self.emit_collective_call(node) + elif isinstance(node, IROptimOperation): + self.emit_optim_init(node) + self.emit_optim_call(node) + elif isinstance(node, IRBpOperation): + continue + elif isinstance(node, IRDataOperation): + continue + else: + raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") + # emit node tensor declaration + self.emit_node_declare(node) + # emit node code + self.forward_region_units.append(self.forward_region) self.forward_region = list() + gen_nodes.append(node) + args = list() + for t in node.inputs(): + if isinstance(t, IRSubTensor): + if not t.is_param(): + args.append(self.tensor_naming(t)) + else: + args.append(self.tensor_naming(t)) + node_args.append(args) # generate full code with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: @@ -154,15 +140,15 @@ def gen(self, device: int, outfile=None, attach=False) -> str: ib.insert_body(self.declare_region) cb.insert_body('') cb.insert_body(ib.code) - for idx, su in enumerate(device_sus): - name = self.su_naming(su) - input_args = ['self'] + su_args[idx] - forward_code = self.all_su_forward_region[idx] + for idx, node in enumerate(gen_nodes): + name = self.node_naming(node) + input_args = ['self'] + node_args[idx] + forward_code = self.forward_region_units[idx] with FunctionBlock(func_name=name, args=input_args) as fb: fb.insert_body(forward_code) # generate output - out_names = self._forward_region_arg_names(su.outputs()) - return_code = f"return {', '.join(out_names)}" + outputs = [self.tensor_naming(t) for t in node.outputs()] + return_code = f"return {', '.join(outputs)}" fb.insert_body(return_code) cb.insert_body('') cb.insert_body(fb.code) @@ -179,127 +165,175 @@ def gen(self, device: int, outfile=None, attach=False) -> str: self.clear() return code - def emit_var_declare(self, var: Any): + def emit_node_declare(self, node: IRCell): """ Emit tensor declaration code """ - if isinstance(var, IRTensor): - name = self.tensor_naming(var) - # emit parameter code - if var.is_param() and not self.symbols.exist(name): - self.symbols.create(name) - code = f'self.{name} = torch.nn.Parameter(torch.empty({tuple(var.shape)}))' - self.declare_region.append(code) - elif isinstance(var, str): - name = self.tensor_naming(var) - if name.startswith('self.'): - if not hasattr(self._ref_module, var[5:]): - if self.symbols.create(name): - #TODO: add default value - code = f'{name} = None' - self.declare_region.append(code) + for input in node.inputs(): + name = self.tensor_naming(input) + if isinstance(input, IRTensor): + if input.is_param() and not self.symbols.exist(name): + self.symbols.create(name) + code = f'{name} = torch.nn.Parameter(torch.empty({tuple(input.shape)}))' + self.declare_region.append(code) + if isinstance(input, str): + if name.startswith('self.'): + if not hasattr(self._ref_module, name[5:]): + raise NotImplementedError("member attribute is not added") + for output in node.outputs(): + self.symbols.create(self.tensor_naming(output)) return - def emit_comm_group_creation(self): - """ - Emit communication group creation code - """ - sign = 'self.init_group(ranks={ranks})' - for ranks in self._all_comm_groups: - ranks = list(ranks) - code = sign.format(ranks=ranks) - self.declare_region.append(code) - def emit_op_call(self, node): + def emit_graph_call(self, graph: IRGraph): + for node in graph.nodes(): + if isinstance(node, IRBpOperation): + raise RuntimeError("IRBpOperation is not expected in GenModel") + self.emit_op_call(node) + + + # def emit_comm_group_creation(self): + # """ + # Emit communication group creation code + # """ + # sign = 'self.init_group(ranks={ranks})' + # for ranks in self._all_comm_groups: + # ranks = list(ranks) + # code = sign.format(ranks=ranks) + # self.declare_region.append(code) + + def emit_op_call(self, node: IRFwOperation): """ Emit op forward code """ op_code = node.signature - arg_names = self._forward_region_arg_names(node.inputs()) + inputs = [self.tensor_naming(t) for t in node.inputs()] kwargs = list() for key in node.kwargs: code = f'{key}={node.kwargs[key]}' kwargs.append(code) - arg_names += kwargs - arg_region = '(' + ', '.join(arg_names) + ')' + inputs += kwargs + inputs = ', '.join(inputs) + body = f'{op_code}({inputs})' if len(node.outputs()) == 0: - code = f'{op_code}{arg_region}' - else: - out_names = self._forward_region_arg_names(node.outputs()) - out_names = ', '.join(out_names) - code = f'{out_names} = {op_code}{arg_region}' - self.forward_region.append(code) - - def emit_comm_call(self, node): - """ - Emit communication code - """ - comm_code = node.signature - send_tensors = self._forward_region_arg_names(node.inputs()) - send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' - send_ranks = node.send_ranks - recv_tensors = self._forward_region_arg_names(node.outputs()) - recv_tensors = ', '.join(recv_tensors) - recv_shapes = [tensor.shape for tensor in node.outputs()] - recv_ranks = node.recv_ranks - if node.comm_type == IRCommType.Send: - code = f'{comm_code}({send_tensors}, {send_ranks})' - elif node.comm_type == IRCommType.Recv: - code = f'{recv_tensors} = {comm_code}({recv_shapes}, {recv_ranks})' - elif node.comm_type == IRCommType.SendRecv: - code = f'{recv_tensors} = {comm_code}({send_tensors}, {send_ranks}, {recv_shapes}, {recv_ranks})' - else: - raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") - self.forward_region.append(code) - - def emit_collective_call(self, node): - ranks = node.ranks - inputs = self._forward_region_arg_names(node.inputs()) - shape = None - if len(inputs) == 0: - assert len(node.outputs()) == 1 - shape = node.outputs(0).shape - inputs = '(' + ', '.join(inputs + ['']) + ')' - outputs = self._forward_region_arg_names(node.outputs()) - outputs = ', '.join(outputs) - if shape: - code = f'{node.signature}({inputs}, {ranks}, {shape})' + code = body else: - code = f'{node.signature}({inputs}, {ranks})' - if outputs: - code = f'{outputs} = {code}' + outputs = [self.tensor_naming(t) for t in node.outputs()] + outputs = ', '.join(outputs) + code = f'{outputs} = {body}' self.forward_region.append(code) - def emit_transform_call(self, node: IRTensorTransform): + def emit_adapter_call(self, node: IRAdapter): """ - Emit in-device tensor select / merge call. + Emit adapter call """ - for prim in node.trace(): + if len(node.device) != 1: + raise RuntimeError("Expected IRAdapter to be dispatched") + rank = node.device[0] + for prim in node.prims(): + # emit select if isinstance(prim, SelectPrim): - signature = 'cube.runtime.transform.select({tensor}, {indmap}, {valmap})' + sign = 'cube.runtime.transform.select({tensor}, {indmap}, {valmap})' input = self.tensor_naming(prim.tensor) - indmap = repr(prim.indmap) - valmap = repr(tuple([prim.valmap.idx, prim.valmap.chunk_num])) output = self.tensor_naming(prim.output) - code = f'{output} = {signature.format(tensor=input, indmap=indmap, valmap=valmap)}' + valmap = (prim.valmap.idx, prim.valmap.chunk_num) + code = f'{output} = {sign.format(tensor=input, indmap=prim.indmap, valmap=valmap)}' self.forward_region.append(code) + # emit move + elif isinstance(prim, MovePrim): + send_sign = 'cube.runtime.transform.send({tensor}, {send_rank})' + recv_sign = 'cube.runtime.transform.recv({shape}, {from_rank}, {dtype})' + tensor = self.tensor_naming(prim.tensor) + # send + if rank == prim.from_rank: + code = f'{send_sign.format(tensor=tensor, send_rank=prim.to_rank)}' + self.forward_region.append(code) + # recv + elif rank == prim.to_rank: + output = self.tensor_naming(prim.tensor) + code = f'{tensor} = {recv_sign.format(shape=prim.shape, from_rank=prim.from_rank, dtype=prim.dtype)}' + self.forward_region.append(code) + # emit merge elif isinstance(prim, MergePrim): - signature = 'cube.runtime.transform.merge({tensors}, {concat}, {add})' - inputs = self._forward_region_arg_names(prim.tensors) - inputs = '(' + ', '.join(inputs) + ')' + sign = 'cube.runtime.transformation.merge({tensors}, {concat}, {add})' + inputs = [self.tensor_naming(t) for t in prim.tensors] + inputs = '(' + ','.join(inputs + ['']) + ')' output = self.tensor_naming(prim.output) - code = f'{output} = {signature.format(tensors=inputs, concat=prim.concat, add=prim.add)}' + code = f'{output} = {sign.format(tensors=inputs, concat=prim.concat, add=prim.add)}' self.forward_region.append(code) else: - raise RuntimeError(f"Not supported prim: {type(prim)}") - for output in node.outputs(): - # contiguous and requires grad - output_name = self.tensor_naming(output) - code = f'{output_name} = {output_name}.contiguous()' - self.forward_region.append(code) - if not output.is_grad(): - code = f'{output_name} = {output_name}.requires_grad_()' - self.forward_region.append(code) + raise TypeError(f"Unkown primitive types {type(prim)} of Adapter") + + # def emit_comm_call(self, node): + # """ + # Emit communication code + # """ + # comm_code = node.signature + # send_tensors = self._forward_region_arg_names(node.inputs()) + # send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' + # send_ranks = node.send_ranks + # recv_tensors = self._forward_region_arg_names(node.outputs()) + # recv_tensors = ', '.join(recv_tensors) + # recv_shapes = [tensor.shape for tensor in node.outputs()] + # recv_ranks = node.recv_ranks + # if node.comm_type == IRCommType.Send: + # code = f'{comm_code}({send_tensors}, {send_ranks})' + # elif node.comm_type == IRCommType.Recv: + # code = f'{recv_tensors} = {comm_code}({recv_shapes}, {recv_ranks})' + # elif node.comm_type == IRCommType.SendRecv: + # code = f'{recv_tensors} = {comm_code}({send_tensors}, {send_ranks}, {recv_shapes}, {recv_ranks})' + # else: + # raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") + # self.forward_region.append(code) + + # def emit_collective_call(self, node): + # ranks = node.ranks + # inputs = self._forward_region_arg_names(node.inputs()) + # shape = None + # if len(inputs) == 0: + # assert len(node.outputs()) == 1 + # shape = node.outputs(0).shape + # inputs = '(' + ', '.join(inputs + ['']) + ')' + # outputs = self._forward_region_arg_names(node.outputs()) + # outputs = ', '.join(outputs) + # if shape: + # code = f'{node.signature}({inputs}, {ranks}, {shape})' + # else: + # code = f'{node.signature}({inputs}, {ranks})' + # if outputs: + # code = f'{outputs} = {code}' + # self.forward_region.append(code) + + # def emit_transform_call(self, node): + # """ + # Emit in-device tensor select / merge call. + # """ + # for prim in node.trace(): + # if isinstance(prim, SelectPrim): + # signature = 'cube.runtime.transform.select({tensor}, {indmap}, {valmap})' + # input = self.tensor_naming(prim.tensor) + # indmap = repr(prim.indmap) + # valmap = repr(tuple([prim.valmap.idx, prim.valmap.chunk_num])) + # output = self.tensor_naming(prim.output) + # code = f'{output} = {signature.format(tensor=input, indmap=indmap, valmap=valmap)}' + # self.forward_region.append(code) + # elif isinstance(prim, MergePrim): + # signature = 'cube.runtime.transform.merge({tensors}, {concat}, {add})' + # inputs = self._forward_region_arg_names(prim.tensors) + # inputs = '(' + ', '.join(inputs) + ')' + # output = self.tensor_naming(prim.output) + # code = f'{output} = {signature.format(tensors=inputs, concat=prim.concat, add=prim.add)}' + # self.forward_region.append(code) + # else: + # raise RuntimeError(f"Not supported prim: {type(prim)}") + # for output in node.outputs(): + # # contiguous and requires grad + # output_name = self.tensor_naming(output) + # code = f'{output_name} = {output_name}.contiguous()' + # self.forward_region.append(code) + # if not output.is_grad(): + # code = f'{output_name} = {output_name}.requires_grad_()' + # self.forward_region.append(code) def emit_optim_init(self, node: IROptimOperation): # reducer init interface @@ -313,7 +347,7 @@ def emit_optim_init(self, node: IROptimOperation): self.declare_region.append('') init_code = reducer_init.format(reducer=reducer_name, ranks=ranks) self.declare_region.append(init_code) - grads = self._forward_region_arg_names(grads) + grads = [self.tensor_naming(t) for t in grads] for grad in grads: add_param_code = add_param.format(reducer=reducer_name, grad=grad) self.declare_region.append(add_param_code) @@ -325,20 +359,17 @@ def emit_optim_call(self, node: IROptimOperation): call_code = f'{reducer_name}.allreduce()' self.forward_region.append(call_code) - def _forward_region_arg_names(self, tensors: List[Any]): + def tensor_naming(self, tensor: Any): """ - Generate arg name list for forward region. + Generate tensor name. - Will add prefix 'self.' for var defined in declare region + Will add prefix 'self.' for parameters """ - named_args : List[str] = list() - for tensor in tensors: - name = self.tensor_naming(tensor) - if isinstance(tensor, IRTensor) and tensor.is_param(): - named_args.append('self.' + name) - else: - named_args.append(name) - return named_args + name = super().tensor_naming(tensor) + if isinstance(tensor, IRSubTensor): + if tensor.is_param(): + name = 'self.' + name + return name def clear(self): """ @@ -347,7 +378,7 @@ def clear(self): # module init code self.declare_region: List[str] = list() # module forward code - self.all_su_forward_region: List[List[str]] = list() + self.forward_region_units: List[List[str]] = list() self.forward_region: List[str] = list() # module member name self.symbols = SymbolTable() @@ -376,9 +407,9 @@ def gen(self, device: int, outfile=None, attach=False) -> str: args=['model', 'dataloader']) as fb: if len(device_sus) == 0: fb.insert_body('pass') - for su in device_sus: - name = self.su_naming(su) - code = self.emit_su(su, name=name) + for node in device_sus: + name = self.node_naming(node) + code = self.emit_node(node, name=name) fb.insert_body(code) gencode += fb.code gencode += [''] @@ -390,92 +421,60 @@ def gen(self, device: int, outfile=None, attach=False) -> str: f.write(code) return code - def emit_su(self, su: ScheduleUnit, name: str) -> List[str]: + def emit_node(self, node: IRCell, name: str) -> List[str]: """ - Emit su code + Emit node / subgraph code """ - fsu_types = [SUType.Forward, SUType.P2P, SUType.Coll, SUType.Transform, SUType.Optimizer] fsign = 'cube.runtime.executor.fexecute({model}, *{inputs})' bsign = 'cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' - if su.stype == SUType.Dataloader: - if len(su.inputs()) != 0: - raise RuntimeError("Dataloader su has no inputs") - outputs = [self.tensor_naming(output) for output in su.outputs()] + inputs = [self.tensor_naming(t) for t in node.inputs() if not t.is_param()] + outputs = [self.tensor_naming(t) for t in node.outputs()] + inputs = '(' + ','.join(inputs + ['']) + ')' + outputs = ', '.join(outputs) + + if isinstance(node, IRGraph): + is_backward = all([isinstance(n, IRBpOperation) for n in node.nodes()]) + # emit forward + if not is_backward: + body = fsign.format(model=f'model.{name}', inputs=inputs) + code = f'{outputs} = {body}' + # emit backward + else: + finputs = [t.data for t in node.outputs() if isinstance(t, IRSubTensor)] + finputs = '(' + ','.join(finputs + ['']) + ')' + foutputs = [t.data for t in node.inputs() if isinstance(t, IRSubTensor)] + foutputs = '(' + ','.join(foutputs + ['']) + ')' + outputs = [self.tensor_naming(t) for t in node.outputs() if isinstance(t, IRSubTensor)] + outputs = ', '.join(outputs) + body = bsign.format( + input_tensors=finputs, output_tensors=foutputs, output_grads=inputs + ) + code = f'{outputs} = {body}' + + elif isinstance(node, IRDataOperation): + if len(node.inputs()) != 0: + raise RuntimeError("Expect Dataloader node has no inputs") + outputs = [self.tensor_naming(output) for output in node.outputs()] return_val = ','.join(outputs) code = f'{return_val} = next(dataloader)' - return code - - elif su.stype in fsu_types: - inputs = list() - for tensor in su.inputs(): - if isinstance(tensor, IRTensor): - if tensor.is_param(): - continue - inputs.append(self.tensor_naming(tensor)) - inputs = '(' + ', '.join(inputs + ['']) + ')' - body = fsign.format( - model = f'model.{name}', - inputs = inputs - ) - outputs = [self.tensor_naming(output) for output in su.outputs()] - return_val = ','.join(outputs) - if len(su.outputs()) == 0: - code = body - else: - code = f'{return_val} = {body}' - return code - - elif su.stype == SUType.Backward: - # 1). input_tensors are forward inputs (happened before su inputs) - # => backward graph output tensor (share tensor in forward / backward graph) - # 2). output_tensors are forward outputs (su.inputs()) - # => backward graph input tensor (share tensor in forward / backward) - # 3). output_grads are recved tesnors of this graph (graph.recv_tensors) - # => backward graph input tensor (graph.recv_tensors) - fsu = su.mirror - finputs = list() - for tensor in fsu.inputs(): - if isinstance(tensor, IRTensor): - if tensor.is_param(): - finputs.append('model.' + self.tensor_naming(tensor)) - continue - finputs.append(self.tensor_naming(tensor)) - fargs = '(' + ', '.join(finputs + ['']) + ')' - - fouts = list() - for tensor in fsu.outputs(): - fouts.append(self.tensor_naming(tensor)) - fouts = '(' + ', '.join(fouts + ['']) + ')' - - fout_grads = list() - for fout in fsu.outputs(): - # the loss computed starting point - # if fout == fout.grad: - fout_grads.append(self.tensor_naming(fout.grad)) - fout_grads = '(' + ', '.join(fout_grads + ['']) + ')' - - body = bsign.format( - input_tensors = fargs, - output_tensors = fouts, - output_grads = fout_grads - ) - - # returned value are graph.outputs - return_val = list() - for input in fsu.inputs(): - if isinstance(input, IRTensor): - return_val.append(self.tensor_naming(input.grad)) - else: - return_val.append(None) - # return_val = [self.tensor_naming(tensor.grad) for tensor in finputs] - # TODO: fix this by using grad attributed - # return_val = return_val[:len(finputs)] - if len(return_val) > 0: - return_code = ', '.join(return_val) + ' = ' - else: - return_code = '' - code = f'{return_code}{body}' - return code + + elif isinstance(node, IRAdapter): + body = fsign.format(model=f'model.{name}', inputs=inputs) + code = f'{outputs} = {body}' + else: - raise RuntimeError(f"Unsupported SUType: {su.stype}") + raise RuntimeError(f"Unspported node type: {type(node)}") + return code + + def tensor_naming(self, tensor: Any): + """ + Generate tensor name. + + Will add prefix 'model.' for parameters + """ + name = super().tensor_naming(tensor) + if isinstance(tensor, IRSubTensor): + if tensor.is_param(): + name = 'model.' + name + return name diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index fc65dfd6..ba828921 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -150,47 +150,113 @@ class IRAdapter(IRCell): * Move: transfer the produced tensors * Merge: merge the produced tensors """ - def __init__(self, dst_tensor: IRSubTensor): - # print(f'generating adapter for: {dst_tensor}') - if not isinstance(dst_tensor, IRSubTensor): - raise RuntimeError("Expected IRSubTensor") - self.dst_tensor = dst_tensor - self._intersections = list() - - # ====== select ====== - self._select_trace = list() - self._select_ptensors = list() - - # ====== move ======= - self._move_trace = list() - - # ====== merge ======= - self._merge_trace = list() - self._gen_select() - self._gen_move() - self._gen_merge() + def __init__(self, prims, + inputs: List[IRSubTensor], idevices: List[List[int]], + outputs: List[IRSubTensor], odevices: List[List[int]]): + + self._prims = prims + self._idevices = tuple(idevices) + self._odevices = tuple(odevices) super().__init__( name='adapter', signature='adapter', - input_length=len(self._intersections), output_length=1, + input_length=len(inputs), + output_length=len(outputs), init_outputs=False ) - for idx, ptensor in enumerate(self._select_ptensors): - self.set_input(idx, ptensor) - self.set_output(0, dst_tensor) - + for idx, tensor in enumerate(inputs): + self.set_input(idx, tensor) + for idx, tensor in enumerate(outputs): + self.set_output(idx, tensor) + # set up device device = set() - for prim in self._select_trace + self._move_trace + self._merge_trace: + for prim in self._prims: device.update(prim.device) self.device = list(device) - def _gen_select(self): - otensor = self.dst_tensor - odevice = otensor.device + def prims(self, select=True, move=True, merge=True): + """ + Return prim list + """ + prims = list() + for prim in self._prims: + if select and isinstance(prim, SelectPrim): + prims.append(prim) + if move and isinstance(prim, MovePrim): + prims.append(prim) + if merge and isinstance(prim, MergePrim): + prims.append(prim) + return prims + + def dispatch(self, rank: int): + """ + Get Adapter for a specific rank + + Returns: + IRAdapter + """ + if not isinstance(rank, int): + raise TypeError(f"Expected rank to be int but got {rank}") + prims = list() + for prim in self.prims(): + if rank in prim.device: + prims.append(prim) + inputs, idevs = list(), list() + for input, devs in zip(self.inputs(), self._idevices): + if rank in devs: + inputs.append(input) + idevs.append(devs) + outputs, odevs = list(), list() + for output, devs in zip(self.outputs(), self._odevices): + if rank in devs: + outputs.append(output) + odevs.append(devs) + adapter = IRAdapter(prims, inputs, idevs, outputs, odevs) + adapter.name = self.name + adapter._id = self._id + adapter.device = rank + return adapter + + def is_identity(self): + """ + Check if the adapter does nothing + + Returns: + Boolean + """ + return len(self._prims) == 0 + + def __repr__(self): + dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' + return dscp + + def module_repr(self) -> str: + return repr(self) + + @staticmethod + def gen(dst_tensor: IRSubTensor): + # print(f'generating adapter for: {dst_tensor}') + if not isinstance(dst_tensor, IRSubTensor): + raise RuntimeError("Expected IRSubTensor") + inputs, intersections, select_prims = IRAdapter.gen_select(dst_tensor) + move_prims = IRAdapter.gen_move(dst_tensor, intersections) + merge_prims = IRAdapter.gen_merge(dst_tensor, intersections) + prims = select_prims + move_prims + merge_prims + idevs = [t.device for t in inputs] + odevs = [dst_tensor.device] + return IRAdapter(prims, inputs, idevs, [dst_tensor], odevs) + + @staticmethod + def gen_select(dst_tensor): + + inputs = list() + intersections = list() + prims = list() - # print(f'select: produced tensors: {otensor.parent.ptensors}') + otensor = dst_tensor + odevice = otensor.device local, remote = list(), list() for ptensor in otensor.parent.ptensors: @@ -198,11 +264,13 @@ def _gen_select(self): local.append(ptensor) else: remote.append(ptensor) + # check local tensor if otensor in local: - self._intersections.append(otensor) - self._select_ptensors.append(ptensor) - return + intersections.append(otensor) + inputs.append(ptensor) + return inputs, intersections, prims + # FIXME: multi producer may result overlapped region for itensor in otensor.parent.ptensors: if not itensor.overlap(otensor): @@ -211,8 +279,8 @@ def _gen_select(self): # intersection common = otensor.common(itensor) common.attach_cell(itensor._cell) - self._intersections.append(common) - self._select_ptensors.append(itensor) + intersections.append(common) + inputs.append(itensor) if common == itensor: continue @@ -242,24 +310,31 @@ def _gen_select(self): f"Not supported value select: {input.valmap} -> {common.valmap}" ) prim = SelectPrim(itensor, indmap, valmap, common.shape, common) - self._select_trace.append(prim) + prims.append(prim) + + return inputs, intersections, prims - def _gen_move(self): - odevice = self.dst_tensor.device - for tensor in self._intersections: + @staticmethod + def gen_move(dst_tensor, intersections): + prims = list() + odevice = dst_tensor.device + for tensor in intersections: if tensor.device != odevice: if len(tensor.device) != 1 or len(odevice) != 1: raise RuntimeError( f"Expected tensor on a single device but got {tensor.device} and {odevice}" ) prim = MovePrim(tensor, from_rank=tensor.device[0], to_rank=odevice[0]) - self._move_trace.append(prim) + prims.append(prim) + return prims - def _gen_merge(self): - output = self.dst_tensor - remain_tensors = copy.copy(self._intersections) + @staticmethod + def gen_merge(dst_tensor, intersections): + prims = list() + output = dst_tensor + remain_tensors = copy.copy(intersections) if output in remain_tensors: - return + return prims out = None while out != output: out = None @@ -273,14 +348,14 @@ def _gen_merge(self): if out is not None: out, concat_dim = out prim = MergePrim([tensor1, tensor2], out, output.device, concat_dim, False) - self._merge_trace.append(prim) + prims.append(prim) merged = True break # try add out = MergePrim.add(tensor1, tensor2) if out is not None: prim = MergePrim([tensor1, tensor2], out, output.device, None, True) - self._merge_trace.append(prim) + prims.append(prim) merged = True break if merged: @@ -291,57 +366,14 @@ def _gen_merge(self): # cannot merge or add if out is None: raise RuntimeError("Merge Plan not found") - - def prims(self, select=True, move=True, merge=True): - """ - Return prim list - """ - prims = list() - if select: - prims += self._select_trace - if move: - prims += self._move_trace - if merge: - prims += self._merge_trace return prims - def dispatch(self, rank: int) -> List: - """ - Get executed prim for a specific rank - - Returns: - List[Prims] - """ - if not isinstance(rank, int): - raise TypeError(f"Expected rank to be int but got {rank}") - prims = list() - for prim in self.prims(): - if rank in prims.device: - prims.append(prim) - return prims - - def is_identity(self): - """ - Check if the adapter does nothing - - Returns: - Boolean - """ - return len(self.prims()) == 0 - - def __repr__(self): - dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' - return dscp - - def module_repr(self) -> str: - return repr(self) - def extra_repr(self): """ Detailed information """ dscp = repr(self) + ':\n' # select - for prim in self._select_trace + self._move_trace + self._merge_trace: + for prim in self._select_prims + self._move_prims + self._merge_prims: dscp += '\t' + repr(prim) + '\n' return dscp diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index b217dba2..e15ca86d 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -17,16 +17,16 @@ def gen(graph: IRGraph) -> IRGraph: # skip parameter if input.is_param(): continue - adapter = IRAdapter(input) + adapter = IRAdapter.gen(input) if not adapter.is_identity(): idx = graph.nodes().index(node) graph._nodes.insert(idx, adapter) if isinstance(node, IRBpOperation): - for grad in node.grads(): + for grad in node.inputs(): if not isinstance(grad, IRSubTensor): continue # skip parameter - adapter = IRAdapter(grad) + adapter = IRAdapter.gen(grad) if not adapter.is_identity(): idx = graph.nodes().index(node) graph._nodes.insert(idx, adapter) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 6b5bd6d2..89f00f88 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -38,10 +38,10 @@ def __init__(self, if inputs is None: inputs = IRGraph.get_inputs(nodes) - inputs = [t for t in inputs if not t.is_param()] + # inputs = [t for t in inputs if not t.is_param()] if outputs is None: outputs = IRGraph.get_outputs(nodes) - outputs = [t for t in outputs if not t.is_param()] + # outputs = [t for t in outputs if not t.is_param()] super().__init__( name=module_name, @@ -189,11 +189,31 @@ def subgraph(self, sub_nodes: List[IRCell]): Return: IRGraph """ + sub_inputs = list() + sub_outputs = list() + for node in sub_nodes: + sub_inputs += node.inputs() + sub_outputs += node.outputs() + remain_inputs = list() + remain_outputs = list() + for node in self.nodes(): + if node in sub_nodes: + continue + remain_inputs += node.inputs() + remain_outputs += node.outputs() + inputs = list() + outputs = list() + for t in sub_inputs: + if isinstance(t, IRSubTensor) and t not in sub_outputs: + inputs.append(t) + for t in sub_outputs: + if isinstance(t, IRSubTensor) and t in remain_inputs: + outputs.append(t) subgraph = IRGraph( nodes = sub_nodes, - input_tensors = None, - output_tensors = None, - module_name = 'subgraph' + inputs = inputs, + outputs = outputs, + module_name = 'segment' ) return subgraph @@ -423,3 +443,6 @@ def extra_repr(self): # outputs dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" return dscp + + def module_repr(self): + return repr(self) \ No newline at end of file diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index f54e7fb2..0dbb11bd 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -108,11 +108,13 @@ def gen_backward(self): grad = None if isinstance(input, IRSubTensor): grad = input.get_grad(self) + input.grad = grad bnode.set_data(idx, input) bnode.set_output(idx, grad) for idx, output in enumerate(self.outputs()): grad = output.get_grad(self) - bnode.set_grad(idx, grad) + output.grad = grad + bnode.set_input(idx, grad) IRCell.make_pair(self, bnode) return bnode @@ -133,13 +135,19 @@ def module_repr(self) -> str: class IRBpOperation(IRCell): - def __init__(self, data_num, grad_num, name='backward'): + def __init__(self, data_num: int, grad_num, name='backward'): + """ + Args: + data_num (int): corresponding forward input length + grad_num (int): corresponding forward output length + """ signature = 'torch.autograd.backward' self.data_num = data_num self.grad_num = grad_num + self._datas = [None] * data_num super().__init__( name, signature, - input_length=data_num + grad_num, + input_length=grad_num, output_length=data_num, init_outputs=False ) @@ -163,26 +171,14 @@ def datas(self, index: Optional[int] = None) -> Union[List[Any], Any]: Forward inputs """ if index is None: - return self.inputs()[:self.data_num] + return copy.copy(self._datas[:self.data_num]) if index >= self.data_num: raise RuntimeError( f"Set the input out of range ({index} >= {self.data_num})" ) - return self.inputs(index) + return self._datas[index] - def grads(self, index: Optional[int] = None) -> Union[List[Any], Any]: - """ - backward op input gradient (a.k.a. output gradient in forward) - """ - if index is None: - return self.inputs()[self.data_num:] - elif index >= self.grad_num: - raise RuntimeError( - f"Set the input out of range ({index} >= {self.grad_num})" - ) - return self.inputs(index + self.data_num) - - def set_data(self, input_index: int, val: Any): + def set_data(self, data_index: int, val: Any): """ Set the node inputs[input_index] with the tensor @@ -192,13 +188,16 @@ def set_data(self, input_index: int, val: Any): Return: the set tensor """ - if input_index >= self.data_num: + if data_index >= self.data_num: raise RuntimeError( - f"Set the input out of range ({input_index} >= {self.data_num})" + f"Set the input out of range ({data_index} >= {self.data_num})" ) - return self.set_input(input_index, val) + val = copy.copy(val) + val.attach_cell(self) + self._datas[data_index] = val + return val - def set_grad(self, input_index: int, val: Any): + def set_input(self, input_index: int, val: Any): """ Set the node input gradient (i.e., output gradient in forward) at input index. @@ -212,11 +211,6 @@ def set_grad(self, input_index: int, val: Any): Return: The set val """ - if input_index >= self.grad_num: - raise RuntimeError( - f"Set the grad out of range ({input_index} >= {self.grad_num})" - ) - input_index += self.data_num # remove the consumer old_val = self.inputs(input_index) if isinstance(old_val, IRSubTensor): @@ -255,10 +249,10 @@ def update(self): self.set_output(idx, grad) for idx, output in enumerate(fnode.outputs()): grad = output.get_grad(fnode) - self.set_grad(idx, grad) + self.set_input(idx, grad) def __repr__(self): - dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, grads={self.grads()}, datas={self.datas()}, outputs={self.outputs()})' + dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, inputs={self.inputs()}, datas={self.datas()}, outputs={self.outputs()})' return dscp def module_repr(self) -> str: @@ -268,7 +262,7 @@ def module_repr(self) -> str: ins = [t for t in self.datas() if isinstance(t, IRSubTensor) and not t.is_param()] outs = [t.grad for t in ins] assert all([out in self.outputs() for out in outs]) - dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, grads={self.grads()}, outputs={outs})' + dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, inputs={self.inputs()}, outputs={outs})' return dscp diff --git a/cube/ir/cten.py b/cube/ir/cten.py index d6baf937..a0fa5df8 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -409,11 +409,12 @@ def __init__(self, shape=None, name=None, dtype=IRDType.unknown): self._dtype: IRDType = dtype - self._is_param = False - self._is_grad = False self._requires_grad = True + self._is_param = False - self._grad = None + self._is_grad = False + self._grad = None # the gradient of this tensor + self._data = None # the tensor of this gradient belongs to @property def requires_grad(self): @@ -490,16 +491,24 @@ def is_param(self): """ return self._is_param + @property + def data(self): + return self._data + @property def grad(self): return self._grad @grad.setter def grad(self, grad): - if grad and not isinstance(grad, IRTensor): + if grad is None: + self._grad = grad + return + elif not isinstance(grad, IRTensor): raise TypeError("grad can only be None or Tensor") - self._grad = grad self.requires_grad = True + self._grad = grad + grad._data = self def as_grad(self): self._is_param = False From d15c0084e6a26197614bffc8340ab5976dd9b11c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 28 Dec 2021 18:54:02 +0800 Subject: [PATCH 0495/1892] schedule code gen --- cube/codegen/codegen.py | 28 +++++++++++++++++++++------- cube/graph/graph.py | 10 +++++++--- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index d58c23a0..2bcd6958 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -400,14 +400,19 @@ def gen(self, device: int, outfile=None, attach=False) -> str: Generate scheduling code based on the given sus """ gencode = copy.copy(self.init_code) - device_sus = self.execplan.sequence(device) + + device_nodes = self.execplan.sequence(device) + for idx, node in enumerate(device_nodes): + if isinstance(node, IRAdapter): + node = node.dispatch(rank=device) + device_nodes[idx] = node # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: - if len(device_sus) == 0: + if len(device_nodes) == 0: fb.insert_body('pass') - for node in device_sus: + for node in device_nodes: name = self.node_naming(node) code = self.emit_node(node, name=name) fb.insert_body(code) @@ -430,8 +435,10 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: inputs = [self.tensor_naming(t) for t in node.inputs() if not t.is_param()] outputs = [self.tensor_naming(t) for t in node.outputs()] - inputs = '(' + ','.join(inputs + ['']) + ')' + inputs = '(' + ', '.join(inputs + ['']) + ')' outputs = ', '.join(outputs) + if len(outputs) == 0: + outputs = '_' if isinstance(node, IRGraph): is_backward = all([isinstance(n, IRBpOperation) for n in node.nodes()]) @@ -442,10 +449,17 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: # emit backward else: finputs = [t.data for t in node.outputs() if isinstance(t, IRSubTensor)] - finputs = '(' + ','.join(finputs + ['']) + ')' foutputs = [t.data for t in node.inputs() if isinstance(t, IRSubTensor)] - foutputs = '(' + ','.join(foutputs + ['']) + ')' - outputs = [self.tensor_naming(t) for t in node.outputs() if isinstance(t, IRSubTensor)] + outputs = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + # remove weight gradient in outputs + for input in finputs: + if input.is_param(): + outputs.remove(input.grad) + finputs = [self.tensor_naming(t) for t in finputs] + finputs = '(' + ', '.join(finputs + ['']) + ')' + foutputs = [self.tensor_naming(t) for t in foutputs] + foutputs = '(' + ', '.join(foutputs + ['']) + ')' + outputs = [self.tensor_naming(t) for t in outputs] outputs = ', '.join(outputs) body = bsign.format( input_tensors=finputs, output_tensors=foutputs, output_grads=inputs diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 89f00f88..01d8215f 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -205,10 +205,14 @@ def subgraph(self, sub_nodes: List[IRCell]): outputs = list() for t in sub_inputs: if isinstance(t, IRSubTensor) and t not in sub_outputs: - inputs.append(t) + if t not in inputs: + inputs.append(t) for t in sub_outputs: - if isinstance(t, IRSubTensor) and t in remain_inputs: - outputs.append(t) + if isinstance(t, IRSubTensor): + # not consumed or used outside this subgraph + if t not in sub_inputs or t in remain_inputs: + if t not in outputs: + outputs.append(t) subgraph = IRGraph( nodes = sub_nodes, inputs = inputs, From f4d1e5595cf778817349d15bdc670dbc380193a1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 28 Dec 2021 20:26:26 +0800 Subject: [PATCH 0496/1892] PAS for single graph --- cube/__init__.py | 2 +- cube/codegen/codegen.py | 22 +++- cube/compiler.py | 129 +++++++++++---------- cube/execplan/planpass/grouping.py | 71 ++++++++++++ cube/execplan/planpass/merge.py | 130 ---------------------- cube/graph/_gpass.py | 93 ---------------- cube/graph/adapter/adapter.py | 2 +- cube/graph/graph.py | 16 ++- cube/runtime/__init__.py | 6 +- cube/runtime/adapter/__init__.py | 11 ++ cube/runtime/{ => adapter}/collectives.py | 56 ++++------ cube/runtime/{ => adapter}/reducer.py | 0 cube/runtime/{ => adapter}/transform.py | 0 cube/runtime/executor.py | 87 ++++++++------- cube/runtime/module.py | 2 +- 15 files changed, 247 insertions(+), 380 deletions(-) create mode 100644 cube/execplan/planpass/grouping.py delete mode 100644 cube/execplan/planpass/merge.py delete mode 100644 cube/graph/_gpass.py create mode 100644 cube/runtime/adapter/__init__.py rename cube/runtime/{ => adapter}/collectives.py (82%) rename cube/runtime/{ => adapter}/reducer.py (100%) rename cube/runtime/{ => adapter}/transform.py (100%) diff --git a/cube/__init__.py b/cube/__init__.py index fdb7c42a..2effa9f0 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,7 +1,7 @@ from cube import logics from cube import runtime -# from cube.compiler import SemanticModel, compile +from cube.compiler import SemanticModel, compile def init(): diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 2bcd6958..dcb71a24 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -6,6 +6,7 @@ import copy from cube.ir.cten import IRCell, IRTensor +from cube.ir.dtype import IRDType from cube.graph.tensor import IRSubTensor from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation from cube.graph.operator.operator import IROptimOperation @@ -27,6 +28,11 @@ def __init__(self, execplan: ExectuionPlan): raise TypeError("execplan should be ExecutionPlan") self.execplan = execplan + def dtype_map(self, dtype: IRDType) -> str: + if not isinstance(dtype, IRDType): + raise TypeError("Expected IRDType") + return 'torch.' + dtype.value + def node_naming(self, node: IRCell) -> str: return f"{node.name}{node._id}" @@ -169,12 +175,13 @@ def emit_node_declare(self, node: IRCell): """ Emit tensor declaration code """ + sign = 'torch.nn.Parameter(torch.empty({shape}, dtype={dtype}))' for input in node.inputs(): name = self.tensor_naming(input) if isinstance(input, IRTensor): if input.is_param() and not self.symbols.exist(name): self.symbols.create(name) - code = f'{name} = torch.nn.Parameter(torch.empty({tuple(input.shape)}))' + code = f'{name} = {sign.format(shape=tuple(input.shape), dtype=self.dtype_map(input.dtype))}' self.declare_region.append(code) if isinstance(input, str): if name.startswith('self.'): @@ -233,7 +240,7 @@ def emit_adapter_call(self, node: IRAdapter): for prim in node.prims(): # emit select if isinstance(prim, SelectPrim): - sign = 'cube.runtime.transform.select({tensor}, {indmap}, {valmap})' + sign = 'cube.runtime.adapter.select({tensor}, {indmap}, {valmap})' input = self.tensor_naming(prim.tensor) output = self.tensor_naming(prim.output) valmap = (prim.valmap.idx, prim.valmap.chunk_num) @@ -241,8 +248,8 @@ def emit_adapter_call(self, node: IRAdapter): self.forward_region.append(code) # emit move elif isinstance(prim, MovePrim): - send_sign = 'cube.runtime.transform.send({tensor}, {send_rank})' - recv_sign = 'cube.runtime.transform.recv({shape}, {from_rank}, {dtype})' + send_sign = 'cube.runtime.adapter.send({tensor}, {send_rank})' + recv_sign = 'cube.runtime.adapter.recv({shape}, {from_rank}, {dtype})' tensor = self.tensor_naming(prim.tensor) # send if rank == prim.from_rank: @@ -251,11 +258,12 @@ def emit_adapter_call(self, node: IRAdapter): # recv elif rank == prim.to_rank: output = self.tensor_naming(prim.tensor) - code = f'{tensor} = {recv_sign.format(shape=prim.shape, from_rank=prim.from_rank, dtype=prim.dtype)}' + dtype = self.dtype_map(prim.dtype) + code = f'{tensor} = {recv_sign.format(shape=prim.shape, from_rank=prim.from_rank, dtype=dtype)}' self.forward_region.append(code) # emit merge elif isinstance(prim, MergePrim): - sign = 'cube.runtime.transformation.merge({tensors}, {concat}, {add})' + sign = 'cube.runtime.adapter.merge({tensors}, {concat}, {add})' inputs = [self.tensor_naming(t) for t in prim.tensors] inputs = '(' + ','.join(inputs + ['']) + ')' output = self.tensor_naming(prim.output) @@ -461,6 +469,8 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: foutputs = '(' + ', '.join(foutputs + ['']) + ')' outputs = [self.tensor_naming(t) for t in outputs] outputs = ', '.join(outputs) + if len(outputs) == 0: + outputs = '_' body = bsign.format( input_tensors=finputs, output_tensors=foutputs, output_grads=inputs ) diff --git a/cube/compiler.py b/cube/compiler.py index ec46dd48..60b5aa09 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,21 +1,24 @@ -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union import torch import time import cube + +from cube.graph import parser +from cube.graph.adapter.gen import AdapterGener from cube.graph.graph import IRGraph from cube.graph.operator.operator import IRDataOperation -from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType -from cube.schedule.translator import IRDataLoader -from cube.schedule.sugraph import SUGraph, SUGraphGener + +from cube.logics.pool import SchedulePool +from cube.logics.translator import LogicTranslator from cube.execplan import ExectuionPlan -from cube.execplan.planpass.torchadapt import TorchRefAdapter -from cube.execplan.planpass.redundant import RemoveRedundantAdapters -from cube.execplan.planpass.merge import MergeComputeSU -from cube.execplan.planpass.gfuse import WeightGradAllreduceFusion -from cube.execplan.planpass.p2pfusion import P2PFusion +# from cube.execplan.planpass.torchadapt import TorchRefAdapter +# from cube.execplan.planpass.redundant import RemoveRedundantAdapters +# from cube.execplan.planpass.merge import MergeComputeSU +# from cube.execplan.planpass.gfuse import WeightGradAllreduceFusion +# from cube.execplan.planpass.p2pfusion import P2PFusion +from cube.execplan.planpass.grouping import Grouping from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -27,7 +30,7 @@ def __init__(self, model: torch.nn.Module, input_shapes): Create semantic model based on AI Scientist description. """ from cube.graph import parser - self.ir_graph = parser.convert( + self.ir_graph = parser.convert_model( model, input_shapes=input_shapes ) self._loaded_module = None @@ -60,7 +63,7 @@ def __call__(self, *args): def compile(model: SemanticModel, dataloader, - policy: Tuple[Optional[Callable], Optional[Callable]] = (None, None)): + PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None): """ AI Scientist calls like: @@ -89,14 +92,11 @@ def train_step(model, dataloader): """ if not isinstance(model, SemanticModel): raise TypeError("Expect Semantic Model") - if len(policy) != 2: - raise TypeError( - "Expected policy to be tuple of transformation + scheduling policy" - ) - transform_policy, schedule_policy = policy + if callable(PAS): + PAS = (PAS,) - ir_graph = model.get_graph() - ir_dataloader = IRDataLoader(dataloader) + model_graph = model.get_graph() + ir_dataloader = parser.convert_dataloader(dataloader) if torch.distributed.is_initialized(): # multiple device @@ -126,54 +126,61 @@ def decorator(fn: Callable) -> Callable: resource = cube.runtime.resource.EnvResource() # logic translator - # print(f'> ir_graph:\n{ir_graph}') - fn(ir_graph, ir_dataloader) - - nodes = SchedulePool().nodes() - - # graph transformation - graph = IRGraph(nodes, None, None, ir_graph.name) - if transform_policy: - graph = transform_policy(graph, resource) + fn(model_graph, ir_dataloader) + graph = LogicTranslator.gen_logic_graph() - # sugraph - sugraph = SUGraphGener.gen_sugraph(graph.nodes()) - if schedule_policy: - sugraph = schedule_policy(sugraph, resource) + if len(PAS) == 1: + graph = PAS[0](graph, resource) + elif len(PAS) == 3: + P, A, S = PAS + graph = P(graph, resource) + graph = A(graph, resource) + graph = S(graph, resource) # check assignment and order - # print(sugraph) - for su in sugraph.sus(): - if len(su.device) == 0: - raise RuntimeError(f"SU {su} device is not set") + for node in graph.nodes(): + if len(node.device) == 0: + raise RuntimeError(f"Node {node} device is not set") # if not SUGraph.is_topo_order(sugraph.sus()): # raise RuntimeError(f"SUGraph order is not topological order") - execplan = ExectuionPlan(sugraph) + # generate adapter + graph = AdapterGener.gen(graph) + + # to execution plan + execplan = ExectuionPlan(graph) + + # plan pass for communication optimization + start = time.time() + execplan = Grouping.apply(execplan) + span = time.time() - start + print('> planpass on grouping operations: {:.2f} s'.format(span)) + + # plan pass to adapt to pytorch semantic: multi branch gradient # TODO: residual support # execplan = TorchRefAdapter.apply(execplan) # plan pass to remove redundant sus - start = time.time() - execplan = RemoveRedundantAdapters.apply(execplan) - span = time.time() - start - print('> planpass on remove redundant adapter: {:.2f} s'.format(span)) - # print(f'> after remove redundant adapters:\n {execplan}') - start = time.time() - execplan = MergeComputeSU.apply(execplan) - span = time.time() - start - print('> planpass on merge compute: {:.2f} s'.format(span)) - # print(f'> after merge backward SU:\n {execplan}') - start = time.time() - execplan = WeightGradAllreduceFusion.apply(execplan) - span = time.time() - start - print('> planpass on grad allreduce: {:.2f} s'.format(span)) + # start = time.time() + # execplan = RemoveRedundantAdapters.apply(execplan) + # span = time.time() - start + # print('> planpass on remove redundant adapter: {:.2f} s'.format(span)) + # # print(f'> after remove redundant adapters:\n {execplan}') + # start = time.time() + # execplan = MergeComputeSU.apply(execplan) + # span = time.time() - start + # print('> planpass on merge compute: {:.2f} s'.format(span)) + # # print(f'> after merge backward SU:\n {execplan}') + # start = time.time() + # execplan = WeightGradAllreduceFusion.apply(execplan) + # span = time.time() - start + # print('> planpass on grad allreduce: {:.2f} s'.format(span)) # print(f'> after add allreduce:\n{execplan}') - start = time.time() - execplan = P2PFusion.apply(execplan) - span = time.time() - start - print('> planpass on p2p fusion: {:.2f} s'.format(span)) + # start = time.time() + # execplan = P2PFusion.apply(execplan) + # span = time.time() - start + # print('> planpass on p2p fusion: {:.2f} s'.format(span)) # print(f'> after fuse P2P SU:\n {execplan}') if torch.distributed.is_initialized(): @@ -194,14 +201,14 @@ def decorator(fn: Callable) -> Callable: outfile = fname, attach=True ) + # get dataloader batch size batch_size = dict() # {devid: batch size} - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - data_op: IRDataOperation = su.nodes(0) - batch_dim = data_op.get_batch_dims()[0] - dev_batch_size = data_op.outputs(0).shape[batch_dim] - batch_size[su.device[0]] = dev_batch_size + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + batch_dim = node.get_batch_dims()[0] + dev_batch_size = node.outputs(0).shape[batch_dim] + batch_size[node.device[0]] = dev_batch_size all_batch_size = set([batch_size[dev] for dev in batch_size]) if len(all_batch_size) != 1: raise NotImplementedError("Heterogenous batch size it not supported") diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py new file mode 100644 index 00000000..405077d0 --- /dev/null +++ b/cube/execplan/planpass/grouping.py @@ -0,0 +1,71 @@ +""" +Operation grouping +""" + +from typing import List, Dict + +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.planpass import PlanPass +from cube.graph.operator.operator import IRBpOperation, IRFwOperation +from cube.graph.adapter.adapter import IRAdapter +from cube.ir.cten import IRCell + + +class Grouping(PlanPass): + + @staticmethod + def apply(execplan: ExectuionPlan) -> ExectuionPlan: + """ + Group contiguous forward and contiguous backward + into subgraph + """ + graph = execplan.graph + # step 1: group forward + adapter + groups = Grouping.group(execplan, [IRFwOperation]) + for devid in execplan.devices(): + for pieces in groups[devid]: + subgraph = graph.subgraph(pieces) + subgraph.device = devid + # update graph: replace the nodes with the subgraph + idx = graph.nodes().index(pieces[0]) + graph._nodes.insert(idx, subgraph) + for node in pieces: + graph._nodes.remove(node) + # update execution plan: replace the nodes with the subgraph + idx = execplan.sequence(devid).index(pieces[0]) + execplan.at(devid).insert(idx, subgraph) + for node in pieces: + execplan.at(devid).remove(node) + # step 2: group backward + groups = Grouping.group(execplan, [IRBpOperation]) + for devid in execplan.devices(): + for pieces in groups[devid]: + subgraph = graph.subgraph(pieces) + subgraph.device = devid + # update graph: replace the nodes with the subgraph + idx = graph.nodes().index(pieces[0]) + graph._nodes.insert(idx, subgraph) + for node in pieces: + graph._nodes.remove(node) + # update execution plan: replace the nodes with the subgraph + idx = execplan.sequence(devid).index(pieces[0]) + execplan.at(devid).insert(idx, subgraph) + for node in pieces: + execplan.at(devid).remove(node) + return execplan + + @staticmethod + def group(execplan, node_types: List) -> Dict[int, List[List[IRCell]]]: + groups = dict() + for devid in execplan.devices(): + groups[devid] = list() + pieces = list() + dev_seq = execplan.sequence(devid) + [None] + for node in dev_seq: + if all([isinstance(node, ntype) for ntype in node_types]): + pieces.append(node) + else: + if len(pieces) != 0: + groups[devid].append(pieces) + pieces = list() + return groups diff --git a/cube/execplan/planpass/merge.py b/cube/execplan/planpass/merge.py deleted file mode 100644 index 23def674..00000000 --- a/cube/execplan/planpass/merge.py +++ /dev/null @@ -1,130 +0,0 @@ -from typing import List - -from cube.execplan import ExectuionPlan -from cube.execplan.planpass.planpass import PlanPass -from cube.graph.operator.operator import IRBpOperation -from cube.ir.cten import IRCell -from cube.schedule.su import SUType, ScheduleUnit - - -class MergeComputeSU(PlanPass): - - @staticmethod - def apply(execplan: ExectuionPlan) -> ExectuionPlan: - """ - Merge consecutive backward SUs. The forward SUs will - also be merged if possible - """ - sugraph = execplan.sugraph - for devid in execplan.devices(): - dev_seq = execplan.sequence(devid) + [None] - pieces: List[ScheduleUnit] = list() - adapters: List[ScheduleUnit] = list() - for su in dev_seq: - if su and su.stype in [SUType.P2P, SUType.Transform]: - if len(pieces) > 0: - adapters.append(su) - continue - if su and su.stype in [SUType.Backward]: - allow_merge = len(pieces) == 0 - for psu in pieces[::-1]: - if sugraph.happen_before(psu, su): - allow_merge = True - break - for adapter in adapters: - if sugraph.happen_before(adapter, su): - allow_merge = False - break - # no merge adapters connected between forward SUs - if allow_merge: - fsus = [su.mirror] + [bsu.mirror for bsu in pieces] #[::-1] - if MergeComputeSU._connected_by_adapter(execplan, fsus, devid): - allow_merge = False - if allow_merge: - pieces.append(su) - continue - # merged forward su - if len(pieces) > 0: - fsus = [bsu.mirror for bsu in pieces][::-1] - if not all([fsu and (fsu in dev_seq) for fsu in fsus]): - raise RuntimeError("Expected same device fw-bw") - mfsu = MergeComputeSU._merge(fsus, devid) - mbsu = mfsu.mirror - # insert merged backward su - mbsu_idx = min([dev_seq.index(bsu) for bsu in pieces]) - for bsu in pieces: - dev_seq[dev_seq.index(bsu)] = None - dev_seq[mbsu_idx] = mbsu - # insert merged forward su - fsus_idx = [dev_seq.index(fsu) for fsu in fsus] - if max(fsus_idx) - min(fsus_idx) == len(fsus) - 1: - for fidx in fsus_idx: - dev_seq[fidx] = None - dev_seq[min(fsus_idx)] = mfsu - pieces = list() - if su and su.stype in [SUType.Backward]: - pieces = [su] - adapters = list() - dev_seq = [su for su in dev_seq if su is not None] - execplan.set(devid, dev_seq) - return execplan - - @staticmethod - def _merge(pieces: List[ScheduleUnit], devid: int) -> ScheduleUnit: - """ - Merge a list of SU into one. - """ - if len(pieces) == 1: - return pieces[0] - fnodes = list() - for fsu in pieces: - fnodes += fsu.nodes() - # TODO: fix multi-branch - mfsu = ScheduleUnit(fnodes, SUType.Forward, name='fsu') - mfsu.device = devid - - # merged backward su - mbnode = IRBpOperation( - data_num=len(mfsu.inputs()), - grad_num=len(mfsu.outputs()) - ) - for idx, fin in enumerate(mfsu.inputs()): - mbnode.set_data(idx, fin) - mbnode.set_output(idx, fin.grad) - for idx, fout in enumerate(mfsu.outputs()): - mbnode.set_grad(idx, fout.grad) - mbsu = ScheduleUnit([mbnode], SUType.Backward, name='bsu') - mbsu.device = devid - - IRCell.make_pair(mfsu, mbsu) - return mfsu - - @staticmethod - def _connected_by_adapter(execplan: ExectuionPlan, fpieces, devid: int): - """ - Check if there is an adapter connecting forward SUs - """ - sugraph = execplan.sugraph - indmap = [execplan.sequence(devid).index(fsu) for fsu in fpieces] - start = min(indmap) - end = max(indmap) - # check fsu1 -> asu -> fsu2 - for asu in execplan.sequence(devid)[start:end]: - if asu.stype in [SUType.P2P, SUType.Transform, SUType.Coll]: - happen_before = False - happen_after = False - # fsu1 -> asu - for fsu1 in fpieces: - if sugraph.happen_before(fsu1, asu): - happen_after = True - break - if not happen_after: - continue - # asu -> fsu2 - for fsu2 in fpieces: - if sugraph.happen_before(asu, fsu2): - happen_before = True - break - if happen_before and happen_after: - return True - return False diff --git a/cube/graph/_gpass.py b/cube/graph/_gpass.py deleted file mode 100644 index 9af2a4bd..00000000 --- a/cube/graph/_gpass.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Any -import copy - -from cube.graph.graph import IRGraph -from cube.graph.tensor import IRSubTensor, ValueMap -from cube.graph.operator import IRFwOperation, IRBpOperation - -from cube.ir.cten import IRCell, IRTensor - - -__all__ = ['forward'] - - -class _TensorGener: - - def __init__(self): - self.symbol = dict() - - def renew(self, val: Any, keep_param=True): - self._check_is_sub_tensor(val) - if not isinstance(val, IRTensor): - return val - if keep_param and val.is_param(): - return val - if val.parent._id not in self.symbol: - self.symbol[val.parent._id] = val.parent.like() - new_val = self.symbol[val.parent._id].select( - indmap=val.indmap, - valmap=val.valmap, - shape=val.shape - ) - return new_val - - def set_map(self, origin: Any, new: Any): - self._check_is_sub_tensor(origin) - self._check_is_sub_tensor(new) - if isinstance(origin, IRSubTensor): - tid = origin.parent._id - if isinstance(new, IRSubTensor): - self.symbol[tid] = new.parent - return - self.symbol[tid] = new - - def _check_is_sub_tensor(self, tensor): - if isinstance(tensor, IRTensor): - if not isinstance(tensor, IRSubTensor): - raise TypeError("Tensor only allows to be SubTensor") - - -def forward(graph, *args) -> IRGraph: - """ - Forward the IRGraph, replacing all the intermediate tensors - """ - if not isinstance(graph, IRGraph): - raise TypeError("Forwarding requires IRGraph") - - gener = _TensorGener() - - for input, arg in zip(graph.inputs(), args): - gener.set_map(input, arg) - - fnodes = list() - - # generate forward nodes - for node in graph.nodes(): - inputs = node.inputs() - outputs = node.outputs() - # fnode = copy.copy(node) - fnode : IRFwOperation = node - fnode._inputs = inputs - fnode._outputs = outputs - # set forward inputs - for idx, val in enumerate(inputs): - fnode.set_input(idx, gener.renew(val)) - # set forward outputs - for idx, val in enumerate(outputs): - fnode.set_output(idx, gener.renew(val)) - fnodes.append(fnode) - - # reverse is only to make op id looks consecutive - for fnode in graph.nodes()[::-1]: - fnode.gen_backward() - - inputs = [gener.renew(input) for input in graph.inputs()] - outputs = [gener.renew(output) for output in graph.outputs()] - - for idx, input in enumerate(inputs): - graph.set_input(idx, input) - for idx, output in enumerate(outputs): - graph.set_output(idx, output) - - # fgraph = IRGraph(fnodes, inputs, outputs, graph.name) - return graph diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index ba828921..0a594675 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -268,7 +268,7 @@ def gen_select(dst_tensor): # check local tensor if otensor in local: intersections.append(otensor) - inputs.append(ptensor) + inputs.append(otensor) return inputs, intersections, prims # FIXME: multi producer may result overlapped region diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 01d8215f..d7c81fe0 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -287,6 +287,7 @@ def replicate(self, op: IRCell, times=1): This is temporary use to enable assign with multiple devices """ + raise NotImplementedError("Replicate is not supported yet") if not isinstance(op, IRCell): raise TypeError("Expected an IRCell") if not isinstance(times, int) or times < 1: @@ -296,22 +297,19 @@ def replicate(self, op: IRCell, times=1): raise RuntimeError(f"Op {op} not exsits") ops = [op] - mirror_ops = [op.mirror] for _ in range(times - 1): - cpy_op = op.replicate() + dup_op = op.replicate() if op.mirror is not None: - cpy_mirror_op = op.mirror.replicate() - mirror_ops.append(cpy_mirror_op) - IRCell.make_pair(cpy_op, cpy_mirror_op) - ops.append(cpy_op) + dup_op.gen_backward() + ops.append(dup_op) idx = self.nodes().index(op) # forward self._nodes = self._nodes[:idx] + ops + self._nodes[idx+1:] # backward - if op.mirror: - mirror_ops = mirror_ops[::-1] + if op.mirror is not None: + bops = [op.mirror for op in ops][::-1] midx = self.nodes().index(op.mirror) - self._nodes = self._nodes[:midx] + mirror_ops + self._nodes[midx+1:] + self._nodes = self._nodes[:midx] + bops + self._nodes[midx+1:] self.reset_dependency() return ops diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index 47dd5028..1929a791 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -1,7 +1,7 @@ -from cube.runtime import collectives, executor, transform +from cube.runtime import executor from cube.runtime import device -from cube.runtime import reducer +from cube.runtime import adapter from cube.runtime import syndata from cube.runtime import resource from cube.runtime import module -from cube.runtime import function \ No newline at end of file +from cube.runtime import function diff --git a/cube/runtime/adapter/__init__.py b/cube/runtime/adapter/__init__.py new file mode 100644 index 00000000..0e4324f7 --- /dev/null +++ b/cube/runtime/adapter/__init__.py @@ -0,0 +1,11 @@ +# communications +from cube.runtime.adapter.collectives import send, recv +from cube.runtime.adapter.collectives import all_gather, all_reduce +from cube.runtime.adapter.collectives import reduce_scatter, broadcast + +# transformations +from cube.runtime.adapter.transform import select +from cube.runtime.adapter.transform import merge + +# reducer +from cube.runtime.adapter.reducer import Reducer diff --git a/cube/runtime/collectives.py b/cube/runtime/adapter/collectives.py similarity index 82% rename from cube/runtime/collectives.py rename to cube/runtime/adapter/collectives.py index 245df862..2d0f23ed 100644 --- a/cube/runtime/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -5,7 +5,7 @@ from cube.profiler.timer import CudaTimer -def send(tensors: List[torch.Tensor], to_ranks: List[int]): +def send(tensor: torch.Tensor, to_rank: int): """ send tensor to the remote devices. Each tensor can be sent to multiple devices @@ -16,18 +16,14 @@ def send(tensors: List[torch.Tensor], to_ranks: List[int]): """ # print(f'{torch.distributed.get_rank()}: sending...') CudaTimer().start(field_name='comm') + send_ops = list() - - ## synthetic ## - # return - - for tensor, rank in zip(tensors, to_ranks): - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, rank - ) - send_ops.append(send_op) + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, to_rank + ) + send_ops.append(send_op) reqs = torch.distributed.batch_isend_irecv(send_ops) for req in reqs: req.wait() @@ -35,37 +31,29 @@ def send(tensors: List[torch.Tensor], to_ranks: List[int]): CudaTimer().stop(field_name='comm') -def recv(shapes: List[List[int]], from_ranks: List[int]): - CudaTimer().start(field_name='comm') +def recv(shape: List[int], from_rank: int, dtype: torch.dtype): # print(f'{torch.distributed.get_rank()}: recving...') - recv_ops = list() - recv_tensors = list() - + CudaTimer().start(field_name='comm') ## synthetic ## # for shape in shapes: # recv_tensors.append( # torch.ones(tuple(shape), # device=torch.cuda.current_device() # )) - for shape, rank in zip(shapes, from_ranks): - tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device() - ) - recv_tensors.append(tensor) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, rank - ) - recv_ops.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(recv_ops) + # + tensor = torch.empty( + shape, requires_grad=True, dtype=dtype, + device=torch.cuda.current_device() + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, from_rank + ) + reqs = torch.distributed.batch_isend_irecv([recv_op]) for req in reqs: req.wait() torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - - if len(recv_tensors) == 0: return None - elif len(recv_tensors) == 1: return recv_tensors[0] - else: return tuple(recv_tensors) + return tensor def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): @@ -172,7 +160,7 @@ def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): return output -def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None): +def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None, dtype=None): """ Broadcast. ranks[0] is the root """ @@ -182,7 +170,7 @@ def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None): if len(tensors) == 1: tensor = tensors[0] else: - tensor = torch.empty(shape, device=torch.cuda.current_device()) + tensor = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) # tensor.requires_grad_() group = DeviceGroup().get_group(ranks) torch.distributed.broadcast(tensor, ranks[0], group=group) diff --git a/cube/runtime/reducer.py b/cube/runtime/adapter/reducer.py similarity index 100% rename from cube/runtime/reducer.py rename to cube/runtime/adapter/reducer.py diff --git a/cube/runtime/transform.py b/cube/runtime/adapter/transform.py similarity index 100% rename from cube/runtime/transform.py rename to cube/runtime/adapter/transform.py diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index aa61c1f4..27c94769 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -1,72 +1,77 @@ r""" -SU Executor for runtime +Executor for runtime """ from typing import Tuple, Any, Callable, List import torch -def fexecute(su: Callable, *input_tensors: Tuple[Any]): +def fexecute(subgraph: Callable, *input_tensors: Tuple[Any]): """ - forward the SUs + forward the sub-graph. """ - outputs = su(*input_tensors) + outputs = subgraph(*input_tensors) # print('forwarding... ') return outputs def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_grads): """ - Backward the SUs - """ - # for tensor in input_tensors: - # if torch.is_tensor(tensor) and tensor.requires_grad: - # tensor.retain_grad() + Backward Procedure. + + input_tensors: List[torch.Tensor]: + tensors that their gradient need to be computed, including parameters. + Correspoinding forward input tensors. + + output_tensors: + tensors that start for gradient backward computation. + Corresponding to forward output tensors. - if len(output_tensor_grads) != len(output_tensors): - raise RuntimeError( - "Expected same length of out tensors and grads" - ) + output_tensor_grads: + gradient tensors corresponding to output_tensors. + Returns: + gradient in order of non-parameter tensors in input_tensors. + (Note parameter tnesors already have gradient accumulated at .grad attribute) + """ + # print(f'{torch.distributed.get_rank()}: backwarding... ') inputs = list() - indmap = list() - for idx, input in enumerate(input_tensors): - if torch.is_tensor(input) and input.requires_grad: - inputs.append(input) - indmap.append(idx) + for input in enumerate(input_tensors): + # skip returning gradients of parameters + if torch.is_tensor(input) and not isinstance(input, torch.nn.Parameter): + inputs.append(inputs) - grads = [None] * len(input_tensors) + grads = list() if len(inputs) != 0: - # print(f'{torch.distributed.get_rank()}: backwarding... ') in_grads = torch.autograd.grad( output_tensors, inputs, output_tensor_grads, allow_unused=True) - for idx, grad in zip(indmap, in_grads): - tensor = input_tensors[idx] + for tensor, grad in zip(inputs, in_grads): if isinstance(tensor, torch.nn.Parameter): if tensor.grad is not None: tensor.grad += grad else: tensor.grad = grad - grads[idx] = grad - - # if len(inputs) != 0: - # torch.autograd.backward( - # output_tensors, - # grad_tensors=output_tensor_grads, - # inputs=inputs - # ) - # for idx, tensor in zip(indmap, inputs): - # grads[idx] = tensor.grad - - # torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) - # grads = list() - # for tensor in input_tensors: - # # print('backward input tensor: {}'.format(tensor)) - # if torch.is_tensor(tensor) and tensor.requires_grad: - # grads.append(tensor.grad) - # else: - # grads.append(None) + else: + grads.append(grad) if len(grads) == 0: return None elif len(grads) == 1: return grads[0] else: return tuple(grads) + + +def backwardV2(input_tensors: List[torch.Tensor], output_tensors, output_tensor_grads): + inputs = list() + for input in enumerate(input_tensors): + # skip returning parameters + if torch.is_tensor(input) and not isinstance(input, torch.nn.Parameter): + inputs.append(inputs) + for tensor in input_tensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + tensor.retain_grad() + torch.autograd.backward( + output_tensors, + grad_tensors=output_tensor_grads, + inputs=input_tensors + ) + grads = [input.grad for input in inputs] + return grads diff --git a/cube/runtime/module.py b/cube/runtime/module.py index d87dbaad..46d69a36 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,7 +1,7 @@ from typing import List import torch from cube.runtime.device import DeviceGroup -from cube.runtime.reducer import Reducer +from cube.runtime.adapter.reducer import Reducer class CubeModule(torch.nn.Module): From 32382d77923a9c22a47191e183e5a5f3a3316f49 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 29 Dec 2021 11:23:22 +0800 Subject: [PATCH 0497/1892] codegen for p2p --- cube/codegen/codegen.py | 50 +++++++---- cube/execplan/planpass/grouping.py | 136 +++++++++++++++++++---------- cube/graph/adapter/gen.py | 6 ++ cube/graph/tensor.py | 2 + cube/ir/cten.py | 5 +- cube/runtime/executor.py | 37 ++++---- 6 files changed, 152 insertions(+), 84 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index dcb71a24..1af78a18 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -41,7 +41,7 @@ def tensor_naming(self, tensor: Any) -> str: Return the var name (unique for different variable) """ if isinstance(tensor, IRTensor): - tensor_name = 'tensor' if tensor.name is None else tensor.name + tensor_name = tensor.name if '.' in tensor_name: tensor_name = tensor_name.split('.')[0] name = '_'.join([tensor_name, str(tensor._id)]) @@ -271,6 +271,12 @@ def emit_adapter_call(self, node: IRAdapter): self.forward_region.append(code) else: raise TypeError(f"Unkown primitive types {type(prim)} of Adapter") + # requires grad generation + sign = '{output} = {output}.contiguous().requires_grad_()' + for output in node.outputs(): + if isinstance(output, IRSubTensor): + code = sign.format(output=self.tensor_naming(output)) + self.forward_region.append(code) # def emit_comm_call(self, node): # """ @@ -443,10 +449,8 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: inputs = [self.tensor_naming(t) for t in node.inputs() if not t.is_param()] outputs = [self.tensor_naming(t) for t in node.outputs()] - inputs = '(' + ', '.join(inputs + ['']) + ')' - outputs = ', '.join(outputs) - if len(outputs) == 0: - outputs = '_' + inputs = self.tuple_naming(inputs) + outputs = self.return_naming(outputs) if isinstance(node, IRGraph): is_backward = all([isinstance(n, IRBpOperation) for n in node.nodes()]) @@ -456,21 +460,18 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: code = f'{outputs} = {body}' # emit backward else: - finputs = [t.data for t in node.outputs() if isinstance(t, IRSubTensor)] - foutputs = [t.data for t in node.inputs() if isinstance(t, IRSubTensor)] - outputs = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + finputs = [t for t in node.mirror.inputs() if t.requires_grad] + foutputs = node.mirror.outputs() + inputs = [t.grad for t in foutputs] + outputs = [t.grad for t in finputs] # remove weight gradient in outputs for input in finputs: if input.is_param(): outputs.remove(input.grad) - finputs = [self.tensor_naming(t) for t in finputs] - finputs = '(' + ', '.join(finputs + ['']) + ')' - foutputs = [self.tensor_naming(t) for t in foutputs] - foutputs = '(' + ', '.join(foutputs + ['']) + ')' - outputs = [self.tensor_naming(t) for t in outputs] - outputs = ', '.join(outputs) - if len(outputs) == 0: - outputs = '_' + finputs = self.tuple_naming(finputs) + foutputs = self.tuple_naming(foutputs) + inputs = self.tuple_naming(inputs) + outputs = self.return_naming(outputs) body = bsign.format( input_tensors=finputs, output_tensors=foutputs, output_grads=inputs ) @@ -480,8 +481,8 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: if len(node.inputs()) != 0: raise RuntimeError("Expect Dataloader node has no inputs") outputs = [self.tensor_naming(output) for output in node.outputs()] - return_val = ','.join(outputs) - code = f'{return_val} = next(dataloader)' + outputs = self.return_naming(outputs) + code = f'{outputs} = next(dataloader)' elif isinstance(node, IRAdapter): body = fsign.format(model=f'model.{name}', inputs=inputs) @@ -491,6 +492,19 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: raise RuntimeError(f"Unspported node type: {type(node)}") return code + def tuple_naming(self, tensors: List[Any]) -> str: + tensors = [self.tensor_naming(t) for t in tensors] + tensors = '(' + ', '.join(tensors + ['']) + ')' + return tensors + + def return_naming(self, tensors: List[Any]) -> str: + tensors = [self.tensor_naming(t) for t in tensors] + if len(tensors) == 0: + tensors = '_' + else: + tensors = ', '.join(tensors) + return tensors + def tensor_naming(self, tensor: Any): """ Generate tensor name. diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 405077d0..3ccb17da 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -2,12 +2,11 @@ Operation grouping """ -from typing import List, Dict +from typing import List, Dict, Tuple from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass from cube.graph.operator.operator import IRBpOperation, IRFwOperation -from cube.graph.adapter.adapter import IRAdapter from cube.ir.cten import IRCell @@ -20,52 +19,97 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: into subgraph """ graph = execplan.graph - # step 1: group forward + adapter - groups = Grouping.group(execplan, [IRFwOperation]) + fgroups, bgroups = Grouping.group(execplan) for devid in execplan.devices(): - for pieces in groups[devid]: - subgraph = graph.subgraph(pieces) - subgraph.device = devid - # update graph: replace the nodes with the subgraph - idx = graph.nodes().index(pieces[0]) - graph._nodes.insert(idx, subgraph) - for node in pieces: - graph._nodes.remove(node) - # update execution plan: replace the nodes with the subgraph - idx = execplan.sequence(devid).index(pieces[0]) - execplan.at(devid).insert(idx, subgraph) - for node in pieces: - execplan.at(devid).remove(node) - # step 2: group backward - groups = Grouping.group(execplan, [IRBpOperation]) - for devid in execplan.devices(): - for pieces in groups[devid]: - subgraph = graph.subgraph(pieces) - subgraph.device = devid - # update graph: replace the nodes with the subgraph - idx = graph.nodes().index(pieces[0]) - graph._nodes.insert(idx, subgraph) - for node in pieces: - graph._nodes.remove(node) - # update execution plan: replace the nodes with the subgraph - idx = execplan.sequence(devid).index(pieces[0]) - execplan.at(devid).insert(idx, subgraph) - for node in pieces: - execplan.at(devid).remove(node) + for fpieces, bpieces in zip(fgroups[devid], bgroups[devid]): + fsubgraph = graph.subgraph(fpieces) + fsubgraph.device = devid + if bpieces is not None: + bsubgraph = graph.subgraph(bpieces) + bsubgraph.device = devid + IRCell.make_pair(fsubgraph, bsubgraph) + subgraphs = [fsubgraph] if bpieces is None else [fsubgraph, bsubgraph] + for subgraph in subgraphs: + pieces = subgraph.nodes() + # update graph: replace the nodes with the subgraph + idx = graph.nodes().index(pieces[0]) + graph._nodes.insert(idx, subgraph) + for node in pieces: + graph._nodes.remove(node) + # update execution plan: replace the nodes with the subgraph + idx = execplan.sequence(devid).index(pieces[0]) + execplan.at(devid).insert(idx, subgraph) + for node in pieces: + execplan.at(devid).remove(node) return execplan @staticmethod - def group(execplan, node_types: List) -> Dict[int, List[List[IRCell]]]: - groups = dict() + def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: + """ + Return forward groups and corresponding + backward groups for each device. + + Each group can be indexed by device id. + Each device id contains a list of forward / backward operations + + Returns: + Tuple: (fgroups, bgroups) + """ + fgroups, bgroups = dict(), dict() for devid in execplan.devices(): - groups[devid] = list() - pieces = list() - dev_seq = execplan.sequence(devid) + [None] - for node in dev_seq: - if all([isinstance(node, ntype) for ntype in node_types]): - pieces.append(node) - else: - if len(pieces) != 0: - groups[devid].append(pieces) - pieces = list() - return groups + fgroups[devid], bgroups[devid] = list(), list() + fpieces, bpieces = list(), list() + seq = execplan.sequence(devid) + fnodes = [fnode for fnode in seq if isinstance(fnode, IRFwOperation)] + have_backward = all( + [isinstance(fnode.mirror, IRBpOperation) for fnode in fnodes] + ) + # training + if have_backward: + bnodes = [fnode.mirror for fnode in fnodes] + for fnode, bnode in zip(fnodes + [-1], bnodes + [-1]): + fconsecutive = Grouping.consecutive(seq, fpieces, fnode) + bconsecutive = Grouping.consecutive(seq, bpieces, bnode) + if fconsecutive and bconsecutive: + fpieces.append(fnode) + bpieces.insert(0, bnode) + else: + if len(fpieces) != 0: + fgroups[devid].append(fpieces) + bgroups[devid].append(bpieces) + fpieces, bpieces = [fnode], [bnode] + # inference + else: + for fnode in fnodes: + fconsecutive = Grouping.consecutive(seq, fpieces, fnode) + if fconsecutive: + fpieces.append(fnode) + else: + if len(fpieces) != 0: + fgroups[devid].append(fpieces) + bgroups[devid].append(None) + fpieces, bpieces = [fnode], list() + return fgroups, bgroups + + @staticmethod + def consecutive(seq: List[IRCell], pieces: List[IRCell], node: IRCell): + """ + Check whether the piecies with new node + is consecutive in the sequence. + + Assume all the node in pieces will apear in seq. + If node not in the sequence, will return False. + """ + if len(pieces) == 0: + return True + if node not in seq: + return False + idx = seq.index(node) + pidx = [seq.index(pnode) for pnode in pieces] + # check whether pieces is consecutive + if max(pidx) - min(pidx) != len(pidx) - 1: + return False + # check whether new node adding new node is consecutive + if idx != max(pidx) + 1 and idx != min(pidx) - 1: + return False + return True diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index e15ca86d..02de2ab6 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -1,4 +1,5 @@ +from numpy import isin from cube.graph.graph import IRGraph from cube.graph.adapter.adapter import IRAdapter from cube.graph.operator.operator import IRBpOperation, IRFwOperation @@ -9,6 +10,11 @@ class AdapterGener: @staticmethod def gen(graph: IRGraph) -> IRGraph: + # update the gradient before generate adapter + for node in graph.nodes(): + if isinstance(node, IRBpOperation): + node.update() + # generate adapter for node in graph.nodes(): if isinstance(node, IRFwOperation): for input in node.inputs(): diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 8099a12b..271c4ae8 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -582,9 +582,11 @@ def get_grad(self, fcell: IRCell): forward cell """ if not self.requires_grad: + self.grad = None return None full_grad = self.parent.grad if full_grad is None: + self.grad = None return None if self in fcell.inputs(): ref_cells = list() diff --git a/cube/ir/cten.py b/cube/ir/cten.py index a0fa5df8..43f0c576 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -394,11 +394,14 @@ def __repr__(self): class IRTensor: """ IRTensor serves as IRGraph edge + + Note by setting IRTensor name to "None" indicates this tensor holds nothing + and will be translated to None in code generation. """ _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad', '_dtype'] - def __init__(self, shape=None, name=None, dtype=IRDType.unknown): + def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown): self._id: int = IDGenerator().gen_tensor_id() self._shape: Optional(List[int]) = shape diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 27c94769..809b29ed 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -15,7 +15,9 @@ def fexecute(subgraph: Callable, *input_tensors: Tuple[Any]): return outputs -def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_grads): +def backward(input_tensors : List[torch.Tensor], + output_tensors: List[torch.Tensor], + output_tensor_grads: List[torch.Tensor]): """ Backward Procedure. @@ -34,26 +36,23 @@ def backward(input_tensors: List[torch.Tensor], output_tensors, output_tensor_gr gradient in order of non-parameter tensors in input_tensors. (Note parameter tnesors already have gradient accumulated at .grad attribute) """ - # print(f'{torch.distributed.get_rank()}: backwarding... ') - inputs = list() - for input in enumerate(input_tensors): - # skip returning gradients of parameters - if torch.is_tensor(input) and not isinstance(input, torch.nn.Parameter): - inputs.append(inputs) - + if len(input_tensors) == 0: + return None grads = list() - if len(inputs) != 0: - in_grads = torch.autograd.grad( - output_tensors, inputs, output_tensor_grads, allow_unused=True) - for tensor, grad in zip(inputs, in_grads): - if isinstance(tensor, torch.nn.Parameter): - if tensor.grad is not None: - tensor.grad += grad - else: - tensor.grad = grad + in_grads = torch.autograd.grad( + outputs = output_tensors, + inputs = input_tensors, + grad_outputs = output_tensor_grads, + allow_unused=True + ) + for tensor, grad in zip(input_tensors, in_grads): + if isinstance(tensor, torch.nn.Parameter): + if tensor.grad is not None: + tensor.grad += grad else: - grads.append(grad) - + tensor.grad = grad + else: + grads.append(grad) if len(grads) == 0: return None elif len(grads) == 1: return grads[0] else: return tuple(grads) From 1f36c072bd60ba0c1620783de7532ddd49f9eb72 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 29 Dec 2021 14:05:06 +0800 Subject: [PATCH 0498/1892] add adapter for weight reducer --- cube/codegen/codegen.py | 152 ++++++++------------------------ cube/execplan/planpass/gfuse.py | 87 ------------------ cube/graph/adapter/adapter.py | 32 +++++-- cube/graph/adapter/gen.py | 63 ++++++++++++- cube/graph/operator/operator.py | 21 +---- 5 files changed, 123 insertions(+), 232 deletions(-) delete mode 100644 cube/execplan/planpass/gfuse.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 1af78a18..29269e2a 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,7 +1,7 @@ """ Generate Pytorch code given the model DAG and the transformation config """ -from typing import List, Any +from typing import Dict, List, Any, Tuple import torch import copy @@ -9,8 +9,8 @@ from cube.ir.dtype import IRDType from cube.graph.tensor import IRSubTensor from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.graph.operator.operator import IROptimOperation from cube.graph.adapter.adapter import IRAdapter, SelectPrim, MovePrim, MergePrim +from cube.graph.adapter.adapter import IRWeightReducer from cube.execplan import ExectuionPlan # from cube.schedule.adapter.collectives import IRCollectives @@ -70,26 +70,29 @@ def __init__(self, execplan: ExectuionPlan): self.symbols = SymbolTable() # ref module to check shared variables self._ref_module = torch.nn.Module() - # groups - self._all_comm_groups = list() - # self.get_all_groups() - def get_all_groups(self): + def init_comm_groups(self): """ Get all communication groups. Creating communication group requires all the devices enter the same call. """ - raise NotImplementedError - for devid in self.execplan.devices(): - for su in self.execplan.sequence(devid): - if su.stype == SUType.Coll: - ranks = list(su.nodes(0).ranks) - ranks.sort() - ranks = tuple(ranks) - if ranks not in self._all_comm_groups: - self._all_comm_groups.append(ranks) + sign = 'self.init_group(ranks={ranks})' + # collect groups from weight reducer + comm_groups: Dict[Tuple[int]] = list() + for node in self.execplan.graph.nodes(): + if isinstance(node, IRWeightReducer): + ranks = list(node.device) + ranks.sort() + ranks = tuple(ranks) + if ranks not in comm_groups: + comm_groups.append(ranks) + # TODO: collect groups from p2p fusion + # create communication group + for ranks in comm_groups: + code = sign.format(ranks=list(ranks)) + self.declare_region.append(code) def gen(self, device: int, outfile=None, attach=False) -> str: """ @@ -99,8 +102,8 @@ def gen(self, device: int, outfile=None, attach=False) -> str: node_args: List[List[str]] = list() gen_nodes: List[IRCell] = list() - # TODO init group - # self.emit_comm_group_creation() + # initialize communication groups + self.init_comm_groups() # parse graph body for node in self.execplan.sequence(device): @@ -116,9 +119,9 @@ def gen(self, device: int, outfile=None, attach=False) -> str: self.emit_adapter_call(node) # elif isinstance(node, IRCollectives): # self.emit_collective_call(node) - elif isinstance(node, IROptimOperation): - self.emit_optim_init(node) - self.emit_optim_call(node) + elif isinstance(node, IRWeightReducer): + self.emit_reducer_init(node) + self.emit_reducer_call(node) elif isinstance(node, IRBpOperation): continue elif isinstance(node, IRDataOperation): @@ -198,17 +201,6 @@ def emit_graph_call(self, graph: IRGraph): raise RuntimeError("IRBpOperation is not expected in GenModel") self.emit_op_call(node) - - # def emit_comm_group_creation(self): - # """ - # Emit communication group creation code - # """ - # sign = 'self.init_group(ranks={ranks})' - # for ranks in self._all_comm_groups: - # ranks = list(ranks) - # code = sign.format(ranks=ranks) - # self.declare_region.append(code) - def emit_op_call(self, node: IRFwOperation): """ Emit op forward code @@ -278,98 +270,26 @@ def emit_adapter_call(self, node: IRAdapter): code = sign.format(output=self.tensor_naming(output)) self.forward_region.append(code) - # def emit_comm_call(self, node): - # """ - # Emit communication code - # """ - # comm_code = node.signature - # send_tensors = self._forward_region_arg_names(node.inputs()) - # send_tensors = '(' + ', '.join(send_tensors + ['']) + ')' - # send_ranks = node.send_ranks - # recv_tensors = self._forward_region_arg_names(node.outputs()) - # recv_tensors = ', '.join(recv_tensors) - # recv_shapes = [tensor.shape for tensor in node.outputs()] - # recv_ranks = node.recv_ranks - # if node.comm_type == IRCommType.Send: - # code = f'{comm_code}({send_tensors}, {send_ranks})' - # elif node.comm_type == IRCommType.Recv: - # code = f'{recv_tensors} = {comm_code}({recv_shapes}, {recv_ranks})' - # elif node.comm_type == IRCommType.SendRecv: - # code = f'{recv_tensors} = {comm_code}({send_tensors}, {send_ranks}, {recv_shapes}, {recv_ranks})' - # else: - # raise TypeError(f"Unsupported IRCommmNode: {node.comm_type}") - # self.forward_region.append(code) - - # def emit_collective_call(self, node): - # ranks = node.ranks - # inputs = self._forward_region_arg_names(node.inputs()) - # shape = None - # if len(inputs) == 0: - # assert len(node.outputs()) == 1 - # shape = node.outputs(0).shape - # inputs = '(' + ', '.join(inputs + ['']) + ')' - # outputs = self._forward_region_arg_names(node.outputs()) - # outputs = ', '.join(outputs) - # if shape: - # code = f'{node.signature}({inputs}, {ranks}, {shape})' - # else: - # code = f'{node.signature}({inputs}, {ranks})' - # if outputs: - # code = f'{outputs} = {code}' - # self.forward_region.append(code) - - # def emit_transform_call(self, node): - # """ - # Emit in-device tensor select / merge call. - # """ - # for prim in node.trace(): - # if isinstance(prim, SelectPrim): - # signature = 'cube.runtime.transform.select({tensor}, {indmap}, {valmap})' - # input = self.tensor_naming(prim.tensor) - # indmap = repr(prim.indmap) - # valmap = repr(tuple([prim.valmap.idx, prim.valmap.chunk_num])) - # output = self.tensor_naming(prim.output) - # code = f'{output} = {signature.format(tensor=input, indmap=indmap, valmap=valmap)}' - # self.forward_region.append(code) - # elif isinstance(prim, MergePrim): - # signature = 'cube.runtime.transform.merge({tensors}, {concat}, {add})' - # inputs = self._forward_region_arg_names(prim.tensors) - # inputs = '(' + ', '.join(inputs) + ')' - # output = self.tensor_naming(prim.output) - # code = f'{output} = {signature.format(tensors=inputs, concat=prim.concat, add=prim.add)}' - # self.forward_region.append(code) - # else: - # raise RuntimeError(f"Not supported prim: {type(prim)}") - # for output in node.outputs(): - # # contiguous and requires grad - # output_name = self.tensor_naming(output) - # code = f'{output_name} = {output_name}.contiguous()' - # self.forward_region.append(code) - # if not output.is_grad(): - # code = f'{output_name} = {output_name}.requires_grad_()' - # self.forward_region.append(code) - - def emit_optim_init(self, node: IROptimOperation): + def emit_reducer_init(self, node: IRWeightReducer): # reducer init interface - reducer_init = '{reducer} = cube.runtime.reducer.Reducer(ranks={ranks})' + reducer_init = '{reducer} = cube.runtime.adapter.Reducer(ranks={ranks})' reducer_add = 'self.add_reducer({reducer})' - add_param = '{reducer}.add_param({grad})' + add_param = '{reducer}.add_param({weight})' # create reducer in declare region - ranks = list(node.ranks) - grads = node.inputs() - reducer_name = f'self.reducer{node._id}' + weights = node.inputs() + reducer_name = f'self.wreducer{node._id}' self.declare_region.append('') - init_code = reducer_init.format(reducer=reducer_name, ranks=ranks) + init_code = reducer_init.format(reducer=reducer_name, ranks=node.device) self.declare_region.append(init_code) - grads = [self.tensor_naming(t) for t in grads] - for grad in grads: - add_param_code = add_param.format(reducer=reducer_name, grad=grad) + weights = [self.tensor_naming(t) for t in weights] + for weight in weights: + add_param_code = add_param.format(reducer=reducer_name, weight=weight) self.declare_region.append(add_param_code) add_code = reducer_add.format(reducer=reducer_name) self.declare_region.append(add_code) - def emit_optim_call(self, node: IROptimOperation): - reducer_name = f'self.reducer{node._id}' + def emit_reducer_call(self, node: IRWeightReducer): + reducer_name = f'self.wreducer{node._id}' call_code = f'{reducer_name}.allreduce()' self.forward_region.append(call_code) @@ -488,6 +408,10 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: body = fsign.format(model=f'model.{name}', inputs=inputs) code = f'{outputs} = {body}' + elif isinstance(node, IRWeightReducer): + body = fsign.format(model=f'model.{name}', inputs='()') + code = f'{outputs} = {body}' + else: raise RuntimeError(f"Unspported node type: {type(node)}") return code diff --git a/cube/execplan/planpass/gfuse.py b/cube/execplan/planpass/gfuse.py deleted file mode 100644 index 28faef39..00000000 --- a/cube/execplan/planpass/gfuse.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Gradient Allreduce Fusion -""" -from typing import Dict, Tuple, List -import sys -import copy - - -from cube.graph.operator.operator import IROptimOperation -from cube.graph.tensor import IRSubTensor, ValueMap - -from cube.execplan import ExectuionPlan -from cube.schedule.su import SUType, ScheduleUnit -from cube.execplan.planpass.planpass import PlanPass - - -class WeightGradAllreduceFusion(PlanPass): - - @staticmethod - def apply(execplan: ExectuionPlan) -> ExectuionPlan: - """ - Apply weight gradient allreduce fusion - """ - reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() - weights, params = WeightGradAllreduceFusion._get_weight_grads(execplan) - for param_id in params: - grads = params[param_id] - ranks = list(grads.keys()) - ranks.sort() - ranks = tuple(ranks) # ranks are used for group - if len(ranks) == 1: - continue - if ranks not in reducers: - reducers[ranks] = list() - reducers[ranks].append(weights[param_id]) - # generate reducer for each rank - for ranks in reducers: - weights = reducers[ranks] - # even though some ranks don't need allreduce, - # pytorch still requires each rank simutaneously call the - # communication group initialization - for devid in execplan.devices(): - dev_weights = copy.copy(weights) - for idx, weight in enumerate(dev_weights): - if devid not in params[weight._id]: - dev_weights[idx] = None - dev_weights = [w for w in dev_weights if w is not None] - opt_op = IROptimOperation(dev_weights, ranks) - reduce_su = ScheduleUnit([opt_op], SUType.Optimizer) - reduce_su.device = devid - execplan.at(devid).append(reduce_su) - return execplan - - @staticmethod - def _get_weight_grads(execplan: ExectuionPlan) -> Dict: - """ - Get weight and gradient - - weights: Dict[param_id: int, IRSubTensor] - grads : Dict[param_id: int, Dict[device: int, List[grad: IRSubTensor]]] - - """ - grads = dict() - weights = dict() - for devid in execplan.devices(): - bsus = [su for su in execplan.sequence(devid) if su.stype == SUType.Backward] - for bsu in bsus: - # bsu has only one node - for input in bsu.inputs(): - if isinstance(input, IRSubTensor) and input.is_param(): - grad = input.grad - if grad is None: - print(input.name, input) - print(grad) - assert grad is not None - # nothing to sync - if grad.valmap == ValueMap(0, 1): - continue - if input._id not in grads: - grads[input._id] = dict() - weights[input._id] = input - if devid not in grads[input._id]: - grads[input._id][devid] = list() - if grad in grads[input._id][devid]: - raise RuntimeError("Already logged grad?") - grads[input._id][devid].append(grad) - return weights, grads diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index 0a594675..12b4cf56 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -228,13 +228,6 @@ def is_identity(self): """ return len(self._prims) == 0 - def __repr__(self): - dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' - return dscp - - def module_repr(self) -> str: - return repr(self) - @staticmethod def gen(dst_tensor: IRSubTensor): # print(f'generating adapter for: {dst_tensor}') @@ -368,6 +361,13 @@ def gen_merge(dst_tensor, intersections): raise RuntimeError("Merge Plan not found") return prims + def __repr__(self): + dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' + return dscp + + def module_repr(self) -> str: + return repr(self) + def extra_repr(self): """ Detailed information @@ -377,3 +377,21 @@ def extra_repr(self): for prim in self._select_prims + self._move_prims + self._merge_prims: dscp += '\t' + repr(prim) + '\n' return dscp + + +class IRWeightReducer(IRCell): + + def __init__(self, weights: List[IRSubTensor], name='reducer'): + if not all([isinstance(w, IRSubTensor) and w.is_param() for w in weights]): + raise RuntimeError("Expected a list of gradient IRSubTensor") + signature = None + super().__init__(name, signature, len(weights), 0) + for idx, weight in enumerate(weights): + self.set_input(idx, weight) + + def __repr__(self): + dscp = f'WReducer{self._id}-{self.device}(inputs={self.inputs()})' + return dscp + + def module_repr(self) -> str: + return repr(self) diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index 02de2ab6..9ed490f0 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -1,20 +1,29 @@ +from typing import Dict, List, Tuple -from numpy import isin from cube.graph.graph import IRGraph -from cube.graph.adapter.adapter import IRAdapter +from cube.graph.tensor import IRSubTensor, ValueMap +from cube.graph.adapter.adapter import IRAdapter, IRWeightReducer from cube.graph.operator.operator import IRBpOperation, IRFwOperation -from cube.graph.tensor import IRSubTensor class AdapterGener: @staticmethod def gen(graph: IRGraph) -> IRGraph: + """ + Generate tensor adapter for both intermediate tensors and weights + """ + graph = AdapterGener.gen_activation_adapter(graph) + graph = AdapterGener.gen_weight_reducer(graph) + return graph + + @staticmethod + def gen_activation_adapter(graph: IRGraph) -> IRGraph: # update the gradient before generate adapter for node in graph.nodes(): if isinstance(node, IRBpOperation): node.update() - # generate adapter + # generate adapter for non-weight values for node in graph.nodes(): if isinstance(node, IRFwOperation): for input in node.inputs(): @@ -37,3 +46,49 @@ def gen(graph: IRGraph) -> IRGraph: idx = graph.nodes().index(node) graph._nodes.insert(idx, adapter) return graph + + + @staticmethod + def gen_weight_reducer(graph: IRGraph) -> IRGraph: + # step 1: get weight and gradient + # weights: Dict[weight_id: int, IRSubTensor] + # grads : Dict[weight_id: int, Dict[device: int, List[grad: IRSubTensor]]] + grads = dict() + weights = dict() + for fnode in graph.nodes(): + if not isinstance(fnode, IRFwOperation): + continue + devid = fnode.device[0] + for input in fnode.inputs(): + if isinstance(input, IRSubTensor) and input.is_param(): + grad = input.grad + # nothing to sync + if grad.valmap == ValueMap(0, 1): + continue + if input._id not in grads: + grads[input._id] = dict() + weights[input._id] = input + if devid not in grads[input._id]: + grads[input._id][devid] = list() + if grad in grads[input._id][devid]: + raise RuntimeError("Already logged grad?") + grads[input._id][devid].append(grad) + # step 2: generate weight. + # reducers: tuple(ranks): List[weight] + reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() + for wid in grads: + ranks = list(grads[wid].keys()) + ranks.sort() + ranks = tuple(ranks) # ranks are used for group + if len(ranks) == 1: + continue + if ranks not in reducers: + reducers[ranks] = list() + reducers[ranks].append(weights[wid]) + # generate reducer for each rank + for ranks in reducers: + weights = reducers[ranks] + opt_op = IRWeightReducer(weights) + opt_op.device = list(ranks) + graph._nodes.append(opt_op) + return graph \ No newline at end of file diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 0dbb11bd..b004319d 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -6,7 +6,7 @@ from cube.algorithm.factory import DistAlgorithmFactory -__all__ = ['IRFwOperation', 'IRBpOperation', 'IRDataOperation', 'IROptimOperation'] +__all__ = ['IRFwOperation', 'IRBpOperation', 'IRDataOperation'] class IRFwOperation(IRCell): @@ -339,22 +339,3 @@ def __repr__(self): def module_repr(self) -> str: return repr(self) - - -class IROptimOperation(IRCell): - - def __init__(self, weights: List[IRSubTensor], ranks: List[int], name='optimizer'): - if not all([isinstance(w, IRSubTensor) and w.is_param() for w in weights]): - raise RuntimeError("Expected a list of gradient IRSubTensor") - if not all([isinstance(rank, int) for rank in ranks]): - raise RuntimeError("Expected a list of int") - signature = None - self._ranks = ranks - - super().__init__(name, signature, len(weights), 0) - for idx, weight in enumerate(weights): - self.set_input(idx, weight) - - @property - def ranks(self): - return copy.copy(self._ranks) From 37e9fb8c3e7a7a40a0dfa346e7f1a8556a00c115 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 29 Dec 2021 20:04:02 +0800 Subject: [PATCH 0499/1892] init fusion --- cube/execplan/planpass/fusion.py | 107 +++++++++++++++++++++++++++++++ cube/graph/adapter/adapter.py | 18 ++++++ 2 files changed, 125 insertions(+) create mode 100644 cube/execplan/planpass/fusion.py diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py new file mode 100644 index 00000000..5a2f0ec2 --- /dev/null +++ b/cube/execplan/planpass/fusion.py @@ -0,0 +1,107 @@ +from typing import List, Dict + +from cube.execplan import ExectuionPlan +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.tensor import IRSubTensor, ValueMap +from cube.execplan.planpass.planpass import PlanPass + + +class P2PFusion(PlanPass): + + @staticmethod + def apply(execplan: ExectuionPlan) -> ExectuionPlan: + adapters = list() + for node in execplan.graph.nodes(): + if isinstance(node, IRAdapter): + adapters.append(node) + pass + + @staticmethod + def allgather_matcher(execplan: ExectuionPlan): + """ + Allgather semantic: + + Given a list of adapters: + 1). [Num] each adapter has same multiple inputs and same one output + 2). [Dev] inputs/outputs among adapters are from different device. device# = number of adapters. + 3). [Indmap] No-overlap index-map among inputs. + 4). [Valmap] each input value-map is same with output valuemap + """ + pass + + @staticmethod + def allreduce_matcher(execplan: ExectuionPlan): + """ + Allreduce semantic: + + Given a list of adapters: + 1). [Num] each adapter has different one input and same one output + 2). [Dev] inputs/outputs among adapters are from different devices + 2). [Indmap] inputs among adapters has same index-map with output. + 3). [Valmap] inputs have parital value-map. Output has full value-map + """ + pass + + @staticmethod + def reducescatter_matcher(execplan: ExectuionPlan): + """ + ReduceScatter semantic: + + Given a list of adapters: + 1). [Num] each adapter has different one input and different one output + 2). [Dev] inputs/outputs among adapters are from different devices + 3). [Indmap] inputs among adapters have same index-map + 4). [Indmap] outputs among adapters have different index-map + 5). [Valmap] inputs among adapters have different partial val-map. + 6). [Valmap] outputs among adapters have same Full val-map + """ + pass + + @staticmethod + def broadcast_matcher(execplan: ExectuionPlan): + """ + Broadcast semantic: + + Given a list of adapters: + 1). [Num] each adapter has same input and output. input = output. + 2). [Dev] inputs among adapters are from a same device. + 3). [Dev] outputs among adapters are from different devices + """ + pass + + # Utilities + @staticmethod + def group_by_output(adapters: List[IRAdapter]): + """ + Group the adapters by same output tensor + """ + tensors = dict() # tensor_id -> tensor + groups = dict() # tensor_id -> List[IRAdapter] + for adapter in adapters: + if len(adapter.outputs()) != 1: + raise RuntimeError("Expected only one output") + tensor = adapter.outputs(0) + tid = tensor._id + if tid not in tensors: + tensors[tid] = tensor + groups[tid] = list() + groups[tid].append(adapter) + return tensors, groups + + @staticmethod + def group_by_input(adapters: List[IRAdapter]): + """ + Group the adapters by same input tensor(s) + """ + tensors = dict() # Tuple[tensor_id] -> tensor + groups = dict() # Tuple[tensor_id] -> List[IRAdapter] + for adapter in adapters: + tensors = adapter.inputs + tids = [tensor._id for tensor in tensors] + tids.sort() + tids = tuple(tids) + if tids not in tensors: + tensors[tids] = tensors + groups[tids] = list() + groups[tids].append(adapter) + return tensors, groups diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index 12b4cf56..80c7d89f 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -190,6 +190,24 @@ def prims(self, select=True, move=True, merge=True): prims.append(prim) return prims + def idevice(self, input_index: int) -> List[int]: + """ + Get device for input tensor at input index. + + Returns: + device: List[int] + """ + return self._idevices[input_index] + + def odevice(self, output_index: int) -> List[int]: + """ + Get device for output tensor at output index. + + Returns: + device: List[int] + """ + return self._odevices[output_index] + def dispatch(self, rank: int): """ Get Adapter for a specific rank From 3e4766f615fe0a9df254e5c37d52461ffdb6d0fa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Dec 2021 14:18:32 +0800 Subject: [PATCH 0500/1892] mlp --- examples/mlp/linears.py | 21 ++++---- examples/mlp/policy/col_parallel.py | 83 +++++++++++++++++------------ 2 files changed, 60 insertions(+), 44 deletions(-) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 8603d52a..49efa870 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,8 +17,7 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.col_parallel import transform_policy -from examples.mlp.policy.col_parallel import schedule_policy +from examples.mlp.policy.col_parallel import PAS # =================== Semantic Model Description ==================== @@ -29,20 +28,20 @@ def __init__(self, dim, mult=1): self.linear2 = nn.Linear(dim * mult, dim) self.linear3 = nn.Linear(dim, dim * mult) self.linear4 = nn.Linear(dim * mult, dim) - self.linear5 = nn.Linear(dim, dim * mult) - self.linear6 = nn.Linear(dim * mult, dim) - self.linear7 = nn.Linear(dim, dim * mult) - self.linear8 = nn.Linear(dim * mult, dim) + # self.linear5 = nn.Linear(dim, dim * mult) + # self.linear6 = nn.Linear(dim * mult, dim) + # self.linear7 = nn.Linear(dim, dim * mult) + # self.linear8 = nn.Linear(dim * mult, dim) def forward(self, data): output = self.linear1(data) output = self.linear2(output) output = self.linear3(output) output = self.linear4(output) - output = self.linear5(output) - output = self.linear6(output) - output = self.linear7(output) - output = self.linear8(output) + # output = self.linear5(output) + # output = self.linear6(output) + # output = self.linear7(output) + # output = self.linear8(output) loss = torch.sum(output) return loss @@ -58,7 +57,7 @@ def train(): dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [batch_size, dim]) - @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + @cube.compile(model, dataloader, PAS=PAS) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index 7fafdce4..2c0a68e7 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -1,44 +1,61 @@ from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation -def transform_policy(graph: IRGraph, resource): +# def transform_policy(graph: IRGraph, resource): +# """ +# The transformation policy transposes linear using column parallel +# """ +# for node in graph.nodes(): +# if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): +# algo = node.algorithms('column') +# if algo: +# sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) +# else: +# sub_nodes = graph.replicate(node, times=resource.ngpus) +# for idx, sub_node in enumerate(sub_nodes): +# sub_node.tag = idx +# print(graph) +# return graph +# +# +# def schedule_policy(sugraph: SUGraph, resource): +# """ +# The schedule policy assign devices +# """ +# # print(sugraph) +# for su in sugraph.sus(): +# if su.stype == SUType.Dataloader: +# devid = su.tag[0] +# sugraph.assign(su, devid) +# # sugraph.assign(su, list(range(resource.ngpus))) +# for su in sugraph.fsus(): +# devid = su.tag[0] +# sugraph.assign(su, devid) +# if su.mirror is None: +# print(f'error su: {su}') +# assert False +# sugraph.assign(su.mirror, devid) +# fsus = sugraph.fsus() +# print('> [scheduling] setting schedule order...') +# sugraph.partial_set_order(fsus, lazy=False) +# return sugraph + + +def PAS(graph: IRGraph, resource): """ - The transformation policy transposes linear using column parallel + Linear Column Partition """ for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - algo = node.algorithms('column') + algo = node.algorithms('data') if algo: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + sub_nodes = graph.partition( + node, algo, config=dict(chunk_num=resource.ngpus) + ) else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - print(graph) + sub_nodes = [node] + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + print(graph.extra_repr()) return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - # print(sugraph) - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - # sugraph.assign(su, list(range(resource.ngpus))) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - if su.mirror is None: - print(f'error su: {su}') - assert False - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - print('> [scheduling] setting schedule order...') - sugraph.partial_set_order(fsus, lazy=False) - return sugraph From 8bbbd3a12e320cc24ebb29e222538100348d112e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Dec 2021 14:20:27 +0800 Subject: [PATCH 0501/1892] test for pas --- tests/graph/test_pas.py | 98 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tests/graph/test_pas.py diff --git a/tests/graph/test_pas.py b/tests/graph/test_pas.py new file mode 100644 index 00000000..dad056cc --- /dev/null +++ b/tests/graph/test_pas.py @@ -0,0 +1,98 @@ +import torch +from torch import nn + +import cube +from cube.graph.adapter.adapter import IRAdapter +import cube.graph.parser as parser +from cube.logics.pool import SchedulePool +from cube.logics.translator import LogicTranslator +from cube.graph.adapter import AdapterGener +from cube.execplan import ExectuionPlan +from cube.execplan.planpass.grouping import Grouping + +from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen + +class MLP(nn.Module): + def __init__(self, dim, mult=4): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult) + self.linear2 = nn.Linear(dim * mult, dim) + self.linear3 = nn.Linear(dim, dim * mult) + self.linear4 = nn.Linear(dim * mult, dim) + + def forward(self, data): + output = self.linear1(data) + output = self.linear2(output) + output = self.linear3(output) + output = self.linear4(output) + loss = torch.sum(output) + return loss + +model = MLP(dim=1024) +dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [128, 1024]) +optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + +def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + + +def test_p(): + SchedulePool().clear() + + graph = parser.convert_model(model, input_shapes=([128, 1024],)) + loader = parser.convert_dataloader(dataloader) + + train_iter(graph, loader) + graph = LogicTranslator.gen_logic_graph() + # print(graph.extra_repr()) + # assert False + + node1, node2, node3, node4 = graph.nodes()[1:5] + for node in graph.nodes(): + graph.assign(node, rank=0) + algo = node2.algorithms('column') + subnodes = graph.partition(node2, algo, config=dict(chunk_num=4)) + for idx, subnode in enumerate(subnodes): + graph.assign(subnode, rank=idx) + + # print(graph.extra_repr()) + + graph = AdapterGener.gen(graph) + # print(graph) + # for node in graph.nodes(): + # if isinstance(node, IRAdapter): + # print(node.extra_repr()) + + execplan = ExectuionPlan(graph) + # print(execplan) + + execplan = Grouping.apply(execplan) + print(execplan) + # print(execplan.graph.extra_repr()) + + mcodegen = ModelCodeGen(execplan) + tcodegen = ScheduleCodeGen(execplan) + + mcode = mcodegen.gen(device=0) + tcode = tcodegen.gen(device=0) + print(mcode) + print(tcode) + + mcode = mcodegen.gen(device=1) + tcode = tcodegen.gen(device=1) + print(mcode) + print(tcode) + + mcode = mcodegen.gen(device=2) + tcode = tcodegen.gen(device=2) + print(mcode) + print(tcode) + mcode = mcodegen.gen(device=3) + tcode = tcodegen.gen(device=3) + print(mcode) + print(tcode) + + assert False From dc9898f9124ac7638b35ca21fa079fe72b69af80 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Dec 2021 16:47:37 +0800 Subject: [PATCH 0502/1892] p2pfusion on allgather --- cube/codegen/codegen.py | 32 +++- cube/compiler.py | 6 + cube/execplan/planpass/fusion.py | 296 +++++++++++++++++++++++++++++-- cube/graph/adapter/adapter.py | 73 +++++++- 4 files changed, 384 insertions(+), 23 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 29269e2a..8956d98c 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -9,7 +9,7 @@ from cube.ir.dtype import IRDType from cube.graph.tensor import IRSubTensor from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.graph.adapter.adapter import IRAdapter, SelectPrim, MovePrim, MergePrim +from cube.graph.adapter.adapter import CollectivePrim, IRAdapter, SelectPrim, MovePrim, MergePrim from cube.graph.adapter.adapter import IRWeightReducer from cube.execplan import ExectuionPlan # from cube.schedule.adapter.collectives import IRCollectives @@ -256,11 +256,24 @@ def emit_adapter_call(self, node: IRAdapter): # emit merge elif isinstance(prim, MergePrim): sign = 'cube.runtime.adapter.merge({tensors}, {concat}, {add})' - inputs = [self.tensor_naming(t) for t in prim.tensors] - inputs = '(' + ','.join(inputs + ['']) + ')' + inputs = self.tuple_naming(prim.tensors) output = self.tensor_naming(prim.output) code = f'{output} = {sign.format(tensors=inputs, concat=prim.concat, add=prim.add)}' self.forward_region.append(code) + # emit collectives + elif isinstance(prim, CollectivePrim): + sign = 'cube.runtime.adapter.{ctype}({input_tensors}, {output_shapes}, {output_dtypes}, {group})' + inputs = self.tuple_naming(prim.inputs) + outputs = self.return_naming(prim.outputs) + body = sign.format( + ctype=prim.ctype.value, + input_tensors = inputs, + output_shapes = prim.output_shapes, + output_dtypes = prim.output_dtypes, + group=prim.group + ) + code = f'{outputs} = {body}' + self.forward_region.append(code) else: raise TypeError(f"Unkown primitive types {type(prim)} of Adapter") # requires grad generation @@ -293,6 +306,19 @@ def emit_reducer_call(self, node: IRWeightReducer): call_code = f'{reducer_name}.allreduce()' self.forward_region.append(call_code) + def return_naming(self, tensors: List[Any]) -> str: + tensors = [self.tensor_naming(t) for t in tensors] + if len(tensors) == 0: + tensors = '_' + else: + tensors = ', '.join(tensors) + return tensors + + def tuple_naming(self, tensors: List[Any]) -> str: + tensors = [self.tensor_naming(t) for t in tensors] + tensors = '(' + ', '.join(tensors + ['']) + ')' + return tensors + def tensor_naming(self, tensor: Any): """ Generate tensor name. diff --git a/cube/compiler.py b/cube/compiler.py index 60b5aa09..9a0e0d0f 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -19,6 +19,7 @@ # from cube.execplan.planpass.gfuse import WeightGradAllreduceFusion # from cube.execplan.planpass.p2pfusion import P2PFusion from cube.execplan.planpass.grouping import Grouping +from cube.execplan.planpass.fusion import P2PFusion from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -156,6 +157,11 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on grouping operations: {:.2f} s'.format(span)) + start = time.time() + execplan = P2PFusion.apply(execplan) + span = time.time() - start + print('> planpass on p2pfusion operations: {:.2f} s'.format(span)) + # plan pass to adapt to pytorch semantic: multi branch gradient # TODO: residual support diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 5a2f0ec2..10e0bc03 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -1,8 +1,15 @@ from typing import List, Dict -from cube.execplan import ExectuionPlan -from cube.graph.adapter.adapter import IRAdapter +# debug only +# import sys +# if tid == tensor_id: print(f'out line: {sys._getframe().f_lineno}') + from cube.graph.tensor import IRSubTensor, ValueMap + +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.adapter.adapter import CollectivePrim, MergePrim + +from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass @@ -14,36 +21,153 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: for node in execplan.graph.nodes(): if isinstance(node, IRAdapter): adapters.append(node) - pass + matchers = [ + P2PFusion.allreduce_matcher, + P2PFusion.allgather_matcher, + P2PFusion.reducescatter_matcher, + P2PFusion.broadcast_matcher, + ] + for matcher in matchers: + matcher(execplan, adapters) + return execplan @staticmethod - def allgather_matcher(execplan: ExectuionPlan): + def allgather_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): """ Allgather semantic: Given a list of adapters: 1). [Num] each adapter has same multiple inputs and same one output - 2). [Dev] inputs/outputs among adapters are from different device. device# = number of adapters. - 3). [Indmap] No-overlap index-map among inputs. - 4). [Valmap] each input value-map is same with output valuemap + 2). [Dev] inputs/outputs among adapters are from different device. + 3). [Dev] adapters have same device. adapters# is same to device set. + 4). [Indmap] inputs inside one adapter are not overlapped + 5). [Valmap] each input value-map is same with output valuemap """ - pass + outputs, groups = P2PFusion.group_by_output(all_adapters) + for tid in outputs: + adapters: List[IRAdapter] = groups[tid] + # condition 1) + if not P2PFusion._check_multi_inputs(adapters): + continue + if not P2PFusion._check_same_inputs(adapters): + continue + # condition 2) + if not P2PFusion._check_different_inputs_devices(adapters, among=False): + continue + if not P2PFusion._check_different_outputs_devices(adapters, among=True): + continue + # condition 3) + cond = True + for adapter in adapters: + if len(adapters) != len(adapter.device): + cond = False + break + if not cond: + continue + # condition 4) + cond = True + for adapter in adapters: + if not P2PFusion._check_indmap_no_overlap(adapter.inputs()): + cond = False + break + if not cond: + continue + # condition 5) + cond = True + for adapter in adapters: + if not P2PFusion._check_valmap_same(adapter.inputs() + adapter.outputs()): + cond = False + break + if not cond: + continue + # gen allgather + print(f'generating allgather for tensor: {outputs[tid]} ...') + for adapter in adapters: + device = adapter.odevice(0) + input_idx = adapter.idevice().index(device) + inputs = [adapter.inputs(input_idx)] + coll = CollectivePrim( + ctype = CollectivePrim.Type.AllGather, + device = device, + group = adapter.device, + inputs = inputs, + input_shapes = None, + input_dtypes = None, + outputs = adapter.inputs(), + output_shapes = None, + output_dtypes = None, + ) + # merge prim still keeps, remove select and move prims + prims = [coll] + adapter.prims(select=False, move=False, coll=False) + adapter._prims = prims + for adapter in adapters: + all_adapters.remove(adapter) @staticmethod - def allreduce_matcher(execplan: ExectuionPlan): + def allreduce_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): """ Allreduce semantic: Given a list of adapters: 1). [Num] each adapter has different one input and same one output 2). [Dev] inputs/outputs among adapters are from different devices - 2). [Indmap] inputs among adapters has same index-map with output. - 3). [Valmap] inputs have parital value-map. Output has full value-map + 3). [Indmap] inputs among adapters has same index-map with output. + 4). [Valmap] inputs have parital value-map. Output has full value-map """ - pass + return + outputs, groups = P2PFusion.group_by_output(all_adapters) + for tid in outputs: + adapters = groups[tid] + # condition 1) + if not P2PFusion._check_multi_inputs(adapters): + continue + if not P2PFusion._check_same_inputs(adapters): + continue + # condition 2) + if not P2PFusion._check_different_inputs_devices(adapters, among=True): + continue + if not P2PFusion._check_different_outputs_devices(adapters, among=True): + continue + # condition 3) + cond = True + for adapter in adapters: + if not P2PFusion._check_indmap_same(adapter.inputs() + adapter.outputs()): + cond = False + break + if not cond: + continue + # condition 4) + inputs = list() + for adapter in adapters: + inputs += adapter.inputs() + if not P2PFusion._check_valmap_no_overlap(inputs): + continue + cond = True + for adapter in adapters: + if adapter.outputs(0).valmap != ValueMap(0, 1): + cond = False + break + if not cond: + continue + # generate + print(f'generating allreduce for tensor: {outputs[tid]} ...') + for adapter in adapters: + device = adapter.odevice(0) + input_idx = adapter.idevice().index(device) + inputs = [adapter.inputs(input_idx)] + coll = CollectivePrim( + ctype = CollectivePrim.Type.AllReduce, + device = device, + group = adapter.device, + inputs = inputs, + outputs = adapter.outputs(), + ) + adapter._prims = [coll] + for adapter in adapters: + all_adapters.remove(adapter) @staticmethod - def reducescatter_matcher(execplan: ExectuionPlan): + def reducescatter_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): """ ReduceScatter semantic: @@ -58,7 +182,7 @@ def reducescatter_matcher(execplan: ExectuionPlan): pass @staticmethod - def broadcast_matcher(execplan: ExectuionPlan): + def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): """ Broadcast semantic: @@ -105,3 +229,147 @@ def group_by_input(adapters: List[IRAdapter]): groups[tids] = list() groups[tids].append(adapter) return tensors, groups + + @staticmethod + def _check_same_inputs(adapters: List[IRAdapter]): + """ + Check if the inputs are same among adapters + """ + input_ids = list() + for adapter in adapters: + tids = [t._id for t in adapter.inputs()] + tids.sort() + input_ids.append(tids) + ninputs = [len(tids) for tids in input_ids] + # number of inputs not same + if len(set(ninputs)) != 1: + return False + # input ids not same + for tids in zip(*input_ids): + if len(set(tids)) != 1: + return False + return True + + @staticmethod + def _check_multi_inputs(adapters: List[IRAdapter]): + for adapter in adapters: + if len(adapter.inputs()) <= 1: + return False + return True + + @staticmethod + def _check_single_inputs(adapters: List[IRAdapter]): + for adapter in adapters: + if len(adapter.inputs()) != 1: + return False + return True + + @staticmethod + def _get_input_devices(adapter: IRAdapter) -> List[int]: + """ + Return sorted device list for all inputs + """ + device = set() + for idevice in adapter.idevice(): + device.update(idevice) + device = list(device) + device.sort() + return device + + @staticmethod + def _get_output_devices(adapter: IRAdapter) -> List[int]: + """ + Return sorted device list for all outputs + """ + device = set() + for odevice in adapter.odevice(): + device.update(odevice) + device = list(device) + device.sort() + return device + + @staticmethod + def _check_different_inputs_devices(adapters: List[IRAdapter], among: bool): + if among: + adapter_devices = list() + for adapter in adapters: + device = P2PFusion._get_input_devices(adapter) + adapter_devices.append(tuple(device)) + if len(set(adapter_devices)) != len(adapters): + return False + return True + else: + for adapter in adapters: + device = P2PFusion._get_input_devices(adapter) + # assume each tensor is attached to one deivce + if len(device) != len(adapter.inputs()): + return False + return True + + @staticmethod + def _check_different_outputs_devices(adapters: List[IRAdapter], among: bool): + if among: + adapter_devices = list() + for adapter in adapters: + device = set() + for odevice in adapter.odevice(): + device.update(odevice) + device = list(device) + device.sort() + adapter_devices.append(tuple(device)) + if len(set(adapter_devices)) != len(adapters): + return False + return True + else: + for adapter in adapters: + device = set() + for odevice in adapter.odevice(): + device.update(odevice) + # assume each tensor is attached to one deivce + if len(device) != len(adapter.outputs()): + return False + return True + + @staticmethod + def _check_indmap_same(tensors: List[IRSubTensor]): + if len(tensors) == 0: + return True + indmap = tensors[0].indmap + for tensor in tensors[1:]: + if tensor.indmap != indmap: + return False + return True + + @staticmethod + def _check_indmap_no_overlap(tensors: List[IRSubTensor]): + if len(tensors) == 0: + return True + for idx1 in range(len(tensors) - 1): + for idx2 in range(idx1 + 1, len(tensors)): + t1 = tensors[idx1] + t2 = tensors[idx2] + if t1.indmap.overlap(t2.indmap): + return False + return True + + @staticmethod + def _check_valmap_same(tensors: List[IRSubTensor]): + if len(tensors) == 0: + return True + valmap = tensors[0].valmap + for tensor in tensors[1:]: + if tensor.valmap != valmap: + return False + return True + + @staticmethod + def _check_valmap_no_overlap(tensors: List[IRSubTensor]): + if len(tensors) == 0: + return True + for idx1 in range(len(tensors) - 1): + for idx2 in range(idx1 + 1, len(tensors)): + t1 = tensors[idx1] + t2 = tensors[idx2] + if t1.valmap.overlap(t2.valmap): + return False + return True diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index 80c7d89f..fb1e0277 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -1,10 +1,12 @@ +from enum import Enum from typing import List, Optional, Tuple import copy import numpy as np from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap from cube.ir.cten import IRCell +from cube.ir.dtype import IRDType class SelectPrim: @@ -38,6 +40,57 @@ def __repr__(self): return dscp +class CollectivePrim: + + class Type(Enum): + AllReduce = 'all_reduce' + AllGather = 'all_gather' + ReduceScatter = 'reduce_scatter' + Broadcast = 'broadcast' + + def __init__(self, ctype: Enum, + device: List[int], + group: List[int], + inputs: List[IRSubTensor] = None, + input_shapes: List[List[int]] = None, + input_dtypes: List[IRDType] = None, + outputs: List[IRSubTensor] = None, + output_shapes: List[List[int]] = None, + output_dtypes: List[IRDType] = None): + """ + inputs: + the collective input tensors. Including remote tensors. + src_ranks: + the tensor rank for each corresponding input tensor + outputs: + the collective output tensors. Including remote tensors. + dst_ranks: + the tensor rank for each corresponding output tensor + device: + the collective to be performed rank. + Note n-device collective will have n CollectivePrim, + each needs to be assigned with a single device rank. + """ + self.ctype = ctype + # inputs + self.inputs: List[IRSubTensor] = inputs + self.input_shapes: List[IRSubTensor] = input_shapes + self.input_dtypes: List[IRDType] = input_dtypes + # outputs + self.outputs: List[IRSubTensor] = outputs + self.output_shapes: List[IRSubTensor] = output_shapes + self.output_dtypes: List[IRDType] = output_dtypes + # communication group + group.sort() + self.group: List[int] = group + # device + self.device = device + + def __repr__(self): + dscp = f'{self.outputs} = {self.ctype.value}(inputs={self.inputs}, group={self.group})' + return dscp + + class MergePrim: def __init__(self, tensors: List[IRSubTensor], output: IRSubTensor, device: List[int], @@ -176,7 +229,7 @@ def __init__(self, prims, device.update(prim.device) self.device = list(device) - def prims(self, select=True, move=True, merge=True): + def prims(self, select=True, move=True, merge=True, coll=True): """ Return prim list """ @@ -188,25 +241,33 @@ def prims(self, select=True, move=True, merge=True): prims.append(prim) if merge and isinstance(prim, MergePrim): prims.append(prim) + if coll and isinstance(prim, CollectivePrim): + prims.append(prim) return prims - def idevice(self, input_index: int) -> List[int]: + def idevice(self, input_index: int = None) -> List[int]: """ Get device for input tensor at input index. Returns: device: List[int] """ - return self._idevices[input_index] + if isinstance(input_index, int): + return self._idevices[input_index] + else: + return copy.copy(self._idevices) - def odevice(self, output_index: int) -> List[int]: + def odevice(self, output_index: int = None) -> List[int]: """ Get device for output tensor at output index. Returns: device: List[int] """ - return self._odevices[output_index] + if isinstance(output_index, int): + return self._odevices[output_index] + else: + return copy.copy(self._odevices) def dispatch(self, rank: int): """ @@ -392,7 +453,7 @@ def extra_repr(self): """ dscp = repr(self) + ':\n' # select - for prim in self._select_prims + self._move_prims + self._merge_prims: + for prim in self._prims: dscp += '\t' + repr(prim) + '\n' return dscp From 6f5555ff978159eb5ac99ad09fd93c817904c61e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Dec 2021 19:00:37 +0800 Subject: [PATCH 0503/1892] add reducescatter fusion --- cube/execplan/planpass/fusion.py | 116 ++++++++++++++++++++++++---- cube/graph/adapter/adapter.py | 9 +++ cube/runtime/adapter/collectives.py | 73 +++++++++-------- 3 files changed, 149 insertions(+), 49 deletions(-) diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 10e0bc03..7738ec2d 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -7,7 +7,7 @@ from cube.graph.tensor import IRSubTensor, ValueMap from cube.graph.adapter.adapter import IRAdapter -from cube.graph.adapter.adapter import CollectivePrim, MergePrim +from cube.graph.adapter.adapter import CollectivePrim from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass @@ -29,6 +29,15 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: ] for matcher in matchers: matcher(execplan, adapters) + # update adapter devices + for node in execplan.graph.nodes(): + if isinstance(node, IRAdapter): + node.update_device() + for devid in execplan.devices(): + for node in execplan.sequence(devid): + if isinstance(node, IRAdapter): + if devid not in node.device: + execplan.at(devid).remove(node) return execplan @staticmethod @@ -111,32 +120,40 @@ def allreduce_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): Given a list of adapters: 1). [Num] each adapter has different one input and same one output 2). [Dev] inputs/outputs among adapters are from different devices - 3). [Indmap] inputs among adapters has same index-map with output. - 4). [Valmap] inputs have parital value-map. Output has full value-map + 3). [Dev] adapters have same device. adapters# is same to device set. + 4). [Indmap] inputs among adapters has same index-map with output. + 5). [Valmap] inputs have parital value-map. Output has full value-map """ - return outputs, groups = P2PFusion.group_by_output(all_adapters) for tid in outputs: - adapters = groups[tid] + adapters: List[IRAdapter] = groups[tid] # condition 1) if not P2PFusion._check_multi_inputs(adapters): continue if not P2PFusion._check_same_inputs(adapters): continue # condition 2) - if not P2PFusion._check_different_inputs_devices(adapters, among=True): + if not P2PFusion._check_different_inputs_devices(adapters, among=False): continue if not P2PFusion._check_different_outputs_devices(adapters, among=True): continue # condition 3) cond = True for adapter in adapters: - if not P2PFusion._check_indmap_same(adapter.inputs() + adapter.outputs()): + if len(adapters) != len(adapter.device): cond = False break if not cond: continue # condition 4) + cond = True + for adapter in adapters: + if not P2PFusion._check_indmap_same(adapter.inputs() + adapter.outputs()): + cond = False + break + if not cond: + continue + # condition 5) inputs = list() for adapter in adapters: inputs += adapter.inputs() @@ -172,14 +189,84 @@ def reducescatter_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter] ReduceScatter semantic: Given a list of adapters: - 1). [Num] each adapter has different one input and different one output + 1). [Num] each adapter has same multiple input and different one output 2). [Dev] inputs/outputs among adapters are from different devices - 3). [Indmap] inputs among adapters have same index-map - 4). [Indmap] outputs among adapters have different index-map - 5). [Valmap] inputs among adapters have different partial val-map. - 6). [Valmap] outputs among adapters have same Full val-map + 3). [Dev] adapters have same device. adapters# is same to device set + 4). [Indmap] inputs of each adapter have same index-map + 5). [Indmap] outputs among adapters have different index-map + 6). [Valmap] inputs of each adapter have different partial val-map + 7). [Valmap] outputs among adapters have same Full val-map """ - pass + inputs, groups = P2PFusion.group_by_input(all_adapters) + for tids in inputs: + adapters: List[IRAdapter] = groups[tids] + # cond 1) + otids = [adapter.outputs(0)._id for adapter in adapters] + if len(set(otids)) != len(adapters): + continue + # cond 2) + if not P2PFusion._check_different_inputs_devices(adapters, among=False): + continue + if not P2PFusion._check_different_outputs_devices(adapters, among=True): + continue + # cond 3) + cond = True + for adapter in adapters: + if len(adapters) != len(adapter.device): + cond = False + break + if not cond: + continue + # cond 4) + cond = True + for adapter in adapters: + if not P2PFusion._check_indmap_same(adapter.inputs()): + cond = False + break + if not cond: + continue + # cond 5) + outputs = [adapter.outputs(0) for adapter in adapters] + if not P2PFusion._check_indmap_no_overlap(outputs): + continue + # cond 6) + cond = True + for adapter in adapters: + if not P2PFusion._check_valmap_no_overlap(adapter.inputs()): + cond = False + break + if not cond: + continue + # cond 7) + cond = True + for adapter in adapters: + if adapter.outputs(0).valmap != ValueMap(0, 1): + cond = False + break + if not cond: + continue + # gen reduce-scatter + print(f'generating reduce-scatter for tensor: {tids} ...') + all_select_prims = list() + for adapter in adapters: + all_select_prims += adapter.prims(move=False, merge=False, coll=False) + for adapter in adapters: + device = adapter.odevice(0) + sprims = [prim for prim in all_select_prims if prim.device == device] + if len(sprims) != len(adapters): + raise RuntimeError(f"got {len(sprims)} (!={len(adapters)}) select prims for reduce-scatter") + inputs = [sprim.output for sprim in sprims] + coll = CollectivePrim( + ctype = CollectivePrim.Type.ReduceScatter, + device = device, + group = adapter.device, + inputs = inputs, + outputs = adapter.outputs(), + ) + prims = sprims + [coll] + adapter._prims = prims + for adapter in adapters: + all_adapters.remove(adapter) @staticmethod def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): @@ -220,8 +307,7 @@ def group_by_input(adapters: List[IRAdapter]): tensors = dict() # Tuple[tensor_id] -> tensor groups = dict() # Tuple[tensor_id] -> List[IRAdapter] for adapter in adapters: - tensors = adapter.inputs - tids = [tensor._id for tensor in tensors] + tids = [tensor._id for tensor in adapter.inputs()] tids.sort() tids = tuple(tids) if tids not in tensors: diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index fb1e0277..3cc92036 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -298,6 +298,15 @@ def dispatch(self, rank: int): adapter.device = rank return adapter + def update_device(self): + """ + Update device (needed when adapter content changes, e.g., P2PFusion) + """ + device = set() + for prim in self._prims: + device.update(prim.device) + self.device = list(device) + def is_identity(self): """ Check if the adapter does nothing diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 2d0f23ed..a8237a91 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -2,7 +2,7 @@ import torch from cube.runtime.device import DeviceGroup -from cube.profiler.timer import CudaTimer +from cube.profiler.timer import CudaTimer, print_each_rank def send(tensor: torch.Tensor, to_rank: int): @@ -95,22 +95,27 @@ def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): else: return tuple(recv_tensors) -def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): +### Collective Universal Interface ### +# def universal(input_tensors: List[torch.Tensor], +# output_shapes: List[List[int]], +# output_dtypes: List[torch.dtype], +# ranks: List[int]) + + +def all_reduce(input_tensors: List[torch.Tensor], + output_shapes: List[List[int]], + output_dtypes: List[torch.dtype], + ranks: List[int]) -> torch.Tensor: """ Allreduce """ CudaTimer().start(field_name='comm') - # print(f'{torch.distributed.get_rank()}: all_reduce...') - assert len(tensors) == 1 - tensor = tensors[0] + assert len(input_tensors) == 1 + tensor = input_tensors if not tensor.is_contiguous(): tensor = tensor.contiguous() tensor = tensor.detach() tensor = tensor.requires_grad_() - - ### Bypass ### - # return tensor - group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(tensor, group=group) @@ -118,62 +123,62 @@ def all_reduce(tensors: List[torch.Tensor], ranks: List[int]): return tensor -def all_gather(tensors: List[torch.Tensor], ranks: List[int]): +def all_gather(input_tensors: List[torch.Tensor], + output_shapes: List[List[int]], + output_dtypes: List[torch.dtype], + ranks: List[int]) -> List[torch.Tensor]: """ Allgather """ - # print(f'{torch.distributed.get_rank()}: all_gather...') CudaTimer().start(field_name='comm') - - assert len(tensors) == 1 - tensor = tensors[0] + assert len(input_tensors) == 1 + tensor = input_tensors[0] group = DeviceGroup().get_group(ranks) tensor_list = [torch.empty_like(tensor) for _ in ranks] idx = ranks.index(DeviceGroup().rank) tensor_list[idx] = tensor torch.distributed.all_gather(tensor_list, tensor, group=group) - tensor_list = [t for oidx, t in enumerate(tensor_list) if oidx != idx] - CudaTimer().stop(field_name='comm') - if len(tensor_list) == 1: - return tensor_list[0] - else: - return tensor_list + return tensor_list -def reduce_scatter(tensors: List[torch.Tensor], ranks: List[int]): +def reduce_scatter(input_tensors: List[torch.Tensor], + output_shapes: List[List[int]], + output_dtypes: List[torch.dtype], + ranks: List[int]) -> List[torch.Tensor]: """ ReduceScatter """ - # print(f'{torch.distributed.get_rank()}: reduce-scatter...') CudaTimer().start(field_name='comm') - - tensors = list(tensors) + input_tensors = list(input_tensors) group = DeviceGroup().get_group(ranks) idx = ranks.index(DeviceGroup().rank) - output = torch.empty_like(tensors[idx], requires_grad=True) + output = torch.empty_like(input_tensors[idx], requires_grad=True) torch.distributed.reduce_scatter( - output, tensors, group=group + output, input_tensors, group=group ) - CudaTimer().stop(field_name='comm') return output -def broadcast(tensors: List[torch.Tensor], ranks: List[int], shape=None, dtype=None): +def broadcast(input_tensors: List[torch.Tensor], + output_shapes: List[List[int]], + output_dtypes: List[torch.dtype], + ranks: List[int]) -> List[torch.Tensor]: """ Broadcast. ranks[0] is the root """ CudaTimer().start(field_name='comm') - # print(f'{torch.distributed.get_rank()}: broadcast...') - # FIXME: data type - if len(tensors) == 1: - tensor = tensors[0] + assert len(input_tensors) == 1 or len(input_tensors) == 0 + if len(input_tensors) == 1: + tensor = input_tensors[0] else: + assert len(output_shapes) == 1 + assert len(output_dtypes) == 1 + shape = output_shapes[0] + dtype = output_dtypes[0] tensor = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) - # tensor.requires_grad_() group = DeviceGroup().get_group(ranks) torch.distributed.broadcast(tensor, ranks[0], group=group) - CudaTimer().stop(field_name='comm') return tensor From 6f4208141d8ae76e25e3fe74c3abca24c32c4cee Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Dec 2021 20:11:36 +0800 Subject: [PATCH 0504/1892] fix allreduce runtime bug --- cube/execplan/planpass/fusion.py | 192 +++++++---- cube/execplan/planpass/p2pfusion.py | 495 ---------------------------- cube/execplan/planpass/redundant.py | 54 --- cube/graph/graph.py | 34 -- cube/runtime/adapter/collectives.py | 2 +- 5 files changed, 123 insertions(+), 654 deletions(-) delete mode 100644 cube/execplan/planpass/p2pfusion.py delete mode 100644 cube/execplan/planpass/redundant.py diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 7738ec2d..c5516a2b 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -1,4 +1,5 @@ -from typing import List, Dict +from typing import List +import copy # debug only # import sys @@ -12,6 +13,12 @@ from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass +# FIXME: all fusions don't consider input order! +# May get incorrect result in some cases. + +# FIXME: all fusions don't check if the communication can be happened at +# the same time + class P2PFusion(PlanPass): @@ -41,19 +48,20 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: return execplan @staticmethod - def allgather_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): + def allreduce_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): """ - Allgather semantic: + Allreduce semantic: Given a list of adapters: - 1). [Num] each adapter has same multiple inputs and same one output - 2). [Dev] inputs/outputs among adapters are from different device. + 1). [Num] each adapter has different one input and same one output + 2). [Dev] inputs/outputs among adapters are from different devices 3). [Dev] adapters have same device. adapters# is same to device set. - 4). [Indmap] inputs inside one adapter are not overlapped - 5). [Valmap] each input value-map is same with output valuemap + 4). [Indmap] inputs among adapters has same index-map with output. + 5). [Valmap] inputs have parital value-map. Output has full value-map """ outputs, groups = P2PFusion.group_by_output(all_adapters) for tid in outputs: + cond = True adapters: List[IRAdapter] = groups[tid] # condition 1) if not P2PFusion._check_multi_inputs(adapters): @@ -66,67 +74,61 @@ def allgather_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): if not P2PFusion._check_different_outputs_devices(adapters, among=True): continue # condition 3) - cond = True for adapter in adapters: if len(adapters) != len(adapter.device): cond = False break - if not cond: - continue + if not cond: continue # condition 4) - cond = True for adapter in adapters: - if not P2PFusion._check_indmap_no_overlap(adapter.inputs()): + if not P2PFusion._check_indmap_same(adapter.inputs() + adapter.outputs()): cond = False break - if not cond: - continue + if not cond: continue # condition 5) - cond = True for adapter in adapters: - if not P2PFusion._check_valmap_same(adapter.inputs() + adapter.outputs()): + if not P2PFusion._check_valmap_no_overlap(adapter.inputs()): cond = False break - if not cond: - continue - # gen allgather - print(f'generating allgather for tensor: {outputs[tid]} ...') + if not cond: continue + for adapter in adapters: + if adapter.outputs(0).valmap != ValueMap(0, 1): + cond = False + break + if not cond: continue + # generate + print(f'generating allreduce for tensor: {outputs[tid]} ...') for adapter in adapters: device = adapter.odevice(0) input_idx = adapter.idevice().index(device) inputs = [adapter.inputs(input_idx)] coll = CollectivePrim( - ctype = CollectivePrim.Type.AllGather, + ctype = CollectivePrim.Type.AllReduce, device = device, group = adapter.device, inputs = inputs, - input_shapes = None, - input_dtypes = None, - outputs = adapter.inputs(), - output_shapes = None, - output_dtypes = None, + outputs = adapter.outputs(), ) - # merge prim still keeps, remove select and move prims - prims = [coll] + adapter.prims(select=False, move=False, coll=False) - adapter._prims = prims + adapter._prims = [coll] for adapter in adapters: all_adapters.remove(adapter) @staticmethod - def allreduce_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): + def allgather_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): """ - Allreduce semantic: + Allgather semantic: Given a list of adapters: - 1). [Num] each adapter has different one input and same one output - 2). [Dev] inputs/outputs among adapters are from different devices + 1). [Num] each adapter has same multiple inputs and same one output + 2). [Dev] inputs/outputs among adapters are from different device. 3). [Dev] adapters have same device. adapters# is same to device set. - 4). [Indmap] inputs among adapters has same index-map with output. - 5). [Valmap] inputs have parital value-map. Output has full value-map + 4). [Indmap] inputs inside one adapter are not overlapped + 5). [Valmap] each input value-map is same with output valuemap """ outputs, groups = P2PFusion.group_by_output(all_adapters) for tid in outputs: adapters: List[IRAdapter] = groups[tid] + cond = True # condition 1) if not P2PFusion._check_multi_inputs(adapters): continue @@ -138,48 +140,43 @@ def allreduce_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): if not P2PFusion._check_different_outputs_devices(adapters, among=True): continue # condition 3) - cond = True for adapter in adapters: if len(adapters) != len(adapter.device): cond = False break - if not cond: - continue + if not cond: continue # condition 4) - cond = True for adapter in adapters: - if not P2PFusion._check_indmap_same(adapter.inputs() + adapter.outputs()): + if not P2PFusion._check_indmap_no_overlap(adapter.inputs()): cond = False break - if not cond: - continue + if not cond: continue # condition 5) - inputs = list() - for adapter in adapters: - inputs += adapter.inputs() - if not P2PFusion._check_valmap_no_overlap(inputs): - continue - cond = True for adapter in adapters: - if adapter.outputs(0).valmap != ValueMap(0, 1): + if not P2PFusion._check_valmap_same(adapter.inputs() + adapter.outputs()): cond = False break - if not cond: - continue - # generate - print(f'generating allreduce for tensor: {outputs[tid]} ...') + if not cond: continue + # gen allgather + print(f'generating allgather for tensor: {outputs[tid]} ...') for adapter in adapters: device = adapter.odevice(0) input_idx = adapter.idevice().index(device) inputs = [adapter.inputs(input_idx)] coll = CollectivePrim( - ctype = CollectivePrim.Type.AllReduce, + ctype = CollectivePrim.Type.AllGather, device = device, group = adapter.device, inputs = inputs, - outputs = adapter.outputs(), + input_shapes = None, + input_dtypes = None, + outputs = adapter.inputs(), + output_shapes = None, + output_dtypes = None, ) - adapter._prims = [coll] + # merge prim still keeps, remove select and move prims + prims = [coll] + adapter.prims(select=False, move=False, coll=False) + adapter._prims = prims for adapter in adapters: all_adapters.remove(adapter) @@ -200,6 +197,7 @@ def reducescatter_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter] inputs, groups = P2PFusion.group_by_input(all_adapters) for tids in inputs: adapters: List[IRAdapter] = groups[tids] + cond = True # cond 1) otids = [adapter.outputs(0)._id for adapter in adapters] if len(set(otids)) != len(adapters): @@ -210,41 +208,33 @@ def reducescatter_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter] if not P2PFusion._check_different_outputs_devices(adapters, among=True): continue # cond 3) - cond = True for adapter in adapters: if len(adapters) != len(adapter.device): cond = False break - if not cond: - continue + if not cond: continue # cond 4) - cond = True for adapter in adapters: if not P2PFusion._check_indmap_same(adapter.inputs()): cond = False break - if not cond: - continue + if not cond: continue # cond 5) outputs = [adapter.outputs(0) for adapter in adapters] if not P2PFusion._check_indmap_no_overlap(outputs): continue # cond 6) - cond = True for adapter in adapters: if not P2PFusion._check_valmap_no_overlap(adapter.inputs()): cond = False break - if not cond: - continue + if not cond: continue # cond 7) - cond = True for adapter in adapters: if adapter.outputs(0).valmap != ValueMap(0, 1): cond = False break - if not cond: - continue + if not cond: continue # gen reduce-scatter print(f'generating reduce-scatter for tensor: {tids} ...') all_select_prims = list() @@ -274,11 +264,73 @@ def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): Broadcast semantic: Given a list of adapters: - 1). [Num] each adapter has same input and output. input = output. + 1). [Num] each adapter has same one input and one output. input = output. 2). [Dev] inputs among adapters are from a same device. 3). [Dev] outputs among adapters are from different devices """ - pass + outputs, groups = P2PFusion.group_by_output(all_adapters) + for tid in outputs: + adapters: List[IRAdapter] = groups[tid] + cond = True + # cond 1) + if not P2PFusion._check_same_inputs(adapters): + continue + if not P2PFusion._check_single_inputs(adapters): + continue + for adapter in adapters: + if adapter.inputs(0) != adapter.outputs(0): + cond = False + break + if not cond: continue + # cond 2) + device = set([adapter.idevice(0)[0] for adapter in adapters]) + if len(device) != 1: + continue + # cond 3) + if not P2PFusion._check_different_outputs_devices(adapters, among=True): + continue + # gen broadcast + print(f'generating broadcast for tensor: {outputs[tid]} ... (NOT SUPPORTED)') + return + # put root rank to the first + root = list(device)[0] + group = set() + for adapter in adapters: + group.update(P2PFusion._get_output_devices(adapter)) + group.remove(root[0]) + # inputs + tensor = adapters[0].inputs() + group = [root] + list(group) + + prims = list() + for device in group: + inputs = tensor if device == root else None + output_shapes = [tensor.shape] + output_dtypes = [tensor.dtype] + coll = CollectivePrim( + ctype = CollectivePrim.Type.Broadcast, + device = device, + group = group, + inputs = inputs, + output_shapes = output_shapes, + output_dtypes = output_dtypes + ) + prims.append(coll) + + # add aditional adapter to root node + root_adapter = IRAdapter( + prims = [prims[0]], + inputs=[tensor], idevices=[[root],], + outputs=[tensor], odevices=[[root],] + ) + # TODO: this should insert into graph and execution plan + for adapter in [root_adapter] + adapters: + device = adapter.odevice(0)[0] + prim = prims[group.index(device)] + adapter._prims = [prim] + + for adapter in adapters: + all_adapters.remove(adapter) # Utilities @staticmethod diff --git a/cube/execplan/planpass/p2pfusion.py b/cube/execplan/planpass/p2pfusion.py deleted file mode 100644 index a72ccff9..00000000 --- a/cube/execplan/planpass/p2pfusion.py +++ /dev/null @@ -1,495 +0,0 @@ -from typing import List, Dict -from cube.execplan import ExectuionPlan -from cube.graph.tensor import ValueMap -from cube.ir.cten import IRTensor -from cube.schedule.su import SUType, ScheduleUnit -from cube.execplan.planpass.planpass import PlanPass - -from cube.schedule.adapter.collectives import IRCollType, IRCollectives - - -class P2PFusion(PlanPass): - - @staticmethod - def apply(execplan: ExectuionPlan) -> ExectuionPlan: - # dict[pid][devid] = list of sub_tensors - fous, fins = P2PFusion.collect_tensors( - execplan, [SUType.Dataloader, SUType.Forward] - ) - bous, bins = P2PFusion.collect_tensors( - execplan, [SUType.Backward] - ) - # debug - # print('=====> forward') - # for pid in fins: - # if pid not in fous: - # continue - # if P2PFusion.have_comm(fous[pid], fins[pid]): - # print(f'=> parent tensor id: {pid}') - # for devid in fous[pid]: - # print(f' ==> device: {devid}') - # for val in fous[pid][devid]: - # print(f' o:', val) - # for devid in fins[pid]: - # print(f' ==> device: {devid}') - # for val in fins[pid][devid]: - # print(f' i:', val) - - matchers = [ - P2PFusion.match_allreduce, - P2PFusion.match_allgather, - P2PFusion.match_reducescatter, - P2PFusion.match_broadcast, - ] - for ous, ins in zip([fous, bous], [fins, bins]): - for pid in ins: - if pid not in ous: - continue - tous, tins = ous[pid], ins[pid] - # if they are on the single device, matching is skipped - if len(tous) == 1 and set(tous.keys()) == set(tins.keys()): - continue - if P2PFusion.have_comm(tous, tins): - colls : List[ScheduleUnit] = None - for matcher in matchers: - colls = matcher(tous, tins) - if colls: - break - if colls is not None: - P2PFusion.add_collectives(execplan, colls) - return execplan - - @staticmethod - def collect_tensors(execplan: ExectuionPlan, stypes: List[SUType]): - # dict[pid][devid] = list of sub_tensors - ous = dict() - ins = dict() - for devid in execplan.devices(): - dev_seq = execplan.sequence(devid) - for su in dev_seq: - if su.stype in stypes: - for val in su.inputs(): - # FIXME: remove parameter constraints - if isinstance(val, IRTensor) and not val.is_param(): - pid = val.parent._id - if pid not in ins: - ins[pid] = dict() - if devid not in ins[pid]: - ins[pid][devid] = list() - # TODO: may have redundancy - ins[pid][devid].append(val) - for idx, val in enumerate(su.outputs()): - if isinstance(val, IRTensor): - pid = val.parent._id - if pid not in ous: - ous[pid] = dict() - if devid not in ous[pid]: - ous[pid][devid] = list() - select_su = su.select_adapters(idx) - if select_su: - for out in select_su.outputs(): - # TODO: may have redundancy - ous[pid][devid].append(out) - else: - # TODO: may have redundancy - ous[pid][devid].append(val) - return ous, ins - - @staticmethod - def have_comm(tensor_ous, tensor_ins): - """ - Check if they don't have communications - """ - for devid in tensor_ins: - if devid not in tensor_ous: - return True - # no transmission - if input in tensor_ous[devid]: - continue - # have transmission - else: - return True - return False - - @staticmethod - def add_collectives(execplan: ExectuionPlan, coll_sus: List[ScheduleUnit]): - for coll_su in coll_sus: - # print(f'inserting Collective SU: {coll_su.name}: {coll_su}') - # find insert place: the first send - devid = coll_su.device[0] - ranks = coll_su.nodes(0).ranks - for idx, su in enumerate(execplan.sequence(devid)): - # send or recv - if su.stype == SUType.P2P: - sr_tensor = (su.inputs() + su.outputs())[0] - if sr_tensor in coll_su.inputs() + coll_su.outputs(): - execplan.at(devid)[idx] = coll_su - break - # merge - if su.stype == SUType.Transform and len(su.inputs()) > 1: - merge_out = su.outputs(0) - if merge_out in coll_su.outputs(): - assert len(coll_su.outputs()) == 1 - execplan.at(devid)[idx] = coll_su - break - else: - raise RuntimeError("Cannot find a send P2P") - # all the send, recv of the inputs will be removed in ranks - for coll_su in coll_sus: - ranks = coll_su.nodes(0).ranks - for input in coll_su.inputs(): - for rank in ranks: - for su in execplan.sequence(rank): - # remove send / recv - if su.stype == SUType.P2P and input in (su.inputs() + su.outputs()): - execplan.at(rank).remove(su) - # remove merge if coll generate merge results - if su.stype == SUType.Transform and len(su.inputs()) > 1: - merge_out = su.outputs(0) - if merge_out in coll_su.outputs(): - assert len(coll_su.outputs()) == 1 - execplan.at(rank).remove(su) - - @staticmethod - def transmission(tensor_ous, in_tensor) -> Dict[int, List[IRTensor]]: - trans_tensors = dict() - for devid in tensor_ous: - for out in tensor_ous[devid]: - if in_tensor.overlap(out): - if devid not in trans_tensors: - trans_tensors[devid] = list() - trans_tensors[devid].append(out) - return trans_tensors - - @staticmethod - def match_allreduce(tous, tins): - """ - Allreduce semantic: - - Each device holds a recvs same spatial tensor from all device and - sends to all device. - The recved tensors are summed into one - """ - allreduce_sus = list() - # {tensor_id: [device_id]} - in_devices: Dict[int, List[int]] = dict() - # {tensor_id: [tensors] - in_tensors: Dict[int, List[IRTensor]] = dict() - for devid in tins: - for in_tensor in tins[devid]: - if in_tensor.valmap != ValueMap(0, 1): - continue - tid = in_tensor._id - if tid not in in_devices: - in_devices[tid] = list() - in_tensors[tid] = list() - in_devices[tid].append(devid) - in_tensors[tid].append(in_tensor) - for tid in in_devices: - # P2P transmission - if len(in_devices[tid]) <= 1: - continue - in_tensor = in_tensors[tid][0] - # {rank: [IRTensor]}} - out_tensors = P2PFusion.transmission(tous, in_tensor) - out_devices = set(out_tensors.keys()) - # check out tensor and reduce in tensor devices are the same set - if out_devices == set(in_devices[tid]): - # multiple transmission FIXME: remove redundancy - if not all([len(out_tensors[odev]) == 1 for odev in out_devices]): - continue - # check same indice map and no overlap value map - unique_indices = list() - for odev in out_tensors: - indmap = out_tensors[odev][0].indmap - if indmap not in unique_indices: - unique_indices.append(indmap) - if len(unique_indices) != 1: - continue - # check no overlap valmaps - all_valmaps = list() - overlap = False - for odev in out_tensors: - valmap = out_tensors[odev][0].valmap - for pre_valmp in all_valmaps: - overlap = pre_valmp.overlap(valmap) - all_valmaps.append(valmap) - if overlap: - continue - - ranks = list(out_tensors.keys()) - inputs = [[out_tensors[rank][0]] for rank in ranks] - - for input, rank in zip(inputs, ranks): - for in_tensor in in_tensors[tid]: - if in_tensor.device[0] == rank: - outputs = [in_tensor] - break - else: - raise RuntimeError("Internal Error") - op = IRCollectives(input, outputs, ranks, IRCollType.AllReduce) - su = ScheduleUnit([op], SUType.Coll, name='allreduce') - su.device = rank - allreduce_sus.append(su) - - # print('>> find allreduce pattern:') - # print(f'device group: {ranks}') - # for input in inputs: - # print(f'src: {input}') - # for output in outputs: - # print(f'dst: {output}') - - if len(allreduce_sus) == 0: - return None - else: - return allreduce_sus - - @staticmethod - def match_allgather(tous, tins): - """ - Allgather semantic: - - Each device performs same transformation merge. - - !!Note: Each input in merge su can be paired with a pair, find - them and remove!! Fuse merge, send, recv into one merge!! - """ - allgather_sus = list() - # {tensor_id: [device_id]} - in_devices: Dict[int, List[int]] = dict() - # {tensor_id: [tensors] - in_tensors: Dict[int, List[IRTensor]] = dict() - for devid in tins: - for in_tensor in tins[devid]: - tid = in_tensor._id - if tid not in in_devices: - in_devices[tid] = list() - in_tensors[tid] = list() - in_devices[tid].append(devid) - in_tensors[tid].append(in_tensor) - for tid in in_devices: - # P2P transmission - if len(in_devices[tid]) <= 1: - continue - in_tensor = in_tensors[tid][0] - # {rank: [IRTensor]}} - out_tensors = P2PFusion.transmission(tous, in_tensor) - out_devices = set(out_tensors.keys()) - if out_devices == set(in_devices[tid]): - # multiple transmission FIXME: remove redundancy - if not all([len(out_tensors[odev]) == 1 for odev in out_devices]): - continue - # check same value map and no overlap indmap - unique_valmaps = list() - for odev in out_tensors: - valmap = out_tensors[odev][0].valmap - if valmap not in unique_valmaps: - unique_valmaps.append(valmap) - if len(unique_valmaps) != 1: - continue - # check no overlap indmap - all_indices = list() - overlap = False - for odev in out_tensors: - indmap = out_tensors[odev][0].indmap - for pre_indices in all_indices: - overlap = pre_indices.overlap(indmap) - all_indices.append(indmap) - if overlap: - continue - - ranks = list(out_tensors.keys()) - inputs = [out_tensors[rank][0] for rank in ranks] - - for input, rank in zip(inputs, ranks): - outputs = [t for t in inputs if t != input] - op = IRCollectives([input], outputs, ranks, IRCollType.AllGather) - su = ScheduleUnit([op], SUType.Coll, name='allgather') - su.device = rank - allgather_sus.append(su) - - # print('>> find allgather pattern:') - # print(f'device group: {ranks}') - # for input in inputs: - # print(f'src: {input}') - # for output in outputs: - # print(f'dst: {output}') - - if len(allgather_sus) == 0: - return None - else: - return allgather_sus - - @staticmethod - def match_reducescatter(tous, tins): - """ - ReduceScatter semantic: - - Each device performs same - """ - rs_sus = list() - # {tensor_id: [device_id]} - in_devices: Dict[int, List[int]] = dict() - # {tensor_id: [tensors] - in_tensors: Dict[int, List[IRTensor]] = dict() - for devid in tins: - for in_tensor in tins[devid]: - tid = in_tensor._id - if in_tensor.valmap != ValueMap(0, 1): - continue - if tid not in in_devices: - in_devices[tid] = list() - in_tensors[tid] = list() - in_devices[tid].append(devid) - in_tensors[tid].append(in_tensor) - # {in_tensor_id: [reduce_tensor device]} - reduce_out_devices = dict() - # {in_tensor_id: [reduce out tensors]} - reduce_out_tensors = dict() - for tid in in_devices: - # P2P transmission - if len(in_devices[tid]) != 1: - continue - in_tensor = in_tensors[tid][0] - out_tensors = P2PFusion.transmission(tous, in_tensor) - - is_reduce = True - for devid in out_tensors: - # multiple transmission FIXME: remove redundancy - if not all([len(out_tensors[odev]) == 1 for odev in out_tensors]): - continue - if out_tensors[devid][0].valmap == ValueMap(0, 1): - is_reduce = False - break - if out_tensors[devid][0].indmap != in_tensor.indmap: - is_reduce = False - break - if is_reduce: - reduce_out_devices[tid] = list() - reduce_out_tensors[tid] = list() - for devid in out_tensors: - reduce_out_devices[tid].append(devid) - reduce_out_tensors[tid].append(out_tensors[devid][0]) - # reverse reduce_devices {tuple(devices): [in_tensors]} - reduce_tensors = dict() - for tid in reduce_out_devices: - devices = tuple(set(reduce_out_devices[tid])) - if devices not in reduce_tensors: - reduce_tensors[devices] = list() - reduce_tensors[devices].append(in_tensors[tid][0]) - # check conditions - for ranks in reduce_tensors: - reduce_in_tensors = reduce_tensors[ranks] - # reduce-scatter requires tensor num to be equal of num devs - if len(reduce_in_tensors) != len(ranks): - continue - # reduce in tensors should place on different devices - devices = [t.device[0] for t in reduce_in_tensors] - if set(devices) != set(ranks): - continue - - # satisfied! set up inputs, outputs and ranks - ranks = list(ranks) - ranks.sort() - - device_inputs = [None] * len(ranks) - for in_tensor in reduce_in_tensors: - out_tensors = reduce_out_tensors[in_tensor._id] - out_devs = [t.device[0] for t in out_tensors] - inputs = [ - out_tensors[out_devs.index(odev)] for odev in ranks - ] - ridx = ranks.index(in_tensor.device[0]) - device_inputs[ridx] = inputs - for in_tensor in reduce_in_tensors: - rank = in_tensor.device[0] - outputs = [in_tensor] - inputs = [inputs[ranks.index(rank)] for inputs in device_inputs] - op = IRCollectives(inputs, outputs, ranks, IRCollType.ReduceScatter) - su = ScheduleUnit([op], SUType.Coll, name='reducescatter') - su.device = rank - rs_sus.append(su) - - # print('>> find reduce-scatter pattern:') - # print(f'device group: {ranks}') - # for output in reduce_in_tensors: - # tid = output._id - # for input in reduce_out_tensors[tid]: - # print(f'src: {input}') - # print(f'dst: {output}') - - if len(rs_sus) == 0: - return None - else: - return rs_sus - - - @staticmethod - def match_broadcast(tous, tins): - """ - Broadcast semantic: - - The root device send the its tensor to all the devices - """ - broadcast_sus = list() - # {tensor_id: [device_id]} - in_devices: Dict[int, List[int]] = dict() - # {tensor_id: [tensors] - in_tensors: Dict[int, List[IRTensor]] = dict() - for devid in tins: - for in_tensor in tins[devid]: - tid = in_tensor._id - if in_tensor.valmap != ValueMap(0, 1): - continue - if tid not in in_devices: - in_devices[tid] = list() - in_tensors[tid] = list() - in_devices[tid].append(devid) - in_tensors[tid].append(in_tensor) - - for tid in in_devices: - # P2P transmission - if len(in_devices[tid]) <= 2: - continue - in_tensor = in_tensors[tid][0] - out_tensors = P2PFusion.transmission(tous, in_tensor) - # multiple transmission FIXME: remove redundancy - if len(out_tensors.keys()) != 1: - continue - # multiple transmission FIXME: remove redundancy - if len(out_tensors[list(out_tensors.keys())[0]]) != 1: - continue - root_tensor = out_tensors[list(out_tensors.keys())[0]][0] - is_equal = True - for in_tensor in in_tensors[tid]: - if in_tensor != root_tensor: - is_equal = False - break - if not is_equal: - continue - ranks = [root_tensor.device[0]] - inputs = [[root_tensor],] - outputs = [[],] - for output in in_tensors[tid]: - devid = output.device[0] - if devid in ranks: - continue - ranks.append(devid) - outputs.append([output]) - inputs.append([]) - for input, output, rank in zip(inputs, outputs, ranks): - op = IRCollectives(input, output, ranks, IRCollType.Broadcast) - su = ScheduleUnit([op], SUType.Coll, name='broadcast') - su.device = rank - broadcast_sus.append(su) - - # print('>> find broadcast pattern:') - # print(f'device group: {ranks}') - # print(su) - - if len(broadcast_sus) == 0: - return None - - - else: - return broadcast_sus diff --git a/cube/execplan/planpass/redundant.py b/cube/execplan/planpass/redundant.py deleted file mode 100644 index 8ec69ee5..00000000 --- a/cube/execplan/planpass/redundant.py +++ /dev/null @@ -1,54 +0,0 @@ -from cube.execplan import ExectuionPlan -from cube.schedule.su import SUType -from cube.execplan.planpass.planpass import PlanPass - - -class RemoveRedundantAdapters(PlanPass): - - @staticmethod - def apply(execplan: ExectuionPlan) -> ExectuionPlan: - """ - Remove redundant adapters - - A redundant adapter is sending / recving tensors on the same deivce - """ - # remove identity comm - for devid in execplan.devices(): - seq = execplan.sequence(devid) - comms = [su for su in seq if su.stype == SUType.P2P] - for comm in comms: - send_ranks = set([devid]) - recv_ranks = set([devid]) - for node in comm.nodes(): - send_ranks.update(node.send_ranks) - recv_ranks.update(node.recv_ranks) - if list(send_ranks) != [devid]: - continue - if list(recv_ranks) != [devid]: - continue - # remove - execplan.at(devid).remove(comm) - # remove redundant comm e.g., recving same tensor from other ranks - for devid in execplan.devices(): - all_outs = list() - seq = execplan.sequence(devid) - for su in seq: - # zero-output SU will not be removed - removable = len(su.outputs()) != 0 - for output in su.outputs(): - if output not in all_outs: - removable = False - all_outs.append(output) - if removable: - # only recv has output - execplan.at(devid).remove(su) - if su.stype == SUType.P2P: - # remove all the paired send - ranks = su.nodes(0).recv_ranks - if len(ranks) > 1: - raise NotImplementedError - rank = ranks[0] - if su.mirror not in execplan.at(rank): - raise RuntimeError("Recv Op not found!") - execplan.at(rank).remove(su.mirror) - return execplan diff --git a/cube/graph/graph.py b/cube/graph/graph.py index d7c81fe0..825d14fd 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -104,40 +104,6 @@ def nodes(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") - # def insert(self, node, src_node=None, dst_node=None, replaced_tensor=None): - # """ - # Insert a node between src_node and dst_node. In default, - # if dst_node is not None, the node will be inserted right before - # dst_node. If the replaced_tensor is provided, the replaced_tensor - # in dst_node's inputs will be removed, and the output of node will be - # set as input for dst_node. - # """ - # if not isinstance(node, IRCell): - # raise TypeError("Expected IRCell to insert") - # if dst_node is not None: - # if dst_node not in self._nodes: - # raise KeyError("dst_node not found") - # if replaced_tensor is not None: - # if replaced_tensor not in dst_node.inputs(): - # raise RuntimeError(f"Expected dst_node input has {replaced_tensor}") - # # remove dst_node input - # input_index = dst_node.inputs().index(replaced_tensor) - # if len(node.outputs()) != 1: - # raise RuntimeError("replaced node requires output length to be 1") - # dst_node.set_input(input_index, node.outputs(0)) - # # insert node - # index = self._nodes.index(dst_node) - # self._nodes.insert(index, node) - # elif src_node is not None: - # if src_node not in self._nodes: - # raise KeyError("src_node not found") - # index = self._nodes.index(src_node) - # self._nodes = self._nodes[:index+1] + [node] + self._nodes[index+1:] - # else: - # raise TypeError("Expected at least one of [src_node, dst_node]") - # #TODO: optimize this - # self.reset_dependency() - def _replace_tensor(self, old_tensor: IRTensor, new_tensor: IRTensor): """ Replace tensor from old_tensor to new_tensor for all the graph. diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index a8237a91..afea9feb 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -111,7 +111,7 @@ def all_reduce(input_tensors: List[torch.Tensor], """ CudaTimer().start(field_name='comm') assert len(input_tensors) == 1 - tensor = input_tensors + tensor = input_tensors[0] if not tensor.is_contiguous(): tensor = tensor.contiguous() tensor = tensor.detach() From 6ede929dccae9f1edcb2a1c34809f1d6d925ecf4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Dec 2021 20:13:02 +0800 Subject: [PATCH 0505/1892] use PAS abstraction --- examples/mlp/policy/col_parallel.py | 2 +- examples/mlp/policy/data_parallel.py | 41 ++++++---------------- examples/mlp/policy/hybrid_parallel.py | 48 +++++++------------------- examples/mlp/policy/no_parallel.py | 21 +---------- 4 files changed, 26 insertions(+), 86 deletions(-) diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index 2c0a68e7..20b92722 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -48,7 +48,7 @@ def PAS(graph: IRGraph, resource): """ for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - algo = node.algorithms('data') + algo = node.algorithms('column') if algo: sub_nodes = graph.partition( node, algo, config=dict(chunk_num=resource.ngpus) diff --git a/examples/mlp/policy/data_parallel.py b/examples/mlp/policy/data_parallel.py index ff74e6bf..6e1e540d 100644 --- a/examples/mlp/policy/data_parallel.py +++ b/examples/mlp/policy/data_parallel.py @@ -1,40 +1,21 @@ from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation -def transform_policy(graph: IRGraph, resource): +def PAS(graph: IRGraph, resource): """ - The transformation policy transposes linear using data parallel + Linear Column Partition """ - for node in graph.nodes(): + for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): algo = node.algorithms('data') - if algo is None: - algo = node.algorithms('dim') - assert algo - sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=resource.ngpus)) + if algo: + sub_nodes = graph.partition( + node, algo, config=dict(chunk_num=resource.ngpus) + ) else: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - print(graph) + sub_nodes = [node] + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + print(graph.extra_repr()) return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - sugraph.partial_set_order(fsus, lazy=False) - return sugraph diff --git a/examples/mlp/policy/hybrid_parallel.py b/examples/mlp/policy/hybrid_parallel.py index 49decb71..bd0c3784 100644 --- a/examples/mlp/policy/hybrid_parallel.py +++ b/examples/mlp/policy/hybrid_parallel.py @@ -1,46 +1,24 @@ from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation -def transform_policy(graph: IRGraph, resource): +def PAS(graph: IRGraph, resource): """ - The transformation policy transposes linear using column parallel + Linear Hybrid Partition """ - linear_idx = 0 - for node in graph.nodes(): + for idx, node in enumerate(graph.nodes()): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - algo = algo = None - if linear_idx % 2 == 0: - print(f'> column partition: {node}') - algo = node.algorithms('column') - else: - print(f'> row partition: {node}') + if idx % 2 == 0: algo = node.algorithms('row') + else: + algo = node.algorithms('column') if algo: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) - linear_idx += 1 + sub_nodes = graph.partition( + node, algo, config=dict(chunk_num=resource.ngpus) + ) else: - print(f'> replicate: {node}') - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - # print(graph) + sub_nodes = [node] + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + print(graph.extra_repr()) return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - # sugraph.assign(su, list(range(resource.ngpus))) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - return sugraph diff --git a/examples/mlp/policy/no_parallel.py b/examples/mlp/policy/no_parallel.py index 702c02b7..18e3fade 100644 --- a/examples/mlp/policy/no_parallel.py +++ b/examples/mlp/policy/no_parallel.py @@ -1,23 +1,4 @@ from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using column parallel - """ +def PAS(graph: IRGraph, resource): return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - for su in sugraph.fsus(): - sugraph.assign(su, 0) - sugraph.assign(su.mirror, 0) - return sugraph From 42e698f6ebb0908dc51c75d7218e7ec5e7221947 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 10:05:35 +0800 Subject: [PATCH 0506/1892] broadcast fusion --- cube/codegen/codegen.py | 8 ++++++-- cube/execplan/planpass/fusion.py | 22 +++++++++++++--------- cube/graph/adapter/adapter.py | 4 ++-- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 8956d98c..06dfee3c 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -265,11 +265,15 @@ def emit_adapter_call(self, node: IRAdapter): sign = 'cube.runtime.adapter.{ctype}({input_tensors}, {output_shapes}, {output_dtypes}, {group})' inputs = self.tuple_naming(prim.inputs) outputs = self.return_naming(prim.outputs) + dtypes = None + if prim.output_dtypes is not None: + dtypes = [self.dtype_map(dtype) for dtype in prim.output_dtypes] + dtypes = self.tuple_naming(dtypes) body = sign.format( ctype=prim.ctype.value, input_tensors = inputs, output_shapes = prim.output_shapes, - output_dtypes = prim.output_dtypes, + output_dtypes = dtypes, group=prim.group ) code = f'{outputs} = {body}' @@ -279,7 +283,7 @@ def emit_adapter_call(self, node: IRAdapter): # requires grad generation sign = '{output} = {output}.contiguous().requires_grad_()' for output in node.outputs(): - if isinstance(output, IRSubTensor): + if isinstance(output, IRSubTensor) and output.requires_grad: code = sign.format(output=self.tensor_naming(output)) self.forward_region.append(code) diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index c5516a2b..17e67426 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -290,28 +290,27 @@ def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): if not P2PFusion._check_different_outputs_devices(adapters, among=True): continue # gen broadcast - print(f'generating broadcast for tensor: {outputs[tid]} ... (NOT SUPPORTED)') - return + print(f'generating broadcast for tensor: {outputs[tid]}') # put root rank to the first root = list(device)[0] group = set() for adapter in adapters: group.update(P2PFusion._get_output_devices(adapter)) - group.remove(root[0]) - # inputs - tensor = adapters[0].inputs() group = [root] + list(group) + # input + tensor = adapters[0].inputs(0) prims = list() for device in group: - inputs = tensor if device == root else None + inputs = [tensor] if device == root else None output_shapes = [tensor.shape] output_dtypes = [tensor.dtype] coll = CollectivePrim( ctype = CollectivePrim.Type.Broadcast, - device = device, + device = [device], group = group, inputs = inputs, + outputs = [tensor], output_shapes = output_shapes, output_dtypes = output_dtypes ) @@ -323,8 +322,13 @@ def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): inputs=[tensor], idevices=[[root],], outputs=[tensor], odevices=[[root],] ) - # TODO: this should insert into graph and execution plan - for adapter in [root_adapter] + adapters: + # insert into graph and execution plan + index = min([execplan.graph.nodes().index(n) for n in adapters]) + execplan.graph._nodes.insert(index, root_adapter) + seq = [node for node in execplan.graph.nodes() if root in node.device] + execplan.set(root, seq) + + for adapter in adapters: device = adapter.odevice(0)[0] prim = prims[group.index(device)] adapter._prims = [prim] diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index 3cc92036..07cf9a4f 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -73,11 +73,11 @@ def __init__(self, ctype: Enum, """ self.ctype = ctype # inputs - self.inputs: List[IRSubTensor] = inputs + self.inputs: List[IRSubTensor] = inputs if inputs is not None else list() self.input_shapes: List[IRSubTensor] = input_shapes self.input_dtypes: List[IRDType] = input_dtypes # outputs - self.outputs: List[IRSubTensor] = outputs + self.outputs: List[IRSubTensor] = outputs if outputs is not None else list() self.output_shapes: List[IRSubTensor] = output_shapes self.output_dtypes: List[IRDType] = output_dtypes # communication group From 2d3d2c7423ae1262dad4b23c8257cf8fffa5457e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 11:29:05 +0800 Subject: [PATCH 0507/1892] add control dependency --- cube/ir/cten.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 43f0c576..3c31823c 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -63,12 +63,10 @@ def __init__(self, for tensor in self._outputs: tensor.attach_cell(self) - # destination cells - # -- will only be set when initializing to a graph - self._successors: List[List[IRCell]] = [list() for _ in range(output_length)] - # source cells: note a tensor can be generated by many cells - # -- will only be set when initializing to a graph - self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length)] + # destination cells. [-1] for control dependency + self._successors: List[List[IRCell]] = [list() for _ in range(output_length+1)] + # source cells. [-1] for control dependency + self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length+1)] self._mirror = None self._tag = None @@ -150,6 +148,7 @@ def inputs(self, index: Optional[int] = None) -> Union[List[Any], Any]: def predecessors(self, index: Optional[int] = None) -> List: """ Get input operator at input index + (or index = -1 for control dependency) Returns: cell(s): Union[List[IRCell], IRCell] @@ -197,8 +196,8 @@ def successors(self, index: Optional[int] = None) -> List: Args: index (int or None): - index of the outputs, None will return the nodes - for all the outputs + index of the outputs (or -1 for control dependency), + None will return the nodes for all the outputs """ if isinstance(index, int): if index >= len(self._outputs): @@ -270,7 +269,9 @@ def add_predecessor(self, input_index: int, cell): """ Add a predecessor cell in the input_index slot. - Note this won't add successor if caller cell to the node + Note this won't add successor if caller cell to the node + + To add control dependency, use `input_index=-1` """ if not isinstance(cell, IRCell): raise TypeError("Expected node to be IRCell") @@ -286,13 +287,15 @@ def clear_predecessor(self): Clear all predecessors """ self._predecessors = [ - list() for _ in range(len(self.inputs())) + list() for _ in range(len(self.inputs()) + 1) ] def add_successor(self, output_index: int, cell): """ Set self node the output index node. `node` will take the self.outputs(index) as the input + + To add control dependency, use `output_index=-1` """ if not isinstance(cell, IRCell): raise TypeError("Expected node to be IRCell") @@ -304,7 +307,7 @@ def clear_successor(self): Clear all successors """ self._successors = [ - list() for _ in range(len(self.outputs())) + list() for _ in range(len(self.outputs()) + 1) ] def make_empty(self): From 0da69b6c973ad9361d88bc751417c1f1c9cc64fe Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 11:29:36 +0800 Subject: [PATCH 0508/1892] eager adapter insert --- cube/graph/adapter/gen.py | 29 ++++++++++-- cube/graph/graph.py | 99 +++++++++++++++++++++++++++++++++++---- 2 files changed, 114 insertions(+), 14 deletions(-) diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index 9ed490f0..0d271c33 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -9,20 +9,32 @@ class AdapterGener: @staticmethod - def gen(graph: IRGraph) -> IRGraph: + def gen(graph: IRGraph, eager=True) -> IRGraph: """ Generate tensor adapter for both intermediate tensors and weights + + Args: + graph: IRGraph. + eager (Boolean): + if True, + each adapter will be inserted right after it's ready to execute. + if False (i.e., lazy), + each adatper will be inserted right before the tensor needs it. + Note weight reducers are always append to last. + Returns: + graph (IRGraph) """ - graph = AdapterGener.gen_activation_adapter(graph) + graph = AdapterGener.gen_activation_adapter(graph, eager) graph = AdapterGener.gen_weight_reducer(graph) return graph @staticmethod - def gen_activation_adapter(graph: IRGraph) -> IRGraph: + def gen_activation_adapter(graph: IRGraph, eager=True) -> IRGraph: # update the gradient before generate adapter for node in graph.nodes(): if isinstance(node, IRBpOperation): node.update() + all_adapters = list() # generate adapter for non-weight values for node in graph.nodes(): if isinstance(node, IRFwOperation): @@ -34,17 +46,24 @@ def gen_activation_adapter(graph: IRGraph) -> IRGraph: continue adapter = IRAdapter.gen(input) if not adapter.is_identity(): + all_adapters.append(adapter) idx = graph.nodes().index(node) graph._nodes.insert(idx, adapter) if isinstance(node, IRBpOperation): for grad in node.inputs(): if not isinstance(grad, IRSubTensor): continue - # skip parameter adapter = IRAdapter.gen(grad) if not adapter.is_identity(): + all_adapters.append(adapter) idx = graph.nodes().index(node) graph._nodes.insert(idx, adapter) + graph.reset_dependency() + if eager: + seq = graph.nodes() + for adapter in all_adapters: + seq.remove(adapter) + graph.partial_set_order(seq, eager=True) return graph @@ -73,7 +92,7 @@ def gen_weight_reducer(graph: IRGraph) -> IRGraph: if grad in grads[input._id][devid]: raise RuntimeError("Already logged grad?") grads[input._id][devid].append(grad) - # step 2: generate weight. + # step 2: generate reducers. # reducers: tuple(ranks): List[weight] reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() for wid in grads: diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 825d14fd..664407eb 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -38,10 +38,8 @@ def __init__(self, if inputs is None: inputs = IRGraph.get_inputs(nodes) - # inputs = [t for t in inputs if not t.is_param()] if outputs is None: outputs = IRGraph.get_outputs(nodes) - # outputs = [t for t in outputs if not t.is_param()] super().__init__( name=module_name, @@ -82,6 +80,14 @@ def reset_dependency(self): if out_tensor.overlap(in_tensor): src_node.add_successor(out_idx, dst_node) dst_node.add_predecessor(in_idx, src_node) + # set mirror as control dependency + for idx1, node1 in enumerate(self._nodes): + node2 = node1.mirror + if isinstance(node2, IRCell) and node2 in self._nodes: + idx2 = self._nodes.index(node2) + if idx1 < idx2: + node1.add_successor(-1, node2) + node2.add_predecessor(-1, node1) def parameters(self): """ @@ -386,10 +392,84 @@ def happen_before(self, node1: IRCell, node2: IRCell, skip=None) -> bool: return False def set_order(self, seq: List[IRCell]): - raise NotImplementedError + """ + Set a topological order for IRGraph, which requires seq: - def partial_set_order(self, seq: List[IRCell], lazy=False): - raise NotImplementedError + 1). The set of nodes in seq must be same with this IRGraph + 2). Staisfies topological order + + Returns: + True if set succesfully, False not. + """ + for node in seq: + if node not in self.nodes(): + return False + if len(seq) != len(self.nodes()): + return False + if not IRGraph.check_legal_order(seq, integrity_check=True): + return False + self._nodes = seq + return True + + def partial_set_order(self, seq: List[IRCell], eager=True): + """ + Set a partial topological order for IRGrah. + The remaining nodes will be automatically inserted to + make the full legal sequence. + + In most of the cases, `eager=True` has better performance. + + Args: + seq: partial scheduling sequence + eager (default True): + if True, the remaining nodes are inserted once it is ready + if Flase, the remaining nodes are inserted only when it is needed. + + Returns: + True if set succesfully, False not. + """ + seq = copy.copy(seq) + for node in seq: + if node not in self.nodes(): + return False + if not IRGraph.check_legal_order(seq, integrity_check=False): + return False + remain: List[IRCell] = [node for node in self.nodes() if node not in seq] + for node in remain: + if eager: + pre_indices = [seq.index(pre) for pre in node.predecessors()] + index = max(pre_indices) + 1 + else: + suc_indices = [seq.index[suc] for suc in node.successors()] + index = min(suc_indices) + seq.insert(index, node) + self._nodes = seq + return True + + @staticmethod + def check_legal_order(seq: List[IRCell], integrity_check=False): + """ + Check whether seq satisfies topological order. + + Args: + seq: List of IRCell + integrity_check: + If true, performs additional integrity check that requires + all the SUs in predecessor and successor of a SU should + appear in the sequence. + + Returns: + Boolean: True for satisfying topo order, otherwise False. + """ + for index, node in enumerate(seq): + for pre in node.predecessors(): + if pre in seq: + pre_idx = seq.index(pre) + if pre_idx >= index: + return False + elif integrity_check: + return False + return True def __repr__(self): dscp = f"Graph{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" @@ -403,10 +483,11 @@ def extra_repr(self): for node in self._nodes: # if isinstance(node, IRBpOperation): # continue - succ_node_ids = [None] * len(node.outputs()) - for out_idx in range(len(node.outputs())): - node_list = [snode._id for snode in node.successors(out_idx)] - succ_node_ids[out_idx] = node_list + succ_node_ids = [node._id for node in node.successors()] + # succ_node_ids = [None] * len(node.outputs()) + # for out_idx in range(len(node.outputs())): + # node_list = [snode._id for snode in node.successors(out_idx)] + # succ_node_ids[out_idx] = node_list dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" # outputs dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" From d06933c32a31cd274d53e8240c94a97999707de9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 11:30:13 +0800 Subject: [PATCH 0509/1892] remove useless schedule --- cube/schedule/__init__.py | 4 - cube/schedule/adapter/__init__.py | 0 cube/schedule/adapter/collectives.py | 50 --- cube/schedule/adapter/comm.py | 78 ---- cube/schedule/adapter/transform.py | 292 ------------- cube/schedule/iterator.py | 149 ------- cube/schedule/pool.py | 50 --- cube/schedule/su.py | 525 ---------------------- cube/schedule/sugraph.py | 630 --------------------------- cube/schedule/translator.py | 103 ----- 10 files changed, 1881 deletions(-) delete mode 100644 cube/schedule/__init__.py delete mode 100644 cube/schedule/adapter/__init__.py delete mode 100644 cube/schedule/adapter/collectives.py delete mode 100644 cube/schedule/adapter/comm.py delete mode 100644 cube/schedule/adapter/transform.py delete mode 100644 cube/schedule/iterator.py delete mode 100644 cube/schedule/pool.py delete mode 100644 cube/schedule/su.py delete mode 100644 cube/schedule/sugraph.py delete mode 100644 cube/schedule/translator.py diff --git a/cube/schedule/__init__.py b/cube/schedule/__init__.py deleted file mode 100644 index 23058c3a..00000000 --- a/cube/schedule/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType -from cube.schedule.translator import IRDataLoader -from cube.schedule.sugraph import SUGraph diff --git a/cube/schedule/adapter/__init__.py b/cube/schedule/adapter/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/schedule/adapter/collectives.py b/cube/schedule/adapter/collectives.py deleted file mode 100644 index 026772b3..00000000 --- a/cube/schedule/adapter/collectives.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import List, Type -from enum import Enum - -from cube.ir.cten import IRCell, IRTensor - - -class IRCollType(Enum): - - AllReduce = 'all_reduce' - AllGather = 'all_gather' - ReduceScatter = 'reduce_scatter' - Broadcast = 'broadcast' - - -class IRCollectives(IRCell): - """ - Collective cell for IRCell - """ - - def __init__(self, inputs: List[IRTensor], outputs: List[IRTensor], - ranks: List[int], colltype: IRCollType): - - if not isinstance(colltype, IRCollType): - raise TypeError("colltype Expected IRCollType") - if not all([isinstance(rank, int) for rank in ranks]): - raise TypeError("ranks should be List[int]") - - self.comm_type = colltype - if colltype == IRCollType.AllReduce: - signature = 'cube.runtime.collectives.all_reduce' - if colltype == IRCollType.AllGather: - signature = 'cube.runtime.collectives.all_gather' - if colltype == IRCollType.ReduceScatter: - signature = 'cube.runtime.collectives.reduce_scatter' - if colltype == IRCollType.Broadcast: - signature = 'cube.runtime.collectives.broadcast' - - self.ranks = ranks - - super().__init__( - name = colltype.value, - signature = signature, - input_length = len(inputs), - output_length = len(outputs) - ) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - for idx, output in enumerate(outputs): - self.set_output(idx, output) - diff --git a/cube/schedule/adapter/comm.py b/cube/schedule/adapter/comm.py deleted file mode 100644 index 671b65fc..00000000 --- a/cube/schedule/adapter/comm.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import List -from enum import Enum - -from cube.ir.cten import IRCell, IRTensor - - -class IRCommType(Enum): - - Send = 'send' - Recv = 'recv' - SendRecv = 'sendrecv' - - -class IRCommunication(IRCell): - """ - Communication cell for IRCell - """ - - def __init__(self, - send_tensors=list(), send_ranks: List[List[int]] = list(), - recv_tensors=list(), recv_ranks: List[List[int]] =list()): - """ - Create a basic send, recv or sendrecv communication node - """ - if len(send_tensors) != 0 and len(recv_tensors) != 0: - comm_type = IRCommType.SendRecv - signature = 'cube.runtime.collectives.sendrecv' - elif len(send_tensors) != 0 and len(recv_tensors) == 0: - comm_type = IRCommType.Send - signature = 'cube.runtime.collectives.send' - elif len(recv_tensors) != 0 and len(send_tensors) == 0: - comm_type = IRCommType.Recv - signature = 'cube.runtime.collectives.recv' - else: - raise ValueError( - "Expected at least one of send_tensors and recv_tensors" - ) - - self.comm_type = comm_type - self.send_tensors = list() - self.send_ranks = list() - self.recv_tensors = list() - self.recv_ranks = list() - - super().__init__( - name = comm_type.value, - signature = signature, - input_length = len(send_tensors), - output_length = len(recv_tensors) - ) - - for idx, (tensor, to_device) in enumerate(zip(send_tensors, send_ranks)): - self.set_input(idx, tensor) - self.send_tensors.append(self.inputs(idx)) - self.send_ranks.append(to_device) - - for idx, (tensor, from_device) in enumerate(zip(recv_tensors, recv_ranks)): - self.set_output(idx, tensor) - self.recv_tensors.append(self.outputs(idx)) - self.recv_ranks.append(from_device) - - def __repr__(self): - inputs = list() - for tensor in self.inputs(): - if isinstance(tensor, IRTensor): - inputs.append(f't{tensor._id}-dev{tensor.device}') - else: - inputs.append(tensor) - - outputs = list() - for tensor in self.outputs(): - if isinstance(tensor, IRTensor): - outputs.append(f't{tensor._id}-dev{tensor.device}') - else: - outputs.append(tensor) - - dscp = f'SendRecv(id={self._id}, send={inputs}, recv={outputs})' - return dscp diff --git a/cube/schedule/adapter/transform.py b/cube/schedule/adapter/transform.py deleted file mode 100644 index 358eb9c6..00000000 --- a/cube/schedule/adapter/transform.py +++ /dev/null @@ -1,292 +0,0 @@ -import copy -from typing import List, Optional -from enum import Enum -import numpy as np - -from cube.ir.cten import IRCell -from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap - - -class IRTransformType(Enum): - - Select = 'cube.runtime.adapter.select' - Merge = 'cube.runtime.adapter.merge' - - -class IRTensorTransform(IRCell): - """ - Tensor transformation by convert source tensors - to destination tensors - - Select: - src_tensors is only one tensor, dst_tensors has (multiple) tensors. - This will select the sub_tensor and generate what it need - - Merge: - src_tensors has (multiple) tensors, dst_tensors is only one tensor. - This will merge the sub_tensor and generate what it need - """ - def __init__(self, src_tensors: List[IRSubTensor], dst_tensors: List[IRSubTensor]): - - if not all([isinstance(t, IRSubTensor) for t in src_tensors]): - raise TypeError("Expected src tensors to be IRSubTensor") - if not all([isinstance(t, IRSubTensor) for t in dst_tensors]): - raise TypeError("Expected dst tensors to be IRSubTensor") - if not ((len(src_tensors) == 1) or (len(dst_tensors) == 1)): - raise ValueError("Expected at least one of tensors has length 1") - - self.ttype = None - self._trace = list() - - # select - if len(src_tensors) == 1: - self.ttype = IRTransformType.Select - self._trace = SelectPlan.gen(src_tensors[0], dst_tensors) - - # merge - elif len(dst_tensors) == 1: - self.ttype = IRTransformType.Merge - self._trace = MergePlan.gen(src_tensors, dst_tensors[0]) - - else: - raise NotImplementedError - - super().__init__( - name = 'transformation', - signature = self.ttype.value, - input_length = len(src_tensors), - output_length = len(dst_tensors) - ) - for idx, input in enumerate(src_tensors): - self.set_input(idx, input) - for idx, output in enumerate(dst_tensors): - self.set_output(idx, output) - - def trace(self): - """ - Get trace of transformation - """ - return copy.copy(self._trace) - - def is_identity(self): - """ - Check if this transformation is a non-op - """ - return len(self._trace) == 0 - - -class SelectPrim: - - def __init__(self, tensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, shape: List[int]): - self.tensor = tensor - self.indmap = indmap - self.valmap = valmap - self.shape = shape - self.output = None - - def set_output(self, output: IRSubTensor): - self.output = output - - def __repr__(self): - dscp = f't{self.output._id} = select(t{self.tensor._id}, {self.indmap}, {self.valmap}, {self.shape})' - return dscp - - -class SelectPlan: - - @staticmethod - def gen(input: IRSubTensor, outputs: List[IRSubTensor]) -> List[SelectPrim]: - trace: List[SelectPrim] = list() - islicers: List[slice] = input.indmap.get() - for output in outputs: - if output == input: - continue - oslicers: List[slice] = output.indmap.get() - # indmap - indmap = list() - for islicer, oslicer in zip(islicers, oslicers): - istart, istop, istep = islicer.start, islicer.stop, islicer.step - ostart, ostop, ostep = oslicer.start, oslicer.stop, oslicer.step - if ostep % istep != 0: - raise RuntimeError("Step condition fails") - # relative offset - start = ostart - istart - stop = start + ostop - ostart - slicer = slice(start, stop, ostep) - indmap.append(slicer) - indmap = IndexMap(tuple(indmap)) - # value map - if output.valmap == input.valmap: - valmap = ValueMap(0, 1) - elif input.valmap == ValueMap(0, 1): - valmap = output.valmap - else: - print('from: ', input) - print('to : ', output) - raise NotImplementedError( - f"Not supported value select: {input.valmap} -> {output.valmap}" - ) - prim = SelectPrim(input, indmap, valmap, output.shape) - prim.set_output(output) - trace.append(prim) - return trace - - -class MergePrim: - def __init__(self, - tensors: List[IRSubTensor], - concat: Optional[int] = None, - add: bool = False): - if not ((concat is not None) ^ (add is True)): # xor condition - raise RuntimeError("Expected concat or add") - self.tensors = tensors - self.concat = concat - self.add = add - self.output = None - # re-order tensor - if isinstance(concat, int): - slicers = [tensor.indmap.get()[concat] for tensor in tensors] - starts = np.array([slicer.start for slicer in slicers], dtype=int) - sorted_idx = np.argsort(starts) - tensors = np.array(tensors)[sorted_idx] - self.tensors = tensors.tolist() - - def set_output(self, output: IRSubTensor): - self.output = output - - - def __repr__(self): - tensors = [f't{t._id}' for t in self.tensors] - tensors = '[' + ', '.join(tensors) + ']' - dscp = f't{self.output._id} = merge({tensors}, axis={self.concat}, add={self.add})' - return dscp - - -class MergePlan: - - @staticmethod - def gen(inputs: List[IRSubTensor], output: IRSubTensor) -> List[MergePrim]: - """ - Generate merge plan from input tensors to the output. - """ - if not all([isinstance(t, IRSubTensor) for t in inputs]): - raise TypeError("Expected inputs: List[IRSubTensor]") - if not isinstance(output, IRSubTensor): - raise TypeError("Expected inputs: List[IRSubTensor]") - - trace : List[MergePrim] = list() - remain_tensors = copy.copy(inputs) - dst_tensor = output - if dst_tensor in remain_tensors: - return trace - out = None - while out != dst_tensor: - # concat or merge - out = None - merge = False - for idx1 in range(len(remain_tensors) - 1): - for idx2 in range(idx1 + 1, len(remain_tensors)): - tensor1 = remain_tensors[idx1] - tensor2 = remain_tensors[idx2] - out = MergePlan.concat(tensor1, tensor2) - if out is not None: - out_tensor, concat_dim = out - out = out_tensor - prim = MergePrim([tensor1, tensor2], concat_dim, False) - prim.set_output(out_tensor) - trace.append(prim) - merge = True - break - out = MergePlan.add(tensor1, tensor2) - if out is not None: - prim = MergePrim([tensor1, tensor2], None, True) - prim.set_output(out) - trace.append(prim) - merge = True - break - if merge: - remain_tensors.remove(tensor1) - remain_tensors.remove(tensor2) - remain_tensors.append(out) - break - # cannot merge or add - if out is None: - raise RuntimeError("Merge Plan not found") - return trace - - - @staticmethod - def concat(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: - """ - Check if two tensor can be merged. - If they can be merged, return the merge index - """ - if not isinstance(tensor1, IRSubTensor) or not isinstance(tensor2, IRSubTensor): - raise TypeError("Expected two tensors") - if tensor1.overlap(tensor2): - return None - if tensor1.parent != tensor2.parent: - return None - if tensor1.valmap != tensor2.valmap: - return None - indices1 = tensor1.indmap.get() - indices2 = tensor2.indmap.get() - indmap = list() - if len(indices1) != len(indices2): - return None - axis = None - for dim, (slicer1, slicer2) in enumerate(zip(indices1, indices2)): - if slicer1 != slicer2: - start1, stop1, step1 = slicer1.start, slicer1.stop, slicer1.step - start2, stop2, step2 = slicer2.start, slicer2.stop, slicer2.step - if step1 != step2: - return None - if axis is not None: - return None - if start1 < start2 and stop1 == start2: - axis = dim - indmap.append(slice(start1, stop2, step1)) - elif start1 > start2 and start1 == stop2: - axis = dim - indmap.append(slice(start2, stop1, step1)) - else: - return None - else: - indmap.append(slicer1) - shapes = list() - for idx, (nele1, nele2) in enumerate(zip(tensor1.shape, tensor2.shape)): - nele = nele1 if idx != axis else nele1 + nele2 - shapes.append(nele) - mtensor = tensor1.parent.select( - indmap = tuple(indmap), - valmap = tensor1.valmap, - shape = shapes - ) - return mtensor, axis - - @staticmethod - def add(tensor1: IRSubTensor, tensor2: IRSubTensor) -> int: - if not isinstance(tensor1, IRSubTensor) or not isinstance(tensor2, IRSubTensor): - raise TypeError("Expected two tensors") - if tensor1.overlap(tensor2): - return None - if tensor1.parent != tensor2.parent: - return None - if tensor1.indmap != tensor2.indmap: - return None - if tensor1.valmap.chunk_num != tensor2.valmap.chunk_num: - return None - chunk_num = tensor1.valmap.chunk_num - idx1, idx2 = tensor1.valmap.idx, tensor2.valmap.idx - if chunk_num % 2 != 0: - return None - chunk_num = int(chunk_num // 2) - if int(idx1 // 2) != int(idx2 // 2): - return None - idx = int(idx1 // 2) - mtensor = tensor1.parent.select( - indmap = tensor1.indmap, - valmap = (idx, chunk_num), - shape = tensor1.shape - ) - return mtensor diff --git a/cube/schedule/iterator.py b/cube/schedule/iterator.py deleted file mode 100644 index a05a77a0..00000000 --- a/cube/schedule/iterator.py +++ /dev/null @@ -1,149 +0,0 @@ -from cube.schedule.action import Action -from cube.schedule.checker import correct_check - -import itertools -import numpy as np - - -def _comb(n, m): - """ - Calcualte combination C(n,m): select n from m (n < m) - """ - res = 1 - for j in range(0, min(n, m)): - res *= (m-j) / (min(n, m) - j) - return int(res) - - -def get_pipeline_seq_space_size(nstage, nmb): - """ - Calculate legal sequence number given num stage and num microbatch - - \prod \limits_{i=1}^{nmb} C(nstage, i*nstage) - - Args: - nstage: number of stages - nmb: number of micro batch - - Return: - total legal line - """ - res = 1 - for i in range(1, nmb+1): - res *= _comb(nstage*2, i*nstage*2) - return res - - -def legal_sequence(actions, relations): - """ - Yield all possible legal sequence given the list of actions - - Args: - actions (list[Actions]) - - Yield: - sequence (list[Actions]) - """ - if not all([isinstance(action, Action) for action in actions]): - raise TypeError("Expected the sequence to be list[Action]") - - for seq in itertools.permutations(actions): - seq = list(seq) - if correct_check(seq, actions, relations): - yield seq - - -def ready_action_set(actions, relations): - """ - Return a list of actions can be executed now - """ - ready_actions = list() - for action in actions: - satisfy = True - for (_, succ) in relations: - if succ == action: - satisfy = False - break - if satisfy: - ready_actions.append(action) - return ready_actions - - -def remove_dependency(action, relations): - new_relations = list() - for (pre, succ) in relations: - # remove dependency - if pre == action: - continue - new_relations.append((pre, succ)) - return new_relations - - -def sequence_space(actions, relations, path_shuffle=True, seq=list()): - if len(actions) == 0: - yield seq - # inital entry - entry_actions = ready_action_set(actions, relations) - entry_actions = np.array(entry_actions) - if path_shuffle: - np.random.shuffle(entry_actions) - for aid, action in enumerate(entry_actions): - if len(seq) == 0: - print(f'> search progress: [{aid}/{len(entry_actions)}]...') - seq = seq + [action] - action_idx = actions.index(action) - sub_actions = actions[:action_idx] + actions[action_idx+1:] - sub_relations = remove_dependency(action, relations) - for res in sequence_space(sub_actions, sub_relations, path_shuffle, seq): - yield res - seq = seq[:-1] - - -def sequence_space_bfs(actions, relations, path_shuffle=True): - # reverse relation - reverse_relation = list() - for relation in relations: - reverse_relation.append((relation[1], relation[0])) - # reverse seq - for seq in sequence_space(actions, reverse_relation, path_shuffle): - yield seq[::-1] - - -def sequence_space_batched(actions, relations, bs, bfs=False): - """ - bs: tuple (num_workers, seq_per_worker) - """ - seqs = list() - space_iter = sequence_space_bfs if bfs else sequence_space - for seq in space_iter(actions, relations): - seqs.append(seq) - if len(seqs) % (bs[0] * bs[1]) == 0: - seqs = [seqs[wid*bs[1]:(wid+1)*bs[1]] for wid in range(bs[0])] - yield seqs - seqs = list() - # tail - if len(seqs) != 0: - seqs = [seqs[wid*bs[1]:(wid+1)*bs[1]] for wid in range(bs[0])] - yield seqs - - -def placement_space(actions, ndevice, fb_same=True, path_shuffle=True, assigned=0): - if assigned == len(actions): - yield actions - return - - action = actions[assigned] - device_choice = np.array(list(range(ndevice)), dtype=np.int) - if path_shuffle: - np.random.shuffle(device_choice) - - if fb_same: - for assigned_action in actions[:assigned]: - # assume action name likes 'fS0D1' - if action.name[1:] == assigned_action.name[1:]: - device_choice = [assigned_action.device] - break - for device in device_choice: - action.device = device - for res in placement_space(actions, ndevice, fb_same, path_shuffle, assigned+1): - yield res diff --git a/cube/schedule/pool.py b/cube/schedule/pool.py deleted file mode 100644 index fd9f2045..00000000 --- a/cube/schedule/pool.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import List, Any -import copy - - -class SchedulePool: - - class __SchedulePool: - - def __init__(self): - - self._nodes = list() - self._tapes = dict() - - instance = None - - def __init__(self): - if not SchedulePool.instance: - SchedulePool.instance = SchedulePool.__SchedulePool() - - def __getattr__(self, name): - return getattr(self.instance, name) - - def add_node(self, node): - self.instance._nodes.append(node) - - def nodes(self) -> List: - return copy.copy(self.instance._nodes) - - def tape(self, tensor, trace: Any): - """ - Record the trace generated to this tensor - """ - self.instance._tapes[tensor._id] = trace - - def get_tape(self, tensor): - """ - Get the trace given the tensor - """ - if tensor._id not in self.instance._tapes: - return None - else: - return self.instance._tapes[tensor._id] - - def clear(self): - self.instance._nodes = list() - self.instance._tapes = dict() - - def __repr__(self): - dscp = '\n'.join([repr(node) for node in self._nodes]) - return dscp diff --git a/cube/schedule/su.py b/cube/schedule/su.py deleted file mode 100644 index 2aba8c18..00000000 --- a/cube/schedule/su.py +++ /dev/null @@ -1,525 +0,0 @@ -from typing import List, Optional, Tuple -import copy -from enum import Enum - -from cube.ir.cten import IRCell, IRTensor -from cube.graph.operator import IRFwOperation - - -class SUType(Enum): - - Dataloader = 'next(dataloader)' - - # outputs = cube.runtime.temporal.forward(model, *args) - Forward = 'cube.runtime.executor.fexecute' - - # grads = cube.runtime.temporal.backward( - # input_tensors, output_tensors, output_grads - # ) - Backward = 'cube.runtime.executor.backward' - - Transform = 'cube.runtime.transform' - - # cube.runtime.collectives.sendrecv(send_tensors, send_ranks, - # recv_shapes, from_ranks - # ) - P2P = 'cube.runtime.adapter.sendrecv' - Coll = 'cube.runtime.adapter.coll' - - Optimizer = 'cube.runtime.reducer.Reduce' - - Empty = 'None' - - -class ScheduleUnit(IRCell): - r""" - ScheduleUnit for policy scheduling. - """ - - def __init__(self, nodes: List[IRCell], stype: SUType, name='su'): - """ - Create a ScheduleUnit. - - Args: - nodes (List[IRCell]): A list of nodes in IRGraph - """ - - if not all([isinstance(node, IRCell) for node in nodes]): - raise ValueError("Expected each nodes to be List[IRCell]") - if not isinstance(stype, SUType): - raise TypeError("Expected stype be SUType") - - # get inputs and outputs - inputs = ScheduleUnit.get_inputs(nodes) - # inputs = [input for input in inputs if not input.is_param()] - outputs = ScheduleUnit.get_outputs(nodes) - # outputs = [output for output in outputs if not output.is_param()] - super().__init__( - name = name, - signature = stype.value, - input_length = len(inputs), - output_length = len(outputs) - ) - - self.stype = stype - - self._nodes = nodes - for idx, input in enumerate(inputs): - self.set_input(idx, input) - for idx, output in enumerate(outputs): - self.set_output(idx, output) - - # each input is associated with a reshape (merge) adatpers and - # a couple of send adapters and recv adapters (send + recv) - self._merge_adapters: List[ScheduleUnit] = [None] * len(inputs) - self._send_in_adapters: List[List[ScheduleUnit]] = [ - list() for _ in range(len(inputs)) - ] - self._recv_in_adapters: List[List[ScheduleUnit]] = [ - list() for _ in range(len(inputs)) - ] - - # each input is associated with a reshape (select) adatpers and - # a couple of send adapters and recv adapters (send + recv) - self._select_adapters: List[ScheduleUnit] = [None] * len(outputs) - self._send_out_adapters: List[List[ScheduleUnit]] = [ - list() for _ in range(len(outputs)) - ] - self._recv_out_adapters: List[List[ScheduleUnit]] = [ - list() for _ in range(len(outputs)) - ] - - # additional control dependency for add_flow - self._ctrl_predecessors = list() - self._ctrl_successors = list() - - self._tag = [node.tag for node in nodes] - - def __copy__(self): - """ - Copy the SU. Note the mirror su is also copied - """ - raise NotImplementedError("Copy SU is not supported yet") - - def in_adapters(self, index: Optional[int] = None) -> List: - """ - Get adapter for the input tensor at index - - Returns: - Tuple[List[ScheduleUnit], List[ScheduleUnit]]: - the send_adapters and recv_adapters - """ - if isinstance(index, int): - if index >= len(self._inputs): - raise RuntimeError( - f"Get index out of range ({index} >= {len(self._inputs)})" - ) - send_adapters = copy.copy(self._send_in_adapters[index]) - recv_adapters = copy.copy(self._recv_in_adapters[index]) - return send_adapters, recv_adapters - elif index is None: - all_send_adapters = list() - all_recv_adapters = list() - for adapters in self._send_in_adapters: - all_send_adapters += adapters - for adapters in self._recv_in_adapters: - all_recv_adapters += adapters - return all_send_adapters, all_recv_adapters - else: - raise TypeError("Expected index to be None or int") - - def merge_adapters(self, index: Optional[int] = None) -> List: - """ - Get select adapter for the input tensor at index - - Returns: - Union[ScheduleUnit, List[ScheduleUnit]] - """ - if isinstance(index, int): - if index >= len(self._inputs): - raise RuntimeError( - f"Get index out of range ({index} >= {len(self._inputs)})" - ) - select_adapter = self._merge_adapters[index] - return select_adapter - elif index is None: - return copy.copy(self._merge_adapters) - else: - raise TypeError("Expected index to be None or int") - - def out_adapters(self, index: Optional[int] = None) -> Tuple[List, List]: - """ - Get adapter for the output tensor at index - - Returns: - Tuple[List[ScheduleUnit], List[ScheduleUnit]]: - the send_adapters and recv_adapters - """ - if isinstance(index, int): - if index >= len(self._outputs): - raise RuntimeError( - f"Get index out of range ({index} >= {len(self._outputs)})" - ) - send_adapters = copy.copy(self._send_out_adapters[index]) - recv_adapters = copy.copy(self._recv_out_adapters[index]) - return send_adapters, recv_adapters - elif index is None: - all_send_adapters = list() - all_recv_adapters = list() - for adapters in self._send_out_adapters: - all_send_adapters += adapters - for adapters in self._recv_out_adapters: - all_recv_adapters += adapters - return all_send_adapters, all_recv_adapters - else: - raise TypeError("Expected index to be None or int") - - def select_adapters(self, index: Optional[int] = None) -> List: - """ - Get select adapter for the input tensor at index - - Returns: - Union[ScheduleUnit, List[ScheduleUnit]] - """ - if isinstance(index, int): - if index >= len(self._outputs): - raise RuntimeError( - f"Get index out of range ({index} >= {len(self._outputs)})" - ) - select_adapter = self._select_adapters[index] - return select_adapter - elif index is None: - return copy.copy(self._select_adapters) - else: - raise TypeError("Expected index to be None or int") - - def _clear_adapters(self, ctrl=False): - """ - Clear all adapters for this SU. By default control dependency is keeped - - Args: - ctrl (boolean): if true, additional control dependency is removed. - if false, additional control dependency is keeped. - """ - self._send_in_adapters: List[List[ScheduleUnit]] = [ - list() for _ in range(len(self.inputs())) - ] - self._recv_in_adapters: List[List[ScheduleUnit]] = [ - list() for _ in range(len(self.inputs())) - ] - self._merge_adapters: List[ScheduleUnit] = [None] * len(self._inputs) - self._select_adapters: List[ScheduleUnit] = [None] * len(self._outputs) - self._send_out_adapters: List[List[ScheduleUnit]] = [ - list() for _ in range(len(self.outputs())) - ] - self._recv_out_adapters: List[List[ScheduleUnit]] = [ - list() for _ in range(len(self.outputs())) - ] - if ctrl: - self._ctrl_predecessors = list() - self._ctrl_successors = list() - - def _add_in_adapter(self, index: int, send_adapters, recv_adapters): - """ - Add adapters to the input tensor of this SU - - Args: - index (int): the input index - send_adapter (ScheduleUnit) - recv_adapter (ScheduleUnit) - """ - if index >= len(self._inputs): - raise ValueError(f"index {index} out of range {len(self._inputs)}") - if isinstance(send_adapters, ScheduleUnit): - send_adapters = [send_adapters] - if not all(isinstance(adapter, ScheduleUnit) for adapter in send_adapters): - raise TypeError("Expected send adapter to be (list of) ScheduleUnit") - if isinstance(recv_adapters, ScheduleUnit): - recv_adapters = [recv_adapters] - if not all(isinstance(adapter, ScheduleUnit) for adapter in send_adapters): - raise TypeError("Expected recv adapters to be (list of) ScheduleUnit") - if len(send_adapters) != len(recv_adapters): - raise ValueError("Expected same number of send / recv adapters") - for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): - self._send_in_adapters[index].append(send_adapter) - self._recv_in_adapters[index].append(recv_adapter) - - def _set_merge_adapter(self, index: int, merge_adapter): - """ - Set adapters to the input tensor of this SU - - Args: - index (int): the input index - merge_adapter (ScheduleUnit) - """ - if index >= len(self._inputs): - raise ValueError(f"index {index} out of range {len(self._inputs)}") - if merge_adapter is not None and not isinstance(merge_adapter, ScheduleUnit): - raise TypeError("Expected merge adapter to be None or ScheduleUnit") - self._merge_adapters[index] = merge_adapter - - def _add_out_adapter(self, index: int, send_adapters, recv_adapters): - """ - Add adapters to the output tensor of this SU - - Args: - index (int): the output index - send_adapter (ScheduleUnit) - recv_adapter (ScheduleUnit) - """ - if index >= len(self._outputs): - raise ValueError(f"index {index} out of range {len(self._outputs)}") - if isinstance(send_adapters, ScheduleUnit): - send_adapters = [send_adapters] - if not all(isinstance(adapter, ScheduleUnit) for adapter in send_adapters): - raise TypeError("Expected send adapter to be (list of) ScheduleUnit") - if isinstance(recv_adapters, ScheduleUnit): - recv_adapters = [recv_adapters] - if not all(isinstance(adapter, ScheduleUnit) for adapter in send_adapters): - raise TypeError("Expected recv adapters to be (list of) ScheduleUnit") - if len(send_adapters) != len(recv_adapters): - raise ValueError("Expected same number of send / recv adapters") - for send_adapter, recv_adapter in zip(send_adapters, recv_adapters): - self._send_out_adapters[index].append(send_adapter) - self._recv_out_adapters[index].append(recv_adapter) - - def _set_select_adapter(self, index: int, select_adapter): - """ - Set adapters to the output tensor of this SU - - Args: - index (int): the output index - select_adapter (ScheduleUnit) - """ - if index >= len(self._outputs): - raise ValueError(f"index {index} out of range {len(self._inputs)}") - if select_adapter is not None and not isinstance(select_adapter, ScheduleUnit): - raise TypeError("Expected merge adapter to be Optional[ScheduleUnit]") - self._select_adapters[index] = select_adapter - - def _remove_adapter(self, adapter): - """ - Remove the adapter - """ - for send_adapters in self._send_in_adapters: - if adapter in send_adapters: - send_adapters.remove(adapter) - return True - for recv_adapters in self._recv_in_adapters: - if adapter in recv_adapters: - recv_adapters.remove(adapter) - return True - if adapter in self._merge_adapters: - idx = self._merge_adapters.index(adapter) - self._merge_adapters[idx] = None - return True - if adapter in self._select_adapters: - idx = self._select_adapters.index(adapter) - self._select_adapters[idx] = None - return True - for send_adapters in self._send_out_adapters: - if adapter in send_adapters: - send_adapters.remove(adapter) - return True - for recv_adapters in self._recv_out_adapters: - if adapter in recv_adapters: - recv_adapters.remove(adapter) - return True - return False - - def nodes(self, index: Optional[int] = None): - """ - Get node at position index - """ - if isinstance(index, int): - if index >= len(self._nodes): - raise RuntimeError( - f"Get node out of range ({index} >= {len(self._nodes)})" - ) - return self._nodes[index] - elif index is None: - return copy.copy(self._nodes) - else: - raise TypeError("Expected index to be None or int") - - def add_predecessor(self, input_index: int, su): - """ - Add a predecessor cell in the input_index slot. - self.input[input_index] = node.output[out_index] - """ - if input_index == -1: - self._ctrl_predecessors.append(su) - else: - super().add_predecessor(input_index, su) - - def predecessors(self, index: Optional[int] = None) -> List: - """ - Get 1-hop predecessor cells including control predecessors - - Args: - index (Optional[int]): - -1: return control predecessors - None: return all predecessors including index - >0 : return input SUs at input index - - Returns: - cell(s): List[IRCell] - """ - if isinstance(index, int): - if index == -1: - return copy.copy(self._ctrl_predecessors) - if index >= len(self._inputs): - raise RuntimeError( - f"Get the input out of range ({index} >= {len(self._inputs)}" - ) - return copy.copy(self._predecessors[index]) - elif index is None: - predecessors = list() - for pre_cells in self._predecessors: - predecessors += pre_cells - predecessors += self._ctrl_predecessors - return predecessors - else: - raise TypeError("Expected index to be None or int") - - def add_successor(self, output_index: int, su): - """ - Set self node the output index node. - `node` will take the self.outputs(index) as the input - """ - if output_index == -1: - self._ctrl_successors.append(su) - else: - super().add_successor(output_index, su) - - def successors(self, index: Optional[int] = None) -> List: - """ - Get 1-hop successor cells including control successors - - Args: - index (Optional[int]): - -1: return control successors - None: return all successors including index - >0 : return output SUs at output index - - Returns: - cells: List[ScheduleUnit] - """ - if isinstance(index, int): - if index == -1: - return copy.copy*self._ctrl_successors - if index >= len(self._outputs): - raise RuntimeError( - f"Get the output out of range ({index} >= {len(self._outputs)}" - ) - return copy.copy(self._successors[index]) - elif index is None: - successors = list() - for post_cells in self._successors: - successors += post_cells - successors += self._ctrl_successors - return successors - else: - raise TypeError("Expected index to be None or int") - - def is_identity(self): - """ - Check if the SU is identity - """ - # not assigned - if len(self.device) == 0: - return False - if self.stype == SUType.P2P: - send_ranks = set(self.device) - recv_ranks = set(self.device) - for node in self.nodes(): - send_ranks.update(node.send_ranks) - recv_ranks.update(node.recv_ranks) - if list(send_ranks) != self.device: - return False - if list(recv_ranks) != self.device: - return False - return True - return False - - @staticmethod - def get_inputs(nodes): - """ - Get all the input tensors the is not generated by nodes - - Inputs - - Returns: - List[IRTensor] - """ - all_outputs = list() - for cell in nodes: - all_outputs += cell.outputs() - inputs = list() - for cell in nodes: - for input in cell.inputs(): - if isinstance(input, IRTensor): - if input not in all_outputs: - if input not in inputs: - inputs.append(input) - return inputs - - @staticmethod - def get_outputs(nodes: List[IRCell]): - """ - Get all the input tensors the is not generated by nodes - - Args: - This will also consider the successor forward nodes. - If it is required by other outside forward nodes, - put in the outputs list - - Returns: - List[IRTensor] - """ - all_inputs = list() - for node in nodes: - all_inputs += node.inputs() - outputs = list() - for node in nodes: - for idx, output in enumerate(node.outputs()): - if isinstance(output, IRTensor): - if output not in all_inputs: - if output not in outputs: - outputs.append(output) - continue - succs = node.successors(idx) - fsuccs = [ - fnode for fnode in succs if isinstance(fnode, IRFwOperation) - ] - for fsucc in fsuccs: - if fsucc not in nodes: - if output not in outputs: - outputs.append(output) - return outputs - - def __repr__(self): - su_inputs = list() - for tensor in self.inputs(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - su_inputs.append(f'{anno}{tensor._id}') - else: - su_inputs.append(tensor) - su_outputs = list() - for tensor in self.outputs(): - if isinstance(tensor, IRTensor): - anno = 't' - if tensor.is_param(): - anno = 'w' - if tensor.is_grad(): - anno = 'g' - su_outputs.append(f'{anno}{tensor._id}') - else: - su_outputs.append(tensor) - dscp = f'SU({self.stype}, nodes={len(self.nodes())})-dev{self.device}: {su_inputs} -> {su_outputs}' - return dscp diff --git a/cube/schedule/sugraph.py b/cube/schedule/sugraph.py deleted file mode 100644 index 9843e03a..00000000 --- a/cube/schedule/sugraph.py +++ /dev/null @@ -1,630 +0,0 @@ -from typing import List, Optional, Union -import copy -from cube.graph.tensor import IRSubTensor - -from cube.ir.cten import IRCell, IRTensor -from cube.graph.operator import IRBpOperation -from cube.graph.operator import IRDataOperation -from cube.graph.operator import IRFwOperation -from cube.schedule.su import SUType, ScheduleUnit -from cube.schedule.adapter.comm import IRCommunication -from cube.schedule.adapter.transform import IRTensorTransform - - -class SUGraph(IRCell): - - def __init__(self, sus: List[ScheduleUnit]): - - if not all([isinstance(su, ScheduleUnit) for su in sus]): - raise TypeError( - f"Expected a list of ScheduleUnits, but got {type(sus)}" - ) - - inputs = IRCell.get_inputs(sus) - inputs = [input for input in inputs if not input.is_param()] - outputs = IRCell.get_outputs(sus) - outputs = [output for output in outputs if not output.is_param()] - super().__init__( - name = 'SU', - signature = 'None', - input_length = len(inputs), - output_length = len(outputs) - ) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - for idx, output in enumerate(outputs): - self.set_output(idx, output) - - self.sequence = sus - SUGraph.reset_dependency(self.sequence) - - @property - def nnodes(self) -> int: - """ - Get number of nodes (int) - """ - return len(self.sequence) - - @staticmethod - def reset_dependency(sus: List[ScheduleUnit]): - """ - Reset the node dataflow dependency - """ - if not all([isinstance(su, ScheduleUnit) for su in sus]): - raise TypeError("Expected list of schedule unit") - adapters = [SUType.P2P, SUType.Coll, SUType.Transform] - for su in sus: - su.clear_predecessor() - su.clear_successor() - for src_idx in range(len(sus)): - src = sus[src_idx] - for dst in sus[src_idx+1:]: - # inter-adapter has no dependency - if src.stype in adapters and \ - dst.stype in adapters and \ - src.stype == dst.stype: - continue - for out_idx, out_tensor in enumerate(src.outputs()): - if not isinstance(out_tensor, IRTensor): - continue - # special dependency for communication adapter - if dst.stype == SUType.P2P: - for recv_tensor in dst.outputs(): - if out_tensor.overlap(recv_tensor): - src.add_successor(out_idx, dst) - dst.add_predecessor(-1, src) - for in_idx, in_tensor in enumerate(dst.inputs()): - if out_tensor.overlap(in_tensor): - src.add_successor(out_idx, dst) - dst.add_predecessor(in_idx, src) - - def __len__(self): - return len(self.sequence) - - def sus(self, index: Optional[int] = None): - """ - Return ScheduleUnit - - Args: - - """ - if isinstance(index, int): - if index >= len(self.sequence): - raise RuntimeError( - f"Get node out of range ({index} >= {len(self.sequence)})" - ) - return self.sequence[index] - elif index is None: - return copy.copy(self.sequence) - else: - raise TypeError("Expected index to be None or int") - - def get_sus(self, stype: SUType) -> List[ScheduleUnit]: - """ - Get SUs that are of stype - """ - return [su for su in self.sequence if su.stype == stype] - - def fsus(self) -> List[ScheduleUnit]: - """ - Get forward ScheduleUnits sequence. - """ - return [su for su in self.sequence if su.stype == SUType.Forward] - - def happen_before(self, su1, su2, visited=None): - """ - Check if the su1 -> (happened before) su2 - - Returns: - Boolean - """ - # FIXME: there is still a strange bug may cause infinite loop - if visited is None: - visited = list() - if su1 in visited: - return False - visited.append(su1) - - if not isinstance(su1, ScheduleUnit) or \ - not isinstance(su2, ScheduleUnit): - raise TypeError("Expected su to be an ScheduleUnit") - if su2 in su1.successors(): - return True - else: - for succ_su in su1.successors(): - # don't need to consider P2P comm dependency - if succ_su.stype == SUType.P2P: - continue - if self.happen_before(succ_su, su2, visited): - return True - return False - - def merge(self, su1: ScheduleUnit, su2: ScheduleUnit) -> ScheduleUnit: - """ - Merge two ScheduleUnit as well as their adapters. This requires - - 1). all the nodes in one SU happens before / after - all the nodes in another SU. (Guaranteed by default - as all the operations on sequence are semantic-correct) - - 2). all the nodes in both SU are on the same device, - have same tags and they are not equal. - - 3). Deadlock-free merge. Suppose - SU1 (dev0) -> SU2 (dev1) -> SU3 (dev0) - Then merge SU1 and SU3 to SU4 will cause - deadlock on SU4 -> <- SU2 - - Note due to PyTorch limitation, - merging two forward ScheduleUnits will also cause - the merge of corresponding two backward ScheduleUnits. - - Returns: - if succeed: A merged ScheduleUnit. - if fail: None - """ - - fsus = self.fsus() - if su1 not in fsus: - raise RuntimeError(f"SU1: {su1} not in forward SUs") - if su2 not in fsus: - raise RuntimeError(f"SU2: {su2} not in forward SUs") - - idx1, idx2 = self.sequence.index(su1), self.sequence.index(su2) - su1, su2 = (su1, su2) if idx1 < idx2 else (su2, su1) - - # condition 1): same device - if su1.device != su2.device: - return None - - # condition 2): su2 input cannot be got from both su1 and other su - start, stop = min(idx1, idx2), max(idx1, idx2) - inter_sus = self.sequence[start+1:stop] - inter_sus = [su for su in inter_sus if su.stype != SUType.P2P] - for su in inter_sus: - # FIXME: currently only allow other device su exists - if self.happen_before(su1, su) or self.happen_before(su, su2): - return None - for idx in range(len(su2.inputs())): - prev_sus = su2.predecessors(idx) - prev_sus = [su for su in prev_sus if su.stype != SUType.P2P] - if su2 in prev_sus and len(prev_sus) > 1: - return None - - # start merging - fnodes = su1.nodes() + su2.nodes() - # TODO: fix multi-branch - fsu = ScheduleUnit(fnodes, SUType.Forward, name='fsu') - fsu.device = su1.device - - bnodes = [node.mirror for node in fnodes][::-1] - skip_bp = all([bnode is None for bnode in bnodes]) - if not skip_bp: - bnode = IRBpOperation( - data_num=len(fsu.inputs()), - grad_num=len(fsu.outputs()) - ) - for idx, fin in enumerate(fsu.inputs()): - bnode.set_data(idx, fin) - - for idx, fout in enumerate(fsu.outputs()): - bnode.set_grad(idx, fout.grad) - - for idx, fin in enumerate(fsu.inputs()): - bnode.set_output(idx, fin.grad) - bsu = ScheduleUnit([bnode], stype=SUType.Backward, name='bsu') - bsu.device = su2.mirror.device - IRCell.make_pair(fsu, bsu) - - def _set_adapters(su1: ScheduleUnit, su2: ScheduleUnit, msu: ScheduleUnit): - # set adapter - for idx, input in enumerate(msu.inputs()): - if input in su1.inputs(): - su1_idx = su1.inputs().index(input) - adapters = su1.in_adapters(su1_idx) - merge_adapter = su1.merge_adapters(su1_idx) - elif input in su2.inputs(): - su2_idx = su2.inputs().index(input) - adapters = su2.in_adapters(su2_idx) - merge_adapter = su2.merge_adapters(su2_idx) - else: - print(f'> Error: msu: {msu}') - print(f'> Error: su1: {su1}') - print(f'> Error: su2: {su2}') - raise RuntimeError("Internal Error: not found input SU") - msu._add_in_adapter(idx, *adapters) - msu._set_merge_adapter(idx, merge_adapter) - for idx, output in enumerate(msu.outputs()): - if output in su1.outputs(): - su1_idx = su1.outputs().index(output) - adapters = su1.out_adapters(su1_idx) - select_adapter = su1.select_adapters(su1_idx) - elif output in su2.outputs(): - su2_idx = su2.outputs().index(output) - adapters = su2.out_adapters(su2_idx) - select_adapter = su2.select_adapters(su2_idx) - else: - raise RuntimeError("Internal Error: not found output SU") - msu._add_out_adapter(idx, *adapters) - msu._set_merge_adapter(idx, select_adapter) - # remove adapters - for idx, input in enumerate(su2.inputs()): - if input not in msu.inputs(): - sadapters, radapters = su2.in_adapters(idx) - for adapter in sadapters + radapters: - if adapter in self.sequence: - self.sequence.remove(adapter) - - _set_adapters(su1, su2, fsu) - if not skip_bp: - _set_adapters(su2.mirror, su1.mirror, bsu) - - # replace - self.sequence[self.sequence.index(su1)] = fsu - self.sequence.remove(su2) - if not skip_bp: - self.sequence[self.sequence.index(su2.mirror)] = bsu - self.sequence.remove(su1.mirror) - - # re-gen adapter - SUGraph.reset_dependency(self.sequence) - return fsu - - def add_flow(self, su1: ScheduleUnit, su2: ScheduleUnit): - """ - Add control flow dependency su1 -> su2 - """ - if not isinstance(su1, ScheduleUnit) or not isinstance(su2, ScheduleUnit): - raise TypeError("Expected both SU1 and SU2 are ScheduleUnit") - if su1 not in self.sequence: - raise ValueError(f"su1 {su1} not in SUGraph") - if su2 not in self.sequence: - raise ValueError(f"su1 {su2} not in SUGraph") - if self.happen_before(su2, su1): - return False - su1.add_successor(-1, su2) - su2.add_predecessor(-1, su1) - return True - - def assign(self, su: ScheduleUnit, ranks: Union[int, List[int]]): - """ - Assign SU to devices. - - The assignment will automatically set device of its Adapter SU. - - 1) if ranks has multiple int, then the su is copied as the same - SU will be happened redundantly on multiple devices. - - 2) if the input tensor this su is decided to be generated on - other devices, then Adapter SUs (send SU and recv SU) will - be generated and inserted right before this SU. - """ - if su not in self.sequence: - raise ValueError(f"SU {su} is not in the SUGraph") - if isinstance(ranks, int): - ranks = [ranks] - elif not all([isinstance(rank, int) for rank in ranks]): - raise TypeError("Expected type ranks to be Union[int, List[int]]") - - if su.stype == SUType.P2P: - return False - - if set(su.device) == set(ranks): - return True - - if len(ranks) != 1: - if su.stype == SUType.Dataloader: - su.device = ranks - else: - raise NotImplementedError("Assign multiple ranks to one SU is not supported") - # print('warning: Missing adapter copy!!') - # sus = [copy.copy(su) for _ in range(len(ranks)-1)] - # for su in sus: - # index = self.sus().index(su) - # self.sequence.insert(index, su) - # SUGraph.reset_dependency(self.sequence) - # for su, rank in zip(sus, ranks): - # self.assign(su, rank) - - # set device - su.device = ranks - - # set adapter device for the input - for idx in range(len(su.inputs())): - send_adapters, recv_adapters = su.in_adapters(idx) - merge_adapter = su.merge_adapters(idx) - for send_adapter in send_adapters: - send_adapter.nodes(0).send_ranks = [ranks[0],] - for recv_adapter in recv_adapters: - recv_adapter.device = ranks - if merge_adapter is not None: - merge_adapter.device = ranks - - # set adapter device for the output - for idx in range(len(su.outputs())): - send_adapters, recv_adapters = su.out_adapters(idx) - select_adapter = su.select_adapters(idx) - for send_adapter in send_adapters: - send_adapter.device = ranks - for recv_adapter in recv_adapters: - recv_adapter.nodes(0).recv_ranks = [ranks[0],] - if select_adapter is not None: - select_adapter.device = ranks - return True - - def set_order(self, seq: List[ScheduleUnit]): - """ - set a topological order for SUGraph, which requires seq: - - 1). The set of SUs in seq must be equal to set of SUGraph - 2). Staisfies topological order - - """ - if not all([isinstance(su, ScheduleUnit) for su in seq]): - raise ValueError("Expected a list of SUs") - if len(seq) != len(self.sequence): - return False - for su in seq: - if su not in self.sequence: - return False - # correctness check - if not SUGraph.is_topo_order(seq, integrity_check=True): - return False - self.sequence = seq - return True - - def partial_set_order(self, seq: List[ScheduleUnit], lazy=False): - """ - Set a order of the sequence using part of SUs. - - A random topological order will be set under - the constraints of given `seq` order - - Args: - seq: partial scheduling sequence - lazy: - if True, the remaining SU is inserted only when it is needed. - if False, the remaining SU is inserted once it is ready. - - """ - if lazy: - raise NotImplementedError("Not supported for Lazy") - seq = copy.copy(seq) - for su in seq: - if su not in self.sequence: - raise RuntimeError(f"SU {su} is not in SUGraph") - if not SUGraph.is_topo_order(seq, integrity_check=False): - return False - remain_sus : ScheduleUnit = list() - for su in self.sequence: - if su not in seq: - remain_sus.append(su) - for rsu in remain_sus: - happen_before_sus = rsu.predecessors() - # A temporal fix for loss computation and backward - # -- as they have no dependency in theory - if rsu.stype == SUType.Backward: - if rsu.mirror not in happen_before_sus: - happen_before_sus.append(rsu.mirror) - # send / recv su pair should be colocated - if rsu.stype == SUType.P2P: - if rsu in seq: - continue - if rsu.mirror in seq: - index = seq.index(rsu.mirror) - seq.insert(index+1, rsu) - continue - if rsu in seq: - raise RuntimeError(f"Internal Error: should not appear SU: {rsu}") - idx = 0 - while len(happen_before_sus) > 0: - if idx == len(seq): - raise RuntimeError( - f"Internal Error: SU {rsu} cannot be inserted" - ) - su = seq[idx] - if su in happen_before_sus: - happen_before_sus.remove(su) - idx += 1 - seq.insert(idx, rsu) - - # if not SUGraph.is_topo_order(seq, integrity_check=True): - # raise RuntimeError("Internal Error: topo is not guaranteed.") - self.sequence = seq - return True - - - @staticmethod - def gen_adapter(sus: List[ScheduleUnit]) -> List[ScheduleUnit]: - """ - Each computation SU has adapters for its inputs. - """ - sugraph = SUGraph(sus) - - # clear adapters - for su in sugraph.sus(): - su._clear_adapters() - - for su in sugraph.sus(): - for in_idx, input in enumerate(su.inputs()): - if not isinstance(input, IRTensor): - continue - pre_sus = su.predecessors(in_idx) - tensor_segments = list() - for pre_su in pre_sus: - for out_idx, output in enumerate(pre_su.outputs()): - if output.overlap(input): - sub_tensor = input.common(output) - if sub_tensor != input and sub_tensor not in tensor_segments: - tensor_segments.append(sub_tensor) - send_op = IRCommunication( - send_tensors=[sub_tensor], - send_ranks = [-1] - ) - recv_op = IRCommunication( - recv_tensors=[sub_tensor], - recv_ranks = [-1] - ) - IRCell.make_pair(send_op, recv_op) - send_su = ScheduleUnit([send_op], SUType.P2P, name='send') - recv_su = ScheduleUnit([recv_op], SUType.P2P, name='recv') - su._add_in_adapter(in_idx, send_su, recv_su) - send_su.device = su.device - pre_su._add_out_adapter(out_idx, send_su, recv_su) - recv_su.device = su.device - IRCell.make_pair(send_su, recv_su) - # add adapter for merge - if len(tensor_segments) != 0: - try: - merge_op = IRTensorTransform( - src_tensors=tensor_segments, dst_tensors=[input] - ) - except Exception: - raise RuntimeError(f"Merge Generation Error: {su}") - merge_su = ScheduleUnit([merge_op], SUType.Transform, name='merge') - su._set_merge_adapter(in_idx, merge_su) - merge_su.device = su.device - - # add adapter for select - for su in sugraph.sus(): - for out_idx, output in enumerate(su.outputs()): - if not isinstance(output, IRTensor): - continue - select_tensors = list() - send_adapters, recv_adapters = su.out_adapters(out_idx) - for send_adapter in send_adapters: - for tensor in send_adapter.nodes(0).send_tensors: - if tensor != output and tensor not in select_tensors: - select_tensors.append(tensor) - if len(select_tensors) != 0: - try: - select_op = IRTensorTransform( - src_tensors=[output], dst_tensors=select_tensors - ) - except Exception: - raise RuntimeError(f"Select Generation Error: {su}") - select_su = ScheduleUnit( - [select_op], SUType.Transform, name='select' - ) - su._set_select_adapter(out_idx, select_su) - select_su.device = su.device - - sus_with_adapter = list() - for su in sus: - # send + recv + merge - for idx in range(len(su.inputs())): - merge_su = su.merge_adapters(idx) - send_adapters, recv_adapters = su.in_adapters(idx) - # PyTorch implementation issue: forward + backward happened on same device - if su.stype == SUType.Backward and not su.inputs(idx).is_grad(): - continue - for send_su, recv_su in zip(send_adapters, recv_adapters): - sus_with_adapter.append(send_su) - sus_with_adapter.append(recv_su) - if merge_su: - sus_with_adapter.append(merge_su) - # excute - sus_with_adapter.append(su) - # select - for idx in range(len(su.outputs())): - select_su = su.select_adapters(idx) - if select_su: - sus_with_adapter.append(select_su) - return sus_with_adapter - - - @staticmethod - def is_topo_order(seq: List[ScheduleUnit], integrity_check=False): - """ - Check whether seq satisfies topological order. - - Args: - seq: List of ScheduleUnit - integrity_check: - If true, performs additional integrity check that requires - all the SUs in predecessor and successor of a SU should - appear in the sequence. - - Returns: - Boolean: True for satisfying topo order, otherwise False. - """ - - for index, su in enumerate(seq): - for pre_su in su.predecessors(): - # find the pre-su not appear in sequence - if integrity_check: - if pre_su not in seq: - return False - if pre_su in seq: - pre_idx = seq.index(pre_su) - # violate topological order - if pre_idx >= index: - return False - return True - - def __repr__(self): - dscp = f'ScheduleSeq (len={len(self)}):\n' - for node in self.sequence: - succ_node_ids = [None] * len(node.outputs()) - for out_idx in range(len(node.outputs())): - node_list = [snode._id for snode in node.successors(out_idx)] - succ_node_ids[out_idx] = node_list - # dscp += f"{node._id}: {node}\n" - dscp += f"\n{node._id}: {node} -> su id {succ_node_ids}\n" - return dscp - - -class SUGraphGener: - - @staticmethod - def gen_sugraph(nodes) -> SUGraph: - """ - Generate SUGraph from SchedulePool - """ - sus = list() - fnodes = list() - fsus: List[ScheduleUnit] = list() - for node in nodes: - su = ScheduleUnit([node], stype=SUType.Empty, name='su') - if isinstance(node, IRDataOperation): - stype = SUType.Dataloader - elif isinstance(node, IRFwOperation): - stype = SUType.Forward - fnodes.append(node) - fsus.append(su) - elif isinstance(node, IRBpOperation): - stype = SUType.Backward - # get the last one same node - index = len(fnodes) - fnodes[::-1].index(node.mirror) - 1 - fsu = fsus[index] - IRCell.make_pair(su, fsu) - # remove fsu - fnodes.pop(index) - fsus.remove(fsu) - else: - raise NotImplementedError("Not implemented node type") - su.stype = stype - sus.append(su) - sus_with_adapter = SUGraph.gen_adapter(sus) - sugraph = SUGraph(sus_with_adapter) - return sugraph - - -class SeqSpace: - - @staticmethod - def space_size(seq, device_num=1): - """ - Calculate legal - """ - - def _comb(n, m): - """ - Calcualte combination C(n,m): select n from m (n < m) - """ - res = 1 - for j in range(0, min(n, m)): - res *= (m-j) / (min(n, m) - j) - return int(res) - - raise NotImplementedError diff --git a/cube/schedule/translator.py b/cube/schedule/translator.py deleted file mode 100644 index dae6ade8..00000000 --- a/cube/schedule/translator.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Traning Logic Translator - -The traning logic first translate the training logic into -Schedule Units, and then add Adapter ScheduleUnit -""" -import copy -from typing import Optional -import torch - -from cube.graph.tensor import IRFullTensor, IRSubTensor -from cube.graph.operator import IRDataOperation -import cube.graph.gpass as gpass -from cube.schedule.pool import SchedulePool - -from cube.runtime.syndata import CubeDataLoader - - -class IRDataLoader: - - def __init__(self, dataloader: CubeDataLoader): - self.dataloader = iter(dataloader) - self.batch_dims = dataloader.get_batch_dims() - - def get_batch_dims(self, idx: Optional[int] = None) -> int: - if idx is None: - return copy.copy(self.batch_dims) - else: - return self.batch_dims[idx] - - def __iter__(self): - return self - - def __next__(self): - return LogicTranslator.load_data(self) - - -class LogicTranslator: - - @staticmethod - def load_data(dataloader: IRDataLoader): - """ - Translator Action: Load data from data loaderw - """ - datas = next(dataloader.dataloader) - if not isinstance(datas, tuple): - datas = (datas,) - - # data IRTensor - outputs = list() - for data in datas: - if torch.is_tensor(data): - data = IRFullTensor(shape=list(data.shape), name='data').tosub() - data.requires_grad = False - outputs.append(data) - - data_op = IRDataOperation( - data_num=len(datas), batch_dims=dataloader.get_batch_dims(), - ) - for idx, output in enumerate(outputs): - data_op.set_output(idx, output) - - SchedulePool().add_node(data_op) - if len(outputs) == 0: return - elif len(outputs) == 1: return outputs[0] - else: return tuple(outputs) - - @staticmethod - def forward(graph, *args): - """ - Translator Action: forward an IRGraph - """ - fgraph = gpass.forward(graph, *args) - for node in fgraph.nodes(): - SchedulePool().add_node(node) - for output in fgraph.outputs(): - SchedulePool().tape(output, fgraph.nodes()) - outputs = fgraph.outputs() - if len(outputs) == 1: return outputs[0] - elif len(outputs) == 0: return None - else: return outputs - - @staticmethod - def backward(loss: IRSubTensor): - """ - Translator Action: backward a tensor - """ - trace = SchedulePool().get_tape(loss) - if trace is None: - raise RuntimeError("No forward detected") - # make grad to 1.0 - if not loss.shape == [1]: - raise RuntimeError("backward can only perform on the scaler tensor") - loss.parent.grad = None - bnode = None - for node in trace: - for idx, output in enumerate(node.outputs()): - if loss.overlap(output): - bnode = node.mirror - output.grad = None - bnode.set_grad(idx, None) - for node in trace[::-1]: - SchedulePool().add_node(node.mirror) From ed3164528747bc3ff00d0bd40a31dc7c8130a254 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 13:12:45 +0800 Subject: [PATCH 0510/1892] communication group init for collectives --- cube/codegen/codegen.py | 22 ++++++++----- cube/compiler.py | 53 ++++++-------------------------- cube/execplan/planpass/fusion.py | 2 +- 3 files changed, 26 insertions(+), 51 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 06dfee3c..efae0841 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -7,14 +7,15 @@ from cube.ir.cten import IRCell, IRTensor from cube.ir.dtype import IRDType + from cube.graph.tensor import IRSubTensor from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation from cube.graph.adapter.adapter import CollectivePrim, IRAdapter, SelectPrim, MovePrim, MergePrim from cube.graph.adapter.adapter import IRWeightReducer +from cube.graph.graph import IRGraph + from cube.execplan import ExectuionPlan -# from cube.schedule.adapter.collectives import IRCollectives -from cube.graph.graph import IRGraph from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -78,17 +79,27 @@ def init_comm_groups(self): Creating communication group requires all the devices enter the same call. """ + graph = self.execplan.graph sign = 'self.init_group(ranks={ranks})' # collect groups from weight reducer comm_groups: Dict[Tuple[int]] = list() - for node in self.execplan.graph.nodes(): + for node in graph.nodes(): if isinstance(node, IRWeightReducer): ranks = list(node.device) ranks.sort() ranks = tuple(ranks) if ranks not in comm_groups: comm_groups.append(ranks) - # TODO: collect groups from p2p fusion + # collect groups from p2p fusion + adapters = [n for n in graph.nodes() if isinstance(n, IRAdapter)] + for adapter in adapters: + for prim in adapter.prims(select=False, move=False, merge=False): + if not isinstance(prim, CollectivePrim): + ranks = prim.group + ranks.sort() + ranks = tuple(ranks) + if ranks not in comm_groups: + comm_groups.append(ranks) # create communication group for ranks in comm_groups: code = sign.format(ranks=list(ranks)) @@ -117,8 +128,6 @@ def gen(self, device: int, outfile=None, attach=False) -> str: elif isinstance(node, IRAdapter): node = node.dispatch(rank=device) self.emit_adapter_call(node) - # elif isinstance(node, IRCollectives): - # self.emit_collective_call(node) elif isinstance(node, IRWeightReducer): self.emit_reducer_init(node) self.emit_reducer_call(node) @@ -194,7 +203,6 @@ def emit_node_declare(self, node: IRCell): self.symbols.create(self.tensor_naming(output)) return - def emit_graph_call(self, graph: IRGraph): for node in graph.nodes(): if isinstance(node, IRBpOperation): diff --git a/cube/compiler.py b/cube/compiler.py index 9a0e0d0f..e33397ec 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Tuple, Union import torch import time @@ -6,23 +6,18 @@ from cube.graph import parser from cube.graph.adapter.gen import AdapterGener -from cube.graph.graph import IRGraph from cube.graph.operator.operator import IRDataOperation from cube.logics.pool import SchedulePool from cube.logics.translator import LogicTranslator from cube.execplan import ExectuionPlan -# from cube.execplan.planpass.torchadapt import TorchRefAdapter -# from cube.execplan.planpass.redundant import RemoveRedundantAdapters -# from cube.execplan.planpass.merge import MergeComputeSU -# from cube.execplan.planpass.gfuse import WeightGradAllreduceFusion -# from cube.execplan.planpass.p2pfusion import P2PFusion from cube.execplan.planpass.grouping import Grouping from cube.execplan.planpass.fusion import P2PFusion from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen +from cube.profiler.timer import print_each_rank class SemanticModel: @@ -41,7 +36,6 @@ def get_graph(self): def load_module(self, filename: str): import importlib.util - print(f'> loading generated spatial moduel from {filename}') spec = importlib.util.spec_from_file_location("GenModel", filename) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -108,7 +102,6 @@ def train_step(model, dataloader): def _load_tschedule_fn(filename) -> Callable: import importlib.util - print(f'> [{myrank}] loading generated schedule from {filename} ...') spec = importlib.util.spec_from_file_location( "_train_step", filename ) @@ -142,8 +135,6 @@ def decorator(fn: Callable) -> Callable: for node in graph.nodes(): if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") - # if not SUGraph.is_topo_order(sugraph.sus()): - # raise RuntimeError(f"SUGraph order is not topological order") # generate adapter graph = AdapterGener.gen(graph) @@ -162,33 +153,6 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on p2pfusion operations: {:.2f} s'.format(span)) - - # plan pass to adapt to pytorch semantic: multi branch gradient - # TODO: residual support - # execplan = TorchRefAdapter.apply(execplan) - # plan pass to remove redundant sus - # start = time.time() - # execplan = RemoveRedundantAdapters.apply(execplan) - # span = time.time() - start - # print('> planpass on remove redundant adapter: {:.2f} s'.format(span)) - # # print(f'> after remove redundant adapters:\n {execplan}') - # start = time.time() - # execplan = MergeComputeSU.apply(execplan) - # span = time.time() - start - # print('> planpass on merge compute: {:.2f} s'.format(span)) - # # print(f'> after merge backward SU:\n {execplan}') - # start = time.time() - # execplan = WeightGradAllreduceFusion.apply(execplan) - # span = time.time() - start - # print('> planpass on grad allreduce: {:.2f} s'.format(span)) - # print(f'> after add allreduce:\n{execplan}') - - # start = time.time() - # execplan = P2PFusion.apply(execplan) - # span = time.time() - start - # print('> planpass on p2p fusion: {:.2f} s'.format(span)) - # print(f'> after fuse P2P SU:\n {execplan}') - if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() else: @@ -224,7 +188,7 @@ def decorator(fn: Callable) -> Callable: compile_end = time.time() compile_time = compile_end - compile_start - print(f'> compile time: {compile_time} seconds') + print('> compile time: {:.2f} seconds'.format(compile_time)) if torch.distributed.is_initialized(): torch.distributed.barrier() @@ -236,8 +200,11 @@ def decorator(fn: Callable) -> Callable: dataloader.reset(batch_size=batch_size) # load module - model.load_module(filename.format(myrank)) - # load temporal - return _load_tschedule_fn(filename.format(myrank)) - + filename = filename.format(myrank) + print_each_rank(f'loading generated module from {filename} ...') + model.load_module(filename) + # load temporal schedule + print_each_rank(f'loading generated schedule from {filename} ...') + return _load_tschedule_fn(filename) + return decorator diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 17e67426..cf4cb4cc 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -290,7 +290,7 @@ def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): if not P2PFusion._check_different_outputs_devices(adapters, among=True): continue # gen broadcast - print(f'generating broadcast for tensor: {outputs[tid]}') + print(f'generating broadcast for tensor: {outputs[tid]} ...') # put root rank to the first root = list(device)[0] group = set() From 1d85f801ce0d11f56dd9718adc234260e23d978e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 14:20:54 +0800 Subject: [PATCH 0511/1892] fix collective group bug --- cube/codegen/codegen.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index efae0841..ef09b172 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -93,17 +93,20 @@ def init_comm_groups(self): # collect groups from p2p fusion adapters = [n for n in graph.nodes() if isinstance(n, IRAdapter)] for adapter in adapters: - for prim in adapter.prims(select=False, move=False, merge=False): + for prim in adapter.prims(): if not isinstance(prim, CollectivePrim): - ranks = prim.group - ranks.sort() - ranks = tuple(ranks) - if ranks not in comm_groups: - comm_groups.append(ranks) + continue + ranks = prim.group + ranks.sort() + ranks = tuple(ranks) + if ranks not in comm_groups: + comm_groups.append(ranks) # create communication group + self.declare_region.append('# communication groups') for ranks in comm_groups: code = sign.format(ranks=list(ranks)) self.declare_region.append(code) + self.declare_region.append(' ') def gen(self, device: int, outfile=None, attach=False) -> str: """ From 77eea1dc13e4af6d89c556bdb66137aca445c412 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 14:21:25 +0800 Subject: [PATCH 0512/1892] megatron policy --- examples/mlp/policy/megatron_parallel.py | 74 ++++++++---------------- 1 file changed, 23 insertions(+), 51 deletions(-) diff --git a/examples/mlp/policy/megatron_parallel.py b/examples/mlp/policy/megatron_parallel.py index e5e40df1..15df2fd3 100644 --- a/examples/mlp/policy/megatron_parallel.py +++ b/examples/mlp/policy/megatron_parallel.py @@ -1,60 +1,32 @@ from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph from cube.graph.operator.operator import IRFwOperation, IRDataOperation -def transform_policy(graph: IRGraph, resource): +def PAS(graph: IRGraph, resource): """ - The transformation policy transposes linear using data parallel + Linear Hybrid + Nested Partition """ tp = 2 - dp = int(resource.ngpus // tp) - linear_idx = 0 - for node in graph.nodes(): - # partition data loader at data dimension - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=dp)) - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - # partition operators first in column and then in data - if isinstance(node, IRFwOperation): - all_sub_nodes = list() - if node.algorithms('column') is not None: - if linear_idx % 2 == 0: - print(' ==> column partition') - algo = node.algorithms('column') - else: - print(' ==> row partition') - algo = node.algorithms('row') - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=tp)) - for sub_node in sub_nodes: - print(' ==> data partition') - algo = sub_node.algorithms('data') - ssub_nodes = graph.partition(sub_node, algo, config=dict(chunk_num=dp)) - all_sub_nodes += ssub_nodes - linear_idx += 1 + dp = resource.ngpus // tp + for idx, node in enumerate(graph.nodes()): + if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): + if idx % 2 == 0: + algo = node.algorithms('row') else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - all_sub_nodes += sub_nodes - # add tags (vdev) for node - for idx, ssub_node in enumerate(all_sub_nodes): - ssub_node.tag = idx - print(graph) + algo = node.algorithms('column') + if algo: + sub_nodes = list() + tp_nodes = graph.partition( + node, algo, config=dict(chunk_num=tp) + ) + for tp_node in tp_nodes: + algo = tp_node.algorithms('data') + dp_nodes = graph.partition( + tp_node, algo, config=dict(chunk_num=dp)) + sub_nodes += dp_nodes + else: + sub_nodes = [node] + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + print(graph.extra_repr()) return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - return sugraph \ No newline at end of file From 1c0ee2e694fb2d6e461abac28c3e746650578414 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 15:17:51 +0800 Subject: [PATCH 0513/1892] fix broadcast bug --- cube/codegen/codegen.py | 2 +- cube/execplan/planpass/fusion.py | 8 +++++--- cube/graph/adapter/__init__.py | 1 - cube/graph/adapter/adapter.py | 29 ++++++++++++++--------------- cube/graph/graph.py | 4 ++++ 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index ef09b172..98a4c145 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -96,7 +96,7 @@ def init_comm_groups(self): for prim in adapter.prims(): if not isinstance(prim, CollectivePrim): continue - ranks = prim.group + ranks = list(prim.group) ranks.sort() ranks = tuple(ranks) if ranks not in comm_groups: diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index cf4cb4cc..37fbef13 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -283,8 +283,10 @@ def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): break if not cond: continue # cond 2) - device = set([adapter.idevice(0)[0] for adapter in adapters]) - if len(device) != 1: + root_device = set() + for adapter in adapters: + root_device.update(P2PFusion._get_input_devices(adapter)) + if len(root_device) != 1: continue # cond 3) if not P2PFusion._check_different_outputs_devices(adapters, among=True): @@ -292,7 +294,7 @@ def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): # gen broadcast print(f'generating broadcast for tensor: {outputs[tid]} ...') # put root rank to the first - root = list(device)[0] + root = list(root_device)[0] group = set() for adapter in adapters: group.update(P2PFusion._get_output_devices(adapter)) diff --git a/cube/graph/adapter/__init__.py b/cube/graph/adapter/__init__.py index 2f1a83e2..e69de29b 100644 --- a/cube/graph/adapter/__init__.py +++ b/cube/graph/adapter/__init__.py @@ -1 +0,0 @@ -from cube.graph.adapter.gen import AdapterGener diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index 07cf9a4f..353e7a12 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -49,14 +49,14 @@ class Type(Enum): Broadcast = 'broadcast' def __init__(self, ctype: Enum, - device: List[int], - group: List[int], - inputs: List[IRSubTensor] = None, - input_shapes: List[List[int]] = None, - input_dtypes: List[IRDType] = None, - outputs: List[IRSubTensor] = None, - output_shapes: List[List[int]] = None, - output_dtypes: List[IRDType] = None): + device: Tuple[int], + group: Tuple[int], + inputs: Tuple[IRSubTensor] = None, + input_shapes: Tuple[Tuple[int]] = None, + input_dtypes: Tuple[IRDType] = None, + outputs: Tuple[IRSubTensor] = None, + output_shapes: Tuple[Tuple[int]] = None, + output_dtypes: Tuple[IRDType] = None): """ inputs: the collective input tensors. Including remote tensors. @@ -73,18 +73,17 @@ def __init__(self, ctype: Enum, """ self.ctype = ctype # inputs - self.inputs: List[IRSubTensor] = inputs if inputs is not None else list() - self.input_shapes: List[IRSubTensor] = input_shapes - self.input_dtypes: List[IRDType] = input_dtypes + self.inputs: Tuple[IRSubTensor] = tuple(inputs) if inputs is not None else list() + self.input_shapes: Tuple[IRSubTensor] = input_shapes + self.input_dtypes: Tuple[IRDType] = input_dtypes # outputs - self.outputs: List[IRSubTensor] = outputs if outputs is not None else list() + self.outputs: Tuple[IRSubTensor] = outputs if outputs is not None else list() self.output_shapes: List[IRSubTensor] = output_shapes self.output_dtypes: List[IRDType] = output_dtypes # communication group - group.sort() - self.group: List[int] = group + self.group: Tuple[int] = tuple(group) # device - self.device = device + self.device = tuple(device) def __repr__(self): dscp = f'{self.outputs} = {self.ctype.value}(inputs={self.inputs}, group={self.group})' diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 664407eb..ddbbacbe 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -12,6 +12,7 @@ from cube.ir.cten import IRTensor, IRCell from cube.graph.operator.operator import IRBpOperation, IRFwOperation +from cube.graph.adapter.adapter import IRAdapter from cube.graph.tensor import IRSubTensor from cube.algorithm.generics import GenericDistAlgo @@ -71,6 +72,9 @@ def reset_dependency(self): for src_idx in range(len(self._nodes)): src_node = self._nodes[src_idx] for dst_node in self._nodes[src_idx+1:]: + # we don't consider dependencies among adapter + if isinstance(src_node, IRAdapter) and isinstance(dst_node, IRAdapter): + continue for out_idx, out_tensor in enumerate(src_node.outputs()): if not isinstance(out_tensor, IRTensor): continue From f6cdecda927e8de06eacc9e838d627e8c0f7a561 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Dec 2021 15:21:03 +0800 Subject: [PATCH 0514/1892] send will not go into broadcast --- cube/execplan/planpass/fusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 37fbef13..b9b9c771 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -272,6 +272,9 @@ def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): for tid in outputs: adapters: List[IRAdapter] = groups[tid] cond = True + # note send can also be broadcast. We skip this case + if len(adapters) <= 2: + continue # cond 1) if not P2PFusion._check_same_inputs(adapters): continue From 28a045938c14dfea814bd91010c997befc70c8e7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 1 Jan 2022 12:39:39 +0800 Subject: [PATCH 0515/1892] fix non-contiguous bug --- cube/runtime/adapter/collectives.py | 4 ++- examples/mlp/linears.py | 2 +- examples/mlp/policy/col_parallel.py | 40 --------------------- examples/mlp/policy/pipe_parallel.py | 54 ++++------------------------ examples/mlp/policy/row_parallel.py | 34 +++++------------- 5 files changed, 19 insertions(+), 115 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index afea9feb..c22a9cca 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -171,7 +171,9 @@ def broadcast(input_tensors: List[torch.Tensor], CudaTimer().start(field_name='comm') assert len(input_tensors) == 1 or len(input_tensors) == 0 if len(input_tensors) == 1: - tensor = input_tensors[0] + tensor: torch.Tensor = input_tensors[0] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() else: assert len(output_shapes) == 1 assert len(output_dtypes) == 1 diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 49efa870..e1dd2978 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,7 +17,7 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.col_parallel import PAS +from examples.mlp.policy.row_parallel import PAS # =================== Semantic Model Description ==================== diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index 20b92722..af3b82b7 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -2,46 +2,6 @@ from cube.graph.operator.operator import IRDataOperation, IRFwOperation -# def transform_policy(graph: IRGraph, resource): -# """ -# The transformation policy transposes linear using column parallel -# """ -# for node in graph.nodes(): -# if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): -# algo = node.algorithms('column') -# if algo: -# sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) -# else: -# sub_nodes = graph.replicate(node, times=resource.ngpus) -# for idx, sub_node in enumerate(sub_nodes): -# sub_node.tag = idx -# print(graph) -# return graph -# -# -# def schedule_policy(sugraph: SUGraph, resource): -# """ -# The schedule policy assign devices -# """ -# # print(sugraph) -# for su in sugraph.sus(): -# if su.stype == SUType.Dataloader: -# devid = su.tag[0] -# sugraph.assign(su, devid) -# # sugraph.assign(su, list(range(resource.ngpus))) -# for su in sugraph.fsus(): -# devid = su.tag[0] -# sugraph.assign(su, devid) -# if su.mirror is None: -# print(f'error su: {su}') -# assert False -# sugraph.assign(su.mirror, devid) -# fsus = sugraph.fsus() -# print('> [scheduling] setting schedule order...') -# sugraph.partial_set_order(fsus, lazy=False) -# return sugraph - - def PAS(graph: IRGraph, resource): """ Linear Column Partition diff --git a/examples/mlp/policy/pipe_parallel.py b/examples/mlp/policy/pipe_parallel.py index 912e6974..c25a8777 100644 --- a/examples/mlp/policy/pipe_parallel.py +++ b/examples/mlp/policy/pipe_parallel.py @@ -1,15 +1,12 @@ -from typing import List import math import random -from cube.schedule.su import SUType, ScheduleUnit -from cube.schedule.sugraph import SUGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation -def transform_policy(graph, resource): +def PAS(graph, resource): """ - The transformation policy transposes linear using data parallel + Random pipeline """ micro_batch_num = resource.ngpus for node in graph.nodes(): @@ -18,48 +15,9 @@ def transform_policy(graph, resource): if algo is not None: sub_nodes = graph.partition(node, algo, config=dict(chunk_num=micro_batch_num)) else: - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=micro_batch_num)) + sub_nodes = [node] for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx + device = random.randint(0, resource.ngpus - 1) + graph.assign(sub_node, device) + print(graph.extra_repr()) return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy - """ - micro_batch_num = resource.ngpus - - fseqs: List[List[ScheduleUnit]] = [list() for _ in range(micro_batch_num)] - fbseqs: List[List[ScheduleUnit]] = [list() for _ in range(micro_batch_num)] - - for fsu in sugraph.fsus(): - micro_bs_id = fsu.tag[0] - fseqs[micro_bs_id].append(fsu) - - for micro_bs_id, fseq in enumerate(fbseqs): - bseq = [fsu.mirror for fsu in fseq][::-1] - fbseqs[micro_bs_id] = fseq + bseq - - # device assignment - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - - print(f'> collect {len(fseqs)} forward-backward sequence') - for fseq in fseqs: - chunk_num = int(math.ceil(len(fseq) / resource.ngpus)) - for idx, su in enumerate(fseq): - # devid = int(idx // chunk_num) - # devid = idx % resource.ngpus - devid = random.randint(0, resource.ngpus - 1) - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - - seqs = list() - for fb_seq in fbseqs: - seqs += fb_seq - sugraph.partial_set_order(seqs) - print(sugraph) - return sugraph diff --git a/examples/mlp/policy/row_parallel.py b/examples/mlp/policy/row_parallel.py index 0a07a473..8f5d4e8e 100644 --- a/examples/mlp/policy/row_parallel.py +++ b/examples/mlp/policy/row_parallel.py @@ -1,37 +1,21 @@ from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph from cube.graph.operator.operator import IRFwOperation, IRDataOperation -def transform_policy(graph: IRGraph, resource): +def PAS(graph: IRGraph, resource): """ - The transformation policy transposes linear using column parallel + Linear Column Partition """ for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): algo = node.algorithms('row') if algo: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=resource.ngpus)) + sub_nodes = graph.partition( + node, algo, config=dict(chunk_num=resource.ngpus) + ) else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx + sub_nodes = [node] + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + print(graph.extra_repr()) return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - # print(sugraph) - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - # sugraph.assign(su, list(range(resource.ngpus))) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - return sugraph From a72d6877fcd278d6eac5ffe80342929484444614 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 1 Jan 2022 12:42:27 +0800 Subject: [PATCH 0516/1892] reduce scatter contigous tensor --- cube/runtime/adapter/collectives.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index c22a9cca..7602d638 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -151,6 +151,9 @@ def reduce_scatter(input_tensors: List[torch.Tensor], """ CudaTimer().start(field_name='comm') input_tensors = list(input_tensors) + for idx, tensor in enumerate(input_tensors): + if not tensor.is_contiguous(): + input_tensors[idx] = tensor.contiguous() group = DeviceGroup().get_group(ranks) idx = ranks.index(DeviceGroup().rank) output = torch.empty_like(input_tensors[idx], requires_grad=True) From a03e6fa072ce3638c6a68bf42d0e21dd2a98127c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 3 Jan 2022 14:32:01 +0800 Subject: [PATCH 0517/1892] enable replicate operators --- cube/graph/adapter/adapter.py | 17 +++++-- cube/graph/graph.py | 80 ++++++++++++++--------------- cube/graph/operator/operator.py | 89 ++++++++++++++++++++++++++++++--- cube/graph/tensor.py | 10 ++-- cube/ir/cten.py | 8 +-- 5 files changed, 142 insertions(+), 62 deletions(-) diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index 353e7a12..d158235b 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -331,6 +331,9 @@ def gen(dst_tensor: IRSubTensor): @staticmethod def gen_select(dst_tensor): + # TODO: consider previous adapter output as later adapter in-tensor + # for residual cases + inputs = list() intersections = list() prims = list() @@ -338,6 +341,7 @@ def gen_select(dst_tensor): otensor = dst_tensor odevice = otensor.device + # local and remote adapter in-tensor local, remote = list(), list() for ptensor in otensor.parent.ptensors: if ptensor.device == odevice: @@ -345,19 +349,19 @@ def gen_select(dst_tensor): else: remote.append(ptensor) - # check local tensor + # first check local in tensor if otensor in local: intersections.append(otensor) inputs.append(otensor) return inputs, intersections, prims - - # FIXME: multi producer may result overlapped region - for itensor in otensor.parent.ptensors: + + # check local + remote + for itensor in local + remote: if not itensor.overlap(otensor): continue # intersection - common = otensor.common(itensor) + common: IRSubTensor = otensor.common(itensor) common.attach_cell(itensor._cell) intersections.append(common) inputs.append(itensor) @@ -391,6 +395,9 @@ def gen_select(dst_tensor): ) prim = SelectPrim(itensor, indmap, valmap, common.shape, common) prims.append(prim) + # TODO: check union == otensor + if common == otensor: + break return inputs, intersections, prims diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ddbbacbe..ea90e372 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -11,7 +11,7 @@ import copy from cube.ir.cten import IRTensor, IRCell -from cube.graph.operator.operator import IRBpOperation, IRFwOperation +from cube.graph.operator.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.graph.adapter.adapter import IRAdapter from cube.graph.tensor import IRSubTensor @@ -114,30 +114,6 @@ def nodes(self, index: Optional[int] = None): else: raise TypeError("Expected index to be None or int") - def _replace_tensor(self, old_tensor: IRTensor, new_tensor: IRTensor): - """ - Replace tensor from old_tensor to new_tensor for all the graph. - """ - def _replace_inputs(cell, old_tensor, new_tensor): - index = cell.inputs().index(old_tensor) - cell.set_input(index, new_tensor) - - def _replace_outputs(cell, old_tensor, new_tensor): - index = cell.outputs().index(old_tensor) - cell.set_output(index, new_tensor) - - if old_tensor in self.inputs(): - _replace_inputs(self, old_tensor, new_tensor) - - for node in self.nodes(): - if old_tensor in node.inputs(): - _replace_inputs(node, old_tensor, new_tensor) - if old_tensor in node.outputs(): - _replace_outputs(node, old_tensor, new_tensor) - - if old_tensor in self.outputs(): - _replace_outputs(self, old_tensor, new_tensor) - def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: """ forward will divide the graph into Actions according to @@ -257,15 +233,14 @@ def get_outputs(nodes: List[IRCell]): ## Parallel Policy Primitives ## - def replicate(self, op: IRCell, times=1): + def replicate(self, op: IRCell, times=1) -> Optional[List[IRCell]]: """ - Replicate an operation with multiple times. + Replicate a forward or data operation multiple times. - This is temporary use to enable assign with multiple devices + The backward of the forward operation will automatically be replicated. """ - raise NotImplementedError("Replicate is not supported yet") - if not isinstance(op, IRCell): - raise TypeError("Expected an IRCell") + if not (isinstance(op, IRFwOperation) or isinstance(op, IRDataOperation)): + raise TypeError("Expected op to be forward op or data op") if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") @@ -274,15 +249,16 @@ def replicate(self, op: IRCell, times=1): ops = [op] for _ in range(times - 1): - dup_op = op.replicate() - if op.mirror is not None: - dup_op.gen_backward() - ops.append(dup_op) + ops.append(op.replicate()) + if isinstance(op.mirror, IRBpOperation): + for rep_op in ops[1:]: + print(rep_op) + rep_op.gen_backward() idx = self.nodes().index(op) # forward self._nodes = self._nodes[:idx] + ops + self._nodes[idx+1:] # backward - if op.mirror is not None: + if isinstance(op.mirror, IRCell): bops = [op.mirror for op in ops][::-1] midx = self.nodes().index(op.mirror) self._nodes = self._nodes[:midx] + bops + self._nodes[midx+1:] @@ -362,15 +338,35 @@ def identity(self, input_tensor, dst_op): ## Assign Policy Primitives ## - def assign(self, op: IRCell, rank: int): + def assign(self, op: IRCell, ranks: Union[int, List[int]]): + """ + Assign an operator (subgraph) to (multiple) rank(s). + + If `ranks` has multiple integer, then the operator will be replicated + `len(ranks)` times and assigned to given device correspondingly. + + Corresponding backward operators (if have) will also be replicated + and assigned to the same device with it's forward operator + + Returns: + True if assigned successfully. + False if not. + """ if op not in self._nodes: raise KeyError(f"{op} is not in the graph") - if not isinstance(rank, int): + if isinstance(ranks, int): + ranks = [ranks] + if not all([isinstance(rank, int) for rank in ranks]): raise TypeError("Expected rank to be int") - op.device = rank - # pytorch requirement - if op.mirror is not None: - op.mirror.device = rank + if len(ranks) > 1: + ops = self.replicate(op, times=len(ranks)) + else: + ops = [op] + for op, rank in zip(ops, ranks): + op.device = rank + # pytorch requirement: forward + backward happened on same device + if op.mirror is not None: + op.mirror.device = rank return True ## Schedule Policy Primitives ## diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index b004319d..c7146eb6 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -4,11 +4,70 @@ from cube.ir.cten import IRCell from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory +from cube.ir.unique import IDGenerator __all__ = ['IRFwOperation', 'IRBpOperation', 'IRDataOperation'] +class BaseOperator: + + def __init__(self, name: str, signature: str, + input_length: int, output_length: int, + init_outputs=False): + super().__init__(name, signature, + input_length, output_length, + init_outputs=init_outputs) + + def infer_shape(self): + """ + Infer output value shape + """ + raise NotImplementedError + + def set_input(self, input_index: int, val: Any): + # remove the consumer + old_val = self.inputs(input_index) + if isinstance(old_val, IRSubTensor): + old_val.parent.rm_consumer(self) + # add the consumer + val = super().set_input(input_index, val) + if isinstance(val, IRSubTensor): + val.parent.add_consumer(self, val) + return val + + def set_output(self, output_index: int, val: Any): + # remove the producer + old_val = self.outputs(output_index) + if isinstance(old_val, IRSubTensor): + old_val.parent.rm_producer(self) + # add the producer + val = super().set_output(output_index, val) + if isinstance(val, IRSubTensor): + val.parent.add_producer(self, val) + return val + + def replicate(self): + """ + Replicate the Operation + """ + cpy = copy.copy(self) + cpy._device = list() + cpy._id = IDGenerator().gen_cell_id() + # reset input and output + cpy._inputs = [None] * len(self.inputs()) + for idx, input in enumerate(self.inputs()): + cpy.set_input(idx, input) + cpy._outputs = [None] * len(self.outputs()) + for idx, output in enumerate(self.outputs()): + cpy.set_output(idx, output) + cpy._mirror = None + cpy._tag = None + cpy.clear_predecessor() + cpy.clear_successor() + return cpy + + class IRFwOperation(IRCell): def __init__(self, @@ -88,8 +147,14 @@ def replicate(self): """ cpy = copy.copy(self) cpy._device = list() - cpy._inputs = copy.copy(self._inputs) - cpy._outputs = copy.copy(self._outputs) + # cpy._id = IDGenerator().gen_cell_id() + # reset input and output + cpy._inputs = [None] * len(self.inputs()) + for idx, input in enumerate(self.inputs()): + cpy.set_input(idx, input) + cpy._outputs = [None] * len(self.outputs()) + for idx, output in enumerate(self.outputs()): + cpy.set_output(idx, output) cpy._mirror = None cpy._tag = None cpy.clear_predecessor() @@ -158,8 +223,14 @@ def replicate(self): """ cpy = copy.copy(self) cpy._device = list() - cpy._inputs = copy.copy(self._inputs) - cpy._outputs = copy.copy(self._outputs) + cpy._id = IDGenerator().gen_cell_id() + # reset input and output + cpy._inputs = [None] * len(self.inputs()) + for idx, input in enumerate(self.inputs()): + cpy.set_input(idx, input) + cpy._outputs = [None] * len(self.outputs()) + for idx, output in enumerate(self.outputs()): + cpy.set_output(idx, output) cpy._mirror = None cpy._tag = None cpy.clear_predecessor() @@ -283,8 +354,14 @@ def replicate(self): """ cpy = copy.copy(self) cpy._device = list() - cpy._inputs = copy.copy(self._inputs) - cpy._outputs = copy.copy(self._outputs) + cpy._id = IDGenerator().gen_cell_id() + # reset input and output + cpy._inputs = [None] * len(self.inputs()) + for idx, input in enumerate(self.inputs()): + cpy.set_input(idx, input) + cpy._outputs = [None] * len(self.outputs()) + for idx, output in enumerate(self.outputs()): + cpy.set_output(idx, output) cpy._mirror = None cpy._tag = None cpy.clear_predecessor() diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 271c4ae8..6b01f56f 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -589,16 +589,16 @@ def get_grad(self, fcell: IRCell): self.grad = None return None if self in fcell.inputs(): - ref_cells = list() + ref_cell_ids = list() for dst_cell in self.parent.consumers: for input in dst_cell.inputs(): - if self.overlap(input): - ref_cells.append(dst_cell) + if self.overlap(input) and dst_cell._id not in ref_cell_ids: + ref_cell_ids.append(dst_cell._id) break - ref_times = len(ref_cells) + ref_times = len(ref_cell_ids) if ref_times == 0: raise RuntimeError("Internal Error: ref time is 0") - idx = ref_cells.index(fcell) + idx = ref_cell_ids.index(fcell._id) grad = full_grad.select( indmap = self.indmap, valmap = (idx, ref_times), diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 3c31823c..0622b61f 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -71,10 +71,10 @@ def __init__(self, self._mirror = None self._tag = None - def __eq__(self, other): - if isinstance(other, IRCell): - return self._id == other._id - return False + # def __eq__(self, other): + # if isinstance(other, IRCell): + # return self._id == other._id + # return False @property def device(self): From f3f2becc500e46e2b8ae0bbf3409163fc0783442 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 3 Jan 2022 14:32:25 +0800 Subject: [PATCH 0518/1892] use replicate for small ops --- examples/mlp/linears.py | 2 +- examples/mlp/policy/col_parallel.py | 40 +++++++++++++++++++++++- examples/mlp/policy/data_parallel.py | 2 +- examples/mlp/policy/hybrid_parallel.py | 2 +- examples/mlp/policy/megatron_parallel.py | 2 +- examples/mlp/policy/row_parallel.py | 2 +- 6 files changed, 44 insertions(+), 6 deletions(-) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index e1dd2978..10cf8882 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,7 +17,7 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.row_parallel import PAS +from examples.mlp.policy.data_parallel import PAS # =================== Semantic Model Description ==================== diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index af3b82b7..6cb16844 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -2,6 +2,43 @@ from cube.graph.operator.operator import IRDataOperation, IRFwOperation +def P(graph, resource): + """ + P policy + """ + for node in graph.nodes(): + if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): + algo = node.algorithms('column') + if algo: + sub_nodes = graph.partition( + node, algo, config=dict(chunk_num=resource.ngpus) + ) + else: + # graph.assign(node, list(range(resource.ngpus))) + sub_nodes = graph.replicate(node, times=resource.ngpus) + # device hint + for idx, node in enumerate(sub_nodes): + node.tag = idx + return graph + + +def A(graph, resource): + """ + A policy + """ + for node in graph.nodes(): + if node.tag is not None: + device = node.tag + graph.assign(node, device) + + +def S(graph, resource): + """ + Schedule Policy. => use default schedule + """ + return graph + + def PAS(graph: IRGraph, resource): """ Linear Column Partition @@ -14,7 +51,8 @@ def PAS(graph: IRGraph, resource): node, algo, config=dict(chunk_num=resource.ngpus) ) else: - sub_nodes = [node] + # graph.assign(node, list(range(resource.ngpus))) + sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) print(graph.extra_repr()) diff --git a/examples/mlp/policy/data_parallel.py b/examples/mlp/policy/data_parallel.py index 6e1e540d..48c8bfd3 100644 --- a/examples/mlp/policy/data_parallel.py +++ b/examples/mlp/policy/data_parallel.py @@ -14,7 +14,7 @@ def PAS(graph: IRGraph, resource): node, algo, config=dict(chunk_num=resource.ngpus) ) else: - sub_nodes = [node] + sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) print(graph.extra_repr()) diff --git a/examples/mlp/policy/hybrid_parallel.py b/examples/mlp/policy/hybrid_parallel.py index bd0c3784..289d5121 100644 --- a/examples/mlp/policy/hybrid_parallel.py +++ b/examples/mlp/policy/hybrid_parallel.py @@ -17,7 +17,7 @@ def PAS(graph: IRGraph, resource): node, algo, config=dict(chunk_num=resource.ngpus) ) else: - sub_nodes = [node] + sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) print(graph.extra_repr()) diff --git a/examples/mlp/policy/megatron_parallel.py b/examples/mlp/policy/megatron_parallel.py index 15df2fd3..502fbeeb 100644 --- a/examples/mlp/policy/megatron_parallel.py +++ b/examples/mlp/policy/megatron_parallel.py @@ -25,7 +25,7 @@ def PAS(graph: IRGraph, resource): tp_node, algo, config=dict(chunk_num=dp)) sub_nodes += dp_nodes else: - sub_nodes = [node] + sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) print(graph.extra_repr()) diff --git a/examples/mlp/policy/row_parallel.py b/examples/mlp/policy/row_parallel.py index 8f5d4e8e..d8286178 100644 --- a/examples/mlp/policy/row_parallel.py +++ b/examples/mlp/policy/row_parallel.py @@ -14,7 +14,7 @@ def PAS(graph: IRGraph, resource): node, algo, config=dict(chunk_num=resource.ngpus) ) else: - sub_nodes = [node] + sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) print(graph.extra_repr()) From 2ac3c299ad0504bdbb88260e23e4dfa79f113079 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 3 Jan 2022 19:33:31 +0800 Subject: [PATCH 0519/1892] fix graph set order bug; add 1f1b pipeline --- cube/graph/adapter/gen.py | 14 +-- cube/graph/graph.py | 10 ++- examples/mlp/linears.py | 2 +- examples/mlp/policy/pipe1f1b_parallel.py | 106 +++++++---------------- 4 files changed, 49 insertions(+), 83 deletions(-) diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index 0d271c33..b0d6c35a 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -24,16 +24,16 @@ def gen(graph: IRGraph, eager=True) -> IRGraph: Returns: graph (IRGraph) """ + # update the gradient before generate adapter + for node in graph.nodes(): + if isinstance(node, IRBpOperation): + node.update() graph = AdapterGener.gen_activation_adapter(graph, eager) graph = AdapterGener.gen_weight_reducer(graph) return graph @staticmethod def gen_activation_adapter(graph: IRGraph, eager=True) -> IRGraph: - # update the gradient before generate adapter - for node in graph.nodes(): - if isinstance(node, IRBpOperation): - node.update() all_adapters = list() # generate adapter for non-weight values for node in graph.nodes(): @@ -90,7 +90,11 @@ def gen_weight_reducer(graph: IRGraph) -> IRGraph: if devid not in grads[input._id]: grads[input._id][devid] = list() if grad in grads[input._id][devid]: - raise RuntimeError("Already logged grad?") + raise RuntimeError( + "Find two same gradient (not expected). " + "This is usually due to replicated node assigned to same device. " + f"\nCheck node:\n\t{fnode}" + ) grads[input._id][devid].append(grad) # step 2: generate reducers. # reducers: tuple(ranks): List[weight] diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ea90e372..1d84d892 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -438,7 +438,10 @@ def partial_set_order(self, seq: List[IRCell], eager=True): for node in remain: if eager: pre_indices = [seq.index(pre) for pre in node.predecessors()] - index = max(pre_indices) + 1 + if len(pre_indices) == 0: + index = 0 + else: + index = max(pre_indices) + 1 else: suc_indices = [seq.index[suc] for suc in node.successors()] index = min(suc_indices) @@ -461,6 +464,7 @@ def check_legal_order(seq: List[IRCell], integrity_check=False): Returns: Boolean: True for satisfying topo order, otherwise False. """ + #TODO: check no new operators are created (including replicate) for index, node in enumerate(seq): for pre in node.predecessors(): if pre in seq: @@ -481,14 +485,12 @@ def extra_repr(self): dscp += f"Inputs: {self.inputs()}\n" # nodes for node in self._nodes: - # if isinstance(node, IRBpOperation): - # continue succ_node_ids = [node._id for node in node.successors()] # succ_node_ids = [None] * len(node.outputs()) # for out_idx in range(len(node.outputs())): # node_list = [snode._id for snode in node.successors(out_idx)] # succ_node_ids[out_idx] = node_list - dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}\n" + dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}" # outputs dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" return dscp diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 10cf8882..d2f218eb 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,7 +17,7 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.data_parallel import PAS +from examples.mlp.policy.pipe1f1b_parallel import PAS # =================== Semantic Model Description ==================== diff --git a/examples/mlp/policy/pipe1f1b_parallel.py b/examples/mlp/policy/pipe1f1b_parallel.py index 96b8d42a..a95e540e 100644 --- a/examples/mlp/policy/pipe1f1b_parallel.py +++ b/examples/mlp/policy/pipe1f1b_parallel.py @@ -1,91 +1,50 @@ -from typing import List -import math - -from cube.schedule.su import SUType, ScheduleUnit -from cube.schedule.sugraph import SUGraph +from cube.graph.graph import IRGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation -def transform_policy(graph, resource): +def PAS(graph: IRGraph, resource): """ - The transformation policy transposes linear using data parallel + 1F1B scheduling """ - micro_batch_num = resource.ngpus - for node in graph.nodes(): - if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): - algo = node.algorithms('data') - if algo is not None: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=micro_batch_num)) - else: - algo = node.algorithms('dim') - # dim trace - sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=micro_batch_num)) - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - return graph - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy - """ - # each device is a stage num_micro_batch = resource.ngpus num_stage = resource.ngpus - fseqs: List[List[ScheduleUnit]] = [list() for _ in range(num_micro_batch)] - fbseqs: List[List[ScheduleUnit]] = [list() for _ in range(num_micro_batch)] - - for fsu in sugraph.fsus(): - micro_bs_id = fsu.tag[0] - fseqs[micro_bs_id].append(fsu) - - for micro_bs_id, fseq in enumerate(fbseqs): - bseq = [fsu.mirror for fsu in fseq][::-1] - fbseqs[micro_bs_id] = fseq + bseq - - print(f'> collect {len(fseqs)} forward-backward sequence') - - # fstages[micro_batch_id][stage] = fstages[micro_batch_id * num_stage + stage] - fstages: List[List[ScheduleUnit]] = [ - list() for _ in range(num_micro_batch * num_stage) - ] + fstages = [list() for _ in range(num_micro_batch * num_stage)] - def f(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: + def f(micro_batch_id: int, stage_id: int): return fstages[micro_batch_id * num_stage + stage_id] - def b(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: + def b(micro_batch_id: int, stage_id: int): fstage = f(micro_batch_id, stage_id) - bstage = [fsu.mirror for fsu in fstage][::-1] + bstage = [fnode.mirror for fnode in fstage][::-1] return bstage - - # assign su to SU Group - for micro_bid, fseq in enumerate(fseqs): - chunk_num = int(len(fseq) // resource.ngpus) - for idx, fsu in enumerate(fseq): - stage = min(int(idx // chunk_num), num_stage - 1) - fstages[micro_bid * num_stage + stage].append(fsu) - - # stage device assignment - for micro_bid in range(num_micro_batch): - for stage in range(num_stage): - for su in f(micro_bid, stage): - sugraph.assign(su, stage) - sugraph.assign(su.mirror, stage) - # device assignment - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + stage_op_num = len(fnodes) // num_stage + for idx, node in enumerate(fnodes): + stage = min(idx // stage_op_num, num_stage - 1) + sub_nodes = None + algo = node.algorithms('data') + if algo is not None: + sub_nodes = graph.partition(node, algo, config=dict(chunk_num=num_micro_batch)) + else: + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=num_micro_batch)) + for mid, sub_node in enumerate(sub_nodes): + f(mid, stage).append(sub_node) + graph.assign(sub_node, stage) - # 1f1b scheduling + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + graph.assign(node, 0) + + # 1F1B scheduling seqs = list() - # warmup - for stage in range(num_stage): - for mid in range(stage): + for mid in range(num_micro_batch): + for stage in range(num_stage - mid): seqs += f(mid, stage) - # steady + cooldown: for mid in range(num_micro_batch): # enqueue backward @@ -93,11 +52,12 @@ def b(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: seqs += b(mid, stage) # enqueue forward for stage in range(num_stage): - f_mid = mid + 1 + num_stage - stage + f_mid = mid + num_stage - stage if f_mid >= num_micro_batch: continue seqs += f(f_mid, stage) + for node in seqs: + print(node) + graph.partial_set_order(seqs) - sugraph.partial_set_order(seqs) - # print(sugraph) - return sugraph + return graph From d3406ecf7cb9f87ed171c39a05e97ad2b552153b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 4 Jan 2022 11:06:38 +0800 Subject: [PATCH 0520/1892] separate PAS --- examples/mlp/linears.py | 5 ++++- examples/mlp/policy/col_parallel.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index d2f218eb..7a32963d 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,7 +17,10 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.pipe1f1b_parallel import PAS +# from examples.mlp.policy.col_parallel import PAS + +from examples.mlp.policy.col_parallel import P, A, S +PAS = (P, A, S) # =================== Semantic Model Description ==================== diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index 6cb16844..8ba0fa99 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -30,6 +30,7 @@ def A(graph, resource): if node.tag is not None: device = node.tag graph.assign(node, device) + return graph def S(graph, resource): From 5d3f6e4d565000d16b608fb928b83993b3a2e062 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 4 Jan 2022 12:44:14 +0800 Subject: [PATCH 0521/1892] update execplan draw --- cube/compiler.py | 2 + cube/execplan/execplan.py | 106 ++++++++++++++++++++++++-------------- 2 files changed, 68 insertions(+), 40 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index e33397ec..b4bfcad5 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -153,6 +153,8 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on p2pfusion operations: {:.2f} s'.format(span)) + # execplan.draw(outfile='execplan.png') + if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() else: diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 491b32c2..d32ffd4b 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -71,46 +71,71 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): outfile: the output file name """ + self.graph.reset_dependency() ndevice = len(self.devices()) # timeline [ [ (start_time, end_time), ... ], ... ] device_timeline = [list() for _ in range(ndevice)] - device_sus = [list() for _ in range(ndevice)] + device_nodes = [list() for _ in range(ndevice)] + + def map2time(node): + if isinstance(node, IRGraph): + span = 0 + for node in node.nodes(): + span += map2time(node) + if isinstance(node, IRFwOperation): + return 1 + if isinstance(node, IRBpOperation): + return 2 + if isinstance(node, IRAdapter): + return 0.5 + return 0 + + def map2color(node): + if isinstance(node, IRGraph): + return map2color(node.nodes(0)) + if isinstance(node, IRFwOperation): + return 'blue' + if isinstance(node, IRBpOperation): + return 'orange' + if isinstance(node, IRAdapter): + return 'green' + + def map2name(node): + if isinstance(node, IRGraph): + if all([isinstance(n, IRFwOperation) for n in node.nodes()]): + return f'f{node._id}' + if all([isinstance(n, IRBpOperation) for n in node.nodes()]): + if node.mirror is not None: + return f'b{node.mirror._id}' + return str(node._id) if spans is None: + print("Using default timing: fwop=1, bwop=2, adapter=0.1") spans = list() for node in self.graph.nodes(): - span = 0 - if isinstance(node, IRFwOperation): - span = 1 - elif isinstance(node, IRBpOperation): - span = 2 - elif isinstance(node, IRAdapter): - span = 0.1 - else: - span = 0 + span = map2time(node) spans.append(span) - for su, span_time in zip(self.seq.sequence, spans): - device = su.device[0] - - # tight execution if no dependency - if len(device_timeline[device]) == 0: - start_time = 1 - else: - start_time = device_timeline[device][-1][1] - - # check dependency - for devid, (timeline, dev_sus) in enumerate(zip(device_timeline, device_sus)): - if devid == device: - continue - for suid, (_, end_time) in enumerate(timeline[::-1]): - other_su = dev_sus[::-1][suid] - if other_su.happen_before(su): - start_time = max(start_time, end_time) - break - - device_timeline[device].append((start_time, start_time + span_time)) - device_sus[device].append(su) + graph = self.graph + for node, span_time in zip(self.graph.nodes(), spans): + for device in node.device: + # tight execution if no dependency + if len(device_timeline[device]) == 0: + start_time = 1 + else: + start_time = device_timeline[device][-1][1] + # check dependency + for devid, timeline in enumerate(device_timeline): + dev_seq = device_nodes[devid] + if devid == device: + continue + for nid, (_, end_time) in enumerate(timeline[::-1]): + other_node = dev_seq[::-1][nid] + if graph.happen_before(other_node, node): + start_time = max(start_time, end_time) + break + device_timeline[device].append((start_time, start_time + span_time)) + device_nodes[device].append(node) # draw the timeline if outfile is not None: @@ -124,13 +149,13 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): fig, ax = plt.subplots() ax.set_xlim((1, max_time)) - plt.xticks(list(range(1, max_time+1, 1))) + plt.xticks(list(range(1, int(max_time)+1, 1))) ax.xaxis.grid(True, linestyle='--') plt.xlabel('time') # yaxis - ax.set_ylim((0.5, self.ndevice+0.5)) - plt.yticks(list(range(1, self.ndevice+1, 1))) + ax.set_ylim((0.5, len(self.devices())+0.5)) + plt.yticks(list(range(1, len(self.devices())+1, 1))) ax.invert_yaxis() plt.ylabel('device id') @@ -138,19 +163,20 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): for devid in range(ndevice): timeline = device_timeline[devid] - sus = device_sus[devid] - for su, (start, end) in zip(sus, timeline): + nodes = device_nodes[devid] + for node, (start, end) in zip(nodes, timeline): + if end - start == 0: + continue # draw - color = 'blue' if (end - start) == 1 else 'orange' + color = map2color(node) rec = Rectangle((start, devid + 0.5), end-start, 1, color=color, ec='black', lw=1.5) ax.add_artist(rec) rx, ry = rec.get_xy() cx = rx + rec.get_width() / 2.0 cy = ry + rec.get_height() / 2.0 - anno = str(su.stype) - # anno = su.name if action.fid is None else action.fid - ax.annotate(anno, (cx, cy), color='w', weight='bold', + anno = map2name(node) + ax.annotate(anno, (cx, cy), color='w', # weight='bold', fontsize=10, ha='center', va='center') # plt.grid() plt.savefig(outfile) From c2500f94234cc28336d24d56ab78ef9798bdf122 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 5 Jan 2022 10:31:44 +0800 Subject: [PATCH 0522/1892] add einops --- cube/algorithm/factory.py | 30 +-- cube/algorithm/generics.py | 30 +-- cube/algorithm/ops/activation.py | 115 ----------- cube/algorithm/ops/bmm.py | 195 ------------------ cube/algorithm/ops/einops.py | 95 +++++++++ cube/algorithm/ops/linear.py | 164 --------------- cube/graph/graph.py | 5 +- cube/graph/operator/function/__init__.py | 2 + cube/graph/operator/function/einops.py | 159 ++++++++++++++ .../graph/operator/{ => function}/function.py | 163 ++++++++------- 10 files changed, 365 insertions(+), 593 deletions(-) delete mode 100644 cube/algorithm/ops/activation.py delete mode 100644 cube/algorithm/ops/bmm.py create mode 100644 cube/algorithm/ops/einops.py delete mode 100644 cube/algorithm/ops/linear.py create mode 100644 cube/graph/operator/function/__init__.py create mode 100644 cube/graph/operator/function/einops.py rename cube/graph/operator/{ => function}/function.py (80%) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index a907047e..7d4b1f20 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -62,31 +62,23 @@ def _load_predefined_algos(self): import cube.algorithm.ops.dataloader as dataloader self.register(dataloader.IRDataOperation, dataloader.DPDataLoader, tag='data') - import cube.algorithm.ops.linear as linear - self.register(linear.Linear, linear.LinearDataParallel, tag='data') - self.register(linear.Linear, linear.LinearColumnWeight, tag='column') - self.register(linear.Linear, linear.LinearRowWeight, tag='row') + import cube.algorithm.ops.einops as einops + self.register(einops.IREinops, einops.DimSplitEinops, tag='dim') - import cube.algorithm.ops.bmm as bmm - self.register(bmm.BatchLinear, bmm.BatchLinearDataParallel, tag='data') - self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='n') - self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='m') - self.register(bmm.BatchLinear, bmm.BatchLinearNParallel, tag='p') - - import cube.algorithm.ops.elementwise as elew - self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') - self.register(elew.Add, elew.AddDimParallel, tag='dim') + # import cube.algorithm.ops.elementwise as elew + # self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') + # self.register(elew.Add, elew.AddDimParallel, tag='dim') import cube.algorithm.ops.layernorm as ln self.register(ln.LayerNorm, ln.LayerNormDimParallel, tag='dim') - import cube.algorithm.ops.activation as activation - self.register(activation.Activation, activation.ActivationDimParallel, tag='dim') - self.register(activation.Dropout, activation.DropoutDimParallel, tag='dim') - self.register(activation.Softmax, activation.SoftmaxDimParallel, tag ='dim') + # import cube.algorithm.ops.activation as activation + # self.register(activation.Activation, activation.ActivationDimParallel, tag='dim') + # self.register(activation.Dropout, activation.DropoutDimParallel, tag='dim') + # self.register(activation.Softmax, activation.SoftmaxDimParallel, tag ='dim') - import cube.algorithm.ops.reduce as reduce - self.register(reduce.Sum, reduce.SumDimParallel, tag='dim') + # import cube.algorithm.ops.reduce as reduce + # self.register(reduce.Sum, reduce.SumDimParallel, tag='dim') import cube.algorithm.ops.complex as complex self.register(complex.CubeComplexToQKV, complex.CubeToQKVDataParallel, tag='data') diff --git a/cube/algorithm/generics.py b/cube/algorithm/generics.py index 73b05a73..e6d9cef2 100644 --- a/cube/algorithm/generics.py +++ b/cube/algorithm/generics.py @@ -15,11 +15,6 @@ def __init__(self, node: IRCell): otherwise should be a list of integers like torch.Tensor.permute() on the logical required format. - Args: - input_layout (list[Outliner, None]): outliner for each input. - The length of outliner should be equal to the number of input - output_layout (list[Outlinter, None]): outliner for each output - The length of outliner should be equal to the number of output # TODO: input_format (list[list[int], None]): input dim order compare with logical definition @@ -28,28 +23,11 @@ def __init__(self, node: IRCell): """ if not isinstance(node, IRCell): raise TypeError("Expected node to be IRCell") - - input_shapes = list() - for input in node.inputs(): - if isinstance(input, IRTensor): - input_shapes.append(input.shape) - else: - input_shapes.append(None) - output_shapes = list() - for output in node.outputs(): - if isinstance(output, IRTensor): - output_shapes.append(output.shape) - else: - output_shapes.append(None) - - self.input_shapes = input_shapes - self.output_shapes = output_shapes - - self._logical_op = type(node) + self._node = node @property - def logic_op(self): - return self._logical_op + def node(self) -> IRCell: + return self._node def satisfy(self, config: Dict): """ @@ -57,7 +35,7 @@ def satisfy(self, config: Dict): """ raise NotImplementedError - def instantiate(self, node, config: Dict): + def instantiate(self, config: Dict): """ Instantiate the algorithm given the config """ diff --git a/cube/algorithm/ops/activation.py b/cube/algorithm/ops/activation.py deleted file mode 100644 index 05aa4ec0..00000000 --- a/cube/algorithm/ops/activation.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Any, Dict, List -import copy - -from cube.algorithm.utils import split_axis -from cube.algorithm.generics import GenericDistAlgo -from cube.ir.cten import IRTensor - -from cube.graph.operator.function import Activation -from cube.graph.operator.function import Dropout -from cube.graph.operator.function import Softmax - - -_kWaitDecision = None - - -class ActivationDimParallel(GenericDistAlgo): - - def __init__(self, node: Activation, dim=None): - if not isinstance(node, Activation): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.ndim = len(node.inputs(0).shape) - self.chunk_num = _kWaitDecision - self.dim = dim - # stay dim convert to positive dim - self.stay_dims = list() - for sdim in node.stay_dims: - sdim = sdim if sdim >= 0 else self.ndim + sdim - self.stay_dims.append(sdim) - - def satisfy(self, config: Dict): - if 'dim' in config: - dim = config['dim'] - else: - if self.dim is None: - raise RuntimeError("Expected dim in config") - dim = self.dim - if dim < 0: - dim = self.ndim + dim - chunk_num = int(config['chunk_num']) - if dim in self.stay_dims: - return False - shape = self.input_shapes[0] - if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: - return True - return False - - def get_extra_kwargs(self, node) -> List[Any]: - """ - Get extra kwarg inputs for the activation - - Returns: - value in List - """ - return [] - - def instantiate(self, node: Activation, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - if 'dim' in config: - self.dim = config['dim'] - - sub_inputs = list() - for input in node.inputs(): - if isinstance(input, IRTensor): - sub_input = split_axis(input, self.dim, self.chunk_num) - else: - sub_input = [input] * self.chunk_num - sub_inputs.append(sub_input) - - sub_outputs = list() - for output in node.outputs(): - if isinstance(output, IRTensor): - sub_output = split_axis(output, self.dim, self.chunk_num) - else: - sub_output = [output] * self.chunk_num - sub_outputs.append(sub_output) - - nodes = list() - for idx, sub_input in enumerate(zip(*sub_inputs)): - extra_input = self.get_extra_kwargs(node) - sub_input = list(sub_input) + extra_input - sub_node = type(node)(node.signature, inputs=sub_input, name=node.name) - sub_node.stay_dims = copy.copy(node.stay_dims) - nodes.append(sub_node) - for idx, sub_output in enumerate(zip(*sub_outputs)): - sub_node = nodes[idx] - for idx, output in enumerate(sub_output): - sub_node.set_output(idx, output) - return nodes - - -class DropoutDimParallel(ActivationDimParallel): - - def __init__(self, node: Activation, dim=None): - super().__init__(node, dim=dim) - - def get_extra_kwargs(self, node: Dropout) -> List[Any]: - if not isinstance(node, Dropout): - raise TypeError("Expected Dropout for DropoutDimParallel") - kwargs = [node.kwargs['p'], node.kwargs['training'], node.kwargs['inplace']] - return kwargs - - -class SoftmaxDimParallel(ActivationDimParallel): - - def __init__(self, node: Activation, dim=None): - super().__init__(node, dim=dim) - - def get_extra_kwargs(self, node) -> List[Any]: - if not isinstance(node, Softmax): - raise TypeError("Expected Softmax for SoftmaxDimParallel") - kwargs = [node.kwargs['dim'], node.kwargs['_stacklevel'], node.kwargs['dtype']] - return kwargs diff --git a/cube/algorithm/ops/bmm.py b/cube/algorithm/ops/bmm.py deleted file mode 100644 index dd25bd2c..00000000 --- a/cube/algorithm/ops/bmm.py +++ /dev/null @@ -1,195 +0,0 @@ -from typing import Dict - -from cube.algorithm.utils import split_axis, split_value -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.operator.function import BatchLinear - - -_kWaitDecision = None - - -class BatchLinearDataParallel(GenericDistAlgo): - """ - Inputs: - input1: [B, N, M] - input2: [B, M, P] - - Outputs: - output: [B, N, P] - """ - - def __init__(self, node: BatchLinear): - - if not isinstance(node, BatchLinear): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - input_shape = self.input_shapes[0] - if chunk_num > 0 and input_shape[0] % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - input1, input2 = node.inputs() - output = node.outputs(0) - - in1s = split_axis(input1, 0, self.chunk_num) - in2s = split_axis(input2, 0, self.chunk_num) - outs = split_axis(output, 0, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - node = BatchLinear( - signature='torch.bmm', - inputs=[in1s[idx], in2s[idx]], - name='bmm' - ) - node.set_output(0, outs[idx]) - nodes.append(node) - return nodes - - -class BatchLinearNParallel(GenericDistAlgo): - """ - Inputs: - input1: [B, N, M] - input2: [B, M, P] - - Outputs: - output: [B, N, P] - """ - - def __init__(self, node: BatchLinear): - - if not isinstance(node, BatchLinear): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - input_shape = self.input_shapes[0] - if chunk_num > 0 and input_shape[1] % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - input1, input2 = node.inputs() - output = node.outputs(0) - - in1s = split_axis(input1, 1, self.chunk_num) - outs = split_axis(output, 1, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - node = BatchLinear( - signature='torch.bmm', - inputs=[in1s[idx], input2], - name='bmm' - ) - node.set_output(0, outs[idx]) - nodes.append(node) - return nodes - - -class BatchLinearMParallel(GenericDistAlgo): - """ - Inputs: - input1: [B, N, M] - input2: [B, M, P] - - Outputs: - output: [B, N, P] - """ - - def __init__(self, node: BatchLinear): - - if not isinstance(node, BatchLinear): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - input_shape = self.input_shapes[0] - if chunk_num > 0 and input_shape[2] % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - input1, input2 = node.inputs() - output = node.outputs(0) - - in1s = split_axis(input1, 2, self.chunk_num) - in2s = split_axis(input2, 1, self.chunk_num) - outs = split_value(output, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - node = BatchLinear( - signature='torch.bmm', - inputs=[in1s[idx], in2s[idx]], - name='bmm' - ) - node.set_output(0, outs[idx]) - nodes.append(node) - return nodes - - -class BatchLinearPParallel(GenericDistAlgo): - """ - Inputs: - input1: [B, N, M] - input2: [B, M, P] - - Outputs: - output: [B, N, P] - """ - - def __init__(self, node: BatchLinear): - - if not isinstance(node, BatchLinear): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - input_shape = self.input_shapes[1] - if chunk_num > 0 and input_shape[2] % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - input1, input2 = node.inputs() - output = node.outputs(0) - - in2s = split_axis(input2, 2, self.chunk_num) - outs = split_axis(output, 2, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - node = BatchLinear( - signature='torch.bmm', - inputs=[input1, in2s[idx]], - name='bmm' - ) - node.set_output(0, outs[idx]) - nodes.append(node) - return nodes diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py new file mode 100644 index 00000000..96e6b489 --- /dev/null +++ b/cube/algorithm/ops/einops.py @@ -0,0 +1,95 @@ +from typing import Any, List, Dict +import copy + +from cube.algorithm.utils import split_axis, split_value +from cube.algorithm.generics import GenericDistAlgo +from cube.ir.cten import IRTensor + +from cube.graph.operator.function import IREinops, EinDim + + +class DimSplitEinops(GenericDistAlgo): + """ + split Einops at dimension level. + + The sum-reduce dimension and non-reduce dimension can be splitted. + + For sum-reduce dimension, the output keeps same shape but has partial-sum valmap result. + For non-reduce dimension, the output keeps same valmap but has partial output shape. + For stay-reduce dimension, this dimension is not allowed to be splitted. + """ + + def __init__(self, node: IREinops): + if not isinstance(node, IREinops): + raise TypeError(f"Expect IREinops") + super().__init__(node) + + def satisfy(self, config: Dict): + """ + config = dict(idx=int, dim=int) + + idx: int + input index + dim: int + dimension of index-th input + num: int + number of chunks to partition + """ + for attr in ['idx', 'dim', 'num']: + if not attr in config: + raise KeyError("Expected idx, dim, num in the config") + node = self.node + idx: int = config['idx'] + dim: int = config['dim'] + num: int = config['num'] + if not (isinstance(idx, int) and abs(idx) < len(node.inputs())): + return False + if node.inputs(idx).shape is None or abs(dim) >= len(node.inputs(idx).shape): + return False + if node.inputs(idx).shape[dim] % num != 0: + return False + return True + + def instantiate(self, config: Dict) -> List[IREinops]: + if not self.satisfy(config): + return False + node: IREinops = self.node + idx: int = config['idx'] + dim: int = config['dim'] + num: int = config['num'] + axis: EinDim = node._ieins[idx][dim] + + # print(f'splitting: {node.einexpr()}') + + ins, ous = list(), list() + for iidx, input in enumerate(node.inputs()): + if axis in node._ieins[iidx]: + dim = node._ieins[iidx].index(axis) + sub_tensors = split_axis(input, dim, num) + ins.append(sub_tensors) + else: + ins.append([input] * num) + for oidx, output in enumerate(node.outputs()): + # split on the non-reduce axis, the output value keeps same + # but the output shape gets splitted + if axis in node._oeins[oidx]: + dim = node._oeins[oidx].index(axis) + if axis.is_reduce(): + raise RuntimeError(f"Reduced axis {dim} appeared in output") + sub_tensors = split_axis(output, dim, num) + ous.append(sub_tensors) + # split on the reduce axis, the output shape keeps same + # but the output value get splitted + else: + if not axis.is_reduce(): + raise RuntimeError(f"Expect axis {axis} to be reduced axis") + sub_tensors = split_value(output, num) + ous.append(sub_tensors) + + sub_nodes = list() + for nid in range(num): + inputs = [t[nid] for t in ins] + outputs = [t[nid] for t in ous] + sub_node = node.new(inputs, outputs) + sub_nodes.append(sub_node) + return sub_nodes diff --git a/cube/algorithm/ops/linear.py b/cube/algorithm/ops/linear.py deleted file mode 100644 index af9cc23a..00000000 --- a/cube/algorithm/ops/linear.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import Dict - -from cube.algorithm.utils import split_axis, split_value -from cube.algorithm.generics import GenericDistAlgo -from cube.graph.operator.function import Linear - - -_kWaitDecision = None - - -class LinearDataParallel(GenericDistAlgo): - """ - Input: - input: [N, *, in_features] - weight: [out_features, in_features] - bias: [out_features,] - - Output: - [N, *, in_features] - """ - - def __init__(self, node: Linear): - - if not isinstance(node, Linear): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - - # input dimension - self.ndim = len(node.inputs(0).shape) - self.dim_choice = list(range(self.ndim - 1)) - - self.chunk_num = _kWaitDecision - if len(self.dim_choice) == 1: - self.dim = 0 - else: - self.dim = _kWaitDecision - - def satisfy(self, config: Dict): - input_shape = self.input_shapes[0] - if input_shape is None: - return False - chunk_num = int(config['chunk_num']) - if 'dim' in config: - dim = config['dim'] - else: - if self.dim is None: - raise RuntimeError("Expected dim in config") - dim = self.dim - if dim < 0: - dim = self.ndim + dim - input_shape = self.input_shapes[0] - if chunk_num > 0 and input_shape[dim] % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - if 'dim' in config: - self.dim = config['dim'] - input, weight, bias = node.inputs() - output = node.outputs(0) - - ins = split_axis(input, self.dim, self.chunk_num) - outs = split_axis(output, self.dim, self.chunk_num) - - nodes = list() - for input_chunk, output_chunk in zip(ins, outs): - node = Linear( - signature='torch.nn.functional.linear', - inputs=[input_chunk, weight, bias], - name='linear' - ) - node.set_output(0, output_chunk) - nodes.append(node) - return nodes - - -class LinearColumnWeight(GenericDistAlgo): - - def __init__(self, node: Linear): - - if not isinstance(node, Linear): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - - self.chunk_num = _kWaitDecision - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - weight_shape = self.input_shapes[1] - if chunk_num > 0 and weight_shape[0] % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - input, weight, bias = node.inputs() - output = node.outputs(0) - - ws = split_axis(weight, 0, self.chunk_num) - if bias is not None: - bs = split_axis(bias, 0, self.chunk_num) - else: - bs = [None] * self.chunk_num - os = split_axis(output, -1, self.chunk_num) - - nodes = list() - for w, b, o in zip(ws, bs, os): - node = Linear( - signature='torch.nn.functional.linear', - inputs=[input, w, b], - name='linear' - ) - node.set_output(0, o) - nodes.append(node) - return nodes - - -class LinearRowWeight(GenericDistAlgo): - - def __init__(self, node: Linear): - - if not isinstance(node, Linear): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - - self.chunk_num = _kWaitDecision - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - weight_shape = self.input_shapes[1] - if chunk_num > 0 and weight_shape[1] % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - input, weight, bias = node.inputs() - output = node.outputs(0) - - ins = split_axis(input, -1, self.chunk_num) - ws = split_axis(weight, 1, self.chunk_num) - if bias: - bs = split_value(bias, self.chunk_num) - else: - bs = [None] * self.chunk_num - os = split_value(output, self.chunk_num) - - nodes = list() - for x, w, b, o in zip(ins, ws, bs, os): - node = Linear( - signature='torch.nn.functional.linear', - inputs=[x, w, b], - name='linear' - ) - node.set_output(0, o) - nodes.append(node) - return nodes diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 1d84d892..002be6b0 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -252,7 +252,6 @@ def replicate(self, op: IRCell, times=1) -> Optional[List[IRCell]]: ops.append(op.replicate()) if isinstance(op.mirror, IRBpOperation): for rep_op in ops[1:]: - print(rep_op) rep_op.gen_backward() idx = self.nodes().index(op) # forward @@ -285,11 +284,11 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional if op not in self.nodes(): raise RuntimeError(f"Not Exist: {op}") - if algo.logic_op != type(op): + if algo.node != op: return None if not algo.satisfy(config): return None - fnodes = algo.instantiate(op, config) + fnodes = algo.instantiate(config) #FIXME: we don't allow non-weight input to be splitted in value for fnode in fnodes: diff --git a/cube/graph/operator/function/__init__.py b/cube/graph/operator/function/__init__.py new file mode 100644 index 00000000..53aa6b6c --- /dev/null +++ b/cube/graph/operator/function/__init__.py @@ -0,0 +1,2 @@ +from cube.graph.operator.function.einops import EinDim, IREinops +from cube.graph.operator.function.function import * \ No newline at end of file diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py new file mode 100644 index 00000000..6cddddf6 --- /dev/null +++ b/cube/graph/operator/function/einops.py @@ -0,0 +1,159 @@ +""" +This operator class is highly inspired by eniops. +""" +import enum +from typing import List, Optional, Tuple + +from cube.ir.cten import IRTensor +from cube.graph.operator.operator import IRFwOperation +from cube.algorithm.factory import DistAlgorithmFactory + + +class EinDim: + + class ReduceType(enum.Enum): + Stay = 0 # the dim is not allowed to be split + Sum = 1 + + def __init__(self, name: str, reduce=None): + if not (str.isidentifier(name) or name == '*'): + raise ValueError("Einstein Axis name should be identifier") + self.name: str = name + self.reduce: Optional[EinDim.ReduceType] = reduce + + def __eq__(self, other): + if isinstance(other, EinDim): + if other.name == self.name: + return True + return False + + def is_reduce(self): + return self.reduce == EinDim.ReduceType.Sum + + def __repr__(self): + return self.name if not self.is_reduce() else self.name + "'" + + +class IREinops(IRFwOperation): + """ + Einstein expression on operators like reshape, view, permute, reduce. + """ + def __init__(self, name: str, signature: str, input_length: int, output_length:int): + super().__init__(name, signature, input_length, output_length) + self._ieins = [list() for _ in range(input_length)] + self._oeins = [list() for _ in range(output_length)] + + def new(self, inputs, outputs, **kwargs): + """ + Create a new same operation given the inputs and outputs + + Each operator needs to implement this. + """ + raise NotImplementedError + + def make_expression(self): + """ + Set einstein-like expression assuming input shapes are given. + + Each operator needs to implement this. + """ + raise NotImplementedError + + def infer_shape(self): + """ + Infer output value shape + """ + for input in self.inputs(): + if isinstance(input, IRTensor) and input.shape is None: + return False + self.make_expression() + # check expression + for input, ein_dims in zip(self.inputs(), self._ieins): + if len(ein_dims) == 0: + if isinstance(input, IRTensor): + raise RuntimeError(f"{self}: {input} has no ein-dims but is a tensor") + if len(ein_dims) != 0: + if not isinstance(input, IRTensor): + raise RuntimeError(f"{self}: {input} has ein-dims but is not a tensor") + if len(input.shape) != len(ein_dims): + raise RuntimeError(f"input tensor ndims ({len(input.shape)}) != ein-dims ({len(ein_dims)})") + # figure output shape + for oidx in range(len(self._outputs)): + output_shape = list() + for oein in self._oeins[oidx]: + for iidx in range(len(self._inputs)): + if oein in self._ieins[iidx]: + input = self.inputs(iidx) + dim = self._ieins[iidx].index(oein) + output_shape.append(input.shape[dim]) + break + self.outputs(oidx).shape = output_shape + return True + + def set_input_ein(self, input_index: int, dims: List[EinDim]): + """ + Set input einstein axis at input index + """ + if not all([isinstance(dim, EinDim) for dim in dims]): + raise TypeError("Expected Tuple[EinDim]") + self._ieins[input_index] = tuple(dims) + + def set_output_ein(self, output_index: int, dims: Tuple[EinDim]): + """ + Set output einstein axis at output index + """ + if not all([isinstance(dim, EinDim) for dim in dims]): + raise TypeError("Expected Tuple[EinDim]") + self._oeins[output_index] = tuple(dims) + + def einexpr(self) -> str: + inputs = list() + outputs = list() + for iein in self._ieins: + inputs.append(' '.join([repr(ein) for ein in iein])) + for oein in self._oeins: + outputs.append(' '.join([repr(ein) for ein in oein])) + return ', '.join(inputs) + ' -> ' + ', '.join(outputs) + + def algorithms(self, tag: Optional[str] = None): + factory = DistAlgorithmFactory() + if tag is None: + templates = list() + if factory.exist(IREinops): + templates = factory.algorithms(IREinops) + algos = list() + for template in templates: + algos.append(template(self)) + return algos + else: + if not factory.exist(IREinops, tag): + return None + template = factory.algorithms(IREinops, tag) + return template(self) + + def parse(self, expr: str): + """ + parse string like: + b m k, b k n -> b m n + """ + if not isinstance(expr, str): + raise TypeError("Expected string") + # remove space + expr = expr.replace(' ', '') + if expr.count('->') != 1: + raise ValueError("string must contain one ->") + input, output = expr.split('->') + inputs = input.split(',') + input_axises = list() + for input in inputs: + axises = list() + for dim in input: + reduce = EinDim.ReduceType.Sum if dim not in output else None + axises.append(EinDim(dim, reduce)) + input_axises.append(axises) + outputs = output.split(',') + output_axises = list() + for output in outputs: + axises = [EinDim(dim) for dim in output] + output_axises.append(axises) + return input_axises, output_axises diff --git a/cube/graph/operator/function.py b/cube/graph/operator/function/function.py similarity index 80% rename from cube/graph/operator/function.py rename to cube/graph/operator/function/function.py index a589c4ad..f8db7c50 100644 --- a/cube/graph/operator/function.py +++ b/cube/graph/operator/function/function.py @@ -1,20 +1,16 @@ -import copy +from typing import List +import string from cube.graph.operator import IRFwOperation +from cube.graph.operator.function.einops import EinDim, IREinops +from cube.ir.cten import IRTensor -class Linear(IRFwOperation): +class Linear(IREinops): """ - Input: - input: [N, *, in_features] - weight: [out_features, in_features] - bias: [out_features,] - - Output: - [N, *, in_features] + b * k, n k -> b * n """ def __init__(self, signature, inputs, name='linear', **kwargs): - input, weight, bias = inputs super().__init__( name, signature, @@ -25,31 +21,38 @@ def __init__(self, signature, inputs, name='linear', **kwargs): self.set_input(1, weight) self.set_input(2, bias) - def infer_shape(self): - """ - input: [(D), M, K] - weight: [N, K] - bias: [N,] - """ - if self.inputs(0).shape is None or self.inputs(1).shape is None: - return False - shape = self.inputs(0).shape[:-1] + self.inputs(1).shape[:1] - self._outputs[0].shape = shape - return True - - -class BatchLinear(IRFwOperation): + def make_expression(self): + expr = 'b * k, n k, n -> b * n' + [idims, wdims, bdims], [odims] = self.parse(expr) + if len(self.inputs(0).shape) == 2: + idims = [idims[0], idims[2]] + odims = [odims[0], odims[2]] + else: + extra_dims = list() + num_extra_dim = len(self.inputs(0).shape) - 2 + dims = [c for c in string.ascii_lowercase if c not in 'bkn'] + for num in range(num_extra_dim): + extra_dims.append(EinDim(dims[num])) + idims = [idims[0]] + extra_dims + [idims[-1]] + odims = [odims[0]] + extra_dims + [odims[-1]] + self.set_input_ein(0, idims) + self.set_input_ein(1, wdims) + if self.inputs(2) is not None: + self.set_input_ein(2, bdims) + self.set_output_ein(0, odims) + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + linear = Linear(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + linear.set_output(idx, output) + return linear + + +class BatchLinear(IREinops): """ - Inputs: - input1: [B, N, M] - input2: [B, M, P] - - Outputs: - output: [B, N, P] + b m k, b k n -> b m n """ - def __init__(self, signature, inputs, name='bmm', **kwargs): - if len(inputs) != 2: raise TypeError(f"Requires 2 inputs. But got {inputs}") input1, input2 = inputs @@ -61,33 +64,27 @@ def __init__(self, signature, inputs, name='bmm', **kwargs): self.set_input(0, input1) self.set_input(1, input2) - def infer_shape(self): - if self.inputs(0).shape is None or self.inputs(1).shape is None: - return False - b1, n1, m1 = self.inputs(0).shape - b2, m2, p2 = self.inputs(1).shape - if m1 != m2 or b1 != b2: - raise RuntimeError("Unmatch {b1} != {b2} or {m1} != {m2}") - shape = [b1, n1, p2] - self._outputs[0].shape = shape - return True + def make_expression(self): + expr = 'b m k, b k n -> b m n' + input_dims, output_dims = self.parse(expr) + for idx, input_dim in enumerate(input_dims): + self.set_input_ein(idx, input_dim) + for idx, output_dim in enumerate(output_dims): + self.set_output_ein(idx, output_dim) + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + bmm = BatchLinear(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + bmm.set_output(idx, output) + return bmm -# ============================= Elementwise ============================ -class ElementWise(IRFwOperation): +class ElementWise(IREinops): """ - Functions like torch.add, torch.mul, torch.sub, etc. + *, _ -> * """ def __init__(self, signature, inputs, name='elementwise', **kwargs): - """ - Inputs: - inputs[0]: IRTensor - inputs[1]: other (IRTensor or Number) - Outputs: - same shape as inputs[0] - """ - if len(inputs) != 2: raise TypeError(f"Expected 2 inputs but got {inputs}") super().__init__( @@ -98,12 +95,27 @@ def __init__(self, signature, inputs, name='elementwise', **kwargs): for idx, input in enumerate(inputs): self.set_input(idx, input) - def infer_shape(self): - if self.inputs(0).shape is None: - return False - shape = copy.copy(self.inputs(0).shape) - self._outputs[0].shape = shape - return True + def make_expression(self): + """ + """ + dims = string.ascii_lowercase + i1, i2 = self.inputs() + dim1 = [EinDim(dims[d]) for d in range(len(i1.shape))] + if isinstance(i2, IRTensor): + if i2.shape == i1.shape: + dim2 = dim1 + else: + raise NotImplementedError(f"Cannot match shape: {i1.shape} and {i2.shape}") + dim2 = list() + self.set_input_ein(0, dim1) + self.set_input_ein(1, dim2) + self.set_output_ein(0, dim1) + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + elew = ElementWise(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + elew.set_output(idx, output) + return elew class Add(ElementWise): @@ -127,6 +139,13 @@ def __init__(self, signature, inputs, name='add', **kwargs): alpha = inputs[2] self.kwargs['alpha'] = alpha + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + inputs = inputs = self.kwags['alpha'] + add = Add(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + add.set_output(idx, output) + return add + class LayerNorm(IRFwOperation): @@ -155,9 +174,7 @@ def infer_shape(self): return True -# ============================= Activation ============================ - -class Activation(IRFwOperation): +class Activation(IREinops): """ functions like GELU, RELU, Dropout. @@ -175,15 +192,15 @@ def __init__(self, signature, inputs, name='activation', **kwargs): output_length=1 ) self.set_input(0, inputs[0]) - # this is for partitioning indicator - self.stay_dims = list() - def infer_shape(self): - input = self.inputs(0) - if input.shape is None: - return False - self._outputs[0].shape = input.shape - return True + def make_expression(self): + """ + * -> * + """ + dims = string.ascii_lowercase + dim1 = [EinDim(dims[d]) for d in range(len(self.inputs(0).shape))] + self.set_input_ein(0, dim1) + self.set_output_ein(0, dim1) class Dropout(Activation): @@ -214,7 +231,11 @@ def __init__(self, signature, inputs, name='softmax', **kwargs): self.kwargs['dim'] = dim self.kwargs['_stacklevel'] = stacklevel self.kwargs['dtype'] = dtype - self.stay_dims.append(dim) + + def make_expression(self): + super().make_expression() + dim = self.kwargs['dim'] + self._ieins[0][dim].reduce = EinDim.ReduceType.Stay # ===================== Loss Computation (Reduce) ========================= From 438331cf07080236445b524e18fd91b131c45f3b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 5 Jan 2022 10:59:09 +0800 Subject: [PATCH 0523/1892] fix value split bug --- cube/algorithm/ops/einops.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py index 96e6b489..c35cd75e 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/einops.py @@ -1,9 +1,8 @@ -from typing import Any, List, Dict -import copy +from typing import List, Dict +import warnings from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo -from cube.ir.cten import IRTensor from cube.graph.operator.function import IREinops, EinDim @@ -68,7 +67,11 @@ def instantiate(self, config: Dict) -> List[IREinops]: sub_tensors = split_axis(input, dim, num) ins.append(sub_tensors) else: - ins.append([input] * num) + if axis.is_reduce(): + print(f'Warning: value split on one input tensor in node{node._id}:{node.name} as reduce axis {axis} not appeared.') + ins.append(split_value(input, num)) + else: + ins.append([input] * num) for oidx, output in enumerate(node.outputs()): # split on the non-reduce axis, the output value keeps same # but the output shape gets splitted @@ -90,6 +93,7 @@ def instantiate(self, config: Dict) -> List[IREinops]: for nid in range(num): inputs = [t[nid] for t in ins] outputs = [t[nid] for t in ous] - sub_node = node.new(inputs, outputs) + sub_node: IREinops = node.new(inputs, outputs) + sub_node.make_expression() sub_nodes.append(sub_node) return sub_nodes From 1235d6825a1737f8e8e23b247037308363c8df64 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 5 Jan 2022 10:59:21 +0800 Subject: [PATCH 0524/1892] update using einops --- examples/mlp/linears.py | 6 +++--- examples/mlp/policy/col_parallel.py | 8 ++++---- examples/mlp/policy/data_parallel.py | 4 ++-- examples/mlp/policy/hybrid_parallel.py | 8 +++----- examples/mlp/policy/megatron_parallel.py | 14 +++++++------- examples/mlp/policy/pipe_parallel.py | 5 +++-- examples/mlp/policy/row_parallel.py | 4 ++-- 7 files changed, 24 insertions(+), 25 deletions(-) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 7a32963d..fdde7693 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,10 +17,10 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -# from examples.mlp.policy.col_parallel import PAS +from examples.mlp.policy.pipe_parallel import PAS -from examples.mlp.policy.col_parallel import P, A, S -PAS = (P, A, S) +# from examples.mlp.policy.col_parallel import P, A, S +# PAS = (P, A, S) # =================== Semantic Model Description ==================== diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index 8ba0fa99..6c3cdf75 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -8,10 +8,10 @@ def P(graph, resource): """ for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - algo = node.algorithms('column') + algo = node.algorithms('dim') if algo: sub_nodes = graph.partition( - node, algo, config=dict(chunk_num=resource.ngpus) + node, algo, config=dict(idx=1, dim=0, num=resource.ngpus) ) else: # graph.assign(node, list(range(resource.ngpus))) @@ -46,10 +46,10 @@ def PAS(graph: IRGraph, resource): """ for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - algo = node.algorithms('column') + algo = node.algorithms('dim') if algo: sub_nodes = graph.partition( - node, algo, config=dict(chunk_num=resource.ngpus) + node, algo, config=dict(idx=1, dim=0, num=resource.ngpus) ) else: # graph.assign(node, list(range(resource.ngpus))) diff --git a/examples/mlp/policy/data_parallel.py b/examples/mlp/policy/data_parallel.py index 48c8bfd3..e0bee742 100644 --- a/examples/mlp/policy/data_parallel.py +++ b/examples/mlp/policy/data_parallel.py @@ -8,10 +8,10 @@ def PAS(graph: IRGraph, resource): """ for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - algo = node.algorithms('data') + algo = node.algorithms('dim') if algo: sub_nodes = graph.partition( - node, algo, config=dict(chunk_num=resource.ngpus) + node, algo, config=dict(idx=0, dim=0, num=resource.ngpus) ) else: sub_nodes = graph.replicate(node, times=resource.ngpus) diff --git a/examples/mlp/policy/hybrid_parallel.py b/examples/mlp/policy/hybrid_parallel.py index 289d5121..658c5271 100644 --- a/examples/mlp/policy/hybrid_parallel.py +++ b/examples/mlp/policy/hybrid_parallel.py @@ -8,13 +8,11 @@ def PAS(graph: IRGraph, resource): """ for idx, node in enumerate(graph.nodes()): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - if idx % 2 == 0: - algo = node.algorithms('row') - else: - algo = node.algorithms('column') + algo = node.algorithms('dim') if algo: sub_nodes = graph.partition( - node, algo, config=dict(chunk_num=resource.ngpus) + node, algo, + config=dict(idx=1, dim=(idx+1)%2, num=resource.ngpus) ) else: sub_nodes = graph.replicate(node, times=resource.ngpus) diff --git a/examples/mlp/policy/megatron_parallel.py b/examples/mlp/policy/megatron_parallel.py index 502fbeeb..40ba3df1 100644 --- a/examples/mlp/policy/megatron_parallel.py +++ b/examples/mlp/policy/megatron_parallel.py @@ -10,19 +10,19 @@ def PAS(graph: IRGraph, resource): dp = resource.ngpus // tp for idx, node in enumerate(graph.nodes()): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - if idx % 2 == 0: - algo = node.algorithms('row') - else: - algo = node.algorithms('column') + algo = node.algorithms('dim') if algo: sub_nodes = list() tp_nodes = graph.partition( - node, algo, config=dict(chunk_num=tp) + node, algo, + config=dict(idx=1, dim=(idx+1)%2, num=tp) ) for tp_node in tp_nodes: - algo = tp_node.algorithms('data') + algo = tp_node.algorithms('dim') dp_nodes = graph.partition( - tp_node, algo, config=dict(chunk_num=dp)) + tp_node, algo, + config=dict(idx=0, dim=0, num=dp) + ) sub_nodes += dp_nodes else: sub_nodes = graph.replicate(node, times=resource.ngpus) diff --git a/examples/mlp/policy/pipe_parallel.py b/examples/mlp/policy/pipe_parallel.py index c25a8777..c2c821a5 100644 --- a/examples/mlp/policy/pipe_parallel.py +++ b/examples/mlp/policy/pipe_parallel.py @@ -11,9 +11,10 @@ def PAS(graph, resource): micro_batch_num = resource.ngpus for node in graph.nodes(): if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): - algo = node.algorithms('data') + algo = node.algorithms('dim') if algo is not None: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=micro_batch_num)) + sub_nodes = graph.partition( + node, algo, config=dict(idx=0, dim=0, num=micro_batch_num)) else: sub_nodes = [node] for idx, sub_node in enumerate(sub_nodes): diff --git a/examples/mlp/policy/row_parallel.py b/examples/mlp/policy/row_parallel.py index d8286178..b44f49a4 100644 --- a/examples/mlp/policy/row_parallel.py +++ b/examples/mlp/policy/row_parallel.py @@ -8,10 +8,10 @@ def PAS(graph: IRGraph, resource): """ for node in graph.nodes(): if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - algo = node.algorithms('row') + algo = node.algorithms('dim') if algo: sub_nodes = graph.partition( - node, algo, config=dict(chunk_num=resource.ngpus) + node, algo, config=dict(idx=1, dim=1, num=resource.ngpus) ) else: sub_nodes = graph.replicate(node, times=resource.ngpus) From aca51dc51ce4e2aeaab132998bc64532c6a32509 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 5 Jan 2022 16:31:49 +0800 Subject: [PATCH 0525/1892] einops stay dim --- cube/algorithm/ops/einops.py | 5 +- cube/algorithm/ops/elementwise.py | 93 ------------------------ cube/algorithm/ops/memory.py | 68 ----------------- cube/graph/operator/function/function.py | 35 +++++---- 4 files changed, 23 insertions(+), 178 deletions(-) delete mode 100644 cube/algorithm/ops/elementwise.py delete mode 100644 cube/algorithm/ops/memory.py diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py index c35cd75e..d3039ea8 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/einops.py @@ -1,5 +1,4 @@ from typing import List, Dict -import warnings from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo @@ -37,7 +36,7 @@ def satisfy(self, config: Dict): for attr in ['idx', 'dim', 'num']: if not attr in config: raise KeyError("Expected idx, dim, num in the config") - node = self.node + node: IREinops = self.node idx: int = config['idx'] dim: int = config['dim'] num: int = config['num'] @@ -47,6 +46,8 @@ def satisfy(self, config: Dict): return False if node.inputs(idx).shape[dim] % num != 0: return False + if node._ieins[idx][dim].reduce == EinDim.ReduceType.Stay: + return False return True def instantiate(self, config: Dict) -> List[IREinops]: diff --git a/cube/algorithm/ops/elementwise.py b/cube/algorithm/ops/elementwise.py deleted file mode 100644 index b0639bed..00000000 --- a/cube/algorithm/ops/elementwise.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Any, List, Dict -import copy - -from cube.algorithm.utils import split_axis -from cube.algorithm.generics import GenericDistAlgo -from cube.ir.cten import IRTensor - -from cube.graph.operator.function import ElementWise -from cube.graph.operator.function import Add - - -_kWaitDecision = None - - -class ElementWiseDimParallel(GenericDistAlgo): - - def __init__(self, node: ElementWise, dim=None): - if not isinstance(node, ElementWise): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.ndim = len(node.inputs(0).shape) - self.chunk_num = _kWaitDecision - self.dim = dim - - def satisfy(self, config: Dict): - if 'dim' in config: - dim = config['dim'] - else: - if self.dim is None: - raise RuntimeError("Expected dim in config") - dim = self.dim - if dim < 0: - dim = self.ndim + dim - chunk_num = int(config['chunk_num']) - shape = self.input_shapes[0] - if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: - return True - return False - - def get_extra_kwargs(self, node: ElementWise) -> List[Any]: - """ - Get extra kwarg inputs for the activation - - Returns: - value in List - """ - return [] - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - if 'dim' in config: - self.dim = config['dim'] - - sub_inputs = list() - for input in node.inputs(): - if isinstance(input, IRTensor): - sub_input = split_axis(input, self.dim, self.chunk_num) - else: - sub_input = [input] * self.chunk_num - sub_inputs.append(sub_input) - - sub_outputs = list() - for output in node.outputs(): - if isinstance(output, IRTensor): - sub_output = split_axis(output, self.dim, self.chunk_num) - else: - sub_output = [output] * self.chunk_num - sub_outputs.append(sub_output) - - nodes = list() - for idx, sub_input in enumerate(zip(*sub_inputs)): - sub_input = list(sub_input) + self.get_extra_kwargs(node) - sub_node = type(node)(node.signature, inputs=sub_input, name=node.name) - sub_node.kwargs = copy.copy(node.kwargs) - nodes.append(sub_node) - for idx, sub_output in enumerate(zip(*sub_outputs)): - sub_node = nodes[idx] - for idx, output in enumerate(sub_output): - sub_node.set_output(idx, output) - return nodes - - -class AddDimParallel(ElementWiseDimParallel): - - def __init__(self, node: ElementWise, dim=None): - super().__init__(node, dim=dim) - - def get_extra_kwargs(self, node: Add) -> List[Any]: - if not isinstance(node, Add): - raise TypeError("Expected Add for AddDimParallel") - return [node.kwargs['alpha']] diff --git a/cube/algorithm/ops/memory.py b/cube/algorithm/ops/memory.py deleted file mode 100644 index 4a5493be..00000000 --- a/cube/algorithm/ops/memory.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Dict -import copy - -from cube.algorithm.utils import split_axis -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.operator.function import Transpose - - -_kWaitDecision = None - - -class TransposeDimParallel(GenericDistAlgo): - - def __init__(self, node: Transpose, dim=None): - if not isinstance(node, Transpose): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - - self.dim0 = node.kwargs['dim0'] - self.dim1 = node.kwargs['dim1'] - self.ndim = len(node.inputs(0).shape) - - # config - self.chunk_num = _kWaitDecision - self.dim = dim - - def satisfy(self, config: Dict): - if 'dim' in config: - dim = config['dim'] - dim = self.ndim + dim if dim < 0 else dim - else: - if self.dim is None: - raise RuntimeError("Expected dim in config") - dim = self.dim - chunk_num = int(config['chunk_num']) - shape = self.input_shapes[0] - if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: - return True - return False - - def instantiate(self, node: Transpose, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - if 'dim' in config: - self.dim = config['dim'] - - input = node.inputs(0) - sub_inputs = split_axis(input, self.dim, self.chunk_num) - - output = node.outputs(0) - target_dim = self.dim - if self.dim == self.dim0: - target_dim = self.dim1 - if self.dim == self.dim1: - target_dim = self.dim0 - sub_outputs = split_axis(output, target_dim, self.chunk_num) - - nodes = list() - for input, output in zip(sub_inputs, sub_outputs): - sub_node = type(node)( - node.signature, inputs=[input, self.dim0, self.dim1], name=node.name - ) - sub_node.kwargs = copy.copy(node.kwargs) - sub_node.set_output(0, output) - nodes.append(sub_node) - return nodes diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index f8db7c50..9d0621a4 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -1,3 +1,4 @@ +import copy from typing import List import string @@ -297,7 +298,7 @@ def infer_shape(self): # ========================= Memory Operation ========================== -class Transpose(IRFwOperation): +class Transpose(IREinops): """ torch.transpose """ @@ -320,22 +321,26 @@ def __init__(self, signature, inputs, name='transpose', **kwargs): self.kwargs['dim0'] = inputs[1] self.kwargs['dim1'] = inputs[2] - def infer_shape(self): - if self.inputs(0).shape is None: - return False - ndim = len(self.inputs(0).shape) + def make_expression(self): + """ + similar like a b c -> a c b + """ + dims = string.ascii_lowercase dim0 = self.kwargs['dim0'] - if dim0 < 0: - dim0 = ndim + dim0 - self.kwargs['dim0'] = dim0 dim1 = self.kwargs['dim1'] - if dim1 < 0: - dim1 = ndim + dim1 - self.kwargs['dim1'] = dim1 - shape = list(self.inputs(0).shape) - shape[dim0], shape[dim1] = shape[dim1], shape[dim0] - self._outputs[0].shape = shape - return True + input = self.inputs(0) + in_dim = [EinDim(dims[d]) for d in range(len(input.shape))] + ou_dim = copy.copy(in_dim) + ou_dim[dim0], ou_dim[dim1] = in_dim[dim1], in_dim[dim0] + + def renew(self, inputs: List[IRTensor], outputs: List[IRTensor]): + dim0 = self.kwargs['dim0'] + dim1 = self.kwargs['dim1'] + inputs += [dim0, dim1] + transpose = Transpose(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + transpose.set_output(idx, output) + return transpose # ===================== Cube Complex Operation ======================= From 9b0d779a6e71ccf51f8937a757ebf1ee53c21a80 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 5 Jan 2022 19:38:57 +0800 Subject: [PATCH 0526/1892] sum to einops --- cube/algorithm/factory.py | 4 +- cube/algorithm/ops/reduce.py | 76 ------------------- cube/graph/operator/function/einops.py | 25 ++++--- cube/graph/operator/function/function.py | 94 +++++++++++++++--------- 4 files changed, 78 insertions(+), 121 deletions(-) delete mode 100644 cube/algorithm/ops/reduce.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 7d4b1f20..b1cc8c01 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -99,5 +99,5 @@ def _load_predefined_algos(self): self.register(complex.CubeComplexEmbedding, complex.CubeEmbedDataParallel, tag='data') self.register(complex.CubeComplexEmbedding, complex.CubeEmbedShardingParallel, tag='shard') - import cube.algorithm.ops.memory as mem - self.register(mem.Transpose, mem.TransposeDimParallel, tag='dim') + # import cube.algorithm.ops.memory as mem + # self.register(mem.Transpose, mem.TransposeDimParallel, tag='dim') diff --git a/cube/algorithm/ops/reduce.py b/cube/algorithm/ops/reduce.py deleted file mode 100644 index a16e65f4..00000000 --- a/cube/algorithm/ops/reduce.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Dict -import copy - -from cube.algorithm.utils import split_axis, split_value -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.operator.function import Sum - - -_kWaitDecision = None - - -class SumDimParallel(GenericDistAlgo): - - def __init__(self, node: Sum, dim=None): - if not isinstance(node, Sum): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.ndim = len(node.inputs(0).shape) - self.reduce_dims = list(range(self.ndim)) - self.keepdim = [False] * self.ndim - if 'dim' in node.kwargs: - self.reduce_dims = [node.kwargs['dim']] - if 'keepdim' in node.kwargs: - self.keepdim = [node.kwargs['keepdim']] * self.ndim - - self.chunk_num = _kWaitDecision - if dim is not None: - dim = self.ndim + dim if dim < 0 else dim - self.dim = dim - - def satisfy(self, config: Dict): - if 'dim' in config: - dim = config['dim'] - else: - if self.dim is None: - raise RuntimeError("Expected dim in config") - dim = self.dim - if dim < 0: - dim = self.ndim + dim - chunk_num = int(config['chunk_num']) - shape = self.input_shapes[0] - if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: - return True - return False - - def instantiate(self, node: Sum, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - if 'dim' in config: - self.dim = config['dim'] - self.dim = self.ndim + self.dim if self.dim < 0 else self.dim - - assert len(node.inputs()) == 1 - input = node.inputs(0) - sub_inputs = split_axis(input, self.dim, self.chunk_num) - - assert len(node.outputs()) == 1 - output = node.outputs(0) - if self.dim not in self.reduce_dims: - sub_outputs = split_axis(output, self.dim, self.chunk_num) - else: - sub_outputs = split_value(output, self.chunk_num) - - nodes = list() - if 'dim' in node.kwargs: - dim = node.kwargs['dim'] - else: - dim = None - for input, output in zip(sub_inputs, sub_outputs): - sub_node = type(node)(node.signature, inputs=[input, dim], name=node.name) - sub_node.kwargs = copy.copy(node.kwargs) - sub_node.set_output(0, output) - nodes.append(sub_node) - return nodes diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 6cddddf6..c948f571 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -16,7 +16,7 @@ class ReduceType(enum.Enum): Sum = 1 def __init__(self, name: str, reduce=None): - if not (str.isidentifier(name) or name == '*'): + if not (str.isidentifier(name) or str.isdecimal(name) or name == '*'): raise ValueError("Einstein Axis name should be identifier") self.name: str = name self.reduce: Optional[EinDim.ReduceType] = reduce @@ -81,6 +81,9 @@ def infer_shape(self): for oidx in range(len(self._outputs)): output_shape = list() for oein in self._oeins[oidx]: + if str.isdecimal(oein.name): + output_shape.append(int(oein.name)) + continue for iidx in range(len(self._inputs)): if oein in self._ieins[iidx]: input = self.inputs(iidx) @@ -118,18 +121,20 @@ def einexpr(self) -> str: def algorithms(self, tag: Optional[str] = None): factory = DistAlgorithmFactory() if tag is None: - templates = list() - if factory.exist(IREinops): - templates = factory.algorithms(IREinops) algos = list() - for template in templates: - algos.append(template(self)) + if factory.exist(type(self)): + algos += [template(self) for template in factory.algorithms(type(self))] + if factory.exist(IREinops): + algos += [template(self) for template in factory.algorithms(IREinops)] return algos else: - if not factory.exist(IREinops, tag): - return None - template = factory.algorithms(IREinops, tag) - return template(self) + if factory.exist(type(self), tag): + template = factory.algorithms(type(self), tag) + return template(self) + if factory.exist(IREinops, tag): + template = factory.algorithms(IREinops, tag) + return template(self) + return None def parse(self, expr: str): """ diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 9d0621a4..59bd5869 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -203,6 +203,12 @@ def make_expression(self): self.set_input_ein(0, dim1) self.set_output_ein(0, dim1) + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + op = Activation(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + op.set_output(idx, output) + return op + class Dropout(Activation): """ @@ -218,6 +224,13 @@ def __init__(self, signature, inputs, name='dropout', **kwargs): self.kwargs['training'] = inputs[2] self.kwargs['inplace'] = inputs[3] + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + inputs = inputs + [self.kwargs['p'], self.kwargs['training'], self.kwargs['inplace']] + op = Dropout(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + op.set_output(idx, output) + return op + class Softmax(Activation): @@ -238,10 +251,16 @@ def make_expression(self): dim = self.kwargs['dim'] self._ieins[0][dim].reduce = EinDim.ReduceType.Stay + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + inputs = inputs + [self.kwargs['dim'], self.kwargs['_stacklevel'], self.kwargs['dtype']] + op = Dropout(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + op.set_output(idx, output) + return op # ===================== Loss Computation (Reduce) ========================= -class Sum(IRFwOperation): +class Sum(IREinops): """ torch.sum """ @@ -265,36 +284,43 @@ def __init__(self, signature, inputs, name='sum', **kwargs): else: self.kwargs['keepdim'] = False - def infer_shape(self): - if self.inputs(0).shape is None: - return False - - # change dim to positive value - ndim = len(self.inputs(0).shape) - if 'dim' in self.kwargs: - dim = self.kwargs['dim'] - dim = ndim + dim if dim < 0 else dim - self.kwargs['dim'] = dim - reduce_dims = [dim] - else: - reduce_dims = list(range(ndim)) - - if 'keepdim' in self.kwargs: - keepdim = self.kwargs['keepdim'] + def make_expression(self): + """ + * -> 1 (no extra kwarg) + a b c -> a c (dim b) + a b c -> a 1 c (dim b and keepdim) + """ + reducedim = None if 'dim' not in self.kwargs else self.kwargs['dim'] + keepdim = False if 'keepdim' not in self.kwargs else self.kwargs['keepdim'] + input = self.inputs(0) + dims = string.ascii_lowercase + in_dim = [ + EinDim(dims[d]) for d in range(len(input.shape))] + ou_dim = copy.copy(in_dim) + if reducedim is not None: + in_dim[reducedim].reduce = EinDim.ReduceType.Sum + if keepdim: + ou_dim[reducedim] = EinDim('1') + else: + ou_dim.pop(reducedim) else: - keepdim = False + for dim in in_dim: + dim.reduce = EinDim.ReduceType.Sum + ou_dim = [EinDim('1')] + self.set_input_ein(0, in_dim) + self.set_output_ein(0, ou_dim) - shape = list() - for dim, nele in enumerate(self.inputs(0).shape): - if dim in reduce_dims: - if keepdim: - shape.append(1) - else: - shape.append(nele) - if len(shape) == 0: - shape = [1] - self._outputs[0].shape = shape - return True + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + reducedim = None if 'dim' not in self.kwargs else self.kwargs['dim'] + keepdim = False if 'keepdim' not in self.kwargs else self.kwargs['keepdim'] + inputs += [reducedim] + if reducedim is not None: + if keepdim: + inputs += [keepdim] + sum_op = Sum(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + sum_op.set_output(idx, output) + return sum_op # ========================= Memory Operation ========================== @@ -332,15 +358,17 @@ def make_expression(self): in_dim = [EinDim(dims[d]) for d in range(len(input.shape))] ou_dim = copy.copy(in_dim) ou_dim[dim0], ou_dim[dim1] = in_dim[dim1], in_dim[dim0] + self.set_input_ein(0, in_dim) + self.set_output_ein(0, ou_dim) - def renew(self, inputs: List[IRTensor], outputs: List[IRTensor]): + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): dim0 = self.kwargs['dim0'] dim1 = self.kwargs['dim1'] inputs += [dim0, dim1] - transpose = Transpose(self.signature, inputs, self.name) + op = Transpose(self.signature, inputs, self.name) for idx, output in enumerate(outputs): - transpose.set_output(idx, output) - return transpose + op.set_output(idx, output) + return op # ===================== Cube Complex Operation ======================= From 965b54914d8c23c48a2c610626f16abdfe8e7e0d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 5 Jan 2022 19:47:40 +0800 Subject: [PATCH 0527/1892] update to use dim parallel --- examples/mlp/policy/col_parallel.py | 16 ++++++++++------ examples/mlp/policy/hybrid_parallel.py | 16 +++++++++------- examples/mlp/policy/megatron_parallel.py | 23 ++++++++++++----------- examples/mlp/policy/pipe1f1b_parallel.py | 11 ++++------- examples/mlp/policy/pipe_parallel.py | 12 +++++++----- examples/mlp/policy/row_parallel.py | 16 ++++++++++------ 6 files changed, 52 insertions(+), 42 deletions(-) diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index 6c3cdf75..6fc8bbaa 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -1,3 +1,4 @@ +import enum from cube.graph import IRGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation @@ -45,13 +46,16 @@ def PAS(graph: IRGraph, resource): Linear Column Partition """ for node in graph.nodes(): - if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + if isinstance(node, IRFwOperation): algo = node.algorithms('dim') - if algo: - sub_nodes = graph.partition( - node, algo, config=dict(idx=1, dim=0, num=resource.ngpus) - ) - else: + sub_nodes = graph.partition( + node, algo, config=dict(idx=1, dim=0, num=resource.ngpus) + ) + if sub_nodes is None: # partition fails # graph.assign(node, list(range(resource.ngpus))) sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): diff --git a/examples/mlp/policy/hybrid_parallel.py b/examples/mlp/policy/hybrid_parallel.py index 658c5271..b09f9a40 100644 --- a/examples/mlp/policy/hybrid_parallel.py +++ b/examples/mlp/policy/hybrid_parallel.py @@ -7,14 +7,16 @@ def PAS(graph: IRGraph, resource): Linear Hybrid Partition """ for idx, node in enumerate(graph.nodes()): - if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + if isinstance(node, IRFwOperation): algo = node.algorithms('dim') - if algo: - sub_nodes = graph.partition( - node, algo, - config=dict(idx=1, dim=(idx+1)%2, num=resource.ngpus) - ) - else: + sub_nodes = graph.partition( + node, algo, config=dict(idx=1, dim=(idx+1)%2, num=resource.ngpus) + ) + if sub_nodes is None: # partition fails sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) diff --git a/examples/mlp/policy/megatron_parallel.py b/examples/mlp/policy/megatron_parallel.py index 40ba3df1..1d5b12ed 100644 --- a/examples/mlp/policy/megatron_parallel.py +++ b/examples/mlp/policy/megatron_parallel.py @@ -9,20 +9,21 @@ def PAS(graph: IRGraph, resource): tp = 2 dp = resource.ngpus // tp for idx, node in enumerate(graph.nodes()): - if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + continue + if isinstance(node, IRFwOperation): + sub_nodes = list() algo = node.algorithms('dim') - if algo: - sub_nodes = list() - tp_nodes = graph.partition( - node, algo, - config=dict(idx=1, dim=(idx+1)%2, num=tp) - ) + tp_nodes = graph.partition( + node, algo, config=dict(idx=1, dim=(idx+1)%2, num=tp) + ) + if tp_nodes is not None: for tp_node in tp_nodes: algo = tp_node.algorithms('dim') - dp_nodes = graph.partition( - tp_node, algo, - config=dict(idx=0, dim=0, num=dp) - ) + dp_nodes = graph.partition(tp_node, algo, config=dict(idx=0, dim=0, num=dp)) sub_nodes += dp_nodes else: sub_nodes = graph.replicate(node, times=resource.ngpus) diff --git a/examples/mlp/policy/pipe1f1b_parallel.py b/examples/mlp/policy/pipe1f1b_parallel.py index a95e540e..bde03c16 100644 --- a/examples/mlp/policy/pipe1f1b_parallel.py +++ b/examples/mlp/policy/pipe1f1b_parallel.py @@ -24,13 +24,10 @@ def b(micro_batch_id: int, stage_id: int): stage_op_num = len(fnodes) // num_stage for idx, node in enumerate(fnodes): stage = min(idx // stage_op_num, num_stage - 1) - sub_nodes = None - algo = node.algorithms('data') - if algo is not None: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=num_micro_batch)) - else: - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, config=dict(dim=0, chunk_num=num_micro_batch)) + # partition at batch dimension + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, config=dict(idx=0, dim=0, num=num_micro_batch)) for mid, sub_node in enumerate(sub_nodes): f(mid, stage).append(sub_node) graph.assign(sub_node, stage) diff --git a/examples/mlp/policy/pipe_parallel.py b/examples/mlp/policy/pipe_parallel.py index c2c821a5..6ee4d674 100644 --- a/examples/mlp/policy/pipe_parallel.py +++ b/examples/mlp/policy/pipe_parallel.py @@ -10,12 +10,14 @@ def PAS(graph, resource): """ micro_batch_num = resource.ngpus for node in graph.nodes(): - if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): + if isinstance(node, IRDataOperation): + device = random.randint(0, resource.ngpus - 1) + graph.assign(node, device) + if isinstance(node, IRFwOperation): algo = node.algorithms('dim') - if algo is not None: - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=0, num=micro_batch_num)) - else: + sub_nodes = graph.partition( + node, algo, config=dict(idx=0, dim=0, num=micro_batch_num)) + if sub_nodes is None: sub_nodes = [node] for idx, sub_node in enumerate(sub_nodes): device = random.randint(0, resource.ngpus - 1) diff --git a/examples/mlp/policy/row_parallel.py b/examples/mlp/policy/row_parallel.py index b44f49a4..71e5c0c5 100644 --- a/examples/mlp/policy/row_parallel.py +++ b/examples/mlp/policy/row_parallel.py @@ -7,13 +7,17 @@ def PAS(graph: IRGraph, resource): Linear Column Partition """ for node in graph.nodes(): - if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + if isinstance(node, IRFwOperation): algo = node.algorithms('dim') - if algo: - sub_nodes = graph.partition( - node, algo, config=dict(idx=1, dim=1, num=resource.ngpus) - ) - else: + sub_nodes = graph.partition( + node, algo, config=dict(idx=1, dim=1, num=resource.ngpus) + ) + if sub_nodes is None: # partition fails + # graph.assign(node, list(range(resource.ngpus))) sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) From 010517314f5317d37795f8b2cbcaffd3923faea3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jan 2022 13:35:42 +0800 Subject: [PATCH 0528/1892] tensor producer and consumer is restricted by graph attach and detach --- cube/graph/adapter/adapter.py | 7 +- cube/graph/adapter/gen.py | 2 + cube/graph/graph.py | 129 +++++++++++++++++++++---------- cube/graph/operator/operator.py | 133 ++++++++++---------------------- cube/graph/tensor.py | 2 +- cube/logics/model.py | 12 +-- 6 files changed, 144 insertions(+), 141 deletions(-) diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index d158235b..ba87ea4b 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -452,7 +452,12 @@ def gen_merge(dst_tensor, intersections): break # cannot merge or add if out is None: - raise RuntimeError("Merge Plan not found") + print(f'failed tensor: {dst_tensor}') + print(f'ptensor:') + for tensor in dst_tensor.parent.ptensors: + print(f'node-{tensor._cell._id}: {tensor}') + print(f'intersections: {intersections}') + raise RuntimeError(f"Merge plan of tensor {dst_tensor} not found") return prims def __repr__(self): diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index b0d6c35a..5e9a29cd 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -27,7 +27,9 @@ def gen(graph: IRGraph, eager=True) -> IRGraph: # update the gradient before generate adapter for node in graph.nodes(): if isinstance(node, IRBpOperation): + idx = graph.detach(node) node.update() + graph.attach(node, idx) graph = AdapterGener.gen_activation_adapter(graph, eager) graph = AdapterGener.gen_weight_reducer(graph) return graph diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 002be6b0..bf5d52a5 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -33,8 +33,8 @@ def __init__(self, inputs: Optional[List[IRTensor]], outputs: Optional[List[IRTensor]], module_name: str): - - self._nodes: List[IRCell] = nodes + + self._nodes: List[IRCell] = list() self._parameters = list() if inputs is None: @@ -54,6 +54,10 @@ def __init__(self, for idx, tensor in enumerate(outputs): self.set_output(idx, tensor) + # insert node from nodes + for idx, node in enumerate(nodes): + self.attach(node, idx) + # set parameter for node in self._nodes: for input in node.inputs(): @@ -173,6 +177,54 @@ def subgraph(self, sub_nodes: List[IRCell]): ) return subgraph + def detach(self, node: IRCell, reset_dependency=False) -> int: + """ + Detach (remove) a node from current graph. + + All the used input and output tensors inside the node + are removed from consumed and produced tensor list. + + Return: + index (int): index of the detached node in the graph + """ + if node not in self.nodes(): + raise KeyError(f"node {node} is not in graph.") + ops = node.nodes() if isinstance(node, IRGraph) else [node] + for op in ops: + for input in op.inputs(): + if isinstance(input, IRSubTensor): + input.parent.rm_consumer(op) + for output in op.outputs(): + if isinstance(output, IRSubTensor): + output.parent.rm_producer(op) + index = self._nodes.index(node) + self._nodes.pop(index) + if reset_dependency: + self.reset_dependency() + return index + + def attach(self, node: IRCell, index, reset_dependency=False): + """ + Attach (insert) a node into current graph at node index. + + All the used input and output tensors inside the node are + recorded in consumed and produced tensor list. + """ + if node in self.nodes(): + raise KeyError(f"node {node} is already in graph.") + ops = node.nodes() if isinstance(node, IRGraph) else [node] + for op in ops: + for input in op.inputs(): + if isinstance(input, IRSubTensor): + input.parent.add_consumer(op, input) + for output in op.outputs(): + if isinstance(output, IRSubTensor): + output.parent.add_producer(op, output) + self._nodes.insert(index, node) + if reset_dependency: + self.reset_dependency() + return + @staticmethod def get_inputs(nodes: List[IRCell]): """ @@ -247,22 +299,21 @@ def replicate(self, op: IRCell, times=1) -> Optional[List[IRCell]]: if op not in self.nodes(): raise RuntimeError(f"Op {op} not exsits") - ops = [op] - for _ in range(times - 1): - ops.append(op.replicate()) + fnodes = [op.replicate() for _ in range(times - 1)] + # insert forward + fidx = self.nodes().index(op) + for idx, fnode in enumerate(fnodes): + self.attach(fnode, fidx + idx) + # insert backward if isinstance(op.mirror, IRBpOperation): - for rep_op in ops[1:]: - rep_op.gen_backward() - idx = self.nodes().index(op) - # forward - self._nodes = self._nodes[:idx] + ops + self._nodes[idx+1:] - # backward - if isinstance(op.mirror, IRCell): - bops = [op.mirror for op in ops][::-1] - midx = self.nodes().index(op.mirror) - self._nodes = self._nodes[:midx] + bops + self._nodes[midx+1:] + for fnode in fnodes: + fnode.gen_backward() + bnodes = [fnode.mirror for fnode in fnodes][::-1] + bidx = self.nodes().index(op.mirror) + for idx, bnode in enumerate(bnodes): + self.attach(bnode, bidx + idx) self.reset_dependency() - return ops + return [op] + fnodes def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: """ @@ -283,6 +334,8 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional raise TypeError("Expected algo to be GenericDistAlgo") if op not in self.nodes(): raise RuntimeError(f"Not Exist: {op}") + if not (isinstance(op, IRFwOperation) or isinstance(op, IRDataOperation)): + raise ValueError("Only allow op to be forward op or data op.") if algo.node != op: return None @@ -298,36 +351,30 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional raise NotImplementedError( f"Not support feature-map {input} to be splitted in value as input" ) - - # remove reference - finputs = op.inputs() - op.make_empty() - if op.mirror is not None: - op.mirror.make_empty() - - # generate backward + # update forward + findex = self.detach(op) + for idx, fnode in enumerate(fnodes): + self.attach(fnode, findex + idx) + # update backward + if isinstance(op.mirror, IRBpOperation): + bindex = self.detach(op.mirror) + bnodes = [fnode.gen_backward() for fnode in fnodes][::-1] + for idx, bnode in enumerate(bnodes): + self.attach(bnode, bindex + idx) + # update gradient updated = set() - for input in finputs: + for input in op.inputs(): if not isinstance(input, IRSubTensor): continue - # go through related consumers and update backward op for fnode in input.parent.consumers: - if isinstance(fnode, IRFwOperation) and fnode._id not in updated: - if fnode.mirror is not None: - fnode.mirror.update() - else: - fnode.gen_backward() - updated.add(fnode._id) - - # insert nodes - idx = self._nodes.index(op) - self._nodes = self._nodes[:idx] + fnodes + self._nodes[idx+1:] - if op.mirror is not None: - idx = self._nodes.index(op.mirror) - bnodes = [node.mirror for node in fnodes][::-1] - self._nodes = self._nodes[:idx] + bnodes + self._nodes[idx+1:] + bnode = fnode.mirror + if isinstance(bnode, IRBpOperation) and fnode._id not in updated: + idx = self.detach(bnode) + bnode.update() + self.attach(bnode, idx) + updated.add(fnode._id) self.reset_dependency() - return copy.copy(fnodes) + return fnodes def merge(self, sub_graph, target_op, op_partition_algorithm): raise NotImplementedError diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index c7146eb6..8f6daf36 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -25,27 +25,27 @@ def infer_shape(self): """ raise NotImplementedError - def set_input(self, input_index: int, val: Any): - # remove the consumer - old_val = self.inputs(input_index) - if isinstance(old_val, IRSubTensor): - old_val.parent.rm_consumer(self) - # add the consumer - val = super().set_input(input_index, val) - if isinstance(val, IRSubTensor): - val.parent.add_consumer(self, val) - return val - - def set_output(self, output_index: int, val: Any): - # remove the producer - old_val = self.outputs(output_index) - if isinstance(old_val, IRSubTensor): - old_val.parent.rm_producer(self) - # add the producer - val = super().set_output(output_index, val) - if isinstance(val, IRSubTensor): - val.parent.add_producer(self, val) - return val + # def set_input(self, input_index: int, val: Any): + # # remove the consumer + # old_val = self.inputs(input_index) + # if isinstance(old_val, IRSubTensor): + # old_val.parent.rm_consumer(self) + # # add the consumer + # val = super().set_input(input_index, val) + # if isinstance(val, IRSubTensor): + # val.parent.add_consumer(self, val) + # return val + + # def set_output(self, output_index: int, val: Any): + # # remove the producer + # old_val = self.outputs(output_index) + # if isinstance(old_val, IRSubTensor): + # old_val.parent.rm_producer(self) + # # add the producer + # val = super().set_output(output_index, val) + # if isinstance(val, IRSubTensor): + # val.parent.add_producer(self, val) + # return val def replicate(self): """ @@ -119,28 +119,6 @@ def algorithms(self, tag: Optional[str] = None): template = factory.algorithms(type(self), tag) return template(self) - def set_input(self, input_index: int, val: Any): - # remove the consumer - old_val = self.inputs(input_index) - if isinstance(old_val, IRSubTensor): - old_val.parent.rm_consumer(self) - # add the consumer - val = super().set_input(input_index, val) - if isinstance(val, IRSubTensor): - val.parent.add_consumer(self, val) - return val - - def set_output(self, output_index: int, val: Any): - # remove the producer - old_val = self.outputs(output_index) - if isinstance(old_val, IRSubTensor): - old_val.parent.rm_producer(self) - # add the producer - val = super().set_output(output_index, val) - if isinstance(val, IRSubTensor): - val.parent.add_producer(self, val) - return val - def replicate(self): """ Replicate the Operation @@ -162,6 +140,15 @@ def replicate(self): return cpy def gen_backward(self): + """ + Generate backward operator for this forward operator. + + Note by calling this API, this forward operator must be + attached into any of one IRGraph, or will lead to reference + count 0 error on gradient calcaultion. + + return: IRBpOperation + """ if self.mirror is not None: raise RuntimeError( "Backward Op already generated. Use self.mirror.update() instead.") @@ -268,48 +255,19 @@ def set_data(self, data_index: int, val: Any): self._datas[data_index] = val return val - def set_input(self, input_index: int, val: Any): - """ - Set the node input gradient - (i.e., output gradient in forward) at input index. - The grad is same order with corresponding output tensor - of it's forward tensor - - Args: - input_idx: input index - val: Union[IRTensor, Any] - - Return: - The set val - """ - # remove the consumer - old_val = self.inputs(input_index) - if isinstance(old_val, IRSubTensor): - old_val.parent.rm_consumer(self) - # add the consumer - val = super().set_input(input_index, val) - if isinstance(val, IRSubTensor): - val.parent.add_consumer(self, val) - return val - - def set_output(self, output_index: int, val: Any): - """ - Set op output grad (Forward input gradient) - """ - # remove the producer - old_val = self.outputs(output_index) - if isinstance(old_val, IRSubTensor): - old_val.parent.rm_producer(self) - # add the producer - val = super().set_output(output_index, val) - if isinstance(val, IRSubTensor): - val.parent.add_producer(self, val) - return val - def update(self): """ Update this backward operator. - This neccessary when op is partitioned and reference count is changed. + This is neccessary when op is partitioned and reference count is changed. + + Note in order to update produced and consumed tensor list, this call should be + wrapped with IRGraph detach and attach: + + ``` + idx = graph.detach(node) + node.update() + graph.attach(node, idx) + ``` """ fnode = self.mirror for idx, input in enumerate(fnode.inputs()): @@ -377,17 +335,6 @@ def infer_shape(self): """ return True - def set_output(self, output_index: int, val: Any): - # remove the producer - old_val = self.outputs(output_index) - if isinstance(old_val, IRSubTensor): - old_val.parent.rm_producer(self) - # add the producer - val = super().set_output(output_index, val) - if isinstance(val, IRSubTensor): - val.parent.add_producer(self, val) - return val - def algorithms(self, tag: Optional[str] = None): """ get algorithm from algorithm factory diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 6b01f56f..05acbb5b 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -362,7 +362,7 @@ def ctensors(self): def add_producer(self, cell: IRCell, tensor: IRTensor): if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): raise TypeError("Expect an IRCell and an IRTensor") - if cell not in self.consumers: + if cell not in self.producers: self.producers.append(cell) self.ptensors.append(tensor) diff --git a/cube/logics/model.py b/cube/logics/model.py index f6073f73..b390c974 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -62,11 +62,12 @@ def forward(graph, *args) -> IRGraph: fnodes = list() # generate forward nodes - for node in graph.nodes(): - inputs = node.inputs() - outputs = node.outputs() - # fnode = copy.copy(node) - fnode : IRFwOperation = node + for fnode in graph.nodes(): + fidx = graph.detach(fnode) + inputs = fnode.inputs() + outputs = fnode.outputs() + # fnode = copy.copy(fnode) + fnode : IRFwOperation = fnode fnode._inputs = inputs fnode._outputs = outputs # set forward inputs @@ -75,6 +76,7 @@ def forward(graph, *args) -> IRGraph: # set forward outputs for idx, val in enumerate(outputs): fnode.set_output(idx, gener.renew(val)) + graph.attach(fnode, fidx) fnodes.append(fnode) # reverse is only to make op id looks consecutive From 017760d840364d85d9252e711181a8ee92a47d24 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jan 2022 19:44:00 +0800 Subject: [PATCH 0529/1892] nodes merge for grid search --- cube/graph/graph.py | 56 ++++++++++++++++++++++++++++++++++++++-- cube/search/__init__.PY | 0 cube/search/iterator.py | 57 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 cube/search/__init__.PY create mode 100644 cube/search/iterator.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index bf5d52a5..56d6d999 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -376,8 +376,60 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional self.reset_dependency() return fnodes - def merge(self, sub_graph, target_op, op_partition_algorithm): - raise NotImplementedError + def merge(self, nodes: List[IRCell], target_node: IRCell): + """ + Merge consecutive nodes in the graph to the target_node. + Note corresponding mirror nodes (if have) will also be merged. + + We don't check computation equivalence between nodes and target_node. + + Merge requires nodes are consecutive in the graph sequence. + """ + if not isinstance(target_node, IRCell): + raise TypeError("Expected target node to be IRCell") + if target_node in self.nodes(): + raise ValueError("Target node is already in the graph") + for node in nodes: + if node not in self.nodes(): + raise KeyError(f"node {node} is not in the graph") + indices = [self.nodes().index(node) for node in nodes] + # consecutive + if max(indices) - min(indices) != len(indices) - 1: + return False + index = min(indices) + # update forward + for node in nodes: + self.detach(node) + self.attach(target_node, index) + # update backward + if all([isinstance(node.mirror, IRCell) for node in nodes]): + bidx = len(self.nodes()) + for node in nodes: + idx = self.detach(node.mirror) + bidx = min(idx, bidx) + if target_node.mirror is None: + if not isinstance(target_node, IRFwOperation): + raise RuntimeError("target node is not FwOp and doens't have mirror node") + target_node.gen_backward() + self.attach(target_node.mirror, bidx) + elif all([isinstance(node.mirror, None) for node in nodes]): + pass + else: + raise ValueError("nodes should have nothing-or-all mirror nodes") + # update weights + updated = set() + for node in nodes + [target_node]: + for input in node.inputs(): + if not isinstance(input, IRSubTensor): + continue + for fnode in input.parent.consumers: + bnode = fnode.mirror + if isinstance(bnode, IRBpOperation) and fnode._id not in updated: + idx = self.detach(bnode) + bnode.update() + self.attach(bnode, idx) + updated.add(fnode._id) + return True def identity(self, input_tensor, dst_op): raise NotImplementedError diff --git a/cube/search/__init__.PY b/cube/search/__init__.PY new file mode 100644 index 00000000..e69de29b diff --git a/cube/search/iterator.py b/cube/search/iterator.py new file mode 100644 index 00000000..e1fab246 --- /dev/null +++ b/cube/search/iterator.py @@ -0,0 +1,57 @@ +from itertools import combinations +from typing import Any, List + + +def comb_iter(candidates: List, pick_num: int): + """ + combination pickers + """ + return combinations(candidates, pick_num) + + +def otho_iter(slots: List[List[Any]]): + """ + othogonal pickers + + item for each slot can be randomly selected + """ + if len(slots) == 0: + yield [] + return + slot = slots[0] + if len(slots) == 1: + for item in slot: + yield [item] + else: + slots = slots[1:] + for item in slot: + for res in otho_iter(slots): + yield [item] + res + return + + +def factorization(K: int, num=1): + """ + Decompose K into `depth` numbers that + a1 * a2 * ... * a_depth = K + ($\prod\limits_{i=1}^depth a_i = K$) + + Yield: + List[int] + """ + if num == 1: + yield [K] + else: + for i in range(1, K+1): + if K % i == 0: + for res in factorization(K // i, num-1): + yield [i] + res + + +if __name__ == '__main__': + + # for seq in otho_iter([[1,2,3], [4,5], [6,7,8]]): + # print(seq) + + for seq in factorization(8, 2): + print(seq) \ No newline at end of file From b513870fe1e31d5d64ad41be0ee9f697bc7a19c3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jan 2022 19:44:26 +0800 Subject: [PATCH 0530/1892] grid search on partition plans --- examples/mlp/policy/search.py | 81 +++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 examples/mlp/policy/search.py diff --git a/examples/mlp/policy/search.py b/examples/mlp/policy/search.py new file mode 100644 index 00000000..b4e38078 --- /dev/null +++ b/examples/mlp/policy/search.py @@ -0,0 +1,81 @@ + +from typing import Dict, List +from itertools import combinations +from cube.graph import IRGraph +from cube.graph.operator.operator import IRDataOperation, IRFwOperation +import cube.search.iterator as iterator + + +def get_plan(graph: IRGraph, fnode: IRFwOperation, configs: List[Dict]) -> List[IRFwOperation]: + + all_nodes = [fnode] + for config in configs: + sub_nodes = list() + for node in all_nodes: + algo = node.algorithms('dim') + sub = graph.partition(node, algo, config) + if sub is None: + sub = graph.replicate(node, times=config['num']) + sub_nodes += sub + all_nodes = sub_nodes + return all_nodes + + +def compositions(graph: IRGraph, fnode: IRFwOperation, nest: List[int]) -> List[IRFwOperation]: + """" + e.g., + fnode: linear + nest: [2, 4] + will get 9 partition strategies of 8-nodes + """ + all_configs = [ + dict(idx=0, dim=0), # data parallel + dict(idx=0, dim=1), # row parallel + dict(idx=1, dim=0), # col parallel + ] + config_iter = combinations(all_configs, len(nest)) + for configs in config_iter: + for config, ndev in zip(configs, nest): + config['num'] = ndev + nodes = get_plan(graph, fnode, configs) + yield nodes + graph.merge(nodes, fnode) + + +def sequence(graph: IRGraph, fnodes: IRFwOperation, resource): + + nest_depth = 2 + nests = iterator.factorization(resource.ngpus, nest_depth) + + if len(fnodes) == 0: + yield list() + + for fnode in fnodes: + for nest in nests: + for seq in compositions(graph, fnode, nest): + for idx, node in enumerate(seq): + graph.assign(node, idx) + for remain in sequence(graph, fnodes[1:], resource): + yield seq + remain + + +def PAS(graph: IRGraph, resource): + + # replicate data operation + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + + fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] + + for idx, seq in enumerate(sequence(graph, fnodes, resource)): + print(f'searching index: {idx}') + print(graph.extra_repr()) + # for node in seq: + # print(node) + # print('\n') + print(f'==> grid searched on {idx+1} seq') + + raise NotImplementedError From 3d38daa7f336b71b7cbabe9a5a9c13bf729e70ab Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jan 2022 20:06:52 +0800 Subject: [PATCH 0531/1892] search tag to record trace --- examples/mlp/policy/search.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/examples/mlp/policy/search.py b/examples/mlp/policy/search.py index b4e38078..e1c910ad 100644 --- a/examples/mlp/policy/search.py +++ b/examples/mlp/policy/search.py @@ -16,8 +16,10 @@ def get_plan(graph: IRGraph, fnode: IRFwOperation, configs: List[Dict]) -> List[ sub = graph.partition(node, algo, config) if sub is None: sub = graph.replicate(node, times=config['num']) + fnode.tag = ('rep', 'rep') sub_nodes += sub all_nodes = sub_nodes + fnode.tag = tuple(config['name'] for config in configs) return all_nodes @@ -29,9 +31,9 @@ def compositions(graph: IRGraph, fnode: IRFwOperation, nest: List[int]) -> List[ will get 9 partition strategies of 8-nodes """ all_configs = [ - dict(idx=0, dim=0), # data parallel - dict(idx=0, dim=1), # row parallel - dict(idx=1, dim=0), # col parallel + dict(idx=0, dim=0, name='dat'), # data parallel + dict(idx=0, dim=1, name='row'), # row parallel + dict(idx=1, dim=0, name='col'), # col parallel ] config_iter = combinations(all_configs, len(nest)) for configs in config_iter: @@ -40,6 +42,7 @@ def compositions(graph: IRGraph, fnode: IRFwOperation, nest: List[int]) -> List[ nodes = get_plan(graph, fnode, configs) yield nodes graph.merge(nodes, fnode) + fnode.tag = None def sequence(graph: IRGraph, fnodes: IRFwOperation, resource): @@ -57,7 +60,14 @@ def sequence(graph: IRGraph, fnodes: IRFwOperation, resource): graph.assign(node, idx) for remain in sequence(graph, fnodes[1:], resource): yield seq + remain - + + +def comm_estimate(graph: IRGraph) -> int: + """ + Estimate communications + """ + pass + def PAS(graph: IRGraph, resource): @@ -68,14 +78,18 @@ def PAS(graph: IRGraph, resource): for idx, node in enumerate(sub_nodes): graph.assign(node, idx) + # replicate loss operation fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] + loss = fnodes[-1] + sub_nodes = graph.replicate(loss, times=resource.ngpus) + # search for linear operations + fnodes = fnodes[:-1] # only search linears for idx, seq in enumerate(sequence(graph, fnodes, resource)): print(f'searching index: {idx}') - print(graph.extra_repr()) - # for node in seq: - # print(node) - # print('\n') + # print(graph.extra_repr()) + for node in fnodes: + print(node.tag) print(f'==> grid searched on {idx+1} seq') raise NotImplementedError From 0f47fe7a1f307d7866ff6871e8e6eada1e2a81ba Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 7 Jan 2022 11:30:32 +0800 Subject: [PATCH 0532/1892] estimator for comm volume --- cube/ir/cten.py | 11 ++++++ cube/profiler/estimator.py | 70 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 cube/profiler/estimator.py diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 0622b61f..d913b50e 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -574,6 +574,17 @@ def shape(self, val): raise RuntimeError("Expected shape to be list[int]") self._shape = copy.copy(list(val)) + def nele(self) -> int: + """ + Get total number of element in the tensor. + """ + if self.shape is None: + raise RuntimeError("Tensor shape is not set") + cnt = 1 + for num in self.shape: + cnt *= num + return cnt + def src(self, cells: List[IRCell]) -> List[IRCell]: """ Return all the cells that will generate this tensor diff --git a/cube/profiler/estimator.py b/cube/profiler/estimator.py new file mode 100644 index 00000000..867b9afd --- /dev/null +++ b/cube/profiler/estimator.py @@ -0,0 +1,70 @@ + +from cube.graph.tensor import IRSubTensor, ValueMap +from cube.graph.adapter.adapter import IRAdapter +from cube.graph import IRGraph +from cube.ir.cten import IRCell, IRTensor + + +class Estimator: + + def __init__(self, graph: IRGraph): + """ + Estimator for policy use + """ + + self.graph = graph + + def comm_volume(self, device: int) -> int: + """ + Estimate message recv volume of device id. + This has no requirement for generating adapters in graph. + + Node that is not assigned to a particular device will not + be considered. + """ + volume = 0 + for node in self.graph.nodes(): + if isinstance(node, IRAdapter): + continue + if device in node.device: + volume += self.comm_volume_node(node) + return volume + + def comm_volume_node(self, node: IRCell) -> int: + """ + Estimate node message recv volume. + This has no requirement for generating adapters in graph. + """ + if node not in self.graph.nodes(): + raise KeyError(f"node {node} not in graph") + if len(node.device) == 0: + raise RuntimeError(f"node {node} device is not assigned") + volume = 0 + for input in node.inputs(): + if isinstance(input, IRSubTensor): + # reducer + if input.is_param(): + if input.grad.valmap != ValueMap(0, 1): + volume += input.nele() * (input.grad.valmap.chunk_num - 1) + # adapter + else: + local, remote = list(), list() + for ptensor in input.parent.ptensors: + if ptensor.device != input.device: + remote.append(ptensor) + else: + local.append(ptensor) + if input in local: + continue + else: + for ptensor in remote: + if input.overlap(ptensor): + intersection = input.common(ptensor) + volume += intersection.nele() + return volume + + def flops(self) -> int: + raise NotImplementedError + + def flops_node(self, node: IRCell) -> int: + raise NotImplementedError From 7d08d7d0443712c635d72cc4a0c9673d1e6d34f5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 7 Jan 2022 13:58:42 +0800 Subject: [PATCH 0533/1892] fix estimator bug --- cube/profiler/__init__.py | 3 ++- cube/profiler/estimator.py | 32 ++++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/cube/profiler/__init__.py b/cube/profiler/__init__.py index 5649c2ee..e349da1d 100644 --- a/cube/profiler/__init__.py +++ b/cube/profiler/__init__.py @@ -1 +1,2 @@ -from cube.profiler.timer import CudaTimer \ No newline at end of file +from cube.profiler.timer import CudaTimer +from cube.profiler.estimator import Estimator \ No newline at end of file diff --git a/cube/profiler/estimator.py b/cube/profiler/estimator.py index 867b9afd..07ad6a44 100644 --- a/cube/profiler/estimator.py +++ b/cube/profiler/estimator.py @@ -1,4 +1,5 @@ +from cube.graph.operator.operator import IRBpOperation, IRFwOperation from cube.graph.tensor import IRSubTensor, ValueMap from cube.graph.adapter.adapter import IRAdapter from cube.graph import IRGraph @@ -34,6 +35,11 @@ def comm_volume_node(self, node: IRCell) -> int: """ Estimate node message recv volume. This has no requirement for generating adapters in graph. + + Note for intermediate tensor communication, the estimated + communication volume is: + Volume = 0 if local produced tensor can covor all the needed region. + else N#(remote produced overlapping region) """ if node not in self.graph.nodes(): raise KeyError(f"node {node} not in graph") @@ -54,13 +60,27 @@ def comm_volume_node(self, node: IRCell) -> int: remote.append(ptensor) else: local.append(ptensor) - if input in local: + # check local + local_cover = False + for ptensor in local: + if input.overlap(ptensor): + intersection = input.common(ptensor) + if intersection == input: + local_cover = True + break + if local_cover: continue - else: - for ptensor in remote: - if input.overlap(ptensor): - intersection = input.common(ptensor) - volume += intersection.nele() + for ptensor in remote: + if input.overlap(ptensor): + intersection = input.common(ptensor) + volume += intersection.nele() + # debug info + # if isinstance(node, IRFwOperation): + # print(f'fw{node._id}-{node.device}-{node.name}: {volume}') + # elif isinstance(node, IRBpOperation): + # print(f'bw{node._id}(fw{node.mirror._id}): {volume}') + # else: + # print(f'cell{node._id}-{node.device}-{node.name}: {volume}') return volume def flops(self) -> int: From 7f1f58f7b5359b1dcedd1893ce8e4737f6ee9d31 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 7 Jan 2022 14:00:21 +0800 Subject: [PATCH 0534/1892] add grid search for parallelisms --- examples/mlp/policy/search.py | 45 +++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/examples/mlp/policy/search.py b/examples/mlp/policy/search.py index e1c910ad..7c8cf12d 100644 --- a/examples/mlp/policy/search.py +++ b/examples/mlp/policy/search.py @@ -4,6 +4,9 @@ from cube.graph import IRGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation import cube.search.iterator as iterator +from cube.profiler.estimator import Estimator + +import numpy as np def get_plan(graph: IRGraph, fnode: IRFwOperation, configs: List[Dict]) -> List[IRFwOperation]: @@ -19,7 +22,7 @@ def get_plan(graph: IRGraph, fnode: IRFwOperation, configs: List[Dict]) -> List[ fnode.tag = ('rep', 'rep') sub_nodes += sub all_nodes = sub_nodes - fnode.tag = tuple(config['name'] for config in configs) + fnode.tag = tuple('{}-{}'.format(config['name'], config['num']) for config in configs) return all_nodes @@ -62,11 +65,15 @@ def sequence(graph: IRGraph, fnodes: IRFwOperation, resource): yield seq + remain -def comm_estimate(graph: IRGraph) -> int: +def comm_estimate(graph: IRGraph, ndevice: int) -> int: """ Estimate communications """ - pass + estimator = Estimator(graph) + total_volume = 0 + for devid in range(ndevice): + total_volume += estimator.comm_volume(devid) + return total_volume def PAS(graph: IRGraph, resource): @@ -82,14 +89,38 @@ def PAS(graph: IRGraph, resource): fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] loss = fnodes[-1] sub_nodes = graph.replicate(loss, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) # search for linear operations fnodes = fnodes[:-1] # only search linears + seqs = list() + comms = list() + plans = list() for idx, seq in enumerate(sequence(graph, fnodes, resource)): - print(f'searching index: {idx}') + print(f'searching index: {idx}...') + seqs.append(seq) + comm = comm_estimate(graph, resource.ngpus) + comms.append(comm) + plan = [node.tag for node in fnodes] + plans.append(plan) + print(f'comm volume: {comm}') + # for node in fnodes: + # print(node.tag) # print(graph.extra_repr()) - for node in fnodes: - print(node.tag) - print(f'==> grid searched on {idx+1} seq') + print(f'==> grid search done on {idx+1} seq') + print(f'\n\n') + + comms = np.array(comms) + indices = np.argsort(comms) + + top_indices = indices[:10] + top_plan = [plans[idx] for idx in top_indices] + top_comm = [comms[idx] for idx in top_indices] + for top_idx, (idx, plan, comm) in enumerate(zip(top_indices, top_plan, top_comm)): + print(f'top {top_idx} (plan index {idx}):') + for lid, node in enumerate(plan): + print(f'linear{lid}: {node}') + print(f'===> comm: {comm}') raise NotImplementedError From cdd484c458bbeee536018085fa514657252452e1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 7 Jan 2022 14:20:00 +0800 Subject: [PATCH 0535/1892] search with timer --- examples/mlp/policy/search.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/mlp/policy/search.py b/examples/mlp/policy/search.py index 7c8cf12d..5715e8f7 100644 --- a/examples/mlp/policy/search.py +++ b/examples/mlp/policy/search.py @@ -1,6 +1,8 @@ from typing import Dict, List +import time from itertools import combinations + from cube.graph import IRGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation import cube.search.iterator as iterator @@ -93,6 +95,8 @@ def PAS(graph: IRGraph, resource): graph.assign(sub_node, idx) # search for linear operations + start = time.time() + fnodes = fnodes[:-1] # only search linears seqs = list() comms = list() @@ -104,11 +108,11 @@ def PAS(graph: IRGraph, resource): comms.append(comm) plan = [node.tag for node in fnodes] plans.append(plan) - print(f'comm volume: {comm}') + print(f'comm volume param#: {comm}') # for node in fnodes: # print(node.tag) # print(graph.extra_repr()) - print(f'==> grid search done on {idx+1} seq') + print(f'==> grid search done on {idx+1} plans') print(f'\n\n') comms = np.array(comms) @@ -121,6 +125,9 @@ def PAS(graph: IRGraph, resource): print(f'top {top_idx} (plan index {idx}):') for lid, node in enumerate(plan): print(f'linear{lid}: {node}') - print(f'===> comm: {comm}') + print(f'===> comm param#: {comm}') + + end = time.time() + print('grid search time: {:.2f}'.format(end-start)) raise NotImplementedError From b6a80b4779ecfbdfa5709c297f2b0a5b95ce5be7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 9 Jan 2022 14:49:55 +0800 Subject: [PATCH 0536/1892] communication volume also consider gets from consumer --- cube/profiler/estimator.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/cube/profiler/estimator.py b/cube/profiler/estimator.py index 07ad6a44..6cec2514 100644 --- a/cube/profiler/estimator.py +++ b/cube/profiler/estimator.py @@ -60,7 +60,7 @@ def comm_volume_node(self, node: IRCell) -> int: remote.append(ptensor) else: local.append(ptensor) - # check local + # check local producer local_cover = False for ptensor in local: if input.overlap(ptensor): @@ -70,10 +70,29 @@ def comm_volume_node(self, node: IRCell) -> int: break if local_cover: continue + # check remote producer + remote_producer_volume = 0 for ptensor in remote: if input.overlap(ptensor): intersection = input.common(ptensor) - volume += intersection.nele() + remote_producer_volume += intersection.nele() + # check remote consumer + # TODO: need to check if all consumers can be + # merged to input + remote_consumer_volume = None + index = input.parent.consumers.index(node) + for ctensor in input.parent.ctensors[:index]: + if input.overlap(ctensor): + if remote_consumer_volume is None: + remote_consumer_volume = 0 + intersection = input.common(ctensor) + remote_consumer_volume += intersection.nele() + if intersection == input: + break + if remote_consumer_volume is None: + volume += remote_producer_volume + else: + volume += min(remote_consumer_volume, remote_producer_volume) # debug info # if isinstance(node, IRFwOperation): # print(f'fw{node._id}-{node.device}-{node.name}: {volume}') From bc64a6805cd2d15f4a87578f25b8ee50754ed2dc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 10:44:22 +0800 Subject: [PATCH 0537/1892] add searched optimal case --- examples/mlp/linears.py | 6 +-- examples/mlp/policy/col_parallel.py | 1 - examples/mlp/policy/megatron_parallel.py | 2 +- examples/mlp/policy/optimal.py | 60 ++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 5 deletions(-) create mode 100644 examples/mlp/policy/optimal.py diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index fdde7693..3f88b8d7 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -17,7 +17,7 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.pipe_parallel import PAS +from examples.mlp.policy.optimal import PAS # from examples.mlp.policy.col_parallel import P, A, S # PAS = (P, A, S) @@ -50,8 +50,8 @@ def forward(self, data): def train(): - batch_size = 4096 - dim = 4096 + batch_size = 8192 + dim = 8192 model = MLP(dim=dim) model = cube.SemanticModel( diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index 6fc8bbaa..b7855732 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -1,4 +1,3 @@ -import enum from cube.graph import IRGraph from cube.graph.operator.operator import IRDataOperation, IRFwOperation diff --git a/examples/mlp/policy/megatron_parallel.py b/examples/mlp/policy/megatron_parallel.py index 1d5b12ed..0d05500a 100644 --- a/examples/mlp/policy/megatron_parallel.py +++ b/examples/mlp/policy/megatron_parallel.py @@ -6,7 +6,7 @@ def PAS(graph: IRGraph, resource): """ Linear Hybrid + Nested Partition """ - tp = 2 + tp = 4 dp = resource.ngpus // tp for idx, node in enumerate(graph.nodes()): if isinstance(node, IRDataOperation): diff --git a/examples/mlp/policy/optimal.py b/examples/mlp/policy/optimal.py new file mode 100644 index 00000000..e94523b8 --- /dev/null +++ b/examples/mlp/policy/optimal.py @@ -0,0 +1,60 @@ +from cube.graph import IRGraph +from cube.graph.operator import IRFwOperation, IRDataOperation + + +def PAS(graph: IRGraph, resource): + + assert resource.ngpus == 4, "the optimal plan is for 4 GPU case." + + # replicate data operation + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + + # replicate loss operation + fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] + loss = fnodes[-1] + sub_nodes = graph.replicate(loss, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + + fnodes = fnodes[:-1] + # linear0 config + config0 = [ + None, + dict(idx=1, dim=0, num=4) # col + ] + # linear1 config + config1 = [ + dict(idx=0, dim=1, num=2), # row + dict(idx=1, dim=0, num=2), # col + ] + # linear2 config + config2 = [ + dict(idx=0, dim=0, num=2), # dat + dict(idx=0, dim=1, num=2), # row + ] + # linear3 config + config3 = [ + dict(idx=0, dim=0, num=2), # dat + dict(idx=0, dim=1, num=2), # row + ] + configs = [config0, config1, config2, config3] + assert len(fnodes) == len(configs) + for fnode, config in zip(fnodes, configs): + all_nodes = [fnode] + for conf in config: + if conf is None: + continue + sub_nodes = list() + for node in all_nodes: + algo = node.algorithms('dim') + nodes = graph.partition(node, algo, conf) + sub_nodes += nodes + all_nodes = sub_nodes + assert len(all_nodes) == 4 + for idx, node in enumerate(all_nodes): + graph.assign(node, idx) + return graph From f09d82ee53d35d4c0ab69ba41b17168fe60adee3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 10:45:06 +0800 Subject: [PATCH 0538/1892] add einops space --- cube/algorithm/ops/einops.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py index d3039ea8..423f2496 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/einops.py @@ -69,7 +69,7 @@ def instantiate(self, config: Dict) -> List[IREinops]: ins.append(sub_tensors) else: if axis.is_reduce(): - print(f'Warning: value split on one input tensor in node{node._id}:{node.name} as reduce axis {axis} not appeared.') + # print(f'Warning: value split on one input tensor in node{node._id}:{node.name} as reduce axis {axis} not appeared.') ins.append(split_value(input, num)) else: ins.append([input] * num) @@ -98,3 +98,27 @@ def instantiate(self, config: Dict) -> List[IREinops]: sub_node.make_expression() sub_nodes.append(sub_node) return sub_nodes + + def space(self, num_device: int) -> List[Dict[str, int]]: + """ + Return a list of possible configurations + given the number of devices + """ + possible_idx = list() + possible_dim = list() + num = num_device + dims = list() + node: IREinops = self.node + for idx, eindims in enumerate(node._ieins): + for dim, eindim in enumerate(eindims): + if eindim.reduce != EinDim.ReduceType.Stay: + if eindim not in dims: + dims.append(eindim) + possible_idx.append(idx) + possible_dim.append(dim) + possible_configs = list() + for idx, dim in zip(possible_idx, possible_dim): + config = dict(idx=idx, dim=dim, num=num) + if self.satisfy(config): + possible_configs.append(config) + return possible_configs From c3f04d3efbbc3e63370f2278a205e13f2bdc7d40 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 05:45:41 +0000 Subject: [PATCH 0539/1892] synthetic data interface --- benchmark/megatron/linears.py | 23 ++++++++++----- cube/runtime/syndata.py | 53 ++++++++++++++++++++++++----------- examples/mlp/linears.py | 19 ++++++++++++- 3 files changed, 70 insertions(+), 25 deletions(-) diff --git a/benchmark/megatron/linears.py b/benchmark/megatron/linears.py index e6070c71..9c1baf50 100644 --- a/benchmark/megatron/linears.py +++ b/benchmark/megatron/linears.py @@ -2,13 +2,18 @@ example: python -m torch.distributed.launch \ - --nproc_per_node=2 \ + --nproc_per_node=4 \ --nnodes=1 \ --node_rank=0 \ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ benchmark/megatron/linears.py + +torchrun --standalone \ + --nproc_per_node=4 \ + --nnodes=1 \ + benchmark/megatron/linears.py """ import argparse @@ -23,7 +28,7 @@ class ColumnMLP(nn.Module): - def __init__(self, dim, mult=16): + def __init__(self, dim, mult=1): super().__init__() self.linear1 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=True) self.linear2 = ColumnParallelLinear(dim * mult, dim, full_input=True, full_output=True) @@ -40,7 +45,7 @@ def forward(self, data): class RowMLP(nn.Module): - def __init__(self, dim, mult=16): + def __init__(self, dim, mult=1): super().__init__() self.linear1 = RowParallelLinear(dim, dim * mult, full_input=True, full_output=True) self.linear2 = RowParallelLinear(dim * mult, dim, full_input=True, full_output=True) @@ -57,7 +62,7 @@ def forward(self, data): class HybridMLP(nn.Module): - def __init__(self, dim, mult=16): + def __init__(self, dim, mult=1): super().__init__() self.linear1 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=False) self.linear2 = RowParallelLinear(dim * mult, dim, full_input=False, full_output=True) @@ -75,8 +80,8 @@ def forward(self, data): def train(args): - batch_size = 128 - dim = 1024 + batch_size = 8192 + dim = 8192 # model = ColumnMLP(dim=dim).cuda() # model = RowMLP(dim=dim).cuda() @@ -85,7 +90,11 @@ def train(args): for param in model.parameters(): torch.nn.init.uniform_(param) - dataloader = cube.runtime.syndata.SynDataLoader(1280, [batch_size, dim]) + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([batch_size, dim],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) def train_iter(model, dataloader): diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index ed7e672b..bb40d9c8 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -2,7 +2,7 @@ Synthetic Data Loader """ -from typing import List, Optional +from typing import List, Optional, Tuple import copy import torch @@ -14,16 +14,21 @@ class CubeDataLoader: r""" Cube Dataloader """ - def __init__(self, batch_dims: List[int], *shapes: List[List[int]]): + def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype], batch_dims: Tuple[int]): """ - batch_dim: - The batch dimension for each input shapes - *shapes: + shapes Tuple[Tuple[int]]: The shape for each data + dtypes Tuple[torch.dtype]: + The dtype for each data + batch_dims Tuple[int]: + The batch dimension of each data """ - if not isinstance(batch_dims, list): - raise RuntimeError("Expected a List[int] for batch dims") - self.shapes = list(shapes) + if not all(isinstance(shape, list) for shape in shapes): + raise TypeError("Expected each shape in shapes to be a list") + if len(shapes) != len(batch_dims) or len(shapes) != len(dtypes): + raise TypeError("Expected number batch dim and dtypes to len(shapes)") + self.shapes = shapes + self.dtypes = dtypes self.batch_dims = batch_dims def get_batch_dims(self, idx: Optional[int] = None) -> int: @@ -33,7 +38,7 @@ def get_batch_dims(self, idx: Optional[int] = None) -> int: if idx is not None: return self.batch_dims[idx] else: - return copy.copy(self.batch_dims) + return list(self.batch_dims) def reset(self, batch_size: int): """ @@ -47,13 +52,27 @@ def reset(self, batch_size: int): class SynDataLoader(CubeDataLoader): r""" Synthetic dataloader to produce tensors - for given shape. + for given shapes, dtypes. """ - def __init__(self, num: int, batch_dim: List[int], *shapes: List[List[int]]): - if len(shapes) != len(batch_dim): - raise TypeError("Expected length of batch dim is same to shapes") - super().__init__(batch_dim, *shapes) - self.length = num + def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, + batch_dims: Tuple[int] = None, length: int = 1280): + """ + shapes Tuple[Tuple[int]]: + The shape for each data + dtypes Tuple[torch.dtype]: + The dtype for each data (Default None: use torch.float32) + batch_dims Tuple[int]: + The batch dimension of each data (Default None: dimension 0 is the batch dim) + length int: + Total number of sample batches. (Default 1280) + """ + if batch_dims is None: + batch_dims = tuple([0] * len(shapes)) + if dtypes is None: + dtypes = tuple([torch.float] * len(shapes)) + + super().__init__(shapes, dtypes, batch_dims) + self.length = length self.pos = 0 self._buffer_num = None @@ -69,8 +88,8 @@ def set_data_buffer(self, buffer_num = 4): self._buffer_num = buffer_num for _ in range(self._buffer_num): datas = list() - for shape in self.shapes: - data = torch.randn(shape).cuda() + for shape, dtype in zip(self.shapes, self.dtypes): + data = torch.randn(shape, dtype=dtype).cuda() datas.append(data) self.datas.append(datas) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 3f88b8d7..d3b2e024 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -9,6 +9,19 @@ --master_port=8004 \ --use_env \ examples/mlp/linears.py + +OMP_NUM_THREADS=4 torchrun --standalone \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --rdzv_id=888 \ + --rdzv_backend=c10d \ + --rdzv_endpoint=worker0:8004 \ + examples/mlp/linears.py """ import torch @@ -58,7 +71,11 @@ def train(): model, input_shapes=([batch_size, dim],), ) - dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [batch_size, dim]) + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([batch_size, dim],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) @cube.compile(model, dataloader, PAS=PAS) def train_iter(model, dataloader): From cbeb5df8492b0de12ae722499b9ca917961d09e1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 06:05:45 +0000 Subject: [PATCH 0540/1892] fix inspector --- examples/inspector.py | 45 +++++++++++-------- .../{megatron_parallel.py => megatron.py} | 2 +- 2 files changed, 28 insertions(+), 19 deletions(-) rename examples/mlp/policy/{megatron_parallel.py => megatron.py} (97%) diff --git a/examples/inspector.py b/examples/inspector.py index 048487d8..d048ed51 100644 --- a/examples/inspector.py +++ b/examples/inspector.py @@ -9,6 +9,11 @@ --master_port=8004 \ --use_env \ examples/inspector.py + +OMP_NUM_THREADS=4 torchrun --standalone \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/inspector.py """ import torch import argparse @@ -19,13 +24,21 @@ from cube.profiler.memory import memory_summary from cube.profiler.timer import print_each_rank -L, N, E = (512, 8, 3072) # gpt -kBatchDims = [0, 0] -kDataShapes = ([N, L], [N, L]) +# L, N, E = (512, 8, 3072) +# kBatchDims = (0, 0) +# kDataShapes = ([N, L], [N, L]) +# kDTypes = (torch.float, torch.long) + +# mlp +kBatchDims = (0,) +kDataShapes = ([8192, 8192],) +kDTypes = (torch.float,) + # transformer -# kBatchDims = [1] +# kBatchDims = (1, ) # kDataShapes = ([512, 4, 3072],) +# kDTypes = (torch.float,) def load_module(filename: str): @@ -55,12 +68,10 @@ def load_train_fn(filename: str): def train(args): global kDataShapes - - # dataloader = cube.runtime.syndata.SynDataLoader( - # 1280, kBatchDims, *kDataShapes - # ) - dataloader = cube.runtime.syndata.SynTextDataLoader( - 1280, kBatchDims, *kDataShapes + global kDTypes + global kBatchDims + dataloader = cube.runtime.syndata.SynDataLoader( + kDataShapes, kDTypes, kBatchDims ) genfile = args.genfile.format(rank=torch.distributed.get_rank()) @@ -69,19 +80,19 @@ def train(args): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - CudaTimer().warmup() + CudaTimer(enable=False).warmup() torch.distributed.barrier() iter_num = args.iter_num def train_iters(): for step in range(iter_num): if step >= 40: - CudaTimer().start('e2e') + CudaTimer(enable=True).start('e2e') train_fn(model, dataloader) optimizer.step() optimizer.zero_grad() if step == 1: - print('test passed') + print('passed 1 iteration') if step >= 40: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: @@ -95,11 +106,9 @@ def train_iters(): else: train_iters() - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-40) memory_summary() diff --git a/examples/mlp/policy/megatron_parallel.py b/examples/mlp/policy/megatron.py similarity index 97% rename from examples/mlp/policy/megatron_parallel.py rename to examples/mlp/policy/megatron.py index 0d05500a..c5a276fd 100644 --- a/examples/mlp/policy/megatron_parallel.py +++ b/examples/mlp/policy/megatron.py @@ -29,5 +29,5 @@ def PAS(graph: IRGraph, resource): sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) - print(graph.extra_repr()) + # print(graph.extra_repr()) return graph From 0bebcd454f9940d93b303248ee8f590569f51883 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 07:41:25 +0000 Subject: [PATCH 0541/1892] update dataloder partition --- cube/algorithm/ops/dataloader.py | 46 ++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py index 7fd4f803..44b61075 100644 --- a/cube/algorithm/ops/dataloader.py +++ b/cube/algorithm/ops/dataloader.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, List import copy from cube.algorithm.utils import split_axis @@ -6,9 +6,6 @@ from cube.graph.operator.operator import IRDataOperation -_kWaitDecision = None - - class DPDataLoader(GenericDistAlgo): def __init__(self, node: IRDataOperation): @@ -16,32 +13,41 @@ def __init__(self, node: IRDataOperation): if not isinstance(node, IRDataOperation): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) - self.batch_dims = node.get_batch_dims() - - self.chunk_num = _kWaitDecision def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - for bdim, shape in zip(self.batch_dims, self.output_shapes): - if chunk_num > 0 and shape[bdim] % chunk_num != 0: + """ + config = dict(dim=int) + num: int + number of chunks to partition + """ + for attr in ['num']: + if not attr in config: + raise KeyError("Expected idx, dim, num in the config") + node: IRDataOperation = self.node + num: int = config['num'] + dims: List[int] = node.get_batch_dims() + for dim, output in zip(dims, node.outputs()): + if output.shape[dim] % num != 0: return False return True - def instantiate(self, node, config: Dict): + def instantiate(self, config: Dict): if not self.satisfy(config): raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - sub_outputs = list() - for bdim, output in zip(self.batch_dims, node.outputs()): - sub_output = split_axis(output, bdim, self.chunk_num) - sub_outputs.append(sub_output) + node: IRDataOperation = self.node + num: int = config['num'] + dims: List[int] = node.get_batch_dims() + outputs = list() + for dim, output in zip(dims, node.outputs()): + output = split_axis(output, dim, num) + outputs.append(output) + nodes = list() - for sub_outs in zip(*sub_outputs): + for outs in zip(*outputs): node = IRDataOperation( - data_num = len(sub_outs), batch_dims = copy.copy(self.batch_dims)) - for idx, out in enumerate(sub_outs): + data_num=len(outs), batch_dims=copy.copy(dims)) + for idx, out in enumerate(outs): node.set_output(idx, out) nodes.append(node) return nodes From 179252505ae1b375411aae462006614a72493649 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 07:42:00 +0000 Subject: [PATCH 0542/1892] fix gelu in newer pytorch version --- cube/graph/operator/function/function.py | 21 +++++++++++++++++++++ cube/graph/parser/mapping.py | 2 +- cube/graph/parser/parser.py | 6 +++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 59bd5869..4be024e7 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -210,6 +210,27 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): return op +class GELU(Activation): + """ + torch.nn.functional.gelu(input, approximate: bool = False) + + Note `approximate` argument is new at pytorch version v1.11 + """ + def __init__(self, signature, inputs, name='gelu', **kwargs): + + super().__init__(signature, [inputs[0]], name) + if len(inputs) == 2: + self.kwargs['approximate'] = inputs[1] + self.set_input(0, inputs[0]) + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + if 'approximate' in self.kwargs: + inputs.append(self.kwargs['approximate']) + op = GELU(self.signature, inputs, self.name) + op.set_output(0, outputs[0]) + return op + + class Dropout(Activation): """ torch.nn.functional.dropout diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index d1feb6be..e7df2d25 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -44,7 +44,7 @@ def map(signature: str) -> IRFwOperation: __ftemplate('dropout') : function.Dropout, - __ftemplate('gelu') : partial(function.Activation, name='gelu'), + __ftemplate('gelu') : function.GELU, __ftemplate('layer_norm'): function.LayerNorm, diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 9bfe65b6..71a63797 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -165,8 +165,12 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: var_name = input.debugName() val = frame.get_var(var_name) input_vals.append(val) + try: + ir_node = Sign2Op.map(fsig)(inputs=input_vals, n_outputs=len(outputs)) + except Exception: + # print(module.code) + raise RuntimeError(f"Parsing error of {node}") - ir_node = Sign2Op.map(fsig)(inputs=input_vals, n_outputs=len(outputs)) if len(ir_node.outputs()) != len(outputs): raise RuntimeError( f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" From 2ba17b65acbd8d0750bcd2321aea177288d515fc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 07:42:37 +0000 Subject: [PATCH 0543/1892] fix allgather collectvie, add communication of allreduce --- cube/runtime/adapter/collectives.py | 2 ++ cube/runtime/adapter/reducer.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 7602d638..3d2d3c51 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -133,6 +133,8 @@ def all_gather(input_tensors: List[torch.Tensor], CudaTimer().start(field_name='comm') assert len(input_tensors) == 1 tensor = input_tensors[0] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() group = DeviceGroup().get_group(ranks) tensor_list = [torch.empty_like(tensor) for _ in ranks] idx = ranks.index(DeviceGroup().rank) diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 3968acc1..b59138da 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -6,7 +6,7 @@ import torch from cube.runtime.device import DeviceGroup - +from cube.profiler.timer import CudaTimer, print_each_rank class Reducer: @@ -35,6 +35,7 @@ def allreduce(self): # param.main_grad = param.grad # for each bucket, do all-reduce for tp in buckets: + CudaTimer().start(field_name='comm') bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = self._flatten_dense_tensors(grads) @@ -43,6 +44,7 @@ def allreduce(self): all_synced = self._unflatten_dense_tensors(coalesced, grads) for grad, synced in zip(grads, all_synced): grad.copy_(synced) + CudaTimer().stop(field_name='comm') def sync(self): """ From 70678c89073fc291a0a50d3aca10704f0c10c82c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 07:43:03 +0000 Subject: [PATCH 0544/1892] update ffn to data and hybrid tensor parallelism --- examples/ffn/ffn.py | 33 +++++++++---- examples/ffn/policy/data.py | 25 ++++++++++ examples/ffn/policy/data_parallel.py | 66 -------------------------- examples/ffn/policy/tensor.py | 32 +++++++++++++ examples/ffn/policy/tensor_parallel.py | 57 ---------------------- 5 files changed, 82 insertions(+), 131 deletions(-) create mode 100644 examples/ffn/policy/data.py delete mode 100644 examples/ffn/policy/data_parallel.py create mode 100644 examples/ffn/policy/tensor.py delete mode 100644 examples/ffn/policy/tensor_parallel.py diff --git a/examples/ffn/ffn.py b/examples/ffn/ffn.py index f432db2b..f822eb64 100644 --- a/examples/ffn/ffn.py +++ b/examples/ffn/ffn.py @@ -9,18 +9,30 @@ --master_port=8004 \ --use_env \ examples/ffn/ffn.py + +OMP_NUM_THREADS=4 torchrun --standalone \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/ffn/ffn.py + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --rdzv_id=888 \ + --rdzv_backend=c10d \ + --rdzv_endpoint=worker0:8004 \ + examples/ffn/ffn.py """ import torch import torch.nn.functional as F -import cube - -from examples.ffn.policy.data_parallel import transform_policy -from examples.ffn.policy.data_parallel import schedule_policy +import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank +from examples.ffn.policy.data import PAS + class FFN(torch.nn.Module): @@ -60,9 +72,13 @@ def train(): model, input_shapes=([L, N, E],), ) - dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([L, N, E],), + dtypes=(torch.float32,), + batch_dims=(1,) + ) - @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + @cube.compile(model, dataloader, PAS=PAS) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) @@ -71,12 +87,12 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - CudaTimer().warmup() + CudaTimer(enable=False).warmup() torch.distributed.barrier() iter_num = 128 for step in range(iter_num): if step >= 40: - CudaTimer().start('e2e') + CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() @@ -87,6 +103,7 @@ def train_iter(model, dataloader): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-40, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-40) if __name__ == '__main__': diff --git a/examples/ffn/policy/data.py b/examples/ffn/policy/data.py new file mode 100644 index 00000000..80200b62 --- /dev/null +++ b/examples/ffn/policy/data.py @@ -0,0 +1,25 @@ +from cube.graph import IRGraph +from cube.graph.operator import IRFwOperation, IRDataOperation + + +def PAS(graph: IRGraph, resource): + """ + Data Parallel + """ + # data operation + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(num=resource.ngpus)) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + batch_dim = node.get_batch_dims()[0] + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, config=dict(idx=0, dim=batch_dim, num=resource.ngpus) + ) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + return graph diff --git a/examples/ffn/policy/data_parallel.py b/examples/ffn/policy/data_parallel.py deleted file mode 100644 index 6201b4f1..00000000 --- a/examples/ffn/policy/data_parallel.py +++ /dev/null @@ -1,66 +0,0 @@ -from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation, IRDataOperation - - -def transform_policy(graph: IRGraph, resource): - - ndevs = resource.ngpus - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - assert len(fnodes) == 4 - - linear1 = fnodes[0] - gelu = fnodes[1] - linear2 = fnodes[2] - loss = fnodes[3] - - all_sub_nodes = list() - - algo = linear1.algorithms('data') - sub_nodes = graph.partition(linear1, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = gelu.algorithms('dim') - sub_nodes = graph.partition(gelu, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = linear2.algorithms('data') - sub_nodes = graph.partition(linear2, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = loss.algorithms('dim') - sub_nodes = graph.partition(loss, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - # data loader - dataloaders = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] - for data_op in dataloaders: - algo = data_op.algorithms('data') - sub_nodes = graph.partition(data_op, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - for sub_nodes in all_sub_nodes: - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - print(graph) - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - sugraph.partial_set_order(fsus, lazy=False) - return sugraph diff --git a/examples/ffn/policy/tensor.py b/examples/ffn/policy/tensor.py new file mode 100644 index 00000000..42e22fcd --- /dev/null +++ b/examples/ffn/policy/tensor.py @@ -0,0 +1,32 @@ +from cube.graph import IRGraph +from cube.graph.operator import IRFwOperation, IRDataOperation + + +def PAS(graph: IRGraph, resource): + """ + Hybrid parallel + """ + # data operation replication + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + # forward operation + configs = [ + dict(idx=1, dim=0, num=resource.ngpus), # linear col + dict(idx=0, dim=-1, num=resource.ngpus), # gelu col + dict(idx=0, dim=-1, num=resource.ngpus), # linear row + dict(idx=0, dim=-1, num=resource.ngpus), # sum + ] + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + assert len(fnodes) == len(configs) + for fnode, config in zip(fnodes, configs): + algo = fnode.algorithms('dim') + sub_nodes = graph.partition( + fnode, algo, config=config + ) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + return graph diff --git a/examples/ffn/policy/tensor_parallel.py b/examples/ffn/policy/tensor_parallel.py deleted file mode 100644 index 1d93c5b5..00000000 --- a/examples/ffn/policy/tensor_parallel.py +++ /dev/null @@ -1,57 +0,0 @@ -from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation - - -def transform_policy(graph: IRGraph, resource): - - ndevs = resource.ngpus - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - assert len(fnodes) == 4 - - linear1 = fnodes[0] - gelu = fnodes[1] - linear2 = fnodes[2] - loss = fnodes[3] - - all_sub_nodes = list() - - algo = linear1.algorithms('column') - sub_nodes = graph.partition(linear1, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = gelu.algorithms('dim') - sub_nodes = graph.partition(gelu, algo, config=dict(dim=2, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = linear2.algorithms('row') - sub_nodes = graph.partition(linear2, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - sub_nodes = graph.replicate(loss, times=ndevs) - all_sub_nodes.append(sub_nodes) - - for sub_nodes in all_sub_nodes: - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - print(graph) - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - sugraph.partial_set_order(fsus, lazy=False) - return sugraph From 78b2a7eaf4a6db45d5e659530a9eeab6ac0d5f23 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 07:43:22 +0000 Subject: [PATCH 0545/1892] update inspector with current impl --- examples/inspector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/inspector.py b/examples/inspector.py index d048ed51..4c6ceeea 100644 --- a/examples/inspector.py +++ b/examples/inspector.py @@ -119,7 +119,8 @@ def train_iters(): default='gencode{rank}.py') parser.add_argument('--iter-num', type=int, default=128) - parser.add_argument('--profile', dest='profile', action='store_true') + parser.add_argument('--profile', dest='profile', action='store_true', + help='use edge://tracing/ or chrome://tracing/ to open the file') args = parser.parse_args() cube.init() From efec9e99de4a8147e9d463240bfca8a85b3f7b81 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 07:57:13 +0000 Subject: [PATCH 0546/1892] refactor example code --- benchmark/swin/layers.py | 213 ----- benchmark/swin/swin_megatron.py | 790 ------------------ examples/case_study/logic/naive_linear.py | 89 -- examples/case_study/models/transformer.py | 225 ----- examples/case_study/policy/logical_code.py | 79 -- examples/case_study/policy/megatron_policy.py | 148 ---- examples/case_study/policy/policy.py | 90 -- examples/case_study/policy/recompute.py | 61 -- examples/case_study/policy/zero.py | 26 - examples/case_study/spatial_primitive.py | 208 ----- examples/case_study/temporal_primitive.py | 282 ------- .../ultimate/grad_accumulation_linear.py | 85 -- .../ultimate/model_partition_linear.py | 154 ---- .../case_study/ultimate/offload_linear.py | 149 ---- .../case_study/ultimate/parallel_linear.py | 266 ------ .../case_study/ultimate/pipeline_linear.py | 232 ----- .../case_study/ultimate/recompute_linear.py | 67 -- examples/case_study/ultimate/zero_linear.py | 215 ----- examples/{ffn => feedforward}/ffn.py | 8 +- examples/{ffn => feedforward}/policy/data.py | 0 .../{ffn => feedforward}/policy/tensor.py | 0 examples/poc/pipeline.py | 71 -- examples/poc/pipeline_space.py | 271 ------ examples/poc/space_size.py | 57 -- .../efficientnet/efficientnet.py | 0 .../efficientnet/schedule.py | 0 {examples => handcraft}/efficientnet/train.py | 4 +- {examples => handcraft}/efficientnet/utils.py | 0 {eval => handcraft/eval}/benchmark_gpt.sh | 0 .../eval}/swin_infer_bs1_224_782Mfp32.sh | 0 .../eval}/swin_infer_bs1_640_Gfp16.sh | 0 .../eval}/swin_infer_bs2_224_782Mfp32.sh | 0 .../eval}/swin_infer_bs2_640_Gfp16.sh | 0 .../eval}/swin_infer_bs4_640_Gfp16.sh | 0 {eval => handcraft/eval}/swin_scaleup.sh | 0 {eval => handcraft/eval}/swin_train_fp16.sh | 0 {eval => handcraft/eval}/swin_train_fp32.sh | 0 {benchmark => handcraft}/megatron/gpt.py | 6 +- {benchmark => handcraft}/megatron/layers.py | 0 {benchmark => handcraft}/megatron/linears.py | 6 +- .../megatron}/megatron_gpt_2.sh | 0 .../megatron/transformer.py | 2 +- .../swin/hybrid_schedule.py | 0 {examples => handcraft}/swin/layers.py | 0 {examples => handcraft}/swin/pmodule.py | 0 {examples => handcraft}/swin/schedule.py | 0 {examples => handcraft}/swin/swin_dt.py | 4 +- {examples => handcraft}/swin/swin_dwt.py | 4 +- .../swin/swin_dwt_infer.py | 6 +- {examples => handcraft}/swin/swin_flexflow.py | 4 +- {examples => handcraft}/swin/swin_hybrid.py | 12 +- {examples => handcraft}/swin/swin_pipe.py | 6 +- .../swin/swin_transformer.py | 0 53 files changed, 31 insertions(+), 3809 deletions(-) delete mode 100644 benchmark/swin/layers.py delete mode 100644 benchmark/swin/swin_megatron.py delete mode 100644 examples/case_study/logic/naive_linear.py delete mode 100644 examples/case_study/models/transformer.py delete mode 100644 examples/case_study/policy/logical_code.py delete mode 100644 examples/case_study/policy/megatron_policy.py delete mode 100644 examples/case_study/policy/policy.py delete mode 100644 examples/case_study/policy/recompute.py delete mode 100644 examples/case_study/policy/zero.py delete mode 100644 examples/case_study/spatial_primitive.py delete mode 100644 examples/case_study/temporal_primitive.py delete mode 100644 examples/case_study/ultimate/grad_accumulation_linear.py delete mode 100644 examples/case_study/ultimate/model_partition_linear.py delete mode 100644 examples/case_study/ultimate/offload_linear.py delete mode 100644 examples/case_study/ultimate/parallel_linear.py delete mode 100644 examples/case_study/ultimate/pipeline_linear.py delete mode 100644 examples/case_study/ultimate/recompute_linear.py delete mode 100644 examples/case_study/ultimate/zero_linear.py rename examples/{ffn => feedforward}/ffn.py (94%) rename examples/{ffn => feedforward}/policy/data.py (100%) rename examples/{ffn => feedforward}/policy/tensor.py (100%) delete mode 100644 examples/poc/pipeline.py delete mode 100644 examples/poc/pipeline_space.py delete mode 100644 examples/poc/space_size.py rename {examples => handcraft}/efficientnet/efficientnet.py (100%) rename {examples => handcraft}/efficientnet/schedule.py (100%) rename {examples => handcraft}/efficientnet/train.py (98%) rename {examples => handcraft}/efficientnet/utils.py (100%) rename {eval => handcraft/eval}/benchmark_gpt.sh (100%) rename {eval => handcraft/eval}/swin_infer_bs1_224_782Mfp32.sh (100%) rename {eval => handcraft/eval}/swin_infer_bs1_640_Gfp16.sh (100%) rename {eval => handcraft/eval}/swin_infer_bs2_224_782Mfp32.sh (100%) rename {eval => handcraft/eval}/swin_infer_bs2_640_Gfp16.sh (100%) rename {eval => handcraft/eval}/swin_infer_bs4_640_Gfp16.sh (100%) rename {eval => handcraft/eval}/swin_scaleup.sh (100%) rename {eval => handcraft/eval}/swin_train_fp16.sh (100%) rename {eval => handcraft/eval}/swin_train_fp32.sh (100%) rename {benchmark => handcraft}/megatron/gpt.py (96%) rename {benchmark => handcraft}/megatron/layers.py (100%) rename {benchmark => handcraft}/megatron/linears.py (96%) rename {benchmark => handcraft/megatron}/megatron_gpt_2.sh (100%) rename {benchmark => handcraft}/megatron/transformer.py (99%) rename {examples => handcraft}/swin/hybrid_schedule.py (100%) rename {examples => handcraft}/swin/layers.py (100%) rename {examples => handcraft}/swin/pmodule.py (100%) rename {examples => handcraft}/swin/schedule.py (100%) rename {examples => handcraft}/swin/swin_dt.py (99%) rename {examples => handcraft}/swin/swin_dwt.py (99%) rename {examples => handcraft}/swin/swin_dwt_infer.py (99%) rename {examples => handcraft}/swin/swin_flexflow.py (99%) rename {examples => handcraft}/swin/swin_hybrid.py (99%) rename {examples => handcraft}/swin/swin_pipe.py (99%) rename {examples => handcraft}/swin/swin_transformer.py (100%) diff --git a/benchmark/swin/layers.py b/benchmark/swin/layers.py deleted file mode 100644 index dda71ad9..00000000 --- a/benchmark/swin/layers.py +++ /dev/null @@ -1,213 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.nn.parameter import Parameter - -from cube.profiler.timer import print_each_rank -from cube.runtime.resource import EnvResource - - -def _reduce(input_): - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) - if world_size == 1: - return input_ - group = EnvResource().tp_group - torch.distributed.all_reduce(input_, group=group) - return input_ - - -def _split(input_): - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) - rank = torch.distributed.get_rank(group=EnvResource().tp_group) - # Bypass the function if we are using only 1 GPU. - if world_size==1: - return input_ - last_dim = input_.dim() - 1 - last_dim_size = input_.size()[last_dim] // world_size - tensor_list = torch.split(input_, last_dim_size, dim=last_dim) - output = tensor_list[rank].contiguous() - return output - - -def _gather(input_): - """Gather tensors and concatinate along the last dimension.""" - - world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) - rank = torch.distributed.get_rank(group=EnvResource().tp_group) - # Bypass the function if we are using only 1 GPU. - if world_size==1: - return input_ - # Size and dimension. - last_dim = input_.dim() - 1 - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - group = EnvResource().tp_group - torch.distributed.all_gather(tensor_list, input_, group=group) - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() - return output - - -class ColumnInputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return input_ - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output) - - -class ColumnOutputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _gather(input_) - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output) - - -class RowInputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _split(input_) - - @staticmethod - def backward(ctx, grad_outputs): - return _gather(grad_outputs) - - -class RowOutputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _reduce(input_) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class ColumnParallelLinear(torch.nn.Module): - - def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False): - super().__init__() - self.input_size = input_size - self.output_size = output_size - self.full_input = full_input - self.full_output = full_output - - world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) - - # print_each_rank(f'> parallizing linear using column partition: ' - # f'{output_size} partitioned by {world_size} devices') - - # not if output size is smaller than world size, - # no parallel enbaled. Each device compute the same - if world_size > output_size: - world_size = 1 - - self.weight = Parameter(torch.empty( - int(self.output_size // world_size), - self.input_size, - )) - if bias: - self.bias = Parameter(torch.empty( - int(self.output_size // world_size), - )) - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def forward(self, input_): - bias = self.bias - if not self.full_input: - raise RuntimeError("Expected full tensor input") - input_parallel = ColumnInputAdapter.apply(input_) - output_parallel = F.linear(input_parallel, self.weight, bias) - if self.full_output: - output = ColumnOutputAdapter.apply(output_parallel) - else: - output = output_parallel - return output - - -class RowParallelLinear(torch.nn.Module): - - def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False): - super().__init__() - self.input_size = input_size - self.output_size = output_size - self.full_input = full_input - self.full_output = full_output - - world_size = torch.distributed.get_world_size(group=EnvResource().tp_group) - - # print_each_rank(f'> parallizing linear using row partition: ' - # f'{output_size} partitioned by {world_size} devices') - - # not if output size is smaller than world size, - # no parallel enbaled. Each device compute the same - if world_size > output_size: - world_size = 1 - - self.weight = Parameter(torch.empty( - self.output_size, - int(self.input_size // world_size), - )) - if bias: - self.bias = Parameter(torch.empty(self.output_size)) - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def forward(self, input_): - bias = self.bias - if self.full_input: - input_parallel = RowInputAdapter.apply(input_) - else: - input_parallel = input_ - output_parallel = F.linear(input_parallel, self.weight, bias) - if self.full_output: - output = RowOutputAdapter.apply(output_parallel) - else: - output = output_parallel - return output - - -class ShardEmbedding(torch.nn.Module): - - def __init__(self, num_embeddings, embedding_dim): - super().__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - - self.shard_num = torch.distributed.get_world_size(group=EnvResource().tp_group) - self.myshard = torch.distributed.get_rank(group=EnvResource().tp_group) - - shard_num_embeddings = self.num_embeddings // self.shard_num - self.vocab_start_index = shard_num_embeddings * self.myshard - self.vocab_end_index = self.vocab_start_index + shard_num_embeddings - - self.weight = torch.nn.Parameter( - torch.empty(shard_num_embeddings, self.embedding_dim) - ) - - def forward(self, input_): - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - output_parallel = F.embedding( - masked_input, self.weight, - None, None, 2., False, False - ) - output = RowOutputAdapter.apply(output_parallel) - return output diff --git a/benchmark/swin/swin_megatron.py b/benchmark/swin/swin_megatron.py deleted file mode 100644 index 9f3ab9e8..00000000 --- a/benchmark/swin/swin_megatron.py +++ /dev/null @@ -1,790 +0,0 @@ - -# -------------------------------------------------------- -# Modified from Swin-Transformer Repo -""" -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - benchmark/swin/swin_megatron.py -""" -# -------------------------------------------------------- - -from typing import Optional -import torch -import torch.nn as nn - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary - -import argparse - - -from benchmark.swin.layers import ColumnParallelLinear, RowParallelLinear - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class MegatronMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, full_input=True, full_output=False) - self.act = act_layer() - # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, full_input=False, full_output=True) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class MegatronWindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.global_num_heads = num_heads - group = cube.runtime.resource.EnvResource().tp_group - tp_world_size = torch.distributed.get_world_size(group=group) - if num_heads % tp_world_size != 0: - print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') - self.num_heads = num_heads // torch.distributed.get_world_size(group=group) - self.dim_heads = dim // self.global_num_heads - self.scale = qk_scale or self.dim_heads ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - # print(f'qkv embed dim: {dim}') - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False) - self.attn_drop = nn.Dropout(attn_drop) - # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, full_input=False, full_output=True) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - relative_position_bias = self.relative_position_bias_table[relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = MegatronWindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MegatronMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - x = x + drop_path(ffn, self.drop_path_p) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x): - for blk in self.blocks: - x = blk(x) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) - self.layers.append(layer) - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def forward_features(self, x): - x = self.patch_embed(x) - # if self.ape: - # x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - return x - - def forward(self, x): - # forward features - # x = self.forward_features(x) - x = self.patch_embed(x) - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - - x = self.head(x) - return x - - def flops(self): - flops = 0 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes - return flops - - -def train(args): - resource = cube.runtime.resource.EnvResource() - - # image batch input - N, C, H, W = [1, 3, 224, 224] - - # embed_dim, depths, num_heads, window_size = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 - # ] - - # 348.55 M - # embed_dim, depths, num_heads, window_size = [ - # 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 - # ] - - # 895.7 M Model - # embed_dim, depths, num_heads, window_size = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 - # ] - - # 2.01B model - embed_dim, depths, num_heads, window_size = [ - 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 - ] - - - model = SwinTransformer(embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size) - model = model.cuda() - memory_summary() - - # setup data parallel reducer - reducer = None - if args.dp > 1: - print('> initialize weight reducer') - reducer = resource.reducer - for param in model.parameters(): - reducer.add_param(param) - - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [N // args.dp, C, H, W]) - - def train_iter(model, dataloader): - img = next(dataloader) - loss = model(img) - loss = torch.sum(loss) - loss.backward() - if reducer is not None: - reducer.allreduce() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - # start training - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - - -if __name__ == '__main__': - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--tp', type=int, default=1, - help='tensor parallel size') - parser.add_argument('--dp', type=int, default=1, - help='data parallel size') - parser.add_argument('--pp', type=int, default=1, - help='pipeline parallel size') - parser.add_argument('--micro-bs', type=int, default=-1) - args = parser.parse_args() - - cube.init() - - # allocate resource - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - tp_size, tp_group_nums = args.tp, ndevs // args.tp - dp_size, dp_group_nums = args.dp, ndevs // args.dp - pp_size, pp_group_nums = args.pp, ndevs // args.pp - - if not pp_size * dp_size * tp_size == ndevs: - raise RuntimeError("Expected all devices are used") - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize data parallel group - all_data_parallel_group_ranks = list() - for i in range(pp_size): - start_rank = i * pp_group_nums - end_rank = (i + 1) * pp_group_nums - for j in range(tp_size): - ranks = list(range(start_rank + j, end_rank, tp_size)) - all_data_parallel_group_ranks.append(ranks) - # initialize groups - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - resource.dp_group = group - resource.reducer = cube.runtime.reducer.Reducer(ranks) - print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) - - # initialize pipelne parallel groups - for i in range(dp_size): - ranks = [data_parallel_group_ranks[i] - for data_parallel_group_ranks in all_data_parallel_group_ranks] - group = devs.get_group(ranks) - if myrank in ranks: - pp_ranks = ranks - resource.pp_group = group - print_each_rank(f'initialzed pipeline parallel group: {pp_ranks}', rank_only=myrank) - - # initialize tensor parallel groups - for i in range(tp_group_nums): - ranks = list(range(i * tp_size, (i + 1) * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - resource.tp_group = group - print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - train(args) diff --git a/examples/case_study/logic/naive_linear.py b/examples/case_study/logic/naive_linear.py deleted file mode 100644 index ef50e8d2..00000000 --- a/examples/case_study/logic/naive_linear.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -from torch.nn.parameter import Parameter -import math - -torch.manual_seed(121) - - -def linear(input, weight, bias=None): - output = torch._C._nn.linear(input, weight, bias) - return output - - -def apply_adam(params, grads, exp_avgs, exp_avg_sqs, steps, beta1, beta2, lr): - for i, param in enumerate(params): - - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step = steps[-1] - - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step - - exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8) - step_size = lr / bias_correction1 - param.addcdiv_(exp_avg, denom, value=-step_size) - - -if __name__ == '__main__': - - torch.cuda.set_device(0) - - # tensor definition - batch_size = 128 - out_features = 1024 - in_features = 1024 - - weight = torch.rand((out_features, in_features)).cuda().requires_grad_() - bias = torch.rand(out_features).cuda().requires_grad_() - input = torch.rand((batch_size, in_features)).cuda() - # print('weight: ', weight) - # print('bias: ', bias) - # print('input: ', input) - - ## Adam optimizer states -- 2x more weights volume - weight_exp_avg = torch.zeros_like( - weight, memory_format=torch.preserve_format - ) - weight_exp_avg_sq = torch.zeros_like( - weight, memory_format=torch.preserve_format - ) - bias_exp_avg = torch.zeros_like( - bias, memory_format=torch.preserve_format - ) - bias_exp_avg_sq = torch.zeros_like( - bias, memory_format=torch.preserve_format - ) - state_steps = list() - lr = 0.01 - beta1 = 0.5 - beta2 = 0.5 - - # iterations - for _ in range(4): - # ======= step1: forward ======= # - output = linear(input, weight, bias) - loss = torch.mean(output) - print(loss) - - # ======= step2: backward ======= # - loss.backward() - # print('weight grad: ', weight.grad.t()) - - # ======= step3: update ======= # - params = [weight, bias] - grads = [weight.grad, bias.grad] - exp_avgs = [weight_exp_avg, bias_exp_avg] - exp_avg_sqs = [weight_exp_avg_sq, bias_exp_avg_sq] - state_steps.append(len(state_steps)+1) - with torch.no_grad(): - apply_adam( - params, grads, exp_avgs, exp_avg_sqs, state_steps, - beta1, beta2, lr - ) - # zero out grad - weight.grad = None - bias.grad = None diff --git a/examples/case_study/models/transformer.py b/examples/case_study/models/transformer.py deleted file mode 100644 index 098a91a9..00000000 --- a/examples/case_study/models/transformer.py +++ /dev/null @@ -1,225 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F - - -class MultiHeadSelfAttention(nn.Module): - - def __init__(self, seq_len, embed_dim, heads, dropout): - super().__init__() - - self.seq_len = seq_len - self.embed_dim = embed_dim - self.num_heads = heads - self.dim_head = embed_dim // heads - self.scale = self.dim_head ** -0.5 - - self.weight_qkv = torch.nn.Parameter(torch.empty( - 3 * embed_dim, embed_dim - )) - self.weight_out = torch.nn.Parameter(torch.empty( - embed_dim, embed_dim - )) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, mask): - """ - x: [L, N, E]: seq_len, batch_size, embedding dimension - output: [L, N, E] - """ - bs = x.shape[1] - - # [L, N, E] -> [L, N, (3 * num_heads * dim_head)] - qkv = F.linear(x, self.weight_qkv, None) - # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 - qkv = qkv.chunk(3, dim=-1) - q, k, v = qkv - - # [L, N, (num_heads * dim_head)] -> [L, (N * num_heads), dim_head] - q = q.contiguous() - q = q.view(self.seq_len, (bs * self.num_heads), self.dim_head) - k = k.contiguous() - k = k.view(self.seq_len, (bs * self.num_heads), self.dim_head) - v = v.contiguous() - v = v.view(self.seq_len, (bs * self.num_heads), self.dim_head) - - # [L, N, (num_heads * dim_head)] -> [(N * num_heads), L, dim_head] - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - # [(N * num_heads), L, dim_head] -> [(N * num_heads), L, dim_head] - q = q * self.scale - # [(N * num_heads), L, dim_head] * [(N * num_heads), dim_head, L] - # -> [(N * num_heads), L, L] - attn = torch.bmm(q, k.transpose(-2, -1)) - - # [(N * num_heads), L, L] -> [N, num_heads, L, L] - attn = attn.view(bs, self.num_heads, self.seq_len, self.seq_len) - # [N, num_heads, L, L] -> [N, num_heads, L, L] - # attn += mask # pytorch official implementation - attn = attn.masked_fill_(mask, -100000.0) - # [N, num_heads, L, L] -> [(N * num_heads), L, L] - attn = attn.view((bs * self.num_heads), self.seq_len, self.seq_len) - - # [(N * num_heads), L, L] -> [(N * num_heads), L, L] - attn = F.softmax(attn, dim=-1) - - # [(N * num_heads), L, L] -> [(N * num_heads), L, L] - attn = self.dropout(attn) - # [(N * num_heads), L, L] * [(N * num_heads), L, dim_head] - # -> [(N * num_heads), L, dim_head] - output = torch.bmm(attn, v) - # [(N * num_heads), L, dim_head] -> [L, (N * num_heads), dim_head] - output = output.transpose(0, 1).contiguous() - # [L, (N * num_heads), dim_head] -> [L, N, (num_heads * dim_head)] - output = output.view(self.seq_len, bs, self.embed_dim) - # [L, N, (num_heads * dim_head)] * [(num_heads * dim_head), (num_heads * dim_head)] - # => [L, N, (num_heads * dim_head)] - output = F.linear(output, self.weight_out) - return output - - def _ref_forward(self, x, mask=True): - """ - X: [L, N, E]: seq_len, batch_size, embedding dimension - mask: whether to use mask - """ - if mask is not None: - ones = torch.ones( - (self.seq_len, self.seq_len), - device=torch.cuda.current_device() - ) - mask = torch.tril(ones) - mask = (mask < 0.5) - else: - mask = None - output, _ = F.multi_head_attention_forward( - query=x, - key=x, - value=x, - embed_dim_to_check=self.embed_dim, - num_heads=self.num_heads, - in_proj_weight=self.weight_qkv, - in_proj_bias=None, - bias_k = None, - bias_v = None, - add_zero_attn=False, - dropout_p=self.dropout.p, - out_proj_weight=self.weight_out, - out_proj_bias=None, - attn_mask=mask, - training=self.training, - need_weights=False, - ) - return output - - -class FFN(torch.nn.Module): - - def __init__(self, hidden_size: int): - super().__init__() - self.dense_h_to_4h = torch.nn.Linear( - hidden_size, 4 * hidden_size - ) - self.dense_4h_to_h = torch.nn.Linear( - 4 * hidden_size, hidden_size - ) - - def forward(self, hidden_states): - # [L, N, E] * [E, 4E] -> [L, N, 4E] - out = self.dense_h_to_4h(hidden_states) - # [L, N, 4E] -> [L, N, 4E] - out = F.gelu(out) - # [L, N, 4E] * [4E, E] -> [L, N, E] - out = self.dense_4h_to_h(out) - return out - - -class TransformerLayer(torch.nn.Module): - - def __init__(self, seq_len, hidden_size, head_num, dropout): - super().__init__() - # layer norm - self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - - self.attention = MultiHeadSelfAttention(seq_len, hidden_size, head_num, dropout) - self.attn_dropout = torch.nn.Dropout(dropout) - - self.mlp_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - self.mlp = FFN(hidden_size) - self.mlp_dropout = torch.nn.Dropout(dropout) - - def forward(self, hidden_states, attention_mask): - # Attention - in_attn_norm = self.input_layernorm(hidden_states) - attn_out = self.attention(in_attn_norm, attention_mask) - # residual - attn_out = self.attn_dropout(attn_out) - residual = attn_out + hidden_states - # MLP - in_mlp_norm = self.mlp_layernorm(residual) - mlp_out = self.mlp(in_mlp_norm) - # residual - mlp_out = self.mlp_dropout(mlp_out) - mlp_out = mlp_out + residual - return mlp_out - - -def get_attn_mask(batch_size: int, seq_len: int): - ones = torch.ones( - (batch_size, seq_len, seq_len), - device=torch.cuda.current_device() - ) - mask = torch.tril(ones) - mask = mask.view(batch_size, 1, seq_len, seq_len) - mask = (mask < 0.5) - return mask - - -def reset_parameter(model): - for param in model.parameters(): - torch.nn.init.uniform_(param) - - -def test_attention(): - L = 64 - N = 16 - E = 128 - n_heads = 8 - - model = MultiHeadSelfAttention(L, E, n_heads, dropout=0.0).cuda() - reset_parameter(model) - - x = torch.rand((L, N, E)).cuda() - mask = get_attn_mask(N, L).cuda() - - out = model(x, mask) - out_ref = model._ref_forward(x, mask) - - assert torch.allclose(out, out_ref) is True - print('test passed') - - -if __name__ == '__main__': - - L = 64 - N = 16 - E = 1024 - n_heads = 8 - - test_attention() - - model = TransformerLayer(L, E, n_heads, 0.5).cuda() - reset_parameter(model) - - x = torch.rand((L, N, E)).cuda() - mask = get_attn_mask(N, L).cuda() - - out = model(x, mask) - # print(out) - # print(out_ref) - # assert torch.allclose(out, out_ref) is True - print('Test passed') - module = torch.jit.script(model) - print(module.graph) - print(module.code) diff --git a/examples/case_study/policy/logical_code.py b/examples/case_study/policy/logical_code.py deleted file mode 100644 index fb03a747..00000000 --- a/examples/case_study/policy/logical_code.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F - -import argparse - -def sschedule(partial_dag, resources): pass -def tschedule(train_fn): pass -resources = None # available hardware resources - - -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, dim * mult), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(dim * mult, dim) - ) - - self.classifier = nn.Linear(dim, classes) - - def forward(self, data, label): - output = self.net(data) - output = self.classifier(output) - loss = F.cross_entropy(output, label) - return loss - - -def data_iter(gbs, dim, classes, length=1024, mbs=None): - mbs = mbs if mbs is not None else gbs - num_mb = gbs // mbs - for _ in range(length): - gbs_data = list() - gbs_label = list() - for _ in range(num_mb): - mbs_data = torch.randn((mbs, dim)) - mbs_label = torch.randint(0, classes, (mbs,)) - gbs_data.append(mbs_data) - gbs_label.append(mbs_label) - yield gbs_data, gbs_label - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--dim', type=int, default=1024) - parser.add_argument('--heads', type=int, default=16) - parser.add_argument('--gbs', type=int, default=64) - parser.add_argument('--mbs', type=int, default=4) - parser.add_argument('--classes', type=int, default=10) - args = parser.parse_args() - - model = FeedForward(args.dim, mult=args.heads, classes=args.classes) - # model = model.cuda() - - # spatial schedule - model = sschedule(model, resources) - # temporal schedule - @tschedule - def train_iter(data, label): - # forward - loss = model(data, label) - # backward - loss.backward() - # update - optimizer.step() - optimizer.zero_grad() - - optimizer = torch.optim.Adam( - model.parameters(), - lr=0.001, - betas=(0.9, 0.99), - weight_decay=0 - ) - - for (data, label) in data_iter(args.gbs, args.dim, args.classes, mbs=args.mbs): - train_iter(data, label) diff --git a/examples/case_study/policy/megatron_policy.py b/examples/case_study/policy/megatron_policy.py deleted file mode 100644 index 773dfeb6..00000000 --- a/examples/case_study/policy/megatron_policy.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import List - -from cube.schedule.su import SUType - - -def transform_policy(graph, resource): - - # suppose this is the policy config that both - # transformation and schedule policy know - tp_size = 8, - pp_size = 4, - dp_size = resource.ndev // (tp_size * pp_size) - num_micro_batch = 16 - - # each op is divided in (mp_dsize, dp_size) - # and put in (pp_size) stage - # TODO groups[stage][dp_group][tp_group] = devices (List[int]) - - # data + pipeline parallelism: first transform graph - for idx, op in enumerate(graph.nodes()): - algorithm = op.algorithm('data_parallel') - graph.partition( - op, algorithm, config=dict(chunk_size=num_micro_batch * dp_size) - ) - pp_stage = idx // (len(graph.nodes()) // pp_size) - op.tag('pp_stage', pp_stage) - - # data parallel - for op in graph.nodes(): - algorithm = op.algorithm('data_parallel') - graph.partition(op) - - # tensor parallel - # a transformer attention layer: - # [attention: col_split(mm + mm + mm) + row_split(mm)] - # a transformer feedforward layer: - # [feedforwrd: col_split(mm) + row_split(mm)] - for idx, op in enumerate(graph.nodes()): - # Attention block - # [1st linear -> 2nd linear) - if op_from_1st_to_2nd_linear(op): - # split column - tp_col_algo = op.logical_op.dist_algo(1) - graph.partition( - op = op, - algorithm = tp_col_algo, - config = dict(chunk_num=tp_size, uniform=True) - ) - # 2nd linear - elif op_is_2nd_linear(op): - # split row - tp_row_algo = op.logical_op.dist_algo(2) - graph.partition( - op = op, - algorithm = tp_row_algo, - config = dict(chunk_num=tp_size, uniform=True) - ) - # MLP block - # [3rd linear -> 4th linear] - elif op_from_3rd_to_4th_linear(op): - # split column - tp_col_algo = op.logical_op.dist_algo(1) - graph.partition( - op = op, - algorithm = tp_col_algo, - config = dict(chunk_num=tp_size, uniform=True) - ) - elif op_is_4th_linear(op): - # split row - tp_row_algo = op.logical_op.dist_algo(2) - graph.partition( - op = op, - algorithm = tp_row_algo, - config = dict(chunk_num=tp_size, uniform=True) - ) - return graph - - -def schedule_policy(su_graph, resource): - - # suppose this is the policy config that both - # transformation and schedule policy know - tp_size = 8, - pp_size = 4, - dp_size = resource.ndev // (tp_size * pp_size) - num_micro_batch = 16 - - # given tp, pp, dp, num mirco batch, set the device id - # for hierachical: [pipeline][data][tensor] = device (int) - dev_groups = set_device_id(tp_size, dp_size, pp_size, num_micro_batch) - - # put sus to forward-backward sequences: List[List[SU(op)]] - fb_op_sus = list() - for su in su_graph.sus(): - if su.stype == SUType.Forward or su.stype == SUType.Backward: - for fb_seq in fb_op_sus: - if fb_seq[-1].happen_before(su): - fb_seq.append(su) - break - else: - fb_op_sus.append([su]) - - # merge to stages: List[List[SU(stage sequential of ops)]] - fb_stage_sus = list() - assert len(fb_op_sus) == tp_size * dp_size * num_micro_batch - for dp in range(dp_size): - for tp in range(tp_size): - fb_stage_sus.append([]) - fb_sus = fb_op_sus[dp * dp_size + tp] - for idx, su in enumerate(fb_sus): - pp = idx // ( len(fb_sus) // pp_size) - device = dev_groups[pp][dp][tp] - su_graph.assign(su, device) - merged_su = None - for su in fb_sus: - if merged_su is None: - merged_su = su - fb_stage_sus[-1].append([su]) - else: - # same device op can be merged - merged_su = su_graph.merge(merged_su, su) - - num_stage = pp_size - f = lambda stage, micro_batch_id: fb_stage_sus[micro_batch_id][stage] - b = lambda stage, micro_batch_id: fb_stage_sus[micro_batch_id][num_stage + stage] - - sequence = list() - - # warmup: - for stage in range(num_stage): - for mid in range(stage): - sequence.append(f(stage, mid)) - - # steady + cooldown: - for mid in range(num_micro_batch): - # enqueue backward - for stage in range(num_stage-1, -1, -1): - sequence.append(b(stage, mid)) - # enqueue forward - for stage in range(num_stage): - f_mid = mid + 1 + num_stage - stage - if f_mid >= num_micro_batch: - continue - sequence.append(f(stage, f_mid)) - - # infor system the control dependency by topological assignment - su_graph.set_order(sequence) - return su_graph diff --git a/examples/case_study/policy/policy.py b/examples/case_study/policy/policy.py deleted file mode 100644 index bb52ab6a..00000000 --- a/examples/case_study/policy/policy.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch - -def select(tensor, indices, val_map_op=None, shape=None): pass - -def input_adapter(inputs, target): pass - -def iter_op(DAG): pass -def generate_for_each_rank(pDAG): pass - - -def sschedule_dp(pDAG, resources, input_tensors): - """ - Data Parallel Description - - Args: - * pDAG: (partial) logical computation graph - * Resources: Environment inlcuding devices, network topology etc - Returns: - * pDAGs (list[DAG]) execution (local & physical) DAG for each rank - """ - # rank [0,1,..., pp_size-1], [pp_size, ..., 2*pp_size - 1], ... - ndevs = resources.ndevs - # suppose 8 devices, 4 for pipeline, 2 for data parallel - dp_size = 2 - pp_size = 4 - for op in iter_op(pDAG): - for op_id, dist_op in enumerate(op.dist_candidates()): - # find the data parallelism - if is_data_parallelism(dist_op): - for tensor in dist_op.inputs + dist_op.outputs: - if isinstance(tensor.segment, SplitAxis): - # pipeline micro-batch = 4 - tensor.segment.chunk_num = dp_size * 4 - # translate to logical tensor segments - tensor.segment.translate() - dist_op.generate_ops() - # setup placement - stage = op_id // (len(pDAG) // pp_size) - for dp_id, sub_op in enumerate(dist_op.ops): - sub_op.device = (dp_id % dp_size) * pp_size + stage - # materialize -- call to the deploy - dist_op.materialize() - # generate input adapter - pDAG.replace(op, dist_op) - break - return pDAG - - -def tschedule_1f1b(actions, relations, resources): - """ - Pipeline 1f1b policy description -- each device order - - Actions: a list of actions - - relations: list[(Action1, Action2)]: a list of tuples indicate partial order - """ - num_stage = resources.n_gpus - num_microbatch = len(actions) / 2 / num_stage - - f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] - b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] - - # action in-device order - stage_order = list() - - for stage in range(num_stage): - order = list() - num_warmup_microbatch = num_stage - stage - 1 - num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) - num_microbatch_remain = num_microbatch - num_warmup_microbatch - - # warmup - for mid in range(num_warmup_microbatch): - order.append(f(stage, mid)) - - # steady - for i in range(num_microbatch_remain): - f_mid = num_warmup_microbatch + i - b_mid = i - order.append(f(stage, f_mid)) - order.append(b(stage, b_mid)) - - # cooldown - for i in range(num_warmup_microbatch): - b_mid = num_microbatch_remain + i - order.append(b(stage, b_mid)) - - stage_order.append(order) - - return stage_order diff --git a/examples/case_study/policy/recompute.py b/examples/case_study/policy/recompute.py deleted file mode 100644 index 24ff55b4..00000000 --- a/examples/case_study/policy/recompute.py +++ /dev/null @@ -1,61 +0,0 @@ -from cube.schedule.su import SUType - -def choose_input(op, input_incarnation): pass -def choose_output(op, output_incarnation): pass -def create_incar(tensor_or_op): pass - - -def transformation_policy(graph, resource): - - def _recompute_ops(graph, ops): - """ - PyTorch Checkpointing - """ - tensors_incar = list() - ops_incar = list() - - for op in ops[:-1]: - op_incar = graph.create_incar(op) - ops_incar.append(op_incar) - for output in op.outputs(): - tensor_incar = graph.create_incar(output) - tensors_incar.append(tensor_incar) - graph.choose_output(ops_incar, tensor_incar) - for op in ops_incar[1:]: - for input in op.inputs(): - for input_incar in input.get_incar(): - if input_incar in tensors_incar: - graph.choose_input(op, input_incar) - # else keep in memory - for op in ops[1:]: - for idx, output in enumerate(op.outputs()): - succ_ops = graph.successors(op, idx) - succ_ops = [ - op for op in succ_ops if op.type == SUType.Backward - ] - for succ_op in succ_ops: - for input in succ_op.inputs(): - for input_incar in input.get_incar(): - if input_incar in tensors_incar: - graph.choose_input(succ_op, input_incar) - - # checkpointing tensor - chunk_num = 4 - # forward ops - fops = [node for node in graph.nodes() if node.type == SUType.Forward] - chunk_size = int(len(fops) // chunk_num) - for cid in range(chunk_num): - chunk_fops = fops[chunk_size * cid, chunk_size * (cid + 1)] - _recompute_ops(graph, chunk_fops) - - -def schedule_policy(sugraph, resource): - - for su in sugraph.sus(): - sugraph.assign(su, 0) - if su.is_incarnation(): - succ_sus = sugraph.successors(su) - for succ_su in succ_sus: - if sugraph.merge(su, succ_su): - break - sugraph.set_order(sugraph.random_topo_order()) diff --git a/examples/case_study/policy/zero.py b/examples/case_study/policy/zero.py deleted file mode 100644 index 960d024f..00000000 --- a/examples/case_study/policy/zero.py +++ /dev/null @@ -1,26 +0,0 @@ -from cube.schedule.su import SUType - -def transformation_policy(graph, resource): - - for op in graph.nodes(): - if op.type == SUType.Forward: - algorithm = op.algorithms('data_parallelism') - sub_graph = graph.partition(op, algorithm, config=dict(chunk_size=resource.ngpus)) - if op.type == SUType.Optimizer: - algorithm = op.algorithms('split_axis_0') - sub_graph = graph.partition(op, algorithm, config=dict(chunk_size=resource.ngpus)) - - return graph - - -def schedule_policy(sugraph, resource): - - semantic_ops = dict() - for su in sugraph.sus(): - if su.nodes(0).semantic_ops not in semantic_ops: - semantic_ops[su.nodes(0).semantic_ops] = list() - semantic_ops[su.nodes(0).semantic_ops].append(su) - for semantic_op in semantic_ops: - for idx, su in enumerate(semantic_ops[semantic_op]): - gpu_id = idx % resource.ngpus - sugraph.assign(su, gpu_id) diff --git a/examples/case_study/spatial_primitive.py b/examples/case_study/spatial_primitive.py deleted file mode 100644 index e4829469..00000000 --- a/examples/case_study/spatial_primitive.py +++ /dev/null @@ -1,208 +0,0 @@ -import torch -import os - -from functools import partial - -torch.manual_seed(121) - -class LogicalOp: pass -class PhyiscalOp: pass - -class LogicalTensor: pass -class PhyiscalTensor: pass - -# select from logical tensor with indices -> generate a logical tensor -def select(tensor: LogicalTensor, indices, val_map_op, shape) -> LogicalTensor: pass - -# deploy logical tensor to devices -def deploy(tensor: LogicalTensor, ranks) -> list(PhyiscalTensor): pass - -# merge logical tensors at `ranks` devices -def merge(tensor: LogicalTensor, ranks, val_reduce_op): pass - -# tensor movement: move physical tensor to rank -def move(tensor: PhyiscalTensor, rank): pass - -# tensor release: release the data in physical tensor inside tensor -def release(tensor: PhyiscalTensor): pass - -# tensor re-genrewate: bring back the data for the physical tensor -def generate(tensor: PhyiscalTensor, rank): pass - - - -## =============== tensor parallelism on matmul ============== ## - -def all_gather(tensors, dim): pass - - -# the logical op linear: -def linear(inputs, weight, bias) -> LogicalTensor: pass - - -def linear_tensor_parallel(inputs, weight, bias, output): - """ - inputs: (M, K) - weight: (N, K) - bias: (N,) - output: (M, N) - - Perform: (M, K) * (\delta N, K) + (\delta N,) = (M, \delta N) - """ - - M = 1024 - K = 1024 - N = 1024 - - # Tensor split -- system + policy generated - inputs = select( - tensor = inputs, - indices = (slice(0, M), slice(0, K)), - val_map_op = None, - shape = (M, K) - ) - - weights, biases, outputs = list(), list(), list() - for cid in range(4): - weights.append(select( - tensor = weight, - indices = (slice(cid * (N // 4), (cid + 1) * (N // 4)), slice(0, K)), - val_map_op = None, - shape = (N // 4, K) - )) - - biases.append(select( - tensor = bias, - indices = (slice(cid * (N // 4), (cid + 1) * (N // 4)),), - val_map_op = None, - shape = (N // 4,) - )) - - outputs.append(select( - tensor = output, - indices = (slice(slice(0, M), cid * (N // 4), (cid + 1) * (N // 4)),), - val_map_op = None, - shape = (M, N // 4) - )) - - # Algorithm -- Expert specified - for weight, bias, output in enumerate(zip(weights, biases, outputs)): - # physical tensor - chunk = torch._C._nn.linear(inputs, weight, bias) - # physical tensor fill in to logical tensor - output.fill(chunk) - - # Tensor deployment -- system + policy generated - inputs = deploy( - segment = inputs, - ranks = [0, 1, 2, 3] - ) - - for rank, (weight, bias) in enumerate(zip(weights, biases)): - weight = deploy( - segment = weight, - ranks = [rank], - ) - bias = deploy( - segment = bias, - ranks = [rank], - ) - - # Logical tensor merge -- system + policy generated - merge( - tensor = outputs, - ranks = [0, 1, 2, 3], - merge_op = partial(all_gather, dim=1) - ) - - -def linear_tensor_parallel_space(inputs, weight, bias, output): - """ - inputs: (M, K) - weight: (N, K) - bias: (N,) - output: (M, N) - - Perform: (M, K) * (\delta N, K) + (\delta N,) = (M, \delta N) - """ - - # no split - def Full(): pass - # split at axis - def SplitAxis(axis, chunk_num, overlap): pass - - # add constraints for inter-tensors - def add_constraint(condition): pass - - # ========= segmentation constraints ===========# - inputs.segment = Full() - weight.segment = SplitAxis( - axis=0, chunk_num=None, overlap=0 - ) - bias.segment = SplitAxis( - axis=0, chunk_num=None, overlap=0 - ) - add_constraint(bias.segment.chunk_num == weight.segment.chunk_num) - - output.segment = SplitAxis( - axis=1, chunk_num=None, overlap=0 - ) - add_constraint(output.segment.chunk_num == weight.layout.chunk_num) - - # ========= distributed algorithms ============# - for pweight, pbias, pout in zip(weight, bias, output): - pout.fill(linear(inputs, pweight, pbias)) - return output - - -## =============== tensor movement / re-generation ============== ## - -def custom_op(forward_fn, backward_fn): pass - -def offload(inputs: PhyiscalTensor, weights: list(PhyiscalTensor), ops: list(PhyiscalOp)): - """ - offload a feature_map after forward the 3rd op - retrieve (prefetch) the feature_map after backward the 5th op - """ - feature_maps = [inputs] - offload_step = 2 - retrieve_step = 4 - for step, (weight, op) in enumerate(zip(weights, ops)): - tensor = feature_maps[-1] - # retrieve - if step == retrieve_step: - feature_maps[-1] = custom_op( - forward_fn=partial((lambda input: input), input=feature_maps[-1]), - backward_fn=partial(move, feature_maps[offload_step + 1], rank=0) - ) - # op calculation - out = op(tensor, weight) - # offload - if step == offload_step: - move(tensor, rank=-1) - feature_maps.append(out) - - -def checkpoint(inputs: PhyiscalTensor, weights: list(PhyiscalTensor), ops: list(PhyiscalOp)): - """ - checkpoint a feature_map after forward the 3rd op - re-generate (possible for packing with other operator) after backward the 5th op - """ - feature_maps = [inputs] - release_step = 2 - recompute_step = 4 - released_tensor = None - for step, (weight, op) in enumerate(zip(weights, ops)): - tensor = feature_maps[-1] - # retrieve - if step == recompute_step: - feature_maps[-1] = custom_op( - forward_fn=partial((lambda input: input), input=feature_maps[-1]), - backward_fn=partial(generate, feature_maps[release_step + 1], rank=0) - ) - # op calculation - out = op(tensor, weight) - # offload - if step == release_step: - release(tensor) - feature_maps.append(out) diff --git a/examples/case_study/temporal_primitive.py b/examples/case_study/temporal_primitive.py deleted file mode 100644 index 1245c890..00000000 --- a/examples/case_study/temporal_primitive.py +++ /dev/null @@ -1,282 +0,0 @@ -import torch - -from functools import partial - - -def select(tensor, indices, val_map_op=None, shape=None): - pass - -## Abstractions and Primitivse ## - -class Action: pass - -def execute(action, **kwargs): - # action instance will automatically take flow-in results - # and select the chunked kwargs - return action(**kwargs) - -def add_flow(*actions): - # this will set all input actions with same flow-id - pass - - -## System Runtime units ## - -def run(schedule, num_microbs, *args): - """ - Take a list of actions and execute in list order - """ - myrank = torch.distributed.get_rank() - chunked_args = list() - for arg in args: - if torch.is_tensor(arg): - chunk_size = data.size(0) / num_microbs - arg = [ - select(arg, slice(chunk_size * 0, chunk_size * 1)), - select(arg, slice(chunk_size * 1, chunk_size * 2)), - select(arg, slice(chunk_size * 2, chunk_size * 3)), - select(arg, slice(chunk_size * 3, chunk_size * 4)) - ] - chunked_args.append(arg) - for action in schedule: - if action.device == myrank: - # wait for cross-device dependency (if have) - action.wait() - # execute - outs = execute(action, *tuple(args)) - return outs - -def check_consistency(sequence, actions, relations): pass - - -# Schedule example - -def naive_schedule(actions: list(Action), relations: set((Action, Action))) -> list(Action): - """ - Args: - actions: order specified by AI scientist (the reference semantic) - relations: set of action dependencies (action1, action2): action1 -> action2 - - Returns: - a execution sequence following the abstraction - """ - # placement - for action in actions: - action.device = 0 - # execution sequence - sequence = actions - return sequence - - -def pipeline_1f1b_schedule(actions, relations): - """ - Pipeline 1f1b policy description -- generate a sequence - - Actions: a list of actions - - relations: list[(Action1, Action2)]: a list of tuples indicate partial order - """ - - # suppose input actions are forward and backward of grad accumulation - # suppose in forward -> ... -> forward -> backward -> ... -> backward - num_stage = torch.distributed.get_world_size() - num_microbatch = len(actions) / 2 / num_stage - - f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] - b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] - - # action placement - for stage in range(num_stage): - for mid in range(num_microbatch): - f(stage, mid).device = torch.device.cuda(stage) - b(stage, mid).device = torch.device.cuda(stage) - - sequence = list() - - # warmup: - for stage in range(num_stage): - for mid in range(stage): - sequence.append(f(stage, mid)) - - # steady + cooldown: - for mid in range(num_microbatch): - # enqueue backward - for stage in range(num_stage-1, -1, -1): - sequence.append(b(stage, mid)) - # enqueue forward - for stage in range(num_stage): - f_mid = mid + 1 + num_stage - stage - if f_mid >= num_microbatch: - continue - sequence.append(f(stage, f_mid)) - assert check_consistency(sequence, actions, relations) - return sequence - - -def pipeline_1f1b_schedule(actions, relations): - """ - Pipeline 1f1b policy description -- each device order - - Actions: a list of actions - - relations: list[(Action1, Action2)]: a list of tuples indicate partial order - """ - num_stage = torch.distributed.get_world_size() - num_microbatch = len(actions) / 2 / num_stage - - f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] - b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] - - # action placement - for stage in range(num_stage): - for mid in range(num_microbatch): - f(stage, mid).device = torch.device.cuda(stage) - b(stage, mid).device = torch.device.cuda(stage) - - # action in-device order - stage_order = list() - - for stage in range(num_stage): - order = list() - num_warmup_microbatch = num_stage - stage - 1 - num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) - num_microbatch_remain = num_microbatch - num_warmup_microbatch - - # warmup - for mid in range(num_warmup_microbatch): - order.append(f(stage, mid)) - - # steady - for i in range(num_microbatch_remain): - f_mid = num_warmup_microbatch + i - b_mid = i - order.append(f(stage, f_mid)) - order.append(b(stage, b_mid)) - - # cooldown - for i in range(num_warmup_microbatch): - b_mid = num_microbatch_remain + i - order.append(b(stage, b_mid)) - - stage_order.append(order) - - assert check_consistency(stage_order, actions, relations) - return stage_order - - - -if __name__ == '__main__': - - # define logical model / optimizer / data loader - class LogicalModel: pass - class Optimizer: pass - class DataLoader: pass - compute_loss = lambda output, label : output - - - model = LogicalModel() - optimizer = Optimizer(model.parameters()) - dataloader = DataLoader(bs=1024) - - for epoch in range(100): - for step, (data, label) in enumerate(dataloader): - # enqueue forward specfied by schedule and execute the first one - output = model(data) - # accessing partial output data without generation will rase warning - # pop forward until to generate the backward tensor - loss = compute_loss(output, label) - loss.backward() - - # loss = schedule(data=data) - optimizer.step() - # lr_scheduler.step() - optimizer.zero_grad() - print(loss) - - if (epoch + 1) % 4 == 0: - model.eval() - # evaluation - - - - -# ======== example sequences for all kinds of configuration ============= - -forward = lambda model, data: model(data) -backward = lambda grad, output: output.backward(grad) -update_gradient = lambda model, grad: model.update(grad) - - -def train_iter_grad_accumulate(model, datas, stage=2, micro_bs=4): - - out_s0_d0 = forward(model[0], datas[0]) - out_s1_d0 = forward(model[1], out_s0_d0) - grad_s1_d0 = backward(out_s1_d0) - grad_s0_d0 = backward(out_s0_d0, grad=grad_s1_d0) - - out_s0_d1 = forward(model[0], datas[1]) - out_s1_d1 = forward(model[1], out_s0_d1) - grad_s1_d1 = backward(out_s1_d1) - grad_s0_d1 = backward(out_s0_d0, grad=grad_s1_d1) - - out_s0_d2 = forward(model[0], datas[2]) - out_s1_d2 = forward(model[1], out_s0_d2) - grad_s1_d2 = backward(out_s1_d2) - grad_s0_d2 = backward(out_s0_d0, grad=grad_s1_d2) - - out_s0_d3 = forward(model[0], datas[3]) - out_s1_d3 = forward(model[1], out_s0_d3) - grad_s1_d3 = backward(out_s1_d3) - grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) - - update_gradient(model[0], model[0].weights.grad) - update_gradient(model[1], model[1].weights.grad) - - -def train_iter_1f1b(model, datas, stage=2, micro_bs=4): - - out_s0_d0 = forward(model[0], datas[0]) - out_s1_d0 = forward(model[1], out_s0_d0) - grad_s1_d0 = backward(out_s1_d0) - - out_s0_d1 = forward(model[0], datas[1]) - grad_s0_d0 = backward(out_s0_d0, grads=grad_s1_d0) - out_s1_d1 = forward(model[1], out_s0_d1) - grad_s1_d1 = backward(out_s1_d1) - - out_s0_d2 = forward(model[0], datas[2]) - grad_s0_d1 = backward(out_s0_d0, grad=grad_s1_d1) - out_s1_d2 = forward(model[1], out_s0_d2) - grad_s1_d2 = backward(out_s1_d2) - - out_s0_d3 = forward(model[0], datas[3]) - grad_s0_d2 = backward(out_s0_d0, grad=grad_s1_d2) - out_s1_d3 = forward(model[1], out_s0_d3) - grad_s1_d3 = backward(out_s1_d3) - update_gradient(model[1], model[1].weights.grad) - - grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) - update_gradient(model[0], model[0].weights.grad) - - -def train_iter_gpipe(model, datas, stage=2, micro_bs=4): - - out_s0_d0 = forward(model[0], datas[0]) - out_s1_d0 = forward(model[1], out_s0_d0) - out_s0_d1 = forward(model[0], datas[1]) - out_s1_d1 = forward(model[1], out_s0_d1) - out_s0_d2 = forward(model[0], datas[2]) - out_s1_d2 = forward(model[1], out_s0_d2) - out_s0_d3 = forward(model[0], datas[3]) - out_s1_d3 = forward(model[1], out_s0_d3) - - grad_s1_d0 = backward(out_s1_d0) - grad_s0_d0 = backward(out_s0_d0, grad=grad_s1_d0) - grad_s1_d1 = backward(out_s1_d1) - grad_s0_d1 = backward(out_s0_d0, grad=grad_s1_d1) - grad_s1_d2 = backward(out_s1_d2) - grad_s0_d2 = backward(out_s0_d0, grad=grad_s1_d2) - grad_s1_d3 = backward(out_s1_d3) - update_gradient(model[1], model[1].weights.grad) - grad_s0_d3 = backward(out_s0_d0, grad=grad_s1_d3) - update_gradient(model[0], model[0].weights.grad) diff --git a/examples/case_study/ultimate/grad_accumulation_linear.py b/examples/case_study/ultimate/grad_accumulation_linear.py deleted file mode 100644 index 659fcf7e..00000000 --- a/examples/case_study/ultimate/grad_accumulation_linear.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -import math - -torch.manual_seed(121) - - -def apply_adam(params, grads, exp_avgs, exp_avg_sqs, steps, beta1, beta2, lr): - for i, param in enumerate(params): - - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step = steps[-1] - - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step - - exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8) - step_size = lr / bias_correction1 - param.addcdiv_(exp_avg, denom, value=-step_size) - - - -if __name__ == '__main__': - - global_bs = 128 - bs = 32 - feats = 1024 - - weight = torch.randn((feats, feats)).cuda().requires_grad_() - bias = torch.randn((feats,)).cuda().requires_grad_() - - ## Adam optimizer states -- 2x more weights volume - weight_exp_avg = torch.zeros_like( - weight, memory_format=torch.preserve_format - ) - weight_exp_avg_sq = torch.zeros_like( - weight, memory_format=torch.preserve_format - ) - bias_exp_avg = torch.zeros_like( - bias, memory_format=torch.preserve_format - ) - bias_exp_avg_sq = torch.zeros_like( - bias, memory_format=torch.preserve_format - ) - state_steps = list() - lr = 0.01 - beta1 = 0.5 - beta2 = 0.5 - - inputs = [torch.randn((bs, feats)).cuda() for _ in range(16)] - # inputs = [torch.randn((bs, feats)).cuda()] * 16 # for debug - - update_interval = int(global_bs / bs) - tic = 0 - for input_data in inputs: - tic += 1 - - # ======= step1: forward ======= # - out = torch._C._nn.linear(input_data, weight, bias) - loss = torch.mean(out) / update_interval ## loss also need scale - print('loss: {}'.format(loss)) - - # ======= step2: backward ======= # - loss.backward() - # Note: during backward, PyTorch will do tensor.grad += computed_grad - # if tensor had gradient, this will do accumulation by default. - - # ======= step3: update ======= # - if tic % update_interval == 0: - params = [weight, bias] - grads = [weight.grad, bias.grad] - exp_avgs = [weight_exp_avg, bias_exp_avg] - exp_avg_sqs = [weight_exp_avg_sq, bias_exp_avg_sq] - state_steps.append(len(state_steps)+1) - with torch.no_grad(): - apply_adam( - params, grads, exp_avgs, exp_avg_sqs, state_steps, - beta1, beta2, lr - ) - # zero out grad - weight.grad = None - bias.grad = None diff --git a/examples/case_study/ultimate/model_partition_linear.py b/examples/case_study/ultimate/model_partition_linear.py deleted file mode 100644 index 75b262d3..00000000 --- a/examples/case_study/ultimate/model_partition_linear.py +++ /dev/null @@ -1,154 +0,0 @@ -"""Example Usage - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - examples/case_study/ultimate/model_partition_linear.py -""" - -import torch -from torch import nn -import os - -class Linears(nn.Module): - """ - Note in model creation, it will only construct model chunks - that belong to this rank - """ - - def __init__(self, features, op_num=4): - super().__init__() - self.ops = nn.ModuleList([]) - - myrank = torch.distributed.get_rank() - ngpus = torch.distributed.get_world_size() - op_num_per_rank = int(op_num / ngpus) - - for _ in range(op_num_per_rank): - self.ops.append(nn.Linear(features, features)) - - def forward(self, x): - out = x - for op in self.ops: - out = op(out) - return out - - -def is_last_stage(): - return torch.distributed.get_rank() == torch.distributed.get_world_size() - 1 - -#================= WhatToDO functions ==================# - -def forward_step(model, input_tensor): - output_tensor = model(input_tensor) - # last stage: calcuate loss - if is_last_stage(): - output_tensor = torch.sum(output_tensor) - print('loss: {}'.format(output_tensor)) - return output_tensor - - -def backward_step(input_tensor, output_tensor, output_tensor_grad): - """ - Calculate input tensor gradient - """ - if input_tensor is not None and input_tensor.requires_grad: - input_tensor.retain_grad() - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) - input_tensor_grad = None - if input_tensor is not None and input_tensor.requires_grad: - input_tensor_grad = input_tensor.grad - return input_tensor_grad - -#================= WhatToDO functions ==================# - -#================= Between Stage functions ==================# - -def send(tensor, to_rank): - """ - send tensor to the target rank - """ - if to_rank < 0 or to_rank >= torch.distributed.get_world_size(): - return None - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, to_rank - ) - reqs = torch.distributed.batch_isend_irecv([send_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - - -def recv(shape, from_rank, boundary_tensor): - if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): - return boundary_tensor - tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device() - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, from_rank - ) - reqs = torch.distributed.batch_isend_irecv([recv_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - return tensor - -#================= Between Stage functions ==================# - - -#================= Scheduling ==================# - -def scheduling_naive(model, inputs, bs, feats): - - myrank = torch.distributed.get_rank() - - # ================ forward pass ================ # - # recv input data - input_tensor = recv(torch.Size([bs, feats]), myrank-1, inputs) - # forward - output_tensor = forward_step(model, input_tensor) - # send forward - send(output_tensor, myrank+1) - - # ================ backward pass ================ # - # recv backward - output_tensor_grad = recv(torch.Size([bs, feats]), myrank+1, None) - # backward - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad) - # send backward - send(input_tensor_grad, myrank-1) - - # ================ weight update ================ # - # xxx - -#================= Scheduling ==================# - - -if __name__ == '__main__': - - # initialize distributed env - local_rank = int(os.environ.get('LOCAL_RANK')) - torch.cuda.set_device(local_rank) - torch.distributed.init_process_group( - backend='nccl', - init_method='env://', - ) - myrank = torch.distributed.get_rank() - - bs = 32 - features = 10240 - - model = Linears(features, op_num=4).cuda() - - if myrank == 0: - inputs = torch.randn((bs, features)).cuda() - else: - inputs = None - - scheduling_naive(model, inputs, bs, features) diff --git a/examples/case_study/ultimate/offload_linear.py b/examples/case_study/ultimate/offload_linear.py deleted file mode 100644 index 7df14b87..00000000 --- a/examples/case_study/ultimate/offload_linear.py +++ /dev/null @@ -1,149 +0,0 @@ -import torch -import os - -torch.manual_seed(121) - -tensor_map = dict() - - -def swap_weight_grad_linear(input, weight, bias): - - ### Policy ### - - # op placement - op_device = torch.device('cuda:0') - - # tensor placement: this should be set at tensor creation stage - # note here if change this, we also need to change tensor init at main - weight.host_device = torch.device('cpu') - bias.host_device = torch.device('cpu') - - # grad placement: this can be set before running - grad_device = torch.device('cuda:0') - def grad_swap(grad): - grad.data = grad.detach().to(grad_device) - return grad - weight.register_hook(grad_swap) - bias.register_hook(grad_swap) - - ## Placement for a tensor swap in/out - ## where to swap in: op.device (op placement policy) - ## where to swap out: tensor.swap_to (policy) - - ## Timing when a tensor swapped in/out - ## Basic Time block (each op is a slot?) - ## Event-driven (tesnor access? on-demand? | dynamic scenario?) - - # Policy description - # op.device = torch.device('cuda:0') - # ... - - ##### - - class SwapLinear(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, bias): - - weight_id = id(weight) - bias_id = id(bias) - ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id)) - tensor_map[weight_id] = weight - tensor_map[bias_id] = bias - - # retrieve from cpu memory - if weight.device != op_device: - weight.data = weight.detach().to(op_device) - if bias.get_device() != op_device: - bias.data = bias.detach().to(op_device) - - # compute - output = torch._C._nn.linear(input, weight, bias) - - # offload to CPU - if weight.device != weight.host_device: - weight.data = weight.detach().to(weight.host_device) - if bias.device != bias.host_device: - bias.data = bias.detach().to(bias.host_device) - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight_id, bias_id = ctx.saved_tensors - weight = tensor_map[weight_id.item()] - bias = tensor_map[bias_id.item()] - - grad_input = grad_weight = grad_bas = None - if ctx.needs_input_grad[0]: - print('computing grad of input...') - # retrieve weight - if weight.device != op_device: - weight.data = weight.detach().to(op_device) - grad_input = grad_output.matmul(weight) - if weight.device != weight.host_device: - weight.data = weight.detach().to(weight.host_device) - if ctx.needs_input_grad[1]: - dim = grad_output.dim() - if dim > 2: - grad_weight = grad\ - .view(-1, grad_output.shape[-1])\ - .t()\ - .matmul(input.view(-1, input.shape[-1])) - else: - grad_weight = grad_output.t().matmul(input) - if ctx.needs_input_grad[2]: - grad_bias = grad_output.sum(0) - - ### Move gradient to it's tensor host device ### - ### WARNING: there will be up to 2 redundant I/O if we require - ### gradient to place differently with its tensor - if grad_weight is not None: - grad_weight.data = grad_weight.detach().to(weight.host_device) - if grad_bias is not None: - grad_bias.data = grad_bias.detach().to(bias.host_device) - - return grad_input, grad_weight, grad_bias - - output = SwapLinear.apply(input, weight, bias) - return output - - -if __name__ == '__main__': - - torch.cuda.set_device(0) - init_memory = torch.cuda.memory_allocated() - - # tensor definition - batch_size = 32 - out_features = 10240 - in_features = 10240 ## 100 MB weight - weight_1 = torch.rand((out_features, in_features)).requires_grad_() - bias_1 = torch.rand(out_features).requires_grad_() - input = torch.rand((batch_size, in_features)).cuda() - weight_2 = torch.rand((out_features, in_features)).requires_grad_() - bias_2 = torch.rand(out_features).requires_grad_() - - input_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - - # op compute - print('======== Offloading Single Device =======') - weight_swap_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - - output = swap_weight_grad_linear(input, weight_1, bias_1) - output = swap_weight_grad_linear(output, weight_2, bias_2) - loss = torch.mean(output) * 100 - print('loss: {}'.format(loss)) - loss.backward() - - finish_op_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - max_allocated = (torch.cuda.max_memory_allocated() - init_memory) / 1024 / 1024 - - # allocate tensor on gpu to see if swap workds - tmp = torch.rand((out_features, in_features)).cuda() - after_alloc_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - - print('Memory Consumption (MB):\n\t input-require: {:.2f}\n\t after swap weight: {:.2f}\n\t after op run {:.2f}\n\t max allocated: {:.2f}\n\t after allocate {:.2f}'.format( - input_memory, weight_swap_memory, finish_op_memory, max_allocated, after_alloc_memory)) - - # correctness verify - print('weight grad: ', weight_1.grad.t()) - print('======== Offloading Single Device =======') diff --git a/examples/case_study/ultimate/parallel_linear.py b/examples/case_study/ultimate/parallel_linear.py deleted file mode 100644 index 1f077c27..00000000 --- a/examples/case_study/ultimate/parallel_linear.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Example Usage - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - examples/case_study/parallel_linear.py -""" - -import torch -import os -from torch.nn.parameter import Parameter -torch.manual_seed(121) - -hooks = list() - -# tensor parallel - split weight in column -def linear_tensor_parallel(input, weight, bias): - ### Policy need to know ### - devices = [0, 1, 2, 3] # how many device to perform? - - ### Necessary information to know ### - rank = torch.distributed.get_rank() # which role I participate? - - ### Additional ops need to use ### -- TODO: System provided - class InputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return input_ - @staticmethod - def backward(ctx, grad_output): - return torch.distributed.all_reduce(grad_output) - - class OutputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_) - output = torch.cat(tensor_list, dim=-1) - return output - @staticmethod - def backward(ctx, grad_output): - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - tensor_list = torch.split( - grad_output, grad_output.size()[-1]//world_size, dim=-1 - ) - return tensor_list[rank].contiguous() - - ### Input Slice ### TODO: expert description on how to tile - weight = torch.chunk(weight, chunks=len(devices), dim=0)[rank].contiguous() - bias = torch.chunk(bias, chunks=len(devices), dim=0)[rank].contiguous() - - ### Input Adapter ### TODO: system generated according to segmentation - input = InputAdapter.apply(input) - - ### Forward ### TODO: expert description on how to compute - output = torch._C._nn.linear(input, weight, bias) - - ### Ouput Adapter ### TODO: system generated according to segmentation - # insert a forward + backward op at last (allgather - split) - output = OutputAdapter.apply(output) - return output - - -# data parallel -def linear_data_parallel(input, weight, bias): - ### Policy need to know ### - devices = [0, 1, 2, 3] # how many device to perform? - - ### Necessary information to know ### - rank = torch.distributed.get_rank() # which role I participate? - - ### Additional ops need to use ### - # -> torch.distributed.all_reduce at backward - - ### Input Slice ### TODO: expert description on how to tile - input = torch.chunk(input, chunks=len(devices), dim=0)[rank].contiguous() - - ### Input Adapter ### TODO: system generated according to segmentation - def grad_hook(grad): - torch.distributed.all_reduce(grad) - grad /= len(devices) - return grad - hw = weight.register_hook(grad_hook) - hb = bias.register_hook(grad_hook) - global hooks - hooks += [hw, hb] - - ### Forward ### TODO: expert description on how to compute - output = torch._C._nn.linear(input, weight, bias) - - ### Output Adapter ### TODO: system generated according to segmentation - return output - - -# tensor + data parallel -def linear_hybrid_tensor_data_parallel(input, weight, bias): - ### Policy need to know ### - tp_size = 2 # how many slices? which device? - dp_size = 2 - - ### Necessary information to execute ### - rank = torch.distributed.get_rank() # which role I participate? - - # data parallel group - dp_group = None - group = torch.distributed.new_group([0,2]) - if rank in [0, 2]: - dp_group = group - group = torch.distributed.new_group([1,3]) - if rank in [1, 3]: - dp_group = group - dp_rank = torch.distributed.get_rank(group=dp_group) - - # tensor parallel group - tp_group = None - group = torch.distributed.new_group([0,1]) - if rank in [0, 1]: - tp_group = group - group = torch.distributed.new_group([2,3]) - if rank in [2, 3]: - tp_group = group - tp_rank = torch.distributed.get_rank(group=tp_group) - tp_world_size = torch.distributed.get_world_size(group=tp_group) - print_each_rank( - 'rank global:tp:dp=[{}:{}:{}] | size global:tp:dp=[{}:{}:{}]'.format( - torch.distributed.get_rank(), - torch.distributed.get_rank(tp_group), - torch.distributed.get_rank(dp_group), - torch.distributed.get_world_size(), - torch.distributed.get_world_size(tp_group), - torch.distributed.get_world_size(dp_group) - )) - - ### Additional Ops ### - class InputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group): - ctx.constants = group - return input_ - @staticmethod - def backward(ctx, grad_output): - group = ctx.constants - return torch.distributed.all_reduce(grad_output, group=group), None - - class OutputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group, dim=-1): - world_size = torch.distributed.get_world_size(group=group) - rank = torch.distributed.get_rank(group=group) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=group) - output = torch.cat(tensor_list, dim=dim) - ctx.constants = (group, dim) - return output - @staticmethod - def backward(ctx, grad_output): - group, dim = ctx.constants - world_size = torch.distributed.get_world_size(group=group) - rank = torch.distributed.get_rank(group=group) - tensor_list = torch.split( - grad_output, grad_output.size()[-1]//world_size, dim=dim - ) - return tensor_list[rank].contiguous(), None, None - - ### Input Adapter - Slice ### TODO: expert description on how to tile - input = torch.chunk(input, chunks=dp_size, dim=0)[dp_rank].contiguous() - weight = torch.chunk(weight, chunks=tp_world_size, dim=0)[tp_rank].contiguous() - bias = torch.chunk(bias, chunks=tp_world_size, dim=0)[tp_rank].contiguous() - - ### Input Adapter - Data Parallel ### TODO: system generated according to segmentation - def grad_hook(grad): - torch.distributed.all_reduce(grad, group=dp_group) - grad /= dp_size - return grad - hw = weight.register_hook(grad_hook) - hb = bias.register_hook(grad_hook) - global hooks - hooks += [hw, hb] - - ### Input Adapter - Tensor Parallel ### TODO: system generated according to segmentation - input = InputAdapter.apply(input, tp_group) - - ### Forward ### TODO: expert description on how to compute - output = torch._C._nn.linear(input, weight, bias) - - ### Output Adapter - Tensor Parallel ### TODO: system generated according to segmentation - output = OutputAdapter.apply(output, tp_group, -1) - - ### Ouput Adapter - Data Parallel ### - ## No need - - return output - - - -######### Utility ############# -def print_each_rank(msg, selected_rank=None): - myrank = torch.distributed.get_rank() - for rank in range(torch.distributed.get_world_size()): - if selected_rank is None or myrank in selected_rank: - if myrank == rank: - print('rank [{}]: {}\n'.format(rank, msg)) - torch.distributed.barrier() - - -if __name__ == '__main__': - - local_rank = int(os.environ.get('LOCAL_RANK')) - torch.cuda.set_device(local_rank) - torch.distributed.init_process_group( - backend='nccl', - init_method='env://', - ) - - # tensor definition - batch_size = 32 - out_features = 10240 - in_features = 10240 - weight = torch.rand((out_features, in_features)).cuda().requires_grad_() - # print_each_rank('weight: {}'.format(weight)) - bias = torch.rand(out_features).cuda().requires_grad_() - # print_each_rank('bias: {}'.format(bias)) - input = torch.rand((batch_size, in_features)).cuda() - # print_each_rank('input: {}'.format(input)) - - # tensor parallel - print_each_rank('======== Model Parallel =========', [0]) - output = linear_tensor_parallel(input, weight, bias) - loss = torch.mean(output) * 100 - print_each_rank(loss) - loss.backward() - # note weight is created as transposed - print_each_rank('weight grad: {}'.format(weight.grad.t())) - print_each_rank('======== Model Parallel =========', [0]) - - # data parallel - weight.grad = None - bias.grad = None - print_each_rank('======== Data Parallel =========', [0]) - output = linear_data_parallel(input, weight, bias) - loss = torch.mean(output) * 100 - loss.backward() - print_each_rank('weight grad: {}'.format(weight.grad.t())) - print_each_rank('======== Data Parallel =========', [0]) - - # hybrid tensor-data parallel - weight.grad = None - bias.grad = None - for hook in hooks: - hook.remove() - print_each_rank('======== Data + Tensor Parallel =========', [0]) - output = linear_hybrid_tensor_data_parallel(input, weight, bias) - loss = torch.mean(output) * 100 - # print_each_rank(loss) - loss.backward() - print_each_rank('weight grad: {}'.format(weight.grad.t())) - print_each_rank('======== Data + Tensor Parallel =========', [0]) diff --git a/examples/case_study/ultimate/pipeline_linear.py b/examples/case_study/ultimate/pipeline_linear.py deleted file mode 100644 index 77bdfb25..00000000 --- a/examples/case_study/ultimate/pipeline_linear.py +++ /dev/null @@ -1,232 +0,0 @@ -"""Example Usage - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - examples/case_study/ultimate/pipeline_linear.py -""" - -import torch -from torch import nn -import os -import time - - -class Linears(nn.Module): - """ - Note in model creation, it will only construct model chunks - that belong to this rank - """ - - def __init__(self, features, op_num=4): - super().__init__() - self.ops = nn.ModuleList([]) - - myrank = torch.distributed.get_rank() - ngpus = torch.distributed.get_world_size() - op_num_per_rank = int(op_num / ngpus) - - for _ in range(op_num_per_rank): - self.ops.append(nn.Linear(features, features)) - - def forward(self, x): - out = x - for op in self.ops: - out = op(out) - return out - - -def is_first_stage(): - return torch.distributed.get_rank() == 0 - - -def is_last_stage(): - return torch.distributed.get_rank() == torch.distributed.get_world_size() - 1 - - -#================= WhatToDO functions ==================# - -def forward_step(model, input_tensor): - output_tensor = model(input_tensor) - # last stage: calcuate loss - if is_last_stage(): - output_tensor = torch.sum(output_tensor) - return output_tensor - - -def backward_step(input_tensor, output_tensor, output_tensor_grad): - """ - Calculate input tensor gradient - """ - if input_tensor is not None and input_tensor.requires_grad: - input_tensor.retain_grad() - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) - input_tensor_grad = None - if input_tensor is not None and input_tensor.requires_grad: - input_tensor_grad = input_tensor.grad - return input_tensor_grad - -#================= WhatToDO functions ==================# - -#================= Between Stage functions ==================# - -def send(tensor, to_rank): - """ - send tensor to the target rank - """ - if to_rank < 0 or to_rank >= torch.distributed.get_world_size(): - return None - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, to_rank - ) - reqs = torch.distributed.batch_isend_irecv([send_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - - -def recv(shape, from_rank, boundary_tensor): - if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): - return boundary_tensor - tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device() - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, from_rank - ) - reqs = torch.distributed.batch_isend_irecv([recv_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - return tensor - - -def send_and_recv(send_tensor, recv_shape, rank, boundary_tensor): - if rank < 0 or rank >= torch.distributed.get_world_size(): - return boundary_tensor - recv_tensor = torch.empty( - recv_shape, requires_grad=True, device=torch.cuda.current_device() - ) - send_op = torch.distributed.P2POp( - torch.distributed.isend, send_tensor, rank - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_tensor, rank - ) - reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - return recv_tensor - -#================= Between Stage functions ==================# - - -#================= Scheduling ==================# - -def scheduling_1f1b(model, inputs, bs, feats, micro_bs): - myrank = torch.distributed.get_rank() - - num_microbatches = int(bs / micro_bs) - num_warmup_microbatches = \ - (torch.distributed.get_world_size() - - torch.distributed.get_rank() - 1) - num_warmup_remaining = num_microbatches - num_warmup_microbatches - - input_tensors = list() - output_tensors = list() - - if inputs is not None: - inputs = torch.chunk(inputs, chunks=num_microbatches, dim=0) - else: - inputs = [None] * num_microbatches - - # warmup forward pass - for i in range(num_warmup_microbatches): - # recv forward - print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) - input_tensor = recv(torch.Size([micro_bs, feats]), myrank-1, inputs[i]) - # forward - output_tensor = forward_step(model, input_tensor) - # send forward - print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) - send(output_tensor, myrank+1) - - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) - - # before running 1F1B, need to recieve first forward tensor - if num_warmup_remaining > 0: - # recv forward - print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) - input_tensor = recv(torch.Size([micro_bs, feats]), myrank-1, inputs[num_warmup_microbatches]) - - # run 1F1B - for i in range(num_warmup_remaining): - # forward - output_tensor = forward_step(model, input_tensor) - # send forward + recv backward grads - print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) - output_tensor_grad = send_and_recv( - output_tensor, torch.Size([micro_bs, feats]), myrank+1, None) - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) - # backward - input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) - input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) - if i != (num_warmup_remaining-1): - # send backward grads + recv forward results - print('[1f1b] rank {}: step-{}: sending backward + recving forward...'.format(myrank, i)) - input_tensor = send_and_recv( - input_tensor_grad, torch.Size([micro_bs, feats]), myrank-1, inputs[num_warmup_microbatches+i+1]) - else: # last iteration - no more inputs - input_tensor = None - # send backward grads - print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) - send(input_tensor_grad, myrank-1) - - # cooldown gradient trans back - for i in range(num_warmup_microbatches): - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - # recv backward gradients - output_tensor_grad = recv(torch.Size([micro_bs, feats]), myrank+1, None) - # backward - input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) - # send backward gradients - print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) - send(input_tensor_grad, myrank-1) - -#================= Scheduling ==================# - - -if __name__ == '__main__': - - # initialize distributed env - local_rank = int(os.environ.get('LOCAL_RANK')) - torch.cuda.set_device(local_rank) - torch.distributed.init_process_group( - backend='nccl', - init_method='env://', - ) - myrank = torch.distributed.get_rank() - - bs = 32 - micro_bs = 1 - features = 10240 - - model = Linears(features, op_num=4).cuda() - - if myrank == 0: - inputs = torch.randn((bs, features)).cuda() - else: - inputs = None - - for _ in range(50): - scheduling_1f1b(model, inputs, bs, features, micro_bs) - # torch.distributed.barrier() # for profiling only - # time.sleep(1) diff --git a/examples/case_study/ultimate/recompute_linear.py b/examples/case_study/ultimate/recompute_linear.py deleted file mode 100644 index f46fae58..00000000 --- a/examples/case_study/ultimate/recompute_linear.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -import os - -torch.manual_seed(121) - -### Checkpoint PyTorch Implementation (Skip un-deterministic scenario) ### -# Note this implementation can only work with a module that consists -# multiple operators. This will won't work for one OP because the output -# for this module will be saved in next op -def checkpoint_module_linear(input, weight, bias): - - class Checkpoint(torch.autograd.Function): - """General class to wrapper op to enable checkpoint""" - @staticmethod - def forward(ctx, run_function, *args): - ctx.run_function = run_function - ctx.tensor_indices = [] - tensor_inputs = [] - for i, arg in enumerate(args): - if torch.is_tensor(arg): - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - ctx.save_for_backward(*tensor_inputs) - - with torch.no_grad(): - outputs = run_function(*args) - return outputs - @staticmethod - def backward(ctx, *args): - # retrieve what need to regenerate tensors - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors - # re-generate - for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - # detach inputs - detached_inputs = list() - for input in inputs: - if torch.is_tensor(input): - x = input.detach() - x.requires_grad = input.requires_grad - else: - x = input - detached_inputs.append(x) - detached_inputs = tuple(detached_inputs) - # generate output tensor - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - if torch.is_tensor(outputs): - outputs = (outputs,) - # run backward to tensors that require a grad - outputs_with_grad = list() - args_with_grad = list() - if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: - outputs_with_grad.append(outputs[i]) - args_with_grad.append(args[i]) - torch.autograd.backward(outputs_with_grad, args_with_grad) - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None - for inp in detached_inputs) - return (None, None) + grads - - output = Checkpoint.apply(torch._C._nn.linear, input, weight, bias) - return output \ No newline at end of file diff --git a/examples/case_study/ultimate/zero_linear.py b/examples/case_study/ultimate/zero_linear.py deleted file mode 100644 index 1a73ee40..00000000 --- a/examples/case_study/ultimate/zero_linear.py +++ /dev/null @@ -1,215 +0,0 @@ -""" -Zero Redundancy Implementation - -Partition Weights / Gradients / Optimizer States across GPUs - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=62000 \ - --use_env \ - examples/case_study/ultimate/zero_linear.py -""" -import torch -import os -import math -torch.manual_seed(121) - -tensor_map = dict() - -def linear_zero(input, weight, bias): - ### weight / bias is partitioned ### - class ZeroLinear(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, bias): - - weight_id = id(weight) - bias_id = id(bias) - ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id)) - tensor_map[weight_id] = weight - tensor_map[bias_id] = bias - - # ======= all-gather parameters ========= # - device_num = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - # all-gather weight - weight_list = [torch.empty_like(weight) for _ in range(device_num)] - weight_list[rank] = weight - torch.distributed.all_gather(weight_list, weight) - weight_full = torch.cat(weight_list, dim=0).contiguous() - # all-gather bias - bias_list = [torch.empty_like(bias) for _ in range(device_num)] - bias_list[rank] = bias - torch.distributed.all_gather(bias_list, bias) - bias_full = torch.cat(bias_list, dim=0).contiguous() - # ======= all-gather parameters ========= # - - # compute: -> use full weight / bias - output = torch._C._nn.linear(input, weight_full, bias_full) - - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight_id, bias_id = ctx.saved_tensors - weight = tensor_map[weight_id.item()] - bias = tensor_map[bias_id.item()] - - grad_input = grad_weight = grad_bas = None - if ctx.needs_input_grad[0]: - # ========== all-gather weight =========== # - weight_list = [torch.empty_like(weight) for _ in range(device_num)] - weight_list[rank] = weight - torch.distributed.all_gather(weight_list, weight) - weight_full = torch.cat(weight_list, dim=0).contiguous() - # ========== all-gather weight =========== # - - grad_input = grad_output.matmul(weight_full) - - if ctx.needs_input_grad[1]: - dim = grad_output.dim() - if dim > 2: - grad_weight_full = grad\ - .view(-1, grad_output.shape[-1])\ - .t()\ - .matmul(input.view(-1, input.shape[-1])) - else: - grad_weight_full = grad_output.t().matmul(input) - if ctx.needs_input_grad[2]: - grad_bias_full = grad_output.sum(0) - - ## ========== reduce-scatter for data parallelism ========= ## - device_num = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - grad_weight_list = list(torch.chunk(grad_weight_full, chunks=device_num, dim=0)) - grad_weight = torch.empty_like(grad_weight_list[rank]) - torch.distributed.reduce_scatter(grad_weight, grad_weight_list) - grad_bias_list = list(torch.chunk(grad_bias_full, chunks=device_num, dim=0)) - grad_bias = torch.empty_like(grad_bias_list[rank]) - torch.distributed.reduce_scatter(grad_bias, grad_bias_list) - ## ========== reduce-scatter for data parallelism ========= ## - - return grad_input, grad_weight, grad_bias - - output = ZeroLinear.apply(input, weight, bias) - return output - - -def apply_adam(params, grads, exp_avgs, exp_avg_sqs, steps, beta1, beta2, lr): - for i, param in enumerate(params): - - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step = steps[-1] - - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step - - exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8) - step_size = lr / bias_correction1 - param.addcdiv_(exp_avg, denom, value=-step_size) - - -######### Utility ############# -def print_each_rank(msg, selected_rank=None): - myrank = torch.distributed.get_rank() - for rank in range(torch.distributed.get_world_size()): - if selected_rank is None or myrank in selected_rank: - if myrank == rank: - print('rank [{}]: {}\n'.format(rank, msg)) - torch.distributed.barrier() - - -if __name__ == '__main__': - - local_rank = int(os.environ.get('LOCAL_RANK')) - torch.cuda.set_device(local_rank) - torch.distributed.init_process_group( - backend='nccl', - init_method='env://', - ) - devices = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - # tensor definition - batch_size = 32 - out_features = 10240 - in_features = 10240 ## 100 MB weight - - # weight - weight = torch.chunk( - torch.rand((out_features, in_features)), - chunks=devices, - dim=0 - )[rank].contiguous().cuda().requires_grad_() - - # bias - bias = torch.chunk( - torch.rand((out_features,)), - chunks=devices, - dim=0 - )[rank].contiguous().cuda().requires_grad_() - - ## Adam optimizer states -- Zero-DP: the states are partitioned - weight_exp_avg = torch.zeros_like( - weight, memory_format=torch.preserve_format - ) - weight_exp_avg_sq = torch.zeros_like( - weight, memory_format=torch.preserve_format - ) - bias_exp_avg = torch.zeros_like( - bias, memory_format=torch.preserve_format - ) - bias_exp_avg_sq = torch.zeros_like( - bias, memory_format=torch.preserve_format - ) - state_steps = list() - lr = 0.01 - beta1 = 0.5 - beta2 = 0.5 - - # data - input = torch.rand((batch_size, in_features)).cuda() - - # op compute - print_each_rank('======== Zero-Redundancy =======', [0]) - - output = linear_zero(input, weight, bias) - loss = torch.mean(output) * 100 - print_each_rank('loss: {}'.format(loss)) - loss.backward() - - # adam optimizer - params = [weight, bias] - grads = [weight.grad, bias.grad] - exp_avgs = [weight_exp_avg, bias_exp_avg] - exp_avg_sqs = [weight_exp_avg_sq, bias_exp_avg_sq] - state_steps.append(len(state_steps)+1) - with torch.no_grad(): - apply_adam( - params, grads, exp_avgs, exp_avg_sqs, state_steps, - beta1, beta2, lr - ) - # zero out grad - weight.grad = None - bias.grad = None - - # finish_op_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - # max_allocated = (torch.cuda.max_memory_allocated() - init_memory) / 1024 / 1024 - - # allocate tensor on gpu to see if swap workds - # after_alloc_memory = (torch.cuda.memory_allocated() - init_memory) / 1024 / 1024 - - # print('Memory Consumption (MB):\n\t input-require: {:.2f}\n\t after swap weight: {:.2f}\n\t after op run {:.2f}\n\t max allocated: {:.2f}\n\t after allocate {:.2f}'.format( - # input_memory, weight_swap_memory, finish_op_memory, max_allocated, after_alloc_memory)) - - # correctness verify - output = linear_zero(input, weight, bias) - loss = torch.mean(output) * 100 - print_each_rank('loss: {}'.format(loss)) - print_each_rank('======== Zero-Redundancy =======', [0]) diff --git a/examples/ffn/ffn.py b/examples/feedforward/ffn.py similarity index 94% rename from examples/ffn/ffn.py rename to examples/feedforward/ffn.py index f822eb64..794cffba 100644 --- a/examples/ffn/ffn.py +++ b/examples/feedforward/ffn.py @@ -8,12 +8,12 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/ffn/ffn.py + examples/feedforward/ffn.py OMP_NUM_THREADS=4 torchrun --standalone \ --nproc_per_node=4 \ --nnodes=1 \ - examples/ffn/ffn.py + examples/feedforward/ffn.py OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ @@ -21,7 +21,7 @@ --rdzv_id=888 \ --rdzv_backend=c10d \ --rdzv_endpoint=worker0:8004 \ - examples/ffn/ffn.py + examples/feedforward/ffn.py """ import torch @@ -31,7 +31,7 @@ from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.ffn.policy.data import PAS +from examples.feedforward.policy.data import PAS class FFN(torch.nn.Module): diff --git a/examples/ffn/policy/data.py b/examples/feedforward/policy/data.py similarity index 100% rename from examples/ffn/policy/data.py rename to examples/feedforward/policy/data.py diff --git a/examples/ffn/policy/tensor.py b/examples/feedforward/policy/tensor.py similarity index 100% rename from examples/ffn/policy/tensor.py rename to examples/feedforward/policy/tensor.py diff --git a/examples/poc/pipeline.py b/examples/poc/pipeline.py deleted file mode 100644 index 0d6e6141..00000000 --- a/examples/poc/pipeline.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -This is to check whether backward can be stopped in the middle - -Verified by using `detach()`, `requires_grad_()` and `retain_grad()` -""" - - -import torch -from torch import nn - -torch.manual_seed(100) - - -class LinearModel(nn.Module): - - def __init__(self, dim): - super().__init__() - self.linear1 = nn.Linear(dim, dim) - self.linear2 = nn.Linear(dim, dim) - self.linear3 = nn.Linear(dim, dim) - self.linear4 = nn.Linear(dim, dim) - - def forward(self, x): - x2_ = None - - x1 = self.linear1(x) - - x2 = self.linear2(x1) - - x2_ = x2.detach() - x2_.requires_grad_() - x2_.retain_grad() - x3 = self.linear3(x2_) - - x4 = self.linear4(x3) - - return x4, x2, x2_ - - -if __name__ == '__main__': - - bs = 32 - dim = 1024 - - model = LinearModel(dim) - model = model.cuda() - - inputs = torch.randn((bs, dim), device=torch.device('cuda:0')) - - output, x2, x2_ = model(inputs) - loss = torch.sum(output) - - # check before backward grads - # print('before linear1 weight grad:\n{}'.format(model.linear1.weight.grad)) - # print('before linear2 weight grad:\n{}'.format(model.linear3.weight.grad)) - # print('before x2 tensor:\n{}'.format(x2.grad)) - # print('===============================') - assert model.linear1.weight.grad is None - assert model.linear2.weight.grad is None - - loss.backward() - assert model.linear1.weight.grad is None - assert torch.is_tensor(model.linear3.weight.grad) is True - # print('after linear1 weight grad :\n{}'.format(model.linear1.weight.grad)) - # print('after linear2 weight grad :\n{}'.format(model.linear3.weight.grad)) - # print('after x2 tensor:\n{}'.format(x2.grad)) - - torch.autograd.backward(x2, grad_tensors=x2_.grad) - assert torch.is_tensor(model.linear1.weight.grad) is True - # print('===============================') - # print('after autograd linear1 weight grad :\n{}'.format(model.linear1.weight.grad)) diff --git a/examples/poc/pipeline_space.py b/examples/poc/pipeline_space.py deleted file mode 100644 index 02f7c00a..00000000 --- a/examples/poc/pipeline_space.py +++ /dev/null @@ -1,271 +0,0 @@ -from cube.schedule.action import Action, add_flow -from cube.schedule.iterator import sequence_space, sequence_space_batched, placement_space -from cube.schedule.plan import ExecutionPlan -from cube.schedule.checker import correct_check - -import argparse -import re -import json -import time -import os -import multiprocessing as mp -from functools import partial - - -def get_semantic(forward_fn, backward_fn, num_stage, num_microbatch): - forward_time = 1 - backward_time = 2 - - actions = list() - relations = list() - for mid in range(num_microbatch): - # forward - for stage in range(num_stage): - action = Action(forward_fn) - action.est_latency = forward_time - action.est_memory = 1 - action.tag('fS{}D{}'.format(stage, mid)) - if stage != 0: - relation = (actions[-1], action) - add_flow(actions[-1], action) - relations.append(relation) - else: - action.fid = mid - actions.append(action) - # backward - for stage in range(num_stage): - action = Action(backward_fn) - action.est_latency = backward_time - action.est_memory = -1 - action.tag('bS{}D{}'.format(num_stage - 1 - stage, mid)) - # relation - relation = (actions[-1], action) - add_flow(actions[-1], action) - # append to relation sets - relations.append(relation) - actions.append(action) - return actions, relations - - -def get_stage_and_mid(action): - ids = re.findall(r"S(\d+)D(\d+)", action.name) - stage, mid = int(ids[0][0]), int(ids[0][1]) - return stage, mid - - -def fixed_placement(actions, ndevice, **kwargs): - for action in actions: - stage, _ = get_stage_and_mid(action) - action.device = stage % ndevice - yield actions - - -def full_grid_search(actions, relations, ndevice, nmb, outpath='./figs'): - """ - Search minimal time plan under the memory constraints - """ - - memory_buckets = dict() - for activation_num in range(1, nmb+1): - memory_buckets[activation_num] = None - - tic = time.time() - for cnt, seq in enumerate(sequence_space(actions, relations)): - for dev_num, dev_seq in enumerate(placement_space(seq, ndevice, fb_same=True)): - # print(f'on sequence > {dev_seq}') - execplan = ExecutionPlan(dev_seq, ndevice) - execplan.gen() - span = execplan.get_time() - memory = execplan.get_memory() - # update plan - for upper_mem in memory_buckets: - if memory <= upper_mem: - if memory_buckets[upper_mem] is None: - memory_buckets[upper_mem] = execplan - execplan.draw(outfile=os.path.join(outpath, f'{ndevice}nmb{nmb}dev.mem{memory}.png')) - if span < memory_buckets[upper_mem].get_time(): - memory_buckets[upper_mem] = execplan - execplan.draw(outfile=os.path.join(outpath, f'{ndevice}nmb{nmb}dev.mem{memory}.png')) - print(f'> found a better seq {seq} time {span} mem {memory}') - # input(f'>>> done on {dev_num+1} device placement ') - if (cnt+1) % 1000 == 0: - throughput = 1000 * (nmb ** ndevice) / (time.time() - tic) - tic = time.time() - print('> search [{}-{}] throughput {:.2f} spatial sequences / sec'.format(cnt+1-1000, cnt+1, throughput)) - # dump to json - print(f'> totally done search on {cnt+1} sequences') - for key in memory_buckets: - memory_buckets[key] = memory_buckets[key].to_json() - with open(os.path.join(outpath, 'results.json'), 'w') as outfile: - json.dump(memory_buckets, outfile) - - -def worker_search(seqs, nstage, ndevice, space_iter=placement_space): - sub_memory_buckets = dict() - for activation_num in range(1, 2*nstage+1): - sub_memory_buckets[activation_num] = None - for seq in seqs: - for dev_seq in space_iter(seq, ndevice, fb_same=True): - execplan = ExecutionPlan(dev_seq, ndevice) - execplan.gen() - span = execplan.get_time() - memory = execplan.get_memory() - # update plan - for upper_mem in sub_memory_buckets: - if memory <= upper_mem: - if sub_memory_buckets[upper_mem] is None: - sub_memory_buckets[upper_mem] = execplan - if span < sub_memory_buckets[upper_mem].get_time(): - sub_memory_buckets[upper_mem] = execplan - return sub_memory_buckets - - -def space_search_mp(actions, relations, nstage, nmb, ndevice, outpath, space_iter=placement_space, nworker=40): - """ - Search minimal time plan under the memory constraints - """ - pool = mp.Pool(processes=nworker) - - memory_buckets = dict() - for activation_num in range(1, 2*nstage+1): - memory_buckets[activation_num] = None - - def merge(sub_memory_buckets): - for upper_mem in sub_memory_buckets: - if sub_memory_buckets[upper_mem] is None: - continue - execplan = sub_memory_buckets[upper_mem] - span = execplan.get_time() - memory = execplan.get_memory() - if memory_buckets[upper_mem] is None: - memory_buckets[upper_mem] = execplan - execplan.draw(outfile=os.path.join(outpath, f'{nstage}stage.{nmb}nmb.{ndevice}dev.mem{memory}.png')) - print(f'> found a better seq {execplan.seq} time {span} mem {memory}') - if span < memory_buckets[upper_mem].get_time(): - memory_buckets[upper_mem] = execplan - execplan.draw(outfile=os.path.join(outpath, f'{nstage}stage.{nmb}nmb.{ndevice}dev.mem{memory}.png')) - print(f'> found a better seq {execplan.seq} time {span} mem {memory}') - - bs = (nworker, 256) - nseqs = 0 - for seqs in sequence_space_batched(actions, relations, bs=bs): - handles = list() - for wid in range(nworker): - handle = pool.apply_async(worker_search, args=(seqs[wid], nstage, ndevice, space_iter)) - handles.append(handle) - nseqs += sum([len(worker_seqs) for worker_seqs in seqs]) - print(f'assigned {nseqs} sequences') - for handle in handles: - sub_buckets = handle.get() - merge(sub_buckets) - - pool.close() - pool.join() - - # dump to json - print(f'> totally done search on {nseqs} sequences') - for key in memory_buckets: - if memory_buckets[key] is not None: - memory_buckets[key] = memory_buckets[key].to_json() - with open(os.path.join(outpath, 'results.json'), 'w') as outfile: - json.dump(memory_buckets, outfile) - - -def pipe_1f1b(actions, relations, nstage, nmb, ndevice): - num_stage = nstage - num_microbatch = nmb - - f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] - b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] - - # action placement - for stage in range(num_stage): - for mid in range(num_microbatch): - f(stage, mid).device = stage % ndevice - print(f(stage, mid), f'stage={stage}, mid={mid}, device={stage % ndevice}') - b(stage, mid).device = stage % ndevice - print(b(stage, mid), f'stage={stage}, mid={mid}') - - sequence = list() - - # warmup: - for stage in range(num_stage): - for mid in range(num_stage-stage): - sequence.append(f(stage, mid)) - - # steady + cooldown: - for mid in range(num_microbatch): - # enqueue backward - for stage in range(num_stage-1, -1, -1): - sequence.append(b(stage, mid)) - # enqueue forward - for stage in range(num_stage): - f_mid = mid + num_stage - stage - if f_mid >= num_microbatch: - continue - sequence.append(f(stage, f_mid)) - print(sequence) - assert correct_check(sequence, actions, relations) - execplan = ExecutionPlan(sequence, ndevice) - execplan.draw(outfile='./pipeline-1f1b.png') - - -def gpipe(actions, relations, nstage, nmb, ndevice): - num_stage = nstage - num_microbatch = nmb - - f = lambda stage, micro_batch_id: actions[2 * micro_batch_id * num_stage + stage] - b = lambda stage, micro_batch_id: actions[(2 * micro_batch_id + 1) * num_stage + num_stage - 1 - stage] - - # action placement - for stage in range(num_stage): - for mid in range(num_microbatch): - f(stage, mid).device = stage % ndevice - print(f(stage, mid), f'stage={stage}, mid={mid}, device={stage % ndevice}') - b(stage, mid).device = stage % ndevice - print(b(stage, mid), f'stage={stage}, mid={mid}') - - sequence = list() - - # warmup: - for stage in range(num_stage): - for mid in range(num_microbatch): - sequence.append(f(stage, mid)) - - # backward - for stage in range(num_stage): - for mid in range(num_microbatch): - sequence.append(b(num_stage - 1 - stage, mid)) - - print(sequence) - # assert correct_check(sequence, actions, relations) - execplan = ExecutionPlan(sequence, ndevice) - execplan.draw(outfile='./gpipe.png') - - -def forward(data): - pass - -def backward(grad): - pass - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--nstage', type=int, default=4, - help='number of stages') - parser.add_argument('--nmb', type=int, default=4, - help='number of micro-batch') - parser.add_argument('--ndev', type=int, default=4, - help='number of devices') - parser.add_argument('--full-placement', action='store_true', - help='device assignment for each action will be fully explored') - parser.add_argument('--outpath', type=str, default='/mydata/MagicCube/search/pipeline/') - args = parser.parse_args() - - actions, relations = get_semantic(forward, backward, args.nstage, args.nmb) - - # pipe_1f1b(actions, relations, args.nstage, args.nmb, args.ndev) - # gpipe(actions, relations, args.nstage, args.nmb, args.ndev) - space_iter = placement_space if args.full_placement else fixed_placement - space_search_mp(actions, relations, args.nstage, args.nmb, args.ndev, args.outpath, space_iter=space_iter) diff --git a/examples/poc/space_size.py b/examples/poc/space_size.py deleted file mode 100644 index fb1b4995..00000000 --- a/examples/poc/space_size.py +++ /dev/null @@ -1,57 +0,0 @@ -from cube.schedule.iterator import get_pipeline_seq_space_size - -import argparse - - -def get_seq_space_size(nstage, nmb): - """ - Calculate legal sequence number given num stage and num microbatch - - \prod \limits_{i=1}^{nmb} C(nstage, i*nstage) - - Args: - nstage: number of stages - nmb: number of micro batch - - Return: - total legal line - """ - return get_pipeline_seq_space_size(nstage, nmb) - - -def get_device_space_size(nstage, nmb, ndevice): - """ - Calculate legal spatial sequence number given num stage and num microbatch - - \prod \limits_{i=1}^{nmb} C(nstage, i*nstage) - - Args: - nstage: number of stages - nmb: number of micro batch - ndevice: number of device - - Return: - total legal line - """ - num_actions = nmb * nstage * 2 - device_space_size = ndevice ** num_actions - return device_space_size - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--nstage', type=int, default=4, - help='number of stages') - parser.add_argument('--nmb', type=int, default=4, - help='number of micro-batch') - parser.add_argument('--ndev', type=int, default=4, - help='number of devices') - args = parser.parse_args() - - seq_space = get_seq_space_size(args.nstage, args.nmb) - print('legal sequence space: {}'.format(seq_space)) - dev_space = get_device_space_size(args.nstage, args.nmb, args.ndev) - print('spatial space for one sequence: {}'.format(dev_space)) - total_space = seq_space * dev_space - print('total space: {}'.format(total_space)) diff --git a/examples/efficientnet/efficientnet.py b/handcraft/efficientnet/efficientnet.py similarity index 100% rename from examples/efficientnet/efficientnet.py rename to handcraft/efficientnet/efficientnet.py diff --git a/examples/efficientnet/schedule.py b/handcraft/efficientnet/schedule.py similarity index 100% rename from examples/efficientnet/schedule.py rename to handcraft/efficientnet/schedule.py diff --git a/examples/efficientnet/train.py b/handcraft/efficientnet/train.py similarity index 98% rename from examples/efficientnet/train.py rename to handcraft/efficientnet/train.py index a49b4ea7..b272033c 100644 --- a/examples/efficientnet/train.py +++ b/handcraft/efficientnet/train.py @@ -10,7 +10,7 @@ --pp 8 --gbs 32 --mbs 1 """ import torch -from examples.efficientnet.efficientnet import EfficientNet +from handcraft.efficientnet.efficientnet import EfficientNet import time import argparse @@ -18,7 +18,7 @@ from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary -from examples.efficientnet.schedule import is_last_stage, scheduling_1f1b +from handcraft.efficientnet.schedule import is_last_stage, scheduling_1f1b def model_partition(model, in_size): diff --git a/examples/efficientnet/utils.py b/handcraft/efficientnet/utils.py similarity index 100% rename from examples/efficientnet/utils.py rename to handcraft/efficientnet/utils.py diff --git a/eval/benchmark_gpt.sh b/handcraft/eval/benchmark_gpt.sh similarity index 100% rename from eval/benchmark_gpt.sh rename to handcraft/eval/benchmark_gpt.sh diff --git a/eval/swin_infer_bs1_224_782Mfp32.sh b/handcraft/eval/swin_infer_bs1_224_782Mfp32.sh similarity index 100% rename from eval/swin_infer_bs1_224_782Mfp32.sh rename to handcraft/eval/swin_infer_bs1_224_782Mfp32.sh diff --git a/eval/swin_infer_bs1_640_Gfp16.sh b/handcraft/eval/swin_infer_bs1_640_Gfp16.sh similarity index 100% rename from eval/swin_infer_bs1_640_Gfp16.sh rename to handcraft/eval/swin_infer_bs1_640_Gfp16.sh diff --git a/eval/swin_infer_bs2_224_782Mfp32.sh b/handcraft/eval/swin_infer_bs2_224_782Mfp32.sh similarity index 100% rename from eval/swin_infer_bs2_224_782Mfp32.sh rename to handcraft/eval/swin_infer_bs2_224_782Mfp32.sh diff --git a/eval/swin_infer_bs2_640_Gfp16.sh b/handcraft/eval/swin_infer_bs2_640_Gfp16.sh similarity index 100% rename from eval/swin_infer_bs2_640_Gfp16.sh rename to handcraft/eval/swin_infer_bs2_640_Gfp16.sh diff --git a/eval/swin_infer_bs4_640_Gfp16.sh b/handcraft/eval/swin_infer_bs4_640_Gfp16.sh similarity index 100% rename from eval/swin_infer_bs4_640_Gfp16.sh rename to handcraft/eval/swin_infer_bs4_640_Gfp16.sh diff --git a/eval/swin_scaleup.sh b/handcraft/eval/swin_scaleup.sh similarity index 100% rename from eval/swin_scaleup.sh rename to handcraft/eval/swin_scaleup.sh diff --git a/eval/swin_train_fp16.sh b/handcraft/eval/swin_train_fp16.sh similarity index 100% rename from eval/swin_train_fp16.sh rename to handcraft/eval/swin_train_fp16.sh diff --git a/eval/swin_train_fp32.sh b/handcraft/eval/swin_train_fp32.sh similarity index 100% rename from eval/swin_train_fp32.sh rename to handcraft/eval/swin_train_fp32.sh diff --git a/benchmark/megatron/gpt.py b/handcraft/megatron/gpt.py similarity index 96% rename from benchmark/megatron/gpt.py rename to handcraft/megatron/gpt.py index 861e074a..716d6ff9 100644 --- a/benchmark/megatron/gpt.py +++ b/handcraft/megatron/gpt.py @@ -8,14 +8,14 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - benchmark/megatron/gpt.py + handcraft/megatron/gpt.py """ import torch import torch.nn.functional as F import cube -from benchmark.megatron.layers import ColumnOutputAdapter, ShardEmbedding -from benchmark.megatron.transformer import TransformerLayer +from handcraft.megatron.layers import ColumnOutputAdapter, ShardEmbedding +from handcraft.megatron.transformer import TransformerLayer from cube.profiler import CudaTimer diff --git a/benchmark/megatron/layers.py b/handcraft/megatron/layers.py similarity index 100% rename from benchmark/megatron/layers.py rename to handcraft/megatron/layers.py diff --git a/benchmark/megatron/linears.py b/handcraft/megatron/linears.py similarity index 96% rename from benchmark/megatron/linears.py rename to handcraft/megatron/linears.py index 9c1baf50..23936f11 100644 --- a/benchmark/megatron/linears.py +++ b/handcraft/megatron/linears.py @@ -8,19 +8,19 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - benchmark/megatron/linears.py + handcraft/megatron/linears.py torchrun --standalone \ --nproc_per_node=4 \ --nnodes=1 \ - benchmark/megatron/linears.py + handcraft/megatron/linears.py """ import argparse import torch from torch import nn -from benchmark.megatron.layers import ColumnParallelLinear, RowParallelLinear +from handcraft.megatron.layers import ColumnParallelLinear, RowParallelLinear import cube from cube.profiler import CudaTimer diff --git a/benchmark/megatron_gpt_2.sh b/handcraft/megatron/megatron_gpt_2.sh similarity index 100% rename from benchmark/megatron_gpt_2.sh rename to handcraft/megatron/megatron_gpt_2.sh diff --git a/benchmark/megatron/transformer.py b/handcraft/megatron/transformer.py similarity index 99% rename from benchmark/megatron/transformer.py rename to handcraft/megatron/transformer.py index 9c6f5ac9..04e16ea1 100644 --- a/benchmark/megatron/transformer.py +++ b/handcraft/megatron/transformer.py @@ -15,7 +15,7 @@ from torch import nn import torch.nn.functional as F import cube -from benchmark.megatron.layers import ColumnParallelLinear, RowParallelLinear +from handcraft.megatron.layers import ColumnParallelLinear, RowParallelLinear from cube.profiler import CudaTimer diff --git a/examples/swin/hybrid_schedule.py b/handcraft/swin/hybrid_schedule.py similarity index 100% rename from examples/swin/hybrid_schedule.py rename to handcraft/swin/hybrid_schedule.py diff --git a/examples/swin/layers.py b/handcraft/swin/layers.py similarity index 100% rename from examples/swin/layers.py rename to handcraft/swin/layers.py diff --git a/examples/swin/pmodule.py b/handcraft/swin/pmodule.py similarity index 100% rename from examples/swin/pmodule.py rename to handcraft/swin/pmodule.py diff --git a/examples/swin/schedule.py b/handcraft/swin/schedule.py similarity index 100% rename from examples/swin/schedule.py rename to handcraft/swin/schedule.py diff --git a/examples/swin/swin_dt.py b/handcraft/swin/swin_dt.py similarity index 99% rename from examples/swin/swin_dt.py rename to handcraft/swin/swin_dt.py index 27107b49..6a9ed004 100644 --- a/examples/swin/swin_dt.py +++ b/handcraft/swin/swin_dt.py @@ -28,9 +28,9 @@ from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary from cube.runtime.device import DeviceGroup -from cube.runtime.reducer import Reducer +from cube.runtime.adapter.reducer import Reducer -from examples.swin.layers import ColumnParallelLinear, DPtoTP, RowParallelLinear, TPtoDP +from handcraft.swin.layers import ColumnParallelLinear, DPtoTP, RowParallelLinear, TPtoDP _dp_reducer: Dict[Tuple[int], Reducer] = dict() diff --git a/examples/swin/swin_dwt.py b/handcraft/swin/swin_dwt.py similarity index 99% rename from examples/swin/swin_dwt.py rename to handcraft/swin/swin_dwt.py index 10f9c769..e6f2f081 100644 --- a/examples/swin/swin_dwt.py +++ b/handcraft/swin/swin_dwt.py @@ -29,9 +29,9 @@ from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary from cube.runtime.device import DeviceGroup -from cube.runtime.reducer import Reducer +from cube.runtime.adapter.reducer import Reducer -from examples.swin.layers import ColumnParallelLinear, RowParallelLinear +from handcraft.swin.layers import ColumnParallelLinear, RowParallelLinear _wp_reducer: Dict[Tuple[int], Reducer] = dict() diff --git a/examples/swin/swin_dwt_infer.py b/handcraft/swin/swin_dwt_infer.py similarity index 99% rename from examples/swin/swin_dwt_infer.py rename to handcraft/swin/swin_dwt_infer.py index a4c96c3d..5c3c2e24 100644 --- a/examples/swin/swin_dwt_infer.py +++ b/handcraft/swin/swin_dwt_infer.py @@ -9,7 +9,7 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_dwt.py --bs 8 \ + handcraft/swin/swin_dwt.py --bs 8 \ --layer0 1 4 1 \ --layer1 1 4 1 \ --layer2 1 1 4 \ @@ -29,9 +29,9 @@ from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary from cube.runtime.device import DeviceGroup -from cube.runtime.reducer import Reducer +from cube.runtime.adapter.reducer import Reducer -from examples.swin.layers import ColumnParallelLinear, RowParallelLinear +from handcraft.swin.layers import ColumnParallelLinear, RowParallelLinear _wp_reducer: Dict[Tuple[int], Reducer] = dict() diff --git a/examples/swin/swin_flexflow.py b/handcraft/swin/swin_flexflow.py similarity index 99% rename from examples/swin/swin_flexflow.py rename to handcraft/swin/swin_flexflow.py index cb1f5fdf..3dc1f054 100644 --- a/examples/swin/swin_flexflow.py +++ b/handcraft/swin/swin_flexflow.py @@ -28,9 +28,9 @@ from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary from cube.runtime.device import DeviceGroup -from cube.runtime.reducer import Reducer +from cube.runtime.adapter.reducer import Reducer -from examples.swin.layers import ColumnParallelLinear, DPtoTP, ValueTPtoEleDP, RowParallelLinear, TPtoDP +from handcraft.swin.layers import ColumnParallelLinear, DPtoTP, ValueTPtoEleDP, RowParallelLinear, TPtoDP _dp_reducer: Dict[Tuple[int], Reducer] = dict() diff --git a/examples/swin/swin_hybrid.py b/handcraft/swin/swin_hybrid.py similarity index 99% rename from examples/swin/swin_hybrid.py rename to handcraft/swin/swin_hybrid.py index 6fada771..8ba3f967 100644 --- a/examples/swin/swin_hybrid.py +++ b/handcraft/swin/swin_hybrid.py @@ -9,7 +9,7 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_hybrid.py \ + handcraft/swin/swin_hybrid.py \ --layer0 8 1 1 \ --layer1 8 1 1 \ --layer2 8 1 1 \ @@ -23,7 +23,7 @@ --master_addr=worker-0 \ --master_port=8004 \ --use_env \ - examples/swin/swin_hybrid.py \ + handcraft/swin/swin_hybrid.py \ --layer0 2 8 1 \ --layer1 2 8 1 \ --layer2 2 8 1 \ @@ -45,14 +45,14 @@ from cube.runtime.device import DeviceGroup from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary -from cube.runtime.reducer import Reducer +from cube.runtime.adapter.reducer import Reducer import argparse -from examples.swin.hybrid_schedule import scheduling_1f1b, is_last_stage -from examples.swin.layers import ColumnParallelLinear, RowParallelLinear, DPtoTP, TPtoDP +from handcraft.swin.hybrid_schedule import scheduling_1f1b, is_last_stage +from handcraft.swin.layers import ColumnParallelLinear, RowParallelLinear, DPtoTP, TPtoDP -from examples.swin.pmodule import ParallelModule +from handcraft.swin.pmodule import ParallelModule _dp_reducer: Dict[Tuple[int], Reducer] = dict() diff --git a/examples/swin/swin_pipe.py b/handcraft/swin/swin_pipe.py similarity index 99% rename from examples/swin/swin_pipe.py rename to handcraft/swin/swin_pipe.py index 3bd429b9..ef1580e2 100644 --- a/examples/swin/swin_pipe.py +++ b/handcraft/swin/swin_pipe.py @@ -9,7 +9,7 @@ --master_addr=127.0.0.1 \ --master_port=8004 \ --use_env \ - examples/swin/swin_pipe.py --pp 8 --gbs 32 --mbs 4 + handcraft/swin/swin_pipe.py --pp 8 --gbs 32 --mbs 4 # V100-16GB: 8GPU: need checkpoint: 8 micro bs """ @@ -28,8 +28,8 @@ import argparse -from examples.swin.schedule import scheduling_1f1b, is_last_stage -from benchmark.swin.layers import ColumnParallelLinear, RowParallelLinear +from handcraft.swin.schedule import scheduling_1f1b, is_last_stage +from handcraft.swin.layers import ColumnParallelLinear, RowParallelLinear def drop_path(x, drop_prob: float = 0.): diff --git a/examples/swin/swin_transformer.py b/handcraft/swin/swin_transformer.py similarity index 100% rename from examples/swin/swin_transformer.py rename to handcraft/swin/swin_transformer.py From f024a896314b5e3f80e521919e209d71ab5a826a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 11:17:23 +0000 Subject: [PATCH 0547/1892] update data parallel --- examples/mlp/policy/data_parallel.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/mlp/policy/data_parallel.py b/examples/mlp/policy/data_parallel.py index e0bee742..452aa1de 100644 --- a/examples/mlp/policy/data_parallel.py +++ b/examples/mlp/policy/data_parallel.py @@ -7,14 +7,17 @@ def PAS(graph: IRGraph, resource): Linear Column Partition """ for node in graph.nodes(): - if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(num=resource.ngpus)) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + batch_dim = node.get_batch_dims()[0] + for node in graph.nodes(): + if isinstance(node, IRFwOperation): algo = node.algorithms('dim') - if algo: - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=0, num=resource.ngpus) - ) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) + sub_nodes = graph.partition( + node, algo, config=dict(idx=0, dim=batch_dim, num=resource.ngpus)) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) print(graph.extra_repr()) From c281d30b615d6a292c76f92b596cc9600d82b57a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 11:22:29 +0000 Subject: [PATCH 0548/1892] environ --- scripts/env-setup.sh | 13 ++++++++----- setup.py | 6 +++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index d8fd916f..5b7313e0 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -1,5 +1,5 @@ -echo using docker image nvcr.io/pytorch:pytorch-21.06-py3 +echo using docker image nvcr.io/pytorch:pytorch-21.12-py3 git config --global core.editor "vim" git config --global user.name "Zhiqi Lin" @@ -10,6 +10,7 @@ sudo git config --global user.name "Zhiqi Lin" sudo git config --global user.email "v-zhiql@microsoft.com" sudo chmod -R a+w /opt/conda +sudo apt-get install htop -y sudo apt-get install tmux -y sudo apt-get install psmisc -y sudo apt-get install lsof -y @@ -24,10 +25,10 @@ sudo apt-get install infiniband-diags -y # sudo rm packages-microsoft-prod.deb # install azcopy -wget https://azcopyvnext.azureedge.net/release20210616/azcopy_linux_amd64_10.11.0.tar.gz -O azcopy.tar.gz -tar -zxvf azcopy.tar.gz -sudo mv azcopy_linux_amd64_10.11.0/azcopy /usr/bin/ -rm -rf azcopy_linux_amd64_10.11.0 azcopy.tar.gz +# wget https://azcopyvnext.azureedge.net/release20210616/azcopy_linux_amd64_10.11.0.tar.gz -O azcopy.tar.gz +# tar -zxvf azcopy.tar.gz +# sudo mv azcopy_linux_amd64_10.11.0/azcopy /usr/bin/ +# rm -rf azcopy_linux_amd64_10.11.0 azcopy.tar.gz wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.tmux.conf -O ~/.tmux.conf wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.vimrc -O ~/.vimrc @@ -39,6 +40,8 @@ echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc # cmd for count code lines # find cube/ -name "*.py" -print0 | xargs -0 wc -l + +# training_daemon will disable torch.jit.script pip uninstall training_daemon -y python setup.py develop pip install -r requirements.txt diff --git a/setup.py b/setup.py index 41c696a8..3824b357 100644 --- a/setup.py +++ b/setup.py @@ -2,11 +2,11 @@ setuptools.setup( name= 'cube', - version= '0.1', + version= '0.2', author= 'Zhiqi Lin', author_email= 'v-zhiql@microsoft.com', - description= 'Magic Cube for configurable-DNN framework', - long_description= 'Magic Cube for configurable-DNN framework', + description= 'Parallelize DNN Traning from A Systematic Way', + long_description= 'Parallelize DNN Traning from A Systematic Way', packages= ['cube'], python_requires= '>=3.6', ) From 5317eb98a54d865b2fc9af702282e2385a1bc47e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jan 2022 11:40:11 +0000 Subject: [PATCH 0549/1892] dataloader batch size setup --- cube/compiler.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index b4bfcad5..a4b7844d 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -18,6 +18,7 @@ from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen from cube.profiler.timer import print_each_rank +from cube.runtime.syndata import CubeDataLoader class SemanticModel: @@ -57,7 +58,7 @@ def __call__(self, *args): return self.ir_graph(*args) -def compile(model: SemanticModel, dataloader, +def compile(model: SemanticModel, dataloader: CubeDataLoader, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None): """ AI Scientist calls like: @@ -87,6 +88,8 @@ def train_step(model, dataloader): """ if not isinstance(model, SemanticModel): raise TypeError("Expect Semantic Model") + if not isinstance(dataloader, CubeDataLoader): + raise TypeError("Expect dataloader derived from CubeDataLoader") if callable(PAS): PAS = (PAS,) @@ -174,19 +177,15 @@ def decorator(fn: Callable) -> Callable: attach=True ) - # get dataloader batch size - batch_size = dict() # {devid: batch size} - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - batch_dim = node.get_batch_dims()[0] - dev_batch_size = node.outputs(0).shape[batch_dim] - batch_size[node.device[0]] = dev_batch_size - all_batch_size = set([batch_size[dev] for dev in batch_size]) + # setup batch size + all_batch_size = set() + dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] + for dnode in dnodes: + bs = [out.shape[dim] for out, dim in zip(dnode.outputs(), dnode.get_batch_dims())] + all_batch_size.update(bs) if len(all_batch_size) != 1: - raise NotImplementedError("Heterogenous batch size it not supported") - batch_size = list(all_batch_size)[0] - # assume batch_size is always first dimension - batch_size = torch.tensor([batch_size], dtype=torch.int).cuda() + raise NotImplementedError("Heterogenous batch size is not supported") + batch_size = torch.tensor(list(all_batch_size), dtype=torch.int).cuda() compile_end = time.time() compile_time = compile_end - compile_start @@ -198,7 +197,7 @@ def decorator(fn: Callable) -> Callable: # reset dataloader torch.distributed.broadcast(batch_size, src=0) batch_size = batch_size.item() - print(f'> reseting dataloader batch size to {batch_size}') + print_each_rank(f'reseting dataloader batch size to {batch_size}') dataloader.reset(batch_size=batch_size) # load module From f3980ba84d9b35ad61222a41a5f4ea6ba3cc05ba Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 12 Jan 2022 08:43:51 +0000 Subject: [PATCH 0550/1892] adaptive fontsize for execplan figure --- cube/execplan/execplan.py | 43 +++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index d32ffd4b..00f34b50 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -94,11 +94,11 @@ def map2color(node): if isinstance(node, IRGraph): return map2color(node.nodes(0)) if isinstance(node, IRFwOperation): - return 'blue' + return '#4472C4' # excel blue if isinstance(node, IRBpOperation): - return 'orange' + return '#ED7D31' # excel orange if isinstance(node, IRAdapter): - return 'green' + return '#70AD47' # excel green def map2name(node): if isinstance(node, IRGraph): @@ -141,26 +141,27 @@ def map2name(node): if outfile is not None: import matplotlib.pyplot as plt from matplotlib.patches import Rectangle - plt.rcParams['figure.figsize'] = (12.0, 4.0) max_time = max( [tline[-1][1] for tline in device_timeline if len(tline) != 0] ) - + plt.rcParams['figure.figsize'] = (4.0 * max_time // ndevice, 4.0) fig, ax = plt.subplots() + renderer = fig.canvas.get_renderer() + + # xaxis ax.set_xlim((1, max_time)) plt.xticks(list(range(1, int(max_time)+1, 1))) ax.xaxis.grid(True, linestyle='--') - plt.xlabel('time') - # yaxis ax.set_ylim((0.5, len(self.devices())+0.5)) plt.yticks(list(range(1, len(self.devices())+1, 1))) ax.invert_yaxis() - plt.ylabel('device id') ax.set_aspect('equal') + fontsize = 100 + txts = list() for devid in range(ndevice): timeline = device_timeline[devid] nodes = device_nodes[devid] @@ -170,15 +171,35 @@ def map2name(node): # draw color = map2color(node) rec = Rectangle((start, devid + 0.5), end-start, 1, - color=color, ec='black', lw=1.5) + color=color, ec='black', lw=1.5) ax.add_artist(rec) rx, ry = rec.get_xy() cx = rx + rec.get_width() / 2.0 cy = ry + rec.get_height() / 2.0 anno = map2name(node) - ax.annotate(anno, (cx, cy), color='w', # weight='bold', - fontsize=10, ha='center', va='center') + txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') + + rbox = rec.get_window_extent(renderer) + for fs in range(40, 1, -2): + txt.set_fontsize(fs) + tbox = txt.get_window_extent(renderer) + if tbox.x0 >= rbox.x0 and tbox.x1 <= rbox.x1 and tbox.y0 >= rbox.y0 and tbox.y1 <= rbox.y1: + break + fontsize = min(fontsize, fs) + txts.append(txt) + + # set font size to same + for txt in txts: + txt.set_fontsize(fontsize) + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize) + plt.xlabel('Time Step', fontsize=fontsize) + plt.ylabel('Device ID', fontsize=fontsize) + # plt.grid() + plt.tight_layout() plt.savefig(outfile) From 4934266e40f0340731bd72dd4b787c96f5bb2b15 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 12 Jan 2022 11:58:11 +0000 Subject: [PATCH 0551/1892] add double data type --- cube/graph/parser/mapping.py | 1 + cube/ir/dtype.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index e7df2d25..91c8edd8 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -87,6 +87,7 @@ def map(dtype: torch.dtype): return DType2IRDType.kDtypeMap[dtype] kDtypeMap = { + torch.float64: ir.float64, torch.float32: ir.float32, torch.float : ir.float32, torch.float16: ir.float16, diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index b70c14f1..2bf0408a 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -1,6 +1,7 @@ from enum import Enum class IRDType(Enum): + float64 = 'float64' float16 = 'float16' float32 = 'float32' int64 = 'int64' @@ -12,6 +13,7 @@ class IRDType(Enum): unknown = 'unknown' +float64 = IRDType.float64 float16 = IRDType.float16 float32 = IRDType.float32 int64 = IRDType.int64 From 8ce215b0e64ba01153c7ebabdd37bb498b0ed370 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 12 Jan 2022 12:34:45 +0000 Subject: [PATCH 0552/1892] init conv2d def --- cube/graph/operator/function/einops.py | 6 ++- cube/graph/operator/function/function.py | 62 ++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index c948f571..24989b33 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -2,6 +2,7 @@ This operator class is highly inspired by eniops. """ import enum +import string from typing import List, Optional, Tuple from cube.ir.cten import IRTensor @@ -136,7 +137,7 @@ def algorithms(self, tag: Optional[str] = None): return template(self) return None - def parse(self, expr: str): + def parse(self, expr: str) -> Tuple[List[List[EinDim]], List[List[EinDim]]]: """ parse string like: b m k, b k n -> b m n @@ -154,6 +155,9 @@ def parse(self, expr: str): axises = list() for dim in input: reduce = EinDim.ReduceType.Sum if dim not in output else None + # a fixed numeric value indicates the axis is not splittable + if str.isnumeric(dim): + reduce = EinDim.ReduceType.Stay axises.append(EinDim(dim, reduce)) input_axises.append(axises) outputs = output.split(',') diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 4be024e7..c9eb3bd9 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -391,6 +391,68 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): op.set_output(idx, output) return op + +class Conv2D(IREinops): + """ + torch.conv2d(input, weight, bias, stride, padding, dialation, groups) + https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html?highlight=torch%20conv2d#torch.nn.functional.conv2d + """ + def __init__(self, signature, inputs, name='conv2d', **kwargs): + if len(inputs) != 7: + raise RuntimeError(f"expected 7 operators for conv2d but got {len(inputs)}") + super().__init__( + name, signature, + input_length=3, + output_length=1 + ) + for idx, input in enumerate(inputs[:3]): + self.set_input(idx, input) + self.kwargs['stride'] = inputs[3] + self.kwargs['padding'] = inputs[4] + self.kwargs['dilation'] = inputs[5] + self.kwargs['groups'] = inputs[6] + + def make_expression(self): + input = 'N I {iH} {iW}' + weight = 'O {group_channel} {kH} {kW}' + bias = 'O' + output = 'N O {oH} {oW}' + # parameters + groups = self.kwargs['groups'] + stride = self.kwargs['stride'] + padding = self.kwargs['padding'] + dilation = self.kwargs['dilation'] + kH = self.inputs(1).shape[2] + kW = self.inputs(1).shape[3] + + iH, iW = self.inputs(0).shape[0:2] + group_channel = self.inputs(0).shape[2] // groups + oH = (iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0] + 1 + oW = (iH + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1] + 1 + + input = input.format(iH=iH, iW=iW) + weight = weight.format(group_channel=group_channel, kH=kH, kW=kW) + output = output.format(oH=oH, oW=oW) + + expr = f'{input}, {weight}, {bias} -> {output}' + [idims, wdims, bdims], [odims] = self.parse(expr) + self.set_input_ein(0, idims) + self.set_input_ein(1, wdims) + if self.inputs(2) is not None: + self.set_input_ein(2, bdims) + self.set_output_ein(0, odims) + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + groups = self.kwargs['groups'] + stride = self.kwargs['stride'] + padding = self.kwargs['padding'] + dilation = self.kwargs['dilation'] + inputs += [groups, stride, padding, dilation] + op = Conv2D(self.signature, inputs, self.name) + op.set_output(0, outputs[0]) + return op + + # ===================== Cube Complex Operation ======================= class CubeComplexToQKV(IRFwOperation): From 464617fec926434c430086d09c77b3e48668c100 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Thu, 13 Jan 2022 01:20:07 +0000 Subject: [PATCH 0553/1892] Added README.md --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 00000000..0ca446aa --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +# Introduction +TODO: Give a short introduction of your project. Let this section explain the objectives or the motivation behind this project. + +# Getting Started +TODO: Guide users through getting your code up and running on their own system. In this section you can talk about: +1. Installation process +2. Software dependencies +3. Latest releases +4. API references + +# Build and Test +TODO: Describe and show how to build your code and run the tests. + +# Contribute +TODO: Explain how other users and developers can contribute to make your code better. + +If you want to learn more about creating good readme files then refer the following [guidelines](https://docs.microsoft.com/en-us/azure/devops/repos/git/create-a-readme?view=azure-devops). You can also seek inspiration from the below readme files: +- [ASP.NET Core](https://github.com/aspnet/Home) +- [Visual Studio Code](https://github.com/Microsoft/vscode) +- [Chakra Core](https://github.com/Microsoft/ChakraCore) \ No newline at end of file From dce35725d9aec915726c36d25522a0e50cc16056 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jan 2022 08:40:09 +0000 Subject: [PATCH 0554/1892] add elementwise operation --- cube/graph/operator/function/einops.py | 15 +++--- cube/graph/operator/function/function.py | 68 ++++++++++++++++++++---- cube/graph/parser/mapping.py | 6 +++ cube/graph/parser/parser.py | 9 ++-- 4 files changed, 80 insertions(+), 18 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 24989b33..0bbdb96c 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -17,7 +17,7 @@ class ReduceType(enum.Enum): Sum = 1 def __init__(self, name: str, reduce=None): - if not (str.isidentifier(name) or str.isdecimal(name) or name == '*'): + if not (str.isidentifier(name) or str.isnumeric(name) or name == '*'): raise ValueError("Einstein Axis name should be identifier") self.name: str = name self.reduce: Optional[EinDim.ReduceType] = reduce @@ -32,7 +32,7 @@ def is_reduce(self): return self.reduce == EinDim.ReduceType.Sum def __repr__(self): - return self.name if not self.is_reduce() else self.name + "'" + return self.name if not self.is_reduce() else self.name + "+" class IREinops(IRFwOperation): @@ -70,7 +70,7 @@ def infer_shape(self): self.make_expression() # check expression for input, ein_dims in zip(self.inputs(), self._ieins): - if len(ein_dims) == 0: + if len(ein_dims) == 0 or ein_dims is None: if isinstance(input, IRTensor): raise RuntimeError(f"{self}: {input} has no ein-dims but is a tensor") if len(ein_dims) != 0: @@ -144,12 +144,16 @@ def parse(self, expr: str) -> Tuple[List[List[EinDim]], List[List[EinDim]]]: """ if not isinstance(expr, str): raise TypeError("Expected string") - # remove space - expr = expr.replace(' ', '') + # split to inputs and outputs if expr.count('->') != 1: raise ValueError("string must contain one ->") + # split to each tensor input, output = expr.split('->') inputs = input.split(',') + outputs = output.split(',') + inputs = [[dim for dim in input.split(' ') if len(dim) != 0] for input in inputs] + outputs = [[dim for dim in output.split(' ') if len(dim) != 0] for output in outputs] + # parse each tensor input_axises = list() for input in inputs: axises = list() @@ -160,7 +164,6 @@ def parse(self, expr: str) -> Tuple[List[List[EinDim]], List[List[EinDim]]]: reduce = EinDim.ReduceType.Stay axises.append(EinDim(dim, reduce)) input_axises.append(axises) - outputs = output.split(',') output_axises = list() for output in outputs: axises = [EinDim(dim) for dim in output] diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index c9eb3bd9..93c8799c 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -101,16 +101,34 @@ def make_expression(self): """ dims = string.ascii_lowercase i1, i2 = self.inputs() - dim1 = [EinDim(dims[d]) for d in range(len(i1.shape))] - if isinstance(i2, IRTensor): - if i2.shape == i1.shape: - dim2 = dim1 + if isinstance(i1, IRTensor) and isinstance(i2, IRTensor): + shape1 = [EinDim(dims[d]) for d in range(len(i1.shape))] + shape2 = [EinDim(dims[d]) for d in range(len(i2.shape))] + if len(i1.shape) == len(i2.shape): + for idx, (dim1, dim2) in enumerate(zip(i1.shape, i2.shape)): + if dim1 != dim2: + shape1[idx] = EinDim(str(dim1), EinDim.ReduceType.Stay) + shape2[idx] = EinDim(str(dim2), EinDim.ReduceType.Stay) else: - raise NotImplementedError(f"Cannot match shape: {i1.shape} and {i2.shape}") - dim2 = list() - self.set_input_ein(0, dim1) - self.set_input_ein(1, dim2) - self.set_output_ein(0, dim1) + if len(i1.shape) == 1: + shape1[0].name = str(i1.shape[0]) + elif len(i2.shape) == 1: + shape2[0].name = str(i2.shape[0]) + out_shape = shape1 if i1.nele() > i2.nele() else shape2 + self.set_input_ein(0, shape1) + self.set_input_ein(1, shape2) + self.set_output_ein(0, out_shape) + else: + if isinstance(i1, IRTensor): + shape1 = [EinDim(dims[d]) for d in range(len(i1.shape))] + self.set_input_ein(0, shape1) + self.set_output_ein(0, shape1) + elif isinstance(i2, IRTensor): + shape2 = [EinDim(dims[d]) for d in range(len(i2.shape))] + self.set_input_ein(1, shape2) + self.set_output_ein(0, shape2) + else: + raise RuntimeError("both inputs {i1} and {i2} are not IRTensor") def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): elew = ElementWise(self.signature, inputs, self.name) @@ -148,6 +166,35 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): return add +class Sub(ElementWise): + """ + torch.add + """ + def __init__(self, signature, inputs, name='sub', **kwargs): + """ + Inputs: + inputs[0]: IRTensor + inputs[1]: other (IRTensor or Number) + inputs[2]: alpha (Number) + Outputs: + same shape as inputs[0] + """ + if len(inputs) != 3: + raise TypeError( + f"Add expected 3 inputs: [tensor, other, alpha], but got {inputs}" + ) + super().__init__(signature, inputs[:2], name=name) + alpha = inputs[2] + self.kwargs['alpha'] = alpha + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + inputs = inputs = self.kwags['alpha'] + add = Sub(self.signature, inputs, self.name) + for idx, output in enumerate(outputs): + add.set_output(idx, output) + return add + + class LayerNorm(IRFwOperation): def __init__(self, signature, inputs, name='layernorm', **kwargs): @@ -436,6 +483,9 @@ def make_expression(self): expr = f'{input}, {weight}, {bias} -> {output}' [idims, wdims, bdims], [odims] = self.parse(expr) + print(idims) + print(wdims) + print(bdims) self.set_input_ein(0, idims) self.set_input_ein(1, wdims) if self.inputs(2) is not None: diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 91c8edd8..caced447 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -52,14 +52,20 @@ def map(signature: str) -> IRFwOperation: __ttemplate('add') : function.Add, + __ttemplate('sub') : function.Sub, + __ttemplate('mul') : partial(function.ElementWise, name='mul'), + __ttemplate('div') : partial(function.ElementWise, name='div'), + __ttemplate('bmm') : function.BatchLinear, __ttemplate('sum') : function.Sum, __ttemplate('transpose') : function.Transpose, + __ttemplate('conv2d'): function.Conv2D, + # complex __customize('toqkv'): partial(function.CubeComplexToQKV, name='toqkv'), diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 71a63797..9184c2ec 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -63,9 +63,12 @@ def parse_module(module, # _ = input('>>>') if len(ir_nodes) != 0: for ir_node in ir_nodes: - ret = ir_node.infer_shape() - if not ret: - print(f'warning: {ir_node} cannot infer shape') + try: + ret = ir_node.infer_shape() + if not ret: + print(f'warning: {ir_node} cannot infer shape') + except Exception: + raise RuntimeError(f"Shape infer error at: {ir_node}") all_ir_nodes += ir_nodes # handle graph output -- Assuming all the output are tensors From 1411bf0f9d02a515f3cce7058d4611006fddd1c3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jan 2022 08:51:32 +0000 Subject: [PATCH 0555/1892] fix conv2d bug; support parser prim construct tuple --- cube/graph/operator/function/function.py | 2 +- cube/graph/parser/parser.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 93c8799c..b46c102f 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -472,7 +472,7 @@ def make_expression(self): kH = self.inputs(1).shape[2] kW = self.inputs(1).shape[3] - iH, iW = self.inputs(0).shape[0:2] + iH, iW = self.inputs(0).shape[2:4] group_channel = self.inputs(0).shape[2] // groups oH = (iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0] + 1 oW = (iH + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1] + 1 diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 9184c2ec..a578c0b9 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -74,6 +74,13 @@ def parse_module(module, # handle graph output -- Assuming all the output are tensors output_var_name = [output.debugName() for output in module.graph.outputs()] output_val = [frame.get_var(var_name) for var_name in output_var_name] + outputs = list() + for val in output_val: + if isinstance(val, list): + outputs += val + else: + outputs.append(val) + output_val = outputs frame.pop() return input_val, all_ir_nodes, output_val @@ -96,6 +103,8 @@ def ntype(node: torch._C.Node): return ScriptNodeKind.PrimListUnpack if node.kind() == 'prim::ListConstruct': return ScriptNodeKind.PrimListConstruct + if node.kind() == 'prim::TupleConstruct': + return ScriptNodeKind.PrimListConstruct if node.kind() == 'prim::TupleUnpack': return ScriptNodeKind.PrimTupleUnpack if node.kind() == 'prim::PythonOp': From 6ce98336f00b8b2f5c9f738724bc008f75fc42c6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jan 2022 09:30:48 +0000 Subject: [PATCH 0556/1892] fix grouping inference bug --- cube/execplan/planpass/grouping.py | 10 +++++----- cube/graph/operator/function/function.py | 3 --- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 3ccb17da..3bf89a86 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -41,6 +41,7 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: execplan.at(devid).insert(idx, subgraph) for node in pieces: execplan.at(devid).remove(node) + print(execplan) return execplan @staticmethod @@ -61,9 +62,7 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: fpieces, bpieces = list(), list() seq = execplan.sequence(devid) fnodes = [fnode for fnode in seq if isinstance(fnode, IRFwOperation)] - have_backward = all( - [isinstance(fnode.mirror, IRBpOperation) for fnode in fnodes] - ) + have_backward = all([fnode.mirror in seq for fnode in fnodes]) # training if have_backward: bnodes = [fnode.mirror for fnode in fnodes] @@ -80,15 +79,16 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: fpieces, bpieces = [fnode], [bnode] # inference else: - for fnode in fnodes: + for fnode in fnodes + [-1]: fconsecutive = Grouping.consecutive(seq, fpieces, fnode) if fconsecutive: fpieces.append(fnode) + bpieces.append(None) else: if len(fpieces) != 0: fgroups[devid].append(fpieces) bgroups[devid].append(None) - fpieces, bpieces = [fnode], list() + fpieces, bpieces = [fnode], [None] return fgroups, bgroups @staticmethod diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index b46c102f..13ba03d4 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -483,9 +483,6 @@ def make_expression(self): expr = f'{input}, {weight}, {bias} -> {output}' [idims, wdims, bdims], [odims] = self.parse(expr) - print(idims) - print(wdims) - print(bdims) self.set_input_ein(0, idims) self.set_input_ein(1, wdims) if self.inputs(2) is not None: From 6a771ee5d2c69ce7261d48a5a0805bf098d75cf8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jan 2022 09:31:12 +0000 Subject: [PATCH 0557/1892] fix bp op set data bug --- cube/graph/operator/operator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 8f6daf36..711f8b1c 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,7 +1,7 @@ from typing import Any, Optional, Union, List import copy -from cube.ir.cten import IRCell +from cube.ir.cten import IRCell, IRTensor from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory from cube.ir.unique import IDGenerator @@ -251,7 +251,8 @@ def set_data(self, data_index: int, val: Any): f"Set the input out of range ({data_index} >= {self.data_num})" ) val = copy.copy(val) - val.attach_cell(self) + if isinstance(val, IRTensor): + val.attach_cell(self) self._datas[data_index] = val return val From 31850f1049e5220d13a6297c21ea31ffdd85d201 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jan 2022 11:48:44 +0000 Subject: [PATCH 0558/1892] fix multi-same-input for one operator --- cube/graph/graph.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 56d6d999..0595dd5b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -191,12 +191,16 @@ def detach(self, node: IRCell, reset_dependency=False) -> int: raise KeyError(f"node {node} is not in graph.") ops = node.nodes() if isinstance(node, IRGraph) else [node] for op in ops: + removed = list() for input in op.inputs(): - if isinstance(input, IRSubTensor): + if isinstance(input, IRSubTensor) and input not in removed: input.parent.rm_consumer(op) + removed.append(input) + removed = list() for output in op.outputs(): - if isinstance(output, IRSubTensor): + if isinstance(output, IRSubTensor) and output not in removed: output.parent.rm_producer(op) + removed.append(output) index = self._nodes.index(node) self._nodes.pop(index) if reset_dependency: From 175066c416e9f29515878c09d24e1390a9847d33 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jan 2022 12:08:20 +0000 Subject: [PATCH 0559/1892] allow train iteration return variables --- cube/compiler.py | 13 +++++++++---- cube/execplan/planpass/grouping.py | 1 - cube/graph/graph.py | 2 +- cube/logics/translator.py | 4 ++-- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index a4b7844d..c8dec8b0 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple, Union +from typing import Callable, Tuple, Union, Optional import torch import time @@ -58,7 +58,7 @@ def __call__(self, *args): return self.ir_graph(*args) -def compile(model: SemanticModel, dataloader: CubeDataLoader, +def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None): """ AI Scientist calls like: @@ -88,6 +88,9 @@ def train_step(model, dataloader): """ if not isinstance(model, SemanticModel): raise TypeError("Expect Semantic Model") + if dataloader is None: + # create empty dataloader + dataloader = cube.runtime.syndata.SynDataLoader(shapes=(),dtypes=()) if not isinstance(dataloader, CubeDataLoader): raise TypeError("Expect dataloader derived from CubeDataLoader") if callable(PAS): @@ -123,8 +126,10 @@ def decorator(fn: Callable) -> Callable: resource = cube.runtime.resource.EnvResource() # logic translator - fn(model_graph, ir_dataloader) - graph = LogicTranslator.gen_logic_graph() + outputs = fn(model_graph, ir_dataloader) + if outputs is None: + outputs = [] + graph = LogicTranslator.gen_logic_graph(outputs=outputs) if len(PAS) == 1: graph = PAS[0](graph, resource) diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 3bf89a86..5805f657 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -41,7 +41,6 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: execplan.at(devid).insert(idx, subgraph) for node in pieces: execplan.at(devid).remove(node) - print(execplan) return execplan @staticmethod diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 0595dd5b..f7f66c76 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -166,7 +166,7 @@ def subgraph(self, sub_nodes: List[IRCell]): for t in sub_outputs: if isinstance(t, IRSubTensor): # not consumed or used outside this subgraph - if t not in sub_inputs or t in remain_inputs: + if t not in sub_inputs or t in remain_inputs or t in self.outputs(): if t not in outputs: outputs.append(t) subgraph = IRGraph( diff --git a/cube/logics/translator.py b/cube/logics/translator.py index 6cf1442d..1ac493e7 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -10,12 +10,12 @@ class LogicTranslator: @staticmethod - def gen_logic_graph(): + def gen_logic_graph(outputs=None): """ Generate Training Logic Graph """ nodes = SchedulePool().nodes() - graph = IRGraph(nodes, inputs=[], outputs=None, module_name='LogicGraph') + graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') return graph @staticmethod From 16c2d581e4ef0221ffd0d83dc03087b16b7954e0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 17 Jan 2022 02:56:23 +0000 Subject: [PATCH 0560/1892] sci examples --- cube/codegen/codegen.py | 4 ++ examples/sci/policy/naive.py | 7 ++ examples/sci/sci.py | 120 +++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 examples/sci/policy/naive.py create mode 100644 examples/sci/sci.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 98a4c145..2b99fff1 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -391,6 +391,10 @@ def gen(self, device: int, outfile=None, attach=False) -> str: name = self.node_naming(node) code = self.emit_node(node, name=name) fb.insert_body(code) + # return code + outputs = self.return_naming(self.execplan.graph.outputs()) + code = f'return {outputs}' + fb.insert_body(code) gencode += fb.code gencode += [''] diff --git a/examples/sci/policy/naive.py b/examples/sci/policy/naive.py new file mode 100644 index 00000000..31d2d2df --- /dev/null +++ b/examples/sci/policy/naive.py @@ -0,0 +1,7 @@ +from cube.graph import IRGraph + +def PAS(graph: IRGraph, resource): + for node in graph.nodes(): + graph.assign(node, 0) + print(graph.extra_repr()) + return graph \ No newline at end of file diff --git a/examples/sci/sci.py b/examples/sci/sci.py new file mode 100644 index 00000000..3e75ef68 --- /dev/null +++ b/examples/sci/sci.py @@ -0,0 +1,120 @@ +from typing import List + +import torch +import torch.nn.functional as F +import time + +torch.set_default_tensor_type(torch.DoubleTensor) + +import cube +from examples.sci.policy.naive import PAS + +""" +OMP_NUM_THREADS=4 torchrun --standalone \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/sci/sci.py +""" + + +class ScientificModel(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, r0: torch.Tensor, p: torch.Tensor, phi: torch.Tensor, + filter: torch.Tensor): + conv_out = F.conv2d(p, filter, padding=1) + alpha = torch.mul(r0, r0).sum() / torch.mul(p, conv_out).sum() + r1 = r0 - alpha * conv_out + # update + phi = phi + alpha * p + r1_sum = torch.mul(r1, r1).sum() + beta = r1_sum / torch.mul(r0, r0).sum() + p = r1 + beta * p + return r1, p, phi, r1_sum + + +class LoopVariables(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): + + shapes = [list(var.size()) for var in variables + constants] + dtypes = [var.dtype for var in variables + constants] + batch_dims = [0] * (len(variables) + len(constants)) + super().__init__(shapes, dtypes, batch_dims) + self.variables = list() + self.constants = list() + for var in variables: + if torch.is_tensor(var) and var.device != torch.cuda.current_device(): + var = var.cuda() + self.variables.append(var) + for const in constants: + if torch.is_tensor(const) and const.device != torch.cuda.current_device(): + const = const.cuda() + self.constants.append(const) + + def __iter__(self): + return self + + def update(self, variables: List[torch.Tensor] = None, constants: List[torch.Tensor] = None): + if variables is not None: + self.variables = variables + if constants is not None: + self.constants = constants + + def reset(self, batch_size): + pass + + def __next__(self): + if len(self.variables) + len(self.constants) == 1: + return (self.variables + self.constants)[0] + return tuple(self.variables + self.constants) + + +def train_loop(): + # initialize + N = 1024 * 2 + filter = torch.tensor( + [[0., 1., 0.], + [1., -4., 1.], + [0., 1., 0.]] + ).view(1, 1, 3, 3) + rho = F.conv2d(torch.ones((1, 1, N, N)), filter, padding=1) + phi = torch.zeros((1, 1, N, N)) + r0 = rho - F.conv2d(phi, filter, padding=1) + p = r0 + + varloader = LoopVariables(variables=[r0, p, phi], constants=[filter]) + model = ScientificModel() + model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes),) + + @cube.compile(model=model, dataloader=varloader, PAS=PAS) + def train_iter(model, dataloader): + r0, p, phi, filter = next(dataloader) + r0, p, phi, r1_sum = model(r0, p, phi, filter) + return r0, p, phi, r1_sum + model = model.get_gen_module() + + start = time.time() + + counter = 0 + while True: + counter += 1 + r0, p, phi, r1_sum = train_iter(model, varloader) + varloader.update(variables=[r0, p, phi]) + if counter % 100 == 0: + print('iters:\t', counter) + print('rnorm:\t', torch.sqrt(r1_sum)) + if torch.sqrt(r1_sum) < 1e-10: + print('**************** Converged ****************') + print('iters:\t', counter) + torch.cuda.synchronize() + print('time:\t', time.time() - start) + print('error:\t', torch.norm(phi - torch.ones((1, 1, N, N)).cuda())) + break + + +if __name__ == '__main__': + cube.init() + train_loop() \ No newline at end of file From 8f243a6dedab7b24e6293d641dae45f235f492fb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 17 Jan 2022 11:59:17 +0000 Subject: [PATCH 0561/1892] add check on return values --- cube/execplan/execplan.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 00f34b50..515a61ca 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -5,6 +5,7 @@ from cube.ir.cten import IRCell from cube.graph.graph import IRGraph +from cube.graph.tensor import IRFullTensor, IRSubTensor class ExectuionPlan: @@ -22,6 +23,19 @@ def __init__(self, graph: IRGraph): self.device_seq[device] = [node] else: self.device_seq[device].append(node) + # check whether graph output is replicated across device + # FIXME: should use adapter to generate communication for + # traning logic output + for output in graph.outputs(): + devices = self.devices() + ltensor: IRFullTensor = output.parent # logic tensor + if isinstance(output, IRSubTensor): + for ptensor, producer in zip(ltensor.ptensors, ltensor.producers): + if ptensor == output: + if producer.device[0] in devices: + devices.remove(producer.device[0]) + if len(devices) != 0: + raise NotImplementedError("Require return values of training logic is replicated across nodes.") def devices(self) -> List[int]: """ From 86bdab18c5eeb5209e5984b63dc9c6c8f868c4fa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 18 Jan 2022 09:09:06 +0000 Subject: [PATCH 0562/1892] init cutomize op registeration --- cube/graph/operator/function/function.py | 65 ++++++++++++++++++++++++ cube/graph/parser/__init__.py | 3 +- cube/graph/parser/mapping.py | 12 +++++ cube/graph/parser/register.py | 43 ++++++++++++++++ 4 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 cube/graph/parser/register.py diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 13ba03d4..447984ca 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -500,6 +500,71 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): return op +class CustomizeEinop(IREinops): + """ + Customize Einop + """ + def __init__(self, signature: str, inputs, name, **kwargs): + expected = ['anno', 'stay', 'kwarg_idx', 'kwarg_name'] + if not all([attr not in kwargs for attr in expected]): + raise KeyError("Expected anno, kwarg_idx, kwarg_name for UDF function") + self.anno: str = kwargs['anno'] + self.stay: List[str] = kwargs['stay'] + # get input output + input_anno, output_anno = self.anno.split('->') + ninputs = len(input_anno.split(',')) + noutputs = len(output_anno.split(',')) + self.kwarg_idx: List[int] = kwargs['kwarg_idx'] + self.kwarg_name: List[str] = kwargs['kwarg_name'] + kwarg_inputs = [inputs[idx] for idx in self.kwarg_idx] + op_inputs = [input for input in inputs if input not in kwarg_inputs] + if len(kwarg_inputs) + ninputs != len(inputs): + raise ValueError( + f"Got {len(inputs)} inputs but kwarg inputs" + f"({len(kwarg_inputs)}) + anno inputs ({ninputs}) doesn't match" + ) + super().__init__( + name, signature, + input_length=ninputs, + output_length=noutputs + ) + for name, kinput in zip(self.kwarg_name, kwarg_inputs): + self.kwargs[name] = kinput + for idx, input in enumerate(op_inputs): + self.set_input(idx, input) + + def make_expression(self): + ishapes, oshapes = self.parse(self.anno) + for idx, ishape in enumerate(ishapes): + for idim in ishape: + if idim.name in self.stay: + idim.reduce = EinDim.ReduceType.Stay + self.set_input_ein(idx, ishape) + for idx, oshape in enumerate(oshapes): + for odim in oshape: + if odim.name in self.stay: + odim.reduce = EinDim.ReduceType.Stay + self.set_output_ein(idx, oshape) + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + kwargs = dict( + anno = self.anno, + stay = self.stay, + kwarg_idx = self.kwarg_idx, + kwarg_name = self.kwarg_name, + ) + all_inputs = [None] * (len(self.inputs) + len(self.kwarg_idx)) + remain_idx = list(range(len(all_inputs))) + for idx, name in zip(self.kwarg_idx, self.kwarg_name): + all_inputs[idx] = self.kwargs[name] + remain_idx.remove(idx) + for idx, input in zip(remain_idx, inputs): + all_inputs[idx] = input + op = CustomizeEinop(self.signature, all_inputs, self.name, **kwargs) + for idx, output in enumerate(outputs): + op.set_output(idx, output) + return op + # ===================== Cube Complex Operation ======================= class CubeComplexToQKV(IRFwOperation): diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index ded01f96..cd7b3a25 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,2 +1,3 @@ from cube.graph.parser.parser import ScriptModuleParser -from cube.graph.parser.converter import convert_model, convert_dataloader \ No newline at end of file +from cube.graph.parser.converter import convert_model, convert_dataloader +from cube.graph.parser.register import register \ No newline at end of file diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index caced447..71ad8d34 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -25,6 +25,18 @@ def map(signature: str) -> IRFwOperation: # print(f'warning: {signature} is not recognized') # return partial(function.UnkownOperator, signature=signature) + @staticmethod + def register(signature: str, op: IRFwOperation): + """ + Register an operator + """ + if not isinstance(signature, str): + raise TypeError(f"Expected signature to be str but got {type(signature)}") + if signature in Sign2Op.kOpMap: + raise KeyError(f"function {signature} is already registered") + print(f'registering op {signature}...') + Sign2Op.kOpMap[signature] = op + # functional templates __ftemplate = lambda name: f'torch.nn.functional.{name}' diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py new file mode 100644 index 00000000..fac34fae --- /dev/null +++ b/cube/graph/parser/register.py @@ -0,0 +1,43 @@ +""" +Register cutomized function +""" + +from functools import partial +from typing import Callable, List +import inspect +import torch + +from cube.graph.operator.function import CustomizeEinop + +from cube.graph.parser.mapping import Sign2Op + + +def register(anno: str, stay: List[str] = None): + """ + Register a function with einop annotations. + """ + if stay is None: + stay = list() + + def decorator(fn: Callable): + if not callable(fn): + raise TypeError("Expected a function") + args = inspect.signature(fn) + arg_names = list(args.parameters.keys()) + arg_kind = [args.parameters[name].annotation for name in arg_names] + func_name = fn.__name__ + kwarg_idx = list() + kwarg_name = list() + for idx, (name, kind) in enumerate(zip(arg_names, arg_kind)): + if kind != torch.Tensor: + kwarg_name.append(name) + kwarg_idx.append(idx) + udfop = partial(CustomizeEinop, + name=func_name, + anno=anno, stay=stay, + kwarg_idx=kwarg_idx, kwarg_name=kwarg_name + ) + Sign2Op.register(func_name, udfop) + return fn + + return decorator From 4bed586c56bfa0a38c46eb25d80ebfebffe4e93f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 18 Jan 2022 09:09:27 +0000 Subject: [PATCH 0563/1892] attention using register-op --- examples/attention/attention.py | 110 ++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 40 deletions(-) diff --git a/examples/attention/attention.py b/examples/attention/attention.py index c68409b3..652e5905 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -24,9 +24,70 @@ from cube.profiler.timer import print_each_rank +@cube.graph.operator.register('L N E, (3 h d) E -> L N (h d)', stay=['L', 'd', 'E']) +def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, num_head: int, + scale: float, dropout: float, training: bool): + """ + L: sequence length + N: batch size + E: embedding size + x: hidden state: [L, N, E] + wqkv: qkv weight: [3 * (num_head * dim_head), E] + dropout: float + num_head: int + """ + L, N = x.shape[0], x.shape[1] + dim_head = wqkv.shape[0] // 3 // num_head + # L N E, (3 h d) E -> L N (3 h d) + qkv = F.linear(x, wqkv, None) + # L N (3 h d) -> L N (h d), L N (h d), L N (h d) + q, k, v = qkv.chunk(3, dim=-1) + # L N (h d) -> L (N h) d + q = q.contiguous().view(L, (N * num_head), dim_head) + # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) + # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) + # L (N h) d -> (N h) L d + q = q.transpose(0, 1) + # L (N h) d -> (N h) L d + k = k.transpose(0, 1) + # L (N h) d -> (N h) L d + v = v.transpose(0, 1) + # (N h) L d, 1 -> (N h) L d + q = q * scale + # (N h) L d -> (N h) d L + k = k.transpose(-2, -1) + # (N h) L d, (N h) d L -> (N h) L L + attn = torch.bmm(q, k) + + # attention mask + # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N, num_head), L, L) + + # (N h) L L -> (N h) L L + attn = F.softmax(attn, dim=-1) + # (N h) L L -> (N h) L L + if training: + attn = F.dropout(attn, dropout, True, False) + # (N h) L L, (N h) L d -> (N h) L d + output = torch.bmm(attn, v) + # (N h) L d -> L (N h) d + output = output.transpose(0, 1).contiguous() + # L (N h) d -> L N (h d) + output = output.view(L, N, num_head * dim_head) + return output + + class MultiHeadSelfAttention(nn.Module): - def __init__(self, seq_len, embed_dim, heads, dropout): + def __init__(self, seq_len, embed_dim, heads, dropout: float): super().__init__() self.seq_len = seq_len @@ -35,55 +96,24 @@ def __init__(self, seq_len, embed_dim, heads, dropout): self.dim_head = embed_dim // heads self.scale = self.dim_head ** -0.5 - self.weight_qkv = torch.nn.Parameter(torch.empty( + self.wqkv = torch.nn.Parameter(torch.empty( 3 * embed_dim, embed_dim )) - self.weight_out = torch.nn.Parameter(torch.empty( + self.wout = torch.nn.Parameter(torch.empty( embed_dim, embed_dim )) - self.dropout = nn.Dropout(dropout) + self.dropout = dropout def forward(self, x): """ x: [L, N, E]: seq_len, batch_size, embedding dimension output: [L, N, E] """ - # [L, N, E] -> 3 x [L, (N * num_head), dim_head] - q, k, v = cube.runtime.function.toqkv( - x, self.weight_qkv, self.num_head - ) - - # [L, (N * num_head), dim_head] -> [(N * num_head), L, dim_head] - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - # [(N * num_head), L, dim_head] -> [(N * num_head), L, dim_head] - q = q * self.scale - # [(N * num_head), L, dim_head] -> [(N * num_head), dim_head, L] - k = k.transpose(-2, -1) - # [(N * num_head), L, dim_head] * [(N * num_head), dim_head, L] - # -> [(N * num_head), L, L] - attn = torch.bmm(q, k) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = cube.runtime.function.tril_mask(attn, self.num_head) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = F.softmax(attn, dim=-1) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = self.dropout(attn) - # [(N * num_head), L, L] * [(N * num_head), L, dim_head] - # -> [(N * num_head), L, dim_head] - output = torch.bmm(attn, v) - - # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] - output = cube.runtime.function.attn_view(output, self.num_head) - - # [L, N, num_head * dim_head] * [E, embed_head * dim_head] - # -> [L, N, E] - output = F.linear(output, self.weight_out) + # L N E, (3 h d) E -> L N (h d) + output = attnfc1(x, self.wqkv, self.num_head, + self.scale, self.dropout, self.training) + # L N (h d), E (h d) -> L N E + output = F.linear(output, self.wout) loss = torch.sum(output) return loss From 0f4d25063de258e768d2d6ad7d0a76b5b57125ca Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 18 Jan 2022 10:52:20 +0000 Subject: [PATCH 0564/1892] fix register bug --- cube/graph/parser/mapping.py | 1 - cube/graph/parser/register.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 71ad8d34..8375b8ae 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -34,7 +34,6 @@ def register(signature: str, op: IRFwOperation): raise TypeError(f"Expected signature to be str but got {type(signature)}") if signature in Sign2Op.kOpMap: raise KeyError(f"function {signature} is already registered") - print(f'registering op {signature}...') Sign2Op.kOpMap[signature] = op # functional templates diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index fac34fae..c2d8b73c 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -32,6 +32,7 @@ def decorator(fn: Callable): if kind != torch.Tensor: kwarg_name.append(name) kwarg_idx.append(idx) + print(f'registering op {func_name} with {len(args.parameters) - len(kwarg_idx)} inputs and {len(kwarg_idx)} kwargs...') udfop = partial(CustomizeEinop, name=func_name, anno=anno, stay=stay, From 32cb4537b7ce699d7ff627aae37e25287f17a71c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 18 Jan 2022 11:26:55 +0000 Subject: [PATCH 0565/1892] update eindims --- cube/graph/operator/function/einops.py | 32 ++++++++++++++++++------ cube/graph/operator/function/function.py | 4 +-- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 0bbdb96c..05a31695 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -2,8 +2,7 @@ This operator class is highly inspired by eniops. """ import enum -import string -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from cube.ir.cten import IRTensor from cube.graph.operator.operator import IRFwOperation @@ -16,11 +15,30 @@ class ReduceType(enum.Enum): Stay = 0 # the dim is not allowed to be split Sum = 1 - def __init__(self, name: str, reduce=None): - if not (str.isidentifier(name) or str.isnumeric(name) or name == '*'): - raise ValueError("Einstein Axis name should be identifier") - self.name: str = name - self.reduce: Optional[EinDim.ReduceType] = reduce + def __init__(self, name: Union[str, List[str]], reduce=None): + if isinstance(name, str): + name = [name] + for n in name: + if not (str.isidentifier(n) or str.isnumeric(n) or n == '*'): + raise ValueError("Einstein Axis name should be identifier") + self._name: List[str] = name + self._reduce: Optional[EinDim.ReduceType] = reduce + + @property + def name(self) -> str: + if len(self._name) == 1: + return self._name[0] + return '(' + ' '.join(self._name) + ')' + + @property + def reduce(self) -> str: + return self._reduce + + @reduce.setter + def reduce(self, val): + if not isinstance(val, EinDim.ReduceType): + raise TypeError("Expected EinDim.ReduceType") + self._reduce = val def __eq__(self, other): if isinstance(other, EinDim): diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 447984ca..2f069732 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -506,8 +506,8 @@ class CustomizeEinop(IREinops): """ def __init__(self, signature: str, inputs, name, **kwargs): expected = ['anno', 'stay', 'kwarg_idx', 'kwarg_name'] - if not all([attr not in kwargs for attr in expected]): - raise KeyError("Expected anno, kwarg_idx, kwarg_name for UDF function") + if not all([attr in kwargs for attr in expected]): + raise KeyError("Expected anno, stay, kwarg_idx, kwarg_name for UDF function") self.anno: str = kwargs['anno'] self.stay: List[str] = kwargs['stay'] # get input output From c2061dbe4f60e989e813a520e048fb3ff7067c84 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 19 Jan 2022 01:57:34 +0000 Subject: [PATCH 0566/1892] einops name setting --- cube/graph/operator/function/einops.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 05a31695..16a53c73 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -30,6 +30,15 @@ def name(self) -> str: return self._name[0] return '(' + ' '.join(self._name) + ')' + @name.setter + def name(self, val: Union[str, List[str]]): + if isinstance(val, str): + self._name = [val] + elif all([isinstance(n, str) for n in val]): + self._name = list(val) + else: + raise TypeError("Expected Union[str, List[str] for name") + @property def reduce(self) -> str: return self._reduce From a704e1b5ef82e54e6aa6302116b8211d28f469ee Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jan 2022 07:20:45 +0000 Subject: [PATCH 0567/1892] use function to einops --- cube/algorithm/factory.py | 23 +- cube/algorithm/ops/einops.py | 22 +- cube/graph/operator/function/einops.py | 391 +++++++--- cube/graph/operator/function/function.py | 917 ++++------------------- 4 files changed, 426 insertions(+), 927 deletions(-) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index b1cc8c01..2a6f8579 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -69,8 +69,8 @@ def _load_predefined_algos(self): # self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') # self.register(elew.Add, elew.AddDimParallel, tag='dim') - import cube.algorithm.ops.layernorm as ln - self.register(ln.LayerNorm, ln.LayerNormDimParallel, tag='dim') + # import cube.algorithm.ops.layernorm as ln + # self.register(ln.LayerNorm, ln.LayerNormDimParallel, tag='dim') # import cube.algorithm.ops.activation as activation # self.register(activation.Activation, activation.ActivationDimParallel, tag='dim') @@ -80,24 +80,5 @@ def _load_predefined_algos(self): # import cube.algorithm.ops.reduce as reduce # self.register(reduce.Sum, reduce.SumDimParallel, tag='dim') - import cube.algorithm.ops.complex as complex - self.register(complex.CubeComplexToQKV, complex.CubeToQKVDataParallel, tag='data') - self.register(complex.CubeComplexToQKV, complex.CubeToQKVHeadParallel, tag='head') - - self.register(complex.CubeComplexTrilMask, complex.CubeTrilMaskDataParallel, tag='data') - self.register(complex.CubeComplexTrilMask, complex.CubeTrilMaskHeadParallel, tag='head') - - self.register(complex.CubeComplexAttnView, complex.CubeAttnViewDataParallel, tag='data') - self.register(complex.CubeComplexAttnView, complex.CubeAttnViewHeadParallel, tag='head') - - self.register(complex.CubeComplexSelfAttention, complex.CubeSelfAttentionDataParallel, tag='data') - self.register(complex.CubeComplexSelfAttention, complex.CubeSelfAttentionHeadParallel, tag='head') - - self.register(complex.CubeComplexFeedForward, complex.CubeFeedForwardDataParallel, tag='data') - self.register(complex.CubeComplexFeedForward, complex.CubeFeedForwardTensorParallel, tag='tensor') - - self.register(complex.CubeComplexEmbedding, complex.CubeEmbedDataParallel, tag='data') - self.register(complex.CubeComplexEmbedding, complex.CubeEmbedShardingParallel, tag='shard') - # import cube.algorithm.ops.memory as mem # self.register(mem.Transpose, mem.TransposeDimParallel, tag='dim') diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py index 423f2496..32413a9e 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/einops.py @@ -46,7 +46,7 @@ def satisfy(self, config: Dict): return False if node.inputs(idx).shape[dim] % num != 0: return False - if node._ieins[idx][dim].reduce == EinDim.ReduceType.Stay: + if node._iannos[idx][dim].reduce == EinDim.ReduceType.Stay: return False return True @@ -57,18 +57,18 @@ def instantiate(self, config: Dict) -> List[IREinops]: idx: int = config['idx'] dim: int = config['dim'] num: int = config['num'] - axis: EinDim = node._ieins[idx][dim] + edim: EinDim = node._iannos[idx][dim] # print(f'splitting: {node.einexpr()}') ins, ous = list(), list() for iidx, input in enumerate(node.inputs()): - if axis in node._ieins[iidx]: - dim = node._ieins[iidx].index(axis) + if edim in node._iannos[iidx]: + dim = node._iannos[iidx].index(edim) sub_tensors = split_axis(input, dim, num) ins.append(sub_tensors) else: - if axis.is_reduce(): + if edim.reduce[0] == EinDim.ReduceType.Sum: # print(f'Warning: value split on one input tensor in node{node._id}:{node.name} as reduce axis {axis} not appeared.') ins.append(split_value(input, num)) else: @@ -76,17 +76,17 @@ def instantiate(self, config: Dict) -> List[IREinops]: for oidx, output in enumerate(node.outputs()): # split on the non-reduce axis, the output value keeps same # but the output shape gets splitted - if axis in node._oeins[oidx]: - dim = node._oeins[oidx].index(axis) - if axis.is_reduce(): + if edim in node._oannos[oidx]: + dim = node._oannos[oidx].index(edim) + if edim.reduce[0] == EinDim.ReduceType.Sum: raise RuntimeError(f"Reduced axis {dim} appeared in output") sub_tensors = split_axis(output, dim, num) ous.append(sub_tensors) # split on the reduce axis, the output shape keeps same # but the output value get splitted else: - if not axis.is_reduce(): - raise RuntimeError(f"Expect axis {axis} to be reduced axis") + if edim.reduce[0] != EinDim.ReduceType.Sum: + raise RuntimeError(f"Expect axis {edim} to be reduced axis") sub_tensors = split_value(output, num) ous.append(sub_tensors) @@ -95,7 +95,7 @@ def instantiate(self, config: Dict) -> List[IREinops]: inputs = [t[nid] for t in ins] outputs = [t[nid] for t in ous] sub_node: IREinops = node.new(inputs, outputs) - sub_node.make_expression() + sub_node.infer_shape() sub_nodes.append(sub_node) return sub_nodes diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 16a53c73..777a9b9f 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -1,8 +1,12 @@ """ This operator class is highly inspired by eniops. """ +from typing import Callable, Dict, List, Union +from typing import Optional, Set, Tuple, Optional import enum -from typing import List, Optional, Tuple, Union +import re +import copy +import string from cube.ir.cten import IRTensor from cube.graph.operator.operator import IRFwOperation @@ -10,19 +14,46 @@ class EinDim: + """ + To represent a dimension, name = {identifier}{reducetype} + e.g., + ab^ means the dimension name is 'ab' and is a frozen dimension (cannot be split) + ab+ means the dimension name is 'ab' and this dimension is a reduce dimension + ['b', 'c+', 'd^'] means the dimension is composed by b, c, d + where b can be spatially partitioned (apear in output), c is a reduce dimension, + d is a frozen dimension (cannot be split) + """ class ReduceType(enum.Enum): - Stay = 0 # the dim is not allowed to be split - Sum = 1 + Spatial='' + Stay = '^' # the dim is not allowed to be split + Sum = '+' - def __init__(self, name: Union[str, List[str]], reduce=None): + def __init__(self, name: Union[str, List[str]]): if isinstance(name, str): name = [name] + self._name: List[str] = list() + self._reduce: List[EinDim.ReduceType] = list() + self._length: Dict[str, Optional[int]] = dict() for n in name: - if not (str.isidentifier(n) or str.isnumeric(n) or n == '*'): - raise ValueError("Einstein Axis name should be identifier") - self._name: List[str] = name - self._reduce: Optional[EinDim.ReduceType] = reduce + # complex name cannot have * + if len(name) > 1 and '*' in n: + raise ValueError("Einstein Axis name cannot have * for multiple inner-dimension") + # get reduce type + reduce = EinDim.ReduceType.Spatial + if n[-1] == EinDim.ReduceType.Sum.value: + reduce = EinDim.ReduceType.Sum + n = n[:-1] + elif n[-1] == EinDim.ReduceType.Stay: + reduce = EinDim.ReduceType.Stay + n = n[:-1] + # get identifier name + if len(n) == 0 or not (str.isidentifier(n) or str.isnumeric(n) or n == '*'): + raise ValueError(f"EinDim name {n} should be identifier") + self._name.append(n) + self._reduce.append(reduce) + for n in self._name: + self._length[n] = None @property def name(self) -> str: @@ -30,24 +61,14 @@ def name(self) -> str: return self._name[0] return '(' + ' '.join(self._name) + ')' - @name.setter - def name(self, val: Union[str, List[str]]): - if isinstance(val, str): - self._name = [val] - elif all([isinstance(n, str) for n in val]): - self._name = list(val) - else: - raise TypeError("Expected Union[str, List[str] for name") - @property def reduce(self) -> str: return self._reduce - - @reduce.setter - def reduce(self, val): - if not isinstance(val, EinDim.ReduceType): - raise TypeError("Expected EinDim.ReduceType") - self._reduce = val + + def setlen(self, anno: str, dim: int): + if anno not in self._name: + raise KeyError(f"Cannot find anno: {anno} in {self.name}") + self._length[anno] = dim def __eq__(self, other): if isinstance(other, EinDim): @@ -59,83 +80,260 @@ def is_reduce(self): return self.reduce == EinDim.ReduceType.Sum def __repr__(self): - return self.name if not self.is_reduce() else self.name + "+" + name_reduce = [name + reduce.value for name, reduce in zip(self._name, self._reduce)] + if len(self._name) == 0: + return self._name[0] + self._reduce[0].value + return '(' + ' '.join(name_reduce) + ')' -class IREinops(IRFwOperation): - """ - Einstein expression on operators like reshape, view, permute, reduce. - """ - def __init__(self, name: str, signature: str, input_length: int, output_length:int): - super().__init__(name, signature, input_length, output_length) - self._ieins = [list() for _ in range(input_length)] - self._oeins = [list() for _ in range(output_length)] +class EinopAnno: - def new(self, inputs, outputs, **kwargs): + def __init__(self, anno: str): """ - Create a new same operation given the inputs and outputs - - Each operator needs to implement this. + initializing annotations specfied in str, e.g., + a (b c) d, d k -> a (b c) k """ - raise NotImplementedError + if not isinstance(anno, str): + raise TypeError("Expected anno to be str") + self.anno = anno + if '->' not in self.anno: + raise ValueError("Expected -> in anno") + # to inputs and outputs + inputs, outputs = self.anno.split('->') + inputs = inputs.split(',') + outputs = outputs.split(',') + # to eindims + self._identifiers: Set[str] = set() + self.inputs: List[List[EinDim]] = [ + self.parse_shape(shape) for shape in inputs + ] + self.outputs: List[List[EinDim]] = [ + self.parse_shape(shape) for shape in outputs + ] - def make_expression(self): + def parse_shape(self, shape: str) -> List[EinDim]: """ - Set einstein-like expression assuming input shapes are given. - - Each operator needs to implement this. + parsing annotations like of a single shape, e.g., + a (b dim) d """ - raise NotImplementedError - - def infer_shape(self): + # => ['a', '(', 'b', 'dim', ')', 'd'] + shapes = list() + for group in re.split('\ +', shape): + if len(group) == 0: + continue + if '(' in group or ')' in group: + for group in re.split('([\(\)])', group): + if len(group) != 0: + shapes.append(group) + else: + shapes.append(group) + identifiers: List[List[str]] = list() + current_identifier = list() + bracket_group = False + for w in shapes: + if w == '(': + if bracket_group: + raise RuntimeError("brackets inside brackets not allowed") + bracket_group = True + identifiers.append(current_identifier) + current_identifier = list() + elif w == ')': + if not bracket_group: + raise RuntimeError("backets are not balanced at (") + bracket_group = False + identifiers.append(current_identifier) + current_identifier = list() + else: + if bracket_group: + current_identifier.append(w) + self._identifiers.add(w) + else: + if len(current_identifier) != 0: + identifiers.append(current_identifier) + current_identifier = [w] + self._identifiers.add(w) + if bracket_group: + raise RuntimeError("brackets are not balanced at )") + if len(current_identifier) != 0: + identifiers.append(current_identifier) + identifiers = [EinDim(identifer) for identifer in identifiers] + return identifiers + + def identifiers(self) -> Set[str]: + return copy.copy(self._identifiers) + + def __repr__(self) -> str: + inputs = ', '.join([repr(input) for input in self.inputs]) + outputs = ', '.join(repr(output) for output in self.outputs) + return inputs + ' -> ' + outputs + + + +class IREinops(IRFwOperation): + """ + Einstein-inspired notation operations + """ + def __init__(self, signature: str, annos: List[Union[str, Tuple[str, Callable]]], + inputs: List, name: str, **kwargs): + noutputs = set() + self._annos: List[EinopAnno] = list() + self._adapt: List[Union[Callable, None]] = list() + for anno in annos: + if isinstance(anno, tuple): + anno, adapt = anno + elif isinstance(anno, str): + adapt = None + else: + raise TypeError("Expected annos to be list of tuples of list of str") + anno = EinopAnno(anno) + self._annos.append(anno) + self._adapt.append(adapt) + noutputs.add(len(anno.outputs)) + self._iannos: List[List[EinDim]] = None + self._oannos: List[List[EinDim]] = None + + if len(noutputs) != 1: + raise ValueError("Annotations should have same output length") + super().__init__(name, signature, len(inputs), list(noutputs)[0]) + # set input + for idx, input in enumerate(inputs): + self.set_input(idx, input) + for name in kwargs: + self.kwargs[name] = kwargs[name] + + def infer_shape(self) -> bool: """ - Infer output value shape + Shape inference by mathcing dimension annotations. + Assume input shape is given """ - for input in self.inputs(): - if isinstance(input, IRTensor) and input.shape is None: - return False - self.make_expression() - # check expression - for input, ein_dims in zip(self.inputs(), self._ieins): - if len(ein_dims) == 0 or ein_dims is None: - if isinstance(input, IRTensor): - raise RuntimeError(f"{self}: {input} has no ein-dims but is a tensor") - if len(ein_dims) != 0: - if not isinstance(input, IRTensor): - raise RuntimeError(f"{self}: {input} has ein-dims but is not a tensor") - if len(input.shape) != len(ein_dims): - raise RuntimeError(f"input tensor ndims ({len(input.shape)}) != ein-dims ({len(ein_dims)})") + # try parsing given anno candidates + ret = False + for anno, adapt in zip(self._annos, self._adapt): + if adapt is not None: + anno = adapt(anno, self) + ret, iannos, oannos = self.parse(anno) + self._iannos = iannos + self._oannos = oannos + if ret: break + if not ret: + raise RuntimeError("No matching anno for given annos") + dimlen: Dict[str, int] = dict() + for input, ishape in zip(self.inputs(), self._iannos): + print(input.shape, ishape) + if not ((ishape is None and not isinstance(input, IRTensor)) or + len(ishape) == len(input.shape)): + raise RuntimeError(f"node {self._id}: error match input: {input.shape} and einshape: {ishape}") + for tdim, edim in zip(input.shape, ishape): + if len(edim._name) == 1: + if edim.name in dimlen and dimlen[edim.name] != tdim: + raise RuntimeError(f"op: {self.signature} has different shape for same dim annotation {edim.name}") + dimlen[edim.name] = tdim + edim.setlen(edim.name, tdim) + else: + toinfer = list() + accum = 1 + for name in edim._name: + if str.isnumeric(name): + accum *= int(name) + edim.setlen(name, int(name)) + dimlen[name] = int(name) + elif name in self.kwargs: + accum *= self.kwargs[name] + edim.setlen(name, self.kwargs[name]) + dimlen[name] = self.kwargs[name] + else: + toinfer.append(name) + if len(toinfer) > 1: + raise RuntimeError(f"Expected indication of dimension {toinfer} from kwargs") + if len(toinfer) == 1: + edim.setlen(toinfer[0], tdim // accum) + dimlen[toinfer[0]] = tdim // accum # figure output shape for oidx in range(len(self._outputs)): output_shape = list() - for oein in self._oeins[oidx]: - if str.isdecimal(oein.name): - output_shape.append(int(oein.name)) - continue - for iidx in range(len(self._inputs)): - if oein in self._ieins[iidx]: - input = self.inputs(iidx) - dim = self._ieins[iidx].index(oein) - output_shape.append(input.shape[dim]) - break + for odim in self._oannos[oidx]: + accum = 1 + for name in odim._name: + if str.isdecimal(name): + accum *= int(name) + else: + if name not in dimlen: + raise KeyError(f"Dim annotation {name} not in input") + accum *= dimlen[name] + odim.setlen(name, dimlen[name]) + output_shape.append(accum) self.outputs(oidx).shape = output_shape - return True + return ret - def set_input_ein(self, input_index: int, dims: List[EinDim]): + def new(self, inputs: List, outputs: List): """ - Set input einstein axis at input index + construct a new operator sharing same kwargs with new inputs + and outputs """ - if not all([isinstance(dim, EinDim) for dim in dims]): - raise TypeError("Expected Tuple[EinDim]") - self._ieins[input_index] = tuple(dims) + annos = list() + for anno, adapt in zip(self._annos, self._adapt): + annos.append((anno.anno, adapt)) + op = IREinops(self.signature, annos, inputs, self.name, **self.kwargs) + for idx, output in enumerate(outputs): + op.set_output(idx, output) + return op - def set_output_ein(self, output_index: int, dims: Tuple[EinDim]): + def parse(self, anno: EinopAnno): """ - Set output einstein axis at output index + parse annotations, assuming input tensor shape is given """ - if not all([isinstance(dim, EinDim) for dim in dims]): - raise TypeError("Expected Tuple[EinDim]") - self._oeins[output_index] = tuple(dims) + # copy + anno = EinopAnno(anno.anno) + if len(anno.inputs) != len(self.inputs()): + return False, None, None + identifiers = anno.identifiers() + # expand * + expand_dims = None + if '*' in identifiers: + for idx in range(len(anno.inputs)): + shape = anno.inputs[idx] + shape_anno = [dim.name for dim in shape] + if '*' in shape_anno: + start = shape_anno.index('*') + span = len(self.inputs(idx).shape) - len(shape) + 1 + if span <= 0: + if expand_dims is None: + expand_dims = list() + if len(expand_dims) > 0: + return False, None, None + anno.inputs[idx].remove(EinDim('*')) + if span > 0: + if expand_dims is None: + expand_dims = list() + unused_annos = [c for c in string.ascii_lowercase if c not in identifiers] + if len(unused_annos) < span: + raise RuntimeError("Too many introduced dimensions") + for dim in range(span): + expand_dims.append(EinDim([unused_annos[dim]])) + if len(expand_dims) != span: + return False, None, None + anno.inputs[idx] = anno.inputs[idx][:start] + expand_dims + anno.inputs[idx][start+1:] + for idx in range(len(anno.outputs)): + shape = anno.outputs[idx] + shape_anno = [dim.name for dim in shape] + if '*' in shape_anno: + if expand_dims is None: + raise RuntimeError("* should appear in inputs") + start = shape_anno.index('*') + span = len(expand_dims) + anno.outputs[idx] = anno.outputs[idx][:start] + expand_dims + anno.outputs[idx][start+1:] + # check dimension consistency + dimlen: Dict[str, int] = dict() + for shape, input in zip(anno.inputs, self.inputs()): + if not isinstance(input, IRTensor): + if shape.name != '1': + return False, None, None + for dim, nele in zip(shape, input.shape): + if dim.name in dimlen: + if nele != dimlen[dim.name]: + return False, None, None + dimlen[dim.name] = nele + return True, anno.inputs, anno.outputs def einexpr(self) -> str: inputs = list() @@ -163,36 +361,3 @@ def algorithms(self, tag: Optional[str] = None): template = factory.algorithms(IREinops, tag) return template(self) return None - - def parse(self, expr: str) -> Tuple[List[List[EinDim]], List[List[EinDim]]]: - """ - parse string like: - b m k, b k n -> b m n - """ - if not isinstance(expr, str): - raise TypeError("Expected string") - # split to inputs and outputs - if expr.count('->') != 1: - raise ValueError("string must contain one ->") - # split to each tensor - input, output = expr.split('->') - inputs = input.split(',') - outputs = output.split(',') - inputs = [[dim for dim in input.split(' ') if len(dim) != 0] for input in inputs] - outputs = [[dim for dim in output.split(' ') if len(dim) != 0] for output in outputs] - # parse each tensor - input_axises = list() - for input in inputs: - axises = list() - for dim in input: - reduce = EinDim.ReduceType.Sum if dim not in output else None - # a fixed numeric value indicates the axis is not splittable - if str.isnumeric(dim): - reduce = EinDim.ReduceType.Stay - axises.append(EinDim(dim, reduce)) - input_axises.append(axises) - output_axises = list() - for output in outputs: - axises = [EinDim(dim) for dim in output] - output_axises.append(axises) - return input_axises, output_axises diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 2f069732..f01313f8 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -1,786 +1,139 @@ -import copy -from typing import List -import string - -from cube.graph.operator import IRFwOperation -from cube.graph.operator.function.einops import EinDim, IREinops -from cube.ir.cten import IRTensor - - -class Linear(IREinops): - """ - b * k, n k -> b * n - """ - def __init__(self, signature, inputs, name='linear', **kwargs): - input, weight, bias = inputs - super().__init__( - name, signature, - input_length=3, - output_length=1 - ) - self.set_input(0, input) - self.set_input(1, weight) - self.set_input(2, bias) - - def make_expression(self): - expr = 'b * k, n k, n -> b * n' - [idims, wdims, bdims], [odims] = self.parse(expr) - if len(self.inputs(0).shape) == 2: - idims = [idims[0], idims[2]] - odims = [odims[0], odims[2]] - else: - extra_dims = list() - num_extra_dim = len(self.inputs(0).shape) - 2 - dims = [c for c in string.ascii_lowercase if c not in 'bkn'] - for num in range(num_extra_dim): - extra_dims.append(EinDim(dims[num])) - idims = [idims[0]] + extra_dims + [idims[-1]] - odims = [odims[0]] + extra_dims + [odims[-1]] - self.set_input_ein(0, idims) - self.set_input_ein(1, wdims) - if self.inputs(2) is not None: - self.set_input_ein(2, bdims) - self.set_output_ein(0, odims) - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - linear = Linear(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - linear.set_output(idx, output) - return linear - - -class BatchLinear(IREinops): - """ - b m k, b k n -> b m n - """ - def __init__(self, signature, inputs, name='bmm', **kwargs): - if len(inputs) != 2: - raise TypeError(f"Requires 2 inputs. But got {inputs}") - input1, input2 = inputs - super().__init__( - name, signature, - input_length=2, - output_length=1 - ) - self.set_input(0, input1) - self.set_input(1, input2) - - def make_expression(self): - expr = 'b m k, b k n -> b m n' - input_dims, output_dims = self.parse(expr) - for idx, input_dim in enumerate(input_dims): - self.set_input_ein(idx, input_dim) - for idx, output_dim in enumerate(output_dims): - self.set_output_ein(idx, output_dim) - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - bmm = BatchLinear(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - bmm.set_output(idx, output) - return bmm - - -class ElementWise(IREinops): - """ - *, _ -> * - """ - - def __init__(self, signature, inputs, name='elementwise', **kwargs): - if len(inputs) != 2: - raise TypeError(f"Expected 2 inputs but got {inputs}") - super().__init__( - name, signature, - input_length=2, - output_length=1 - ) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - - def make_expression(self): - """ - """ - dims = string.ascii_lowercase - i1, i2 = self.inputs() - if isinstance(i1, IRTensor) and isinstance(i2, IRTensor): - shape1 = [EinDim(dims[d]) for d in range(len(i1.shape))] - shape2 = [EinDim(dims[d]) for d in range(len(i2.shape))] - if len(i1.shape) == len(i2.shape): - for idx, (dim1, dim2) in enumerate(zip(i1.shape, i2.shape)): - if dim1 != dim2: - shape1[idx] = EinDim(str(dim1), EinDim.ReduceType.Stay) - shape2[idx] = EinDim(str(dim2), EinDim.ReduceType.Stay) - else: - if len(i1.shape) == 1: - shape1[0].name = str(i1.shape[0]) - elif len(i2.shape) == 1: - shape2[0].name = str(i2.shape[0]) - out_shape = shape1 if i1.nele() > i2.nele() else shape2 - self.set_input_ein(0, shape1) - self.set_input_ein(1, shape2) - self.set_output_ein(0, out_shape) - else: - if isinstance(i1, IRTensor): - shape1 = [EinDim(dims[d]) for d in range(len(i1.shape))] - self.set_input_ein(0, shape1) - self.set_output_ein(0, shape1) - elif isinstance(i2, IRTensor): - shape2 = [EinDim(dims[d]) for d in range(len(i2.shape))] - self.set_input_ein(1, shape2) - self.set_output_ein(0, shape2) - else: - raise RuntimeError("both inputs {i1} and {i2} are not IRTensor") - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - elew = ElementWise(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - elew.set_output(idx, output) - return elew - - -class Add(ElementWise): - """ - torch.add - """ - def __init__(self, signature, inputs, name='add', **kwargs): - """ - Inputs: - inputs[0]: IRTensor - inputs[1]: other (IRTensor or Number) - inputs[2]: alpha (Number) - Outputs: - same shape as inputs[0] - """ - if len(inputs) != 3: - raise TypeError( - f"Add expected 3 inputs: [tensor, other, alpha], but got {inputs}" - ) - super().__init__(signature, inputs[:2], name=name) - alpha = inputs[2] - self.kwargs['alpha'] = alpha - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - inputs = inputs = self.kwags['alpha'] - add = Add(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - add.set_output(idx, output) - return add - - -class Sub(ElementWise): - """ - torch.add - """ - def __init__(self, signature, inputs, name='sub', **kwargs): - """ - Inputs: - inputs[0]: IRTensor - inputs[1]: other (IRTensor or Number) - inputs[2]: alpha (Number) - Outputs: - same shape as inputs[0] - """ - if len(inputs) != 3: - raise TypeError( - f"Add expected 3 inputs: [tensor, other, alpha], but got {inputs}" - ) - super().__init__(signature, inputs[:2], name=name) - alpha = inputs[2] - self.kwargs['alpha'] = alpha - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - inputs = inputs = self.kwags['alpha'] - add = Sub(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - add.set_output(idx, output) - return add - - -class LayerNorm(IRFwOperation): - - def __init__(self, signature, inputs, name='layernorm', **kwargs): - - if len(inputs) != 5: - raise TypeError(f"Expected 5 inputs, but got: {inputs}") - input = inputs[0] - normalized_shape = inputs[1] - if not isinstance(normalized_shape, list): - raise TypeError(f"Expected list of int, but got: {type(normalized_shape)}") - weight = inputs[2] - bias = inputs[3] - eps = inputs[4] - - inputs = [input, normalized_shape, weight, bias] - super().__init__(name, signature, input_length=4, output_length=1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs['eps'] = eps - - def infer_shape(self): - if self.inputs(0).shape is None: - return False - self.outputs(0).shape = self.inputs(0).shape - return True - - -class Activation(IREinops): - """ - functions like GELU, RELU, Dropout. - - Exclude softmax - """ - - def __init__(self, signature, inputs, name='activation', **kwargs): - - if len(inputs) != 1: - raise TypeError("Expected single tensor input") - - super().__init__( - name, signature, - input_length=1, - output_length=1 - ) - self.set_input(0, inputs[0]) - - def make_expression(self): - """ - * -> * - """ - dims = string.ascii_lowercase - dim1 = [EinDim(dims[d]) for d in range(len(self.inputs(0).shape))] - self.set_input_ein(0, dim1) - self.set_output_ein(0, dim1) - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - op = Activation(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - op.set_output(idx, output) - return op - - -class GELU(Activation): - """ - torch.nn.functional.gelu(input, approximate: bool = False) - - Note `approximate` argument is new at pytorch version v1.11 - """ - def __init__(self, signature, inputs, name='gelu', **kwargs): - - super().__init__(signature, [inputs[0]], name) - if len(inputs) == 2: - self.kwargs['approximate'] = inputs[1] - self.set_input(0, inputs[0]) - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - if 'approximate' in self.kwargs: - inputs.append(self.kwargs['approximate']) - op = GELU(self.signature, inputs, self.name) - op.set_output(0, outputs[0]) - return op - - -class Dropout(Activation): - """ - torch.nn.functional.dropout - """ - def __init__(self, signature, inputs, name='dropout', **kwargs): - - if len(inputs) != 4: - raise TypeError(f"Expected 4 inputs but got {inputs}") - super().__init__(signature, [inputs[0]], name) - self.set_input(0, inputs[0]) - self.kwargs['p'] = inputs[1] - self.kwargs['training'] = inputs[2] - self.kwargs['inplace'] = inputs[3] - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - inputs = inputs + [self.kwargs['p'], self.kwargs['training'], self.kwargs['inplace']] - op = Dropout(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - op.set_output(idx, output) - return op - - -class Softmax(Activation): - - def __init__(self, signature, inputs, name='softmax', **kwargs): - - if len(inputs) != 4: - raise TypeError(f"Expected 4 inputs, but got: {inputs}") - - tensor, dim, stacklevel, dtype = inputs[0], inputs[1], inputs[2], inputs[3] - super().__init__(signature, inputs=[inputs[0]], name=name) - self.set_input(0, tensor) - self.kwargs['dim'] = dim - self.kwargs['_stacklevel'] = stacklevel - self.kwargs['dtype'] = dtype - - def make_expression(self): - super().make_expression() - dim = self.kwargs['dim'] - self._ieins[0][dim].reduce = EinDim.ReduceType.Stay - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - inputs = inputs + [self.kwargs['dim'], self.kwargs['_stacklevel'], self.kwargs['dtype']] - op = Dropout(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - op.set_output(idx, output) - return op - -# ===================== Loss Computation (Reduce) ========================= - -class Sum(IREinops): - """ - torch.sum - """ - def __init__(self, signature, inputs, name='sum', **kwargs): - - if len(inputs) <= 1: - raise TypeError(f"Expected at least 2 inputs, but got {inputs}") - if inputs[1] is not None and not isinstance(inputs[1], int): - raise TypeError(f"Expected inputs[1] to be None or int, but got {type(inputs[1])}") - - super().__init__( - name, signature, - input_length=1, - output_length=1 - ) - self.set_input(0, inputs[0]) - if inputs[1] is not None: - self.kwargs['dim'] = inputs[1] - if len(inputs) > 2: - self.kwargs['keepdim'] = inputs[2] - else: - self.kwargs['keepdim'] = False - - def make_expression(self): - """ - * -> 1 (no extra kwarg) - a b c -> a c (dim b) - a b c -> a 1 c (dim b and keepdim) - """ - reducedim = None if 'dim' not in self.kwargs else self.kwargs['dim'] - keepdim = False if 'keepdim' not in self.kwargs else self.kwargs['keepdim'] - input = self.inputs(0) - dims = string.ascii_lowercase - in_dim = [ - EinDim(dims[d]) for d in range(len(input.shape))] - ou_dim = copy.copy(in_dim) - if reducedim is not None: - in_dim[reducedim].reduce = EinDim.ReduceType.Sum - if keepdim: - ou_dim[reducedim] = EinDim('1') - else: - ou_dim.pop(reducedim) - else: - for dim in in_dim: - dim.reduce = EinDim.ReduceType.Sum - ou_dim = [EinDim('1')] - self.set_input_ein(0, in_dim) - self.set_output_ein(0, ou_dim) - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - reducedim = None if 'dim' not in self.kwargs else self.kwargs['dim'] - keepdim = False if 'keepdim' not in self.kwargs else self.kwargs['keepdim'] - inputs += [reducedim] - if reducedim is not None: - if keepdim: - inputs += [keepdim] - sum_op = Sum(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - sum_op.set_output(idx, output) - return sum_op - -# ========================= Memory Operation ========================== - -class Transpose(IREinops): - """ - torch.transpose - """ - def __init__(self, signature, inputs, name='transpose', **kwargs): - - if len(inputs) != 3: - raise RuntimeError("expected 3 inputs ") - - if not isinstance(inputs[1], int): - raise TypeError(f"Expected 1st input: int, but got {type(inputs[1])}") - if not isinstance(inputs[2], int): - raise TypeError(f"Expected 1st input: int, but got {type(inputs[2])}") - - super().__init__( - name, signature, - input_length=1, - output_length=1 - ) - self.set_input(0, inputs[0]) - self.kwargs['dim0'] = inputs[1] - self.kwargs['dim1'] = inputs[2] - - def make_expression(self): - """ - similar like a b c -> a c b - """ - dims = string.ascii_lowercase - dim0 = self.kwargs['dim0'] - dim1 = self.kwargs['dim1'] - input = self.inputs(0) - in_dim = [EinDim(dims[d]) for d in range(len(input.shape))] - ou_dim = copy.copy(in_dim) - ou_dim[dim0], ou_dim[dim1] = in_dim[dim1], in_dim[dim0] - self.set_input_ein(0, in_dim) - self.set_output_ein(0, ou_dim) - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - dim0 = self.kwargs['dim0'] - dim1 = self.kwargs['dim1'] - inputs += [dim0, dim1] - op = Transpose(self.signature, inputs, self.name) - for idx, output in enumerate(outputs): - op.set_output(idx, output) - return op - - -class Conv2D(IREinops): +from cube.graph.operator.function.einops import EinDim, EinopAnno, IREinops + + +def Linear(signature, inputs): + annos = [ + 'b * k+, n k+ -> b * n', # no bias + 'b * k+, n k+, n -> b * n' # have bias + ] + return IREinops(signature, annos, inputs, 'linear') + + +def BatchLinear(signature, inputs): + annos = [ + 'b m k, b k n -> b m n' + ] + return IREinops(signature, annos, inputs, 'bmm') + + +def Add(signature, inputs): + assert len(inputs) == 3 + inputs, alpha = inputs[0:2], inputs[2] + # TODO: support broadcast + annos = [ + '*, 1 -> *', + '1, * -> *', + '*, * -> *', + ] + return IREinops(signature, annos, inputs, 'add', alpha=alpha) + + +def Sub(signature, inputs): + assert len(inputs) == 3 + inputs, alpha = inputs[0:2], inputs[2] + # TODO: support broadcast + annos = [ + '*, 1 -> *', + '1, * -> *', + '*, * -> *', + ] + return IREinops(signature, annos, inputs, 'sub', alpha=alpha) + + +def Mul(signature, inputs): + annos = [ + '*, 1 -> *', + '1, * -> *', + '*, * -> *', + ] + return IREinops(signature, annos, inputs, 'mul') + + +def Div(signature, inputs): + annos = [ + '*, 1 -> *', + '1, * -> *', + '*, * -> *', + ] + return IREinops(signature, annos, inputs, 'div') + + +def GeLU(signature, inputs): + annos = ['* -> *'] + tensor = inputs[0:1] + if len(inputs) == 2: + # adapt for newest pytorch version + approximate = inputs[1] + return IREinops(signature, annos, tensor, 'gelu', + approximate=approximate) + else: + return IREinops(signature, annos, tensor, 'gelu') + + +def Softmax(signature, inputs): + annos = ['* -> *'] + tensor = inputs[0:1] + dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] + return IREinops(signature, annos, tensor, 'softmax', + dim=dim, _stacklevel=_stacklevel, dtype=dtype) + + +def Dropout(signature, inputs): + annos = [ + '* -> *' + ] + tensor = inputs[0:1] + p, training, inplace = inputs[1], inputs[2], inputs[3] + return IREinops(signature, annos, tensor, 'dropout', + p=p, traning=training, inplace=inplace) + + +def Sum(signature, inputs): + # TODO: support dim reduction + annos = [ + '* -> 1', + ] + tensor = inputs[0:1] + dim = inputs[1] + if dim is not None: + keepdim = inputs[2] if len(inputs) > 2 else False + return IREinops(signature, annos, tensor, 'sum', + dim=dim, keepdim=keepdim) + else: + return IREinops(signature, annos, tensor, 'sum') + + +def Transpose(signature, inputs): + def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: + dim0, dim1 = node.kwargs[0], node.kwargs[1] + anno.outputs[0][dim0], anno.outputs[0][dim1] = \ + anno.inputs[0][dim1], anno.inputs[0][dim0] + return anno + annos = [('* -> *', adapt),] + inputs, dim0, dim1 = inputs[0:1], inputs[1], inputs[2] + return IREinops(signature, annos, inputs, 'transpose', + dim0=dim0, dim1=dim1) + + +def Conv2D(signature, inputs): """ torch.conv2d(input, weight, bias, stride, padding, dialation, groups) https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html?highlight=torch%20conv2d#torch.nn.functional.conv2d """ - def __init__(self, signature, inputs, name='conv2d', **kwargs): - if len(inputs) != 7: - raise RuntimeError(f"expected 7 operators for conv2d but got {len(inputs)}") - super().__init__( - name, signature, - input_length=3, - output_length=1 - ) - for idx, input in enumerate(inputs[:3]): - self.set_input(idx, input) - self.kwargs['stride'] = inputs[3] - self.kwargs['padding'] = inputs[4] - self.kwargs['dilation'] = inputs[5] - self.kwargs['groups'] = inputs[6] - - def make_expression(self): - input = 'N I {iH} {iW}' - weight = 'O {group_channel} {kH} {kW}' - bias = 'O' - output = 'N O {oH} {oW}' - # parameters - groups = self.kwargs['groups'] - stride = self.kwargs['stride'] - padding = self.kwargs['padding'] - dilation = self.kwargs['dilation'] - kH = self.inputs(1).shape[2] - kW = self.inputs(1).shape[3] - - iH, iW = self.inputs(0).shape[2:4] - group_channel = self.inputs(0).shape[2] // groups - oH = (iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0] + 1 - oW = (iH + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1] + 1 - - input = input.format(iH=iH, iW=iW) - weight = weight.format(group_channel=group_channel, kH=kH, kW=kW) - output = output.format(oH=oH, oW=oW) - - expr = f'{input}, {weight}, {bias} -> {output}' - [idims, wdims, bdims], [odims] = self.parse(expr) - self.set_input_ein(0, idims) - self.set_input_ein(1, wdims) - if self.inputs(2) is not None: - self.set_input_ein(2, bdims) - self.set_output_ein(0, odims) - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - groups = self.kwargs['groups'] - stride = self.kwargs['stride'] - padding = self.kwargs['padding'] - dilation = self.kwargs['dilation'] - inputs += [groups, stride, padding, dilation] - op = Conv2D(self.signature, inputs, self.name) - op.set_output(0, outputs[0]) - return op - - -class CustomizeEinop(IREinops): - """ - Customize Einop - """ - def __init__(self, signature: str, inputs, name, **kwargs): - expected = ['anno', 'stay', 'kwarg_idx', 'kwarg_name'] - if not all([attr in kwargs for attr in expected]): - raise KeyError("Expected anno, stay, kwarg_idx, kwarg_name for UDF function") - self.anno: str = kwargs['anno'] - self.stay: List[str] = kwargs['stay'] - # get input output - input_anno, output_anno = self.anno.split('->') - ninputs = len(input_anno.split(',')) - noutputs = len(output_anno.split(',')) - self.kwarg_idx: List[int] = kwargs['kwarg_idx'] - self.kwarg_name: List[str] = kwargs['kwarg_name'] - kwarg_inputs = [inputs[idx] for idx in self.kwarg_idx] - op_inputs = [input for input in inputs if input not in kwarg_inputs] - if len(kwarg_inputs) + ninputs != len(inputs): - raise ValueError( - f"Got {len(inputs)} inputs but kwarg inputs" - f"({len(kwarg_inputs)}) + anno inputs ({ninputs}) doesn't match" - ) - super().__init__( - name, signature, - input_length=ninputs, - output_length=noutputs - ) - for name, kinput in zip(self.kwarg_name, kwarg_inputs): - self.kwargs[name] = kinput - for idx, input in enumerate(op_inputs): - self.set_input(idx, input) - - def make_expression(self): - ishapes, oshapes = self.parse(self.anno) - for idx, ishape in enumerate(ishapes): - for idim in ishape: - if idim.name in self.stay: - idim.reduce = EinDim.ReduceType.Stay - self.set_input_ein(idx, ishape) - for idx, oshape in enumerate(oshapes): - for odim in oshape: - if odim.name in self.stay: - odim.reduce = EinDim.ReduceType.Stay - self.set_output_ein(idx, oshape) - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - kwargs = dict( - anno = self.anno, - stay = self.stay, - kwarg_idx = self.kwarg_idx, - kwarg_name = self.kwarg_name, - ) - all_inputs = [None] * (len(self.inputs) + len(self.kwarg_idx)) - remain_idx = list(range(len(all_inputs))) - for idx, name in zip(self.kwarg_idx, self.kwarg_name): - all_inputs[idx] = self.kwargs[name] - remain_idx.remove(idx) - for idx, input in zip(remain_idx, inputs): - all_inputs[idx] = input - op = CustomizeEinop(self.signature, all_inputs, self.name, **kwargs) - for idx, output in enumerate(outputs): - op.set_output(idx, output) - return op - -# ===================== Cube Complex Operation ======================= - -class CubeComplexToQKV(IRFwOperation): - """ - Inputs: - hidden_state: [L, N, E] - weight: [3 * (num_head * dim_head), E] - num_head: int - dim_head: int - - where L = sequence length, N = batch size, E = num_head * dim_head - - Returns: - Q: [L, N * num_head, dim_head] - K: [L, N * num_head, dim_head] - V: [L, N * num_head, dim_head] - """ - def __init__(self, signature, inputs, name='toqkv', **kwargs): - if len(inputs) != 3: - raise TypeError(f"Expected 3 arguments but goit {inputs}") - qkv, weight = inputs[0], inputs[1] - super().__init__( - name, signature, - input_length=2, - output_length=3 - ) - self.set_input(0, qkv) - self.set_input(1, weight) - self.kwargs['num_head'] = inputs[2] - - def infer_shape(self): - if self.inputs(0).shape is None or self.inputs(1) is None: - return False - seqlen = self.inputs(0).shape[0] - bs = self.inputs(0).shape[1] - num_head = self.kwargs['num_head'] - dim_head = self.inputs(1).shape[0] // 3 // num_head - - shape = [seqlen, bs * num_head, dim_head] - for output in self.outputs(): - output.shape = shape - return True - - -class CubeComplexTrilMask(IRFwOperation): - """ - Inputs: - input: [N * num_head, L, L] - num_head: int - - Returns: - output: [N * num_head, L, L] - """ - def __init__(self, signature, inputs, name='trilmask', **kwargs): - if len(inputs) != 2: - raise TypeError("Expected 2 input") - tensor, num_head = inputs[0], inputs[1] - super().__init__( - name, signature, - input_length=1, - output_length=1 - ) - self.set_input(0, tensor) - self.kwargs['num_head'] = num_head - - def infer_shape(self): - if self.inputs(0).shape is None: - return False - self._outputs[0].shape = self.inputs(0).shape - return True - - -class CubeComplexAttnView(IRFwOperation): - """ - Inputs: - [N * num_head, L, dim_head] - - Outputs: - [L, N, num_head * dim_head] - """ - def __init__(self, signature, inputs, name='attn_view', **kwargs): - if len(inputs) != 2: - raise TypeError("Expected 2 input") - tensor, num_head = inputs[0], inputs[1] - super().__init__( - name, signature, - input_length=1, - output_length=1 - ) - self.set_input(0, tensor) - self.kwargs['num_head'] = num_head - - def infer_shape(self): - if self.inputs(0).shape is None: - return False - num_head = self.kwargs['num_head'] - bs = self.inputs(0).shape[0] // num_head - seqlen = self.inputs(0).shape[1] - dim_head = self.inputs(0).shape[2] - shape = [seqlen, bs, num_head * dim_head] - self._outputs[0].shape = shape - return True - - -class CubeComplexSelfAttention(IRFwOperation): - """ - Multi-Head Self-Attention. - - L: sequence length - N: batch size - E: embedding size - - Inputs: - hidden_state: [L, N, E] - w_qkv : [3 * num_head * dim_head, E] - w_out : [E, E] - num_head: int - dim_head: int - dropout_p: float - - Outputs: - hidden_state: [L, N, E] - """ - def __init__(self, signature, inputs, name='selfattn', **kwargs): - if len(inputs) != 6: - raise RuntimeError(f"Expected 6 inputs but got {input}") - num_head: int = inputs[3] - dim_head: int = inputs[4] - dropout_p: float = inputs[5] - super().__init__( - name, signature, - input_length = 3, - output_length = 1 - ) - for idx, tensor in enumerate(inputs[:3]): - self.set_input(idx, tensor) - self.kwargs['num_head'] = num_head - self.kwargs['dim_head'] = dim_head - self.kwargs['dropout_p'] = dropout_p - - def infer_shape(self): - if self.inputs(0).shape is None: - return False - self.outputs(0).shape = self.inputs(0).shape - return True - - -class CubeComplexFeedForward(IRFwOperation): - """ - FeedForward - - Inputs: - hidden_state: [L, N, E] - w_proj1: [4 * E, E] - w_bias1: [4 * E,] - w_porj2: [E, 4 * E] - w_bias2: [E,] - - Outputs: - hidden_state: [L, N, E] - """ - def __init__(self, signature, inputs, name='selfattn', **kwargs): - if len(inputs) != 5: - raise RuntimeError(f"Expected 6 inputs but got {inputs}") - super().__init__( - name, signature, - input_length = 5, - output_length = 1 - ) - for idx, tensor in enumerate(inputs): - self.set_input(idx, tensor) - - def infer_shape(self): - if self.inputs(0).shape is None: - return False - self.outputs(0).shape = self.inputs(0).shape - return True - - -class CubeComplexEmbedding(IRFwOperation): - """ - Embedding - """ - def __init__(self, signature, inputs, name='embedding', **kwargs): - if len(inputs) != 4: - raise RuntimeError(f"Expected 4 inputs but got {inputs}") - input, weight = inputs[0], inputs[1] - start, stop = inputs[2], inputs[3] - super().__init__( - name, signature, - input_length = 2, - output_length = 1 - ) - self.set_input(0, input) - self.set_input(1, weight) - self.kwargs['start'] = start - self.kwargs['stop'] = stop - - def infer_shape(self): - if self.inputs(0).shape is None or self.inputs(1).shape is None: - return False - self.outputs(0).shape = self.inputs(0).shape + [self.inputs(1).shape[1]] - return True - - -class UnkownOperator(IRFwOperation): - - def __init__(self, signature, inputs, name='unknown_op', n_outputs=None): - - super().__init__( - name, signature=signature, - input_length=len(inputs), - output_length=n_outputs, - ) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - - def infer_shape(self): - return False + def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: + iH, iW = node.inputs(0).shape[2:4] + stride = node.kwargs['stride'] + padding = node.kwargs['padding'] + dilation = node.kwargs['dilation'] + dH = node.inputs(1).shape[2] + dW = node.inputs(1).shape[3] + oH = (iH + 2 * padding[0] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 + oW = (iW + 2 * padding[1] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 + anno.outputs[0][2] = EinDim([str(oH)]) + anno.outputs[0][3] = EinDim([str(oW)]) + return anno + annos = [('N iC H W, oC GiC dH dW, oC -> N oC oH oW', adapt)] + tensors = inputs[0:3] + stride, padding, dilation, groups = inputs[3:] + return IREinops(signature, annos, tensors, 'conv2d', + stride=stride, padding=padding, dilation=dilation, groups=groups) From 26d9134f73c94d02d1c9e9614997f531b9ad082a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jan 2022 08:08:26 +0000 Subject: [PATCH 0568/1892] einop documents --- cube/graph/operator/function/einops.py | 66 +++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 777a9b9f..10e4b545 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -1,6 +1,54 @@ """ -This operator class is highly inspired by eniops. +This operator class is highly inspired by einops. + +* Annotating Dimensions: + + e.g., 'a+', 'ab^', 'cd', '(ab+ c^ d)', '64' + +A dimension of a tensor can be annotated by {identifier}{reduce} template. + +An `identifier` must be one of: + 1) symbolic annotation that must match with the criteria of python str.isidentifier. + 2) numeric string that must match with python str.isnumeric. This indicates the shape is the same value + numeric string will always have '^' reduction type + 3) '*': this special value indicates the dimension is dynamic will automatically get expanded given the shape + +A `reduce` can be a set of {'', '+', '^'}: + '' indicates this dimension will apear in output. + '+' indicates no this dimension will be reduced in output using sume + '^' means this dimension is out of scope, Einops will not handle this (cannot do split on it) + +A complex annotation for a dimension is using brackets, i.e., '(' and ')', to include +more inner-dimensions. The value of inner dimension must be (partially) indicated by function args (of same name) +so that letting system know (infer). + +* Annotating Operator: + +e.g., 'm k+, n k+ -> m n', '4 k+, k+ d -> 8 d', '* d^, s -> * s' + +An operator dimension can be annoted with input dimensions and output dimensions. +Same identifier indicates the same shape and semantically same dimension propagation. + +'->' seperates the inputs (left) and outputs (right) and ',' separates each input and output. +A shape needs to be annotated using dimension annotations with delimiters of (mulitple) space ' '. + +Dimension annotations in Output must apear in inputs, or using numeric string + +* Splitting Rule: + +Spatial Splitting (dimension with '' reduce type): + tensors that have this dimension will be splitted spatially. + tensors that don't have this dimension will be replicated. + +Numerical Splitting (dimension with '+' reduce type): + tensors that have this dimension will be splitted spatially, + tensors that don't have this dimension will be splitted numerically + +Illegal Splitting (dimension with '^' reduce type): + Illegal splitting algorithm on this dimension. + """ + from typing import Callable, Dict, List, Union from typing import Optional, Set, Tuple, Optional import enum @@ -26,8 +74,8 @@ class EinDim: class ReduceType(enum.Enum): Spatial='' - Stay = '^' # the dim is not allowed to be split Sum = '+' + Stay = '^' # the dim is not allowed to be split def __init__(self, name: Union[str, List[str]]): if isinstance(name, str): @@ -50,6 +98,8 @@ def __init__(self, name: Union[str, List[str]]): # get identifier name if len(n) == 0 or not (str.isidentifier(n) or str.isnumeric(n) or n == '*'): raise ValueError(f"EinDim name {n} should be identifier") + if str.isnumeric(n): + reduce = EinDim.ReduceType.Stay self._name.append(n) self._reduce.append(reduce) for n in self._name: @@ -61,6 +111,9 @@ def name(self) -> str: return self._name[0] return '(' + ' '.join(self._name) + ')' + def names(self) -> List[str]: + return copy.copy(self._name) + @property def reduce(self) -> str: return self._reduce @@ -219,7 +272,6 @@ def infer_shape(self) -> bool: raise RuntimeError("No matching anno for given annos") dimlen: Dict[str, int] = dict() for input, ishape in zip(self.inputs(), self._iannos): - print(input.shape, ishape) if not ((ishape is None and not isinstance(input, IRTensor)) or len(ishape) == len(input.shape)): raise RuntimeError(f"node {self._id}: error match input: {input.shape} and einshape: {ishape}") @@ -338,10 +390,10 @@ def parse(self, anno: EinopAnno): def einexpr(self) -> str: inputs = list() outputs = list() - for iein in self._ieins: - inputs.append(' '.join([repr(ein) for ein in iein])) - for oein in self._oeins: - outputs.append(' '.join([repr(ein) for ein in oein])) + for shape in self._iannos: + inputs.append(' '.join([repr(edim) for edim in shape])) + for shape in self._oannos: + outputs.append(' '.join([repr(edim) for edim in shape])) return ', '.join(inputs) + ' -> ' + ', '.join(outputs) def algorithms(self, tag: Optional[str] = None): From 1c20111b15c57dafe40c04045083991d3cd46a6b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jan 2022 08:24:23 +0000 Subject: [PATCH 0569/1892] fix einop parse bug --- cube/graph/operator/function/einops.py | 19 ++++++++++--------- cube/graph/operator/function/function.py | 7 ++++++- cube/graph/parser/mapping.py | 22 ++++------------------ cube/graph/parser/parser.py | 4 ++-- cube/graph/parser/register.py | 20 +++++++++++++++++--- 5 files changed, 39 insertions(+), 33 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 10e4b545..de135146 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -274,7 +274,7 @@ def infer_shape(self) -> bool: for input, ishape in zip(self.inputs(), self._iannos): if not ((ishape is None and not isinstance(input, IRTensor)) or len(ishape) == len(input.shape)): - raise RuntimeError(f"node {self._id}: error match input: {input.shape} and einshape: {ishape}") + raise RuntimeError(f"node {self._id} {self.signature}: error match input: {input.shape} and einshape: {ishape}") for tdim, edim in zip(input.shape, ishape): if len(edim._name) == 1: if edim.name in dimlen and dimlen[edim.name] != tdim: @@ -334,8 +334,6 @@ def parse(self, anno: EinopAnno): """ parse annotations, assuming input tensor shape is given """ - # copy - anno = EinopAnno(anno.anno) if len(anno.inputs) != len(self.inputs()): return False, None, None identifiers = anno.identifiers() @@ -378,13 +376,16 @@ def parse(self, anno: EinopAnno): dimlen: Dict[str, int] = dict() for shape, input in zip(anno.inputs, self.inputs()): if not isinstance(input, IRTensor): - if shape.name != '1': + if not (len(shape) != 1 and shape[0].name != '1'): return False, None, None - for dim, nele in zip(shape, input.shape): - if dim.name in dimlen: - if nele != dimlen[dim.name]: - return False, None, None - dimlen[dim.name] = nele + else: + if len(input.shape) != len(shape): + return False, None, None + for edim, nele in zip(shape, input.shape): + if edim.name in dimlen: + if nele != dimlen[edim.name]: + return False, None, None + dimlen[edim.name] = nele return True, anno.inputs, anno.outputs def einexpr(self) -> str: diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index f01313f8..daf6cd98 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -132,8 +132,13 @@ def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: anno.outputs[0][2] = EinDim([str(oH)]) anno.outputs[0][3] = EinDim([str(oW)]) return anno - annos = [('N iC H W, oC GiC dH dW, oC -> N oC oH oW', adapt)] + annos = [ + ('N iC H W, oC GiC dH dW, oC -> N oC oH oW', adapt), + ('N iC H W, oC GiC dH dW -> N oC oH oW', adapt), + ] tensors = inputs[0:3] + if tensors[-1] is None: + tensors = inputs[0:2] stride, padding, dilation, groups = inputs[3:] return IREinops(signature, annos, tensors, 'conv2d', stride=stride, padding=padding, dilation=dilation, groups=groups) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 8375b8ae..edb2c6bb 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -55,9 +55,9 @@ def register(signature: str, op: IRFwOperation): __ftemplate('dropout') : function.Dropout, - __ftemplate('gelu') : function.GELU, + __ftemplate('gelu') : function.GeLU, - __ftemplate('layer_norm'): function.LayerNorm, + # __ftemplate('layer_norm'): function.LayerNorm, # torch aten @@ -65,9 +65,9 @@ def register(signature: str, op: IRFwOperation): __ttemplate('sub') : function.Sub, - __ttemplate('mul') : partial(function.ElementWise, name='mul'), + __ttemplate('mul') : function.Mul, - __ttemplate('div') : partial(function.ElementWise, name='div'), + __ttemplate('div') : function.Div, __ttemplate('bmm') : function.BatchLinear, @@ -77,20 +77,6 @@ def register(signature: str, op: IRFwOperation): __ttemplate('conv2d'): function.Conv2D, - # complex - - __customize('toqkv'): partial(function.CubeComplexToQKV, name='toqkv'), - - __customize('tril_mask'): function.CubeComplexTrilMask, - - __customize('attn_view'): function.CubeComplexAttnView, - - __customize('self_attn'): function.CubeComplexSelfAttention, - - __customize('feedforward'): function.CubeComplexFeedForward, - - __customize('embedding'): function.CubeComplexEmbedding, - } diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index a578c0b9..351c53bd 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -178,7 +178,7 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: val = frame.get_var(var_name) input_vals.append(val) try: - ir_node = Sign2Op.map(fsig)(inputs=input_vals, n_outputs=len(outputs)) + ir_node = Sign2Op.map(fsig)(inputs=input_vals) except Exception: # print(module.code) raise RuntimeError(f"Parsing error of {node}") @@ -216,7 +216,7 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: input_val.append(val) # create IR node - ir_node = Sign2Op.map(fsig)(inputs=input_val, n_outputs=len(outputs)) + ir_node = Sign2Op.map(fsig)(inputs=input_val) if len(ir_node.outputs()) != len(outputs): raise RuntimeError( f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index c2d8b73c..62b652a9 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -7,7 +7,7 @@ import inspect import torch -from cube.graph.operator.function import CustomizeEinop +from cube.graph.operator.function.einops import IREinops from cube.graph.parser.mapping import Sign2Op @@ -15,6 +15,20 @@ def register(anno: str, stay: List[str] = None): """ Register a function with einop annotations. + + This function is cooperated with CustomizeEinop. + User needs to define a python function with type annotations + for each input argument. And user needs to pass dimension annotations + as well as (optional) frozen split dimensions (i.e., the dimensions cannot split). + + For EinDims containing brackets (e.g., (3 h d)), + user should have same argument name in the function definition + to help system infer each dim length, e.g., + + @cube.register('a (b c) -> (a b) c') + def funcname(x: torch.Tensor, b: int = 4): + xxx + """ if stay is None: stay = list() @@ -33,9 +47,9 @@ def decorator(fn: Callable): kwarg_name.append(name) kwarg_idx.append(idx) print(f'registering op {func_name} with {len(args.parameters) - len(kwarg_idx)} inputs and {len(kwarg_idx)} kwargs...') - udfop = partial(CustomizeEinop, + udfop = partial(IREinops, name=func_name, - anno=anno, stay=stay, + anno=[anno], kwarg_idx=kwarg_idx, kwarg_name=kwarg_name ) Sign2Op.register(func_name, udfop) From 53ed9bbf5a2e9dee2185f3b28a9826f64539f136 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jan 2022 08:24:38 +0000 Subject: [PATCH 0570/1892] update examples --- examples/attention/attention.py | 25 ++-- examples/attention/policy/tensor_parallel.py | 127 +++---------------- examples/{sci => poisson}/policy/naive.py | 0 examples/{sci => poisson}/sci.py | 4 +- 4 files changed, 38 insertions(+), 118 deletions(-) rename examples/{sci => poisson}/policy/naive.py (100%) rename examples/{sci => poisson}/sci.py (98%) diff --git a/examples/attention/attention.py b/examples/attention/attention.py index 652e5905..d324dbd6 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -9,6 +9,11 @@ --master_port=8004 \ --use_env \ examples/attention/attention.py + +OMP_NUM_THREADS=4 torchrun --standalone \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/attention/attention.py """ import torch @@ -17,15 +22,14 @@ import cube -from examples.attention.policy.data_parallel import transform_policy -from examples.attention.policy.data_parallel import schedule_policy +from examples.attention.policy.tensor_parallel import PAS from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -@cube.graph.operator.register('L N E, (3 h d) E -> L N (h d)', stay=['L', 'd', 'E']) -def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, num_head: int, +@cube.graph.parser.register('L N E, (3 h d) E -> L N (h d)', stay=['L', 'd', 'E']) +def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, h: int, scale: float, dropout: float, training: bool): """ L: sequence length @@ -34,8 +38,9 @@ def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, num_head: int, x: hidden state: [L, N, E] wqkv: qkv weight: [3 * (num_head * dim_head), E] dropout: float - num_head: int + h: int: number of heads """ + num_head = h L, N = x.shape[0], x.shape[1] dim_head = wqkv.shape[0] // 3 // num_head # L N E, (3 h d) E -> L N (3 h d) @@ -69,7 +74,7 @@ def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, num_head: int, mask = mask.view(N, 1, L, L) mask = (mask < 0.5) attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N, num_head), L, L) + attn = attn.view((N * num_head), L, L) # (N h) L L -> (N h) L L attn = F.softmax(attn, dim=-1) @@ -136,9 +141,13 @@ def train(): model, input_shapes=([L, N, E],), ) - dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([L, N, E],), + dtypes=(torch.float32,), + batch_dims=(1,) + ) - @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + @cube.compile(model, dataloader, policy=PAS) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) diff --git a/examples/attention/policy/tensor_parallel.py b/examples/attention/policy/tensor_parallel.py index df132025..14c3f6fe 100644 --- a/examples/attention/policy/tensor_parallel.py +++ b/examples/attention/policy/tensor_parallel.py @@ -1,110 +1,21 @@ from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation - - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using tensor parallel - """ - ndevs = resource.ngpus - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - assert len(fnodes) == 14 - - toqkv = fnodes[0] - q_t = fnodes[1] - k_t = fnodes[2] - v_t = fnodes[3] - q_scale = fnodes[4] - k_t2 = fnodes[5] - qk_bmm = fnodes[6] - mask = fnodes[7] - softmax = fnodes[8] - dropout = fnodes[9] - attnv_bmm = fnodes[10] - attnview = fnodes[11] - linear = fnodes[12] - loss = fnodes[13] - - all_sub_nodes = list() - - algo = toqkv.algorithms('head') - sub_nodes = graph.partition(toqkv, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = q_t.algorithms('dim') - sub_nodes = graph.partition(q_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = k_t.algorithms('dim') - sub_nodes = graph.partition(k_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = v_t.algorithms('dim') - sub_nodes = graph.partition(v_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = q_scale.algorithms('dim') - sub_nodes = graph.partition(q_scale, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = k_t2.algorithms('dim') - sub_nodes = graph.partition(k_t2, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = qk_bmm.algorithms('data') - sub_nodes = graph.partition(qk_bmm, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = mask.algorithms('head') - sub_nodes = graph.partition(mask, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = softmax.algorithms('dim') - sub_nodes = graph.partition(softmax, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = dropout.algorithms('dim') - sub_nodes = graph.partition(dropout, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = attnv_bmm.algorithms('data') - sub_nodes = graph.partition(attnv_bmm, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = attnview.algorithms('head') - sub_nodes = graph.partition(attnview, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = linear.algorithms('row') - sub_nodes = graph.partition(linear, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - sub_nodes = graph.replicate(loss, times=ndevs) - all_sub_nodes.append(sub_nodes) - - for sub_nodes in all_sub_nodes: - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - print(graph) - # assert False +from cube.graph.operator.operator import IRDataOperation, IRFwOperation + + +def PAS(graph: IRGraph, resource): + # data loader + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + graph.assign(node, list(range(resource.ngpus))) + fnodes = [isinstance(node, IRFwOperation) for node in graph.nodes()] + for idx, node in enumerate(fnodes): + if idx == 0: + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, config=dict(idx=0, dim=1, num=resource.ngpus) + ) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + print(graph.extra_repr()) return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - sugraph.partial_set_order(fsus, lazy=False) - return sugraph + diff --git a/examples/sci/policy/naive.py b/examples/poisson/policy/naive.py similarity index 100% rename from examples/sci/policy/naive.py rename to examples/poisson/policy/naive.py diff --git a/examples/sci/sci.py b/examples/poisson/sci.py similarity index 98% rename from examples/sci/sci.py rename to examples/poisson/sci.py index 3e75ef68..f6456010 100644 --- a/examples/sci/sci.py +++ b/examples/poisson/sci.py @@ -7,13 +7,13 @@ torch.set_default_tensor_type(torch.DoubleTensor) import cube -from examples.sci.policy.naive import PAS +from examples.poisson.policy.naive import PAS """ OMP_NUM_THREADS=4 torchrun --standalone \ --nproc_per_node=4 \ --nnodes=1 \ - examples/sci/sci.py + examples/poisson/sci.py """ From 1a3ffa569c409f516eac536a0dca827955525f31 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Jan 2022 06:23:05 +0000 Subject: [PATCH 0571/1892] fix replicate issue --- cube/graph/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index f7f66c76..ac2f89ad 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -309,7 +309,7 @@ def replicate(self, op: IRCell, times=1) -> Optional[List[IRCell]]: for idx, fnode in enumerate(fnodes): self.attach(fnode, fidx + idx) # insert backward - if isinstance(op.mirror, IRBpOperation): + if isinstance(op.mirror, IRBpOperation) and op.mirror in self.nodes(): for fnode in fnodes: fnode.gen_backward() bnodes = [fnode.mirror for fnode in fnodes][::-1] From 0d8900a1d197c30e47d7d7cefdd91a5489e80ac8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Jan 2022 07:05:52 +0000 Subject: [PATCH 0572/1892] PAS with replicate --- examples/poisson/policy/naive.py | 6 ++++-- examples/poisson/sci.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/poisson/policy/naive.py b/examples/poisson/policy/naive.py index 31d2d2df..8d6ff06e 100644 --- a/examples/poisson/policy/naive.py +++ b/examples/poisson/policy/naive.py @@ -2,6 +2,8 @@ def PAS(graph: IRGraph, resource): for node in graph.nodes(): - graph.assign(node, 0) + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) print(graph.extra_repr()) - return graph \ No newline at end of file + return graph diff --git a/examples/poisson/sci.py b/examples/poisson/sci.py index f6456010..222b1feb 100644 --- a/examples/poisson/sci.py +++ b/examples/poisson/sci.py @@ -11,7 +11,7 @@ """ OMP_NUM_THREADS=4 torchrun --standalone \ - --nproc_per_node=4 \ + --nproc_per_node=2 \ --nnodes=1 \ examples/poisson/sci.py """ From 902d5a14acd8c45b9108209ffc5ad6f3ae60376e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Jan 2022 07:06:25 +0000 Subject: [PATCH 0573/1892] conv2d use defined class --- cube/graph/operator/function/conv.py | 52 ++++++++++++++++++++++++ cube/graph/operator/function/function.py | 51 ++++++++++++++--------- 2 files changed, 84 insertions(+), 19 deletions(-) create mode 100644 cube/graph/operator/function/conv.py diff --git a/cube/graph/operator/function/conv.py b/cube/graph/operator/function/conv.py new file mode 100644 index 00000000..237aef94 --- /dev/null +++ b/cube/graph/operator/function/conv.py @@ -0,0 +1,52 @@ +from typing import List + +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + + +class IRConv2D(IRFwOperation): + + def __init__(self, signature: str, inputs: List[IRTensor], name: str, + **kwargs): + assert len(inputs) == 3, "Expected only input, weight, bias as inputs" + assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" + super().__init__(name, signature, 3, 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update(kwargs) + + def infer_shape(self) -> bool: + """ + Output shape inference given the input shapes + """ + if len(self.inputs(0).shape) == 0 or len(self.inputs(1).shape) == 0: + return False + N = self.inputs(0).shape[0] + iH, iW = self.inputs(0).shape[2:4] + oC = self.inputs(1).shape[0] + stride = self.kwargs['stride'] + padding = self.kwargs['padding'] + dilation = self.kwargs['dilation'] + dH = self.inputs(1).shape[2] + dW = self.inputs(1).shape[3] + oH = (iH + 2 * padding[0] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 + oW = (iW + 2 * padding[1] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 + shape = [N, oC, oH, oW] + self.outputs(0).shape = shape + return True + + def new(self, inputs: List, outputs: List): + """ + construct a new operator sharing same kwargs with new inputs + and outputs + """ + stride = self.kwargs['stride'] + padding = self.kwargs['padding'] + dilation = self.kwargs['dilation'] + groups = self.kwargs['groups'] + op = IRConv2D(self.signature, inputs, self.name, + stride=stride, padding=padding, dilation=dilation, groups=groups) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + op.infer_shape() + return op diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index daf6cd98..cacdc536 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -1,4 +1,5 @@ from cube.graph.operator.function.einops import EinDim, EinopAnno, IREinops +from cube.graph.operator.function.conv import IRConv2D def Linear(signature, inputs): @@ -115,30 +116,42 @@ def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: dim0=dim0, dim1=dim1) +# def Conv2D(signature, inputs): +# """ +# torch.conv2d(input, weight, bias, stride, padding, dialation, groups) +# https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html?highlight=torch%20conv2d#torch.nn.functional.conv2d +# """ +# def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: +# iH, iW = node.inputs(0).shape[2:4] +# stride = node.kwargs['stride'] +# padding = node.kwargs['padding'] +# dilation = node.kwargs['dilation'] +# dH = node.inputs(1).shape[2] +# dW = node.inputs(1).shape[3] +# oH = (iH + 2 * padding[0] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 +# oW = (iW + 2 * padding[1] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 +# anno.outputs[0][2] = EinDim([str(oH)]) +# anno.outputs[0][3] = EinDim([str(oW)]) +# return anno +# annos = [ +# ('N iC H W, oC GiC dH dW, oC -> N oC oH oW', adapt), +# ('N iC H W, oC GiC dH dW -> N oC oH oW', adapt), +# ] +# tensors = inputs[0:3] +# if tensors[-1] is None: +# tensors = inputs[0:2] +# stride, padding, dilation, groups = inputs[3:] +# return IREinops(signature, annos, tensors, 'conv2d', +# stride=stride, padding=padding, dilation=dilation, groups=groups) + + def Conv2D(signature, inputs): """ torch.conv2d(input, weight, bias, stride, padding, dialation, groups) https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html?highlight=torch%20conv2d#torch.nn.functional.conv2d """ - def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: - iH, iW = node.inputs(0).shape[2:4] - stride = node.kwargs['stride'] - padding = node.kwargs['padding'] - dilation = node.kwargs['dilation'] - dH = node.inputs(1).shape[2] - dW = node.inputs(1).shape[3] - oH = (iH + 2 * padding[0] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 - oW = (iW + 2 * padding[1] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 - anno.outputs[0][2] = EinDim([str(oH)]) - anno.outputs[0][3] = EinDim([str(oW)]) - return anno - annos = [ - ('N iC H W, oC GiC dH dW, oC -> N oC oH oW', adapt), - ('N iC H W, oC GiC dH dW -> N oC oH oW', adapt), - ] + assert len(inputs) == 7, f"Expected 7 inputs but only got {len(inputs)}" tensors = inputs[0:3] - if tensors[-1] is None: - tensors = inputs[0:2] stride, padding, dilation, groups = inputs[3:] - return IREinops(signature, annos, tensors, 'conv2d', + return IRConv2D(signature, tensors, 'conv2d', stride=stride, padding=padding, dilation=dilation, groups=groups) From b417bf21fba639e9bd30effa9b04939da05b92eb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Jan 2022 10:29:24 +0000 Subject: [PATCH 0574/1892] add conv2d dim split --- cube/algorithm/factory.py | 3 + cube/algorithm/ops/conv.py | 86 ++++++++++++++++++++++++ cube/graph/operator/function/function.py | 4 +- 3 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 cube/algorithm/ops/conv.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 2a6f8579..9a67bb11 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -65,6 +65,9 @@ def _load_predefined_algos(self): import cube.algorithm.ops.einops as einops self.register(einops.IREinops, einops.DimSplitEinops, tag='dim') + import cube.algorithm.ops.conv as conv + self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') + # import cube.algorithm.ops.elementwise as elew # self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') # self.register(elew.Add, elew.AddDimParallel, tag='dim') diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py new file mode 100644 index 00000000..897b4807 --- /dev/null +++ b/cube/algorithm/ops/conv.py @@ -0,0 +1,86 @@ +from typing import Dict + +from cube.algorithm.utils import split_axis, split_value +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.operator.function.conv import IRConv2D + + +class DimSplitConv2D(GenericDistAlgo): + """ + split Conv2D at dimension level + + (N iC H W) () + """ + + + def __init__(self, node: IRConv2D): + if not isinstance(node, IRConv2D): + raise TypeError(f"Expect IRConv2D") + super().__init__(node) + + def satisfy(self, config: Dict): + """ + config = dict(idx=int, dim=int, num=num) + + N iC H W, oC iC dH dW, oC -> N oC oH oW + + Splittable dimension: N, oC + Reduce dimension: oC + """ + for attr in ['idx', 'dim', 'num']: + if not attr in config: + raise KeyError("Expected idx, dim, num in the config") + node: IRConv2D = self.node + idx: int = config['idx'] + dim: int = config['dim'] + num: int = config['num'] + groups = node.kwargs['groups'] + # split N: + if (idx, dim) == (0, 0): + return node.inputs(0).shape[0] % num == 0 + # split oC + if (idx, dim) == (1, 0): + return node.inputs(1).shape[0] % num == 0 + # split iC + if (idx, dim) == (0, 1) or (idx, dim) == (1, 1): + return groups == 1 and node.inputs(1).shape[0] % 0 == num + + def instantiate(self, config: Dict): + if not self.satisfy(config): + return False + node: IRConv2D = self.node + idx: int = config['idx'] + dim: int = config['dim'] + num: int = config['num'] + + inputs, weights, bias = list(), list(), list() + outputs = list() + # split N + if (idx, dim) == (0, 0): + inputs = split_axis(node.inputs(0), axis=0, chunk_num=num) + weights = [node.inputs(1)] * num + bias = [node.inputs(2)] * num + outputs = split_axis(node.outputs(0), axis=0, chunk_num=num) + # split oC + if (idx, dim) == (1, 0): + inputs = [node.inputs(0)] * num + weights = split_axis(node.inputs(1), axis=0, chunk_num=num) + if node.inputs(2) is None: + bias = [None] * num + else: + bias = split_axis(node.inputs(2), axis=0, chunk_num=num) + outputs = split_axis(node.outputs(0), axis=1, chunk_num=num) + # split iC + if (idx, dim) == (0, 1) or (idx, dim) == (1, 1): + inputs = split_axis(node.inputs(0), axis=1, chunk_num=num) + weights = split_axis(node.inputs(1), axis=1, chunk_num=num) + if node.inputs(2) is None: + bias = [None] * num + else: + bias = split_value(node.inputs(2), chunk_num=num) + outputs = split_value(node.outputs(0), chunk_num=num) + subnodes = list() + for i, w, b, o in zip(inputs, weights, bias, outputs): + subnodes.append(node.new([i, w, b], [o])) + return subnodes diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index cacdc536..7f05f2d6 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -134,8 +134,8 @@ def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: # anno.outputs[0][3] = EinDim([str(oW)]) # return anno # annos = [ -# ('N iC H W, oC GiC dH dW, oC -> N oC oH oW', adapt), -# ('N iC H W, oC GiC dH dW -> N oC oH oW', adapt), +# ('N iC+ H^ W^, oC iC+ dH^ dW^, oC -> N oC oH^ oW^', adapt), +# ('N iC+ H^ W^, oC iC+ dH^ dW^ -> N oC oH^ oW^', adapt), # ] # tensors = inputs[0:3] # if tensors[-1] is None: From ea06ea0487cb46b22a6fe65580b67ee820e6bea1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Jan 2022 11:51:58 +0000 Subject: [PATCH 0575/1892] backward node removed if no backward called --- cube/algorithm/ops/conv.py | 30 ++++++++++++++++++++++++++++++ cube/graph/graph.py | 2 +- cube/ir/cten.py | 14 ++++++++------ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 897b4807..59819260 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -84,3 +84,33 @@ def instantiate(self, config: Dict): for i, w, b, o in zip(inputs, weights, bias, outputs): subnodes.append(node.new([i, w, b], [o])) return subnodes + + +class HaloSplitCon2D(GenericDistAlgo): + """ + Halo-exchange split + + N iC H W, oC iC dH dW, oC -> N oC oH oW + """ + + def __init__(self, node: IRConv2D): + if not isinstance(node, IRConv2D): + raise TypeError(f"Expect IRConv2D") + super().__init__(node) + + def satisfy(self, config: Dict): + for attr in ['idx', 'dim', 'num']: + if not attr in config: + raise KeyError("Expected idx, dim, num in the config") + node: IRConv2D = self.node + idx: int = config['idx'] + dim: int = config['dim'] + num: int = config['num'] + groups = node.kwargs['groups'] + stride = node.kwargs['groups'] + padding = node.kwargs['padding'] + dilation = node.kwargs['dilation'] + # split H + if (idx, dim) == (0, 2): + strideH = stride[0] + pass \ No newline at end of file diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ac2f89ad..f7f66c76 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -309,7 +309,7 @@ def replicate(self, op: IRCell, times=1) -> Optional[List[IRCell]]: for idx, fnode in enumerate(fnodes): self.attach(fnode, fidx + idx) # insert backward - if isinstance(op.mirror, IRBpOperation) and op.mirror in self.nodes(): + if isinstance(op.mirror, IRBpOperation): for fnode in fnodes: fnode.gen_backward() bnodes = [fnode.mirror for fnode in fnodes][::-1] diff --git a/cube/ir/cten.py b/cube/ir/cten.py index d913b50e..6d5de5df 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -104,12 +104,14 @@ def mirror(self, other): @staticmethod def make_pair(cell1, cell2): - if not isinstance(cell1, IRCell): - raise TypeError("Expected cell1 to be IRCell") - if not isinstance(cell2, IRCell): - raise TypeError("Expected cell2 to be IRCell") - cell1._mirror = cell2 - cell2._mirror = cell1 + if isinstance(cell1, IRCell): + cell1._mirror = cell2 + elif cell1 is not None: + raise TypeError("Expected cell1 to be IRCell or None") + if isinstance(cell2, IRCell): + cell2._mirror = cell1 + elif cell2 is not None: + raise TypeError("Expected cell2 to be IRCell or None") def on_device(self, device_id: int): """ From 63f48a9f68fbff3ddf25008764d6bcc133cf503c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Jan 2022 11:57:12 +0000 Subject: [PATCH 0576/1892] backward node removed if no backward called --- cube/compiler.py | 2 +- cube/logics/translator.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/cube/compiler.py b/cube/compiler.py index c8dec8b0..bbdcba39 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -189,7 +189,7 @@ def decorator(fn: Callable) -> Callable: bs = [out.shape[dim] for out, dim in zip(dnode.outputs(), dnode.get_batch_dims())] all_batch_size.update(bs) if len(all_batch_size) != 1: - raise NotImplementedError("Heterogenous batch size is not supported") + raise NotImplementedError(f"Heterogenous batch size {bs} is not supported") batch_size = torch.tensor(list(all_batch_size), dtype=torch.int).cuda() compile_end = time.time() diff --git a/cube/logics/translator.py b/cube/logics/translator.py index 1ac493e7..49d9c97b 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -1,3 +1,6 @@ +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRCell + from cube.logics.dataloader import IRDataLoader from cube.logics import model from cube.logics.pool import SchedulePool @@ -16,6 +19,12 @@ def gen_logic_graph(outputs=None): """ nodes = SchedulePool().nodes() graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') + # remove backward nodes if no backward is called + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + bnode = node.mirror + if bnode not in graph.nodes(): + IRCell.make_pair(node, None) return graph @staticmethod From 0109967f880abcdac935db980a8b0a94d9637946 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 26 Jan 2022 11:48:45 +0000 Subject: [PATCH 0577/1892] conv for H split --- cube/algorithm/factory.py | 1 + cube/algorithm/ops/conv.py | 72 +++++++++++++++++++++--- cube/algorithm/ops/layernorm.py | 63 --------------------- cube/algorithm/utils.py | 34 +++++++++-- cube/graph/operator/function/conv.py | 1 + cube/graph/operator/function/function.py | 5 ++ cube/runtime/function/__init__.py | 2 +- cube/runtime/function/function.py | 40 +++++++++++++ 8 files changed, 143 insertions(+), 75 deletions(-) delete mode 100644 cube/algorithm/ops/layernorm.py create mode 100644 cube/runtime/function/function.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 9a67bb11..e898d60e 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -67,6 +67,7 @@ def _load_predefined_algos(self): import cube.algorithm.ops.conv as conv self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') + self.register(conv.IRConv2D, conv.HaloSplitConv2D, tag='halo') # import cube.algorithm.ops.elementwise as elew # self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 59819260..5cfd2e17 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -1,6 +1,6 @@ from typing import Dict -from cube.algorithm.utils import split_axis, split_value +from cube.algorithm.utils import split_axis, split_axis_custom, split_value from cube.algorithm.generics import GenericDistAlgo from cube.graph.operator.function.conv import IRConv2D @@ -10,10 +10,9 @@ class DimSplitConv2D(GenericDistAlgo): """ split Conv2D at dimension level - (N iC H W) () + N iC H W, oC iC dH dW, oC -> N oC oH oW """ - def __init__(self, node: IRConv2D): if not isinstance(node, IRConv2D): raise TypeError(f"Expect IRConv2D") @@ -86,7 +85,7 @@ def instantiate(self, config: Dict): return subnodes -class HaloSplitCon2D(GenericDistAlgo): +class HaloSplitConv2D(GenericDistAlgo): """ Halo-exchange split @@ -103,14 +102,73 @@ def satisfy(self, config: Dict): if not attr in config: raise KeyError("Expected idx, dim, num in the config") node: IRConv2D = self.node + H, W = node.inputs(0).shape[2:] + idx: int = config['idx'] + dim: int = config['dim'] + num: int = config['num'] + stride = node.kwargs['stride'] + dilation = node.kwargs['dilation'] + # FIXME: stride + if stride != [1, 1]: + raise NotImplementedError("Splitting on stride != [1,1] is not supported") + if dilation != [1, 1]: + raise NotImplementedError("Splitting on dilation != [1,1] is not supported") + # split H + if (idx, dim) == (0, 2): + return H % num == 0 + # split W + if (idx, dim) == (0, 3): + return W % num == 0 + + def instantiate(self, config: Dict): + if not self.satisfy(config): + return None + node: IRConv2D = self.node + H, W = node.inputs(0).shape[2:] + dH, dW = node.inputs(1).shape[2:] + oH, oW = node.outputs(0).shape[2:] idx: int = config['idx'] dim: int = config['dim'] num: int = config['num'] groups = node.kwargs['groups'] - stride = node.kwargs['groups'] + stride = node.kwargs['stride'] padding = node.kwargs['padding'] dilation = node.kwargs['dilation'] # split H if (idx, dim) == (0, 2): - strideH = stride[0] - pass \ No newline at end of file + # input and padding + slicers = list() + pads = list() + for idx in range(num): + # input + start = max(0, H // num * idx - dH + 1) + stop = min(H, H // num * (idx + 1) + dH - 1) + slicers.append(slice(start, stop, 1)) + # padding + padl = padding[0] if start == 0 else 0 + padr = padding[1] if stop == H else 0 + pads.append([padl, padr, padding[2], padding[3]]) + inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) + # weight + weights = [node.inputs(1)] * num + # bias + bias = [node.inputs(2)] * num + # padding + pads.append([padl, padr, padding[2], padding[3]]) + # outputs + slicers = list() + for idx in range(num): + start = start = max(0, oH // num * idx - dH + 1) + stop = min(oH, oH // num * (idx + 1) + dH - 1) + slicers.append(slice(start, stop, 1)) + outputs = split_axis_custom(node.outputs(0), axis=dim, chunks=slicers) + # split W + if (idx, dim) == (0, 1): + raise NotImplementedError("Split on W is not supported yet") + sub_nodes = list() + for i, w, b, pad, o in zip(inputs, weights, bias, pads, outputs): + conv = IRConv2D(node.signature, [i, w, b], node.name, + stride=stride, padding=pad, dilation=dilation, groups=groups) + conv.set_output(0, o) + sub_nodes.append(conv) + return sub_nodes diff --git a/cube/algorithm/ops/layernorm.py b/cube/algorithm/ops/layernorm.py deleted file mode 100644 index 51a1f467..00000000 --- a/cube/algorithm/ops/layernorm.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Dict -import copy - -from cube.algorithm.utils import split_axis -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.operator.function import LayerNorm - - -_kWaitDecision = None - -class LayerNormDimParallel(GenericDistAlgo): - - def __init__(self, node: LayerNorm, dim=None): - if not isinstance(node, LayerNorm): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.ndim = len(node.inputs(0).shape) - last_ndims = len(node.inputs(1)) - self.stay_dims = list() - for dim in range(last_ndims): - self.stay_dims.append(self.ndim - dim - 1) - - self.chunk_num = _kWaitDecision - self.dim = dim - - def satisfy(self, config: Dict): - if 'dim' in config: - dim = config['dim'] - else: - if self.dim is None: - raise RuntimeError("Expected dim in config") - dim = self.dim - if dim < 0: - dim = self.ndim + dim - chunk_num = int(config['chunk_num']) - if dim in self.stay_dims: - return False - shape = self.input_shapes[0] - if dim >= 0 and dim < self.ndim and shape[dim] % chunk_num == 0: - return True - return False - - def instantiate(self, node: LayerNorm, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - if 'dim' in config: - self.dim = config['dim'] - - input = node.inputs(0) - sub_inputs = split_axis(input, self.dim, self.chunk_num) - - output = node.outputs(0) - sub_outputs = split_axis(output, self.dim, self.chunk_num) - - nodes = list() - for sub_input, sub_output in zip(sub_inputs, sub_outputs): - inputs = [sub_input] + node.inputs()[1:] + [node.kwargs['eps']] - sub_node = LayerNorm(node.signature, inputs, node.name) - sub_node.set_output(0, sub_output) - nodes.append(sub_node) - return nodes diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py index 67fae0b0..5e0bd4a2 100644 --- a/cube/algorithm/utils.py +++ b/cube/algorithm/utils.py @@ -1,8 +1,8 @@ +from typing import List, Union +from cube.graph.tensor import IRSubTensor -from cube.ir.cten import IRTensor - -def split_axis(tensor: IRTensor, axis: int, chunk_num: int): +def split_axis(tensor: IRSubTensor, axis: int, chunk_num: int): """ Split tensor along an axis. The axis can be positive or negative. """ @@ -34,7 +34,33 @@ def split_axis(tensor: IRTensor, axis: int, chunk_num: int): return sub_tensors -def split_value(tensor: IRTensor, chunk_num: int): +def split_axis_custom(tensor: IRSubTensor, axis: int, chunks: List[slice]): + """ + Split tensor along an axis with cutomized selection + """ + if axis < 0: + axis = len(tensor.shape) + axis + if axis >= len(tensor.shape): + raise RuntimeError(f"Axis should within dims ({axis} >= {len(tensor.shape)})") + chunk_num = len(chunks) + + slicers, shape = list(), list() + for nele in tensor.shape: + slicers.append(slice(0, nele, 1)) + shape.append(nele) + sub_tensors = list() + for cid in range(chunk_num): + slicers[axis] = chunks[cid] + shape[axis] = chunks[cid].stop - chunks[cid].start + sub_tensors.append(tensor.select( + indmap = tuple(slicers), + valmap = None, + shape = shape + )) + return sub_tensors + + +def split_value(tensor: IRSubTensor, chunk_num: int): # full shape shape_slicer = list() diff --git a/cube/graph/operator/function/conv.py b/cube/graph/operator/function/conv.py index 237aef94..ca148686 100644 --- a/cube/graph/operator/function/conv.py +++ b/cube/graph/operator/function/conv.py @@ -8,6 +8,7 @@ class IRConv2D(IRFwOperation): def __init__(self, signature: str, inputs: List[IRTensor], name: str, **kwargs): + signature = 'cube.runtime.function.conv2d' assert len(inputs) == 3, "Expected only input, weight, bias as inputs" assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" super().__init__(name, signature, 3, 1) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 7f05f2d6..2501584f 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -153,5 +153,10 @@ def Conv2D(signature, inputs): assert len(inputs) == 7, f"Expected 7 inputs but only got {len(inputs)}" tensors = inputs[0:3] stride, padding, dilation, groups = inputs[3:] + if isinstance(padding, int): + padding = [padding] * 4 + elif len(padding) == 2: + padH, padW = padding + padding = [padH, padH, padW, padW] return IRConv2D(signature, tensors, 'conv2d', stride=stride, padding=padding, dilation=dilation, groups=groups) diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py index 294c00ce..e86bbc76 100644 --- a/cube/runtime/function/__init__.py +++ b/cube/runtime/function/__init__.py @@ -1,3 +1,3 @@ -import cube.runtime.function.complex as complex from cube.runtime.function.complex import * from cube.runtime.function.dist import * +from cube.runtime.function.function import * \ No newline at end of file diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py new file mode 100644 index 00000000..06a6d0e6 --- /dev/null +++ b/cube/runtime/function/function.py @@ -0,0 +1,40 @@ +from typing import Optional, List +import torch +import torch.nn.functional as TorchF + + +def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + stride: int, padding: List[int], dilation, groups: int = 1): + """ + input: N iC H W + weight: oC iC dH dW + bias: oC + padding: int, List[int], e.g., 1, [1, 1], [1, 0, 1, 0] + """ + input = TorchF.pad(input, padding, 'constant', 0) + return TorchF.conv2d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) + + +def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): + """ + Embedding + + Inputs: + input: torch.Tensor [*] + weight: [vocab size, embed size] + start: int + stop: int + + Outputs: + output: [*, embed_size] + """ + input = input.long() + input_mask = (input < start) | (input >= stop) + masked_input = input.clone() - start + masked_input[input_mask] = 0 + output = TorchF.embedding( + masked_input, weight, + None, None, 2.0, False, False + ) + output[input_mask, :] = 0.0 + return output From f1e11c185508a2347bfaae520de5445ae6a99f7a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 28 Jan 2022 08:40:00 +0000 Subject: [PATCH 0578/1892] init tensor sub --- cube/graph/adapter/adapter.py | 8 +-- cube/graph/operator/function/conv.py | 4 +- cube/graph/tensor.py | 74 ++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index ba87ea4b..c34579a9 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -452,11 +452,13 @@ def gen_merge(dst_tensor, intersections): break # cannot merge or add if out is None: - print(f'failed tensor: {dst_tensor}') + print(f'failed tensor: {dst_tensor.extra_repr()}') print(f'ptensor:') for tensor in dst_tensor.parent.ptensors: - print(f'node-{tensor._cell._id}: {tensor}') - print(f'intersections: {intersections}') + print(f'node-{tensor._cell._id}: {tensor.extra_repr()}') + print('intersections:') + for tensor in intersections: + print(f'{tensor.extra_repr()}') raise RuntimeError(f"Merge plan of tensor {dst_tensor} not found") return prims diff --git a/cube/graph/operator/function/conv.py b/cube/graph/operator/function/conv.py index ca148686..f9f5fd25 100644 --- a/cube/graph/operator/function/conv.py +++ b/cube/graph/operator/function/conv.py @@ -30,8 +30,8 @@ def infer_shape(self) -> bool: dilation = self.kwargs['dilation'] dH = self.inputs(1).shape[2] dW = self.inputs(1).shape[3] - oH = (iH + 2 * padding[0] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 - oW = (iW + 2 * padding[1] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 + oH = (iH + padding[0] + padding[1] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 + oW = (iW + padding[2] + padding[3] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 shape = [N, oC, oH, oW] self.outputs(0).shape = shape return True diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 05acbb5b..b0ac5c83 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -189,6 +189,66 @@ def __and__(self, other): raise NotImplementedError(f"not supported for differnt steps") return IndexMap(tuple(slices)) + def __sub__(self, other) -> Optional[List]: + """ + Get the remaining part. + We reuqire other should completely inside this tensor + and the remaining part should be only one tile, else + will return None + + Args: + other: IndexMap + + Returns: + IndexMap for the remaining part + """ + if not isinstance(other, IndexMap): + raise TypeError("Expected IndexMap") + if self.ndims != other.ndims: + return None + dim_common: List[List[slice]] = [list() for _ in range(self.ndims)] + dim_differ: List[List[slice]] = [list() for _ in range(self.ndims)] + for dim, (slicer1, slicer2) in enumerate(zip(self.get(), other.get())): + # self indices + start1, stop1 = slicer1.start, slicer1.stop + step1 = slicer1.step if slicer1.step else 1 + # other indices + start2, stop2 = slicer2.start, slicer2.stop + step2 = slicer2.step if slicer2.step else 1 + if step1 != 1 or step2 != 1: + return None + # no intersection + if min(stop1, stop2) <= max(start1, start2): + return None + # set common + start = max(start1, start2) + stop = min(stop1, stop2) + dim_common[dim].append(slice(start, stop, step1)) + # set difference + if start1 == start2: + if stop2 < stop1: + dim_differ[dim].append(slice(stop2, stop1, step1)) + elif stop1 == stop2: + if start1 < start2: + dim_differ.append(slice(start1, start2, step1)) + else: + raise NotImplementedError("Multipe indexmap is not supported") + indmaps = list() + splitdim = set() + slices = list() + for dim in range(self.ndims): + common = dim_common[dim] + differ = dim_differ[dim] + if len(common) + len(differ) != 1: + raise NotImplementedError("Multipe indexmap is not supported") + if len(differ) == 1: + splitdim.add(dim) + slices.append(differ[0]) + else: + slices.append(common[0]) + indmaps.append(IndexMap(tuple(slices))) + return indmaps + def __repr__(self): dscp = repr(self._indices) return dscp @@ -690,6 +750,20 @@ def common(self, other): raise NotImplementedError("Customized IRTensor not support") return None + def difference(self, other): + """ + Get differene part of sub-tensor + + Currently this requires tensor to be subset + + Args: + other: IRSubTensor + + Returns: + None for fail + """ + pass + def __repr__(self): anno = 't' if self.is_param(): From 85666b2ab39fd7129dfd93b43d85c6d1283c3f18 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Feb 2022 09:58:04 +0800 Subject: [PATCH 0579/1892] adapt to newest pytorch --- cube/runtime/device.py | 3 --- tests/runtime/test_nccl.py | 29 +++++++++++++++-------------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/cube/runtime/device.py b/cube/runtime/device.py index f6dfc470..789b2605 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -13,9 +13,6 @@ class __DeviceGroup: def __init__(self): torch.distributed.init_process_group( backend='nccl', - init_method='env://', - # world_size=device_num, - # init_method='tcp://' + '{master_ip}:{port}'.format(master_ip=master_ip, port=port) ) self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() diff --git a/tests/runtime/test_nccl.py b/tests/runtime/test_nccl.py index 348244c6..7184743b 100644 --- a/tests/runtime/test_nccl.py +++ b/tests/runtime/test_nccl.py @@ -1,12 +1,13 @@ """ Single node usage: -e.g., 8 GPUs -python -m torch.distributed.launch --nproc_per_node=4 test_nccl.py -Multi-node usage: -e.g., 2-node each with 8 GPUs -python -m torch.distributed.launch --nproc_per_node=8 --node_rank=0 --master_port=6000 --master_addr='master ip iddress' --nnodes=2 test_nccl.py -python -m torch.distributed.launch --nproc_per_node=8 --node_rank=1 --master_port=6000 --master_addr='master ip iddress' --nnodes=2 test_nccl.py +e.g., 4 GPUs + +OMP_NUM_THREADS=4 torchrun --standalone \ + --nproc_per_node=8 \ + --nnodes=1 \ + tests/runtime/test_nccl.py + """ import torch @@ -62,11 +63,11 @@ def test_allgather(size, local_rank): toc = time.perf_counter() -def benchmark(args): +def benchmark(args, local_rank): size = args.begin while size <= args.end: - # test_allgather(size * 1024 * 1024, args.local_rank) - test_nccl(size * 1024 * 1024, args.local_rank) # MB to B + # test_allgather(size * 1024 * 1024, local_rank) + test_nccl(size * 1024 * 1024, local_rank) # MB to B size *= 2 print_each_rank('test on nccl is done') @@ -78,12 +79,12 @@ def benchmark(args): help='start message size in MB') parser.add_argument('--end', type=int, default=64, help='end message size in MB') - parser.add_argument('--local_rank', type=int, required=True, - help='specified by torch.distributed.launch') args = parser.parse_args() + print('> initializing distributed environ...') torch.distributed.init_process_group(backend='nccl') - print_each_rank('local rank-{} launches'.format(args.local_rank)) + local_rank = int(os.environ.get('LOCAL_RANK')) + print_each_rank('local rank-{} launches'.format(local_rank)) - torch.cuda.set_device(args.local_rank) - benchmark(args) \ No newline at end of file + torch.cuda.set_device(local_rank) + benchmark(args, local_rank) \ No newline at end of file From e9a52eabd195822c0359bc171c07e75314fb9036 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Feb 2022 13:51:04 +0800 Subject: [PATCH 0580/1892] enable load existing generated code --- cube/codegen/syntax/blocks.py | 3 ++- cube/compiler.py | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/cube/codegen/syntax/blocks.py b/cube/codegen/syntax/blocks.py index 12459f90..bea6d974 100644 --- a/cube/codegen/syntax/blocks.py +++ b/cube/codegen/syntax/blocks.py @@ -23,7 +23,8 @@ def insert_body(self, code): def __exit__(self, exc_type, exc_value, exc_tb): # add indent for function block for idx in range(1, len(self.code)): - self.code[idx] = '\t' + self.code[idx] + # use 4 space as indent + self.code[idx] = ' ' + self.code[idx] if not exc_tb is None: print('Error detected in function block') diff --git a/cube/compiler.py b/cube/compiler.py index bbdcba39..d4d19896 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,6 +1,7 @@ from typing import Callable, Tuple, Union, Optional import torch import time +import os import cube @@ -59,7 +60,7 @@ def __call__(self, *args): def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, - PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None): + PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, override = True): """ AI Scientist calls like: @@ -118,6 +119,18 @@ def _load_tschedule_fn(filename) -> Callable: def decorator(fn: Callable) -> Callable: filename = 'gencode{}.py' batch_size = torch.tensor([-1], dtype=torch.int).cuda() + + if not override and os.path.exists(filename.format(myrank)): + filename = filename.format(myrank) + # TODO: set batch size + print('warning: dataloader batch size stay as default.') + # load module code + print_each_rank(f'loading existed module from {filename} ...') + model.load_module(filename) + # load schedule code + print_each_rank(f'loading existed schedule from {filename} ...') + return _load_tschedule_fn(filename) + if myrank == 0: compile_start = time.time() From 591b62a2ded20a4f2704008288fae68c0fa19be4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Feb 2022 14:37:22 +0800 Subject: [PATCH 0581/1892] fix conv2d H W partition --- cube/algorithm/ops/conv.py | 33 +++++++++++++++---------------- cube/runtime/function/function.py | 5 ++++- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 5cfd2e17..51dfe5a1 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -102,7 +102,7 @@ def satisfy(self, config: Dict): if not attr in config: raise KeyError("Expected idx, dim, num in the config") node: IRConv2D = self.node - H, W = node.inputs(0).shape[2:] + oH, oW = node.outputs(0).shape[2:] idx: int = config['idx'] dim: int = config['dim'] num: int = config['num'] @@ -115,11 +115,11 @@ def satisfy(self, config: Dict): raise NotImplementedError("Splitting on dilation != [1,1] is not supported") # split H if (idx, dim) == (0, 2): - return H % num == 0 + return oH % num == 0 # split W if (idx, dim) == (0, 3): - return W % num == 0 - + return oW % num == 0 + def instantiate(self, config: Dict): if not self.satisfy(config): return None @@ -139,15 +139,19 @@ def instantiate(self, config: Dict): # input and padding slicers = list() pads = list() - for idx in range(num): - # input - start = max(0, H // num * idx - dH + 1) - stop = min(H, H // num * (idx + 1) + dH - 1) - slicers.append(slice(start, stop, 1)) + start = 0 - padding[0] + for cid in range(num): # padding - padl = padding[0] if start == 0 else 0 - padr = padding[1] if stop == H else 0 + padl = padding[0] if cid == 0 else 0 + padr = padding[1] if cid == num - 1 else 0 pads.append([padl, padr, padding[2], padding[3]]) + # input -- FIXME: only work for stride=[1,1] + chunkH = oH // num + dilation[0] * (dH - 1) + stop = start + chunkH - padr + slicers.append(slice(max(0, start), min(H, stop))) + start = stop - dilation[0] * (dH - 1) + # start = 0 if cid == 0 else 1023 + # stop = 1025 if cid == 0 else H inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) # weight weights = [node.inputs(1)] * num @@ -156,12 +160,7 @@ def instantiate(self, config: Dict): # padding pads.append([padl, padr, padding[2], padding[3]]) # outputs - slicers = list() - for idx in range(num): - start = start = max(0, oH // num * idx - dH + 1) - stop = min(oH, oH // num * (idx + 1) + dH - 1) - slicers.append(slice(start, stop, 1)) - outputs = split_axis_custom(node.outputs(0), axis=dim, chunks=slicers) + outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) # split W if (idx, dim) == (0, 1): raise NotImplementedError("Split on W is not supported yet") diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 06a6d0e6..f5aa2315 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -9,8 +9,11 @@ def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso input: N iC H W weight: oC iC dH dW bias: oC - padding: int, List[int], e.g., 1, [1, 1], [1, 0, 1, 0] + padding: List[int, int, int, int]: [Htop, Hbottom, Wtop, Wbottom] or + List[int, int]: [Hside, Wside] """ + # switch H and W to match torch.nn.functional.pad + padding = padding[len(padding) // 2:] + padding[0:len(padding) // 2] input = TorchF.pad(input, padding, 'constant', 0) return TorchF.conv2d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) From 9f8c1e68c6c12644ede361bc2d37e7bd12d6f985 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Feb 2022 14:58:53 +0800 Subject: [PATCH 0582/1892] remove require grad in inference --- cube/ir/cten.py | 2 ++ cube/logics/translator.py | 8 ++++++++ examples/poisson/policy/naive.py | 8 +++++++- examples/poisson/sci.py | 6 +++--- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 6d5de5df..e32048f2 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -565,6 +565,8 @@ def __eq__(self, tensor): @property def shape(self): + if self._shape is None: + return [] return copy.copy(self._shape) @shape.setter diff --git a/cube/logics/translator.py b/cube/logics/translator.py index 49d9c97b..bca42fc8 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -25,6 +25,14 @@ def gen_logic_graph(outputs=None): bnode = node.mirror if bnode not in graph.nodes(): IRCell.make_pair(node, None) + for input in node.inputs(): + if isinstance(input, IRSubTensor): + input.grad = None + input.requires_grad = False + for output in node.outputs(): + if isinstance(output, IRSubTensor): + output.grad = None + output.requires_grad = False return graph @staticmethod diff --git a/examples/poisson/policy/naive.py b/examples/poisson/policy/naive.py index 8d6ff06e..df323a9c 100644 --- a/examples/poisson/policy/naive.py +++ b/examples/poisson/policy/naive.py @@ -1,8 +1,14 @@ from cube.graph import IRGraph +from cube.graph.operator.function import IRConv2D def PAS(graph: IRGraph, resource): for node in graph.nodes(): - sub_nodes = graph.replicate(node, times=resource.ngpus) + if isinstance(node, IRConv2D): + algo = node.algorithms('halo') + sub_nodes = graph.partition(node, algo, config=dict(idx=0, dim=2, num=resource.ngpus)) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + # sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) print(graph.extra_repr()) diff --git a/examples/poisson/sci.py b/examples/poisson/sci.py index 222b1feb..62e75322 100644 --- a/examples/poisson/sci.py +++ b/examples/poisson/sci.py @@ -4,13 +4,13 @@ import torch.nn.functional as F import time -torch.set_default_tensor_type(torch.DoubleTensor) +# torch.set_default_tensor_type(torch.DoubleTensor) import cube from examples.poisson.policy.naive import PAS """ -OMP_NUM_THREADS=4 torchrun --standalone \ +OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=2 \ --nnodes=1 \ examples/poisson/sci.py @@ -89,7 +89,7 @@ def train_loop(): model = ScientificModel() model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes),) - @cube.compile(model=model, dataloader=varloader, PAS=PAS) + @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) def train_iter(model, dataloader): r0, p, phi, filter = next(dataloader) r0, p, phi, r1_sum = model(r0, p, phi, filter) From 42111e34ce6f030e3053925a70265b32a31fe569 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Feb 2022 11:10:17 +0800 Subject: [PATCH 0583/1892] megatron pesudo code --- examples/gpt/policy/megatron.md | 74 +++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 examples/gpt/policy/megatron.md diff --git a/examples/gpt/policy/megatron.md b/examples/gpt/policy/megatron.md new file mode 100644 index 00000000..136a789f --- /dev/null +++ b/examples/gpt/policy/megatron.md @@ -0,0 +1,74 @@ + +``` +function PADataParallel(Graph G, Resource R, Config C): + for node in G.nodes() do + algorithm <- getPartitionAlgo(node, 'data parallelism') + subnodes <- G.partition(node, algorithm, C.data_parallel_size) + for dp_idx in 0 to C.data_parallel_size do + rank <- mapDpToRank(dp_idx, R) + G.assign(subnodes[dp_idx], rank) + return G + + +function PATensorParallel(Graph G, Resource R, Config C): + for node in G.nodes() do + algorithm <- getPartitionAlgo(node, 'tensor parallelism') + subnodes <- G.partition(node, algorithm, C.tensor_parallel_size) + for tp_idx in 0 to C.tensor_parallel_size do + rank <- mapTpToRank(tp_idx, R) + G.assign(subnodes[tp_idx], rank) + return G + + +function PAPipelineParallel(Graph G, Resource R, Config C): + + for node in G.nodes() do + algorithm <- getPartitionAlgo(node, 'data parallelism') + G.partition(node, algorithm, C.num_micro_batches) + + for node in G.nodes() do + stage_id <- getStageID(node, G, C.num_stages) // policy + rank <- mapStageToRank(stage_id, R) + G.assign(node, stage) + + groupStageAndMicroBatch(G, C.num_stages, C.num_micro_batches) + return G + + +function PSPipelineParallel(Graph G, Resource R, Config C): + // each node in G stands for a stage (sub-graph) + sequence <- EmptyArray[] + // warmup phase + for micro_batch_id in 0 to C.num_micro_batches do + for stage_id in 0 to C.num_stages - micro_batch_id do + node <- getForwardStage(G, micro_batch_id, stage_id) + arrayPush(sequence, node) + # steady and cooldown phase + for micro_batch_id in 0 to C.num_micro_batches do + // enqueue backward + for stage_id in C.num_stages to 0 do + node <- getBackwardStage(G, micro_batch_id, stage_id) + arrayPush(sequence, node) + // enqueue forward + for stage_id in 0 to C.num_stages do + mid <- micro_batch_id + C.num_stages - stage_id + if mid <= C.num_stages then + node <- getForwardStage(G, mid, stage_id) + arrayPush(sequence, node) + G.schedule(sequence) + return G + + +function Megatron(Graph G, Resource R, Config C): + // Resource split + R_data, R_pipe, R_tensor <- splitResource(R, C) + // split to stages + G <- PAPipelineParallel(G, R_pipe, C) + // inner stage: data + tensor parallelism + for stage in G.nodes: + PADataParallel(stage, R_data, C) + PATensorParallel(stage, R_tensor, C) + // inter stage: 1F1B scheduling + G <- PSPipelineParallel(G, R_pipe, C) + return G +``` \ No newline at end of file From 21ae1b5f3ebf31bdfc040920f551de03a9757a00 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Feb 2022 12:46:44 +0800 Subject: [PATCH 0584/1892] use default type to create tensor --- cube/graph/parser/parser.py | 6 ++++-- cube/ir/cten.py | 15 ++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 351c53bd..14538864 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -39,8 +39,9 @@ def parse_module(module, # handle graph input -- Assuming all the inputs are tensors input_var_name = [input.debugName() for input in module.graph.inputs()] + kDefaultType = DType2IRDType.map(torch.get_default_dtype()) for index, var_name in enumerate(input_var_name[1:]): # omit self - frame.add_var(var_name, IRFullTensor(name=var_name, requires_grad=False), graph_arg=index) + frame.add_var(var_name, IRFullTensor(name=var_name, requires_grad=False, dtype=kDefaultType), graph_arg=index) input_val = [frame.get_var(var_name) for var_name in input_var_name[1:]] # handle input shape @@ -397,7 +398,8 @@ def parse_prim_tupleunpack_node(node, module, frame) -> List[None]: dtype = output.type().str() var_name = output.debugName() if dtype == 'Tensor': - ir_tensor = IRFullTensor(name=var_name) + kDefaultType = DType2IRDType.map(torch.get_default_dtype()) + ir_tensor = IRFullTensor(name=var_name, dtype=kDefaultType) tuple_outs.append(ir_tensor) frame.add_var(var_name, ir_tensor) else: diff --git a/cube/ir/cten.py b/cube/ir/cten.py index e32048f2..0d54d424 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -235,11 +235,12 @@ def set_input(self, input_index: int, val: Any): # set tensor dst val.attach_cell(self) # set input value dtype - if self._dtype != IRDType.unknown: - val.dtype = self._dtype - # set cell dtype - elif val.dtype != IRDType.unknown: + if self._dtype == IRDType.unknown: self._dtype = val.dtype + for output in self.outputs(): + if isinstance(output, IRTensor): + output.dtype = self._dtype + val.dtype = self._dtype self._inputs[input_index] = val return val @@ -259,11 +260,7 @@ def set_output(self, output_index: int, val: Any): val = copy.copy(val) val.attach_cell(self) # set output value dtype - if self._dtype != IRDType.unknown: - val.dtype = self._dtype - # set cell dtype - elif val.dtype != IRDType.unknown: - self._dtype = val.dtype + val.dtype = self._dtype self._outputs[output_index] = val return val From b9352821ed0a30ea89745cfe9e345f9bb4174150 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Feb 2022 12:49:50 +0800 Subject: [PATCH 0585/1892] enable conv split at W dimension --- cube/algorithm/ops/conv.py | 28 ++++++++++++++++++++++++---- examples/poisson/policy/naive.py | 2 +- examples/poisson/sci.py | 2 +- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 51dfe5a1..1c593801 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -108,6 +108,8 @@ def satisfy(self, config: Dict): num: int = config['num'] stride = node.kwargs['stride'] dilation = node.kwargs['dilation'] + if dim not in [2, 3]: + return False # FIXME: stride if stride != [1, 1]: raise NotImplementedError("Splitting on stride != [1,1] is not supported") @@ -157,13 +159,31 @@ def instantiate(self, config: Dict): weights = [node.inputs(1)] * num # bias bias = [node.inputs(2)] * num - # padding - pads.append([padl, padr, padding[2], padding[3]]) # outputs outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) # split W - if (idx, dim) == (0, 1): - raise NotImplementedError("Split on W is not supported yet") + if (idx, dim) == (0, 3): + # input and padding + slicers = list() + pads = list() + start = 0 - padding[2] + for cid in range(num): + # padding + padt = padding[2] if cid == 0 else 0 + padb = padding[3] if cid == num - 1 else 0 + pads.append([padding[0], padding[1], padt, padb]) + # input -- FIXME: only work for stride=[1,1] + chunkH = oW // num + dilation[0] * (dH - 1) + stop = start + chunkH - padb + slicers.append(slice(max(0, start), min(H, stop))) + start = stop - dilation[0] * (dH - 1) + inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) + # weight + weights = [node.inputs(1)] * num + # bias + bias = [node.inputs(2)] * num + # outputs + outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) sub_nodes = list() for i, w, b, pad, o in zip(inputs, weights, bias, pads, outputs): conv = IRConv2D(node.signature, [i, w, b], node.name, diff --git a/examples/poisson/policy/naive.py b/examples/poisson/policy/naive.py index df323a9c..bb9f45bf 100644 --- a/examples/poisson/policy/naive.py +++ b/examples/poisson/policy/naive.py @@ -5,7 +5,7 @@ def PAS(graph: IRGraph, resource): for node in graph.nodes(): if isinstance(node, IRConv2D): algo = node.algorithms('halo') - sub_nodes = graph.partition(node, algo, config=dict(idx=0, dim=2, num=resource.ngpus)) + sub_nodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus)) else: sub_nodes = graph.replicate(node, times=resource.ngpus) # sub_nodes = graph.replicate(node, times=resource.ngpus) diff --git a/examples/poisson/sci.py b/examples/poisson/sci.py index 62e75322..f3501542 100644 --- a/examples/poisson/sci.py +++ b/examples/poisson/sci.py @@ -4,7 +4,7 @@ import torch.nn.functional as F import time -# torch.set_default_tensor_type(torch.DoubleTensor) +torch.set_default_tensor_type(torch.DoubleTensor) import cube from examples.poisson.policy.naive import PAS From 4dc9204cd867d7360294f4d03948bc26ab887ed4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Feb 2022 13:27:07 +0800 Subject: [PATCH 0586/1892] pseudo code for megatron --- examples/gpt/policy/megatron.md | 59 +++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/examples/gpt/policy/megatron.md b/examples/gpt/policy/megatron.md index 136a789f..86489a4d 100644 --- a/examples/gpt/policy/megatron.md +++ b/examples/gpt/policy/megatron.md @@ -1,7 +1,7 @@ ``` function PADataParallel(Graph G, Resource R, Config C): - for node in G.nodes() do + for node in G.nodes do algorithm <- getPartitionAlgo(node, 'data parallelism') subnodes <- G.partition(node, algorithm, C.data_parallel_size) for dp_idx in 0 to C.data_parallel_size do @@ -11,7 +11,7 @@ function PADataParallel(Graph G, Resource R, Config C): function PATensorParallel(Graph G, Resource R, Config C): - for node in G.nodes() do + for node in G.nodes do algorithm <- getPartitionAlgo(node, 'tensor parallelism') subnodes <- G.partition(node, algorithm, C.tensor_parallel_size) for tp_idx in 0 to C.tensor_parallel_size do @@ -22,16 +22,16 @@ function PATensorParallel(Graph G, Resource R, Config C): function PAPipelineParallel(Graph G, Resource R, Config C): - for node in G.nodes() do + for node in G.nodes do algorithm <- getPartitionAlgo(node, 'data parallelism') G.partition(node, algorithm, C.num_micro_batches) - for node in G.nodes() do - stage_id <- getStageID(node, G, C.num_stages) // policy + for node in G.nodes do + stage_id <- getStageID(node, G, C.pipeline_parallel_size) // policy rank <- mapStageToRank(stage_id, R) G.assign(node, stage) - groupStageAndMicroBatch(G, C.num_stages, C.num_micro_batches) + groupStageAndMicroBatch(G, C.pipeline_parallel_size, C.num_micro_batches) return G @@ -40,19 +40,19 @@ function PSPipelineParallel(Graph G, Resource R, Config C): sequence <- EmptyArray[] // warmup phase for micro_batch_id in 0 to C.num_micro_batches do - for stage_id in 0 to C.num_stages - micro_batch_id do + for stage_id in 0 to C.pipeline_parallel_size - micro_batch_id do node <- getForwardStage(G, micro_batch_id, stage_id) arrayPush(sequence, node) # steady and cooldown phase for micro_batch_id in 0 to C.num_micro_batches do // enqueue backward - for stage_id in C.num_stages to 0 do + for stage_id in C.pipeline_parallel_size to 0 do node <- getBackwardStage(G, micro_batch_id, stage_id) arrayPush(sequence, node) // enqueue forward - for stage_id in 0 to C.num_stages do - mid <- micro_batch_id + C.num_stages - stage_id - if mid <= C.num_stages then + for stage_id in 0 to C.pipeline_parallel_size do + mid <- micro_batch_id + C.pipeline_parallel_size - stage_id + if mid <= C.pipeline_parallel_size then node <- getForwardStage(G, mid, stage_id) arrayPush(sequence, node) G.schedule(sequence) @@ -60,15 +60,32 @@ function PSPipelineParallel(Graph G, Resource R, Config C): function Megatron(Graph G, Resource R, Config C): - // Resource split - R_data, R_pipe, R_tensor <- splitResource(R, C) - // split to stages - G <- PAPipelineParallel(G, R_pipe, C) - // inner stage: data + tensor parallelism - for stage in G.nodes: - PADataParallel(stage, R_data, C) - PATensorParallel(stage, R_tensor, C) - // inter stage: 1F1B scheduling - G <- PSPipelineParallel(G, R_pipe, C) + // Graph G: Dataflow graph containing operators as nodes + // Resource R: Environment Resource including GPU numbers and topology + // Config C: policy user configuration including: + // data_parallel_size, + // tensor_parallel_size, + // pipeline_parallel_size, + // num_micro_batches + + // Resource split: group resources + Rs <- splitResource(R, C) + R_pp <- getResourceForPP(Rs, C) + + // split to stages and micro-batches + G <- PAPipelineParallel(G, R_pp, C) + + // inter / inner stage scheduling: 1F1B scheduling + G <- PSPipelineParallel(G, R_pp, C) + + // inner stage parallelism: hybrid parallelism + for stage in G.nodes do + // data parallelism + R_dp <- getResourceForDP(Rs, stage_id) + PADataParallel(stage, R_dp, C) + // tensor parallelism + R_tp <- getResourceForTP(Rs, stage_id) + PATensorParallel(stage, R_tp, C) + return G ``` \ No newline at end of file From ceaeb244ac1f88ca49f4aab5134388f62cb104d9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Feb 2022 13:39:56 +0800 Subject: [PATCH 0587/1892] update megatron pseudo code --- examples/gpt/policy/megatron.md | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/gpt/policy/megatron.md b/examples/gpt/policy/megatron.md index 86489a4d..86954c76 100644 --- a/examples/gpt/policy/megatron.md +++ b/examples/gpt/policy/megatron.md @@ -31,23 +31,24 @@ function PAPipelineParallel(Graph G, Resource R, Config C): rank <- mapStageToRank(stage_id, R) G.assign(node, stage) + // group to a sub-graph (block): A microbatch on one stage groupStageAndMicroBatch(G, C.pipeline_parallel_size, C.num_micro_batches) return G function PSPipelineParallel(Graph G, Resource R, Config C): - // each node in G stands for a stage (sub-graph) + // each node in G stands for a block (sub-graph) sequence <- EmptyArray[] // warmup phase for micro_batch_id in 0 to C.num_micro_batches do for stage_id in 0 to C.pipeline_parallel_size - micro_batch_id do - node <- getForwardStage(G, micro_batch_id, stage_id) + node <- getForwardBlock(G, micro_batch_id, stage_id) arrayPush(sequence, node) # steady and cooldown phase for micro_batch_id in 0 to C.num_micro_batches do // enqueue backward for stage_id in C.pipeline_parallel_size to 0 do - node <- getBackwardStage(G, micro_batch_id, stage_id) + node <- getBackwardBlock(G, micro_batch_id, stage_id) arrayPush(sequence, node) // enqueue forward for stage_id in 0 to C.pipeline_parallel_size do @@ -72,20 +73,21 @@ function Megatron(Graph G, Resource R, Config C): Rs <- splitResource(R, C) R_pp <- getResourceForPP(Rs, C) - // split to stages and micro-batches + // group into blocks (each block is a microbatch on a stage) G <- PAPipelineParallel(G, R_pp, C) - // inter / inner stage scheduling: 1F1B scheduling + // inter block scheduling: 1F1B scheduling G <- PSPipelineParallel(G, R_pp, C) - // inner stage parallelism: hybrid parallelism - for stage in G.nodes do + // inner block parallelism: hybrid parallelism + for block in G.nodes do + stage_id <- getStageID(G, block) // data parallelism R_dp <- getResourceForDP(Rs, stage_id) - PADataParallel(stage, R_dp, C) + PADataParallel(block, R_dp, C) // tensor parallelism R_tp <- getResourceForTP(Rs, stage_id) - PATensorParallel(stage, R_tp, C) + PATensorParallel(block, R_tp, C) return G ``` \ No newline at end of file From 27773231506d881bebb64ceee941820d1c4fb413 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Feb 2022 17:18:19 +0800 Subject: [PATCH 0588/1892] remove device mapping --- examples/gpt/policy/megatron.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/gpt/policy/megatron.md b/examples/gpt/policy/megatron.md index 86954c76..79c35e8a 100644 --- a/examples/gpt/policy/megatron.md +++ b/examples/gpt/policy/megatron.md @@ -5,8 +5,7 @@ function PADataParallel(Graph G, Resource R, Config C): algorithm <- getPartitionAlgo(node, 'data parallelism') subnodes <- G.partition(node, algorithm, C.data_parallel_size) for dp_idx in 0 to C.data_parallel_size do - rank <- mapDpToRank(dp_idx, R) - G.assign(subnodes[dp_idx], rank) + G.assign(subnodes[dp_idx], dp_idx) return G @@ -15,8 +14,7 @@ function PATensorParallel(Graph G, Resource R, Config C): algorithm <- getPartitionAlgo(node, 'tensor parallelism') subnodes <- G.partition(node, algorithm, C.tensor_parallel_size) for tp_idx in 0 to C.tensor_parallel_size do - rank <- mapTpToRank(tp_idx, R) - G.assign(subnodes[tp_idx], rank) + G.assign(subnodes[tp_idx], tp_idx) return G @@ -28,8 +26,7 @@ function PAPipelineParallel(Graph G, Resource R, Config C): for node in G.nodes do stage_id <- getStageID(node, G, C.pipeline_parallel_size) // policy - rank <- mapStageToRank(stage_id, R) - G.assign(node, stage) + G.assign(node, stage_id) // group to a sub-graph (block): A microbatch on one stage groupStageAndMicroBatch(G, C.pipeline_parallel_size, C.num_micro_batches) From e9dbe74db18ce5027fb692b6fa7418bbbef65947 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 10 Feb 2022 09:17:01 +0800 Subject: [PATCH 0589/1892] hybrid partition on conv --- examples/poisson/policy/naive.py | 7 ++++++- examples/poisson/sci.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/poisson/policy/naive.py b/examples/poisson/policy/naive.py index bb9f45bf..0863b65e 100644 --- a/examples/poisson/policy/naive.py +++ b/examples/poisson/policy/naive.py @@ -4,8 +4,13 @@ def PAS(graph: IRGraph, resource): for node in graph.nodes(): if isinstance(node, IRConv2D): + sub_nodes = list() algo = node.algorithms('halo') - sub_nodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus)) + Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus // 2)) + for Wnode in Wnodes: + algo = Wnode.algorithms('halo') + Hnodes = graph.partition(Wnode, algo, config=dict(idx=0, dim=2, num=2)) + sub_nodes += Hnodes else: sub_nodes = graph.replicate(node, times=resource.ngpus) # sub_nodes = graph.replicate(node, times=resource.ngpus) diff --git a/examples/poisson/sci.py b/examples/poisson/sci.py index f3501542..8d78e90b 100644 --- a/examples/poisson/sci.py +++ b/examples/poisson/sci.py @@ -11,7 +11,7 @@ """ OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=2 \ + --nproc_per_node=4 \ --nnodes=1 \ examples/poisson/sci.py """ From 26fed107c53e7c813296d2d274d17d65ef40fe74 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 11 Feb 2022 18:54:36 +0800 Subject: [PATCH 0590/1892] grouping consecutive adapters --- cube/compiler.py | 8 ++++- cube/execplan/planpass/grouping.py | 56 ++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/cube/compiler.py b/cube/compiler.py index d4d19896..21a2ba23 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,3 +1,4 @@ +from email.headerregistry import Group from typing import Callable, Tuple, Union, Optional import torch import time @@ -13,7 +14,7 @@ from cube.logics.translator import LogicTranslator from cube.execplan import ExectuionPlan -from cube.execplan.planpass.grouping import Grouping +from cube.execplan.planpass.grouping import Grouping, GroupingAdapter from cube.execplan.planpass.fusion import P2PFusion from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -174,6 +175,11 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on p2pfusion operations: {:.2f} s'.format(span)) + start = time.time() + execplan = GroupingAdapter.apply(execplan) + span = time.time() - start + print('> planpass on grouping adapters : {:.2f} s'.format(span)) + # execplan.draw(outfile='execplan.png') if torch.distributed.is_initialized(): diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 5805f657..57fbf010 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -2,10 +2,12 @@ Operation grouping """ +from sqlite3 import adapt from typing import List, Dict, Tuple from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass +from cube.graph.adapter.adapter import IRAdapter from cube.graph.operator.operator import IRBpOperation, IRFwOperation from cube.ir.cten import IRCell @@ -112,3 +114,57 @@ def consecutive(seq: List[IRCell], pieces: List[IRCell], node: IRCell): if idx != max(pidx) + 1 and idx != min(pidx) - 1: return False return True + + +class GroupingAdapter(PlanPass): + + @staticmethod + def apply(execplan: ExectuionPlan) -> ExectuionPlan: + for devid in execplan.devices(): + groups: List[List[IRAdapter]] = GroupingAdapter.consecutive( + execplan.sequence(devid)) + for adapters in groups: + if len(adapters) <= 1: + continue + sprims, tprims, mprims = list(), list(), list() + inputs, idevices = list(), list() + outputs, odevices = list(), list() + for adapter in adapters: + sprims += adapter.prims(move=False, merge=False, coll=False) + tprims += adapter.prims(select=False, merge=False) + mprims += adapter.prims(select=False, move=False, coll=False) + for idx, input in enumerate(adapter.inputs()): + if devid in adapter.idevice(idx): + if input not in inputs: + inputs.append(input) + idevices.append(adapter.idevice(idx)) + for idx, output in enumerate(adapter.outputs()): + if devid in adapter.odevice(idx): + if output not in outputs: + outputs.append(output) + odevices.append(adapter.odevice(idx)) + prims = sprims + tprims + mprims + fused_adapter = IRAdapter(prims, + inputs = inputs, idevices = idevices, + outputs = outputs, odevices = odevices) + start = execplan.sequence(devid).index(adapters[0]) + end = execplan.sequence(devid).index(adapters[-1]) + for _ in range(end - start + 1): + execplan.at(devid).pop(start) + execplan.at(devid).insert(start, fused_adapter) + return execplan + + @staticmethod + def consecutive(seq: List[IRCell]) -> List[List[IRAdapter]]: + group = list() + curr = list() + curr_idx = -1 + for idx, node in enumerate(seq + [None]): + if isinstance(node, IRAdapter) and idx == curr_idx + 1: + curr.append(node) + else: + if len(curr) != 0: + group.append(curr) + curr = list() + curr_idx = idx + return group From f66389cdc68529874bfd4711a99b8d81302b5301 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 11 Feb 2022 18:55:25 +0800 Subject: [PATCH 0591/1892] remove useless import --- cube/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cube/compiler.py b/cube/compiler.py index 21a2ba23..63db166d 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,4 +1,3 @@ -from email.headerregistry import Group from typing import Callable, Tuple, Union, Optional import torch import time From 9180cfdcc18515acc4df52f5db1eece2ed0acef2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Feb 2022 15:29:14 +0800 Subject: [PATCH 0592/1892] P: keep assigned device set; execplan use analyze; fix einops * parse bug; enable scheduling search --- cube/execplan/execplan.py | 102 ++++++++++++-------- cube/graph/graph.py | 5 + cube/graph/operator/function/einops.py | 6 +- examples/mlp/policy/st_search.py | 123 +++++++++++++++++++++++++ requirements.txt | 3 +- 5 files changed, 200 insertions(+), 39 deletions(-) create mode 100644 examples/mlp/policy/st_search.py diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 515a61ca..44026918 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -1,4 +1,5 @@ -from typing import List, Optional +from ast import Call +from typing import Callable, List, Optional import copy from cube.graph.adapter.adapter import IRAdapter from cube.graph.operator.operator import IRBpOperation, IRFwOperation @@ -73,7 +74,11 @@ def set(self, device_id: int, seq: List[IRCell]): raise TypeError("Expected a list of Cell") self.device_seq[device_id] = seq - def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): + def analyze(self, + map2time: Optional[Callable] = None, + map2mem: Optional[Callable] = None, + map2name: Optional[Callable] = None, + outfile = None): """ Draw the execution timeline. @@ -90,19 +95,50 @@ def draw(self, spans: Optional[List[int]] = None, outfile='./execplan.png'): # timeline [ [ (start_time, end_time), ... ], ... ] device_timeline = [list() for _ in range(ndevice)] device_nodes = [list() for _ in range(ndevice)] - - def map2time(node): - if isinstance(node, IRGraph): - span = 0 - for node in node.nodes(): - span += map2time(node) - if isinstance(node, IRFwOperation): - return 1 - if isinstance(node, IRBpOperation): - return 2 - if isinstance(node, IRAdapter): - return 0.5 - return 0 + device_mem = [0] * ndevice + device_peak_mem = [0] * ndevice + + if map2time is None: + def map2time(node): + if isinstance(node, IRGraph): + span = 0 + for node in node.nodes(): + span += map2time(node) + if isinstance(node, IRFwOperation): + return 1 + if isinstance(node, IRBpOperation): + return 2 + if isinstance(node, IRAdapter): + return 0.5 + return 0 + + if map2mem is None: + def map2mem(node): + if isinstance(node, IRGraph): + peak_mem = 0 + curr_mem = 0 + for node in node.nodes(): + curr_mem += map2mem(node) + peak_mem = max(curr_mem, peak_mem) + if isinstance(node, IRFwOperation): + return 1 + if isinstance(node, IRBpOperation): + return -1 + return 0 + + if map2name is None: + def map2name(node): + if isinstance(node, IRGraph): + if all([isinstance(n, IRFwOperation) for n in node.nodes()]): + return f'f{node._id}' + if all([isinstance(n, IRBpOperation) for n in node.nodes()]): + if node.mirror is not None: + return f'b{node.mirror._id}' + if isinstance(node, IRFwOperation): + return f'f{node._id}' + if isinstance(node, IRBpOperation): + return f'b{node.mirror._id}' + return str(node._id) def map2color(node): if isinstance(node, IRGraph): @@ -114,25 +150,14 @@ def map2color(node): if isinstance(node, IRAdapter): return '#70AD47' # excel green - def map2name(node): - if isinstance(node, IRGraph): - if all([isinstance(n, IRFwOperation) for n in node.nodes()]): - return f'f{node._id}' - if all([isinstance(n, IRBpOperation) for n in node.nodes()]): - if node.mirror is not None: - return f'b{node.mirror._id}' - return str(node._id) - - if spans is None: - print("Using default timing: fwop=1, bwop=2, adapter=0.1") - spans = list() - for node in self.graph.nodes(): - span = map2time(node) - spans.append(span) - graph = self.graph - for node, span_time in zip(self.graph.nodes(), spans): + for node in self.graph.nodes(): + span, mem = map2time(node), map2mem(node) for device in node.device: + # memory + device_mem[device] += mem + if device_peak_mem[device] < device_mem[device]: + device_peak_mem[device] = device_mem[device] # tight execution if no dependency if len(device_timeline[device]) == 0: start_time = 1 @@ -148,17 +173,19 @@ def map2name(node): if graph.happen_before(other_node, node): start_time = max(start_time, end_time) break - device_timeline[device].append((start_time, start_time + span_time)) + device_timeline[device].append((start_time, start_time + span)) device_nodes[device].append(node) + max_time = max( + [tline[-1][1] for tline in device_timeline if len(tline) != 0] + ) + max_mem = max(device_peak_mem) + # draw the timeline if outfile is not None: import matplotlib.pyplot as plt from matplotlib.patches import Rectangle - max_time = max( - [tline[-1][1] for tline in device_timeline if len(tline) != 0] - ) plt.rcParams['figure.figsize'] = (4.0 * max_time // ndevice, 4.0) fig, ax = plt.subplots() renderer = fig.canvas.get_renderer() @@ -197,7 +224,7 @@ def map2name(node): for fs in range(40, 1, -2): txt.set_fontsize(fs) tbox = txt.get_window_extent(renderer) - if tbox.x0 >= rbox.x0 and tbox.x1 <= rbox.x1 and tbox.y0 >= rbox.y0 and tbox.y1 <= rbox.y1: + if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: break fontsize = min(fontsize, fs) txts.append(txt) @@ -216,6 +243,7 @@ def map2name(node): plt.tight_layout() plt.savefig(outfile) + return max_time, max_mem def __repr__(self): dscp = f'Execution Plan ({self.graph.name}):\n' diff --git a/cube/graph/graph.py b/cube/graph/graph.py index f7f66c76..d9b5c664 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -377,6 +377,11 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional bnode.update() self.attach(bnode, idx) updated.add(fnode._id) + # update device + for fnode in fnodes: + fnode.device = op.device + if isinstance(fnode.mirror, IRCell): + fnode.mirror.device = op.device self.reset_dependency() return fnodes diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index de135146..89136a05 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -359,7 +359,11 @@ def parse(self, anno: EinopAnno): if len(unused_annos) < span: raise RuntimeError("Too many introduced dimensions") for dim in range(span): - expand_dims.append(EinDim([unused_annos[dim]])) + if '*' not in anno.anno.split('->')[-1]: + anno_dim = EinDim([unused_annos[dim] + '+']) + else: + anno_dim = EinDim([unused_annos[dim]]) + expand_dims.append(anno_dim) if len(expand_dims) != span: return False, None, None anno.inputs[idx] = anno.inputs[idx][:start] + expand_dims + anno.inputs[idx][start+1:] diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py new file mode 100644 index 00000000..b148f252 --- /dev/null +++ b/examples/mlp/policy/st_search.py @@ -0,0 +1,123 @@ +import copy +from typing import Callable, List +import sys +from cube.graph.graph import IRGraph, IRFwOperation +from cube.graph.operator.operator import IRBpOperation, IRDataOperation +from cube.ir.cten import IRCell +from cube.execplan import ExectuionPlan + + +def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): + """ + Get ready-to-emit node list from remain node set + """ + ready = list() + for node in remain: + satisfy = True + for pre in node.predecessors(): + if pre not in seq: + satisfy = False + break + if satisfy: + # if len(seq) > 0 and len(seq[-1].device) != 0 and len(node.device) != 0: + # # no dependency pruning + # if seq[-1] not in node.predecessors(): + # if node.device[0] < seq[-1].device[0]: + # continue + ready.append(node) + return ready + + +def topo_sequence(nodes: List[IRCell], seq = None): + if seq is None: + seq = list() + if len(nodes) == 0: + yield seq + # initial entry + entry_nodes = ready_emit_set(remain=nodes, seq=seq) + if len(entry_nodes) == 0: + return None + for node in entry_nodes: + seq = seq + [node] + nid = nodes.index(node) + sub_nodes = nodes[:nid] + nodes[nid+1:] + for res in topo_sequence(sub_nodes, seq): + if res is None: + continue + yield res + seq = seq[:-1] + + +def stage_division(graph: IRGraph, node: IRCell, num_stages: int) -> int: + """ + Determine stage division + """ + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + num_fnodes = len(fnodes) + idx = fnodes.index(node) + stage = min(idx // (num_fnodes // num_stages), num_stages - 1) + return stage + + +def estimator(execplan: ExectuionPlan, map2time: Callable, map2mem: Callable): + """ + Estimate time + """ + max_time, max_mem = execplan.analyze(map2time=map2time, map2mem=map2mem) + return max_time, max_mem + +def PAS(graph: IRGraph, resource): + num_microbatch = 2 + num_stages = 2 + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + # split to micro batches + for node in fnodes: + stage = stage_division(graph, node, num_stages=num_stages) + graph.assign(node, stage) + for node in fnodes: + # partition at batch dimension + algo = node.algorithms('dim') + graph.partition( + node, algo, config=dict(idx=0, dim=0, num=num_microbatch)) + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + graph.assign(node, 0) + + def map2time(node: IRCell): + if isinstance(node, IRFwOperation): + return 1 + if isinstance(node, IRBpOperation): + return 2 + return 0 + + def map2mem(node: IRCell): + if isinstance(node, IRFwOperation): + return 1 + if isinstance(node, IRBpOperation): + return -1 + return 0 + + bucket_mem = list(range(len(fnodes) + 1)) + bucket_times = list([sys.maxsize for _ in range(len(fnodes) + 1)]) + bucket_seqs = [None] * (len(fnodes) + 1) + + print('start sorting...') + for idx, seq in enumerate(topo_sequence(graph.nodes())): + # seqrepr = [node._id for node in seq] + # print(seqrepr) + graph._nodes = seq + execplan = ExectuionPlan(graph) + span, mem = execplan.analyze(map2time=map2time, map2mem=map2mem) + bucket = bucket_mem.index(mem) + + # execplan.draw(outfile='out.png') + # print(span, mem) + # input('>>> ') + + if span < bucket_times[bucket]: + print(f'find better plan at mem budget {mem}: span: {span}') + bucket_times[bucket] = span + bucket_seqs[bucket] = copy.copy(seq) + execplan.analyze(map2time=map2time, outfile=f'plan.mem{mem}.png') + print(f'done search on {idx + 1} sequences') + assert False diff --git a/requirements.txt b/requirements.txt index f2466477..b4306a47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -z3-solver \ No newline at end of file +z3-solver +matplotlib \ No newline at end of file From 720cae91879f490bec5379f264c408da9610b800 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Feb 2022 16:19:50 +0800 Subject: [PATCH 0593/1892] add primitive: add schedule --- cube/graph/graph.py | 20 ++++++++++++ examples/mlp/policy/st_search.py | 54 +++++++++++++++++++++++--------- 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index d9b5c664..086e1f4b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -68,6 +68,8 @@ def __init__(self, def reset_dependency(self): """ Reset the node dataflow dependency + + Note all the predefined control dependencies will be removed. """ for node in self._nodes: node.clear_predecessor() @@ -498,6 +500,24 @@ def happen_before(self, node1: IRCell, node2: IRCell, skip=None) -> bool: return True return False + def add_schedule(self, nodes: List[IRCell]) -> bool: + """ + Add node happen before dependencies according to nodes list order + """ + if not all([isinstance(node, IRCell) for node in nodes]): + raise TypeError("Expected List[IRCell") + for idx in range(len(nodes) - 1): + prev = nodes[idx] + post = nodes[idx + 1] + if self.happen_before(post, prev): + return False + for idx in range(len(nodes) - 1): + prev = nodes[idx] + post = nodes[idx + 1] + prev.add_successor(output_index=-1, cell=post) + post.add_predecessor(input_index=-1, cell=prev) + return True + def set_order(self, seq: List[IRCell]): """ Set a topological order for IRGraph, which requires seq: diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index b148f252..890cb3c3 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -19,11 +19,11 @@ def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): satisfy = False break if satisfy: - # if len(seq) > 0 and len(seq[-1].device) != 0 and len(node.device) != 0: - # # no dependency pruning - # if seq[-1] not in node.predecessors(): - # if node.device[0] < seq[-1].device[0]: - # continue + if len(seq) > 0 and len(seq[-1].device) != 0 and len(node.device) != 0: + # pruning #1: filter out equal sequences + if seq[-1] not in node.predecessors(): + if node.device[0] < seq[-1].device[0]: + continue ready.append(node) return ready @@ -67,9 +67,24 @@ def estimator(execplan: ExectuionPlan, map2time: Callable, map2mem: Callable): return max_time, max_mem def PAS(graph: IRGraph, resource): - num_microbatch = 2 - num_stages = 2 + num_microbatch = 4 + num_stages = 4 + fstages = [list() for _ in range(num_microbatch * num_stages)] + + def f(micro_batch_id: int, stage_id: int): + return fstages[micro_batch_id * num_stages + stage_id] + + def b(micro_batch_id: int, stage_id: int): + fstage = f(micro_batch_id, stage_id) + bstage = [fnode.mirror for fnode in fstage][::-1] + return bstage + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + graph.assign(node, 0) + # split to micro batches for node in fnodes: stage = stage_division(graph, node, num_stages=num_stages) @@ -77,11 +92,20 @@ def PAS(graph: IRGraph, resource): for node in fnodes: # partition at batch dimension algo = node.algorithms('dim') - graph.partition( + sub_nodes = graph.partition( node, algo, config=dict(idx=0, dim=0, num=num_microbatch)) - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - graph.assign(node, 0) + for mid, sub_node in enumerate(sub_nodes): + f(mid, stage).append(sub_node) + + # pruning #2: symmetric microbatches, make micro-batch id smaller happen earlier + for sid in range(num_stages): + fops = list() + bops = list() + for mid in range(num_microbatch): + fops += f(mid, sid) + bops += b(mid, sid) + assert graph.add_schedule(fops) + assert graph.add_schedule(bops) def map2time(node: IRCell): if isinstance(node, IRFwOperation): @@ -97,9 +121,9 @@ def map2mem(node: IRCell): return -1 return 0 - bucket_mem = list(range(len(fnodes) + 1)) - bucket_times = list([sys.maxsize for _ in range(len(fnodes) + 1)]) - bucket_seqs = [None] * (len(fnodes) + 1) + bucket_mem = list(range(num_microbatch * len(fnodes) // num_stages + 1)) + bucket_times = list([sys.maxsize for _ in range(len(bucket_mem))]) + bucket_seqs = [None] * len(bucket_mem) print('start sorting...') for idx, seq in enumerate(topo_sequence(graph.nodes())): @@ -113,6 +137,8 @@ def map2mem(node: IRCell): # execplan.draw(outfile='out.png') # print(span, mem) # input('>>> ') + if (idx + 1) % 5000 == 0: + print(f'progress: searched {(idx + 1) // 1000}K seqs') if span < bucket_times[bucket]: print(f'find better plan at mem budget {mem}: span: {span}') From add1972231cc6a0f343b70a7cda843db8438e61e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Feb 2022 20:14:30 +0800 Subject: [PATCH 0594/1892] fix analyze bug by reseting dependencies --- cube/execplan/execplan.py | 5 +- examples/mlp/policy/st_search.py | 125 ++++++++++++++++++++++++------- 2 files changed, 100 insertions(+), 30 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 44026918..2c8a9587 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -90,7 +90,6 @@ def analyze(self, outfile: the output file name """ - self.graph.reset_dependency() ndevice = len(self.devices()) # timeline [ [ (start_time, end_time), ... ], ... ] device_timeline = [list() for _ in range(ndevice)] @@ -170,7 +169,7 @@ def map2color(node): continue for nid, (_, end_time) in enumerate(timeline[::-1]): other_node = dev_seq[::-1][nid] - if graph.happen_before(other_node, node): + if other_node in node.predecessors(): start_time = max(start_time, end_time) break device_timeline[device].append((start_time, start_time + span)) @@ -185,7 +184,7 @@ def map2color(node): if outfile is not None: import matplotlib.pyplot as plt from matplotlib.patches import Rectangle - + plt.close('all') plt.rcParams['figure.figsize'] = (4.0 * max_time // ndevice, 4.0) fig, ax = plt.subplots() renderer = fig.canvas.get_renderer() diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index 890cb3c3..793d8f85 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -6,6 +6,8 @@ from cube.ir.cten import IRCell from cube.execplan import ExectuionPlan +from multiprocessing import Pool + def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): """ @@ -48,11 +50,22 @@ def topo_sequence(nodes: List[IRCell], seq = None): seq = seq[:-1] -def stage_division(graph: IRGraph, node: IRCell, num_stages: int) -> int: +def topo_sequence_batch(nodes: List[IRCell], bs=1): + seqs = list() + for idx, seq in enumerate(topo_sequence(nodes)): + seqs.append(seq) + if len(seqs) % bs == 0: + print(f'dispatch {len(seqs)} seq...') + yield seqs + seqs = list() + if len(seqs) > 0: + yield seqs + + +def stage_division(fnodes: List[IRCell], node: IRCell, num_stages: int) -> int: """ Determine stage division """ - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] num_fnodes = len(fnodes) idx = fnodes.index(node) stage = min(idx // (num_fnodes // num_stages), num_stages - 1) @@ -66,8 +79,39 @@ def estimator(execplan: ExectuionPlan, map2time: Callable, map2mem: Callable): max_time, max_mem = execplan.analyze(map2time=map2time, map2mem=map2mem) return max_time, max_mem + +def worker(seqs: List[List[IRCell]], bucket_mem: List[int]): + def map2time(node: IRCell): + if isinstance(node, IRFwOperation): + return 1 + if isinstance(node, IRBpOperation): + return 2 + return 0 + + def map2mem(node: IRCell): + if isinstance(node, IRFwOperation): + return 1 + if isinstance(node, IRBpOperation): + return -1 + return 0 + + bucket_times = list([sys.maxsize for _ in range(len(bucket_mem))]) + bucket_seqs = [None] * len(bucket_mem) + graph = IRGraph([], [], [], 'search') + for seq in seqs: + graph._nodes = seq + # graph.reset_dependency() # this needs as in other process dependency will break + execplan = ExectuionPlan(graph) + span, mem = execplan.analyze(map2time=map2time, map2mem=map2mem) + bucket = bucket_mem.index(mem) + if span < bucket_times[bucket]: + bucket_times[bucket] = span + bucket_seqs[bucket] = copy.copy(seq) + return bucket_times, bucket_seqs + + def PAS(graph: IRGraph, resource): - num_microbatch = 4 + num_microbatch = 8 num_stages = 4 fstages = [list() for _ in range(num_microbatch * num_stages)] @@ -87,14 +131,12 @@ def b(micro_batch_id: int, stage_id: int): # split to micro batches for node in fnodes: - stage = stage_division(graph, node, num_stages=num_stages) - graph.assign(node, stage) - for node in fnodes: - # partition at batch dimension + stage = stage_division(fnodes, node, num_stages=num_stages) algo = node.algorithms('dim') sub_nodes = graph.partition( node, algo, config=dict(idx=0, dim=0, num=num_microbatch)) for mid, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, stage) f(mid, stage).append(sub_node) # pruning #2: symmetric microbatches, make micro-batch id smaller happen earlier @@ -126,24 +168,53 @@ def map2mem(node: IRCell): bucket_seqs = [None] * len(bucket_mem) print('start sorting...') - for idx, seq in enumerate(topo_sequence(graph.nodes())): - # seqrepr = [node._id for node in seq] - # print(seqrepr) - graph._nodes = seq - execplan = ExectuionPlan(graph) - span, mem = execplan.analyze(map2time=map2time, map2mem=map2mem) - bucket = bucket_mem.index(mem) - - # execplan.draw(outfile='out.png') - # print(span, mem) - # input('>>> ') - if (idx + 1) % 5000 == 0: - print(f'progress: searched {(idx + 1) // 1000}K seqs') - - if span < bucket_times[bucket]: - print(f'find better plan at mem budget {mem}: span: {span}') - bucket_times[bucket] = span - bucket_seqs[bucket] = copy.copy(seq) - execplan.analyze(map2time=map2time, outfile=f'plan.mem{mem}.png') - print(f'done search on {idx + 1} sequences') + + nproc = 24 + worker_samples = 1000 + pool = Pool(processes=nproc) + for idx, seqs in enumerate(topo_sequence_batch(graph.nodes(), bs=nproc * worker_samples)): + results = list() + for wid in range(nproc): + start = min(nproc * worker_samples, worker_samples* wid) + stop = min(nproc * worker_samples, start + worker_samples) + worker_seqs = seqs[start:stop] + results.append(pool.apply_async(worker, (worker_seqs, bucket_mem))) + results = map(lambda res: res.get(), results) + # merge results + for times, res_seqs in results: + for mem, (span_new, span_old) in enumerate(zip(times, bucket_times)): + if span_new < span_old: + print(f'find better plan at mem budget {mem}: span: {span_new}') + bucket_times[mem] = span_new + bucket_seqs[mem] = res_seqs[mem] + _graph = IRGraph([], [], [], 'search') + _graph._nodes = res_seqs[mem] + execplan = ExectuionPlan(_graph) + execplan.analyze(map2time=map2time, outfile=f'plan.mem{mem}.png') + if (idx + 1) % 1 == 0: + print(f'progress: searched {(idx + 1) * nproc * worker_samples} K sequences') + if len(seqs) != worker_samples: + num = idx + len(seqs) / (worker_samples * nproc) + print(f'done search on {int(num * nproc * worker_samples)} K sequences') assert False + + # _graph = IRGraph([], [], [], 'search') + # for idx, seq in enumerate(topo_sequence(graph.nodes())): + # _graph._nodes = seq + # execplan = ExectuionPlan(_graph) + # span, mem = execplan.analyze(map2time=map2time, map2mem=map2mem) + # bucket = bucket_mem.index(mem) + # + # # execplan.draw(outfile='out.png') + # # print(span, mem) + # # input('>>> ') + # if (idx + 1) % 5000 == 0: + # print(f'progress: searched {(idx + 1) // 1000}K seqs') + # + # if span < bucket_times[bucket]: + # print(f'find better plan at mem budget {mem}: span: {span}') + # bucket_times[bucket] = span + # bucket_seqs[bucket] = copy.copy(seq) + # execplan.analyze(map2time=map2time, outfile=f'plan.mem{mem}.png') + # print(f'done search on {idx + 1} sequences') + # assert False From 076a00d45a985caded9c3b1960bf473ed9f9819b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 23 Feb 2022 09:55:46 +0800 Subject: [PATCH 0595/1892] multi-process search --- examples/mlp/policy/st_search.py | 298 +++++++++++++++---------------- 1 file changed, 140 insertions(+), 158 deletions(-) diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index 793d8f85..e5fa671e 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -1,6 +1,5 @@ import copy -from typing import Callable, List -import sys +from typing import List, Tuple, Dict from cube.graph.graph import IRGraph, IRFwOperation from cube.graph.operator.operator import IRBpOperation, IRDataOperation from cube.ir.cten import IRCell @@ -9,126 +8,136 @@ from multiprocessing import Pool -def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): +class Estimator: """ - Get ready-to-emit node list from remain node set + A node tag is represented as (mem_weight, mem_activation, exec_time) """ - ready = list() - for node in remain: - satisfy = True - for pre in node.predecessors(): - if pre not in seq: - satisfy = False - break - if satisfy: - if len(seq) > 0 and len(seq[-1].device) != 0 and len(node.device) != 0: - # pruning #1: filter out equal sequences - if seq[-1] not in node.predecessors(): - if node.device[0] < seq[-1].device[0]: - continue - ready.append(node) - return ready - - -def topo_sequence(nodes: List[IRCell], seq = None): - if seq is None: - seq = list() - if len(nodes) == 0: - yield seq - # initial entry - entry_nodes = ready_emit_set(remain=nodes, seq=seq) - if len(entry_nodes) == 0: - return None - for node in entry_nodes: - seq = seq + [node] - nid = nodes.index(node) - sub_nodes = nodes[:nid] + nodes[nid+1:] - for res in topo_sequence(sub_nodes, seq): - if res is None: - continue - yield res - seq = seq[:-1] - - -def topo_sequence_batch(nodes: List[IRCell], bs=1): - seqs = list() - for idx, seq in enumerate(topo_sequence(nodes)): - seqs.append(seq) - if len(seqs) % bs == 0: - print(f'dispatch {len(seqs)} seq...') - yield seqs - seqs = list() - if len(seqs) > 0: - yield seqs - + @staticmethod + def taging(graph: IRGraph): + for node in graph.nodes(): + # tag: (mem_weight, mem_activation, span) + if isinstance(node, IRFwOperation): + node.cost = (0, 1, 1) + elif isinstance(node, IRBpOperation): + node.cost = (0, -1, 2) + else: + node.cost = (0, 0, 0) + + @staticmethod + def map2mem(node: IRCell): + if node.cost is not None: + mem_w, mem_a, span = node.cost + else: + mem_w, mem_a, span = 0, 0, 0 + return mem_w + mem_a -def stage_division(fnodes: List[IRCell], node: IRCell, num_stages: int) -> int: - """ - Determine stage division - """ - num_fnodes = len(fnodes) - idx = fnodes.index(node) - stage = min(idx // (num_fnodes // num_stages), num_stages - 1) - return stage + @staticmethod + def map2time(node: IRCell): + if node.cost is not None: + mem_w, mem_a, span = node.cost + else: + mem_w, mem_a, span = 0, 0, 0 + return span -def estimator(execplan: ExectuionPlan, map2time: Callable, map2mem: Callable): +class TSampler: """ - Estimate time + Schedule sampler """ - max_time, max_mem = execplan.analyze(map2time=map2time, map2mem=map2mem) - return max_time, max_mem - - -def worker(seqs: List[List[IRCell]], bucket_mem: List[int]): - def map2time(node: IRCell): - if isinstance(node, IRFwOperation): - return 1 - if isinstance(node, IRBpOperation): - return 2 - return 0 + @staticmethod + def topo_sequence_batch(nodes: List[IRCell], bs=1): + seqs = list() + for idx, seq in enumerate(TSampler.topo_sequence(nodes)): + seqs.append(seq) + if len(seqs) % bs == 0: + print(f'dispatch {len(seqs)} seq...') + yield seqs + seqs = list() + if len(seqs) > 0: + yield seqs - def map2mem(node: IRCell): - if isinstance(node, IRFwOperation): - return 1 - if isinstance(node, IRBpOperation): - return -1 - return 0 - - bucket_times = list([sys.maxsize for _ in range(len(bucket_mem))]) - bucket_seqs = [None] * len(bucket_mem) - graph = IRGraph([], [], [], 'search') - for seq in seqs: - graph._nodes = seq - # graph.reset_dependency() # this needs as in other process dependency will break - execplan = ExectuionPlan(graph) - span, mem = execplan.analyze(map2time=map2time, map2mem=map2mem) - bucket = bucket_mem.index(mem) - if span < bucket_times[bucket]: - bucket_times[bucket] = span - bucket_seqs[bucket] = copy.copy(seq) - return bucket_times, bucket_seqs + @staticmethod + def topo_sequence(nodes: List[IRCell], seq = None): + if seq is None: + seq = list() + if len(nodes) == 0: + yield seq + # initial entry + entry_nodes = TSampler.ready_emit_set(remain=nodes, seq=seq) + if len(entry_nodes) == 0: + return None + for node in entry_nodes: + seq = seq + [node] + nid = nodes.index(node) + sub_nodes = nodes[:nid] + nodes[nid+1:] + for res in TSampler.topo_sequence(sub_nodes, seq): + if res is None: + continue + yield res + seq = seq[:-1] + + @staticmethod + def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): + """ + Get ready-to-emit node list from remain node set + """ + ready = list() + for node in remain: + satisfy = True + for pre in node.predecessors(): + if pre not in seq: + satisfy = False + break + if satisfy: + if len(seq) > 0 and len(seq[-1].device) != 0 and len(node.device) != 0: + # pruning #1: filter out equal sequences + if seq[-1] not in node.predecessors(): + if node.device[0] < seq[-1].device[0]: + continue + ready.append(node) + return ready + + +class Searcher: + + @staticmethod + def run(seqs: List[List[IRCell]]): + # mem -> (time, seq) + bucket = dict() + graph = IRGraph([], [], [], 'search') + for seq in seqs: + graph._nodes = seq + # graph.reset_dependency() # this needs as in other process dependency will break + execplan = ExectuionPlan(graph) + span, mem = execplan.analyze(map2time=Estimator.map2time, map2mem=Estimator.map2mem) + if mem not in bucket: + bucket[mem] = (span, copy.copy(seq)) + elif bucket[mem][0] > span: + bucket[mem] = (span, copy.copy(seq)) + return bucket def PAS(graph: IRGraph, resource): - num_microbatch = 8 + num_microbatch = 4 num_stages = 4 - fstages = [list() for _ in range(num_microbatch * num_stages)] + # ============================ micro-batch / stage split ============================ + fstages = [list() for _ in range(num_microbatch * num_stages)] def f(micro_batch_id: int, stage_id: int): return fstages[micro_batch_id * num_stages + stage_id] - def b(micro_batch_id: int, stage_id: int): fstage = f(micro_batch_id, stage_id) bstage = [fnode.mirror for fnode in fstage][::-1] return bstage fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - graph.assign(node, 0) - + def stage_division(fnodes: List[IRCell], node: IRCell, num_stages: int) -> int: + """Determine stage division + """ + num_fnodes = len(fnodes) + idx = fnodes.index(node) + stage = min(idx // (num_fnodes // num_stages), num_stages - 1) + return stage # split to micro batches for node in fnodes: stage = stage_division(fnodes, node, num_stages=num_stages) @@ -138,6 +147,11 @@ def b(micro_batch_id: int, stage_id: int): for mid, sub_node in enumerate(sub_nodes): graph.assign(sub_node, stage) f(mid, stage).append(sub_node) + # data operator + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + graph.assign(node, 0) + # ============================ micro-batch / stage split ============================ # pruning #2: symmetric microbatches, make micro-batch id smaller happen earlier for sid in range(num_stages): @@ -149,72 +163,40 @@ def b(micro_batch_id: int, stage_id: int): assert graph.add_schedule(fops) assert graph.add_schedule(bops) - def map2time(node: IRCell): - if isinstance(node, IRFwOperation): - return 1 - if isinstance(node, IRBpOperation): - return 2 - return 0 - - def map2mem(node: IRCell): - if isinstance(node, IRFwOperation): - return 1 - if isinstance(node, IRBpOperation): - return -1 - return 0 - - bucket_mem = list(range(num_microbatch * len(fnodes) // num_stages + 1)) - bucket_times = list([sys.maxsize for _ in range(len(bucket_mem))]) - bucket_seqs = [None] * len(bucket_mem) + Estimator.taging(graph) + # memory (int) -> (time, seq) + bucket = dict() print('start sorting...') - nproc = 24 - worker_samples = 1000 + nproc, worker_samples = 32, 512 pool = Pool(processes=nproc) - for idx, seqs in enumerate(topo_sequence_batch(graph.nodes(), bs=nproc * worker_samples)): + _graph = IRGraph([], [], [], 'search') + for idx, seqs in enumerate(TSampler.topo_sequence_batch(graph.nodes(), bs=nproc * worker_samples)): results = list() for wid in range(nproc): - start = min(nproc * worker_samples, worker_samples* wid) - stop = min(nproc * worker_samples, start + worker_samples) + start = worker_samples* wid + stop = start + worker_samples worker_seqs = seqs[start:stop] - results.append(pool.apply_async(worker, (worker_seqs, bucket_mem))) - results = map(lambda res: res.get(), results) + results.append(pool.apply_async(Searcher.run, (worker_seqs,))) + results: List[Dict[int, Tuple[int, List]]] = map(lambda res: res.get(), results) # merge results - for times, res_seqs in results: - for mem, (span_new, span_old) in enumerate(zip(times, bucket_times)): - if span_new < span_old: - print(f'find better plan at mem budget {mem}: span: {span_new}') - bucket_times[mem] = span_new - bucket_seqs[mem] = res_seqs[mem] - _graph = IRGraph([], [], [], 'search') - _graph._nodes = res_seqs[mem] + for worker_bucket in results: + for mem, (span, seq) in worker_bucket.items(): + better = False + if mem not in bucket: + better = True + elif bucket[mem][0] > span: + better = True + if better: + print(f'find better plan at mem budget {mem}: span: {span}') + bucket[mem] = (span, seq) + _graph._nodes = seq execplan = ExectuionPlan(_graph) - execplan.analyze(map2time=map2time, outfile=f'plan.mem{mem}.png') + execplan.analyze(map2time=Estimator.map2time, outfile=f'plan.mem{mem}.png') if (idx + 1) % 1 == 0: - print(f'progress: searched {(idx + 1) * nproc * worker_samples} K sequences') + print(f'progress: searched {(idx) * nproc * worker_samples + len(seqs)} K sequences') if len(seqs) != worker_samples: - num = idx + len(seqs) / (worker_samples * nproc) - print(f'done search on {int(num * nproc * worker_samples)} K sequences') + num = idx * nproc * worker_samples + len(seqs) + print(f'done search on {num} K sequences') assert False - - # _graph = IRGraph([], [], [], 'search') - # for idx, seq in enumerate(topo_sequence(graph.nodes())): - # _graph._nodes = seq - # execplan = ExectuionPlan(_graph) - # span, mem = execplan.analyze(map2time=map2time, map2mem=map2mem) - # bucket = bucket_mem.index(mem) - # - # # execplan.draw(outfile='out.png') - # # print(span, mem) - # # input('>>> ') - # if (idx + 1) % 5000 == 0: - # print(f'progress: searched {(idx + 1) // 1000}K seqs') - # - # if span < bucket_times[bucket]: - # print(f'find better plan at mem budget {mem}: span: {span}') - # bucket_times[bucket] = span - # bucket_seqs[bucket] = copy.copy(seq) - # execplan.analyze(map2time=map2time, outfile=f'plan.mem{mem}.png') - # print(f'done search on {idx + 1} sequences') - # assert False From d1d5142141dc0a2a8169359d4875bfdce1e37b9e Mon Sep 17 00:00:00 2001 From: lynex Date: Wed, 23 Feb 2022 13:17:45 +0800 Subject: [PATCH 0596/1892] add wrf for example --- examples/wrf/wrf.py | 286 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 examples/wrf/wrf.py diff --git a/examples/wrf/wrf.py b/examples/wrf/wrf.py new file mode 100644 index 00000000..16099a82 --- /dev/null +++ b/examples/wrf/wrf.py @@ -0,0 +1,286 @@ +import torch +torch.set_default_tensor_type(torch.DoubleTensor) +from torch import nn +import torch.nn.functional as F +# from linalg import tridiagonal + + +class WRF(torch.nn.Module): + r"""WRF Model + + Args: + theta (Tensor): inital potential temperature, (nz + 2, ny, nx), including boundary condition + p_t (Tensor): pressure at model top, (ny, nx), if None all zeros + p_s (Tensor): pressure at surface, (ny, nx), if None all sea level pressure + u0 (Tensor): inital x flow, (nz, ny, nx + 1), if None all zeros, default to None. + v0 (Tensor): inital y flow, (nz, ny + 1, nx), if None all zeros, detault to None. + w0 (Tensor): inital z flow, (nz - 1, ny, nx), if None all zeros, default to None. + """ + + def __init__(self, dx, dy, nx, ny, nz, theta, p_t=None, p_s=None, u0=None, v0=None, w0=None, device='cuda'): + super().__init__() + self.device = device + + # constants + self.PREF = 1e5 # reference pressure, usually sea level pressure, Pa + self.Rd = 287 # gas constant for dry air, J/(kg*K) + self.g = 9.81 # the acceleration of gravity, m/s**2 + + # spatial discretization + self.dx, self.dy, self.dz, self.nx, self.ny, self.nz = dx, dy, 1. / (nz + 1), nx, ny, nz + + # agnostic variables + self.P_t = p_t if p_t else torch.ones((1, ny, nx), device=device) * self.PREF * 0.0 + self.P_s = p_s if p_s else torch.ones((1, ny, nx), device=device) * self.PREF + # pressure (nz, ny, nx) + self.P = torch.linspace(self.dz, 1 - self.dz, nz, device=device).view(nz, 1, 1) * \ + (self.P_s - self.P_t).view(1, ny, nx) + self.P_t + # Alpha (nz, ny, nx) + self.Alpha = self.Rd / self.PREF * theta[1:-1] * (self.P / self.PREF)**(-1/1.4) + + # prognostic variables + # Mu (nz, ny, nx) + self.Mu = torch.ones((nz, 1, 1), device=device) * (self.P_s - self.P_t).view(1, ny, nx) + # self.Mu_t = (self.P_s - self.P_t).view(1, ny, nx) + # self.Mu_s = (self.P_s - self.P_t).view(1, ny, nx) + # Phi (nz - 1, ny, nx) + Phi = torch.zeros((nz + 1, ny, nx), device=device) + Phi[:-1] = self.Mu * self.Alpha * self.dz + for i in range(nz - 1, -1, -1): + Phi[i] += Phi[i + 1] + self.Phi_t = Phi[0].view(1, ny, nx) + self.Phi_s = Phi[-1].view(1, ny, nx) + self.Phi = Phi[1:-1] + # Theta (nz, ny, nx) + self.theta_t = theta[0].view(1, ny, nx) + self.theta_s = theta[-1].view(1, ny, nx) + self.Theta = theta[1:-1] * self.Mu + # U (nz, ny, nx + 1) + self.U = u0 if u0 is not None else torch.zeros((nz, ny, nx + 1), device=device) + # V (nz, ny + 1, nx) + self.V = v0 if v0 is not None else torch.zeros((nz, ny + 1, nx), device=device) + # W (nz - 1, ny, nx) + self.W = w0 if w0 is not None else torch.zeros((nz - 1, ny, nx), device=device) + + def RHS(self, U, V, W, Theta, Mu, Phi): + # volecity + u = U / self.bar_x(self.pad_x(Mu)) + v = V / self.bar_y(self.pad_y(Mu)) + w = W / self.bar_z(Mu) + alpha = -self.delta_z(self.pad_z(Phi, self.Phi_t, self.Phi_s)) / Mu + self.Alpha = alpha + theta = Theta / Mu + p = self.PREF * (self.Rd * theta / self.PREF / alpha)**1.4 + omega = -w * self.g / self.bar_z(alpha) / self.bar_z(Mu) + Omega = omega * self.bar_z(Mu) + self.Omega = Omega + + # advection term + R_U = - self.delta_x(self.bar_x(self.pad_x(U)) * self.bar_x(self.pad_x(u))) \ + - self.delta_y(self.bar_x(self.pad_x(V)) * self.bar_y(self.pad_y(u))) \ + - self.delta_z(self.bar_x(self.pad_x(self.pad_z(Omega))) * self.bar_z(self.pad_z(u))) + + R_V = - self.delta_x(self.bar_y(self.pad_y(U)) * self.bar_x(self.pad_x(v))) \ + - self.delta_y(self.bar_y(self.pad_y(V)) * self.bar_y(self.pad_y(v))) \ + - self.delta_z(self.bar_y(self.pad_y(self.pad_z(Omega))) * self.bar_z(self.pad_z(v))) + + R_W = - self.delta_x(self.bar_z(U) * self.bar_x(self.pad_x(w))) \ + - self.delta_y(self.bar_z(V) * self.bar_y(self.pad_y(w))) \ + - self.delta_z(self.bar_z(self.pad_z(Omega)) * self.bar_z(self.pad_z(w))) + + R_Theta = - self.delta_x(U * self.bar_x(self.pad_x(theta))) \ + - self.delta_y(V * self.bar_y(self.pad_y(theta))) \ + - self.delta_z(self.pad_z(Omega) * self.bar_z(self.pad_z(theta))) + + R_Phi = - self.bar_z(self.bar_x(u)) * self.delta_x(self.bar_x(self.pad_x(Phi))) \ + - self.bar_z(self.bar_y(v)) * self.delta_y(self.bar_y(self.pad_y(Phi))) \ + - omega * self.delta_z(self.bar_z(self.pad_z(Phi, self.Phi_t, self.Phi_s))) + + R_Mu = - self.delta_x(U) - self.delta_y(V) - self.delta_z(self.pad_z(Omega)) + + # pressure term + R_U += - self.bar_x(self.pad_x(Mu)) * self.bar_x(self.pad_x(alpha)) * self.delta_x(self.pad_x(p)) \ + - (self.delta_z(self.bar_x(self.bar_z(self.pad_x(self.pad_z(p, self.P_t, self.P_s))))) * + self.delta_x(self.pad_x(self.bar_z(self.pad_z(Phi, self.Phi_t, self.Phi_s))))) + + R_V += - self.bar_y(self.pad_y(Mu)) * self.bar_y(self.pad_y(alpha)) * self.delta_y(self.pad_y(p)) \ + - (self.delta_z(self.bar_y(self.bar_z(self.pad_y(self.pad_z(p, self.P_t, self.P_s))))) * + self.delta_y(self.pad_y(self.bar_z(self.pad_z(Phi, self.Phi_t, self.Phi_s))))) + + R_W += self.g * (self.delta_z(p) - self.bar_z(Mu)) + + # gravity term + R_Phi += self.g * w + + # Coriolis term + # R_U += + 100 * self.bar_x(self.bar_y(self.pad_x(V))) \ + # - 100 * self.bar_x(self.bar_z(self.pad_x(self.pad_z(W)))) \ + # - u * self.bar_x(self.bar_z(self.pad_x(self.pad_z(W)))) / 6400. / 1000. + # R_V += - 100 * self.bar_x(self.bar_y(self.pad_y(U))) \ + # + 100 * self.bar_y(self.bar_z(self.pad_y(self.pad_z(W)))) \ + # - v * self.bar_y(self.bar_z(self.pad_y(self.pad_z(W)))) / 6400. / 1000. + + return R_U, R_V, R_W, R_Theta, R_Mu, R_Phi, + + def RK3_step(self, U, V, W, Theta, Mu, Phi, dt): + r"""One RK3 Step""" + R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U, V, W, Theta, Mu, Phi) + U_ = U + dt * R_U / 3 + V_ = V + dt * R_V / 3 + W_ = W + dt * R_W / 3 + Theta_ = Theta + dt * R_Theta / 3 + Mu_ = Mu + dt * R_Mu / 3 + Phi_ = Phi + dt * R_Phi / 3 + + R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_) + U_ = U + dt * R_U / 2 + V_ = V + dt * R_V / 2 + W_ = W + dt * R_W / 2 + Theta_ = Theta + dt * R_Theta / 2 + Mu_ = Mu + dt * R_Mu / 2 + Phi_ = Phi + dt * R_Phi / 2 + + R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_) + U += dt * R_U + V += dt * R_V + W += dt * R_W + Theta += dt * R_Theta + Mu += dt * R_Mu + Phi += dt * R_Phi + + return U, V, W, Theta, Mu, Phi + + def forward(self, dt): + self.U, self.V, self.W, self.Theta, self.Mu, self.Phi = \ + self.RK3_step(self.U, self.V, self.W, self.Theta, self.Mu, self.Phi, dt) + + def pad_x(self, X): + r"""Periodic boundary condition in x axis""" + return F.pad(X, (1, 1), "circular") + + def pad_y(self, X): + r"""Periodic boundary condition in y axis""" + Nz, Ny, Nx = X.shape + return F.pad(X.view(1, Nz, Ny, Nx), (0, 0, 1, 1), "circular").view(Nz, Ny + 2, Nx) + + def pad_z(self, X, top=None, surface=None): + r"""Dirichlet boundary condition in z axis""" + _, ny, nx = X.shape + top = top if top is not None else torch.zeros((1, ny, nx), device=X.device) + surface = surface if surface is not None else torch.zeros((1, ny, nx), device=X.device) + return torch.cat((top, X, surface), dim=0) + + def bar_x(self, X): + r"""Numerical scheme for X\bar^x + + Args: + X (Tensor): shape (Nz, Ny, Nx) + + Returns: + Tensor: X\bar^x with shape (Nz, Ny, Nx-1) + """ + Nz, Ny, Nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / 2. + + def delta_x(self, X): + r"""Numerical scheme for \delta_x X + + Args: + X (Tensor): shape (Nz, Ny, Nx) + + Returns: + Tensor: \delta_x X with shape (Nz, Ny, Nx-1) + """ + Nz, Ny, Nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / self.dx + + def bar_y(self, X): + r"""Numerical scheme for X\bar^y + + Args: + X (Tensor): shape (Nz, Ny, Nx) + + Returns: + Tensor: X\bar^y with shape (Nz, Ny-1, Nx) + """ + Nz, Ny, Nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / 2. + + def delta_y(self, X): + r"""Numerical scheme for \delta_y X + + Args: + X (Tensor): shape (Nz, Ny, Nx) + + Returns: + Tensor: \delta_y X with shape (Nz, Ny-1, Nx) + """ + Nz, Ny, Nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 2, 1) + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / self.dy + + def bar_z(self, X): + r"""Numerical scheme for X\bar^z + + Args: + X (Tensor): shape (Nz, Ny, Nx) + + Returns: + Tensor: X\bar^z with shape (Nz-1, Ny, Nx) + """ + Nz, Ny, Nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / 2. + + def delta_z(self, X): + r"""Numerical scheme for \delta_z X + + Args: + X (Tensor): shape (Nz, Ny, Nx) + + Returns: + Tensor: \delta_z X with shape (Nz-1, Ny, Nx) + """ + Nz, Ny, Nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / self.dz + + def _acoustic_step(self, ): + r"""One acustic step""" + pass + + +if __name__ == "__main__": + # simulation settings + nx = 201 + ny = 201 + nz = 201 + dx = 1e3 # m + dy = 1e3 # m + + x0 = 100e3 + y0 = 100e3 + grid_x, grid_y = torch.meshgrid(torch.linspace(0, 200e3, 201), torch.linspace(0, 200e3, 201)) + # 100K + theta = torch.linspace(0, 1, nz + 2).view(nz + 2, 1, 1) * torch.ones((1, ny, nx)) * 600. + 300 + theta += torch.linspace(1, 0, nz + 2).view(nz + 2, 1, 1) * \ + -100. * torch.exp(-0.5 * ((grid_x - x0)**2 + (grid_y - y0)**2) / 400e6).view(1, ny, nx) + # u0 = torch.ones((nz, ny, nx + 1)).cuda() + wrf = WRF(dx, dy, nx, ny, nz, theta.cuda()) + + import matplotlib.pyplot as plt + import numpy as np + + while True: + plt.cla() + cf = plt.contourf(wrf.Theta[:, 100, :].cpu().numpy(), levels=50, cmap='jet') + cb = plt.colorbar(cf) + plt.savefig('res.jpeg', dpi=300) + plt.clf() + input('stop') + + for i in range(1): + wrf(0.1) From 249cc7cb0b5792d0d676f2f69c1371eaaec055d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 24 Feb 2022 00:31:45 +0800 Subject: [PATCH 0597/1892] add sampler for spatial + temporal --- examples/mlp/policy/st_search.py | 119 ++++++++++++++++++++++++------- 1 file changed, 95 insertions(+), 24 deletions(-) diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index e5fa671e..ba5d50d9 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -1,4 +1,6 @@ import copy +from os import stat +import time from typing import List, Tuple, Dict from cube.graph.graph import IRGraph, IRFwOperation from cube.graph.operator.operator import IRBpOperation, IRDataOperation @@ -40,14 +42,27 @@ def map2time(node: IRCell): return span -class TSampler: +class Sampler: """ Schedule sampler """ @staticmethod - def topo_sequence_batch(nodes: List[IRCell], bs=1): + def sample(graph: IRGraph, n_microbatch: int, n_stage: int, n_worker: int, n_sample_per_worker: int): + # spatial assignment + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + assert n_microbatch * n_stage == len(fnodes), f"{n_microbatch * n_stage} != {len(fnodes)}" + for placement in Sampler.spatial(n_microbatch, n_stage, n_stage): + assert len(placement) == len(fnodes) + print(placement) + for fnode, devid in zip(fnodes, placement): + graph.assign(fnode, devid) + for seqs in Sampler.btemporal(graph.nodes(), bs=n_worker * n_sample_per_worker): + yield seqs + + @staticmethod + def btemporal(nodes: List[IRCell], bs=1): seqs = list() - for idx, seq in enumerate(TSampler.topo_sequence(nodes)): + for idx, seq in enumerate(Sampler.temporal(nodes)): seqs.append(seq) if len(seqs) % bs == 0: print(f'dispatch {len(seqs)} seq...') @@ -57,20 +72,20 @@ def topo_sequence_batch(nodes: List[IRCell], bs=1): yield seqs @staticmethod - def topo_sequence(nodes: List[IRCell], seq = None): + def temporal(nodes: List[IRCell], seq = None): if seq is None: seq = list() if len(nodes) == 0: yield seq # initial entry - entry_nodes = TSampler.ready_emit_set(remain=nodes, seq=seq) + entry_nodes = Sampler.ready_emit_set(remain=nodes, seq=seq) if len(entry_nodes) == 0: return None for node in entry_nodes: seq = seq + [node] nid = nodes.index(node) sub_nodes = nodes[:nid] + nodes[nid+1:] - for res in TSampler.topo_sequence(sub_nodes, seq): + for res in Sampler.temporal(sub_nodes, seq): if res is None: continue yield res @@ -97,6 +112,52 @@ def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): ready.append(node) return ready + @staticmethod + def spatial(num_microbatch: int, num_stage: int, num_device: int, placement = None): + # each device pick num_microbatch * num_stage // num_device blocks + per_device_nblocks = num_microbatch * num_stage // num_device + # placement each stage placement + placement = placement if placement is not None else [] + + if len(placement) == num_microbatch * num_stage: + bucket_min = [num_microbatch * num_stage] * num_device + for nid, devid in enumerate(placement): + bucket_min[devid] = min(bucket_min[devid], nid) + check = [bucket_min[idx + 1] - bucket_min[idx] for idx in range(num_device - 1)] + if min(check) < 0: + yield None + else: + yield placement + else: + # require strict increasing array [min(bucket) for bucket in buckets] + # bucket_min = list(range(num_microbatch * num_stage, num_microbatch * num_stage + num_device + 1)) + bucket_cnt = [0] * num_device + for nid, devid in enumerate(placement): + # bucket_min[devid] = min(nid, bucket_min[devid]) if bucket_min[devid] is not None else nid + bucket_cnt[devid] += 1 + for devid in range(num_device): + if bucket_cnt[devid] < per_device_nblocks: + placement = placement + [devid] + for seq in Sampler.spatial(num_microbatch, num_stage, num_device, placement): + if seq is None: + continue + yield seq + placement = placement[:-1] + # if bucket_cnt[devid] == per_device_nblocks: + # continue + # # try to place on devid + # new_min = min(bucket_min[devid], len(placement)) + # if bucket_min[devid + 1] < new_min: + # continue + # placement.append(devid) + # print(placement) + # input(">>>1 ") + # if len(placement) == num_microbatch * num_stage: + # yield placement + # for seq in Sampler.spatial(num_microbatch, num_stage, num_device, placement): + # yield seq + # placement = placement[:-1] + class Searcher: @@ -118,8 +179,10 @@ def run(seqs: List[List[IRCell]]): def PAS(graph: IRGraph, resource): - num_microbatch = 4 - num_stages = 4 + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + num_microbatch = len(fnodes) + num_stages = len(fnodes) + print(f'num-microbatch: {num_microbatch}, num-stages: {num_stages}') # ============================ micro-batch / stage split ============================ fstages = [list() for _ in range(num_microbatch * num_stages)] @@ -130,7 +193,6 @@ def b(micro_batch_id: int, stage_id: int): bstage = [fnode.mirror for fnode in fstage][::-1] return bstage - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] def stage_division(fnodes: List[IRCell], node: IRCell, num_stages: int) -> int: """Determine stage division """ @@ -147,22 +209,22 @@ def stage_division(fnodes: List[IRCell], node: IRCell, num_stages: int) -> int: for mid, sub_node in enumerate(sub_nodes): graph.assign(sub_node, stage) f(mid, stage).append(sub_node) - # data operator - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - graph.assign(node, 0) # ============================ micro-batch / stage split ============================ # pruning #2: symmetric microbatches, make micro-batch id smaller happen earlier - for sid in range(num_stages): - fops = list() - bops = list() - for mid in range(num_microbatch): - fops += f(mid, sid) - bops += b(mid, sid) - assert graph.add_schedule(fops) - assert graph.add_schedule(bops) + # for sid in range(num_stages): + # fops = list() + # bops = list() + # for mid in range(num_microbatch): + # fops += f(mid, sid) + # bops += b(mid, sid) + # assert graph.add_schedule(fops) + # assert graph.add_schedule(bops) + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + graph = IRGraph([], [], [], 'search') + graph._nodes = fnodes + [fnode.mirror for fnode in fnodes[::-1]] + graph.reset_dependency() Estimator.taging(graph) # memory (int) -> (time, seq) @@ -172,7 +234,8 @@ def stage_division(fnodes: List[IRCell], node: IRCell, num_stages: int) -> int: nproc, worker_samples = 32, 512 pool = Pool(processes=nproc) _graph = IRGraph([], [], [], 'search') - for idx, seqs in enumerate(TSampler.topo_sequence_batch(graph.nodes(), bs=nproc * worker_samples)): + for idx, seqs in enumerate(Sampler.sample(graph, num_microbatch, num_stages, nproc, worker_samples)): + tic = time.time() results = list() for wid in range(nproc): start = worker_samples* wid @@ -194,9 +257,17 @@ def stage_division(fnodes: List[IRCell], node: IRCell, num_stages: int) -> int: _graph._nodes = seq execplan = ExectuionPlan(_graph) execplan.analyze(map2time=Estimator.map2time, outfile=f'plan.mem{mem}.png') + toc = time.time() + throughput = round(len(seqs) / (toc - tic), 2) if (idx + 1) % 1 == 0: - print(f'progress: searched {(idx) * nproc * worker_samples + len(seqs)} K sequences') + print(f'progress: searched {(idx) * nproc * worker_samples + len(seqs)} sequences, throughput: {throughput} seqs/s') if len(seqs) != worker_samples: num = idx * nproc * worker_samples + len(seqs) - print(f'done search on {num} K sequences') + print(f'done search on {num} sequences') assert False + + +if __name__ == '__main__': + for idx, placement in enumerate(Sampler.spatial(3, 3, 3)): + print(placement) + print(f'total {idx + 1} seqs') From f07301048ee3f2757c8e1271e3682bfeac19ff5b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 24 Feb 2022 18:41:47 +0800 Subject: [PATCH 0598/1892] add sampler module --- cube/search/sampler.py | 305 +++++++++++++++++++++++++++++++ examples/mlp/policy/st_search.py | 298 ++++++------------------------ 2 files changed, 360 insertions(+), 243 deletions(-) create mode 100644 cube/search/sampler.py diff --git a/cube/search/sampler.py b/cube/search/sampler.py new file mode 100644 index 00000000..e468fad8 --- /dev/null +++ b/cube/search/sampler.py @@ -0,0 +1,305 @@ +""" +Micro-batch sampler for scheduling search +""" +from typing import Callable, Dict, List, Tuple +from cube.graph.graph import IRGraph, IRFwOperation +from cube.graph.operator.operator import IRBpOperation +from cube.ir.cten import IRCell +from cube.execplan import ExectuionPlan + +from multiprocessing import Pool +import numpy as np +import time, copy, math + + +class Estimator: + """ + A node cost is represented as (mem_weight, mem_activation, exec_time) + """ + @staticmethod + def taging(graph: IRGraph): + for node in graph.nodes(): + # tag: (mem_weight, mem_activation, span) + if isinstance(node, IRFwOperation): + node.cost = (0, 1, 1) + elif isinstance(node, IRBpOperation): + node.cost = (0, -1, 2) + else: + node.cost = (0, 0, 0) + + @staticmethod + def map2mem(node: IRCell): + if node.cost is not None: + mem_w, mem_a, span = node.cost + else: + mem_w, mem_a, span = 0, 0, 0 + return mem_w + mem_a + + @staticmethod + def map2time(node: IRCell): + if node.cost is not None: + mem_w, mem_a, span = node.cost + else: + mem_w, mem_a, span = 0, 0, 0 + return span + + +class Sampler: + """ + Schedule sampler + """ + @staticmethod + def sample(micro_seqs: List[List[IRCell]], n_microbatch: int, n_stage: int, n_device: int, + ssampler: Callable, tsampler: Callable): + assert len(micro_seqs) == n_microbatch + for seq in micro_seqs: + assert len(seq) // 2 == n_stage + graph = IRGraph([], [], [], 'search') + flatten_nodes = list() + for seq in micro_seqs: + flatten_nodes += seq + graph._nodes = flatten_nodes + for placements in ssampler(n_microbatch, n_stage, n_device): + print('seraching placement:\n', placements) + # assign to device + for mid in range(n_microbatch): + for devid, fnode in zip(placements[mid], micro_seqs[mid]): + graph.assign(fnode, devid) + for seqs in tsampler(graph.nodes()): + yield seqs + + +class TemporalSampler: + """ + Temporal sampler takes nodes (List[IRCell]) as input + """ + + @staticmethod + def btemporal(nodes: List[IRCell], bs=1): + seqs = list() + for idx, seq in enumerate(TemporalSampler.temporal(nodes)): + seqs.append(seq) + if len(seqs) % bs == 0: + print(f'dispatch {len(seqs)} seq...') + yield seqs + seqs = list() + if len(seqs) > 0: + yield seqs + + @staticmethod + def temporal(nodes: List[IRCell], seq = None): + if seq is None: + seq = list() + if len(nodes) == 0: + yield seq + # initial entry + entry_nodes = TemporalSampler.ready_emit_set(remain=nodes, seq=seq) + if len(entry_nodes) == 0: + return None + for node in entry_nodes: + seq = seq + [node] + nid = nodes.index(node) + sub_nodes = nodes[:nid] + nodes[nid+1:] + for res in TemporalSampler.temporal(sub_nodes, seq): + if res is None: + continue + yield res + seq = seq[:-1] + + @staticmethod + def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): + """ + Get ready-to-emit node list from remain node set + """ + ready = list() + for node in remain: + satisfy = True + for pre in node.predecessors(): + if pre not in seq: + satisfy = False + break + if satisfy: + if len(seq) > 0 and len(seq[-1].device) != 0 and len(node.device) != 0: + # pruning #1: filter out equal sequences + if seq[-1] not in node.predecessors(): + if node.device[0] < seq[-1].device[0]: + continue + ready.append(node) + return ready + + +class SpatialSampler: + """ + Spatial sampler takes (n_microbatch, n_stage, n_device) as input + """ + + @staticmethod + def full(n_microbatch: int, n_stage: int, n_device: int, placement = None): + # each device pick n_microbatch * n_stage // n_device blocks + per_device_nblocks = n_microbatch * n_stage // n_device + # placement each stage placement + placement = placement if placement is not None else [] + + if len(placement) == n_microbatch * n_stage: + bucket_min = [n_microbatch * n_stage] * n_device + for nid, devid in enumerate(placement): + bucket_min[devid] = min(bucket_min[devid], nid) + check = [bucket_min[idx + 1] - bucket_min[idx] for idx in range(n_device - 1)] + if min(check) < 0: + yield None + else: + yield placement + else: + # require strict increasing array [min(bucket) for bucket in buckets] + # bucket_min = list(range(n_microbatch * n_stage, n_microbatch * n_stage + n_device + 1)) + bucket_cnt = [0] * n_device + for nid, devid in enumerate(placement): + # bucket_min[devid] = min(nid, bucket_min[devid]) if bucket_min[devid] is not None else nid + bucket_cnt[devid] += 1 + for devid in range(n_device): + if bucket_cnt[devid] < per_device_nblocks: + placement = placement + [devid] + for seq in SpatialSampler.full(n_microbatch, n_stage, n_device, placement): + if seq is None: + continue + yield seq + placement = placement[:-1] + + @staticmethod + def same(n_microbatch: int, n_stage: int, n_device: int, wlimits: int): + """ + Same spatial placement for each micro-batch + """ + placements = [] + for _ in range(n_microbatch): + placement = [sid % n_device for sid in range(n_stage)] + placements.append(placement) + yield placements + + @staticmethod + def othogonal(n_microbatch: int, n_stage: int, n_device: int, + wlimits: int, balance = True, placements = None): + """ + Find most othogonal plans given weight_limits + + Yield: + List[microbatch][stage] = device (int) + """ + if balance: + nstages_per_dev = n_microbatch * n_stage // n_device + else: + nstages_per_dev = n_microbatch * n_stage + # wlimits = wlimits if wlimits < n_stage else n_stage + # placements = [] if placements is None else placements + wstatus = [set() for _ in range(n_device)] + bstatus = [0] * n_device + start_slots = np.array([n_stage] * n_device, dtype=int) + # if len(placements) == n_microbatch: + # yield placements + # else: + # for placement in placements: + # for sid, devid in enumerate(placement): + # wstatus[devid].add(sid) + # start_slots[devid] = min(sid, start_slots[devid]) + # bstatus[devid] += 1 + placements = [] + for _ in range(n_microbatch): + placement = list() + for sid in range(n_stage): + # get last starting device + for devid in np.argsort(start_slots)[::-1]: + if bstatus[devid] == nstages_per_dev: + continue + # try place + if sid not in wstatus[devid] and len(wstatus[devid]) == wlimits: + continue + placement = placement + [devid] + wstatus[devid].add(sid) + bstatus[devid] += 1 + start_slots[devid] = min(sid, start_slots[devid]) + break + if len(placement) != n_stage: + raise RuntimeError("Cannot find othogonal plans") + placements = placements + [placement] + # for seq in SpatialSampler.othogonal(n_microbatch, n_stage, n_device, wlimits, placements): + # yield seq + # placements = placements[:-1] + yield placements + + @staticmethod + def microbatch_placement(n_stage: int, n_device: int, + wlimits: int, placement = None, wstatus = None): + """ + Find microbatch placement + Yield: + List[stage] = device[int] + """ + placement = [] if placement is None else placement + wstatus = [0] * n_device if wstatus is None else wstatus + if len(placement) == n_stage: + yield placement + else: + for devid in range(n_device): + if wstatus[devid] == wlimits: + continue + placement = placement + [devid] + wstatus[devid] += 1 + for seq in SpatialSampler.microbatch_placement(n_stage, n_device, wlimits, placement, wstatus): + yield seq + wstatus[devid] -= 1 + placement = placement[:-1] + + +class Searcher: + + pool = Pool(processes=32) + + @staticmethod + def search(seqs: List[List[IRCell]], bucket: Dict, n_worker: int = 1) -> Dict[int, Tuple[int, List]]: + pool = Pool(processes=32) + # memory (int) -> (time, seq) + tic = time.time() + per_worker_seqs = int(math.ceil(len(seqs) / n_worker)) + worker_buckets = list() + for wid in range(n_worker): + start = wid * per_worker_seqs + stop = (wid + 1) * per_worker_seqs + worker_seqs = seqs[start:stop] + worker_buckets.append(pool.apply_async(Searcher._run, (worker_seqs,))) + worker_buckets: List[Dict] = map(lambda buck: buck.get(), worker_buckets) + # merge results + for worker_bucket in worker_buckets: + for mem, (span, seq) in worker_bucket.items(): + if mem in bucket and bucket[mem][0] < span: + continue + print(f'find better plan at mem budget {mem}: span: {span}') + bucket[mem] = (span, seq) + toc = time.time() + throughput = round(len(seqs) / (toc - tic), 2) + print(f'searched {len(seqs)} sequences... throughput: {throughput} seqs/s') + pool.close() + pool.join() + + @staticmethod + def _run(seqs: List[List[IRCell]]) -> Dict[int, Tuple[int, List]]: + """ + Worker run + """ + bucket = dict() + graph = IRGraph([], [], [], 'search') + for seq in seqs: + graph._nodes = seq + execplan = ExectuionPlan(graph) + span, mem = execplan.analyze(map2time=Estimator.map2time, map2mem=Estimator.map2mem) + if mem not in bucket: + bucket[mem] = (span, copy.copy(seq)) + elif bucket[mem][0] > span: + bucket[mem] = (span, copy.copy(seq)) + return bucket + + +if __name__ == '__main__': + + for idx, placement in enumerate(SpatialSampler.othogonal(n_microbatch=4, n_stage=4, n_device=4, wlimits=2)): + print(placement) + print(f'total {idx+1} placements') diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index ba5d50d9..b9b66fb7 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -1,269 +1,81 @@ -import copy -from os import stat -import time -from typing import List, Tuple, Dict +from functools import partial +from typing import List from cube.graph.graph import IRGraph, IRFwOperation -from cube.graph.operator.operator import IRBpOperation, IRDataOperation from cube.ir.cten import IRCell from cube.execplan import ExectuionPlan -from multiprocessing import Pool +from cube.search.sampler import Estimator, Sampler, SpatialSampler, TemporalSampler, Searcher -class Estimator: - """ - A node tag is represented as (mem_weight, mem_activation, exec_time) - """ - @staticmethod - def taging(graph: IRGraph): - for node in graph.nodes(): - # tag: (mem_weight, mem_activation, span) - if isinstance(node, IRFwOperation): - node.cost = (0, 1, 1) - elif isinstance(node, IRBpOperation): - node.cost = (0, -1, 2) - else: - node.cost = (0, 0, 0) - - @staticmethod - def map2mem(node: IRCell): - if node.cost is not None: - mem_w, mem_a, span = node.cost - else: - mem_w, mem_a, span = 0, 0, 0 - return mem_w + mem_a +class MicroBatchView: @staticmethod - def map2time(node: IRCell): - if node.cost is not None: - mem_w, mem_a, span = node.cost - else: - mem_w, mem_a, span = 0, 0, 0 - return span - - -class Sampler: - """ - Schedule sampler - """ - @staticmethod - def sample(graph: IRGraph, n_microbatch: int, n_stage: int, n_worker: int, n_sample_per_worker: int): - # spatial assignment - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - assert n_microbatch * n_stage == len(fnodes), f"{n_microbatch * n_stage} != {len(fnodes)}" - for placement in Sampler.spatial(n_microbatch, n_stage, n_stage): - assert len(placement) == len(fnodes) - print(placement) - for fnode, devid in zip(fnodes, placement): - graph.assign(fnode, devid) - for seqs in Sampler.btemporal(graph.nodes(), bs=n_worker * n_sample_per_worker): - yield seqs - - @staticmethod - def btemporal(nodes: List[IRCell], bs=1): - seqs = list() - for idx, seq in enumerate(Sampler.temporal(nodes)): - seqs.append(seq) - if len(seqs) % bs == 0: - print(f'dispatch {len(seqs)} seq...') - yield seqs - seqs = list() - if len(seqs) > 0: - yield seqs - - @staticmethod - def temporal(nodes: List[IRCell], seq = None): - if seq is None: - seq = list() - if len(nodes) == 0: - yield seq - # initial entry - entry_nodes = Sampler.ready_emit_set(remain=nodes, seq=seq) - if len(entry_nodes) == 0: - return None - for node in entry_nodes: - seq = seq + [node] - nid = nodes.index(node) - sub_nodes = nodes[:nid] + nodes[nid+1:] - for res in Sampler.temporal(sub_nodes, seq): - if res is None: - continue - yield res - seq = seq[:-1] + def node2stage(node: IRCell, fnodes: List[IRCell], n_stage: int): + num_fnodes = len(fnodes) + idx = fnodes.index(node) + stage = min(idx // (num_fnodes // n_stage), n_stage - 1) + return stage @staticmethod - def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): + def split(graph: IRGraph, n_microbatch: int) -> List[IRCell]: """ - Get ready-to-emit node list from remain node set + Split graph into micro-batch view """ - ready = list() - for node in remain: - satisfy = True - for pre in node.predecessors(): - if pre not in seq: - satisfy = False - break - if satisfy: - if len(seq) > 0 and len(seq[-1].device) != 0 and len(node.device) != 0: - # pruning #1: filter out equal sequences - if seq[-1] not in node.predecessors(): - if node.device[0] < seq[-1].device[0]: - continue - ready.append(node) - return ready - - @staticmethod - def spatial(num_microbatch: int, num_stage: int, num_device: int, placement = None): - # each device pick num_microbatch * num_stage // num_device blocks - per_device_nblocks = num_microbatch * num_stage // num_device - # placement each stage placement - placement = placement if placement is not None else [] - - if len(placement) == num_microbatch * num_stage: - bucket_min = [num_microbatch * num_stage] * num_device - for nid, devid in enumerate(placement): - bucket_min[devid] = min(bucket_min[devid], nid) - check = [bucket_min[idx + 1] - bucket_min[idx] for idx in range(num_device - 1)] - if min(check) < 0: - yield None - else: - yield placement - else: - # require strict increasing array [min(bucket) for bucket in buckets] - # bucket_min = list(range(num_microbatch * num_stage, num_microbatch * num_stage + num_device + 1)) - bucket_cnt = [0] * num_device - for nid, devid in enumerate(placement): - # bucket_min[devid] = min(nid, bucket_min[devid]) if bucket_min[devid] is not None else nid - bucket_cnt[devid] += 1 - for devid in range(num_device): - if bucket_cnt[devid] < per_device_nblocks: - placement = placement + [devid] - for seq in Sampler.spatial(num_microbatch, num_stage, num_device, placement): - if seq is None: - continue - yield seq - placement = placement[:-1] - # if bucket_cnt[devid] == per_device_nblocks: - # continue - # # try to place on devid - # new_min = min(bucket_min[devid], len(placement)) - # if bucket_min[devid + 1] < new_min: - # continue - # placement.append(devid) - # print(placement) - # input(">>>1 ") - # if len(placement) == num_microbatch * num_stage: - # yield placement - # for seq in Sampler.spatial(num_microbatch, num_stage, num_device, placement): - # yield seq - # placement = placement[:-1] - - -class Searcher: + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + micro_seqs = [list() for _ in range(n_microbatch)] + for node in fnodes: + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, config=dict(idx=0, dim=0, num=n_microbatch)) + for mid, sub_node in enumerate(sub_nodes): + micro_seqs[mid].append(sub_node) + for mid in range(n_microbatch): + micro_seqs[mid] = micro_seqs[mid] + [n.mirror for n in micro_seqs[mid][::-1]] + return micro_seqs @staticmethod - def run(seqs: List[List[IRCell]]): - # mem -> (time, seq) - bucket = dict() - graph = IRGraph([], [], [], 'search') - for seq in seqs: - graph._nodes = seq - # graph.reset_dependency() # this needs as in other process dependency will break - execplan = ExectuionPlan(graph) - span, mem = execplan.analyze(map2time=Estimator.map2time, map2mem=Estimator.map2mem) - if mem not in bucket: - bucket[mem] = (span, copy.copy(seq)) - elif bucket[mem][0] > span: - bucket[mem] = (span, copy.copy(seq)) - return bucket + def flatten(micro_seqs: List[List[IRCell]]): + flatten_nodes = list() + for seq in micro_seqs: + flatten_nodes += seq + return flatten_nodes def PAS(graph: IRGraph, resource): - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - num_microbatch = len(fnodes) - num_stages = len(fnodes) - print(f'num-microbatch: {num_microbatch}, num-stages: {num_stages}') - # ============================ micro-batch / stage split ============================ - fstages = [list() for _ in range(num_microbatch * num_stages)] - def f(micro_batch_id: int, stage_id: int): - return fstages[micro_batch_id * num_stages + stage_id] - def b(micro_batch_id: int, stage_id: int): - fstage = f(micro_batch_id, stage_id) - bstage = [fnode.mirror for fnode in fstage][::-1] - return bstage + # n_microbatch, n_stage, n_device + M, S, D = 4, 4, 4 - def stage_division(fnodes: List[IRCell], node: IRCell, num_stages: int) -> int: - """Determine stage division - """ - num_fnodes = len(fnodes) - idx = fnodes.index(node) - stage = min(idx // (num_fnodes // num_stages), num_stages - 1) - return stage - # split to micro batches - for node in fnodes: - stage = stage_division(fnodes, node, num_stages=num_stages) - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=0, num=num_microbatch)) - for mid, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, stage) - f(mid, stage).append(sub_node) - # ============================ micro-batch / stage split ============================ + # memory limits + wlimits = 1 + alimits = 4 - # pruning #2: symmetric microbatches, make micro-batch id smaller happen earlier - # for sid in range(num_stages): - # fops = list() - # bops = list() - # for mid in range(num_microbatch): - # fops += f(mid, sid) - # bops += b(mid, sid) - # assert graph.add_schedule(fops) - # assert graph.add_schedule(bops) + micro_seqs = MicroBatchView.split(graph, M) + assert len(micro_seqs) == M and len(micro_seqs[0]) // 2 == S + sgraph = IRGraph(MicroBatchView.flatten(micro_seqs), [], [], 'search') + Estimator.taging(sgraph) - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - graph = IRGraph([], [], [], 'search') - graph._nodes = fnodes + [fnode.mirror for fnode in fnodes[::-1]] - graph.reset_dependency() - Estimator.taging(graph) + # pruning + for sid in range(S): + # forward intra-device dependency + sgraph.add_schedule([micro_seqs[mid][sid] for mid in range(M)]) + # backward intra-device dependency + sgraph.add_schedule([micro_seqs[mid][sid+S] for mid in range(M)]) - # memory (int) -> (time, seq) - bucket = dict() - print('start sorting...') + n_worker, seq_per_worker = 32, 512 + tsampler = partial(TemporalSampler.btemporal, bs=n_worker*seq_per_worker) + ssampler = partial(SpatialSampler.same, wlimits=wlimits) - nproc, worker_samples = 32, 512 - pool = Pool(processes=nproc) - _graph = IRGraph([], [], [], 'search') - for idx, seqs in enumerate(Sampler.sample(graph, num_microbatch, num_stages, nproc, worker_samples)): - tic = time.time() - results = list() - for wid in range(nproc): - start = worker_samples* wid - stop = start + worker_samples - worker_seqs = seqs[start:stop] - results.append(pool.apply_async(Searcher.run, (worker_seqs,))) - results: List[Dict[int, Tuple[int, List]]] = map(lambda res: res.get(), results) - # merge results - for worker_bucket in results: - for mem, (span, seq) in worker_bucket.items(): - better = False - if mem not in bucket: - better = True - elif bucket[mem][0] > span: - better = True - if better: - print(f'find better plan at mem budget {mem}: span: {span}') - bucket[mem] = (span, seq) - _graph._nodes = seq - execplan = ExectuionPlan(_graph) - execplan.analyze(map2time=Estimator.map2time, outfile=f'plan.mem{mem}.png') - toc = time.time() - throughput = round(len(seqs) / (toc - tic), 2) - if (idx + 1) % 1 == 0: - print(f'progress: searched {(idx) * nproc * worker_samples + len(seqs)} sequences, throughput: {throughput} seqs/s') - if len(seqs) != worker_samples: - num = idx * nproc * worker_samples + len(seqs) - print(f'done search on {num} sequences') + bucket = dict() + cnt = 0 + for seqs in Sampler.sample(micro_seqs, M, S, D, ssampler, tsampler): + Searcher.search(seqs, bucket, n_worker=n_worker) + for mem, (span, seq) in bucket.items(): + sgraph._nodes = seq + execplan = ExectuionPlan(sgraph) + execplan.analyze(map2time=Estimator.map2time, outfile=f'plan.mem{mem}.png') + cnt += len(seqs) + print(f'done search on {cnt} sequences') assert False From ec06a8436036bf041f62edd29b987230cfee15d6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 25 Feb 2022 11:40:54 +0800 Subject: [PATCH 0599/1892] othogonal plan search --- cube/search/sampler.py | 133 ++++++++++++++++++++----------- examples/mlp/policy/st_search.py | 11 +-- 2 files changed, 89 insertions(+), 55 deletions(-) diff --git a/cube/search/sampler.py b/cube/search/sampler.py index e468fad8..053e2efc 100644 --- a/cube/search/sampler.py +++ b/cube/search/sampler.py @@ -58,14 +58,35 @@ def sample(micro_seqs: List[List[IRCell]], n_microbatch: int, n_stage: int, n_de flatten_nodes = list() for seq in micro_seqs: flatten_nodes += seq - graph._nodes = flatten_nodes - for placements in ssampler(n_microbatch, n_stage, n_device): + graph = IRGraph(flatten_nodes, [], [], 'search') + # graph._nodes = flatten_nodes + for sidx, placements in enumerate(ssampler(n_microbatch, n_stage, n_device)): print('seraching placement:\n', placements) # assign to device for mid in range(n_microbatch): for devid, fnode in zip(placements[mid], micro_seqs[mid]): graph.assign(fnode, devid) + + # pruning: add dependecies for micro-batches with same device assignment + graph.reset_dependency() + same_microbatch = dict() + for mid, placement in enumerate(placements): + placement = tuple(placement) + if placement not in same_microbatch: + same_microbatch[placement] = list() + same_microbatch[placement].append(mid) + for placement, mids in same_microbatch.items(): + if len(mids) > 1: + print(f'find {mids} microbatch same, add dependency') + for sid in range(len(placement)): + # add forward dependency + graph.add_schedule([micro_seqs[mid][sid] for mid in mids]) + # add backward dependency + graph.add_schedule([micro_seqs[mid][sid+len(placement)] for mid in mids]) + + # search for seqs in tsampler(graph.nodes()): + print(f'searching {len(seqs)} sequences under {sidx}-th placement') yield seqs @@ -77,10 +98,9 @@ class TemporalSampler: @staticmethod def btemporal(nodes: List[IRCell], bs=1): seqs = list() - for idx, seq in enumerate(TemporalSampler.temporal(nodes)): + for seq in TemporalSampler.temporal(nodes): seqs.append(seq) if len(seqs) % bs == 0: - print(f'dispatch {len(seqs)} seq...') yield seqs seqs = list() if len(seqs) > 0: @@ -178,53 +198,74 @@ def same(n_microbatch: int, n_stage: int, n_device: int, wlimits: int): @staticmethod def othogonal(n_microbatch: int, n_stage: int, n_device: int, - wlimits: int, balance = True, placements = None): + wlimits: int, status = None, placements = None): """ - Find most othogonal plans given weight_limits + Find othogonal plans given weight_limits Yield: List[microbatch][stage] = device (int) """ - if balance: - nstages_per_dev = n_microbatch * n_stage // n_device - else: - nstages_per_dev = n_microbatch * n_stage - # wlimits = wlimits if wlimits < n_stage else n_stage - # placements = [] if placements is None else placements - wstatus = [set() for _ in range(n_device)] - bstatus = [0] * n_device - start_slots = np.array([n_stage] * n_device, dtype=int) - # if len(placements) == n_microbatch: - # yield placements - # else: - # for placement in placements: - # for sid, devid in enumerate(placement): - # wstatus[devid].add(sid) - # start_slots[devid] = min(sid, start_slots[devid]) - # bstatus[devid] += 1 - placements = [] - for _ in range(n_microbatch): - placement = list() + # each element denotes number of block assigned + status = np.zeros((n_device, n_stage), dtype=int) if status is None else status + placements = [] if placements is None else placements + # repeat to reduce space + if len(placements) == wlimits: + for idx in range(n_microbatch - wlimits): + placements = placements + [copy.copy(placements[idx % wlimits])] + yield placements + # find othogonal placements + elif len(placements) == 0: + # fix the first one due to symmetric device + placements = placements + [[sid % n_device for sid in range(n_stage)]] for sid in range(n_stage): - # get last starting device - for devid in np.argsort(start_slots)[::-1]: - if bstatus[devid] == nstages_per_dev: - continue - # try place - if sid not in wstatus[devid] and len(wstatus[devid]) == wlimits: - continue - placement = placement + [devid] - wstatus[devid].add(sid) - bstatus[devid] += 1 - start_slots[devid] = min(sid, start_slots[devid]) - break - if len(placement) != n_stage: - raise RuntimeError("Cannot find othogonal plans") - placements = placements + [placement] - # for seq in SpatialSampler.othogonal(n_microbatch, n_stage, n_device, wlimits, placements): - # yield seq - # placements = placements[:-1] - yield placements + status[sid % n_device][sid] += 1 + for seqs in SpatialSampler.othogonal(n_microbatch, n_stage, n_device, + wlimits, status, placements): + yield seqs + else: + for placement in SpatialSampler.microbatch_othogonal(np.copy(status)): + placements = placements + [placement] + for sid, devid in enumerate(placement): + status[devid][sid] += 1 + for seqs in SpatialSampler.othogonal(n_microbatch, n_stage, n_device, + wlimits, status, placements): + yield seqs + for sid, devid in enumerate(placement): + status[devid][sid] -= 1 + placements = placements[:-1] + + @staticmethod + def microbatch_othogonal(status: np.ndarray, placement = None): + """ + status: + 2D array [n_device, n_stage], each element represents + how many stage blocks are assigned. + """ + n_device, n_stage = status.shape + assert n_stage == 4 + placement = [] if placement is None else placement + if len(placement) == n_stage: + # print(placement) + # input('>>>out') + yield placement + else: + sid = len(placement) + allocation = np.sum(status, axis=1) + min_alloc = np.min(allocation) + collision = status[:,sid] + valid = list() + for devid, coll in enumerate(collision): + if coll != 0 or allocation[devid] != min_alloc: + continue + valid.append(devid) + for devid in valid: + placement = placement + [devid] + status[devid][sid] += 1 + for seq in SpatialSampler.microbatch_othogonal(status, placement): + yield seq + status[devid][sid] -= 1 + placement = placement[:-1] + @staticmethod def microbatch_placement(n_stage: int, n_device: int, @@ -270,7 +311,7 @@ def search(seqs: List[List[IRCell]], bucket: Dict, n_worker: int = 1) -> Dict[in # merge results for worker_bucket in worker_buckets: for mem, (span, seq) in worker_bucket.items(): - if mem in bucket and bucket[mem][0] < span: + if mem in bucket and bucket[mem][0] <= span: continue print(f'find better plan at mem budget {mem}: span: {span}') bucket[mem] = (span, seq) diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index b9b66fb7..8bbb34e6 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -47,7 +47,7 @@ def PAS(graph: IRGraph, resource): M, S, D = 4, 4, 4 # memory limits - wlimits = 1 + wlimits = 2 alimits = 4 micro_seqs = MicroBatchView.split(graph, M) @@ -55,16 +55,9 @@ def PAS(graph: IRGraph, resource): sgraph = IRGraph(MicroBatchView.flatten(micro_seqs), [], [], 'search') Estimator.taging(sgraph) - # pruning - for sid in range(S): - # forward intra-device dependency - sgraph.add_schedule([micro_seqs[mid][sid] for mid in range(M)]) - # backward intra-device dependency - sgraph.add_schedule([micro_seqs[mid][sid+S] for mid in range(M)]) - n_worker, seq_per_worker = 32, 512 tsampler = partial(TemporalSampler.btemporal, bs=n_worker*seq_per_worker) - ssampler = partial(SpatialSampler.same, wlimits=wlimits) + ssampler = partial(SpatialSampler.othogonal, wlimits=wlimits) bucket = dict() cnt = 0 From 3627de18c071fb4f31fc71235456c965b6742810 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 27 Feb 2022 14:43:13 +0800 Subject: [PATCH 0600/1892] piper search algo for pipeline algorithm --- cube/search/piper.py | 160 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 cube/search/piper.py diff --git a/cube/search/piper.py b/cube/search/piper.py new file mode 100644 index 00000000..627e0590 --- /dev/null +++ b/cube/search/piper.py @@ -0,0 +1,160 @@ +""" +Abstraction layer for microb-batch execution plan merge. +""" + +from typing import Any, Dict, List, Tuple +import numpy as np + + +class MicroPlan: + + def __init__(self, plan: np.ndarray, name: str = None, summation=None): + """ + positions: + List of [spatial, temporal] slots to anchor the action + """ + assert len(plan.shape) == 2 + self.name = name + self.plan = plan + self.summation = [self] if summation is None else summation + + @property + def ndevs(self): + return self.plan.shape[0] + + @property + def nsteps(self): + return self.plan.shape[1] + + def valid(self) -> bool: + """ + Check runnability + """ + return np.max(self.plan) <= 1 + + def __add__(self, other): + if not isinstance(other, MicroPlan): + raise TypeError("Expect MicroPlan") + lhs, rhs = self, other + ndevs = max(lhs.ndevs, rhs.ndevs) + nsteps = max(lhs.nsteps, rhs.nsteps) + lhs_plan = np.pad( + lhs.plan, ((0, ndevs-lhs.ndevs),(0, nsteps-lhs.nsteps)) + ) + rhs_plan = np.pad( + rhs.plan, ((0, ndevs-rhs.ndevs), (0, nsteps-rhs.nsteps)) + ) + plan = lhs_plan + rhs_plan + if np.max(plan) <= 1: + return (True, MicroPlan(plan, summation=lhs.summation+rhs.summation)) + else: + # find conflict + sidx, tidx = (plan > 1).nonzero() + return (False, (sidx, tidx)) + + def shift(self, position: Tuple[int, int], distance: int) -> bool: + """ + shift the task at position to later (+) or previous (-) steps + + MicroPlan requires there is no more than one task on same temporal slot + + Args: + position: tuple of (spatial_idx (row), step_idx (column)) + """ + s, t = position + if self.plan[s][t] != 1: + raise KeyError("No task is on this possition") + if t + distance < 0: + return False + if distance == 0: + return True + if distance > 0: + slots = np.zeros((self.ndevs, distance), dtype=int) + self.plan = np.insert(self.plan, slice(t, t+distance), slots, axis=1) + return True + if distance < 0: + slots = self.plan[:,t+distance:t] + if np.max(slots) != 0: + return False + self.plan = np.delete(self.plan, slice(t+distance, t), axis=1) + return True + return False + + def __repr__(self): + return repr(self.plan) + + +def create_microbatch(n_stage: int, n_dev: int, placement: List[int], name=None): + plan = np.zeros((n_dev, n_stage * 2), dtype=int) + for sid, devid in enumerate(placement): + # forward + plan[devid, sid] += 1 + # backward + plan[devid, 2 * n_stage - 1 - sid] += 1 + return MicroPlan(plan, name) + + +def get_conflict(micros: List[MicroPlan], step: int): + """ + Get conflicting postition at temporal step T + """ + plans = [] + for micro in micros: + if step >= micro.nsteps: + plans.append(np.zeros((micro.ndevs, 1), dtype=int)) + else: + plans.append(micro.plan[:,step:step+1]) + # [ndev, nmicros] + plans = np.hstack(tuple(plans)) + # devid [int] -> (micro_id, step) + conflicts = dict() + # conflict device ids + devids = np.where(np.sum(plans, axis=1) > 1)[0] + for devid in devids: + positions = plans[devid].nonzero()[0] + positions = [(mid, step) for mid in positions] + conflicts[devid] = positions + return conflicts + + +def solve(micros: List[MicroPlan], conflicts: Dict[int, Tuple[int, int]]): + # always address first conflicts + print(f'solve conflicts: {conflicts}') + devid = list(conflicts.keys())[0] + mid, tid = conflicts[devid][0] + print(f'select device: {devid}, micro id: {mid}, step: {tid} to solve') + micros[mid].shift((devid, tid), 1) + print(f'shift results: microbatch-{mid}') + print(micros[mid]) + return (mid, devid, tid) + + +def search(n_microbatch: int, n_stage: int, n_dev: int): + placement = [sid % n_dev for sid in range(n_stage)] + micros = [create_microbatch(n_stage, n_dev, placement, name=mid) for mid in range(n_microbatch)] + tidx = 0 + #TODO: justify: why firstly sovle early-step conflicts + while tidx < max([micro.nsteps for micro in micros]): + while True: + # conflict point Dict[device_id, (mid, step_id)] + conflicts = get_conflict(micros, step=tidx) + if len(conflicts) > 0: + # solve conflicts + #TODO: justify: whom: which microbatch should apply shift + #TODO: justify: how: shift distance + solve(micros, conflicts) + else: + tidx += 1 + break + span = max([micro.nsteps for micro in micros]) + print(f'find plan: {span} steps') + for mid, micro in enumerate(micros): + print(f'microbatch-{mid}:') + print(micro) + + +if __name__ == '__main__': + num_microbatch = 4 + num_stage = 4 + num_device = 4 + search(num_microbatch, num_stage, num_device) From dad3a61ac603afa761bba35535ffa568e6054b95 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 27 Feb 2022 19:17:58 +0800 Subject: [PATCH 0601/1892] using stall primitive (the only one needed) --- cube/search/piper.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/cube/search/piper.py b/cube/search/piper.py index 627e0590..919c4e14 100644 --- a/cube/search/piper.py +++ b/cube/search/piper.py @@ -51,6 +51,14 @@ def __add__(self, other): # find conflict sidx, tidx = (plan > 1).nonzero() return (False, (sidx, tidx)) + + def stall(self, step: int): + """ + Primitive: insert a stall at stepline index `step` + """ + slots = np.zeros((self.ndevs, 1), dtype=int) + self.plan = np.insert(self.plan, slice(step, step+1), slots, axis=1) + return True def shift(self, position: Tuple[int, int], distance: int) -> bool: """ @@ -123,7 +131,8 @@ def solve(micros: List[MicroPlan], conflicts: Dict[int, Tuple[int, int]]): devid = list(conflicts.keys())[0] mid, tid = conflicts[devid][0] print(f'select device: {devid}, micro id: {mid}, step: {tid} to solve') - micros[mid].shift((devid, tid), 1) + micros[mid].stall(tid) + # micros[mid].shift((devid, tid), 1) print(f'shift results: microbatch-{mid}') print(micros[mid]) return (mid, devid, tid) From 48f1a2cbd9a216dc51809c2dd88fc61766f9b5c7 Mon Sep 17 00:00:00 2001 From: lynex Date: Tue, 1 Mar 2022 09:26:55 +0800 Subject: [PATCH 0602/1892] wrf example updated --- examples/wrf/wrf.py | 304 +++++++++++++++++++++++++++----------------- 1 file changed, 190 insertions(+), 114 deletions(-) diff --git a/examples/wrf/wrf.py b/examples/wrf/wrf.py index 16099a82..3a6e5f04 100644 --- a/examples/wrf/wrf.py +++ b/examples/wrf/wrf.py @@ -1,116 +1,127 @@ +from typing import List + import torch torch.set_default_tensor_type(torch.DoubleTensor) from torch import nn import torch.nn.functional as F # from linalg import tridiagonal +import cube +from examples.poisson.policy.naive import PAS + + +device = 'cuda' # + + + +def init(nx, ny, nz, dx, dy, dz, theta, PREF, Rd, g, p_t=None, p_s=None, u0=None, v0=None, w0=None, device='cuda'): + # spatial discretization + # dx, dy, dz, nx, ny, nz = dx, dy, 1. / (nz + 1), nx, ny, nz + # agnostic variables + P_t = p_t if p_t else torch.ones((1, ny, nx), device=device) * PREF * 0.0 + P_s = p_s if p_s else torch.ones((1, ny, nx), device=device) * PREF + # pressure (nz, ny, nx) + P = torch.linspace(dz, 1 - dz, nz, device=device).view(nz, 1, 1) * \ + (P_s - P_t).view(1, ny, nx) + P_t + # Alpha (nz, ny, nx) + Alpha = Rd / PREF * theta[1:-1] * (P / PREF) ** (-1 / 1.4) + # prognostic variables + # Mu (nz, ny, nx) + Mu = torch.ones((nz, 1, 1), device=device) * (P_s - P_t).view(1, ny, nx) + # Mu_t = (P_s - P_t).view(1, ny, nx) + # Mu_s = (P_s - P_t).view(1, ny, nx) + # Phi (nz - 1, ny, nx) + Phi = torch.zeros((nz + 1, ny, nx), device=device) + Phi[:-1] = Mu * Alpha * dz + for i in range(nz - 1, -1, -1): + Phi[i] += Phi[i + 1] + Phi_t = Phi[0].view(1, ny, nx) + Phi_s = Phi[-1].view(1, ny, nx) + Phi = Phi[1:-1] + # Theta (nz, ny, nx) + theta_t = theta[0].view(1, ny, nx) + theta_s = theta[-1].view(1, ny, nx) + Theta = theta[1:-1] * Mu + # U (nz, ny, nx + 1) + U = u0 if u0 is not None else torch.zeros((nz, ny, nx + 1), device=device) + # V (nz, ny + 1, nx) + V = v0 if v0 is not None else torch.zeros((nz, ny + 1, nx), device=device) + # W (nz - 1, ny, nx) + W = w0 if w0 is not None else torch.zeros((nz - 1, ny, nx), device=device) + + return U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s class WRF(torch.nn.Module): - r"""WRF Model - - Args: - theta (Tensor): inital potential temperature, (nz + 2, ny, nx), including boundary condition - p_t (Tensor): pressure at model top, (ny, nx), if None all zeros - p_s (Tensor): pressure at surface, (ny, nx), if None all sea level pressure - u0 (Tensor): inital x flow, (nz, ny, nx + 1), if None all zeros, default to None. - v0 (Tensor): inital y flow, (nz, ny + 1, nx), if None all zeros, detault to None. - w0 (Tensor): inital z flow, (nz - 1, ny, nx), if None all zeros, default to None. - """ - - def __init__(self, dx, dy, nx, ny, nz, theta, p_t=None, p_s=None, u0=None, v0=None, w0=None, device='cuda'): + + def __init__(self): super().__init__() - self.device = device - - # constants - self.PREF = 1e5 # reference pressure, usually sea level pressure, Pa - self.Rd = 287 # gas constant for dry air, J/(kg*K) - self.g = 9.81 # the acceleration of gravity, m/s**2 - - # spatial discretization - self.dx, self.dy, self.dz, self.nx, self.ny, self.nz = dx, dy, 1. / (nz + 1), nx, ny, nz - - # agnostic variables - self.P_t = p_t if p_t else torch.ones((1, ny, nx), device=device) * self.PREF * 0.0 - self.P_s = p_s if p_s else torch.ones((1, ny, nx), device=device) * self.PREF - # pressure (nz, ny, nx) - self.P = torch.linspace(self.dz, 1 - self.dz, nz, device=device).view(nz, 1, 1) * \ - (self.P_s - self.P_t).view(1, ny, nx) + self.P_t - # Alpha (nz, ny, nx) - self.Alpha = self.Rd / self.PREF * theta[1:-1] * (self.P / self.PREF)**(-1/1.4) - - # prognostic variables - # Mu (nz, ny, nx) - self.Mu = torch.ones((nz, 1, 1), device=device) * (self.P_s - self.P_t).view(1, ny, nx) - # self.Mu_t = (self.P_s - self.P_t).view(1, ny, nx) - # self.Mu_s = (self.P_s - self.P_t).view(1, ny, nx) - # Phi (nz - 1, ny, nx) - Phi = torch.zeros((nz + 1, ny, nx), device=device) - Phi[:-1] = self.Mu * self.Alpha * self.dz - for i in range(nz - 1, -1, -1): - Phi[i] += Phi[i + 1] - self.Phi_t = Phi[0].view(1, ny, nx) - self.Phi_s = Phi[-1].view(1, ny, nx) - self.Phi = Phi[1:-1] - # Theta (nz, ny, nx) - self.theta_t = theta[0].view(1, ny, nx) - self.theta_s = theta[-1].view(1, ny, nx) - self.Theta = theta[1:-1] * self.Mu - # U (nz, ny, nx + 1) - self.U = u0 if u0 is not None else torch.zeros((nz, ny, nx + 1), device=device) - # V (nz, ny + 1, nx) - self.V = v0 if v0 is not None else torch.zeros((nz, ny + 1, nx), device=device) - # W (nz - 1, ny, nx) - self.W = w0 if w0 is not None else torch.zeros((nz - 1, ny, nx), device=device) - - def RHS(self, U, V, W, Theta, Mu, Phi): + # self.method_name() + + # def forward(self, dt): + # self.U, self.V, self.W, self.Theta, self.Mu, self.Phi = \ + # self.RK3_step(self.U, self.V, self.W, self.Theta, self.Mu, self.Phi, dt) + def forward(self, + U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, + dx, dy, dz, + dt, PREF, Rd, g): + + U, V, W, Theta, Mu, Phi = self.RK3_step(U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) + return U, V, W, Theta, Mu, Phi + + def RHS(self, + U, V, W, Theta, Mu, Phi, + Phi_t, Phi_s, + P_t, P_s, dx, dy, dz, + PREF, Rd, g): + # volecity u = U / self.bar_x(self.pad_x(Mu)) v = V / self.bar_y(self.pad_y(Mu)) w = W / self.bar_z(Mu) - alpha = -self.delta_z(self.pad_z(Phi, self.Phi_t, self.Phi_s)) / Mu - self.Alpha = alpha + alpha = -self.delta_z(self.pad_z(Phi, Phi_t, Phi_s), dz) / Mu + Alpha = alpha theta = Theta / Mu - p = self.PREF * (self.Rd * theta / self.PREF / alpha)**1.4 - omega = -w * self.g / self.bar_z(alpha) / self.bar_z(Mu) + p = PREF * (Rd * theta / PREF / alpha)**1.4 + omega = -w * g / self.bar_z(alpha) / self.bar_z(Mu) Omega = omega * self.bar_z(Mu) - self.Omega = Omega + #Omega = Omega # advection term - R_U = - self.delta_x(self.bar_x(self.pad_x(U)) * self.bar_x(self.pad_x(u))) \ - - self.delta_y(self.bar_x(self.pad_x(V)) * self.bar_y(self.pad_y(u))) \ - - self.delta_z(self.bar_x(self.pad_x(self.pad_z(Omega))) * self.bar_z(self.pad_z(u))) + R_U = - self.delta_x(self.bar_x(self.pad_x(U)) * self.bar_x(self.pad_x(u)), dx) \ + - self.delta_y(self.bar_x(self.pad_x(V)) * self.bar_y(self.pad_y(u)), dy) \ + - self.delta_z(self.bar_x(self.pad_x(self.pad_z(Omega))) * self.bar_z(self.pad_z(u)), dz) - R_V = - self.delta_x(self.bar_y(self.pad_y(U)) * self.bar_x(self.pad_x(v))) \ - - self.delta_y(self.bar_y(self.pad_y(V)) * self.bar_y(self.pad_y(v))) \ - - self.delta_z(self.bar_y(self.pad_y(self.pad_z(Omega))) * self.bar_z(self.pad_z(v))) + R_V = - self.delta_x(self.bar_y(self.pad_y(U)) * self.bar_x(self.pad_x(v)), dx) \ + - self.delta_y(self.bar_y(self.pad_y(V)) * self.bar_y(self.pad_y(v)), dy) \ + - self.delta_z(self.bar_y(self.pad_y(self.pad_z(Omega))) * self.bar_z(self.pad_z(v)), dz) - R_W = - self.delta_x(self.bar_z(U) * self.bar_x(self.pad_x(w))) \ - - self.delta_y(self.bar_z(V) * self.bar_y(self.pad_y(w))) \ - - self.delta_z(self.bar_z(self.pad_z(Omega)) * self.bar_z(self.pad_z(w))) + R_W = - self.delta_x(self.bar_z(U) * self.bar_x(self.pad_x(w)), dx) \ + - self.delta_y(self.bar_z(V) * self.bar_y(self.pad_y(w)), dy) \ + - self.delta_z(self.bar_z(self.pad_z(Omega)) * self.bar_z(self.pad_z(w)), dz) - R_Theta = - self.delta_x(U * self.bar_x(self.pad_x(theta))) \ - - self.delta_y(V * self.bar_y(self.pad_y(theta))) \ - - self.delta_z(self.pad_z(Omega) * self.bar_z(self.pad_z(theta))) + R_Theta = - self.delta_x(U * self.bar_x(self.pad_x(theta)), dx) \ + - self.delta_y(V * self.bar_y(self.pad_y(theta)), dy) \ + - self.delta_z(self.pad_z(Omega) * self.bar_z(self.pad_z(theta)), dz) - R_Phi = - self.bar_z(self.bar_x(u)) * self.delta_x(self.bar_x(self.pad_x(Phi))) \ - - self.bar_z(self.bar_y(v)) * self.delta_y(self.bar_y(self.pad_y(Phi))) \ - - omega * self.delta_z(self.bar_z(self.pad_z(Phi, self.Phi_t, self.Phi_s))) + R_Phi = - self.bar_z(self.bar_x(u)) * self.delta_x(self.bar_x(self.pad_x(Phi)), dx) \ + - self.bar_z(self.bar_y(v)) * self.delta_y(self.bar_y(self.pad_y(Phi)), dy) \ + - Omega * self.delta_z(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s)), dz) - R_Mu = - self.delta_x(U) - self.delta_y(V) - self.delta_z(self.pad_z(Omega)) + R_Mu = - self.delta_x(U, dx) - self.delta_y(V, dy) - self.delta_z(self.pad_z(Omega), dz) # pressure term - R_U += - self.bar_x(self.pad_x(Mu)) * self.bar_x(self.pad_x(alpha)) * self.delta_x(self.pad_x(p)) \ - - (self.delta_z(self.bar_x(self.bar_z(self.pad_x(self.pad_z(p, self.P_t, self.P_s))))) * - self.delta_x(self.pad_x(self.bar_z(self.pad_z(Phi, self.Phi_t, self.Phi_s))))) + R_U += - self.bar_x(self.pad_x(Mu)) * self.bar_x(self.pad_x(alpha)) * self.delta_x(self.pad_x(p), dx) \ + - (self.delta_z(self.bar_x(self.bar_z(self.pad_x(self.pad_z(p, P_t, P_s)))), dz) * + self.delta_x(self.pad_x(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s))), dx)) - R_V += - self.bar_y(self.pad_y(Mu)) * self.bar_y(self.pad_y(alpha)) * self.delta_y(self.pad_y(p)) \ - - (self.delta_z(self.bar_y(self.bar_z(self.pad_y(self.pad_z(p, self.P_t, self.P_s))))) * - self.delta_y(self.pad_y(self.bar_z(self.pad_z(Phi, self.Phi_t, self.Phi_s))))) + R_V += - self.bar_y(self.pad_y(Mu)) * self.bar_y(self.pad_y(alpha)) * self.delta_y(self.pad_y(p), dy) \ + - (self.delta_z(self.bar_y(self.bar_z(self.pad_y(self.pad_z(p, P_t, P_s)))), dz) * + self.delta_y(self.pad_y(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s))), dy)) - R_W += self.g * (self.delta_z(p) - self.bar_z(Mu)) + R_W += g * (self.delta_z(p, dz) - self.bar_z(Mu)) # gravity term - R_Phi += self.g * w + R_Phi += g * w # Coriolis term # R_U += + 100 * self.bar_x(self.bar_y(self.pad_x(V))) \ @@ -120,11 +131,12 @@ def RHS(self, U, V, W, Theta, Mu, Phi): # + 100 * self.bar_y(self.bar_z(self.pad_y(self.pad_z(W)))) \ # - v * self.bar_y(self.bar_z(self.pad_y(self.pad_z(W)))) / 6400. / 1000. - return R_U, R_V, R_W, R_Theta, R_Mu, R_Phi, + return R_U, R_V, R_W, R_Theta, R_Mu, R_Phi #, Alpha, Omega - def RK3_step(self, U, V, W, Theta, Mu, Phi, dt): + # def RK3_step(self, U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, dt): + def RK3_step(self, U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g): r"""One RK3 Step""" - R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U, V, W, Theta, Mu, Phi) + R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) U_ = U + dt * R_U / 3 V_ = V + dt * R_V / 3 W_ = W + dt * R_W / 3 @@ -132,7 +144,7 @@ def RK3_step(self, U, V, W, Theta, Mu, Phi, dt): Mu_ = Mu + dt * R_Mu / 3 Phi_ = Phi + dt * R_Phi / 3 - R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_) + R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) U_ = U + dt * R_U / 2 V_ = V + dt * R_V / 2 W_ = W + dt * R_W / 2 @@ -140,7 +152,7 @@ def RK3_step(self, U, V, W, Theta, Mu, Phi, dt): Mu_ = Mu + dt * R_Mu / 2 Phi_ = Phi + dt * R_Phi / 2 - R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_) + R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) U += dt * R_U V += dt * R_V W += dt * R_W @@ -150,10 +162,6 @@ def RK3_step(self, U, V, W, Theta, Mu, Phi, dt): return U, V, W, Theta, Mu, Phi - def forward(self, dt): - self.U, self.V, self.W, self.Theta, self.Mu, self.Phi = \ - self.RK3_step(self.U, self.V, self.W, self.Theta, self.Mu, self.Phi, dt) - def pad_x(self, X): r"""Periodic boundary condition in x axis""" return F.pad(X, (1, 1), "circular") @@ -163,11 +171,12 @@ def pad_y(self, X): Nz, Ny, Nx = X.shape return F.pad(X.view(1, Nz, Ny, Nx), (0, 0, 1, 1), "circular").view(Nz, Ny + 2, Nx) - def pad_z(self, X, top=None, surface=None): + # TODO def pad_z(self, X, top=None, surface=None): + def pad_z(self, X, top=torch.Tensor(), surface=torch.Tensor()): r"""Dirichlet boundary condition in z axis""" _, ny, nx = X.shape - top = top if top is not None else torch.zeros((1, ny, nx), device=X.device) - surface = surface if surface is not None else torch.zeros((1, ny, nx), device=X.device) + top = torch.zeros((1, ny, nx), device=X.device) #TODO top = top if top is not None else torch.zeros((1, ny, nx), device=X.device) + surface = torch.zeros((1, ny, nx), device=X.device) #TODO surface = surface if surface is not None else torch.zeros((1, ny, nx), device=X.device) return torch.cat((top, X, surface), dim=0) def bar_x(self, X): @@ -183,7 +192,7 @@ def bar_x(self, X): filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / 2. - def delta_x(self, X): + def delta_x(self, X, dx): r"""Numerical scheme for \delta_x X Args: @@ -194,7 +203,7 @@ def delta_x(self, X): """ Nz, Ny, Nx = X.shape filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / self.dx + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / dx def bar_y(self, X): r"""Numerical scheme for X\bar^y @@ -209,7 +218,7 @@ def bar_y(self, X): filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / 2. - def delta_y(self, X): + def delta_y(self, X, dy): r"""Numerical scheme for \delta_y X Args: @@ -220,7 +229,7 @@ def delta_y(self, X): """ Nz, Ny, Nx = X.shape filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 2, 1) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / self.dy + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / dy def bar_z(self, X): r"""Numerical scheme for X\bar^z @@ -235,7 +244,7 @@ def bar_z(self, X): filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / 2. - def delta_z(self, X): + def delta_z(self, X, dz): r"""Numerical scheme for \delta_z X Args: @@ -246,20 +255,63 @@ def delta_z(self, X): """ Nz, Ny, Nx = X.shape filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / self.dz + return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / dz def _acoustic_step(self, ): r"""One acustic step""" pass +class LoopVariables(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): + + shapes = [list(var.size()) for var in variables + constants] + dtypes = [var.dtype for var in variables + constants] + batch_dims = [0] * (len(variables) + len(constants)) + super().__init__(shapes, dtypes, batch_dims) + self.variables = list() + self.constants = list() + for var in variables: + if torch.is_tensor(var) and var.device != torch.cuda.current_device(): + var = var.cuda() + self.variables.append(var) + for const in constants: + if torch.is_tensor(const) and const.device != torch.cuda.current_device(): + const = const.cuda() + self.constants.append(const) + + def __iter__(self): + return self + + def update(self, variables: List[torch.Tensor] = None, constants: List[torch.Tensor] = None): + if variables is not None: + self.variables = variables + if constants is not None: + self.constants = constants + + def reset(self, batch_size): + pass + + def __next__(self): + if len(self.variables) + len(self.constants) == 1: + return (self.variables + self.constants)[0] + return tuple(self.variables + self.constants) + if __name__ == "__main__": + cube.init() + # simulation settings nx = 201 ny = 201 nz = 201 - dx = 1e3 # m + dx = 1e3 # m dy = 1e3 # m + dz = 1. / (nz + 1) + # constants + PREF = torch.tensor(1e5) # reference pressure, usually sea level pressure, Pa + Rd = torch.tensor(287) # gas constant for dry air, J/(kg*K) + g = torch.tensor(9.81) # the acceleration of gravity, m/s**2 x0 = 100e3 y0 = 100e3 @@ -269,18 +321,42 @@ def _acoustic_step(self, ): theta += torch.linspace(1, 0, nz + 2).view(nz + 2, 1, 1) * \ -100. * torch.exp(-0.5 * ((grid_x - x0)**2 + (grid_y - y0)**2) / 400e6).view(1, ny, nx) # u0 = torch.ones((nz, ny, nx + 1)).cuda() - wrf = WRF(dx, dy, nx, ny, nz, theta.cuda()) + # wrf = WRF(dx, dy, nx, ny, nz, theta.cuda()) + theta = theta.cuda() + + dt = torch.tensor(0.1) + # nx = torch.tensor(nx) + # ny = torch.tensor(ny) + # nz = torch.tensor(nz) + dx = torch.tensor(dx) + dy = torch.tensor(dy) + dz = torch.tensor(dz) + + U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s = init(nx, ny, nz, dx, dy, dz, theta, PREF, Rd, g) + + varloader = LoopVariables(variables=[U, V, W, Theta, Mu, Phi, dt], constants=[Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g]) + model = WRF() + model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes), ) + + #TODO @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) + def train_iter(model, dataloader): + U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g = next(dataloader) + U, V, W, Theta, Mu, Phi = model(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, dt, PREF, Rd, g) + return U, V, W, Theta, Mu, Phi + #TODO model = model.get_gen_module() import matplotlib.pyplot as plt import numpy as np - while True: - plt.cla() - cf = plt.contourf(wrf.Theta[:, 100, :].cpu().numpy(), levels=50, cmap='jet') - cb = plt.colorbar(cf) - plt.savefig('res.jpeg', dpi=300) - plt.clf() - input('stop') + for iter in range (3): # while True: + # plt.cla() + # cf = plt.contourf(wrf.Theta[:, 100, :].cpu().numpy(), levels=50, cmap='jet') + # cb = plt.colorbar(cf) + # plt.savefig('res.jpeg', dpi=300) + # plt.clf() + # input('stop') + + # for i in range(1): + print("iter-{}...".format(iter)) + U, V, W, Theta, Mu, Phi = train_iter(model, varloader) # Phi_t, Phi_s, theta_t, theta_s - for i in range(1): - wrf(0.1) From a7053b6fb1e62996cf57e0c565e5edf0431c8eb3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 1 Mar 2022 10:29:13 +0800 Subject: [PATCH 0603/1892] enable parse on method --- cube/graph/parser/parser.py | 57 +++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 14538864..f7ba0e98 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -86,6 +86,49 @@ def parse_module(module, frame.pop() return input_val, all_ir_nodes, output_val + @staticmethod + def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): + """ + Parse module method + """ + + frame.push() + + input_var_name = [input.debugName() for input in method.graph.inputs()] + kDefaultType = DType2IRDType.map(torch.get_default_dtype()) + + for index, var_name in enumerate(input_var_name[1:]): # omit self + frame.add_var(var_name, IRFullTensor(name=var_name, requires_grad=False, dtype=kDefaultType), graph_arg=index) + + input_val = [frame.get_var(var_name) for var_name in input_var_name[1:]] + + all_ir_nodes: List[IRFwOperation] = list() + for node in method.graph.nodes(): + ir_nodes = ScriptModuleParser.parse_node(node, module, frame) + if len(ir_nodes) != 0: + for ir_node in ir_nodes: + try: + ret = ir_node.infer_shape() + if not ret: + print(f'warning: {ir_node} cannot infer shape') + except Exception: + raise RuntimeError(f"Shape infer error at: {ir_node}") + all_ir_nodes += ir_nodes + + # handle graph output -- Assuming all the output are tensors + output_var_name = [output.debugName() for output in method.graph.outputs()] + output_val = [frame.get_var(var_name) for var_name in output_var_name] + outputs = list() + for val in output_val: + if isinstance(val, list): + outputs += val + else: + outputs.append(val) + output_val = outputs + + frame.pop() + return input_val, all_ir_nodes, output_val + @staticmethod def ntype(node: torch._C.Node): if node.kind() == 'prim::GetAttr': @@ -242,8 +285,8 @@ def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: # forward label = node.s('name') - if label != 'forward': - raise RuntimeError(f"{node} is calling function {label} that is not `forward`") + # if label != 'forward': + # raise RuntimeError(f"{node} is calling function {label} that is not `forward`") # handle inputs -- in stack with reverse order for input in inputs[1:][::-1]: @@ -254,9 +297,13 @@ def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: # print(f'> {frame}') # recursively parse the module - module_label = node.inputsAt(0).node().s('name') - call_module = getattr(module, module_label) - _, ir_nodes, outputs_val = ScriptModuleParser.parse_module(call_module, frame=frame) + if node.inputsAt(0).debugName() == 'self': + call_module = module + else: + call_module = getattr(module, node.inputsAt(0).debugName()) + + call_method = getattr(call_module, label) + _, ir_nodes, outputs_val = ScriptModuleParser.parse_module_method(call_module, call_method, frame=frame) # pop out the frame frame.pop_param(times=len(inputs)-1) From 4c5139a1c16d5b9501a0675ee7b0fedeb3bf53b4 Mon Sep 17 00:00:00 2001 From: lynex Date: Tue, 1 Mar 2022 15:52:29 +0800 Subject: [PATCH 0604/1892] fix IRFullTensor print bug --- cube/ir/cten.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 0d54d424..d5a7add3 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -461,7 +461,10 @@ def detach_cell(self): @property def device(self) -> List[int]: - return self._cell.device + if self._cell: + return self._cell.device + else: + return None @device.setter def device(self, val: Union[int, List[int]]): From 571bbab86521b9a036791890e30162c9ee66a858 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 2 Mar 2022 18:39:00 +0800 Subject: [PATCH 0605/1892] remove test files --- tests/algorithm/test_activation.py | 75 --- tests/algorithm/test_bmm.py | 206 -------- tests/algorithm/test_complex.py | 543 --------------------- tests/algorithm/test_elementwise.py | 90 ---- tests/algorithm/test_factory.py | 15 - tests/algorithm/test_generics.py | 22 - tests/algorithm/test_layernorm.py | 80 --- tests/algorithm/test_linear_algo.py | 185 ------- tests/algorithm/test_memory.py | 46 -- tests/algorithm/test_reduce.py | 68 --- tests/codegen/test_codegen.py | 186 ------- tests/codegen/test_partition_codegen.py | 112 ----- tests/execplan/test_planpass_merge.py | 84 ---- tests/execplan/test_planpass_redundant.py | 78 --- tests/graph/parser/test_parse_attention.py | 115 ----- tests/graph/parser/test_parse_mlp.py | 70 --- tests/graph/test_function.py | 19 - tests/graph/test_graph.py | 184 ------- tests/graph/test_graph_partition.py | 157 ------ tests/graph/test_pas.py | 98 ---- tests/graph/test_tensor.py | 231 --------- tests/graph/test_tensor_grad.py | 153 ------ tests/ir/test_cell.py | 219 --------- tests/ir/test_tensor.py | 192 -------- tests/runtime/rollsplit.py | 164 ------- tests/runtime/test_group.py | 42 -- tests/runtime/test_nccl.py | 90 ---- tests/schedule/test_adapter_transform.py | 196 -------- tests/schedule/test_pool.py | 22 - tests/schedule/test_su.py | 99 ---- tests/schedule/test_sugraph.py | 249 ---------- tests/schedule/test_translator.py | 159 ------ tests/schedule/test_workflow.py | 103 ---- 33 files changed, 4352 deletions(-) delete mode 100644 tests/algorithm/test_activation.py delete mode 100644 tests/algorithm/test_bmm.py delete mode 100644 tests/algorithm/test_complex.py delete mode 100644 tests/algorithm/test_elementwise.py delete mode 100644 tests/algorithm/test_factory.py delete mode 100644 tests/algorithm/test_generics.py delete mode 100644 tests/algorithm/test_layernorm.py delete mode 100644 tests/algorithm/test_linear_algo.py delete mode 100644 tests/algorithm/test_memory.py delete mode 100644 tests/algorithm/test_reduce.py delete mode 100644 tests/codegen/test_codegen.py delete mode 100644 tests/codegen/test_partition_codegen.py delete mode 100644 tests/execplan/test_planpass_merge.py delete mode 100644 tests/execplan/test_planpass_redundant.py delete mode 100644 tests/graph/parser/test_parse_attention.py delete mode 100644 tests/graph/parser/test_parse_mlp.py delete mode 100644 tests/graph/test_function.py delete mode 100644 tests/graph/test_graph.py delete mode 100644 tests/graph/test_graph_partition.py delete mode 100644 tests/graph/test_pas.py delete mode 100644 tests/graph/test_tensor.py delete mode 100644 tests/graph/test_tensor_grad.py delete mode 100644 tests/ir/test_cell.py delete mode 100644 tests/ir/test_tensor.py delete mode 100644 tests/runtime/rollsplit.py delete mode 100644 tests/runtime/test_group.py delete mode 100644 tests/runtime/test_nccl.py delete mode 100644 tests/schedule/test_adapter_transform.py delete mode 100644 tests/schedule/test_pool.py delete mode 100644 tests/schedule/test_su.py delete mode 100644 tests/schedule/test_sugraph.py delete mode 100644 tests/schedule/test_translator.py delete mode 100644 tests/schedule/test_workflow.py diff --git a/tests/algorithm/test_activation.py b/tests/algorithm/test_activation.py deleted file mode 100644 index 4a5e9dfb..00000000 --- a/tests/algorithm/test_activation.py +++ /dev/null @@ -1,75 +0,0 @@ -import cube.algorithm.ops.activation as activation -from cube.graph.operator.function import Dropout -from cube.graph.tensor import IRFullTensor - - -def test_softmax_dim_parallel(): - - input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() - dim = -1 - stacklevel = 3 - dtype = None - - semantic_op = activation.Softmax( - signature = 'torch.nn.functional.softmax', - inputs = [input1, dim, stacklevel, dtype], - ) - semantic_op.infer_shape() - - op_dim = activation.SoftmaxDimParallel(semantic_op) - assert op_dim.dim is None - assert op_dim.chunk_num is None - - assert op_dim.satisfy(dict(dim=0, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) - assert not op_dim.satisfy(dict(dim=1, chunk_num=4)) - assert not op_dim.satisfy(dict(dim=-1, chunk_num=4)) - - nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) - for node in nodes: - print(node) - assert isinstance(node, activation.Softmax) - for input in node.inputs(): - print(input) - assert input.shape == [1024 // 4, 1024] - for output in node.outputs(): - print(output) - assert output.shape == [1024 // 4, 1024] - assert node.kwargs == semantic_op.kwargs - assert node.stay_dims == semantic_op.stay_dims - - -def test_dropout_dim_parallel(): - - input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() - p = 0.5 - training = True - inplace = False - - semantic_op = activation.Dropout( - signature = 'torch.nn.functional.softmax', - inputs = [input1, p, training, inplace], - ) - semantic_op.infer_shape() - - op_dim = activation.DropoutDimParallel(semantic_op) - assert op_dim.dim is None - assert op_dim.chunk_num is None - - assert op_dim.satisfy(dict(dim=0, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) - assert op_dim.satisfy(dict(dim=1, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-1, chunk_num=4)) - - nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) - for node in nodes: - print(node) - assert isinstance(node, activation.Dropout) - for input in node.inputs(): - print(input) - assert input.shape == [1024 // 4, 1024] - for output in node.outputs(): - print(output) - assert output.shape == [1024 // 4, 1024] - assert node.kwargs == semantic_op.kwargs - assert node.stay_dims == semantic_op.stay_dims diff --git a/tests/algorithm/test_bmm.py b/tests/algorithm/test_bmm.py deleted file mode 100644 index a161c0f2..00000000 --- a/tests/algorithm/test_bmm.py +++ /dev/null @@ -1,206 +0,0 @@ -import cube.algorithm.ops.bmm as bmm -from cube.graph.tensor import IRFullTensor, ValueMap - - -def test_bmm_data_parallel(): - - B = 64 # seq len - N = 256 # batch - M = 1024 # hiddend size = dim_head * num_head - P = 512 - input1 = IRFullTensor(shape=[B, N, M], name='hidden').tosub() - input2 = IRFullTensor(shape=[B, M, P], name='input2').tosub() - - semantic_op = bmm.BatchLinear( - signature='torch.bmm', - inputs = [input1, input2] - ) - semantic_op.infer_shape() - - bmm_dp = bmm.BatchLinearDataParallel(semantic_op) - - assert bmm_dp.chunk_num is None - - assert bmm_dp.satisfy(dict(chunk_num=8)) - assert not bmm_dp.satisfy(dict(chunk_num=9)) - - nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, bmm.BatchLinear) - - nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, bmm.BatchLinear) - - input1s = [node.inputs(0) for node in nodes] - print('inputs:') - for input in input1s: - print(input) - assert input.shape == [B // 4, N, M] - - input2s = [node.inputs(1) for node in nodes] - print('input2s:') - for input2 in input2s: - print(input2) - assert input2.shape == [B // 4, M, P] - - outputs = [node.outputs(0) for node in nodes] - for output in outputs: - print(output) - assert output.shape == [B // 4, N, P] - assert output.valmap == ValueMap(0, 1) - - -def test_bmm_n_parallel(): - - B = 64 # seq len - N = 256 # batch - M = 1024 # hiddend size = dim_head * num_head - P = 512 - input1 = IRFullTensor(shape=[B, N, M], name='hidden').tosub() - input2 = IRFullTensor(shape=[B, M, P], name='input2').tosub() - - semantic_op = bmm.BatchLinear( - signature='torch.bmm', - inputs = [input1, input2] - ) - semantic_op.infer_shape() - - bmm_dp = bmm.BatchLinearNParallel(semantic_op) - - assert bmm_dp.chunk_num is None - - assert bmm_dp.satisfy(dict(chunk_num=8)) - assert not bmm_dp.satisfy(dict(chunk_num=9)) - - nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, bmm.BatchLinear) - - nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, bmm.BatchLinear) - - input1s = [node.inputs(0) for node in nodes] - print('inputs:') - for input in input1s: - print(input) - assert input.shape == [B, N // 4, M] - - input2s = [node.inputs(1) for node in nodes] - print('input2s:') - for input2 in input2s: - print(input2) - assert input2.shape == [B, M, P] - - outputs = [node.outputs(0) for node in nodes] - for output in outputs: - print(output) - assert output.shape == [B, N // 4, P] - assert output.valmap == ValueMap(0, 1) - - -def test_bmm_m_parallel(): - - B = 64 - N = 256 - M = 1024 - P = 512 - input1 = IRFullTensor(shape=[B, N, M], name='input1').tosub() - input2 = IRFullTensor(shape=[B, M, P], name='input2').tosub() - - semantic_op = bmm.BatchLinear( - signature='torch.bmm', - inputs = [input1, input2] - ) - semantic_op.infer_shape() - - bmm_dp = bmm.BatchLinearMParallel(semantic_op) - - assert bmm_dp.chunk_num is None - - assert bmm_dp.satisfy(dict(chunk_num=8)) - assert not bmm_dp.satisfy(dict(chunk_num=9)) - - nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, bmm.BatchLinear) - - nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, bmm.BatchLinear) - - input1s = [node.inputs(0) for node in nodes] - print('inputs:') - for input in input1s: - print(input) - assert input.shape == [B, N, M // 4] - - input2s = [node.inputs(1) for node in nodes] - print('input2s:') - for input2 in input2s: - print(input2) - assert input2.shape == [B, M // 4, P] - - outputs = [node.outputs(0) for node in nodes] - for idx, output in enumerate(outputs): - print(output) - assert output.shape == [B, N, P] - assert output.valmap == ValueMap(idx, 4) - - -def test_bmm_p_parallel(): - - B = 64 # seq len - N = 256 # batch - M = 1024 # hiddend size = dim_head * num_head - P = 512 - input1 = IRFullTensor(shape=[B, N, M], name='hidden').tosub() - input2 = IRFullTensor(shape=[B, M, P], name='input2').tosub() - - semantic_op = bmm.BatchLinear( - signature='torch.bmm', - inputs = [input1, input2] - ) - semantic_op.infer_shape() - - bmm_dp = bmm.BatchLinearPParallel(semantic_op) - - assert bmm_dp.chunk_num is None - - assert bmm_dp.satisfy(dict(chunk_num=8)) - assert not bmm_dp.satisfy(dict(chunk_num=9)) - - nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, bmm.BatchLinear) - - nodes = bmm_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, bmm.BatchLinear) - - input1s = [node.inputs(0) for node in nodes] - print('inputs:') - for input in input1s: - print(input) - assert input.shape == [B, N, M] - - input2s = [node.inputs(1) for node in nodes] - print('input2s:') - for input2 in input2s: - print(input2) - assert input2.shape == [B, M, P // 4] - - outputs = [node.outputs(0) for node in nodes] - for output in outputs: - print(output) - assert output.shape == [B, N, P // 4] - assert output.valmap == ValueMap(0, 1) \ No newline at end of file diff --git a/tests/algorithm/test_complex.py b/tests/algorithm/test_complex.py deleted file mode 100644 index 03d86e37..00000000 --- a/tests/algorithm/test_complex.py +++ /dev/null @@ -1,543 +0,0 @@ -import cube.algorithm.ops.complex as complex -from cube.graph.tensor import IRFullTensor, ValueMap - - -def test_complex_toqkv_data_parallel(): - - L = 64 # seq len - N = 16 # batch - E = 1024 # hiddend size = dim_head * num_head - num_head = 8 - dim_head = E // num_head - input = IRFullTensor(shape=[L, N, E], name='hidden').tosub() - weight = IRFullTensor(shape=[3 * E, E], name='weight').tosub() - - semantic_op = complex.CubeComplexToQKV( - signature='cube.runtime.function.complex.toqkv', - inputs = [input, weight, num_head] - ) - semantic_op.infer_shape() - - qkv_dp = complex.CubeToQKVDataParallel(semantic_op) - - assert qkv_dp.chunk_num is None - - assert qkv_dp.satisfy(dict(chunk_num=8)) - assert not qkv_dp.satisfy(dict(chunk_num=32)) - - nodes = qkv_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexToQKV) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - assert input.shape == [L, N // 4, E] - weights = [node.inputs(1) for node in nodes] - - print('weights:') - for weight in weights: - print(weight) - assert weight.shape == [3 * E, E] - - sub_heads = [node.kwargs['num_head'] for node in nodes] - print('num_head:') - for nhead in sub_heads: - assert nhead == 8 - print(nhead) - - outputs = [node.outputs() for node in nodes] - print('outputs:') - for output in outputs: - q, k, v = output - print('q:', q) - print('k:', k) - print('v:', v) - assert q.shape == [L, N * num_head // 4, dim_head] - assert k.shape == [L, N * num_head // 4, dim_head] - assert v.shape == [L, N * num_head // 4, dim_head] - - -def test_complex_toqkv_head_parallel(): - - L = 64 # seq len - N = 16 # batch - E = 1024 # hiddend size = dim_head * num_head - num_head = 8 - dim_head = E // num_head - input = IRFullTensor(shape=[L, N, E], name='hidden').tosub() - weight = IRFullTensor(shape=[3 * E, E], name='weight').tosub() - - semantic_op = complex.CubeComplexToQKV( - signature='cube.runtime.function.complex.toqkv', - inputs = [input, weight, num_head] - ) - semantic_op.infer_shape() - - qkv_hp = complex.CubeToQKVHeadParallel(semantic_op) - - assert qkv_hp.chunk_num is None - - assert qkv_hp.satisfy(dict(chunk_num=8)) - assert not qkv_hp.satisfy(dict(chunk_num=32)) - - nodes = qkv_hp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexToQKV) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - assert input.shape == [L, N, E] - - weights = [node.inputs(1) for node in nodes] - print('weights:') - for weight in weights: - assert weight.shape == [3 * E // 4, E] - print(weight) - - sub_heads = [node.kwargs['num_head'] for node in nodes] - print('sub_heads:') - for nhead in sub_heads: - assert nhead == num_head // 4 - print(nhead) - - outputs = [node.outputs() for node in nodes] - print('outputs:') - for output in outputs: - q, k, v = output - print('q:', q) - print('k:', k) - print('v:', v) - assert q.shape == [L, N * num_head // 4, dim_head] - assert k.shape == [L, N * num_head // 4, dim_head] - assert v.shape == [L, N * num_head // 4, dim_head] - - -def test_complex_tril_mask_data_parallel(): - - L = 64 # seq len - N = 16 # batch - num_head = 8 - input = IRFullTensor(shape=[N * num_head, L, L], name='hidden').tosub() - - semantic_op = complex.CubeComplexTrilMask( - signature = 'cube.runtime.function.complex.trill_mask', - inputs = [input, num_head], - ) - semantic_op.infer_shape() - - mask_dp = complex.CubeTrilMaskDataParallel(semantic_op) - - assert mask_dp.chunk_num is None - - assert mask_dp.satisfy(dict(chunk_num=8)) - assert not mask_dp.satisfy(dict(chunk_num=32)) - - nodes = mask_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexTrilMask) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - assert input.shape == [N * num_head // 4, L, L] - - sub_heads = [node.kwargs['num_head'] for node in nodes] - print('num_head:') - for nhead in sub_heads: - assert nhead == 8 - print(nhead) - - outputs = [node.outputs(0) for node in nodes] - print('outputs:') - for output in outputs: - print(output) - assert output.shape == [N * num_head // 4, L, L] - - -def test_complex_tril_mask_head_parallel(): - - L = 64 # seq len - N = 16 # batch - num_head = 8 - input = IRFullTensor(shape=[N * num_head, L, L], name='hidden').tosub() - - semantic_op = complex.CubeComplexTrilMask( - signature = 'cube.runtime.function.complex.trill_mask', - inputs = [input, num_head], - ) - semantic_op.infer_shape() - - mask_hp = complex.CubeTrilMaskHeadParallel(semantic_op) - - assert mask_hp.chunk_num is None - - assert mask_hp.satisfy(dict(chunk_num=8)) - assert not mask_hp.satisfy(dict(chunk_num=32)) - - nodes = mask_hp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexTrilMask) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - assert input.shape == [N * num_head // 4, L, L] - - sub_heads = [node.kwargs['num_head'] for node in nodes] - print('num_head:') - for nhead in sub_heads: - assert nhead == num_head // 4 - print(nhead) - - outputs = [node.outputs(0) for node in nodes] - print('outputs:') - for output in outputs: - print(output) - assert output.shape == [N * num_head // 4, L, L] - - -def test_complex_attn_view_data_parallel(): - - L = 64 # seq len - N = 16 # batch - num_head = 8 - dim_head = 128 - input = IRFullTensor( - shape=[N * num_head, L, dim_head], name='hidden').tosub() - - semantic_op = complex.CubeComplexAttnView( - signature = 'cube.runtime.function.complex.trill_mask', - inputs = [input, num_head], - ) - semantic_op.infer_shape() - - mask_hp = complex.CubeAttnViewDataParallel(semantic_op) - - assert mask_hp.chunk_num is None - - assert mask_hp.satisfy(dict(chunk_num=8)) - assert not mask_hp.satisfy(dict(chunk_num=32)) - - nodes = mask_hp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexAttnView) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - assert input.shape == [N * num_head // 4, L, dim_head] - - sub_heads = [node.kwargs['num_head'] for node in nodes] - print('num_head:') - for nhead in sub_heads: - assert nhead == num_head - print(nhead) - - outputs = [node.outputs(0) for node in nodes] - print('outputs:') - for output in outputs: - print(output) - assert output.shape == [L, N // 4, num_head * dim_head] - - -def test_complex_attn_view_head_parallel(): - - L = 64 # seq len - N = 16 # batch - num_head = 8 - dim_head = 128 - input = IRFullTensor( - shape=[N * num_head, L, dim_head], name='hidden').tosub() - - semantic_op = complex.CubeComplexAttnView( - signature = 'cube.runtime.function.complex.trill_mask', - inputs = [input, num_head], - ) - semantic_op.infer_shape() - - mask_hp = complex.CubeAttnViewHeadParallel(semantic_op) - - assert mask_hp.chunk_num is None - - assert mask_hp.satisfy(dict(chunk_num=8)) - assert not mask_hp.satisfy(dict(chunk_num=32)) - - nodes = mask_hp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexAttnView) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - assert input.shape == [N * num_head // 4, L, dim_head] - - sub_heads = [node.kwargs['num_head'] for node in nodes] - print('num_head:') - for nhead in sub_heads: - assert nhead == num_head // 4 - print(nhead) - - outputs = [node.outputs(0) for node in nodes] - print('outputs:') - for output in outputs: - print(output) - assert output.shape == [L, N, num_head * dim_head // 4] - - -def test_complex_self_attention_head_parallel(): - L = 64 # seq len - N = 16 # batch - num_head = 8 - dim_head = 128 - E = num_head * dim_head - - input = IRFullTensor( - shape=[L, N, E], name='hidden').tosub() - w_qkv = IRFullTensor( - shape=[3 * num_head * dim_head, num_head * dim_head], name='wqkv').tosub() - w_out = IRFullTensor( - shape=[num_head * dim_head, num_head * dim_head], name='wout').tosub() - - semantic_op = complex.CubeComplexSelfAttention( - signature = 'cube.runtime.function.complex.self_attn', - inputs = [input, w_qkv, w_out, num_head, dim_head, 0.5], - ) - semantic_op.infer_shape() - - op_head = complex.CubeSelfAttentionHeadParallel(semantic_op) - - assert op_head.satisfy(config=dict(chunk_num=8)) - assert not op_head.satisfy(config=dict(chunk_num=16)) - - nodes = op_head.instantiate(semantic_op, config=dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexSelfAttention) - - for idx, node in enumerate(nodes): - assert node.outputs(0).shape == [L, N, E] - assert node.outputs(0).valmap == ValueMap(idx, 4) - assert node.kwargs['num_head'] == num_head // 4 - assert node.inputs(0).shape == [L, N, E] - assert node.inputs(1).shape == [3 * E // 4, E] - assert node.inputs(2).shape == [E, E // 4] - - -def test_complex_self_attention_data_parallel(): - L = 64 # seq len - N = 16 # batch - num_head = 8 - dim_head = 128 - E = num_head * dim_head - - input = IRFullTensor( - shape=[L, N, E], name='hidden').tosub() - w_qkv = IRFullTensor( - shape=[3 * num_head * dim_head, num_head * dim_head], name='wqkv').tosub() - w_out = IRFullTensor( - shape=[num_head * dim_head, num_head * dim_head], name='wout').tosub() - - semantic_op = complex.CubeComplexSelfAttention( - signature = 'cube.runtime.function.complex.self_attn', - inputs = [input, w_qkv, w_out, num_head, dim_head, 0.5], - ) - semantic_op.infer_shape() - - op_head = complex.CubeSelfAttentionDataParallel(semantic_op) - - assert op_head.satisfy(config=dict(chunk_num=8)) - assert not op_head.satisfy(config=dict(chunk_num=32)) - - nodes = op_head.instantiate(semantic_op, config=dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexSelfAttention) - - for idx, node in enumerate(nodes): - assert node.outputs(0).shape == [L, N // 4, E] - assert node.outputs(0).valmap == ValueMap(0, 1) - assert node.kwargs['num_head'] == num_head - assert node.inputs(0).shape == [L, N // 4, E] - assert node.inputs(1).shape == [3 * E, E] - assert node.inputs(2).shape == [E, E] - - -def test_complex_feedforward_tensor_parallel(): - L = 64 # seq len - N = 16 # batch - E = 1024 - - input = IRFullTensor( - shape=[L, N, E], name='hidden').tosub() - w_proj1 = IRFullTensor( - shape=[4 * E, E], name='proj1').tosub() - w_bias1 = IRFullTensor( - shape=[4 * E,], name='bias1').tosub() - w_proj2 = IRFullTensor( - shape=[E, 4 * E], name='proj2').tosub() - w_bias2 = IRFullTensor( - shape=[E,], name='bias2').tosub() - - semantic_op = complex.CubeComplexFeedForward( - signature = 'cube.runtime.function.complex.feedforward', - inputs = [input, w_proj1, w_bias1, w_proj2, w_bias2], - ) - semantic_op.infer_shape() - - op_head = complex.CubeFeedForwardTensorParallel(semantic_op) - - assert op_head.satisfy(config=dict(chunk_num=8)) - assert op_head.satisfy(config=dict(chunk_num=32)) - - nodes = op_head.instantiate(semantic_op, config=dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexFeedForward) - - for idx, node in enumerate(nodes): - assert node.outputs(0).shape == [L, N, E] - assert node.outputs(0).valmap == ValueMap(idx, 4) - assert node.inputs(0).shape == [L, N, E] - assert node.inputs(1).shape == [4 * E // 4, E] - assert node.inputs(2).shape == [4 * E // 4,] - assert node.inputs(3).shape == [E, 4 * E // 4] - assert node.inputs(4).shape == [E,] - assert node.inputs(4).valmap == ValueMap(idx, 4) - - -def test_complex_feedforward_data_parallel(): - L = 64 # seq len - N = 16 # batch - E = 1024 - - input = IRFullTensor( - shape=[L, N, E], name='hidden').tosub() - w_proj1 = IRFullTensor( - shape=[4 * E, E], name='proj1').tosub() - w_bias1 = IRFullTensor( - shape=[4 * E,], name='bias1').tosub() - w_proj2 = IRFullTensor( - shape=[E, 4 * E], name='proj2').tosub() - w_bias2 = IRFullTensor( - shape=[E,], name='bias2').tosub() - - semantic_op = complex.CubeComplexFeedForward( - signature = 'cube.runtime.function.complex.feedforward', - inputs = [input, w_proj1, w_bias1, w_proj2, w_bias2], - ) - semantic_op.infer_shape() - - op_head = complex.CubeFeedForwardDataParallel(semantic_op) - - assert op_head.satisfy(config=dict(chunk_num=8)) - assert not op_head.satisfy(config=dict(chunk_num=32)) - - nodes = op_head.instantiate(semantic_op, config=dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexFeedForward) - - for idx, node in enumerate(nodes): - assert node.outputs(0).shape == [L, N // 4, E] - assert node.outputs(0).valmap == ValueMap(0, 1) - assert node.inputs(0).shape == [L, N // 4, E] - assert node.inputs(1).shape == [4 * E, E] - assert node.inputs(2).shape == [4 * E,] - assert node.inputs(3).shape == [E, 4 * E] - assert node.inputs(4).shape == [E,] - - -def test_embed_shard_parallel(): - L = 64 # seq len - N = 16 # batch - vocab = 50304 - E = 1024 - - ids = IRFullTensor(shape=[L, N], name='hidden').tosub() - weight = IRFullTensor(shape=[vocab, E], name='hidden').tosub() - start = 0 - stop = vocab - - semantic_op = complex.CubeComplexEmbedding( - signature = 'cube.runtime.function.complex.embedding', - inputs = [ids, weight, start, stop] - ) - semantic_op.infer_shape() - - assert semantic_op.outputs(0).shape == [L, N, E] - - op_shard = complex.CubeEmbedShardingParallel(semantic_op) - - assert op_shard.satisfy(config=dict(chunk_num=8)) - assert op_shard.satisfy(config=dict(chunk_num=32)) - assert not op_shard.satisfy(config=dict(chunk_num=256)) - - nodes = op_shard.instantiate(semantic_op, config=dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexEmbedding) - - start = semantic_op.kwargs['start'] - stop = semantic_op.kwargs['stop'] - shard = (stop - start) // 4 - for idx, node in enumerate(nodes): - assert node.outputs(0).shape == [L, N, E] - assert node.outputs(0).valmap == ValueMap(idx, 4) - assert node.inputs(0).shape == [L, N] - assert node.inputs(1).shape == [vocab // 4, E] - assert node.kwargs['start'] == start + idx * shard - assert node.kwargs['stop'] == start + (idx + 1) * shard - - -def test_embed_shard_parallel(): - L = 64 # seq len - N = 16 # batch - vocab = 50304 - E = 1024 - - ids = IRFullTensor(shape=[L, N], name='hidden').tosub() - weight = IRFullTensor(shape=[vocab, E], name='hidden').tosub() - start = 0 - stop = vocab - - semantic_op = complex.CubeComplexEmbedding( - signature = 'cube.runtime.function.complex.embedding', - inputs = [ids, weight, start, stop] - ) - semantic_op.infer_shape() - - assert semantic_op.outputs(0).shape == [L, N, E] - - op_shard = complex.CubeEmbedDataParallel(semantic_op) - - assert op_shard.satisfy(config=dict(dim=1, chunk_num=8)) - assert not op_shard.satisfy(config=dict(dim=1, chunk_num=32)) - - nodes = op_shard.instantiate(semantic_op, config=dict(dim=1, chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, complex.CubeComplexEmbedding) - - start = semantic_op.kwargs['start'] - stop = semantic_op.kwargs['stop'] - for idx, node in enumerate(nodes): - assert node.outputs(0).shape == [L, N // 4, E] - assert node.outputs(0).valmap == ValueMap(0, 1) - assert node.inputs(0).shape == [L, N // 4] - assert node.inputs(1).shape == [vocab, E] - assert node.kwargs['start'] == start - assert node.kwargs['stop'] == stop diff --git a/tests/algorithm/test_elementwise.py b/tests/algorithm/test_elementwise.py deleted file mode 100644 index f41d1b67..00000000 --- a/tests/algorithm/test_elementwise.py +++ /dev/null @@ -1,90 +0,0 @@ -from cube.graph.operator.function import ElementWise -import cube.algorithm.ops.elementwise as elew -from cube.graph.tensor import IRFullTensor - - -def test_elementwise_dim_parallel(): - - input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() - input2 = IRFullTensor(shape=[1024, 1024], name='input2').tosub() - - semantic_op = ElementWise( - signature='torch.add', inputs=[input1, input2], name='add' - ) - semantic_op.infer_shape() - print('semantic op:') - print(semantic_op) - - op_dp = elew.ElementWiseDimParallel(semantic_op, dim=0) - - assert op_dp.chunk_num is None - - # test satisfy - assert op_dp.satisfy(dict(chunk_num = 4)) - assert not op_dp.satisfy(dict(chunk_num = 10)) - - nodes = op_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, ElementWise) - - for node in nodes: - print('=======') - print(node) - print('inputs:') - for input in node.inputs(): - print(input) - assert input.shape == [256, 1024] - print('outputs:') - for output in node.outputs(): - print(output) - assert output.shape == [256, 1024] - - op_dp = elew.ElementWiseDimParallel(semantic_op, dim=1) - nodes = op_dp.instantiate(semantic_op, dict(chunk_num=4)) - - for node in nodes: - print('=======') - print(node) - print('inputs:') - for input in node.inputs(): - print(input) - assert input.shape == [1024, 256] - print('outputs:') - for output in node.outputs(): - print(output) - assert output.shape == [1024, 256] - - -def test_add_dim_parallel(): - - input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() - input2 = IRFullTensor(shape=[1024, 1024], name='input2').tosub() - alpha = 1.0 - - semantic_op = elew.Add( - signature='torch.add', inputs=[input1, input2, alpha], name='add' - ) - semantic_op.infer_shape() - - dim_op = elew.AddDimParallel(semantic_op) - - assert dim_op.dim is None - assert dim_op.chunk_num is None - - assert dim_op.satisfy(config=dict(dim=1, chunk_num=4)) - assert dim_op.satisfy(config=dict(dim=-1, chunk_num=4)) - assert dim_op.satisfy(config=dict(dim=0, chunk_num=4)) - assert not dim_op.satisfy(config=dict(dim=2, chunk_num=4)) - - nodes = dim_op.instantiate(semantic_op, dict(dim=0, chunk_num=4)) - for node in nodes: - print(node) - assert isinstance(node, elew.Add) - for input in node.inputs(): - print(input) - assert input.shape == [1024 // 4, 1024] - for output in node.outputs(): - print(output) - assert output.shape == [1024 // 4, 1024] - assert node.kwargs == semantic_op.kwargs diff --git a/tests/algorithm/test_factory.py b/tests/algorithm/test_factory.py deleted file mode 100644 index b7453274..00000000 --- a/tests/algorithm/test_factory.py +++ /dev/null @@ -1,15 +0,0 @@ -from cube.algorithm.factory import DistAlgorithmFactory -from cube.graph.operator.function import Linear -from cube.algorithm.generics import GenericDistAlgo - - -def test_factory_init(): - factory = DistAlgorithmFactory() - assert len(factory.algorithms(Linear)) == 3 - - -def test_factory_tag(): - - factory = DistAlgorithmFactory() - dp = factory.algorithms(Linear, tag='data') - assert issubclass(dp, GenericDistAlgo) diff --git a/tests/algorithm/test_generics.py b/tests/algorithm/test_generics.py deleted file mode 100644 index 088729d2..00000000 --- a/tests/algorithm/test_generics.py +++ /dev/null @@ -1,22 +0,0 @@ -from cube.algorithm.generics import GenericDistAlgo -from cube.ir.cten import IRCell, IRTensor - - -def test_generic_algo_init(): - input1 = IRTensor(shape=[1024, 1024]) - input2 = IRTensor(shape=[1024, 1000]) - bias = None - cell = IRCell(name='test', signature='test', input_length=3, output_length=1) - cell.set_input(0, input1) - cell.set_input(1, input2) - cell.set_input(2, bias) - cell.outputs(0).shape = [1024, 1000] - - algo = GenericDistAlgo(cell) - assert algo.logic_op is IRCell - assert len(algo.input_shapes) == 3 - assert algo.input_shapes[0] == [1024, 1024] - assert algo.input_shapes[1] == [1024, 1000] - assert algo.input_shapes[2] is None - assert len(algo.output_shapes) == 1 - assert algo.output_shapes[0] == [1024, 1000] diff --git a/tests/algorithm/test_layernorm.py b/tests/algorithm/test_layernorm.py deleted file mode 100644 index 73174808..00000000 --- a/tests/algorithm/test_layernorm.py +++ /dev/null @@ -1,80 +0,0 @@ -from cube.graph.operator.function import ElementWise -import cube.algorithm.ops.layernorm as ln -from cube.graph.tensor import IRFullTensor - - -def test_elementwise_dim_parallel(): - - input1 = IRFullTensor(shape=[1024, 512, 256], name='input1').tosub() - normalized_shape = [256,] - weight = IRFullTensor(shape=[256], name='weight').tosub() - bias = IRFullTensor(shape=[256], name='bias').tosub() - eps = 1e-5 - - semantic_op = ln.LayerNorm( - signature='torch.nn.functional.layernorm', - inputs=[input1, normalized_shape, weight, bias, eps], - name='layernorm' - ) - semantic_op.infer_shape() - print('semantic op:') - print(semantic_op) - - op_dim = ln.LayerNormDimParallel(semantic_op) - - assert op_dim.chunk_num is None - - # test satisfy - assert op_dim.satisfy(dict(dim=0, chunk_num = 4)) - assert op_dim.satisfy(dict(dim=1, chunk_num = 8)) - assert not op_dim.satisfy(dict(dim=2, chunk_num = 8)) - - nodes = op_dim.instantiate(semantic_op, dict(dim=1, chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, ln.LayerNorm) - - for node in nodes: - print(node) - print('inputs:') - - input = node.inputs(0) - print(input) - assert input.shape == [1024, 512 // 4, 256] - - weight = node.inputs(2) - print(weight) - assert weight.shape == [256,] - - bias = node.inputs(3) - print(bias) - assert bias.shape == [256,] - - print('outputs:') - for output in node.outputs(): - print(output) - assert output.shape == [1024, 512 // 4, 256] - - op_dim = ln.LayerNormDimParallel(semantic_op, dim=0) - nodes = op_dim.instantiate(semantic_op, dict(chunk_num=4)) - - for node in nodes: - print(node) - print('inputs:') - - input = node.inputs(0) - print(input) - assert input.shape == [1024 // 4, 512, 256] - - weight = node.inputs(2) - print(weight) - assert weight.shape == [256,] - - bias = node.inputs(3) - print(bias) - assert bias.shape == [256,] - - print('outputs:') - for output in node.outputs(): - print(output) - assert input.shape == [1024 // 4, 512, 256] diff --git a/tests/algorithm/test_linear_algo.py b/tests/algorithm/test_linear_algo.py deleted file mode 100644 index 7fd58008..00000000 --- a/tests/algorithm/test_linear_algo.py +++ /dev/null @@ -1,185 +0,0 @@ -from cube.graph.operator.function import Linear -from cube.algorithm.ops.linear import LinearDataParallel -from cube.algorithm.ops.linear import LinearColumnWeight -from cube.algorithm.ops.linear import LinearRowWeight -from cube.graph.tensor import IRFullTensor, ValueMap - - -def test_linear_data_parallel(): - - input = IRFullTensor(shape=[1024, 1024], name='input').tosub() - weight = IRFullTensor(shape=[1000, 1024], name='weight').tosub() - bias = IRFullTensor(shape=[1000,], name='bias').tosub() - - semantic_op = Linear( - signature='torch.nn.functional.linear', - inputs = [input, weight, bias], - ) - semantic_op.infer_shape() - - linear_dp = LinearDataParallel(semantic_op) - - assert linear_dp.chunk_num is None - - # test satisfy - assert linear_dp.satisfy(dict(chunk_num=4)) - assert not linear_dp.satisfy(dict(chunk_num=10)) - - nodes = linear_dp.instantiate(semantic_op, dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, Linear) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - weights = [node.inputs(1) for node in nodes] - print('weights:') - for weight in weights: - print(weight) - biass = [node.inputs(2) for node in nodes] - print('bias:') - for bias in biass: - print(bias) - - for idx, x in enumerate(inputs): - assert x.shape == [256, 1024] - assert x.indmap.get()[0] == slice(256 * idx, 256 * (idx + 1), 1) - assert not inputs[0].overlap(inputs[1]) - assert not inputs[0].overlap(inputs[2]) - assert not inputs[0].overlap(inputs[3]) - - for w in weights: - assert w.shape == [1000, 1024] - assert w == weight - - for b in biass: - assert b.shape == [1000] - assert b == bias - - -def test_linear_column_weight(): - input = IRFullTensor(shape=[1024, 1024], name='input').tosub() - weight = IRFullTensor(shape=[1000, 1024], name='weight').tosub() - bias = IRFullTensor(shape=[1000,], name='bias').tosub() - - semantic_op = Linear( - signature='torch.nn.functional.linear', - inputs = [input, weight, bias], - ) - semantic_op.infer_shape() - - linear_col_weight = LinearColumnWeight(semantic_op) - - # test satisfy - assert linear_col_weight.satisfy(dict(chunk_num=4)) - assert linear_col_weight.satisfy(dict(chunk_num=10)) - assert not linear_col_weight.satisfy(dict(chunk_num=12)) - - nodes = linear_col_weight.instantiate(semantic_op, config=dict(chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, Linear) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - weights = [node.inputs(1) for node in nodes] - print('weights:') - for weight in weights: - print(weight) - biass = [node.inputs(2) for node in nodes] - print('bias:') - for bias in biass: - print(bias) - outputs = [node.outputs(0) for node in nodes] - print('output:') - for output in outputs: - print(output) - - for x in inputs: - assert x == input - - for idx, w in enumerate(weights): - assert w.shape == [250, 1024] - assert w.indmap.get()[0] == slice(250 * idx, 250 * (idx + 1), 1) - - for idx, b in enumerate(biass): - assert b.shape == [250] - assert b.indmap.get() == (slice(250 * idx, 250 * (idx + 1), 1),) - - for idx, output in enumerate(outputs): - assert output.shape == [1024, 250] - assert output.indmap.get()[0] == slice(0, 1024, 1) - assert output.indmap.get()[1] == slice(250 * idx, 250 * (idx + 1), 1) - - -def test_linear_row(): - input = IRFullTensor(shape=[1024, 1024], name='input').tosub() - weight = IRFullTensor(shape=[1000, 1024], name='weight').tosub() - bias = IRFullTensor(shape=[1000,], name='bias').tosub() - - semantic_op = Linear( - signature='torch.nn.functional.linear', - inputs = [input, weight, bias], - ) - semantic_op.infer_shape() - - input_shapes = list() - for input in semantic_op.inputs(): - input_shapes.append(input.shape) - - output_shapes = list() - for output in semantic_op.outputs(): - output_shapes.append(output.shape) - - linear_row_weight = LinearRowWeight(semantic_op) - - # test satisfy - assert linear_row_weight.satisfy(dict(chunk_num=4)) - assert not linear_row_weight.satisfy(dict(chunk_num=10)) - assert not linear_row_weight.satisfy(dict(chunk_num=12)) - - nodes = linear_row_weight.instantiate(semantic_op, config=dict(chunk_num=4)) - - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, Linear) - - inputs = [node.inputs(0) for node in nodes] - print('inputs:') - for input in inputs: - print(input) - weights = [node.inputs(1) for node in nodes] - print('weights:') - for weight in weights: - print(weight) - biass = [node.inputs(2) for node in nodes] - print('bias:') - for bias in biass: - print(bias) - outputs = [node.outputs(0) for node in nodes] - print('output:') - for output in outputs: - print(output) - - for idx, x in enumerate(inputs): - assert x.shape == [1024, 256] - assert x.indmap.get()[1] == slice(256 * idx, 256 * (idx + 1), 1) - assert x.valmap == ValueMap(0, 1) - - for idx, w in enumerate(weights): - assert w.shape == [1000, 256] - assert w.indmap.get()[1] == slice(256 * idx, 256 * (idx + 1), 1) - assert w.valmap == ValueMap(0, 1) - - for idx, b in enumerate(biass): - assert b.shape == [1000,] - assert b.indmap.get()[0] == slice(0, 1000, 1) - assert b.valmap == ValueMap(idx, 4) - - for idx, output in enumerate(outputs): - assert output.shape == [1024, 1000] - assert output.valmap == ValueMap(idx, 4) diff --git a/tests/algorithm/test_memory.py b/tests/algorithm/test_memory.py deleted file mode 100644 index ed1b5176..00000000 --- a/tests/algorithm/test_memory.py +++ /dev/null @@ -1,46 +0,0 @@ -import cube.algorithm.ops.memory as mem -from cube.graph.tensor import IRFullTensor, ValueMap - - -def test_transpose_dim_parallel(): - - M = 512 - N = 1024 - input1 = IRFullTensor(shape=[M, N], name='input1').tosub() - dim0 = 0 - dim1 = 1 - - semantic_op = mem.Transpose( - signature='torch.transpose', inputs=[input1, dim0, dim1], name='transpose' - ) - semantic_op.infer_shape() - print('semantic op:') - print(semantic_op) - - op_dim = mem.TransposeDimParallel(semantic_op) - assert op_dim.dim is None - assert op_dim.chunk_num is None - - # test satisfy - assert op_dim.satisfy(dict(dim=0, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) - assert op_dim.satisfy(dict(dim=1, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-1, chunk_num=4)) - - nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, mem.Transpose) - - for idx, node in enumerate(nodes): - print('=======') - print(node) - print('inputs:') - for input in node.inputs(): - print(input) - assert input.shape == [M // 4, N] - print('outputs:') - for output in node.outputs(): - print(output) - assert output.shape == [N, M // 4] - assert output.valmap == ValueMap(0, 1) diff --git a/tests/algorithm/test_reduce.py b/tests/algorithm/test_reduce.py deleted file mode 100644 index d2304616..00000000 --- a/tests/algorithm/test_reduce.py +++ /dev/null @@ -1,68 +0,0 @@ -import cube.algorithm.ops.reduce as reduce -from cube.graph.tensor import IRFullTensor, ValueMap - - -def test_reduce_dim_parallel(): - - input1 = IRFullTensor(shape=[1024, 1024], name='input1').tosub() - dim = None - - semantic_op = reduce.Sum( - signature='torch.sum', inputs=[input1, dim], name='add' - ) - semantic_op.infer_shape() - print('semantic op:') - print(semantic_op) - - op_dim = reduce.SumDimParallel(semantic_op) - assert op_dim.dim is None - assert op_dim.chunk_num is None - - # test satisfy - assert op_dim.satisfy(dict(dim=0, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) - assert op_dim.satisfy(dict(dim=1, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-1, chunk_num=4)) - - nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) - assert len(nodes) == 4 - for node in nodes: - assert isinstance(node, reduce.Sum) - - for idx, node in enumerate(nodes): - print('=======') - print(node) - print('inputs:') - for input in node.inputs(): - print(input) - assert input.shape == [256, 1024] - print('outputs:') - for output in node.outputs(): - print(output) - assert output.shape == [1] - assert output.valmap == ValueMap(idx, 4) - - - dim = 1 - semantic_op = reduce.Sum( - signature='torch.sum', inputs=[input1, dim], name='add' - ) - semantic_op.infer_shape() - assert op_dim.satisfy(dict(dim=0, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-2, chunk_num=4)) - assert op_dim.satisfy(dict(dim=1, chunk_num=4)) - assert op_dim.satisfy(dict(dim=-1, chunk_num=4)) - - op_dim = reduce.SumDimParallel(semantic_op) - nodes = op_dim.instantiate(semantic_op, dict(dim=0, chunk_num=4)) - for idx, node in enumerate(nodes): - print(node) - print('inputs:') - for input in node.inputs(): - print(input) - assert input.shape == [256, 1024] - print('outputs:') - for output in node.outputs(): - print(output) - assert output.shape == [256] - assert output.valmap == ValueMap(0, 1) diff --git a/tests/codegen/test_codegen.py b/tests/codegen/test_codegen.py deleted file mode 100644 index fc05862f..00000000 --- a/tests/codegen/test_codegen.py +++ /dev/null @@ -1,186 +0,0 @@ -from cube.graph.operator.operator import IRDataOperation, IRFwOperation -from cube.graph.tensor import IRFullTensor -from cube.graph.operator.function import Linear -from cube.graph.graph import IRGraph -from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraphGener -from cube.schedule.translator import IRDataLoader - -from cube.execplan import ExectuionPlan -from cube.execplan.planpass.redundant import RemoveRedundantAdapters -from cube.execplan.planpass.merge import MergeComputeSU - -import torch - -from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen - - -class FakeDataLoader: - def __init__(self, batch_size, num=640): - self.batch_size = batch_size - self.length = num - self.pos = 0 - def __iter__(self): - self.pos = 0 - return self - def __next__(self): - self.pos += 1 - if self.pos == self.length: - raise StopIteration - return torch.randn((self.batch_size, 1024)) - - -def construct_graph(): - - input = IRFullTensor(shape=[64,1024], name='data') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - weight4 = IRFullTensor(shape=[1024, 1024], name='weight') - bias4 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - # linear4 - linear4 = Linear( - name='linear4', - signature='torch.nn.functional.linear', - inputs= [linear3.outputs(0), weight4, bias4], - ) - linear4.infer_shape() - - graph = IRGraph( - nodes=[linear1, linear2, linear3, linear4], - input_tensors=[input], - output_tensors=linear4.outputs(), - module_name="Test" - ) - return graph - - -def test_model_gen(): - - SchedulePool().clear() - - grad_accum = 2 - - graph = construct_graph() - dataloader = IRDataLoader(FakeDataLoader(batch_size=64)) - - data = next(dataloader) - output = graph(data) - output.backward() - - nodes = SchedulePool().nodes() - graph = IRGraph(nodes, None, None, module_name='Test') - - for node in graph.nodes(): - if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): - algo = node.algorithms('data') - graph.partition(node, algo, config=dict(chunk_num=grad_accum)) - - # print(graph) - - sugraph = SUGraphGener.gen_sugraph(graph.nodes()) - - # print(sugraph) - - fb_seqs = list() - for fsu in sugraph.fsus(): - for fb_seq in fb_seqs: - for ksu in fb_seq[::-1]: - if sugraph.happen_before(ksu, fsu): - fb_seq.append(fsu) - break - else: - continue - break - else: - fb_seqs.append([fsu]) - - if len(fb_seqs) != grad_accum: - for idx, fb_seq in enumerate(fb_seqs): - print(f'> sequence {idx}:') - for su in fb_seq: - print(su) - assert len(fb_seqs) == grad_accum - - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - - for fb_seq in fb_seqs: - for idx, su in enumerate(fb_seq): - if idx < 2: - sugraph.assign(su, 0) - sugraph.assign(su.mirror, 0) - else: - sugraph.assign(su, 1) - sugraph.assign(su.mirror, 1) - - for fb_seq in fb_seqs: - fb_seq += [fsu.mirror for fsu in fb_seq][::-1] - - print('========= after asignment: ==========\n', sugraph) - - seqs = list() - for fb_seq in fb_seqs: - seqs += fb_seq - print('> seqs:') - for su in seqs: - print(su) - sugraph.partial_set_order(seqs) - - # print('========= after reorder: ==========\n', sugraph) - - execplan = ExectuionPlan(sugraph) - execplan = RemoveRedundantAdapters.apply(execplan) - - # print('========= after remove adapter: ==========\n', execplan) - - execplan = MergeComputeSU.apply(execplan) - # print('========= after merge small SU: ==========\n', execplan) - - mgener = ModelCodeGen(execplan) - tgener = ScheduleCodeGen(execplan) - - mcode0 = mgener.gen(device = 0) - tcode0 = tgener.gen(device = 0) - print('model code on device 0: ') - print(mcode0) - print('schedule code on device 0: ') - print(tcode0) - - mcode1 = mgener.gen(device = 1) - tcode1 = tgener.gen(device = 1) - print('model code on device 1: ') - print(mcode1) - print('schedule code on device 1: ') - print(tcode1) - - assert False diff --git a/tests/codegen/test_partition_codegen.py b/tests/codegen/test_partition_codegen.py deleted file mode 100644 index 8b55805c..00000000 --- a/tests/codegen/test_partition_codegen.py +++ /dev/null @@ -1,112 +0,0 @@ -from cube.graph.operator.operator import IRDataOperation, IRFwOperation -from cube.graph.tensor import IRFullTensor -from cube.graph.operator.function import Linear -from cube.graph.graph import IRGraph -from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraphGener -from cube.schedule.translator import IRDataLoader - -from cube.execplan import ExectuionPlan -from cube.execplan.planpass.redundant import RemoveRedundantAdapters -from cube.execplan.planpass.merge import MergeComputeSU - -from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen - - -def simple_linear(): - input1 = IRFullTensor(shape=[64,1024], name='data1') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input1, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - return [input1], [linear1, linear2, linear3], [linear3.outputs(0)] - - -def test_linear_col_codegen(): - - SchedulePool().clear() - ngpus = 2 - - inputs, ops, outputs = simple_linear() - linear1, linear2, linear3, = ops - graph = IRGraph(ops, inputs, outputs, 'MLP') - print(graph) - - inputs = [inputs[0].tosub()] - loss = graph(*inputs) - loss.backward() - - nodes = SchedulePool().nodes() - fbgraph = IRGraph(nodes, None, None, 'MLPFull') - print(fbgraph) - - # replace first linear by data parallel - algo = linear1.algorithms('column') - subnodes1 = fbgraph.partition(linear1, algo, config=dict(chunk_num=ngpus)) - - algo = linear2.algorithms('column') - subnodes2 = fbgraph.partition(linear2, algo, config=dict(chunk_num=ngpus)) - - algo = linear3.algorithms('column') - subnodes3 = fbgraph.partition(linear3, algo, config=dict(chunk_num=ngpus)) - - print(fbgraph) - - sugraph = SUGraphGener.gen_sugraph(fbgraph.nodes()) - algosu1 = sugraph.fsus()[:ngpus] - for idx, su in enumerate(algosu1): - sugraph.assign(su, idx) - sugraph.assign(su.mirror, idx) - algosu2 = sugraph.fsus()[ngpus: ngpus * 2] - for idx, su in enumerate(algosu2): - sugraph.assign(su, idx) - sugraph.assign(su.mirror, idx) - algosu3 = sugraph.fsus()[ngpus * 2: ngpus * 3] - for idx, su in enumerate(algosu3): - sugraph.assign(su, idx) - sugraph.assign(su.mirror, idx) - print(sugraph) - - execplan = ExectuionPlan(sugraph) - execplan = RemoveRedundantAdapters.apply(execplan) - - execplan = MergeComputeSU.apply(execplan) - - mgener = ModelCodeGen(execplan) - tgener = ScheduleCodeGen(execplan) - - for devid in range(ngpus): - mcode0 = mgener.gen(device=devid, outfile=f'test{devid}.py') - tcode0 = tgener.gen(device=devid, outfile=f'test{devid}.py', attach=True) - print(f'===> model code on device {devid}: ') - print(mcode0) - print(f'===> schedule code on device {devid}: ') - print(tcode0) - - assert False \ No newline at end of file diff --git a/tests/execplan/test_planpass_merge.py b/tests/execplan/test_planpass_merge.py deleted file mode 100644 index 215c9711..00000000 --- a/tests/execplan/test_planpass_merge.py +++ /dev/null @@ -1,84 +0,0 @@ -from cube.graph.tensor import IRFullTensor -from cube.graph.operator.function import Linear -from cube.graph.graph import IRGraph - -from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraphGener - -from cube.execplan import ExectuionPlan -from cube.execplan.planpass.redundant import RemoveRedundantAdapters -from cube.execplan.planpass.merge import MergeComputeSU - - -def construct_graph(): - - input = IRFullTensor(shape=[64,1024], name='data') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - graph = IRGraph( - nodes=[linear1, linear2, linear3], - input_tensors=[input], - output_tensors=linear3.outputs(), - module_name="Test" - ) - return graph - - -def test_planpass_merge(): - SchedulePool().clear() - - graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data').tosub() - output = graph(data) - output.backward() - - nodes = SchedulePool().nodes() - sugraph = SUGraphGener.gen_sugraph(nodes) - - for su in sugraph.sus(): - if su.stype != SUType.P2P: - sugraph.assign(su, 0) - - print('orignal:') - print(sugraph) - - execplan = ExectuionPlan(sugraph) - execplan = RemoveRedundantAdapters.apply(execplan) - execplan = MergeComputeSU.apply(execplan) - - print('merged:') - for devid in execplan.devices(): - print(f'> device {devid}') - for su in execplan.sequence(devid): - print(su) - assert su.stype != SUType.P2P - assert len(execplan.sequence(0)) == 2 \ No newline at end of file diff --git a/tests/execplan/test_planpass_redundant.py b/tests/execplan/test_planpass_redundant.py deleted file mode 100644 index 30cafb11..00000000 --- a/tests/execplan/test_planpass_redundant.py +++ /dev/null @@ -1,78 +0,0 @@ -from cube.graph.tensor import IRFullTensor -from cube.graph.operator.function import Linear -from cube.graph.graph import IRGraph - -from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraphGener - -from cube.execplan import ExectuionPlan -from cube.execplan.planpass.redundant import RemoveRedundantAdapters - - -def construct_graph(): - - input = IRFullTensor(shape=[64,1024], name='data') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - graph = IRGraph( - nodes=[linear1, linear2, linear3], - input_tensors=[input], - output_tensors=linear3.outputs(), - module_name="Test" - ) - return graph - - -def test_remove_adapter(): - - SchedulePool().clear() - - graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data').tosub() - output = graph(data) - output.backward() - - nodes = SchedulePool().nodes() - sugraph = SUGraphGener.gen_sugraph(nodes) - - for su in sugraph.sus(): - sugraph.assign(su, 0) - - execplan = ExectuionPlan(sugraph) - execplan = RemoveRedundantAdapters.apply(execplan) - - for devid in execplan.devices(): - print(f'> device {devid}') - for su in execplan.sequence(devid): - print(su) - assert su.stype != SUType.P2P - assert len(execplan.sequence(0)) == 6 \ No newline at end of file diff --git a/tests/graph/parser/test_parse_attention.py b/tests/graph/parser/test_parse_attention.py deleted file mode 100644 index c8cdf25c..00000000 --- a/tests/graph/parser/test_parse_attention.py +++ /dev/null @@ -1,115 +0,0 @@ -from torch import nn -import torch -import torch.nn.functional as F - -from cube.graph import parser -import cube - - -class MultiHeadSelfAttention(nn.Module): - - def __init__(self, bs, seq_len, embed_dim, heads, dropout): - super().__init__() - - self.batch_size = bs - self.seq_len = seq_len - self.embed_dim = embed_dim - self.num_heads = heads - self.dim_head = embed_dim // heads - self.scale = self.dim_head ** -0.5 - - self.weight_qkv = torch.nn.Parameter(torch.empty( - 3 * embed_dim, embed_dim - )) - self.weight_out = torch.nn.Parameter(torch.empty( - embed_dim, embed_dim - )) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): #, mask): - """ - x: [L, N, E]: seq_len, batch_size, embedding dimension - output: [L, N, E] - """ - bs = self.batch_size - - # [L, N, (num_heads * dim_head)], - # [(num_heads * dim_head), 3 * (num_heads * dim_head)] - # -> [L, N, (num_heads * dim_head * 3)] - # qkv = F.linear(x, self.weight_qkv, None) - - # # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 - # qkv = qkv.chunk(3, dim=-1) - # q, k, v = qkv - # - # # [L, N, (num_heads * dim_head)] -> [L, (N * num_heads), dim_head] - # q = q.contiguous() - # q = q.view(self.seq_len, (bs * self.num_heads), self.dim_head) - # k = k.contiguous() - # k = k.view(self.seq_len, (bs * self.num_heads), self.dim_head) - # v = v.contiguous() - # v = v.view(self.seq_len, (bs * self.num_heads), self.dim_head) - - # [L, N, E] -> 3 x [L, (N * num_heads), dim_head] - q, k, v = cube.runtime.function.toqkv( - x, self.weight_qkv, self.num_heads - ) - - # [L, (N * num_heads), dim_head] -> [(N * num_heads), L, dim_head] - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - # [(N * num_heads), L, dim_head] -> [(N * num_heads), L, dim_head] - q = q * self.scale - # [(N * num_heads), L, dim_head] * [(N * num_heads), dim_head, L] - # -> [(N * num_heads), L, L] - k = k.transpose(-2, -1) - attn = torch.bmm(q, k) - - # # [(N * num_heads), L, L] -> [N, num_heads, L, L] - # attn = attn.view(bs, self.num_heads, self.seq_len, self.seq_len) - # # [N, num_heads, L, L] -> [N, num_heads, L, L] - # attn = attn.masked_fill_(mask, -100000.0) - # # [N, num_heads, L, L] -> [(N * num_heads), L, L] - # attn = attn.view((bs * self.num_heads), self.seq_len, self.seq_len) - attn = cube.runtime.function.tril_mask(attn, self.num_heads) - - # [(N * num_heads), L, L] -> [(N * num_heads), L, L] - attn = F.softmax(attn, dim=-1) - - # [(N * num_heads), L, L] -> [(N * num_heads), L, L] - attn = self.dropout(attn) - # [(N * num_heads), L, L] * [(N * num_heads), L, dim_head] - # -> [(N * num_heads), L, dim_head] - output = torch.bmm(attn, v) - - # # [(N * num_heads), L, dim_head] -> [L, (N * num_heads), dim_head] - # output = output.transpose(0, 1) - # output = output.contiguous() - # # [L, (N * num_heads), dim_head] -> [L, N, (num_heads * dim_head)] - # output = output.view(self.seq_len, bs, self.embed_dim) - output = cube.runtime.function.attn_view(output, self.num_heads) - - # [L, N, (num_heads * dim_head)] * [(num_heads * dim_head), (num_heads * dim_head)] - # => [L, N, (num_heads * dim_head)] - output = F.linear(output, self.weight_out) - return output - - -def test_parse_attention(): - - L = 64 # seq len - N = 16 # batch - E = 1024 # hiddend size = dim_head * num_head - n_heads = 8 - - model = MultiHeadSelfAttention(N, L, E, n_heads, dropout=0.5) - module = torch.jit.script(model) - print(module.graph) - # print(module.code) - - graph = parser.convert(model, input_shapes=([L, N, E],)) - print(graph) - - assert False diff --git a/tests/graph/parser/test_parse_mlp.py b/tests/graph/parser/test_parse_mlp.py deleted file mode 100644 index e40f6aa4..00000000 --- a/tests/graph/parser/test_parse_mlp.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -from torch import nn - -import cube.graph.parser as parser -from cube.ir.cten import IRTensor -import cube.ir as ir - - -class FeedForward(nn.Module): - def __init__(self, dim, dropout=0., mult=16, classes=1000): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult, bias=False) - self.gelu = nn.GELU() - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim * mult, dim) - self.classifier = nn.Linear(dim, classes) - - def forward(self, data, x: int = 4): - output = self.linear1(data) - output = self.gelu(output) - output = self.dropout(output) - output = output + data - output = self.linear2(output) - output = self.classifier(output) - return output - - -model = FeedForward(dim=1024) - - -def test_parse_module(): - - graph = parser.convert(model, input_shapes=([1024,1024],[1,])) - print(graph) - assert len(graph.nodes()) == 6 - assert len(graph.inputs()) == 2 - assert len(graph.outputs()) == 1 - - node1, node2, node3, node4, node5, node6 = graph.nodes() - assert node1.signature == 'torch.nn.functional.linear' - assert node2.signature == 'torch.nn.functional.gelu' - assert node3.signature == 'torch.nn.functional.dropout' - assert node4.signature == 'torch.add' - assert node5.signature == 'torch.nn.functional.linear' - assert node6.signature == 'torch.nn.functional.linear' - - assert node1.inputs(2) is None - assert isinstance(node5.inputs(2), IRTensor) - - # dependency - assert node2.predecessors() == [node1] - assert node3.predecessors() == [node2] - assert node4.predecessors() == [node3] - assert node5.predecessors() == [node4] - assert node6.predecessors() == [node5] - assert node1.successors() == [node2] - assert node2.successors() == [node3] - assert node3.successors() == [node4] - assert node4.successors() == [node5] - assert node5.successors() == [node6] - - # dtype - for node in graph.nodes(): - assert node._dtype == ir.float32 - for val in node.inputs() + node.outputs(): - if isinstance(val, IRTensor): - val.dtype == ir.float32 - - assert graph.outputs(0).shape == [1024, 1000] - assert False \ No newline at end of file diff --git a/tests/graph/test_function.py b/tests/graph/test_function.py deleted file mode 100644 index 927be6ae..00000000 --- a/tests/graph/test_function.py +++ /dev/null @@ -1,19 +0,0 @@ -from cube.graph.operator.function import Linear -from cube.graph.tensor import IRFullTensor -from cube.algorithm.linear import LinearDataParallel - - -def test_linear_algo(): - - input = IRFullTensor(shape=[1024, 1024], name='input').tosub() - weight = IRFullTensor(shape=[1000, 1024], name='weight').tosub() - bias = IRFullTensor(shape=[1000,], name='bias').tosub() - - semantic_op = Linear( - signature='torch.nn.functional.linear', - inputs = [input, weight, bias], - ) - semantic_op.infer_shape() - - assert len(semantic_op.algorithms()) == 3 - assert isinstance(semantic_op.algorithms('data'), LinearDataParallel) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py deleted file mode 100644 index 66f7c6b8..00000000 --- a/tests/graph/test_graph.py +++ /dev/null @@ -1,184 +0,0 @@ -from cube.graph.graph import IRGraph -from cube.graph.operator.operator import IRBpOperation -from cube.graph.tensor import IRFullTensor, IRSubTensor -from cube.graph.operator.function import Linear -import cube.graph.gpass as gpass -from cube.ir.cten import IRTensor - - -def construct_model(): - - input = IRFullTensor(shape=[64,1024], name='data') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - # return [input], [ops], [output] - return [input], [linear1, linear2, linear3], [linear3.outputs(0)] - - -def test_graph_init(): - - inputs, ops, outputs = construct_model() - graph = IRGraph(ops, inputs, outputs, 'MLP') - print(graph) - - assert len(graph.inputs()) == 1 - assert len(graph.outputs()) == 1 - assert graph.name == 'MLP' - - all_inputs = list() - all_outputs = list() - for node in graph.nodes(): - all_inputs += node.inputs() - all_outputs += node.outputs() - - for input in all_inputs: - if isinstance(input, IRTensor): - assert isinstance(input, IRSubTensor) - for output in all_outputs: - if isinstance(output, IRTensor): - assert isinstance(output, IRSubTensor) - - # check inputs - for full_input, sub_input in zip(inputs, graph.inputs()): - assert full_input.overlap(sub_input) - assert full_input.shape == sub_input.shape - assert sub_input in all_inputs - for full_output, sub_output in zip(outputs, graph.outputs()): - assert full_output.overlap(sub_output) - assert full_output.shape == sub_output.shape - assert sub_output in all_outputs - - # check dependency - node1, node2, node3 = graph.nodes() - assert node2 in node1.successors() - assert node3 in node2.successors() - assert node1 in node2.predecessors() - assert node2 in node3.predecessors() - # one-hop test - assert node1 not in node3.predecessors() - assert node3 not in node1.successors() - # false test - assert node1 not in node2.successors() - assert node3 not in node2.predecessors() - - # weight test - params = graph.parameters() - assert len(params) == 5 - - -def test_graph_nodes(): - inputs, ops, outputs = construct_model() - graph = IRGraph(ops, inputs, outputs, 'MLP') - assert id(graph.nodes()) != id(graph.nodes()) - assert graph.nodes(1) == ops[1] - - -def test_graph_forward(): - inputs, ops, outputs = construct_model() - graph = IRGraph(ops, inputs, outputs, 'MLP') - - fgraph = gpass.forward(graph, *graph.inputs()) - print(fgraph) - - fparam_id = [param._id for param in fgraph.parameters()] - param_id = [param._id for param in graph.parameters()] - assert set(fparam_id) == set(param_id) - - for gnode, fnode in zip(graph.nodes(), fgraph.nodes()): - assert gnode.name == fnode.name - assert gnode.signature == fnode.signature - assert len(gnode.inputs()) == len(fnode.inputs()) - assert len(gnode.outputs()) == len(fnode.outputs()) - assert len(gnode.predecessors()) == len(fnode.predecessors()) - assert len(gnode.successors()) == len(fnode.successors()) - - # test backward - bnodes = [node.mirror for node in fgraph.nodes()][::-1] - bgraph = IRGraph(bnodes, None, None, module_name='backwards') - print(bgraph) - bnode1, bnode2, bnode3 = bnodes - for bnode in bnodes: - assert isinstance(bnode, IRBpOperation) - assert len(bnode.inputs()) == 4 - assert len(bnode.outputs()) == 3 - assert bnode2 in bnode1.successors() - assert bnode3 in bnode2.successors() - assert not bnode3 in bnode1.successors() - assert bnode1 in bnode2.predecessors() - assert bnode2 in bnode3.predecessors() - assert not bnode1 in bnode3.predecessors() - - -def test_graph_multi_forward(): - inputs, ops, outputs = construct_model() - graph = IRGraph(ops, inputs, outputs, 'MLP') - - def _gen_data(graph): - data = list() - for input in graph.inputs(): - data.append(input.parent.like().tosub()) - return data - fgraph1 = gpass.forward(graph, *_gen_data(graph)) - fgraph2 = gpass.forward(graph, *_gen_data(graph)) - print(fgraph1) - print(fgraph2) - assert fgraph1.inputs != fgraph2.inputs() - for node1, node2 in zip(fgraph1.nodes(), fgraph2.nodes()): - assert node1.inputs() != node2.inputs() - - -def test_graph_partition(): - - inputs, ops, outputs = construct_model() - graph = IRGraph(ops, inputs, outputs, 'MLP') - - node1, node2, node3 = graph.nodes() - - algo = node2.algorithms('data') - sub_nodes = graph.partition(node2, algo, config=dict(chunk_num=4)) - assert sub_nodes is not None - assert len(graph.nodes()) == 6 - dnode1, dnode2, dnode3, dnode4 = sub_nodes - assert dnode2 not in dnode1.successors() - assert dnode3 not in dnode1.successors() - assert dnode4 not in dnode1.successors() - - algo = node3.algorithms('column') - sub_nodes = graph.partition(node3, algo, config=dict(chunk_num=4)) - print(graph) - - cnode1, cnode2, cnode3, cnode4 = sub_nodes - for cnode in sub_nodes: - assert dnode1 in cnode.predecessors() - assert dnode2 in cnode.predecessors() - assert dnode3 in cnode.predecessors() - assert dnode4 in cnode.predecessors() - assert len(graph.nodes()) == 9 diff --git a/tests/graph/test_graph_partition.py b/tests/graph/test_graph_partition.py deleted file mode 100644 index ce9d4f10..00000000 --- a/tests/graph/test_graph_partition.py +++ /dev/null @@ -1,157 +0,0 @@ -import enum -from cube.graph.graph import IRGraph -from cube.graph.tensor import IRFullTensor, ValueMap -from cube.graph.operator.function import Linear, ElementWise -from cube.schedule.pool import SchedulePool -from cube.schedule.sugraph import SUGraphGener - - -def simple_linear(): - input1 = IRFullTensor(shape=[64,1024], name='data1') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input1, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - return [input1], [linear1, linear2, linear3], [linear3.outputs(0)] - - -def test_linear_dp_partition(): - - SchedulePool().clear() - - inputs, ops, outputs = simple_linear() - linear1, linear2, linear3, = ops - graph = IRGraph(ops, inputs, outputs, 'MLP') - print(graph) - - inputs = [inputs[0].tosub()] - loss = graph(*inputs) - loss.backward() - - nodes = SchedulePool().nodes() - fbgraph = IRGraph(nodes, None, None, 'MLPFull') - print(fbgraph) - - # replace first linear by data parallel - algo = linear1.algorithms('data') - subnodes = fbgraph.partition(linear1, algo, config=dict(chunk_num=4)) - - algo = linear2.algorithms('data') - subnodes = fbgraph.partition(linear2, algo, config=dict(chunk_num=4)) - - algo = linear3.algorithms('data') - subnodes = fbgraph.partition(linear3, algo, config=dict(chunk_num=4)) - - print(fbgraph) - for node in subnodes: - print(node) - print(node.mirror) - # assert False - -def test_linear_hybrid_partition(): - - SchedulePool().clear() - ngpus = 2 - - inputs, ops, outputs = simple_linear() - linear1, linear2, linear3, = ops - graph = IRGraph(ops, inputs, outputs, 'MLP') - print(graph) - - inputs = [inputs[0].tosub()] - loss = graph(*inputs) - loss.backward() - - nodes = SchedulePool().nodes() - fbgraph = IRGraph(nodes, None, None, 'MLPFull') - print(fbgraph) - - # replace first linear by data parallel - algo = linear1.algorithms('column') - subnodes1 = fbgraph.partition(linear1, algo, config=dict(chunk_num=ngpus)) - - algo = linear2.algorithms('column') - subnodes2 = fbgraph.partition(linear2, algo, config=dict(chunk_num=ngpus)) - - algo = linear3.algorithms('column') - subnodes3 = fbgraph.partition(linear3, algo, config=dict(chunk_num=ngpus)) - - print(fbgraph) - # for node in subnodes: - # print(node) - # print(node.mirror) - - sugraph = SUGraphGener.gen_sugraph(fbgraph.nodes()) - algosu1 = sugraph.fsus()[:ngpus] - for idx, su in enumerate(algosu1): - sugraph.assign(su, idx) - sugraph.assign(su.mirror, idx) - algosu2 = sugraph.fsus()[ngpus: ngpus * 2] - for idx, su in enumerate(algosu2): - sugraph.assign(su, idx) - sugraph.assign(su.mirror, idx) - algosu3 = sugraph.fsus()[ngpus * 2: ngpus * 3] - for idx, su in enumerate(algosu3): - sugraph.assign(su, idx) - sugraph.assign(su.mirror, idx) - print(sugraph) - - print('===== algo 1 =====') - for idx, su in enumerate(algosu1): - print('F:', su) - print('B:', su.mirror) - data_grad = su.mirror.outputs(0) - data_grad_ref = su.inputs(0).get_grad(su.nodes(0)) - print('grad :', data_grad) - print('grad ref:', data_grad_ref) - assert data_grad == data_grad_ref - assert data_grad.valmap == ValueMap(idx, ngpus) - - print('===== algo 2 =====') - for idx, su in enumerate(algosu2): - print('F:', su) - print('B:', su.mirror) - data_grad = su.mirror.outputs(0) - data_grad_ref = su.inputs(0).get_grad(su.nodes(0)) - print('grad :', data_grad) - print('grad ref:', data_grad_ref) - assert data_grad == data_grad_ref - assert data_grad.valmap == ValueMap(idx, ngpus) - - print('===== algo 3 =====') - for idx, su in enumerate(algosu3): - print('F:', su) - print('B:', su.mirror) - data_grad = su.mirror.outputs(0) - data_grad_ref = su.inputs(0).get_grad(su.nodes(0)) - print('grad :', data_grad) - print('grad ref:', data_grad_ref) - assert data_grad == data_grad_ref - assert data_grad.valmap == ValueMap(idx, ngpus) - - assert False diff --git a/tests/graph/test_pas.py b/tests/graph/test_pas.py deleted file mode 100644 index dad056cc..00000000 --- a/tests/graph/test_pas.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -from torch import nn - -import cube -from cube.graph.adapter.adapter import IRAdapter -import cube.graph.parser as parser -from cube.logics.pool import SchedulePool -from cube.logics.translator import LogicTranslator -from cube.graph.adapter import AdapterGener -from cube.execplan import ExectuionPlan -from cube.execplan.planpass.grouping import Grouping - -from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen - -class MLP(nn.Module): - def __init__(self, dim, mult=4): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult) - self.linear2 = nn.Linear(dim * mult, dim) - self.linear3 = nn.Linear(dim, dim * mult) - self.linear4 = nn.Linear(dim * mult, dim) - - def forward(self, data): - output = self.linear1(data) - output = self.linear2(output) - output = self.linear3(output) - output = self.linear4(output) - loss = torch.sum(output) - return loss - -model = MLP(dim=1024) -dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [128, 1024]) -optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - -def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - - -def test_p(): - SchedulePool().clear() - - graph = parser.convert_model(model, input_shapes=([128, 1024],)) - loader = parser.convert_dataloader(dataloader) - - train_iter(graph, loader) - graph = LogicTranslator.gen_logic_graph() - # print(graph.extra_repr()) - # assert False - - node1, node2, node3, node4 = graph.nodes()[1:5] - for node in graph.nodes(): - graph.assign(node, rank=0) - algo = node2.algorithms('column') - subnodes = graph.partition(node2, algo, config=dict(chunk_num=4)) - for idx, subnode in enumerate(subnodes): - graph.assign(subnode, rank=idx) - - # print(graph.extra_repr()) - - graph = AdapterGener.gen(graph) - # print(graph) - # for node in graph.nodes(): - # if isinstance(node, IRAdapter): - # print(node.extra_repr()) - - execplan = ExectuionPlan(graph) - # print(execplan) - - execplan = Grouping.apply(execplan) - print(execplan) - # print(execplan.graph.extra_repr()) - - mcodegen = ModelCodeGen(execplan) - tcodegen = ScheduleCodeGen(execplan) - - mcode = mcodegen.gen(device=0) - tcode = tcodegen.gen(device=0) - print(mcode) - print(tcode) - - mcode = mcodegen.gen(device=1) - tcode = tcodegen.gen(device=1) - print(mcode) - print(tcode) - - mcode = mcodegen.gen(device=2) - tcode = tcodegen.gen(device=2) - print(mcode) - print(tcode) - mcode = mcodegen.gen(device=3) - tcode = tcodegen.gen(device=3) - print(mcode) - print(tcode) - - assert False diff --git a/tests/graph/test_tensor.py b/tests/graph/test_tensor.py deleted file mode 100644 index 66052ca9..00000000 --- a/tests/graph/test_tensor.py +++ /dev/null @@ -1,231 +0,0 @@ -import copy - -from cube.graph.tensor import IRFullTensor, IRSubTensor, ValueMap - - -def test_full_tensor_init(): - - tensor = IRFullTensor(shape=[1024,1024], name='full_tensor') - assert tensor.shape == [1024, 1024] - assert tensor.name == 'full_tensor' - -def test_full_tensor_constrcut(): - - tensor = IRFullTensor(shape=[1024,1024], name='full_tensor') - ctensor = copy.copy(tensor) - assert isinstance(ctensor, IRFullTensor) - -def test_full_tensor_select(): - - tensor = IRFullTensor(shape=[1024,1024], name='tensor') - assert len(tensor.segments()) == 0 - assert len(tensor.indmap()) == 0 - assert len(tensor.val_maps()) == 0 - - sub_tensor1 = tensor.select( - indmap = (slice(0, 1024), slice(0, 512)), - valmap = None, - shape = (1024, 512) - ) - - sub_tensor2 = tensor.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - - assert sub_tensor1.shape == (1024, 512) - assert sub_tensor1.name == 'tensor' - - assert sub_tensor2.shape == (1024, 512) - assert sub_tensor2.name == 'tensor' - - assert len(tensor.segments()) == 2 - assert len(tensor.indmap()) == 2 - assert len(tensor.val_maps()) == 2 - - -def test_full_tensor_overlap(): - - tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') - sub_tensor1 = tensor1.select( - indmap = (slice(0, 1024), slice(256, 1024)), - valmap = None, - shape = (1024, 768) - ) - - sub_tensor2 = tensor1.select( - indmap = (slice(0, 1024, 2), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - sub_tensor3 = tensor1.select( - indmap = (slice(1, 1024, 2), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - - tensor2 = IRFullTensor(shape=[1024,1024], name='tensor') - - assert tensor1.overlap(sub_tensor1) - assert tensor1.overlap(tensor1) - assert not tensor1.overlap(tensor2) - assert not tensor2.overlap(sub_tensor1) - - assert not sub_tensor2.overlap(sub_tensor3) - - -def test_sub_tensor_select(): - - tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') - sub_tensor1 = tensor1.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - sub_tensor2 = sub_tensor1.select( - indmap = (slice(512, 1024), slice(0, 256)), - valmap = None, - shape = (512, 256) - ) - sub_tensor3 = sub_tensor1.select( - indmap = (slice(512, 1024), slice(256, 512)), - valmap = None, - shape = (512, 256) - ) - - indmap = sub_tensor2.indmap.get() - assert indmap == (slice(512, 1024, 1), slice(512, 768, 1)) - indmap = sub_tensor3.indmap.get() - assert indmap == (slice(512, 1024, 1), slice(768, 1024, 1)) - - assert len(tensor1.segments()) == 3 - assert sub_tensor1 in tensor1.segments() - assert sub_tensor2 in tensor1.segments() - assert sub_tensor3 in tensor1.segments() - - -def test_sub_tensor_ind_overlap(): - - tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') - sub_tensor1 = tensor1.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - sub_tensor2 = sub_tensor1.select( - indmap = (slice(512, 1024), slice(0, 256)), - valmap = None, - shape = (512, 256) - ) - sub_tensor3 = sub_tensor1.select( - indmap = (slice(512, 1024), slice(256, 512)), - valmap = None, - shape = (512, 256) - ) - - assert sub_tensor1.overlap(sub_tensor2) - assert sub_tensor1.overlap(sub_tensor3) - assert not sub_tensor2.overlap(sub_tensor3) - - -def test_sub_tensor_val_overlap(): - tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') - sub_tensor1 = tensor1.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - sub_tensor2 = tensor1.select( - indmap = (slice(0, 1024), slice(0, 512)), - valmap = (0, 4), - shape = (1024, 512) - ) - sub_tensor3 = tensor1.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = (0, 4), - shape = (1024, 512) - ) - sub_tensor4 = tensor1.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = (1, 4), - shape = (1024, 512) - ) - - assert not sub_tensor1.overlap(sub_tensor2) - assert not sub_tensor2.overlap(sub_tensor3) - assert sub_tensor1.overlap(sub_tensor3) - assert sub_tensor1.overlap(sub_tensor4) - assert sub_tensor4.overlap(sub_tensor1) - assert not sub_tensor3.overlap(sub_tensor4) - -def test_sub_tensor_common(): - - tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') - sub_tensor_col1 = tensor1.select( - indmap = (slice(0, 1024), slice(0, 512)), - valmap = None, - shape = (1024, 512) - ) - sub_tensor_col2 = tensor1.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - sub_tensor_row1 = tensor1.select( - indmap = (slice(0, 512), slice(0, 1024)), - valmap = None, - shape = (512, 1024) - ) - sub_tensor_row2 = tensor1.select( - indmap = (slice(512, 1024), slice(0, 1024)), - valmap = None, - shape = (512, 1024) - ) - - lt = sub_tensor_col1.common(sub_tensor_row1) - rt = sub_tensor_col2.common(sub_tensor_row1) - lb = sub_tensor_row2.common(sub_tensor_col1) - rb = sub_tensor_row2.common(sub_tensor_col2) - - assert lt.indmap.get() == (slice(0, 512, 1), slice(0, 512, 1)) - assert rt.indmap.get() == (slice(0, 512, 1), slice(512, 1024, 1)) - assert lb.indmap.get() == (slice(512, 1024, 1), slice(0, 512, 1)) - assert rb.indmap.get() == (slice(512, 1024, 1), slice(512, 1024, 1)) - - -def test_sub_tensor_as_grad(): - tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') - sub_tensor1 = tensor1.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - - sub_tensor1.as_grad() - assert sub_tensor1.is_grad() - - sub_tensor2 = tensor1.select( - indmap = (slice(0, 1024), slice(0, 512)), - valmap = (0, 4), - shape = (1024, 512) - ) - assert sub_tensor2.is_grad() - - -def test_sub_tensor_copy(): - tensor1 = IRFullTensor(shape=[1024,1024], name='tensor') - sub_tensor1 = tensor1.select( - indmap = (slice(0, 1024), slice(512, 1024)), - valmap = None, - shape = (1024, 512) - ) - sub_tensor2 = tensor1.select( - indmap = (slice(0, 1024), slice(0, 512)), - valmap = (0, 4), - shape = (1024, 512) - ) - sub_tensor1.grads = [sub_tensor2] - cpy_tensor = copy.copy(sub_tensor1) - assert cpy_tensor.grads[0] == sub_tensor2 - diff --git a/tests/graph/test_tensor_grad.py b/tests/graph/test_tensor_grad.py deleted file mode 100644 index e1948814..00000000 --- a/tests/graph/test_tensor_grad.py +++ /dev/null @@ -1,153 +0,0 @@ -from cube.graph.graph import IRGraph -from cube.graph.tensor import IRFullTensor, ValueMap -from cube.graph.operator.function import Linear, ElementWise -import cube.graph.gpass as gpass -from cube.ir.cten import IRTensor - - -def construct_model(): - - input1 = IRFullTensor(shape=[64,1024], name='data1') - input2 = IRFullTensor(shape=[64,1024], name='data2') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input1, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - # linear4 - linear4 = Linear( - name='linear4', - signature='torch.nn.functional.linear', - inputs= [input2, weight1, bias1], - ) - linear4.infer_shape() - - # element-wise - add5 = ElementWise( - name='add', - signature='torch.add', - inputs=[linear2.outputs(0), linear3.outputs(0)] - ) - add5.infer_shape() - - # element-wise - add6 = ElementWise( - name='add', - signature='torch.add', - inputs=[add5.outputs(0), linear4.outputs(0)] - ) - add6.infer_shape() - - # return [input], [ops], [output] - return [input1, input2], [linear1, linear2, linear3, linear4, add5, add6], [add6.outputs(0)] - - -def test_tensor_grad(): - - inputs, ops, outputs = construct_model() - linear1, linear2, linear3, linear4, add5, add6 = ops - graph = IRGraph(ops, inputs, outputs, 'MLP') - print(graph) - - all_parent_tids = list() - all_parent_tensors = list() - for op in ops: - for input in op.inputs(): - if isinstance(input, IRTensor): - if input.parent._id not in all_parent_tids: - all_parent_tensors.append(input.parent) - - for pten in all_parent_tensors: - assert pten.grad is None - print(pten.name, pten) - cell_ids = [cell._id for cell in pten.consumers] - print('consumers id:', cell_ids) - print('') - - print('test grad:') - - input = linear1.inputs(0) - assert input.grad is None - gin = input.get_grad(linear1) - assert gin.valmap == ValueMap(0, 1) - print(gin.name, gin) - - weight = linear1.inputs(1) - gw = weight.get_grad(linear1) - assert gw.valmap == ValueMap(0, 2) - print(gw.name, gw) - - weight = linear4.inputs(1) - gw = weight.get_grad(linear4) - assert gw.valmap == ValueMap(1, 2) - print(gw.name, gw) - - out2 = linear2.outputs(0) - gout2 = out2.get_grad(linear2) - print(gout2.name, gout2) - assert gout2.valmap == ValueMap(0, 1) - gout2 = out2.get_grad(linear3) - print(gout2.name, gout2) - assert gout2.valmap == ValueMap(0, 2) - gout2 = out2.get_grad(add5) - print(gout2.name, gout2) - assert gout2.valmap == ValueMap(1, 2) - - out3 = linear3.outputs(0) - gout3 = out3.get_grad(linear3) - print(gout3.name, gout3) - assert gout3.valmap == ValueMap(0, 1) - gout3 = out3.get_grad(add5) - print(gout3.name, gout3) - assert gout3.valmap == ValueMap(0, 1) - - for node in graph.nodes(): - assert node.mirror is None - - print('test forward graph:') - inputs = [inputs[0].tosub(), inputs[1].tosub()] - graph = gpass.forward(graph, *inputs) - print(graph) - for node in graph.nodes()[::-1]: - print(node.mirror) - - gw1 = linear1.mirror.outputs(1) - assert gw1.is_grad() - print(gw1) - gw4 = linear4.mirror.outputs(1) - assert gw4.is_grad() - print(gw4) - - assert gw1.parent == gw4.parent - assert gw1.shape == gw4.shape - assert gw1.indmap == gw4.indmap - assert gw1.valmap != gw4.valmap - - # assert False diff --git a/tests/ir/test_cell.py b/tests/ir/test_cell.py deleted file mode 100644 index 30fb2857..00000000 --- a/tests/ir/test_cell.py +++ /dev/null @@ -1,219 +0,0 @@ -from cube.ir.cten import IRCell, IRTensor - - -def test_cell_init(): - - cell = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - cell2 = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - assert cell2._id != cell._id - - assert len(cell.device) == 0 - assert cell.name == 'cell_test' - assert cell.signature == 'torch.nn.functional.linear' - assert len(cell.inputs()) == 3 - assert len(cell.outputs()) == 1 - assert len(cell.device) == 0 - - -def test_cell_device(): - - cell = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - assert len(cell.device) == 0 - cell.device = 2 - assert len(cell.device) == 1 - assert cell.device[0] == 2 - assert cell.on_device(2) - assert not cell.on_device(3) - - cell.device = [2,3] - assert len(cell.device) == 2 - assert set(cell.device) == set([2, 3]) - assert cell.on_device(2) - assert cell.on_device(3) - assert not cell.on_device(4) - - -def test_cell_inputs(): - - cell = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - assert len(cell.inputs()) == 3 - for input in cell.inputs(): - assert input is None - - # the copy behavior - inputs = cell.inputs() - inputs[2] = 0 - assert cell.inputs(2) is None - - for idx in range(len(cell.inputs())): - assert cell.inputs(idx) is None - tensor = IRTensor(shape=[1024,], name='input') - cell.set_input(idx, tensor) - assert cell.inputs(idx) == tensor - - -def test_cell_outputs(): - - cell = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - assert len(cell.outputs()) == 1 - for output in cell.outputs(): - assert isinstance(output, IRTensor) - - # the copy behavior - outputs = cell.outputs() - outputs[0] = 4 - assert cell.outputs(0) != 4 - - for idx in range(len(cell.outputs())): - output = cell.outputs(idx) - tensor = IRTensor(shape=[1024,], name='output') - cell.set_output(0, tensor) - assert cell.outputs(0) == tensor - assert cell.outputs(0) != output - - -def test_cell_predecessor(): - - cell_prev = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - cell_post = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - assert len(cell_post.predecessors()) == 0 - assert len(cell_prev.predecessors()) == 0 - - cell_post.add_predecessor(1, cell_prev) - assert cell_prev in cell_post.predecessors() - assert len(cell_post.predecessors()) == 1 - assert cell_prev in cell_post.predecessors(1) - - assert len(cell_post.successors()) == 0 - - -def test_cell_successor(): - - cell_prev = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - cell_post = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - assert len(cell_prev.successors()) == 0 - assert len(cell_post.successors()) == 0 - - cell_prev.add_successor(0, cell_post) - assert cell_post in cell_prev.successors() - assert len(cell_prev.successors()) == 1 - assert cell_post in cell_prev.successors() - - assert len(cell_post.predecessors()) == 0 - - -def test_cell_get_inputs_and_outputs(): - - cell1 = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - input1 = IRTensor(shape=[1024, 1024]) - weight1 = IRTensor(shape=[1024, 1024]) - bias1 = IRTensor(shape=[1024,]) - - cell1.set_input(0, input1) - cell1.set_input(1, weight1) - cell1.set_input(2, bias1) - - - cell2 = IRCell( - name='cell_test', - signature='torch.nn.functional.linear', - input_length=3, - output_length=1 - ) - - input2 = IRTensor(shape=[1024, 1024]) - weight2 = IRTensor(shape=[1024, 1024]) - bias2 = IRTensor(shape=[1024,]) - - cell2.set_input(0, input2) - cell2.set_input(1, weight2) - cell2.set_input(2, bias2) - - inputs = IRCell.get_inputs([cell1, cell2]) - assert len(inputs) == 6 - assert input1 in inputs - assert weight1 in inputs - assert bias1 in inputs - assert input2 in inputs - assert weight2 in inputs - assert bias2 in inputs - - outputs = IRCell.get_outputs([cell1, cell2]) - assert len(outputs) == 2 - for output in cell1.outputs() + cell2.outputs(): - assert output in outputs - - # overlapped - cell2.set_input(1, weight1) - cell2.set_input(0, cell1.outputs(0)) - - inputs = IRCell.get_inputs([cell1, cell2]) - assert len(inputs) == 5 - assert input1 in inputs - assert weight1 in inputs - assert bias1 in inputs - assert bias2 in inputs - - outputs = IRCell.get_outputs([cell1, cell2]) - assert len(outputs) == 1 - assert cell2.outputs(0) in outputs - assert cell1.outputs(0) not in outputs diff --git a/tests/ir/test_tensor.py b/tests/ir/test_tensor.py deleted file mode 100644 index 6a172907..00000000 --- a/tests/ir/test_tensor.py +++ /dev/null @@ -1,192 +0,0 @@ -import copy - -from cube.ir.cten import IRTensor, IRCell - - -def test_tensor_init(): - - tensor1 = IRTensor() - tensor2 = IRTensor(shape=[1,2,3]) - tensor3 = IRTensor(shape=[1024], name='tensor') - - assert tensor1._id != tensor2._id - assert tensor2._id != tensor3._id - - assert tensor1.shape is None - assert tensor2.shape == [1,2,3] - assert tensor3.shape == [1024,] - - assert tensor1.name is None - assert tensor2.name is None - assert tensor3.name == 'tensor' - - assert len(tensor1.device) == 0 - assert len(tensor2.device) == 0 - - assert tensor1.requires_grad - assert tensor2.requires_grad - assert tensor3.requires_grad - - -def test_tensor_attach(): - - tensor1 = IRTensor() - tensor2 = IRTensor() - cell = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - - tensor1.attach_cell(cell) - assert cell in tensor1._cell - assert len(tensor1._cell) == 1 - assert len(tensor2._cell) == 0 - - tensor1.detach_cell(cell) - assert cell not in tensor1._cell - assert len(tensor1._cell) == 0 - - cell.set_input(0, tensor1) - cell.set_output(0, tensor1) - assert len(tensor1._cell) == 0 - assert len(cell.inputs(0)._cell) == 1 - - -def test_tensor_renew(): - - tensor1 = IRTensor(shape=[1024], name='renew_tensor') - cell = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - cell.set_input(0, tensor1) - tensor1 = cell.inputs(0) - - tensor2 = tensor1.renew() - assert tensor2.shape == tensor1.shape - assert tensor2.name == tensor1.name - assert tensor2 not in cell.inputs() - assert len(tensor2._cell) == 0 - assert tensor2.requires_grad == tensor1.requires_grad - - -def test_tensor_copy(): - - tensor1 = IRTensor(shape=[1024], name='renew_tensor') - cell = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - tensor1 = cell.set_input(0, tensor1) - - tensor2 = copy.copy(tensor1) - assert tensor2 == tensor1 - assert len(tensor2._cell) == 0 - - -def test_tensor_device(): - - tensor1 = IRTensor(shape=[1024], name='renew_tensor') - cell1 = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - cell2 = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - tensor1 = cell1.set_input(0, tensor1) - tensor2 = cell2.set_input(0, tensor1) - - assert tensor1 == tensor2 - - assert len(tensor1.device) == 0 - assert len(tensor2.device) == 0 - - cell1.device = 2 - assert tensor1.device == [2] - assert len(tensor2.device) == 0 - - cell2.device = 3 - assert tensor1.device == [2] - assert tensor2.device == [3] - - -def test_tensor_dst(): - tensor1 = IRTensor(shape=[1024], name='renew_tensor') - cell1 = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - cell2 = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - - cell1.set_input(0, tensor1) - cells = tensor1.dst([cell1, cell2]) - assert set(cells) == set([cell1]) - - cell2.set_input(0, tensor1) - cells = tensor1.dst([cell1, cell2]) - assert set(cells) == set([cell1, cell2]) - - -def test_tensor_src(): - tensor1 = IRTensor(shape=[1024], name='renew_tensor') - cell1 = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - cell2 = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - - cell1.set_output(0, tensor1) - cells = tensor1.src([cell1, cell2]) - assert set(cells) == set([cell1]) - - cell2.set_output(0, tensor1) - cells = tensor1.src([cell1, cell2]) - assert set(cells) == set([cell1, cell2]) - - -def test_tensor_is_leaf(): - tensor1 = IRTensor(shape=[1024], name='renew_tensor') - cell1 = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - cell2 = IRCell( - name='cell', - signature='any', - input_length=3, - output_length=1 - ) - cell1.set_input(0, tensor1) - assert tensor1.is_leaf([cell1]) - - cell2.set_input(0, cell1.outputs(0)) - assert cell2.outputs(0).is_leaf([cell1]) - assert not cell2.outputs(0).is_leaf([cell1, cell2]) diff --git a/tests/runtime/rollsplit.py b/tests/runtime/rollsplit.py deleted file mode 100644 index 21dc065a..00000000 --- a/tests/runtime/rollsplit.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -CUDA_VISIBLE_DEVICES=4,5,6,7 -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - tests/runtime/rollsplit.py - -""" - - -import torch - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - - -def test_roll_parallel(): - - # input size - group = None - world_size = torch.distributed.get_world_size(group=group) - input_size = [1, 224 // world_size, 224, 256] - # input_size = torch.arange(0, ) - input = torch.randn(input_size).cuda() * 10 - - CudaTimer().warmup(seconds=2) - - torch.distributed.barrier() - CudaTimer().start(field_name='roll_halo') - for _ in range(1000): - roll_out = cube.runtime.function.roll_dim_parallel( - input, (9 // 2), 1, list(range(world_size)), group - ) - CudaTimer().stop(field_name='roll_halo') - ref1 = roll_out - # print_each_rank(ref1, rank_only=0) - assert roll_out.shape == input.shape - span = CudaTimer().duration(times=1000, field_name='roll_halo') - print_each_rank('span on halo exchange: {:.2f} ms'.format(span)) - - - torch.distributed.barrier() - CudaTimer().start(field_name='roll_allgather') - for _ in range(1000): - roll_out = cube.runtime.function.roll_dim_allgather( - input, (9 // 2), 1, group - ) - CudaTimer().stop(field_name='roll_allgather') - ref2 = roll_out - # print_each_rank(ref2, rank_only=0) - span = CudaTimer().duration(times=1000, field_name='roll_allgather') - print_each_rank('span on allgather exchange: {:.2f} ms'.format(span)) - - if not torch.allclose(ref1, ref2, atol=1e-3, rtol=1e-3): - print('correctness test failed') - else: - print('correctness test passed') - - -def test_roll_grid_parallel(): - # input size - group = None - world_size = torch.distributed.get_world_size(group=group) - myrank = torch.distributed.get_rank() - assert world_size == 4 - # input_size = [1, 224 // 2, 224 // 2, 256] - # input = torch.randn(input_size).cuda() * 10 - input = torch.arange(myrank * 4, (myrank + 1) * 4).view(1, 2, 2, 1).float() - input = input.cuda() - print_each_rank(f'input: {input.view(-1)}') - - CudaTimer().warmup(seconds=2) - - torch.distributed.barrier() - CudaTimer().start(field_name='roll_halo') - for _ in range(1): - roll_out = cube.runtime.function.roll_grid_parallel( - input, (9 // 2, 9 // 2), (1, 2), 2, 2, group - ) - CudaTimer().stop(field_name='roll_halo') - ref1 = roll_out - # print_each_rank(ref1, rank_only=0) - assert roll_out.shape == input.shape - span = CudaTimer().duration(times=1000, field_name='roll_halo') - print_each_rank('span on halo exchange: {:.2f} ms'.format(span)) - - torch.distributed.barrier() - CudaTimer().start(field_name='roll_allgather') - for _ in range(1): - roll_out = cube.runtime.function.roll_dim_allgather( - input, (9 // 2), 1, group - ) - roll_out = cube.runtime.function.roll_dim_allgather( - roll_out, (9 // 2), 2, group - ) - CudaTimer().stop(field_name='roll_allgather') - ref2 = roll_out - # print_each_rank(ref2, rank_only=0) - span = CudaTimer().duration(times=1000, field_name='roll_allgather') - print_each_rank('span on allgather exchange: {:.2f} ms'.format(span)) - - -def test_roll_parallel_autograd(): - - group = None - world_size = torch.distributed.get_world_size(group=group) - input_size = [1, 224 // world_size, 224, 256] - # input_size = torch.arange(0, ) - input = torch.randn(input_size).cuda() * 10 - input = input.requires_grad_() - - out = cube.runtime.function.roll_dim_parallel( - input, (9 // 2), 1, group - ) - loss = torch.sum(out) - loss.backward() - print(loss) - print(input.grad) - - -def test_grid_partition(): - - group = None - world_size = torch.distributed.get_world_size(group=group) - assert world_size == 4 - input_size = [1, 56, 56, 256] - input = torch.randn(input_size).cuda() * 10 - input = input.requires_grad_() - out = cube.runtime.function.grid_partition(input, 2, 2, group = None) - print(out.shape) - assert out.shape == torch.Size([1, 56 // 2, 56 // 2, 256]) - loss = torch.sum(out) - loss.backward() - # print(input.grad) - - -def test_grid_collection(): - - group = None - world_size = torch.distributed.get_world_size(group=group) - assert world_size == 4 - input_size = [1, 56 // 2, 56 // 2, 256] - input = torch.randn(input_size).cuda() * 10 - input = input.requires_grad_() - out = cube.runtime.function.grid_collection(input, 2, 2, group = None) - assert out.shape == torch.Size([1, 56, 56, 256]) - loss = torch.sum(out) - loss.backward() - # print(input.grad) - - -if __name__ == '__main__': - - cube.init() - test_roll_parallel() - # test_roll_grid_parallel() - # test_roll_parallel_autograd() - # test_grid_partition() - # test_grid_collection() \ No newline at end of file diff --git a/tests/runtime/test_group.py b/tests/runtime/test_group.py deleted file mode 100644 index 21bd37d4..00000000 --- a/tests/runtime/test_group.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -Test this with: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=6000 \ - --use_env \ - tests/test_group.py -""" - -from cube.runtime.device import DeviceGroup - -import torch - - -def test_sub_group(): - - group = DeviceGroup() - myrank = group.rank - sub_group_1 = group.get_group([0,2]) - if myrank in [0,2]: - assert torch.distributed.get_rank(sub_group_1) in [0,1] - else: - assert torch.distributed.get_rank(sub_group_1) == -1 - - sub_group_2 = group.get_group([1,3]) - if myrank in [1,3]: - assert torch.distributed.get_rank(sub_group_2) in [0,1] - else: - assert torch.distributed.get_rank(sub_group_2) == -1 - # print(group) - - -if __name__ == '__main__': - - # init distributed - group = DeviceGroup() - - test_sub_group() diff --git a/tests/runtime/test_nccl.py b/tests/runtime/test_nccl.py deleted file mode 100644 index 7184743b..00000000 --- a/tests/runtime/test_nccl.py +++ /dev/null @@ -1,90 +0,0 @@ - -""" -Single node usage: -e.g., 4 GPUs - -OMP_NUM_THREADS=4 torchrun --standalone \ - --nproc_per_node=8 \ - --nnodes=1 \ - tests/runtime/test_nccl.py - -""" - -import torch -import time -import sys -import os -import argparse - - -def print_each_rank(msg, select=True, outfile=''): - myrank = torch.distributed.get_rank() - outfile = sys.stdout if outfile == '' else outfile - for rank in range(torch.distributed.get_world_size()): - if select: - if myrank == rank: - f = open(outfile, 'a') if outfile != sys.stdout else sys.stdout - f.write('rank [{}]: {}\n'.format(rank, msg)) - if outfile != sys.stdout: - f.close() - torch.distributed.barrier() - - -def test_nccl(size, local_rank): - msg = torch.ones((size,)).cuda() - # warm up - for _ in range(20): - out = torch.distributed.all_reduce(msg) - torch.cuda.synchronize() - # profile - tic = time.perf_counter() - for _ in range(100): - out = torch.distributed.all_reduce(msg) - torch.cuda.synchronize() - toc = time.perf_counter() - - span = (toc - tic) * 1000 / 100 # in ms - bandwidth = size / span / 1e6 # in GB/s - print_each_rank( - 'NCCL Allreduce | Msg Size: {:.0f} MB | Algo Bandwidth: {:.2f} GB/s'.format( - size / 1024 / 1024, bandwidth), - select=(local_rank==0), - ) - -def test_allgather(size, local_rank): - msg = torch.ones((size,)).cuda() - tensor_list = [torch.empty_like(msg) for _ in range(torch.distributed.get_world_size())] - - tic = time.perf_counter() - for _ in range(100): - out = torch.distributed.all_gather(tensor_list, msg) - torch.cuda.synchronize() - print_each_rank('Passed all-gather') - toc = time.perf_counter() - - -def benchmark(args, local_rank): - size = args.begin - while size <= args.end: - # test_allgather(size * 1024 * 1024, local_rank) - test_nccl(size * 1024 * 1024, local_rank) # MB to B - size *= 2 - print_each_rank('test on nccl is done') - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--begin', type=int, default=4, - help='start message size in MB') - parser.add_argument('--end', type=int, default=64, - help='end message size in MB') - args = parser.parse_args() - - print('> initializing distributed environ...') - torch.distributed.init_process_group(backend='nccl') - local_rank = int(os.environ.get('LOCAL_RANK')) - print_each_rank('local rank-{} launches'.format(local_rank)) - - torch.cuda.set_device(local_rank) - benchmark(args, local_rank) \ No newline at end of file diff --git a/tests/schedule/test_adapter_transform.py b/tests/schedule/test_adapter_transform.py deleted file mode 100644 index d0f97b44..00000000 --- a/tests/schedule/test_adapter_transform.py +++ /dev/null @@ -1,196 +0,0 @@ -from cube.schedule.adapter.transform import IRTransformType -from cube.schedule.adapter.transform import IRTensorTransform - -from cube.graph.tensor import IRFullTensor - - -def test_tensor_transform_select(): - - tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() - - tensor2 = tensor1.select( - indmap = (slice(0, 512), slice(0, 1024)), - valmap = (0, 1), - shape = [512, 1024] - ) - - tensor3 = tensor1.select( - indmap = (slice(512, 1024), slice(0, 1024)), - valmap = (0, 2), - shape = [512, 1024] - ) - - tensor4 = tensor3.select( - indmap = (slice(0, 256), slice(0, 512)), - valmap = (0, 1), - shape = [256, 512] - ) - - tensor5 = tensor3.select( - indmap = (slice(256, 512), slice(0, 512)), - valmap = (0, 1), - shape = [256, 512] - ) - - select1 = IRTensorTransform( - src_tensors=[tensor1], - dst_tensors=[tensor2, tensor3] - ) - assert len(select1.inputs()) == 1 - assert len(select1.outputs()) == 2 - assert select1.ttype == IRTransformType.Select - - print('> select1:', select1) - for prim in select1.trace(): - print(prim) - - select2 = IRTensorTransform( - src_tensors=[tensor3], - dst_tensors=[tensor4, tensor5] - ) - print('> select2:', select2) - for prim in select2.trace(): - print(prim) - assert False - - -def test_tensor_transform_merge(): - tensor0 = IRFullTensor(shape=[1024,1024], name='test1').tosub() - - tensor1 = tensor0.select( - indmap = (slice(0, 512), slice(0, 512)), - valmap = None, - shape = [256, 1024] - ) - - tensor2 = tensor0.select( - indmap = (slice(0, 512), slice(512, 1024)), - valmap = None, - shape = [256, 1024] - ) - - tensor3 = tensor0.select( - indmap = (slice(512, 1024), slice(0, 512)), - valmap = None, - shape = [256, 512] - ) - - tensor4 = tensor0.select( - indmap = (slice(512, 1024), slice(512, 1024)), - valmap = None, - shape = [256, 512] - ) - - tensor5 = tensor0.select( - indmap = (slice(512, 1024), slice(0, 1024)), - valmap = None, - shape = [256, 512] - ) - - merge1 = IRTensorTransform( - src_tensors=[tensor1, tensor2, tensor3, tensor4], - dst_tensors=[tensor0] - ) - assert len(merge1.inputs()) == 4 - assert len(merge1.outputs()) == 1 - assert merge1.ttype == IRTransformType.Merge - - print('> merge1:') - for prim in merge1.trace(): - print(prim) - assert merge1.trace()[-1].output == tensor0 - assert merge1.trace()[-1].output._id == tensor0._id - - merge2 = IRTensorTransform( - src_tensors=[tensor3, tensor4], - dst_tensors=[tensor5] - ) - print('> merge2:') - for prim in merge2.trace(): - print(prim) - assert merge2.trace()[-1].output == tensor5 - assert merge2.trace()[-1].output._id == tensor5._id - # assert False - - tensor6 = tensor0.select( - indmap = (slice(0, 256), slice(0, 1024)), - valmap = (0, 4), - shape = [256, 1024] - ) - tensor7 = tensor0.select( - indmap = (slice(0, 256), slice(0, 1024)), - valmap = (1, 4), - shape = [256, 1024] - ) - tensor8 = tensor0.select( - indmap = (slice(0, 256), slice(0, 1024)), - valmap = (2, 4), - shape = [256, 1024] - ) - tensor9 = tensor0.select( - indmap = (slice(0, 256), slice(0, 1024)), - valmap = (3, 4), - shape = [256, 1024] - ) - - tensor10 = tensor0.select( - indmap = (slice(0, 256), slice(0, 1024)), - valmap = (0, 1) - ) - - merge3 = IRTensorTransform( - src_tensors=[tensor6, tensor7, tensor8, tensor9], - dst_tensors=[tensor10] - ) - print('> merge3:') - for prim in merge3.trace(): - print(prim) - assert merge3.trace()[-1].output._id == tensor10._id - # assert False - - -def test_transform_identity(): - - tensor1 = IRFullTensor(shape=[1024,1024], name='test1').tosub() - - tensor2 = tensor1.select( - indmap = (slice(512, 1024), slice(0, 1024)), - valmap = None, - shape = [512, 1024] - ) - - tensor3 = tensor2.select( - indmap = (slice(0, 256), slice(0, 1024)), - valmap = None, - shape = [256, 1024] - ) - - tensor4 = tensor1.select( - indmap = (slice(512, 768), slice(0, 1024)), - valmap = None, - shape = [256, 1024] - ) - - tensor5 = tensor1.select( - indmap = (slice(512, 768), slice(0, 1024)), - valmap = None, - shape = [256, 1024] - ) - - select1 = IRTensorTransform( - src_tensors=[tensor2], - dst_tensors=[tensor4, tensor5] - ) - assert not select1.is_identity() - - select2 = IRTensorTransform( - src_tensors=[tensor3], - dst_tensors=[tensor4, tensor5] - ) - assert select2.is_identity() - - merge1 = IRTensorTransform( - src_tensors=[tensor4], - dst_tensors=[tensor5] - ) - assert merge1.is_identity() diff --git a/tests/schedule/test_pool.py b/tests/schedule/test_pool.py deleted file mode 100644 index d98813bc..00000000 --- a/tests/schedule/test_pool.py +++ /dev/null @@ -1,22 +0,0 @@ -from cube.schedule.pool import SchedulePool -from cube.schedule.su import SUType, ScheduleUnit - -from cube.ir.cten import IRCell - - -def test_schedule_pool(): - - SchedulePool().clear() - assert len(SchedulePool()._nodes) == 0 - assert len(SchedulePool().nodes()) == 0 - - cell = IRCell( - name='test', signature='test', input_length=4, output_length=2 - ) - SchedulePool().add_node(cell) - - assert len(SchedulePool()._nodes) == 1 - assert len(SchedulePool().nodes()) == 1 - - for record_node in SchedulePool().nodes(): - assert record_node == cell diff --git a/tests/schedule/test_su.py b/tests/schedule/test_su.py deleted file mode 100644 index d255c139..00000000 --- a/tests/schedule/test_su.py +++ /dev/null @@ -1,99 +0,0 @@ -import copy - -from cube.graph.tensor import IRFullTensor -from cube.graph.operator.function import Linear -from cube.graph.graph import IRGraph -from cube.ir.cten import IRCell - -from cube.schedule.su import SUType, ScheduleUnit - - -def construct_model(): - - input = IRFullTensor(shape=[64,1024], name='data') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - # return [input], [ops], [output] - return [input], [linear1, linear2, linear3], [linear3.outputs(0)] - - -def test_su_init(): - - inputs, nodes, outputs = construct_model() - graph = IRGraph(nodes, inputs, outputs, 'Test') - linear1, linear2, linear3 = nodes - - su1 = ScheduleUnit([linear1], stype=SUType.Forward) - assert len(su1.inputs()) == 3 - assert len(su1.outputs()) == 1 - assert su1.signature == SUType.Forward.value - - assert su1.mirror is None - assert su1.stype == SUType.Forward - assert su1._nodes == [linear1] - assert len(su1._send_in_adapters) == 3 - assert len(su1._recv_in_adapters) == 3 - assert len(su1._send_out_adapters) == 1 - assert len(su1._recv_out_adapters) == 1 - assert len(su1._ctrl_predecessors) == 0 - assert len(su1._ctrl_successors) == 0 - - su2 = ScheduleUnit([linear1, linear2], stype=SUType.Forward) - print('su2:', su2) - assert len(su2.inputs()) == 4 - assert len(su2.outputs()) == 1 - assert su2.signature == SUType.Forward.value - - su3 = ScheduleUnit([linear1, linear2, linear3], stype=SUType.Forward) - print('su3:', su3) - assert len(su3.inputs()) == 6 - assert len(su3.outputs()) == 1 - assert su3.signature == SUType.Forward.value - - -def test_su_copy(): - - inputs, nodes, outputs = construct_model() - graph = IRGraph(nodes, inputs, outputs, 'Test') - linear1, linear2, linear3 = nodes - - su1 = ScheduleUnit([linear1, linear2], stype=SUType.Forward) - su2 = ScheduleUnit([linear1, linear2, linear3], stype=SUType.Forward) - IRCell.make_pair(su1, su2) - - csu = copy.copy(su1) - assert csu.inputs() == su1.inputs() - assert csu.outputs() == su1.outputs() - - assert csu.mirror is not None - mirror = csu.mirror - assert mirror.inputs() == su2.inputs() - assert mirror.outputs() == su2.outputs() diff --git a/tests/schedule/test_sugraph.py b/tests/schedule/test_sugraph.py deleted file mode 100644 index 7b72a3e2..00000000 --- a/tests/schedule/test_sugraph.py +++ /dev/null @@ -1,249 +0,0 @@ -from cube.graph.tensor import IRFullTensor -from cube.graph.operator.function import Linear -from cube.graph.graph import IRGraph -from cube.ir.cten import IRCell - -from cube.schedule.su import SUType, ScheduleUnit -from cube.schedule.sugraph import SUGraph -from cube.schedule.adapter.comm import IRCommunication - - -def construct_graph(): - - input = IRFullTensor(shape=[64,1024], name='data') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - graph = IRGraph( - nodes=[linear1, linear2, linear3], - input_tensors=[input], - output_tensors=linear3.outputs(), - module_name="Test" - ) - return graph - - -def test_graph_init(): - - graph = construct_graph() - sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] - - sugraph = SUGraph(sus) - print(sugraph) - assert len(sugraph.inputs()) == 1 - assert len(sugraph.outputs()) == 1 - assert graph.inputs() == sugraph.inputs() - assert graph.outputs() == sugraph.outputs() - - assert sugraph.sequence == sus - - # test dependency - su1, su2, su3 = sus - assert su2 in su1.successors() - assert su3 in su2.successors() - assert su3 not in su1.successors() - assert su1 in su2.predecessors() - assert su1 in su2.predecessors(0) - assert su2 in su3.predecessors() - assert su1 not in su3.predecessors() - - -def test_sugraph_happen_before(): - - graph = construct_graph() - sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] - - sugraph = SUGraph(sus) - su1, su2, su3 = sugraph.sus() - - assert sugraph.happen_before(su1, su2) - assert not sugraph.happen_before(su2, su1) - assert sugraph.happen_before(su1, su3) - assert not sugraph.happen_before(su3, su1) - assert sugraph.happen_before(su2, su3) - assert not sugraph.happen_before(su3, su2) - - -def test_sugraph_merge(): - - graph = construct_graph() - sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] - - sugraph = SUGraph(sus) - su1, su2, su3 = sugraph.fsus() - - assert sugraph.merge(su1, su3) is None - - print('origin: ') - print(sugraph) - su12 = sugraph.merge(su1, su2) - print('merged: ') - print(sugraph) - assert sugraph.nnodes == 2 - assert len(su12.inputs()) == 4 - assert len(su12.outputs()) == 1 - assert len(su12.nodes()) == 2 - assert su12 in sugraph.sus() - assert su1 not in sugraph.sus() - assert su2 not in sugraph.sus() - assert sugraph.happen_before(su12, su3) - - -def test_sugraph_add_flow(): - - graph = construct_graph() - sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] - - sugraph = SUGraph(sus) - su1, su2, su3 = sugraph.sus() - - assert su1 not in su3.predecessors() - assert su3 not in su1.successors() - - assert not sugraph.add_flow(su3, su1) - - assert sugraph.add_flow(su1, su3) - assert su1 in su3.predecessors() - assert su3 in su1.successors() - - -def test_sugraph_assign1(): - - graph = construct_graph() - sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] - - su1, su2, su3 = sus - - # adapter between su1-su2 - send_op = IRCommunication( - send_tensors=[su1.outputs(0)], - send_ranks = [-1] - ) - recv_op = IRCommunication( - recv_tensors=[su1.outputs(0)], - recv_ranks = [-1] - ) - IRCell.make_pair(send_op, recv_op) - send_su12 = ScheduleUnit([send_op], SUType.P2P, name='send') - recv_su12 = ScheduleUnit([recv_op], SUType.P2P, name='recv') - su1._add_out_adapter(0, send_su12, recv_su12) - su2._add_in_adapter(0, send_su12, recv_su12) - - # adapter between su2-su3 - send_op = IRCommunication( - send_tensors=[su1.outputs(0)], - send_ranks = [-1] - ) - recv_op = IRCommunication( - recv_tensors=[su1.outputs(0)], - recv_ranks = [-1] - ) - IRCell.make_pair(send_op, recv_op) - send_su23 = ScheduleUnit([send_op], SUType.P2P, name='send') - recv_su23 = ScheduleUnit([recv_op], SUType.P2P, name='recv') - su2._add_out_adapter(0, send_su23, recv_su23) - su3._add_in_adapter(0, send_su23, recv_su23) - - sugraph = SUGraph( - [su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3] - ) - - assert sugraph.assign(su1, 0) - assert su1.device == [0] - assert send_su12.device == [0] - assert send_su12.nodes(0).send_ranks == [-1] - assert recv_su12.device == [] - assert recv_su12.nodes(0).recv_ranks == [0] - - assert sugraph.assign(su2, 1) - assert su1.device == [0] - assert send_su12.device == [0] - assert send_su12.nodes(0).send_ranks == [1] - assert recv_su12.device == [1] - assert recv_su12.nodes(0).recv_ranks == [0] - - assert sugraph.assign(su3, 1) - assert su3.device == [1] - assert send_su23.device == [1] - assert send_su23.nodes(0).send_ranks == [1] - assert recv_su23.device == [1] - assert recv_su23.nodes(0).recv_ranks == [1] - - assert not sugraph.assign(send_su12, 3) - - -def test_sugraph_assign2(): - - graph = construct_graph() - sus = [ScheduleUnit([node], SUType.Forward) for node in graph.nodes()] - - su1, su2, su3 = sus - - # adapter between su1-su2 - send_op = IRCommunication( - send_tensors=[su1.outputs(0)], - send_ranks = [-1] - ) - recv_op = IRCommunication( - recv_tensors=[su1.outputs(0)], - recv_ranks = [-1] - ) - IRCell.make_pair(send_op, recv_op) - send_su12 = ScheduleUnit([send_op], SUType.P2P, name='send') - recv_su12 = ScheduleUnit([recv_op], SUType.P2P, name='recv') - su1._add_out_adapter(0, send_su12, recv_su12) - su2._add_in_adapter(0, send_su12, recv_su12) - - # adapter between su2-su3 - send_op = IRCommunication( - send_tensors=[su1.outputs(0)], - send_ranks = [-1] - ) - recv_op = IRCommunication( - recv_tensors=[su1.outputs(0)], - recv_ranks = [-1] - ) - IRCell.make_pair(send_op, recv_op) - send_su23 = ScheduleUnit([send_op], SUType.P2P, name='send') - recv_su23 = ScheduleUnit([recv_op], SUType.P2P, name='recv') - su2._add_out_adapter(0, send_su23, recv_su23) - su3._add_in_adapter(0, send_su23, recv_su23) - - sugraph = SUGraph( - [su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3] - ) - - assert not sugraph.set_order( - [su2, send_su12, recv_su12, su1, send_su23, recv_su23, su3] - ) - - assert sugraph.set_order( - [su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3] - ) diff --git a/tests/schedule/test_translator.py b/tests/schedule/test_translator.py deleted file mode 100644 index abf8f84c..00000000 --- a/tests/schedule/test_translator.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation - -from cube.schedule.translator import LogicTranslator -from cube.schedule.translator import IRDataLoader -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph, SUGraphGener -from cube.schedule.pool import SchedulePool - -from cube.graph.tensor import IRFullTensor, IRSubTensor -from cube.graph.operator.function import Linear -from cube.graph.graph import IRGraph - - -class FakeDataLoader: - def __init__(self, batch_size, num=640): - self.batch_size = batch_size - self.length = num - self.pos = 0 - def __iter__(self): - self.pos = 0 - return self - def __next__(self): - self.pos += 1 - if self.pos == self.length: - raise StopIteration - return torch.randn((self.batch_size, 1024)) - - -def construct_graph(): - - input = IRFullTensor(shape=[64,1024], name='data') - weight1 = IRFullTensor(shape=[1024, 1024], name='weight') - bias1 = IRFullTensor(shape=[1024, 1024], name='bias') - weight2 = IRFullTensor(shape=[1024, 1024], name='weight') - weight3 = IRFullTensor(shape=[1024, 1024], name='weight') - bias3 = IRFullTensor(shape=[1024, 1024], name='bias') - - # linear1 - linear1 = Linear( - name='linear1', - signature='torch.nn.functional.linear', - inputs= [input, weight1, bias1], - ) - linear1.infer_shape() - - # linear2 - linear2 = Linear( - name='linear2', - signature='torch.nn.functional.linear', - inputs= [linear1.outputs(0), weight2, None], - ) - linear2.infer_shape() - - # linear3 - linear3 = Linear( - name='linear3', - signature='torch.nn.functional.linear', - inputs= [linear2.outputs(0), weight3, bias3], - ) - linear3.infer_shape() - - graph = IRGraph( - nodes=[linear1, linear2, linear3], - input_tensors=[input], - output_tensors=linear3.outputs(), - module_name="Test" - ) - return graph - - -def test_load_dataloader(): - - SchedulePool().clear() - dataloader = IRDataLoader(FakeDataLoader(batch_size=64)) - - data1 = next(dataloader) - assert isinstance(data1, IRSubTensor) - assert data1.shape == [64, 1024] - - data2 = next(dataloader) - assert len(SchedulePool().nodes()) == 2 - assert all([isinstance(node, IRDataOperation) for node in SchedulePool().nodes()]) - - data3 = LogicTranslator.load_data(dataloader) - assert isinstance(data1, IRSubTensor) - assert data1.shape == [64, 1024] - assert len(SchedulePool().nodes()) == 3 - assert all([isinstance(node, IRDataOperation) for node in SchedulePool().nodes()]) - - -def test_translator_forward(): - SchedulePool().clear() - - graph = construct_graph() - print(graph) - data = IRFullTensor(shape=[64,1024], name='data').tosub() - output = graph(data) - - assert isinstance(output, IRSubTensor) - assert output.shape == [64, 1024] - - nodes = SchedulePool().nodes() - assert len(nodes) == 3 - assert isinstance(SchedulePool().get_tape(output), list) - for node in nodes: - assert isinstance(node, IRFwOperation) - assert isinstance(node.mirror, IRBpOperation) - - -def test_translator_backward(): - SchedulePool().clear() - - graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data').tosub() - loss = graph(data) - loss.backward() - - nodes = SchedulePool().nodes() - for node in nodes: - print(node) - assert len(nodes) == 6 - fnodes = nodes[0:3] - bnodes = nodes[3:] - assert loss in bnodes[0].inputs() - for fnode, bnode in zip(fnodes, bnodes[::-1]): - assert fnode.mirror == bnode - assert bnode.mirror == fnode - - -def test_sugraph_gener_gen(): - SchedulePool().clear() - - graph = construct_graph() - data = IRFullTensor(shape=[64,1024], name='data').tosub() - output = graph(data) - - # forward adatpers - nodes = SchedulePool().nodes() - sugraph = SUGraphGener.gen_sugraph(nodes) - assert len(sugraph.sus()) == 7 - su1, send_su12, recv_su12, su2, send_su23, recv_su23, su3 = sugraph.sus() - assert su1.stype == SUType.Forward - assert su2.stype == SUType.Forward - assert su3.stype == SUType.Forward - assert send_su12.stype == SUType.P2P - assert recv_su12.stype == SUType.P2P - assert send_su23.stype == SUType.P2P - assert recv_su23.stype == SUType.P2P - - # backward adapters - output.backward() - nodes = SchedulePool().nodes() - sugraph = SUGraphGener.gen_sugraph(nodes) - for su in sugraph.sus(): - print(su) - # note loss will be the input to autograd, therefore - # have additional adapters - assert len(sugraph.sus()) == 14 diff --git a/tests/schedule/test_workflow.py b/tests/schedule/test_workflow.py deleted file mode 100644 index 71aa3313..00000000 --- a/tests/schedule/test_workflow.py +++ /dev/null @@ -1,103 +0,0 @@ -import torch -from torch import nn - -import cube -from cube.graph.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.pool import SchedulePool -from cube.schedule.sugraph import SUGraphGener -from cube.schedule.translator import IRDataLoader - - -class MLP(nn.Module): - def __init__(self, dim, mult=16): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult, bias=False) - self.linear2 = nn.Linear(dim * mult, dim) - self.linear3 = nn.Linear(dim, dim * mult, bias=False) - self.linear4 = nn.Linear(dim * mult, dim) - - def forward(self, data): - output = self.linear1(data) - output = self.linear2(output) - output = self.linear3(output) - output = self.linear4(output) - return output - - -class FakeDataLoader: - def __init__(self, shape, num=640): - self.shape = shape - self.length = num - self.pos = 0 - def __iter__(self): - self.pos = 0 - return self - def __next__(self): - self.pos += 1 - if self.pos == self.length: - raise StopIteration - return torch.randn(self.shape) - - -def test_semantic_model(): - dim = 1024 - model = MLP(dim=dim) - model = cube.SemanticModel( - model, - input_shapes=([64, dim],) - ) - assert isinstance(model.ir_graph, IRGraph) - assert model._loaded_module is None - - -def test_schedule(): - - SchedulePool().clear() - - dim = 1024 - batch_size = 64 - - model = MLP(dim=dim) - model = cube.SemanticModel( - model, - input_shapes=([batch_size, dim],) - ) - - dataloader = FakeDataLoader((batch_size, dim)) - dataloader = IRDataLoader(dataloader) - - def policy(sugraph, resources): - # dataloader - sugraph.assign(sugraph.sus(0), 0) - - fsus = [su for su in sugraph.sus() if su.stype == SUType.Forward] - for idx, fsu in enumerate(fsus): - bsu = fsu.mirror - if idx < 2: - sugraph.assign(fsu, 0) - sugraph.assign(bsu, 0) - else: - sugraph.assign(fsu, 1) - sugraph.assign(bsu, 1) - return sugraph - - def train_iter(model, dataloader): - num_micro_batch = 1 - for _ in range(num_micro_batch): - data = next(dataloader) - output = model(data) - output.backward() - - train_iter(model, dataloader) - - nodes = SchedulePool().nodes() - graph = IRGraph(nodes, None, None, 'testmodel') - print(graph) - - sugraph = SUGraphGener.gen_sugraph(nodes) - - sugraph = policy(sugraph, None) - print(sugraph) - - assert len(sugraph.sus()) == 23 From d5a8c85c0382e23f6cf400b222490e767224d4b1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 2 Mar 2022 19:52:37 +0800 Subject: [PATCH 0606/1892] working on tp 1f1b --- handcraft/pipeline/dummy.py | 76 ++++++++++ handcraft/pipeline/schedule.py | 265 +++++++++++++++++++++++++++++++++ 2 files changed, 341 insertions(+) create mode 100644 handcraft/pipeline/dummy.py create mode 100644 handcraft/pipeline/schedule.py diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py new file mode 100644 index 00000000..5aa83db0 --- /dev/null +++ b/handcraft/pipeline/dummy.py @@ -0,0 +1,76 @@ +""" +Dummy model + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + handcraft/pipeline/dummy.py +""" +import torch +import torch.nn.functional as F +import cube +from cube.runtime.device import DeviceGroup +from cube.runtime.syndata import SynDataLoader + +from handcraft.pipeline.schedule import schedule_tp_1f1b, schedule_naive + + +class DummyModel(torch.nn.Module): + + def __init__(self, dim: int, bs: int, stage_id: int, sharding=False): + + super().__init__() + self.bs = bs + self.dim = dim + self.is_last_stage = stage_id == DeviceGroup().world_size + if sharding: + chunk_num = torch.distributed.get_world_size() + self.weight = torch.nn.Parameter(torch.zeros((dim // chunk_num, dim))) + else: + self.weight = torch.nn.Parameter(torch.zeros((dim, dim))) + + def input_shape(self): + return (self.bs, self.dim, self.dim) + + def input_dtype(self): + return torch.float32 + + def forward(self, input): + output = F.linear(input, self.weight) + if self.is_last_stage: + output = torch.sum(output) + return output + + + +if __name__ == '__main__': + + cube.init() + rank = DeviceGroup().rank + + dim = 1024 + gbs = 32 + mbs = 8 + + # tp 1f1b + first_stage_model = DummyModel(dim, mbs, 0, sharding=True).cuda() + if rank == 0: + model = None + else: + model = DummyModel(dim, mbs, rank, sharding=False).cuda() + + # naive pipleline + # model = DummyModel(dim, mbs, sharding=False).cuda() + + dataloader = SynDataLoader( + shapes=([mbs, dim, dim],), + dtypes=(torch.float32, ), + batch_dims=(0,) + ) + + for step in range(128): + # schedule_naive(model, dataloader, gbs // mbs) + schedule_tp_1f1b(model, first_stage_model, dataloader, gbs // mbs) + if (step+1) % 10 == 0: + print(f'iteration: {step+1}/128') + \ No newline at end of file diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py new file mode 100644 index 00000000..0012e045 --- /dev/null +++ b/handcraft/pipeline/schedule.py @@ -0,0 +1,265 @@ +from typing import List +import torch + +from cube.profiler.timer import CudaTimer, print_each_rank +import cube.runtime.adapter.collectives as coll +from cube.runtime.device import DeviceGroup + + +def forward_step(model, *args, **kwargs): + """ + Forward pass + """ + CudaTimer().start("forward") + output = model(*args, **kwargs) + CudaTimer().stop("forward") + return output + + +def backward_step(input_tensors: List[torch.Tensor], + output_tensors: List[torch.Tensor], + output_tensor_grads: List[torch.Tensor]): + """ + Backward pass + """ + for tensor in input_tensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + tensor.retain_grad() + CudaTimer().start("backward") + torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) + CudaTimer().stop("backward") + input_tensor_grads = [] + for tensor in input_tensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + input_tensor_grads.append(tensor.grad) + else: + input_tensor_grads.append(None) + return input_tensor_grads + + +def is_first_stage(): + return DeviceGroup().rank == 0 + + +def is_last_stage(): + return DeviceGroup().rank == DeviceGroup().world_size - 1 + + +def recv_input(model, dataloader, prev_rank: int): + if is_first_stage(): + return next(dataloader) + else: + return coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + + +def schedule_naive(model, dataloader, num_microbatch: int): + rank = DeviceGroup().rank + next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size + prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size + for _ in range(num_microbatch): + # recv forward + if is_first_stage(): + input = next(dataloader) + else: + print(f'rank {rank} recving forward input...') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + # forward + output = forward_step(model, input) + # send forward + if not is_last_stage(): + print(f'rank {rank} sending forward output...') + coll.send(output, next_rank) + # recv backward + output_grad = None + if not is_last_stage(): + print(f'rank {rank} recving backward input...') + output_grad = coll.recv(output.size(), next_rank, output.dtype) + # backward + input_grad = backward_step([input], [output], [output_grad])[0] + # send backward + if not is_first_stage(): + print(f'rank {rank} sending backward output...') + coll.send(input_grad, prev_rank) + + +def schedule_tp_1f1b(model: torch.nn.Module, + first_stage_model: torch.nn.Module, + dataloader, + num_microbatch: int): + rank = DeviceGroup().rank + next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size + prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size + + input_tensors = list() + output_tensors = list() + + input_1st_tensors = list() + output_1st_tensors = list() + + def tp_forward(fmodel, dataloader) -> torch.Tensor: + input = next(dataloader) + output = forward_step(fmodel, input) + input_1st_tensors.append(input) + output_1st_tensors.append(output) + # gather + outputs = coll.gather([output], None, None, [1,0,2,3]) + if rank == 1: + outputs[0], outputs[1] = outputs[1], outputs[0] + output = torch.cat(tuple(outputs), dim=-1) + else: + output = None + return output + + def tp_backward(grad: torch.Tensor): + with torch.no_grad(): + grads = grad.chunk(4, dim=-1) + grads[0], grads[1] = grads[1], grads[0] + input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) + grad_1st = coll.scatter(grads, [output_1st.size()], [output_1st.dtype], [1,0,2,3]) + backward_step([input_1st], [output_1st], [grad_1st])[0] + + def recv_grad(output: torch.Tensor): + if is_last_stage(): + return None + return coll.recv(output.size(), next_rank, output.dtype) + + fofst = [0, 0, 0, -1][rank] + bofst = [0, -2, -2, -1][rank] + last_barrier_grad = None + for step in range(num_microbatch + 2): + torch.distributed.barrier() + print_each_rank('=========', rank_only=0) + fmid, bmid = step + fofst, step + bofst + # step1: tp forward + if 0 <= step and step <= num_microbatch - 1: + print(f'rank {rank} forward tp model ') + output_1st = tp_forward(first_stage_model, dataloader) + print(f'rank {rank} here') + # step2-1: backward + forward + if rank % 2 == 0: + # backward + forward + if rank == 0: pass + else: + if 0 <= bmid and bmid <= num_microbatch - 1: + input, output = input_tensors.pop(0), output_tensors.pop(0) + # recv output grad + print(f'rank {rank} recv backward grad ') + output_grad = recv_grad(output) + # backward + input_grad = backward_step([input], [output], [output_grad])[0] + # send input grad + print(f'rank {rank} send backward input ') + coll.send(input_grad, prev_rank) + if 0 <= fmid and fmid <= num_microbatch - 1: + # recv input + print(f'rank {rank} recv forward input ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + # forward step + output = forward_step(model, input) + # send output + print(f'rank {rank} send forward output ') + coll.send(output, next_rank) + input_tensors.append(input) + output_tensors.append(output) + # step2-2: forward + backward + #FIXME: warmup forward transimission + if rank % 2 == 1: + if bmid >= 1 and rank != 2: + # cross-barrier send grad + print(f'rank {rank} send backward input ') + coll.send(last_barrier_grad, prev_rank) + if 0 <= fmid and fmid <= num_microbatch - 1: + # recv input + input = output_1st + if rank != 1: + print(f'rank {rank} recv forward input ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + # forward + output = forward_step(model, input) + # send forward + if not is_last_stage(): + print(f'rank {rank} send forward output ') + coll.send(output, next_rank) + input_tensors.append(input) + output_tensors.append(output) + if 0 <= bmid and bmid <= num_microbatch - 1: + input, output = input_tensors.pop(0), output_tensors.pop(0) + # recv grad + print(f'rank {rank} recv backward grad ') + output_grad = recv_grad(output) + # backward + input_grad = backward_step([input], [output], [output_grad])[0] + # postpone send to next barrier + last_barrier_grad = input_grad + # step3: tp backward + if 0 <= (step-2) and (step-2) <= num_microbatch - 1: + print(f'rank {rank} backward tp model ') + tp_backward(last_barrier_grad) + + torch.distributed.barrier() + print_each_rank('=========', rank_only=0) + + +def schedule_1f1b(model: torch.nn.Module, + dataloader, + num_microbatch: int): + group = list(range(DeviceGroup().world_size)) + rank = DeviceGroup().rank + next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size + prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size + + input_tensors = list() + output_tensors = list() + + # warmup + num_warmup_microbatch = DeviceGroup().world_size - 1 - rank + for mid in range(num_warmup_microbatch): + # recv forward + input = recv_input(model, dataloader, prev_rank) + # forward + output = forward_step(model, input) + # send forward + coll.send(output, next_rank) + input_tensors.append(input) + output_tensors.append(output) + + num_warmup_remaining = num_microbatch - num_warmup_microbatch + if num_warmup_remaining > 0: + input = recv_input(model, dataloader, prev_rank) + + # steady + for i in range(num_warmup_microbatch): + # forward + output = forward_step(model, input) + # send forward + recv backward + grad = coll.sendrecv( + [output], + [list(output.size())], [output.dtype], + [next_rank], [next_rank] + )[0] + input_tensors.append(input) + output_tensors.append(output) + # backward + input, output = input_tensors.pop(0), output_tensors.pop(0) + input_grad = backward_step([input], [output], [grad]) + # send backward recv forward + if i != (num_warmup_remaining-1): + input = coll.sendrecv( + [input_grad], + (list(input.size()),), (input.dtype,), + [prev_rank], [prev_rank] + ) + else: + # send backward + coll.send(input_grad, prev_rank) + + # cooldown + for i in range(num_warmup_microbatch): + input, output = input_tensors.pop(0), output_tensors.pop(0) + # recv backward + grad = coll.recv(list(output.size()), next_rank, dtype=output.dtype) + # backward + grad = backward_step([input], [output], [grad]) + # send backward + coll.send(grad, prev_rank) + From 4132622d217564ea591e8a9854c7088e7c58f92c Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 3 Mar 2022 11:07:05 +0800 Subject: [PATCH 0607/1892] add node recog of Pad, Conv3D, EinOps(for reshape) to enable graph capture of wrf(slim) --- cube/graph/operator/function/conv.py | 56 +++++ cube/graph/operator/function/function.py | 52 +++++ cube/graph/operator/function/pad.py | 51 +++++ cube/graph/parser/mapping.py | 11 + cube/graph/parser/parser.py | 5 + examples/wrf/policy/naive.py | 20 ++ examples/wrf/wrf.py | 253 ++++++++++++++--------- 7 files changed, 351 insertions(+), 97 deletions(-) create mode 100644 cube/graph/operator/function/pad.py create mode 100644 examples/wrf/policy/naive.py diff --git a/cube/graph/operator/function/conv.py b/cube/graph/operator/function/conv.py index f9f5fd25..92213396 100644 --- a/cube/graph/operator/function/conv.py +++ b/cube/graph/operator/function/conv.py @@ -51,3 +51,59 @@ def new(self, inputs: List, outputs: List): op.set_output(0, outputs[0]) op.infer_shape() return op + + + +class IRConv3D(IRFwOperation): + + def __init__(self, signature: str, inputs: List[IRTensor], name: str, + **kwargs): + signature = 'cube.runtime.function.conv3d' + assert len(inputs) == 3, "Expected only input, weight, bias as inputs" + assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" + super().__init__(name, signature, 3, 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update(kwargs) + + def infer_shape(self) -> bool: + """ + Output shape inference given the input shapes + """ + if len(self.inputs(0).shape) == 0 or len(self.inputs(1).shape) == 0: + return False + N = self.inputs(0).shape[0] + iC = self.inputs(0).shape[1] + iT, iH, iW = self.inputs(0).shape[2:5] + + oC = self.inputs(1).shape[0] + stride = self.kwargs['stride'] + padding = self.kwargs['padding'] + dilation = self.kwargs['dilation'] + dT = self.inputs(1).shape[2] + dH = self.inputs(1).shape[3] + dW = self.inputs(1).shape[4] + + oT = (iT + 2 * padding[0] - dilation[0] * (dT - 1) - 1) // stride[0] + 1 + oH = (iH + 2 * padding[1] - dilation[1] * (dH - 1) - 1) // stride[1] + 1 + oW = (iW + 2 * padding[2] - dilation[2] * (dW - 1) - 1) // stride[2] + 1 + shape = [N, oC, oT, oH, oW] + + self.outputs(0).shape = shape + return True + + def new(self, inputs: List, outputs: List): + """ + construct a new operator sharing same kwargs with new inputs + and outputs + """ + stride = self.kwargs['stride'] + padding = self.kwargs['padding'] + dilation = self.kwargs['dilation'] + groups = self.kwargs['groups'] + op = IRConv3D(self.signature, inputs, self.name, + stride=stride, padding=padding, dilation=dilation, groups=groups) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + op.infer_shape() + return op \ No newline at end of file diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 2501584f..b65be0d7 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -1,5 +1,9 @@ from cube.graph.operator.function.einops import EinDim, EinopAnno, IREinops from cube.graph.operator.function.conv import IRConv2D +from cube.graph.operator.function.conv import IRConv3D +from cube.graph.operator.function.pad import IRPad +from cube.graph.operator.function.scripteinops import IRScriptEinOps + def Linear(signature, inputs): @@ -160,3 +164,51 @@ def Conv2D(signature, inputs): padding = [padH, padH, padW, padW] return IRConv2D(signature, tensors, 'conv2d', stride=stride, padding=padding, dilation=dilation, groups=groups) + + +def Conv3D(signature, inputs): + """ + conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor + https://pytorch.org/docs/stable/generated/torch.nn.functional.conv3d.html?highlight=conv3d#torch.nn.functional.conv3d + """ + assert len(inputs) == 7, f"Expected 7 inputs but only got {len(inputs)}" + tensors = inputs[0:3] + stride, padding, dilation, groups = inputs[3:] + if isinstance(padding, int): + padding = [padding] * 4 + elif len(padding) == 2: + padH, padW = padding + padding = [padH, padH, padW, padW] + return IRConv3D(signature, tensors, 'conv3d', + stride=stride, padding=padding, dilation=dilation, groups=groups) + +def Pad(signature, inputs): + """ + torch.nn.functional.pad(input, pad, mode='constant', value=0.0) + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad + :param signature: + :param inputs: + :return: + """ + # print("#Pad::inputs.len: {}".format(len(inputs))) + # idx = 0 + # for input in inputs: + # if idx >= 0: + # print("#Pad::input[{}]: {}".format(idx, input)) + # idx += 1 + tensors = inputs[0:1] + pad, mode, value = inputs[1:] + return IRPad(signature, tensors, 'pad', pad=pad, mode=mode, value=value) + +def ScriptEinOps(signature, inputs): + """ + apply_for_scriptable_torch(recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str) -> torch.Tensor: + https://github.com/arogozhnikov/einops/blob/master/einops/_torch_specific.py + :param signature: + :param inputs: + :return: + """ + recipe = inputs[0] + tensors = inputs[1:2] + reduction_type = inputs[2] + return IRScriptEinOps(signature, tensors, 'scripteinops', recipe=recipe, reduction_type=reduction_type) \ No newline at end of file diff --git a/cube/graph/operator/function/pad.py b/cube/graph/operator/function/pad.py new file mode 100644 index 00000000..db92f78b --- /dev/null +++ b/cube/graph/operator/function/pad.py @@ -0,0 +1,51 @@ +from typing import List + +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + +class IRPad(IRFwOperation): + def __init__(self, signature: str, inputs: List[IRTensor], name: str, + **kwargs): + # torch.nn.functional.pad(input, pad, mode='constant', value=0.0) + # pad: List[int] + assert len(inputs) == 1, "Expected only input, weight, bias as inputs" + assert len(kwargs) == 3, "Expected 2 kwargs: mode, value" + super().__init__(name, signature, 1, 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update(kwargs) + + def infer_shape(self) -> bool: + """ + Output shape inference given the input shapes + """ + if len(self.inputs(0).shape) == 0: + return False + + N = self.inputs(0).shape[0] + pad = self.kwargs['pad'] + mode = self.kwargs['mode'] + value = self.kwargs['value'] + assert (len(pad) % 2 == 0, "IRPad::infer_shape len(pad) % 2 == 0") + + shape = self.inputs(0).shape + for pad_idx, pad_size in enumerate(pad): + shape[-1 - (pad_idx // 2)] += pad_size + + self.outputs(0).shape = shape + return True + + def new(self, inputs: List, outputs: List): + """ + construct a new operator sharing same kwargs with new inputs + and outputs + """ + pad = self.kwargs['pad'] + mode = self.kwargs['mode'] + value = self.kwargs['value'] + op = IRPad(self.signature, inputs, self.name, + pad=pad, mode=mode, value=value) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + op.infer_shape() + return op diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index edb2c6bb..f8dc3978 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -45,6 +45,9 @@ def register(signature: str, op: IRFwOperation): # customized __customize = lambda name: f'cube.runtime.function.complex.{name}' + # einops + __einopsize = lambda name: f'einops._torch_specific.{name}' + kOpMap = { # torch nn functional @@ -57,6 +60,8 @@ def register(signature: str, op: IRFwOperation): __ftemplate('gelu') : function.GeLU, + __ftemplate('_pad'): function.Pad, + # __ftemplate('layer_norm'): function.LayerNorm, # torch aten @@ -77,6 +82,12 @@ def register(signature: str, op: IRFwOperation): __ttemplate('conv2d'): function.Conv2D, + __ttemplate('conv3d'): function.Conv3D, + + #einops + __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, + + } diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index f7ba0e98..b3ebd8f0 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -157,6 +157,7 @@ def ntype(node: torch._C.Node): @staticmethod def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation]: + # print("### parse_node {}".format(node)) """ Parse the node and return the IRFwOperation nodes """ @@ -365,7 +366,11 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: elif dtype == 'NoneType': frame.add_var(var_name, None) # module name or other things cannot handle + elif dtype == '__torch__.einops.einops.TransformRecipe': + recipe = getattr(module, label) + frame.add_var(var_name, recipe) else: + # print("### parse_prim_attr_node unknown: {}".format(dtype)) frame.add_var(var_name, label) return list() diff --git a/examples/wrf/policy/naive.py b/examples/wrf/policy/naive.py new file mode 100644 index 00000000..0863b65e --- /dev/null +++ b/examples/wrf/policy/naive.py @@ -0,0 +1,20 @@ +from cube.graph import IRGraph +from cube.graph.operator.function import IRConv2D + +def PAS(graph: IRGraph, resource): + for node in graph.nodes(): + if isinstance(node, IRConv2D): + sub_nodes = list() + algo = node.algorithms('halo') + Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus // 2)) + for Wnode in Wnodes: + algo = Wnode.algorithms('halo') + Hnodes = graph.partition(Wnode, algo, config=dict(idx=0, dim=2, num=2)) + sub_nodes += Hnodes + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + # sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + print(graph.extra_repr()) + return graph diff --git a/examples/wrf/wrf.py b/examples/wrf/wrf.py index 3a6e5f04..feddba5d 100644 --- a/examples/wrf/wrf.py +++ b/examples/wrf/wrf.py @@ -6,6 +6,14 @@ import torch.nn.functional as F # from linalg import tridiagonal +from einops import rearrange +from einops.layers.torch import Rearrange + +torch.jit.script(Rearrange('b c h w -> b h w c')) +print("torch einops 0") +torch.jit.script(Rearrange('(b0 b1 b2) c h w -> b0 b1 b2 h w c', b0=1, b1=1)) +print("torch einops 1") + import cube from examples.poisson.policy.naive import PAS @@ -13,6 +21,8 @@ device = 'cuda' # +def namestr(obj, namespace): + return [name for name in namespace if namespace[name] is obj] def init(nx, ny, nz, dx, dy, dz, theta, PREF, Rd, g, p_t=None, p_s=None, u0=None, v0=None, w0=None, device='cuda'): # spatial discretization @@ -49,13 +59,26 @@ def init(nx, ny, nz, dx, dy, dz, theta, PREF, Rd, g, p_t=None, p_s=None, u0=None # W (nz - 1, ny, nx) W = w0 if w0 is not None else torch.zeros((nz - 1, ny, nx), device=device) + # for var in [U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s]: + # print("### {} shape {}".format(namestr(var, globals()), var.shape)) return U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s class WRF(torch.nn.Module): def __init__(self): super().__init__() - # self.method_name() + self.bar_x_pre = Rearrange('(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # x = X.view(1, 1, Nz, Ny, Nx) + self.bar_x_post = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') + + self.pad_y_pre = Rearrange('(b0 Nz) Ny Nx -> b0 Nz Ny Nx', b0=1) #X.view(1, Nz, Ny, Nx) + self.pad_y_post = Rearrange('b0 Nz Ny Nx -> (b0 Nz) Ny Nx') #.view(Nz, Ny + 2, Nx) + + self.delta_z_pre = Rearrange('(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / dz + self.delta_z_post = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') + + self.delta_x_pre = self.delta_z_pre + self.delta_x_post = self.delta_z_post + # def forward(self, dt): # self.U, self.V, self.W, self.Theta, self.Mu, self.Phi = \ @@ -63,65 +86,76 @@ def __init__(self): def forward(self, U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, - dt, PREF, Rd, g): + dt, PREF, Rd, g, + bar_x_filter, delta_z_filter): - U, V, W, Theta, Mu, Phi = self.RK3_step(U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) + # U, V, W, Theta, Mu, Phi = self.RK3_step(U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter) + U = self.RK3_step(U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, + Rd, g, bar_x_filter, delta_z_filter) return U, V, W, Theta, Mu, Phi def RHS(self, U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, - PREF, Rd, g): + PREF, Rd, g, + bar_x_filter, + delta_z_filter): + + delta_x_filter = bar_x_filter # volecity - u = U / self.bar_x(self.pad_x(Mu)) - v = V / self.bar_y(self.pad_y(Mu)) - w = W / self.bar_z(Mu) - alpha = -self.delta_z(self.pad_z(Phi, Phi_t, Phi_s), dz) / Mu + u = U / self.bar_x(self.pad_x(Mu), bar_x_filter) + # v = V / self.bar_y(self.pad_y(Mu), bar_y_filter) + # w = W / self.bar_z(Mu) + alpha = self.delta_z(self.pad_z(Phi, Phi_t, Phi_s), dz, delta_z_filter) / Mu + #TODO recover me alpha = -self.delta_z(self.pad_z(Phi, Phi_t, Phi_s), dz, delta_z_filter) / Mu Alpha = alpha theta = Theta / Mu - p = PREF * (Rd * theta / PREF / alpha)**1.4 - omega = -w * g / self.bar_z(alpha) / self.bar_z(Mu) - Omega = omega * self.bar_z(Mu) + p = theta #TODO p = PREF * (Rd * theta / PREF / alpha)**1.4 + # omega = -w * g / self.bar_z(alpha) / self.bar_z(Mu) + # Omega = omega * self.bar_z(Mu) #Omega = Omega # advection term - R_U = - self.delta_x(self.bar_x(self.pad_x(U)) * self.bar_x(self.pad_x(u)), dx) \ - - self.delta_y(self.bar_x(self.pad_x(V)) * self.bar_y(self.pad_y(u)), dy) \ - - self.delta_z(self.bar_x(self.pad_x(self.pad_z(Omega))) * self.bar_z(self.pad_z(u)), dz) - - R_V = - self.delta_x(self.bar_y(self.pad_y(U)) * self.bar_x(self.pad_x(v)), dx) \ - - self.delta_y(self.bar_y(self.pad_y(V)) * self.bar_y(self.pad_y(v)), dy) \ - - self.delta_z(self.bar_y(self.pad_y(self.pad_z(Omega))) * self.bar_z(self.pad_z(v)), dz) - - R_W = - self.delta_x(self.bar_z(U) * self.bar_x(self.pad_x(w)), dx) \ - - self.delta_y(self.bar_z(V) * self.bar_y(self.pad_y(w)), dy) \ - - self.delta_z(self.bar_z(self.pad_z(Omega)) * self.bar_z(self.pad_z(w)), dz) - - R_Theta = - self.delta_x(U * self.bar_x(self.pad_x(theta)), dx) \ - - self.delta_y(V * self.bar_y(self.pad_y(theta)), dy) \ - - self.delta_z(self.pad_z(Omega) * self.bar_z(self.pad_z(theta)), dz) - - R_Phi = - self.bar_z(self.bar_x(u)) * self.delta_x(self.bar_x(self.pad_x(Phi)), dx) \ - - self.bar_z(self.bar_y(v)) * self.delta_y(self.bar_y(self.pad_y(Phi)), dy) \ - - Omega * self.delta_z(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s)), dz) - - R_Mu = - self.delta_x(U, dx) - self.delta_y(V, dy) - self.delta_z(self.pad_z(Omega), dz) + R_U = self.delta_x(self.bar_x(self.pad_x(U), bar_x_filter) * self.bar_x(self.pad_x(u), bar_x_filter), dx, delta_x_filter) + #- self.delta_x(self.bar_x(self.pad_x(U), bar_x_filter) * self.bar_x(self.pad_x(u), bar_x_filter), dx, delta_x_filter) + # \ + # - self.delta_y(self.bar_x(self.pad_x(V), bar_x_filter) * self.bar_y(self.pad_y(u)), dy) \ + # - self.delta_z(self.bar_x(self.pad_x(self.pad_z(Omega)), bar_x_filter) * self.bar_z(self.pad_z(u)), dz) + + # R_V = - self.delta_x(self.bar_y(self.pad_y(U)) * self.bar_x(self.pad_x(v)), dx) \ + # - self.delta_y(self.bar_y(self.pad_y(V)) * self.bar_y(self.pad_y(v)), dy) \ + # - self.delta_z(self.bar_y(self.pad_y(self.pad_z(Omega))) * self.bar_z(self.pad_z(v)), dz) + # + # R_W = - self.delta_x(self.bar_z(U) * self.bar_x(self.pad_x(w)), dx) \ + # - self.delta_y(self.bar_z(V) * self.bar_y(self.pad_y(w)), dy) \ + # - self.delta_z(self.bar_z(self.pad_z(Omega)) * self.bar_z(self.pad_z(w)), dz) + + # R_Theta = - self.delta_x(U * self.bar_x(self.pad_x(theta)), dx) \ + # - self.delta_y(V * self.bar_y(self.pad_y(theta)), dy) \ + # - self.delta_z(self.pad_z(Omega) * self.bar_z(self.pad_z(theta)), dz) + # + # R_Phi = - self.bar_z(self.bar_x(u)) * self.delta_x(self.bar_x(self.pad_x(Phi)), dx) \ + # - self.bar_z(self.bar_y(v)) * self.delta_y(self.bar_y(self.pad_y(Phi)), dy) \ + # - Omega * self.delta_z(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s)), dz) + # + # R_Mu = - self.delta_x(U, dx) - self.delta_y(V, dy) - self.delta_z(self.pad_z(Omega), dz) # pressure term - R_U += - self.bar_x(self.pad_x(Mu)) * self.bar_x(self.pad_x(alpha)) * self.delta_x(self.pad_x(p), dx) \ - - (self.delta_z(self.bar_x(self.bar_z(self.pad_x(self.pad_z(p, P_t, P_s)))), dz) * - self.delta_x(self.pad_x(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s))), dx)) + R_U = R_U - self.bar_x(self.pad_x(Mu), bar_x_filter) * self.bar_x(self.pad_x(alpha), bar_x_filter) * self.delta_x(self.pad_x(p), dx, delta_x_filter) + # \ + # - (self.delta_z(self.bar_x(self.bar_z(self.pad_x(self.pad_z(p, P_t, P_s))), bar_x_filter), dz) * + # self.delta_x(self.pad_x(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s))), dx)) - R_V += - self.bar_y(self.pad_y(Mu)) * self.bar_y(self.pad_y(alpha)) * self.delta_y(self.pad_y(p), dy) \ - - (self.delta_z(self.bar_y(self.bar_z(self.pad_y(self.pad_z(p, P_t, P_s)))), dz) * - self.delta_y(self.pad_y(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s))), dy)) + # R_V = R_V - self.bar_y(self.pad_y(Mu)) * self.bar_y(self.pad_y(alpha)) * self.delta_y(self.pad_y(p), dy) \ + # - (self.delta_z(self.bar_y(self.bar_z(self.pad_y(self.pad_z(p, P_t, P_s)))), dz) * + # self.delta_y(self.pad_y(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s))), dy)) + # + # R_W = R_W + g * (self.delta_z(p, dz) - self.bar_z(Mu)) - R_W += g * (self.delta_z(p, dz) - self.bar_z(Mu)) - - # gravity term - R_Phi += g * w + # # gravity term + # R_Phi = R_Phi + g * w # Coriolis term # R_U += + 100 * self.bar_x(self.bar_y(self.pad_x(V))) \ @@ -131,36 +165,37 @@ def RHS(self, # + 100 * self.bar_y(self.bar_z(self.pad_y(self.pad_z(W)))) \ # - v * self.bar_y(self.bar_z(self.pad_y(self.pad_z(W)))) / 6400. / 1000. - return R_U, R_V, R_W, R_Theta, R_Mu, R_Phi #, Alpha, Omega + return R_U #, R_V, R_W, R_Theta, R_Mu, R_Phi #, Alpha, Omega # def RK3_step(self, U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, dt): - def RK3_step(self, U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g): + def RK3_step(self, U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter): r"""One RK3 Step""" - R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) + # R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter) + R_U = self.RHS(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter) U_ = U + dt * R_U / 3 - V_ = V + dt * R_V / 3 - W_ = W + dt * R_W / 3 - Theta_ = Theta + dt * R_Theta / 3 - Mu_ = Mu + dt * R_Mu / 3 - Phi_ = Phi + dt * R_Phi / 3 - - R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) - U_ = U + dt * R_U / 2 - V_ = V + dt * R_V / 2 - W_ = W + dt * R_W / 2 - Theta_ = Theta + dt * R_Theta / 2 - Mu_ = Mu + dt * R_Mu / 2 - Phi_ = Phi + dt * R_Phi / 2 - - R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) - U += dt * R_U - V += dt * R_V - W += dt * R_W - Theta += dt * R_Theta - Mu += dt * R_Mu - Phi += dt * R_Phi - - return U, V, W, Theta, Mu, Phi + # V_ = V + dt * R_V / 3 + # W_ = W + dt * R_W / 3 + # Theta_ = Theta + dt * R_Theta / 3 + # Mu_ = Mu + dt * R_Mu / 3 + # Phi_ = Phi + dt * R_Phi / 3 + + # R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) + # U_ = U + dt * R_U / 2 + # V_ = V + dt * R_V / 2 + # W_ = W + dt * R_W / 2 + # Theta_ = Theta + dt * R_Theta / 2 + # Mu_ = Mu + dt * R_Mu / 2 + # Phi_ = Phi + dt * R_Phi / 2 + # + # R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) + U = U + R_U #TODO U = U + dt * R_U + # V = V + dt * R_V + # W = W + dt * R_W + # Theta = Theta + dt * R_Theta + # Mu = Mu + dt * R_Mu + # Phi = Phi + dt * R_Phi + + return U #TODO , V, W, Theta, Mu, Phi def pad_x(self, X): r"""Periodic boundary condition in x axis""" @@ -168,18 +203,23 @@ def pad_x(self, X): def pad_y(self, X): r"""Periodic boundary condition in y axis""" - Nz, Ny, Nx = X.shape - return F.pad(X.view(1, Nz, Ny, Nx), (0, 0, 1, 1), "circular").view(Nz, Ny + 2, Nx) + # Nz, Ny, Nx = X.shape + # return F.pad(X.view(1, Nz, Ny, Nx), (0, 0, 1, 1), "circular").view(Nz, Ny + 2, Nx) + x_ext = self.pad_y_pre(X) + x_pad = F.pad(x_ext, (0, 0, 1, 1), "circular") + x_unext = self.pad_y_post(x_pad) + return x_unext # TODO def pad_z(self, X, top=None, surface=None): def pad_z(self, X, top=torch.Tensor(), surface=torch.Tensor()): r"""Dirichlet boundary condition in z axis""" - _, ny, nx = X.shape - top = torch.zeros((1, ny, nx), device=X.device) #TODO top = top if top is not None else torch.zeros((1, ny, nx), device=X.device) - surface = torch.zeros((1, ny, nx), device=X.device) #TODO surface = surface if surface is not None else torch.zeros((1, ny, nx), device=X.device) - return torch.cat((top, X, surface), dim=0) + # _, ny, nx = X.shape + # top = torch.zeros((1, ny, nx), device=X.device) #TODO top = top if top is not None else torch.zeros((1, ny, nx), device=X.device) + # surface = torch.zeros((1, ny, nx), device=X.device) #TODO surface = surface if surface is not None else torch.zeros((1, ny, nx), device=X.device) + # return torch.cat((top, X, surface), dim=0) + return F.pad(X, (0, 0, 0, 0, 1, 1), "constant", 0.) - def bar_x(self, X): + def bar_x(self, X, filter): r"""Numerical scheme for X\bar^x Args: @@ -188,11 +228,17 @@ def bar_x(self, X): Returns: Tensor: X\bar^x with shape (Nz, Ny, Nx-1) """ - Nz, Ny, Nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / 2. - - def delta_x(self, X, dx): + # Nz, Ny, Nx = X.shape + # filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) # filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) + #TODO return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / 2. + x = self.bar_x_pre(X) #x = rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # x = X.view(1, 1, Nz, Ny, Nx) + convx = F.conv3d(x, filter) + convx2 = self.bar_x_post(convx) #rearrange(convx, 'b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') # convx2 = convx.view(Nz, Ny, Nx - 1) + convx3 = convx2 #TODO recover me / 2. + # convx3 = X # + return convx3 + + def delta_x(self, X, dx, filter): r"""Numerical scheme for \delta_x X Args: @@ -201,11 +247,17 @@ def delta_x(self, X, dx): Returns: Tensor: \delta_x X with shape (Nz, Ny, Nx-1) """ - Nz, Ny, Nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / dx + # Nz, Ny, Nx = X.shape + # filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) + # return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / dx + x_ext = self.delta_x_pre(X) + x_conv = F.conv3d(x_ext, filter) + x_unext = self.delta_x_post(x_conv) + return x_unext #TODO / dx + + - def bar_y(self, X): + def bar_y(self, X, filter): r"""Numerical scheme for X\bar^y Args: @@ -214,9 +266,9 @@ def bar_y(self, X): Returns: Tensor: X\bar^y with shape (Nz, Ny-1, Nx) """ - Nz, Ny, Nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / 2. + # Nz, Ny, Nx = X.shape + # filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) + # return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / 2. def delta_y(self, X, dy): r"""Numerical scheme for \delta_y X @@ -244,7 +296,7 @@ def bar_z(self, X): filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / 2. - def delta_z(self, X, dz): + def delta_z(self, X, dz, filter): r"""Numerical scheme for \delta_z X Args: @@ -253,9 +305,13 @@ def delta_z(self, X, dz): Returns: Tensor: \delta_z X with shape (Nz-1, Ny, Nx) """ - Nz, Ny, Nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / dz + # Nz, Ny, Nx = X.shape + # filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) + # return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / dz + x_ext = self.delta_z_pre(X) + x_conv = F.conv3d(x_ext, filter) + x_unext = self.delta_z_post(x_conv) + return x_unext #TODO / dz def _acoustic_step(self, ): r"""One acustic step""" @@ -265,7 +321,8 @@ def _acoustic_step(self, ): class LoopVariables(cube.runtime.syndata.CubeDataLoader): def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): - + # for var in variables + constants: + # print("### var = {}, type = {}".format(var, type(var))) shapes = [list(var.size()) for var in variables + constants] dtypes = [var.dtype for var in variables + constants] batch_dims = [0] * (len(variables) + len(constants)) @@ -333,17 +390,19 @@ def __next__(self): dz = torch.tensor(dz) U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s = init(nx, ny, nz, dx, dy, dz, theta, PREF, Rd, g) + bar_x_filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) + delta_z_filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) - varloader = LoopVariables(variables=[U, V, W, Theta, Mu, Phi, dt], constants=[Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g]) + varloader = LoopVariables(variables=[U, V, W, Theta, Mu, Phi, dt], constants=[Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter]) model = WRF() model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes), ) - #TODO @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) + @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) def train_iter(model, dataloader): - U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g = next(dataloader) - U, V, W, Theta, Mu, Phi = model(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, dt, PREF, Rd, g) + U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter , delta_z_filter= next(dataloader) + U, V, W, Theta, Mu, Phi = model(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, dt, PREF, Rd, g, bar_x_filter, delta_z_filter) return U, V, W, Theta, Mu, Phi - #TODO model = model.get_gen_module() + model = model.get_gen_module() import matplotlib.pyplot as plt import numpy as np From 849b15a3a6967d0f30aa6cb8a3d3835e3eb84548 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Mar 2022 13:26:26 +0800 Subject: [PATCH 0608/1892] add 4-gpu tp 1f1b version --- handcraft/pipeline/schedule.py | 178 +++++++++++++++++++++------------ 1 file changed, 115 insertions(+), 63 deletions(-) diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index 0012e045..d74f8a23 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -104,16 +104,21 @@ def tp_forward(fmodel, dataloader) -> torch.Tensor: # gather outputs = coll.gather([output], None, None, [1,0,2,3]) if rank == 1: - outputs[0], outputs[1] = outputs[1], outputs[0] - output = torch.cat(tuple(outputs), dim=-1) + with torch.no_grad(): + outputs[0], outputs[1] = outputs[1], outputs[0] + output = torch.cat(tuple(outputs), dim=-1) + output = output.requires_grad_() else: output = None return output def tp_backward(grad: torch.Tensor): - with torch.no_grad(): - grads = grad.chunk(4, dim=-1) - grads[0], grads[1] = grads[1], grads[0] + if rank == 1: + with torch.no_grad(): + grads = list(grad.chunk(4, dim=-1)) + grads[0], grads[1] = grads[1], grads[0] + else: + grads = None input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) grad_1st = coll.scatter(grads, [output_1st.size()], [output_1st.dtype], [1,0,2,3]) backward_step([input_1st], [output_1st], [grad_1st])[0] @@ -125,79 +130,126 @@ def recv_grad(output: torch.Tensor): fofst = [0, 0, 0, -1][rank] bofst = [0, -2, -2, -1][rank] - last_barrier_grad = None + last_backward = None + last_forward = None for step in range(num_microbatch + 2): torch.distributed.barrier() - print_each_rank('=========', rank_only=0) + # print_each_rank(f'=========begin rank {rank}=========') fmid, bmid = step + fofst, step + bofst + do_backward = 0 <= bmid and bmid <= num_microbatch - 1 + do_forward = 0 <= fmid and fmid <= num_microbatch - 1 + # step1: tp forward if 0 <= step and step <= num_microbatch - 1: - print(f'rank {rank} forward tp model ') + # print(f'rank {rank} forward tp model ') output_1st = tp_forward(first_stage_model, dataloader) - print(f'rank {rank} here') - # step2-1: backward + forward - if rank % 2 == 0: - # backward + forward - if rank == 0: pass - else: - if 0 <= bmid and bmid <= num_microbatch - 1: - input, output = input_tensors.pop(0), output_tensors.pop(0) - # recv output grad - print(f'rank {rank} recv backward grad ') - output_grad = recv_grad(output) - # backward - input_grad = backward_step([input], [output], [output_grad])[0] - # send input grad - print(f'rank {rank} send backward input ') - coll.send(input_grad, prev_rank) - if 0 <= fmid and fmid <= num_microbatch - 1: - # recv input - print(f'rank {rank} recv forward input ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - # forward step - output = forward_step(model, input) - # send output - print(f'rank {rank} send forward output ') - coll.send(output, next_rank) - input_tensors.append(input) - output_tensors.append(output) - # step2-2: forward + backward - #FIXME: warmup forward transimission - if rank % 2 == 1: - if bmid >= 1 and rank != 2: - # cross-barrier send grad - print(f'rank {rank} send backward input ') - coll.send(last_barrier_grad, prev_rank) - if 0 <= fmid and fmid <= num_microbatch - 1: - # recv input + + # step2: backward + forward + if rank == 0: + pass + + if rank == 2: + # inter-barrier + if do_backward and last_forward is not None: + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [last_forward], [model.output_shape()], [model.output_dtype()], + [next_rank], [next_rank] + )[0] + elif do_backward: + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + elif last_forward is not None: + # print(f'rank {rank} send forward output ') + coll.send(last_forward, next_rank) + + # backward + if do_backward: + input, output = input_tensors.pop(0), output_tensors.pop(0) + # backward + input_grad = backward_step([input], [output], [output_grad])[0] + + # intra-barrier + if do_backward and do_forward: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_backward: + # print(f'rank {rank} send backward grad ') + coll.send(input_grad, prev_rank) + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + + # forward + last_forward = None + if do_forward: + # forward step + output = forward_step(model, input) + input_tensors.append(input) + output_tensors.append(output) + last_forward = output + + if rank == 1: + + # forward + if do_forward: input = output_1st - if rank != 1: - print(f'rank {rank} recv forward input ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - # forward output = forward_step(model, input) - # send forward - if not is_last_stage(): - print(f'rank {rank} send forward output ') - coll.send(output, next_rank) input_tensors.append(input) output_tensors.append(output) - if 0 <= bmid and bmid <= num_microbatch - 1: + + # intra-barrier send recv + if do_forward and do_backward: + # send forward recv backward + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [output], [output.size()], [output.dtype], + [next_rank], [next_rank] + )[0] + elif do_forward: + # print(f'rank {rank} send forward output ') + coll.send(output, next_rank) + elif do_backward: + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + + # backward + if do_backward: input, output = input_tensors.pop(0), output_tensors.pop(0) - # recv grad - print(f'rank {rank} recv backward grad ') - output_grad = recv_grad(output) - # backward input_grad = backward_step([input], [output], [output_grad])[0] - # postpone send to next barrier - last_barrier_grad = input_grad + last_backward = input_grad + + if rank == 3: + + # inter-barrier + if do_forward and last_backward is not None: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + elif last_backward is not None: + # print(f'rank {rank} send backward grad ') + coll.send(last_backward, prev_rank) + + # backward + forward + if do_forward: + output = forward_step(model, input) + input_grad = backward_step([input], [output], [None,])[0] + last_backward = input_grad + # step3: tp backward if 0 <= (step-2) and (step-2) <= num_microbatch - 1: - print(f'rank {rank} backward tp model ') - tp_backward(last_barrier_grad) + # print(f'rank {rank} backward tp model ') + tp_backward(last_backward) - torch.distributed.barrier() - print_each_rank('=========', rank_only=0) + # print_each_rank(f'=========end rank {rank}=========') def schedule_1f1b(model: torch.nn.Module, From 10ac8aad7befb7e1021fa5eaf5e8f86d8292fae9 Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 3 Mar 2022 13:36:27 +0800 Subject: [PATCH 0609/1892] fix missing EinOps support --- cube/graph/operator/function/scripteinops.py | 50 ++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 cube/graph/operator/function/scripteinops.py diff --git a/cube/graph/operator/function/scripteinops.py b/cube/graph/operator/function/scripteinops.py new file mode 100644 index 00000000..14d37672 --- /dev/null +++ b/cube/graph/operator/function/scripteinops.py @@ -0,0 +1,50 @@ + +from typing import List + +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + +from einops.einops import _apply_recipe + +import torch + +class IRScriptEinOps(IRFwOperation): + + def __init__(self, signature: str, inputs: List[IRTensor], name: str, + **kwargs): + signature = 'einops._torch_specific.apply_for_scriptable_torch' #'cube.runtime.function.conv2d' + assert len(inputs) == 1, "Expected only input" + assert len(kwargs) == 2, "Expected 2 kwargs: recipe, reduction_type" + super().__init__(name, signature, 1, 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update(kwargs) + + def infer_shape(self) -> bool: + """ + Output shape inference given the input shapes + """ + if len(self.inputs(0).shape) == 0: + return False + + recipe = self.kwargs['recipe'] + reduction_type = self.kwargs['reduction_type'] + tmp_tensor = torch.zeros(self.inputs(0).shape) + tmp_output = _apply_recipe(recipe, tmp_tensor, reduction_type) + self.outputs(0).shape = list(tmp_output.shape) + return True + + def new(self, inputs: List, outputs: List): + """ + construct a new operator sharing same kwargs with new inputs + and outputs + """ + recipe = self.kwargs['recipe'] + reduction_type = self.kwargs['reduction_type'] + op = IRScriptEinOps(self.signature, inputs, self.name, + recipe=recipe, reduction_type=reduction_type) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + op.infer_shape() + return op + From ce86e41f77c217db90236455f989f06ccd375687 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Mar 2022 14:18:30 +0800 Subject: [PATCH 0610/1892] add gather and scatter collectives --- cube/runtime/adapter/collectives.py | 116 +++++++++++++++++++++++----- handcraft/pipeline/dummy.py | 65 ++++++++++++---- handcraft/pipeline/schedule.py | 8 +- 3 files changed, 148 insertions(+), 41 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 3d2d3c51..54732eaa 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -1,4 +1,5 @@ from typing import List +from unittest import defaultTestLoader import torch from cube.runtime.device import DeviceGroup @@ -56,29 +57,28 @@ def recv(shape: List[int], from_rank: int, dtype: torch.dtype): return tensor -def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): +def sendrecv(input_tensors: List[torch.Tensor], + output_shapes: List[List[int]], + output_dtypes: List[torch.dtype], + send_ranks: List[int], + recv_ranks: List[int]) -> List[torch.Tensor]: CudaTimer().start(field_name='comm') # print('sending and recving...') ops = list() - recv_tensors = list() - for tensor, ranks in zip(send_tensors, to_ranks): + outputs = list() + for tensor, rank in zip(input_tensors, send_ranks): if not torch.is_tensor(tensor): raise RuntimeError(f"Expected {tensor} to be tensor") - for rank in ranks: - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, rank - ) - ops.append(send_op) - for shape, ranks in zip(recv_shapes, from_ranks): - if len(ranks) != 1: - raise RuntimeError( - "Not supported for recving same tensor from multiple devices" - ) - rank = ranks[0] + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, rank + ) + ops.append(send_op) + for shape, dtype, rank in zip(output_shapes, output_dtypes, recv_ranks): tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device() + shape, dtype=dtype, + requires_grad=True, device=torch.cuda.current_device() ) - recv_tensors.append(tensor) + outputs.append(tensor) recv_op = torch.distributed.P2POp( torch.distributed.irecv, tensor, rank ) @@ -86,13 +86,9 @@ def send_and_recv(send_tensors, to_ranks, recv_shapes, from_ranks): reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() - torch.cuda.synchronize() CudaTimer().stop(field_name='comm') - - if len(recv_tensors) == 0: return None - elif len(recv_tensors) == 1: return recv_tensors[0] - else: return tuple(recv_tensors) + return outputs ### Collective Universal Interface ### @@ -189,3 +185,81 @@ def broadcast(input_tensors: List[torch.Tensor], torch.distributed.broadcast(tensor, ranks[0], group=group) CudaTimer().stop(field_name='comm') return tensor + + +def gather(input_tensors: List[torch.Tensor], + output_shapes: List[List[int]], + output_dtypes: List[torch.dtype], + ranks: List[int]) -> List[torch.Tensor]: + """ + Gather. ranks[0] is the root + """ + CudaTimer().start(field_name='comm') + assert len(input_tensors) == 1 + input_tensor = input_tensors[0] + dst = ranks[0] + if DeviceGroup().rank == dst: + # recv + tensor_list = [input_tensor] + [torch.empty_like(input_tensor) for _ in range(len(ranks)-1)] + ops = list() + for rank, tensor in zip(ranks[1:], tensor_list[1:]): + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, rank + ) + ops.append(recv_op) + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + else: + # send + tensor_list = [] + send_op = torch.distributed.P2POp( + torch.distributed.isend, input_tensor, ranks[0] + ) + reqs = torch.distributed.batch_isend_irecv([send_op]) + for req in reqs: + req.wait() + CudaTimer().stop(field_name='comm') + return tensor_list + + +def scatter(input_tensors: List[torch.Tensor], + output_shapes: List[List[int]], + output_dtypes: List[torch.dtype], + ranks: List[int]) -> List[torch.Tensor]: + CudaTimer().start(field_name='comm') + output = None + src = ranks[0] + if DeviceGroup().rank == src: + # send + ops = list() + for rank, tensor in zip(ranks, input_tensors): + if rank == src: + output = tensor + else: + if not tensor.is_contiguous(): + with torch.no_grad(): + tensor = tensor.contiguous() + send_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, rank + ) + ops.append(send_op) + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + else: + # recv + assert len(output_shapes) == 1 and len(output_dtypes) == 1 + output = torch.empty( + output_shapes[0], dtype=output_dtypes[0], + requires_grad=True, device=torch.cuda.current_device() + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, output, src + ) + reqs = torch.distributed.batch_isend_irecv([recv_op]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return output diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py index 5aa83db0..a457aacb 100644 --- a/handcraft/pipeline/dummy.py +++ b/handcraft/pipeline/dummy.py @@ -4,15 +4,20 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - handcraft/pipeline/dummy.py + handcraft/pipeline/dummy.py --use-naive """ +from sys import argv +from threading import activeCount import torch import torch.nn.functional as F import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank from cube.runtime.device import DeviceGroup from cube.runtime.syndata import SynDataLoader from handcraft.pipeline.schedule import schedule_tp_1f1b, schedule_naive +import argparse class DummyModel(torch.nn.Module): @@ -22,7 +27,7 @@ def __init__(self, dim: int, bs: int, stage_id: int, sharding=False): super().__init__() self.bs = bs self.dim = dim - self.is_last_stage = stage_id == DeviceGroup().world_size + self.is_last_stage = stage_id == DeviceGroup().world_size - 1 if sharding: chunk_num = torch.distributed.get_world_size() self.weight = torch.nn.Parameter(torch.zeros((dim // chunk_num, dim))) @@ -32,9 +37,15 @@ def __init__(self, dim: int, bs: int, stage_id: int, sharding=False): def input_shape(self): return (self.bs, self.dim, self.dim) + def output_shape(self): + return (1,) if self.is_last_stage else (self.bs, self.dim, self.dim) + def input_dtype(self): return torch.float32 + def output_dtype(self): + return torch.float32 + def forward(self, input): output = F.linear(input, self.weight) if self.is_last_stage: @@ -45,22 +56,32 @@ def forward(self, input): if __name__ == '__main__': + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--use-naive', action='store_true', + help='use naive pipeline') + parser.add_argument('--use-tp1f1b', action='store_true', + help='use tensor parallel 1f1b') + args = parser.parse_args() + assert args.use_naive ^ args.use_tp1f1b, "Specify (only) 1 way pipeline" + cube.init() rank = DeviceGroup().rank - dim = 1024 - gbs = 32 + dim = 2048 + gbs = 512 mbs = 8 # tp 1f1b - first_stage_model = DummyModel(dim, mbs, 0, sharding=True).cuda() - if rank == 0: - model = None - else: - model = DummyModel(dim, mbs, rank, sharding=False).cuda() + if args.use_tp1f1b: + first_stage_model = DummyModel(dim, mbs, 0, sharding=True).cuda() + if rank == 0: + model = None + else: + model = DummyModel(dim, mbs, rank, sharding=False).cuda() - # naive pipleline - # model = DummyModel(dim, mbs, sharding=False).cuda() + if args.use_naive: + # naive pipleline + model = DummyModel(dim, mbs, rank, sharding=False).cuda() dataloader = SynDataLoader( shapes=([mbs, dim, dim],), @@ -68,9 +89,21 @@ def forward(self, input): batch_dims=(0,) ) - for step in range(128): - # schedule_naive(model, dataloader, gbs // mbs) - schedule_tp_1f1b(model, first_stage_model, dataloader, gbs // mbs) + iter_num = 64 + CudaTimer(enable=False).warmup() + for step in range(iter_num): + if step >= 20: + CudaTimer(enable=True).start('e2e') + + if args.use_tp1f1b: + schedule_tp_1f1b(model, first_stage_model, dataloader, gbs // mbs) + if args.use_naive: + schedule_naive(model, dataloader, gbs // mbs) + + if step >= 20: + CudaTimer().stop('e2e') if (step+1) % 10 == 0: - print(f'iteration: {step+1}/128') - \ No newline at end of file + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-20, field_name='e2e'))) diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index d74f8a23..1179ee49 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -61,24 +61,24 @@ def schedule_naive(model, dataloader, num_microbatch: int): if is_first_stage(): input = next(dataloader) else: - print(f'rank {rank} recving forward input...') + # print(f'rank {rank} recving forward input...') input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) # forward output = forward_step(model, input) # send forward if not is_last_stage(): - print(f'rank {rank} sending forward output...') + # print(f'rank {rank} sending forward output...') coll.send(output, next_rank) # recv backward output_grad = None if not is_last_stage(): - print(f'rank {rank} recving backward input...') + # print(f'rank {rank} recving backward input...') output_grad = coll.recv(output.size(), next_rank, output.dtype) # backward input_grad = backward_step([input], [output], [output_grad])[0] # send backward if not is_first_stage(): - print(f'rank {rank} sending backward output...') + # print(f'rank {rank} sending backward output...') coll.send(input_grad, prev_rank) From 39873740cbde7d63040b9f6289d039a63bcfc3c0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Mar 2022 14:42:38 +0800 Subject: [PATCH 0611/1892] increase data number --- handcraft/pipeline/dummy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py index a457aacb..739d4e05 100644 --- a/handcraft/pipeline/dummy.py +++ b/handcraft/pipeline/dummy.py @@ -86,7 +86,8 @@ def forward(self, input): dataloader = SynDataLoader( shapes=([mbs, dim, dim],), dtypes=(torch.float32, ), - batch_dims=(0,) + batch_dims=(0,), + length=128000 ) iter_num = 64 From bf1b6f51b1e0e411c8bb9a656330323cb2e5550e Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 3 Mar 2022 15:56:43 +0800 Subject: [PATCH 0612/1892] add atmosphere example --- examples/atmosphere/weather.py | 363 +++++++++++++++++++++++++++++++++ 1 file changed, 363 insertions(+) create mode 100644 examples/atmosphere/weather.py diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py new file mode 100644 index 00000000..81865427 --- /dev/null +++ b/examples/atmosphere/weather.py @@ -0,0 +1,363 @@ +import math +import torch +import numpy as np +import torch.nn.functional as F + +torch.set_default_tensor_type(torch.DoubleTensor) + + + +class Atmoshpere(torch.nn.Module): + def __init__(self, nz, ny, nx, dy, dx, x0, y0, device='cuda'): + super().__init__() + self.device = torch.device(device) + + # physics constant + self.g = 9.8 # acceleration of gravity, unit in m/s^2 + self.PSEA = 101325. # sea level pressure, unit in Pa + self.KAPPA = 0.286 # dimensionless + self.RE = 6.4e6 # radius of earth, unit in m + self.CPD = 1004.67 # specific heat of dry air at constant pressure J*kg^-1*K^-1 + # self.OMEGA = 7.292e-5 # angular speed of the Earth s^-1 + self.OMEGA = 1e-1 # angular speed of the Earth s^-1 + + # atmoshpere verticle profile + self.hight_profile = torch.tensor([ + 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, + 8.5, 9, 9.5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100 + ]).to(self.device) * 1e3 + + self.pressure_profile = torch.tensor([ + 1013.25, 1001.20, 989.45, 977.72, 966.11, 954.61, 943.22, 931.94, 920.77, 909.71, 898.80, 845.59, 795.0, + 746.9, 701.2, 657.8, 616.6, 577.5, 540.5, 505.4, 472.2, 440.7, 411.1, 383.0, 356.5, 331.5, 308.0, 285.8, + 265.0, 227.0, 194.0, 165.8, 141.7, 121.1, 103.5, 88.5, 75.7, 64.7, 55.3, 47.3, 40.5, 34.7, 29.7, 25.5, 21.9, + 18.8, 16.2, 13.9, 12.0, 10.3, 8.89, 7.67, 6.63, 5.75, 4.99, 4.33, 3.77, 3.29, 2.87, 2.51, 2.20, 1.93, 1.69, + 1.49, 1.31, 1.16, 1.02, 0.903, 0.903, 0.425, 0.220, 0.109, 0.0522, 0.0239, 0.0105, 0.0045, 0.0018, 0.00076, + 0.00032 + ]).to(self.device) * 1e2 + + self.temperature_profile = torch.tensor([ + 288.15, 287.50, 286.85, 286.20, 285.55, 284.90, 284.25, 283.60, 282.95, 282.30, 281.65, 278.40, 275.15, + 271.91, 268.66, 265.41, 262.17, 258.92, 255.68, 252.43, 249.19, 245.94, 242.70, 239.46, 236.22, 232.97, + 229.73, 226.49, 223.25, 216.78, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, + 217.58, 218.57, 219.57, 220.56, 221.55, 222.54, 223.54, 224.53, 225.52, 226.51, 227.50, 228.49, 230.97, + 233.74, 236.51, 239.28, 242.05, 244.82, 247.58, 250.35, 253.11, 255.88, 258.64, 261.40, 264.16, 266.93, + 269.68, 270.65, 270.65, 270.65, 260.77, 247.02, 233.29, 219.59, 208.40, 198.64, 188.89, 186.87, 188.42, + 195.08 + ]).to(self.device) + + self.density_profile = torch.tensor([ + 1.225, 1.213, 1.202, 1.190, 1.179, 1.167, 1.156, 1.145, 1.134, 1.123, 1.112, 1.058, 1.007, 0.957, 0.909, + 0.863, 0.819, 0.777, 0.736, 0.697, 0.660, 0.624, 0.590, 0.557, 0.526, 0.496, 0.467, 0.440, 0.414, 0.365, + 0.312, 0.267, 0.228, 0.195, 0.166, 0.142, 0.122, 0.104, 0.0889, 0.0757, 0.0645, 0.0550, 0.0469, 0.0401, + 0.0343, 0.0293, 0.0251, 0.0215, 0.0184, 0.0158, 0.0136, 0.0116, 0.00989, 0.00846, 0.00726, 0.00624, 0.00537, + 0.00463, 0.00400, 0.00346, 0.00299, 0.00260, 0.00226, 0.00197, 0.00171, 0.0015, 0.00132, 0.00116, 0.00103, + 5.7e-4, 3.1e-4, 1.6e-4, 8.3e-4, 4.0e-5, 1.8e-5, 8.2e-6, 3.4e-6, 7.5e-7, 5.6e-7 + ]).to(self.device) + + # simulation domain + self.nx = nx + self.ny = ny + self.nz = nz + self.dx = dx + self.dy = dy + self.dz = 1. / nz + self.x0 = x0 + self.y0 = y0 + + def init(self, ps, pt, zs): + self.Y = ( + torch.linspace(0, self.ny, self.ny + 1, device=self.device) * self.dy + self.y0 + ).view(1, self.ny + 1, 1) + self.deltaA = self.RE**2 * torch.cos(self.bar_y(self.Y) * 0.5).to(self.device) * self.dx * self.dy # (1, ny, 1) + self.f = 2 * self.OMEGA * torch.sin(self.bar_y(self.Y)) * torch.cos(self.bar_y(self.Y)) * self.RE # (nz, ny, nx) + + # vertical grids + pt = torch.tensor([pt]).view(1, 1, 1).to(self.device) + zt = self.hight_from_pressure(pt) + z = torch.linspace(1, 0, self.nz + 1, device=self.device).view(-1, 1, 1) * zt + p_ = self.pressure_from_hight(z) + self.sigma = (p_ - pt) / (p_[-1] - pt) # (nz + 1, 1, 1) + + # column pressure, with shape (1, ny, nx) + pi = (ps - pt).view(1, self.ny, self.nx) + + # potential temperature factor + p_ = pt + self.sigma * pi # (nz + 1, ny, nx) + self.P_ = (p_ / self.PSEA)**self.KAPPA # (nz + 1, ny, nx) + self.P = self.delta_z(p_ * self.P_) / self.delta_z(p_) / (1 + self.KAPPA) # (nz, ny, nx) + + # potential temperature + p = self.PSEA * self.P**(1 / self.KAPPA) + T = self.temperature_from_pressure(p) + theta = T / self.P + + # geopotential (nz, ny, nx) + self.phi = torch.zeros((self.nz, self.ny, self.nx), device=self.device) + self.zs = zs + + # vertical velocity + self.w = torch.zeros((self.nz + 1, self.ny, self.nx), device=self.device) + + return pi, theta + + def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): + # flux + F = self.bar_x(self.pad_x(pi)) * 0.5 * u * self.RE * self.dy # (nz, ny, nx + 1) + G = self.bar_y(self.pad_y(pi)) * 0.5 * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) + B = self.bar_y2(self.bar_x(self.pad_y(F))) / 12. # (nz, ny, nx) + C = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(G)))) / 12. # (nz, ny + 1, nx + 1) + D = self.bar_y2(self.pad_y(G)) + self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) + E = self.bar_y2(self.pad_y(G)) - self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) + Q = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(F)))) / 12. # (nz, ny + 1, nx + 1) + R = self.bar_x2(self.bar_y(self.pad_x(G))) / 12. # (nz, ny, nx) + S = self.bar_y(self.bar_x(self.pad_x(G))) + self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) + T = self.bar_y(self.bar_x(self.pad_x(G))) - self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) + + pi1 = pi0 - dt / self.deltaA * ((self.delta_x(F) + self.delta_y(G)) * self.dz).sum(axis=0) # (nz, ny, nx) + + # print('pi:', pi1.mean()) + + # update diagnostic variable w (nz + 1, ny, nx) + for i in range(1, self.nz + 1): + self.w[i] = - ((self.delta_x(F[:i]) + self.delta_y(G[:i])) * self.dz).sum(axis=0) / self.deltaA / pi1 \ + - self.sigma[i] * (pi1 - pi0) / dt / pi1 + + # print('w:', self.w.mean()) + + # update potential temperature theta (nz, ny, nx) + theta_ = self.pad_z( + (self.bar_z(self.P * theta) - self.delta_z(theta) * self.P_[1:-1]) / self.delta_z(self.P) + ) # (nz + 1, ny, nx) + theta1 = pi0 / pi1 * theta0 + dt / self.deltaA / pi1 * ( + (self.delta_x(F * self.bar_x(self.pad_x(theta))) + self.delta_y(G * self.bar_y(self.pad_y(theta)))) / 2. + + pi * self.deltaA * self.delta_z(self.w * theta_) / self.dz + + 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(theta)))) + ) + + # print('theta:', theta1.mean()) + + # update geopotential + self.phi[-1] = self.g * self.zs - self.CPD * (self.P[-1] - self.P_[-1]) * theta[-1] + for i in range(1, self.nz): + tmp = self.phi[-i] - self.CPD * (self.P_[-i - 1] - self.P[-i]) * theta[-i] + self.phi[-1 - i] = tmp - self.CPD * (self.P[-1 - i] - self.P_[-1 - i]) * theta[-1 - i] + + # print('phi:', self.phi.mean()) + + # update u (nz, ny, nx + 1) + pi0_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi0 * self.deltaA)))) / 8. # (nz, ny, nx + 1) + pi1_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi1 * self.deltaA)))) / 8. # (nz, ny, nx + 1) + pi_w_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny, nx + 1) + advec = ( + - self.delta_x(self.pad_x(B * self.bar_x(u))) + - self.delta_y(C * self.bar_y(self.pad_y(u))) + + self.delta_D(self.pad_x(D * self.bar_xy(self.pad_y(u)))) + + self.delta_E(self.pad_x(E * self.bar_xy(self.pad_y(u)))) + ) / 2. + trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(u)) * 0.5) / self.dz + press = - self.RE * self.dy * ( + self.delta_x(self.pad_x(self.phi)) * self.delta_x(self.pad_x(pi)) / 2. + + self.delta_x(self.pad_x(pi)) * 0.5 * self.CPD * self.bar_x(self.pad_x( + theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) + )) + ) + diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(u)))) + cori = self.RE * self.dx * self.dy * 0.25 * ( + self.bar_x(self.pad_x(pi * self.bar_y(v) * (self.f + 0. * self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) + ) * 0.0 + u1 = (pi0_deltaA * u0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA + # print('u1:', u1.mean()) + + # update v (nz, ny + 1, nx) + pi0_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi0 * self.deltaA)))) / 8. # (nz, ny + 1, nx) + pi1_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi1 * self.deltaA)))) / 8. # (nz, ny + 1, nx) + pi_w_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny + 1, nx) + advec = ( + - self.delta_x(Q * self.bar_x(self.pad_x(v))) + - self.delta_y(self.pad_y(R * self.bar_y(v))) + + self.delta_D(self.pad_y(S * self.bar_xy(self.pad_x(v)))) + + self.delta_E(self.pad_y(T * self.bar_xy(self.pad_x(v)))) + ) / 2. + trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(v)) * 0.5) / self.dz + press = - self.RE * self.dx * ( + self.delta_y(self.pad_y(self.phi)) * self.delta_y(self.pad_y(pi)) / 2. + + self.delta_y(self.pad_y(pi)) * 0.5 * self.CPD * self.bar_y(self.pad_y( + theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) + )) + ) + diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(v)))) + cori = - self.RE * self.dx * self.dy * 0.25 * ( + self.bar_y(self.pad_y(pi * self.bar_x(u) * (self.f + self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) + ) * 0.0 + v1 = (pi0_deltaA * v0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA + # print('v1:', v1.mean()) + + return pi1, theta1, u1, v1 + + def forward(self, dt, pi, theta, u, v): + pi_, theta_, u_, v_ = self.step(dt / 2, pi, theta, u, v, pi, theta, u, v) + return self.step(dt, pi_, theta_, u_, v_, pi, theta, u, v) + + def hight_from_pressure(self, p): + ind0 = torch.abs((p[None] - self.pressure_profile[(..., ) + (None, ) * len(p.shape)])).argmin(axis=0) + ind1 = (p > self.pressure_profile[ind0]) * 2 - 1 + ind0 + hight = (self.hight_profile[ind1] - self.hight_profile[ind0]) * (p - self.pressure_profile[ind0]) / ( + self.pressure_profile[ind1] - self.pressure_profile[ind0]) + self.hight_profile[ind0] + return hight + + def pressure_from_hight(self, z): + ind0 = torch.abs((z[None] - self.hight_profile[(..., ) + (None, ) * len(z.shape)])).argmin(axis=0) + ind1 = (self.hight_profile[ind0] > z) * 2 - 1 + ind0 + p = (self.pressure_profile[ind1] - self.pressure_profile[ind0]) * (z - self.hight_profile[ind0]) / \ + (self.hight_profile[ind1] - self.hight_profile[ind0]) + self.pressure_profile[ind0] + return p + + def temperature_from_pressure(self, p): + ind0 = torch.abs((p[None] - self.pressure_profile[(..., ) + (None, ) * len(p.shape)])).argmin(axis=0) + ind1 = (p > self.pressure_profile[ind0]) * 2 - 1 + ind0 + T = (self.temperature_profile[ind1] - self.temperature_profile[ind0]) * (p - self.pressure_profile[ind0]) / ( + self.pressure_profile[ind1] - self.pressure_profile[ind0]) + self.temperature_profile[ind0] + return T + + def pad_x(self, X): + return F.pad(X, (1, 1), "circular") + + def bar_x(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) + + def bar_x2(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 2., 1.], device=X.device).view(1, 1, 1, 1, 3) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 2) + + def delta_x(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) + + def pad_y(self, X): + nz, ny, nx = X.shape + return F.pad(X.view(1, nz, ny, nx), (0, 0, 1, 1), "circular").view(nz, ny + 2, nx) + + def bar_y(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) + + def bar_y2(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 2., 1.], device=X.device).view(1, 1, 1, 3, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 2, nx) + + def delta_y(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 2, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) + + def bar_z(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) + + def pad_z(self, X): + nz, ny, nx = X.shape + return F.pad(X.view(1, 1, nz, ny, nx), (0, 0, 0, 0, 1, 1)).view(nz + 2, ny, nx) + + def delta_z(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) + + def delta_D(self, X): + nz, ny, nx = X.shape + filter = torch.tensor( + [[1., 0.], + [0., -1.]], + device=X.device + ).view(1, 1, 1, 2, 2) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + + def delta_E(self, X): + nz, ny, nx = X.shape + filter = torch.tensor( + [[0., 1.], + [-1., 0.]], + device=X.device + ).view(1, 1, 1, 2, 2) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + + def bar_xy(self, X): + nz, ny, nx = X.shape + filter = torch.tensor( + [[1., 0.], + [0., 1.]], + device=X.device + ).view(1, 1, 1, 2, 2) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + + def laplas(self, X): + nz, ny, nx = X.shape + filter = torch.tensor( + [[[0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.]], + [[0., 1., 0.], + [1., -6, 1.], + [0., 1., 0.]], + [[0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.]]], + device=X.device + ).view(1, 1, 3, 3, 3) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 2, ny - 2, nx - 2) + + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + nz = 15 + ny = 100 + nx = 100 + dy = 1e-4 + dx = 1e-4 + x0 = 0.0 + y0 = 0.2 + + atmosphere = Atmoshpere(nz, ny, nx, dy, dx, x0, y0) + + xc = nx * dx / 2 + x0 + yc = ny * dy / 2 + y0 + X = torch.linspace(0, nx - 1, nx).view(1, 1, nx).cuda() * dx + x0 + Y = torch.linspace(0, ny - 1, ny).view(1, ny, 1).cuda() * dy + y0 + ps = torch.ones((1, ny, nx)).cuda() * atmosphere.PSEA - 300 * torch.exp( + - 1e-6 * ((atmosphere.RE * torch.cos((Y + yc) / 2)) * (X - xc))**2 + - 1e-6 * (atmosphere.RE * (Y - yc))**2) + pt = 250e2 + zs = torch.zeros((ny, nx)).cuda() + 10000 * torch.exp( + - 1e-6 * (atmosphere.RE * (X - nx * dx / 3 - x0))**2 + - 1e-6 * (atmosphere.RE * (Y - yc))**2) + + pi, theta = atmosphere.init(ps, pt, zs) + + u = torch.zeros((nz, ny, nx + 1)).cuda() + v = torch.zeros((nz, ny + 1, nx)).cuda() + + for i in range(100): + pi, theta, u, v = atmosphere(1., pi, theta, u, v) + + # ctf = plt.contourf(pi.view(ny, nx).numpy(), levels=50, cmap='jet') + plt.cla() + ct = plt.contour(zs.view(ny, nx).cpu().numpy(), levels=[7000]) + ctf = plt.contourf(u[3].cpu().numpy(), levels=50, cmap='jet') + plt.colorbar(ctf) + # plt.grid(True) + plt.tight_layout() + plt.savefig(f'res2/res{i}.jpeg', dpi=300) + plt.clf() + + print(i) \ No newline at end of file From fbc77115da5d5c551095b6e4eb3c066429db5f81 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Mar 2022 18:31:52 +0800 Subject: [PATCH 0613/1892] fix scalar tensor --- cube/logics/dataloader.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cube/logics/dataloader.py b/cube/logics/dataloader.py index cc8871b1..49a0c043 100644 --- a/cube/logics/dataloader.py +++ b/cube/logics/dataloader.py @@ -21,7 +21,11 @@ def __init__(self, dataloader: CubeDataLoader, dtype_map): for data in datas: if torch.is_tensor(data): self.dtypes.append(dtype_map.map(data.dtype)) - self.shapes.append(list(data.shape)) + shape = tuple(data.shape) + # special handler for scalar tensor shape + if len(shape) == 0: + shape = (1,) + self.shapes.append(shape) else: raise NotImplementedError("Data should be torch.Tensor") From a51d1db71fa4794dfde82a5909c9a09c32465c94 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 4 Mar 2022 11:22:28 +0800 Subject: [PATCH 0614/1892] support arbitrary device number --- handcraft/pipeline/dummy.py | 15 ++++++--- handcraft/pipeline/schedule.py | 57 ++++++++++++++++++++++++---------- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py index 739d4e05..636d73d3 100644 --- a/handcraft/pipeline/dummy.py +++ b/handcraft/pipeline/dummy.py @@ -5,9 +5,12 @@ --nproc_per_node=4 \ --nnodes=1 \ handcraft/pipeline/dummy.py --use-naive + +OMP_NUM_THREADS=4 python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + handcraft/pipeline/dummy.py --use-naive """ -from sys import argv -from threading import activeCount import torch import torch.nn.functional as F import cube @@ -61,6 +64,8 @@ def forward(self, input): help='use naive pipeline') parser.add_argument('--use-tp1f1b', action='store_true', help='use tensor parallel 1f1b') + parser.add_argument('--nmb', type=int, default=4, + help='num of micro batch') args = parser.parse_args() assert args.use_naive ^ args.use_tp1f1b, "Specify (only) 1 way pipeline" @@ -68,8 +73,8 @@ def forward(self, input): rank = DeviceGroup().rank dim = 2048 - gbs = 512 mbs = 8 + gbs = mbs * args.nmb # tp 1f1b if args.use_tp1f1b: @@ -97,9 +102,9 @@ def forward(self, input): CudaTimer(enable=True).start('e2e') if args.use_tp1f1b: - schedule_tp_1f1b(model, first_stage_model, dataloader, gbs // mbs) + schedule_tp_1f1b(model, first_stage_model, dataloader, args.nmb, DeviceGroup().world_size) if args.use_naive: - schedule_naive(model, dataloader, gbs // mbs) + schedule_naive(model, dataloader, args.nmb) if step >= 20: CudaTimer().stop('e2e') diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index 1179ee49..6c8ce99e 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -85,7 +85,8 @@ def schedule_naive(model, dataloader, num_microbatch: int): def schedule_tp_1f1b(model: torch.nn.Module, first_stage_model: torch.nn.Module, dataloader, - num_microbatch: int): + num_microbatch: int, + num_stage: int): rank = DeviceGroup().rank next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size @@ -96,13 +97,16 @@ def schedule_tp_1f1b(model: torch.nn.Module, input_1st_tensors = list() output_1st_tensors = list() + gather_list = list(range(num_stage)) + gather_list[0], gather_list[1] = gather_list[1], gather_list[0] + def tp_forward(fmodel, dataloader) -> torch.Tensor: input = next(dataloader) output = forward_step(fmodel, input) input_1st_tensors.append(input) output_1st_tensors.append(output) # gather - outputs = coll.gather([output], None, None, [1,0,2,3]) + outputs = coll.gather([output], None, None, gather_list) if rank == 1: with torch.no_grad(): outputs[0], outputs[1] = outputs[1], outputs[0] @@ -115,21 +119,20 @@ def tp_forward(fmodel, dataloader) -> torch.Tensor: def tp_backward(grad: torch.Tensor): if rank == 1: with torch.no_grad(): - grads = list(grad.chunk(4, dim=-1)) + grads = list(grad.chunk(num_stage, dim=-1)) grads[0], grads[1] = grads[1], grads[0] else: grads = None input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) - grad_1st = coll.scatter(grads, [output_1st.size()], [output_1st.dtype], [1,0,2,3]) + grad_1st = coll.scatter(grads, [output_1st.size()], [output_1st.dtype], gather_list) backward_step([input_1st], [output_1st], [grad_1st])[0] - def recv_grad(output: torch.Tensor): - if is_last_stage(): - return None - return coll.recv(output.size(), next_rank, output.dtype) - - fofst = [0, 0, 0, -1][rank] - bofst = [0, -2, -2, -1][rank] + fofst = [0] + [-(step // 2) for step in range(num_stage-1)] + bofst = [0] + [-(num_stage - 2 - (step // 2)) for step in range(num_stage-1)] + # print(fofst) + # print(bofst) + fofst = fofst[rank] + bofst = bofst[rank] last_backward = None last_forward = None for step in range(num_microbatch + 2): @@ -148,7 +151,7 @@ def recv_grad(output: torch.Tensor): if rank == 0: pass - if rank == 2: + if rank != 0 and rank % 2 == 0: # inter-barrier if do_backward and last_forward is not None: # print(f'rank {rank} recv backward grad + send forward output ') @@ -222,7 +225,7 @@ def recv_grad(output: torch.Tensor): input_grad = backward_step([input], [output], [output_grad])[0] last_backward = input_grad - if rank == 3: + if rank != 1 and rank % 2 == 1: # inter-barrier if do_forward and last_backward is not None: @@ -238,14 +241,36 @@ def recv_grad(output: torch.Tensor): # print(f'rank {rank} send backward grad ') coll.send(last_backward, prev_rank) - # backward + forward + # forward if do_forward: output = forward_step(model, input) - input_grad = backward_step([input], [output], [None,])[0] + input_tensors.append(input) + output_tensors.append(output) + + # intra-barrier send recv + output_grad = None + if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): + # send forward recv backward + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [output], [output.size()], [output.dtype], + [next_rank], [next_rank] + )[0] + elif do_forward and not is_last_stage(): + # print(f'rank {rank} send forward output ') + coll.send(output, next_rank) + elif do_backward and not is_last_stage(): + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + + # backward + forward + if do_backward: + input, output = input_tensors.pop(0), output_tensors.pop(0) + input_grad = backward_step([input], [output], [output_grad])[0] last_backward = input_grad # step3: tp backward - if 0 <= (step-2) and (step-2) <= num_microbatch - 1: + if 0 <= (step-num_stage+2) and (step-num_stage+2) <= num_microbatch - 1: # print(f'rank {rank} backward tp model ') tp_backward(last_backward) From 98b1e12b2e914879fcd71d271928beca1ac17a21 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 4 Mar 2022 16:32:15 +0800 Subject: [PATCH 0615/1892] update with embedding layer --- handcraft/pipeline/dummy.py | 103 +++++++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 24 deletions(-) diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py index 636d73d3..2ea3de28 100644 --- a/handcraft/pipeline/dummy.py +++ b/handcraft/pipeline/dummy.py @@ -16,45 +16,96 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary from cube.runtime.device import DeviceGroup -from cube.runtime.syndata import SynDataLoader +from cube.runtime.syndata import SynDataLoader, SynTextDataLoader from handcraft.pipeline.schedule import schedule_tp_1f1b, schedule_naive import argparse +""" +Stage0: + Embedding [M, 1], [N, E] -> [M, E] + Linear [M, E], [E, E] -> [M, E] + +Stage Else: + Linear [M, E], [E, E] -> [M, E] + +Condition: N > 8M - E +""" + +class ReduceEmbed(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + torch.distributed.all_reduce(input) + return input + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + class DummyModel(torch.nn.Module): - def __init__(self, dim: int, bs: int, stage_id: int, sharding=False): + def __init__(self, M: int, N: int, E: int, stage_id: int, sharding=False): super().__init__() - self.bs = bs - self.dim = dim + self.M = M + self.N = N + self.E = E self.is_last_stage = stage_id == DeviceGroup().world_size - 1 - if sharding: - chunk_num = torch.distributed.get_world_size() - self.weight = torch.nn.Parameter(torch.zeros((dim // chunk_num, dim))) + self.is_first_stage = stage_id == 0 + self.sharding = sharding + # first stage + chunk_num = torch.distributed.get_world_size() if sharding else 1 + if self.is_first_stage: + self.vocab_start_index = N // chunk_num * stage_id + self.vocab_end_index = N // chunk_num * (stage_id + 1) + self.embed_weight = torch.nn.Parameter(torch.zeros((N // chunk_num, E))) + self.fc_weight = torch.nn.Parameter(torch.zeros((E // chunk_num, E))) else: - self.weight = torch.nn.Parameter(torch.zeros((dim, dim))) + self.fc_weight = torch.nn.Parameter(torch.zeros((E // chunk_num, E))) def input_shape(self): - return (self.bs, self.dim, self.dim) + if self.is_first_stage: + return (self.M,) + else: + return (self.M, self.E) def output_shape(self): - return (1,) if self.is_last_stage else (self.bs, self.dim, self.dim) + if self.is_last_stage: + return (1,) + else: + return (self.M, self.E) def input_dtype(self): - return torch.float32 + if self.is_first_stage: + return torch.int64 + else: + return torch.float32 def output_dtype(self): return torch.float32 - def forward(self, input): - output = F.linear(input, self.weight) + def forward(self, input: torch.Tensor): + if self.is_first_stage: + if self.sharding: + mask = (input < self.vocab_start_index) | \ + (input >= self.vocab_end_index) + input = input.clone() - self.vocab_start_index + input[mask] = 0 + input = F.embedding(input, self.embed_weight) + input = ReduceEmbed.apply(input) + else: + input = F.embedding(input, self.embed_weight) + + output = F.linear(input, self.fc_weight) + if self.is_last_stage: output = torch.sum(output) return output - if __name__ == '__main__': @@ -66,31 +117,34 @@ def forward(self, input): help='use tensor parallel 1f1b') parser.add_argument('--nmb', type=int, default=4, help='num of micro batch') + parser.add_argument('--M', type=int, default=4096, + help='M dimension length = sequence length') + parser.add_argument('--N', type=int, default=50257, + help='word number') + parser.add_argument('--E', type=int, default=2048, + help='E dimension length = hidden dimension length') args = parser.parse_args() assert args.use_naive ^ args.use_tp1f1b, "Specify (only) 1 way pipeline" + print(args) cube.init() rank = DeviceGroup().rank - dim = 2048 - mbs = 8 - gbs = mbs * args.nmb - # tp 1f1b if args.use_tp1f1b: - first_stage_model = DummyModel(dim, mbs, 0, sharding=True).cuda() + first_stage_model = DummyModel(args.M, args.N, args.E, 0, sharding=True).cuda() if rank == 0: model = None else: - model = DummyModel(dim, mbs, rank, sharding=False).cuda() + model = DummyModel(args.M, args.N, args.E, rank, sharding=False).cuda() if args.use_naive: # naive pipleline - model = DummyModel(dim, mbs, rank, sharding=False).cuda() + model = DummyModel(args.M, args.N, args.E, rank, sharding=False).cuda() - dataloader = SynDataLoader( - shapes=([mbs, dim, dim],), - dtypes=(torch.float32, ), + dataloader = SynTextDataLoader( + shapes=([args.M],), + dtypes=(torch.int64, ), batch_dims=(0,), length=128000 ) @@ -113,3 +167,4 @@ def forward(self, input): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-20, field_name='e2e'))) + memory_summary() From c5feedb8e45a8c912c67b4c146f602ad046c979e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 4 Mar 2022 17:13:34 +0800 Subject: [PATCH 0616/1892] add test script --- handcraft/pipeline/dummy.py | 14 +++++++++++++ handcraft/pipeline/run.sh | 40 +++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100755 handcraft/pipeline/run.sh diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py index 2ea3de28..a1e8e9e0 100644 --- a/handcraft/pipeline/dummy.py +++ b/handcraft/pipeline/dummy.py @@ -46,6 +46,18 @@ def backward(ctx, grad_output): return grad_output +class IdentityFoward(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + torch.distributed.all_reduce(grad_output) + return grad_output + + class DummyModel(torch.nn.Module): @@ -101,6 +113,8 @@ def forward(self, input: torch.Tensor): else: input = F.embedding(input, self.embed_weight) + if self.sharding: + input = IdentityFoward.apply(input) output = F.linear(input, self.fc_weight) if self.is_last_stage: diff --git a/handcraft/pipeline/run.sh b/handcraft/pipeline/run.sh new file mode 100755 index 00000000..f68d3370 --- /dev/null +++ b/handcraft/pipeline/run.sh @@ -0,0 +1,40 @@ +# 4 gpus + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-tp1f1b --nmb 4 > 4dev4nmb-tp1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-tp1f1b --nmb 8 > 4dev8nmb-tp1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-tp1f1b --nmb 64 > 4dev64nmb-tp1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-naive --nmb 4 > 4dev4nmb-naive.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-naive --nmb 8 > 4dev8nmb-naive.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-naive --nmb 64 > 4dev64nmb-naive.txt + +# 8 gpus + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-tp1f1b --nmb 8 > 8dev8nmb-tp1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-tp1f1b --nmb 16 > 8dev16nmb-tp1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-tp1f1b --nmb 128 > 8dev128nmb-tp1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-naive --nmb 8 > 8dev8nmb-naive.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-naive --nmb 16 > 8dev16nmb-naive.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-naive --nmb 128 > 8dev128nmb-naive.txt + From 7204341bca3f7280c1383645c5892a32634fce43 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 4 Mar 2022 20:37:14 +0800 Subject: [PATCH 0617/1892] fix scalar elementwise op with python types --- cube/graph/operator/function/einops.py | 84 ++++++++++++-------------- 1 file changed, 40 insertions(+), 44 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 89136a05..8b7a4df8 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -92,7 +92,7 @@ def __init__(self, name: Union[str, List[str]]): if n[-1] == EinDim.ReduceType.Sum.value: reduce = EinDim.ReduceType.Sum n = n[:-1] - elif n[-1] == EinDim.ReduceType.Stay: + elif n[-1] == EinDim.ReduceType.Stay.value: reduce = EinDim.ReduceType.Stay n = n[:-1] # get identifier name @@ -107,6 +107,9 @@ def __init__(self, name: Union[str, List[str]]): @property def name(self) -> str: + """ + Return identifier without reduce + """ if len(self._name) == 1: return self._name[0] return '(' + ' '.join(self._name) + ')' @@ -134,7 +137,7 @@ def is_reduce(self): def __repr__(self): name_reduce = [name + reduce.value for name, reduce in zip(self._name, self._reduce)] - if len(self._name) == 0: + if len(self._name) == 1: return self._name[0] + self._reduce[0].value return '(' + ' '.join(name_reduce) + ')' @@ -272,9 +275,10 @@ def infer_shape(self) -> bool: raise RuntimeError("No matching anno for given annos") dimlen: Dict[str, int] = dict() for input, ishape in zip(self.inputs(), self._iannos): - if not ((ishape is None and not isinstance(input, IRTensor)) or - len(ishape) == len(input.shape)): - raise RuntimeError(f"node {self._id} {self.signature}: error match input: {input.shape} and einshape: {ishape}") + if not isinstance(input, IRTensor): + continue + if len(ishape) != len(input.shape): + raise RuntimeError(f"node {self._id} {self.signature}: error match input: {input.shape} and ein_shape: {ishape}") for tdim, edim in zip(input.shape, ishape): if len(edim._name) == 1: if edim.name in dimlen and dimlen[edim.name] != tdim: @@ -337,55 +341,47 @@ def parse(self, anno: EinopAnno): if len(anno.inputs) != len(self.inputs()): return False, None, None identifiers = anno.identifiers() + # expand * expand_dims = None if '*' in identifiers: - for idx in range(len(anno.inputs)): - shape = anno.inputs[idx] - shape_anno = [dim.name for dim in shape] - if '*' in shape_anno: - start = shape_anno.index('*') - span = len(self.inputs(idx).shape) - len(shape) + 1 - if span <= 0: - if expand_dims is None: - expand_dims = list() - if len(expand_dims) > 0: - return False, None, None - anno.inputs[idx].remove(EinDim('*')) - if span > 0: - if expand_dims is None: - expand_dims = list() - unused_annos = [c for c in string.ascii_lowercase if c not in identifiers] - if len(unused_annos) < span: - raise RuntimeError("Too many introduced dimensions") - for dim in range(span): - if '*' not in anno.anno.split('->')[-1]: - anno_dim = EinDim([unused_annos[dim] + '+']) - else: - anno_dim = EinDim([unused_annos[dim]]) - expand_dims.append(anno_dim) - if len(expand_dims) != span: - return False, None, None - anno.inputs[idx] = anno.inputs[idx][:start] + expand_dims + anno.inputs[idx][start+1:] - for idx in range(len(anno.outputs)): - shape = anno.outputs[idx] - shape_anno = [dim.name for dim in shape] - if '*' in shape_anno: + # names + in_names = [[e.name for e in input] for input in anno.inputs] + out_names = [[e.name for e in out] for out in anno.outputs] + spatial = all(['*' in names for names in out_names]) + candicates = [c if spatial else c + '^' for c in string.ascii_lowercase if c not in identifiers] + # go through inputs + for idx, (names, input) in enumerate(zip(in_names, self.inputs())): + if '*' in names: + if not isinstance(input, IRTensor): + return False, None, None + pos = names.index('*') + span = len(self.inputs(idx).shape) - (len(names) - 1) + if expand_dims is not None and len(expand_dims) != span: + return False, None, None if expand_dims is None: - raise RuntimeError("* should appear in inputs") - start = shape_anno.index('*') - span = len(expand_dims) - anno.outputs[idx] = anno.outputs[idx][:start] + expand_dims + anno.outputs[idx][start+1:] + expand_dims = [] + if span > 0: + expand_dims = [EinDim(candicates[dim]) for dim in range(span)] + anno.inputs[idx] = anno.inputs[idx][:pos] + expand_dims + anno.inputs[idx][pos+1:] + # * should appear in inputs + if expand_dims is None: + return False, None, None + # go through outputs + for idx, names in enumerate(out_names): + if '*' in names: + pos = names.index('*') + anno.outputs[idx] = anno.outputs[idx][:pos] + expand_dims + anno.outputs[idx][pos+1:] # check dimension consistency dimlen: Dict[str, int] = dict() - for shape, input in zip(anno.inputs, self.inputs()): + for eshape, input in zip(anno.inputs, self.inputs()): if not isinstance(input, IRTensor): - if not (len(shape) != 1 and shape[0].name != '1'): + if not (len(eshape) == 1 and eshape[0].name == '1'): return False, None, None else: - if len(input.shape) != len(shape): + if len(input.shape) != len(eshape): return False, None, None - for edim, nele in zip(shape, input.shape): + for edim, nele in zip(eshape, input.shape): if edim.name in dimlen: if nele != dimlen[edim.name]: return False, None, None From 36b3df1b69a54a3f72b7f8c459895ca972b185ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 6 Mar 2022 12:44:30 +0800 Subject: [PATCH 0618/1892] fix compiler return single value bug --- cube/compiler.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cube/compiler.py b/cube/compiler.py index 63db166d..dcb92f7e 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -7,6 +7,7 @@ from cube.graph import parser from cube.graph.adapter.gen import AdapterGener +from cube.graph.graph import IRGraph from cube.graph.operator.operator import IRDataOperation from cube.logics.pool import SchedulePool @@ -142,6 +143,8 @@ def decorator(fn: Callable) -> Callable: outputs = fn(model_graph, ir_dataloader) if outputs is None: outputs = [] + elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = [outputs] graph = LogicTranslator.gen_logic_graph(outputs=outputs) if len(PAS) == 1: @@ -152,6 +155,9 @@ def decorator(fn: Callable) -> Callable: graph = A(graph, resource) graph = S(graph, resource) + if not isinstance(graph, IRGraph): + raise RuntimeError("Expected policy return IRGraph") + # check assignment and order for node in graph.nodes(): if len(node.device) == 0: @@ -179,7 +185,8 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on grouping adapters : {:.2f} s'.format(span)) - # execplan.draw(outfile='execplan.png') + execplan.graph.reset_dependency() + # execplan.analyze(outfile='execplan.png') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() From e3676157c8afdcfc6f73dc106f7c4e685c3b38fd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 6 Mar 2022 13:09:19 +0800 Subject: [PATCH 0619/1892] support torch.size --- cube/graph/parser/parser.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index b3ebd8f0..ea1ebf76 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -36,6 +36,8 @@ def parse_module(module, The overall entry to parse a torchscript graph module """ frame.push() + print(module.graph) + print(module.code) # handle graph input -- Assuming all the inputs are tensors input_var_name = [input.debugName() for input in module.graph.inputs()] @@ -254,11 +256,16 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: outputs = [output for output in node.outputs()] # handle inputs: - input_val = list() - for input in inputs: - var_name = input.debugName() - val = frame.get_var(var_name) - input_val.append(val) + input_val = [frame.get_var(input.debugName()) for input in inputs] + + # special handling on aten::size(tensor: tensor, dim: int) + if fsig == 'torch.size': + assert len(inputs) == 2 and len(outputs) == 1, \ + "Expected 2 inputs and 1 outputs for torch.size" + tensor, dim = input_val + output: int = tensor.shape[dim] + frame.add_var(outputs[0].debugName(), output) + return [] # create IR node ir_node = Sign2Op.map(fsig)(inputs=input_val) From 6002cd94f04895067b725e0569f9d4cecf7627a9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 6 Mar 2022 14:43:03 +0800 Subject: [PATCH 0620/1892] fix parsing bug of brackets --- cube/graph/operator/function/einops.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 8b7a4df8..4b2bf310 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -191,20 +191,22 @@ def parse_shape(self, shape: str) -> List[EinDim]: if bracket_group: raise RuntimeError("brackets inside brackets not allowed") bracket_group = True - identifiers.append(current_identifier) + if len(current_identifier) > 0: + identifiers.append(current_identifier) current_identifier = list() elif w == ')': if not bracket_group: raise RuntimeError("backets are not balanced at (") bracket_group = False - identifiers.append(current_identifier) + if len(current_identifier) > 0: + identifiers.append(current_identifier) current_identifier = list() else: if bracket_group: current_identifier.append(w) self._identifiers.add(w) else: - if len(current_identifier) != 0: + if len(current_identifier) > 0: identifiers.append(current_identifier) current_identifier = [w] self._identifiers.add(w) @@ -280,7 +282,7 @@ def infer_shape(self) -> bool: if len(ishape) != len(input.shape): raise RuntimeError(f"node {self._id} {self.signature}: error match input: {input.shape} and ein_shape: {ishape}") for tdim, edim in zip(input.shape, ishape): - if len(edim._name) == 1: + if len(edim.names()) == 1: if edim.name in dimlen and dimlen[edim.name] != tdim: raise RuntimeError(f"op: {self.signature} has different shape for same dim annotation {edim.name}") dimlen[edim.name] = tdim @@ -334,7 +336,7 @@ def new(self, inputs: List, outputs: List): op.set_output(idx, output) return op - def parse(self, anno: EinopAnno): + def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[EinDim]]]: """ parse annotations, assuming input tensor shape is given """ From fc32336a4602366765d638c4dc2600a6f858b8f0 Mon Sep 17 00:00:00 2001 From: lynex Date: Sun, 6 Mar 2022 16:06:27 +0800 Subject: [PATCH 0621/1892] Sum add dim support --- cube/graph/operator/function/function.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index b65be0d7..4842ea90 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -102,6 +102,10 @@ def Sum(signature, inputs): dim = inputs[1] if dim is not None: keepdim = inputs[2] if len(inputs) > 2 else False + dim_len = len(tensor[0].shape) + anno = "".join([f'b{i} ' for i in range(dim_len)]) + " -> " + "".join([f'b{i} ' if i not in dim else "" for i in range(dim_len)]) + annos.append(anno) + # print("### Sum::anno = {}", annos) return IREinops(signature, annos, tensor, 'sum', dim=dim, keepdim=keepdim) else: From 9e656164e86deeb85092a870ce0dc407414a236b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 6 Mar 2022 16:13:04 +0800 Subject: [PATCH 0622/1892] add view and reshape --- cube/graph/operator/function/function.py | 212 +++++++++++++++++++++-- cube/graph/parser/mapping.py | 7 +- 2 files changed, 206 insertions(+), 13 deletions(-) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 4842ea90..7fbb33a3 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -1,3 +1,8 @@ +from typing import Iterable, List, Optional, Union, Dict +import string +import copy + +from cube.ir.cten import IRTensor from cube.graph.operator.function.einops import EinDim, EinopAnno, IREinops from cube.graph.operator.function.conv import IRConv2D from cube.graph.operator.function.conv import IRConv3D @@ -5,6 +10,53 @@ from cube.graph.operator.function.scripteinops import IRScriptEinOps +def _create_eshape(shape: List[int], iterator: Optional[Iterable] = None, + reduce: EinDim.ReduceType = EinDim.ReduceType.Spatial) -> List[str]: + """ + Create dimension annotation given the shape and + letter iterator + """ + if iterator is None: + iterator = iter(string.ascii_lowercase) + return [next(iterator) + reduce.value for _ in range(len(shape))] + + +def _create_anno(ins: List[List[Union[str, List[str]]]], + ous: List[List[Union[str, List[str]]]]) -> str: + """ + Create annotation string + e.g., + ins = [ ['a', 'b', 'c+'], ['c+', ['d', 'e']] ] + ous = [ ['a', 'b', 'd', 'e'] ] + => + 'a b c+, c+ (d e) -> a b d e' + """ + in_annos = list() + ou_annos = list() + for shape in ins: + flatten = list() + for edim in shape: + if isinstance(edim, str): + flatten.append(edim) + # List + elif len(edim) == 1: + flatten.append(edim[0]) + else: + flatten.append('(' + ' '.join(edim) + ')') + in_annos.append(' '.join(flatten)) + for shape in ous: + flatten = list() + for edim in shape: + if isinstance(edim, str): + flatten.append(edim) + # List + elif len(edim) == 1: + flatten.append(edim[0]) + else: + flatten.append('(' + ' '.join(edim) + ')') + ou_annos.append(' '.join(flatten)) + return ', '.join(in_annos) + ' -> ' + ', '.join(ou_annos) + def Linear(signature, inputs): annos = [ @@ -24,24 +76,56 @@ def BatchLinear(signature, inputs): def Add(signature, inputs): assert len(inputs) == 3 inputs, alpha = inputs[0:2], inputs[2] - # TODO: support broadcast annos = [ '*, 1 -> *', '1, * -> *', '*, * -> *', ] + # broadcast + lhs, rhs = inputs + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ + len(lhs.shape) == len(rhs.shape): + if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): + # TODO: support spatial partitioning on broadcast dim + lshape = _create_eshape(lhs.shape) + rshape = copy.copy(lshape) + oshape = copy.copy(lshape) + for dim in range(len(lhs.shape)): + if lhs.shape[dim] < rhs.shape[dim]: + oshape[dim] = rshape[dim] + lshape[dim] = str(lhs.shape[dim]) + elif lhs.shape[dim] > rhs.shape[dim]: + oshape[dim] = lshape[dim] + rshape[dim] = str(rhs.shape[dim]) + annos = [_create_anno([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'add', alpha=alpha) def Sub(signature, inputs): assert len(inputs) == 3 inputs, alpha = inputs[0:2], inputs[2] - # TODO: support broadcast annos = [ '*, 1 -> *', '1, * -> *', '*, * -> *', ] + # broadcast + lhs, rhs = inputs + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ + len(lhs.shape) == len(rhs.shape): + if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): + # TODO: support spatial partitioning on broadcast dim + lshape = _create_eshape(lhs.shape) + rshape = copy.copy(lshape) + oshape = copy.copy(lshape) + for dim in range(len(lhs.shape)): + if lhs.shape[dim] < rhs.shape[dim]: + oshape[dim] = rshape[dim] + lshape[dim] = str(lhs.shape[dim]) + elif lhs.shape[dim] > rhs.shape[dim]: + oshape[dim] = lshape[dim] + rshape[dim] = str(rhs.shape[dim]) + annos = [_create_anno([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'sub', alpha=alpha) @@ -51,6 +135,23 @@ def Mul(signature, inputs): '1, * -> *', '*, * -> *', ] + # broadcast + lhs, rhs = inputs + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ + len(lhs.shape) == len(rhs.shape): + if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): + # TODO: support spatial partitioning on broadcast dim + lshape = _create_eshape(lhs.shape) + rshape = copy.copy(lshape) + oshape = copy.copy(lshape) + for dim in range(len(lhs.shape)): + if lhs.shape[dim] < rhs.shape[dim]: + oshape[dim] = rshape[dim] + lshape[dim] = str(lhs.shape[dim]) + elif lhs.shape[dim] > rhs.shape[dim]: + oshape[dim] = lshape[dim] + rshape[dim] = str(rhs.shape[dim]) + annos = [_create_anno([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'mul') @@ -60,6 +161,23 @@ def Div(signature, inputs): '1, * -> *', '*, * -> *', ] + # broadcast + lhs, rhs = inputs + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ + len(lhs.shape) == len(rhs.shape): + if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): + # TODO: support spatial partitioning on broadcast dim + lshape = _create_eshape(lhs.shape) + rshape = copy.copy(lshape) + oshape = copy.copy(lshape) + for dim in range(len(lhs.shape)): + if lhs.shape[dim] < rhs.shape[dim]: + oshape[dim] = rshape[dim] + lshape[dim] = str(lhs.shape[dim]) + elif lhs.shape[dim] > rhs.shape[dim]: + oshape[dim] = lshape[dim] + rshape[dim] = str(rhs.shape[dim]) + annos = [_create_anno([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'div') @@ -113,17 +231,91 @@ def Sum(signature, inputs): def Transpose(signature, inputs): - def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: - dim0, dim1 = node.kwargs[0], node.kwargs[1] - anno.outputs[0][dim0], anno.outputs[0][dim1] = \ - anno.inputs[0][dim1], anno.inputs[0][dim0] - return anno - annos = [('* -> *', adapt),] - inputs, dim0, dim1 = inputs[0:1], inputs[1], inputs[2] - return IREinops(signature, annos, inputs, 'transpose', + """ + out = torch.transpose(tensor, dim0, dim1) + """ + assert len(inputs) == 3 + input, dim0, dim1 = inputs + + edim_in = _create_eshape(input.shape) + edim_ou = copy.copy(edim_in) + edim_ou[dim0], edim_ou[dim1] = edim_ou[dim1], edim_ou[dim0] + anno = _create_anno([edim_in], [edim_ou]) + + return IREinops(signature, [anno], [input], 'transpose', dim0=dim0, dim1=dim1) +def View(signature, inputs): + """ + out = torch.Tensor.view(tensor: torch.Tensor, shape: List[int]) + """ + assert len(inputs) == 2 + input, shape = inputs + in_shape, ou_shape = list(input.shape), shape + print(in_shape, ou_shape) + + # shape check + def nele(shape, nele=1): + for dimlen in shape: nele *= dimlen + return nele + # handle '-1' in shape + cnt = nele(in_shape) + if -1 in ou_shape: + idx = ou_shape.index(-1) + ou_shape[idx] = cnt // (-nele(ou_shape)) + assert nele(in_shape) == nele(ou_shape), "shape mismatch" + # generate annotation + shape_map: Dict[str, int] = dict() + letters = iter(string.ascii_lowercase) + in_anno, ou_anno = [], [] + in_dim, ou_dim = 0, 0 + in_remain, ou_remain = in_shape[in_dim], ou_shape[ou_dim] + in_bracket, ou_bracket = [], [] + in_dimlen, ou_dimlen = 1, 1 + while True: + letter = next(letters) + dimlen = min(in_remain, ou_remain) + in_dimlen, ou_dimlen = in_dimlen * dimlen, ou_dimlen * dimlen + in_remain, ou_remain = in_remain // dimlen, ou_remain // dimlen + in_bracket.append(letter) + ou_bracket.append(letter) + shape_map[letter] = dimlen + if in_remain == 1: + in_anno.append(in_bracket) + in_bracket, in_dimlen = [], 1 + in_dim += 1 + if in_dim < len(in_shape): + in_remain = in_shape[in_dim] + if ou_remain == 1: + ou_anno.append(ou_bracket) + ou_bracket, ou_dimlen = [], 1 + ou_dim += 1 + if ou_dim < len(ou_shape): + ou_remain = ou_shape[ou_dim] + if in_dim == len(in_shape) and ou_dim == len(ou_shape): + break + # setup reduction: only first dimension can be spatially partitioned + spatial_in = set() + spatial_ou = set() + for in_bracket in in_anno: + spatial_in.add(in_bracket[0]) + for ou_bracket in ou_anno: + spatial_ou.add(ou_bracket[0]) + spatial = spatial_in.intersection(spatial_ou) + for bracket in in_anno + ou_anno: + for subdim, edim in enumerate(bracket): + if edim not in spatial: + bracket[subdim] = str(shape_map[edim]) + # bracket[subdim] = edim + '^' + anno = _create_anno([in_anno], [ou_anno]) + return IREinops(signature, [anno], [input], 'view', shape=shape) + + +def Reshape(signature, inputs): + return View(signature, inputs) + + # def Conv2D(signature, inputs): # """ # torch.conv2d(input, weight, bias, stride, padding, dialation, groups) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index f8dc3978..39546b8a 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -42,9 +42,6 @@ def register(signature: str, op: IRFwOperation): # tensor template __ttemplate = lambda name: f'torch.{name}' - # customized - __customize = lambda name: f'cube.runtime.function.complex.{name}' - # einops __einopsize = lambda name: f'einops._torch_specific.{name}' @@ -80,6 +77,10 @@ def register(signature: str, op: IRFwOperation): __ttemplate('transpose') : function.Transpose, + __ttemplate('view'): function.View, + + __ttemplate('reshape'): function.Reshape, + __ttemplate('conv2d'): function.Conv2D, __ttemplate('conv3d'): function.Conv3D, From a64800e9d3133020ee33932d76c700bd169758a1 Mon Sep 17 00:00:00 2001 From: lynex Date: Sun, 6 Mar 2022 20:17:01 +0800 Subject: [PATCH 0623/1892] debug weather --- cube/graph/operator/function/einops.py | 10 + examples/atmosphere/weather.py | 580 +++++++++++++++---------- 2 files changed, 371 insertions(+), 219 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 4b2bf310..ae1ca2f9 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -274,6 +274,7 @@ def infer_shape(self) -> bool: self._oannos = oannos if ret: break if not ret: + print(f'self._annos = {self._annos}, self._adapt = {self._adapt}') raise RuntimeError("No matching anno for given annos") dimlen: Dict[str, int] = dict() for input, ishape in zip(self.inputs(), self._iannos): @@ -340,9 +341,11 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei """ parse annotations, assuming input tensor shape is given """ + print(f"anno = {anno}; anno.inputs = {anno.inputs}; self.inputs = {self.inputs()}") if len(anno.inputs) != len(self.inputs()): return False, None, None identifiers = anno.identifiers() + print(f'identifiers = {identifiers}') # expand * expand_dims = None @@ -356,10 +359,12 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei for idx, (names, input) in enumerate(zip(in_names, self.inputs())): if '*' in names: if not isinstance(input, IRTensor): + print('Ln 362') return False, None, None pos = names.index('*') span = len(self.inputs(idx).shape) - (len(names) - 1) if expand_dims is not None and len(expand_dims) != span: + print('Ln 367') return False, None, None if expand_dims is None: expand_dims = [] @@ -368,6 +373,7 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei anno.inputs[idx] = anno.inputs[idx][:pos] + expand_dims + anno.inputs[idx][pos+1:] # * should appear in inputs if expand_dims is None: + print('Ln 376') return False, None, None # go through outputs for idx, names in enumerate(out_names): @@ -379,15 +385,19 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei for eshape, input in zip(anno.inputs, self.inputs()): if not isinstance(input, IRTensor): if not (len(eshape) == 1 and eshape[0].name == '1'): + print('Ln 388') return False, None, None else: if len(input.shape) != len(eshape): + print('Ln 392') return False, None, None for edim, nele in zip(eshape, input.shape): if edim.name in dimlen: if nele != dimlen[edim.name]: + print('Ln 397') return False, None, None dimlen[edim.name] = nele + print('Ln 400') return True, anno.inputs, anno.outputs def einexpr(self) -> str: diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index 81865427..1224382f 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -6,11 +6,19 @@ torch.set_default_tensor_type(torch.DoubleTensor) +from typing import List +import cube +from examples.poisson.policy.naive import PAS + +from einops.layers.torch import Rearrange class Atmoshpere(torch.nn.Module): - def __init__(self, nz, ny, nx, dy, dx, x0, y0, device='cuda'): + def __init__(self, + nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, + bar_x_filter, bar_y_filter, delta_x_filter, delta_y_filter, + device='cuda'): super().__init__() - self.device = torch.device(device) + #self.device = torch.device(device) # physics constant self.g = 9.8 # acceleration of gravity, unit in m/s^2 @@ -21,41 +29,6 @@ def __init__(self, nz, ny, nx, dy, dx, x0, y0, device='cuda'): # self.OMEGA = 7.292e-5 # angular speed of the Earth s^-1 self.OMEGA = 1e-1 # angular speed of the Earth s^-1 - # atmoshpere verticle profile - self.hight_profile = torch.tensor([ - 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, - 8.5, 9, 9.5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, - 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100 - ]).to(self.device) * 1e3 - - self.pressure_profile = torch.tensor([ - 1013.25, 1001.20, 989.45, 977.72, 966.11, 954.61, 943.22, 931.94, 920.77, 909.71, 898.80, 845.59, 795.0, - 746.9, 701.2, 657.8, 616.6, 577.5, 540.5, 505.4, 472.2, 440.7, 411.1, 383.0, 356.5, 331.5, 308.0, 285.8, - 265.0, 227.0, 194.0, 165.8, 141.7, 121.1, 103.5, 88.5, 75.7, 64.7, 55.3, 47.3, 40.5, 34.7, 29.7, 25.5, 21.9, - 18.8, 16.2, 13.9, 12.0, 10.3, 8.89, 7.67, 6.63, 5.75, 4.99, 4.33, 3.77, 3.29, 2.87, 2.51, 2.20, 1.93, 1.69, - 1.49, 1.31, 1.16, 1.02, 0.903, 0.903, 0.425, 0.220, 0.109, 0.0522, 0.0239, 0.0105, 0.0045, 0.0018, 0.00076, - 0.00032 - ]).to(self.device) * 1e2 - - self.temperature_profile = torch.tensor([ - 288.15, 287.50, 286.85, 286.20, 285.55, 284.90, 284.25, 283.60, 282.95, 282.30, 281.65, 278.40, 275.15, - 271.91, 268.66, 265.41, 262.17, 258.92, 255.68, 252.43, 249.19, 245.94, 242.70, 239.46, 236.22, 232.97, - 229.73, 226.49, 223.25, 216.78, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, - 217.58, 218.57, 219.57, 220.56, 221.55, 222.54, 223.54, 224.53, 225.52, 226.51, 227.50, 228.49, 230.97, - 233.74, 236.51, 239.28, 242.05, 244.82, 247.58, 250.35, 253.11, 255.88, 258.64, 261.40, 264.16, 266.93, - 269.68, 270.65, 270.65, 270.65, 260.77, 247.02, 233.29, 219.59, 208.40, 198.64, 188.89, 186.87, 188.42, - 195.08 - ]).to(self.device) - - self.density_profile = torch.tensor([ - 1.225, 1.213, 1.202, 1.190, 1.179, 1.167, 1.156, 1.145, 1.134, 1.123, 1.112, 1.058, 1.007, 0.957, 0.909, - 0.863, 0.819, 0.777, 0.736, 0.697, 0.660, 0.624, 0.590, 0.557, 0.526, 0.496, 0.467, 0.440, 0.414, 0.365, - 0.312, 0.267, 0.228, 0.195, 0.166, 0.142, 0.122, 0.104, 0.0889, 0.0757, 0.0645, 0.0550, 0.0469, 0.0401, - 0.0343, 0.0293, 0.0251, 0.0215, 0.0184, 0.0158, 0.0136, 0.0116, 0.00989, 0.00846, 0.00726, 0.00624, 0.00537, - 0.00463, 0.00400, 0.00346, 0.00299, 0.00260, 0.00226, 0.00197, 0.00171, 0.0015, 0.00132, 0.00116, 0.00103, - 5.7e-4, 3.1e-4, 1.6e-4, 8.3e-4, 4.0e-5, 1.8e-5, 8.2e-6, 3.4e-6, 7.5e-7, 5.6e-7 - ]).to(self.device) - # simulation domain self.nx = nx self.ny = ny @@ -66,201 +39,194 @@ def __init__(self, nz, ny, nx, dy, dx, x0, y0, device='cuda'): self.x0 = x0 self.y0 = y0 - def init(self, ps, pt, zs): - self.Y = ( - torch.linspace(0, self.ny, self.ny + 1, device=self.device) * self.dy + self.y0 - ).view(1, self.ny + 1, 1) - self.deltaA = self.RE**2 * torch.cos(self.bar_y(self.Y) * 0.5).to(self.device) * self.dx * self.dy # (1, ny, 1) - self.f = 2 * self.OMEGA * torch.sin(self.bar_y(self.Y)) * torch.cos(self.bar_y(self.Y)) * self.RE # (nz, ny, nx) - - # vertical grids - pt = torch.tensor([pt]).view(1, 1, 1).to(self.device) - zt = self.hight_from_pressure(pt) - z = torch.linspace(1, 0, self.nz + 1, device=self.device).view(-1, 1, 1) * zt - p_ = self.pressure_from_hight(z) - self.sigma = (p_ - pt) / (p_[-1] - pt) # (nz + 1, 1, 1) - - # column pressure, with shape (1, ny, nx) - pi = (ps - pt).view(1, self.ny, self.nx) - - # potential temperature factor - p_ = pt + self.sigma * pi # (nz + 1, ny, nx) - self.P_ = (p_ / self.PSEA)**self.KAPPA # (nz + 1, ny, nx) - self.P = self.delta_z(p_ * self.P_) / self.delta_z(p_) / (1 + self.KAPPA) # (nz, ny, nx) + self.deltaA = deltaA + self.Y = Y + self.f = f + self.sigma = sigma + self.P_ = P_ + self.P = P + self.phi = phi + self.zs = zs + self.w = w - # potential temperature - p = self.PSEA * self.P**(1 / self.KAPPA) - T = self.temperature_from_pressure(p) - theta = T / self.P + self.bar_x_filter = bar_x_filter + self.bar_y_filter = bar_y_filter + self.delta_x_filter = delta_x_filter + self.delta_y_filter = delta_y_filter - # geopotential (nz, ny, nx) - self.phi = torch.zeros((self.nz, self.ny, self.nx), device=self.device) - self.zs = zs + self.pre_conv3d_reshape = Rearrange('(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # x = X.view(1, 1, Nz, Ny, Nx) + self.post_conv3d_reshape = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') - # vertical velocity - self.w = torch.zeros((self.nz + 1, self.ny, self.nx), device=self.device) - return pi, theta def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # flux - F = self.bar_x(self.pad_x(pi)) * 0.5 * u * self.RE * self.dy # (nz, ny, nx + 1) - G = self.bar_y(self.pad_y(pi)) * 0.5 * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) - B = self.bar_y2(self.bar_x(self.pad_y(F))) / 12. # (nz, ny, nx) - C = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(G)))) / 12. # (nz, ny + 1, nx + 1) - D = self.bar_y2(self.pad_y(G)) + self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) - E = self.bar_y2(self.pad_y(G)) - self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) - Q = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(F)))) / 12. # (nz, ny + 1, nx + 1) - R = self.bar_x2(self.bar_y(self.pad_x(G))) / 12. # (nz, ny, nx) - S = self.bar_y(self.bar_x(self.pad_x(G))) + self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) - T = self.bar_y(self.bar_x(self.pad_x(G))) - self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) - - pi1 = pi0 - dt / self.deltaA * ((self.delta_x(F) + self.delta_y(G)) * self.dz).sum(axis=0) # (nz, ny, nx) + F = self.bar_x(self.pad_x(pi)) * u #* self.RE * self.dy # (nz, ny, nx + 1) + G = self.bar_y(self.pad_y(pi)) * v #* self.RE * self.dx # * torch.cos(self.Y) # (nz, ny + 1, nx) + # F = self.bar_x(self.pad_x(pi)) * 0.5 * u * self.RE * self.dy # (nz, ny, nx + 1) + # G = self.bar_y(self.pad_y(pi)) * 0.5 * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) + # B = self.bar_y2(self.bar_x(self.pad_y(F))) / 12. # (nz, ny, nx) + # C = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(G)))) / 12. # (nz, ny + 1, nx + 1) + # D = self.bar_y2(self.pad_y(G)) + self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) + # E = self.bar_y2(self.pad_y(G)) - self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) + # Q = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(F)))) / 12. # (nz, ny + 1, nx + 1) + # R = self.bar_x2(self.bar_y(self.pad_x(G))) / 12. # (nz, ny, nx) + # S = self.bar_y(self.bar_x(self.pad_x(G))) + self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) + # T = self.bar_y(self.bar_x(self.pad_x(G))) - self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) + + # pi1 = pi0 - dt / self.deltaA * ((self.delta_x(F) + self.delta_y(G)) * self.dz).sum(axis=0) # (nz, ny, nx) + tmp0 = ((self.delta_x(F) + self.delta_y(G)) * self.dz) + tmp1 = tmp0.sum(dim=0) + pi1 = pi0 - self.deltaA * tmp1 # pi1 = pi0 - dt / self.deltaA * tmp1 # print('pi:', pi1.mean()) - # update diagnostic variable w (nz + 1, ny, nx) - for i in range(1, self.nz + 1): - self.w[i] = - ((self.delta_x(F[:i]) + self.delta_y(G[:i])) * self.dz).sum(axis=0) / self.deltaA / pi1 \ - - self.sigma[i] * (pi1 - pi0) / dt / pi1 + # TODO a custom Op needed + # # update diagnostic variable w (nz + 1, ny, nx) + # for i in range(1, self.nz + 1): + # self.w[i] = - ((self.delta_x(F[:i]) + self.delta_y(G[:i])) * self.dz).sum(dim=0) / self.deltaA / pi1 \ + # - self.sigma[i] * (pi1 - pi0) / dt / pi1 # print('w:', self.w.mean()) - # update potential temperature theta (nz, ny, nx) - theta_ = self.pad_z( - (self.bar_z(self.P * theta) - self.delta_z(theta) * self.P_[1:-1]) / self.delta_z(self.P) - ) # (nz + 1, ny, nx) - theta1 = pi0 / pi1 * theta0 + dt / self.deltaA / pi1 * ( - (self.delta_x(F * self.bar_x(self.pad_x(theta))) + self.delta_y(G * self.bar_y(self.pad_y(theta)))) / 2. + - pi * self.deltaA * self.delta_z(self.w * theta_) / self.dz + - 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(theta)))) - ) + # # update potential temperature theta (nz, ny, nx) + # theta_ = self.pad_z( + # (self.bar_z(self.P * theta) - self.delta_z(theta) * self.P_[1:-1]) / self.delta_z(self.P) + # ) # (nz + 1, ny, nx) + # theta1 = pi0 / pi1 * theta0 + dt / self.deltaA / pi1 * ( + # (self.delta_x(F * self.bar_x(self.pad_x(theta))) + self.delta_y(G * self.bar_y(self.pad_y(theta)))) / 2. + + # pi * self.deltaA * self.delta_z(self.w * theta_) / self.dz + + # 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(theta)))) + # ) # print('theta:', theta1.mean()) - # update geopotential - self.phi[-1] = self.g * self.zs - self.CPD * (self.P[-1] - self.P_[-1]) * theta[-1] - for i in range(1, self.nz): - tmp = self.phi[-i] - self.CPD * (self.P_[-i - 1] - self.P[-i]) * theta[-i] - self.phi[-1 - i] = tmp - self.CPD * (self.P[-1 - i] - self.P_[-1 - i]) * theta[-1 - i] + # TODO a custom Op needed + # # update geopotential + # self.phi[-1] = self.g * self.zs - self.CPD * (self.P[-1] - self.P_[-1]) * theta[-1] + # for i in range(1, self.nz): + # tmp = self.phi[-i] - self.CPD * (self.P_[-i - 1] - self.P[-i]) * theta[-i] + # self.phi[-1 - i] = tmp - self.CPD * (self.P[-1 - i] - self.P_[-1 - i]) * theta[-1 - i] # print('phi:', self.phi.mean()) - # update u (nz, ny, nx + 1) - pi0_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi0 * self.deltaA)))) / 8. # (nz, ny, nx + 1) - pi1_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi1 * self.deltaA)))) / 8. # (nz, ny, nx + 1) - pi_w_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny, nx + 1) - advec = ( - - self.delta_x(self.pad_x(B * self.bar_x(u))) - - self.delta_y(C * self.bar_y(self.pad_y(u))) - + self.delta_D(self.pad_x(D * self.bar_xy(self.pad_y(u)))) - + self.delta_E(self.pad_x(E * self.bar_xy(self.pad_y(u)))) - ) / 2. - trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(u)) * 0.5) / self.dz - press = - self.RE * self.dy * ( - self.delta_x(self.pad_x(self.phi)) * self.delta_x(self.pad_x(pi)) / 2. + - self.delta_x(self.pad_x(pi)) * 0.5 * self.CPD * self.bar_x(self.pad_x( - theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) - )) - ) - diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(u)))) - cori = self.RE * self.dx * self.dy * 0.25 * ( - self.bar_x(self.pad_x(pi * self.bar_y(v) * (self.f + 0. * self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) - ) * 0.0 - u1 = (pi0_deltaA * u0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA - # print('u1:', u1.mean()) - - # update v (nz, ny + 1, nx) - pi0_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi0 * self.deltaA)))) / 8. # (nz, ny + 1, nx) - pi1_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi1 * self.deltaA)))) / 8. # (nz, ny + 1, nx) - pi_w_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny + 1, nx) - advec = ( - - self.delta_x(Q * self.bar_x(self.pad_x(v))) - - self.delta_y(self.pad_y(R * self.bar_y(v))) - + self.delta_D(self.pad_y(S * self.bar_xy(self.pad_x(v)))) - + self.delta_E(self.pad_y(T * self.bar_xy(self.pad_x(v)))) - ) / 2. - trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(v)) * 0.5) / self.dz - press = - self.RE * self.dx * ( - self.delta_y(self.pad_y(self.phi)) * self.delta_y(self.pad_y(pi)) / 2. + - self.delta_y(self.pad_y(pi)) * 0.5 * self.CPD * self.bar_y(self.pad_y( - theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) - )) - ) - diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(v)))) - cori = - self.RE * self.dx * self.dy * 0.25 * ( - self.bar_y(self.pad_y(pi * self.bar_x(u) * (self.f + self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) - ) * 0.0 - v1 = (pi0_deltaA * v0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA - # print('v1:', v1.mean()) - - return pi1, theta1, u1, v1 - - def forward(self, dt, pi, theta, u, v): - pi_, theta_, u_, v_ = self.step(dt / 2, pi, theta, u, v, pi, theta, u, v) - return self.step(dt, pi_, theta_, u_, v_, pi, theta, u, v) - - def hight_from_pressure(self, p): - ind0 = torch.abs((p[None] - self.pressure_profile[(..., ) + (None, ) * len(p.shape)])).argmin(axis=0) - ind1 = (p > self.pressure_profile[ind0]) * 2 - 1 + ind0 - hight = (self.hight_profile[ind1] - self.hight_profile[ind0]) * (p - self.pressure_profile[ind0]) / ( - self.pressure_profile[ind1] - self.pressure_profile[ind0]) + self.hight_profile[ind0] - return hight + # # update u (nz, ny, nx + 1) + # pi0_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi0 * self.deltaA)))) / 8. # (nz, ny, nx + 1) + # pi1_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi1 * self.deltaA)))) / 8. # (nz, ny, nx + 1) + # pi_w_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny, nx + 1) + # advec = ( + # - self.delta_x(self.pad_x(B * self.bar_x(u))) + # - self.delta_y(C * self.bar_y(self.pad_y(u))) + # + self.delta_D(self.pad_x(D * self.bar_xy(self.pad_y(u)))) + # + self.delta_E(self.pad_x(E * self.bar_xy(self.pad_y(u)))) + # ) / 2. + # trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(u)) * 0.5) / self.dz + # press = - self.RE * self.dy * ( + # self.delta_x(self.pad_x(self.phi)) * self.delta_x(self.pad_x(pi)) / 2. + + # self.delta_x(self.pad_x(pi)) * 0.5 * self.CPD * self.bar_x(self.pad_x( + # theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) + # )) + # ) + # diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(u)))) + # cori = self.RE * self.dx * self.dy * 0.25 * ( + # self.bar_x(self.pad_x(pi * self.bar_y(v) * (self.f + 0. * self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) + # ) * 0.0 + # u1 = (pi0_deltaA * u0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA + # # print('u1:', u1.mean()) + + # # update v (nz, ny + 1, nx) + # pi0_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi0 * self.deltaA)))) / 8. # (nz, ny + 1, nx) + # pi1_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi1 * self.deltaA)))) / 8. # (nz, ny + 1, nx) + # pi_w_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny + 1, nx) + # advec = ( + # - self.delta_x(Q * self.bar_x(self.pad_x(v))) + # - self.delta_y(self.pad_y(R * self.bar_y(v))) + # + self.delta_D(self.pad_y(S * self.bar_xy(self.pad_x(v)))) + # + self.delta_E(self.pad_y(T * self.bar_xy(self.pad_x(v)))) + # ) / 2. + # trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(v)) * 0.5) / self.dz + # press = - self.RE * self.dx * ( + # self.delta_y(self.pad_y(self.phi)) * self.delta_y(self.pad_y(pi)) / 2. + + # self.delta_y(self.pad_y(pi)) * 0.5 * self.CPD * self.bar_y(self.pad_y( + # theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) + # )) + # ) + # diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(v)))) + # cori = - self.RE * self.dx * self.dy * 0.25 * ( + # self.bar_y(self.pad_y(pi * self.bar_x(u) * (self.f + self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) + # ) * 0.0 + # v1 = (pi0_deltaA * v0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA + # # print('v1:', v1.mean()) + + # return pi1, theta1, u1, v1 + # return pi1, theta0, u0, v0 + return pi1, theta0, u0, v0 + + def forward(self, pi, theta, u, v, dt): + # pi_, theta_, u_, v_ = self.step(dt / 2., pi, theta, u, v, pi, theta, u, v) + # return self.step(dt, pi_, theta_, u_, v_, pi, theta, u, v) + + #TODO recover me + return self.step(dt, pi, theta, u, v, pi, theta, u, v) - def pressure_from_hight(self, z): - ind0 = torch.abs((z[None] - self.hight_profile[(..., ) + (None, ) * len(z.shape)])).argmin(axis=0) - ind1 = (self.hight_profile[ind0] > z) * 2 - 1 + ind0 - p = (self.pressure_profile[ind1] - self.pressure_profile[ind0]) * (z - self.hight_profile[ind0]) / \ - (self.hight_profile[ind1] - self.hight_profile[ind0]) + self.pressure_profile[ind0] - return p - - def temperature_from_pressure(self, p): - ind0 = torch.abs((p[None] - self.pressure_profile[(..., ) + (None, ) * len(p.shape)])).argmin(axis=0) - ind1 = (p > self.pressure_profile[ind0]) * 2 - 1 + ind0 - T = (self.temperature_profile[ind1] - self.temperature_profile[ind0]) * (p - self.pressure_profile[ind0]) / ( - self.pressure_profile[ind1] - self.pressure_profile[ind0]) + self.temperature_profile[ind0] - return T def pad_x(self, X): return F.pad(X, (1, 1), "circular") def bar_x(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) + # nz, ny, nx = X.shape + # filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) + filter = self.bar_x_filter + x_ext = self.pre_conv3d_reshape(X) + x_convd = F.conv3d(x_ext, filter) + x_unext = self.post_conv3d_reshape(x_convd) + return x_unext def bar_x2(self, X): nz, ny, nx = X.shape - filter = torch.tensor([1., 2., 1.], device=X.device).view(1, 1, 1, 1, 3) + filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 1, 3) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 2) def delta_x(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) + # nz, ny, nx = X.shape + # filter = torch.tensor([-1., 1.]).view(1, 1, 1, 1, 2) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) + filter = self.delta_x_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + def pad_y(self, X): - nz, ny, nx = X.shape - return F.pad(X.view(1, nz, ny, nx), (0, 0, 1, 1), "circular").view(nz, ny + 2, nx) + # nz, ny, nx = X.shape + #TODO check return F.pad(X.view(1, nz, ny, nx), (0, 0, 1, 1), "circular").view(nz, ny + 2, nx) + return F.pad(X, (0, 0, 1, 1), "circular") def bar_y(self, X): nz, ny, nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) + # filter = torch.tensor([1., 1.]).view(1, 1, 1, 2, 1) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) + filter = self.bar_y_filter + x_ext = self.pre_conv3d_reshape(X) + x_convd = F.conv3d(x_ext, filter) + x_unext = self.post_conv3d_reshape(x_convd) + return x_unext def bar_y2(self, X): nz, ny, nx = X.shape - filter = torch.tensor([1., 2., 1.], device=X.device).view(1, 1, 1, 3, 1) + filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 3, 1) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 2, nx) def delta_y(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 2, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) + # nz, ny, nx = X.shape + # filter = torch.tensor([-1., 1.]).view(1, 1, 1, 2, 1) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) + filter = self.delta_y_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + def bar_z(self, X): nz, ny, nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) + filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) def pad_z(self, X): @@ -269,15 +235,14 @@ def pad_z(self, X): def delta_z(self, X): nz, ny, nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) + filter = torch.tensor([-1., 1.]).view(1, 1, 2, 1, 1) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) def delta_D(self, X): nz, ny, nx = X.shape filter = torch.tensor( [[1., 0.], - [0., -1.]], - device=X.device + [0., -1.]] ).view(1, 1, 1, 2, 2) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) @@ -285,8 +250,7 @@ def delta_E(self, X): nz, ny, nx = X.shape filter = torch.tensor( [[0., 1.], - [-1., 0.]], - device=X.device + [-1., 0.]] ).view(1, 1, 1, 2, 2) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) @@ -294,8 +258,7 @@ def bar_xy(self, X): nz, ny, nx = X.shape filter = torch.tensor( [[1., 0.], - [0., 1.]], - device=X.device + [0., 1.]] ).view(1, 1, 1, 2, 2) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) @@ -311,14 +274,51 @@ def laplas(self, X): [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]]], - device=X.device ).view(1, 1, 3, 3, 3) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 2, ny - 2, nx - 2) +class LoopVariables(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): + # for var in variables + constants: + # print("### var = {}, type = {}".format(var, type(var))) + shapes = [list(var.size()) for var in variables + constants] + dtypes = [var.dtype for var in variables + constants] + batch_dims = [0] * (len(variables) + len(constants)) + super().__init__(shapes, dtypes, batch_dims) + self.variables = list() + self.constants = list() + for var in variables: + if torch.is_tensor(var) and var.device != torch.cuda.current_device(): + var = var.cuda() + self.variables.append(var) + for const in constants: + if torch.is_tensor(const) and const.device != torch.cuda.current_device(): + const = const.cuda() + self.constants.append(const) + + def __iter__(self): + return self + + def update(self, variables: List[torch.Tensor] = None, constants: List[torch.Tensor] = None): + if variables is not None: + self.variables = variables + if constants is not None: + self.constants = constants + + def reset(self, batch_size): + pass + + def __next__(self): + if len(self.variables) + len(self.constants) == 1: + return (self.variables + self.constants)[0] + return tuple(self.variables + self.constants) + if __name__ == "__main__": import matplotlib.pyplot as plt + cube.init() nz = 15 ny = 100 @@ -328,36 +328,178 @@ def laplas(self, X): x0 = 0.0 y0 = 0.2 - atmosphere = Atmoshpere(nz, ny, nx, dy, dx, x0, y0) + PSEA = 101325. # sea level pressure, unit in Pa + RE = 6.4e6 # radius of earth, unit in m xc = nx * dx / 2 + x0 yc = ny * dy / 2 + y0 - X = torch.linspace(0, nx - 1, nx).view(1, 1, nx).cuda() * dx + x0 - Y = torch.linspace(0, ny - 1, ny).view(1, ny, 1).cuda() * dy + y0 - ps = torch.ones((1, ny, nx)).cuda() * atmosphere.PSEA - 300 * torch.exp( - - 1e-6 * ((atmosphere.RE * torch.cos((Y + yc) / 2)) * (X - xc))**2 - - 1e-6 * (atmosphere.RE * (Y - yc))**2) + X = torch.linspace(0, nx - 1, nx).view(1, 1, nx) * dx + x0 + Y = torch.linspace(0, ny - 1, ny).view(1, ny, 1) * dy + y0 + ps = torch.ones((1, ny, nx)) * PSEA - 300 * torch.exp( + - 1e-6 * ((RE * torch.cos((Y + yc) / 2)) * (X - xc))**2 + - 1e-6 * (RE * (Y - yc))**2) pt = 250e2 - zs = torch.zeros((ny, nx)).cuda() + 10000 * torch.exp( - - 1e-6 * (atmosphere.RE * (X - nx * dx / 3 - x0))**2 - - 1e-6 * (atmosphere.RE * (Y - yc))**2) + zs = torch.zeros((ny, nx)) + 10000 * torch.exp( + - 1e-6 * (RE * (X - nx * dx / 3 - x0))**2 + - 1e-6 * (RE * (Y - yc))**2) + + u = torch.zeros((nz, ny, nx + 1)) + v = torch.zeros((nz, ny + 1, nx)) + + dt = torch.tensor(1.) + + # physics constant + g = 9.8 # acceleration of gravity, unit in m/s^2 + PSEA = 101325. # sea level pressure, unit in Pa + KAPPA = 0.286 # dimensionless + RE = 6.4e6 # radius of earth, unit in m + CPD = 1004.67 # specific heat of dry air at constant pressure J*kg^-1*K^-1 + # OMEGA = 7.292e-5 # angular speed of the Earth s^-1 + OMEGA = 1e-1 # angular speed of the Earth s^-1 + + # atmoshpere verticle profile + hight_profile = torch.tensor([ + 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, + 8.5, 9, 9.5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100 + ]) * 1e3 + + pressure_profile = torch.tensor([ + 1013.25, 1001.20, 989.45, 977.72, 966.11, 954.61, 943.22, 931.94, 920.77, 909.71, 898.80, 845.59, 795.0, + 746.9, 701.2, 657.8, 616.6, 577.5, 540.5, 505.4, 472.2, 440.7, 411.1, 383.0, 356.5, 331.5, 308.0, 285.8, + 265.0, 227.0, 194.0, 165.8, 141.7, 121.1, 103.5, 88.5, 75.7, 64.7, 55.3, 47.3, 40.5, 34.7, 29.7, 25.5, 21.9, + 18.8, 16.2, 13.9, 12.0, 10.3, 8.89, 7.67, 6.63, 5.75, 4.99, 4.33, 3.77, 3.29, 2.87, 2.51, 2.20, 1.93, 1.69, + 1.49, 1.31, 1.16, 1.02, 0.903, 0.903, 0.425, 0.220, 0.109, 0.0522, 0.0239, 0.0105, 0.0045, 0.0018, 0.00076, + 0.00032 + ]) * 1e2 + + temperature_profile = torch.tensor([ + 288.15, 287.50, 286.85, 286.20, 285.55, 284.90, 284.25, 283.60, 282.95, 282.30, 281.65, 278.40, 275.15, + 271.91, 268.66, 265.41, 262.17, 258.92, 255.68, 252.43, 249.19, 245.94, 242.70, 239.46, 236.22, 232.97, + 229.73, 226.49, 223.25, 216.78, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, + 217.58, 218.57, 219.57, 220.56, 221.55, 222.54, 223.54, 224.53, 225.52, 226.51, 227.50, 228.49, 230.97, + 233.74, 236.51, 239.28, 242.05, 244.82, 247.58, 250.35, 253.11, 255.88, 258.64, 261.40, 264.16, 266.93, + 269.68, 270.65, 270.65, 270.65, 260.77, 247.02, 233.29, 219.59, 208.40, 198.64, 188.89, 186.87, 188.42, + 195.08 + ]) + + density_profile = torch.tensor([ + 1.225, 1.213, 1.202, 1.190, 1.179, 1.167, 1.156, 1.145, 1.134, 1.123, 1.112, 1.058, 1.007, 0.957, 0.909, + 0.863, 0.819, 0.777, 0.736, 0.697, 0.660, 0.624, 0.590, 0.557, 0.526, 0.496, 0.467, 0.440, 0.414, 0.365, + 0.312, 0.267, 0.228, 0.195, 0.166, 0.142, 0.122, 0.104, 0.0889, 0.0757, 0.0645, 0.0550, 0.0469, 0.0401, + 0.0343, 0.0293, 0.0251, 0.0215, 0.0184, 0.0158, 0.0136, 0.0116, 0.00989, 0.00846, 0.00726, 0.00624, 0.00537, + 0.00463, 0.00400, 0.00346, 0.00299, 0.00260, 0.00226, 0.00197, 0.00171, 0.0015, 0.00132, 0.00116, 0.00103, + 5.7e-4, 3.1e-4, 1.6e-4, 8.3e-4, 4.0e-5, 1.8e-5, 8.2e-6, 3.4e-6, 7.5e-7, 5.6e-7 + ]) + + def hight_from_pressure(p): + ind0 = torch.abs((p[None] - pressure_profile[(..., ) + (None, ) * len(p.shape)])).argmin(axis=0) + ind1 = (p > pressure_profile[ind0]) * 2 - 1 + ind0 + hight = (hight_profile[ind1] - hight_profile[ind0]) * (p - pressure_profile[ind0]) / ( + pressure_profile[ind1] - pressure_profile[ind0]) + hight_profile[ind0] + return hight + + def pressure_from_hight(z): + ind0 = torch.abs((z[None] - hight_profile[(..., ) + (None, ) * len(z.shape)])).argmin(axis=0) + ind1 = (hight_profile[ind0] > z) * 2 - 1 + ind0 + p = (pressure_profile[ind1] - pressure_profile[ind0]) * (z - hight_profile[ind0]) / \ + (hight_profile[ind1] - hight_profile[ind0]) + pressure_profile[ind0] + return p + + def temperature_from_pressure(p): + ind0 = torch.abs((p[None] - pressure_profile[(..., ) + (None, ) * len(p.shape)])).argmin(axis=0) + ind1 = (p > pressure_profile[ind0]) * 2 - 1 + ind0 + T = (temperature_profile[ind1] - temperature_profile[ind0]) * (p - pressure_profile[ind0]) / ( + pressure_profile[ind1] - pressure_profile[ind0]) + temperature_profile[ind0] + return T + + def bar_y(X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 1.]).view(1, 1, 1, 2, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) + + def delta_z(X): + nz, ny, nx = X.shape + filter = torch.tensor([-1., 1.]).view(1, 1, 2, 1, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) + + def init(ps, pt, zs): + Y = ( + torch.linspace(0, ny, ny + 1) * dy + y0 + ).view(1, ny + 1, 1) + deltaA = RE**2 * torch.cos(bar_y(Y) * 0.5) * dx * dy # (1, ny, 1) + f = 2 * OMEGA * torch.sin(bar_y(Y)) * torch.cos(bar_y(Y)) * RE # (nz, ny, nx) + + # vertical grids + pt = torch.tensor([pt]).view(1, 1, 1) + zt = hight_from_pressure(pt) + z = torch.linspace(1, 0, nz + 1).view(-1, 1, 1) * zt + p_ = pressure_from_hight(z) + sigma = (p_ - pt) / (p_[-1] - pt) # (nz + 1, 1, 1) + + # column pressure, with shape (1, ny, nx) + pi = (ps - pt).view(1, ny, nx) + + # potential temperature factor + p_ = pt + sigma * pi # (nz + 1, ny, nx) + P_ = (p_ / PSEA)**KAPPA # (nz + 1, ny, nx) + P = delta_z(p_ * P_) / delta_z(p_) / (1 + KAPPA) # (nz, ny, nx) + + # potential temperature + p = PSEA * P**(1 / KAPPA) + T = temperature_from_pressure(p) + theta = T / P + + # geopotential (nz, ny, nx) + phi = torch.zeros((nz, ny, nx)) + zs = zs + + # vertical velocity + w = torch.zeros((nz + 1, ny, nx)) + + return pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w + + + pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w = init(ps, pt, zs) + print("[pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w]") + for var in [pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w]: + print(f'shape {var.shape}') + + bar_x_filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) + bar_y_filter = torch.tensor([1., 1.]).view(1, 1, 1, 2, 1) + delta_x_filter = torch.tensor([-1., 1.]).view(1, 1, 1, 1, 2) + delta_y_filter = torch.tensor([-1., 1.]).view(1, 1, 1, 2, 1) + + + model = Atmoshpere(nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, + bar_x_filter, bar_y_filter, delta_x_filter, delta_y_filter) + + print("[pi, theta, u, v, dt]") + for var in [pi, theta, u, v, dt]: + print(f'shape {var.shape}') - pi, theta = atmosphere.init(ps, pt, zs) + varloader = LoopVariables(variables=[pi, theta, u, v], constants=[dt]) + model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes)) - u = torch.zeros((nz, ny, nx + 1)).cuda() - v = torch.zeros((nz, ny + 1, nx)).cuda() + @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) + def train_iter(model, dataloader): + pi, theta, u, v, dt = next(dataloader) + pi, theta, u, v = model(pi, theta, u, v, dt) + return pi, theta, u, v + model = model.get_gen_module() - for i in range(100): - pi, theta, u, v = atmosphere(1., pi, theta, u, v) + for i in range(3): + print("iter-{}...".format(i)) + pi, theta, u, v = train_iter(model, varloader) - # ctf = plt.contourf(pi.view(ny, nx).numpy(), levels=50, cmap='jet') - plt.cla() - ct = plt.contour(zs.view(ny, nx).cpu().numpy(), levels=[7000]) - ctf = plt.contourf(u[3].cpu().numpy(), levels=50, cmap='jet') - plt.colorbar(ctf) - # plt.grid(True) - plt.tight_layout() - plt.savefig(f'res2/res{i}.jpeg', dpi=300) - plt.clf() + # # ctf = plt.contourf(pi.view(ny, nx).numpy(), levels=50, cmap='jet') + # plt.cla() + # ct = plt.contour(zs.view(ny, nx).cpu().numpy(), levels=[7000]) + # ctf = plt.contourf(u[3].cpu().numpy(), levels=50, cmap='jet') + # plt.colorbar(ctf) + # # plt.grid(True) + # plt.tight_layout() + # plt.savefig(f'res2/res{i}.jpeg', dpi=300) + # plt.clf() - print(i) \ No newline at end of file + # print(i) From 25a0f0641a8f70af8d2edafe28540ec11271bb42 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 6 Mar 2022 20:50:46 +0800 Subject: [PATCH 0624/1892] fix * expand bugs --- cube/graph/operator/function/einops.py | 18 ++++----- cube/graph/parser/parser.py | 52 +++++++++++++------------- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index ae1ca2f9..22ad464c 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -220,6 +220,13 @@ def parse_shape(self, shape: str) -> List[EinDim]: def identifiers(self) -> Set[str]: return copy.copy(self._identifiers) + def reset_identifiers(self): + self._identifiers = set() + for eshape in self.inputs + self.outputs: + for edim in eshape: + for name in edim.names(): + self._identifiers.add(name) + def __repr__(self) -> str: inputs = ', '.join([repr(input) for input in self.inputs]) outputs = ', '.join(repr(output) for output in self.outputs) @@ -274,7 +281,6 @@ def infer_shape(self) -> bool: self._oannos = oannos if ret: break if not ret: - print(f'self._annos = {self._annos}, self._adapt = {self._adapt}') raise RuntimeError("No matching anno for given annos") dimlen: Dict[str, int] = dict() for input, ishape in zip(self.inputs(), self._iannos): @@ -341,11 +347,9 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei """ parse annotations, assuming input tensor shape is given """ - print(f"anno = {anno}; anno.inputs = {anno.inputs}; self.inputs = {self.inputs()}") if len(anno.inputs) != len(self.inputs()): return False, None, None identifiers = anno.identifiers() - print(f'identifiers = {identifiers}') # expand * expand_dims = None @@ -359,12 +363,10 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei for idx, (names, input) in enumerate(zip(in_names, self.inputs())): if '*' in names: if not isinstance(input, IRTensor): - print('Ln 362') return False, None, None pos = names.index('*') span = len(self.inputs(idx).shape) - (len(names) - 1) if expand_dims is not None and len(expand_dims) != span: - print('Ln 367') return False, None, None if expand_dims is None: expand_dims = [] @@ -373,31 +375,27 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei anno.inputs[idx] = anno.inputs[idx][:pos] + expand_dims + anno.inputs[idx][pos+1:] # * should appear in inputs if expand_dims is None: - print('Ln 376') return False, None, None # go through outputs for idx, names in enumerate(out_names): if '*' in names: pos = names.index('*') anno.outputs[idx] = anno.outputs[idx][:pos] + expand_dims + anno.outputs[idx][pos+1:] + anno.reset_identifiers() # check dimension consistency dimlen: Dict[str, int] = dict() for eshape, input in zip(anno.inputs, self.inputs()): if not isinstance(input, IRTensor): if not (len(eshape) == 1 and eshape[0].name == '1'): - print('Ln 388') return False, None, None else: if len(input.shape) != len(eshape): - print('Ln 392') return False, None, None for edim, nele in zip(eshape, input.shape): if edim.name in dimlen: if nele != dimlen[edim.name]: - print('Ln 397') return False, None, None dimlen[edim.name] = nele - print('Ln 400') return True, anno.inputs, anno.outputs def einexpr(self) -> str: diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index ea1ebf76..30d19bce 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -36,8 +36,6 @@ def parse_module(module, The overall entry to parse a torchscript graph module """ frame.push() - print(module.graph) - print(module.code) # handle graph input -- Assuming all the inputs are tensors input_var_name = [input.debugName() for input in module.graph.inputs()] @@ -58,21 +56,20 @@ def parse_module(module, all_ir_nodes: List[IRFwOperation] = list() for node in module.graph.nodes(): - # debug info - # print(f'on parsing:\n\t{node}') ir_nodes = ScriptModuleParser.parse_node(node, module, frame) - # print(f'> {frame}') - # print(f'> {ir_nodes}') - # _ = input('>>>') - if len(ir_nodes) != 0: - for ir_node in ir_nodes: - try: - ret = ir_node.infer_shape() - if not ret: - print(f'warning: {ir_node} cannot infer shape') - except Exception: - raise RuntimeError(f"Shape infer error at: {ir_node}") - all_ir_nodes += ir_nodes + for ir_node in ir_nodes: + try: + ret = ir_node.infer_shape() + if not ret: + print(f'warning: {ir_node} cannot infer shape') + except Exception: + raise RuntimeError( + f"====== Shape Infer Error ====\n\n\n" + f"IR Node: {ir_node}\n\n" + f"Node:\n{node}\n" + f"====== Shape Infer Error ====\n\n\n" + ) + all_ir_nodes += ir_nodes # handle graph output -- Assuming all the output are tensors output_var_name = [output.debugName() for output in module.graph.outputs()] @@ -107,15 +104,20 @@ def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): all_ir_nodes: List[IRFwOperation] = list() for node in method.graph.nodes(): ir_nodes = ScriptModuleParser.parse_node(node, module, frame) - if len(ir_nodes) != 0: - for ir_node in ir_nodes: - try: - ret = ir_node.infer_shape() - if not ret: - print(f'warning: {ir_node} cannot infer shape') - except Exception: - raise RuntimeError(f"Shape infer error at: {ir_node}") - all_ir_nodes += ir_nodes + for ir_node in ir_nodes: + try: + ret = ir_node.infer_shape() + if not ret: + print(f'warning: {ir_node} cannot infer shape') + except Exception: + raise RuntimeError( + f"====== Shape Infer Error ====\n\n\n" + f"IR Node: {ir_node}\n\n" + f"Module:\n{module.code}\n\n" + f"Node:\n{node}\n" + f"====== Shape Infer Error ====\n\n\n" + ) + all_ir_nodes += ir_nodes # handle graph output -- Assuming all the output are tensors output_var_name = [output.debugName() for output in method.graph.outputs()] From bc06599f553c7d7769426c8f18a602317671dd6e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 6 Mar 2022 21:05:18 +0800 Subject: [PATCH 0625/1892] add 1f1b-pack --- handcraft/pipeline/dummy.py | 145 +++++++++++++++++++----- handcraft/pipeline/run.sh | 4 + handcraft/pipeline/schedule.py | 201 ++++++++++++++++++++++++--------- 3 files changed, 272 insertions(+), 78 deletions(-) diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py index a1e8e9e0..0ee82835 100644 --- a/handcraft/pipeline/dummy.py +++ b/handcraft/pipeline/dummy.py @@ -20,7 +20,7 @@ from cube.runtime.device import DeviceGroup from cube.runtime.syndata import SynDataLoader, SynTextDataLoader -from handcraft.pipeline.schedule import schedule_tp_1f1b, schedule_naive +from handcraft.pipeline.schedule import schedule_tp_1f1b, schedule_naive, schedule_tp_1f1b_pack import argparse """ @@ -58,31 +58,64 @@ def backward(ctx, grad_output): return grad_output +class DummyModelEmbed(torch.nn.Module): + + def __init__(self, M: int, N: int, E: int, stage_id: int, sharding=False): + super().__init__() + self.M = M + self.N = N + self.E = E + self.sharding = sharding + chunk_num = torch.distributed.get_world_size() if sharding else 1 + self.vocab_start_index = N // chunk_num * stage_id + self.vocab_end_index = N // chunk_num * (stage_id + 1) + self.embed_weight = torch.nn.Parameter(torch.zeros((N // chunk_num, E))) + + def input_shape(self): + return (self.M, ) + + def input_dtype(self): + return torch.int64 + + def output_shape(self): + return (self.M, self.E) + + def output_dtype(self): + return torch.float32 + + def forward(self, input: torch.Tensor): + if self.sharding: + mask = (input < self.vocab_start_index) | \ + (input >= self.vocab_end_index) + input = input.clone() - self.vocab_start_index + input[mask] = 0 + input = F.embedding(input, self.embed_weight) + input = ReduceEmbed.apply(input) + else: + input = F.embedding(input, self.embed_weight) + return input + class DummyModel(torch.nn.Module): - def __init__(self, M: int, N: int, E: int, stage_id: int, sharding=False): + def __init__(self, M: int, N: int, E: int, stage_id: int, + sharding=False, embed: torch.nn.Module = None): super().__init__() self.M = M self.N = N self.E = E self.is_last_stage = stage_id == DeviceGroup().world_size - 1 - self.is_first_stage = stage_id == 0 self.sharding = sharding + # mebed module + self.embed = embed # first stage chunk_num = torch.distributed.get_world_size() if sharding else 1 - if self.is_first_stage: - self.vocab_start_index = N // chunk_num * stage_id - self.vocab_end_index = N // chunk_num * (stage_id + 1) - self.embed_weight = torch.nn.Parameter(torch.zeros((N // chunk_num, E))) - self.fc_weight = torch.nn.Parameter(torch.zeros((E // chunk_num, E))) - else: - self.fc_weight = torch.nn.Parameter(torch.zeros((E // chunk_num, E))) + self.fc_weight = torch.nn.Parameter(torch.zeros((E // chunk_num, E))) def input_shape(self): - if self.is_first_stage: - return (self.M,) + if self.embed: + return self.embed.input_shape() else: return (self.M, self.E) @@ -93,8 +126,8 @@ def output_shape(self): return (self.M, self.E) def input_dtype(self): - if self.is_first_stage: - return torch.int64 + if self.embed: + return self.embed.input_dtype() else: return torch.float32 @@ -102,16 +135,8 @@ def output_dtype(self): return torch.float32 def forward(self, input: torch.Tensor): - if self.is_first_stage: - if self.sharding: - mask = (input < self.vocab_start_index) | \ - (input >= self.vocab_end_index) - input = input.clone() - self.vocab_start_index - input[mask] = 0 - input = F.embedding(input, self.embed_weight) - input = ReduceEmbed.apply(input) - else: - input = F.embedding(input, self.embed_weight) + if self.embed: + input = self.embed(input) if self.sharding: input = IdentityFoward.apply(input) @@ -122,6 +147,52 @@ def forward(self, input: torch.Tensor): return output +class DummyModelTP(torch.nn.Module): + + def __init__(self, M: int, N: int, E: int, stage_id: int): + super().__init__() + self.M = M + self.N = N + self.E = E + self.stages = DeviceGroup().world_size + + self.vocab_start_index = N // self.stages * stage_id + self.vocab_end_index = N // self.stages * (stage_id + 1) + self.embed_weight = torch.nn.Parameter(torch.zeros((N // self.stages, E))) + self.fc_weights = torch.nn.ParameterList() + for idx in range(self.stages): + if idx % 2 == 0: + self.fc_weights.append( + torch.nn.Parameter(torch.zeros((E // self.stages, E))) + ) + else: + self.fc_weights.append( + torch.nn.Parameter(torch.zeros((E, E // self.stages))) + ) + + def forward(self, input: torch.Tensor): + mask = (input < self.vocab_start_index) | \ + (input >= self.vocab_end_index) + input = input.clone() - self.vocab_start_index + input[mask] = 0 + input = F.embedding(input, self.embed_weight) + x = ReduceEmbed.apply(input) + for idx in range(self.stages): + # column partition + if idx % 2 == 0: + x = IdentityFoward.apply(x) + x = F.linear(x, self.fc_weights[idx]) + else: + x = ReduceEmbed.apply(x) + x = F.linear(x, self.fc_weights[idx]) + # reduce + if self.stages % 2 == 0: + x = ReduceEmbed.apply(x) + else: + raise RuntimeError("number of stages only supported to be mod 2 == 0") + return torch.sum(x) + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='swin') @@ -129,6 +200,10 @@ def forward(self, input: torch.Tensor): help='use naive pipeline') parser.add_argument('--use-tp1f1b', action='store_true', help='use tensor parallel 1f1b') + parser.add_argument('--use-tp1f1b-pack', action='store_true', + help='use tensor parallel 1f1b') + parser.add_argument('--use-tp', action='store_true', + help='use pure tensor parallelism') parser.add_argument('--nmb', type=int, default=4, help='num of micro batch') parser.add_argument('--M', type=int, default=4096, @@ -138,7 +213,7 @@ def forward(self, input: torch.Tensor): parser.add_argument('--E', type=int, default=2048, help='E dimension length = hidden dimension length') args = parser.parse_args() - assert args.use_naive ^ args.use_tp1f1b, "Specify (only) 1 way pipeline" + print(args) cube.init() @@ -146,7 +221,8 @@ def forward(self, input: torch.Tensor): # tp 1f1b if args.use_tp1f1b: - first_stage_model = DummyModel(args.M, args.N, args.E, 0, sharding=True).cuda() + embed = DummyModelEmbed(args.M, args.N, args.E, 0, sharding=True).cuda() + first_stage_model = DummyModel(args.M, args.N, args.E, 0, sharding=True, embed=embed).cuda() if rank == 0: model = None else: @@ -154,8 +230,18 @@ def forward(self, input: torch.Tensor): if args.use_naive: # naive pipleline + embed = None + if rank == 0: + embed = DummyModelEmbed(args.M, args.N, args.E, 0, sharding=False).cuda() + model = DummyModel(args.M, args.N, args.E, rank, sharding=False, embed=embed).cuda() + + if args.use_tp1f1b_pack: + embed = DummyModelEmbed(args.M, args.N, args.E, 0, sharding=True).cuda() model = DummyModel(args.M, args.N, args.E, rank, sharding=False).cuda() + if args.use_tp: + model = DummyModelTP(args.M, args.N, args.E, rank).cuda() + dataloader = SynTextDataLoader( shapes=([args.M],), dtypes=(torch.int64, ), @@ -173,6 +259,13 @@ def forward(self, input: torch.Tensor): schedule_tp_1f1b(model, first_stage_model, dataloader, args.nmb, DeviceGroup().world_size) if args.use_naive: schedule_naive(model, dataloader, args.nmb) + if args.use_tp1f1b_pack: + schedule_tp_1f1b_pack(model, embed, dataloader, args.nmb, DeviceGroup().world_size) + if args.use_tp: + for _ in range(args.nmb): + data = next(dataloader) + loss = model(data) + loss.backward() if step >= 20: CudaTimer().stop('e2e') diff --git a/handcraft/pipeline/run.sh b/handcraft/pipeline/run.sh index f68d3370..ed60adc1 100755 --- a/handcraft/pipeline/run.sh +++ b/handcraft/pipeline/run.sh @@ -18,6 +18,10 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/pipeline/dummy.py --use-naive --nmb 64 > 4dev64nmb-naive.txt +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/pipeline/dummy.py --use-tp --nmb 64 > 4dev64nmb-tp.txt + + # 8 gpus OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index 6c8ce99e..0624c806 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -277,10 +277,11 @@ def tp_backward(grad: torch.Tensor): # print_each_rank(f'=========end rank {rank}=========') -def schedule_1f1b(model: torch.nn.Module, - dataloader, - num_microbatch: int): - group = list(range(DeviceGroup().world_size)) +def schedule_tp_1f1b_pack(model: torch.nn.Module, + first_stage_model: torch.nn.Module, + dataloader, + num_microbatch: int, + num_stage: int): rank = DeviceGroup().rank next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size @@ -288,55 +289,151 @@ def schedule_1f1b(model: torch.nn.Module, input_tensors = list() output_tensors = list() - # warmup - num_warmup_microbatch = DeviceGroup().world_size - 1 - rank - for mid in range(num_warmup_microbatch): - # recv forward - input = recv_input(model, dataloader, prev_rank) - # forward - output = forward_step(model, input) - # send forward - coll.send(output, next_rank) - input_tensors.append(input) - output_tensors.append(output) + input_1st_tensors = list() + output_1st_tensors = list() - num_warmup_remaining = num_microbatch - num_warmup_microbatch - if num_warmup_remaining > 0: - input = recv_input(model, dataloader, prev_rank) + def tp_forward(fmodel, dataloader) -> torch.Tensor: + input = next(dataloader) + #TODO: gather + output = forward_step(fmodel, input) + input_1st_tensors.append(input) + output_1st_tensors.append(output) + output = output.detach().requires_grad_() + return output - # steady - for i in range(num_warmup_microbatch): - # forward - output = forward_step(model, input) - # send forward + recv backward - grad = coll.sendrecv( - [output], - [list(output.size())], [output.dtype], - [next_rank], [next_rank] - )[0] - input_tensors.append(input) - output_tensors.append(output) - # backward - input, output = input_tensors.pop(0), output_tensors.pop(0) - input_grad = backward_step([input], [output], [grad]) - # send backward recv forward - if i != (num_warmup_remaining-1): - input = coll.sendrecv( - [input_grad], - (list(input.size()),), (input.dtype,), - [prev_rank], [prev_rank] - ) - else: - # send backward - coll.send(input_grad, prev_rank) + def tp_backward(grad: torch.Tensor): + input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) + if rank != 0: + grad = torch.empty_like(output_1st) + torch.distributed.broadcast(grad, src=0) + backward_step([input_1st], [output_1st], [grad])[0] - # cooldown - for i in range(num_warmup_microbatch): - input, output = input_tensors.pop(0), output_tensors.pop(0) - # recv backward - grad = coll.recv(list(output.size()), next_rank, dtype=output.dtype) - # backward - grad = backward_step([input], [output], [grad]) - # send backward - coll.send(grad, prev_rank) + fofst = [-(step // 2) for step in range(num_stage)] + bofst = [-(num_stage - 1 - (step // 2)) for step in range(num_stage)] + # print(fofst) + # print(bofst) + fofst = fofst[rank] + bofst = bofst[rank] + last_backward = None + last_forward = None + for step in range(num_microbatch + num_stage - 1): + torch.distributed.barrier() + # print_each_rank(f'=========begin rank {rank}=========') + fmid, bmid = step + fofst, step + bofst + do_backward = 0 <= bmid and bmid <= num_microbatch - 1 + do_forward = 0 <= fmid and fmid <= num_microbatch - 1 + + # step1: tp forward + if 0 <= step and step <= num_microbatch - 1: + # print(f'rank {rank} forward tp model ') + output_1st = tp_forward(first_stage_model, dataloader) + + # forward + backward + if rank % 2 == 0: + # inter-barrier + if rank == 0: + input = output_1st + else: + if do_forward and last_backward is not None: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + elif last_backward is not None: + # print(f'rank {rank} send backward grad ') + coll.send(last_backward, prev_rank) + + # forward + if do_forward: + output = forward_step(model, input) + input_tensors.append(input) + output_tensors.append(output) + + # intra-barrier send recv + output_grad = None + if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): + # send forward recv backward + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [output], [output.size()], [output.dtype], + [next_rank], [next_rank] + )[0] + elif do_forward and not is_last_stage(): + # print(f'rank {rank} send forward output ') + coll.send(output, next_rank) + elif do_backward and not is_last_stage(): + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + + # backward + if do_backward: + input, output = input_tensors.pop(0), output_tensors.pop(0) + input_grad = backward_step([input], [output], [output_grad])[0] + last_backward = input_grad + + # backward + forward + if rank % 2 == 1: + # inter-barrier + if is_last_stage(): + output_grad = None + else: + if do_backward and last_forward is not None: + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [last_forward], [model.output_shape()], [model.output_dtype()], + [next_rank], [next_rank] + )[0] + elif do_backward: + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + elif last_forward is not None: + # print(f'rank {rank} send forward output ') + coll.send(last_forward, next_rank) + + # backward + if do_backward: + input, output = input_tensors.pop(0), output_tensors.pop(0) + # backward + input_grad = backward_step([input], [output], [output_grad])[0] + + # intra-barrier + if do_backward and do_forward: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_backward: + # print(f'rank {rank} send backward grad ') + coll.send(input_grad, prev_rank) + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + + # forward + last_forward = None + if do_forward: + # forward step + output = forward_step(model, input) + input_tensors.append(input) + output_tensors.append(output) + last_forward = output + + # step3: tp backward + if 0 <= (step-num_stage+1) and (step-num_stage+1) <= num_microbatch - 1: + # print(f'rank {rank} backward tp model ') + tp_backward(last_backward) + + # print_each_rank(f'=========end rank {rank}: {step}=========') + + assert len(input_tensors) == 0 + assert len(output_tensors) == 0 + + assert len(input_1st_tensors) == 0 + assert len(output_1st_tensors) == 0 + # print_each_rank(f'=========end rank {rank}=========') \ No newline at end of file From 655a26ae413f47a363e049cb564b4f023c66c53a Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 7 Mar 2022 11:57:06 +0800 Subject: [PATCH 0626/1892] add neg sin cos, update weather.py --- cube/graph/operator/function/function.py | 35 +++ cube/graph/parser/mapping.py | 6 + examples/atmosphere/weather.py | 300 ++++++++++++++--------- 3 files changed, 223 insertions(+), 118 deletions(-) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 7fbb33a3..530c01ae 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -178,8 +178,43 @@ def Div(signature, inputs): oshape[dim] = lshape[dim] rshape[dim] = str(rhs.shape[dim]) annos = [_create_anno([lshape, rshape], [oshape])] + print(f"Div::annos = {annos}") return IREinops(signature, annos, inputs, 'div') +def Neg(signature, inputs): + annos = ['* -> *'] + tensor = inputs[0:1] + if len(inputs) == 2: + # adapt for newest pytorch version + approximate = inputs[1] + return IREinops(signature, annos, tensor, 'neg', + approximate=approximate) + else: + return IREinops(signature, annos, tensor, 'neg') + +def Sin(signature, inputs): + annos = ['* -> *'] + tensor = inputs[0:1] + if len(inputs) == 2: + # adapt for newest pytorch version + approximate = inputs[1] + return IREinops(signature, annos, tensor, 'sin', + approximate=approximate) + else: + return IREinops(signature, annos, tensor, 'sin') + + +def Cos(signature, inputs): + annos = ['* -> *'] + tensor = inputs[0:1] + if len(inputs) == 2: + # adapt for newest pytorch version + approximate = inputs[1] + return IREinops(signature, annos, tensor, 'cos', + approximate=approximate) + else: + return IREinops(signature, annos, tensor, 'cos') + def GeLU(signature, inputs): annos = ['* -> *'] diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 39546b8a..815d2b0e 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -71,6 +71,12 @@ def register(signature: str, op: IRFwOperation): __ttemplate('div') : function.Div, + __ttemplate('neg'): function.Neg, + + __ttemplate('sin'): function.Sin, + + __ttemplate('cos'): function.Cos, + __ttemplate('bmm') : function.BatchLinear, __ttemplate('sum') : function.Sum, diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index 1224382f..51dae7e9 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -15,7 +15,7 @@ class Atmoshpere(torch.nn.Module): def __init__(self, nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, - bar_x_filter, bar_y_filter, delta_x_filter, delta_y_filter, + bar_x_filter, bar_y_filter, delta_x_filter, delta_y_filter, bar_y2_filter,bar_x2_filter,bar_z_filter,delta_z_filter,delta_E_filter,laplas_filter,bar_xy_filter,delta_D_filter, device='cuda'): super().__init__() #self.device = torch.device(device) @@ -53,6 +53,14 @@ def __init__(self, self.bar_y_filter = bar_y_filter self.delta_x_filter = delta_x_filter self.delta_y_filter = delta_y_filter + self.bar_y2_filter = bar_y2_filter + self.bar_x2_filter = bar_x2_filter + self.bar_z_filter = bar_z_filter + self.delta_z_filter = delta_z_filter + self.delta_E_filter = delta_E_filter + self.laplas_filter = laplas_filter + self.bar_xy_filter = bar_xy_filter + self.delta_D_filter = delta_D_filter self.pre_conv3d_reshape = Rearrange('(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # x = X.view(1, 1, Nz, Ny, Nx) self.post_conv3d_reshape = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') @@ -61,23 +69,24 @@ def __init__(self, def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # flux - F = self.bar_x(self.pad_x(pi)) * u #* self.RE * self.dy # (nz, ny, nx + 1) - G = self.bar_y(self.pad_y(pi)) * v #* self.RE * self.dx # * torch.cos(self.Y) # (nz, ny + 1, nx) - # F = self.bar_x(self.pad_x(pi)) * 0.5 * u * self.RE * self.dy # (nz, ny, nx + 1) - # G = self.bar_y(self.pad_y(pi)) * 0.5 * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) - # B = self.bar_y2(self.bar_x(self.pad_y(F))) / 12. # (nz, ny, nx) - # C = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(G)))) / 12. # (nz, ny + 1, nx + 1) - # D = self.bar_y2(self.pad_y(G)) + self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) - # E = self.bar_y2(self.pad_y(G)) - self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) - # Q = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(F)))) / 12. # (nz, ny + 1, nx + 1) - # R = self.bar_x2(self.bar_y(self.pad_x(G))) / 12. # (nz, ny, nx) - # S = self.bar_y(self.bar_x(self.pad_x(G))) + self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) - # T = self.bar_y(self.bar_x(self.pad_x(G))) - self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) - - # pi1 = pi0 - dt / self.deltaA * ((self.delta_x(F) + self.delta_y(G)) * self.dz).sum(axis=0) # (nz, ny, nx) - tmp0 = ((self.delta_x(F) + self.delta_y(G)) * self.dz) - tmp1 = tmp0.sum(dim=0) - pi1 = pi0 - self.deltaA * tmp1 # pi1 = pi0 - dt / self.deltaA * tmp1 + # F = self.bar_x(self.pad_x(pi)) * u * self.RE * self.dy # (nz, ny, nx + 1) + # G = self.bar_y(self.pad_y(pi)) * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) + F = self.bar_x(self.pad_x(pi)) * 0.5 * u * self.RE * self.dy # (nz, ny, nx + 1) + G = self.bar_y(self.pad_y(pi)) * 0.5 * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) + B = self.bar_y2(self.bar_x(self.pad_y(F))) / 12. # (nz, ny, nx) + C = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(G)))) / 12. # (nz, ny + 1, nx + 1) + D = self.bar_y2(self.pad_y(G)) + self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) + E = self.bar_y2(self.pad_y(G)) - self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) + Q = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(F)))) / 12. # (nz, ny + 1, nx + 1) + R = self.bar_x2(self.bar_y(self.pad_x(G))) / 12. # (nz, ny, nx) + S = self.bar_y(self.bar_x(self.pad_x(G))) + self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) + T = self.bar_y(self.bar_x(self.pad_x(G))) - self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) + + pi1 = pi0 - dt / self.deltaA * ((self.delta_x(F) + self.delta_y(G)) * self.dz).sum(dim=0) #sum(axis=0) # (nz, ny, nx) + # tmp0 = ((self.delta_x(F) + self.delta_y(G)) * self.dz) + # tmp1 = tmp0.sum(dim=0) + # # pi1 = pi0 - self.deltaA * tmp1 # + # pi1 = pi0 - dt / self.deltaA * tmp1 # print('pi:', pi1.mean()) @@ -89,15 +98,16 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # print('w:', self.w.mean()) - # # update potential temperature theta (nz, ny, nx) + # update potential temperature theta (nz, ny, nx) # theta_ = self.pad_z( # (self.bar_z(self.P * theta) - self.delta_z(theta) * self.P_[1:-1]) / self.delta_z(self.P) # ) # (nz + 1, ny, nx) - # theta1 = pi0 / pi1 * theta0 + dt / self.deltaA / pi1 * ( - # (self.delta_x(F * self.bar_x(self.pad_x(theta))) + self.delta_y(G * self.bar_y(self.pad_y(theta)))) / 2. + - # pi * self.deltaA * self.delta_z(self.w * theta_) / self.dz + - # 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(theta)))) - # ) + theta_ = theta0 #TODO remove me + theta1 = pi0 / pi1 * theta0 + dt / self.deltaA / pi1 * ( + (self.delta_x(F * self.bar_x(self.pad_x(theta))) + self.delta_y(G * self.bar_y(self.pad_y(theta)))) / 2. + + pi * self.deltaA * self.delta_z(self.w * theta_) / self.dz + + 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(theta)))) + ) # print('theta:', theta1.mean()) @@ -110,64 +120,72 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # print('phi:', self.phi.mean()) - # # update u (nz, ny, nx + 1) - # pi0_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi0 * self.deltaA)))) / 8. # (nz, ny, nx + 1) - # pi1_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi1 * self.deltaA)))) / 8. # (nz, ny, nx + 1) - # pi_w_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny, nx + 1) - # advec = ( - # - self.delta_x(self.pad_x(B * self.bar_x(u))) - # - self.delta_y(C * self.bar_y(self.pad_y(u))) - # + self.delta_D(self.pad_x(D * self.bar_xy(self.pad_y(u)))) - # + self.delta_E(self.pad_x(E * self.bar_xy(self.pad_y(u)))) - # ) / 2. - # trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(u)) * 0.5) / self.dz - # press = - self.RE * self.dy * ( - # self.delta_x(self.pad_x(self.phi)) * self.delta_x(self.pad_x(pi)) / 2. + - # self.delta_x(self.pad_x(pi)) * 0.5 * self.CPD * self.bar_x(self.pad_x( - # theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) - # )) - # ) - # diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(u)))) - # cori = self.RE * self.dx * self.dy * 0.25 * ( - # self.bar_x(self.pad_x(pi * self.bar_y(v) * (self.f + 0. * self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) - # ) * 0.0 - # u1 = (pi0_deltaA * u0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA - # # print('u1:', u1.mean()) + # update u (nz, ny, nx + 1) + pi0_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi0 * self.deltaA)))) / 8. # (nz, ny, nx + 1) + pi1_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi1 * self.deltaA)))) / 8. # (nz, ny, nx + 1) + pi_w_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny, nx + 1) + advec = ( + - self.delta_x(self.pad_x(B * self.bar_x(u))) + - self.delta_y(C * self.bar_y(self.pad_y(u))) + + self.delta_D(self.pad_x(D * self.bar_xy(self.pad_y(u)))) + + self.delta_E(self.pad_x(E * self.bar_xy(self.pad_y(u)))) + ) / 2. + trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(u)) * 0.5) / self.dz + #TODO fixme press = - self.RE * self.dy * ( + press = self.dy * ( + self.delta_x(self.pad_x(self.phi)) * self.delta_x(self.pad_x(pi)) / 2. + + self.delta_x(self.pad_x(pi)) * 0.5 * self.CPD * self.bar_x(self.pad_x( + theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) + )) + ) + diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(u)))) + #TODO fixme cori = self.RE * self.dx * self.dy * 0.25 * ( + cori = self.dy * ( + self.bar_x(self.pad_x(pi * self.bar_y(v) * (self.f + 0. * self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) + ) * 0.0 + u1 = (pi0_deltaA * u0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA + # print('u1:', u1.mean()) # # update v (nz, ny + 1, nx) - # pi0_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi0 * self.deltaA)))) / 8. # (nz, ny + 1, nx) - # pi1_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi1 * self.deltaA)))) / 8. # (nz, ny + 1, nx) - # pi_w_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny + 1, nx) - # advec = ( - # - self.delta_x(Q * self.bar_x(self.pad_x(v))) - # - self.delta_y(self.pad_y(R * self.bar_y(v))) - # + self.delta_D(self.pad_y(S * self.bar_xy(self.pad_x(v)))) - # + self.delta_E(self.pad_y(T * self.bar_xy(self.pad_x(v)))) - # ) / 2. - # trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(v)) * 0.5) / self.dz - # press = - self.RE * self.dx * ( - # self.delta_y(self.pad_y(self.phi)) * self.delta_y(self.pad_y(pi)) / 2. + - # self.delta_y(self.pad_y(pi)) * 0.5 * self.CPD * self.bar_y(self.pad_y( - # theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) - # )) - # ) - # diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(v)))) - # cori = - self.RE * self.dx * self.dy * 0.25 * ( - # self.bar_y(self.pad_y(pi * self.bar_x(u) * (self.f + self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) - # ) * 0.0 - # v1 = (pi0_deltaA * v0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA + pi0_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi0 * self.deltaA)))) / 8. # (nz, ny + 1, nx) + pi1_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi1 * self.deltaA)))) / 8. # (nz, ny + 1, nx) + pi_w_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny + 1, nx) + advec = ( + - self.delta_x(Q * self.bar_x(self.pad_x(v))) + - self.delta_y(self.pad_y(R * self.bar_y(v))) + + self.delta_D(self.pad_y(S * self.bar_xy(self.pad_x(v)))) + + self.delta_E(self.pad_y(T * self.bar_xy(self.pad_x(v)))) + ) / 2. + trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(v)) * 0.5) / self.dz + #TODO fixme press = - self.RE * self.dx * ( + press = self.dx * ( + self.delta_y(self.pad_y(self.phi)) * self.delta_y(self.pad_y(pi)) / 2. + + self.delta_y(self.pad_y(pi)) * 0.5 * self.CPD * self.bar_y(self.pad_y( + theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) + )) + ) + diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(v)))) + #TODO fixme cori = - self.RE * self.dx * self.dy * 0.25 * ( + cori = self.dy * ( + self.bar_y(self.pad_y(pi * self.bar_x(u) * (self.f + self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) + ) * 0.0 + v1 = (pi0_deltaA * v0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA # # print('v1:', v1.mean()) # return pi1, theta1, u1, v1 # return pi1, theta0, u0, v0 - return pi1, theta0, u0, v0 + return pi1, theta1, u1, v1 def forward(self, pi, theta, u, v, dt): # pi_, theta_, u_, v_ = self.step(dt / 2., pi, theta, u, v, pi, theta, u, v) # return self.step(dt, pi_, theta_, u_, v_, pi, theta, u, v) #TODO recover me - return self.step(dt, pi, theta, u, v, pi, theta, u, v) + # pi1, theta0, u0, v0 = self.step(dt, pi, theta, u, v, pi, theta, u, v) + # return pi1, theta0, u0, v0 + pi1, theta1, u1, v1 = self.step(dt, pi, theta, u, v, pi, theta, u, v) + return pi1, theta1, u1, v1 + def pad_x(self, X): @@ -184,9 +202,11 @@ def bar_x(self, X): return x_unext def bar_x2(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 1, 3) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 2) + # nz, ny, nx = X.shape + # filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 1, 3) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 2) + filter = self.bar_x2_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) def delta_x(self, X): # nz, ny, nx = X.shape @@ -212,9 +232,11 @@ def bar_y(self, X): return x_unext def bar_y2(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 3, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 2, nx) + # nz, ny, nx = X.shape + # filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 3, 1) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 2, nx) + filter = self.bar_y2_filter #torch.tensor([1., 2., 1.]).view(1, 1, 1, 3, 1) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) def delta_y(self, X): # nz, ny, nx = X.shape @@ -223,59 +245,74 @@ def delta_y(self, X): filter = self.delta_y_filter return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) - def bar_z(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) + # nz, ny, nx = X.shape + # filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) + filter = self.bar_z_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) def pad_z(self, X): - nz, ny, nx = X.shape - return F.pad(X.view(1, 1, nz, ny, nx), (0, 0, 0, 0, 1, 1)).view(nz + 2, ny, nx) + # nz, ny, nx = X.shape + # return F.pad(X.view(1, 1, nz, ny, nx), (0, 0, 0, 0, 1, 1)).view(nz + 2, ny, nx) + return F.pad(X, (0, 0, 0, 0, 1, 1)) def delta_z(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([-1., 1.]).view(1, 1, 2, 1, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) + # nz, ny, nx = X.shape + # filter = torch.tensor([-1., 1.]).view(1, 1, 2, 1, 1) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) + filter = self.delta_z_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + def delta_D(self, X): - nz, ny, nx = X.shape - filter = torch.tensor( - [[1., 0.], - [0., -1.]] - ).view(1, 1, 1, 2, 2) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + # nz, ny, nx = X.shape + # filter = torch.tensor( + # [[1., 0.], + # [0., -1.]] + # ).view(1, 1, 1, 2, 2) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + filter = self.delta_D_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) def delta_E(self, X): - nz, ny, nx = X.shape - filter = torch.tensor( - [[0., 1.], - [-1., 0.]] - ).view(1, 1, 1, 2, 2) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + # nz, ny, nx = X.shape + # filter = torch.tensor( + # [[0., 1.], + # [-1., 0.]] + # ).view(1, 1, 1, 2, 2) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + filter = self.delta_E_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + def bar_xy(self, X): - nz, ny, nx = X.shape - filter = torch.tensor( - [[1., 0.], - [0., 1.]] - ).view(1, 1, 1, 2, 2) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + # nz, ny, nx = X.shape + # filter = torch.tensor( + # [[1., 0.], + # [0., 1.]] + # ).view(1, 1, 1, 2, 2) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) + filter = self.bar_xy_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + def laplas(self, X): - nz, ny, nx = X.shape - filter = torch.tensor( - [[[0., 0., 0.], - [0., 1., 0.], - [0., 0., 0.]], - [[0., 1., 0.], - [1., -6, 1.], - [0., 1., 0.]], - [[0., 0., 0.], - [0., 1., 0.], - [0., 0., 0.]]], - ).view(1, 1, 3, 3, 3) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 2, ny - 2, nx - 2) + # nz, ny, nx = X.shape + # filter = torch.tensor( + # [[[0., 0., 0.], + # [0., 1., 0.], + # [0., 0., 0.]], + # [[0., 1., 0.], + # [1., -6, 1.], + # [0., 1., 0.]], + # [[0., 0., 0.], + # [0., 1., 0.], + # [0., 0., 0.]]], + # ).view(1, 1, 3, 3, 3) + # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 2, ny - 2, nx - 2) + filter = self.laplas_filter + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) class LoopVariables(cube.runtime.syndata.CubeDataLoader): @@ -283,7 +320,7 @@ class LoopVariables(cube.runtime.syndata.CubeDataLoader): def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): # for var in variables + constants: # print("### var = {}, type = {}".format(var, type(var))) - shapes = [list(var.size()) for var in variables + constants] + shapes = [list(var.size() if len(var.size()) > 0 else [1]) for var in variables + constants] dtypes = [var.dtype for var in variables + constants] batch_dims = [0] * (len(variables) + len(constants)) super().__init__(shapes, dtypes, batch_dims) @@ -469,10 +506,37 @@ def init(ps, pt, zs): bar_y_filter = torch.tensor([1., 1.]).view(1, 1, 1, 2, 1) delta_x_filter = torch.tensor([-1., 1.]).view(1, 1, 1, 1, 2) delta_y_filter = torch.tensor([-1., 1.]).view(1, 1, 1, 2, 1) + bar_y2_filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 3, 1) + bar_x2_filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 1, 3) + bar_z_filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) + delta_z_filter = torch.tensor([-1., 1.]).view(1, 1, 2, 1, 1) + delta_E_filter = torch.tensor( + [[0., 1.], + [-1., 0.]] + ).view(1, 1, 1, 2, 2) + laplas_filter = torch.tensor( + [[[0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.]], + [[0., 1., 0.], + [1., -6, 1.], + [0., 1., 0.]], + [[0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.]]], + ).view(1, 1, 3, 3, 3) + bar_xy_filter = torch.tensor( + [[1., 0.], + [0., 1.]] + ).view(1, 1, 1, 2, 2) + delta_D_filter = torch.tensor( + [[1., 0.], + [0., -1.]] + ).view(1, 1, 1, 2, 2) model = Atmoshpere(nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, - bar_x_filter, bar_y_filter, delta_x_filter, delta_y_filter) + bar_x_filter, bar_y_filter, delta_x_filter, delta_y_filter, bar_y2_filter, bar_x2_filter, bar_z_filter, delta_z_filter, delta_E_filter, laplas_filter,bar_xy_filter,delta_D_filter) print("[pi, theta, u, v, dt]") for var in [pi, theta, u, v, dt]: From 9764b4a889e333dffe67680a04b1a7abfd3f6671 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Mar 2022 13:50:18 +0800 Subject: [PATCH 0627/1892] correctness verify --- handcraft/pipeline/dummy.py | 70 ++++++++++++++++++++++++---------- handcraft/pipeline/schedule.py | 19 ++++++++- 2 files changed, 67 insertions(+), 22 deletions(-) diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py index 0ee82835..41db0151 100644 --- a/handcraft/pipeline/dummy.py +++ b/handcraft/pipeline/dummy.py @@ -34,6 +34,8 @@ Condition: N > 8M - E """ +io_input = input + class ReduceEmbed(torch.autograd.Function): @staticmethod @@ -60,16 +62,17 @@ def backward(ctx, grad_output): class DummyModelEmbed(torch.nn.Module): - def __init__(self, M: int, N: int, E: int, stage_id: int, sharding=False): + def __init__(self, M: int, N: int, E: int, sharding=False): super().__init__() self.M = M self.N = N self.E = E self.sharding = sharding - chunk_num = torch.distributed.get_world_size() if sharding else 1 + chunk_num = DeviceGroup().world_size if sharding else 1 + stage_id = DeviceGroup().rank if sharding else 1 self.vocab_start_index = N // chunk_num * stage_id self.vocab_end_index = N // chunk_num * (stage_id + 1) - self.embed_weight = torch.nn.Parameter(torch.zeros((N // chunk_num, E))) + self.embed_weight = torch.nn.Parameter(torch.ones((N // chunk_num, E))) def input_shape(self): return (self.M, ) @@ -90,6 +93,7 @@ def forward(self, input: torch.Tensor): input = input.clone() - self.vocab_start_index input[mask] = 0 input = F.embedding(input, self.embed_weight) + input[mask, :] = 0.0 input = ReduceEmbed.apply(input) else: input = F.embedding(input, self.embed_weight) @@ -111,7 +115,7 @@ def __init__(self, M: int, N: int, E: int, stage_id: int, self.embed = embed # first stage chunk_num = torch.distributed.get_world_size() if sharding else 1 - self.fc_weight = torch.nn.Parameter(torch.zeros((E // chunk_num, E))) + self.fc_weight = torch.nn.Parameter(torch.ones((E // chunk_num, E)) / 10000) def input_shape(self): if self.embed: @@ -135,6 +139,7 @@ def output_dtype(self): return torch.float32 def forward(self, input: torch.Tensor): + # print(f'[{DeviceGroup().rank}] input: {input}, shape={input.size()}') if self.embed: input = self.embed(input) @@ -144,6 +149,7 @@ def forward(self, input: torch.Tensor): if self.is_last_stage: output = torch.sum(output) + # print(f'[{DeviceGroup().rank}] output: {output}, shape={output.size()}') return output @@ -158,39 +164,39 @@ def __init__(self, M: int, N: int, E: int, stage_id: int): self.vocab_start_index = N // self.stages * stage_id self.vocab_end_index = N // self.stages * (stage_id + 1) - self.embed_weight = torch.nn.Parameter(torch.zeros((N // self.stages, E))) + self.embed = DummyModelEmbed(M, N, E, sharding=True) self.fc_weights = torch.nn.ParameterList() for idx in range(self.stages): if idx % 2 == 0: self.fc_weights.append( - torch.nn.Parameter(torch.zeros((E // self.stages, E))) + torch.nn.Parameter(torch.ones((E // self.stages, E)) / 10000) ) else: self.fc_weights.append( - torch.nn.Parameter(torch.zeros((E, E // self.stages))) + torch.nn.Parameter(torch.ones((E, E // self.stages)) / 10000) ) def forward(self, input: torch.Tensor): - mask = (input < self.vocab_start_index) | \ - (input >= self.vocab_end_index) - input = input.clone() - self.vocab_start_index - input[mask] = 0 - input = F.embedding(input, self.embed_weight) - x = ReduceEmbed.apply(input) + x = self.embed(input) + # print(f'embed: {x}') for idx in range(self.stages): # column partition if idx % 2 == 0: x = IdentityFoward.apply(x) x = F.linear(x, self.fc_weights[idx]) else: - x = ReduceEmbed.apply(x) x = F.linear(x, self.fc_weights[idx]) + x = ReduceEmbed.apply(x) + # print(f'linear: {x}') # reduce - if self.stages % 2 == 0: - x = ReduceEmbed.apply(x) - else: + if self.stages % 2 != 0: raise RuntimeError("number of stages only supported to be mod 2 == 0") - return torch.sum(x) + loss = torch.sum(x) + # print(loss) + # if rank == 0: + # io_input(f'{step}>>>') + # torch.distributed.barrier() + return loss if __name__ == '__main__': @@ -221,26 +227,41 @@ def forward(self, input: torch.Tensor): # tp 1f1b if args.use_tp1f1b: - embed = DummyModelEmbed(args.M, args.N, args.E, 0, sharding=True).cuda() + embed = DummyModelEmbed(args.M, args.N, args.E, sharding=True).cuda() first_stage_model = DummyModel(args.M, args.N, args.E, 0, sharding=True, embed=embed).cuda() if rank == 0: model = None else: model = DummyModel(args.M, args.N, args.E, rank, sharding=False).cuda() + # optimizer + if rank == 0: + parameters = first_stage_model.parameters() + else: + parameters = list(first_stage_model.parameters()) + list(model.parameters()) + optimizer = torch.optim.Adam(parameters) if args.use_naive: # naive pipleline embed = None if rank == 0: - embed = DummyModelEmbed(args.M, args.N, args.E, 0, sharding=False).cuda() + embed = DummyModelEmbed(args.M, args.N, args.E, sharding=False).cuda() model = DummyModel(args.M, args.N, args.E, rank, sharding=False, embed=embed).cuda() + # optimizer + optimizer = torch.optim.Adam(model.parameters()) + if args.use_tp1f1b_pack: - embed = DummyModelEmbed(args.M, args.N, args.E, 0, sharding=True).cuda() + embed = DummyModelEmbed(args.M, args.N, args.E, sharding=True).cuda() model = DummyModel(args.M, args.N, args.E, rank, sharding=False).cuda() + optimizer = torch.optim.Adam(list(embed.parameters()) + list(model.parameters())) if args.use_tp: model = DummyModelTP(args.M, args.N, args.E, rank).cuda() + optimizer = torch.optim.Adam(model.parameters()) + + # 0.11GB + print_each_rank('model consumption') + memory_summary() dataloader = SynTextDataLoader( shapes=([args.M],), @@ -249,6 +270,10 @@ def forward(self, input: torch.Tensor): length=128000 ) + # 0.11GB + print_each_rank('model + dataloader consumption') + memory_summary() + iter_num = 64 CudaTimer(enable=False).warmup() for step in range(iter_num): @@ -266,6 +291,9 @@ def forward(self, input: torch.Tensor): data = next(dataloader) loss = model(data) loss.backward() + + optimizer.step() + optimizer.zero_grad() if step >= 20: CudaTimer().stop('e2e') diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index 0624c806..d9d0d0ee 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -2,9 +2,12 @@ import torch from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary import cube.runtime.adapter.collectives as coll from cube.runtime.device import DeviceGroup +io_input = input + def forward_step(model, *args, **kwargs): """ @@ -56,10 +59,11 @@ def schedule_naive(model, dataloader, num_microbatch: int): rank = DeviceGroup().rank next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size - for _ in range(num_microbatch): + for step in range(num_microbatch): # recv forward if is_first_stage(): input = next(dataloader) + print(input) else: # print(f'rank {rank} recving forward input...') input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) @@ -81,6 +85,11 @@ def schedule_naive(model, dataloader, num_microbatch: int): # print(f'rank {rank} sending backward output...') coll.send(input_grad, prev_rank) + # memory_summary() + # if rank == 0: + # io_input(f'{step}>>>') + # torch.distributed.barrier() + def schedule_tp_1f1b(model: torch.nn.Module, first_stage_model: torch.nn.Module, @@ -353,6 +362,9 @@ def tp_backward(grad: torch.Tensor): input_tensors.append(input) output_tensors.append(output) + # mem = torch.cuda.max_memory_allocated() + # print(f'rank {rank}: {mem / 1024 / 1024 / 1024} GB forward') + # intra-barrier send recv output_grad = None if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): @@ -428,6 +440,11 @@ def tp_backward(grad: torch.Tensor): # print(f'rank {rank} backward tp model ') tp_backward(last_backward) + # memory_summary() + # print_each_rank(f'{len(input_1st_tensors)}, {len(input_tensors)}') + # if rank == 0: + # io_input(f'{step}>>>') + # torch.distributed.barrier() # print_each_rank(f'=========end rank {rank}: {step}=========') assert len(input_tensors) == 0 From 99efe79600b31b8471446638475beb7633c8030e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Mar 2022 13:50:47 +0800 Subject: [PATCH 0628/1892] fix sendrecv bug --- cube/runtime/adapter/collectives.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 54732eaa..8300b80d 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -58,10 +58,10 @@ def recv(shape: List[int], from_rank: int, dtype: torch.dtype): def sendrecv(input_tensors: List[torch.Tensor], - output_shapes: List[List[int]], - output_dtypes: List[torch.dtype], - send_ranks: List[int], - recv_ranks: List[int]) -> List[torch.Tensor]: + output_shapes: List[List[int]], + output_dtypes: List[torch.dtype], + send_ranks: List[int], + recv_ranks: List[int]) -> List[torch.Tensor]: CudaTimer().start(field_name='comm') # print('sending and recving...') ops = list() @@ -86,7 +86,7 @@ def sendrecv(input_tensors: List[torch.Tensor], reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() - + torch.cuda.synchronize() CudaTimer().stop(field_name='comm') return outputs From f5f0d498d7bef08053c99b8379af196278df4293 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Mar 2022 14:02:03 +0800 Subject: [PATCH 0629/1892] fix gather bug --- cube/runtime/adapter/collectives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 8300b80d..647ffcd3 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -219,6 +219,7 @@ def gather(input_tensors: List[torch.Tensor], reqs = torch.distributed.batch_isend_irecv([send_op]) for req in reqs: req.wait() + torch.cuda.synchronize() CudaTimer().stop(field_name='comm') return tensor_list From e5d926a7ad5d2504e95c581a629aa304fbdf400a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Mar 2022 14:02:41 +0800 Subject: [PATCH 0630/1892] tp1f1b correctness verify --- handcraft/pipeline/schedule.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index d9d0d0ee..26f8776b 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -283,6 +283,10 @@ def tp_backward(grad: torch.Tensor): # print(f'rank {rank} backward tp model ') tp_backward(last_backward) + # if rank == 0: + # io_input(f'{step}>>>') + # torch.distributed.barrier() + # print_each_rank(f'=========end rank {rank}=========') From dc7d1302bb8f186e336943c981faa672633910d2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Mar 2022 14:08:05 +0800 Subject: [PATCH 0631/1892] setup script --- handcraft/pipeline/run.sh | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/handcraft/pipeline/run.sh b/handcraft/pipeline/run.sh index ed60adc1..080b7306 100755 --- a/handcraft/pipeline/run.sh +++ b/handcraft/pipeline/run.sh @@ -1,44 +1,28 @@ # 4 gpus OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp1f1b --nmb 4 > 4dev4nmb-tp1f1b.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp1f1b --nmb 8 > 4dev8nmb-tp1f1b.txt + handcraft/pipeline/dummy.py --use-tp --nmb 64 > 4dev64nmb-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/pipeline/dummy.py --use-tp1f1b --nmb 64 > 4dev64nmb-tp1f1b.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive --nmb 4 > 4dev4nmb-naive.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive --nmb 8 > 4dev8nmb-naive.txt + handcraft/pipeline/dummy.py --use-tp1f1b-pack --nmb 64 > 4dev64nmb-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/pipeline/dummy.py --use-naive --nmb 64 > 4dev64nmb-naive.txt -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp --nmb 64 > 4dev64nmb-tp.txt - # 8 gpus OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp1f1b --nmb 8 > 8dev8nmb-tp1f1b.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp1f1b --nmb 16 > 8dev16nmb-tp1f1b.txt + handcraft/pipeline/dummy.py --use-tp --nmb 128 > 8dev128nmb-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ handcraft/pipeline/dummy.py --use-tp1f1b --nmb 128 > 8dev128nmb-tp1f1b.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive --nmb 8 > 8dev8nmb-naive.txt + handcraft/pipeline/dummy.py --use-tp1f1b-pack --nmb 128 > 8dev128nmb-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive --nmb 16 > 8dev16nmb-naive.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive --nmb 128 > 8dev128nmb-naive.txt - + handcraft/pipeline/dummy.py --use-naive --nmb 128 > 8dev128nmb-naive.txt \ No newline at end of file From 28677430e178fbd74f86d44483176853e3a8349a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Mar 2022 14:14:47 +0800 Subject: [PATCH 0632/1892] scale bug in tp1f1b --- handcraft/pipeline/schedule.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index 26f8776b..5c9e4237 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -144,7 +144,7 @@ def tp_backward(grad: torch.Tensor): bofst = bofst[rank] last_backward = None last_forward = None - for step in range(num_microbatch + 2): + for step in range(num_microbatch + num_stage - 2): torch.distributed.barrier() # print_each_rank(f'=========begin rank {rank}=========') fmid, bmid = step + fofst, step + bofst @@ -287,6 +287,11 @@ def tp_backward(grad: torch.Tensor): # io_input(f'{step}>>>') # torch.distributed.barrier() + assert len(input_tensors) == 0 + assert len(output_tensors) == 0 + assert len(input_1st_tensors) == 0 + assert len(output_1st_tensors) == 0 + # print_each_rank(f'=========end rank {rank}=========') @@ -453,7 +458,6 @@ def tp_backward(grad: torch.Tensor): assert len(input_tensors) == 0 assert len(output_tensors) == 0 - assert len(input_1st_tensors) == 0 assert len(output_1st_tensors) == 0 From 4f9a7c2aab407ca07e0e1b8b0e91ea9e691d6b4b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Mar 2022 14:51:19 +0800 Subject: [PATCH 0633/1892] fix scaling bug --- handcraft/pipeline/schedule.py | 66 ++++++++++------------------------ 1 file changed, 19 insertions(+), 47 deletions(-) diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index 5c9e4237..bfbc57f7 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -63,7 +63,6 @@ def schedule_naive(model, dataloader, num_microbatch: int): # recv forward if is_first_stage(): input = next(dataloader) - print(input) else: # print(f'rank {rank} recving forward input...') input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) @@ -160,7 +159,7 @@ def tp_backward(grad: torch.Tensor): if rank == 0: pass - if rank != 0 and rank % 2 == 0: + if rank % 2 == 0 and rank != 0: # inter-barrier if do_backward and last_forward is not None: # print(f'rank {rank} recv backward grad + send forward output ') @@ -203,52 +202,24 @@ def tp_backward(grad: torch.Tensor): input_tensors.append(input) output_tensors.append(output) last_forward = output - - if rank == 1: - - # forward - if do_forward: - input = output_1st - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) - - # intra-barrier send recv - if do_forward and do_backward: - # send forward recv backward - # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [output], [output.size()], [output.dtype], - [next_rank], [next_rank] - )[0] - elif do_forward: - # print(f'rank {rank} send forward output ') - coll.send(output, next_rank) - elif do_backward: - # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) - - # backward - if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) - input_grad = backward_step([input], [output], [output_grad])[0] - last_backward = input_grad - if rank != 1 and rank % 2 == 1: - + if rank % 2 == 1: # inter-barrier - if do_forward and last_backward is not None: - # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] - elif do_forward: - # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - elif last_backward is not None: - # print(f'rank {rank} send backward grad ') - coll.send(last_backward, prev_rank) + if rank == 1: + input = output_1st + else: + if do_forward and last_backward is not None: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + elif last_backward is not None: + # print(f'rank {rank} send backward grad ') + coll.send(last_backward, prev_rank) # forward if do_forward: @@ -273,6 +244,7 @@ def tp_backward(grad: torch.Tensor): output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) # backward + forward + last_backward = None if do_backward: input, output = input_tensors.pop(0), output_tensors.pop(0) input_grad = backward_step([input], [output], [output_grad])[0] @@ -391,6 +363,7 @@ def tp_backward(grad: torch.Tensor): output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) # backward + last_backward = None if do_backward: input, output = input_tensors.pop(0), output_tensors.pop(0) input_grad = backward_step([input], [output], [output_grad])[0] @@ -450,7 +423,6 @@ def tp_backward(grad: torch.Tensor): tp_backward(last_backward) # memory_summary() - # print_each_rank(f'{len(input_1st_tensors)}, {len(input_tensors)}') # if rank == 0: # io_input(f'{step}>>>') # torch.distributed.barrier() From c3c74ed400c8e077f45b906528c927626a2b76ce Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 7 Mar 2022 15:55:34 +0800 Subject: [PATCH 0634/1892] weather.py updated --- examples/atmosphere/weather.py | 142 ++++++++------------------------- 1 file changed, 34 insertions(+), 108 deletions(-) diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index 51dae7e9..862ab0d6 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -15,7 +15,9 @@ class Atmoshpere(torch.nn.Module): def __init__(self, nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, - bar_x_filter, bar_y_filter, delta_x_filter, delta_y_filter, bar_y2_filter,bar_x2_filter,bar_z_filter,delta_z_filter,delta_E_filter,laplas_filter,bar_xy_filter,delta_D_filter, + bar_x_filter, bar_y_filter, bar_z_filter, + bar_x2_filter, bar_y2_filter, bar_xy_filter, + delta_x_filter, delta_y_filter, delta_z_filter, delta_D_filter, delta_E_filter, laplas_filter, device='cuda'): super().__init__() #self.device = torch.device(device) @@ -66,11 +68,8 @@ def __init__(self, self.post_conv3d_reshape = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') - def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # flux - # F = self.bar_x(self.pad_x(pi)) * u * self.RE * self.dy # (nz, ny, nx + 1) - # G = self.bar_y(self.pad_y(pi)) * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) F = self.bar_x(self.pad_x(pi)) * 0.5 * u * self.RE * self.dy # (nz, ny, nx + 1) G = self.bar_y(self.pad_y(pi)) * 0.5 * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) B = self.bar_y2(self.bar_x(self.pad_y(F))) / 12. # (nz, ny, nx) @@ -83,11 +82,6 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): T = self.bar_y(self.bar_x(self.pad_x(G))) - self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) pi1 = pi0 - dt / self.deltaA * ((self.delta_x(F) + self.delta_y(G)) * self.dz).sum(dim=0) #sum(axis=0) # (nz, ny, nx) - # tmp0 = ((self.delta_x(F) + self.delta_y(G)) * self.dz) - # tmp1 = tmp0.sum(dim=0) - # # pi1 = pi0 - self.deltaA * tmp1 # - # pi1 = pi0 - dt / self.deltaA * tmp1 - # print('pi:', pi1.mean()) # TODO a custom Op needed @@ -176,147 +170,76 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # return pi1, theta0, u0, v0 return pi1, theta1, u1, v1 - def forward(self, pi, theta, u, v, dt): - # pi_, theta_, u_, v_ = self.step(dt / 2., pi, theta, u, v, pi, theta, u, v) - # return self.step(dt, pi_, theta_, u_, v_, pi, theta, u, v) - #TODO recover me - # pi1, theta0, u0, v0 = self.step(dt, pi, theta, u, v, pi, theta, u, v) - # return pi1, theta0, u0, v0 - pi1, theta1, u1, v1 = self.step(dt, pi, theta, u, v, pi, theta, u, v) + def forward(self, pi, theta, u, v, dt): + pi_, theta_, u_, v_ = self.step(dt / 2., pi, theta, u, v, pi, theta, u, v) + pi1, theta1, u1, v1 = self.step(dt, pi_, theta_, u_, v_, pi, theta, u, v) return pi1, theta1, u1, v1 - def pad_x(self, X): return F.pad(X, (1, 1), "circular") + def bar_x(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) - filter = self.bar_x_filter - x_ext = self.pre_conv3d_reshape(X) - x_convd = F.conv3d(x_ext, filter) - x_unext = self.post_conv3d_reshape(x_convd) - return x_unext + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_x_filter)) + def bar_x2(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 1, 3) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 2) - filter = self.bar_x2_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_x2_filter)) + def delta_x(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor([-1., 1.]).view(1, 1, 1, 1, 2) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) - filter = self.delta_x_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_x_filter)) def pad_y(self, X): - # nz, ny, nx = X.shape #TODO check return F.pad(X.view(1, nz, ny, nx), (0, 0, 1, 1), "circular").view(nz, ny + 2, nx) return F.pad(X, (0, 0, 1, 1), "circular") + def bar_y(self, X): - nz, ny, nx = X.shape - # filter = torch.tensor([1., 1.]).view(1, 1, 1, 2, 1) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) - filter = self.bar_y_filter - x_ext = self.pre_conv3d_reshape(X) - x_convd = F.conv3d(x_ext, filter) - x_unext = self.post_conv3d_reshape(x_convd) - return x_unext + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_y_filter)) + def bar_y2(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 3, 1) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 2, nx) - filter = self.bar_y2_filter #torch.tensor([1., 2., 1.]).view(1, 1, 1, 3, 1) - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_y2_filter)) + def delta_y(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor([-1., 1.]).view(1, 1, 1, 2, 1) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) - filter = self.delta_y_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_y_filter)) + def bar_z(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) - filter = self.bar_z_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_z_filter)) + def pad_z(self, X): - # nz, ny, nx = X.shape # return F.pad(X.view(1, 1, nz, ny, nx), (0, 0, 0, 0, 1, 1)).view(nz + 2, ny, nx) return F.pad(X, (0, 0, 0, 0, 1, 1)) + def delta_z(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor([-1., 1.]).view(1, 1, 2, 1, 1) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) - filter = self.delta_z_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_z_filter)) def delta_D(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor( - # [[1., 0.], - # [0., -1.]] - # ).view(1, 1, 1, 2, 2) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) - filter = self.delta_D_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_D_filter)) + def delta_E(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor( - # [[0., 1.], - # [-1., 0.]] - # ).view(1, 1, 1, 2, 2) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) - filter = self.delta_E_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_E_filter)) def bar_xy(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor( - # [[1., 0.], - # [0., 1.]] - # ).view(1, 1, 1, 2, 2) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx - 1) - filter = self.bar_xy_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_xy_filter)) def laplas(self, X): - # nz, ny, nx = X.shape - # filter = torch.tensor( - # [[[0., 0., 0.], - # [0., 1., 0.], - # [0., 0., 0.]], - # [[0., 1., 0.], - # [1., -6, 1.], - # [0., 1., 0.]], - # [[0., 0., 0.], - # [0., 1., 0.], - # [0., 0., 0.]]], - # ).view(1, 1, 3, 3, 3) - # return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 2, ny - 2, nx - 2) - filter = self.laplas_filter - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), filter)) + return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.laplas_filter)) class LoopVariables(cube.runtime.syndata.CubeDataLoader): - def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): # for var in variables + constants: # print("### var = {}, type = {}".format(var, type(var))) @@ -338,15 +261,18 @@ def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]) def __iter__(self): return self + def update(self, variables: List[torch.Tensor] = None, constants: List[torch.Tensor] = None): if variables is not None: self.variables = variables if constants is not None: self.constants = constants + def reset(self, batch_size): pass + def __next__(self): if len(self.variables) + len(self.constants) == 1: return (self.variables + self.constants)[0] @@ -496,7 +422,6 @@ def init(ps, pt, zs): return pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w - pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w = init(ps, pt, zs) print("[pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w]") for var in [pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w]: @@ -534,9 +459,10 @@ def init(ps, pt, zs): [0., -1.]] ).view(1, 1, 1, 2, 2) - model = Atmoshpere(nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, - bar_x_filter, bar_y_filter, delta_x_filter, delta_y_filter, bar_y2_filter, bar_x2_filter, bar_z_filter, delta_z_filter, delta_E_filter, laplas_filter,bar_xy_filter,delta_D_filter) + bar_x_filter, bar_y_filter, bar_z_filter, + bar_x2_filter, bar_y2_filter, bar_xy_filter, + delta_x_filter, delta_y_filter, delta_z_filter, delta_D_filter, delta_E_filter, laplas_filter) print("[pi, theta, u, v, dt]") for var in [pi, theta, u, v, dt]: From ac12c910f18ea60ab058fa2ddc82b663823569eb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Mar 2022 13:39:17 +0800 Subject: [PATCH 0635/1892] paser handle on tupleunpack --- cube/graph/operator/function/einops.py | 37 +++-- cube/graph/operator/function/function.py | 5 +- cube/graph/parser/frame.py | 4 +- cube/graph/parser/parser.py | 172 ++++++++++------------- 4 files changed, 94 insertions(+), 124 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 22ad464c..f0cbe1b5 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -147,7 +147,7 @@ class EinopAnno: def __init__(self, anno: str): """ initializing annotations specfied in str, e.g., - a (b c) d, d k -> a (b c) k + a (b c) d+, d+ k -> a (b c) k """ if not isinstance(anno, str): raise TypeError("Expected anno to be str") @@ -166,13 +166,14 @@ def __init__(self, anno: str): self.outputs: List[List[EinDim]] = [ self.parse_shape(shape) for shape in outputs ] + self.reset_identifiers() def parse_shape(self, shape: str) -> List[EinDim]: """ parsing annotations like of a single shape, e.g., - a (b dim) d + a (b+ dim) d^ """ - # => ['a', '(', 'b', 'dim', ')', 'd'] + # => ['a', '(', 'b+', 'dim', ')', 'd^'] shapes = list() for group in re.split('\ +', shape): if len(group) == 0: @@ -183,7 +184,7 @@ def parse_shape(self, shape: str) -> List[EinDim]: shapes.append(group) else: shapes.append(group) - identifiers: List[List[str]] = list() + edims: List[List[str]] = list() current_identifier = list() bracket_group = False for w in shapes: @@ -192,30 +193,28 @@ def parse_shape(self, shape: str) -> List[EinDim]: raise RuntimeError("brackets inside brackets not allowed") bracket_group = True if len(current_identifier) > 0: - identifiers.append(current_identifier) + edims.append(current_identifier) current_identifier = list() elif w == ')': if not bracket_group: raise RuntimeError("backets are not balanced at (") bracket_group = False if len(current_identifier) > 0: - identifiers.append(current_identifier) + edims.append(current_identifier) current_identifier = list() else: if bracket_group: current_identifier.append(w) - self._identifiers.add(w) else: if len(current_identifier) > 0: - identifiers.append(current_identifier) + edims.append(current_identifier) current_identifier = [w] - self._identifiers.add(w) if bracket_group: raise RuntimeError("brackets are not balanced at )") if len(current_identifier) != 0: - identifiers.append(current_identifier) - identifiers = [EinDim(identifer) for identifer in identifiers] - return identifiers + edims.append(current_identifier) + edims = [EinDim(edim) for edim in edims] + return edims def identifiers(self) -> Set[str]: return copy.copy(self._identifiers) @@ -355,29 +354,29 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei expand_dims = None if '*' in identifiers: # names - in_names = [[e.name for e in input] for input in anno.inputs] - out_names = [[e.name for e in out] for out in anno.outputs] - spatial = all(['*' in names for names in out_names]) - candicates = [c if spatial else c + '^' for c in string.ascii_lowercase if c not in identifiers] + candicates = [c for c in string.ascii_lowercase if c not in identifiers] # go through inputs - for idx, (names, input) in enumerate(zip(in_names, self.inputs())): + for idx, (eshape, input) in enumerate(zip(anno.inputs, self.inputs())): + names = [edim.name for edim in eshape] if '*' in names: if not isinstance(input, IRTensor): return False, None, None pos = names.index('*') + split = eshape[pos].reduce[0].value span = len(self.inputs(idx).shape) - (len(names) - 1) if expand_dims is not None and len(expand_dims) != span: return False, None, None if expand_dims is None: expand_dims = [] if span > 0: - expand_dims = [EinDim(candicates[dim]) for dim in range(span)] + expand_dims = [EinDim(candicates[dim]+split) for dim in range(span)] anno.inputs[idx] = anno.inputs[idx][:pos] + expand_dims + anno.inputs[idx][pos+1:] # * should appear in inputs if expand_dims is None: return False, None, None # go through outputs - for idx, names in enumerate(out_names): + for idx, eshape in enumerate(anno.outputs): + names = [edim.name for edim in eshape] if '*' in names: pos = names.index('*') anno.outputs[idx] = anno.outputs[idx][:pos] + expand_dims + anno.outputs[idx][pos+1:] diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 530c01ae..d928e9c7 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -178,7 +178,6 @@ def Div(signature, inputs): oshape[dim] = lshape[dim] rshape[dim] = str(rhs.shape[dim]) annos = [_create_anno([lshape, rshape], [oshape])] - print(f"Div::annos = {annos}") return IREinops(signature, annos, inputs, 'div') def Neg(signature, inputs): @@ -249,7 +248,7 @@ def Dropout(signature, inputs): def Sum(signature, inputs): # TODO: support dim reduction annos = [ - '* -> 1', + '*+ -> 1', ] tensor = inputs[0:1] dim = inputs[1] @@ -258,7 +257,6 @@ def Sum(signature, inputs): dim_len = len(tensor[0].shape) anno = "".join([f'b{i} ' for i in range(dim_len)]) + " -> " + "".join([f'b{i} ' if i not in dim else "" for i in range(dim_len)]) annos.append(anno) - # print("### Sum::anno = {}", annos) return IREinops(signature, annos, tensor, 'sum', dim=dim, keepdim=keepdim) else: @@ -288,7 +286,6 @@ def View(signature, inputs): assert len(inputs) == 2 input, shape = inputs in_shape, ou_shape = list(input.shape), shape - print(in_shape, ou_shape) # shape check def nele(shape, nele=1): diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index 72feba1d..bf19c035 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -35,8 +35,8 @@ def add_var(self, var_name: str, val: Any, graph_arg: int = -1): val: variable content graph_arg (int): indicate whether it is an argument of the graph. - If is 0, is not a graph arg. - If > 0, is a graph arg, will try to find + If == -1, is not a graph arg. + If >= 0, is a graph arg, will try to find val from previous frame """ if not isinstance(var_name, str): diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 30d19bce..64a572a0 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -1,7 +1,7 @@ import torch import enum import re -from typing import List, Tuple, Optional, Union +from typing import List, Tuple, Optional from cube.graph import IRFwOperation from cube.graph.tensor import IRFullTensor @@ -30,30 +30,31 @@ class ScriptModuleParser: @staticmethod def parse_module(module, input_shapes: Optional[ Tuple[List[int],] ] = None, - frame: Frame = Frame()) \ + frame: Frame = None) \ -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: """ The overall entry to parse a torchscript graph module """ + frame = frame if frame is not None else Frame() frame.push() + inputs = list(module.graph.inputs())[1:] + if input_shapes is not None and len(input_shapes) != len(inputs): + raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + # handle graph input -- Assuming all the inputs are tensors - input_var_name = [input.debugName() for input in module.graph.inputs()] kDefaultType = DType2IRDType.map(torch.get_default_dtype()) - for index, var_name in enumerate(input_var_name[1:]): # omit self - frame.add_var(var_name, IRFullTensor(name=var_name, requires_grad=False, dtype=kDefaultType), graph_arg=index) - input_val = [frame.get_var(var_name) for var_name in input_var_name[1:]] - - # handle input shape - if input_shapes: - if len(input_val) != len(input_shapes): - raise RuntimeError( - f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(input_val)})" - ) - for shape, val in zip(input_shapes, input_val): - if isinstance(val, IRFullTensor): - val.shape = shape + for idx, input in enumerate(inputs): + if isinstance(input.type(), torch._C.TensorType): + shape = None if input_shapes is None else input_shapes[idx] + dtype = kDefaultType + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.debugName()) + else: + raise NotImplementedError("Graph inputs only accepts Tensor") + frame.add_var(input.debugName(), val, graph_arg=idx) + input_val = [frame.get_var(input.debugName()) for input in inputs] + # handle nodes all_ir_nodes: List[IRFwOperation] = list() for node in module.graph.nodes(): ir_nodes = ScriptModuleParser.parse_node(node, module, frame) @@ -70,10 +71,12 @@ def parse_module(module, f"====== Shape Infer Error ====\n\n\n" ) all_ir_nodes += ir_nodes - - # handle graph output -- Assuming all the output are tensors + + # handle outputs output_var_name = [output.debugName() for output in module.graph.outputs()] output_val = [frame.get_var(var_name) for var_name in output_var_name] + + # flatten output_val outputs = list() for val in output_val: if isinstance(val, list): @@ -90,7 +93,6 @@ def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): """ Parse module method """ - frame.push() input_var_name = [input.debugName() for input in method.graph.inputs()] @@ -119,16 +121,9 @@ def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): ) all_ir_nodes += ir_nodes - # handle graph output -- Assuming all the output are tensors + # handle graph output output_var_name = [output.debugName() for output in method.graph.outputs()] output_val = [frame.get_var(var_name) for var_name in output_var_name] - outputs = list() - for val in output_val: - if isinstance(val, list): - outputs += val - else: - outputs.append(val) - output_val = outputs frame.pop() return input_val, all_ir_nodes, output_val @@ -165,30 +160,29 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] """ Parse the node and return the IRFwOperation nodes """ + node_type = ScriptModuleParser.ntype(node) try: - node_type = ScriptModuleParser.ntype(node) - except RuntimeError: - print(module.graph) - raise RuntimeError("Unsupported node kind {node.kind()} found in parsing. See above graph.") - if node_type == ScriptNodeKind.PrimCallFunction: - return ScriptModuleParser.parse_prim_function_node(node, module, frame) - if node_type == ScriptNodeKind.AtenOp: - return ScriptModuleParser.parse_aten_node(node, module, frame) - if node_type == ScriptNodeKind.PrimCallMethod: - return ScriptModuleParser.parse_prim_method_node(node, module, frame) - if node_type == ScriptNodeKind.PrimGetAttr: - return ScriptModuleParser.parse_prim_attr_node(node, module, frame) - if node_type == ScriptNodeKind.PrimConstant: - return ScriptModuleParser.parse_prim_constant_node(node, module, frame) - if node_type == ScriptNodeKind.PrimListConstruct: - return ScriptModuleParser.parse_prim_list_construct_node(node, module, frame) - if node_type == ScriptNodeKind.PrimListUnpack: - return ScriptModuleParser.parse_prim_listunpack_node(node, module, frame) - if node_type == ScriptNodeKind.PrimTupleUnpack: - return list() # tuple unpack should only be used in prim function node - if node_type == ScriptNodeKind.PrimPythonOp: - return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) - raise NotImplementedError(f"Un-supported node type {node_type}") + if node_type == ScriptNodeKind.PrimCallFunction: + return ScriptModuleParser.parse_prim_function_node(node, module, frame) + if node_type == ScriptNodeKind.AtenOp: + return ScriptModuleParser.parse_aten_node(node, module, frame) + if node_type == ScriptNodeKind.PrimCallMethod: + return ScriptModuleParser.parse_prim_method_node(node, module, frame) + if node_type == ScriptNodeKind.PrimGetAttr: + return ScriptModuleParser.parse_prim_attr_node(node, module, frame) + if node_type == ScriptNodeKind.PrimConstant: + return ScriptModuleParser.parse_prim_constant_node(node, module, frame) + if node_type == ScriptNodeKind.PrimListConstruct: + return ScriptModuleParser.parse_prim_list_construct_node(node, module, frame) + if node_type == ScriptNodeKind.PrimListUnpack: + return ScriptModuleParser.parse_prim_listunpack_node(node, module, frame) + if node_type == ScriptNodeKind.PrimTupleUnpack: + return ScriptModuleParser.parse_prim_tupleunpack_node(node, module, frame) + if node_type == ScriptNodeKind.PrimPythonOp: + return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) + raise NotImplementedError(f"Un-supported node type {node_type}") + except Exception: + raise RuntimeError(f"\n\nParsing error at node:\n\t{node}\n") @staticmethod def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: @@ -196,54 +190,46 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: parse node like: Tensor = prim::CallFunction(%5, %input.1, %3, %4) %5 : Function = prim::Constant[name="linear"]() + %12 : (Tensor, Tensor) = prim::CallFunction(%5, %x1.1, %x2.1) """ inputs = [input for input in node.inputs()] - outputs = [output for output in node.outputs()] - outputs: List[Union[torch._C.Value, IRFullTensor]] = list() - for output in node.outputs(): - # unpack the output type - if isinstance(output.type(), torch._C.TupleType): - for unpack_node in module.graph.nodes(): - if ScriptModuleParser.ntype(unpack_node) == ScriptNodeKind.PrimTupleUnpack: - if output in unpack_node.inputs(): - ScriptModuleParser.parse_prim_tupleunpack_node(unpack_node, module, frame) - break - tuple_outputs = frame.get_var(output.debugName()) - outputs += tuple_outputs - else: - outputs.append(output) - - # handle function node + # get signature fnode = node.inputsAt(0).node() if not ScriptModuleParser.ntype(fnode) == ScriptNodeKind.PrimConstant: raise RuntimeError(f"Found unexpected function call node: {fnode}") fsig = frame.get_var(inputs[0].debugName()) - # handle inputs + # get inputs input_vals = list() for index, input in enumerate(inputs[1:]): var_name = input.debugName() val = frame.get_var(var_name) input_vals.append(val) - try: - ir_node = Sign2Op.map(fsig)(inputs=input_vals) - except Exception: - # print(module.code) - raise RuntimeError(f"Parsing error of {node}") - if len(ir_node.outputs()) != len(outputs): + # map to IR operator + ir_node = Sign2Op.map(fsig)(inputs=input_vals) + + # push output in the frame + # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) + # : >>> dir(a) + # : >>> a.elements() # [TensorType, TensorType] + cnt = 0 + for output in node.outputs(): + if isinstance(output.type(), torch._C.TupleType): + tuplen = len(output.type().elements()) + ir_output = [ir_node.outputs(idx) for idx in range(cnt, cnt+tuplen)] + cnt += tuplen + else: + ir_output = ir_node.outputs(cnt) + cnt += 1 + frame.add_var(output.debugName(), ir_output) + + if cnt != len(ir_node.outputs()): raise RuntimeError( - f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" + f"Parse fail: {fsig} has {cnt} outputs != pre-defined {len(ir_node.outputs())}" ) - # handle outputs - for index, output in enumerate(outputs): - if isinstance(output, IRFullTensor): - ir_node.set_output(index, output) - else: - frame.add_var(output.debugName(), ir_node.outputs(index)) - return [ir_node] @staticmethod @@ -295,17 +281,12 @@ def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: # forward label = node.s('name') - # if label != 'forward': - # raise RuntimeError(f"{node} is calling function {label} that is not `forward`") - # handle inputs -- in stack with reverse order for input in inputs[1:][::-1]: var_name = input.debugName() val = frame.get_var(var_name) frame.push_param(var_name) - # print(f'> {frame}') - # recursively parse the module if node.inputsAt(0).debugName() == 'self': call_module = module @@ -443,7 +424,7 @@ def parse_prim_listunpack_node(node, module, frame: Frame) -> List[None]: raise NotImplementedError @staticmethod - def parse_prim_tupleunpack_node(node, module, frame) -> List[None]: + def parse_prim_tupleunpack_node(node, module, frame: Frame) -> List[None]: """ Parse script module node like: %q.1 : Tensor, %k.1 : Tensor, %v.1 : Tensor = prim::TupleUnpack(%11) @@ -454,18 +435,11 @@ def parse_prim_tupleunpack_node(node, module, frame) -> List[None]: raise RuntimeError("Find UnpackTuple has more than one input") if len(outputs) == 1: raise RuntimeError("Find UnpackTuple has only one output") - tuple_outs = list() - for output in outputs: - dtype = output.type().str() - var_name = output.debugName() - if dtype == 'Tensor': - kDefaultType = DType2IRDType.map(torch.get_default_dtype()) - ir_tensor = IRFullTensor(name=var_name, dtype=kDefaultType) - tuple_outs.append(ir_tensor) - frame.add_var(var_name, ir_tensor) - else: - raise NotImplementedError - frame.add_var(inputs[0].debugName(), tuple_outs) + tuple_inputs = frame.get_var(inputs[0].debugName()) + if len(tuple_inputs) != len(outputs): + raise RuntimeError("Expected unpacked tuple number have same length of tupled input") + for output, val in zip(outputs, tuple_inputs): + frame.add_var(output.debugName(), val) return list() @staticmethod From b1b986b7b7849f6b302f393e88664e140d04f665 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Mar 2022 14:42:51 +0800 Subject: [PATCH 0636/1892] einop re-structure --- cube/graph/operator/function/einops.py | 132 ++++++++++--------------- 1 file changed, 54 insertions(+), 78 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index f0cbe1b5..794d98c9 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -49,7 +49,7 @@ """ -from typing import Callable, Dict, List, Union +from typing import Any, Dict, List, Union from typing import Optional, Set, Tuple, Optional import enum import re @@ -237,28 +237,24 @@ class IREinops(IRFwOperation): """ Einstein-inspired notation operations """ - def __init__(self, signature: str, annos: List[Union[str, Tuple[str, Callable]]], + def __init__(self, signature: str, annos: Tuple[str], inputs: List, name: str, **kwargs): - noutputs = set() - self._annos: List[EinopAnno] = list() - self._adapt: List[Union[Callable, None]] = list() - for anno in annos: - if isinstance(anno, tuple): - anno, adapt = anno - elif isinstance(anno, str): - adapt = None - else: - raise TypeError("Expected annos to be list of tuples of list of str") - anno = EinopAnno(anno) - self._annos.append(anno) - self._adapt.append(adapt) - noutputs.add(len(anno.outputs)) + self._annos_candidates: List[str] = tuple(annos) self._iannos: List[List[EinDim]] = None self._oannos: List[List[EinDim]] = None - if len(noutputs) != 1: - raise ValueError("Annotations should have same output length") - super().__init__(name, signature, len(inputs), list(noutputs)[0]) + for anno in self._annos_candidates: + anno = EinopAnno(anno) + # expand * and check shape dimension consistency + if self.parse(inputs, anno): + self._iannos = anno.inputs + self._oannos = anno.outputs + break + else: + raise RuntimeError("No matching anno for given annos") + + n_outputs = len(self._oannos) + super().__init__(name, signature, len(inputs), n_outputs) # set input for idx, input in enumerate(inputs): self.set_input(idx, input) @@ -267,105 +263,84 @@ def __init__(self, signature: str, annos: List[Union[str, Tuple[str, Callable]]] def infer_shape(self) -> bool: """ - Shape inference by mathcing dimension annotations. - Assume input shape is given + Shape inference using the matched annotation """ - # try parsing given anno candidates - ret = False - for anno, adapt in zip(self._annos, self._adapt): - if adapt is not None: - anno = adapt(anno, self) - ret, iannos, oannos = self.parse(anno) - self._iannos = iannos - self._oannos = oannos - if ret: break - if not ret: - raise RuntimeError("No matching anno for given annos") dimlen: Dict[str, int] = dict() for input, ishape in zip(self.inputs(), self._iannos): if not isinstance(input, IRTensor): continue - if len(ishape) != len(input.shape): - raise RuntimeError(f"node {self._id} {self.signature}: error match input: {input.shape} and ein_shape: {ishape}") - for tdim, edim in zip(input.shape, ishape): + for tdim, edim in zip(input.shape, ishape): if len(edim.names()) == 1: - if edim.name in dimlen and dimlen[edim.name] != tdim: - raise RuntimeError(f"op: {self.signature} has different shape for same dim annotation {edim.name}") dimlen[edim.name] = tdim - edim.setlen(edim.name, tdim) - else: - toinfer = list() - accum = 1 - for name in edim._name: - if str.isnumeric(name): - accum *= int(name) - edim.setlen(name, int(name)) - dimlen[name] = int(name) - elif name in self.kwargs: - accum *= self.kwargs[name] - edim.setlen(name, self.kwargs[name]) - dimlen[name] = self.kwargs[name] - else: - toinfer.append(name) - if len(toinfer) > 1: - raise RuntimeError(f"Expected indication of dimension {toinfer} from kwargs") - if len(toinfer) == 1: - edim.setlen(toinfer[0], tdim // accum) - dimlen[toinfer[0]] = tdim // accum + continue + # infer hidden dim shape + toinfer = None + accum = 1 + for name in edim.names(): + if str.isnumeric(name): + accum *= int(name) + dimlen[name] = int(name) + elif name in self.kwargs: + accum *= self.kwargs[name] + dimlen[name] = self.kwargs[name] + else: + if toinfer is not None: + raise RuntimeError(f"Too many dimensions need to be inferred") + toinfer = name + if toinfer is not None: + dimlen[toinfer] = tdim // accum # figure output shape for oidx in range(len(self._outputs)): output_shape = list() for odim in self._oannos[oidx]: accum = 1 - for name in odim._name: - if str.isdecimal(name): + for name in odim.names(): + if str.isnumeric(name): accum *= int(name) else: if name not in dimlen: raise KeyError(f"Dim annotation {name} not in input") accum *= dimlen[name] - odim.setlen(name, dimlen[name]) output_shape.append(accum) self.outputs(oidx).shape = output_shape - return ret + return True def new(self, inputs: List, outputs: List): """ construct a new operator sharing same kwargs with new inputs and outputs """ - annos = list() - for anno, adapt in zip(self._annos, self._adapt): - annos.append((anno.anno, adapt)) + annos = self._annos_candidates op = IREinops(self.signature, annos, inputs, self.name, **self.kwargs) for idx, output in enumerate(outputs): op.set_output(idx, output) return op - def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[EinDim]]]: + def parse(self, inputs: List[Any], anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[EinDim]]]: """ parse annotations, assuming input tensor shape is given """ - if len(anno.inputs) != len(self.inputs()): - return False, None, None identifiers = anno.identifiers() + # input shape match + if len(anno.inputs) != len(inputs): + return False + # expand * expand_dims = None if '*' in identifiers: - # names candicates = [c for c in string.ascii_lowercase if c not in identifiers] # go through inputs - for idx, (eshape, input) in enumerate(zip(anno.inputs, self.inputs())): + for idx, (eshape, input) in enumerate(zip(anno.inputs, inputs)): names = [edim.name for edim in eshape] if '*' in names: if not isinstance(input, IRTensor): - return False, None, None + return False pos = names.index('*') split = eshape[pos].reduce[0].value - span = len(self.inputs(idx).shape) - (len(names) - 1) + span = len(inputs[idx].shape) - (len(names) - 1) if expand_dims is not None and len(expand_dims) != span: - return False, None, None + return False if expand_dims is None: expand_dims = [] if span > 0: @@ -373,7 +348,7 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei anno.inputs[idx] = anno.inputs[idx][:pos] + expand_dims + anno.inputs[idx][pos+1:] # * should appear in inputs if expand_dims is None: - return False, None, None + return False # go through outputs for idx, eshape in enumerate(anno.outputs): names = [edim.name for edim in eshape] @@ -381,21 +356,22 @@ def parse(self, anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[Ei pos = names.index('*') anno.outputs[idx] = anno.outputs[idx][:pos] + expand_dims + anno.outputs[idx][pos+1:] anno.reset_identifiers() + # check dimension consistency dimlen: Dict[str, int] = dict() - for eshape, input in zip(anno.inputs, self.inputs()): + for eshape, input in zip(anno.inputs, inputs): if not isinstance(input, IRTensor): if not (len(eshape) == 1 and eshape[0].name == '1'): - return False, None, None + return False else: if len(input.shape) != len(eshape): - return False, None, None + return False for edim, nele in zip(eshape, input.shape): if edim.name in dimlen: if nele != dimlen[edim.name]: - return False, None, None + return False dimlen[edim.name] = nele - return True, anno.inputs, anno.outputs + return True def einexpr(self) -> str: inputs = list() From af4a8982aaf24b1238f845a9a914beb1a40c1087 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Mar 2022 16:06:19 +0800 Subject: [PATCH 0637/1892] enable operator registration --- cube/graph/operator/function/einops.py | 7 +++- cube/graph/operator/function/function.py | 4 +- cube/graph/parser/register.py | 52 +++++++++++------------- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 794d98c9..34632b96 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -251,7 +251,12 @@ def __init__(self, signature: str, annos: Tuple[str], self._oannos = anno.outputs break else: - raise RuntimeError("No matching anno for given annos") + raise RuntimeError( + f"no matching anno for given annos." + f"op: {signature}\n" + f"inputs: {inputs}\n" + f"annos: {annos}\n" + ) n_outputs = len(self._oannos) super().__init__(name, signature, len(inputs), n_outputs) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index d928e9c7..42791041 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -3,7 +3,7 @@ import copy from cube.ir.cten import IRTensor -from cube.graph.operator.function.einops import EinDim, EinopAnno, IREinops +from cube.graph.operator.function.einops import EinDim, IREinops from cube.graph.operator.function.conv import IRConv2D from cube.graph.operator.function.conv import IRConv3D from cube.graph.operator.function.pad import IRPad @@ -63,6 +63,8 @@ def Linear(signature, inputs): 'b * k+, n k+ -> b * n', # no bias 'b * k+, n k+, n -> b * n' # have bias ] + if inputs[2] is None: + inputs = inputs[0:2] return IREinops(signature, annos, inputs, 'linear') diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 62b652a9..b011a0e9 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -2,8 +2,7 @@ Register cutomized function """ -from functools import partial -from typing import Callable, List +from typing import Any, Callable, List import inspect import torch @@ -12,47 +11,44 @@ from cube.graph.parser.mapping import Sign2Op -def register(anno: str, stay: List[str] = None): +def register(anno: str): """ Register a function with einop annotations. - This function is cooperated with CustomizeEinop. - User needs to define a python function with type annotations - for each input argument. And user needs to pass dimension annotations - as well as (optional) frozen split dimensions (i.e., the dimensions cannot split). + This function is cooperated with IREinOp. + User needs to define a python function that satisfies + 1). Has type annotations for each input + 2). Tensor inputs goes first then other inputs - For EinDims containing brackets (e.g., (3 h d)), - user should have same argument name in the function definition - to help system infer each dim length, e.g., + For EinDims containing brackets (e.g., (3 h d)) that can not be + inferred by system, user should have same argument name in the + function definition to help system infer each dim length, e.g., @cube.register('a (b c) -> (a b) c') def funcname(x: torch.Tensor, b: int = 4): xxx - """ - if stay is None: - stay = list() - def decorator(fn: Callable): if not callable(fn): raise TypeError("Expected a function") + fsig = fn.__name__ args = inspect.signature(fn) arg_names = list(args.parameters.keys()) arg_kind = [args.parameters[name].annotation for name in arg_names] - func_name = fn.__name__ - kwarg_idx = list() - kwarg_name = list() - for idx, (name, kind) in enumerate(zip(arg_names, arg_kind)): - if kind != torch.Tensor: - kwarg_name.append(name) - kwarg_idx.append(idx) - print(f'registering op {func_name} with {len(args.parameters) - len(kwarg_idx)} inputs and {len(kwarg_idx)} kwargs...') - udfop = partial(IREinops, - name=func_name, - anno=[anno], - kwarg_idx=kwarg_idx, kwarg_name=kwarg_name - ) - Sign2Op.register(func_name, udfop) + kwarg_names = [name for (name, kind) in zip(arg_names, arg_kind) if kind != torch.Tensor] + nkwargs = len(kwarg_names) + ninputs = len(arg_names) - len(kwarg_names) + + def udfop(signature: str, inputs: List[Any]): + tensors = inputs[:ninputs] + kwarg_vals = inputs[ninputs:] + kwargs = dict() + for name, val in zip(kwarg_names, kwarg_vals): + kwargs[name] = val + return IREinops(signature, [anno], tensors, **kwargs, name=fsig) + + print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') + Sign2Op.register(fsig, udfop) return fn return decorator From 133652818a6f8b670650f1a04b09d836a64008bb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Mar 2022 16:06:51 +0800 Subject: [PATCH 0638/1892] registration example --- examples/attention/attention.py | 13 +- examples/attention/policy/data_parallel.py | 119 ------------------- examples/attention/policy/naive.py | 7 ++ examples/attention/policy/no_parallel.py | 23 ---- examples/attention/policy/tensor_parallel.py | 21 ---- 5 files changed, 13 insertions(+), 170 deletions(-) delete mode 100644 examples/attention/policy/data_parallel.py create mode 100644 examples/attention/policy/naive.py delete mode 100644 examples/attention/policy/no_parallel.py delete mode 100644 examples/attention/policy/tensor_parallel.py diff --git a/examples/attention/attention.py b/examples/attention/attention.py index d324dbd6..8c405455 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -10,7 +10,7 @@ --use_env \ examples/attention/attention.py -OMP_NUM_THREADS=4 torchrun --standalone \ +OMP_NUM_THREADS=1 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ examples/attention/attention.py @@ -19,16 +19,15 @@ import torch from torch import nn import torch.nn.functional as F -import cube - - -from examples.attention.policy.tensor_parallel import PAS +import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank +from examples.attention.policy.naive import PAS + -@cube.graph.parser.register('L N E, (3 h d) E -> L N (h d)', stay=['L', 'd', 'E']) +@cube.graph.parser.register('L^ N E^, (3 h d^) E^ -> L^ N (h d^)') def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, h: int, scale: float, dropout: float, training: bool): """ @@ -147,7 +146,7 @@ def train(): batch_dims=(1,) ) - @cube.compile(model, dataloader, policy=PAS) + @cube.compile(model, dataloader, PAS=PAS) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) diff --git a/examples/attention/policy/data_parallel.py b/examples/attention/policy/data_parallel.py deleted file mode 100644 index ab8aaf26..00000000 --- a/examples/attention/policy/data_parallel.py +++ /dev/null @@ -1,119 +0,0 @@ -from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation, IRDataOperation - - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using tensor parallel - """ - ndevs = resource.ngpus - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - assert len(fnodes) == 14 - - toqkv = fnodes[0] - q_t = fnodes[1] - k_t = fnodes[2] - v_t = fnodes[3] - q_scale = fnodes[4] - k_t2 = fnodes[5] - qk_bmm = fnodes[6] - mask = fnodes[7] - softmax = fnodes[8] - dropout = fnodes[9] - attnv_bmm = fnodes[10] - attnview = fnodes[11] - linear = fnodes[12] - loss = fnodes[13] - - all_sub_nodes = list() - - algo = toqkv.algorithms('data') - sub_nodes = graph.partition(toqkv, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = q_t.algorithms('dim') - sub_nodes = graph.partition(q_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = k_t.algorithms('dim') - sub_nodes = graph.partition(k_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = v_t.algorithms('dim') - sub_nodes = graph.partition(v_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = q_scale.algorithms('dim') - sub_nodes = graph.partition(q_scale, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = k_t2.algorithms('dim') - sub_nodes = graph.partition(k_t2, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = qk_bmm.algorithms('data') - sub_nodes = graph.partition(qk_bmm, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = mask.algorithms('head') - sub_nodes = graph.partition(mask, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = softmax.algorithms('dim') - sub_nodes = graph.partition(softmax, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = dropout.algorithms('dim') - sub_nodes = graph.partition(dropout, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = attnv_bmm.algorithms('data') - sub_nodes = graph.partition(attnv_bmm, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = attnview.algorithms('data') - sub_nodes = graph.partition(attnview, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = linear.algorithms('data') - sub_nodes = graph.partition(linear, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = loss.algorithms('dim') - sub_nodes = graph.partition(loss, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - # data loader - dataloaders = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] - for data_op in dataloaders: - algo = data_op.algorithms('data') - sub_nodes = graph.partition(data_op, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - for sub_nodes in all_sub_nodes: - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - print(graph) - # assert False - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - sugraph.partial_set_order(fsus, lazy=False) - return sugraph diff --git a/examples/attention/policy/naive.py b/examples/attention/policy/naive.py new file mode 100644 index 00000000..2d7fe3fe --- /dev/null +++ b/examples/attention/policy/naive.py @@ -0,0 +1,7 @@ +from cube.graph import IRGraph +from cube.graph.operator.operator import IRDataOperation, IRFwOperation + + +def PAS(graph: IRGraph, resource): + print(graph.extra_repr()) + return graph diff --git a/examples/attention/policy/no_parallel.py b/examples/attention/policy/no_parallel.py deleted file mode 100644 index 702c02b7..00000000 --- a/examples/attention/policy/no_parallel.py +++ /dev/null @@ -1,23 +0,0 @@ -from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph - - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using column parallel - """ - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - for su in sugraph.fsus(): - sugraph.assign(su, 0) - sugraph.assign(su.mirror, 0) - return sugraph diff --git a/examples/attention/policy/tensor_parallel.py b/examples/attention/policy/tensor_parallel.py deleted file mode 100644 index 14c3f6fe..00000000 --- a/examples/attention/policy/tensor_parallel.py +++ /dev/null @@ -1,21 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation - - -def PAS(graph: IRGraph, resource): - # data loader - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - graph.assign(node, list(range(resource.ngpus))) - fnodes = [isinstance(node, IRFwOperation) for node in graph.nodes()] - for idx, node in enumerate(fnodes): - if idx == 0: - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=1, num=resource.ngpus) - ) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - print(graph.extra_repr()) - return graph - From b8370bed851833a86426076b79f34c72a5b33e6c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Mar 2022 18:05:16 +0800 Subject: [PATCH 0639/1892] codegen on customized function --- cube/codegen/codegen.py | 5 +++++ cube/graph/parser/mapping.py | 8 ++++++-- cube/graph/parser/register.py | 5 ++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 2b99fff1..ed4fc56c 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -4,6 +4,7 @@ from typing import Dict, List, Any, Tuple import torch import copy +from cube.graph.parser.mapping import Sign2Op from cube.ir.cten import IRCell, IRTensor from cube.ir.dtype import IRDType @@ -62,6 +63,10 @@ def __init__(self, execplan: ExectuionPlan): self.init_code: List[str] = [ '\n\n########## Generated Model Code ###########', 'import torch', 'import cube', '', ''] + # customized op code + for _, op_impl in Sign2Op.kOpCodeDef.items(): + self.init_code.append(op_impl) + self.init_code += ['', ''] # module init code self.declare_region: List[str] = list() # module forward code diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 815d2b0e..cc33ac51 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -2,6 +2,7 @@ Mapping of Signature -> IROperator """ +from typing import Dict import torch from functools import partial @@ -26,7 +27,7 @@ def map(signature: str) -> IRFwOperation: # return partial(function.UnkownOperator, signature=signature) @staticmethod - def register(signature: str, op: IRFwOperation): + def register(signature: str, op: IRFwOperation, code): """ Register an operator """ @@ -35,6 +36,7 @@ def register(signature: str, op: IRFwOperation): if signature in Sign2Op.kOpMap: raise KeyError(f"function {signature} is already registered") Sign2Op.kOpMap[signature] = op + Sign2Op.kOpCodeDef[signature] = code # functional templates __ftemplate = lambda name: f'torch.nn.functional.{name}' @@ -94,9 +96,11 @@ def register(signature: str, op: IRFwOperation): #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, - } + # customized operator code: signature -> code + kOpCodeDef: Dict[str, str] = {} + class DType2IRDType: diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index b011a0e9..df38fcf2 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -38,6 +38,9 @@ def decorator(fn: Callable): kwarg_names = [name for (name, kind) in zip(arg_names, arg_kind) if kind != torch.Tensor] nkwargs = len(kwarg_names) ninputs = len(arg_names) - len(kwarg_names) + # get customized op code + code = inspect.getsource(fn) + code = code[code.index('def'):] def udfop(signature: str, inputs: List[Any]): tensors = inputs[:ninputs] @@ -48,7 +51,7 @@ def udfop(signature: str, inputs: List[Any]): return IREinops(signature, [anno], tensors, **kwargs, name=fsig) print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') - Sign2Op.register(fsig, udfop) + Sign2Op.register(fsig, udfop, code) return fn return decorator From 60a7116d6b9f1aa07298ef7dc1fc7c9d57bbf1f7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Mar 2022 18:05:46 +0800 Subject: [PATCH 0640/1892] use torch interface --- examples/attention/attention.py | 9 ++++----- examples/attention/policy/naive.py | 6 +++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/attention/attention.py b/examples/attention/attention.py index 8c405455..add204f7 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -18,7 +18,6 @@ import torch from torch import nn -import torch.nn.functional as F import cube from cube.profiler import CudaTimer @@ -43,7 +42,7 @@ def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, h: int, L, N = x.shape[0], x.shape[1] dim_head = wqkv.shape[0] // 3 // num_head # L N E, (3 h d) E -> L N (3 h d) - qkv = F.linear(x, wqkv, None) + qkv = torch.nn.functional.linear(x, wqkv, None) # L N (3 h d) -> L N (h d), L N (h d), L N (h d) q, k, v = qkv.chunk(3, dim=-1) # L N (h d) -> L (N h) d @@ -76,10 +75,10 @@ def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, h: int, attn = attn.view((N * num_head), L, L) # (N h) L L -> (N h) L L - attn = F.softmax(attn, dim=-1) + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L if training: - attn = F.dropout(attn, dropout, True, False) + attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L, (N h) L d -> (N h) L d output = torch.bmm(attn, v) # (N h) L d -> L (N h) d @@ -117,7 +116,7 @@ def forward(self, x): output = attnfc1(x, self.wqkv, self.num_head, self.scale, self.dropout, self.training) # L N (h d), E (h d) -> L N E - output = F.linear(output, self.wout) + output = torch.nn.functional.linear(output, self.wout) loss = torch.sum(output) return loss diff --git a/examples/attention/policy/naive.py b/examples/attention/policy/naive.py index 2d7fe3fe..caf38e07 100644 --- a/examples/attention/policy/naive.py +++ b/examples/attention/policy/naive.py @@ -1,7 +1,11 @@ from cube.graph import IRGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation def PAS(graph: IRGraph, resource): print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) return graph From a6410d4bce132297c0991ab087f28a310b279292 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Mar 2022 19:36:44 +0800 Subject: [PATCH 0641/1892] add layer_norm --- cube/graph/operator/function/einops.py | 2 ++ cube/graph/operator/function/function.py | 16 +++++++++++++++- cube/graph/parser/mapping.py | 2 ++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/cube/graph/operator/function/einops.py b/cube/graph/operator/function/einops.py index 34632b96..8ffa3132 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/operator/function/einops.py @@ -365,6 +365,8 @@ def parse(self, inputs: List[Any], anno: EinopAnno) -> Tuple[bool, List[List[Ein # check dimension consistency dimlen: Dict[str, int] = dict() for eshape, input in zip(anno.inputs, inputs): + if input is None: + continue if not isinstance(input, IRTensor): if not (len(eshape) == 1 and eshape[0].name == '1'): return False diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 42791041..4190c731 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -2,6 +2,8 @@ import string import copy +from numpy import isin + from cube.ir.cten import IRTensor from cube.graph.operator.function.einops import EinDim, IREinops from cube.graph.operator.function.conv import IRConv2D @@ -244,7 +246,19 @@ def Dropout(signature, inputs): tensor = inputs[0:1] p, training, inplace = inputs[1], inputs[2], inputs[3] return IREinops(signature, annos, tensor, 'dropout', - p=p, traning=training, inplace=inplace) + p=p, training=training, inplace=inplace) + + +def LayerNorm(signature, inputs): + input, normalized_shape, weight, bias, eps = inputs + if len(normalized_shape) != 1: + raise NotImplementedError("Only support normalized_shape to be int") + annos = [ + f'N *, 1, {normalized_shape[0]}, {normalized_shape[0]} -> N *', + f'N *, 1, 1, 1 -> N *' + ] + return IREinops(signature, annos, [input, normalized_shape, weight, bias], + 'layernorm', eps=eps) def Sum(signature, inputs): diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index cc33ac51..2bb764ba 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -61,6 +61,8 @@ def register(signature: str, op: IRFwOperation, code): __ftemplate('_pad'): function.Pad, + __ftemplate('layer_norm'): function.LayerNorm, + # __ftemplate('layer_norm'): function.LayerNorm, # torch aten From 7236ece0305b6a484b82288eed04b43e8c1c8449 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Mar 2022 19:37:58 +0800 Subject: [PATCH 0642/1892] transformer model (without residual) --- examples/attention/attention.py | 6 +- .../transformer/policy/megatron_parallel.py | 73 ------- examples/transformer/policy/naive.py | 11 + examples/transformer/policy/no_parallel.py | 23 --- .../transformer/policy/pipeline_parallel.py | 109 ---------- .../transformer/policy/tensor_parallel.py | 161 --------------- examples/transformer/transformer.py | 192 ------------------ examples/transformer/transformers.py | 163 ++++++++++----- 8 files changed, 124 insertions(+), 614 deletions(-) delete mode 100644 examples/transformer/policy/megatron_parallel.py create mode 100644 examples/transformer/policy/naive.py delete mode 100644 examples/transformer/policy/no_parallel.py delete mode 100644 examples/transformer/policy/pipeline_parallel.py delete mode 100644 examples/transformer/policy/tensor_parallel.py delete mode 100644 examples/transformer/transformer.py diff --git a/examples/attention/attention.py b/examples/attention/attention.py index add204f7..c7873501 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -21,6 +21,7 @@ import cube from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary from cube.profiler.timer import print_each_rank from examples.attention.policy.naive import PAS @@ -154,12 +155,12 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - CudaTimer().warmup() + CudaTimer(enable=False).warmup() torch.distributed.barrier() iter_num = 128 for step in range(iter_num): if step >= 40: - CudaTimer().start('e2e') + CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() @@ -170,6 +171,7 @@ def train_iter(model, dataloader): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-40, field_name='e2e'))) + memory_summary() if __name__ == '__main__': diff --git a/examples/transformer/policy/megatron_parallel.py b/examples/transformer/policy/megatron_parallel.py deleted file mode 100644 index 15a00cb6..00000000 --- a/examples/transformer/policy/megatron_parallel.py +++ /dev/null @@ -1,73 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.operator.function import CubeComplexFeedForward, CubeComplexSelfAttention -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation - - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using tensor parallel - """ - print('> transforming graph...') - ndevs = resource.ngpus - dp = 1 - tp = ndevs // dp - - # dataloader - - dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] - for dnode in dnodes: - algo = dnode.algorithms('data') - dp_nodes = graph.partition(dnode, algo, config=dict(chunk_num=dp)) - for idx, dp_node in enumerate(dp_nodes): - dp_node.tag = idx * tp - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - for fnode in fnodes: - sub_nodes = list() - if isinstance(fnode, CubeComplexSelfAttention): - algo = fnode.algorithms('data') - dp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=dp)) - for dp_node in dp_nodes: - algo = dp_node.algorithms('head') - tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) - sub_nodes += tp_nodes - elif isinstance(fnode, CubeComplexFeedForward): - algo = fnode.algorithms('data') - dp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=dp)) - for dp_node in dp_nodes: - algo = dp_node.algorithms('tensor') - tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) - sub_nodes += tp_nodes - else: - # note replicate should put in the last due to bugs: - algo = fnode.algorithms('dim') - dp_nodes = graph.partition(fnode, algo, config=dict(dim=1, chunk_num=dp)) - for dp_node in dp_nodes: - rep_nodes = graph.replicate(dp_node, times=tp) - sub_nodes += rep_nodes - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - # print(graph) - # assert False - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - sugraph.partial_set_order(fsus, lazy=False) - return sugraph diff --git a/examples/transformer/policy/naive.py b/examples/transformer/policy/naive.py new file mode 100644 index 00000000..eb3a3516 --- /dev/null +++ b/examples/transformer/policy/naive.py @@ -0,0 +1,11 @@ +from cube.graph import IRGraph +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation + + +def PAS(graph: IRGraph, resource): + print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + return graph \ No newline at end of file diff --git a/examples/transformer/policy/no_parallel.py b/examples/transformer/policy/no_parallel.py deleted file mode 100644 index 702c02b7..00000000 --- a/examples/transformer/policy/no_parallel.py +++ /dev/null @@ -1,23 +0,0 @@ -from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph - - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using column parallel - """ - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - for su in sugraph.fsus(): - sugraph.assign(su, 0) - sugraph.assign(su.mirror, 0) - return sugraph diff --git a/examples/transformer/policy/pipeline_parallel.py b/examples/transformer/policy/pipeline_parallel.py deleted file mode 100644 index 366fb23b..00000000 --- a/examples/transformer/policy/pipeline_parallel.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import List - -from cube.schedule.su import SUType, ScheduleUnit -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation - -_batch_size = 8 -_micro_batch_size = 1 - -def transform_policy(graph, resource): - """ - The transformation policy transposes linear using data parallel - """ - print('> transforming graph...') - micro_batch_num = _batch_size // _micro_batch_size - - for node in graph.nodes(): - if isinstance(node, IRDataOperation) or isinstance(node, IRFwOperation): - algo = node.algorithms('data') - if algo is not None: - sub_nodes = graph.partition(node, algo, config=dict(chunk_num=micro_batch_num)) - else: - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, config=dict(dim=1, chunk_num=micro_batch_num)) - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - print('> [Done] transforming graph...') - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy - """ - print('> scheduling su graph...') - num_micro_batch = _batch_size // _micro_batch_size - # each device is a stage - num_stage = resource.ngpus - - fseqs: List[List[ScheduleUnit]] = [list() for _ in range(num_micro_batch)] - fbseqs: List[List[ScheduleUnit]] = [list() for _ in range(num_micro_batch)] - - for fsu in sugraph.fsus(): - micro_bs_id = fsu.tag[0] - fseqs[micro_bs_id].append(fsu) - - for micro_bs_id, fseq in enumerate(fbseqs): - bseq = [fsu.mirror for fsu in fseq][::-1] - fbseqs[micro_bs_id] = fseq + bseq - - print(f'> collect {len(fseqs)} forward-backward sequence') - - # fstages[micro_batch_id][stage] = fstages[micro_batch_id * num_stage + stage] - fstages: List[List[ScheduleUnit]] = [ - list() for _ in range(num_micro_batch * num_stage) - ] - - def f(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: - return fstages[micro_batch_id * num_stage + stage_id] - - def b(micro_batch_id: int, stage_id: int) -> List[ScheduleUnit]: - fstage = f(micro_batch_id, stage_id) - bstage = [fsu.mirror for fsu in fstage][::-1] - return bstage - - # assign su to stages - for micro_bid, fseq in enumerate(fseqs): - chunk_num = int(len(fseq) // resource.ngpus) - for idx, fsu in enumerate(fseq): - stage = min(int(idx // chunk_num), num_stage - 1) - fstages[micro_bid * num_stage + stage].append(fsu) - - # stage device assignment - for micro_bid in range(num_micro_batch): - for stage in range(num_stage): - for su in f(micro_bid, stage): - sugraph.assign(su, stage) - sugraph.assign(su.mirror, stage) - - # device assignment - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - - # 1f1b scheduling - seqs = list() - - # warmup - for stage in range(num_stage): - for mid in range(stage): - seqs += f(mid, stage) - - # steady + cooldown: - for mid in range(num_micro_batch): - # enqueue backward - for stage in range(num_stage-1, -1, -1): - seqs += b(mid, stage) - # enqueue forward - for stage in range(num_stage): - f_mid = mid + 1 + num_stage - stage - if f_mid >= num_micro_batch: - continue - seqs += f(f_mid, stage) - - sugraph.partial_set_order(seqs) - - print('> [Done] scheduling su graph') - # print(sugraph) - return sugraph diff --git a/examples/transformer/policy/tensor_parallel.py b/examples/transformer/policy/tensor_parallel.py deleted file mode 100644 index 0cedcb4c..00000000 --- a/examples/transformer/policy/tensor_parallel.py +++ /dev/null @@ -1,161 +0,0 @@ -from cube.graph import IRGraph -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRFwOperation - - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using tensor parallel - """ - print('> transforming graph...') - ndevs = resource.ngpus - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - assert len(fnodes) == 23 - - attn_ln = fnodes[0] - - toqkv = fnodes[1] - q_t = fnodes[2] - k_t = fnodes[3] - v_t = fnodes[4] - q_scale = fnodes[5] - k_t2 = fnodes[6] - qk_bmm = fnodes[7] - mask = fnodes[8] - softmax = fnodes[9] - attn_dropout = fnodes[10] - attnv_bmm = fnodes[11] - attnview = fnodes[12] - linear = fnodes[13] - - attn_post_dropout = fnodes[14] - attn_residual = fnodes[15] - - ffn_ln = fnodes[16] - ffn_linear1 = fnodes[17] - ffn_gelu = fnodes[18] - ffn_linear2 = fnodes[19] - - ffn_post_dropout = fnodes[20] - ffn_post_residual = fnodes[21] - - loss = fnodes[22] - - - all_sub_nodes = list() - - # ============== attention ============ - sub_nodes = graph.replicate(attn_ln, times=resource.ngpus) - all_sub_nodes.append(sub_nodes) - - algo = toqkv.algorithms('head') - sub_nodes = graph.partition(toqkv, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = q_t.algorithms('dim') - sub_nodes = graph.partition(q_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = k_t.algorithms('dim') - sub_nodes = graph.partition(k_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = v_t.algorithms('dim') - sub_nodes = graph.partition(v_t, algo, config=dict(dim=1, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = q_scale.algorithms('dim') - sub_nodes = graph.partition(q_scale, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = k_t2.algorithms('dim') - sub_nodes = graph.partition(k_t2, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = qk_bmm.algorithms('data') - sub_nodes = graph.partition(qk_bmm, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = mask.algorithms('head') - sub_nodes = graph.partition(mask, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = softmax.algorithms('dim') - sub_nodes = graph.partition(softmax, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = attn_dropout.algorithms('dim') - sub_nodes = graph.partition(attn_dropout, algo, config=dict(dim=0, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = attnv_bmm.algorithms('data') - sub_nodes = graph.partition(attnv_bmm, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = attnview.algorithms('head') - sub_nodes = graph.partition(attnview, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = linear.algorithms('row') - sub_nodes = graph.partition(linear, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - # ========== between attention and mlp =============== - sub_nodes = graph.replicate(attn_post_dropout, times=resource.ngpus) - all_sub_nodes.append(sub_nodes) - - sub_nodes = graph.replicate(attn_residual, times=resource.ngpus) - all_sub_nodes.append(sub_nodes) - - sub_nodes = graph.replicate(ffn_ln, times=resource.ngpus) - all_sub_nodes.append(sub_nodes) - - # =========== mlp =========== - algo = ffn_linear1.algorithms('column') - sub_nodes = graph.partition(ffn_linear1, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = ffn_gelu.algorithms('dim') - sub_nodes = graph.partition(ffn_gelu, algo, config=dict(dim=2, chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - algo = ffn_linear2.algorithms('row') - sub_nodes = graph.partition(ffn_linear2, algo, config=dict(chunk_num=ndevs)) - all_sub_nodes.append(sub_nodes) - - # ========== post mlp ======== - sub_nodes = graph.replicate(ffn_post_dropout, times=resource.ngpus) - all_sub_nodes.append(sub_nodes) - - sub_nodes = graph.replicate(ffn_post_residual, times=resource.ngpus) - all_sub_nodes.append(sub_nodes) - - # =========== loss =========== - sub_nodes = graph.replicate(loss, times=ndevs) - all_sub_nodes.append(sub_nodes) - - for sub_nodes in all_sub_nodes: - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - print(graph) - # assert False - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - sugraph.assign(su, 0) - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - sugraph.partial_set_order(fsus, lazy=False) - return sugraph diff --git a/examples/transformer/transformer.py b/examples/transformer/transformer.py deleted file mode 100644 index 78759d9a..00000000 --- a/examples/transformer/transformer.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/transformer/transformer.py -""" - -import torch -from torch import nn -import torch.nn.functional as F -import cube - - -from examples.transformer.policy.tensor_parallel import transform_policy -from examples.transformer.policy.tensor_parallel import schedule_policy - -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - - -class MultiHeadSelfAttention(nn.Module): - - def __init__(self, seq_len, embed_dim, heads, dropout): - super().__init__() - - self.seq_len = seq_len - self.embed_dim = embed_dim - self.num_head = heads - self.dim_head = embed_dim // heads - self.scale = self.dim_head ** -0.5 - - self.weight_qkv = torch.nn.Parameter(torch.empty( - 3 * embed_dim, embed_dim - )) - self.weight_out = torch.nn.Parameter(torch.empty( - embed_dim, embed_dim - )) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - """ - x: [L, N, E]: seq_len, batch_size, embedding dimension - output: [L, N, E] - """ - # [L, N, E] -> 3 x [L, (N * num_head), dim_head] - q, k, v = cube.runtime.function.toqkv( - x, self.weight_qkv, self.num_head - ) - - # [L, (N * num_head), dim_head] -> [(N * num_head), L, dim_head] - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - # [(N * num_head), L, dim_head] -> [(N * num_head), L, dim_head] - q = q * self.scale - # [(N * num_head), L, dim_head] -> [(N * num_head), dim_head, L] - k = k.transpose(-2, -1) - # [(N * num_head), L, dim_head] * [(N * num_head), dim_head, L] - # -> [(N * num_head), L, L] - attn = torch.bmm(q, k) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = cube.runtime.function.tril_mask(attn, self.num_head) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = F.softmax(attn, dim=-1) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = self.dropout(attn) - # [(N * num_head), L, L] * [(N * num_head), L, dim_head] - # -> [(N * num_head), L, dim_head] - output = torch.bmm(attn, v) - - # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] - output = cube.runtime.function.attn_view(output, self.num_head) - - # [L, N, num_head * dim_head] * [E, embed_head * dim_head] - # -> [L, N, E] - output = F.linear(output, self.weight_out) - return output - - -class FFN(torch.nn.Module): - - def __init__(self, hidden_size: int): - super().__init__() - self.dense_h_to_4h = torch.nn.Linear( - hidden_size, 4 * hidden_size - ) - self.dense_4h_to_h = torch.nn.Linear( - 4 * hidden_size, hidden_size - ) - - def forward(self, hidden_states): - # [L, N, E] * [E, 4E] -> [L, N, 4E] - out = self.dense_h_to_4h(hidden_states) - # [L, N, 4E] -> [L, N, 4E] - out = F.gelu(out) - # [L, N, 4E] * [4E, E] -> [L, N, E] - out = self.dense_4h_to_h(out) - return out - - -class TransformerLayer(torch.nn.Module): - - def __init__(self, seq_len, hidden_size, head_num, dropout): - super().__init__() - # layer norm - self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - - self.attention = MultiHeadSelfAttention(seq_len, hidden_size, head_num, dropout) - self.attn_dropout = torch.nn.Dropout(dropout) - - self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - self.ffn = FFN(hidden_size) - self.ffn_dropout = torch.nn.Dropout(dropout) - - def forward(self, hidden_states): - # Attention - in_attn_norm = self.input_layernorm(hidden_states) - attn_out = self.attention(in_attn_norm) - # residual - attn_out = self.attn_dropout(attn_out) - residual = attn_out + hidden_states - # ffn - in_ffn_norm = self.ffn_layernorm(residual) - ffn_out = self.ffn(in_ffn_norm) - # residual - ffn_out = self.ffn_dropout(ffn_out) - ffn_out = ffn_out + residual - - loss = torch.sum(ffn_out) - return loss - - -def train(): - L = 512 # seq len - N = 32 # batch size - # configs: [hidden size, num_head] - # E, num_head = [1536, 16] # 1.2B model - # E, num_head = [1920, 20] # 2.5B model - # E, num_head = [2304, 24] # 4.2B model - E, num_head = [3072, 32] # 8.7B model - - - model = TransformerLayer( - seq_len=L, hidden_size=E, head_num=num_head, dropout=0.5 - ) - model = cube.SemanticModel( - model, input_shapes=([L, N, E],), - ) - - dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) - - @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - model = model.get_gen_module() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - - -if __name__ == '__main__': - - cube.init() - train() diff --git a/examples/transformer/transformers.py b/examples/transformer/transformers.py index 65f04721..bfa82c5c 100644 --- a/examples/transformer/transformers.py +++ b/examples/transformer/transformers.py @@ -1,6 +1,11 @@ """ example: +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/transformer/transformers.py + python -m torch.distributed.launch \ --nproc_per_node=4 \ --nnodes=1 \ @@ -15,78 +20,124 @@ from torch import nn import cube -from examples.transformer.policy.megatron_parallel import transform_policy -from examples.transformer.policy.megatron_parallel import schedule_policy +from examples.transformer.policy.naive import PAS from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank from cube.profiler.memory import memory_summary +@cube.graph.parser.register('L^ N E^, (3 h d^) E^ -> L^ N (h d^)') +def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, h: int, + scale: float, dropout: float, training: bool): + """ + L: sequence length + N: batch size + E: embedding size + x: hidden state: [L, N, E] + wqkv: qkv weight: [3 * (num_head * dim_head), E] + dropout: float + h: int: number of heads + """ + num_head = h + L, N = x.shape[0], x.shape[1] + dim_head = wqkv.shape[0] // 3 // num_head + # L N E, (3 h d) E -> L N (3 h d) + qkv = torch.nn.functional.linear(x, wqkv, None) + # L N (3 h d) -> L N (h d), L N (h d), L N (h d) + q, k, v = qkv.chunk(3, dim=-1) + # L N (h d) -> L (N h) d + q = q.contiguous().view(L, (N * num_head), dim_head) + # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) + # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) + # L (N h) d -> (N h) L d + q = q.transpose(0, 1) + # L (N h) d -> (N h) L d + k = k.transpose(0, 1) + # L (N h) d -> (N h) L d + v = v.transpose(0, 1) + # (N h) L d, 1 -> (N h) L d + q = q * scale + # (N h) L d -> (N h) d L + k = k.transpose(-2, -1) + # (N h) L d, (N h) d L -> (N h) L L + attn = torch.bmm(q, k) + + # attention mask + # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + # (N h) L L -> (N h) L L + attn = torch.nn.functional.softmax(attn, dim=-1) + # (N h) L L -> (N h) L L + if training: + attn = torch.nn.functional.dropout(attn, dropout, True, False) + # (N h) L L, (N h) L d -> (N h) L d + output = torch.bmm(attn, v) + # (N h) L d -> L (N h) d + output = output.transpose(0, 1).contiguous() + # L (N h) d -> L N (h d) + output = output.view(L, N, num_head * dim_head) + return output + + class MultiHeadSelfAttention(nn.Module): - def __init__(self, embed_dim, heads, dropout): + def __init__(self, embed_dim, heads, dropout: float): super().__init__() - + self.embed_dim = embed_dim self.num_head = heads self.dim_head = embed_dim // heads - self.dropout = dropout + self.scale = self.dim_head ** -0.5 - self.weight_qkv = torch.nn.Parameter(torch.empty( + self.wqkv = torch.nn.Parameter(torch.empty( 3 * embed_dim, embed_dim )) - self.weight_out = torch.nn.Parameter(torch.empty( + self.wout = torch.nn.Parameter(torch.empty( embed_dim, embed_dim )) + self.dropout = dropout def forward(self, x): """ - Multi-Head Self-Attention. - - L: sequence length - N: batch size - E: embedding size - - Inputs: - hidden_state: [L, N, E] - w_qkv : [3 * num_head * dim_head, E] - w_out : [E, E] - - Outputs: - hidden_state: [L, N, E] + x: [L, N, E]: seq_len, batch_size, embedding dimension + output: [L, N, E] """ - - hidden_state = cube.runtime.function.complex.self_attn( - x, self.weight_qkv, self.weight_out, - self.num_head, self.dim_head, self.dropout - ) - return hidden_state + # L N E, (3 h d) E -> L N (h d) + output = attnfc1(x, self.wqkv, self.num_head, + self.scale, self.dropout, self.training) + # L N (h d), E (h d) -> L N E + output = torch.nn.functional.linear(output, self.wout) + return output class FFN(torch.nn.Module): def __init__(self, hidden_size: int): super().__init__() - self.proj1_weight = torch.nn.Parameter( - torch.empty(4 * hidden_size, hidden_size) - ) - self.proj1_bias = torch.nn.Parameter( - torch.empty(4 * hidden_size) + self.dense_h_to_4h = torch.nn.Linear( + hidden_size, 4 * hidden_size ) - self.proj2_weight = torch.nn.Parameter( - torch.empty(hidden_size, 4 * hidden_size) - ) - self.proj2_bias = torch.nn.Parameter( - torch.empty(hidden_size) + self.dense_4h_to_h = torch.nn.Linear( + 4 * hidden_size, hidden_size ) def forward(self, hidden_states): - hidden_states = cube.runtime.function.complex.feedforward( - hidden_states, - self.proj1_weight, self.proj1_bias, - self.proj2_weight, self.proj2_bias - ) - return hidden_states + # [L, N, E] * [E, 4E] -> [L, N, 4E] + out = self.dense_h_to_4h(hidden_states) + # [L, N, 4E] -> [L, N, 4E] + out = torch.nn.functional.gelu(out) + # [L, N, 4E] * [4E, E] -> [L, N, E] + out = self.dense_4h_to_h(out) + return out class TransformerLayer(torch.nn.Module): @@ -109,13 +160,15 @@ def forward(self, hidden_states): attn_out = self.attention(in_attn_norm) # residual attn_out = self.attn_dropout(attn_out) - residual = attn_out + hidden_states + # TODO: enable residual + residual = attn_out # + hidden_states # ffn in_ffn_norm = self.ffn_layernorm(residual) ffn_out = self.ffn(in_ffn_norm) # residual ffn_out = self.ffn_dropout(ffn_out) - ffn_out = ffn_out + residual + # TODO: enable residual + ffn_out = ffn_out # + residual return ffn_out @@ -143,9 +196,10 @@ def train(): L = 512 # seq len N = 8 # batch size # configs: [hidden size, num_head] - # E, num_head = [2304, 24, 24] # 1.7B model - E, num_head, layers = [3072, 32, 30] # 3.6B model - # E, num_head, layers = [4096, 32, 36] # 7.5B model + # E, num_head = [1536, 16] # 1.2B model + # E, num_head = [1920, 20] # 2.5B model + # E, num_head = [2304, 24] # 4.2B model + E, num_head = [3072, 32] # 8.7B model model = Transformers( @@ -155,9 +209,13 @@ def train(): model, input_shapes=([L, N, E],), ) - dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([L, N, E],), + dtypes=(torch.float32,), + batch_dims=(1,) + ) - @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) + @cube.compile(model, dataloader, PAS=PAS) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) @@ -180,11 +238,8 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) memory_summary() if __name__ == '__main__': From 1b66ba3862fffa07fabad7ed98eb941941198c9d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Mar 2022 10:27:50 +0800 Subject: [PATCH 0643/1892] device group creates hybrid groups --- cube/runtime/device.py | 30 +++++- cube/runtime/function/__init__.py | 1 - cube/runtime/function/complex.py | 168 ------------------------------ 3 files changed, 29 insertions(+), 170 deletions(-) delete mode 100644 cube/runtime/function/complex.py diff --git a/cube/runtime/device.py b/cube/runtime/device.py index 789b2605..0e667ed8 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -1,7 +1,8 @@ """ Communication group settings among devices """ - +from typing import List +import numpy as np import torch import os @@ -50,6 +51,33 @@ def get_group(self, ranks): self.groups[rank_bits] = torch.distributed.new_group(list(ranks)) return self.groups[rank_bits] + def create_hybrid(self, group_num: List[int]) -> List[List[int]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + if cnt != self.world_size: + raise RuntimeError("product of group_num should be same with total device number") + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = np.prod(np.delete(group_num, dim)) + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + for ranks in grid_dim: + # initialize group + _ = self.get_group(ranks) + if self.rank in ranks: + outputs.append(ranks) + assert len(outputs) == len(group_num) + return outputs + + @staticmethod def bitmap(ranks): """ diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py index e86bbc76..bcd6790e 100644 --- a/cube/runtime/function/__init__.py +++ b/cube/runtime/function/__init__.py @@ -1,3 +1,2 @@ -from cube.runtime.function.complex import * from cube.runtime.function.dist import * from cube.runtime.function.function import * \ No newline at end of file diff --git a/cube/runtime/function/complex.py b/cube/runtime/function/complex.py deleted file mode 100644 index 84111c69..00000000 --- a/cube/runtime/function/complex.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn.functional as F - - -def toqkv(hidden_state: torch.Tensor, weight: torch.Tensor, - num_head: int): - """ - Inputs: - hidden_state: [L, N, E] - weight: [3 * (num_head * dim_head), E] - num_head: int - - where L = sequence length, N = batch size, E = num_head * dim_head - - Returns: - Q: [L, N * num_head, dim_head] - K: [L, N * num_head, dim_head] - V: [L, N * num_head, dim_head] - """ - seqlen = hidden_state.shape[0] - bs = hidden_state.shape[1] - dim_head = weight.shape[0] // 3 // num_head - qkv = F.linear(hidden_state, weight, None) - qkv = qkv.chunk(3, dim=-1) - q, k, v = qkv - q = q.contiguous() - q = q.view(seqlen, (bs * num_head), dim_head) - k = k.contiguous() - k = k.view(seqlen, (bs * num_head), dim_head) - v = v.contiguous() - v = v.view(seqlen, (bs * num_head), dim_head) - return q, k, v - - -def tril_mask(input: torch.Tensor, num_head: int): - """ - Inputs: - input: [N * num_head, L, L] - num_head: int - - Returns: - output: [N * num_head, L, L] - """ - bs: int = input.shape[0] // num_head - seqlen: int = input.shape[2] - input = input.view(bs, num_head, seqlen, seqlen) - # set up mask - ones = torch.ones( - (bs, seqlen, seqlen), - device=input.device, - ) - mask = torch.tril(ones) - mask = mask.view(bs, 1, seqlen, seqlen) - mask = (mask < 0.5) - # mask - masked_input = input.masked_fill_(mask, -10000.0) - masked_input = masked_input.view((bs * num_head), seqlen, seqlen) - return masked_input - - -def attn_view(input: torch.Tensor, num_head: int): - """ - Inputs: - [N * num_head, L, dim_head] - - Outputs: - [L, N, num_head * dim_head] - """ - bs: int = input.shape[0] // num_head - seqlen: int = input.shape[1] - dim_head = input.shape[2] - # [(N * num_head), L, dim_head] -> [L, (N * num_head), dim_head] - input = input.transpose(0, 1).contiguous() - # [L, (N * num_head), dim_head] -> [L, N, (num_head * dim_head)] - input = input.view(seqlen, bs, num_head * dim_head) - return input - - -def self_attn(hidden_state, w_qkv, w_out, - num_head: int, dim_head: int, - dropout_p: float): - """ - Multi-Head Self-Attention. - - L: sequence length - N: batch size - E: embedding size - - Inputs: - hidden_state: [L, N, E] - w_qkv : [3 * num_head * dim_head, E] - w_out : [E, E] - - Outputs: - hidden_state: [L, N, E] - """ - scale = dim_head ** -0.5 - seqlen = hidden_state.shape[0] - bs = hidden_state.shape[1] - - qkv = F.linear(hidden_state, w_qkv, None) - qkv = qkv.chunk(3, dim=-1) - q, k, v = qkv - q = q.contiguous() - q = q.view(seqlen, (bs * num_head), dim_head) - k = k.contiguous() - k = k.view(seqlen, (bs * num_head), dim_head) - v = v.contiguous() - v = v.view(seqlen, (bs * num_head), dim_head) - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - q = q * scale - k = k.transpose(-2, -1) - attn = torch.bmm(q, k) - - attn = tril_mask(attn, num_head) - attn = F.softmax(attn, dim=-1) - attn = F.dropout(attn, dropout_p, True, False) - output = torch.bmm(attn, v) - output = attn_view(output, num_head) - - output = F.linear(output, w_out, None) - return output - - -def feedforward(hidden_state, w_proj1, w_bias1, w_proj2, w_bias2): - """ - FeedForward - - Inputs: - hidden_state: [L, N, E] - w_proj1: [4 * E, E] - w_bias1: [4 * E,] - w_porj2: [E, 4 * E] - w_bias2: [E,] - """ - out = F.linear(hidden_state, w_proj1, w_bias1) - out = F.gelu(out) - out = F.linear(out, w_proj2, w_bias2) - return out - - -def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): - """ - Embedding - - Inputs: - input: torch.Tensor [*] - weight: [vocab size, embed size] - start: int - stop: int - - Outputs: - output: [*, embed_size] - """ - input = input.long() - input_mask = (input < start) | (input >= stop) - masked_input = input.clone() - start - masked_input[input_mask] = 0 - output = F.embedding( - masked_input, weight, - None, None, 2.0, False, False - ) - output[input_mask, :] = 0.0 - return output From a97839abe04072128f41ff2490b3e9208cc741bc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Mar 2022 12:20:45 +0800 Subject: [PATCH 0644/1892] add hybrid benchmark --- handcraft/pipeline/dummy_hybrid.py | 283 +++++++++++++++++++++++++++++ handcraft/pipeline/schedule.py | 29 ++- 2 files changed, 305 insertions(+), 7 deletions(-) create mode 100644 handcraft/pipeline/dummy_hybrid.py diff --git a/handcraft/pipeline/dummy_hybrid.py b/handcraft/pipeline/dummy_hybrid.py new file mode 100644 index 00000000..e7ff9282 --- /dev/null +++ b/handcraft/pipeline/dummy_hybrid.py @@ -0,0 +1,283 @@ +""" +Dummy model + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + handcraft/pipeline/dummy_hybrid.py --tp-size 2 --pp-size 2 + +OMP_NUM_THREADS=4 python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + handcraft/pipeline/dummy_hybrid.py --use-naive +""" +import torch +import torch.nn.functional as F +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary +from cube.runtime.device import DeviceGroup +from cube.runtime.syndata import SynDataLoader, SynTextDataLoader + +from handcraft.pipeline.schedule import schedule_tp_1f1b, schedule_naive, schedule_tp_1f1b_pack +import argparse + +""" +Stage0: + Embedding [M, 1], [N, E] -> [M, E] + Linear [M, E], [E, E] -> [M, E] + +Stage Else: + Linear [M, E], [E, E] -> [M, E] + +Condition: N > 8M - E +""" + +_tp_group = None +_pp_group = None +_pp_next_rank = None +_pp_prev_rank = None + +io_input = input + +class ReduceEmbed(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + global _tp_group + torch.distributed.all_reduce(input, group=_tp_group) + return input + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class IdentityFoward(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + global _tp_group + torch.distributed.all_reduce(grad_output, group=_tp_group) + return grad_output + + +class DummyModelEmbed(torch.nn.Module): + + def __init__(self, M: int, N: int, E: int, tp_group = None, pp_group = None): + super().__init__() + self.M = M + self.N = N + self.E = E + + self.tp_group = tp_group + self.tp_size = torch.distributed.get_world_size(tp_group) + self.pp_group = pp_group + + shard_id = torch.distributed.get_rank(tp_group) + self.vocab_start_index = N // self.tp_size * shard_id + self.vocab_end_index = N // self.tp_size * (shard_id + 1) + self.embed_weight = torch.nn.Parameter(torch.ones((N // self.tp_size, E))) + + def input_shape(self): + return (self.M, ) + + def input_dtype(self): + return torch.int64 + + def output_shape(self): + return (self.M, self.E) + + def output_dtype(self): + return torch.float32 + + def forward(self, input: torch.Tensor): + if self.tp_size > 1: + mask = (input < self.vocab_start_index) | \ + (input >= self.vocab_end_index) + input = input.clone() - self.vocab_start_index + input[mask] = 0 + input = F.embedding(input, self.embed_weight) + input[mask, :] = 0.0 + input = ReduceEmbed.apply(input) + else: + input = F.embedding(input, self.embed_weight) + return input + + +class DummyModel(torch.nn.Module): + + def __init__(self, M: int, N: int, E: int): + + super().__init__() + self.M = M + self.N = N + self.E = E + + # group + global _tp_group + self.tp_group = _tp_group + global _pp_group + self.pp_group = _pp_group + + self.pp_stage = torch.distributed.get_rank(_pp_group) + self.is_first_pp_stage = self.pp_stage == 0 + self.is_last_stage = self.pp_stage == torch.distributed.get_world_size(_pp_group) - 1 + + self.tp_size = torch.distributed.get_world_size(_tp_group) + + # mebed module + if self.is_first_pp_stage: + self.embed = DummyModelEmbed(M, N, E, self.tp_group, self.pp_group) + else: + self.embed = None + + total_fc_num = torch.distributed.get_world_size() + fc_weights = list() + input_shapes = list() + output_shapes = list() + shard_types = list() + for idx in range(total_fc_num): + if idx % 2 == 0: + fc_weights.append( + torch.nn.Parameter(torch.ones((E // self.tp_size, E)) / 10000) + ) + input_shapes.append((M, E)) + output_shapes.append((M, E // self.tp_size)) + shard_types.append('col') + else: + fc_weights.append( + torch.nn.Parameter(torch.ones((E, E // self.tp_size)) / 10000) + ) + input_shapes.append((M, E // self.tp_size)) + output_shapes.append((M, E)) + shard_types.append('row') + + self.fc_num = total_fc_num // torch.distributed.get_world_size(_pp_group) + self.fc_weights = torch.nn.ParameterList( + fc_weights[self.fc_num * self.pp_stage : self.fc_num * (self.pp_stage + 1)] + ) + self.ins = input_shapes[self.fc_num * self.pp_stage : self.fc_num * (self.pp_stage + 1)] + self.ous = output_shapes[self.fc_num * self.pp_stage : self.fc_num * (self.pp_stage + 1)] + self.shard_types = shard_types[self.fc_num * self.pp_stage : self.fc_num * (self.pp_stage + 1)] + print_each_rank(f'initializing with {self.fc_num} fcs: {self.shard_types}') + + + def input_shape(self): + if self.embed: + return self.embed.input_shape() + else: + return self.ins[0] + + def output_shape(self): + if self.is_last_stage: + return (1,) + else: + return self.ous[-1] + + def input_dtype(self): + if self.embed: + return self.embed.input_dtype() + else: + return torch.float32 + + def output_dtype(self): + return torch.float32 + + def forward(self, input: torch.Tensor): + # print(f'[{DeviceGroup().rank}] input: {input}, shape={input.size()}') + if self.embed: + x = self.embed(input) + else: + x = input + + for stype, weight in zip(self.shard_types, self.fc_weights): + # column partition + if stype == 'col': + x = IdentityFoward.apply(x) + x = F.linear(x, weight) + elif stype == 'row': + x = F.linear(x, weight) + x = ReduceEmbed.apply(x) + else: + assert False + + if self.is_last_stage: + x = torch.sum(x) + # print(x) + # print(f'[{DeviceGroup().rank}] output: {output}, shape={output.size()}') + return x + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--tp-size', type=int, + help='tensor parallelism size') + parser.add_argument('--pp-size', type=int, + help='pipeline parallelism size') + parser.add_argument('--nmb', type=int, default=4, + help='num of micro batch') + parser.add_argument('--M', type=int, default=4096, + help='M dimension length = sequence length') + parser.add_argument('--N', type=int, default=50257, + help='word number') + parser.add_argument('--E', type=int, default=2048, + help='E dimension length = hidden dimension length') + args = parser.parse_args() + + print(args) + + cube.init() + rank = DeviceGroup().rank + pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) + + if _pp_group is None: + _pp_group = DeviceGroup().get_group(pp_ranks) + idx = pp_ranks.index(DeviceGroup().rank) + _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] + _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] + if _tp_group is None: + _tp_group = DeviceGroup().get_group(tp_ranks) + + + model = DummyModel(args.M, args.N, args.E).cuda() + optimizer = torch.optim.Adam(model.parameters()) + + # 0.11GB + print_each_rank('model consumption') + memory_summary() + + dataloader = SynTextDataLoader( + shapes=([args.M],), + dtypes=(torch.int64, ), + batch_dims=(0,), + length=128000 + ) + + # 0.11GB + print_each_rank('model + dataloader consumption') + memory_summary() + + iter_num = 64 + CudaTimer(enable=False).warmup() + for step in range(iter_num): + if step >= 20: + CudaTimer(enable=True).start('e2e') + + schedule_naive(model, dataloader, args.nmb, [_pp_prev_rank, _pp_next_rank]) + optimizer.step() + optimizer.zero_grad() + + if step >= 20: + CudaTimer().stop('e2e') + if (step+1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-20, field_name='e2e'))) + memory_summary() diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py index bfbc57f7..8db0c28d 100644 --- a/handcraft/pipeline/schedule.py +++ b/handcraft/pipeline/schedule.py @@ -6,9 +6,17 @@ import cube.runtime.adapter.collectives as coll from cube.runtime.device import DeviceGroup +from torch.distributed.distributed_c10d import _get_global_rank io_input = input +def get_global_rank(group, group_rank): + if group is None: + return group_rank + else: + return _get_global_rank(group, group_rank) + + def forward_step(model, *args, **kwargs): """ Forward pass @@ -55,13 +63,20 @@ def recv_input(model, dataloader, prev_rank: int): return coll.recv(model.input_shape(), prev_rank, model.input_dtype()) -def schedule_naive(model, dataloader, num_microbatch: int): +def schedule_naive(model, dataloader, num_microbatch: int, neighbors = None): rank = DeviceGroup().rank - next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size - prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size + if neighbors is None: + prev_rank = (rank - 1) % DeviceGroup().world_size + next_rank = (rank + 1) % DeviceGroup().world_size + else: + prev_rank, next_rank = neighbors + + is_first_stage = rank < prev_rank + is_last_stage = rank > next_rank + for step in range(num_microbatch): # recv forward - if is_first_stage(): + if is_first_stage: input = next(dataloader) else: # print(f'rank {rank} recving forward input...') @@ -69,18 +84,18 @@ def schedule_naive(model, dataloader, num_microbatch: int): # forward output = forward_step(model, input) # send forward - if not is_last_stage(): + if not is_last_stage: # print(f'rank {rank} sending forward output...') coll.send(output, next_rank) # recv backward output_grad = None - if not is_last_stage(): + if not is_last_stage: # print(f'rank {rank} recving backward input...') output_grad = coll.recv(output.size(), next_rank, output.dtype) # backward input_grad = backward_step([input], [output], [output_grad])[0] # send backward - if not is_first_stage(): + if not is_first_stage: # print(f'rank {rank} sending backward output...') coll.send(input_grad, prev_rank) From 17dab0846d2a4c015fc97db11d3c735c604cf1ff Mon Sep 17 00:00:00 2001 From: lynex Date: Fri, 11 Mar 2022 18:15:59 +0800 Subject: [PATCH 0645/1892] naive PAS for atmosphere weather --- examples/atmosphere/policy/naive.py | 11 +++++++++++ examples/atmosphere/weather.py | 4 +--- 2 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 examples/atmosphere/policy/naive.py diff --git a/examples/atmosphere/policy/naive.py b/examples/atmosphere/policy/naive.py new file mode 100644 index 00000000..caf38e07 --- /dev/null +++ b/examples/atmosphere/policy/naive.py @@ -0,0 +1,11 @@ +from cube.graph import IRGraph +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation + + +def PAS(graph: IRGraph, resource): + print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + return graph diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index 862ab0d6..b63159f7 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -8,7 +8,7 @@ from typing import List import cube -from examples.poisson.policy.naive import PAS +from examples.atmosphere.policy.naive import PAS from einops.layers.torch import Rearrange @@ -166,8 +166,6 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): v1 = (pi0_deltaA * v0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA # # print('v1:', v1.mean()) - # return pi1, theta1, u1, v1 - # return pi1, theta0, u0, v0 return pi1, theta1, u1, v1 From cbba4bb37062df4c55d81d391772433dc66e55e5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 11 Mar 2022 20:20:39 +0800 Subject: [PATCH 0646/1892] add mbart model --- examples/mbart/mbart.py | 415 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 415 insertions(+) create mode 100644 examples/mbart/mbart.py diff --git a/examples/mbart/mbart.py b/examples/mbart/mbart.py new file mode 100644 index 00000000..1eccf1b7 --- /dev/null +++ b/examples/mbart/mbart.py @@ -0,0 +1,415 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/mbart/mbart.py +""" + +from typing import Optional +import argparse +import math +import torch + +import cube +from cube.runtime.syndata import SynTextDataLoader +from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary +from cube.profiler.timer import print_each_rank + + +# fairseq task +# translation_from_pretrained_bart + +# fairseq criterion +# label_smoothed_cross_entropy, --label_smoothing = 0.2 + +class Config: + + num_embeddings = 500027 + + encoder_embed_path = None + encoder_embed_dim = 1024 + encoder_ffn_embed_dim = 4 * 1024 + encoder_layers = 12 + encoder_attention_heads = 16 + encoder_normalize_before = True + encoder_learned_pos = True + + decoder_embed_path = None + decoder_embed_dim = 1024 + decoder_ffn_embed_dim = 4 * 1024 + decoder_layers = 12 + decoder_attention_heads = 16 + decoder_normalize_before = True + decoder_learned_pos = True + cross_self_attention = False + no_cross_attention = False + + attention_dropout = 0.0 + activation_dropout = 0.0 + dropout = 0.1 + + max_target_positions = 1024 + max_source_positions = 1024 + adaptive_softmax_cutoff = None + adaptive_softmax_dropout = 0 + + share_decoder_input_output_embed = True + share_all_embeddings = True + + decoder_output_dim = 1024 # same with decorder_embed_dim + decoder_input_dim = 1024 # same with decorder_embed_dim + + no_scale_embedding = False # True in bart large + layernorm_embedding = True + activation_fn = 'gelu' + pooler_activation_fn = 'tanh' + pooler_dropout = 0.0 + + +def attn_fn(query: torch.Tensor, key: torch.Tensor, + wq: torch.Tensor, wq_bias: Optional[torch.Tensor], + wk: torch.Tensor, wk_bias: Optional[torch.Tensor], + wv: torch.Tensor, wv_bias: Optional[torch.Tensor], + wout: torch.Tensor, wout_bias: Optional[torch.Tensor], + h: int, scale: float, dropout: float, mask=True): + """ + query, key: (L, N, E) = (seqlen, batch size, embed_dim) + wq, wk, wv weight: [(num_head * dim_head), E] + dropout: float + h: int: number of heads + """ + num_head = h + L, N = query.size(0), query.size(1) + dim_head = wq.size(0) // num_head + + q = torch.nn.functional.linear(query, wq, wq_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(key, wk, wk_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(key, wv, wv_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + if mask: # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, wout, wout_bias) # L N (h d), E E -> L N E + return output + + +class MultiheadAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, dropout=0.0, bias=True): + super().__init__() + self.kdim = embed_dim + self.vdim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # K + self.k_proj = torch.nn.Parameter(torch.empty(embed_dim, self.kdim)) + if bias: + self.k_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.k_bias = None + # V + self.v_proj = torch.nn.Parameter(torch.empty(embed_dim, self.vdim)) + if bias: + self.v_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.v_bias = None + # Q + self.q_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + if bias: + self.q_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.q_bias = None + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + if bias: + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.out_bias = None + + def forward_encoder_decoder_attn(self, query: torch.Tensor, key: torch.Tensor): + # tgt_len, bsz, embed_dim = query.size() + # q = torch.nn.functional.linear(query, self.q_proj, self.q_bias) + # k = torch.nn.functional.linear(key, self.k_proj, self.k_bias) + # v = torch.nn.functional.linear(key, self.v_proj, self.v_bias) + # q = q * self.scaling + # q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + # k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + # v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + # attn_weights = torch.bmm(q, k.transpose(1, 2)) + # # TODO: here needs a mask + # attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + # attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_p) + # attn = torch.bmm(attn_probs, v) + # attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + # attn = torch.nn.functional.linear(attn, self.out_proj, self.out_bias) + return attn_fn(query, key, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) + + def forward_self_attn(self, query): + return attn_fn(query, query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) + + +class EncoderLayer(torch.nn.Module): + + def __init__(self, cfg: Config): + + super().__init__() + self.self_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.encoder_attention_heads, cfg.attention_dropout) + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.dropout = torch.nn.Dropout(p=cfg.dropout) + self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + self.fc1 = torch.nn.Linear(cfg.encoder_embed_dim, cfg.encoder_ffn_embed_dim) + self.fc2 = torch.nn.Linear(cfg.encoder_ffn_embed_dim, cfg.encoder_embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + def forward(self, x): # , encoder_padding_mask: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None): + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn.forward_self_attn(x) + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + x = self.fc1(x) + x = torch.nn.functional.gelu(x) + x = self.activation_dropout(x) + x = self.fc2(x) + + x = self.dropout(x) + x = x + residual + return x + + +class Encoder(torch.nn.Module): + + def __init__(self, cfg: Config, embed_tokens: torch.nn.Module): + super().__init__() + self.dropout = torch.nn.Dropout(cfg.dropout) + self.max_source_positions = cfg.max_source_positions + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(cfg.encoder_embed_dim) + self.embed_positions = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) + self.layernorm_embedding = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.layers = torch.nn.ModuleList([]) + self.layers.extend( + [EncoderLayer(cfg) for _ in range(cfg.encoder_layers)] + ) + # normalize before + self.layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + def forward(self, src_tokens: torch.Tensor): + token_embedding = self.embed_tokens(src_tokens) + embed = self.embed_scale * token_embedding + + x = embed + self.embed_positions.weight # self.embed_positions(src_tokens) + x = self.layernorm_embedding(x) + x = self.dropout(x) + + x = x.transpose(0, 1) + for layer in self.layers: + x = layer(x) # encoder_padding_mask if has_pads else None) + x = self.layer_norm(x) + return x + + +class DecoderLayer(torch.nn.Module): + + def __init__(self, cfg: Config): + + super().__init__() + self.dropout = torch.nn.Dropout(p=cfg.dropout) + self.self_attn = MultiheadAttention(cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) + self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + # encoder atten + self.encoder_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) + self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + # self.encoder_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + self.fc1 = torch.nn.Linear(cfg.decoder_embed_dim, cfg.decoder_ffn_embed_dim) + self.fc2 = torch.nn.Linear(cfg.decoder_ffn_embed_dim, cfg.decoder_embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + def forward(self, x, encoder_out): # encoder_padding_mask): + residual = x + # normalize before + x = self.self_attn_layer_norm(x) + + # self attention + x = self.self_attn.forward_self_attn(x) + x = self.dropout(x) + x = residual + x + + # encoder attn + residual = x + # normalize before + x = self.encoder_attn_layer_norm(x) + x = self.encoder_attn.forward_encoder_decoder_attn(x, encoder_out) + x = self.dropout(x) + x = x + residual + + residual = x + # normalize before + x = self.final_layer_norm(x) + x = self.fc1(x) + x = torch.nn.functional.gelu(x) + x = self.activation_dropout(x) + x = self.fc2(x) + x = self.dropout(x) + x = x + residual + return x + + +class Decoder(torch.nn.Module): + + def __init__(self, cfg: Config, embed_tokens: torch.nn.Module): + super().__init__() + self.dropout = torch.nn.Dropout(cfg.dropout) + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(cfg.decoder_embed_dim) + self.embed_positions = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) + self.layernorm_embedding = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.layers = torch.nn.ModuleList([]) + self.layers.extend( + [DecoderLayer(cfg) for _ in range(cfg.decoder_layers)] + ) + self.layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + def forward(self, prev_output_tokens: torch.Tensor, enc: torch.Tensor): + positions = self.embed_positions.weight # self.embed_positions(prev_output_tokens) + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + x = x + positions + x = self.layernorm_embedding(x) + x = self.dropout(x) + # B T C -> T B C + x = x.transpose(0, 1) + # decoder layers + for layer in self.layers: + x = layer(x, enc) + x = self.layer_norm(x) + # T x B x C -> B x T x C + x = x.transpose(0, 1) + # B T C, N, C -> B T N + x = torch.nn.functional.linear(x, self.embed_tokens.weight) + return x + +# label_smoothed_cross_entropy +def criterion(output: torch.Tensor, prev_output_tokens: torch.Tensor, label_smoothing: float = 0.2): + target = prev_output_tokens[:, 1:] + # fairseq.criterions.label_smoothed_cross_entory + # model.get_normalized_probs + lprobs = torch.nn.functional.softmax(output, dim=-1) + # fairseq.criterions.label_smoothed_nll_loss + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = label_smoothing / (lprobs.size(-1) - 1) + loss = (1.0 - label_smoothing - eps_i) * nll_loss + eps_i * smooth_loss + return loss + + +class mBART(torch.nn.Module): + + def __init__(self, cfg: Config): + super().__init__() + emb = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) + self.encoder: Encoder = Encoder(cfg, emb) + self.decoder: Decoder = Decoder(cfg, emb) + + def forward(self, src_tokens, prev_output_tokens): + encoder_out = self.encoder(src_tokens) + decoder_out = self.decoder(prev_output_tokens, encoder_out) + loss = criterion(decoder_out, prev_output_tokens) + return loss + + +def train(): + bs = 1 + + cfg = Config() + model = mBART(cfg).cuda() + + print_each_rank('model weight consumpition:') + memory_summary() + + dataloader = SynTextDataLoader( + shapes=( + [bs, cfg.max_source_positions], + [bs, cfg.max_target_positions] + ), + dtypes=(torch.int64, torch.int64), + batch_dims=(0,0,) + ) + + def train_iter(model, dataloader): + model.eval() + src_tokens, prev_output_tokens = next(dataloader) + loss = model(src_tokens, prev_output_tokens) + loss.backward() + + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + CudaTimer(enable=False).warmup() + iter_num = 128 + for step in range(iter_num): + if step >= 40: + CudaTimer(enable=True).start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= 40: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + memory_summary() + + +if __name__ == '__main__': + + cube.init() + train() \ No newline at end of file From a18439b950171125e7e002eb37c8ed19bb6d59d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 13 Mar 2022 12:05:57 +0800 Subject: [PATCH 0647/1892] fix dataloader for scientific model --- cube/algorithm/ops/dataloader.py | 5 ++ cube/compiler.py | 22 ++++---- cube/graph/operator/operator.py | 6 +- cube/logics/dataloader.py | 46 ++++++--------- cube/runtime/syndata.py | 97 +++++++++++++++++++++++++++----- 5 files changed, 117 insertions(+), 59 deletions(-) diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py index 44b61075..25c22140 100644 --- a/cube/algorithm/ops/dataloader.py +++ b/cube/algorithm/ops/dataloader.py @@ -26,6 +26,11 @@ def satisfy(self, config: Dict): node: IRDataOperation = self.node num: int = config['num'] dims: List[int] = node.get_batch_dims() + # check batch size + all_batch_size = set([output.shape[dim] for dim, output in zip(dims, node.outputs())]) + # batch size not same -- indicate a scientific model + if len(all_batch_size) != 1: + return False for dim, output in zip(dims, node.outputs()): if output.shape[dim] % num != 0: return False diff --git a/cube/compiler.py b/cube/compiler.py index dcb92f7e..b062a5f5 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -20,7 +20,8 @@ from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen from cube.profiler.timer import print_each_rank -from cube.runtime.syndata import CubeDataLoader +from cube.runtime.syndata import CubeDataLoader, SciLoopVariables + class SemanticModel: @@ -208,14 +209,15 @@ def decorator(fn: Callable) -> Callable: ) # setup batch size - all_batch_size = set() - dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] - for dnode in dnodes: - bs = [out.shape[dim] for out, dim in zip(dnode.outputs(), dnode.get_batch_dims())] - all_batch_size.update(bs) - if len(all_batch_size) != 1: - raise NotImplementedError(f"Heterogenous batch size {bs} is not supported") - batch_size = torch.tensor(list(all_batch_size), dtype=torch.int).cuda() + if not isinstance(dataloader, SciLoopVariables): + all_batch_size = set() + dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] + for dnode in dnodes: + bs = [out.shape[dim] for out, dim in zip(dnode.outputs(), dnode.get_batch_dims())] + all_batch_size.update(bs) + if len(all_batch_size) != 1: + raise NotImplementedError(f"Heterogenous batch size {bs} is not supported") + batch_size = torch.tensor(list(all_batch_size), dtype=torch.int).cuda() compile_end = time.time() compile_time = compile_end - compile_start @@ -228,7 +230,7 @@ def decorator(fn: Callable) -> Callable: torch.distributed.broadcast(batch_size, src=0) batch_size = batch_size.item() print_each_rank(f'reseting dataloader batch size to {batch_size}') - dataloader.reset(batch_size=batch_size) + dataloader.set_batch_size(batch_size) # load module filename = filename.format(myrank) diff --git a/cube/graph/operator/operator.py b/cube/graph/operator/operator.py index 711f8b1c..1f626189 100644 --- a/cube/graph/operator/operator.py +++ b/cube/graph/operator/operator.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, List +from typing import Any, Optional, Tuple, Union, List import copy from cube.ir.cten import IRCell, IRTensor @@ -298,9 +298,7 @@ def module_repr(self) -> str: class IRDataOperation(IRCell): - def __init__(self, data_num: int, batch_dims: List[int], name='dataloader'): - if not isinstance(batch_dims, list): - raise RuntimeError("Expected batch dims to be a list") + def __init__(self, data_num: int, batch_dims: Tuple[int], name='dataloader'): if len(batch_dims) != data_num: raise RuntimeError("Expected each output data has a specified batch dim") signature = 'dataloader.__next__' diff --git a/cube/logics/dataloader.py b/cube/logics/dataloader.py index 49a0c043..4ad72e1d 100644 --- a/cube/logics/dataloader.py +++ b/cube/logics/dataloader.py @@ -1,39 +1,25 @@ -import copy -from typing import Optional - -import torch - +from typing import Tuple from cube.runtime.syndata import CubeDataLoader class IRDataLoader: def __init__(self, dataloader: CubeDataLoader, dtype_map): - self.dataloader = iter(dataloader) - self.batch_dims = dataloader.get_batch_dims() - self.dtypes = list() - self.shapes = list() - - datas = next(dataloader) - if not isinstance(datas, tuple): - datas = (datas,) - - for data in datas: - if torch.is_tensor(data): - self.dtypes.append(dtype_map.map(data.dtype)) - shape = tuple(data.shape) - # special handler for scalar tensor shape - if len(shape) == 0: - shape = (1,) - self.shapes.append(shape) - else: - raise NotImplementedError("Data should be torch.Tensor") - - def get_batch_dims(self, idx: Optional[int] = None) -> int: - if idx is None: - return copy.copy(self.batch_dims) - else: - return self.batch_dims[idx] + if not isinstance(dataloader, CubeDataLoader): + raise TypeError("Expected data loader derived from CubeDataLoader") + self.dataloader: CubeDataLoader = iter(dataloader) + self.dtypes = [dtype_map.map(dtype) for dtype in dataloader.dtypes] + self.shapes = [list(shape) for shape in dataloader.shapes] + + def get_batch_dims(self) -> Tuple[int]: + return tuple(self.dataloader.batch_dims) + + def get_batch_size(self) -> int: + return self.dataloader.get_batch_size() + + def set_batch_size(self, bs: int): + self.dataloader.set_batch_size(bs) + return def __iter__(self): return self diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index bb40d9c8..514761cd 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -2,8 +2,7 @@ Synthetic Data Loader """ -from typing import List, Optional, Tuple -import copy +from typing import Any, List, Optional, Tuple import torch @@ -14,7 +13,7 @@ class CubeDataLoader: r""" Cube Dataloader """ - def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype], batch_dims: Tuple[int]): + def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype], batch_dims: Tuple[int] = None): """ shapes Tuple[Tuple[int]]: The shape for each data @@ -27,28 +26,95 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype], batch_d raise TypeError("Expected each shape in shapes to be a list") if len(shapes) != len(batch_dims) or len(shapes) != len(dtypes): raise TypeError("Expected number batch dim and dtypes to len(shapes)") - self.shapes = shapes + self.shapes = tuple([list(shape) for shape in shapes]) self.dtypes = dtypes - self.batch_dims = batch_dims + self.batch_dims = (0,) * len(self.shapes) if batch_dims is None else batch_dims - def get_batch_dims(self, idx: Optional[int] = None) -> int: + def get_batch_size(self) -> int: """ - Get batch dimension for idx-th data + get batch size """ - if idx is not None: - return self.batch_dims[idx] - else: - return list(self.batch_dims) + all_batch_size = set([shape[dim] for shape, dim in zip(self.shapes, self.batch_dims)]) + if len(all_batch_size) != 1: + raise ValueError("Heterogenous batch size in dataloader") + return list(all_batch_size)[0] - def reset(self, batch_size: int): + def set_batch_size(self, batch_size: int): """ - Reset batch size + set batch size """ - for bdim, shape in zip(self.batch_dims, self.shapes): - shape[bdim] = batch_size + self.batch_size = batch_size + for shape, dim in zip(self.shapes, self.batch_dims): + shape[dim] = batch_size print(f'> data loader output shape change to: {self.shapes}') +class SciLoopVariables(CubeDataLoader): + r"""Scientific loop variable loader + """ + def __init__(self, variables: List[Any], constants: List[Any]): + shapes = [] + dtypes = [] + for var in variables + constants: + if torch.is_tensor(var): + shapes.append(list(var.size()) if len(var.size()) != 0 else [1,]) + dtypes.append(var.dtype) + else: + shapes.append([1,]) + dtypes.append(type(var)) + batch_dims = [-1] * (len(variables) + len(constants)) + super().__init__(shapes, dtypes, batch_dims) + self.variables = list() + self.constants = list() + for var in variables: + if torch.is_tensor(var) and var.device != torch.cuda.current_device(): + var = var.cuda() + self.variables.append(var) + for const in constants: + if torch.is_tensor(const) and const.device != torch.cuda.current_device(): + const = const.cuda() + self.constants.append(const) + + def get_batch_size(self) -> int: + return 0 + + def set_batch_size(self, batch_size: int): + return + + def __iter__(self): + return self + + def __next__(self): + if len(self.variables) + len(self.constants) == 1: + return (self.variables + self.constants)[0] + return tuple(self.variables + self.constants) + + def update(self, variables: Optional[List[Any]] = None, constants: Optional[List[Any]] = None): + """ + Update variables and constants + """ + if variables is not None: + if len(variables) != len(self.variables): + raise ValueError(f"Expected {len(self.shapes)} but only got {len(variables)} varaibales to update") + for var, expected_shape in zip(variables, self.shapes): + expected_shape = tuple(expected_shape) + if not torch.is_tensor(var) and expected_shape != (1,): + raise ValueError(f"Non-tensor variable: Expected shape is (1,)") + if torch.is_tensor(var) and tuple(var.size()) != expected_shape: + raise ValueError(f"Shape update mismatch: var: {var.size()} != expected: {expected_shape}") + self.variables = variables + if constants is not None: + if len(constants) != len(self.constants): + raise ValueError(f"Expected {len(self.shapes)} but only got {len(constants)} varaibales to update") + for const, expected_shape in zip(constants, self.shapes): + expected_shape = tuple(expected_shape) + if not torch.is_tensor(const) and expected_shape != (1,): + raise ValueError(f"Non-tensor constant: Expected shape is (1,)") + if torch.is_tensor(const) and tuple(const.size()) != expected_shape: + raise ValueError(f"Shape update mismatch: const: {const.size()} != expected: {expected_shape}") + self.constants = constants + + class SynDataLoader(CubeDataLoader): r""" Synthetic dataloader to produce tensors @@ -109,6 +175,7 @@ def __next__(self): class SynTextDataLoader(SynDataLoader): def set_data_buffer(self, buffer_num=4, text_num=50257): + torch.manual_seed(0) self.datas = list() self._buffer_num = buffer_num for _ in range(self._buffer_num): From 9dc8c2d5e219e7327a669f2a4586a10cf74cb6b9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 13 Mar 2022 12:06:22 +0800 Subject: [PATCH 0648/1892] fix code gen for kwargs to be str --- cube/codegen/codegen.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index ed4fc56c..616579bb 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -225,7 +225,10 @@ def emit_op_call(self, node: IRFwOperation): inputs = [self.tensor_naming(t) for t in node.inputs()] kwargs = list() for key in node.kwargs: - code = f'{key}={node.kwargs[key]}' + val = node.kwargs[key] + if isinstance(val, str) and 'self.' not in val: + val = '"' + val + '"' + code = f'{key}={val}' kwargs.append(code) inputs += kwargs inputs = ', '.join(inputs) From d5051f835361597223723106103b05af47337c34 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 13 Mar 2022 12:07:15 +0800 Subject: [PATCH 0649/1892] switch to use runtime sci dataloader --- examples/poisson/sci.py | 5 +++-- examples/wrf/wrf.py | 40 ++-------------------------------------- 2 files changed, 5 insertions(+), 40 deletions(-) diff --git a/examples/poisson/sci.py b/examples/poisson/sci.py index 8d78e90b..7fd0ec39 100644 --- a/examples/poisson/sci.py +++ b/examples/poisson/sci.py @@ -6,12 +6,13 @@ torch.set_default_tensor_type(torch.DoubleTensor) +from cube.runtime.syndata import SciLoopVariables import cube from examples.poisson.policy.naive import PAS """ OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ + --nproc_per_node=1 \ --nnodes=1 \ examples/poisson/sci.py """ @@ -85,7 +86,7 @@ def train_loop(): r0 = rho - F.conv2d(phi, filter, padding=1) p = r0 - varloader = LoopVariables(variables=[r0, p, phi], constants=[filter]) + varloader = SciLoopVariables(variables=[r0, p, phi], constants=[filter]) model = ScientificModel() model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes),) diff --git a/examples/wrf/wrf.py b/examples/wrf/wrf.py index feddba5d..3f2d4857 100644 --- a/examples/wrf/wrf.py +++ b/examples/wrf/wrf.py @@ -6,6 +6,7 @@ import torch.nn.functional as F # from linalg import tridiagonal +from cube.runtime.syndata import SciLoopVariables from einops import rearrange from einops.layers.torch import Rearrange @@ -318,43 +319,6 @@ def _acoustic_step(self, ): pass -class LoopVariables(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): - # for var in variables + constants: - # print("### var = {}, type = {}".format(var, type(var))) - shapes = [list(var.size()) for var in variables + constants] - dtypes = [var.dtype for var in variables + constants] - batch_dims = [0] * (len(variables) + len(constants)) - super().__init__(shapes, dtypes, batch_dims) - self.variables = list() - self.constants = list() - for var in variables: - if torch.is_tensor(var) and var.device != torch.cuda.current_device(): - var = var.cuda() - self.variables.append(var) - for const in constants: - if torch.is_tensor(const) and const.device != torch.cuda.current_device(): - const = const.cuda() - self.constants.append(const) - - def __iter__(self): - return self - - def update(self, variables: List[torch.Tensor] = None, constants: List[torch.Tensor] = None): - if variables is not None: - self.variables = variables - if constants is not None: - self.constants = constants - - def reset(self, batch_size): - pass - - def __next__(self): - if len(self.variables) + len(self.constants) == 1: - return (self.variables + self.constants)[0] - return tuple(self.variables + self.constants) - if __name__ == "__main__": cube.init() @@ -393,7 +357,7 @@ def __next__(self): bar_x_filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) delta_z_filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) - varloader = LoopVariables(variables=[U, V, W, Theta, Mu, Phi, dt], constants=[Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter]) + varloader = SciLoopVariables(variables=[U, V, W, Theta, Mu, Phi, dt], constants=[Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter]) model = WRF() model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes), ) From 17674a7048937cf3291e3100d826e188d42c13bb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 13 Mar 2022 13:50:51 +0800 Subject: [PATCH 0650/1892] fix torch.size --- cube/graph/parser/parser.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 64a572a0..94cfa53e 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -142,14 +142,14 @@ def ntype(node: torch._C.Node): return ScriptNodeKind.AtenOp if node.kind() == 'prim::If': return ScriptNodeKind.PrimIf - if node.kind() == 'prim::ListUnpack': - return ScriptNodeKind.PrimListUnpack if node.kind() == 'prim::ListConstruct': return ScriptNodeKind.PrimListConstruct if node.kind() == 'prim::TupleConstruct': return ScriptNodeKind.PrimListConstruct + if node.kind() == 'prim::ListUnpack': + return ScriptNodeKind.PrimListUnpack if node.kind() == 'prim::TupleUnpack': - return ScriptNodeKind.PrimTupleUnpack + return ScriptNodeKind.PrimListUnpack if node.kind() == 'prim::PythonOp': return ScriptNodeKind.PrimPythonOp raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @@ -175,9 +175,7 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == ScriptNodeKind.PrimListConstruct: return ScriptModuleParser.parse_prim_list_construct_node(node, module, frame) if node_type == ScriptNodeKind.PrimListUnpack: - return ScriptModuleParser.parse_prim_listunpack_node(node, module, frame) - if node_type == ScriptNodeKind.PrimTupleUnpack: - return ScriptModuleParser.parse_prim_tupleunpack_node(node, module, frame) + return ScriptModuleParser.parse_prim_list_unpack_node(node, module, frame) if node_type == ScriptNodeKind.PrimPythonOp: return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) raise NotImplementedError(f"Un-supported node type {node_type}") @@ -248,10 +246,12 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: # special handling on aten::size(tensor: tensor, dim: int) if fsig == 'torch.size': - assert len(inputs) == 2 and len(outputs) == 1, \ - "Expected 2 inputs and 1 outputs for torch.size" - tensor, dim = input_val - output: int = tensor.shape[dim] + if len(inputs) == 2: + tensor, dim = input_val + output: int = tensor.shape[dim] + else: + tensor = input_val[0] + output: List[int] = list(tensor.shape) frame.add_var(outputs[0].debugName(), output) return [] @@ -420,11 +420,7 @@ def parse_prim_list_construct_node(node, module, frame: Frame) -> List[None]: return list() @staticmethod - def parse_prim_listunpack_node(node, module, frame: Frame) -> List[None]: - raise NotImplementedError - - @staticmethod - def parse_prim_tupleunpack_node(node, module, frame: Frame) -> List[None]: + def parse_prim_list_unpack_node(node, module, frame: Frame) -> List[None]: """ Parse script module node like: %q.1 : Tensor, %k.1 : Tensor, %v.1 : Tensor = prim::TupleUnpack(%11) From 3867296b835ecb0dcc882a62f7aa47f79404e730 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 13 Mar 2022 14:26:28 +0800 Subject: [PATCH 0651/1892] view assertation --- cube/graph/operator/function/function.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 4190c731..e4ee991c 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -2,8 +2,6 @@ import string import copy -from numpy import isin - from cube.ir.cten import IRTensor from cube.graph.operator.function.einops import EinDim, IREinops from cube.graph.operator.function.conv import IRConv2D @@ -301,6 +299,8 @@ def View(signature, inputs): """ assert len(inputs) == 2 input, shape = inputs + if not all([isinstance(dim, int) for dim in shape]): + raise TypeError("Expected tensor.view has static int shape") in_shape, ou_shape = list(input.shape), shape # shape check From 95b4b3208a97eaca2b7440a04095ff6d423f2e5d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 13 Mar 2022 14:26:46 +0800 Subject: [PATCH 0652/1892] change to sci loop vars --- examples/atmosphere/weather.py | 43 ++-------------------------------- 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index b63159f7..946d4cdf 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -8,6 +8,7 @@ from typing import List import cube +from cube.runtime.syndata import SciLoopVariables from examples.atmosphere.policy.naive import PAS from einops.layers.torch import Rearrange @@ -237,46 +238,6 @@ def laplas(self, X): return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.laplas_filter)) -class LoopVariables(cube.runtime.syndata.CubeDataLoader): - def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): - # for var in variables + constants: - # print("### var = {}, type = {}".format(var, type(var))) - shapes = [list(var.size() if len(var.size()) > 0 else [1]) for var in variables + constants] - dtypes = [var.dtype for var in variables + constants] - batch_dims = [0] * (len(variables) + len(constants)) - super().__init__(shapes, dtypes, batch_dims) - self.variables = list() - self.constants = list() - for var in variables: - if torch.is_tensor(var) and var.device != torch.cuda.current_device(): - var = var.cuda() - self.variables.append(var) - for const in constants: - if torch.is_tensor(const) and const.device != torch.cuda.current_device(): - const = const.cuda() - self.constants.append(const) - - def __iter__(self): - return self - - - def update(self, variables: List[torch.Tensor] = None, constants: List[torch.Tensor] = None): - if variables is not None: - self.variables = variables - if constants is not None: - self.constants = constants - - - def reset(self, batch_size): - pass - - - def __next__(self): - if len(self.variables) + len(self.constants) == 1: - return (self.variables + self.constants)[0] - return tuple(self.variables + self.constants) - - if __name__ == "__main__": import matplotlib.pyplot as plt cube.init() @@ -466,7 +427,7 @@ def init(ps, pt, zs): for var in [pi, theta, u, v, dt]: print(f'shape {var.shape}') - varloader = LoopVariables(variables=[pi, theta, u, v], constants=[dt]) + varloader = SciLoopVariables(variables=[pi, theta, u, v], constants=[dt]) model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes)) @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) From a7de5dab4336be366f2ccf26b4a2104d74837f7f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 13 Mar 2022 14:55:47 +0800 Subject: [PATCH 0653/1892] add model memory inspect --- cube/profiler/memory.py | 68 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index ad2f9aff..263edb52 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -1,3 +1,4 @@ +from typing import Any, List import torch from cube.profiler.timer import print_each_rank @@ -9,3 +10,70 @@ def memory_summary(): print_each_rank( '{:.2f}GB memory consumption'.format(mem / 1024 / 1024 / 1024), ) + + +def model_summary(model: torch.nn.Module, inputs: List[Any], do_eval=False, max_depth=6): + """ + Benchmakr memory consumption for each module. + This could only be called before any other forward/backward + + New attributes will be assigned to each module: + + * _summary_depth (Int) + * _summary_begin_end (Boolean) + * _summary_memory_state (Int) + + Make sure all of these attributes are not used in modules. + """ + torch.cuda.empty_cache() + static_memory = torch.cuda.memory_allocated() + print_each_rank( + 'static model: {:,.2f} MB'.format(static_memory / 1024 / 1024), rank_only=0) + + stat = dict(depth=0) + def before_forward(module, input): + module._summary_depth = stat['depth'] + module._summary_begin_end = False + if len(list(module.children())) != 0: + if stat['depth'] + 1 < max_depth: + name = module.__class__.__name__ + module._summary_begin_end = True + prefix = ' ' * module._summary_depth + '[Begin] > ' + print_each_rank(prefix + '{}:'.format(name), rank_only=0) + if module._summary_depth < max_depth: + module._summary_memory_state = torch.cuda.memory_allocated() + stat['depth'] += 1 + + + def after_forward(module, input, output): + stat['depth'] -= 1 + if module._summary_depth >= max_depth: + return + name = module.__class__.__name__ + torch.cuda.empty_cache() + curr_memory = torch.cuda.memory_allocated() + mem_consumption = curr_memory - module._summary_memory_state + mem_consumption = mem_consumption / 1024 / 1024 + + n_params = sum([p.data.numel() for p in list(module.parameters())]) + + prefix = ' ' * module._summary_depth + prefix += '[End] > ' if module._summary_begin_end else '> ' + print_each_rank( + prefix + '{}: Mem {:,.2f} MB, Params: {:,} ({:,.2f} MB if fp32)'.format( + name, mem_consumption, n_params, n_params / 1024 / 1024 * 4), rank_only=0) + + handle_pre = torch.nn.modules.module.register_module_forward_pre_hook(before_forward) + handle_after = torch.nn.modules.module.register_module_forward_hook(after_forward) + + if do_eval: + model.eval() + else: + model.train() + _ = model(*inputs) + + handle_pre.remove() + handle_after.remove() + + if stat['depth'] != 0: + raise ValueError("Internal Error: depth {} not to 0".format(stat['depth'])) From a0ced110db174dc430b67269be377a5d9e38cd11 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 13 Mar 2022 16:32:28 +0800 Subject: [PATCH 0654/1892] bug fix on mbart --- examples/mbart/mbart.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/mbart/mbart.py b/examples/mbart/mbart.py index 1eccf1b7..51565458 100644 --- a/examples/mbart/mbart.py +++ b/examples/mbart/mbart.py @@ -15,7 +15,7 @@ import cube from cube.runtime.syndata import SynTextDataLoader from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary +from cube.profiler.memory import memory_summary, model_summary from cube.profiler.timer import print_each_rank @@ -27,7 +27,7 @@ class Config: - num_embeddings = 500027 + num_embeddings = 2500027 encoder_embed_path = None encoder_embed_dim = 1024 @@ -152,6 +152,14 @@ def __init__(self, embed_dim: int, num_heads: int, dropout=0.0, bias=True): else: self.out_bias = None + def forward(self, query: torch.Tensor, key: torch.Tensor): + return attn_fn(query, key, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) + def forward_encoder_decoder_attn(self, query: torch.Tensor, key: torch.Tensor): # tgt_len, bsz, embed_dim = query.size() # q = torch.nn.functional.linear(query, self.q_proj, self.q_bias) @@ -200,7 +208,7 @@ def __init__(self, cfg: Config): def forward(self, x): # , encoder_padding_mask: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None): residual = x x = self.self_attn_layer_norm(x) - x = self.self_attn.forward_self_attn(x) + x = self.self_attn(x, x) x = self.dropout(x) x = x + residual @@ -273,7 +281,7 @@ def forward(self, x, encoder_out): # encoder_padding_mask): x = self.self_attn_layer_norm(x) # self attention - x = self.self_attn.forward_self_attn(x) + x = self.self_attn(x, x) x = self.dropout(x) x = residual + x @@ -281,7 +289,7 @@ def forward(self, x, encoder_out): # encoder_padding_mask): residual = x # normalize before x = self.encoder_attn_layer_norm(x) - x = self.encoder_attn.forward_encoder_decoder_attn(x, encoder_out) + x = self.encoder_attn(x, encoder_out) x = self.dropout(x) x = x + residual @@ -384,15 +392,16 @@ def train(): ) def train_iter(model, dataloader): - model.eval() + # model.eval() src_tokens, prev_output_tokens = next(dataloader) + # model_summary(model, (src_tokens, prev_output_tokens)) loss = model(src_tokens, prev_output_tokens) loss.backward() optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) CudaTimer(enable=False).warmup() - iter_num = 128 + iter_num = 1 for step in range(iter_num): if step >= 40: CudaTimer(enable=True).start('e2e') @@ -404,8 +413,8 @@ def train_iter(model, dataloader): if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) + # print_each_rank('e2e time (ms) per iteration: {} ms'.format( + # CudaTimer().duration(iter_num-40, field_name='e2e'))) memory_summary() From 6e6aa13528fbb12927b9ec2032b9446369fdb0ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 14 Mar 2022 10:05:30 +0800 Subject: [PATCH 0655/1892] mbart embedding length --- examples/mbart/mbart.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/mbart/mbart.py b/examples/mbart/mbart.py index 51565458..a69e3380 100644 --- a/examples/mbart/mbart.py +++ b/examples/mbart/mbart.py @@ -27,7 +27,7 @@ class Config: - num_embeddings = 2500027 + num_embeddings = 250027 encoder_embed_path = None encoder_embed_dim = 1024 @@ -394,9 +394,9 @@ def train(): def train_iter(model, dataloader): # model.eval() src_tokens, prev_output_tokens = next(dataloader) - # model_summary(model, (src_tokens, prev_output_tokens)) - loss = model(src_tokens, prev_output_tokens) - loss.backward() + model_summary(model, (src_tokens, prev_output_tokens)) + # loss = model(src_tokens, prev_output_tokens) + # loss.backward() optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) @@ -406,6 +406,7 @@ def train_iter(model, dataloader): if step >= 40: CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) + break optimizer.step() optimizer.zero_grad() if step >= 40: From c43081ccde6f2e9bed0b4543abf34c2bdfd96750 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 14 Mar 2022 17:26:04 +0800 Subject: [PATCH 0656/1892] init with pipeline-naive scheduling --- handcraft/mbart/mbart.py | 667 ++++++++++++++++++++++++++++++++++++ handcraft/mbart/schedule.py | 575 +++++++++++++++++++++++++++++++ handcraft/mbart/tp.py | 88 +++++ 3 files changed, 1330 insertions(+) create mode 100644 handcraft/mbart/mbart.py create mode 100644 handcraft/mbart/schedule.py create mode 100644 handcraft/mbart/tp.py diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py new file mode 100644 index 00000000..bec0789d --- /dev/null +++ b/handcraft/mbart/mbart.py @@ -0,0 +1,667 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + handcraft/mbart/mbart.py --pp-size 4 --tp-size 1 --nmb 4 +""" + +from typing import Optional +import argparse +import math +import torch + +import cube +from cube.runtime.device import DeviceGroup +from cube.runtime.syndata import SynTextDataLoader +from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary, model_summary +from cube.profiler.timer import print_each_rank + +from handcraft.mbart.schedule import schedule_naive + +_tp_group = -1 +_pp_group = -1 +_pp_embed_group = -1 +_pp_next_rank = None +_pp_prev_rank = None + +# fairseq task +# translation_from_pretrained_bart + +# fairseq criterion +# label_smoothed_cross_entropy, --label_smoothing = 0.2 + +class Config: + + num_embeddings = 250027 + + encoder_embed_path = None + encoder_embed_dim = 1024 + encoder_ffn_embed_dim = 4 * 1024 + encoder_layers = 12 + encoder_attention_heads = 16 + encoder_normalize_before = True + encoder_learned_pos = True + + decoder_embed_path = None + decoder_embed_dim = 1024 + decoder_ffn_embed_dim = 4 * 1024 + decoder_layers = 12 + decoder_attention_heads = 16 + decoder_normalize_before = True + decoder_learned_pos = True + cross_self_attention = False + no_cross_attention = False + + attention_dropout = 0.0 + activation_dropout = 0.0 + dropout = 0.1 + + max_target_positions = 1024 + max_source_positions = 1024 + adaptive_softmax_cutoff = None + adaptive_softmax_dropout = 0 + + share_decoder_input_output_embed = True + share_all_embeddings = True + + decoder_output_dim = 1024 # same with decorder_embed_dim + decoder_input_dim = 1024 # same with decorder_embed_dim + + no_scale_embedding = False # True in bart large + layernorm_embedding = True + activation_fn = 'gelu' + pooler_activation_fn = 'tanh' + pooler_dropout = 0.0 + + +def attn_fn(query: torch.Tensor, key: torch.Tensor, + wq: torch.Tensor, wq_bias: Optional[torch.Tensor], + wk: torch.Tensor, wk_bias: Optional[torch.Tensor], + wv: torch.Tensor, wv_bias: Optional[torch.Tensor], + wout: torch.Tensor, wout_bias: Optional[torch.Tensor], + h: int, scale: float, dropout: float, mask=True): + """ + query, key: (L, N, E) = (seqlen, batch size, embed_dim) + wq, wk, wv weight: [(num_head * dim_head), E] + dropout: float + h: int: number of heads + """ + num_head = h + L, N = query.size(0), query.size(1) + dim_head = wq.size(0) // num_head + + q = torch.nn.functional.linear(query, wq, wq_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(key, wk, wk_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(key, wv, wv_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + if mask: # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, wout, wout_bias) # L N (h d), E E -> L N E + return output + + +class MultiheadAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, dropout=0.0, bias=True): + super().__init__() + self.kdim = embed_dim + self.vdim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # K + self.k_proj = torch.nn.Parameter(torch.empty(embed_dim, self.kdim)) + if bias: + self.k_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.k_bias = None + # V + self.v_proj = torch.nn.Parameter(torch.empty(embed_dim, self.vdim)) + if bias: + self.v_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.v_bias = None + # Q + self.q_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + if bias: + self.q_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.q_bias = None + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + if bias: + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.out_bias = None + + def forward(self, query: torch.Tensor, key: torch.Tensor): + return attn_fn(query, key, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) + + def forward_encoder_decoder_attn(self, query: torch.Tensor, key: torch.Tensor): + # tgt_len, bsz, embed_dim = query.size() + # q = torch.nn.functional.linear(query, self.q_proj, self.q_bias) + # k = torch.nn.functional.linear(key, self.k_proj, self.k_bias) + # v = torch.nn.functional.linear(key, self.v_proj, self.v_bias) + # q = q * self.scaling + # q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + # k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + # v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + # attn_weights = torch.bmm(q, k.transpose(1, 2)) + # # TODO: here needs a mask + # attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + # attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_p) + # attn = torch.bmm(attn_probs, v) + # attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + # attn = torch.nn.functional.linear(attn, self.out_proj, self.out_bias) + return attn_fn(query, key, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) + + def forward_self_attn(self, query): + return attn_fn(query, query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) + + +class EncoderLayer(torch.nn.Module): + + def __init__(self, cfg: Config): + + super().__init__() + self.cfg = cfg + self.self_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.encoder_attention_heads, cfg.attention_dropout) + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.dropout = torch.nn.Dropout(p=cfg.dropout) + self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + self.fc1 = torch.nn.Linear(cfg.encoder_embed_dim, cfg.encoder_ffn_embed_dim) + self.fc2 = torch.nn.Linear(cfg.encoder_ffn_embed_dim, cfg.encoder_embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + def input_shape(self): + # L, N, E + return (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + + def output_shape(self): + # L N E + return (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + + def input_dtype(self): + return torch.float32 + + def output_dtype(self): + return torch.float32 + + def forward(self, x): # , encoder_padding_mask: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None): + # print(f'encoder layer: x: {x.size()}') + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x, x) + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + x = self.fc1(x) + x = torch.nn.functional.gelu(x) + x = self.activation_dropout(x) + x = self.fc2(x) + + x = self.dropout(x) + x = x + residual + return x + + +class Encoder(torch.nn.Module): + + def __init__(self, cfg: Config, embed_tokens: torch.nn.Embedding): + super().__init__() + self.dropout = torch.nn.Dropout(cfg.dropout) + self.max_source_positions = cfg.max_source_positions + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(cfg.encoder_embed_dim) + self.embed_positions = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) + self.layernorm_embedding = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.layers = torch.nn.ModuleList([]) + self.layers.extend( + [EncoderLayer(cfg) for _ in range(cfg.encoder_layers)] + ) + # normalize before + self.layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + def forward(self, src_tokens: torch.Tensor): + token_embedding = torch.nn.functional.embedding(src_tokens, self.embed_tokens.weight) # self.embed_tokens(src_tokens) + embed = self.embed_scale * token_embedding + + x = embed + self.embed_positions.weight # self.embed_positions(src_tokens) + x = self.layernorm_embedding(x) + x = self.dropout(x) + + x = x.transpose(0, 1) + for layer in self.layers: + x = layer(x) # encoder_padding_mask if has_pads else None) + x = self.layer_norm(x) + return x + + +class DecoderLayer(torch.nn.Module): + + def __init__(self, cfg: Config): + + super().__init__() + self.cfg = cfg + self.dropout = torch.nn.Dropout(p=cfg.dropout) + self.self_attn = MultiheadAttention(cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) + self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + # encoder atten + self.encoder_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) + self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + # self.encoder_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + self.fc1 = torch.nn.Linear(cfg.decoder_embed_dim, cfg.decoder_ffn_embed_dim) + self.fc2 = torch.nn.Linear(cfg.decoder_ffn_embed_dim, cfg.decoder_embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + def input_shape(self): + return ( + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + ) + + def output_shape(self): + return ( + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + ) + + def input_dtype(self): + return (torch.float32, torch.float32) + + def output_dtype(self): + return (torch.float32, torch.float32) + + def forward(self, x, encoder_out): # encoder_padding_mask): + # print(f'decoder layer: x: {x.size()}, encoder_out: {encoder_out.size()}') + residual = x + # normalize before + x = self.self_attn_layer_norm(x) + + # self attention + x = self.self_attn(x, x) + x = self.dropout(x) + x = residual + x + + # encoder attn + residual = x + # normalize before + x = self.encoder_attn_layer_norm(x) + x = self.encoder_attn(x, encoder_out) + x = self.dropout(x) + x = x + residual + + residual = x + # normalize before + x = self.final_layer_norm(x) + x = self.fc1(x) + x = torch.nn.functional.gelu(x) + x = self.activation_dropout(x) + x = self.fc2(x) + x = self.dropout(x) + x = x + residual + return x, encoder_out + + +class Decoder(torch.nn.Module): + + def __init__(self, cfg: Config, embed_tokens: torch.nn.Embedding): + super().__init__() + self.dropout = torch.nn.Dropout(cfg.dropout) + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(cfg.decoder_embed_dim) + self.embed_positions = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) + self.layernorm_embedding = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.layers = torch.nn.ModuleList([]) + self.layers.extend( + [DecoderLayer(cfg) for _ in range(cfg.decoder_layers)] + ) + self.layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + def forward(self, prev_output_tokens: torch.Tensor, enc: torch.Tensor): + positions = self.embed_positions.weight # self.embed_positions(prev_output_tokens) + embed = torch.nn.functional.embedding(prev_output_tokens, self.embed_tokens.weight) + x = self.embed_scale * embed + x = x + positions + x = self.layernorm_embedding(x) + x = self.dropout(x) + # B T C -> T B C + x = x.transpose(0, 1) + # decoder layers + for layer in self.layers: + x, enc = layer(x, enc) + x = self.layer_norm(x) + # T x B x C -> B x T x C + x = x.transpose(0, 1) + # B T C, N, C -> B T N + x = torch.nn.functional.linear(x, self.embed_tokens.weight) + return x + +# label_smoothed_cross_entropy +def criterion(output: torch.Tensor, prev_output_tokens: torch.Tensor, label_smoothing: float = 0.2): + target = prev_output_tokens[:, 1:] + # fairseq.criterions.label_smoothed_cross_entory + # model.get_normalized_probs + lprobs = torch.nn.functional.softmax(output, dim=-1) + # fairseq.criterions.label_smoothed_nll_loss + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = label_smoothing / (lprobs.size(-1) - 1) + loss = (1.0 - label_smoothing - eps_i) * nll_loss + eps_i * smooth_loss + return loss + + +class mBARTFull(torch.nn.Module): + + def __init__(self, cfg: Config, dataloader): + super().__init__() + self.cfg = cfg + self.dataloader = iter(dataloader) + + self.rank = DeviceGroup().rank + + global _pp_group + self.pp_group = _pp_group + self.total_layers = cfg.encoder_layers + cfg.decoder_layers + + self.pp_stage = torch.distributed.get_rank(_pp_group) + self.num_stages = torch.distributed.get_world_size(_pp_group) + + self.layer_start = self.total_layers // self.num_stages * self.pp_stage + self.layer_end = self.total_layers // self.num_stages * (self.pp_stage + 1) + + self.encoder_preprocess = (self.pp_stage == 0) + self.encoder_forward = (self.layer_start < cfg.encoder_layers) + self.decoder_preprocess = (self.pp_stage == self.num_stages // 2) + self.decoder_forward = (self.layer_start >= cfg.encoder_layers) + self.loss_compute = (self.pp_stage == self.num_stages - 1) + + self.encoder_layer_start = self.layer_start + self.encoder_layer_end = min(self.layer_end, cfg.encoder_layers) + + self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) + self.decoder_layer_end = self.layer_end + + self.emb = None + + # encoder preprocess + if self.encoder_preprocess: + print(f'[{self.rank}]: initializing preprocess encoder parameters') + self.emb = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) + self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) + self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) + self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + # encoders + if self.encoder_forward: + print(f'[{self.rank}]: initializing {self.encoder_layer_end - self.encoder_layer_start} encoder layers') + self.encoders = torch.nn.ModuleList([EncoderLayer(cfg) for _ in range(self.encoder_layer_end - self.encoder_layer_start)]) + + # decoder preprocess + if self.decoder_preprocess: + print(f'[{self.rank}]: initializing preprocess decoder parameters') + if self.emb is None: + self.emb = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) + self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) + self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) + self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + # decoders + if self.decoder_forward: + print(f'[{self.rank}]: initializing {self.decoder_layer_end - self.decoder_layer_start} decoder layers') + self.decoders = torch.nn.ModuleList([DecoderLayer(cfg) for _ in range(self.decoder_layer_end - self.decoder_layer_start)]) + + # compute loss + if self.loss_compute: + print(f'[{self.rank}]: will compute loss') + if self.emb is None: + self.emb = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + + def input_shape(self): + if self.encoder_preprocess: + # src_tokens, prev_output_tokens + return () + elif self.encoder_forward: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + ) + elif self.decoder_preprocess: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + ) + elif self.decoder_forward: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + ) + else: + assert False + + def output_shape(self): + shape = None + if self.encoder_preprocess or self.encoder_forward: + shape = (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + if self.decoder_preprocess or self.decoder_forward: + shape = ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + ) + if self.loss_compute: + shape = ((1,),) + assert shape is not None + return shape + + def input_dtype(self): + if self.encoder_preprocess: + return () + elif self.encoder_forward: + return (torch.float32,) + elif self.decoder_preprocess: + return (torch.float32,) + elif self.decoder_forward: + return (torch.float32, torch.float32) + else: + assert False + + def output_dtype(self): + dtype = None + if self.encoder_preprocess or self.encoder_forward: + dtype = (torch.float32,) + if self.decoder_preprocess or self.decoder_forward: + dtype = (torch.float32, torch.float32) + if self.loss_compute: + dtype = ((1,),) + assert dtype is not None + return dtype + + + def forward(self, enc=None, dec=None): + """ + x1: src_tokens or encoder output/input + x2: prev_output_tokens or decoder output/input + """ + src_tokens, prev_output_tokens = None, None + # encoder preprocess + if self.encoder_preprocess: + src_tokens, prev_output_tokens = next(self.dataloader) + token_embedding = torch.nn.functional.embedding(src_tokens, self.emb.weight) + embed = self.embed_scale_encoder * token_embedding + x = embed + self.embed_positions_encoder.weight + x = self.layernorm_embedding_encoder(x) + x = torch.nn.functional.dropout(x, p=0.0) + enc = x.transpose(0, 1) + output = (enc,) + + # forward encoder + if self.encoder_forward: + for layer in self.encoders: + enc = layer(enc) # encoder_padding_mask if has_pads else None) + output = (enc,) + + # decoder preprocess + if self.decoder_preprocess: + enc = self.layer_norm_encoder(enc) + if prev_output_tokens is None: + _, prev_output_tokens = next(self.dataloader) + embed = torch.nn.functional.embedding(prev_output_tokens, self.emb.weight) + embed = self.embed_scale_decoder * embed + embed = embed + self.embed_positions_decoder.weight + embed = self.layernorm_embedding_decoder(embed) + embed = torch.nn.functional.dropout(embed, p=0.0) + dec = embed.transpose(0, 1) + output = (enc, dec) + + # forward decoder + if self.decoder_forward: + for layer in self.decoders: + dec, enc = layer(dec, enc) + output = (enc, dec) + + # postprocess + if self.loss_compute: + if prev_output_tokens is None: + _, prev_output_tokens = next(self.dataloader) + dec = self.layer_norm_decoder(dec) + dec = dec.transpose(0, 1) + dec = torch.nn.functional.linear(dec, self.emb.weight) + loss = criterion(dec, prev_output_tokens) + output = (loss,) + + return output + + +def reduce_embed(model, pp_embed_group): + """ + Embedding gradients needs to be reduced across pipeline stages + """ + if isinstance(model.emb, torch.nn.Module): + grad = model.emb.weight.grad + else: + grad = None + if grad is not None: + torch.distributed.all_reduce(grad, group=pp_embed_group) + torch.cuda.synchronize() + + + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--tp-size', type=int, + help='tensor parallelism size') + parser.add_argument('--pp-size', type=int, + help='pipeline parallelism size') + parser.add_argument('--nmb', type=int, default=4, + help='num of micro batch') + args = parser.parse_args() + + print(args) + + cube.init() + pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) + print_each_rank(f'my pp ranks: {pp_ranks}') + print_each_rank(f'my tp_ranks: {tp_ranks}') + + if _pp_group == -1: + _pp_group = DeviceGroup().get_group(pp_ranks) + idx = pp_ranks.index(DeviceGroup().rank) + _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] + _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] + if _tp_group == -1: + _tp_group = DeviceGroup().get_group(tp_ranks) + + # create embed group: first encoder, first decoder, last stage + # FIXME: only work for tp_size = 1 + embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2], pp_ranks[-1]] + embed_ranks = list(set(embed_ranks)) + _pp_embed_group = DeviceGroup().get_group(embed_ranks) + + + cfg = Config() + dataloader = SynTextDataLoader( + shapes=( + [1, cfg.max_source_positions], + [1, cfg.max_target_positions] + ), + dtypes=(torch.int64, torch.int64), + batch_dims=(0,0,) + ) + model = mBARTFull(cfg, dataloader).cuda() + print_each_rank('model weight consumpition:') + memory_summary() + + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + CudaTimer(enable=False).warmup() + iter_num = 64 + for step in range(iter_num): + if step >= 20: + CudaTimer(enable=True).start('e2e') + schedule_naive(model, args.nmb, (_pp_prev_rank, _pp_next_rank)) + reduce_embed(model, _pp_embed_group) + optimizer.step() + optimizer.zero_grad() + if step >= 20: + CudaTimer().stop('e2e') + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-20, field_name='e2e'))) + memory_summary() diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py new file mode 100644 index 00000000..bda0b989 --- /dev/null +++ b/handcraft/mbart/schedule.py @@ -0,0 +1,575 @@ +from typing import List, Tuple +import torch + +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary +import cube.runtime.adapter.collectives as coll +from cube.runtime.device import DeviceGroup + +io_input = input + +def forward_step(model, *args, **kwargs): + """ + Forward pass + """ + CudaTimer().start("forward") + output = model(*args, **kwargs) + CudaTimer().stop("forward") + return output + + +def backward_step(input_tensors: List[torch.Tensor], + output_tensors: List[torch.Tensor], + output_tensor_grads: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Backward pass + """ + for tensor in input_tensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + tensor.retain_grad() + CudaTimer().start("backward") + torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) + CudaTimer().stop("backward") + input_tensor_grads = [] + for tensor in input_tensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + input_tensor_grads.append(tensor.grad) + else: + input_tensor_grads.append(None) + return input_tensor_grads + + +def recv_forward(model, prev_rank: int) -> List[torch.Tensor]: + CudaTimer().start(field_name='comm') + shapes = model.input_shape() + dtypes = model.input_dtype() + if len(shapes) == 0: + return () + # print(f'rank {DeviceGroup().rank} recving forward: {shapes}') + tensors = [ + torch.empty( + shape, requires_grad=True, dtype=dtype, + device=torch.cuda.current_device() + ) for shape, dtype in zip(shapes, dtypes) + ] + recv_ops = [ + torch.distributed.P2POp( + torch.distributed.irecv, tensor, prev_rank + ) for tensor in tensors + ] + reqs = torch.distributed.batch_isend_irecv(recv_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return tensors + + +def recv_backward(model, next_rank: int) -> List[torch.Tensor]: + CudaTimer().start(field_name='comm') + shapes = model.output_shape() + dtypes = model.output_dtype() + if len(shapes) == 0: + return () + # print(f'rank {DeviceGroup().rank} recving backward: {shapes}') + tensors = [ + torch.empty( + shape, requires_grad=False, dtype=dtype, + device=torch.cuda.current_device() + ) for shape, dtype in zip(shapes, dtypes) + ] + recv_ops = [ + torch.distributed.P2POp( + torch.distributed.irecv, tensor, next_rank + ) for tensor in tensors + ] + reqs = torch.distributed.batch_isend_irecv(recv_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return tensors + + +def send_forward(outputs: List[torch.Tensor], next_rank: int): + if len(outputs) == 0: + return + CudaTimer().start(field_name='comm') + # print(f'rank {DeviceGroup().rank} sending forward: {[tuple(t.size()) for t in outputs]}') + send_ops = [ + torch.distributed.P2POp( + torch.distributed.isend, tensor, next_rank + ) for tensor in outputs + ] + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + + +def send_backward(grads: List[torch.Tensor], prev_rank: int): + if len(grads) == 0: + return + CudaTimer().start(field_name='comm') + # print(f'rank {DeviceGroup().rank} sending backward: {[tuple(t.size()) for t in grads]}') + send_ops = [ + torch.distributed.P2POp( + torch.distributed.isend, tensor, prev_rank + ) for tensor in grads + ] + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + + +def send_forward_recv_backward(outputs, model, next_rank: int) -> List[torch.Tensor]: + CudaTimer().start(field_name='comm') + shapes = model.output_shape() + dtypes = model.output_dtype() + # print(f'rank {DeviceGroup().rank} sending forward: {[tuple(t.size()) for t in outputs]} recving backward {shapes}') + ops = list() + # send forward outputs + send_ops = [ + torch.distributed.P2POp( + torch.distributed.isend, tensor, next_rank + ) for tensor in outputs + ] + ops += send_ops + # recv backward inputs + tensors = [ + torch.empty( + shape, requires_grad=True, dtype=dtype, + device=torch.cuda.current_device() + ) for shape, dtype in zip(shapes, dtypes) + ] + recv_ops = [ + torch.distributed.P2POp( + torch.distributed.irecv, tensor, next_rank + ) for tensor in tensors + ] + ops += recv_ops + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return tensors + + +def send_backward_recv_forward(grads, model, prev_rank: int) -> List[torch.Tensor]: + CudaTimer().start(field_name='comm') + shapes = model.input_shape() + dtypes = model.input_dtype() + # print(f'rank {DeviceGroup().rank} sending backward: {[tuple(t.size()) for t in grads]} recving forward {shapes}') + ops = list() + # send backward gradients + send_ops = [ + torch.distributed.P2POp( + torch.distributed.isend, tensor, prev_rank + ) for tensor in grads + ] + ops += send_ops + # recv forward inputs + tensors = [ + torch.empty( + shape, requires_grad=True, dtype=dtype, + device=torch.cuda.current_device() + ) for shape, dtype in zip(shapes, dtypes) + ] + recv_ops = [ + torch.distributed.P2POp( + torch.distributed.irecv, tensor, prev_rank + ) for tensor in tensors + ] + ops += recv_ops + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return tensors + + + +def schedule_naive(model, num_microbatch: int, neighbors: Tuple[int, int]): + """ + neighbors: (prev_rank: int, next_rank: int) + """ + rank = DeviceGroup().rank + prev_rank, next_rank = neighbors + + is_first_stage = rank < prev_rank + is_last_stage = rank > next_rank + + for step in range(num_microbatch): + # print(f'rank {rank} recving forward input...') + inputs = () if is_first_stage else recv_forward(model, prev_rank) + # forward + outputs = forward_step(model, *inputs) + # send forward + if not is_last_stage: + # print(f'rank {rank} sending forward output...') + send_forward(outputs, next_rank) + # recv backward + # print(f'rank {rank} recving backward input...') + output_grads = (None,) if is_last_stage else recv_backward(model, next_rank) + # backward + input_grads = backward_step(inputs, outputs, output_grads) + # send backward + if not is_first_stage: + # print(f'rank {rank} sending backward output...') + send_backward(input_grads, prev_rank) + + # memory_summary() + # if rank == 0: + # io_input(f'{step}>>>') + # torch.distributed.barrier() + + +def schedule_tp_1f1b(model: torch.nn.Module, + first_stage_model: torch.nn.Module, + dataloader, + num_microbatch: int, + num_stage: int): + rank = DeviceGroup().rank + next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size + prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size + + input_tensors = list() + output_tensors = list() + + input_1st_tensors = list() + output_1st_tensors = list() + + gather_list = list(range(num_stage)) + gather_list[0], gather_list[1] = gather_list[1], gather_list[0] + + def tp_forward(fmodel, dataloader) -> torch.Tensor: + input = next(dataloader) + output = forward_step(fmodel, input) + input_1st_tensors.append(input) + output_1st_tensors.append(output) + # gather + outputs = coll.gather([output], None, None, gather_list) + if rank == 1: + with torch.no_grad(): + outputs[0], outputs[1] = outputs[1], outputs[0] + output = torch.cat(tuple(outputs), dim=-1) + output = output.requires_grad_() + else: + output = None + return output + + def tp_backward(grad: torch.Tensor): + if rank == 1: + with torch.no_grad(): + grads = list(grad.chunk(num_stage, dim=-1)) + grads[0], grads[1] = grads[1], grads[0] + else: + grads = None + input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) + grad_1st = coll.scatter(grads, [output_1st.size()], [output_1st.dtype], gather_list) + backward_step([input_1st], [output_1st], [grad_1st])[0] + + fofst = [0] + [-(step // 2) for step in range(num_stage-1)] + bofst = [0] + [-(num_stage - 2 - (step // 2)) for step in range(num_stage-1)] + # print(fofst) + # print(bofst) + fofst = fofst[rank] + bofst = bofst[rank] + last_backward = None + last_forward = None + for step in range(num_microbatch + num_stage - 2): + torch.distributed.barrier() + # print_each_rank(f'=========begin rank {rank}=========') + fmid, bmid = step + fofst, step + bofst + do_backward = 0 <= bmid and bmid <= num_microbatch - 1 + do_forward = 0 <= fmid and fmid <= num_microbatch - 1 + + # step1: tp forward + if 0 <= step and step <= num_microbatch - 1: + # print(f'rank {rank} forward tp model ') + output_1st = tp_forward(first_stage_model, dataloader) + + # step2: backward + forward + if rank == 0: + pass + + if rank % 2 == 0 and rank != 0: + # inter-barrier + if do_backward and last_forward is not None: + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [last_forward], [model.output_shape()], [model.output_dtype()], + [next_rank], [next_rank] + )[0] + elif do_backward: + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + elif last_forward is not None: + # print(f'rank {rank} send forward output ') + coll.send(last_forward, next_rank) + + # backward + if do_backward: + input, output = input_tensors.pop(0), output_tensors.pop(0) + # backward + input_grad = backward_step([input], [output], [output_grad])[0] + + # intra-barrier + if do_backward and do_forward: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_backward: + # print(f'rank {rank} send backward grad ') + coll.send(input_grad, prev_rank) + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + + # forward + last_forward = None + if do_forward: + # forward step + output = forward_step(model, input) + input_tensors.append(input) + output_tensors.append(output) + last_forward = output + + if rank % 2 == 1: + # inter-barrier + if rank == 1: + input = output_1st + else: + if do_forward and last_backward is not None: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + elif last_backward is not None: + # print(f'rank {rank} send backward grad ') + coll.send(last_backward, prev_rank) + + # forward + if do_forward: + output = forward_step(model, input) + input_tensors.append(input) + output_tensors.append(output) + + # intra-barrier send recv + output_grad = None + if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): + # send forward recv backward + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [output], [output.size()], [output.dtype], + [next_rank], [next_rank] + )[0] + elif do_forward and not is_last_stage(): + # print(f'rank {rank} send forward output ') + coll.send(output, next_rank) + elif do_backward and not is_last_stage(): + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + + # backward + forward + last_backward = None + if do_backward: + input, output = input_tensors.pop(0), output_tensors.pop(0) + input_grad = backward_step([input], [output], [output_grad])[0] + last_backward = input_grad + + # step3: tp backward + if 0 <= (step-num_stage+2) and (step-num_stage+2) <= num_microbatch - 1: + # print(f'rank {rank} backward tp model ') + tp_backward(last_backward) + + # if rank == 0: + # io_input(f'{step}>>>') + # torch.distributed.barrier() + + assert len(input_tensors) == 0 + assert len(output_tensors) == 0 + assert len(input_1st_tensors) == 0 + assert len(output_1st_tensors) == 0 + + # print_each_rank(f'=========end rank {rank}=========') + + +def schedule_tp_1f1b_pack(model: torch.nn.Module, + first_stage_model: torch.nn.Module, + dataloader, + num_microbatch: int, + num_stage: int): + rank = DeviceGroup().rank + next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size + prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size + + input_tensors = list() + output_tensors = list() + + input_1st_tensors = list() + output_1st_tensors = list() + + def tp_forward(fmodel, dataloader) -> torch.Tensor: + input = next(dataloader) + #TODO: gather + output = forward_step(fmodel, input) + input_1st_tensors.append(input) + output_1st_tensors.append(output) + output = output.detach().requires_grad_() + return output + + def tp_backward(grad: torch.Tensor): + input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) + if rank != 0: + grad = torch.empty_like(output_1st) + torch.distributed.broadcast(grad, src=0) + backward_step([input_1st], [output_1st], [grad])[0] + + fofst = [-(step // 2) for step in range(num_stage)] + bofst = [-(num_stage - 1 - (step // 2)) for step in range(num_stage)] + # print(fofst) + # print(bofst) + fofst = fofst[rank] + bofst = bofst[rank] + last_backward = None + last_forward = None + for step in range(num_microbatch + num_stage - 1): + torch.distributed.barrier() + # print_each_rank(f'=========begin rank {rank}=========') + fmid, bmid = step + fofst, step + bofst + do_backward = 0 <= bmid and bmid <= num_microbatch - 1 + do_forward = 0 <= fmid and fmid <= num_microbatch - 1 + + # step1: tp forward + if 0 <= step and step <= num_microbatch - 1: + # print(f'rank {rank} forward tp model ') + output_1st = tp_forward(first_stage_model, dataloader) + + # forward + backward + if rank % 2 == 0: + # inter-barrier + if rank == 0: + input = output_1st + else: + if do_forward and last_backward is not None: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + elif last_backward is not None: + # print(f'rank {rank} send backward grad ') + coll.send(last_backward, prev_rank) + + # forward + if do_forward: + output = forward_step(model, input) + input_tensors.append(input) + output_tensors.append(output) + + # mem = torch.cuda.max_memory_allocated() + # print(f'rank {rank}: {mem / 1024 / 1024 / 1024} GB forward') + + # intra-barrier send recv + output_grad = None + if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): + # send forward recv backward + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [output], [output.size()], [output.dtype], + [next_rank], [next_rank] + )[0] + elif do_forward and not is_last_stage(): + # print(f'rank {rank} send forward output ') + coll.send(output, next_rank) + elif do_backward and not is_last_stage(): + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + + # backward + last_backward = None + if do_backward: + input, output = input_tensors.pop(0), output_tensors.pop(0) + input_grad = backward_step([input], [output], [output_grad])[0] + last_backward = input_grad + + # backward + forward + if rank % 2 == 1: + # inter-barrier + if is_last_stage(): + output_grad = None + else: + if do_backward and last_forward is not None: + # print(f'rank {rank} recv backward grad + send forward output ') + output_grad = coll.sendrecv( + [last_forward], [model.output_shape()], [model.output_dtype()], + [next_rank], [next_rank] + )[0] + elif do_backward: + # print(f'rank {rank} recv backward grad ') + output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + elif last_forward is not None: + # print(f'rank {rank} send forward output ') + coll.send(last_forward, next_rank) + + # backward + if do_backward: + input, output = input_tensors.pop(0), output_tensors.pop(0) + # backward + input_grad = backward_step([input], [output], [output_grad])[0] + + # intra-barrier + if do_backward and do_forward: + # print(f'rank {rank} send backward grad + recv forward output ') + input = coll.sendrecv( + [input_grad], [model.input_shape()], [model.input_dtype()], + [prev_rank], [prev_rank] + )[0] + elif do_backward: + # print(f'rank {rank} send backward grad ') + coll.send(input_grad, prev_rank) + elif do_forward: + # print(f'rank {rank} recv forward output ') + input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + + # forward + last_forward = None + if do_forward: + # forward step + output = forward_step(model, input) + input_tensors.append(input) + output_tensors.append(output) + last_forward = output + + # step3: tp backward + if 0 <= (step-num_stage+1) and (step-num_stage+1) <= num_microbatch - 1: + # print(f'rank {rank} backward tp model ') + tp_backward(last_backward) + + # memory_summary() + # if rank == 0: + # io_input(f'{step}>>>') + # torch.distributed.barrier() + # print_each_rank(f'=========end rank {rank}: {step}=========') + + assert len(input_tensors) == 0 + assert len(output_tensors) == 0 + assert len(input_1st_tensors) == 0 + assert len(output_1st_tensors) == 0 + + # print_each_rank(f'=========end rank {rank}=========') \ No newline at end of file diff --git a/handcraft/mbart/tp.py b/handcraft/mbart/tp.py new file mode 100644 index 00000000..b37c71a6 --- /dev/null +++ b/handcraft/mbart/tp.py @@ -0,0 +1,88 @@ +from typing import Tuple +import torch + + +class Reduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, group): + torch.distributed.all_reduce(input, group=group) + return input + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class IdentityFoward(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, group): + ctx._group = group + return input + + @staticmethod + def backward(ctx, grad_output): + torch.distributed.all_reduce(grad_output, group=ctx._group) + return grad_output, None + + +def shard_linear_col(input, weight, bias, group): + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return torch.nn.functional(input, weight, bias) + input = IdentityFoward.apply(input, group) + return torch.nn.functional(input, weight, bias) + + +def shard_linear_row(input, weight, bias, group): + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return torch.nn.functional(input, weight, bias) + out = torch.nn.functional(input, weight, bias) + out = Reduce.apply(out, group) + return out + + +class DummyModelEmbed(torch.nn.Module): + + def __init__(self, num_embeddings: int, embedding_dim: int, + input_shape: Tuple[int, int], group = None): + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.input_shape = input_shape + + self.tp_group = group + self.tp_size = torch.distributed.get_world_size(group) + + shard_id = torch.distributed.get_rank(group) + self.vocab_start_index = num_embeddings // self.tp_size * shard_id + self.vocab_end_index = num_embeddings // self.tp_size * (shard_id + 1) + self.embed_weight = torch.nn.Parameter(torch.ones((num_embeddings // self.tp_size, embedding_dim))) + + def input_shape(self): + return self.input_shape + + def input_dtype(self): + return torch.int64 + + def output_shape(self): + return self.input_shape + (self.embedding_dim,) + + def output_dtype(self): + return torch.float32 + + def forward(self, input: torch.Tensor): + if self.tp_size > 1: + mask = (input < self.vocab_start_index) | \ + (input >= self.vocab_end_index) + input = input.clone() - self.vocab_start_index + input[mask] = 0 + input = torch.nn.functional.embedding(input, self.embed_weight) + input[mask, :] = 0.0 + input = Reduce.apply(input, self.tp_group) + else: + input = torch.nn.functional.embedding(input, self.embed_weight) + return input From 0863a634ee903e57d8af104c959809a6bf0084c3 Mon Sep 17 00:00:00 2001 From: lynex Date: Tue, 15 Mar 2022 14:58:14 +0800 Subject: [PATCH 0657/1892] 1) enable single process dev/debug mode with env SINGLE_DEV_MODE=True 2) Can finish gen code for example/atmosphere/weather.py if edit parser.py: # if isinstance(tensor, torch.nn.Parameter): # ir_tensor.as_param() ir_tensor.as_param() #TODO remove me handle non grad in gen.py fix einops.Rearrange code gen with pickle + cube.runtime.function.einops fix PyTorch circular pad with workaround fix skipping theta_ shape bug --- cube/graph/adapter/gen.py | 3 ++ cube/graph/operator/function/conv.py | 2 +- cube/graph/operator/function/function.py | 4 ++- cube/graph/operator/function/scripteinops.py | 13 ++++---- cube/profiler/timer.py | 5 ++++ cube/runtime/device.py | 31 +++++++++++++------- cube/runtime/function/function.py | 7 +++++ cube/runtime/resource.py | 7 ++++- examples/atmosphere/weather.py | 9 ++++-- 9 files changed, 60 insertions(+), 21 deletions(-) diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index 5e9a29cd..0a9b1112 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -83,6 +83,9 @@ def gen_weight_reducer(graph: IRGraph) -> IRGraph: for input in fnode.inputs(): if isinstance(input, IRSubTensor) and input.is_param(): grad = input.grad + if grad is None: #TODO remove me, for weather.py test + print(f'WARNING: skipping non grad of {fnode}') + continue # nothing to sync if grad.valmap == ValueMap(0, 1): continue diff --git a/cube/graph/operator/function/conv.py b/cube/graph/operator/function/conv.py index 92213396..9bb2b636 100644 --- a/cube/graph/operator/function/conv.py +++ b/cube/graph/operator/function/conv.py @@ -58,7 +58,7 @@ class IRConv3D(IRFwOperation): def __init__(self, signature: str, inputs: List[IRTensor], name: str, **kwargs): - signature = 'cube.runtime.function.conv3d' + #TODO signature = 'cube.runtime.function.conv3d' assert len(inputs) == 3, "Expected only input, weight, bias as inputs" assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" super().__init__(name, signature, 3, 1) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index e4ee991c..c09a958b 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -455,4 +455,6 @@ def ScriptEinOps(signature, inputs): recipe = inputs[0] tensors = inputs[1:2] reduction_type = inputs[2] - return IRScriptEinOps(signature, tensors, 'scripteinops', recipe=recipe, reduction_type=reduction_type) \ No newline at end of file + import pickle + recipe_str = pickle.dumps(recipe) + return IRScriptEinOps(signature, tensors, 'scripteinops', recipe_str=recipe_str, reduction_type=reduction_type) \ No newline at end of file diff --git a/cube/graph/operator/function/scripteinops.py b/cube/graph/operator/function/scripteinops.py index 14d37672..b8f89e0a 100644 --- a/cube/graph/operator/function/scripteinops.py +++ b/cube/graph/operator/function/scripteinops.py @@ -12,9 +12,9 @@ class IRScriptEinOps(IRFwOperation): def __init__(self, signature: str, inputs: List[IRTensor], name: str, **kwargs): - signature = 'einops._torch_specific.apply_for_scriptable_torch' #'cube.runtime.function.conv2d' + signature = 'cube.runtime.function.einops' assert len(inputs) == 1, "Expected only input" - assert len(kwargs) == 2, "Expected 2 kwargs: recipe, reduction_type" + assert len(kwargs) == 2, "Expected 2 kwargs: recipe_str, reduction_type" super().__init__(name, signature, 1, 1) for idx, input in enumerate(inputs): self.set_input(idx, input) @@ -27,7 +27,10 @@ def infer_shape(self) -> bool: if len(self.inputs(0).shape) == 0: return False - recipe = self.kwargs['recipe'] + recipe_str = self.kwargs['recipe_str'] + import pickle + recipe = pickle.loads(recipe_str) + reduction_type = self.kwargs['reduction_type'] tmp_tensor = torch.zeros(self.inputs(0).shape) tmp_output = _apply_recipe(recipe, tmp_tensor, reduction_type) @@ -39,10 +42,10 @@ def new(self, inputs: List, outputs: List): construct a new operator sharing same kwargs with new inputs and outputs """ - recipe = self.kwargs['recipe'] + recipe_str = self.kwargs['recipe_str'] reduction_type = self.kwargs['reduction_type'] op = IRScriptEinOps(self.signature, inputs, self.name, - recipe=recipe, reduction_type=reduction_type) + recipe_str=recipe_str, reduction_type=reduction_type) assert len(outputs) == 1 op.set_output(0, outputs[0]) op.infer_shape() diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index f5e12e13..cd90da65 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -5,6 +5,11 @@ def print_each_rank(msg, rank_only=None, outfile=''): + import os + single_device_mode = os.environ.get('SINGLE_DEV_MODE') + if single_device_mode: + return + myrank = torch.distributed.get_rank() outfile = sys.stdout if outfile == '' else outfile for rank in range(torch.distributed.get_world_size()): diff --git a/cube/runtime/device.py b/cube/runtime/device.py index 0e667ed8..6b6cf157 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -12,16 +12,27 @@ class DeviceGroup: class __DeviceGroup: def __init__(self): - torch.distributed.init_process_group( - backend='nccl', - ) - self.rank = torch.distributed.get_rank() - self.world_size = torch.distributed.get_world_size() - # assume each node has the same device number - self.local_rank = int(os.environ.get('LOCAL_RANK')) - self.node_id = self.rank // torch.cuda.device_count() - self.groups = dict() - torch.cuda.set_device(self.local_rank) + single_device_mode = os.environ.get('SINGLE_DEV_MODE') + print(f'single_device_mode = {single_device_mode}') + if single_device_mode: + print(f"DeviceGroup init using single device mode...") + self.rank = 0 + self.world_size = 1 + self.local_rank = 0 + self.node_id = 0 + self.groups = dict() + torch.cuda.set_device(0) + else: + torch.distributed.init_process_group( + backend='nccl', + ) + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + # assume each node has the same device number + self.local_rank = int(os.environ.get('LOCAL_RANK')) + self.node_id = self.rank // torch.cuda.device_count() + self.groups = dict() + torch.cuda.set_device(self.local_rank) instance = None diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index f5aa2315..4a01c9ad 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -41,3 +41,10 @@ def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): ) output[input_mask, :] = 0.0 return output + +def einops(input: torch.Tensor, recipe_str, reduction_type: str): + import pickle + recipe = pickle.loads(recipe_str) + from einops.einops import _apply_recipe + output = _apply_recipe(recipe, input, reduction_type) + return output \ No newline at end of file diff --git a/cube/runtime/resource.py b/cube/runtime/resource.py index 63a5275c..06dc1c2f 100644 --- a/cube/runtime/resource.py +++ b/cube/runtime/resource.py @@ -3,6 +3,7 @@ """ import torch +import os class EnvResource: @@ -11,7 +12,11 @@ class __EnvResource: def __init__(self): # number of gpus - self.ngpus = torch.distributed.get_world_size() + single_device_mode = os.environ.get('SINGLE_DEV_MODE') + if single_device_mode: + self.ngpus = 1 + else: + self.ngpus = torch.distributed.get_world_size() # device topology self.topo = None diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index 946d4cdf..981bdd38 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -68,6 +68,9 @@ def __init__(self, self.pre_conv3d_reshape = Rearrange('(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # x = X.view(1, 1, Nz, Ny, Nx) self.post_conv3d_reshape = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') + self.pre_pady_reshape = Rearrange('(b0 Nz) Ny Nx -> b0 Nz Ny Nx', b0=1) + self.post_pady_reshape = Rearrange('b0 Nz Ny Nx -> (b0 Nz) Ny Nx') + def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # flux @@ -97,7 +100,7 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # theta_ = self.pad_z( # (self.bar_z(self.P * theta) - self.delta_z(theta) * self.P_[1:-1]) / self.delta_z(self.P) # ) # (nz + 1, ny, nx) - theta_ = theta0 #TODO remove me + theta_ = self.pad_z(self.bar_z(theta0)) #theta0 #TODO remove me theta1 = pi0 / pi1 * theta0 + dt / self.deltaA / pi1 * ( (self.delta_x(F * self.bar_x(self.pad_x(theta))) + self.delta_y(G * self.bar_y(self.pad_y(theta)))) / 2. + pi * self.deltaA * self.delta_z(self.w * theta_) / self.dz + @@ -194,7 +197,7 @@ def delta_x(self, X): def pad_y(self, X): #TODO check return F.pad(X.view(1, nz, ny, nx), (0, 0, 1, 1), "circular").view(nz, ny + 2, nx) - return F.pad(X, (0, 0, 1, 1), "circular") + return self.post_pady_reshape(F.pad(self.pre_pady_reshape(X), (0, 0, 1, 1), "circular")) def bar_y(self, X): @@ -451,4 +454,4 @@ def train_iter(model, dataloader): # plt.savefig(f'res2/res{i}.jpeg', dpi=300) # plt.clf() - # print(i) + print(f'pi = {pi}; theta = {theta}; u = {u}; v = {v}') From d97f62479f2e60737ea18879c5edd3957d42eed2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 15 Mar 2022 20:36:49 +0800 Subject: [PATCH 0658/1892] add tp1f1b-pack --- handcraft/mbart/mbart.py | 314 +++++++++++++++++++-------- handcraft/mbart/schedule.py | 417 +++++++++++++----------------------- handcraft/mbart/tp.py | 110 +++++----- 3 files changed, 433 insertions(+), 408 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index bec0789d..10c846cf 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -19,7 +19,8 @@ from cube.profiler.memory import memory_summary, model_summary from cube.profiler.timer import print_each_rank -from handcraft.mbart.schedule import schedule_naive +from handcraft.mbart.schedule import schedule_naive, schedule_tp_1f1b_pack +from handcraft.mbart.tp import AllGatherScatter, ParallelEmbed, BroadcastReduce, ReduceBroadcast _tp_group = -1 _pp_group = -1 @@ -403,13 +404,116 @@ def criterion(output: torch.Tensor, prev_output_tokens: torch.Tensor, label_smoo return loss -class mBARTFull(torch.nn.Module): +class ShardHeadTail(torch.nn.Module): - def __init__(self, cfg: Config, dataloader): + def __init__(self, cfg: Config, group=-1): + """ + group = -1 means no tensor parallelism + """ super().__init__() self.cfg = cfg - self.dataloader = iter(dataloader) + self.group = group + self.shard_num = torch.distributed.get_world_size(group) if group != -1 else 1 + self.shard_idx = torch.distributed.get_rank(group) if group != -1 else 0 + if self.shard_num > 0: + print(f'[{torch.distributed.get_rank()}]: initialize sharding embed (x{self.shard_num})') + self.vocab_start_index = self.cfg.num_embeddings // self.shard_num * self.shard_idx + self.vocab_end_index = self.cfg.num_embeddings // self.shard_num * (self.shard_idx + 1) + self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.shard_num, self.cfg.encoder_embed_dim))) + + # encoder-preprocess + self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) + self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) + self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + # decoder-preprocess + self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) + self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) + self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + # post-proces + + self._inputs = (None, None) + + def set_inputs(self, *inputs): + self._inputs = inputs + + def criterion_input_shape(self): + return ( + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + ) + + def embed_lookup(self, tokens, dst: Optional[int] = None): + if self.shard_num > 1: + mask = (tokens < self.vocab_start_index) | \ + (tokens >= self.vocab_end_index) + tokens = tokens.clone() - self.vocab_start_index + tokens[mask] = 0 + embed = torch.nn.functional.embedding(tokens, self.weight) + embed[mask, :] = 0.0 + embed = ReduceBroadcast.apply(embed, dst, self.group) + else: + embed = torch.nn.functional.embedding(tokens, self.weight) + return embed + + def encoder_preprocess(self, dst: Optional[int] = None): + source_tokens, _ = self._inputs + source_embed = self.embed_lookup(source_tokens, dst) + embed = self.embed_scale_encoder * source_embed + x = embed + self.embed_positions_encoder.weight + x = self.layernorm_embedding_encoder(x) + x = torch.nn.functional.dropout(x, p=0.0) + enc = x.transpose(0, 1) + return (enc,) + + def decoder_preprocess(self, dst: Optional[int] = None): + _, prev_output_tokens = self._inputs + target_emb = self.embed_lookup(prev_output_tokens, dst) + embed = self.embed_scale_decoder * target_emb + embed = embed + self.embed_positions_decoder.weight + embed = self.layernorm_embedding_decoder(embed) + embed = torch.nn.functional.dropout(embed, p=0.0) + dec = embed.transpose(0, 1) + return (dec,) + + def postprocess(self, output, src: Optional[int] = None): + _, prev_output_tokens = self._inputs + if self.group == -1: + output = self.layer_norm_decoder(output) + output = output.transpose(0, 1) + output = torch.nn.functional.linear(output, self.weight) + loss = criterion(output, prev_output_tokens) + return (loss,) + else: + assert src is not None + if self.shard_idx != src: + output = torch.empty( + self.criterion_input_shape()[0], + dtype=torch.float32, + requires_grad=True, + device=torch.cuda.current_device() + ) + output = output.transpose(0, 1) + output = BroadcastReduce.apply(output, src, self.group) + # return (torch.sum(output.contiguous()),) + output = torch.nn.functional.linear(output, self.weight) + output = AllGatherScatter.apply(output, -1, self.group) + loss = criterion(output, prev_output_tokens) + return (loss,) + + + +class mBARTFull(torch.nn.Module): + + def __init__(self, cfg: Config, + encoder_preprocess=True, + decoder_preprocess=True, + post_process=True, shard=True): + super().__init__() + self.cfg = cfg + self._preprocess = [None, None] # enc, dec + self.rank = DeviceGroup().rank global _pp_group @@ -422,11 +526,15 @@ def __init__(self, cfg: Config, dataloader): self.layer_start = self.total_layers // self.num_stages * self.pp_stage self.layer_end = self.total_layers // self.num_stages * (self.pp_stage + 1) - self.encoder_preprocess = (self.pp_stage == 0) - self.encoder_forward = (self.layer_start < cfg.encoder_layers) - self.decoder_preprocess = (self.pp_stage == self.num_stages // 2) - self.decoder_forward = (self.layer_start >= cfg.encoder_layers) - self.loss_compute = (self.pp_stage == self.num_stages - 1) + self.encoder_preprocess = encoder_preprocess + self.encoder_forward = (self.layer_start < cfg.encoder_layers) + + self.decoder_preprocess = decoder_preprocess + self.decoder_first_stage = self.layer_start == cfg.encoder_layers + self.decoder_forward = (self.layer_start >= cfg.encoder_layers) + self.decoder_last_stage = (self.layer_end == cfg.encoder_layers + cfg.decoder_layers) + + self.postprocess = post_process self.encoder_layer_start = self.layer_start self.encoder_layer_end = min(self.layer_end, cfg.encoder_layers) @@ -434,47 +542,35 @@ def __init__(self, cfg: Config, dataloader): self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) self.decoder_layer_end = self.layer_end - self.emb = None - - # encoder preprocess - if self.encoder_preprocess: - print(f'[{self.rank}]: initializing preprocess encoder parameters') - self.emb = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) - self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) - self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) - self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + if encoder_preprocess or decoder_preprocess or post_process or shard: + self.headtail = ShardHeadTail(cfg, group = None if shard else -1) + else: + self.headtail = None # encoders if self.encoder_forward: print(f'[{self.rank}]: initializing {self.encoder_layer_end - self.encoder_layer_start} encoder layers') self.encoders = torch.nn.ModuleList([EncoderLayer(cfg) for _ in range(self.encoder_layer_end - self.encoder_layer_start)]) - - # decoder preprocess - if self.decoder_preprocess: - print(f'[{self.rank}]: initializing preprocess decoder parameters') - if self.emb is None: - self.emb = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) - self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) - self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) - self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) - self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) - + if self.encoder_layer_end == cfg.encoder_layers: + self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + else: + self.layer_norm_encoder = None + # decoders if self.decoder_forward: print(f'[{self.rank}]: initializing {self.decoder_layer_end - self.decoder_layer_start} decoder layers') self.decoders = torch.nn.ModuleList([DecoderLayer(cfg) for _ in range(self.decoder_layer_end - self.decoder_layer_start)]) + if self.decoder_layer_end == cfg.encoder_layers + cfg.decoder_layers: + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + else: + self.layer_norm_decoder = None - # compute loss - if self.loss_compute: + # postpross + if self.postprocess: print(f'[{self.rank}]: will compute loss') - if self.emb is None: - self.emb = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) - def input_shape(self): if self.encoder_preprocess: - # src_tokens, prev_output_tokens return () elif self.encoder_forward: return ( @@ -484,24 +580,38 @@ def input_shape(self): return ( (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), ) + elif self.decoder_first_stage: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + ) elif self.decoder_forward: return ( (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), ) + elif self.decoder_last_stage: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + ) else: - assert False + assert False, "post-process is not allowed to be a single stage" def output_shape(self): shape = None if self.encoder_preprocess or self.encoder_forward: shape = (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + # decoder preprocess is not allowed to be a single stage if self.decoder_preprocess or self.decoder_forward: shape = ( (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), ) - if self.loss_compute: + if self.decoder_last_stage: + shape = ( + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + ) + if self.postprocess: shape = ((1,),) assert shape is not None return shape @@ -523,64 +633,77 @@ def output_dtype(self): if self.encoder_preprocess or self.encoder_forward: dtype = (torch.float32,) if self.decoder_preprocess or self.decoder_forward: - dtype = (torch.float32, torch.float32) - if self.loss_compute: - dtype = ((1,),) + if self.pp_stage == self.num_stages - 1: + dtype = (torch.float32,) + else: + dtype = (torch.float32, torch.float32) + if self.postprocess: + dtype = ((torch.float32,),) assert dtype is not None return dtype + def set_inputs(self, *inputs): + assert len(inputs) == 2 + if self.headtail is not None: + self.headtail.set_inputs(*inputs) + + def set_preprocess(self, enc=None, dec=None): + if enc is not None: + self._preprocess[0] = enc + if dec is not None: + self._preprocess[1] = dec + + def forward_encoder_preprocess(self, dst=None): + return self.headtail.encoder_preprocess(dst) + + def forward_decoder_preprocess(self, dst=None): + return self.headtail.decoder_preprocess(dst) + + def forward_postprocess(self, dec, src=None): + return self.headtail.postprocess(dec, src) def forward(self, enc=None, dec=None): """ - x1: src_tokens or encoder output/input - x2: prev_output_tokens or decoder output/input + enc: encoder input/output + dec: decoder output/input """ - src_tokens, prev_output_tokens = None, None + pre_enc, pre_dec = self._preprocess + enc = pre_enc if enc is None else enc + dec = pre_dec if dec is None else dec + # encoder preprocess if self.encoder_preprocess: - src_tokens, prev_output_tokens = next(self.dataloader) - token_embedding = torch.nn.functional.embedding(src_tokens, self.emb.weight) - embed = self.embed_scale_encoder * token_embedding - x = embed + self.embed_positions_encoder.weight - x = self.layernorm_embedding_encoder(x) - x = torch.nn.functional.dropout(x, p=0.0) - enc = x.transpose(0, 1) - output = (enc,) + output = self.forward_encoder_preprocess(dst=None) + enc = output[0] # forward encoder if self.encoder_forward: for layer in self.encoders: enc = layer(enc) # encoder_padding_mask if has_pads else None) + if self.layer_norm_encoder is not None: + enc = self.layer_norm_encoder(enc) output = (enc,) # decoder preprocess if self.decoder_preprocess: - enc = self.layer_norm_encoder(enc) - if prev_output_tokens is None: - _, prev_output_tokens = next(self.dataloader) - embed = torch.nn.functional.embedding(prev_output_tokens, self.emb.weight) - embed = self.embed_scale_decoder * embed - embed = embed + self.embed_positions_decoder.weight - embed = self.layernorm_embedding_decoder(embed) - embed = torch.nn.functional.dropout(embed, p=0.0) - dec = embed.transpose(0, 1) - output = (enc, dec) + output = self.forward_decoder_preprocess(dst=None) + dec = output[0] # forward decoder if self.decoder_forward: + dec = pre_dec if dec is None else dec for layer in self.decoders: dec, enc = layer(dec, enc) - output = (enc, dec) + if self.layer_norm_decoder is not None: + dec = self.layer_norm_decoder(dec) + output = (dec,) + else: + output = (enc, dec) # postprocess - if self.loss_compute: - if prev_output_tokens is None: - _, prev_output_tokens = next(self.dataloader) - dec = self.layer_norm_decoder(dec) - dec = dec.transpose(0, 1) - dec = torch.nn.functional.linear(dec, self.emb.weight) - loss = criterion(dec, prev_output_tokens) - output = (loss,) + if self.postprocess: + output = self.forward_postprocess(dec) + loss = output[0] return output @@ -590,7 +713,7 @@ def reduce_embed(model, pp_embed_group): Embedding gradients needs to be reduced across pipeline stages """ if isinstance(model.emb, torch.nn.Module): - grad = model.emb.weight.grad + grad = model.emb.get_weight().grad else: grad = None if grad is not None: @@ -598,39 +721,39 @@ def reduce_embed(model, pp_embed_group): torch.cuda.synchronize() - - if __name__ == '__main__': parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--tp-size', type=int, - help='tensor parallelism size') - parser.add_argument('--pp-size', type=int, - help='pipeline parallelism size') parser.add_argument('--nmb', type=int, default=4, help='num of micro batch') + parser.add_argument('--use-naive', action='store_true', + help='use naive pipeline') + parser.add_argument('--use-tp1f1b-pack', action='store_true', + help='use tensor parallel 1f1b') args = parser.parse_args() print(args) cube.init() - pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) + pp_ranks = list(range(DeviceGroup().world_size)) + # pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) print_each_rank(f'my pp ranks: {pp_ranks}') - print_each_rank(f'my tp_ranks: {tp_ranks}') if _pp_group == -1: _pp_group = DeviceGroup().get_group(pp_ranks) idx = pp_ranks.index(DeviceGroup().rank) _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] - if _tp_group == -1: - _tp_group = DeviceGroup().get_group(tp_ranks) + is_first_stage = idx == 0 + is_first_decoder_stage = idx == len(pp_ranks) // 2 + is_last_stage = idx == len(pp_ranks) - 1 # create embed group: first encoder, first decoder, last stage # FIXME: only work for tp_size = 1 - embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2], pp_ranks[-1]] - embed_ranks = list(set(embed_ranks)) - _pp_embed_group = DeviceGroup().get_group(embed_ranks) + if args.use_naive: + embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2], pp_ranks[-1]] + embed_ranks = list(set(embed_ranks)) + _pp_embed_group = DeviceGroup().get_group(embed_ranks) cfg = Config() @@ -642,7 +765,14 @@ def reduce_embed(model, pp_embed_group): dtypes=(torch.int64, torch.int64), batch_dims=(0,0,) ) - model = mBARTFull(cfg, dataloader).cuda() + if args.use_naive: + encoder_preprocess = is_first_stage + decoder_preprocess = is_first_decoder_stage + postprocess = is_last_stage + model = mBARTFull(cfg, encoder_preprocess, decoder_preprocess, postprocess, shard=False).cuda() + else: + model = mBARTFull(cfg, False, False, False, shard=True).cuda() + print_each_rank('model weight consumpition:') memory_summary() @@ -653,8 +783,16 @@ def reduce_embed(model, pp_embed_group): for step in range(iter_num): if step >= 20: CudaTimer(enable=True).start('e2e') - schedule_naive(model, args.nmb, (_pp_prev_rank, _pp_next_rank)) - reduce_embed(model, _pp_embed_group) + if args.use_naive: + schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) + reduce_embed(model, _pp_embed_group) + if args.use_tp1f1b_pack: + schedule_tp_1f1b_pack( + model, iter(dataloader), + args.nmb, len(pp_ranks), (_pp_prev_rank, _pp_next_rank) + ) + if step == 0: + print('passed 1st iteration') optimizer.step() optimizer.zero_grad() if step >= 20: diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py index bda0b989..22597928 100644 --- a/handcraft/mbart/schedule.py +++ b/handcraft/mbart/schedule.py @@ -151,7 +151,7 @@ def send_forward_recv_backward(outputs, model, next_rank: int) -> List[torch.Ten ) for tensor in tensors ] ops += recv_ops - reqs = torch.distributed.batch_isend_irecv(send_ops) + reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() torch.cuda.synchronize() @@ -185,7 +185,7 @@ def send_backward_recv_forward(grads, model, prev_rank: int) -> List[torch.Tenso ) for tensor in tensors ] ops += recv_ops - reqs = torch.distributed.batch_isend_irecv(send_ops) + reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() torch.cuda.synchronize() @@ -194,7 +194,7 @@ def send_backward_recv_forward(grads, model, prev_rank: int) -> List[torch.Tenso -def schedule_naive(model, num_microbatch: int, neighbors: Tuple[int, int]): +def schedule_naive(model, dataloader, num_microbatch: int, neighbors: Tuple[int, int]): """ neighbors: (prev_rank: int, next_rank: int) """ @@ -205,6 +205,7 @@ def schedule_naive(model, num_microbatch: int, neighbors: Tuple[int, int]): is_last_stage = rank > next_rank for step in range(num_microbatch): + model.set_inputs(*next(dataloader)) # print(f'rank {rank} recving forward input...') inputs = () if is_first_stage else recv_forward(model, prev_rank) # forward @@ -229,213 +230,69 @@ def schedule_naive(model, num_microbatch: int, neighbors: Tuple[int, int]): # torch.distributed.barrier() -def schedule_tp_1f1b(model: torch.nn.Module, - first_stage_model: torch.nn.Module, - dataloader, - num_microbatch: int, - num_stage: int): +def schedule_tp_1f1b_pack(model: torch.nn.Module, + dataloader, + num_microbatch: int, + num_stage: int, + neighbors: Tuple[int, int]): rank = DeviceGroup().rank - next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size - prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size + prev_rank, next_rank = neighbors + + is_first_stage = rank < prev_rank + # FIXME: only work for pure pipeline + is_first_decoder_stage = (rank == num_stage // 2) + is_last_stage = rank > next_rank + last_stage = torch.distributed.get_world_size() - 1 input_tensors = list() output_tensors = list() - input_1st_tensors = list() - output_1st_tensors = list() - - gather_list = list(range(num_stage)) - gather_list[0], gather_list[1] = gather_list[1], gather_list[0] - - def tp_forward(fmodel, dataloader) -> torch.Tensor: - input = next(dataloader) - output = forward_step(fmodel, input) - input_1st_tensors.append(input) - output_1st_tensors.append(output) - # gather - outputs = coll.gather([output], None, None, gather_list) - if rank == 1: - with torch.no_grad(): - outputs[0], outputs[1] = outputs[1], outputs[0] - output = torch.cat(tuple(outputs), dim=-1) - output = output.requires_grad_() - else: - output = None - return output - - def tp_backward(grad: torch.Tensor): - if rank == 1: - with torch.no_grad(): - grads = list(grad.chunk(num_stage, dim=-1)) - grads[0], grads[1] = grads[1], grads[0] + input_head_tensors = list() + output_head_tensors = list() + + def tp_head_forward() -> torch.Tensor: + src_tokens, prev_output_tokens = next(dataloader) + model.set_inputs(*(src_tokens, prev_output_tokens)) + enc = model.forward_encoder_preprocess(dst=0)[0] + dec = model.forward_decoder_preprocess(dst=num_stage // 2)[0] + input_head_tensors.append((src_tokens, prev_output_tokens)) + output_head_tensors.append((enc, dec)) + enc = enc.detach().requires_grad_() + dec = dec.detach().requires_grad_() + # FIXME: this will change decoder input + if is_first_stage: + model.set_preprocess(enc=enc) + if is_first_decoder_stage: + model.set_preprocess(dec=dec) + if is_first_stage: + return (enc,) + if is_first_decoder_stage: + return (dec,) else: - grads = None - input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) - grad_1st = coll.scatter(grads, [output_1st.size()], [output_1st.dtype], gather_list) - backward_step([input_1st], [output_1st], [grad_1st])[0] - - fofst = [0] + [-(step // 2) for step in range(num_stage-1)] - bofst = [0] + [-(num_stage - 2 - (step // 2)) for step in range(num_stage-1)] - # print(fofst) - # print(bofst) - fofst = fofst[rank] - bofst = bofst[rank] - last_backward = None - last_forward = None - for step in range(num_microbatch + num_stage - 2): - torch.distributed.barrier() - # print_each_rank(f'=========begin rank {rank}=========') - fmid, bmid = step + fofst, step + bofst - do_backward = 0 <= bmid and bmid <= num_microbatch - 1 - do_forward = 0 <= fmid and fmid <= num_microbatch - 1 - - # step1: tp forward - if 0 <= step and step <= num_microbatch - 1: - # print(f'rank {rank} forward tp model ') - output_1st = tp_forward(first_stage_model, dataloader) - - # step2: backward + forward - if rank == 0: - pass - - if rank % 2 == 0 and rank != 0: - # inter-barrier - if do_backward and last_forward is not None: - # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [last_forward], [model.output_shape()], [model.output_dtype()], - [next_rank], [next_rank] - )[0] - elif do_backward: - # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) - elif last_forward is not None: - # print(f'rank {rank} send forward output ') - coll.send(last_forward, next_rank) - - # backward - if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) - # backward - input_grad = backward_step([input], [output], [output_grad])[0] - - # intra-barrier - if do_backward and do_forward: - # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] - elif do_backward: - # print(f'rank {rank} send backward grad ') - coll.send(input_grad, prev_rank) - elif do_forward: - # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - - # forward - last_forward = None - if do_forward: - # forward step - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) - last_forward = output - - if rank % 2 == 1: - # inter-barrier - if rank == 1: - input = output_1st - else: - if do_forward and last_backward is not None: - # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] - elif do_forward: - # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - elif last_backward is not None: - # print(f'rank {rank} send backward grad ') - coll.send(last_backward, prev_rank) - - # forward - if do_forward: - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) + return () - # intra-barrier send recv - output_grad = None - if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): - # send forward recv backward - # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [output], [output.size()], [output.dtype], - [next_rank], [next_rank] - )[0] - elif do_forward and not is_last_stage(): - # print(f'rank {rank} send forward output ') - coll.send(output, next_rank) - elif do_backward and not is_last_stage(): - # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) - - # backward + forward - last_backward = None - if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) - input_grad = backward_step([input], [output], [output_grad])[0] - last_backward = input_grad - - # step3: tp backward - if 0 <= (step-num_stage+2) and (step-num_stage+2) <= num_microbatch - 1: - # print(f'rank {rank} backward tp model ') - tp_backward(last_backward) - - # if rank == 0: - # io_input(f'{step}>>>') - # torch.distributed.barrier() - - assert len(input_tensors) == 0 - assert len(output_tensors) == 0 - assert len(input_1st_tensors) == 0 - assert len(output_1st_tensors) == 0 - - # print_each_rank(f'=========end rank {rank}=========') - - -def schedule_tp_1f1b_pack(model: torch.nn.Module, - first_stage_model: torch.nn.Module, - dataloader, - num_microbatch: int, - num_stage: int): - rank = DeviceGroup().rank - next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size - prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size - - input_tensors = list() - output_tensors = list() - - input_1st_tensors = list() - output_1st_tensors = list() - - def tp_forward(fmodel, dataloader) -> torch.Tensor: - input = next(dataloader) - #TODO: gather - output = forward_step(fmodel, input) - input_1st_tensors.append(input) - output_1st_tensors.append(output) - output = output.detach().requires_grad_() - return output - - def tp_backward(grad: torch.Tensor): - input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) - if rank != 0: - grad = torch.empty_like(output_1st) - torch.distributed.broadcast(grad, src=0) - backward_step([input_1st], [output_1st], [grad])[0] + def tp_head_backward(grads: Tuple[torch.Tensor]): + inputs_head, outputs_head = input_head_tensors.pop(0), output_head_tensors.pop(0) + # encoder backward + enc, dec = outputs_head + if not is_first_stage: + grads = (torch.empty_like(enc),) + # decoder backward + backward_step((), (enc,), grads) + #FIXME: grads is using enc gradient!!! + if not is_first_decoder_stage: + grads = (torch.empty_like(dec),) + backward_step((), (dec,), grads) + + def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): + dec = None + if is_last_stage: + assert len(outputs) == 1 + dec = outputs[0] + dec = dec.detach().requires_grad_() + loss = model.forward_postprocess(dec, src=last_stage) + grads = backward_step((dec,), loss, (None,)) + return grads fofst = [-(step // 2) for step in range(num_stage)] bofst = [-(num_stage - 1 - (step // 2)) for step in range(num_stage)] @@ -443,8 +300,9 @@ def tp_backward(grad: torch.Tensor): # print(bofst) fofst = fofst[rank] bofst = bofst[rank] - last_backward = None - last_forward = None + last_backward = (None,) + last_forward = (None,) + tail_grads = (None,) for step in range(num_microbatch + num_stage - 1): torch.distributed.barrier() # print_each_rank(f'=========begin rank {rank}=========') @@ -455,121 +313,142 @@ def tp_backward(grad: torch.Tensor): # step1: tp forward if 0 <= step and step <= num_microbatch - 1: # print(f'rank {rank} forward tp model ') - output_1st = tp_forward(first_stage_model, dataloader) + inputs = tp_head_forward() # forward + backward if rank % 2 == 0: # inter-barrier - if rank == 0: - input = output_1st + if is_first_stage: + inputs = inputs else: - if do_forward and last_backward is not None: + if do_forward and last_backward != (None,): # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] + inputs = send_backward_recv_forward(last_backward, model, prev_rank) + # input = coll.sendrecv( + # [input_grad], [model.input_shape()], [model.input_dtype()], + # [prev_rank], [prev_rank] + # )[0] elif do_forward: # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - elif last_backward is not None: + inputs = recv_forward(model, prev_rank) + # input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + elif last_backward != (None,): # print(f'rank {rank} send backward grad ') - coll.send(last_backward, prev_rank) + send_backward(last_backward, prev_rank) + # coll.send(last_backward, prev_rank) # forward if do_forward: - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) + input_tensors.append(inputs) + if is_first_stage: + inputs = () + outputs = forward_step(model, *inputs) + output_tensors.append(outputs) # mem = torch.cuda.max_memory_allocated() # print(f'rank {rank}: {mem / 1024 / 1024 / 1024} GB forward') # intra-barrier send recv - output_grad = None - if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): + output_grads = (None,) + if (do_forward and not is_last_stage) and (do_backward and not is_last_stage): # send forward recv backward # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [output], [output.size()], [output.dtype], - [next_rank], [next_rank] - )[0] - elif do_forward and not is_last_stage(): + output_grads = send_forward_recv_backward(outputs, model, next_rank) + # output_grads = coll.sendrecv( + # [output], [output.size()], [output.dtype], + # [next_rank], [next_rank] + # )[0] + elif do_forward and not is_last_stage: # print(f'rank {rank} send forward output ') - coll.send(output, next_rank) - elif do_backward and not is_last_stage(): + send_forward(outputs, next_rank) + # coll.send(output, next_rank) + elif do_backward and not is_last_stage: # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + output_grads = recv_backward(model, next_rank) + # output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) # backward - last_backward = None + last_backward = (None,) if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) - input_grad = backward_step([input], [output], [output_grad])[0] - last_backward = input_grad + inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) + input_grads = backward_step(inputs, outputs, output_grads) + last_backward = input_grads # backward + forward if rank % 2 == 1: # inter-barrier - if is_last_stage(): - output_grad = None + if is_last_stage: + output_grads = tail_grads else: - if do_backward and last_forward is not None: + if do_backward and last_forward != (None,): # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [last_forward], [model.output_shape()], [model.output_dtype()], - [next_rank], [next_rank] - )[0] + output_grads = send_forward_recv_backward(last_forward, model, next_rank) + # output_grad = coll.sendrecv( + # [last_forward], [model.output_shape()], [model.output_dtype()], + # [next_rank], [next_rank] + # )[0] elif do_backward: # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) - elif last_forward is not None: + output_grads = recv_backward(model, next_rank) + # output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) + elif last_forward != (None,): # print(f'rank {rank} send forward output ') - coll.send(last_forward, next_rank) + send_forward(last_forward, next_rank) + # coll.send(last_forward, next_rank) # backward + last_backward = (None,) if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) + inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) # backward - input_grad = backward_step([input], [output], [output_grad])[0] + input_grads = backward_step(inputs, outputs, output_grads) + last_backward = input_grads # intra-barrier if do_backward and do_forward: # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] + inputs = send_backward_recv_forward(input_grads, model, prev_rank) + # input = coll.sendrecv( + # [input_grad], [model.input_shape()], [model.input_dtype()], + # [prev_rank], [prev_rank] + # )[0] elif do_backward: # print(f'rank {rank} send backward grad ') - coll.send(input_grad, prev_rank) + send_backward(input_grads, prev_rank) + # coll.send(input_grad, prev_rank) elif do_forward: # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) + inputs = recv_forward(model, prev_rank) + # input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) # forward - last_forward = None + last_forward = (None,) if do_forward: # forward step - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) - last_forward = output - - # step3: tp backward - if 0 <= (step-num_stage+1) and (step-num_stage+1) <= num_microbatch - 1: - # print(f'rank {rank} backward tp model ') - tp_backward(last_backward) - - # memory_summary() - # if rank == 0: - # io_input(f'{step}>>>') - # torch.distributed.barrier() - # print_each_rank(f'=========end rank {rank}: {step}=========') + outputs = forward_step(model, *inputs) + input_tensors.append(inputs) + output_tensors.append(outputs) + last_forward = outputs + + # tp tail forward-backward + last_stage_mid = step - (num_stage - 1) // 2 + if 0 <= last_stage_mid and last_stage_mid <= num_microbatch - 1: + tail_grads = tp_tail_forward_backward(last_forward) + + # step 4: tp encoder and decoder backward + encoder_mid = step + 1 - num_stage + if 0 <= encoder_mid and encoder_mid <= num_microbatch - 1: + tp_head_backward(last_backward) + + memory_summary() + if rank == 0: + io_input(f'{step}>>>') + torch.distributed.barrier() + print_each_rank(f'=========end rank {rank}: {step}=========') assert len(input_tensors) == 0 assert len(output_tensors) == 0 - assert len(input_1st_tensors) == 0 - assert len(output_1st_tensors) == 0 + assert len(input_head_tensors) == 0 + assert len(output_head_tensors) == 0 # print_each_rank(f'=========end rank {rank}=========') \ No newline at end of file diff --git a/handcraft/mbart/tp.py b/handcraft/mbart/tp.py index b37c71a6..ff6e9302 100644 --- a/handcraft/mbart/tp.py +++ b/handcraft/mbart/tp.py @@ -2,7 +2,7 @@ import torch -class Reduce(torch.autograd.Function): +class AllReduceIdentity(torch.autograd.Function): @staticmethod def forward(ctx, input, group): @@ -14,7 +14,7 @@ def backward(ctx, grad_output): return grad_output, None -class IdentityFoward(torch.autograd.Function): +class IdentityAllreduce(torch.autograd.Function): @staticmethod def forward(ctx, input, group): @@ -27,62 +27,70 @@ def backward(ctx, grad_output): return grad_output, None -def shard_linear_col(input, weight, bias, group): - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return torch.nn.functional(input, weight, bias) - input = IdentityFoward.apply(input, group) - return torch.nn.functional(input, weight, bias) +class AllGatherScatter(torch.autograd.Function): + @staticmethod + def forward(ctx, input, dim, group): + ctx._group = group + ctx._dim = dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input + rank = torch.distributed.get_rank(group) + tensor_list = [torch.empty_like(input) for _ in range(world_size)] + tensor_list[rank] = input + torch.distributed.all_gather(tensor_list, input, group=group) + output = torch.cat(tensor_list, dim=dim).contiguous() + return output -def shard_linear_row(input, weight, bias, group): - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return torch.nn.functional(input, weight, bias) - out = torch.nn.functional(input, weight, bias) - out = Reduce.apply(out, group) - return out - - -class DummyModelEmbed(torch.nn.Module): - - def __init__(self, num_embeddings: int, embedding_dim: int, - input_shape: Tuple[int, int], group = None): - super().__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.input_shape = input_shape + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + group = ctx._group + dim = ctx._dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output + input_list = grad_output.chunk(world_size, dim=dim) + rank = torch.distributed.get_rank(group) + grad = input_list[rank].contiguous() + return grad, None, None - self.tp_group = group - self.tp_size = torch.distributed.get_world_size(group) - shard_id = torch.distributed.get_rank(group) - self.vocab_start_index = num_embeddings // self.tp_size * shard_id - self.vocab_end_index = num_embeddings // self.tp_size * (shard_id + 1) - self.embed_weight = torch.nn.Parameter(torch.ones((num_embeddings // self.tp_size, embedding_dim))) +class ReduceBroadcast(torch.autograd.Function): - def input_shape(self): - return self.input_shape + @staticmethod + def forward(ctx, input, dst: int, group=None): + ctx._dst = dst + ctx._group = group + torch.distributed.reduce(input, dst, group=group) + torch.cuda.synchronize() + return input - def input_dtype(self): - return torch.int64 + @staticmethod + def backward(ctx, grad_output): + src = ctx._dst + group = ctx._group + torch.distributed.broadcast(grad_output, src, group=group) + torch.cuda.synchronize() + return grad_output, None, None - def output_shape(self): - return self.input_shape + (self.embedding_dim,) - def output_dtype(self): - return torch.float32 +class BroadcastReduce(torch.autograd.Function): - def forward(self, input: torch.Tensor): - if self.tp_size > 1: - mask = (input < self.vocab_start_index) | \ - (input >= self.vocab_end_index) - input = input.clone() - self.vocab_start_index - input[mask] = 0 - input = torch.nn.functional.embedding(input, self.embed_weight) - input[mask, :] = 0.0 - input = Reduce.apply(input, self.tp_group) - else: - input = torch.nn.functional.embedding(input, self.embed_weight) + @staticmethod + def forward(ctx, input, src: int, group=None): + ctx._src = src + ctx._group = group + torch.distributed.broadcast(input, src, group=group) + torch.cuda.synchronize() return input + + @staticmethod + def backward(ctx, grad_output): + dst = ctx._src + group = ctx._group + if not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + torch.distributed.reduce(grad_output, dst, group=group) + torch.cuda.synchronize() + return grad_output, None, None From 1bf00fc2191eee9474dda28db7b056a11efd7f96 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 16 Mar 2022 09:24:44 +0800 Subject: [PATCH 0659/1892] fix naive bug --- handcraft/mbart/mbart.py | 15 +++++++-------- handcraft/mbart/run.sh | 18 ++++++++++++++++++ handcraft/mbart/schedule.py | 10 +++++----- 3 files changed, 30 insertions(+), 13 deletions(-) create mode 100644 handcraft/mbart/run.sh diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 10c846cf..6e97e446 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -20,7 +20,7 @@ from cube.profiler.timer import print_each_rank from handcraft.mbart.schedule import schedule_naive, schedule_tp_1f1b_pack -from handcraft.mbart.tp import AllGatherScatter, ParallelEmbed, BroadcastReduce, ReduceBroadcast +from handcraft.mbart.tp import AllGatherScatter, BroadcastReduce, ReduceBroadcast _tp_group = -1 _pp_group = -1 @@ -480,7 +480,6 @@ def decoder_preprocess(self, dst: Optional[int] = None): def postprocess(self, output, src: Optional[int] = None): _, prev_output_tokens = self._inputs if self.group == -1: - output = self.layer_norm_decoder(output) output = output.transpose(0, 1) output = torch.nn.functional.linear(output, self.weight) loss = criterion(output, prev_output_tokens) @@ -712,8 +711,8 @@ def reduce_embed(model, pp_embed_group): """ Embedding gradients needs to be reduced across pipeline stages """ - if isinstance(model.emb, torch.nn.Module): - grad = model.emb.get_weight().grad + if isinstance(model.headtail, torch.nn.Module): + grad = model.headtail.weight.grad else: grad = None if grad is not None: @@ -779,9 +778,9 @@ def reduce_embed(model, pp_embed_group): optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) CudaTimer(enable=False).warmup() - iter_num = 64 + iter_num = 32 for step in range(iter_num): - if step >= 20: + if step >= 10: CudaTimer(enable=True).start('e2e') if args.use_naive: schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) @@ -795,11 +794,11 @@ def reduce_embed(model, pp_embed_group): print('passed 1st iteration') optimizer.step() optimizer.zero_grad() - if step >= 20: + if step >= 10: CudaTimer().stop('e2e') if (step + 1) % 10 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-20, field_name='e2e'))) + CudaTimer().duration(iter_num-10, field_name='e2e'))) memory_summary() diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh new file mode 100644 index 00000000..c8fbe7ee --- /dev/null +++ b/handcraft/mbart/run.sh @@ -0,0 +1,18 @@ +# 4 gpus + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 64 > 4dev64nmb-tp1f1b-pack.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart.py --use-naive --nmb 64 > 4dev64nmb-naive.txt + +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/pipeline/dummy_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > 4dev64nmb-2tp2pp.txt + +# 8 gpus + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 128 > 8dev128nmb-tp1f1b-pack.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart.py --use-naive --nmb 128 > 8dev128nmb-naive.txt diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py index 22597928..e01484a9 100644 --- a/handcraft/mbart/schedule.py +++ b/handcraft/mbart/schedule.py @@ -440,11 +440,11 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): if 0 <= encoder_mid and encoder_mid <= num_microbatch - 1: tp_head_backward(last_backward) - memory_summary() - if rank == 0: - io_input(f'{step}>>>') - torch.distributed.barrier() - print_each_rank(f'=========end rank {rank}: {step}=========') + # memory_summary() + # if rank == 0: + # io_input(f'{step}>>>') + # torch.distributed.barrier() + # print_each_rank(f'=========end rank {rank}: {step}=========') assert len(input_tensors) == 0 assert len(output_tensors) == 0 From 25aaba74726300626c20883983a03c01145cc080 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 16 Mar 2022 13:41:14 +0800 Subject: [PATCH 0660/1892] add 1f1b scheduling --- handcraft/mbart/mbart.py | 12 ++++-- handcraft/mbart/run.sh | 0 handcraft/mbart/schedule.py | 73 ++++++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 4 deletions(-) mode change 100644 => 100755 handcraft/mbart/run.sh diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 6e97e446..5db5166b 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -19,7 +19,7 @@ from cube.profiler.memory import memory_summary, model_summary from cube.profiler.timer import print_each_rank -from handcraft.mbart.schedule import schedule_naive, schedule_tp_1f1b_pack +from handcraft.mbart.schedule import schedule_naive, schedule_1f1b, schedule_tp_1f1b_pack from handcraft.mbart.tp import AllGatherScatter, BroadcastReduce, ReduceBroadcast _tp_group = -1 @@ -727,6 +727,8 @@ def reduce_embed(model, pp_embed_group): help='num of micro batch') parser.add_argument('--use-naive', action='store_true', help='use naive pipeline') + parser.add_argument('--use-1f1b', action='store_true', + help='use 1f1b scheduling') parser.add_argument('--use-tp1f1b-pack', action='store_true', help='use tensor parallel 1f1b') args = parser.parse_args() @@ -749,7 +751,7 @@ def reduce_embed(model, pp_embed_group): # create embed group: first encoder, first decoder, last stage # FIXME: only work for tp_size = 1 - if args.use_naive: + if args.use_naive or args.use_1f1b: embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2], pp_ranks[-1]] embed_ranks = list(set(embed_ranks)) _pp_embed_group = DeviceGroup().get_group(embed_ranks) @@ -764,7 +766,7 @@ def reduce_embed(model, pp_embed_group): dtypes=(torch.int64, torch.int64), batch_dims=(0,0,) ) - if args.use_naive: + if args.use_naive or args.use_1f1b: encoder_preprocess = is_first_stage decoder_preprocess = is_first_decoder_stage postprocess = is_last_stage @@ -785,6 +787,10 @@ def reduce_embed(model, pp_embed_group): if args.use_naive: schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) reduce_embed(model, _pp_embed_group) + if args.use_1f1b: + for _ in range(args.nmb // 2): + schedule_1f1b(model, iter(dataloader), 2, len(pp_ranks), (_pp_prev_rank, _pp_next_rank)) + reduce_embed(model, _pp_embed_group) if args.use_tp1f1b_pack: schedule_tp_1f1b_pack( model, iter(dataloader), diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh old mode 100644 new mode 100755 diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py index e01484a9..c2f618f1 100644 --- a/handcraft/mbart/schedule.py +++ b/handcraft/mbart/schedule.py @@ -1,3 +1,4 @@ +from turtle import forward from typing import List, Tuple import torch @@ -451,4 +452,74 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): assert len(input_head_tensors) == 0 assert len(output_head_tensors) == 0 - # print_each_rank(f'=========end rank {rank}=========') \ No newline at end of file + # print_each_rank(f'=========end rank {rank}=========') + + +def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors): + + rank = torch.distributed.get_rank() + prev_rank, next_rank = neighbors + is_first_stage = rank < prev_rank + is_last_stage = rank > next_rank + + num_warmup_microbatches = num_stage - 1 - rank + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatch) + num_warmup_remaining = num_microbatch - num_warmup_microbatches + + input_tensors = list() + output_tensors = list() + + # warmup + for i in range(num_warmup_microbatches): + model.set_inputs(*next(dataloader)) + # recv forward + inputs = () if is_first_stage else recv_forward(model, prev_rank) + # forward + outputs = forward_step(model, *inputs) + # send forward + send_forward(outputs, next_rank) + input_tensors.append(inputs) + output_tensors.append(outputs) + + # before running 1f1b: need to recv first forward tensor + if num_warmup_remaining > 0: + model.set_inputs(*next(dataloader)) + inputs = () if is_first_stage else recv_forward(model, prev_rank) + + # run 1f1b + for i in range(num_warmup_remaining): + model.set_inputs(*next(dataloader)) + # forward + outputs = forward_step(model, *inputs) + input_tensors.append(inputs) + output_tensors.append(outputs) + + # send forward recv backward + grads = (None,) + if not is_last_stage: + grads = send_forward_recv_backward(outputs, model, next_rank) + + # backward + inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) + input_grads = backward_step(inputs, outputs, grads) + + # send backward + inputs = () + if not is_first_stage: + if i != (num_warmup_remaining-1): + # send backward recv forward + inputs = send_backward_recv_forward(input_grads, model, prev_rank) + else: + # send backward + send_backward(input_grads, prev_rank) + + # cooldown + for i in range(num_warmup_microbatches): + inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) + # recv backward + grads = (None,) if is_last_stage else recv_backward(model, next_rank) + # backward + input_grads = backward_step(inputs, outputs, grads) + # send backward + if not is_first_stage: + send_backward(input_grads, prev_rank) From 68f9fb96bc44749974f805970d530c4432b52d04 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 16 Mar 2022 19:17:53 +0800 Subject: [PATCH 0661/1892] remove useless --- cube/algorithm/ops/complex.py | 658 ---------------------------------- 1 file changed, 658 deletions(-) delete mode 100644 cube/algorithm/ops/complex.py diff --git a/cube/algorithm/ops/complex.py b/cube/algorithm/ops/complex.py deleted file mode 100644 index 9e985d65..00000000 --- a/cube/algorithm/ops/complex.py +++ /dev/null @@ -1,658 +0,0 @@ -from typing import Dict - -from cube.algorithm.utils import split_axis, split_value -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.operator.function import CubeComplexToQKV -from cube.graph.operator.function import CubeComplexTrilMask -from cube.graph.operator.function import CubeComplexAttnView -from cube.graph.operator.function import CubeComplexSelfAttention -from cube.graph.operator.function import CubeComplexFeedForward -from cube.graph.operator.function import CubeComplexEmbedding - - -_kWaitDecision = None - - -class CubeToQKVDataParallel(GenericDistAlgo): - """ - Inputs: - hidden_state: [L, N, E] - weight: [3 * (num_head * dim_head), E] - num_head: int - - where L = sequence length, N = batch size, E = num_head * dim_head - - Returns: - Q: [L, N * num_head, dim_head] - K: [L, N * num_head, dim_head] - V: [L, N * num_head, dim_head] - """ - def __init__(self, node: CubeComplexToQKV): - - if not isinstance(node, CubeComplexToQKV): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.num_head = node.kwargs['num_head'] - self.bs = node.inputs(0).shape[1] - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.bs % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - hidden_size, weight = node.inputs() - q, k, v = node.outputs() - - ins = split_axis(hidden_size, 1, self.chunk_num) - qs = split_axis(q, 1, self.chunk_num) - ks = split_axis(k, 1, self.chunk_num) - vs = split_axis(v, 1, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ins[idx], weight, self.num_head] - node = CubeComplexToQKV( - signature = 'cube.runtime.function.complex.toqkv', - inputs = inputs, - name = 'toqkv' - ) - node.set_output(0, qs[idx]) - node.set_output(1, ks[idx]) - node.set_output(2, vs[idx]) - nodes.append(node) - return nodes - - -class CubeToQKVHeadParallel(GenericDistAlgo): - """ - Inputs: - hidden_state: [L, N, E] (seqlen, batch size, num_head * dim_head) - weight: [E * 3, E] - num_head: int - - Returns: - Q: [L, N * num_head, dim_head] - K: [L, N * num_head, dim_head] - V: [L, N * num_head, dim_head] - """ - def __init__(self, node: CubeComplexToQKV): - - if not isinstance(node, CubeComplexToQKV): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.num_head = node.kwargs['num_head'] - self.bs = node.inputs(0).shape[1] - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.num_head % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - hidden_state, weight = node.inputs() - q, k, v = node.outputs() - - ws = split_axis(weight, 0, self.chunk_num) - qs = split_axis(q, 1, self.chunk_num) - ks = split_axis(k, 1, self.chunk_num) - vs = split_axis(v, 1, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [hidden_state, ws[idx], self.num_head // self.chunk_num] - node = CubeComplexToQKV( - signature = 'cube.runtime.function.complex.toqkv', - inputs = inputs, - name = 'toqkv' - ) - node.set_output(0, qs[idx]) - node.set_output(1, ks[idx]) - node.set_output(2, vs[idx]) - nodes.append(node) - return nodes - - -class CubeTrilMaskDataParallel(GenericDistAlgo): - """ - Inputs: - input: [N * num_head, L, L] - num_head: int - - Returns: - output: [N * num_head, L, L] - """ - def __init__(self, node: CubeComplexTrilMask): - - if not isinstance(node, CubeComplexTrilMask): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.num_head = node.kwargs['num_head'] - self.bs = node.inputs(0).shape[0] // self.num_head - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.bs % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - hidden_size = node.inputs(0) - masked_out = node.outputs(0) - - ins = split_axis(hidden_size, 0, self.chunk_num) - ous = split_axis(masked_out, 0, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ins[idx], self.num_head] - node = CubeComplexTrilMask( - signature = 'cube.runtime.function.complex.tril_mask', - inputs = inputs, - name = 'tril_mask' - ) - node.set_output(0, ous[idx]) - nodes.append(node) - return nodes - - -class CubeTrilMaskHeadParallel(GenericDistAlgo): - """ - Inputs: - input: [N * num_head, L, L] - num_head: int - - Returns: - output: [N * num_head, L, L] - """ - def __init__(self, node: CubeComplexTrilMask): - - if not isinstance(node, CubeComplexTrilMask): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.num_head = node.kwargs['num_head'] - self.bs = node.inputs(0).shape[0] // self.num_head - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.num_head % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - hidden_size = node.inputs(0) - masked_out = node.outputs(0) - - ins = split_axis(hidden_size, 0, self.chunk_num) - ous = split_axis(masked_out, 0, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ins[idx], self.num_head // self.chunk_num] - node = CubeComplexTrilMask( - signature = 'cube.runtime.function.complex.tril_mask', - inputs = inputs, - name = 'tril_mask' - ) - node.set_output(0, ous[idx]) - nodes.append(node) - return nodes - - -class CubeAttnViewDataParallel(GenericDistAlgo): - """ - Inputs: - [N * num_head, L, dim_head] - - Outputs: - [L, N, num_head * dim_head] - """ - def __init__(self, node: CubeComplexAttnView): - if not isinstance(node, CubeComplexAttnView): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.num_head = node.kwargs['num_head'] - self.bs = node.inputs(0).shape[0] // self.num_head - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.bs % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - attn = node.inputs(0) - out = node.outputs(0) - - ins = split_axis(attn, 0, self.chunk_num) - ous = split_axis(out, 1, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ins[idx], self.num_head] - node = CubeComplexAttnView( - signature = 'cube.runtime.function.complex.attn_view', - inputs = inputs, - name = 'attn_view' - ) - node.set_output(0, ous[idx]) - nodes.append(node) - return nodes - - -class CubeAttnViewHeadParallel(GenericDistAlgo): - """ - Inputs: - [N * num_head, L, dim_head] - - Outputs: - [L, N, num_head * dim_head] - """ - def __init__(self, node: CubeComplexAttnView): - if not isinstance(node, CubeComplexAttnView): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.num_head = node.kwargs['num_head'] - self.bs = node.inputs(0).shape[0] // self.num_head - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.num_head % chunk_num == 0: - return True - return False - - def instantiate(self, node, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - attn = node.inputs(0) - out = node.outputs(0) - - ins = split_axis(attn, 0, self.chunk_num) - ous = split_axis(out, 2, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ins[idx], self.num_head // self.chunk_num] - node = CubeComplexAttnView( - signature = 'cube.runtime.function.complex.attn_view', - inputs = inputs, - name = 'attn_view' - ) - node.set_output(0, ous[idx]) - nodes.append(node) - return nodes - - -class CubeSelfAttentionHeadParallel(GenericDistAlgo): - """ - Multi-Head Self-Attention. - - L: sequence length - N: batch size - E: embedding size - - Inputs: - hidden_state: [L, N, E] - w_qkv : [3 * num_head * dim_head, E] - w_out : [E, E] - num_head: int - dim_head: int - dropout_p: float - - Outputs: - hidden_state: [L, N, E] - """ - def __init__(self, node: CubeComplexSelfAttention): - if not isinstance(node, CubeComplexSelfAttention): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.num_head = node.kwargs['num_head'] - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.num_head % chunk_num == 0: - return True - return False - - def instantiate(self, node: CubeComplexSelfAttention, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - hidden_state = node.inputs(0) - w_qkv = node.inputs(1) - w_out = node.inputs(2) - num_head = node.kwargs['num_head'] - dim_head = node.kwargs['dim_head'] - dropout_p = node.kwargs['dropout_p'] - out = node.outputs(0) - - - w_qkvs = split_axis(w_qkv, 0, self.chunk_num) - w_outs = split_axis(w_out, 1, self.chunk_num) - ous = split_value(out, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ - hidden_state, w_qkvs[idx], w_outs[idx], - num_head // self.chunk_num, dim_head, dropout_p - ] - node = CubeComplexSelfAttention( - signature = 'cube.runtime.function.complex.self_attn', - inputs = inputs, - ) - node.set_output(0, ous[idx]) - nodes.append(node) - return nodes - - -class CubeSelfAttentionDataParallel(GenericDistAlgo): - """ - Multi-Head Self-Attention. - - L: sequence length - N: batch size - E: embedding size - - Inputs: - hidden_state: [L, N, E] - w_qkv : [3 * num_head * dim_head, E] - w_out : [E, E] - num_head: int - dim_head: int - dropout_p: float - - Outputs: - hidden_state: [L, N, E] - """ - def __init__(self, node: CubeComplexSelfAttention): - if not isinstance(node, CubeComplexSelfAttention): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.bs = node.inputs(0).shape[1] - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.bs % chunk_num == 0: - return True - return False - - def instantiate(self, node: CubeComplexSelfAttention, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - hidden_state = node.inputs(0) - w_qkv = node.inputs(1) - w_out = node.inputs(2) - num_head = node.kwargs['num_head'] - dim_head = node.kwargs['dim_head'] - dropout_p = node.kwargs['dropout_p'] - out = node.outputs(0) - - ins = split_axis(hidden_state, 1, self.chunk_num) - ous = split_axis(out, 1, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ - ins[idx], w_qkv, w_out, - num_head, dim_head, dropout_p - ] - node = CubeComplexSelfAttention( - signature = 'cube.runtime.function.complex.self_attn', - inputs = inputs, - ) - node.set_output(0, ous[idx]) - nodes.append(node) - return nodes - - -class CubeFeedForwardTensorParallel(GenericDistAlgo): - """ - FeedForward - - Inputs: - hidden_state: [L, N, E] - w_proj1: [4 * E, E] - w_bias1: [4 * E,] - w_porj2: [E, 4 * E] - w_bias2: [E,] - - Outputs: - hidden_state: [L, N, E] - """ - def __init__(self, node: CubeComplexFeedForward): - if not isinstance(node, CubeComplexFeedForward): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.embed_size = node.inputs(1).shape[0] - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.embed_size % chunk_num == 0: - return True - return False - - def instantiate(self, node: CubeComplexFeedForward, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - hidden_state = node.inputs(0) - w_proj1 = node.inputs(1) - w_bias1 = node.inputs(2) - w_proj2 = node.inputs(3) - w_bias2 = node.inputs(4) - - out = node.outputs(0) - - w_proj1s = split_axis(w_proj1, 0, self.chunk_num) - w_bias1s = split_axis(w_bias1, 0, self.chunk_num) - w_proj2s = split_axis(w_proj2, 1, self.chunk_num) - w_bias2s = split_value(w_bias2, self.chunk_num) - - outs = split_value(out, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ - hidden_state, - w_proj1s[idx], w_bias1s[idx], - w_proj2s[idx], w_bias2s[idx] - ] - node = CubeComplexFeedForward( - signature = 'cube.runtime.function.complex.feedforward', - inputs = inputs, - ) - node.set_output(0, outs[idx]) - nodes.append(node) - return nodes - - -class CubeFeedForwardDataParallel(GenericDistAlgo): - """ - FeedForward - - Inputs: - hidden_state: [L, N, E] - w_proj1: [4 * E, E] - w_bias1: [4 * E,] - w_porj2: [E, 4 * E] - w_bias2: [E,] - - Outputs: - hidden_state: [L, N, E] - """ - def __init__(self, node: CubeComplexFeedForward): - if not isinstance(node, CubeComplexFeedForward): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.bs = node.inputs(0).shape[1] - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.bs % chunk_num == 0: - return True - return False - - def instantiate(self, node: CubeComplexFeedForward, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - hidden_state = node.inputs(0) - w_proj1 = node.inputs(1) - w_bias1 = node.inputs(2) - w_proj2 = node.inputs(3) - w_bias2 = node.inputs(4) - out = node.outputs(0) - - ins = split_axis(hidden_state, 1, self.chunk_num) - outs = split_axis(out, 1, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ - ins[idx], - w_proj1, w_bias1, - w_proj2, w_bias2, - ] - node = CubeComplexFeedForward( - signature = 'cube.runtime.function.complex.feedforward', - inputs = inputs, - ) - node.set_output(0, outs[idx]) - nodes.append(node) - return nodes - - -class CubeEmbedDataParallel(GenericDistAlgo): - - def __init__(self, node: CubeComplexEmbedding): - if not isinstance(node, CubeComplexEmbedding): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.dims = node.inputs(0).shape - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - dim = int(config['dim']) - if dim >= len(self.dims): - return False - if chunk_num > 0 and self.dims[dim] % chunk_num == 0: - return True - return False - - def instantiate(self, node: CubeComplexEmbedding, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - dim = int(config['dim']) - - input = node.inputs(0) - weight = node.inputs(1) - start = node.kwargs['start'] - stop = node.kwargs['stop'] - - out = node.outputs(0) - - ins = split_axis(input, dim, self.chunk_num) - outs = split_axis(out, dim, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - inputs = [ - ins[idx], weight, start, stop - ] - node = CubeComplexEmbedding( - signature = 'cube.runtime.function.complex.embedding', - inputs = inputs, - ) - node.set_output(0, outs[idx]) - nodes.append(node) - return nodes - - -class CubeEmbedShardingParallel(GenericDistAlgo): - - def __init__(self, node: CubeComplexEmbedding): - if not isinstance(node, CubeComplexEmbedding): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - self.chunk_num = _kWaitDecision - self.vocabs = node.inputs(1).shape[0] - - def satisfy(self, config: Dict): - chunk_num = int(config['chunk_num']) - if chunk_num > 0 and self.vocabs % chunk_num == 0: - return True - return False - - def instantiate(self, node: CubeComplexEmbedding, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") - self.chunk_num = int(config['chunk_num']) - - input = node.inputs(0) - weight = node.inputs(1) - start = node.kwargs['start'] - stop = node.kwargs['stop'] - shard = (stop - start) // self.chunk_num - - out = node.outputs(0) - - ws = split_axis(weight, 0, self.chunk_num) - outs = split_value(out, self.chunk_num) - - nodes = list() - for idx in range(self.chunk_num): - shard_start = start + shard * idx - shard_stop = shard_start + shard - inputs = [ - input, ws[idx], shard_start, shard_stop - ] - node = CubeComplexEmbedding( - signature = 'cube.runtime.function.complex.embedding', - inputs = inputs, - ) - node.set_output(0, outs[idx]) - nodes.append(node) - return nodes From 99353897c5c5dd9b7b4c25f3ee2cb0740906362d Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 17 Mar 2022 11:26:34 +0800 Subject: [PATCH 0662/1892] fix torch 1.11 torch.linear error --- cube/graph/operator/function/function.py | 5 +++++ cube/graph/parser/mapping.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index c09a958b..f27fbe89 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -59,6 +59,11 @@ def _create_anno(ins: List[List[Union[str, List[str]]]], def Linear(signature, inputs): + if signature == 'torch.linear': + import warnings + warnings.warn(f'signature {signature} replaced into torch.nn.functional.linear') + signature = 'torch.nn.functional.linear' + annos = [ 'b * k+, n k+ -> b * n', # no bias 'b * k+, n k+, n -> b * n' # have bias diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 2bb764ba..e3dd9d43 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -95,6 +95,9 @@ def register(signature: str, op: IRFwOperation, code): __ttemplate('conv3d'): function.Conv3D, + #pytorch1.11 + __ttemplate('linear'): function.Linear, + #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, From ec72b29c2abc4930237147971ddd21751707f9d4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Mar 2022 14:44:56 +0800 Subject: [PATCH 0663/1892] switch to sentence classification task --- handcraft/mbart/mbart.py | 100 +++++++++++++++-------------- handcraft/mbart/schedule.py | 121 +++++++++++++++--------------------- 2 files changed, 99 insertions(+), 122 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 5db5166b..df3dc61c 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -77,6 +77,9 @@ class Config: pooler_activation_fn = 'tanh' pooler_dropout = 0.0 + # classification task + num_classes = 2 + def attn_fn(query: torch.Tensor, key: torch.Tensor, wq: torch.Tensor, wq_bias: Optional[torch.Tensor], @@ -404,6 +407,36 @@ def criterion(output: torch.Tensor, prev_output_tokens: torch.Tensor, label_smoo return loss +class MBartClassificationHead(torch.nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.num_classes = num_classes + self.dense = torch.nn.Linear(input_dim, inner_dim) + self.dropout = torch.nn.Dropout(p=pooler_dropout) + self.out_proj = torch.nn.Linear(inner_dim, num_classes) + self.loss_fct = torch.nn.CrossEntropyLoss() + + def forward(self, dec: torch.Tensor, labels): + # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] + dec = dec.transpose(0, 1)[:,-1,:] + sentence_represent = dec + hidden_states = self.dropout(sentence_represent) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + logits = self.out_proj(hidden_states) + loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) + return (loss,) + + class ShardHeadTail(torch.nn.Module): def __init__(self, cfg: Config, group=-1): @@ -434,16 +467,11 @@ def __init__(self, cfg: Config, group=-1): # post-proces - self._inputs = (None, None) + self._inputs = (None, ) def set_inputs(self, *inputs): self._inputs = inputs - def criterion_input_shape(self): - return ( - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), - ) - def embed_lookup(self, tokens, dst: Optional[int] = None): if self.shard_num > 1: mask = (tokens < self.vocab_start_index) | \ @@ -458,7 +486,7 @@ def embed_lookup(self, tokens, dst: Optional[int] = None): return embed def encoder_preprocess(self, dst: Optional[int] = None): - source_tokens, _ = self._inputs + source_tokens = self._inputs[0] source_embed = self.embed_lookup(source_tokens, dst) embed = self.embed_scale_encoder * source_embed x = embed + self.embed_positions_encoder.weight @@ -468,7 +496,7 @@ def encoder_preprocess(self, dst: Optional[int] = None): return (enc,) def decoder_preprocess(self, dst: Optional[int] = None): - _, prev_output_tokens = self._inputs + prev_output_tokens = self._inputs[0] target_emb = self.embed_lookup(prev_output_tokens, dst) embed = self.embed_scale_decoder * target_emb embed = embed + self.embed_positions_decoder.weight @@ -477,31 +505,6 @@ def decoder_preprocess(self, dst: Optional[int] = None): dec = embed.transpose(0, 1) return (dec,) - def postprocess(self, output, src: Optional[int] = None): - _, prev_output_tokens = self._inputs - if self.group == -1: - output = output.transpose(0, 1) - output = torch.nn.functional.linear(output, self.weight) - loss = criterion(output, prev_output_tokens) - return (loss,) - else: - assert src is not None - if self.shard_idx != src: - output = torch.empty( - self.criterion_input_shape()[0], - dtype=torch.float32, - requires_grad=True, - device=torch.cuda.current_device() - ) - output = output.transpose(0, 1) - output = BroadcastReduce.apply(output, src, self.group) - # return (torch.sum(output.contiguous()),) - output = torch.nn.functional.linear(output, self.weight) - output = AllGatherScatter.apply(output, -1, self.group) - loss = criterion(output, prev_output_tokens) - return (loss,) - - class mBARTFull(torch.nn.Module): @@ -511,6 +514,7 @@ def __init__(self, cfg: Config, post_process=True, shard=True): super().__init__() self.cfg = cfg + self.dummy_labels = torch.tensor([1]).cuda() self._preprocess = [None, None] # enc, dec self.rank = DeviceGroup().rank @@ -567,6 +571,7 @@ def __init__(self, cfg: Config, # postpross if self.postprocess: print(f'[{self.rank}]: will compute loss') + self.head = MBartClassificationHead(cfg.decoder_embed_dim, 1024, cfg.num_classes, 0.0) def input_shape(self): if self.encoder_preprocess: @@ -588,13 +593,9 @@ def input_shape(self): (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), ) - elif self.decoder_last_stage: - return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), - ) - else: - assert False, "post-process is not allowed to be a single stage" + elif self.postprocess: + return ((1,),) + assert False def output_shape(self): shape = None @@ -606,12 +607,10 @@ def output_shape(self): (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), ) - if self.decoder_last_stage: + if self.postprocess: shape = ( - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (1,), ) - if self.postprocess: - shape = ((1,),) assert shape is not None return shape @@ -642,7 +641,7 @@ def output_dtype(self): return dtype def set_inputs(self, *inputs): - assert len(inputs) == 2 + assert len(inputs) == 1 if self.headtail is not None: self.headtail.set_inputs(*inputs) @@ -658,8 +657,8 @@ def forward_encoder_preprocess(self, dst=None): def forward_decoder_preprocess(self, dst=None): return self.headtail.decoder_preprocess(dst) - def forward_postprocess(self, dec, src=None): - return self.headtail.postprocess(dec, src) + def forward_postprocess(self, dec): + return self.head(dec, self.dummy_labels) def forward(self, enc=None, dec=None): """ @@ -761,10 +760,9 @@ def reduce_embed(model, pp_embed_group): dataloader = SynTextDataLoader( shapes=( [1, cfg.max_source_positions], - [1, cfg.max_target_positions] ), - dtypes=(torch.int64, torch.int64), - batch_dims=(0,0,) + dtypes=(torch.int64,), + batch_dims=(0,) ) if args.use_naive or args.use_1f1b: encoder_preprocess = is_first_stage @@ -772,7 +770,7 @@ def reduce_embed(model, pp_embed_group): postprocess = is_last_stage model = mBARTFull(cfg, encoder_preprocess, decoder_preprocess, postprocess, shard=False).cuda() else: - model = mBARTFull(cfg, False, False, False, shard=True).cuda() + model = mBARTFull(cfg, False, False, is_last_stage, shard=True).cuda() print_each_rank('model weight consumpition:') memory_summary() diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py index c2f618f1..21a2e9e5 100644 --- a/handcraft/mbart/schedule.py +++ b/handcraft/mbart/schedule.py @@ -1,10 +1,8 @@ -from turtle import forward from typing import List, Tuple import torch from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary -import cube.runtime.adapter.collectives as coll from cube.runtime.device import DeviceGroup io_input = input @@ -206,7 +204,7 @@ def schedule_naive(model, dataloader, num_microbatch: int, neighbors: Tuple[int, is_last_stage = rank > next_rank for step in range(num_microbatch): - model.set_inputs(*next(dataloader)) + model.set_inputs(next(dataloader)) # print(f'rank {rank} recving forward input...') inputs = () if is_first_stage else recv_forward(model, prev_rank) # forward @@ -243,58 +241,56 @@ def schedule_tp_1f1b_pack(model: torch.nn.Module, # FIXME: only work for pure pipeline is_first_decoder_stage = (rank == num_stage // 2) is_last_stage = rank > next_rank - last_stage = torch.distributed.get_world_size() - 1 input_tensors = list() output_tensors = list() - input_head_tensors = list() - output_head_tensors = list() + input_encoder_tensors = list() + output_encoder_tensors = list() + input_decoder_tensors = list() + output_decoder_tensors = list() - def tp_head_forward() -> torch.Tensor: - src_tokens, prev_output_tokens = next(dataloader) - model.set_inputs(*(src_tokens, prev_output_tokens)) + def tp_encoder_preprocess() -> torch.Tensor: + tokens = next(dataloader) + model.set_inputs(tokens) enc = model.forward_encoder_preprocess(dst=0)[0] - dec = model.forward_decoder_preprocess(dst=num_stage // 2)[0] - input_head_tensors.append((src_tokens, prev_output_tokens)) - output_head_tensors.append((enc, dec)) + input_encoder_tensors.append((tokens,)) + output_encoder_tensors.append((enc,)) enc = enc.detach().requires_grad_() - dec = dec.detach().requires_grad_() - # FIXME: this will change decoder input if is_first_stage: model.set_preprocess(enc=enc) - if is_first_decoder_stage: - model.set_preprocess(dec=dec) - if is_first_stage: return (enc,) + return () + + def tp_decoder_preprocess() -> torch.Tensor: + tokens = next(dataloader) + model.set_inputs(tokens) + dec = model.forward_decoder_preprocess(dst=num_stage // 2)[0] + input_decoder_tensors.append((tokens,)) + output_decoder_tensors.append((dec,)) + dec = dec.detach().requires_grad_() if is_first_decoder_stage: + model.set_preprocess(dec=dec) return (dec,) - else: - return () + return () - def tp_head_backward(grads: Tuple[torch.Tensor]): - inputs_head, outputs_head = input_head_tensors.pop(0), output_head_tensors.pop(0) + def tp_encoder_backward(grads: Tuple[torch.Tensor]): + inputs_head, outputs_head = input_encoder_tensors.pop(0), output_encoder_tensors.pop(0) # encoder backward - enc, dec = outputs_head + enc = outputs_head[0] if not is_first_stage: grads = (torch.empty_like(enc),) # decoder backward backward_step((), (enc,), grads) - #FIXME: grads is using enc gradient!!! + + def tp_decoder_backward(grads: Tuple[torch.Tensor]): + inputs_head, outputs_head = input_decoder_tensors.pop(0), output_decoder_tensors.pop(0) + # decoder backward + dec = outputs_head[0] if not is_first_decoder_stage: grads = (torch.empty_like(dec),) backward_step((), (dec,), grads) - def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): - dec = None - if is_last_stage: - assert len(outputs) == 1 - dec = outputs[0] - dec = dec.detach().requires_grad_() - loss = model.forward_postprocess(dec, src=last_stage) - grads = backward_step((dec,), loss, (None,)) - return grads - fofst = [-(step // 2) for step in range(num_stage)] bofst = [-(num_stage - 1 - (step // 2)) for step in range(num_stage)] # print(fofst) @@ -303,20 +299,26 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): bofst = bofst[rank] last_backward = (None,) last_forward = (None,) - tail_grads = (None,) for step in range(num_microbatch + num_stage - 1): torch.distributed.barrier() # print_each_rank(f'=========begin rank {rank}=========') fmid, bmid = step + fofst, step + bofst + decoder_fmid = step - num_stage // 2 // 2 + encoder_bmid = step + 1 - num_stage // 2 * 2 + decoder_bmid = step + 1 - int(num_stage // 2 * 1.5) do_backward = 0 <= bmid and bmid <= num_microbatch - 1 do_forward = 0 <= fmid and fmid <= num_microbatch - 1 - # step1: tp forward + # step1: tp encoder forward if 0 <= step and step <= num_microbatch - 1: # print(f'rank {rank} forward tp model ') - inputs = tp_head_forward() + inputs = tp_encoder_preprocess() + + # step2: tp decoder forward + if 0 <= decoder_fmid and decoder_fmid <= num_microbatch - 1: + tp_decoder_preprocess() - # forward + backward + # step 3: forward + backward if rank % 2 == 0: # inter-barrier if is_first_stage: @@ -325,18 +327,12 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): if do_forward and last_backward != (None,): # print(f'rank {rank} send backward grad + recv forward output ') inputs = send_backward_recv_forward(last_backward, model, prev_rank) - # input = coll.sendrecv( - # [input_grad], [model.input_shape()], [model.input_dtype()], - # [prev_rank], [prev_rank] - # )[0] elif do_forward: # print(f'rank {rank} recv forward output ') inputs = recv_forward(model, prev_rank) - # input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) elif last_backward != (None,): # print(f'rank {rank} send backward grad ') send_backward(last_backward, prev_rank) - # coll.send(last_backward, prev_rank) # forward if do_forward: @@ -355,18 +351,12 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): # send forward recv backward # print(f'rank {rank} recv backward grad + send forward output ') output_grads = send_forward_recv_backward(outputs, model, next_rank) - # output_grads = coll.sendrecv( - # [output], [output.size()], [output.dtype], - # [next_rank], [next_rank] - # )[0] elif do_forward and not is_last_stage: # print(f'rank {rank} send forward output ') send_forward(outputs, next_rank) - # coll.send(output, next_rank) elif do_backward and not is_last_stage: # print(f'rank {rank} recv backward grad ') output_grads = recv_backward(model, next_rank) - # output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) # backward last_backward = (None,) @@ -379,23 +369,17 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): if rank % 2 == 1: # inter-barrier if is_last_stage: - output_grads = tail_grads + output_grads = (None,) else: if do_backward and last_forward != (None,): # print(f'rank {rank} recv backward grad + send forward output ') output_grads = send_forward_recv_backward(last_forward, model, next_rank) - # output_grad = coll.sendrecv( - # [last_forward], [model.output_shape()], [model.output_dtype()], - # [next_rank], [next_rank] - # )[0] elif do_backward: # print(f'rank {rank} recv backward grad ') output_grads = recv_backward(model, next_rank) - # output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) elif last_forward != (None,): # print(f'rank {rank} send forward output ') send_forward(last_forward, next_rank) - # coll.send(last_forward, next_rank) # backward last_backward = (None,) @@ -409,18 +393,12 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): if do_backward and do_forward: # print(f'rank {rank} send backward grad + recv forward output ') inputs = send_backward_recv_forward(input_grads, model, prev_rank) - # input = coll.sendrecv( - # [input_grad], [model.input_shape()], [model.input_dtype()], - # [prev_rank], [prev_rank] - # )[0] elif do_backward: # print(f'rank {rank} send backward grad ') send_backward(input_grads, prev_rank) - # coll.send(input_grad, prev_rank) elif do_forward: # print(f'rank {rank} recv forward output ') inputs = recv_forward(model, prev_rank) - # input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) # forward last_forward = (None,) @@ -432,14 +410,13 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): last_forward = outputs # tp tail forward-backward - last_stage_mid = step - (num_stage - 1) // 2 - if 0 <= last_stage_mid and last_stage_mid <= num_microbatch - 1: - tail_grads = tp_tail_forward_backward(last_forward) + if 0 <= decoder_bmid and decoder_bmid <= num_microbatch - 1: + # FIXME: currently use encoder grad + tp_decoder_backward(last_backward) # step 4: tp encoder and decoder backward - encoder_mid = step + 1 - num_stage - if 0 <= encoder_mid and encoder_mid <= num_microbatch - 1: - tp_head_backward(last_backward) + if 0 <= encoder_bmid and encoder_bmid <= num_microbatch - 1: + tp_encoder_backward(last_backward) # memory_summary() # if rank == 0: @@ -449,8 +426,10 @@ def tp_tail_forward_backward(outputs: Tuple[torch.Tensor]): assert len(input_tensors) == 0 assert len(output_tensors) == 0 - assert len(input_head_tensors) == 0 - assert len(output_head_tensors) == 0 + assert len(input_encoder_tensors) == 0 + assert len(output_encoder_tensors) == 0 + assert len(input_decoder_tensors) == 0 + assert len(output_decoder_tensors) == 0 # print_each_rank(f'=========end rank {rank}=========') @@ -471,7 +450,7 @@ def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors): # warmup for i in range(num_warmup_microbatches): - model.set_inputs(*next(dataloader)) + model.set_inputs(next(dataloader)) # recv forward inputs = () if is_first_stage else recv_forward(model, prev_rank) # forward From 8e37bab86cf0fbffd4a1aa8d634337ef257687c7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Mar 2022 14:52:41 +0800 Subject: [PATCH 0664/1892] fix naive bug --- handcraft/mbart/mbart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index df3dc61c..887a19ea 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -751,7 +751,7 @@ def reduce_embed(model, pp_embed_group): # create embed group: first encoder, first decoder, last stage # FIXME: only work for tp_size = 1 if args.use_naive or args.use_1f1b: - embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2], pp_ranks[-1]] + embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2]] embed_ranks = list(set(embed_ranks)) _pp_embed_group = DeviceGroup().get_group(embed_ranks) @@ -796,6 +796,7 @@ def reduce_embed(model, pp_embed_group): ) if step == 0: print('passed 1st iteration') + memory_summary() optimizer.step() optimizer.zero_grad() if step >= 10: From f74010d030ba41b9cc7c66919a906c07bf7e7526 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Mar 2022 19:04:37 +0800 Subject: [PATCH 0665/1892] fix return none bug --- cube/codegen/codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 616579bb..de35f9e2 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -477,7 +477,7 @@ def tuple_naming(self, tensors: List[Any]) -> str: def return_naming(self, tensors: List[Any]) -> str: tensors = [self.tensor_naming(t) for t in tensors] if len(tensors) == 0: - tensors = '_' + tensors = '' else: tensors = ', '.join(tensors) return tensors From 86020026abcd394929e698484c4a7873b69833f1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Mar 2022 20:33:37 +0800 Subject: [PATCH 0666/1892] hybrid benchmark --- handcraft/mbart/mbart.py | 2 +- handcraft/mbart/mbart_hybrid.py | 730 ++++++++++++++++++++++++++++++++ handcraft/mbart/run.sh | 23 +- 3 files changed, 750 insertions(+), 5 deletions(-) create mode 100644 handcraft/mbart/mbart_hybrid.py diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 887a19ea..9fa0e150 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - handcraft/mbart/mbart.py --pp-size 4 --tp-size 1 --nmb 4 + handcraft/mbart/mbart_hybrid.py --pp-size 4 --tp-size 1 --nmb 4 """ from typing import Optional diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py new file mode 100644 index 00000000..78ead61e --- /dev/null +++ b/handcraft/mbart/mbart_hybrid.py @@ -0,0 +1,730 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + handcraft/mbart/mbart.py --pp-size 4 --tp-size 1 --nmb 4 +""" + +from typing import Optional +import argparse +import math +import torch + +import cube +from cube.runtime.device import DeviceGroup +from cube.runtime.syndata import SynTextDataLoader +from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary, model_summary +from cube.profiler.timer import print_each_rank + +from handcraft.mbart.schedule import schedule_naive, schedule_1f1b, schedule_tp_1f1b_pack +from handcraft.mbart.tp import AllGatherScatter, AllReduceIdentity, BroadcastReduce, IdentityAllreduce, ReduceBroadcast + +_tp_group = -1 +_pp_group = -1 +_pp_embed_group = -1 +_pp_next_rank = None +_pp_prev_rank = None + +# fairseq task +# translation_from_pretrained_bart + +# fairseq criterion +# label_smoothed_cross_entropy, --label_smoothing = 0.2 + +class Config: + + num_embeddings = 250027 + + encoder_embed_path = None + encoder_embed_dim = 1024 + encoder_ffn_embed_dim = 4 * 1024 + encoder_layers = 12 + encoder_attention_heads = 16 + encoder_normalize_before = True + encoder_learned_pos = True + + decoder_embed_path = None + decoder_embed_dim = 1024 + decoder_ffn_embed_dim = 4 * 1024 + decoder_layers = 12 + decoder_attention_heads = 16 + decoder_normalize_before = True + decoder_learned_pos = True + cross_self_attention = False + no_cross_attention = False + + attention_dropout = 0.0 + activation_dropout = 0.0 + dropout = 0.1 + + max_target_positions = 1024 + max_source_positions = 1024 + adaptive_softmax_cutoff = None + adaptive_softmax_dropout = 0 + + share_decoder_input_output_embed = True + share_all_embeddings = True + + decoder_output_dim = 1024 # same with decorder_embed_dim + decoder_input_dim = 1024 # same with decorder_embed_dim + + no_scale_embedding = False # True in bart large + layernorm_embedding = True + activation_fn = 'gelu' + pooler_activation_fn = 'tanh' + pooler_dropout = 0.0 + + # classification task + num_classes = 2 + + +def attn_fn(query: torch.Tensor, key: torch.Tensor, + wq: torch.Tensor, wq_bias: Optional[torch.Tensor], + wk: torch.Tensor, wk_bias: Optional[torch.Tensor], + wv: torch.Tensor, wv_bias: Optional[torch.Tensor], + wout: torch.Tensor, wout_bias: Optional[torch.Tensor], + h: int, scale: float, dropout: float, mask=True): + """ + query, key: (L, N, E) = (seqlen, batch size, embed_dim) + wq, wk, wv weight: [(num_head * dim_head), E] + dropout: float + h: int: number of heads + """ + num_head = h + L, N = query.size(0), query.size(1) + dim_head = wq.size(0) // num_head + + q = torch.nn.functional.linear(query, wq, wq_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(key, wk, wk_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(key, wv, wv_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + if mask: # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, wout, wout_bias) # L N (h d), E E -> L N E + return output + + +class MultiheadAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, dropout=0.0, bias=True): + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + self.qdim = embed_dim + self.kdim = embed_dim + self.vdim = embed_dim + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.num_heads = num_heads + self.dropout_p = dropout + # K + self.k_proj = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size, self.kdim)) + if bias: + self.k_bias = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size)) + else: + self.k_bias = None + # V + self.v_proj = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size, self.vdim)) + if bias: + self.v_bias = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size)) + else: + self.v_bias = None + # Q + self.q_proj = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size, self.qdim)) + if bias: + self.q_bias = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size)) + else: + self.q_bias = None + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim // self.tp_size)) + if bias: + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.out_bias = None + + def forward(self, query: torch.Tensor, key: torch.Tensor): + if key is not query: + key = IdentityAllreduce.apply(key, self.tp_group) + query = IdentityAllreduce.apply(query, self.tp_group) + attn = attn_fn(query, key, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) + attn = AllReduceIdentity.apply(attn, self.tp_group) + return attn + + +class EncoderLayer(torch.nn.Module): + + def __init__(self, cfg: Config): + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + self.cfg = cfg + self.self_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.encoder_attention_heads, cfg.attention_dropout) + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.dropout = torch.nn.Dropout(p=cfg.dropout) + self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + self.fc1 = torch.nn.Linear(cfg.encoder_embed_dim, cfg.encoder_ffn_embed_dim // self.tp_size) + self.fc2 = torch.nn.Linear(cfg.encoder_ffn_embed_dim // self.tp_size, cfg.encoder_embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + def input_shape(self): + # L, N, E + return (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + + def output_shape(self): + # L N E + return (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + + def input_dtype(self): + return torch.float32 + + def output_dtype(self): + return torch.float32 + + def forward(self, x): # , encoder_padding_mask: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None): + # print(f'encoder layer: x: {x.size()}') + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x, x) + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + if self.tp_size > 1: + x = IdentityAllreduce.apply(x, self.tp_group) + x = self.fc1(x) + x = torch.nn.functional.gelu(x) + x = self.activation_dropout(x) + x = self.fc2(x) + if self.tp_size > 1: + x = AllReduceIdentity.apply(x, self.tp_group) + + x = self.dropout(x) + x = x + residual + return x + + +class DecoderLayer(torch.nn.Module): + + def __init__(self, cfg: Config): + + super().__init__() + + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + self.cfg = cfg + self.dropout = torch.nn.Dropout(p=cfg.dropout) + self.self_attn = MultiheadAttention(cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) + self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + # encoder atten + self.encoder_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) + self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + self.fc1 = torch.nn.Linear(cfg.decoder_embed_dim, cfg.decoder_ffn_embed_dim // self.tp_size) + self.fc2 = torch.nn.Linear(cfg.decoder_ffn_embed_dim // self.tp_size, cfg.decoder_embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + def input_shape(self): + return ( + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + ) + + def output_shape(self): + return ( + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + ) + + def input_dtype(self): + return (torch.float32, torch.float32) + + def output_dtype(self): + return (torch.float32, torch.float32) + + def forward(self, x, encoder_out): # encoder_padding_mask): + # print(f'decoder layer: x: {x.size()}, encoder_out: {encoder_out.size()}') + residual = x + # normalize before + x = self.self_attn_layer_norm(x) + + # self attention + x = self.self_attn(x, x) + x = self.dropout(x) + x = residual + x + + # encoder attn + residual = x + # normalize before + x = self.encoder_attn_layer_norm(x) + x = self.encoder_attn(x, encoder_out) + x = self.dropout(x) + x = x + residual + + residual = x + # normalize before + x = self.final_layer_norm(x) + if self.tp_size > 1: + x = IdentityAllreduce.apply(x, self.tp_group) + x = self.fc1(x) + x = torch.nn.functional.gelu(x) + x = self.activation_dropout(x) + x = self.fc2(x) + if self.tp_size > 1: + x = AllReduceIdentity.apply(x, self.tp_group) + x = self.dropout(x) + x = x + residual + return x, encoder_out + + +class MBartClassificationHead(torch.nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + self.num_classes = num_classes + self.dense = torch.nn.Linear(input_dim, inner_dim // self.tp_size) + self.dropout = torch.nn.Dropout(p=pooler_dropout) + self.out_proj = torch.nn.Linear(inner_dim // self.tp_size, num_classes) + self.loss_fct = torch.nn.CrossEntropyLoss() + + def forward(self, dec: torch.Tensor, labels): + # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] + dec = dec.transpose(0, 1)[:,-1,:] + sentence_represent = dec + hidden_states = self.dropout(sentence_represent) + if self.tp_size > 1: + hidden_states = IdentityAllreduce.apply(hidden_states, self.tp_group) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + logits = self.out_proj(hidden_states) + if self.tp_size > 1: + logits = AllReduceIdentity.apply(logits, self.tp_group) + loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) + return (loss,) + + +class ShardHeadTail(torch.nn.Module): + + def __init__(self, cfg: Config): + """ + group = -1 means no tensor parallelism + """ + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + self.tp_idx = 0 if _tp_group == -1 else torch.distributed.get_rank(_tp_group) + + self.cfg = cfg + if self.tp_size > 0: + print(f'[{torch.distributed.get_rank()}]: initialize sharding embed (x{self.tp_size})') + + self.vocab_start_index = self.cfg.num_embeddings // self.tp_size * self.tp_idx + self.vocab_end_index = self.cfg.num_embeddings // self.tp_size * (self.tp_idx + 1) + self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.tp_size, self.cfg.encoder_embed_dim))) + + # encoder-preprocess + self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) + self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) + self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + # decoder-preprocess + self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) + self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) + self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + self._inputs = (None, ) + + def set_inputs(self, *inputs): + self._inputs = inputs + + def embed_lookup(self, tokens): + if self.tp_size > 1: + mask = (tokens < self.vocab_start_index) | \ + (tokens >= self.vocab_end_index) + tokens = tokens.clone() - self.vocab_start_index + tokens[mask] = 0 + embed = torch.nn.functional.embedding(tokens, self.weight) + embed[mask, :] = 0.0 + embed = AllReduceIdentity.apply(embed, self.tp_group) + else: + embed = torch.nn.functional.embedding(tokens, self.weight) + return embed + + def encoder_preprocess(self): + source_tokens = self._inputs[0] + source_embed = self.embed_lookup(source_tokens) + embed = self.embed_scale_encoder * source_embed + x = embed + self.embed_positions_encoder.weight + x = self.layernorm_embedding_encoder(x) + x = torch.nn.functional.dropout(x, p=0.0) + enc = x.transpose(0, 1) + return (enc,) + + def decoder_preprocess(self): + prev_output_tokens = self._inputs[0] + target_emb = self.embed_lookup(prev_output_tokens) + embed = self.embed_scale_decoder * target_emb + embed = embed + self.embed_positions_decoder.weight + embed = self.layernorm_embedding_decoder(embed) + embed = torch.nn.functional.dropout(embed, p=0.0) + dec = embed.transpose(0, 1) + return (dec,) + + +class mBARTFull(torch.nn.Module): + + def __init__(self, cfg: Config, + encoder_preprocess=True, + decoder_preprocess=True, + post_process=True): + super().__init__() + self.cfg = cfg + self.dummy_labels = torch.tensor([1]).cuda() + self._preprocess = [None, None] # enc, dec + + self.rank = DeviceGroup().rank + + global _pp_group + self.pp_group = _pp_group + self.total_layers = cfg.encoder_layers + cfg.decoder_layers + + self.pp_stage = torch.distributed.get_rank(_pp_group) + self.num_stages = torch.distributed.get_world_size(_pp_group) + + self.layer_start = self.total_layers // self.num_stages * self.pp_stage + self.layer_end = self.total_layers // self.num_stages * (self.pp_stage + 1) + + self.encoder_preprocess = encoder_preprocess + self.encoder_forward = (self.layer_start < cfg.encoder_layers) + + self.decoder_preprocess = decoder_preprocess + self.decoder_first_stage = self.layer_start == cfg.encoder_layers + self.decoder_forward = (self.layer_end > cfg.encoder_layers) + self.decoder_last_stage = (self.layer_end == cfg.encoder_layers + cfg.decoder_layers) + + self.postprocess = post_process + + self.encoder_layer_start = self.layer_start + self.encoder_layer_end = min(self.layer_end, cfg.encoder_layers) + + self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) + self.decoder_layer_end = self.layer_end + + if encoder_preprocess or decoder_preprocess or post_process: + self.headtail = ShardHeadTail(cfg) + else: + self.headtail = None + + # encoders + if self.encoder_forward: + print(f'[{self.rank}]: initializing {self.encoder_layer_end - self.encoder_layer_start} encoder layers') + self.encoders = torch.nn.ModuleList([EncoderLayer(cfg) for _ in range(self.encoder_layer_end - self.encoder_layer_start)]) + if self.encoder_layer_end == cfg.encoder_layers: + self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + else: + self.layer_norm_encoder = None + + # decoders + if self.decoder_forward: + print(f'[{self.rank}]: initializing {self.decoder_layer_end - self.decoder_layer_start} decoder layers') + self.decoders = torch.nn.ModuleList([DecoderLayer(cfg) for _ in range(self.decoder_layer_end - self.decoder_layer_start)]) + if self.decoder_layer_end == cfg.encoder_layers + cfg.decoder_layers: + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + else: + self.layer_norm_decoder = None + + # postpross + if self.postprocess: + print(f'[{self.rank}]: will compute loss') + self.head = MBartClassificationHead(cfg.decoder_embed_dim, 1024, cfg.num_classes, 0.0) + + def input_shape(self): + if self.encoder_preprocess: + return () + elif self.encoder_forward: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + ) + elif self.decoder_preprocess: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + ) + elif self.decoder_first_stage: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + ) + elif self.decoder_forward: + return ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + ) + elif self.postprocess: + return ((1,),) + assert False + + def output_shape(self): + shape = None + if self.encoder_preprocess or self.encoder_forward: + shape = (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + # decoder preprocess is not allowed to be a single stage + if self.decoder_preprocess or self.decoder_forward: + shape = ( + (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + ) + if self.postprocess: + shape = ( + (1,), + ) + assert shape is not None + return shape + + def input_dtype(self): + if self.encoder_preprocess: + return () + elif self.encoder_forward: + return (torch.float32,) + elif self.decoder_preprocess: + return (torch.float32,) + elif self.decoder_forward: + return (torch.float32, torch.float32) + else: + assert False + + def output_dtype(self): + dtype = None + if self.encoder_preprocess or self.encoder_forward: + dtype = (torch.float32,) + if self.decoder_preprocess or self.decoder_forward: + if self.pp_stage == self.num_stages - 1: + dtype = (torch.float32,) + else: + dtype = (torch.float32, torch.float32) + if self.postprocess: + dtype = ((torch.float32,),) + assert dtype is not None + return dtype + + def set_inputs(self, *inputs): + assert len(inputs) == 1 + if self.headtail is not None: + self.headtail.set_inputs(*inputs) + + def set_preprocess(self, enc=None, dec=None): + if enc is not None: + self._preprocess[0] = enc + if dec is not None: + self._preprocess[1] = dec + + def forward_encoder_preprocess(self): + return self.headtail.encoder_preprocess() + + def forward_decoder_preprocess(self): + return self.headtail.decoder_preprocess() + + def forward_postprocess(self, dec): + return self.head(dec, self.dummy_labels) + + def forward(self, enc=None, dec=None): + """ + enc: encoder input/output + dec: decoder output/input + """ + pre_enc, pre_dec = self._preprocess + enc = pre_enc if enc is None else enc + dec = pre_dec if dec is None else dec + + # encoder preprocess + if self.encoder_preprocess: + output = self.forward_encoder_preprocess() + enc = output[0] + + # forward encoder + if self.encoder_forward: + for layer in self.encoders: + enc = layer(enc) # encoder_padding_mask if has_pads else None) + if self.layer_norm_encoder is not None: + enc = self.layer_norm_encoder(enc) + output = (enc,) + + # decoder preprocess + if self.decoder_preprocess: + output = self.forward_decoder_preprocess() + dec = output[0] + + # forward decoder + if self.decoder_forward: + dec = pre_dec if dec is None else dec + for layer in self.decoders: + dec, enc = layer(dec, enc) + if self.layer_norm_decoder is not None: + dec = self.layer_norm_decoder(dec) + output = (dec,) + else: + output = (enc, dec) + + # postprocess + if self.postprocess: + output = self.forward_postprocess(dec) + loss = output[0] + + return output + + +def reduce_embed(model, pp_embed_group): + """ + Embedding gradients needs to be reduced across pipeline stages + """ + if isinstance(model.headtail, torch.nn.Module): + grad = model.headtail.weight.grad + else: + grad = None + if grad is not None: + torch.distributed.all_reduce(grad, group=pp_embed_group) + torch.cuda.synchronize() + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='swin') + parser.add_argument('--nmb', type=int, default=4, + help='num of micro batch') + parser.add_argument('--pp-size', type=int, default=1, + help='use pipeline parallelism') + parser.add_argument('--tp-size', type=int, default=1, + help='use tensor parallelism') + args = parser.parse_args() + + print(args) + + cube.init() + pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) + print_each_rank(f'my pp ranks: {pp_ranks}') + print_each_rank(f'my tp ranks: {tp_ranks}') + + if _tp_group == -1: + _tp_group = DeviceGroup().get_group(tp_ranks) + + if _pp_group == -1: + _pp_group = DeviceGroup().get_group(pp_ranks) + idx = pp_ranks.index(DeviceGroup().rank) + _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] + _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] + is_first_stage = idx == 0 + is_first_decoder_stage = idx == len(pp_ranks) // 2 + is_last_stage = idx == len(pp_ranks) - 1 + + if len(pp_ranks) > 1: + pranks = [torch.zeros((args.pp_size,), dtype=torch.int, device=torch.cuda.current_device()) for _ in range(args.tp_size)] + prank = torch.tensor(pp_ranks, dtype=torch.int).cuda() + pranks[torch.distributed.get_rank(_pp_group)] = prank + torch.distributed.all_gather(pranks, prank, group=_tp_group) + torch.cuda.synchronize() + print_each_rank(f'allgather-pp ranks: {pranks}') + + for prank in pranks: + prank = prank.tolist() + embed_ranks = [prank[0], prank[len(prank) // 2]] + embed_ranks = list(set(embed_ranks)) + group = DeviceGroup().get_group(embed_ranks) + if torch.distributed.get_rank(_tp_group) in prank: + print(f'embedding group: {embed_ranks}') + _pp_embed_group = group + assert _pp_embed_group != -1 + + cfg = Config() + dataloader = SynTextDataLoader( + shapes=( + [1, cfg.max_source_positions], + ), + dtypes=(torch.int64,), + batch_dims=(0,) + ) + dataloader = iter(dataloader) + + + if args.pp_size > 1: + encoder_preprocess = is_first_stage + decoder_preprocess = is_first_decoder_stage + postprocess = is_last_stage + model = mBARTFull(cfg, encoder_preprocess, decoder_preprocess, postprocess).cuda() + else: + model = mBARTFull(cfg, True, True, True).cuda() + + print_each_rank('model weight consumpition:') + memory_summary() + + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + CudaTimer(enable=False).warmup() + iter_num = 32 + for step in range(iter_num): + if step >= 10: + CudaTimer(enable=True).start('e2e') + if args.pp_size > 1: + schedule_naive(model, dataloader, args.nmb, (_pp_prev_rank, _pp_next_rank)) + reduce_embed(model, _pp_embed_group) + else: + for _ in range(args.nmb): + model.set_inputs(next(dataloader)) + loss = model()[0] + loss.backward() + if step == 0: + print('passed 1st iteration') + memory_summary() + optimizer.step() + optimizer.zero_grad() + if step >= 10: + CudaTimer().stop('e2e') + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-10, field_name='e2e'))) + memory_summary() diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh index c8fbe7ee..bd15c737 100755 --- a/handcraft/mbart/run.sh +++ b/handcraft/mbart/run.sh @@ -1,18 +1,33 @@ # 4 gpus +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 64 > 4dev64nmb-tp1f1b-pack.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-naive --nmb 64 > 4dev64nmb-naive.txt + OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 64 > 4dev64nmb-tp1f1b-pack.txt + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 64 > 4dev64nmb-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-naive --nmb 64 > 4dev64nmb-naive.txt + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > 4dev64nmb-tp2pp2.txt # OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ # handcraft/pipeline/dummy_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > 4dev64nmb-2tp2pp.txt # 8 gpus +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 128 > 8dev128nmb-tp1f1b-pack.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-naive --nmb 128 > 8dev128nmb-naive.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 128 > 8dev128nmb-tp.txt + OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 128 > 8dev128nmb-tp1f1b-pack.txt + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 128 > 8dev128nmb-tp4pp2.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-naive --nmb 128 > 8dev128nmb-naive.txt + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 128 > 8dev128nmb-tp2pp4.txt \ No newline at end of file From 4ef902d522dbc6cb25ad78002169dd90f16f7726 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Mar 2022 20:51:14 +0800 Subject: [PATCH 0667/1892] fix dataloader bug --- handcraft/mbart/mbart_hybrid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 78ead61e..517658c4 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -686,7 +686,6 @@ def reduce_embed(model, pp_embed_group): dtypes=(torch.int64,), batch_dims=(0,) ) - dataloader = iter(dataloader) if args.pp_size > 1: @@ -708,11 +707,12 @@ def reduce_embed(model, pp_embed_group): if step >= 10: CudaTimer(enable=True).start('e2e') if args.pp_size > 1: - schedule_naive(model, dataloader, args.nmb, (_pp_prev_rank, _pp_next_rank)) + schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) reduce_embed(model, _pp_embed_group) else: + loader = iter(dataloader) for _ in range(args.nmb): - model.set_inputs(next(dataloader)) + model.set_inputs(next(loader)) loss = model()[0] loss.backward() if step == 0: From 6b2e0f235a88ade223f3f299f2f4c903517e8ae5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 18 Mar 2022 21:21:59 +0800 Subject: [PATCH 0668/1892] add mbart model --- cube/profiler/memory.py | 5 +- examples/feedforward/ffn.py | 112 ------ examples/feedforward/policy/data.py | 25 -- examples/feedforward/policy/tensor.py | 32 -- examples/gpt/gpt.py | 265 -------------- examples/gpt/policy/megatron.md | 90 ----- examples/gpt/policy/megatron_parallel.py | 138 -------- examples/mbart/mbart.py | 425 ----------------------- examples/nlp/blocks/attention.py | 162 +++++++++ examples/nlp/blocks/decoder.py | 43 +++ examples/nlp/blocks/encoder.py | 29 ++ examples/nlp/blocks/mlp.py | 32 ++ examples/nlp/gpt/model.py | 101 ++++++ examples/nlp/gpt/train.py | 69 ++++ examples/nlp/mbart/model.py | 201 +++++++++++ examples/nlp/mbart/train.py | 69 ++++ 16 files changed, 710 insertions(+), 1088 deletions(-) delete mode 100644 examples/feedforward/ffn.py delete mode 100644 examples/feedforward/policy/data.py delete mode 100644 examples/feedforward/policy/tensor.py delete mode 100644 examples/gpt/gpt.py delete mode 100644 examples/gpt/policy/megatron.md delete mode 100644 examples/gpt/policy/megatron_parallel.py delete mode 100644 examples/mbart/mbart.py create mode 100644 examples/nlp/blocks/attention.py create mode 100644 examples/nlp/blocks/decoder.py create mode 100644 examples/nlp/blocks/encoder.py create mode 100644 examples/nlp/blocks/mlp.py create mode 100644 examples/nlp/gpt/model.py create mode 100644 examples/nlp/gpt/train.py create mode 100644 examples/nlp/mbart/model.py create mode 100644 examples/nlp/mbart/train.py diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index 263edb52..e9cb8a13 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -29,7 +29,10 @@ def model_summary(model: torch.nn.Module, inputs: List[Any], do_eval=False, max_ static_memory = torch.cuda.memory_allocated() print_each_rank( 'static model: {:,.2f} MB'.format(static_memory / 1024 / 1024), rank_only=0) - + nparams = sum([param.numel() for param in model.parameters()]) + print_each_rank( + 'model paramters: {:,.2f} M'.format(nparams / 1000000), rank_only=0) + stat = dict(depth=0) def before_forward(module, input): module._summary_depth = stat['depth'] diff --git a/examples/feedforward/ffn.py b/examples/feedforward/ffn.py deleted file mode 100644 index 794cffba..00000000 --- a/examples/feedforward/ffn.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/feedforward/ffn.py - -OMP_NUM_THREADS=4 torchrun --standalone \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/feedforward/ffn.py - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --rdzv_id=888 \ - --rdzv_backend=c10d \ - --rdzv_endpoint=worker0:8004 \ - examples/feedforward/ffn.py -""" - -import torch -import torch.nn.functional as F - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - -from examples.feedforward.policy.data import PAS - - -class FFN(torch.nn.Module): - - def __init__(self, hidden_size: int): - super().__init__() - self.dense_h_to_4h = torch.nn.Linear( - hidden_size, 4 * hidden_size - ) - self.dense_4h_to_h = torch.nn.Linear( - 4 * hidden_size, hidden_size - ) - - def forward(self, hidden_states): - # [L, N, E] * [E, 4E] -> [L, N, 4E] - out = self.dense_h_to_4h(hidden_states) - # [L, N, 4E] -> [L, N, 4E] - out = F.gelu(out) - # [L, N, 4E] * [4E, E] -> [L, N, E] - out = self.dense_4h_to_h(out) - - loss = torch.sum(out) - return loss - - -def train(): - L = 512 # seq len - N = 32 # batch size - # configs: [hidden size, num_head] - # E, num_head = [1536, 16] # 1.2B model - # E, num_head = [1920, 20] # 2.5B model - # E, num_head = [2304, 24] # 4.2B model - E, num_head = [3072, 32] # 8.7B model - - - model = FFN(hidden_size=E) - model = cube.SemanticModel( - model, input_shapes=([L, N, E],), - ) - - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([L, N, E],), - dtypes=(torch.float32,), - batch_dims=(1,) - ) - - @cube.compile(model, dataloader, PAS=PAS) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - model = model.get_gen_module() - - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-40) - - -if __name__ == '__main__': - - cube.init() - train() diff --git a/examples/feedforward/policy/data.py b/examples/feedforward/policy/data.py deleted file mode 100644 index 80200b62..00000000 --- a/examples/feedforward/policy/data.py +++ /dev/null @@ -1,25 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.operator import IRFwOperation, IRDataOperation - - -def PAS(graph: IRGraph, resource): - """ - Data Parallel - """ - # data operation - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(num=resource.ngpus)) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - batch_dim = node.get_batch_dims()[0] - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=batch_dim, num=resource.ngpus) - ) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - return graph diff --git a/examples/feedforward/policy/tensor.py b/examples/feedforward/policy/tensor.py deleted file mode 100644 index 42e22fcd..00000000 --- a/examples/feedforward/policy/tensor.py +++ /dev/null @@ -1,32 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.operator import IRFwOperation, IRDataOperation - - -def PAS(graph: IRGraph, resource): - """ - Hybrid parallel - """ - # data operation replication - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - # forward operation - configs = [ - dict(idx=1, dim=0, num=resource.ngpus), # linear col - dict(idx=0, dim=-1, num=resource.ngpus), # gelu col - dict(idx=0, dim=-1, num=resource.ngpus), # linear row - dict(idx=0, dim=-1, num=resource.ngpus), # sum - ] - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - assert len(fnodes) == len(configs) - for fnode, config in zip(fnodes, configs): - algo = fnode.algorithms('dim') - sub_nodes = graph.partition( - fnode, algo, config=config - ) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - return graph diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py deleted file mode 100644 index 72ae4d1c..00000000 --- a/examples/gpt/gpt.py +++ /dev/null @@ -1,265 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/gpt/gpt.py -""" - -import torch -from torch import nn -import torch.nn.functional as F -import cube - - -from examples.gpt.policy.megatron_parallel import transform_policy -from examples.gpt.policy.megatron_parallel import schedule_policy - -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary - - -class MultiHeadSelfAttention(nn.Module): - - def __init__(self, embed_dim, heads, dropout): - super().__init__() - - self.num_head = heads - self.dim_head = embed_dim // heads - self.dropout = dropout - - self.weight_qkv = torch.nn.Parameter(torch.empty( - 3 * embed_dim, embed_dim - )) - self.weight_out = torch.nn.Parameter(torch.empty( - embed_dim, embed_dim - )) - - def forward(self, x): - """ - Multi-Head Self-Attention. - - L: sequence length - N: batch size - E: embedding size - - Inputs: - hidden_state: [L, N, E] - w_qkv : [3 * num_head * dim_head, E] - w_out : [E, E] - - Outputs: - hidden_state: [L, N, E] - """ - - hidden_state = cube.runtime.function.complex.self_attn( - x, self.weight_qkv, self.weight_out, - self.num_head, self.dim_head, self.dropout - ) - return hidden_state - - -class FFN(torch.nn.Module): - - def __init__(self, hidden_size: int): - super().__init__() - self.proj1_weight = torch.nn.Parameter( - torch.empty(4 * hidden_size, hidden_size) - ) - self.proj1_bias = torch.nn.Parameter( - torch.empty(4 * hidden_size) - ) - self.proj2_weight = torch.nn.Parameter( - torch.empty(hidden_size, 4 * hidden_size) - ) - self.proj2_bias = torch.nn.Parameter( - torch.empty(hidden_size) - ) - - def forward(self, hidden_states): - hidden_states = cube.runtime.function.complex.feedforward( - hidden_states, - self.proj1_weight, self.proj1_bias, - self.proj2_weight, self.proj2_bias - ) - return hidden_states - - -class TransformerLayer(torch.nn.Module): - - def __init__(self, hidden_size, num_head, dropout): - super().__init__() - # layer norm - self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - - self.attention = MultiHeadSelfAttention(hidden_size, num_head, dropout) - self.attn_dropout = torch.nn.Dropout(dropout) - - self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - self.ffn = FFN(hidden_size) - self.ffn_dropout = torch.nn.Dropout(dropout) - - def forward(self, hidden_states): - # Attention - in_attn_norm = self.input_layernorm(hidden_states) - attn_out = self.attention(in_attn_norm) - # residual - attn_out = self.attn_dropout(attn_out) - residual = attn_out + hidden_states - # ffn - in_ffn_norm = self.ffn_layernorm(residual) - ffn_out = self.ffn(in_ffn_norm) - # residual - ffn_out = self.ffn_dropout(ffn_out) - ffn_out = ffn_out + residual - return ffn_out - - -class GPT(torch.nn.Module): - - def __init__(self, hidden_size, vocab_size, seqlen_size, - bs, seqlen, num_head, num_layers: int): - super().__init__() - - self.num_layers = num_layers - self.bs = bs - self.seqlen = seqlen - self.ntoken = 1.0 / self.bs * self.seqlen - - # embeddings - self.vocab_size = vocab_size - self.vocab_embed_weight = torch.nn.Parameter( - torch.empty(vocab_size, hidden_size) - ) - self.seqlen_size = seqlen_size - self.pos_embed_weight = torch.nn.Parameter( - torch.empty(seqlen_size, hidden_size) - ) - - self.embed_dropout = torch.nn.Dropout(0.5) - - # transformer layers - # self.layers = torch.nn.ModuleList( - # [TransformerLayer(seqlen, hidden_size, num_head, 0.5) for _ in range(num_layers)] - # ) - self.transform1 = TransformerLayer(hidden_size, num_head, 0.5) - self.transform2 = TransformerLayer(hidden_size, num_head, 0.5) - self.transform3 = TransformerLayer(hidden_size, num_head, 0.5) - self.transform4 = TransformerLayer(hidden_size, num_head, 0.5) - - # final linear - self.final_layernorm = torch.nn.LayerNorm( - hidden_size, 1e-5 - ) - - def forward(self, input_ids, position_ids): - """ - input_ids: - [bs, seqlen] - position_ids: - [bs, seqlen] - """ - - # preprocess: embedding - # [bs, seqlen] -> [bs, seqlen, hidden size] - words_embeddings = cube.runtime.function.embedding( - input_ids, self.vocab_embed_weight, 0, self.vocab_size - ) - # [bs, seqlen] -> [bs, seqlen, hidden size] - position_embeddings = cube.runtime.function.embedding( - position_ids, self.pos_embed_weight, 0, self.seqlen_size - ) - embeddings = words_embeddings + position_embeddings - encoder_input = self.embed_dropout(embeddings) - - # [bs, seqlen, hidden size] -> [seqlen, bs, hidden size] - hidden_states = encoder_input.transpose(0, 1) #.contiguous() - - # transformer - # [seqlen, bs, hidden size] -> [seqlen, bs, hidden size] - # for layer in self.layers: - # hidden_states = layer(hidden_states) - hidden_states = self.transform1(hidden_states) - hidden_states = self.transform2(hidden_states) - hidden_states = self.transform3(hidden_states) - hidden_states = self.transform4(hidden_states) - - hidden_states = self.final_layernorm(hidden_states) - - # post process - # [seqlen, bs, hidden size] -> [bs, seqlen, hidden size] - hidden_states = hidden_states.transpose(0, 1) # .contiguous() - # [bs, seqlen, hidden size] * [self.vocab_size, hidden size] - # => [bs, seqlen, self.vocab_size] - logits = F.linear(hidden_states, self.vocab_embed_weight) - - # loss # for verification, the mask is ommitted - # [bs, seqlen, self.vocab_size] -> [1] - loss = torch.sum(logits) - # loss = loss * self.ntoken - - return loss - - -def train(): - L = 512 # seq len - N = 8 # batch size - # configs: [hidden size, num_head] - # E, num_head = [2304, 24, 24] # 1.7B model - E, num_head, layers = [3072, 32, 30] # 3.6B model - # E, num_head, layers = [4096, 32, 36] # 7.5B model - - - model = GPT( - hidden_size=E, vocab_size=50304, seqlen_size=L, - bs=N, seqlen=L, num_head=num_head, num_layers=layers - ) - model = cube.SemanticModel( - model, input_shapes=([N, L], [N, L],), - ) - - dataloader = cube.runtime.syndata.SynTextDataLoader(1280, [0, 0], [N, L], [N, L]) - - @cube.compile(model, dataloader, policy=(transform_policy, schedule_policy)) - def train_iter(model, dataloader): - input_ids, position_ids = next(dataloader) - loss = model(input_ids, position_ids) - loss.backward() - model = model.get_gen_module() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on iteration') - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - - -if __name__ == '__main__': - - cube.init() - train() diff --git a/examples/gpt/policy/megatron.md b/examples/gpt/policy/megatron.md deleted file mode 100644 index 79c35e8a..00000000 --- a/examples/gpt/policy/megatron.md +++ /dev/null @@ -1,90 +0,0 @@ - -``` -function PADataParallel(Graph G, Resource R, Config C): - for node in G.nodes do - algorithm <- getPartitionAlgo(node, 'data parallelism') - subnodes <- G.partition(node, algorithm, C.data_parallel_size) - for dp_idx in 0 to C.data_parallel_size do - G.assign(subnodes[dp_idx], dp_idx) - return G - - -function PATensorParallel(Graph G, Resource R, Config C): - for node in G.nodes do - algorithm <- getPartitionAlgo(node, 'tensor parallelism') - subnodes <- G.partition(node, algorithm, C.tensor_parallel_size) - for tp_idx in 0 to C.tensor_parallel_size do - G.assign(subnodes[tp_idx], tp_idx) - return G - - -function PAPipelineParallel(Graph G, Resource R, Config C): - - for node in G.nodes do - algorithm <- getPartitionAlgo(node, 'data parallelism') - G.partition(node, algorithm, C.num_micro_batches) - - for node in G.nodes do - stage_id <- getStageID(node, G, C.pipeline_parallel_size) // policy - G.assign(node, stage_id) - - // group to a sub-graph (block): A microbatch on one stage - groupStageAndMicroBatch(G, C.pipeline_parallel_size, C.num_micro_batches) - return G - - -function PSPipelineParallel(Graph G, Resource R, Config C): - // each node in G stands for a block (sub-graph) - sequence <- EmptyArray[] - // warmup phase - for micro_batch_id in 0 to C.num_micro_batches do - for stage_id in 0 to C.pipeline_parallel_size - micro_batch_id do - node <- getForwardBlock(G, micro_batch_id, stage_id) - arrayPush(sequence, node) - # steady and cooldown phase - for micro_batch_id in 0 to C.num_micro_batches do - // enqueue backward - for stage_id in C.pipeline_parallel_size to 0 do - node <- getBackwardBlock(G, micro_batch_id, stage_id) - arrayPush(sequence, node) - // enqueue forward - for stage_id in 0 to C.pipeline_parallel_size do - mid <- micro_batch_id + C.pipeline_parallel_size - stage_id - if mid <= C.pipeline_parallel_size then - node <- getForwardStage(G, mid, stage_id) - arrayPush(sequence, node) - G.schedule(sequence) - return G - - -function Megatron(Graph G, Resource R, Config C): - // Graph G: Dataflow graph containing operators as nodes - // Resource R: Environment Resource including GPU numbers and topology - // Config C: policy user configuration including: - // data_parallel_size, - // tensor_parallel_size, - // pipeline_parallel_size, - // num_micro_batches - - // Resource split: group resources - Rs <- splitResource(R, C) - R_pp <- getResourceForPP(Rs, C) - - // group into blocks (each block is a microbatch on a stage) - G <- PAPipelineParallel(G, R_pp, C) - - // inter block scheduling: 1F1B scheduling - G <- PSPipelineParallel(G, R_pp, C) - - // inner block parallelism: hybrid parallelism - for block in G.nodes do - stage_id <- getStageID(G, block) - // data parallelism - R_dp <- getResourceForDP(Rs, stage_id) - PADataParallel(block, R_dp, C) - // tensor parallelism - R_tp <- getResourceForTP(Rs, stage_id) - PATensorParallel(block, R_tp, C) - - return G -``` \ No newline at end of file diff --git a/examples/gpt/policy/megatron_parallel.py b/examples/gpt/policy/megatron_parallel.py deleted file mode 100644 index f43158a4..00000000 --- a/examples/gpt/policy/megatron_parallel.py +++ /dev/null @@ -1,138 +0,0 @@ -import time - -from cube.graph import IRGraph -from cube.graph.operator.function import CubeComplexEmbedding, Linear, Sum -from cube.graph.operator.function import CubeComplexFeedForward -from cube.graph.operator.function import CubeComplexSelfAttention -from cube.graph.operator.function import Transpose -from cube.schedule.su import SUType -from cube.schedule.sugraph import SUGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation - - -def transform_policy(graph: IRGraph, resource): - """ - The transformation policy transposes linear using tensor parallel - """ - print('> transforming graph...') - ndevs = resource.ngpus - dp = 1 - tp = ndevs // dp - - # dataloader - - dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] - for dnode in dnodes: - sub_nodes = list() - algo = dnode.algorithms('data') - dp_nodes = graph.partition(dnode, algo, config=dict(chunk_num=dp)) - for dp_node in dp_nodes: - tp_nodes = graph.replicate(dp_node, times=tp) - sub_nodes += tp_nodes - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # preprocess before transformer - for fnode in fnodes[:5]: - sub_nodes = list() - if isinstance(fnode, CubeComplexEmbedding): - algo = fnode.algorithms('data') - dp_nodes = graph.partition(fnode, algo, config=dict(dim=0, chunk_num=dp)) - if dp_nodes[0].inputs(1).shape[0] >= 50000: - for dp_node in dp_nodes: - algo = dp_node.algorithms('shard') - tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) - sub_nodes += tp_nodes - else: - for dp_node in dp_nodes: - tp_nodes = graph.replicate(dp_node, times=tp) - sub_nodes += tp_nodes - else: - algo = fnode.algorithms('dim') - assert algo - dp_nodes = graph.partition(fnode, algo, config=dict(dim=0, chunk_num=dp)) - for dp_node in dp_nodes: - tp_nodes = graph.replicate(dp_node, times=tp) - sub_nodes += tp_nodes - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - # transformers - for fnode in fnodes[5:-3]: - sub_nodes = list() - if isinstance(fnode, CubeComplexSelfAttention): - algo = fnode.algorithms('data') - dp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=dp)) - for dp_node in dp_nodes: - algo = dp_node.algorithms('head') - tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) - sub_nodes += tp_nodes - elif isinstance(fnode, CubeComplexFeedForward): - algo = fnode.algorithms('data') - dp_nodes = graph.partition(fnode, algo, config=dict(chunk_num=dp)) - for dp_node in dp_nodes: - algo = dp_node.algorithms('tensor') - tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) - sub_nodes += tp_nodes - else: - # note replicate should put in the last due to bugs: - algo = fnode.algorithms('dim') - dp_nodes = graph.partition(fnode, algo, config=dict(dim=1, chunk_num=dp)) - for dp_node in dp_nodes: - rep_nodes = graph.replicate(dp_node, times=tp) - sub_nodes += rep_nodes - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - # post-process - for fnode in fnodes[-3:]: - sub_nodes = list() - if isinstance(fnode, Transpose): - algo = fnode.algorithms('dim') - dp_nodes = graph.partition(fnode, algo, config=dict(dim=1, chunk_num=dp)) - for dp_node in dp_nodes: - rep_nodes = graph.replicate(dp_node, times=tp) - sub_nodes += rep_nodes - elif isinstance(fnode, Linear): - algo = fnode.algorithms('data') - dp_nodes = graph.partition(fnode, algo, config=dict(dim=0, chunk_num=dp)) - for dp_node in dp_nodes: - algo = dp_node.algorithms('column') - tp_nodes = graph.partition(dp_node, algo, config=dict(chunk_num=tp)) - sub_nodes += tp_nodes - else: - rep_nodes = graph.replicate(fnode, times=ndevs) - sub_nodes += rep_nodes - for idx, sub_node in enumerate(sub_nodes): - sub_node.tag = idx - - # print(graph) - # assert False - return graph - - -def schedule_policy(sugraph: SUGraph, resource): - """ - The schedule policy assign devices - """ - print('> scheduling SU...') - start_time = time.time() - - for su in sugraph.sus(): - if su.stype == SUType.Dataloader: - devid = su.tag[0] - sugraph.assign(su, devid) - print('> [scheduling] assign device...') - for su in sugraph.fsus(): - devid = su.tag[0] - sugraph.assign(su, devid) - sugraph.assign(su.mirror, devid) - fsus = sugraph.fsus() - print('> [scheduling] setting schedule order...') - sugraph.partial_set_order(fsus, lazy=False) - - span = time.time() - start_time - print('> Done scheduling: {:.2f} seconds'.format(span)) - return sugraph diff --git a/examples/mbart/mbart.py b/examples/mbart/mbart.py deleted file mode 100644 index a69e3380..00000000 --- a/examples/mbart/mbart.py +++ /dev/null @@ -1,425 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - examples/mbart/mbart.py -""" - -from typing import Optional -import argparse -import math -import torch - -import cube -from cube.runtime.syndata import SynTextDataLoader -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary, model_summary -from cube.profiler.timer import print_each_rank - - -# fairseq task -# translation_from_pretrained_bart - -# fairseq criterion -# label_smoothed_cross_entropy, --label_smoothing = 0.2 - -class Config: - - num_embeddings = 250027 - - encoder_embed_path = None - encoder_embed_dim = 1024 - encoder_ffn_embed_dim = 4 * 1024 - encoder_layers = 12 - encoder_attention_heads = 16 - encoder_normalize_before = True - encoder_learned_pos = True - - decoder_embed_path = None - decoder_embed_dim = 1024 - decoder_ffn_embed_dim = 4 * 1024 - decoder_layers = 12 - decoder_attention_heads = 16 - decoder_normalize_before = True - decoder_learned_pos = True - cross_self_attention = False - no_cross_attention = False - - attention_dropout = 0.0 - activation_dropout = 0.0 - dropout = 0.1 - - max_target_positions = 1024 - max_source_positions = 1024 - adaptive_softmax_cutoff = None - adaptive_softmax_dropout = 0 - - share_decoder_input_output_embed = True - share_all_embeddings = True - - decoder_output_dim = 1024 # same with decorder_embed_dim - decoder_input_dim = 1024 # same with decorder_embed_dim - - no_scale_embedding = False # True in bart large - layernorm_embedding = True - activation_fn = 'gelu' - pooler_activation_fn = 'tanh' - pooler_dropout = 0.0 - - -def attn_fn(query: torch.Tensor, key: torch.Tensor, - wq: torch.Tensor, wq_bias: Optional[torch.Tensor], - wk: torch.Tensor, wk_bias: Optional[torch.Tensor], - wv: torch.Tensor, wv_bias: Optional[torch.Tensor], - wout: torch.Tensor, wout_bias: Optional[torch.Tensor], - h: int, scale: float, dropout: float, mask=True): - """ - query, key: (L, N, E) = (seqlen, batch size, embed_dim) - wq, wk, wv weight: [(num_head * dim_head), E] - dropout: float - h: int: number of heads - """ - num_head = h - L, N = query.size(0), query.size(1) - dim_head = wq.size(0) // num_head - - q = torch.nn.functional.linear(query, wq, wq_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(key, wk, wk_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(key, wv, wv_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # attention mask - if mask: # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L -> (N h) L L - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, wout, wout_bias) # L N (h d), E E -> L N E - return output - - -class MultiheadAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, dropout=0.0, bias=True): - super().__init__() - self.kdim = embed_dim - self.vdim = embed_dim - self.num_heads = num_heads - self.head_dim = embed_dim // num_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # K - self.k_proj = torch.nn.Parameter(torch.empty(embed_dim, self.kdim)) - if bias: - self.k_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.k_bias = None - # V - self.v_proj = torch.nn.Parameter(torch.empty(embed_dim, self.vdim)) - if bias: - self.v_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.v_bias = None - # Q - self.q_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) - if bias: - self.q_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.q_bias = None - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) - if bias: - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.out_bias = None - - def forward(self, query: torch.Tensor, key: torch.Tensor): - return attn_fn(query, key, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) - - def forward_encoder_decoder_attn(self, query: torch.Tensor, key: torch.Tensor): - # tgt_len, bsz, embed_dim = query.size() - # q = torch.nn.functional.linear(query, self.q_proj, self.q_bias) - # k = torch.nn.functional.linear(key, self.k_proj, self.k_bias) - # v = torch.nn.functional.linear(key, self.v_proj, self.v_bias) - # q = q * self.scaling - # q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) - # k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) - # v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) - # attn_weights = torch.bmm(q, k.transpose(1, 2)) - # # TODO: here needs a mask - # attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - # attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_p) - # attn = torch.bmm(attn_probs, v) - # attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) - # attn = torch.nn.functional.linear(attn, self.out_proj, self.out_bias) - return attn_fn(query, key, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) - - def forward_self_attn(self, query): - return attn_fn(query, query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) - - -class EncoderLayer(torch.nn.Module): - - def __init__(self, cfg: Config): - - super().__init__() - self.self_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.encoder_attention_heads, cfg.attention_dropout) - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) - self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.fc1 = torch.nn.Linear(cfg.encoder_embed_dim, cfg.encoder_ffn_embed_dim) - self.fc2 = torch.nn.Linear(cfg.encoder_ffn_embed_dim, cfg.encoder_embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) - - def forward(self, x): # , encoder_padding_mask: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None): - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x, x) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - x = self.fc1(x) - x = torch.nn.functional.gelu(x) - x = self.activation_dropout(x) - x = self.fc2(x) - - x = self.dropout(x) - x = x + residual - return x - - -class Encoder(torch.nn.Module): - - def __init__(self, cfg: Config, embed_tokens: torch.nn.Module): - super().__init__() - self.dropout = torch.nn.Dropout(cfg.dropout) - self.max_source_positions = cfg.max_source_positions - self.embed_tokens = embed_tokens - self.embed_scale = math.sqrt(cfg.encoder_embed_dim) - self.embed_positions = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) - self.layernorm_embedding = torch.nn.LayerNorm(cfg.encoder_embed_dim) - self.layers = torch.nn.ModuleList([]) - self.layers.extend( - [EncoderLayer(cfg) for _ in range(cfg.encoder_layers)] - ) - # normalize before - self.layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) - - def forward(self, src_tokens: torch.Tensor): - token_embedding = self.embed_tokens(src_tokens) - embed = self.embed_scale * token_embedding - - x = embed + self.embed_positions.weight # self.embed_positions(src_tokens) - x = self.layernorm_embedding(x) - x = self.dropout(x) - - x = x.transpose(0, 1) - for layer in self.layers: - x = layer(x) # encoder_padding_mask if has_pads else None) - x = self.layer_norm(x) - return x - - -class DecoderLayer(torch.nn.Module): - - def __init__(self, cfg: Config): - - super().__init__() - self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.self_attn = MultiheadAttention(cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) - self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) - # encoder atten - self.encoder_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) - self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) - # self.encoder_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) - - self.fc1 = torch.nn.Linear(cfg.decoder_embed_dim, cfg.decoder_ffn_embed_dim) - self.fc2 = torch.nn.Linear(cfg.decoder_ffn_embed_dim, cfg.decoder_embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) - - def forward(self, x, encoder_out): # encoder_padding_mask): - residual = x - # normalize before - x = self.self_attn_layer_norm(x) - - # self attention - x = self.self_attn(x, x) - x = self.dropout(x) - x = residual + x - - # encoder attn - residual = x - # normalize before - x = self.encoder_attn_layer_norm(x) - x = self.encoder_attn(x, encoder_out) - x = self.dropout(x) - x = x + residual - - residual = x - # normalize before - x = self.final_layer_norm(x) - x = self.fc1(x) - x = torch.nn.functional.gelu(x) - x = self.activation_dropout(x) - x = self.fc2(x) - x = self.dropout(x) - x = x + residual - return x - - -class Decoder(torch.nn.Module): - - def __init__(self, cfg: Config, embed_tokens: torch.nn.Module): - super().__init__() - self.dropout = torch.nn.Dropout(cfg.dropout) - self.embed_tokens = embed_tokens - self.embed_scale = math.sqrt(cfg.decoder_embed_dim) - self.embed_positions = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) - self.layernorm_embedding = torch.nn.LayerNorm(cfg.decoder_embed_dim) - self.layers = torch.nn.ModuleList([]) - self.layers.extend( - [DecoderLayer(cfg) for _ in range(cfg.decoder_layers)] - ) - self.layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) - - def forward(self, prev_output_tokens: torch.Tensor, enc: torch.Tensor): - positions = self.embed_positions.weight # self.embed_positions(prev_output_tokens) - x = self.embed_scale * self.embed_tokens(prev_output_tokens) - x = x + positions - x = self.layernorm_embedding(x) - x = self.dropout(x) - # B T C -> T B C - x = x.transpose(0, 1) - # decoder layers - for layer in self.layers: - x = layer(x, enc) - x = self.layer_norm(x) - # T x B x C -> B x T x C - x = x.transpose(0, 1) - # B T C, N, C -> B T N - x = torch.nn.functional.linear(x, self.embed_tokens.weight) - return x - -# label_smoothed_cross_entropy -def criterion(output: torch.Tensor, prev_output_tokens: torch.Tensor, label_smoothing: float = 0.2): - target = prev_output_tokens[:, 1:] - # fairseq.criterions.label_smoothed_cross_entory - # model.get_normalized_probs - lprobs = torch.nn.functional.softmax(output, dim=-1) - # fairseq.criterions.label_smoothed_nll_loss - if target.dim() == lprobs.dim() - 1: - target = target.unsqueeze(-1) - nll_loss = -lprobs.gather(dim=-1, index=target) - smooth_loss = -lprobs.sum(dim=-1, keepdim=True) - nll_loss = nll_loss.squeeze(-1) - smooth_loss = smooth_loss.squeeze(-1) - nll_loss = nll_loss.sum() - smooth_loss = smooth_loss.sum() - eps_i = label_smoothing / (lprobs.size(-1) - 1) - loss = (1.0 - label_smoothing - eps_i) * nll_loss + eps_i * smooth_loss - return loss - - -class mBART(torch.nn.Module): - - def __init__(self, cfg: Config): - super().__init__() - emb = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) - self.encoder: Encoder = Encoder(cfg, emb) - self.decoder: Decoder = Decoder(cfg, emb) - - def forward(self, src_tokens, prev_output_tokens): - encoder_out = self.encoder(src_tokens) - decoder_out = self.decoder(prev_output_tokens, encoder_out) - loss = criterion(decoder_out, prev_output_tokens) - return loss - - -def train(): - bs = 1 - - cfg = Config() - model = mBART(cfg).cuda() - - print_each_rank('model weight consumpition:') - memory_summary() - - dataloader = SynTextDataLoader( - shapes=( - [bs, cfg.max_source_positions], - [bs, cfg.max_target_positions] - ), - dtypes=(torch.int64, torch.int64), - batch_dims=(0,0,) - ) - - def train_iter(model, dataloader): - # model.eval() - src_tokens, prev_output_tokens = next(dataloader) - model_summary(model, (src_tokens, prev_output_tokens)) - # loss = model(src_tokens, prev_output_tokens) - # loss.backward() - - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - CudaTimer(enable=False).warmup() - iter_num = 1 - for step in range(iter_num): - if step >= 40: - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - break - optimizer.step() - optimizer.zero_grad() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - # print_each_rank('e2e time (ms) per iteration: {} ms'.format( - # CudaTimer().duration(iter_num-40, field_name='e2e'))) - memory_summary() - - -if __name__ == '__main__': - - cube.init() - train() \ No newline at end of file diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py new file mode 100644 index 00000000..656f80fb --- /dev/null +++ b/examples/nlp/blocks/attention.py @@ -0,0 +1,162 @@ +import torch +import cube + + +@cube.graph.parser.register('L^ N E^, H+ E^, H+, H+ E^, H+, H+ E^, H+, E^ H+, E^ -> L^ N E') +def self_attention(query: torch.Tensor, + q_proj: torch.Tensor, q_bias: torch.Tensor, + k_proj: torch.Tensor, k_bias: torch.Tensor, + v_proj: torch.Tensor, v_bias: torch.Tensor, + out_proj: torch.Tensor, out_bias: torch.Tensor, + h: int, scale: float, dropout_p: float, mask=True): + num_head = h + L, N = query.size(0), query.size(1) + dim_head = q_proj.size(0) // num_head + + q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + if mask: # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + return output + + +@cube.graph.parser.register('L^ N E^, L^ N E^, H+ E^, H+, H+ E^, H+, H+ E^, H+, E^ H+, E^ -> L^ N E') +def cross_attention(query: torch.Tensor, key: torch.Tensor, + q_proj: torch.Tensor, q_bias: torch.Tensor, + k_proj: torch.Tensor, k_bias: torch.Tensor, + v_proj: torch.Tensor, v_bias: torch.Tensor, + out_proj: torch.Tensor, out_bias: torch.Tensor, + h: int, scale: float, dropout_p: float, mask=True): + num_head = h + L, N = query.size(0), query.size(1) + dim_head = q_proj.size(0) // num_head + + q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(key, k_proj, k_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(key, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + if mask: # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + return output + + +class MultiHeadSelfAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias=True): + super().__init__() + self.kdim = embed_dim + self.vdim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # Q + self.q_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + if bias: + self.q_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.q_bias = None + # K + self.k_proj = torch.nn.Parameter(torch.empty(embed_dim, self.kdim)) + self.k_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + # V + self.v_proj = torch.nn.Parameter(torch.empty(embed_dim, self.vdim)) + self.v_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + + def forward(self, query): + return self_attention( + query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p, mask=True + ) + + +class MultiHeadCrossAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias=True): + super().__init__() + self.kdim = embed_dim + self.vdim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # Q + self.q_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + if bias: + self.q_bias = torch.nn.Parameter(torch.empty(embed_dim)) + else: + self.q_bias = None + # K + self.k_proj = torch.nn.Parameter(torch.empty(embed_dim, self.kdim)) + self.k_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + # V + self.v_proj = torch.nn.Parameter(torch.empty(embed_dim, self.vdim)) + self.v_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + + def forward(self, query: torch.Tensor, key: torch.Tensor): + return cross_attention( + query, key, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p, mask=True + ) diff --git a/examples/nlp/blocks/decoder.py b/examples/nlp/blocks/decoder.py new file mode 100644 index 00000000..ee5f5767 --- /dev/null +++ b/examples/nlp/blocks/decoder.py @@ -0,0 +1,43 @@ +import torch +from examples.nlp.blocks.attention import MultiHeadCrossAttention, MultiHeadSelfAttention +from examples.nlp.blocks.mlp import MLP + + +class DecoderLayer(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, ffn_embed_dim: int, + dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): + super().__init__() + self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) + self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, atten_dropout) + + self.cross_attn_layer_norm = torch.nn.LayerNorm(embed_dim) + self.cross_attn = MultiHeadCrossAttention(embed_dim, num_heads, atten_dropout) + + self.dropout = torch.nn.Dropout(p=dropout) + + self.mlp = MLP(embed_dim, ffn_embed_dim, activation_dropout) + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def forward(self, x: torch.Tensor, encoder_output: torch.Tensor) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x) + + x = self.dropout(x) + x = x + residual + + residual = x + x = self.cross_attn_layer_norm(x) + x = self.cross_attn(x, encoder_output) + + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + x = self.mlp(x) + + x = self.dropout(x) + x = x + residual + return x diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py new file mode 100644 index 00000000..ce22d6e9 --- /dev/null +++ b/examples/nlp/blocks/encoder.py @@ -0,0 +1,29 @@ +import torch +from examples.nlp.blocks.attention import MultiHeadSelfAttention +from examples.nlp.blocks.mlp import MLP + + +class EncoderLayer(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, ffn_embed_dim: int, + dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): + super().__init__() + self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, atten_dropout) + self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) + self.dropout = torch.nn.Dropout(p=dropout) + self.mlp = MLP(embed_dim, ffn_embed_dim, activation_dropout) + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x) + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + x = self.mlp(x) + x = self.dropout(x) + x = x + residual + return x diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py new file mode 100644 index 00000000..28a7e441 --- /dev/null +++ b/examples/nlp/blocks/mlp.py @@ -0,0 +1,32 @@ +import torch +import cube + + +@cube.graph.parser.register('L^ N E^, H+ E^, H+, E H+, E -> L^ N E') +def feedforward(x: torch.Tensor, + proj1: torch.Tensor, proj1_bias: torch.Tensor, + proj2: torch.Tensor, proj2_bias: torch.Tensor, + dropout: float) -> torch.Tensor: + x = torch.nn.functional.linear(x, proj1, proj1_bias) + x = torch.nn.functional.gelu(x) + x = torch.nn.functional.dropout(x, dropout, True, False) + x = torch.nn.functional.linear(x, proj2, proj2_bias) + return x + + +class MLP(torch.nn.Module): + + def __init__(self, embed_dim, hidden_dim, dropout: float, bias=True): + super().__init__() + self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) + self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) + self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) + self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) + self.dropout = dropout + + def forward(self, x: torch.Tensor): + x = feedforward(x, + self.proj1, self.proj1_bias, + self.proj2, self.proj2_bias, + self.dropout) + return x \ No newline at end of file diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py new file mode 100644 index 00000000..a4de26be --- /dev/null +++ b/examples/nlp/gpt/model.py @@ -0,0 +1,101 @@ +import torch +import math + +from examples.nlp.blocks.encoder import EncoderLayer + +import cube + + +class Config: + + num_embeddings = 50304 + seqlen = 512 + + # 1.7B model + embed_dim = 2304 + layers = 24 + attention_heads = 24 + + # 3.6B model + # embed_dim = 3072 + # layers = 32 + # attention_heads = 32 + + # 7.5B model + # embed_dim = 4096 + # layers = 32 + # attention_heads = 36 + + ffn_embed_dim = embed_dim * 4 + dropout = 0.0 + attn_dropout = 0.0 + activation_dropout = 0.0 + + +class GPT(torch.nn.Module): + + def __init__(self): + super().__init__() + cfg = Config() + + self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) + self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) + self.embed_dropout = torch.nn.Dropout() + + self.layers = torch.nn.ModuleList( + [EncoderLayer( + cfg.embed_dim, cfg.attention_heads, cfg.ffn_embed_dim, + cfg.dropout, cfg.attn_dropout, cfg.activation_dropout + ) for _ in range(cfg.layers)] + ) + self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): + + embed = self.embed(input_ids) + pos_embed = self.position(position_ids) + embed = embed + pos_embed + embed = self.embed_dropout(embed) + enc = embed.transpose(0, 1) + + for layer in self.layers: + enc = layer(enc) + enc = self.final_layernorm(enc) + + logits = torch.nn.functional.linear(enc, self.embed.weight) + # simplified + loss = torch.sum(logits) + return loss + + +class GPTDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + + self.bs = batch_size + self.cfg = Config() + super().__init__( + shapes=([batch_size, self.cfg.seqlen], + [batch_size, self.cfg.seqlen], + ), + dtypes=(torch.int64, torch.int64), + batch_dims=(0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + input_ids = torch.randint( + 0, self.cfg.num_embeddings, + size=(self.bs, self.cfg.seqlen), + dtype=torch.int64, device=torch.cuda.current_device() + ) + position_ids = torch.arange( + 0, self.cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() + ).repeat(self.bs) + return (input_ids, position_ids) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] \ No newline at end of file diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py new file mode 100644 index 00000000..d10676b1 --- /dev/null +++ b/examples/nlp/gpt/train.py @@ -0,0 +1,69 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/nlp/gpt/train.py +""" + + +import torch + +from examples.nlp.gpt.model import GPT +from examples.nlp.gpt.model import GPTDataLoader + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary, model_summary + + +def train(): + + batch_size = 1 + + model = GPT().cuda() + dataloader = GPTDataLoader(batch_size) + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + print_each_rank('model weight consumpition:') + memory_summary() + + def train_iter(model, dataloader): + input_ids, position_ids = next(dataloader) + loss = model(input_ids, position_ids) + loss.backward() + + CudaTimer(enable=False).warmup() + iter_num = 64 + for step in range(iter_num): + + if step == 0: + model_summary(model, next(dataloader)) + + if step >= 20: + CudaTimer(enable=True).start('e2e') + + # training + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + if step >= 20: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('passed first iteration') + + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + memory_summary() + + +if __name__ == '__main__': + + cube.init() + train() \ No newline at end of file diff --git a/examples/nlp/mbart/model.py b/examples/nlp/mbart/model.py new file mode 100644 index 00000000..57faafdb --- /dev/null +++ b/examples/nlp/mbart/model.py @@ -0,0 +1,201 @@ +import torch +import math + +from examples.nlp.blocks.encoder import EncoderLayer +from examples.nlp.blocks.decoder import DecoderLayer + +import cube + + +class Config: + + # source and target + max_source_positions = 1024 + max_target_positions = 1024 + + num_embeddings = 250027 + + encoder_embed_dim = 1024 + encoder_ffn_embed_dim = 4 * 1024 + encoder_layers = 12 + encoder_attention_heads = 16 + + decoder_embed_dim = 1024 + decoder_ffn_embed_dim = 4 * 1024 + decoder_layers = 12 + decoder_attention_heads = 16 + + attention_dropout = 0.0 + dropout = 0.1 + activation_dropout = 0.0 + + pad_token_id = 1 + eos_token_id = 2 + + # classification task + num_classes = 3 + + +class PositionalEmbedding(torch.nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, seq_len: int): + positions = torch.arange( + 0, seq_len, dtype=torch.long, device=torch.cuda.current_device() + ) + return super().forward(positions + self.offset) + + +class MBartClassificationHead(torch.nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + + self.num_classes = num_classes + self.dense = torch.nn.Linear(input_dim, inner_dim) + self.dropout = torch.nn.Dropout(p=pooler_dropout) + self.out_proj = torch.nn.Linear(inner_dim, num_classes) + self.loss_fct = torch.nn.CrossEntropyLoss() + + def forward(self, dec: torch.Tensor, labels): + # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] + dec = dec[:,-1,:] + sentence_represent = dec + hidden_states = self.dropout(sentence_represent) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + logits = self.out_proj(hidden_states) + loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) + return loss + + +class MBartForSentenceClassification(torch.nn.Module): + + def __init__(self): + super().__init__() + cfg = Config() + # embedding + self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) + + # encoder embedding + self.encoder_position = PositionalEmbedding(cfg.max_source_positions, cfg.encoder_embed_dim) + self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) + self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + # encoder layers + self.encoders = torch.nn.ModuleList( + [EncoderLayer( + cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.decoder_ffn_embed_dim, + cfg.dropout, cfg.attention_dropout, cfg.activation_dropout + ) for _ in range(cfg.decoder_layers)] + ) + self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + # decoder embedding + self.decoder_position = PositionalEmbedding(cfg.max_target_positions, cfg.decoder_embed_dim) + self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) + self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + + # decoder layers + self.decoders = torch.nn.ModuleList( + [DecoderLayer( + cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.decoder_ffn_embed_dim, + cfg.dropout, cfg.attention_dropout, cfg.activation_dropout + ) for _ in range(cfg.decoder_layers)] + ) + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + + self.head = MBartClassificationHead(cfg.decoder_embed_dim, 1024, cfg.num_classes, 0.0) + + def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor, labels: torch.Tensor): + + # encoder embedding + enc_emb = self.embed(input_ids) + enc_emb = enc_emb * self.embed_scale_encoder + enc_emb = enc_emb + self.encoder_position(input_ids.size(1)) + enc_emb = self.layernorm_embedding_encoder(enc_emb) + enc_emb = torch.nn.functional.dropout(enc_emb, p=0.0) + enc = enc_emb.transpose(0, 1) + + # encoder layers + for layer in self.encoders: + enc = layer(enc) + enc = self.layer_norm_encoder(enc) + + # decoder embedding + dec_emb = self.embed(decoder_input_ids) + dec_emb = dec_emb * self.embed_scale_decoder + dec_emb = dec_emb + self.decoder_position(decoder_input_ids.size(1)) + dec_emb = self.layernorm_embedding_decoder(dec_emb) + dec_emb = torch.nn.functional.dropout(dec_emb, p=0.0) + dec = dec_emb.transpose(0, 1) + + # decoder layers + for layer in self.decoders: + dec = layer(dec, enc) + dec = self.layer_norm_decoder(dec) + dec = dec.transpose(0, 1) + + # head + loss = self.head(dec, labels) + return loss + + +class MBartDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + + self.bs = batch_size + self.cfg = Config() + super().__init__( + shapes=([batch_size, self.cfg.max_source_positions,], + [batch_size, self.cfg.max_target_positions], + [batch_size] + ), + dtypes=(torch.int64, torch.int64, torch.int64), + batch_dims=(0, 0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + input_ids = torch.randint( + 0, self.cfg.num_embeddings, + size=(self.bs, self.cfg.max_source_positions), + dtype=torch.int64, device=torch.cuda.current_device() + ) + decoder_input_ids = MBartDataLoader.shift_tokens_right(input_ids, self.cfg.pad_token_id) + labels = torch.randint( + 0, self.cfg.num_classes, + size=(self.bs,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + return (input_ids, decoder_input_ids, labels) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + @staticmethod + def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): + prev_output_tokens = input_ids.clone() + prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) + decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() + prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() + prev_output_tokens[:, 0] = decoder_start_tokens + return prev_output_tokens + diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py new file mode 100644 index 00000000..f8eadeeb --- /dev/null +++ b/examples/nlp/mbart/train.py @@ -0,0 +1,69 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/nlp/mbart/train.py +""" + + +import torch + +from examples.nlp.mbart.model import MBartForSentenceClassification +from examples.nlp.mbart.model import MBartDataLoader + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary, model_summary + + +def train(): + + batch_size = 1 + + model = MBartForSentenceClassification().cuda() + dataloader = MBartDataLoader(batch_size) + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + print_each_rank('model weight consumpition:') + memory_summary() + + def train_iter(model, dataloader): + input_ids, decoder_input_ids, labels = next(dataloader) + loss = model(input_ids, decoder_input_ids, labels) + loss.backward() + + CudaTimer(enable=False).warmup() + iter_num = 64 + for step in range(iter_num): + + if step == 0: + model_summary(model, next(dataloader)) + + if step >= 20: + CudaTimer(enable=True).start('e2e') + + # training + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + if step >= 20: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('passed first iteration') + + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + memory_summary() + + +if __name__ == '__main__': + + cube.init() + train() \ No newline at end of file From 398d2703dbee53c6f6dc1c0a56de189ac5c7283b Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 21 Mar 2022 10:45:17 +0800 Subject: [PATCH 0669/1892] add custom_op strip_2_borders for example atmosphere-weather --- cube/graph/operator/function/customops.py | 70 +++++++++++++++++++++++ cube/graph/operator/function/function.py | 18 +++++- cube/graph/parser/mapping.py | 6 ++ cube/runtime/function/function.py | 26 ++++++++- examples/atmosphere/weather.py | 22 +++---- examples/custom_ops.py | 24 ++++++++ 6 files changed, 153 insertions(+), 13 deletions(-) create mode 100644 cube/graph/operator/function/customops.py create mode 100644 examples/custom_ops.py diff --git a/cube/graph/operator/function/customops.py b/cube/graph/operator/function/customops.py new file mode 100644 index 00000000..1fc0d1e7 --- /dev/null +++ b/cube/graph/operator/function/customops.py @@ -0,0 +1,70 @@ +from typing import List + +import cube.runtime.function +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + +class IRCustomOps(IRFwOperation): + def __init__(self, signature: str, inputs: List[IRTensor], name: str, + **kwargs): + # torch.nn.functional.pad(input, pad, mode='constant', value=0.0) + # pad: List[int] + if signature == 'examples.custom_ops.strip_2_borders': + signature = signature.replace('examples.custom_ops', 'cube.runtime.function')#'cube.runtime.function.strip_2_borders' + assert len(inputs) == 1, "Expected only input, weight, bias as inputs" + assert len(kwargs) == 0, "Expected 0 kwargs: " + super().__init__(name, signature, 1, 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + elif signature == 'examples.custom_ops.update_diag': + signature = signature.replace('examples.custom_ops', 'cube.runtime.function') + assert len(inputs) == 9, "Expected only input, weight, bias as inputs" + assert len(kwargs) == 2, "Expected 0 kwargs: " + super().__init__(name, signature, 1, 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update(kwargs) + else: + raise RuntimeError(f'IRCustomOps::__init__ unknown signature: {self.signature}') + + + def infer_shape(self) -> bool: + """ + Output shape inference given the input shapes + """ + if self.signature.endswith('strip_2_borders'): + if len(self.inputs(0).shape) == 0: + return False + shape = self.inputs(0).shape + shape[0] = shape[0]-2 + self.outputs(0).shape = shape + return True + elif self.signature.endswith('update_diag'): + shape = self.inputs(0).shape + print(f'### {self.signature} in.shape = {shape}') + self.outputs(0).shape = shape + print(f'### {self.signature} out.shape = {shape}') + return True + else: + raise RuntimeError(f'IRCustomOps::infer_shape unknown signature: {self.signature}') + + def new(self, inputs: List, outputs: List): + """ + construct a new operator sharing same kwargs with new inputs + and outputs + """ + if self.signature.endswith('strip_2_borders'): + op = IRCustomOps(self.signature, inputs, self.name,) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + op.infer_shape() + return op + elif self.signature.endswith('update_diag'): + op = IRCustomOps(self.signature, inputs, self.name, self.kwargs) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + op.infer_shape() + return op + else: + raise RuntimeError(f'IRCustomOps::new unknown signature: {self.signature}') + diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index f27fbe89..1641fd25 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -8,6 +8,7 @@ from cube.graph.operator.function.conv import IRConv3D from cube.graph.operator.function.pad import IRPad from cube.graph.operator.function.scripteinops import IRScriptEinOps +from cube.graph.operator.function.customops import IRCustomOps def _create_eshape(shape: List[int], iterator: Optional[Iterable] = None, @@ -462,4 +463,19 @@ def ScriptEinOps(signature, inputs): reduction_type = inputs[2] import pickle recipe_str = pickle.dumps(recipe) - return IRScriptEinOps(signature, tensors, 'scripteinops', recipe_str=recipe_str, reduction_type=reduction_type) \ No newline at end of file + return IRScriptEinOps(signature, tensors, 'scripteinops', recipe_str=recipe_str, reduction_type=reduction_type) + + +def CustomOps(signature, inputs): + if signature == 'examples.custom_ops.strip_2_borders': + tensors = inputs[0:1] + print(f'CustomOps:tensors[0] = {tensors[0]}') + return IRCustomOps(signature, tensors, 'custom_ops') + elif signature == 'example.custom_ops.update_diag': + tensors = inputs[0:9] + dz = inputs[9] + dt = inputs[10] + return IRCustomOps(signature, tensors, 'custom_ops', dz=dz, dt=dt) + else: + import warnings + warnings.warn(f"ERROR Unknown custom op, signature{signature}") \ No newline at end of file diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index e3dd9d43..fda88819 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -47,6 +47,9 @@ def register(signature: str, op: IRFwOperation, code): # einops __einopsize = lambda name: f'einops._torch_specific.{name}' + # custom ops + __customops = lambda name: f'examples.custom_ops.{name}' + kOpMap = { # torch nn functional @@ -101,6 +104,9 @@ def register(signature: str, op: IRFwOperation, code): #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, + #custom ops + __customops('strip_2_borders'): function.CustomOps, + __customops('update_diag'): function.CustomOps, } # customized operator code: signature -> code diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 4a01c9ad..ded6a9e8 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -47,4 +47,28 @@ def einops(input: torch.Tensor, recipe_str, reduction_type: str): recipe = pickle.loads(recipe_str) from einops.einops import _apply_recipe output = _apply_recipe(recipe, input, reduction_type) - return output \ No newline at end of file + return output + +############### custom op ################# +#TODO move me +def update_diag(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, + delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, + pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dz, dt): + def pre_conv3d_reshape(X): + return einops.einops.rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) + def post_conv3d_reshape(X): + return einops.einops.rearrange(X, 'b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') + def delta_x(X): + return post_conv3d_reshape(F.conv3d(pre_conv3d_reshape(X), delta_x_filter)) + def delta_y(X): + return post_conv3d_reshape(F.conv3d(pre_conv3d_reshape(X), delta_y_filter)) + # update diagnostic variable w (nz + 1, ny, nx) + for i in range(1, w.shape[0]): + w[i] = - ((delta_x(F[:i]) + delta_y(G[:i])) * dz).sum(dim=0) / deltaA / pi1 \ + - sigma[i] * (pi1 - pi0) / dt / pi1 + + return w + +def strip_2_borders(w: torch.Tensor): + return w[1:-1] + diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index 981bdd38..ff884518 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -1,18 +1,17 @@ -import math import torch -import numpy as np import torch.nn.functional as F torch.set_default_tensor_type(torch.DoubleTensor) - -from typing import List import cube from cube.runtime.syndata import SciLoopVariables from examples.atmosphere.policy.naive import PAS from einops.layers.torch import Rearrange +#custom ops +import examples.custom_ops as custom_ops + class Atmoshpere(torch.nn.Module): def __init__(self, nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, @@ -88,19 +87,21 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): pi1 = pi0 - dt / self.deltaA * ((self.delta_x(F) + self.delta_y(G)) * self.dz).sum(dim=0) #sum(axis=0) # (nz, ny, nx) # print('pi:', pi1.mean()) - # TODO a custom Op needed + # # update diagnostic variable w (nz + 1, ny, nx) # for i in range(1, self.nz + 1): # self.w[i] = - ((self.delta_x(F[:i]) + self.delta_y(G[:i])) * self.dz).sum(dim=0) / self.deltaA / pi1 \ # - self.sigma[i] * (pi1 - pi0) / dt / pi1 - + # TODO fix this custom Op + # self.w = custom_ops.update_diag(self.w, F, G, self.delta_x_filter, self.delta_y_filter, self.deltaA, + # pi0, pi1, self.sigma, self.dz, dt) # print('w:', self.w.mean()) # update potential temperature theta (nz, ny, nx) - # theta_ = self.pad_z( - # (self.bar_z(self.P * theta) - self.delta_z(theta) * self.P_[1:-1]) / self.delta_z(self.P) - # ) # (nz + 1, ny, nx) - theta_ = self.pad_z(self.bar_z(theta0)) #theta0 #TODO remove me + theta_ = self.pad_z( + (self.bar_z(self.P * theta) - self.delta_z(theta) * custom_ops.strip_2_borders(self.P_)) / self.delta_z(self.P) + ) # (nz + 1, ny, nx) + theta1 = pi0 / pi1 * theta0 + dt / self.deltaA / pi1 * ( (self.delta_x(F * self.bar_x(self.pad_x(theta))) + self.delta_y(G * self.bar_y(self.pad_y(theta)))) / 2. + pi * self.deltaA * self.delta_z(self.w * theta_) / self.dz + @@ -242,7 +243,6 @@ def laplas(self, X): if __name__ == "__main__": - import matplotlib.pyplot as plt cube.init() nz = 15 diff --git a/examples/custom_ops.py b/examples/custom_ops.py new file mode 100644 index 00000000..7ad1df3a --- /dev/null +++ b/examples/custom_ops.py @@ -0,0 +1,24 @@ +import torch +import einops + +############### custom op ################# +def update_diag(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, + delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, + pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dz:float, dt:torch.Tensor): + # def pre_conv3d_reshape(X): + # return einops.einops.rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) + # def post_conv3d_reshape(X): + # return einops.einops.rearrange(X, 'b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') + # def delta_x(X): + # return post_conv3d_reshape(F.conv3d(pre_conv3d_reshape(X), delta_x_filter)) + # def delta_y(X): + # return post_conv3d_reshape(F.conv3d(pre_conv3d_reshape(X), delta_y_filter)) + # # update diagnostic variable w (nz + 1, ny, nx) + # for i in range(1, w.shape[0]): + # w[i] = - ((delta_x(F[:i]) + delta_y(G[:i])) * dz).sum(dim=0) / deltaA / pi1 \ + # - sigma[i] * (pi1 - pi0) / dt / pi1 + + return w + +def strip_2_borders(w: torch.Tensor): + return w[1:-1] \ No newline at end of file From 0315acae1189506fe48e40743f9f17cc2debe662 Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 21 Mar 2022 11:09:34 +0800 Subject: [PATCH 0670/1892] add custom_op update_diag for example atmosphere-weather --- cube/graph/operator/function/customops.py | 6 +++--- cube/graph/operator/function/function.py | 12 ++++++------ cube/runtime/function/function.py | 2 +- examples/atmosphere/weather.py | 6 +++--- examples/custom_ops.py | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cube/graph/operator/function/customops.py b/cube/graph/operator/function/customops.py index 1fc0d1e7..23541107 100644 --- a/cube/graph/operator/function/customops.py +++ b/cube/graph/operator/function/customops.py @@ -18,9 +18,9 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, self.set_input(idx, input) elif signature == 'examples.custom_ops.update_diag': signature = signature.replace('examples.custom_ops', 'cube.runtime.function') - assert len(inputs) == 9, "Expected only input, weight, bias as inputs" - assert len(kwargs) == 2, "Expected 0 kwargs: " - super().__init__(name, signature, 1, 1) + assert len(inputs) == 10, "Expected only input, weight, bias as inputs" + assert len(kwargs) == 1, "Expected 0 kwargs: " + super().__init__(name, signature, len(inputs), 1) for idx, input in enumerate(inputs): self.set_input(idx, input) self.kwargs.update(kwargs) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 1641fd25..79df2c54 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -471,11 +471,11 @@ def CustomOps(signature, inputs): tensors = inputs[0:1] print(f'CustomOps:tensors[0] = {tensors[0]}') return IRCustomOps(signature, tensors, 'custom_ops') - elif signature == 'example.custom_ops.update_diag': - tensors = inputs[0:9] - dz = inputs[9] - dt = inputs[10] - return IRCustomOps(signature, tensors, 'custom_ops', dz=dz, dt=dt) + elif signature == 'examples.custom_ops.update_diag': + tensors = inputs[0:10] + # dt = inputs[9] + dz = inputs[10] + return IRCustomOps(signature, tensors, 'custom_ops', dz=dz) else: import warnings - warnings.warn(f"ERROR Unknown custom op, signature{signature}") \ No newline at end of file + warnings.warn(f"ERROR Unknown custom op, signature {signature}") \ No newline at end of file diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index ded6a9e8..364e4beb 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -53,7 +53,7 @@ def einops(input: torch.Tensor, recipe_str, reduction_type: str): #TODO move me def update_diag(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, - pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dz, dt): + pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dt, dz): def pre_conv3d_reshape(X): return einops.einops.rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) def post_conv3d_reshape(X): diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index ff884518..021def86 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -92,9 +92,9 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # for i in range(1, self.nz + 1): # self.w[i] = - ((self.delta_x(F[:i]) + self.delta_y(G[:i])) * self.dz).sum(dim=0) / self.deltaA / pi1 \ # - self.sigma[i] * (pi1 - pi0) / dt / pi1 - # TODO fix this custom Op - # self.w = custom_ops.update_diag(self.w, F, G, self.delta_x_filter, self.delta_y_filter, self.deltaA, - # pi0, pi1, self.sigma, self.dz, dt) + # TODO fix SetAttr for "self.w =" + w = custom_ops.update_diag(self.w, F, G, self.delta_x_filter, self.delta_y_filter, self.deltaA, + pi0, pi1, self.sigma, dt, self.dz) # print('w:', self.w.mean()) # update potential temperature theta (nz, ny, nx) diff --git a/examples/custom_ops.py b/examples/custom_ops.py index 7ad1df3a..699c4e8d 100644 --- a/examples/custom_ops.py +++ b/examples/custom_ops.py @@ -4,7 +4,7 @@ ############### custom op ################# def update_diag(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, - pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dz:float, dt:torch.Tensor): + pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dt:torch.Tensor, dz:float): # def pre_conv3d_reshape(X): # return einops.einops.rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # def post_conv3d_reshape(X): From 781d05b427c15c269e184a4d22334ee6e8b5d9f5 Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 21 Mar 2022 11:20:25 +0800 Subject: [PATCH 0671/1892] add custom_op update_diag nit fix --- cube/graph/operator/function/customops.py | 2 -- cube/runtime/function/function.py | 6 ++++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cube/graph/operator/function/customops.py b/cube/graph/operator/function/customops.py index 23541107..2f6da84a 100644 --- a/cube/graph/operator/function/customops.py +++ b/cube/graph/operator/function/customops.py @@ -41,9 +41,7 @@ def infer_shape(self) -> bool: return True elif self.signature.endswith('update_diag'): shape = self.inputs(0).shape - print(f'### {self.signature} in.shape = {shape}') self.outputs(0).shape = shape - print(f'### {self.signature} out.shape = {shape}') return True else: raise RuntimeError(f'IRCustomOps::infer_shape unknown signature: {self.signature}') diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 364e4beb..1fa967aa 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -54,15 +54,17 @@ def einops(input: torch.Tensor, recipe_str, reduction_type: str): def update_diag(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dt, dz): + import einops.einops def pre_conv3d_reshape(X): return einops.einops.rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) def post_conv3d_reshape(X): return einops.einops.rearrange(X, 'b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') def delta_x(X): - return post_conv3d_reshape(F.conv3d(pre_conv3d_reshape(X), delta_x_filter)) + return post_conv3d_reshape(torch.nn.functional.conv3d(pre_conv3d_reshape(X), delta_x_filter)) def delta_y(X): - return post_conv3d_reshape(F.conv3d(pre_conv3d_reshape(X), delta_y_filter)) + return post_conv3d_reshape(torch.nn.functional.conv3d(pre_conv3d_reshape(X), delta_y_filter)) # update diagnostic variable w (nz + 1, ny, nx) + w.detach_() #to prevent ERROR: A leaf Variable that requires grad is being used in an in-place operation. for i in range(1, w.shape[0]): w[i] = - ((delta_x(F[:i]) + delta_y(G[:i])) * dz).sum(dim=0) / deltaA / pi1 \ - sigma[i] * (pi1 - pi0) / dt / pi1 From d36bf65ba13e1f6a45662dcd3296b53d076dd23d Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 21 Mar 2022 13:28:33 +0800 Subject: [PATCH 0672/1892] 1. add custom_op update_geopotential_ 2. rename custom_ops that modifies content from abc to abc_ --- cube/graph/operator/function/customops.py | 24 ++++++++++++++++++++--- cube/graph/operator/function/function.py | 9 ++++++++- cube/graph/parser/mapping.py | 3 ++- cube/runtime/function/function.py | 15 +++++++++++++- examples/atmosphere/weather.py | 4 ++-- examples/custom_ops.py | 12 +++++++++++- 6 files changed, 58 insertions(+), 9 deletions(-) diff --git a/cube/graph/operator/function/customops.py b/cube/graph/operator/function/customops.py index 2f6da84a..c3f71649 100644 --- a/cube/graph/operator/function/customops.py +++ b/cube/graph/operator/function/customops.py @@ -16,7 +16,7 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, super().__init__(name, signature, 1, 1) for idx, input in enumerate(inputs): self.set_input(idx, input) - elif signature == 'examples.custom_ops.update_diag': + elif signature == 'examples.custom_ops.update_diag_': signature = signature.replace('examples.custom_ops', 'cube.runtime.function') assert len(inputs) == 10, "Expected only input, weight, bias as inputs" assert len(kwargs) == 1, "Expected 0 kwargs: " @@ -24,6 +24,14 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, for idx, input in enumerate(inputs): self.set_input(idx, input) self.kwargs.update(kwargs) + elif signature == 'examples.custom_ops.update_geopotential_': + signature = signature.replace('examples.custom_ops', 'cube.runtime.function') + assert len(inputs) == 5, "Expected only input, weight, bias as inputs" + assert len(kwargs) == 3, "Expected 0 kwargs: " + super().__init__(name, signature, len(inputs), 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update(kwargs) else: raise RuntimeError(f'IRCustomOps::__init__ unknown signature: {self.signature}') @@ -39,7 +47,11 @@ def infer_shape(self) -> bool: shape[0] = shape[0]-2 self.outputs(0).shape = shape return True - elif self.signature.endswith('update_diag'): + elif self.signature.endswith('update_diag_'): + shape = self.inputs(0).shape + self.outputs(0).shape = shape + return True + elif self.signature.endswith('update_geopotential_'): shape = self.inputs(0).shape self.outputs(0).shape = shape return True @@ -57,7 +69,13 @@ def new(self, inputs: List, outputs: List): op.set_output(0, outputs[0]) op.infer_shape() return op - elif self.signature.endswith('update_diag'): + elif self.signature.endswith('update_diag_'): + op = IRCustomOps(self.signature, inputs, self.name, self.kwargs) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + op.infer_shape() + return op + elif self.signature.endswith('update_geopotential_'): op = IRCustomOps(self.signature, inputs, self.name, self.kwargs) assert len(outputs) == 1 op.set_output(0, outputs[0]) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 79df2c54..c721b6c7 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -471,11 +471,18 @@ def CustomOps(signature, inputs): tensors = inputs[0:1] print(f'CustomOps:tensors[0] = {tensors[0]}') return IRCustomOps(signature, tensors, 'custom_ops') - elif signature == 'examples.custom_ops.update_diag': + elif signature == 'examples.custom_ops.update_diag_': tensors = inputs[0:10] # dt = inputs[9] dz = inputs[10] return IRCustomOps(signature, tensors, 'custom_ops', dz=dz) + elif signature == 'examples.custom_ops.update_geopotential_': + tensors = inputs[0:5] + g = inputs[5] + CPD = inputs[6] + nz = inputs[7] + return IRCustomOps(signature, tensors, 'custom_ops', g=g, CPD=CPD, nz=nz) + else: import warnings warnings.warn(f"ERROR Unknown custom op, signature {signature}") \ No newline at end of file diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index fda88819..b984ac94 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -106,7 +106,8 @@ def register(signature: str, op: IRFwOperation, code): #custom ops __customops('strip_2_borders'): function.CustomOps, - __customops('update_diag'): function.CustomOps, + __customops('update_diag_'): function.CustomOps, + __customops('update_geopotential_'): function.CustomOps, } # customized operator code: signature -> code diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 1fa967aa..546844ed 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -51,7 +51,7 @@ def einops(input: torch.Tensor, recipe_str, reduction_type: str): ############### custom op ################# #TODO move me -def update_diag(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, +def update_diag_(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dt, dz): import einops.einops @@ -64,6 +64,8 @@ def delta_x(X): def delta_y(X): return post_conv3d_reshape(torch.nn.functional.conv3d(pre_conv3d_reshape(X), delta_y_filter)) # update diagnostic variable w (nz + 1, ny, nx) + import warnings + warnings.warn("detaching w in update_diag_...") w.detach_() #to prevent ERROR: A leaf Variable that requires grad is being used in an in-place operation. for i in range(1, w.shape[0]): w[i] = - ((delta_x(F[:i]) + delta_y(G[:i])) * dz).sum(dim=0) / deltaA / pi1 \ @@ -71,6 +73,17 @@ def delta_y(X): return w +def update_geopotential_(phi: torch.Tensor, zs: torch.Tensor, P: torch.Tensor, P_: torch.Tensor, theta: torch.Tensor, g, CPD, nz): + import warnings + warnings.warn("detaching phi in update_geopotential_...") + phi.detach_() + phi[-1] = g * zs - CPD * (P[-1] - P_[-1]) * theta[-1] + for i in range(1, nz): + tmp = phi[-i] - CPD * (P_[-i - 1] - P[-i]) * theta[-i] + phi[-1 - i] = tmp - CPD * (P[-1 - i] - P_[-1 - i]) * theta[-1 - i] + + return phi + def strip_2_borders(w: torch.Tensor): return w[1:-1] diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py index 021def86..12dc8261 100644 --- a/examples/atmosphere/weather.py +++ b/examples/atmosphere/weather.py @@ -93,7 +93,7 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # self.w[i] = - ((self.delta_x(F[:i]) + self.delta_y(G[:i])) * self.dz).sum(dim=0) / self.deltaA / pi1 \ # - self.sigma[i] * (pi1 - pi0) / dt / pi1 # TODO fix SetAttr for "self.w =" - w = custom_ops.update_diag(self.w, F, G, self.delta_x_filter, self.delta_y_filter, self.deltaA, + custom_ops.update_diag_(self.w, F, G, self.delta_x_filter, self.delta_y_filter, self.deltaA, pi0, pi1, self.sigma, dt, self.dz) # print('w:', self.w.mean()) @@ -116,7 +116,7 @@ def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): # for i in range(1, self.nz): # tmp = self.phi[-i] - self.CPD * (self.P_[-i - 1] - self.P[-i]) * theta[-i] # self.phi[-1 - i] = tmp - self.CPD * (self.P[-1 - i] - self.P_[-1 - i]) * theta[-1 - i] - + custom_ops.update_geopotential_(self.phi, self.zs, self.P, self.P_, theta, self.g, self.CPD, self.nz) # print('phi:', self.phi.mean()) # update u (nz, ny, nx + 1) diff --git a/examples/custom_ops.py b/examples/custom_ops.py index 699c4e8d..151c15ef 100644 --- a/examples/custom_ops.py +++ b/examples/custom_ops.py @@ -2,9 +2,10 @@ import einops ############### custom op ################# -def update_diag(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, +def update_diag_(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dt:torch.Tensor, dz:float): + #NOTE place holder # def pre_conv3d_reshape(X): # return einops.einops.rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # def post_conv3d_reshape(X): @@ -20,5 +21,14 @@ def update_diag(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, return w +def update_geopotential_(phi: torch.Tensor, zs: torch.Tensor, P: torch.Tensor, P_: torch.Tensor, theta: torch.Tensor, + g:float, CPD:float, nz:int): + # NOTE place holder + # phi[-1] = g * zs - CPD * (P[-1] - P_[-1]) * theta[-1] + # for i in range(1, nz): + # tmp = phi[-i] - CPD * (P_[-i - 1] - P[-i]) * theta[-i] + # phi[-1 - i] = tmp - CPD * (P[-1 - i] - P_[-1 - i]) * theta[-1 - i] + return phi + def strip_2_borders(w: torch.Tensor): return w[1:-1] \ No newline at end of file From 2ed1aa71a92ca3d659d2ce41988be0ffe311991d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 21 Mar 2022 15:43:47 +0800 Subject: [PATCH 0673/1892] switch to 3b model and enable re-compute --- handcraft/mbart/mbart.py | 281 +++++++++----------------------- handcraft/mbart/mbart_hybrid.py | 170 ++++++++----------- handcraft/mbart/run.sh | 38 +++-- handcraft/mbart/schedule.py | 4 +- 4 files changed, 174 insertions(+), 319 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 9fa0e150..4ab27bac 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -11,6 +11,7 @@ import argparse import math import torch +from torch.utils import checkpoint import cube from cube.runtime.device import DeviceGroup @@ -20,9 +21,9 @@ from cube.profiler.timer import print_each_rank from handcraft.mbart.schedule import schedule_naive, schedule_1f1b, schedule_tp_1f1b_pack -from handcraft.mbart.tp import AllGatherScatter, BroadcastReduce, ReduceBroadcast +from handcraft.mbart.tp import ReduceBroadcast + -_tp_group = -1 _pp_group = -1 _pp_embed_group = -1 _pp_next_rank = None @@ -38,23 +39,20 @@ class Config: num_embeddings = 250027 - encoder_embed_path = None - encoder_embed_dim = 1024 - encoder_ffn_embed_dim = 4 * 1024 + # d_ff + decoder_layers = 12 encoder_layers = 12 - encoder_attention_heads = 16 - encoder_normalize_before = True - encoder_learned_pos = True + embed_dim = 1024 - decoder_embed_path = None - decoder_embed_dim = 1024 - decoder_ffn_embed_dim = 4 * 1024 - decoder_layers = 12 - decoder_attention_heads = 16 - decoder_normalize_before = True - decoder_learned_pos = True - cross_self_attention = False - no_cross_attention = False + # 610M model -> original setting + # attention_heads = 16 + # attention_inner_dim = attention_heads * 64 + # ffn_dim = 4 * embed_dim + + # t5-3b config + attention_heads = 32 + attention_inner_dim = attention_heads * 128 + ffn_dim = 16384 attention_dropout = 0.0 activation_dropout = 0.0 @@ -68,17 +66,9 @@ class Config: share_decoder_input_output_embed = True share_all_embeddings = True - decoder_output_dim = 1024 # same with decorder_embed_dim - decoder_input_dim = 1024 # same with decorder_embed_dim - - no_scale_embedding = False # True in bart large - layernorm_embedding = True - activation_fn = 'gelu' - pooler_activation_fn = 'tanh' - pooler_dropout = 0.0 - # classification task - num_classes = 2 + pooler_dropout = 0.0 + num_classes = 3 def attn_fn(query: torch.Tensor, key: torch.Tensor, @@ -131,63 +121,35 @@ def attn_fn(query: torch.Tensor, key: torch.Tensor, class MultiheadAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout=0.0, bias=True): + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, bias=True): super().__init__() - self.kdim = embed_dim - self.vdim = embed_dim + self.inner_dim = inner_dim self.num_heads = num_heads - self.head_dim = embed_dim // num_heads + self.head_dim = inner_dim // num_heads self.scaling = self.head_dim ** -0.5 self.dropout_p = dropout # K - self.k_proj = torch.nn.Parameter(torch.empty(embed_dim, self.kdim)) - if bias: - self.k_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.k_bias = None + self.k_proj = torch.nn.Parameter(torch.empty(self.inner_dim, embed_dim)) + self.k_bias = torch.nn.Parameter(torch.empty(self.inner_dim)) if bias else None # V - self.v_proj = torch.nn.Parameter(torch.empty(embed_dim, self.vdim)) - if bias: - self.v_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.v_bias = None + self.v_proj = torch.nn.Parameter(torch.empty(self.inner_dim, embed_dim)) + self.v_bias = torch.nn.Parameter(torch.empty(self.inner_dim)) if bias else None # Q - self.q_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) - if bias: - self.q_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.q_bias = None + self.q_proj = torch.nn.Parameter(torch.empty(self.inner_dim, embed_dim)) + self.q_bias = torch.nn.Parameter(torch.empty(self.inner_dim)) if bias else None # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) - if bias: - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.out_bias = None + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, self.inner_dim)) + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None def forward(self, query: torch.Tensor, key: torch.Tensor): return attn_fn(query, key, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) def forward_encoder_decoder_attn(self, query: torch.Tensor, key: torch.Tensor): - # tgt_len, bsz, embed_dim = query.size() - # q = torch.nn.functional.linear(query, self.q_proj, self.q_bias) - # k = torch.nn.functional.linear(key, self.k_proj, self.k_bias) - # v = torch.nn.functional.linear(key, self.v_proj, self.v_bias) - # q = q * self.scaling - # q = q.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) - # k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) - # v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) - # attn_weights = torch.bmm(q, k.transpose(1, 2)) - # # TODO: here needs a mask - # attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - # attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_p) - # attn = torch.bmm(attn_probs, v) - # attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) - # attn = torch.nn.functional.linear(attn, self.out_proj, self.out_bias) return attn_fn(query, key, self.q_proj, self.q_bias, self.k_proj, self.k_bias, @@ -210,21 +172,21 @@ def __init__(self, cfg: Config): super().__init__() self.cfg = cfg - self.self_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.encoder_attention_heads, cfg.attention_dropout) - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) self.dropout = torch.nn.Dropout(p=cfg.dropout) self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.fc1 = torch.nn.Linear(cfg.encoder_embed_dim, cfg.encoder_ffn_embed_dim) - self.fc2 = torch.nn.Linear(cfg.encoder_ffn_embed_dim, cfg.encoder_embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim) + self.fc2 = torch.nn.Linear(cfg.ffn_dim, cfg.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) def input_shape(self): # L, N, E - return (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + return (self.cfg.max_source_positions, 1, self.cfg.embed_dim) def output_shape(self): # L N E - return (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + return (self.cfg.max_source_positions, 1, self.cfg.embed_dim) def input_dtype(self): return torch.float32 @@ -252,38 +214,6 @@ def forward(self, x): # , encoder_padding_mask: Optional[torch.Tensor], attn_ma return x -class Encoder(torch.nn.Module): - - def __init__(self, cfg: Config, embed_tokens: torch.nn.Embedding): - super().__init__() - self.dropout = torch.nn.Dropout(cfg.dropout) - self.max_source_positions = cfg.max_source_positions - self.embed_tokens = embed_tokens - self.embed_scale = math.sqrt(cfg.encoder_embed_dim) - self.embed_positions = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) - self.layernorm_embedding = torch.nn.LayerNorm(cfg.encoder_embed_dim) - self.layers = torch.nn.ModuleList([]) - self.layers.extend( - [EncoderLayer(cfg) for _ in range(cfg.encoder_layers)] - ) - # normalize before - self.layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) - - def forward(self, src_tokens: torch.Tensor): - token_embedding = torch.nn.functional.embedding(src_tokens, self.embed_tokens.weight) # self.embed_tokens(src_tokens) - embed = self.embed_scale * token_embedding - - x = embed + self.embed_positions.weight # self.embed_positions(src_tokens) - x = self.layernorm_embedding(x) - x = self.dropout(x) - - x = x.transpose(0, 1) - for layer in self.layers: - x = layer(x) # encoder_padding_mask if has_pads else None) - x = self.layer_norm(x) - return x - - class DecoderLayer(torch.nn.Module): def __init__(self, cfg: Config): @@ -291,29 +221,28 @@ def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.self_attn = MultiheadAttention(cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) + self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) # encoder atten - self.encoder_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) - self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) - # self.encoder_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.encoder_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) + self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - self.fc1 = torch.nn.Linear(cfg.decoder_embed_dim, cfg.decoder_ffn_embed_dim) - self.fc2 = torch.nn.Linear(cfg.decoder_ffn_embed_dim, cfg.decoder_embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim) + self.fc2 = torch.nn.Linear(cfg.ffn_dim, cfg.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) def input_shape(self): return ( - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + (self.cfg.max_target_positions, 1, self.cfg.embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim) ) def output_shape(self): return ( - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + (self.cfg.max_target_positions, 1, self.cfg.embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim) ) def input_dtype(self): @@ -322,10 +251,9 @@ def input_dtype(self): def output_dtype(self): return (torch.float32, torch.float32) - def forward(self, x, encoder_out): # encoder_padding_mask): + def forward(self, x, encoder_out): # print(f'decoder layer: x: {x.size()}, encoder_out: {encoder_out.size()}') residual = x - # normalize before x = self.self_attn_layer_norm(x) # self attention @@ -335,14 +263,12 @@ def forward(self, x, encoder_out): # encoder_padding_mask): # encoder attn residual = x - # normalize before x = self.encoder_attn_layer_norm(x) x = self.encoder_attn(x, encoder_out) x = self.dropout(x) x = x + residual residual = x - # normalize before x = self.final_layer_norm(x) x = self.fc1(x) x = torch.nn.functional.gelu(x) @@ -353,60 +279,6 @@ def forward(self, x, encoder_out): # encoder_padding_mask): return x, encoder_out -class Decoder(torch.nn.Module): - - def __init__(self, cfg: Config, embed_tokens: torch.nn.Embedding): - super().__init__() - self.dropout = torch.nn.Dropout(cfg.dropout) - self.embed_tokens = embed_tokens - self.embed_scale = math.sqrt(cfg.decoder_embed_dim) - self.embed_positions = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) - self.layernorm_embedding = torch.nn.LayerNorm(cfg.decoder_embed_dim) - self.layers = torch.nn.ModuleList([]) - self.layers.extend( - [DecoderLayer(cfg) for _ in range(cfg.decoder_layers)] - ) - self.layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) - - def forward(self, prev_output_tokens: torch.Tensor, enc: torch.Tensor): - positions = self.embed_positions.weight # self.embed_positions(prev_output_tokens) - embed = torch.nn.functional.embedding(prev_output_tokens, self.embed_tokens.weight) - x = self.embed_scale * embed - x = x + positions - x = self.layernorm_embedding(x) - x = self.dropout(x) - # B T C -> T B C - x = x.transpose(0, 1) - # decoder layers - for layer in self.layers: - x, enc = layer(x, enc) - x = self.layer_norm(x) - # T x B x C -> B x T x C - x = x.transpose(0, 1) - # B T C, N, C -> B T N - x = torch.nn.functional.linear(x, self.embed_tokens.weight) - return x - -# label_smoothed_cross_entropy -def criterion(output: torch.Tensor, prev_output_tokens: torch.Tensor, label_smoothing: float = 0.2): - target = prev_output_tokens[:, 1:] - # fairseq.criterions.label_smoothed_cross_entory - # model.get_normalized_probs - lprobs = torch.nn.functional.softmax(output, dim=-1) - # fairseq.criterions.label_smoothed_nll_loss - if target.dim() == lprobs.dim() - 1: - target = target.unsqueeze(-1) - nll_loss = -lprobs.gather(dim=-1, index=target) - smooth_loss = -lprobs.sum(dim=-1, keepdim=True) - nll_loss = nll_loss.squeeze(-1) - smooth_loss = smooth_loss.squeeze(-1) - nll_loss = nll_loss.sum() - smooth_loss = smooth_loss.sum() - eps_i = label_smoothing / (lprobs.size(-1) - 1) - loss = (1.0 - label_smoothing - eps_i) * nll_loss + eps_i * smooth_loss - return loss - - class MBartClassificationHead(torch.nn.Module): """Head for sentence-level classification tasks.""" @@ -437,7 +309,7 @@ def forward(self, dec: torch.Tensor, labels): return (loss,) -class ShardHeadTail(torch.nn.Module): +class ShardEmbed(torch.nn.Module): def __init__(self, cfg: Config, group=-1): """ @@ -453,17 +325,17 @@ def __init__(self, cfg: Config, group=-1): self.vocab_start_index = self.cfg.num_embeddings // self.shard_num * self.shard_idx self.vocab_end_index = self.cfg.num_embeddings // self.shard_num * (self.shard_idx + 1) - self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.shard_num, self.cfg.encoder_embed_dim))) + self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.shard_num, self.cfg.embed_dim))) # encoder-preprocess - self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) - self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) - self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.embed_dim) + self.embed_scale_encoder = math.sqrt(cfg.embed_dim) + self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.embed_dim) # decoder-preprocess - self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) - self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) - self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.embed_scale_decoder = math.sqrt(cfg.embed_dim) + self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.embed_dim) + self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.embed_dim) # post-proces @@ -546,7 +418,7 @@ def __init__(self, cfg: Config, self.decoder_layer_end = self.layer_end if encoder_preprocess or decoder_preprocess or post_process or shard: - self.headtail = ShardHeadTail(cfg, group = None if shard else -1) + self.headtail = ShardEmbed(cfg, group = None if shard else -1) else: self.headtail = None @@ -555,7 +427,7 @@ def __init__(self, cfg: Config, print(f'[{self.rank}]: initializing {self.encoder_layer_end - self.encoder_layer_start} encoder layers') self.encoders = torch.nn.ModuleList([EncoderLayer(cfg) for _ in range(self.encoder_layer_end - self.encoder_layer_start)]) if self.encoder_layer_end == cfg.encoder_layers: - self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.layer_norm_encoder = torch.nn.LayerNorm(cfg.embed_dim) else: self.layer_norm_encoder = None @@ -564,34 +436,34 @@ def __init__(self, cfg: Config, print(f'[{self.rank}]: initializing {self.decoder_layer_end - self.decoder_layer_start} decoder layers') self.decoders = torch.nn.ModuleList([DecoderLayer(cfg) for _ in range(self.decoder_layer_end - self.decoder_layer_start)]) if self.decoder_layer_end == cfg.encoder_layers + cfg.decoder_layers: - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.embed_dim) else: self.layer_norm_decoder = None # postpross if self.postprocess: print(f'[{self.rank}]: will compute loss') - self.head = MBartClassificationHead(cfg.decoder_embed_dim, 1024, cfg.num_classes, 0.0) + self.head = MBartClassificationHead(cfg.embed_dim, 1024, cfg.num_classes, 0.0) def input_shape(self): if self.encoder_preprocess: return () elif self.encoder_forward: return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), ) elif self.decoder_preprocess: return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), ) elif self.decoder_first_stage: return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), ) elif self.decoder_forward: return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.embed_dim), ) elif self.postprocess: return ((1,),) @@ -600,12 +472,12 @@ def input_shape(self): def output_shape(self): shape = None if self.encoder_preprocess or self.encoder_forward: - shape = (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + shape = (self.cfg.max_source_positions, 1, self.cfg.embed_dim), # decoder preprocess is not allowed to be a single stage if self.decoder_preprocess or self.decoder_forward: shape = ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.embed_dim), ) if self.postprocess: shape = ( @@ -677,7 +549,8 @@ def forward(self, enc=None, dec=None): # forward encoder if self.encoder_forward: for layer in self.encoders: - enc = layer(enc) # encoder_padding_mask if has_pads else None) + enc = checkpoint.checkpoint(layer, enc) + # enc = layer(enc) if self.layer_norm_encoder is not None: enc = self.layer_norm_encoder(enc) output = (enc,) @@ -691,7 +564,8 @@ def forward(self, enc=None, dec=None): if self.decoder_forward: dec = pre_dec if dec is None else dec for layer in self.decoders: - dec, enc = layer(dec, enc) + dec, enc = checkpoint.checkpoint(layer, dec, enc) + # dec, enc = layer(dec, enc) if self.layer_norm_decoder is not None: dec = self.layer_norm_decoder(dec) output = (dec,) @@ -799,6 +673,9 @@ def reduce_embed(model, pp_embed_group): memory_summary() optimizer.step() optimizer.zero_grad() + if step == 0: + print('memory after optimizer') + memory_summary() if step >= 10: CudaTimer().stop('e2e') if (step + 1) % 10 == 0: diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 517658c4..e39c4b73 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -11,6 +11,7 @@ import argparse import math import torch +from torch.utils import checkpoint import cube from cube.runtime.device import DeviceGroup @@ -20,7 +21,7 @@ from cube.profiler.timer import print_each_rank from handcraft.mbart.schedule import schedule_naive, schedule_1f1b, schedule_tp_1f1b_pack -from handcraft.mbart.tp import AllGatherScatter, AllReduceIdentity, BroadcastReduce, IdentityAllreduce, ReduceBroadcast +from handcraft.mbart.tp import AllReduceIdentity, IdentityAllreduce, ReduceBroadcast _tp_group = -1 _pp_group = -1 @@ -28,33 +29,25 @@ _pp_next_rank = None _pp_prev_rank = None -# fairseq task -# translation_from_pretrained_bart - -# fairseq criterion -# label_smoothed_cross_entropy, --label_smoothing = 0.2 class Config: num_embeddings = 250027 - encoder_embed_path = None - encoder_embed_dim = 1024 - encoder_ffn_embed_dim = 4 * 1024 + # d_ff + decoder_layers = 12 encoder_layers = 12 - encoder_attention_heads = 16 - encoder_normalize_before = True - encoder_learned_pos = True + embed_dim = 1024 - decoder_embed_path = None - decoder_embed_dim = 1024 - decoder_ffn_embed_dim = 4 * 1024 - decoder_layers = 12 - decoder_attention_heads = 16 - decoder_normalize_before = True - decoder_learned_pos = True - cross_self_attention = False - no_cross_attention = False + # 610M model -> original setting + # attention_heads = 16 + # attention_inner_dim = attention_heads * 64 + # ffn_dim = 4 * embed_dim + + # t5-3b config + attention_heads = 32 + attention_inner_dim = attention_heads * 128 + ffn_dim = 16384 attention_dropout = 0.0 activation_dropout = 0.0 @@ -68,17 +61,9 @@ class Config: share_decoder_input_output_embed = True share_all_embeddings = True - decoder_output_dim = 1024 # same with decorder_embed_dim - decoder_input_dim = 1024 # same with decorder_embed_dim - - no_scale_embedding = False # True in bart large - layernorm_embedding = True - activation_fn = 'gelu' - pooler_activation_fn = 'tanh' - pooler_dropout = 0.0 - # classification task - num_classes = 2 + pooler_dropout = 0.0 + num_classes = 3 def attn_fn(query: torch.Tensor, key: torch.Tensor, @@ -131,42 +116,29 @@ def attn_fn(query: torch.Tensor, key: torch.Tensor, class MultiheadAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout=0.0, bias=True): + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, bias=True): super().__init__() self.tp_group = _tp_group self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - self.qdim = embed_dim - self.kdim = embed_dim - self.vdim = embed_dim - self.head_dim = embed_dim // num_heads - self.scaling = self.head_dim ** -0.5 + self.inner_dim = inner_dim self.num_heads = num_heads + self.head_dim = inner_dim // num_heads + self.scaling = self.head_dim ** -0.5 self.dropout_p = dropout # K - self.k_proj = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size, self.kdim)) - if bias: - self.k_bias = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size)) - else: - self.k_bias = None + self.k_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) + self.k_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None # V - self.v_proj = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size, self.vdim)) - if bias: - self.v_bias = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size)) - else: - self.v_bias = None + self.v_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) + self.v_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None # Q - self.q_proj = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size, self.qdim)) - if bias: - self.q_bias = torch.nn.Parameter(torch.empty(embed_dim // self.tp_size)) - else: - self.q_bias = None + self.q_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) + self.q_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim // self.tp_size)) - if bias: - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.out_bias = None + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, self.inner_dim // self.tp_size)) + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + def forward(self, query: torch.Tensor, key: torch.Tensor): if key is not query: @@ -190,21 +162,21 @@ def __init__(self, cfg: Config): self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) self.cfg = cfg - self.self_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.encoder_attention_heads, cfg.attention_dropout) - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) self.dropout = torch.nn.Dropout(p=cfg.dropout) self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.fc1 = torch.nn.Linear(cfg.encoder_embed_dim, cfg.encoder_ffn_embed_dim // self.tp_size) - self.fc2 = torch.nn.Linear(cfg.encoder_ffn_embed_dim // self.tp_size, cfg.encoder_embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) + self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) def input_shape(self): # L, N, E - return (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + return (self.cfg.max_source_positions, 1, self.cfg.embed_dim) def output_shape(self): # L N E - return (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + return (self.cfg.max_source_positions, 1, self.cfg.embed_dim) def input_dtype(self): return torch.float32 @@ -247,28 +219,28 @@ def __init__(self, cfg: Config): self.cfg = cfg self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.self_attn = MultiheadAttention(cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) + self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) # encoder atten - self.encoder_attn = MultiheadAttention(cfg.encoder_embed_dim, cfg.decoder_attention_heads, cfg.attention_dropout) - self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.encoder_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) + self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - self.fc1 = torch.nn.Linear(cfg.decoder_embed_dim, cfg.decoder_ffn_embed_dim // self.tp_size) - self.fc2 = torch.nn.Linear(cfg.decoder_ffn_embed_dim // self.tp_size, cfg.decoder_embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) + self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) def input_shape(self): return ( - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + (self.cfg.max_target_positions, 1, self.cfg.embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim) ) def output_shape(self): return ( - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim) + (self.cfg.max_target_positions, 1, self.cfg.embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim) ) def input_dtype(self): @@ -349,7 +321,7 @@ def forward(self, dec: torch.Tensor, labels): return (loss,) -class ShardHeadTail(torch.nn.Module): +class SharedEmbed(torch.nn.Module): def __init__(self, cfg: Config): """ @@ -366,17 +338,17 @@ def __init__(self, cfg: Config): self.vocab_start_index = self.cfg.num_embeddings // self.tp_size * self.tp_idx self.vocab_end_index = self.cfg.num_embeddings // self.tp_size * (self.tp_idx + 1) - self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.tp_size, self.cfg.encoder_embed_dim))) + self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.tp_size, self.cfg.embed_dim))) # encoder-preprocess - self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.encoder_embed_dim) - self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) - self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.embed_dim) + self.embed_scale_encoder = math.sqrt(cfg.embed_dim) + self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.embed_dim) # decoder-preprocess - self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) - self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.decoder_embed_dim) - self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.embed_scale_decoder = math.sqrt(cfg.embed_dim) + self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.embed_dim) + self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.embed_dim) self._inputs = (None, ) @@ -457,7 +429,7 @@ def __init__(self, cfg: Config, self.decoder_layer_end = self.layer_end if encoder_preprocess or decoder_preprocess or post_process: - self.headtail = ShardHeadTail(cfg) + self.headtail = SharedEmbed(cfg) else: self.headtail = None @@ -466,7 +438,7 @@ def __init__(self, cfg: Config, print(f'[{self.rank}]: initializing {self.encoder_layer_end - self.encoder_layer_start} encoder layers') self.encoders = torch.nn.ModuleList([EncoderLayer(cfg) for _ in range(self.encoder_layer_end - self.encoder_layer_start)]) if self.encoder_layer_end == cfg.encoder_layers: - self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.layer_norm_encoder = torch.nn.LayerNorm(cfg.embed_dim) else: self.layer_norm_encoder = None @@ -475,34 +447,34 @@ def __init__(self, cfg: Config, print(f'[{self.rank}]: initializing {self.decoder_layer_end - self.decoder_layer_start} decoder layers') self.decoders = torch.nn.ModuleList([DecoderLayer(cfg) for _ in range(self.decoder_layer_end - self.decoder_layer_start)]) if self.decoder_layer_end == cfg.encoder_layers + cfg.decoder_layers: - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.embed_dim) else: self.layer_norm_decoder = None # postpross if self.postprocess: print(f'[{self.rank}]: will compute loss') - self.head = MBartClassificationHead(cfg.decoder_embed_dim, 1024, cfg.num_classes, 0.0) + self.head = MBartClassificationHead(cfg.embed_dim, 1024, cfg.num_classes, 0.0) def input_shape(self): if self.encoder_preprocess: return () elif self.encoder_forward: return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), ) elif self.decoder_preprocess: return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), ) elif self.decoder_first_stage: return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), ) elif self.decoder_forward: return ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.embed_dim), ) elif self.postprocess: return ((1,),) @@ -511,12 +483,12 @@ def input_shape(self): def output_shape(self): shape = None if self.encoder_preprocess or self.encoder_forward: - shape = (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), + shape = (self.cfg.max_source_positions, 1, self.cfg.embed_dim), # decoder preprocess is not allowed to be a single stage if self.decoder_preprocess or self.decoder_forward: shape = ( - (self.cfg.max_source_positions, 1, self.cfg.encoder_embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.decoder_embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim), + (self.cfg.max_target_positions, 1, self.cfg.embed_dim), ) if self.postprocess: shape = ( @@ -588,7 +560,8 @@ def forward(self, enc=None, dec=None): # forward encoder if self.encoder_forward: for layer in self.encoders: - enc = layer(enc) # encoder_padding_mask if has_pads else None) + enc = checkpoint.checkpoint(layer, enc) + # enc = layer(enc) if self.layer_norm_encoder is not None: enc = self.layer_norm_encoder(enc) output = (enc,) @@ -602,7 +575,8 @@ def forward(self, enc=None, dec=None): if self.decoder_forward: dec = pre_dec if dec is None else dec for layer in self.decoders: - dec, enc = layer(dec, enc) + dec, enc = checkpoint.checkpoint(layer, dec, enc) + # dec, enc = layer(dec, enc) if self.layer_norm_decoder is not None: dec = self.layer_norm_decoder(dec) output = (dec,) @@ -663,7 +637,7 @@ def reduce_embed(model, pp_embed_group): if len(pp_ranks) > 1: pranks = [torch.zeros((args.pp_size,), dtype=torch.int, device=torch.cuda.current_device()) for _ in range(args.tp_size)] prank = torch.tensor(pp_ranks, dtype=torch.int).cuda() - pranks[torch.distributed.get_rank(_pp_group)] = prank + pranks[torch.distributed.get_rank(_tp_group)] = prank torch.distributed.all_gather(pranks, prank, group=_tp_group) torch.cuda.synchronize() print_each_rank(f'allgather-pp ranks: {pranks}') diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh index bd15c737..4d524a6a 100755 --- a/handcraft/mbart/run.sh +++ b/handcraft/mbart/run.sh @@ -1,33 +1,37 @@ +evaldir=eval/3b-checkpoint + +mkdir -p ${evaldir} + # 4 gpus -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 64 > 4dev64nmb-tp1f1b-pack.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-naive --nmb 64 > 4dev64nmb-naive.txt +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 64 > ${evaldir}/4dev64nmb-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 64 > 4dev64nmb-tp.txt + handcraft/mbart/mbart.py --use-naive --nmb 64 > ${evaldir}/4dev64nmb-naive.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > 4dev64nmb-tp2pp2.txt + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 64 > ${evaldir}/4dev64nmb-tp.txt -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/pipeline/dummy_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > 4dev64nmb-2tp2pp.txt +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > ${evaldir}/4dev64nmb-tp2pp2.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/pipeline/dummy_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > ${evaldir}/4dev64nmb-2tp2pp.txt # 8 gpus -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 128 > 8dev128nmb-tp1f1b-pack.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-naive --nmb 128 > 8dev128nmb-naive.txt +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 128 > ${evaldir}/8dev128nmb-tp1f1b-pack.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart.py --use-naive --nmb 128 > ${evaldir}/8dev128nmb-naive.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 128 > 8dev128nmb-tp.txt + handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 128 > ${evaldir}/8dev128nmb-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 128 > 8dev128nmb-tp4pp2.txt + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 128 > ${evaldir}/8dev128nmb-tp4pp2.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 128 > 8dev128nmb-tp2pp4.txt \ No newline at end of file + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 128 > ${evaldir}/8dev128nmb-tp2pp4.txt diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py index 21a2e9e5..c51e5dfd 100644 --- a/handcraft/mbart/schedule.py +++ b/handcraft/mbart/schedule.py @@ -462,12 +462,12 @@ def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors): # before running 1f1b: need to recv first forward tensor if num_warmup_remaining > 0: - model.set_inputs(*next(dataloader)) + model.set_inputs(next(dataloader)) inputs = () if is_first_stage else recv_forward(model, prev_rank) # run 1f1b for i in range(num_warmup_remaining): - model.set_inputs(*next(dataloader)) + model.set_inputs(next(dataloader)) # forward outputs = forward_step(model, *inputs) input_tensors.append(inputs) From a8dc975c6406d5bab5cea00bec409fc07f1c9b02 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 22 Mar 2022 14:24:22 +0800 Subject: [PATCH 0674/1892] mbart scaling --- handcraft/mbart/mbart.py | 81 ++++++++++++++------- handcraft/mbart/mbart_hybrid.py | 121 ++++++++++++++++++++++---------- 2 files changed, 139 insertions(+), 63 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 4ab27bac..d044b067 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -37,22 +37,26 @@ class Config: - num_embeddings = 250027 - - # d_ff - decoder_layers = 12 - encoder_layers = 12 - embed_dim = 1024 # 610M model -> original setting + # num_embeddings = 250027 + # decoder_layers = 12 + # encoder_layers = 12 + # embed_dim = 1024 # attention_heads = 16 # attention_inner_dim = attention_heads * 64 # ffn_dim = 4 * embed_dim - # t5-3b config - attention_heads = 32 - attention_inner_dim = attention_heads * 128 - ffn_dim = 16384 + scale = 2 + scale_p = scale * 0.25 + + num_embeddings = 250027 + int(250027*scale_p) + decoder_layers = 12 + int(12*scale_p) + encoder_layers = 12 + int(12*scale_p) + embed_dim = 1024 + int(1024*scale_p) + attention_heads = 16 + int(16*scale_p) + attention_inner_dim = attention_heads * 64 + ffn_dim = 4 * embed_dim attention_dropout = 0.0 activation_dropout = 0.0 @@ -60,11 +64,6 @@ class Config: max_target_positions = 1024 max_source_positions = 1024 - adaptive_softmax_cutoff = None - adaptive_softmax_dropout = 0 - - share_decoder_input_output_embed = True - share_all_embeddings = True # classification task pooler_dropout = 0.0 @@ -119,6 +118,19 @@ def attn_fn(query: torch.Tensor, key: torch.Tensor, return output +class PositionalEmbedding(torch.nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, seq_len: int): + positions = torch.arange( + 0, seq_len, dtype=torch.long, device=torch.cuda.current_device() + ) + return super().forward(positions + self.offset) + + class MultiheadAttention(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, bias=True): @@ -328,17 +340,15 @@ def __init__(self, cfg: Config, group=-1): self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.shard_num, self.cfg.embed_dim))) # encoder-preprocess - self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.embed_dim) + self.embed_positions_encoder = PositionalEmbedding(cfg.max_source_positions, cfg.embed_dim) self.embed_scale_encoder = math.sqrt(cfg.embed_dim) self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.embed_dim) # decoder-preprocess self.embed_scale_decoder = math.sqrt(cfg.embed_dim) - self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.embed_dim) + self.embed_positions_decoder = PositionalEmbedding(cfg.max_target_positions, cfg.embed_dim) self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.embed_dim) - # post-proces - self._inputs = (None, ) def set_inputs(self, *inputs): @@ -361,7 +371,7 @@ def encoder_preprocess(self, dst: Optional[int] = None): source_tokens = self._inputs[0] source_embed = self.embed_lookup(source_tokens, dst) embed = self.embed_scale_encoder * source_embed - x = embed + self.embed_positions_encoder.weight + x = embed + self.embed_positions_encoder(source_tokens.size(1)) x = self.layernorm_embedding_encoder(x) x = torch.nn.functional.dropout(x, p=0.0) enc = x.transpose(0, 1) @@ -371,7 +381,7 @@ def decoder_preprocess(self, dst: Optional[int] = None): prev_output_tokens = self._inputs[0] target_emb = self.embed_lookup(prev_output_tokens, dst) embed = self.embed_scale_decoder * target_emb - embed = embed + self.embed_positions_decoder.weight + embed = embed + self.embed_positions_decoder(prev_output_tokens.size(1)) embed = self.layernorm_embedding_decoder(embed) embed = torch.nn.functional.dropout(embed, p=0.0) dec = embed.transpose(0, 1) @@ -398,8 +408,25 @@ def __init__(self, cfg: Config, self.pp_stage = torch.distributed.get_rank(_pp_group) self.num_stages = torch.distributed.get_world_size(_pp_group) - self.layer_start = self.total_layers // self.num_stages * self.pp_stage - self.layer_end = self.total_layers // self.num_stages * (self.pp_stage + 1) + encoder_stages = self.num_stages // 2 + decoder_stages = self.num_stages // 2 + if self.pp_stage < self.num_stages // 2: + encoder_stages = self.num_stages // 2 + chunk = cfg.encoder_layers // encoder_stages + remain = cfg.encoder_layers % encoder_stages + layers = [chunk] * encoder_stages + for idx in range(remain): + layers[-idx] += 1 + self.layer_start = sum(layers[0:self.pp_stage]) + self.layer_end = self.layer_start + layers[self.pp_stage] + if self.pp_stage >= self.num_stages // 2: + chunk = cfg.decoder_layers // decoder_stages + remain = cfg.decoder_layers % decoder_stages + layers = [chunk] * decoder_stages + for idx in range(remain): + layers[-idx] += 1 + self.layer_start = cfg.encoder_layers + sum(layers[0:self.pp_stage-encoder_stages]) + self.layer_end = self.layer_start + layers[self.pp_stage-encoder_stages] self.encoder_preprocess = encoder_preprocess self.encoder_forward = (self.layer_start < cfg.encoder_layers) @@ -549,8 +576,8 @@ def forward(self, enc=None, dec=None): # forward encoder if self.encoder_forward: for layer in self.encoders: - enc = checkpoint.checkpoint(layer, enc) - # enc = layer(enc) + # enc = checkpoint.checkpoint(layer, enc) + enc = layer(enc) if self.layer_norm_encoder is not None: enc = self.layer_norm_encoder(enc) output = (enc,) @@ -564,8 +591,8 @@ def forward(self, enc=None, dec=None): if self.decoder_forward: dec = pre_dec if dec is None else dec for layer in self.decoders: - dec, enc = checkpoint.checkpoint(layer, dec, enc) - # dec, enc = layer(dec, enc) + # dec, enc = checkpoint.checkpoint(layer, dec, enc) + dec, enc = layer(dec, enc) if self.layer_norm_decoder is not None: dec = self.layer_norm_decoder(dec) output = (dec,) diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index e39c4b73..4499fd1e 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -32,22 +32,16 @@ class Config: - num_embeddings = 250027 + scale = 2 + scale_p = scale * 0.25 - # d_ff - decoder_layers = 12 - encoder_layers = 12 - embed_dim = 1024 - - # 610M model -> original setting - # attention_heads = 16 - # attention_inner_dim = attention_heads * 64 - # ffn_dim = 4 * embed_dim - - # t5-3b config - attention_heads = 32 - attention_inner_dim = attention_heads * 128 - ffn_dim = 16384 + num_embeddings = 250027 + int(250027*scale_p) + decoder_layers = 12 + int(12*scale_p) + encoder_layers = 12 + int(12*scale_p) + embed_dim = 1024 + int(1024*scale_p) + attention_heads = 16 + int(16*scale_p) + attention_inner_dim = attention_heads * 64 + ffn_dim = 4 * embed_dim attention_dropout = 0.0 activation_dropout = 0.0 @@ -154,6 +148,19 @@ def forward(self, query: torch.Tensor, key: torch.Tensor): return attn +class PositionalEmbedding(torch.nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, seq_len: int): + positions = torch.arange( + 0, seq_len, dtype=torch.long, device=torch.cuda.current_device() + ) + return super().forward(positions + self.offset) + + class EncoderLayer(torch.nn.Module): def __init__(self, cfg: Config): @@ -323,7 +330,7 @@ def forward(self, dec: torch.Tensor, labels): class SharedEmbed(torch.nn.Module): - def __init__(self, cfg: Config): + def __init__(self, cfg: Config, embed_cpu=False): """ group = -1 means no tensor parallelism """ @@ -331,29 +338,33 @@ def __init__(self, cfg: Config): self.tp_group = _tp_group self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) self.tp_idx = 0 if _tp_group == -1 else torch.distributed.get_rank(_tp_group) - - self.cfg = cfg if self.tp_size > 0: print(f'[{torch.distributed.get_rank()}]: initialize sharding embed (x{self.tp_size})') + + self.embed_cpu = embed_cpu + self.cfg = cfg self.vocab_start_index = self.cfg.num_embeddings // self.tp_size * self.tp_idx self.vocab_end_index = self.cfg.num_embeddings // self.tp_size * (self.tp_idx + 1) self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.tp_size, self.cfg.embed_dim))) # encoder-preprocess - self.embed_positions_encoder = torch.nn.Embedding(cfg.max_source_positions, cfg.embed_dim) + self.embed_positions_encoder = PositionalEmbedding(cfg.max_source_positions, cfg.embed_dim) self.embed_scale_encoder = math.sqrt(cfg.embed_dim) self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.embed_dim) # decoder-preprocess self.embed_scale_decoder = math.sqrt(cfg.embed_dim) - self.embed_positions_decoder = torch.nn.Embedding(cfg.max_source_positions, cfg.embed_dim) + self.embed_positions_decoder = PositionalEmbedding(cfg.max_target_positions, cfg.embed_dim) self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.embed_dim) self._inputs = (None, ) def set_inputs(self, *inputs): - self._inputs = inputs + if self.embed_cpu: + self._inputs = [input.cpu() for input in inputs] + else: + self._inputs = inputs def embed_lookup(self, tokens): if self.tp_size > 1: @@ -363,16 +374,23 @@ def embed_lookup(self, tokens): tokens[mask] = 0 embed = torch.nn.functional.embedding(tokens, self.weight) embed[mask, :] = 0.0 + if self.embed_cpu: + embed = embed.cuda() embed = AllReduceIdentity.apply(embed, self.tp_group) else: embed = torch.nn.functional.embedding(tokens, self.weight) + if self.embed_cpu: + embed = embed.cuda() return embed def encoder_preprocess(self): source_tokens = self._inputs[0] + seq_len = source_tokens.size(1) + assert seq_len == self.cfg.max_source_positions + source_embed = self.embed_lookup(source_tokens) embed = self.embed_scale_encoder * source_embed - x = embed + self.embed_positions_encoder.weight + x = embed + self.embed_positions_encoder(seq_len) x = self.layernorm_embedding_encoder(x) x = torch.nn.functional.dropout(x, p=0.0) enc = x.transpose(0, 1) @@ -380,9 +398,12 @@ def encoder_preprocess(self): def decoder_preprocess(self): prev_output_tokens = self._inputs[0] + seq_len = prev_output_tokens.size(1) + assert seq_len == self.cfg.max_source_positions + target_emb = self.embed_lookup(prev_output_tokens) embed = self.embed_scale_decoder * target_emb - embed = embed + self.embed_positions_decoder.weight + embed = embed + self.embed_positions_encoder(seq_len) embed = self.layernorm_embedding_decoder(embed) embed = torch.nn.functional.dropout(embed, p=0.0) dec = embed.transpose(0, 1) @@ -394,7 +415,8 @@ class mBARTFull(torch.nn.Module): def __init__(self, cfg: Config, encoder_preprocess=True, decoder_preprocess=True, - post_process=True): + post_process=True, + embed_cpu=False): super().__init__() self.cfg = cfg self.dummy_labels = torch.tensor([1]).cuda() @@ -409,8 +431,25 @@ def __init__(self, cfg: Config, self.pp_stage = torch.distributed.get_rank(_pp_group) self.num_stages = torch.distributed.get_world_size(_pp_group) - self.layer_start = self.total_layers // self.num_stages * self.pp_stage - self.layer_end = self.total_layers // self.num_stages * (self.pp_stage + 1) + encoder_stages = self.num_stages // 2 + decoder_stages = self.num_stages // 2 + if self.pp_stage < self.num_stages // 2: + encoder_stages = self.num_stages // 2 + chunk = cfg.encoder_layers // encoder_stages + remain = cfg.encoder_layers % encoder_stages + layers = [chunk] * encoder_stages + for idx in range(remain): + layers[-idx] += 1 + self.layer_start = sum(layers[0:self.pp_stage]) + self.layer_end = self.layer_start + layers[self.pp_stage] + if self.pp_stage >= self.num_stages // 2: + chunk = cfg.decoder_layers // decoder_stages + remain = cfg.decoder_layers % decoder_stages + layers = [chunk] * decoder_stages + for idx in range(remain): + layers[-idx] += 1 + self.layer_start = cfg.encoder_layers + sum(layers[0:self.pp_stage-encoder_stages]) + self.layer_end = self.layer_start + layers[self.pp_stage-encoder_stages] self.encoder_preprocess = encoder_preprocess self.encoder_forward = (self.layer_start < cfg.encoder_layers) @@ -428,8 +467,8 @@ def __init__(self, cfg: Config, self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) self.decoder_layer_end = self.layer_end - if encoder_preprocess or decoder_preprocess or post_process: - self.headtail = SharedEmbed(cfg) + if encoder_preprocess or decoder_preprocess: + self.headtail = SharedEmbed(cfg, embed_cpu) else: self.headtail = None @@ -560,8 +599,8 @@ def forward(self, enc=None, dec=None): # forward encoder if self.encoder_forward: for layer in self.encoders: - enc = checkpoint.checkpoint(layer, enc) - # enc = layer(enc) + # enc = checkpoint.checkpoint(layer, enc) + enc = layer(enc) if self.layer_norm_encoder is not None: enc = self.layer_norm_encoder(enc) output = (enc,) @@ -575,8 +614,8 @@ def forward(self, enc=None, dec=None): if self.decoder_forward: dec = pre_dec if dec is None else dec for layer in self.decoders: - dec, enc = checkpoint.checkpoint(layer, dec, enc) - # dec, enc = layer(dec, enc) + # dec, enc = checkpoint.checkpoint(layer, dec, enc) + dec, enc = layer(dec, enc) if self.layer_norm_decoder is not None: dec = self.layer_norm_decoder(dec) output = (dec,) @@ -613,6 +652,8 @@ def reduce_embed(model, pp_embed_group): help='use pipeline parallelism') parser.add_argument('--tp-size', type=int, default=1, help='use tensor parallelism') + parser.add_argument('--embed-cpu', action='store_true', + help='put embedding inside CPU') args = parser.parse_args() print(args) @@ -653,6 +694,7 @@ def reduce_embed(model, pp_embed_group): assert _pp_embed_group != -1 cfg = Config() + print_each_rank(cfg, rank_only=0) dataloader = SynTextDataLoader( shapes=( [1, cfg.max_source_positions], @@ -666,9 +708,13 @@ def reduce_embed(model, pp_embed_group): encoder_preprocess = is_first_stage decoder_preprocess = is_first_decoder_stage postprocess = is_last_stage - model = mBARTFull(cfg, encoder_preprocess, decoder_preprocess, postprocess).cuda() + model = mBARTFull(cfg, encoder_preprocess, decoder_preprocess, postprocess, args.embed_cpu).cuda() else: - model = mBARTFull(cfg, True, True, True).cuda() + model = mBARTFull(cfg, True, True, True, args.embd_cpu).cuda() + + if args.embed_cpu: + if model.headtail is not None: + model.headtail.weight = torch.nn.Parameter(model.headtail.weight.cpu()) print_each_rank('model weight consumpition:') memory_summary() @@ -681,8 +727,11 @@ def reduce_embed(model, pp_embed_group): if step >= 10: CudaTimer(enable=True).start('e2e') if args.pp_size > 1: - schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) - reduce_embed(model, _pp_embed_group) + schedule_1f1b(model, iter(dataloader), args.nmb, args.pp_size, (_pp_prev_rank, _pp_next_rank)) + # schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) + # TODO: support gradient allreduce in cpu + if not args.embed_cpu: + reduce_embed(model, _pp_embed_group) else: loader = iter(dataloader) for _ in range(args.nmb): From 00cd232e56c2c7d4d3167da8fa6053a5e19e6155 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 22 Mar 2022 15:23:22 +0800 Subject: [PATCH 0675/1892] mbart mode scale --- handcraft/mbart/mbart.py | 40 ++++++++++++++++++++-------------------- handcraft/mbart/tp.py | 20 +++++++++++++++++++- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index d044b067..f911f1f7 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -29,6 +29,20 @@ _pp_next_rank = None _pp_prev_rank = None +parser = argparse.ArgumentParser(description='swin') +parser.add_argument('--scale', type=int, default=0, + help='scale of model, 0 is original one.') +parser.add_argument('--nmb', type=int, default=4, + help='num of micro batch') +parser.add_argument('--use-naive', action='store_true', + help='use naive pipeline') +parser.add_argument('--use-1f1b', action='store_true', + help='use 1f1b scheduling') +parser.add_argument('--use-tp1f1b-pack', action='store_true', + help='use tensor parallel 1f1b') +args = parser.parse_args() +print(args) + # fairseq task # translation_from_pretrained_bart @@ -444,7 +458,7 @@ def __init__(self, cfg: Config, self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) self.decoder_layer_end = self.layer_end - if encoder_preprocess or decoder_preprocess or post_process or shard: + if encoder_preprocess or decoder_preprocess or shard: self.headtail = ShardEmbed(cfg, group = None if shard else -1) else: self.headtail = None @@ -622,22 +636,8 @@ def reduce_embed(model, pp_embed_group): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--nmb', type=int, default=4, - help='num of micro batch') - parser.add_argument('--use-naive', action='store_true', - help='use naive pipeline') - parser.add_argument('--use-1f1b', action='store_true', - help='use 1f1b scheduling') - parser.add_argument('--use-tp1f1b-pack', action='store_true', - help='use tensor parallel 1f1b') - args = parser.parse_args() - - print(args) - cube.init() pp_ranks = list(range(DeviceGroup().world_size)) - # pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) print_each_rank(f'my pp ranks: {pp_ranks}') if _pp_group == -1: @@ -679,9 +679,9 @@ def reduce_embed(model, pp_embed_group): optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) CudaTimer(enable=False).warmup() - iter_num = 32 + iter_num = 10 for step in range(iter_num): - if step >= 10: + if step >= 3: CudaTimer(enable=True).start('e2e') if args.use_naive: schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) @@ -703,11 +703,11 @@ def reduce_embed(model, pp_embed_group): if step == 0: print('memory after optimizer') memory_summary() - if step >= 10: + if step >= 3: CudaTimer().stop('e2e') - if (step + 1) % 10 == 0: + if (step + 1) % 3 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-10, field_name='e2e'))) + CudaTimer().duration(iter_num-3, field_name='e2e'))) memory_summary() diff --git a/handcraft/mbart/tp.py b/handcraft/mbart/tp.py index ff6e9302..01c80641 100644 --- a/handcraft/mbart/tp.py +++ b/handcraft/mbart/tp.py @@ -6,6 +6,9 @@ class AllReduceIdentity(torch.autograd.Function): @staticmethod def forward(ctx, input, group): + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input torch.distributed.all_reduce(input, group=group) return input @@ -23,6 +26,9 @@ def forward(ctx, input, group): @staticmethod def backward(ctx, grad_output): + world_size = torch.distributed.get_world_size(ctx._group) + if world_size == 1: + return grad_output, None torch.distributed.all_reduce(grad_output, group=ctx._group) return grad_output, None @@ -59,9 +65,12 @@ def backward(ctx, grad_output: torch.Tensor): class ReduceBroadcast(torch.autograd.Function): @staticmethod - def forward(ctx, input, dst: int, group=None): + def forward(ctx, input, dst: int, group): ctx._dst = dst ctx._group = group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input torch.distributed.reduce(input, dst, group=group) torch.cuda.synchronize() return input @@ -70,6 +79,9 @@ def forward(ctx, input, dst: int, group=None): def backward(ctx, grad_output): src = ctx._dst group = ctx._group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output, None, None torch.distributed.broadcast(grad_output, src, group=group) torch.cuda.synchronize() return grad_output, None, None @@ -81,6 +93,9 @@ class BroadcastReduce(torch.autograd.Function): def forward(ctx, input, src: int, group=None): ctx._src = src ctx._group = group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input torch.distributed.broadcast(input, src, group=group) torch.cuda.synchronize() return input @@ -89,6 +104,9 @@ def forward(ctx, input, src: int, group=None): def backward(ctx, grad_output): dst = ctx._src group = ctx._group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output, None, None if not grad_output.is_contiguous(): grad_output = grad_output.contiguous() torch.distributed.reduce(grad_output, dst, group=group) From 44f72cd4a13d0953ab48a087ef67dafe157eb229 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 22 Mar 2022 15:24:15 +0800 Subject: [PATCH 0676/1892] scale model --- handcraft/mbart/mbart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index f911f1f7..5fdbbe86 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -61,7 +61,7 @@ class Config: # attention_inner_dim = attention_heads * 64 # ffn_dim = 4 * embed_dim - scale = 2 + scale = args.scale scale_p = scale * 0.25 num_embeddings = 250027 + int(250027*scale_p) From 5aed4beb7cb958c49c4074ce35adfe12ed4d09d4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 22 Mar 2022 19:57:53 +0800 Subject: [PATCH 0677/1892] scale model --- handcraft/mbart/mbart.py | 58 ++++++-------- handcraft/mbart/mbart_hybrid.py | 130 ++++++++++++++++---------------- 2 files changed, 92 insertions(+), 96 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 5fdbbe86..1465abce 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -20,7 +20,7 @@ from cube.profiler.memory import memory_summary, model_summary from cube.profiler.timer import print_each_rank -from handcraft.mbart.schedule import schedule_naive, schedule_1f1b, schedule_tp_1f1b_pack +from handcraft.mbart.schedule import schedule_1f1b, schedule_tp_1f1b_pack from handcraft.mbart.tp import ReduceBroadcast @@ -32,10 +32,10 @@ parser = argparse.ArgumentParser(description='swin') parser.add_argument('--scale', type=int, default=0, help='scale of model, 0 is original one.') -parser.add_argument('--nmb', type=int, default=4, +parser.add_argument('--nmb', type=int, help='num of micro batch') -parser.add_argument('--use-naive', action='store_true', - help='use naive pipeline') +parser.add_argument('--iter-nmb', type=int, default=0, + help='num of micro batch per scheduling iteration (1f1b only)') parser.add_argument('--use-1f1b', action='store_true', help='use 1f1b scheduling') parser.add_argument('--use-tp1f1b-pack', action='store_true', @@ -43,11 +43,25 @@ args = parser.parse_args() print(args) -# fairseq task -# translation_from_pretrained_bart +cube.init() +pp_ranks = list(range(DeviceGroup().world_size)) +print_each_rank(f'my pp ranks: {pp_ranks}') + +if _pp_group == -1: + _pp_group = DeviceGroup().get_group(pp_ranks) + idx = pp_ranks.index(DeviceGroup().rank) + _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] + _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] + is_first_stage = idx == 0 + is_first_decoder_stage = idx == len(pp_ranks) // 2 + is_last_stage = idx == len(pp_ranks) - 1 + +# create embed group: first encoder, first decoder, last stage +if args.use_naive or args.use_1f1b: + embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2]] + embed_ranks = list(set(embed_ranks)) + _pp_embed_group = DeviceGroup().get_group(embed_ranks) -# fairseq criterion -# label_smoothed_cross_entropy, --label_smoothing = 0.2 class Config: @@ -636,26 +650,6 @@ def reduce_embed(model, pp_embed_group): if __name__ == '__main__': - cube.init() - pp_ranks = list(range(DeviceGroup().world_size)) - print_each_rank(f'my pp ranks: {pp_ranks}') - - if _pp_group == -1: - _pp_group = DeviceGroup().get_group(pp_ranks) - idx = pp_ranks.index(DeviceGroup().rank) - _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] - _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] - is_first_stage = idx == 0 - is_first_decoder_stage = idx == len(pp_ranks) // 2 - is_last_stage = idx == len(pp_ranks) - 1 - - # create embed group: first encoder, first decoder, last stage - # FIXME: only work for tp_size = 1 - if args.use_naive or args.use_1f1b: - embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2]] - embed_ranks = list(set(embed_ranks)) - _pp_embed_group = DeviceGroup().get_group(embed_ranks) - cfg = Config() dataloader = SynTextDataLoader( @@ -683,12 +677,10 @@ def reduce_embed(model, pp_embed_group): for step in range(iter_num): if step >= 3: CudaTimer(enable=True).start('e2e') - if args.use_naive: - schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) - reduce_embed(model, _pp_embed_group) if args.use_1f1b: - for _ in range(args.nmb // 2): - schedule_1f1b(model, iter(dataloader), 2, len(pp_ranks), (_pp_prev_rank, _pp_next_rank)) + iter_num = args.iter_nmb + for _ in range(args.nmb // args.iter_nmb): + schedule_1f1b(model, iter(dataloader), args.iter_nmb, len(pp_ranks), (_pp_prev_rank, _pp_next_rank)) reduce_embed(model, _pp_embed_group) if args.use_tp1f1b_pack: schedule_tp_1f1b_pack( diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 4499fd1e..58429b14 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -4,7 +4,8 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - handcraft/mbart/mbart.py --pp-size 4 --tp-size 1 --nmb 4 + handcraft/mbart/mbart_hybrid.py --pp-size 4 --tp-size 1\ + --nmb 4 --scale 0 --iter-nmb 4 """ from typing import Optional @@ -20,7 +21,7 @@ from cube.profiler.memory import memory_summary, model_summary from cube.profiler.timer import print_each_rank -from handcraft.mbart.schedule import schedule_naive, schedule_1f1b, schedule_tp_1f1b_pack +from handcraft.mbart.schedule import schedule_1f1b, schedule_tp_1f1b_pack from handcraft.mbart.tp import AllReduceIdentity, IdentityAllreduce, ReduceBroadcast _tp_group = -1 @@ -30,9 +31,62 @@ _pp_prev_rank = None +parser = argparse.ArgumentParser(description='swin') +parser.add_argument('--nmb', type=int, default=4, + help='num of micro batch') +parser.add_argument('--scale', type=int, default=0, + help='scale of model, 0 is original one.') +parser.add_argument('--pp-size', type=int, default=1, + help='use pipeline parallelism') +parser.add_argument('--tp-size', type=int, default=1, + help='use tensor parallelism') +parser.add_argument('--embed-cpu', action='store_true', + help='put embedding inside CPU') +parser.add_argument('--iter-nmb', type=int, default=0, + help='num of micro batch per scheduling iteration (1f1b only)') +args = parser.parse_args() +print(args) + +cube.init() +pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) +print_each_rank(f'my pp ranks: {pp_ranks}') +print_each_rank(f'my tp ranks: {tp_ranks}') + +if _tp_group == -1: + _tp_group = DeviceGroup().get_group(tp_ranks) + +if _pp_group == -1: + _pp_group = DeviceGroup().get_group(pp_ranks) + idx = pp_ranks.index(DeviceGroup().rank) + _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] + _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] + is_first_stage = idx == 0 + is_first_decoder_stage = idx == len(pp_ranks) // 2 + is_last_stage = idx == len(pp_ranks) - 1 + +if len(pp_ranks) > 1: + pranks = [torch.zeros((args.pp_size,), dtype=torch.int, device=torch.cuda.current_device()) for _ in range(args.tp_size)] + prank = torch.tensor(pp_ranks, dtype=torch.int).cuda() + pranks[torch.distributed.get_rank(_tp_group)] = prank + torch.distributed.all_gather(pranks, prank, group=_tp_group) + torch.cuda.synchronize() + print_each_rank(f'allgather-pp ranks: {pranks}') + + for prank in pranks: + prank = prank.tolist() + embed_ranks = [prank[0], prank[len(prank) // 2]] + embed_ranks = list(set(embed_ranks)) + group = DeviceGroup().get_group(embed_ranks) + if torch.distributed.get_rank(_tp_group) in prank: + print(f'embedding group: {embed_ranks}') + _pp_embed_group = group + assert _pp_embed_group != -1 + + + class Config: - scale = 2 + scale = args.scale scale_p = scale * 0.25 num_embeddings = 250027 + int(250027*scale_p) @@ -49,11 +103,6 @@ class Config: max_target_positions = 1024 max_source_positions = 1024 - adaptive_softmax_cutoff = None - adaptive_softmax_dropout = 0 - - share_decoder_input_output_embed = True - share_all_embeddings = True # classification task pooler_dropout = 0.0 @@ -645,54 +694,6 @@ def reduce_embed(model, pp_embed_group): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--nmb', type=int, default=4, - help='num of micro batch') - parser.add_argument('--pp-size', type=int, default=1, - help='use pipeline parallelism') - parser.add_argument('--tp-size', type=int, default=1, - help='use tensor parallelism') - parser.add_argument('--embed-cpu', action='store_true', - help='put embedding inside CPU') - args = parser.parse_args() - - print(args) - - cube.init() - pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) - print_each_rank(f'my pp ranks: {pp_ranks}') - print_each_rank(f'my tp ranks: {tp_ranks}') - - if _tp_group == -1: - _tp_group = DeviceGroup().get_group(tp_ranks) - - if _pp_group == -1: - _pp_group = DeviceGroup().get_group(pp_ranks) - idx = pp_ranks.index(DeviceGroup().rank) - _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] - _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] - is_first_stage = idx == 0 - is_first_decoder_stage = idx == len(pp_ranks) // 2 - is_last_stage = idx == len(pp_ranks) - 1 - - if len(pp_ranks) > 1: - pranks = [torch.zeros((args.pp_size,), dtype=torch.int, device=torch.cuda.current_device()) for _ in range(args.tp_size)] - prank = torch.tensor(pp_ranks, dtype=torch.int).cuda() - pranks[torch.distributed.get_rank(_tp_group)] = prank - torch.distributed.all_gather(pranks, prank, group=_tp_group) - torch.cuda.synchronize() - print_each_rank(f'allgather-pp ranks: {pranks}') - - for prank in pranks: - prank = prank.tolist() - embed_ranks = [prank[0], prank[len(prank) // 2]] - embed_ranks = list(set(embed_ranks)) - group = DeviceGroup().get_group(embed_ranks) - if torch.distributed.get_rank(_tp_group) in prank: - print(f'embedding group: {embed_ranks}') - _pp_embed_group = group - assert _pp_embed_group != -1 - cfg = Config() print_each_rank(cfg, rank_only=0) dataloader = SynTextDataLoader( @@ -722,13 +723,13 @@ def reduce_embed(model, pp_embed_group): optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) CudaTimer(enable=False).warmup() - iter_num = 32 + iter_num = 10 for step in range(iter_num): - if step >= 10: + if step >= 3: CudaTimer(enable=True).start('e2e') if args.pp_size > 1: - schedule_1f1b(model, iter(dataloader), args.nmb, args.pp_size, (_pp_prev_rank, _pp_next_rank)) - # schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) + for _ in range(args.nmb // args.iter_nmb): + schedule_1f1b(model, iter(dataloader), args.iter_nmb, args.pp_size, (_pp_prev_rank, _pp_next_rank)) # TODO: support gradient allreduce in cpu if not args.embed_cpu: reduce_embed(model, _pp_embed_group) @@ -743,11 +744,14 @@ def reduce_embed(model, pp_embed_group): memory_summary() optimizer.step() optimizer.zero_grad() - if step >= 10: + if step >= 3: CudaTimer().stop('e2e') - if (step + 1) % 10 == 0: + if step == 0: + print_each_rank('after optimizer') + memory_summary() + if (step + 1) % 3 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-10, field_name='e2e'))) + CudaTimer().duration(iter_num-3, field_name='e2e'))) memory_summary() From 259e0c8fc798d2acf101192d3481c86bf43bc467 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 22 Mar 2022 20:05:19 +0800 Subject: [PATCH 0678/1892] fix pure tp bug --- handcraft/mbart/mbart_hybrid.py | 42 ++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 58429b14..3692d2af 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -480,25 +480,29 @@ def __init__(self, cfg: Config, self.pp_stage = torch.distributed.get_rank(_pp_group) self.num_stages = torch.distributed.get_world_size(_pp_group) - encoder_stages = self.num_stages // 2 - decoder_stages = self.num_stages // 2 - if self.pp_stage < self.num_stages // 2: + if self.num_stages >= 2: encoder_stages = self.num_stages // 2 - chunk = cfg.encoder_layers // encoder_stages - remain = cfg.encoder_layers % encoder_stages - layers = [chunk] * encoder_stages - for idx in range(remain): - layers[-idx] += 1 - self.layer_start = sum(layers[0:self.pp_stage]) - self.layer_end = self.layer_start + layers[self.pp_stage] - if self.pp_stage >= self.num_stages // 2: - chunk = cfg.decoder_layers // decoder_stages - remain = cfg.decoder_layers % decoder_stages - layers = [chunk] * decoder_stages - for idx in range(remain): - layers[-idx] += 1 - self.layer_start = cfg.encoder_layers + sum(layers[0:self.pp_stage-encoder_stages]) - self.layer_end = self.layer_start + layers[self.pp_stage-encoder_stages] + decoder_stages = self.num_stages // 2 + if self.pp_stage < self.num_stages // 2: + encoder_stages = self.num_stages // 2 + chunk = cfg.encoder_layers // encoder_stages + remain = cfg.encoder_layers % encoder_stages + layers = [chunk] * encoder_stages + for idx in range(remain): + layers[-idx] += 1 + self.layer_start = sum(layers[0:self.pp_stage]) + self.layer_end = self.layer_start + layers[self.pp_stage] + if self.pp_stage >= self.num_stages // 2: + chunk = cfg.decoder_layers // decoder_stages + remain = cfg.decoder_layers % decoder_stages + layers = [chunk] * decoder_stages + for idx in range(remain): + layers[-idx] += 1 + self.layer_start = cfg.encoder_layers + sum(layers[0:self.pp_stage-encoder_stages]) + self.layer_end = self.layer_start + layers[self.pp_stage-encoder_stages] + else: + self.layer_start = 0 + self.layer_end = cfg.encoder_layers + cfg.decoder_layers self.encoder_preprocess = encoder_preprocess self.encoder_forward = (self.layer_start < cfg.encoder_layers) @@ -711,7 +715,7 @@ def reduce_embed(model, pp_embed_group): postprocess = is_last_stage model = mBARTFull(cfg, encoder_preprocess, decoder_preprocess, postprocess, args.embed_cpu).cuda() else: - model = mBARTFull(cfg, True, True, True, args.embd_cpu).cuda() + model = mBARTFull(cfg, True, True, True, args.embed_cpu).cuda() if args.embed_cpu: if model.headtail is not None: From 6c50dd07670be25b9de150d3d23cdca557427027 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 22 Mar 2022 23:53:53 +0800 Subject: [PATCH 0679/1892] 1f1b supports hybrid --- handcraft/mbart/mbart_hybrid.py | 7 ++++--- handcraft/mbart/schedule.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 3692d2af..7a1fcdfe 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -21,7 +21,7 @@ from cube.profiler.memory import memory_summary, model_summary from cube.profiler.timer import print_each_rank -from handcraft.mbart.schedule import schedule_1f1b, schedule_tp_1f1b_pack +from handcraft.mbart.schedule import schedule_naive, schedule_1f1b from handcraft.mbart.tp import AllReduceIdentity, IdentityAllreduce, ReduceBroadcast _tp_group = -1 @@ -165,8 +165,8 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) self.inner_dim = inner_dim - self.num_heads = num_heads self.head_dim = inner_dim // num_heads + self.num_heads = num_heads // self.tp_size self.scaling = self.head_dim ** -0.5 self.dropout_p = dropout # K @@ -732,8 +732,9 @@ def reduce_embed(model, pp_embed_group): if step >= 3: CudaTimer(enable=True).start('e2e') if args.pp_size > 1: + # schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) for _ in range(args.nmb // args.iter_nmb): - schedule_1f1b(model, iter(dataloader), args.iter_nmb, args.pp_size, (_pp_prev_rank, _pp_next_rank)) + schedule_1f1b(model, iter(dataloader), args.iter_nmb, args.pp_size, (_pp_prev_rank, _pp_next_rank), group=_pp_group) # TODO: support gradient allreduce in cpu if not args.embed_cpu: reduce_embed(model, _pp_embed_group) diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py index c51e5dfd..ee96cdf1 100644 --- a/handcraft/mbart/schedule.py +++ b/handcraft/mbart/schedule.py @@ -434,14 +434,14 @@ def tp_decoder_backward(grads: Tuple[torch.Tensor]): # print_each_rank(f'=========end rank {rank}=========') -def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors): +def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors, group=None): rank = torch.distributed.get_rank() prev_rank, next_rank = neighbors is_first_stage = rank < prev_rank is_last_stage = rank > next_rank - num_warmup_microbatches = num_stage - 1 - rank + num_warmup_microbatches = num_stage - 1 - torch.distributed.get_rank(group=group) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatch) num_warmup_remaining = num_microbatch - num_warmup_microbatches From 04ddf2db3e9119a7d8d0907c9ff41909b21c57fb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 23 Mar 2022 11:31:13 +0800 Subject: [PATCH 0680/1892] running config --- handcraft/mbart/run.sh | 50 ++++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh index 4d524a6a..bff16bb2 100755 --- a/handcraft/mbart/run.sh +++ b/handcraft/mbart/run.sh @@ -1,37 +1,55 @@ -evaldir=eval/3b-checkpoint +evaldir=eval/mbart-scale mkdir -p ${evaldir} # 4 gpus OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 64 > ${evaldir}/4dev64nmb-tp1f1b-pack.txt + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ + --scale 2 > ${evaldir}/4dev256nmb-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-naive --nmb 64 > ${evaldir}/4dev64nmb-naive.txt + handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ + --scale 2 > ${evaldir}/4dev256nmb-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 64 > ${evaldir}/4dev64nmb-tp.txt + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ + --scale 2 --iter-nmb 256 > ${evaldir}/4dev256nmb-tp2pp2.txt -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > ${evaldir}/4dev64nmb-tp2pp2.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy_hybrid.py --tp-size 2 --pp-size 2 --nmb 64 > ${evaldir}/4dev64nmb-2tp2pp.txt # 8 gpus OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 128 > ${evaldir}/8dev128nmb-tp1f1b-pack.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-naive --nmb 128 > ${evaldir}/8dev128nmb-naive.txt + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ + --scale 4 > ${evaldir}/8dev256nmb-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 128 > ${evaldir}/8dev128nmb-tp.txt + handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ + --scale 4 > ${evaldir}/8dev256nmb-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 128 > ${evaldir}/8dev128nmb-tp4pp2.txt + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ + --scale 4 --iter-nmb 256 > ${evaldir}/8dev128nmb-tp4pp2.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 128 > ${evaldir}/8dev128nmb-tp2pp4.txt + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ + --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp2pp4.txt + + +# 16 gpus + +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ +# --scale 6 > ${evaldir}/16dev256nmb-tp1f1b.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 4 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp4pp4.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 2 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp8pp2.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt From 542a8bff25a5cc872a84be777883502e14f11357 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 23 Mar 2022 13:18:06 +0800 Subject: [PATCH 0681/1892] fix error --- handcraft/mbart/mbart.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 1465abce..733a846a 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -57,7 +57,7 @@ is_last_stage = idx == len(pp_ranks) - 1 # create embed group: first encoder, first decoder, last stage -if args.use_naive or args.use_1f1b: +if args.use_1f1b: embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2]] embed_ranks = list(set(embed_ranks)) _pp_embed_group = DeviceGroup().get_group(embed_ranks) @@ -659,7 +659,7 @@ def reduce_embed(model, pp_embed_group): dtypes=(torch.int64,), batch_dims=(0,) ) - if args.use_naive or args.use_1f1b: + if args.use_1f1b: encoder_preprocess = is_first_stage decoder_preprocess = is_first_decoder_stage postprocess = is_last_stage @@ -678,7 +678,6 @@ def reduce_embed(model, pp_embed_group): if step >= 3: CudaTimer(enable=True).start('e2e') if args.use_1f1b: - iter_num = args.iter_nmb for _ in range(args.nmb // args.iter_nmb): schedule_1f1b(model, iter(dataloader), args.iter_nmb, len(pp_ranks), (_pp_prev_rank, _pp_next_rank)) reduce_embed(model, _pp_embed_group) From c86c6a0b8425e3a41afaab8b4c0169e3359d9145 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 23 Mar 2022 13:52:25 +0800 Subject: [PATCH 0682/1892] fix script --- handcraft/mbart/run.sh | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh index bff16bb2..b224eea6 100755 --- a/handcraft/mbart/run.sh +++ b/handcraft/mbart/run.sh @@ -10,6 +10,10 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ + --scale 2 > ${evaldir}/4dev256nmb-1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ --scale 2 > ${evaldir}/4dev256nmb-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ @@ -24,7 +28,11 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ --scale 4 > ${evaldir}/8dev256nmb-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ + handcraft/mbart/mbart.py --use-1f1b --nmb 1 \ + --scale 4 > ${evaldir}/8dev256nmb-1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ --scale 4 > ${evaldir}/8dev256nmb-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ From bcaf650a7998ea63876d644fee6930285096b8fd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 23 Mar 2022 13:58:34 +0800 Subject: [PATCH 0683/1892] fix script --- handcraft/mbart/run.sh | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh index b224eea6..c1f1997e 100755 --- a/handcraft/mbart/run.sh +++ b/handcraft/mbart/run.sh @@ -10,7 +10,7 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ - --scale 2 > ${evaldir}/4dev256nmb-1f1b.txt + --scale 2 --iter-nmb 256 > ${evaldir}/4dev256nmb-1f1b.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ @@ -28,8 +28,8 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ --scale 4 > ${evaldir}/8dev256nmb-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-1f1b --nmb 1 \ - --scale 4 > ${evaldir}/8dev256nmb-1f1b.txt + handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ + --scale 4 --iter-nmb 1 > ${evaldir}/8dev256nmb-1f1b.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ @@ -46,18 +46,22 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ # 16 gpus -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ -# --scale 6 > ${evaldir}/16dev256nmb-tp1f1b.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 4 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp4pp4.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 2 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp8pp2.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt +OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ + --scale 6 > ${evaldir}/16dev256nmb-tp1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 16 --pp-size 1 --nmb 256 \ + --scale 6 > ${evaldir}/16dev256nmb-tp.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 4 --nmb 256 \ + --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp4pp4.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 2 --nmb 256 \ + --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp8pp2.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ + --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt From aec9261c06b6167b95f02d4fcb2240d57c4928fd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 24 Mar 2022 12:25:26 +0000 Subject: [PATCH 0684/1892] enable recompute --- handcraft/mbart/mbart.py | 25 ++++++---- handcraft/mbart/mbart_hybrid.py | 25 +++++++--- handcraft/mbart/run-recompute.sh | 67 ++++++++++++++++++++++++++ handcraft/mbart/run.sh | 82 ++++++++++++++++---------------- handcraft/mbart/schedule.py | 75 +++++++++++++++++++++++------ 5 files changed, 204 insertions(+), 70 deletions(-) create mode 100755 handcraft/mbart/run-recompute.sh diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 733a846a..e02f4781 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -40,6 +40,8 @@ help='use 1f1b scheduling') parser.add_argument('--use-tp1f1b-pack', action='store_true', help='use tensor parallel 1f1b') +parser.add_argument('--use-recompute', action='store_true', + help='use recompute for a stage') args = parser.parse_args() print(args) @@ -78,7 +80,7 @@ class Config: scale = args.scale scale_p = scale * 0.25 - num_embeddings = 250027 + int(250027*scale_p) + num_embeddings = 500000 # 250027 + int(250027*scale_p) decoder_layers = 12 + int(12*scale_p) encoder_layers = 12 + int(12*scale_p) embed_dim = 1024 + int(1024*scale_p) @@ -673,18 +675,25 @@ def reduce_embed(model, pp_embed_group): optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) CudaTimer(enable=False).warmup() - iter_num = 10 + iter_num = 6 for step in range(iter_num): - if step >= 3: + if step >= 2: CudaTimer(enable=True).start('e2e') if args.use_1f1b: for _ in range(args.nmb // args.iter_nmb): - schedule_1f1b(model, iter(dataloader), args.iter_nmb, len(pp_ranks), (_pp_prev_rank, _pp_next_rank)) + schedule_1f1b( + model, iter(dataloader), + args.iter_nmb, len(pp_ranks), + (_pp_prev_rank, _pp_next_rank), + recompute=args.use_recompute, + ) reduce_embed(model, _pp_embed_group) if args.use_tp1f1b_pack: schedule_tp_1f1b_pack( model, iter(dataloader), - args.nmb, len(pp_ranks), (_pp_prev_rank, _pp_next_rank) + args.nmb, len(pp_ranks), + (_pp_prev_rank, _pp_next_rank), + recompute=args.use_recompute, ) if step == 0: print('passed 1st iteration') @@ -694,11 +703,11 @@ def reduce_embed(model, pp_embed_group): if step == 0: print('memory after optimizer') memory_summary() - if step >= 3: + if step >= 2: CudaTimer().stop('e2e') - if (step + 1) % 3 == 0: + if (step + 1) % 2 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-3, field_name='e2e'))) + CudaTimer().duration(iter_num-2, field_name='e2e'))) memory_summary() diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 7a1fcdfe..9e8ec320 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -8,6 +8,7 @@ --nmb 4 --scale 0 --iter-nmb 4 """ +import re from typing import Optional import argparse import math @@ -44,6 +45,8 @@ help='put embedding inside CPU') parser.add_argument('--iter-nmb', type=int, default=0, help='num of micro batch per scheduling iteration (1f1b only)') +parser.add_argument('--use-recompute', action='store_true', + help='use recompute for a stage') args = parser.parse_args() print(args) @@ -89,7 +92,7 @@ class Config: scale = args.scale scale_p = scale * 0.25 - num_embeddings = 250027 + int(250027*scale_p) + num_embeddings = 500000 # 250027 + int(250027*scale_p) decoder_layers = 12 + int(12*scale_p) encoder_layers = 12 + int(12*scale_p) embed_dim = 1024 + int(1024*scale_p) @@ -727,14 +730,22 @@ def reduce_embed(model, pp_embed_group): optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) CudaTimer(enable=False).warmup() - iter_num = 10 + iter_num = 6 for step in range(iter_num): - if step >= 3: + if step >= 2: CudaTimer(enable=True).start('e2e') if args.pp_size > 1: # schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) for _ in range(args.nmb // args.iter_nmb): - schedule_1f1b(model, iter(dataloader), args.iter_nmb, args.pp_size, (_pp_prev_rank, _pp_next_rank), group=_pp_group) + schedule_1f1b( + model, + iter(dataloader), + args.iter_nmb, + args.pp_size, + (_pp_prev_rank, _pp_next_rank), + group=_pp_group, + recompute=args.recompute + ) # TODO: support gradient allreduce in cpu if not args.embed_cpu: reduce_embed(model, _pp_embed_group) @@ -749,14 +760,14 @@ def reduce_embed(model, pp_embed_group): memory_summary() optimizer.step() optimizer.zero_grad() - if step >= 3: + if step >= 2: CudaTimer().stop('e2e') if step == 0: print_each_rank('after optimizer') memory_summary() - if (step + 1) % 3 == 0: + if (step + 1) % 2 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-3, field_name='e2e'))) + CudaTimer().duration(iter_num-2, field_name='e2e'))) memory_summary() diff --git a/handcraft/mbart/run-recompute.sh b/handcraft/mbart/run-recompute.sh new file mode 100755 index 00000000..f7683c6d --- /dev/null +++ b/handcraft/mbart/run-recompute.sh @@ -0,0 +1,67 @@ +evaldir=/data/MagicCube/scale-mbart-recompute + +mkdir -p ${evaldir} + +# 4 gpus + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ + --scale 3 --use-recompute > ${evaldir}/4dev256nmb-tp1f1b-pack.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ + --scale 3 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ + --scale 3 --use-recompute > ${evaldir}/4dev256nmb-tp.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ + --scale 3 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-tp2pp2.txt + + +# 8 gpus + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ + --scale 4 --use-recompute > ${evaldir}/8dev256nmb-tp1f1b-pack.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ + --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-1f1b.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ + --scale 4 --use-recompute > ${evaldir}/8dev256nmb-tp.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ + --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-tp4pp2.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ + --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-tp2pp4.txt + + +# 16 gpus + +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ +# --scale 6 > ${evaldir}/16dev256nmb-tp1f1b.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 16 --pp-size 1 --nmb 256 \ +# --scale 6 > ${evaldir}/16dev256nmb-tp.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 4 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp4pp4.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 2 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp8pp2.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh index c1f1997e..59665fd0 100755 --- a/handcraft/mbart/run.sh +++ b/handcraft/mbart/run.sh @@ -1,4 +1,4 @@ -evaldir=eval/mbart-scale +evaldir=/data/MagicCube/scale-mbart mkdir -p ${evaldir} @@ -6,19 +6,19 @@ mkdir -p ${evaldir} OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 2 > ${evaldir}/4dev256nmb-tp1f1b-pack.txt - + --scale 3 > ${evaldir}/4dev256nmb-tp1f1b-pack.txt +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ - --scale 2 --iter-nmb 256 > ${evaldir}/4dev256nmb-1f1b.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ - --scale 2 > ${evaldir}/4dev256nmb-tp.txt + --scale 3 --iter-nmb 1 > ${evaldir}/4dev256nmb-1f1b.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ +# --scale 3 > ${evaldir}/4dev256nmb-tp.txt -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ - --scale 2 --iter-nmb 256 > ${evaldir}/4dev256nmb-tp2pp2.txt +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ +# --scale 3 --iter-nmb 256 > ${evaldir}/4dev256nmb-tp2pp2.txt # 8 gpus @@ -31,37 +31,37 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ --scale 4 --iter-nmb 1 > ${evaldir}/8dev256nmb-1f1b.txt -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ - --scale 4 > ${evaldir}/8dev256nmb-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ - --scale 4 --iter-nmb 256 > ${evaldir}/8dev128nmb-tp4pp2.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ - --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp2pp4.txt +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ +# --scale 4 > ${evaldir}/8dev256nmb-tp.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ +# --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp4pp2.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ +# --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp2pp4.txt # 16 gpus -OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 6 > ${evaldir}/16dev256nmb-tp1f1b.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 16 --pp-size 1 --nmb 256 \ - --scale 6 > ${evaldir}/16dev256nmb-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 4 --nmb 256 \ - --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp4pp4.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 2 --nmb 256 \ - --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp8pp2.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ - --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ +# --scale 6 > ${evaldir}/16dev256nmb-tp1f1b.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 16 --pp-size 1 --nmb 256 \ +# --scale 6 > ${evaldir}/16dev256nmb-tp.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 4 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp4pp4.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 2 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp8pp2.txt +# +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ +# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py index ee96cdf1..70732c3b 100644 --- a/handcraft/mbart/schedule.py +++ b/handcraft/mbart/schedule.py @@ -233,7 +233,8 @@ def schedule_tp_1f1b_pack(model: torch.nn.Module, dataloader, num_microbatch: int, num_stage: int, - neighbors: Tuple[int, int]): + neighbors: Tuple[int, int], + recompute=False): rank = DeviceGroup().rank prev_rank, next_rank = neighbors @@ -339,11 +340,20 @@ def tp_decoder_backward(grads: Tuple[torch.Tensor]): input_tensors.append(inputs) if is_first_stage: inputs = () - outputs = forward_step(model, *inputs) - output_tensors.append(outputs) - - # mem = torch.cuda.max_memory_allocated() - # print(f'rank {rank}: {mem / 1024 / 1024 / 1024} GB forward') + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs) + output_tensors.append(None) + else: + outputs = forward_step(model, *inputs) + output_tensors.append(outputs) + + # recompute if backward is needed + if do_backward: + inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) + if recompute: + assert outputs is None + outputs = forward_step(model, *inputs) # intra-barrier send recv output_grads = (None,) @@ -361,7 +371,7 @@ def tp_decoder_backward(grads: Tuple[torch.Tensor]): # backward last_backward = (None,) if do_backward: - inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) + # inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) input_grads = backward_step(inputs, outputs, output_grads) last_backward = input_grads @@ -404,11 +414,26 @@ def tp_decoder_backward(grads: Tuple[torch.Tensor]): last_forward = (None,) if do_forward: # forward step - outputs = forward_step(model, *inputs) input_tensors.append(inputs) - output_tensors.append(outputs) + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs) + output_tensors.append(None) + else: + outputs = forward_step(model, *inputs) + output_tensors.append(outputs) last_forward = outputs + next_backward = 0 <= (bmid+1) and (bmid+1) <= num_microbatch - 1 + if next_backward: + if recompute: + inputs, outputs = input_tensors[0], output_tensors[0] + assert outputs is None + outputs = forward_step(model, *inputs) + input_tensors[0] = inputs + output_tensors[0] = outputs + + # tp tail forward-backward if 0 <= decoder_bmid and decoder_bmid <= num_microbatch - 1: # FIXME: currently use encoder grad @@ -434,7 +459,13 @@ def tp_decoder_backward(grads: Tuple[torch.Tensor]): # print_each_rank(f'=========end rank {rank}=========') -def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors, group=None): +def schedule_1f1b(model: torch.nn.Module, + dataloader, + num_microbatch: int, + num_stage: int, + neighbors: Tuple[int, int], + group=None, + recompute=False): rank = torch.distributed.get_rank() prev_rank, next_rank = neighbors @@ -454,11 +485,16 @@ def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors, group # recv forward inputs = () if is_first_stage else recv_forward(model, prev_rank) # forward - outputs = forward_step(model, *inputs) + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs) + output_tensors.append(None) + else: + outputs = forward_step(model, *inputs) + output_tensors.append(outputs) # send forward send_forward(outputs, next_rank) input_tensors.append(inputs) - output_tensors.append(outputs) # before running 1f1b: need to recv first forward tensor if num_warmup_remaining > 0: @@ -469,9 +505,14 @@ def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors, group for i in range(num_warmup_remaining): model.set_inputs(next(dataloader)) # forward - outputs = forward_step(model, *inputs) + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs) + output_tensors.append(None) + else: + outputs = forward_step(model, *inputs) + output_tensors.append(outputs) input_tensors.append(inputs) - output_tensors.append(outputs) # send forward recv backward grads = (None,) @@ -480,6 +521,9 @@ def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors, group # backward inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) + if recompute: + assert outputs is None + outputs = forward_step(model, *inputs) input_grads = backward_step(inputs, outputs, grads) # send backward @@ -498,6 +542,9 @@ def schedule_1f1b(model, dataloader, num_microbatch, num_stage, neighbors, group # recv backward grads = (None,) if is_last_stage else recv_backward(model, next_rank) # backward + if recompute: + assert outputs is None + outputs = forward_step(model, *inputs) input_grads = backward_step(inputs, outputs, grads) # send backward if not is_first_stage: From 3c1c2f4b0d39a4db6e31c90a9ce3d5451a462fa6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 24 Mar 2022 14:09:33 +0000 Subject: [PATCH 0685/1892] fix recompute issue --- handcraft/mbart/mbart_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 9e8ec320..795edad8 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -744,7 +744,7 @@ def reduce_embed(model, pp_embed_group): args.pp_size, (_pp_prev_rank, _pp_next_rank), group=_pp_group, - recompute=args.recompute + recompute=args.use_recompute ) # TODO: support gradient allreduce in cpu if not args.embed_cpu: From 9cf5e6a535a6065df29bc3fc85ba426dd55a8445 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 25 Mar 2022 15:00:57 +0000 Subject: [PATCH 0686/1892] model scale --- handcraft/mbart/mbart.py | 34 ++++++---- handcraft/mbart/mbart_hybrid.py | 50 ++++++++------ handcraft/mbart/run-recompute-arch.sh | 76 ++++++++++++++++++++++ handcraft/mbart/run-recompute.sh | 93 +++++++++++++++++++++++---- handcraft/mbart/run.sh | 41 ++++++------ handcraft/mbart/test.py | 45 +++++++++++++ 6 files changed, 276 insertions(+), 63 deletions(-) create mode 100755 handcraft/mbart/run-recompute-arch.sh create mode 100644 handcraft/mbart/test.py diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index e02f4781..b488871b 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -4,20 +4,19 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --pp-size 4 --tp-size 1 --nmb 4 + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 4 --use-recompute """ from typing import Optional import argparse import math import torch -from torch.utils import checkpoint import cube from cube.runtime.device import DeviceGroup from cube.runtime.syndata import SynTextDataLoader from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary, model_summary +from cube.profiler.memory import memory_summary from cube.profiler.timer import print_each_rank from handcraft.mbart.schedule import schedule_1f1b, schedule_tp_1f1b_pack @@ -30,12 +29,19 @@ _pp_prev_rank = None parser = argparse.ArgumentParser(description='swin') -parser.add_argument('--scale', type=int, default=0, - help='scale of model, 0 is original one.') +# model arch +parser.add_argument('--layers', type=int, default=12, + help='number encoder/decoder of layers') +parser.add_argument('--hidden-size', type=int, default=1024, + help='hidden size') +parser.add_argument('--heads', type=int, default=16, + help='number of heads') +# training config parser.add_argument('--nmb', type=int, help='num of micro batch') parser.add_argument('--iter-nmb', type=int, default=0, help='num of micro batch per scheduling iteration (1f1b only)') +# parallelism parser.add_argument('--use-1f1b', action='store_true', help='use 1f1b scheduling') parser.add_argument('--use-tp1f1b-pack', action='store_true', @@ -77,14 +83,19 @@ class Config: # attention_inner_dim = attention_heads * 64 # ffn_dim = 4 * embed_dim - scale = args.scale - scale_p = scale * 0.25 + # scale = args.scale + # scale_p = scale * 0.25 num_embeddings = 500000 # 250027 + int(250027*scale_p) - decoder_layers = 12 + int(12*scale_p) - encoder_layers = 12 + int(12*scale_p) - embed_dim = 1024 + int(1024*scale_p) - attention_heads = 16 + int(16*scale_p) + # decoder_layers = 12 + int(12*scale_p) + # encoder_layers = 12 + int(12*scale_p) + # embed_dim = 1024 + int(1024*scale_p) + # attention_heads = 16 + int(16*scale_p) if scale < 6 else 40 + 8*(scale-6) + decoder_layers = args.layers + encoder_layers = args.layers + embed_dim = args.hidden_size + attention_heads = args.heads + attention_inner_dim = attention_heads * 64 ffn_dim = 4 * embed_dim @@ -654,6 +665,7 @@ def reduce_embed(model, pp_embed_group): cfg = Config() + print_each_rank(f'enc/dec layer#: {cfg.encoder_layers}, embed_dim#: {cfg.embed_dim}, heads#: {cfg.attention_heads}, ffn_dim#: {cfg.ffn_dim}', rank_only=0) dataloader = SynTextDataLoader( shapes=( [1, cfg.max_source_positions], diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 795edad8..69e46fb3 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -4,26 +4,24 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --pp-size 4 --tp-size 1\ + handcraft/mbart/mbart_hybrid.py --pp-size 2 --tp-size 2\ --nmb 4 --scale 0 --iter-nmb 4 """ -import re from typing import Optional import argparse import math import torch -from torch.utils import checkpoint import cube from cube.runtime.device import DeviceGroup from cube.runtime.syndata import SynTextDataLoader from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary, model_summary +from cube.profiler.memory import memory_summary from cube.profiler.timer import print_each_rank -from handcraft.mbart.schedule import schedule_naive, schedule_1f1b -from handcraft.mbart.tp import AllReduceIdentity, IdentityAllreduce, ReduceBroadcast +from handcraft.mbart.schedule import schedule_1f1b +from handcraft.mbart.tp import AllReduceIdentity, IdentityAllreduce _tp_group = -1 _pp_group = -1 @@ -32,19 +30,28 @@ _pp_prev_rank = None -parser = argparse.ArgumentParser(description='swin') +parser = argparse.ArgumentParser(description='mbart hybrid') + +# model arch +parser.add_argument('--layers', type=int, default=12, + help='number encoder/decoder of layers') +parser.add_argument('--hidden-size', type=int, default=1024, + help='hidden size') +parser.add_argument('--heads', type=int, default=16, + help='number of heads') +# training config parser.add_argument('--nmb', type=int, default=4, help='num of micro batch') -parser.add_argument('--scale', type=int, default=0, - help='scale of model, 0 is original one.') +parser.add_argument('--iter-nmb', type=int, default=0, + help='num of micro batch per scheduling iteration') + +# parallelism parser.add_argument('--pp-size', type=int, default=1, help='use pipeline parallelism') parser.add_argument('--tp-size', type=int, default=1, help='use tensor parallelism') parser.add_argument('--embed-cpu', action='store_true', help='put embedding inside CPU') -parser.add_argument('--iter-nmb', type=int, default=0, - help='num of micro batch per scheduling iteration (1f1b only)') parser.add_argument('--use-recompute', action='store_true', help='use recompute for a stage') args = parser.parse_args() @@ -73,7 +80,7 @@ pranks[torch.distributed.get_rank(_tp_group)] = prank torch.distributed.all_gather(pranks, prank, group=_tp_group) torch.cuda.synchronize() - print_each_rank(f'allgather-pp ranks: {pranks}') + # print_each_rank(f'allgather-pp ranks: {pranks}') for prank in pranks: prank = prank.tolist() @@ -89,14 +96,19 @@ class Config: - scale = args.scale - scale_p = scale * 0.25 + # scale = args.scale + # scale_p = scale * 0.25 num_embeddings = 500000 # 250027 + int(250027*scale_p) - decoder_layers = 12 + int(12*scale_p) - encoder_layers = 12 + int(12*scale_p) - embed_dim = 1024 + int(1024*scale_p) - attention_heads = 16 + int(16*scale_p) + # decoder_layers = 12 + int(12*scale_p) + # encoder_layers = 12 + int(12*scale_p) + # embed_dim = 1024 + int(1024*scale_p) + # attention_heads = 16 + int(16*scale_p) if scale < 6 else 40 + 8*(scale-6) + decoder_layers = args.layers + encoder_layers = args.layers + embed_dim = args.hidden_size + attention_heads = args.heads + attention_inner_dim = attention_heads * 64 ffn_dim = 4 * embed_dim @@ -702,7 +714,7 @@ def reduce_embed(model, pp_embed_group): if __name__ == '__main__': cfg = Config() - print_each_rank(cfg, rank_only=0) + print_each_rank(f'enc/dec layer#: {cfg.encoder_layers}, embed_dim#: {cfg.embed_dim}, heads#: {cfg.attention_heads}, ffn_dim#: {cfg.ffn_dim}', rank_only=0) dataloader = SynTextDataLoader( shapes=( [1, cfg.max_source_positions], diff --git a/handcraft/mbart/run-recompute-arch.sh b/handcraft/mbart/run-recompute-arch.sh new file mode 100755 index 00000000..8682e54e --- /dev/null +++ b/handcraft/mbart/run-recompute-arch.sh @@ -0,0 +1,76 @@ +layers=24 +hidden=4096 +heads=32 +gpus=8 + +evaldir=eval/mbart-v100-32gb-pcie-recompute +mkdir -p ${evaldir} + + +# TP-1F1B +echo 'testing mixture-1f1b' +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + + +# # Pure 1F1B +# echo 'testing pure 1f1b' +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ +# handcraft/mbart/mbart.py \ +# --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ +# --use-1f1b --nmb 256 --iter-nmb 256\ +# --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b.txt + +# Pure TP +echo 'testing pure tensor parallelism' +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + +# # Hybrid TP-1F1B -- 4 GPU +# if [ ${gpus} == 4 ] +# then +# echo 'testing hybrid tp:pp=2:2' +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py \ +# --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ +# --tp-size 2 --pp-size 2 --nmb 256 --iter-nmb 256 \ +# --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt +# sleep 5 +# killall python +# sleep 5 +# killall python +# fi +# +# # Hybrid TP-1F1B -- 8 GPU +# if [ ${gpus} == 8 ] +# then +# echo 'testing hybrid tp:pp=4:2' +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py \ +# --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ +# --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ +# --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt +# sleep 5 +# killall python +# sleep 5 +# killall python +# +# echo 'testing hybrid tp:pp=2:4' +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py \ +# --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ +# --tp-size 2 --pp-size 4 --nmb 256 --iter-nmb 256 \ +# --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt +# sleep 5 +# killall python +# sleep 5 +# killall python +# fi + +python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/run-recompute.sh b/handcraft/mbart/run-recompute.sh index f7683c6d..0631d3f1 100755 --- a/handcraft/mbart/run-recompute.sh +++ b/handcraft/mbart/run-recompute.sh @@ -2,47 +2,109 @@ evaldir=/data/MagicCube/scale-mbart-recompute mkdir -p ${evaldir} -# 4 gpus +# 4 gpus recompute scale=3,4,5 + +# TP-1F1B +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ +# --scale 3 --use-recompute > ${evaldir}/4dev256nmb-scale3-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 3 --use-recompute > ${evaldir}/4dev256nmb-tp1f1b-pack.txt + --scale 4 --use-recompute > ${evaldir}/4dev256nmb-scale4-tp1f1b-pack.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ + --scale 5 --use-recompute > ${evaldir}/4dev256nmb-scale5-tp1f1b-pack.txt + +# Pure 1F1B +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ +# --scale 3 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale3-1f1b.txt + +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ +# --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale4-1f1b.txt + +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ +# --scale 5 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale5-1f1b.txt + +# Pure TP +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ +# --scale 3 --use-recompute > ${evaldir}/4dev256nmb-scale3-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ - --scale 3 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-1f1b.txt + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ + --scale 4 --use-recompute > ${evaldir}/4dev256nmb-scale4-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ - --scale 3 --use-recompute > ${evaldir}/4dev256nmb-tp.txt + --scale 5 --use-recompute > ${evaldir}/4dev256nmb-scale5-tp.txt + +# Hybrid TP-PP: TP=2, PP=2 + +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ +# --scale 3 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale3-tp2pp2.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ - --scale 3 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-tp2pp2.txt + --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale4-tp2pp2.txt + +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ +# --scale 5 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale5-tp2pp2.txt -# 8 gpus +# 8 gpus recompute scale=6,7 +# TP-1F1B OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 4 --use-recompute > ${evaldir}/8dev256nmb-tp1f1b-pack.txt + --scale 6 --use-recompute > ${evaldir}/8dev256nmb-scale6-tp1f1b-pack.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ - --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-1f1b.txt + handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ + --scale 7 --use-recompute > ${evaldir}/8dev256nmb-scale7-tp1f1b-pack.txt +# Pure 1F1B +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ +# --scale 6 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb--scale6-1f1b.txt + +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ +# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ +# --scale 6 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb--scale7-1f1b.txt + + +# Pure TP OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ - --scale 4 --use-recompute > ${evaldir}/8dev256nmb-tp.txt + --scale 6 --use-recompute > ${evaldir}/8dev256nmb-scale6-tp.txt OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ - --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-tp4pp2.txt + handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ + --scale 7 --use-recompute > ${evaldir}/8dev256nmb-scale7-tp.txt + +# Hybrid TP-PP: TP2-PP4, TP4-PP2 OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ - --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-tp2pp4.txt + --scale 6 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-scale6-tp2pp4.txt +# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ +# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ +# --scale 7 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-scale7-tp2pp4.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ + --scale 6 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-scale6-tp4pp2.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ + --scale 7 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-scale7-tp4pp2.txt # 16 gpus @@ -65,3 +127,6 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ # OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ # handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ # --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt + +echo 'done!!!' +python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh index 59665fd0..60cfa3a3 100755 --- a/handcraft/mbart/run.sh +++ b/handcraft/mbart/run.sh @@ -7,18 +7,18 @@ mkdir -p ${evaldir} OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ --scale 3 > ${evaldir}/4dev256nmb-tp1f1b-pack.txt -# + OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ --scale 3 --iter-nmb 1 > ${evaldir}/4dev256nmb-1f1b.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ -# --scale 3 > ${evaldir}/4dev256nmb-tp.txt -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ -# --scale 3 --iter-nmb 256 > ${evaldir}/4dev256nmb-tp2pp2.txt +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ + --scale 3 > ${evaldir}/4dev256nmb-tp.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ + --scale 3 --iter-nmb 256 > ${evaldir}/4dev256nmb-tp2pp2.txt # 8 gpus @@ -31,17 +31,17 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ --scale 4 --iter-nmb 1 > ${evaldir}/8dev256nmb-1f1b.txt -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ -# --scale 4 > ${evaldir}/8dev256nmb-tp.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ -# --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp4pp2.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ -# --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp2pp4.txt +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ + --scale 4 > ${evaldir}/8dev256nmb-tp.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ + --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp4pp2.txt + +OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ + --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp2pp4.txt # 16 gpus @@ -65,3 +65,6 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ # OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ # handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ # --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt + +echo 'done!!!' +python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/test.py b/handcraft/mbart/test.py new file mode 100644 index 00000000..db9da84b --- /dev/null +++ b/handcraft/mbart/test.py @@ -0,0 +1,45 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 handcraft/mbart/test.py +""" + +import torch +import cube +from cube.profiler.memory import memory_summary, model_summary +from handcraft.mbart.tp import AllReduceIdentity, IdentityAllreduce + +scale = 7 +embed_dim = 1024 + int(1024 * (scale * 0.25)) +print(f'embed dim = {embed_dim}') + + + +class Model(torch.nn.Module): + + def __init__(self): + super().__init__() + self.embed = torch.nn.Embedding(500000, embed_dim) + + def forward(self, x): + out = self.embed(x) + loss = torch.sum(out) + return loss + +cube.init() +print('loading...') +model = Model().cuda() +input_ids = torch.randint(0, 25000, (1, 1024), dtype=torch.int).cuda() + +optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + +print('training...') +for _ in range(3): + + loss = model(input_ids) + loss.backward() + optimizer.step() + +memory_summary() From 0e68d16f961615e1eb2bb66d11fc097f1b953b46 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 27 Mar 2022 08:00:34 +0000 Subject: [PATCH 0687/1892] full scripts to run v100-32gb experiments --- .../mbart/run-recompute-full-v100-32gb.sh | 301 ++++++++++++++++++ handcraft/mbart/run-recompute.sh | 132 -------- 2 files changed, 301 insertions(+), 132 deletions(-) create mode 100755 handcraft/mbart/run-recompute-full-v100-32gb.sh delete mode 100755 handcraft/mbart/run-recompute.sh diff --git a/handcraft/mbart/run-recompute-full-v100-32gb.sh b/handcraft/mbart/run-recompute-full-v100-32gb.sh new file mode 100755 index 00000000..a3cfd790 --- /dev/null +++ b/handcraft/mbart/run-recompute-full-v100-32gb.sh @@ -0,0 +1,301 @@ +evaldir=/data/MagicCube/scale-mbart-recompute + +mkdir -p ${evaldir} + +# ================================================= +# 4 gpus: arch layer 21,21, hidden 1792, heads 28 +# ================================================= +layers=21 +hidden=1792 +heads=28 +gpus=4 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b.txt + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + +echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 2 --pp-size 2 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt + + +# ================================================= +# 4 gpus: arch layer 24,24, hidden 2048, heads 32 +# ================================================= +layers=24 +hidden=2048 +heads=32 +gpus=4 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" +echo "Will be OOM" + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + +echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 2 --pp-size 2 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt + + +# ================================================= +# 4 gpus: arch layer 24,24, hidden 2560, heads 32 +# ================================================= +layers=24 +hidden=2560 +heads=32 +gpus=4 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" +echo "Will be OOM" + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + +echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" +echo "Will be OOM" + + +# ================================================= +# 4 gpus: arch layer 18,18, hidden 3072, heads 32 +# ================================================= +layers=18 +hidden=3072 +heads=32 +gpus=4 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" +echo "Will be OOM" + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + +echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" +echo "Will be OOM" + + +# ================================================= +# 4 gpus: arch layer 27,27, hidden 3072, heads 32 +# ================================================= +layers=27 +hidden=2304 +heads=36 +gpus=4 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" +echo "Will be OOM" + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + +echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" +echo "Will be OOM" + + +# ================================================= +# 8 gpus: arch layer 24,24, hidden 2048, heads 32 +# ================================================= +layers=24 +hidden=2048 +heads=32 +gpus=8 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b.txt + + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + +echo "testing pure tensor parallelism 2x4: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 2 --pp-size 4 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt + +echo "testing tensor x pipeline parallelism 2x4: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt + +# ================================================= +# 8 gpus: arch layer 30,30, hidden 2560, heads 40 +# ================================================= +layers=30 +hidden=2560 +heads=40 +gpus=8 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + +echo "testing pure tensor parallelism 2x4: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 2 --pp-size 4 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt + +echo "testing tensor x pipeline parallelism 2x4: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt + + +# ================================================= +# 8 gpus: arch layer 33,33, hidden 2816, heads 40 +# ================================================= +layers=33 +hidden=2816 +heads=48 +gpus=8 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + + +echo "testing tensor x pipeline parallelism 4x2: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt + +# ================================================= +# 8 gpus: arch layer 24,24, hidden 4096, heads 32 +# ================================================= +layers=24 +hidden=4096 +heads=32 +gpus=8 + +echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-tp1f1b-pack --nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt + +echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + + +echo "testing tensor x pipeline parallelism 4x2: L${layers}E${hidden}H${heads}" +echo "Will OOM" + + +echo 'done!!!' +# python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/run-recompute.sh b/handcraft/mbart/run-recompute.sh deleted file mode 100755 index 0631d3f1..00000000 --- a/handcraft/mbart/run-recompute.sh +++ /dev/null @@ -1,132 +0,0 @@ -evaldir=/data/MagicCube/scale-mbart-recompute - -mkdir -p ${evaldir} - -# 4 gpus recompute scale=3,4,5 - -# TP-1F1B -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ -# --scale 3 --use-recompute > ${evaldir}/4dev256nmb-scale3-tp1f1b-pack.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 4 --use-recompute > ${evaldir}/4dev256nmb-scale4-tp1f1b-pack.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 5 --use-recompute > ${evaldir}/4dev256nmb-scale5-tp1f1b-pack.txt - -# Pure 1F1B -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ -# --scale 3 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale3-1f1b.txt - -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ -# --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale4-1f1b.txt - -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ -# --scale 5 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale5-1f1b.txt - -# Pure TP -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ -# --scale 3 --use-recompute > ${evaldir}/4dev256nmb-scale3-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ - --scale 4 --use-recompute > ${evaldir}/4dev256nmb-scale4-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ - --scale 5 --use-recompute > ${evaldir}/4dev256nmb-scale5-tp.txt - -# Hybrid TP-PP: TP=2, PP=2 - -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ -# --scale 3 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale3-tp2pp2.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ - --scale 4 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale4-tp2pp2.txt - -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ -# --scale 5 --iter-nmb 256 --use-recompute > ${evaldir}/4dev256nmb-scale5-tp2pp2.txt - - -# 8 gpus recompute scale=6,7 - -# TP-1F1B -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 6 --use-recompute > ${evaldir}/8dev256nmb-scale6-tp1f1b-pack.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 7 --use-recompute > ${evaldir}/8dev256nmb-scale7-tp1f1b-pack.txt - -# Pure 1F1B -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ -# --scale 6 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb--scale6-1f1b.txt - -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ -# --scale 6 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb--scale7-1f1b.txt - - -# Pure TP -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ - --scale 6 --use-recompute > ${evaldir}/8dev256nmb-scale6-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ - --scale 7 --use-recompute > ${evaldir}/8dev256nmb-scale7-tp.txt - - -# Hybrid TP-PP: TP2-PP4, TP4-PP2 -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ - --scale 6 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-scale6-tp2pp4.txt - -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ -# --scale 7 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-scale7-tp2pp4.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ - --scale 6 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-scale6-tp4pp2.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ - --scale 7 --iter-nmb 256 --use-recompute > ${evaldir}/8dev256nmb-scale7-tp4pp2.txt - -# 16 gpus - -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ -# --scale 6 > ${evaldir}/16dev256nmb-tp1f1b.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 16 --pp-size 1 --nmb 256 \ -# --scale 6 > ${evaldir}/16dev256nmb-tp.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 4 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp4pp4.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 2 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp8pp2.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt - -echo 'done!!!' -python scripts/keep.py --gpus 8 From ad1b638b188084097d5c2f79f1745dbd4213903b Mon Sep 17 00:00:00 2001 From: lynex Date: Sun, 27 Mar 2022 16:13:24 +0800 Subject: [PATCH 0688/1892] dim-partitioning algo for pad --- cube/algorithm/factory.py | 2 + cube/algorithm/ops/pad.py | 91 ++++++++++++++++++++ cube/algorithm/utils.py | 2 +- cube/graph/adapter/adapter.py | 6 +- cube/graph/graph.py | 5 +- cube/graph/operator/function/pad.py | 7 +- cube/runtime/function/function.py | 3 +- cube/runtime/syndata.py | 1 + examples/atmosphere/algo_test.py | 106 ++++++++++++++++++++++++ examples/atmosphere/policy/replicate.py | 13 +++ examples/atmosphere/policy/split.py | 40 +++++++++ 11 files changed, 266 insertions(+), 10 deletions(-) create mode 100644 cube/algorithm/ops/pad.py create mode 100644 examples/atmosphere/algo_test.py create mode 100644 examples/atmosphere/policy/replicate.py create mode 100644 examples/atmosphere/policy/split.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index e898d60e..183ee24b 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -69,6 +69,8 @@ def _load_predefined_algos(self): self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') self.register(conv.IRConv2D, conv.HaloSplitConv2D, tag='halo') + import cube.algorithm.ops.pad as pad + self.register(pad.IRPad, pad.DimSplitPad, tag='dim') # import cube.algorithm.ops.elementwise as elew # self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') # self.register(elew.Add, elew.AddDimParallel, tag='dim') diff --git a/cube/algorithm/ops/pad.py b/cube/algorithm/ops/pad.py new file mode 100644 index 00000000..47c47438 --- /dev/null +++ b/cube/algorithm/ops/pad.py @@ -0,0 +1,91 @@ +from typing import Dict + +from cube.algorithm.utils import split_axis, split_axis_custom, split_value +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.operator.function.pad import IRPad + +class DimSplitPad(GenericDistAlgo): + """ + split Pad at dimension level + + """ + def __init__(self, node: IRPad): + if not isinstance(node, IRPad): + raise TypeError(f"Expect IRConv2D") + super().__init__(node) + + def satisfy(self, config: Dict): + """ + config = dict(idx=int, dim=int, num=num) + + """ + for attr in ['dim', 'num']: + if not attr in config: + raise KeyError("Expected dim, num in the config") + node: IRPad = self.node + dim: int = config['dim'] + num: int = config['num'] + pad = node.kwargs['pad'] + mode = node.kwargs['mode'] + value = node.kwargs['value'] + assert len(pad) % 2 == 0 + pad_dim_count = len(pad) / 2 + + # split non-pad dim + if dim < len(node.inputs(0).shape) - pad_dim_count: + return node.inputs(0).shape[dim] % num == 0 + # split pad dim + else: + dim_in_pad = len(node.inputs(0).shape) - 1 - dim + return (node.inputs(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) % num == 0 + + def instantiate(self, config: Dict): + if not self.satisfy(config): + return False + node: IRPad = self.node + dim: int = config['dim'] + num: int = config['num'] + pad = node.kwargs['pad'] + mode = node.kwargs['mode'] + value = node.kwargs['value'] + pad_dim_count = len(pad) / 2 + + inputs = list() + outputs = list() + subnodes = list() + + # split non-pad dim + if dim < len(node.inputs(0).shape) - pad_dim_count: + inputs = split_axis(node.inputs(0), axis=dim, chunk_num=num) + outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) + for i, o in zip(inputs, outputs): + subnodes.append(node.new([i], [o])) + else: # split pad dim + inputs = split_axis(node.inputs(0), axis=dim, chunk_num=num) + slicers = list() + pads = list() + dim_in_pad = len(node.inputs(0).shape) - 1 - dim + global_padl = pad[dim_in_pad * 2] + global_padr = pad[dim_in_pad * 2 + 1] + chunk_size = (node.outputs(0).shape[dim] - global_padl - global_padr) // num + start = 0 + for cid in range(num): + padl = global_padl if cid == 0 else 0 + padr = global_padr if cid == num-1 else 0 + + cur_pad = pad.copy() + cur_pad[dim_in_pad * 2] = padl + cur_pad[dim_in_pad * 2 + 1] = padr + pads.append(cur_pad) + + stop = start + padl + padr + chunk_size + slicers.append(slice(max(0, start), min(node.outputs(0).shape[dim], stop))) + start = stop + + outputs = split_axis_custom(node.outputs(0), axis=dim, chunks=slicers) + + for i, o, p in zip(inputs, outputs, pads): + subnodes.append(node.new([i], [o], pad=p)) + + return subnodes \ No newline at end of file diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py index 5e0bd4a2..ef3a266a 100644 --- a/cube/algorithm/utils.py +++ b/cube/algorithm/utils.py @@ -36,7 +36,7 @@ def split_axis(tensor: IRSubTensor, axis: int, chunk_num: int): def split_axis_custom(tensor: IRSubTensor, axis: int, chunks: List[slice]): """ - Split tensor along an axis with cutomized selection + Split tensor along an axis with customized selection """ if axis < 0: axis = len(tensor.shape) + axis diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index c34579a9..74d50129 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -342,7 +342,9 @@ def gen_select(dst_tensor): odevice = otensor.device # local and remote adapter in-tensor - local, remote = list(), list() + # local_remote instead of local + remote to preserve inputs order + # TODO check order as may affect merging result + local, remote, local_and_remote = list(), list(), otensor.parent.ptensors for ptensor in otensor.parent.ptensors: if ptensor.device == odevice: local.append(ptensor) @@ -356,7 +358,7 @@ def gen_select(dst_tensor): return inputs, intersections, prims # check local + remote - for itensor in local + remote: + for itensor in local_and_remote: #local + remote: if not itensor.overlap(otensor): continue diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 086e1f4b..d544d08f 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -291,7 +291,7 @@ def get_outputs(nodes: List[IRCell]): ## Parallel Policy Primitives ## - def replicate(self, op: IRCell, times=1) -> Optional[List[IRCell]]: + def replicate(self, op: IRCell, times=1, reset_dependency=True) -> Optional[List[IRCell]]: """ Replicate a forward or data operation multiple times. @@ -318,7 +318,8 @@ def replicate(self, op: IRCell, times=1) -> Optional[List[IRCell]]: bidx = self.nodes().index(op.mirror) for idx, bnode in enumerate(bnodes): self.attach(bnode, bidx + idx) - self.reset_dependency() + if reset_dependency: + self.reset_dependency() return [op] + fnodes def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: diff --git a/cube/graph/operator/function/pad.py b/cube/graph/operator/function/pad.py index db92f78b..e1f96cf6 100644 --- a/cube/graph/operator/function/pad.py +++ b/cube/graph/operator/function/pad.py @@ -35,16 +35,17 @@ def infer_shape(self) -> bool: self.outputs(0).shape = shape return True - def new(self, inputs: List, outputs: List): + def new(self, inputs: List, outputs: List, pad = None): """ construct a new operator sharing same kwargs with new inputs and outputs """ - pad = self.kwargs['pad'] + if pad == None: + pad = self.kwargs['pad'] mode = self.kwargs['mode'] value = self.kwargs['value'] op = IRPad(self.signature, inputs, self.name, - pad=pad, mode=mode, value=value) + pad=pad, mode=mode, value=value) assert len(outputs) == 1 op.set_output(0, outputs[0]) op.infer_shape() diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 546844ed..68703764 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -85,5 +85,4 @@ def update_geopotential_(phi: torch.Tensor, zs: torch.Tensor, P: torch.Tensor, P return phi def strip_2_borders(w: torch.Tensor): - return w[1:-1] - + return w[1:-1] \ No newline at end of file diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 514761cd..73374165 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -150,6 +150,7 @@ def __iter__(self): return self def set_data_buffer(self, buffer_num = 4): + torch.manual_seed(0) self.datas = list() self._buffer_num = buffer_num for _ in range(self._buffer_num): diff --git a/examples/atmosphere/algo_test.py b/examples/atmosphere/algo_test.py new file mode 100644 index 00000000..666f97e4 --- /dev/null +++ b/examples/atmosphere/algo_test.py @@ -0,0 +1,106 @@ +""" +example: + +python -m torch.distributed.launch \ + --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=8004 \ + --use_env \ + examples/mlp/linears.py + +OMP_NUM_THREADS=4 torchrun --standalone \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --rdzv_id=888 \ + --rdzv_backend=c10d \ + --rdzv_endpoint=worker0:8004 \ + examples/mlp/linears.py +""" + +import torch +from torch import nn + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from examples.atmosphere.policy.split import PAS +import torch.nn.functional as F + + +# from examples.mlp.policy.col_parallel import P, A, S +# PAS = (P, A, S) + +# =================== Semantic Model Description ==================== + +class MLP(nn.Module): + def __init__(self, dim, mult=1): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult) + + + def forward(self, data): + a = self.linear1(data) + paded = F.pad(a, (1, 1), "constant", 8.8) + output = paded + 0 + # loss = torch.sum(output) + # return loss + return output + + +def train(): + batch_size = 4 + dim = 4 + + model = MLP(dim=dim) + model = cube.SemanticModel( + model, input_shapes=([batch_size, dim],), + ) + + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([batch_size, dim],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) + + @cube.compile(model, dataloader, PAS=PAS, override=True) + def train_iter(model, dataloader): + data = next(dataloader) + # loss = model(data) + # loss.backward() + output = model(data) + return output + + model = model.get_gen_module() + + # optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + # CudaTimer(enable=False).warmup() + torch.distributed.barrier() + iter_num = 1 + for step in range(iter_num): + # if step >= 40: + # CudaTimer(enable=True).start('e2e') + output = train_iter(model, dataloader) + # optimizer.step() + # optimizer.zero_grad() + # if step >= 40: + # CudaTimer().stop('e2e') + # if (step + 1) % 20 == 0: + # print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + print(f'output = {output}') + + # print_each_rank('e2e time (ms) per iteration: {} ms'.format( + # CudaTimer().duration(iter_num - 40, field_name='e2e'))) + # CudaTimer().print_all(times=iter_num - 40) + + +if __name__ == '__main__': + cube.init() + train() \ No newline at end of file diff --git a/examples/atmosphere/policy/replicate.py b/examples/atmosphere/policy/replicate.py new file mode 100644 index 00000000..3653786d --- /dev/null +++ b/examples/atmosphere/policy/replicate.py @@ -0,0 +1,13 @@ +from cube.graph import IRGraph +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation + + +def PAS(graph: IRGraph, resource): + print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus, reset_dependency=False) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph diff --git a/examples/atmosphere/policy/split.py b/examples/atmosphere/policy/split.py new file mode 100644 index 00000000..d58d05a9 --- /dev/null +++ b/examples/atmosphere/policy/split.py @@ -0,0 +1,40 @@ +from cube.graph import IRGraph +from cube.graph.adapter.adapter import IRAdapter +from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.graph.operator.function.conv import IRConv3D +from cube.graph.operator.function.pad import IRPad + +def PAS(graph: IRGraph, resource): + print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + if isinstance(node, IRDataOperation): + print(f'### IRDataOperation = {node}') + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, config=dict(num=resource.ngpus)) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + elif isinstance(node, IRPad): + print(f'### IRPad = {node}') + sub_nodes = list() + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, config=dict(dim=1, num=min(2, resource.ngpus))) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + elif isinstance(node, IRConv3D): + print(f'### IRConv3D = {node}') + sub_nodes = list() + algo = node.algorithms('halo') + Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus // 2)) + for Wnode in Wnodes: + algo = Wnode.algorithms('halo') + Hnodes = graph.partition(Wnode, algo, config=dict(idx=0, dim=2, num=2)) + sub_nodes += Hnodes + else: + print(f'### to-replicate = {node}') + sub_nodes = graph.replicate(node, times=resource.ngpus, reset_dependency=False) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + else: + print(f'### non-IRBpOperation = {node}') + return graph From 968b06e363cf8a4fc707395eeac2229267495374 Mon Sep 17 00:00:00 2001 From: lynex Date: Sun, 27 Mar 2022 20:27:04 +0800 Subject: [PATCH 0689/1892] halo-partitioning algo for conv3d --- cube/algorithm/factory.py | 1 + cube/algorithm/ops/conv.py | 111 +++++++++++++++++++++++++++ cube/graph/operator/function/conv.py | 2 +- cube/runtime/function/function.py | 15 ++++ examples/atmosphere/algo_test.py | 61 +++++++++++---- examples/atmosphere/policy/split.py | 14 ++-- 6 files changed, 184 insertions(+), 20 deletions(-) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 183ee24b..71545ec8 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -68,6 +68,7 @@ def _load_predefined_algos(self): import cube.algorithm.ops.conv as conv self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') self.register(conv.IRConv2D, conv.HaloSplitConv2D, tag='halo') + self.register(conv.IRConv3D, conv.HaloSplitConv3D, tag='halo') import cube.algorithm.ops.pad as pad self.register(pad.IRPad, pad.DimSplitPad, tag='dim') diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 1c593801..295c679b 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -4,6 +4,7 @@ from cube.algorithm.generics import GenericDistAlgo from cube.graph.operator.function.conv import IRConv2D +from cube.graph.operator.function.conv import IRConv3D class DimSplitConv2D(GenericDistAlgo): @@ -191,3 +192,113 @@ def instantiate(self, config: Dict): conv.set_output(0, o) sub_nodes.append(conv) return sub_nodes + + + +class HaloSplitConv3D(GenericDistAlgo): + """ + Halo-exchange split + + N iC D H W, oC iC dH dW, oC -> N oC oD oH oW + (dim-N is optional) + """ + + def __init__(self, node: IRConv3D): + if not isinstance(node, IRConv3D): + raise TypeError(f"Expect IRConv2D") + super().__init__(node) + + def satisfy(self, config: Dict): + for attr in ['idx', 'dim', 'num']: + if not attr in config: + raise KeyError("Expected idx, dim, num in the config") + node: IRConv3D = self.node + oD, oH, oW = node.outputs(0).shape[2:] + idx: int = config['idx'] + dim: int = config['dim'] + num: int = config['num'] + stride = node.kwargs['stride'] + dilation = node.kwargs['dilation'] + if dim not in [2, 3]: + return False + # FIXME: stride + if stride != [1, 1, 1]: + raise NotImplementedError("Splitting on stride != [1,1] is not supported") + if dilation != [1, 1, 1]: + raise NotImplementedError("Splitting on dilation != [1,1] is not supported") + # split H + if (idx, dim) == (0, 2): + return oH % num == 0 + # split W + if (idx, dim) == (0, 3): + return oW % num == 0 + + def instantiate(self, config: Dict): + if not self.satisfy(config): + return None + node: IRConv3D = self.node + D, H, W = node.inputs(0).shape[2:] + dD, dH, dW = node.inputs(1).shape[2:] + oD, oH, oW = node.outputs(0).shape[2:] + idx: int = config['idx'] + dim: int = config['dim'] + num: int = config['num'] + groups = node.kwargs['groups'] + stride = node.kwargs['stride'] + padding = node.kwargs['padding'] + dilation = node.kwargs['dilation'] + # split H + if (idx, dim) == (0, 2): + # input and padding + slicers = list() + pads = list() + start = 0 - padding[0] + for cid in range(num): + # padding + padl = padding[1] if cid == 0 else 0 + padr = padding[1] if cid == num - 1 else 0 + pads.append([padding[0], padding[0], padl, padr, padding[2], padding[2]]) + # input -- FIXME: only work for stride=[1,1] + chunkH = oH // num + dilation[0] * (dH - 1) + stop = start + chunkH - padr + slicers.append(slice(max(0, start), min(H, stop))) + start = stop - dilation[0] * (dH - 1) + # start = 0 if cid == 0 else 1023 + # stop = 1025 if cid == 0 else H + inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) + # weight + weights = [node.inputs(1)] * num + # bias + bias = [node.inputs(2)] * num + # outputs + outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) + # split W + if (idx, dim) == (0, 3): + # input and padding + slicers = list() + pads = list() + start = 0 - padding[2] + for cid in range(num): + # padding + padt = padding[2] if cid == 0 else 0 + padb = padding[2] if cid == num - 1 else 0 + pads.append([padding[0], padding[0], padding[1], padding[1], padt, padb]) + # input -- FIXME: only work for stride=[1,1] + chunkH = oW // num + dilation[0] * (dH - 1) + stop = start + chunkH - padb + slicers.append(slice(max(0, start), min(H, stop))) + start = stop - dilation[0] * (dH - 1) + inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) + # weight + weights = [node.inputs(1)] * num + # bias + bias = [node.inputs(2)] * num + # outputs + outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) + sub_nodes = list() + for i, w, b, pad, o in zip(inputs, weights, bias, pads, outputs): + conv = IRConv3D(node.signature, [i, w, b], node.name, + stride=stride, padding=pad, dilation=dilation, groups=groups) + conv.set_output(0, o) + sub_nodes.append(conv) + return sub_nodes diff --git a/cube/graph/operator/function/conv.py b/cube/graph/operator/function/conv.py index 9bb2b636..92213396 100644 --- a/cube/graph/operator/function/conv.py +++ b/cube/graph/operator/function/conv.py @@ -58,7 +58,7 @@ class IRConv3D(IRFwOperation): def __init__(self, signature: str, inputs: List[IRTensor], name: str, **kwargs): - #TODO signature = 'cube.runtime.function.conv3d' + signature = 'cube.runtime.function.conv3d' assert len(inputs) == 3, "Expected only input, weight, bias as inputs" assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" super().__init__(name, signature, 3, 1) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 68703764..fe25c860 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -17,6 +17,21 @@ def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso input = TorchF.pad(input, padding, 'constant', 0) return TorchF.conv2d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) +def conv3d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + stride: int, padding: List[int], dilation, groups: int = 1): + """ + input: N iC D H W, + weight: oC iC dH dW, oC + bias: oC + padding: List[int, int, int, int]: [Htop, Hbottom, Wtop, Wbottom] or + List[int, int]: [Hside, Wside] + + output: N oC oD oH oW + """ + # switch D, H and W to match torch.nn.functional.pad + padding = [padding[(2 + i) // 2 * (-2) + (i % 2)] for i in range(len(padding))] + input = TorchF.pad(input, padding, 'constant', 0) + return TorchF.conv3d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): """ diff --git a/examples/atmosphere/algo_test.py b/examples/atmosphere/algo_test.py index 666f97e4..603321f9 100644 --- a/examples/atmosphere/algo_test.py +++ b/examples/atmosphere/algo_test.py @@ -40,11 +40,10 @@ # =================== Semantic Model Description ==================== class MLP(nn.Module): - def __init__(self, dim, mult=1): + def __init__(self, dim, mult=1, filter=None): super().__init__() self.linear1 = nn.Linear(dim, dim * mult) - def forward(self, data): a = self.linear1(data) paded = F.pad(a, (1, 1), "constant", 8.8) @@ -53,21 +52,55 @@ def forward(self, data): # return loss return output +class ConvModel(nn.Module): + def __init__(self, dim, mult=1, filter=None): + super().__init__() + # self.linear1 = nn.Linear(dim, dim * mult) + self.filter = filter -def train(): - batch_size = 4 - dim = 4 + def forward(self, data): + # a = self.linear1(data) + # paded = F.pad(a, (1, 1), "constant", 8.8) + # output = paded + 0 + added = data + 1.0 + convd = torch.nn.functional.conv3d(added, self.filter, padding=[1,1,1]) + output = convd + 0 - model = MLP(dim=dim) - model = cube.SemanticModel( - model, input_shapes=([batch_size, dim],), - ) + # loss = torch.sum(output) + # return loss + return output - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, dim],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) +def train(): + batch_size = 2 + dim = 4 + in_channel, out_channel = 2, 2 + dimT, dimH, dimW = 2, 4, 4 + kT, kH, kW = 1, 3, 3 + + to_test = "MLP" + to_test = "Conv3d" + if to_test == "MLP": + model = MLP(dim=dim) + model = cube.SemanticModel( + model, input_shapes=([batch_size, dim],), + ) + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([batch_size, dim],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) + elif to_test == "Conv3d": + filter = torch.randn(out_channel, in_channel, kT, kH, kW) + model = ConvModel(dim=dim, filter=filter) + model = cube.SemanticModel( + model, input_shapes=([batch_size, in_channel, dimT, dimH, dimW],), + ) + + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([batch_size, in_channel, dimT, dimH, dimW],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) @cube.compile(model, dataloader, PAS=PAS, override=True) def train_iter(model, dataloader): diff --git a/examples/atmosphere/policy/split.py b/examples/atmosphere/policy/split.py index d58d05a9..0fecdc6a 100644 --- a/examples/atmosphere/policy/split.py +++ b/examples/atmosphere/policy/split.py @@ -25,11 +25,15 @@ def PAS(graph: IRGraph, resource): print(f'### IRConv3D = {node}') sub_nodes = list() algo = node.algorithms('halo') - Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus // 2)) - for Wnode in Wnodes: - algo = Wnode.algorithms('halo') - Hnodes = graph.partition(Wnode, algo, config=dict(idx=0, dim=2, num=2)) - sub_nodes += Hnodes + # Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus // 2)) + Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus)) + # for Wnode in Wnodes: + # algo = Wnode.algorithms('halo') + # Hnodes = graph.partition(Wnode, algo, config=dict(idx=0, dim=2, num=2)) + # sub_nodes += Hnodes + sub_nodes += Wnodes #TODO remove temp + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) else: print(f'### to-replicate = {node}') sub_nodes = graph.replicate(node, times=resource.ngpus, reset_dependency=False) From d6c04d9424414a90da654c1253a014a1fa1c8c35 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 27 Mar 2022 13:27:22 +0000 Subject: [PATCH 0690/1892] print all metrics --- handcraft/mbart/mbart.py | 1 + handcraft/mbart/mbart_hybrid.py | 1 + 2 files changed, 2 insertions(+) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index b488871b..31d330e0 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -722,4 +722,5 @@ def reduce_embed(model, pp_embed_group): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-2, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-2) memory_summary() diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index 69e46fb3..d941d3ed 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -782,4 +782,5 @@ def reduce_embed(model, pp_embed_group): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-2, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-2) memory_summary() From f1fa2b9941e0b61f75fd2bd774a936fb03909fb1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 27 Mar 2022 21:28:49 +0800 Subject: [PATCH 0691/1892] add swap to mbart --- handcraft/mbart/mbart.py | 45 ++++++++++--- handcraft/mbart/swap.py | 137 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 9 deletions(-) create mode 100644 handcraft/mbart/swap.py diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index b488871b..0076d699 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -4,7 +4,10 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 4 --use-recompute + handcraft/mbart/mbart.py \ + --layers 12 --hidden-size 1024 --heads 16 \ + --use-1f1b --nmb 4 --iter-nmb 4 \ + --use-recompute --use-swap """ from typing import Optional @@ -20,6 +23,7 @@ from cube.profiler.timer import print_each_rank from handcraft.mbart.schedule import schedule_1f1b, schedule_tp_1f1b_pack +from handcraft.mbart.swap import SwapEmbed, get_swap_parameters from handcraft.mbart.tp import ReduceBroadcast @@ -48,6 +52,8 @@ help='use tensor parallel 1f1b') parser.add_argument('--use-recompute', action='store_true', help='use recompute for a stage') +parser.add_argument('--use-swap', action='store_true', + help='use embedding swap (1f1b only)') args = parser.parse_args() print(args) @@ -364,7 +370,7 @@ def forward(self, dec: torch.Tensor, labels): class ShardEmbed(torch.nn.Module): - def __init__(self, cfg: Config, group=-1): + def __init__(self, cfg: Config, group=-1, swap=False): """ group = -1 means no tensor parallelism """ @@ -375,10 +381,15 @@ def __init__(self, cfg: Config, group=-1): self.shard_idx = torch.distributed.get_rank(group) if group != -1 else 0 if self.shard_num > 0: print(f'[{torch.distributed.get_rank()}]: initialize sharding embed (x{self.shard_num})') + assert not (swap and self.shard_idx > 1), "only 1f1b can use swap" - self.vocab_start_index = self.cfg.num_embeddings // self.shard_num * self.shard_idx - self.vocab_end_index = self.cfg.num_embeddings // self.shard_num * (self.shard_idx + 1) - self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.shard_num, self.cfg.embed_dim))) + self.swap = swap + if swap: + self.embed = SwapEmbed(self.cfg.num_embeddings, self.cfg.embed_dim) + else: + self.vocab_start_index = self.cfg.num_embeddings // self.shard_num * self.shard_idx + self.vocab_end_index = self.cfg.num_embeddings // self.shard_num * (self.shard_idx + 1) + self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.shard_num, self.cfg.embed_dim))) # encoder-preprocess self.embed_positions_encoder = PositionalEmbedding(cfg.max_source_positions, cfg.embed_dim) @@ -396,7 +407,9 @@ def set_inputs(self, *inputs): self._inputs = inputs def embed_lookup(self, tokens, dst: Optional[int] = None): - if self.shard_num > 1: + if self.swap: + embed = self.embed(tokens) + elif self.shard_num > 1: mask = (tokens < self.vocab_start_index) | \ (tokens >= self.vocab_end_index) tokens = tokens.clone() - self.vocab_start_index @@ -486,7 +499,7 @@ def __init__(self, cfg: Config, self.decoder_layer_end = self.layer_end if encoder_preprocess or decoder_preprocess or shard: - self.headtail = ShardEmbed(cfg, group = None if shard else -1) + self.headtail = ShardEmbed(cfg, group = None if shard else -1, swap=args.use_swap) else: self.headtail = None @@ -653,11 +666,21 @@ def reduce_embed(model, pp_embed_group): Embedding gradients needs to be reduced across pipeline stages """ if isinstance(model.headtail, torch.nn.Module): - grad = model.headtail.weight.grad + if model.headtail.swap: + with torch.no_grad(): + grad = model.headtail.embed.weight.grad + grad = grad.data.cuda() + else: + grad = model.headtail.weight.grad else: grad = None if grad is not None: torch.distributed.all_reduce(grad, group=pp_embed_group) + if isinstance(model.headtail, torch.nn.Module): + if model.headtail.swap: + with torch.no_grad(): + grad = grad.cpu() + model.headtail.embed.weight.grad = grad torch.cuda.synchronize() @@ -684,7 +707,11 @@ def reduce_embed(model, pp_embed_group): print_each_rank('model weight consumpition:') memory_summary() - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + if args.use_swap: + parameters = get_swap_parameters() + list(model.parameters()) + else: + parameters = model.parameters() + optimizer = torch.optim.Adam(parameters, lr=3e-05, betas=(0.9, 0.98)) CudaTimer(enable=False).warmup() iter_num = 6 diff --git a/handcraft/mbart/swap.py b/handcraft/mbart/swap.py new file mode 100644 index 00000000..70617647 --- /dev/null +++ b/handcraft/mbart/swap.py @@ -0,0 +1,137 @@ +from typing import List +import torch + +_param_map = dict() + + +def get_swap_parameters() -> List[torch.nn.Parameter]: + global _param_map + return list(_param_map.values()) + + +class _SwapEmbed(torch.autograd.Function): + + @staticmethod + def forward(ctx, input: torch.Tensor, weight_id: int, fake: torch.nn.Parameter): + # the fake parameter is preventing no grad fn + ctx.save_for_backward(input, fake) + ctx.weight_id = weight_id + + global _param_map + weight = _param_map[weight_id] + ctx.num_embeddings, ctx.embedding_dim = weight.size() + ctx.weight_dtype = weight.dtype + + with torch.no_grad(): + # swap in + weight.data = weight.detach().cuda() + # compute + output = torch.nn.functional.embedding(input, weight) + # swap out + weight.data = weight.detach().cpu() + + return output + + @staticmethod + def backward(ctx, grad_output): + print(f'debug: >> {torch.distributed.get_rank()} embed backward here') + (input, fake) = ctx.saved_tensors + + global _param_map + weight = _param_map[ctx.weight_id] + + # swap in + with torch.no_grad(): + weight = weight.data.cuda().requires_grad_() + # compute + with torch.enable_grad(): + output = torch.nn.functional.embedding(input, weight) + torch.autograd.backward((output,), (grad_output,)) + # swap out + assert weight.grad is not None + with torch.no_grad(): + grad = weight.grad.data.cpu() + weight = weight.data.cpu().requires_grad_() + weight.grad = grad + + _param_map[ctx.weight_id] = weight + fake_grad = torch.zeros_like(fake) + return None, None, fake_grad + + +class SwapEmbed(torch.nn.Module): + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx=None): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + assert padding_idx >= 0 + self.padding_idx = self.num_embeddings + padding_idx + else: + self.padding_idx = padding_idx + + _weight = torch.nn.Parameter( + torch.empty(num_embeddings, embedding_dim, requires_grad=True) + ) + self.weight_id = id(_weight) + # the fake parameter is preventing no grad fn + self.fake = torch.nn.Parameter(torch.empty((1,), requires_grad=True)) + global _param_map + _param_map[self.weight_id] = _weight + + def forward(self, input): + return _SwapEmbed.apply(input, self.weight_id, self.fake) + + @property + def weight(self): + global _param_map + return _param_map[self.weight_id] + + +if __name__ == '__main__': + + import cube + from cube.profiler.memory import model_summary + cube.init() + + class Model(torch.nn.Module): + + def __init__(self): + super().__init__() + # self.model1 = torch.nn.Embedding(250000, 1024) + self.model1 = SwapEmbed(250000, 1024) + self.model2 = SwapEmbed(250000, 1024) + # self.model2 = torch.nn.Embedding(250000, 1024) + self.model3 = torch.nn.Embedding(250000, 1024) + + def forward(self, input_ids): + out1 = self.model1(input_ids) + # assert out1.grad_fn is not None + out1 = out1 * 10 + # out2 = checkpoint.checkpoint(self.model2, input_ids) + out2 = self.model2(input_ids) + out2 = out2 / 10 + out3 = self.model3(input_ids) + out3 = -out3 + return torch.sum(out1 + out2 + out3) + + model = Model().cuda() + model.train() + + input_ids = torch.randint( + 0, 25000, (128, 1024), + dtype=torch.int, + device=torch.cuda.current_device(), + ) + + model_summary(model, (input_ids,)) + + loss = model(input_ids) + print(loss) + loss.backward() + + print(model.model1.weight.grad) From 01b3629b60a4a7604b0e740aa23e0988a08ef31a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 27 Mar 2022 23:27:47 +0800 Subject: [PATCH 0692/1892] optimize swap --- handcraft/mbart/mbart.py | 5 ++--- handcraft/mbart/swap.py | 33 +++++++++++++++++++++------------ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 17edfeb0..8de1ea5a 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -669,7 +669,7 @@ def reduce_embed(model, pp_embed_group): if model.headtail.swap: with torch.no_grad(): grad = model.headtail.embed.weight.grad - grad = grad.data.cuda() + grad = grad.cuda() else: grad = model.headtail.weight.grad else: @@ -679,8 +679,7 @@ def reduce_embed(model, pp_embed_group): if isinstance(model.headtail, torch.nn.Module): if model.headtail.swap: with torch.no_grad(): - grad = grad.cpu() - model.headtail.embed.weight.grad = grad + model.headtail.embed.weight.grad.copy_(grad) torch.cuda.synchronize() diff --git a/handcraft/mbart/swap.py b/handcraft/mbart/swap.py index 70617647..e41ef9e9 100644 --- a/handcraft/mbart/swap.py +++ b/handcraft/mbart/swap.py @@ -24,17 +24,22 @@ def forward(ctx, input: torch.Tensor, weight_id: int, fake: torch.nn.Parameter): with torch.no_grad(): # swap in - weight.data = weight.detach().cuda() + weight_gpu = torch.empty( + weight.size(), dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=True + ) + weight_gpu.copy_(weight) # compute - output = torch.nn.functional.embedding(input, weight) + output = torch.nn.functional.embedding(input, weight_gpu) # swap out - weight.data = weight.detach().cpu() + del weight_gpu return output @staticmethod def backward(ctx, grad_output): - print(f'debug: >> {torch.distributed.get_rank()} embed backward here') + # print(f'debug: >> {torch.distributed.get_rank()} embed backward here') (input, fake) = ctx.saved_tensors global _param_map @@ -42,19 +47,22 @@ def backward(ctx, grad_output): # swap in with torch.no_grad(): - weight = weight.data.cuda().requires_grad_() + weight_gpu = torch.empty( + weight.size(), dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=True + ) + weight_gpu.copy_(weight) # compute with torch.enable_grad(): - output = torch.nn.functional.embedding(input, weight) + output = torch.nn.functional.embedding(input, weight_gpu) torch.autograd.backward((output,), (grad_output,)) # swap out - assert weight.grad is not None + assert weight_gpu.grad is not None with torch.no_grad(): - grad = weight.grad.data.cpu() - weight = weight.data.cpu().requires_grad_() - weight.grad = grad + weight.grad.copy_(weight_gpu.grad) + del weight_gpu - _param_map[ctx.weight_id] = weight fake_grad = torch.zeros_like(fake) return None, None, fake_grad @@ -75,8 +83,9 @@ def __init__(self, self.padding_idx = padding_idx _weight = torch.nn.Parameter( - torch.empty(num_embeddings, embedding_dim, requires_grad=True) + torch.empty(num_embeddings, embedding_dim, requires_grad=True, pin_memory=True) ) + _weight.grad = torch.zeros_like(_weight, requires_grad=False, pin_memory=True) self.weight_id = id(_weight) # the fake parameter is preventing no grad fn self.fake = torch.nn.Parameter(torch.empty((1,), requires_grad=True)) From 4b29cc5325dc10edbeb17f4a1ac5f9db6dd60f9a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 27 Mar 2022 15:31:30 +0000 Subject: [PATCH 0693/1892] add timer --- handcraft/mbart/tp.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/handcraft/mbart/tp.py b/handcraft/mbart/tp.py index 01c80641..b8f9bf06 100644 --- a/handcraft/mbart/tp.py +++ b/handcraft/mbart/tp.py @@ -1,5 +1,5 @@ -from typing import Tuple import torch +from cube.profiler.timer import CudaTimer class AllReduceIdentity(torch.autograd.Function): @@ -9,7 +9,9 @@ def forward(ctx, input, group): world_size = torch.distributed.get_world_size(group) if world_size == 1: return input + CudaTimer().start(field_name='comm') torch.distributed.all_reduce(input, group=group) + CudaTimer().stop(field_name='comm') return input @staticmethod @@ -29,7 +31,9 @@ def backward(ctx, grad_output): world_size = torch.distributed.get_world_size(ctx._group) if world_size == 1: return grad_output, None + CudaTimer().start(field_name='comm') torch.distributed.all_reduce(grad_output, group=ctx._group) + CudaTimer().stop(field_name='comm') return grad_output, None @@ -42,11 +46,13 @@ def forward(ctx, input, dim, group): world_size = torch.distributed.get_world_size(group) if world_size == 1: return input + CudaTimer().start(field_name='comm') rank = torch.distributed.get_rank(group) tensor_list = [torch.empty_like(input) for _ in range(world_size)] tensor_list[rank] = input torch.distributed.all_gather(tensor_list, input, group=group) output = torch.cat(tensor_list, dim=dim).contiguous() + CudaTimer().stop(field_name='comm') return output @staticmethod @@ -56,9 +62,11 @@ def backward(ctx, grad_output: torch.Tensor): world_size = torch.distributed.get_world_size(group) if world_size == 1: return grad_output + CudaTimer().start(field_name='comm') input_list = grad_output.chunk(world_size, dim=dim) rank = torch.distributed.get_rank(group) grad = input_list[rank].contiguous() + CudaTimer().stop(field_name='comm') return grad, None, None @@ -71,8 +79,10 @@ def forward(ctx, input, dst: int, group): world_size = torch.distributed.get_world_size(group) if world_size == 1: return input + CudaTimer().start(field_name='comm') torch.distributed.reduce(input, dst, group=group) torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') return input @staticmethod @@ -82,8 +92,10 @@ def backward(ctx, grad_output): world_size = torch.distributed.get_world_size(group) if world_size == 1: return grad_output, None, None + CudaTimer().start(field_name='comm') torch.distributed.broadcast(grad_output, src, group=group) torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') return grad_output, None, None @@ -96,8 +108,10 @@ def forward(ctx, input, src: int, group=None): world_size = torch.distributed.get_world_size(group) if world_size == 1: return input + CudaTimer().start(field_name='comm') torch.distributed.broadcast(input, src, group=group) torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') return input @staticmethod @@ -107,8 +121,10 @@ def backward(ctx, grad_output): world_size = torch.distributed.get_world_size(group) if world_size == 1: return grad_output, None, None + CudaTimer().start(field_name='comm') if not grad_output.is_contiguous(): grad_output = grad_output.contiguous() torch.distributed.reduce(grad_output, dst, group=group) torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') return grad_output, None, None From fdf4cd49a02dc660a82f22037149e4d4faf7a6d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 28 Mar 2022 01:12:54 +0800 Subject: [PATCH 0694/1892] non-uniform partition --- handcraft/mbart/mbart.py | 107 +++++++++++++++++++------------- handcraft/mbart/mbart_hybrid.py | 99 ++++++++++++++--------------- 2 files changed, 113 insertions(+), 93 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 8de1ea5a..9b526070 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -31,6 +31,7 @@ _pp_embed_group = -1 _pp_next_rank = None _pp_prev_rank = None +_layer_divisions = [] parser = argparse.ArgumentParser(description='swin') # model arch @@ -66,14 +67,57 @@ idx = pp_ranks.index(DeviceGroup().rank) _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] - is_first_stage = idx == 0 - is_first_decoder_stage = idx == len(pp_ranks) // 2 - is_last_stage = idx == len(pp_ranks) - 1 + + encoder_time = [1] * args.layers + decoder_time = [2] * args.layers + times = encoder_time + decoder_time + num_stages = torch.distributed.get_world_size(_pp_group) + budget = sum(times) // num_stages + print_each_rank(f'budget: {budget}', rank_only=0) + start, end = 0, 1 + for idx in range(num_stages): + accum = times[start] + assert end <= args.layers * 2 + while end != args.layers * 2: + accum += times[end] + if accum > budget: + break + end += 1 + if idx == num_stages - 1: + end = args.layers * 2 + _layer_divisions.append((start, end)) + start, end = end, end+1 + + # uniform division algorithm: + # num_stages = torch.distributed.get_world_size(_pp_group) + # chunk = args.layers // (num_stages // 2) + # encoder_nlayers = [chunk] * (num_stages // 2) + # for idx in range(args.layers % (num_stages // 2)): + # encoder_nlayers[-idx] += 1 + # encoder_layers = [ + # (sum(encoder_nlayers[:rank]), + # sum(encoder_nlayers[:rank+1])) for rank in range(num_stages // 2) + # ] + # decoder_layers = [ + # (args.layers + sum(encoder_nlayers[:rank]), + # args.layers + sum(encoder_nlayers[:rank+1])) for rank in range(num_stages // 2) + # ] + # _layer_divisions = encoder_layers + decoder_layers + + print_each_rank(f'layer divisions: {_layer_divisions}', rank_only=0) + -# create embed group: first encoder, first decoder, last stage +# create embed group: first encoder, first decoder if args.use_1f1b: - embed_ranks = [pp_ranks[0], pp_ranks[len(pp_ranks) // 2]] - embed_ranks = list(set(embed_ranks)) + encoder_preprocess = 0 + decoder_preprocess = None + for rank in range(len(pp_ranks)): + start, end = _layer_divisions[rank] + if start <= args.layers and end > args.layers: + decoder_preprocess = rank + break + assert decoder_preprocess is not None + embed_ranks = [encoder_preprocess, decoder_preprocess] _pp_embed_group = DeviceGroup().get_group(embed_ranks) @@ -89,9 +133,6 @@ class Config: # attention_inner_dim = attention_heads * 64 # ffn_dim = 4 * embed_dim - # scale = args.scale - # scale_p = scale * 0.25 - num_embeddings = 500000 # 250027 + int(250027*scale_p) # decoder_layers = 12 + int(12*scale_p) # encoder_layers = 12 + int(12*scale_p) @@ -444,10 +485,7 @@ def decoder_preprocess(self, dst: Optional[int] = None): class mBARTFull(torch.nn.Module): - def __init__(self, cfg: Config, - encoder_preprocess=True, - decoder_preprocess=True, - post_process=True, shard=True): + def __init__(self, cfg: Config, shard=True): super().__init__() self.cfg = cfg self.dummy_labels = torch.tensor([1]).cuda() @@ -461,36 +499,17 @@ def __init__(self, cfg: Config, self.pp_stage = torch.distributed.get_rank(_pp_group) self.num_stages = torch.distributed.get_world_size(_pp_group) + self.layer_start, self.layer_end = _layer_divisions[self.pp_stage] - encoder_stages = self.num_stages // 2 - decoder_stages = self.num_stages // 2 - if self.pp_stage < self.num_stages // 2: - encoder_stages = self.num_stages // 2 - chunk = cfg.encoder_layers // encoder_stages - remain = cfg.encoder_layers % encoder_stages - layers = [chunk] * encoder_stages - for idx in range(remain): - layers[-idx] += 1 - self.layer_start = sum(layers[0:self.pp_stage]) - self.layer_end = self.layer_start + layers[self.pp_stage] - if self.pp_stage >= self.num_stages // 2: - chunk = cfg.decoder_layers // decoder_stages - remain = cfg.decoder_layers % decoder_stages - layers = [chunk] * decoder_stages - for idx in range(remain): - layers[-idx] += 1 - self.layer_start = cfg.encoder_layers + sum(layers[0:self.pp_stage-encoder_stages]) - self.layer_end = self.layer_start + layers[self.pp_stage-encoder_stages] - - self.encoder_preprocess = encoder_preprocess + self.encoder_preprocess = self.layer_start == 0 if not shard else False self.encoder_forward = (self.layer_start < cfg.encoder_layers) - self.decoder_preprocess = decoder_preprocess - self.decoder_first_stage = self.layer_start == cfg.encoder_layers - self.decoder_forward = (self.layer_start >= cfg.encoder_layers) + self.decoder_first_stage = self.layer_start <= cfg.encoder_layers and self.layer_end > cfg.encoder_layers + self.decoder_preprocess = self.decoder_first_stage if not shard else False + self.decoder_forward = (self.layer_end > cfg.encoder_layers) self.decoder_last_stage = (self.layer_end == cfg.encoder_layers + cfg.decoder_layers) - self.postprocess = post_process + self.postprocess = self.decoder_last_stage self.encoder_layer_start = self.layer_start self.encoder_layer_end = min(self.layer_end, cfg.encoder_layers) @@ -498,7 +517,7 @@ def __init__(self, cfg: Config, self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) self.decoder_layer_end = self.layer_end - if encoder_preprocess or decoder_preprocess or shard: + if self.encoder_preprocess or self.decoder_preprocess or shard: self.headtail = ShardEmbed(cfg, group = None if shard else -1, swap=args.use_swap) else: self.headtail = None @@ -675,7 +694,10 @@ def reduce_embed(model, pp_embed_group): else: grad = None if grad is not None: + CudaTimer().start('comm') torch.distributed.all_reduce(grad, group=pp_embed_group) + torch.cuda.synchronize() + CudaTimer().stop('comm') if isinstance(model.headtail, torch.nn.Module): if model.headtail.swap: with torch.no_grad(): @@ -696,12 +718,9 @@ def reduce_embed(model, pp_embed_group): batch_dims=(0,) ) if args.use_1f1b: - encoder_preprocess = is_first_stage - decoder_preprocess = is_first_decoder_stage - postprocess = is_last_stage - model = mBARTFull(cfg, encoder_preprocess, decoder_preprocess, postprocess, shard=False).cuda() + model = mBARTFull(cfg, shard=False).cuda() else: - model = mBARTFull(cfg, False, False, is_last_stage, shard=True).cuda() + model = mBARTFull(cfg, shard=True).cuda() print_each_rank('model weight consumpition:') memory_summary() diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index d941d3ed..df8a575a 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -4,8 +4,10 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --pp-size 2 --tp-size 2\ - --nmb 4 --scale 0 --iter-nmb 4 + handcraft/mbart/mbart_hybrid.py \ + --layers 12 --hidden-size 1024 --heads 16 \ + --pp-size 2 --tp-size 2 --nmb 4 --iter-nmb 4 \ + --use-recompute """ from typing import Optional @@ -28,6 +30,7 @@ _pp_embed_group = -1 _pp_next_rank = None _pp_prev_rank = None +_layer_divisions = [] parser = argparse.ArgumentParser(description='mbart hybrid') @@ -70,22 +73,50 @@ idx = pp_ranks.index(DeviceGroup().rank) _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] - is_first_stage = idx == 0 - is_first_decoder_stage = idx == len(pp_ranks) // 2 - is_last_stage = idx == len(pp_ranks) - 1 + + encoder_time = [1] * args.layers + decoder_time = [2] * args.layers + times = encoder_time + decoder_time + num_stages = torch.distributed.get_world_size(_pp_group) + budget = sum(times) // num_stages + print_each_rank(f'budget: {budget}', rank_only=0) + start, end = 0, 1 + for idx in range(num_stages): + accum = times[start] + assert end <= args.layers * 2 + while end != args.layers * 2: + accum += times[end] + if accum > budget: + break + end += 1 + if idx == num_stages - 1: + end = args.layers * 2 + _layer_divisions.append((start, end)) + start, end = end, end+1 + print_each_rank(f'layer divisions: {_layer_divisions}', rank_only=0) if len(pp_ranks) > 1: - pranks = [torch.zeros((args.pp_size,), dtype=torch.int, device=torch.cuda.current_device()) for _ in range(args.tp_size)] + pranks = [torch.zeros( + (args.pp_size,), dtype=torch.int, device=torch.cuda.current_device()) for _ in range(args.tp_size) + ] prank = torch.tensor(pp_ranks, dtype=torch.int).cuda() pranks[torch.distributed.get_rank(_tp_group)] = prank torch.distributed.all_gather(pranks, prank, group=_tp_group) torch.cuda.synchronize() # print_each_rank(f'allgather-pp ranks: {pranks}') - + encoder_preprocess_tp = 0 + decoder_preprocess_tp = None + for rank in range(len(pp_ranks)): + start, end = _layer_divisions[rank] + if start <= args.layers and end > args.layers: + decoder_preprocess_tp = rank + break + assert decoder_preprocess_tp is not None for prank in pranks: prank = prank.tolist() - embed_ranks = [prank[0], prank[len(prank) // 2]] + embed_ranks = [prank[encoder_preprocess_tp], prank[decoder_preprocess_tp]] embed_ranks = list(set(embed_ranks)) + print_each_rank(f'init embed group: {embed_ranks}') group = DeviceGroup().get_group(embed_ranks) if torch.distributed.get_rank(_tp_group) in prank: print(f'embedding group: {embed_ranks}') @@ -476,11 +507,7 @@ def decoder_preprocess(self): class mBARTFull(torch.nn.Module): - def __init__(self, cfg: Config, - encoder_preprocess=True, - decoder_preprocess=True, - post_process=True, - embed_cpu=False): + def __init__(self, cfg: Config, embed_cpu=False): super().__init__() self.cfg = cfg self.dummy_labels = torch.tensor([1]).cuda() @@ -494,40 +521,17 @@ def __init__(self, cfg: Config, self.pp_stage = torch.distributed.get_rank(_pp_group) self.num_stages = torch.distributed.get_world_size(_pp_group) + self.layer_start, self.layer_end = _layer_divisions[self.pp_stage] - if self.num_stages >= 2: - encoder_stages = self.num_stages // 2 - decoder_stages = self.num_stages // 2 - if self.pp_stage < self.num_stages // 2: - encoder_stages = self.num_stages // 2 - chunk = cfg.encoder_layers // encoder_stages - remain = cfg.encoder_layers % encoder_stages - layers = [chunk] * encoder_stages - for idx in range(remain): - layers[-idx] += 1 - self.layer_start = sum(layers[0:self.pp_stage]) - self.layer_end = self.layer_start + layers[self.pp_stage] - if self.pp_stage >= self.num_stages // 2: - chunk = cfg.decoder_layers // decoder_stages - remain = cfg.decoder_layers % decoder_stages - layers = [chunk] * decoder_stages - for idx in range(remain): - layers[-idx] += 1 - self.layer_start = cfg.encoder_layers + sum(layers[0:self.pp_stage-encoder_stages]) - self.layer_end = self.layer_start + layers[self.pp_stage-encoder_stages] - else: - self.layer_start = 0 - self.layer_end = cfg.encoder_layers + cfg.decoder_layers - - self.encoder_preprocess = encoder_preprocess + self.encoder_preprocess = self.layer_start == 0 self.encoder_forward = (self.layer_start < cfg.encoder_layers) - self.decoder_preprocess = decoder_preprocess - self.decoder_first_stage = self.layer_start == cfg.encoder_layers + self.decoder_first_stage = self.layer_start <= cfg.encoder_layers and self.layer_end > cfg.encoder_layers + self.decoder_preprocess = self.decoder_first_stage self.decoder_forward = (self.layer_end > cfg.encoder_layers) self.decoder_last_stage = (self.layer_end == cfg.encoder_layers + cfg.decoder_layers) - self.postprocess = post_process + self.postprocess = self.decoder_last_stage self.encoder_layer_start = self.layer_start self.encoder_layer_end = min(self.layer_end, cfg.encoder_layers) @@ -535,7 +539,7 @@ def __init__(self, cfg: Config, self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) self.decoder_layer_end = self.layer_end - if encoder_preprocess or decoder_preprocess: + if self.encoder_preprocess or self.decoder_preprocess: self.headtail = SharedEmbed(cfg, embed_cpu) else: self.headtail = None @@ -707,7 +711,10 @@ def reduce_embed(model, pp_embed_group): else: grad = None if grad is not None: + CudaTimer().start('comm') torch.distributed.all_reduce(grad, group=pp_embed_group) + torch.cuda.synchronize() + CudaTimer().stop('comm') torch.cuda.synchronize() @@ -724,13 +731,7 @@ def reduce_embed(model, pp_embed_group): ) - if args.pp_size > 1: - encoder_preprocess = is_first_stage - decoder_preprocess = is_first_decoder_stage - postprocess = is_last_stage - model = mBARTFull(cfg, encoder_preprocess, decoder_preprocess, postprocess, args.embed_cpu).cuda() - else: - model = mBARTFull(cfg, True, True, True, args.embed_cpu).cuda() + model = mBARTFull(cfg, args.embed_cpu).cuda() if args.embed_cpu: if model.headtail is not None: From ed6bbdeb5be7cef1a5064bcc48fa04659601ca26 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 28 Mar 2022 10:56:54 +0800 Subject: [PATCH 0695/1892] switch back to uniform partition --- handcraft/mbart/mbart.py | 2 +- handcraft/mbart/mbart_hybrid.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py index 9b526070..d17edf1d 100644 --- a/handcraft/mbart/mbart.py +++ b/handcraft/mbart/mbart.py @@ -69,7 +69,7 @@ _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] encoder_time = [1] * args.layers - decoder_time = [2] * args.layers + decoder_time = [1] * args.layers times = encoder_time + decoder_time num_stages = torch.distributed.get_world_size(_pp_group) budget = sum(times) // num_stages diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py index df8a575a..16e76895 100644 --- a/handcraft/mbart/mbart_hybrid.py +++ b/handcraft/mbart/mbart_hybrid.py @@ -75,7 +75,7 @@ _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] encoder_time = [1] * args.layers - decoder_time = [2] * args.layers + decoder_time = [1] * args.layers times = encoder_time + decoder_time num_stages = torch.distributed.get_world_size(_pp_group) budget = sum(times) // num_stages From 7ea2d966e20fed84dd95fa95c547b4c25f40ac47 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 28 Mar 2022 21:02:22 +0800 Subject: [PATCH 0696/1892] add swin transformer example --- examples/vision/swin/model.py | 675 ++++++++++++++++++++++++++++++++++ examples/vision/swin/train.py | 118 ++++++ 2 files changed, 793 insertions(+) create mode 100644 examples/vision/swin/model.py create mode 100644 examples/vision/swin/train.py diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py new file mode 100644 index 00000000..ee6edb3f --- /dev/null +++ b/examples/vision/swin/model.py @@ -0,0 +1,675 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu + +# The file is merged with source code from timm +# -------------------------------------------------------- +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +import cube + + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +class DropPath(torch.nn.Module): + + def __init__(self, drop_prob: float): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.criterion = nn.CrossEntropyLoss() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x, labels: torch.Tensor): + x = self.forward_features(x) + x = self.head(x) + loss = self.criterion(x, labels) + return loss + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int, img_size: int, num_classes: int): + + self.bs = batch_size + self.img_size = img_size + self.num_classes = num_classes + super().__init__( + shapes=([batch_size, 3, img_size, img_size,], + [batch_size], + ), + dtypes=(torch.float, torch.int), + batch_dims=(0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + img = torch.rand( + *(self.bs, 3, self.img_size, self.img_size), + dtype=torch.float, + device=torch.cuda.current_device() + ) + labels = torch.randint( + 0, self.num_classes, + size=(self.bs,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + return (img, labels) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] \ No newline at end of file diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py new file mode 100644 index 00000000..ae6e350e --- /dev/null +++ b/examples/vision/swin/train.py @@ -0,0 +1,118 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/vision/swin/train.py +""" + +import torch +from examples.vision.swin.model import SwinTransformer, ImageDataLoader + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary, model_summary + + +class Config: + + # swin-large 201M + embed_dim = 192 + depths = [2, 2, 18, 2] + num_heads = [6, 12, 24, 48] + + # swin-huge: 2.5B + # embed_dim = 512 + # depths = [2, 2, 42, 2] + # num_heads = [16, 32, 64, 128] + + mlp_ratio = 4 + qkv_bias = True + qk_scale = None + + drop_path_rate = 0.2 + drop_rate = 0.2 + + + # 224 x 224 + img_size = 224 + window_size = 7 + + # 640 x 640 + img_size = 640 + window_size = 40 + + # 1536 x 1536 + # img_size = 1536 + # window_size = 48 + + num_classes = 1000 + + +def train(): + + batch_size = 1 + + cfg = Config() + model = SwinTransformer(img_size=cfg.img_size, + patch_size=4, + in_chans=3, + num_classes=cfg.num_classes, + embed_dim=cfg.embed_dim, + depths=cfg.depths, + num_heads=cfg.num_heads, + window_size=cfg.window_size, + mlp_ratio=cfg.mlp_ratio, + qkv_bias=cfg.qkv_bias, + qk_scale=cfg.qk_scale, + drop_rate=cfg.drop_rate, + drop_path_rate=cfg.drop_path_rate, + ape=False, + patch_norm=True, + use_checkpoint=False) + + model = model.cuda() + dataloader = ImageDataLoader(batch_size, cfg.img_size, cfg.num_classes) + optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) + + print_each_rank('model weight consumpition:') + memory_summary() + + def train_iter(model, dataloader): + imgs, labels = next(dataloader) + loss = model(imgs, labels) + loss.backward() + + CudaTimer(enable=False).warmup() + iter_num = 10 + for step in range(iter_num): + + if step == 0: + model_summary(model, next(dataloader)) + + if step >= 4: + CudaTimer(enable=True).start('e2e') + + # training + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + if step >= 4: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('passed first iteration') + + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-4, field_name='e2e'))) + memory_summary() + +if __name__ == '__main__': + + cube.init() + train() From 3ea3dc6dd379d02f9d8b166a221821337157808d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 02:21:19 +0000 Subject: [PATCH 0697/1892] log experiments --- handcraft/mbart/run-1f1b-swap.sh | 100 ++++++++++++++++++ handcraft/mbart/run-recompute-arch.sh | 94 ++++++++-------- .../mbart/run-recompute-full-v100-32gb.sh | 2 +- handcraft/mbart/run.sh | 70 ------------ handcraft/mbart/test.py | 45 -------- 5 files changed, 148 insertions(+), 163 deletions(-) create mode 100755 handcraft/mbart/run-1f1b-swap.sh delete mode 100755 handcraft/mbart/run.sh delete mode 100644 handcraft/mbart/test.py diff --git a/handcraft/mbart/run-1f1b-swap.sh b/handcraft/mbart/run-1f1b-swap.sh new file mode 100755 index 00000000..2acb589f --- /dev/null +++ b/handcraft/mbart/run-1f1b-swap.sh @@ -0,0 +1,100 @@ +evaldir=eval/mbart-v100-32gb-pcie-recompute + +mkdir -p ${evaldir} + +# ================================================= +# 4 gpus: arch layer 21,21, hidden 1792, heads 28 +# ================================================= +layers=24 +hidden=2048 +heads=32 +gpus=4 + +echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt + + + +layers=24 +hidden=2560 +heads=32 +gpus=4 + +echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt + + +layers=18 +hidden=3072 +heads=32 +gpus=4 + +echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt + + +layers=27 +hidden=2304 +heads=36 +gpus=4 + +echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt + + +layers=30 +hidden=2560 +heads=40 +gpus=8 + +echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt + + +layers=33 +hidden=2816 +heads=48 +gpus=8 + +echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt + + +layers=24 +hidden=4096 +heads=32 +gpus=8 + +echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt + + +# python scripts/keep.py --gpus 8 \ No newline at end of file diff --git a/handcraft/mbart/run-recompute-arch.sh b/handcraft/mbart/run-recompute-arch.sh index 8682e54e..4b90ea89 100755 --- a/handcraft/mbart/run-recompute-arch.sh +++ b/handcraft/mbart/run-recompute-arch.sh @@ -1,5 +1,5 @@ layers=24 -hidden=4096 +hidden=2560 heads=32 gpus=8 @@ -17,12 +17,12 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ # # Pure 1F1B -# echo 'testing pure 1f1b' -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ -# handcraft/mbart/mbart.py \ -# --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ -# --use-1f1b --nmb 256 --iter-nmb 256\ -# --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b.txt +echo 'testing pure 1f1b' +OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --use-1f1b --nmb 256 --iter-nmb 256\ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b.txt # Pure TP echo 'testing pure tensor parallelism' @@ -33,44 +33,44 @@ OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt # # Hybrid TP-1F1B -- 4 GPU -# if [ ${gpus} == 4 ] -# then -# echo 'testing hybrid tp:pp=2:2' -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py \ -# --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ -# --tp-size 2 --pp-size 2 --nmb 256 --iter-nmb 256 \ -# --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt -# sleep 5 -# killall python -# sleep 5 -# killall python -# fi -# -# # Hybrid TP-1F1B -- 8 GPU -# if [ ${gpus} == 8 ] -# then -# echo 'testing hybrid tp:pp=4:2' -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py \ -# --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ -# --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ -# --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt -# sleep 5 -# killall python -# sleep 5 -# killall python -# -# echo 'testing hybrid tp:pp=2:4' -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py \ -# --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ -# --tp-size 2 --pp-size 4 --nmb 256 --iter-nmb 256 \ -# --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt -# sleep 5 -# killall python -# sleep 5 -# killall python -# fi +if [ ${gpus} == 4 ] +then + echo 'testing hybrid tp:pp=2:2' + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 2 --pp-size 2 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt + sleep 5 + killall python + sleep 5 + killall python +fi + +# Hybrid TP-1F1B -- 8 GPU +if [ ${gpus} == 8 ] +then + echo 'testing hybrid tp:pp=4:2' + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt + sleep 5 + killall python + sleep 5 + killall python + + echo 'testing hybrid tp:pp=2:4' + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --tp-size 2 --pp-size 4 --nmb 256 --iter-nmb 256 \ + --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt + sleep 5 + killall python + sleep 5 + killall python +fi -python scripts/keep.py --gpus 8 +# python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/run-recompute-full-v100-32gb.sh b/handcraft/mbart/run-recompute-full-v100-32gb.sh index a3cfd790..adddc92d 100755 --- a/handcraft/mbart/run-recompute-full-v100-32gb.sh +++ b/handcraft/mbart/run-recompute-full-v100-32gb.sh @@ -1,4 +1,4 @@ -evaldir=/data/MagicCube/scale-mbart-recompute +evaldir=eval/mbart-v100-32gb-pcie-recompute mkdir -p ${evaldir} diff --git a/handcraft/mbart/run.sh b/handcraft/mbart/run.sh deleted file mode 100755 index 60cfa3a3..00000000 --- a/handcraft/mbart/run.sh +++ /dev/null @@ -1,70 +0,0 @@ -evaldir=/data/MagicCube/scale-mbart - -mkdir -p ${evaldir} - -# 4 gpus - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 3 > ${evaldir}/4dev256nmb-tp1f1b-pack.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ - --scale 3 --iter-nmb 1 > ${evaldir}/4dev256nmb-1f1b.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 1 --nmb 256 \ - --scale 3 > ${evaldir}/4dev256nmb-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 2 --nmb 256 \ - --scale 3 --iter-nmb 256 > ${evaldir}/4dev256nmb-tp2pp2.txt - - -# 8 gpus - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ - --scale 4 > ${evaldir}/8dev256nmb-tp1f1b-pack.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart.py --use-1f1b --nmb 256 \ - --scale 4 --iter-nmb 1 > ${evaldir}/8dev256nmb-1f1b.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 1 --nmb 256 \ - --scale 4 > ${evaldir}/8dev256nmb-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 2 --nmb 256 \ - --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp4pp2.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 4 --nmb 256 \ - --scale 4 --iter-nmb 256 > ${evaldir}/8dev256nmb-tp2pp4.txt - - -# 16 gpus - -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart.py --use-tp1f1b-pack --nmb 256 \ -# --scale 6 > ${evaldir}/16dev256nmb-tp1f1b.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 16 --pp-size 1 --nmb 256 \ -# --scale 6 > ${evaldir}/16dev256nmb-tp.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 4 --pp-size 4 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp4pp4.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 8 --pp-size 2 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp8pp2.txt -# -# OMP_NUM_THREADS=4 torchrun --nproc_per_node=16 --nnodes=1 \ -# handcraft/mbart/mbart_hybrid.py --tp-size 2 --pp-size 8 --nmb 256 \ -# --scale 6 --iter-nmb 256 > ${evaldir}/16dev256nmb-tp2pp8.txt - -echo 'done!!!' -python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/test.py b/handcraft/mbart/test.py deleted file mode 100644 index db9da84b..00000000 --- a/handcraft/mbart/test.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 handcraft/mbart/test.py -""" - -import torch -import cube -from cube.profiler.memory import memory_summary, model_summary -from handcraft.mbart.tp import AllReduceIdentity, IdentityAllreduce - -scale = 7 -embed_dim = 1024 + int(1024 * (scale * 0.25)) -print(f'embed dim = {embed_dim}') - - - -class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.embed = torch.nn.Embedding(500000, embed_dim) - - def forward(self, x): - out = self.embed(x) - loss = torch.sum(out) - return loss - -cube.init() -print('loading...') -model = Model().cuda() -input_ids = torch.randint(0, 25000, (1, 1024), dtype=torch.int).cuda() - -optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - -print('training...') -for _ in range(3): - - loss = model(input_ids) - loss.backward() - optimizer.step() - -memory_summary() From 9a4a6a7742cf1eb13c8bf1db3496f468b067adb3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 02:23:42 +0000 Subject: [PATCH 0698/1892] add runtime tests --- tests/test_nccl.py | 103 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tests/test_nccl.py diff --git a/tests/test_nccl.py b/tests/test_nccl.py new file mode 100644 index 00000000..d3000b2f --- /dev/null +++ b/tests/test_nccl.py @@ -0,0 +1,103 @@ + +""" + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=1 \ + tests/test_nccl.py + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + tests/test_nccl.py + +OMP_NUM_THREADS=4 python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + tests/test_nccl.py + +""" + +import torch +import time +import sys +import os +import argparse + + +def print_each_rank(msg, select=True, outfile=''): + myrank = torch.distributed.get_rank() + outfile = sys.stdout if outfile == '' else outfile + for rank in range(torch.distributed.get_world_size()): + if select: + if myrank == rank: + f = open(outfile, 'a') if outfile != sys.stdout else sys.stdout + f.write('rank [{}]: {}\n'.format(rank, msg)) + if outfile != sys.stdout: + f.close() + torch.distributed.barrier() + + +def test_nccl(size, local_rank): + msg = torch.ones((size,)).cuda() + # warm up + for _ in range(20): + out = torch.distributed.all_reduce(msg) + torch.cuda.synchronize() + # profile + tic = time.perf_counter() + for _ in range(100): + out = torch.distributed.all_reduce(msg) + torch.cuda.synchronize() + toc = time.perf_counter() + + span = (toc - tic) * 1000 / 100 # in ms + bandwidth = size / span / 1e6 # in GB/s + print_each_rank( + 'NCCL Allreduce | Msg Size: {:.0f} MB | Algo Bandwidth: {:.2f} GB/s'.format( + size / 1024 / 1024, bandwidth), + select=(local_rank==0), + ) + +def test_allgather(size, local_rank): + msg = torch.ones((size,)).cuda() + tensor_list = [torch.empty_like(msg) for _ in range(torch.distributed.get_world_size())] + + tic = time.perf_counter() + for _ in range(100): + out = torch.distributed.all_gather(tensor_list, msg) + torch.cuda.synchronize() + print_each_rank('Passed all-gather') + toc = time.perf_counter() + + +def benchmark(args): + size = args.begin + while size <= args.end: + # test_allgather(size * 1024 * 1024, args.local_rank) + test_nccl(size * 1024 * 1024, args.local_rank) # MB to B + size *= 2 + print_each_rank('test on nccl is done') + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--begin', type=int, default=4, + help='start message size in MB') + parser.add_argument('--end', type=int, default=64, + help='end message size in MB') + args = parser.parse_args() + + torch.distributed.init_process_group(backend='nccl') + print(f'{torch.distributed.get_rank()} launches') + + args.local_rank = int(os.environ.get('LOCAL_RANK')) + torch.cuda.set_device(args.local_rank) + benchmark(args) From c5cc7f49e0323db58b0ee02bb418d38c3822cfb2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 02:25:06 +0000 Subject: [PATCH 0699/1892] remove dummy --- handcraft/pipeline/dummy.py | 305 ------------------- handcraft/pipeline/dummy_hybrid.py | 283 ------------------ handcraft/pipeline/run.sh | 28 -- handcraft/pipeline/schedule.py | 451 ----------------------------- 4 files changed, 1067 deletions(-) delete mode 100644 handcraft/pipeline/dummy.py delete mode 100644 handcraft/pipeline/dummy_hybrid.py delete mode 100755 handcraft/pipeline/run.sh delete mode 100644 handcraft/pipeline/schedule.py diff --git a/handcraft/pipeline/dummy.py b/handcraft/pipeline/dummy.py deleted file mode 100644 index 41db0151..00000000 --- a/handcraft/pipeline/dummy.py +++ /dev/null @@ -1,305 +0,0 @@ -""" -Dummy model - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive - -OMP_NUM_THREADS=4 python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive -""" -import torch -import torch.nn.functional as F -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup -from cube.runtime.syndata import SynDataLoader, SynTextDataLoader - -from handcraft.pipeline.schedule import schedule_tp_1f1b, schedule_naive, schedule_tp_1f1b_pack -import argparse - -""" -Stage0: - Embedding [M, 1], [N, E] -> [M, E] - Linear [M, E], [E, E] -> [M, E] - -Stage Else: - Linear [M, E], [E, E] -> [M, E] - -Condition: N > 8M - E -""" - -io_input = input - -class ReduceEmbed(torch.autograd.Function): - - @staticmethod - def forward(ctx, input): - torch.distributed.all_reduce(input) - return input - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class IdentityFoward(torch.autograd.Function): - - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad_output): - torch.distributed.all_reduce(grad_output) - return grad_output - - -class DummyModelEmbed(torch.nn.Module): - - def __init__(self, M: int, N: int, E: int, sharding=False): - super().__init__() - self.M = M - self.N = N - self.E = E - self.sharding = sharding - chunk_num = DeviceGroup().world_size if sharding else 1 - stage_id = DeviceGroup().rank if sharding else 1 - self.vocab_start_index = N // chunk_num * stage_id - self.vocab_end_index = N // chunk_num * (stage_id + 1) - self.embed_weight = torch.nn.Parameter(torch.ones((N // chunk_num, E))) - - def input_shape(self): - return (self.M, ) - - def input_dtype(self): - return torch.int64 - - def output_shape(self): - return (self.M, self.E) - - def output_dtype(self): - return torch.float32 - - def forward(self, input: torch.Tensor): - if self.sharding: - mask = (input < self.vocab_start_index) | \ - (input >= self.vocab_end_index) - input = input.clone() - self.vocab_start_index - input[mask] = 0 - input = F.embedding(input, self.embed_weight) - input[mask, :] = 0.0 - input = ReduceEmbed.apply(input) - else: - input = F.embedding(input, self.embed_weight) - return input - - -class DummyModel(torch.nn.Module): - - def __init__(self, M: int, N: int, E: int, stage_id: int, - sharding=False, embed: torch.nn.Module = None): - - super().__init__() - self.M = M - self.N = N - self.E = E - self.is_last_stage = stage_id == DeviceGroup().world_size - 1 - self.sharding = sharding - # mebed module - self.embed = embed - # first stage - chunk_num = torch.distributed.get_world_size() if sharding else 1 - self.fc_weight = torch.nn.Parameter(torch.ones((E // chunk_num, E)) / 10000) - - def input_shape(self): - if self.embed: - return self.embed.input_shape() - else: - return (self.M, self.E) - - def output_shape(self): - if self.is_last_stage: - return (1,) - else: - return (self.M, self.E) - - def input_dtype(self): - if self.embed: - return self.embed.input_dtype() - else: - return torch.float32 - - def output_dtype(self): - return torch.float32 - - def forward(self, input: torch.Tensor): - # print(f'[{DeviceGroup().rank}] input: {input}, shape={input.size()}') - if self.embed: - input = self.embed(input) - - if self.sharding: - input = IdentityFoward.apply(input) - output = F.linear(input, self.fc_weight) - - if self.is_last_stage: - output = torch.sum(output) - # print(f'[{DeviceGroup().rank}] output: {output}, shape={output.size()}') - return output - - -class DummyModelTP(torch.nn.Module): - - def __init__(self, M: int, N: int, E: int, stage_id: int): - super().__init__() - self.M = M - self.N = N - self.E = E - self.stages = DeviceGroup().world_size - - self.vocab_start_index = N // self.stages * stage_id - self.vocab_end_index = N // self.stages * (stage_id + 1) - self.embed = DummyModelEmbed(M, N, E, sharding=True) - self.fc_weights = torch.nn.ParameterList() - for idx in range(self.stages): - if idx % 2 == 0: - self.fc_weights.append( - torch.nn.Parameter(torch.ones((E // self.stages, E)) / 10000) - ) - else: - self.fc_weights.append( - torch.nn.Parameter(torch.ones((E, E // self.stages)) / 10000) - ) - - def forward(self, input: torch.Tensor): - x = self.embed(input) - # print(f'embed: {x}') - for idx in range(self.stages): - # column partition - if idx % 2 == 0: - x = IdentityFoward.apply(x) - x = F.linear(x, self.fc_weights[idx]) - else: - x = F.linear(x, self.fc_weights[idx]) - x = ReduceEmbed.apply(x) - # print(f'linear: {x}') - # reduce - if self.stages % 2 != 0: - raise RuntimeError("number of stages only supported to be mod 2 == 0") - loss = torch.sum(x) - # print(loss) - # if rank == 0: - # io_input(f'{step}>>>') - # torch.distributed.barrier() - return loss - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--use-naive', action='store_true', - help='use naive pipeline') - parser.add_argument('--use-tp1f1b', action='store_true', - help='use tensor parallel 1f1b') - parser.add_argument('--use-tp1f1b-pack', action='store_true', - help='use tensor parallel 1f1b') - parser.add_argument('--use-tp', action='store_true', - help='use pure tensor parallelism') - parser.add_argument('--nmb', type=int, default=4, - help='num of micro batch') - parser.add_argument('--M', type=int, default=4096, - help='M dimension length = sequence length') - parser.add_argument('--N', type=int, default=50257, - help='word number') - parser.add_argument('--E', type=int, default=2048, - help='E dimension length = hidden dimension length') - args = parser.parse_args() - - print(args) - - cube.init() - rank = DeviceGroup().rank - - # tp 1f1b - if args.use_tp1f1b: - embed = DummyModelEmbed(args.M, args.N, args.E, sharding=True).cuda() - first_stage_model = DummyModel(args.M, args.N, args.E, 0, sharding=True, embed=embed).cuda() - if rank == 0: - model = None - else: - model = DummyModel(args.M, args.N, args.E, rank, sharding=False).cuda() - # optimizer - if rank == 0: - parameters = first_stage_model.parameters() - else: - parameters = list(first_stage_model.parameters()) + list(model.parameters()) - optimizer = torch.optim.Adam(parameters) - - if args.use_naive: - # naive pipleline - embed = None - if rank == 0: - embed = DummyModelEmbed(args.M, args.N, args.E, sharding=False).cuda() - model = DummyModel(args.M, args.N, args.E, rank, sharding=False, embed=embed).cuda() - # optimizer - optimizer = torch.optim.Adam(model.parameters()) - - - if args.use_tp1f1b_pack: - embed = DummyModelEmbed(args.M, args.N, args.E, sharding=True).cuda() - model = DummyModel(args.M, args.N, args.E, rank, sharding=False).cuda() - optimizer = torch.optim.Adam(list(embed.parameters()) + list(model.parameters())) - - if args.use_tp: - model = DummyModelTP(args.M, args.N, args.E, rank).cuda() - optimizer = torch.optim.Adam(model.parameters()) - - # 0.11GB - print_each_rank('model consumption') - memory_summary() - - dataloader = SynTextDataLoader( - shapes=([args.M],), - dtypes=(torch.int64, ), - batch_dims=(0,), - length=128000 - ) - - # 0.11GB - print_each_rank('model + dataloader consumption') - memory_summary() - - iter_num = 64 - CudaTimer(enable=False).warmup() - for step in range(iter_num): - if step >= 20: - CudaTimer(enable=True).start('e2e') - - if args.use_tp1f1b: - schedule_tp_1f1b(model, first_stage_model, dataloader, args.nmb, DeviceGroup().world_size) - if args.use_naive: - schedule_naive(model, dataloader, args.nmb) - if args.use_tp1f1b_pack: - schedule_tp_1f1b_pack(model, embed, dataloader, args.nmb, DeviceGroup().world_size) - if args.use_tp: - for _ in range(args.nmb): - data = next(dataloader) - loss = model(data) - loss.backward() - - optimizer.step() - optimizer.zero_grad() - - if step >= 20: - CudaTimer().stop('e2e') - if (step+1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-20, field_name='e2e'))) - memory_summary() diff --git a/handcraft/pipeline/dummy_hybrid.py b/handcraft/pipeline/dummy_hybrid.py deleted file mode 100644 index e7ff9282..00000000 --- a/handcraft/pipeline/dummy_hybrid.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -Dummy model - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/pipeline/dummy_hybrid.py --tp-size 2 --pp-size 2 - -OMP_NUM_THREADS=4 python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/pipeline/dummy_hybrid.py --use-naive -""" -import torch -import torch.nn.functional as F -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup -from cube.runtime.syndata import SynDataLoader, SynTextDataLoader - -from handcraft.pipeline.schedule import schedule_tp_1f1b, schedule_naive, schedule_tp_1f1b_pack -import argparse - -""" -Stage0: - Embedding [M, 1], [N, E] -> [M, E] - Linear [M, E], [E, E] -> [M, E] - -Stage Else: - Linear [M, E], [E, E] -> [M, E] - -Condition: N > 8M - E -""" - -_tp_group = None -_pp_group = None -_pp_next_rank = None -_pp_prev_rank = None - -io_input = input - -class ReduceEmbed(torch.autograd.Function): - - @staticmethod - def forward(ctx, input): - global _tp_group - torch.distributed.all_reduce(input, group=_tp_group) - return input - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class IdentityFoward(torch.autograd.Function): - - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad_output): - global _tp_group - torch.distributed.all_reduce(grad_output, group=_tp_group) - return grad_output - - -class DummyModelEmbed(torch.nn.Module): - - def __init__(self, M: int, N: int, E: int, tp_group = None, pp_group = None): - super().__init__() - self.M = M - self.N = N - self.E = E - - self.tp_group = tp_group - self.tp_size = torch.distributed.get_world_size(tp_group) - self.pp_group = pp_group - - shard_id = torch.distributed.get_rank(tp_group) - self.vocab_start_index = N // self.tp_size * shard_id - self.vocab_end_index = N // self.tp_size * (shard_id + 1) - self.embed_weight = torch.nn.Parameter(torch.ones((N // self.tp_size, E))) - - def input_shape(self): - return (self.M, ) - - def input_dtype(self): - return torch.int64 - - def output_shape(self): - return (self.M, self.E) - - def output_dtype(self): - return torch.float32 - - def forward(self, input: torch.Tensor): - if self.tp_size > 1: - mask = (input < self.vocab_start_index) | \ - (input >= self.vocab_end_index) - input = input.clone() - self.vocab_start_index - input[mask] = 0 - input = F.embedding(input, self.embed_weight) - input[mask, :] = 0.0 - input = ReduceEmbed.apply(input) - else: - input = F.embedding(input, self.embed_weight) - return input - - -class DummyModel(torch.nn.Module): - - def __init__(self, M: int, N: int, E: int): - - super().__init__() - self.M = M - self.N = N - self.E = E - - # group - global _tp_group - self.tp_group = _tp_group - global _pp_group - self.pp_group = _pp_group - - self.pp_stage = torch.distributed.get_rank(_pp_group) - self.is_first_pp_stage = self.pp_stage == 0 - self.is_last_stage = self.pp_stage == torch.distributed.get_world_size(_pp_group) - 1 - - self.tp_size = torch.distributed.get_world_size(_tp_group) - - # mebed module - if self.is_first_pp_stage: - self.embed = DummyModelEmbed(M, N, E, self.tp_group, self.pp_group) - else: - self.embed = None - - total_fc_num = torch.distributed.get_world_size() - fc_weights = list() - input_shapes = list() - output_shapes = list() - shard_types = list() - for idx in range(total_fc_num): - if idx % 2 == 0: - fc_weights.append( - torch.nn.Parameter(torch.ones((E // self.tp_size, E)) / 10000) - ) - input_shapes.append((M, E)) - output_shapes.append((M, E // self.tp_size)) - shard_types.append('col') - else: - fc_weights.append( - torch.nn.Parameter(torch.ones((E, E // self.tp_size)) / 10000) - ) - input_shapes.append((M, E // self.tp_size)) - output_shapes.append((M, E)) - shard_types.append('row') - - self.fc_num = total_fc_num // torch.distributed.get_world_size(_pp_group) - self.fc_weights = torch.nn.ParameterList( - fc_weights[self.fc_num * self.pp_stage : self.fc_num * (self.pp_stage + 1)] - ) - self.ins = input_shapes[self.fc_num * self.pp_stage : self.fc_num * (self.pp_stage + 1)] - self.ous = output_shapes[self.fc_num * self.pp_stage : self.fc_num * (self.pp_stage + 1)] - self.shard_types = shard_types[self.fc_num * self.pp_stage : self.fc_num * (self.pp_stage + 1)] - print_each_rank(f'initializing with {self.fc_num} fcs: {self.shard_types}') - - - def input_shape(self): - if self.embed: - return self.embed.input_shape() - else: - return self.ins[0] - - def output_shape(self): - if self.is_last_stage: - return (1,) - else: - return self.ous[-1] - - def input_dtype(self): - if self.embed: - return self.embed.input_dtype() - else: - return torch.float32 - - def output_dtype(self): - return torch.float32 - - def forward(self, input: torch.Tensor): - # print(f'[{DeviceGroup().rank}] input: {input}, shape={input.size()}') - if self.embed: - x = self.embed(input) - else: - x = input - - for stype, weight in zip(self.shard_types, self.fc_weights): - # column partition - if stype == 'col': - x = IdentityFoward.apply(x) - x = F.linear(x, weight) - elif stype == 'row': - x = F.linear(x, weight) - x = ReduceEmbed.apply(x) - else: - assert False - - if self.is_last_stage: - x = torch.sum(x) - # print(x) - # print(f'[{DeviceGroup().rank}] output: {output}, shape={output.size()}') - return x - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--tp-size', type=int, - help='tensor parallelism size') - parser.add_argument('--pp-size', type=int, - help='pipeline parallelism size') - parser.add_argument('--nmb', type=int, default=4, - help='num of micro batch') - parser.add_argument('--M', type=int, default=4096, - help='M dimension length = sequence length') - parser.add_argument('--N', type=int, default=50257, - help='word number') - parser.add_argument('--E', type=int, default=2048, - help='E dimension length = hidden dimension length') - args = parser.parse_args() - - print(args) - - cube.init() - rank = DeviceGroup().rank - pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) - - if _pp_group is None: - _pp_group = DeviceGroup().get_group(pp_ranks) - idx = pp_ranks.index(DeviceGroup().rank) - _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] - _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] - if _tp_group is None: - _tp_group = DeviceGroup().get_group(tp_ranks) - - - model = DummyModel(args.M, args.N, args.E).cuda() - optimizer = torch.optim.Adam(model.parameters()) - - # 0.11GB - print_each_rank('model consumption') - memory_summary() - - dataloader = SynTextDataLoader( - shapes=([args.M],), - dtypes=(torch.int64, ), - batch_dims=(0,), - length=128000 - ) - - # 0.11GB - print_each_rank('model + dataloader consumption') - memory_summary() - - iter_num = 64 - CudaTimer(enable=False).warmup() - for step in range(iter_num): - if step >= 20: - CudaTimer(enable=True).start('e2e') - - schedule_naive(model, dataloader, args.nmb, [_pp_prev_rank, _pp_next_rank]) - optimizer.step() - optimizer.zero_grad() - - if step >= 20: - CudaTimer().stop('e2e') - if (step+1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-20, field_name='e2e'))) - memory_summary() diff --git a/handcraft/pipeline/run.sh b/handcraft/pipeline/run.sh deleted file mode 100755 index 080b7306..00000000 --- a/handcraft/pipeline/run.sh +++ /dev/null @@ -1,28 +0,0 @@ -# 4 gpus - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp --nmb 64 > 4dev64nmb-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp1f1b --nmb 64 > 4dev64nmb-tp1f1b.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp1f1b-pack --nmb 64 > 4dev64nmb-tp1f1b-pack.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive --nmb 64 > 4dev64nmb-naive.txt - - -# 8 gpus - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp --nmb 128 > 8dev128nmb-tp.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp1f1b --nmb 128 > 8dev128nmb-tp1f1b.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-tp1f1b-pack --nmb 128 > 8dev128nmb-tp1f1b-pack.txt - -OMP_NUM_THREADS=4 torchrun --nproc_per_node=8 --nnodes=1 \ - handcraft/pipeline/dummy.py --use-naive --nmb 128 > 8dev128nmb-naive.txt \ No newline at end of file diff --git a/handcraft/pipeline/schedule.py b/handcraft/pipeline/schedule.py deleted file mode 100644 index 8db0c28d..00000000 --- a/handcraft/pipeline/schedule.py +++ /dev/null @@ -1,451 +0,0 @@ -from typing import List -import torch - -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary -import cube.runtime.adapter.collectives as coll -from cube.runtime.device import DeviceGroup - -from torch.distributed.distributed_c10d import _get_global_rank -io_input = input - - -def get_global_rank(group, group_rank): - if group is None: - return group_rank - else: - return _get_global_rank(group, group_rank) - - -def forward_step(model, *args, **kwargs): - """ - Forward pass - """ - CudaTimer().start("forward") - output = model(*args, **kwargs) - CudaTimer().stop("forward") - return output - - -def backward_step(input_tensors: List[torch.Tensor], - output_tensors: List[torch.Tensor], - output_tensor_grads: List[torch.Tensor]): - """ - Backward pass - """ - for tensor in input_tensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - tensor.retain_grad() - CudaTimer().start("backward") - torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) - CudaTimer().stop("backward") - input_tensor_grads = [] - for tensor in input_tensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - input_tensor_grads.append(tensor.grad) - else: - input_tensor_grads.append(None) - return input_tensor_grads - - -def is_first_stage(): - return DeviceGroup().rank == 0 - - -def is_last_stage(): - return DeviceGroup().rank == DeviceGroup().world_size - 1 - - -def recv_input(model, dataloader, prev_rank: int): - if is_first_stage(): - return next(dataloader) - else: - return coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - - -def schedule_naive(model, dataloader, num_microbatch: int, neighbors = None): - rank = DeviceGroup().rank - if neighbors is None: - prev_rank = (rank - 1) % DeviceGroup().world_size - next_rank = (rank + 1) % DeviceGroup().world_size - else: - prev_rank, next_rank = neighbors - - is_first_stage = rank < prev_rank - is_last_stage = rank > next_rank - - for step in range(num_microbatch): - # recv forward - if is_first_stage: - input = next(dataloader) - else: - # print(f'rank {rank} recving forward input...') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - # forward - output = forward_step(model, input) - # send forward - if not is_last_stage: - # print(f'rank {rank} sending forward output...') - coll.send(output, next_rank) - # recv backward - output_grad = None - if not is_last_stage: - # print(f'rank {rank} recving backward input...') - output_grad = coll.recv(output.size(), next_rank, output.dtype) - # backward - input_grad = backward_step([input], [output], [output_grad])[0] - # send backward - if not is_first_stage: - # print(f'rank {rank} sending backward output...') - coll.send(input_grad, prev_rank) - - # memory_summary() - # if rank == 0: - # io_input(f'{step}>>>') - # torch.distributed.barrier() - - -def schedule_tp_1f1b(model: torch.nn.Module, - first_stage_model: torch.nn.Module, - dataloader, - num_microbatch: int, - num_stage: int): - rank = DeviceGroup().rank - next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size - prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size - - input_tensors = list() - output_tensors = list() - - input_1st_tensors = list() - output_1st_tensors = list() - - gather_list = list(range(num_stage)) - gather_list[0], gather_list[1] = gather_list[1], gather_list[0] - - def tp_forward(fmodel, dataloader) -> torch.Tensor: - input = next(dataloader) - output = forward_step(fmodel, input) - input_1st_tensors.append(input) - output_1st_tensors.append(output) - # gather - outputs = coll.gather([output], None, None, gather_list) - if rank == 1: - with torch.no_grad(): - outputs[0], outputs[1] = outputs[1], outputs[0] - output = torch.cat(tuple(outputs), dim=-1) - output = output.requires_grad_() - else: - output = None - return output - - def tp_backward(grad: torch.Tensor): - if rank == 1: - with torch.no_grad(): - grads = list(grad.chunk(num_stage, dim=-1)) - grads[0], grads[1] = grads[1], grads[0] - else: - grads = None - input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) - grad_1st = coll.scatter(grads, [output_1st.size()], [output_1st.dtype], gather_list) - backward_step([input_1st], [output_1st], [grad_1st])[0] - - fofst = [0] + [-(step // 2) for step in range(num_stage-1)] - bofst = [0] + [-(num_stage - 2 - (step // 2)) for step in range(num_stage-1)] - # print(fofst) - # print(bofst) - fofst = fofst[rank] - bofst = bofst[rank] - last_backward = None - last_forward = None - for step in range(num_microbatch + num_stage - 2): - torch.distributed.barrier() - # print_each_rank(f'=========begin rank {rank}=========') - fmid, bmid = step + fofst, step + bofst - do_backward = 0 <= bmid and bmid <= num_microbatch - 1 - do_forward = 0 <= fmid and fmid <= num_microbatch - 1 - - # step1: tp forward - if 0 <= step and step <= num_microbatch - 1: - # print(f'rank {rank} forward tp model ') - output_1st = tp_forward(first_stage_model, dataloader) - - # step2: backward + forward - if rank == 0: - pass - - if rank % 2 == 0 and rank != 0: - # inter-barrier - if do_backward and last_forward is not None: - # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [last_forward], [model.output_shape()], [model.output_dtype()], - [next_rank], [next_rank] - )[0] - elif do_backward: - # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) - elif last_forward is not None: - # print(f'rank {rank} send forward output ') - coll.send(last_forward, next_rank) - - # backward - if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) - # backward - input_grad = backward_step([input], [output], [output_grad])[0] - - # intra-barrier - if do_backward and do_forward: - # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] - elif do_backward: - # print(f'rank {rank} send backward grad ') - coll.send(input_grad, prev_rank) - elif do_forward: - # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - - # forward - last_forward = None - if do_forward: - # forward step - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) - last_forward = output - - if rank % 2 == 1: - # inter-barrier - if rank == 1: - input = output_1st - else: - if do_forward and last_backward is not None: - # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] - elif do_forward: - # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - elif last_backward is not None: - # print(f'rank {rank} send backward grad ') - coll.send(last_backward, prev_rank) - - # forward - if do_forward: - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) - - # intra-barrier send recv - output_grad = None - if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): - # send forward recv backward - # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [output], [output.size()], [output.dtype], - [next_rank], [next_rank] - )[0] - elif do_forward and not is_last_stage(): - # print(f'rank {rank} send forward output ') - coll.send(output, next_rank) - elif do_backward and not is_last_stage(): - # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) - - # backward + forward - last_backward = None - if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) - input_grad = backward_step([input], [output], [output_grad])[0] - last_backward = input_grad - - # step3: tp backward - if 0 <= (step-num_stage+2) and (step-num_stage+2) <= num_microbatch - 1: - # print(f'rank {rank} backward tp model ') - tp_backward(last_backward) - - # if rank == 0: - # io_input(f'{step}>>>') - # torch.distributed.barrier() - - assert len(input_tensors) == 0 - assert len(output_tensors) == 0 - assert len(input_1st_tensors) == 0 - assert len(output_1st_tensors) == 0 - - # print_each_rank(f'=========end rank {rank}=========') - - -def schedule_tp_1f1b_pack(model: torch.nn.Module, - first_stage_model: torch.nn.Module, - dataloader, - num_microbatch: int, - num_stage: int): - rank = DeviceGroup().rank - next_rank = (DeviceGroup().rank + 1) % DeviceGroup().world_size - prev_rank = (DeviceGroup().rank - 1) % DeviceGroup().world_size - - input_tensors = list() - output_tensors = list() - - input_1st_tensors = list() - output_1st_tensors = list() - - def tp_forward(fmodel, dataloader) -> torch.Tensor: - input = next(dataloader) - #TODO: gather - output = forward_step(fmodel, input) - input_1st_tensors.append(input) - output_1st_tensors.append(output) - output = output.detach().requires_grad_() - return output - - def tp_backward(grad: torch.Tensor): - input_1st, output_1st = input_1st_tensors.pop(0), output_1st_tensors.pop(0) - if rank != 0: - grad = torch.empty_like(output_1st) - torch.distributed.broadcast(grad, src=0) - backward_step([input_1st], [output_1st], [grad])[0] - - fofst = [-(step // 2) for step in range(num_stage)] - bofst = [-(num_stage - 1 - (step // 2)) for step in range(num_stage)] - # print(fofst) - # print(bofst) - fofst = fofst[rank] - bofst = bofst[rank] - last_backward = None - last_forward = None - for step in range(num_microbatch + num_stage - 1): - torch.distributed.barrier() - # print_each_rank(f'=========begin rank {rank}=========') - fmid, bmid = step + fofst, step + bofst - do_backward = 0 <= bmid and bmid <= num_microbatch - 1 - do_forward = 0 <= fmid and fmid <= num_microbatch - 1 - - # step1: tp forward - if 0 <= step and step <= num_microbatch - 1: - # print(f'rank {rank} forward tp model ') - output_1st = tp_forward(first_stage_model, dataloader) - - # forward + backward - if rank % 2 == 0: - # inter-barrier - if rank == 0: - input = output_1st - else: - if do_forward and last_backward is not None: - # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] - elif do_forward: - # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - elif last_backward is not None: - # print(f'rank {rank} send backward grad ') - coll.send(last_backward, prev_rank) - - # forward - if do_forward: - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) - - # mem = torch.cuda.max_memory_allocated() - # print(f'rank {rank}: {mem / 1024 / 1024 / 1024} GB forward') - - # intra-barrier send recv - output_grad = None - if (do_forward and not is_last_stage()) and (do_backward and not is_last_stage()): - # send forward recv backward - # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [output], [output.size()], [output.dtype], - [next_rank], [next_rank] - )[0] - elif do_forward and not is_last_stage(): - # print(f'rank {rank} send forward output ') - coll.send(output, next_rank) - elif do_backward and not is_last_stage(): - # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) - - # backward - last_backward = None - if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) - input_grad = backward_step([input], [output], [output_grad])[0] - last_backward = input_grad - - # backward + forward - if rank % 2 == 1: - # inter-barrier - if is_last_stage(): - output_grad = None - else: - if do_backward and last_forward is not None: - # print(f'rank {rank} recv backward grad + send forward output ') - output_grad = coll.sendrecv( - [last_forward], [model.output_shape()], [model.output_dtype()], - [next_rank], [next_rank] - )[0] - elif do_backward: - # print(f'rank {rank} recv backward grad ') - output_grad = coll.recv(model.output_shape(), next_rank, model.output_dtype()) - elif last_forward is not None: - # print(f'rank {rank} send forward output ') - coll.send(last_forward, next_rank) - - # backward - if do_backward: - input, output = input_tensors.pop(0), output_tensors.pop(0) - # backward - input_grad = backward_step([input], [output], [output_grad])[0] - - # intra-barrier - if do_backward and do_forward: - # print(f'rank {rank} send backward grad + recv forward output ') - input = coll.sendrecv( - [input_grad], [model.input_shape()], [model.input_dtype()], - [prev_rank], [prev_rank] - )[0] - elif do_backward: - # print(f'rank {rank} send backward grad ') - coll.send(input_grad, prev_rank) - elif do_forward: - # print(f'rank {rank} recv forward output ') - input = coll.recv(model.input_shape(), prev_rank, model.input_dtype()) - - # forward - last_forward = None - if do_forward: - # forward step - output = forward_step(model, input) - input_tensors.append(input) - output_tensors.append(output) - last_forward = output - - # step3: tp backward - if 0 <= (step-num_stage+1) and (step-num_stage+1) <= num_microbatch - 1: - # print(f'rank {rank} backward tp model ') - tp_backward(last_backward) - - # memory_summary() - # if rank == 0: - # io_input(f'{step}>>>') - # torch.distributed.barrier() - # print_each_rank(f'=========end rank {rank}: {step}=========') - - assert len(input_tensors) == 0 - assert len(output_tensors) == 0 - assert len(input_1st_tensors) == 0 - assert len(output_1st_tensors) == 0 - - # print_each_rank(f'=========end rank {rank}=========') \ No newline at end of file From 31ab3a7653c5f9501fc199423f5e4c99a5719910 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 10:30:54 +0800 Subject: [PATCH 0700/1892] playground for micro-benchmark test --- handcraft/playground/test.sh | 144 ++++++++++++ handcraft/playground/transformers.py | 326 +++++++++++++++++++++++++++ 2 files changed, 470 insertions(+) create mode 100755 handcraft/playground/test.sh create mode 100644 handcraft/playground/transformers.py diff --git a/handcraft/playground/test.sh b/handcraft/playground/test.sh new file mode 100755 index 00000000..43a95b88 --- /dev/null +++ b/handcraft/playground/test.sh @@ -0,0 +1,144 @@ +datadir=eval/sharding +mkdir -p ${datadir} + +# hidden=768 +# heads=12 +# +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=1 \ +# --nnodes=1 \ +# handcraft/playground/transformers.py \ +# --hidden-size ${hidden} --heads ${heads} \ +# --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt +# +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=1 \ +# --nnodes=1 \ +# handcraft/playground/transformers.py \ +# --hidden-size ${hidden} --heads ${heads} \ +# --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt + + +hidden=1024 +heads=16 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt + + +hidden=1536 +heads=16 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt + +hidden=2304 +heads=24 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt + + +hidden=2560 +heads=32 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt + + + +hidden=4096 +heads=32 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt + +hidden=5120 +heads=40 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt + + +hidden=12288 +heads=96 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/playground/transformers.py \ + --hidden-size ${hidden} --heads ${heads} \ + --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt \ No newline at end of file diff --git a/handcraft/playground/transformers.py b/handcraft/playground/transformers.py new file mode 100644 index 00000000..9b46ad50 --- /dev/null +++ b/handcraft/playground/transformers.py @@ -0,0 +1,326 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/dummy/transformers.py +""" +import torch +from torch.utils import checkpoint + +import cube +from cube.profiler.memory import memory_summary, model_summary +from cube.profiler.timer import CudaTimer, print_each_rank +import argparse + + +parser = argparse.ArgumentParser(description='transformer') +# model arch +parser.add_argument('--layers', type=int, default=4, + help='number encoder/decoder of layers') +parser.add_argument('--hidden-size', type=int, default=1024, + help='hidden size') +parser.add_argument('--heads', type=int, default=16, + help='number of heads') +parser.add_argument('--bs', type=int, default=8, + help='number of heads') +# parallelism +parser.add_argument('--seq', type=int, default=1, + help='sharding sequential execution') +args = parser.parse_args() +print(args) + + +cube.init() + + +class Config: + + layers = args.layers + embed_dim = args.hidden_size + num_heads = args.heads + ffn_dim = embed_dim * 4 + + seqlen = 1024 + + +def self_attention(query: torch.Tensor, + q_proj: torch.Tensor, q_bias: torch.Tensor, + k_proj: torch.Tensor, k_bias: torch.Tensor, + v_proj: torch.Tensor, v_bias: torch.Tensor, + out_proj: torch.Tensor, out_bias: torch.Tensor, + h: int, scale: float, dropout_p: float, mask=True): + num_head = h + L, N = query.size(0), query.size(1) + dim_head = q_proj.size(0) // num_head + + q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + if mask: # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + return output + + +class MultiHeadSelfAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): + super().__init__() + self.inner_dim = inner_dim + self.num_heads = num_heads + self.head_dim = inner_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # Q + self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + # K + self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + # V + self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + + def forward(self, query): + # x = self_attention( + # query, + # self.q_proj, self.q_bias, + # self.k_proj, self.k_bias, + # self.v_proj, self.v_bias, + # self.out_proj, self.out_bias, + # self.num_heads, self.scaling, self.dropout_p, mask=True + # ) + x = checkpoint.checkpoint( + self_attention, + query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p, True + ) + return x + + +class SeqMHSA(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, + inner_dim: int, dropout: float = 0.0, + bias=True, shard_num=4): + super().__init__() + assert num_heads % shard_num == 0 + print_each_rank(f'using sequence MHSA: sharding size: {shard_num}') + self.layers = torch.nn.ModuleList( + MultiHeadSelfAttention( + embed_dim, + num_heads // shard_num, + inner_dim // shard_num, + dropout, + bias + ) for _ in range(shard_num) + ) + + def forward(self, x): + out_sum = None + for layer in self.layers: + out = layer(x) + out_sum = out if out_sum is None else out_sum + out + return out_sum + + +def feedforward(x: torch.Tensor, + proj1: torch.Tensor, proj1_bias: torch.Tensor, + proj2: torch.Tensor, proj2_bias: torch.Tensor, + dropout: float) -> torch.Tensor: + x = torch.nn.functional.linear(x, proj1, proj1_bias) + x = torch.nn.functional.gelu(x) + x = torch.nn.functional.dropout(x, dropout, True, False) + x = torch.nn.functional.linear(x, proj2, proj2_bias) + return x + + +class MLP(torch.nn.Module): + + def __init__(self, embed_dim, hidden_dim, dropout: float, bias=True): + super().__init__() + self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) + self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) + self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) + self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) + self.dropout = dropout + + def forward(self, x: torch.Tensor): + x = checkpoint.checkpoint( + feedforward, + x, self.proj1, self.proj1_bias, self.proj2, self.proj2_bias, self.dropout + ) + # x = feedforward(x, + # self.proj1, self.proj1_bias, + # self.proj2, self.proj2_bias, + # self.dropout) + return x + + +class SeqMLP(torch.nn.Module): + + def __init__(self, embed_dim, hidden_dim, dropout: float, + bias=True, shard_num = 4): + super().__init__() + print_each_rank(f'using sequence MLP: sharding size: {shard_num}') + assert hidden_dim % shard_num == 0 + self.layers = torch.nn.ModuleList( + [MLP(embed_dim, hidden_dim // shard_num, dropout, bias) for _ in range(shard_num)] + ) + + def forward(self, x: torch.Tensor): + out_sum = None + for layer in self.layers: + out = layer(x) + out_sum = out if out_sum is None else out_sum + out + return out_sum + + +class TransformerLayer(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, attn_inner_dim: int, ffn_embed_dim: int, + dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): + super().__init__() + + if args.seq > 1: + self.self_attn = SeqMHSA(embed_dim, num_heads, attn_inner_dim, atten_dropout, shard_num=args.seq) + else: + self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_inner_dim, atten_dropout) + + self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) + self.dropout = torch.nn.Dropout(p=dropout) + + if args.seq > 1: + self.mlp = SeqMLP(embed_dim, ffn_embed_dim, activation_dropout, shard_num=args.seq) + else: + self.mlp = MLP(embed_dim, ffn_embed_dim, activation_dropout) + + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x) + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + x = self.mlp(x) + x = self.dropout(x) + x = x + residual + return x + + +class Model(torch.nn.Module): + + def __init__(self): + super().__init__() + self.cfg = Config() + self.layers = torch.nn.ModuleList( + [TransformerLayer( + self.cfg.embed_dim, + self.cfg.num_heads, + self.cfg.embed_dim, + self.cfg.ffn_dim + ) for _ in range(self.cfg.layers)] + ) + + def forward(self, x: torch.Tensor): # L N E + + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) + return loss + + +class DataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + + self.bs = batch_size + self.cfg = Config() + super().__init__( + shapes=([self.cfg.seqlen, batch_size, self.cfg.embed_dim],), + dtypes=(torch.float,), + batch_dims=(1,) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + inputs = torch.randn( + *(self.cfg.seqlen, self.bs, self.cfg.embed_dim), + dtype=torch.float, + device=torch.cuda.current_device(), + requires_grad=True + ) + return (inputs,) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + + +if __name__ == '__main__': + + dataloader = DataLoader(batch_size=args.bs) + model = Model().cuda() + model.train() + # optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + optimizer = torch.optim.SGD(model.parameters(), lr=3e-05) + + CudaTimer(enable=False).warmup() + torch.distributed.barrier() + iter_num = 10 + for step in range(iter_num): + dataloader = iter(dataloader) + if step >= 4: + CudaTimer(enable=True).start('e2e') + if step == 0: + model_summary(model, next(dataloader)) + loss = model(*next(dataloader)) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if step >= 4: + CudaTimer().stop('e2e') + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-4, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-4) + memory_summary() \ No newline at end of file From e5de86d54d226bf4f96b64d61c8ef731b3c9fa1b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 14:53:43 +0800 Subject: [PATCH 0701/1892] update requirement --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b4306a47..4e316892 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ z3-solver -matplotlib \ No newline at end of file +matplotlib +einops \ No newline at end of file From 3268730c7d5502c73d52c8883231e91427ffef7d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 15:28:42 +0800 Subject: [PATCH 0702/1892] fix codegen bug --- cube/codegen/codegen.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index de35f9e2..c5c11450 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -393,6 +393,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: + fb.insert_body('_ = None') if len(device_nodes) == 0: fb.insert_body('pass') for node in device_nodes: @@ -477,7 +478,7 @@ def tuple_naming(self, tensors: List[Any]) -> str: def return_naming(self, tensors: List[Any]) -> str: tensors = [self.tensor_naming(t) for t in tensors] if len(tensors) == 0: - tensors = '' + tensors = '_' else: tensors = ', '.join(tensors) return tensors From 984cc3f4a2ae1f85f0c3f5feb41ce0b68e5a11d1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 19:38:59 +0800 Subject: [PATCH 0703/1892] add colocate-sharding for swin --- cube/runtime/adapter/distnn.py | 130 ++ handcraft/eval/benchmark_gpt.sh | 56 - handcraft/eval/swin_infer_bs1_224_782Mfp32.sh | 93 -- handcraft/eval/swin_infer_bs1_640_Gfp16.sh | 112 -- handcraft/eval/swin_infer_bs2_224_782Mfp32.sh | 107 -- handcraft/eval/swin_infer_bs2_640_Gfp16.sh | 112 -- handcraft/eval/swin_infer_bs4_640_Gfp16.sh | 112 -- handcraft/eval/swin_scaleup.sh | 30 - handcraft/eval/swin_train_fp16.sh | 257 ---- handcraft/eval/swin_train_fp32.sh | 243 ---- handcraft/module/schedule.py | 324 +++++ handcraft/module/stage.py | 123 ++ handcraft/swin/hybrid_schedule.py | 251 ---- handcraft/swin/layers.py | 326 ----- handcraft/swin/pmodule.py | 65 - handcraft/swin/schedule.py | 234 ---- handcraft/swin/swin_dt.py | 966 --------------- handcraft/swin/swin_dwt.py | 979 --------------- handcraft/swin/swin_dwt_infer.py | 954 --------------- handcraft/swin/swin_flexflow.py | 993 --------------- handcraft/swin/swin_hybrid.py | 1086 ----------------- handcraft/swin/swin_pipe.py | 872 ------------- handcraft/swin/swin_transformer.py | 696 ----------- handcraft/swin/train.py | 873 +++++++++++++ handcraft/swin/utils.py | 106 ++ 25 files changed, 1556 insertions(+), 8544 deletions(-) create mode 100644 cube/runtime/adapter/distnn.py delete mode 100755 handcraft/eval/benchmark_gpt.sh delete mode 100755 handcraft/eval/swin_infer_bs1_224_782Mfp32.sh delete mode 100755 handcraft/eval/swin_infer_bs1_640_Gfp16.sh delete mode 100755 handcraft/eval/swin_infer_bs2_224_782Mfp32.sh delete mode 100755 handcraft/eval/swin_infer_bs2_640_Gfp16.sh delete mode 100755 handcraft/eval/swin_infer_bs4_640_Gfp16.sh delete mode 100644 handcraft/eval/swin_scaleup.sh delete mode 100755 handcraft/eval/swin_train_fp16.sh delete mode 100755 handcraft/eval/swin_train_fp32.sh create mode 100644 handcraft/module/schedule.py create mode 100644 handcraft/module/stage.py delete mode 100644 handcraft/swin/hybrid_schedule.py delete mode 100644 handcraft/swin/layers.py delete mode 100644 handcraft/swin/pmodule.py delete mode 100644 handcraft/swin/schedule.py delete mode 100644 handcraft/swin/swin_dt.py delete mode 100644 handcraft/swin/swin_dwt.py delete mode 100644 handcraft/swin/swin_dwt_infer.py delete mode 100644 handcraft/swin/swin_flexflow.py delete mode 100644 handcraft/swin/swin_hybrid.py delete mode 100644 handcraft/swin/swin_pipe.py delete mode 100644 handcraft/swin/swin_transformer.py create mode 100644 handcraft/swin/train.py create mode 100644 handcraft/swin/utils.py diff --git a/cube/runtime/adapter/distnn.py b/cube/runtime/adapter/distnn.py new file mode 100644 index 00000000..b8f9bf06 --- /dev/null +++ b/cube/runtime/adapter/distnn.py @@ -0,0 +1,130 @@ +import torch +from cube.profiler.timer import CudaTimer + + +class AllReduceIdentity(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, group): + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input + CudaTimer().start(field_name='comm') + torch.distributed.all_reduce(input, group=group) + CudaTimer().stop(field_name='comm') + return input + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class IdentityAllreduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, group): + ctx._group = group + return input + + @staticmethod + def backward(ctx, grad_output): + world_size = torch.distributed.get_world_size(ctx._group) + if world_size == 1: + return grad_output, None + CudaTimer().start(field_name='comm') + torch.distributed.all_reduce(grad_output, group=ctx._group) + CudaTimer().stop(field_name='comm') + return grad_output, None + + +class AllGatherScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, dim, group): + ctx._group = group + ctx._dim = dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input + CudaTimer().start(field_name='comm') + rank = torch.distributed.get_rank(group) + tensor_list = [torch.empty_like(input) for _ in range(world_size)] + tensor_list[rank] = input + torch.distributed.all_gather(tensor_list, input, group=group) + output = torch.cat(tensor_list, dim=dim).contiguous() + CudaTimer().stop(field_name='comm') + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + group = ctx._group + dim = ctx._dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output + CudaTimer().start(field_name='comm') + input_list = grad_output.chunk(world_size, dim=dim) + rank = torch.distributed.get_rank(group) + grad = input_list[rank].contiguous() + CudaTimer().stop(field_name='comm') + return grad, None, None + + +class ReduceBroadcast(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, dst: int, group): + ctx._dst = dst + ctx._group = group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input + CudaTimer().start(field_name='comm') + torch.distributed.reduce(input, dst, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input + + @staticmethod + def backward(ctx, grad_output): + src = ctx._dst + group = ctx._group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output, None, None + CudaTimer().start(field_name='comm') + torch.distributed.broadcast(grad_output, src, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return grad_output, None, None + + +class BroadcastReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, src: int, group=None): + ctx._src = src + ctx._group = group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input + CudaTimer().start(field_name='comm') + torch.distributed.broadcast(input, src, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input + + @staticmethod + def backward(ctx, grad_output): + dst = ctx._src + group = ctx._group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output, None, None + CudaTimer().start(field_name='comm') + if not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + torch.distributed.reduce(grad_output, dst, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return grad_output, None, None diff --git a/handcraft/eval/benchmark_gpt.sh b/handcraft/eval/benchmark_gpt.sh deleted file mode 100755 index 8d528214..00000000 --- a/handcraft/eval/benchmark_gpt.sh +++ /dev/null @@ -1,56 +0,0 @@ - -echo benchmarking gpt megatron hybrid parallelism... - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - benchmark/megatron/gpt.py > /mydata/MagicCube/expdata/8B.2V100.Megatron.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - benchmark/megatron/gpt.py > /mydata/MagicCube/expdata/8B.4V100.Megatron.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - benchmark/megatron/gpt.py > /mydata/MagicCube/expdata/8B.8V100.Megatron.txt - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/gpt/gpt.py > /mydata/MagicCube/expdata/8B.2V100.Cube.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/gpt/gpt.py > /mydata/MagicCube/expdata/8B.4V100.Cube.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/gpt/gpt.py > /mydata/MagicCube/expdata/8B.8V100.Cube.txt diff --git a/handcraft/eval/swin_infer_bs1_224_782Mfp32.sh b/handcraft/eval/swin_infer_bs1_224_782Mfp32.sh deleted file mode 100755 index fe73f429..00000000 --- a/handcraft/eval/swin_infer_bs1_224_782Mfp32.sh +++ /dev/null @@ -1,93 +0,0 @@ - -logfile=expinfer_224_782M_fp32_bs1 - -mkdir -p ${logfile} - -# ================== Maximal Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 1 \ - --layer1 1 1 1 \ - --layer2 1 1 1 \ - --layer3 1 1 1 \ - > ${logfile}/1gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 2 \ - --layer1 1 1 2 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - > ${logfile}/2gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 4 \ - --layer1 1 1 4 \ - --layer2 1 1 4 \ - --layer3 1 1 4 \ - > ${logfile}/4gpu_tp.txt - - -# ================== Window + Tensor Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 2 1 \ - --layer1 1 2 1 \ - --layer2 1 2 1 \ - --layer3 1 1 2 \ - > ${logfile}/2gpu_2wp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 4 1 \ - --layer1 1 4 1 \ - --layer2 1 4 1 \ - --layer3 1 1 4 \ - > ${logfile}/4gpu_4wp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 8 1 \ - --layer1 1 8 1 \ - --layer2 1 4 2 \ - --layer3 1 1 8 \ - > ${logfile}/8gpu_8wp8tp.txt - diff --git a/handcraft/eval/swin_infer_bs1_640_Gfp16.sh b/handcraft/eval/swin_infer_bs1_640_Gfp16.sh deleted file mode 100755 index 7fefa488..00000000 --- a/handcraft/eval/swin_infer_bs1_640_Gfp16.sh +++ /dev/null @@ -1,112 +0,0 @@ -mkdir -p expinfer_Gfp16_bs1 - -# ================== Maximal Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 1 \ - --layer1 1 1 1 \ - --layer2 1 1 1 \ - --layer3 1 1 1 \ - --fp16 \ - > expinfer_Gfp16_bs1/1gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 2 \ - --layer1 1 1 2 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs1/2gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 4 \ - --layer1 1 1 4 \ - --layer2 1 1 4 \ - --layer3 1 1 4 \ - --fp16 \ - > expinfer_Gfp16_bs1/4gpu_tp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 1 8 \ - --layer1 1 1 8 \ - --layer2 1 1 8 \ - --layer3 1 1 8 \ - --fp16 \ - > expinfer_Gfp16_bs1/8gpu_tp.txt - - -# ================== Window + Tensor Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 2 1 \ - --layer1 1 2 1 \ - --layer2 1 2 1 \ - --layer3 1 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs1/2gpu_2wp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 4 1 \ - --layer1 1 4 1 \ - --layer2 1 4 1 \ - --layer3 1 1 4 \ - --fp16 \ - > expinfer_Gfp16_bs1/4gpu_4wp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 1 \ - --layer0 1 8 1 \ - --layer1 1 8 1 \ - --layer2 1 4 2 \ - --layer3 1 1 8 \ - --fp16 \ - > expinfer_Gfp16_bs1/8gpu_8wp8tp.txt - diff --git a/handcraft/eval/swin_infer_bs2_224_782Mfp32.sh b/handcraft/eval/swin_infer_bs2_224_782Mfp32.sh deleted file mode 100755 index 628947b8..00000000 --- a/handcraft/eval/swin_infer_bs2_224_782Mfp32.sh +++ /dev/null @@ -1,107 +0,0 @@ - -logfile=expinfer_224_782M_fp32_bs2 - -mkdir -p ${logfile} - -# ================== Maximal Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 1 1 1 \ - --layer1 1 1 1 \ - --layer2 1 1 1 \ - --layer3 1 1 1 \ - > ${logfile}/1gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 1 1 \ - --layer1 2 1 1 \ - --layer2 2 1 1 \ - --layer3 2 1 1 \ - > ${logfile}/2gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 1 2 \ - --layer1 2 1 2 \ - --layer2 2 1 2 \ - --layer3 2 1 2 \ - > ${logfile}/4gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 1 4 \ - --layer1 2 1 4 \ - --layer2 2 1 4 \ - --layer3 2 1 4 \ - > ${logfile}/8gpu_tp.txt - - -# ================== Window + Tensor Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 1 1 \ - --layer1 2 1 1 \ - --layer2 2 1 1 \ - --layer3 2 1 1 \ - > ${logfile}/2gpu_2wp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 2 1 \ - --layer1 2 2 1 \ - --layer2 2 2 1 \ - --layer3 2 1 2 \ - > ${logfile}/4gpu_4wp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 4 1 \ - --layer1 2 4 1 \ - --layer2 2 4 1 \ - --layer3 2 1 4 \ - > ${logfile}/8gpu_8wp8tp.txt - diff --git a/handcraft/eval/swin_infer_bs2_640_Gfp16.sh b/handcraft/eval/swin_infer_bs2_640_Gfp16.sh deleted file mode 100755 index 91c5e7ae..00000000 --- a/handcraft/eval/swin_infer_bs2_640_Gfp16.sh +++ /dev/null @@ -1,112 +0,0 @@ -mkdir -p expinfer_Gfp16_bs2 - -# ================== Maximal Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 1 1 1 \ - --layer1 1 1 1 \ - --layer2 1 1 1 \ - --layer3 1 1 1 \ - --fp16 \ - > expinfer_Gfp16_bs2/1gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 1 1 \ - --layer1 2 1 1 \ - --layer2 2 1 1 \ - --layer3 2 1 1 \ - --fp16 \ - > expinfer_Gfp16_bs2/2gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 1 2 \ - --layer1 2 1 2 \ - --layer2 2 1 2 \ - --layer3 2 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs2/4gpu_tp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 1 4 \ - --layer1 2 1 4 \ - --layer2 2 1 4 \ - --layer3 2 1 4 \ - --fp16 \ - > expinfer_Gfp16_bs2/8gpu_tp.txt - - -# ================== Window + Tensor Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 1 2 1 \ - --layer1 1 2 1 \ - --layer2 1 2 1 \ - --layer3 1 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs2/2gpu_2wp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 2 1 \ - --layer1 2 2 1 \ - --layer2 2 2 1 \ - --layer3 2 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs2/4gpu_4wp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 2 \ - --layer0 2 4 1 \ - --layer1 2 4 1 \ - --layer2 2 4 1 \ - --layer3 2 1 4 \ - --fp16 \ - > expinfer_Gfp16_bs2/8gpu_8wp8tp.txt - diff --git a/handcraft/eval/swin_infer_bs4_640_Gfp16.sh b/handcraft/eval/swin_infer_bs4_640_Gfp16.sh deleted file mode 100755 index 74eb8195..00000000 --- a/handcraft/eval/swin_infer_bs4_640_Gfp16.sh +++ /dev/null @@ -1,112 +0,0 @@ -mkdir -p expinfer_Gfp16_bs4 - -# ================== Maximal Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 1 1 1 \ - --layer1 1 1 1 \ - --layer2 1 1 1 \ - --layer3 1 1 1 \ - --fp16 \ - > expinfer_Gfp16_bs4/1gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 2 1 1 \ - --layer1 2 1 1 \ - --layer2 2 1 1 \ - --layer3 2 1 1 \ - --fp16 \ - > expinfer_Gfp16_bs4/2gpu_tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 4 1 1 \ - --layer1 4 1 1 \ - --layer2 4 1 1 \ - --layer3 4 1 1 \ - --fp16 \ - > expinfer_Gfp16_bs4/4gpu_tp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 4 1 2 \ - --layer1 4 1 2 \ - --layer2 4 1 2 \ - --layer3 4 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs4/8gpu_tp.txt - - -# ================== Window + Tensor Parallel =============== - -# python -m torch.distributed.launch \ -# --nproc_per_node=2 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt_infer.py --bs 4 \ -# --layer0 4 1 1 \ -# --layer1 4 1 1 \ -# --layer2 4 1 1 \ -# --layer3 4 1 1 \ -# --fp16 \ -# > expinfer_Gfp16_bs4/2gpu_2wp2tp.txt -# -# python -m torch.distributed.launch \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt_infer.py --bs 4 \ -# --layer0 4 1 1 \ -# --layer1 4 1 1 \ -# --layer2 4 1 1 \ -# --layer3 4 1 1 \ -# --fp16 \ -# > expinfer_Gfp16_bs4/4gpu_4wp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt_infer.py --bs 4 \ - --layer0 4 2 1 \ - --layer1 4 2 1 \ - --layer2 4 2 1 \ - --layer3 4 1 2 \ - --fp16 \ - > expinfer_Gfp16_bs4/8gpu_8wp8tp.txt - diff --git a/handcraft/eval/swin_scaleup.sh b/handcraft/eval/swin_scaleup.sh deleted file mode 100644 index 6c4de61d..00000000 --- a/handcraft/eval/swin_scaleup.sh +++ /dev/null @@ -1,30 +0,0 @@ - -# Swin cube maximal scaling -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=$NID \ - --master_addr=worker-0 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_hybrid.py \ - --layer0 2 8 1 \ - --layer1 2 1 8 \ - --layer2 2 1 8 \ - --layer3 2 1 8 \ - --gbs 8 --mbs 8 - -# Swin Megatron maximal scaling -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=$NID \ - --master_addr=worker-0 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_hybrid.py \ - --layer0 2 8 1 \ - --layer1 2 1 8 \ - --layer2 2 1 8 \ - --layer3 2 1 8 \ - --gbs 8 --mbs 8 diff --git a/handcraft/eval/swin_train_fp16.sh b/handcraft/eval/swin_train_fp16.sh deleted file mode 100755 index b6ef102a..00000000 --- a/handcraft/eval/swin_train_fp16.sh +++ /dev/null @@ -1,257 +0,0 @@ -bs=$1 - -logfile=exptrain_782M_bs${bs}_fp32 - -mkdir -p ${logfile} - -# ================== Megatron Policy Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 1 2 \ - --layer1 1 1 2 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - --fp16 \ - > ${logfile}/2gpu_maxdp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 2 1 2 \ - --layer1 2 1 2 \ - --layer2 2 1 2 \ - --layer3 2 1 2 \ - --fp16 \ - > ${logfile}/4gpu_maxdp2tp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 4 1 2 \ - --layer1 4 1 2 \ - --layer2 4 1 2 \ - --layer3 4 1 2 \ - --fp16 \ - > ${logfile}/8gpu_maxdp2tp.txt - -# ================== Maximal Tensor Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 1 2 \ - --layer1 1 1 2 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - --fp16 \ - > ${logfile}/2gpu_maxtp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 1 4 \ - --layer1 1 1 4 \ - --layer2 1 1 4 \ - --layer3 1 1 4 \ - --fp16 \ - > ${logfile}/4gpu_maxtp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 2 1 4 \ - --layer1 2 1 4 \ - --layer2 2 1 4 \ - --layer3 2 1 4 \ - --fp16 \ - > ${logfile}/8gpu_maxtp.txt - -# ================== Window + Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 1 1 \ - --layer1 1 1 1 \ - --layer2 1 1 1 \ - --layer3 1 1 1 \ - --fp16 \ - > ${logfile}/single.txt - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 2 1 \ - --layer1 1 2 1 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - --fp16 \ - > ${logfile}/2gpu_2wp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 4 1 \ - --layer1 1 4 1 \ - --layer2 1 1 4 \ - --layer3 1 1 4 \ - --fp16 \ - > ${logfile}/4gpu_4wp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 8 1 \ - --layer1 1 1 8 \ - --layer2 1 1 8 \ - --layer3 1 1 8 \ - --fp16 \ - > ${logfile}/8gpu_8wp8tp.txt - - -# ================== Data + Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dt.py --bs ${bs} \ - --layer0 2 1 \ - --layer1 2 1 \ - --layer2 1 2 \ - --layer3 1 2 \ - --fp16 \ - > ${logfile}/2gpu_dt_2dp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dt.py --bs ${bs} \ - --layer0 4 1 \ - --layer1 4 1 \ - --layer2 1 4 \ - --layer3 1 4 \ - --fp16 \ - > ${logfile}/4gpu_dt_4dp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dt.py --bs ${bs} \ - --layer0 8 1 \ - --layer1 8 1 \ - --layer2 1 8 \ - --layer3 1 8 \ - --fp16 \ - > ${logfile}/8gpu_dt_8dp8tp.txt - - -# ========================== Data Parallel ====================== # - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 2 1 1 \ - --layer1 2 1 1 \ - --layer2 2 1 1 \ - --layer3 2 1 1 \ - --fp16 \ - > ${logfile}/2gpu_maxdp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 4 1 1 \ - --layer1 4 1 1 \ - --layer2 4 1 1 \ - --layer3 4 1 1 \ - --fp16 \ - > ${logfile}/4gpu_maxdp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 8 1 1 \ - --layer1 8 1 1 \ - --layer2 8 1 1 \ - --layer3 8 1 1 \ - --fp16 \ - > ${logfile}/8gpu_maxdp.txt \ No newline at end of file diff --git a/handcraft/eval/swin_train_fp32.sh b/handcraft/eval/swin_train_fp32.sh deleted file mode 100755 index 97bc677e..00000000 --- a/handcraft/eval/swin_train_fp32.sh +++ /dev/null @@ -1,243 +0,0 @@ -bs=$1 - -logfile=exptrain_782M_bs${bs}_fp32_384 - -mkdir -p ${logfile} - -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 1 1 \ - --layer1 1 1 1 \ - --layer2 1 1 1 \ - --layer3 1 1 1 \ - > ${logfile}/single.txt - -# ================== Megatron Policy Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 1 2 \ - --layer1 1 1 2 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - > ${logfile}/2gpu_maxdp2tp.txt - -# python -m torch.distributed.launch \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt.py --bs ${bs} \ -# --layer0 2 1 2 \ -# --layer1 2 1 2 \ -# --layer2 2 1 2 \ -# --layer3 2 1 2 \ -# > ${logfile}/4gpu_maxdp2tp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 4 1 2 \ - --layer1 4 1 2 \ - --layer2 4 1 2 \ - --layer3 4 1 2 \ - > ${logfile}/8gpu_maxdp2tp.txt - -# ================== Maximal Tensor Parallel =============== - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 1 2 \ - --layer1 1 1 2 \ - --layer2 1 1 2 \ - --layer3 1 1 2 \ - > ${logfile}/2gpu_maxtp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 1 1 4 \ - --layer1 1 1 4 \ - --layer2 1 1 4 \ - --layer3 1 1 4 \ - > ${logfile}/4gpu_maxtp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 2 1 4 \ - --layer1 2 1 4 \ - --layer2 2 1 4 \ - --layer3 2 1 4 \ - > ${logfile}/8gpu_maxtp.txt - -# ================== Window + Tensor Parallel =============== - -# python -m torch.distributed.launch \ -# --nproc_per_node=2 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt.py --bs ${bs} \ -# --layer0 1 2 1 \ -# --layer1 1 2 1 \ -# --layer2 1 1 2 \ -# --layer3 1 1 2 \ -# > ${logfile}/2gpu_2wp2tp.txt -# -# python -m torch.distributed.launch \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt.py --bs ${bs} \ -# --layer0 1 4 1 \ -# --layer1 1 4 1 \ -# --layer2 1 1 4 \ -# --layer3 1 1 4 \ -# > exptrain_782M_bs8_fp32/4gpu_4wp4tp.txt -# -# python -m torch.distributed.launch \ -# --nproc_per_node=8 \ -# --nnodes=1 \ -# --node_rank=0 \ -# --master_addr=127.0.0.1 \ -# --master_port=8004 \ -# --use_env \ -# examples/swin/swin_dwt.py --bs ${bs} \ -# --layer0 1 8 1 \ -# --layer1 1 1 8 \ -# --layer2 1 1 8 \ -# --layer3 1 1 8 \ -# > ${logfile}/8gpu_8wp8tp.txt - - -# ================== Data + Tensor Parallel =============== -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dt.py --bs ${bs} \ - --layer0 2 1 \ - --layer1 2 1 \ - --layer2 1 2 \ - --layer3 1 2 \ - > ${logfile}/2gpu_dt_2dp2tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dt.py --bs ${bs} \ - --layer0 4 1 \ - --layer1 4 1 \ - --layer2 1 4 \ - --layer3 1 4 \ - > ${logfile}/4gpu_dt_4dp4tp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dt.py --bs ${bs} \ - --layer0 8 1 \ - --layer1 8 1 \ - --layer2 1 8 \ - --layer3 1 8 \ - > ${logfile}/8gpu_dt_8dp8tp.txt - - -# ================== Pure Data Parallel ============= - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 2 1 1 \ - --layer1 2 1 1 \ - --layer2 2 1 1 \ - --layer3 2 1 1 \ - > ${logfile}/2gpu_maxdp.txt - - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 4 1 1 \ - --layer1 4 1 1 \ - --layer2 4 1 1 \ - --layer3 4 1 1 \ - > ${logfile}/4gpu_maxdp.txt - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs ${bs} \ - --layer0 8 1 1 \ - --layer1 8 1 1 \ - --layer2 8 1 1 \ - --layer3 8 1 1 \ - > ${logfile}/8gpu_maxdp.txt diff --git a/handcraft/module/schedule.py b/handcraft/module/schedule.py new file mode 100644 index 00000000..e9d9dfd1 --- /dev/null +++ b/handcraft/module/schedule.py @@ -0,0 +1,324 @@ +from typing import List, Tuple +import torch + +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary +from cube.runtime.device import DeviceGroup + +from handcraft.module.stage import PipeStage + +io_input = input + +def forward_step(model, *args, **kwargs): + """ + Forward pass + """ + CudaTimer().start("forward") + outputs = model(*args, **kwargs) + if not isinstance(outputs, tuple): + outputs = (outputs, ) + CudaTimer().stop("forward") + return outputs + + +def backward_step(input_tensors: List[torch.Tensor], + output_tensors: List[torch.Tensor], + output_tensor_grads: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Backward pass + """ + for tensor in input_tensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + tensor.retain_grad() + CudaTimer().start("backward") + torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) + CudaTimer().stop("backward") + input_tensor_grads = [] + for tensor in input_tensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + input_tensor_grads.append(tensor.grad) + else: + input_tensor_grads.append(None) + return input_tensor_grads + + +def recv_forward(model: PipeStage, prev_rank: int) -> List[torch.Tensor]: + shapes, dtypes = model.inputs_info + assert len(shapes) == len(dtypes) + assert isinstance(prev_rank, int), "Expected prev_rank to be int" + # print(f'rank {DeviceGroup().rank} recving forward: {shapes}, {dtypes}') + if len(shapes) == 0: return () + + CudaTimer().start(field_name='comm') + tensors = [ + torch.empty( + shape, requires_grad=True, dtype=dtype, + device=torch.cuda.current_device() + ) for shape, dtype in zip(shapes, dtypes) + ] + recv_ops = [ + torch.distributed.P2POp( + torch.distributed.irecv, tensor, prev_rank + ) for tensor in tensors + ] + reqs = torch.distributed.batch_isend_irecv(recv_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return tensors + + +def recv_backward(model: PipeStage, next_rank: int) -> List[torch.Tensor]: + shapes, dtypes = model.outputs_info + assert len(shapes) == len(dtypes) + assert isinstance(next_rank, int), "Expected next_rank to be int" + # print(f'rank {DeviceGroup().rank} recving backward: {shapes}') + if len(shapes) == 0: return () + + CudaTimer().start(field_name='comm') + tensors = [ + torch.empty( + shape, requires_grad=False, dtype=dtype, + device=torch.cuda.current_device() + ) for shape, dtype in zip(shapes, dtypes) + ] + recv_ops = [ + torch.distributed.P2POp( + torch.distributed.irecv, tensor, next_rank + ) for tensor in tensors + ] + reqs = torch.distributed.batch_isend_irecv(recv_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return tensors + + +def send_forward(outputs: List[torch.Tensor], next_rank: int): + assert all([torch.is_tensor(out) for out in outputs]), "Expected List[Tensor]" + assert isinstance(next_rank, int), "Expected next_rank to be int" + if len(outputs) == 0: return + # print(f'rank {DeviceGroup().rank} sending forward: {[tuple(t.size()) for t in outputs]}') + + CudaTimer().start(field_name='comm') + send_ops = [ + torch.distributed.P2POp( + torch.distributed.isend, tensor, next_rank + ) for tensor in outputs + ] + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + + +def send_backward(grads: List[torch.Tensor], prev_rank: int): + assert all([torch.is_tensor(grad) for grad in grads]), "Expected List[Tensor]" + assert isinstance(prev_rank, int), "Expected prev_rank to be int" + if len(grads) == 0: return + CudaTimer().start(field_name='comm') + # print(f'rank {DeviceGroup().rank} sending backward: {[tuple(t.size()) for t in grads]}') + + send_ops = [ + torch.distributed.P2POp( + torch.distributed.isend, tensor, prev_rank + ) for tensor in grads + ] + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + + +def send_forward_recv_backward(outputs, model: PipeStage, next_rank: int) -> List[torch.Tensor]: + assert all([torch.is_tensor(out) for out in outputs]), "Expected List[Tensor]" + assert isinstance(next_rank, int), "Expected next_rank to be int" + shapes, dtypes = model.outputs_info + assert len(shapes) == len(dtypes) + # print(f'rank {DeviceGroup().rank} sending forward: {[tuple(t.size()) for t in outputs]} recving backward {shapes}') + + CudaTimer().start(field_name='comm') + ops = list() + # send forward outputs + send_ops = [ + torch.distributed.P2POp( + torch.distributed.isend, tensor, next_rank + ) for tensor in outputs + ] + ops += send_ops + # recv backward inputs + tensors = [ + torch.empty( + shape, requires_grad=True, dtype=dtype, + device=torch.cuda.current_device() + ) for shape, dtype in zip(shapes, dtypes) + ] + recv_ops = [ + torch.distributed.P2POp( + torch.distributed.irecv, tensor, next_rank + ) for tensor in tensors + ] + ops += recv_ops + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return tensors + + +def send_backward_recv_forward(grads, model: PipeStage, prev_rank: int) -> List[torch.Tensor]: + assert all([torch.is_tensor(grad) for grad in grads]), "Expected List[Tensor]" + assert isinstance(prev_rank, int), "Expected prev_rank to be int" + shapes, dtypes = model.inputs_info + assert len(shapes) == len(dtypes) + # print(f'rank {DeviceGroup().rank} sending backward: {[tuple(t.size()) for t in grads]} recving forward {shapes}') + + CudaTimer().start(field_name='comm') + ops = list() + # send backward gradients + send_ops = [ + torch.distributed.P2POp( + torch.distributed.isend, tensor, prev_rank + ) for tensor in grads + ] + ops += send_ops + # recv forward inputs + tensors = [ + torch.empty( + shape, requires_grad=True, dtype=dtype, + device=torch.cuda.current_device() + ) for shape, dtype in zip(shapes, dtypes) + ] + recv_ops = [ + torch.distributed.P2POp( + torch.distributed.irecv, tensor, prev_rank + ) for tensor in tensors + ] + ops += recv_ops + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return tensors + + + +def schedule_naive(model: PipeStage, dataloader, num_microbatch: int): + """ + neighbors: (prev_rank: int, next_rank: int) + """ + prev_rank = model.prev_stage_global_grank + next_rank = model.next_stage_global_rank + + for _ in range(num_microbatch): + model.data = next(dataloader) + # print(f'rank {rank} recving forward input...') + inputs = () if model.is_first_stage else recv_forward(model, prev_rank) + # forward + outputs = forward_step(model, *inputs) + # send forward + if not model.is_last_stage: + # print(f'rank {rank} sending forward output...') + send_forward(outputs, next_rank) + # recv backward + # print(f'rank {rank} recving backward input...') + output_grads = (None,) if model.is_last_stage else recv_backward(model, next_rank) + # backward + input_grads = backward_step(inputs, outputs, output_grads) + # send backward + if not model.is_first_stage: + # print(f'rank {rank} sending backward output...') + send_backward(input_grads, prev_rank) + + +def schedule_1f1b(model: PipeStage, + dataloader, + num_microbatch: int, + recompute=False): + + num_stage = model.num_stages + prev_rank = model.prev_stage_global_grank + next_rank = model.next_stage_global_rank + + num_warmup_microbatches = num_stage - 1 - model.stage_local_rank + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatch) + num_warmup_remaining = num_microbatch - num_warmup_microbatches + + # warmup + for i in range(num_warmup_microbatches): + model.data = next(dataloader) + # recv forward + inputs = () if model.is_first_stage else recv_forward(model, prev_rank) + # forward + model.push(inputs, 'inputs') + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs) + model.push(None, 'outputs') + else: + outputs = forward_step(model, *inputs) + model.push(outputs, 'outputs') + # send forward + send_forward(outputs, next_rank) + + # before running 1f1b: need to recv first forward tensor + if num_warmup_remaining > 0: + model.data = next(dataloader) + inputs = () if model.is_first_stage else recv_forward(model, prev_rank) + + # run 1f1b + for i in range(num_warmup_remaining): + model.data = next(dataloader) + # forward + model.push(inputs, 'inputs') + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs) + model.push(None, 'outputs') + else: + outputs = forward_step(model, *inputs) + model.push(outputs, 'outputs') + + # send forward recv backward + grads = (None,) + if not model.is_last_stage: + grads = send_forward_recv_backward(outputs, model, next_rank) + + # backward + inputs, outputs = model.pop('inputs'), model.pop('outputs') + if recompute: + assert outputs is None + outputs = forward_step(model, *inputs) + input_grads = backward_step(inputs, outputs, grads) + + # send backward + inputs = () + if not model.is_first_stage: + if i != (num_warmup_remaining-1): + # send backward recv forward + inputs = send_backward_recv_forward(input_grads, model, prev_rank) + else: + # send backward + send_backward(input_grads, prev_rank) + + # cooldown + for i in range(num_warmup_microbatches): + inputs, outputs = model.pop('inputs'), model.pop('outputs') + # recv backward + grads = (None,) if model.is_last_stage else recv_backward(model, next_rank) + # backward + if recompute: + assert outputs is None + outputs = forward_step(model, *inputs) + input_grads = backward_step(inputs, outputs, grads) + # send backward + if not model.is_first_stage: + send_backward(input_grads, prev_rank) + + model.assert_empty_cached() \ No newline at end of file diff --git a/handcraft/module/stage.py b/handcraft/module/stage.py new file mode 100644 index 00000000..b2939a50 --- /dev/null +++ b/handcraft/module/stage.py @@ -0,0 +1,123 @@ +from typing import Any, List, Tuple +import torch + + +class PipeStage(torch.nn.Module): + + def __init__(self): + super().__init__() + self._cached = dict() + self._data = () + self._input_shapes = () + self._input_dtypes = () + self._output_shapes = () + self._output_dtypes = () + + # pipeline information + self._num_stages = None + self._is_first_stage = None + self._is_last_stage = None + self._stage_grank = None # global rank + self._next_grank = None # global rank + self._prev_grank = None # global rank + self._stage_lrank = None # local rank + self._next_lrank = None # local rank + self._prev_lrank = None # local rank + + @property + def is_first_stage(self) -> bool: + return self._is_first_stage + + @property + def is_last_stage(self) -> bool: + return self._is_last_stage + + @property + def next_stage_global_rank(self) -> int: + return self._next_grank + + @property + def prev_stage_global_grank(self) -> int: + return self._prev_grank + + @property + def stage_global_rank(self) -> int: + return self._stage_grank + + @property + def next_stage_local_rank(self) -> int: + return self._next_lrank + + @property + def prev_stage_local_rank(self) -> int: + return self._prev_lrank + + @property + def stage_local_rank(self) -> int: + return self._stage_lrank + + @property + def num_stages(self): + return self._num_stages + + def set_pipeline(self, group_global_ranks: Tuple[int]): + """ + Setup pipeline information given global ranks. + Note NCCL group should be initialized outside + """ + if len(group_global_ranks) == 0: + group_global_ranks = (torch.distributed.get_rank(),) + self._num_stages = len(group_global_ranks) + self._stage_grank = torch.distributed.get_rank() + self._stage_lrank = group_global_ranks.index(self._stage_grank) + + self._next_grank = group_global_ranks[(self._stage_lrank+1) % self.num_stages] + self._prev_grank = group_global_ranks[(self._stage_lrank-1) % self.num_stages] + + self._next_lrank = (self._stage_lrank+1) % self.num_stages + self._prev_lrank = (self._stage_lrank-1) % self.num_stages + + self._is_first_stage = self._stage_lrank == 0 + self._is_last_stage = self._stage_lrank == self.num_stages - 1 + + def pop(self, region: str = 'default') -> Any: + return self._cached[region].pop(0) + + def push(self, val: Any, region: str = 'default'): + if region not in self._cached: + self._cached[region] = [] + return self._cached[region].append(val) + + def assert_empty_cached(self): + for key, vals in self._cached.items(): + assert len(vals) == 0, f"key {key} still has {len(vals)} values" + + @property + def inputs_info(self) -> Tuple[Tuple, Tuple]: + """ + return input shapes and dtypes + """ + return self._input_shapes, self._input_dtypes + + @inputs_info.setter + def inputs_info(self, shapes_dtypes: Tuple[Tuple, Tuple]): + self._input_shapes, self._input_dtypes = shapes_dtypes + + @property + def outputs_info(self) -> Tuple[Tuple, Tuple]: + """ + return output shapes and dtypes + """ + return self._output_shapes, self._output_dtypes + + @outputs_info.setter + def outputs_info(self, shapes_dtypes: Tuple[Tuple, Tuple]): + self._output_shapes, self._output_dtypes = shapes_dtypes + + @property + def data(self) -> Tuple: + return self._data + + @data.setter + def data(self, datas: Tuple): + self._data = datas diff --git a/handcraft/swin/hybrid_schedule.py b/handcraft/swin/hybrid_schedule.py deleted file mode 100644 index 4b474fca..00000000 --- a/handcraft/swin/hybrid_schedule.py +++ /dev/null @@ -1,251 +0,0 @@ -import torch - -from torch.distributed.distributed_c10d import _get_global_rank -from cube.profiler.timer import CudaTimer - - -def get_global_rank(group, group_rank): - if group is None: - return group_rank - else: - return _get_global_rank(group, group_rank) - - -def is_last_stage(group): - return torch.distributed.get_rank(group=group) == torch.distributed.get_world_size(group=group) - 1 - - -#================= WhatToDO functions ==================# - -def forward_step(model, image, trans_input=None): - CudaTimer().start("forward") - output = model(image, trans_input) - CudaTimer().stop("forward") - return output - - -def backward_step(feature_map, output_tensor, output_tensor_grad): - """ - Calculate input tensor gradient - """ - if feature_map is not None and feature_map.requires_grad: - feature_map.retain_grad() - CudaTimer().start("backward") - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) - CudaTimer().stop("backward") - input_tensor_grad = None - if feature_map is not None and feature_map.requires_grad: - input_tensor_grad = feature_map.grad - return input_tensor_grad - -#================= WhatToDO functions ==================# - -#================= Between Stage functions ==================# - -def send(tensors, to_rank, group): - """ - send tensor to the target rank - """ - if to_rank < 0 or to_rank >= torch.distributed.get_world_size(group): - return None - if group is not None: - to_rank = get_global_rank(group, to_rank) - # print(f'send: {torch.distributed.get_rank()} -> {to_rank}: {tensors[0].shape}') - assert isinstance(tensors, list) or isinstance(tensors, tuple) - CudaTimer().start("send") - reqs = list() - for tensor in tensors: - if tensor is None: - continue - elif torch.is_tensor(tensor): - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, to_rank - ) - reqs.append(send_op) - else: - raise RuntimeError("Expected tensor or None") - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("send") - - -def recv(shapes, from_rank, dtype, group): - if from_rank < 0 or from_rank >= torch.distributed.get_world_size(group): - return [None] * len(shapes) - assert isinstance(shapes, list) or isinstance(shapes, tuple) - if group is not None: - from_rank = get_global_rank(group, from_rank) - # print(f'recv: {torch.distributed.get_rank()} <- {from_rank}: {shapes}') - CudaTimer().start("recv") - reqs = list() - recved_tensors = list() - for shape in shapes: - if shape is None: - recved_tensors.append(None) - continue - tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device(), - dtype=dtype - ) - recved_tensors.append(tensor) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, from_rank - ) - reqs.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("recv") - return recved_tensors - - -def send_and_recv(send_tensors, recv_shapes, rank, dtype, group): - if rank < 0 or rank >= torch.distributed.get_world_size(group): - return [None] * len(recv_shapes) - if group is not None: - rank = get_global_rank(group, rank) - # print(f'exchange: {torch.distributed.get_rank()} <-> {rank}: {recv_shapes}') - assert isinstance(send_tensors, list) or isinstance(send_tensors, tuple) - assert isinstance(recv_shapes, list) or isinstance(recv_shapes, tuple) - CudaTimer().start("send_recv") - reqs = list() - recved_tensors = list() - for tensor in send_tensors: - if tensor is None: - continue - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, rank - ) - reqs.append(send_op) - for shape in recv_shapes: - if shape is None: - recved_tensors.append(None) - continue - recv_tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device(), - dtype=dtype - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_tensor, rank - ) - recved_tensors.append(recv_tensor) - reqs.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("send_recv") - return recved_tensors - -#================= Between Stage functions ==================# - -def split_batch(inputs, num_microbatches): - """ - Split a mini-batch to micro-batches - """ - assert isinstance(inputs, list) or isinstance(inputs, tuple) - input_chunks = list() - for feature_map in inputs: - if torch.is_tensor(feature_map): - feature_map = torch.chunk(feature_map, chunks=num_microbatches, dim=0) - else: - feature_map = [feature_map] * num_microbatches - input_chunks.append(feature_map) - micro_batches = list() - for micro_data in zip(*tuple(input_chunks)): - micro_batches.append(micro_data) - return micro_batches - - -#================= Scheduling ==================# - -def scheduling_1f1b(model, inputs, bs, micro_bs, dtype, group): - myrank = torch.distributed.get_rank(group) - - num_microbatches = int(bs / micro_bs) - num_warmup_microbatches = \ - (torch.distributed.get_world_size(group) - - torch.distributed.get_rank(group) - 1) - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_warmup_remaining = num_microbatches - num_warmup_microbatches - - input_tensors = list() - output_tensors = list() - - inputs = split_batch(inputs, num_microbatches) - - # warmup forward pass - for i in range(num_warmup_microbatches): - # recv forward - # print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) - feature_map = recv( - (torch.Size(model.in_size),), myrank-1, dtype, group - )[0] - image = inputs[i][0] - # forward - output_tensor = forward_step(model, image, feature_map) - # send forward - # print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) - send((output_tensor,), myrank+1, group) - - input_tensors.append(feature_map) - output_tensors.append(output_tensor) - - # before running 1F1B, need to recieve first forward tensor - if num_warmup_remaining > 0: - # recv forward - # print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) - feature_map = recv( - (torch.Size(model.in_size),), myrank-1, dtype, group - )[0] - image = inputs[num_warmup_microbatches][0] - - # run 1F1B - for i in range(num_warmup_remaining): - # forward - output_tensor = forward_step(model, image, feature_map) - # send forward + recv backward grads - # print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) - output_tensor_grad = send_and_recv( - (output_tensor,), - (torch.Size(model.out_size),), - myrank+1, dtype, group - )[0] - input_tensors.append(feature_map) - output_tensors.append(output_tensor) - # backward - feature_map, output_tensor = input_tensors.pop(0), output_tensors.pop(0) - input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) - if i != (num_warmup_remaining-1): - # send backward grads + recv forward results - # print('[1f1b] rank {}: step-{}: sending backward + recving forward...'.format(myrank, i)) - feature_map = send_and_recv( - (input_tensor_grad,), - (torch.Size(model.in_size),), - myrank-1, dtype, group - )[0] - image = inputs[num_warmup_microbatches+i+1][0] - else: # last iteration - no more inputs - feature_map = None - # send backward grads - # print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) - send((input_tensor_grad,), myrank-1, group) - - # cooldown gradient trans back - for i in range(num_warmup_microbatches): - feature_map = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - # recv backward gradients - output_tensor_grad = recv( - (torch.Size(model.out_size),), myrank+1, dtype, group - )[0] - # backward - input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) - # send backward gradients - # print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) - send((input_tensor_grad,), myrank-1, group) - -#================= Scheduling ==================# \ No newline at end of file diff --git a/handcraft/swin/layers.py b/handcraft/swin/layers.py deleted file mode 100644 index d9b80e2e..00000000 --- a/handcraft/swin/layers.py +++ /dev/null @@ -1,326 +0,0 @@ -import torch -from torch import autograd -import torch.nn.functional as F -from torch.nn.parameter import Parameter -from cube.profiler.timer import CudaTimer - - -def _reduce(input_, group): - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - CudaTimer().start(field_name='tp_allreduce') - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - CudaTimer().stop(field_name='tp_allreduce') - return input_ - torch.distributed.all_reduce(input_, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='tp_allreduce') - return input_ - - -def _split(input_, group, dim=-1): - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - world_size = torch.distributed.get_world_size(group=group) - rank = torch.distributed.get_rank(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size==1: - return input_ - dim_size = input_.size()[dim] // world_size - tensor_list = torch.split(input_, dim_size, dim=dim) - output = tensor_list[rank].contiguous() - return output - - -def _gather(input_, group, dim=-1): - """Gather tensors and concatinate along the last dimension.""" - CudaTimer().start(field_name='tp_allgather') - - world_size = torch.distributed.get_world_size(group=group) - rank = torch.distributed.get_rank(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size==1: - CudaTimer().stop(field_name='tp_allgather') - return input_ - # Size and dimension. - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=group) - torch.cuda.synchronize() - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=dim).contiguous() - - CudaTimer().stop(field_name='tp_allgather') - return output - -def _scatter(input_, group, dim=0): - """Reduce-Scatter tensor""" - CudaTimer().start(field_name='tp_reduce_scatter') - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - CudaTimer().stop(field_name='tp_reduce_scatter') - return input_ - rank = torch.distributed.get_rank(group=group) - tensor_list = list(torch.chunk(input_, world_size, dim)) - # for idx, tensor in enumerate(tensor_list): - # tensor_list[idx] = tensor.contiguous() - torch.distributed.reduce_scatter(tensor_list[rank], tensor_list, group=group) - CudaTimer().stop(field_name='tp_reduce_scatter') - return tensor_list[rank] - - -class ColumnInputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group): - ctx.group = group - return input_ - @staticmethod - def backward(ctx, grad_output): - group = ctx.group - return _reduce(grad_output, group), None - - -class ColumnOutputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group): - ctx.group = group - return _gather(input_, group) - @staticmethod - def backward(ctx, grad_output): - group = ctx.group - return _split(grad_output, group), None - - -class RowInputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group): - ctx.group = group - return _split(input_, group) - - @staticmethod - def backward(ctx, grad_outputs): - group = ctx.group - return _gather(grad_outputs, group), None - - -class RowOutputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group): - ctx.group = group - return _reduce(input_, group) - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -class DPtoTPAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group): - """ - split - """ - ctx.group = group - return _gather(input_, group, dim=0) - - @staticmethod - def backward(ctx, grad_output): - """ - reduce-scatter - """ - group = ctx.group - return _split(grad_output, group, dim=0), None - -class ValueTPtoEleDPAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group): - """ - Reduce Scatter - """ - ctx.group = group - return _scatter(input_, group, dim=0) - - @staticmethod - def backward(ctx, grad_output): - """ - Allgather - """ - group = ctx.group - return _gather(grad_output, group, dim=0), None - - -class TPtoDPAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, group): - """ - Reduce-scatter - """ - ctx.group = group - return _split(input_, group, dim=0) - - @staticmethod - def backward(ctx, grad_output): - """ - all-gather - """ - group = ctx.group - return _gather(grad_output, group, dim=0), None - - - - -class ColumnParallelLinear(torch.nn.Module): - - def __init__(self, input_size, output_size, bias=True, in_adapter=True, out_adapter=True, tp_group=-1): - super().__init__() - assert tp_group != -1 - self.input_size = input_size - self.output_size = output_size - self.in_adapter = in_adapter - self.out_adapter = out_adapter - - self.group = tp_group - self.world_size = torch.distributed.get_world_size(group=self.group) - - # print_each_rank(f'> parallizing linear using column partition: ' - # f'{output_size} partitioned by {world_size} devices') - - # not if output size is smaller than world size, - # no parallel enbaled. Each device compute the same - if self.world_size > output_size: - raise RuntimeError - - self.weight = Parameter(torch.empty( - int(self.output_size // self.world_size), - self.input_size, - )) - if bias: - self.bias = Parameter(torch.empty( - int(self.output_size // self.world_size), - )) - else: - self.bias = None - - def forward(self, input_): - - if self.in_adapter and self.world_size > 1: - input_ = ColumnInputAdapter.apply(input_, self.group) - - output = F.linear(input_, self.weight, self.bias) - - if self.out_adapter and self.world_size > 1: - output = ColumnOutputAdapter.apply(output, self.group) - - return output - - -class RowParallelLinear(torch.nn.Module): - - def __init__(self, input_size, output_size, bias=True, in_adapter=True, out_adapter=True, tp_group=-1): - super().__init__() - assert tp_group != -1 - self.input_size = input_size - self.output_size = output_size - self.in_adapter = in_adapter - self.out_adapter = out_adapter - - self.group = tp_group - self.world_size = torch.distributed.get_world_size(group=self.group) - - # print_each_rank(f'> parallizing linear using row partition: ' - # f'{output_size} partitioned by {world_size} devices') - - # not if output size is smaller than world size, - # no parallel enbaled. Each device compute the same - if self.world_size > input_size: - raise RuntimeError - - self.weight = Parameter(torch.empty( - self.output_size, - int(self.input_size // self.world_size), - )) - if bias: - self.bias = Parameter(torch.empty(self.output_size)) - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def forward(self, input_): - bias = self.bias - if self.in_adapter and self.world_size > 1: - input_ = RowInputAdapter.apply(input_, self.group) - - output = F.linear(input_, self.weight, bias) - - if self.out_adapter and self.world_size > 1: - output = RowOutputAdapter.apply(output, self.group) - - return output - - -class ShardEmbedding(torch.nn.Module): - - def __init__(self, num_embeddings, embedding_dim, tp_group): - super().__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - - self.group = tp_group - self.shard_num = torch.distributed.get_world_size(group=self.group) - self.myshard = torch.distributed.get_rank(group=self.group) - - shard_num_embeddings = self.num_embeddings // self.shard_num - self.vocab_start_index = shard_num_embeddings * self.myshard - self.vocab_end_index = self.vocab_start_index + shard_num_embeddings - - self.weight = torch.nn.Parameter( - torch.empty(shard_num_embeddings, self.embedding_dim) - ) - - def forward(self, input_): - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - output_parallel = F.embedding( - masked_input, self.weight, - None, None, 2., False, False - ) - output = RowOutputAdapter.apply(output_parallel, self.group) - return output - - -class DPtoTP(torch.nn.Module): - - def __init__(self, dp_group): - super().__init__() - self.group = dp_group - - def forward(self, input_): - return DPtoTPAdapter.apply(input_, self.group) - - -class TPtoDP(torch.nn.Module): - - def __init__(self, tp_group): - super().__init__() - self.group = tp_group - - def forward(self, input_): - return TPtoDPAdapter.apply(input_, self.group) - - -class ValueTPtoEleDP(torch.nn.Module): - - def __init__(self, tp_group): - super().__init__() - self.group = tp_group - - def forward(self, input_): - return ValueTPtoEleDPAdapter.apply(input_, self.group) diff --git a/handcraft/swin/pmodule.py b/handcraft/swin/pmodule.py deleted file mode 100644 index 1ab0387e..00000000 --- a/handcraft/swin/pmodule.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import List - -import torch - -from cube.runtime.device import DeviceGroup - - -class ParallelModule(torch.nn.Module): - - def __init__(self, pp_ranks: List[int] = list(), - dp_ranks: List[int] = list(), - tp_ranks: List[int] = list()): - - super().__init__() - self._pp_ranks = tuple(pp_ranks) - self._pp_group = DeviceGroup().get_group(pp_ranks) - - self._dp_ranks = tuple(dp_ranks) - self._dp_group = DeviceGroup().get_group(dp_ranks) - - self._tp_ranks = tuple(tp_ranks) - self._tp_group = DeviceGroup().get_group(tp_ranks) - - self.in_size = None - self.out_size = None - - @property - def pp_ranks(self): - return self._pp_ranks - - @property - def pp_group(self): - return self._pp_group - - def use_pp(self): - return len(self._pp_ranks) > 1 - - @property - def dp_ranks(self): - return self._dp_ranks - - @property - def dp_group(self): - return self._dp_group - - def use_dp(self): - return len(self._dp_ranks) > 1 - - @property - def tp_ranks(self): - return self._tp_ranks - - @property - def tp_group(self): - return self._tp_group - - @property - def use_tp(self): - return len(self._tp_ranks) > 1 - - def set_in_size(self, size: List[int]): - self.in_size = size - - def set_out_size(self, size: List[int]): - self.out_size = size \ No newline at end of file diff --git a/handcraft/swin/schedule.py b/handcraft/swin/schedule.py deleted file mode 100644 index 5b271bbc..00000000 --- a/handcraft/swin/schedule.py +++ /dev/null @@ -1,234 +0,0 @@ -import torch - -from cube.profiler.timer import CudaTimer - - -def is_last_stage(): - return torch.distributed.get_rank() == torch.distributed.get_world_size() - 1 - - -#================= WhatToDO functions ==================# - -def forward_step(model, image, trans_input=None): - CudaTimer().start("forward") - output = model(image, trans_input) - CudaTimer().stop("forward") - return output - - -def backward_step(feature_map, output_tensor, output_tensor_grad): - """ - Calculate input tensor gradient - """ - if feature_map is not None and feature_map.requires_grad: - feature_map.retain_grad() - CudaTimer().start("backward") - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) - CudaTimer().stop("backward") - input_tensor_grad = None - if feature_map is not None and feature_map.requires_grad: - input_tensor_grad = feature_map.grad - return input_tensor_grad - -#================= WhatToDO functions ==================# - -#================= Between Stage functions ==================# - -def send(tensors, to_rank): - """ - send tensor to the target rank - """ - if to_rank < 0 or to_rank >= torch.distributed.get_world_size(): - return None - assert isinstance(tensors, list) or isinstance(tensors, tuple) - CudaTimer().start("send") - reqs = list() - for tensor in tensors: - if tensor is None: - continue - elif torch.is_tensor(tensor): - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, to_rank - ) - reqs.append(send_op) - else: - raise RuntimeError("Expected tensor or None") - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("send") - - -def recv(shapes, from_rank, dtype=torch.float): - if from_rank < 0 or from_rank >= torch.distributed.get_world_size(): - return [None] * len(shapes) - assert isinstance(shapes, list) or isinstance(shapes, tuple) - CudaTimer().start("recv") - reqs = list() - recved_tensors = list() - for shape in shapes: - if shape is None: - recved_tensors.append(None) - continue - tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device(), - dtype=dtype - ) - recved_tensors.append(tensor) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, from_rank - ) - reqs.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("recv") - return recved_tensors - - -def send_and_recv(send_tensors, recv_shapes, rank, dtype=torch.float): - if rank < 0 or rank >= torch.distributed.get_world_size(): - return [None] * len(recv_shapes) - assert isinstance(send_tensors, list) or isinstance(send_tensors, tuple) - assert isinstance(recv_shapes, list) or isinstance(recv_shapes, tuple) - CudaTimer().start("send_recv") - reqs = list() - recved_tensors = list() - for tensor in send_tensors: - if tensor is None: - continue - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, rank - ) - reqs.append(send_op) - for shape in recv_shapes: - if shape is None: - recved_tensors.append(None) - continue - recv_tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device(), - dtype=dtype - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_tensor, rank - ) - recved_tensors.append(recv_tensor) - reqs.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("send_recv") - return recved_tensors - -#================= Between Stage functions ==================# - -def split_batch(inputs, num_microbatches): - """ - Split a mini-batch to micro-batches - """ - assert isinstance(inputs, list) or isinstance(inputs, tuple) - input_chunks = list() - for feature_map in inputs: - if torch.is_tensor(feature_map): - feature_map = torch.chunk(feature_map, chunks=num_microbatches, dim=0) - else: - feature_map = [feature_map] * num_microbatches - input_chunks.append(feature_map) - micro_batches = list() - for micro_data in zip(*tuple(input_chunks)): - micro_batches.append(micro_data) - return micro_batches - - -#================= Scheduling ==================# - -def scheduling_1f1b(model, inputs, bs, micro_bs, dtype=torch.float): - myrank = torch.distributed.get_rank() - - num_microbatches = int(bs / micro_bs) - num_warmup_microbatches = \ - (torch.distributed.get_world_size() - - torch.distributed.get_rank() - 1) - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_warmup_remaining = num_microbatches - num_warmup_microbatches - - input_tensors = list() - output_tensors = list() - - inputs = split_batch(inputs, num_microbatches) - - # warmup forward pass - for i in range(num_warmup_microbatches): - # recv forward - # print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) - feature_map = recv( - (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype - )[0] - image = inputs[i][0] - # forward - output_tensor = forward_step(model, image, feature_map) - # send forward - # print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) - send((output_tensor,), myrank+1) - - input_tensors.append(feature_map) - output_tensors.append(output_tensor) - - # before running 1F1B, need to recieve first forward tensor - if num_warmup_remaining > 0: - # recv forward - # print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) - feature_map = recv( - (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype - )[0] - image = inputs[num_warmup_microbatches][0] - - # run 1F1B - for i in range(num_warmup_remaining): - # forward - output_tensor = forward_step(model, image, feature_map) - # send forward + recv backward grads - # print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) - output_tensor_grad = send_and_recv( - (output_tensor,), - (torch.Size([micro_bs] + model.out_size),), - myrank+1, dtype - )[0] - input_tensors.append(feature_map) - output_tensors.append(output_tensor) - # backward - feature_map, output_tensor = input_tensors.pop(0), output_tensors.pop(0) - input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) - if i != (num_warmup_remaining-1): - # send backward grads + recv forward results - # print('[1f1b] rank {}: step-{}: sending backward + recving forward...'.format(myrank, i)) - feature_map = send_and_recv( - (input_tensor_grad,), - (torch.Size([micro_bs] + model.in_size),), - myrank-1, dtype - )[0] - image = inputs[num_warmup_microbatches+i+1][0] - else: # last iteration - no more inputs - feature_map = None - # send backward grads - # print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) - send((input_tensor_grad,), myrank-1) - - # cooldown gradient trans back - for i in range(num_warmup_microbatches): - feature_map = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - # recv backward gradients - output_tensor_grad = recv( - (torch.Size([micro_bs] + model.out_size),), myrank+1, dtype - )[0] - # backward - input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) - # send backward gradients - # print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) - send((input_tensor_grad,), myrank-1) - -#================= Scheduling ==================# \ No newline at end of file diff --git a/handcraft/swin/swin_dt.py b/handcraft/swin/swin_dt.py deleted file mode 100644 index 6a9ed004..00000000 --- a/handcraft/swin/swin_dt.py +++ /dev/null @@ -1,966 +0,0 @@ - -# -------------------------------------------------------- -# Modified from Swin-Transformer Repo -""" -python -m torch.distributed.launch \ - --nproc_per_node=1 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dt.py --bs 16 \ - --layer0 1 1 \ - --layer1 1 1 \ - --layer2 1 1 \ - --layer3 1 1 -""" -# -------------------------------------------------------- - -from typing import Dict, Optional, Tuple -import torch -import torch.nn as nn -import argparse -import time - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.reducer import Reducer - -from handcraft.swin.layers import ColumnParallelLinear, DPtoTP, RowParallelLinear, TPtoDP - -_dp_reducer: Dict[Tuple[int], Reducer] = dict() - - -def setup_device_group(tp: int, dp: int, layer_id: int): - """ - Layer wise device group initialize - - Returns: - - """ - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - if not tp * dp == ndevs: - raise RuntimeError("Expected same device number") - - assert tp == 1 or dp == 1, "Currently hybrid not supported" - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize tensor parallel groups - for i in range(dp): - ranks = list(range(i * tp, (i + 1) * tp)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - # initialize data parallel groups - for i in range(tp): - ranks = list(range(i, ndevs, tp)) - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) - return tp_ranks, dp_ranks - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class MegatronMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tp_group=-1): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=tp_group) - self.act = act_layer() - # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=True, tp_group=tp_group) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class MegatronWindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., tp_group=-1): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.global_num_heads = num_heads - - tp_world_size = torch.distributed.get_world_size(group=tp_group) - if num_heads % tp_world_size != 0: - raise RuntimeError(f'detecting un-even num head {num_heads} partition to {tp_world_size}') - self.num_heads = num_heads // tp_world_size - - self.dim_heads = dim // self.global_num_heads - self.scale = qk_scale or self.dim_heads ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # relative position index - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - self.register_buffer('relative_position_index', relative_position_index) - - # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) - self.attn_drop = nn.Dropout(attn_drop) - # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=True, tp_group=tp_group) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, - tp_group=-1): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = MegatronWindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - tp_group=tp_group) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MegatronMlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop, - tp_group=tp_group - ) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - x = x + drop_path(ffn, self.drop_path_p) - - return x - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, - tp_group=-1, layer_id=-1): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - # build blocks - self.blocks = nn.ModuleList() - for i in range(depth): - block = SwinTransformerBlock( - dim=dim, input_resolution=self.input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - tp_group=tp_group, - ) - self.blocks.append(block) - - def forward(self, x): - for blk in self.blocks: - x = blk(x) - return x - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, fp16=False, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - - # ====================== depth 0 =========================== - pconfig = pconfigs[0] - l0_tp_ranks, l0_dp_ranks = setup_device_group(**pconfig) - tp_group = DeviceGroup().get_group(l0_tp_ranks) - - input_resolution = ( - patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) - ) - self.basic_layer0 = BasicLayer( - dim=int(embed_dim * 2 ** 0), - input_resolution=input_resolution, - depth=depths[0], - num_heads=num_heads[0], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:0]):sum(depths[:0 + 1])], - norm_layer=norm_layer, - tp_group=tp_group, - ) - - if len(l0_dp_ranks) > 1: - dp_ranks = tuple(l0_dp_ranks) - if dp_ranks not in _dp_reducer: - _dp_reducer[dp_ranks] = Reducer(dp_ranks) - for param in self.patch_embed.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.basic_layer0.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - - # ====================== depth 1 =========================== - pconfig = pconfigs[1] - l1_tp_ranks, l1_dp_ranks = setup_device_group(**pconfig) - tp_group = DeviceGroup().get_group(l1_tp_ranks) - - # adapter - if len(l0_dp_ranks) > 1 and len(l1_tp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter01 = DPtoTP(DeviceGroup().get_group(l0_dp_ranks)) - elif len(l0_tp_ranks) > 1 and len(l1_dp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter01 = TPtoDP(DeviceGroup().get_group(l0_tp_ranks)) - else: - self.adapter01 = torch.nn.Identity() - - self.merging0 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 0), norm_layer=norm_layer - ) - - input_resolution = ( - patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) - ) - self.basic_layer1 = BasicLayer( - dim=int(embed_dim * 2 ** 1), - input_resolution=input_resolution, - depth=depths[1], - num_heads=num_heads[1], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:1]):sum(depths[:1 + 1])], - norm_layer=norm_layer, - tp_group=tp_group, - ) - - if len(l1_dp_ranks) > 1: - dp_ranks = tuple(l1_dp_ranks) - if dp_ranks not in _dp_reducer: - _dp_reducer[dp_ranks] = Reducer(dp_ranks) - for param in self.basic_layer1.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.merging0.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - - # ====================== depth 2 =========================== - pconfig = pconfigs[2] - l2_tp_ranks, l2_dp_ranks = setup_device_group(**pconfig) - tp_group = DeviceGroup().get_group(l2_tp_ranks) - - # adapter - if len(l1_dp_ranks) > 1 and len(l2_tp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter12 = DPtoTP(DeviceGroup().get_group(l1_dp_ranks)) - elif len(l1_tp_ranks) > 1 and len(l2_dp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter12 = TPtoDP(DeviceGroup().get_group(l1_tp_ranks)) - else: - self.adapter12 = torch.nn.Identity() - - - self.merging1 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 1), norm_layer=norm_layer - ) - - input_resolution = ( - patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) - ) - self.basic_layer2 = BasicLayer( - dim=int(embed_dim * 2 ** 2), - input_resolution=input_resolution, - depth=depths[2], - num_heads=num_heads[2], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:2]):sum(depths[:2 + 1])], - norm_layer=norm_layer, - tp_group=tp_group - ) - - if len(l2_dp_ranks) > 1: - dp_ranks = tuple(l2_dp_ranks) - if dp_ranks not in _dp_reducer: - _dp_reducer[dp_ranks] = Reducer(dp_ranks) - for param in self.basic_layer2.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.merging1.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - # ====================== depth 3 =========================== - pconfig = pconfigs[3] - l3_tp_ranks, l3_dp_ranks = setup_device_group(**pconfig) - tp_group = DeviceGroup().get_group(l3_tp_ranks) - - # adapter - if len(l2_dp_ranks) > 1 and len(l3_tp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter23 = DPtoTP(DeviceGroup().get_group(l2_dp_ranks)) - elif len(l2_tp_ranks) > 1 and len(l3_dp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter23 = TPtoDP(DeviceGroup().get_group(l2_tp_ranks)) - else: - self.adapter23 = torch.nn.Identity() - - self.merging2 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 2), norm_layer=norm_layer - ) - - self.basic_layer3 = BasicLayer( - dim=int(embed_dim * 2 ** 3), - input_resolution=(patches_resolution[0] // (2 ** 3), - patches_resolution[1] // (2 ** 3)), - depth=depths[3], - num_heads=num_heads[3], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:3]):sum(depths[:3 + 1])], - norm_layer=norm_layer, - tp_group=tp_group - ) - - if len(l3_dp_ranks) > 1: - dp_ranks = tuple(l3_dp_ranks) - if dp_ranks not in _dp_reducer: - _dp_reducer[dp_ranks] = Reducer(dp_ranks) - for param in self.basic_layer3.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.merging2.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - if len(l3_dp_ranks) > 1: - dp_ranks = tuple(l3_dp_ranks) - for param in self.norm.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.head.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, x): - x = self.patch_embed(x) - x = self.pos_drop(x) - - CudaTimer().start('basic_layer0') - x = self.basic_layer0(x) - CudaTimer().start('adapter') - x = self.adapter01(x) - CudaTimer().stop('adapter') - x = self.merging0(x) - CudaTimer().stop('basic_layer0') - - CudaTimer().start('basic_layer1') - x = self.basic_layer1(x) - CudaTimer().start('adapter') - x = self.adapter12(x) - CudaTimer().stop('adapter') - x = self.merging1(x) - CudaTimer().stop('basic_layer1') - - CudaTimer().start('basic_layer2') - x = self.basic_layer2(x) - CudaTimer().start('adapter') - x = self.adapter23(x) - CudaTimer().stop('adapter') - x = self.merging2(x) - CudaTimer().stop('basic_layer2') - - CudaTimer().start('basic_layer3') - x = self.basic_layer3(x) - CudaTimer().stop('basic_layer3') - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C L - x = torch.flatten(x, 1) - - x = self.head(x) - return x - - -def train(args, pconfigs): - - # dim_head is always 32 - - # img resolution, windows size: 224, 384, 518, 640 - C, H, W, window_size = [3, 224, 224, 7] - # C, H, W, window_size = [3, 384, 384, 12] - # C, H, W, window_size = [3, 518, 518, ?] - # C, H, W, window_size = [3, 640, 640, 20] - - # image batch size - N = args.bs - - # Swin-Tiny - # embed_dim, depths, num_heads = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24] - # ] - - # SwinV2-B: 87 M - # embed_dim, depths, num_heads = [ - # 128, [2, 2, 18, 2], [4, 8, 16, 32] - # ] - - # SwinV2-L: 196 M - # embed_dim, depths, num_heads = [ - # 192, [2, 2, 18, 2], [6, 12, 24, 48] - # ] - - # SwinV2-H: 657 M - embed_dim, depths, num_heads = [ - 352, [2, 2, 18, 2], [11, 22, 44, 88] - ] - - # # SwinV2-H modified: 782 M - embed_dim, depths, num_heads = [ - 384, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # # head dim 32 -> 48 - # embed_dim, depths, num_heads = [ - # 576, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 64 -- too much - # embed_dim, depths, num_heads = [ - # 768, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 80 - # embed_dim, depths, num_heads = [ - # 960, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 96 - # embed_dim, depths, num_heads = [ - # 1152, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # # head dim 32 -> 112 - # embed_dim, depths, num_heads = [ - # 1344, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 128 - # embed_dim, depths, num_heads = [ - # 1536, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 144 - # embed_dim, depths, num_heads = [ - # 1728, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 160 - # embed_dim, depths, num_heads = [ - # 1920, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - - - model = SwinTransformer(img_size = H, - embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size, - pconfigs = pconfigs) - if args.fp16: - print_each_rank('use half precision') - model = model.half() - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - model = model.cuda() - memory_summary() - - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [N // args.dp, C, H, W]) - - if args.fp16: - data_buff = [[e.half() for e in data] for data in dataloader.datas] - dataloader.datas = data_buff - - def train_iter(model, dataloader): - img = next(dataloader) - loss = model(img) - loss = torch.sum(loss) - loss.backward() - CudaTimer().start('dp_allreduce') - for ranks in _dp_reducer: - reducer = _dp_reducer[ranks] - reducer.allreduce() - CudaTimer().stop('dp_allreduce') - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - # start training - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - span = 0 - iter_num = 60 - for step in range(iter_num): - if step >= 20: - torch.cuda.synchronize() - start = time.time() - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 20: - torch.cuda.synchronize() - stop = time.time() - span += (stop - start) * 1000 - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = span / (iter_num-20) - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - CudaTimer().print_all(times=iter_num-20) - - -if __name__ == '__main__': - - cube.init() - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--layer0', type=int, nargs='+', - help='data, tensor parallel config') - parser.add_argument('--layer1', type=int, nargs='+', - help='data, tensor parallel config') - parser.add_argument('--layer2', type=int, nargs='+', - help='data, tensor parallel config') - parser.add_argument('--layer3', type=int, nargs='+', - help='data, tensor parallel config') - parser.add_argument('--bs', type=int, default=1, - help='bs') - parser.add_argument('--fp16', action='store_true', dest='fp16') - args = parser.parse_args() - - assert len(args.layer0) == 2 - assert len(args.layer1) == 2 - assert len(args.layer2) == 2 - assert len(args.layer3) == 2 - - # data parallel should be same - args.dp = args.layer0[0] - - pconfigs = [ - dict(layer_id=0, dp=args.layer0[0], tp=args.layer0[1]), # basic layer 0 - dict(layer_id=1, dp=args.layer1[0], tp=args.layer1[1]), # basic layer 1 - dict(layer_id=2, dp=args.layer2[0], tp=args.layer2[1]), # basic layer 2 - dict(layer_id=3, dp=args.layer3[0], tp=args.layer3[1]), # basic layer 3 - ] - - print_each_rank(pconfigs, rank_only=0) - train(args, pconfigs) diff --git a/handcraft/swin/swin_dwt.py b/handcraft/swin/swin_dwt.py deleted file mode 100644 index e6f2f081..00000000 --- a/handcraft/swin/swin_dwt.py +++ /dev/null @@ -1,979 +0,0 @@ - -# -------------------------------------------------------- -# Modified from Swin-Transformer Repo -""" -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_dwt.py --bs 8 \ - --layer0 1 4 1 \ - --layer1 1 4 1 \ - --layer2 1 1 4 \ - --layer3 1 1 4 - -""" -# -------------------------------------------------------- - -from typing import Dict, Optional, Tuple -import torch -import torch.nn as nn -import argparse -import time - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.reducer import Reducer - -from handcraft.swin.layers import ColumnParallelLinear, RowParallelLinear - - -_wp_reducer: Dict[Tuple[int], Reducer] = dict() -_dp_reducer: Dict[Tuple[int], Reducer] = dict() - - -def setup_device_group(tp: int, wp: int, dp: int, layer_id: int): - """ - Layer wise device group initialize - - Returns: - - """ - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - tp_size, tp_group_nums = tp, ndevs // tp - wp_size, wp_group_nums = wp, ndevs // wp - dp_size, dp_group_nums = dp, ndevs // dp - - if not tp_size * wp_size * dp_size == ndevs: - raise RuntimeError("Expected all devices are used") - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize tensor parallel groups - for i in range(tp_group_nums): - ranks = list(range(i * tp_size, (i + 1) * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - # initialize wp parallel group - all_wp_parallel_group_ranks = list() - for i in range(dp_size): - start_rank = i * dp_group_nums - end_rank = (i + 1) * dp_group_nums - for j in range(tp_size): - ranks = list(range(start_rank + j, end_rank, tp_size)) - all_wp_parallel_group_ranks.append(ranks) - # initialize groups - group = devs.get_group(ranks) - if myrank in ranks: - wp_ranks = ranks - _wp_reducer[tuple(ranks)] = Reducer(ranks) - print_each_rank(f'layer {layer_id}: initialzed window parallel group: {wp_ranks}', rank_only=myrank) - - # initialize data parallel groups - start_rank = 0 - end_rank = ndevs - for i in range(wp_size * tp_size): - ranks = list(range(i, ndevs, wp_size * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - _dp_reducer[tuple(ranks)] = Reducer(ranks) - print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) - return tp_ranks, wp_ranks, dp_ranks - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class MegatronMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tp_group=-1): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=tp_group) - self.act = act_layer() - # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=True, tp_group=tp_group) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class MegatronWindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., tp_group=-1): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.global_num_heads = num_heads - - tp_world_size = torch.distributed.get_world_size(group=tp_group) - if num_heads % tp_world_size != 0: - raise RuntimeError(f'detecting un-even num head {num_heads} partition to {tp_world_size}') - self.num_heads = num_heads // tp_world_size - - self.dim_heads = dim // self.global_num_heads - self.scale = qk_scale or self.dim_heads ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # relative position index - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - self.register_buffer('relative_position_index', relative_position_index) - - # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - # print(f'qkv embed dim: {dim}') - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) - self.attn_drop = nn.Dropout(attn_drop) - # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=True, tp_group=tp_group) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, - tp_group=-1, wp_plans=-1): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = MegatronWindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - tp_group=tp_group) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MegatronMlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop, - tp_group=tp_group - ) - - self.wp_group, self.wp_nH_ranks, self.wp_nW_ranks = wp_plans - self.use_wp = torch.distributed.get_world_size(self.wp_group) != 1 - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - if self.use_wp: - shifted_x = cube.runtime.function.roll_grid_parallel( - x, (-self.shift_size, -self.shift_size), (1,2), - self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group - ) - else: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - if self.use_wp: - x = cube.runtime.function.roll_grid_parallel( - shifted_x, (self.shift_size, self.shift_size), (1,2), - self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group - ) - else: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - x = x + drop_path(ffn, self.drop_path_p) - - return x - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, - tp=1, wp=1, dp=1, layer_id=-1): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - self.resource = cube.runtime.resource.EnvResource() - tp_ranks, wp_ranks, dp_ranks = setup_device_group(tp, wp, dp, layer_id) - tp_group = DeviceGroup().get_group(tp_ranks) - wp_group = DeviceGroup().get_group(wp_ranks) - wp_nH_ranks = [-1] - wp_nW_ranks = [-1] - - # window parallel - self.wp_resolution = input_resolution - if wp > 1: - H, W = self.input_resolution - nH = 1 - nW = wp // nH - while nH <= nW: - if H % nH != 0 or W % nW != 0: - nW = nW // 2 - nH = int(nH * 2) - else: - break - if nH > nW: - raise RuntimeError(f"layer {layer_id}: Cannot window partition plan") - print_each_rank(f"layer {layer_id}: Find partition plan: H{H} // {nH}, W{W} // {nW}") - self.wp_resolution = (H // nH, W // nW) - self.wp_group = wp_group - # wp_group multi dim shift ranks - for i in range(nH): - ranks = list(range(i * nW, (i + 1) * nW)) - if torch.distributed.get_rank(wp_group) in ranks: - wp_nW_ranks = ranks - break - for i in range(nW): - ranks = list(range(i, wp, nW)) - if torch.distributed.get_rank(wp_group) in ranks: - wp_nH_ranks = ranks - break - assert wp_nH_ranks != [-1] - assert wp_nW_ranks != [-1] - print_each_rank(f'window parallel nH group ranks: {wp_nH_ranks}') - print_each_rank(f'window parallel nW group ranks: {wp_nW_ranks}') - - # build blocks - self.blocks = nn.ModuleList() - for i in range(depth): - block = SwinTransformerBlock( - dim=dim, input_resolution=self.wp_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - tp_group=tp_group, wp_plans=(wp_group, wp_nH_ranks, wp_nW_ranks) - ) - self.blocks.append(block) - - self.wp_preprocess = False - self.wp_postprocess = False - if wp > 1: - for param in self.blocks.parameters(): - _wp_reducer[tuple(wp_ranks)].add_param(param) - self.wp_preprocess = True - self.wp_postprocess = True - - def forward(self, x): - if self.wp_preprocess: - oH, oW = self.input_resolution - pH, pW = self.wp_resolution - x = x.view(-1, oH, oW, self.dim) - x = cube.runtime.function.grid_partition(x, oH // pH, oW // pW, group=self.wp_group) - x = x.view(-1, pH * pW, self.dim).contiguous() - - for blk in self.blocks: - x = blk(x) - - if self.wp_postprocess: - oH, oW = self.input_resolution - pH, pW = self.wp_resolution - x = x.view(-1, pH, pW, self.dim) - x = cube.runtime.function.grid_collection(x, oH // pH, oW // pW, group=self.wp_group) - x = x.view(-1, oH * oW, self.dim) - return x - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - - # ====================== depth 0 =========================== - pconfig = pconfigs[0] - input_resolution = ( - patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) - ) - self.basic_layer0 = BasicLayer( - dim=int(embed_dim * 2 ** 0), - input_resolution=input_resolution, - depth=depths[0], - num_heads=num_heads[0], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:0]):sum(depths[:0 + 1])], - norm_layer=norm_layer, - **pconfig, - ) - - self.merging0 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 0), norm_layer=norm_layer - ) - - # ====================== depth 1 =========================== - pconfig = pconfigs[1] - input_resolution = ( - patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) - ) - self.basic_layer1 = BasicLayer( - dim=int(embed_dim * 2 ** 1), - input_resolution=input_resolution, - depth=depths[1], - num_heads=num_heads[1], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:1]):sum(depths[:1 + 1])], - norm_layer=norm_layer, - **pconfig, - ) - - self.merging1 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 1), norm_layer=norm_layer - ) - - - # ====================== depth 2 =========================== - pconfig = pconfigs[2] - input_resolution = ( - patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) - ) - self.basic_layer2 = BasicLayer( - dim=int(embed_dim * 2 ** 2), - input_resolution=input_resolution, - depth=depths[2], - num_heads=num_heads[2], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:2]):sum(depths[:2 + 1])], - norm_layer=norm_layer, - **pconfig - ) - - self.merging2 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 2), norm_layer=norm_layer - ) - - # ====================== depth 3 =========================== - pconfig = pconfigs[3] - self.basic_layer3 = BasicLayer( - dim=int(embed_dim * 2 ** 3), - input_resolution=(patches_resolution[0] // (2 ** 3), - patches_resolution[1] // (2 ** 3)), - depth=depths[3], - num_heads=num_heads[3], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:3]):sum(depths[:3 + 1])], - norm_layer=norm_layer, - **pconfig - ) - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, x): - x = self.patch_embed(x) - x = self.pos_drop(x) - - CudaTimer().start('basic_layer0') - x = self.basic_layer0(x) - CudaTimer().stop('basic_layer0') - x = self.merging0(x) - CudaTimer().start('basic_layer1') - x = self.basic_layer1(x) - CudaTimer().stop('basic_layer1') - x = self.merging1(x) - CudaTimer().start('basic_layer2') - x = self.basic_layer2(x) - CudaTimer().stop('basic_layer2') - x = self.merging2(x) - CudaTimer().start('basic_layer3') - x = self.basic_layer3(x) - CudaTimer().stop('basic_layer3') - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C L - x = torch.flatten(x, 1) - - x = self.head(x) - return x - - -def train(args, pconfigs): - - # dim_head is always 32 - - # img resolution, windows size: 224, 384, 518, 640 - C, H, W, window_size = [3, 224, 224, 7] - # C, H, W, window_size = [3, 384, 384, 12] - # C, H, W, window_size = [3, 518, 518, ?] - # C, H, W, window_size = [4, 640, 640, 20] - - # image batch size - N = args.bs - - # Swin-Tiny - # embed_dim, depths, num_heads = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24] - # ] - - # SwinV2-B: 87 M - # embed_dim, depths, num_heads = [ - # 128, [2, 2, 18, 2], [4, 8, 16, 32] - # ] - - # SwinV2-L: 196 M - # embed_dim, depths, num_heads = [ - # 192, [2, 2, 18, 2], [6, 12, 24, 48] - # ] - - # SwinV2-H: 657 M - # embed_dim, depths, num_heads = [ - # 352, [2, 2, 18, 2], [11, 22, 44, 88] - # ] - - # SwinV2-H modified: 782 M - embed_dim, depths, num_heads = [ - 384, [2, 2, 18, 2], [12, 24, 48, 96] - ] - - # SwinV2-G: 2.5B Model - # embed_dim, depths, num_heads = [ - # 512, [2, 2, 42, 2], [16, 32, 64, 128] - # ] - - # 895.7 M Model - # embed_dim, depths, num_heads = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96] - # ] - - - # 2.01B model - # embed_dim, depths, num_heads = [ - # 576, [2, 2, 22, 2], [12, 24, 48, 96] - # ] - - - model = SwinTransformer(img_size = H, - embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size, - pconfigs = pconfigs) - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - if args.fp16: - print_each_rank('use half model') - model = model.half() - model = model.cuda() - memory_summary() - - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [N // args.dp, C, H, W]) - - if args.fp16: - data_buff = [[e.half() for e in data] for data in dataloader.datas] - dataloader.datas = data_buff - - if args.dp > 1: - assert len(_dp_reducer) == 1 - reducer = None - for ranks in _dp_reducer: - reducer = _dp_reducer[ranks] - for param in model.parameters(): - reduced = False - for wp_ranks in _wp_reducer: - if param in _wp_reducer[wp_ranks]._params: - reduced = True - break - if not reduced: - reducer.add_param(param) - - def train_iter(model, dataloader): - img = next(dataloader) - loss = model(img) - loss = torch.sum(loss) - loss.backward() - CudaTimer().start('wp_allreduce') - for ranks in _wp_reducer: - reducer = _wp_reducer[ranks] - reducer.allreduce() - CudaTimer().stop('wp_allreduce') - CudaTimer().start('dp_allreduce') - for ranks in _dp_reducer: - reducer = _dp_reducer[ranks] - reducer.allreduce() - CudaTimer().stop('dp_allreduce') - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - # start training - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - span = 0 - iter_num = 60 - for step in range(iter_num): - if step >= 20: - torch.cuda.synchronize() - start = time.time() - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 20: - torch.cuda.synchronize() - stop = time.time() - span += (stop - start) * 1000 - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = span / (iter_num-20) - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - CudaTimer().print_all(times=iter_num-20) - - -if __name__ == '__main__': - - cube.init() - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--layer0', type=int, nargs='+', - help='data, window tensor parallel config') - parser.add_argument('--layer1', type=int, nargs='+', - help='data, window tensor parallel config') - parser.add_argument('--layer2', type=int, nargs='+', - help='data, window tensor parallel config') - parser.add_argument('--layer3', type=int, nargs='+', - help='data, window tensor parallel config') - parser.add_argument('--bs', type=int, default=1, - help='bs') - parser.add_argument('--fp16', action='store_true', dest='fp16') - args = parser.parse_args() - - assert len(args.layer0) == 3 - assert len(args.layer1) == 3 - assert len(args.layer2) == 3 - assert len(args.layer3) == 3 - - # data parallel should be same - assert args.layer0[0] == args.layer1[0] and args.layer1[0] == args.layer2[0] and args.layer2[0] == args.layer3[0] - args.dp = args.layer0[0] - - pconfigs = [ - dict(layer_id=0, dp=args.layer0[0], wp=args.layer0[1], tp=args.layer0[2]), # basic layer 0 - dict(layer_id=1, dp=args.layer1[0], wp=args.layer1[1], tp=args.layer1[2]), # basic layer 1 - dict(layer_id=2, dp=args.layer2[0], wp=args.layer2[1], tp=args.layer2[2]), # basic layer 2 # prob at 8:1? - dict(layer_id=3, dp=args.layer3[0], wp=args.layer3[1], tp=args.layer3[2]), # basic layer 3 - ] - - # pconfigs = [ - # dict(layer_id=0, tp=4, wp=1, dp=args.dp), # basic layer 0 - # dict(layer_id=1, tp=4, wp=1, dp=args.dp), # basic layer 1 - # dict(layer_id=2, tp=4, wp=1, dp=args.dp), # basic layer 2 # prob at 8:1? - # dict(layer_id=3, tp=4, wp=1, dp=args.dp), # basic layer 3 - # ] - - print_each_rank(pconfigs, rank_only=0) - train(args, pconfigs) diff --git a/handcraft/swin/swin_dwt_infer.py b/handcraft/swin/swin_dwt_infer.py deleted file mode 100644 index 5c3c2e24..00000000 --- a/handcraft/swin/swin_dwt_infer.py +++ /dev/null @@ -1,954 +0,0 @@ - -# -------------------------------------------------------- -# Modified from Swin-Transformer Repo -""" -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - handcraft/swin/swin_dwt.py --bs 8 \ - --layer0 1 4 1 \ - --layer1 1 4 1 \ - --layer2 1 1 4 \ - --layer3 1 1 4 - -""" -# -------------------------------------------------------- - -from typing import Dict, Optional, Tuple -import torch -import torch.nn as nn -import argparse -import time - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.reducer import Reducer - -from handcraft.swin.layers import ColumnParallelLinear, RowParallelLinear - - -_wp_reducer: Dict[Tuple[int], Reducer] = dict() -_dp_reducer: Dict[Tuple[int], Reducer] = dict() - - -def setup_device_group(tp: int, wp: int, dp: int, layer_id: int): - """ - Layer wise device group initialize - - Returns: - - """ - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - tp_size, tp_group_nums = tp, ndevs // tp - wp_size, wp_group_nums = wp, ndevs // wp - dp_size, dp_group_nums = dp, ndevs // dp - - if not tp_size * wp_size * dp_size == ndevs: - raise RuntimeError("Expected all devices are used") - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize tensor parallel groups - for i in range(tp_group_nums): - ranks = list(range(i * tp_size, (i + 1) * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - # initialize wp parallel group - all_wp_parallel_group_ranks = list() - for i in range(dp_size): - start_rank = i * dp_group_nums - end_rank = (i + 1) * dp_group_nums - for j in range(tp_size): - ranks = list(range(start_rank + j, end_rank, tp_size)) - all_wp_parallel_group_ranks.append(ranks) - # initialize groups - group = devs.get_group(ranks) - if myrank in ranks: - wp_ranks = ranks - _wp_reducer[tuple(ranks)] = Reducer(ranks) - print_each_rank(f'layer {layer_id}: initialzed window parallel group: {wp_ranks}', rank_only=myrank) - - # initialize data parallel groups - start_rank = 0 - end_rank = ndevs - for i in range(wp_size * tp_size): - ranks = list(range(i, ndevs, wp_size * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - _dp_reducer[tuple(ranks)] = Reducer(ranks) - print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) - return tp_ranks, wp_ranks, dp_ranks - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class MegatronMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tp_group=-1): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=tp_group) - self.act = act_layer() - # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=True, tp_group=tp_group) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class MegatronWindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., tp_group=-1): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.global_num_heads = num_heads - - tp_world_size = torch.distributed.get_world_size(group=tp_group) - if num_heads % tp_world_size != 0: - raise RuntimeError(f'detecting un-even num head {num_heads} partition to {tp_world_size}') - self.num_heads = num_heads // tp_world_size - - self.dim_heads = dim // self.global_num_heads - self.scale = qk_scale or self.dim_heads ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # relative position index - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - self.register_buffer('relative_position_index', relative_position_index) - - # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - # print(f'qkv embed dim: {dim}') - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) - self.attn_drop = nn.Dropout(attn_drop) - # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=True, tp_group=tp_group) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, - tp_group=-1, wp_plans=-1, layer_id=-1): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - self.wp_group, self.wp_nH_ranks, self.wp_nW_ranks = wp_plans - # if min(self.input_resolution) <= self.window_size: - # # if window size is larger than input resolution, we don't partition windows - # self.shift_size = 0 - # self.window_size = min(self.input_resolution) - if layer_id == 3: - print('set shift size to 0') - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = MegatronWindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - tp_group=tp_group) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MegatronMlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop, - tp_group=tp_group - ) - - self.use_wp = torch.distributed.get_world_size(self.wp_group) != 1 - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - if self.use_wp: - shifted_x = cube.runtime.function.roll_grid_parallel( - x, (-self.shift_size, -self.shift_size), (1,2), - self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group - ) - else: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - if self.use_wp: - x = cube.runtime.function.roll_grid_parallel( - shifted_x, (self.shift_size, self.shift_size), (1,2), - self.wp_nH_ranks, self.wp_nW_ranks, self.wp_group - ) - else: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - x = x + drop_path(ffn, self.drop_path_p) - - return x - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, - tp=1, wp=1, dp=1, layer_id=-1): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - self.resource = cube.runtime.resource.EnvResource() - tp_ranks, wp_ranks, dp_ranks = setup_device_group(tp, wp, dp, layer_id) - tp_group = DeviceGroup().get_group(tp_ranks) - wp_group = DeviceGroup().get_group(wp_ranks) - wp_nH_ranks = [-1] - wp_nW_ranks = [-1] - - # window parallel - self.wp_resolution = input_resolution - if wp > 1: - H, W = self.input_resolution - nH = 1 - nW = wp // nH - while nH <= nW: - if H % nH != 0 or W % nW != 0 or (H // nH) % window_size != 0 or (W // nW) % window_size != 0: - nW = nW // 2 - nH = int(nH * 2) - else: - break - if nH > nW: - raise RuntimeError(f"layer {layer_id}: Cannot window partition plan") - print_each_rank(f"layer {layer_id}: Find partition plan: H{H} // {nH}, W{W} // {nW}") - self.wp_resolution = (H // nH, W // nW) - self.wp_group = wp_group - # wp_group multi dim shift ranks - for i in range(nH): - ranks = list(range(i * nW, (i + 1) * nW)) - if torch.distributed.get_rank(wp_group) in ranks: - wp_nW_ranks = ranks - break - for i in range(nW): - ranks = list(range(i, wp, nW)) - if torch.distributed.get_rank(wp_group) in ranks: - wp_nH_ranks = ranks - break - assert wp_nH_ranks != [-1] - assert wp_nW_ranks != [-1] - print_each_rank(f'window parallel nH group ranks: {wp_nH_ranks}') - print_each_rank(f'window parallel nW group ranks: {wp_nW_ranks}') - - # build blocks - self.blocks = nn.ModuleList() - for i in range(depth): - block = SwinTransformerBlock( - dim=dim, input_resolution=self.wp_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - tp_group=tp_group, wp_plans=(wp_group, wp_nH_ranks, wp_nW_ranks), - layer_id = layer_id - ) - self.blocks.append(block) - - self.wp_preprocess = False - self.wp_postprocess = False - if wp > 1: - for param in self.blocks.parameters(): - _wp_reducer[tuple(wp_ranks)].add_param(param) - self.wp_preprocess = True - self.wp_postprocess = True - - def forward(self, x): - if self.wp_preprocess: - oH, oW = self.input_resolution - pH, pW = self.wp_resolution - x = x.view(-1, oH, oW, self.dim) - x = cube.runtime.function.grid_partition(x, oH // pH, oW // pW, group=self.wp_group) - x = x.view(-1, pH * pW, self.dim).contiguous() - - for blk in self.blocks: - x = blk(x) - - if self.wp_postprocess: - oH, oW = self.input_resolution - pH, pW = self.wp_resolution - x = x.view(-1, pH, pW, self.dim) - x = cube.runtime.function.grid_collection(x, oH // pH, oW // pW, group=self.wp_group) - x = x.view(-1, oH * oW, self.dim) - return x - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - - # ====================== depth 0 =========================== - pconfig = pconfigs[0] - input_resolution = ( - patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) - ) - self.basic_layer0 = BasicLayer( - dim=int(embed_dim * 2 ** 0), - input_resolution=input_resolution, - depth=depths[0], - num_heads=num_heads[0], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:0]):sum(depths[:0 + 1])], - norm_layer=norm_layer, - **pconfig, - ) - - self.merging0 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 0), norm_layer=norm_layer - ) - - # ====================== depth 1 =========================== - pconfig = pconfigs[1] - input_resolution = ( - patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) - ) - self.basic_layer1 = BasicLayer( - dim=int(embed_dim * 2 ** 1), - input_resolution=input_resolution, - depth=depths[1], - num_heads=num_heads[1], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:1]):sum(depths[:1 + 1])], - norm_layer=norm_layer, - **pconfig, - ) - - self.merging1 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 1), norm_layer=norm_layer - ) - - - # ====================== depth 2 =========================== - pconfig = pconfigs[2] - input_resolution = ( - patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) - ) - self.basic_layer2 = BasicLayer( - dim=int(embed_dim * 2 ** 2), - input_resolution=input_resolution, - depth=depths[2], - num_heads=num_heads[2], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:2]):sum(depths[:2 + 1])], - norm_layer=norm_layer, - **pconfig - ) - - self.merging2 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 2), norm_layer=norm_layer - ) - - # ====================== depth 3 =========================== - pconfig = pconfigs[3] - self.basic_layer3 = BasicLayer( - dim=int(embed_dim * 2 ** 3), - input_resolution=(patches_resolution[0] // (2 ** 3), - patches_resolution[1] // (2 ** 3)), - depth=depths[3], - num_heads=num_heads[3], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:3]):sum(depths[:3 + 1])], - norm_layer=norm_layer, - **pconfig - ) - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, x): - x = self.patch_embed(x) - x = self.pos_drop(x) - - CudaTimer().start('basic_layer0') - x = self.basic_layer0(x) - CudaTimer().stop('basic_layer0') - x = self.merging0(x) - CudaTimer().start('basic_layer1') - x = self.basic_layer1(x) - CudaTimer().stop('basic_layer1') - x = self.merging1(x) - CudaTimer().start('basic_layer2') - x = self.basic_layer2(x) - CudaTimer().stop('basic_layer2') - x = self.merging2(x) - CudaTimer().start('basic_layer3') - x = self.basic_layer3(x) - CudaTimer().stop('basic_layer3') - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C L - x = torch.flatten(x, 1) - - x = self.head(x) - return x - - -def train(args, pconfigs): - - # dim_head is always 32 - - # img resolution, windows size: 224, 384, 518, 640 - C, H, W, window_size = [3, 224, 224, 7] - # C, H, W, window_size = [3, 384, 384, 12] - # C, H, W, window_size = [3, 518, 518, ?] - # C, H, W, window_size = [3, 640, 640, 20] - # C, H, W, window_size = [3, 1536, 1536, 48] - - # image batch size - N = args.bs - - # Swin-Tiny - # embed_dim, depths, num_heads = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24] - # ] - - # SwinV2-B: 87 M - # embed_dim, depths, num_heads = [ - # 128, [2, 2, 18, 2], [4, 8, 16, 32] - # ] - - # SwinV2-L: 196 M - # embed_dim, depths, num_heads = [ - # 192, [2, 2, 18, 2], [6, 12, 24, 48] - # ] - - # SwinV2-H: 657 M - # embed_dim, depths, num_heads = [ - # 352, [2, 2, 18, 2], [11, 22, 44, 88] - # ] - - # SwinV2-H modified: 782 M - embed_dim, depths, num_heads = [ - 384, [2, 2, 18, 2], [12, 24, 48, 96] - ] - - # SwinV2-G: 2.5B Model - # embed_dim, depths, num_heads = [ - # 512, [2, 2, 42, 2], [16, 32, 64, 128] - # ] - - # 895.7 M Model - # embed_dim, depths, num_heads = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96] - # ] - - - # 2.01B model - # embed_dim, depths, num_heads = [ - # 576, [2, 2, 22, 2], [12, 24, 48, 96] - # ] - - print_each_rank( - f'Test setting: Resolution {H}, Embed {embed_dim}, depths: {depths}, heads: {num_heads}', - rank_only=0 - ) - - - model = SwinTransformer(img_size = H, - embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size, - pconfigs = pconfigs) - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - if args.fp16: - print_each_rank('use half model') - model = model.half() - model = model.cuda() - memory_summary() - - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [N // args.dp, C, H, W]) - - if args.fp16: - data_buff = [[e.half() for e in data] for data in dataloader.datas] - dataloader.datas = data_buff - - model.eval() - def infer_iter(model, dataloader): - with torch.no_grad(): - img = next(dataloader) - loss = model(img) - - # start training - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - span = 0 - iter_num = 60 - for step in range(iter_num): - if step >= 20: - torch.cuda.synchronize() - start = time.time() - CudaTimer(enable=True).start('e2e') - infer_iter(model, dataloader) - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 20: - torch.cuda.synchronize() - stop = time.time() - span += (stop - start) * 1000 - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = span / (iter_num-20) - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - CudaTimer().print_all(times=iter_num-20) - - -if __name__ == '__main__': - - cube.init() - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--layer0', type=int, nargs='+', - help='data, window tensor parallel config') - parser.add_argument('--layer1', type=int, nargs='+', - help='data, window tensor parallel config') - parser.add_argument('--layer2', type=int, nargs='+', - help='data, window tensor parallel config') - parser.add_argument('--layer3', type=int, nargs='+', - help='data, window tensor parallel config') - parser.add_argument('--bs', type=int, default=1, - help='bs') - parser.add_argument('--fp16', action='store_true', dest='fp16') - args = parser.parse_args() - - assert len(args.layer0) == 3 - assert len(args.layer1) == 3 - assert len(args.layer2) == 3 - assert len(args.layer3) == 3 - - # data parallel should be same - assert args.layer0[0] == args.layer1[0] and args.layer1[0] == args.layer2[0] and args.layer2[0] == args.layer3[0] - args.dp = args.layer0[0] - - pconfigs = [ - dict(layer_id=0, dp=args.layer0[0], wp=args.layer0[1], tp=args.layer0[2]), # basic layer 0 - dict(layer_id=1, dp=args.layer1[0], wp=args.layer1[1], tp=args.layer1[2]), # basic layer 1 - dict(layer_id=2, dp=args.layer2[0], wp=args.layer2[1], tp=args.layer2[2]), # basic layer 2 # prob at 8:1? - dict(layer_id=3, dp=args.layer3[0], wp=args.layer3[1], tp=args.layer3[2]), # basic layer 3 - ] - - print_each_rank(pconfigs, rank_only=0) - train(args, pconfigs) diff --git a/handcraft/swin/swin_flexflow.py b/handcraft/swin/swin_flexflow.py deleted file mode 100644 index 3dc1f054..00000000 --- a/handcraft/swin/swin_flexflow.py +++ /dev/null @@ -1,993 +0,0 @@ - -# -------------------------------------------------------- -# Modified from Swin-Transformer Repo -""" -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/swin/swin_flexflow.py --bs 8 \ - --layer0 8 1 \ - --layer1 8 1 \ - --layer2 1 8 \ - --layer3 1 8 -""" -# -------------------------------------------------------- - -from typing import Dict, Optional, Tuple -import torch -import torch.nn as nn -import argparse -import time - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.reducer import Reducer - -from handcraft.swin.layers import ColumnParallelLinear, DPtoTP, ValueTPtoEleDP, RowParallelLinear, TPtoDP - -_dp_reducer: Dict[Tuple[int], Reducer] = dict() - - -def setup_device_group(tp: int, dp: int, layer_id: int): - """ - Layer wise device group initialize - - Returns: - - """ - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - if not tp * dp == ndevs: - raise RuntimeError("Expected same device number") - - assert tp == 1 or dp == 1, "Currently hybrid not supported" - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize tensor parallel groups - for i in range(dp): - ranks = list(range(i * tp, (i + 1) * tp)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - # initialize data parallel groups - for i in range(tp): - ranks = list(range(i, ndevs, tp)) - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) - return tp_ranks, dp_ranks - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class MegatronMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tp_group=-1): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=tp_group) - self.act = act_layer() - # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=False, tp_group=tp_group) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class MegatronWindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., tp_group=-1): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.global_num_heads = num_heads - - tp_world_size = torch.distributed.get_world_size(group=tp_group) - if num_heads % tp_world_size != 0: - raise RuntimeError(f'detecting un-even num head {num_heads} partition to {tp_world_size}') - self.num_heads = num_heads // tp_world_size - - self.dim_heads = dim // self.global_num_heads - self.scale = qk_scale or self.dim_heads ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # relative position index - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - self.register_buffer('relative_position_index', relative_position_index) - - # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=tp_group) - self.attn_drop = nn.Dropout(attn_drop) - # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=False, tp_group=tp_group) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, - tp_group=-1): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = MegatronWindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - tp_group=tp_group) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MegatronMlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop, - tp_group=tp_group - ) - - self.partition_all_op = True - if self.partition_all_op and torch.distributed.get_world_size(tp_group) > 1: - print('> enabled all-op partitioning...') - self.val_tp_to_dp = ValueTPtoEleDP(tp_group) - self.tp_to_dp = TPtoDP(tp_group) - self.dp_to_tp = DPtoTP(tp_group) - else: - self.tp_to_dp = None - self.dp_to_tp = None - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - - if self.partition_all_op and self.tp_to_dp is not None: - x = self.val_tp_to_dp(x) - shortcut = self.tp_to_dp(shortcut) - - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - - if self.partition_all_op and self.dp_to_tp is not None: - x = self.dp_to_tp(x) - - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - - if self.partition_all_op and self.tp_to_dp is not None: - x = self.val_tp_to_dp(x) - ffn = self.tp_to_dp(ffn) - - x = x + drop_path(ffn, self.drop_path_p) - - if self.partition_all_op and self.dp_to_tp is not None: - x = self.dp_to_tp(x) - - return x - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, - tp_group=-1, layer_id=-1): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - # build blocks - self.blocks = nn.ModuleList() - for i in range(depth): - block = SwinTransformerBlock( - dim=dim, input_resolution=self.input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - tp_group=tp_group, - ) - self.blocks.append(block) - - def forward(self, x): - for blk in self.blocks: - x = blk(x) - return x - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, pconfigs=None, fp16=False, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - - # ====================== depth 0 =========================== - pconfig = pconfigs[0] - l0_tp_ranks, l0_dp_ranks = setup_device_group(**pconfig) - tp_group = DeviceGroup().get_group(l0_tp_ranks) - - input_resolution = ( - patches_resolution[0] // (2 ** 0), patches_resolution[1] // (2 ** 0) - ) - self.basic_layer0 = BasicLayer( - dim=int(embed_dim * 2 ** 0), - input_resolution=input_resolution, - depth=depths[0], - num_heads=num_heads[0], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:0]):sum(depths[:0 + 1])], - norm_layer=norm_layer, - tp_group=tp_group, - ) - - if len(l0_dp_ranks) > 1: - dp_ranks = tuple(l0_dp_ranks) - if dp_ranks not in _dp_reducer: - _dp_reducer[dp_ranks] = Reducer(dp_ranks) - for param in self.patch_embed.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.basic_layer0.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - - # ====================== depth 1 =========================== - pconfig = pconfigs[1] - l1_tp_ranks, l1_dp_ranks = setup_device_group(**pconfig) - tp_group = DeviceGroup().get_group(l1_tp_ranks) - - # adapter - if len(l0_dp_ranks) > 1 and len(l1_tp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter01 = DPtoTP(DeviceGroup().get_group(l0_dp_ranks)) - elif len(l0_tp_ranks) > 1 and len(l1_dp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter01 = TPtoDP(DeviceGroup().get_group(l0_tp_ranks)) - else: - self.adapter01 = torch.nn.Identity() - - self.merging0 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 0), norm_layer=norm_layer - ) - - input_resolution = ( - patches_resolution[0] // (2 ** 1), patches_resolution[1] // (2 ** 1) - ) - self.basic_layer1 = BasicLayer( - dim=int(embed_dim * 2 ** 1), - input_resolution=input_resolution, - depth=depths[1], - num_heads=num_heads[1], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:1]):sum(depths[:1 + 1])], - norm_layer=norm_layer, - tp_group=tp_group, - ) - - if len(l1_dp_ranks) > 1: - dp_ranks = tuple(l1_dp_ranks) - if dp_ranks not in _dp_reducer: - _dp_reducer[dp_ranks] = Reducer(dp_ranks) - for param in self.basic_layer1.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.merging0.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - - # ====================== depth 2 =========================== - pconfig = pconfigs[2] - l2_tp_ranks, l2_dp_ranks = setup_device_group(**pconfig) - tp_group = DeviceGroup().get_group(l2_tp_ranks) - - # adapter - if len(l1_dp_ranks) > 1 and len(l2_tp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter12 = DPtoTP(DeviceGroup().get_group(l1_dp_ranks)) - elif len(l1_tp_ranks) > 1 and len(l2_dp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter12 = TPtoDP(DeviceGroup().get_group(l1_tp_ranks)) - else: - self.adapter12 = torch.nn.Identity() - - - self.merging1 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 1), norm_layer=norm_layer - ) - - input_resolution = ( - patches_resolution[0] // (2 ** 2), patches_resolution[1] // (2 ** 2) - ) - self.basic_layer2 = BasicLayer( - dim=int(embed_dim * 2 ** 2), - input_resolution=input_resolution, - depth=depths[2], - num_heads=num_heads[2], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:2]):sum(depths[:2 + 1])], - norm_layer=norm_layer, - tp_group=tp_group - ) - - if len(l2_dp_ranks) > 1: - dp_ranks = tuple(l2_dp_ranks) - if dp_ranks not in _dp_reducer: - _dp_reducer[dp_ranks] = Reducer(dp_ranks) - for param in self.basic_layer2.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.merging1.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - # ====================== depth 3 =========================== - pconfig = pconfigs[3] - l3_tp_ranks, l3_dp_ranks = setup_device_group(**pconfig) - tp_group = DeviceGroup().get_group(l3_tp_ranks) - - # adapter - if len(l2_dp_ranks) > 1 and len(l3_tp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter23 = DPtoTP(DeviceGroup().get_group(l2_dp_ranks)) - elif len(l2_tp_ranks) > 1 and len(l3_dp_ranks) > 1: - print_each_rank('add dp to tp adapters') - self.adapter23 = TPtoDP(DeviceGroup().get_group(l2_tp_ranks)) - else: - self.adapter23 = torch.nn.Identity() - - self.merging2 = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** 2), norm_layer=norm_layer - ) - - self.basic_layer3 = BasicLayer( - dim=int(embed_dim * 2 ** 3), - input_resolution=(patches_resolution[0] // (2 ** 3), - patches_resolution[1] // (2 ** 3)), - depth=depths[3], - num_heads=num_heads[3], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:3]):sum(depths[:3 + 1])], - norm_layer=norm_layer, - tp_group=tp_group - ) - - if len(l3_dp_ranks) > 1: - dp_ranks = tuple(l3_dp_ranks) - if dp_ranks not in _dp_reducer: - _dp_reducer[dp_ranks] = Reducer(dp_ranks) - for param in self.basic_layer3.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.merging2.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - if len(l3_dp_ranks) > 1: - dp_ranks = tuple(l3_dp_ranks) - for param in self.norm.parameters(): - _dp_reducer[dp_ranks].add_param(param) - for param in self.head.parameters(): - _dp_reducer[dp_ranks].add_param(param) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, x): - x = self.patch_embed(x) - x = self.pos_drop(x) - - CudaTimer().start('basic_layer0') - x = self.basic_layer0(x) - CudaTimer().start('adapter') - x = self.adapter01(x) - CudaTimer().stop('adapter') - x = self.merging0(x) - CudaTimer().stop('basic_layer0') - - CudaTimer().start('basic_layer1') - x = self.basic_layer1(x) - CudaTimer().start('adapter') - x = self.adapter12(x) - CudaTimer().stop('adapter') - x = self.merging1(x) - CudaTimer().stop('basic_layer1') - - CudaTimer().start('basic_layer2') - x = self.basic_layer2(x) - CudaTimer().start('adapter') - x = self.adapter23(x) - CudaTimer().stop('adapter') - x = self.merging2(x) - CudaTimer().stop('basic_layer2') - - CudaTimer().start('basic_layer3') - x = self.basic_layer3(x) - CudaTimer().stop('basic_layer3') - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C L - x = torch.flatten(x, 1) - - x = self.head(x) - return x - - -def train(args, pconfigs): - - # dim_head is always 32 - - # img resolution, windows size: 224, 384, 518, 640 - C, H, W, window_size = [3, 224, 224, 7] - # C, H, W, window_size = [3, 384, 384, 12] - # C, H, W, window_size = [3, 518, 518, ?] - # C, H, W, window_size = [3, 640, 640, 20] - - # image batch size - N = args.bs - - # Swin-Tiny - # embed_dim, depths, num_heads = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24] - # ] - - # SwinV2-B: 87 M - # embed_dim, depths, num_heads = [ - # 128, [2, 2, 18, 2], [4, 8, 16, 32] - # ] - - # SwinV2-L: 196 M - # embed_dim, depths, num_heads = [ - # 192, [2, 2, 18, 2], [6, 12, 24, 48] - # ] - - # SwinV2-H: 657 M - embed_dim, depths, num_heads = [ - 352, [2, 2, 18, 2], [11, 22, 44, 88] - ] - - # # SwinV2-H modified: 782 M - embed_dim, depths, num_heads = [ - 384, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # # head dim 32 -> 48 - # embed_dim, depths, num_heads = [ - # 576, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 64 -- too much - # embed_dim, depths, num_heads = [ - # 768, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 80 - # embed_dim, depths, num_heads = [ - # 960, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 96 - # embed_dim, depths, num_heads = [ - # 1152, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # # head dim 32 -> 112 - # embed_dim, depths, num_heads = [ - # 1344, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 128 - # embed_dim, depths, num_heads = [ - # 1536, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 144 - # embed_dim, depths, num_heads = [ - # 1728, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - # # head dim 32 -> 160 - # embed_dim, depths, num_heads = [ - # 1920, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - - - model = SwinTransformer(img_size = H, - embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size, - pconfigs = pconfigs) - if args.fp16: - print_each_rank('use half precision') - model = model.half() - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - model = model.cuda() - memory_summary() - - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [N // args.dp, C, H, W]) - - if args.fp16: - data_buff = [[e.half() for e in data] for data in dataloader.datas] - dataloader.datas = data_buff - - def train_iter(model, dataloader): - img = next(dataloader) - loss = model(img) - loss = torch.sum(loss) - loss.backward() - CudaTimer().start('dp_allreduce') - for ranks in _dp_reducer: - reducer = _dp_reducer[ranks] - reducer.allreduce() - CudaTimer().stop('dp_allreduce') - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - # start training - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - span = 0 - iter_num = 60 - for step in range(iter_num): - if step >= 20: - torch.cuda.synchronize() - start = time.time() - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 20: - torch.cuda.synchronize() - stop = time.time() - span += (stop - start) * 1000 - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = span / (iter_num-20) - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - CudaTimer().print_all(times=iter_num-20) - - -if __name__ == '__main__': - - cube.init() - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--layer0', type=int, nargs='+', - help='data, tensor parallel config') - parser.add_argument('--layer1', type=int, nargs='+', - help='data, tensor parallel config') - parser.add_argument('--layer2', type=int, nargs='+', - help='data, tensor parallel config') - parser.add_argument('--layer3', type=int, nargs='+', - help='data, tensor parallel config') - parser.add_argument('--bs', type=int, default=1, - help='bs') - parser.add_argument('--fp16', action='store_true', dest='fp16') - args = parser.parse_args() - - assert len(args.layer0) == 2 - assert len(args.layer1) == 2 - assert len(args.layer2) == 2 - assert len(args.layer3) == 2 - - # data parallel should be same - args.dp = args.layer0[0] - - pconfigs = [ - dict(layer_id=0, dp=args.layer0[0], tp=args.layer0[1]), # basic layer 0 - dict(layer_id=1, dp=args.layer1[0], tp=args.layer1[1]), # basic layer 1 - dict(layer_id=2, dp=args.layer2[0], tp=args.layer2[1]), # basic layer 2 - dict(layer_id=3, dp=args.layer3[0], tp=args.layer3[1]), # basic layer 3 - ] - - print_each_rank(pconfigs, rank_only=0) - train(args, pconfigs) diff --git a/handcraft/swin/swin_hybrid.py b/handcraft/swin/swin_hybrid.py deleted file mode 100644 index 8ba3f967..00000000 --- a/handcraft/swin/swin_hybrid.py +++ /dev/null @@ -1,1086 +0,0 @@ - -# -------------------------------------------------------- -# Modified from Swin-Transformer Repo -""" -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - handcraft/swin/swin_hybrid.py \ - --layer0 8 1 1 \ - --layer1 8 1 1 \ - --layer2 8 1 1 \ - --layer3 8 1 1 \ - --gbs 1 --mbs 1 - -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=$NID \ - --master_addr=worker-0 \ - --master_port=8004 \ - --use_env \ - handcraft/swin/swin_hybrid.py \ - --layer0 2 8 1 \ - --layer1 2 8 1 \ - --layer2 2 8 1 \ - --layer3 2 8 1 \ - --gbs 8 --mbs 8 - -# V100-16GB: 8GPU: need checkpoint: 8 micro bs -""" -# -------------------------------------------------------- - -from typing import Optional, Dict, Tuple -import torch -import torch.nn as nn -import torch.utils.checkpoint as checkpoint - - -import cube -from cube.profiler import CudaTimer -from cube.runtime.device import DeviceGroup -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.adapter.reducer import Reducer - -import argparse - -from handcraft.swin.hybrid_schedule import scheduling_1f1b, is_last_stage -from handcraft.swin.layers import ColumnParallelLinear, RowParallelLinear, DPtoTP, TPtoDP - -from handcraft.swin.pmodule import ParallelModule - - -_dp_reducer: Dict[Tuple[int], Reducer] = dict() - - -def setup_device_group(pp: int, dp: int, tp: int, layer_id: int): - """ - Layer wise device group initialize - - Returns: - - """ - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - if not pp * tp * dp == ndevs: - raise RuntimeError("Expected same device number") - - # assert tp == 1 or dp == 1, "Currently hybrid not supported" - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize tensor parallel groups - for i in range(ndevs // tp): - ranks = list(range(i * tp, (i + 1) * tp)) - if len(ranks) > 1: - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - print_each_rank(f'layer {layer_id}: initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - # initialize data parallel groups - for i in range(pp): - start_rank = i * ndevs // pp - end_rank = (i+1) * ndevs // pp - for j in range(tp): - ranks = list(range(start_rank + j, end_rank, tp)) - if len(ranks) > 1: - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - _dp_reducer[tuple(dp_ranks)] = Reducer(dp_ranks) - print_each_rank(f'layer {layer_id}: initialzed data parallel group: {dp_ranks}', rank_only=myrank) - - # initialize pipeline parallel groups - for i in range(dp * tp): - ranks = list(range(i, ndevs, tp * dp)) - if len(ranks) > 1: - group = devs.get_group(ranks) - if myrank in ranks: - pp_ranks = ranks - print_each_rank(f'layer {layer_id}: initialized pipeline parallel group: {pp_ranks}') - - return pp_ranks, dp_ranks, tp_ranks - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class MegatronMlp(ParallelModule): - def __init__(self, in_features, hidden_features=None, out_features=None, - act_layer=nn.GELU, drop=0., - pp_ranks=-1, tp_ranks=-1, dp_ranks=-1): - super().__init__( - pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks - ) - out_features = out_features or in_features - hidden_features = hidden_features or in_features - # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, in_adapter=True, out_adapter=False, tp_group=self.tp_group) - self.act = act_layer() - # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, in_adapter=False, out_adapter=True, tp_group=self.tp_group) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class MegatronWindowAttention(ParallelModule): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0, - pp_ranks=-1, tp_ranks=-1, dp_ranks=-1): - - super().__init__( - pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks - ) - self.dim = dim - self.window_size = window_size # Wh, Ww - self.global_num_heads = num_heads - - tp_world_size = torch.distributed.get_world_size(group=self.tp_group) - if num_heads % tp_world_size != 0: - print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') - self.num_heads = num_heads // tp_world_size - self.dim_heads = dim // self.global_num_heads - self.scale = qk_scale or self.dim_heads ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # relative position index - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - self.register_buffer('relative_position_index', relative_position_index) - - - # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - # print(f'qkv embed dim: {dim}') - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, in_adapter=True, out_adapter=False, tp_group=self.tp_group) - self.attn_drop = nn.Dropout(attn_drop) - # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, in_adapter=False, out_adapter=True, tp_group=self.tp_group) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(ParallelModule): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, - pp_ranks=-1, tp_ranks=-1, dp_ranks=-1, fw_bs=-1): - super().__init__( - pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks - ) - - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = MegatronWindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MegatronMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, - pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - H, W = self.input_resolution - - assert fw_bs // len(dp_ranks) != 0 - self.set_in_size([fw_bs // len(dp_ranks), H * W, self.dim]) - self.set_out_size([fw_bs // len(dp_ranks), H * W, self.dim]) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - x = x + drop_path(ffn, self.drop_path_p) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(ParallelModule): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, - pp_ranks=-1, tp_ranks=-1, dp_ranks=-1, fw_bs=-1): - super().__init__( - pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks - ) - - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - H, W = self.input_resolution - - assert fw_bs // len(dp_ranks) != 0 - self.set_in_size([fw_bs // len(dp_ranks), H * W, self.dim]) - self.set_out_size([fw_bs // len(dp_ranks), H // 2 * W // 2, self.dim * 2]) - - def forward(self, x): - """ - x: B, H*W, C - """ - assert list(x.shape) == self.in_size - - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - assert list(x.shape) == self.out_size - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(ParallelModule): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, - pp_ranks=-1, tp_ranks=-1, dp_ranks=-1, layer_id=-1, fw_bs=-1): - - super().__init__( - pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks - ) - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - # build blocks - self.blocks = nn.ModuleList([]) - for i in range(depth): - block = SwinTransformerBlock( - dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pp_ranks=pp_ranks, dp_ranks=dp_ranks, tp_ranks=tp_ranks, fw_bs=fw_bs - ) - self.blocks.append(block) - - def forward(self, x): - raise RuntimeError("Error call here") - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - pconfigs=None, fw_bs=-1, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - - tp_ranks, dp_ranks, pp_ranks = list(), list(), list() - for i in range(4): - pconfig = pconfigs[i] - layer_pp_ranks, layer_dp_ranks, layer_tp_ranks = setup_device_group(**pconfig) - tp_ranks.append(layer_tp_ranks) - dp_ranks.append(layer_dp_ranks) - pp_ranks.append(layer_pp_ranks) - - # build network layers - layers = nn.ModuleList() - for i_layer in range(self.num_layers): - pconfig = pconfigs[i_layer] - layer_tp_ranks, layer_dp_ranks = tp_ranks[i_layer], dp_ranks[i_layer] - - if i_layer != self.num_layers - 1: - next_layer_tp_ranks = tp_ranks[i_layer + 1] - next_layer_dp_ranks = dp_ranks[i_layer + 1] - else: - next_layer_dp_ranks = list() - next_layer_tp_ranks = list() - - input_resolution = ( - patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer) - ) - layer = BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - input_resolution=input_resolution, - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - pp_ranks=pp_ranks[i_layer], dp_ranks=dp_ranks[i_layer], tp_ranks=tp_ranks[i_layer], - fw_bs=fw_bs - ) - - for block in layer.blocks: - layers.append(block) - - if i_layer < self.num_layers - 1: - merging = PatchMerging( - input_resolution, dim=int(embed_dim * 2 ** i_layer), - norm_layer=norm_layer, - pp_ranks=pp_ranks[i_layer], dp_ranks=dp_ranks[i_layer], tp_ranks=tp_ranks[i_layer], - fw_bs = fw_bs, - ) - layers.append(merging) - else: - merging = None - - # adapter - if len(layer_dp_ranks) == 1 and len(layer_tp_ranks) > 1 \ - and len(next_layer_dp_ranks) > 1 and len(next_layer_tp_ranks) == 1: - print_each_rank('add tp to dp adapters') - adapter = TPtoDP(DeviceGroup().get_group(next_layer_dp_ranks)) - adapter.in_size = layers[-1].out_size - out_size = [size for size in layers[-1].out_size] - out_size[0] = out_size[0] // len(next_layer_dp_ranks) - adapter.out_size = out_size - elif len(layer_tp_ranks) == 1 and len(layer_dp_ranks) > 1 \ - and len(next_layer_tp_ranks) > 1 and len(next_layer_dp_ranks) == 1: - print_each_rank('add dp to tp adapters') - adapter = DPtoTP(DeviceGroup().get_group(next_layer_tp_ranks)) - adapter.in_size = layers[-1].out_size - out_size = [size for size in layers[-1].out_size] - out_size[0] = out_size[0] * len(layer_dp_ranks) - adapter.out_size = out_size - elif len(layer_tp_ranks) == len(next_layer_tp_ranks) and \ - len(layer_dp_ranks) == len(next_layer_dp_ranks): - adapter = torch.nn.Identity() - adapter.in_size = layers[-1].out_size - adapter.out_size = layers[-1].out_size - layers.append(adapter) - - - # ================ Pipeline Parallel Region ====================== - self.pp_group = DeviceGroup().get_group(pp_ranks[0]) - pp_rank = torch.distributed.get_rank(self.pp_group) - pp_size = torch.distributed.get_world_size(self.pp_group) - - assert len(layers) == 31 - - for block in layers: - print_each_rank(f'> block: {type(block).__name__}: in {block.in_size}, out: {block.out_size}', rank_only=0) - - chunk = len(layers) // pp_size - if len(layers) % pp_size != 0: - remain = len(layers) % pp_size - if pp_rank < remain: - start = pp_rank * (chunk+1) - chunk = chunk + 1 - else: - start = remain * (chunk + 1) + (pp_rank - remain) * chunk - else: - start = pp_rank * chunk - stop = start + chunk - - # self.use_checkpoint = [False] * (stop - start) - self.use_checkpoint = [True] * (stop - start) - - # 8gpu layer assign - # layer_split = [5, 5, 4, 3, 3, 3, 3, 5] # original - # layer_split = [3, 3, 3, 3, 3, 4, 4, 4] - # assert sum(layer_split) == 31 - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - # self.use_checkpoint = [False] * (stop - start) - # for idx in range(stop - start): - # if pp_rank == 0: - # if idx < 1: - # self.use_checkpoint[idx] = True - - # 4 stage layer assign - # layer_split = [8, 8, 7, 8] # original - # layer_split = [6, 7, 7, 7] - - # assert sum(layer_split) == 31 - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - - print_each_rank(f'layer start -> end: {start} -> {stop}') - print_each_rank(self.use_checkpoint) - self.layers = layers[start:stop] - - local_chunk = list() - for block in self.layers: - local_chunk.append(f'{type(block).__name__}: in: {block.in_size}; out: {block.out_size}') - local_chunk = '\n'.join(local_chunk) - print_each_rank('local chunk:\n' + local_chunk) - - self.in_size = self.layers[0].in_size - assert isinstance(self.in_size, list) - self.out_size = self.layers[-1].out_size - assert isinstance(self.out_size, list) - - self.preprocess = False - if pp_rank == 0: - self.preprocess = True - self.in_size = [in_chans, img_size, img_size] - self.postprocess = False - if is_last_stage(self.pp_group): - self.postprocess = True - self.out_size = [1,] - - if self.postprocess: - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - - # =================== Data Parallel ======================== - - self.split_data = len(dp_ranks[0]) - - # preprocess data parallel region - if self.preprocess and len(dp_ranks[0]) > 1: - for param in self.patch_embed.parameters(): - _dp_reducer[tuple(dp_ranks[0])].add_param(param) - - # block data parallel region - for block in self.layers: - if isinstance(block, ParallelModule): - if block.use_dp(): - for param in block.parameters(): - _dp_reducer[block.dp_ranks].add_param(param) - - # postprocess data parallel region - if self.postprocess and len(dp_ranks[-1]) > 1: - for param in self.norm.parameters(): - _dp_reducer[tuple(dp_ranks[-1])].add_param(param) - for param in self.head.parameters(): - _dp_reducer[tuple(dp_ranks[-1])].add_param(param) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, image: torch.Tensor, feature_map=None): - - if self.preprocess: - with torch.no_grad(): - # FIXME: should select corresponding chunk - image = image.chunk(self.split_data, 0)[0] - x = self.patch_embed(image) - x = self.pos_drop(x) - feature_map = x - - for layer, use_checkpoint in zip(self.layers, self.use_checkpoint): - if use_checkpoint: - feature_map = checkpoint.checkpoint(layer, feature_map) - else: - feature_map = layer(feature_map) - x = feature_map - - if self.postprocess: - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C L - x = torch.flatten(x, 1) - x = self.head(x) - # simulate for simplicity - x = torch.sum(x) - return x - - def flops(self): - flops = 0 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes - return flops - - -def train(args, pconfigs): - - # dim_head is always 32 - - # img resolution, windows size: 224, 384, 518, 640 - C, H, W, window_size = [3, 224, 224, 7] - # C, H, W, window_size = [3, 384, 384, 12] - # C, H, W, window_size = [3, 518, 518, ?] - # C, H, W, window_size = [3, 640, 640, 20] - # C, H, W, window_size = [3, 1536, 1536, 48] - - # image batch size - N = args.gbs - - # Swin-Tiny - # embed_dim, depths, num_heads = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24] - # ] - - # SwinV2-B: 87 M - # embed_dim, depths, num_heads = [ - # 128, [2, 2, 18, 2], [4, 8, 16, 32] - # ] - - # SwinV2-L: 196 M - # embed_dim, depths, num_heads = [ - # 192, [2, 2, 18, 2], [6, 12, 24, 48] - # ] - - # SwinV2-H: 657 M - # embed_dim, depths, num_heads = [ - # 352, [2, 2, 18, 2], [11, 22, 44, 88] - # ] - - # SwinV2-H modified: 782 M - embed_dim, depths, num_heads = [ - 384, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # # head dim 32 -> 48 - embed_dim, depths, num_heads = [ - 576, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # # head dim 32 -> 64 -- too much - embed_dim, depths, num_heads = [ - 768, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # # head dim 32 -> 80 - embed_dim, depths, num_heads = [ - 960, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # # head dim 32 -> 96 - embed_dim, depths, num_heads = [ - 1152, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # # head dim 32 -> 112 - embed_dim, depths, num_heads = [ - 1344, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # head dim 32 -> 128 - embed_dim, depths, num_heads = [ - 1536, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # head dim 32 -> 144 - embed_dim, depths, num_heads = [ - 1728, [2, 2, 18, 2], [12, 24, 48, 96] - ] - # head dim 32 -> 160 - # embed_dim, depths, num_heads = [ - # 1920, [2, 2, 18, 2], [12, 24, 48, 96] - # ] - - # SwinV2-G: 2.5B Model - # embed_dim, depths, num_heads = [ - # 512, [2, 2, 42, 2], [16, 32, 64, 128] - # ] - - # 895.7 M Model - # embed_dim, depths, num_heads = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96] - # ] - - - # 2.01B model - # embed_dim, depths, num_heads = [ - # 576, [2, 2, 22, 2], [12, 24, 48, 96] - # ] - - print_each_rank( - f'config: embed_dim: {embed_dim}, depths: {depths}, num_heads: {num_heads}' - ) - - - model = SwinTransformer(img_size = H, - embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size, - pconfigs = pconfigs, - fw_bs = args.mbs) - model = model.cuda() - memory_summary() - - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [args.gbs, C, H, W]) - dataloader.set_data_buffer(buffer_num=2) - - def train_iter(model, dataloader): - img = next(dataloader) - scheduling_1f1b(model, [img], args.gbs, args.mbs, torch.float, model.pp_group) - torch.distributed.barrier() - CudaTimer().start('dp_allreduce') - for ranks in _dp_reducer: - reducer = _dp_reducer[ranks] - reducer.allreduce() - CudaTimer().stop('dp_allreduce') - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - # start training - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - iter_num = 20 - for step in range(iter_num): - if step >= 10: - CudaTimer(enable=True).start('e2e') - torch.distributed.barrier() - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - torch.distributed.barrier() - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 10: - CudaTimer().stop('e2e') - if (step + 1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-10, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - - CudaTimer().print_all(times=iter_num-10) - memory_summary() - - -if __name__ == '__main__': - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--layer0', type=int, nargs='+', - help='pipeline, data, tensor parallel config') - parser.add_argument('--layer1', type=int, nargs='+', - help='pipeline, data, tensor parallel config') - parser.add_argument('--layer2', type=int, nargs='+', - help='pipeline, data, tensor parallel config') - parser.add_argument('--layer3', type=int, nargs='+', - help='pipeline, data, tensor parallel config') - parser.add_argument('--gbs', type=int, default=-1) - parser.add_argument('--mbs', type=int, default=-1) - args = parser.parse_args() - - cube.init() - - # allocate resource - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - args.pp = args.layer0[0] - - pconfigs = [ - dict(layer_id=0, pp=args.layer0[0], dp=args.layer0[1], tp=args.layer0[2]), # basic layer 0 - dict(layer_id=1, pp=args.layer0[0], dp=args.layer1[1], tp=args.layer1[2]), # basic layer 1 - dict(layer_id=2, pp=args.layer0[0], dp=args.layer2[1], tp=args.layer2[2]), # basic layer 2 - dict(layer_id=3, pp=args.layer0[0], dp=args.layer3[1], tp=args.layer3[2]), # basic layer 3 - ] - - train(args, pconfigs) diff --git a/handcraft/swin/swin_pipe.py b/handcraft/swin/swin_pipe.py deleted file mode 100644 index ef1580e2..00000000 --- a/handcraft/swin/swin_pipe.py +++ /dev/null @@ -1,872 +0,0 @@ - -# -------------------------------------------------------- -# Modified from Swin-Transformer Repo -""" -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - handcraft/swin/swin_pipe.py --pp 8 --gbs 32 --mbs 4 - -# V100-16GB: 8GPU: need checkpoint: 8 micro bs -""" -# -------------------------------------------------------- - -from typing import Optional -import torch -import torch.nn as nn -import torch.utils.checkpoint as checkpoint - - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary - -import argparse - -from handcraft.swin.schedule import scheduling_1f1b, is_last_stage -from handcraft.swin.layers import ColumnParallelLinear, RowParallelLinear - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class MegatronMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - # self.fc1 = nn.Linear(in_features, hidden_features) - self.fc1 = ColumnParallelLinear(in_features, hidden_features, full_input=True, full_output=False) - self.act = act_layer() - # self.fc2 = nn.Linear(hidden_features, out_features) - self.fc2 = RowParallelLinear(hidden_features, out_features, full_input=False, full_output=True) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class MegatronWindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.global_num_heads = num_heads - group = cube.runtime.resource.EnvResource().tp_group - tp_world_size = torch.distributed.get_world_size(group=group) - if num_heads % tp_world_size != 0: - print(f'detecting un-even num head {num_heads} partition to {tp_world_size}') - self.num_heads = num_heads // torch.distributed.get_world_size(group=group) - self.dim_heads = dim // self.global_num_heads - self.scale = qk_scale or self.dim_heads ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), self.num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # relative position index - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - self.register_buffer('relative_position_index', relative_position_index) - - - # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - # print(f'qkv embed dim: {dim}') - self.qkv = ColumnParallelLinear(dim, dim * 3, bias=qkv_bias, full_input=True, full_output=False) - self.attn_drop = nn.Dropout(attn_drop) - # self.proj = nn.Linear(dim, dim) - self.proj = RowParallelLinear(dim, dim, full_input=False, full_output=True) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.dim_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, self.num_heads * self.dim_heads) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = MegatronWindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MegatronMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - H, W = self.input_resolution - self.in_size = [H * W, self.dim] - self.out_size = [H * W, self.dim] - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - x = x + drop_path(ffn, self.drop_path_p) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - H, W = self.input_resolution - self.in_size = [H * W, self.dim] - self.out_size = [H // 2 * W // 2, self.dim * 2] - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, module_lists=None): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - # build blocks - for i in range(depth): - block = SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - module_lists.append(block) - - # patch merging layer - if downsample is not None: - merging = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - module_lists.append(merging) - - def forward(self, x): - raise RuntimeError("Error call here") - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - layers = nn.ModuleList() - for i_layer in range(self.num_layers): - _ = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - module_lists=layers) - - # pipeline stage - pp_rank = torch.distributed.get_rank() - pp_size = torch.distributed.get_world_size() - - chunk = len(layers) // pp_size - if len(layers) % pp_size != 0: - remain = len(layers) % pp_size - if pp_rank < remain: - start = pp_rank * (chunk+1) - chunk = chunk + 1 - else: - start = remain * (chunk + 1) + (pp_rank - remain) * chunk - else: - start = pp_rank * chunk - stop = start + chunk - - # self.use_checkpoint = [True] * (stop - start) - # self.use_checkpoint = [False] * (stop - start) - - # 8gpu layer assign - # layer_split = [3, 4, 3, 3, 3, 4, 3, 4] - # assert sum(layer_split) == 27 - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - # self.use_checkpoint = [False] * (stop - start) - # for idx in range(stop - start): - # if pp_rank == 0: - # if idx < 1: - # self.use_checkpoint[idx] = True - - # 4Ggpu layer assign - # layer_split = [6, 7, 7, 7] - # assert sum(layer_split) == 27 - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - - print_each_rank(f'layer start -> end: {start} -> {stop}') - print_each_rank(self.use_checkpoint) - self.layers = layers[start:stop] - print_each_rank([str(type(layer)) + '\n' for layer in self.layers]) - - self.in_size = self.layers[0].in_size - assert isinstance(self.in_size, list) - self.out_size = self.layers[-1].out_size - assert isinstance(self.out_size, list) - - - self.preprocess = False - if pp_rank == 0: - self.preprocess = True - self.in_size = [in_chans, img_size, img_size] - self.postprocess = False - if is_last_stage(): - self.postprocess = True - self.out_size = [1,] - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, image, feature_map=None): - - if self.preprocess: - x = self.patch_embed(image) - x = self.pos_drop(x) - feature_map = x - - for layer, use_checkpoint in zip(self.layers, self.use_checkpoint): - if use_checkpoint: - feature_map = checkpoint.checkpoint(layer, feature_map) - else: - feature_map = layer(feature_map) - x = feature_map - - if self.postprocess: - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C L - x = torch.flatten(x, 1) - x = self.head(x) - # simulate for simplicity - x = torch.sum(x) - return x - - def flops(self): - flops = 0 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes - return flops - - -def train(args): - resource = cube.runtime.resource.EnvResource() - - # dim_head is always 32 - - # img resolution, windows size: 224, 384, 518, 640 - C, H, W, window_size = [3, 224, 224, 7] - # C, H, W, window_size = [3, 384, 384, 12] - # C, H, W, window_size = [3, 518, 518, ?] - # C, H, W, window_size = [3, 640, 640, 20] - # C, H, W, window_size = [3, 1536, 1536, 48] - - # image batch size - N = args.gbs - - # Swin-Tiny - # embed_dim, depths, num_heads = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24] - # ] - - # SwinV2-B: 87 M - # embed_dim, depths, num_heads = [ - # 128, [2, 2, 18, 2], [4, 8, 16, 32] - # ] - - # SwinV2-L: 196 M - # embed_dim, depths, num_heads = [ - # 192, [2, 2, 18, 2], [6, 12, 24, 48] - # ] - - # SwinV2-H: 657 M - # embed_dim, depths, num_heads = [ - # 352, [2, 2, 18, 2], [11, 22, 44, 88] - # ] - - # SwinV2-H modified: 782 M - embed_dim, depths, num_heads = [ - 384, [2, 2, 18, 2], [12, 24, 48, 96] - ] - - # SwinV2-G: 2.5B Model - # embed_dim, depths, num_heads = [ - # 512, [2, 2, 42, 2], [16, 32, 64, 128] - # ] - - # 895.7 M Model - # embed_dim, depths, num_heads = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96] - # ] - - - # 2.01B model - # embed_dim, depths, num_heads = [ - # 576, [2, 2, 22, 2], [12, 24, 48, 96] - # ] - - - model = SwinTransformer(img_size = H, - embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size) - model = model.cuda() - memory_summary() - - # setup data parallel reducer - # reducer = None - assert args.dp == 1 - - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [N // args.dp, C, H, W]) - dataloader.set_data_buffer(buffer_num=16) - - def train_iter(model, dataloader): - img = next(dataloader) - scheduling_1f1b(model, [img], args.gbs, args.mbs, dtype=torch.float) - # if reducer is not None: - # reducer.allreduce() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - # start training - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 20 - for step in range(iter_num): - if step >= 10: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 10: - CudaTimer().stop('e2e') - if (step + 1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-10, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - - CudaTimer().print_all(times=iter_num-10) - memory_summary() - - -if __name__ == '__main__': - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--tp', type=int, default=1, - help='tensor parallel size') - parser.add_argument('--dp', type=int, default=1, - help='data parallel size') - parser.add_argument('--pp', type=int, default=1, - help='pipeline parallel size') - parser.add_argument('--gbs', type=int, default=-1) - parser.add_argument('--mbs', type=int, default=-1) - args = parser.parse_args() - - cube.init() - - # allocate resource - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - tp_size, tp_group_nums = args.tp, ndevs // args.tp - dp_size, dp_group_nums = args.dp, ndevs // args.dp - pp_size, pp_group_nums = args.pp, ndevs // args.pp - - if not pp_size * dp_size * tp_size == ndevs: - raise RuntimeError("Expected all devices are used") - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize data parallel group - all_data_parallel_group_ranks = list() - for i in range(pp_size): - start_rank = i * pp_group_nums - end_rank = (i + 1) * pp_group_nums - for j in range(tp_size): - ranks = list(range(start_rank + j, end_rank, tp_size)) - all_data_parallel_group_ranks.append(ranks) - # initialize groups - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - resource.dp_group = group - resource.reducer = cube.runtime.reducer.Reducer(ranks) - print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) - - # initialize pipelne parallel groups - for i in range(dp_size): - ranks = [data_parallel_group_ranks[i] - for data_parallel_group_ranks in all_data_parallel_group_ranks] - group = devs.get_group(ranks) - if myrank in ranks: - pp_ranks = ranks - resource.pp_group = group - print_each_rank(f'initialzed pipeline parallel group: {pp_ranks}', rank_only=myrank) - - # initialize tensor parallel groups - for i in range(tp_group_nums): - ranks = list(range(i * tp_size, (i + 1) * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - resource.tp_group = group - print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - train(args) diff --git a/handcraft/swin/swin_transformer.py b/handcraft/swin/swin_transformer.py deleted file mode 100644 index 4d4f0763..00000000 --- a/handcraft/swin/swin_transformer.py +++ /dev/null @@ -1,696 +0,0 @@ - -# -------------------------------------------------------- -# Swin Transformer -# Copyright (c) 2021 Microsoft -# Licensed under The MIT License [see LICENSE for details] -# Written by Ze Liu - -# Copied and modified from -# -------------------------------------------------------- - -from typing import Optional -import torch -import torch.nn as nn - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary - - -def drop_path(x, drop_prob: float = 0.): - if drop_prob == 0.: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x: torch.Tensor, window_size: int): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - # [B, H_window_num, window_size, W_window_num, window_size, C] - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - # [B, H_window_num, W_window_num, window_size, window_size, C] - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() - # [B * H_windows_num * W_window_size, window_size, window_size, C] - windows = windows.view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows: torch.Tensor, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def window_position_index(window_size_h: int, window_size_w: int): - coords_h = torch.arange(window_size_h) - coords_w = torch.arange(window_size_w) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size_h - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size_w - 1 - relative_coords[:, :, 0] *= 2 * window_size_w - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww, Wh*Ww - return relative_position_index - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask: Optional[torch.Tensor] = None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_index = window_position_index(self.window_size[0], self.window_size[1]) - relative_position_bias = self.relative_position_bias_table[relative_position_index] - # [Wh * Ww, Wh * Ww, nH] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path_p = drop_path - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - # [B, H, W, C] -> [B * num_windows, window_size_h, windows_size_w, C] - x_windows = window_partition(shifted_x, self.window_size) - # -> [B * num_windows, window_size_h * windows_size_w, C] - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # same in/out: [B * num_windows, window_size_h * windows_size_w, C] - attn_windows = self.attn(x_windows, mask=self.attn_mask) - - # merge windows - # [B * num_windows, w_h * w_w, C] -> [B * num_windows, w_h, w_w, C] - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - # [B * num_windows, window_size_h, windows_size_w, C] -> [B, H', W', C] - shifted_x = window_reverse(attn_windows, self.window_size, H, W) - - # reverse cyclic shift - # [B, H', W', C] -> [B, H, W, C] - x = shifted_x - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - # [B, H, W, C] -> [B, H * W, C] - x = x.view(B, H * W, C) - # [B, H * W, C] -> [B, H * W, C] - x = shortcut + drop_path(x, self.drop_path_p) - # FFN - # [B, H * W, C] -> [B, H * W, C] - ffn = self.norm2(x) - # [B, H * W, C] -> [B, H * W, C] - ffn = self.mlp(ffn) - # [B, H * W, C] + [B, H * W, C] -> [B, H * W, C] - x = x + drop_path(ffn, self.drop_path_p) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x): - for blk in self.blocks: - x = blk(x) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - # self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - # if self.ape: - # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - # trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) - self.layers.append(layer) - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def forward_features(self, x): - x = self.patch_embed(x) - # if self.ape: - # x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - return x - - def forward(self, x): - # forward features - # x = self.forward_features(x) - x = self.patch_embed(x) - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - - x = self.head(x) - return x - - def flops(self): - flops = 0 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes - return flops - - -def train(): - - # image batch input - N, C, H, W = [1, 3, 224, 224] - - # embed_dim, depths, num_heads, window_size = [ - # 96, [2, 2, 6, 2], [3, 6, 12, 24], 7 - # ] - - # 348.55 M - embed_dim, depths, num_heads, window_size = [ - 256, [2, 2, 18, 2], [8, 16, 32, 64], 7 - ] - - # 895.7 M Model -- 224x224 - # embed_dim, depths, num_heads, window_size = [ - # 384, [2, 2, 22, 2], [12, 24, 48, 96], 7 - # ] - - # 2.01B model - # embed_dim, depths, num_heads, window_size = [ - # 576, [2, 2, 22, 2], [12, 24, 48, 96], 7 - # ] - - - model = SwinTransformer(embed_dim = embed_dim, - depths = depths, - num_heads = num_heads, - window_size = window_size) - - - module = torch.jit.script(model) - print(module.graph) - # print(parser.ScriptModuleParser.flatten(module, depth=2)) - - model = model.cuda() - - dataloader = cube.runtime.syndata.SynDataLoader(1280, [0], [N, C, H, W]) - - def train_iter(model, dataloader): - img = next(dataloader) - loss = model(img) - loss = torch.sum(loss) - loss.backward() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on iteration') - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - - -if __name__ == '__main__': - - cube.init() - train() diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py new file mode 100644 index 00000000..3e34bcb2 --- /dev/null +++ b/handcraft/swin/train.py @@ -0,0 +1,873 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers 18 --dim 192 --heads 6 \ + --pp-size 1 --tp-size 1 --dp-size 1 \ + --bs 4 --micro-bs 1 --coshard 1 --fp16 +""" + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary, model_summary +from cube.runtime.adapter.reducer import Reducer +from cube.runtime.device import DeviceGroup +from cube.runtime.adapter.distnn import IdentityAllreduce, AllReduceIdentity +from handcraft.module.schedule import schedule_1f1b +from handcraft.module.stage import PipeStage +from handcraft.swin.utils import create_position_bias, trunc_normal_, window_partition, window_reverse, DropPath + +import argparse + + +parser = argparse.ArgumentParser(description='swin') + +# model arch +parser.add_argument('--layers', type=int, default=18, + help='third stage layer depths. default large') +parser.add_argument('--dim', type=int, default=192, + help='input channel of first stage') +parser.add_argument('--heads', type=int, default=6, + help='head num of first stage') +# data +parser.add_argument('--img-size', type=int, default=640, + help='image size, can be 224, 640, 1536') +parser.add_argument('--window-size', type=int, default=40, + help='image size, can be 7, 40, 48') +# training +parser.add_argument('--bs', type=int, default=256, + help='batch size') +parser.add_argument('--micro-bs', type=int, default=1, + help='micro batch size') +parser.add_argument('--pp-size', type=int, default=1, + help='pipeline parallelism size') +parser.add_argument('--tp-size', type=int, default=1, + help='tensor parallelism size') +parser.add_argument('--dp-size', type=int, default=1, + help='data parallelism size') +parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b'], + help='scheduling algorithm') +parser.add_argument('--coshard', type=int, default=1) +parser.add_argument('--fp16', action='store_true', default='') + +args = parser.parse_args() +print(args) + +_tp_group = -1 + +_dp_group = -1 +_dp_reducer = None + +_pp_group = -1 +_pp_global_ranks = () +_schedule = schedule_1f1b +_layer_divisions = [] + +cube.init() +dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( + [args.dp_size, args.pp_size, args.tp_size] +) + +if len(dp_ranks) != 1: + print_each_rank(f'initializing dp ranks: {dp_ranks}') + _dp_group = DeviceGroup().get_group(dp_ranks) + _dp_reducer = Reducer(dp_ranks) + +if len(tp_ranks) != 1: + print_each_rank(f'initializing tp ranks: {tp_ranks}') + _tp_group = DeviceGroup().get_group(tp_ranks) + +if len(pp_ranks) != 1: + print_each_rank(f'initializing pp ranks: {pp_ranks}') + _pp_group = DeviceGroup().get_group(pp_ranks) + _pp_global_ranks = tuple(pp_ranks) + + # layer division + nlayers = 2 + 2 + args.layers + 2 + 3 # 3 is patch merging layers + times = ([1] * 2 + [0]) + \ + ([1] * 2 + [0]) + \ + ([1] * args.layers + [0]) + \ + ([1] * 2) + num_stages = len(pp_ranks) + budget = sum(times) // num_stages + print_each_rank(f'budget: {budget}', rank_only=0) + start, end = 0, 1 + for idx in range(num_stages): + accum = times[start] + assert end <= nlayers + while end != nlayers: + accum += times[end] + if accum > budget: + break + end += 1 + if idx == num_stages - 1: + end = nlayers + _layer_divisions.append((start, end)) + start, end = end, end+1 +else: + _layer_divisions = [(0, 2 + 2 + args.layers + 2 + 3)] +print_each_rank(f'layer divisions: {_layer_divisions}', rank_only=0) + + +class Config: + + embed_dim = args.dim + depths = [2, 2, args.layers, 2] + num_heads = [args.heads, args.heads * 2, args.heads * 4, args.heads * 8] + + mlp_ratio = 4 + qkv_bias = True + qk_scale = None + drop_path_rate = 0.2 + drop_rate = 0.2 + + img_size = args.img_size + window_size = args.window_size + num_classes = 1000 + + +class Mlp(torch.nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + self._tp_group = _tp_group + self._tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features // self._tp_size) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features // self._tp_size, out_features) + self.drop = nn.Dropout(drop) + + def forward_(self, x): + if self._tp_size > 1: + x = IdentityAllreduce.apply(x, self._tp_group) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + if self._tp_size > 1: + x = AllReduceIdentity.apply(x, self._tp_group) + return x + + def forward(self, x): + x = checkpoint.checkpoint(self.forward_, x) + # x = self.forward_(x) + return x + + +class SeqMlp(torch.nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., + coshard=1): + super().__init__() + self.coshard = coshard + assert hidden_features is not None + assert hidden_features % coshard == 0 + self.mlps = torch.nn.ModuleList( + [Mlp(in_features, hidden_features // coshard, out_features, act_layer, drop) for _ in range(coshard)] + ) + + def forward(self, x): + outs = None + for mlp in self.mlps: + x_out = mlp(x) + outs = x_out if outs is None else outs + x_out + return outs + + +class WindowAttention(torch.nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, inner_dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self._tp_group = _tp_group + self._tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + self.dim = dim + self.window_size = window_size # Wh, Ww + self.head_dim = inner_dim // num_heads + assert num_heads % self._tp_size == 0 + self.num_heads = num_heads // self._tp_size + self.scale = qk_scale or self.head_dim ** -0.5 + + # define define a parameter table of relative position bias + table, index = create_position_bias(self.window_size, self.num_heads) + self.relative_position_bias_table = table + self.register_buffer("relative_position_index", index) + + self.qkv = nn.Linear(dim, inner_dim // self._tp_size * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(inner_dim // self._tp_size, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward_(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + if self._tp_size > 1: + x = IdentityAllreduce.apply(x, self._tp_group) + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + if self._tp_size > 1: + x = AllReduceIdentity.apply(x, self._tp_group) + + return x + + def forward(self, x, mask=None): + x = checkpoint.checkpoint(self.forward_, x, mask) + # x = self.forward_(x, mask) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SeqWindowAttention(torch.nn.Module): + + def __init__(self, dim, inner_dim, window_size, num_heads, + qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + coshard=1): + super().__init__() + self.coshard = coshard + assert inner_dim % coshard == 0 + self.attns = torch.nn.ModuleList( + [WindowAttention( + dim, inner_dim // coshard, window_size, num_heads // coshard, + qkv_bias, qk_scale, attn_drop, proj_drop) for _ in range(coshard)] + ) + + def forward(self, x, mask=None): + outs = None + for attn in self.attns: + x_out = attn(x, mask) + outs = x_out if outs is None else outs + x_out + return outs + + +class SwinTransformerBlock(PipeStage): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_coshard=False, layer_id=None): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + if not use_coshard or args.coshard == 1: + self.attn = WindowAttention( + dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + else: + print(f'use colocate-sharding: {args.coshard}') + self.attn = SeqWindowAttention( + dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, coshard=args.coshard) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + if not use_coshard or args.coshard == 1: + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + else: + self.mlp = SeqMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, coshard=args.coshard) + + H, W = self.input_resolution + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + assert args.micro_bs // args.dp_size != 0 + self.inputs_info = ( + ((args.micro_bs // args.dp_size, H * W, self.dim),), + (torch.float32 if not args.fp16 else torch.float16,) + ) + self.outputs_info = ( + ((args.micro_bs // args.dp_size, H * W, self.dim),), + (torch.float32 if not args.fp16 else torch.float16,) + ) + self.layer_id = layer_id # for profiling + + def forward(self, x): + CudaTimer().start(f'layer{self.layer_id}') + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + CudaTimer().stop(f'layer{self.layer_id}') + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(PipeStage): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + H, W = self.input_resolution + assert args.micro_bs // args.dp_size != 0 + self.inputs_info = ( + ((args.micro_bs // args.dp_size, H * W, self.dim),), + (torch.float32 if not args.fp16 else torch.float16,) + ) + self.outputs_info = ( + ((args.micro_bs // args.dp_size, (H // 2) * (W // 2), self.dim * 2),), + (torch.float32 if not args.fp16 else torch.float16,) + ) + + def forward_(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def forward(self, x): + x = checkpoint.checkpoint(self.forward_, x) + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +def create_basic_layter(dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, layer_id=None): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + # swin transformer layers + blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, use_coshard=depth==2, layer_id=layer_id) + for i in range(depth)]) + # patch merging layer + if downsample is not None: + downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + blocks.append(downsample) + return blocks + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(PipeStage): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, **kwargs): + super().__init__() + self.set_pipeline(_pp_global_ranks) + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + self.patches_resolution = (img_size // patch_size, img_size // patch_size) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + blocks = create_basic_layter(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(self.patches_resolution[0] // (2 ** i_layer), + self.patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + layer_id=i_layer) + self.layers += blocks + + # pipeline split layers + start, end = _layer_divisions[self.stage_local_rank] + print_each_rank(f'initializing layer ranging from [{start}, {end})') + self.layers = self.layers[start:end] + + self.inputs_info = self.layers[0].inputs_info + self.outputs_info = self.layers[-1].outputs_info + + # preprocess + if self.is_first_stage: + print(f'rank [{torch.distributed.get_rank()}]: initializing pre-process...') + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + # dropout + self.pos_drop = nn.Dropout(p=drop_rate) + + self.inputs_info = ((), ()) + + # post-process + if self.is_last_stage: + print(f'rank [{torch.distributed.get_rank()}]: initializing post-process...') + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.criterion = nn.CrossEntropyLoss() + + self.outputs_info = ( + (1,), + torch.float32 if args.fp16 else torch.float16 + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward(self, x = None): + if self.is_first_stage: + CudaTimer().start('pre-process') + x, _ = self.data + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + CudaTimer().stop('pre-process') + + for layer in self.layers: + x = layer(x) + + if self.is_last_stage: + CudaTimer().start('post-process') + _, labels = self.data + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + x = self.criterion(x, labels) + CudaTimer().stop('post-process') + + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int, img_size: int, num_classes: int): + + self.bs = batch_size + self.img_size = img_size + self.num_classes = num_classes + super().__init__( + shapes=([batch_size, 3, img_size, img_size,], + [batch_size], + ), + dtypes=(torch.float, torch.int), + batch_dims=(0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + img = torch.rand( + *(self.bs, 3, self.img_size, self.img_size), + dtype=torch.float, + device=torch.cuda.current_device() + ) + labels = torch.randint( + 0, self.num_classes, + size=(self.bs,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + return (img, labels) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + +def train(): + + cfg = Config() + model = SwinTransformer(img_size=cfg.img_size, + patch_size=4, + in_chans=3, + num_classes=cfg.num_classes, + embed_dim=cfg.embed_dim, + depths=cfg.depths, + num_heads=cfg.num_heads, + window_size=cfg.window_size, + mlp_ratio=cfg.mlp_ratio, + qkv_bias=cfg.qkv_bias, + qk_scale=cfg.qk_scale, + drop_rate=cfg.drop_rate, + drop_path_rate=cfg.drop_path_rate, + ape=False, + patch_norm=True, + use_checkpoint=False) + model = model.cuda() + dataloader = ImageDataLoader(args.micro_bs, cfg.img_size, cfg.num_classes) + if _dp_reducer is not None: + _dp_reducer.add_param(model.parameters()) + optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) + + print_each_rank('model weight consumpition:') + memory_summary() + + def train_iter(model, dataloader): + num_microbatch = args.bs // args.micro_bs + if _pp_group != -1: + _schedule(model, dataloader, num_microbatch) + else: + for _ in range(num_microbatch): + model.data = next(dataloader) + loss = model() + loss.backward() + if _dp_reducer is not None: + _dp_reducer.allreduce() + + CudaTimer(enable=False).warmup() + iter_num = 10 + for step in range(iter_num): + + # if step == 0: + # model_summary(model, next(dataloader)) + + if step >= 4: + CudaTimer(enable=True).start('e2e') + + # training + train_iter(model, dataloader) + + if step == 0: + print_each_rank('passed first iteration', rank_only=0) + print_each_rank('memory consumption before optimizer:', rank_only=0) + memory_summary() + + optimizer.step() + optimizer.zero_grad() + + if step >= 4: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('memory consumption after optimizer:', rank_only=0) + memory_summary() + + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-4, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-4) + memory_summary() + +train() diff --git a/handcraft/swin/utils.py b/handcraft/swin/utils.py new file mode 100644 index 00000000..dada9d68 --- /dev/null +++ b/handcraft/swin/utils.py @@ -0,0 +1,106 @@ +from typing import Tuple +import warnings + +import torch +import math + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class DropPath(torch.nn.Module): + + def __init__(self, drop_prob: float): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +def create_position_bias(window_size: Tuple[int, int], num_heads: int): + relative_position_bias_table = torch.nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + trunc_normal_(relative_position_bias_table, std=.02) + return relative_position_bias_table, relative_position_index From 5d99946982b4404e71c1580bd2657b011e43f851 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 19:54:01 +0800 Subject: [PATCH 0704/1892] fix fp16 and dp bug --- handcraft/swin/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 3e34bcb2..b34de544 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -768,7 +768,7 @@ def __init__(self, batch_size: int, img_size: int, num_classes: int): shapes=([batch_size, 3, img_size, img_size,], [batch_size], ), - dtypes=(torch.float, torch.int), + dtypes=(torch.float if not args.fp16 else torch.float16, torch.int), batch_dims=(0, 0) ) self.samples = [self.random_sample()] @@ -776,7 +776,7 @@ def __init__(self, batch_size: int, img_size: int, num_classes: int): def random_sample(self): img = torch.rand( *(self.bs, 3, self.img_size, self.img_size), - dtype=torch.float, + dtype=torch.float if not args.fp16 else torch.float16, device=torch.cuda.current_device() ) labels = torch.randint( @@ -813,10 +813,13 @@ def train(): ape=False, patch_norm=True, use_checkpoint=False) + if args.fp16: + model = model.half() model = model.cuda() dataloader = ImageDataLoader(args.micro_bs, cfg.img_size, cfg.num_classes) if _dp_reducer is not None: - _dp_reducer.add_param(model.parameters()) + for param in model.parameters(): + _dp_reducer.add_param(param) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) print_each_rank('model weight consumpition:') From 452c4d8cb542cb525078e108a059e9974a9e33fa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 12:39:29 +0000 Subject: [PATCH 0705/1892] coshard only at first layer --- handcraft/swin/test.sh | 32 +++++++++++++++++++++++++++ handcraft/swin/train.py | 49 ++++++++++++++++++++++++++--------------- 2 files changed, 63 insertions(+), 18 deletions(-) create mode 100755 handcraft/swin/test.sh diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh new file mode 100755 index 00000000..252f1f22 --- /dev/null +++ b/handcraft/swin/test.sh @@ -0,0 +1,32 @@ +# swin transformer constant head dim == 32 + +evaldir=eval/swin-coshard +mkdir -p ${evaldir} + +# Swin-Giant +layers=42 +dim=512 +heads=16 +img_size=1536 +window_size=48 +coshard=16 +gpus=4 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 16 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 1 --tp-size ${gpus} --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index b34de544..85e65cc0 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -91,10 +91,10 @@ # layer division nlayers = 2 + 2 + args.layers + 2 + 3 # 3 is patch merging layers - times = ([1] * 2 + [0]) + \ - ([1] * 2 + [0]) + \ - ([1] * args.layers + [0]) + \ - ([1] * 2) + times = ([2039/2] * 2 + [0]) + \ + ([1118/2] * 2 + [0]) + \ + ([5474/4] * args.layers + [0]) + \ + ([510/2] * 2) num_stages = len(pp_ranks) budget = sum(times) // num_stages print_each_rank(f'budget: {budget}', rank_only=0) @@ -158,9 +158,11 @@ def forward_(self, x): x = AllReduceIdentity.apply(x, self._tp_group) return x - def forward(self, x): - x = checkpoint.checkpoint(self.forward_, x) - # x = self.forward_(x) + def forward(self, x, recompute=True): + if recompute: + x = checkpoint.checkpoint(self.forward_, x) + else: + x = self.forward_(x) return x @@ -176,10 +178,10 @@ def __init__(self, in_features, hidden_features=None, out_features=None, [Mlp(in_features, hidden_features // coshard, out_features, act_layer, drop) for _ in range(coshard)] ) - def forward(self, x): + def forward(self, x, recompute=True): outs = None for mlp in self.mlps: - x_out = mlp(x) + x_out = mlp(x, recompute=recompute) outs = x_out if outs is None else outs + x_out return outs @@ -263,9 +265,11 @@ def forward_(self, x, mask=None): return x - def forward(self, x, mask=None): - x = checkpoint.checkpoint(self.forward_, x, mask) - # x = self.forward_(x, mask) + def forward(self, x, mask=None, recompute=True): + if recompute: + x = checkpoint.checkpoint(self.forward_, x, mask) + else: + x = self.forward_(x, mask) return x def extra_repr(self) -> str: @@ -299,10 +303,10 @@ def __init__(self, dim, inner_dim, window_size, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) for _ in range(coshard)] ) - def forward(self, x, mask=None): + def forward(self, x, mask=None, recompute=True): outs = None for attn in self.attns: - x_out = attn(x, mask) + x_out = attn(x, mask, recompute) outs = x_out if outs is None else outs + x_out return outs @@ -396,8 +400,7 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 ) self.layer_id = layer_id # for profiling - def forward(self, x): - CudaTimer().start(f'layer{self.layer_id}') + def forward_(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" @@ -417,7 +420,7 @@ def forward(self, x): x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows = self.attn(x_windows, mask=self.attn_mask, recompute=self.layer_id != 2) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) @@ -432,7 +435,17 @@ def forward(self, x): # FFN x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path(self.mlp(self.norm2(x), recompute=self.layer_id != 2)) + return x + + def forward(self, x): + CudaTimer().start(f'layer{self.layer_id}') + # layer-wise recompute + if self.layer_id == 2: + x = checkpoint.checkpoint(self.forward_, x) + # attention/mlp-wise recompute + else: + x = self.forward_(x) CudaTimer().stop(f'layer{self.layer_id}') return x From 6838f20ef3e324fdb696335eb43c116cc67bf41e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 12:54:15 +0000 Subject: [PATCH 0706/1892] update profiled metric --- handcraft/swin/test.sh | 2 ++ handcraft/swin/train.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index 252f1f22..20a6bae9 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -30,3 +30,5 @@ OMP_NUM_THREADS=4 torchrun \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt + +python scripts/keep.py --gpus 8 diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 85e65cc0..903bf06b 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -93,7 +93,7 @@ nlayers = 2 + 2 + args.layers + 2 + 3 # 3 is patch merging layers times = ([2039/2] * 2 + [0]) + \ ([1118/2] * 2 + [0]) + \ - ([5474/4] * args.layers + [0]) + \ + ([2910/8] * args.layers + [0]) + \ ([510/2] * 2) num_stages = len(pp_ranks) budget = sum(times) // num_stages From 570e0ca2e666ea77f72ac5e48dcd22ef288af7dc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 29 Mar 2022 15:35:21 +0000 Subject: [PATCH 0707/1892] setup test scripts --- handcraft/swin/test.sh | 132 +++++++++++++++++++++++++++++++++++++++- handcraft/swin/train.py | 10 +-- 2 files changed, 135 insertions(+), 7 deletions(-) diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index 20a6bae9..f5d231f2 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -3,7 +3,85 @@ evaldir=eval/swin-coshard mkdir -p ${evaldir} -# Swin-Giant + +layers=18 +dim=192 +heads=6 +img_size=1536 +window_size=48 +coshard=6 +gpus=4 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 1 --tp-size ${gpus} --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt + + + +layers=26 +dim=384 +heads=12 +img_size=1536 +window_size=48 +coshard=12 +gpus=4 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 1 --tp-size ${gpus} --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt + + + layers=42 dim=512 heads=16 @@ -19,7 +97,7 @@ OMP_NUM_THREADS=4 torchrun \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 16 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt + --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt OMP_NUM_THREADS=4 torchrun \ @@ -31,4 +109,54 @@ OMP_NUM_THREADS=4 torchrun \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt + + + +layers=50 +dim=768 +heads=24 +img_size=1536 +window_size=48 +coshard=16 +gpus=8 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 1 --tp-size ${gpus} --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt + + python scripts/keep.py --gpus 8 diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 903bf06b..54e503fa 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -851,13 +851,13 @@ def train_iter(model, dataloader): _dp_reducer.allreduce() CudaTimer(enable=False).warmup() - iter_num = 10 + iter_num = 6 for step in range(iter_num): # if step == 0: # model_summary(model, next(dataloader)) - if step >= 4: + if step >= 2: CudaTimer(enable=True).start('e2e') # training @@ -871,7 +871,7 @@ def train_iter(model, dataloader): optimizer.step() optimizer.zero_grad() - if step >= 4: + if step >= 2: CudaTimer().stop('e2e') if step == 0: @@ -882,8 +882,8 @@ def train_iter(model, dataloader): print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-4, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-4) + CudaTimer().duration(iter_num-2, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-2) memory_summary() train() From 9dbd10f39d9a910de7c9921c5401f216168e76ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Mar 2022 02:10:14 +0000 Subject: [PATCH 0708/1892] add test script --- handcraft/swin/test.sh | 206 +++++++++++------------------------------ 1 file changed, 56 insertions(+), 150 deletions(-) diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index f5d231f2..aba8623c 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -4,159 +4,65 @@ evaldir=eval/swin-coshard mkdir -p ${evaldir} -layers=18 -dim=192 -heads=6 img_size=1536 window_size=48 -coshard=6 -gpus=4 -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt - - - -layers=26 -dim=384 -heads=12 -img_size=1536 -window_size=48 -coshard=12 -gpus=4 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt - - - -layers=42 -dim=512 -heads=16 -img_size=1536 -window_size=48 -coshard=16 -gpus=4 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt - - - -layers=50 -dim=768 -heads=24 -img_size=1536 -window_size=48 -coshard=16 -gpus=8 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt +test() +{ + layers=$1 + dim=$2 + heads=$3 + coshard=$4 + gpus=$5 + + echo "testing ${gpus}-dev: PP-Coshard${coshard}: L${layers}E${dim}H${heads}" + echo "OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt" + + echo "testing ${gpus}-dev: TP-Coshard1: L${layers}E${dim}H${heads}" + echo "OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 1 --tp-size ${gpus} --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt" + + echo "testing ${gpus}-dev: PP-Coshard1: L${layers}E${dim}H${heads}" + echo "OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt" + + killall python + sleep 5 + killall python +} + +# test Layers Dim Heads Coshard GPUs +test 18 256 8 8 4 +test 18 512 16 16 4 +test 18 768 24 24 4 + +test 26 512 16 16 4 +test 26 768 24 24 4 +test 26 1024 32 32 4 + +test 34 256 8 8 8 +test 34 512 16 16 8 +test 34 768 24 24 8 +test 34 1024 32 32 8 python scripts/keep.py --gpus 8 From ad6404df6f812f2fdea618e0021e2d5de7041d8a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Mar 2022 02:12:20 +0000 Subject: [PATCH 0709/1892] fix test bug --- handcraft/swin/test.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index aba8623c..ed08a0e5 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -17,34 +17,34 @@ test() gpus=$5 echo "testing ${gpus}-dev: PP-Coshard${coshard}: L${layers}E${dim}H${heads}" - echo "OMP_NUM_THREADS=4 torchrun \ + OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt" + --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt echo "testing ${gpus}-dev: TP-Coshard1: L${layers}E${dim}H${heads}" - echo "OMP_NUM_THREADS=4 torchrun \ + OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt" + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt echo "testing ${gpus}-dev: PP-Coshard1: L${layers}E${dim}H${heads}" - echo "OMP_NUM_THREADS=4 torchrun \ + OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt" + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt killall python sleep 5 From 76a7d93a56b2bc3d6df310231ad622bfb269adbd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Mar 2022 10:20:39 +0800 Subject: [PATCH 0710/1892] fix assert --- cube/graph/operator/function/pad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/operator/function/pad.py b/cube/graph/operator/function/pad.py index e1f96cf6..3b18a9c2 100644 --- a/cube/graph/operator/function/pad.py +++ b/cube/graph/operator/function/pad.py @@ -26,7 +26,7 @@ def infer_shape(self) -> bool: pad = self.kwargs['pad'] mode = self.kwargs['mode'] value = self.kwargs['value'] - assert (len(pad) % 2 == 0, "IRPad::infer_shape len(pad) % 2 == 0") + assert len(pad) % 2 == 0, "IRPad::infer_shape len(pad) % 2 == 0" shape = self.inputs(0).shape for pad_idx, pad_size in enumerate(pad): From 3dd194418fb3d1ebf93b0736d5e6f5a02adf0f58 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Mar 2022 03:13:05 +0000 Subject: [PATCH 0711/1892] add profile results --- handcraft/swin/train.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 54e503fa..26e78fd0 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -5,9 +5,9 @@ --nproc_per_node=1 \ --nnodes=1 \ handcraft/swin/train.py \ - --layers 18 --dim 192 --heads 6 \ + --layers 18 --dim 256 --heads 8 \ --pp-size 1 --tp-size 1 --dp-size 1 \ - --bs 4 --micro-bs 1 --coshard 1 --fp16 + --bs 4 --micro-bs 1 --coshard 8 --fp16 """ import torch @@ -37,9 +37,9 @@ parser.add_argument('--heads', type=int, default=6, help='head num of first stage') # data -parser.add_argument('--img-size', type=int, default=640, +parser.add_argument('--img-size', type=int, default=1536, help='image size, can be 224, 640, 1536') -parser.add_argument('--window-size', type=int, default=40, +parser.add_argument('--window-size', type=int, default=48, help='image size, can be 7, 40, 48') # training parser.add_argument('--bs', type=int, default=256, @@ -91,10 +91,33 @@ # layer division nlayers = 2 + 2 + args.layers + 2 + 3 # 3 is patch merging layers - times = ([2039/2] * 2 + [0]) + \ - ([1118/2] * 2 + [0]) + \ - ([2910/8] * args.layers + [0]) + \ - ([510/2] * 2) + # metrics for V100-32GB-PCIe + if args.dim == 256: # OK! + times = ([109.93] * 2 + [0]) + \ + ([60.34] * 2 + [0]) + \ + ([43.18] * args.layers + [0]) + \ + ([27.51] * 2) + elif args.dim == 512: # OK! + times = ([255.10] * 2 + [0]) + \ + ([139.92] * 2 + [0]) + \ + ([90.98] * args.layers + [0]) + \ + ([63.78] * 2) + elif args.dim == 768: # OK! + times = ([440.5] * 2 + [0]) + \ + ([241.4] * 2 + [0]) + \ + ([145.7] * args.layers + [0]) + \ + ([108.9] * 2) + elif args.dim == 1024: # TP needed + times = ([255.10] * 2 + [0]) + \ + ([139.92] * 2 + [0]) + \ + ([90.98] * args.layers + [0]) + \ + ([63.78] * 2) + else: + print_each_rank('WARNING: NO Metric Logged!!') + times = ([1] * 2 + [0]) + \ + ([1] * 2 + [0]) + \ + ([1] * args.layers + [0]) + \ + ([1] * 2) num_stages = len(pp_ranks) budget = sum(times) // num_stages print_each_rank(f'budget: {budget}', rank_only=0) From 9d7709dd8073845f2cdd86185ba0e3df2645b8f2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Mar 2022 05:01:55 +0000 Subject: [PATCH 0712/1892] better pipeline stages --- handcraft/swin/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 26e78fd0..9a7b381c 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -126,9 +126,9 @@ accum = times[start] assert end <= nlayers while end != nlayers: - accum += times[end] - if accum > budget: + if budget - accum < 0.5 * times[end]: break + accum += times[end] end += 1 if idx == num_stages - 1: end = nlayers From 91476efe67ad441f5ab5181a8782322dd3d9966d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Mar 2022 15:18:19 +0000 Subject: [PATCH 0713/1892] add test for hybrid tp --- handcraft/swin/test.sh | 55 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index ed08a0e5..36a3e12c 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -36,6 +36,56 @@ test() --pp-size 1 --tp-size ${gpus} --dp-size 1 \ --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt + # # Hybrid TP-1F1B -- 4 GPU + if [ ${gpus} == 4 ] + then + echo "testing ${gpus}-dev: TP2-PP2: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 2 --tp-size 2 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2-coshard1.txt + sleep 5 + killall python + sleep 5 + killall python + fi + + # Hybrid TP-1F1B -- 8 GPU + if [ ${gpus} == 8 ] + then + echo "testing ${gpus}-dev: TP4-PP2: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 2 --tp-size 4 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2-coshard1.txt + sleep 5 + killall python + sleep 5 + killall python + + echo "testing ${gpus}-dev: TP2-PP4: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 4 --tp-size 2 --dp-size 1 \ + --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4-coshard1.txt + sleep 5 + killall python + sleep 5 + killall python + fi + echo "testing ${gpus}-dev: PP-Coshard1: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ @@ -56,13 +106,14 @@ test 18 256 8 8 4 test 18 512 16 16 4 test 18 768 24 24 4 +test 26 256 8 8 4 test 26 512 16 16 4 test 26 768 24 24 4 -test 26 1024 32 32 4 +# test 26 1024 32 32 4 test 34 256 8 8 8 test 34 512 16 16 8 test 34 768 24 24 8 -test 34 1024 32 32 8 +# test 34 1024 32 32 8 python scripts/keep.py --gpus 8 From 4133a509f443c632f784d25ce67422f99c56461a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 31 Mar 2022 08:06:47 +0000 Subject: [PATCH 0714/1892] swin cosharding --- cube/runtime/adapter/distnn.py | 2 +- handcraft/swin/test.sh | 152 +++++++++++++++++++++++++++------ handcraft/swin/train.py | 81 ++++++++++++------ 3 files changed, 184 insertions(+), 51 deletions(-) diff --git a/cube/runtime/adapter/distnn.py b/cube/runtime/adapter/distnn.py index b8f9bf06..d41ddb30 100644 --- a/cube/runtime/adapter/distnn.py +++ b/cube/runtime/adapter/distnn.py @@ -37,7 +37,7 @@ def backward(ctx, grad_output): return grad_output, None -class AllGatherScatter(torch.autograd.Function): +class AllGatherSplit(torch.autograd.Function): @staticmethod def forward(ctx, input, dim, group): diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index 36a3e12c..c6abeb75 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -8,15 +8,14 @@ img_size=1536 window_size=48 -test() +test_naive_pp() { layers=$1 dim=$2 heads=$3 - coshard=$4 - gpus=$5 + gpus=$4 - echo "testing ${gpus}-dev: PP-Coshard${coshard}: L${layers}E${dim}H${heads}" + echo "testing ${gpus}-dev: Pure PP${coshard}: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ @@ -24,9 +23,21 @@ test() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard ${coshard} --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard${coshard}.txt + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}.txt + sleep 5 + killall python + sleep 5 + killall python +} - echo "testing ${gpus}-dev: TP-Coshard1: L${layers}E${dim}H${heads}" +test_naive_tp() +{ + layers=$1 + dim=$2 + heads=$3 + gpus=$4 + + echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ @@ -34,9 +45,20 @@ test() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}-coshard1.txt + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_naive_hybrid_tp_pp() +{ + layers=$1 + dim=$2 + heads=$3 + gpus=$4 - # # Hybrid TP-1F1B -- 4 GPU if [ ${gpus} == 4 ] then echo "testing ${gpus}-dev: TP2-PP2: L${layers}E${dim}H${heads}" @@ -47,7 +69,7 @@ test() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 2 --tp-size 2 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2-coshard1.txt + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2.txt sleep 5 killall python sleep 5 @@ -65,7 +87,7 @@ test() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 2 --tp-size 4 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2-coshard1.txt + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2.txt sleep 5 killall python sleep 5 @@ -79,14 +101,22 @@ test() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 4 --tp-size 2 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4-coshard1.txt + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4.txt sleep 5 killall python sleep 5 killall python fi +} + +test_coshard_pp() +{ + layers=$1 + dim=$2 + heads=$3 + gpus=$4 - echo "testing ${gpus}-dev: PP-Coshard1: L${layers}E${dim}H${heads}" + echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ @@ -94,26 +124,96 @@ test() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --coshard 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard1.txt - + --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard.txt + sleep 5 killall python sleep 5 killall python } -# test Layers Dim Heads Coshard GPUs -test 18 256 8 8 4 -test 18 512 16 16 4 -test 18 768 24 24 4 +test_coshard_hybrid_tp_pp() +{ + layers=$1 + dim=$2 + heads=$3 + gpus=$4 + + if [ ${gpus} == 4 ] + then + echo "testing ${gpus}-dev: TP2-PP2: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 2 --tp-size 2 --dp-size 1 \ + --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2.txt + sleep 5 + killall python + sleep 5 + killall python + fi + + # Hybrid TP-1F1B -- 8 GPU + if [ ${gpus} == 8 ] + then + echo "testing ${gpus}-dev: TP4-PP2: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 2 --tp-size 4 --dp-size 1 \ + --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2.txt + sleep 5 + killall python + sleep 5 + killall python + + echo "testing ${gpus}-dev: TP2-PP4: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 4 --tp-size 2 --dp-size 1 \ + --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4.txt + sleep 5 + killall python + sleep 5 + killall python + fi +} + +test_all() +{ + layers=$1 + dim=$2 + heads=$3 + gpus=$4 + test_naive_pp $layers $dim $heads $gpus + test_naive_tp $layers $dim $heads $gpus + test_naive_hybrid_tp_pp $layers $dim $heads $gpus + test_coshard_pp $layers $dim $heads $gpus +} + + +# test Layers Dim Heads GPUs +test_all 18 256 8 4 +test_all 18 512 16 4 +test_all 18 768 24 4 -test 26 256 8 8 4 -test 26 512 16 16 4 -test 26 768 24 24 4 -# test 26 1024 32 32 4 +test_all 26 256 8 8 4 +test_all 26 512 16 16 4 +test_all 26 768 24 24 4 +test_all 26 1024 32 32 4 -test 34 256 8 8 8 -test 34 512 16 16 8 -test 34 768 24 24 8 -# test 34 1024 32 32 8 +test_all 34 256 8 8 +test_all 34 512 16 8 +test_all 34 768 24 8 +test_all 34 1024 32 8 python scripts/keep.py --gpus 8 diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 9a7b381c..66204888 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -1,13 +1,14 @@ """ example: +gpus=4 OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ + --nproc_per_node=${gpus} \ --nnodes=1 \ handcraft/swin/train.py \ --layers 18 --dim 256 --heads 8 \ - --pp-size 1 --tp-size 1 --dp-size 1 \ - --bs 4 --micro-bs 1 --coshard 8 --fp16 + --pp-size 4 --tp-size 1 --dp-size 1 \ + --bs ${gpus} --micro-bs 1 --use-coshard --fp16 """ import torch @@ -19,7 +20,7 @@ from cube.profiler.memory import memory_summary, model_summary from cube.runtime.adapter.reducer import Reducer from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.distnn import IdentityAllreduce, AllReduceIdentity +from cube.runtime.adapter.distnn import IdentityAllreduce, AllReduceIdentity, AllGatherSplit from handcraft.module.schedule import schedule_1f1b from handcraft.module.stage import PipeStage from handcraft.swin.utils import create_position_bias, trunc_normal_, window_partition, window_reverse, DropPath @@ -54,8 +55,8 @@ help='data parallelism size') parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b'], help='scheduling algorithm') -parser.add_argument('--coshard', type=int, default=1) -parser.add_argument('--fp16', action='store_true', default='') +parser.add_argument('--use-coshard', action='store_true', default=False) +parser.add_argument('--fp16', action='store_true', default=False) args = parser.parse_args() print(args) @@ -126,7 +127,7 @@ accum = times[start] assert end <= nlayers while end != nlayers: - if budget - accum < 0.5 * times[end]: + if times[end] > 0 and budget - accum < 0.5 * times[end]: break accum += times[end] end += 1 @@ -318,15 +319,37 @@ def __init__(self, dim, inner_dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., coshard=1): super().__init__() + assert (num_heads // args.tp_size) % coshard == 0 + # only coshard num heads of first two stages self.coshard = coshard - assert inner_dim % coshard == 0 self.attns = torch.nn.ModuleList( [WindowAttention( - dim, inner_dim // coshard, window_size, num_heads // coshard, - qkv_bias, qk_scale, attn_drop, proj_drop) for _ in range(coshard)] + dim, inner_dim // self.coshard, window_size, num_heads // self.coshard, + qkv_bias, qk_scale, attn_drop, proj_drop) for _ in range(self.coshard)] ) def forward(self, x, mask=None, recompute=True): + + # ===> sharding from both window and heads + # B = x.size(0) + # if B % 2 == 0: + # xs = torch.chunk(x, 2, dim=0) + # masks = torch.chunk(mask, 2, dim=0) if mask is not None else (None,) * 2 + # else: + # xs = (x,) + # masks = (mask,) + # outs = [] + # for bid, (cx, cmask) in enumerate(zip(xs, masks)): + # for attn in self.attns: + # cx_out = attn(cx, cmask, recompute) + # if len(outs) < bid + 1: + # outs.append(cx_out) + # else: + # outs[bid] = outs[bid] + cx_out + # outs = torch.concat(tuple(outs), dim=0) + # return outs + + # ===> sharding only from heads outs = None for attn in self.attns: x_out = attn(x, mask, recompute) @@ -369,23 +392,25 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) - if not use_coshard or args.coshard == 1: + if not use_coshard or layer_id in [2,3]: self.attn = WindowAttention( dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) else: - print(f'use colocate-sharding: {args.coshard}') + coshard = num_heads // args.tp_size + print(f'Swin-stage-{layer_id} using coshard {coshard}') self.attn = SeqWindowAttention( dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, coshard=args.coshard) + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, coshard=coshard) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - if not use_coshard or args.coshard == 1: + if not use_coshard or layer_id in [2,3]: self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) else: - self.mlp = SeqMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, coshard=args.coshard) + coshard = num_heads // args.tp_size + self.mlp = SeqMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, coshard=coshard) H, W = self.input_resolution if self.shift_size > 0: @@ -421,7 +446,8 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 ((args.micro_bs // args.dp_size, H * W, self.dim),), (torch.float32 if not args.fp16 else torch.float16,) ) - self.layer_id = layer_id # for profiling + self.layer_id = layer_id + self.inner_recompute = False if not use_coshard else layer_id in [0,1] def forward_(self, x): H, W = self.input_resolution @@ -443,7 +469,7 @@ def forward_(self, x): x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask, recompute=self.layer_id != 2) # nW*B, window_size*window_size, C + attn_windows = self.attn(x_windows, mask=self.attn_mask, recompute=self.inner_recompute) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) @@ -458,13 +484,13 @@ def forward_(self, x): # FFN x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x), recompute=self.layer_id != 2)) + x = x + self.drop_path(self.mlp(self.norm2(x), recompute=self.inner_recompute)) return x def forward(self, x): CudaTimer().start(f'layer{self.layer_id}') # layer-wise recompute - if self.layer_id == 2: + if not self.inner_recompute: x = checkpoint.checkpoint(self.forward_, x) # attention/mlp-wise recompute else: @@ -583,7 +609,7 @@ def create_basic_layter(dim, input_resolution, depth, num_heads, window_size, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, use_coshard=depth==2, layer_id=layer_id) + norm_layer=norm_layer, use_coshard=args.use_coshard, layer_id=layer_id) for i in range(depth)]) # patch merging layer if downsample is not None: @@ -775,9 +801,15 @@ def forward(self, x = None): if self.is_last_stage: CudaTimer().start('post-process') _, labels = self.data - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) + + def _post_process(x): + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + x = self.head(x) + return x + + x = checkpoint.checkpoint(_post_process, x) x = self.criterion(x, labels) CudaTimer().stop('post-process') @@ -878,7 +910,8 @@ def train_iter(model, dataloader): for step in range(iter_num): # if step == 0: - # model_summary(model, next(dataloader)) + # model.data = next(dataloader) + # model_summary(model, (), rank_only=1) if step >= 2: CudaTimer(enable=True).start('e2e') From 43bf1fe81b337669993a481dcde99d1ed5d40040 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 31 Mar 2022 09:17:17 +0000 Subject: [PATCH 0715/1892] fix cosharding tp-pp hybrid bug --- handcraft/swin/train.py | 45 ++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 66204888..842a6fe8 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -1,14 +1,14 @@ """ example: -gpus=4 +gpus=16 OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ handcraft/swin/train.py \ - --layers 18 --dim 256 --heads 8 \ - --pp-size 4 --tp-size 1 --dp-size 1 \ - --bs ${gpus} --micro-bs 1 --use-coshard --fp16 + --bs ${gpus} --micro-bs 1 --fp16 \ + --dp-size 1 --pp-size 16 --tp-size 1 \ + --layers 42 --dim 1024 --heads 32 --use-coshard """ import torch @@ -195,18 +195,31 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., coshard=1): super().__init__() + self._tp_group = _tp_group + self._tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + self.coshard = coshard assert hidden_features is not None assert hidden_features % coshard == 0 self.mlps = torch.nn.ModuleList( [Mlp(in_features, hidden_features // coshard, out_features, act_layer, drop) for _ in range(coshard)] ) + # remove tp communication inside each mlp as it will be + # done outside here + for mlp in self.mlps: + mlp._tp_size = 1 def forward(self, x, recompute=True): + if self._tp_size > 1: + x = IdentityAllreduce.apply(x, self._tp_group) + outs = None for mlp in self.mlps: x_out = mlp(x, recompute=recompute) outs = x_out if outs is None else outs + x_out + + if self._tp_size > 1: + outs = AllReduceIdentity.apply(outs, self._tp_group) return outs @@ -319,7 +332,9 @@ def __init__(self, dim, inner_dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., coshard=1): super().__init__() - assert (num_heads // args.tp_size) % coshard == 0 + self._tp_group = _tp_group + self._tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + assert (num_heads // args.tp_size) % coshard == 0 # only coshard num heads of first two stages self.coshard = coshard self.attns = torch.nn.ModuleList( @@ -327,6 +342,10 @@ def __init__(self, dim, inner_dim, window_size, num_heads, dim, inner_dim // self.coshard, window_size, num_heads // self.coshard, qkv_bias, qk_scale, attn_drop, proj_drop) for _ in range(self.coshard)] ) + # remove communication inside each attention as it will be + # done outside here + for attn in self.attns: + attn._tp_size = 1 def forward(self, x, mask=None, recompute=True): @@ -350,10 +369,16 @@ def forward(self, x, mask=None, recompute=True): # return outs # ===> sharding only from heads + if self._tp_size > 1: + x = IdentityAllreduce.apply(x, self._tp_group) + outs = None for attn in self.attns: x_out = attn(x, mask, recompute) outs = x_out if outs is None else outs + x_out + + if self._tp_size > 1: + outs = AllReduceIdentity.apply(outs, self._tp_group) return outs @@ -398,7 +423,7 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) else: coshard = num_heads // args.tp_size - print(f'Swin-stage-{layer_id} using coshard {coshard}') + print_each_rank(f'Swin-stage-{layer_id} using coshard {coshard}', rank_only=0) self.attn = SeqWindowAttention( dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, coshard=coshard) @@ -919,10 +944,10 @@ def train_iter(model, dataloader): # training train_iter(model, dataloader) - if step == 0: - print_each_rank('passed first iteration', rank_only=0) - print_each_rank('memory consumption before optimizer:', rank_only=0) - memory_summary() + # if step == 0: + # print_each_rank('passed first iteration', rank_only=0) + # print_each_rank('memory consumption before optimizer:', rank_only=0) + # memory_summary() optimizer.step() optimizer.zero_grad() From 0bf9903d587bba5de5be78153778ab1b1c6ebf9d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Apr 2022 01:58:32 +0000 Subject: [PATCH 0716/1892] add test for multiple node --- handcraft/swin/test-multi-node.sh | 203 ++++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100755 handcraft/swin/test-multi-node.sh diff --git a/handcraft/swin/test-multi-node.sh b/handcraft/swin/test-multi-node.sh new file mode 100755 index 00000000..1feec5f7 --- /dev/null +++ b/handcraft/swin/test-multi-node.sh @@ -0,0 +1,203 @@ +# swin transformer constant head dim == 32 + +evaldir=eval/swin-coshard +mkdir -p ${evaldir} + + +img_size=1536 +window_size=48 + + +test_naive_pp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + + echo "testing ${gpus}-dev: Pure PP${coshard}: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_naive_tp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + + echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 1 --tp-size ${gpus} --dp-size 1 \ + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_naive_hybrid_tp_pp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + + # Hybrid TP-1F1B -- 16 GPU + if [ ${gpus} == 16 ] + then + echo "testing ${gpus}-dev: TP8-PP2: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 2 --tp-size 8 --dp-size 1 \ + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp8pp2.txt + sleep 5 + killall python + sleep 5 + killall python + + echo "testing ${gpus}-dev: TP2-PP4: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 4 --tp-size 4 --dp-size 1 \ + --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp4.txt + sleep 5 + killall python + sleep 5 + killall python + fi +} + +test_coshard_pp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + + echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_coshard_hybrid_tp_pp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + + # Hybrid TP-1F1B -- 8 GPU + if [ ${gpus} == 16 ] + then + echo "testing ${gpus}-dev: TP8-PP2: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 2 --tp-size 8 --dp-size 1 \ + --bs 64 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp8pp2-coshard.txt + sleep 5 + killall python + sleep 5 + killall python + + # echo "testing ${gpus}-dev: TP4-PP4: L${layers}E${dim}H${heads}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=${nodes} \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ + # handcraft/swin/train.py \ + # --layers ${layers} --dim ${dim} --heads ${heads} \ + # --img-size ${img_size} --window-size ${window_size} \ + # --pp-size 4 --tp-size 4 --dp-size 1 \ + # --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp4-coshard.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + fi +} + +test_all() +{ + layers=$1 + dim=$2 + heads=$3 + gpus=$4 + test_naive_pp $layers $dim $heads $gpus + test_naive_tp $layers $dim $heads $gpus + test_naive_hybrid_tp_pp $layers $dim $heads $gpus + test_coshard_pp $layers $dim $heads $gpus +} + + +# test Layers Dim Heads Nodes GPUs +# test_naive_tp 42 1024 32 2 8 +test_coshard_hybrid_tp_pp 42 1024 32 2 16 + +# test_naive_tp 50 1024 32 2 8 +test_coshard_hybrid_tp_pp 50 1024 32 2 16 + +# test_naive_tp 34 1024 32 2 8 +test_coshard_hybrid_tp_pp 34 1024 32 2 16 + +python scripts/keep.py --gpus 8 From 919af91e2a7109d1bc25ca0147ce4c04a93932df Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Apr 2022 16:51:09 +0800 Subject: [PATCH 0717/1892] correctness verify --- handcraft/mbart/train.py | 781 +++++++++++++++++++++++++++++++++++ handcraft/module/schedule.py | 211 +++++++++- handcraft/module/stage.py | 9 + 3 files changed, 997 insertions(+), 4 deletions(-) create mode 100644 handcraft/mbart/train.py diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py new file mode 100644 index 00000000..d2a8164c --- /dev/null +++ b/handcraft/mbart/train.py @@ -0,0 +1,781 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + handcraft/mbart/train.py \ + --layers 12 --hidden-size 1024 --heads 16 \ + --dp-size 1 --pp-size 4 --tp-size 1 \ + --bs 4 --micro-bs 1 --schedule 1f1b +""" + +from typing import Optional +import argparse +import math +import numpy as np +import torch + +import cube +from cube.runtime.device import DeviceGroup +from cube.runtime.adapter.reducer import Reducer +from cube.runtime.adapter.distnn import ReduceBroadcast, AllReduceIdentity, IdentityAllreduce + +from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary +from cube.profiler.timer import print_each_rank + + +from handcraft.module.schedule import schedule_1f1b, schedule_tp1f1b +from handcraft.module.stage import PipeStage +from handcraft.mbart.swap import SwapEmbed, get_swap_parameters + +torch.manual_seed(0) +np.random.seed(0) + + +parser = argparse.ArgumentParser(description='mbart') +# model arch +parser.add_argument('--layers', type=int, default=12, + help='number encoder/decoder of layers') +parser.add_argument('--hidden-size', type=int, default=1024, + help='hidden size') +parser.add_argument('--heads', type=int, default=16, + help='number of heads') +# training config +parser.add_argument('--bs', type=int, default=256, + help='num of micro batch') +parser.add_argument('--micro-bs', type=int, default=1, + help='micro batch size') +parser.add_argument('--fp16', action='store_true', default=False) +# parallelism +parser.add_argument('--pp-size', type=int, default=1, + help='pipeline parallelism size') +parser.add_argument('--tp-size', type=int, default=1, + help='tensor parallelism size') +parser.add_argument('--dp-size', type=int, default=1, + help='data parallelism size') +parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b', 'tp1f1b'], + help='scheduling algorithm') +parser.add_argument('--use-swap', action='store_true', default=False, + help='swap on embedding weight') + +args = parser.parse_args() +print(args) + +_tp_group = -1 + +_dp_group = -1 +_dp_reducer = None + +_pp_group = -1 +_pp_global_ranks = () +_first_encoder_stage = 0 +_first_decoder_stage = 0 +_layer_divisions = [] + +_schedule = schedule_1f1b if args.schedule == '1f1b' else schedule_tp1f1b +if args.schedule == 'tp1f1b': + assert args.tp_size == 1 and args.dp_size == 1, "tp1f1b only supports pure pipeline" + +_pp_embed_group = -1 + +cube.init() +dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( + [args.dp_size, args.pp_size, args.tp_size] +) + +if len(dp_ranks) != 1: + assert False, "DP is not supported yet" + print_each_rank(f'initializing dp ranks: {dp_ranks}') + _dp_group = DeviceGroup().get_group(dp_ranks) + _dp_reducer = Reducer(dp_ranks) + +if len(tp_ranks) != 1: + print_each_rank(f'initializing tp ranks: {tp_ranks}') + _tp_group = DeviceGroup().get_group(tp_ranks) + +if len(pp_ranks) != 1: + print_each_rank(f'initializing pp ranks: {pp_ranks}') + _pp_group = DeviceGroup().get_group(pp_ranks) + _pp_global_ranks = tuple(pp_ranks) + + # layer division + encoder_time = [1] * args.layers + decoder_time = [1] * args.layers + times = encoder_time + decoder_time + num_stages = torch.distributed.get_world_size(_pp_group) + budget = sum(times) // num_stages + print_each_rank(f'budget: {budget}', rank_only=0) + start, end = 0, 1 + for idx in range(num_stages): + accum = times[start] + assert end <= args.layers * 2 + while end != args.layers * 2: + accum += times[end] + if accum > budget: + break + end += 1 + if idx == num_stages - 1: + end = args.layers * 2 + _layer_divisions.append((start, end)) + if start <= args.layers and end > args.layers: + _first_decoder_stage = idx + start, end = end, end+1 + assert _first_decoder_stage != _first_encoder_stage, "Not supported yet" +else: + _layer_divisions = [(0, args.layers * 2)] +print_each_rank( + f"layer divisions: {_layer_divisions} | " + f"first encoder stage: {_first_encoder_stage} | " + f"first decoder stage: {_first_decoder_stage}", rank_only=0 +) + + +# create embed group: first encoder, first decoder +if args.schedule == '1f1b' and args.pp_size > 1: + grid = np.arange( + args.pp_size * args.tp_size).reshape((args.pp_size, args.tp_size)) + encoder_preprocess = grid[_first_encoder_stage,:] + decoder_preprocess = grid[_first_decoder_stage,:] + embed_ranks = np.vstack((encoder_preprocess, decoder_preprocess)) + grank = torch.distributed.get_rank() + for gid in range(args.tp_size): + embed_rank = embed_ranks[:,gid] + embed_rank = np.squeeze(embed_rank).tolist() + print_each_rank(f'creating embed group: {embed_rank}') + group = DeviceGroup().get_group(embed_rank) + if grank in embed_rank: + print(f'rank [{grank}]: embedding group: {embed_rank}') + _pp_embed_group = group + + +class Config: + + num_embeddings = 500000 + decoder_layers = args.layers + encoder_layers = args.layers + embed_dim = args.hidden_size + attention_heads = args.heads + + attention_inner_dim = attention_heads * 64 + ffn_dim = 4 * embed_dim + + attention_dropout = 0.0 # for correctness veirfication + activation_dropout = 0.0 # for correctness veirfication + dropout = 0.0 # for correctness veirfication + + max_target_positions = 1024 + max_source_positions = 1024 + + # classification task + pooler_dropout = 0.0 + num_classes = 3 + + +def attn_fn(query: torch.Tensor, key: torch.Tensor, + wq: torch.Tensor, wq_bias: Optional[torch.Tensor], + wk: torch.Tensor, wk_bias: Optional[torch.Tensor], + wv: torch.Tensor, wv_bias: Optional[torch.Tensor], + wout: torch.Tensor, wout_bias: Optional[torch.Tensor], + h: int, scale: float, dropout: float, mask=True): + """ + query, key: (L, N, E) = (seqlen, batch size, embed_dim) + wq, wk, wv weight: [(num_head * dim_head), E] + dropout: float + h: int: number of heads + """ + num_head = h + L, N = query.size(0), query.size(1) + dim_head = wq.size(0) // num_head + + q = torch.nn.functional.linear(query, wq, wq_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(key, wk, wk_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(key, wv, wv_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + if mask: # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, wout, wout_bias) # L N (h d), E E -> L N E + return output + + +class PositionalEmbedding(torch.nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, seq_len: int): + positions = torch.arange( + 0, seq_len, dtype=torch.long, device=torch.cuda.current_device() + ) + return super().forward(positions + self.offset) + + +class MultiheadAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, bias=True): + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + self.inner_dim = inner_dim + self.head_dim = inner_dim // num_heads + self.num_heads = num_heads // self.tp_size + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # K + self.k_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) + self.k_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None + # V + self.v_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) + self.v_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None + # Q + self.q_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) + self.q_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, self.inner_dim // self.tp_size)) + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + + def forward(self, query: torch.Tensor, key: torch.Tensor): + if self.tp_size > 1: + if key is not query: + key = IdentityAllreduce.apply(key, self.tp_group) + query = IdentityAllreduce.apply(query, self.tp_group) + attn = attn_fn(query, key, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p) + if self.tp_size > 1: + attn = AllReduceIdentity.apply(attn, self.tp_group) + return attn + + +class EncoderLayer(PipeStage): + + def __init__(self, cfg: Config): + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + self.cfg = cfg + self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) + self.dropout = torch.nn.Dropout(p=cfg.dropout) + self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim) + self.fc2 = torch.nn.Linear(cfg.ffn_dim, cfg.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) + + self.inputs_info = ( + ((self.cfg.max_source_positions, 1, self.cfg.embed_dim),), + (torch.float32 if not args.fp16 else torch.float16,) + ) + self.outputs_info = ( + ((self.cfg.max_source_positions, 1, self.cfg.embed_dim),), + (torch.float32 if not args.fp16 else torch.float16,) + ) + + def forward(self, x): + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x, x) + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + + if self.tp_size > 1: + x = IdentityAllreduce.apply(x, self.tp_group) + x = self.fc1(x) + x = torch.nn.functional.gelu(x) + x = self.activation_dropout(x) + x = self.fc2(x) + if self.tp_size > 1: + x = AllReduceIdentity.apply(x, self.tp_group) + + x = self.dropout(x) + + x = x + residual + return x + + +class DecoderLayer(PipeStage): + + def __init__(self, cfg: Config): + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + self.cfg = cfg + self.dropout = torch.nn.Dropout(p=cfg.dropout) + self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) + self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + + self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) + # encoder atten + self.encoder_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) + self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) + + self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim) + self.fc2 = torch.nn.Linear(cfg.ffn_dim, cfg.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) + + self.inputs_info = ( + ((self.cfg.max_target_positions, 1, self.cfg.embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim)), + (torch.float32 if not args.fp16 else torch.float16, + torch.float32 if not args.fp16 else torch.float16,) + ) + self.outputs_info = ( + ((self.cfg.max_target_positions, 1, self.cfg.embed_dim), + (self.cfg.max_source_positions, 1, self.cfg.embed_dim)), + (torch.float32 if not args.fp16 else torch.float16, + torch.float32 if not args.fp16 else torch.float16,) + ) + + def forward(self, x, encoder_out): + # print(f'decoder layer: x: {x.size()}, encoder_out: {encoder_out.size()}') + residual = x + x = self.self_attn_layer_norm(x) + + # self attention + x = self.self_attn(x, x) + x = self.dropout(x) + x = residual + x + + # encoder attn + residual = x + x = self.encoder_attn_layer_norm(x) + x = self.encoder_attn(x, encoder_out) + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + + if self.tp_size > 1: + x = IdentityAllreduce.apply(x, self.tp_group) + x = self.fc1(x) + x = torch.nn.functional.gelu(x) + x = self.activation_dropout(x) + x = self.fc2(x) + if self.tp_size > 1: + x = AllReduceIdentity.apply(x, self.tp_group) + + x = self.dropout(x) + x = x + residual + return x, encoder_out + + +class MBartClassificationHead(torch.nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.num_classes = num_classes + self.dense = torch.nn.Linear(input_dim, inner_dim) + self.dropout = torch.nn.Dropout(p=pooler_dropout) + self.out_proj = torch.nn.Linear(inner_dim, num_classes) + self.loss_fct = torch.nn.CrossEntropyLoss() + + def forward(self, dec: torch.Tensor, labels): + # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] + dec = dec.transpose(0, 1)[:,-1,:] + sentence_represent = dec + hidden_states = self.dropout(sentence_represent) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + logits = self.out_proj(hidden_states) + loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) + return loss + + +class ShardEmbed(torch.nn.Module): + + def __init__(self, cfg: Config): + super().__init__() + self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) + self.tp_id = 0 if self.tp_group == -1 else torch.distributed.get_rank(self.tp_group) + + self.cfg = cfg + print(f'initialize sharding embed (x{self.tp_size})') + + self.swap = args.use_swap + if self.swap: + assert args.schedule == '1f1b', "only 1f1b can use swap" + self.embed = SwapEmbed(self.cfg.num_embeddings, self.cfg.embed_dim) + else: + self.vocab_start_index = self.cfg.num_embeddings // self.tp_size * self.tp_id + self.vocab_end_index = self.cfg.num_embeddings // self.tp_size * (self.tp_id + 1) + self.weight = torch.nn.Parameter( + torch.ones((self.cfg.num_embeddings // self.tp_size, self.cfg.embed_dim)) + ) + + # encoder-preprocess + self.embed_positions_encoder = PositionalEmbedding(cfg.max_source_positions, cfg.embed_dim) + self.embed_scale_encoder = math.sqrt(cfg.embed_dim) + self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.embed_dim) + + # decoder-preprocess + self.embed_scale_decoder = math.sqrt(cfg.embed_dim) + self.embed_positions_decoder = PositionalEmbedding(cfg.max_target_positions, cfg.embed_dim) + self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.embed_dim) + + self.init_weight() + + def embed_lookup(self, tokens, dst: Optional[int] = None): + """ + Embedding lookup + if dst is None, use all + """ + if self.swap: + embed = self.embed(tokens) + elif self.tp_size > 1: + mask = (tokens < self.vocab_start_index) | \ + (tokens >= self.vocab_end_index) + tokens = tokens.clone() - self.vocab_start_index + tokens[mask] = 0 + embed = torch.nn.functional.embedding(tokens, self.weight) + embed[mask, :] = 0.0 + if dst is None: + assert _tp_group != -1 + embed = AllReduceIdentity.apply(embed, self.tp_group) + else: + assert self.tp_group is None # args.sharding = True + embed = ReduceBroadcast.apply(embed, dst, None) + else: + embed = torch.nn.functional.embedding(tokens, self.weight) + return embed + + def forward(self, tokens, encoder=False, decoder=False, dst: Optional[int] = None): + """ + If dst is not None: the embedding is sharded across all devices + using tp1f1b, and hence requires a Reduce on the target rank. + """ + assert encoder ^ decoder, "can only be either encoder or decoder" + embed = self.embed_lookup(tokens, dst) + x = embed + self.embed_positions_encoder(embed.size(1)) + if encoder: + x = self.layernorm_embedding_encoder(x) + if decoder: + x = self.layernorm_embedding_decoder(x) + x = torch.nn.functional.dropout(x, p=0.0) + x = x.transpose(0, 1) + return x + + def init_weight(self): + for param in self.parameters(): + torch.nn.init.constant_(param, 0.1) + + +class MBart(PipeStage): + + def __init__(self, cfg: Config): + super().__init__() + self.set_pipeline(_pp_global_ranks) + self.first_encoder_stage = _first_encoder_stage + self.first_decoder_stage = _first_decoder_stage + + self.cfg = cfg + encoders = [EncoderLayer(cfg) for _ in range(self.cfg.encoder_layers)] + decoders = [DecoderLayer(cfg) for _ in range(self.cfg.decoder_layers)] + + start, end = _layer_divisions[self.stage_local_rank] + print_each_rank(f'initializing layer ranging from [{start}, {end})') + + self.encoder_preprocess = self.is_first_stage + self.encoder_forward = start < cfg.encoder_layers + self.decoder_preprocess = start <= cfg.encoder_layers and end > cfg.encoder_layers + self.decoder_forward = end > cfg.encoder_layers + self.sharding = args.schedule == 'tp1f1b' + print_each_rank( + f"encoder: (pre: {self.encoder_preprocess}) {self.encoder_forward} | " + f"decoder (pre: {self.decoder_preprocess}) {self.decoder_forward} | " + f"post-process: {self.is_last_stage} | sharding {self.sharding}" + ) + + inputs_info = None + outputs_info = None + + self.embed: ShardEmbed = None + if self.encoder_preprocess: + self.embed = ShardEmbed(cfg) if self.embed is None else self.embed + inputs_info = ((), ()) if inputs_info is None else inputs_info + + if self.encoder_forward: + self.encoders = torch.nn.ModuleList( + encoders[start:min(end, cfg.encoder_layers)] + ) + self.layer_norm_encoder = None + if self.decoder_preprocess or self.decoder_forward: + self.layer_norm_encoder = torch.nn.LayerNorm(cfg.embed_dim) + + inputs_info = self.encoders[0].inputs_info if inputs_info is None else inputs_info + outputs_info = self.encoders[-1].outputs_info + + if self.decoder_preprocess: + self.embed = ShardEmbed(cfg) if self.embed is None else self.embed + inputs_info = encoders[-1].outputs_info if inputs_info is None else inputs_info + + if self.decoder_forward: + self.decoders = torch.nn.ModuleList( + decoders[max(0, start-cfg.encoder_layers): end-cfg.encoder_layers] + ) + self.layer_norm_decoder = None + if self.is_last_stage: + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.embed_dim) + + inputs_info = self.decoders[0].inputs_info if inputs_info is None else inputs_info + outputs_info = self.decoders[-1].outputs_info + + if self.is_last_stage: + self.head = MBartClassificationHead(cfg.embed_dim, 1024, cfg.num_classes, 0.0) + outputs_info = ((1,), torch.float32 if not args.fp16 else torch.float16) + + if self.sharding: + self.embed = ShardEmbed(cfg) if self.embed is None else self.embed + + assert inputs_info is not None + assert outputs_info is not None + self.inputs_info = inputs_info + self.outputs_info = outputs_info + print_each_rank(f'stage: inputs: {inputs_info} | outputs: {outputs_info}') + self.init_weight() + + def init_weight(self): + for param in self.parameters(): + torch.nn.init.constant_(param, 0.01) + + def forward_encoder_shard(self): + """ + Return detached outputs with enabled gradient + """ + source_tokens, _, _ = self.data + enc = self.embed(source_tokens, encoder=True, dst=self.first_encoder_stage) + model.push(enc, 'encoder_sharding_output') + if self.stage_global_rank == self.first_encoder_stage: + enc = enc.detach().requires_grad_() + self.push(enc, 'encoder_preprocess') + return enc + + def forward_decoder_shard(self): + """ + Return detached outputs with enabled gradient + """ + _, prev_tokens, _ = self.data + dec = self.embed(prev_tokens, decoder=True, dst=self.first_decoder_stage) + model.push(dec, 'decoder_sharding_output') + if self.stage_global_rank == self.first_decoder_stage: + dec = dec.detach().requires_grad_() + self.push(dec, 'decoder_preprocess') + return dec + + def forward(self, enc=None, dec=None, recompute=False): + """ + enc: encoder input + dec: decoder input + recompute: outside control for tp1f1b + """ + if self.encoder_preprocess: + if self.sharding: + if recompute: + enc = self.get_last('encoder_preprocess') + else: + enc = self.pop('encoder_preprocess') + else: + source_tokens, _, _ = self.data + enc = self.embed(source_tokens, encoder=True) + + if self.encoder_forward: + for layer in self.encoders: + enc = layer(enc) + if self.layer_norm_encoder is not None: + enc = self.layer_norm_encoder(enc) + output = enc + + if self.decoder_preprocess: + if self.sharding: + if recompute: + dec = self.get_last('decoder_preprocess') + else: + dec = self.pop('decoder_preprocess') + else: + _, prev_tokens, _ = self.data + dec = self.embed(prev_tokens, decoder=True) + + if self.decoder_forward: + assert enc is not None + for layer in self.decoders: + dec, enc = layer(dec, enc) + if self.layer_norm_decoder is not None: + dec = self.layer_norm_decoder(dec) + output = (enc, dec) + + if self.is_last_stage: + _, _, label = self.data + loss = self.head(dec, label) + output = loss + + return output + + +def reduce_embed(model: MBart, pp_embed_group): + """ + Embedding gradients needs to be reduced across pipeline stages + """ + if pp_embed_group == -1: + return + if isinstance(model.embed, torch.nn.Module): + if model.embed.swap: + with torch.no_grad(): + grad = model.embed.weight.grad + grad = grad.cuda() + else: + grad = model.embed.weight.grad + else: + grad = None + if grad is not None: + CudaTimer().start('comm') + torch.distributed.all_reduce(grad, group=pp_embed_group) + torch.cuda.synchronize() + CudaTimer().stop('comm') + if isinstance(model.embed, torch.nn.Module): + if model.embed.swap: + with torch.no_grad(): + model.embed.embed.weight.grad.copy_(grad) + torch.cuda.synchronize() + + +class MBartDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int, cfg: Config): + self.bs = batch_size + self.cfg = cfg + super().__init__( + shapes=( + [batch_size, cfg.max_source_positions,], + [batch_size, cfg.max_target_positions,], + [batch_size,] + ), + dtypes=( + torch.int64, + torch.int64, + torch.int, + ), + batch_dims=(0, 0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + source_token = torch.randint( + 0, 25000, + size=(self.bs, cfg.max_source_positions,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + target_token = torch.randint( + 0, 25000, + size=(self.bs, cfg.max_target_positions,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + labels = torch.randint( + 0, self.cfg.num_classes, + size=(self.bs,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + return (source_token, target_token, labels) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + +if __name__ == '__main__': + + + cfg = Config() + print_each_rank(f'enc/dec layer#: {cfg.encoder_layers}, embed_dim#: {cfg.embed_dim}, heads#: {cfg.attention_heads}, ffn_dim#: {cfg.ffn_dim}', rank_only=0) + + model = MBart(cfg) + model = model.half().cuda() if args.fp16 else model.cuda() + + dataloader = MBartDataLoader(args.micro_bs, cfg) + + parameters = get_swap_parameters() + list(model.parameters()) if args.use_swap else model.parameters() + optimizer = torch.optim.Adam(parameters, lr=3e-05, betas=(0.9, 0.98)) + + print_each_rank('model weight consumpition:') + memory_summary() + + CudaTimer(enable=False).warmup() + iter_num = 6 + for step in range(iter_num): + if step >= 2: + CudaTimer(enable=True).start('e2e') + + # train 1 step + num_microbatch = args.bs // args.micro_bs + if args.pp_size > 1: + _schedule(model, dataloader, num_microbatch, recompute=True) + reduce_embed(model, _pp_embed_group) + else: + for _ in range(num_microbatch): + model.data = next(dataloader) + loss = model() + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + if step >= 2: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('memory after optimizer:', rank_only=0) + memory_summary() + + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-2, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-2) + memory_summary() diff --git a/handcraft/module/schedule.py b/handcraft/module/schedule.py index e9d9dfd1..d2362bdf 100644 --- a/handcraft/module/schedule.py +++ b/handcraft/module/schedule.py @@ -1,9 +1,7 @@ -from typing import List, Tuple +from typing import List import torch from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup from handcraft.module.stage import PipeStage @@ -281,6 +279,9 @@ def schedule_1f1b(model: PipeStage, with torch.no_grad(): outputs = forward_step(model, *inputs) model.push(None, 'outputs') + # correctness checkprint + # if model.is_last_stage: + # print(outputs) else: outputs = forward_step(model, *inputs) model.push(outputs, 'outputs') @@ -321,4 +322,206 @@ def schedule_1f1b(model: PipeStage, if not model.is_first_stage: send_backward(input_grads, prev_rank) - model.assert_empty_cached() \ No newline at end of file + model.assert_empty_cached() + + +def schedule_tp1f1b(model: PipeStage, + dataloader, + num_microbatch: int, + recompute=False): + + def tp_encoder_preprocess(model: PipeStage) -> torch.Tensor: + model.data = next(dataloader) + enc = model.forward_encoder_shard() + return (enc,) + + def tp_decoder_preprocess(model: PipeStage) -> torch.Tensor: + model.data = next(dataloader) + dec = model.forward_decoder_shard() + return (dec,) + + def tp_encoder_backward(model: PipeStage): + enc = model.pop('encoder_sharding_output') + if model.stage_local_rank == model.first_encoder_stage: + grads = model.pop('encoder_sharding_grad') + else: + grads = (torch.empty_like(enc),) + backward_step((), (enc,), grads) + + def tp_decoder_backward(model: PipeStage): + dec = model.pop('decoder_sharding_output') + if model.stage_local_rank == model.first_decoder_stage: + grads = model.pop('decoder_sharding_grad') + else: + grads = (torch.empty_like(dec),) + backward_step((), (dec,), grads) + + num_stage = model.num_stages + rank = model.stage_local_rank + prev_rank = model.prev_stage_global_grank + next_rank = model.next_stage_global_rank + fofst = [-(step // 2) for step in range(num_stage)] + bofst = [-(num_stage - 1 - (step // 2)) for step in range(num_stage)] + + fofst = fofst[model.stage_local_rank] + bofst = bofst[model.stage_local_rank] + last_backward = (None,) + last_forward = (None,) + + for step in range(num_microbatch + num_stage - 1): + fmid, bmid = step + fofst, step + bofst + encoder_fmid = step + decoder_fmid = step - num_stage // 2 // 2 + encoder_bmid = step + 1 - num_stage // 2 * 2 + decoder_bmid = step + 1 - int(num_stage // 2 * 1.5) + do_backward = 0 <= bmid and bmid <= num_microbatch - 1 + do_forward = 0 <= fmid and fmid <= num_microbatch - 1 + + # step1: tp encoder forward + encoder_inputs = None + if 0 <= encoder_fmid and encoder_fmid <= num_microbatch - 1: + encoder_inputs = tp_encoder_preprocess(model) + # step2: tp decoder forward + decoder_inputs = None + if 0 <= decoder_fmid and decoder_fmid <= num_microbatch - 1: + decoder_inputs = tp_decoder_preprocess(model) + + # step 3: forward + backward + if rank % 2 == 0: + # inter-barrier + inputs = () + if not model.is_first_stage: + if do_forward and last_backward != (None,): + # print(f'rank {rank} send backward grad + recv forward output ') + inputs = send_backward_recv_forward(last_backward, model, prev_rank) + elif do_forward: + # print(f'rank {rank} recv forward output ') + inputs = recv_forward(model, prev_rank) + elif last_backward != (None,): + # print(f'rank {rank} send backward grad ') + send_backward(last_backward, prev_rank) + + # forward + if do_forward: + + if model.stage_local_rank == model.first_encoder_stage and encoder_inputs is not None: + model.push(encoder_inputs, 'inputs') + elif model.stage_local_rank == model.first_decoder_stage and decoder_inputs is not None: + assert len(inputs) == 1 and len(decoder_inputs) == 1 + model.push((inputs[0], decoder_inputs[0]), 'inputs') + else: + model.push(inputs, 'inputs') + + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs, recompute=True) + model.push(None, 'outputs') + else: + outputs = forward_step(model, *inputs) + model.push(outputs, 'outputs') + + # recompute if backward is needed + if do_backward: + inputs, outputs_bp = model.pop('inputs'), model.pop('outputs') + if recompute: + assert outputs_bp is None + outputs_bp = forward_step(model, *inputs) + + # intra-barrier send recv + output_grads = (None,) + if (do_forward and not model.is_last_stage) and (do_backward and not model.is_last_stage): + # send forward recv backward + # print(f'rank {rank} recv backward grad + send forward output ') + output_grads = send_forward_recv_backward(outputs, model, next_rank) + elif do_forward and not model.is_last_stage: + # print(f'rank {rank} send forward output ') + send_forward(outputs, next_rank) + elif do_backward and not model.is_last_stage: + # print(f'rank {rank} recv backward grad ') + output_grads = recv_backward(model, next_rank) + + # backward + last_backward = (None,) + if do_backward: + # inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) + input_grads = backward_step(inputs, outputs_bp, output_grads) + + if model.stage_local_rank == model.first_encoder_stage: + assert len(input_grads) == 1 + model.push(input_grads, 'encoder_sharding_grad') + elif model.stage_local_rank == model.first_decoder_stage: + assert len(input_grads) == 2 + model.push((input_grads[1],), 'decoder_sharding_grad') + input_grads = (input_grads[0],) + last_backward = input_grads + + # step 3: backward + forward + if rank % 2 == 1: + # inter-barrier + if model.is_last_stage: + output_grads = (None,) + else: + if do_backward and last_forward != (None,): + # print(f'rank {rank} recv backward grad + send forward output ') + output_grads = send_forward_recv_backward(last_forward, model, next_rank) + elif do_backward: + # print(f'rank {rank} recv backward grad ') + output_grads = recv_backward(model, next_rank) + elif last_forward != (None,): + # print(f'rank {rank} send forward output ') + send_forward(last_forward, next_rank) + + # backward + last_backward = (None,) + if do_backward: + inputs, outputs_bp = model.pop('inputs'), model.pop('outputs') + # backward + input_grads = backward_step(inputs, outputs_bp, output_grads) + last_backward = input_grads + + # intra-barrier + if do_backward and do_forward: + # print(f'rank {rank} send backward grad + recv forward output ') + inputs = send_backward_recv_forward(input_grads, model, prev_rank) + elif do_backward: + # print(f'rank {rank} send backward grad ') + send_backward(input_grads, prev_rank) + elif do_forward: + # print(f'rank {rank} recv forward output ') + inputs = recv_forward(model, prev_rank) + + # forward + last_forward = (None,) + if do_forward: + # forward step + model.push(inputs, 'inputs') + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs, recompute=True) + model.push(None, 'outputs') + # correctness check print + # if model.is_last_stage: + # print(outputs) + else: + outputs = forward_step(model, *inputs) + model.push(outputs, 'outputs') + last_forward = outputs + + next_backward = 0 <= (bmid+1) and (bmid+1) <= num_microbatch - 1 + if next_backward: + if recompute: + inputs, outputs_bp = model.pop('inputs'), model.pop('outputs') + assert outputs_bp is None + outputs = forward_step(model, *inputs) + model.push_ahead(inputs, 'inputs') + model.push_ahead(outputs, 'outputs') + + # step 4: sharding decoder backward + if 0 <= decoder_bmid and decoder_bmid <= num_microbatch - 1: + tp_decoder_backward(model) + + # step 5: sharding encoder backward + if 0 <= encoder_bmid and encoder_bmid <= num_microbatch - 1: + tp_encoder_backward(model) + + model.assert_empty_cached() diff --git a/handcraft/module/stage.py b/handcraft/module/stage.py index b2939a50..fd32a90a 100644 --- a/handcraft/module/stage.py +++ b/handcraft/module/stage.py @@ -80,9 +80,18 @@ def set_pipeline(self, group_global_ranks: Tuple[int]): self._is_first_stage = self._stage_lrank == 0 self._is_last_stage = self._stage_lrank == self.num_stages - 1 + def get_last(self, region: str = 'default') -> Any: + return self._cached[region][-1] + + def pop_last(self, region: str = 'default') -> Any: + return self._cached[region].pop(-1) + def pop(self, region: str = 'default') -> Any: return self._cached[region].pop(0) + def push_ahead(self, val: Any, region: str = 'default') -> Any: + self._cached[region] = [val] + self._cached[region] + def push(self, val: Any, region: str = 'default'): if region not in self._cached: self._cached[region] = [] From 1e288061ebc995b0a8b7a375709f630ff9c664db Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Apr 2022 17:14:24 +0800 Subject: [PATCH 0718/1892] renew mbart --- handcraft/mbart/mbart.py | 771 ----------------- handcraft/mbart/mbart_hybrid.py | 787 ------------------ handcraft/mbart/run-1f1b-swap.sh | 100 --- handcraft/mbart/run-recompute-arch.sh | 76 -- .../mbart/run-recompute-full-v100-32gb.sh | 301 ------- handcraft/mbart/schedule.py | 551 ------------ handcraft/mbart/test-fp32.sh | 203 +++++ handcraft/mbart/tp.py | 130 --- 8 files changed, 203 insertions(+), 2716 deletions(-) delete mode 100644 handcraft/mbart/mbart.py delete mode 100644 handcraft/mbart/mbart_hybrid.py delete mode 100755 handcraft/mbart/run-1f1b-swap.sh delete mode 100755 handcraft/mbart/run-recompute-arch.sh delete mode 100755 handcraft/mbart/run-recompute-full-v100-32gb.sh delete mode 100644 handcraft/mbart/schedule.py create mode 100755 handcraft/mbart/test-fp32.sh delete mode 100644 handcraft/mbart/tp.py diff --git a/handcraft/mbart/mbart.py b/handcraft/mbart/mbart.py deleted file mode 100644 index d17edf1d..00000000 --- a/handcraft/mbart/mbart.py +++ /dev/null @@ -1,771 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers 12 --hidden-size 1024 --heads 16 \ - --use-1f1b --nmb 4 --iter-nmb 4 \ - --use-recompute --use-swap -""" - -from typing import Optional -import argparse -import math -import torch - -import cube -from cube.runtime.device import DeviceGroup -from cube.runtime.syndata import SynTextDataLoader -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - -from handcraft.mbart.schedule import schedule_1f1b, schedule_tp_1f1b_pack -from handcraft.mbart.swap import SwapEmbed, get_swap_parameters -from handcraft.mbart.tp import ReduceBroadcast - - -_pp_group = -1 -_pp_embed_group = -1 -_pp_next_rank = None -_pp_prev_rank = None -_layer_divisions = [] - -parser = argparse.ArgumentParser(description='swin') -# model arch -parser.add_argument('--layers', type=int, default=12, - help='number encoder/decoder of layers') -parser.add_argument('--hidden-size', type=int, default=1024, - help='hidden size') -parser.add_argument('--heads', type=int, default=16, - help='number of heads') -# training config -parser.add_argument('--nmb', type=int, - help='num of micro batch') -parser.add_argument('--iter-nmb', type=int, default=0, - help='num of micro batch per scheduling iteration (1f1b only)') -# parallelism -parser.add_argument('--use-1f1b', action='store_true', - help='use 1f1b scheduling') -parser.add_argument('--use-tp1f1b-pack', action='store_true', - help='use tensor parallel 1f1b') -parser.add_argument('--use-recompute', action='store_true', - help='use recompute for a stage') -parser.add_argument('--use-swap', action='store_true', - help='use embedding swap (1f1b only)') -args = parser.parse_args() -print(args) - -cube.init() -pp_ranks = list(range(DeviceGroup().world_size)) -print_each_rank(f'my pp ranks: {pp_ranks}') - -if _pp_group == -1: - _pp_group = DeviceGroup().get_group(pp_ranks) - idx = pp_ranks.index(DeviceGroup().rank) - _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] - _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] - - encoder_time = [1] * args.layers - decoder_time = [1] * args.layers - times = encoder_time + decoder_time - num_stages = torch.distributed.get_world_size(_pp_group) - budget = sum(times) // num_stages - print_each_rank(f'budget: {budget}', rank_only=0) - start, end = 0, 1 - for idx in range(num_stages): - accum = times[start] - assert end <= args.layers * 2 - while end != args.layers * 2: - accum += times[end] - if accum > budget: - break - end += 1 - if idx == num_stages - 1: - end = args.layers * 2 - _layer_divisions.append((start, end)) - start, end = end, end+1 - - # uniform division algorithm: - # num_stages = torch.distributed.get_world_size(_pp_group) - # chunk = args.layers // (num_stages // 2) - # encoder_nlayers = [chunk] * (num_stages // 2) - # for idx in range(args.layers % (num_stages // 2)): - # encoder_nlayers[-idx] += 1 - # encoder_layers = [ - # (sum(encoder_nlayers[:rank]), - # sum(encoder_nlayers[:rank+1])) for rank in range(num_stages // 2) - # ] - # decoder_layers = [ - # (args.layers + sum(encoder_nlayers[:rank]), - # args.layers + sum(encoder_nlayers[:rank+1])) for rank in range(num_stages // 2) - # ] - # _layer_divisions = encoder_layers + decoder_layers - - print_each_rank(f'layer divisions: {_layer_divisions}', rank_only=0) - - -# create embed group: first encoder, first decoder -if args.use_1f1b: - encoder_preprocess = 0 - decoder_preprocess = None - for rank in range(len(pp_ranks)): - start, end = _layer_divisions[rank] - if start <= args.layers and end > args.layers: - decoder_preprocess = rank - break - assert decoder_preprocess is not None - embed_ranks = [encoder_preprocess, decoder_preprocess] - _pp_embed_group = DeviceGroup().get_group(embed_ranks) - - -class Config: - - - # 610M model -> original setting - # num_embeddings = 250027 - # decoder_layers = 12 - # encoder_layers = 12 - # embed_dim = 1024 - # attention_heads = 16 - # attention_inner_dim = attention_heads * 64 - # ffn_dim = 4 * embed_dim - - num_embeddings = 500000 # 250027 + int(250027*scale_p) - # decoder_layers = 12 + int(12*scale_p) - # encoder_layers = 12 + int(12*scale_p) - # embed_dim = 1024 + int(1024*scale_p) - # attention_heads = 16 + int(16*scale_p) if scale < 6 else 40 + 8*(scale-6) - decoder_layers = args.layers - encoder_layers = args.layers - embed_dim = args.hidden_size - attention_heads = args.heads - - attention_inner_dim = attention_heads * 64 - ffn_dim = 4 * embed_dim - - attention_dropout = 0.0 - activation_dropout = 0.0 - dropout = 0.1 - - max_target_positions = 1024 - max_source_positions = 1024 - - # classification task - pooler_dropout = 0.0 - num_classes = 3 - - -def attn_fn(query: torch.Tensor, key: torch.Tensor, - wq: torch.Tensor, wq_bias: Optional[torch.Tensor], - wk: torch.Tensor, wk_bias: Optional[torch.Tensor], - wv: torch.Tensor, wv_bias: Optional[torch.Tensor], - wout: torch.Tensor, wout_bias: Optional[torch.Tensor], - h: int, scale: float, dropout: float, mask=True): - """ - query, key: (L, N, E) = (seqlen, batch size, embed_dim) - wq, wk, wv weight: [(num_head * dim_head), E] - dropout: float - h: int: number of heads - """ - num_head = h - L, N = query.size(0), query.size(1) - dim_head = wq.size(0) // num_head - - q = torch.nn.functional.linear(query, wq, wq_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(key, wk, wk_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(key, wv, wv_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # attention mask - if mask: # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L -> (N h) L L - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, wout, wout_bias) # L N (h d), E E -> L N E - return output - - -class PositionalEmbedding(torch.nn.Embedding): - - def __init__(self, num_embeddings: int, embedding_dim: int): - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward(self, seq_len: int): - positions = torch.arange( - 0, seq_len, dtype=torch.long, device=torch.cuda.current_device() - ) - return super().forward(positions + self.offset) - - -class MultiheadAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, bias=True): - super().__init__() - self.inner_dim = inner_dim - self.num_heads = num_heads - self.head_dim = inner_dim // num_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # K - self.k_proj = torch.nn.Parameter(torch.empty(self.inner_dim, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(self.inner_dim)) if bias else None - # V - self.v_proj = torch.nn.Parameter(torch.empty(self.inner_dim, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(self.inner_dim)) if bias else None - # Q - self.q_proj = torch.nn.Parameter(torch.empty(self.inner_dim, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(self.inner_dim)) if bias else None - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, self.inner_dim)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None - - def forward(self, query: torch.Tensor, key: torch.Tensor): - return attn_fn(query, key, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) - - def forward_encoder_decoder_attn(self, query: torch.Tensor, key: torch.Tensor): - return attn_fn(query, key, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) - - def forward_self_attn(self, query): - return attn_fn(query, query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) - - -class EncoderLayer(torch.nn.Module): - - def __init__(self, cfg: Config): - - super().__init__() - self.cfg = cfg - self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim) - self.fc2 = torch.nn.Linear(cfg.ffn_dim, cfg.embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - def input_shape(self): - # L, N, E - return (self.cfg.max_source_positions, 1, self.cfg.embed_dim) - - def output_shape(self): - # L N E - return (self.cfg.max_source_positions, 1, self.cfg.embed_dim) - - def input_dtype(self): - return torch.float32 - - def output_dtype(self): - return torch.float32 - - def forward(self, x): # , encoder_padding_mask: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None): - # print(f'encoder layer: x: {x.size()}') - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x, x) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - x = self.fc1(x) - x = torch.nn.functional.gelu(x) - x = self.activation_dropout(x) - x = self.fc2(x) - - x = self.dropout(x) - x = x + residual - return x - - -class DecoderLayer(torch.nn.Module): - - def __init__(self, cfg: Config): - - super().__init__() - self.cfg = cfg - self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - # encoder atten - self.encoder_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim) - self.fc2 = torch.nn.Linear(cfg.ffn_dim, cfg.embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - def input_shape(self): - return ( - (self.cfg.max_target_positions, 1, self.cfg.embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.embed_dim) - ) - - def output_shape(self): - return ( - (self.cfg.max_target_positions, 1, self.cfg.embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.embed_dim) - ) - - def input_dtype(self): - return (torch.float32, torch.float32) - - def output_dtype(self): - return (torch.float32, torch.float32) - - def forward(self, x, encoder_out): - # print(f'decoder layer: x: {x.size()}, encoder_out: {encoder_out.size()}') - residual = x - x = self.self_attn_layer_norm(x) - - # self attention - x = self.self_attn(x, x) - x = self.dropout(x) - x = residual + x - - # encoder attn - residual = x - x = self.encoder_attn_layer_norm(x) - x = self.encoder_attn(x, encoder_out) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - x = self.fc1(x) - x = torch.nn.functional.gelu(x) - x = self.activation_dropout(x) - x = self.fc2(x) - x = self.dropout(x) - x = x + residual - return x, encoder_out - - -class MBartClassificationHead(torch.nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.num_classes = num_classes - self.dense = torch.nn.Linear(input_dim, inner_dim) - self.dropout = torch.nn.Dropout(p=pooler_dropout) - self.out_proj = torch.nn.Linear(inner_dim, num_classes) - self.loss_fct = torch.nn.CrossEntropyLoss() - - def forward(self, dec: torch.Tensor, labels): - # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] - dec = dec.transpose(0, 1)[:,-1,:] - sentence_represent = dec - hidden_states = self.dropout(sentence_represent) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - logits = self.out_proj(hidden_states) - loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) - return (loss,) - - -class ShardEmbed(torch.nn.Module): - - def __init__(self, cfg: Config, group=-1, swap=False): - """ - group = -1 means no tensor parallelism - """ - super().__init__() - self.cfg = cfg - self.group = group - self.shard_num = torch.distributed.get_world_size(group) if group != -1 else 1 - self.shard_idx = torch.distributed.get_rank(group) if group != -1 else 0 - if self.shard_num > 0: - print(f'[{torch.distributed.get_rank()}]: initialize sharding embed (x{self.shard_num})') - assert not (swap and self.shard_idx > 1), "only 1f1b can use swap" - - self.swap = swap - if swap: - self.embed = SwapEmbed(self.cfg.num_embeddings, self.cfg.embed_dim) - else: - self.vocab_start_index = self.cfg.num_embeddings // self.shard_num * self.shard_idx - self.vocab_end_index = self.cfg.num_embeddings // self.shard_num * (self.shard_idx + 1) - self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.shard_num, self.cfg.embed_dim))) - - # encoder-preprocess - self.embed_positions_encoder = PositionalEmbedding(cfg.max_source_positions, cfg.embed_dim) - self.embed_scale_encoder = math.sqrt(cfg.embed_dim) - self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.embed_dim) - - # decoder-preprocess - self.embed_scale_decoder = math.sqrt(cfg.embed_dim) - self.embed_positions_decoder = PositionalEmbedding(cfg.max_target_positions, cfg.embed_dim) - self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.embed_dim) - - self._inputs = (None, ) - - def set_inputs(self, *inputs): - self._inputs = inputs - - def embed_lookup(self, tokens, dst: Optional[int] = None): - if self.swap: - embed = self.embed(tokens) - elif self.shard_num > 1: - mask = (tokens < self.vocab_start_index) | \ - (tokens >= self.vocab_end_index) - tokens = tokens.clone() - self.vocab_start_index - tokens[mask] = 0 - embed = torch.nn.functional.embedding(tokens, self.weight) - embed[mask, :] = 0.0 - embed = ReduceBroadcast.apply(embed, dst, self.group) - else: - embed = torch.nn.functional.embedding(tokens, self.weight) - return embed - - def encoder_preprocess(self, dst: Optional[int] = None): - source_tokens = self._inputs[0] - source_embed = self.embed_lookup(source_tokens, dst) - embed = self.embed_scale_encoder * source_embed - x = embed + self.embed_positions_encoder(source_tokens.size(1)) - x = self.layernorm_embedding_encoder(x) - x = torch.nn.functional.dropout(x, p=0.0) - enc = x.transpose(0, 1) - return (enc,) - - def decoder_preprocess(self, dst: Optional[int] = None): - prev_output_tokens = self._inputs[0] - target_emb = self.embed_lookup(prev_output_tokens, dst) - embed = self.embed_scale_decoder * target_emb - embed = embed + self.embed_positions_decoder(prev_output_tokens.size(1)) - embed = self.layernorm_embedding_decoder(embed) - embed = torch.nn.functional.dropout(embed, p=0.0) - dec = embed.transpose(0, 1) - return (dec,) - - -class mBARTFull(torch.nn.Module): - - def __init__(self, cfg: Config, shard=True): - super().__init__() - self.cfg = cfg - self.dummy_labels = torch.tensor([1]).cuda() - self._preprocess = [None, None] # enc, dec - - self.rank = DeviceGroup().rank - - global _pp_group - self.pp_group = _pp_group - self.total_layers = cfg.encoder_layers + cfg.decoder_layers - - self.pp_stage = torch.distributed.get_rank(_pp_group) - self.num_stages = torch.distributed.get_world_size(_pp_group) - self.layer_start, self.layer_end = _layer_divisions[self.pp_stage] - - self.encoder_preprocess = self.layer_start == 0 if not shard else False - self.encoder_forward = (self.layer_start < cfg.encoder_layers) - - self.decoder_first_stage = self.layer_start <= cfg.encoder_layers and self.layer_end > cfg.encoder_layers - self.decoder_preprocess = self.decoder_first_stage if not shard else False - self.decoder_forward = (self.layer_end > cfg.encoder_layers) - self.decoder_last_stage = (self.layer_end == cfg.encoder_layers + cfg.decoder_layers) - - self.postprocess = self.decoder_last_stage - - self.encoder_layer_start = self.layer_start - self.encoder_layer_end = min(self.layer_end, cfg.encoder_layers) - - self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) - self.decoder_layer_end = self.layer_end - - if self.encoder_preprocess or self.decoder_preprocess or shard: - self.headtail = ShardEmbed(cfg, group = None if shard else -1, swap=args.use_swap) - else: - self.headtail = None - - # encoders - if self.encoder_forward: - print(f'[{self.rank}]: initializing {self.encoder_layer_end - self.encoder_layer_start} encoder layers') - self.encoders = torch.nn.ModuleList([EncoderLayer(cfg) for _ in range(self.encoder_layer_end - self.encoder_layer_start)]) - if self.encoder_layer_end == cfg.encoder_layers: - self.layer_norm_encoder = torch.nn.LayerNorm(cfg.embed_dim) - else: - self.layer_norm_encoder = None - - # decoders - if self.decoder_forward: - print(f'[{self.rank}]: initializing {self.decoder_layer_end - self.decoder_layer_start} decoder layers') - self.decoders = torch.nn.ModuleList([DecoderLayer(cfg) for _ in range(self.decoder_layer_end - self.decoder_layer_start)]) - if self.decoder_layer_end == cfg.encoder_layers + cfg.decoder_layers: - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.embed_dim) - else: - self.layer_norm_decoder = None - - # postpross - if self.postprocess: - print(f'[{self.rank}]: will compute loss') - self.head = MBartClassificationHead(cfg.embed_dim, 1024, cfg.num_classes, 0.0) - - def input_shape(self): - if self.encoder_preprocess: - return () - elif self.encoder_forward: - return ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - ) - elif self.decoder_preprocess: - return ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - ) - elif self.decoder_first_stage: - return ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - ) - elif self.decoder_forward: - return ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.embed_dim), - ) - elif self.postprocess: - return ((1,),) - assert False - - def output_shape(self): - shape = None - if self.encoder_preprocess or self.encoder_forward: - shape = (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - # decoder preprocess is not allowed to be a single stage - if self.decoder_preprocess or self.decoder_forward: - shape = ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.embed_dim), - ) - if self.postprocess: - shape = ( - (1,), - ) - assert shape is not None - return shape - - def input_dtype(self): - if self.encoder_preprocess: - return () - elif self.encoder_forward: - return (torch.float32,) - elif self.decoder_preprocess: - return (torch.float32,) - elif self.decoder_forward: - return (torch.float32, torch.float32) - else: - assert False - - def output_dtype(self): - dtype = None - if self.encoder_preprocess or self.encoder_forward: - dtype = (torch.float32,) - if self.decoder_preprocess or self.decoder_forward: - if self.pp_stage == self.num_stages - 1: - dtype = (torch.float32,) - else: - dtype = (torch.float32, torch.float32) - if self.postprocess: - dtype = ((torch.float32,),) - assert dtype is not None - return dtype - - def set_inputs(self, *inputs): - assert len(inputs) == 1 - if self.headtail is not None: - self.headtail.set_inputs(*inputs) - - def set_preprocess(self, enc=None, dec=None): - if enc is not None: - self._preprocess[0] = enc - if dec is not None: - self._preprocess[1] = dec - - def forward_encoder_preprocess(self, dst=None): - return self.headtail.encoder_preprocess(dst) - - def forward_decoder_preprocess(self, dst=None): - return self.headtail.decoder_preprocess(dst) - - def forward_postprocess(self, dec): - return self.head(dec, self.dummy_labels) - - def forward(self, enc=None, dec=None): - """ - enc: encoder input/output - dec: decoder output/input - """ - pre_enc, pre_dec = self._preprocess - enc = pre_enc if enc is None else enc - dec = pre_dec if dec is None else dec - - # encoder preprocess - if self.encoder_preprocess: - output = self.forward_encoder_preprocess(dst=None) - enc = output[0] - - # forward encoder - if self.encoder_forward: - for layer in self.encoders: - # enc = checkpoint.checkpoint(layer, enc) - enc = layer(enc) - if self.layer_norm_encoder is not None: - enc = self.layer_norm_encoder(enc) - output = (enc,) - - # decoder preprocess - if self.decoder_preprocess: - output = self.forward_decoder_preprocess(dst=None) - dec = output[0] - - # forward decoder - if self.decoder_forward: - dec = pre_dec if dec is None else dec - for layer in self.decoders: - # dec, enc = checkpoint.checkpoint(layer, dec, enc) - dec, enc = layer(dec, enc) - if self.layer_norm_decoder is not None: - dec = self.layer_norm_decoder(dec) - output = (dec,) - else: - output = (enc, dec) - - # postprocess - if self.postprocess: - output = self.forward_postprocess(dec) - loss = output[0] - - return output - - -def reduce_embed(model, pp_embed_group): - """ - Embedding gradients needs to be reduced across pipeline stages - """ - if isinstance(model.headtail, torch.nn.Module): - if model.headtail.swap: - with torch.no_grad(): - grad = model.headtail.embed.weight.grad - grad = grad.cuda() - else: - grad = model.headtail.weight.grad - else: - grad = None - if grad is not None: - CudaTimer().start('comm') - torch.distributed.all_reduce(grad, group=pp_embed_group) - torch.cuda.synchronize() - CudaTimer().stop('comm') - if isinstance(model.headtail, torch.nn.Module): - if model.headtail.swap: - with torch.no_grad(): - model.headtail.embed.weight.grad.copy_(grad) - torch.cuda.synchronize() - - -if __name__ == '__main__': - - - cfg = Config() - print_each_rank(f'enc/dec layer#: {cfg.encoder_layers}, embed_dim#: {cfg.embed_dim}, heads#: {cfg.attention_heads}, ffn_dim#: {cfg.ffn_dim}', rank_only=0) - dataloader = SynTextDataLoader( - shapes=( - [1, cfg.max_source_positions], - ), - dtypes=(torch.int64,), - batch_dims=(0,) - ) - if args.use_1f1b: - model = mBARTFull(cfg, shard=False).cuda() - else: - model = mBARTFull(cfg, shard=True).cuda() - - print_each_rank('model weight consumpition:') - memory_summary() - - if args.use_swap: - parameters = get_swap_parameters() + list(model.parameters()) - else: - parameters = model.parameters() - optimizer = torch.optim.Adam(parameters, lr=3e-05, betas=(0.9, 0.98)) - - CudaTimer(enable=False).warmup() - iter_num = 6 - for step in range(iter_num): - if step >= 2: - CudaTimer(enable=True).start('e2e') - if args.use_1f1b: - for _ in range(args.nmb // args.iter_nmb): - schedule_1f1b( - model, iter(dataloader), - args.iter_nmb, len(pp_ranks), - (_pp_prev_rank, _pp_next_rank), - recompute=args.use_recompute, - ) - reduce_embed(model, _pp_embed_group) - if args.use_tp1f1b_pack: - schedule_tp_1f1b_pack( - model, iter(dataloader), - args.nmb, len(pp_ranks), - (_pp_prev_rank, _pp_next_rank), - recompute=args.use_recompute, - ) - if step == 0: - print('passed 1st iteration') - memory_summary() - optimizer.step() - optimizer.zero_grad() - if step == 0: - print('memory after optimizer') - memory_summary() - if step >= 2: - CudaTimer().stop('e2e') - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-2, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-2) - memory_summary() diff --git a/handcraft/mbart/mbart_hybrid.py b/handcraft/mbart/mbart_hybrid.py deleted file mode 100644 index 16e76895..00000000 --- a/handcraft/mbart/mbart_hybrid.py +++ /dev/null @@ -1,787 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers 12 --hidden-size 1024 --heads 16 \ - --pp-size 2 --tp-size 2 --nmb 4 --iter-nmb 4 \ - --use-recompute -""" - -from typing import Optional -import argparse -import math -import torch - -import cube -from cube.runtime.device import DeviceGroup -from cube.runtime.syndata import SynTextDataLoader -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - -from handcraft.mbart.schedule import schedule_1f1b -from handcraft.mbart.tp import AllReduceIdentity, IdentityAllreduce - -_tp_group = -1 -_pp_group = -1 -_pp_embed_group = -1 -_pp_next_rank = None -_pp_prev_rank = None -_layer_divisions = [] - - -parser = argparse.ArgumentParser(description='mbart hybrid') - -# model arch -parser.add_argument('--layers', type=int, default=12, - help='number encoder/decoder of layers') -parser.add_argument('--hidden-size', type=int, default=1024, - help='hidden size') -parser.add_argument('--heads', type=int, default=16, - help='number of heads') -# training config -parser.add_argument('--nmb', type=int, default=4, - help='num of micro batch') -parser.add_argument('--iter-nmb', type=int, default=0, - help='num of micro batch per scheduling iteration') - -# parallelism -parser.add_argument('--pp-size', type=int, default=1, - help='use pipeline parallelism') -parser.add_argument('--tp-size', type=int, default=1, - help='use tensor parallelism') -parser.add_argument('--embed-cpu', action='store_true', - help='put embedding inside CPU') -parser.add_argument('--use-recompute', action='store_true', - help='use recompute for a stage') -args = parser.parse_args() -print(args) - -cube.init() -pp_ranks, tp_ranks = DeviceGroup().create_hybrid([args.pp_size, args.tp_size]) -print_each_rank(f'my pp ranks: {pp_ranks}') -print_each_rank(f'my tp ranks: {tp_ranks}') - -if _tp_group == -1: - _tp_group = DeviceGroup().get_group(tp_ranks) - -if _pp_group == -1: - _pp_group = DeviceGroup().get_group(pp_ranks) - idx = pp_ranks.index(DeviceGroup().rank) - _pp_next_rank = pp_ranks[(idx+1) % len(pp_ranks)] - _pp_prev_rank = pp_ranks[(idx-1) % len(pp_ranks)] - - encoder_time = [1] * args.layers - decoder_time = [1] * args.layers - times = encoder_time + decoder_time - num_stages = torch.distributed.get_world_size(_pp_group) - budget = sum(times) // num_stages - print_each_rank(f'budget: {budget}', rank_only=0) - start, end = 0, 1 - for idx in range(num_stages): - accum = times[start] - assert end <= args.layers * 2 - while end != args.layers * 2: - accum += times[end] - if accum > budget: - break - end += 1 - if idx == num_stages - 1: - end = args.layers * 2 - _layer_divisions.append((start, end)) - start, end = end, end+1 - print_each_rank(f'layer divisions: {_layer_divisions}', rank_only=0) - -if len(pp_ranks) > 1: - pranks = [torch.zeros( - (args.pp_size,), dtype=torch.int, device=torch.cuda.current_device()) for _ in range(args.tp_size) - ] - prank = torch.tensor(pp_ranks, dtype=torch.int).cuda() - pranks[torch.distributed.get_rank(_tp_group)] = prank - torch.distributed.all_gather(pranks, prank, group=_tp_group) - torch.cuda.synchronize() - # print_each_rank(f'allgather-pp ranks: {pranks}') - encoder_preprocess_tp = 0 - decoder_preprocess_tp = None - for rank in range(len(pp_ranks)): - start, end = _layer_divisions[rank] - if start <= args.layers and end > args.layers: - decoder_preprocess_tp = rank - break - assert decoder_preprocess_tp is not None - for prank in pranks: - prank = prank.tolist() - embed_ranks = [prank[encoder_preprocess_tp], prank[decoder_preprocess_tp]] - embed_ranks = list(set(embed_ranks)) - print_each_rank(f'init embed group: {embed_ranks}') - group = DeviceGroup().get_group(embed_ranks) - if torch.distributed.get_rank(_tp_group) in prank: - print(f'embedding group: {embed_ranks}') - _pp_embed_group = group - assert _pp_embed_group != -1 - - - -class Config: - - # scale = args.scale - # scale_p = scale * 0.25 - - num_embeddings = 500000 # 250027 + int(250027*scale_p) - # decoder_layers = 12 + int(12*scale_p) - # encoder_layers = 12 + int(12*scale_p) - # embed_dim = 1024 + int(1024*scale_p) - # attention_heads = 16 + int(16*scale_p) if scale < 6 else 40 + 8*(scale-6) - decoder_layers = args.layers - encoder_layers = args.layers - embed_dim = args.hidden_size - attention_heads = args.heads - - attention_inner_dim = attention_heads * 64 - ffn_dim = 4 * embed_dim - - attention_dropout = 0.0 - activation_dropout = 0.0 - dropout = 0.1 - - max_target_positions = 1024 - max_source_positions = 1024 - - # classification task - pooler_dropout = 0.0 - num_classes = 3 - - -def attn_fn(query: torch.Tensor, key: torch.Tensor, - wq: torch.Tensor, wq_bias: Optional[torch.Tensor], - wk: torch.Tensor, wk_bias: Optional[torch.Tensor], - wv: torch.Tensor, wv_bias: Optional[torch.Tensor], - wout: torch.Tensor, wout_bias: Optional[torch.Tensor], - h: int, scale: float, dropout: float, mask=True): - """ - query, key: (L, N, E) = (seqlen, batch size, embed_dim) - wq, wk, wv weight: [(num_head * dim_head), E] - dropout: float - h: int: number of heads - """ - num_head = h - L, N = query.size(0), query.size(1) - dim_head = wq.size(0) // num_head - - q = torch.nn.functional.linear(query, wq, wq_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(key, wk, wk_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(key, wv, wv_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # attention mask - if mask: # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L -> (N h) L L - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, wout, wout_bias) # L N (h d), E E -> L N E - return output - - -class MultiheadAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, bias=True): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.inner_dim = inner_dim - self.head_dim = inner_dim // num_heads - self.num_heads = num_heads // self.tp_size - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # K - self.k_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None - # V - self.v_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None - # Q - self.q_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, self.inner_dim // self.tp_size)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None - - - def forward(self, query: torch.Tensor, key: torch.Tensor): - if key is not query: - key = IdentityAllreduce.apply(key, self.tp_group) - query = IdentityAllreduce.apply(query, self.tp_group) - attn = attn_fn(query, key, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) - attn = AllReduceIdentity.apply(attn, self.tp_group) - return attn - - -class PositionalEmbedding(torch.nn.Embedding): - - def __init__(self, num_embeddings: int, embedding_dim: int): - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward(self, seq_len: int): - positions = torch.arange( - 0, seq_len, dtype=torch.long, device=torch.cuda.current_device() - ) - return super().forward(positions + self.offset) - - -class EncoderLayer(torch.nn.Module): - - def __init__(self, cfg: Config): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.cfg = cfg - self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) - self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - def input_shape(self): - # L, N, E - return (self.cfg.max_source_positions, 1, self.cfg.embed_dim) - - def output_shape(self): - # L N E - return (self.cfg.max_source_positions, 1, self.cfg.embed_dim) - - def input_dtype(self): - return torch.float32 - - def output_dtype(self): - return torch.float32 - - def forward(self, x): # , encoder_padding_mask: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None): - # print(f'encoder layer: x: {x.size()}') - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x, x) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - if self.tp_size > 1: - x = IdentityAllreduce.apply(x, self.tp_group) - x = self.fc1(x) - x = torch.nn.functional.gelu(x) - x = self.activation_dropout(x) - x = self.fc2(x) - if self.tp_size > 1: - x = AllReduceIdentity.apply(x, self.tp_group) - - x = self.dropout(x) - x = x + residual - return x - - -class DecoderLayer(torch.nn.Module): - - def __init__(self, cfg: Config): - - super().__init__() - - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.cfg = cfg - self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - # encoder atten - self.encoder_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) - self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - def input_shape(self): - return ( - (self.cfg.max_target_positions, 1, self.cfg.embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.embed_dim) - ) - - def output_shape(self): - return ( - (self.cfg.max_target_positions, 1, self.cfg.embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.embed_dim) - ) - - def input_dtype(self): - return (torch.float32, torch.float32) - - def output_dtype(self): - return (torch.float32, torch.float32) - - def forward(self, x, encoder_out): # encoder_padding_mask): - # print(f'decoder layer: x: {x.size()}, encoder_out: {encoder_out.size()}') - residual = x - # normalize before - x = self.self_attn_layer_norm(x) - - # self attention - x = self.self_attn(x, x) - x = self.dropout(x) - x = residual + x - - # encoder attn - residual = x - # normalize before - x = self.encoder_attn_layer_norm(x) - x = self.encoder_attn(x, encoder_out) - x = self.dropout(x) - x = x + residual - - residual = x - # normalize before - x = self.final_layer_norm(x) - if self.tp_size > 1: - x = IdentityAllreduce.apply(x, self.tp_group) - x = self.fc1(x) - x = torch.nn.functional.gelu(x) - x = self.activation_dropout(x) - x = self.fc2(x) - if self.tp_size > 1: - x = AllReduceIdentity.apply(x, self.tp_group) - x = self.dropout(x) - x = x + residual - return x, encoder_out - - -class MBartClassificationHead(torch.nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.num_classes = num_classes - self.dense = torch.nn.Linear(input_dim, inner_dim // self.tp_size) - self.dropout = torch.nn.Dropout(p=pooler_dropout) - self.out_proj = torch.nn.Linear(inner_dim // self.tp_size, num_classes) - self.loss_fct = torch.nn.CrossEntropyLoss() - - def forward(self, dec: torch.Tensor, labels): - # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] - dec = dec.transpose(0, 1)[:,-1,:] - sentence_represent = dec - hidden_states = self.dropout(sentence_represent) - if self.tp_size > 1: - hidden_states = IdentityAllreduce.apply(hidden_states, self.tp_group) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - logits = self.out_proj(hidden_states) - if self.tp_size > 1: - logits = AllReduceIdentity.apply(logits, self.tp_group) - loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) - return (loss,) - - -class SharedEmbed(torch.nn.Module): - - def __init__(self, cfg: Config, embed_cpu=False): - """ - group = -1 means no tensor parallelism - """ - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - self.tp_idx = 0 if _tp_group == -1 else torch.distributed.get_rank(_tp_group) - if self.tp_size > 0: - print(f'[{torch.distributed.get_rank()}]: initialize sharding embed (x{self.tp_size})') - - self.embed_cpu = embed_cpu - self.cfg = cfg - - self.vocab_start_index = self.cfg.num_embeddings // self.tp_size * self.tp_idx - self.vocab_end_index = self.cfg.num_embeddings // self.tp_size * (self.tp_idx + 1) - self.weight = torch.nn.Parameter(torch.ones((self.cfg.num_embeddings // self.tp_size, self.cfg.embed_dim))) - - # encoder-preprocess - self.embed_positions_encoder = PositionalEmbedding(cfg.max_source_positions, cfg.embed_dim) - self.embed_scale_encoder = math.sqrt(cfg.embed_dim) - self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.embed_dim) - - # decoder-preprocess - self.embed_scale_decoder = math.sqrt(cfg.embed_dim) - self.embed_positions_decoder = PositionalEmbedding(cfg.max_target_positions, cfg.embed_dim) - self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.embed_dim) - - self._inputs = (None, ) - - def set_inputs(self, *inputs): - if self.embed_cpu: - self._inputs = [input.cpu() for input in inputs] - else: - self._inputs = inputs - - def embed_lookup(self, tokens): - if self.tp_size > 1: - mask = (tokens < self.vocab_start_index) | \ - (tokens >= self.vocab_end_index) - tokens = tokens.clone() - self.vocab_start_index - tokens[mask] = 0 - embed = torch.nn.functional.embedding(tokens, self.weight) - embed[mask, :] = 0.0 - if self.embed_cpu: - embed = embed.cuda() - embed = AllReduceIdentity.apply(embed, self.tp_group) - else: - embed = torch.nn.functional.embedding(tokens, self.weight) - if self.embed_cpu: - embed = embed.cuda() - return embed - - def encoder_preprocess(self): - source_tokens = self._inputs[0] - seq_len = source_tokens.size(1) - assert seq_len == self.cfg.max_source_positions - - source_embed = self.embed_lookup(source_tokens) - embed = self.embed_scale_encoder * source_embed - x = embed + self.embed_positions_encoder(seq_len) - x = self.layernorm_embedding_encoder(x) - x = torch.nn.functional.dropout(x, p=0.0) - enc = x.transpose(0, 1) - return (enc,) - - def decoder_preprocess(self): - prev_output_tokens = self._inputs[0] - seq_len = prev_output_tokens.size(1) - assert seq_len == self.cfg.max_source_positions - - target_emb = self.embed_lookup(prev_output_tokens) - embed = self.embed_scale_decoder * target_emb - embed = embed + self.embed_positions_encoder(seq_len) - embed = self.layernorm_embedding_decoder(embed) - embed = torch.nn.functional.dropout(embed, p=0.0) - dec = embed.transpose(0, 1) - return (dec,) - - -class mBARTFull(torch.nn.Module): - - def __init__(self, cfg: Config, embed_cpu=False): - super().__init__() - self.cfg = cfg - self.dummy_labels = torch.tensor([1]).cuda() - self._preprocess = [None, None] # enc, dec - - self.rank = DeviceGroup().rank - - global _pp_group - self.pp_group = _pp_group - self.total_layers = cfg.encoder_layers + cfg.decoder_layers - - self.pp_stage = torch.distributed.get_rank(_pp_group) - self.num_stages = torch.distributed.get_world_size(_pp_group) - self.layer_start, self.layer_end = _layer_divisions[self.pp_stage] - - self.encoder_preprocess = self.layer_start == 0 - self.encoder_forward = (self.layer_start < cfg.encoder_layers) - - self.decoder_first_stage = self.layer_start <= cfg.encoder_layers and self.layer_end > cfg.encoder_layers - self.decoder_preprocess = self.decoder_first_stage - self.decoder_forward = (self.layer_end > cfg.encoder_layers) - self.decoder_last_stage = (self.layer_end == cfg.encoder_layers + cfg.decoder_layers) - - self.postprocess = self.decoder_last_stage - - self.encoder_layer_start = self.layer_start - self.encoder_layer_end = min(self.layer_end, cfg.encoder_layers) - - self.decoder_layer_start = max(cfg.encoder_layers, self.layer_start) - self.decoder_layer_end = self.layer_end - - if self.encoder_preprocess or self.decoder_preprocess: - self.headtail = SharedEmbed(cfg, embed_cpu) - else: - self.headtail = None - - # encoders - if self.encoder_forward: - print(f'[{self.rank}]: initializing {self.encoder_layer_end - self.encoder_layer_start} encoder layers') - self.encoders = torch.nn.ModuleList([EncoderLayer(cfg) for _ in range(self.encoder_layer_end - self.encoder_layer_start)]) - if self.encoder_layer_end == cfg.encoder_layers: - self.layer_norm_encoder = torch.nn.LayerNorm(cfg.embed_dim) - else: - self.layer_norm_encoder = None - - # decoders - if self.decoder_forward: - print(f'[{self.rank}]: initializing {self.decoder_layer_end - self.decoder_layer_start} decoder layers') - self.decoders = torch.nn.ModuleList([DecoderLayer(cfg) for _ in range(self.decoder_layer_end - self.decoder_layer_start)]) - if self.decoder_layer_end == cfg.encoder_layers + cfg.decoder_layers: - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.embed_dim) - else: - self.layer_norm_decoder = None - - # postpross - if self.postprocess: - print(f'[{self.rank}]: will compute loss') - self.head = MBartClassificationHead(cfg.embed_dim, 1024, cfg.num_classes, 0.0) - - def input_shape(self): - if self.encoder_preprocess: - return () - elif self.encoder_forward: - return ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - ) - elif self.decoder_preprocess: - return ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - ) - elif self.decoder_first_stage: - return ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - ) - elif self.decoder_forward: - return ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.embed_dim), - ) - elif self.postprocess: - return ((1,),) - assert False - - def output_shape(self): - shape = None - if self.encoder_preprocess or self.encoder_forward: - shape = (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - # decoder preprocess is not allowed to be a single stage - if self.decoder_preprocess or self.decoder_forward: - shape = ( - (self.cfg.max_source_positions, 1, self.cfg.embed_dim), - (self.cfg.max_target_positions, 1, self.cfg.embed_dim), - ) - if self.postprocess: - shape = ( - (1,), - ) - assert shape is not None - return shape - - def input_dtype(self): - if self.encoder_preprocess: - return () - elif self.encoder_forward: - return (torch.float32,) - elif self.decoder_preprocess: - return (torch.float32,) - elif self.decoder_forward: - return (torch.float32, torch.float32) - else: - assert False - - def output_dtype(self): - dtype = None - if self.encoder_preprocess or self.encoder_forward: - dtype = (torch.float32,) - if self.decoder_preprocess or self.decoder_forward: - if self.pp_stage == self.num_stages - 1: - dtype = (torch.float32,) - else: - dtype = (torch.float32, torch.float32) - if self.postprocess: - dtype = ((torch.float32,),) - assert dtype is not None - return dtype - - def set_inputs(self, *inputs): - assert len(inputs) == 1 - if self.headtail is not None: - self.headtail.set_inputs(*inputs) - - def set_preprocess(self, enc=None, dec=None): - if enc is not None: - self._preprocess[0] = enc - if dec is not None: - self._preprocess[1] = dec - - def forward_encoder_preprocess(self): - return self.headtail.encoder_preprocess() - - def forward_decoder_preprocess(self): - return self.headtail.decoder_preprocess() - - def forward_postprocess(self, dec): - return self.head(dec, self.dummy_labels) - - def forward(self, enc=None, dec=None): - """ - enc: encoder input/output - dec: decoder output/input - """ - pre_enc, pre_dec = self._preprocess - enc = pre_enc if enc is None else enc - dec = pre_dec if dec is None else dec - - # encoder preprocess - if self.encoder_preprocess: - output = self.forward_encoder_preprocess() - enc = output[0] - - # forward encoder - if self.encoder_forward: - for layer in self.encoders: - # enc = checkpoint.checkpoint(layer, enc) - enc = layer(enc) - if self.layer_norm_encoder is not None: - enc = self.layer_norm_encoder(enc) - output = (enc,) - - # decoder preprocess - if self.decoder_preprocess: - output = self.forward_decoder_preprocess() - dec = output[0] - - # forward decoder - if self.decoder_forward: - dec = pre_dec if dec is None else dec - for layer in self.decoders: - # dec, enc = checkpoint.checkpoint(layer, dec, enc) - dec, enc = layer(dec, enc) - if self.layer_norm_decoder is not None: - dec = self.layer_norm_decoder(dec) - output = (dec,) - else: - output = (enc, dec) - - # postprocess - if self.postprocess: - output = self.forward_postprocess(dec) - loss = output[0] - - return output - - -def reduce_embed(model, pp_embed_group): - """ - Embedding gradients needs to be reduced across pipeline stages - """ - if isinstance(model.headtail, torch.nn.Module): - grad = model.headtail.weight.grad - else: - grad = None - if grad is not None: - CudaTimer().start('comm') - torch.distributed.all_reduce(grad, group=pp_embed_group) - torch.cuda.synchronize() - CudaTimer().stop('comm') - torch.cuda.synchronize() - - -if __name__ == '__main__': - - cfg = Config() - print_each_rank(f'enc/dec layer#: {cfg.encoder_layers}, embed_dim#: {cfg.embed_dim}, heads#: {cfg.attention_heads}, ffn_dim#: {cfg.ffn_dim}', rank_only=0) - dataloader = SynTextDataLoader( - shapes=( - [1, cfg.max_source_positions], - ), - dtypes=(torch.int64,), - batch_dims=(0,) - ) - - - model = mBARTFull(cfg, args.embed_cpu).cuda() - - if args.embed_cpu: - if model.headtail is not None: - model.headtail.weight = torch.nn.Parameter(model.headtail.weight.cpu()) - - print_each_rank('model weight consumpition:') - memory_summary() - - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - CudaTimer(enable=False).warmup() - iter_num = 6 - for step in range(iter_num): - if step >= 2: - CudaTimer(enable=True).start('e2e') - if args.pp_size > 1: - # schedule_naive(model, iter(dataloader), args.nmb, (_pp_prev_rank, _pp_next_rank)) - for _ in range(args.nmb // args.iter_nmb): - schedule_1f1b( - model, - iter(dataloader), - args.iter_nmb, - args.pp_size, - (_pp_prev_rank, _pp_next_rank), - group=_pp_group, - recompute=args.use_recompute - ) - # TODO: support gradient allreduce in cpu - if not args.embed_cpu: - reduce_embed(model, _pp_embed_group) - else: - loader = iter(dataloader) - for _ in range(args.nmb): - model.set_inputs(next(loader)) - loss = model()[0] - loss.backward() - if step == 0: - print('passed 1st iteration') - memory_summary() - optimizer.step() - optimizer.zero_grad() - if step >= 2: - CudaTimer().stop('e2e') - if step == 0: - print_each_rank('after optimizer') - memory_summary() - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-2, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-2) - memory_summary() diff --git a/handcraft/mbart/run-1f1b-swap.sh b/handcraft/mbart/run-1f1b-swap.sh deleted file mode 100755 index 2acb589f..00000000 --- a/handcraft/mbart/run-1f1b-swap.sh +++ /dev/null @@ -1,100 +0,0 @@ -evaldir=eval/mbart-v100-32gb-pcie-recompute - -mkdir -p ${evaldir} - -# ================================================= -# 4 gpus: arch layer 21,21, hidden 1792, heads 28 -# ================================================= -layers=24 -hidden=2048 -heads=32 -gpus=4 - -echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt - - - -layers=24 -hidden=2560 -heads=32 -gpus=4 - -echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt - - -layers=18 -hidden=3072 -heads=32 -gpus=4 - -echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt - - -layers=27 -hidden=2304 -heads=36 -gpus=4 - -echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt - - -layers=30 -hidden=2560 -heads=40 -gpus=8 - -echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt - - -layers=33 -hidden=2816 -heads=48 -gpus=8 - -echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt - - -layers=24 -hidden=4096 -heads=32 -gpus=8 - -echo "testing pure-1f1b-swap: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b-swap.txt - - -# python scripts/keep.py --gpus 8 \ No newline at end of file diff --git a/handcraft/mbart/run-recompute-arch.sh b/handcraft/mbart/run-recompute-arch.sh deleted file mode 100755 index 4b90ea89..00000000 --- a/handcraft/mbart/run-recompute-arch.sh +++ /dev/null @@ -1,76 +0,0 @@ -layers=24 -hidden=2560 -heads=32 -gpus=8 - -evaldir=eval/mbart-v100-32gb-pcie-recompute -mkdir -p ${evaldir} - - -# TP-1F1B -echo 'testing mixture-1f1b' -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - - -# # Pure 1F1B -echo 'testing pure 1f1b' -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b.txt - -# Pure TP -echo 'testing pure tensor parallelism' -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - -# # Hybrid TP-1F1B -- 4 GPU -if [ ${gpus} == 4 ] -then - echo 'testing hybrid tp:pp=2:2' - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 2 --pp-size 2 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt - sleep 5 - killall python - sleep 5 - killall python -fi - -# Hybrid TP-1F1B -- 8 GPU -if [ ${gpus} == 8 ] -then - echo 'testing hybrid tp:pp=4:2' - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt - sleep 5 - killall python - sleep 5 - killall python - - echo 'testing hybrid tp:pp=2:4' - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 2 --pp-size 4 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt - sleep 5 - killall python - sleep 5 - killall python -fi - -# python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/run-recompute-full-v100-32gb.sh b/handcraft/mbart/run-recompute-full-v100-32gb.sh deleted file mode 100755 index adddc92d..00000000 --- a/handcraft/mbart/run-recompute-full-v100-32gb.sh +++ /dev/null @@ -1,301 +0,0 @@ -evaldir=eval/mbart-v100-32gb-pcie-recompute - -mkdir -p ${evaldir} - -# ================================================= -# 4 gpus: arch layer 21,21, hidden 1792, heads 28 -# ================================================= -layers=21 -hidden=1792 -heads=28 -gpus=4 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b.txt - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - -echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 2 --pp-size 2 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt - - -# ================================================= -# 4 gpus: arch layer 24,24, hidden 2048, heads 32 -# ================================================= -layers=24 -hidden=2048 -heads=32 -gpus=4 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" -echo "Will be OOM" - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - -echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 2 --pp-size 2 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt - - -# ================================================= -# 4 gpus: arch layer 24,24, hidden 2560, heads 32 -# ================================================= -layers=24 -hidden=2560 -heads=32 -gpus=4 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" -echo "Will be OOM" - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - -echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" -echo "Will be OOM" - - -# ================================================= -# 4 gpus: arch layer 18,18, hidden 3072, heads 32 -# ================================================= -layers=18 -hidden=3072 -heads=32 -gpus=4 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" -echo "Will be OOM" - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - -echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" -echo "Will be OOM" - - -# ================================================= -# 4 gpus: arch layer 27,27, hidden 3072, heads 32 -# ================================================= -layers=27 -hidden=2304 -heads=36 -gpus=4 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" -echo "Will be OOM" - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - -echo "testing tensor x pipeline parallelism 2x2: L${layers}E${hidden}H${heads}" -echo "Will be OOM" - - -# ================================================= -# 8 gpus: arch layer 24,24, hidden 2048, heads 32 -# ================================================= -layers=24 -hidden=2048 -heads=32 -gpus=8 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-1f1b --nmb 256 --iter-nmb 256\ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-1f1b.txt - - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - -echo "testing pure tensor parallelism 2x4: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 2 --pp-size 4 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt - -echo "testing tensor x pipeline parallelism 2x4: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt - -# ================================================= -# 8 gpus: arch layer 30,30, hidden 2560, heads 40 -# ================================================= -layers=30 -hidden=2560 -heads=40 -gpus=8 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - -echo "testing pure tensor parallelism 2x4: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 2 --pp-size 4 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt - -echo "testing tensor x pipeline parallelism 2x4: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt - - -# ================================================= -# 8 gpus: arch layer 33,33, hidden 2816, heads 40 -# ================================================= -layers=33 -hidden=2816 -heads=48 -gpus=8 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - - -echo "testing tensor x pipeline parallelism 4x2: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 4 --pp-size 2 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt - -# ================================================= -# 8 gpus: arch layer 24,24, hidden 4096, heads 32 -# ================================================= -layers=24 -hidden=4096 -heads=32 -gpus=8 - -echo "testing mixture-1f1b: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --use-tp1f1b-pack --nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b-pack.txt - -echo "testing pure tensor parallelism: L${layers}E${hidden}H${heads}" -OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size ${gpus} --pp-size 1 --nmb 256 --iter-nmb 256 \ - --use-recompute > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - - -echo "testing tensor x pipeline parallelism 4x2: L${layers}E${hidden}H${heads}" -echo "Will OOM" - - -echo 'done!!!' -# python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/schedule.py b/handcraft/mbart/schedule.py deleted file mode 100644 index 70732c3b..00000000 --- a/handcraft/mbart/schedule.py +++ /dev/null @@ -1,551 +0,0 @@ -from typing import List, Tuple -import torch - -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.device import DeviceGroup - -io_input = input - -def forward_step(model, *args, **kwargs): - """ - Forward pass - """ - CudaTimer().start("forward") - output = model(*args, **kwargs) - CudaTimer().stop("forward") - return output - - -def backward_step(input_tensors: List[torch.Tensor], - output_tensors: List[torch.Tensor], - output_tensor_grads: List[torch.Tensor]) -> List[torch.Tensor]: - """ - Backward pass - """ - for tensor in input_tensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - tensor.retain_grad() - CudaTimer().start("backward") - torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) - CudaTimer().stop("backward") - input_tensor_grads = [] - for tensor in input_tensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - input_tensor_grads.append(tensor.grad) - else: - input_tensor_grads.append(None) - return input_tensor_grads - - -def recv_forward(model, prev_rank: int) -> List[torch.Tensor]: - CudaTimer().start(field_name='comm') - shapes = model.input_shape() - dtypes = model.input_dtype() - if len(shapes) == 0: - return () - # print(f'rank {DeviceGroup().rank} recving forward: {shapes}') - tensors = [ - torch.empty( - shape, requires_grad=True, dtype=dtype, - device=torch.cuda.current_device() - ) for shape, dtype in zip(shapes, dtypes) - ] - recv_ops = [ - torch.distributed.P2POp( - torch.distributed.irecv, tensor, prev_rank - ) for tensor in tensors - ] - reqs = torch.distributed.batch_isend_irecv(recv_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return tensors - - -def recv_backward(model, next_rank: int) -> List[torch.Tensor]: - CudaTimer().start(field_name='comm') - shapes = model.output_shape() - dtypes = model.output_dtype() - if len(shapes) == 0: - return () - # print(f'rank {DeviceGroup().rank} recving backward: {shapes}') - tensors = [ - torch.empty( - shape, requires_grad=False, dtype=dtype, - device=torch.cuda.current_device() - ) for shape, dtype in zip(shapes, dtypes) - ] - recv_ops = [ - torch.distributed.P2POp( - torch.distributed.irecv, tensor, next_rank - ) for tensor in tensors - ] - reqs = torch.distributed.batch_isend_irecv(recv_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return tensors - - -def send_forward(outputs: List[torch.Tensor], next_rank: int): - if len(outputs) == 0: - return - CudaTimer().start(field_name='comm') - # print(f'rank {DeviceGroup().rank} sending forward: {[tuple(t.size()) for t in outputs]}') - send_ops = [ - torch.distributed.P2POp( - torch.distributed.isend, tensor, next_rank - ) for tensor in outputs - ] - reqs = torch.distributed.batch_isend_irecv(send_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - - -def send_backward(grads: List[torch.Tensor], prev_rank: int): - if len(grads) == 0: - return - CudaTimer().start(field_name='comm') - # print(f'rank {DeviceGroup().rank} sending backward: {[tuple(t.size()) for t in grads]}') - send_ops = [ - torch.distributed.P2POp( - torch.distributed.isend, tensor, prev_rank - ) for tensor in grads - ] - reqs = torch.distributed.batch_isend_irecv(send_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - - -def send_forward_recv_backward(outputs, model, next_rank: int) -> List[torch.Tensor]: - CudaTimer().start(field_name='comm') - shapes = model.output_shape() - dtypes = model.output_dtype() - # print(f'rank {DeviceGroup().rank} sending forward: {[tuple(t.size()) for t in outputs]} recving backward {shapes}') - ops = list() - # send forward outputs - send_ops = [ - torch.distributed.P2POp( - torch.distributed.isend, tensor, next_rank - ) for tensor in outputs - ] - ops += send_ops - # recv backward inputs - tensors = [ - torch.empty( - shape, requires_grad=True, dtype=dtype, - device=torch.cuda.current_device() - ) for shape, dtype in zip(shapes, dtypes) - ] - recv_ops = [ - torch.distributed.P2POp( - torch.distributed.irecv, tensor, next_rank - ) for tensor in tensors - ] - ops += recv_ops - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return tensors - - -def send_backward_recv_forward(grads, model, prev_rank: int) -> List[torch.Tensor]: - CudaTimer().start(field_name='comm') - shapes = model.input_shape() - dtypes = model.input_dtype() - # print(f'rank {DeviceGroup().rank} sending backward: {[tuple(t.size()) for t in grads]} recving forward {shapes}') - ops = list() - # send backward gradients - send_ops = [ - torch.distributed.P2POp( - torch.distributed.isend, tensor, prev_rank - ) for tensor in grads - ] - ops += send_ops - # recv forward inputs - tensors = [ - torch.empty( - shape, requires_grad=True, dtype=dtype, - device=torch.cuda.current_device() - ) for shape, dtype in zip(shapes, dtypes) - ] - recv_ops = [ - torch.distributed.P2POp( - torch.distributed.irecv, tensor, prev_rank - ) for tensor in tensors - ] - ops += recv_ops - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return tensors - - - -def schedule_naive(model, dataloader, num_microbatch: int, neighbors: Tuple[int, int]): - """ - neighbors: (prev_rank: int, next_rank: int) - """ - rank = DeviceGroup().rank - prev_rank, next_rank = neighbors - - is_first_stage = rank < prev_rank - is_last_stage = rank > next_rank - - for step in range(num_microbatch): - model.set_inputs(next(dataloader)) - # print(f'rank {rank} recving forward input...') - inputs = () if is_first_stage else recv_forward(model, prev_rank) - # forward - outputs = forward_step(model, *inputs) - # send forward - if not is_last_stage: - # print(f'rank {rank} sending forward output...') - send_forward(outputs, next_rank) - # recv backward - # print(f'rank {rank} recving backward input...') - output_grads = (None,) if is_last_stage else recv_backward(model, next_rank) - # backward - input_grads = backward_step(inputs, outputs, output_grads) - # send backward - if not is_first_stage: - # print(f'rank {rank} sending backward output...') - send_backward(input_grads, prev_rank) - - # memory_summary() - # if rank == 0: - # io_input(f'{step}>>>') - # torch.distributed.barrier() - - -def schedule_tp_1f1b_pack(model: torch.nn.Module, - dataloader, - num_microbatch: int, - num_stage: int, - neighbors: Tuple[int, int], - recompute=False): - rank = DeviceGroup().rank - prev_rank, next_rank = neighbors - - is_first_stage = rank < prev_rank - # FIXME: only work for pure pipeline - is_first_decoder_stage = (rank == num_stage // 2) - is_last_stage = rank > next_rank - - input_tensors = list() - output_tensors = list() - - input_encoder_tensors = list() - output_encoder_tensors = list() - input_decoder_tensors = list() - output_decoder_tensors = list() - - def tp_encoder_preprocess() -> torch.Tensor: - tokens = next(dataloader) - model.set_inputs(tokens) - enc = model.forward_encoder_preprocess(dst=0)[0] - input_encoder_tensors.append((tokens,)) - output_encoder_tensors.append((enc,)) - enc = enc.detach().requires_grad_() - if is_first_stage: - model.set_preprocess(enc=enc) - return (enc,) - return () - - def tp_decoder_preprocess() -> torch.Tensor: - tokens = next(dataloader) - model.set_inputs(tokens) - dec = model.forward_decoder_preprocess(dst=num_stage // 2)[0] - input_decoder_tensors.append((tokens,)) - output_decoder_tensors.append((dec,)) - dec = dec.detach().requires_grad_() - if is_first_decoder_stage: - model.set_preprocess(dec=dec) - return (dec,) - return () - - def tp_encoder_backward(grads: Tuple[torch.Tensor]): - inputs_head, outputs_head = input_encoder_tensors.pop(0), output_encoder_tensors.pop(0) - # encoder backward - enc = outputs_head[0] - if not is_first_stage: - grads = (torch.empty_like(enc),) - # decoder backward - backward_step((), (enc,), grads) - - def tp_decoder_backward(grads: Tuple[torch.Tensor]): - inputs_head, outputs_head = input_decoder_tensors.pop(0), output_decoder_tensors.pop(0) - # decoder backward - dec = outputs_head[0] - if not is_first_decoder_stage: - grads = (torch.empty_like(dec),) - backward_step((), (dec,), grads) - - fofst = [-(step // 2) for step in range(num_stage)] - bofst = [-(num_stage - 1 - (step // 2)) for step in range(num_stage)] - # print(fofst) - # print(bofst) - fofst = fofst[rank] - bofst = bofst[rank] - last_backward = (None,) - last_forward = (None,) - for step in range(num_microbatch + num_stage - 1): - torch.distributed.barrier() - # print_each_rank(f'=========begin rank {rank}=========') - fmid, bmid = step + fofst, step + bofst - decoder_fmid = step - num_stage // 2 // 2 - encoder_bmid = step + 1 - num_stage // 2 * 2 - decoder_bmid = step + 1 - int(num_stage // 2 * 1.5) - do_backward = 0 <= bmid and bmid <= num_microbatch - 1 - do_forward = 0 <= fmid and fmid <= num_microbatch - 1 - - # step1: tp encoder forward - if 0 <= step and step <= num_microbatch - 1: - # print(f'rank {rank} forward tp model ') - inputs = tp_encoder_preprocess() - - # step2: tp decoder forward - if 0 <= decoder_fmid and decoder_fmid <= num_microbatch - 1: - tp_decoder_preprocess() - - # step 3: forward + backward - if rank % 2 == 0: - # inter-barrier - if is_first_stage: - inputs = inputs - else: - if do_forward and last_backward != (None,): - # print(f'rank {rank} send backward grad + recv forward output ') - inputs = send_backward_recv_forward(last_backward, model, prev_rank) - elif do_forward: - # print(f'rank {rank} recv forward output ') - inputs = recv_forward(model, prev_rank) - elif last_backward != (None,): - # print(f'rank {rank} send backward grad ') - send_backward(last_backward, prev_rank) - - # forward - if do_forward: - input_tensors.append(inputs) - if is_first_stage: - inputs = () - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs) - output_tensors.append(None) - else: - outputs = forward_step(model, *inputs) - output_tensors.append(outputs) - - # recompute if backward is needed - if do_backward: - inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) - if recompute: - assert outputs is None - outputs = forward_step(model, *inputs) - - # intra-barrier send recv - output_grads = (None,) - if (do_forward and not is_last_stage) and (do_backward and not is_last_stage): - # send forward recv backward - # print(f'rank {rank} recv backward grad + send forward output ') - output_grads = send_forward_recv_backward(outputs, model, next_rank) - elif do_forward and not is_last_stage: - # print(f'rank {rank} send forward output ') - send_forward(outputs, next_rank) - elif do_backward and not is_last_stage: - # print(f'rank {rank} recv backward grad ') - output_grads = recv_backward(model, next_rank) - - # backward - last_backward = (None,) - if do_backward: - # inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) - input_grads = backward_step(inputs, outputs, output_grads) - last_backward = input_grads - - # backward + forward - if rank % 2 == 1: - # inter-barrier - if is_last_stage: - output_grads = (None,) - else: - if do_backward and last_forward != (None,): - # print(f'rank {rank} recv backward grad + send forward output ') - output_grads = send_forward_recv_backward(last_forward, model, next_rank) - elif do_backward: - # print(f'rank {rank} recv backward grad ') - output_grads = recv_backward(model, next_rank) - elif last_forward != (None,): - # print(f'rank {rank} send forward output ') - send_forward(last_forward, next_rank) - - # backward - last_backward = (None,) - if do_backward: - inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) - # backward - input_grads = backward_step(inputs, outputs, output_grads) - last_backward = input_grads - - # intra-barrier - if do_backward and do_forward: - # print(f'rank {rank} send backward grad + recv forward output ') - inputs = send_backward_recv_forward(input_grads, model, prev_rank) - elif do_backward: - # print(f'rank {rank} send backward grad ') - send_backward(input_grads, prev_rank) - elif do_forward: - # print(f'rank {rank} recv forward output ') - inputs = recv_forward(model, prev_rank) - - # forward - last_forward = (None,) - if do_forward: - # forward step - input_tensors.append(inputs) - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs) - output_tensors.append(None) - else: - outputs = forward_step(model, *inputs) - output_tensors.append(outputs) - last_forward = outputs - - next_backward = 0 <= (bmid+1) and (bmid+1) <= num_microbatch - 1 - if next_backward: - if recompute: - inputs, outputs = input_tensors[0], output_tensors[0] - assert outputs is None - outputs = forward_step(model, *inputs) - input_tensors[0] = inputs - output_tensors[0] = outputs - - - # tp tail forward-backward - if 0 <= decoder_bmid and decoder_bmid <= num_microbatch - 1: - # FIXME: currently use encoder grad - tp_decoder_backward(last_backward) - - # step 4: tp encoder and decoder backward - if 0 <= encoder_bmid and encoder_bmid <= num_microbatch - 1: - tp_encoder_backward(last_backward) - - # memory_summary() - # if rank == 0: - # io_input(f'{step}>>>') - # torch.distributed.barrier() - # print_each_rank(f'=========end rank {rank}: {step}=========') - - assert len(input_tensors) == 0 - assert len(output_tensors) == 0 - assert len(input_encoder_tensors) == 0 - assert len(output_encoder_tensors) == 0 - assert len(input_decoder_tensors) == 0 - assert len(output_decoder_tensors) == 0 - - # print_each_rank(f'=========end rank {rank}=========') - - -def schedule_1f1b(model: torch.nn.Module, - dataloader, - num_microbatch: int, - num_stage: int, - neighbors: Tuple[int, int], - group=None, - recompute=False): - - rank = torch.distributed.get_rank() - prev_rank, next_rank = neighbors - is_first_stage = rank < prev_rank - is_last_stage = rank > next_rank - - num_warmup_microbatches = num_stage - 1 - torch.distributed.get_rank(group=group) - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatch) - num_warmup_remaining = num_microbatch - num_warmup_microbatches - - input_tensors = list() - output_tensors = list() - - # warmup - for i in range(num_warmup_microbatches): - model.set_inputs(next(dataloader)) - # recv forward - inputs = () if is_first_stage else recv_forward(model, prev_rank) - # forward - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs) - output_tensors.append(None) - else: - outputs = forward_step(model, *inputs) - output_tensors.append(outputs) - # send forward - send_forward(outputs, next_rank) - input_tensors.append(inputs) - - # before running 1f1b: need to recv first forward tensor - if num_warmup_remaining > 0: - model.set_inputs(next(dataloader)) - inputs = () if is_first_stage else recv_forward(model, prev_rank) - - # run 1f1b - for i in range(num_warmup_remaining): - model.set_inputs(next(dataloader)) - # forward - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs) - output_tensors.append(None) - else: - outputs = forward_step(model, *inputs) - output_tensors.append(outputs) - input_tensors.append(inputs) - - # send forward recv backward - grads = (None,) - if not is_last_stage: - grads = send_forward_recv_backward(outputs, model, next_rank) - - # backward - inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) - if recompute: - assert outputs is None - outputs = forward_step(model, *inputs) - input_grads = backward_step(inputs, outputs, grads) - - # send backward - inputs = () - if not is_first_stage: - if i != (num_warmup_remaining-1): - # send backward recv forward - inputs = send_backward_recv_forward(input_grads, model, prev_rank) - else: - # send backward - send_backward(input_grads, prev_rank) - - # cooldown - for i in range(num_warmup_microbatches): - inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) - # recv backward - grads = (None,) if is_last_stage else recv_backward(model, next_rank) - # backward - if recompute: - assert outputs is None - outputs = forward_step(model, *inputs) - input_grads = backward_step(inputs, outputs, grads) - # send backward - if not is_first_stage: - send_backward(input_grads, prev_rank) diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh new file mode 100755 index 00000000..dc56e8d4 --- /dev/null +++ b/handcraft/mbart/test-fp32.sh @@ -0,0 +1,203 @@ +evaldir=eval/mbart-v100-32gb-pcie-recompute +mkdir -p ${evaldir} + + +test_mix_tp_1f1b() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev mixture-1f1b: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs 256 --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule tp1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_tp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure tp: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs 256 --micro-bs 1 \ + --pp-size 1 --tp-size ${gpus} \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_pp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure pp: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs 256 --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_pp_swap() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure pp swap: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs 256 --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule 1f1b --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp-swap.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_hybrid_tp_pp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + + if [ ${gpus} == 4 ] + then + echo "testing ${gpus}-dev tp:pp=2:2 | L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs 256 --micro-bs 1 \ + --pp-size 2 --tp-size 2 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt + sleep 5 + killall python + sleep 5 + killall python + fi + + if [ ${gpus} == 8 ] + then + echo "testing ${gpus}-dev tp:pp=4:2 | L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs 256 --micro-bs 1 \ + --pp-size 2 --tp-size 4 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt + sleep 5 + killall python + sleep 5 + killall python + + echo "testing ${gpus}-dev tp:pp=2:4 | L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs 256 --micro-bs 1 \ + --pp-size 2 --tp-size 4 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt + sleep 5 + killall python + sleep 5 + killall python + fi +} + +# ================================================= +# 4 gpus: arch layer 21,21, hidden 1792, heads 28 +# ================================================= +# test_mixtp_1f1b 21 1792 28 4 +# test_tp 21 1792 28 4 +# test_pp 21 1792 28 4 +# test_hybrid_tp_pp 21 1792 28 4 + +# ================================================= +# 4 gpus: arch layer 24,24, hidden 2048, heads 32 +# ================================================= +# test_mixtp_1f1b 24 2048 32 4 +# test_tp 24 2048 32 4 +# test_pp 24 2048 32 4 +# test_hybrid_tp_pp 24 2048 32 4 + +# ================================================= +# 4 gpus: arch layer 24,24, hidden 2560, heads 32 +# ================================================= +# test_mixtp_1f1b 24 2560 32 4 +# test_tp 24 2560 32 4 +# test_pp 24 2560 32 4 +# test_hybrid_tp_pp 24 2560 32 4 + +# ================================================= +# 4 gpus: arch layer 18,18, hidden 3072, heads 32 +# ================================================= +# test_mixtp_1f1b 18 3072 32 4 +# test_tp 18 3072 32 4 +# test_pp 18 3072 32 4 +# test_hybrid_tp_pp 18 3072 32 4 + +# ================================================= +# 4 gpus: arch layer 27,27, hidden 2304, heads 36 +# ================================================= +test_mixtp_1f1b 27 2304 36 4 +test_tp 27 2304 36 4 +# test_pp 27 2304 36 4 +# test_hybrid_tp_pp 27 2304 36 4 + +# ================================================= +# 8 gpus: arch layer 24,24, hidden 2048, heads 32 +# ================================================= +# test_mixtp_1f1b 24 2048 32 8 +# test_tp 24 2048 32 8 +# test_pp 24 2048 32 8 +# test_hybrid_tp_pp 24 2048 32 8 + +# ================================================= +# 8 gpus: arch layer 30,30, hidden 2560, heads 40 +# ================================================= +# test_mixtp_1f1b 30 2560 40 8 +# test_tp 30 2560 40 8 +# test_pp 30 2560 40 8 +# test_hybrid_tp_pp 30 2560 40 8 + +# ================================================= +# 8 gpus: arch layer 33,33, hidden 2816, heads 48 +# ================================================= +test_mixtp_1f1b 33 2816 48 8 +test_tp 33 2816 48 8 +# test_pp 33 2816 48 8 +# test_hybrid_tp_pp 33 2816 48 8 + +# ================================================= +# 8 gpus: arch layer 24,24, hidden 4096, heads 32 +# ================================================= +test_mixtp_1f1b 24 4096 32 8 +test_tp 24 4096 32 8 +# test_pp 24 4096 32 8 +# test_hybrid_tp_pp 24 4096 32 8 + +python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/tp.py b/handcraft/mbart/tp.py deleted file mode 100644 index b8f9bf06..00000000 --- a/handcraft/mbart/tp.py +++ /dev/null @@ -1,130 +0,0 @@ -import torch -from cube.profiler.timer import CudaTimer - - -class AllReduceIdentity(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, group): - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input - CudaTimer().start(field_name='comm') - torch.distributed.all_reduce(input, group=group) - CudaTimer().stop(field_name='comm') - return input - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -class IdentityAllreduce(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, group): - ctx._group = group - return input - - @staticmethod - def backward(ctx, grad_output): - world_size = torch.distributed.get_world_size(ctx._group) - if world_size == 1: - return grad_output, None - CudaTimer().start(field_name='comm') - torch.distributed.all_reduce(grad_output, group=ctx._group) - CudaTimer().stop(field_name='comm') - return grad_output, None - - -class AllGatherScatter(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, dim, group): - ctx._group = group - ctx._dim = dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input - CudaTimer().start(field_name='comm') - rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(input) for _ in range(world_size)] - tensor_list[rank] = input - torch.distributed.all_gather(tensor_list, input, group=group) - output = torch.cat(tensor_list, dim=dim).contiguous() - CudaTimer().stop(field_name='comm') - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - group = ctx._group - dim = ctx._dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output - CudaTimer().start(field_name='comm') - input_list = grad_output.chunk(world_size, dim=dim) - rank = torch.distributed.get_rank(group) - grad = input_list[rank].contiguous() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class ReduceBroadcast(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, dst: int, group): - ctx._dst = dst - ctx._group = group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input - CudaTimer().start(field_name='comm') - torch.distributed.reduce(input, dst, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input - - @staticmethod - def backward(ctx, grad_output): - src = ctx._dst - group = ctx._group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output, None, None - CudaTimer().start(field_name='comm') - torch.distributed.broadcast(grad_output, src, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return grad_output, None, None - - -class BroadcastReduce(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, src: int, group=None): - ctx._src = src - ctx._group = group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input - CudaTimer().start(field_name='comm') - torch.distributed.broadcast(input, src, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input - - @staticmethod - def backward(ctx, grad_output): - dst = ctx._src - group = ctx._group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output, None, None - CudaTimer().start(field_name='comm') - if not grad_output.is_contiguous(): - grad_output = grad_output.contiguous() - torch.distributed.reduce(grad_output, dst, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return grad_output, None, None From f8e170714979a48f1d1c54589703fd5ce52846e9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Apr 2022 09:20:40 +0000 Subject: [PATCH 0719/1892] test bug fix --- handcraft/mbart/test-fp32.sh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh index dc56e8d4..87a22f37 100755 --- a/handcraft/mbart/test-fp32.sh +++ b/handcraft/mbart/test-fp32.sh @@ -131,7 +131,7 @@ test_hybrid_tp_pp() # ================================================= # 4 gpus: arch layer 21,21, hidden 1792, heads 28 # ================================================= -# test_mixtp_1f1b 21 1792 28 4 +# test_mix_tp_1f1b 21 1792 28 4 # test_tp 21 1792 28 4 # test_pp 21 1792 28 4 # test_hybrid_tp_pp 21 1792 28 4 @@ -139,7 +139,7 @@ test_hybrid_tp_pp() # ================================================= # 4 gpus: arch layer 24,24, hidden 2048, heads 32 # ================================================= -# test_mixtp_1f1b 24 2048 32 4 +# test_mix_tp_1f1b 24 2048 32 4 # test_tp 24 2048 32 4 # test_pp 24 2048 32 4 # test_hybrid_tp_pp 24 2048 32 4 @@ -147,7 +147,7 @@ test_hybrid_tp_pp() # ================================================= # 4 gpus: arch layer 24,24, hidden 2560, heads 32 # ================================================= -# test_mixtp_1f1b 24 2560 32 4 +# test_mix_tp_1f1b 24 2560 32 4 # test_tp 24 2560 32 4 # test_pp 24 2560 32 4 # test_hybrid_tp_pp 24 2560 32 4 @@ -155,7 +155,7 @@ test_hybrid_tp_pp() # ================================================= # 4 gpus: arch layer 18,18, hidden 3072, heads 32 # ================================================= -# test_mixtp_1f1b 18 3072 32 4 +# test_mix_tp_1f1b 18 3072 32 4 # test_tp 18 3072 32 4 # test_pp 18 3072 32 4 # test_hybrid_tp_pp 18 3072 32 4 @@ -163,7 +163,7 @@ test_hybrid_tp_pp() # ================================================= # 4 gpus: arch layer 27,27, hidden 2304, heads 36 # ================================================= -test_mixtp_1f1b 27 2304 36 4 +test_mix_tp_1f1b 27 2304 36 4 test_tp 27 2304 36 4 # test_pp 27 2304 36 4 # test_hybrid_tp_pp 27 2304 36 4 @@ -171,7 +171,7 @@ test_tp 27 2304 36 4 # ================================================= # 8 gpus: arch layer 24,24, hidden 2048, heads 32 # ================================================= -# test_mixtp_1f1b 24 2048 32 8 +# test_mix_tp_1f1b 24 2048 32 8 # test_tp 24 2048 32 8 # test_pp 24 2048 32 8 # test_hybrid_tp_pp 24 2048 32 8 @@ -179,7 +179,7 @@ test_tp 27 2304 36 4 # ================================================= # 8 gpus: arch layer 30,30, hidden 2560, heads 40 # ================================================= -# test_mixtp_1f1b 30 2560 40 8 +# test_mix_tp_1f1b 30 2560 40 8 # test_tp 30 2560 40 8 # test_pp 30 2560 40 8 # test_hybrid_tp_pp 30 2560 40 8 @@ -187,7 +187,7 @@ test_tp 27 2304 36 4 # ================================================= # 8 gpus: arch layer 33,33, hidden 2816, heads 48 # ================================================= -test_mixtp_1f1b 33 2816 48 8 +test_mix_tp_1f1b 33 2816 48 8 test_tp 33 2816 48 8 # test_pp 33 2816 48 8 # test_hybrid_tp_pp 33 2816 48 8 @@ -195,7 +195,7 @@ test_tp 33 2816 48 8 # ================================================= # 8 gpus: arch layer 24,24, hidden 4096, heads 32 # ================================================= -test_mixtp_1f1b 24 4096 32 8 +test_mix_tp_1f1b 24 4096 32 8 test_tp 24 4096 32 8 # test_pp 24 4096 32 8 # test_hybrid_tp_pp 24 4096 32 8 From 05f4ef5179c1cf6d4673771d15329d953f95423f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Apr 2022 10:36:43 +0000 Subject: [PATCH 0720/1892] fix tp bug --- handcraft/mbart/test-fp32.sh | 26 +++++++++++++++-------- handcraft/mbart/train.py | 41 +++++++++++++++--------------------- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh index 87a22f37..d1beec94 100755 --- a/handcraft/mbart/test-fp32.sh +++ b/handcraft/mbart/test-fp32.sh @@ -1,4 +1,4 @@ -evaldir=eval/mbart-v100-32gb-pcie-recompute +evaldir=eval/mbart-fp32-v100-32gb mkdir -p ${evaldir} @@ -12,7 +12,7 @@ test_mix_tp_1f1b() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 256 --micro-bs 1 \ + --bs 8 --micro-bs 1 \ --pp-size ${gpus} --tp-size 1 \ --schedule tp1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt sleep 5 @@ -31,7 +31,7 @@ test_tp() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 256 --micro-bs 1 \ + --bs 8 --micro-bs 1 \ --pp-size 1 --tp-size ${gpus} \ --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt sleep 5 @@ -128,6 +128,14 @@ test_hybrid_tp_pp() fi } + +test_mix_tp_1f1b 27 2304 36 4 +test_tp 27 2304 36 4 +test_mix_tp_1f1b 33 2816 48 8 +test_tp 33 2816 48 8 +test_mix_tp_1f1b 24 4096 32 8 +test_tp 24 4096 32 8 + # ================================================= # 4 gpus: arch layer 21,21, hidden 1792, heads 28 # ================================================= @@ -163,8 +171,8 @@ test_hybrid_tp_pp() # ================================================= # 4 gpus: arch layer 27,27, hidden 2304, heads 36 # ================================================= -test_mix_tp_1f1b 27 2304 36 4 -test_tp 27 2304 36 4 +# test_mix_tp_1f1b 27 2304 36 4 +# test_tp 27 2304 36 4 # test_pp 27 2304 36 4 # test_hybrid_tp_pp 27 2304 36 4 @@ -187,16 +195,16 @@ test_tp 27 2304 36 4 # ================================================= # 8 gpus: arch layer 33,33, hidden 2816, heads 48 # ================================================= -test_mix_tp_1f1b 33 2816 48 8 -test_tp 33 2816 48 8 +# test_mix_tp_1f1b 33 2816 48 8 +# test_tp 33 2816 48 8 # test_pp 33 2816 48 8 # test_hybrid_tp_pp 33 2816 48 8 # ================================================= # 8 gpus: arch layer 24,24, hidden 4096, heads 32 # ================================================= -test_mix_tp_1f1b 24 4096 32 8 -test_tp 24 4096 32 8 +# test_mix_tp_1f1b 24 4096 32 8 +# test_tp 24 4096 32 8 # test_pp 24 4096 32 8 # test_hybrid_tp_pp 24 4096 32 8 diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index d2a8164c..40a00f03 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -101,27 +101,20 @@ _pp_global_ranks = tuple(pp_ranks) # layer division - encoder_time = [1] * args.layers - decoder_time = [1] * args.layers - times = encoder_time + decoder_time - num_stages = torch.distributed.get_world_size(_pp_group) - budget = sum(times) // num_stages - print_each_rank(f'budget: {budget}', rank_only=0) - start, end = 0, 1 - for idx in range(num_stages): - accum = times[start] - assert end <= args.layers * 2 - while end != args.layers * 2: - accum += times[end] - if accum > budget: - break - end += 1 - if idx == num_stages - 1: - end = args.layers * 2 - _layer_divisions.append((start, end)) + chunk_num = args.layers // (args.pp_size // 2) + layers = [chunk_num] * (args.pp_size // 2) + for idx in range(args.layers % chunk_num): + layers[-1-idx] += 1 + layer_num_per_dev = layers + layers + start = 0 + layer_scopes, start = [], 0 + for sid in range(args.pp_size): + end = start + layer_num_per_dev[sid] + layer_scopes.append((start, end)) if start <= args.layers and end > args.layers: - _first_decoder_stage = idx - start, end = end, end+1 + _first_decoder_stage = sid + start = end + _layer_divisions = layer_scopes assert _first_decoder_stage != _first_encoder_stage, "Not supported yet" else: _layer_divisions = [(0, args.layers * 2)] @@ -287,8 +280,8 @@ def __init__(self, cfg: Config): self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) self.dropout = torch.nn.Dropout(p=cfg.dropout) self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim) - self.fc2 = torch.nn.Linear(cfg.ffn_dim, cfg.embed_dim) + self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) + self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) self.inputs_info = ( @@ -342,8 +335,8 @@ def __init__(self, cfg: Config): self.encoder_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim) - self.fc2 = torch.nn.Linear(cfg.ffn_dim, cfg.embed_dim) + self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) + self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) self.inputs_info = ( From 08aaa31977b201a7065d1893898836a273e6b6df Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Apr 2022 10:50:40 +0000 Subject: [PATCH 0721/1892] avoid oom --- handcraft/mbart/test-fp32.sh | 36 ++++++++++++++++++++---------------- handcraft/mbart/train.py | 4 ++-- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh index d1beec94..e1c8659a 100755 --- a/handcraft/mbart/test-fp32.sh +++ b/handcraft/mbart/test-fp32.sh @@ -1,6 +1,7 @@ evaldir=eval/mbart-fp32-v100-32gb mkdir -p ${evaldir} +bs=256 test_mix_tp_1f1b() { @@ -12,7 +13,7 @@ test_mix_tp_1f1b() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 8 --micro-bs 1 \ + --bs ${bs} --micro-bs 1 \ --pp-size ${gpus} --tp-size 1 \ --schedule tp1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt sleep 5 @@ -31,7 +32,7 @@ test_tp() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 8 --micro-bs 1 \ + --bs ${bs} --micro-bs 1 \ --pp-size 1 --tp-size ${gpus} \ --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt sleep 5 @@ -50,7 +51,7 @@ test_pp() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 256 --micro-bs 1 \ + --bs ${bs} --micro-bs 1 \ --pp-size ${gpus} --tp-size 1 \ --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp.txt sleep 5 @@ -69,7 +70,7 @@ test_pp_swap() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 256 --micro-bs 1 \ + --bs ${bs} --micro-bs 1 \ --pp-size ${gpus} --tp-size 1 \ --schedule 1f1b --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp-swap.txt sleep 5 @@ -91,7 +92,7 @@ test_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/mbart_hybrid.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 256 --micro-bs 1 \ + --bs ${bs} --micro-bs 1 \ --pp-size 2 --tp-size 2 \ --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt sleep 5 @@ -106,7 +107,7 @@ test_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/mbart_hybrid.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 256 --micro-bs 1 \ + --bs ${bs} --micro-bs 1 \ --pp-size 2 --tp-size 4 \ --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt sleep 5 @@ -118,7 +119,7 @@ test_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/mbart_hybrid.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 256 --micro-bs 1 \ + --bs ${bs} --micro-bs 1 \ --pp-size 2 --tp-size 4 \ --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt sleep 5 @@ -129,6 +130,9 @@ test_hybrid_tp_pp() } +# ================================================= +# selected experiments +# ================================================= test_mix_tp_1f1b 27 2304 36 4 test_tp 27 2304 36 4 test_mix_tp_1f1b 33 2816 48 8 @@ -139,7 +143,7 @@ test_tp 24 4096 32 8 # ================================================= # 4 gpus: arch layer 21,21, hidden 1792, heads 28 # ================================================= -# test_mix_tp_1f1b 21 1792 28 4 +# test_mix_tp_1f1b 21 1792 28 4 # test_tp 21 1792 28 4 # test_pp 21 1792 28 4 # test_hybrid_tp_pp 21 1792 28 4 @@ -147,7 +151,7 @@ test_tp 24 4096 32 8 # ================================================= # 4 gpus: arch layer 24,24, hidden 2048, heads 32 # ================================================= -# test_mix_tp_1f1b 24 2048 32 4 +# test_mix_tp_1f1b 24 2048 32 4 # test_tp 24 2048 32 4 # test_pp 24 2048 32 4 # test_hybrid_tp_pp 24 2048 32 4 @@ -155,7 +159,7 @@ test_tp 24 4096 32 8 # ================================================= # 4 gpus: arch layer 24,24, hidden 2560, heads 32 # ================================================= -# test_mix_tp_1f1b 24 2560 32 4 +# test_mix_tp_1f1b 24 2560 32 4 # test_tp 24 2560 32 4 # test_pp 24 2560 32 4 # test_hybrid_tp_pp 24 2560 32 4 @@ -163,7 +167,7 @@ test_tp 24 4096 32 8 # ================================================= # 4 gpus: arch layer 18,18, hidden 3072, heads 32 # ================================================= -# test_mix_tp_1f1b 18 3072 32 4 +# test_mix_tp_1f1b 18 3072 32 4 # test_tp 18 3072 32 4 # test_pp 18 3072 32 4 # test_hybrid_tp_pp 18 3072 32 4 @@ -171,7 +175,7 @@ test_tp 24 4096 32 8 # ================================================= # 4 gpus: arch layer 27,27, hidden 2304, heads 36 # ================================================= -# test_mix_tp_1f1b 27 2304 36 4 +# test_mix_tp_1f1b 27 2304 36 4 # test_tp 27 2304 36 4 # test_pp 27 2304 36 4 # test_hybrid_tp_pp 27 2304 36 4 @@ -179,7 +183,7 @@ test_tp 24 4096 32 8 # ================================================= # 8 gpus: arch layer 24,24, hidden 2048, heads 32 # ================================================= -# test_mix_tp_1f1b 24 2048 32 8 +# test_mix_tp_1f1b 24 2048 32 8 # test_tp 24 2048 32 8 # test_pp 24 2048 32 8 # test_hybrid_tp_pp 24 2048 32 8 @@ -187,7 +191,7 @@ test_tp 24 4096 32 8 # ================================================= # 8 gpus: arch layer 30,30, hidden 2560, heads 40 # ================================================= -# test_mix_tp_1f1b 30 2560 40 8 +# test_mix_tp_1f1b 30 2560 40 8 # test_tp 30 2560 40 8 # test_pp 30 2560 40 8 # test_hybrid_tp_pp 30 2560 40 8 @@ -195,7 +199,7 @@ test_tp 24 4096 32 8 # ================================================= # 8 gpus: arch layer 33,33, hidden 2816, heads 48 # ================================================= -# test_mix_tp_1f1b 33 2816 48 8 +# test_mix_tp_1f1b 33 2816 48 8 # test_tp 33 2816 48 8 # test_pp 33 2816 48 8 # test_hybrid_tp_pp 33 2816 48 8 @@ -203,7 +207,7 @@ test_tp 24 4096 32 8 # ================================================= # 8 gpus: arch layer 24,24, hidden 4096, heads 32 # ================================================= -# test_mix_tp_1f1b 24 4096 32 8 +# test_mix_tp_1f1b 24 4096 32 8 # test_tp 24 4096 32 8 # test_pp 24 4096 32 8 # test_hybrid_tp_pp 24 4096 32 8 diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index 40a00f03..f52b2e11 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -104,7 +104,7 @@ chunk_num = args.layers // (args.pp_size // 2) layers = [chunk_num] * (args.pp_size // 2) for idx in range(args.layers % chunk_num): - layers[-1-idx] += 1 + layers[-2-idx] += 1 layer_num_per_dev = layers + layers start = 0 layer_scopes, start = [], 0 @@ -738,7 +738,7 @@ def __next__(self): print_each_rank('model weight consumpition:') memory_summary() - CudaTimer(enable=False).warmup() + CudaTimer(enable=False) iter_num = 6 for step in range(iter_num): if step >= 2: From c8e0ad06852a012009021784dc0dc39afd002cc0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Apr 2022 13:11:59 +0000 Subject: [PATCH 0722/1892] fix relative index bug --- handcraft/swin/test-multi-node.sh | 10 +++--- handcraft/swin/test.sh | 46 +++++++++++++-------------- handcraft/swin/train.py | 52 +++++++++++++++++++++---------- handcraft/swin/utils.py | 33 +++++++++++++------- 4 files changed, 81 insertions(+), 60 deletions(-) diff --git a/handcraft/swin/test-multi-node.sh b/handcraft/swin/test-multi-node.sh index 1feec5f7..6e3bb84a 100755 --- a/handcraft/swin/test-multi-node.sh +++ b/handcraft/swin/test-multi-node.sh @@ -190,14 +190,12 @@ test_all() } -# test Layers Dim Heads Nodes GPUs -# test_naive_tp 42 1024 32 2 8 +# ================================================= +# selected experiments +# ================================================= test_coshard_hybrid_tp_pp 42 1024 32 2 16 - -# test_naive_tp 50 1024 32 2 8 test_coshard_hybrid_tp_pp 50 1024 32 2 16 - -# test_naive_tp 34 1024 32 2 8 +test_naive_tp 34 1024 32 2 8 test_coshard_hybrid_tp_pp 34 1024 32 2 16 python scripts/keep.py --gpus 8 diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index c6abeb75..3646e0e1 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -3,7 +3,7 @@ evaldir=eval/swin-coshard mkdir -p ${evaldir} - +bs=256 img_size=1536 window_size=48 @@ -23,7 +23,7 @@ test_naive_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}.txt sleep 5 killall python sleep 5 @@ -45,7 +45,7 @@ test_naive_tp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt sleep 5 killall python sleep 5 @@ -69,7 +69,7 @@ test_naive_hybrid_tp_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 2 --tp-size 2 --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2.txt sleep 5 killall python sleep 5 @@ -87,7 +87,7 @@ test_naive_hybrid_tp_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 2 --tp-size 4 --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2.txt sleep 5 killall python sleep 5 @@ -101,7 +101,7 @@ test_naive_hybrid_tp_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 4 --tp-size 2 --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4.txt sleep 5 killall python sleep 5 @@ -116,7 +116,7 @@ test_coshard_pp() heads=$3 gpus=$4 - echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" + echo "testing ${gpus}-dev: Coshard PP: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ @@ -124,7 +124,7 @@ test_coshard_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard.txt + --bs ${bs} --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard.txt sleep 5 killall python sleep 5 @@ -148,7 +148,8 @@ test_coshard_hybrid_tp_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 2 --tp-size 2 --dp-size 1 \ - --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2.txt + --bs ${bs} --micro-bs 1 --use-coshard \ + --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2-coshard.txt sleep 5 killall python sleep 5 @@ -166,7 +167,8 @@ test_coshard_hybrid_tp_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 2 --tp-size 4 --dp-size 1 \ - --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2.txt + --bs ${bs} --micro-bs 1 --use-coshard \ + --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2-coshard.txt sleep 5 killall python sleep 5 @@ -180,7 +182,8 @@ test_coshard_hybrid_tp_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 4 --tp-size 2 --dp-size 1 \ - --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4.txt + --bs ${bs} --micro-bs 1 --use-coshard \ + --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4-coshard.txt sleep 5 killall python sleep 5 @@ -200,20 +203,13 @@ test_all() test_coshard_pp $layers $dim $heads $gpus } +# ================================================= +# selected experiments +# ================================================= +test_coshard_pp 26 512 16 4 +test_naive_tp 26 512 16 4 +test_coshard_pp 34 768 24 8 +test_naive_tp 34 768 24 8 -# test Layers Dim Heads GPUs -test_all 18 256 8 4 -test_all 18 512 16 4 -test_all 18 768 24 4 - -test_all 26 256 8 8 4 -test_all 26 512 16 16 4 -test_all 26 768 24 24 4 -test_all 26 1024 32 32 4 - -test_all 34 256 8 8 -test_all 34 512 16 8 -test_all 34 768 24 8 -test_all 34 1024 32 8 python scripts/keep.py --gpus 8 diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 842a6fe8..9df6ffed 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -23,7 +23,7 @@ from cube.runtime.adapter.distnn import IdentityAllreduce, AllReduceIdentity, AllGatherSplit from handcraft.module.schedule import schedule_1f1b from handcraft.module.stage import PipeStage -from handcraft.swin.utils import create_position_bias, trunc_normal_, window_partition, window_reverse, DropPath +from handcraft.swin.utils import create_position_bias, create_position_index, trunc_normal_, window_partition, window_reverse, DropPath import argparse @@ -236,7 +236,9 @@ class WindowAttention(torch.nn.Module): proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ - def __init__(self, dim, inner_dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, inner_dim, window_size, num_heads, + qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_index=True): super().__init__() self._tp_group = _tp_group @@ -250,24 +252,35 @@ def __init__(self, dim, inner_dim, window_size, num_heads, qkv_bias=True, qk_sca self.scale = qk_scale or self.head_dim ** -0.5 # define define a parameter table of relative position bias - table, index = create_position_bias(self.window_size, self.num_heads) + table = create_position_bias(self.window_size, self.num_heads) self.relative_position_bias_table = table - self.register_buffer("relative_position_index", index) + if position_index: + index = create_position_index(window_size, cuda=False) + self.register_buffer("relative_position_index", index) + else: + self.relative_position_index = None self.qkv = nn.Linear(dim, inner_dim // self._tp_size * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(inner_dim // self._tp_size, dim) self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) - def forward_(self, x, mask=None): + def forward_(self, x, mask=None, position_index=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ + assert self.relative_position_index is None ^ position_index is None + if position_index is not None: + relative_position_index = self.position_index + else: + relative_position_index = self.relative_position_index + + if position_index is None: + relative_position_index = create_position_index(self.window_size, cuda=True) + if self._tp_size > 1: x = IdentityAllreduce.apply(x, self._tp_group) @@ -278,7 +291,7 @@ def forward_(self, x, mask=None): q = q * self.scale attn = (q @ k.transpose(-2, -1)) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + relative_position_bias = self.relative_position_bias_table[relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) @@ -302,11 +315,11 @@ def forward_(self, x, mask=None): return x - def forward(self, x, mask=None, recompute=True): + def forward(self, x, mask=None, position_index=None, recompute=True): if recompute: - x = checkpoint.checkpoint(self.forward_, x, mask) + x = checkpoint.checkpoint(self.forward_, x, mask, position_index) else: - x = self.forward_(x, mask) + x = self.forward_(x, mask, position_index) return x def extra_repr(self) -> str: @@ -340,10 +353,13 @@ def __init__(self, dim, inner_dim, window_size, num_heads, self.attns = torch.nn.ModuleList( [WindowAttention( dim, inner_dim // self.coshard, window_size, num_heads // self.coshard, - qkv_bias, qk_scale, attn_drop, proj_drop) for _ in range(self.coshard)] + qkv_bias, qk_scale, attn_drop, proj_drop, False) for _ in range(self.coshard)] ) - # remove communication inside each attention as it will be - # done outside here + # 1) remove communication inside each attention as it will be + # done outside here + # 2) share same relative position index + index = create_position_index(window_size, cuda=False) + self.register_buffer("relative_position_index", index) for attn in self.attns: attn._tp_size = 1 @@ -374,7 +390,7 @@ def forward(self, x, mask=None, recompute=True): outs = None for attn in self.attns: - x_out = attn(x, mask, recompute) + x_out = attn(x, mask, self.relative_position_index, recompute) outs = x_out if outs is None else outs + x_out if self._tp_size > 1: @@ -423,6 +439,7 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) else: coshard = num_heads // args.tp_size + coshard = coshard // 2 if layer_id > 0 else coshard print_each_rank(f'Swin-stage-{layer_id} using coshard {coshard}', rank_only=0) self.attn = SeqWindowAttention( dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, @@ -435,6 +452,7 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) else: coshard = num_heads // args.tp_size + coshard = coshard // 2 if layer_id > 0 else coshard self.mlp = SeqMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, coshard=coshard) H, W = self.input_resolution @@ -915,7 +933,7 @@ def train(): _dp_reducer.add_param(param) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) - print_each_rank('model weight consumpition:') + print_each_rank('model weight consumpition:', rank_only=0) memory_summary() def train_iter(model, dataloader): @@ -930,7 +948,7 @@ def train_iter(model, dataloader): if _dp_reducer is not None: _dp_reducer.allreduce() - CudaTimer(enable=False).warmup() + CudaTimer(enable=False) iter_num = 6 for step in range(iter_num): diff --git a/handcraft/swin/utils.py b/handcraft/swin/utils.py index dada9d68..2472b44d 100644 --- a/handcraft/swin/utils.py +++ b/handcraft/swin/utils.py @@ -91,16 +91,25 @@ def forward(self, x): def create_position_bias(window_size: Tuple[int, int], num_heads: int): relative_position_bias_table = torch.nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(window_size[0]) - coords_w = torch.arange(window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww trunc_normal_(relative_position_bias_table, std=.02) - return relative_position_bias_table, relative_position_index + return relative_position_bias_table + + +def create_position_index(window_size: Tuple[int, int], cuda=False): + # get pair-wise relative position index for each token inside the window + with torch.no_grad(): + if cuda: + coords_h = torch.arange(window_size[0], device=torch.cuda.current_device()) + coords_w = torch.arange(window_size[1], device=torch.cuda.current_device()) + else: + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + return relative_position_index From 5b561f68da61a769494bfff2207ef71dec90c1eb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Apr 2022 13:19:45 +0000 Subject: [PATCH 0723/1892] swin fix position index bug --- handcraft/swin/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 9df6ffed..25d6392f 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -272,9 +272,9 @@ def forward_(self, x, mask=None, position_index=None): x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ - assert self.relative_position_index is None ^ position_index is None + assert (self.relative_position_index is None) ^ (position_index is None) if position_index is not None: - relative_position_index = self.position_index + relative_position_index = position_index else: relative_position_index = self.relative_position_index From 7906ed722d055d59d3e62697a3c2a4f4c3c3cd8f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 2 Apr 2022 04:47:23 +0000 Subject: [PATCH 0724/1892] inner-cosharding --- handcraft/swin/test-multi-node.sh | 66 +++++++++++++++++-------------- handcraft/swin/test.sh | 3 ++ handcraft/swin/train.py | 43 ++++++++++---------- 3 files changed, 61 insertions(+), 51 deletions(-) diff --git a/handcraft/swin/test-multi-node.sh b/handcraft/swin/test-multi-node.sh index 6e3bb84a..566d9da8 100755 --- a/handcraft/swin/test-multi-node.sh +++ b/handcraft/swin/test-multi-node.sh @@ -6,6 +6,7 @@ mkdir -p ${evaldir} img_size=1536 window_size=48 +bs=256 test_naive_pp() @@ -27,7 +28,7 @@ test_naive_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}.txt sleep 5 killall python sleep 5 @@ -53,7 +54,7 @@ test_naive_tp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt sleep 5 killall python sleep 5 @@ -82,13 +83,13 @@ test_naive_hybrid_tp_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 2 --tp-size 8 --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp8pp2.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp8pp2.txt sleep 5 killall python sleep 5 killall python - echo "testing ${gpus}-dev: TP2-PP4: L${layers}E${dim}H${heads}" + echo "testing ${gpus}-dev: TP4-PP4: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ @@ -96,7 +97,7 @@ test_naive_hybrid_tp_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 4 --tp-size 4 --dp-size 1 \ - --bs 256 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp4.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp4.txt sleep 5 killall python sleep 5 @@ -111,6 +112,7 @@ test_coshard_pp() heads=$3 nodes=$4 gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ @@ -123,7 +125,8 @@ test_coshard_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard.txt + --bs ${bs} --micro-bs 1 --use-coshard + --fp16 > ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt sleep 5 killall python sleep 5 @@ -137,28 +140,12 @@ test_coshard_hybrid_tp_pp() heads=$3 nodes=$4 gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} # Hybrid TP-1F1B -- 8 GPU if [ ${gpus} == 16 ] then - echo "testing ${gpus}-dev: TP8-PP2: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 2 --tp-size 8 --dp-size 1 \ - --bs 64 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp8pp2-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - - # echo "testing ${gpus}-dev: TP4-PP4: L${layers}E${dim}H${heads}" + # echo "testing ${gpus}-dev: TP8-PP2: L${layers}E${dim}H${heads}" # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=${nodes} \ @@ -168,12 +155,30 @@ test_coshard_hybrid_tp_pp() # handcraft/swin/train.py \ # --layers ${layers} --dim ${dim} --heads ${heads} \ # --img-size ${img_size} --window-size ${window_size} \ - # --pp-size 4 --tp-size 4 --dp-size 1 \ - # --bs 256 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp4-coshard.txt + # --pp-size 2 --tp-size 8 --dp-size 1 \ + # --bs 64 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp8pp2-coshard.txt # sleep 5 # killall python # sleep 5 # killall python + + echo "testing coshard ${gpus}-dev: TP4-PP4: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 4 --tp-size 4 --dp-size 1 \ + --bs ${bs} --micro-bs 1 --use-coshard --use-inner-coshard \ + --fp16 > ${evaldir}/${gpus}dev-${arch}-tp4pp4-coshard.txt + sleep 5 + killall python + sleep 5 + killall python fi } @@ -193,9 +198,10 @@ test_all() # ================================================= # selected experiments # ================================================= -test_coshard_hybrid_tp_pp 42 1024 32 2 16 -test_coshard_hybrid_tp_pp 50 1024 32 2 16 -test_naive_tp 34 1024 32 2 8 -test_coshard_hybrid_tp_pp 34 1024 32 2 16 + +# test_naive_tp 42 1024 32 2 16 +test_coshard_hybrid_tp_pp 42 1024 32 2 16 +# test_naive_tp 50 1024 32 2 16 +test_coshard_hybrid_tp_pp 50 1024 32 2 16 python scripts/keep.py --gpus 8 diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index 3646e0e1..d023c3cb 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -211,5 +211,8 @@ test_naive_tp 26 512 16 4 test_coshard_pp 34 768 24 8 test_naive_tp 34 768 24 8 +# DGX-2 testing cases +# test_coshard_hybrid_tp_pp 42 1024 32 16 + python scripts/keep.py --gpus 8 diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 25d6392f..60d9a0be 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -22,7 +22,7 @@ from cube.runtime.device import DeviceGroup from cube.runtime.adapter.distnn import IdentityAllreduce, AllReduceIdentity, AllGatherSplit from handcraft.module.schedule import schedule_1f1b -from handcraft.module.stage import PipeStage +from handcraft.module.stage import PipeStage, layer_division from handcraft.swin.utils import create_position_bias, create_position_index, trunc_normal_, window_partition, window_reverse, DropPath import argparse @@ -55,7 +55,10 @@ help='data parallelism size') parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b'], help='scheduling algorithm') -parser.add_argument('--use-coshard', action='store_true', default=False) +parser.add_argument('--use-coshard', action='store_true', default=False, + help='enable this will split head but co-locate them with re-compute') +parser.add_argument('--use-inner-coshard', action='store_true', default=False, + help='enable this will shard bmm in attention of q @ k') parser.add_argument('--fp16', action='store_true', default=False) args = parser.parse_args() @@ -120,21 +123,12 @@ ([1] * args.layers + [0]) + \ ([1] * 2) num_stages = len(pp_ranks) - budget = sum(times) // num_stages - print_each_rank(f'budget: {budget}', rank_only=0) - start, end = 0, 1 - for idx in range(num_stages): - accum = times[start] - assert end <= nlayers - while end != nlayers: - if times[end] > 0 and budget - accum < 0.5 * times[end]: - break - accum += times[end] - end += 1 - if idx == num_stages - 1: - end = nlayers - _layer_divisions.append((start, end)) - start, end = end, end+1 + _layer_divisions = layer_division(times, num_stages) + # specific rules for stage division in order to fit in memory + if args.dim == 1024 and args.tp_size == 4: + if _layer_divisions[0][1] > 8: + remain_times = times[8:] + _layer_divisions = [(0, 8)] + layer_division(remain_times, num_stages-1, start_id=8) else: _layer_divisions = [(0, 2 + 2 + args.layers + 2 + 3)] print_each_rank(f'layer divisions: {_layer_divisions}', rank_only=0) @@ -278,9 +272,6 @@ def forward_(self, x, mask=None, position_index=None): else: relative_position_index = self.relative_position_index - if position_index is None: - relative_position_index = create_position_index(self.window_size, cuda=True) - if self._tp_size > 1: x = IdentityAllreduce.apply(x, self._tp_group) @@ -289,7 +280,17 @@ def forward_(self, x, mask=None, position_index=None): q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + + k = k.transpose(-2, -1) + # inner coshard by splitting windows + if args.use_inner_coshard and (B_ == 64 or B_ == 16): + chunk_num = B_ // 4 + attn = [] + for shard_q, shard_k in zip(torch.chunk(q, chunks=chunk_num, dim=0), torch.chunk(k, chunks=chunk_num, dim=0)): + attn.append(shard_q @ shard_k) + attn = torch.concat(tuple(attn), dim=0) + else: + attn = (q @ k) relative_position_bias = self.relative_position_bias_table[relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH From 964b4537d214736404df21b290c9679011538290 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 2 Apr 2022 04:50:26 +0000 Subject: [PATCH 0725/1892] layer division algorithm --- handcraft/module/stage.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/handcraft/module/stage.py b/handcraft/module/stage.py index fd32a90a..39359167 100644 --- a/handcraft/module/stage.py +++ b/handcraft/module/stage.py @@ -130,3 +130,29 @@ def data(self) -> Tuple: @data.setter def data(self, datas: Tuple): self._data = datas + + +def layer_division(times: List[int], num_stages: int, start_id: int = 0): + """ + Computation balance division + """ + divisions = [] + budget = sum(times) / num_stages + nlayers = len(times) + start, end = 0, 1 + for idx in range(num_stages): + accum = times[start] + assert end <= nlayers + while end != nlayers: + if times[end] > 0 and budget - accum < 0.5 * times[end]: + break + accum += times[end] + end += 1 + if idx == num_stages - 1: + end = nlayers + divisions.append((start, end)) + start, end = end, end+1 + for sid in range(num_stages): + start, end = divisions[sid] + divisions[sid] = (start+start_id, end+start_id) + return divisions From c3d9b1d1cb16796fa8bf2d4c1b94df5cdaefd89d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 2 Apr 2022 08:12:10 +0000 Subject: [PATCH 0726/1892] add script --- handcraft/swin/test-multi-node.sh | 14 ++++++++++---- handcraft/swin/test.sh | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/handcraft/swin/test-multi-node.sh b/handcraft/swin/test-multi-node.sh index 566d9da8..7ac0a825 100755 --- a/handcraft/swin/test-multi-node.sh +++ b/handcraft/swin/test-multi-node.sh @@ -91,8 +91,11 @@ test_naive_hybrid_tp_pp() echo "testing ${gpus}-dev: TP4-PP4: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -199,9 +202,12 @@ test_all() # selected experiments # ================================================= -# test_naive_tp 42 1024 32 2 16 +test_naive_tp 42 1024 32 2 16 test_coshard_hybrid_tp_pp 42 1024 32 2 16 -# test_naive_tp 50 1024 32 2 16 +# test_naive_hybrid_tp_pp 42 1024 32 2 16 # -> OOM + +test_naive_tp 50 1024 32 2 16 test_coshard_hybrid_tp_pp 50 1024 32 2 16 +# test_naive_hybrid_tp_pp 50 1024 32 2 16 # -> OOM python scripts/keep.py --gpus 8 diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index d023c3cb..b736cf57 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -206,10 +206,19 @@ test_all() # ================================================= # selected experiments # ================================================= -test_coshard_pp 26 512 16 4 -test_naive_tp 26 512 16 4 -test_coshard_pp 34 768 24 8 -test_naive_tp 34 768 24 8 +test_coshard_pp 26 512 16 4 +test_naive_tp 26 512 16 4 +# test_naive_hybrid_tp_pp 26 512 16 4 # --> OOM + +test_coshard_pp 34 512 16 8 +test_naive_tp 34 512 16 8 +test_naive_hybrid_tp_pp 34 512 16 8 + +test_coshard_pp 42 768 24 8 +test_naive_tp 42 768 24 8 +# test_naive_hybrid_tp_pp 42 768 24 8 # --> OOM + + # DGX-2 testing cases # test_coshard_hybrid_tp_pp 42 1024 32 16 From 3af03f4d16432d183a2e5e6db25f36fa85b97362 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 2 Apr 2022 12:26:51 +0000 Subject: [PATCH 0727/1892] add params# --- handcraft/swin/train.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 60d9a0be..39db2d4e 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -645,16 +645,16 @@ def create_basic_layter(dim, input_resolution, depth, num_heads, window_size, use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ # swin transformer layers - blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, use_coshard=args.use_coshard, layer_id=layer_id) - for i in range(depth)]) + blocks = [SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, use_coshard=args.use_coshard, layer_id=layer_id) + for i in range(depth)] # patch merging layer if downsample is not None: downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) @@ -755,7 +755,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers - self.layers = nn.ModuleList() + layers = [] for i_layer in range(self.num_layers): blocks = create_basic_layter(dim=int(embed_dim * 2 ** i_layer), input_resolution=(self.patches_resolution[0] // (2 ** i_layer), @@ -770,12 +770,12 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, layer_id=i_layer) - self.layers += blocks + layers += blocks # pipeline split layers start, end = _layer_divisions[self.stage_local_rank] - print_each_rank(f'initializing layer ranging from [{start}, {end})') - self.layers = self.layers[start:end] + self.layers = torch.nn.ModuleList(layers[start:end]) + print_each_rank(f'initialized layers ({len(self.layers)}) ranging from [{start}, {end})') self.inputs_info = self.layers[0].inputs_info self.outputs_info = self.layers[-1].outputs_info From b31d8f43038cbbc58cb8f062d0c6b62dc2d36a61 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 2 Apr 2022 12:28:29 +0000 Subject: [PATCH 0728/1892] add params --- handcraft/swin/train.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 60d9a0be..fe62491b 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -1,14 +1,14 @@ """ example: -gpus=16 +gpus=1 OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ handcraft/swin/train.py \ - --bs ${gpus} --micro-bs 1 --fp16 \ - --dp-size 1 --pp-size 16 --tp-size 1 \ - --layers 42 --dim 1024 --heads 32 --use-coshard + --bs 32 --micro-bs 1 --fp16 \ + --dp-size 1 --pp-size 1 --tp-size 1 \ + --layers 18 --dim 128 --heads 4 """ import torch @@ -329,14 +329,14 @@ def extra_repr(self) -> str: def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim + # qkv = self.qkv(x) # M K N + flops += N * self.dim * (3 * self.head_dim * self.num_heads) # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N + flops += self.num_heads * (N * self.head_dim * N) # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) + flops += self.num_heads * N * N * self.head_dim # x = self.proj(x) - flops += N * self.dim * self.dim + flops += N * self.head_dim * self.num_heads * self.dim return flops @@ -398,6 +398,12 @@ def forward(self, x, mask=None, recompute=True): outs = AllReduceIdentity.apply(outs, self._tp_group) return outs + def flops(self, N): + flops = 0 + for attn in self.attns: + flops += attn.flops(N) + return flops + class SwinTransformerBlock(PipeStage): r""" Swin Transformer Block. @@ -925,6 +931,8 @@ def train(): ape=False, patch_norm=True, use_checkpoint=False) + nparams = sum([param.numel() for param in model.parameters()]) + print_each_rank(f'Model Params#: {nparams}') if args.fp16: model = model.half() model = model.cuda() From 9b94a2228f142f9d697f287bfc7151d07bffea55 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 2 Apr 2022 15:28:39 +0000 Subject: [PATCH 0729/1892] add multi-node test --- handcraft/mbart/test-multinode-fp32.sh | 172 +++++++++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 handcraft/mbart/test-multinode-fp32.sh diff --git a/handcraft/mbart/test-multinode-fp32.sh b/handcraft/mbart/test-multinode-fp32.sh new file mode 100644 index 00000000..b1d1c6b0 --- /dev/null +++ b/handcraft/mbart/test-multinode-fp32.sh @@ -0,0 +1,172 @@ +evaldir=eval/mbart-fp32-v100-32gb +mkdir -p ${evaldir} + +bs=256 + +test_mix_tp_1f1b() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev mixture-1f1b: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule tp1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_tp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure tp: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size 1 --tp-size ${gpus} \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_pp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure pp: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_pp_swap() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure pp swap: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule 1f1b --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp-swap.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_hybrid_tp_pp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + + if [ ${gpus} == 16 ] + then + echo "testing ${gpus}-dev tp:pp=2:2 | L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size 2 --tp-size 8 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt + sleep 5 + killall python + sleep 5 + killall python + + echo "testing ${gpus}-dev tp:pp=2:2 | L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size 4 --tp-size 4 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp4.txt + sleep 5 + killall python + sleep 5 + killall python + + echo "testing ${gpus}-dev tp:pp=2:2 | L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/mbart/mbart_hybrid.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size 8 --tp-size 2 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt + sleep 5 + killall python + sleep 5 + killall python + fi +} + + +# ================================================= +# selected experiments +# ================================================= +test_mix_tp_1f1b 32 5120 40 16 +test_tp 32 5120 40 16 +# test_hybrid_tp_pp 32 5120 40 16 # --> OOM + +python scripts/keep.py --gpus 8 From c2c62aeb2385819eb1a031441cd924e535b69b1e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 3 Apr 2022 09:43:08 +0000 Subject: [PATCH 0730/1892] fix swin dp coshard bugs --- handcraft/swin/test.sh | 43 ++++++++++++++++++++++++++++++++--------- handcraft/swin/train.py | 14 +++++++------- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index b736cf57..4853624d 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -3,7 +3,7 @@ evaldir=eval/swin-coshard mkdir -p ${evaldir} -bs=256 +bs=8 img_size=1536 window_size=48 @@ -131,6 +131,29 @@ test_coshard_pp() killall python } +test_coshard_dp() +{ + layers=$1 + dim=$2 + heads=$3 + gpus=$4 + + echo "testing ${gpus}-dev: Coshard DP: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 1 --tp-size 1 --dp-size ${gpus} \ + --bs ${bs} --micro-bs 1 --use-coshard \ + --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-dp${gpus}-coshard.txt + sleep 5 + killall python + sleep 5 + killall python +} + test_coshard_hybrid_tp_pp() { layers=$1 @@ -206,19 +229,21 @@ test_all() # ================================================= # selected experiments # ================================================= -test_coshard_pp 26 512 16 4 -test_naive_tp 26 512 16 4 -# test_naive_hybrid_tp_pp 26 512 16 4 # --> OOM +test_coshard_dp 18 256 8 2 -test_coshard_pp 34 512 16 8 -test_naive_tp 34 512 16 8 -test_naive_hybrid_tp_pp 34 512 16 8 +# test_coshard_dp 26 512 16 4 +# test_coshard_pp 26 512 16 4 +# test_naive_tp 26 512 16 4 +# test_naive_hybrid_tp_pp 26 512 16 4 # --> OOM -test_coshard_pp 42 768 24 8 -test_naive_tp 42 768 24 8 +# test_coshard_pp 42 768 24 8 +# test_naive_tp 42 768 24 8 # test_naive_hybrid_tp_pp 42 768 24 8 # --> OOM +# test_coshard_pp 34 512 16 8 +# test_naive_tp 34 512 16 8 +# test_naive_hybrid_tp_pp 34 512 16 8 # DGX-2 testing cases # test_coshard_hybrid_tp_pp 42 1024 32 16 diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 4c8cf937..5ec938ed 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -487,13 +487,13 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 self.register_buffer("attn_mask", attn_mask) - assert args.micro_bs // args.dp_size != 0 + assert args.bs // (args.micro_bs * args.dp_size) != 0 self.inputs_info = ( - ((args.micro_bs // args.dp_size, H * W, self.dim),), + ((args.micro_bs, H * W, self.dim),), (torch.float32 if not args.fp16 else torch.float16,) ) self.outputs_info = ( - ((args.micro_bs // args.dp_size, H * W, self.dim),), + ((args.micro_bs, H * W, self.dim),), (torch.float32 if not args.fp16 else torch.float16,) ) self.layer_id = layer_id @@ -583,13 +583,13 @@ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): self.norm = norm_layer(4 * dim) H, W = self.input_resolution - assert args.micro_bs // args.dp_size != 0 + assert args.bs // (args.micro_bs * args.dp_size) != 0 self.inputs_info = ( - ((args.micro_bs // args.dp_size, H * W, self.dim),), + ((args.micro_bs, H * W, self.dim),), (torch.float32 if not args.fp16 else torch.float16,) ) self.outputs_info = ( - ((args.micro_bs // args.dp_size, (H // 2) * (W // 2), self.dim * 2),), + ((args.micro_bs, (H // 2) * (W // 2), self.dim * 2),), (torch.float32 if not args.fp16 else torch.float16,) ) @@ -946,7 +946,7 @@ def train(): memory_summary() def train_iter(model, dataloader): - num_microbatch = args.bs // args.micro_bs + num_microbatch = args.bs // (args.dp_size * args.micro_bs) if _pp_group != -1: _schedule(model, dataloader, num_microbatch) else: From 1777efed548754ce92ff3174d71088b15f45a5d0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 3 Apr 2022 10:35:08 +0000 Subject: [PATCH 0731/1892] test for 2 gpus and 4 gpus --- handcraft/mbart/test-fp32.sh | 16 ++++++++++------ handcraft/mbart/train.py | 6 +++--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh index e1c8659a..73da9974 100755 --- a/handcraft/mbart/test-fp32.sh +++ b/handcraft/mbart/test-fp32.sh @@ -133,12 +133,16 @@ test_hybrid_tp_pp() # ================================================= # selected experiments # ================================================= -test_mix_tp_1f1b 27 2304 36 4 -test_tp 27 2304 36 4 -test_mix_tp_1f1b 33 2816 48 8 -test_tp 33 2816 48 8 -test_mix_tp_1f1b 24 4096 32 8 -test_tp 24 4096 32 8 +test_tp 8 2048 16 2 +test_mix_tp_1f1b 8 2048 16 2 +test_hybrid_tp_pp 8 2048 16 2 + +test_mix_tp_1f1b 16 3072 24 4 +test_tp 16 3072 24 4 + +# test_mix_tp_1f1b 24 4096 32 8 +# test_tp 24 4096 32 8 + # ================================================= # 4 gpus: arch layer 21,21, hidden 1792, heads 28 diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index f52b2e11..f4a60ff7 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -5,9 +5,9 @@ --nproc_per_node=4 \ --nnodes=1 \ handcraft/mbart/train.py \ - --layers 12 --hidden-size 1024 --heads 16 \ - --dp-size 1 --pp-size 4 --tp-size 1 \ - --bs 4 --micro-bs 1 --schedule 1f1b + --layers 16 --hidden-size 3072 --heads 32 \ + --pp-size 2 --tp-size 2 \ + --bs 16 --micro-bs 1 --schedule 1f1b """ from typing import Optional From 7ae4cddedaf2f0f6b205031e640b0bf96f27c607 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 3 Apr 2022 12:19:29 +0000 Subject: [PATCH 0732/1892] add full test scripts --- handcraft/mbart/test-fp32.sh | 88 ++++---------------------- handcraft/mbart/test-multinode-fp32.sh | 6 ++ handcraft/swin/test.sh | 34 +++++----- 3 files changed, 36 insertions(+), 92 deletions(-) diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh index 73da9974..4cd2266f 100755 --- a/handcraft/mbart/test-fp32.sh +++ b/handcraft/mbart/test-fp32.sh @@ -133,87 +133,21 @@ test_hybrid_tp_pp() # ================================================= # selected experiments # ================================================= -test_tp 8 2048 16 2 -test_mix_tp_1f1b 8 2048 16 2 -test_hybrid_tp_pp 8 2048 16 2 +# test_tp 8 2048 16 2 +# test_mix_tp_1f1b 8 2048 16 2 +# test_hybrid_tp_pp 8 2048 16 2 test_mix_tp_1f1b 16 3072 24 4 test_tp 16 3072 24 4 +test_mix_tp_1f1b 16 3072 24 8 +test_tp 16 3072 24 8 +# test_mix_tp_1f1b 16 3072 24 16 +# test_tp 16 3072 24 16 -# test_mix_tp_1f1b 24 4096 32 8 -# test_tp 24 4096 32 8 - - -# ================================================= -# 4 gpus: arch layer 21,21, hidden 1792, heads 28 -# ================================================= -# test_mix_tp_1f1b 21 1792 28 4 -# test_tp 21 1792 28 4 -# test_pp 21 1792 28 4 -# test_hybrid_tp_pp 21 1792 28 4 - -# ================================================= -# 4 gpus: arch layer 24,24, hidden 2048, heads 32 -# ================================================= -# test_mix_tp_1f1b 24 2048 32 4 -# test_tp 24 2048 32 4 -# test_pp 24 2048 32 4 -# test_hybrid_tp_pp 24 2048 32 4 - -# ================================================= -# 4 gpus: arch layer 24,24, hidden 2560, heads 32 -# ================================================= -# test_mix_tp_1f1b 24 2560 32 4 -# test_tp 24 2560 32 4 -# test_pp 24 2560 32 4 -# test_hybrid_tp_pp 24 2560 32 4 - -# ================================================= -# 4 gpus: arch layer 18,18, hidden 3072, heads 32 -# ================================================= -# test_mix_tp_1f1b 18 3072 32 4 -# test_tp 18 3072 32 4 -# test_pp 18 3072 32 4 -# test_hybrid_tp_pp 18 3072 32 4 - -# ================================================= -# 4 gpus: arch layer 27,27, hidden 2304, heads 36 -# ================================================= -# test_mix_tp_1f1b 27 2304 36 4 -# test_tp 27 2304 36 4 -# test_pp 27 2304 36 4 -# test_hybrid_tp_pp 27 2304 36 4 - -# ================================================= -# 8 gpus: arch layer 24,24, hidden 2048, heads 32 -# ================================================= -# test_mix_tp_1f1b 24 2048 32 8 -# test_tp 24 2048 32 8 -# test_pp 24 2048 32 8 -# test_hybrid_tp_pp 24 2048 32 8 - -# ================================================= -# 8 gpus: arch layer 30,30, hidden 2560, heads 40 -# ================================================= -# test_mix_tp_1f1b 30 2560 40 8 -# test_tp 30 2560 40 8 -# test_pp 30 2560 40 8 -# test_hybrid_tp_pp 30 2560 40 8 - -# ================================================= -# 8 gpus: arch layer 33,33, hidden 2816, heads 48 -# ================================================= -# test_mix_tp_1f1b 33 2816 48 8 -# test_tp 33 2816 48 8 -# test_pp 33 2816 48 8 -# test_hybrid_tp_pp 33 2816 48 8 +test_mix_tp_1f1b 16 3072 24 4 +test_tp 16 3072 24 4 -# ================================================= -# 8 gpus: arch layer 24,24, hidden 4096, heads 32 -# ================================================= -# test_mix_tp_1f1b 24 4096 32 8 -# test_tp 24 4096 32 8 -# test_pp 24 4096 32 8 -# test_hybrid_tp_pp 24 4096 32 8 +test_mix_tp_1f1b 24 4096 32 8 +test_tp 24 4096 32 8 python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/test-multinode-fp32.sh b/handcraft/mbart/test-multinode-fp32.sh index b1d1c6b0..4d5c0905 100644 --- a/handcraft/mbart/test-multinode-fp32.sh +++ b/handcraft/mbart/test-multinode-fp32.sh @@ -165,6 +165,12 @@ test_hybrid_tp_pp() # ================================================= # selected experiments # ================================================= + +# strong scalability test +test_mix_tp_1f1b 16 3072 24 16 +test_tp 16 3072 24 16 + +# model scaling test test_mix_tp_1f1b 32 5120 40 16 test_tp 32 5120 40 16 # test_hybrid_tp_pp 32 5120 40 16 # --> OOM diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index 4853624d..a4bdcf12 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -3,7 +3,7 @@ evaldir=eval/swin-coshard mkdir -p ${evaldir} -bs=8 +bs=256 img_size=1536 window_size=48 @@ -229,24 +229,28 @@ test_all() # ================================================= # selected experiments # ================================================= +test_naive_tp 6 96 3 1 +test_naive_tp 10 128 4 1 +test_naive_tp 14 192 6 1 +# test_naive_tp 18 256 8 1 # --> OOM +# test_naive_tp 26 512 16 1 # --> OOM +test_coshard_pp 6 96 3 1 +test_coshard_pp 10 128 4 1 +test_coshard_pp 14 192 6 1 +test_coshard_pp 18 256 8 1 +test_coshard_pp 26 512 16 1 + test_coshard_dp 18 256 8 2 +test_naive_tp 18 256 8 2 +test_naive_hybrid_tp_pp 18 256 8 2 -# test_coshard_dp 26 512 16 4 -# test_coshard_pp 26 512 16 4 -# test_naive_tp 26 512 16 4 +test_coshard_dp 26 512 16 4 +test_coshard_pp 26 512 16 4 +test_naive_tp 26 512 16 4 # test_naive_hybrid_tp_pp 26 512 16 4 # --> OOM -# test_coshard_pp 42 768 24 8 -# test_naive_tp 42 768 24 8 +test_coshard_pp 42 768 24 8 +test_naive_tp 42 768 24 8 # test_naive_hybrid_tp_pp 42 768 24 8 # --> OOM - -# test_coshard_pp 34 512 16 8 -# test_naive_tp 34 512 16 8 -# test_naive_hybrid_tp_pp 34 512 16 8 - -# DGX-2 testing cases -# test_coshard_hybrid_tp_pp 42 1024 32 16 - - python scripts/keep.py --gpus 8 From e7d7a7ebdaa7f999c8875bcfd87c3e01d15125b1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 4 Apr 2022 09:33:41 +0000 Subject: [PATCH 0733/1892] add schedule for tp1f1b 2-stage --- handcraft/mbart/test-fp32.sh | 8 +- handcraft/module/schedule.py | 171 +++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 4 deletions(-) diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh index 4cd2266f..a64e6043 100755 --- a/handcraft/mbart/test-fp32.sh +++ b/handcraft/mbart/test-fp32.sh @@ -1,7 +1,7 @@ evaldir=eval/mbart-fp32-v100-32gb mkdir -p ${evaldir} -bs=256 +bs=16 test_mix_tp_1f1b() { @@ -133,9 +133,9 @@ test_hybrid_tp_pp() # ================================================= # selected experiments # ================================================= -# test_tp 8 2048 16 2 -# test_mix_tp_1f1b 8 2048 16 2 -# test_hybrid_tp_pp 8 2048 16 2 +test_tp 8 2048 16 2 +test_mix_tp_1f1b 8 2048 16 2 +test_hybrid_tp_pp 8 2048 16 2 test_mix_tp_1f1b 16 3072 24 4 test_tp 16 3072 24 4 diff --git a/handcraft/module/schedule.py b/handcraft/module/schedule.py index d2362bdf..5f2fc700 100644 --- a/handcraft/module/schedule.py +++ b/handcraft/module/schedule.py @@ -325,10 +325,181 @@ def schedule_1f1b(model: PipeStage, model.assert_empty_cached() +def schedule_tp1f1b_pp2(model: PipeStage, + dataloader, + num_microbatch: int, + recompute=False): + def tp_encoder_preprocess(model: PipeStage) -> torch.Tensor: + model.data = next(dataloader) + enc = model.forward_encoder_shard() + return (enc,) + + def tp_decoder_preprocess(model: PipeStage) -> torch.Tensor: + model.data = next(dataloader) + dec = model.forward_decoder_shard() + return (dec,) + + def tp_encoder_backward(model: PipeStage): + enc = model.pop('encoder_sharding_output') + if model.stage_local_rank == model.first_encoder_stage: + grads = model.pop('encoder_sharding_grad') + else: + grads = (torch.empty_like(enc),) + backward_step((), (enc,), grads) + + def tp_decoder_backward(model: PipeStage): + dec = model.pop('decoder_sharding_output') + if model.stage_local_rank == model.first_decoder_stage: + grads = model.pop('decoder_sharding_grad') + else: + grads = (torch.empty_like(dec),) + backward_step((), (dec,), grads) + + num_stage = model.num_stages + rank = model.stage_local_rank + prev_rank = model.prev_stage_global_grank + next_rank = model.next_stage_global_rank + + output_grads = (None,) + inputs = () + for step in range(num_microbatch * 2 + 2): + + encoder_fmid = step // 2 + encoder_bmid = step - 2 + decoder_fmid = step - 1 + decoder_bmid = step - 3 + + # step1: forward sharding 0 + if step % 2 == 0: + encoder_fmid = step // 2 + encoder_inputs = None + if 0 <= encoder_fmid and encoder_fmid <= num_microbatch - 1: + encoder_inputs = tp_encoder_preprocess(model) + # step1: forward sharding 1 + if step % 2 == 1: + decoder_fmid = (step - 1) // 2 + decoder_inputs = None + if 0 <= decoder_fmid and decoder_fmid <= num_microbatch - 1: + decoder_inputs = tp_decoder_preprocess(model) + + if rank % 2 == 0: + # do forward + if step % 2 == 0: + fmid = step // 2 + do_forward = 0 <= fmid and fmid <= num_microbatch - 1 + if do_forward: + model.push(encoder_inputs, 'inputs') + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *(), recompute=True) + model.push(None, 'outputs') + else: + outputs = forward_step(model, *()) + model.push(outputs, 'outputs') + + # recompute + next_bmid = (step + 1 - 3) // 2 if step+1 >= 3 else -1 + do_next_backward = 0 <= next_bmid and next_bmid <= num_microbatch - 1 + if recompute and do_next_backward : + outputs_bp = model.pop('outputs') + assert outputs_bp is None + outputs_bp = forward_step(model, *()) + model.push_ahead(outputs_bp, 'outputs') + + # send forward recv backward + if do_forward and do_next_backward: + # print(f'rank {rank}: step {step}: send forward recv backward') + output_grads = send_forward_recv_backward(outputs, model, next_rank) + elif do_next_backward: + # print(f'rank {rank}: step {step}: recv backward') + output_grads = recv_backward(model, next_rank) + elif do_forward: + # print(f'rank {rank}: step {step}: send forward') + send_forward(outputs, next_rank) + + # do backward + else: + bmid = (step - 3) // 2 if step >= 3 else -1 + if 0 <= bmid and bmid <= num_microbatch - 1: + inputs, outputs = model.pop('inputs'), model.pop('outputs') + input_grads = backward_step(inputs, outputs, output_grads) + output_grads = (None,) + assert len(input_grads) == 1 + model.push(input_grads, 'encoder_sharding_grad') + + if rank % 2 == 1: + # do backward + if step % 2 == 0: + bmid = (step - 2) // 2 if step >= 2 else -1 + do_backward = 0 <= bmid and bmid <= num_microbatch - 1 + + # backward + if do_backward: + inputs, outputs = model.pop('inputs'), model.pop('outputs') + assert output_grads == (None,) + input_grads = backward_step(inputs, outputs, output_grads) + assert len(inputs) == 2 + model.push((input_grads[1],), 'decoder_sharding_grad') + input_grads = (input_grads[0],) + + # send backward recv forward + next_fmid = (step + 1 - 1) // 2 + do_next_forward = 0 <= next_fmid and next_fmid <= num_microbatch - 1 + if do_backward and do_next_forward: + # print(f'rank {rank}: step {step}: send backward recv forward') + inputs = send_backward_recv_forward(input_grads, model, prev_rank) + elif do_next_forward: + # print(f'rank {rank}: step {step}: recv forward') + inputs = recv_forward(model, prev_rank) + elif do_backward: + # print(f'rank {rank}: step {step}: send backward') + send_backward(input_grads, prev_rank) + # do forward + else: + # forward + fmid = (step - 1) // 2 + if 0 <= fmid and fmid <= num_microbatch - 1: + assert inputs != () + model.push((inputs[0], decoder_inputs[0]), 'inputs') + if recompute: + with torch.no_grad(): + outputs = forward_step(model, *inputs, recompute=True) + model.push(None, 'outputs') + else: + outputs = forward_step(model, *inputs) + model.push(outputs, 'outputs') + + # recompute + if recompute: + inputs, outputs = model.pop('inputs'), model.pop('outputs') + assert outputs is None + outputs = forward_step(model, *inputs) + model.push_ahead(inputs, 'inputs') + model.push_ahead(outputs, 'outputs') + + + # step3: backward sharding 1 + if step % 2 == 0: + decoder_bmid = (step - 2) // 2 + if 0 <= decoder_bmid and decoder_bmid <= num_microbatch - 1: + tp_decoder_backward(model) + + # step3: backward sharding 0 + if step % 2 == 1: + encoder_bmid = (step - 3) // 2 + if 0 <= encoder_bmid and encoder_bmid <= num_microbatch - 1: + tp_encoder_backward(model) + + model.assert_empty_cached() + + def schedule_tp1f1b(model: PipeStage, dataloader, num_microbatch: int, recompute=False): + # special cases for pipeline stage == 2 + if model.num_stages == 2: + return schedule_tp1f1b_pp2(model, dataloader, num_microbatch, recompute) def tp_encoder_preprocess(model: PipeStage) -> torch.Tensor: model.data = next(dataloader) From fd4f77a73d12b10ab95bcfc1917221643ea9466d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 4 Apr 2022 09:36:30 +0000 Subject: [PATCH 0734/1892] bs back to normal --- handcraft/mbart/test-fp32.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh index a64e6043..badf06be 100755 --- a/handcraft/mbart/test-fp32.sh +++ b/handcraft/mbart/test-fp32.sh @@ -1,7 +1,7 @@ evaldir=eval/mbart-fp32-v100-32gb mkdir -p ${evaldir} -bs=16 +bs=256 test_mix_tp_1f1b() { From 6a265adc9df6b6463015a7b2a544a1070c2235d2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 4 Apr 2022 12:41:25 +0000 Subject: [PATCH 0735/1892] add assertation --- handcraft/mbart/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index f4a60ff7..643b48f3 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -94,6 +94,7 @@ if len(tp_ranks) != 1: print_each_rank(f'initializing tp ranks: {tp_ranks}') _tp_group = DeviceGroup().get_group(tp_ranks) + assert args.heads % args.tp_size == 0, "cannot be divided by tp-size" if len(pp_ranks) != 1: print_each_rank(f'initializing pp ranks: {pp_ranks}') @@ -728,6 +729,8 @@ def __next__(self): print_each_rank(f'enc/dec layer#: {cfg.encoder_layers}, embed_dim#: {cfg.embed_dim}, heads#: {cfg.attention_heads}, ffn_dim#: {cfg.ffn_dim}', rank_only=0) model = MBart(cfg) + nparams = sum([param.numel() for param in model.parameters()]) + print_each_rank('model params: [{nparams}]. Launching model...') model = model.half().cuda() if args.fp16 else model.cuda() dataloader = MBartDataLoader(args.micro_bs, cfg) @@ -739,6 +742,7 @@ def __next__(self): memory_summary() CudaTimer(enable=False) + torch.distributed.barrier() iter_num = 6 for step in range(iter_num): if step >= 2: From 2a242c433524e3ee6f0e769d81db09b4ddd19bae Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 4 Apr 2022 14:47:30 +0000 Subject: [PATCH 0736/1892] fix division bug --- handcraft/mbart/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index 643b48f3..b95daed6 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -104,7 +104,7 @@ # layer division chunk_num = args.layers // (args.pp_size // 2) layers = [chunk_num] * (args.pp_size // 2) - for idx in range(args.layers % chunk_num): + for idx in range(args.layers % args.pp_size): layers[-2-idx] += 1 layer_num_per_dev = layers + layers start = 0 @@ -730,7 +730,7 @@ def __next__(self): model = MBart(cfg) nparams = sum([param.numel() for param in model.parameters()]) - print_each_rank('model params: [{nparams}]. Launching model...') + print_each_rank(f'model params: [{nparams}]. Launching model...') model = model.half().cuda() if args.fp16 else model.cuda() dataloader = MBartDataLoader(args.micro_bs, cfg) From 777a81d90945dbe8e0d3e0f5d3cb0c152718adf6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 5 Apr 2022 04:31:36 +0000 Subject: [PATCH 0737/1892] recompute for tensor parallelism --- handcraft/mbart/train.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index b95daed6..8e879416 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -15,6 +15,7 @@ import math import numpy as np import torch +import torch.utils.checkpoint as checkpoint import cube from cube.runtime.device import DeviceGroup @@ -104,7 +105,7 @@ # layer division chunk_num = args.layers // (args.pp_size // 2) layers = [chunk_num] * (args.pp_size // 2) - for idx in range(args.layers % args.pp_size): + for idx in range(args.layers % (args.pp_size // 2)): layers[-2-idx] += 1 layer_num_per_dev = layers + layers start = 0 @@ -615,8 +616,15 @@ def forward(self, enc=None, dec=None, recompute=False): enc = self.embed(source_tokens, encoder=True) if self.encoder_forward: - for layer in self.encoders: - enc = layer(enc) + if args.pp_size == 1: + def encoder_forward(enc): + for layer in self.encoders: + enc = layer(enc) + return enc + enc = checkpoint.checkpoint(encoder_forward, enc) + else: + for layer in self.encoders: + enc = layer(enc) if self.layer_norm_encoder is not None: enc = self.layer_norm_encoder(enc) output = enc @@ -632,9 +640,17 @@ def forward(self, enc=None, dec=None, recompute=False): dec = self.embed(prev_tokens, decoder=True) if self.decoder_forward: - assert enc is not None - for layer in self.decoders: - dec, enc = layer(dec, enc) + if args.pp_size == 1: + def decoder_forward(enc, dec): + assert enc is not None + for layer in self.decoders: + dec, enc = layer(dec, enc) + return enc, dec + enc, dec = checkpoint.checkpoint(decoder_forward, enc, dec) + else: + assert enc is not None + for layer in self.decoders: + dec, enc = layer(dec, enc) if self.layer_norm_decoder is not None: dec = self.layer_norm_decoder(dec) output = (enc, dec) From b70bbac69996fd18ff8d13a1fa0c778369b1c276 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 5 Apr 2022 07:46:46 +0000 Subject: [PATCH 0738/1892] add tflops calculation --- handcraft/mbart/train.py | 78 ++++++++++++++++++++++++++++++++++--- handcraft/swin/train.py | 83 +++++++++++++++++++++++++--------------- 2 files changed, 126 insertions(+), 35 deletions(-) diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index 8e879416..df9654be 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -2,12 +2,11 @@ example: OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ + --nproc_per_node=1 \ --nnodes=1 \ handcraft/mbart/train.py \ - --layers 16 --hidden-size 3072 --heads 32 \ - --pp-size 2 --tp-size 2 \ - --bs 16 --micro-bs 1 --schedule 1f1b + --layers 8 --hidden-size 2048 --heads 16 \ + --bs 1 --micro-bs 1 --schedule 1f1b """ from typing import Optional @@ -236,6 +235,7 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, self.tp_group = _tp_group self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + self.embed_dim = embed_dim self.inner_dim = inner_dim self.head_dim = inner_dim // num_heads self.num_heads = num_heads // self.tp_size @@ -269,6 +269,22 @@ def forward(self, query: torch.Tensor, key: torch.Tensor): attn = AllReduceIdentity.apply(attn, self.tp_group) return attn + def flops(self, seqlen: int): + """ + Get forward-pass FLOPs for 1 micro-batch + """ + attn_flops = dict( + kqv=3 * seqlen * self.embed_dim * self.head_dim * self.num_heads, + kqv_bias=3 * seqlen * self.head_dim * self.num_heads, + q_scale=seqlen * self.num_heads * self.head_dim, # (N h) L d, 1 -> (N h) L d + attn_score=self.num_heads * seqlen * self.head_dim * seqlen, # (N h) L d, (N h) d L -> (N h) L L + attn_softmax=5 * self.num_heads * seqlen * seqlen, # (N h) L L + attn_dropout=self.num_heads * seqlen * seqlen, # (N h) L L -> (N h) L L + attn_output=self.num_heads * seqlen * seqlen * self.head_dim, # (N h) L L, (N h) L d -> (N h) L d + out_proj=seqlen * self.num_heads * self.head_dim * self.embed_dim, # L N (h d), E (h d) -> L N E + ) + return sum(attn_flops.values()) + class EncoderLayer(PipeStage): @@ -282,6 +298,7 @@ def __init__(self, cfg: Config): self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) self.dropout = torch.nn.Dropout(p=cfg.dropout) self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) + self.hidden_dim = cfg.ffn_dim // self.tp_size self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) @@ -319,6 +336,23 @@ def forward(self, x): x = x + residual return x + def flops(self): + seqlen = self.cfg.max_source_positions + enc_flops = dict( + attn_layer_norm=5 * seqlen * self.cfg.embed_dim, # (L, N, E) + attn=self.self_attn.flops(seqlen), + dropout=seqlen * self.cfg.embed_dim, # (L, N, E) + attn_residual=seqlen * self.cfg.embed_dim, + fc_layer_norm=5 * seqlen * self.cfg.embed_dim, # (L, N, E) + fc1=seqlen * self.cfg.embed_dim * self.hidden_dim, # L N E, E hidden -> L N hidden + gelu=8 * seqlen * self.hidden_dim, + fc_inner_dropout=seqlen * self.hidden_dim, + fc2=seqlen * self.cfg.embed_dim * self.hidden_dim, # L N hidden, hidden E -> L N E + fc_dropout=seqlen * self.cfg.embed_dim, + fc_residual=seqlen * self.cfg.embed_dim, + ) + return sum(enc_flops.values()) + class DecoderLayer(PipeStage): @@ -337,6 +371,7 @@ def __init__(self, cfg: Config): self.encoder_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) + self.hidden_dim = cfg.ffn_dim // self.tp_size self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) @@ -387,6 +422,23 @@ def forward(self, x, encoder_out): x = x + residual return x, encoder_out + def flops(self): + seqlen = self.cfg.max_target_positions + dec_flops = dict( + attn_layer_norm=0, # ignore + attn=self.self_attn.flops(seqlen) * 2, # self attention + cross attention + dropout=seqlen * self.cfg.embed_dim, # (L, N, E) + attn_residual=seqlen * self.cfg.embed_dim, + fc_layer_norm=0, # ignore + fc1=seqlen * self.cfg.embed_dim * self.hidden_dim, # L N E, E hidden -> L N hidden + gelu=seqlen * self.hidden_dim, + fc_inner_dropout=seqlen * self.hidden_dim, + fc2=seqlen * self.cfg.embed_dim * self.hidden_dim, # L N hidden, hidden E -> L N E + fc_dropout=seqlen * self.cfg.embed_dim, + fc_residual=seqlen * self.cfg.embed_dim, + ) + return sum(dec_flops.values()) + class MBartClassificationHead(torch.nn.Module): """Head for sentence-level classification tasks.""" @@ -417,6 +469,9 @@ def forward(self, dec: torch.Tensor, labels): loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) return loss + def flops(self): + return 0 # ignore + class ShardEmbed(torch.nn.Module): @@ -492,6 +547,10 @@ def forward(self, tokens, encoder=False, decoder=False, dst: Optional[int] = Non x = x.transpose(0, 1) return x + def flops(self): + # ignore + return 0 + def init_weight(self): for param in self.parameters(): torch.nn.init.constant_(param, 0.1) @@ -662,6 +721,13 @@ def decoder_forward(enc, dec): return output + def flops(self): + enc_flops = sum([enc.flops() for enc in self.encoders]) + enc_layernorm = 5 * self.cfg.max_source_positions * self.cfg.embed_dim + dec_flops = sum([dec.flops() for dec in self.decoders]) + dec_layernorm = 5 * self.cfg.max_target_positions * self.cfg.embed_dim + return enc_flops + enc_layernorm + dec_flops + dec_layernorm + def reduce_embed(model: MBart, pp_embed_group): """ @@ -746,7 +812,9 @@ def __next__(self): model = MBart(cfg) nparams = sum([param.numel() for param in model.parameters()]) - print_each_rank(f'model params: [{nparams}]. Launching model...') + forward_flops = model.flops() + flops = forward_flops * 4 # forward + re-compute forward + backward (=2 forward flops) + print_each_rank(f'model params: {nparams} | FLOPs: {flops}. Launching model...') model = model.half().cuda() if args.fp16 else model.cuda() dataloader = MBartDataLoader(args.micro_bs, cfg) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 5ec938ed..432f596d 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -1,14 +1,13 @@ """ example: -gpus=1 OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ + --nproc_per_node=1 \ --nnodes=1 \ handcraft/swin/train.py \ - --bs 32 --micro-bs 1 --fp16 \ + --bs 1 --micro-bs 1 --fp16 \ --dp-size 1 --pp-size 1 --tp-size 1 \ - --layers 18 --dim 128 --heads 4 + --layers 10 --dim 128 --heads 4 """ import torch @@ -159,6 +158,8 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay out_features = out_features or in_features hidden_features = hidden_features or in_features + self.in_features = in_features + self.hidden_features = hidden_features // self._tp_size self.fc1 = nn.Linear(in_features, hidden_features // self._tp_size) self.act = act_layer() self.fc2 = nn.Linear(hidden_features // self._tp_size, out_features) @@ -183,6 +184,16 @@ def forward(self, x, recompute=True): x = self.forward_(x) return x + def flops(self, seqlen: int): + mlp_flops = dict( + fc1=seqlen * self.in_features * self.hidden_features, + act=8 * seqlen * self.hidden_features, + drop=seqlen * self.hidden_features, + fc2=seqlen * self.hidden_features * self.in_features, + final_drop=seqlen * self.in_features, + ) + return sum(mlp_flops.values()) + class SeqMlp(torch.nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, @@ -216,6 +227,9 @@ def forward(self, x, recompute=True): outs = AllReduceIdentity.apply(outs, self._tp_group) return outs + def flops(self, seqlen: int): + return sum([mlp.flops(seqlen) for mlp in self.mlps]) + class WindowAttention(torch.nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. @@ -326,18 +340,21 @@ def forward(self, x, mask=None, position_index=None, recompute=True): def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) # M K N - flops += N * self.dim * (3 * self.head_dim * self.num_heads) - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * (N * self.head_dim * N) - # x = (attn @ v) - flops += self.num_heads * N * N * self.head_dim - # x = self.proj(x) - flops += N * self.head_dim * self.num_heads * self.dim - return flops + def flops(self, seqlen: int): + # calculate flops for one window + # seqlen is window size * window size + attn_flops = dict( + kqv=3 * seqlen * self.dim * self.head_dim * self.num_heads, + kqv_bias= 3 * seqlen * self.head_dim * self.num_heads, + q_scale=seqlen * self.num_heads * self.head_dim, + attn_score=self.num_heads * seqlen * self.head_dim * seqlen, # q @ k + position_index=self.num_heads * seqlen * seqlen, + attn_softmax=5 * self.num_heads * seqlen * seqlen, + attn_dropout=self.num_heads * seqlen * seqlen, + attn_output=self.num_heads * seqlen * seqlen * self.head_dim, # attn @ v + out_proj=seqlen * self.num_heads * self.head_dim * self.dim # self.proj(x) + ) + return sum(attn_flops.values()) class SeqWindowAttention(torch.nn.Module): @@ -398,10 +415,10 @@ def forward(self, x, mask=None, recompute=True): outs = AllReduceIdentity.apply(outs, self._tp_group) return outs - def flops(self, N): + def flops(self, seqlen: int): flops = 0 for attn in self.attns: - flops += attn.flops(N) + flops += attn.flops(seqlen) return flops @@ -553,18 +570,22 @@ def extra_repr(self) -> str: f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): - flops = 0 H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops + num_windows = H * W / self.window_size / self.window_size + block_flops = dict( + norm1=5 * H * W * self.dim, + roll1=0, # ignore + window_partition=0, # ignore + attn=num_windows * self.attn.flops(self.window_size * self.window_size), + roll2=0, # ignore + attn_dropout=H * W * self.dim, + atnn_residual=H * W * self.dim, + norm2=5 * H * W * self.dim, + mlp=self.mlp.flops(H * W), + mlp_drop=H * W * self.dim, + mlp_residual=H * W * self.dim, + ) + return sum(block_flops.values()) class PatchMerging(PipeStage): @@ -932,7 +953,9 @@ def train(): patch_norm=True, use_checkpoint=False) nparams = sum([param.numel() for param in model.parameters()]) - print_each_rank(f'Model Params#: {nparams}') + forward_flops = model.flops() + tflops = forward_flops * 4 / (1e12) # forward + recompute-forward + backward (2x) + print_each_rank(f'Model Params#: {nparams} | TFlops: {tflops}') if args.fp16: model = model.half() model = model.cuda() From 672f154ea39b1557c7fad87c5f61a80692ef9c01 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 5 Apr 2022 07:50:34 +0000 Subject: [PATCH 0739/1892] add test scripts --- ...t-multinode-fp32.sh => test-2node-fp32.sh} | 81 ++++---- handcraft/mbart/test-4node-fp32.sh | 174 ++++++++++++++++++ 2 files changed, 215 insertions(+), 40 deletions(-) rename handcraft/mbart/{test-multinode-fp32.sh => test-2node-fp32.sh} (67%) mode change 100644 => 100755 create mode 100755 handcraft/mbart/test-4node-fp32.sh diff --git a/handcraft/mbart/test-multinode-fp32.sh b/handcraft/mbart/test-2node-fp32.sh old mode 100644 new mode 100755 similarity index 67% rename from handcraft/mbart/test-multinode-fp32.sh rename to handcraft/mbart/test-2node-fp32.sh index 4d5c0905..5accdb10 --- a/handcraft/mbart/test-multinode-fp32.sh +++ b/handcraft/mbart/test-2node-fp32.sh @@ -108,14 +108,14 @@ test_hybrid_tp_pp() if [ ${gpus} == 16 ] then - echo "testing ${gpus}-dev tp:pp=2:2 | L${layers}E${hidden}H${heads}" + echo "testing ${gpus}-dev tp:pp=8:2 | L${layers}E${hidden}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=2 \ --node_rank=${NODE_RANK} \ --master_addr="${MASTER_IP}" \ --master_port=${MASTER_PORT} \ - handcraft/mbart/mbart_hybrid.py \ + handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs ${bs} --micro-bs 1 \ --pp-size 2 --tp-size 8 \ @@ -125,39 +125,39 @@ test_hybrid_tp_pp() sleep 5 killall python - echo "testing ${gpus}-dev tp:pp=2:2 | L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size 4 --tp-size 4 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp4.txt - sleep 5 - killall python - sleep 5 - killall python - - echo "testing ${gpus}-dev tp:pp=2:2 | L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size 8 --tp-size 2 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt - sleep 5 - killall python - sleep 5 - killall python + # echo "testing ${gpus}-dev tp:pp=4:4 | L${layers}E${hidden}H${heads}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=2 \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ + # handcraft/mbart/train.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --bs ${bs} --micro-bs 1 \ + # --pp-size 4 --tp-size 4 \ + # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp4.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + # + # echo "testing ${gpus}-dev tp:pp=2:8 | L${layers}E${hidden}H${heads}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=2 \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ + # handcraft/mbart/train.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --bs ${bs} --micro-bs 1 \ + # --pp-size 8 --tp-size 2 \ + # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt + # sleep 5 + # killall python + # sleep 5 + # killall python fi } @@ -167,12 +167,13 @@ test_hybrid_tp_pp() # ================================================= # strong scalability test -test_mix_tp_1f1b 16 3072 24 16 -test_tp 16 3072 24 16 +# test_mix_tp_1f1b 16 3072 24 16 +# test_tp 16 3072 24 16 # model scaling test -test_mix_tp_1f1b 32 5120 40 16 -test_tp 32 5120 40 16 -# test_hybrid_tp_pp 32 5120 40 16 # --> OOM +test_mix_tp_1f1b 36 5120 32 16 +test_tp 36 5120 32 16 +# test_hybrid_tp_pp 40 5120 32 16 # --> OOM +# test_hybrid_tp_pp 36 5120 32 16 # --> OOM python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/test-4node-fp32.sh b/handcraft/mbart/test-4node-fp32.sh new file mode 100755 index 00000000..857170a5 --- /dev/null +++ b/handcraft/mbart/test-4node-fp32.sh @@ -0,0 +1,174 @@ +evaldir=eval/mbart-fp32-v100-32gb +mkdir -p ${evaldir} + +bs=256 + +test_mix_tp_1f1b() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev mixture-1f1b: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=4 \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule tp1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_tp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure tp: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=4 \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size 1 --tp-size ${gpus} \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_pp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure pp: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=4 \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_pp_swap() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + echo "testing ${gpus}-dev pure pp swap: L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=4 \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size ${gpus} --tp-size 1 \ + --schedule 1f1b --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp-swap.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_hybrid_tp_pp() +{ + layers=$1 + hidden=$2 + heads=$3 + gpus=$4 + + if [ ${gpus} == 32 ] + then + echo "testing ${gpus}-dev tp:pp=16:2 | L${layers}E${hidden}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=4 \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/mbart/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --bs ${bs} --micro-bs 1 \ + --pp-size 2 --tp-size 16 \ + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt + sleep 5 + killall python + sleep 5 + killall python + + # echo "testing ${gpus}-dev tp:pp=8:4 | L${layers}E${hidden}H${heads}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=4 \ + # --node_rank=${REMOTE_NODE_RANK} \ + # --master_addr="${REMOTE_MASTER_IP}" \ + # --master_port=${REMOTE_MASTER_PORT} \ + # handcraft/mbart/mbart_hybrid.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --bs ${bs} --micro-bs 1 \ + # --pp-size 4 --tp-size 8 \ + # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp4.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + # + # echo "testing ${gpus}-dev tp:pp=4:8 | L${layers}E${hidden}H${heads}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=4 \ + # --node_rank=${REMOTE_NODE_RANK} \ + # --master_addr="${REMOTE_MASTER_IP}" \ + # --master_port=${REMOTE_MASTER_PORT} \ + # handcraft/mbart/mbart_hybrid.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --bs ${bs} --micro-bs 1 \ + # --pp-size 8 --tp-size 4 \ + # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + fi +} + + +# ================================================= +# selected experiments +# ================================================= + + +# test_hybrid_tp_pp 40 8192 64 32 +test_mix_tp_1f1b 40 8192 64 32 +# test_tp 40 6144 48 32 + +python scripts/keep.py --gpus 8 From cc468da7fc39da3081130190a49d06b936dcc180 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 5 Apr 2022 08:20:10 +0000 Subject: [PATCH 0740/1892] initialize only part models --- handcraft/mbart/train.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index df9654be..10b120b3 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -565,8 +565,6 @@ def __init__(self, cfg: Config): self.first_decoder_stage = _first_decoder_stage self.cfg = cfg - encoders = [EncoderLayer(cfg) for _ in range(self.cfg.encoder_layers)] - decoders = [DecoderLayer(cfg) for _ in range(self.cfg.decoder_layers)] start, end = _layer_divisions[self.stage_local_rank] print_each_rank(f'initializing layer ranging from [{start}, {end})') @@ -590,11 +588,12 @@ def __init__(self, cfg: Config): self.embed = ShardEmbed(cfg) if self.embed is None else self.embed inputs_info = ((), ()) if inputs_info is None else inputs_info + self.encoders = [] + self.layer_norm_encoder = None if self.encoder_forward: - self.encoders = torch.nn.ModuleList( - encoders[start:min(end, cfg.encoder_layers)] - ) - self.layer_norm_encoder = None + encoders = [EncoderLayer(cfg) for _ in range(min(end, cfg.encoder_layers) - start)] + assert len(encoders) == end - start + self.encoders = torch.nn.ModuleList(encoders) if self.decoder_preprocess or self.decoder_forward: self.layer_norm_encoder = torch.nn.LayerNorm(cfg.embed_dim) @@ -602,14 +601,16 @@ def __init__(self, cfg: Config): outputs_info = self.encoders[-1].outputs_info if self.decoder_preprocess: + _encoder = EncoderLayer(cfg) self.embed = ShardEmbed(cfg) if self.embed is None else self.embed - inputs_info = encoders[-1].outputs_info if inputs_info is None else inputs_info + inputs_info = _encoder.outputs_info if inputs_info is None else inputs_info + self.decoders = [] + self.layer_norm_decoder = None if self.decoder_forward: - self.decoders = torch.nn.ModuleList( - decoders[max(0, start-cfg.encoder_layers): end-cfg.encoder_layers] - ) - self.layer_norm_decoder = None + decoders = [DecoderLayer(cfg) for _ in range(end - max(cfg.encoder_layers, start))] + assert len(decoders) == end - start, f"end: {end}, start: {start}" + self.decoders = torch.nn.ModuleList(decoders) if self.is_last_stage: self.layer_norm_decoder = torch.nn.LayerNorm(cfg.embed_dim) @@ -723,9 +724,9 @@ def decoder_forward(enc, dec): def flops(self): enc_flops = sum([enc.flops() for enc in self.encoders]) - enc_layernorm = 5 * self.cfg.max_source_positions * self.cfg.embed_dim + enc_layernorm = 5 * self.cfg.max_source_positions * self.cfg.embed_dim if self.layer_norm_decoder is None else 0 dec_flops = sum([dec.flops() for dec in self.decoders]) - dec_layernorm = 5 * self.cfg.max_target_positions * self.cfg.embed_dim + dec_layernorm = 5 * self.cfg.max_target_positions * self.cfg.embed_dim if self.layer_norm_decoder is None else 0 return enc_flops + enc_layernorm + dec_flops + dec_layernorm From 4fc08de687f78b254356ff71f13b42264ced3841 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 5 Apr 2022 11:34:36 +0000 Subject: [PATCH 0741/1892] fix pure tp --- handcraft/mbart/train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index 10b120b3..58139b0f 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -592,7 +592,6 @@ def __init__(self, cfg: Config): self.layer_norm_encoder = None if self.encoder_forward: encoders = [EncoderLayer(cfg) for _ in range(min(end, cfg.encoder_layers) - start)] - assert len(encoders) == end - start self.encoders = torch.nn.ModuleList(encoders) if self.decoder_preprocess or self.decoder_forward: self.layer_norm_encoder = torch.nn.LayerNorm(cfg.embed_dim) @@ -609,7 +608,6 @@ def __init__(self, cfg: Config): self.layer_norm_decoder = None if self.decoder_forward: decoders = [DecoderLayer(cfg) for _ in range(end - max(cfg.encoder_layers, start))] - assert len(decoders) == end - start, f"end: {end}, start: {start}" self.decoders = torch.nn.ModuleList(decoders) if self.is_last_stage: self.layer_norm_decoder = torch.nn.LayerNorm(cfg.embed_dim) @@ -814,8 +812,8 @@ def __next__(self): model = MBart(cfg) nparams = sum([param.numel() for param in model.parameters()]) forward_flops = model.flops() - flops = forward_flops * 4 # forward + re-compute forward + backward (=2 forward flops) - print_each_rank(f'model params: {nparams} | FLOPs: {flops}. Launching model...') + tflops = forward_flops * 4 / 1e12 # forward + re-compute forward + backward (=2 forward flops) + print_each_rank(f'model params: {nparams} | TFLOPs: {tflops}. Launching model...') model = model.half().cuda() if args.fp16 else model.cuda() dataloader = MBartDataLoader(args.micro_bs, cfg) From 953561f67ef20bdfe96f881b5860af14a7c0b063 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 6 Apr 2022 11:50:59 +0000 Subject: [PATCH 0742/1892] env-setup scripts --- scripts/env-setup.sh | 5 +++-- scripts/sync-itp.sh | 23 +++++++++++++++++++++++ scripts/sync-singularity.sh | 25 +++++++++++++++++++++++++ scripts/sync.sh | 4 ---- scripts/sync4.sh | 5 ----- 5 files changed, 51 insertions(+), 11 deletions(-) create mode 100755 scripts/sync-itp.sh create mode 100755 scripts/sync-singularity.sh delete mode 100755 scripts/sync.sh delete mode 100755 scripts/sync4.sh diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh index 5b7313e0..c9ecf29a 100755 --- a/scripts/env-setup.sh +++ b/scripts/env-setup.sh @@ -10,11 +10,13 @@ sudo git config --global user.name "Zhiqi Lin" sudo git config --global user.email "v-zhiql@microsoft.com" sudo chmod -R a+w /opt/conda +sudo apt-get update sudo apt-get install htop -y sudo apt-get install tmux -y sudo apt-get install psmisc -y sudo apt-get install lsof -y -sudo apt-get install infiniband-diags -y +sudo apt-get install infiniband-diags -y # ibstatus => check ib link +sudo apt-get install net-tools -y # ifconfig # install blob # sudo apt-get install lsb-release -y @@ -45,4 +47,3 @@ echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc pip uninstall training_daemon -y python setup.py develop pip install -r requirements.txt - diff --git a/scripts/sync-itp.sh b/scripts/sync-itp.sh new file mode 100755 index 00000000..588a2a99 --- /dev/null +++ b/scripts/sync-itp.sh @@ -0,0 +1,23 @@ +# ============= ITP Variables ============ +# NODE_RANK +# MASTER_IP +# MASTER_PORT +# ============= ITP Variables ============ + +node_num=$1 + +if [ ${node_num} == 4 ] +then + scp -r /workspace/MagicCube/handcraft worker-1:/workspace/MagicCube/ + scp -r /workspace/MagicCube/cube worker-1:/workspace/MagicCube/ + scp -r /workspace/MagicCube/handcraft worker-2:/workspace/MagicCube/ + scp -r /workspace/MagicCube/cube worker-2:/workspace/MagicCube/ + scp -r /workspace/MagicCube/handcraft worker-3:/workspace/MagicCube/ + scp -r /workspace/MagicCube/cube worker-3:/workspace/MagicCube/ +fi + +if [ ${node_num} == 2 ] +then + scp -r /workspace/MagicCube/handcraft worker-1:/workspace/MagicCube/ + scp -r /workspace/MagicCube/cube worker-1:/workspace/MagicCube/ +fi \ No newline at end of file diff --git a/scripts/sync-singularity.sh b/scripts/sync-singularity.sh new file mode 100755 index 00000000..f8d4cf5d --- /dev/null +++ b/scripts/sync-singularity.sh @@ -0,0 +1,25 @@ + +# ============= Singularity Variables ============ +# NODE_RANK +# MASTER_ADDR +# MASTER_PORT +# ============= Singularity Variables ============ + +node_num=$1 + +if [ ${node_num} == 4 ] +then + scp -r /workspace/MagicCube/handcraft node-1:/workspace/MagicCube/ + scp -r /workspace/MagicCube/cube node-1:/workspace/MagicCube/ + scp -r /workspace/MagicCube/handcraft node-2:/workspace/MagicCube/ + scp -r /workspace/MagicCube/cube node-2:/workspace/MagicCube/ + scp -r /workspace/MagicCube/handcraft node-3:/workspace/MagicCube/ + scp -r /workspace/MagicCube/cube node-3:/workspace/MagicCube/ +fi + +if [ ${node_num} == 2 ] +then + scp -r /workspace/MagicCube/handcraft node-1:/workspace/MagicCube/ + scp -r /workspace/MagicCube/cube node-1:/workspace/MagicCube/ +fi + diff --git a/scripts/sync.sh b/scripts/sync.sh deleted file mode 100755 index 16737ab6..00000000 --- a/scripts/sync.sh +++ /dev/null @@ -1,4 +0,0 @@ - -# usually worker-1 -worker_name=$1 -scp -r /workspace/MagicCube/examples ${worker_name}:/workspace/MagicCube/ \ No newline at end of file diff --git a/scripts/sync4.sh b/scripts/sync4.sh deleted file mode 100755 index 1edf2a53..00000000 --- a/scripts/sync4.sh +++ /dev/null @@ -1,5 +0,0 @@ - -scp -r /workspace/MagicCube/examples worker-1:/workspace/MagicCube/ -scp -r /workspace/MagicCube/examples worker-2:/workspace/MagicCube/ -scp -r /workspace/MagicCube/examples worker-3:/workspace/MagicCube/ -scp -r /workspace/MagicCube/examples worker-4:/workspace/MagicCube/ \ No newline at end of file From c902ea342201d3b7f85b6327786a157e2ec0630a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 6 Apr 2022 11:51:36 +0000 Subject: [PATCH 0743/1892] change from init directly --- handcraft/swin/train.py | 42 +++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 432f596d..4cbd6bd9 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -464,7 +464,7 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 else: coshard = num_heads // args.tp_size coshard = coshard // 2 if layer_id > 0 else coshard - print_each_rank(f'Swin-stage-{layer_id} using coshard {coshard}', rank_only=0) + print(f'rank [{torch.distributed.get_rank()}]: Swin-stage-{layer_id} using coshard {coshard}') self.attn = SeqWindowAttention( dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, coshard=coshard) @@ -653,7 +653,8 @@ def flops(self): def create_basic_layter(dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, layer_id=None): + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, + layer_id=None, start_id=0): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. @@ -675,7 +676,7 @@ def create_basic_layter(dim, input_resolution, depth, num_heads, window_size, blocks = [SwinTransformerBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, + shift_size=0 if ((i + start_id) % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, @@ -782,12 +783,24 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers + total_layers = [3, 3, depths[2] + 1, 2] + # pipeline split layers + start, end = _layer_divisions[self.stage_local_rank] layers = [] for i_layer in range(self.num_layers): + layer_start = sum(total_layers[:i_layer]) + layer_end = sum(total_layers[:i_layer+1]) + if max(layer_start, start) >= min(layer_end, end): + continue + have_downsample = start < layer_end and layer_end <= end and i_layer < self.num_layers - 1 + layer_start_id = max(layer_start, start) - layer_start + layer_num = min(layer_end, end) - max(layer_start, start) + layer_num = layer_num if not have_downsample else layer_num - 1 + assert layer_num >= 1 blocks = create_basic_layter(dim=int(embed_dim * 2 ** i_layer), input_resolution=(self.patches_resolution[0] // (2 ** i_layer), self.patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], + depth=layer_num, num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, @@ -795,14 +808,13 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - layer_id=i_layer) + downsample=PatchMerging if have_downsample else None, + layer_id=i_layer, start_id=layer_start_id) layers += blocks - - # pipeline split layers - start, end = _layer_divisions[self.stage_local_rank] - self.layers = torch.nn.ModuleList(layers[start:end]) - print_each_rank(f'initialized layers ({len(self.layers)}) ranging from [{start}, {end})') + assert (end - start) == len(layers), f"layer num not equal, [{start}, {end}) != {len(layers)} " + torch.distributed.barrier() + self.layers = torch.nn.ModuleList(layers) + print_each_rank(f'initialized {len(self.layers)} layers ranging from [{start}, {end})') self.inputs_info = self.layers[0].inputs_info self.outputs_info = self.layers[-1].outputs_info @@ -888,11 +900,13 @@ def _post_process(x): def flops(self): flops = 0 - flops += self.patch_embed.flops() + if self.is_first_stage: + flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes + if self.is_last_stage: + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes return flops From 1648f717e3b357e159cfb13a8745113cd609ab21 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 6 Apr 2022 12:07:49 +0000 Subject: [PATCH 0744/1892] metric --- handcraft/swin/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 4cbd6bd9..a78dad7a 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -110,7 +110,7 @@ ([241.4] * 2 + [0]) + \ ([145.7] * args.layers + [0]) + \ ([108.9] * 2) - elif args.dim == 1024: # TP needed + elif args.dim >= 1024: # TP needed times = ([255.10] * 2 + [0]) + \ ([139.92] * 2 + [0]) + \ ([90.98] * args.layers + [0]) + \ From 253b41908f4003184923093421b62fe8c60236b2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 6 Apr 2022 14:17:07 +0000 Subject: [PATCH 0745/1892] add 4node training script --- handcraft/mbart/test-2node-fp32.sh | 3 +- handcraft/mbart/test-4node-fp32.sh | 44 ++- handcraft/module/stage.py | 10 +- .../{test-multi-node.sh => test-2node.sh} | 0 handcraft/swin/test-4node.sh | 288 ++++++++++++++++++ handcraft/swin/test.sh | 13 + handcraft/swin/train.py | 16 +- 7 files changed, 363 insertions(+), 11 deletions(-) rename handcraft/swin/{test-multi-node.sh => test-2node.sh} (100%) create mode 100755 handcraft/swin/test-4node.sh diff --git a/handcraft/mbart/test-2node-fp32.sh b/handcraft/mbart/test-2node-fp32.sh index 5accdb10..9c01e0d9 100755 --- a/handcraft/mbart/test-2node-fp32.sh +++ b/handcraft/mbart/test-2node-fp32.sh @@ -173,7 +173,8 @@ test_hybrid_tp_pp() # model scaling test test_mix_tp_1f1b 36 5120 32 16 test_tp 36 5120 32 16 -# test_hybrid_tp_pp 40 5120 32 16 # --> OOM # test_hybrid_tp_pp 36 5120 32 16 # --> OOM +# test_mix_tp_1f1b 40 5120 40 16 + python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/test-4node-fp32.sh b/handcraft/mbart/test-4node-fp32.sh index 857170a5..e4a56303 100755 --- a/handcraft/mbart/test-4node-fp32.sh +++ b/handcraft/mbart/test-4node-fp32.sh @@ -42,7 +42,7 @@ test_tp() --master_port=${REMOTE_MASTER_PORT} \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ + --bs 8 --micro-bs 1 \ --pp-size 1 --tp-size ${gpus} \ --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt sleep 5 @@ -119,7 +119,7 @@ test_hybrid_tp_pp() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs ${bs} --micro-bs 1 \ --pp-size 2 --tp-size 16 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt + --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp16pp2.txt sleep 5 killall python sleep 5 @@ -136,7 +136,7 @@ test_hybrid_tp_pp() # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --bs ${bs} --micro-bs 1 \ # --pp-size 4 --tp-size 8 \ - # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp4.txt + # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp4.txt # sleep 5 # killall python # sleep 5 @@ -153,7 +153,7 @@ test_hybrid_tp_pp() # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --bs ${bs} --micro-bs 1 \ # --pp-size 8 --tp-size 4 \ - # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt + # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp8.txt # sleep 5 # killall python # sleep 5 @@ -167,8 +167,38 @@ test_hybrid_tp_pp() # ================================================= -# test_hybrid_tp_pp 40 8192 64 32 -test_mix_tp_1f1b 40 8192 64 32 -# test_tp 40 6144 48 32 +test_mix_tp_1f1b 48 6144 32 32 +test_tp 48 6144 32 32 +test_hybrid_tp_pp 48 6144 32 32 python scripts/keep.py --gpus 8 + +# OOM: --layers 64 --hidden-size 6144 --heads 32 +# OOM: --layers 52 --hidden-size 6144 --heads 32 -- 29.64GB +# SUC: --layers 48 --hidden-size 6144 --heads 32 -- 29.64GB +# SUC: --layers 48 --hidden-size 5120 --heads 32 + +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=8 \ +# --nnodes=4 \ +# --node_rank=${REMOTE_NODE_RANK} \ +# --master_addr="${REMOTE_MASTER_IP}" \ +# --master_port=${REMOTE_MASTER_PORT} \ +# handcraft/mbart/train.py \ +# --layers 48 --hidden-size 6144 --heads 32 \ +# --bs 32 --micro-bs 1 \ +# --pp-size 32 --tp-size 1 \ +# --schedule tp1f1b +# +# +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=8 \ +# --nnodes=4 \ +# --node_rank=${REMOTE_NODE_RANK} \ +# --master_addr="${REMOTE_MASTER_IP}" \ +# --master_port=${REMOTE_MASTER_PORT} \ +# handcraft/mbart/train.py \ +# --layers 52 --hidden-size 6144 --heads 32 \ +# --bs 4 --micro-bs 1 \ +# --pp-size 1 --tp-size 32 \ +# --schedule 1f1b \ No newline at end of file diff --git a/handcraft/module/stage.py b/handcraft/module/stage.py index 39359167..9341a086 100644 --- a/handcraft/module/stage.py +++ b/handcraft/module/stage.py @@ -132,7 +132,7 @@ def data(self, datas: Tuple): self._data = datas -def layer_division(times: List[int], num_stages: int, start_id: int = 0): +def layer_division(times: List[int], num_stages: int, start_id: int = 0, limits: List[int] = None): """ Computation balance division """ @@ -140,10 +140,16 @@ def layer_division(times: List[int], num_stages: int, start_id: int = 0): budget = sum(times) / num_stages nlayers = len(times) start, end = 0, 1 + if limits is None: + limits = [None] * num_stages + else: + assert len(limits) == num_stages for idx in range(num_stages): accum = times[start] assert end <= nlayers while end != nlayers: + if limits[idx] is not None and (end - start) == limits[idx]: + break if times[end] > 0 and budget - accum < 0.5 * times[end]: break accum += times[end] @@ -151,6 +157,8 @@ def layer_division(times: List[int], num_stages: int, start_id: int = 0): if idx == num_stages - 1: end = nlayers divisions.append((start, end)) + if idx != num_stages - 1: + budget = sum(times[end:]) / (num_stages - 1 - idx) start, end = end, end+1 for sid in range(num_stages): start, end = divisions[sid] diff --git a/handcraft/swin/test-multi-node.sh b/handcraft/swin/test-2node.sh similarity index 100% rename from handcraft/swin/test-multi-node.sh rename to handcraft/swin/test-2node.sh diff --git a/handcraft/swin/test-4node.sh b/handcraft/swin/test-4node.sh new file mode 100755 index 00000000..7d83cd2f --- /dev/null +++ b/handcraft/swin/test-4node.sh @@ -0,0 +1,288 @@ +# swin transformer constant head dim == 32 + +evaldir=eval/swin-coshard +mkdir -p ${evaldir} + + +img_size=1536 +window_size=48 +bs=256 + + +test_naive_pp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} + + echo "testing ${gpus}-dev: Pure PP${coshard}: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-pp${gpus}.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_naive_tp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} + + echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 1 --tp-size ${gpus} --dp-size 1 \ + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_naive_hybrid_tp_pp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} + + # Hybrid TP-1F1B -- 16 GPU + if [ ${gpus} == 32 ] + then + echo "testing ${gpus}-dev: TP16-PP2: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 2 --tp-size 16 --dp-size 1 \ + --bs ${bs} --micro-bs 1 \ + --fp16 > ${evaldir}/${gpus}dev-${arch}-tp8pp2.txt + sleep 5 + killall python + sleep 5 + killall python + + # echo "testing ${gpus}-dev: TP8-PP4: L${layers}E${dim}H${heads}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=${nodes} \ + # --node_rank=${REMOTE_NODE_RANK} \ + # --master_addr="${REMOTE_MASTER_IP}" \ + # --master_port=${REMOTE_MASTER_PORT} \ + # handcraft/swin/train.py \ + # --layers ${layers} --dim ${dim} --heads ${heads} \ + # --img-size ${img_size} --window-size ${window_size} \ + # --pp-size 4 --tp-size 8 --dp-size 1 \ + # --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp4.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + fi +} + +test_coshard_pp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} + + echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size ${gpus} --tp-size 1 --dp-size 1 \ + --bs ${bs} --micro-bs 1 --use-coshard + --fp16 > ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_coshard_hybrid_tp_pp() +{ + layers=$1 + dim=$2 + heads=$3 + nodes=$4 + gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} + + # Hybrid TP-1F1B -- 8 GPU + if [ ${gpus} == 32 ] + then + # echo "testing ${gpus}-dev: TP16-PP2: L${layers}E${dim}H${heads}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=${nodes} \ + # --node_rank=${REMOTE_NODE_RANK} \ + # --master_addr="${REMOTE_MASTER_IP}" \ + # --master_port=${REMOTE_MASTER_PORT} \ + # handcraft/swin/train.py \ + # --layers ${layers} --dim ${dim} --heads ${heads} \ + # --img-size ${img_size} --window-size ${window_size} \ + # --pp-size 2 --tp-size 16 --dp-size 1 \ + # --bs 64 --micro-bs 1 --use-coshard --use-inner-coshard \ + # --fp16 > ${evaldir}/${gpus}dev-${arch}-tp16pp2-coshard.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + + echo "testing coshard ${gpus}-dev: TP8-PP4: L${layers}E${dim}H${heads}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=${nodes} \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ + handcraft/swin/train.py \ + --layers ${layers} --dim ${dim} --heads ${heads} \ + --img-size ${img_size} --window-size ${window_size} \ + --pp-size 4 --tp-size 8 --dp-size 1 \ + --bs ${bs} --micro-bs 1 --use-coshard --use-inner-coshard \ + --fp16 > ${evaldir}/${gpus}dev-${arch}-tp8pp4-coshard.txt + sleep 5 + killall python + sleep 5 + killall python + + # echo "testing coshard ${gpus}-dev: TP4-PP8: L${layers}E${dim}H${heads}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=${nodes} \ + # --node_rank=${REMOTE_NODE_RANK} \ + # --master_addr="${REMOTE_MASTER_IP}" \ + # --master_port=${REMOTE_MASTER_PORT} \ + # handcraft/swin/train.py \ + # --layers ${layers} --dim ${dim} --heads ${heads} \ + # --img-size ${img_size} --window-size ${window_size} \ + # --pp-size 8 --tp-size 4 --dp-size 1 \ + # --bs ${bs} --micro-bs 1 --use-coshard --use-inner-coshard \ + # --fp16 > ${evaldir}/${gpus}dev-${arch}-tp8pp4-coshard.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + fi +} + +test_all() +{ + layers=$1 + dim=$2 + heads=$3 + gpus=$4 + test_naive_pp $layers $dim $heads $gpus + test_naive_tp $layers $dim $heads $gpus + test_naive_hybrid_tp_pp $layers $dim $heads $gpus + test_coshard_pp $layers $dim $heads $gpus +} + + +# ================================================= +# selected experiments +# ================================================= + +test_naive_tp 58 1536 32 4 32 +test_coshard_hybrid_tp_pp 58 1536 32 4 32 +# test_naive_hybrid_tp_pp 58 1536 32 4 32 # -> OOM + +python scripts/keep.py --gpus 8 + + + +# ============ exp +# Fail: 50 1280 32 | COSHARD-TP: TP4PP8 Fail TP: ? Hybrid-TP: ? +# TEST: 50 1536 32 | COSHARD-TP: ? TP4PP8 ALL Fail TP8PP4 SUC TP: ? Hybrid-TP: ? +# TEST: 58 1536 32 | COSHARD-TP: ? TP4PP8 ? Fail TP8PP4 SUC TP: ? Hybrid-TP: ? +# FAIL: 50 1536 64 | COSHARD-TP: ? TP8PP4 Fail TP: ? Hybrid-TP: ? +# FAIL: 50 2048 64 | COSHARD-TP: ? TP8PP4 Fail TP: ? Hybrid-TP: ? + + +# coshard +# layers=58 +# dim=1536 +# heads=32 +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=8 \ +# --nnodes=4 \ +# --node_rank=${REMOTE_NODE_RANK} \ +# --master_addr="${REMOTE_MASTER_IP}" \ +# --master_port=${REMOTE_MASTER_PORT} \ +# handcraft/swin/train.py \ +# --layers ${layers} --dim ${dim} --heads ${heads} \ +# --img-size 1536 --window-size 48 \ +# --pp-size 8 --tp-size 4 --dp-size 1 \ +# --bs 8 --micro-bs 1 --use-coshard --use-inner-coshard \ +# --fp16 +# +# # hybrid tp +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=8 \ +# --nnodes=4 \ +# --node_rank=${REMOTE_NODE_RANK} \ +# --master_addr="${REMOTE_MASTER_IP}" \ +# --master_port=${REMOTE_MASTER_PORT} \ +# handcraft/swin/train.py \ +# --layers ${layers} --dim ${dim} --heads ${heads} \ +# --img-size 1536 --window-size 48 \ +# --pp-size 2 --tp-size 16 --dp-size 1 \ +# --bs 4 --micro-bs 1 \ +# --fp16 +# +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=8 \ +# --nnodes=4 \ +# --node_rank=${REMOTE_NODE_RANK} \ +# --master_addr="${REMOTE_MASTER_IP}" \ +# --master_port=${REMOTE_MASTER_PORT} \ +# handcraft/swin/train.py \ +# --layers ${layers} --dim ${dim} --heads ${heads} \ +# --img-size 1536 --window-size 48 \ +# --pp-size 1 --tp-size 32 --dp-size 1 \ +# --bs 2 --micro-bs 1 --fp16 +# +# clear +# killall python \ No newline at end of file diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index a4bdcf12..8a5f344f 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -254,3 +254,16 @@ test_naive_tp 42 768 24 8 # test_naive_hybrid_tp_pp 42 768 24 8 # --> OOM python scripts/keep.py --gpus 8 + + +# for test +# coshard-pp +# gpus=4 +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=${gpus} \ +# --nnodes=1 \ +# handcraft/swin/train.py \ +# --layers 26 --dim 512 --heads 16 \ +# --img-size 1536 --window-size 48 \ +# --pp-size ${gpus} --tp-size 1 --dp-size 1 \ +# --bs 16 --micro-bs 1 --use-coshard --fp16 \ No newline at end of file diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index a78dad7a..f47e00b5 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -125,12 +125,24 @@ _layer_divisions = layer_division(times, num_stages) # specific rules for stage division in order to fit in memory if args.dim == 1024 and args.tp_size == 4: + # first stage if _layer_divisions[0][1] > 8: remain_times = times[8:] _layer_divisions = [(0, 8)] + layer_division(remain_times, num_stages-1, start_id=8) + if args.dim == 1536 and args.tp_size == 8: + limits = [None] * args.pp_size + limits[0] = 4 + _layer_divisions = layer_division(times, num_stages, limits=limits) + # if args.dim == 1536 and args.tp_size == 4: + # limits = [None] * args.pp_size + # limits[0] = 1 + # limits[1] = 3 + # limits[2] = 9 + # _layer_divisions = layer_division(times, num_stages, limits=limits) + else: _layer_divisions = [(0, 2 + 2 + args.layers + 2 + 3)] -print_each_rank(f'layer divisions: {_layer_divisions}', rank_only=0) +print_each_rank(f'layer divisions: {_layer_divisions}') class Config: @@ -796,7 +808,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, layer_start_id = max(layer_start, start) - layer_start layer_num = min(layer_end, end) - max(layer_start, start) layer_num = layer_num if not have_downsample else layer_num - 1 - assert layer_num >= 1 + assert layer_num >= 0 blocks = create_basic_layter(dim=int(embed_dim * 2 ** i_layer), input_resolution=(self.patches_resolution[0] // (2 ** i_layer), self.patches_resolution[1] // (2 ** i_layer)), From 4fe7f248a81d49e13b41bf2d749b24854c87fceb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 03:32:24 +0000 Subject: [PATCH 0746/1892] tp use less bs --- handcraft/swin/test-4node.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handcraft/swin/test-4node.sh b/handcraft/swin/test-4node.sh index 7d83cd2f..8770122c 100755 --- a/handcraft/swin/test-4node.sh +++ b/handcraft/swin/test-4node.sh @@ -56,7 +56,7 @@ test_naive_tp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt + --bs 16 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt sleep 5 killall python sleep 5 @@ -285,4 +285,4 @@ python scripts/keep.py --gpus 8 # --bs 2 --micro-bs 1 --fp16 # # clear -# killall python \ No newline at end of file +# killall python From f6088e30aaec0487dca43613bb6fdd8f3b68d680 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 09:18:33 +0000 Subject: [PATCH 0747/1892] add gpt model --- handcraft/gpt3/train.py | 519 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 519 insertions(+) create mode 100644 handcraft/gpt3/train.py diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py new file mode 100644 index 00000000..7f15bdab --- /dev/null +++ b/handcraft/gpt3/train.py @@ -0,0 +1,519 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers 4 --hidden-size 1024 --heads 16 \ + --bs 1 --micro-bs 1 --schedule 1f1b +""" + +import torch +import cube +import math +import numpy as np + +from cube.runtime.device import DeviceGroup +from cube.runtime.adapter.reducer import Reducer +from cube.runtime.adapter.distnn import AllReduceIdentity, IdentityAllreduce, AllGatherSplit + +from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary +from cube.profiler.timer import print_each_rank + +from handcraft.module.schedule import schedule_1f1b +from handcraft.module.stage import PipeStage + +import argparse + +torch.manual_seed(0) +np.random.seed(0) + +parser = argparse.ArgumentParser(description='gpt3') +# model arch +parser.add_argument('--layers', type=int, default=12, + help='number encoder/decoder of layers') +parser.add_argument('--hidden-size', type=int, default=1024, + help='hidden size') +parser.add_argument('--heads', type=int, default=16, + help='number of heads') +# training config +parser.add_argument('--bs', type=int, default=256, + help='num of micro batch') +parser.add_argument('--micro-bs', type=int, default=1, + help='micro batch size') +parser.add_argument('--fp16', action='store_true', default=False) +# parallelism +parser.add_argument('--pp-size', type=int, default=1, + help='pipeline parallelism size') +parser.add_argument('--tp-size', type=int, default=1, + help='tensor parallelism size') +parser.add_argument('--dp-size', type=int, default=1, + help='data parallelism size') +parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b'], + help='scheduling algorithm') +args = parser.parse_args() +print(args) + +_tp_group = -1 + +_dp_group = -1 +_dp_reducer = None + +_pp_group = -1 +_pp_global_ranks = () +_layer_divisions = [] + +_schedule = schedule_1f1b + +_pp_embed_group = -1 +cube.init() +dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( + [args.dp_size, args.pp_size, args.tp_size] +) + +if len(dp_ranks) != 1: + print_each_rank(f'initializing dp ranks: {dp_ranks}') + _dp_group = DeviceGroup().get_group(dp_ranks) + _dp_reducer = Reducer(dp_ranks) + +if len(tp_ranks) != 1: + print_each_rank(f'initializing tp ranks: {tp_ranks}') + _tp_group = DeviceGroup().get_group(tp_ranks) + assert args.heads % args.tp_size == 0, "cannot be divided by tp-size" + +if len(pp_ranks) != 1: + print_each_rank(f'initializing pp ranks: {pp_ranks}') + _pp_group = DeviceGroup().get_group(pp_ranks) + _pp_global_ranks = tuple(pp_ranks) + _layer_divisions = layer_division([1] * args.layers, args.pp_size) +else: + _layer_divisions = [(0, args.layers)] + +if args.schedule == '1f1b' and args.pp_size > 1: + grid = np.arange(args.dp_size, args.pp_size * args.tp_size).reshape( + (args.dp_size, args.pp_size, args.tp_size)) + for dp_rank in range(args.dp_size): + embed_ranks = np.vstack((grid[dp_rank, 0, :], grid[dp_rank, -1, :])) + grank = torch.distributed.get_rank() + for gid in range(args.tp_size): + embed_rank = embed_ranks[:,gid] + embed_rank = np.squeeze(embed_rank).tolist() + print_each_rank(f'creating embed group: {embed_rank}') + group = DeviceGroup().get_group(embed_rank) + if grank in embed_rank: + print(f'rank [{grank}]: embedding group: {embed_rank}') + _pp_embed_group = group + + +class Config: + vocab_size = 50273 + seqlen = 1024 + layers = args.layers + heads = args.heads + hidden_size = args.hidden_size + +config = Config() + + +class MLP(torch.nn.Module): + + def __init__(self): + super().__init__() + self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) + + self.dense_h_to_4h = torch.nn.Linear( + config.hidden_size, config.hidden_size * 4 // self.tp_size + ) + + self.dense_4h_to_h = torch.nn.Linear( + config.hidden_size * 4 // self.tp_size, config.hidden_size + ) + + def forward(self, hidden_states): + if self.tp_size > 1: + hidden_states = IdentityAllreduce.apply(hidden_states, self.tp_group) + x = self.dense_h_to_4h(hidden_states) + x = torch.nn.functional.gelu(x) + x = self.dense_4h_to_h(x) + if self.tp_size > 1: + x = AllReduceIdentity.apply(x, self.tp_group) + return x + + +class Attention(torch.nn.Module): + + def __init__(self): + super().__init__() + self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) + + self.num_heads = config.heads // self.tp_size + self.head_dim = config.hidden_size // config.heads + projection_size = self.num_heads * self.head_dim + + self.query_key_value = torch.nn.Linear( + config.hidden_size, + 3 * projection_size, + ) + self.softmax = torch.nn.Softmax(dim=-1) + self.norm_factor = math.sqrt(self.head_dim) + self.dense = torch.nn.Linear( + projection_size, config.hidden_size + ) + + def forward(self, x, mask): + if self.tp_size > 1: + x = IdentityAllreduce.apply(x, self.tp_group) + + mixed_x_layer = self.query_key_value(x) + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_heads, 3 * self.head_dim) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + query_layer, key_layer, value_layer = \ + torch.chunk(mixed_x_layer, 3, dim=-1) + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + matmul_result = torch.empty( + output_size[0]*output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=torch.cuda.current_device()) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # attention scores and attention mask [b, np, sq, sk] + if mask is not None: + attention_scores.masked_fill_(mask, -10000.0) + attention_probs = self.softmax(attention_scores) + + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.head_dim * self.num_heads,) + context_layer = context_layer.view(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + output = self.dense(context_layer) + if self.tp_size > 1: + output = AllReduceIdentity.apply(output, self.tp_group) + return output + + +class Embedding(torch.nn.Module): + + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__() + self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) + self.tp_id = 0 if self.tp_group == -1 else torch.distributed.get_rank(self.tp_group) + + self.vocab_start_index = num_embeddings // self.tp_size * self.tp_id + self.vocab_end_index = num_embeddings // self.tp_size * (self.tp_id + 1) + self.weight = torch.nn.Parameter( + torch.ones((num_embeddings // self.tp_size, embedding_dim)) + ) + + def forward(self, tokens): + """ + Embedding lookup + if dst is None, use all + """ + if self.tp_size > 1: + mask = (tokens < self.vocab_start_index) | \ + (tokens >= self.vocab_end_index) + tokens = tokens.clone() - self.vocab_start_index + tokens[mask] = 0 + embed = torch.nn.functional.embedding(tokens, self.weight) + embed[mask, :] = 0.0 + embed = AllReduceIdentity.apply(embed, self.tp_group) + else: + embed = torch.nn.functional.embedding(tokens, self.weight) + return embed + + +class TransformerLayer(PipeStage): + + def __init__(self): + super().__init__() + self.input_layernorm = torch.nn.LayerNorm(config.hidden_size) + self.self_attention = Attention() + self.hidden_dropout = 0.0 + self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size) + self.mlp = MLP() + + # sq, b, h + self.inputs_info = ( + ((config.seqlen, args.micro_bs, config.hidden_size),), + (torch.float16 if args.fp16 else torch.float32) + ) + self.outputs_info = ( + ((config.seqlen, args.micro_bs, config.hidden_size),), + (torch.float16 if args.fp16 else torch.float32) + ) + + + def forward(self, hidden_states, attention_mask): + + layernrom_output = self.input_layernorm(hidden_states) + + attention_output = self.self_attention(layernrom_output, attention_mask) + + residual = hidden_states + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = layernorm_input + residual + layernrom_output = self.post_attention_layernorm(layernorm_input) + + mlp_output = self.mlp(layernrom_output) + + residual = layernorm_input + output = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + output = layernorm_input + residual + return output + + +class Pooler(torch.nn.Module): + + def __init__(self): + super().__init__() + self.dense = troch.nn.Linear(config.hidden_size, config.hidden_size) + + def forward(self, hidden_states, sequence_index=0): + pooled = hidden_states[:, sequence_index, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled + + +class GPT3(PipeStage): + + def __init__(self): + super().__init__() + self.set_pipeline(pp_ranks) + self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) + + self.word_embeddings = None + if self.is_first_stage: + self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = torch.nn.Embedding( + config.seqlen, config.hidden_size + ) + self.embedding_dropout = torch.nn.Dropout(0.0) + + start, end = _layer_divisions[self.stage_local_rank] + layers = [TransformerLayer() for _ in range(end - start)] + self.layers = torch.nn.ModuleList(layers) + + if self.is_last_stage: + self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) if self.word_embeddings is None else self.word_embeddings + self.final_layernorm = torch.nn.LayerNorm(config.hidden_size) + + def forward(self, hidden_states = None): + # data + # input_ids, position_ids, atten_mask, loss_mask + + # preprocess + if self.is_first_stage: + input_ids, position_ids, _, _ = self.data + word_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = word_embeddings + position_embeddings + embeddings = self.embedding_dropout(embeddings) + hidden_states = embeddings + hidden_states = hidden_states.transpose(0, 1).contiguous() + + + assert hidden_states is not None + _, _, attention_mask, _ = self.data + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + outputs = hidden_states + + # postprocess + if self.is_last_stage: + labels, _, _, loss_mask = self.data + + hidden_states = hidden_states.transpose(0, 1).contiguous() + hidden_states = self.final_layernorm(hidden_states) + + if self.tp_size > 1: + hidden_states = IdentityAllreduce.apply(hidden_states, self.tp_group) + logits = torch.nn.functional.linear(hidden_states, self.word_embeddings.weight) + if self.tp_size > 1: + logits = AllGatherSplit.apply(logits, -1, self.tp_group) + + # minor changes from + # https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/pretrain_gpt.py#L75 + logits = logits.float() + logits = logits.view(config.seqlen, -1) + labels = labels.view(-1) + loss = torch.nn.functional.cross_entropy(logits, labels) + outputs = loss + + return outputs + + +class GPT3DataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + self.bs = batch_size + super().__init__( + shapes=( + [batch_size, config.seqlen,], + [batch_size, config.seqlen,], + [batch_size, config.seqlen,], + [batch_size, config.seqlen,], + ), + dtypes=( + torch.int64, + torch.int64, + torch.float16 if args.fp16 else torch.float, + torch.float16 if args.fp16 else torch.float, + ), + batch_dims=(0, 0, 0, 0) + ) + self.samples = [self.random_sample()] + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + def random_sample(self): + input_ids = torch.randint( + 0, 25000, + size=(self.bs, config.seqlen,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + attention_mask, loss_mask, position_ids = self.get_ltor_masks_and_position_ids(input_ids) + return (input_ids, position_ids, attention_mask, loss_mask) + + def get_ltor_masks_and_position_ids(self, input_ids): + """ + Build masks and position id for left to right model. + https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/utils.py#L81 + """ + # Extract batch size and sequence length. + seq_length = config.seqlen + # Attention mask (lower triangular). + mask_dtype = torch.float16 if args.fp16 else torch.float32 + attention_mask = torch.tril( + torch.ones((args.micro_bs, seq_length, seq_length), dtype=mask_dtype, device=torch.cuda.current_device()) + ).view(args.micro_bs, 1, seq_length, seq_length) + + # Loss mask. + loss_mask = torch.ones(input_ids.size(), device=input_ids.device) + eod_token = 2 + loss_mask[input_ids == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, + device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + return attention_mask, loss_mask, position_ids + + +if __name__ == '__main__': + + model = GPT3() + nparams = sum([param.numel() for param in model.parameters()]) + # forward_flops = model.flops() + tflops = 0 # forward_flops * 4 / 1e12 # forward + re-compute forward + backward (=2 forward flops) + print_each_rank(f'model params (M): {nparams / 1e6} | TFLOPs: {tflops}. Launching model...') + model = model.half().cuda() if args.fp16 else model.cuda() + + dataloader = GPT3DataLoader(args.micro_bs) + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + print_each_rank('model weight consumpition:') + memory_summary() + + CudaTimer(enable=False) + torch.distributed.barrier() + iter_num = 6 + for step in range(iter_num): + if step >= 2: + CudaTimer(enable=True).start('e2e') + + # train 1 step + num_microbatch = args.bs // args.micro_bs + if args.pp_size > 1: + _schedule(model, dataloader, num_microbatch) + reduce_embed(model, _pp_embed_group) + else: + for _ in range(num_microbatch): + model.data = next(dataloader) + loss = model() + loss.backward() + + if _dp_reducer is not None: + _dp_reducer.allreduce() + + optimizer.step() + optimizer.zero_grad() + + if step >= 2: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('memory after optimizer:', rank_only=0) + memory_summary() + + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-2, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-2) + memory_summary() From 9500de2199430eaac5e2f045840469ba9b39131f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 09:20:02 +0000 Subject: [PATCH 0748/1892] remove legacy code --- handcraft/megatron/gpt.py | 154 -------------------- handcraft/megatron/layers.py | 190 ------------------------ handcraft/megatron/linears.py | 132 ----------------- handcraft/megatron/megatron_gpt_2.sh | 53 ------- handcraft/megatron/transformer.py | 207 --------------------------- 5 files changed, 736 deletions(-) delete mode 100644 handcraft/megatron/gpt.py delete mode 100644 handcraft/megatron/layers.py delete mode 100644 handcraft/megatron/linears.py delete mode 100755 handcraft/megatron/megatron_gpt_2.sh delete mode 100644 handcraft/megatron/transformer.py diff --git a/handcraft/megatron/gpt.py b/handcraft/megatron/gpt.py deleted file mode 100644 index 716d6ff9..00000000 --- a/handcraft/megatron/gpt.py +++ /dev/null @@ -1,154 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - handcraft/megatron/gpt.py -""" - -import torch -import torch.nn.functional as F -import cube -from handcraft.megatron.layers import ColumnOutputAdapter, ShardEmbedding -from handcraft.megatron.transformer import TransformerLayer - - -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - - -class GPT(torch.nn.Module): - - def __init__(self, hidden_size, vocab_size, seqlen_size, - bs, seqlen, num_head, num_layers: int): - super().__init__() - - self.num_layers = num_layers - self.bs = bs - self.seqlen = seqlen - self.ntoken = 1.0 / self.bs * self.seqlen - - # embeddings - - self.vocab_size = vocab_size - self.vocab_embedding = ShardEmbedding(self.vocab_size, hidden_size) - self.seqlen_size = seqlen_size - self.pos_embed_weight = torch.nn.Parameter( - torch.empty(seqlen_size, hidden_size) - ) - - self.embed_dropout = torch.nn.Dropout(0.5) - - self.transform1 = TransformerLayer(seqlen, hidden_size, num_head, 0.5) - self.transform2 = TransformerLayer(seqlen, hidden_size, num_head, 0.5) - self.transform3 = TransformerLayer(seqlen, hidden_size, num_head, 0.5) - self.transform4 = TransformerLayer(seqlen, hidden_size, num_head, 0.5) - - # final linear - self.final_layernorm = torch.nn.LayerNorm( - hidden_size, 1e-5 - ) - - def forward(self, input_ids, position_ids): - """ - input_ids: - [bs, seqlen] - position_ids: - [bs, seqlen] - """ - - # preprocess: embedding - # [bs, seqlen] -> [bs, seqlen, hidden size] - words_embeddings = self.vocab_embedding(input_ids) - - # [bs, seqlen] -> [bs, seqlen, hidden size] - position_embeddings = cube.runtime.function.embedding( - position_ids, self.pos_embed_weight, 0, self.seqlen_size - ) - embeddings = words_embeddings + position_embeddings - encoder_input = self.embed_dropout(embeddings) - - # [bs, seqlen, hidden size] -> [seqlen, bs, hidden size] - hidden_states = encoder_input.transpose(0, 1) - - hidden_states = self.transform1(hidden_states) - hidden_states = self.transform2(hidden_states) - hidden_states = self.transform3(hidden_states) - hidden_states = self.transform4(hidden_states) - - hidden_states = self.final_layernorm(hidden_states) - - # post process - hidden_states = hidden_states.transpose(0, 1) # .contiguous() - logits = F.linear(hidden_states, self.vocab_embedding.weight) - # all gather - logits = ColumnOutputAdapter.apply(logits) - - # loss # for verification, the mask is ommitted - # [bs, seqlen, self.vocab_size] -> [1] - loss = torch.sum(logits) - # loss = loss * self.ntoken - return loss - -def train(): - L = 512 # seq len - N = 8 # batch size - # configs: [hidden size, num_head] - # E, num_head = [2304, 24, 24] # 1.7B model - E, num_head, layers = [3072, 32, 30] # 3.6B model - # E, num_head, layers = [4096, 32, 36] # 7.5B model - - print_each_rank('config: L={}, N={}, E={}, num-head={}'.format( - L, N, E, num_head - )) - - - model = GPT( - hidden_size=E, vocab_size=50304, seqlen_size=L, - bs=N, seqlen=L, num_head=num_head, num_layers=layers - ).cuda() - - dataloader = cube.runtime.syndata.SynTextDataLoader(1280, [0, 0], [N, L], [N, L]) - - def train_iter(model, dataloader): - input_ids, position_ids = next(dataloader) - torch.distributed.broadcast(input_ids, 0) - torch.distributed.broadcast(position_ids, 0) - torch.cuda.synchronize() - loss = model(input_ids, position_ids) - loss.backward() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - memory_summary() - - -if __name__ == '__main__': - - cube.init() - train() \ No newline at end of file diff --git a/handcraft/megatron/layers.py b/handcraft/megatron/layers.py deleted file mode 100644 index 3ec1efd2..00000000 --- a/handcraft/megatron/layers.py +++ /dev/null @@ -1,190 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.nn.parameter import Parameter - - -def _reduce(input_): - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size() - if world_size == 1: - return input_ - torch.distributed.all_reduce(input_, group=None) - return input_ - - -def _split(input_): - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - # Bypass the function if we are using only 1 GPU. - if world_size==1: - return input_ - last_dim = input_.dim() - 1 - last_dim_size = input_.size()[last_dim] // world_size - tensor_list = torch.split(input_, last_dim_size, dim=last_dim) - output = tensor_list[rank].contiguous() - return output - - -def _gather(input_): - """Gather tensors and concatinate along the last dimension.""" - - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - # Bypass the function if we are using only 1 GPU. - if world_size==1: - return input_ - # Size and dimension. - last_dim = input_.dim() - 1 - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=None) - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() - return output - - -class ColumnInputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return input_ - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output) - - -class ColumnOutputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _gather(input_) - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output) - - -class RowInputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _split(input_) - - @staticmethod - def backward(ctx, grad_outputs): - return _gather(grad_outputs) - - -class RowOutputAdapter(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _reduce(input_) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class ColumnParallelLinear(torch.nn.Module): - - def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False): - super().__init__() - self.input_size = input_size - self.output_size = output_size - self.full_input = full_input - self.full_output = full_output - - world_size = torch.distributed.get_world_size() - self.weight = Parameter(torch.empty( - int(self.output_size // world_size), - self.input_size, - )) - if bias: - self.bias = Parameter(torch.empty( - int(self.output_size // world_size), - )) - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def forward(self, input_): - bias = self.bias - if not self.full_input: - raise RuntimeError("Expected full tensor input") - input_parallel = ColumnInputAdapter.apply(input_) - output_parallel = F.linear(input_parallel, self.weight, bias) - if self.full_output: - output = ColumnOutputAdapter.apply(output_parallel) - else: - output = output_parallel - return output - - -class RowParallelLinear(torch.nn.Module): - - def __init__(self, input_size, output_size, bias=True, full_input=True, full_output=False): - super().__init__() - self.input_size = input_size - self.output_size = output_size - self.full_input = full_input - self.full_output = full_output - - world_size = torch.distributed.get_world_size() - self.weight = Parameter(torch.empty( - self.output_size, - int(self.input_size // world_size), - )) - if bias: - self.bias = Parameter(torch.empty(self.output_size)) - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def forward(self, input_): - bias = self.bias - if self.full_input: - input_parallel = RowInputAdapter.apply(input_) - else: - input_parallel = input_ - output_parallel = F.linear(input_parallel, self.weight, bias) - if self.full_output: - output = RowOutputAdapter.apply(output_parallel) - else: - output = output_parallel - return output - - -class ShardEmbedding(torch.nn.Module): - - def __init__(self, num_embeddings, embedding_dim): - super().__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - - self.shard_num = torch.distributed.get_world_size() - self.myshard = torch.distributed.get_rank() - - shard_num_embeddings = self.num_embeddings // self.shard_num - self.vocab_start_index = shard_num_embeddings * self.myshard - self.vocab_end_index = self.vocab_start_index + shard_num_embeddings - - self.weight = torch.nn.Parameter( - torch.empty(shard_num_embeddings, self.embedding_dim) - ) - - def forward(self, input_): - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - output_parallel = F.embedding( - masked_input, self.weight, - None, None, 2., False, False - ) - output = RowOutputAdapter.apply(output_parallel) - return output diff --git a/handcraft/megatron/linears.py b/handcraft/megatron/linears.py deleted file mode 100644 index 23936f11..00000000 --- a/handcraft/megatron/linears.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - handcraft/megatron/linears.py - -torchrun --standalone \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/megatron/linears.py -""" - -import argparse - -import torch -from torch import nn -from handcraft.megatron.layers import ColumnParallelLinear, RowParallelLinear - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - - -class ColumnMLP(nn.Module): - def __init__(self, dim, mult=1): - super().__init__() - self.linear1 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=True) - self.linear2 = ColumnParallelLinear(dim * mult, dim, full_input=True, full_output=True) - self.linear3 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=True) - self.linear4 = ColumnParallelLinear(dim * mult, dim, full_input=True, full_output=True) - - def forward(self, data): - output = self.linear1(data) - output = self.linear2(output) - output = self.linear3(output) - output = self.linear4(output) - loss = torch.sum(output) - return loss - - -class RowMLP(nn.Module): - def __init__(self, dim, mult=1): - super().__init__() - self.linear1 = RowParallelLinear(dim, dim * mult, full_input=True, full_output=True) - self.linear2 = RowParallelLinear(dim * mult, dim, full_input=True, full_output=True) - self.linear3 = RowParallelLinear(dim, dim * mult, full_input=True, full_output=True) - self.linear4 = RowParallelLinear(dim * mult, dim, full_input=True, full_output=True) - - def forward(self, data): - output = self.linear1(data) - output = self.linear2(output) - output = self.linear3(output) - output = self.linear4(output) - loss = torch.sum(output) - return loss - - -class HybridMLP(nn.Module): - def __init__(self, dim, mult=1): - super().__init__() - self.linear1 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=False) - self.linear2 = RowParallelLinear(dim * mult, dim, full_input=False, full_output=True) - self.linear3 = ColumnParallelLinear(dim, dim * mult, full_input=True, full_output=False) - self.linear4 = RowParallelLinear(dim * mult, dim, full_input=False, full_output=True) - - def forward(self, data): - output = self.linear1(data) - output = self.linear2(output) - output = self.linear3(output) - output = self.linear4(output) - loss = torch.sum(output) - return loss - - -def train(args): - - batch_size = 8192 - dim = 8192 - - # model = ColumnMLP(dim=dim).cuda() - # model = RowMLP(dim=dim).cuda() - model = HybridMLP(dim=dim).cuda() - - for param in model.parameters(): - torch.nn.init.uniform_(param) - - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, dim],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - def train_iter(model, dataloader): - data = next(dataloader) - # torch.distributed.broadcast(data, 0) - # torch.cuda.synchronize() - loss = model(data) - loss.backward() - - CudaTimer().warmup(seconds=1.0) - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='inspect') - parser.add_argument('--bs', type=int, default=128) - args = parser.parse_args() - - cube.init() - train(args) diff --git a/handcraft/megatron/megatron_gpt_2.sh b/handcraft/megatron/megatron_gpt_2.sh deleted file mode 100755 index 08b52245..00000000 --- a/handcraft/megatron/megatron_gpt_2.sh +++ /dev/null @@ -1,53 +0,0 @@ -#! /bin/bash - -# Runs the "345M" parameter model - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=62001 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -DATA_PATH=/mydata/LargeModel/GPT-2/webtext2/my-gpt2_text_document - -DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" - -## Optional Config ## -# --checkpoint-activations \ -# NCCL_P2P_DISABLE=1 -# --fp16 - -rm -rf /workspace/Megatron-LM/megatron/fused_kernels/build - -python -m torch.distributed.launch $DISTRIBUTED_ARGS \ - /workspace/Megatron-LM/pretrain_gpt.py \ - --checkpoint-activations \ - --tensor-model-parallel-size 1 \ - --pipeline-model-parallel-size 8 \ - --num-layers 24 \ - --hidden-size 2304 \ - --num-attention-heads 24 \ - --micro-batch-size 1 \ - --global-batch-size 64 \ - --seq-length 1024 \ - --max-position-embeddings 1024 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --data-path $DATA_PATH \ - --vocab-file /mydata/LargeModel/GPT-2/gpt2-vocab.json \ - --merge-file /mydata/LargeModel/GPT-2/gpt2-merges.txt \ - --data-impl mmap \ - --split 949,50,1 \ - --distributed-backend nccl \ - --lr 0.00015 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --clip-grad 1.0 \ - --lr-warmup-fraction .01 \ - --no-masked-softmax-fusion \ - --no-bias-dropout-fusion \ - --no-bias-gelu-fusion \ - --log-interval 10 diff --git a/handcraft/megatron/transformer.py b/handcraft/megatron/transformer.py deleted file mode 100644 index 04e16ea1..00000000 --- a/handcraft/megatron/transformer.py +++ /dev/null @@ -1,207 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - benchmark/megatron/transformer.py -""" - -import torch -from torch import nn -import torch.nn.functional as F -import cube -from handcraft.megatron.layers import ColumnParallelLinear, RowParallelLinear - - -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - - -class MultiHeadSelfAttention(nn.Module): - - def __init__(self, seq_len, embed_dim, heads, dropout): - super().__init__() - - self.seq_len = seq_len - self.embed_dim = embed_dim - self.num_head = heads - self.dim_head = embed_dim // heads - self.scale = self.dim_head ** -0.5 - - self.world_size = torch.distributed.get_world_size() - - self.toqkv = ColumnParallelLinear( - embed_dim, 3 * embed_dim, bias=False, - full_input=True, full_output=False - ) - self.out = RowParallelLinear( - embed_dim, embed_dim, bias=False, - full_input=False, full_output=True - ) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - """ - x: [L, N, E]: seq_len, batch_size, embedding dimension - output: [L, N, E] - """ - bs = x.shape[1] - # [L, N, E] -> [L, N, (3 * num_heads * dim_head)] - qkv = self.toqkv(x) - # [L, N, E] -> [L, N, (num_heads * dim_head)] x 3 - qkv = qkv.chunk(3, dim=-1) - q, k, v = qkv - q = q.contiguous() - q = q.view(self.seq_len, (bs * self.num_head // self.world_size), self.dim_head) - k = k.contiguous() - k = k.view(self.seq_len, (bs * self.num_head // self.world_size), self.dim_head) - v = v.contiguous() - v = v.view(self.seq_len, (bs * self.num_head // self.world_size), self.dim_head) - - # [L, (N * num_head), dim_head] -> [(N * num_head), L, dim_head] - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - # [(N * num_head), L, dim_head] -> [(N * num_head), L, dim_head] - q = q * self.scale - # [(N * num_head), L, dim_head] -> [(N * num_head), dim_head, L] - k = k.transpose(-2, -1) - # [(N * num_head), L, dim_head] * [(N * num_head), dim_head, L] - # -> [(N * num_head), L, L] - attn = torch.bmm(q, k) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = cube.runtime.function.tril_mask( - attn, self.num_head // self.world_size - ) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = F.softmax(attn, dim=-1) - - # [(N * num_head), L, L] -> [(N * num_head), L, L] - attn = self.dropout(attn) - # [(N * num_head), L, L] * [(N * num_head), L, dim_head] - # -> [(N * num_head), L, dim_head] - output = torch.bmm(attn, v) - - # [(N * num_head), L, dim_head] -> [L, N, num_head * dim_head] - output = cube.runtime.function.attn_view( - output, self.num_head // self.world_size - ) - - # [L, N, num_head * dim_head] * [E, embed_head * dim_head] - # -> [L, N, E] - output = self.out(output) - return output - - -class FFN(torch.nn.Module): - - def __init__(self, hidden_size: int): - super().__init__() - self.dense_h_to_4h = ColumnParallelLinear( - hidden_size, 4 * hidden_size, - full_input=True, full_output=False - ) - self.dense_4h_to_h = RowParallelLinear( - 4 * hidden_size, hidden_size, - full_input=False, full_output=True - ) - - def forward(self, hidden_states): - # [L, N, E] * [E, 4E] -> [L, N, 4E] - out = self.dense_h_to_4h(hidden_states) - # [L, N, 4E] -> [L, N, 4E] - out = F.gelu(out) - # [L, N, 4E] * [4E, E] -> [L, N, E] - out = self.dense_4h_to_h(out) - return out - - -class TransformerLayer(torch.nn.Module): - - def __init__(self, seq_len, hidden_size, head_num, dropout): - super().__init__() - # layer norm - self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - - self.attention = MultiHeadSelfAttention(seq_len, hidden_size, head_num, dropout) - self.attn_dropout = torch.nn.Dropout(dropout) - - self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - self.ffn = FFN(hidden_size) - self.ffn_dropout = torch.nn.Dropout(dropout) - - def forward(self, hidden_states): - # Attention - in_attn_norm = self.input_layernorm(hidden_states) - attn_out = self.attention(in_attn_norm) - # residual - attn_out = self.attn_dropout(attn_out) - residual = attn_out + hidden_states - # ffn - in_ffn_norm = self.ffn_layernorm(residual) - ffn_out = self.ffn(in_ffn_norm) - # residual - ffn_out = self.ffn_dropout(ffn_out) - ffn_out = ffn_out + residual - return ffn_out - - -def train(): - L = 512 # seq len - N = 32 # batch size - # configs: [hidden size, num_head] - # E, num_head = [1536, 16] # 1.2B model - # E, num_head = [1920, 20] # 2.5B model - # E, num_head = [2304, 24] # 4.2B model - E, num_head = [3072, 32] # 8.7B model - - - model = TransformerLayer( - seq_len=L, hidden_size=E, head_num=num_head, dropout=0.5 - ).cuda() - - dataloader = cube.runtime.syndata.SynDataLoader(1280, [1], [L, N, E]) - - def train_iter(model, dataloader): - data = next(dataloader) - torch.distributed.broadcast(data, 0) - torch.cuda.synchronize() - out = model(data) - loss = torch.sum(out) - loss.backward() - - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-40, field_name='e2e') - throughput = N / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - - -if __name__ == '__main__': - - cube.init() - train() From d459a00f586ae0e5904fefe042cbc999b442dad8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 09:35:08 +0000 Subject: [PATCH 0749/1892] pipeline --- handcraft/gpt3/train.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 7f15bdab..330f98fd 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -1,9 +1,10 @@ """ OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ handcraft/gpt3/train.py \ - --layers 4 --hidden-size 1024 --heads 16 \ + --layers 16 --hidden-size 1024 --heads 16 \ + --dp-size 1 --tp-size 1 --pp-size 4 \ --bs 1 --micro-bs 1 --schedule 1f1b """ @@ -21,7 +22,7 @@ from cube.profiler.timer import print_each_rank from handcraft.module.schedule import schedule_1f1b -from handcraft.module.stage import PipeStage +from handcraft.module.stage import PipeStage, layer_division import argparse @@ -66,6 +67,7 @@ _schedule = schedule_1f1b _pp_embed_group = -1 +_pp_embed_reducer = None cube.init() dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( [args.dp_size, args.pp_size, args.tp_size] @@ -88,9 +90,10 @@ _layer_divisions = layer_division([1] * args.layers, args.pp_size) else: _layer_divisions = [(0, args.layers)] +print_each_rank(f'layer divisions: {_layer_divisions}') if args.schedule == '1f1b' and args.pp_size > 1: - grid = np.arange(args.dp_size, args.pp_size * args.tp_size).reshape( + grid = np.arange(args.dp_size * args.pp_size * args.tp_size).reshape( (args.dp_size, args.pp_size, args.tp_size)) for dp_rank in range(args.dp_size): embed_ranks = np.vstack((grid[dp_rank, 0, :], grid[dp_rank, -1, :])) @@ -103,6 +106,7 @@ if grank in embed_rank: print(f'rank [{grank}]: embedding group: {embed_rank}') _pp_embed_group = group + _pp_embed_reducer = Reducer(embed_rank) class Config: @@ -291,11 +295,11 @@ def __init__(self): # sq, b, h self.inputs_info = ( ((config.seqlen, args.micro_bs, config.hidden_size),), - (torch.float16 if args.fp16 else torch.float32) + (torch.float16 if args.fp16 else torch.float32,) ) self.outputs_info = ( ((config.seqlen, args.micro_bs, config.hidden_size),), - (torch.float16 if args.fp16 else torch.float32) + (torch.float16 if args.fp16 else torch.float32,) ) @@ -339,21 +343,39 @@ def __init__(self): self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) + inputs_info = None + outputs_info = None + self.word_embeddings = None if self.is_first_stage: + print(f'rank [{torch.distributed.get_rank()}]: initializing preprocess...') self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = torch.nn.Embedding( config.seqlen, config.hidden_size ) self.embedding_dropout = torch.nn.Dropout(0.0) + + inputs_info = ((), ()) if inputs_info is None else inputs_info start, end = _layer_divisions[self.stage_local_rank] + print_each_rank(f'initializing layers [{start}, {end})...') layers = [TransformerLayer() for _ in range(end - start)] self.layers = torch.nn.ModuleList(layers) + inputs_info = self.layers[0].inputs_info if inputs_info is None else inputs_info + outputs_info = self.layers[-1].outputs_info + if self.is_last_stage: + print(f'rank [{torch.distributed.get_rank()}]: initializing postprocess...') self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) if self.word_embeddings is None else self.word_embeddings self.final_layernorm = torch.nn.LayerNorm(config.hidden_size) + outputs_info = ((1,), (torch.float32,)) + + assert inputs_info is not None + assert outputs_info is not None + self.inputs_info = inputs_info + self.outputs_info = outputs_info + print_each_rank(f'stage: inputs: {inputs_info} | outputs: {outputs_info}') def forward(self, hidden_states = None): # data From 6e0dfd5dc774d7437ce9e7b144b443979c66a440 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 10:12:36 +0000 Subject: [PATCH 0750/1892] 1f1b enabled --- handcraft/gpt3/train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 330f98fd..c2a97589 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -497,6 +497,11 @@ def get_ltor_masks_and_position_ids(self, input_ids): dataloader = GPT3DataLoader(args.micro_bs) optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + if _pp_embed_reducer is not None: + _pp_embed_reducer.add_param(model.word_embeddings.weight) + if _dp_reducer is not None: + for param in model.parameters(): + _dp_reducer.add_param(param) print_each_rank('model weight consumpition:') memory_summary() @@ -512,12 +517,14 @@ def get_ltor_masks_and_position_ids(self, input_ids): num_microbatch = args.bs // args.micro_bs if args.pp_size > 1: _schedule(model, dataloader, num_microbatch) - reduce_embed(model, _pp_embed_group) else: for _ in range(num_microbatch): model.data = next(dataloader) loss = model() loss.backward() + + if _pp_embed_reducer is not None: + _pp_embed_reducer.allreduce() if _dp_reducer is not None: _dp_reducer.allreduce() From 83344e3486c17b6d29e3831cd7e52bcd65c5342b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 12:53:03 +0000 Subject: [PATCH 0751/1892] coshard --- handcraft/gpt3/train.py | 183 ++++++++++++++++++++++++++++++++-------- 1 file changed, 148 insertions(+), 35 deletions(-) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index c2a97589..74233bdd 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -1,14 +1,30 @@ """ +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers 24 --hidden-size 2048 --heads 32 \ + --dp-size 1 --tp-size 1 --pp-size 1 \ + --seqlen 8192 --bs 8 --micro-bs 1 --fp16 + OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ handcraft/gpt3/train.py \ - --layers 16 --hidden-size 1024 --heads 16 \ + --layers 32 --hidden-size 4096 --heads 32 \ --dp-size 1 --tp-size 1 --pp-size 4 \ - --bs 1 --micro-bs 1 --schedule 1f1b + --seqlen 1024 --bs 8 --micro-bs 1 --fp16 + +350M: --layers 24 --hidden-size 1024 --heads 16 \ +1.3B: --layers 24 --hidden-size 2048 --heads 32 \ +2.6B: --layers 32 --hidden-size 2560 --heads 32 \ +6.7B: --layers 32 --hidden-size 4096 --heads 32 \ +15 B: --layers 48 --hidden-size 5120 --heads 32 \ +39 B: --layers 48 --hidden-size 8192 --heads 64 \ """ import torch +import torch.utils.checkpoint as checkpoint import cube import math import numpy as np @@ -37,6 +53,8 @@ help='hidden size') parser.add_argument('--heads', type=int, default=16, help='number of heads') +parser.add_argument('--seqlen', type=int, default=1024, + help='sequence length') # training config parser.add_argument('--bs', type=int, default=256, help='num of micro batch') @@ -52,6 +70,10 @@ help='data parallelism size') parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b'], help='scheduling algorithm') +parser.add_argument('--use-coshard', action='store_true', default=False) +parser.add_argument('--coshard-num', type=int, default=4, + help='if use coshard, the coshard number') + args = parser.parse_args() print(args) @@ -111,7 +133,7 @@ class Config: vocab_size = 50273 - seqlen = 1024 + seqlen = args.seqlen layers = args.layers heads = args.heads hidden_size = args.hidden_size @@ -121,20 +143,21 @@ class Config: class MLP(torch.nn.Module): - def __init__(self): + def __init__(self, hidden_dim: int = None): super().__init__() - self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_group = _tp_group self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) + hidden_dim = config.hidden_size * 4 if hidden_dim is None else hidden_dim self.dense_h_to_4h = torch.nn.Linear( - config.hidden_size, config.hidden_size * 4 // self.tp_size + config.hidden_size, hidden_dim // self.tp_size ) self.dense_4h_to_h = torch.nn.Linear( - config.hidden_size * 4 // self.tp_size, config.hidden_size + hidden_dim // self.tp_size, config.hidden_size ) - def forward(self, hidden_states): + def forward_(self, hidden_states): if self.tp_size > 1: hidden_states = IdentityAllreduce.apply(hidden_states, self.tp_group) x = self.dense_h_to_4h(hidden_states) @@ -144,15 +167,50 @@ def forward(self, hidden_states): x = AllReduceIdentity.apply(x, self.tp_group) return x + def forward(self, hidden_states, recompute=False): + if recompute: + x = checkpoint.checkpoint(self.forward_, hidden_states) + else: + x = self.forward_(hidden_states) + return x -class Attention(torch.nn.Module): + +class SeqMLP(torch.nn.Module): def __init__(self): super().__init__() - self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_group = _tp_group self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) - self.num_heads = config.heads // self.tp_size + coshard = args.coshard_num + assert (config.hidden_size * 4) % (self.tp_size * coshard) == 0 + hidden_dim = config.hidden_size * 4 // coshard + self.mlps = torch.nn.ModuleList([MLP(hidden_dim) for _ in range(coshard)]) + for mlp in self.mlps: + mlp.tp_size = 1 + + def forward(self, x, recompute=True): + if self.tp_size > 1: + x = IdentityAllreduce.apply(x, self.tp_group) + + outs = None + for mlp in self.mlps: + x_out = mlp(x, recompute=recompute) + outs = x_out if outs is None else outs + x_out + + if self.tp_size > 1: + outs = AllReduceIdentity.apply(outs, self.tp_group) + return outs + + +class Attention(torch.nn.Module): + + def __init__(self, num_heads: int = None): + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) + + self.num_heads = num_heads if num_heads is not None else config.heads // self.tp_size self.head_dim = config.hidden_size // config.heads projection_size = self.num_heads * self.head_dim @@ -166,29 +224,32 @@ def __init__(self): projection_size, config.hidden_size ) - def forward(self, x, mask): + def forward_(self, x, mask): + # x: [seqlen, bs, hidden], np: head num | hn: head dim if self.tp_size > 1: x = IdentityAllreduce.apply(x, self.tp_group) + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(x) new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_heads, 3 * self.head_dim) + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] query_layer, key_layer, value_layer = \ torch.chunk(mixed_x_layer, 3, dim=-1) - # [b, np, sq, sk] + # [b, np, seqlen, seqlen] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - # [sq, b, np, hn] -> [sq, b * np, hn] + # [seqlen, b, np, hn] -> [seqlen, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] + # [seqlen, b, np, hn] -> [seqlen, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) @@ -199,17 +260,17 @@ def forward(self, x, mask): dtype=query_layer.dtype, device=torch.cuda.current_device()) - # Raw attention scores. [b * np, sq, sk] + # Raw attention scores. [b * np, seqlen, seqlen] matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * np, seqlen, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, seqlen] beta=0.0, alpha=(1.0/self.norm_factor)) - # change view to [b, np, sq, sk] + # change view to [b, np, seqlen, seqlen] attention_scores = matmul_result.view(*output_size) - # attention scores and attention mask [b, np, sq, sk] + # attention scores and attention mask [b, np, seqlen, seqlen] if mask is not None: attention_scores.masked_fill_(mask, -10000.0) attention_probs = self.softmax(attention_scores) @@ -219,42 +280,79 @@ def forward(self, x, mask): query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] + # change view [seqlen, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] + # change view [b * np, seqlen, seqlen] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] + # matmul: [b * np, seqlen, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] + # change view [b, np, seqlen, hn] context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] + # [b, np, seqlen, hn] --> [seqlen, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] + # [seqlen, b, np, hn] --> [seqlen, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.head_dim * self.num_heads,) context_layer = context_layer.view(*new_context_layer_shape) # ================= - # Output. [sq, b, h] + # Output. [seqlen, b, h] # ================= output = self.dense(context_layer) if self.tp_size > 1: output = AllReduceIdentity.apply(output, self.tp_group) return output + def forward(self, x, mask, recompute=False): + if recompute: + x = checkpoint.checkpoint(self.forward_, x, mask) + else: + x = self.forward_(x, mask) + return x + + +class SeqAttention(torch.nn.Module): + + def __init__(self): + super().__init__() + self.tp_group = _tp_group + self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) + + coshard = args.coshard_num + assert config.heads % (coshard * self.tp_size) == 0 + self.shard_num_heads = config.heads // coshard + self.attns = torch.nn.ModuleList( + [Attention(self.shard_num_heads) for _ in range(coshard)] + ) + for attn in self.attns: + attn._tp_size = 1 + + def forward(self, x, mask, recompute=True): + if self.tp_size > 1: + x = IdentityAllreduce.apply(x, self.tp_group) + + outs = None + for attn in self.attns: + x_out = attn(x, mask, recompute) + outs = x_out if outs is None else outs + x_out + + if self.tp_size > 1: + outs = AllReduceIdentity.apply(outs, self.tp_group) + return outs + class Embedding(torch.nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int): super().__init__() - self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_group = _tp_group self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) self.tp_id = 0 if self.tp_group == -1 else torch.distributed.get_rank(self.tp_group) @@ -287,12 +385,21 @@ class TransformerLayer(PipeStage): def __init__(self): super().__init__() self.input_layernorm = torch.nn.LayerNorm(config.hidden_size) - self.self_attention = Attention() + if args.use_coshard: + print('use cosharding attention...') + self.self_attention = SeqAttention() + else: + self.self_attention = Attention() + self.hidden_dropout = 0.0 self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size) - self.mlp = MLP() + if args.use_coshard: + print('use cosharding mlp...') + self.mlp = SeqMLP() + else: + self.mlp = MLP() - # sq, b, h + # seqlen, b, h self.inputs_info = ( ((config.seqlen, args.micro_bs, config.hidden_size),), (torch.float16 if args.fp16 else torch.float32,) @@ -340,7 +447,7 @@ class GPT3(PipeStage): def __init__(self): super().__init__() self.set_pipeline(pp_ranks) - self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group + self.tp_group = _tp_group self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) inputs_info = None @@ -389,13 +496,19 @@ def forward(self, hidden_states = None): embeddings = word_embeddings + position_embeddings embeddings = self.embedding_dropout(embeddings) hidden_states = embeddings + # [seqlen, bs, hidden] hidden_states = hidden_states.transpose(0, 1).contiguous() assert hidden_states is not None _, _, attention_mask, _ = self.data for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask) + if args.use_coshard: + # inner recompute + hidden_states = layer(hidden_states, attention_mask) + else: + # block recompute + hidden_states = checkpoint.checkpoint(layer, hidden_states, attention_mask) outputs = hidden_states # postprocess @@ -414,7 +527,7 @@ def forward(self, hidden_states = None): # minor changes from # https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/pretrain_gpt.py#L75 logits = logits.float() - logits = logits.view(config.seqlen, -1) + logits = logits.view(args.micro_bs * config.seqlen, -1) labels = labels.view(-1) loss = torch.nn.functional.cross_entropy(logits, labels) outputs = loss From 64ab83e221ef1247e3835bd537565f9178822e30 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 13:19:55 +0000 Subject: [PATCH 0752/1892] add email notification --- handcraft/mbart/test-2node-fp32.sh | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/handcraft/mbart/test-2node-fp32.sh b/handcraft/mbart/test-2node-fp32.sh index 9c01e0d9..ece03892 100755 --- a/handcraft/mbart/test-2node-fp32.sh +++ b/handcraft/mbart/test-2node-fp32.sh @@ -1,6 +1,8 @@ evaldir=eval/mbart-fp32-v100-32gb mkdir -p ${evaldir} +wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py + bs=256 test_mix_tp_1f1b() @@ -9,6 +11,8 @@ test_mix_tp_1f1b() hidden=$2 heads=$3 gpus=$4 + arch=${gpus}dev-L${layers}E${hidden}H${heads} + echo "testing ${gpus}-dev mixture-1f1b: L${layers}E${hidden}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ @@ -20,11 +24,14 @@ test_mix_tp_1f1b() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs ${bs} --micro-bs 1 \ --pp-size ${gpus} --tp-size 1 \ - --schedule tp1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt + --schedule tp1f1b > ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt sleep 5 killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results MBart Mixture-1f1b | Node ${Node_RANK} | ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt" + --file ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt } test_tp() @@ -33,6 +40,8 @@ test_tp() hidden=$2 heads=$3 gpus=$4 + arch=L${layers}E${hidden}H${heads} + echo "testing ${gpus}-dev pure tp: L${layers}E${hidden}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ @@ -44,11 +53,14 @@ test_tp() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs ${bs} --micro-bs 1 \ --pp-size 1 --tp-size ${gpus} \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + --schedule 1f1b > ${evaldir}/${gpus}dev-${arch}-tp.txt sleep 5 killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results MBart TP | Node Rank ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp.txt" + --file ${evaldir}/${gpus}dev-${arch}-tp.txt } test_pp() From 7931377fe8ddc305b6290e2a6dac3fe29f4fcaeb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 13:21:53 +0000 Subject: [PATCH 0753/1892] lower bs for pure tp --- handcraft/mbart/test-2node-fp32.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handcraft/mbart/test-2node-fp32.sh b/handcraft/mbart/test-2node-fp32.sh index ece03892..9cc0b948 100755 --- a/handcraft/mbart/test-2node-fp32.sh +++ b/handcraft/mbart/test-2node-fp32.sh @@ -51,7 +51,7 @@ test_tp() --master_port=${MASTER_PORT} \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ + --bs 16 --micro-bs 1 \ --pp-size 1 --tp-size ${gpus} \ --schedule 1f1b > ${evaldir}/${gpus}dev-${arch}-tp.txt sleep 5 From 5a8206914c083cf7175c02f7f9e201171e529653 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 7 Apr 2022 15:38:38 +0000 Subject: [PATCH 0754/1892] fix cosharding bug --- handcraft/gpt3/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 74233bdd..145c93d0 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -210,7 +210,7 @@ def __init__(self, num_heads: int = None): self.tp_group = _tp_group self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) - self.num_heads = num_heads if num_heads is not None else config.heads // self.tp_size + self.num_heads = (config.heads if num_heads is None else num_heads) // self.tp_size self.head_dim = config.hidden_size // config.heads projection_size = self.num_heads * self.head_dim @@ -332,7 +332,7 @@ def __init__(self): [Attention(self.shard_num_heads) for _ in range(coshard)] ) for attn in self.attns: - attn._tp_size = 1 + attn.tp_size = 1 def forward(self, x, mask, recompute=True): if self.tp_size > 1: From 313b7d98cbabe5b8e24761fe1c149944524e383b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 01:48:07 +0000 Subject: [PATCH 0755/1892] update script --- handcraft/mbart/test-2node-fp32.sh | 7 +-- handcraft/mbart/test-4node-fp32.sh | 66 ++++++++++++++---------- handcraft/mbart/test-fp32.sh | 2 +- handcraft/swin/test-2node.sh | 35 +++++++++---- handcraft/swin/test-4node.sh | 80 +++++++++++++++++------------- scripts/aggregate.sh | 23 +++++++++ 6 files changed, 137 insertions(+), 76 deletions(-) create mode 100755 scripts/aggregate.sh diff --git a/handcraft/mbart/test-2node-fp32.sh b/handcraft/mbart/test-2node-fp32.sh index 9cc0b948..70d386be 100755 --- a/handcraft/mbart/test-2node-fp32.sh +++ b/handcraft/mbart/test-2node-fp32.sh @@ -1,6 +1,7 @@ evaldir=eval/mbart-fp32-v100-32gb mkdir -p ${evaldir} +rm -f notify.py wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py bs=256 @@ -11,7 +12,7 @@ test_mix_tp_1f1b() hidden=$2 heads=$3 gpus=$4 - arch=${gpus}dev-L${layers}E${hidden}H${heads} + arch=L${layers}E${hidden}H${heads} echo "testing ${gpus}-dev mixture-1f1b: L${layers}E${hidden}H${heads}" OMP_NUM_THREADS=4 torchrun \ @@ -30,7 +31,7 @@ test_mix_tp_1f1b() sleep 5 killall python python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results MBart Mixture-1f1b | Node ${Node_RANK} | ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt" + --msg "Test Results MBart Mixture-1f1b | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt" \ --file ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt } @@ -59,7 +60,7 @@ test_tp() sleep 5 killall python python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results MBart TP | Node Rank ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp.txt" + --msg "Test Results MBart TP | Node Rank ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp.txt" \ --file ${evaldir}/${gpus}dev-${arch}-tp.txt } diff --git a/handcraft/mbart/test-4node-fp32.sh b/handcraft/mbart/test-4node-fp32.sh index e4a56303..89673ba5 100755 --- a/handcraft/mbart/test-4node-fp32.sh +++ b/handcraft/mbart/test-4node-fp32.sh @@ -3,6 +3,9 @@ mkdir -p ${evaldir} bs=256 +rm -f notify.py +wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py + test_mix_tp_1f1b() { layers=$1 @@ -13,9 +16,9 @@ test_mix_tp_1f1b() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=4 \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs ${bs} --micro-bs 1 \ @@ -25,6 +28,9 @@ test_mix_tp_1f1b() killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results MBart Mixture-1f1b | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt" \ + --file ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt } test_tp() @@ -37,9 +43,9 @@ test_tp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=4 \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs 8 --micro-bs 1 \ @@ -49,6 +55,9 @@ test_tp() killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results MBart Pure TP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt" \ + --file ${evaldir}/${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt } test_pp() @@ -61,9 +70,9 @@ test_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=4 \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs ${bs} --micro-bs 1 \ @@ -85,9 +94,9 @@ test_pp_swap() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=4 \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs ${bs} --micro-bs 1 \ @@ -112,9 +121,9 @@ test_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=4 \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --bs ${bs} --micro-bs 1 \ @@ -124,14 +133,17 @@ test_hybrid_tp_pp() killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results MBart TP16-PP2 | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp16pp2.txt" \ + --file ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp16pp2.txt # echo "testing ${gpus}-dev tp:pp=8:4 | L${layers}E${hidden}H${heads}" # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=4 \ - # --node_rank=${REMOTE_NODE_RANK} \ - # --master_addr="${REMOTE_MASTER_IP}" \ - # --master_port=${REMOTE_MASTER_PORT} \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ # handcraft/mbart/mbart_hybrid.py \ # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --bs ${bs} --micro-bs 1 \ @@ -146,9 +158,9 @@ test_hybrid_tp_pp() # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=4 \ - # --node_rank=${REMOTE_NODE_RANK} \ - # --master_addr="${REMOTE_MASTER_IP}" \ - # --master_port=${REMOTE_MASTER_PORT} \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ # handcraft/mbart/mbart_hybrid.py \ # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --bs ${bs} --micro-bs 1 \ @@ -181,9 +193,9 @@ python scripts/keep.py --gpus 8 # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=4 \ -# --node_rank=${REMOTE_NODE_RANK} \ -# --master_addr="${REMOTE_MASTER_IP}" \ -# --master_port=${REMOTE_MASTER_PORT} \ +# --node_rank=${NODE_RANK} \ +# --master_addr="${MASTER_IP}" \ +# --master_port=${MASTER_PORT} \ # handcraft/mbart/train.py \ # --layers 48 --hidden-size 6144 --heads 32 \ # --bs 32 --micro-bs 1 \ @@ -194,9 +206,9 @@ python scripts/keep.py --gpus 8 # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=4 \ -# --node_rank=${REMOTE_NODE_RANK} \ -# --master_addr="${REMOTE_MASTER_IP}" \ -# --master_port=${REMOTE_MASTER_PORT} \ +# --node_rank=${NODE_RANK} \ +# --master_addr="${MASTER_IP}" \ +# --master_port=${MASTER_PORT} \ # handcraft/mbart/train.py \ # --layers 52 --hidden-size 6144 --heads 32 \ # --bs 4 --micro-bs 1 \ diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh index badf06be..27f20ed3 100755 --- a/handcraft/mbart/test-fp32.sh +++ b/handcraft/mbart/test-fp32.sh @@ -32,7 +32,7 @@ test_tp() OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ handcraft/mbart/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ + --bs 16 --micro-bs 1 \ --pp-size 1 --tp-size ${gpus} \ --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt sleep 5 diff --git a/handcraft/swin/test-2node.sh b/handcraft/swin/test-2node.sh index 7ac0a825..73deb18f 100755 --- a/handcraft/swin/test-2node.sh +++ b/handcraft/swin/test-2node.sh @@ -3,6 +3,8 @@ evaldir=eval/swin-coshard mkdir -p ${evaldir} +rm -f notify.py +wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py img_size=1536 window_size=48 @@ -16,6 +18,7 @@ test_naive_pp() heads=$3 nodes=$4 gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} echo "testing ${gpus}-dev: Pure PP${coshard}: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ @@ -28,11 +31,14 @@ test_naive_pp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}.txt + --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-pp${gpus}.txt sleep 5 killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results Swin PP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-pp${gpus}.txt" \ + --file ${evaldir}/${gpus}dev-${arch}-pp${gpus}.txt } test_naive_tp() @@ -42,6 +48,7 @@ test_naive_tp() heads=$3 nodes=$4 gpus=$5 + arch=L${layers}E${dim}H${heads}-${img_size} echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ @@ -54,11 +61,14 @@ test_naive_tp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt + --bs 16 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt sleep 5 killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results Swin TP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt" \ + --file ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt } test_naive_hybrid_tp_pp() @@ -76,9 +86,9 @@ test_naive_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -93,9 +103,9 @@ test_naive_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -169,9 +179,9 @@ test_coshard_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ + --node_rank=${REMOTE_NODE_RANK} \ + --master_addr="${REMOTE_MASTER_IP}" \ + --master_port=${REMOTE_MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -182,6 +192,9 @@ test_coshard_hybrid_tp_pp() killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results Swin TP4-PP4+Coshard | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp4pp4-coshard.txt" \ + --file ${evaldir}/${gpus}dev-${arch}-tp4pp4-coshard.txt fi } diff --git a/handcraft/swin/test-4node.sh b/handcraft/swin/test-4node.sh index 8770122c..ab132b77 100755 --- a/handcraft/swin/test-4node.sh +++ b/handcraft/swin/test-4node.sh @@ -8,6 +8,9 @@ img_size=1536 window_size=48 bs=256 +rm -f notify.py +wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py + test_naive_pp() { @@ -22,9 +25,9 @@ test_naive_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -49,9 +52,9 @@ test_naive_tp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -61,6 +64,9 @@ test_naive_tp() killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results Swin TP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt" \ + --file ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt } test_naive_hybrid_tp_pp() @@ -79,9 +85,9 @@ test_naive_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -97,9 +103,9 @@ test_naive_hybrid_tp_pp() # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=${nodes} \ - # --node_rank=${REMOTE_NODE_RANK} \ - # --master_addr="${REMOTE_MASTER_IP}" \ - # --master_port=${REMOTE_MASTER_PORT} \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ # handcraft/swin/train.py \ # --layers ${layers} --dim ${dim} --heads ${heads} \ # --img-size ${img_size} --window-size ${window_size} \ @@ -121,13 +127,13 @@ test_coshard_pp() gpus=$5 arch=L${layers}E${dim}H${heads}-${img_size} - echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" + echo "testing ${gpus}-dev: Coshard PP: L${layers}E${dim}H${heads}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -138,6 +144,9 @@ test_coshard_pp() killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results Swin Coshard PP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt" \ + --file ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt } test_coshard_hybrid_tp_pp() @@ -156,9 +165,9 @@ test_coshard_hybrid_tp_pp() # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=${nodes} \ - # --node_rank=${REMOTE_NODE_RANK} \ - # --master_addr="${REMOTE_MASTER_IP}" \ - # --master_port=${REMOTE_MASTER_PORT} \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ # handcraft/swin/train.py \ # --layers ${layers} --dim ${dim} --heads ${heads} \ # --img-size ${img_size} --window-size ${window_size} \ @@ -174,9 +183,9 @@ test_coshard_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -187,14 +196,17 @@ test_coshard_hybrid_tp_pp() killall python sleep 5 killall python + python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ + --msg "Test Results Swin Coshard TP8-PP4 | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt" \ + --file ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt # echo "testing coshard ${gpus}-dev: TP4-PP8: L${layers}E${dim}H${heads}" # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=${nodes} \ - # --node_rank=${REMOTE_NODE_RANK} \ - # --master_addr="${REMOTE_MASTER_IP}" \ - # --master_port=${REMOTE_MASTER_PORT} \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ # handcraft/swin/train.py \ # --layers ${layers} --dim ${dim} --heads ${heads} \ # --img-size ${img_size} --window-size ${window_size} \ @@ -248,9 +260,9 @@ python scripts/keep.py --gpus 8 # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=4 \ -# --node_rank=${REMOTE_NODE_RANK} \ -# --master_addr="${REMOTE_MASTER_IP}" \ -# --master_port=${REMOTE_MASTER_PORT} \ +# --node_rank=${NODE_RANK} \ +# --master_addr="${MASTER_IP}" \ +# --master_port=${MASTER_PORT} \ # handcraft/swin/train.py \ # --layers ${layers} --dim ${dim} --heads ${heads} \ # --img-size 1536 --window-size 48 \ @@ -262,9 +274,9 @@ python scripts/keep.py --gpus 8 # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=4 \ -# --node_rank=${REMOTE_NODE_RANK} \ -# --master_addr="${REMOTE_MASTER_IP}" \ -# --master_port=${REMOTE_MASTER_PORT} \ +# --node_rank=${NODE_RANK} \ +# --master_addr="${MASTER_IP}" \ +# --master_port=${MASTER_PORT} \ # handcraft/swin/train.py \ # --layers ${layers} --dim ${dim} --heads ${heads} \ # --img-size 1536 --window-size 48 \ @@ -275,9 +287,9 @@ python scripts/keep.py --gpus 8 # OMP_NUM_THREADS=4 torchrun \ # --nproc_per_node=8 \ # --nnodes=4 \ -# --node_rank=${REMOTE_NODE_RANK} \ -# --master_addr="${REMOTE_MASTER_IP}" \ -# --master_port=${REMOTE_MASTER_PORT} \ +# --node_rank=${NODE_RANK} \ +# --master_addr="${MASTER_IP}" \ +# --master_port=${MASTER_PORT} \ # handcraft/swin/train.py \ # --layers ${layers} --dim ${dim} --heads ${heads} \ # --img-size 1536 --window-size 48 \ diff --git a/scripts/aggregate.sh b/scripts/aggregate.sh new file mode 100755 index 00000000..a5d03ab5 --- /dev/null +++ b/scripts/aggregate.sh @@ -0,0 +1,23 @@ +# ============= ITP Variables ============ +# NODE_RANK +# MASTER_IP +# MASTER_PORT +# ============= ITP Variables ============ + +node_num=$1 + +if [ ${node_num} == 4 ] +then + mkdir -p /workspace/MagicCube/eval/worker-1 + scp -r worker-1:/workspace/MagicCube/eval/ /workspace/MagicCube/eval/worker-1 + mkdir -p /workspace/MagicCube/eval/worker-2 + scp -r worker-2:/workspace/MagicCube/eval/ /workspace/MagicCube/eval/worker-2 + mkdir -p /workspace/MagicCube/eval/worker-3 + scp -r worker-3:/workspace/MagicCube/eval/ /workspace/MagicCube/eval/worker-3 +fi + +if [ ${node_num} == 2 ] +then + mkdir -p /workspace/MagicCube/eval/worker-1 + scp -r worker-1:/workspace/MagicCube/eval/ workspace/MagicCube/eval/worker-1 +fi From aa3e35022e97eb08ac4e51b43b7d42752e5f333e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 02:23:30 +0000 Subject: [PATCH 0756/1892] fix send bug --- handcraft/mbart/test-4node-fp32.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handcraft/mbart/test-4node-fp32.sh b/handcraft/mbart/test-4node-fp32.sh index 89673ba5..83ece579 100755 --- a/handcraft/mbart/test-4node-fp32.sh +++ b/handcraft/mbart/test-4node-fp32.sh @@ -57,7 +57,7 @@ test_tp() killall python python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ --msg "Test Results MBart Pure TP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt" \ - --file ${evaldir}/${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt + --file ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt } test_pp() From b73d9d8d59682ef506ce8407b41ecbb876f190d4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 02:57:06 +0000 Subject: [PATCH 0757/1892] script fix --- handcraft/swin/test-2node.sh | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/handcraft/swin/test-2node.sh b/handcraft/swin/test-2node.sh index 73deb18f..47f3f433 100755 --- a/handcraft/swin/test-2node.sh +++ b/handcraft/swin/test-2node.sh @@ -86,9 +86,9 @@ test_naive_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -103,9 +103,9 @@ test_naive_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -179,9 +179,9 @@ test_coshard_hybrid_tp_pp() OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=${nodes} \ - --node_rank=${REMOTE_NODE_RANK} \ - --master_addr="${REMOTE_MASTER_IP}" \ - --master_port=${REMOTE_MASTER_PORT} \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ handcraft/swin/train.py \ --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ @@ -215,12 +215,21 @@ test_all() # selected experiments # ================================================= -test_naive_tp 42 1024 32 2 16 -test_coshard_hybrid_tp_pp 42 1024 32 2 16 -# test_naive_hybrid_tp_pp 42 1024 32 2 16 # -> OOM - test_naive_tp 50 1024 32 2 16 test_coshard_hybrid_tp_pp 50 1024 32 2 16 # test_naive_hybrid_tp_pp 50 1024 32 2 16 # -> OOM python scripts/keep.py --gpus 8 + +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=8 \ +# --nnodes=2 \ +# --node_rank=${NODE_RANK} \ +# --master_addr="${MASTER_IP}" \ +# --master_port=${MASTER_PORT} \ +# handcraft/swin/train.py \ +# --layers 50 --dim 1024 --heads 32 \ +# --img-size 1536 --window-size 48 \ +# --pp-size 4 --tp-size 4 --dp-size 1 \ +# --bs 256 --micro-bs 1 --use-coshard --use-inner-coshard \ +# --fp16 From 23c3c3c7feb31337c0dfce2ca5d8f4039b6c053c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 07:41:02 +0000 Subject: [PATCH 0758/1892] weird OOM in another testbed --- handcraft/swin/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index f47e00b5..7821492b 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -1031,6 +1031,9 @@ def train_iter(model, dataloader): if step >= 2: CudaTimer().stop('e2e') + torch.cuda.empty_cache() + torch.distributed.barrier() + if step == 0: print_each_rank('memory consumption after optimizer:', rank_only=0) memory_summary() From 187375f8674676a8dd1d510b92b02c1da88643fb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 10:53:09 +0000 Subject: [PATCH 0759/1892] fix test script error --- handcraft/swin/test.sh | 3 +-- handcraft/swin/train.py | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh index 8a5f344f..57deff28 100755 --- a/handcraft/swin/test.sh +++ b/handcraft/swin/test.sh @@ -45,7 +45,7 @@ test_naive_tp() --layers ${layers} --dim ${dim} --heads ${heads} \ --img-size ${img_size} --window-size ${window_size} \ --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt + --bs 16 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt sleep 5 killall python sleep 5 @@ -244,7 +244,6 @@ test_coshard_dp 18 256 8 2 test_naive_tp 18 256 8 2 test_naive_hybrid_tp_pp 18 256 8 2 -test_coshard_dp 26 512 16 4 test_coshard_pp 26 512 16 4 test_naive_tp 26 512 16 4 # test_naive_hybrid_tp_pp 26 512 16 4 # --> OOM diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index f47e00b5..83ffb5f7 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -1010,10 +1010,6 @@ def train_iter(model, dataloader): iter_num = 6 for step in range(iter_num): - # if step == 0: - # model.data = next(dataloader) - # model_summary(model, (), rank_only=1) - if step >= 2: CudaTimer(enable=True).start('e2e') @@ -1031,6 +1027,9 @@ def train_iter(model, dataloader): if step >= 2: CudaTimer().stop('e2e') + torch.cuda.empty_cache() + torch.distributed.barrier() + if step == 0: print_each_rank('memory consumption after optimizer:', rank_only=0) memory_summary() From bcdb308d9cba0ee19b175792086a374832f5897e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 10:56:00 +0000 Subject: [PATCH 0760/1892] add test scripts --- handcraft/gpt3/test-1gpu.sh | 66 ++++++++ handcraft/gpt3/test-1node.sh | 283 +++++++++++++++++++++++++++++++++++ 2 files changed, 349 insertions(+) create mode 100755 handcraft/gpt3/test-1gpu.sh create mode 100755 handcraft/gpt3/test-1node.sh diff --git a/handcraft/gpt3/test-1gpu.sh b/handcraft/gpt3/test-1gpu.sh new file mode 100755 index 00000000..adb125af --- /dev/null +++ b/handcraft/gpt3/test-1gpu.sh @@ -0,0 +1,66 @@ +#### +# Single Node Model Scaling Test +#### +evaldir=eval/gpt3-coshard-v100-32gb +mkdir -p ${evaldir} + +bs=4 + +test_naive() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing naive (recompute): ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 > ${evaldir}/1dev-${arch}-naive.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_coshard() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing coshard: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ + --use-coshard --coshard-num 8 > ${evaldir}/1dev-${arch}-coshard.txt + sleep 5 + killall python + sleep 5 + killall python +} + + +# test_naive 24 2048 32 2048 +# test_naive 24 2048 32 4096 +# test_naive 24 2048 32 8192 +# # test_naive 24 2048 32 12288 # --> OOM +# # test_naive 24 2048 32 16384 # --> OOM +# +# test_coshard 24 2048 32 2048 +# test_coshard 24 2048 32 4096 +# test_coshard 24 2048 32 8192 +test_coshard 24 2048 32 12288 +# test_coshard 24 2048 32 16384 +# test_coshard 24 2048 32 20480 +# test_coshard 24 2048 32 24576 diff --git a/handcraft/gpt3/test-1node.sh b/handcraft/gpt3/test-1node.sh new file mode 100755 index 00000000..ae8850d9 --- /dev/null +++ b/handcraft/gpt3/test-1node.sh @@ -0,0 +1,283 @@ +#### +# Single Node Model Scaling Test +#### +evaldir=eval/gpt3-coshard-v100-32gb +mkdir -p ${evaldir} + +bs=256 + +test_pp() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing pipeline 1f1b: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --pp-size ${gpus} --tp-size 1 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-pp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_pp_coshard() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing coshard: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --pp-size ${gpus} --tp-size 1 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ + --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-pp-coshard.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_tp() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing tp: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --pp-size ${gpus} --tp-size 1 \ + --seqlen ${seqlen} --bs 16 --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_hybrid() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + if [ ${gpus} == 4 ] + then + echo "testing hybrid: tp:pp=2:2 : ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --pp-size 2 --dp-size 2 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp2pp2.txt + sleep 5 + killall python + sleep 5 + killall python + fi + + if [ ${gpus} == 8 ] + then + # echo "testing hybrid: dp:pp=4:2 : ${arch}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=${gpus} \ + # --nnodes=1 \ + # handcraft/gpt3/train.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --dp-size 4 --pp-size 2 \ + # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + # --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp4pp2.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + + echo "testing hybrid: tp:pp=2:4 : ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size 2 --pp-size 4 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp2pp4.txt + sleep 5 + killall python + sleep 5 + killall python + fi +} + + +test_dp() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing dp: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size ${gpus} --pp-size 1 --tp-size 1 \ + --seqlen ${seqlen} --bs 16 --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_dp_coshard() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing DP coshard: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size ${gpus} --pp-size 1 --tp-size 1 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ + --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp-coshard.txt + sleep 5 + killall python + sleep 5 + killall python +} + + +test_hybrid_coshard() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + if [ ${gpus} == 4 ] + then + echo "testing coshard hybrid: dp:pp=2:2 : ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size 2 --pp-size 2 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-tp2pp2-coshard.txt + sleep 5 + killall python + sleep 5 + killall python + fi + + if [ ${gpus} == 8 ] + then + echo "testing hybrid: dp:pp=4:2 : ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size 4 --pp-size 2 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-tp4pp2-coshard.txt + sleep 5 + killall python + sleep 5 + killall python + + # echo "testing hybrid: dp:pp=2:4 : ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=${gpus} \ + --nnodes=1 \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size 2 --pp-size 4 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-tp2pp4.txt + sleep 5 + killall python + sleep 5 + killall python + fi +} + +# 2.6B +test_dp_coshard 32 2560 32 12288 4 +test_pp 32 2560 32 12288 4 +test_hybrid 32 2560 32 12288 4 + +# 6.7B +test_hybrid 32 4096 32 8192 8 # pp2dp4 OOM, pp4dp2: 26.06GB +test_hybrid_coshard 32 4096 32 8192 8 # pp2dp4: 20.4GB +test_pp_coshard 32 4096 32 12288 8 # 16.73GB +test_hybrid_coshard 32 4096 32 12288 8 # pp2dp4: OOM, dp2pp4: 25.17GB + + +# =========================== + +# test_pp 24 8192 64 2048 8 # 15.45 GB +# test_pp 24 8192 64 4096 8 # 22.84 GB +# test_pp 24 8192 64 8192 8 # OOM +# test_tp 24 8192 64 8192 8 + +# 2.6B +# test_pp_coshard 32 2560 32 2048 1 # 12.24 GB +# test_pp 32 2560 32 2048 1 # can run +# test_pp 32 2560 32 4096 1 # 15.5GB +# test_pp 32 2560 32 8192 1 # 28.38 GB +# test_dp 32 2560 32 8192 4 # 28.38 GB + + +# 6.7B +# test_dp 32 4096 32 4096 8 # OOM +# test_hybrid 32 4096 32 4096 8 # 18.99GB +# test_hybrid 32 4096 32 8192 8 # pp2dp4 oom, pp4dp2: 26.06GB +# test_dp_coshard 32 4096 32 8192 8 # OOM +# test_hybrid_coshard 32 4096 32 8192 8 # pp2dp4: 20.4GB +# test_hybrid 32 4096 32 12288 8 # all OOM +# test_pp 32 4096 32 12288 8 # OOM +# test_pp_coshard 32 4096 32 12288 8 # 16.73GB +# test_hybrid_coshard 32 4096 32 12288 8 # dp4pp2 OOM, dp2pp4: 25.17GB From 7c6c7e6f716479f52273b1feffbac2cac22acc56 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 13:45:54 +0000 Subject: [PATCH 0761/1892] add test script for 1node --- handcraft/gpt3/test-1node.sh | 5 +++-- handcraft/gpt3/train.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/handcraft/gpt3/test-1node.sh b/handcraft/gpt3/test-1node.sh index ae8850d9..3dc75d42 100755 --- a/handcraft/gpt3/test-1node.sh +++ b/handcraft/gpt3/test-1node.sh @@ -4,7 +4,7 @@ evaldir=eval/gpt3-coshard-v100-32gb mkdir -p ${evaldir} -bs=256 +bs=8 test_pp() { @@ -252,7 +252,6 @@ test_hybrid 32 2560 32 12288 4 # 6.7B test_hybrid 32 4096 32 8192 8 # pp2dp4 OOM, pp4dp2: 26.06GB test_hybrid_coshard 32 4096 32 8192 8 # pp2dp4: 20.4GB -test_pp_coshard 32 4096 32 12288 8 # 16.73GB test_hybrid_coshard 32 4096 32 12288 8 # pp2dp4: OOM, dp2pp4: 25.17GB @@ -281,3 +280,5 @@ test_hybrid_coshard 32 4096 32 12288 8 # pp2dp4: OOM, dp2pp4: 25.17GB # test_pp 32 4096 32 12288 8 # OOM # test_pp_coshard 32 4096 32 12288 8 # 16.73GB # test_hybrid_coshard 32 4096 32 12288 8 # dp4pp2 OOM, dp2pp4: 25.17GB + +# 15B diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 145c93d0..6cc5e5da 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -386,7 +386,7 @@ def __init__(self): super().__init__() self.input_layernorm = torch.nn.LayerNorm(config.hidden_size) if args.use_coshard: - print('use cosharding attention...') + # print('use cosharding attention...') self.self_attention = SeqAttention() else: self.self_attention = Attention() @@ -394,7 +394,7 @@ def __init__(self): self.hidden_dropout = 0.0 self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size) if args.use_coshard: - print('use cosharding mlp...') + # print('use cosharding mlp...') self.mlp = SeqMLP() else: self.mlp = MLP() From db664b949379f1a61fef0175a8f9f5a53ac8c4ad Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 13:48:15 +0000 Subject: [PATCH 0762/1892] enable log --- handcraft/gpt3/test-1node.sh | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/handcraft/gpt3/test-1node.sh b/handcraft/gpt3/test-1node.sh index 3dc75d42..ab81e522 100755 --- a/handcraft/gpt3/test-1node.sh +++ b/handcraft/gpt3/test-1node.sh @@ -23,7 +23,7 @@ test_pp() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --pp-size ${gpus} --tp-size 1 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-pp.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-pp.txt sleep 5 killall python sleep 5 @@ -47,7 +47,7 @@ test_pp_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --pp-size ${gpus} --tp-size 1 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-pp-coshard.txt + --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-pp-coshard.txt sleep 5 killall python sleep 5 @@ -71,7 +71,7 @@ test_tp() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --pp-size ${gpus} --tp-size 1 \ --seqlen ${seqlen} --bs 16 --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-tp.txt sleep 5 killall python sleep 5 @@ -97,7 +97,7 @@ test_hybrid() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --pp-size 2 --dp-size 2 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp2pp2.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-tp2pp2.txt sleep 5 killall python sleep 5 @@ -114,7 +114,7 @@ test_hybrid() # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --dp-size 4 --pp-size 2 \ # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp4pp2.txt + # --fp16 > ${evaldir}/${gpus}dev-${arch}-tp4pp2.txt # sleep 5 # killall python # sleep 5 @@ -128,7 +128,7 @@ test_hybrid() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 2 --pp-size 4 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp2pp4.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-tp2pp4.txt sleep 5 killall python sleep 5 @@ -154,7 +154,7 @@ test_dp() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size ${gpus} --pp-size 1 --tp-size 1 \ --seqlen ${seqlen} --bs 16 --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-dp.txt sleep 5 killall python sleep 5 @@ -178,7 +178,7 @@ test_dp_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size ${gpus} --pp-size 1 --tp-size 1 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp-coshard.txt + --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp-coshard.txt sleep 5 killall python sleep 5 @@ -205,7 +205,7 @@ test_hybrid_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 2 --pp-size 2 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-tp2pp2-coshard.txt + --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-tp2pp2-coshard.txt sleep 5 killall python sleep 5 @@ -222,7 +222,7 @@ test_hybrid_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 4 --pp-size 2 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-tp4pp2-coshard.txt + --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-tp4pp2-coshard.txt sleep 5 killall python sleep 5 @@ -236,7 +236,7 @@ test_hybrid_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 2 --pp-size 4 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-tp2pp4.txt + --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-tp2pp4.txt sleep 5 killall python sleep 5 From 030603b51fa31448c1a001415eddb794a492b041 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 14:22:22 +0000 Subject: [PATCH 0763/1892] test for 2 node case --- handcraft/gpt3/test-1node.sh | 16 +- handcraft/gpt3/test-2node.sh | 302 +++++++++++++++++++++++++++++++++++ 2 files changed, 310 insertions(+), 8 deletions(-) create mode 100755 handcraft/gpt3/test-2node.sh diff --git a/handcraft/gpt3/test-1node.sh b/handcraft/gpt3/test-1node.sh index ab81e522..b68fc266 100755 --- a/handcraft/gpt3/test-1node.sh +++ b/handcraft/gpt3/test-1node.sh @@ -4,7 +4,7 @@ evaldir=eval/gpt3-coshard-v100-32gb mkdir -p ${evaldir} -bs=8 +bs=256 test_pp() { @@ -97,7 +97,7 @@ test_hybrid() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --pp-size 2 --dp-size 2 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-tp2pp2.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-dp2pp2.txt sleep 5 killall python sleep 5 @@ -114,13 +114,13 @@ test_hybrid() # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --dp-size 4 --pp-size 2 \ # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 > ${evaldir}/${gpus}dev-${arch}-tp4pp2.txt + # --fp16 > ${evaldir}/${gpus}dev-${arch}-dp4pp2.txt # sleep 5 # killall python # sleep 5 # killall python - echo "testing hybrid: tp:pp=2:4 : ${arch}" + echo "testing hybrid: dp:pp=2:4 : ${arch}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=${gpus} \ --nnodes=1 \ @@ -128,7 +128,7 @@ test_hybrid() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 2 --pp-size 4 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-tp2pp4.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-dp2pp4.txt sleep 5 killall python sleep 5 @@ -205,7 +205,7 @@ test_hybrid_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 2 --pp-size 2 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-tp2pp2-coshard.txt + --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp2pp2-coshard.txt sleep 5 killall python sleep 5 @@ -222,7 +222,7 @@ test_hybrid_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 4 --pp-size 2 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-tp4pp2-coshard.txt + --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp4pp2-coshard.txt sleep 5 killall python sleep 5 @@ -236,7 +236,7 @@ test_hybrid_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 2 --pp-size 4 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-tp2pp4.txt + --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp2pp4-coshard.txt sleep 5 killall python sleep 5 diff --git a/handcraft/gpt3/test-2node.sh b/handcraft/gpt3/test-2node.sh new file mode 100755 index 00000000..88c517d9 --- /dev/null +++ b/handcraft/gpt3/test-2node.sh @@ -0,0 +1,302 @@ +#### +# 2-Node Model Scaling Test +#### +evaldir=eval/gpt3-coshard-v100-32gb +mkdir -p ${evaldir} + +bs=256 + +test_pp() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing pipeline 1f1b: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --pp-size ${gpus} --tp-size 1 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-pp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_pp_coshard() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing coshard: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --pp-size ${gpus} --tp-size 1 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ + --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-pp-coshard.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_tp() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing tp: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --pp-size ${gpus} --tp-size 1 \ + --seqlen ${seqlen} --bs 16 --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_hybrid() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing hybrid: dp:pp=2:8 : ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size 2 --pp-size 8 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp2pp8.txt + sleep 5 + killall python + sleep 5 + killall python + + # echo "testing hybrid: dp:pp=4:4 : ${arch}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=2 \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ + # handcraft/gpt3/train.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --dp-size 4 --pp-size 4 \ + # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + # --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp4pp4.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + # + # echo "testing hybrid: dp:pp=8:2 : ${arch}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=2 \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ + # handcraft/gpt3/train.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --dp-size 8 --pp-size 2 \ + # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + # --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp8pp2.txt + # sleep 5 + # killall python + # sleep 5 + # killall python +} + + +test_dp() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing dp: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size ${gpus} --pp-size 1 --tp-size 1 \ + --seqlen ${seqlen} --bs 16 --micro-bs 1 \ + --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp.txt + sleep 5 + killall python + sleep 5 + killall python +} + +test_dp_coshard() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + echo "testing DP coshard: ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size ${gpus} --pp-size 1 --tp-size 1 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ + --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp-coshard.txt + sleep 5 + killall python + sleep 5 + killall python +} + + +test_hybrid_coshard() +{ + layers=$1 + hidden=$2 + heads=$3 + seqlen=$4 + gpus=$5 + arch=L${layers}E${hidden}H${heads}-seq${seqlen} + + # echo "testing coshard hybrid: dp:pp=2:8 : ${arch}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=2 \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ + # handcraft/gpt3/train.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --dp-size 2 --pp-size 8 \ + # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + # --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp2pp8-coshard.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + + echo "testing coshard hybrid: dp:pp=4:4 : ${arch}" + OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=${NODE_RANK} \ + --master_addr="${MASTER_IP}" \ + --master_port=${MASTER_PORT} \ + handcraft/gpt3/train.py \ + --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + --dp-size 4 --pp-size 4 \ + --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp4pp4-coshard.txt + sleep 5 + killall python + sleep 5 + killall python + + # echo "testing coshard hybrid: dp:pp=8:2 : ${arch}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=2 \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ + # handcraft/gpt3/train.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --dp-size 8 --pp-size 2 \ + # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + # --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp4pp4-coshard.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + +} + +# 15B +test_pp 48 5120 32 8192 16 +test_hybrid_coshard 48 5120 32 8192 16 + +# =========================== + +# test_pp 24 8192 64 2048 8 # 15.45 GB +# test_pp 24 8192 64 4096 8 # 22.84 GB +# test_pp 24 8192 64 8192 8 # OOM +# test_tp 24 8192 64 8192 8 + +# 2.6B +# test_pp_coshard 32 2560 32 2048 1 # 12.24 GB +# test_pp 32 2560 32 2048 1 # can run +# test_pp 32 2560 32 4096 1 # 15.5GB +# test_pp 32 2560 32 8192 1 # 28.38 GB +# test_dp 32 2560 32 8192 4 # 28.38 GB + + +# 6.7B +# test_dp 32 4096 32 4096 8 # OOM +# test_hybrid 32 4096 32 4096 8 # 18.99GB +# test_hybrid 32 4096 32 8192 8 # pp2dp4 oom, pp4dp2: 26.06GB +# test_dp_coshard 32 4096 32 8192 8 # OOM +# test_hybrid_coshard 32 4096 32 8192 8 # pp2dp4: 20.4GB +# test_hybrid 32 4096 32 12288 8 # all OOM +# test_pp 32 4096 32 12288 8 # OOM +# test_pp_coshard 32 4096 32 12288 8 # 16.73GB +# test_hybrid_coshard 32 4096 32 12288 8 # dp4pp2 OOM, dp2pp4: 25.17GB + +# 15B +# test_hybrid 48 5120 32 4096 16 -> pp8tp2 15.62GB +# test_hybrid 48 5120 32 8192 16 # OOM +# test_hybrid_coshard 48 5120 32 8192 16 \ No newline at end of file From 7696485e8e09bc7be12f357038b39840bcbacbff Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 14:25:20 +0000 Subject: [PATCH 0764/1892] enable log --- handcraft/gpt3/test-2node.sh | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/handcraft/gpt3/test-2node.sh b/handcraft/gpt3/test-2node.sh index 88c517d9..47b08cd3 100755 --- a/handcraft/gpt3/test-2node.sh +++ b/handcraft/gpt3/test-2node.sh @@ -26,7 +26,7 @@ test_pp() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --pp-size ${gpus} --tp-size 1 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-pp.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-pp.txt sleep 5 killall python sleep 5 @@ -53,7 +53,7 @@ test_pp_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --pp-size ${gpus} --tp-size 1 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-pp-coshard.txt + --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-pp-coshard.txt sleep 5 killall python sleep 5 @@ -80,7 +80,7 @@ test_tp() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --pp-size ${gpus} --tp-size 1 \ --seqlen ${seqlen} --bs 16 --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-tp.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-tp.txt sleep 5 killall python sleep 5 @@ -107,7 +107,7 @@ test_hybrid() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 2 --pp-size 8 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp2pp8.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-dp2pp8.txt sleep 5 killall python sleep 5 @@ -124,7 +124,7 @@ test_hybrid() # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --dp-size 4 --pp-size 4 \ # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp4pp4.txt + # --fp16 > ${evaldir}/${gpus}dev-${arch}-dp4pp4.txt # sleep 5 # killall python # sleep 5 @@ -141,7 +141,7 @@ test_hybrid() # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --dp-size 8 --pp-size 2 \ # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp8pp2.txt + # --fp16 > ${evaldir}/${gpus}dev-${arch}-dp8pp2.txt # sleep 5 # killall python # sleep 5 @@ -169,7 +169,7 @@ test_dp() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size ${gpus} --pp-size 1 --tp-size 1 \ --seqlen ${seqlen} --bs 16 --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-dp.txt sleep 5 killall python sleep 5 @@ -196,7 +196,7 @@ test_dp_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size ${gpus} --pp-size 1 --tp-size 1 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp-coshard.txt + --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp-coshard.txt sleep 5 killall python sleep 5 @@ -224,7 +224,7 @@ test_hybrid_coshard() # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --dp-size 2 --pp-size 8 \ # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp2pp8-coshard.txt + # --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp2pp8-coshard.txt # sleep 5 # killall python # sleep 5 @@ -241,7 +241,7 @@ test_hybrid_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size 4 --pp-size 4 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp4pp4-coshard.txt + --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp4pp4-coshard.txt sleep 5 killall python sleep 5 @@ -258,7 +258,7 @@ test_hybrid_coshard() # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ # --dp-size 8 --pp-size 2 \ # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp4pp4-coshard.txt + # --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp4pp4-coshard.txt # sleep 5 # killall python # sleep 5 From b00f08bf088ff8b2e6987cd4223e2cef6e02cbae Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 15:09:52 +0000 Subject: [PATCH 0765/1892] strange OOM in some platform --- handcraft/gpt3/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 6cc5e5da..6a2d5622 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -648,6 +648,9 @@ def get_ltor_masks_and_position_ids(self, input_ids): if step >= 2: CudaTimer().stop('e2e') + torch.cuda.empty_cache() + torch.distributed.barrier() + if step == 0: print_each_rank('memory after optimizer:', rank_only=0) memory_summary() From b8a9fed964106f8eb384c442cc5e8227234932c0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 8 Apr 2022 16:37:03 +0000 Subject: [PATCH 0766/1892] fix test script --- handcraft/gpt3/test-2node.sh | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/handcraft/gpt3/test-2node.sh b/handcraft/gpt3/test-2node.sh index 47b08cd3..346be8ae 100755 --- a/handcraft/gpt3/test-2node.sh +++ b/handcraft/gpt3/test-2node.sh @@ -96,7 +96,24 @@ test_hybrid() gpus=$5 arch=L${layers}E${hidden}H${heads}-seq${seqlen} - echo "testing hybrid: dp:pp=2:8 : ${arch}" + # echo "testing hybrid: dp:pp=2:8 : ${arch}" + # OMP_NUM_THREADS=4 torchrun \ + # --nproc_per_node=8 \ + # --nnodes=2 \ + # --node_rank=${NODE_RANK} \ + # --master_addr="${MASTER_IP}" \ + # --master_port=${MASTER_PORT} \ + # handcraft/gpt3/train.py \ + # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ + # --dp-size 2 --pp-size 8 \ + # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ + # --fp16 > ${evaldir}/${gpus}dev-${arch}-dp2pp8.txt + # sleep 5 + # killall python + # sleep 5 + # killall python + + echo "testing hybrid: tp:pp=2:8 : ${arch}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=2 \ @@ -105,9 +122,9 @@ test_hybrid() --master_port=${MASTER_PORT} \ handcraft/gpt3/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size 2 --pp-size 8 \ + --tp-size 2 --pp-size 8 \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-dp2pp8.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-tp2pp8.txt sleep 5 killall python sleep 5 @@ -267,7 +284,7 @@ test_hybrid_coshard() } # 15B -test_pp 48 5120 32 8192 16 +test_hybrid 48 5120 32 8192 16 test_hybrid_coshard 48 5120 32 8192 16 # =========================== @@ -298,5 +315,9 @@ test_hybrid_coshard 48 5120 32 8192 16 # 15B # test_hybrid 48 5120 32 4096 16 -> pp8tp2 15.62GB -# test_hybrid 48 5120 32 8192 16 # OOM -# test_hybrid_coshard 48 5120 32 8192 16 \ No newline at end of file +# test_hybrid 48 5120 32 8192 16 # pp-dp OOM, pp8tp2: can run +# test_pp 48 5120 32 8192 16 # OOM +# test_hybrid_coshard 48 5120 32 8192 16 # can run +# test_hybrid 48 5120 32 6144 16 # pp8dp2 can run +# test_hybrid 32 4096 32 12288 16 # pp8dp2 OOM, pp8tp2 25.10G +# test_pp 32 4096 32 12288 16 # OOM \ No newline at end of file From 5db9c87aafc71b5632de9b4c4b94d85c2f804929 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 9 Apr 2022 02:09:48 +0000 Subject: [PATCH 0767/1892] fix dp bug --- handcraft/gpt3/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 6a2d5622..2ebe1f64 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -627,7 +627,7 @@ def get_ltor_masks_and_position_ids(self, input_ids): CudaTimer(enable=True).start('e2e') # train 1 step - num_microbatch = args.bs // args.micro_bs + num_microbatch = args.bs // (args.micro_bs * args.dp_size) if args.pp_size > 1: _schedule(model, dataloader, num_microbatch) else: From 267b089b203fa82a950b31bbd6dde9a5b05947d2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 9 Apr 2022 03:17:22 +0000 Subject: [PATCH 0768/1892] 16 gpu full test --- handcraft/gpt3/test-2node.sh | 291 +++++------------------------------ 1 file changed, 36 insertions(+), 255 deletions(-) diff --git a/handcraft/gpt3/test-2node.sh b/handcraft/gpt3/test-2node.sh index 346be8ae..76368925 100755 --- a/handcraft/gpt3/test-2node.sh +++ b/handcraft/gpt3/test-2node.sh @@ -6,86 +6,6 @@ mkdir -p ${evaldir} bs=256 -test_pp() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing pipeline 1f1b: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --pp-size ${gpus} --tp-size 1 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-pp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_pp_coshard() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing coshard: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --pp-size ${gpus} --tp-size 1 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-pp-coshard.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_tp() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing tp: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --pp-size ${gpus} --tp-size 1 \ - --seqlen ${seqlen} --bs 16 --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-tp.txt - sleep 5 - killall python - sleep 5 - killall python -} test_hybrid() { @@ -94,26 +14,12 @@ test_hybrid() heads=$3 seqlen=$4 gpus=$5 + dp=$6 + pp=$7 + tp=$8 arch=L${layers}E${hidden}H${heads}-seq${seqlen} - # echo "testing hybrid: dp:pp=2:8 : ${arch}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=2 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/gpt3/train.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --dp-size 2 --pp-size 8 \ - # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 > ${evaldir}/${gpus}dev-${arch}-dp2pp8.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - - echo "testing hybrid: tp:pp=2:8 : ${arch}" + echo "testing hybrid: dp:pp:tp=${dp}:${pp}:${tp} : ${arch}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=2 \ @@ -122,98 +28,9 @@ test_hybrid() --master_port=${MASTER_PORT} \ handcraft/gpt3/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --tp-size 2 --pp-size 8 \ + --dp-size ${dp} --pp-size ${pp} --tp-size ${tp} \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-tp2pp8.txt - sleep 5 - killall python - sleep 5 - killall python - - # echo "testing hybrid: dp:pp=4:4 : ${arch}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=2 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/gpt3/train.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --dp-size 4 --pp-size 4 \ - # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 > ${evaldir}/${gpus}dev-${arch}-dp4pp4.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - # - # echo "testing hybrid: dp:pp=8:2 : ${arch}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=2 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/gpt3/train.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --dp-size 8 --pp-size 2 \ - # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 > ${evaldir}/${gpus}dev-${arch}-dp8pp2.txt - # sleep 5 - # killall python - # sleep 5 - # killall python -} - - -test_dp() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing dp: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size ${gpus} --pp-size 1 --tp-size 1 \ - --seqlen ${seqlen} --bs 16 --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-dp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_dp_coshard() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing DP coshard: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size ${gpus} --pp-size 1 --tp-size 1 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp-coshard.txt + --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp${dp}pp${pp}tp${tp}.txt sleep 5 killall python sleep 5 @@ -228,26 +45,12 @@ test_hybrid_coshard() heads=$3 seqlen=$4 gpus=$5 + dp=$6 + pp=$7 + tp=$8 arch=L${layers}E${hidden}H${heads}-seq${seqlen} - # echo "testing coshard hybrid: dp:pp=2:8 : ${arch}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=2 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/gpt3/train.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --dp-size 2 --pp-size 8 \ - # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp2pp8-coshard.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - - echo "testing coshard hybrid: dp:pp=4:4 : ${arch}" + echo "testing coshard hybrid: dp:pp=${dp}:${pp}:${tp} : ${arch}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=2 \ @@ -256,68 +59,46 @@ test_hybrid_coshard() --master_port=${MASTER_PORT} \ handcraft/gpt3/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size 4 --pp-size 4 \ + --dp-size ${dp} --pp-size ${pp} --tp-size ${tp} \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp4pp4-coshard.txt + --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp${dp}pp${pp}tp${tp}-coshard.txt sleep 5 killall python sleep 5 killall python - - # echo "testing coshard hybrid: dp:pp=8:2 : ${arch}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=2 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/gpt3/train.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --dp-size 8 --pp-size 2 \ - # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp4pp4-coshard.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - } # 15B -test_hybrid 48 5120 32 8192 16 -test_hybrid_coshard 48 5120 32 8192 16 -# =========================== +test_hybrid 48 5120 32 2048 16 4 4 1 # dp8pp2 OOM dp4pp4 18.91GB +test_hybrid_coshard 48 5120 32 2048 16 4 4 1 # dp4pp4 20.93 -# test_pp 24 8192 64 2048 8 # 15.45 GB -# test_pp 24 8192 64 4096 8 # 22.84 GB -# test_pp 24 8192 64 8192 8 # OOM -# test_tp 24 8192 64 8192 8 +# test_hybrid 48 5120 32 4096 16 16 1 1 # pp16 +# test_hybrid 48 5120 32 4096 16 1 1 16 # tp16 +test_hybrid 48 5120 32 4096 16 4 4 1 # dp4pp4 15.62 +test_hybrid_coshard 48 5120 32 4096 16 4 4 1 # dp4pp4 20.93 -# 2.6B -# test_pp_coshard 32 2560 32 2048 1 # 12.24 GB -# test_pp 32 2560 32 2048 1 # can run -# test_pp 32 2560 32 4096 1 # 15.5GB -# test_pp 32 2560 32 8192 1 # 28.38 GB -# test_dp 32 2560 32 8192 4 # 28.38 GB +# test_hybrid 48 5120 32 8192 16 16 1 1 # pp16 OOM +test_hybrid 48 5120 32 8192 16 1 8 2 # pp8tp2 # pp2tp2 17.17GB +test_hybrid_coshard 48 5120 32 8192 16 4 4 1 # dp4pp4 # dp4pp4 26.73GB +test_hybrid 48 5120 32 12288 16 1 4 4 # pp8tp2 OOM pp4tp4 20.29GB +test_hybrid_coshard 48 5120 32 12288 16 2 8 1 # dp4pp4 OOM dp2pp8 26.88GB -# 6.7B -# test_dp 32 4096 32 4096 8 # OOM -# test_hybrid 32 4096 32 4096 8 # 18.99GB -# test_hybrid 32 4096 32 8192 8 # pp2dp4 oom, pp4dp2: 26.06GB -# test_dp_coshard 32 4096 32 8192 8 # OOM -# test_hybrid_coshard 32 4096 32 8192 8 # pp2dp4: 20.4GB -# test_hybrid 32 4096 32 12288 8 # all OOM -# test_pp 32 4096 32 12288 8 # OOM -# test_pp_coshard 32 4096 32 12288 8 # 16.73GB -# test_hybrid_coshard 32 4096 32 12288 8 # dp4pp2 OOM, dp2pp4: 25.17GB +# =========================== # 15B -# test_hybrid 48 5120 32 4096 16 -> pp8tp2 15.62GB +# test_hybrid 48 5120 32 2048 16 4 4 1 # dp8pp2 OOM dp4pp4 18.91GB + +# test_pp 48 5120 32 4096 16 # 12.42GB +# test_hybrid 48 5120 32 4096 16 2 8 1 # dp2pp8 15.62GB +# test_hybrid 48 5120 32 4096 16 4 4 1 # dp4pp4 15.62 +# test_hybrid 48 5120 32 4096 16 8 2 1 # dp8pp2 OOM +# test_hybrid_coshard 48 5120 32 4096 16 4 4 1 # dp16 OOM dp8pp2 OOM dp4pp4 can run + # test_hybrid 48 5120 32 8192 16 # pp-dp OOM, pp8tp2: can run # test_pp 48 5120 32 8192 16 # OOM -# test_hybrid_coshard 48 5120 32 8192 16 # can run -# test_hybrid 48 5120 32 6144 16 # pp8dp2 can run -# test_hybrid 32 4096 32 12288 16 # pp8dp2 OOM, pp8tp2 25.10G -# test_pp 32 4096 32 12288 16 # OOM \ No newline at end of file +# test_hybrid_coshard 48 5120 32 8192 16 4 4 1 # dp4pp4 + +# test_hybrid 48 5120 32 12288 16 1 4 4 # pp8tp2 OOM pp4tp4 20.29GB +# test_hybrid_coshard 48 5120 32 12288 16 2 8 1 # dp4pp4 OOM dp2pp8 26.88GB From 4297a13c8c4fba74695be55f877bd4dee4f73a02 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 9 Apr 2022 08:01:27 +0000 Subject: [PATCH 0769/1892] add test for 16g mem constraints --- handcraft/gpt3/test-1gpu.sh | 15 ++++--- handcraft/gpt3/test-2node.sh | 39 +++++++++++------- handcraft/gpt3/train.py | 78 ++++++++++++++++++++++++++++++++++-- 3 files changed, 110 insertions(+), 22 deletions(-) diff --git a/handcraft/gpt3/test-1gpu.sh b/handcraft/gpt3/test-1gpu.sh index adb125af..3e89c256 100755 --- a/handcraft/gpt3/test-1gpu.sh +++ b/handcraft/gpt3/test-1gpu.sh @@ -21,7 +21,7 @@ test_naive() handcraft/gpt3/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/1dev-${arch}-naive.txt + --fp16 # > ${evaldir}/1dev-${arch}-naive.txt sleep 5 killall python sleep 5 @@ -43,7 +43,7 @@ test_coshard() handcraft/gpt3/train.py \ --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 > ${evaldir}/1dev-${arch}-coshard.txt + --use-coshard --coshard-num 8 # > ${evaldir}/1dev-${arch}-coshard.txt sleep 5 killall python sleep 5 @@ -51,16 +51,21 @@ test_coshard() } +test_naive 48 5120 32 2048 +test_naive 48 5120 32 4096 +test_naive 48 5120 32 8192 +test_naive 48 5120 32 12288 + # test_naive 24 2048 32 2048 # test_naive 24 2048 32 4096 # test_naive 24 2048 32 8192 -# # test_naive 24 2048 32 12288 # --> OOM -# # test_naive 24 2048 32 16384 # --> OOM +# # test_naive 24 2048 32 12288 # --# > OOM +# # test_naive 24 2048 32 16384 # --# > OOM # # test_coshard 24 2048 32 2048 # test_coshard 24 2048 32 4096 # test_coshard 24 2048 32 8192 -test_coshard 24 2048 32 12288 +# test_coshard 24 2048 32 12288 # test_coshard 24 2048 32 16384 # test_coshard 24 2048 32 20480 # test_coshard 24 2048 32 24576 diff --git a/handcraft/gpt3/test-2node.sh b/handcraft/gpt3/test-2node.sh index 76368925..4fce6138 100755 --- a/handcraft/gpt3/test-2node.sh +++ b/handcraft/gpt3/test-2node.sh @@ -30,7 +30,7 @@ test_hybrid() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size ${dp} --pp-size ${pp} --tp-size ${tp} \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 # > ${evaldir}/${gpus}dev-${arch}-dp${dp}pp${pp}tp${tp}.txt + --fp16 > ${evaldir}/${gpus}dev-${arch}-dp${dp}pp${pp}tp${tp}.txt sleep 5 killall python sleep 5 @@ -61,7 +61,7 @@ test_hybrid_coshard() --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ --dp-size ${dp} --pp-size ${pp} --tp-size ${tp} \ --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 # > ${evaldir}/${gpus}dev-${arch}-dp${dp}pp${pp}tp${tp}-coshard.txt + --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp${dp}pp${pp}tp${tp}-coshard.txt sleep 5 killall python sleep 5 @@ -70,20 +70,28 @@ test_hybrid_coshard() # 15B -test_hybrid 48 5120 32 2048 16 4 4 1 # dp8pp2 OOM dp4pp4 18.91GB -test_hybrid_coshard 48 5120 32 2048 16 4 4 1 # dp4pp4 20.93 - -# test_hybrid 48 5120 32 4096 16 16 1 1 # pp16 -# test_hybrid 48 5120 32 4096 16 1 1 16 # tp16 -test_hybrid 48 5120 32 4096 16 4 4 1 # dp4pp4 15.62 -test_hybrid_coshard 48 5120 32 4096 16 4 4 1 # dp4pp4 20.93 +# test_hybrid 48 5120 32 2048 16 4 4 1 # dp8pp2 OOM dp4pp4 18.91GB +# test_hybrid_coshard 48 5120 32 2048 16 4 4 1 # dp4pp4 20.93 +# +# test_hybrid 48 5120 32 4096 16 4 4 1 # dp4pp4 15.62 +# test_hybrid_coshard 48 5120 32 4096 16 4 4 1 # dp4pp4 20.93 +# +# test_hybrid 48 5120 32 8192 16 1 8 2 # pp8tp2 # pp2tp2 17.17GB +# test_hybrid_coshard 48 5120 32 8192 16 4 4 1 # dp4pp4 # dp4pp4 26.73GB +# +# test_hybrid 48 5120 32 12288 16 1 4 4 # pp8tp2 OOM pp4tp4 20.29GB +# test_hybrid_coshard 48 5120 32 12288 16 2 8 1 # dp4pp4 OOM dp2pp8 26.88GB -# test_hybrid 48 5120 32 8192 16 16 1 1 # pp16 OOM -test_hybrid 48 5120 32 8192 16 1 8 2 # pp8tp2 # pp2tp2 17.17GB -test_hybrid_coshard 48 5120 32 8192 16 4 4 1 # dp4pp4 # dp4pp4 26.73GB +# 6.7B 251.35 TFLOPS +test_hybrid 32 4096 32 2048 16 4 4 1 # dp8pp2 OOM dp4pp4 9.29GB +test_hybrid_coshard 32 4096 32 2048 16 4 4 1 # dp4pp4 +test_hybrid 32 4096 32 4096 16 4 4 1 # dp8pp2 OOM, dp4pp4 13.05GB +test_hybrid_coshard 32 4096 32 4096 16 4 4 1 # dp4pp4 10.45, dp8pp2 OOM +test_hybrid 32 4096 32 8192 16 1 8 2 # dp4pp4 OOM dp2pp8 OOM pp16 OOM pp8tp2 13.46GB +test_hybrid_coshard 32 4096 32 8192 16 4 4 1 # dp4pp4 14.38 +# test_hybrid 32 4096 32 12288 16 1 1 16 # pp8tp2 OOM pp4tp4 OOM pp2tp8 OOM pp1tp16 OOM +test_hybrid_coshard 32 4096 32 12288 16 2 4 2 # dp2pp8 OOM dp2pp4tp2 13.31GB -test_hybrid 48 5120 32 12288 16 1 4 4 # pp8tp2 OOM pp4tp4 20.29GB -test_hybrid_coshard 48 5120 32 12288 16 2 8 1 # dp4pp4 OOM dp2pp8 26.88GB # =========================== @@ -102,3 +110,6 @@ test_hybrid_coshard 48 5120 32 12288 16 2 8 1 # dp4pp4 OOM dp2pp8 26.88GB # test_hybrid 48 5120 32 12288 16 1 4 4 # pp8tp2 OOM pp4tp4 20.29GB # test_hybrid_coshard 48 5120 32 12288 16 2 8 1 # dp4pp4 OOM dp2pp8 26.88GB + +python scripts/keep.py --gpus 8 +killall python \ No newline at end of file diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 2ebe1f64..99e8ff86 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -91,6 +91,9 @@ _pp_embed_group = -1 _pp_embed_reducer = None cube.init() +print_each_rank('setting memory constraints to 16GB') +torch.cuda.set_per_process_memory_fraction(0.5) + dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( [args.dp_size, args.pp_size, args.tp_size] ) @@ -174,6 +177,14 @@ def forward(self, hidden_states, recompute=False): x = self.forward_(hidden_states) return x + def flops(self): + mlp_flops = dict( + fc1=config.seqlen * config.hidden_size * config.hidden_size * 4 // self.tp_size, + gelu=8 * config.seqlen * config.hidden_size * 4 // self.tp_size, + fc2=config.seqlen * (config.hidden_size * 4 // self.tp_size) * config.hidden_size, + ) + return sum(mlp_flops.values()) + class SeqMLP(torch.nn.Module): @@ -202,6 +213,9 @@ def forward(self, x, recompute=True): outs = AllReduceIdentity.apply(outs, self.tp_group) return outs + def flops(self): + return sum([mlp.flops() for mlp in self.mlps]) + class Attention(torch.nn.Module): @@ -317,6 +331,19 @@ def forward(self, x, mask, recompute=False): x = self.forward_(x, mask) return x + def flops(self): + seqlen = config.seqlen + attn_flops = dict( + kqv=3 * seqlen * config.hidden_size * self.head_dim * self.num_heads, + kqv_bias=3 * seqlen * self.head_dim * self.num_heads, + q_scale=seqlen * self.num_heads * self.head_dim, # (N h) L d, 1 -> (N h) L d + attn_score=self.num_heads * seqlen * self.head_dim * seqlen, # (N h) L d, (N h) d L -> (N h) L L + attn_softmax=5 * self.num_heads * seqlen * seqlen, # (N h) L L + attn_output=self.num_heads * seqlen * seqlen * self.head_dim, # (N h) L L, (N h) L d -> (N h) L d + out_proj=seqlen * self.num_heads * self.head_dim * config.hidden_size, # L N (h d), E (h d) -> L N E + ) + return sum(attn_flops.values()) + class SeqAttention(torch.nn.Module): @@ -347,6 +374,9 @@ def forward(self, x, mask, recompute=True): outs = AllReduceIdentity.apply(outs, self.tp_group) return outs + def flops(self): + return sum([attn.flops() for attn in self.attns]) + class Embedding(torch.nn.Module): @@ -428,6 +458,20 @@ def forward(self, hidden_states, attention_mask): output = layernorm_input + residual return output + def flops(self): + seqlen = config.seqlen + transformer_flops = dict( + attn_layer_norm=5 * seqlen * config.hidden_size, # (L, N, E) + attn=self.self_attention.flops(), + dropout=seqlen * config.hidden_size, # (L, N, E) + attn_residual=seqlen * config.hidden_size, + fc_layer_norm=5 * seqlen * config.hidden_size, # (L, N, E) + mlp=self.mlp.flops(), + fc_dropout=seqlen * config.hidden_size, + fc_residual=seqlen * config.hidden_size, + ) + return sum(transformer_flops.values()) + class Pooler(torch.nn.Module): @@ -534,6 +578,18 @@ def forward(self, hidden_states = None): return outputs + def flops(self): + flops = 0 + if self.is_first_stage: + # ignore + flops += 0 + # transformer layers + flops += sum([t.flops() for t in self.layers]) + if self.is_last_stage: + # logits + flops += config.seqlen * config.hidden_size * config.vocab_size + return flops + class GPT3DataLoader(cube.runtime.syndata.CubeDataLoader): @@ -599,13 +655,29 @@ def get_ltor_masks_and_position_ids(self, input_ids): return attention_mask, loss_mask, position_ids +def get_alpa_tflops(): + batch_size = 1 + seq_len = config.seqlen + hidden_size = config.hidden_size + num_layers = config.layers + vocab_size = config.vocab_size + factor = 96 # if checkpoint_activations else 72 + total_flop = factor * batch_size * seq_len * (hidden_size ** 2) * num_layers * \ + (1 + seq_len / (6 * hidden_size)) \ + + 6 * batch_size * seq_len * hidden_size * vocab_size + # Note: if we use dot to compute forward embedding + # then the last term in total_flops should be + # "+ 10 * batch_size * seq_len * hidden_size * vocab_size". + tflops = total_flop / 1e12 # total_flop / latency / num_gpus / 1e12 + return tflops + + if __name__ == '__main__': + print_each_rank(f'alpa calculated TFLOPs: {get_alpa_tflops()}', rank_only=0) model = GPT3() nparams = sum([param.numel() for param in model.parameters()]) - # forward_flops = model.flops() - tflops = 0 # forward_flops * 4 / 1e12 # forward + re-compute forward + backward (=2 forward flops) - print_each_rank(f'model params (M): {nparams / 1e6} | TFLOPs: {tflops}. Launching model...') + print_each_rank(f'model params (M): {nparams / 1e6}. Launching model...') model = model.half().cuda() if args.fp16 else model.cuda() dataloader = GPT3DataLoader(args.micro_bs) From f53f3622310943443b6ae1c86b064d3580f18882 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 9 Apr 2022 11:17:02 +0000 Subject: [PATCH 0770/1892] node nodes script --- handcraft/gpt3/test-2node.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/handcraft/gpt3/test-2node.sh b/handcraft/gpt3/test-2node.sh index 4fce6138..fff7cc8a 100755 --- a/handcraft/gpt3/test-2node.sh +++ b/handcraft/gpt3/test-2node.sh @@ -1,7 +1,7 @@ #### # 2-Node Model Scaling Test #### -evaldir=eval/gpt3-coshard-v100-32gb +evaldir=eval/gpt3-coshard-v100-16gb mkdir -p ${evaldir} bs=256 @@ -50,7 +50,7 @@ test_hybrid_coshard() tp=$8 arch=L${layers}E${hidden}H${heads}-seq${seqlen} - echo "testing coshard hybrid: dp:pp=${dp}:${pp}:${tp} : ${arch}" + echo "testing coshard hybrid: dp:pp:tp=${dp}:${pp}:${tp} : ${arch}" OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=2 \ @@ -90,7 +90,7 @@ test_hybrid_coshard 32 4096 32 4096 16 4 4 1 # dp4pp4 10.45, dp8pp2 OOM test_hybrid 32 4096 32 8192 16 1 8 2 # dp4pp4 OOM dp2pp8 OOM pp16 OOM pp8tp2 13.46GB test_hybrid_coshard 32 4096 32 8192 16 4 4 1 # dp4pp4 14.38 # test_hybrid 32 4096 32 12288 16 1 1 16 # pp8tp2 OOM pp4tp4 OOM pp2tp8 OOM pp1tp16 OOM -test_hybrid_coshard 32 4096 32 12288 16 2 4 2 # dp2pp8 OOM dp2pp4tp2 13.31GB +test_hybrid_coshard 32 4096 32 12288 16 1 4 4 # dp2pp8 OOM dp2pp4tp2 13.31GB # =========================== From 1478549da57627449a9429502245b96c10ffb57d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 9 Apr 2022 11:17:25 +0000 Subject: [PATCH 0771/1892] keep without too many log --- scripts/keep.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/keep.py b/scripts/keep.py index 3054c7f2..b580a83e 100644 --- a/scripts/keep.py +++ b/scripts/keep.py @@ -35,22 +35,23 @@ def keep(rank, args): a = torch.rand((8192, 8192)).cuda() b = torch.rand((8192, 8192)).cuda() + print(f'benchmarking {args.gpus} gpus...') while True: tic = time.time() for _ in range(5000): c = a * b torch.cuda.synchronize() toc = time.time() - if rank == 0: - print('benchmark 8K matmul: time span: {}ms'.format((toc - tic) * 1000 / 5000)) + # if rank == 0: + # print('benchmark 8K matmul: time span: {}ms'.format((toc - tic) * 1000 / 5000)) time.sleep(args.interval) while True: util = get_gpu_util(rank) if util <= 10: break - print('rank {}: find gpu busy, keep sleeping...'.format(rank)) + # print('rank {}: find gpu busy, keep sleeping...'.format(rank)) time.sleep(args.interval) - print('rank {} gets up'.format(rank)) + # print('rank {} gets up'.format(rank)) if __name__ == '__main__': From c14c61edcc7c0ceda3eed8033ee43be8c7b15abb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 10 Apr 2022 08:03:38 +0000 Subject: [PATCH 0772/1892] add textnas --- handcraft/textnas/dataloader.py | 144 ++++++++++++++++++ handcraft/textnas/dataset.sh | 6 + handcraft/textnas/ops.py | 240 ++++++++++++++++++++++++++++++ handcraft/textnas/train.py | 255 ++++++++++++++++++++++++++++++++ 4 files changed, 645 insertions(+) create mode 100644 handcraft/textnas/dataloader.py create mode 100644 handcraft/textnas/dataset.sh create mode 100644 handcraft/textnas/ops.py create mode 100644 handcraft/textnas/train.py diff --git a/handcraft/textnas/dataloader.py b/handcraft/textnas/dataloader.py new file mode 100644 index 00000000..c724c192 --- /dev/null +++ b/handcraft/textnas/dataloader.py @@ -0,0 +1,144 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +""" +For test: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/textnas/dataloader.py +""" + + +import os +import numpy as np +import torch +from torch.utils import data +import threading +from transformers import BertModel, BertTokenizer +import collections +import time + +import cube +from cube.runtime.device import DeviceGroup +from cube.profiler import CudaTimer + + +def read_sst_2(data_path='./SST-2', max_input_length=64, min_count=1): + sentences, labels = [], [] + assert os.path.exists(data_path) + dataset_train = os.path.join(data_path, 'train.tsv') + with open(dataset_train, 'r') as f: + lines = f.readlines()[1:] # skip first + for line in lines: + sentence, label = line.split('\t') + sentence = sentence.strip() + label = int(label.strip()) + sentences.append(sentence) + labels.append(label) + return sentences, labels + + +class SSTDataset(data.Dataset): + def __init__(self): + self.sents, self.labels = read_sst_2() + print(f'> loaded SST dataset: train length: {len(self.sents)}') + + def __getitem__(self, index): + return self.sents[index], self.labels[index] + + def __len__(self): + return len(self.sents) + + +class SharedDataLoader(object): + def __init__(self, batch_size, replicate=True, **kwargs): + self.replicate = replicate + self.has_model = self.replicate or (DeviceGroup().rank == 0) + + dataset = SSTDataset() + dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, **kwargs) + self.dataloader = dataloader + + if self.has_model: + self.model = BertModel.from_pretrained('bert-base-uncased').cuda() + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + else: + self.model = None + self.tokenizer = None + + self.max_queue = 32 + self.input_size = (batch_size, 64, 768) + self.batch_size = batch_size + self.length = len(dataset) // batch_size + + def __iter__(self): + self.counter = 0 + self.shared_queue = collections.deque() + self._dataloader_iter = iter(self.dataloader) + if self.has_model and (not self.replicate): + # sharing mode: all models share the same dataloader + print('starting pipeline to produce datas') + self.workers = threading.Thread(target=self._pipe).start() + return self + + def __len__(self): + return len(self.dataloader) + + def get_data(self): + if self.replicate: + CudaTimer().start('bert') + text, label = next(self._dataloader_iter) + text = torch.tensor([self.tokenizer.encode(t, max_length=64, padding='max_length') for t in text]).cuda() + mask = text > 0 + with torch.no_grad(): + output = self.model(text)['last_hidden_state'] + label = label.cuda() + if self.replicate: + CudaTimer().stop('bert') + return output, mask, label + + def _pipe(self): + while True: + while len(self.shared_queue) >= self.max_queue: + time.sleep(0.2) + # print('sample data...') + datas = self.get_data() + # print(datas) + self.shared_queue.append(datas) + + def __next__(self): + self.counter += 1 + if self.counter >= len(self): + raise StopIteration + if self.replicate: + # replicate mode: each gpu has a dataloader + text, masks, labels = self.get_data() + else: + # sharing mode: all models share the same dataloader + if self.has_model: + while not self.shared_queue: + time.sleep(0.1) + text, masks, labels = self.shared_queue.popleft() + assert torch.is_tensor(text) + masks = masks.float() + else: + text = torch.zeros(self.input_size, dtype=torch.float, device="cuda") + labels = torch.zeros(self.batch_size, dtype=torch.long, device="cuda") + masks = torch.zeros(self.input_size[:2], dtype=torch.float, device="cuda") + CudaTimer().start('get_data') + torch.distributed.broadcast(text, 0) + torch.distributed.broadcast(labels, 0) + torch.distributed.broadcast(masks, 0) + CudaTimer().stop('get_data') + masks = masks.bool() + return text, masks, labels + + +if __name__ == '__main__': + + cube.init() + dataloader = SharedDataLoader(32, replicate=True) + for datas in dataloader: + print(f'get data: {[data.size() for data in datas]}') + input('>>>') diff --git a/handcraft/textnas/dataset.sh b/handcraft/textnas/dataset.sh new file mode 100644 index 00000000..2c76689e --- /dev/null +++ b/handcraft/textnas/dataset.sh @@ -0,0 +1,6 @@ + +echo 'downloading SST-2 dataset...' +wget https://dl.fbaipublicfiles.com/glue/data/SST-2.zip +unzip SST-2.zip +rm SST-2.zip + diff --git a/handcraft/textnas/ops.py b/handcraft/textnas/ops.py new file mode 100644 index 00000000..4e88943b --- /dev/null +++ b/handcraft/textnas/ops.py @@ -0,0 +1,240 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn.functional as F +from torch import nn + + +INF = 1E10 +EPS = 1E-12 + +def get_length(mask): + length = torch.sum(mask, 1) + length = length.long() + return length + + +class Mask(nn.Module): + + def forward(self, seq, mask): + # seq: (N, C, L) + # mask: (N, L) + seq_mask = torch.unsqueeze(mask, 2) + seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2) + return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq)) + + +class BatchNorm(nn.Module): + def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True): + super(BatchNorm, self).__init__() + self.mask_opt = Mask() + self.mask_opt1 = Mask() + self.pre_mask = pre_mask + self.post_mask = post_mask + self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine) + + def forward(self, seq, mask): + if self.pre_mask: + seq = self.mask_opt(seq, mask) + seq = self.bn(seq) + if self.post_mask: + seq = self.mask_opt1(seq, mask) + return seq + + +class ConvBN(nn.Module): + + def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob, + pre_mask, post_mask, with_bn=True, with_relu=True): + super(ConvBN, self).__init__() + self.mask_opt = Mask() + self.pre_mask = pre_mask + self.post_mask = post_mask + self.with_bn = with_bn + self.with_relu = with_relu + self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, + padding=(kernal_size - 1) // 2) + self.dropout = nn.Dropout(p=(1 - cnn_keep_prob)) + + if with_bn: + self.bn = BatchNorm(out_channels, not post_mask, True) + + if with_relu: + self.relu = nn.ReLU() + + def forward(self, seq, mask): + if self.pre_mask: + seq = self.mask_opt(seq, mask) + seq = self.conv(seq) + if self.post_mask: + seq = self.mask_opt(seq, mask) + if self.with_bn: + seq = self.bn(seq, mask) + if self.with_relu: + seq = self.relu(seq) + seq = self.dropout(seq) + return seq + + +class AvgPool(nn.Module): + def __init__(self, kernal_size, pre_mask, post_mask): + super(AvgPool, self).__init__() + self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2) + self.pre_mask = pre_mask + self.post_mask = post_mask + self.mask_opt = Mask() + + def forward(self, seq, mask): + if self.pre_mask: + seq = self.mask_opt(seq, mask) + seq = self.avg_pool(seq) + if self.post_mask: + seq = self.mask_opt(seq, mask) + return seq + + +class MaxPool(nn.Module): + def __init__(self, kernal_size, pre_mask, post_mask): + super(MaxPool, self).__init__() + self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2) + self.pre_mask = pre_mask + self.post_mask = post_mask + self.mask_opt = Mask() + + def forward(self, seq, mask): + if self.pre_mask: + seq = self.mask_opt(seq, mask) + seq = self.max_pool(seq) + if self.post_mask: + seq = self.mask_opt(seq, mask) + return seq + + +class Attention(nn.Module): + def __init__(self, num_units, num_heads, keep_prob, is_mask): + super(Attention, self).__init__() + self.num_units = num_units + self.num_heads = num_heads + self.keep_prob = keep_prob + self.is_mask = is_mask + + self.linear_q = nn.Linear(num_units, num_units) + self.linear_k = nn.Linear(num_units, num_units) + self.linear_v = nn.Linear(num_units, num_units) + + self.bn = BatchNorm(num_units, True, is_mask) + self.dropout = nn.Dropout(p=1 - self.keep_prob) + + def forward(self, seq, mask): + in_c = seq.size()[1] + seq = torch.transpose(seq, 1, 2) # (N, L, C) + queries = seq + keys = seq + num_heads = self.num_heads + + # T_q = T_k = L + Q = F.relu(self.linear_q(seq)) # (N, T_q, C) + K = F.relu(self.linear_k(seq)) # (N, T_k, C) + V = F.relu(self.linear_v(seq)) # (N, T_k, C) + + # Split and concat + Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h) + K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h) + V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h) + + # Multiplication + outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k) + # Scale + outputs = outputs / (K_.size()[-1] ** 0.5) + # Key Masking + key_masks = mask.repeat(num_heads, 1) # (h*N, T_k) + key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k) + key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k) + + paddings = torch.ones_like(outputs) * (-INF) # extremely small value + outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs) + + query_masks = mask.repeat(num_heads, 1) # (h*N, T_q) + query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1) + query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k) + + att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k) + att_scores = self.dropout(att_scores) + + # Weighted sum + x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h) + # Restore shape + x_outputs = torch.cat( + torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0), + dim=2) # (N, T_q, C) + + x = torch.transpose(x_outputs, 1, 2) # (N, C, L) + x = self.bn(x, mask) + + return x + + +class RNN(nn.Module): + def __init__(self, hidden_size, output_keep_prob): + super(RNN, self).__init__() + self.hidden_size = hidden_size + self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_keep_prob = output_keep_prob + + self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob)) + + def forward(self, seq, mask): + # seq: (N, C, L) + # mask: (N, L) + max_len = seq.size()[2] + length = get_length(mask) + seq = torch.transpose(seq, 1, 2) # to (N, L, C) + packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True, + enforce_sorted=False) + outputs, _ = self.bid_rnn(packed_seq) + outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, + total_length=max_len)[0] + outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C) + outputs = self.out_dropout(outputs) # output dropout + return torch.transpose(outputs, 1, 2) # back to: (N, C, L) + + +class LinearCombine(nn.Module): + def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False): + super(LinearCombine, self).__init__() + self.layers_num = layers_num + self.trainable = trainable + self.input_aware = input_aware + self.word_level = word_level + + if input_aware: + raise NotImplementedError("Input aware is not supported.") + self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num), + requires_grad=trainable) + + def forward(self, seq): + nw = F.softmax(self.w, dim=0) + seq = torch.mul(seq, nw) + seq = torch.sum(seq, dim=0) + return seq + + +class GlobalAvgPool(nn.Module): + def forward(self, x, mask): + x = torch.sum(x, 2) + length = torch.sum(mask, 1, keepdim=True).float() + length += torch.eq(length, 0.0).float() * EPS + length = length.repeat(1, x.size()[1]) + x /= length + return x + + +class GlobalMaxPool(nn.Module): + def forward(self, x, mask): + mask = torch.eq(mask.float(), 0.0).long() + mask = torch.unsqueeze(mask, dim=1).repeat(1, x.size()[1], 1) + mask *= -INF + x += mask + x, _ = torch.max(x + mask, 2) + return x \ No newline at end of file diff --git a/handcraft/textnas/train.py b/handcraft/textnas/train.py new file mode 100644 index 00000000..409349cd --- /dev/null +++ b/handcraft/textnas/train.py @@ -0,0 +1,255 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + handcraft/textnas/train.py \ + --bs 128 --models 12 --schedule pipe +""" + +import numpy as np +import torch +import torch.nn as nn +import argparse + +import cube +from cube.runtime.device import DeviceGroup +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +from cube.profiler.memory import memory_summary + +from handcraft.textnas.ops import ConvBN, LinearCombine, AvgPool, MaxPool, RNN, Attention, BatchNorm, GlobalMaxPool, GlobalAvgPool +from handcraft.textnas.dataloader import SharedDataLoader + + +cube.init() + +parser = argparse.ArgumentParser(description='textnas') +parser.add_argument('--schedule', type=str, default='replicate', choices=['replicate', 'pipe'], + help='scheduling algorithm. model: train model with replicated dataloader. pipe: train model with shared dataloader') +parser.add_argument('--models', type=int, default=1, + help='number of models to be trained in total') +parser.add_argument('--bs', type=int, default=128, + help='num of micro batch') +args = parser.parse_args() +print(args) + + +_model_divisions = [] +if args.schedule == 'replicate': + num_trainers = DeviceGroup().world_size + num_model_per_device = args.models // num_trainers + _model_divisions = [num_model_per_device] * num_trainers + for idx in range(args.models % num_trainers): + _model_divisions[-1-idx] += 1 +if args.schedule == 'pipe': + num_trainers = DeviceGroup().world_size - 1 + num_model_per_device = args.models // num_trainers + _model_divisions = [0] + [num_model_per_device] * num_trainers + for idx in range(args.models % num_trainers): + _model_divisions[-1-idx] += 1 +print_each_rank(f'model number placements: {_model_divisions}') + + +class WrapperOp(nn.Module): + def __init__(self, op_choice, input_args): + super(WrapperOp, self).__init__() + self.op_choice = op_choice + self.input_args = input_args + self.op = None + + def conv_shortcut(kernel_size, hidden_units, cnn_keep_prob): + return ConvBN(kernel_size, hidden_units, hidden_units, + cnn_keep_prob, False, True) + + if op_choice == 'conv_shortcut1': + self.op = conv_shortcut(*input_args) + elif op_choice == 'conv_shortcut3': + self.op = conv_shortcut(*input_args) + elif op_choice == 'conv_shortcut5': + self.op = conv_shortcut(*input_args) + elif op_choice == 'conv_shortcut7': + self.op = conv_shortcut(*input_args) + elif op_choice == 'AvgPool': + self.op = AvgPool(3, False, True) + elif op_choice == 'MaxPool': + self.op = MaxPool(3, False, True) + elif op_choice == 'RNN': + self.op = RNN(*input_args) + elif op_choice == 'Attention': + self.op = Attention(*input_args) + else: + raise + + def forward(self, prec, mask): + return self.op(prec, mask) + + +class Layer(nn.Module): + def __init__(self, key, prev_keys, hidden_units, choose_from_k, cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask): + super(Layer, self).__init__() + + self.n_candidates = len(prev_keys) + if self.n_candidates: + #===self.prec = mutables.InputChoice(choose_from=prev_keys[-choose_from_k:], n_chosen=1) + self.prec = 1 + else: + # first layer, skip input choice + self.prec = None + '''self.op = mutables.LayerChoice([ + conv_shortcut(1), + conv_shortcut(3), + conv_shortcut(5), + conv_shortcut(7), + AvgPool(3, False, True), + MaxPool(3, False, True), + RNN(hidden_units, lstm_keep_prob), + Attention(hidden_units, 4, att_keep_prob, att_mask) + ])''' + #self.op = conv_shortcut(1) + #self.op = Attention(hidden_units, 4, att_keep_prob, att_mask) + #self.op = RNN(hidden_units, lstm_keep_prob) + #self.op = WrapperOp('RNN', [hidden_units, lstm_keep_prob]) + #self.op = WrapperOp('Attention', [hidden_units, 4, att_keep_prob, att_mask]) + #self.op = WrapperOp('MaxPool', [3, False, True]) + #self.op = WrapperOp('AvgPool', [3, False, True]) + #self.op = WrapperOp('conv_shortcut7', [7, hidden_units, cnn_keep_prob]) + #self.op = WrapperOp('conv_shortcut5', [5, hidden_units, cnn_keep_prob]) + #self.op = WrapperOp('conv_shortcut3', [3, hidden_units, cnn_keep_prob]) + self.op = WrapperOp('conv_shortcut1', [1, hidden_units, cnn_keep_prob]) + if self.n_candidates: + #===self.skipconnect = mutables.InputChoice(choose_from=prev_keys) + self.skipconnect = 1 + else: + self.skipconnect = None + self.bn = BatchNorm(hidden_units, False, True) + + self.prec_n_candidates = choose_from_k + self.skip_n_candidates = len(prev_keys) + + def forward(self, last_layer, prev_layers, mask): + # pass an extra last_layer to deal with layer 0 (prev_layers is empty) + if self.prec is None: + prec = last_layer + else: + #===prec = self.prec(prev_layers[-self.prec.n_candidates:]) # skip first + x = min(len(prev_layers), self.prec_n_candidates) + prec = prev_layers[-x] # skip first + out = self.op(prec, mask) + if self.skipconnect is not None: + #===connection = self.skipconnect(prev_layers[-self.skipconnect.n_candidates:]) + connection = prev_layers[-self.skip_n_candidates] + if connection is not None: + out = out + connection + out = self.bn(out, mask) + return out + + +class Model(nn.Module): + def __init__(self, embedding_dim=768, hidden_units=256, num_layers=24, num_classes=5, choose_from_k=5, + lstm_keep_prob=0.5, cnn_keep_prob=0.5, att_keep_prob=0.5, att_mask=True, + embed_keep_prob=0.5, final_output_keep_prob=1.0, global_pool="avg"): + super(Model, self).__init__() + + # self.embedding = nn.Embedding.from_pretrained(embedding, freeze=False) + self.hidden_units = hidden_units + self.num_layers = num_layers + self.num_classes = num_classes + + self.init_conv = ConvBN(1, embedding_dim, hidden_units, cnn_keep_prob, False, True) + + self.layers = nn.ModuleList() + candidate_keys_pool = [] + for layer_id in range(self.num_layers): + k = "layer_{}".format(layer_id) + self.layers.append(Layer(k, candidate_keys_pool, hidden_units, choose_from_k, + cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask)) + candidate_keys_pool.append(k) + + self.linear_combine = LinearCombine(self.num_layers) + self.linear_out = nn.Linear(self.hidden_units, self.num_classes) + + self.embed_dropout = nn.Dropout(p=1 - embed_keep_prob) + self.output_dropout = nn.Dropout(p=1 - final_output_keep_prob) + + assert global_pool in ["max", "avg"] + if global_pool == "max": + self.global_pool = GlobalMaxPool() + elif global_pool == "avg": + self.global_pool = GlobalAvgPool() + + self.criterion = torch.nn.CrossEntropyLoss() + + def forward(self, inputs, mask, labels): + # sent_ids, mask = inputs + # seq = self.embedding(sent_ids.long()) + seq = self.embed_dropout(inputs) + + seq = torch.transpose(seq, 1, 2) # from (N, L, C) -> (N, C, L) + + x = self.init_conv(seq, mask) + prev_layers = [] + + for layer in self.layers: + x = layer(x, prev_layers, mask) + prev_layers.append(x) + + x = self.linear_combine(torch.stack(prev_layers)) + x = self.global_pool(x, mask) + x = self.output_dropout(x) + x = self.linear_out(x) + loss = self.criterion(x, labels) + return loss + + +if __name__ == '__main__': + + # initialize models + num_model = _model_divisions[DeviceGroup().rank] + print_each_rank(f'initializing {num_model} models...') + models = [Model().cuda() for _ in range(num_model)] + + # initialize dataloaders + if args.schedule == 'replicate': + dataloader = SharedDataLoader(args.bs, replicate=True) + elif args.schedule == 'pipe': + dataloader = SharedDataLoader(args.bs, replicate=False) + else: + assert False + dataloader = iter(dataloader) + + # initialize optimizer + optimizers = [ + torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) for model in models + ] + + CudaTimer(enable=False) + torch.distributed.barrier() + iter_num = 32 + for step in range(iter_num): + if step >= 8: + CudaTimer(enable=True).start('e2e') + + text, masks, labels = next(dataloader) + for model, optimizer in zip(models, optimizers): + CudaTimer().start('nas-model') + loss = model(text, masks, labels) + loss.backward() + optimizer.step() + optimizer.zero_grad() + CudaTimer().stop('nas-model') + + if step >= 8: + CudaTimer().stop('e2e') + + if step == 0: + torch.distributed.barrier() + print_each_rank('memory after optimizer:', rank_only=0) + memory_summary() + + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-8, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-8) + memory_summary() From 359b27a6cc918cebc91ffb5650ac20994c7a8d99 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 10 Apr 2022 12:00:22 +0000 Subject: [PATCH 0773/1892] add non-uniform partition --- handcraft/textnas/dataset.sh | 0 handcraft/textnas/train.py | 37 ++++++++++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 6 deletions(-) mode change 100644 => 100755 handcraft/textnas/dataset.sh diff --git a/handcraft/textnas/dataset.sh b/handcraft/textnas/dataset.sh old mode 100644 new mode 100755 diff --git a/handcraft/textnas/train.py b/handcraft/textnas/train.py index 409349cd..fa8b7d72 100644 --- a/handcraft/textnas/train.py +++ b/handcraft/textnas/train.py @@ -19,6 +19,7 @@ from handcraft.textnas.ops import ConvBN, LinearCombine, AvgPool, MaxPool, RNN, Attention, BatchNorm, GlobalMaxPool, GlobalAvgPool from handcraft.textnas.dataloader import SharedDataLoader +from handcraft.module.stage import layer_division cube.init() @@ -29,7 +30,9 @@ parser.add_argument('--models', type=int, default=1, help='number of models to be trained in total') parser.add_argument('--bs', type=int, default=128, - help='num of micro batch') + help='number of micro batch (default: paper setting)') +parser.add_argument('--non-uniform', action='store_true', default=False, + help='use non-uniform partition that Bert-allocated GPU can also have models') args = parser.parse_args() print(args) @@ -43,10 +46,16 @@ _model_divisions[-1-idx] += 1 if args.schedule == 'pipe': num_trainers = DeviceGroup().world_size - 1 - num_model_per_device = args.models // num_trainers - _model_divisions = [0] + [num_model_per_device] * num_trainers - for idx in range(args.models % num_trainers): - _model_divisions[-1-idx] += 1 + if args.non_uniform: + times = [160.65] + [78.79] * args.models + _model_divisions = layer_division(times, DeviceGroup().world_size) + _model_divisions = [end-start for start, end in _model_divisions] + _model_divisions[0] -= 1 + else: + num_model_per_device = args.models // num_trainers + _model_divisions = [0] + [num_model_per_device] * num_trainers + for idx in range(args.models % num_trainers): + _model_divisions[-1-idx] += 1 print_each_rank(f'model number placements: {_model_divisions}') @@ -228,7 +237,12 @@ def forward(self, inputs, mask, labels): for step in range(iter_num): if step >= 8: CudaTimer(enable=True).start('e2e') - + # if args.schedule == 'replicate': + # # retiarii baseline + # for _ in range(len(models)): + # text, masks, labels = next(dataloader) + # else: + # text, masks, labels = next(dataloader) text, masks, labels = next(dataloader) for model, optimizer in zip(models, optimizers): CudaTimer().start('nas-model') @@ -237,6 +251,17 @@ def forward(self, inputs, mask, labels): optimizer.step() optimizer.zero_grad() CudaTimer().stop('nas-model') + + # CudaTimer().start('nas-model') + # losses = [] + # for model in models: + # losses.append(model(text, masks, labels)) + # for loss in losses: + # loss.backward() + # for optimizer in optimizers: + # optimizer.step() + # optimizer.zero_grad() + # CudaTimer().stop('nas-model') if step >= 8: CudaTimer().stop('e2e') From c4e52792ded9d59397389749552f0794017a2494 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 10 Apr 2022 13:33:19 +0000 Subject: [PATCH 0774/1892] longer test --- handcraft/textnas/train.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/handcraft/textnas/train.py b/handcraft/textnas/train.py index fa8b7d72..44b6e89d 100644 --- a/handcraft/textnas/train.py +++ b/handcraft/textnas/train.py @@ -47,7 +47,7 @@ if args.schedule == 'pipe': num_trainers = DeviceGroup().world_size - 1 if args.non_uniform: - times = [160.65] + [78.79] * args.models + times = [160] + [80] * args.models _model_divisions = layer_division(times, DeviceGroup().world_size) _model_divisions = [end-start for start, end in _model_divisions] _model_divisions[0] -= 1 @@ -224,7 +224,6 @@ def forward(self, inputs, mask, labels): dataloader = SharedDataLoader(args.bs, replicate=False) else: assert False - dataloader = iter(dataloader) # initialize optimizer optimizers = [ @@ -233,9 +232,10 @@ def forward(self, inputs, mask, labels): CudaTimer(enable=False) torch.distributed.barrier() - iter_num = 32 + dataloader = iter(dataloader) + iter_num = 64 for step in range(iter_num): - if step >= 8: + if step >= 16: CudaTimer(enable=True).start('e2e') # if args.schedule == 'replicate': # # retiarii baseline @@ -263,7 +263,7 @@ def forward(self, inputs, mask, labels): # optimizer.zero_grad() # CudaTimer().stop('nas-model') - if step >= 8: + if step >= 16: CudaTimer().stop('e2e') if step == 0: @@ -275,6 +275,6 @@ def forward(self, inputs, mask, labels): print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-8, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-8) + CudaTimer().duration(iter_num-16, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-16) memory_summary() From 05659990936ce97969699ecd687984c196bba9ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 12 Apr 2022 08:20:17 +0000 Subject: [PATCH 0775/1892] add feature loading parital tensor from local --- cube/graph/adapter/adapter.py | 36 +++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index 74d50129..b2281dc3 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -352,10 +352,38 @@ def gen_select(dst_tensor): remote.append(ptensor) # first check local in tensor - if otensor in local: - intersections.append(otensor) - inputs.append(otensor) - return inputs, intersections, prims + for tensor in local: + common = tensor.common(otensor) + if tensor == otensor: + intersections.append(tensor) + inputs.append(tensor) + return inputs, intersections, prims + elif common == otensor: + # index map + indmap = list() + islicers = tensor.indmap.get() + oslicers = common.indmap.get() + for islicer, oslicer in zip(islicers, oslicers): + istart, istop, istep = islicer.start, islicer.stop, islicer.step + ostart, ostop, ostep = oslicer.start, oslicer.stop, oslicer.step + # relative offset + start = ostart - istart + stop = start + ostop - ostart + slicer = slice(start, stop, ostep) + indmap.append(slicer) + # value map must be same + if tensor.valmap != common.valmap: + break + valmap = ValueMap(0, 1) + prim = SelectPrim(tensor, indmap, valmap, common.shape, common) + prims.append(prim) + intersections.append(otensor) + inputs.append(tensor) + return inputs, intersections, prims + # if otensor in local: + # intersections.append(otensor) + # inputs.append(otensor) + # return inputs, intersections, prims # check local + remote for itensor in local_and_remote: #local + remote: From fccdec0ca85f660f0d595cf20a330150af7be5f0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 13 Apr 2022 06:43:00 +0000 Subject: [PATCH 0776/1892] init bigbird --- handcraft/bigbird/sparse_attn.py | 970 +++++++++++++++++++++++++++++++ 1 file changed, 970 insertions(+) create mode 100644 handcraft/bigbird/sparse_attn.py diff --git a/handcraft/bigbird/sparse_attn.py b/handcraft/bigbird/sparse_attn.py new file mode 100644 index 00000000..ebae5566 --- /dev/null +++ b/handcraft/bigbird/sparse_attn.py @@ -0,0 +1,970 @@ +""" +BigBird paper +https://papers.nips.cc/paper/2020/file/c8512d142a2d849725f31a9a7a361ab9-Paper.pdf + +Understanding blog: +https://github.com/huggingface/blog/blob/main/big-bird.md +""" + +import torch +import cube + +class Config: + + num_attention_heads = 32 + hidden_size = 4096 + all_head_sie = hidden_size + max_position_embeddings = 4096 # seqlen + num_random_blocks = 3 + block_size=4 + use_bias = True + +config = Config() + + +class BigBirdBlockSparseAttention(nn.Module): + def __init__(self, seed=None): + super().__init__() + + self.max_seqlen = config.max_position_embeddings + self.seed = seed + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.num_random_blocks = config.num_random_blocks + self.block_size = config.block_size + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + output_attentions=None, + ): + # Currently this `class` can't be used in decoder. + + batch_size, seqlen, _ = hidden_states.size() + to_seq_length = from_seq_length = seqlen + from_block_size = to_block_size = self.block_size + + assert from_seq_length % from_block_size == 0, "Query sided sequence length must be multiple of block size" + assert to_seq_length % to_block_size == 0, "Key/Value sided sequence length must be multiple of block size" + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + context_layer, attention_probs = self.bigbird_block_sparse_attention( + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + self.num_attention_heads, + self.num_random_blocks, + self.attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_length, + to_seq_length, + seed=self.seed, + plan_from_length=None, + plan_num_rand_blocks=None, + output_attentions=output_attentions, + ) + + context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + @staticmethod + def torch_bmm_nd(inp_1, inp_2, ndim=None): + """Fast nd matrix multiplication""" + # faster replacement of torch.einsum ("bhqk,bhkd->bhqd") + return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view( + inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1]) + ) + + @staticmethod + def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None): + """Fast nd matrix multiplication with transpose""" + # faster replacement of torch.einsum (bhqd,bhkd->bhqk) + return torch.bmm( + inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2) + ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2])) + + def bigbird_block_sparse_attention( + self, + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + n_heads, + n_rand_blocks, + attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_len, + to_seq_len, + seed, + plan_from_length, + plan_num_rand_blocks, + output_attentions, + ): + + # BigBird block-sparse attention as suggested in paper + + # ITC: + # global tokens: 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # ETC: + # global tokens: extra_globals_tokens + 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # Note: + # 1) Currently, ETC is not supported. + # 2) Window size is fixed to 3 blocks & it can be changed only by + # changing `block_size`. + # 3) Number of global blocks are fixed (2 blocks here) & global tokens can be + # controlled only by `block_size`. + + # attention is calculated separately for q[0], q[1], q[2:-2], q[-2], q[-1] in order to use special trick of shifting tokens (for calculating sliding attention) + # hence following code can be divided into 5 parts. + + if from_seq_len // from_block_size != to_seq_len // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + rsqrt_d = 1 / math.sqrt(attention_head_size) + bsz = batch_size + attn_mask_penalty = -10000.0 + + # generate random attention and corresponding masks + np.random.seed(seed) + if from_seq_len in [1024, 3072, 4096]: # old plans used in paper + rand_attn = [ + self._bigbird_block_rand_mask( + self.max_seqlen, self.max_seqlen, from_block_size, to_block_size, n_rand_blocks, last_idx=1024 + )[: (from_seq_len // from_block_size - 2)] + for _ in range(n_heads) + ] + else: + if plan_from_length is None: + plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( + from_seq_len, from_block_size, n_rand_blocks + ) + + rand_attn = self._bigbird_block_rand_mask_with_head( + from_seq_length=from_seq_len, + to_seq_length=to_seq_len, + from_block_size=from_block_size, + to_block_size=to_block_size, + num_heads=n_heads, + plan_from_length=plan_from_length, + plan_num_rand_blocks=plan_num_rand_blocks, + ) + + rand_attn = np.stack(rand_attn, axis=0) + rand_attn = torch.tensor(rand_attn, device=query_layer.device, dtype=torch.long) + rand_attn.unsqueeze_(0) + rand_attn = torch.cat([rand_attn for _ in range(batch_size)], dim=0) + + rand_mask = self._create_rand_mask_from_inputs( + from_blocked_mask, to_blocked_mask, rand_attn, n_heads, n_rand_blocks, bsz, from_seq_len, from_block_size + ) + + blocked_query_matrix = query_layer.view(bsz, n_heads, from_seq_len // from_block_size, from_block_size, -1) + blocked_key_matrix = key_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + blocked_value_matrix = value_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + + # preparing block for randn attn + gathered_key = self.torch_gather_b2(blocked_key_matrix, rand_attn) + gathered_key = gathered_key.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + gathered_value = self.torch_gather_b2(blocked_value_matrix, rand_attn) + gathered_value = gathered_value.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + + # 1st PART + # 1st block (global block) attention scores + # q[0] x (k[0], k[1], k[2], k[3], k[4] .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + first_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 0], key_layer, ndim=4) + + first_product = first_product * rsqrt_d + first_product += (1.0 - to_mask) * attn_mask_penalty + first_attn_weights = nn.functional.softmax( + first_product, dim=-1 + ) # [bsz, n_heads, from_block_size, to_seq_len] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4) + first_context_layer.unsqueeze_(2) + + # 2nd PART + # 2nd block attention scores + # q[1] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> 2nd, 3rd blocks + # global key blocks -> 1st block + + second_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, 1], + blocked_key_matrix[:, :, 2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + second_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, 1], + blocked_value_matrix[:, :, 2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 1], second_key_mat, ndim=4) + second_seq_pad = torch.cat( + [ + to_mask[:, :, :, : 3 * to_block_size], + to_mask[:, :, :, -to_block_size:], + to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_rand_pad = torch.cat( + [ + rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, 0], + ], + dim=3, + ) + second_product = second_product * rsqrt_d + second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty + second_attn_weights = nn.functional.softmax( + second_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_context_layer = self.torch_bmm_nd(second_attn_weights, second_value_mat, ndim=4) + + second_context_layer.unsqueeze_(2) + + # 3rd PART + # Middle blocks attention scores + # q[-2:2] x (sliding_keys, random_keys, global_keys) + # sliding attn is calculated using special trick of shifting tokens as discussed in paper + # random keys are generated by taking random indices as per `rand_attn` + # global keys -> 1st & last block + + exp_blocked_key_matrix = torch.cat( + [blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1]], dim=3 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + exp_blocked_value_matrix = torch.cat( + [blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1]], + dim=3, + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + middle_query_matrix = blocked_query_matrix[:, :, 2:-2] + + # sliding attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [b, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + inner_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, exp_blocked_key_matrix, ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, 3*to_block_size] + inner_band_product = inner_band_product * rsqrt_d + + # randn attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + rand_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, gathered_key[:, :, 1:-1], ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] + rand_band_product = rand_band_product * rsqrt_d + + # Including 1st block (since it's global) + first_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + first_band_product = first_band_product * rsqrt_d + + # Including last block (since it's global) + last_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + last_band_product = last_band_product * rsqrt_d + + # masking padded tokens + inner_band_product += (1.0 - band_mask) * attn_mask_penalty + first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * attn_mask_penalty + last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * attn_mask_penalty + rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty + + # completing attention scores matrix for all q[-2:2] + band_product = torch.cat( + [first_band_product, inner_band_product, rand_band_product, last_band_product], dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # safely doing softmax since attention matrix is completed + attn_weights = nn.functional.softmax( + band_product, dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # contribution of sliding keys + # [bsz, n_heads, m//from_block_size-4, from_block_size, 3*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + context_layer = self.torch_bmm_nd( + attn_weights[:, :, :, :, to_block_size : 4 * to_block_size], exp_blocked_value_matrix, ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of random keys + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + context_layer += self.torch_bmm_nd( + attn_weights[:, :, :, :, 4 * to_block_size : -to_block_size], gathered_value[:, :, 1:-1], ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of global keys + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, :to_block_size], blocked_value_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, -to_block_size:], blocked_value_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # 4th PART + # last 2nd token attention scores + # q[-2] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> last 3 blocks + # global key block -> 1st block + # random key block -> based on indices stored in `randn_attn` + + second_last_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, -3], + blocked_key_matrix[:, :, -2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+n_random_blocks)*to_block_size, -1] + second_last_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, -3], + blocked_value_matrix[:, :, -2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+r)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -2], second_last_key_mat, ndim=4) + second_last_seq_pad = torch.cat( + [ + to_mask[:, :, :, :to_block_size], + to_mask[:, :, :, -3 * to_block_size :], + to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_last_rand_pad = torch.cat( + [ + rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, -1], + ], + dim=3, + ) + second_last_product = second_last_product * rsqrt_d + second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty + second_last_attn_weights = nn.functional.softmax( + second_last_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_last_context_layer = self.torch_bmm_nd(second_last_attn_weights, second_last_value_mat, ndim=4) + second_last_context_layer.unsqueeze_(2) + + # 5th PART + # last block (global) attention scores + # q[-1] x (k[0], k[1], k[2], k[3], .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) + last_product = last_product * rsqrt_d + last_product += (1.0 - to_mask) * attn_mask_penalty + last_attn_weights = nn.functional.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4) + last_context_layer.unsqueeze_(2) + + # combining representations of all tokens + context_layer = torch.cat( + [first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer], + dim=2, + ) + context_layer = context_layer.view((bsz, n_heads, from_seq_len, -1)) * from_mask + context_layer = torch.transpose(context_layer, 1, 2) + + # this is just for visualizing; forward pass doesn't depend on following code + if output_attentions: + # TODO(PVP): need to verify if below code is correct + attention_probs = torch.zeros( + bsz, n_heads, from_seq_len, to_seq_len, dtype=torch.float, device=context_layer.device + ) + + # 1st query block + # corresponding to `first_context_layer` + attention_probs[:, :, :from_block_size, :] = first_attn_weights # all keys global + + # 2nd query block + # corresponding to `second_context_layer` + attention_probs[:, :, from_block_size : 2 * from_block_size, : 3 * to_block_size] = second_attn_weights[ + :, :, :, : 3 * to_block_size + ] # 1st three key blocks (global + sliding) + attention_probs[:, :, from_block_size : 2 * from_block_size, -to_block_size:] = second_attn_weights[ + :, :, :, 3 * to_block_size : 4 * to_block_size + ] # last key block (global) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, 1, :, i2[0]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Middle query blocks + # corresponding to `context_layer` + # sliding keys + for q_idx in range(from_seq_len // from_block_size - 4): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + )[:, :, 2:-2, :, 1:-1, :] + right_slice = attn_weights[:, :, q_idx, :, to_block_size : 4 * to_block_size] + attn_probs_view[:, :, q_idx, :, q_idx : q_idx + 3, :] = right_slice.view( + bsz, n_heads, from_block_size, 3, to_block_size + ) # inner_band_product + # global keys (corresponding to 1st key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, :to_block_size] = attn_weights[ + :, :, :, :, :to_block_size + ].view( + bsz, n_heads, -1, to_block_size + ) # first_band_product + # global keys (corresponding to last key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, -to_block_size:] = attn_weights[ + :, :, :, :, -to_block_size: + ].view( + bsz, n_heads, -1, to_block_size + ) # last_band_product + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + for q_idx in range(1, len(i2) - 1): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[q_idx - 1, :, 4 * to_block_size : -to_block_size] + attn_probs_view[p1, p2, q_idx + 1, :, i2[q_idx]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Second-last query block + # corresponding to `second_last_context_layer` + attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[ + :, :, :, :to_block_size + ] # 1st key block (global) + attention_probs[ + :, :, -2 * from_block_size : -from_block_size, -3 * to_block_size : + ] = second_last_attn_weights[ + :, :, :, to_block_size : 4 * to_block_size + ] # last three blocks (global + sliding) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, -2, :, i2[-1]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # last query block + # corresponding to `last_context_layer` + attention_probs[:, :, -from_block_size:, :] = last_attn_weights # all keys global + + else: + attention_probs = None + + return context_layer, attention_probs + + @staticmethod + def torch_gather_b2(params, indices): + # this operation is equivalent to tf.gather when batch_dims=2 + + if params.shape[:2] != indices.shape[:2]: + raise ValueError( + f"Make sure that the first two dimensions of params and indices are identical, \ + but they are params: {params.shape[:2]} vs. indices: {params.shape[:2]}" + ) + num_indices_to_gather = indices.shape[-2] * indices.shape[-1] + num_indices_to_pick_from = params.shape[2] + + indices_shift = ( + torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device) + // num_indices_to_gather + * num_indices_to_pick_from + ) + + flattened_indices = indices.view(-1) + indices_shift + flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1]) + + out_flattened = flattened_params.index_select(0, flattened_indices) + + out = out_flattened.reshape(params.shape[:2] + (num_indices_to_gather,) + params.shape[3:]) + return out + + @staticmethod + def _create_rand_mask_from_inputs( + from_blocked_mask, + to_blocked_mask, + rand_attn, + num_attention_heads, + num_rand_blocks, + batch_size, + from_seq_length, + from_block_size, + ): + """ + Create 3D attention mask from a 2D tensor mask. + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + rand_attn: [batch_size, num_attention_heads, + from_seq_length//from_block_size-2, num_rand_blocks] + num_attention_heads: int. Number of attention heads. + num_rand_blocks: int. Number of random chunks per row. + batch_size: int. Batch size for computation. + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + Returns: + float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2, + from_block_size, num_rand_blocks*to_block_size]. + """ + num_windows = from_seq_length // from_block_size - 2 + rand_mask = torch.stack([p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]) + rand_mask = rand_mask.view(batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size) + rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) + return rand_mask + + @staticmethod + def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks): + """ + Gives the plan of where to put random attention. + Args: + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + num_rand_blocks: int. Number of random chunks per row. + Returns: + plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for + each block + """ + + plan_from_length = [] + plan_num_rand_blocks = [] + if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(0) + elif (num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks // 2) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2)) + else: + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks) + + return plan_from_length, plan_num_rand_blocks + + @staticmethod + def _bigbird_block_rand_mask( + from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1 + ): + """ + Create adjacency list of random attention. + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_rand_blocks: int. Number of random chunks per row. + last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, + if positive then num_rand_blocks blocks chosen only up to last_idx. + Returns: + adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks + """ + # using this method when from_seq_length in [1024, 3072, 4096] + + assert ( + from_seq_length // from_block_size == to_seq_length // to_block_size + ), "Error the number of blocks needs to be same!" + + rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) + middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32) + last = to_seq_length // to_block_size - 1 + if last_idx > (2 * to_block_size): + last = (last_idx // to_block_size) - 1 + + r = num_rand_blocks # shorthand + for i in range(1, from_seq_length // from_block_size - 1): + start = i - 2 + end = i + if i == 1: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] + elif i == 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] + elif i == from_seq_length // from_block_size - 3: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -3: should have been sliced till last-3 + elif i == from_seq_length // from_block_size - 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -4: should have been sliced till last-4 + else: + if start > last: + start = last + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + elif (end + 1) == last: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + else: + rand_attn[i - 1, :] = np.random.permutation( + np.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) + )[:r] + return rand_attn + + def _bigbird_block_rand_mask_with_head( + self, + from_seq_length, + to_seq_length, + from_block_size, + to_block_size, + num_heads, + plan_from_length, + plan_num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_top=1, + global_block_bottom=1, + global_block_left=1, + global_block_right=1, + ): + """ + Create adjacency list of random attention. + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_heads: int. total number of heads. + plan_from_length: list. plan from length where num_random_blocks are chosen from. + plan_num_rand_blocks: list. number of rand blocks within the plan. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_top: int. number of blocks at the top. + global_block_bottom: int. number of blocks at the bottom. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + Returns: + adjacency list of size num_head where each element is of size from_seq_length//from_block_size-2 by + num_rand_blocks + """ + # using this method when from_seq_length not in [1024, 3072, 4096] + + assert ( + from_seq_length // from_block_size == to_seq_length // to_block_size + ), "Error the number of blocks needs to be same!" + + assert from_seq_length in plan_from_length, "Error from sequence length not in plan!" + + # Total number of blocks in the mmask + num_blocks = from_seq_length // from_block_size + # Number of blocks per plan + plan_block_length = np.array(plan_from_length) // from_block_size + # till when to follow plan + max_plan_idx = plan_from_length.index(from_seq_length) + # Random Attention adjacency list + rand_attn = [ + np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32) + for i in range(num_heads) + ] + + # We will go iteratively over the plan blocks and pick random number of + # Attention blocks from the legally allowed blocks + for plan_idx in range(max_plan_idx + 1): + rnd_r_cnt = 0 + if plan_idx > 0: + # set the row for all from_blocks starting from 0 to + # plan_block_length[plan_idx-1] + # column indx start fromm plan_block_length[plan_idx-1] and ends at + # plan_block_length[plan_idx] + if plan_num_rand_blocks[plan_idx] > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=plan_block_length[plan_idx - 1], + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for pl_id in range(plan_idx): + if plan_num_rand_blocks[pl_id] == 0: + continue + for blk_rw_idx in range(plan_block_length[plan_idx - 1], plan_block_length[plan_idx]): + rnd_r_cnt = 0 + to_start_block_id = 0 + if pl_id > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id])) + to_start_block_id = plan_block_length[pl_id - 1] + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: pl_id + 1])) + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[pl_id], + num_rand_blocks=plan_num_rand_blocks[pl_id], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + if plan_num_rand_blocks[plan_idx] == 0: + continue + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + from_start_block_id = global_block_top + to_start_block_id = 0 + if plan_idx > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + from_start_block_id = plan_block_length[plan_idx - 1] + to_start_block_id = plan_block_length[plan_idx - 1] + + for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + + return rand_attn + + @staticmethod + def _get_single_block_row_attention( + block_id, + to_start_block_id, + to_end_block_id, + num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_left=1, + global_block_right=1, + ): + """ + For a single row block get random row attention. + Args: + block_id: int. block id of row. + to_start_block_id: int. random attention column start id. + to_end_block_id: int. random attention column end id. + num_rand_blocks: int. number of random blocks to be selected. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + Returns: + row containing the random attention vector of size num_rand_blocks. + """ + # list of to_blocks from which to choose random attention + to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32) + # permute the blocks + perm_block = np.random.permutation(to_block_list) + + # illegal blocks for the current block id, using window + illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) + + # Add blocks at the start and at the end + illegal_blocks.extend(list(range(global_block_left))) + illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id))) + + # The second from_block cannot choose random attention on second last to_block + if block_id == 1: + illegal_blocks.append(to_end_block_id - 2) + + # The second last from_block cannot choose random attention on second to_block + if block_id == to_end_block_id - 2: + illegal_blocks.append(1) + + selected_random_blokcs = [] + + for i in range(to_end_block_id - to_start_block_id): + if perm_block[i] not in illegal_blocks: + selected_random_blokcs.append(perm_block[i]) + if len(selected_random_blokcs) == num_rand_blocks: + break + return np.array(selected_random_blokcs, dtype=np.int32) + + +# class BigBirdDataLoader(cube.runtime.syndata.CubeDataLoader): +# +# def __init__(self, batch_size: int): +# self.bs = batch_size +# super().__init__( +# shapes=( +# [batch_size, ] +# ) +# ) + + + +if __name__ == '__main__': + + model = BigBirdBlockSparseAttention() + nparams = sum([param.numel() for param in model.parameters()]) + print_each_rank(f'model params (M): {nparams / 1e6}. Launching model...') + model = model.half().cuda() if args.fp16 else model.cuda() + + dataloader = GPT3DataLoader(args.micro_bs) + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + print_each_rank('model weight consumpition:') + memory_summary() + + CudaTimer(enable=False) + torch.distributed.barrier() + iter_num = 6 + for step in range(iter_num): + if step >= 2: + CudaTimer(enable=True).start('e2e') + + # train 1 step + num_microbatch = args.bs // (args.micro_bs * args.dp_size) + if args.pp_size > 1: + _schedule(model, dataloader, num_microbatch) + else: + for _ in range(num_microbatch): + model.data = next(dataloader) + loss = model() + loss.backward() + + if _pp_embed_reducer is not None: + _pp_embed_reducer.allreduce() + + if _dp_reducer is not None: + _dp_reducer.allreduce() + + optimizer.step() + optimizer.zero_grad() + + if step >= 2: + CudaTimer().stop('e2e') + + torch.cuda.empty_cache() + torch.distributed.barrier() + + if step == 0: + print_each_rank('memory after optimizer:', rank_only=0) + memory_summary() + + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-2, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-2) + memory_summary() \ No newline at end of file From 72af2146e295880478fecf3438ba18ec9322d92f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 13 Apr 2022 12:29:08 +0000 Subject: [PATCH 0777/1892] sparse attention --- handcraft/bigbird/sparse_attn.py | 344 +++++++++++++++---------------- 1 file changed, 171 insertions(+), 173 deletions(-) diff --git a/handcraft/bigbird/sparse_attn.py b/handcraft/bigbird/sparse_attn.py index ebae5566..7282ac34 100644 --- a/handcraft/bigbird/sparse_attn.py +++ b/handcraft/bigbird/sparse_attn.py @@ -4,29 +4,104 @@ Understanding blog: https://github.com/huggingface/blog/blob/main/big-bird.md + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/bigbird/sparse_attn.py \ + --hidden-size 4096 --heads 32 --seqlen 4096 \ + --bs 8 --fp16 """ import torch +import torch.nn as nn import cube +import math +import numpy as np + +import argparse +from cube.runtime.device import DeviceGroup +from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary +from cube.profiler.timer import print_each_rank + + + +parser = argparse.ArgumentParser(description='sparse_attention') + +parser.add_argument('--hidden-size', type=int, default=4096, + help='hidden size') +parser.add_argument('--heads', type=int, default=32, + help='number of heads') +parser.add_argument('--seqlen', type=int, default=3096, + help='sequence length') +parser.add_argument('--blk-size', type=int, default=64, + help='sequence length') +# training config +parser.add_argument('--bs', type=int, default=256, + help='num of micro batch') +parser.add_argument('--fp16', action='store_true', default=False) +parser.add_argument('--sparse', action='store_true', default=False) +args = parser.parse_args() +print(args) +cube.init() + class Config: - num_attention_heads = 32 - hidden_size = 4096 + num_attention_heads = args.heads + hidden_size = args.hidden_size all_head_sie = hidden_size - max_position_embeddings = 4096 # seqlen + seqlen = args.seqlen # seqlen num_random_blocks = 3 - block_size=4 + block_size=args.blk_size use_bias = True config = Config() +def create_mask(): + batch_size = args.bs + seq_length = config.seqlen + block_size = config.block_size + attention_mask = torch.ones(((batch_size, seq_length)), device=torch.cuda.current_device()) + assert ( + seq_length % block_size == 0 + ), f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block size is {block_size}." + + def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): + """ + Create 3D attention mask from a 2D tensor mask. + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + Returns: + float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, + 3*to_block_size]. + """ + exp_blocked_to_pad = torch.cat( + [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], dim=2 + ) + band_mask = torch.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) + band_mask.unsqueeze_(1) + return band_mask + + blocked_encoder_mask = attention_mask.view(batch_size, seq_length // block_size, block_size) + band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) + from_mask = attention_mask.view(batch_size, 1, seq_length, 1) + to_mask = attention_mask.view(batch_size, 1, 1, seq_length) + return blocked_encoder_mask, band_mask, from_mask, to_mask + + + class BigBirdBlockSparseAttention(nn.Module): def __init__(self, seed=None): super().__init__() - self.max_seqlen = config.max_position_embeddings + self.max_seqlen = config.seqlen self.seed = seed if config.hidden_size % config.num_attention_heads != 0: @@ -51,17 +126,10 @@ def transpose_for_scores(self, x): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward( - self, - hidden_states, - band_mask=None, - from_mask=None, - to_mask=None, - from_blocked_mask=None, - to_blocked_mask=None, - output_attentions=None, - ): + def forward(self, hidden_states): # Currently this `class` can't be used in decoder. + blocked_mask, band_mask, from_mask, to_mask = create_mask() + from_blocked_mask = to_blocked_mask = blocked_mask batch_size, seqlen, _ = hidden_states.size() to_seq_length = from_seq_length = seqlen @@ -74,7 +142,7 @@ def forward( key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - context_layer, attention_probs = self.bigbird_block_sparse_attention( + context_layer = self.bigbird_block_sparse_attention( query_layer, key_layer, value_layer, @@ -94,13 +162,10 @@ def forward( seed=self.seed, plan_from_length=None, plan_num_rand_blocks=None, - output_attentions=output_attentions, ) context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs + return context_layer @staticmethod def torch_bmm_nd(inp_1, inp_2, ndim=None): @@ -126,8 +191,8 @@ def bigbird_block_sparse_attention( band_mask, from_mask, to_mask, - from_blocked_mask, - to_blocked_mask, + from_blocked_mask, # same with blocked encoder mask + to_blocked_mask, # same with blocked encoder mask n_heads, n_rand_blocks, attention_head_size, @@ -139,7 +204,6 @@ def bigbird_block_sparse_attention( seed, plan_from_length, plan_num_rand_blocks, - output_attentions, ): # BigBird block-sparse attention as suggested in paper @@ -446,126 +510,7 @@ def bigbird_block_sparse_attention( context_layer = context_layer.view((bsz, n_heads, from_seq_len, -1)) * from_mask context_layer = torch.transpose(context_layer, 1, 2) - # this is just for visualizing; forward pass doesn't depend on following code - if output_attentions: - # TODO(PVP): need to verify if below code is correct - attention_probs = torch.zeros( - bsz, n_heads, from_seq_len, to_seq_len, dtype=torch.float, device=context_layer.device - ) - - # 1st query block - # corresponding to `first_context_layer` - attention_probs[:, :, :from_block_size, :] = first_attn_weights # all keys global - - # 2nd query block - # corresponding to `second_context_layer` - attention_probs[:, :, from_block_size : 2 * from_block_size, : 3 * to_block_size] = second_attn_weights[ - :, :, :, : 3 * to_block_size - ] # 1st three key blocks (global + sliding) - attention_probs[:, :, from_block_size : 2 * from_block_size, -to_block_size:] = second_attn_weights[ - :, :, :, 3 * to_block_size : 4 * to_block_size - ] # last key block (global) - # random keys - for p1, i1, w1 in zip(range(bsz), rand_attn, second_attn_weights): - # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch - for p2, i2, w2 in zip(range(n_heads), i1, w1): - # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads - attn_probs_view = attention_probs.view( - bsz, - n_heads, - from_seq_len // from_block_size, - from_block_size, - to_seq_len // to_block_size, - to_block_size, - ) - right_slice = w2[:, 4 * to_block_size :] - attn_probs_view[p1, p2, 1, :, i2[0]] = right_slice.view( - from_block_size, n_rand_blocks, to_block_size - ) - - # Middle query blocks - # corresponding to `context_layer` - # sliding keys - for q_idx in range(from_seq_len // from_block_size - 4): - attn_probs_view = attention_probs.view( - bsz, - n_heads, - from_seq_len // from_block_size, - from_block_size, - to_seq_len // to_block_size, - to_block_size, - )[:, :, 2:-2, :, 1:-1, :] - right_slice = attn_weights[:, :, q_idx, :, to_block_size : 4 * to_block_size] - attn_probs_view[:, :, q_idx, :, q_idx : q_idx + 3, :] = right_slice.view( - bsz, n_heads, from_block_size, 3, to_block_size - ) # inner_band_product - # global keys (corresponding to 1st key block) - attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, :to_block_size] = attn_weights[ - :, :, :, :, :to_block_size - ].view( - bsz, n_heads, -1, to_block_size - ) # first_band_product - # global keys (corresponding to last key block) - attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, -to_block_size:] = attn_weights[ - :, :, :, :, -to_block_size: - ].view( - bsz, n_heads, -1, to_block_size - ) # last_band_product - # random keys - for p1, i1, w1 in zip(range(bsz), rand_attn, attn_weights): - # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch - for p2, i2, w2 in zip(range(n_heads), i1, w1): - # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads - for q_idx in range(1, len(i2) - 1): - attn_probs_view = attention_probs.view( - bsz, - n_heads, - from_seq_len // from_block_size, - from_block_size, - to_seq_len // to_block_size, - to_block_size, - ) - right_slice = w2[q_idx - 1, :, 4 * to_block_size : -to_block_size] - attn_probs_view[p1, p2, q_idx + 1, :, i2[q_idx]] = right_slice.view( - from_block_size, n_rand_blocks, to_block_size - ) - - # Second-last query block - # corresponding to `second_last_context_layer` - attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[ - :, :, :, :to_block_size - ] # 1st key block (global) - attention_probs[ - :, :, -2 * from_block_size : -from_block_size, -3 * to_block_size : - ] = second_last_attn_weights[ - :, :, :, to_block_size : 4 * to_block_size - ] # last three blocks (global + sliding) - # random keys - for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights): - # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch - for p2, i2, w2 in zip(range(n_heads), i1, w1): - # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads - attn_probs_view = attention_probs.view( - bsz, - n_heads, - from_seq_len // from_block_size, - from_block_size, - to_seq_len // to_block_size, - to_block_size, - ) - right_slice = w2[:, 4 * to_block_size :] - attn_probs_view[p1, p2, -2, :, i2[-1]] = right_slice.view( - from_block_size, n_rand_blocks, to_block_size - ) - - # last query block - # corresponding to `last_context_layer` - attention_probs[:, :, -from_block_size:, :] = last_attn_weights # all keys global - - else: - attention_probs = None - - return context_layer, attention_probs + return context_layer @staticmethod def torch_gather_b2(params, indices): @@ -899,16 +844,74 @@ def _get_single_block_row_attention( break return np.array(selected_random_blokcs, dtype=np.int32) + @torch.no_grad() + def forward_dense(self, hidden_state): + N, L = hidden_state.size(0), hidden_state.size(1) + num_head = self.num_attention_heads + dim_head = self.attention_head_size + scale = 1 / math.sqrt(self.attention_head_size) + + # bs, seq, emb -> seq, bs, emb + hidden_state = hidden_state.transpose(0, 1) + q = self.query(hidden_state) + k = self.key(hidden_state) + v = self.value(hidden_state) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + mask = ones # torch.tril(ones) + mask = mask.view(N, 1, L, L) + mask = (mask < 0.5) + attn = attn.masked_fill_(mask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, 0.0, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + # output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + output = output.transpose(0, 1).contiguous() + return output + + + +class BigBirdDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + self.bs = batch_size + super().__init__( + shapes=( + [args.bs, config.seqlen, config.hidden_size], + ), + dtypes=(torch.float16 if args.fp16 else torch.float,), + batch_dims=(0,) + ) + self.samples = [self.random_sample()] -# class BigBirdDataLoader(cube.runtime.syndata.CubeDataLoader): -# -# def __init__(self, batch_size: int): -# self.bs = batch_size -# super().__init__( -# shapes=( -# [batch_size, ] -# ) -# ) + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + def random_sample(self): + hidden_state = torch.randn( + self.bs, config.seqlen, config.hidden_size, + dtype=torch.float16 if args.fp16 else torch.float, + device=torch.cuda.current_device() + ) + return hidden_state @@ -918,8 +921,9 @@ def _get_single_block_row_attention( nparams = sum([param.numel() for param in model.parameters()]) print_each_rank(f'model params (M): {nparams / 1e6}. Launching model...') model = model.half().cuda() if args.fp16 else model.cuda() + model.eval() - dataloader = GPT3DataLoader(args.micro_bs) + dataloader = BigBirdDataLoader(args.bs) optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) print_each_rank('model weight consumpition:') @@ -927,31 +931,25 @@ def _get_single_block_row_attention( CudaTimer(enable=False) torch.distributed.barrier() - iter_num = 6 + iter_num =32 for step in range(iter_num): - if step >= 2: + if step >= 8: CudaTimer(enable=True).start('e2e') # train 1 step - num_microbatch = args.bs // (args.micro_bs * args.dp_size) - if args.pp_size > 1: - _schedule(model, dataloader, num_microbatch) - else: - for _ in range(num_microbatch): - model.data = next(dataloader) - loss = model() - loss.backward() - - if _pp_embed_reducer is not None: - _pp_embed_reducer.allreduce() - - if _dp_reducer is not None: - _dp_reducer.allreduce() + # num_microbatch = 1 + with torch.no_grad(): + data = next(dataloader) + if args.sparse: + out = model(data) + else: + out = model.forward_dense(data) + # loss.backward() optimizer.step() optimizer.zero_grad() - if step >= 2: + if step >= 8: CudaTimer().stop('e2e') torch.cuda.empty_cache() @@ -961,10 +959,10 @@ def _get_single_block_row_attention( print_each_rank('memory after optimizer:', rank_only=0) memory_summary() - if (step + 1) % 2 == 0: + if (step + 1) % 8 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-2, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-2) + CudaTimer().duration(iter_num-8, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-8) memory_summary() \ No newline at end of file From 1d1dbf68911bf4bc30d3b1575da38d6ce80b628e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 14 Apr 2022 06:53:07 +0000 Subject: [PATCH 0778/1892] add handcraft sparse attention --- handcraft/bigbird/sparse.py | 426 ++++++++++++++++++++++++++++++++++++ 1 file changed, 426 insertions(+) create mode 100644 handcraft/bigbird/sparse.py diff --git a/handcraft/bigbird/sparse.py b/handcraft/bigbird/sparse.py new file mode 100644 index 00000000..7aa39f98 --- /dev/null +++ b/handcraft/bigbird/sparse.py @@ -0,0 +1,426 @@ +""" +BigBird paper +https://papers.nips.cc/paper/2020/file/c8512d142a2d849725f31a9a7a361ab9-Paper.pdf + +Understanding blog: +https://github.com/huggingface/blog/blob/main/big-bird.md + + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + handcraft/bigbird/sparse.py \ + --hidden-size 4096 --heads 32 --seqlen 4096 \ + --bs 8 --fp16 +""" + +import torch +import torch.nn as nn +import cube +import math +import numpy as np + +import argparse +from cube.runtime.device import DeviceGroup +from cube.profiler import CudaTimer +from cube.profiler.memory import memory_summary +from cube.profiler.timer import print_each_rank + +from cube.runtime.adapter.distnn import AllGatherSplit, IdentityAllreduce + + + +parser = argparse.ArgumentParser(description='sparse_attention') + +parser.add_argument('--hidden-size', type=int, default=4096, + help='hidden size') +parser.add_argument('--heads', type=int, default=32, + help='number of heads') +parser.add_argument('--seqlen', type=int, default=4096, + help='sequence length') +parser.add_argument('--blk-size', type=int, default=64, + help='sequence length') +# parallelism +parser.add_argument('--tp-size', type=int, default=1, + help='tensor parallelism size') +# training config +parser.add_argument('--bs', type=int, default=256, + help='num of micro batch') +parser.add_argument('--fp16', action='store_true', default=False) +parser.add_argument('--sparse', action='store_true', default=False) +args = parser.parse_args() + +print(args) +cube.init() + +tp_ranks = list(range(args.tp_size)) +_tp_group = -1 +_tp_size = len(tp_ranks) +_tp_rank = DeviceGroup().rank +if len(tp_ranks) > 1: + print_each_rank(f'initializing tp ranks: {tp_ranks}') + _tp_group = DeviceGroup().get_group(tp_ranks) + + +class Config: + + num_attention_heads = args.heads + hidden_size = args.hidden_size + all_head_sie = hidden_size + seqlen = args.seqlen # seqlen + num_random_blocks = 3 + block_size=args.blk_size + use_bias = True + +config = Config() + + +def bmm(tensor1: torch.Tensor, tensor2: torch.Tensor, ndim: int): + # print(f'bmm: {tensor1.size()} {tensor2.size()}') + return torch.bmm( + tensor1.reshape((-1,) + tensor1.shape[-2:]), + tensor2.reshape((-1,) + tensor2.shape[-2:]) + ).view(tensor1.shape[: ndim - 2] + (tensor1.shape[ndim - 2], tensor2.shape[ndim-1])) + + +def stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + h: int, block_size: int): + # print('start stride qk') + # q, k, v: (N h) L d + num_head = h + L, N = q.size(1), q.size(0) // h + dim_head = q.size(2) + assert L % block_size == 0 + + q = q.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d + k = k.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d + v = v.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d + + # stride diagnal [2:] + middle_q = q[:, 2:] # (N h) (nblock-2) blksize d + # (N h) nblock-3 (3 blksize) d + sliding_keys = torch.cat((k[:, 1:-2], k[:, 2:-1], k[:, 3:]), dim=2) + # (N h) 1 blksize d + pad_k_zero = torch.zeros_like(k[:,-3:-2]) + # (N h) 1 (3 blksize) d + sliding_bottom_keys = torch.cat((pad_k_zero, k[:,-2:-1], k[:,-1:]), dim=2) + # (N h) (nblock-2) (3 blksize) d + sliding_keys = torch.cat((sliding_keys, sliding_bottom_keys), dim=1) + # (N h) (nblock-2) d (3 blksize) + sliding_keys = sliding_keys.transpose(2, 3) + + # (N h) (nblock-3) (3 blksize) d + stride_vals = torch.cat((v[:, 1:-2], v[:, 2:-1], k[:, 3:]), dim=2) + # (N h) 1 (3 blksize) d + stride_bottom_vals = torch.cat((pad_k_zero, v[:,-2:-1], v[:,-1:]), dim=2) + # (N h) (nblock-2) (3 blksize) d + stride_vals = torch.cat((stride_vals, stride_bottom_vals), dim=1) + + # (N h) (nblock-2) blksize (3 blksize) + qk = bmm(middle_q, sliding_keys, ndim=4) + # (N h) ((nblock-2) blksize) (3 blksize) + qk = qk.view(N * h, -1, block_size * 3) + return qk, stride_vals + + +def global_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + h: int, block_size: int): + # print('start global qk') + # q, k, v: (N h) L d + num_head = h + L, N = q.size(1), q.size(0) // h + dim_head = q.size(2) + assert L % block_size == 0 + + # first two row + head_q = q[:, :2 * block_size] # (N h) (2 blocksize) d + head_k = k.transpose(1, 2) # (N h) d L + head = bmm(head_q, head_k, ndim=3) # (N h) (2 blocksize) L + # (N h) L d + head_v = v + + # remain first two column + col_q = q[:, 2 * block_size:] # (N h) ((nblock-2) blocksize) d + col_k = k[:, :2 * block_size].transpose(1, 2) # (N h) d (2 blocksize) + # (N h) (2 blksize) d + col_v = v[:, :2 * block_size] + col = bmm(col_q, col_k, ndim=3) # (N h) ((nblock-2) blocksize) (2 blocksize) + return head, head_v, col, col_v + + +def randn_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + h: int, block_size: int): + # q, k, v: (N h) L d + rand_num = 2 + num_head = h + L, N = q.size(1), q.size(0) // h + dim_head = q.size(2) + # nblock-2 2 + indices = torch.randint( + 2, L // block_size, (L // block_size-2, rand_num), + dtype=torch.int64, device=torch.cuda.current_device() + ) + # (N h) nblock blksize d + k = k.view(N * num_head, L // block_size, block_size, dim_head) + v = v.view(N * num_head, L // block_size, block_size, dim_head) + # (N h) nblock-2 (2 blksize) d + gathered_k = torch.cat( + (k[:,indices[:,0]], k[:,indices[:,1]]), dim=2 + ) + # (N h) nblock-2 (2 blksize) d + gathered_v = torch.cat( + (v[:,indices[:,0]], v[:,indices[:,1]]), dim=2 + ) + # (N h) nblock blksize d + q = q.view(N * num_head, L // block_size, block_size, dim_head) + # (N h) nblock-2 blksize d + q = q[:,2:] + # (N h) nblock-2 blksize (2 blksize) + qk = bmm(q, gathered_k.transpose(2,3), ndim=4) + # (N h) ((nblock-2) blksize) (2 blksize) + qk = qk.view(N * h, -1, 2 * block_size) + return qk, gathered_v + + +def sparse_attn(query: torch.Tensor, + q_proj: torch.Tensor, q_bias: torch.Tensor, + k_proj: torch.Tensor, k_bias: torch.Tensor, + v_proj: torch.Tensor, v_bias: torch.Tensor, + out_proj: torch.Tensor, out_bias: torch.Tensor, + h: int, scale: float, dropout_p: float, block_size: int): + num_head = h + L, N = query.size(0), query.size(1) + dim_head = q_proj.size(0) // num_head + + q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + + # sqk: (N h) ((nblock-2) blksize) (3 blksize) + # sqk_v: (N h) (nblock-2) (3 blksize) d + sqk, sqk_v = stride_qk(q, k, v, h, block_size=block_size) + # head: (N h) (2 blocksize) L + # head_v: (N h) L d + # col: (N h) ((nblock-2) blocksize) (2 blocksize) + # col_v: (N h) (2 blksize) d + head, head_v, col, col_v = global_qk(q, k, v, h, block_size=block_size) + # rqk: (N h) ((nblock-2) blksize) (2 blksize) + # rqk_v: (N h) (nblock-2) (2 blksize) d + rqk, rqk_v = randn_qk(q, k, v, h, block_size=block_size) + + # (N h) ((nblock-2) blksize) L + head_attn = torch.nn.functional.softmax(head, dim=-1) + head_attn = torch.nn.functional.dropout(head_attn, dropout_p, True, False) + + # (N h) ((nblock-2) blksize) (7 blksize) + middle_attn = torch.cat((col, sqk, rqk), dim=-1) + # (N h) ((nblock-2) blksize) (7 blksize) + middle_attn = torch.nn.functional.softmax(middle_attn, dim=-1) + middle_attn = torch.nn.functional.dropout(middle_attn, dropout_p, True, False) + + # select just for performance test + # (N h) (2 blocksize) L, (N h) L d -> (N h) (2 blksize) d + head_output = bmm(head_attn, v, ndim=3) + + # global col v: (N h) ((nblock-2) blksize) (2 blksize), (N h) (2 blksize) d + # :-> (N h) (L-(2 blksize)) d + middle_output = bmm(middle_attn[:,:,:2 * block_size], col_v, ndim=3) + + middle_stride = middle_attn[:,:,2*block_size:5*block_size].view( + N * h, L // block_size - 2, block_size, 3 * block_size + ) + # stide v: (N h) (nblock-2) blksize (3 blksize), (N h) (nblock-2) (3 blksize) d + # : -> (N h) (nblock-2) blksize d + middle_stride_output = bmm(middle_stride, sqk_v, ndim=4) + middle_output += middle_stride_output.view(N * h, -1, sqk_v.size(-1)) + + # (N h) (nblock-2) blksize (2 blksize) + middle_rand = middle_attn[:,:,5*block_size:].view( + N * h, L // block_size - 2, block_size, 2 * block_size + ) + # rand v: (N h) (nblock-2) blksize (2 blksize), (N h) (nblock-2) (2 blksize) d + # -> (N h) (nblock-2) blksize d + middle_rand_output = bmm(middle_rand, rqk_v, ndim=4) + middle_output += middle_rand_output.view(N * h, -1, rqk_v.size(-1)) + + # (N h) (2 blksize) d, (N h) ((nblock-2) blksize) d -> (N h) L d + output = torch.cat((head_output, middle_output), dim=1) + + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + return output + + +def dense_attn(query: torch.Tensor, + q_proj: torch.Tensor, q_bias: torch.Tensor, + k_proj: torch.Tensor, k_bias: torch.Tensor, + v_proj: torch.Tensor, v_bias: torch.Tensor, + out_proj: torch.Tensor, out_bias: torch.Tensor, + h: int, scale: float, dropout_p: float, block_size: int): + num_head = h + L, N = query.size(0), query.size(1) + dim_head = q_proj.size(0) // num_head + + q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # attention mask + attention_mask = torch.ones(((N, L)), device=torch.cuda.current_device()) + attention_mask = attention_mask.view(N, L // block_size, block_size) + exp_blocked_to_pad = torch.cat( + [attention_mask[:, 1:-3], attention_mask[:, 2:-2], attention_mask[:, 3:-1]], dim=2 + ) + band_mask = torch.einsum("blq,blk->blqk", attention_mask[:, 2:-2], exp_blocked_to_pad) + band_mask.unsqueeze_(1) + band_mask = band_mask < 0.5 + # attn.masked_fill_(band_mask, -1000.0) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + return output + + +class MultiHeadSelfAttention(torch.nn.Module): + + def __init__(self): + super().__init__() + self.kdim = config.hidden_size + self.vdim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = 0.0 + self.block_size = config.block_size + # Q + self.q_proj = torch.nn.Parameter(torch.empty(config.hidden_size, config.hidden_size)) + self.q_bias = torch.nn.Parameter(torch.empty(config.hidden_size)) + # K + self.k_proj = torch.nn.Parameter(torch.empty(config.hidden_size, self.kdim)) + self.k_bias = torch.nn.Parameter(torch.empty(config.hidden_size)) + # V + self.v_proj = torch.nn.Parameter(torch.empty(config.hidden_size, self.vdim)) + self.v_bias = torch.nn.Parameter(torch.empty(config.hidden_size)) + # Out + self.out_proj = torch.nn.Parameter(torch.empty(config.hidden_size, config.hidden_size)) + self.out_bias = torch.nn.Parameter(torch.empty(config.hidden_size)) + + def forward(self, query): + if args.sparse: + return sparse_attn( + query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p, self.block_size + ) + else: + return dense_attn( + query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p, self.block_size + ) + + +class AttnDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + self.bs = batch_size + super().__init__( + shapes=( + [config.seqlen, args.bs, config.hidden_size], + ), + dtypes=(torch.float16 if args.fp16 else torch.float,), + batch_dims=(0,) + ) + self.samples = [self.random_sample()] + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + def random_sample(self): + hidden_state = torch.randn( + config.seqlen, self.bs, config.hidden_size, + dtype=torch.float16 if args.fp16 else torch.float, + device=torch.cuda.current_device() + ) + return hidden_state + + + +if __name__ == '__main__': + + model = MultiHeadSelfAttention() + nparams = sum([param.numel() for param in model.parameters()]) + print_each_rank(f'model params (M): {nparams / 1e6}. Launching model...') + model = model.half().cuda() if args.fp16 else model.cuda() + model.eval() + + dataloader = AttnDataLoader(args.bs) + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + print_each_rank('model weight consumpition:') + memory_summary() + + CudaTimer(enable=False) + torch.distributed.barrier() + iter_num =32 + for step in range(iter_num): + if step >= 8: + CudaTimer(enable=True).start('e2e') + + # train 1 step + # num_microbatch = 1 + with torch.no_grad(): + data = next(dataloader) + out = model(data) + # loss.backward() + + optimizer.step() + optimizer.zero_grad() + + if step >= 8: + CudaTimer().stop('e2e') + + torch.cuda.empty_cache() + torch.distributed.barrier() + + if step == 0: + print_each_rank('memory after optimizer:', rank_only=0) + memory_summary() + + if (step + 1) % 8 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-8, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-8) + memory_summary() \ No newline at end of file From ba4c056d3f8192274354d031b88b3f9d11c1fae3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 14 Apr 2022 08:07:19 +0000 Subject: [PATCH 0779/1892] add sparse attention profiling --- handcraft/bigbird/sparse.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/handcraft/bigbird/sparse.py b/handcraft/bigbird/sparse.py index 7aa39f98..3202a763 100644 --- a/handcraft/bigbird/sparse.py +++ b/handcraft/bigbird/sparse.py @@ -85,6 +85,7 @@ def bmm(tensor1: torch.Tensor, tensor2: torch.Tensor, ndim: int): def stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, h: int, block_size: int): + CudaTimer().start('stride_qk') # print('start stride qk') # q, k, v: (N h) L d num_head = h @@ -120,11 +121,13 @@ def stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qk = bmm(middle_q, sliding_keys, ndim=4) # (N h) ((nblock-2) blksize) (3 blksize) qk = qk.view(N * h, -1, block_size * 3) + CudaTimer().stop('stride_qk') return qk, stride_vals def global_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, h: int, block_size: int): + CudaTimer().start('global_qk') # print('start global qk') # q, k, v: (N h) L d num_head = h @@ -145,11 +148,13 @@ def global_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # (N h) (2 blksize) d col_v = v[:, :2 * block_size] col = bmm(col_q, col_k, ndim=3) # (N h) ((nblock-2) blocksize) (2 blocksize) + CudaTimer().stop('global_qk') return head, head_v, col, col_v def randn_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, h: int, block_size: int): + CudaTimer().start('rand_qk') # q, k, v: (N h) L d rand_num = 2 num_head = h @@ -179,6 +184,7 @@ def randn_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qk = bmm(q, gathered_k.transpose(2,3), ndim=4) # (N h) ((nblock-2) blksize) (2 blksize) qk = qk.view(N * h, -1, 2 * block_size) + CudaTimer().stop('rand_qk') return qk, gathered_v @@ -192,6 +198,7 @@ def sparse_attn(query: torch.Tensor, L, N = query.size(0), query.size(1) dim_head = q_proj.size(0) // num_head + CudaTimer().start('to_qkv') q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) @@ -202,6 +209,7 @@ def sparse_attn(query: torch.Tensor, k = k.transpose(0, 1) # L (N h) d -> (N h) L d v = v.transpose(0, 1) # L (N h) d -> (N h) L d q = q * scale # (N h) L d, 1 -> (N h) L d + CudaTimer().stop('to_qkv') # sqk: (N h) ((nblock-2) blksize) (3 blksize) # sqk_v: (N h) (nblock-2) (3 blksize) d @@ -216,6 +224,7 @@ def sparse_attn(query: torch.Tensor, rqk, rqk_v = randn_qk(q, k, v, h, block_size=block_size) # (N h) ((nblock-2) blksize) L + CudaTimer().start('all_softmax') head_attn = torch.nn.functional.softmax(head, dim=-1) head_attn = torch.nn.functional.dropout(head_attn, dropout_p, True, False) @@ -224,15 +233,18 @@ def sparse_attn(query: torch.Tensor, # (N h) ((nblock-2) blksize) (7 blksize) middle_attn = torch.nn.functional.softmax(middle_attn, dim=-1) middle_attn = torch.nn.functional.dropout(middle_attn, dropout_p, True, False) + CudaTimer().stop('all_softmax') - # select just for performance test + CudaTimer().start('global_qk') # (N h) (2 blocksize) L, (N h) L d -> (N h) (2 blksize) d head_output = bmm(head_attn, v, ndim=3) # global col v: (N h) ((nblock-2) blksize) (2 blksize), (N h) (2 blksize) d # :-> (N h) (L-(2 blksize)) d middle_output = bmm(middle_attn[:,:,:2 * block_size], col_v, ndim=3) + CudaTimer().stop('global_qk') + CudaTimer().start('stride_qk') middle_stride = middle_attn[:,:,2*block_size:5*block_size].view( N * h, L // block_size - 2, block_size, 3 * block_size ) @@ -240,8 +252,10 @@ def sparse_attn(query: torch.Tensor, # : -> (N h) (nblock-2) blksize d middle_stride_output = bmm(middle_stride, sqk_v, ndim=4) middle_output += middle_stride_output.view(N * h, -1, sqk_v.size(-1)) + CudaTimer().stop('stride_qk') # (N h) (nblock-2) blksize (2 blksize) + CudaTimer().start('rand_qk') middle_rand = middle_attn[:,:,5*block_size:].view( N * h, L // block_size - 2, block_size, 2 * block_size ) @@ -249,6 +263,7 @@ def sparse_attn(query: torch.Tensor, # -> (N h) (nblock-2) blksize d middle_rand_output = bmm(middle_rand, rqk_v, ndim=4) middle_output += middle_rand_output.view(N * h, -1, rqk_v.size(-1)) + CudaTimer().stop('rand_qk') # (N h) (2 blksize) d, (N h) ((nblock-2) blksize) d -> (N h) L d output = torch.cat((head_output, middle_output), dim=1) From 4f183faf3e3e2450e5fff04b75e5c393880299a7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 14 Apr 2022 09:54:37 +0000 Subject: [PATCH 0780/1892] add missing dropout --- handcraft/gpt3/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 99e8ff86..4c864734 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -91,8 +91,8 @@ _pp_embed_group = -1 _pp_embed_reducer = None cube.init() -print_each_rank('setting memory constraints to 16GB') -torch.cuda.set_per_process_memory_fraction(0.5) +# print_each_rank('setting memory constraints to 16GB') +# torch.cuda.set_per_process_memory_fraction(0.5) dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( [args.dp_size, args.pp_size, args.tp_size] @@ -288,6 +288,7 @@ def forward_(self, x, mask): if mask is not None: attention_scores.masked_fill_(mask, -10000.0) attention_probs = self.softmax(attention_scores) + attention_probs = torch.nn.functional.dropout(attention_probs, 0.0) output_size = (value_layer.size(1), value_layer.size(2), From 95e7344b1836697875a8878fe1b1fd2c6d4273d2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 14 Apr 2022 12:38:02 +0000 Subject: [PATCH 0781/1892] enable sparse parallel --- handcraft/bigbird/sparse.py | 293 +++++++++++++++++++++++++++++------- 1 file changed, 238 insertions(+), 55 deletions(-) diff --git a/handcraft/bigbird/sparse.py b/handcraft/bigbird/sparse.py index 3202a763..8874555f 100644 --- a/handcraft/bigbird/sparse.py +++ b/handcraft/bigbird/sparse.py @@ -40,9 +40,6 @@ help='sequence length') parser.add_argument('--blk-size', type=int, default=64, help='sequence length') -# parallelism -parser.add_argument('--tp-size', type=int, default=1, - help='tensor parallelism size') # training config parser.add_argument('--bs', type=int, default=256, help='num of micro batch') @@ -53,7 +50,7 @@ print(args) cube.init() -tp_ranks = list(range(args.tp_size)) +tp_ranks = list(range(DeviceGroup().world_size)) _tp_group = -1 _tp_size = len(tp_ranks) _tp_rank = DeviceGroup().rank @@ -85,7 +82,7 @@ def bmm(tensor1: torch.Tensor, tensor2: torch.Tensor, ndim: int): def stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, h: int, block_size: int): - CudaTimer().start('stride_qk') + CudaTimer().start('stride_q@k') # print('start stride qk') # q, k, v: (N h) L d num_head = h @@ -95,39 +92,81 @@ def stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q = q.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d k = k.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d - v = v.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d # stride diagnal [2:] middle_q = q[:, 2:] # (N h) (nblock-2) blksize d # (N h) nblock-3 (3 blksize) d sliding_keys = torch.cat((k[:, 1:-2], k[:, 2:-1], k[:, 3:]), dim=2) # (N h) 1 blksize d - pad_k_zero = torch.zeros_like(k[:,-3:-2]) + pad_zero = torch.zeros_like(k[:,-3:-2]) # (N h) 1 (3 blksize) d - sliding_bottom_keys = torch.cat((pad_k_zero, k[:,-2:-1], k[:,-1:]), dim=2) + sliding_bottom_keys = torch.cat((pad_zero, k[:,-2:-1], k[:,-1:]), dim=2) # (N h) (nblock-2) (3 blksize) d sliding_keys = torch.cat((sliding_keys, sliding_bottom_keys), dim=1) # (N h) (nblock-2) d (3 blksize) sliding_keys = sliding_keys.transpose(2, 3) - # (N h) (nblock-3) (3 blksize) d - stride_vals = torch.cat((v[:, 1:-2], v[:, 2:-1], k[:, 3:]), dim=2) - # (N h) 1 (3 blksize) d - stride_bottom_vals = torch.cat((pad_k_zero, v[:,-2:-1], v[:,-1:]), dim=2) - # (N h) (nblock-2) (3 blksize) d - stride_vals = torch.cat((stride_vals, stride_bottom_vals), dim=1) - # (N h) (nblock-2) blksize (3 blksize) qk = bmm(middle_q, sliding_keys, ndim=4) # (N h) ((nblock-2) blksize) (3 blksize) qk = qk.view(N * h, -1, block_size * 3) - CudaTimer().stop('stride_qk') - return qk, stride_vals + CudaTimer().stop('stride_q@k') + return qk + + +def parallel_stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + h: int, block_size: int, start: int, end: int): + CudaTimer().start('parallel_stride_q@k') + # print('start stride qk') + # q, k, v: (N h) L d + num_head = h + L, N = q.size(1), q.size(0) // h + dim_head = q.size(2) + assert L % block_size == 0 + + q = q.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d + k = k.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d + + # (N h) 1 blksize d + pad_zero = torch.zeros_like(k[:,-1:]) + # (N h) (end-start) blksize d + middle_q = q[:,2+start:2+end] + if end + 2 == L // block_size: + # (N h) (end-start) blksize d + k_right = torch.cat((k[:,start+3:], pad_zero), dim=1) + else: + k_right = k[:,start+3:end+3] + # (N h) (end-start) (3 blksize) d + sliding_keys = torch.cat((k[:, start+1:end+1], k[:, start+2:end+2], k_right), dim=2) + # (N h) (end-start) blksize (3 blksize) + qk = bmm(middle_q, sliding_keys.transpose(2, 3), ndim=4) + # (N h) (nblock-2) blksize (3 blksize) + qk = torch.nn.functional.pad( + qk, (0,0,0,0,start,L//block_size-2-end), 'constant', 0) + # (N h) ((nblock-2) blksize) (3 blksize) + qk = qk.view(N * h, -1, block_size * 3) + CudaTimer().stop('parallel_stride_q@k') + return qk + + +def stride_v(v: torch.Tensor, h: int, block_size: int): + L, N, dim_head = v.size(1), v.size(0) // h, v.size(2) + assert L % block_size == 0 + v = v.view(N * h, L // block_size, block_size, dim_head) + # (N h) 1 blksize d + pad_zero = torch.zeros_like(v[:,-3:-2]) + # (N h) (nblock-3) (3 blksize) d + stride_vals = torch.cat((v[:, 1:-2], v[:, 2:-1], v[:, 3:]), dim=2) + # (N h) 1 (3 blksize) d + stride_bottom_vals = torch.cat((v[:,-2:-1], v[:,-1:], pad_zero), dim=2) + # (N h) (nblock-2) (3 blksize) d + stride_vals = torch.cat((stride_vals, stride_bottom_vals), dim=1) + return stride_vals def global_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, h: int, block_size: int): - CudaTimer().start('global_qk') + CudaTimer().start('global_q@k') # print('start global qk') # q, k, v: (N h) L d num_head = h @@ -148,15 +187,16 @@ def global_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # (N h) (2 blksize) d col_v = v[:, :2 * block_size] col = bmm(col_q, col_k, ndim=3) # (N h) ((nblock-2) blocksize) (2 blocksize) - CudaTimer().stop('global_qk') + CudaTimer().stop('global_q@k') return head, head_v, col, col_v def randn_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - h: int, block_size: int): - CudaTimer().start('rand_qk') + h: int, block_size: int, rand_num: int = 2): + CudaTimer().start('rand_q@k') + torch.manual_seed(0) # q, k, v: (N h) L d - rand_num = 2 + # rand_num = 2 num_head = h L, N = q.size(1), q.size(0) // h dim_head = q.size(2) @@ -167,25 +207,39 @@ def randn_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, ) # (N h) nblock blksize d k = k.view(N * num_head, L // block_size, block_size, dim_head) - v = v.view(N * num_head, L // block_size, block_size, dim_head) - # (N h) nblock-2 (2 blksize) d - gathered_k = torch.cat( - (k[:,indices[:,0]], k[:,indices[:,1]]), dim=2 - ) - # (N h) nblock-2 (2 blksize) d - gathered_v = torch.cat( - (v[:,indices[:,0]], v[:,indices[:,1]]), dim=2 - ) + + # Optimize: remove for loop, use direct index can greatly speedup + # (N h) nblock-2 (randnum blksize) d + keys = tuple(k[:,indices[:,idx]] for idx in range(rand_num)) + gathered_k = torch.cat(keys, dim=2) + # (N h) nblock blksize d q = q.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock-2 blksize d q = q[:,2:] - # (N h) nblock-2 blksize (2 blksize) - qk = bmm(q, gathered_k.transpose(2,3), ndim=4) - # (N h) ((nblock-2) blksize) (2 blksize) - qk = qk.view(N * h, -1, 2 * block_size) - CudaTimer().stop('rand_qk') - return qk, gathered_v + # (N h) nblock-2 blksize (randnum blksize) + qk = bmm(q, gathered_k.transpose(2, 3), ndim=4) + # (N h) ((nblock-2) blksize) (randnum blksize) + qk = qk.view(N * h, -1, rand_num * block_size) + CudaTimer().stop('rand_q@k') + return qk + + +def randn_v(v: torch.Tensor, h: int, block_size: int, rand_num: int = 2): + # v: (N h) L d + # CudaTimer().start('rand_v') + torch.manual_seed(0) + L, N, dim_head = v.size(1), v.size(0) // h, v.size(2) + # nblock-2 2 + indices = torch.randint( + 2, L // block_size, (L // block_size-2, rand_num), + dtype=torch.int64, device=torch.cuda.current_device() + ) + v = v.view(N * h, L // block_size, block_size, dim_head) + vals = tuple(v[:,indices[:,idx]] for idx in range(rand_num)) + gathered_v = torch.cat(vals, dim=2) + # CudaTimer().stop('rand_v') + return gathered_v def sparse_attn(query: torch.Tensor, @@ -194,6 +248,7 @@ def sparse_attn(query: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, out_proj: torch.Tensor, out_bias: torch.Tensor, h: int, scale: float, dropout_p: float, block_size: int): + rand_num = 2 num_head = h L, N = query.size(0), query.size(1) dim_head = q_proj.size(0) // num_head @@ -211,17 +266,125 @@ def sparse_attn(query: torch.Tensor, q = q * scale # (N h) L d, 1 -> (N h) L d CudaTimer().stop('to_qkv') + CudaTimer().start('q@k') # sqk: (N h) ((nblock-2) blksize) (3 blksize) + sqk = stride_qk(q, k, v, h, block_size=block_size) + # head: (N h) (2 blocksize) L + # head_v: (N h) L d + # col: (N h) ((nblock-2) blocksize) (2 blocksize) + # col_v: (N h) (2 blksize) d + head, head_v, col, col_v = global_qk(q, k, v, h, block_size=block_size) + # rqk: (N h) ((nblock-2) blksize) (2 blksize) + rqk = randn_qk(q, k, v, h, block_size=block_size) + CudaTimer().stop('q@k') + # sqk_v: (N h) (nblock-2) (3 blksize) d - sqk, sqk_v = stride_qk(q, k, v, h, block_size=block_size) + sqk_v = stride_v(v, h, block_size) + # rqk_v: (N h) (nblock-2) (2 blksize) d + rqk_v = randn_v(v, h, block_size) + + # (N h) ((nblock-2) blksize) L + CudaTimer().start('all_softmax') + head_attn = torch.nn.functional.softmax(head, dim=-1) + head_attn = torch.nn.functional.dropout(head_attn, dropout_p, True, False) + + # (N h) ((nblock-2) blksize) (7 blksize) + middle_attn = torch.cat((col, sqk, rqk), dim=-1) + # (N h) ((nblock-2) blksize) (7 blksize) + middle_attn = torch.nn.functional.softmax(middle_attn, dim=-1) + middle_attn = torch.nn.functional.dropout(middle_attn, dropout_p, True, False) + CudaTimer().stop('all_softmax') + + CudaTimer().start('global_qk@v') + # (N h) (2 blocksize) L, (N h) L d -> (N h) (2 blksize) d + head_output = bmm(head_attn, v, ndim=3) + + # global col v: (N h) ((nblock-2) blksize) (2 blksize), (N h) (2 blksize) d + # :-> (N h) (L-(2 blksize)) d + middle_output = bmm(middle_attn[:,:,:2 * block_size], col_v, ndim=3) + CudaTimer().stop('global_qk@v') + + CudaTimer().start('stride_qk@v') + middle_stride = middle_attn[:,:,2*block_size:5*block_size].view( + N * h, L // block_size - 2, block_size, 3 * block_size + ) + # stide v: (N h) (nblock-2) blksize (3 blksize), (N h) (nblock-2) (3 blksize) d + # : -> (N h) (nblock-2) blksize d + middle_stride_output = bmm(middle_stride, sqk_v, ndim=4) + middle_output += middle_stride_output.view(N * h, -1, sqk_v.size(-1)) + CudaTimer().stop('stride_qk@v') + + # (N h) (nblock-2) blksize (randnum blksize) + CudaTimer().start('rand_qk@v') + middle_rand = middle_attn[:,:,5*block_size:].view( + N * h, L // block_size - 2, block_size, rand_num * block_size + ) + # rand v: (N h) (nblock-2) blksize (2 blksize), (N h) (nblock-2) (2 blksize) d + # -> (N h) (nblock-2) blksize d + middle_rand_output = bmm(middle_rand, rqk_v, ndim=4) + middle_output += middle_rand_output.view(N * h, -1, rqk_v.size(-1)) + CudaTimer().stop('rand_qk@v') + + # (N h) (2 blksize) d, (N h) ((nblock-2) blksize) d -> (N h) L d + output = torch.cat((head_output, middle_output), dim=1) + + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + return output + + +def parallel_sparse_attn(query: torch.Tensor, + q_proj: torch.Tensor, q_bias: torch.Tensor, + k_proj: torch.Tensor, k_bias: torch.Tensor, + v_proj: torch.Tensor, v_bias: torch.Tensor, + out_proj: torch.Tensor, out_bias: torch.Tensor, + h: int, scale: float, dropout_p: float, block_size: int): + rand_num = 2 + num_head = h + L, N = query.size(0), query.size(1) + dim_head = q_proj.size(0) // num_head + + CudaTimer().start('to_qkv') + q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + CudaTimer().stop('to_qkv') + + nblocks = L // block_size - 2 + scope = [nblocks // _tp_size] * _tp_size + for idx in range(nblocks % _tp_size): + scope[idx] += 1 + start, end = sum(scope[:_tp_rank]), sum(scope[:_tp_rank+1]) + + CudaTimer().start('q@k') + # sqk: (N h) ((nblock-2) blksize) (3 blksize) + sqk = parallel_stride_qk( + q, k, v, h, block_size, start, end) + req = torch.distributed.all_reduce(sqk, async_op=True) + # head: (N h) (2 blocksize) L # head_v: (N h) L d # col: (N h) ((nblock-2) blocksize) (2 blocksize) # col_v: (N h) (2 blksize) d head, head_v, col, col_v = global_qk(q, k, v, h, block_size=block_size) + # rqk: (N h) ((nblock-2) blksize) (2 blksize) # rqk_v: (N h) (nblock-2) (2 blksize) d - rqk, rqk_v = randn_qk(q, k, v, h, block_size=block_size) + rqk = randn_qk(q, k, v, h, block_size=block_size) + req.wait() + CudaTimer().stop('q@k') + + # sqk_v: (N h) (nblock-2) (3 blksize) d + sqk_v = stride_v(v, h, block_size) + rqk_v = randn_v(v, h, block_size) # (N h) ((nblock-2) blksize) L CudaTimer().start('all_softmax') @@ -235,16 +398,16 @@ def sparse_attn(query: torch.Tensor, middle_attn = torch.nn.functional.dropout(middle_attn, dropout_p, True, False) CudaTimer().stop('all_softmax') - CudaTimer().start('global_qk') + CudaTimer().start('global_qk@v') # (N h) (2 blocksize) L, (N h) L d -> (N h) (2 blksize) d head_output = bmm(head_attn, v, ndim=3) # global col v: (N h) ((nblock-2) blksize) (2 blksize), (N h) (2 blksize) d # :-> (N h) (L-(2 blksize)) d middle_output = bmm(middle_attn[:,:,:2 * block_size], col_v, ndim=3) - CudaTimer().stop('global_qk') + CudaTimer().stop('global_qk@v') - CudaTimer().start('stride_qk') + CudaTimer().start('stride_qk@v') middle_stride = middle_attn[:,:,2*block_size:5*block_size].view( N * h, L // block_size - 2, block_size, 3 * block_size ) @@ -252,25 +415,27 @@ def sparse_attn(query: torch.Tensor, # : -> (N h) (nblock-2) blksize d middle_stride_output = bmm(middle_stride, sqk_v, ndim=4) middle_output += middle_stride_output.view(N * h, -1, sqk_v.size(-1)) - CudaTimer().stop('stride_qk') + CudaTimer().stop('stride_qk@v') - # (N h) (nblock-2) blksize (2 blksize) - CudaTimer().start('rand_qk') + # (N h) (nblock-2) blksize (randnum blksize) + CudaTimer().start('rand_qk@v') middle_rand = middle_attn[:,:,5*block_size:].view( - N * h, L // block_size - 2, block_size, 2 * block_size + N * h, L // block_size - 2, block_size, rand_num * block_size ) # rand v: (N h) (nblock-2) blksize (2 blksize), (N h) (nblock-2) (2 blksize) d # -> (N h) (nblock-2) blksize d middle_rand_output = bmm(middle_rand, rqk_v, ndim=4) middle_output += middle_rand_output.view(N * h, -1, rqk_v.size(-1)) - CudaTimer().stop('rand_qk') + CudaTimer().stop('rand_qk@v') # (N h) (2 blksize) d, (N h) ((nblock-2) blksize) d -> (N h) L d output = torch.cat((head_output, middle_output), dim=1) + CudaTimer().start('out_proj') output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + CudaTimer().stop('out_proj') return output @@ -294,8 +459,10 @@ def dense_attn(query: torch.Tensor, k = k.transpose(0, 1) # L (N h) d -> (N h) L d v = v.transpose(0, 1) # L (N h) d -> (N h) L d q = q * scale # (N h) L d, 1 -> (N h) L d + CudaTimer().start('q@k') k = k.transpose(1, 2) # (N h) L d -> (N h) d L attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + CudaTimer().stop('q@k') # attention mask attention_mask = torch.ones(((N, L)), device=torch.cuda.current_device()) @@ -308,9 +475,15 @@ def dense_attn(query: torch.Tensor, band_mask = band_mask < 0.5 # attn.masked_fill_(band_mask, -1000.0) + CudaTimer().start('all_softmax') attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L + CudaTimer().stop('all_softmax') + + CudaTimer().start('qk@v') output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + CudaTimer().stop('qk@v') + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E @@ -343,14 +516,24 @@ def __init__(self): def forward(self, query): if args.sparse: - return sparse_attn( - query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p, self.block_size - ) + if _tp_size > 1: + return parallel_sparse_attn( + query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p, self.block_size + ) + else: + return sparse_attn( + query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p, self.block_size + ) else: return dense_attn( query, From 03b016a83648141b25cc398899e85c9162812d7f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 15 Apr 2022 03:01:39 +0000 Subject: [PATCH 0782/1892] fatal bug fix: mlp not backward --- handcraft/gpt3/train.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 4c864734..a2275065 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -135,7 +135,7 @@ class Config: - vocab_size = 50273 + vocab_size = 50432 seqlen = args.seqlen layers = args.layers heads = args.heads @@ -390,7 +390,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.vocab_start_index = num_embeddings // self.tp_size * self.tp_id self.vocab_end_index = num_embeddings // self.tp_size * (self.tp_id + 1) self.weight = torch.nn.Parameter( - torch.ones((num_embeddings // self.tp_size, embedding_dim)) + torch.ones((num_embeddings // self.tp_size, embedding_dim), requires_grad=True) ) def forward(self, tokens): @@ -443,20 +443,20 @@ def __init__(self): def forward(self, hidden_states, attention_mask): - layernrom_output = self.input_layernorm(hidden_states) + layernorm_output = self.input_layernorm(hidden_states) - attention_output = self.self_attention(layernrom_output, attention_mask) + attention_output = self.self_attention(layernorm_output, attention_mask) residual = hidden_states - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout) layernorm_input = layernorm_input + residual - layernrom_output = self.post_attention_layernorm(layernorm_input) + layernorm_output = self.post_attention_layernorm(layernorm_input) - mlp_output = self.mlp(layernrom_output) + mlp_output = self.mlp(layernorm_output) residual = layernorm_input - output = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - output = layernorm_input + residual + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout) + output = output + residual return output def flops(self): @@ -682,7 +682,7 @@ def get_alpa_tflops(): model = model.half().cuda() if args.fp16 else model.cuda() dataloader = GPT3DataLoader(args.micro_bs) - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.02, lr=3e-05, betas=(0.9, 0.98)) if _pp_embed_reducer is not None: _pp_embed_reducer.add_param(model.word_embeddings.weight) if _dp_reducer is not None: From 6adb11b17e5544a17cdc080f59ff5cde332a526e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 15 Apr 2022 12:49:21 +0000 Subject: [PATCH 0783/1892] fix bugs in testing performance --- handcraft/bigbird/sparse.py | 176 +++++++++++------------------------- 1 file changed, 51 insertions(+), 125 deletions(-) diff --git a/handcraft/bigbird/sparse.py b/handcraft/bigbird/sparse.py index 8874555f..b5822faa 100644 --- a/handcraft/bigbird/sparse.py +++ b/handcraft/bigbird/sparse.py @@ -10,8 +10,8 @@ --nproc_per_node=1 \ --nnodes=1 \ handcraft/bigbird/sparse.py \ - --hidden-size 4096 --heads 32 --seqlen 4096 \ - --bs 8 --fp16 + --hidden-size 5120 --heads 32 --seqlen 12288 \ + --bs 1 --fp16 """ import torch @@ -65,23 +65,24 @@ class Config: hidden_size = args.hidden_size all_head_sie = hidden_size seqlen = args.seqlen # seqlen - num_random_blocks = 3 + num_random_blocks = 2 block_size=args.blk_size use_bias = True config = Config() -def bmm(tensor1: torch.Tensor, tensor2: torch.Tensor, ndim: int): +def bmm(tensor1: torch.Tensor, tensor2: torch.Tensor, ndim: int, out=None): # print(f'bmm: {tensor1.size()} {tensor2.size()}') return torch.bmm( tensor1.reshape((-1,) + tensor1.shape[-2:]), - tensor2.reshape((-1,) + tensor2.shape[-2:]) + tensor2.reshape((-1,) + tensor2.shape[-2:]), + out=out ).view(tensor1.shape[: ndim - 2] + (tensor1.shape[ndim - 2], tensor2.shape[ndim-1])) def stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - h: int, block_size: int): + h: int, block_size: int, out=None): CudaTimer().start('stride_q@k') # print('start stride qk') # q, k, v: (N h) L d @@ -107,9 +108,9 @@ def stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sliding_keys = sliding_keys.transpose(2, 3) # (N h) (nblock-2) blksize (3 blksize) - qk = bmm(middle_q, sliding_keys, ndim=4) + out = bmm(middle_q, sliding_keys, ndim=4, out=out) # (N h) ((nblock-2) blksize) (3 blksize) - qk = qk.view(N * h, -1, block_size * 3) + qk = out.view(N * h, -1, block_size * 3) CudaTimer().stop('stride_q@k') return qk @@ -192,7 +193,7 @@ def global_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def randn_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - h: int, block_size: int, rand_num: int = 2): + h: int, block_size: int, rand_num: int = 2, out=None): CudaTimer().start('rand_q@k') torch.manual_seed(0) # q, k, v: (N h) L d @@ -218,9 +219,9 @@ def randn_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # (N h) nblock-2 blksize d q = q[:,2:] # (N h) nblock-2 blksize (randnum blksize) - qk = bmm(q, gathered_k.transpose(2, 3), ndim=4) + out = bmm(q, gathered_k.transpose(2, 3), ndim=4, out=out) # (N h) ((nblock-2) blksize) (randnum blksize) - qk = qk.view(N * h, -1, rand_num * block_size) + qk = out.view(N * h, -1, rand_num * block_size) CudaTimer().stop('rand_q@k') return qk @@ -252,6 +253,7 @@ def sparse_attn(query: torch.Tensor, num_head = h L, N = query.size(0), query.size(1) dim_head = q_proj.size(0) // num_head + nblocks = L // block_size CudaTimer().start('to_qkv') q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) @@ -266,138 +268,49 @@ def sparse_attn(query: torch.Tensor, q = q * scale # (N h) L d, 1 -> (N h) L d CudaTimer().stop('to_qkv') + # sqk = torch.empty( + # N * h, (nblocks - 2) * block_size, 3 * block_size, + # dtype=torch.float16 if args.fp16 else torch.float32, + # device=torch.cuda.current_device() + # ) + # + # rqk = torch.empty( + # N * h, (nblocks - 2) * block_size, 2 * block_size, + # dtype=torch.float16 if args.fp16 else torch.float32, + # device=torch.cuda.current_device() + # ) + # we don't need pre-allocation as memory are sufficient + sqk = rqk = None + CudaTimer().start('q@k') # sqk: (N h) ((nblock-2) blksize) (3 blksize) - sqk = stride_qk(q, k, v, h, block_size=block_size) + sqk = stride_qk(q, k, v, h, block_size=block_size, out=sqk) # head: (N h) (2 blocksize) L # head_v: (N h) L d # col: (N h) ((nblock-2) blocksize) (2 blocksize) # col_v: (N h) (2 blksize) d head, head_v, col, col_v = global_qk(q, k, v, h, block_size=block_size) # rqk: (N h) ((nblock-2) blksize) (2 blksize) - rqk = randn_qk(q, k, v, h, block_size=block_size) + rqk = randn_qk(q, k, v, h, block_size=block_size, out=rqk) + # (N h) ((nblock-2) blksize) (7 blksize) + middle_attn = torch.cat((col, sqk, rqk), dim=-1) CudaTimer().stop('q@k') - # sqk_v: (N h) (nblock-2) (3 blksize) d - sqk_v = stride_v(v, h, block_size) - # rqk_v: (N h) (nblock-2) (2 blksize) d - rqk_v = randn_v(v, h, block_size) - - # (N h) ((nblock-2) blksize) L CudaTimer().start('all_softmax') + # (N h) ((nblock-2) blksize) L head_attn = torch.nn.functional.softmax(head, dim=-1) head_attn = torch.nn.functional.dropout(head_attn, dropout_p, True, False) - - # (N h) ((nblock-2) blksize) (7 blksize) - middle_attn = torch.cat((col, sqk, rqk), dim=-1) # (N h) ((nblock-2) blksize) (7 blksize) middle_attn = torch.nn.functional.softmax(middle_attn, dim=-1) middle_attn = torch.nn.functional.dropout(middle_attn, dropout_p, True, False) CudaTimer().stop('all_softmax') - CudaTimer().start('global_qk@v') - # (N h) (2 blocksize) L, (N h) L d -> (N h) (2 blksize) d - head_output = bmm(head_attn, v, ndim=3) - - # global col v: (N h) ((nblock-2) blksize) (2 blksize), (N h) (2 blksize) d - # :-> (N h) (L-(2 blksize)) d - middle_output = bmm(middle_attn[:,:,:2 * block_size], col_v, ndim=3) - CudaTimer().stop('global_qk@v') - - CudaTimer().start('stride_qk@v') - middle_stride = middle_attn[:,:,2*block_size:5*block_size].view( - N * h, L // block_size - 2, block_size, 3 * block_size - ) - # stide v: (N h) (nblock-2) blksize (3 blksize), (N h) (nblock-2) (3 blksize) d - # : -> (N h) (nblock-2) blksize d - middle_stride_output = bmm(middle_stride, sqk_v, ndim=4) - middle_output += middle_stride_output.view(N * h, -1, sqk_v.size(-1)) - CudaTimer().stop('stride_qk@v') - - # (N h) (nblock-2) blksize (randnum blksize) - CudaTimer().start('rand_qk@v') - middle_rand = middle_attn[:,:,5*block_size:].view( - N * h, L // block_size - 2, block_size, rand_num * block_size - ) - # rand v: (N h) (nblock-2) blksize (2 blksize), (N h) (nblock-2) (2 blksize) d - # -> (N h) (nblock-2) blksize d - middle_rand_output = bmm(middle_rand, rqk_v, ndim=4) - middle_output += middle_rand_output.view(N * h, -1, rqk_v.size(-1)) - CudaTimer().stop('rand_qk@v') - - # (N h) (2 blksize) d, (N h) ((nblock-2) blksize) d -> (N h) L d - output = torch.cat((head_output, middle_output), dim=1) - - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E - return output - - -def parallel_sparse_attn(query: torch.Tensor, - q_proj: torch.Tensor, q_bias: torch.Tensor, - k_proj: torch.Tensor, k_bias: torch.Tensor, - v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, block_size: int): - rand_num = 2 - num_head = h - L, N = query.size(0), query.size(1) - dim_head = q_proj.size(0) // num_head - - CudaTimer().start('to_qkv') - q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - CudaTimer().stop('to_qkv') - - nblocks = L // block_size - 2 - scope = [nblocks // _tp_size] * _tp_size - for idx in range(nblocks % _tp_size): - scope[idx] += 1 - start, end = sum(scope[:_tp_rank]), sum(scope[:_tp_rank+1]) - - CudaTimer().start('q@k') - # sqk: (N h) ((nblock-2) blksize) (3 blksize) - sqk = parallel_stride_qk( - q, k, v, h, block_size, start, end) - req = torch.distributed.all_reduce(sqk, async_op=True) - - # head: (N h) (2 blocksize) L - # head_v: (N h) L d - # col: (N h) ((nblock-2) blocksize) (2 blocksize) - # col_v: (N h) (2 blksize) d - head, head_v, col, col_v = global_qk(q, k, v, h, block_size=block_size) - - # rqk: (N h) ((nblock-2) blksize) (2 blksize) - # rqk_v: (N h) (nblock-2) (2 blksize) d - rqk = randn_qk(q, k, v, h, block_size=block_size) - req.wait() - CudaTimer().stop('q@k') - + CudaTimer().start('qk@v') # sqk_v: (N h) (nblock-2) (3 blksize) d sqk_v = stride_v(v, h, block_size) + # rqk_v: (N h) (nblock-2) (2 blksize) d rqk_v = randn_v(v, h, block_size) - # (N h) ((nblock-2) blksize) L - CudaTimer().start('all_softmax') - head_attn = torch.nn.functional.softmax(head, dim=-1) - head_attn = torch.nn.functional.dropout(head_attn, dropout_p, True, False) - - # (N h) ((nblock-2) blksize) (7 blksize) - middle_attn = torch.cat((col, sqk, rqk), dim=-1) - # (N h) ((nblock-2) blksize) (7 blksize) - middle_attn = torch.nn.functional.softmax(middle_attn, dim=-1) - middle_attn = torch.nn.functional.dropout(middle_attn, dropout_p, True, False) - CudaTimer().stop('all_softmax') - CudaTimer().start('global_qk@v') # (N h) (2 blocksize) L, (N h) L d -> (N h) (2 blksize) d head_output = bmm(head_attn, v, ndim=3) @@ -430,6 +343,7 @@ def parallel_sparse_attn(query: torch.Tensor, # (N h) (2 blksize) d, (N h) ((nblock-2) blksize) d -> (N h) L d output = torch.cat((head_output, middle_output), dim=1) + CudaTimer().stop('qk@v') CudaTimer().start('out_proj') output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d @@ -449,19 +363,29 @@ def dense_attn(query: torch.Tensor, L, N = query.size(0), query.size(1) dim_head = q_proj.size(0) // num_head + CudaTimer().start('to_qkv') q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q.transpose(0, 1).contiguous() # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d q = q * scale # (N h) L d, 1 -> (N h) L d + CudaTimer().stop('to_qkv') + # k = k.transpose(1, 2).contiguous() # (N h) L d -> (N h) d L + + CudaTimer().start('allocation') + attn = torch.empty( + (N * h), L, L, dtype=torch.float16 if args.fp16 else args.fp32, + device=torch.cuda.current_device() + ) + CudaTimer().stop('allocation') + CudaTimer().start('q@k') - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + attn = torch.bmm(q, k.transpose(1, 2), out=attn) # (N h) L d, (N h) d L -> (N h) L L CudaTimer().stop('q@k') # attention mask @@ -484,9 +408,11 @@ def dense_attn(query: torch.Tensor, output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d CudaTimer().stop('qk@v') + CudaTimer().start('out_proj') output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + CudaTimer().stop('out_proj') return output From 9739d25db0867703f42af6fbb81c77cc19156905 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 17 Apr 2022 11:20:37 +0000 Subject: [PATCH 0784/1892] test attention memory consumpiton --- handcraft/gpt3/train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index a2275065..895d8368 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -239,6 +239,9 @@ def __init__(self, num_heads: int = None): ) def forward_(self, x, mask): + # to test attention memory consumpiton: enable this + # start_mem = torch.cuda.memory_allocated() + # x: [seqlen, bs, hidden], np: head num | hn: head dim if self.tp_size > 1: x = IdentityAllreduce.apply(x, self.tp_group) @@ -323,6 +326,11 @@ def forward_(self, x, mask): output = self.dense(context_layer) if self.tp_size > 1: output = AllReduceIdentity.apply(output, self.tp_group) + + # to test attention memory consumpiton: enable this + # end_mem = torch.cuda.memory_allocated() + # print(f'mem: attention memory: {(end_mem - start_mem) / 1024 / 1024} MB') + return output def forward(self, x, mask, recompute=False): From 916058617496929fb5372599858c1b513e2c7e1f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 19 Apr 2022 04:29:14 +0000 Subject: [PATCH 0785/1892] remove allocation overhead --- handcraft/bigbird/sparse.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/handcraft/bigbird/sparse.py b/handcraft/bigbird/sparse.py index b5822faa..05ae5997 100644 --- a/handcraft/bigbird/sparse.py +++ b/handcraft/bigbird/sparse.py @@ -10,7 +10,7 @@ --nproc_per_node=1 \ --nnodes=1 \ handcraft/bigbird/sparse.py \ - --hidden-size 5120 --heads 32 --seqlen 12288 \ + --hidden-size 5120 --heads 32 --seqlen 4096 \ --bs 1 --fp16 """ @@ -268,19 +268,19 @@ def sparse_attn(query: torch.Tensor, q = q * scale # (N h) L d, 1 -> (N h) L d CudaTimer().stop('to_qkv') - # sqk = torch.empty( - # N * h, (nblocks - 2) * block_size, 3 * block_size, - # dtype=torch.float16 if args.fp16 else torch.float32, - # device=torch.cuda.current_device() - # ) - # - # rqk = torch.empty( - # N * h, (nblocks - 2) * block_size, 2 * block_size, - # dtype=torch.float16 if args.fp16 else torch.float32, - # device=torch.cuda.current_device() - # ) + sqk = torch.empty( + N * h, (nblocks - 2) * block_size, 3 * block_size, + dtype=torch.float16 if args.fp16 else torch.float32, + device=torch.cuda.current_device() + ) + + rqk = torch.empty( + N * h, (nblocks - 2) * block_size, 2 * block_size, + dtype=torch.float16 if args.fp16 else torch.float32, + device=torch.cuda.current_device() + ) # we don't need pre-allocation as memory are sufficient - sqk = rqk = None + # sqk = rqk = None CudaTimer().start('q@k') # sqk: (N h) ((nblock-2) blksize) (3 blksize) @@ -297,7 +297,7 @@ def sparse_attn(query: torch.Tensor, CudaTimer().stop('q@k') CudaTimer().start('all_softmax') - # (N h) ((nblock-2) blksize) L + # (N h) (2 blksize) L head_attn = torch.nn.functional.softmax(head, dim=-1) head_attn = torch.nn.functional.dropout(head_attn, dropout_p, True, False) # (N h) ((nblock-2) blksize) (7 blksize) @@ -370,7 +370,7 @@ def dense_attn(query: torch.Tensor, q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d + # v = v.transpose(0, 1) # L (N h) d -> (N h) L d q = q.transpose(0, 1).contiguous() # L (N h) d -> (N h) L d k = k.transpose(0, 1) # L (N h) d -> (N h) L d q = q * scale # (N h) L d, 1 -> (N h) L d @@ -405,7 +405,7 @@ def dense_attn(query: torch.Tensor, CudaTimer().stop('all_softmax') CudaTimer().start('qk@v') - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = torch.bmm(attn, v.transpose(0, 1)) # (N h) L L, (N h) L d -> (N h) L d CudaTimer().stop('qk@v') CudaTimer().start('out_proj') @@ -514,7 +514,7 @@ def random_sample(self): print_each_rank('model weight consumpition:') memory_summary() - CudaTimer(enable=False) + CudaTimer(enable=False).warmup(2) torch.distributed.barrier() iter_num =32 for step in range(iter_num): @@ -534,7 +534,7 @@ def random_sample(self): if step >= 8: CudaTimer().stop('e2e') - torch.cuda.empty_cache() + # torch.cuda.empty_cache() torch.distributed.barrier() if step == 0: From d695de1b9b880b315a8d13dd5941b910897e43be Mon Sep 17 00:00:00 2001 From: lynex Date: Wed, 27 Apr 2022 11:27:54 +0800 Subject: [PATCH 0786/1892] runtime.conv3d fix --- cube/runtime/function/function.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index fe25c860..c0d27aaa 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -29,8 +29,8 @@ def conv3d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso output: N oC oD oH oW """ # switch D, H and W to match torch.nn.functional.pad - padding = [padding[(2 + i) // 2 * (-2) + (i % 2)] for i in range(len(padding))] - input = TorchF.pad(input, padding, 'constant', 0) + pad_padding = [padding[-1 - (i // 2)] for i in range(len(padding) * 2)] + input = TorchF.pad(input, pad_padding, 'constant', 0) return TorchF.conv3d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): From 796b5b6498a6b0da65c0be348e56382e17d31a24 Mon Sep 17 00:00:00 2001 From: lynex Date: Wed, 27 Apr 2022 13:10:45 +0800 Subject: [PATCH 0787/1892] add example wrf2 --- examples/wrf/wrf2.py | 451 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 examples/wrf/wrf2.py diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py new file mode 100644 index 00000000..6d0ffff0 --- /dev/null +++ b/examples/wrf/wrf2.py @@ -0,0 +1,451 @@ +import torch +import torch.nn.functional as F + +torch.set_default_tensor_type(torch.DoubleTensor) + + +class WRF(torch.nn.Module): + def __init__(self, dt, ntau, nz, ny, nx, dz, dy, dx, device): + super().__init__() + # simulation domain settings + self.dt = dt + self.ntau = ntau + self.nx = nx + self.ny = ny + self.nz = nz + self.delta_x = dx + self.delta_y = dy + self.delta_z = dz + + # physics constant + self.g = 9.8 # acceleration of gravity, unit in m/s^2 + self.GAMMA = 1.4 # the ratio of heat capacities for dry air + self.PREF = 101325. # sea level pressure, unit in Pa + self.RD = 287. # gas constant for dry air J*kg^-1*K^-1 + self.RE = 6.4e6 # radius of earth, unit in m + self.OMEGA = 7.292e-5 # angular speed of the Earth s^-1 + + self.device = torch.device(device) + + def init(self, theta, Ptop=250e2): + eta = torch.linspace(0, 1, self.nz + 1, device=self.device) + pi = self.PREF - Ptop + p0 = Ptop + pi * eta + self.p0 = ((p0[:-1] + p0[1:]) / 2).view(self.nz, 1, 1) * torch.ones((1, self.ny, self.nx), device=self.device) + + self.mu0 = torch.ones((self.nz, self.ny, self.nx), device=self.device) * pi + mu1 = torch.zeros((self.nz, self.ny, self.nx), device=self.device) + + self.alpha0 = (self.RD * theta) / self.PREF * (self.p0 / self.PREF)**(-1. / self.GAMMA) + + phi0 = torch.zeros((self.nz + 1, self.ny, self.nx), device=self.device) + phi1 = torch.zeros((self.nz - 1, self.ny, self.nx), device=self.device) + phi0[-1] = self.alpha0[-1] * self.mu0[-1] + for i in range(self.nz - 1, -1, -1): + phi0[i] = self.alpha0[i] * self.mu0[i] * self.delta_z + phi0[i + 1] + self.phi0 = phi0[1:-1] # phi0 with shape (nz - 1, ny, nx) + self.phit = phi0[0].view(1, self.ny, self.nx) # model top hight + self.phis = phi0[-1].view(1, self.ny, self.nx) # earth surface hight + + self.ztop = (self.phit / self.g).view(1, self.ny, self.nx) + + Theta = theta * self.mu0 + + return Theta, phi1, mu1 + + def forward(self, U, V, W, O, Theta, phi1, mu1): + r""" + Args: + U (Tensor): (nz, ny, nx - 1) + V (Tensor): (nz, ny - 1, nx) + W (Tensor): (nz - 1, ny, nx) + O (Tensor): (nz - 1, ny, nx) + Theta (Tensor): (nz, ny, nx) + phi1 (Tensor): (nz - 1, ny, nx) + mu1 (Tensor): (nz, ny, nx) + """ + R_U, R_V, R_W, R_Theta, R_phi, R_mu = self.RHS(U, V, W, O, Theta, phi1, mu1) + U_, V_, W_, O_, Theta_, phi1_, mu1_ = \ + self.step(self.dt / 3, 1, + U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) + + R_U, R_V, R_W, R_Theta, R_phi, R_mu = self.RHS(U_, V_, W_, O_, Theta_, phi1_, mu1_) + # U_, V_, W_, O_, Theta_, phi1_, mu1_ = \ + # self._step(self.dt / 2, self.ntau // 2, + # U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) + U_, V_, W_, O_, Theta_, phi1_, mu1_ = \ + self.step(self.dt / self.ntau, self.ntau // 2, + U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) + + R_U, R_V, R_W, R_Theta, R_phi, R_mu = self.RHS(U_, V_, W_, O_, Theta_, phi1_, mu1_) + U, V, W, O, Theta, phi1, mu1 = \ + self.step(self.dt / self.ntau, self.ntau, + U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) + # U, V, W, O, Theta, phi1, mu1 = \ + # self._step(self.dt, self.ntau, + # U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) + + # print() + # print('R_U\t', R_U.abs().min(), R_U.abs().max()) + # print('R_V\t', R_V.abs().min(), R_V.abs().max()) + # print('R_W\t', R_W.abs().min(), R_W.abs().max()) + # print('R_phi\t', R_phi.abs().min(), R_phi.abs().max()) + # print('R_mu\t', R_mu.abs().min(), R_mu.abs().max()) + # print('R_Theta\t', R_Theta.abs().min(), R_Theta.abs().max()) + + return U, V, W, O, Theta, phi1, mu1 + + def _step(self, dtau, ntau, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu): + U += R_U * dtau + V += R_V * dtau + W += R_W * dtau + Theta += R_Theta * dtau + phi1 += R_phi * dtau + mu1 += R_mu * dtau + + phi = phi1 + self.phi0 + O = self.g * W / self.dz(self.bz(self.pzphi(phi))) + + return U, V, W, O, Theta, phi1, mu1 + + def step(self, dtau, ntau, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu): + # initialize perturbed varibles + U2 = torch.zeros(U.shape, device=self.device) + V2 = torch.zeros(V.shape, device=self.device) + W2 = torch.zeros(W.shape, device=self.device) + O2 = torch.zeros(O.shape, device=self.device) + Theta2 = torch.zeros(Theta.shape, device=self.device) + phi2 = torch.zeros(phi1.shape, device=self.device) + mu2 = torch.zeros(mu1.shape, device=self.device) + pi2 = torch.zeros((self.ny, self.nx), device=self.device) + + phi = self.phi0 + phi1 + mu = self.mu0 + mu1 + alpha = - self.dz(self.pzphi(phi)) / mu + p = self.PREF * (self.RD * Theta / mu / self.PREF / alpha)**self.GAMMA + + for i in range(ntau): + U2, V2, W2, O2, Theta2, phi2, mu2, pi2 = \ + self.ac_step(dtau, + U2, V2, W2, O2, Theta2, phi2, mu2, pi2, + R_U, R_V, R_W, R_Theta, R_phi, R_mu, + U, V, Theta, phi, mu, alpha, p) + + return U + U2, V + V2, W + W2, O + O2, Theta + Theta2, phi1 + phi2, mu1 + mu2 + + def ac_step(self, dtau, + U2, V2, W2, O2, Theta2, phi2, mu2, pi2, + R_U, R_V, R_W, R_Theta, R_phi, R_mu, + U, V, Theta, phi, mu, alpha, p): + r"""one acoustic step""" + # diagnostic variables + alpha2 = - (self.dz(self.pz(phi2)) + alpha * mu2) / mu + cs2 = self.GAMMA * p * alpha # square of sound speed + C = cs2 / mu / alpha**2 + p2 = self.GAMMA * p * (Theta2 / Theta - alpha2 / alpha - mu2 / mu) + theta = Theta / mu + + # prognostic variables + U2_ = U2 + dtau * ( + R_U - self.bx(mu) * ( + self.bx(alpha) * self.dx(p2) + + self.bx(alpha2) * self.dx(self.p0) + + self.dx(self.bz(self.pz(phi2)))) - + self.dx(self.bz(self.pzphi(phi))) * (self.dz(self.bx(self.pzp1(self.bz(p2)))) - self.bx(mu2)) + ) + V2_ = V2 + dtau * ( + R_V - self.by(mu) * ( + self.by(alpha) * self.dy(p2) + + self.by(alpha2) * self.dy(self.p0) + + self.dy(self.bz(self.pz(phi2)))) - + self.dy(self.bz(self.pzphi(phi))) * (self.dz(self.by(self.pzp1(self.bz(p2)))) - self.by(mu2)) + ) + + # W2_ = W2 + dtau * R_W + # O2_ = self.g * W2_ / self.dz(self.bz(self.pzphi(phi))) + # mu2_ = mu2 + dtau * R_mu + + dpi2 = - (self.dx(self.px(U2_ + U)) + self.dy(self.py(V2_ + V))).sum(0) * self.delta_z + pi2 = pi2 + dpi2 * dtau + + O2_ = torch.zeros(O2.shape, device=O2.device) + mu2_ = torch.zeros(mu2.shape, device=mu2.device) + for i in range(1, O2.shape[0] + 1): + O2_[-i] = i * self.delta_z * dpi2 + \ + (self.dx(self.px(U2_)) + self.dy(self.py(V2_)) - R_mu)[-i:].view( + -1, self.ny, self.nx).sum(0) * self.delta_z + for i in range(mu2.shape[0]): + mu2_[i] = pi2 + + # self.O2_ = O2_ + + Theta2_ = Theta2 + dtau * ( + R_Theta + - self.dx(self.px(U2_ * self.bx(theta))) + - self.dy(self.py(V2_ * self.by(theta))) + - self.dz(self.pz(O2_ * self.bz(theta))) + ) + # print('Theta2_:\t', Theta2_.min(), Theta2_.max()) + # Theta2_ = torch.zeros(Theta2_.shape, device=Theta2_.device) + + def f(x): + phi2_ = phi2 + dtau * ( + R_phi - (O2_ * self.dz(self.bz(self.pzphi(phi))) - self.g * (x + W2) * 0.5) / self.bz(mu)) + return ( + R_W + ( + self.dz(C * self.dz(self.pz(phi2_))) + self.dz(self.GAMMA * p * Theta2_ / Theta) - self.bz(mu2_) + + self.dz(C * self.dz(self.pz(phi2))) + self.dz(self.GAMMA * p * Theta2 / Theta) - self.bz(mu2) + ) * 0.5 * self.g + ) * dtau + W2 - x + + W2_ = self.solve_tridiagonal(f) + if torch.abs(f(W2_)).max() > 1e-6: + print("Triangular solver warning:\t", torch.abs(f(W2_)).max()) + W2_ = W2_ / (1 + self.damping(phi, 0.2, self.ztop * 0.75) * dtau) + # print((1 + self.damping(phi, 0.8, self.ztop * 0.75) * dtau)[:, 64, 64]) + + # W2_ = W2 + dtau * R_W + + phi2_ = phi2 + dtau * ( + R_phi - (O2_ * self.dz(self.bz(self.pzphi(phi))) - self.g * (W2_ + W2) * 0.5) / self.bz(mu)) + # phi2_ = phi2 + dtau * R_phi + + return U2_, V2_, W2_, O2_, Theta2_, phi2_, mu2_, pi2 + + def damping(self, phi, gamma, zd): + z = phi / self.g + res = gamma * torch.sin(torch.pi / 2 * (1 - (self.ztop - z) / (self.ztop - zd)))**2 + return res * z.gt(zd).double() + + def RHS(self, U, V, W, O, Theta, phi1, mu1): + mu = self.mu0 + mu1 + phi = self.phi0 + phi1 + alpha = - self.dz(self.pzphi(phi)) / mu + alpha1 = alpha - self.alpha0 + theta = Theta / mu + p = self.PREF * (self.RD * theta / self.PREF / alpha)**self.GAMMA + p1 = p - self.p0 + + u = U / self.bx(mu) + v = V / self.by(mu) + w = W / self.bz(mu) + + R_U = ( + # pressure term + - self.bx(mu) * ( + + self.dx(self.bz(self.pz(phi1))) + + self.bx(alpha) * self.dx(p1) + + self.bx(alpha1) * self.dx(self.p0)) + - self.dx(self.bz(self.pzphi(phi))) * (self.dz(self.bx(self.pzp1(self.bz(p1)))) - self.bx(mu1)) + # advection term + - self.dx(self.bx(self.px(U * u))) + - self.dy(self.bx(self.py(V * self.by(self.bx(self.px(u)))))) + - self.dz(self.bx(self.pz(O * self.bz(self.bx(self.px(u)))))) + ) + R_V = ( + # pressure term + - self.by(mu) * ( + + self.dy(self.bz(self.pz(phi1))) + + self.by(alpha) * self.dy(p1) + + self.by(alpha1) * self.dy(self.p0)) + - self.dy(self.bz(self.pzphi(phi))) * (self.dz(self.by(self.pzp1(self.bz(p1)))) - self.by(mu1)) + # advection term + - self.dx(self.by(self.px(U * self.bx(self.by(self.py(v)))))) + - self.dy(self.by(self.py(V * v))) + - self.dz(self.by(self.pz(O * self.bz(self.by(self.py(v)))))) + ) + R_W = ( + # pressure term + + self.g * (self.dz(p1) - self.bz(self.mu0) * 0.0) - self.bz(mu1) * self.g + # advection term + - self.dx(self.px(self.bz(U) * self.bx(w))) + - self.dy(self.py(self.bz(V) * self.by(w))) + - self.dz(self.bz(self.pz(O * w))) + ) + R_Theta = ( + - self.dx(self.px(U * self.bx(theta))) + - self.dy(self.py(V * self.by(theta))) + - self.dz(self.pz(O * self.bz(theta))) + ) + R_phi = ( + # advection term + - self.bx(self.px(self.bz(U) * self.dx(phi))) + - self.by(self.py(self.bz(V) * self.dy(phi))) + - O * self.dz(self.bz(self.pzphi(phi))) + # gravity term + + self.g * W + ) / self.bz(mu) + R_mu = ( + # advection term + - self.dx(self.px(U)) + - self.dy(self.py(V)) + - self.dz(self.pz(O)) + ) + + return R_U, R_V, R_W, R_Theta, R_phi, R_mu + + def dx(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) / self.delta_x + + def dy(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 2, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) / self.delta_y + + def dz(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) / self.delta_z + + def px(self, X): + return F.pad(X, (1, 1), "circular") + + def py(self, X): + nz, ny, nx = X.shape + return F.pad(X.view(1, nz, ny, nx), (0, 0, 1, 1), "constant").view(nz, ny + 2, nx) + + def pz(self, X): + nz, ny, nx = X.shape + return F.pad(X.view(1, 1, nz, ny, nx), (0, 0, 0, 0, 1, 1), "constant").view(nz + 2, ny, nx) + + def pzphi(self, X): + """pad phi in z axis""" + return torch.cat((self.phit, X, self.phis), 0) + + def pzp1(self, X): + """pad p1 in z axis""" + nz, ny, nx = X.shape + p1t = torch.zeros((1, ny, nx), device=X.device) + p1s = X[-1].view(1, ny, nx) + return torch.cat((p1t, X, p1s), 0) + + def bx(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) / 2. + + def by(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) / 2. + + def bz(self, X): + nz, ny, nx = X.shape + filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) + return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) / 2. + + def solve_tridiagonal(self, f): + r"""Solve tridiagonal system f(x) = Ax - b = 0 + + Args: + f (Callable, return Tensor): Tridiagonal system (nz - 1, ny, nx) -> (nz - 1, ny, nx) + + Returns: + Tensor: Solution of the linear system with shape (D, H, W) + """ + b = - f(torch.zeros((self.nz - 1, self.ny, self.nx), device=self.device)) + + idx0 = torch.tensor([1., 0, 0], device=self.device).view(3, 1, 1) + idx0 = idx0.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] + r0 = f(idx0) + b + + idx1 = torch.tensor([0., 1, 0], device=self.device).view(3, 1, 1) + idx1 = idx1.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] + r1 = f(idx1) + b + + idx2 = torch.tensor([0., 0, 1], device=self.device).view(3, 1, 1) + idx2 = idx2.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] + r2 = f(idx2) + b + + d = (torch.stack([r0, r1, r2], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1) + l = (torch.stack([r2, r0, r1], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1)[1:] + u = (torch.stack([r1, r2, r0], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1)[:-1] + + # forward sweep + for i in range(1, d.shape[0]): + w = l[i - 1] / d[i - 1] + d[i] = d[i] - w * u[i - 1] + b[i] = b[i] - w * b[i - 1] + + # backward substitution + x = torch.zeros(b.shape, device=b.device) + x[-1] = b[-1] / d[-1] + for i in range(x.shape[0] - 2, -1, -1): + x[i] = (b[i] - u[i] * x[i + 1]) / d[i] + + return x + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + from matplotlib.ticker import ScalarFormatter + + nz = 16 + dz = 1. / 16 + ny = 128 + dy = 1e3 + nx = 128 + dx = 1e3 + + dt = 1. + + x = torch.linspace(-1., 1., 128).cuda() + y = torch.linspace(-1., 1., 128).cuda() + theta0 = (torch.linspace(1, 0, nz).cuda() * 500 + 300).view(nz, 1, 1) * torch.ones((1, ny, nx)).cuda() + theta1 = torch.exp(-0.5 * (x / 0.1)**2).view(1, 1, nx) * torch.exp(-0.5 * (y / 0.1)**2).view(1, ny, 1) * 0.01 + wrf = WRF(dt, 10, nz, ny, nx, dz, dy, dx, 'cuda') + Theta, phi1, mu1 = wrf.init(theta0) + Theta += theta1 * wrf.mu0 + + U = torch.zeros((nz, ny, nx - 1)).cuda() + V = torch.zeros((nz, ny - 1, nx)).cuda() + W = torch.zeros((nz - 1, ny, nx)).cuda() + O = torch.zeros((nz - 1, ny, nx)).cuda() + + for i in range(10): + U, V, W, O, Theta, phi1, mu1 = wrf(U, V, W, O, Theta, phi1, mu1) + mu = wrf.mu0 + mu1 + u = U / wrf.bx(mu) + v = V / wrf.by(mu) + w = W / wrf.bz(mu) + o = O / wrf.bz(mu) + theta = Theta / mu + + interval = 1 + if i % interval == 0: + # plt.cla() + # fig, ax = plt.subplots(2, 3, figsize=(12, 6)) + # + # ctf = ax[0, 0].contourf(u[nz // 2, :, :].cpu().numpy(), levels=50, cmap='jet') + # ax[0, 0].set_title('u') + # plt.colorbar(ctf, ax=ax[0, 0], format='%.1e') + # + # ctf = ax[0, 1].contourf(v[nz // 2, :, :].cpu().numpy(), levels=50, cmap='jet') + # ax[0, 1].set_title('v') + # plt.colorbar(ctf, ax=ax[0, 1], format='%.1e') + # + # ctf = ax[0, 2].contourf(w[:, 32, :].cpu().numpy(), levels=50, cmap='jet') + # ax[0, 2].set_title('w') + # plt.colorbar(ctf, ax=ax[0, 2], format='%.1e') + # + # ctf = ax[1, 0].contourf(mu1[nz // 2, :, :].cpu().numpy(), levels=50, cmap='jet') + # ax[1, 0].set_title(r'$\mu^\prime$') + # plt.colorbar(ctf, ax=ax[1, 0], format='%.1e') + # + # ctf = ax[1, 1].contourf(phi1[:, 32, :].cpu().numpy(), levels=50, cmap='jet') + # ax[1, 1].set_title(r'$\phi^\prime$') + # plt.colorbar(ctf, ax=ax[1, 1], format='%.1e') + # + # ctf = ax[1, 2].contourf(o[:, 32, :].cpu().numpy(), levels=50, cmap='jet') + # ax[1, 2].set_title(r'$\omega$') + # plt.colorbar(ctf, ax=ax[1, 2], format='%.1e') + # + # fig.text(0.01, 0.95, f't={i * dt}s', size=18) + # plt.tight_layout() + # plt.savefig(f'res/res{i // interval}.jpeg', dpi=300) + # plt.close() + # plt.clf() + + print(i) From 6b291573ad000bac49ad4e1a925738522d9bc78e Mon Sep 17 00:00:00 2001 From: lynex Date: Fri, 29 Apr 2022 17:27:39 +0800 Subject: [PATCH 0788/1892] plan revisit with example --- .../playground/dag/graph_manipulation.py | 429 ++++++++++++++++++ 1 file changed, 429 insertions(+) create mode 100644 handcraft/playground/dag/graph_manipulation.py diff --git a/handcraft/playground/dag/graph_manipulation.py b/handcraft/playground/dag/graph_manipulation.py new file mode 100644 index 00000000..19909793 --- /dev/null +++ b/handcraft/playground/dag/graph_manipulation.py @@ -0,0 +1,429 @@ +from enum import Enum + +# class NodeType(Enum): +# UNKNOWN = 0 +# DATALOADER = 1 +# FORWARD = 2 +# BACKWARD_A = 3 +# BACKWARD_W = 4 +# OPTIMIZER = 5 + +nodeList = [] +global_node_id = -1 + + +def new_node_id(): + global global_node_id + global_node_id += 1 + return global_node_id + + +def last_node(last_step=1): + assert len(nodeList) >= last_step + return nodeList[-last_step] + + + + +class AlgorithmMgr: + batch_split: str + replica: str + + def __init__(self): + self.batch_split = 'batch_split' + self.replica = 'replica' + +class Operator: + algo: AlgorithmMgr + def __init__(self): + self.algo = AlgorithmMgr() + +class Node: + id: int + inputs: [] + outputs: [] + removed: bool + op: Operator + + def __init__(self): + super().__init__() + self.removed = False + self.id = new_node_id() + self.inputs = [] + self.outputs = [] + self.op = Operator() + nodeList.append(self) + + def spawn(self, portion: str=None): + node = self.__class__() #create same type + node.inputs = [t.spawn() for t in self.inputs] + node.outputs = [t.spawn() for t in self.outputs] + return node + + def __str__(self): + return "Node({}), {}\tinput:{}\toutput:{} ".format( + self.id, str(type(self)).lstrip(''), + '\t'.join([str(x) for x in self.inputs] if len(self.inputs) > 0 else ""), + '\t'.join([str(x) for x in self.outputs] if len(self.outputs) > 0 else "")) + + +class NodeData(Node): + def __init__(self): + super(NodeData, self).__init__() + # self.type = NodeType.DATALOADER + + +class NodeFwd(Node): + def __init__(self): + super(NodeFwd, self).__init__() + # self.type = NodeType.FORWARD + + +class NodeBwd(Node): + def __init__(self): + super(NodeBwd, self).__init__() + + +class NodeBwdA(NodeBwd): + def __init__(self): + super(NodeBwd, self).__init__() + # self.type = NodeType.BACKWARD_A + + +class NodeBwdW(NodeBwd): + def __init__(self): + super(NodeBwd, self).__init__() + # self.type = NodeType.BACKWARD_W + + +class NodeOpt(Node): + def __init__(self): + super(NodeOpt, self).__init__() + # self.type = NodeType.OPTIMIZER + + +# for logic tensor +class TensorType(Enum): + UNKNOWN = 0 + WEIGHT = 1 + WEIGHT_UPDATED = 2 + ACTIVATION = 3 + GRADIENT_A = 4 + GRADIENT_W = 5 + OPTIMIZER_STATE = 6 + LOSS = 7 + + +logicTensorList = [] +global_logic_tensor_id = -1 + + +def new_logic_tensor_id(): + global global_logic_tensor_id + global_logic_tensor_id += 1 + return global_logic_tensor_id + + +def last_logic_tensor(last_step=1): + assert len(logicTensorList) >= last_step + return logicTensorList[-last_step] + + +class LogicTensor: + id: int + type: TensorType + + def __init__(self, tensor_type=TensorType.UNKNOWN): + super().__init__() + self.id = new_logic_tensor_id() + self.type = tensor_type + + +tensorList = [] +global_tensor_id = -1 + + +def new_tensor_id(): + global global_tensor_id + global_tensor_id += 1 + return global_tensor_id + + +def last_tensor(last_step=1): + assert len(tensorList) >= last_step + return tensorList[-last_step] + + +class Tensor: + id: int + logic: LogicTensor + portion: str + + def new(self): + pass + + def __init__(self, tensor_type=TensorType.UNKNOWN, exist_tensor=None, portion=None): + super().__init__() + self.id = new_tensor_id() + if exist_tensor is None: + self.logic = LogicTensor(tensor_type) + self.portion = 'full' + else: + self.logic = exist_tensor.logic + self.portion = exist_tensor.portion + if portion is not None: + self.portion += '>' + portion + tensorList.append(self) + + def __getattr__(self, attr): + if(attr == 'type'): + return self.logic.type + else: + return self.attr + + def __str__(self): + return ("Tensor({}), {} of ({} {})".format( + self.id, + self.portion, + self.logic.id, + str(self.type).lstrip('TensorType.'))) + + def spawn(self, portion:str=None): + return Tensor(exist_tensor=self, portion=portion) + +class Graph: + nodes: [] + + def find_input(self, node: Node, tensor_type: TensorType): + ret = list(filter(lambda x: x.type == tensor_type, node.inputs)) + assert len(ret) > 0 + return ret[0] + + def create_sample_graph(self): + op_num = 2 + + for idx in range(1): # sample data loader + node = NodeData() # Node(NodeType.DATALOADER) + node.outputs.append(Tensor(TensorType.ACTIVATION)) + self.nodes.append(node) + + for idx in range(op_num): # forward ops + node = NodeFwd() # Node(NodeType.FORWARD) + node.inputs.append(last_tensor()) + node.inputs.append(Tensor(TensorType.WEIGHT)) + node.outputs.append(Tensor(TensorType.ACTIVATION)) + self.nodes.append(node) + + for idx in range(1): # label data loader + node = NodeData() # Node(NodeType.DATALOADER) + node.outputs.append(Tensor(TensorType.ACTIVATION)) + self.nodes.append(node) + + for idx in range(1): # loss + node = NodeFwd() # Node(NodeType.FORWARD) + node.inputs.append(last_tensor()) + node.outputs.append(Tensor(TensorType.LOSS)) + self.nodes.append(node) + + for fwd_node in list(filter(lambda x: type(x) is NodeFwd, self.nodes))[::-1]: # backward ops + out_gradient = last_tensor() + if len(fwd_node.inputs) == 2: + # computing weight's gradient + node = NodeBwdW() # Node(NodeType.BACKWARD_W) + node.inputs.append(out_gradient) # out_g_act + node.inputs.append(self.find_input(fwd_node, TensorType.WEIGHT)) + node.outputs.append(Tensor(TensorType.GRADIENT_W)) + self.nodes.append(node) + if len(fwd_node.inputs) >= 1: + # computing activation's gradient + node = NodeBwdA() # Node(NodeType.BACKWARD_A) + node.inputs.append(out_gradient) + node.inputs.append(self.find_input(fwd_node, TensorType.ACTIVATION)) + node.outputs.append(Tensor(TensorType.GRADIENT_A)) + self.nodes.append(node) + else: + assert False + + for bwd_w_node in list(filter(lambda x: type(x) is NodeBwdW, self.nodes)): # optimizer + node = NodeOpt() # Node(NodeType.OPTIMIZER) + node.inputs.append(self.find_input(bwd_w_node, TensorType.WEIGHT)) # WEIGHT + node.inputs.append(bwd_w_node.outputs[0]) + node.inputs.append(Tensor(TensorType.OPTIMIZER_STATE)) + node.outputs.append(Tensor(TensorType.WEIGHT_UPDATED)) + self.nodes.append(node) + + def __init__(self, create_sample=False): + super().__init__() + self.nodes = [] + + if create_sample: + self.create_sample_graph() + + def __str__(self): + # for node in self.nodes: + return '\n'.join([str(x) if not x.removed else "DEL "+str(x) for x in self.nodes]) + + +graph = Graph(create_sample=True) +print('graph = \n{}'.format(graph)) +global_new_graph = Graph() + +print('nodeList[{}] = \n{}'.format(len(nodeList), nodeList)) +print('tensorList[{}] = \n{}'.format(len(tensorList), tensorList)) + + +class Config: + num: int + + +class Device: + pass + + +class Parallelizer: + def run(self, g: Graph, config: Config) -> Graph: + return None + + +def trans(node: Node, algo, num: int) -> [Node]: + node.removed = True + nodes = [node.spawn() for i in range(num)] + if algo == 'replica': + global_new_graph.nodes.extend(nodes) + return nodes + elif algo == 'batch_split': + for idx, nd in enumerate(nodes): + for ts in nd.inputs + nd.outputs: + ts.portion += '>batch-{}/{}'.format(idx, num) + global_new_graph.nodes.extend(nodes) + return nodes + else: + assert False + + + +def sched_s(node: Node, dev: Device) -> None: + print("sched_s...") + pass + + +def sched_t(node_before: Node, node_after: Node) -> bool: + print("sched_t...") + pass + + +def sched_t(nodes: [Node]) -> bool: + pass + + +def set_affinity(): + pass + +from collections import namedtuple +def idxzip(list: []): + Entry = namedtuple('Entry', ['idx', 'item']) + # return [{'idx': i, 'item': x} for i, x in enumerate(list)] + return [Entry(i, x) for i, x in enumerate(list)] + + +### TODO how about Tx in flexflow GSPMD etc.? + +# traditional data-parallel process: +# tx start +# 1. replicated graph g -> g' * N +# 2. change batch-size of g' +# 3. insert gradient allreduce (manually, can auto-gen in our sys) +# tx end + + +class DataParallelParallelizer(Parallelizer): + def run(self, g: Graph, config: Config) -> Graph: + global global_new_graph + global_new_graph.nodes.clear() + + for node in g.nodes: + if isinstance(node, (NodeData, NodeFwd, NodeBwd)): + nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit + map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + elif isinstance(node, (NodeOpt)): + nodes = trans(node, node.op.algo.replica, config.num) + map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + else: + print(node) + print(type(node)) + assert False + + global_new_graph.nodes.extend([nd for nd in graph.nodes if not nd.removed]) + return global_new_graph + + +class DataParallelZeROParallelizer(Parallelizer): + def run(self, g: Graph, config: Config) -> Graph: + for node in g.nodes: + if type(node) in [NodeData, NodeFwd, NodeBwd]: + nodes = node.trans(node.op.algo.batch_split, config.num) # by batch-dim-slit + map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + elif type(node) in [NodeOpt]: + nodes = node.trans(node.op.algo.split, config.num) + map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + else: + assert False + + +class GradientAccumulationParallelizer(Parallelizer): + def run(self, g: Graph, config: Config) -> Graph: + for node in g.nodes: + if type(node) in [NodeData, NodeFwd, NodeBwd]: + nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit + sched_t(nodes) # sequential order + elif type(node) in [NodeOpt]: + pass + else: + assert False + + +def node_to_stage(g: Graph, config: Config) -> {}: # return node->stage mapping + pass + + +class GPipeParallelizer(Parallelizer): + def run(self, g: Graph, config: Config) -> Graph: + n2stage = node_to_stage(g, config) + for node in g.nodes: + device = n2stage(node) + if type(node) in [NodeData, NodeFwd, NodeBwd]: + nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit + sched_t(nodes) # sequential order + map(lambda x: sched_s(node=x, dev=device), nodes) # assign same stage device + elif type(node) in [NodeOpt]: + sched_s(node, device) + else: + assert False + + +class TensorParallelParallelizer(Parallelizer): + def run(self, g: Graph, config: Config) -> Graph: + for node in g.nodes: + if type(node) in [NodeFwd, NodeBwd, NodeOpt]: + nodes = trans(node, node.op.algo.tensor_split, config.num) # by tensor-dim-slit + map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + elif type(node) in [NodeData]: + nodes = trans(node, node.op.algo.replica, config.num) + map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + else: + assert False + + +class Recompute(Parallelizer): + def run(self, g: Graph, config: Config) -> Graph: + for node in g.nodes: + if type(node) in [NodeFwd]: + nodes = trans(node, node.op.algo.replica, 2) + set_affinity() # break dependencies op0.fwd -> op1.fwd; op0.fwd' -> op0.bwd + + +para = DataParallelParallelizer() +config = Config() +config.num = 2 +global_new_graph = para.run(graph, config) +print('new_graph = \n{}'.format(global_new_graph)) From cdefd41251ce0262fb5807f07609032138b5e976 Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 5 May 2022 11:54:28 +0800 Subject: [PATCH 0789/1892] nit api --- .../playground/dag/graph_manipulation.py | 192 +++++++++++++++--- 1 file changed, 159 insertions(+), 33 deletions(-) diff --git a/handcraft/playground/dag/graph_manipulation.py b/handcraft/playground/dag/graph_manipulation.py index 19909793..b7d931fc 100644 --- a/handcraft/playground/dag/graph_manipulation.py +++ b/handcraft/playground/dag/graph_manipulation.py @@ -1,4 +1,5 @@ from enum import Enum +import sys # class NodeType(Enum): # UNKNOWN = 0 @@ -8,6 +9,18 @@ # BACKWARD_W = 4 # OPTIMIZER = 5 +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + nodeList = [] global_node_id = -1 @@ -32,6 +45,8 @@ class AlgorithmMgr: def __init__(self): self.batch_split = 'batch_split' self.replica = 'replica' + self.split = 'split' + self.tensor_split = 'tensor_split' class Operator: algo: AlgorithmMgr @@ -67,6 +82,10 @@ def __str__(self): '\t'.join([str(x) for x in self.outputs] if len(self.outputs) > 0 else "")) + def slim(self): + return "Node({}), {}".format( + self.id, str(type(self)).lstrip('')) + class NodeData(Node): def __init__(self): super(NodeData, self).__init__() @@ -268,8 +287,8 @@ def __str__(self): print('graph = \n{}'.format(graph)) global_new_graph = Graph() -print('nodeList[{}] = \n{}'.format(len(nodeList), nodeList)) -print('tensorList[{}] = \n{}'.format(len(tensorList), tensorList)) +# print('nodeList[{}] = \n{}'.format(len(nodeList), nodeList)) +# print('tensorList[{}] = \n{}'.format(len(tensorList), tensorList)) class Config: @@ -297,28 +316,45 @@ def trans(node: Node, algo, num: int) -> [Node]: ts.portion += '>batch-{}/{}'.format(idx, num) global_new_graph.nodes.extend(nodes) return nodes + elif algo == 'split': #elementwise split + for idx, nd in enumerate(nodes): + for ts in nd.inputs + nd.outputs: + ts.portion += '>flat-{}/{}'.format(idx, num) + global_new_graph.nodes.extend(nodes) + return nodes + elif algo == 'tensor_split': + for idx, nd in enumerate(nodes): + for ts in nd.inputs + nd.outputs: + ts.portion += '>tensor-{}/{}'.format(idx, num) + global_new_graph.nodes.extend(nodes) + return nodes else: assert False - def sched_s(node: Node, dev: Device) -> None: - print("sched_s...") + print("{}sched_s...{} @ {}{}".format(bcolors.OKGREEN, node.slim(), dev, bcolors.ENDC)) pass -def sched_t(node_before: Node, node_after: Node) -> bool: - print("sched_t...") - pass +def sched_t_pair(node_before: Node, node_after: Node) -> bool: + print("{}sched_t...{}-> {}{}".format(bcolors.OKBLUE, node_before.slim(), node_after.slim(), bcolors.ENDC)) + #TODO legal check + return True def sched_t(nodes: [Node]) -> bool: - pass + for i in range (len(nodes) - 1): + if not sched_t_pair(nodes[i], nodes[i+1]): + return False + return True -def set_affinity(): +def set_affinity(producer_node, consumer_node): + print("{}affinity...{}-> {}{}".format(bcolors.OKCYAN, producer_node.slim(), consumer_node.slim(), bcolors.ENDC)) pass + from collections import namedtuple def idxzip(list: []): Entry = namedtuple('Entry', ['idx', 'item']) @@ -326,6 +362,11 @@ def idxzip(list: []): return [Entry(i, x) for i, x in enumerate(list)] +def xmap(func, iterables): + if list(map(func, iterables)) is None: + print('xmap ERROR') + + ### TODO how about Tx in flexflow GSPMD etc.? # traditional data-parallel process: @@ -341,89 +382,174 @@ def run(self, g: Graph, config: Config) -> Graph: global global_new_graph global_new_graph.nodes.clear() + # ---------------- for node in g.nodes: if isinstance(node, (NodeData, NodeFwd, NodeBwd)): - nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit - map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-split + xmap(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) elif isinstance(node, (NodeOpt)): - nodes = trans(node, node.op.algo.replica, config.num) - map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + nodes = trans(node, node.op.algo.replica, config.num) #replicated optimizers + xmap(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) else: print(node) print(type(node)) assert False + # ---------------- - global_new_graph.nodes.extend([nd for nd in graph.nodes if not nd.removed]) + global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] return global_new_graph class DataParallelZeROParallelizer(Parallelizer): def run(self, g: Graph, config: Config) -> Graph: + global global_new_graph + global_new_graph.nodes.clear() + + # ---------------- for node in g.nodes: - if type(node) in [NodeData, NodeFwd, NodeBwd]: - nodes = node.trans(node.op.algo.batch_split, config.num) # by batch-dim-slit + if isinstance(node, (NodeData, NodeFwd, NodeBwd)): + nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-split map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) - elif type(node) in [NodeOpt]: - nodes = node.trans(node.op.algo.split, config.num) + elif isinstance(node, (NodeOpt)): + nodes = trans(node, node.op.algo.split, config.num) #split optimizers map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) else: assert False + # ---------------- + + global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] + return global_new_graph class GradientAccumulationParallelizer(Parallelizer): def run(self, g: Graph, config: Config) -> Graph: + global global_new_graph + global_new_graph.nodes.clear() + + # ---------------- for node in g.nodes: - if type(node) in [NodeData, NodeFwd, NodeBwd]: + if isinstance(node, (NodeData, NodeFwd, NodeBwd)): nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit sched_t(nodes) # sequential order - elif type(node) in [NodeOpt]: + elif isinstance(node, (NodeOpt)): pass else: assert False + # ---------------- + + global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] + return global_new_graph def node_to_stage(g: Graph, config: Config) -> {}: # return node->stage mapping - pass + ret = {} + nodes = g.nodes # TODO topo forward traversal + fwd_node = list(filter(lambda x: type(x) is NodeFwd, nodes)) + + per_stage_size = len(nodes) // config.stages + for node in nodes: + # TODO replace dummy assignment + ret[node] = 0 + + return ret class GPipeParallelizer(Parallelizer): def run(self, g: Graph, config: Config) -> Graph: + global global_new_graph + global_new_graph.nodes.clear() + + # ---------------- n2stage = node_to_stage(g, config) for node in g.nodes: - device = n2stage(node) - if type(node) in [NodeData, NodeFwd, NodeBwd]: + device = n2stage[node] + if isinstance(node, (NodeData, NodeFwd, NodeBwd)): nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit sched_t(nodes) # sequential order - map(lambda x: sched_s(node=x, dev=device), nodes) # assign same stage device - elif type(node) in [NodeOpt]: + xmap(lambda x: sched_s(node=x, dev=device), nodes) # assign same stage device + elif isinstance(node, (NodeOpt)): sched_s(node, device) else: assert False + # ---------------- + + global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] + return global_new_graph class TensorParallelParallelizer(Parallelizer): def run(self, g: Graph, config: Config) -> Graph: + global global_new_graph + global_new_graph.nodes.clear() + + # ---------------- for node in g.nodes: - if type(node) in [NodeFwd, NodeBwd, NodeOpt]: + if isinstance(node, (NodeFwd, NodeBwd, NodeOpt)): nodes = trans(node, node.op.algo.tensor_split, config.num) # by tensor-dim-slit - map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) - elif type(node) in [NodeData]: + xmap(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + elif isinstance(node, (NodeData)): nodes = trans(node, node.op.algo.replica, config.num) - map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + xmap(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) else: assert False + # ---------------- + + global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] + return global_new_graph + + +def find_consumers(graph: Graph, tensor: Tensor): + ret = [] + for node in graph.nodes: + if any([input_tensor.logic == tensor.logic for input_tensor in node.inputs]): + ret.append(node) + return ret + +def find_producers(graph: Graph, tensor: Tensor): + ret = [] + for node in graph.nodes: + if any([output_tensor.logic == tensor.logic for output_tensor in node.outputs]): + ret.append(node) + return ret class Recompute(Parallelizer): def run(self, g: Graph, config: Config) -> Graph: + global global_new_graph + global_new_graph.nodes.clear() + + # ---------------- for node in g.nodes: - if type(node) in [NodeFwd]: - nodes = trans(node, node.op.algo.replica, 2) - set_affinity() # break dependencies op0.fwd -> op1.fwd; op0.fwd' -> op0.bwd + if isinstance(node, (NodeFwd)): + origin_fwd, recompute_fwd = trans(node, node.op.algo.replica, 2) + consumers = find_consumers(g, origin_fwd.outputs[0]) + for consumer in consumers: + if isinstance(consumer, NodeFwd): + set_affinity(origin_fwd, consumer) # break dependencies op0.fwd -> op1.fwd; op0.fwd' -> op0.bwd + else: + set_affinity(recompute_fwd, consumer) # break dependencies op0.fwd -> op1.fwd; op0.fwd' -> op0.bwd + producers = list(filter(lambda x: isinstance(x, NodeBwd), find_producers(g, consumer.inputs[0]))) + for producer in producers: + sched_t_pair(producer, recompute_fwd) + # ---------------- + + global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] + return global_new_graph + +class ActivationSwap(Parallelizer): + def run(self, g: Graph, config: Config) -> Graph: + pass + +# para = DataParallelParallelizer() +# para = DataParallelZeROParallelizer() +# para = GradientAccumulationParallelizer() +# para = GPipeParallelizer() +# para = TensorParallelParallelizer() +para = Recompute() -para = DataParallelParallelizer() config = Config() config.num = 2 +config.stages = 2 global_new_graph = para.run(graph, config) print('new_graph = \n{}'.format(global_new_graph)) From 234dcdaea163d1f029f266770d055541c7cc21de Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 May 2022 16:53:25 +0800 Subject: [PATCH 0790/1892] init composer --- cube/search/{__init__.PY => __init__.py} | 0 cube/search/composer.py | 283 +++++++++++++++++++++++ requirements.txt | 5 +- 3 files changed, 285 insertions(+), 3 deletions(-) rename cube/search/{__init__.PY => __init__.py} (100%) create mode 100644 cube/search/composer.py diff --git a/cube/search/__init__.PY b/cube/search/__init__.py similarity index 100% rename from cube/search/__init__.PY rename to cube/search/__init__.py diff --git a/cube/search/composer.py b/cube/search/composer.py new file mode 100644 index 00000000..3b9a40e0 --- /dev/null +++ b/cube/search/composer.py @@ -0,0 +1,283 @@ +""" +Abstraction layer for microb-batch execution plan merge. +""" + +from typing import Dict, List, Tuple +import numpy as np +from enum import Enum + + +class Block: + """ + Execution block for a MicroPlan + """ + + class BType(Enum): + FW = 'forward' + BW = 'backward' + + def __init__(self, mid: int, pos: Tuple[int, int], btype: BType): + self.mid: int = mid + self.type = btype + self._position = tuple(pos) + # dependency track + self.before: List[Block] = list() + self.after: List[Block] = list() + + @property + def position(self): + return self._position + + @position.setter + def position(self, pos: Tuple[int, int]): + if len(pos) != 2: + raise ValueError("Expected positition to be Tuple[int, int]") + self._position = pos + + @staticmethod + def add_dependency(before, after): + if not (isinstance(before, Block) and isinstance(after, Block)): + raise ValueError("Expected before and after to be Block") + if after not in before.after: + before.after.append(after) + if before not in after.before: + after.before.append(before) + + def __repr__(self): + return f'f{self.mid}' if self.type == Block.BType.FW else f'b{self.mid}' + + +class MicroPlan: + + def __init__(self, mid: int, ndevs: int): + """ + Create an empty microbatch execution plan + + mid: microbatch id + ndevs: number of devices + """ + self.mid = mid + self.blocks: Dict[Tuple[int, int], Block] = dict() + self.execplan = np.zeros((ndevs, ndevs * 2), dtype=int) + + @property + def ndevs(self): + return self.execplan.shape[0] + + @property + def nsteps(self): + return self.execplan.shape[1] + + def expand_to(self, nsteps: int): + if self.nsteps < nsteps: + extend = nsteps - self.nsteps + self.execplan = np.pad(self.execplan, ((0,0),(0,extend))) + + def block(self, dev: int, step: int): + if (dev, step) not in self.blocks: + return None + return self.blocks[(dev, step)] + + def add_block(self, pos: Tuple[int, int], btype: Block.BType) -> Block: + """ + Add a execution block + """ + dev, step = pos + if dev >= self.ndevs: + raise ValueError("device out of scope") + if step >= self.nsteps: + self.expand_to(step + 1) + if self.execplan[dev, step] != 0: + raise ValueError(f"Postition {pos} already has blocks") + block = Block(self.mid, pos, btype) + self.execplan[dev, step] += 1 + self.blocks[(dev, step)] = block + return block + + def add_dependency(self, blocks: List[Block]): + """ + Add dependent blocks: + block[0] < block[1] < block[2] < ... + """ + for idx in range(len(blocks) - 1): + Block.add_dependency(blocks[idx], blocks[idx+1]) + + def shift(self, block: Block): + """ + The primitive during search + """ + # check block in this plan + if block not in self.blocks.values(): + raise ValueError("Block not in this micro plan") + dev, step = block.position + for after_block in block.after: + if step + 1 == after_block.position[1]: + self.shift(after_block) + self.execplan[dev, step] = 0 + if step + 1 >= self.nsteps: + self.expand_to(self.nsteps + 1) + self.execplan[dev, step+1] = 1 + # update block and self.blocks + block.position = (dev, step+1) + del self.blocks[(dev, step)] + self.blocks[(dev, step+1)] = block + + def unshift(self, block: Block): + """ + reverse shift, for search only + """ + dev, step = block.position + if step == 0: + raise ValueError("unshift a block with step = 0") + # shift back + self.execplan[dev, step] = 0 + self.execplan[dev, step-1] = 1 + block.position = (dev, step-1) + del self.blocks[(dev, step)] + self.blocks[(dev, step-1)] = 1 + # shift back shifted blocks + for after_block in block.after: + if step + 1 == after_block.position[1]: + self.unshift(after_block) + + def __repr__(self): + namelen = 2 + self.mid // 10 + dscp = '' + for dev in range(self.ndevs): + for step in range(self.nsteps): + block = self.block(dev, step) + if block is None: + dscp += '-' * namelen + ' ' + else: + # TODO: 2 replace to namelen + dscp += '{: <2}'.format(repr(block)) + ' ' + dscp += '\n' + return dscp + + +class SchedulePlan: + + def __init__(self, micros: List[MicroPlan]): + self.micros = micros + + # get schedules + max_steps = max(micro.nsteps for micro in micros) + for micro in micros: + micro.expand_to(max_steps) + plans = tuple(micro.execplan for micro in micros) + schedule = np.sum(np.stack(plans, axis=-1), axis=-1) + if len(np.where(self.schedule > 1)[0]) > 0: + raise ValueError("micro plans are not composable") + # cut off redundant steps + for idx in range(schedule.shape[1]): + if np.sum(schedule[:, -idx-1]) != 0: + break + self.schedule = schedule[:, :-idx] if idx > 0 else schedule + + # get blocks + self.blocks = dict() + for micro in micros: + self.blocks.update(micro.blocks) + + @property + def ndevs(self): + return self.schedule.shape[0] + + @property + def nsteps(self): + return self.schedule.shape[1] + + def block(self, dev: int, step: int): + if (dev, step) not in self.blocks: + return None + return self.blocks[(dev, step)] + + @staticmethod + def composable(micros: List[MicroPlan]) -> bool: + max_steps = max(micro.nsteps for micro in micros) + for micro in micros: + micro.expand_to(max_steps) + plans = tuple(micro.execplan for micro in micros) + schedule = np.sum(np.stack(plans, axis=-1), axis=-1) + devids = np.where(schedule > 1)[0] + return len(devids) == 0 + + @staticmethod + def conflict(micros: List[MicroPlan], step: int) -> bool: + max_steps = max(micro.nsteps for micro in micros) + for micro in micros: + micro.expand_to(max_steps) + plans = tuple(micro.execplan[:,step] for micro in micros) + schedule = np.sum(np.stack(plans, axis=-1), axis=-1) + cmicros = [] + cblocks = [] + devids, steps = np.where(schedule > 1) + for dev, step in zip(devids, steps): + for micro in micros: + if micro.block[dev, step] is not None: + cmicros.append(micro) + cblocks.append(micro.block[dev, step]) + return cmicros, cblocks + + def __repr__(self): + nmicros = len(self.micros) + namelen = 2 + nmicros // 10 + dscp = '' + for dev in range(self.ndevs): + for step in range(self.nsteps): + block = self.block(dev, step) + if block is None: + dscp += '-' * namelen + ' ' + else: + # TODO: 2 replace to namelen + dscp += '{: <2}'.format(repr(block)) + ' ' + dscp += '\n' + return dscp + + def visualize(self, outfile=None): + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + plt.close('all') + + +class Composer: + + @staticmethod + def premise(fn, ndevs: int): + micros = fn(ndevs) + return micros + + @staticmethod + def schedule(micros, step=0): + # DFS search + while not SchedulePlan.composable(micros): + cmicros, cblocks = SchedulePlan.conflict(micros, step) + if len(cmicros) == 0: + step += 1 + else: + for micro, block in zip(cmicros, cblocks): + micro.shift(block) + Composer.schedule(micros, step=step) + micro.unshift(block) + print(f'search a plan with step {step}') + + + +if __name__ == '__main__': + + def uniform_staging(ndevs: int, nmicros=4): + micros = [] + for mid in range(nmicros): + micro = MicroPlan(mid, ndevs) + fblocks = [micro.add_block((sid, sid), Block.BType.FW) for sid in range(ndevs)] + bblocks = [micro.add_block((ndevs-1-sid, sid+ndevs), Block.BType.BW) for sid in range(ndevs)] + blocks = fblocks + bblocks + micro.add_dependency(blocks) + micros.append(micro) + return micros + + ndevs = 4 + micros = Composer.premise(uniform_staging, ndevs) + for mid, micro in enumerate(micros): + print(f'Microbatch #{mid}:') + print(micro) diff --git a/requirements.txt b/requirements.txt index 4e316892..f32bdd38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ -z3-solver -matplotlib -einops \ No newline at end of file +einops +matplotlib \ No newline at end of file From 49dab7d5d3ae3c4c2d539c79666b96a41c193e68 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 May 2022 16:54:27 +0800 Subject: [PATCH 0791/1892] remove legacy --- cube/search/piper.py | 169 ------------------------------------------- 1 file changed, 169 deletions(-) delete mode 100644 cube/search/piper.py diff --git a/cube/search/piper.py b/cube/search/piper.py deleted file mode 100644 index 919c4e14..00000000 --- a/cube/search/piper.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Abstraction layer for microb-batch execution plan merge. -""" - -from typing import Any, Dict, List, Tuple -import numpy as np - - -class MicroPlan: - - def __init__(self, plan: np.ndarray, name: str = None, summation=None): - """ - positions: - List of [spatial, temporal] slots to anchor the action - """ - assert len(plan.shape) == 2 - self.name = name - self.plan = plan - self.summation = [self] if summation is None else summation - - @property - def ndevs(self): - return self.plan.shape[0] - - @property - def nsteps(self): - return self.plan.shape[1] - - def valid(self) -> bool: - """ - Check runnability - """ - return np.max(self.plan) <= 1 - - def __add__(self, other): - if not isinstance(other, MicroPlan): - raise TypeError("Expect MicroPlan") - lhs, rhs = self, other - ndevs = max(lhs.ndevs, rhs.ndevs) - nsteps = max(lhs.nsteps, rhs.nsteps) - lhs_plan = np.pad( - lhs.plan, ((0, ndevs-lhs.ndevs),(0, nsteps-lhs.nsteps)) - ) - rhs_plan = np.pad( - rhs.plan, ((0, ndevs-rhs.ndevs), (0, nsteps-rhs.nsteps)) - ) - plan = lhs_plan + rhs_plan - if np.max(plan) <= 1: - return (True, MicroPlan(plan, summation=lhs.summation+rhs.summation)) - else: - # find conflict - sidx, tidx = (plan > 1).nonzero() - return (False, (sidx, tidx)) - - def stall(self, step: int): - """ - Primitive: insert a stall at stepline index `step` - """ - slots = np.zeros((self.ndevs, 1), dtype=int) - self.plan = np.insert(self.plan, slice(step, step+1), slots, axis=1) - return True - - def shift(self, position: Tuple[int, int], distance: int) -> bool: - """ - shift the task at position to later (+) or previous (-) steps - - MicroPlan requires there is no more than one task on same temporal slot - - Args: - position: tuple of (spatial_idx (row), step_idx (column)) - """ - s, t = position - if self.plan[s][t] != 1: - raise KeyError("No task is on this possition") - if t + distance < 0: - return False - if distance == 0: - return True - if distance > 0: - slots = np.zeros((self.ndevs, distance), dtype=int) - self.plan = np.insert(self.plan, slice(t, t+distance), slots, axis=1) - return True - if distance < 0: - slots = self.plan[:,t+distance:t] - if np.max(slots) != 0: - return False - self.plan = np.delete(self.plan, slice(t+distance, t), axis=1) - return True - return False - - def __repr__(self): - return repr(self.plan) - - -def create_microbatch(n_stage: int, n_dev: int, placement: List[int], name=None): - plan = np.zeros((n_dev, n_stage * 2), dtype=int) - for sid, devid in enumerate(placement): - # forward - plan[devid, sid] += 1 - # backward - plan[devid, 2 * n_stage - 1 - sid] += 1 - return MicroPlan(plan, name) - - -def get_conflict(micros: List[MicroPlan], step: int): - """ - Get conflicting postition at temporal step T - """ - plans = [] - for micro in micros: - if step >= micro.nsteps: - plans.append(np.zeros((micro.ndevs, 1), dtype=int)) - else: - plans.append(micro.plan[:,step:step+1]) - # [ndev, nmicros] - plans = np.hstack(tuple(plans)) - # devid [int] -> (micro_id, step) - conflicts = dict() - # conflict device ids - devids = np.where(np.sum(plans, axis=1) > 1)[0] - for devid in devids: - positions = plans[devid].nonzero()[0] - positions = [(mid, step) for mid in positions] - conflicts[devid] = positions - return conflicts - - -def solve(micros: List[MicroPlan], conflicts: Dict[int, Tuple[int, int]]): - # always address first conflicts - print(f'solve conflicts: {conflicts}') - devid = list(conflicts.keys())[0] - mid, tid = conflicts[devid][0] - print(f'select device: {devid}, micro id: {mid}, step: {tid} to solve') - micros[mid].stall(tid) - # micros[mid].shift((devid, tid), 1) - print(f'shift results: microbatch-{mid}') - print(micros[mid]) - return (mid, devid, tid) - - -def search(n_microbatch: int, n_stage: int, n_dev: int): - placement = [sid % n_dev for sid in range(n_stage)] - micros = [create_microbatch(n_stage, n_dev, placement, name=mid) for mid in range(n_microbatch)] - tidx = 0 - #TODO: justify: why firstly sovle early-step conflicts - while tidx < max([micro.nsteps for micro in micros]): - while True: - # conflict point Dict[device_id, (mid, step_id)] - conflicts = get_conflict(micros, step=tidx) - if len(conflicts) > 0: - # solve conflicts - #TODO: justify: whom: which microbatch should apply shift - #TODO: justify: how: shift distance - solve(micros, conflicts) - else: - tidx += 1 - break - span = max([micro.nsteps for micro in micros]) - print(f'find plan: {span} steps') - for mid, micro in enumerate(micros): - print(f'microbatch-{mid}:') - print(micro) - - -if __name__ == '__main__': - num_microbatch = 4 - num_stage = 4 - num_device = 4 - search(num_microbatch, num_stage, num_device) From a93044e5797dfc98c4997562e0708a6eb2a42f53 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 6 May 2022 17:08:53 +0800 Subject: [PATCH 0792/1892] add 1F1b example --- cube/search/composer.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 3b9a40e0..024a645f 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -166,7 +166,7 @@ def __init__(self, micros: List[MicroPlan]): micro.expand_to(max_steps) plans = tuple(micro.execplan for micro in micros) schedule = np.sum(np.stack(plans, axis=-1), axis=-1) - if len(np.where(self.schedule > 1)[0]) > 0: + if len(np.where(schedule > 1)[0]) > 0: raise ValueError("micro plans are not composable") # cut off redundant steps for idx in range(schedule.shape[1]): @@ -275,9 +275,35 @@ def uniform_staging(ndevs: int, nmicros=4): micro.add_dependency(blocks) micros.append(micro) return micros + + def compose_1F1B(ndevs, nmicros): + # premise + micros = uniform_staging(ndevs, nmicros) + print('premise micros:') + for micro in micros: + print(micro) + # shift + for mid, micro in enumerate(micros): + block = micro.block(0, 0) + for _ in range(2 * mid): + micro.shift(block) + print('shifted micros:') + for micro in micros: + print(micro) + assert SchedulePlan.composable(micros) + schedule = SchedulePlan(micros) + print(f'schedule (step={schedule.nsteps}):') + print(schedule) + return schedule + ndevs = 4 - micros = Composer.premise(uniform_staging, ndevs) - for mid, micro in enumerate(micros): - print(f'Microbatch #{mid}:') - print(micro) + nmicros = 8 + + # for test + # micros = Composer.premise(uniform_staging, ndevs) + # for mid, micro in enumerate(micros): + # print(f'Microbatch #{mid}:') + # print(micro) + + compose_1F1B(ndevs, nmicros) From 265e41b006a9df31d0c00dc1d05373a4d86d5bf0 Mon Sep 17 00:00:00 2001 From: lynex Date: Fri, 6 May 2022 21:02:24 +0800 Subject: [PATCH 0793/1892] add graph trans --- .../playground/dag/graph_manipulation.py | 41 +++++++-------- handcraft/playground/dag/graph_trans.py | 51 +++++++++++++++++++ 2 files changed, 69 insertions(+), 23 deletions(-) create mode 100644 handcraft/playground/dag/graph_trans.py diff --git a/handcraft/playground/dag/graph_manipulation.py b/handcraft/playground/dag/graph_manipulation.py index b7d931fc..ea076a73 100644 --- a/handcraft/playground/dag/graph_manipulation.py +++ b/handcraft/playground/dag/graph_manipulation.py @@ -295,8 +295,9 @@ class Config: num: int -class Device: - pass +class Device(int): + def __init__(self, x, base=10): + super().__init__(x, base) class Parallelizer: @@ -356,17 +357,12 @@ def set_affinity(producer_node, consumer_node): from collections import namedtuple -def idxzip(list: []): +def index_enumerate(list: []): Entry = namedtuple('Entry', ['idx', 'item']) # return [{'idx': i, 'item': x} for i, x in enumerate(list)] return [Entry(i, x) for i, x in enumerate(list)] -def xmap(func, iterables): - if list(map(func, iterables)) is None: - print('xmap ERROR') - - ### TODO how about Tx in flexflow GSPMD etc.? # traditional data-parallel process: @@ -386,13 +382,11 @@ def run(self, g: Graph, config: Config) -> Graph: for node in g.nodes: if isinstance(node, (NodeData, NodeFwd, NodeBwd)): nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-split - xmap(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] elif isinstance(node, (NodeOpt)): - nodes = trans(node, node.op.algo.replica, config.num) #replicated optimizers - xmap(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + nodes = trans(node, node.op.algo.replica, config.num) # replicated optimizers + [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] else: - print(node) - print(type(node)) assert False # ---------------- @@ -409,10 +403,10 @@ def run(self, g: Graph, config: Config) -> Graph: for node in g.nodes: if isinstance(node, (NodeData, NodeFwd, NodeBwd)): nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-split - map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] elif isinstance(node, (NodeOpt)): - nodes = trans(node, node.op.algo.split, config.num) #split optimizers - map(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + nodes = trans(node, node.op.algo.split, config.num) # split optimizers + [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] else: assert False # ---------------- @@ -466,7 +460,7 @@ def run(self, g: Graph, config: Config) -> Graph: if isinstance(node, (NodeData, NodeFwd, NodeBwd)): nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit sched_t(nodes) # sequential order - xmap(lambda x: sched_s(node=x, dev=device), nodes) # assign same stage device + [sched_s(node=x, dev=device) for x in nodes] # assign same stage device elif isinstance(node, (NodeOpt)): sched_s(node, device) else: @@ -486,10 +480,10 @@ def run(self, g: Graph, config: Config) -> Graph: for node in g.nodes: if isinstance(node, (NodeFwd, NodeBwd, NodeOpt)): nodes = trans(node, node.op.algo.tensor_split, config.num) # by tensor-dim-slit - xmap(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] elif isinstance(node, (NodeData)): nodes = trans(node, node.op.algo.replica, config.num) - xmap(lambda x: sched_s(node=x.item, dev=x.idx), idxzip(nodes)) + [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] else: assert False # ---------------- @@ -525,9 +519,9 @@ def run(self, g: Graph, config: Config) -> Graph: consumers = find_consumers(g, origin_fwd.outputs[0]) for consumer in consumers: if isinstance(consumer, NodeFwd): - set_affinity(origin_fwd, consumer) # break dependencies op0.fwd -> op1.fwd; op0.fwd' -> op0.bwd + set_affinity(origin_fwd, consumer) # break dependencies op0.fwd -> op1.fwd; else: - set_affinity(recompute_fwd, consumer) # break dependencies op0.fwd -> op1.fwd; op0.fwd' -> op0.bwd + set_affinity(recompute_fwd, consumer) # break dependencies op0.fwd' -> op0.bwd producers = list(filter(lambda x: isinstance(x, NodeBwd), find_producers(g, consumer.inputs[0]))) for producer in producers: sched_t_pair(producer, recompute_fwd) @@ -538,14 +532,15 @@ def run(self, g: Graph, config: Config) -> Graph: class ActivationSwap(Parallelizer): def run(self, g: Graph, config: Config) -> Graph: + #TODO activate consuming NodeBwd -> Identity(CPU) + NodeBwd pass -# para = DataParallelParallelizer() +para = DataParallelParallelizer() # para = DataParallelZeROParallelizer() # para = GradientAccumulationParallelizer() # para = GPipeParallelizer() # para = TensorParallelParallelizer() -para = Recompute() +# para = Recompute() config = Config() diff --git a/handcraft/playground/dag/graph_trans.py b/handcraft/playground/dag/graph_trans.py new file mode 100644 index 00000000..320dade0 --- /dev/null +++ b/handcraft/playground/dag/graph_trans.py @@ -0,0 +1,51 @@ +from graph_manipulation import * + + + +# general transformations +''' +Op := I -> Op (pre-identity) +Op := Op -> I (post-identity) +Op := Op, Op (replicate) +''' + +# batch transformation (due to operator sample-wise) +''' +DataLoader + split (output)activation + +OperatorForward + split (input)activation + replica (input)weight + split (output)activation* + +OperatorBackward-(activation's gradient) + split (input)d-activation* + replica (input)weight + split (output)d-activation + +OperatorBackward-(weight's gradient) + split (input)d-activation* + split (input)activation + value-split (to-reduce) (output)d-weight +''' + +# non-batch transformation (operator semantic aware) +''' +elementwise operators (including optimizers) + arbitrary same split on inputs and outputs + +MatMul [M, K]*[K, N] => [M, N] + 1. split M or N (e.g., cases with M or N as batch-dim) + 2. split reducing dim K: [M, K/2]*[K/2, N] => value-split [M, N] + +Conv2D + 1. split (input) image H, W => halo exchange then local Conv2D, split (output) image + 2. split (input) filter out-channel-dim => Conv2D on replicated image with partial filter, value-split (output) image + +(more cases) ... +''' + + +def trans(node, )->Node: + pass \ No newline at end of file From 5b64850b08ba4ed236594fe5f445037e15ab41bc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 7 May 2022 10:23:04 +0800 Subject: [PATCH 0794/1892] add visualization --- cube/search/composer.py | 66 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 024a645f..3bcde5f9 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -237,7 +237,66 @@ def __repr__(self): def visualize(self, outfile=None): import matplotlib.pyplot as plt from matplotlib.patches import Rectangle - plt.close('all') + from matplotlib.ticker import AutoMinorLocator + plt.close('all') + fig, ax = plt.subplots(figsize=(4 * self.nsteps // self.ndevs, 4)) + renderer = fig.canvas.get_renderer() + + # xaxis + ax.set_xlim((0, self.nsteps)) + plt.xticks( + ticks=np.arange(0.5, self.nsteps+0.5, 1.0, dtype=float), + labels=np.arange(1, self.nsteps+1, 1, dtype=int) + ) + minor_locator = AutoMinorLocator(2) + plt.gca().xaxis.set_minor_locator(minor_locator) + ax.xaxis.grid(which='minor', linestyle='--') + # yaxis + ax.set_ylim((0.5, self.ndevs+0.5)) + plt.yticks(np.arange(1, self.ndevs+1, 1, dtype=int)) + ax.invert_yaxis() + + fontsize = [40] + txts = list() + def draw_block(block: Block, fontsize): + color = '#4472C4' if block.type == Block.BType.FW else '#ED7D31' + dev, step = block.position + rec = Rectangle((step, dev+0.5), 1, 1, color=color, ec='black', lw=1.5) + ax.add_artist(rec) + rx, ry = rec.get_xy() + cx = rx + rec.get_width() / 2.0 + cy = ry + rec.get_height() / 2.0 + anno = str(block.mid) + txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') + rbox = rec.get_window_extent(renderer) + for fs in range(fontsize[0], 1, -2): + txt.set_fontsize(fs) + tbox = txt.get_window_extent(renderer) + if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: + break + fontsize[0] = min(fontsize[0], fs) + txts.append(txt) + + for dev in range(self.ndevs): + for step in range(self.nsteps): + block = self.block(dev, step) + if block is not None: + draw_block(block, fontsize) + # set fontsize to same + fontsize = fontsize[0] + for txt in txts: + txt.set_fontsize(fontsize) + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize) + plt.xlabel('Time Step', fontsize=fontsize) + plt.ylabel('Device', fontsize=fontsize) + plt.tight_layout() + if outfile: + plt.savefig(outfile) + else: + plt.show() class Composer: @@ -298,7 +357,7 @@ def compose_1F1B(ndevs, nmicros): ndevs = 4 - nmicros = 8 + nmicros = 4 # for test # micros = Composer.premise(uniform_staging, ndevs) @@ -306,4 +365,5 @@ def compose_1F1B(ndevs, nmicros): # print(f'Microbatch #{mid}:') # print(micro) - compose_1F1B(ndevs, nmicros) + schedule = compose_1F1B(ndevs, nmicros) + schedule.visualize('out.png') From ebcc88fe1b77033f5133eef01f4ea89247197fd3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 7 May 2022 10:46:07 +0800 Subject: [PATCH 0795/1892] add base class --- cube/search/composer.py | 274 +++++++++++++++++++--------------------- 1 file changed, 132 insertions(+), 142 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 3bcde5f9..cf78ae4a 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -47,37 +47,131 @@ def __repr__(self): return f'f{self.mid}' if self.type == Block.BType.FW else f'b{self.mid}' -class MicroPlan: +class PlanBase: - def __init__(self, mid: int, ndevs: int): - """ - Create an empty microbatch execution plan - - mid: microbatch id - ndevs: number of devices - """ - self.mid = mid + def __init__(self, ndevs: int): self.blocks: Dict[Tuple[int, int], Block] = dict() - self.execplan = np.zeros((ndevs, ndevs * 2), dtype=int) + self.plan = np.zeros((ndevs, ndevs * 2), dtype=int) @property def ndevs(self): - return self.execplan.shape[0] + return self.plan.shape[0] @property def nsteps(self): - return self.execplan.shape[1] - - def expand_to(self, nsteps: int): - if self.nsteps < nsteps: - extend = nsteps - self.nsteps - self.execplan = np.pad(self.execplan, ((0,0),(0,extend))) + return self.plan.shape[1] def block(self, dev: int, step: int): if (dev, step) not in self.blocks: return None return self.blocks[(dev, step)] + def squeeze(self): + """ + remove redundant steps + """ + execflag = np.sum(self.plan, axis=1) + for idx in range(self.nsteps): + if execflag[-idx-1] != 0: + break + self.plan = self.plan[:, :-idx] if idx > 0 else self.plan + + def __repr__(self): + namelen = 2 + dscp = '' + for dev in range(self.ndevs): + for step in range(self.nsteps): + block = self.block(dev, step) + if block is None: + dscp += '-' * namelen + ' ' + else: + # TODO: 2 replace to namelen + dscp += '{: <2}'.format(repr(block)) + ' ' + dscp += '\n' + return dscp + + def visualize(self, outfile=None): + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + from matplotlib.ticker import AutoMinorLocator + plt.close('all') + fig, ax = plt.subplots(figsize=(4 * self.nsteps // self.ndevs, 4)) + renderer = fig.canvas.get_renderer() + + # xaxis + ax.set_xlim((0, self.nsteps)) + plt.xticks( + ticks=np.arange(0.5, self.nsteps+0.5, 1.0, dtype=float), + labels=np.arange(1, self.nsteps+1, 1, dtype=int) + ) + minor_locator = AutoMinorLocator(2) + plt.gca().xaxis.set_minor_locator(minor_locator) + ax.xaxis.grid(which='minor', linestyle='--') + # yaxis + ax.set_ylim((0.5, self.ndevs+0.5)) + plt.yticks(np.arange(1, self.ndevs+1, 1, dtype=int)) + ax.invert_yaxis() + + fontsize = [40] + txts = list() + def draw_block(block: Block, fontsize): + color = '#4472C4' if block.type == Block.BType.FW else '#ED7D31' + dev, step = block.position + rec = Rectangle((step, dev+0.5), 1, 1, color=color, ec='black', lw=1.5) + ax.add_artist(rec) + rx, ry = rec.get_xy() + cx = rx + rec.get_width() / 2.0 + cy = ry + rec.get_height() / 2.0 + anno = str(block.mid) + txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') + rbox = rec.get_window_extent(renderer) + for fs in range(fontsize[0], 1, -2): + txt.set_fontsize(fs) + tbox = txt.get_window_extent(renderer) + if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: + break + fontsize[0] = min(fontsize[0], fs) + txts.append(txt) + + for dev in range(self.ndevs): + for step in range(self.nsteps): + block = self.block(dev, step) + if block is not None: + draw_block(block, fontsize) + # set fontsize to same + fontsize = fontsize[0] + for txt in txts: + txt.set_fontsize(fontsize) + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize) + plt.xlabel('Time Step', fontsize=fontsize) + plt.ylabel('Device', fontsize=fontsize) + plt.tight_layout() + if outfile: + plt.savefig(outfile) + else: + plt.show() + + +class MicroPlan(PlanBase): + + def __init__(self, mid: int, ndevs: int): + """ + Create an empty microbatch execution plan + + mid: microbatch id + ndevs: number of devices + """ + super().__init__(ndevs) + self.mid = mid + + def expand_to(self, nsteps: int): + if self.nsteps < nsteps: + extend = nsteps - self.nsteps + self.plan = np.pad(self.plan, ((0,0),(0,extend))) + def add_block(self, pos: Tuple[int, int], btype: Block.BType) -> Block: """ Add a execution block @@ -87,10 +181,10 @@ def add_block(self, pos: Tuple[int, int], btype: Block.BType) -> Block: raise ValueError("device out of scope") if step >= self.nsteps: self.expand_to(step + 1) - if self.execplan[dev, step] != 0: + if self.plan[dev, step] != 0: raise ValueError(f"Postition {pos} already has blocks") block = Block(self.mid, pos, btype) - self.execplan[dev, step] += 1 + self.plan[dev, step] += 1 self.blocks[(dev, step)] = block return block @@ -113,10 +207,10 @@ def shift(self, block: Block): for after_block in block.after: if step + 1 == after_block.position[1]: self.shift(after_block) - self.execplan[dev, step] = 0 + self.plan[dev, step] = 0 if step + 1 >= self.nsteps: self.expand_to(self.nsteps + 1) - self.execplan[dev, step+1] = 1 + self.plan[dev, step+1] = 1 # update block and self.blocks block.position = (dev, step+1) del self.blocks[(dev, step)] @@ -130,8 +224,8 @@ def unshift(self, block: Block): if step == 0: raise ValueError("unshift a block with step = 0") # shift back - self.execplan[dev, step] = 0 - self.execplan[dev, step-1] = 1 + self.plan[dev, step] = 0 + self.plan[dev, step-1] = 1 block.position = (dev, step-1) del self.blocks[(dev, step)] self.blocks[(dev, step-1)] = 1 @@ -140,64 +234,39 @@ def unshift(self, block: Block): if step + 1 == after_block.position[1]: self.unshift(after_block) - def __repr__(self): - namelen = 2 + self.mid // 10 - dscp = '' - for dev in range(self.ndevs): - for step in range(self.nsteps): - block = self.block(dev, step) - if block is None: - dscp += '-' * namelen + ' ' - else: - # TODO: 2 replace to namelen - dscp += '{: <2}'.format(repr(block)) + ' ' - dscp += '\n' - return dscp - -class SchedulePlan: +class SchedulePlan(PlanBase): def __init__(self, micros: List[MicroPlan]): + ndevs = [micro.ndevs for micro in micros] + if len(set(ndevs)) != 1: + raise ValueError(f"Device number not same: {ndevs}") + ndevs = ndevs[0] + + super().__init__(ndevs) self.micros = micros - # get schedules + # get schedule plans max_steps = max(micro.nsteps for micro in micros) for micro in micros: micro.expand_to(max_steps) - plans = tuple(micro.execplan for micro in micros) + plans = tuple(micro.plan for micro in micros) schedule = np.sum(np.stack(plans, axis=-1), axis=-1) if len(np.where(schedule > 1)[0]) > 0: raise ValueError("micro plans are not composable") - # cut off redundant steps - for idx in range(schedule.shape[1]): - if np.sum(schedule[:, -idx-1]) != 0: - break - self.schedule = schedule[:, :-idx] if idx > 0 else schedule + self.plan = schedule + self.squeeze() - # get blocks - self.blocks = dict() + # set blocks for micro in micros: self.blocks.update(micro.blocks) - @property - def ndevs(self): - return self.schedule.shape[0] - - @property - def nsteps(self): - return self.schedule.shape[1] - - def block(self, dev: int, step: int): - if (dev, step) not in self.blocks: - return None - return self.blocks[(dev, step)] - @staticmethod def composable(micros: List[MicroPlan]) -> bool: max_steps = max(micro.nsteps for micro in micros) for micro in micros: micro.expand_to(max_steps) - plans = tuple(micro.execplan for micro in micros) + plans = tuple(micro.plan for micro in micros) schedule = np.sum(np.stack(plans, axis=-1), axis=-1) devids = np.where(schedule > 1)[0] return len(devids) == 0 @@ -207,7 +276,7 @@ def conflict(micros: List[MicroPlan], step: int) -> bool: max_steps = max(micro.nsteps for micro in micros) for micro in micros: micro.expand_to(max_steps) - plans = tuple(micro.execplan[:,step] for micro in micros) + plans = tuple(micro.plan[:,step] for micro in micros) schedule = np.sum(np.stack(plans, axis=-1), axis=-1) cmicros = [] cblocks = [] @@ -219,85 +288,6 @@ def conflict(micros: List[MicroPlan], step: int) -> bool: cblocks.append(micro.block[dev, step]) return cmicros, cblocks - def __repr__(self): - nmicros = len(self.micros) - namelen = 2 + nmicros // 10 - dscp = '' - for dev in range(self.ndevs): - for step in range(self.nsteps): - block = self.block(dev, step) - if block is None: - dscp += '-' * namelen + ' ' - else: - # TODO: 2 replace to namelen - dscp += '{: <2}'.format(repr(block)) + ' ' - dscp += '\n' - return dscp - - def visualize(self, outfile=None): - import matplotlib.pyplot as plt - from matplotlib.patches import Rectangle - from matplotlib.ticker import AutoMinorLocator - plt.close('all') - fig, ax = plt.subplots(figsize=(4 * self.nsteps // self.ndevs, 4)) - renderer = fig.canvas.get_renderer() - - # xaxis - ax.set_xlim((0, self.nsteps)) - plt.xticks( - ticks=np.arange(0.5, self.nsteps+0.5, 1.0, dtype=float), - labels=np.arange(1, self.nsteps+1, 1, dtype=int) - ) - minor_locator = AutoMinorLocator(2) - plt.gca().xaxis.set_minor_locator(minor_locator) - ax.xaxis.grid(which='minor', linestyle='--') - # yaxis - ax.set_ylim((0.5, self.ndevs+0.5)) - plt.yticks(np.arange(1, self.ndevs+1, 1, dtype=int)) - ax.invert_yaxis() - - fontsize = [40] - txts = list() - def draw_block(block: Block, fontsize): - color = '#4472C4' if block.type == Block.BType.FW else '#ED7D31' - dev, step = block.position - rec = Rectangle((step, dev+0.5), 1, 1, color=color, ec='black', lw=1.5) - ax.add_artist(rec) - rx, ry = rec.get_xy() - cx = rx + rec.get_width() / 2.0 - cy = ry + rec.get_height() / 2.0 - anno = str(block.mid) - txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') - rbox = rec.get_window_extent(renderer) - for fs in range(fontsize[0], 1, -2): - txt.set_fontsize(fs) - tbox = txt.get_window_extent(renderer) - if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: - break - fontsize[0] = min(fontsize[0], fs) - txts.append(txt) - - for dev in range(self.ndevs): - for step in range(self.nsteps): - block = self.block(dev, step) - if block is not None: - draw_block(block, fontsize) - # set fontsize to same - fontsize = fontsize[0] - for txt in txts: - txt.set_fontsize(fontsize) - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(fontsize) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(fontsize) - plt.xlabel('Time Step', fontsize=fontsize) - plt.ylabel('Device', fontsize=fontsize) - plt.tight_layout() - if outfile: - plt.savefig(outfile) - else: - plt.show() - class Composer: @@ -357,7 +347,7 @@ def compose_1F1B(ndevs, nmicros): ndevs = 4 - nmicros = 4 + nmicros = 8 # for test # micros = Composer.premise(uniform_staging, ndevs) From 129d138884f63cbabad3ed041f546d2673d7df00 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 7 May 2022 17:18:04 +0800 Subject: [PATCH 0796/1892] prepare for search --- cube/search/composer.py | 181 +++++++++++++++++++++++++++------------- 1 file changed, 121 insertions(+), 60 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index cf78ae4a..635d4af5 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -2,7 +2,7 @@ Abstraction layer for microb-batch execution plan merge. """ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import numpy as np from enum import Enum @@ -16,24 +16,13 @@ class BType(Enum): FW = 'forward' BW = 'backward' - def __init__(self, mid: int, pos: Tuple[int, int], btype: BType): + def __init__(self, mid: int, btype: BType): self.mid: int = mid self.type = btype - self._position = tuple(pos) # dependency track self.before: List[Block] = list() self.after: List[Block] = list() - @property - def position(self): - return self._position - - @position.setter - def position(self, pos: Tuple[int, int]): - if len(pos) != 2: - raise ValueError("Expected positition to be Tuple[int, int]") - self._position = pos - @staticmethod def add_dependency(before, after): if not (isinstance(before, Block) and isinstance(after, Block)): @@ -51,6 +40,7 @@ class PlanBase: def __init__(self, ndevs: int): self.blocks: Dict[Tuple[int, int], Block] = dict() + self.positions: Dict[int, Tuple[int, int]] = dict() self.plan = np.zeros((ndevs, ndevs * 2), dtype=int) @property @@ -62,9 +52,22 @@ def nsteps(self): return self.plan.shape[1] def block(self, dev: int, step: int): + """ + Get block given a position + """ if (dev, step) not in self.blocks: return None return self.blocks[(dev, step)] + + def position(self, block: Block) -> Optional[Tuple[int, int]]: + """ + Get (dev, step) position given a block. + If block not in this plan, return None + """ + if id(block) in self.positions: + return self.positions[id(block)] + else: + return None def squeeze(self): """ @@ -114,9 +117,9 @@ def visualize(self, outfile=None): fontsize = [40] txts = list() - def draw_block(block: Block, fontsize): + def draw_block(block: Block, position: Tuple[int, int], fontsize): color = '#4472C4' if block.type == Block.BType.FW else '#ED7D31' - dev, step = block.position + dev, step = position rec = Rectangle((step, dev+0.5), 1, 1, color=color, ec='black', lw=1.5) ax.add_artist(rec) rx, ry = rec.get_xy() @@ -137,7 +140,7 @@ def draw_block(block: Block, fontsize): for step in range(self.nsteps): block = self.block(dev, step) if block is not None: - draw_block(block, fontsize) + draw_block(block, self.position(block), fontsize) # set fontsize to same fontsize = fontsize[0] for txt in txts: @@ -183,9 +186,10 @@ def add_block(self, pos: Tuple[int, int], btype: Block.BType) -> Block: self.expand_to(step + 1) if self.plan[dev, step] != 0: raise ValueError(f"Postition {pos} already has blocks") - block = Block(self.mid, pos, btype) + block = Block(self.mid, btype) self.plan[dev, step] += 1 self.blocks[(dev, step)] = block + self.positions[id(block)] = (dev, step) return block def add_dependency(self, blocks: List[Block]): @@ -196,43 +200,55 @@ def add_dependency(self, blocks: List[Block]): for idx in range(len(blocks) - 1): Block.add_dependency(blocks[idx], blocks[idx+1]) - def shift(self, block: Block): + def copy(self): + """ + copy a micro plan + """ + micro = MicroPlan(self.ndevs, self.mid) + micro.plan = np.array(self.plan, copy=True) + micro.blocks.update(self.blocks) + micro.positions.update(self.positions) + return micro + + def shift(self, block: Block, inplace=True): """ - The primitive during search + The primitive: shift a block by pushing one step later """ + micro = self if inplace else self.copy() # check block in this plan - if block not in self.blocks.values(): + if block not in micro.blocks.values(): raise ValueError("Block not in this micro plan") - dev, step = block.position + dev, step = micro.position(block) for after_block in block.after: - if step + 1 == after_block.position[1]: - self.shift(after_block) - self.plan[dev, step] = 0 - if step + 1 >= self.nsteps: - self.expand_to(self.nsteps + 1) - self.plan[dev, step+1] = 1 - # update block and self.blocks - block.position = (dev, step+1) - del self.blocks[(dev, step)] - self.blocks[(dev, step+1)] = block - - def unshift(self, block: Block): + if step + 1 == micro.position(after_block)[1]: + micro.shift(after_block, inplace=True) + micro.plan[dev, step] = 0 + if step + 1 >= micro.nsteps: + micro.expand_to(micro.nsteps + 1) + micro.plan[dev, step+1] = 1 + # update blocks and positions + del micro.blocks[(dev, step)] + micro.blocks[(dev, step+1)] = block + micro.positions[id(block)] = (dev, step+1) + + def unshift(self, block: Block, inplace=True): """ reverse shift, for search only """ - dev, step = block.position + micro = self if inplace else self.copy() + dev, step = micro.position(block) if step == 0: raise ValueError("unshift a block with step = 0") # shift back - self.plan[dev, step] = 0 - self.plan[dev, step-1] = 1 - block.position = (dev, step-1) - del self.blocks[(dev, step)] - self.blocks[(dev, step-1)] = 1 + micro.plan[dev, step] = 0 + micro.plan[dev, step-1] = 1 + del micro.blocks[(dev, step)] + micro.blocks[(dev, step-1)] = 1 + micro.positions[id(block)] = (dev, step-1) # shift back shifted blocks for after_block in block.after: - if step + 1 == after_block.position[1]: - self.unshift(after_block) + if step + 1 == micro.position(after_block)[1]: + micro.unshift(after_block, inplace=True) class SchedulePlan(PlanBase): @@ -257,9 +273,10 @@ def __init__(self, micros: List[MicroPlan]): self.plan = schedule self.squeeze() - # set blocks + # set blocks and positions for micro in micros: self.blocks.update(micro.blocks) + self.positions.update(micro.positions) @staticmethod def composable(micros: List[MicroPlan]) -> bool: @@ -272,21 +289,25 @@ def composable(micros: List[MicroPlan]) -> bool: return len(devids) == 0 @staticmethod - def conflict(micros: List[MicroPlan], step: int) -> bool: + def conflict(micros: List[MicroPlan], step: int) -> Dict[int, List[Tuple[MicroPlan, Block]]]: + """ + Get conflict blocks at `step`. + Return the conflicted (MicroPlan, Block) grouped by device id + """ max_steps = max(micro.nsteps for micro in micros) for micro in micros: micro.expand_to(max_steps) plans = tuple(micro.plan[:,step] for micro in micros) schedule = np.sum(np.stack(plans, axis=-1), axis=-1) - cmicros = [] - cblocks = [] - devids, steps = np.where(schedule > 1) - for dev, step in zip(devids, steps): + conflicts = dict() + devids = np.where(schedule > 1)[0] + for dev in devids: + conflicts[dev] = [] for micro in micros: - if micro.block[dev, step] is not None: - cmicros.append(micro) - cblocks.append(micro.block[dev, step]) - return cmicros, cblocks + cblock = micro.block[dev, step] + if cblock is not None: + conflicts[dev].append((micro, cblock)) + return conflicts class Composer: @@ -296,20 +317,60 @@ def premise(fn, ndevs: int): micros = fn(ndevs) return micros + @staticmethod - def schedule(micros, step=0): - # DFS search + def bfs_schedule(micros: List[MicroPlan]): + step = 0 + prev_step_trace: List[List[Tuple[MicroPlan, Block]]] = [[]] + next_step_trace: List[List[Tuple[MicroPlan, Block]]] = [] while not SchedulePlan.composable(micros): - cmicros, cblocks = SchedulePlan.conflict(micros, step) - if len(cmicros) == 0: - step += 1 - else: - for micro, block in zip(cmicros, cblocks): + for trace in prev_step_trace: + # move accroding to trace + for micro, block in trace: micro.shift(block) - Composer.schedule(micros, step=step) + # get and solve conflicts + conflicts = SchedulePlan.conflict(micros, step) + new_trace = [] # TODO + search_devs = [] + for dev, microblocks in conflicts.items(): + cmicros = [micro for (micro, _) in microblocks] + cblocks = [block for (_, block) in microblocks] + if Composer.same_plans(cmicros, start_step=step): + for cmicro, cblock in zip(cmicros[1:], cblocks[1:]): + trace.append((cmicro, cblock)) + else: + search_devs.append(dev) + if len(search_devs) == 0: + next_step_trace.append(trace) + else: + # TODO + pass + # move back according to trace + for micro, block in trace[::-1]: micro.unshift(block) - print(f'search a plan with step {step}') + if len(conflicts) == 0: + continue + + @staticmethod + def same_plans(micros: List[MicroPlan], start_step: int = 0) -> bool: + Composer.to_same_step(micros) + plans = [micro.plan[:,start_step:] for micro in micros] + plan = plans[0] + for other in plans[1:]: + if not np.array_equal(plan, other): + return False + return True + + @staticmethod + def to_same_step(micros: List[MicroPlan]): + """ + extend micros to same step + """ + nsteps = max(micro.nsteps for micro in micros) + for micro in micros: + micro.expand_to(nsteps) + return micros if __name__ == '__main__': From e84416dbaeb08e3345948077d3ed0d57a182b409 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 7 May 2022 19:53:29 +0800 Subject: [PATCH 0797/1892] add bfs seearch --- cube/search/composer.py | 125 ++++++++++++++++++++++++++++++---------- 1 file changed, 94 insertions(+), 31 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 635d4af5..4f453960 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -2,9 +2,10 @@ Abstraction layer for microb-batch execution plan merge. """ -from typing import Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional import numpy as np from enum import Enum +import time class Block: @@ -204,7 +205,7 @@ def copy(self): """ copy a micro plan """ - micro = MicroPlan(self.ndevs, self.mid) + micro = MicroPlan(self.mid, self.ndevs) micro.plan = np.array(self.plan, copy=True) micro.blocks.update(self.blocks) micro.positions.update(self.positions) @@ -230,6 +231,7 @@ def shift(self, block: Block, inplace=True): del micro.blocks[(dev, step)] micro.blocks[(dev, step+1)] = block micro.positions[id(block)] = (dev, step+1) + return micro def unshift(self, block: Block, inplace=True): """ @@ -249,6 +251,7 @@ def unshift(self, block: Block, inplace=True): for after_block in block.after: if step + 1 == micro.position(after_block)[1]: micro.unshift(after_block, inplace=True) + return micro class SchedulePlan(PlanBase): @@ -304,7 +307,7 @@ def conflict(micros: List[MicroPlan], step: int) -> Dict[int, List[Tuple[MicroPl for dev in devids: conflicts[dev] = [] for micro in micros: - cblock = micro.block[dev, step] + cblock = micro.block(dev, step) if cblock is not None: conflicts[dev].append((micro, cblock)) return conflicts @@ -313,43 +316,68 @@ def conflict(micros: List[MicroPlan], step: int) -> Dict[int, List[Tuple[MicroPl class Composer: @staticmethod - def premise(fn, ndevs: int): - micros = fn(ndevs) + def premise(fn, ndevs: int, nmicros: int): + micros = fn(ndevs, nmicros) return micros @staticmethod def bfs_schedule(micros: List[MicroPlan]): + micros.sort(key=lambda m: m.mid) step = 0 - prev_step_trace: List[List[Tuple[MicroPlan, Block]]] = [[]] - next_step_trace: List[List[Tuple[MicroPlan, Block]]] = [] - while not SchedulePlan.composable(micros): - for trace in prev_step_trace: - # move accroding to trace - for micro, block in trace: - micro.shift(block) + prev: List[List[MicroPlan]] = [micros] + next: List[List[MicroPlan]] = [] + output: List[List[MicroPlan]] = [] + while True: + find = False + print(f'solving step {step}, candidates {len(prev)}...') + for micros in prev: # get and solve conflicts conflicts = SchedulePlan.conflict(micros, step) - new_trace = [] # TODO + # input(f'conflicts: dev: {list(conflicts.keys())}, mids: {[[conf[0].mid for conf in c] for c in conflicts.values()]} | >>>') search_devs = [] + # direct shift on symetrics for dev, microblocks in conflicts.items(): cmicros = [micro for (micro, _) in microblocks] cblocks = [block for (_, block) in microblocks] if Composer.same_plans(cmicros, start_step=step): for cmicro, cblock in zip(cmicros[1:], cblocks[1:]): - trace.append((cmicro, cblock)) + # print(f'shift(micro{cmicro.mid}, block<{cmicro.position(cblock)}>)') + cmicro.shift(cblock, inplace=True) else: search_devs.append(dev) if len(search_devs) == 0: - next_step_trace.append(trace) + micros = [micro.copy() for micro in micros] + next.append(micros) + if SchedulePlan.composable(micros): + output.append(micros) + find = True + # search space using different shift choice else: - # TODO - pass - # move back according to trace - for micro, block in trace[::-1]: - micro.unshift(block) - if len(conflicts) == 0: - continue + slots = [[micro.mid for (micro, _) in conflicts[dev]] for dev in search_devs] + # input(f'search devs: {search_devs}, slots: {slots} | >>>') + for keep_mids in Composer.otho_iter(slots): + shifted_micros = [micro.copy() for micro in micros] + shift_mids = [ + [mid for mid in slot if mid != kmid] for kmid, slot in zip(keep_mids, slots) + ] + for dev, mids in zip(search_devs, shift_mids): + for mid in mids: + block = micros[mid].block(dev, step) + # print(f'shift(micro{mid}, block<{(dev, step)}>)') + shifted_micros[mid] = micros[mid].shift(block, inplace=False) + next.append(shifted_micros) + if SchedulePlan.composable(shifted_micros): + output.append(shifted_micros) + shifted_micros=None + find = True + prev, next = next, [] + if find: + prev = output + break + step += 1 + schedules = [SchedulePlan(micros) for micros in prev] + return schedules @staticmethod @@ -372,6 +400,27 @@ def to_same_step(micros: List[MicroPlan]): micro.expand_to(nsteps) return micros + @staticmethod + def otho_iter(slots: List[List[Any]]): + """ + othogonal pickers + + item for each slot can be randomly selected + """ + if len(slots) == 0: + yield [] + return + slot = slots[0] + if len(slots) == 1: + for item in slot: + yield [item] + else: + slots = slots[1:] + for item in slot: + for res in Composer.otho_iter(slots): + yield [item] + res + return + if __name__ == '__main__': @@ -405,16 +454,30 @@ def compose_1F1B(ndevs, nmicros): print(f'schedule (step={schedule.nsteps}):') print(schedule) return schedule + + def search(ndevs, nmicros): + # premise + micros = Composer.premise(uniform_staging, ndevs, nmicros) + + # search shift + tic = time.time() + schedules = Composer.bfs_schedule(micros) + toc = time.time() + print('search done. time {:.2f}s'.format(toc - tic)) + + + steps = set(schedule.nsteps for schedule in schedules) + assert len(steps) == 1, f"got un-consistent step set: {steps}" + nsteps = list(steps)[0] + print(f'find {len(schedules)} step-optimal schedules of step: {nsteps}') + for idx, schedule in enumerate(schedules): + print(f'Schedule #{idx+1}:') + print(schedule) ndevs = 4 - nmicros = 8 - - # for test - # micros = Composer.premise(uniform_staging, ndevs) - # for mid, micro in enumerate(micros): - # print(f'Microbatch #{mid}:') - # print(micro) + nmicros = 4 - schedule = compose_1F1B(ndevs, nmicros) - schedule.visualize('out.png') + # schedule = compose_1F1B(ndevs, nmicros) + # schedule.visualize('out.png') + search(ndevs, nmicros) From 81f1d8d67fdc7e2dd6d412473a47605e3a210bb1 Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 9 May 2022 09:45:04 +0800 Subject: [PATCH 0798/1892] nit --- handcraft/playground/dag/data_parallel_raw.py | 89 +++++++++++++++++++ .../playground/dag/graph_manipulation.py | 12 +-- handcraft/playground/dag/graph_trans.py | 9 +- 3 files changed, 97 insertions(+), 13 deletions(-) create mode 100644 handcraft/playground/dag/data_parallel_raw.py diff --git a/handcraft/playground/dag/data_parallel_raw.py b/handcraft/playground/dag/data_parallel_raw.py new file mode 100644 index 00000000..869d9de5 --- /dev/null +++ b/handcraft/playground/dag/data_parallel_raw.py @@ -0,0 +1,89 @@ +from graph_manipulation import * + +''' +[dataflow graph]: the logic of one training iteration +input: samples, weight tensors, optimizer state tensors +output: (updated) weight tensors, (updated) optimizer state tensors + +(assumption) a DL model description compatible with different batch-size + ref 1 ONNX: https://github.com/onnx/onnx/issues/2182 + ref 2 ... + +''' + +''' +graph manipulation as data-parallel +option 1: deep copy graph and manually decide adjustment for each node/tensor +option 2: manually decide manipulation for each node/tensor +option 3: using node-role, e.g., DataNode/Fwd/Bwd split; Optimizer replicate; weight gradient all-reduce before apply +option 4: using tensor info. e.g., tensors with +''' +def data_parallel_raw(g: Graph, dev_num: int, option: int): + print(g) + + if option == 1: #per node manipulation following oracle + def magic_func(node) -> bool: + pass + + # 1. multiply operators + for node in g.nodes: + if magic_func(node) == 'split op': + new_nodes = [] + for i in range(dev_num): + new_node_inputs = [] + for ts in node.inputs: + if magic_func(ts) == 'split tensor': + new_node_inputs.append(split(ts, dev_num)) + elif magic_func(ts) == 'replicate tensor': + new_node_inputs.append(ts) + #TODO connect split tensor + new_nodes.append(Node(type=node.type, inputs=new_node_inputs)) #insert new node + elif magic_func(node) == 'replicate op': + new_nodes = [node] * dev_num + + g.replace(node, new_nodes) # TODO connect + + #2. inserting gradient averaging + for node in g.nodes: + new_allreduce_node = None + gradient = None + for ts in node.inputs: + if magic_func(ts): + new_allreduce_node = Node(type=allreduce, inputs=ts) + gradient = ts + break + new_node = Node(node.type, node.inputs - gradient + new_allreduce_node.output) + g.replace(node, [new_allreduce_node, new_node]) + + elif option == 2: #replicate entire graph and adjust + def magic_func(node) -> bool: + pass + + # deep copy graph + graphs = [g.deepcopy() for i in range(dev_num)] + # reset batch size for each new graph, leveraging resizable DFG description (<-assumption) + for graph in graphs: + graph.batch_size = g.batch_size // dev_num + # inserting gradient averaging + for graph in graphs: + for node in graph.nodes: + new_allreduce_node = None + gradient = None + for ts in node.inputs: + if magic_func(ts): + new_allreduce_node = Node(type=allreduce, inputs=ts) + gradient = ts + break + new_node = Node(type=node.type, inputs=node.inputs - gradient + new_allreduce_node.output) + graph.replace(node, [new_allreduce_node, new_node]) + elif option == 3: #node role based manipulation + pass + + elif option == 4: #tensor dimention info based manipulation + pass + + + + + +data_parallel_raw(graph) \ No newline at end of file diff --git a/handcraft/playground/dag/graph_manipulation.py b/handcraft/playground/dag/graph_manipulation.py index ea076a73..397b3657 100644 --- a/handcraft/playground/dag/graph_manipulation.py +++ b/handcraft/playground/dag/graph_manipulation.py @@ -284,7 +284,7 @@ def __str__(self): graph = Graph(create_sample=True) -print('graph = \n{}'.format(graph)) +# print('graph = \n{}'.format(graph)) global_new_graph = Graph() # print('nodeList[{}] = \n{}'.format(len(nodeList), nodeList)) @@ -543,8 +543,8 @@ def run(self, g: Graph, config: Config) -> Graph: # para = Recompute() -config = Config() -config.num = 2 -config.stages = 2 -global_new_graph = para.run(graph, config) -print('new_graph = \n{}'.format(global_new_graph)) +# config = Config() +# config.num = 2 +# config.stages = 2 +# global_new_graph = para.run(graph, config) +# print('new_graph = \n{}'.format(global_new_graph)) diff --git a/handcraft/playground/dag/graph_trans.py b/handcraft/playground/dag/graph_trans.py index 320dade0..4134acec 100644 --- a/handcraft/playground/dag/graph_trans.py +++ b/handcraft/playground/dag/graph_trans.py @@ -1,7 +1,3 @@ -from graph_manipulation import * - - - # general transformations ''' Op := I -> Op (pre-identity) @@ -9,7 +5,7 @@ Op := Op, Op (replicate) ''' -# batch transformation (due to operator sample-wise) +# batch transformation (due to DL operators are sample-wise) ''' DataLoader split (output)activation @@ -47,5 +43,4 @@ ''' -def trans(node, )->Node: - pass \ No newline at end of file + From ccc98fb0ce2216774e8b94c732784fc90e2688b4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 9 May 2022 12:48:48 +0800 Subject: [PATCH 0799/1892] add pruning techs for same microbatches --- cube/search/composer.py | 75 +++++++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 21 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 4f453960..330db261 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -324,12 +324,13 @@ def premise(fn, ndevs: int, nmicros: int): @staticmethod def bfs_schedule(micros: List[MicroPlan]): micros.sort(key=lambda m: m.mid) + same_micros = Composer.same_plans(micros, start_step=0) step = 0 + opt_step = sum(micro.nsteps for micro in micros) # initial prev: List[List[MicroPlan]] = [micros] next: List[List[MicroPlan]] = [] - output: List[List[MicroPlan]] = [] - while True: - find = False + schedules: List[SchedulePlan] = [] + while step <= opt_step: print(f'solving step {step}, candidates {len(prev)}...') for micros in prev: # get and solve conflicts @@ -348,35 +349,51 @@ def bfs_schedule(micros: List[MicroPlan]): search_devs.append(dev) if len(search_devs) == 0: micros = [micro.copy() for micro in micros] - next.append(micros) if SchedulePlan.composable(micros): - output.append(micros) - find = True + schedule = SchedulePlan(micros) + schedules.append(schedule) + if schedule.nsteps < opt_step: + print(f'find fewer steps: {schedule.nsteps}') + opt_step = min(opt_step, schedule.nsteps) + else: + next.append(micros) + # search space using different shift choice else: slots = [[micro.mid for (micro, _) in conflicts[dev]] for dev in search_devs] + # prune for symmetric micro batch + keep_slots = None + if same_micros: + # pruning: always shifts on later micro batches on forward and backward + keep_slots = [] + for dev in search_devs: + mids = [] + fmids = [micro.mid for (micro, block) in conflicts[dev] if block.type == Block.BType.FW] + bmids = [micro.mid for (micro, block) in conflicts[dev] if block.type == Block.BType.BW] + if len(fmids) > 0: + mids.append(min(fmids)) + if len(bmids) > 0: + mids.append(min(bmids)) + keep_slots.append(mids) # input(f'search devs: {search_devs}, slots: {slots} | >>>') - for keep_mids in Composer.otho_iter(slots): + for shift_mids in Composer.iter_shifts(slots, keep_slots): shifted_micros = [micro.copy() for micro in micros] - shift_mids = [ - [mid for mid in slot if mid != kmid] for kmid, slot in zip(keep_mids, slots) - ] for dev, mids in zip(search_devs, shift_mids): for mid in mids: block = micros[mid].block(dev, step) # print(f'shift(micro{mid}, block<{(dev, step)}>)') shifted_micros[mid] = micros[mid].shift(block, inplace=False) - next.append(shifted_micros) if SchedulePlan.composable(shifted_micros): - output.append(shifted_micros) - shifted_micros=None - find = True + schedule = SchedulePlan(shifted_micros) + schedules.append(schedule) + if schedule.nsteps < opt_step: + print(f'find fewer steps: {schedule.nsteps}') + opt_step = min(opt_step, schedule.nsteps) + else: + next.append(shifted_micros) prev, next = next, [] - if find: - prev = output - break step += 1 - schedules = [SchedulePlan(micros) for micros in prev] + schedules = [schedule for schedule in schedules if schedule.nsteps == opt_step] return schedules @@ -401,11 +418,11 @@ def to_same_step(micros: List[MicroPlan]): return micros @staticmethod - def otho_iter(slots: List[List[Any]]): + def iter_bucket(slots: List[List[Any]]): """ othogonal pickers - item for each slot can be randomly selected + item for each slot can be iteratively selected """ if len(slots) == 0: yield [] @@ -417,9 +434,25 @@ def otho_iter(slots: List[List[Any]]): else: slots = slots[1:] for item in slot: - for res in Composer.otho_iter(slots): + for res in Composer.iter_bucket(slots): yield [item] + res return + + @staticmethod + def iter_shifts(conflicts: List[List[int]], keep_candidates: List[List[int]] = None) -> List[List[int]]: + """ + conflicts: the conflicted micro ids grouped by devices. + keep_candidates: the candidates that can be preserved with no shift. By default, None + indicates every block has the opportunity to be kept. + Yield: + microbatch ids need to be shifted + """ + keep_candidates = conflicts if keep_candidates is None else keep_candidates + for keep_mids in Composer.iter_bucket(keep_candidates): + shift_mids = [ + [mid for mid in slot if mid != kmid] for kmid, slot in zip(keep_mids, conflicts) + ] + yield shift_mids if __name__ == '__main__': From 94ad0c0eed66d609d44ce27304ee9a2ca3ef5de6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 9 May 2022 16:37:43 +0800 Subject: [PATCH 0800/1892] add memory optimization tech --- cube/search/composer.py | 58 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 330db261..37c56e9a 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -322,7 +322,8 @@ def premise(fn, ndevs: int, nmicros: int): @staticmethod - def bfs_schedule(micros: List[MicroPlan]): + def bfs_schedule(micros: List[MicroPlan], mem_opt=True): + total_status = 1 micros.sort(key=lambda m: m.mid) same_micros = Composer.same_plans(micros, start_step=0) step = 0 @@ -330,7 +331,7 @@ def bfs_schedule(micros: List[MicroPlan]): prev: List[List[MicroPlan]] = [micros] next: List[List[MicroPlan]] = [] schedules: List[SchedulePlan] = [] - while step <= opt_step: + while step < opt_step: print(f'solving step {step}, candidates {len(prev)}...') for micros in prev: # get and solve conflicts @@ -391,9 +392,14 @@ def bfs_schedule(micros: List[MicroPlan]): opt_step = min(opt_step, schedule.nsteps) else: next.append(shifted_micros) + total_status += len(next) prev, next = next, [] step += 1 + total_status += len(schedules) schedules = [schedule for schedule in schedules if schedule.nsteps == opt_step] + if mem_opt: + schedules = [SchedulePlan(Composer.memory_opt(schedule.micros)) for schedule in schedules] + print(f'searched {total_status} status.') return schedules @@ -454,10 +460,50 @@ def iter_shifts(conflicts: List[List[int]], keep_candidates: List[List[int]] = N ] yield shift_mids + @staticmethod + def memory_opt(micros: List[MicroPlan]): + """ + optimize memory given a schedule plan. + The micros are composable. + """ + nsteps = max(micro.nsteps for micro in micros) + for step in range(nsteps-1, -1, -1): + micros = Composer.memory_opt_step(micros, step) + return micros + + @staticmethod + def memory_opt_step(micros: List[MicroPlan], step: int): + splan = sum(micro.plan for micro in micros) + free_steps = [np.where(splan[dev,:] == 0)[0] for dev in range(micros[0].ndevs)] + for micro in micros: + devs = np.where(micro.plan[:,step] > 0)[0] + for dev in devs: + # find non-critical forward blocks + block = micro.block(dev, step) + if block.type != Block.BType.FW: + continue + maxstep = min(micro.position(nblock)[1] for nblock in block.after) - 1 + if maxstep == step: # no room for shift => critical + continue + # find maximal shift distance + maxshift = None + for t in range(maxstep, step, -1): + if t in free_steps[dev]: + maxshift = t - step + break + # apply shift by `distance` times + if maxshift is not None: + for _ in range(maxshift): + micro.shift(block, inplace=True) + return micros + if __name__ == '__main__': def uniform_staging(ndevs: int, nmicros=4): + """ + shape be can "v" or "^" + """ micros = [] for mid in range(nmicros): micro = MicroPlan(mid, ndevs) @@ -466,7 +512,7 @@ def uniform_staging(ndevs: int, nmicros=4): blocks = fblocks + bblocks micro.add_dependency(blocks) micros.append(micro) - return micros + return micros def compose_1F1B(ndevs, nmicros): # premise @@ -488,7 +534,7 @@ def compose_1F1B(ndevs, nmicros): print(schedule) return schedule - def search(ndevs, nmicros): + def search(ndevs, nmicros, visualize=False): # premise micros = Composer.premise(uniform_staging, ndevs, nmicros) @@ -502,10 +548,12 @@ def search(ndevs, nmicros): steps = set(schedule.nsteps for schedule in schedules) assert len(steps) == 1, f"got un-consistent step set: {steps}" nsteps = list(steps)[0] - print(f'find {len(schedules)} step-optimal schedules of step: {nsteps}') + print(f'find {len(schedules)} step-optimal plans (step={nsteps})') for idx, schedule in enumerate(schedules): print(f'Schedule #{idx+1}:') print(schedule) + if visualize: + schedule.visualize(f'planlog/plan{idx+1}.png') ndevs = 4 From 1310b7e7753a5232eb0dccffce603908b99fa806 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Tue, 10 May 2022 08:50:21 +0800 Subject: [PATCH 0801/1892] nit --- handcraft/playground/dag/data_parallel_raw.py | 153 +++++++++++------- 1 file changed, 97 insertions(+), 56 deletions(-) diff --git a/handcraft/playground/dag/data_parallel_raw.py b/handcraft/playground/dag/data_parallel_raw.py index 869d9de5..0f8ab7ff 100644 --- a/handcraft/playground/dag/data_parallel_raw.py +++ b/handcraft/playground/dag/data_parallel_raw.py @@ -2,84 +2,125 @@ ''' [dataflow graph]: the logic of one training iteration -input: samples, weight tensors, optimizer state tensors -output: (updated) weight tensors, (updated) optimizer state tensors +DFG input: samples, weight tensors, optimizer state tensors +DFG output: (updated) weight tensors, (updated) optimizer state tensors -(assumption) a DL model description compatible with different batch-size +data tensors as edges, produced by one operator and consumed by one or more operator(s) +operators as nodes, consumes one or more tensor(s), (mostly) producing one tensor + +(assumption resizable batch) same DL model description for different batch-sizes (batch-size as variable) ref 1 ONNX: https://github.com/onnx/onnx/issues/2182 ref 2 ... -''' +/////////////// +graph manipulation as data-parallel, 4 method options +option 1: manually decide manipulation for each node/tensor following an oracle that knows everything +option 2: deep copy graph and manually decide adjustment for each node/tensor +option 3: using node-role, e.g., DataNode/Fwd/Bwd split; Optimizer replicate; weight's gradient all-reduce before used by Optimizers +option 4: using tensor info. e.g., tensors with batch-dim will split, operators and other tensors adapt accordingly ''' -graph manipulation as data-parallel -option 1: deep copy graph and manually decide adjustment for each node/tensor -option 2: manually decide manipulation for each node/tensor -option 3: using node-role, e.g., DataNode/Fwd/Bwd split; Optimizer replicate; weight gradient all-reduce before apply -option 4: using tensor info. e.g., tensors with -''' -def data_parallel_raw(g: Graph, dev_num: int, option: int): - print(g) - - if option == 1: #per node manipulation following oracle - def magic_func(node) -> bool: - pass +def data_parallel_raw(g: Graph, device_num: int, method: int): + def oracle_func(*args) -> bool: + pass - # 1. multiply operators + if method == 'raw graph manipulation': #per node manipulation following oracle's instruction + # 1. multiply operators for ``parallel'' in data-parllelism for node in g.nodes: - if magic_func(node) == 'split op': - new_nodes = [] - for i in range(dev_num): - new_node_inputs = [] - for ts in node.inputs: - if magic_func(ts) == 'split tensor': - new_node_inputs.append(split(ts, dev_num)) - elif magic_func(ts) == 'replicate tensor': - new_node_inputs.append(ts) - #TODO connect split tensor - new_nodes.append(Node(type=node.type, inputs=new_node_inputs)) #insert new node - elif magic_func(node) == 'replicate op': - new_nodes = [node] * dev_num - - g.replace(node, new_nodes) # TODO connect - - #2. inserting gradient averaging + new_nodes = [] + for device_id in range(device_num): + new_node_inputs = [] + for ts in node.inputs: + # find corresponding input tensor, which is another new operator's (sliced/replicated...) output + new_input = oracle_func(node, ts, device_id, device_num).query("find_new_input") + new_node_inputs.append(new_input) + + new_node_outputs = [] + for ts in node.outputs: + # new out tensor of the same shape (if replicate) or 1/N (if slice on certain dim) + new_output_shape = oracle_func(node, ts, device_id, device_num).query("new_output_shape") + new_output = Tensor(new_output_shape) # create new tensor as output (will be another operator(s)'s input) + new_node_outputs.append(new_output) + + new_node_type = oracle_func(node).query("new_node_type") + # create new node, with device info + new_node = Node(type=new_node_type, inputs=new_node_inputs, outputs=new_node_outputs, + device=device_id) + new_nodes.append(new_node) + + g.replace(node, new_nodes) #replacing with new nodes + + # 2. inserting gradient averaging for node in g.nodes: new_allreduce_node = None - gradient = None + input_to_replace = None for ts in node.inputs: - if magic_func(ts): - new_allreduce_node = Node(type=allreduce, inputs=ts) - gradient = ts + if oracle_func(ts).query('insert allreduce here'): + new_allreduce_node = Node(type='allreduce', inputs=ts) + input_to_replace = ts break - new_node = Node(node.type, node.inputs - gradient + new_allreduce_node.output) + + new_node = Node(type=node.type, inputs=node.inputs - input_to_replace + new_allreduce_node.output, + outputs=node.outputs) g.replace(node, [new_allreduce_node, new_node]) - elif option == 2: #replicate entire graph and adjust - def magic_func(node) -> bool: - pass + elif method == 'replicate graph and adjust': #replicate entire graph and adjust, similar to approaches of Horovod and PyTorch DDP + # 1. deep copy graph + graphs = [g.deepcopy() for i in range(device_num)] - # deep copy graph - graphs = [g.deepcopy() for i in range(dev_num)] - # reset batch size for each new graph, leveraging resizable DFG description (<-assumption) - for graph in graphs: - graph.batch_size = g.batch_size // dev_num - # inserting gradient averaging + # 2. reset batch size for each new graph, leveraging model description resizable batch (<-assumption) + # input or output shape inferred from shape_inference, representing split (1/N shape) or replicated (unchanged shape) + for index, graph in enumerate(graphs): + graph.arguments.batch_size = g.arguments.batch_size // device_num + graph.to_device(device=index) + + # 3. inserting gradient averaging for graph in graphs: for node in graph.nodes: new_allreduce_node = None - gradient = None + input_to_replace = None for ts in node.inputs: - if magic_func(ts): - new_allreduce_node = Node(type=allreduce, inputs=ts) - gradient = ts + if oracle_func(ts).query('insert allreduce here'): + new_allreduce_node = Node(type=allreduce, inputs=ts, outputs=Tensor(node.outputs.shape)) + input_to_replace = ts break - new_node = Node(type=node.type, inputs=node.inputs - gradient + new_allreduce_node.output) + + new_node = Node(type=node.type, inputs=node.inputs - input_to_replace + new_allreduce_node.output, + outputs=node.outputs) graph.replace(node, [new_allreduce_node, new_node]) - elif option == 3: #node role based manipulation - pass - elif option == 4: #tensor dimention info based manipulation + elif method == 3: #node role based manipulation + for node in g.nodes: + if isinstance(node, (NodeData)): + new_nodes = [ + Node(type=node.type, inputs=None, + config=node.config.reset_batch_size(node.config.batch_size // device_num), + outputs=node.outputs.shape[0] // device_num + node.outputs.shape[1:]) for + device_id in range(device_num)] + elif isinstance(node, (NodeFwd, NodeBwdA)): + new_nodes = [ + # assume inputs[0] as activation (for NodeFwd) or activation's gradient (for NodeBwdA) and inputs[1] as weight + Node(type=node.type, + inputs=[oracle_func(node, node.inputs[0], device_id, device_num).query("find_new_input"), #batch-split + oracle_func(node, node.inputs[1], device_id, device_num).query("find_new_input")], #replicated + outputs=Tensor(node.outputs.shape[0] // device_num + node.outputs.shape[1:])) #output batch-split + for device_id in range(device_num)] + elif isinstance(node, (NodeBwdW)): #backward that computing weight's gradient + # assume inputs[0] as activation's gradient and inputs[1] as activation, both with batch-dim + new_nodes = [ + Node(type=node.type, + inputs=[oracle_func(node, node.inputs[0], device_id, device_num).query("find_new_input"), #batch-split + oracle_func(node, node.inputs[1], device_id, device_num).query("find_new_input")], #batch-split + outputs=[Tensor(node.outputs.shape[0])]) #shape unchanged, but only 1/N value + for device_id in range(device_num)] + elif isinstance(node, (NodeOpt)): + new_nodes = trans(node, algo.replica, device_num) # replicated optimizers + [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] + + g.replace(node, new_nodes) + + #omit device assign and allreduce insertion + elif method == 4: #tensor dimention info based manipulation pass From 5306111519928f5acda8fd50575e36d011524061 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 11 May 2022 13:55:19 +0800 Subject: [PATCH 0802/1892] premise extends to multiple devices --- cube/search/composer.py | 286 +++++++++++++++++++++------------------- 1 file changed, 154 insertions(+), 132 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 37c56e9a..e612058b 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -2,7 +2,7 @@ Abstraction layer for microb-batch execution plan merge. """ -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional, Union import numpy as np from enum import Enum import time @@ -40,7 +40,7 @@ def __repr__(self): class PlanBase: def __init__(self, ndevs: int): - self.blocks: Dict[Tuple[int, int], Block] = dict() + self.blocks: Dict[Tuple[Tuple[int], int], Block] = dict() self.positions: Dict[int, Tuple[int, int]] = dict() self.plan = np.zeros((ndevs, ndevs * 2), dtype=int) @@ -60,7 +60,7 @@ def block(self, dev: int, step: int): return None return self.blocks[(dev, step)] - def position(self, block: Block) -> Optional[Tuple[int, int]]: + def position(self, block: Block) -> Optional[Tuple[Tuple[int], int]]: """ Get (dev, step) position given a block. If block not in this plan, return None @@ -118,24 +118,25 @@ def visualize(self, outfile=None): fontsize = [40] txts = list() - def draw_block(block: Block, position: Tuple[int, int], fontsize): + def draw_block(block: Block, position: Tuple[Tuple[int], int], fontsize): color = '#4472C4' if block.type == Block.BType.FW else '#ED7D31' - dev, step = position - rec = Rectangle((step, dev+0.5), 1, 1, color=color, ec='black', lw=1.5) - ax.add_artist(rec) - rx, ry = rec.get_xy() - cx = rx + rec.get_width() / 2.0 - cy = ry + rec.get_height() / 2.0 - anno = str(block.mid) - txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') - rbox = rec.get_window_extent(renderer) - for fs in range(fontsize[0], 1, -2): - txt.set_fontsize(fs) - tbox = txt.get_window_extent(renderer) - if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: - break - fontsize[0] = min(fontsize[0], fs) - txts.append(txt) + devs, step = position + for dev in devs: + rec = Rectangle((step, dev+0.5), 1, 1, color=color, ec='black', lw=1.5) + ax.add_artist(rec) + rx, ry = rec.get_xy() + cx = rx + rec.get_width() / 2.0 + cy = ry + rec.get_height() / 2.0 + anno = str(block.mid) + txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') + rbox = rec.get_window_extent(renderer) + for fs in range(fontsize[0], 1, -2): + txt.set_fontsize(fs) + tbox = txt.get_window_extent(renderer) + if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: + break + fontsize[0] = min(fontsize[0], fs) + txts.append(txt) for dev in range(self.ndevs): for step in range(self.nsteps): @@ -176,21 +177,27 @@ def expand_to(self, nsteps: int): extend = nsteps - self.nsteps self.plan = np.pad(self.plan, ((0,0),(0,extend))) - def add_block(self, pos: Tuple[int, int], btype: Block.BType) -> Block: + def add_block(self, pos: Tuple[Union[int, Tuple[int]], int], btype: Block.BType) -> Block: """ - Add a execution block + Add a execution block. + pos: [dev(s), step] """ - dev, step = pos - if dev >= self.ndevs: + devs, step = pos + if isinstance(devs, int): + devs = (devs,) + else: + devs = tuple(devs) + if max(devs) >= self.ndevs: raise ValueError("device out of scope") if step >= self.nsteps: self.expand_to(step + 1) - if self.plan[dev, step] != 0: + if not all([self.plan[dev, step] == 0 for dev in devs]): raise ValueError(f"Postition {pos} already has blocks") block = Block(self.mid, btype) - self.plan[dev, step] += 1 - self.blocks[(dev, step)] = block - self.positions[id(block)] = (dev, step) + for dev in devs: + self.plan[dev, step] += 1 + self.blocks[(dev, step)] = block + self.positions[id(block)] = (devs, step) return block def add_dependency(self, blocks: List[Block]): @@ -219,18 +226,25 @@ def shift(self, block: Block, inplace=True): # check block in this plan if block not in micro.blocks.values(): raise ValueError("Block not in this micro plan") - dev, step = micro.position(block) + devs, step = micro.position(block) + # shift later blocks for after_block in block.after: if step + 1 == micro.position(after_block)[1]: micro.shift(after_block, inplace=True) - micro.plan[dev, step] = 0 - if step + 1 >= micro.nsteps: - micro.expand_to(micro.nsteps + 1) - micro.plan[dev, step+1] = 1 - # update blocks and positions - del micro.blocks[(dev, step)] - micro.blocks[(dev, step+1)] = block - micro.positions[id(block)] = (dev, step+1) + for dev in devs: + next_block = self.block(dev, step+1) + if next_block is not None: + micro.shift(next_block, inplace=True) + # shift this one + for dev in devs: + micro.plan[dev, step] = 0 + if step + 1 >= micro.nsteps: + micro.expand_to(micro.nsteps + 1) + micro.plan[dev, step+1] = 1 + # update blocks and positions + del micro.blocks[(dev, step)] + micro.blocks[(dev, step+1)] = block + micro.positions[id(block)] = (devs, step+1) return micro def unshift(self, block: Block, inplace=True): @@ -238,19 +252,21 @@ def unshift(self, block: Block, inplace=True): reverse shift, for search only """ micro = self if inplace else self.copy() - dev, step = micro.position(block) + devs, step = micro.position(block) if step == 0: raise ValueError("unshift a block with step = 0") # shift back - micro.plan[dev, step] = 0 - micro.plan[dev, step-1] = 1 - del micro.blocks[(dev, step)] - micro.blocks[(dev, step-1)] = 1 - micro.positions[id(block)] = (dev, step-1) + for dev in devs: + micro.plan[dev, step] = 0 + micro.plan[dev, step-1] = 1 + del micro.blocks[(dev, step)] + micro.blocks[(dev, step-1)] = 1 + micro.positions[id(block)] = (devs, step-1) # shift back shifted blocks for after_block in block.after: if step + 1 == micro.position(after_block)[1]: micro.unshift(after_block, inplace=True) + # TODO: how can I know the independent blocks got shifted? return micro @@ -320,7 +336,6 @@ def premise(fn, ndevs: int, nmicros: int): micros = fn(ndevs, nmicros) return micros - @staticmethod def bfs_schedule(micros: List[MicroPlan], mem_opt=True): total_status = 1 @@ -337,61 +352,21 @@ def bfs_schedule(micros: List[MicroPlan], mem_opt=True): # get and solve conflicts conflicts = SchedulePlan.conflict(micros, step) # input(f'conflicts: dev: {list(conflicts.keys())}, mids: {[[conf[0].mid for conf in c] for c in conflicts.values()]} | >>>') - search_devs = [] - # direct shift on symetrics - for dev, microblocks in conflicts.items(): - cmicros = [micro for (micro, _) in microblocks] - cblocks = [block for (_, block) in microblocks] - if Composer.same_plans(cmicros, start_step=step): - for cmicro, cblock in zip(cmicros[1:], cblocks[1:]): - # print(f'shift(micro{cmicro.mid}, block<{cmicro.position(cblock)}>)') - cmicro.shift(cblock, inplace=True) - else: - search_devs.append(dev) - if len(search_devs) == 0: - micros = [micro.copy() for micro in micros] - if SchedulePlan.composable(micros): - schedule = SchedulePlan(micros) + for shifts in Composer.iter_shifts(conflicts, step, prune_same_micro=True, keep_early_fw=same_micros): + # print(f"step {step}: {shifts}") + shifted_micros = [micro.copy() for micro in micros] + for cblock in shifts: + cmid = cblock.mid + cmicro = shifted_micros[cmid] + cmicro.shift(cblock, inplace=True) + if SchedulePlan.composable(shifted_micros): + schedule = SchedulePlan(shifted_micros) schedules.append(schedule) if schedule.nsteps < opt_step: print(f'find fewer steps: {schedule.nsteps}') opt_step = min(opt_step, schedule.nsteps) else: - next.append(micros) - - # search space using different shift choice - else: - slots = [[micro.mid for (micro, _) in conflicts[dev]] for dev in search_devs] - # prune for symmetric micro batch - keep_slots = None - if same_micros: - # pruning: always shifts on later micro batches on forward and backward - keep_slots = [] - for dev in search_devs: - mids = [] - fmids = [micro.mid for (micro, block) in conflicts[dev] if block.type == Block.BType.FW] - bmids = [micro.mid for (micro, block) in conflicts[dev] if block.type == Block.BType.BW] - if len(fmids) > 0: - mids.append(min(fmids)) - if len(bmids) > 0: - mids.append(min(bmids)) - keep_slots.append(mids) - # input(f'search devs: {search_devs}, slots: {slots} | >>>') - for shift_mids in Composer.iter_shifts(slots, keep_slots): - shifted_micros = [micro.copy() for micro in micros] - for dev, mids in zip(search_devs, shift_mids): - for mid in mids: - block = micros[mid].block(dev, step) - # print(f'shift(micro{mid}, block<{(dev, step)}>)') - shifted_micros[mid] = micros[mid].shift(block, inplace=False) - if SchedulePlan.composable(shifted_micros): - schedule = SchedulePlan(shifted_micros) - schedules.append(schedule) - if schedule.nsteps < opt_step: - print(f'find fewer steps: {schedule.nsteps}') - opt_step = min(opt_step, schedule.nsteps) - else: - next.append(shifted_micros) + next.append(shifted_micros) total_status += len(next) prev, next = next, [] step += 1 @@ -402,7 +377,6 @@ def bfs_schedule(micros: List[MicroPlan], mem_opt=True): print(f'searched {total_status} status.') return schedules - @staticmethod def same_plans(micros: List[MicroPlan], start_step: int = 0) -> bool: Composer.to_same_step(micros) @@ -424,41 +398,84 @@ def to_same_step(micros: List[MicroPlan]): return micros @staticmethod - def iter_bucket(slots: List[List[Any]]): + def iter_shifts(conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], + step: int, + prune_same_micro = True, + keep_early_fw = False) -> List[Block]: """ - othogonal pickers - - item for each slot can be iteratively selected + Enumerate shifted blocks to resolve conflicts on step `step`. """ - if len(slots) == 0: - yield [] - return - slot = slots[0] - if len(slots) == 1: - for item in slot: - yield [item] - else: - slots = slots[1:] - for item in slot: - for res in Composer.iter_bucket(slots): - yield [item] + res - return - - @staticmethod - def iter_shifts(conflicts: List[List[int]], keep_candidates: List[List[int]] = None) -> List[List[int]]: - """ - conflicts: the conflicted micro ids grouped by devices. - keep_candidates: the candidates that can be preserved with no shift. By default, None - indicates every block has the opportunity to be kept. - Yield: - microbatch ids need to be shifted - """ - keep_candidates = conflicts if keep_candidates is None else keep_candidates - for keep_mids in Composer.iter_bucket(keep_candidates): - shift_mids = [ - [mid for mid in slot if mid != kmid] for kmid, slot in zip(keep_mids, conflicts) - ] - yield shift_mids + devs = list(conflicts.keys()) + prev_shifts: List[List[Block]] = [[],] + next_shifts: List[List[Block]] = [] + for dev in devs: + for shifts in prev_shifts: + cmicros = [c[0] for c in conflicts[dev]] + cblocks = [c[1] for c in conflicts[dev]] + # since a same block can be on multiple devices (e.g., tensor parallel) + # we need to remove shifted blocks if it is decided before + for sblock in shifts: + if sblock in cblocks: + idx = cblocks.index(sblock) + cblocks = cblocks[:idx] + cblocks[idx+1:] + cmicros = cmicros[:idx] + cmicros[idx+1:] + if len(cblocks) <= 1: + continue + # pruning tech (for same micro plans): keep forward and backward blocks with smallest mid + if keep_early_fw: + fcblocks = [cblock for cblock in cblocks if cblock.type == Block.BType.FW] + bcblocks = [cblock for cblock in cblocks if cblock.type == Block.BType.BW] + candidates = [] + if len(fcblocks) > 0: + candidates.append(fcblocks[0]) + if len(bcblocks) > 0: + candidates.append(bcblocks[0]) + for kblock in candidates: + idx = cblocks.index(kblock) + # keep blocks on the idx while shifts the rest + nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] + # if the reserved block executes on multiple devices, + # then the rest device must shift all other blocks + for odev in cmicros[idx].position(kblock)[0]: + if odev != dev: + for _, ocblock in conflicts[odev]: + if ocblock != cblocks[idx] and ocblock not in nshifts: + nshifts.append(ocblock) + next_shifts.append(shifts + cblocks[:idx] + cblocks[idx+1:]) + # pruning tech: if micro plan is same, keep the first block + elif prune_same_micro: + if Composer.same_plans(cmicros, start_step=step): + shifts = shifts + cblocks[1:] + next_shifts.append(shifts) + else: + for idx in range(len(cblocks)): + # keep blocks on the idx while shifts the rest + nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] + # if the reserved block executes on multiple devices, + # then the rest device must shift all other blocks + for odev in cmicros[idx].position(cblocks[idx])[0]: + if odev != dev: + for _, ocblock in conflicts[odev]: + if ocblock != cblocks[idx] and ocblock not in nshifts: + nshifts.append(ocblock) + next_shifts.append(nshifts) + # full conflicted space search + else: + for idx in range(len(cblocks)): + # keep blocks on the idx while shifts the rest + nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] + # if the reserved block executes on multiple devices, + # then the rest device must shift all other blocks + for odev in cmicros[idx].position(cblocks[idx])[0]: + if odev != dev: + for _, ocblock in conflicts[odev]: + if ocblock != cblocks[idx] and ocblock not in nshifts: + nshifts.append(ocblock) + next_shifts.append(nshifts) + + prev_shifts, next_shifts = next_shifts, [] + for shifts in prev_shifts: + yield shifts @staticmethod def memory_opt(micros: List[MicroPlan]): @@ -477,18 +494,23 @@ def memory_opt_step(micros: List[MicroPlan], step: int): free_steps = [np.where(splan[dev,:] == 0)[0] for dev in range(micros[0].ndevs)] for micro in micros: devs = np.where(micro.plan[:,step] > 0)[0] + fblocks = [] + # find forward blocks for dev in devs: - # find non-critical forward blocks block = micro.block(dev, step) if block.type != Block.BType.FW: continue + if block not in fblocks: + fblocks.append(block) + # find non-critical forward blocks + for block in fblocks: maxstep = min(micro.position(nblock)[1] for nblock in block.after) - 1 if maxstep == step: # no room for shift => critical continue # find maximal shift distance maxshift = None for t in range(maxstep, step, -1): - if t in free_steps[dev]: + if all([t in free_steps[dev] for dev in micro.position(block)[0]]): maxshift = t - step break # apply shift by `distance` times @@ -540,7 +562,7 @@ def search(ndevs, nmicros, visualize=False): # search shift tic = time.time() - schedules = Composer.bfs_schedule(micros) + schedules = Composer.bfs_schedule(micros, mem_opt=True) toc = time.time() print('search done. time {:.2f}s'.format(toc - tic)) @@ -561,4 +583,4 @@ def search(ndevs, nmicros, visualize=False): # schedule = compose_1F1B(ndevs, nmicros) # schedule.visualize('out.png') - search(ndevs, nmicros) + search(ndevs, nmicros, visualize=False) From b933538f5dc1982178823b72daf5ecf699bf2182 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 11 May 2022 16:01:39 +0800 Subject: [PATCH 0803/1892] add block hashing --- cube/search/composer.py | 197 ++++++++++++++++++++++++++++------------ 1 file changed, 141 insertions(+), 56 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index e612058b..8b8323e4 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -1,8 +1,10 @@ """ +The Tetris. + Abstraction layer for microb-batch execution plan merge. """ -from typing import Any, Dict, List, Tuple, Optional, Union +from typing import Any, Callable, Dict, List, Tuple, Optional, Union import numpy as np from enum import Enum import time @@ -340,7 +342,7 @@ def premise(fn, ndevs: int, nmicros: int): def bfs_schedule(micros: List[MicroPlan], mem_opt=True): total_status = 1 micros.sort(key=lambda m: m.mid) - same_micros = Composer.same_plans(micros, start_step=0) + block_hash = Composer.construct_hash(micros) # False # Composer.same_plans(micros, start_step=0) step = 0 opt_step = sum(micro.nsteps for micro in micros) # initial prev: List[List[MicroPlan]] = [micros] @@ -352,13 +354,17 @@ def bfs_schedule(micros: List[MicroPlan], mem_opt=True): # get and solve conflicts conflicts = SchedulePlan.conflict(micros, step) # input(f'conflicts: dev: {list(conflicts.keys())}, mids: {[[conf[0].mid for conf in c] for c in conflicts.values()]} | >>>') - for shifts in Composer.iter_shifts(conflicts, step, prune_same_micro=True, keep_early_fw=same_micros): + for shifts in Composer.iter_shifts(conflicts, step, prune_same_micro=True, block_hash=block_hash): # print(f"step {step}: {shifts}") shifted_micros = [micro.copy() for micro in micros] for cblock in shifts: cmid = cblock.mid cmicro = shifted_micros[cmid] cmicro.shift(cblock, inplace=True) + # print(f"solved results: ") + # for micro in shifted_micros: + # print(f'microbatch #{micro.mid}:') + # print(micro) if SchedulePlan.composable(shifted_micros): schedule = SchedulePlan(shifted_micros) schedules.append(schedule) @@ -386,7 +392,42 @@ def same_plans(micros: List[MicroPlan], start_step: int = 0) -> bool: if not np.array_equal(plan, other): return False return True - + + @staticmethod + def construct_hash(micros: List[MicroPlan]) -> Callable: + """ + construct a hashing function to map "same" blocks into a same integer. + + The "same" blocks refer to the same-position blocks of same micro plans. + """ + # group same micro plans + same_plans: List[List[MicroPlan]] = [[]] + for micro in micros: + for smicros in same_plans: + if Composer.same_plans(smicros + [micro]): + smicros.append(micro) + break + else: + same_plans.append([micro]) + print(f'detecting {len(same_plans)} same-microplan groups: {[[plan.mid for plan in smicros] for smicros in same_plans]}') + # for each micro plan group, group same hash functions + gid = 0 + block2gid: Dict[int, int] = dict() + for smicros in same_plans: + positions: Dict[Tuple[Tuple[int], int], List[Block]] = dict() + for micro in smicros: + for pos, block in micro.blocks.items(): + if pos not in positions: + positions[pos] = [] + positions[pos].append(block) + for blocks in positions.values(): + for block in blocks: + block2gid[id(block)] = gid + gid += 1 + def blockhash(block: Block) -> int: + return block2gid[id(block)] + return blockhash + @staticmethod def to_same_step(micros: List[MicroPlan]): """ @@ -401,7 +442,7 @@ def to_same_step(micros: List[MicroPlan]): def iter_shifts(conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], step: int, prune_same_micro = True, - keep_early_fw = False) -> List[Block]: + block_hash = Union[None, Callable]) -> List[Block]: """ Enumerate shifted blocks to resolve conflicts on step `step`. """ @@ -420,59 +461,58 @@ def iter_shifts(conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], cblocks = cblocks[:idx] + cblocks[idx+1:] cmicros = cmicros[:idx] + cmicros[idx+1:] if len(cblocks) <= 1: + next_shifts.append(shifts) continue - # pruning tech (for same micro plans): keep forward and backward blocks with smallest mid - if keep_early_fw: - fcblocks = [cblock for cblock in cblocks if cblock.type == Block.BType.FW] - bcblocks = [cblock for cblock in cblocks if cblock.type == Block.BType.BW] - candidates = [] - if len(fcblocks) > 0: - candidates.append(fcblocks[0]) - if len(bcblocks) > 0: - candidates.append(bcblocks[0]) - for kblock in candidates: - idx = cblocks.index(kblock) - # keep blocks on the idx while shifts the rest - nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] - # if the reserved block executes on multiple devices, - # then the rest device must shift all other blocks - for odev in cmicros[idx].position(kblock)[0]: - if odev != dev: - for _, ocblock in conflicts[odev]: - if ocblock != cblocks[idx] and ocblock not in nshifts: - nshifts.append(ocblock) - next_shifts.append(shifts + cblocks[:idx] + cblocks[idx+1:]) - # pruning tech: if micro plan is same, keep the first block - elif prune_same_micro: - if Composer.same_plans(cmicros, start_step=step): - shifts = shifts + cblocks[1:] - next_shifts.append(shifts) - else: - for idx in range(len(cblocks)): - # keep blocks on the idx while shifts the rest - nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] - # if the reserved block executes on multiple devices, - # then the rest device must shift all other blocks - for odev in cmicros[idx].position(cblocks[idx])[0]: - if odev != dev: - for _, ocblock in conflicts[odev]: - if ocblock != cblocks[idx] and ocblock not in nshifts: - nshifts.append(ocblock) - next_shifts.append(nshifts) - # full conflicted space search + + candidates = [] + if block_hash is not None: + gids = [block_hash(cblock) for cblock in cblocks] + for gid in set(gids): + gblocks = [cblock for cblock, cgid in zip(cblocks, gids) if cgid == gid] + gmids = [gblock.mid for gblock in gblocks] + idx = gmids.index(min(gmids)) + candidates.append(gblocks[idx]) else: - for idx in range(len(cblocks)): - # keep blocks on the idx while shifts the rest - nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] - # if the reserved block executes on multiple devices, - # then the rest device must shift all other blocks - for odev in cmicros[idx].position(cblocks[idx])[0]: - if odev != dev: - for _, ocblock in conflicts[odev]: - if ocblock != cblocks[idx] and ocblock not in nshifts: - nshifts.append(ocblock) - next_shifts.append(nshifts) + candidates = cblocks + + if prune_same_micro: + if Composer.same_plans(cmicros, start_step=step): + candidates = [candidates[0]] + # pruning tech (for same micro plans): keep forward and backward blocks with smallest mid + # if keep_early_fw: + # fcblocks = [cblock for cblock in cblocks if cblock.type == Block.BType.FW] + # bcblocks = [cblock for cblock in cblocks if cblock.type == Block.BType.BW] + # candidates = [] + # if len(fcblocks) > 0: + # candidates.append(fcblocks[0]) + # if len(bcblocks) > 0: + # candidates.append(bcblocks[0]) + # for kblock in candidates: + # idx = cblocks.index(kblock) + # # keep blocks on the idx while shifts the rest + # nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] + # # if the reserved block executes on multiple devices, + # # then the rest device must shift all other blocks + # for odev in cmicros[idx].position(kblock)[0]: + # if odev != dev and odev in conflicts: + # for _, ocblock in conflicts[odev]: + # if ocblock != cblocks[idx] and ocblock not in nshifts: + # nshifts.append(ocblock) + # next_shifts.append(shifts + cblocks[:idx] + cblocks[idx+1:]) + # pruning tech: if micro plan is same, keep the first block + for kblock in candidates: + idx = cblocks.index(kblock) + # keep blocks on the idx while shifts the rest + nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] + # if the reserved block executes on multiple devices, + # then the rest device must shift all other blocks + for odev in cmicros[idx].position(kblock)[0]: + if odev != dev and odev in conflicts: + for _, ocblock in conflicts[odev]: + if ocblock != cblocks[idx] and ocblock not in nshifts: + nshifts.append(ocblock) + next_shifts.append(nshifts) prev_shifts, next_shifts = next_shifts, [] for shifts in prev_shifts: yield shifts @@ -524,7 +564,10 @@ def memory_opt_step(micros: List[MicroPlan], step: int): def uniform_staging(ndevs: int, nmicros=4): """ - shape be can "v" or "^" + f b + f b + f b + f b """ micros = [] for mid in range(nmicros): @@ -535,6 +578,32 @@ def uniform_staging(ndevs: int, nmicros=4): micro.add_dependency(blocks) micros.append(micro) return micros + + def mbart_staging(ndevs: int, nmicros=4): + """ + f f f b b b + f f f b b b + f f f b b b + f f f b b b + """ + micros = [] + for mid in range(nmicros): + micro = MicroPlan(mid, ndevs) + fblocks = [] + bblocks = [] + for step in range(ndevs+2): + if step in [0, ndevs // 2+1]: + fblock = micro.add_block((tuple(range(ndevs)), step), Block.BType.FW) + bblock = micro.add_block((tuple(range(ndevs)), (ndevs+2)*2-1-step), Block.BType.BW) + else: + dev = step - 1 if step < ndevs//2+1 else step - 2 + fblock = micro.add_block((dev, step), Block.BType.FW) + bblock = micro.add_block((dev, (ndevs+2)*2-1-step), Block.BType.BW) + fblocks.append(fblock) + bblocks.append(bblock) + micro.add_dependency(fblocks+bblocks[::-1]) + micros.append(micro) + return micros def compose_1F1B(ndevs, nmicros): # premise @@ -559,6 +628,12 @@ def compose_1F1B(ndevs, nmicros): def search(ndevs, nmicros, visualize=False): # premise micros = Composer.premise(uniform_staging, ndevs, nmicros) + # micros = Composer.premise(mbart_staging, ndevs, nmicros) + print('============== Premise ================') + for idx, micro in enumerate(micros): + print(f'microbatch #{idx}:') + print(micro) + print('============== Premise ================') # search shift tic = time.time() @@ -584,3 +659,13 @@ def search(ndevs, nmicros, visualize=False): # schedule = compose_1F1B(ndevs, nmicros) # schedule.visualize('out.png') search(ndevs, nmicros, visualize=False) + + # micros = mbart_staging(ndevs, nmicros) + # for idx, micro in enumerate(micros): + # print(f'microbatch #{idx}:') + # print(micro) + # + # micros[0].shift(micros[0].block(0, 0)) + # micros[0].shift(micros[0].block(0, 2)) + # micros[0].shift(micros[0].block(0, 5)) + # print(micros[0]) \ No newline at end of file From b11c6810e47ea82cdab9b900a6705a38f4fa3cef Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 11 May 2022 16:20:49 +0800 Subject: [PATCH 0804/1892] add chimera bipipe example --- cube/search/composer.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 8b8323e4..77682ed5 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -564,10 +564,10 @@ def memory_opt_step(micros: List[MicroPlan], step: int): def uniform_staging(ndevs: int, nmicros=4): """ - f b - f b - f b - f b + f b + f b + f b + f b """ micros = [] for mid in range(nmicros): @@ -604,6 +604,31 @@ def mbart_staging(ndevs: int, nmicros=4): micro.add_dependency(fblocks+bblocks[::-1]) micros.append(micro) return micros + + def chimera_staging(ndevs: int, nmicros: int): + """ + f b f b + f b f b + f b f b + f b f b + """ + micros = [] + assert nmicros % 2 == 0, "require microbatch# can be divided by 2." + for mid in range(nmicros // 2): # V shape + micro = MicroPlan(mid, ndevs) + fblocks = [micro.add_block((sid, sid), Block.BType.FW) for sid in range(ndevs)] + bblocks = [micro.add_block((ndevs-1-sid, sid+ndevs), Block.BType.BW) for sid in range(ndevs)] + blocks = fblocks + bblocks + micro.add_dependency(blocks) + micros.append(micro) + for mid in range(nmicros // 2): # ^ shape + micro = MicroPlan(mid + nmicros // 2, ndevs) + fblocks = [micro.add_block((ndevs-1-sid, sid), Block.BType.FW) for sid in range(ndevs)] + bblocks = [micro.add_block((sid, sid+ndevs), Block.BType.BW) for sid in range(ndevs)] + blocks = fblocks + bblocks + micro.add_dependency(blocks) + micros.append(micro) + return micros def compose_1F1B(ndevs, nmicros): # premise @@ -627,7 +652,8 @@ def compose_1F1B(ndevs, nmicros): def search(ndevs, nmicros, visualize=False): # premise - micros = Composer.premise(uniform_staging, ndevs, nmicros) + # micros = Composer.premise(uniform_staging, ndevs, nmicros) + micros = Composer.premise(chimera_staging, ndevs, nmicros) # micros = Composer.premise(mbart_staging, ndevs, nmicros) print('============== Premise ================') for idx, micro in enumerate(micros): From bfcf51682e46708bed88c5e06ab5994c003b518a Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Wed, 11 May 2022 17:14:23 +0800 Subject: [PATCH 0805/1892] make wrf2 scriptable --- .gitignore | 5 +- cube/graph/parser/converter.py | 3 +- examples/wrf/wrf2.py | 101 +++++++++++++++++++++------------ 3 files changed, 70 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index c0f52faf..6988bde4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ __pycache__ -*.egg-info \ No newline at end of file +*.egg-info + +.vs/ +.vscode/ \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index dc7377d1..c85ae8f3 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -15,7 +15,8 @@ def convert_model(model: torch.nn.Module, """ try: smodule = torch.jit.script(model) - except Exception: + except Exception as ex: + print(ex) raise RuntimeError("Cannot convert module into torchscript moudle.") module_name = smodule.original_name inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py index 6d0ffff0..7c932998 100644 --- a/examples/wrf/wrf2.py +++ b/examples/wrf/wrf2.py @@ -1,8 +1,12 @@ import torch import torch.nn.functional as F +from cube.runtime.syndata import SciLoopVariables + torch.set_default_tensor_type(torch.DoubleTensor) +import cube + class WRF(torch.nn.Module): def __init__(self, dt, ntau, nz, ny, nx, dz, dy, dx, device): @@ -108,7 +112,7 @@ def _step(self, dtau, ntau, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta return U, V, W, O, Theta, phi1, mu1 - def step(self, dtau, ntau, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu): + def step(self, dtau:float, ntau:int, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu): # initialize perturbed varibles U2 = torch.zeros(U.shape, device=self.device) V2 = torch.zeros(V.shape, device=self.device) @@ -133,7 +137,7 @@ def step(self, dtau, ntau, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, return U + U2, V + V2, W + W2, O + O2, Theta + Theta2, phi1 + phi2, mu1 + mu2 - def ac_step(self, dtau, + def ac_step(self, dtau:float, U2, V2, W2, O2, Theta2, phi2, mu2, pi2, R_U, R_V, R_W, R_Theta, R_phi, R_mu, U, V, Theta, phi, mu, alpha, p): @@ -188,19 +192,12 @@ def ac_step(self, dtau, # print('Theta2_:\t', Theta2_.min(), Theta2_.max()) # Theta2_ = torch.zeros(Theta2_.shape, device=Theta2_.device) - def f(x): - phi2_ = phi2 + dtau * ( - R_phi - (O2_ * self.dz(self.bz(self.pzphi(phi))) - self.g * (x + W2) * 0.5) / self.bz(mu)) - return ( - R_W + ( - self.dz(C * self.dz(self.pz(phi2_))) + self.dz(self.GAMMA * p * Theta2_ / Theta) - self.bz(mu2_) + - self.dz(C * self.dz(self.pz(phi2))) + self.dz(self.GAMMA * p * Theta2 / Theta) - self.bz(mu2) - ) * 0.5 * self.g - ) * dtau + W2 - x - - W2_ = self.solve_tridiagonal(f) - if torch.abs(f(W2_)).max() > 1e-6: - print("Triangular solver warning:\t", torch.abs(f(W2_)).max()) + W2_ = self.solve_tridiagonal_( + phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, + O2_, C, Theta2_, mu2_) + + #if torch.abs(f(W2_)).max() > 1e-6: + # print("Triangular solver warning:\t", torch.abs(f(W2_)).max()) W2_ = W2_ / (1 + self.damping(phi, 0.2, self.ztop * 0.75) * dtau) # print((1 + self.damping(phi, 0.8, self.ztop * 0.75) * dtau)[:, 64, 64]) @@ -212,7 +209,7 @@ def f(x): return U2_, V2_, W2_, O2_, Theta2_, phi2_, mu2_, pi2 - def damping(self, phi, gamma, zd): + def damping(self, phi, gamma:float, zd): z = phi / self.g res = gamma * torch.sin(torch.pi / 2 * (1 - (self.ztop - z) / (self.ztop - zd)))**2 return res * z.gt(zd).double() @@ -233,7 +230,7 @@ def RHS(self, U, V, W, O, Theta, phi1, mu1): R_U = ( # pressure term - self.bx(mu) * ( - + self.dx(self.bz(self.pz(phi1))) + self.dx(self.bz(self.pz(phi1))) + self.bx(alpha) * self.dx(p1) + self.bx(alpha1) * self.dx(self.p0)) - self.dx(self.bz(self.pzphi(phi))) * (self.dz(self.bx(self.pzp1(self.bz(p1)))) - self.bx(mu1)) @@ -245,7 +242,7 @@ def RHS(self, U, V, W, O, Theta, phi1, mu1): R_V = ( # pressure term - self.by(mu) * ( - + self.dy(self.bz(self.pz(phi1))) + self.dy(self.bz(self.pz(phi1))) + self.by(alpha) * self.dy(p1) + self.by(alpha1) * self.dy(self.p0)) - self.dy(self.bz(self.pzphi(phi))) * (self.dz(self.by(self.pzp1(self.bz(p1)))) - self.by(mu1)) @@ -256,7 +253,8 @@ def RHS(self, U, V, W, O, Theta, phi1, mu1): ) R_W = ( # pressure term - + self.g * (self.dz(p1) - self.bz(self.mu0) * 0.0) - self.bz(mu1) * self.g + #+ self.g * (self.dz(p1) - self.bz(self.mu0) * 0.0) - self.bz(mu1) * self.g + self.g * (self.dz(p1) - self.bz(self.mu0) * 0.0) - self.bz(mu1) * self.g # advection term - self.dx(self.px(self.bz(U) * self.bx(w))) - self.dy(self.py(self.bz(V) * self.by(w))) @@ -336,28 +334,47 @@ def bz(self, X): filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) / 2. - def solve_tridiagonal(self, f): - r"""Solve tridiagonal system f(x) = Ax - b = 0 - - Args: - f (Callable, return Tensor): Tridiagonal system (nz - 1, ny, nx) -> (nz - 1, ny, nx) - - Returns: - Tensor: Solution of the linear system with shape (D, H, W) - """ - b = - f(torch.zeros((self.nz - 1, self.ny, self.nx), device=self.device)) - - idx0 = torch.tensor([1., 0, 0], device=self.device).view(3, 1, 1) + def tridiagonal_system(self, + phi2, dtau:float, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, + O2_, C, Theta2_, mu2_, + x): + phi2_ = phi2 + dtau * ( + R_phi - (O2_ * self.dz(self.bz(self.pzphi(phi))) - self.g * (x + W2) * 0.5) / self.bz(mu)) + return ( + R_W + ( + self.dz(C * self.dz(self.pz(phi2_))) + self.dz(self.GAMMA * p * Theta2_ / Theta) - self.bz(mu2_) + + self.dz(C * self.dz(self.pz(phi2))) + self.dz(self.GAMMA * p * Theta2 / Theta) - self.bz(mu2) + ) * 0.5 * self.g + ) * dtau + W2 - x + + def solve_tridiagonal_(self, + phi2, dtau:float, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, + O2_, C, Theta2_, mu2_): + b = - self.tridiagonal_system( + phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, + O2_, C, Theta2_, mu2_, + torch.zeros((self.nz - 1, self.ny, self.nx), device=self.device)) + + idx0 = torch.tensor([1., 0., 0.], device=self.device).view(3, 1, 1) idx0 = idx0.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] - r0 = f(idx0) + b + r0 = self.tridiagonal_system( + phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, + O2_, C, Theta2_, mu2_, + idx0) + b - idx1 = torch.tensor([0., 1, 0], device=self.device).view(3, 1, 1) + idx1 = torch.tensor([0., 1., 0.], device=self.device).view(3, 1, 1) idx1 = idx1.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] - r1 = f(idx1) + b + r1 = self.tridiagonal_system( + phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, + O2_, C, Theta2_, mu2_, + idx1) + b - idx2 = torch.tensor([0., 0, 1], device=self.device).view(3, 1, 1) + idx2 = torch.tensor([0., 0., 1.], device=self.device).view(3, 1, 1) idx2 = idx2.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] - r2 = f(idx2) + b + r2 = self.tridiagonal_system( + phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, + O2_, C, Theta2_, mu2_, + idx2) + b d = (torch.stack([r0, r1, r2], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1) l = (torch.stack([r2, r0, r1], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1)[1:] @@ -404,8 +421,18 @@ def solve_tridiagonal(self, f): W = torch.zeros((nz - 1, ny, nx)).cuda() O = torch.zeros((nz - 1, ny, nx)).cuda() + varloader = SciLoopVariables(variables=[U, V, W, O, Theta, phi1, mu1], constants=[]) + model = cube.SemanticModel(wrf, input_shapes=tuple(varloader.shapes)) + + @cube.compile(model=model, dataloader=varloader) + def train_iter(model, dataloader): + U, V, W, O, Theta, phi1, mu1 = dataloader + U, V, W, O, Theta, phi1, mu1 = model(U, V, W, O, Theta, phi1, mu1) + return U, V, W, O, Theta, phi1, mu1 + model = model.get_gen_module() + for i in range(10): - U, V, W, O, Theta, phi1, mu1 = wrf(U, V, W, O, Theta, phi1, mu1) + U, V, W, O, Theta, phi1, mu1 = train_iter(model, varloader) mu = wrf.mu0 + mu1 u = U / wrf.bx(mu) v = V / wrf.by(mu) From 870f50687897f115fc310f00db3b015ebaa6e051 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 12 May 2022 02:04:55 +0800 Subject: [PATCH 0806/1892] handle torch.cat --- cube/graph/operator/function/cat.py | 52 ++++++++++++++++++++++++ cube/graph/operator/function/function.py | 10 +++++ cube/graph/parser/mapping.py | 2 + 3 files changed, 64 insertions(+) create mode 100644 cube/graph/operator/function/cat.py diff --git a/cube/graph/operator/function/cat.py b/cube/graph/operator/function/cat.py new file mode 100644 index 00000000..0d6024ce --- /dev/null +++ b/cube/graph/operator/function/cat.py @@ -0,0 +1,52 @@ +from copy import copy +from typing import List + +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + +class IRCat(IRFwOperation): + def __init__(self, signature: str, inputs: List[IRTensor], name: str, + **kwargs): + # torch.cat(inputs:List[Tensor], dim:int) -> Tensor + # REMARK: the input to 'cat' is a tensor list, so 'inputs' parameter directly reflects the singleton list containing that list, + # so the meaning of param 'inputs' is sligtly different from other IRXXXOp. + assert len(inputs) > 0, "TODO handle zero inputs" + assert len(kwargs) == 1, "Expected 1 kwargs: dim" + + super().__init__(name, signature, len(inputs), 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update(kwargs) + self._cat_count = len(inputs) + + def infer_shape(self) -> bool: + """ + Output shape inference given the input shapes + """ + dim = self.kwargs['dim'] + + # validation + # TODO how about zero inputs? + tensors : List[IRTensor] = self.inputs(None) # None for all inputs + + # Shape without the dim-th component + s0 : list = None + for i, tensor in enumerate(tensors): + s : list = copy(tensor.shape) # avoid mutating the original shape + + if len(s) == 0: + # Any shape unknown + return False + + s.pop(dim) + if i == 0: + s0 = s + else: + if s != s0: + # Inconsistent input shape + return False + + sumLen : int = sum(t.shape[dim] for t in tensors) + s0.insert(dim, sumLen) + self.outputs(0).shape = s0 + return True diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index c721b6c7..0018e66a 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -1,6 +1,7 @@ from typing import Iterable, List, Optional, Union, Dict import string import copy +from cube.graph.operator.function.cat import IRCat from cube.ir.cten import IRTensor from cube.graph.operator.function.einops import EinDim, IREinops @@ -450,6 +451,15 @@ def Pad(signature, inputs): pad, mode, value = inputs[1:] return IRPad(signature, tensors, 'pad', pad=pad, mode=mode, value=value) +def Cat(signature, inputs): + """ + torch.cat(inputs: List[Tensor], dim: int) -> Tensor + """ + tensors : List[IRTensor] + dim : int + tensors, dim = inputs + return IRCat(signature, tensors, 'cat', dim=dim) + def ScriptEinOps(signature, inputs): """ apply_for_scriptable_torch(recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str) -> torch.Tensor: diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index b984ac94..85c7ffb6 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -101,6 +101,8 @@ def register(signature: str, op: IRFwOperation, code): #pytorch1.11 __ttemplate('linear'): function.Linear, + __ttemplate('cat'): function.Cat, + #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, From 5d0c2557164be6e548c2f614ff7da2ee2840a2d6 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 12 May 2022 02:53:17 +0800 Subject: [PATCH 0807/1892] starting handle IR construct prim::device/aten::tensor --- cube/graph/parser/parser.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 94cfa53e..6a4477ad 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -23,6 +23,7 @@ class ScriptNodeKind(enum.Enum): PrimListUnpack = 8 PrimTupleUnpack = 9 PrimPythonOp = 10 + PrimGetDevice = 11 class ScriptModuleParser: @@ -152,6 +153,8 @@ def ntype(node: torch._C.Node): return ScriptNodeKind.PrimListUnpack if node.kind() == 'prim::PythonOp': return ScriptNodeKind.PrimPythonOp + if node.kind() == 'prim::device': + return ScriptNodeKind.PrimGetDevice raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @staticmethod @@ -178,6 +181,8 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] return ScriptModuleParser.parse_prim_list_unpack_node(node, module, frame) if node_type == ScriptNodeKind.PrimPythonOp: return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) + if node_type == ScriptNodeKind.PrimGetDevice: + return ScriptModuleParser.parse_prim_get_device_node(node, module, frame) raise NotImplementedError(f"Un-supported node type {node_type}") except Exception: raise RuntimeError(f"\n\nParsing error at node:\n\t{node}\n") @@ -254,6 +259,15 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: output: List[int] = list(tensor.shape) frame.add_var(outputs[0].debugName(), output) return [] + # aten::tensor(elems: List[T], dtype:ScalarType, device:Device, requires_grad:bool) -> Tensor + elif fsig == 'torch.tensor': + # originally 'aten::tensor' + var_name = outputs[0].debugName() + elems, dtype, device, requires_grad = input_val + kDefaultType = DType2IRDType.map(dtype) + ir_tensor = IRFullTensor(shape=[len(elems)], name=var_name, requires_grad=requires_grad, dtype=kDefaultType) + frame.add_var(var_name, ir_tensor) + return [] # create IR node ir_node = Sign2Op.map(fsig)(inputs=input_val) @@ -443,6 +457,18 @@ def parse_prim_python_op_node(node, module, frame): raise NotImplementedError("Cannot support torch.jit.ignore") print(dir(node)) + @staticmethod + def parse_prim_get_device_node(node, module, frame): + inputs = list(node.inputs()) + outputs = list(node.outputs()) + if len(inputs) != 1: + raise RuntimeError("Find prim::device has not exactly one input") + if len(outputs) != 1: + raise RuntimeError("Find prim::device has not exactly one output") + input = frame.get_var(inputs[0].debugName()) + frame.add_var(outputs[0].debugName(), "TODO DEFINE THE DEVICE") + return [] + @staticmethod def flatten(smodule, depth=0): """ From 9ae7bc59e395ef37081b34aa304f16de05f9d8ea Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 12 May 2022 12:22:39 +0800 Subject: [PATCH 0808/1892] erased prim::Device --- cube/graph/parser/parser.py | 43 ++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 6a4477ad..585a440a 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -1,7 +1,7 @@ import torch import enum import re -from typing import List, Tuple, Optional +from typing import Any, List, Tuple, Optional from cube.graph import IRFwOperation from cube.graph.tensor import IRFullTensor @@ -11,6 +11,9 @@ _refmodule = torch.nn.Module() +class ErasedDevice: + pass + class ScriptNodeKind(enum.Enum): PrimGetAttr = 1 @@ -23,7 +26,7 @@ class ScriptNodeKind(enum.Enum): PrimListUnpack = 8 PrimTupleUnpack = 9 PrimPythonOp = 10 - PrimGetDevice = 11 + PrimDevice = 11 # erased class ScriptModuleParser: @@ -154,7 +157,7 @@ def ntype(node: torch._C.Node): if node.kind() == 'prim::PythonOp': return ScriptNodeKind.PrimPythonOp if node.kind() == 'prim::device': - return ScriptNodeKind.PrimGetDevice + return ScriptNodeKind.PrimDevice raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @staticmethod @@ -181,8 +184,11 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] return ScriptModuleParser.parse_prim_list_unpack_node(node, module, frame) if node_type == ScriptNodeKind.PrimPythonOp: return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) - if node_type == ScriptNodeKind.PrimGetDevice: - return ScriptModuleParser.parse_prim_get_device_node(node, module, frame) + + # TODO bother assigning all ignored prim functions new NodeKinds? + if node_type == ScriptNodeKind.PrimDevice: + return ScriptModuleParser.parse_value_erased_node(node, module, frame, [ErasedDevice()]) + raise NotImplementedError(f"Un-supported node type {node_type}") except Exception: raise RuntimeError(f"\n\nParsing error at node:\n\t{node}\n") @@ -259,13 +265,17 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: output: List[int] = list(tensor.shape) frame.add_var(outputs[0].debugName(), output) return [] - # aten::tensor(elems: List[T], dtype:ScalarType, device:Device, requires_grad:bool) -> Tensor + # aten::tensor(elems: List^{n:Nat}[T], dtype:Optional[ScalarType], device:Device, requires_grad:bool) -> Tensor elif fsig == 'torch.tensor': # originally 'aten::tensor' var_name = outputs[0].debugName() - elems, dtype, device, requires_grad = input_val - kDefaultType = DType2IRDType.map(dtype) - ir_tensor = IRFullTensor(shape=[len(elems)], name=var_name, requires_grad=requires_grad, dtype=kDefaultType) + elems, dtype, erased_device, requires_grad = input_val + + # dtype may be None, in PyTorch it's to infer dtype from 'elems'. + if dtype == None: + dtype = DType2IRDType.map(torch.get_default_dtype()) + + ir_tensor = IRFullTensor(shape=[len(elems)], name=var_name, requires_grad=requires_grad, dtype=dtype) frame.add_var(var_name, ir_tensor) return [] @@ -458,17 +468,16 @@ def parse_prim_python_op_node(node, module, frame): print(dir(node)) @staticmethod - def parse_prim_get_device_node(node, module, frame): - inputs = list(node.inputs()) + def parse_value_erased_node(node, module, frame, erased_vals: List[Any]): outputs = list(node.outputs()) - if len(inputs) != 1: - raise RuntimeError("Find prim::device has not exactly one input") - if len(outputs) != 1: - raise RuntimeError("Find prim::device has not exactly one output") - input = frame.get_var(inputs[0].debugName()) - frame.add_var(outputs[0].debugName(), "TODO DEFINE THE DEVICE") + + assert len(outputs) == len(erased_vals) + for output, erased_val in zip(outputs, erased_vals): + frame.add_var(output.debugName(), erased_vals) return [] + + @staticmethod def flatten(smodule, depth=0): """ From 3b8d5d32d641d5fddc55e300887076591ae4cd09 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 12 May 2022 13:38:33 +0800 Subject: [PATCH 0809/1892] fix variable ref --- cube/graph/parser/parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 585a440a..aa739e1e 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -472,8 +472,8 @@ def parse_value_erased_node(node, module, frame, erased_vals: List[Any]): outputs = list(node.outputs()) assert len(outputs) == len(erased_vals) - for output, erased_val in zip(outputs, erased_vals): - frame.add_var(output.debugName(), erased_vals) + for output, ev in zip(outputs, erased_vals): + frame.add_var(output.debugName(), ev) return [] From 7ffeea6a804ea5578f1bfdef42a0bd6913a7f2a7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 12 May 2022 17:51:46 +0800 Subject: [PATCH 0810/1892] fix view bug --- cube/graph/operator/function/function.py | 121 ++++++++++++++++------- 1 file changed, 87 insertions(+), 34 deletions(-) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index c721b6c7..9a91b24a 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -309,61 +309,114 @@ def View(signature, inputs): raise TypeError("Expected tensor.view has static int shape") in_shape, ou_shape = list(input.shape), shape - # shape check + # infer -1 def nele(shape, nele=1): for dimlen in shape: nele *= dimlen return nele - # handle '-1' in shape + cnt = nele(in_shape) if -1 in ou_shape: idx = ou_shape.index(-1) ou_shape[idx] = cnt // (-nele(ou_shape)) assert nele(in_shape) == nele(ou_shape), "shape mismatch" + # generate annotation - shape_map: Dict[str, int] = dict() + rest_inshape = [dimlen for dimlen in in_shape] + rest_oushape = [dimlen for dimlen in ou_shape] + chain = [] + can_bucket = True + while len(rest_inshape) != 0 or len(rest_oushape) != 0: + if len(rest_inshape) == 0: + chain = chain + rest_oushape + rest_oushape = [] + elif len(rest_oushape) == 0: + chain = chain + rest_inshape + rest_inshape = [] + else: + dimlen = min(rest_inshape[0], rest_oushape[0]) + if max(rest_inshape[0], rest_oushape[0]) % dimlen == 0: + chain.append(dimlen) + if dimlen == rest_inshape[0]: + rest_inshape.pop(0) + else: + rest_inshape[0] = rest_inshape[0] // dimlen + if dimlen == rest_oushape[0]: + rest_oushape.pop(0) + else: + rest_oushape[0] = rest_oushape[0] // dimlen + else: + can_bucket = False + print(rest_inshape, rest_oushape) + # assert False + break + letters = iter(string.ascii_lowercase) - in_anno, ou_anno = [], [] - in_dim, ou_dim = 0, 0 - in_remain, ou_remain = in_shape[in_dim], ou_shape[ou_dim] - in_bracket, ou_bracket = [], [] - in_dimlen, ou_dimlen = 1, 1 - while True: - letter = next(letters) - dimlen = min(in_remain, ou_remain) - in_dimlen, ou_dimlen = in_dimlen * dimlen, ou_dimlen * dimlen - in_remain, ou_remain = in_remain // dimlen, ou_remain // dimlen - in_bracket.append(letter) - ou_bracket.append(letter) - shape_map[letter] = dimlen - if in_remain == 1: - in_anno.append(in_bracket) - in_bracket, in_dimlen = [], 1 - in_dim += 1 - if in_dim < len(in_shape): - in_remain = in_shape[in_dim] - if ou_remain == 1: - ou_anno.append(ou_bracket) - ou_bracket, ou_dimlen = [], 1 - ou_dim += 1 - if ou_dim < len(ou_shape): - ou_remain = ou_shape[ou_dim] - if in_dim == len(in_shape) and ou_dim == len(ou_shape): - break - # setup reduction: only first dimension can be spatially partitioned + if can_bucket: + inchain = ouchain = chain + inedims = ouedims = edims = [next(letters) for _ in chain] + else: + inchain, ouchain = in_shape, ou_shape + inedims = [str(dimlen) for dimlen in in_shape] + ouedims = [str(dimlen) for dimlen in ou_shape] + chain = inchain + ouchain + edims = inedims + ouedims + + shape_map: Dict[str, int] = {edim: eshape for (edim, eshape) in zip(edims, chain)} + print(inchain) + + # generate input and output shape annotations + def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[str]]: + anno = [] + dimidx = 0 + for idx, dimlen in enumerate(shape): + elements, bracket = 1, [] + maxele = len(chain) - dimidx - (len(shape) - 1 - idx) + while True: + if len(bracket) == maxele: + assert elements == dimlen, f"internal match error1: {bracket}" + break + if dimidx >= len(chain) or elements * chain[dimidx] > dimlen: + assert elements == dimlen, f"internal match error2: {bracket}" + break + else: + elements *= chain[dimidx] + bracket.append(edims[dimidx]) + dimidx += 1 + anno.append(bracket) + return anno + + in_anno = buckets(in_shape, inchain, inedims) + ou_anno = buckets(ou_shape, ouchain, ouedims) + print(in_anno, ou_anno) + + # postprocess on dimlen == 1 + shape_map['1'] = 1 + for bracket in in_anno + ou_anno: + for subdim, edim in enumerate(bracket): + if shape_map[edim] == 1: + bracket[subdim] = str(shape_map[edim]) + + # find out the axis that can be partitioned spatial_in = set() spatial_ou = set() for in_bracket in in_anno: - spatial_in.add(in_bracket[0]) + for edim in in_bracket: + if edim != '1': + spatial_in.add(edim) + break for ou_bracket in ou_anno: - spatial_ou.add(ou_bracket[0]) + for edim in ou_bracket: + spatial_ou.add(edim) spatial = spatial_in.intersection(spatial_ou) + for bracket in in_anno + ou_anno: for subdim, edim in enumerate(bracket): if edim not in spatial: bracket[subdim] = str(shape_map[edim]) # bracket[subdim] = edim + '^' anno = _create_anno([in_anno], [ou_anno]) - return IREinops(signature, [anno], [input], 'view', shape=shape) + print(f'torch.view anno: {anno}') + return IREinops(signature, [anno], [input], 'view', shape=tuple(shape)) def Reshape(signature, inputs): From cbd8ada6347036a4efa73fba860d519f4708b7f8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 12 May 2022 17:54:56 +0800 Subject: [PATCH 0811/1892] clear view debug info --- cube/graph/operator/function/function.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 60cccb4a..ab7b8bf2 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -347,8 +347,6 @@ def nele(shape, nele=1): rest_oushape[0] = rest_oushape[0] // dimlen else: can_bucket = False - print(rest_inshape, rest_oushape) - # assert False break letters = iter(string.ascii_lowercase) @@ -361,9 +359,7 @@ def nele(shape, nele=1): ouedims = [str(dimlen) for dimlen in ou_shape] chain = inchain + ouchain edims = inedims + ouedims - shape_map: Dict[str, int] = {edim: eshape for (edim, eshape) in zip(edims, chain)} - print(inchain) # generate input and output shape annotations def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[str]]: @@ -388,7 +384,6 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s in_anno = buckets(in_shape, inchain, inedims) ou_anno = buckets(ou_shape, ouchain, ouedims) - print(in_anno, ou_anno) # postprocess on dimlen == 1 shape_map['1'] = 1 @@ -416,7 +411,6 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s bracket[subdim] = str(shape_map[edim]) # bracket[subdim] = edim + '^' anno = _create_anno([in_anno], [ou_anno]) - print(f'torch.view anno: {anno}') return IREinops(signature, [anno], [input], 'view', shape=tuple(shape)) From 76ae2cc6089c07d4fe139e24db571fc5661a068a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 17 May 2022 14:37:58 +0800 Subject: [PATCH 0812/1892] fix execplan view --- cube/execplan/execplan.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 2c8a9587..d3e2af69 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -1,6 +1,7 @@ -from ast import Call from typing import Callable, List, Optional import copy +import numpy as np + from cube.graph.adapter.adapter import IRAdapter from cube.graph.operator.operator import IRBpOperation, IRFwOperation @@ -149,7 +150,6 @@ def map2color(node): if isinstance(node, IRAdapter): return '#70AD47' # excel green - graph = self.graph for node in self.graph.nodes(): span, mem = map2time(node), map2mem(node) for device in node.device: @@ -179,11 +179,13 @@ def map2color(node): [tline[-1][1] for tline in device_timeline if len(tline) != 0] ) max_mem = max(device_peak_mem) + # max_mem = sum(device_peak_mem) # draw the timeline if outfile is not None: import matplotlib.pyplot as plt from matplotlib.patches import Rectangle + from matplotlib.ticker import AutoMinorLocator plt.close('all') plt.rcParams['figure.figsize'] = (4.0 * max_time // ndevice, 4.0) fig, ax = plt.subplots() @@ -191,16 +193,19 @@ def map2color(node): # xaxis ax.set_xlim((1, max_time)) - plt.xticks(list(range(1, int(max_time)+1, 1))) - ax.xaxis.grid(True, linestyle='--') + plt.xticks( + ticks=np.arange(1.5, max_time+0.5, 1.0, dtype=float), + labels=np.arange(1, max_time, 1, dtype=int) + ) + minor_locator = AutoMinorLocator(2) + plt.gca().xaxis.set_minor_locator(minor_locator) + ax.xaxis.grid(which='minor', linestyle='--') # yaxis ax.set_ylim((0.5, len(self.devices())+0.5)) plt.yticks(list(range(1, len(self.devices())+1, 1))) ax.invert_yaxis() - ax.set_aspect('equal') - - fontsize = 100 + fontsize = 40 txts = list() for devid in range(ndevice): timeline = device_timeline[devid] @@ -220,7 +225,7 @@ def map2color(node): txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') rbox = rec.get_window_extent(renderer) - for fs in range(40, 1, -2): + for fs in range(fontsize, 1, -2): txt.set_fontsize(fs) tbox = txt.get_window_extent(renderer) if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: From 8ef4c5f70aeff602b03abad340f66567b26f6eae Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 18 May 2022 16:58:25 +0800 Subject: [PATCH 0813/1892] fix dataloader bug --- cube/runtime/syndata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 73374165..be91c80b 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -160,8 +160,8 @@ def set_data_buffer(self, buffer_num = 4): datas.append(data) self.datas.append(datas) - def reset(self, batch_size: int): - super().reset(batch_size) + def set_batch_size(self, batch_size: int): + super().set_batch_size(batch_size) self.set_data_buffer() def __next__(self): From e3287a1ca3ec8c90b3801255237bd33bcd9303f6 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 19 May 2022 15:13:55 +0800 Subject: [PATCH 0814/1892] add value-evaluation for arithmetic ops --- cube/graph/operator/function/function.py | 89 +++++++++++++++++++++--- cube/graph/parser/mapping.py | 8 ++- cube/graph/parser/parser.py | 41 ++++++++--- 3 files changed, 115 insertions(+), 23 deletions(-) diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index ab7b8bf2..1e5f53d6 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -83,15 +83,28 @@ def BatchLinear(signature, inputs): def Add(signature, inputs): - assert len(inputs) == 3 - inputs, alpha = inputs[0:2], inputs[2] + if len(inputs) == 2: + kwargs = {} + elif len(inputs) == 3: + alpha = inputs[2] + kwargs = {'alpha': alpha} + inputs = inputs[0:2] + else: + raise RuntimeError("The number of inputs must be 2 or 3") + + lhs, rhs = inputs + + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + # In this case there won't be an 'alpha' parameter. + assert not('alpha' in kwargs) + return lhs + rhs + annos = [ '*, 1 -> *', '1, * -> *', '*, * -> *', ] # broadcast - lhs, rhs = inputs if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ len(lhs.shape) == len(rhs.shape): if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): @@ -107,19 +120,33 @@ def Add(signature, inputs): oshape[dim] = lshape[dim] rshape[dim] = str(rhs.shape[dim]) annos = [_create_anno([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, 'add', alpha=alpha) + return IREinops(signature, annos, inputs, 'add', **kwargs) def Sub(signature, inputs): - assert len(inputs) == 3 - inputs, alpha = inputs[0:2], inputs[2] + if len(inputs) == 2: + alpha = 1 + kwargs = {} + elif len(inputs) == 3: + alpha = inputs[2] + kwargs = {'alpha': alpha} + inputs = inputs[0:2] + else: + raise RuntimeError("The number of inputs must be 2 or 3") + + lhs, rhs = inputs + + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + # In this case there won't be an 'alpha' parameter. + assert not('alpha' in kwargs) + return lhs - rhs + annos = [ '*, 1 -> *', '1, * -> *', '*, * -> *', ] # broadcast - lhs, rhs = inputs if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ len(lhs.shape) == len(rhs.shape): if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): @@ -135,17 +162,21 @@ def Sub(signature, inputs): oshape[dim] = lshape[dim] rshape[dim] = str(rhs.shape[dim]) annos = [_create_anno([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, 'sub', alpha=alpha) + return IREinops(signature, annos, inputs, 'sub', **kwargs) def Mul(signature, inputs): + lhs, rhs = inputs + + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + return lhs * rhs + annos = [ '*, 1 -> *', '1, * -> *', '*, * -> *', ] # broadcast - lhs, rhs = inputs if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ len(lhs.shape) == len(rhs.shape): if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): @@ -165,13 +196,20 @@ def Mul(signature, inputs): def Div(signature, inputs): + lhs, rhs = inputs + + if isinstance(lhs, int) and isinstance(rhs, int): + # only if both operands are int, do we do floor division. + return lhs // rhs + elif isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + return lhs / rhs + annos = [ '*, 1 -> *', '1, * -> *', '*, * -> *', ] # broadcast - lhs, rhs = inputs if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ len(lhs.shape) == len(rhs.shape): if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): @@ -189,6 +227,37 @@ def Div(signature, inputs): annos = [_create_anno([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'div') + +def Pow(signature, inputs): + lhs, rhs = inputs + + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + return lhs ** rhs + + annos = [ + '*, 1 -> *', + '1, * -> *', + '*, * -> *', + ] + # broadcast + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ + len(lhs.shape) == len(rhs.shape): + if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): + # TODO: support spatial partitioning on broadcast dim + lshape = _create_eshape(lhs.shape) + rshape = copy.copy(lshape) + oshape = copy.copy(lshape) + for dim in range(len(lhs.shape)): + if lhs.shape[dim] < rhs.shape[dim]: + oshape[dim] = rshape[dim] + lshape[dim] = str(lhs.shape[dim]) + elif lhs.shape[dim] > rhs.shape[dim]: + oshape[dim] = lshape[dim] + rshape[dim] = str(rhs.shape[dim]) + annos = [_create_anno([lshape, rshape], [oshape])] + return IREinops(signature, annos, inputs, 'pow') + + def Neg(signature, inputs): annos = ['* -> *'] tensor = inputs[0:1] diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 85c7ffb6..eaede8bf 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -2,7 +2,7 @@ Mapping of Signature -> IROperator """ -from typing import Dict +from typing import Any, Callable, Dict, Union import torch from functools import partial @@ -15,7 +15,7 @@ class Sign2Op: @staticmethod - def map(signature: str) -> IRFwOperation: + def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: """ Map the signature to GenericLogicalOp """ @@ -27,7 +27,7 @@ def map(signature: str) -> IRFwOperation: # return partial(function.UnkownOperator, signature=signature) @staticmethod - def register(signature: str, op: IRFwOperation, code): + def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]], code): """ Register an operator """ @@ -80,6 +80,8 @@ def register(signature: str, op: IRFwOperation, code): __ttemplate('neg'): function.Neg, + __ttemplate('pow'): function.Pow, + __ttemplate('sin'): function.Sin, __ttemplate('cos'): function.Cos, diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index aa739e1e..14d09052 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -265,6 +265,15 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: output: List[int] = list(tensor.shape) frame.add_var(outputs[0].debugName(), output) return [] + + # aten::__getitem__.t(t[](a) list, int idx) -> t(*)" + # REMARK List-type only. '__getitem__' cannot serve as accessor to tensor element. + elif fsig == 'torch.__getitem__': + # NOTE there are other overloadings of '__getitem__' for 'str'(i.e. char list), 'Dict(t)' in TorchScript + container, index = input_val + frame.add_var(outputs[0].debugName(), container[index]) + return [] + # aten::tensor(elems: List^{n:Nat}[T], dtype:Optional[ScalarType], device:Device, requires_grad:bool) -> Tensor elif fsig == 'torch.tensor': # originally 'aten::tensor' @@ -279,18 +288,30 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: frame.add_var(var_name, ir_tensor) return [] - # create IR node - ir_node = Sign2Op.map(fsig)(inputs=input_val) - if len(ir_node.outputs()) != len(outputs): - raise RuntimeError( - f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" - ) + # May be a symbolic object i.e. IRFwOperation, + # or, occasionally this node can be statically evaluated, therefore a concrete value + result = Sign2Op.map(fsig)(inputs=input_val) - # handle outputs - for index, output in enumerate(outputs): - frame.add_var(output.debugName(), ir_node.outputs(index)) + if isinstance(result, IRFwOperation): + # to create IR node - return [ir_node] + ir_node = result + if len(ir_node.outputs()) != len(outputs): + raise RuntimeError( + f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" + ) + + # handle outputs + for index, output in enumerate(outputs): + frame.add_var(output.debugName(), ir_node.outputs(index)) + + return [ir_node] + + else: + # concrete value. + assert len(outputs) == 1, "Cases with multiple outputs are only List/Tuple-Unpack and handled specially" + frame.add_var(outputs[0].debugName(), result) + return [] @staticmethod def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: From bb0cfbb6c40871eaab959ec2230fc91bc5fa829e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 19 May 2022 21:15:18 +0800 Subject: [PATCH 0815/1892] add var reuse --- cube/codegen/codegen.py | 44 ++++++++++++++++++++++++++----- cube/codegen/register.py | 57 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 7 deletions(-) create mode 100644 cube/codegen/register.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index c5c11450..b7f93df9 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,7 +1,7 @@ """ Generate Pytorch code given the model DAG and the transformation config """ -from typing import Dict, List, Any, Tuple +from typing import Dict, List, Any, Tuple, Union import torch import copy from cube.graph.parser.mapping import Sign2Op @@ -19,6 +19,7 @@ from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock +from cube.codegen.register import VarManager class CodeGen: @@ -377,12 +378,14 @@ def __init__(self, execplan: ExectuionPlan): 'import torch', 'import cube', ''] # module member name self.symbols = SymbolTable() + self.vars = VarManager() def gen(self, device: int, outfile=None, attach=False) -> str: """ Generate scheduling code based on the given sus """ gencode = copy.copy(self.init_code) + self.vars = VarManager() device_nodes = self.execplan.sequence(device) for idx, node in enumerate(device_nodes): @@ -390,6 +393,28 @@ def gen(self, device: int, outfile=None, attach=False) -> str: node = node.dispatch(rank=device) device_nodes[idx] = node + def refcount(tensor, node) -> int: + idx = device_nodes.index(node) + refcnt = 0 + for ref_node in device_nodes[idx+1:]: + if isinstance(ref_node, IRGraph): + if all([isinstance(rnode, IRFwOperation) for rnode in ref_node.nodes()]): + if tensor in ref_node.inputs(): + refcnt += 1 + else: + finputs = ref_node.mirror.inputs() + foutputs = ref_node.mirror.outputs() + grad_in = [t.grad for t in foutputs] + if tensor in finputs + foutputs + grad_in: + refcnt += 1 + else: + if tensor in ref_node.inputs(): + refcnt += 1 + return refcnt + + for node in device_nodes: + print(f'dev{device}: {node}') + # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: @@ -400,6 +425,12 @@ def gen(self, device: int, outfile=None, attach=False) -> str: name = self.node_naming(node) code = self.emit_node(node, name=name) fb.insert_body(code) + # free unused tensor + for tensor in node.inputs() + node.outputs(): + if isinstance(tensor, IRSubTensor) and not tensor.is_param(): + refcnt = refcount(tensor, node) + if refcnt == 0: + self.vars.free(tensor) # return code outputs = self.return_naming(self.execplan.graph.outputs()) code = f'return {outputs}' @@ -483,14 +514,13 @@ def return_naming(self, tensors: List[Any]) -> str: tensors = ', '.join(tensors) return tensors - def tensor_naming(self, tensor: Any): + def tensor_naming(self, tensor: Union[IRSubTensor, Any]): """ Generate tensor name. Will add prefix 'model.' for parameters """ - name = super().tensor_naming(tensor) - if isinstance(tensor, IRSubTensor): - if tensor.is_param(): - name = 'model.' + name - return name + if isinstance(tensor, IRSubTensor) and tensor.is_param(): + return 'model.' + self.vars.allocate(tensor) + else: + return self.vars.allocate(tensor) diff --git a/cube/codegen/register.py b/cube/codegen/register.py new file mode 100644 index 00000000..ba812b5c --- /dev/null +++ b/cube/codegen/register.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, List, Union + +from cube.ir.cten import IRTensor + + +class VarManager: + """ + Tensor naming reuse engine for saving memory + """ + + def __init__(self): + # the unique id + self.nid = 0 + self.slots: List[int] = list() + # original tensor id -> new tensor id mapping + self.tmap: Dict[int, int] = dict() + + def free(self, tensor: Union[IRTensor, Any]): + """ + Free a tensor + """ + if isinstance(tensor, IRTensor): + assert tensor._id in self.tmap, f"Double free on tensor {tensor}" + reg = self.tmap[tensor._id] + del self.tmap[tensor._id] + self.slots.append(reg) + + def allocate(self, tensor: Union[IRTensor, Any]) -> str: + """ + Allocate a tensor name for the tensor. + New tensors will be allocated by available + unique ids freed by other tensor. + Existing teensor will get the allocated name. + """ + if isinstance(tensor, IRTensor): + ttype = 'g' if tensor.is_grad() else 't' + # param is graph attribute, don't need allocation + if tensor.is_param(): + return f'{tensor.name}_{tensor._id}' + if tensor._id in self.tmap: + # fetch the original one + reg = self.tmap[tensor._id] + return f'{ttype}{reg}' + else: + # allocate a new one + if len(self.slots) == 0: + reg = self.nid + self.nid += 1 + else: + reg = self.slots.pop(-1) + self.tmap[tensor._id] = reg + return f'{ttype}{reg}' + else: + return str(tensor) + + + From 83cc8df45332fa9eea8f8f4a8eeed2f6742b4ed8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 13:46:11 +0800 Subject: [PATCH 0816/1892] full tensor records operator in execution order --- cube/graph/graph.py | 83 ++++++++++++++++++++++++++++++-------------- cube/graph/tensor.py | 46 +++++++++++++++--------- 2 files changed, 86 insertions(+), 43 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index d544d08f..22e10eeb 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -13,7 +13,7 @@ from cube.ir.cten import IRTensor, IRCell from cube.graph.operator.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.graph.adapter.adapter import IRAdapter -from cube.graph.tensor import IRSubTensor +from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.algorithm.generics import GenericDistAlgo @@ -36,6 +36,7 @@ def __init__(self, self._nodes: List[IRCell] = list() self._parameters = list() + self._full_tensors: Dict[int, IRFullTensor] = dict() if inputs is None: inputs = IRGraph.get_inputs(nodes) @@ -54,15 +55,22 @@ def __init__(self, for idx, tensor in enumerate(outputs): self.set_output(idx, tensor) + # set parameters and full tensors + for node in nodes: + for tensor in node.inputs() + node.outputs(): + if isinstance(tensor, IRSubTensor): + pid = tensor.parent._id + self._full_tensors[pid] = tensor.parent + if tensor.is_param(): + self._parameters.append(input) + + for ftensor in self._full_tensors.values(): + ftensor.clear_producer_consumer() + # insert node from nodes for idx, node in enumerate(nodes): self.attach(node, idx) - # set parameter - for node in self._nodes: - for input in node.inputs(): - if isinstance(input, IRTensor) and input.is_param(): - self._parameters.append(input) self.reset_dependency() def reset_dependency(self): @@ -105,6 +113,12 @@ def parameters(self): """ return copy.copy(self._parameters) + def full_tensors(self): + """ + Return full tensor list + """ + return list(self._full_tensors.values()) + def nodes(self, index: Optional[int] = None): """ Get node at position index @@ -191,20 +205,14 @@ def detach(self, node: IRCell, reset_dependency=False) -> int: """ if node not in self.nodes(): raise KeyError(f"node {node} is not in graph.") - ops = node.nodes() if isinstance(node, IRGraph) else [node] - for op in ops: - removed = list() - for input in op.inputs(): - if isinstance(input, IRSubTensor) and input not in removed: - input.parent.rm_consumer(op) - removed.append(input) - removed = list() - for output in op.outputs(): - if isinstance(output, IRSubTensor) and output not in removed: - output.parent.rm_producer(op) - removed.append(output) index = self._nodes.index(node) self._nodes.pop(index) + for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor): + itensor.parent.rm_consumer(node) + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor): + otensor.parent.rm_producer(node) if reset_dependency: self.reset_dependency() return index @@ -218,15 +226,27 @@ def attach(self, node: IRCell, index, reset_dependency=False): """ if node in self.nodes(): raise KeyError(f"node {node} is already in graph.") - ops = node.nodes() if isinstance(node, IRGraph) else [node] - for op in ops: - for input in op.inputs(): - if isinstance(input, IRSubTensor): - input.parent.add_consumer(op, input) - for output in op.outputs(): - if isinstance(output, IRSubTensor): - output.parent.add_producer(op, output) self._nodes.insert(index, node) + # update consumer + for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor): + idx = 0 + for consumer in itensor.parent.consumers: + if self.nodes().index(consumer) < index: + idx += 1 + else: + break + itensor.parent.add_consumer(node, itensor, idx) + # update producer + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor): + idx = 0 + for producer in otensor.parent.producers: + if self.nodes().index(producer) < index: + idx += 1 + else: + break + otensor.parent.add_producer(node, otensor, idx) if reset_dependency: self.reset_dependency() return @@ -624,4 +644,13 @@ def extra_repr(self): return dscp def module_repr(self): - return repr(self) \ No newline at end of file + return repr(self) + + +class IRSegment(IRCell): + """ + A segment refers to a piece of workload of IRGraph + """ + + def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRTensor]): + pass diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index b0ac5c83..34056c8d 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -27,9 +27,6 @@ import cube.ir as ir -__all__ = ['IndexMap', 'ValueMap', 'IRFullTensor', 'IRSubTensor'] - - class IndexMap: def __init__(self, indmap): @@ -338,7 +335,7 @@ def __repr__(self): return f'({self.idx}/{self.chunk_num})' -def _to_indmap(indmap: Union[Tuple, IndexMap]): +def _to_indmap(indmap: Union[Tuple, IndexMap]) -> IndexMap: if not isinstance(indmap, tuple) and not isinstance(indmap, IndexMap): raise TypeError("Expected indmap to be tuple or IndexMap") if isinstance(indmap, tuple): @@ -346,7 +343,7 @@ def _to_indmap(indmap: Union[Tuple, IndexMap]): return indmap -def _to_value_map(valmap: Union[Tuple, ValueMap, None]): +def _to_value_map(valmap: Union[Tuple, ValueMap, None]) -> ValueMap: if not isinstance(valmap, tuple) and \ not isinstance(valmap, ValueMap) and \ not valmap is None: @@ -361,6 +358,13 @@ def _to_value_map(valmap: Union[Tuple, ValueMap, None]): class IRFullTensor(IRTensor): + """ + Full (logic) Tensor intermeidate representation. + + It records its Sub (physical) Tensors with corresponding + producer operators and consumer operators following + the sequentail execution order by its graph. + """ def __init__(self, shape=None, name=None, requires_grad=True, dtype=ir.float32): @@ -419,33 +423,41 @@ def ctensors(self): """ return self._ctensors - def add_producer(self, cell: IRCell, tensor: IRTensor): + def add_producer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): raise TypeError("Expect an IRCell and an IRTensor") - if cell not in self.producers: - self.producers.append(cell) - self.ptensors.append(tensor) + assert cell not in self._producers, f"{cell} already exists as producer" + self._producers.insert(idx, cell) + self._ptensors.insert(idx, tensor) - def add_consumer(self, cell: IRCell, tensor: IRTensor): + def add_consumer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): raise TypeError("Expect an IRCell and an IRTensor") - if cell not in self.consumers: - self.consumers.append(cell) - self.ctensors.append(tensor) + assert cell not in self._consumers, f"{cell} already exists as consumer" + self._consumers.insert(idx, cell) + self._ctensors.insert(idx, tensor) - def rm_producer(self, cell: IRCell): + def rm_producer(self, cell: IRCell) -> int: if cell not in self.producers: raise KeyError(f"Cell {cell} not found in producer") idx = self.producers.index(cell) self.producers.pop(idx) self.ptensors.pop(idx) + return idx - def rm_consumer(self, cell: IRCell): + def rm_consumer(self, cell: IRCell) -> int: if cell not in self.consumers: raise KeyError(f"Cell {cell} not found in producer") idx = self.consumers.index(cell) self.consumers.pop(idx) self.ctensors.pop(idx) + return idx + + def clear_producer_consumer(self) -> int: + self._producers = [] + self._ptensors = [] + self._consumers = [] + self._ctensors = [] def subtensors(self): """ @@ -568,7 +580,9 @@ def __repr__(self): class IRSubTensor(IRTensor): - def __init__(self, full_tensor: IRTensor, indmap, valmap: Optional[ValueMap] =None, shape=None): + def __init__(self, full_tensor: IRTensor, + indmap: List[Union[Tuple, IndexMap]], + valmap: Optional[ValueMap] = None, shape=None): """ Create an IRSubTensor. From 36ca38eca0cfa42b20977d24965fdf0068a77850 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 13:47:15 +0800 Subject: [PATCH 0817/1892] update with differentiable collectives --- cube/runtime/adapter/collectives.py | 1 - cube/runtime/adapter/distnn.py | 195 +++++++++++++++++++++++++--- 2 files changed, 175 insertions(+), 21 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 647ffcd3..5f829138 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -1,5 +1,4 @@ from typing import List -from unittest import defaultTestLoader import torch from cube.runtime.device import DeviceGroup diff --git a/cube/runtime/adapter/distnn.py b/cube/runtime/adapter/distnn.py index d41ddb30..c15649d2 100644 --- a/cube/runtime/adapter/distnn.py +++ b/cube/runtime/adapter/distnn.py @@ -1,18 +1,98 @@ +from typing import List import torch + from cube.profiler.timer import CudaTimer +from cube.runtime.device import DeviceGroup + + +class SendRecv(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dst: int, group): + CudaTimer().start(field_name='comm') + ctx._tsize = input_.size() + ctx._tdtype = input_.dtype + ctx._src = dst + if not input_.is_contiguous(): + input_ = input_.contiguous() + sendop = torch.distributed.P2POp( + torch.distributed.isend, input_, dst + ) + reqs = torch.distributed.batch_isend_irecv([sendop]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, _grad: torch.Tensor): + CudaTimer().start(field_name='comm') + size = ctx._tsize + dtype = ctx._tdtype + src = ctx._src + grad = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) + recvop = torch.distributed.P2POp( + torch.distributed.irecv, grad, src + ) + reqs = torch.distributed.batch_isend_irecv([recvop]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return grad, None, None + + +class RecvSend(torch.autograd.Function): + + @staticmethod + def forward(ctx, size, dtype, src: int, ranks: List[int]): + CudaTimer().start(field_name='comm') + ctx._tsize = size + ctx._tdtype = dtype + ctx._dst = src + input_ = torch.empty( + size, dtype=dtype, device=torch.cuda.current_device(), + requires_grad=True) + recvop = torch.distributed.P2POp( + torch.distributed.irecv, input_, src + ) + reqs = torch.distributed.batch_isend_irecv([recvop]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad: torch.Tensor): + CudaTimer().start(field_name='comm') + dst = ctx._dst + if not grad.is_contiguous(): + grad = grad.contiguous() + sendop = torch.distributed.P2POp( + torch.distributed.isend, grad, dst + ) + reqs = torch.distributed.batch_isend_irecv([sendop]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return None, None, None, None class AllReduceIdentity(torch.autograd.Function): @staticmethod - def forward(ctx, input, group): + def forward(ctx, input_, ranks: List[int]): + group = DeviceGroup().get_group(ranks) world_size = torch.distributed.get_world_size(group) if world_size == 1: - return input + return input_ CudaTimer().start(field_name='comm') - torch.distributed.all_reduce(input, group=group) + torch.distributed.all_reduce(input_, group=group) CudaTimer().stop(field_name='comm') - return input + return input_ @staticmethod def backward(ctx, grad_output): @@ -22,10 +102,11 @@ def backward(ctx, grad_output): class IdentityAllreduce(torch.autograd.Function): @staticmethod - def forward(ctx, input, group): + def forward(ctx, input_, ranks: List[int]): + group = DeviceGroup().get_group(ranks) ctx._group = group - return input - + return input_ + @staticmethod def backward(ctx, grad_output): world_size = torch.distributed.get_world_size(ctx._group) @@ -37,20 +118,58 @@ def backward(ctx, grad_output): return grad_output, None +class ReduceScatterAllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dim: int, ranks: List[int]): + group = DeviceGroup().get_group(ranks) + ctx._group = group + ctx._dim = dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_, + CudaTimer().start(field_name='comm') + input_tensors = input_.chunk(world_size, dim) + rank = torch.distributed.get_rank(group) + input_ = torch.empty_like(input_tensors[rank], requires_grad=True) + torch.distributed.reduce_scatter( + input_, input_tensors, group=group + ) + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output): + group = ctx._group + dim = ctx._dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output + CudaTimer().start(field_name='comm') + rank = torch.distributed.get_rank(group) + tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] + tensor_list[rank] = grad_output + torch.distributed.all_gather(tensor_list, grad_output, group=group) + grad = torch.cat(tensor_list, dim=dim).contiguous() + CudaTimer().stop(field_name='comm') + return grad, None, None + + class AllGatherSplit(torch.autograd.Function): @staticmethod - def forward(ctx, input, dim, group): + def forward(ctx, input_: torch.Tensor, dim: int, ranks: List[int]): + group = DeviceGroup().get_group(ranks) ctx._group = group ctx._dim = dim world_size = torch.distributed.get_world_size(group) if world_size == 1: - return input + return input_ CudaTimer().start(field_name='comm') rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(input) for _ in range(world_size)] - tensor_list[rank] = input - torch.distributed.all_gather(tensor_list, input, group=group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) output = torch.cat(tensor_list, dim=dim).contiguous() CudaTimer().stop(field_name='comm') return output @@ -70,20 +189,55 @@ def backward(ctx, grad_output: torch.Tensor): return grad, None, None +class SplitAllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dim: int, ranks: List[int]): + group = DeviceGroup().get_group(ranks) + ctx._group = group + ctx._dim = dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + CudaTimer().start(field_name='comm') + input_list = input_.chunk(world_size, dim=dim) + rank = torch.distributed.get_rank(group) + input_ = input_list[rank].contiguous() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + group = ctx._group + dim = ctx._dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output + CudaTimer().start(field_name='comm') + rank = torch.distributed.get_rank(group) + tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] + tensor_list[rank] = grad_output + torch.distributed.all_gather(tensor_list, grad_output, group=group) + grad = torch.cat(tensor_list, dim=dim).contiguous() + CudaTimer().stop(field_name='comm') + return grad, None, None + + class ReduceBroadcast(torch.autograd.Function): @staticmethod - def forward(ctx, input, dst: int, group): + def forward(ctx, input_: torch.Tensor, dst: int, ranks: List[int]): + group = DeviceGroup().get_group(ranks) ctx._dst = dst ctx._group = group world_size = torch.distributed.get_world_size(group) if world_size == 1: - return input + return input_ CudaTimer().start(field_name='comm') - torch.distributed.reduce(input, dst, group=group) + torch.distributed.reduce(input_, dst, group=group) torch.cuda.synchronize() CudaTimer().stop(field_name='comm') - return input + return input_ @staticmethod def backward(ctx, grad_output): @@ -102,17 +256,18 @@ def backward(ctx, grad_output): class BroadcastReduce(torch.autograd.Function): @staticmethod - def forward(ctx, input, src: int, group=None): + def forward(ctx, input_: torch.Tensor, src: int, ranks: List[int]): + group = DeviceGroup().get_group(ranks) ctx._src = src ctx._group = group world_size = torch.distributed.get_world_size(group) if world_size == 1: - return input + return input_ CudaTimer().start(field_name='comm') - torch.distributed.broadcast(input, src, group=group) + torch.distributed.broadcast(input_, src, group=group) torch.cuda.synchronize() CudaTimer().stop(field_name='comm') - return input + return input_ @staticmethod def backward(ctx, grad_output): From 8385ddfe2e267d78a03d961c7efa8af4ff0a7824 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 13:47:50 +0800 Subject: [PATCH 0818/1892] add primitive for adapter --- cube/graph/adapter/prim.py | 339 +++++++++++++++++++++++++++++++++++++ 1 file changed, 339 insertions(+) create mode 100644 cube/graph/adapter/prim.py diff --git a/cube/graph/adapter/prim.py b/cube/graph/adapter/prim.py new file mode 100644 index 00000000..3b9bf112 --- /dev/null +++ b/cube/graph/adapter/prim.py @@ -0,0 +1,339 @@ +""" +The primitive used for IRAdapter +""" + +from typing import Callable, List, Optional, Union +import copy + +from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap + +# the general adapter primitive class +class IRAdapterPrim: + + def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): + self._inputs = inputs + self._outputs = outputs + self._device = [] + self.kwargs = dict() + + def inputs(self, idx: Optional[int] = None): + assert idx is None or isinstance(idx, int), "expected idx to be None or int" + if idx is None: + return copy.copy(self._inputs) + else: + return self._inputs[idx] + + def outputs(self, idx: Optional[int] = None): + assert idx is None or isinstance(idx, int), "expected idx to be None or int" + if idx is None: + return copy.copy(self._outputs) + else: + return self._outputs[idx] + + def dispatch(self, devid: int): + return self + + @property + def device(self) -> List[int]: + return copy.copy(self._device) + + @device.setter + def device(self, devs: Union[int, List[int]]): + if isinstance(devs, int): + devs = [devs] + self._device = devs + +# spatial abstract primitive +class SpatialPrim(IRAdapterPrim): + """ + basic class for representing spatial primitives + """ + def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): + super().__init__(inputs, outputs) + + +# numerical abstract primitive +class ValuePrim(IRAdapterPrim): + """ + basic class for representing numerical primitives + """ + def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): + super().__init__(inputs, outputs) + +# communication abstract primitive +class CommPrim(IRAdapterPrim): + """ + communication primitive + """ + def __init__(self, + itensors: List[IRSubTensor], + otensors: List[IRSubTensor]): + super().__init__(itensors, otensors) + devices = [] + for t in itensors + otensors: + devices += t.device + self.device = list(set(devices)) + + def dispatch(self, devid: int): + """ + dispatch to a given device + """ + raise NotImplementedError + + def __repr__(self) -> str: + dscp = f'{self.outputs()} = {self.signature}({self.inputs()})' + return dscp + +# ====================================================== + +class SelectPrim(SpatialPrim): + + def __init__(self, + itensor: IRSubTensor, + indmap: IndexMap, valmap: ValueMap, + otensor: IRSubTensor): + super().__init__([itensor], [otensor]) + self.indmap = indmap + self.valmap = valmap + self.device = itensor.device + + def __repr__(self): + dscp = f'{self.outputs(0)} = select({self.inputs(0)})' + return dscp + + +class SplitDimPrim(SpatialPrim): + """ + split dimension + """ + def __init__(self, itensor: IRSubTensor, dim: int, + otensors: List[IRSubTensor]): + super().__init__([itensor], otensors) + self.dim = dim + self.device = itensor.device + + +class MergeDimPrim(SpatialPrim): + """ + concatenate dimension + """ + def __init__(self, itensors: List[IRSubTensor], dim: int, + otensor: IRSubTensor) -> None: + assert all(itensor.device == itensors[0].device for itensor in itensors), "device not same" + super().__init__(itensors, [otensor]) + self.dim = dim + self.device = itensors[0].device + +# numerical primitive + +class ReducePrim(ValuePrim): + + def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): + assert all(itensor.device == itensors[0].device for itensor in itensors), "device not same" + super().__init__(itensors, [otensor]) + self.reduce = '+' + self.device = itensors[0].device + +# communication primitive + +class SendPrim(CommPrim): + """ + P2P send prim + """ + def __init__(self, tensor, dst: int): + super().__init__([tensor], [tensor]) + self.kwargs['dst'] = dst + + def dispatch(self, devid: int): + assert devid == self.device[0], f"device {devid} not applied for this comm primitive" + return SendPrim(self.inputs(0), self.kwargs['dst']) + + def __repr__(self) -> str: + return f"{self.inputs(0)} = send({self.inputs(0)}, dst={self.kwargs['dst']}" + + +class RecvPrim(CommPrim): + """ + P2P recv prim + """ + def __init__(self, tensor, src: int): + super().__init__([], [tensor]) + self.kwargs['src'] = src + self.kwargs['shape'] = tensor.shape + self.kwargs['dtype'] = tensor.dtype + + def dispatch(self, devid: int): + assert devid == self.device[0], f"device {devid} not applied for this comm primitive" + return RecvPrim(self.outputs(0), self.kwargs['src']) + + def __repr__(self) -> str: + return f"{self.outputs(0)} = recv(shape={self.kwargs['shape']}, dtype={self.kwargs['dtype']}, dst={self.kwargs['dst']}" + + +class MovePrim(CommPrim): + """ + P2P send/recv, non-differentiable + """ + def __init__(self, tensor: IRSubTensor, src: int, dst: int): + super().__init__([tensor], [tensor]) + self.kwargs['src'] = src + self.kwargs['dst'] = dst + + def dispatch(self, devid: int) -> Union[SendPrim, RecvPrim]: + if devid == self.kwargs['src']: + return SendPrim(self.inputs(0), self.kwargs['devid']) + if devid == self.kwargs['dst']: + return RecvPrim(self.inputs(0), self.kwargs['src']) + raise ValueError(f"device {devid} is not src ({self.kwargs['src']}) or ({self.kwargs['dst']})") + + def __repr__(self): + dscp = f'move({self.inputs(0)}, from={self.src}, to={self.dst})' + return dscp + + +class CollectivePrim(CommPrim): + """ + Collective primitive, non-differentiable + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + super().__init__(itensors, otensors) + self.kwargs['ranks'] = self.device + for arg, val in kwargs.items(): + self.kwargs[arg] = val + + def dispatch(self, devid: int, init_method: Callable): + """ + dispatch to a given device + """ + assert devid in self.device, f"device {devid} not applied for this comm primitive" + itensors = [itensor for itensor in self.inputs() if devid in itensor.device] + otensors = [otensor for otensor in self.outputs() if devid in otensor.device] + prim = init_method(itensors, otensors, **self.kwargs) + return prim + + +class AllReducePrim(CollectivePrim): + """ + non-differentiable allreduce + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): + super().__init__(itensors, otensors) + + +class AllGatherPrim(CollectivePrim): + """ + non-differentiabl all-to-all + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): + super().__init__(itensors, otensors) + + +class ReduceScatterPrim(CollectivePrim): + """ + non-differential reduce-scatter + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): + super().__init__(itensors, otensors) + + +class BroadcastPrim(CollectivePrim): + """ + non-differential reduce-scatter + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], src: int): + super().__init__(itensors, otensors, src=src) + + +class ReducePrim(CollectivePrim): + """ + non-differential reduce prim + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dst: int): + super().__init__(itensors, otensors, dst=dst) + + +class AllToAllPrim(CollectivePrim): + """ + non-differentiable all-to-all + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idim: int, odim: int): + """ + itensors: each rank hosts one tensor splitted by idim + otensors: each rank hosts one tensor splitted by odim + idim != odim + """ + super().__init__(itensors, otensors, idim=idim, odim=odim) + + +class DiffCollectivePrim(CollectivePrim): + """ + Differentiable collective primitive + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + """ + differentiable collectives + """ + super().__init__(itensors, otensors, **kwargs) + + +class AllReduceIdentityPrim(DiffCollectivePrim): + """ + forward: allreduce. + backward: identity + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): + super().__init__(itensors, otensors) + + +class IdentityAllreducePrim(DiffCollectivePrim): + """ + forward: identity + backward: allreduce + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): + super().__init__(itensors, otensors) + + +class ReduceScatterAllGatherPrim(DiffCollectivePrim): + """ + forward: reduce-scatter + backward: all-gather + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): + super().__init__(itensors, otensors, dim=dim) + + +class AllGatherSplitPrim(DiffCollectivePrim): + """ + forward: all-gather + backward: split + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): + super().__init__(itensors, otensors, dim=dim) + + +class SplitAllGatherPrim(DiffCollectivePrim): + """ + forward: split + backward: all-gather + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): + super().__init__(itensors, otensors, dim=dim) + + +class ReduceBroadcastPrim(DiffCollectivePrim): + """ + forward: broadcast + backward: reduce + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dst: int): + super().__init__(itensors, otensors, dst=dst) + + +class BroadcastRedducePrim(DiffCollectivePrim): + """ + forward: broadcast + backward: reduce + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], src: int): + super().__init__(itensors, otensors, src=src) From a3e0ece329a654e13bdb123964d3b8fedd729699 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 13:49:35 +0800 Subject: [PATCH 0819/1892] update nlp blocks and models --- examples/nlp/blocks/attention.py | 50 ++++++++++++++------------------ examples/nlp/blocks/decoder.py | 13 ++++++--- examples/nlp/blocks/encoder.py | 9 ++++-- examples/nlp/blocks/mlp.py | 5 ++-- examples/nlp/gpt/model.py | 10 ++++--- 5 files changed, 45 insertions(+), 42 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 656f80fb..4798ed45 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -8,7 +8,7 @@ def self_attention(query: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, out_proj: torch.Tensor, out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, mask=True): + h: int, scale: float, dropout_p: float, mask: bool = True): num_head = h L, N = query.size(0), query.size(1) dim_head = q_proj.size(0) // num_head @@ -90,28 +90,24 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, class MultiHeadSelfAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias=True): + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): super().__init__() - self.kdim = embed_dim - self.vdim = embed_dim + self.inner_dim = inner_dim self.num_heads = num_heads - self.head_dim = embed_dim // num_heads + self.head_dim = inner_dim // num_heads self.scaling = self.head_dim ** -0.5 self.dropout_p = dropout # Q - self.q_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) - if bias: - self.q_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.q_bias = None + self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None # K - self.k_proj = torch.nn.Parameter(torch.empty(embed_dim, self.kdim)) - self.k_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None # V - self.v_proj = torch.nn.Parameter(torch.empty(embed_dim, self.vdim)) - self.v_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None def forward(self, query): @@ -127,28 +123,24 @@ def forward(self, query): class MultiHeadCrossAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias=True): + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): super().__init__() - self.kdim = embed_dim - self.vdim = embed_dim + self.inner_dim = inner_dim self.num_heads = num_heads - self.head_dim = embed_dim // num_heads + self.head_dim = inner_dim // num_heads self.scaling = self.head_dim ** -0.5 self.dropout_p = dropout # Q - self.q_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) - if bias: - self.q_bias = torch.nn.Parameter(torch.empty(embed_dim)) - else: - self.q_bias = None + self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None # K - self.k_proj = torch.nn.Parameter(torch.empty(embed_dim, self.kdim)) - self.k_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None # V - self.v_proj = torch.nn.Parameter(torch.empty(embed_dim, self.vdim)) - self.v_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, embed_dim)) + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None def forward(self, query: torch.Tensor, key: torch.Tensor): diff --git a/examples/nlp/blocks/decoder.py b/examples/nlp/blocks/decoder.py index ee5f5767..ea4b57de 100644 --- a/examples/nlp/blocks/decoder.py +++ b/examples/nlp/blocks/decoder.py @@ -5,18 +5,23 @@ class DecoderLayer(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, ffn_embed_dim: int, + def __init__(self, embed_dim: int, num_heads: int, + attn_hidden_dim: int, ffn_hidden_dim: int, dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): super().__init__() self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, atten_dropout) + self.self_attn = MultiHeadSelfAttention( + embed_dim, num_heads, attn_hidden_dim, atten_dropout + ) self.cross_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.cross_attn = MultiHeadCrossAttention(embed_dim, num_heads, atten_dropout) + self.cross_attn = MultiHeadCrossAttention( + embed_dim, num_heads, attn_hidden_dim, atten_dropout + ) self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_embed_dim, activation_dropout) + self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) self.final_layer_norm = torch.nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor, encoder_output: torch.Tensor) -> torch.Tensor: diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index ce22d6e9..83645f61 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -5,13 +5,16 @@ class EncoderLayer(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, ffn_embed_dim: int, + def __init__(self, embed_dim: int, num_heads: int, + attn_hidden_dim: int, ffn_hidden_dim: int, dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): super().__init__() - self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, atten_dropout) + self.self_attn = MultiHeadSelfAttention( + embed_dim, num_heads, attn_hidden_dim, atten_dropout + ) self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_embed_dim, activation_dropout) + self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) self.final_layer_norm = torch.nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py index 28a7e441..1f9472ea 100644 --- a/examples/nlp/blocks/mlp.py +++ b/examples/nlp/blocks/mlp.py @@ -16,8 +16,9 @@ def feedforward(x: torch.Tensor, class MLP(torch.nn.Module): - def __init__(self, embed_dim, hidden_dim, dropout: float, bias=True): + def __init__(self, embed_dim: int, hidden_dim: int, dropout: float, bias=True): super().__init__() + print((hidden_dim, embed_dim)) self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) @@ -29,4 +30,4 @@ def forward(self, x: torch.Tensor): self.proj1, self.proj1_bias, self.proj2, self.proj2_bias, self.dropout) - return x \ No newline at end of file + return x diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index a4de26be..73e24739 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -1,5 +1,5 @@ import torch -import math + from examples.nlp.blocks.encoder import EncoderLayer @@ -13,7 +13,7 @@ class Config: # 1.7B model embed_dim = 2304 - layers = 24 + layers = 8 # 24 attention_heads = 24 # 3.6B model @@ -26,7 +26,8 @@ class Config: # layers = 32 # attention_heads = 36 - ffn_embed_dim = embed_dim * 4 + attn_hidden_dim = embed_dim + ffn_hidden_dim = embed_dim * 4 dropout = 0.0 attn_dropout = 0.0 activation_dropout = 0.0 @@ -44,7 +45,8 @@ def __init__(self): self.layers = torch.nn.ModuleList( [EncoderLayer( - cfg.embed_dim, cfg.attention_heads, cfg.ffn_embed_dim, + cfg.embed_dim, cfg.attention_heads, + cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.dropout, cfg.attn_dropout, cfg.activation_dropout ) for _ in range(cfg.layers)] ) From 602af92bb2333550e608cd83a1712b85923de0fb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 13:51:07 +0800 Subject: [PATCH 0820/1892] update running command --- examples/mlp/linears.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index d3b2e024..2039085f 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -1,16 +1,7 @@ """ example: -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/mlp/linears.py - -OMP_NUM_THREADS=4 torchrun --standalone \ +OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ examples/mlp/linears.py @@ -32,9 +23,6 @@ from cube.profiler.timer import print_each_rank from examples.mlp.policy.optimal import PAS -# from examples.mlp.policy.col_parallel import P, A, S -# PAS = (P, A, S) - # =================== Semantic Model Description ==================== class MLP(nn.Module): @@ -88,21 +76,22 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() torch.distributed.barrier() - iter_num = 128 + iter_num = 64 + warmup = 20 for step in range(iter_num): - if step >= 40: + if step >= warmup: CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() - if step >= 40: + if step >= warmup: CudaTimer().stop('e2e') if (step + 1) % 20 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-40) + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) if __name__ == '__main__': From d1bdb8fdacc35244a3d33ca558cae1b3ad22d928 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 13:59:56 +0800 Subject: [PATCH 0821/1892] clean code and update readme --- README.md | 17 ++++--- cube/__init__.py | 32 ++++++++++++ cube/execplan/planpass/grouping.py | 2 - cube/runtime/executor.py | 82 ++++++++++++++++++++++++------ cube/search/iterator.py | 19 +++++++ cube/search/sampler.py | 52 +++++++++++++------ examples/mlp/policy/st_search.py | 3 +- 7 files changed, 165 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 54faa364..fc7a2347 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,11 @@ # MagicCube -AI System Compiler to compile a semantic (single-device) model to distributed model using policies specified by System Expert. +AI System Compiler to map a semantic (single-device) model into distributed execution using policies specified by System Expert. + +## Prerequisite + +* Python >= 3.7 +* PyTorch >= 1.9 ## Install @@ -14,12 +19,8 @@ python setup.py develop * [Micro Benchmark] Run a mutiple MLP Model ```sh -python -m torch.distributed.launch \ - --nproc_per_node=2 \ +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/linears.py + examples/mlp/linears.pys ``` diff --git a/cube/__init__.py b/cube/__init__.py index 2effa9f0..534f909f 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -7,3 +7,35 @@ def init(): _ = runtime.device.DeviceGroup() _ = runtime.resource.EnvResource() + + + +# ================== Experimental Feature ======================= + +# import threading + +# _message_context = None + +# def handle_request(): +# manager = runtime.executor.MessageManager() +# while True: +# req = manager.pull() +# if isinstance(req, int): +# break +# req.wait() + +# def init_manager(): +# global _message_context +# _ = runtime.executor.MessageManager() +# _message_context = threading.Thread(target=handle_request) +# _message_context.start() + + +# def finish_manager(): +# """ +# Clear message manager +# """ +# global _message_context +# manager = runtime.executor.MessageManager() +# manager.push(-1) +# _message_context.join() diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 57fbf010..139dc03a 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -1,8 +1,6 @@ """ Operation grouping """ - -from sqlite3 import adapt from typing import List, Dict, Tuple from cube.execplan import ExectuionPlan diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 809b29ed..24d1d85a 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -58,19 +58,69 @@ def backward(input_tensors : List[torch.Tensor], else: return tuple(grads) -def backwardV2(input_tensors: List[torch.Tensor], output_tensors, output_tensor_grads): - inputs = list() - for input in enumerate(input_tensors): - # skip returning parameters - if torch.is_tensor(input) and not isinstance(input, torch.nn.Parameter): - inputs.append(inputs) - for tensor in input_tensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - tensor.retain_grad() - torch.autograd.backward( - output_tensors, - grad_tensors=output_tensor_grads, - inputs=input_tensors - ) - grads = [input.grad for input in inputs] - return grads +# def backward(input_tensors: List[torch.Tensor], +# output_tensors: List[torch.Tensor], +# output_tensor_grads: List[torch.Tensor]) -> Tuple[torch.Tensor]: +# """ +# Backward Procedure. +# +# input_tensors: List[torch.Tensor]: +# tensors that their gradient need to be computed, including parameters. +# Correspoinding forward input tensors. +# +# output_tensors: +# tensors that start for gradient backward computation. +# Corresponding to forward output tensors. +# +# output_tensor_grads: +# gradient tensors corresponding to output_tensors. +# +# Returns: +# gradient in order of non-parameter tensors in input_tensors. +# (Note parameter tnesors already have gradient accumulated at .grad attribute) +# """ +# if len(output_tensors) == 0: +# return None +# inputs = list() +# for input_ in input_tensors: +# if torch.is_tensor(input_) and not isinstance(input_, torch.nn.Parameter): +# if input_.requires_grad: +# input_.retain_grad() +# inputs.append(input_) +# torch.autograd.backward( +# output_tensors, +# grad_tensors=output_tensor_grads, +# ) +# grads = tuple(input_.grad for input_ in inputs) +# if len(grads) == 0: return None +# elif len(grads) == 1: return grads[0] +# else: return tuple(grads) + +### =================== Experimental Feature ======================= + +# import queue +# +# +# class MessageManager: +# """ +# message manager to make send as async calls. +# """ +# +# class __MessageManager: +# def __init__(self): +# self._reqs = queue.Queue(maxsize=128) +# +# instance = None +# +# def __init__(self): +# if not MessageManager.instance: +# MessageManager.instance = MessageManager.__MessageManager() +# +# def __getattr__(self, name): +# return getattr(self.instance, name) +# +# def push(self, req): +# self.instance._reqs.put(req, block=True, timeout=None) +# +# def pull(self): +# return self.instance._reqs.get(block=True, timeout=None) diff --git a/cube/search/iterator.py b/cube/search/iterator.py index e1fab246..1b698a00 100644 --- a/cube/search/iterator.py +++ b/cube/search/iterator.py @@ -47,6 +47,25 @@ def factorization(K: int, num=1): for res in factorization(K // i, num-1): yield [i] + res +def diff_balls_diff_boxes(nballs: int, nboxes: int, remain = None, placement = None): + balls_per_box = nballs // nboxes + if placement is None and remain is None: + # placement[ball_id] = box_id + placement = [] + # remain slots: remain_slots[box_id] = int + remain = [balls_per_box] * nboxes + if len(placement) == nballs: + yield placement + for box_id, remain_balls in enumerate(remain): + if remain_balls > 0: + placement.append(box_id) + remain[box_id] -= 1 + for seq in diff_balls_diff_boxes(nballs, nboxes, remain, placement): + yield seq + remain[box_id] += 1 + placement = placement[:-1] + + if __name__ == '__main__': diff --git a/cube/search/sampler.py b/cube/search/sampler.py index 053e2efc..f7eae7d9 100644 --- a/cube/search/sampler.py +++ b/cube/search/sampler.py @@ -50,7 +50,7 @@ class Sampler: """ @staticmethod def sample(micro_seqs: List[List[IRCell]], n_microbatch: int, n_stage: int, n_device: int, - ssampler: Callable, tsampler: Callable): + ssampler: Callable, tsampler: Callable, wlimits: int, alimits: int): assert len(micro_seqs) == n_microbatch for seq in micro_seqs: assert len(seq) // 2 == n_stage @@ -68,21 +68,43 @@ def sample(micro_seqs: List[List[IRCell]], n_microbatch: int, n_stage: int, n_de graph.assign(fnode, devid) # pruning: add dependecies for micro-batches with same device assignment + # this pruning guarantees the optimal + # graph.reset_dependency() + # same_microbatch = dict() + # for mid, placement in enumerate(placements): + # placement = tuple(placement) + # if placement not in same_microbatch: + # same_microbatch[placement] = list() + # same_microbatch[placement].append(mid) + # for placement, mids in same_microbatch.items(): + # if len(mids) > 1: + # print(f'find {mids} microbatch same, add dependency') + # for sid in range(len(placement)): + # # add forward dependency + # graph.add_schedule([micro_seqs[mid][sid] for mid in mids]) + # # add backward dependency + # graph.add_schedule([micro_seqs[mid][sid+len(placement)] for mid in mids]) + + # pruning graph.reset_dependency() - same_microbatch = dict() - for mid, placement in enumerate(placements): - placement = tuple(placement) - if placement not in same_microbatch: - same_microbatch[placement] = list() - same_microbatch[placement].append(mid) - for placement, mids in same_microbatch.items(): - if len(mids) > 1: - print(f'find {mids} microbatch same, add dependency') - for sid in range(len(placement)): - # add forward dependency - graph.add_schedule([micro_seqs[mid][sid] for mid in mids]) - # add backward dependency - graph.add_schedule([micro_seqs[mid][sid+len(placement)] for mid in mids]) + forders = [[] for _ in range(n_device)] + # n_device x n_stage + borders = [[[] for _ in range(n_stage)] for _ in range(n_device)] + for sid in range(n_stage): + for mid in range(min(n_microbatch, alimits)): + devid = placements[mid][sid] + forders[devid].append((mid, sid)) + borders[devid][n_stage - 1 - sid].append(mid) + for devid, order in enumerate(forders): + fseq = list() + for mid, sid in order: + fseq.append(micro_seqs[mid][sid]) + graph.add_schedule(fseq) + for devid, order in enumerate(borders): + bseq = list() + for sid in range(n_stage): + bseq += [micro_seqs[mid][n_stage-1-sid] for mid in order[sid]] + graph.add_schedule(bseq) # search for seqs in tsampler(graph.nodes()): diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index 8bbb34e6..3f60ceed 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -42,6 +42,7 @@ def flatten(micro_seqs: List[List[IRCell]]): def PAS(graph: IRGraph, resource): + print(graph.extra_repr()) # n_microbatch, n_stage, n_device M, S, D = 4, 4, 4 @@ -61,7 +62,7 @@ def PAS(graph: IRGraph, resource): bucket = dict() cnt = 0 - for seqs in Sampler.sample(micro_seqs, M, S, D, ssampler, tsampler): + for seqs in Sampler.sample(micro_seqs, M, S, D, ssampler, tsampler, wlimits, alimits): Searcher.search(seqs, bucket, n_worker=n_worker) for mem, (span, seq) in bucket.items(): sgraph._nodes = seq From 8e4dfd74e357a5b2c5311c4f8c831ecbb2f4bb84 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 14:01:12 +0800 Subject: [PATCH 0822/1892] composer clean code --- cube/search/composer.py | 34 +++++++--------------------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/cube/search/composer.py b/cube/search/composer.py index 77682ed5..cd20e7f3 100644 --- a/cube/search/composer.py +++ b/cube/search/composer.py @@ -479,28 +479,6 @@ def iter_shifts(conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], if Composer.same_plans(cmicros, start_step=step): candidates = [candidates[0]] - # pruning tech (for same micro plans): keep forward and backward blocks with smallest mid - # if keep_early_fw: - # fcblocks = [cblock for cblock in cblocks if cblock.type == Block.BType.FW] - # bcblocks = [cblock for cblock in cblocks if cblock.type == Block.BType.BW] - # candidates = [] - # if len(fcblocks) > 0: - # candidates.append(fcblocks[0]) - # if len(bcblocks) > 0: - # candidates.append(bcblocks[0]) - # for kblock in candidates: - # idx = cblocks.index(kblock) - # # keep blocks on the idx while shifts the rest - # nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] - # # if the reserved block executes on multiple devices, - # # then the rest device must shift all other blocks - # for odev in cmicros[idx].position(kblock)[0]: - # if odev != dev and odev in conflicts: - # for _, ocblock in conflicts[odev]: - # if ocblock != cblocks[idx] and ocblock not in nshifts: - # nshifts.append(ocblock) - # next_shifts.append(shifts + cblocks[:idx] + cblocks[idx+1:]) - # pruning tech: if micro plan is same, keep the first block for kblock in candidates: idx = cblocks.index(kblock) # keep blocks on the idx while shifts the rest @@ -562,7 +540,7 @@ def memory_opt_step(micros: List[MicroPlan], step: int): if __name__ == '__main__': - def uniform_staging(ndevs: int, nmicros=4): + def uniform_staging(ndevs: int, nmicros=4) -> List[MicroPlan]: """ f b f b @@ -579,7 +557,7 @@ def uniform_staging(ndevs: int, nmicros=4): micros.append(micro) return micros - def mbart_staging(ndevs: int, nmicros=4): + def mbart_staging(ndevs: int, nmicros=4) -> List[MicroPlan]: """ f f f b b b f f f b b b @@ -605,7 +583,7 @@ def mbart_staging(ndevs: int, nmicros=4): micros.append(micro) return micros - def chimera_staging(ndevs: int, nmicros: int): + def chimera_staging(ndevs: int, nmicros: int) -> List[MicroPlan]: """ f b f b f b f b @@ -653,12 +631,14 @@ def compose_1F1B(ndevs, nmicros): def search(ndevs, nmicros, visualize=False): # premise # micros = Composer.premise(uniform_staging, ndevs, nmicros) - micros = Composer.premise(chimera_staging, ndevs, nmicros) - # micros = Composer.premise(mbart_staging, ndevs, nmicros) + # micros = Composer.premise(chimera_staging, ndevs, nmicros) + micros = Composer.premise(mbart_staging, ndevs, nmicros) print('============== Premise ================') for idx, micro in enumerate(micros): print(f'microbatch #{idx}:') print(micro) + if visualize: + micro.visualize(f'planlog/micro{idx}.png') print('============== Premise ================') # search shift From 3621b02c86c347113b0e1ca9e42fe53fcb989eca Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 16:50:08 +0800 Subject: [PATCH 0823/1892] fix tensor grad shape bug --- cube/graph/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 34056c8d..9308f33f 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -38,7 +38,7 @@ def __init__(self, indmap): raise NotImplementedError( "Only support for sliced index mapping" ) - self._indices = indmap + self._indices: List[slice] = indmap def __eq__(self, other): if isinstance(other, IndexMap): From c0bf60939d06943bc0332052dbc39e4fc42eb90f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 23 May 2022 16:51:19 +0800 Subject: [PATCH 0824/1892] fix tensor grad shape bug --- cube/ir/cten.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index d5a7add3..c9cee3ef 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -577,6 +577,8 @@ def shape(self, val): not all([isinstance(size, int) for size in val]): raise RuntimeError("Expected shape to be list[int]") self._shape = copy.copy(list(val)) + if self.grad is not None: + self.grad.shape = copy.copy(list(val)) def nele(self) -> int: """ From d1bc0c6be5d69c5f44da99240a99fd4484b28245 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 24 May 2022 18:04:45 +0800 Subject: [PATCH 0825/1892] add grid layout --- cube/graph/adapter/layout.py | 344 +++++++++++++++++++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 cube/graph/adapter/layout.py diff --git a/cube/graph/adapter/layout.py b/cube/graph/adapter/layout.py new file mode 100644 index 00000000..41894592 --- /dev/null +++ b/cube/graph/adapter/layout.py @@ -0,0 +1,344 @@ +from typing import Dict, List, Tuple, Union, Optional +import copy +from matplotlib.style import available +import numpy as np + +from cube.graph.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap + + +class GridLayout: + """ + This class assumes a full-tensor can only be + uniformly partitioned / replicated on dimensions and values. + + A partition plan N-dim tensor layout can be represented as + : R (replica), V (value), dim_i (dimension) + """ + + def __init__(self, ftensor: IRFullTensor, subtensors: List[IRSubTensor], mats: np.ndarray): + """ + ftensor: N-dim FullTensor + subtensors: List[IRSubTensors] + mats: Array[IRSubTensor]: + (2+N)-dim matrix, with index respect to + """ + self.ftensor = ftensor + self.subtensors = subtensors + self._tindex: Dict[int, List[int]] = dict() + self._mats = mats + + @property + def R(self) -> int: + return self._mats.shape[0] + + @property + def V(self) -> int: + return self._mats.shape[1] + + @property + def D(self) -> Tuple[int]: + return tuple(self._mats.shape[2:]) + + @property + def vec(self) -> Tuple[int]: + return tuple(self._mats.shape) + + @property + def ndims(self): + return len(self._mats.shape) + + # def index(self, subtensor: IRSubTensor) -> List[int]: + # """ + # Get index of of the subtensor + # """ + # assert id(subtensor) in self._tindex, f"tensor: {subtensor} not found" + # return copy.copy(self._tindex(id(subtensor))) + + def get(self, r: bool = False, v: bool = False, d: Union[bool, int]=False) -> List[IRSubTensor]: + if r: + nchunks = self.R + idx = 0 + elif v: + nchunks = self.V + idx = 1 + elif isinstance(d, int): + nchunks = self.D[d] + idx = 2 + d + else: + raise ValueError("r, v, d should set at least one") + axes = list(range(idx)) + list(range(idx+1, self.ndims)) + [idx] + mat = np.transpose(self._mats, axes).reshape((-1, nchunks)) + for i in mat.shape[0]: + yield mat[i] + + # ====== primitives ===== # + + def d2r(self, dim: int, chunks: int): + """ + dimension to replica: allgather + """ + layout = list(self.vec) + assert layout[2+dim] % chunks == 0, f"not dividable dim: {layout[2+dim]} // {chunks}" + layout[0] = layout[0] * chunks + layout[2+dim] = layout[2+dim] // chunks + return grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + + # all_tensors = [] + # for tensors in self.get(d=dim): + # assert len(tensors) % chunks == 0, "not dividable dim and chunks" + # tensors = tensors.reshape((-1, chunks)) + # for group_tensors in tensors: # go through each row + # indmap = [] + # for idim in range(self.ndims): + # if idim != dim: + # indmap.append(group_tensors[0].valmap.get()[idim]) + # else: + # slicer = slice( + # group_tensors[0].indmap.get()[idim].start, + # group_tensors[-1].indmap.get()[idim].stop, 1 + # ) + # indmap.append(slicer) + # valmap = group_tensors[0].valmap + # for tensor in group_tensors: + # gtensor = self.ftensor.select(indmap, tuple(valmap)) + # gtensor._cell = tensor._cell # set device + # all_tensors.append(gtensor) + # return GridLayout(self.ftensor, all_tensors) + + def d2d(self, from_dim: int, to_dim: int, chunks: int): + """ + dimension to dimension: all-to-all + """ + layout = list(self.vec) + assert layout[2+from_dim] % chunks == 0, f"not dividable dim: {layout[2+from_dim]} // {chunks}" + layout[2+from_dim] = layout[2+from_dim] // chunks + layout[2+to_dim] = layout[2+to_dim] * chunks + return grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + + # if from_dim == to_dim: + # return self + # all_tensors = [] + # for tensors in self.get(d=from_dim): + # assert len(tensors) % chunks == 0, "not dividable dim and chunks" + # tensors = tensors.reshape((-1, chunks)) + # for group_tensors in tensors: + # for cid, tensor in enumerate(group_tensors): + # indmap = [] + # for dim in range(self.ndims): + # # from_dim gets nchunks larger + # if dim == from_dim: + # slicer = slice( + # group_tensors[0].indmap.get()[dim].start, + # group_tensors[-1].indmap.get()[dim].stop, 1 + # ) + # indmap.append(slicer) + # # to_dim gets nchunks smaller + # elif dim == to_dim: + # nele = tensor.shape[dim] // chunks + # start = tensor.indmap.get()[dim].start + # slicer = slice( + # start + nele * cid, + # start + nele * (cid + 1), 1 + # ) + # indmap.append(slicer) + # # others keep unchanged + # else: + # indmap = tensor.indmap.get()[dim] + # valmap = tensor.valmap + # ttensor = self.ftensor.select(indmap, tuple(valmap)) + # ttensor._cell = tensor + # all_tensors.append(ttensor) + # return GridLayout(self.ftensor, all_tensors) + + def v2r(self, chunks: int): + """ + value to replica: all-reduce + """ + layout = list(self.vec) + assert layout[1] % chunks == 0, f"not dividable value chunks: {layout[1]} // {chunks}" + layout[1] = layout[1] // chunks + layout[0] = layout[0] * chunks + return grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + + + def v2d(self, dim: int, chunks: int): + """ + value to dimension: reduce-scatter + """ + layout = list(self.vec) + assert layout[1] % chunks == 0, f"not dividable value chunks: {layout[0]} // {chunks}" + layout[1] = layout[1] // chunks + layout[2+dim] = layout[2+dim] * chunks + return grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + + + def r2d(self, dim: int, chunks: int): + """ + replica to dimension: split + """ + layout = list(self.vec) + assert layout[0] % chunks == 0, f"not dividable replica: {layout[0]} // {chunks}" + layout[0] = layout[0] // chunks + layout[2+dim] = layout[2+dim] * chunks + return grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + + # ================ solution ============= # + + def path(self, dst) -> List: + """ + find path ways from this layout to the target layout + + order: R -> V -> S + """ + def step(layout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLayout: + if dec_idx >= 2 and inc_idx == 0: # d2r + return layout.d2r(dec_idx-2, chunks) + if dec_idx >= 2 and inc_idx >= 2: # d2d + return layout.d2d(dec_idx-2, inc_idx-2, chunks) + if dec_idx == 1 and inc_idx == 0: # v2r + return layout.v2r(chunks) + if dec_idx == 1 and inc_idx >= 2: # v2d + return layout.v2d(inc_idx-2, chunks) + if dec_idx == 0 and inc_idx >= 2: # r2d + return layout.r2d(inc_idx-2, chunks) + raise RuntimeError("Cannot find primitive. Report as a bug") + + paths: List[GridLayout] = [self] + dst: GridLayout = dst + while paths[-1].vec != dst.vec: + src: GridLayout = paths[-1] + inc_idx, dec_idx = None, None + for idx, (schunk, dchunk) in enumerate(zip(src.vec, dst.vec)): + if schunk != dchunk: + print(f'src: {src.vec}, dst: {dst.vec}') + if schunk < dchunk: + inc_idx = idx # src should increase chunks on idx-dim + need_chunks = dchunk // schunk if dchunk % schunk == 0 else dchunk + for dec_idx in range(inc_idx+1, self.ndims): + # print(f'{dec_idx}/{self.ndims}') + if src.vec[dec_idx] > dst.vec[dec_idx]: + if src.vec[dec_idx] % dst.vec[dec_idx] != 0: + available_chunks = dst.vec[dec_idx] + else: + available_chunks = src.vec[dec_idx] // dst.vec[dec_idx] + chunks = min(available_chunks, need_chunks) + break + else: + raise RuntimeError("Cannot find feassible dimension. Report this as a bug.") + else: + dec_idx = idx + need_chunks = schunk // dchunk if schunk % dchunk == 0 else schunk + for inc_idx in range(dec_idx+1, self.ndims): + if src.vec[inc_idx] < dst.vec[inc_idx]: + if dst.vec[inc_idx] % src.vec[inc_idx] != 0: + available_chunks = dst.vec[inc_idx] + else: + available_chunks = dst.vec[inc_idx] // src.vec[inc_idx] + chunks = min(available_chunks, need_chunks) + break + else: + raise RuntimeError("Cannot find feassible dimension. Report this as a bug.") + print(chunks, need_chunks) + layout = step(src, dec_idx, inc_idx, chunks) + paths.append(layout) + break + return paths + + def __repr__(self): + dscp = f'T{self.ftensor._id}' + return dscp + + +def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int]) -> np.ndarray: + """ + partition a ftensor using grid layout of + """ + mats = np.empty([r, v] + dims, dtype=IRSubTensor) + all_subtensors = [] + + def iter_idx(dims: List[int]) -> Tuple[int]: + if len(dims) == 0: + yield () + else: + for i in range(dims[0]): + for indices in iter_idx(dims[1:]): + yield (i,) + indices + # generate tensor for each index + for indices in iter_idx([v,]+dims): + valmap = ValueMap(indices[0], v) + indmap = [] + shape = [] + for dim, (nchunk, index) in enumerate(zip(dims, indices[1:])): + assert ftensor.shape[dim] % nchunk == 0, f"not dividable for {nchunk} chunks over dim {dim}" + csize = ftensor.shape[dim] // nchunk + start = csize * index + indmap.append(slice(start, start+csize, 1)) + shape.append(csize) + subtensor = ftensor.select(tuple(indmap), valmap, shape) + # replicate + subtensors = [copy.copy(subtensor) for _ in range(r)] + all_subtensors += subtensors + mats[(slice(None),)+indices] = np.array(subtensors, dtype=IRSubTensor) + return GridLayout(ftensor, all_subtensors, mats) + + +def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]) -> Optional[GridLayout]: + _replica: int = None + _value: int = None + _dims: List[int] = [None] * len(ftensor.shape) + _tindex: Dict[int, List[int]] = dict() + + ndims = len(ftensor.shape) + + replicas: Dict[int, List[IRSubTensor]] = dict() + vchunks: set = set() + dchunks: List[set] = [set() for _ in range(ndims)] + + for subtensor in subtensors: + tid = id(subtensor) + # set up replica + if subtensor._id not in replicas: + replicas[subtensor._id] = [] + _tindex[tid] = [len(replicas[subtensor._id])] + replicas[subtensor._id].append(subtensor) + # setup value + _tindex[tid].append(subtensor.valmap.idx) + vchunks.add(subtensor.valmap.chunk_num) + # setup dimensions + for dim in range(ndims): + snele = subtensor.shape[dim] + start = subtensor.indmap.get()[dim].start + fnele = ftensor.shape[dim] + if fnele % snele != 0 or start % snele != 0: + raise RuntimeError(f"dimension split error: full nele: {fnele}, sub nele: {snele}, start: {start}") + dchunks[dim].add(fnele // snele) + _tindex[tid].append(start // snele) + # replica (R) + nreplicas = set(len(ts) for ts in replicas.values()) + if len(nreplicas) != 1: + raise RuntimeError(f"different replicas: {nreplicas}") + _replica = list(nreplicas)[0] + # value (V) + nchunks = set(t.valmap.chunk_num for t in subtensors) + if len(nchunks) != 1: + raise RuntimeError(f"different value split: {nchunks}") + _value = list(nchunks)[0] + # dimension (D) + for dim in range(ndims): + if len(dchunks[dim]) != 1: + raise RuntimeError(f"different dimension split: {dchunks[dim]}") + _dims[dim] = list(dchunks[dim])[0] + + # set matrix + mats = np.empty([_replica, _value] + _dims, dtype=IRSubTensor) + for subtensor in subtensors: + idx = tuple(_tindex[id(subtensor)]) + assert mats[idx] is None, f"repeating entry. mutiple same {subtensor}" + mats[tuple(idx)] = subtensor + assert not (mats == None).any(), "at least one entry not set" + return GridLayout(ftensor, subtensors, mats) From d2f14632b8cda19e1aa46a91d2ac98fa51a5cba8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 25 May 2022 09:58:35 +0800 Subject: [PATCH 0826/1892] add test for grid --- cube/graph/adapter/layout.py | 337 ++++++++++++++++++----------------- tests/test_grid.py | 17 ++ 2 files changed, 190 insertions(+), 164 deletions(-) create mode 100644 tests/test_grid.py diff --git a/cube/graph/adapter/layout.py b/cube/graph/adapter/layout.py index 41894592..0bdfe9eb 100644 --- a/cube/graph/adapter/layout.py +++ b/cube/graph/adapter/layout.py @@ -1,6 +1,5 @@ -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Tuple, Union import copy -from matplotlib.style import available import numpy as np from cube.graph.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap @@ -47,12 +46,16 @@ def vec(self) -> Tuple[int]: def ndims(self): return len(self._mats.shape) - # def index(self, subtensor: IRSubTensor) -> List[int]: - # """ - # Get index of of the subtensor - # """ - # assert id(subtensor) in self._tindex, f"tensor: {subtensor} not found" - # return copy.copy(self._tindex(id(subtensor))) + @property + def mat(self): + return self._mats + + def index(self, subtensor: IRSubTensor) -> Tuple[int]: + """ + Get index of of the subtensor + """ + assert id(subtensor) in self._tindex, f"tensor: {subtensor} not found" + return tuple(self._tindex(id(subtensor))) def get(self, r: bool = False, v: bool = False, d: Union[bool, int]=False) -> List[IRSubTensor]: if r: @@ -81,30 +84,14 @@ def d2r(self, dim: int, chunks: int): assert layout[2+dim] % chunks == 0, f"not dividable dim: {layout[2+dim]} // {chunks}" layout[0] = layout[0] * chunks layout[2+dim] = layout[2+dim] // chunks - return grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) - - # all_tensors = [] - # for tensors in self.get(d=dim): - # assert len(tensors) % chunks == 0, "not dividable dim and chunks" - # tensors = tensors.reshape((-1, chunks)) - # for group_tensors in tensors: # go through each row - # indmap = [] - # for idim in range(self.ndims): - # if idim != dim: - # indmap.append(group_tensors[0].valmap.get()[idim]) - # else: - # slicer = slice( - # group_tensors[0].indmap.get()[idim].start, - # group_tensors[-1].indmap.get()[idim].stop, 1 - # ) - # indmap.append(slicer) - # valmap = group_tensors[0].valmap - # for tensor in group_tensors: - # gtensor = self.ftensor.select(indmap, tuple(valmap)) - # gtensor._cell = tensor._cell # set device - # all_tensors.append(gtensor) - # return GridLayout(self.ftensor, all_tensors) + glayout = GridLayout.grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + # set device + gmat = GridLayout.transpose(glayout.mat, 0, 2+dim) + omat = GridLayout.transpose(self.mat, 0, 2+dim) + for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): + gtensor._cell = otensor._cell + return glayout def d2d(self, from_dim: int, to_dim: int, chunks: int): """ @@ -114,43 +101,14 @@ def d2d(self, from_dim: int, to_dim: int, chunks: int): assert layout[2+from_dim] % chunks == 0, f"not dividable dim: {layout[2+from_dim]} // {chunks}" layout[2+from_dim] = layout[2+from_dim] // chunks layout[2+to_dim] = layout[2+to_dim] * chunks - return grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) - - # if from_dim == to_dim: - # return self - # all_tensors = [] - # for tensors in self.get(d=from_dim): - # assert len(tensors) % chunks == 0, "not dividable dim and chunks" - # tensors = tensors.reshape((-1, chunks)) - # for group_tensors in tensors: - # for cid, tensor in enumerate(group_tensors): - # indmap = [] - # for dim in range(self.ndims): - # # from_dim gets nchunks larger - # if dim == from_dim: - # slicer = slice( - # group_tensors[0].indmap.get()[dim].start, - # group_tensors[-1].indmap.get()[dim].stop, 1 - # ) - # indmap.append(slicer) - # # to_dim gets nchunks smaller - # elif dim == to_dim: - # nele = tensor.shape[dim] // chunks - # start = tensor.indmap.get()[dim].start - # slicer = slice( - # start + nele * cid, - # start + nele * (cid + 1), 1 - # ) - # indmap.append(slicer) - # # others keep unchanged - # else: - # indmap = tensor.indmap.get()[dim] - # valmap = tensor.valmap - # ttensor = self.ftensor.select(indmap, tuple(valmap)) - # ttensor._cell = tensor - # all_tensors.append(ttensor) - # return GridLayout(self.ftensor, all_tensors) + glayout = GridLayout.grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + # set device + gmat = GridLayout.transpose(glayout.mat, 2+to_dim, 2+from_dim) + omat = GridLayout.transpose(self.mat, 2+to_dim, 2+from_dim) + for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): + gtensor._cell = otensor._cell + return glayout def v2r(self, chunks: int): """ @@ -160,8 +118,14 @@ def v2r(self, chunks: int): assert layout[1] % chunks == 0, f"not dividable value chunks: {layout[1]} // {chunks}" layout[1] = layout[1] // chunks layout[0] = layout[0] * chunks - return grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) + glayout = GridLayout.grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + # set device + gmat = GridLayout.transpose(glayout.mat, 0, 1) + omat = GridLayout.transpose(self.mat, 0, 1) + for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): + gtensor._cell = otensor._cell + return glayout def v2d(self, dim: int, chunks: int): @@ -172,8 +136,14 @@ def v2d(self, dim: int, chunks: int): assert layout[1] % chunks == 0, f"not dividable value chunks: {layout[0]} // {chunks}" layout[1] = layout[1] // chunks layout[2+dim] = layout[2+dim] * chunks - return grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) + glayout = GridLayout.grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + # set device + gmat = GridLayout.transpose(glayout.mat, 2+dim, 1) + omat = GridLayout.transpose(self.mat, 2+dim, 1) + for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): + gtensor._cell = otensor._cell + return glayout def r2d(self, dim: int, chunks: int): @@ -184,8 +154,14 @@ def r2d(self, dim: int, chunks: int): assert layout[0] % chunks == 0, f"not dividable replica: {layout[0]} // {chunks}" layout[0] = layout[0] // chunks layout[2+dim] = layout[2+dim] * chunks - return grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) + glayout = GridLayout.grid(self.ftensor, + r=layout[0], v=layout[1], dims=layout[2:]) + # set device + gmat = GridLayout.transpose(glayout.mat, 2+dim, 0) + omat = GridLayout.transpose(self.mat, 2+dim, 0) + for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): + gtensor._cell = otensor._cell + return glayout # ================ solution ============= # @@ -215,7 +191,7 @@ def step(layout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLay inc_idx, dec_idx = None, None for idx, (schunk, dchunk) in enumerate(zip(src.vec, dst.vec)): if schunk != dchunk: - print(f'src: {src.vec}, dst: {dst.vec}') + # print(f'src: {src.vec}, dst: {dst.vec}') if schunk < dchunk: inc_idx = idx # src should increase chunks on idx-dim need_chunks = dchunk // schunk if dchunk % schunk == 0 else dchunk @@ -243,7 +219,7 @@ def step(layout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLay break else: raise RuntimeError("Cannot find feassible dimension. Report this as a bug.") - print(chunks, need_chunks) + # print(chunks, need_chunks) layout = step(src, dec_idx, inc_idx, chunks) paths.append(layout) break @@ -253,92 +229,125 @@ def __repr__(self): dscp = f'T{self.ftensor._id}' return dscp + def print_dev_tensors(self): + """ + print each device hold tensors. + """ + devices: Dict[int, List[IRSubTensor]] = dict() + for tensor in self.subtensors: + assert len(tensor.device) == 1, f"got tensor device: {tensor.device}" + if tensor.device[0] not in devices: + devices[tensor.device[0]] = [] + devices[tensor.device[0]].append(tensor) + devs = list(devices.keys()) + devs.sort() + for dev in devs: + print(f'dev{dev}: {devices[dev]}') + + @staticmethod + def transpose(mat: np.ndarray, dim0: int, dim1: int): + """ + put the dim0 and dim1 of the mat to the last two dims + """ + ndims = len(mat.shape) + axes = list(range(ndims)) + assert dim0 < ndims and dim1 < ndims, "dim0 or dim1 out of index" + axes.pop(max(dim0, dim1)) + axes.pop(min(dim0, dim1)) + axes += [dim0, dim1] + return np.transpose(mat, axes) + + @staticmethod + def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int]): + """ + partition a ftensor using grid layout of + """ + mats = np.empty([r, v] + dims, dtype=IRSubTensor) + all_subtensors = [] + + def iter_idx(dims: List[int]) -> Tuple[int]: + if len(dims) == 0: + yield () + else: + for i in range(dims[0]): + for indices in iter_idx(dims[1:]): + yield (i,) + indices + # generate tensor for each index + for indices in iter_idx([v,]+dims): + valmap = ValueMap(indices[0], v) + indmap = [] + shape = [] + for dim, (nchunk, index) in enumerate(zip(dims, indices[1:])): + assert ftensor.shape[dim] % nchunk == 0, f"not dividable for {nchunk} chunks over dim {dim}" + csize = ftensor.shape[dim] // nchunk + start = csize * index + indmap.append(slice(start, start+csize, 1)) + shape.append(csize) + subtensor = ftensor.select(tuple(indmap), valmap, shape) + # replicate + subtensors = [copy.copy(subtensor) for _ in range(r)] + all_subtensors += subtensors + mats[(slice(None),)+indices] = np.array(subtensors, dtype=IRSubTensor) + return GridLayout(ftensor, all_subtensors, mats) + + @staticmethod + def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): + """ + convert ftensor and subtensors into a GridLayout. -def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int]) -> np.ndarray: - """ - partition a ftensor using grid layout of - """ - mats = np.empty([r, v] + dims, dtype=IRSubTensor) - all_subtensors = [] - - def iter_idx(dims: List[int]) -> Tuple[int]: - if len(dims) == 0: - yield () - else: - for i in range(dims[0]): - for indices in iter_idx(dims[1:]): - yield (i,) + indices - # generate tensor for each index - for indices in iter_idx([v,]+dims): - valmap = ValueMap(indices[0], v) - indmap = [] - shape = [] - for dim, (nchunk, index) in enumerate(zip(dims, indices[1:])): - assert ftensor.shape[dim] % nchunk == 0, f"not dividable for {nchunk} chunks over dim {dim}" - csize = ftensor.shape[dim] // nchunk - start = csize * index - indmap.append(slice(start, start+csize, 1)) - shape.append(csize) - subtensor = ftensor.select(tuple(indmap), valmap, shape) - # replicate - subtensors = [copy.copy(subtensor) for _ in range(r)] - all_subtensors += subtensors - mats[(slice(None),)+indices] = np.array(subtensors, dtype=IRSubTensor) - return GridLayout(ftensor, all_subtensors, mats) - - -def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]) -> Optional[GridLayout]: - _replica: int = None - _value: int = None - _dims: List[int] = [None] * len(ftensor.shape) - _tindex: Dict[int, List[int]] = dict() - - ndims = len(ftensor.shape) - - replicas: Dict[int, List[IRSubTensor]] = dict() - vchunks: set = set() - dchunks: List[set] = [set() for _ in range(ndims)] - - for subtensor in subtensors: - tid = id(subtensor) - # set up replica - if subtensor._id not in replicas: - replicas[subtensor._id] = [] - _tindex[tid] = [len(replicas[subtensor._id])] - replicas[subtensor._id].append(subtensor) - # setup value - _tindex[tid].append(subtensor.valmap.idx) - vchunks.add(subtensor.valmap.chunk_num) - # setup dimensions + If failed, raise error + """ + _replica: int = None + _value: int = None + _dims: List[int] = [None] * len(ftensor.shape) + _tindex: Dict[int, List[int]] = dict() + + ndims = len(ftensor.shape) + + replicas: Dict[int, List[IRSubTensor]] = dict() + vchunks: set = set() + dchunks: List[set] = [set() for _ in range(ndims)] + + for subtensor in subtensors: + tid = id(subtensor) + # set up replica + if subtensor._id not in replicas: + replicas[subtensor._id] = [] + _tindex[tid] = [len(replicas[subtensor._id])] + replicas[subtensor._id].append(subtensor) + # setup value + _tindex[tid].append(subtensor.valmap.idx) + vchunks.add(subtensor.valmap.chunk_num) + # setup dimensions + for dim in range(ndims): + snele = subtensor.shape[dim] + start = subtensor.indmap.get()[dim].start + fnele = ftensor.shape[dim] + if fnele % snele != 0 or start % snele != 0: + raise RuntimeError(f"dimension split error: full nele: {fnele}, sub nele: {snele}, start: {start}") + dchunks[dim].add(fnele // snele) + _tindex[tid].append(start // snele) + # replica (R) + nreplicas = set(len(ts) for ts in replicas.values()) + if len(nreplicas) != 1: + raise RuntimeError(f"different replicas: {nreplicas}") + _replica = list(nreplicas)[0] + # value (V) + nchunks = set(t.valmap.chunk_num for t in subtensors) + if len(nchunks) != 1: + raise RuntimeError(f"different value split: {nchunks}") + _value = list(nchunks)[0] + # dimension (D) for dim in range(ndims): - snele = subtensor.shape[dim] - start = subtensor.indmap.get()[dim].start - fnele = ftensor.shape[dim] - if fnele % snele != 0 or start % snele != 0: - raise RuntimeError(f"dimension split error: full nele: {fnele}, sub nele: {snele}, start: {start}") - dchunks[dim].add(fnele // snele) - _tindex[tid].append(start // snele) - # replica (R) - nreplicas = set(len(ts) for ts in replicas.values()) - if len(nreplicas) != 1: - raise RuntimeError(f"different replicas: {nreplicas}") - _replica = list(nreplicas)[0] - # value (V) - nchunks = set(t.valmap.chunk_num for t in subtensors) - if len(nchunks) != 1: - raise RuntimeError(f"different value split: {nchunks}") - _value = list(nchunks)[0] - # dimension (D) - for dim in range(ndims): - if len(dchunks[dim]) != 1: - raise RuntimeError(f"different dimension split: {dchunks[dim]}") - _dims[dim] = list(dchunks[dim])[0] - - # set matrix - mats = np.empty([_replica, _value] + _dims, dtype=IRSubTensor) - for subtensor in subtensors: - idx = tuple(_tindex[id(subtensor)]) - assert mats[idx] is None, f"repeating entry. mutiple same {subtensor}" - mats[tuple(idx)] = subtensor - assert not (mats == None).any(), "at least one entry not set" - return GridLayout(ftensor, subtensors, mats) + if len(dchunks[dim]) != 1: + raise RuntimeError(f"different dimension split: {dchunks[dim]}") + _dims[dim] = list(dchunks[dim])[0] + + # set matrix + mats = np.empty([_replica, _value] + _dims, dtype=IRSubTensor) + for subtensor in subtensors: + idx = tuple(_tindex[id(subtensor)]) + assert mats[idx] is None, f"repeating entry. mutiple same {subtensor}" + mats[tuple(idx)] = subtensor + assert not (mats == None).any(), "at least one entry not set" + return GridLayout(ftensor, subtensors, mats) diff --git a/tests/test_grid.py b/tests/test_grid.py new file mode 100644 index 00000000..76e5ac06 --- /dev/null +++ b/tests/test_grid.py @@ -0,0 +1,17 @@ +from cube.graph.adapter.layout import GridLayout +from cube.graph.tensor import IRFullTensor + +def test_grid(): + + tensor = IRFullTensor(shape=[8192,8192], name='src') + + src = GridLayout.grid(tensor, r=2, v=2, dims=[0, 0]) + dst = GridLayout.grid(tensor, r=4, v=1, dims=[0, 0]) + + path = src.path(dst) + for grid in path: + print(grid) + + +if __name__ == '__main__': + test_grid() From 60d8436b5476c7357da9a446b8dc560c9b421480 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 25 May 2022 11:03:57 +0800 Subject: [PATCH 0827/1892] fix device assignment --- cube/graph/adapter/layout.py | 39 ++++++++---------------------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/cube/graph/adapter/layout.py b/cube/graph/adapter/layout.py index 0bdfe9eb..2679e80f 100644 --- a/cube/graph/adapter/layout.py +++ b/cube/graph/adapter/layout.py @@ -23,7 +23,6 @@ def __init__(self, ftensor: IRFullTensor, subtensors: List[IRSubTensor], mats: n """ self.ftensor = ftensor self.subtensors = subtensors - self._tindex: Dict[int, List[int]] = dict() self._mats = mats @property @@ -50,30 +49,6 @@ def ndims(self): def mat(self): return self._mats - def index(self, subtensor: IRSubTensor) -> Tuple[int]: - """ - Get index of of the subtensor - """ - assert id(subtensor) in self._tindex, f"tensor: {subtensor} not found" - return tuple(self._tindex(id(subtensor))) - - def get(self, r: bool = False, v: bool = False, d: Union[bool, int]=False) -> List[IRSubTensor]: - if r: - nchunks = self.R - idx = 0 - elif v: - nchunks = self.V - idx = 1 - elif isinstance(d, int): - nchunks = self.D[d] - idx = 2 + d - else: - raise ValueError("r, v, d should set at least one") - axes = list(range(idx)) + list(range(idx+1, self.ndims)) + [idx] - mat = np.transpose(self._mats, axes).reshape((-1, nchunks)) - for i in mat.shape[0]: - yield mat[i] - # ====== primitives ===== # def d2r(self, dim: int, chunks: int): @@ -87,8 +62,8 @@ def d2r(self, dim: int, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - gmat = GridLayout.transpose(glayout.mat, 0, 2+dim) omat = GridLayout.transpose(self.mat, 0, 2+dim) + gmat = GridLayout.transpose(glayout.mat, 2+dim, 0) for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): gtensor._cell = otensor._cell return glayout @@ -104,8 +79,8 @@ def d2d(self, from_dim: int, to_dim: int, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - gmat = GridLayout.transpose(glayout.mat, 2+to_dim, 2+from_dim) omat = GridLayout.transpose(self.mat, 2+to_dim, 2+from_dim) + gmat = GridLayout.transpose(glayout.mat, 2+from_dim, 2+to_dim) for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): gtensor._cell = otensor._cell return glayout @@ -121,8 +96,8 @@ def v2r(self, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - gmat = GridLayout.transpose(glayout.mat, 0, 1) omat = GridLayout.transpose(self.mat, 0, 1) + gmat = GridLayout.transpose(glayout.mat, 1, 0) for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): gtensor._cell = otensor._cell return glayout @@ -139,8 +114,8 @@ def v2d(self, dim: int, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - gmat = GridLayout.transpose(glayout.mat, 2+dim, 1) omat = GridLayout.transpose(self.mat, 2+dim, 1) + gmat = GridLayout.transpose(glayout.mat, 1, 2+dim) for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): gtensor._cell = otensor._cell return glayout @@ -157,8 +132,8 @@ def r2d(self, dim: int, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - gmat = GridLayout.transpose(glayout.mat, 2+dim, 0) omat = GridLayout.transpose(self.mat, 2+dim, 0) + gmat = GridLayout.transpose(glayout.mat, 0, 2+dim) for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): gtensor._cell = otensor._cell return glayout @@ -242,7 +217,9 @@ def print_dev_tensors(self): devs = list(devices.keys()) devs.sort() for dev in devs: - print(f'dev{dev}: {devices[dev]}') + print(f'dev{dev}:') + for tensor in devices[dev]: + print(f'\t{tensor.extra_repr()}') @staticmethod def transpose(mat: np.ndarray, dim0: int, dim1: int): From 90dd393294d930313de4d20b8569710e07738df5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 25 May 2022 14:50:47 +0800 Subject: [PATCH 0828/1892] generate communication primitives --- cube/graph/adapter/layout.py | 103 ++++++++++++++++++++++------------- cube/graph/adapter/prim.py | 46 +++++++++++++--- cube/graph/tensor.py | 1 + cube/ir/cten.py | 8 +-- tests/test_grid.py | 8 ++- 5 files changed, 112 insertions(+), 54 deletions(-) diff --git a/cube/graph/adapter/layout.py b/cube/graph/adapter/layout.py index 2679e80f..7b631def 100644 --- a/cube/graph/adapter/layout.py +++ b/cube/graph/adapter/layout.py @@ -1,8 +1,15 @@ -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple import copy import numpy as np -from cube.graph.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from cube.graph.tensor import IRFullTensor, IRSubTensor +from cube.graph.tensor import IndexMap, ValueMap + +from cube.graph.adapter.prim import AllGatherPrim # d2r +from cube.graph.adapter.prim import AllToAllPrim # d2d +from cube.graph.adapter.prim import AllReducePrim # v2r +from cube.graph.adapter.prim import ReduceScatterPrim # v2d +from cube.graph.adapter.prim import SplitDropDimPrim # r2d class GridLayout: @@ -62,11 +69,16 @@ def d2r(self, dim: int, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - omat = GridLayout.transpose(self.mat, 0, 2+dim) - gmat = GridLayout.transpose(glayout.mat, 2+dim, 0) - for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): - gtensor._cell = otensor._cell - return glayout + imat = GridLayout.transpose(self.mat, 0, 2+dim) + omat = GridLayout.transpose(glayout.mat, 2+dim, 0) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor._cell = itensor._cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + print(itensors) + print(otensors) + prims.append(AllGatherPrim(itensors, otensors, dim)) + return glayout, prims def d2d(self, from_dim: int, to_dim: int, chunks: int): """ @@ -79,11 +91,14 @@ def d2d(self, from_dim: int, to_dim: int, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - omat = GridLayout.transpose(self.mat, 2+to_dim, 2+from_dim) - gmat = GridLayout.transpose(glayout.mat, 2+from_dim, 2+to_dim) - for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): - gtensor._cell = otensor._cell - return glayout + imat = GridLayout.transpose(self.mat, 2+to_dim, 2+from_dim) + omat = GridLayout.transpose(glayout.mat, 2+from_dim, 2+to_dim) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor._cell = itensor._cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(AllToAllPrim(itensors, otensors, from_dim, to_dim)) + return glayout, prims def v2r(self, chunks: int): """ @@ -96,12 +111,14 @@ def v2r(self, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - omat = GridLayout.transpose(self.mat, 0, 1) - gmat = GridLayout.transpose(glayout.mat, 1, 0) - for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): - gtensor._cell = otensor._cell - return glayout - + imat = GridLayout.transpose(self.mat, 0, 1) + omat = GridLayout.transpose(glayout.mat, 1, 0) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor._cell = itensor._cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(AllReducePrim(itensors, otensors)) + return glayout, prims def v2d(self, dim: int, chunks: int): """ @@ -114,12 +131,14 @@ def v2d(self, dim: int, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - omat = GridLayout.transpose(self.mat, 2+dim, 1) - gmat = GridLayout.transpose(glayout.mat, 1, 2+dim) - for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): - gtensor._cell = otensor._cell - return glayout - + imat = GridLayout.transpose(self.mat, 2+dim, 1) + omat = GridLayout.transpose(glayout.mat, 1, 2+dim) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor._cell = itensor._cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(ReduceScatterPrim(itensors, otensors, dim)) + return glayout, prims def r2d(self, dim: int, chunks: int): """ @@ -132,11 +151,15 @@ def r2d(self, dim: int, chunks: int): glayout = GridLayout.grid(self.ftensor, r=layout[0], v=layout[1], dims=layout[2:]) # set device - omat = GridLayout.transpose(self.mat, 2+dim, 0) - gmat = GridLayout.transpose(glayout.mat, 0, 2+dim) - for gtensor, otensor in zip(gmat.flatten(), omat.flatten()): - gtensor._cell = otensor._cell - return glayout + imat = GridLayout.transpose(self.mat, 2+dim, 0) + omat = GridLayout.transpose(glayout.mat, 0, 2+dim) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor._cell = itensor._cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + for idx, (itensor, otensor) in enumerate(zip(itensors, otensors)): + prims.append(SplitDropDimPrim(itensor, otensor, dim, chunks, idx)) + return glayout, prims # ================ solution ============= # @@ -146,19 +169,20 @@ def path(self, dst) -> List: order: R -> V -> S """ - def step(layout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLayout: + def step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLayout: if dec_idx >= 2 and inc_idx == 0: # d2r - return layout.d2r(dec_idx-2, chunks) + return ilayout.d2r(dec_idx-2, chunks) if dec_idx >= 2 and inc_idx >= 2: # d2d - return layout.d2d(dec_idx-2, inc_idx-2, chunks) + return ilayout.d2d(dec_idx-2, inc_idx-2, chunks) if dec_idx == 1 and inc_idx == 0: # v2r - return layout.v2r(chunks) + return ilayout.v2r(chunks) if dec_idx == 1 and inc_idx >= 2: # v2d - return layout.v2d(inc_idx-2, chunks) + return ilayout.v2d(inc_idx-2, chunks) if dec_idx == 0 and inc_idx >= 2: # r2d - return layout.r2d(inc_idx-2, chunks) + return ilayout.r2d(inc_idx-2, chunks) raise RuntimeError("Cannot find primitive. Report as a bug") - + + comm_prims = [] paths: List[GridLayout] = [self] dst: GridLayout = dst while paths[-1].vec != dst.vec: @@ -195,10 +219,11 @@ def step(layout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLay else: raise RuntimeError("Cannot find feassible dimension. Report this as a bug.") # print(chunks, need_chunks) - layout = step(src, dec_idx, inc_idx, chunks) - paths.append(layout) + olayout, oprims = step(src, dec_idx, inc_idx, chunks) + paths.append(olayout) + comm_prims += oprims break - return paths + return paths, comm_prims def __repr__(self): dscp = f'T{self.ftensor._id}' diff --git a/cube/graph/adapter/prim.py b/cube/graph/adapter/prim.py index 3b9bf112..e1a7e141 100644 --- a/cube/graph/adapter/prim.py +++ b/cube/graph/adapter/prim.py @@ -11,8 +11,8 @@ class IRAdapterPrim: def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): - self._inputs = inputs - self._outputs = outputs + self._inputs = list(inputs) + self._outputs = list(outputs) self._device = [] self.kwargs = dict() @@ -70,7 +70,7 @@ def __init__(self, otensors: List[IRSubTensor]): super().__init__(itensors, otensors) devices = [] - for t in itensors + otensors: + for t in list(itensors) + list(otensors): devices += t.device self.device = list(set(devices)) @@ -113,6 +113,23 @@ def __init__(self, itensor: IRSubTensor, dim: int, self.device = itensor.device +class SplitDropDimPrim(SpatialPrim): + """ + split dimension in n chunks and take idx-th chunk + """ + def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor, + dim: int, chunks: int, idx: int): + assert 0 <=idx and idx < chunks, "idx out of scope" + super().__init__([itensor], [otensor]) + self.dim = dim + self.chunks = chunks + self.idx = idx + self.device = itensor.device + + def __repr__(self) -> str: + return f'dev{self.device}: {self.outputs(0)} = split(dim={self.dim}, chunks={self.chunks}, idx={self.idx})' + + class MergeDimPrim(SpatialPrim): """ concatenate dimension @@ -126,7 +143,7 @@ def __init__(self, itensors: List[IRSubTensor], dim: int, # numerical primitive -class ReducePrim(ValuePrim): +class SumPrim(ValuePrim): def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): assert all(itensor.device == itensors[0].device for itensor in itensors), "device not same" @@ -219,21 +236,31 @@ class AllReducePrim(CollectivePrim): def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): super().__init__(itensors, otensors) + def __repr__(self) -> str: + return f'dev{self.device}: {self.outputs()} = all_reduce({self.inputs()}' + class AllGatherPrim(CollectivePrim): """ non-differentiabl all-to-all """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): - super().__init__(itensors, otensors) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): + super().__init__(itensors, otensors, dim=dim) + + def __repr__(self) -> str: + return f'dev{self.device}: {self.outputs()} = all_gather({self.inputs()})' class ReduceScatterPrim(CollectivePrim): """ non-differential reduce-scatter """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): - super().__init__(itensors, otensors) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): + super().__init__(itensors, otensors, dim=dim) + + def __repr__(self) -> str: + return f'dev{self.device}: {self.outputs()} = reduce_scatter({self.inputs()})' + class BroadcastPrim(CollectivePrim): @@ -264,6 +291,9 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idi """ super().__init__(itensors, otensors, idim=idim, odim=odim) + def __repr__(self) -> str: + return f"dev{self.device}: {self.outputs()} = all_to_all({self.inputs}, idim={self.kwargs['idm']}, odim={self.kwargs['odim']})" + class DiffCollectivePrim(CollectivePrim): """ diff --git a/cube/graph/tensor.py b/cube/graph/tensor.py index 9308f33f..41c53eaa 100644 --- a/cube/graph/tensor.py +++ b/cube/graph/tensor.py @@ -518,6 +518,7 @@ def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, # return tensor to keep id same for same sub tensor for sub_tensor in self.subtensors(): if sub_tensor.indmap == indmap and sub_tensor.valmap == valmap: + sub_tensor = copy.copy(sub_tensor) return sub_tensor sub_tensor = IRSubTensor(self, indmap, valmap, shape) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index c9cee3ef..0d1781d0 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -403,9 +403,9 @@ class IRTensor: _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad', '_dtype'] - def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown): + def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): - self._id: int = IDGenerator().gen_tensor_id() + self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() self._shape: Optional(List[int]) = shape self.name = name if name else 'tensor' @@ -464,7 +464,7 @@ def device(self) -> List[int]: if self._cell: return self._cell.device else: - return None + return [] @device.setter def device(self, val: Union[int, List[int]]): @@ -551,7 +551,7 @@ def __copy__(self): Returns: tensor """ - tensor = IRTensor(self._shape, self.name) + tensor = IRTensor(self._shape, self.name, tid=self._id) for key in self.__dict__: setattr(tensor, key, getattr(self, key)) # clear attached cells diff --git a/tests/test_grid.py b/tests/test_grid.py index 76e5ac06..75d815f1 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -5,12 +5,14 @@ def test_grid(): tensor = IRFullTensor(shape=[8192,8192], name='src') - src = GridLayout.grid(tensor, r=2, v=2, dims=[0, 0]) - dst = GridLayout.grid(tensor, r=4, v=1, dims=[0, 0]) + src = GridLayout.grid(tensor, r=2, v=2, dims=[1, 1]) + dst = GridLayout.grid(tensor, r=2, v=1, dims=[2, 1]) - path = src.path(dst) + path, prims = src.path(dst) for grid in path: print(grid) + for prim in prims: + print(prim) if __name__ == '__main__': From cbdf9fa280a04940fd389daf9992d0ab63648232 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 26 May 2022 23:47:28 +0800 Subject: [PATCH 0829/1892] auto replace operator --- cube/graph/adapter/layout.py | 39 ++++++++++++++++++++++++++++++------ cube/graph/adapter/prim.py | 27 +++++++++++++++++-------- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/cube/graph/adapter/layout.py b/cube/graph/adapter/layout.py index 7b631def..5cd8b1b9 100644 --- a/cube/graph/adapter/layout.py +++ b/cube/graph/adapter/layout.py @@ -75,8 +75,6 @@ def d2r(self, dim: int, chunks: int): otensor._cell = itensor._cell prims = [] for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - print(itensors) - print(otensors) prims.append(AllGatherPrim(itensors, otensors, dim)) return glayout, prims @@ -163,11 +161,24 @@ def r2d(self, dim: int, chunks: int): # ================ solution ============= # - def path(self, dst) -> List: + def path(self, dst, auto_replace: bool = False) -> Tuple: """ - find path ways from this layout to the target layout - - order: R -> V -> S + Find a path from self to destination GridLayout using + primitivies. This implementation uses search order of + R -> V -> S. + + Args: + dst: GridLayout + auto_replace: bool + If true, the consumer operator may be replaced + to match the device assignment. + + Return: + paths: List[GridLayout] + the search path from source GridLayout (self) + to destination GridLayout (self) + comm_prims: List[IRAdapterPrim] + communication primitives for translation """ def step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLayout: if dec_idx >= 2 and inc_idx == 0: # d2r @@ -223,6 +234,22 @@ def step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLa paths.append(olayout) comm_prims += oprims break + if auto_replace: + replaced = False + reorder : Dict[str, Tuple[int, int]] = dict() + for itensor, otensor in zip(paths[-1].mat.flatten(), dst.mat.flatten()): + assert len(itensor.device) == 1 and len(otensor.device) == 1, \ + "Expect tensor only has one device. Report this as a bug" + if itensor.device != otensor.device: + inode, onode = itensor._cell, otensor._cell + reorder[f'{onode.name}-{onode._id}'] = (onode.device[0], inode.device[0]) + onode.device = inode.device + if onode.mirror is not None: + onode.mirror.device = inode.device + replaced = True + if replaced: + print(f'warning: a better device placement is found and set for op {reorder}') + return paths, comm_prims def __repr__(self): diff --git a/cube/graph/adapter/prim.py b/cube/graph/adapter/prim.py index e1a7e141..512e945d 100644 --- a/cube/graph/adapter/prim.py +++ b/cube/graph/adapter/prim.py @@ -2,7 +2,7 @@ The primitive used for IRAdapter """ -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import copy from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap @@ -15,6 +15,7 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): self._outputs = list(outputs) self._device = [] self.kwargs = dict() + self.signature = None def inputs(self, idx: Optional[int] = None): assert idx is None or isinstance(idx, int), "expected idx to be None or int" @@ -31,6 +32,8 @@ def outputs(self, idx: Optional[int] = None): return self._outputs[idx] def dispatch(self, devid: int): + if devid not in self.device: + return None return self @property @@ -109,7 +112,7 @@ class SplitDimPrim(SpatialPrim): def __init__(self, itensor: IRSubTensor, dim: int, otensors: List[IRSubTensor]): super().__init__([itensor], otensors) - self.dim = dim + self.kwargs['dim'] = dim self.device = itensor.device @@ -121,13 +124,14 @@ def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor, dim: int, chunks: int, idx: int): assert 0 <=idx and idx < chunks, "idx out of scope" super().__init__([itensor], [otensor]) - self.dim = dim - self.chunks = chunks - self.idx = idx + self.kwargs['dim'] = dim + self.kwargs['chunks'] = chunks + self.kwargs['idx'] = idx self.device = itensor.device + self.signature = 'cube.runtime.adapter.collectives.split_drop_dim' def __repr__(self) -> str: - return f'dev{self.device}: {self.outputs(0)} = split(dim={self.dim}, chunks={self.chunks}, idx={self.idx})' + return f"dev{self.device}: {self.outputs(0)} = split(dim={self.kwargs['dim']}, chunks={self.kwargs['chunks']}, idx={self.kwargs['idx']})" class MergeDimPrim(SpatialPrim): @@ -218,14 +222,17 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k for arg, val in kwargs.items(): self.kwargs[arg] = val - def dispatch(self, devid: int, init_method: Callable): + def dispatch(self, devid: int) -> Optional[CommPrim]: """ dispatch to a given device """ + if devid not in self.device: + return None assert devid in self.device, f"device {devid} not applied for this comm primitive" itensors = [itensor for itensor in self.inputs() if devid in itensor.device] otensors = [otensor for otensor in self.outputs() if devid in otensor.device] - prim = init_method(itensors, otensors, **self.kwargs) + prim = CollectivePrim(itensors, otensors, **self.kwargs) + prim.signature = self.signature return prim @@ -235,6 +242,7 @@ class AllReducePrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): super().__init__(itensors, otensors) + self.signature = 'cube.runtime.adapter.collectives.all_reduce' def __repr__(self) -> str: return f'dev{self.device}: {self.outputs()} = all_reduce({self.inputs()}' @@ -246,6 +254,7 @@ class AllGatherPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): super().__init__(itensors, otensors, dim=dim) + self.signature = 'cube.runtime.adapter.collectives.all_gather' def __repr__(self) -> str: return f'dev{self.device}: {self.outputs()} = all_gather({self.inputs()})' @@ -257,6 +266,7 @@ class ReduceScatterPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): super().__init__(itensors, otensors, dim=dim) + self.signature = 'cube.runtime.adapter.collectives.reduce_scatter' def __repr__(self) -> str: return f'dev{self.device}: {self.outputs()} = reduce_scatter({self.inputs()})' @@ -290,6 +300,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idi idim != odim """ super().__init__(itensors, otensors, idim=idim, odim=odim) + self.signature = 'cube.runtime.adapter.collectives.all_to_all' def __repr__(self) -> str: return f"dev{self.device}: {self.outputs()} = all_to_all({self.inputs}, idim={self.kwargs['idm']}, odim={self.kwargs['odim']})" From 1e4ac229d8c77fcfcb4a468f66ad625872b01c0d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 27 May 2022 16:10:52 +0800 Subject: [PATCH 0830/1892] new adapter with grid layout integrated --- cube/codegen/codegen.py | 141 +++---- cube/compiler.py | 24 +- cube/graph/adapter/adapter.py | 550 ++++------------------------ cube/graph/adapter/gen.py | 332 ++++++++++++++--- cube/graph/adapter/layout.py | 34 +- cube/graph/adapter/prim.py | 100 ++--- cube/graph/graph.py | 2 +- cube/runtime/adapter/__init__.py | 10 +- cube/runtime/adapter/collectives.py | 99 +++-- cube/runtime/adapter/transform.py | 71 ++-- 10 files changed, 614 insertions(+), 749 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index b7f93df9..aa18c4cf 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -11,8 +11,7 @@ from cube.graph.tensor import IRSubTensor from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.graph.adapter.adapter import CollectivePrim, IRAdapter, SelectPrim, MovePrim, MergePrim -from cube.graph.adapter.adapter import IRWeightReducer +from cube.graph.adapter.adapter import IRWeightReducer, IRAdapter from cube.graph.graph import IRGraph from cube.execplan import ExectuionPlan @@ -99,10 +98,10 @@ def init_comm_groups(self): # collect groups from p2p fusion adapters = [n for n in graph.nodes() if isinstance(n, IRAdapter)] for adapter in adapters: - for prim in adapter.prims(): - if not isinstance(prim, CollectivePrim): + for prim in adapter.prims: + if len(prim.device) == 1: continue - ranks = list(prim.group) + ranks = list(prim.device) ranks.sort() ranks = tuple(ranks) if ranks not in comm_groups: @@ -135,7 +134,8 @@ def gen(self, device: int, outfile=None, attach=False) -> str: elif isinstance(node, IRFwOperation): self.emit_op_call(node) elif isinstance(node, IRAdapter): - node = node.dispatch(rank=device) + node = node.dispatch(device) + print(node.extra_repr()) self.emit_adapter_call(node) elif isinstance(node, IRWeightReducer): self.emit_reducer_init(node) @@ -246,66 +246,77 @@ def emit_adapter_call(self, node: IRAdapter): """ Emit adapter call """ - if len(node.device) != 1: - raise RuntimeError("Expected IRAdapter to be dispatched") - rank = node.device[0] - for prim in node.prims(): - # emit select - if isinstance(prim, SelectPrim): - sign = 'cube.runtime.adapter.select({tensor}, {indmap}, {valmap})' - input = self.tensor_naming(prim.tensor) - output = self.tensor_naming(prim.output) - valmap = (prim.valmap.idx, prim.valmap.chunk_num) - code = f'{output} = {sign.format(tensor=input, indmap=prim.indmap, valmap=valmap)}' - self.forward_region.append(code) - # emit move - elif isinstance(prim, MovePrim): - send_sign = 'cube.runtime.adapter.send({tensor}, {send_rank})' - recv_sign = 'cube.runtime.adapter.recv({shape}, {from_rank}, {dtype})' - tensor = self.tensor_naming(prim.tensor) - # send - if rank == prim.from_rank: - code = f'{send_sign.format(tensor=tensor, send_rank=prim.to_rank)}' - self.forward_region.append(code) - # recv - elif rank == prim.to_rank: - output = self.tensor_naming(prim.tensor) - dtype = self.dtype_map(prim.dtype) - code = f'{tensor} = {recv_sign.format(shape=prim.shape, from_rank=prim.from_rank, dtype=dtype)}' - self.forward_region.append(code) - # emit merge - elif isinstance(prim, MergePrim): - sign = 'cube.runtime.adapter.merge({tensors}, {concat}, {add})' - inputs = self.tuple_naming(prim.tensors) - output = self.tensor_naming(prim.output) - code = f'{output} = {sign.format(tensors=inputs, concat=prim.concat, add=prim.add)}' - self.forward_region.append(code) - # emit collectives - elif isinstance(prim, CollectivePrim): - sign = 'cube.runtime.adapter.{ctype}({input_tensors}, {output_shapes}, {output_dtypes}, {group})' - inputs = self.tuple_naming(prim.inputs) - outputs = self.return_naming(prim.outputs) - dtypes = None - if prim.output_dtypes is not None: - dtypes = [self.dtype_map(dtype) for dtype in prim.output_dtypes] - dtypes = self.tuple_naming(dtypes) - body = sign.format( - ctype=prim.ctype.value, - input_tensors = inputs, - output_shapes = prim.output_shapes, - output_dtypes = dtypes, - group=prim.group - ) - code = f'{outputs} = {body}' - self.forward_region.append(code) + assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" + # rank = node.device[0] + for prim in node.prims: + # print(f'generating prim: {prim}') + if len(prim.inputs()) == 1: + itensors = self.tensor_naming(prim.inputs()[0]) else: - raise TypeError(f"Unkown primitive types {type(prim)} of Adapter") + itensors = self.tuple_naming(prim.inputs()) + kwargs = list() + for name, val in prim.kwargs.items(): + kwargs.append(f'{name}={val}') + kwargs = ', '.join(kwargs) + outputs = self.return_naming(prim.outputs()) + code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' + self.forward_region.append(code) + # # emit select + # if isinstance(prim, SelectPrim): + # sign = 'cube.runtime.adapter.select({tensor}, {indmap}, {valmap})' + # input = self.tensor_naming(prim.tensor) + # output = self.tensor_naming(prim.output) + # valmap = (prim.valmap.idx, prim.valmap.chunk_num) + # code = f'{output} = {sign.format(tensor=input, indmap=prim.indmap, valmap=valmap)}' + # self.forward_region.append(code) + # # emit move + # elif isinstance(prim, MovePrim): + # send_sign = 'cube.runtime.adapter.send({tensor}, {send_rank})' + # recv_sign = 'cube.runtime.adapter.recv({shape}, {from_rank}, {dtype})' + # tensor = self.tensor_naming(prim.tensor) + # # send + # if rank == prim.from_rank: + # code = f'{send_sign.format(tensor=tensor, send_rank=prim.to_rank)}' + # self.forward_region.append(code) + # # recv + # elif rank == prim.to_rank: + # output = self.tensor_naming(prim.tensor) + # dtype = self.dtype_map(prim.dtype) + # code = f'{tensor} = {recv_sign.format(shape=prim.shape, from_rank=prim.from_rank, dtype=dtype)}' + # self.forward_region.append(code) + # # emit merge + # elif isinstance(prim, MergePrim): + # sign = 'cube.runtime.adapter.merge({tensors}, {concat}, {add})' + # inputs = self.tuple_naming(prim.tensors) + # output = self.tensor_naming(prim.output) + # code = f'{output} = {sign.format(tensors=inputs, concat=prim.concat, add=prim.add)}' + # self.forward_region.append(code) + # # emit collectives + # elif isinstance(prim, CollectivePrim): + # sign = 'cube.runtime.adapter.{ctype}({input_tensors}, {output_shapes}, {output_dtypes}, {group})' + # inputs = self.tuple_naming(prim.inputs) + # outputs = self.return_naming(prim.outputs) + # dtypes = None + # if prim.output_dtypes is not None: + # dtypes = [self.dtype_map(dtype) for dtype in prim.output_dtypes] + # dtypes = self.tuple_naming(dtypes) + # body = sign.format( + # ctype=prim.ctype.value, + # input_tensors = inputs, + # output_shapes = prim.output_shapes, + # output_dtypes = dtypes, + # group=prim.group + # ) + # code = f'{outputs} = {body}' + # self.forward_region.append(code) + # else: + # raise TypeError(f"Unkown primitive types {type(prim)} of Adapter") # requires grad generation - sign = '{output} = {output}.contiguous().requires_grad_()' - for output in node.outputs(): - if isinstance(output, IRSubTensor) and output.requires_grad: - code = sign.format(output=self.tensor_naming(output)) - self.forward_region.append(code) + # sign = '{output} = {output}.contiguous().requires_grad_()' + # for output in node.outputs(): + # if isinstance(output, IRSubTensor) and output.requires_grad: + # code = sign.format(output=self.tensor_naming(output)) + # self.forward_region.append(code) def emit_reducer_init(self, node: IRWeightReducer): # reducer init interface @@ -390,7 +401,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: device_nodes = self.execplan.sequence(device) for idx, node in enumerate(device_nodes): if isinstance(node, IRAdapter): - node = node.dispatch(rank=device) + node = node.dispatch(device) device_nodes[idx] = node def refcount(tensor, node) -> int: diff --git a/cube/compiler.py b/cube/compiler.py index b062a5f5..d08cc51e 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -6,7 +6,7 @@ import cube from cube.graph import parser -from cube.graph.adapter.gen import AdapterGener +from cube.graph.adapter.gen import IRAdapterGener from cube.graph.graph import IRGraph from cube.graph.operator.operator import IRDataOperation @@ -14,8 +14,7 @@ from cube.logics.translator import LogicTranslator from cube.execplan import ExectuionPlan -from cube.execplan.planpass.grouping import Grouping, GroupingAdapter -from cube.execplan.planpass.fusion import P2PFusion +from cube.execplan.planpass.grouping import Grouping from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -165,7 +164,8 @@ def decorator(fn: Callable) -> Callable: raise RuntimeError(f"Node {node} device is not set") # generate adapter - graph = AdapterGener.gen(graph) + # graph = AdapterGener.gen(graph) + graph = IRAdapterGener.gen(graph) # to execution plan execplan = ExectuionPlan(graph) @@ -176,15 +176,15 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on grouping operations: {:.2f} s'.format(span)) - start = time.time() - execplan = P2PFusion.apply(execplan) - span = time.time() - start - print('> planpass on p2pfusion operations: {:.2f} s'.format(span)) + # start = time.time() + # execplan = P2PFusion.apply(execplan) + # span = time.time() - start + # print('> planpass on p2pfusion operations: {:.2f} s'.format(span)) - start = time.time() - execplan = GroupingAdapter.apply(execplan) - span = time.time() - start - print('> planpass on grouping adapters : {:.2f} s'.format(span)) + # start = time.time() + # execplan = GroupingAdapter.apply(execplan) + # span = time.time() - start + # print('> planpass on grouping adapters : {:.2f} s'.format(span)) execplan.graph.reset_dependency() # execplan.analyze(outfile='execplan.png') diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index b2281dc3..1dda4841 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -1,519 +1,115 @@ - -from enum import Enum -from typing import List, Optional, Tuple +from typing import List, Optional import copy -import numpy as np -from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap +from cube.graph.adapter.prim import IRAdapterPrim, IdentityPrim +from cube.graph.tensor import IRSubTensor from cube.ir.cten import IRCell -from cube.ir.dtype import IRDType - - -class SelectPrim: - - def __init__(self, tensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, - shape: List[int], output: IRSubTensor): - self.tensor = tensor - self.indmap = indmap - self.valmap = valmap - self.shape = shape - self.output = output - self.device: List[int] = tensor.device - - def __repr__(self): - dscp = f'{self.output} = select({self.tensor})' - return dscp - - -class MovePrim: - - def __init__(self, tensor: IRSubTensor, from_rank: int, to_rank: int): - self.tensor = tensor - self.from_rank = from_rank - self.to_rank = to_rank - self.shape = tensor.shape - self.dtype = tensor.dtype - self.device: List[int] = [from_rank, to_rank] - - def __repr__(self): - dscp = f'move({self.tensor}, from={self.from_rank}, to={self.to_rank})' - return dscp - - -class CollectivePrim: - - class Type(Enum): - AllReduce = 'all_reduce' - AllGather = 'all_gather' - ReduceScatter = 'reduce_scatter' - Broadcast = 'broadcast' - - def __init__(self, ctype: Enum, - device: Tuple[int], - group: Tuple[int], - inputs: Tuple[IRSubTensor] = None, - input_shapes: Tuple[Tuple[int]] = None, - input_dtypes: Tuple[IRDType] = None, - outputs: Tuple[IRSubTensor] = None, - output_shapes: Tuple[Tuple[int]] = None, - output_dtypes: Tuple[IRDType] = None): - """ - inputs: - the collective input tensors. Including remote tensors. - src_ranks: - the tensor rank for each corresponding input tensor - outputs: - the collective output tensors. Including remote tensors. - dst_ranks: - the tensor rank for each corresponding output tensor - device: - the collective to be performed rank. - Note n-device collective will have n CollectivePrim, - each needs to be assigned with a single device rank. - """ - self.ctype = ctype - # inputs - self.inputs: Tuple[IRSubTensor] = tuple(inputs) if inputs is not None else list() - self.input_shapes: Tuple[IRSubTensor] = input_shapes - self.input_dtypes: Tuple[IRDType] = input_dtypes - # outputs - self.outputs: Tuple[IRSubTensor] = outputs if outputs is not None else list() - self.output_shapes: List[IRSubTensor] = output_shapes - self.output_dtypes: List[IRDType] = output_dtypes - # communication group - self.group: Tuple[int] = tuple(group) - # device - self.device = tuple(device) - - def __repr__(self): - dscp = f'{self.outputs} = {self.ctype.value}(inputs={self.inputs}, group={self.group})' - return dscp - - -class MergePrim: - def __init__(self, tensors: List[IRSubTensor], - output: IRSubTensor, device: List[int], - concat: Optional[int] = None, add: bool = False): - if not ((concat is not None) ^ (add is True)): # xor condition - raise RuntimeError("Expected concat or add") - self.tensors = tensors - self.concat = concat - self.add = add - self.output = output - # re-order tensor - if isinstance(concat, int): - slicers = [tensor.indmap.get()[concat] for tensor in tensors] - starts = np.array([slicer.start for slicer in slicers], dtype=int) - sorted_idx = np.argsort(starts) - tensors = np.array(tensors)[sorted_idx] - self.tensors = tensors.tolist() - self.device: List[int] = device - - def set_output(self, output: IRSubTensor): - self.output = output - - @staticmethod - def concat(tensor1: IRSubTensor, tensor2: IRSubTensor) -> Optional[Tuple[IRSubTensor, int]]: - """ - Check if two tensor can be merged. - If they can be merged, return the merge index - """ - if not isinstance(tensor1, IRSubTensor) or not isinstance(tensor2, IRSubTensor): - raise TypeError("Expected two tensors") - if tensor1.overlap(tensor2): - return None - if tensor1.parent != tensor2.parent: - return None - if tensor1.valmap != tensor2.valmap: - return None - indices1 = tensor1.indmap.get() - indices2 = tensor2.indmap.get() - indmap = list() - if len(indices1) != len(indices2): - return None - axis = None - for dim, (slicer1, slicer2) in enumerate(zip(indices1, indices2)): - if slicer1 != slicer2: - start1, stop1, step1 = slicer1.start, slicer1.stop, slicer1.step - start2, stop2, step2 = slicer2.start, slicer2.stop, slicer2.step - if step1 != step2: - return None - if axis is not None: - return None - if start1 < start2 and stop1 == start2: - axis = dim - indmap.append(slice(start1, stop2, step1)) - elif start1 > start2 and start1 == stop2: - axis = dim - indmap.append(slice(start2, stop1, step1)) - else: - return None - else: - indmap.append(slicer1) - shapes = list() - for idx, (nele1, nele2) in enumerate(zip(tensor1.shape, tensor2.shape)): - nele = nele1 if idx != axis else nele1 + nele2 - shapes.append(nele) - mtensor = tensor1.parent.select( - indmap = tuple(indmap), - valmap = tensor1.valmap, - shape = shapes - ) - return mtensor, axis - - @staticmethod - def add(tensor1: IRSubTensor, tensor2: IRSubTensor) -> Optional[IRSubTensor]: - if not isinstance(tensor1, IRSubTensor) or not isinstance(tensor2, IRSubTensor): - raise TypeError("Expected two tensors") - if tensor1.overlap(tensor2): - return None - if tensor1.parent != tensor2.parent: - return None - if tensor1.indmap != tensor2.indmap: - return None - if tensor1.valmap.chunk_num != tensor2.valmap.chunk_num: - return None - chunk_num = tensor1.valmap.chunk_num - idx1, idx2 = tensor1.valmap.idx, tensor2.valmap.idx - if chunk_num % 2 != 0: - return None - chunk_num = int(chunk_num // 2) - if int(idx1 // 2) != int(idx2 // 2): - return None - idx = int(idx1 // 2) - mtensor = tensor1.parent.select( - indmap = tensor1.indmap, - valmap = (idx, chunk_num), - shape = tensor1.shape - ) - return mtensor - - def __repr__(self): - dscp = f'{self.output} = merge({self.tensors}, axis={self.concat}, add={self.add})' - return dscp class IRAdapter(IRCell): - """ - Tensor Adapter for each operator. - - A Tensor Adapter has three stages: - * Select: select produced tensors - * Move: transfer the produced tensors - * Merge: merge the produced tensors - """ - - def __init__(self, prims, - inputs: List[IRSubTensor], idevices: List[List[int]], - outputs: List[IRSubTensor], odevices: List[List[int]]): - - self._prims = prims - self._idevices = tuple(idevices) - self._odevices = tuple(odevices) + def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): super().__init__( name='adapter', signature='adapter', input_length=len(inputs), output_length=len(outputs), init_outputs=False ) - for idx, tensor in enumerate(inputs): - self.set_input(idx, tensor) - for idx, tensor in enumerate(outputs): - self.set_output(idx, tensor) + # we don't use input and output setter as this will + # change tensor device info + self._inputs = inputs + self._outputs = outputs + + self._prims: Optional[List[IRAdapterPrim]] = None + self._differentiable = False + + @property + def prims(self) -> List[IRAdapterPrim]: + if self.is_forward: + if self.differentiable(): + return self.diffcolls + else: + return self.forward + else: + if self.differentiable(): + # not able to see + return [] + else: + return self.backward - # set up device - device = set() - for prim in self._prims: - device.update(prim.device) - self.device = list(device) + @property + def prims(self) -> List[IRAdapterPrim]: + return copy.copy(self._prims) - def prims(self, select=True, move=True, merge=True, coll=True): - """ - Return prim list - """ - prims = list() - for prim in self._prims: - if select and isinstance(prim, SelectPrim): - prims.append(prim) - if move and isinstance(prim, MovePrim): - prims.append(prim) - if merge and isinstance(prim, MergePrim): - prims.append(prim) - if coll and isinstance(prim, CollectivePrim): - prims.append(prim) - return prims + @prims.setter + def prims(self, prims: List[IRAdapterPrim]): + assert all(isinstance(prim, IRAdapterPrim) for prim in prims), "Expect List[IRAdapterPrim]" + self._prims = prims + self.update_device() - def idevice(self, input_index: int = None) -> List[int]: + @property + def differentiable(self) -> bool: """ - Get device for input tensor at input index. - - Returns: - device: List[int] + return if the adapter is using differentiable primitives """ - if isinstance(input_index, int): - return self._idevices[input_index] - else: - return copy.copy(self._idevices) + return self._differentiable - def odevice(self, output_index: int = None) -> List[int]: - """ - Get device for output tensor at output index. + @differentiable.setter + def differentiable(self, val: bool): + self._differentiable = val - Returns: - device: List[int] - """ - if isinstance(output_index, int): - return self._odevices[output_index] - else: - return copy.copy(self._odevices) + def update_device(self): + device = set() + for prim in self.prims: + device.update(prim.device) + self.device = list(device) - def dispatch(self, rank: int): + def dispatch(self, devid: int): """ Get Adapter for a specific rank Returns: IRAdapter """ - if not isinstance(rank, int): - raise TypeError(f"Expected rank to be int but got {rank}") - prims = list() - for prim in self.prims(): - if rank in prim.device: - prims.append(prim) - inputs, idevs = list(), list() - for input, devs in zip(self.inputs(), self._idevices): - if rank in devs: - inputs.append(input) - idevs.append(devs) - outputs, odevs = list(), list() - for output, devs in zip(self.outputs(), self._odevices): - if rank in devs: - outputs.append(output) - odevs.append(devs) - adapter = IRAdapter(prims, inputs, idevs, outputs, odevs) + assert isinstance(devid, int), f"Expect devid to be int but got {devid}" + prims = [prim.dispatch(devid) for prim in self.prims] + prims = [prim for prim in prims if prim is not None] + # get inputs + inputs = [] + for itensor in self.inputs(): + if devid in itensor.device: + inputs.append(itensor) + outputs = [] + for otensor in self.outputs(): + if devid in otensor.device: + outputs.append(otensor) + # insert identity prims + if len(prims) == 0: + assert len(inputs) == len(outputs) and all(itensor in outputs for itensor in inputs), \ + "input/output tensor not match for empty prims" + for itensor in inputs: + prims.append(IdentityPrim(itensor)) + # dispatch + adapter = IRAdapter(inputs, outputs) + adapter.prims = prims adapter.name = self.name adapter._id = self._id - adapter.device = rank return adapter - def update_device(self): - """ - Update device (needed when adapter content changes, e.g., P2PFusion) - """ - device = set() - for prim in self._prims: - device.update(prim.device) - self.device = list(device) - - def is_identity(self): - """ - Check if the adapter does nothing - - Returns: - Boolean - """ - return len(self._prims) == 0 - - @staticmethod - def gen(dst_tensor: IRSubTensor): - # print(f'generating adapter for: {dst_tensor}') - if not isinstance(dst_tensor, IRSubTensor): - raise RuntimeError("Expected IRSubTensor") - inputs, intersections, select_prims = IRAdapter.gen_select(dst_tensor) - move_prims = IRAdapter.gen_move(dst_tensor, intersections) - merge_prims = IRAdapter.gen_merge(dst_tensor, intersections) - prims = select_prims + move_prims + merge_prims - idevs = [t.device for t in inputs] - odevs = [dst_tensor.device] - return IRAdapter(prims, inputs, idevs, [dst_tensor], odevs) - - @staticmethod - def gen_select(dst_tensor): - - # TODO: consider previous adapter output as later adapter in-tensor - # for residual cases - - inputs = list() - intersections = list() - prims = list() - - otensor = dst_tensor - odevice = otensor.device - - # local and remote adapter in-tensor - # local_remote instead of local + remote to preserve inputs order - # TODO check order as may affect merging result - local, remote, local_and_remote = list(), list(), otensor.parent.ptensors - for ptensor in otensor.parent.ptensors: - if ptensor.device == odevice: - local.append(ptensor) - else: - remote.append(ptensor) - - # first check local in tensor - for tensor in local: - common = tensor.common(otensor) - if tensor == otensor: - intersections.append(tensor) - inputs.append(tensor) - return inputs, intersections, prims - elif common == otensor: - # index map - indmap = list() - islicers = tensor.indmap.get() - oslicers = common.indmap.get() - for islicer, oslicer in zip(islicers, oslicers): - istart, istop, istep = islicer.start, islicer.stop, islicer.step - ostart, ostop, ostep = oslicer.start, oslicer.stop, oslicer.step - # relative offset - start = ostart - istart - stop = start + ostop - ostart - slicer = slice(start, stop, ostep) - indmap.append(slicer) - # value map must be same - if tensor.valmap != common.valmap: - break - valmap = ValueMap(0, 1) - prim = SelectPrim(tensor, indmap, valmap, common.shape, common) - prims.append(prim) - intersections.append(otensor) - inputs.append(tensor) - return inputs, intersections, prims - # if otensor in local: - # intersections.append(otensor) - # inputs.append(otensor) - # return inputs, intersections, prims - - # check local + remote - for itensor in local_and_remote: #local + remote: - if not itensor.overlap(otensor): - continue - - # intersection - common: IRSubTensor = otensor.common(itensor) - common.attach_cell(itensor._cell) - intersections.append(common) - inputs.append(itensor) - if common == itensor: - continue - - islicers = itensor.indmap.get() - oslicers = common.indmap.get() - # index map - indmap = list() - for islicer, oslicer in zip(islicers, oslicers): - istart, istop, istep = islicer.start, islicer.stop, islicer.step - ostart, ostop, ostep = oslicer.start, oslicer.stop, oslicer.step - if ostep % istep != 0: - raise RuntimeError("Step condition fails") - # relative offset - start = ostart - istart - stop = start + ostop - ostart - slicer = slice(start, stop, ostep) - indmap.append(slicer) - # value map - if itensor.valmap == common.valmap: - valmap = ValueMap(0, 1) - elif itensor.valmap == ValueMap(0, 1): - valmap = common.valmap - else: - print('from: ', itensor) - print('to : ', common) - raise NotImplementedError( - f"Not supported value select: {input.valmap} -> {common.valmap}" - ) - prim = SelectPrim(itensor, indmap, valmap, common.shape, common) - prims.append(prim) - # TODO: check union == otensor - if common == otensor: - break - - return inputs, intersections, prims - - @staticmethod - def gen_move(dst_tensor, intersections): - prims = list() - odevice = dst_tensor.device - for tensor in intersections: - if tensor.device != odevice: - if len(tensor.device) != 1 or len(odevice) != 1: - raise RuntimeError( - f"Expected tensor on a single device but got {tensor.device} and {odevice}" - ) - prim = MovePrim(tensor, from_rank=tensor.device[0], to_rank=odevice[0]) - prims.append(prim) - return prims - - @staticmethod - def gen_merge(dst_tensor, intersections): - prims = list() - output = dst_tensor - remain_tensors = copy.copy(intersections) - if output in remain_tensors: - return prims - out = None - while out != output: - out = None - merged = False - for idx1 in range(len(remain_tensors) - 1): - for idx2 in range(idx1 + 1, len(remain_tensors)): - tensor1 = remain_tensors[idx1] - tensor2 = remain_tensors[idx2] - # try concat - out = MergePrim.concat(tensor1, tensor2) - if out is not None: - out, concat_dim = out - prim = MergePrim([tensor1, tensor2], out, output.device, concat_dim, False) - prims.append(prim) - merged = True - break - # try add - out = MergePrim.add(tensor1, tensor2) - if out is not None: - prim = MergePrim([tensor1, tensor2], out, output.device, None, True) - prims.append(prim) - merged = True - break - if merged: - remain_tensors.remove(tensor1) - remain_tensors.remove(tensor2) - remain_tensors.append(out) - break - # cannot merge or add - if out is None: - print(f'failed tensor: {dst_tensor.extra_repr()}') - print(f'ptensor:') - for tensor in dst_tensor.parent.ptensors: - print(f'node-{tensor._cell._id}: {tensor.extra_repr()}') - print('intersections:') - for tensor in intersections: - print(f'{tensor.extra_repr()}') - raise RuntimeError(f"Merge plan of tensor {dst_tensor} not found") - return prims - def __repr__(self): - dscp = f'Adapter{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' - return dscp + return f'Adapter-{self._id}{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' - def module_repr(self) -> str: - return repr(self) - - def extra_repr(self): - """ - Detailed information - """ - dscp = repr(self) + ':\n' - # select - for prim in self._prims: - dscp += '\t' + repr(prim) + '\n' + def extra_repr(self) -> str: + dscp = f'Adapter-{self._id}[{self.device}](inputs={self.inputs()}, outputs={self.outputs()})\n' + for prim in self.prims: + dscp += repr(prim) + '\n' return dscp class IRWeightReducer(IRCell): def __init__(self, weights: List[IRSubTensor], name='reducer'): - if not all([isinstance(w, IRSubTensor) and w.is_param() for w in weights]): + if not all(isinstance(w, IRSubTensor) and w.is_param() for w in weights): raise RuntimeError("Expected a list of gradient IRSubTensor") signature = None super().__init__(name, signature, len(weights), 0) diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index 0a9b1112..4b827edb 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -1,17 +1,25 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import warnings +import copy from cube.graph.graph import IRGraph -from cube.graph.tensor import IRSubTensor, ValueMap +from cube.graph.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap from cube.graph.adapter.adapter import IRAdapter, IRWeightReducer + from cube.graph.operator.operator import IRBpOperation, IRFwOperation +from cube.ir.cten import IRCell + +from cube.graph.adapter.prim import IRAdapterPrim +from cube.graph.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim +from cube.graph.adapter.layout import GridLayout -class AdapterGener: +class IRAdapterGener: @staticmethod - def gen(graph: IRGraph, eager=True) -> IRGraph: + def gen(graph: IRGraph) -> IRGraph: """ - Generate tensor adapter for both intermediate tensors and weights + Generate tensor adapter for both activations and weights Args: graph: IRGraph. @@ -30,47 +38,12 @@ def gen(graph: IRGraph, eager=True) -> IRGraph: idx = graph.detach(node) node.update() graph.attach(node, idx) - graph = AdapterGener.gen_activation_adapter(graph, eager) - graph = AdapterGener.gen_weight_reducer(graph) + graph = IRAdapterGener.gen_activation(graph) + graph = IRAdapterGener.gen_weight(graph) return graph @staticmethod - def gen_activation_adapter(graph: IRGraph, eager=True) -> IRGraph: - all_adapters = list() - # generate adapter for non-weight values - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - for input in node.inputs(): - if not isinstance(input, IRSubTensor): - continue - # skip parameter - if input.is_param(): - continue - adapter = IRAdapter.gen(input) - if not adapter.is_identity(): - all_adapters.append(adapter) - idx = graph.nodes().index(node) - graph._nodes.insert(idx, adapter) - if isinstance(node, IRBpOperation): - for grad in node.inputs(): - if not isinstance(grad, IRSubTensor): - continue - adapter = IRAdapter.gen(grad) - if not adapter.is_identity(): - all_adapters.append(adapter) - idx = graph.nodes().index(node) - graph._nodes.insert(idx, adapter) - graph.reset_dependency() - if eager: - seq = graph.nodes() - for adapter in all_adapters: - seq.remove(adapter) - graph.partial_set_order(seq, eager=True) - return graph - - - @staticmethod - def gen_weight_reducer(graph: IRGraph) -> IRGraph: + def gen_weight(graph: IRGraph) -> IRGraph: # step 1: get weight and gradient # weights: Dict[weight_id: int, IRSubTensor] # grads : Dict[weight_id: int, Dict[device: int, List[grad: IRSubTensor]]] @@ -83,8 +56,7 @@ def gen_weight_reducer(graph: IRGraph) -> IRGraph: for input in fnode.inputs(): if isinstance(input, IRSubTensor) and input.is_param(): grad = input.grad - if grad is None: #TODO remove me, for weather.py test - print(f'WARNING: skipping non grad of {fnode}') + if grad is None: continue # nothing to sync if grad.valmap == ValueMap(0, 1): @@ -119,4 +91,272 @@ def gen_weight_reducer(graph: IRGraph) -> IRGraph: opt_op = IRWeightReducer(weights) opt_op.device = list(ranks) graph._nodes.append(opt_op) - return graph \ No newline at end of file + return graph + + @staticmethod + def gen_activation(graph: IRGraph) -> IRGraph: + for ftensor in graph.full_tensors(): + # backward will gen in forward + if ftensor.is_param() or ftensor.is_grad(): + continue + adapters = IRAdapterGener.gen_fulltensor(ftensor) + if len(adapters) == 0: + continue + for fadapter in adapters: + # insert forward adapter + idx = min([graph.nodes().index(c) for c in ftensor.consumers]) + graph._nodes.insert(idx, fadapter) + # insert backward adapter + grad: Optional[IRFullTensor] = ftensor.grad + if grad is not None: + badapter: IRAdapter = fadapter.mirror + idx = min([graph.nodes().index(c) for c in grad.consumers]) + graph._nodes.insert(idx, badapter) + return graph + + @staticmethod + def gen_fulltensor(ftensor: IRFullTensor) -> List[IRAdapter]: + # print(f'analyzing ftensor: {ftensor}') + # print(f'ptensors: {ftensor.ptensors}') + # print(f'ctensors: {ftensor.ctensors}') + if len(ftensor.consumers) == 0: + return [] + pdevs = set() + for pnode in ftensor.producers: + pdevs.update(pnode.device) + cdevs = set() + for cnode in ftensor.consumers: + cdevs.update(cnode.device) + # sharing devices + if pdevs == cdevs: + return IRAdapterGener.gen_gridlayout(ftensor) + # no-sharing devices + # elif len(pdevs.intersection(cdevs)) == 0: + # print(f'detect no intersection') + # return [] + # general cases + warnings.warn('UserWarning: the adapter is generated using inefficient P2P send/recv') + fprims, bprims = [], [] + for subtensor in ftensor.ctensors: + fprims += IRAdapterGener.gen_subtensor(subtensor) + fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) + fadapter.prims = fprims + # print(fadapter.extra_repr()) + grad: IRFullTensor = ftensor.grad + if grad is not None: + for subtensor in grad.ctensors: + bprims += IRAdapterGener.gen_subtensor(subtensor) + badapter = IRAdapter(grad.ptensors, grad.ctensors) + badapter.prims = bprims + # print(badapter.extra_repr()) + IRCell.make_pair(fadapter, badapter) + if len(fprims) == 0 and len(bprims) == 0: + return [] + return [fadapter] + + @staticmethod + def gen_gridlayout(ftensor: IRFullTensor) -> List[IRAdapter]: + """ + Generate adapters for connecting producer with consumer with + shared devices for forward and backward. + + ftensor: IRFullTensor: forward full tensor. + """ + # producer grid layout + ilayout = GridLayout.togrid(ftensor, ftensor.ptensors) + # reorder ctensors to match with ptensors + devs = [ptensor.device for ptensor in ilayout.mat.flatten()] + ctensors = [None] * len(devs) + for ctensor in ftensor.ctensors: + idx = devs.index(ctensor.device) + assert ctensors[idx] is None, "same device of different tensors" + ctensors[idx] = ctensor + # consumer grid layout + olayout = GridLayout.togrid(ftensor, ctensors) + # print(f'forward full tensor: {ftensor}\n producer: {ilayout}, consumer: {olayout}') + # find path + paths, fprims = ilayout.path(olayout, auto_replace=True) + + # re-assign the operator if miss-ordered + names, from_dev, to_dev = [], [], [] + reorder : Dict[str, Tuple[int, int]] = dict() + for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): + assert len(itensor.device) == 1 and len(otensor.device) == 1, \ + "Expect tensor only has one device. Report this as a bug" + if itensor.device != otensor.device: + inode, onode = itensor._cell, otensor._cell + names.append(f'{onode.name}{onode._id}') + from_dev.append(onode.device[0]) + to_dev.append(inode.device[0]) + onode.device = inode.device + if onode.mirror is not None: + onode.mirror.device = inode.device + if len(reorder) > 0: + warnings.warn(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') + + # print('find path:') + # for path in paths: print(path) + # print('comm prims:') + # for prim in fprims: print(prim) + fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) + fadapter.prims = fprims + + # generate backward + grad: IRFullTensor = ftensor.grad + bprims = [] + if grad is not None: + # reorder ptensors to match with forward + ptensors = [None] * len(devs) + for ptensor in grad.ptensors: + idx = devs.index(ptensor.device) + assert ptensors[idx] is None, "same device of different tensors" + ptensors[idx] = ptensor + ilayout = GridLayout.togrid(grad, ptensors) + olayout = GridLayout.togrid(grad, grad.ctensors) + # print(f'backward full tensor: {grad}\n producer: {ilayout}, consumer: {olayout}') + paths, bprims = ilayout.path(olayout) + # check the device order + for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): + assert len(itensor.device) == len(otensor.device), "backward device not match" + # print('find path:') + # for path in paths: print(path) + # print('comm prims') + # for prim in bprims: print(prim) + badapter = IRAdapter(grad.ptensors, grad.ctensors) + badapter.prims = bprims + IRCell.make_pair(fadapter, badapter) + if len(fprims) == 0 and len(bprims) == 0: + return [] + # print('=====') + return [fadapter] + + @staticmethod + def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: + """ + Generate communication prims for a sub-tensor. + The subtensor should be a IRSubTensor of consumer. + + The generation takes three stages: select, move, merge + """ + ftensor = subtensor.parent + # category to local tensor and remote tensor + local = [t for t in ftensor.ptensors if t.device == subtensor.device] + remote = [t for t in ftensor.ptensors if t.device != subtensor.device] + prims = [] + + # ==== select ==== # + intersections = [] + # check local + for tensor in local: + common = tensor.common(subtensor) + if tensor == subtensor: + return prims + elif common == subtensor: + indmap = [] + for islicer, oslicer in zip(tensor.indmap.get(), common.indmap.get()): + start = oslicer.start - islicer.start + stop = start + oslicer.stop - oslicer.start + indmap.append(slice(start, stop, 1)) + valmap = ValueMap(0, 1) + common.attach_cell(subtensor._cell) + prims.append(SelectPrim(tensor, indmap, valmap, common)) + return prims + # check local + remote + if len(intersections) == 0: + for itensor in local+remote: + if not itensor.overlap(subtensor): + continue + common = itensor.common(subtensor) + common.attach_cell(itensor._cell) + print(f'get common: {common.extra_repr()}') + intersections.append(common) + if common == itensor: + continue + indmap = [] + for islicer, oslicer in zip(itensor.indmap.get(), common.indmap.get()): + start = oslicer.start - islicer.start + stop = start + oslicer.stop - oslicer.start + indmap.append(slice(start, stop, 1)) + assert itensor.valmap == common.valmap or itensor.valmap == ValueMap(0,1), \ + f"Not supported value select: {itensor.valmap} -> {common.valmap}" + valmap = ValueMap(0, 1) + prims.append(SelectPrim(itensor, indmap, valmap, common)) + # TODO: check union == subtensor + if common == subtensor: + break + print(intersections) + # ====== move ===== # + tmoved = [] + for tensor in intersections: + assert len(tensor.device) == 1 and len(subtensor.device) == 1, "Expected only one device." + mtensor = tensor + if tensor.device != subtensor.device: + mtensor = copy.copy(tensor) + mtensor.attach_cell(subtensor._cell) + prims.append(MovePrim(tensor, mtensor)) + tmoved.append(mtensor) + + # ===== merge ===== # + remain_tensors: List[IRSubTensor] = copy.copy(tmoved) + if subtensor in remain_tensors: + return prims + out = None + while out != subtensor: + out, merged = None, False + for idx1 in range(len(remain_tensors) - 1): + for idx2 in range(idx1, len(remain_tensors)): + t1, t2 = remain_tensors[idx1], remain_tensors[idx2] + # check reducable + if t1.indmap == t2.indmap and t1.valmap.chunk_num == t2.valmap.chunk_num: + vid1, vid2 = t1.valmap.idx, t2.valmap.idx + # sum e.g., 0,1 but not 1,2 + if min(vid1, vid2) % 2 == 0 and abs(vid1-vid2) == 1: + vid = min(vid1, vid2) // 2 + valmap = ValueMap(vid, t1.valmap.chunk_num // 2) + out = subtensor.parent.select(t1.indmap, valmap, t1.shape) + out.attach_cell(subtensor._cell) + prims.append(SumPrim([t1, t2], out)) + merged = True + break + # try merge dimension + elif t1.valmap == t2.valmap: + cat_dim: Dict[int, List[IRSubTensor]] = dict() + indmap = [] + for dim, (s1, s2) in enumerate(zip(t1.indmap.get(), t2.indmap.get())): + if s1 != s2: + if min(s1.stop, s2.stop) == max(s1.start, s2.start): + if s1.start < s2.start: + cat_dim[dim] = [t1, t2] + else: + cat_dim[dim] = [t1, t2] + indmap.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop), 1)) + else: + cat_dim[dim] = None + indmap.append(None) + else: + indmap.append(s1) + if None in indmap: + continue + indmap = IndexMap(tuple(indmap)) + valmap = t1.valmap + out = t1.parent.select(indmap, valmap, indmap.shape) + out.attach_cell(subtensor._cell) + cdim = list(cat_dim.keys())[0] + prims.append(MergeDimPrim(cat_dim[cdim], out, cdim)) + merged = True + break + if merged: + remain_tensors.remove(t1) + remain_tensors.remove(t2) + remain_tensors.append(out) + break + if out is None: + ptensors = '\n\t'.join(t.extra_repr() for t in ftensor.ptensors) + raise RuntimeError( + f"Fail to build adapter.\n" + f"FullTensor:{ftensor}\n" + f"Producers:\n\t{ptensors}\n" + f"SubTensor:\n\t{subtensor.extra_repr()}" + ) + return prims + \ No newline at end of file diff --git a/cube/graph/adapter/layout.py b/cube/graph/adapter/layout.py index 5cd8b1b9..b2864e64 100644 --- a/cube/graph/adapter/layout.py +++ b/cube/graph/adapter/layout.py @@ -9,7 +9,7 @@ from cube.graph.adapter.prim import AllToAllPrim # d2d from cube.graph.adapter.prim import AllReducePrim # v2r from cube.graph.adapter.prim import ReduceScatterPrim # v2d -from cube.graph.adapter.prim import SplitDropDimPrim # r2d +from cube.graph.adapter.prim import ChunkPrim # r2d class GridLayout: @@ -156,7 +156,7 @@ def r2d(self, dim: int, chunks: int): prims = [] for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): for idx, (itensor, otensor) in enumerate(zip(itensors, otensors)): - prims.append(SplitDropDimPrim(itensor, otensor, dim, chunks, idx)) + prims.append(ChunkPrim(itensor, otensor, dim, chunks, idx)) return glayout, prims # ================ solution ============= # @@ -234,21 +234,21 @@ def step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLa paths.append(olayout) comm_prims += oprims break - if auto_replace: - replaced = False - reorder : Dict[str, Tuple[int, int]] = dict() - for itensor, otensor in zip(paths[-1].mat.flatten(), dst.mat.flatten()): - assert len(itensor.device) == 1 and len(otensor.device) == 1, \ - "Expect tensor only has one device. Report this as a bug" - if itensor.device != otensor.device: - inode, onode = itensor._cell, otensor._cell - reorder[f'{onode.name}-{onode._id}'] = (onode.device[0], inode.device[0]) - onode.device = inode.device - if onode.mirror is not None: - onode.mirror.device = inode.device - replaced = True - if replaced: - print(f'warning: a better device placement is found and set for op {reorder}') + # if auto_replace: + # replaced = False + # reorder : Dict[str, Tuple[int, int]] = dict() + # for itensor, otensor in zip(paths[-1].mat.flatten(), dst.mat.flatten()): + # assert len(itensor.device) == 1 and len(otensor.device) == 1, \ + # "Expect tensor only has one device. Report this as a bug" + # if itensor.device != otensor.device: + # inode, onode = itensor._cell, otensor._cell + # reorder[f'{onode.name}-{onode._id}'] = (onode.device[0], inode.device[0]) + # onode.device = inode.device + # if onode.mirror is not None: + # onode.mirror.device = inode.device + # replaced = True + # if replaced: + # print(f'warning: a better device placement is found and set for op {reorder}') return paths, comm_prims diff --git a/cube/graph/adapter/prim.py b/cube/graph/adapter/prim.py index 512e945d..95618b29 100644 --- a/cube/graph/adapter/prim.py +++ b/cube/graph/adapter/prim.py @@ -53,6 +53,7 @@ class SpatialPrim(IRAdapterPrim): """ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): super().__init__(inputs, outputs) + self.device = list(set(t.device[0] for t in inputs)) # numerical abstract primitive @@ -62,6 +63,8 @@ class ValuePrim(IRAdapterPrim): """ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): super().__init__(inputs, outputs) + self.device = list(set(t.device[0] for t in inputs)) + # communication abstract primitive class CommPrim(IRAdapterPrim): @@ -89,6 +92,17 @@ def __repr__(self) -> str: # ====================================================== +class IdentityPrim(SpatialPrim): + + def __init__(self, itensor: IRSubTensor): + super().__init__([itensor], [itensor]) + self.signature = 'cube.runtime.adapter.identity' + + def __repr__(self): + dscp = f"{self.outputs(0)} = identity({self.inputs(0)})" + return dscp + + class SelectPrim(SpatialPrim): def __init__(self, @@ -96,27 +110,16 @@ def __init__(self, indmap: IndexMap, valmap: ValueMap, otensor: IRSubTensor): super().__init__([itensor], [otensor]) - self.indmap = indmap - self.valmap = valmap - self.device = itensor.device + self.kwargs['indmap'] = indmap + self.kwargs['valmap'] = (valmap.idx, valmap.chunk_num) + self.signature = f"cube.runtime.adapter.select" def __repr__(self): - dscp = f'{self.outputs(0)} = select({self.inputs(0)})' + dscp = f"{self.outputs(0)} = select({self.inputs(0)}, indmap={self.kwargs['indmap']}, valmap={self.kwargs['valmap']})" return dscp -class SplitDimPrim(SpatialPrim): - """ - split dimension - """ - def __init__(self, itensor: IRSubTensor, dim: int, - otensors: List[IRSubTensor]): - super().__init__([itensor], otensors) - self.kwargs['dim'] = dim - self.device = itensor.device - - -class SplitDropDimPrim(SpatialPrim): +class ChunkPrim(SpatialPrim): """ split dimension in n chunks and take idx-th chunk """ @@ -127,8 +130,7 @@ def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor, self.kwargs['dim'] = dim self.kwargs['chunks'] = chunks self.kwargs['idx'] = idx - self.device = itensor.device - self.signature = 'cube.runtime.adapter.collectives.split_drop_dim' + self.signature = 'cube.runtime.adapter.chunk' def __repr__(self) -> str: return f"dev{self.device}: {self.outputs(0)} = split(dim={self.kwargs['dim']}, chunks={self.kwargs['chunks']}, idx={self.kwargs['idx']})" @@ -138,12 +140,14 @@ class MergeDimPrim(SpatialPrim): """ concatenate dimension """ - def __init__(self, itensors: List[IRSubTensor], dim: int, - otensor: IRSubTensor) -> None: + def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, dim: int) -> None: assert all(itensor.device == itensors[0].device for itensor in itensors), "device not same" super().__init__(itensors, [otensor]) - self.dim = dim - self.device = itensors[0].device + self.kwargs['dim'] = dim + self.signature = 'cube.runtime.adapter.smerge' + + def __repr__(self) -> str: + return f"dev{self.device}: {self.outputs(0)} = concat({self.inputs()}, dim={self.kwargs['dim']})" # numerical primitive @@ -152,8 +156,10 @@ class SumPrim(ValuePrim): def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): assert all(itensor.device == itensors[0].device for itensor in itensors), "device not same" super().__init__(itensors, [otensor]) - self.reduce = '+' - self.device = itensors[0].device + self.signature = 'cube.runtime.adapter.vmerge' + + def __repr__(self) -> str: + return f"dev{self.device}: {self.outputs(0)} = add({self.inputs()})" # communication primitive @@ -164,10 +170,8 @@ class SendPrim(CommPrim): def __init__(self, tensor, dst: int): super().__init__([tensor], [tensor]) self.kwargs['dst'] = dst - - def dispatch(self, devid: int): - assert devid == self.device[0], f"device {devid} not applied for this comm primitive" - return SendPrim(self.inputs(0), self.kwargs['dst']) + self.device = tensor.device + self.signature = 'cube.runtime.adapter.send' def __repr__(self) -> str: return f"{self.inputs(0)} = send({self.inputs(0)}, dst={self.kwargs['dst']}" @@ -177,38 +181,37 @@ class RecvPrim(CommPrim): """ P2P recv prim """ - def __init__(self, tensor, src: int): + def __init__(self, tensor: IRSubTensor, src: int): super().__init__([], [tensor]) - self.kwargs['src'] = src self.kwargs['shape'] = tensor.shape - self.kwargs['dtype'] = tensor.dtype - - def dispatch(self, devid: int): - assert devid == self.device[0], f"device {devid} not applied for this comm primitive" - return RecvPrim(self.outputs(0), self.kwargs['src']) + self.kwargs['dtype'] = 'torch.' + tensor.dtype.value + self.kwargs['src'] = src + self.device = tensor.device + self.signature = 'cube.runtime.adapter.recv' def __repr__(self) -> str: - return f"{self.outputs(0)} = recv(shape={self.kwargs['shape']}, dtype={self.kwargs['dtype']}, dst={self.kwargs['dst']}" + return f"{self.outputs(0)} = recv(shape={self.kwargs['shape']}, dtype={self.kwargs['dtype']}, src={self.kwargs['src']}" class MovePrim(CommPrim): """ P2P send/recv, non-differentiable """ - def __init__(self, tensor: IRSubTensor, src: int, dst: int): - super().__init__([tensor], [tensor]) - self.kwargs['src'] = src - self.kwargs['dst'] = dst + def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor): + super().__init__([itensor], [otensor]) + self.kwargs['src'] = itensor.device[0] + self.kwargs['dst'] = otensor.device[0] + self.device = itensor.device + otensor.device def dispatch(self, devid: int) -> Union[SendPrim, RecvPrim]: if devid == self.kwargs['src']: - return SendPrim(self.inputs(0), self.kwargs['devid']) + return SendPrim(self.inputs(0), self.kwargs['dst']) if devid == self.kwargs['dst']: - return RecvPrim(self.inputs(0), self.kwargs['src']) - raise ValueError(f"device {devid} is not src ({self.kwargs['src']}) or ({self.kwargs['dst']})") + return RecvPrim(self.outputs(0), self.kwargs['src']) + return None def __repr__(self): - dscp = f'move({self.inputs(0)}, from={self.src}, to={self.dst})' + dscp = f"move({self.inputs(0)}, src={self.kwargs['src']}, dst={self.kwargs['dst']})" return dscp @@ -242,7 +245,7 @@ class AllReducePrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): super().__init__(itensors, otensors) - self.signature = 'cube.runtime.adapter.collectives.all_reduce' + self.signature = 'cube.runtime.adapter.all_reduce' def __repr__(self) -> str: return f'dev{self.device}: {self.outputs()} = all_reduce({self.inputs()}' @@ -254,7 +257,7 @@ class AllGatherPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): super().__init__(itensors, otensors, dim=dim) - self.signature = 'cube.runtime.adapter.collectives.all_gather' + self.signature = 'cube.runtime.adapter.all_gather' def __repr__(self) -> str: return f'dev{self.device}: {self.outputs()} = all_gather({self.inputs()})' @@ -266,13 +269,12 @@ class ReduceScatterPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): super().__init__(itensors, otensors, dim=dim) - self.signature = 'cube.runtime.adapter.collectives.reduce_scatter' + self.signature = 'cube.runtime.adapter.reduce_scatter' def __repr__(self) -> str: return f'dev{self.device}: {self.outputs()} = reduce_scatter({self.inputs()})' - class BroadcastPrim(CollectivePrim): """ non-differential reduce-scatter @@ -300,7 +302,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idi idim != odim """ super().__init__(itensors, otensors, idim=idim, odim=odim) - self.signature = 'cube.runtime.adapter.collectives.all_to_all' + self.signature = 'cube.runtime.adapter.all_to_all' def __repr__(self) -> str: return f"dev{self.device}: {self.outputs()} = all_to_all({self.inputs}, idim={self.kwargs['idm']}, odim={self.kwargs['odim']})" diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 22e10eeb..ca388e08 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -329,7 +329,7 @@ def replicate(self, op: IRCell, times=1, reset_dependency=True) -> Optional[List # insert forward fidx = self.nodes().index(op) for idx, fnode in enumerate(fnodes): - self.attach(fnode, fidx + idx) + self.attach(fnode, fidx + idx + 1) # insert backward if isinstance(op.mirror, IRBpOperation): for fnode in fnodes: diff --git a/cube/runtime/adapter/__init__.py b/cube/runtime/adapter/__init__.py index 0e4324f7..ed7e9fc0 100644 --- a/cube/runtime/adapter/__init__.py +++ b/cube/runtime/adapter/__init__.py @@ -1,11 +1,5 @@ -# communications -from cube.runtime.adapter.collectives import send, recv -from cube.runtime.adapter.collectives import all_gather, all_reduce -from cube.runtime.adapter.collectives import reduce_scatter, broadcast - -# transformations -from cube.runtime.adapter.transform import select -from cube.runtime.adapter.transform import merge +from cube.runtime.adapter.collectives import * +from cube.runtime.adapter.transform import * # reducer from cube.runtime.adapter.reducer import Reducer diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 5f829138..054f2562 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -1,11 +1,11 @@ -from typing import List +from typing import List, Tuple import torch from cube.runtime.device import DeviceGroup from cube.profiler.timer import CudaTimer, print_each_rank -def send(tensor: torch.Tensor, to_rank: int): +def send(tensor: torch.Tensor, dst: int): """ send tensor to the remote devices. Each tensor can be sent to multiple devices @@ -21,7 +21,7 @@ def send(tensor: torch.Tensor, to_rank: int): if not tensor.is_contiguous(): tensor = tensor.contiguous() send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, to_rank + torch.distributed.isend, tensor, dst ) send_ops.append(send_op) reqs = torch.distributed.batch_isend_irecv(send_ops) @@ -29,9 +29,10 @@ def send(tensor: torch.Tensor, to_rank: int): req.wait() torch.cuda.synchronize() CudaTimer().stop(field_name='comm') + return tensor -def recv(shape: List[int], from_rank: int, dtype: torch.dtype): +def recv(tensors: List[torch.Tensor], shape: List[int], dtype: torch.dtype, src: int): # print(f'{torch.distributed.get_rank()}: recving...') CudaTimer().start(field_name='comm') ## synthetic ## @@ -46,7 +47,7 @@ def recv(shape: List[int], from_rank: int, dtype: torch.dtype): device=torch.cuda.current_device() ) recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, from_rank + torch.distributed.irecv, tensor, src ) reqs = torch.distributed.batch_isend_irecv([recv_op]) for req in reqs: @@ -90,75 +91,73 @@ def sendrecv(input_tensors: List[torch.Tensor], return outputs -### Collective Universal Interface ### -# def universal(input_tensors: List[torch.Tensor], -# output_shapes: List[List[int]], -# output_dtypes: List[torch.dtype], -# ranks: List[int]) - - -def all_reduce(input_tensors: List[torch.Tensor], - output_shapes: List[List[int]], - output_dtypes: List[torch.dtype], +def all_reduce(itensor: torch.Tensor, ranks: List[int]) -> torch.Tensor: """ Allreduce """ CudaTimer().start(field_name='comm') - assert len(input_tensors) == 1 - tensor = input_tensors[0] - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - tensor = tensor.detach() - tensor = tensor.requires_grad_() + if not itensor.is_contiguous(): + itensor = itensor.contiguous() + itensor = itensor.detach().requires_grad_() group = DeviceGroup().get_group(ranks) - torch.distributed.all_reduce(tensor, group=group) - + torch.distributed.all_reduce(itensor, group=group) + torch.cuda.synchronize() CudaTimer().stop(field_name='comm') - return tensor + return itensor -def all_gather(input_tensors: List[torch.Tensor], - output_shapes: List[List[int]], - output_dtypes: List[torch.dtype], - ranks: List[int]) -> List[torch.Tensor]: +def all_gather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: """ Allgather """ CudaTimer().start(field_name='comm') - assert len(input_tensors) == 1 - tensor = input_tensors[0] - if not tensor.is_contiguous(): - tensor = tensor.contiguous() + if not itensor.is_contiguous(): + itensor = itensor.contiguous() group = DeviceGroup().get_group(ranks) - tensor_list = [torch.empty_like(tensor) for _ in ranks] - idx = ranks.index(DeviceGroup().rank) - tensor_list[idx] = tensor - torch.distributed.all_gather(tensor_list, tensor, group=group) + tensor_list = [torch.empty_like(itensor) for _ in ranks] + tensor_list[torch.distributed.get_rank(group)] = itensor.data + torch.distributed.all_gather(tensor_list, itensor, group=group) + torch.cuda.synchronize() + # concat + otensor = torch.concat(tuple(tensor_list), dim=dim).requires_grad_() CudaTimer().stop(field_name='comm') - return tensor_list + return otensor -def reduce_scatter(input_tensors: List[torch.Tensor], - output_shapes: List[List[int]], - output_dtypes: List[torch.dtype], - ranks: List[int]) -> List[torch.Tensor]: +def reduce_scatter(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: """ ReduceScatter """ CudaTimer().start(field_name='comm') - input_tensors = list(input_tensors) - for idx, tensor in enumerate(input_tensors): + itensors = list(itensor.chunk(len(ranks), dim)) + for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): - input_tensors[idx] = tensor.contiguous() + itensors[idx] = tensor.contiguous() group = DeviceGroup().get_group(ranks) - idx = ranks.index(DeviceGroup().rank) - output = torch.empty_like(input_tensors[idx], requires_grad=True) - torch.distributed.reduce_scatter( - output, input_tensors, group=group - ) + otensor = torch.empty_like(itensors[0], requires_grad=True) + torch.distributed.reduce_scatter(otensor, itensors, group=group) + torch.cuda.synchronize() CudaTimer().stop(field_name='comm') - return output + return otensor + + +def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: + """ + All-to-all + """ + CudaTimer().start(field_name='comm') + itensors = list(itensor.chunk(len(ranks), dim=odim)) + for idx, tensor in enumerate(itensors): + if not tensor.is_contiguous(): + itensors[idx] = tensor.contiguous() + otensors = [torch.empty_like(t) for t in itensors] + group = DeviceGroup().get_group(ranks) + torch.distributed.all_to_all(otensors, itensors, group=group) + torch.cuda.synchronize() + otensor = torch.concat(tuple(otensors), dim=idim).requires_grad_() + CudaTimer().stop(field_name='comm') + return otensor def broadcast(input_tensors: List[torch.Tensor], diff --git a/cube/runtime/adapter/transform.py b/cube/runtime/adapter/transform.py index 63fe24b4..806e331a 100644 --- a/cube/runtime/adapter/transform.py +++ b/cube/runtime/adapter/transform.py @@ -2,43 +2,66 @@ Adapter: Tensor Transformation """ -from typing import List, Tuple, Optional +from typing import List, Tuple import torch +def identity(tensor: torch.Tensor): + """ + identity + """ + with torch.no_grad(): + tensor = tensor.detach().requires_grad_() + return tensor + + def select(tensor: torch.Tensor, indmap: Tuple[slice], valmap: Tuple[int, int]) -> torch.Tensor: - + """ + Select a part of tensor spatially and numerically. + """ with torch.no_grad(): sub_tensor = tensor[indmap] if valmap != (0, 1): sub_tensor = sub_tensor / valmap[1] - sub_tensor = sub_tensor.contiguous() + sub_tensor = sub_tensor.detach().requires_grad_() return sub_tensor -def merge(tensors: List[torch.Tensor], - concat: Optional[int] = None, - add: bool = False): + +def chunk(itensor: torch.Tensor, dim: int, chunks: int, idx: int) -> torch.Tensor: """ - Runtime primitive to finish tensor transformation. + split dimension in n chunks and take idx-th chunk + """ + with torch.no_grad(): + otensor = itensor.chunk(chunks, dim)[idx] + otensor = otensor.detach().requires_grad_() + return otensor - Warning: No contiguous is called!!! need to explicitly called - before communication + +def smerge(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: + """ + Runtime primitive of spatial merge. + Concatenate the tensors along a dimension Args: tensors: a list of torch tensor - concat: Optional[int]: the dimension to merge - add: bool: whether to perform value merge - """ - if not ((concat is not None) ^ (add is True)): # xor condition - raise RuntimeError("Expected concat or add") - if concat is not None: - with torch.no_grad(): - out = torch.cat(tensors, concat) - return out - if add is not None: - with torch.no_grad(): - out = tensors[0] - for tensor in tensors[1:]: - out = out + tensor - return out + dim: the dimension to concatenate. + """ + with torch.no_grad(): + out = torch.concat(tuple(tensors), dim).requires_grad_() + return out + + +def vmerge(tensors: List[torch.Tensor]) -> torch.Tensor: + """ + Runtime primitives of numerical merge. + Sum the tensors. + + Args: + tensors: a list of torch tensor + """ + with torch.no_grad(): + out = tensors[0] + for tensor in tensors[1:]: + out = out + tensor + return out.requires_grad_() From bd813571aa1b547d5834b6a20a5b7701ebbfe023 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 27 May 2022 17:02:17 +0800 Subject: [PATCH 0831/1892] fix empty adapter dispatch bug --- cube/codegen/codegen.py | 4 ---- cube/graph/adapter/adapter.py | 12 +++++------- cube/graph/adapter/gen.py | 2 +- cube/runtime/adapter/transform.py | 23 +++++++++++++++++++---- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index aa18c4cf..9d52b3a2 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -135,7 +135,6 @@ def gen(self, device: int, outfile=None, attach=False) -> str: self.emit_op_call(node) elif isinstance(node, IRAdapter): node = node.dispatch(device) - print(node.extra_repr()) self.emit_adapter_call(node) elif isinstance(node, IRWeightReducer): self.emit_reducer_init(node) @@ -423,9 +422,6 @@ def refcount(tensor, node) -> int: refcnt += 1 return refcnt - for node in device_nodes: - print(f'dev{device}: {node}') - # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: diff --git a/cube/graph/adapter/adapter.py b/cube/graph/adapter/adapter.py index 1dda4841..6611247f 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/graph/adapter/adapter.py @@ -23,6 +23,11 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): self._prims: Optional[List[IRAdapterPrim]] = None self._differentiable = False + device = set() + for tensor in inputs + outputs: + device.update(set(tensor.device)) + self.device = list(device) + @property def prims(self) -> List[IRAdapterPrim]: if self.is_forward: @@ -45,7 +50,6 @@ def prims(self) -> List[IRAdapterPrim]: def prims(self, prims: List[IRAdapterPrim]): assert all(isinstance(prim, IRAdapterPrim) for prim in prims), "Expect List[IRAdapterPrim]" self._prims = prims - self.update_device() @property def differentiable(self) -> bool: @@ -58,12 +62,6 @@ def differentiable(self) -> bool: def differentiable(self, val: bool): self._differentiable = val - def update_device(self): - device = set() - for prim in self.prims: - device.update(prim.device) - self.device = list(device) - def dispatch(self, devid: int): """ Get Adapter for a specific rank diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index 4b827edb..5cd1c130 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -135,7 +135,7 @@ def gen_fulltensor(ftensor: IRFullTensor) -> List[IRAdapter]: # print(f'detect no intersection') # return [] # general cases - warnings.warn('UserWarning: the adapter is generated using inefficient P2P send/recv') + warnings.warn('The adapter is generated using inefficient P2P send/recv') fprims, bprims = [], [] for subtensor in ftensor.ctensors: fprims += IRAdapterGener.gen_subtensor(subtensor) diff --git a/cube/runtime/adapter/transform.py b/cube/runtime/adapter/transform.py index 806e331a..8f8194b5 100644 --- a/cube/runtime/adapter/transform.py +++ b/cube/runtime/adapter/transform.py @@ -10,8 +10,11 @@ def identity(tensor: torch.Tensor): """ identity """ + require_grad = tensor.requires_grad with torch.no_grad(): - tensor = tensor.detach().requires_grad_() + tensor = tensor.detach() + if require_grad: + tensor = tensor.requires_grad_() return tensor @@ -20,11 +23,14 @@ def select(tensor: torch.Tensor, """ Select a part of tensor spatially and numerically. """ + require_grad = tensor.requires_grad with torch.no_grad(): sub_tensor = tensor[indmap] if valmap != (0, 1): sub_tensor = sub_tensor / valmap[1] - sub_tensor = sub_tensor.detach().requires_grad_() + sub_tensor = sub_tensor.detach() + if require_grad: + sub_tensor = sub_tensor.requires_grad_() return sub_tensor @@ -32,9 +38,12 @@ def chunk(itensor: torch.Tensor, dim: int, chunks: int, idx: int) -> torch.Tenso """ split dimension in n chunks and take idx-th chunk """ + require_grad = itensor.requires_grad with torch.no_grad(): otensor = itensor.chunk(chunks, dim)[idx] - otensor = otensor.detach().requires_grad_() + otensor = otensor.detach() + if require_grad: + otensor = otensor.requires_grad_() return otensor @@ -47,8 +56,11 @@ def smerge(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: tensors: a list of torch tensor dim: the dimension to concatenate. """ + require_grad = any(t.require_grad for t in tensors) with torch.no_grad(): out = torch.concat(tuple(tensors), dim).requires_grad_() + if require_grad: + out = out.requires_grad_() return out @@ -60,8 +72,11 @@ def vmerge(tensors: List[torch.Tensor]) -> torch.Tensor: Args: tensors: a list of torch tensor """ + require_grad = any(t.require_grad for t in tensors) with torch.no_grad(): out = tensors[0] for tensor in tensors[1:]: out = out + tensor - return out.requires_grad_() + if require_grad: + out = out.requires_grad_() + return out From f74f26056d486ee0dfa8c0cdd578a6bd1bb80432 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 27 May 2022 18:37:56 +0800 Subject: [PATCH 0832/1892] fix chunk primitive --- cube/graph/adapter/gen.py | 2 +- cube/graph/adapter/layout.py | 21 +++------------------ cube/graph/adapter/prim.py | 13 +++++++------ cube/runtime/adapter/transform.py | 7 +++++-- 4 files changed, 16 insertions(+), 27 deletions(-) diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index 5cd1c130..d5e5435d 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -175,7 +175,7 @@ def gen_gridlayout(ftensor: IRFullTensor) -> List[IRAdapter]: olayout = GridLayout.togrid(ftensor, ctensors) # print(f'forward full tensor: {ftensor}\n producer: {ilayout}, consumer: {olayout}') # find path - paths, fprims = ilayout.path(olayout, auto_replace=True) + paths, fprims = ilayout.path(olayout) # re-assign the operator if miss-ordered names, from_dev, to_dev = [], [], [] diff --git a/cube/graph/adapter/layout.py b/cube/graph/adapter/layout.py index b2864e64..60499657 100644 --- a/cube/graph/adapter/layout.py +++ b/cube/graph/adapter/layout.py @@ -155,13 +155,14 @@ def r2d(self, dim: int, chunks: int): otensor._cell = itensor._cell prims = [] for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + ranks = tuple(t.device[0] for t in itensors) for idx, (itensor, otensor) in enumerate(zip(itensors, otensors)): - prims.append(ChunkPrim(itensor, otensor, dim, chunks, idx)) + prims.append(ChunkPrim(itensor, otensor, dim, ranks)) return glayout, prims # ================ solution ============= # - def path(self, dst, auto_replace: bool = False) -> Tuple: + def path(self, dst) -> Tuple: """ Find a path from self to destination GridLayout using primitivies. This implementation uses search order of @@ -234,22 +235,6 @@ def step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLa paths.append(olayout) comm_prims += oprims break - # if auto_replace: - # replaced = False - # reorder : Dict[str, Tuple[int, int]] = dict() - # for itensor, otensor in zip(paths[-1].mat.flatten(), dst.mat.flatten()): - # assert len(itensor.device) == 1 and len(otensor.device) == 1, \ - # "Expect tensor only has one device. Report this as a bug" - # if itensor.device != otensor.device: - # inode, onode = itensor._cell, otensor._cell - # reorder[f'{onode.name}-{onode._id}'] = (onode.device[0], inode.device[0]) - # onode.device = inode.device - # if onode.mirror is not None: - # onode.mirror.device = inode.device - # replaced = True - # if replaced: - # print(f'warning: a better device placement is found and set for op {reorder}') - return paths, comm_prims def __repr__(self): diff --git a/cube/graph/adapter/prim.py b/cube/graph/adapter/prim.py index 95618b29..06e0c61e 100644 --- a/cube/graph/adapter/prim.py +++ b/cube/graph/adapter/prim.py @@ -2,7 +2,7 @@ The primitive used for IRAdapter """ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import copy from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap @@ -124,16 +124,17 @@ class ChunkPrim(SpatialPrim): split dimension in n chunks and take idx-th chunk """ def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor, - dim: int, chunks: int, idx: int): - assert 0 <=idx and idx < chunks, "idx out of scope" + dim: int, ranks: Tuple[int]): + assert itensor.device[0] in ranks, "idx out of scope" super().__init__([itensor], [otensor]) self.kwargs['dim'] = dim - self.kwargs['chunks'] = chunks - self.kwargs['idx'] = idx + self.kwargs['ranks'] = ranks self.signature = 'cube.runtime.adapter.chunk' def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs(0)} = split(dim={self.kwargs['dim']}, chunks={self.kwargs['chunks']}, idx={self.kwargs['idx']})" + chunks = len(self.kwargs['ranks']) + idx = self.kwargs['ranks'].index(self.device[0]) + return f"dev{self.device}: {self.outputs(0)} = split(dim={self.kwargs['dim']}, chunks={chunks}, idx={idx})" class MergeDimPrim(SpatialPrim): diff --git a/cube/runtime/adapter/transform.py b/cube/runtime/adapter/transform.py index 8f8194b5..392ccfd4 100644 --- a/cube/runtime/adapter/transform.py +++ b/cube/runtime/adapter/transform.py @@ -34,13 +34,16 @@ def select(tensor: torch.Tensor, return sub_tensor -def chunk(itensor: torch.Tensor, dim: int, chunks: int, idx: int) -> torch.Tensor: +def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: """ split dimension in n chunks and take idx-th chunk + + ranks (Tuple[int]): the order of split tensor. """ + idx = ranks.index(torch.distributed.get_rank()) require_grad = itensor.requires_grad with torch.no_grad(): - otensor = itensor.chunk(chunks, dim)[idx] + otensor = itensor.chunk(len(ranks), dim)[idx] otensor = otensor.detach() if require_grad: otensor = otensor.requires_grad_() From 379b817067bf1ed9fdb858686892fbb59f23b238 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Fri, 27 May 2022 14:09:02 +0000 Subject: [PATCH 0833/1892] Merged PR 1381: Add ops and adapt the model to parse WRF Changes: - unroll Python loops (in TorchIR, `prim::Loop`) whose upper bound can be statically evaluated - added CubeIR ops involved in WRF TODO - type inference since some ops involves non-floating dtypes, like the comparison ops. - IRGraph construction won't terminate for WRF graph, check that. --- cube/graph/operator/function/cat.py | 27 +++- cube/graph/operator/function/creators.py | 41 +++++ cube/graph/operator/function/function.py | 192 +++++++++++++++++++++-- cube/graph/operator/function/repeat.py | 40 +++++ cube/graph/operator/function/scatter.py | 60 +++++++ cube/graph/operator/function/select.py | 78 +++++++++ cube/graph/parser/frame.py | 22 ++- cube/graph/parser/mapping.py | 24 +++ cube/graph/parser/parser.py | 178 +++++++++++++++++++-- examples/wrf/wrf2.py | 19 ++- requirements.txt | 3 +- tests/test_prim_loop.py | 156 ++++++++++++++++++ 12 files changed, 805 insertions(+), 35 deletions(-) create mode 100644 cube/graph/operator/function/creators.py create mode 100644 cube/graph/operator/function/repeat.py create mode 100644 cube/graph/operator/function/scatter.py create mode 100644 cube/graph/operator/function/select.py create mode 100644 tests/test_prim_loop.py diff --git a/cube/graph/operator/function/cat.py b/cube/graph/operator/function/cat.py index 0d6024ce..a674c6b4 100644 --- a/cube/graph/operator/function/cat.py +++ b/cube/graph/operator/function/cat.py @@ -1,4 +1,5 @@ from copy import copy +import itertools from typing import List from cube.graph.operator.operator import IRFwOperation @@ -17,7 +18,6 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, for idx, input in enumerate(inputs): self.set_input(idx, input) self.kwargs.update(kwargs) - self._cat_count = len(inputs) def infer_shape(self) -> bool: """ @@ -50,3 +50,28 @@ def infer_shape(self) -> bool: s0.insert(dim, sumLen) self.outputs(0).shape = s0 return True + + +class IRStack(IRFwOperation): + def __init__(self, signature: str, inputs: List[IRTensor], name: str, dim: int): + # torch.stack(inputs:List[Tensor], dim:int) -> Tensor + assert len(inputs) > 0 + + super().__init__(name, signature, len(inputs), 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update({"dim": dim}) + + def infer_shape(self) -> bool: + dim = self.kwargs['dim'] + tensors : List[IRTensor] = self.inputs(None) # None for all inputs + + # `stack` requires all input tensors to have the same shape + if len(set(t.shape for t in tensors)) != 1: + return False + + shp : list = tensors[0].shape.copy() + shp.insert(dim, len(tensors)) + self.outputs(0).shape = shp + return True + diff --git a/cube/graph/operator/function/creators.py b/cube/graph/operator/function/creators.py new file mode 100644 index 00000000..1cf76eef --- /dev/null +++ b/cube/graph/operator/function/creators.py @@ -0,0 +1,41 @@ +from copy import copy +from typing import List + +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + +class IRZeros(IRFwOperation): + def __init__(self, signature: str, shape: List[int], name: str): + + # The shape information must be statically known integer values + assert all(isinstance(dim, int) for dim in shape) + + super().__init__(name, signature, input_length=0, output_length=1) + self._shape = copy(shape) + + def infer_shape(self) -> bool: + self.outputs(0).shape = self._shape + return True + + +#class IRNewTensor(IRFwOperation): +# def __init__(self, signature: str, data, name:str): +# pass +# def infer_shape(self) -> bool: +# pass + + +# `aten::to` has several overloading, which one should be dispatched is determined by the argument types +# See +# https://github.com/pytorch/pytorch/blob/483bb4f0cb273f42f655aa30eee6a1fbbaba69b0/torch/csrc/jit/runtime/register_prim_ops.cpp#L1057 +# https://github.com/pytorch/pytorch/blob/483bb4f0cb273f42f655aa30eee6a1fbbaba69b0/torch/csrc/jit/runtime/register_prim_ops.cpp#L2215 +class IRToTensor(IRFwOperation): + def __init__(self, signature: str, inputs, name:str): + super().__init__(name, signature, input_length=1, output_length=1) + self.set_input(0, inputs[0]) + + def infer_shape(self) -> bool: + self.outputs(0).shape = self.inputs(0).shape + return True + + diff --git a/cube/graph/operator/function/function.py b/cube/graph/operator/function/function.py index 1e5f53d6..483e95cb 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/operator/function/function.py @@ -1,7 +1,7 @@ -from typing import Iterable, List, Optional, Union, Dict +from typing import Any, Iterable, List, Optional, Tuple, Union, Dict import string import copy -from cube.graph.operator.function.cat import IRCat +import numpy from cube.ir.cten import IRTensor from cube.graph.operator.function.einops import EinDim, IREinops @@ -10,6 +10,11 @@ from cube.graph.operator.function.pad import IRPad from cube.graph.operator.function.scripteinops import IRScriptEinOps from cube.graph.operator.function.customops import IRCustomOps +from cube.graph.operator.function.cat import IRCat, IRStack +from cube.graph.operator.function.creators import IRToTensor, IRZeros +from cube.graph.operator.function.select import IRSelect, IRSlice +from cube.graph.operator.function.scatter import IRSelectScatter +from cube.graph.operator.function.repeat import IRRepeat def _create_eshape(shape: List[int], iterator: Optional[Iterable] = None, @@ -82,6 +87,56 @@ def BatchLinear(signature, inputs): return IREinops(signature, annos, inputs, 'bmm') +def Zeros(signature, + inputs: Tuple[ List[int], Optional[Any], Optional[Any], 'ErasedDevice', Optional[bool] ]): + # zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + # + # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of + # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. + + shape, dtype, layout, _erased_device, pin_memory = inputs + + # TODO parameters to support, currently they are all None + assert dtype is None + assert layout is None + assert pin_memory is None + + for dim, i in enumerate(shape): + if not isinstance(dim, int) and not dim >= 0: + raise RuntimeWarning(f"The {i}-th component of the shape must be non-negative integer") + return IRZeros(signature, shape, 'zeros') + + +def NewTensor(signature, + inputs: Tuple[ list, Optional[Any], 'ErasedDevice', bool ]): + # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor + # + # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of + # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. + + data, dtype, _erased_device, requires_grad = inputs + + # TODO parameters to support, currently they are all None + assert dtype is None + assert requires_grad == False + + arr = numpy.array(data) + + # ints or floats of any precision, e.g. i8, i64, f16, f32 + # and the specified array is regular/non-ragged. + # Otherwise NumPy would decide the element type as _o_bject. + if not arr.dtype.kind in ['i','f']: + raise RuntimeError("The specified data to create new tensor must be ints or floats") + + # TODO temporarily fake creation with Zeros + shape = list(arr.shape) + return IRZeros(signature, shape, 'tensor') + +def ToTensor(signature, + inputs: Tuple[ IRTensor, ... ]): + tensors = inputs[0:1] + return IRToTensor(signature, tensors, 'to') + def Add(signature, inputs): if len(inputs) == 2: kwargs = {} @@ -198,10 +253,9 @@ def Mul(signature, inputs): def Div(signature, inputs): lhs, rhs = inputs - if isinstance(lhs, int) and isinstance(rhs, int): - # only if both operands are int, do we do floor division. - return lhs // rhs - elif isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + # For `aten::div` we always do floating division, even operands are both ints. + # TorchScript would dispatch frontend `a // b` to another op `aten::floordiv`. return lhs / rhs annos = [ @@ -228,6 +282,36 @@ def Div(signature, inputs): return IREinops(signature, annos, inputs, 'div') +def FloorDiv(signature, inputs): + lhs, rhs = inputs + + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + return lhs // rhs + + annos = [ + '*, 1 -> *', + '1, * -> *', + '*, * -> *', + ] + # broadcast + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ + len(lhs.shape) == len(rhs.shape): + if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): + # TODO: support spatial partitioning on broadcast dim + lshape = _create_eshape(lhs.shape) + rshape = copy.copy(lshape) + oshape = copy.copy(lshape) + for dim in range(len(lhs.shape)): + if lhs.shape[dim] < rhs.shape[dim]: + oshape[dim] = rshape[dim] + lshape[dim] = str(lhs.shape[dim]) + elif lhs.shape[dim] > rhs.shape[dim]: + oshape[dim] = lshape[dim] + rshape[dim] = str(rhs.shape[dim]) + annos = [_create_anno([lshape, rshape], [oshape])] + return IREinops(signature, annos, inputs, 'floordiv') + + def Pow(signature, inputs): lhs, rhs = inputs @@ -258,16 +342,59 @@ def Pow(signature, inputs): return IREinops(signature, annos, inputs, 'pow') +# if both operands are scalars, returns bool. +# if one operand is a tensor, returns a broadcasted tensor with dtype being bool. +def comparison_einops(f, name, signature, inputs): + # f : (Scalar, Scalar) -> bool + assert len(inputs) == 2 + lhs, rhs = inputs + + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + return f(lhs, rhs) + + annos = [ + '*, 1 -> *', + '1, * -> *', + '*, * -> *', + ] + # broadcast + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ + len(lhs.shape) == len(rhs.shape): + if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): + # TODO: support spatial partitioning on broadcast dim + lshape = _create_eshape(lhs.shape) + rshape = copy.copy(lshape) + oshape = copy.copy(lshape) + for dim in range(len(lhs.shape)): + if lhs.shape[dim] < rhs.shape[dim]: + oshape[dim] = rshape[dim] + lshape[dim] = str(lhs.shape[dim]) + elif lhs.shape[dim] > rhs.shape[dim]: + oshape[dim] = lshape[dim] + rshape[dim] = str(rhs.shape[dim]) + annos = [_create_anno([lshape, rshape], [oshape])] + return IREinops(signature, annos, inputs, name) + + def Neg(signature, inputs): - annos = ['* -> *'] - tensor = inputs[0:1] - if len(inputs) == 2: + if len(inputs) == 1: + kwargs = {} + elif len(inputs) == 2: # adapt for newest pytorch version approximate = inputs[1] - return IREinops(signature, annos, tensor, 'neg', - approximate=approximate) + kwargs = {'approximate': approximate} + + inputs = inputs[0:1] else: - return IREinops(signature, annos, tensor, 'neg') + raise RuntimeError("The number of inputs must be 1 or 2") + + arg, = inputs + if isinstance(arg, (int, float)): + assert not('approximate' in kwargs) + return -arg + + annos = ['* -> *'] + return IREinops(signature, annos, inputs, 'neg', **kwargs) def Sin(signature, inputs): annos = ['* -> *'] @@ -570,12 +697,53 @@ def Pad(signature, inputs): def Cat(signature, inputs): """ torch.cat(inputs: List[Tensor], dim: int) -> Tensor + + e.g. cat(tensor([2,3]), tensor([2,3])).shape == [4,3] """ tensors : List[IRTensor] dim : int tensors, dim = inputs return IRCat(signature, tensors, 'cat', dim=dim) +def Stack(signature, inputs: Tuple[List[IRTensor], int]): + """ + torch.stack(inputs: List[Tensor], dim: int) -> Tensor + + e.g. stack(tensor([2,3]), tensor([2,3])).shape == [2,2,3] + """ + tensors, dim = inputs + return IRCat(signature, tensors, 'stack', dim=dim) + +def Select(signature, inputs: Tuple[IRTensor, int, int]): + """ + torch.select(self:Tensor, dim:int, index:int) -> Tensor + """ + tensor, dim, index = inputs + return IRSelect(signature, [tensor], 'select', dim, index) + +def Slice(signature, inputs: Tuple[IRTensor, int, Optional[int], Optional[int], int]): + """ + aten::slice(input:Tensor, dim:int, start:Optional[int], end:Optional[int], step:int) -> Tensor + """ + tensor, dim, start, end, step = inputs + return IRSlice(signature, [tensor], 'slice', dim, start, end, step) + +def SelectScatter(signature, inputs:Tuple[IRTensor, IRTensor, int, int]): + """ + torch.select_scatter(self:Tensor, input:Tensor, dim:int, index:int) -> Tensor + """ + self, input, dim, index = inputs + return IRSelectScatter(signature, [self, input], 'scatter_select', dim, index) + + +def Repeat(signature, inputs:Tuple[IRTensor, List[int]]): + """ + torch.repeat(tensor:Tensor, repeats: List[int]) -> Tensor + """ + tensor, repeats = inputs + return IRRepeat(signature, [tensor], 'repeat', repeats) + + def ScriptEinOps(signature, inputs): """ apply_for_scriptable_torch(recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str) -> torch.Tensor: diff --git a/cube/graph/operator/function/repeat.py b/cube/graph/operator/function/repeat.py new file mode 100644 index 00000000..79705e35 --- /dev/null +++ b/cube/graph/operator/function/repeat.py @@ -0,0 +1,40 @@ +from typing import List, Optional, Tuple +import itertools + +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + +class IRRepeat(IRFwOperation): + """ + torch.repeat(tensor:Tensor, repeats: List[int]) -> Tensor + """ + + def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, repeats:List[int]): + assert len(inputs) == 1 + assert isinstance(repeats, list) + assert all(isinstance(r, int) for r in repeats) + + super().__init__(name, signature, 1, 1) + self.set_input(0, inputs[0]) + self.kwargs.update({"repeats": repeats}) + + def infer_shape(self) -> bool: + shp_self : List[int] = self.inputs(0).shape + if len(shp_self) == 0: + return False + + repeats : List[int] = self.kwargs["repeats"] + + # This API broadcasts the input tensor if the specified `repeats:list` is longer than the shape. + s1 = shp_self.copy() + s1.reverse() + s2 = repeats.copy() + s2.reverse() + + # Multiply from the end + shp = [d1 * d2 for d1, d2 in itertools.zip_longest(s1, s2, fillvalue=1)] + shp.reverse() + + self.outputs(0).shape = shp + return True + diff --git a/cube/graph/operator/function/scatter.py b/cube/graph/operator/function/scatter.py new file mode 100644 index 00000000..fa11f264 --- /dev/null +++ b/cube/graph/operator/function/scatter.py @@ -0,0 +1,60 @@ +from copy import copy +from typing import List, Optional, Tuple + +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + +class IRSelectScatter(IRFwOperation): + """ + torch.select_scatter(self:Tensor, input:Tensor, dim:int, index:int) -> Tensor + + identical to: + ``` + x = self.copy() # Assume N-d tensor. + view = x.select(dim, index) # View and input are (N-1)-d tensors. + view.copy_(input) # See REMARK! + return x + ``` + + REMARK: + Unlike the `copy_` API in the identical code snippet above, + `select_scatter` (as well as other scatter family APIs) are NOT broadcastable, + namely it requires the `input` tensor to embed is an exactly (N-1)-dimensional tensor. + + But in-place Python code like + ``` + self[index] = input + ``` + involves broadcasting, so `input` can has any broadcastable shapes to `self.shape.pop(dim)`, + including being scalars. + """ + + def __init__(self, signature: str, inputs:Tuple[IRTensor, IRTensor], name: str, dim:int, index:int): + assert len(inputs) == 2 + + super().__init__(name, signature, 2, 1) + self.set_input(0, inputs[0]) + self.set_input(1, inputs[1]) + self.kwargs.update({"dim": dim, "index": index}) + + def infer_shape(self) -> bool: + shp_self : List[int] = self.inputs(0).shape + if len(shp_self) == 0: + return False + + shp_input = self.inputs(1).shape + + if len(shp_input) == 0: + print("The 0-length input shape is ambiguous, may be uninferrable or just of a 0-d tensor") + elif len(shp_input) > 0: + dim: int = self.kwargs["dim"] + copy_shp = shp_self.copy() + copy_shp.pop(dim) + if copy_shp != shp_input: + raise RuntimeError(f"self shape {shp_self} and input shape {shp_input} with dim={dim} mismatch") + + s2 = copy(shp_self) + self.outputs(0).shape = s2 + return True + + diff --git a/cube/graph/operator/function/select.py b/cube/graph/operator/function/select.py new file mode 100644 index 00000000..ee22f15a --- /dev/null +++ b/cube/graph/operator/function/select.py @@ -0,0 +1,78 @@ +from copy import copy +from typing import List, Optional, Tuple + +from cube.graph.operator.operator import IRFwOperation +from cube.ir.cten import IRTensor + +class IRSelect(IRFwOperation): + """ + torch.select(input:Tensor, dim:int, index:int) -> Tensor + """ + def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, dim:int, index:int): + assert len(inputs) == 1 + + super().__init__(name, signature, 1, 1) + self.set_input(0, inputs[0]) + self.kwargs.update({"dim": dim, "index": index}) + + def infer_shape(self) -> bool: + s : List[int] = self.inputs(0).shape + if len(s) == 0: + return False + + dim = self.kwargs["dim"] + + s2 = copy(s) + s2.pop(dim) + self.outputs(0).shape = s2 + + return True + + +class IRSlice(IRFwOperation): + """ + aten::slice(input:Tensor, dim:int=0, start:Optional[int]=None, end:Optional[int]=None, step:int=1) -> Tensor + """ + + def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, + dim:int, start:Optional[int], end:Optional[int], step:int): + assert len(inputs) == 1 + + super().__init__(name, signature, 1, 1) + self.set_input(0, inputs[0]) + self.kwargs.update({"dim": dim, "start": start, "end": end, "step": step}) + + def infer_shape(self) -> bool: + s : List[int] = self.inputs(0).shape + if len(s) == 0: + return False + + dim : int = self.kwargs["dim"] + start : Optional[int] = self.kwargs["start"] + end : Optional[int] = self.kwargs["end"] + step : int = self.kwargs["step"] + + if start is None: + start = 0 + if end is None: + end = 2 ** 64 + + dim_len = s[dim] + + def clip(offset): + if offset < 0: + offset += dim_len + return min(dim_len, max(0, offset)) + + start = clip(start) + end = clip(end) + + sliced_dim_len = len(range(start, end, step)) + s2 = s.copy() + s2[dim] = sliced_dim_len + self.outputs(0).shape = s2 + + return True + + +# torch.gather(input:Tensor, dim:int, index:LongTensor, *, sparse_grad=False, out=None) -> Tensor diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index bf19c035..19f95ea4 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -12,11 +12,21 @@ def __init__(self): self._vars: List[dict[str, Any]] = list() self._var_stack: List[str] = list() - def push(self): + def push(self, inherit_from_top=False): """ This should only be called when step in a module + + Args: + inherit_from_top (bool): + whether to make all already defined variables in the top frame + accessible to the evaluation procedure + (e.g. references to such variables won't cause VarNotFound exception). """ - self._vars.append(OrderedDict()) + if inherit_from_top: + assert len(self._vars) > 0 + self._vars.append(self._vars[-1].copy()) + else: + self._vars.append(OrderedDict()) def pop(self): """ @@ -35,9 +45,13 @@ def add_var(self, var_name: str, val: Any, graph_arg: int = -1): val: variable content graph_arg (int): indicate whether it is an argument of the graph. + If == -1, is not a graph arg. - If >= 0, is a graph arg, will try to find - val from previous frame + + If >= 0, is a graph arg, will try to find val from previous frame, + by associating the names of the formal parameters of the callee function + and the names of the arguments passed-in. + (then look up the values of the arguments in the previous frame) """ if not isinstance(var_name, str): raise RuntimeError("Expected var_name is str") diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index eaede8bf..44f14665 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Union import torch +import operator from functools import partial import cube.graph.operator.function as function @@ -70,6 +71,11 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # torch aten + # creators + __ttemplate('zeros'): function.Zeros, + __ttemplate('tensor'): function.NewTensor, + __ttemplate('to'): function.ToTensor, + __ttemplate('add') : function.Add, __ttemplate('sub') : function.Sub, @@ -78,8 +84,15 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('div') : function.Div, + __ttemplate('floordiv') : function.FloorDiv, + __ttemplate('neg'): function.Neg, + __ttemplate('gt'): partial(function.comparison_einops, operator.gt, 'gt'), + __ttemplate('lt'): partial(function.comparison_einops, operator.lt, 'lt'), + __ttemplate('ge'): partial(function.comparison_einops, operator.ge, 'ge'), + __ttemplate('le'): partial(function.comparison_einops, operator.le, 'le'), + __ttemplate('pow'): function.Pow, __ttemplate('sin'): function.Sin, @@ -100,11 +113,22 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('conv3d'): function.Conv3D, + __ttemplate('select'): function.Select, + + __ttemplate('slice'): function.Slice, + + #pytorch1.11 + __ttemplate('select_scatter'): function.SelectScatter, + + __ttemplate('repeat'): function.Repeat, + #pytorch1.11 __ttemplate('linear'): function.Linear, __ttemplate('cat'): function.Cat, + __ttemplate('stack'): function.Stack, + #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 14d09052..23ec8cc0 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -1,3 +1,4 @@ +from multiprocessing.synchronize import Condition import torch import enum import re @@ -27,6 +28,7 @@ class ScriptNodeKind(enum.Enum): PrimTupleUnpack = 9 PrimPythonOp = 10 PrimDevice = 11 # erased + PrimLoop = 12 class ScriptModuleParser: @@ -146,6 +148,8 @@ def ntype(node: torch._C.Node): return ScriptNodeKind.AtenOp if node.kind() == 'prim::If': return ScriptNodeKind.PrimIf + if node.kind() == 'prim::Loop': + return ScriptNodeKind.PrimLoop if node.kind() == 'prim::ListConstruct': return ScriptNodeKind.PrimListConstruct if node.kind() == 'prim::TupleConstruct': @@ -184,6 +188,8 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] return ScriptModuleParser.parse_prim_list_unpack_node(node, module, frame) if node_type == ScriptNodeKind.PrimPythonOp: return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) + if node_type == ScriptNodeKind.PrimLoop: + return ScriptModuleParser.parse_prim_loop_node(node, module, frame) # TODO bother assigning all ignored prim functions new NodeKinds? if node_type == ScriptNodeKind.PrimDevice: @@ -274,18 +280,16 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: frame.add_var(outputs[0].debugName(), container[index]) return [] - # aten::tensor(elems: List^{n:Nat}[T], dtype:Optional[ScalarType], device:Device, requires_grad:bool) -> Tensor - elif fsig == 'torch.tensor': - # originally 'aten::tensor' - var_name = outputs[0].debugName() - elems, dtype, erased_device, requires_grad = input_val - - # dtype may be None, in PyTorch it's to infer dtype from 'elems'. - if dtype == None: - dtype = DType2IRDType.map(torch.get_default_dtype()) + elif fsig == 'torch.__range_length': + lo, hi, step = input_val + rng_len = ScriptModuleParser.aten___range_length(lo, hi, step) + frame.add_var(outputs[0].debugName(), rng_len) + return [] - ir_tensor = IRFullTensor(shape=[len(elems)], name=var_name, requires_grad=requires_grad, dtype=dtype) - frame.add_var(var_name, ir_tensor) + elif fsig == 'torch.__derive_index': + index, start, step = input_val + derived = ScriptModuleParser.aten___derive_index(index, start, step) + frame.add_var(outputs[0].debugName(), derived) return [] # May be a symbolic object i.e. IRFwOperation, @@ -448,6 +452,125 @@ def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: """ raise NotImplementedError("Dynamic Graph is not supported yet") + @staticmethod + def parse_prim_loop_node(node, module, frame) -> List[IRFwOperation]: + """ + Inputs: + %max_iter_count : int + %init_condition : bool + %x_1 : T_1 + ... + %x_N : T_N + %dependencies : R + + Syntax: + %y_1 : T_1, ..., %y_N : T_N = prim::Loop(%max_iter_count, %init_condition, %x_1, ..., %x_N) + block0(%iter_step : int, %p_1 : T_1, ..., %p_N : T_N): + ... + %r_1 : T_1 = some_func(%x_1, %dependencies) + ... + %r_N : T_N = ... + %next_condition : bool = ... + -> (%next_condition, %r_1, ..., %r_N) + + REMARK: + - Outer variables (%dependencies) may be referenced in the Loop-body/subgraph, this is AKA _free variables_. + In contrast, a standalone TorchScript function/graph will have all variables, + including its parameters, defined within its scope. + + In other words, functions/graphs have no free variables. + + Semantics: + - The next step is evaluated if both (%iter_step < %max_iter_count) and (%next_condition == True). + - (%y_1, ..., %y_N) are bound to the last (%r_1, ..., %r_N) returned. + If no step is ever evaluated, they are (%x_1, ..., %x_N). + """ + inputs : List[torch._C.Value] = list(node.inputs()) + outputs : List[torch._C.Value] = list(node.outputs()) + + in_vals = [frame.get_var(input.debugName()) for input in inputs] + + max_iter_count, init_condition = in_vals[0:2] + if not isinstance(max_iter_count, int): + raise RuntimeError("The upper bound of the loop must be able to be statically evaluated") + if not isinstance(init_condition, bool): + raise RuntimeError("The init condition of the loop must be able to be statically evaluated") + + # type: Subgraph + loop_block : torch._C.Block = list(node.blocks())[0] + + body_in_vars : torch._C.Value = list(loop_block.inputs()) + iter_step_var = body_in_vars[0] + p_vars = body_in_vars[1:] + + body_out_vars = list(loop_block.outputs()) + + step = 0 + condition = init_condition + loop_carried_vals = in_vals[2:] + + all_ir_nodes : List[IRFwOperation] = [] + + while step < max_iter_count and condition: + + # create the context for evaluating the body, and bind loop variables %iter_step, %p_1, ... + + # Defensively we don't let variables defined in the Loop body subgraph pollute the outer graph. + # So we'd better duplicate all existing variables into a new frame (namely 'inherit_from_top'), + # and clean up this new frame after the interpretation of the whole loop execution. + frame.push(inherit_from_top=True) + + frame.add_var(iter_step_var.debugName(), step) + + # At the evaluation of each step, we cannot call Frame's 'push_param(var_name)' and 'add_var(var_name, val, graph_arg=N)' APIs, + # because all intermediate loop-carried values do not have syntactically static names. + # + # For the sake of isolation, we don't bind carried values onto {y_i}s variables and overwrite the binding + # during evaluation, either. + assert len(p_vars) == len(loop_carried_vals) + for p_var, carried_val in zip(p_vars, loop_carried_vals): + frame.add_var(p_var.debugName(), carried_val) + + # evaluate the body block + for subnode in loop_block.nodes(): + subnode : torch._C.Node + + sub_ir_nodes : List[IRFwOperation] = ScriptModuleParser.parse_node(subnode, module, frame) + + for ir_node in sub_ir_nodes: + try: + ret = ir_node.infer_shape() + if not ret: + print(f'warning: {ir_node} cannot infer shape') + except Exception: + raise RuntimeError( + f"====== Shape Infer Error ====\n\n\n" + f"IR Node: {ir_node}\n\n" + f"Module:\n{module.code}\n\n" + f"Node:\n{node}\n" + f"====== Shape Infer Error ====\n\n\n" + ) + + all_ir_nodes += sub_ir_nodes + + # rebind for next step and clean-ups + step_result_vals = [frame.get_var(body_out_var.debugName()) for body_out_var in body_out_vars] + condition = step_result_vals[0] + loop_carried_vals = step_result_vals[1:] + step += 1 + + frame.pop() + + if not isinstance(condition, bool): + raise RuntimeError(f"At the {step}-th step the condition is not evaluated to a constant bool") + + assert len(outputs) == len(loop_carried_vals) + for output, y_val in zip(outputs, loop_carried_vals): + frame.add_var(output.debugName(), y_val) + + return all_ir_nodes + + @staticmethod def parse_prim_list_construct_node(node, module, frame: Frame) -> List[None]: """ @@ -518,4 +641,37 @@ def flatten(smodule, depth=0): submodule = getattr(smodule, label) ScriptModuleParser.flatten(submodule, depth+1) + @staticmethod + def aten___range_length(lo, hi, step): + """ + aten::__range_length(int lo, int hi, int step) -> int + + Python loops + ``` + for i in range(L, H, S): + use(i) + ``` + will be translated to TorchScript + ``` + _c = aten::__range_length(L, H, S) + for _k < _c: + i = aten::__derive_index(k, L, S) + use(i) + ``` + """ + if not (isinstance(lo, int) and isinstance(hi, int) and isinstance(step, int)): + raise RuntimeError("All inputs to __range_length must be statically evaluated") + if step == 0: + raise RuntimeError("Step cannot be zero") + + return len(range(lo, hi, step)) + + @staticmethod + def aten___derive_index(index, start, step): + if not (isinstance(index, int) and isinstance(start, int) and isinstance(step, int)): + raise RuntimeError("All inputs to __derive_index must be statically evaluated") + + return start + index * step + + diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py index 7c932998..7142fb90 100644 --- a/examples/wrf/wrf2.py +++ b/examples/wrf/wrf2.py @@ -174,12 +174,15 @@ def ac_step(self, dtau:float, O2_ = torch.zeros(O2.shape, device=O2.device) mu2_ = torch.zeros(mu2.shape, device=mu2.device) + for i in range(1, O2.shape[0] + 1): - O2_[-i] = i * self.delta_z * dpi2 + \ + sub = i * self.delta_z * dpi2 + \ (self.dx(self.px(U2_)) + self.dy(self.py(V2_)) - R_mu)[-i:].view( -1, self.ny, self.nx).sum(0) * self.delta_z + O2_ = O2_.select_scatter(sub, dim=0, index=-i) + for i in range(mu2.shape[0]): - mu2_[i] = pi2 + mu2_ = mu2_.select_scatter(pi2, dim=0, index=i) # self.O2_ = O2_ @@ -383,14 +386,18 @@ def solve_tridiagonal_(self, # forward sweep for i in range(1, d.shape[0]): w = l[i - 1] / d[i - 1] - d[i] = d[i] - w * u[i - 1] - b[i] = b[i] - w * b[i - 1] + + d_i = d[i] - w * u[i - 1] + b_i = b[i] - w * b[i - 1] + + d = d.select_scatter(d_i, dim=0, index=i) + b = b.select_scatter(b_i, dim=0, index=i) # backward substitution x = torch.zeros(b.shape, device=b.device) - x[-1] = b[-1] / d[-1] + x.select_scatter(b[-1] / d[-1], dim=0, index=-1) for i in range(x.shape[0] - 2, -1, -1): - x[i] = (b[i] - u[i] * x[i + 1]) / d[i] + x.select_scatter( (b[i] - u[i] * x[i + 1]) / d[i], dim=0, index=i) return x diff --git a/requirements.txt b/requirements.txt index f32bdd38..e83a8c22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ einops -matplotlib \ No newline at end of file +matplotlib +pytest \ No newline at end of file diff --git a/tests/test_prim_loop.py b/tests/test_prim_loop.py new file mode 100644 index 00000000..8f4c14ed --- /dev/null +++ b/tests/test_prim_loop.py @@ -0,0 +1,156 @@ +# run tests: +# pytest ./tests/test_prim_loop.py + +import pytest +import torch +import cube + +from cube.graph.parser.frame import Frame +from cube.graph.parser.parser import ScriptModuleParser +from cube.graph.tensor import IRFullTensor +from cube import ir + +# Stub objects: +# - A stub object for 'ScriptModule' should have members: +# -- entry_method_normally_forward: Stub[ScriptMethod] +# -- code: str # only to avoid AttributeError, could be empty +# and optionally: +# -- other_script_method: Stub[ScriptMethod] +# +# - A stub object for 'ScriptMethod' should have fields: +# -- graph: torch._C.Graph + +class StubScriptMethod(object): + def __init__(self, graph: torch._C.Graph) -> None: + self.graph = graph + +# REMARK: +# 'torch._C.parse_ir' will change local variable names into unique-number ID, e.g. +# graph(%p: int): +# %local = ... +# becomes: +# graph(%p: int): +# %1 = ... + +def out_var_name0(g): + return next(g.outputs()).debugName() + +def test_simple_unroll_evaluation(): + g = torch._C.parse_ir(''' + graph(%a : int): + %ub : int = prim::Constant[value=100]() + %truth : bool = prim::Constant[value=1]() + %z : int = prim::Loop(%ub, %truth, %a) + block0(%step : int, %p : int): + %r : int = aten::add(%step, %p) + -> (%truth, %r) + return (%z) + ''') + frame = Frame() + frame.push() + frame.add_var("a", 0) + + for node in g.nodes(): + ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) + assert len(ir_nodes) == 0 + + # %z becomes %3 + assert frame.get_var(out_var_name0(g)) == (0+99)*100//2 + +def test_unroll_with_structural_info(): + g = torch._C.parse_ir(''' + graph(%a : Tensor): + %ub : int = prim::Constant[value=3]() + %truth : bool = prim::Constant[value=1]() + %i0 : int = prim::Constant[value=0]() + %z : Tensor = prim::Loop(%ub, %truth, %a) + block0(%step : int, %p : Tensor): + %ts : Tensor[] = prim::ListConstruct(%p, %p) # at each step, double the 0-th dim + %r : Tensor = aten::cat(%ts, %i0) + -> (%truth, %r) + return (%z) + ''') + frame = Frame() + frame.push() + + t_a = IRFullTensor(shape=[2,3]) + frame.add_var("a", t_a) + + all_ir_nodes = [] + for node in g.nodes(): + ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) + all_ir_nodes += ir_nodes + + assert len(all_ir_nodes) == 3 + + p = t_a + for i in range(3): + ir_node = all_ir_nodes[i] + in_val_1, in_val_2 = ir_node.inputs() + assert in_val_1 == p + assert in_val_2 == p + + out_val = ir_node.outputs(0) + assert out_val.shape == [ 2**(i+2) , 3] + + p = out_val + + +def test_nested_unroll(): + ''' + The outer loop has 3 steps, and the inner loop has 3 steps too. + ''' + + subp = torch._C.parse_ir(''' + graph(%self: int, %a : Tensor): + %ub : int = prim::Constant[value=3]() + %truth : bool = prim::Constant[value=1]() + %i0 : int = prim::Constant[value=0]() + %z : Tensor = prim::Loop(%ub, %truth, %a) + block0(%step : int, %p : Tensor): + %ts : Tensor[] = prim::ListConstruct(%p, %p) # at each step, double the 0-th dim + %r : Tensor = aten::cat(%ts, %i0) + -> (%truth, %r) + return (%z) + ''') + main = torch._C.parse_ir(''' + graph(%self: int, %a : Tensor): + %ub : int = prim::Constant[value=3]() + %truth : bool = prim::Constant[value=1]() + %z : Tensor = prim::Loop(%ub, %truth, %a) + block0(%step : int, %p : Tensor): + %r : Tensor = prim::CallMethod[name="subp"](%self, %p) + -> (%truth, %r) + return (%z) + ''') + + class StubScriptModule(object): + def __init__(self) -> None: + self.main = StubScriptMethod(main) + self.subp = StubScriptMethod(subp) + module = StubScriptModule() + + frame = Frame() + frame.push() + + t_a = IRFullTensor(shape=[2,3]) + frame.add_var("a", t_a) + + all_ir_nodes = [] + for node in main.nodes(): + ir_nodes = ScriptModuleParser.parse_node(node, module, frame) + all_ir_nodes += ir_nodes + + assert len(all_ir_nodes) == 9 + + p = t_a + for i in range(9): + ir_node = all_ir_nodes[i] + in_val_1, in_val_2 = ir_node.inputs() + assert in_val_1 == p + assert in_val_2 == p + + out_val = ir_node.outputs(0) + assert out_val.shape == [ 2**(i+2) , 3] + + p = out_val \ No newline at end of file From d89d9d097f24c804811bd6c9547e1084705768aa Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Fri, 27 May 2022 22:10:40 +0800 Subject: [PATCH 0834/1892] remove bad import --- cube/graph/parser/parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 23ec8cc0..93384d06 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -1,4 +1,3 @@ -from multiprocessing.synchronize import Condition import torch import enum import re From 99a127647e2f6ca3272588e7cb0451f0a1a15cfc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 28 May 2022 21:11:46 +0800 Subject: [PATCH 0835/1892] differentiable communication support --- cube/codegen/codegen.py | 98 +---- cube/compiler.py | 11 +- cube/execplan/execplan.py | 68 ++-- cube/execplan/planpass/fusion.py | 578 ++++----------------------- cube/execplan/planpass/grouping.py | 83 +--- cube/execplan/planpass/torchadapt.py | 220 ---------- cube/graph/adapter/gen.py | 11 +- cube/graph/adapter/layout.py | 7 +- cube/graph/adapter/prim.py | 195 ++++----- cube/graph/graph.py | 113 ++++-- cube/runtime/adapter/__init__.py | 3 +- cube/runtime/adapter/collectives.py | 17 + cube/runtime/adapter/distnn.py | 285 ------------- cube/runtime/adapter/nn.py | 290 ++++++++++++++ cube/runtime/adapter/transform.py | 20 +- 15 files changed, 657 insertions(+), 1342 deletions(-) delete mode 100644 cube/execplan/planpass/torchadapt.py delete mode 100644 cube/runtime/adapter/distnn.py create mode 100644 cube/runtime/adapter/nn.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 9d52b3a2..622ab9a0 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -12,7 +12,7 @@ from cube.graph.tensor import IRSubTensor from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation from cube.graph.adapter.adapter import IRWeightReducer, IRAdapter -from cube.graph.graph import IRGraph +from cube.graph.graph import IRGraph, IRSegment from cube.execplan import ExectuionPlan @@ -125,16 +125,15 @@ def gen(self, device: int, outfile=None, attach=False) -> str: self.init_comm_groups() # parse graph body - for node in self.execplan.sequence(device): - if isinstance(node, IRGraph): + for node in self.execplan.seq(device): + if isinstance(node, IRSegment): # skip backward ir graph - if all([isinstance(n, IRBpOperation) for n in node.nodes()]): + if not node.forward: continue - self.emit_graph_call(node) + self.emit_segment_call(node) elif isinstance(node, IRFwOperation): self.emit_op_call(node) elif isinstance(node, IRAdapter): - node = node.dispatch(device) self.emit_adapter_call(node) elif isinstance(node, IRWeightReducer): self.emit_reducer_init(node) @@ -211,11 +210,17 @@ def emit_node_declare(self, node: IRCell): self.symbols.create(self.tensor_naming(output)) return - def emit_graph_call(self, graph: IRGraph): + def emit_segment_call(self, graph: IRGraph): + """ + Emit IRSegment code + """ for node in graph.nodes(): - if isinstance(node, IRBpOperation): - raise RuntimeError("IRBpOperation is not expected in GenModel") - self.emit_op_call(node) + if isinstance(node, IRFwOperation): + self.emit_op_call(node) + elif isinstance(node, IRAdapter): + self.emit_adapter_call(node) + else: + raise RuntimeError(f"unexpected type {type(node)} in forward graph:\n{graph.extra_repr()}") def emit_op_call(self, node: IRFwOperation): """ @@ -246,9 +251,7 @@ def emit_adapter_call(self, node: IRAdapter): Emit adapter call """ assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" - # rank = node.device[0] for prim in node.prims: - # print(f'generating prim: {prim}') if len(prim.inputs()) == 1: itensors = self.tensor_naming(prim.inputs()[0]) else: @@ -260,62 +263,6 @@ def emit_adapter_call(self, node: IRAdapter): outputs = self.return_naming(prim.outputs()) code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' self.forward_region.append(code) - # # emit select - # if isinstance(prim, SelectPrim): - # sign = 'cube.runtime.adapter.select({tensor}, {indmap}, {valmap})' - # input = self.tensor_naming(prim.tensor) - # output = self.tensor_naming(prim.output) - # valmap = (prim.valmap.idx, prim.valmap.chunk_num) - # code = f'{output} = {sign.format(tensor=input, indmap=prim.indmap, valmap=valmap)}' - # self.forward_region.append(code) - # # emit move - # elif isinstance(prim, MovePrim): - # send_sign = 'cube.runtime.adapter.send({tensor}, {send_rank})' - # recv_sign = 'cube.runtime.adapter.recv({shape}, {from_rank}, {dtype})' - # tensor = self.tensor_naming(prim.tensor) - # # send - # if rank == prim.from_rank: - # code = f'{send_sign.format(tensor=tensor, send_rank=prim.to_rank)}' - # self.forward_region.append(code) - # # recv - # elif rank == prim.to_rank: - # output = self.tensor_naming(prim.tensor) - # dtype = self.dtype_map(prim.dtype) - # code = f'{tensor} = {recv_sign.format(shape=prim.shape, from_rank=prim.from_rank, dtype=dtype)}' - # self.forward_region.append(code) - # # emit merge - # elif isinstance(prim, MergePrim): - # sign = 'cube.runtime.adapter.merge({tensors}, {concat}, {add})' - # inputs = self.tuple_naming(prim.tensors) - # output = self.tensor_naming(prim.output) - # code = f'{output} = {sign.format(tensors=inputs, concat=prim.concat, add=prim.add)}' - # self.forward_region.append(code) - # # emit collectives - # elif isinstance(prim, CollectivePrim): - # sign = 'cube.runtime.adapter.{ctype}({input_tensors}, {output_shapes}, {output_dtypes}, {group})' - # inputs = self.tuple_naming(prim.inputs) - # outputs = self.return_naming(prim.outputs) - # dtypes = None - # if prim.output_dtypes is not None: - # dtypes = [self.dtype_map(dtype) for dtype in prim.output_dtypes] - # dtypes = self.tuple_naming(dtypes) - # body = sign.format( - # ctype=prim.ctype.value, - # input_tensors = inputs, - # output_shapes = prim.output_shapes, - # output_dtypes = dtypes, - # group=prim.group - # ) - # code = f'{outputs} = {body}' - # self.forward_region.append(code) - # else: - # raise TypeError(f"Unkown primitive types {type(prim)} of Adapter") - # requires grad generation - # sign = '{output} = {output}.contiguous().requires_grad_()' - # for output in node.outputs(): - # if isinstance(output, IRSubTensor) and output.requires_grad: - # code = sign.format(output=self.tensor_naming(output)) - # self.forward_region.append(code) def emit_reducer_init(self, node: IRWeightReducer): # reducer init interface @@ -397,18 +344,14 @@ def gen(self, device: int, outfile=None, attach=False) -> str: gencode = copy.copy(self.init_code) self.vars = VarManager() - device_nodes = self.execplan.sequence(device) - for idx, node in enumerate(device_nodes): - if isinstance(node, IRAdapter): - node = node.dispatch(device) - device_nodes[idx] = node + device_nodes = self.execplan.seq(device) def refcount(tensor, node) -> int: idx = device_nodes.index(node) refcnt = 0 for ref_node in device_nodes[idx+1:]: - if isinstance(ref_node, IRGraph): - if all([isinstance(rnode, IRFwOperation) for rnode in ref_node.nodes()]): + if isinstance(ref_node, IRSegment): + if ref_node.forward: if tensor in ref_node.inputs(): refcnt += 1 else: @@ -464,10 +407,9 @@ def emit_node(self, node: IRCell, name: str) -> List[str]: inputs = self.tuple_naming(inputs) outputs = self.return_naming(outputs) - if isinstance(node, IRGraph): - is_backward = all([isinstance(n, IRBpOperation) for n in node.nodes()]) + if isinstance(node, IRSegment): # emit forward - if not is_backward: + if node.forward: body = fsign.format(model=f'model.{name}', inputs=inputs) code = f'{outputs} = {body}' # emit backward diff --git a/cube/compiler.py b/cube/compiler.py index d08cc51e..990edbd6 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -14,6 +14,7 @@ from cube.logics.translator import LogicTranslator from cube.execplan import ExectuionPlan +from cube.execplan.planpass.fusion import DiffFusion from cube.execplan.planpass.grouping import Grouping from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen @@ -171,16 +172,16 @@ def decorator(fn: Callable) -> Callable: execplan = ExectuionPlan(graph) # plan pass for communication optimization + start = time.time() + execplan = DiffFusion.apply(execplan) + span = time.time() - start + print('> planpass on diff-fusion operations: {:.2f} s'.format(span)) + start = time.time() execplan = Grouping.apply(execplan) span = time.time() - start print('> planpass on grouping operations: {:.2f} s'.format(span)) - # start = time.time() - # execplan = P2PFusion.apply(execplan) - # span = time.time() - start - # print('> planpass on p2pfusion operations: {:.2f} s'.format(span)) - # start = time.time() # execplan = GroupingAdapter.apply(execplan) # span = time.time() - start diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index d3e2af69..5b19ffb6 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional import copy import numpy as np @@ -13,18 +13,38 @@ class ExectuionPlan: def __init__(self, graph: IRGraph): - if not isinstance(graph, IRGraph): - raise TypeError("Expected a list of ScheduleUnit") - self.graph = graph - self.device_seq = dict() + assert isinstance(graph, IRGraph), "Expected an IRGraph" + self._graph = graph + self._seq: Dict[int, List[IRCell]] = dict() + + # execution sequence for each device for node in graph.nodes(): if len(node.device) == 0: raise RuntimeError(f"Node device not set: {node}") for device in node.device: - if device not in self.device_seq: - self.device_seq[device] = [node] - else: - self.device_seq[device].append(node) + if device not in self._seq: + self._seq[device] = [] + self._seq[device].append(node) + + # adapter dispatch + for devid in self.devices(): + adapters = [node for node in self.at(devid) if isinstance(node, IRAdapter)] + while len(adapters) > 0: + fadapter = adapters[0] + badapter: Optional[IRAdapter] = fadapter.mirror + fnode = fadapter.dispatch(devid) + fidx = self.at(devid).index(fadapter) + self.at(devid)[fidx] = fnode + if badapter: + bnode = badapter.dispatch(devid) + IRCell.make_pair(fnode, bnode) + bidx = self.at(devid).index(badapter) + self.at(devid)[bidx] = bnode + # remove un-dispatched adapter + adapters.pop(0) + if badapter: + adapters.remove(badapter) + # check whether graph output is replicated across device # FIXME: should use adapter to generate communication for # traning logic output @@ -39,41 +59,43 @@ def __init__(self, graph: IRGraph): if len(devices) != 0: raise NotImplementedError("Require return values of training logic is replicated across nodes.") + @property + def graph(self) -> IRGraph: + return self._graph + def devices(self) -> List[int]: """ Get device set """ - devices = list(self.device_seq.keys()) + devices = list(self._seq.keys()) devices.sort() return devices - def sequence(self, device_id: int) -> List[IRCell]: + def seq(self, devid: int) -> List[IRCell]: """ - Get a copy of execution sequence for device id + Get a view of execution sequence for device id Note changing the list content will not change the execution plan. """ - if device_id not in self.device_seq: - return list() - return copy.copy(self.device_seq[device_id]) + assert devid in self._seq, f"device id {devid} not exists" + return copy.copy(self._seq[devid]) - def at(self, device_id: int) -> List[IRCell]: + def at(self, devid: int) -> List[IRCell]: """ Access the sequence for device id Note changing the list content will change the execution plan. """ - if device_id not in self.device_seq: - return list() - return self.device_seq[device_id] + assert devid in self._seq, f"device id {devid} not exists" + return self._seq[devid] - def set(self, device_id: int, seq: List[IRCell]): + def set(self, devid: int, seq: List[IRCell]): """ Set device sequence """ if not all([isinstance(su, IRCell) for su in seq]): raise TypeError("Expected a list of Cell") - self.device_seq[device_id] = seq + self._seq[devid] = seq def analyze(self, map2time: Optional[Callable] = None, @@ -253,6 +275,6 @@ def __repr__(self): dscp = f'Execution Plan ({self.graph.name}):\n' for devid in self.devices(): dscp += f'====> Device {devid}:\n' - for node in self.sequence(devid): - dscp += f'{node.module_repr()}\n' + for node in self._seq(devid): + dscp += f'{node}\n' return dscp diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index b9b9c771..b30eaece 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -1,522 +1,98 @@ from typing import List -import copy - -# debug only -# import sys -# if tid == tensor_id: print(f'out line: {sys._getframe().f_lineno}') - -from cube.graph.tensor import IRSubTensor, ValueMap from cube.graph.adapter.adapter import IRAdapter -from cube.graph.adapter.adapter import CollectivePrim from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass -# FIXME: all fusions don't consider input order! -# May get incorrect result in some cases. +from cube.graph.adapter.prim import IRAdapterPrim +from cube.graph.adapter.prim import AllReducePrim, AllGatherPrim, ReduceScatterPrim, AllToAllPrim +from cube.graph.adapter.prim import IdentityPrim, ChunkPrim +from cube.graph.adapter.prim import IdentityAllreducePrim, AllReduceIdentityPrim, AllReduceAllReducePrim +from cube.graph.adapter.prim import AllGatherReduceScatterPrim, ReduceScatterAllGatherPrim +from cube.graph.adapter.prim import SplitAllGatherPrim, AllGatherSplitPrim +from cube.graph.adapter.prim import AllToAllAllToAllPrim -# FIXME: all fusions don't check if the communication can be happened at -# the same time - -class P2PFusion(PlanPass): +class DiffFusion(PlanPass): @staticmethod def apply(execplan: ExectuionPlan) -> ExectuionPlan: - adapters = list() - for node in execplan.graph.nodes(): - if isinstance(node, IRAdapter): - adapters.append(node) - matchers = [ - P2PFusion.allreduce_matcher, - P2PFusion.allgather_matcher, - P2PFusion.reducescatter_matcher, - P2PFusion.broadcast_matcher, - ] - for matcher in matchers: - matcher(execplan, adapters) - # update adapter devices - for node in execplan.graph.nodes(): - if isinstance(node, IRAdapter): - node.update_device() - for devid in execplan.devices(): - for node in execplan.sequence(devid): - if isinstance(node, IRAdapter): - if devid not in node.device: - execplan.at(devid).remove(node) - return execplan - - @staticmethod - def allreduce_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): - """ - Allreduce semantic: - - Given a list of adapters: - 1). [Num] each adapter has different one input and same one output - 2). [Dev] inputs/outputs among adapters are from different devices - 3). [Dev] adapters have same device. adapters# is same to device set. - 4). [Indmap] inputs among adapters has same index-map with output. - 5). [Valmap] inputs have parital value-map. Output has full value-map - """ - outputs, groups = P2PFusion.group_by_output(all_adapters) - for tid in outputs: - cond = True - adapters: List[IRAdapter] = groups[tid] - # condition 1) - if not P2PFusion._check_multi_inputs(adapters): - continue - if not P2PFusion._check_same_inputs(adapters): - continue - # condition 2) - if not P2PFusion._check_different_inputs_devices(adapters, among=False): - continue - if not P2PFusion._check_different_outputs_devices(adapters, among=True): - continue - # condition 3) - for adapter in adapters: - if len(adapters) != len(adapter.device): - cond = False - break - if not cond: continue - # condition 4) - for adapter in adapters: - if not P2PFusion._check_indmap_same(adapter.inputs() + adapter.outputs()): - cond = False - break - if not cond: continue - # condition 5) - for adapter in adapters: - if not P2PFusion._check_valmap_no_overlap(adapter.inputs()): - cond = False - break - if not cond: continue - for adapter in adapters: - if adapter.outputs(0).valmap != ValueMap(0, 1): - cond = False - break - if not cond: continue - # generate - print(f'generating allreduce for tensor: {outputs[tid]} ...') - for adapter in adapters: - device = adapter.odevice(0) - input_idx = adapter.idevice().index(device) - inputs = [adapter.inputs(input_idx)] - coll = CollectivePrim( - ctype = CollectivePrim.Type.AllReduce, - device = device, - group = adapter.device, - inputs = inputs, - outputs = adapter.outputs(), - ) - adapter._prims = [coll] - for adapter in adapters: - all_adapters.remove(adapter) - - @staticmethod - def allgather_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): - """ - Allgather semantic: - - Given a list of adapters: - 1). [Num] each adapter has same multiple inputs and same one output - 2). [Dev] inputs/outputs among adapters are from different device. - 3). [Dev] adapters have same device. adapters# is same to device set. - 4). [Indmap] inputs inside one adapter are not overlapped - 5). [Valmap] each input value-map is same with output valuemap - """ - outputs, groups = P2PFusion.group_by_output(all_adapters) - for tid in outputs: - adapters: List[IRAdapter] = groups[tid] - cond = True - # condition 1) - if not P2PFusion._check_multi_inputs(adapters): - continue - if not P2PFusion._check_same_inputs(adapters): - continue - # condition 2) - if not P2PFusion._check_different_inputs_devices(adapters, among=False): - continue - if not P2PFusion._check_different_outputs_devices(adapters, among=True): - continue - # condition 3) - for adapter in adapters: - if len(adapters) != len(adapter.device): - cond = False - break - if not cond: continue - # condition 4) - for adapter in adapters: - if not P2PFusion._check_indmap_no_overlap(adapter.inputs()): - cond = False - break - if not cond: continue - # condition 5) - for adapter in adapters: - if not P2PFusion._check_valmap_same(adapter.inputs() + adapter.outputs()): - cond = False - break - if not cond: continue - # gen allgather - print(f'generating allgather for tensor: {outputs[tid]} ...') - for adapter in adapters: - device = adapter.odevice(0) - input_idx = adapter.idevice().index(device) - inputs = [adapter.inputs(input_idx)] - coll = CollectivePrim( - ctype = CollectivePrim.Type.AllGather, - device = device, - group = adapter.device, - inputs = inputs, - input_shapes = None, - input_dtypes = None, - outputs = adapter.inputs(), - output_shapes = None, - output_dtypes = None, - ) - # merge prim still keeps, remove select and move prims - prims = [coll] + adapter.prims(select=False, move=False, coll=False) - adapter._prims = prims - for adapter in adapters: - all_adapters.remove(adapter) - - @staticmethod - def reducescatter_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): - """ - ReduceScatter semantic: - - Given a list of adapters: - 1). [Num] each adapter has same multiple input and different one output - 2). [Dev] inputs/outputs among adapters are from different devices - 3). [Dev] adapters have same device. adapters# is same to device set - 4). [Indmap] inputs of each adapter have same index-map - 5). [Indmap] outputs among adapters have different index-map - 6). [Valmap] inputs of each adapter have different partial val-map - 7). [Valmap] outputs among adapters have same Full val-map - """ - inputs, groups = P2PFusion.group_by_input(all_adapters) - for tids in inputs: - adapters: List[IRAdapter] = groups[tids] - cond = True - # cond 1) - otids = [adapter.outputs(0)._id for adapter in adapters] - if len(set(otids)) != len(adapters): - continue - # cond 2) - if not P2PFusion._check_different_inputs_devices(adapters, among=False): - continue - if not P2PFusion._check_different_outputs_devices(adapters, among=True): - continue - # cond 3) - for adapter in adapters: - if len(adapters) != len(adapter.device): - cond = False - break - if not cond: continue - # cond 4) - for adapter in adapters: - if not P2PFusion._check_indmap_same(adapter.inputs()): - cond = False - break - if not cond: continue - # cond 5) - outputs = [adapter.outputs(0) for adapter in adapters] - if not P2PFusion._check_indmap_no_overlap(outputs): - continue - # cond 6) - for adapter in adapters: - if not P2PFusion._check_valmap_no_overlap(adapter.inputs()): - cond = False - break - if not cond: continue - # cond 7) - for adapter in adapters: - if adapter.outputs(0).valmap != ValueMap(0, 1): - cond = False - break - if not cond: continue - # gen reduce-scatter - print(f'generating reduce-scatter for tensor: {tids} ...') - all_select_prims = list() - for adapter in adapters: - all_select_prims += adapter.prims(move=False, merge=False, coll=False) - for adapter in adapters: - device = adapter.odevice(0) - sprims = [prim for prim in all_select_prims if prim.device == device] - if len(sprims) != len(adapters): - raise RuntimeError(f"got {len(sprims)} (!={len(adapters)}) select prims for reduce-scatter") - inputs = [sprim.output for sprim in sprims] - coll = CollectivePrim( - ctype = CollectivePrim.Type.ReduceScatter, - device = device, - group = adapter.device, - inputs = inputs, - outputs = adapter.outputs(), - ) - prims = sprims + [coll] - adapter._prims = prims - for adapter in adapters: - all_adapters.remove(adapter) - - @staticmethod - def broadcast_matcher(execplan: ExectuionPlan, all_adapters: List[IRAdapter]): - """ - Broadcast semantic: - - Given a list of adapters: - 1). [Num] each adapter has same one input and one output. input = output. - 2). [Dev] inputs among adapters are from a same device. - 3). [Dev] outputs among adapters are from different devices - """ - outputs, groups = P2PFusion.group_by_output(all_adapters) - for tid in outputs: - adapters: List[IRAdapter] = groups[tid] - cond = True - # note send can also be broadcast. We skip this case - if len(adapters) <= 2: - continue - # cond 1) - if not P2PFusion._check_same_inputs(adapters): - continue - if not P2PFusion._check_single_inputs(adapters): - continue - for adapter in adapters: - if adapter.inputs(0) != adapter.outputs(0): - cond = False - break - if not cond: continue - # cond 2) - root_device = set() - for adapter in adapters: - root_device.update(P2PFusion._get_input_devices(adapter)) - if len(root_device) != 1: - continue - # cond 3) - if not P2PFusion._check_different_outputs_devices(adapters, among=True): - continue - # gen broadcast - print(f'generating broadcast for tensor: {outputs[tid]} ...') - # put root rank to the first - root = list(root_device)[0] - group = set() - for adapter in adapters: - group.update(P2PFusion._get_output_devices(adapter)) - group = [root] + list(group) - # input - tensor = adapters[0].inputs(0) - - prims = list() - for device in group: - inputs = [tensor] if device == root else None - output_shapes = [tensor.shape] - output_dtypes = [tensor.dtype] - coll = CollectivePrim( - ctype = CollectivePrim.Type.Broadcast, - device = [device], - group = group, - inputs = inputs, - outputs = [tensor], - output_shapes = output_shapes, - output_dtypes = output_dtypes - ) - prims.append(coll) - - # add aditional adapter to root node - root_adapter = IRAdapter( - prims = [prims[0]], - inputs=[tensor], idevices=[[root],], - outputs=[tensor], odevices=[[root],] - ) - # insert into graph and execution plan - index = min([execplan.graph.nodes().index(n) for n in adapters]) - execplan.graph._nodes.insert(index, root_adapter) - seq = [node for node in execplan.graph.nodes() if root in node.device] - execplan.set(root, seq) - for adapter in adapters: - device = adapter.odevice(0)[0] - prim = prims[group.index(device)] - adapter._prims = [prim] - - for adapter in adapters: - all_adapters.remove(adapter) + def is_forward_adapter(adapter: IRAdapter): + return all(not t.is_grad() for t in adapter.inputs()) - # Utilities - @staticmethod - def group_by_output(adapters: List[IRAdapter]): - """ - Group the adapters by same output tensor - """ - tensors = dict() # tensor_id -> tensor - groups = dict() # tensor_id -> List[IRAdapter] - for adapter in adapters: - if len(adapter.outputs()) != 1: - raise RuntimeError("Expected only one output") - tensor = adapter.outputs(0) - tid = tensor._id - if tid not in tensors: - tensors[tid] = tensor - groups[tid] = list() - groups[tid].append(adapter) - return tensors, groups + cnt = 0 + for devid in execplan.devices(): + for node in execplan.seq(devid): + if isinstance(node, IRAdapter) and is_forward_adapter(node): + ret = DiffFusion.nnfuse(node) + cnt = cnt+1 if ret else cnt + print(f'successfully generate {cnt} differentiable adapters') + return execplan @staticmethod - def group_by_input(adapters: List[IRAdapter]): + def nnfuse(fadapter: IRAdapter) -> bool: """ - Group the adapters by same input tensor(s) - """ - tensors = dict() # Tuple[tensor_id] -> tensor - groups = dict() # Tuple[tensor_id] -> List[IRAdapter] - for adapter in adapters: - tids = [tensor._id for tensor in adapter.inputs()] - tids.sort() - tids = tuple(tids) - if tids not in tensors: - tensors[tids] = tensors - groups[tids] = list() - groups[tids].append(adapter) - return tensors, groups + Fuse the forward adapter with its backward adapter into differentiable + communications. Note this is an inplacement update - @staticmethod - def _check_same_inputs(adapters: List[IRAdapter]): + Return: + success: boolean """ - Check if the inputs are same among adapters - """ - input_ids = list() - for adapter in adapters: - tids = [t._id for t in adapter.inputs()] - tids.sort() - input_ids.append(tids) - ninputs = [len(tids) for tids in input_ids] - # number of inputs not same - if len(set(ninputs)) != 1: + if not isinstance(fadapter.mirror, IRAdapter): return False - # input ids not same - for tids in zip(*input_ids): - if len(set(tids)) != 1: - return False - return True - - @staticmethod - def _check_multi_inputs(adapters: List[IRAdapter]): - for adapter in adapters: - if len(adapter.inputs()) <= 1: - return False - return True - - @staticmethod - def _check_single_inputs(adapters: List[IRAdapter]): - for adapter in adapters: - if len(adapter.inputs()) != 1: - return False - return True - - @staticmethod - def _get_input_devices(adapter: IRAdapter) -> List[int]: - """ - Return sorted device list for all inputs - """ - device = set() - for idevice in adapter.idevice(): - device.update(idevice) - device = list(device) - device.sort() - return device - - @staticmethod - def _get_output_devices(adapter: IRAdapter) -> List[int]: - """ - Return sorted device list for all outputs - """ - device = set() - for odevice in adapter.odevice(): - device.update(odevice) - device = list(device) - device.sort() - return device - - @staticmethod - def _check_different_inputs_devices(adapters: List[IRAdapter], among: bool): - if among: - adapter_devices = list() - for adapter in adapters: - device = P2PFusion._get_input_devices(adapter) - adapter_devices.append(tuple(device)) - if len(set(adapter_devices)) != len(adapters): - return False - return True - else: - for adapter in adapters: - device = P2PFusion._get_input_devices(adapter) - # assume each tensor is attached to one deivce - if len(device) != len(adapter.inputs()): - return False - return True - - @staticmethod - def _check_different_outputs_devices(adapters: List[IRAdapter], among: bool): - if among: - adapter_devices = list() - for adapter in adapters: - device = set() - for odevice in adapter.odevice(): - device.update(odevice) - device = list(device) - device.sort() - adapter_devices.append(tuple(device)) - if len(set(adapter_devices)) != len(adapters): - return False - return True - else: - for adapter in adapters: - device = set() - for odevice in adapter.odevice(): - device.update(odevice) - # assume each tensor is attached to one deivce - if len(device) != len(adapter.outputs()): - return False - return True - - @staticmethod - def _check_indmap_same(tensors: List[IRSubTensor]): - if len(tensors) == 0: - return True - indmap = tensors[0].indmap - for tensor in tensors[1:]: - if tensor.indmap != indmap: - return False - return True - - @staticmethod - def _check_indmap_no_overlap(tensors: List[IRSubTensor]): - if len(tensors) == 0: - return True - for idx1 in range(len(tensors) - 1): - for idx2 in range(idx1 + 1, len(tensors)): - t1 = tensors[idx1] - t2 = tensors[idx2] - if t1.indmap.overlap(t2.indmap): - return False - return True - - @staticmethod - def _check_valmap_same(tensors: List[IRSubTensor]): - if len(tensors) == 0: - return True - valmap = tensors[0].valmap - for tensor in tensors[1:]: - if tensor.valmap != valmap: - return False - return True - - @staticmethod - def _check_valmap_no_overlap(tensors: List[IRSubTensor]): - if len(tensors) == 0: + badapter: IRAdapter = fadapter.mirror + fprims, bprims = fadapter.prims, badapter.prims + + def is_allreduce(prims: List[IRAdapterPrim]) -> bool: + return len(prims) == 1 and all(isinstance(prim, AllReducePrim) for prim in prims) + + def is_identity(prims: List[IRAdapterPrim]) -> bool: + return len(prims) == 1 and all(isinstance(prim, IdentityPrim) for prim in prims) + + def is_redsca(prims: List[IRAdapterPrim]) -> bool: # reduce-scatter + return len(prims) == 1 and all(isinstance(prim, ReduceScatterPrim) for prim in prims) + + def is_allgather(prims: List[IRAdapterPrim]) -> bool: + return len(prims) == 1 and all(isinstance(prim, AllGatherPrim) for prim in prims) + + def is_chunk(prims: List[IRAdapterPrim]) -> bool: + return len(prims) == 1 and all(isinstance(prim, ChunkPrim) for prim in prims) + + def is_alltoall(prims: List[IRAdapterPrim]) -> bool: + return len(prims) == 1 and all(isinstance(prim, AllToAllPrim) for prim in prims) + + prims = None + # allreduce-identity + if is_allreduce(fprims) and is_identity(bprims): + prims = [AllReduceIdentityPrim(p.inputs(), p.outputs(), **p.kwargs) for p in fprims] + # identity-allreduce + elif is_identity(fprims) and is_allreduce(bprims): + prims = [IdentityAllreducePrim(p.inputs(), p.outputs(), **bprims[0].kwargs) for p in fprims] + # allreduce-allreduce + elif is_allreduce(fprims) and is_allreduce(bprims): + prims = [AllReduceAllReducePrim(p.inputs(), p.outputs(), **p.kwargs) for p in fprims] + # allgather-reducescatter + elif is_allgather(fprims) and is_redsca(bprims): + prims = [AllGatherReduceScatterPrim(p.inputs(), p.outputs(), **p.kwargs) for p in fprims] + # reducescatter-allgather + elif is_redsca(fprims) and is_allgather(bprims): + prims = [ReduceScatterAllGatherPrim(p.inputs(), p.outputs(), **p.kwargs) for p in fprims] + # allgather-chunk + elif is_allgather(fprims) and is_chunk(bprims): + prims = [AllGatherSplitPrim(p.inputs(), p.outputs(), **p.kwargs) for p in fprims] + # chunk-allgather + elif is_chunk(fprims) and is_allgather(bprims): + prims = [SplitAllGatherPrim(p.inputs(), p.outputs(), **p.kwargs) for p in fprims] + # all-to-all + elif is_alltoall(fprims) and is_alltoall(bprims): + prims = [AllToAllAllToAllPrim(p.inputs(), p.outputs(), **p.kwargs) for p in fprims] + + if prims is not None: + fadapter.prims = prims + badapter.prims = prims + fadapter.differentiable = True + badapter.differentiable = True return True - for idx1 in range(len(tensors) - 1): - for idx2 in range(idx1 + 1, len(tensors)): - t1 = tensors[idx1] - t2 = tensors[idx2] - if t1.valmap.overlap(t2.valmap): - return False - return True + return False diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 139dc03a..b89add5c 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -22,22 +22,15 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: fgroups, bgroups = Grouping.group(execplan) for devid in execplan.devices(): for fpieces, bpieces in zip(fgroups[devid], bgroups[devid]): - fsubgraph = graph.subgraph(fpieces) - fsubgraph.device = devid + fsubgraph = graph.segment(fpieces) if bpieces is not None: - bsubgraph = graph.subgraph(bpieces) - bsubgraph.device = devid + bsubgraph = graph.segment(bpieces) IRCell.make_pair(fsubgraph, bsubgraph) subgraphs = [fsubgraph] if bpieces is None else [fsubgraph, bsubgraph] for subgraph in subgraphs: - pieces = subgraph.nodes() - # update graph: replace the nodes with the subgraph - idx = graph.nodes().index(pieces[0]) - graph._nodes.insert(idx, subgraph) - for node in pieces: - graph._nodes.remove(node) # update execution plan: replace the nodes with the subgraph - idx = execplan.sequence(devid).index(pieces[0]) + pieces = subgraph.nodes() + idx = execplan.seq(devid).index(pieces[0]) execplan.at(devid).insert(idx, subgraph) for node in pieces: execplan.at(devid).remove(node) @@ -55,13 +48,21 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: Returns: Tuple: (fgroups, bgroups) """ + def is_forward_adapter(adapter: IRAdapter) -> bool: + return all(not t.is_grad() for t in adapter.inputs()) + fgroups, bgroups = dict(), dict() for devid in execplan.devices(): fgroups[devid], bgroups[devid] = list(), list() fpieces, bpieces = list(), list() - seq = execplan.sequence(devid) - fnodes = [fnode for fnode in seq if isinstance(fnode, IRFwOperation)] - have_backward = all([fnode.mirror in seq for fnode in fnodes]) + seq = execplan.seq(devid) + fnodes = [] + for fnode in seq: + if isinstance(fnode, IRFwOperation): + fnodes.append(fnode) + if isinstance(fnode, IRAdapter) and fnode.differentiable and is_forward_adapter(fnode): + fnodes.append(fnode) + have_backward = all(fnode.mirror in seq for fnode in fnodes) # training if have_backward: bnodes = [fnode.mirror for fnode in fnodes] @@ -112,57 +113,3 @@ def consecutive(seq: List[IRCell], pieces: List[IRCell], node: IRCell): if idx != max(pidx) + 1 and idx != min(pidx) - 1: return False return True - - -class GroupingAdapter(PlanPass): - - @staticmethod - def apply(execplan: ExectuionPlan) -> ExectuionPlan: - for devid in execplan.devices(): - groups: List[List[IRAdapter]] = GroupingAdapter.consecutive( - execplan.sequence(devid)) - for adapters in groups: - if len(adapters) <= 1: - continue - sprims, tprims, mprims = list(), list(), list() - inputs, idevices = list(), list() - outputs, odevices = list(), list() - for adapter in adapters: - sprims += adapter.prims(move=False, merge=False, coll=False) - tprims += adapter.prims(select=False, merge=False) - mprims += adapter.prims(select=False, move=False, coll=False) - for idx, input in enumerate(adapter.inputs()): - if devid in adapter.idevice(idx): - if input not in inputs: - inputs.append(input) - idevices.append(adapter.idevice(idx)) - for idx, output in enumerate(adapter.outputs()): - if devid in adapter.odevice(idx): - if output not in outputs: - outputs.append(output) - odevices.append(adapter.odevice(idx)) - prims = sprims + tprims + mprims - fused_adapter = IRAdapter(prims, - inputs = inputs, idevices = idevices, - outputs = outputs, odevices = odevices) - start = execplan.sequence(devid).index(adapters[0]) - end = execplan.sequence(devid).index(adapters[-1]) - for _ in range(end - start + 1): - execplan.at(devid).pop(start) - execplan.at(devid).insert(start, fused_adapter) - return execplan - - @staticmethod - def consecutive(seq: List[IRCell]) -> List[List[IRAdapter]]: - group = list() - curr = list() - curr_idx = -1 - for idx, node in enumerate(seq + [None]): - if isinstance(node, IRAdapter) and idx == curr_idx + 1: - curr.append(node) - else: - if len(curr) != 0: - group.append(curr) - curr = list() - curr_idx = idx - return group diff --git a/cube/execplan/planpass/torchadapt.py b/cube/execplan/planpass/torchadapt.py deleted file mode 100644 index b1cf460e..00000000 --- a/cube/execplan/planpass/torchadapt.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -PyTorch Adapter for multi-branch reference - -If a tensor is the input for multiple operators: - - the gradient of this tensor will be value splitted for each op-backward. - -However, in pytorch, the gradient is accumulated by default, this -will cause inconsistent behaviour for transoform SU when the referred -operators are on the same device or not. - -For the situation when the referred operators are on different devices: - Nothing happens - -For the situation when the referred operators are on same device: - The gradient will change to match `auto accumulation` semantics. - For first referred op: grad will be set to ValueMap(idx, num_referred_devices) - For other referred op: grad is set to None -""" - -from typing import Dict - -from cube.execplan import ExectuionPlan -from cube.graph.tensor import IRSubTensor, ValueMap -from cube.schedule.adapter.transform import IRTensorTransform -from cube.schedule.su import SUType, ScheduleUnit -from cube.execplan.planpass.planpass import PlanPass - - -class TorchRefAdapter(PlanPass): - - @staticmethod - def apply(execplan: ExectuionPlan): - # same device multiple reference - multiref_fsus, multiref_fnodes = TorchRefAdapter.multi_ref_cells(execplan) - for tid in multiref_fsus: - print(f'multi-referred tensor id: {tid}') - for devid in multiref_fsus[tid]: - for fsu in multiref_fsus[tid][devid]: - print(f'dev {devid}: {fsu}') - - - for tid in multiref_fsus: - # check chunk num for each device - total_ops = set() - for devid in multiref_fnodes[tid]: - for op in multiref_fnodes[tid][devid]: - total_ops.add(op._id) - total_ops = list(total_ops) - num_ops = len(total_ops) - # how many ops are computed for each device - dev_ops = dict() - for devid in multiref_fnodes[tid]: - op_index = list() - for op in multiref_fnodes[tid][devid]: - op_index.append(total_ops.index(op._id)) - cnt = len(op_index) - if cnt != 1 and cnt != num_ops: - raise NotImplementedError("Only support even chunk for multi-ref") - dev_ops[devid] = op_index - - for idx, devid in enumerate(multiref_fsus[tid]): - # the value map should be op_num / total_ops - op_index = dev_ops[devid] - if len(op_index) == num_ops: - grad_idx, grad_num = 0, 1 - elif len(op_index) == 1: - grad_idx, grad_num = op_index[0], num_ops - - # the first forward, the last backward - fsu = multiref_fsus[tid][devid][0] - ftensor = None - for input in fsu.inputs(): - if isinstance(input, IRSubTensor): - if input._id == tid: - ftensor = input - break - if ftensor is None: - raise RuntimeError("Internal Error: fsu not found input tensor") - grad = ftensor.parent.grad.select( - indmap = ftensor.indmap, - valmap = ValueMap(grad_idx, grad_num), - shape = ftensor.shape - ) - rm_grad = TorchRefAdapter.set_grad(fsu, ftensor, grad) - TorchRefAdapter.replace_all(execplan, rm_grad, grad, devid) - - # all the other reference place: set grad to none - for fsu in multiref_fsus[tid][devid][1:]: - rm_grad = TorchRefAdapter.set_grad(fsu, ftensor, grad=None) - TorchRefAdapter.replace_all(execplan, rm_grad, None, devid) - - print(execplan) - - # reset select and merge adapters - for devid in execplan.devices(): - for idx, su in enumerate(execplan.sequence(devid)): - if su.stype == SUType.Transform: - ins = [input for input in su.inputs() if input is not None] - ous = [ou for ou in su.outputs() if ou is not None] - if len(ins) < len(su.inputs()) or len(ous) < len(su.outputs()): - for ou in ous: - if ou in ins: - break - trans = IRTensorTransform( - src_tensors=ins, dst_tensors=ous - ) - trans_su = ScheduleUnit([trans], SUType.Transform, name='trans') - trans_su.device = devid - if len(trans_su.outputs()) == 0: - # meaning outputs in inputs - execplan.at(devid).remove(su) - execplan.sugraph.sequence.remove(su) - else: - execplan.at(devid)[idx] = trans_su - suidx = execplan.sugraph.sequence.index(su) - execplan.sugraph.sequence[suidx] = trans_su - execplan.sugraph.reset_dependency(execplan.sugraph.sus()) - return execplan - - @staticmethod - def multi_ref_cells(execplan: ExectuionPlan) -> Dict: - """ - Return: - { - sub_tensor id: - device id: - [forward su or forward node] - } - """ - fnodes = dict() - fsus = dict() - for devid in execplan.devices(): - for fsu in execplan.sequence(devid): - if fsu.stype == SUType.Forward: - for input in fsu.inputs(): - if isinstance(input, IRSubTensor): - tid = input._id - if tid not in fnodes: - fnodes[tid] = dict() - fsus[tid] = dict() - if devid not in fnodes[tid]: - fnodes[tid][devid] = list() - fsus[tid][devid] = list() - fsus[tid][devid].append(fsu) - for node in fsu.nodes(): - if input in node.inputs(): - fnodes[tid][devid].append(node) - multiref_fnodes = dict() - multiref_sus = dict() - for tid in fnodes: - for devid in fnodes[tid]: - if len(fnodes[tid][devid]) != 1: - multiref_sus[tid] = fnodes[tid] - multiref_fnodes[tid] = fsus[tid] - break - return multiref_fnodes, multiref_sus - - - @staticmethod - def set_grad(fsu: ScheduleUnit, input: IRSubTensor, grad): - """ - Return removed grad - """ - if not isinstance(fsu, ScheduleUnit) or fsu.stype != SUType.Forward: - raise TypeError("Require SU to be forward SU") - # forward SU - findex = fsu.inputs().index(input) - fsu.inputs(findex).grad = grad - if not len(fsu.nodes()) == 1: - raise RuntimeError("TorchAdapt should call before merge") - fnode = fsu.nodes(0) - findex = fnode.inputs().index(input) - fnode.inputs(findex).grad = grad - # backward SU - bsu = fsu.mirror - bindex = bsu.inputs().index(input) - bin = bsu.inputs(bindex) - try: - gindex = bsu.outputs().index(bin.grad) - except ValueError: - raise RuntimeError( - (f"Internal Error: cannot find given grad in bsu: {bsu}:\n" - f"gradient given tensor: {bin}, grad: {bin.grad}") - ) - removed_grad = bin.grad - bin.grad = grad - bsu.set_output(gindex, grad) - return removed_grad - - @staticmethod - def replace_all(execplan: ExectuionPlan, src: IRSubTensor, dst, devid: int): - for su in execplan.sequence(devid): - # pair removement for p2p will already remove su - if su not in execplan.at(devid): - continue - rm_su = None - if src in su.inputs(): - if len(su.inputs()) == 1 and dst is None: - execplan.at(devid).remove(su) - execplan.sugraph.sequence.remove(su) - rm_su = su - else: - index = su.inputs().index(src) - su.set_input(index, dst) - if src in su.outputs(): - if len(su.outputs()) == 1 and dst is None: - execplan.at(devid).remove(su) - execplan.sugraph.sequence.remove(su) - rm_su = su - else: - index = su.outputs().index(src) - su.set_output(index, dst) - # pair removement - if rm_su is not None and rm_su.stype == SUType.P2P: - mirror = rm_su.mirror - dev = mirror.device[0] - if mirror in execplan.at(dev): - execplan.at(dev).remove(mirror) - execplan.sugraph.sequence.remove(mirror) diff --git a/cube/graph/adapter/gen.py b/cube/graph/adapter/gen.py index d5e5435d..4102ab08 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/adapter/gen.py @@ -127,13 +127,16 @@ def gen_fulltensor(ftensor: IRFullTensor) -> List[IRAdapter]: cdevs = set() for cnode in ftensor.consumers: cdevs.update(cnode.device) + # sharing devices if pdevs == cdevs: return IRAdapterGener.gen_gridlayout(ftensor) + # no-sharing devices # elif len(pdevs.intersection(cdevs)) == 0: # print(f'detect no intersection') # return [] + # general cases warnings.warn('The adapter is generated using inefficient P2P send/recv') fprims, bprims = [], [] @@ -141,14 +144,12 @@ def gen_fulltensor(ftensor: IRFullTensor) -> List[IRAdapter]: fprims += IRAdapterGener.gen_subtensor(subtensor) fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) fadapter.prims = fprims - # print(fadapter.extra_repr()) grad: IRFullTensor = ftensor.grad if grad is not None: for subtensor in grad.ctensors: bprims += IRAdapterGener.gen_subtensor(subtensor) badapter = IRAdapter(grad.ptensors, grad.ctensors) badapter.prims = bprims - # print(badapter.extra_repr()) IRCell.make_pair(fadapter, badapter) if len(fprims) == 0 and len(bprims) == 0: return [] @@ -227,7 +228,6 @@ def gen_gridlayout(ftensor: IRFullTensor) -> List[IRAdapter]: IRCell.make_pair(fadapter, badapter) if len(fprims) == 0 and len(bprims) == 0: return [] - # print('=====') return [fadapter] @staticmethod @@ -268,7 +268,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: continue common = itensor.common(subtensor) common.attach_cell(itensor._cell) - print(f'get common: {common.extra_repr()}') + # print(f'get common: {common.extra_repr()}') intersections.append(common) if common == itensor: continue @@ -284,7 +284,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: # TODO: check union == subtensor if common == subtensor: break - print(intersections) + # print(intersections) # ====== move ===== # tmoved = [] for tensor in intersections: @@ -359,4 +359,3 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: f"SubTensor:\n\t{subtensor.extra_repr()}" ) return prims - \ No newline at end of file diff --git a/cube/graph/adapter/layout.py b/cube/graph/adapter/layout.py index 60499657..e490c2b3 100644 --- a/cube/graph/adapter/layout.py +++ b/cube/graph/adapter/layout.py @@ -155,9 +155,10 @@ def r2d(self, dim: int, chunks: int): otensor._cell = itensor._cell prims = [] for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - ranks = tuple(t.device[0] for t in itensors) - for idx, (itensor, otensor) in enumerate(zip(itensors, otensors)): - prims.append(ChunkPrim(itensor, otensor, dim, ranks)) + prims.append(ChunkPrim(itensors, otensors, dim)) + # ranks = tuple(t.device[0] for t in itensors) + # for idx, (itensor, otensor) in enumerate(zip(itensors, otensors)): + # prims.append(ChunkPrim(itensor, otensor, dim, ranks)) return glayout, prims # ================ solution ============= # diff --git a/cube/graph/adapter/prim.py b/cube/graph/adapter/prim.py index 06e0c61e..e6262c22 100644 --- a/cube/graph/adapter/prim.py +++ b/cube/graph/adapter/prim.py @@ -2,19 +2,22 @@ The primitive used for IRAdapter """ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import copy from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap + # the general adapter primitive class class IRAdapterPrim: - def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): + def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor], **kwargs): self._inputs = list(inputs) self._outputs = list(outputs) self._device = [] self.kwargs = dict() + for arg, val in kwargs.items(): + self.kwargs[arg] = val self.signature = None def inputs(self, idx: Optional[int] = None): @@ -51,8 +54,8 @@ class SpatialPrim(IRAdapterPrim): """ basic class for representing spatial primitives """ - def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): - super().__init__(inputs, outputs) + def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor], **kwargs): + super().__init__(inputs, outputs, **kwargs) self.device = list(set(t.device[0] for t in inputs)) @@ -71,10 +74,8 @@ class CommPrim(IRAdapterPrim): """ communication primitive """ - def __init__(self, - itensors: List[IRSubTensor], - otensors: List[IRSubTensor]): - super().__init__(itensors, otensors) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + super().__init__(itensors, otensors, **kwargs) devices = [] for t in list(itensors) + list(otensors): devices += t.device @@ -109,9 +110,7 @@ def __init__(self, itensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, otensor: IRSubTensor): - super().__init__([itensor], [otensor]) - self.kwargs['indmap'] = indmap - self.kwargs['valmap'] = (valmap.idx, valmap.chunk_num) + super().__init__([itensor], [otensor], indmap=indmap, valmap=(valmap.idx, valmap.chunk_num)) self.signature = f"cube.runtime.adapter.select" def __repr__(self): @@ -119,32 +118,13 @@ def __repr__(self): return dscp -class ChunkPrim(SpatialPrim): - """ - split dimension in n chunks and take idx-th chunk - """ - def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor, - dim: int, ranks: Tuple[int]): - assert itensor.device[0] in ranks, "idx out of scope" - super().__init__([itensor], [otensor]) - self.kwargs['dim'] = dim - self.kwargs['ranks'] = ranks - self.signature = 'cube.runtime.adapter.chunk' - - def __repr__(self) -> str: - chunks = len(self.kwargs['ranks']) - idx = self.kwargs['ranks'].index(self.device[0]) - return f"dev{self.device}: {self.outputs(0)} = split(dim={self.kwargs['dim']}, chunks={chunks}, idx={idx})" - - class MergeDimPrim(SpatialPrim): """ concatenate dimension """ def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, dim: int) -> None: assert all(itensor.device == itensors[0].device for itensor in itensors), "device not same" - super().__init__(itensors, [otensor]) - self.kwargs['dim'] = dim + super().__init__(itensors, [otensor], dim=dim) self.signature = 'cube.runtime.adapter.smerge' def __repr__(self) -> str: @@ -169,9 +149,7 @@ class SendPrim(CommPrim): P2P send prim """ def __init__(self, tensor, dst: int): - super().__init__([tensor], [tensor]) - self.kwargs['dst'] = dst - self.device = tensor.device + super().__init__([tensor], [tensor], dst=dst) self.signature = 'cube.runtime.adapter.send' def __repr__(self) -> str: @@ -183,11 +161,8 @@ class RecvPrim(CommPrim): P2P recv prim """ def __init__(self, tensor: IRSubTensor, src: int): - super().__init__([], [tensor]) - self.kwargs['shape'] = tensor.shape - self.kwargs['dtype'] = 'torch.' + tensor.dtype.value - self.kwargs['src'] = src - self.device = tensor.device + super().__init__([], [tensor], + shape=tensor.shape, dtype='torch.'+tensor.dtype.value, src=src) self.signature = 'cube.runtime.adapter.recv' def __repr__(self) -> str: @@ -199,10 +174,8 @@ class MovePrim(CommPrim): P2P send/recv, non-differentiable """ def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor): - super().__init__([itensor], [otensor]) - self.kwargs['src'] = itensor.device[0] - self.kwargs['dst'] = otensor.device[0] - self.device = itensor.device + otensor.device + assert itensor.device != otensor.device, "no movement detected." + super().__init__([itensor], [otensor], src=itensor.device[0], dst=otensor.device[0]) def dispatch(self, devid: int) -> Union[SendPrim, RecvPrim]: if devid == self.kwargs['src']: @@ -221,10 +194,9 @@ class CollectivePrim(CommPrim): Collective primitive, non-differentiable """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): - super().__init__(itensors, otensors) - self.kwargs['ranks'] = self.device - for arg, val in kwargs.items(): - self.kwargs[arg] = val + super().__init__(itensors, otensors, **kwargs) + if 'ranks' not in self.kwargs: + self.kwargs['ranks'] = self.device def dispatch(self, devid: int) -> Optional[CommPrim]: """ @@ -235,7 +207,7 @@ def dispatch(self, devid: int) -> Optional[CommPrim]: assert devid in self.device, f"device {devid} not applied for this comm primitive" itensors = [itensor for itensor in self.inputs() if devid in itensor.device] otensors = [otensor for otensor in self.outputs() if devid in otensor.device] - prim = CollectivePrim(itensors, otensors, **self.kwargs) + prim = type(self)(itensors, otensors, **self.kwargs) prim.signature = self.signature return prim @@ -244,8 +216,8 @@ class AllReducePrim(CollectivePrim): """ non-differentiable allreduce """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): - super().__init__(itensors, otensors) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + super().__init__(itensors, otensors, **kwargs) self.signature = 'cube.runtime.adapter.all_reduce' def __repr__(self) -> str: @@ -256,8 +228,8 @@ class AllGatherPrim(CollectivePrim): """ non-differentiabl all-to-all """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): - super().__init__(itensors, otensors, dim=dim) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): + super().__init__(itensors, otensors, dim=dim, **kwargs) self.signature = 'cube.runtime.adapter.all_gather' def __repr__(self) -> str: @@ -268,8 +240,8 @@ class ReduceScatterPrim(CollectivePrim): """ non-differential reduce-scatter """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): - super().__init__(itensors, otensors, dim=dim) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): + super().__init__(itensors, otensors, dim=dim, **kwargs) self.signature = 'cube.runtime.adapter.reduce_scatter' def __repr__(self) -> str: @@ -280,104 +252,149 @@ class BroadcastPrim(CollectivePrim): """ non-differential reduce-scatter """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], src: int): - super().__init__(itensors, otensors, src=src) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], src: int, **kwargs): + super().__init__(itensors, otensors, src=src, **kwargs) class ReducePrim(CollectivePrim): """ non-differential reduce prim """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dst: int): - super().__init__(itensors, otensors, dst=dst) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dst: int, **kwargs): + super().__init__(itensors, otensors, dst=dst, **kwargs) class AllToAllPrim(CollectivePrim): """ non-differentiable all-to-all """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idim: int, odim: int): + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idim: int, odim: int, **kwargs): """ itensors: each rank hosts one tensor splitted by idim otensors: each rank hosts one tensor splitted by odim idim != odim """ - super().__init__(itensors, otensors, idim=idim, odim=odim) + super().__init__(itensors, otensors, idim=idim, odim=odim, **kwargs) self.signature = 'cube.runtime.adapter.all_to_all' def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs()} = all_to_all({self.inputs}, idim={self.kwargs['idm']}, odim={self.kwargs['odim']})" + return f"dev{self.device}: {self.outputs()} = all_to_all({self.inputs()}, idim={self.kwargs['idm']}, odim={self.kwargs['odim']})" -class DiffCollectivePrim(CollectivePrim): +class ChunkPrim(CollectivePrim): """ - Differentiable collective primitive + split dimension in n chunks and take idx-th chunk """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): - """ - differentiable collectives - """ - super().__init__(itensors, otensors, **kwargs) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): + super().__init__(itensors, otensors, dim=dim, **kwargs) + self.signature = 'cube.runtime.adapter.chunk' + + def __repr__(self) -> str: + return f"dev{self.device}: {self.outputs()} = split({self.inputs()}, dim={self.kwargs['dim']})" -class AllReduceIdentityPrim(DiffCollectivePrim): +class AllReduceIdentityPrim(AllReducePrim): """ forward: allreduce. backward: identity """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): - super().__init__(itensors, otensors) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + super().__init__(itensors, otensors, **kwargs) + self.signature = 'cube.runtime.adapter.nn.allreduce_identity' + + def __repr__(self) -> str: + return f"dev{self.device}: {self.outputs()} = nn.allreduce_identity({self.inputs()})" -class IdentityAllreducePrim(DiffCollectivePrim): +class IdentityAllreducePrim(AllReducePrim): """ forward: identity backward: allreduce """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor]): - super().__init__(itensors, otensors) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + super().__init__(itensors, otensors, **kwargs) + self.signature = 'cube.runtime.adapter.nn.identity_allreduce' + + def __repr__(self) -> str: + return f"dev{self.device}: {self.outputs()} = nn.identity_allreduce({self.inputs()})" + + +class AllReduceAllReducePrim(AllReducePrim): + """ + forward: allreduce + backward: allreduce + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + super().__init__(itensors, otensors, **kwargs) + self.signature = 'cube.runtime.adapter.nn.allreduce_allreduce' + def __repr__(self) -> str: + return f"dev{self.device}: {self.outputs} = nn.allreduce_allreduce({self.inputs()}" -class ReduceScatterAllGatherPrim(DiffCollectivePrim): + +class ReduceScatterAllGatherPrim(ReduceScatterPrim): """ forward: reduce-scatter backward: all-gather """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): - super().__init__(itensors, otensors, dim=dim) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): + super().__init__(itensors, otensors, dim, **kwargs) + self.signature = 'cube.runtime.adapter.nn.reducescatter_allgather' + + +class AllGatherReduceScatterPrim(AllGatherPrim): + """ + forward: all-gather + backward: reduce-scatter + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): + super().__init__(itensors, otensors, dim, **kwargs) + self.signature = 'cube.runtime.adapter.nn.allgather_reducescatter' -class AllGatherSplitPrim(DiffCollectivePrim): +class AllGatherSplitPrim(AllGatherPrim): """ forward: all-gather backward: split """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): - super().__init__(itensors, otensors, dim=dim) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): + super().__init__(itensors, otensors, dim, **kwargs) + self.signature = 'cube.runtime.adapter.nn.allgather_split' -class SplitAllGatherPrim(DiffCollectivePrim): +class SplitAllGatherPrim(AllGatherPrim): """ forward: split backward: all-gather """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int): - super().__init__(itensors, otensors, dim=dim) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): + super().__init__(itensors, otensors, dim, **kwargs) + self.signature = 'cube.runtime.adapter.nn.allgather_split' + + +class AllToAllAllToAllPrim(AllToAllPrim): + """ + forward: all-to-all + backward: all-to-all + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idim: int, odim: int, **kwargs): + super().__init__(itensors, otensors, idim, odim, **kwargs) + self.signature = 'cube.runtime.adapter.nn.alltoall_alltoall' -class ReduceBroadcastPrim(DiffCollectivePrim): +class ReduceBroadcastPrim(CollectivePrim): """ forward: broadcast backward: reduce """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dst: int): - super().__init__(itensors, otensors, dst=dst) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dst: int, **kwargs): + super().__init__(itensors, otensors, dst=dst, **kwargs) -class BroadcastRedducePrim(DiffCollectivePrim): +class BroadcastRedducePrim(CollectivePrim): """ forward: broadcast backward: reduce """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], src: int): - super().__init__(itensors, otensors, src=src) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], src: int, **kwargs): + super().__init__(itensors, otensors, src=src, **kwargs) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ca388e08..08440b19 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -18,7 +18,47 @@ from cube.algorithm.generics import GenericDistAlgo -__all__ = ['IRGraph'] +class IRSegment(IRCell): + """ + A segment refers to a piece of workload of IRGraph + """ + + def __init__(self, nodes: List[IRCell], inputs: List[IRSubTensor], outputs: List[IRSubTensor]): + self._nodes = nodes + super().__init__('segment', '', len(inputs), len(outputs), init_outputs=False) + for idx, val in enumerate(inputs): + self.set_input(idx, val) + for idx, val in enumerate(outputs): + self.set_output(idx, val) + # setup device + device = set() + for node in nodes: + device.update(node.device) + self.device = list(device) + # setup whether forward + fnodes = any(isinstance(n, IRFwOperation) for n in nodes) + bnodes = any(isinstance(n, IRBpOperation) for n in nodes) + assert not (fnodes and bnodes), "An IRSegment cannot have both forward nodes and backward nodes" + self._forward = fnodes + + @property + def forward(self) -> bool: + return self._forward + + def nodes(self, idx: Optional[int] = None) -> Union[IRCell, List[IRCell]]: + if isinstance(idx, int): + return self._nodes[idx] + else: + return copy.copy(self._nodes) + + def __repr__(self): + return f'Segment{self._id}(inputs={self.inputs()}, outputs={self.outputs()})' + + def extra_repr(self) -> str: + dscp = repr(self) + for node in self.nodes(): + dscp += '\n\t' + repr(node) + return dscp class IRGraph(IRCell): @@ -154,44 +194,31 @@ def __call__(self, *args): """ return self.forward(*args) - def subgraph(self, sub_nodes: List[IRCell]): + def segment(self, nodes: List[IRCell]) -> IRSegment: """ - Create a subgraph with sub nodes. + Create a segment (sub-graph) with part of the nodes. Return: - IRGraph - """ - sub_inputs = list() - sub_outputs = list() - for node in sub_nodes: - sub_inputs += node.inputs() - sub_outputs += node.outputs() - remain_inputs = list() - remain_outputs = list() - for node in self.nodes(): - if node in sub_nodes: - continue - remain_inputs += node.inputs() - remain_outputs += node.outputs() - inputs = list() - outputs = list() - for t in sub_inputs: - if isinstance(t, IRSubTensor) and t not in sub_outputs: - if t not in inputs: - inputs.append(t) - for t in sub_outputs: - if isinstance(t, IRSubTensor): - # not consumed or used outside this subgraph - if t not in sub_inputs or t in remain_inputs or t in self.outputs(): - if t not in outputs: - outputs.append(t) - subgraph = IRGraph( - nodes = sub_nodes, - inputs = inputs, - outputs = outputs, - module_name = 'segment' - ) - return subgraph + IRSegment + """ + inputs, outputs = [], [] + for node in nodes: + # update inputs + itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] + for itensor in itensors: + producers = [p for p in itensor.parent.producers if p.device == node.device] + # no producer means a weight + if len(producers) == 0 or any(p not in nodes for p in producers): + inputs.append(itensor) + # update outputs + otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + for otensor in otensors: + consumers = [c for c in otensor.parent.consumers if c.device == node.device] + # no consumer usually means the loss + if len(consumers) == 0 or any(c not in nodes for c in consumers): + outputs.append(otensor) + segment = IRSegment(nodes, inputs, outputs) + return segment def detach(self, node: IRCell, reset_dependency=False) -> int: """ @@ -207,6 +234,8 @@ def detach(self, node: IRCell, reset_dependency=False) -> int: raise KeyError(f"node {node} is not in graph.") index = self._nodes.index(node) self._nodes.pop(index) + if isinstance(node, IRAdapter): + return index for itensor in node.inputs(): if isinstance(itensor, IRSubTensor): itensor.parent.rm_consumer(node) @@ -222,11 +251,14 @@ def attach(self, node: IRCell, index, reset_dependency=False): Attach (insert) a node into current graph at node index. All the used input and output tensors inside the node are - recorded in consumed and produced tensor list. + recorded in consumed and produced tensor list. Adapter node + will not record the consumer and producer. """ if node in self.nodes(): raise KeyError(f"node {node} is already in graph.") self._nodes.insert(index, node) + if isinstance(node, IRAdapter): + return # update consumer for itensor in node.inputs(): if isinstance(itensor, IRSubTensor): @@ -647,10 +679,3 @@ def module_repr(self): return repr(self) -class IRSegment(IRCell): - """ - A segment refers to a piece of workload of IRGraph - """ - - def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRTensor]): - pass diff --git a/cube/runtime/adapter/__init__.py b/cube/runtime/adapter/__init__.py index ed7e9fc0..6eb54da8 100644 --- a/cube/runtime/adapter/__init__.py +++ b/cube/runtime/adapter/__init__.py @@ -1,5 +1,4 @@ from cube.runtime.adapter.collectives import * from cube.runtime.adapter.transform import * - -# reducer +from cube.runtime.adapter import nn from cube.runtime.adapter.reducer import Reducer diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 054f2562..00e2114f 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -160,6 +160,23 @@ def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) - return otensor +def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: + """ + split dimension in n chunks and take idx-th chunk + + ranks (Tuple[int]): the order of split tensor. + """ + group = DeviceGroup().get_group(ranks) + idx = torch.distributed.get_rank(group) + require_grad = itensor.requires_grad + with torch.no_grad(): + otensor = itensor.chunk(len(ranks), dim)[idx] + otensor = otensor.detach() + if require_grad: + otensor = otensor.requires_grad_() + return otensor + + def broadcast(input_tensors: List[torch.Tensor], output_shapes: List[List[int]], output_dtypes: List[torch.dtype], diff --git a/cube/runtime/adapter/distnn.py b/cube/runtime/adapter/distnn.py deleted file mode 100644 index c15649d2..00000000 --- a/cube/runtime/adapter/distnn.py +++ /dev/null @@ -1,285 +0,0 @@ -from typing import List -import torch - -from cube.profiler.timer import CudaTimer -from cube.runtime.device import DeviceGroup - - -class SendRecv(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dst: int, group): - CudaTimer().start(field_name='comm') - ctx._tsize = input_.size() - ctx._tdtype = input_.dtype - ctx._src = dst - if not input_.is_contiguous(): - input_ = input_.contiguous() - sendop = torch.distributed.P2POp( - torch.distributed.isend, input_, dst - ) - reqs = torch.distributed.batch_isend_irecv([sendop]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, _grad: torch.Tensor): - CudaTimer().start(field_name='comm') - size = ctx._tsize - dtype = ctx._tdtype - src = ctx._src - grad = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) - recvop = torch.distributed.P2POp( - torch.distributed.irecv, grad, src - ) - reqs = torch.distributed.batch_isend_irecv([recvop]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class RecvSend(torch.autograd.Function): - - @staticmethod - def forward(ctx, size, dtype, src: int, ranks: List[int]): - CudaTimer().start(field_name='comm') - ctx._tsize = size - ctx._tdtype = dtype - ctx._dst = src - input_ = torch.empty( - size, dtype=dtype, device=torch.cuda.current_device(), - requires_grad=True) - recvop = torch.distributed.P2POp( - torch.distributed.irecv, input_, src - ) - reqs = torch.distributed.batch_isend_irecv([recvop]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad: torch.Tensor): - CudaTimer().start(field_name='comm') - dst = ctx._dst - if not grad.is_contiguous(): - grad = grad.contiguous() - sendop = torch.distributed.P2POp( - torch.distributed.isend, grad, dst - ) - reqs = torch.distributed.batch_isend_irecv([sendop]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return None, None, None, None - - -class AllReduceIdentity(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, ranks: List[int]): - group = DeviceGroup().get_group(ranks) - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - torch.distributed.all_reduce(input_, group=group) - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -class IdentityAllreduce(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, ranks: List[int]): - group = DeviceGroup().get_group(ranks) - ctx._group = group - return input_ - - @staticmethod - def backward(ctx, grad_output): - world_size = torch.distributed.get_world_size(ctx._group) - if world_size == 1: - return grad_output, None - CudaTimer().start(field_name='comm') - torch.distributed.all_reduce(grad_output, group=ctx._group) - CudaTimer().stop(field_name='comm') - return grad_output, None - - -class ReduceScatterAllGather(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dim: int, ranks: List[int]): - group = DeviceGroup().get_group(ranks) - ctx._group = group - ctx._dim = dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_, - CudaTimer().start(field_name='comm') - input_tensors = input_.chunk(world_size, dim) - rank = torch.distributed.get_rank(group) - input_ = torch.empty_like(input_tensors[rank], requires_grad=True) - torch.distributed.reduce_scatter( - input_, input_tensors, group=group - ) - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output): - group = ctx._group - dim = ctx._dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output - CudaTimer().start(field_name='comm') - rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] - tensor_list[rank] = grad_output - torch.distributed.all_gather(tensor_list, grad_output, group=group) - grad = torch.cat(tensor_list, dim=dim).contiguous() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class AllGatherSplit(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dim: int, ranks: List[int]): - group = DeviceGroup().get_group(ranks) - ctx._group = group - ctx._dim = dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=group) - output = torch.cat(tensor_list, dim=dim).contiguous() - CudaTimer().stop(field_name='comm') - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - group = ctx._group - dim = ctx._dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output - CudaTimer().start(field_name='comm') - input_list = grad_output.chunk(world_size, dim=dim) - rank = torch.distributed.get_rank(group) - grad = input_list[rank].contiguous() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class SplitAllGather(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dim: int, ranks: List[int]): - group = DeviceGroup().get_group(ranks) - ctx._group = group - ctx._dim = dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - input_list = input_.chunk(world_size, dim=dim) - rank = torch.distributed.get_rank(group) - input_ = input_list[rank].contiguous() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - group = ctx._group - dim = ctx._dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output - CudaTimer().start(field_name='comm') - rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] - tensor_list[rank] = grad_output - torch.distributed.all_gather(tensor_list, grad_output, group=group) - grad = torch.cat(tensor_list, dim=dim).contiguous() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class ReduceBroadcast(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dst: int, ranks: List[int]): - group = DeviceGroup().get_group(ranks) - ctx._dst = dst - ctx._group = group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - torch.distributed.reduce(input_, dst, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output): - src = ctx._dst - group = ctx._group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output, None, None - CudaTimer().start(field_name='comm') - torch.distributed.broadcast(grad_output, src, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return grad_output, None, None - - -class BroadcastReduce(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, src: int, ranks: List[int]): - group = DeviceGroup().get_group(ranks) - ctx._src = src - ctx._group = group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - torch.distributed.broadcast(input_, src, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output): - dst = ctx._src - group = ctx._group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output, None, None - CudaTimer().start(field_name='comm') - if not grad_output.is_contiguous(): - grad_output = grad_output.contiguous() - torch.distributed.reduce(grad_output, dst, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return grad_output, None, None diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py new file mode 100644 index 00000000..3132883b --- /dev/null +++ b/cube/runtime/adapter/nn.py @@ -0,0 +1,290 @@ +from typing import List, Tuple +import torch + +from cube.profiler.timer import CudaTimer +from cube.runtime.adapter.collectives import all_reduce +from cube.runtime.device import DeviceGroup + + +def _allreduce(itensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: + CudaTimer().start(field_name='comm') + if not itensor.is_contiguous(): + itensor = itensor.contiguous() + group = DeviceGroup().get_group(ranks) + torch.distributed.all_reduce(itensor, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return itensor + + +def _allgather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: + CudaTimer().start(field_name='comm') + if not itensor.is_contiguous(): + itensor = itensor.contiguous() + group = DeviceGroup().get_group(ranks) + tensor_list = [torch.empty_like(itensor) for _ in ranks] + tensor_list[torch.distributed.get_rank(group)] = itensor.data + torch.distributed.all_gather(tensor_list, itensor, group=group) + torch.cuda.synchronize() + # concat + otensor = torch.concat(tuple(tensor_list), dim=dim).requires_grad_() + CudaTimer().stop(field_name='comm') + return otensor + + +def _reducescatter(itensor: torch.Tensor, dim:int, ranks: Tuple[int]) -> torch.Tensor: + CudaTimer().start(field_name='comm') + itensors = list(itensor.chunk(len(ranks), dim)) + for idx, tensor in enumerate(itensors): + if not tensor.is_contiguous(): + itensors[idx] = tensor.contiguous() + group = DeviceGroup().get_group(ranks) + otensor = torch.empty_like(itensors[0]) + torch.distributed.reduce_scatter(otensor, itensors, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return otensor + + +def _alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: + + CudaTimer().start(field_name='comm') + itensors = list(itensor.chunk(len(ranks), dim=odim)) + for idx, tensor in enumerate(itensors): + if not tensor.is_contiguous(): + itensors[idx] = tensor.contiguous() + otensors = [torch.empty_like(t) for t in itensors] + group = DeviceGroup().get_group(ranks) + torch.distributed.all_to_all(otensors, itensors, group=group) + torch.cuda.synchronize() + otensor = torch.concat(tuple(otensors), dim=idim) + CudaTimer().stop(field_name='comm') + return otensor + + +def _chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: + """ + split dimension in n chunks and take idx-th chunk + + ranks (Tuple[int]): the order of split tensor. + """ + group = DeviceGroup().get_group(ranks) + idx = torch.distributed.get_rank(group) + return itensor.chunk(len(ranks), dim)[idx] + + +class AllReduceIdentity(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, ranks: Tuple[int]): + return _allreduce(itensor, ranks) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +def allreduce_identity(tensor: torch.Tensor, ranks: List[int]): + return AllReduceIdentity.apply(tensor, ranks) + + +class IdentityAllreduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, ranks: Tuple[int]): + ctx._ranks = ranks + return itensor + + @staticmethod + def backward(ctx, grad: torch.Tensor): + ranks = ctx._ranks + grad = _allreduce(grad, ranks) + return grad, None + + +def identity_allreduce(tensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: + return IdentityAllreduce.apply(tensor, ranks) + + +class AllReduceAllReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, ranks: Tuple[int]): + ctx._ranks = ranks + otensor = _allreduce(itensor, ranks) + return otensor + + @staticmethod + def backward(ctx, grad: torch.Tensor): + ranks = ctx._ranks + grad = _allreduce(grad, ranks) + return grad, None + + +def allreduce_allreduce(tensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: + return AllReduceAllReduce.apply(tensor, ranks) + + +class ReduceScatterAllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int]): + ctx._ranks = ranks + ctx._dim = dim + return _reducescatter(itensor, dim, ranks) + + @staticmethod + def backward(ctx, grad: torch.Tensor): + ranks = ctx._ranks + dim = ctx._dim + grad = _allgather(grad, dim, ranks) + return grad, None, None + + +def reducescatter_allgather(tensor: torch.Tensor, dim: int, ranks: List[int]): + return ReduceScatterAllGather.apply(tensor, dim, ranks) + + +class AllGatherReduceScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int]): + ctx._ranks = ranks + ctx._dim = dim + return _allgather(itensor, dim, ranks) + + @staticmethod + def backward(ctx, grad: torch.Tensor): + ranks = ctx._ranks + dim = ctx._dim + grad = _reducescatter(grad, dim, ranks) + return grad, None, None + + +def allgather_reducescatter(tensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: + return AllGatherReduceScatter.apply(tensor, dim, ranks) + + +class AllGatherSplit(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int]): + ctx._ranks = ranks + ctx._dim = dim + return _allgather(itensor, dim, ranks) + + @staticmethod + def backward(ctx, grad: torch.Tensor): + ranks = ctx._ranks + dim = ctx._dim + return _chunk(grad, dim, ranks), None, None + + +def allgather_split(tensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: + return AllGatherSplit.apply(tensor, dim, ranks) + + +class SplitAllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int], group): + """ + ranks should be the global rank + """ + ctx._ranks = ranks + ctx._dim = dim + return _chunk(itensor, dim, ranks) + + @staticmethod + def backward(ctx, grad: torch.Tensor): + ranks = ctx._ranks + dim = ctx._dim + grad = _allgather(grad, dim, ranks) + return grad, None, None + + +def chunk_allgather(tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: + return SplitAllGather.apply(tensor, dim, ranks) + + +class AllToAllAllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]): + ctx._ranks = ranks + ctx._idim = idim + ctx._odim = odim + return _alltoall(itensor, idim, odim, ranks) + + @staticmethod + def backward(ctx, grad: torch.Tensor): + ranks = ctx._ranks + idim, odim = ctx._idim, ctx._odim + grad = _alltoall(grad, odim, idim, ranks) + return grad, None, None, None + + +def alltoall_alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: + return AllToAllAllToAll.apply(itensor, idim, odim, ranks) + + +class ReduceBroadcast(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dst: int, ranks: List[int]): + group = DeviceGroup().get_group(ranks) + ctx._dst = dst + ctx._group = group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + CudaTimer().start(field_name='comm') + torch.distributed.reduce(input_, dst, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output): + src = ctx._dst + group = ctx._group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output, None, None + CudaTimer().start(field_name='comm') + torch.distributed.broadcast(grad_output, src, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return grad_output, None, None + + +class BroadcastReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, src: int, ranks: List[int]): + group = DeviceGroup().get_group(ranks) + ctx._src = src + ctx._group = group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + CudaTimer().start(field_name='comm') + torch.distributed.broadcast(input_, src, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output): + dst = ctx._src + group = ctx._group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output, None, None + CudaTimer().start(field_name='comm') + if not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + torch.distributed.reduce(grad_output, dst, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return grad_output, None, None diff --git a/cube/runtime/adapter/transform.py b/cube/runtime/adapter/transform.py index 392ccfd4..506bd14a 100644 --- a/cube/runtime/adapter/transform.py +++ b/cube/runtime/adapter/transform.py @@ -34,22 +34,6 @@ def select(tensor: torch.Tensor, return sub_tensor -def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: - """ - split dimension in n chunks and take idx-th chunk - - ranks (Tuple[int]): the order of split tensor. - """ - idx = ranks.index(torch.distributed.get_rank()) - require_grad = itensor.requires_grad - with torch.no_grad(): - otensor = itensor.chunk(len(ranks), dim)[idx] - otensor = otensor.detach() - if require_grad: - otensor = otensor.requires_grad_() - return otensor - - def smerge(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: """ Runtime primitive of spatial merge. @@ -59,7 +43,7 @@ def smerge(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: tensors: a list of torch tensor dim: the dimension to concatenate. """ - require_grad = any(t.require_grad for t in tensors) + require_grad = any(t.requires_grad for t in tensors) with torch.no_grad(): out = torch.concat(tuple(tensors), dim).requires_grad_() if require_grad: @@ -75,7 +59,7 @@ def vmerge(tensors: List[torch.Tensor]) -> torch.Tensor: Args: tensors: a list of torch tensor """ - require_grad = any(t.require_grad for t in tensors) + require_grad = any(t.requires_grad for t in tensors) with torch.no_grad(): out = tensors[0] for tensor in tensors[1:]: From 2b17f7c7ea4a3d9439a88a2e1587413143bf80eb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 28 May 2022 23:57:33 +0800 Subject: [PATCH 0836/1892] re-organize code --- cube/algorithm/ops/conv.py | 4 +-- cube/algorithm/ops/dataloader.py | 2 +- cube/algorithm/ops/einops.py | 2 +- cube/algorithm/ops/pad.py | 2 +- cube/algorithm/utils.py | 2 +- cube/codegen/codegen.py | 6 ++--- cube/compiler.py | 4 +-- cube/execplan/execplan.py | 6 ++--- cube/execplan/planpass/fusion.py | 16 +++++------ cube/execplan/planpass/grouping.py | 4 +-- cube/graph/__init__.py | 3 --- cube/graph/function/__init__.py | 2 ++ cube/graph/{operator => }/function/cat.py | 2 +- cube/graph/{operator => }/function/conv.py | 2 +- .../graph/{operator => }/function/creators.py | 2 +- .../{operator => }/function/customops.py | 3 +-- cube/graph/{operator => }/function/einops.py | 2 +- .../graph/{operator => }/function/function.py | 22 +++++++-------- cube/graph/{operator => }/function/pad.py | 2 +- cube/graph/{operator => }/function/repeat.py | 2 +- cube/graph/{operator => }/function/scatter.py | 2 +- .../{operator => }/function/scripteinops.py | 2 +- cube/graph/{operator => }/function/select.py | 2 +- cube/graph/{adapter => gener}/__init__.py | 0 cube/graph/{adapter => gener}/gen.py | 12 ++++----- cube/graph/{adapter => gener}/layout.py | 16 +++++------ cube/graph/graph.py | 6 ++--- cube/graph/operator/__init__.py | 3 --- cube/graph/operator/function/__init__.py | 2 -- cube/graph/parser/converter.py | 2 +- cube/graph/parser/mapping.py | 4 +-- cube/graph/parser/parser.py | 4 +-- cube/graph/parser/register.py | 2 +- cube/ir/__init__.py | 5 +++- cube/ir/adapter/__init__.py | 1 + cube/{graph => ir}/adapter/adapter.py | 4 +-- cube/{graph => ir}/adapter/prim.py | 2 +- cube/{graph/operator => ir}/operator.py | 27 +------------------ cube/{graph => ir}/tensor.py | 4 +-- cube/logics/model.py | 4 +-- cube/logics/translator.py | 8 +++--- cube/profiler/estimator.py | 6 ++--- cube/search/sampler.py | 4 +-- examples/atmosphere/policy/naive.py | 4 +-- examples/atmosphere/policy/replicate.py | 4 +-- examples/atmosphere/policy/split.py | 7 +++-- examples/attention/policy/naive.py | 4 +-- examples/mlp/policy/col_parallel.py | 2 +- examples/mlp/policy/data_parallel.py | 2 +- examples/mlp/policy/hybrid_parallel.py | 2 +- examples/mlp/policy/megatron.py | 2 +- examples/mlp/policy/optimal.py | 2 +- examples/mlp/policy/pipe1f1b_parallel.py | 2 +- examples/mlp/policy/pipe_parallel.py | 2 +- examples/mlp/policy/row_parallel.py | 2 +- examples/mlp/policy/search.py | 2 +- examples/mlp/policy/st_search.py | 3 ++- examples/poisson/policy/naive.py | 2 +- examples/transformer/policy/naive.py | 4 +-- examples/wrf/policy/naive.py | 2 +- tests/test_grid.py | 4 +-- tests/test_prim_loop.py | 2 +- 62 files changed, 119 insertions(+), 147 deletions(-) create mode 100644 cube/graph/function/__init__.py rename cube/graph/{operator => }/function/cat.py (97%) rename cube/graph/{operator => }/function/conv.py (98%) rename cube/graph/{operator => }/function/creators.py (96%) rename cube/graph/{operator => }/function/customops.py (97%) rename cube/graph/{operator => }/function/einops.py (99%) rename cube/graph/{operator => }/function/function.py (97%) rename cube/graph/{operator => }/function/pad.py (96%) rename cube/graph/{operator => }/function/repeat.py (95%) rename cube/graph/{operator => }/function/scatter.py (97%) rename cube/graph/{operator => }/function/scripteinops.py (96%) rename cube/graph/{operator => }/function/select.py (97%) rename cube/graph/{adapter => gener}/__init__.py (100%) rename cube/graph/{adapter => gener}/gen.py (97%) rename cube/graph/{adapter => gener}/layout.py (97%) delete mode 100644 cube/graph/operator/__init__.py delete mode 100644 cube/graph/operator/function/__init__.py create mode 100644 cube/ir/adapter/__init__.py rename cube/{graph => ir}/adapter/adapter.py (97%) rename cube/{graph => ir}/adapter/prim.py (99%) rename cube/{graph/operator => ir}/operator.py (91%) rename cube/{graph => ir}/tensor.py (99%) diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 295c679b..9fa8ce3e 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -3,8 +3,8 @@ from cube.algorithm.utils import split_axis, split_axis_custom, split_value from cube.algorithm.generics import GenericDistAlgo -from cube.graph.operator.function.conv import IRConv2D -from cube.graph.operator.function.conv import IRConv3D +from cube.graph.function.conv import IRConv2D +from cube.graph.function.conv import IRConv3D class DimSplitConv2D(GenericDistAlgo): diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py index 25c22140..2f5d8386 100644 --- a/cube/algorithm/ops/dataloader.py +++ b/cube/algorithm/ops/dataloader.py @@ -3,7 +3,7 @@ from cube.algorithm.utils import split_axis from cube.algorithm.generics import GenericDistAlgo -from cube.graph.operator.operator import IRDataOperation +from cube.ir.operator import IRDataOperation class DPDataLoader(GenericDistAlgo): diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py index 32413a9e..faacbbbb 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/einops.py @@ -3,7 +3,7 @@ from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo -from cube.graph.operator.function import IREinops, EinDim +from cube.graph.function import IREinops, EinDim class DimSplitEinops(GenericDistAlgo): diff --git a/cube/algorithm/ops/pad.py b/cube/algorithm/ops/pad.py index 47c47438..50890907 100644 --- a/cube/algorithm/ops/pad.py +++ b/cube/algorithm/ops/pad.py @@ -3,7 +3,7 @@ from cube.algorithm.utils import split_axis, split_axis_custom, split_value from cube.algorithm.generics import GenericDistAlgo -from cube.graph.operator.function.pad import IRPad +from cube.graph.function.pad import IRPad class DimSplitPad(GenericDistAlgo): """ diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py index ef3a266a..7d43e84c 100644 --- a/cube/algorithm/utils.py +++ b/cube/algorithm/utils.py @@ -1,5 +1,5 @@ from typing import List, Union -from cube.graph.tensor import IRSubTensor +from cube.ir.tensor import IRSubTensor def split_axis(tensor: IRSubTensor, axis: int, chunk_num: int): diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 622ab9a0..4dd4317d 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -9,9 +9,9 @@ from cube.ir.cten import IRCell, IRTensor from cube.ir.dtype import IRDType -from cube.graph.tensor import IRSubTensor -from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.graph.adapter.adapter import IRWeightReducer, IRAdapter +from cube.ir.tensor import IRSubTensor +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.adapter import IRWeightReducer, IRAdapter from cube.graph.graph import IRGraph, IRSegment from cube.execplan import ExectuionPlan diff --git a/cube/compiler.py b/cube/compiler.py index 990edbd6..ded621b9 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -6,9 +6,9 @@ import cube from cube.graph import parser -from cube.graph.adapter.gen import IRAdapterGener +from cube.graph.gener.gen import IRAdapterGener from cube.graph.graph import IRGraph -from cube.graph.operator.operator import IRDataOperation +from cube.ir.operator import IRDataOperation from cube.logics.pool import SchedulePool from cube.logics.translator import LogicTranslator diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 5b19ffb6..41918234 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -2,12 +2,12 @@ import copy import numpy as np -from cube.graph.adapter.adapter import IRAdapter -from cube.graph.operator.operator import IRBpOperation, IRFwOperation +from cube.ir.adapter import IRAdapter +from cube.ir.operator import IRBpOperation, IRFwOperation from cube.ir.cten import IRCell from cube.graph.graph import IRGraph -from cube.graph.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor class ExectuionPlan: diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index b30eaece..b52449bf 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -1,17 +1,17 @@ from typing import List -from cube.graph.adapter.adapter import IRAdapter +from cube.ir.adapter import IRAdapter from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass -from cube.graph.adapter.prim import IRAdapterPrim -from cube.graph.adapter.prim import AllReducePrim, AllGatherPrim, ReduceScatterPrim, AllToAllPrim -from cube.graph.adapter.prim import IdentityPrim, ChunkPrim -from cube.graph.adapter.prim import IdentityAllreducePrim, AllReduceIdentityPrim, AllReduceAllReducePrim -from cube.graph.adapter.prim import AllGatherReduceScatterPrim, ReduceScatterAllGatherPrim -from cube.graph.adapter.prim import SplitAllGatherPrim, AllGatherSplitPrim -from cube.graph.adapter.prim import AllToAllAllToAllPrim +from cube.ir.adapter.prim import IRAdapterPrim +from cube.ir.adapter.prim import AllReducePrim, AllGatherPrim, ReduceScatterPrim, AllToAllPrim +from cube.ir.adapter.prim import IdentityPrim, ChunkPrim +from cube.ir.adapter.prim import IdentityAllreducePrim, AllReduceIdentityPrim, AllReduceAllReducePrim +from cube.ir.adapter.prim import AllGatherReduceScatterPrim, ReduceScatterAllGatherPrim +from cube.ir.adapter.prim import SplitAllGatherPrim, AllGatherSplitPrim +from cube.ir.adapter.prim import AllToAllAllToAllPrim class DiffFusion(PlanPass): diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index b89add5c..a53f753a 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -5,8 +5,8 @@ from cube.execplan import ExectuionPlan from cube.execplan.planpass.planpass import PlanPass -from cube.graph.adapter.adapter import IRAdapter -from cube.graph.operator.operator import IRBpOperation, IRFwOperation +from cube.ir.adapter import IRAdapter +from cube.ir.operator import IRBpOperation, IRFwOperation from cube.ir.cten import IRCell diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py index 527a5d19..ec86b08f 100644 --- a/cube/graph/__init__.py +++ b/cube/graph/__init__.py @@ -1,5 +1,2 @@ from cube.graph.graph import IRGraph -from cube.graph.operator import IRFwOperation -from cube.graph.tensor import IRFullTensor, IRSubTensor from cube.graph import parser - diff --git a/cube/graph/function/__init__.py b/cube/graph/function/__init__.py new file mode 100644 index 00000000..93f14214 --- /dev/null +++ b/cube/graph/function/__init__.py @@ -0,0 +1,2 @@ +from cube.graph.function.einops import EinDim, IREinops +from cube.graph.function.function import * \ No newline at end of file diff --git a/cube/graph/operator/function/cat.py b/cube/graph/function/cat.py similarity index 97% rename from cube/graph/operator/function/cat.py rename to cube/graph/function/cat.py index a674c6b4..374e314a 100644 --- a/cube/graph/operator/function/cat.py +++ b/cube/graph/function/cat.py @@ -2,7 +2,7 @@ import itertools from typing import List -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor class IRCat(IRFwOperation): diff --git a/cube/graph/operator/function/conv.py b/cube/graph/function/conv.py similarity index 98% rename from cube/graph/operator/function/conv.py rename to cube/graph/function/conv.py index 92213396..f115f8fa 100644 --- a/cube/graph/operator/function/conv.py +++ b/cube/graph/function/conv.py @@ -1,6 +1,6 @@ from typing import List -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor diff --git a/cube/graph/operator/function/creators.py b/cube/graph/function/creators.py similarity index 96% rename from cube/graph/operator/function/creators.py rename to cube/graph/function/creators.py index 1cf76eef..dded4b93 100644 --- a/cube/graph/operator/function/creators.py +++ b/cube/graph/function/creators.py @@ -1,7 +1,7 @@ from copy import copy from typing import List -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor class IRZeros(IRFwOperation): diff --git a/cube/graph/operator/function/customops.py b/cube/graph/function/customops.py similarity index 97% rename from cube/graph/operator/function/customops.py rename to cube/graph/function/customops.py index c3f71649..c8572d22 100644 --- a/cube/graph/operator/function/customops.py +++ b/cube/graph/function/customops.py @@ -1,7 +1,6 @@ from typing import List -import cube.runtime.function -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor class IRCustomOps(IRFwOperation): diff --git a/cube/graph/operator/function/einops.py b/cube/graph/function/einops.py similarity index 99% rename from cube/graph/operator/function/einops.py rename to cube/graph/function/einops.py index 8ffa3132..a33dc445 100644 --- a/cube/graph/operator/function/einops.py +++ b/cube/graph/function/einops.py @@ -57,7 +57,7 @@ import string from cube.ir.cten import IRTensor -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.algorithm.factory import DistAlgorithmFactory diff --git a/cube/graph/operator/function/function.py b/cube/graph/function/function.py similarity index 97% rename from cube/graph/operator/function/function.py rename to cube/graph/function/function.py index 483e95cb..ab10aa9c 100644 --- a/cube/graph/operator/function/function.py +++ b/cube/graph/function/function.py @@ -4,17 +4,17 @@ import numpy from cube.ir.cten import IRTensor -from cube.graph.operator.function.einops import EinDim, IREinops -from cube.graph.operator.function.conv import IRConv2D -from cube.graph.operator.function.conv import IRConv3D -from cube.graph.operator.function.pad import IRPad -from cube.graph.operator.function.scripteinops import IRScriptEinOps -from cube.graph.operator.function.customops import IRCustomOps -from cube.graph.operator.function.cat import IRCat, IRStack -from cube.graph.operator.function.creators import IRToTensor, IRZeros -from cube.graph.operator.function.select import IRSelect, IRSlice -from cube.graph.operator.function.scatter import IRSelectScatter -from cube.graph.operator.function.repeat import IRRepeat +from cube.graph.function.einops import EinDim, IREinops +from cube.graph.function.conv import IRConv2D +from cube.graph.function.conv import IRConv3D +from cube.graph.function.pad import IRPad +from cube.graph.function.scripteinops import IRScriptEinOps +from cube.graph.function.customops import IRCustomOps +from cube.graph.function.cat import IRCat, IRStack +from cube.graph.function.creators import IRToTensor, IRZeros +from cube.graph.function.select import IRSelect, IRSlice +from cube.graph.function.scatter import IRSelectScatter +from cube.graph.function.repeat import IRRepeat def _create_eshape(shape: List[int], iterator: Optional[Iterable] = None, diff --git a/cube/graph/operator/function/pad.py b/cube/graph/function/pad.py similarity index 96% rename from cube/graph/operator/function/pad.py rename to cube/graph/function/pad.py index 3b18a9c2..0c3147c0 100644 --- a/cube/graph/operator/function/pad.py +++ b/cube/graph/function/pad.py @@ -1,6 +1,6 @@ from typing import List -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor class IRPad(IRFwOperation): diff --git a/cube/graph/operator/function/repeat.py b/cube/graph/function/repeat.py similarity index 95% rename from cube/graph/operator/function/repeat.py rename to cube/graph/function/repeat.py index 79705e35..0d2e650f 100644 --- a/cube/graph/operator/function/repeat.py +++ b/cube/graph/function/repeat.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple import itertools -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor class IRRepeat(IRFwOperation): diff --git a/cube/graph/operator/function/scatter.py b/cube/graph/function/scatter.py similarity index 97% rename from cube/graph/operator/function/scatter.py rename to cube/graph/function/scatter.py index fa11f264..bc1d19d0 100644 --- a/cube/graph/operator/function/scatter.py +++ b/cube/graph/function/scatter.py @@ -1,7 +1,7 @@ from copy import copy from typing import List, Optional, Tuple -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor class IRSelectScatter(IRFwOperation): diff --git a/cube/graph/operator/function/scripteinops.py b/cube/graph/function/scripteinops.py similarity index 96% rename from cube/graph/operator/function/scripteinops.py rename to cube/graph/function/scripteinops.py index b8f89e0a..fc078fef 100644 --- a/cube/graph/operator/function/scripteinops.py +++ b/cube/graph/function/scripteinops.py @@ -1,7 +1,7 @@ from typing import List -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor from einops.einops import _apply_recipe diff --git a/cube/graph/operator/function/select.py b/cube/graph/function/select.py similarity index 97% rename from cube/graph/operator/function/select.py rename to cube/graph/function/select.py index ee22f15a..d21096e2 100644 --- a/cube/graph/operator/function/select.py +++ b/cube/graph/function/select.py @@ -1,7 +1,7 @@ from copy import copy from typing import List, Optional, Tuple -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor class IRSelect(IRFwOperation): diff --git a/cube/graph/adapter/__init__.py b/cube/graph/gener/__init__.py similarity index 100% rename from cube/graph/adapter/__init__.py rename to cube/graph/gener/__init__.py diff --git a/cube/graph/adapter/gen.py b/cube/graph/gener/gen.py similarity index 97% rename from cube/graph/adapter/gen.py rename to cube/graph/gener/gen.py index 4102ab08..f6db5785 100644 --- a/cube/graph/adapter/gen.py +++ b/cube/graph/gener/gen.py @@ -3,15 +3,15 @@ import copy from cube.graph.graph import IRGraph -from cube.graph.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap -from cube.graph.adapter.adapter import IRAdapter, IRWeightReducer +from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from cube.ir.adapter import IRAdapter, IRWeightReducer -from cube.graph.operator.operator import IRBpOperation, IRFwOperation +from cube.ir.operator import IRBpOperation, IRFwOperation from cube.ir.cten import IRCell -from cube.graph.adapter.prim import IRAdapterPrim -from cube.graph.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim -from cube.graph.adapter.layout import GridLayout +from cube.ir.adapter.prim import IRAdapterPrim +from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim +from cube.graph.gener.layout import GridLayout class IRAdapterGener: diff --git a/cube/graph/adapter/layout.py b/cube/graph/gener/layout.py similarity index 97% rename from cube/graph/adapter/layout.py rename to cube/graph/gener/layout.py index e490c2b3..44814f6c 100644 --- a/cube/graph/adapter/layout.py +++ b/cube/graph/gener/layout.py @@ -2,14 +2,14 @@ import copy import numpy as np -from cube.graph.tensor import IRFullTensor, IRSubTensor -from cube.graph.tensor import IndexMap, ValueMap - -from cube.graph.adapter.prim import AllGatherPrim # d2r -from cube.graph.adapter.prim import AllToAllPrim # d2d -from cube.graph.adapter.prim import AllReducePrim # v2r -from cube.graph.adapter.prim import ReduceScatterPrim # v2d -from cube.graph.adapter.prim import ChunkPrim # r2d +from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IndexMap, ValueMap + +from cube.ir.adapter.prim import AllGatherPrim # d2r +from cube.ir.adapter.prim import AllToAllPrim # d2d +from cube.ir.adapter.prim import AllReducePrim # v2r +from cube.ir.adapter.prim import ReduceScatterPrim # v2d +from cube.ir.adapter.prim import ChunkPrim # r2d class GridLayout: diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 08440b19..0aca94e1 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -11,9 +11,9 @@ import copy from cube.ir.cten import IRTensor, IRCell -from cube.graph.operator.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.graph.adapter.adapter import IRAdapter -from cube.graph.tensor import IRFullTensor, IRSubTensor +from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation +from cube.ir.adapter import IRAdapter +from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.algorithm.generics import GenericDistAlgo diff --git a/cube/graph/operator/__init__.py b/cube/graph/operator/__init__.py deleted file mode 100644 index 80f9ae2e..00000000 --- a/cube/graph/operator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cube.graph.operator.operator import IRFwOperation -from cube.graph.operator.operator import IRBpOperation -from cube.graph.operator.operator import IRDataOperation \ No newline at end of file diff --git a/cube/graph/operator/function/__init__.py b/cube/graph/operator/function/__init__.py deleted file mode 100644 index 53aa6b6c..00000000 --- a/cube/graph/operator/function/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from cube.graph.operator.function.einops import EinDim, IREinops -from cube.graph.operator.function.function import * \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index c85ae8f3..6a90b2b2 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,7 +1,7 @@ from typing import Optional, List from cube.ir.cten import IRTensor -from cube.graph.tensor import IRFullTensor +from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser from cube.graph import IRGraph from cube.logics.dataloader import IRDataLoader diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 44f14665..dd1432f1 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -8,8 +8,8 @@ import operator from functools import partial -import cube.graph.operator.function as function -from cube.graph.operator.operator import IRFwOperation +import cube.graph.function as function +from cube.ir.operator import IRFwOperation import cube.ir as ir diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 93384d06..79a215d9 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -3,8 +3,8 @@ import re from typing import Any, List, Tuple, Optional -from cube.graph import IRFwOperation -from cube.graph.tensor import IRFullTensor +from cube.ir.operator import IRFwOperation +from cube.ir.tensor import IRFullTensor from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import Sign2Op, DType2IRDType diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index df38fcf2..6efbeb4c 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -6,7 +6,7 @@ import inspect import torch -from cube.graph.operator.function.einops import IREinops +from cube.graph.function.einops import IREinops from cube.graph.parser.mapping import Sign2Op diff --git a/cube/ir/__init__.py b/cube/ir/__init__.py index 10053459..23ad9584 100644 --- a/cube/ir/__init__.py +++ b/cube/ir/__init__.py @@ -1,2 +1,5 @@ +from cube.ir.dtype import * from cube.ir.cten import IRTensor, IRCell -from cube.ir.dtype import * \ No newline at end of file +from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.operator import IRFwOperation, IRBpOperation, IRDataOperation +from cube.ir.adapter.adapter import IRAdapter diff --git a/cube/ir/adapter/__init__.py b/cube/ir/adapter/__init__.py new file mode 100644 index 00000000..553b5db3 --- /dev/null +++ b/cube/ir/adapter/__init__.py @@ -0,0 +1 @@ +from cube.ir.adapter.adapter import IRAdapter, IRWeightReducer diff --git a/cube/graph/adapter/adapter.py b/cube/ir/adapter/adapter.py similarity index 97% rename from cube/graph/adapter/adapter.py rename to cube/ir/adapter/adapter.py index 6611247f..68e5161c 100644 --- a/cube/graph/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -1,8 +1,8 @@ from typing import List, Optional import copy -from cube.graph.adapter.prim import IRAdapterPrim, IdentityPrim -from cube.graph.tensor import IRSubTensor +from cube.ir.adapter.prim import IRAdapterPrim, IdentityPrim +from cube.ir.tensor import IRSubTensor from cube.ir.cten import IRCell diff --git a/cube/graph/adapter/prim.py b/cube/ir/adapter/prim.py similarity index 99% rename from cube/graph/adapter/prim.py rename to cube/ir/adapter/prim.py index e6262c22..3ad950ef 100644 --- a/cube/graph/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union import copy -from cube.graph.tensor import IRSubTensor, IndexMap, ValueMap +from cube.ir.tensor import IRSubTensor, IndexMap, ValueMap # the general adapter primitive class diff --git a/cube/graph/operator/operator.py b/cube/ir/operator.py similarity index 91% rename from cube/graph/operator/operator.py rename to cube/ir/operator.py index 1f626189..2b8ac283 100644 --- a/cube/graph/operator/operator.py +++ b/cube/ir/operator.py @@ -2,14 +2,11 @@ import copy from cube.ir.cten import IRCell, IRTensor -from cube.graph.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory from cube.ir.unique import IDGenerator -__all__ = ['IRFwOperation', 'IRBpOperation', 'IRDataOperation'] - - class BaseOperator: def __init__(self, name: str, signature: str, @@ -25,28 +22,6 @@ def infer_shape(self): """ raise NotImplementedError - # def set_input(self, input_index: int, val: Any): - # # remove the consumer - # old_val = self.inputs(input_index) - # if isinstance(old_val, IRSubTensor): - # old_val.parent.rm_consumer(self) - # # add the consumer - # val = super().set_input(input_index, val) - # if isinstance(val, IRSubTensor): - # val.parent.add_consumer(self, val) - # return val - - # def set_output(self, output_index: int, val: Any): - # # remove the producer - # old_val = self.outputs(output_index) - # if isinstance(old_val, IRSubTensor): - # old_val.parent.rm_producer(self) - # # add the producer - # val = super().set_output(output_index, val) - # if isinstance(val, IRSubTensor): - # val.parent.add_producer(self, val) - # return val - def replicate(self): """ Replicate the Operation diff --git a/cube/graph/tensor.py b/cube/ir/tensor.py similarity index 99% rename from cube/graph/tensor.py rename to cube/ir/tensor.py index 41c53eaa..f8bbb29b 100644 --- a/cube/graph/tensor.py +++ b/cube/ir/tensor.py @@ -24,7 +24,7 @@ import math from cube.ir.cten import IRCell, IRTensor -import cube.ir as ir +import cube.ir.dtype as irdtype class IndexMap: @@ -366,7 +366,7 @@ class IRFullTensor(IRTensor): the sequentail execution order by its graph. """ - def __init__(self, shape=None, name=None, requires_grad=True, dtype=ir.float32): + def __init__(self, shape=None, name=None, requires_grad=True, dtype=irdtype.float32): super().__init__(shape, name, dtype) diff --git a/cube/logics/model.py b/cube/logics/model.py index b390c974..7d0d4442 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -2,8 +2,8 @@ import copy from cube.graph.graph import IRGraph -from cube.graph.tensor import IRSubTensor -from cube.graph.operator import IRFwOperation +from cube.ir.tensor import IRSubTensor +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor diff --git a/cube/logics/translator.py b/cube/logics/translator.py index bca42fc8..966caf60 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -1,13 +1,13 @@ -from cube.graph.operator.operator import IRFwOperation +from cube.ir.operator import IRFwOperation, IRDataOperation from cube.ir.cten import IRCell +from cube.ir.tensor import IRFullTensor, IRSubTensor + +from cube.graph.graph import IRGraph from cube.logics.dataloader import IRDataLoader from cube.logics import model from cube.logics.pool import SchedulePool -from cube.graph.graph import IRGraph -from cube.graph.tensor import IRFullTensor, IRSubTensor -from cube.graph.operator import IRDataOperation class LogicTranslator: diff --git a/cube/profiler/estimator.py b/cube/profiler/estimator.py index 6cec2514..5aec2d05 100644 --- a/cube/profiler/estimator.py +++ b/cube/profiler/estimator.py @@ -1,7 +1,7 @@ -from cube.graph.operator.operator import IRBpOperation, IRFwOperation -from cube.graph.tensor import IRSubTensor, ValueMap -from cube.graph.adapter.adapter import IRAdapter +from cube.ir.operator import IRBpOperation, IRFwOperation +from cube.ir.tensor import IRSubTensor, ValueMap +from cube.ir.adapter import IRAdapter from cube.graph import IRGraph from cube.ir.cten import IRCell, IRTensor diff --git a/cube/search/sampler.py b/cube/search/sampler.py index f7eae7d9..08aa873f 100644 --- a/cube/search/sampler.py +++ b/cube/search/sampler.py @@ -2,8 +2,8 @@ Micro-batch sampler for scheduling search """ from typing import Callable, Dict, List, Tuple -from cube.graph.graph import IRGraph, IRFwOperation -from cube.graph.operator.operator import IRBpOperation +from cube.graph.graph import IRGraph +from cube.ir.operator import IRFwOperation, IRBpOperation from cube.ir.cten import IRCell from cube.execplan import ExectuionPlan diff --git a/examples/atmosphere/policy/naive.py b/examples/atmosphere/policy/naive.py index caf38e07..535c0de1 100644 --- a/examples/atmosphere/policy/naive.py +++ b/examples/atmosphere/policy/naive.py @@ -1,6 +1,6 @@ from cube.graph import IRGraph -from cube.graph.adapter.adapter import IRAdapter -from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.adapter import IRAdapter +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation def PAS(graph: IRGraph, resource): diff --git a/examples/atmosphere/policy/replicate.py b/examples/atmosphere/policy/replicate.py index 3653786d..e2099fa6 100644 --- a/examples/atmosphere/policy/replicate.py +++ b/examples/atmosphere/policy/replicate.py @@ -1,6 +1,6 @@ from cube.graph import IRGraph -from cube.graph.adapter.adapter import IRAdapter -from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.adapter.adapter import IRAdapter +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation def PAS(graph: IRGraph, resource): diff --git a/examples/atmosphere/policy/split.py b/examples/atmosphere/policy/split.py index 0fecdc6a..cf4d8eb8 100644 --- a/examples/atmosphere/policy/split.py +++ b/examples/atmosphere/policy/split.py @@ -1,8 +1,7 @@ from cube.graph import IRGraph -from cube.graph.adapter.adapter import IRAdapter -from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.graph.operator.function.conv import IRConv3D -from cube.graph.operator.function.pad import IRPad +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.graph.function.conv import IRConv3D +from cube.graph.function.pad import IRPad def PAS(graph: IRGraph, resource): print(graph.extra_repr()) diff --git a/examples/attention/policy/naive.py b/examples/attention/policy/naive.py index caf38e07..164de037 100644 --- a/examples/attention/policy/naive.py +++ b/examples/attention/policy/naive.py @@ -1,6 +1,6 @@ from cube.graph import IRGraph -from cube.graph.adapter.adapter import IRAdapter -from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.adapter.adapter import IRAdapter +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation def PAS(graph: IRGraph, resource): diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py index b7855732..6ea6a7de 100644 --- a/examples/mlp/policy/col_parallel.py +++ b/examples/mlp/policy/col_parallel.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation def P(graph, resource): diff --git a/examples/mlp/policy/data_parallel.py b/examples/mlp/policy/data_parallel.py index 452aa1de..e224a0b6 100644 --- a/examples/mlp/policy/data_parallel.py +++ b/examples/mlp/policy/data_parallel.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation def PAS(graph: IRGraph, resource): diff --git a/examples/mlp/policy/hybrid_parallel.py b/examples/mlp/policy/hybrid_parallel.py index b09f9a40..84230022 100644 --- a/examples/mlp/policy/hybrid_parallel.py +++ b/examples/mlp/policy/hybrid_parallel.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation def PAS(graph: IRGraph, resource): diff --git a/examples/mlp/policy/megatron.py b/examples/mlp/policy/megatron.py index c5a276fd..99eadc64 100644 --- a/examples/mlp/policy/megatron.py +++ b/examples/mlp/policy/megatron.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.graph.operator.operator import IRFwOperation, IRDataOperation +from cube.ir.operator import IRFwOperation, IRDataOperation def PAS(graph: IRGraph, resource): diff --git a/examples/mlp/policy/optimal.py b/examples/mlp/policy/optimal.py index e94523b8..fcc6d8ea 100644 --- a/examples/mlp/policy/optimal.py +++ b/examples/mlp/policy/optimal.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.graph.operator import IRFwOperation, IRDataOperation +from cube.ir.operator import IRFwOperation, IRDataOperation def PAS(graph: IRGraph, resource): diff --git a/examples/mlp/policy/pipe1f1b_parallel.py b/examples/mlp/policy/pipe1f1b_parallel.py index bde03c16..d0fea253 100644 --- a/examples/mlp/policy/pipe1f1b_parallel.py +++ b/examples/mlp/policy/pipe1f1b_parallel.py @@ -1,5 +1,5 @@ from cube.graph.graph import IRGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation def PAS(graph: IRGraph, resource): diff --git a/examples/mlp/policy/pipe_parallel.py b/examples/mlp/policy/pipe_parallel.py index 6ee4d674..e50ca93f 100644 --- a/examples/mlp/policy/pipe_parallel.py +++ b/examples/mlp/policy/pipe_parallel.py @@ -1,7 +1,7 @@ import math import random -from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation def PAS(graph, resource): diff --git a/examples/mlp/policy/row_parallel.py b/examples/mlp/policy/row_parallel.py index 71e5c0c5..012f12aa 100644 --- a/examples/mlp/policy/row_parallel.py +++ b/examples/mlp/policy/row_parallel.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.graph.operator.operator import IRFwOperation, IRDataOperation +from cube.ir.operator import IRFwOperation, IRDataOperation def PAS(graph: IRGraph, resource): diff --git a/examples/mlp/policy/search.py b/examples/mlp/policy/search.py index 5715e8f7..ef1fa13c 100644 --- a/examples/mlp/policy/search.py +++ b/examples/mlp/policy/search.py @@ -4,7 +4,7 @@ from itertools import combinations from cube.graph import IRGraph -from cube.graph.operator.operator import IRDataOperation, IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation import cube.search.iterator as iterator from cube.profiler.estimator import Estimator diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index 3f60ceed..eacfa9e7 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -1,6 +1,7 @@ from functools import partial from typing import List -from cube.graph.graph import IRGraph, IRFwOperation +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRCell from cube.execplan import ExectuionPlan diff --git a/examples/poisson/policy/naive.py b/examples/poisson/policy/naive.py index 0863b65e..d58e2d5e 100644 --- a/examples/poisson/policy/naive.py +++ b/examples/poisson/policy/naive.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.graph.operator.function import IRConv2D +from cube.graph.function import IRConv2D def PAS(graph: IRGraph, resource): for node in graph.nodes(): diff --git a/examples/transformer/policy/naive.py b/examples/transformer/policy/naive.py index eb3a3516..250e4ae7 100644 --- a/examples/transformer/policy/naive.py +++ b/examples/transformer/policy/naive.py @@ -1,6 +1,6 @@ from cube.graph import IRGraph -from cube.graph.adapter.adapter import IRAdapter -from cube.graph.operator.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.adapter.adapter import IRAdapter +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation def PAS(graph: IRGraph, resource): diff --git a/examples/wrf/policy/naive.py b/examples/wrf/policy/naive.py index 0863b65e..d58e2d5e 100644 --- a/examples/wrf/policy/naive.py +++ b/examples/wrf/policy/naive.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.graph.operator.function import IRConv2D +from cube.graph.function import IRConv2D def PAS(graph: IRGraph, resource): for node in graph.nodes(): diff --git a/tests/test_grid.py b/tests/test_grid.py index 75d815f1..fb0164e3 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,5 +1,5 @@ -from cube.graph.adapter.layout import GridLayout -from cube.graph.tensor import IRFullTensor +from cube.graph.gener.layout import GridLayout +from cube.ir.tensor import IRFullTensor def test_grid(): diff --git a/tests/test_prim_loop.py b/tests/test_prim_loop.py index 8f4c14ed..523bb9a0 100644 --- a/tests/test_prim_loop.py +++ b/tests/test_prim_loop.py @@ -7,7 +7,7 @@ from cube.graph.parser.frame import Frame from cube.graph.parser.parser import ScriptModuleParser -from cube.graph.tensor import IRFullTensor +from cube.ir.tensor import IRFullTensor from cube import ir # Stub objects: From d221804a07f7ba76ac6a5e3bd47e42c690a0be4a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 May 2022 14:31:57 +0800 Subject: [PATCH 0837/1892] fix signature --- cube/execplan/planpass/fusion.py | 9 ++++----- cube/execplan/planpass/grouping.py | 8 ++------ cube/ir/adapter/adapter.py | 13 +++++++++++++ cube/ir/adapter/prim.py | 2 +- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index b52449bf..9cd4769f 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -18,14 +18,13 @@ class DiffFusion(PlanPass): @staticmethod def apply(execplan: ExectuionPlan) -> ExectuionPlan: - - def is_forward_adapter(adapter: IRAdapter): - return all(not t.is_grad() for t in adapter.inputs()) - + """ + Fuse the non-differentiable adapters into differentiable adapters. + """ cnt = 0 for devid in execplan.devices(): for node in execplan.seq(devid): - if isinstance(node, IRAdapter) and is_forward_adapter(node): + if isinstance(node, IRAdapter) and node.forward and not node.differentiable: ret = DiffFusion.nnfuse(node) cnt = cnt+1 if ret else cnt print(f'successfully generate {cnt} differentiable adapters') diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index a53f753a..64cc254a 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -15,8 +15,7 @@ class Grouping(PlanPass): @staticmethod def apply(execplan: ExectuionPlan) -> ExectuionPlan: """ - Group contiguous forward and contiguous backward - into subgraph + Group contiguous differentiable operators segments """ graph = execplan.graph fgroups, bgroups = Grouping.group(execplan) @@ -48,9 +47,6 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: Returns: Tuple: (fgroups, bgroups) """ - def is_forward_adapter(adapter: IRAdapter) -> bool: - return all(not t.is_grad() for t in adapter.inputs()) - fgroups, bgroups = dict(), dict() for devid in execplan.devices(): fgroups[devid], bgroups[devid] = list(), list() @@ -60,7 +56,7 @@ def is_forward_adapter(adapter: IRAdapter) -> bool: for fnode in seq: if isinstance(fnode, IRFwOperation): fnodes.append(fnode) - if isinstance(fnode, IRAdapter) and fnode.differentiable and is_forward_adapter(fnode): + if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.forward: fnodes.append(fnode) have_backward = all(fnode.mirror in seq for fnode in fnodes) # training diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index 68e5161c..16143809 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -28,6 +28,12 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): device.update(set(tensor.device)) self.device = list(device) + # setup whether this adapter is for forward stage + is_fw = any(not t.is_grad() for t in self.inputs() + self.outputs()) + is_bw = any(t.is_grad() for t in self.inputs() + self.outputs()) + assert not (is_fw and is_bw), "An IRAdapter cannot serve for both forward and backward stage" + self._forward = is_fw + @property def prims(self) -> List[IRAdapterPrim]: if self.is_forward: @@ -62,6 +68,13 @@ def differentiable(self) -> bool: def differentiable(self, val: bool): self._differentiable = val + @property + def forward(self) -> bool: + """ + return True if this adapter serves in forward stage. + """ + return self._forward + def dispatch(self, devid: int): """ Get Adapter for a specific rank diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index 3ad950ef..a5efaac8 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -369,7 +369,7 @@ class SplitAllGatherPrim(AllGatherPrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): super().__init__(itensors, otensors, dim, **kwargs) - self.signature = 'cube.runtime.adapter.nn.allgather_split' + self.signature = 'cube.runtime.adapter.nn.split_allgather' class AllToAllAllToAllPrim(AllToAllPrim): From 8f6440ff7372847d99bdbe9fee2c06a10752cd53 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 31 May 2022 17:39:37 +0800 Subject: [PATCH 0838/1892] add 1f1b scheduling plan --- cube/codegen/codegen.py | 54 +++++++----- cube/compiler.py | 22 ++--- cube/execplan/execplan.py | 65 +++++++------- cube/execplan/planpass/fusion.py | 12 +++ cube/graph/graph.py | 131 +++++++++++++++++++++-------- cube/graph/schedule/__init__.py | 1 + cube/graph/schedule/sched1f1b.py | 118 ++++++++++++++++++++++++++ cube/graph/schedule/strategy.py | 55 ++++++++++++ cube/ir/adapter/adapter.py | 16 ++-- cube/runtime/schedule/__init__.py | 1 + cube/runtime/schedule/sched1f1b.py | 103 +++++++++++++++++++++++ cube/runtime/schedule/strategy.py | 108 ++++++++++++++++++++++++ 12 files changed, 581 insertions(+), 105 deletions(-) create mode 100644 cube/graph/schedule/__init__.py create mode 100644 cube/graph/schedule/sched1f1b.py create mode 100644 cube/graph/schedule/strategy.py create mode 100644 cube/runtime/schedule/__init__.py create mode 100644 cube/runtime/schedule/sched1f1b.py create mode 100644 cube/runtime/schedule/strategy.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 4dd4317d..1c052357 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -12,7 +12,9 @@ from cube.ir.tensor import IRSubTensor from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from cube.ir.adapter import IRWeightReducer, IRAdapter +from cube.ir.adapter.prim import CollectivePrim from cube.graph.graph import IRGraph, IRSegment +from cube.graph.schedule import IRScheduleStrategy from cube.execplan import ExectuionPlan @@ -96,16 +98,15 @@ def init_comm_groups(self): if ranks not in comm_groups: comm_groups.append(ranks) # collect groups from p2p fusion - adapters = [n for n in graph.nodes() if isinstance(n, IRAdapter)] + adapters = [n for n in graph.flatten() if isinstance(n, IRAdapter)] for adapter in adapters: for prim in adapter.prims: - if len(prim.device) == 1: - continue - ranks = list(prim.device) - ranks.sort() - ranks = tuple(ranks) - if ranks not in comm_groups: - comm_groups.append(ranks) + if isinstance(prim, CollectivePrim): + ranks = list(prim.kwargs['ranks']) + ranks.sort() + ranks = tuple(ranks) + if ranks not in comm_groups: + comm_groups.append(ranks) # create communication group self.declare_region.append('# communication groups') for ranks in comm_groups: @@ -365,22 +366,26 @@ def refcount(tensor, node) -> int: refcnt += 1 return refcnt - # generate code with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: fb.insert_body('_ = None') + # body code if len(device_nodes) == 0: fb.insert_body('pass') - for node in device_nodes: - name = self.node_naming(node) - code = self.emit_node(node, name=name) + elif self.execplan.graph.schedule_plan: + code = self.emit_schedule_plan(self.execplan.graph.schedule_plan, device) fb.insert_body(code) - # free unused tensor - for tensor in node.inputs() + node.outputs(): - if isinstance(tensor, IRSubTensor) and not tensor.is_param(): - refcnt = refcount(tensor, node) - if refcnt == 0: - self.vars.free(tensor) + else: + for node in device_nodes: + name = self.node_naming(node) + code = self.emit_node(node, name=name) + fb.insert_body(code) + # free unused tensor + for tensor in node.inputs() + node.outputs(): + if isinstance(tensor, IRSubTensor) and not tensor.is_param(): + refcnt = refcount(tensor, node) + if refcnt == 0: + self.vars.free(tensor) # return code outputs = self.return_naming(self.execplan.graph.outputs()) code = f'return {outputs}' @@ -395,7 +400,18 @@ def refcount(tensor, node) -> int: f.write(code) return code - def emit_node(self, node: IRCell, name: str) -> List[str]: + def emit_schedule_plan(self, schedplan: IRScheduleStrategy, devid: int): + signature = schedplan.signature + kwargs: Dict[str, Any] = schedplan.kwargs(devid) + strkwargs = dict() + for kwarg, val in kwargs.items(): + name = str(val) if not isinstance(val, IRCell) else 'model.'+self.node_naming(val) + strkwargs[kwarg] = name + code = ', '.join(f'{kwarg}={name}' for kwarg, name in strkwargs.items()) + code = f'{signature}({code})' + return code + + def emit_node(self, node: IRCell, name: str) -> str: """ Emit node / subgraph code """ diff --git a/cube/compiler.py b/cube/compiler.py index ded621b9..c24ead97 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -165,9 +165,12 @@ def decorator(fn: Callable) -> Callable: raise RuntimeError(f"Node {node} device is not set") # generate adapter - # graph = AdapterGener.gen(graph) graph = IRAdapterGener.gen(graph) + if graph.schedule_plan: + graph = graph.schedule_plan.apply(graph) + print(graph.schedule_plan) + # to execution plan execplan = ExectuionPlan(graph) @@ -177,18 +180,15 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on diff-fusion operations: {:.2f} s'.format(span)) - start = time.time() - execplan = Grouping.apply(execplan) - span = time.time() - start - print('> planpass on grouping operations: {:.2f} s'.format(span)) - - # start = time.time() - # execplan = GroupingAdapter.apply(execplan) - # span = time.time() - start - # print('> planpass on grouping adapters : {:.2f} s'.format(span)) + # plan pass for computation grouping + if not graph.schedule_plan: + start = time.time() + execplan = Grouping.apply(execplan) + span = time.time() - start + print('> planpass on grouping operations: {:.2f} s'.format(span)) execplan.graph.reset_dependency() - # execplan.analyze(outfile='execplan.png') + execplan.analyze(outfile='execplan.png') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 41918234..691257c1 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -2,12 +2,10 @@ import copy import numpy as np +from cube.ir.cten import IRCell from cube.ir.adapter import IRAdapter from cube.ir.operator import IRBpOperation, IRFwOperation - -from cube.ir.cten import IRCell -from cube.graph.graph import IRGraph -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.graph.graph import IRGraph, IRSegment class ExectuionPlan: @@ -26,38 +24,26 @@ def __init__(self, graph: IRGraph): self._seq[device] = [] self._seq[device].append(node) - # adapter dispatch + # adapter/segment dispatch for devid in self.devices(): - adapters = [node for node in self.at(devid) if isinstance(node, IRAdapter)] - while len(adapters) > 0: - fadapter = adapters[0] - badapter: Optional[IRAdapter] = fadapter.mirror - fnode = fadapter.dispatch(devid) - fidx = self.at(devid).index(fadapter) - self.at(devid)[fidx] = fnode - if badapter: - bnode = badapter.dispatch(devid) - IRCell.make_pair(fnode, bnode) - bidx = self.at(devid).index(badapter) - self.at(devid)[bidx] = bnode - # remove un-dispatched adapter - adapters.pop(0) - if badapter: - adapters.remove(badapter) + nodes = [node for node in self.at(devid) if isinstance(node, (IRAdapter, IRSegment))] + while len(nodes) > 0: + # dispatch + fnode = nodes[0] + fidx = self.at(devid).index(fnode) + fnode_dev = fnode.dispatch(devid) + self.at(devid)[fidx] = fnode_dev + nodes.pop(0) + if fnode.mirror is not None: + bidx = self.at(devid).index(fnode.mirror) + nodes.remove(fnode.mirror) + self.at(devid)[bidx] = fnode_dev.mirror - # check whether graph output is replicated across device - # FIXME: should use adapter to generate communication for - # traning logic output + # TODO: adapter support for return consistency for output in graph.outputs(): - devices = self.devices() - ltensor: IRFullTensor = output.parent # logic tensor - if isinstance(output, IRSubTensor): - for ptensor, producer in zip(ltensor.ptensors, ltensor.producers): - if ptensor == output: - if producer.device[0] in devices: - devices.remove(producer.device[0]) - if len(devices) != 0: - raise NotImplementedError("Require return values of training logic is replicated across nodes.") + for devid in self.devices(): + ptensors = [pt for pt in output.parent.ptensors if pt == output and devid in pt.device] + assert len(ptensors) >= 1, f"Missing full graph output tensor {output} in device {devid}" @property def graph(self) -> IRGraph: @@ -89,6 +75,19 @@ def at(self, devid: int) -> List[IRCell]: assert devid in self._seq, f"device id {devid} not exists" return self._seq[devid] + def flatten(self, devid: int) -> List[IRCell]: + """ + Flatten the sequence by expanding segments + """ + assert devid in self._seq, f"device id {devid} not exists" + nodes = [] + for node in self._seq[devid]: + if isinstance(node, IRSegment): + nodes += node.nodes() + else: + nodes.append(node) + return nodes + def set(self, devid: int, seq: List[IRCell]): """ Set device sequence diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 9cd4769f..5e1cb2d1 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -1,4 +1,5 @@ from typing import List +from cube.graph.graph import IRSegment from cube.ir.adapter import IRAdapter @@ -27,6 +28,17 @@ def apply(execplan: ExectuionPlan) -> ExectuionPlan: if isinstance(node, IRAdapter) and node.forward and not node.differentiable: ret = DiffFusion.nnfuse(node) cnt = cnt+1 if ret else cnt + if isinstance(node, IRSegment) and node.forward: + for fnode in node.nodes(): + if isinstance(fnode, IRAdapter) and node.forward and not fnode.differentiable: + ret = DiffFusion.nnfuse(fnode) + if not ret: + raise NotImplementedError( + f"adapter within IRSegment cannot fuse to differientiable adapter" + f"\nforward: {fnode.extra_repr()}" + f"\nbackward: {fnode.mirror.extra_repr()}" + ) + cnt = cnt + 1 print(f'successfully generate {cnt} differentiable adapters') return execplan diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 0aca94e1..ea07f194 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,7 +7,7 @@ will be inserted at scheduling time. """ -from typing import Union, Tuple, List, Optional, Dict +from typing import Any, Union, Tuple, List, Optional, Dict import copy from cube.ir.cten import IRTensor, IRCell @@ -20,12 +20,16 @@ class IRSegment(IRCell): """ - A segment refers to a piece of workload of IRGraph + A distributed sub-graph representing a piece of workload in parent IRGraph """ def __init__(self, nodes: List[IRCell], inputs: List[IRSubTensor], outputs: List[IRSubTensor]): - self._nodes = nodes super().__init__('segment', '', len(inputs), len(outputs), init_outputs=False) + + self._nodes = nodes + self._idevice = [t.device for t in inputs] + self._odevice = [t.device for t in outputs] + for idx, val in enumerate(inputs): self.set_input(idx, val) for idx, val in enumerate(outputs): @@ -51,8 +55,36 @@ def nodes(self, idx: Optional[int] = None) -> Union[IRCell, List[IRCell]]: else: return copy.copy(self._nodes) + def dispatch(self, devid: int, for_mirror=True) -> Optional[IRCell]: + """ + Instantiate from distributed representation to a + device-specific sub-graph. + + The mirror will also be dispatched if it is not None. + + Return the dispatched segment + """ + if devid not in self.device: + return None + if len(self.device) == 1 and self.device == [devid]: + return self + itensors = [t for t, device in zip(self.inputs(), self._idevice) if devid in device] + otensors = [t for t, device in zip(self.outputs(), self._odevice) if devid in device] + nodes = [n for n in self.nodes() if devid in n.device] + for idx, adapter in enumerate(nodes): + if isinstance(adapter, IRAdapter): + nodes[idx] = adapter.dispatch(devid) + fseg = IRSegment(nodes, itensors, otensors) + fseg._id = self._id + # dispatch for mirror + if for_mirror and isinstance(self.mirror, IRSegment): + bseg = self.mirror.dispatch(devid, for_mirror=False) + IRCell.make_pair(fseg, bseg) + return fseg + def __repr__(self): - return f'Segment{self._id}(inputs={self.inputs()}, outputs={self.outputs()})' + name = ('f' if self.forward else 'b') + 'Segment' + return f'{name}{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' def extra_repr(self) -> str: dscp = repr(self) @@ -63,9 +95,8 @@ def extra_repr(self) -> str: class IRGraph(IRCell): """ - PyTorch IR Graph - - The IR Graph only contains forward graph + IR Graph. The hyperGraph for representing distributed + graph. """ def __init__(self, @@ -78,6 +109,8 @@ def __init__(self, self._parameters = list() self._full_tensors: Dict[int, IRFullTensor] = dict() + self._schedule_strategy = None + if inputs is None: inputs = IRGraph.get_inputs(nodes) if outputs is None: @@ -113,6 +146,14 @@ def __init__(self, self.reset_dependency() + @property + def schedule_plan(self) -> Optional[Any]: + return self._schedule_strategy + + @schedule_plan.setter + def schedule_plan(self, val: Optional[Any]): + self._schedule_strategy = val + def reset_dependency(self): """ Reset the node dataflow dependency @@ -122,30 +163,19 @@ def reset_dependency(self): for node in self._nodes: node.clear_predecessor() node.clear_successor() - # set node predecessors and successors - for src_idx in range(len(self._nodes)): - src_node = self._nodes[src_idx] - for dst_node in self._nodes[src_idx+1:]: - # we don't consider dependencies among adapter - if isinstance(src_node, IRAdapter) and isinstance(dst_node, IRAdapter): - continue - for out_idx, out_tensor in enumerate(src_node.outputs()): - if not isinstance(out_tensor, IRTensor): - continue - for in_idx, in_tensor in enumerate(dst_node.inputs()): - if not isinstance(in_tensor, IRTensor): - continue - if out_tensor.overlap(in_tensor): - src_node.add_successor(out_idx, dst_node) - dst_node.add_predecessor(in_idx, src_node) - # set mirror as control dependency - for idx1, node1 in enumerate(self._nodes): - node2 = node1.mirror - if isinstance(node2, IRCell) and node2 in self._nodes: - idx2 = self._nodes.index(node2) - if idx1 < idx2: - node1.add_successor(-1, node2) - node2.add_predecessor(-1, node1) + # TODO: adapter dependency not set + for ftensor in self._full_tensors.values(): + for ptensor, producer in zip(ftensor.ptensors, ftensor.producers): + for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + if ptensor.overlap(ctensor): + pidx = producer.outputs().index(ptensor) + cidx = consumer.inputs().index(ctensor) + producer.add_successor(pidx, consumer) + consumer.add_predecessor(cidx, producer) + # set mirror as control dependency + if producer.mirror and isinstance(producer, IRFwOperation): + producer.add_successor(-1, producer) + producer.mirror.add_predecessor(-1, producer) def parameters(self): """ @@ -159,7 +189,7 @@ def full_tensors(self): """ return list(self._full_tensors.values()) - def nodes(self, index: Optional[int] = None): + def nodes(self, index: Optional[int] = None) -> Union[IRCell, List[IRCell]]: """ Get node at position index """ @@ -203,23 +233,40 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: """ inputs, outputs = [], [] for node in nodes: + assert not isinstance(node, IRSegment), 'A segment cannot be in other segments' # update inputs itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] for itensor in itensors: - producers = [p for p in itensor.parent.producers if p.device == node.device] - # no producer means a weight + producers = [p for p in itensor.parent.producers if set(p.device).issubset(set(node.device))] + # no producer means a weight or cross device-group if len(producers) == 0 or any(p not in nodes for p in producers): inputs.append(itensor) # update outputs otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] for otensor in otensors: - consumers = [c for c in otensor.parent.consumers if c.device == node.device] - # no consumer usually means the loss + consumers = [c for c in otensor.parent.consumers if set(c.device).issubset(set(node.device))] + # no consumer usually means the loss or cross device-group if len(consumers) == 0 or any(c not in nodes for c in consumers): outputs.append(otensor) segment = IRSegment(nodes, inputs, outputs) return segment + def group(self, nodes: List[IRCell]) -> IRSegment: + """ + Group consecutive nodes into IRSegment. + + Currently this interface will break the dependency, + it can only be used after user policy + """ + allnodes = self.nodes() + indices = [allnodes.index(n) for n in nodes] + minidx, maxidx = min(indices), max(indices) + assert maxidx - minidx + 1 == len(nodes), "nodes are not consecutive" + segment = self.segment(nodes) + self._nodes = allnodes[:minidx] + [segment] + allnodes[maxidx+1:] + # FIXME: set segment dependnecy + return segment + def detach(self, node: IRCell, reset_dependency=False) -> int: """ Detach (remove) a node from current graph. @@ -283,6 +330,18 @@ def attach(self, node: IRCell, index, reset_dependency=False): self.reset_dependency() return + def flatten(self) -> List[IRCell]: + """ + Flattent the graph by expanding nodes + """ + nodes = [] + for node in self.nodes(): + if isinstance(node, IRSegment): + nodes += node.nodes() + else: + nodes.append(node) + return nodes + @staticmethod def get_inputs(nodes: List[IRCell]): """ diff --git a/cube/graph/schedule/__init__.py b/cube/graph/schedule/__init__.py new file mode 100644 index 00000000..3712eea3 --- /dev/null +++ b/cube/graph/schedule/__init__.py @@ -0,0 +1 @@ +from cube.graph.schedule.strategy import IRScheduleStrategy \ No newline at end of file diff --git a/cube/graph/schedule/sched1f1b.py b/cube/graph/schedule/sched1f1b.py new file mode 100644 index 00000000..b7825d0d --- /dev/null +++ b/cube/graph/schedule/sched1f1b.py @@ -0,0 +1,118 @@ + +from typing import Dict, Tuple +from cube.ir.adapter.adapter import IRAdapter +from cube.ir.cten import IRCell + +from cube.graph.graph import IRGraph, IRSegment +from cube.graph.schedule import IRScheduleStrategy + + +class IRSchedule1F1B(IRScheduleStrategy): + """ + 1F1B Scheduling + + This requires a micro-batch can be grouped into continguous segments + which are placed on distinct device groups (refered as a stage): + + [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] + [Recv-Backward] Backward-Segment [Send-Backward] + """ + + def __init__(self, num_microbatch: int, devmesh: Tuple[Tuple[int]], recompute=False): + super().__init__(num_microbatch, devmesh) + self.signature = 'cube.runtime.schedule.Schedule1F1B.run' + # forward body + self.segment = dict() + # forward send + self.sfadapter = dict() + # forward recv + self.rfadapter = dict() + # backard send + self.sbadapter = dict() + # backward recv + self.rbadapter = dict() + # num_stage + self.num_stages = len(devmesh) + # stage id + self.stage_id = dict() + # recompute + self.recompute = recompute + + def apply(self, graph: IRGraph) -> IRGraph: + graph = IRSchedule1F1B.segmentation(graph, self.devmesh) + for stage_id, devices in enumerate(self.devmesh): + for devid in devices: + nodes = [n for n in graph.nodes() if devid in n.device] + # forward body + fsegments = [seg for seg in nodes if isinstance(seg, IRSegment) and seg.forward] + assert len(fsegments) == 1, "find more than one segment." + fsegment = fsegments[0] + self.segment[devid] = fsegment + fidx = nodes.index(fsegment) + bidx = nodes.index(fsegment.mirror) + # adapters + adapters = [adapter for adapter in nodes if isinstance(adapter, IRAdapter)] + # forward sends + forward_sends = [n for n in adapters if n.forward and nodes.index(n) > fidx] + if stage_id == self.num_stages - 1: + assert len(forward_sends) == 0, f"stage: {stage_id}: last stage should not send forward outputs" + self.sfadapter[devid] = None + else: + assert len(forward_sends) == 1, f"stage: {stage_id}: last stage should not send forward outputs" + self.sfadapter[devid] = forward_sends[0] + # forward recvs + forward_recvs = [n for n in adapters if n.forward and nodes.index(n) < fidx] + if stage_id == 0: + assert len(forward_recvs) == 0, f"stage: {stage_id}: first stage should not recv inputs" + self.rfadapter[devid] = None + else: + assert len(forward_recvs) == 1, f"stage: {stage_id}: non-first stage should recv 1 inputs" + self.rfadapter[devid] = forward_recvs[0] + # backward sends + backward_sends = [n for n in adapters if not n.forward and nodes.index(n) > bidx] + if stage_id == 0: + assert len(backward_sends) == 0, f"stage: {stage_id}: first stage should not send back gradient" + self.sbadapter[devid] = None + else: + assert len(backward_sends) == 1, f"stage: {stage_id}: non-first stage should not send back gradient" + self.sbadapter[devid] = backward_sends[0] + # backward recvs + backward_recvs = [n for n in adapters if not n.forward and nodes.index(n) < bidx] + if stage_id == self.num_stages - 1: + assert len(backward_recvs) == 0, f"stage: {stage_id}: last stage should not recv gradient" + self.rbadapter[devid] = None + else: + assert len(backward_recvs) == 1, f"stage: {stage_id}: non-last stage should recv 1 gradient" + self.rbadapter[devid] = backward_recvs[0] + # stage id + self.stage_id[devid] = stage_id + return graph + + def kwargs(self, devid: int) -> Dict[str, IRCell]: + """ + return kwargs for runtime caller + """ + return dict( + segment = self.segment[devid], + sfadapter = self.sfadapter[devid], + rfadapter = self.rfadapter[devid], + sbadapter = self.sbadapter[devid], + rbadapter = self.rbadapter[devid], + dataloader = 'dataloader', + stage_id = self.stage_id[devid], + num_stages = self.num_stages, + num_microbatch = self.num_microbatch, + recompute = self.recompute + ) + + def __repr__(self) -> str: + dscp = '' + for mesh in self.devmesh: + dscp += (f"1F1B-Schedule-stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" + f" segment = {self.segment[mesh[0]]}\n" + f" send-fw = {self.sfadapter[mesh[0]]}\n" + f" recv-fw = {self.rfadapter[mesh[0]]}\n" + f" send-bw = {self.sbadapter[mesh[0]]}\n" + f" recv-bw = {self.rbadapter[mesh[0]]}\n" + f")\n") + return dscp diff --git a/cube/graph/schedule/strategy.py b/cube/graph/schedule/strategy.py new file mode 100644 index 00000000..5c254e44 --- /dev/null +++ b/cube/graph/schedule/strategy.py @@ -0,0 +1,55 @@ +from typing import Tuple, Dict, Any +from cube.graph.graph import IRGraph +from cube.ir.adapter.adapter import IRAdapter +from cube.ir.cten import IRCell +from cube.ir.operator import IRFwOperation + + +class IRScheduleStrategy: + + def __init__(self, num_microbatch: int, devmesh: Tuple[Tuple[int]]) -> None: + self.num_microbatch = num_microbatch + self.devmesh = devmesh + self.signature: str = '' + + def apply(self, graph: IRGraph) -> IRGraph: + raise NotImplementedError + + def kwargs(self, device: int) -> Dict[str, Any]: + raise NotImplementedError + + @staticmethod + def segmentation(graph: IRGraph, devmesh: Tuple[Tuple[int]]) -> IRGraph: + """ + Utilities for grouping operators into segments with device mesh + """ + stages = [[] for _ in range(len(devmesh))] + for node in graph.nodes(): + for meshid, devices in enumerate(devmesh): + if set(node.device).issubset(set(devices)): + stages[meshid].append(node) + break + # grouping + for stage in stages: + fconsecutive, bconsecutive = [], [] + for node in stage: + if isinstance(node, IRFwOperation) or (isinstance(node, IRAdapter) and node.forward): + fconsecutive.append(node) + if node.mirror: + bconsecutive.append(node.mirror) + else: + assert len(fconsecutive) == len(bconsecutive) or len(bconsecutive) == 0, 'mismatch number of forward and backward operators.' + if len(fconsecutive) != 0: + fsegment = graph.group(fconsecutive) + if len(bconsecutive) != 0: + bsegment = graph.group(bconsecutive[::-1]) + IRCell.make_pair(fsegment, bsegment) + fconsecutive, bconsecutive = [], [] + return graph + + @staticmethod + def merging(graph: IRGraph) -> IRGraph: + """ + merge the adapters into one + """ + pass diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index 16143809..a3fa1a06 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -75,7 +75,7 @@ def forward(self) -> bool: """ return self._forward - def dispatch(self, devid: int): + def dispatch(self, devid: int, for_mirror=True): """ Get Adapter for a specific rank @@ -101,11 +101,15 @@ def dispatch(self, devid: int): for itensor in inputs: prims.append(IdentityPrim(itensor)) # dispatch - adapter = IRAdapter(inputs, outputs) - adapter.prims = prims - adapter.name = self.name - adapter._id = self._id - return adapter + fadapter = IRAdapter(inputs, outputs) + fadapter.prims = prims + fadapter.name = self.name + fadapter._id = self._id + # dispatch for mirror + if for_mirror and isinstance(self.mirror, IRAdapter): + badapter = self.mirror.dispatch(devid, for_mirror=False) + IRCell.make_pair(fadapter, badapter) + return fadapter def __repr__(self): return f'Adapter-{self._id}{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' diff --git a/cube/runtime/schedule/__init__.py b/cube/runtime/schedule/__init__.py new file mode 100644 index 00000000..b2db67e5 --- /dev/null +++ b/cube/runtime/schedule/__init__.py @@ -0,0 +1 @@ +from cube.runtime.schedule.sched1f1b import Schedule1F1B \ No newline at end of file diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py new file mode 100644 index 00000000..555208a8 --- /dev/null +++ b/cube/runtime/schedule/sched1f1b.py @@ -0,0 +1,103 @@ +from typing import Callable, Iterable +import torch + +from cube.runtime.schedule.strategy import ScheduleABC + + +class Schedule1F1B(ScheduleABC): + + @staticmethod + def run(segment: Callable, # forward body + rfadapter: Callable, # recv_forward adapter + sfadapter: Callable, # send_forward adapter + rbadapter: Callable, # recv_backward adapter + sbadapter: Callable, # send_backward adapter + dataloader: Iterable, + stage_id: int, + num_stages: int, + num_microbatch: int, + recompute=False): + + num_warmup_microbatches = num_stages - 1 - stage_id + num_warmup_remaining = num_microbatch - num_warmup_microbatches + + # warmup + for _ in range(num_warmup_microbatches): + # recv forward + # print(f'rank[{torch.distributed.get_rank()}]: line26: recving forward') + inputs = Schedule1F1B.adapter_step(rfadapter) + inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs + # forward + Schedule1F1B.push_tail('inputs', inputs) + if recompute: + with torch.no_grad(): + outputs = Schedule1F1B.forward_step(segment, *inputs) + Schedule1F1B.push_tail('outputs', None) + else: + # print(f'rank[{torch.distributed.get_rank()}]: line36: forward') + outputs = Schedule1F1B.forward_step(segment, *inputs) + Schedule1F1B.push_tail('outputs', outputs) + # send forward + # print(f'rank[{torch.distributed.get_rank()}]: line40 send forward') + Schedule1F1B.adapter_step(sfadapter, *outputs) + + if num_warmup_remaining > 0: + # print(f'rank[{torch.distributed.get_rank()}]: line44 recv forward') + inputs = Schedule1F1B.adapter_step(rfadapter) + inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs + + # steady + for i in range(num_warmup_remaining): + # forward + Schedule1F1B.push_tail('inputs', inputs) + if recompute: + with torch.no_grad(): + outputs = Schedule1F1B.forward_step(segment, *inputs) + Schedule1F1B.push_tail('outputs', None) + else: + # print(f'rank[{torch.distributed.get_rank()}]: line 57 forward') + outputs = Schedule1F1B.forward_step(segment, *inputs) + Schedule1F1B.push_tail('outputs', outputs) + + # send forward recv backward + # print(f'rank[{torch.distributed.get_rank()}]: line62 send forward recv backward') + grads = Schedule1F1B.exchange(sfadapter, rbadapter, stage_id, *outputs) + grads = (None,) if len(grads) == 0 else grads + + # backward + inputs, outputs = Schedule1F1B.pop_head('inputs'), Schedule1F1B.pop_head('outputs') + if recompute: + assert outputs is None + outputs = Schedule1F1B.forward_step(segment, *inputs) + # print(f'rank[{torch.distributed.get_rank()}]: line71 backward') + input_grads = Schedule1F1B.backward_step(inputs, outputs, grads) + + # send backward recv forward + if i != num_warmup_remaining - 1: + # print(f'rank[{torch.distributed.get_rank()}]: line77 send backward recv forward') + inputs = Schedule1F1B.exchange(sbadapter, rfadapter, stage_id, *input_grads) + inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs + else: + # send backward + # print(f'rank[{torch.distributed.get_rank()}]: line82 send backward') + Schedule1F1B.adapter_step(sbadapter, *input_grads) + + # cooldown + for i in range(num_warmup_microbatches): + inputs, outputs = Schedule1F1B.pop_head('inputs'), Schedule1F1B.pop_head('outputs') + # recv backward + # print(f'rank[{torch.distributed.get_rank()}]: line89 recv backward') + grads = Schedule1F1B.adapter_step(rbadapter) + grads = (None,) if len(grads) == 0 else grads + # backward + if recompute: + assert outputs is None + outputs = Schedule1F1B.forward_step(segment, *inputs) + # print(f'rank[{torch.distributed.get_rank()}]: line96 backward') + input_grads = Schedule1F1B.backward_step(inputs, outputs, grads) + # send backward + # print(f'rank[{torch.distributed.get_rank()}]: line99 send backward') + Schedule1F1B.adapter_step(sbadapter, *input_grads) + + Schedule1F1B.assert_empty() + # print(f'rank[{torch.distributed.get_rank()}]: ok here') diff --git a/cube/runtime/schedule/strategy.py b/cube/runtime/schedule/strategy.py new file mode 100644 index 00000000..3b78c89d --- /dev/null +++ b/cube/runtime/schedule/strategy.py @@ -0,0 +1,108 @@ +from typing import Any, Callable, Dict, Iterable, List +import torch + +from cube.profiler.timer import CudaTimer + + +class ScheduleABC: + + status: Dict[str, List[torch.Tensor]] = dict() + + @staticmethod + def forward_step(segment: Callable, *args, **kwargs): + """ + forward pass + """ + CudaTimer().start('forward') + outputs = segment(*args, **kwargs) + CudaTimer().stop('forward') + if not isinstance(outputs, tuple): + outputs = (outputs,) + return outputs + + @staticmethod + def backward_step(itensors: List[torch.Tensor], + otensors: List[torch.Tensor], + otensor_grads: List[torch.Tensor]) -> List[torch.Tensor]: + """ + backward pass + """ + for tensor in itensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + tensor.retain_grad() + CudaTimer().start("backward") + torch.autograd.backward(otensors, grad_tensors=otensor_grads) + CudaTimer().stop("backward") + itensor_grads = [] + for tensor in itensors: + if torch.is_tensor(tensor) and tensor.requires_grad: + itensor_grads.append(tensor.grad) + else: + itensor_grads.append(None) + return tuple(itensor_grads) + + @staticmethod + def dataloader_step(dataloader: Iterable): + data = next(dataloader) + if not isinstance(data, tuple): + data = (data,) + return data + + @staticmethod + def adapter_step(adapter: Callable, *args): + """ + adapter pass + """ + if adapter is None: return () + CudaTimer().start('adapter') + outputs = adapter(*args) + CudaTimer().stop('adapter') + if not isinstance(outputs, tuple): + outputs = (outputs,) + return outputs + + @staticmethod + def exchange(sadapter: Callable, radapter: Callable, stage_id: int, *args): + """ + send adapter and recv adapter + """ + # TODO: optimize with batch operators + if stage_id % 2 == 0: + ScheduleABC.adapter_step(sadapter, *args) + outs = ScheduleABC.adapter_step(radapter) + else: + outs = ScheduleABC.adapter_step(radapter) + ScheduleABC.adapter_step(sadapter, *args) + return outs + + @staticmethod + def push_tail(name: str, val: Any): + if name not in ScheduleABC.status: + ScheduleABC.status[name] = [] + ScheduleABC.status[name].append(val) + + @staticmethod + def push_head(name: str, val: Any): + if name not in ScheduleABC.status: + ScheduleABC.status[name] = [] + ScheduleABC.status[name].insert(0, val) + + @staticmethod + def pop_head(name: str): + assert name in ScheduleABC.status, f"{name} is empty" + out = ScheduleABC.status[name].pop(-1) + if len(ScheduleABC.status[name]) == 0: + del ScheduleABC.status[name] + return out + + @staticmethod + def pop_tail(name: str): + assert name in ScheduleABC.status, f"{name} is empty" + out = ScheduleABC.status[name].pop(0) + if len(ScheduleABC.status[name]) == 0: + del ScheduleABC.status + return out + + @staticmethod + def assert_empty(): + assert len(ScheduleABC.status) == 0, f"status is not empty. Got field {list(ScheduleABC.status.keys())}" From 49c3809116b9623952f0d154bc0ea0d9d62ee38b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 31 May 2022 17:40:08 +0800 Subject: [PATCH 0839/1892] add 1f1b scheduling example --- examples/mlp/policy/megatron_pptp.py | 67 ++++++++++++++++++++++++ examples/mlp/policy/pipe1f1b_parallel.py | 60 --------------------- 2 files changed, 67 insertions(+), 60 deletions(-) create mode 100644 examples/mlp/policy/megatron_pptp.py delete mode 100644 examples/mlp/policy/pipe1f1b_parallel.py diff --git a/examples/mlp/policy/megatron_pptp.py b/examples/mlp/policy/megatron_pptp.py new file mode 100644 index 00000000..9d567842 --- /dev/null +++ b/examples/mlp/policy/megatron_pptp.py @@ -0,0 +1,67 @@ +from typing import Tuple +import numpy as np + +from cube.graph.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.schedule.sched1f1b import IRSchedule1F1B + + +def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + + +def PAS(graph: IRGraph, resource): + + # assert resource.ngpus == 8, "should apply on 8 gpus" + num_stage = 4 + num_tp = resource.ngpus // num_stage + num_microbatch = resource.ngpus + + _, tp_mesh = create_mesh(resource.ngpus, (num_stage, num_tp)) + print(f'> pipeline-tensor parallel group: {tp_mesh}') + assert len(tp_mesh) == num_stage + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + node2stage = lambda node: min(fnodes.index(node) // (len(fnodes) // num_stage), num_stage-1) + + for idx, node in enumerate(fnodes): + # get tensor parallel group + sid = node2stage(node) + tp_group = tp_mesh[sid] + # partition + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, dict(idx=1, dim=idx%2, num=num_tp)) + if tp_nodes is None: + tp_nodes = graph.replicate(node, times=num_tp) + # assign + for devid, node in zip(tp_group, tp_nodes): + graph.assign(node, devid) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + mesh = tp_mesh[0] + rnodes = graph.replicate(node, times=num_tp) + for devid, rnode in zip(mesh, rnodes): + graph.assign(rnode, devid) + # setup schedule to 1F1B + schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) + graph.schedule_plan = schedule + return graph diff --git a/examples/mlp/policy/pipe1f1b_parallel.py b/examples/mlp/policy/pipe1f1b_parallel.py deleted file mode 100644 index d0fea253..00000000 --- a/examples/mlp/policy/pipe1f1b_parallel.py +++ /dev/null @@ -1,60 +0,0 @@ -from cube.graph.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation - - -def PAS(graph: IRGraph, resource): - """ - 1F1B scheduling - """ - - num_micro_batch = resource.ngpus - num_stage = resource.ngpus - - fstages = [list() for _ in range(num_micro_batch * num_stage)] - - def f(micro_batch_id: int, stage_id: int): - return fstages[micro_batch_id * num_stage + stage_id] - - def b(micro_batch_id: int, stage_id: int): - fstage = f(micro_batch_id, stage_id) - bstage = [fnode.mirror for fnode in fstage][::-1] - return bstage - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - stage_op_num = len(fnodes) // num_stage - for idx, node in enumerate(fnodes): - stage = min(idx // stage_op_num, num_stage - 1) - # partition at batch dimension - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=0, num=num_micro_batch)) - for mid, sub_node in enumerate(sub_nodes): - f(mid, stage).append(sub_node) - graph.assign(sub_node, stage) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - graph.assign(node, 0) - - # 1F1B scheduling - seqs = list() - # warmup - for mid in range(num_micro_batch): - for stage in range(num_stage - mid): - seqs += f(mid, stage) - # steady + cooldown: - for mid in range(num_micro_batch): - # enqueue backward - for stage in range(num_stage-1, -1, -1): - seqs += b(mid, stage) - # enqueue forward - for stage in range(num_stage): - f_mid = mid + num_stage - stage - if f_mid >= num_micro_batch: - continue - seqs += f(f_mid, stage) - for node in seqs: - print(node) - graph.partial_set_order(seqs) - - return graph From 6245e034951833c7242a0a83e94c8e4734943d8b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Jun 2022 09:43:44 +0800 Subject: [PATCH 0840/1892] add solver based solution --- cube/{search => tetris}/composer.py | 2 +- cube/tetris/solver.py | 244 ++++++++++++++++++++++++++++ 2 files changed, 245 insertions(+), 1 deletion(-) rename cube/{search => tetris}/composer.py (99%) create mode 100644 cube/tetris/solver.py diff --git a/cube/search/composer.py b/cube/tetris/composer.py similarity index 99% rename from cube/search/composer.py rename to cube/tetris/composer.py index cd20e7f3..acbfea75 100644 --- a/cube/search/composer.py +++ b/cube/tetris/composer.py @@ -632,7 +632,7 @@ def search(ndevs, nmicros, visualize=False): # premise # micros = Composer.premise(uniform_staging, ndevs, nmicros) # micros = Composer.premise(chimera_staging, ndevs, nmicros) - micros = Composer.premise(mbart_staging, ndevs, nmicros) + micros = Composer.premise(uniform_staging, ndevs, nmicros) print('============== Premise ================') for idx, micro in enumerate(micros): print(f'microbatch #{idx}:') diff --git a/cube/tetris/solver.py b/cube/tetris/solver.py new file mode 100644 index 00000000..86007a67 --- /dev/null +++ b/cube/tetris/solver.py @@ -0,0 +1,244 @@ +""" +A solver based solution for scheduling plan +""" + +from typing import List, Optional, Tuple +from enum import Enum + +from z3 import * +import time +import copy + + +gsolver = Solver() + + +class Block: + + class BType(Enum): + FW = 'forward' + BW = 'backward' + + def __init__(self, mid: int, btype: BType, name: str, mem=1): + global _uid + global gsolver + self.mid = mid + self.step = Int(name) + self.memory = mem if btype == Block.BType.FW else 0-mem + gsolver.add(self.step >= 1) + self.btype = btype + + @staticmethod + def add_dependency(blk1, blk2): + """ + add dependency: blk1 -> blk2 + """ + global gsolver + gsolver.add(blk1.step < blk2.step) + + def __repr__(self): + return f'f{self.mid}' if self.btype == Block.BType.FW else f'b{self.mid}' + + +class SchedulePlan: + + def __init__(self, ndevs: int) -> None: + + self._blocks: List[List[Block]] = [[] for _ in range(ndevs)] + self.ndevs = ndevs + self._nsteps = None + self._mem = None + self._solution: Optional[z3.z3.ModelRef] = None + + @property + def nblocks(self) -> int: + return sum(len(blks) for blks in self._blocks) + + @property + def nsteps(self) -> int: + if self._solution is None: + return -1 + return self._solution.eval(self._nsteps).as_long() + + @property + def mem(self) -> int: + if self._mem is None: + return -1 + return self._solution.eval(self._mem).as_long() + + def blocks(self, devid: Optional[int] = None) -> List[Block]: + if isinstance(devid, int): + return copy.copy(self._blocks[devid]) + else: + allblocks = [] + for blks in self._blocks: + allblocks += blks + return allblocks + + def position(self, block: Block) -> Tuple[int, int]: + """ + get block position (device, time) after the search + """ + # device + for devid in range(self.ndevs): + if block in self.blocks(devid): + break + else: + assert False, 'block not in schedule plan' + # time step + step = None + if self._solution is not None: + step = self._solution[block.step] + return (devid, step) + + def add_block(self, block: Block, device: int): + global gsolver + for blk in self._blocks[device]: + gsolver.add(blk.step != block.step) + self._blocks[device].append(block) + # set plan step variable + if self._nsteps is None: + self._nsteps = block.step + else: + self._nsteps = If(block.step > self._nsteps, block.step, self._nsteps) + + def set_memory(self): + nblocks = max(len(blks) for blks in self._blocks) + # mems = [IntVector(f'memdev{devid}', nblocks) for devid in range(self.ndevs)] + peaks = [] + for devid in range(self.ndevs): + peak = 0 + curr = 0 + for step in range(0, nblocks): + mem = 0 + for block in self.blocks(devid): + mem = If(block.step == step, block.memory, mem) + curr = mem + curr + peak = If(curr > peak, curr, peak) + peaks.append(peak) + # global peak + globalpeak = peaks[0] + for devid in range(1, self.ndevs): + globalpeak = If(peaks[devid] > globalpeak, peaks[devid], globalpeak) + self._mem = globalpeak + return globalpeak + + def set_solution(self, solution: z3.z3.ModelRef): + self._solution = solution + + def solve(self): + global gsolver + tic = time.time() + opt_step = max(len(blks) for blks in self._blocks) + while True: + gsolver.push() + gsolver.add(self._nsteps == opt_step) + if gsolver.check() == sat: + print(f'find optimal step in {opt_step} steps') + solution = gsolver.model() + self.set_solution(solution) + gsolver.pop() + break + else: + print(f'fail to find solution for {opt_step} steps') + gsolver.pop() + opt_step += 1 + toc = time.time() + print('search time: {:.2f} seconds'.format(toc-tic)) + print('solution:') + print(self) + + # search for optimal memory + tic = time.time() + opt_mem = 1 + self.set_memory() + gsolver.push() + gsolver.add(self._nsteps == opt_step) + while True: + gsolver.push() + gsolver.add(self._mem == opt_mem) + if gsolver.check() == sat: + print(f'find optimal memory {opt_mem}') + solution = gsolver.model() + self.set_solution(solution) + gsolver.pop() + break + else: + print(f'fail to find solution for memory {opt_mem}') + gsolver.pop() + opt_mem += 1 + gsolver.pop() + toc = time.time() + print('search memory time: {:.2f} seconds'.format(toc-tic)) + print('solution:\n', self) + + # self.iter_space(opt_step) + + def iter_space(self, nsteps: int, memory: int): + """ + iterate all solutions find by solver + """ + global gsolver + gsolver.push() + gsolver.add(self._nsteps == nsteps) + models = [] + while gsolver.check() == sat: + model = gsolver.model() + models.append(model) + block = [] + for d in model: + assert not d.arity() > 0, 'uniterpreted functions found' + c = d() + block.append(c != model[d]) + gsolver.add(Or(block)) + if len(models) % 10 == 0: + print(f'find {len(models)} solutions..') + gsolver.pop() + + + def __repr__(self) -> str: + if self._solution is None: + return 'Unsolved Schedule Plan.' + namelen = 2 + dscp = '' + for devid in range(self.ndevs): + blocks = self.blocks(devid) + steps = [self.position(blk)[1] for blk in blocks] + for step in range(1, self.nsteps+1): + if step not in steps: + dscp += '-' * namelen + ' ' + else: + idx = steps.index(step) + dscp += '{: <2}'.format(repr(blocks[idx])) + ' ' + dscp += '\n' + return dscp + + +if __name__ == '__main__': + + def uniform_staging(ndevs: int, nmicros) -> SchedulePlan: + """ + f b + f b + f b + f b + """ + sched = SchedulePlan(ndevs) + for mid in range(nmicros): + fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}') for devid in range(ndevs)] + bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}') for devid in range(ndevs)][::-1] + blocks = fblocks + bblocks + for idx in range(ndevs * 2 - 1): + Block.add_dependency(blocks[idx], blocks[idx+1]) + for devid in range(ndevs): + sched.add_block(fblocks[devid], devid) + sched.add_block(bblocks[ndevs-1-devid], devid) + return sched + + ndevs = 4 + nmicros = 4 + + sched = uniform_staging(ndevs, nmicros) + sched.solve() + + From 7fac25d9d0a38b9a91b359585f2d05b7a7eb24f8 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Mon, 6 Jun 2022 11:50:22 +0800 Subject: [PATCH 0841/1892] add cube.init to WRF2 --- examples/wrf/wrf2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py index 7142fb90..add2ad43 100644 --- a/examples/wrf/wrf2.py +++ b/examples/wrf/wrf2.py @@ -406,6 +406,8 @@ def solve_tridiagonal_(self, import matplotlib.pyplot as plt from matplotlib.ticker import ScalarFormatter + cube.init() + nz = 16 dz = 1. / 16 ny = 128 From 4fd10424f6166de70fbc4d83f52df0324b7a70fa Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Mon, 6 Jun 2022 11:52:41 +0800 Subject: [PATCH 0842/1892] add missing next --- examples/wrf/wrf2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py index add2ad43..ef1f7a82 100644 --- a/examples/wrf/wrf2.py +++ b/examples/wrf/wrf2.py @@ -435,7 +435,7 @@ def solve_tridiagonal_(self, @cube.compile(model=model, dataloader=varloader) def train_iter(model, dataloader): - U, V, W, O, Theta, phi1, mu1 = dataloader + U, V, W, O, Theta, phi1, mu1 = next(dataloader) U, V, W, O, Theta, phi1, mu1 = model(U, V, W, O, Theta, phi1, mu1) return U, V, W, O, Theta, phi1, mu1 model = model.get_gen_module() From 57f627f5ad0ff05420376a7d99221d583ffc648f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Jun 2022 12:23:37 +0800 Subject: [PATCH 0843/1892] fix bug: same operands issue on adding consumer and producer --- cube/graph/graph.py | 48 ++++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ea07f194..c93d9223 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -283,9 +283,19 @@ def detach(self, node: IRCell, reset_dependency=False) -> int: self._nodes.pop(index) if isinstance(node, IRAdapter): return index + # update consumer + itensors = [] for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor) and itensor not in itensors: + itensors.append(itensor) + for itensor in itensors: if isinstance(itensor, IRSubTensor): itensor.parent.rm_consumer(node) + # update producer + otensors = [] + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor) and otensor not in otensors: + otensors.append(otensor) for otensor in node.outputs(): if isinstance(otensor, IRSubTensor): otensor.parent.rm_producer(node) @@ -307,25 +317,31 @@ def attach(self, node: IRCell, index, reset_dependency=False): if isinstance(node, IRAdapter): return # update consumer + itensors = [] for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor): - idx = 0 - for consumer in itensor.parent.consumers: - if self.nodes().index(consumer) < index: - idx += 1 - else: - break - itensor.parent.add_consumer(node, itensor, idx) + if isinstance(itensor, IRSubTensor) and itensor not in itensors: + itensors.append(itensor) + for itensor in itensors: + idx = 0 + for consumer in itensor.parent.consumers: + if self.nodes().index(consumer) < index: + idx += 1 + else: + break + itensor.parent.add_consumer(node, itensor, idx) # update producer + otensors = [] for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor): - idx = 0 - for producer in otensor.parent.producers: - if self.nodes().index(producer) < index: - idx += 1 - else: - break - otensor.parent.add_producer(node, otensor, idx) + if isinstance(otensor, IRSubTensor) and otensor not in otensors: + otensors.append(otensor) + for otensor in otensors: + idx = 0 + for producer in otensor.parent.producers: + if self.nodes().index(producer) < index: + idx += 1 + else: + break + otensor.parent.add_producer(node, otensor, idx) if reset_dependency: self.reset_dependency() return From 0181a015771a7c3759d216c5e193048e031ba570 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Jun 2022 13:24:54 +0800 Subject: [PATCH 0844/1892] fix sum shape inference on dim reduction --- cube/graph/function/function.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index ab10aa9c..e5e2a18a 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -2,6 +2,7 @@ import string import copy import numpy +from torch import sign from cube.ir.cten import IRTensor from cube.graph.function.einops import EinDim, IREinops @@ -463,21 +464,26 @@ def LayerNorm(signature, inputs): def Sum(signature, inputs): - # TODO: support dim reduction - annos = [ - '*+ -> 1', - ] - tensor = inputs[0:1] + tensor = inputs[0] dim = inputs[1] + einput = _create_eshape(tensor.shape) + eoutput = copy.copy(einput) + if dim is not None: + keepdim = inputs[2] + sort_dim = list(dim) + sort_dim.sort() + for dimidx in sort_dim[::-1]: + eoutput.pop(dimidx) + einput[dimidx] = einput[dimidx] + '+' + else: + eoutput = ['1'] + # every dimension is reduced + einput = [edim + '+' for edim in einput] + anno = _create_anno([einput], [eoutput]) if dim is not None: - keepdim = inputs[2] if len(inputs) > 2 else False - dim_len = len(tensor[0].shape) - anno = "".join([f'b{i} ' for i in range(dim_len)]) + " -> " + "".join([f'b{i} ' if i not in dim else "" for i in range(dim_len)]) - annos.append(anno) - return IREinops(signature, annos, tensor, 'sum', - dim=dim, keepdim=keepdim) + return IREinops(signature, [anno], [tensor], 'sum', dim=dim, keepdim=keepdim) else: - return IREinops(signature, annos, tensor, 'sum') + return IREinops(signature, [anno], [tensor], 'sum') def Transpose(signature, inputs): From c780bd5893aeaf0eb32f35cbbde0a96136d065fe Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Jun 2022 13:26:11 +0800 Subject: [PATCH 0845/1892] clear needless import --- cube/graph/function/function.py | 1 - cube/graph/graph.py | 12 +++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index e5e2a18a..b6334b87 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -2,7 +2,6 @@ import string import copy import numpy -from torch import sign from cube.ir.cten import IRTensor from cube.graph.function.einops import EinDim, IREinops diff --git a/cube/graph/graph.py b/cube/graph/graph.py index c93d9223..5c326e13 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -284,21 +284,19 @@ def detach(self, node: IRCell, reset_dependency=False) -> int: if isinstance(node, IRAdapter): return index # update consumer - itensors = [] + itensors: List[IRSubTensor] = [] for itensor in node.inputs(): if isinstance(itensor, IRSubTensor) and itensor not in itensors: itensors.append(itensor) for itensor in itensors: - if isinstance(itensor, IRSubTensor): - itensor.parent.rm_consumer(node) + itensor.parent.rm_consumer(node) # update producer - otensors = [] + otensors: List[IRSubTensor] = [] for otensor in node.outputs(): if isinstance(otensor, IRSubTensor) and otensor not in otensors: otensors.append(otensor) - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor): - otensor.parent.rm_producer(node) + for otensor in otensors: + otensor.parent.rm_producer(node) if reset_dependency: self.reset_dependency() return index From c433647937e6eb365b14989e7c0ee27134b0349c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Jun 2022 14:55:00 +0800 Subject: [PATCH 0846/1892] support block on multiple device (tensor parallelism) --- cube/tetris/solver.py | 91 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 11 deletions(-) diff --git a/cube/tetris/solver.py b/cube/tetris/solver.py index 86007a67..11fdb70e 100644 --- a/cube/tetris/solver.py +++ b/cube/tetris/solver.py @@ -2,7 +2,7 @@ A solver based solution for scheduling plan """ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from enum import Enum from z3 import * @@ -10,6 +10,7 @@ import copy + gsolver = Solver() @@ -22,6 +23,7 @@ class BType(Enum): def __init__(self, mid: int, btype: BType, name: str, mem=1): global _uid global gsolver + self.name = name self.mid = mid self.step = Int(name) self.memory = mem if btype == Block.BType.FW else 0-mem @@ -91,11 +93,13 @@ def position(self, block: Block) -> Tuple[int, int]: step = self._solution[block.step] return (devid, step) - def add_block(self, block: Block, device: int): + def add_block(self, block: Block, devices: Tuple[int]): global gsolver - for blk in self._blocks[device]: - gsolver.add(blk.step != block.step) - self._blocks[device].append(block) + devices = (devices,) if isinstance(devices, int) else devices + for device in devices: + for blk in self._blocks[device]: + gsolver.add(blk.step != block.step) + self._blocks[device].append(block) # set plan step variable if self._nsteps is None: self._nsteps = block.step @@ -130,7 +134,9 @@ def solve(self): global gsolver tic = time.time() opt_step = max(len(blks) for blks in self._blocks) + max_step = self.nblocks while True: + assert opt_step <= max_step, "out of step boundary. consider this as a bug." gsolver.push() gsolver.add(self._nsteps == opt_step) if gsolver.check() == sat: @@ -170,17 +176,20 @@ def solve(self): gsolver.pop() toc = time.time() print('search memory time: {:.2f} seconds'.format(toc-tic)) - print('solution:\n', self) + print('solution:') + print(self) - # self.iter_space(opt_step) + self.iter_space(opt_step, opt_mem) - def iter_space(self, nsteps: int, memory: int): + def iter_space(self, nsteps: int, memory: int = None): """ iterate all solutions find by solver """ global gsolver gsolver.push() gsolver.add(self._nsteps == nsteps) + if memory is not None: + gsolver.add(self._mem == memory) models = [] while gsolver.check() == sat: model = gsolver.model() @@ -194,6 +203,7 @@ def iter_space(self, nsteps: int, memory: int): if len(models) % 10 == 0: print(f'find {len(models)} solutions..') gsolver.pop() + print(f'find {len(models)} possible models') def __repr__(self) -> str: @@ -235,10 +245,69 @@ def uniform_staging(ndevs: int, nmicros) -> SchedulePlan: sched.add_block(bblocks[ndevs-1-devid], devid) return sched + def chimera_staging(ndevs: int, nmicros: int) -> SchedulePlan: + """ + f b f b + f b f b + f b f b + f b f b + """ + sched = SchedulePlan(ndevs) + assert nmicros % 2 == 0, "require microbatch# can be devided by 2" + for mid in range(nmicros // 2): # V shape + fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}', mem=devid+1) for devid in range(ndevs)] + bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}', mem=devid+1) for devid in range(ndevs-1,-1,-1)] + blocks = fblocks + bblocks + for idx in range(ndevs * 2 - 1): + Block.add_dependency(blocks[idx], blocks[idx+1]) + for devid in range(ndevs): + sched.add_block(fblocks[devid], devid) + sched.add_block(bblocks[ndevs-1-devid], devid) + for mid in range(nmicros // 2): # ^ shape + mid = mid + nmicros // 2 + fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}', mem=ndevs-devid) for devid in range(ndevs-1,-1,-1)] + bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}', mem=ndevs-devid) for devid in range(ndevs)] + blocks = fblocks + bblocks + for idx in range(ndevs * 2 - 1): + Block.add_dependency(blocks[idx], blocks[idx+1]) + for devid in range(ndevs): + sched.add_block(fblocks[ndevs-1-devid], devid) + sched.add_block(bblocks[devid], devid) + return sched + + def mbart_staging(ndevs: int, nmicros: int) -> SchedulePlan: + """ + f f f b b b + f f f b b b + f f f b b b + f f f b b b + """ + sched = SchedulePlan(ndevs) + for mid in range(nmicros): + fblocks = [] + bblocks = [] + for step in range(ndevs+2): + if step in [0, ndevs // 2 + 1]: + fdevid = bdevid = tuple(range(ndevs)) + fblock = Block(mid, Block.BType.FW, f'fe{step}{mid}devall', mem=4) + bblock = Block(mid, Block.BType.BW, f'be{step}{mid}devall', mem=4) + else: + fdevid = bdevid = step - 1 if step < ndevs // 2 + 1 else step - 2 + fblock = Block(mid, Block.BType.FW, f'f{mid}dev{fdevid}', mem=1) + bblock = Block(mid, Block.BType.BW, f'b{mid}dev{bdevid}', mem=1) + fblocks.append(fblock) + bblocks.append(bblock) + sched.add_block(fblock, fdevid) + sched.add_block(bblock, bdevid) + blocks = fblocks + bblocks[::-1] + for idx in range((ndevs + 2) * 2 - 1): + Block.add_dependency(blocks[idx], blocks[idx+1]) + return sched + ndevs = 4 nmicros = 4 - sched = uniform_staging(ndevs, nmicros) + # sched = uniform_staging(ndevs, nmicros) + # sched = chimera_staging(ndevs, nmicros) + sched = mbart_staging(ndevs, nmicros) # ndev=4, nmicro=4 => solution: step=32 sched.solve() - - From 053a5b8647649af7cdb82d64bebfa9658d8d75d2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 8 Jun 2022 09:35:18 +0800 Subject: [PATCH 0847/1892] searching with decreasing step and memory --- cube/tetris/solver.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/cube/tetris/solver.py b/cube/tetris/solver.py index 11fdb70e..520c0d23 100644 --- a/cube/tetris/solver.py +++ b/cube/tetris/solver.py @@ -130,37 +130,44 @@ def set_memory(self): def set_solution(self, solution: z3.z3.ModelRef): self._solution = solution - def solve(self): + def solve(self, decrease = True): global gsolver tic = time.time() - opt_step = max(len(blks) for blks in self._blocks) + min_step = max(len(blks) for blks in self._blocks) max_step = self.nblocks + opt_step = max_step if decrease else min_step while True: - assert opt_step <= max_step, "out of step boundary. consider this as a bug." + assert min_step <= opt_step and opt_step <= max_step, "out of step boundary. consider this as a bug." gsolver.push() gsolver.add(self._nsteps == opt_step) if gsolver.check() == sat: - print(f'find optimal step in {opt_step} steps') + print(f'find scheduling plan in {opt_step} steps') solution = gsolver.model() self.set_solution(solution) gsolver.pop() - break + if not decrease: break else: print(f'fail to find solution for {opt_step} steps') gsolver.pop() - opt_step += 1 + if decrease: + opt_step += 1 + break + opt_step = opt_step - 1 if decrease else opt_step + 1 toc = time.time() - print('search time: {:.2f} seconds'.format(toc-tic)) + print('search time: {:.2f} seconds. find optimal step: {}'.format(toc-tic, opt_step)) print('solution:') print(self) # search for optimal memory tic = time.time() - opt_mem = 1 + min_mem = max(min(blk.memory for blk in blks if blk.btype == Block.BType.FW) for blks in self._blocks) + max_mem = max(sum(blk.memory for blk in blks if blk.btype == Block.BType.FW) for blks in self._blocks) + opt_mem = max_mem if decrease else min_mem self.set_memory() gsolver.push() gsolver.add(self._nsteps == opt_step) while True: + assert min_mem <= opt_mem and opt_mem <= max_mem, "out of memory boundary. consider this as a bug" gsolver.push() gsolver.add(self._mem == opt_mem) if gsolver.check() == sat: @@ -168,18 +175,21 @@ def solve(self): solution = gsolver.model() self.set_solution(solution) gsolver.pop() - break + if not decrease: break else: print(f'fail to find solution for memory {opt_mem}') gsolver.pop() - opt_mem += 1 + if decrease: + opt_mem += 1 + break + opt_mem = opt_mem - 1 if decrease else opt_mem + 1 gsolver.pop() toc = time.time() - print('search memory time: {:.2f} seconds'.format(toc-tic)) + print('search memory time: {:.2f} seconds. opt-memory: {}'.format(toc-tic, opt_mem)) print('solution:') print(self) - self.iter_space(opt_step, opt_mem) + # self.iter_space(opt_step, opt_mem) def iter_space(self, nsteps: int, memory: int = None): """ @@ -307,7 +317,7 @@ def mbart_staging(ndevs: int, nmicros: int) -> SchedulePlan: ndevs = 4 nmicros = 4 - # sched = uniform_staging(ndevs, nmicros) + sched = uniform_staging(ndevs, nmicros) # sched = chimera_staging(ndevs, nmicros) - sched = mbart_staging(ndevs, nmicros) # ndev=4, nmicro=4 => solution: step=32 + # sched = mbart_staging(ndevs, nmicros) # ndev=4, nmicro=4 => solution: step=32 sched.solve() From 1c05765880c9efa41bfc5955ecf17bbf41f38410 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 8 Jun 2022 14:56:05 +0800 Subject: [PATCH 0848/1892] decreasing search algorithm --- cube/tetris/composer.py | 30 ++++++++++++++---------------- cube/tetris/solver.py | 18 +++++++++++------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/cube/tetris/composer.py b/cube/tetris/composer.py index acbfea75..f1cbaa11 100644 --- a/cube/tetris/composer.py +++ b/cube/tetris/composer.py @@ -339,10 +339,10 @@ def premise(fn, ndevs: int, nmicros: int): return micros @staticmethod - def bfs_schedule(micros: List[MicroPlan], mem_opt=True): + def bfs_schedule(micros: List[MicroPlan], mem_opt=True, prune_symmetric=True): total_status = 1 micros.sort(key=lambda m: m.mid) - block_hash = Composer.construct_hash(micros) # False # Composer.same_plans(micros, start_step=0) + block_hash = Composer.construct_hash(micros) if prune_symmetric else None step = 0 opt_step = sum(micro.nsteps for micro in micros) # initial prev: List[List[MicroPlan]] = [micros] @@ -354,7 +354,7 @@ def bfs_schedule(micros: List[MicroPlan], mem_opt=True): # get and solve conflicts conflicts = SchedulePlan.conflict(micros, step) # input(f'conflicts: dev: {list(conflicts.keys())}, mids: {[[conf[0].mid for conf in c] for c in conflicts.values()]} | >>>') - for shifts in Composer.iter_shifts(conflicts, step, prune_same_micro=True, block_hash=block_hash): + for shifts in Composer.iter_shifts(conflicts, block_hash=block_hash): # print(f"step {step}: {shifts}") shifted_micros = [micro.copy() for micro in micros] for cblock in shifts: @@ -440,8 +440,6 @@ def to_same_step(micros: List[MicroPlan]): @staticmethod def iter_shifts(conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], - step: int, - prune_same_micro = True, block_hash = Union[None, Callable]) -> List[Block]: """ Enumerate shifted blocks to resolve conflicts on step `step`. @@ -475,9 +473,9 @@ def iter_shifts(conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], else: candidates = cblocks - if prune_same_micro: - if Composer.same_plans(cmicros, start_step=step): - candidates = [candidates[0]] + # if prune_same_micro: + # if Composer.same_plans(cmicros, start_step=step): + # candidates = [candidates[0]] for kblock in candidates: idx = cblocks.index(kblock) @@ -631,8 +629,8 @@ def compose_1F1B(ndevs, nmicros): def search(ndevs, nmicros, visualize=False): # premise # micros = Composer.premise(uniform_staging, ndevs, nmicros) - # micros = Composer.premise(chimera_staging, ndevs, nmicros) - micros = Composer.premise(uniform_staging, ndevs, nmicros) + micros = Composer.premise(chimera_staging, ndevs, nmicros) + # micros = Composer.premise(mbart_staging, ndevs, nmicros) print('============== Premise ================') for idx, micro in enumerate(micros): print(f'microbatch #{idx}:') @@ -643,7 +641,7 @@ def search(ndevs, nmicros, visualize=False): # search shift tic = time.time() - schedules = Composer.bfs_schedule(micros, mem_opt=True) + schedules = Composer.bfs_schedule(micros, mem_opt=True, prune_symmetric=True) toc = time.time() print('search done. time {:.2f}s'.format(toc - tic)) @@ -652,11 +650,11 @@ def search(ndevs, nmicros, visualize=False): assert len(steps) == 1, f"got un-consistent step set: {steps}" nsteps = list(steps)[0] print(f'find {len(schedules)} step-optimal plans (step={nsteps})') - for idx, schedule in enumerate(schedules): - print(f'Schedule #{idx+1}:') - print(schedule) - if visualize: - schedule.visualize(f'planlog/plan{idx+1}.png') + # for idx, schedule in enumerate(schedules): + # print(f'Schedule #{idx+1}:') + # print(schedule) + # if visualize: + # schedule.visualize(f'planlog/plan{idx+1}.png') ndevs = 4 diff --git a/cube/tetris/solver.py b/cube/tetris/solver.py index 520c0d23..58a3633c 100644 --- a/cube/tetris/solver.py +++ b/cube/tetris/solver.py @@ -2,7 +2,7 @@ A solver based solution for scheduling plan """ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple from enum import Enum from z3 import * @@ -171,7 +171,7 @@ def solve(self, decrease = True): gsolver.push() gsolver.add(self._mem == opt_mem) if gsolver.check() == sat: - print(f'find optimal memory {opt_mem}') + print(f'find scheduling plan in {opt_mem} memory') solution = gsolver.model() self.set_solution(solution) gsolver.pop() @@ -189,7 +189,11 @@ def solve(self, decrease = True): print('solution:') print(self) - # self.iter_space(opt_step, opt_mem) + tic = time.time() + self.iter_space(opt_step, opt_mem) + toc = time.time() + print('iterate all plans: {:.2f} seconds.'.format(toc-tic)) + def iter_space(self, nsteps: int, memory: int = None): """ @@ -210,7 +214,7 @@ def iter_space(self, nsteps: int, memory: int = None): c = d() block.append(c != model[d]) gsolver.add(Or(block)) - if len(models) % 10 == 0: + if len(models) % 100 == 0: print(f'find {len(models)} solutions..') gsolver.pop() print(f'find {len(models)} possible models') @@ -317,7 +321,7 @@ def mbart_staging(ndevs: int, nmicros: int) -> SchedulePlan: ndevs = 4 nmicros = 4 - sched = uniform_staging(ndevs, nmicros) + # sched = uniform_staging(ndevs, nmicros) # sched = chimera_staging(ndevs, nmicros) - # sched = mbart_staging(ndevs, nmicros) # ndev=4, nmicro=4 => solution: step=32 - sched.solve() + sched = mbart_staging(ndevs, nmicros) # ndev=4, nmicro=4 => solution: step=32 + sched.solve(decrease=True) From d055a1fd9232b60eb73005d94443492f40733ae7 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 9 Jun 2022 08:36:38 +0000 Subject: [PATCH 0849/1892] Merged PR 1382: Fake WRF loop upper bounds Fake WRF loop upper bounds by introducing explicit variables, which also ease the control the magnitude of the unrolled graph. However, due to the complicity of the WRF model itself, even to reduce the faked UBs to 1, 1, 1, there will still be ~4200 nodes. (in contrast, the original UBs result in ~23000 nodes). --- examples/wrf/wrf2.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py index ef1f7a82..a57fa706 100644 --- a/examples/wrf/wrf2.py +++ b/examples/wrf/wrf2.py @@ -31,6 +31,14 @@ def __init__(self, dt, ntau, nz, ny, nx, dz, dy, dx, device): self.device = torch.device(device) + # TODO remove these testing parameters + # These three are to control the size of the unrolled graph, and they are related to the three layers of the nested loops, respectively. + # The magnitude is almost decided by `ntau` only. + self._step_fake_ntau = 1 + self._ac_step_fake_ub = 2 + self._solver_fake_ub = 2 + + def init(self, theta, Ptop=250e2): eta = torch.linspace(0, 1, self.nz + 1, device=self.device) pi = self.PREF - Ptop @@ -128,6 +136,9 @@ def step(self, dtau:float, ntau:int, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W alpha = - self.dz(self.pzphi(phi)) / mu p = self.PREF * (self.RD * Theta / mu / self.PREF / alpha)**self.GAMMA + # TODO fake upper bound + ntau = self._step_fake_ntau + for i in range(ntau): U2, V2, W2, O2, Theta2, phi2, mu2, pi2 = \ self.ac_step(dtau, @@ -175,13 +186,17 @@ def ac_step(self, dtau:float, O2_ = torch.zeros(O2.shape, device=O2.device) mu2_ = torch.zeros(mu2.shape, device=mu2.device) - for i in range(1, O2.shape[0] + 1): + # TODO fake upper bound + #for i in range(1, O2.shape[0] + 1): + for i in range(1, self._ac_step_fake_ub): sub = i * self.delta_z * dpi2 + \ (self.dx(self.px(U2_)) + self.dy(self.py(V2_)) - R_mu)[-i:].view( -1, self.ny, self.nx).sum(0) * self.delta_z O2_ = O2_.select_scatter(sub, dim=0, index=-i) - for i in range(mu2.shape[0]): + # TODO fake upper bound + #for i in range(mu2.shape[0]): + for i in range(1, self._ac_step_fake_ub): mu2_ = mu2_.select_scatter(pi2, dim=0, index=i) # self.O2_ = O2_ @@ -384,7 +399,10 @@ def solve_tridiagonal_(self, u = (torch.stack([r1, r2, r0], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1)[:-1] # forward sweep - for i in range(1, d.shape[0]): + + # TODO fake upper bound + #for i in range(1, d.shape[0]): + for i in range(1, self._solver_fake_ub): w = l[i - 1] / d[i - 1] d_i = d[i] - w * u[i - 1] @@ -396,7 +414,10 @@ def solve_tridiagonal_(self, # backward substitution x = torch.zeros(b.shape, device=b.device) x.select_scatter(b[-1] / d[-1], dim=0, index=-1) - for i in range(x.shape[0] - 2, -1, -1): + + # TODO fake upper bound + #for i in range(x.shape[0] - 2, -1, -1): + for i in range(1, self._solver_fake_ub): x.select_scatter( (b[i] - u[i] * x[i + 1]) / d[i], dim=0, index=i) return x From 443dd5fb900e848271ce877f2263ae6be097de2e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 10 Jun 2022 13:40:06 +0800 Subject: [PATCH 0850/1892] fix IRSegment bug on dupilcate inputs/outputs; fix IRAdapter gener on single device grid layout; fix tensor view signature remove compiler draw execution plan --- README.md | 2 +- cube/compiler.py | 4 ++-- cube/graph/function/function.py | 1 + cube/graph/gener/gen.py | 2 +- cube/graph/graph.py | 6 ++++-- cube/graph/parser/parser.py | 5 +++++ cube/runtime/__init__.py | 1 + 7 files changed, 15 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index fc7a2347..9566acb9 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ AI System Compiler to map a semantic (single-device) model into distributed exec ## Prerequisite * Python >= 3.7 -* PyTorch >= 1.9 +* PyTorch >= 1.11 ## Install diff --git a/cube/compiler.py b/cube/compiler.py index c24ead97..349fea68 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -187,8 +187,8 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on grouping operations: {:.2f} s'.format(span)) - execplan.graph.reset_dependency() - execplan.analyze(outfile='execplan.png') + # execplan.graph.reset_dependency() + # execplan.analyze(outfile='execplan.png') if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index b6334b87..b816db4d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -612,6 +612,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s bracket[subdim] = str(shape_map[edim]) # bracket[subdim] = edim + '^' anno = _create_anno([in_anno], [ou_anno]) + signature = 'torch.Tensor.view' return IREinops(signature, [anno], [input], 'view', shape=tuple(shape)) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index f6db5785..db455293 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -129,7 +129,7 @@ def gen_fulltensor(ftensor: IRFullTensor) -> List[IRAdapter]: cdevs.update(cnode.device) # sharing devices - if pdevs == cdevs: + if pdevs == cdevs and len(pdevs) > 1: return IRAdapterGener.gen_gridlayout(ftensor) # no-sharing devices diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5c326e13..24690a61 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -240,14 +240,16 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: producers = [p for p in itensor.parent.producers if set(p.device).issubset(set(node.device))] # no producer means a weight or cross device-group if len(producers) == 0 or any(p not in nodes for p in producers): - inputs.append(itensor) + if itensor not in inputs: + inputs.append(itensor) # update outputs otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] for otensor in otensors: consumers = [c for c in otensor.parent.consumers if set(c.device).issubset(set(node.device))] # no consumer usually means the loss or cross device-group if len(consumers) == 0 or any(c not in nodes for c in consumers): - outputs.append(otensor) + if otensor not in outputs: + outputs.append(otensor) segment = IRSegment(nodes, inputs, outputs) return segment diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 79a215d9..7ca249ca 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -8,6 +8,8 @@ from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import Sign2Op, DType2IRDType +import warnings + _refmodule = torch.nn.Module() @@ -391,6 +393,9 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: ) if isinstance(tensor, torch.nn.Parameter): ir_tensor.as_param() + else: + warnings.warn('Detected non-parameter tensor as graph attribute. Regard them as parameters') + ir_tensor.as_param() frame.add_var(var_name, ir_tensor) # symbolic attributes elif dtype in ['bool', 'int', 'float']: diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index 1929a791..e883bb2c 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -5,3 +5,4 @@ from cube.runtime import resource from cube.runtime import module from cube.runtime import function +from cube.runtime import schedule From e59f3f58da8d896eef416e33276c6c131eced224 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 10 Jun 2022 13:54:06 +0800 Subject: [PATCH 0851/1892] naive PAS --- examples/wrf/policy/naive.py | 14 +------------- examples/wrf/wrf2.py | 3 ++- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/examples/wrf/policy/naive.py b/examples/wrf/policy/naive.py index d58e2d5e..9d1b7d97 100644 --- a/examples/wrf/policy/naive.py +++ b/examples/wrf/policy/naive.py @@ -3,18 +3,6 @@ def PAS(graph: IRGraph, resource): for node in graph.nodes(): - if isinstance(node, IRConv2D): - sub_nodes = list() - algo = node.algorithms('halo') - Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus // 2)) - for Wnode in Wnodes: - algo = Wnode.algorithms('halo') - Hnodes = graph.partition(Wnode, algo, config=dict(idx=0, dim=2, num=2)) - sub_nodes += Hnodes - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - # sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) + graph.assign(node, 0) print(graph.extra_repr()) return graph diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py index a57fa706..e79ff9ff 100644 --- a/examples/wrf/wrf2.py +++ b/examples/wrf/wrf2.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from cube.runtime.syndata import SciLoopVariables +from examples.wrf.policy.naive import PAS torch.set_default_tensor_type(torch.DoubleTensor) @@ -454,7 +455,7 @@ def solve_tridiagonal_(self, varloader = SciLoopVariables(variables=[U, V, W, O, Theta, phi1, mu1], constants=[]) model = cube.SemanticModel(wrf, input_shapes=tuple(varloader.shapes)) - @cube.compile(model=model, dataloader=varloader) + @cube.compile(model=model, dataloader=varloader, PAS=PAS) def train_iter(model, dataloader): U, V, W, O, Theta, phi1, mu1 = next(dataloader) U, V, W, O, Theta, phi1, mu1 = model(U, V, W, O, Theta, phi1, mu1) From 3ea3aa39a40ce3987c4b778ebbc93534b2b106a2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 10 Jun 2022 19:44:40 +0800 Subject: [PATCH 0852/1892] fix adapter gener bug; fix refcount bug for graph output --- cube/codegen/codegen.py | 18 +++++++++------ cube/graph/function/function.py | 7 ++++++ cube/graph/gener/gen.py | 8 +++++-- cube/graph/graph.py | 2 ++ cube/runtime/function/function.py | 7 ++++++ examples/poisson/sci.py | 37 ------------------------------- 6 files changed, 33 insertions(+), 46 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 1c052357..244fce3f 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -347,24 +347,28 @@ def gen(self, device: int, outfile=None, attach=False) -> str: device_nodes = self.execplan.seq(device) - def refcount(tensor, node) -> int: + def later_ref(tensor, node) -> bool: + """ + check whether the output tensor of the node need to be later used. + """ idx = device_nodes.index(node) - refcnt = 0 + if tensor in self.execplan.graph.outputs(): + return True for ref_node in device_nodes[idx+1:]: if isinstance(ref_node, IRSegment): if ref_node.forward: if tensor in ref_node.inputs(): - refcnt += 1 + return True else: finputs = ref_node.mirror.inputs() foutputs = ref_node.mirror.outputs() grad_in = [t.grad for t in foutputs] if tensor in finputs + foutputs + grad_in: - refcnt += 1 + return True else: if tensor in ref_node.inputs(): - refcnt += 1 - return refcnt + return True + return False with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: @@ -383,7 +387,7 @@ def refcount(tensor, node) -> int: # free unused tensor for tensor in node.inputs() + node.outputs(): if isinstance(tensor, IRSubTensor) and not tensor.is_param(): - refcnt = refcount(tensor, node) + refcnt = later_ref(tensor, node) if refcnt == 0: self.vars.free(tensor) # return code diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index b816db4d..6c3d2230 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -65,6 +65,13 @@ def _create_anno(ins: List[List[Union[str, List[str]]]], return ', '.join(in_annos) + ' -> ' + ', '.join(ou_annos) +def Identity(signature, inputs): + signature = 'cube.runtime.function.identity' + eshape = _create_eshape(inputs[0].shape) + anno = _create_anno([eshape], [eshape]) + return IREinops(signature, [anno], inputs, 'identity') + + def Linear(signature, inputs): if signature == 'torch.linear': import warnings diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index db455293..7a0fc65b 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -40,6 +40,7 @@ def gen(graph: IRGraph) -> IRGraph: graph.attach(node, idx) graph = IRAdapterGener.gen_activation(graph) graph = IRAdapterGener.gen_weight(graph) + # TODO: generate adapter for graph outputs return graph @staticmethod @@ -129,7 +130,10 @@ def gen_fulltensor(ftensor: IRFullTensor) -> List[IRAdapter]: cdevs.update(cnode.device) # sharing devices - if pdevs == cdevs and len(pdevs) > 1: + if pdevs == cdevs and len(pdevs) > 1 and \ + len(pdevs) == len(ftensor.producers) and \ + len(cdevs) == len(ftensor.consumers): + # TODO: enable tensor fusion of tensors on same device return IRAdapterGener.gen_gridlayout(ftensor) # no-sharing devices @@ -205,7 +209,7 @@ def gen_gridlayout(ftensor: IRFullTensor) -> List[IRAdapter]: # generate backward grad: IRFullTensor = ftensor.grad bprims = [] - if grad is not None: + if grad is not None and (len(grad.ptensors) != 0 or len(grad.ctensors) != 0): # reorder ptensors to match with forward ptensors = [None] * len(devs) for ptensor in grad.ptensors: diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 24690a61..8977f4c0 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -245,6 +245,8 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: # update outputs otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] for otensor in otensors: + if otensor in self.outputs(): + outputs.append(otensor) consumers = [c for c in otensor.parent.consumers if set(c.device).issubset(set(node.device))] # no consumer usually means the loss or cross device-group if len(consumers) == 0 or any(c not in nodes for c in consumers): diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index c0d27aaa..55bc5c34 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -3,6 +3,13 @@ import torch.nn.functional as TorchF +def identity(tensor: torch.Tensor) -> torch.Tensor: + """ + identity forward + """ + return tensor + + def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], stride: int, padding: List[int], dilation, groups: int = 1): """ diff --git a/examples/poisson/sci.py b/examples/poisson/sci.py index 7fd0ec39..47e7b900 100644 --- a/examples/poisson/sci.py +++ b/examples/poisson/sci.py @@ -36,43 +36,6 @@ def forward(self, r0: torch.Tensor, p: torch.Tensor, phi: torch.Tensor, return r1, p, phi, r1_sum -class LoopVariables(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, variables: List[torch.Tensor], constants: List[torch.Tensor]): - - shapes = [list(var.size()) for var in variables + constants] - dtypes = [var.dtype for var in variables + constants] - batch_dims = [0] * (len(variables) + len(constants)) - super().__init__(shapes, dtypes, batch_dims) - self.variables = list() - self.constants = list() - for var in variables: - if torch.is_tensor(var) and var.device != torch.cuda.current_device(): - var = var.cuda() - self.variables.append(var) - for const in constants: - if torch.is_tensor(const) and const.device != torch.cuda.current_device(): - const = const.cuda() - self.constants.append(const) - - def __iter__(self): - return self - - def update(self, variables: List[torch.Tensor] = None, constants: List[torch.Tensor] = None): - if variables is not None: - self.variables = variables - if constants is not None: - self.constants = constants - - def reset(self, batch_size): - pass - - def __next__(self): - if len(self.variables) + len(self.constants) == 1: - return (self.variables + self.constants)[0] - return tuple(self.variables + self.constants) - - def train_loop(): # initialize N = 1024 * 2 From bdc37fde0fa55ff1b36c831da0d2584aee9814f8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 13 Jun 2022 13:16:52 +0800 Subject: [PATCH 0853/1892] fix graph attribute multi-reference bug --- cube/graph/parser/frame.py | 60 +++++++++++++++++++++++++++++-------- cube/graph/parser/parser.py | 40 ++++++++++++++----------- 2 files changed, 71 insertions(+), 29 deletions(-) diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index 19f95ea4..9350ef4b 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -11,10 +11,13 @@ def __init__(self): # var name -> value (IRTesnor, deterministic) self._vars: List[dict[str, Any]] = list() self._var_stack: List[str] = list() + # module attributes + self._attributes: List[dict[str, Any]] = list() - def push(self, inherit_from_top=False): + def push_var(self, inherit_from_top=False): """ - This should only be called when step in a module + Push a new variable frame as current variable frame. + This should only be called when stepping in a module or method. Args: inherit_from_top (bool): @@ -28,9 +31,10 @@ def push(self, inherit_from_top=False): else: self._vars.append(OrderedDict()) - def pop(self): + def pop_var(self): """ - This should only be called step out a module + Pop the current variable frame. + This should only be called when steping out a module or method. """ if len(self._vars) == 0: raise RuntimeError("Try to pop stack with 0 depth") @@ -44,14 +48,10 @@ def add_var(self, var_name: str, val: Any, graph_arg: int = -1): var_name (str): variable name (unique) val: variable content graph_arg (int): - indicate whether it is an argument of the graph. - - If == -1, is not a graph arg. - - If >= 0, is a graph arg, will try to find val from previous frame, - by associating the names of the formal parameters of the callee function - and the names of the arguments passed-in. - (then look up the values of the arguments in the previous frame) + indicate whether it is an argument of the graph. -1 indicates not an argument. + If >= 0, is a graph arg, will try to find val from variable stack, + and link the name of the argument name from the callee function + to the names of the argument passed-in. """ if not isinstance(var_name, str): raise RuntimeError("Expected var_name is str") @@ -90,6 +90,42 @@ def get_var(self, var_name: str) -> Any: return self._vars[-1][var_name] raise KeyError(f"Cannot find var name {var_name}") + def push_attr(self): + """ + Push a new module attribut frame as current frame. + This should only be called when stepping in the graph. + """ + self._attributes.append(OrderedDict()) + + def pop_attr(self): + """ + Pop the current module attribute frame. + This should only be called when stepping out the graph. + """ + self._attributes.pop() + + def add_attr(self, name: str, val: Any): + """ + Add module attribute + """ + if name in self._attributes[-1]: + raise KeyError("Try to add an already existed attributed") + self._attributes[-1][name] = val + + def get_attr(self, name: str) -> Any: + """ + Get module attribute by name + """ + if name not in self._attributes[-1]: + raise KeyError(f"Cannot find var name {name}") + return self._attributes[-1][name] + + def has_attr(self, name: str) -> bool: + """ + Return if `name` exists in current attributes + """ + return name in self._attributes[-1] + def push_param(self, var_name): """ push var name to the method stack diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 7ca249ca..87991aa9 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -43,7 +43,8 @@ def parse_module(module, The overall entry to parse a torchscript graph module """ frame = frame if frame is not None else Frame() - frame.push() + frame.push_var() + frame.push_attr() inputs = list(module.graph.inputs())[1:] if input_shapes is not None and len(input_shapes) != len(inputs): @@ -92,7 +93,8 @@ def parse_module(module, outputs.append(val) output_val = outputs - frame.pop() + frame.pop_var() + frame.pop_attr() return input_val, all_ir_nodes, output_val @staticmethod @@ -100,7 +102,7 @@ def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): """ Parse module method """ - frame.push() + frame.push_var() input_var_name = [input.debugName() for input in method.graph.inputs()] kDefaultType = DType2IRDType.map(torch.get_default_dtype()) @@ -132,7 +134,7 @@ def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): output_var_name = [output.debugName() for output in method.graph.outputs()] output_val = [frame.get_var(var_name) for var_name in output_var_name] - frame.pop() + frame.pop_var() return input_val, all_ir_nodes, output_val @staticmethod @@ -357,7 +359,7 @@ def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: return ir_nodes @staticmethod - def parse_prim_attr_node(node, module, frame) -> List[None]: + def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: """ Parse script module node like: %2 :__torch__.torch.nn.modules.linear.___torch_mangle_0.Linear = prim::GetAttr[name="linear1"](%self) @@ -386,16 +388,20 @@ def parse_prim_attr_node(node, module, frame) -> List[None]: if dtype == 'Tensor': tensor = getattr(module, label) shape = list(tensor.shape) - ir_tensor = IRFullTensor( - name=label, shape=shape, - requires_grad=tensor.requires_grad, - dtype=DType2IRDType.map(tensor.dtype) - ) - if isinstance(tensor, torch.nn.Parameter): - ir_tensor.as_param() + if frame.has_attr(label): + ir_tensor = frame.get_attr(label) else: - warnings.warn('Detected non-parameter tensor as graph attribute. Regard them as parameters') - ir_tensor.as_param() + ir_tensor = IRFullTensor( + name=label, shape=shape, + requires_grad=tensor.requires_grad, + dtype=DType2IRDType.map(tensor.dtype) + ) + if isinstance(tensor, torch.nn.Parameter): + ir_tensor.as_param() + else: + warnings.warn('Detected non-parameter tensor as graph attribute. Regard them as parameters') + ir_tensor.as_param() + frame.add_attr(label, ir_tensor) frame.add_var(var_name, ir_tensor) # symbolic attributes elif dtype in ['bool', 'int', 'float']: @@ -457,7 +463,7 @@ def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: raise NotImplementedError("Dynamic Graph is not supported yet") @staticmethod - def parse_prim_loop_node(node, module, frame) -> List[IRFwOperation]: + def parse_prim_loop_node(node, module, frame: Frame) -> List[IRFwOperation]: """ Inputs: %max_iter_count : int @@ -522,7 +528,7 @@ def parse_prim_loop_node(node, module, frame) -> List[IRFwOperation]: # Defensively we don't let variables defined in the Loop body subgraph pollute the outer graph. # So we'd better duplicate all existing variables into a new frame (namely 'inherit_from_top'), # and clean up this new frame after the interpretation of the whole loop execution. - frame.push(inherit_from_top=True) + frame.push_var(inherit_from_top=True) frame.add_var(iter_step_var.debugName(), step) @@ -563,7 +569,7 @@ def parse_prim_loop_node(node, module, frame) -> List[IRFwOperation]: loop_carried_vals = step_result_vals[1:] step += 1 - frame.pop() + frame.pop_var() if not isinstance(condition, bool): raise RuntimeError(f"At the {step}-th step the condition is not evaluated to a constant bool") From d5b9d5b77fd60c660b8e495818e95b7aa3d842d5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Jun 2022 10:25:10 +0800 Subject: [PATCH 0854/1892] fix graph attribute method call bug --- cube/graph/parser/parser.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 87991aa9..72892f38 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -340,14 +340,19 @@ def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: frame.push_param(var_name) # recursively parse the module - if node.inputsAt(0).debugName() == 'self': + self_module = node.inputsAt(0).debugName() == 'self' + if self_module: call_module = module else: call_module = getattr(module, node.inputsAt(0).debugName()) + frame.push_attr() call_method = getattr(call_module, label) _, ir_nodes, outputs_val = ScriptModuleParser.parse_module_method(call_module, call_method, frame=frame) + if not self_module: + frame.pop_attr() + # pop out the frame frame.pop_param(times=len(inputs)-1) From eab99466abde812efe840b29645490d59c53e304 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Jun 2022 10:32:32 +0800 Subject: [PATCH 0855/1892] fix require grad bugs --- cube/codegen/codegen.py | 7 +++-- cube/ir/cten.py | 35 +++++---------------- cube/ir/tensor.py | 57 ++++++++++++++++++++++++++-------- cube/logics/model.py | 65 +++++++++++++++------------------------ cube/logics/translator.py | 27 ++++++++-------- cube/runtime/executor.py | 8 +++-- 6 files changed, 99 insertions(+), 100 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 244fce3f..4d69880f 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -419,18 +419,19 @@ def emit_node(self, node: IRCell, name: str) -> str: """ Emit node / subgraph code """ - fsign = 'cube.runtime.executor.fexecute({model}, *{inputs})' + fsign = 'cube.runtime.executor.fexecute({model}, *{inputs}, requires_grad={req_grad})' bsign = 'cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' inputs = [self.tensor_naming(t) for t in node.inputs() if not t.is_param()] outputs = [self.tensor_naming(t) for t in node.outputs()] + req_grad = any(t.requires_grad is not None for t in outputs if isinstance(t, IRTensor)) inputs = self.tuple_naming(inputs) outputs = self.return_naming(outputs) if isinstance(node, IRSegment): # emit forward if node.forward: - body = fsign.format(model=f'model.{name}', inputs=inputs) + body = fsign.format(model=f'model.{name}', inputs=inputs, req_grad=req_grad) code = f'{outputs} = {body}' # emit backward else: @@ -459,7 +460,7 @@ def emit_node(self, node: IRCell, name: str) -> str: code = f'{outputs} = next(dataloader)' elif isinstance(node, IRAdapter): - body = fsign.format(model=f'model.{name}', inputs=inputs) + body = fsign.format(model=f'model.{name}', inputs=inputs, req_grad=req_grad) code = f'{outputs} = {body}' elif isinstance(node, IRWeightReducer): diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 0d1781d0..1c5572e4 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -14,7 +14,7 @@ """ -from typing import List, Union, Optional, Any +from typing import List, Tuple, Union, Optional, Any import copy from cube.ir.unique import IDGenerator @@ -406,7 +406,7 @@ class IRTensor: def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() - self._shape: Optional(List[int]) = shape + self._shape: Optional(List[int]) = () if shape is None else tuple(shape) self.name = name if name else 'tensor' # device @@ -472,18 +472,6 @@ def device(self, val: Union[int, List[int]]): "tensor placement is not allowed to set manually" ) - @property - def requires_grad(self): - return self._requires_grad - - @requires_grad.setter - def requires_grad(self, requires: bool): - if not isinstance(requires, bool): - raise TypeError("Expected bool") - self._requires_grad = requires - if not requires: - self.grad = None - def as_param(self): """ Set the tensor as trainable parameter @@ -564,23 +552,16 @@ def __eq__(self, tensor): return self._id == tensor._id @property - def shape(self): - if self._shape is None: - return [] - return copy.copy(self._shape) + def shape(self) -> Tuple[int]: + return list(self._shape) @shape.setter - def shape(self, val): - if self._shape is not None and self._shape != val: - raise RuntimeError("Try to change shape") - if not isinstance(val, list) or \ - not all([isinstance(size, int) for size in val]): - raise RuntimeError("Expected shape to be list[int]") - self._shape = copy.copy(list(val)) + def shape(self, val: Tuple[int]): + self._shape = tuple(val) if self.grad is not None: - self.grad.shape = copy.copy(list(val)) + self.grad.shape = tuple(val) - def nele(self) -> int: + def nelement(self) -> int: """ Get total number of element in the tensor. """ diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index f8bbb29b..c60aaec4 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -382,9 +382,6 @@ def __init__(self, shape=None, name=None, requires_grad=True, dtype=irdtype.floa self._segments : List[IRSubTensor] = list() self.requires_grad = requires_grad - if requires_grad: - grad = IRFullTensor(shape, 'g' + self.name, False).as_grad() - self.grad = grad def __copy__(self): """ @@ -465,6 +462,34 @@ def subtensors(self): """ return copy.copy(self._segments) + @property + def grad(self) -> Optional[IRTensor]: + return self._grad + + @grad.setter + def grad(self, val: Optional[IRTensor]): + assert isinstance(val, IRFullTensor) or val is None, f"grad can only be IRFullTensor or None, but got {val}" + self._grad = val + if val is None: + self._requires_grad = False + for t in self.producers + self.consumers: + t.grad = None + else: + assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." + self._requires_grad = True + + @property + def requires_grad(self): + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, val: bool): + self._requires_grad = val + if val and self.grad is None: + self._grad = IRFullTensor(self.shape, 'g' + self.name, False).as_grad() + elif not val and self.grad is not None: + self._grad = None + def as_param(self): """ Set the tensor as trainable parameter @@ -472,14 +497,11 @@ def as_param(self): self.requires_grad = True self._is_param = True self._is_grad = False - # for sub_tensor in self.ptensors + self.ctensors: - # sub_tensor.as_param() def as_grad(self): + self._requires_grad = False self._is_param = False self._is_grad = True - # for sub_tensor in self.ptensors + self.ctensors: - # sub_tensor.as_grad() return self def like(self): @@ -605,6 +627,20 @@ def __init__(self, full_tensor: IRTensor, # val map self._valmap = _to_value_map(valmap) + @property + def grad(self) -> Optional[IRTensor]: + return self._grad + + @grad.setter + def grad(self, val: Optional[IRTensor]): + assert isinstance(val, IRSubTensor) or val is None, f"grad can only be IRFullTensor or None, but got {val}" + self._grad = val + if val is None: + self._requires_grad = False + else: + assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." + self._requires_grad = True + def __eq__(self, other): if isinstance(other, IRFullTensor): @@ -651,18 +687,15 @@ def __copy__(self): tensor._cell = None return tensor - def get_grad(self, fcell: IRCell): + def get_grad(self, fcell: IRCell) -> Optional[IRTensor]: """ Get gradient of this tensor which is associated by a forward cell """ - if not self.requires_grad: + if self.parent.grad is None: self.grad = None return None full_grad = self.parent.grad - if full_grad is None: - self.grad = None - return None if self in fcell.inputs(): ref_cell_ids = list() for dst_cell in self.parent.consumers: diff --git a/cube/logics/model.py b/cube/logics/model.py index 7d0d4442..130a1a0a 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List import copy from cube.graph.graph import IRGraph @@ -47,49 +47,32 @@ def _check_is_sub_tensor(self, tensor): raise TypeError("Tensor only allows to be SubTensor") -def forward(graph, *args) -> IRGraph: +def forward(graph: IRGraph, *args) -> IRGraph: """ Forward the IRGraph, replacing all the intermediate tensors """ if not isinstance(graph, IRGraph): - raise TypeError("Forwarding requires IRGraph") - - gener = _TensorGener() - - for input, arg in zip(graph.inputs(), args): - gener.set_map(input, arg) - - fnodes = list() - - # generate forward nodes - for fnode in graph.nodes(): - fidx = graph.detach(fnode) - inputs = fnode.inputs() - outputs = fnode.outputs() - # fnode = copy.copy(fnode) - fnode : IRFwOperation = fnode - fnode._inputs = inputs - fnode._outputs = outputs - # set forward inputs - for idx, val in enumerate(inputs): - fnode.set_input(idx, gener.renew(val)) - # set forward outputs - for idx, val in enumerate(outputs): - fnode.set_output(idx, gener.renew(val)) - graph.attach(fnode, fidx) - fnodes.append(fnode) - - # reverse is only to make op id looks consecutive - for fnode in graph.nodes()[::-1]: + raise TypeError("Requires IRGraph for forward") + # align graph with input tensors + itensors: List[IRSubTensor] = graph.inputs() + for idx, (itensor, arg) in enumerate(zip(itensors, args)): + graph.set_input(idx, arg) + for producer in copy.copy(itensor.parent.producers): + pidx = graph.detach(producer) + while itensor in producer.outputs(): + oidx = producer.outputs().index(itensor) + producer.set_output(oidx, arg) + graph.attach(producer, pidx) + for consumer in copy.copy(itensor.parent.consumers): + cidx = graph.detach(consumer) + while itensor in consumer.inputs(): + iidx = consumer.inputs().index(itensor) + consumer.set_input(iidx, arg) + graph.attach(consumer, cidx) + while itensor in graph.outputs(): + oidx = graph.outputs().index(itensor) + graph.set_output(oidx, arg) + # generate backward reverse is only to make op id looks consecutive + for fnode in [n for n in graph.nodes() if isinstance(n, IRFwOperation)][::-1]: fnode.gen_backward() - - inputs = [gener.renew(input) for input in graph.inputs()] - outputs = [gener.renew(output) for output in graph.outputs()] - - for idx, input in enumerate(inputs): - graph.set_input(idx, input) - for idx, output in enumerate(outputs): - graph.set_output(idx, output) - - # fgraph = IRGraph(fnodes, inputs, outputs, graph.name) return graph diff --git a/cube/logics/translator.py b/cube/logics/translator.py index 966caf60..20a5286c 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -20,19 +20,16 @@ def gen_logic_graph(outputs=None): nodes = SchedulePool().nodes() graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') # remove backward nodes if no backward is called - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - bnode = node.mirror - if bnode not in graph.nodes(): - IRCell.make_pair(node, None) - for input in node.inputs(): - if isinstance(input, IRSubTensor): - input.grad = None - input.requires_grad = False - for output in node.outputs(): - if isinstance(output, IRSubTensor): - output.grad = None - output.requires_grad = False + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + for fnode in fnodes: + if fnode.mirror not in graph.nodes(): + IRCell.make_pair(fnode, None) + for itensor in fnode.inputs(): + if isinstance(itensor, IRSubTensor): + itensor.grad = None + for otensor in fnode.outputs(): + if isinstance(otensor, IRSubTensor): + otensor.parent.grad = None return graph @staticmethod @@ -83,10 +80,10 @@ def backward(loss: IRSubTensor): trace = SchedulePool().get_tape(loss) if trace is None: raise RuntimeError("No forward detected") - if not loss.shape == [1]: + if loss.nelement() != 1: raise RuntimeError("backward can only perform on the scaler tensor") # grad should be None or 1.0 - loss.parent.requires_grad = False + loss.parent._grad = None for node in trace: for output in node.outputs(): if loss.overlap(output): diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 24d1d85a..07b5d184 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -6,11 +6,15 @@ import torch -def fexecute(subgraph: Callable, *input_tensors: Tuple[Any]): +def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): """ forward the sub-graph. """ - outputs = subgraph(*input_tensors) + if not requires_grad: + with torch.no_grad(): + outputs = subgraph(*input_tensors) + else: + outputs = subgraph(*input_tensors) # print('forwarding... ') return outputs From f22299306bfbe488cb236d94932fc3a9c15ff7b5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Jun 2022 10:33:17 +0800 Subject: [PATCH 0856/1892] fix adapter multi output bug --- cube/graph/gener/gen.py | 23 +++++++++++++++++------ cube/graph/graph.py | 2 ++ cube/ir/adapter/adapter.py | 4 ++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 7a0fc65b..12bb7401 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -108,9 +108,9 @@ def gen_activation(graph: IRGraph) -> IRGraph: idx = min([graph.nodes().index(c) for c in ftensor.consumers]) graph._nodes.insert(idx, fadapter) # insert backward adapter - grad: Optional[IRFullTensor] = ftensor.grad - if grad is not None: - badapter: IRAdapter = fadapter.mirror + badapter: IRAdapter = fadapter.mirror + if badapter is not None: + grad: Optional[IRFullTensor] = ftensor.grad idx = min([graph.nodes().index(c) for c in grad.consumers]) graph._nodes.insert(idx, badapter) return graph @@ -149,7 +149,8 @@ def gen_fulltensor(ftensor: IRFullTensor) -> List[IRAdapter]: fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) fadapter.prims = fprims grad: IRFullTensor = ftensor.grad - if grad is not None: + # TODO: understand why grad cannot be None in inference-only + if grad is not None and (len(grad.ptensors) != 0 or len(grad.ctensors) != 0): for subtensor in grad.ctensors: bprims += IRAdapterGener.gen_subtensor(subtensor) badapter = IRAdapter(grad.ptensors, grad.ctensors) @@ -246,6 +247,16 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: # category to local tensor and remote tensor local = [t for t in ftensor.ptensors if t.device == subtensor.device] remote = [t for t in ftensor.ptensors if t.device != subtensor.device] + # consumers before this consumer can also be considered as input + cidx = ftensor.consumers.index(subtensor._cell) + for ctensor in ftensor.ctensors[:cidx]: + if subtensor.device == ctensor.device: + if ctensor not in local: + local.append(ctensor) + # TODO: also consider consumers on other devices + # else: + # if ctensor not in remote: + # remote.append(ctensor) prims = [] # ==== select ==== # @@ -254,7 +265,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: for tensor in local: common = tensor.common(subtensor) if tensor == subtensor: - return prims + return [] elif common == subtensor: indmap = [] for islicer, oslicer in zip(tensor.indmap.get(), common.indmap.get()): @@ -339,7 +350,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: indmap.append(None) else: indmap.append(s1) - if None in indmap: + if None in indmap or len(cat_dim) > 1: continue indmap = IndexMap(tuple(indmap)) valmap = t1.valmap diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 8977f4c0..380b8d38 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -240,6 +240,7 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: producers = [p for p in itensor.parent.producers if set(p.device).issubset(set(node.device))] # no producer means a weight or cross device-group if len(producers) == 0 or any(p not in nodes for p in producers): + # FIXME: itensor should also consider device difference if itensor not in inputs: inputs.append(itensor) # update outputs @@ -250,6 +251,7 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: consumers = [c for c in otensor.parent.consumers if set(c.device).issubset(set(node.device))] # no consumer usually means the loss or cross device-group if len(consumers) == 0 or any(c not in nodes for c in consumers): + # FIXME: otensor should also consider device difference if otensor not in outputs: outputs.append(otensor) segment = IRSegment(nodes, inputs, outputs) diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index a3fa1a06..bd0d09b5 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -88,11 +88,11 @@ def dispatch(self, devid: int, for_mirror=True): # get inputs inputs = [] for itensor in self.inputs(): - if devid in itensor.device: + if devid in itensor.device and itensor not in inputs: inputs.append(itensor) outputs = [] for otensor in self.outputs(): - if devid in otensor.device: + if devid in otensor.device and otensor not in outputs: outputs.append(otensor) # insert identity prims if len(prims) == 0: From f71c5002cf43b12dacbbcb37ba30825be1a0380a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Jun 2022 11:08:14 +0800 Subject: [PATCH 0857/1892] fix bug dim merge primitive --- cube/graph/gener/gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 12bb7401..06cfa8ae 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -343,7 +343,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: if s1.start < s2.start: cat_dim[dim] = [t1, t2] else: - cat_dim[dim] = [t1, t2] + cat_dim[dim] = [t2, t1] indmap.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop), 1)) else: cat_dim[dim] = None From 3fa2107b55d2a8a61892ae5cdb43de420930ef7f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Jun 2022 16:20:15 +0800 Subject: [PATCH 0858/1892] fix gradient flag for loss --- cube/codegen/codegen.py | 8 ++++++-- cube/ir/tensor.py | 35 +++++++++++++++++++++-------------- cube/logics/translator.py | 28 ++++++++++++++-------------- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 4d69880f..09814cd8 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -424,7 +424,7 @@ def emit_node(self, node: IRCell, name: str) -> str: inputs = [self.tensor_naming(t) for t in node.inputs() if not t.is_param()] outputs = [self.tensor_naming(t) for t in node.outputs()] - req_grad = any(t.requires_grad is not None for t in outputs if isinstance(t, IRTensor)) + req_grad = any(t.requires_grad for t in node.outputs() if isinstance(t, IRTensor)) inputs = self.tuple_naming(inputs) outputs = self.return_naming(outputs) @@ -438,6 +438,10 @@ def emit_node(self, node: IRCell, name: str) -> str: finputs = [t for t in node.mirror.inputs() if t.requires_grad] foutputs = node.mirror.outputs() inputs = [t.grad for t in foutputs] + for idx, itensor in enumerate(inputs): + if isinstance(itensor, float): + assert itensor == 1.0, "Loss gradient should be 1.0" + inputs[idx] = None outputs = [t.grad for t in finputs] # remove weight gradient in outputs for input in finputs: @@ -464,7 +468,7 @@ def emit_node(self, node: IRCell, name: str) -> str: code = f'{outputs} = {body}' elif isinstance(node, IRWeightReducer): - body = fsign.format(model=f'model.{name}', inputs='()') + body = fsign.format(model=f'model.{name}', inputs='()', req_grad=req_grad) code = f'{outputs} = {body}' else: diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index c60aaec4..686205f7 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -463,20 +463,25 @@ def subtensors(self): return copy.copy(self._segments) @property - def grad(self) -> Optional[IRTensor]: + def grad(self) -> Optional[Union[IRTensor, float]]: return self._grad @grad.setter - def grad(self, val: Optional[IRTensor]): - assert isinstance(val, IRFullTensor) or val is None, f"grad can only be IRFullTensor or None, but got {val}" + def grad(self, val: Optional[Union[IRTensor, float]]): + """ + int indicates the tensor is the loss tensor. + """ + assert isinstance(val, (IRFullTensor, float)) or val is None, f"grad can only be IRFullTensor or None, but got {val}" self._grad = val if val is None: self._requires_grad = False - for t in self.producers + self.consumers: + for t in self.ptensors + self.ctensors: t.grad = None - else: + elif isinstance(val, IRFullTensor): assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." self._requires_grad = True + else: + self._requires_grad = True @property def requires_grad(self): @@ -486,9 +491,9 @@ def requires_grad(self): def requires_grad(self, val: bool): self._requires_grad = val if val and self.grad is None: - self._grad = IRFullTensor(self.shape, 'g' + self.name, False).as_grad() + self.grad = IRFullTensor(self.shape, 'g' + self.name, False).as_grad() elif not val and self.grad is not None: - self._grad = None + self.grad = None def as_param(self): """ @@ -628,18 +633,20 @@ def __init__(self, full_tensor: IRTensor, self._valmap = _to_value_map(valmap) @property - def grad(self) -> Optional[IRTensor]: + def grad(self) -> Optional[Union[IRTensor, float]]: return self._grad @grad.setter - def grad(self, val: Optional[IRTensor]): - assert isinstance(val, IRSubTensor) or val is None, f"grad can only be IRFullTensor or None, but got {val}" + def grad(self, val: Optional[Union[IRTensor, float]]): + assert isinstance(val, (IRSubTensor, float)) or val is None, f"grad can only be IRFullTensor or None, but got {val}" self._grad = val if val is None: self._requires_grad = False - else: + elif isinstance(val, IRSubTensor): assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." self._requires_grad = True + else: + self._requires_grad = True def __eq__(self, other): @@ -692,10 +699,10 @@ def get_grad(self, fcell: IRCell) -> Optional[IRTensor]: Get gradient of this tensor which is associated by a forward cell """ - if self.parent.grad is None: - self.grad = None - return None full_grad = self.parent.grad + if full_grad is None or isinstance(full_grad, float): + self.grad = full_grad + return full_grad if self in fcell.inputs(): ref_cell_ids = list() for dst_cell in self.parent.consumers: diff --git a/cube/logics/translator.py b/cube/logics/translator.py index 20a5286c..08768df9 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -1,4 +1,4 @@ -from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -19,17 +19,17 @@ def gen_logic_graph(outputs=None): """ nodes = SchedulePool().nodes() graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') + has_bp = any(n for n in graph.nodes() if isinstance(n, IRBpOperation)) + if has_bp: + assert (fnode.mirror in graph.nodes() for node in graph.nodes() if isinstance(node, IRFwOperation)), \ + "Training requires all nodes have backward." + return graph # remove backward nodes if no backward is called fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] for fnode in fnodes: - if fnode.mirror not in graph.nodes(): - IRCell.make_pair(fnode, None) - for itensor in fnode.inputs(): - if isinstance(itensor, IRSubTensor): - itensor.grad = None - for otensor in fnode.outputs(): - if isinstance(otensor, IRSubTensor): - otensor.parent.grad = None + IRCell.make_pair(fnode, None) + for ftensor in graph.full_tensors(): + ftensor.requires_grad = False return graph @staticmethod @@ -83,11 +83,11 @@ def backward(loss: IRSubTensor): if loss.nelement() != 1: raise RuntimeError("backward can only perform on the scaler tensor") # grad should be None or 1.0 - loss.parent._grad = None - for node in trace: - for output in node.outputs(): - if loss.overlap(output): - node.mirror.update() + loss.parent.grad = 1.0 + for node in loss.parent.producers: + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor) and otensor.overlap(loss): + loss.grad = loss.parent.grad for node in trace[::-1]: SchedulePool().add_node(node.mirror) From a39473d5e4feb8619887df3e1b6d092b90d1d765 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Wed, 15 Jun 2022 07:31:13 +0000 Subject: [PATCH 0859/1892] Merged PR 1383: Codegen WRF model Make WRF model codegen pass: - Add another surface where IRNode is code-gen-ed into frontend/PyTorch invocation (namely, anti-lowering). Some TorchScript-internal ops need to be handled specially on both the arguments and the syntax. - Move some definitions out of the original modules for clearer reference graph and avoid cyclic reference. --- cube/codegen/codegen.py | 15 +-- cube/codegen/frontend_mapping.py | 154 ++++++++++++++++++++++++++++++ cube/graph/function/creators.py | 15 ++- cube/graph/function/function.py | 88 +++++++++++++---- cube/graph/function/scatter.py | 1 - cube/graph/parser/mapping.py | 30 +----- cube/graph/torch_dtype_mapping.py | 77 +++++++++++++++ 7 files changed, 323 insertions(+), 57 deletions(-) create mode 100644 cube/codegen/frontend_mapping.py create mode 100644 cube/graph/torch_dtype_mapping.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 09814cd8..c255ce5b 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -21,6 +21,7 @@ from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock from cube.codegen.register import VarManager +from cube.codegen.frontend_mapping import Sign2EmitRule class CodeGen: @@ -227,18 +228,18 @@ def emit_op_call(self, node: IRFwOperation): """ Emit op forward code """ - op_code = node.signature + signature = node.signature inputs = [self.tensor_naming(t) for t in node.inputs()] - kwargs = list() + kwargs = {} for key in node.kwargs: val = node.kwargs[key] if isinstance(val, str) and 'self.' not in val: val = '"' + val + '"' - code = f'{key}={val}' - kwargs.append(code) - inputs += kwargs - inputs = ', '.join(inputs) - body = f'{op_code}({inputs})' + kwargs[key] = val + + emit_rule = Sign2EmitRule.map(signature) + body = emit_rule(node, inputs, kwargs) + if len(node.outputs()) == 0: code = body else: diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py new file mode 100644 index 00000000..41307467 --- /dev/null +++ b/cube/codegen/frontend_mapping.py @@ -0,0 +1,154 @@ +# Some operators should be specially handled during codegen to the frontend code, +# here we define the customized rule for code emisson. + +from typing import Any, Callable, Dict, List, Optional + +from cube import ir +from cube.ir.cten import IRTensor +from cube.ir.dtype import IRDType +from cube.ir.operator import IRFwOperation + +import torch + +class Sign2EmitRule: + + @staticmethod + def map(signature:str) -> Callable[[IRFwOperation, List[str], Dict[str, Any]], str]: + """ + The definition of the emit rule is like: + + ``` + def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: + x_var, h_var, c_var = arg_vars + return f"lstm({x_var}, [{h_var}, {c_var}], OTHER_ARG_VARS)" + ``` + + 'arg_vars' are inputs (all are Tensor-typed) variable names as string, e.g., ["x", "y"] + 'kw_pairs' are dict whose values has been preprocessed and can be directly stringified, + e.g., {"dim":1, "layout"="nchw"} + """ + return Sign2EmitRule._signMap.get(signature) or Sign2EmitRule._common_rule_join_all + + # By default, we flatten all args and join them by "," + # this includes ops with a fixed number of parameters like 'add(x,y)', + # or ops allowing multiple parameters at the frontend like 'block_diag(t1,t2' + @staticmethod + def _common_rule_join_all(node:IRFwOperation, arg_vars:List[str], kw_pairs:dict) -> str: + signature = node.signature + + kw_assigns = list() + for key, val in kw_pairs.items(): + code = f'{key}={val}' + kw_assigns.append(code) + + args = ", ".join(arg_vars + kw_assigns) + return f"{signature}({args})" + + @staticmethod + def _common_rule_input_as_list(node:IRFwOperation, arg_vars:List[str], kw_pairs:dict) -> str: + signature = node.signature + + kw_assigns = list() + for key, val in kw_pairs.items(): + code = f'{key}={val}' + kw_assigns.append(code) + + args = ", ".join(arg_vars) + kwargs = ", ".join(kw_assigns) + return f"{signature}([{args}], {kwargs})" + + @staticmethod + def emit_slice(node, arg_vars:list, kw_pairs:dict) -> str: + """ + The op is: + aten::slice(input:Tensor, dim:int=0, start:Optional[int]=None, end:Optional[int]=None, step:int=1) -> Tensor + + but at the frontend such an invocation must be rewritten as 'x[:, l:h:s, :, :]' + depending on the 'input's rank and the 'dim' value. + """ + out_tensors : list = node.outputs() + assert len(out_tensors) == 1 + out_tensor : IRTensor = out_tensors[0] + + assert len(arg_vars) == 1 + in_tensor_var : str = arg_vars[0] + + dim : int = kw_pairs["dim"] + start : Optional[int] = kw_pairs["start"] + end : Optional[int] = kw_pairs["end"] + step : int = kw_pairs["step"] + + rank = len(out_tensor.shape) + subscript_components = [":"] * rank + + slice_str = f"{start or ''}:{end or ''}:{step}" + subscript_components[dim] = slice_str + + return f"{in_tensor_var}[{', '.join(subscript_components)}]" + + + # TODO consider making the IR-Torch conversion like IRDType2TorchDType intrinsic to codegen, + # so that we don't need to ad hoc do the conversion as in these emission functions. + # Also, we'd better limit the complexity of the values in 'kw_pairs' so we know for sure we have + # done all necessary conversion. + # + # Basically to convert internal 'IRDType' to frontend 'torch.dtype' + @staticmethod + def emit_zeros(node, arg_vars:list, kw_pairs:dict) -> str: + """ + zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + """ + kw_pairs = kw_pairs.copy() + if 'dtype' in kw_pairs: + ir_dtype : IRDType = kw_pairs['dtype'] + if ir_dtype is not None: + kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) + + # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. + assert 'device' not in kw_pairs + kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. + + assert len(arg_vars) == 0 + return Sign2EmitRule._common_rule_join_all(node, arg_vars, kw_pairs) + + # Basically to convert internal 'IRDType' to frontend 'torch.dtype' + @staticmethod + def emit_to(node, arg_vars:list, kw_pairs:dict) -> str: + kw_pairs = kw_pairs.copy() + + # Unlike 'zeros' who has 'ScalarType? dtype', 'to' has a non-nullable 'dtype'. + ir_dtype : IRDType = kw_pairs['dtype'] + assert ir_dtype is not None + kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) + + return Sign2EmitRule._common_rule_join_all(node, arg_vars, kw_pairs) + + + _signMap = { + 'torch.cat': _common_rule_input_as_list, + 'torch.stack': _common_rule_input_as_list, + + 'torch.slice': emit_slice, + 'torch.zeros': emit_zeros, + 'torch.Tensor.to': emit_to, + } + + +# The reverse mapping of DType2IRDType in /graph/parser/mapping.py +class IRDType2DType: + + @staticmethod + def map(ir_dtype:IRDType) -> torch.dtype: + return IRDType2DType._map[ir_dtype] # subscript/[]-access will throw if not found + + _map = { + ir.float64: torch.float64, + ir.float32: torch.float32, + ir.float16: torch.float16, + ir.uint8: torch.uint8, + ir.int8: torch.int8, + ir.int16: torch.int16, + ir.int32: torch.int32, + ir.int64: torch.int64, + ir.boolean: torch.bool + } diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index dded4b93..1b098154 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -1,20 +1,24 @@ from copy import copy -from typing import List +from typing import List, Optional +from cube.ir.dtype import IRDType from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor class IRZeros(IRFwOperation): - def __init__(self, signature: str, shape: List[int], name: str): + def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:Optional[IRDType]=None): # The shape information must be statically known integer values assert all(isinstance(dim, int) for dim in shape) super().__init__(name, signature, input_length=0, output_length=1) - self._shape = copy(shape) + + # The positional argument to specify the shape is actually called 'size'. + self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) def infer_shape(self) -> bool: - self.outputs(0).shape = self._shape + shape : list = copy(self.kwargs["size"]) + self.outputs(0).shape = shape return True @@ -30,9 +34,10 @@ def infer_shape(self) -> bool: # https://github.com/pytorch/pytorch/blob/483bb4f0cb273f42f655aa30eee6a1fbbaba69b0/torch/csrc/jit/runtime/register_prim_ops.cpp#L1057 # https://github.com/pytorch/pytorch/blob/483bb4f0cb273f42f655aa30eee6a1fbbaba69b0/torch/csrc/jit/runtime/register_prim_ops.cpp#L2215 class IRToTensor(IRFwOperation): - def __init__(self, signature: str, inputs, name:str): + def __init__(self, signature: str, inputs, name:str, ir_dtype:IRDType): super().__init__(name, signature, input_length=1, output_length=1) self.set_input(0, inputs[0]) + self.kwargs.update({"dtype": ir_dtype}) def infer_shape(self) -> bool: self.outputs(0).shape = self.inputs(0).shape diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 6c3d2230..b710080c 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1,7 +1,8 @@ from typing import Any, Iterable, List, Optional, Tuple, Union, Dict import string import copy -import numpy +import torch +import warnings from cube.ir.cten import IRTensor from cube.graph.function.einops import EinDim, IREinops @@ -15,6 +16,8 @@ from cube.graph.function.select import IRSelect, IRSlice from cube.graph.function.scatter import IRSelectScatter from cube.graph.function.repeat import IRRepeat +from cube.ir.dtype import IRDType +from cube.graph.torch_dtype_mapping import DType2IRDType, TorchScalarTypeEnumMap def _create_eshape(shape: List[int], iterator: Optional[Iterable] = None, @@ -101,17 +104,21 @@ def Zeros(signature, # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. - shape, dtype, layout, _erased_device, pin_memory = inputs + size, dtype, layout, _erased_device, pin_memory = inputs # TODO parameters to support, currently they are all None assert dtype is None assert layout is None assert pin_memory is None - for dim, i in enumerate(shape): + ir_dtype : Optional[IRDType] = None + if dtype is not None: + ir_dtype = DType2IRDType.map(dtype) + + for dim, i in enumerate(size): if not isinstance(dim, int) and not dim >= 0: - raise RuntimeWarning(f"The {i}-th component of the shape must be non-negative integer") - return IRZeros(signature, shape, 'zeros') + raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") + return IRZeros(signature, size, 'zeros', ir_dtype) def NewTensor(signature, @@ -127,22 +134,58 @@ def NewTensor(signature, assert dtype is None assert requires_grad == False - arr = numpy.array(data) + ir_dtype : Optional[IRDType] = None + if dtype is not None: + ir_dtype = DType2IRDType.map(dtype) - # ints or floats of any precision, e.g. i8, i64, f16, f32 - # and the specified array is regular/non-ragged. - # Otherwise NumPy would decide the element type as _o_bject. - if not arr.dtype.kind in ['i','f']: - raise RuntimeError("The specified data to create new tensor must be ints or floats") + # if 'data' is not: + # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 + # 2) non-ragged + # ... then this call will throw. + arr = torch.tensor(data, dtype=dtype) # TODO temporarily fake creation with Zeros + # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', + # but since we have omitted the 'data', we must do type inferrence ourselves, + # only in this way we get correct dtype e.g. ints or bools. shape = list(arr.shape) - return IRZeros(signature, shape, 'tensor') + torch_inferred_dtype = arr.dtype + ir_dtype = DType2IRDType.map(torch_inferred_dtype) + signature = 'torch.zeros' + return IRZeros(signature, shape, 'tensor', ir_dtype=ir_dtype) def ToTensor(signature, inputs: Tuple[ IRTensor, ... ]): - tensors = inputs[0:1] - return IRToTensor(signature, tensors, 'to') + """ + 'aten::to' has many overloadings that need resolution, + they differ by both the arity and the type of the argument (possibly at the same position): + + ``` + aten::to.device(Tensor self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor): + aten::to.dtype(Tensor self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor): + aten::to.dtype_layout(Tensor self, *, int dtype, int layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor): + aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor): + aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)): + aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)): + aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)): + ``` + ... where the 'int? dtype' is the underlying type for the enum 'ScalarType'. + """ + + # in our case we only care the overloading 'to.dtype' (arity=5) + assert len(inputs) == 5 + tensor : IRTensor + dtype_underlying : int + non_blocking : bool + copy : bool + opt_memory_format : Optional[int] + tensor, dtype_underlying, non_blocking, copy, opt_memory_format = inputs + + dtype : torch.dtype = TorchScalarTypeEnumMap.map(dtype_underlying) + ir_dtype : IRDType = DType2IRDType.map(dtype) + + signature = 'torch.Tensor.to' + return IRToTensor(signature, [tensor], 'to', ir_dtype=ir_dtype) def Add(signature, inputs): if len(inputs) == 2: @@ -510,7 +553,7 @@ def Transpose(signature, inputs): def View(signature, inputs): """ - out = torch.Tensor.view(tensor: torch.Tensor, shape: List[int]) + out = torch.Tensor.view(tensor: torch.Tensor, size: List[int]) """ assert len(inputs) == 2 input, shape = inputs @@ -620,10 +663,19 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s # bracket[subdim] = edim + '^' anno = _create_anno([in_anno], [ou_anno]) signature = 'torch.Tensor.view' - return IREinops(signature, [anno], [input], 'view', shape=tuple(shape)) + return IREinops(signature, [anno], [input], 'view', size=tuple(shape)) def Reshape(signature, inputs): + """ + torch.reshape(Tensor self, int[] shape) -> Tensor + """ + + warnings.warn(""" + 'torch.reshape' is currently dispatched to 'torch.Tensor.view', + but 'reshape' has keyword parameter 'shape' while 'view' has 'size'. + ArgumentMissing error may be raised during codegen.""") + return View(signature, inputs) @@ -754,6 +806,10 @@ def Repeat(signature, inputs:Tuple[IRTensor, List[int]]): torch.repeat(tensor:Tensor, repeats: List[int]) -> Tensor """ tensor, repeats = inputs + + assert signature == 'torch.repeat' # this is the API in TorchScript + signature = 'torch.Tensor.repeat' # this is the API in Python frontend and is not a Tensor member method + return IRRepeat(signature, [tensor], 'repeat', repeats) diff --git a/cube/graph/function/scatter.py b/cube/graph/function/scatter.py index bc1d19d0..a4116dc7 100644 --- a/cube/graph/function/scatter.py +++ b/cube/graph/function/scatter.py @@ -57,4 +57,3 @@ def infer_shape(self) -> bool: self.outputs(0).shape = s2 return True - diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index dd1432f1..37d3aa34 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -10,8 +10,9 @@ import cube.graph.function as function from cube.ir.operator import IRFwOperation -import cube.ir as ir +# TODO this is a backwards-compatible alias +from cube.graph.torch_dtype_mapping import DType2IRDType class Sign2Op: @@ -140,30 +141,3 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # customized operator code: signature -> code kOpCodeDef: Dict[str, str] = {} - - -class DType2IRDType: - - @staticmethod - def map(dtype: torch.dtype): - """ - Map the torch dtype to IRDType - """ - return DType2IRDType.kDtypeMap[dtype] - - kDtypeMap = { - torch.float64: ir.float64, - torch.float32: ir.float32, - torch.float : ir.float32, - torch.float16: ir.float16, - torch.half : ir.float16, - torch.uint8 : ir.uint8, - torch.int8 : ir.int8, - torch.int16 : ir.int16, - torch.short : ir.int16, - torch.int32 : ir.int32, - torch.int : ir.int32, - torch.int64 : ir.int64, - torch.long : ir.int64, - torch.bool : ir.boolean - } diff --git a/cube/graph/torch_dtype_mapping.py b/cube/graph/torch_dtype_mapping.py new file mode 100644 index 00000000..0787c913 --- /dev/null +++ b/cube/graph/torch_dtype_mapping.py @@ -0,0 +1,77 @@ +from cube import ir +import torch + +class DType2IRDType: + + @staticmethod + def map(dtype: torch.dtype): + """ + Map the torch dtype to IRDType + """ + return DType2IRDType.kDtypeMap[dtype] + + kDtypeMap = { + torch.double: ir.float64, + torch.float64: ir.float64, + torch.float32: ir.float32, + torch.float : ir.float32, + torch.float16: ir.float16, + torch.half : ir.float16, + torch.uint8 : ir.uint8, + torch.int8 : ir.int8, + torch.int16 : ir.int16, + torch.short : ir.int16, + torch.int32 : ir.int32, + torch.int : ir.int32, + torch.int64 : ir.int64, + torch.long : ir.int64, + torch.bool : ir.boolean + } + + +# see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h +# +# ScalarType enum is totally a PyTorch-internal object. Neither itself nor its underlying ints +# are accessible from its Python frontend. +class TorchScalarTypeEnumMap: + + @staticmethod + def map(underlying: int) -> torch.dtype: + + assert isinstance(underlying, int), """ + This function is to convert an underlying 'int' for a Torch-internal 'at::ScalarType' enum + to its corresponding Python-frontend 'torch.dtype' enum. + """ + + dtype = TorchScalarTypeEnumMap._fields[underlying] + + assert dtype is not None, f""" + Referenced to an unsupported ScalarType with underlying int being {underlying} + """ + + return dtype + + # Less used dtypes are masked out because PyTorch keeps **exposing and hiding** them recently + # from a view of Python frontend. + _fields = [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.half, + torch.float32, + torch.float64, + None, #torch.complex32, # complexHalf + None, #torch.complex64, # complexFloat + None, #torch.complex128, # complexDouble + torch.bool, + None, #torch.qint8, + None, #torch.quint8, + None, #torch.qint32, + None, #torch.bfloat16, + None, #torch.quint4x2, + None, #torch.quint2x4, + ] + + assert len(_fields) == 18, "Do not remove any item, mask it out with None" \ No newline at end of file From 2b32c5bd4d1e247da474500ae5315df58e1305f0 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Wed, 15 Jun 2022 08:53:05 +0000 Subject: [PATCH 0860/1892] Merged PR 1384: fix compatibility issues fix compatibility issues by fixing to `python==3.7` (dev env) and `torch==1.11+cu113` --- README.md | 3 +- cube/codegen/frontend_mapping.py | 185 +++++++++++++++---------------- requirements.txt | 5 +- 3 files changed, 96 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 9566acb9..8484cb53 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,8 @@ AI System Compiler to map a semantic (single-device) model into distributed exec ## Prerequisite * Python >= 3.7 -* PyTorch >= 1.11 + +> Install Python 3.7 in the development environment for widest compatibility. ## Install diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index 41307467..e14ecf19 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -10,6 +10,95 @@ import torch +# By default, we flatten all args and join them by "," +# this includes ops with a fixed number of parameters like 'add(x,y)', +# or ops allowing multiple parameters at the frontend like 'block_diag(t1,t2' +def _common_rule_join_all(node:IRFwOperation, arg_vars:List[str], kw_pairs:dict) -> str: + signature = node.signature + + kw_assigns = list() + for key, val in kw_pairs.items(): + code = f'{key}={val}' + kw_assigns.append(code) + + args = ", ".join(arg_vars + kw_assigns) + return f"{signature}({args})" + +def _common_rule_input_as_list(node:IRFwOperation, arg_vars:List[str], kw_pairs:dict) -> str: + signature = node.signature + + kw_assigns = list() + for key, val in kw_pairs.items(): + code = f'{key}={val}' + kw_assigns.append(code) + + args = ", ".join(arg_vars) + kwargs = ", ".join(kw_assigns) + return f"{signature}([{args}], {kwargs})" + +def emit_slice(node, arg_vars:list, kw_pairs:dict) -> str: + """ + The op is: + aten::slice(input:Tensor, dim:int=0, start:Optional[int]=None, end:Optional[int]=None, step:int=1) -> Tensor + + but at the frontend such an invocation must be rewritten as 'x[:, l:h:s, :, :]' + depending on the 'input's rank and the 'dim' value. + """ + out_tensors : list = node.outputs() + assert len(out_tensors) == 1 + out_tensor : IRTensor = out_tensors[0] + + assert len(arg_vars) == 1 + in_tensor_var : str = arg_vars[0] + + dim : int = kw_pairs["dim"] + start : Optional[int] = kw_pairs["start"] + end : Optional[int] = kw_pairs["end"] + step : int = kw_pairs["step"] + + rank = len(out_tensor.shape) + subscript_components = [":"] * rank + + slice_str = f"{start or ''}:{end or ''}:{step}" + subscript_components[dim] = slice_str + + return f"{in_tensor_var}[{', '.join(subscript_components)}]" + + +# TODO consider making the IR-Torch conversion like IRDType2TorchDType intrinsic to codegen, +# so that we don't need to ad hoc do the conversion as in these emission functions. +# Also, we'd better limit the complexity of the values in 'kw_pairs' so we know for sure we have +# done all necessary conversion. +# +# Basically to convert internal 'IRDType' to frontend 'torch.dtype' +def emit_zeros(node, arg_vars:list, kw_pairs:dict) -> str: + """ + zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + """ + kw_pairs = kw_pairs.copy() + if 'dtype' in kw_pairs: + ir_dtype : IRDType = kw_pairs['dtype'] + if ir_dtype is not None: + kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) + + # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. + assert 'device' not in kw_pairs + kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. + + assert len(arg_vars) == 0 + return _common_rule_join_all(node, arg_vars, kw_pairs) + +# Basically to convert internal 'IRDType' to frontend 'torch.dtype' +def emit_to(node, arg_vars:list, kw_pairs:dict) -> str: + kw_pairs = kw_pairs.copy() + + # Unlike 'zeros' who has 'ScalarType? dtype', 'to' has a non-nullable 'dtype'. + ir_dtype : IRDType = kw_pairs['dtype'] + assert ir_dtype is not None + kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) + + return _common_rule_join_all(node, arg_vars, kw_pairs) + class Sign2EmitRule: @staticmethod @@ -27,101 +116,7 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: 'kw_pairs' are dict whose values has been preprocessed and can be directly stringified, e.g., {"dim":1, "layout"="nchw"} """ - return Sign2EmitRule._signMap.get(signature) or Sign2EmitRule._common_rule_join_all - - # By default, we flatten all args and join them by "," - # this includes ops with a fixed number of parameters like 'add(x,y)', - # or ops allowing multiple parameters at the frontend like 'block_diag(t1,t2' - @staticmethod - def _common_rule_join_all(node:IRFwOperation, arg_vars:List[str], kw_pairs:dict) -> str: - signature = node.signature - - kw_assigns = list() - for key, val in kw_pairs.items(): - code = f'{key}={val}' - kw_assigns.append(code) - - args = ", ".join(arg_vars + kw_assigns) - return f"{signature}({args})" - - @staticmethod - def _common_rule_input_as_list(node:IRFwOperation, arg_vars:List[str], kw_pairs:dict) -> str: - signature = node.signature - - kw_assigns = list() - for key, val in kw_pairs.items(): - code = f'{key}={val}' - kw_assigns.append(code) - - args = ", ".join(arg_vars) - kwargs = ", ".join(kw_assigns) - return f"{signature}([{args}], {kwargs})" - - @staticmethod - def emit_slice(node, arg_vars:list, kw_pairs:dict) -> str: - """ - The op is: - aten::slice(input:Tensor, dim:int=0, start:Optional[int]=None, end:Optional[int]=None, step:int=1) -> Tensor - - but at the frontend such an invocation must be rewritten as 'x[:, l:h:s, :, :]' - depending on the 'input's rank and the 'dim' value. - """ - out_tensors : list = node.outputs() - assert len(out_tensors) == 1 - out_tensor : IRTensor = out_tensors[0] - - assert len(arg_vars) == 1 - in_tensor_var : str = arg_vars[0] - - dim : int = kw_pairs["dim"] - start : Optional[int] = kw_pairs["start"] - end : Optional[int] = kw_pairs["end"] - step : int = kw_pairs["step"] - - rank = len(out_tensor.shape) - subscript_components = [":"] * rank - - slice_str = f"{start or ''}:{end or ''}:{step}" - subscript_components[dim] = slice_str - - return f"{in_tensor_var}[{', '.join(subscript_components)}]" - - - # TODO consider making the IR-Torch conversion like IRDType2TorchDType intrinsic to codegen, - # so that we don't need to ad hoc do the conversion as in these emission functions. - # Also, we'd better limit the complexity of the values in 'kw_pairs' so we know for sure we have - # done all necessary conversion. - # - # Basically to convert internal 'IRDType' to frontend 'torch.dtype' - @staticmethod - def emit_zeros(node, arg_vars:list, kw_pairs:dict) -> str: - """ - zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - """ - kw_pairs = kw_pairs.copy() - if 'dtype' in kw_pairs: - ir_dtype : IRDType = kw_pairs['dtype'] - if ir_dtype is not None: - kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) - - # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. - assert 'device' not in kw_pairs - kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. - - assert len(arg_vars) == 0 - return Sign2EmitRule._common_rule_join_all(node, arg_vars, kw_pairs) - - # Basically to convert internal 'IRDType' to frontend 'torch.dtype' - @staticmethod - def emit_to(node, arg_vars:list, kw_pairs:dict) -> str: - kw_pairs = kw_pairs.copy() - - # Unlike 'zeros' who has 'ScalarType? dtype', 'to' has a non-nullable 'dtype'. - ir_dtype : IRDType = kw_pairs['dtype'] - assert ir_dtype is not None - kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) - - return Sign2EmitRule._common_rule_join_all(node, arg_vars, kw_pairs) + return Sign2EmitRule._signMap.get(signature) or _common_rule_join_all _signMap = { diff --git a/requirements.txt b/requirements.txt index e83a8c22..7da97869 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ einops matplotlib -pytest \ No newline at end of file +pytest + +--find-links https://download.pytorch.org/whl/torch_stable.html +torch==1.11.0+cu113 \ No newline at end of file From 6f1c30ed40171043beb301ea758451e18f248c09 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Jun 2022 11:38:07 +0800 Subject: [PATCH 0861/1892] gradient update --- cube/ir/cten.py | 121 +++++++++----------------------------- cube/ir/operator.py | 78 +++++++++--------------- cube/ir/tensor.py | 86 +++++++++++++++------------ cube/logics/translator.py | 13 ++-- 4 files changed, 110 insertions(+), 188 deletions(-) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 1c5572e4..41c1e96a 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -401,38 +401,29 @@ class IRTensor: and will be translated to None in code generation. """ - _attr = ['name', '_is_param', '_requires_grad', '_is_grad', '_grad', '_dtype'] + _attr = ['name', '_is_param', '_is_grad', '_requires_grad', '_dtype'] def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() - self._shape: Optional(List[int]) = () if shape is None else tuple(shape) - self.name = name if name else 'tensor' + self._shape: Tuple[int] = () if shape is None else tuple(shape) + self.name: str = name if name else 'tensor' # device self._cell: Optional[IRCell] = None self._dtype: IRDType = dtype + self._is_param: bool = False + self._is_grad: bool = False - self._requires_grad = True - self._is_param = False - - self._is_grad = False - self._grad = None # the gradient of this tensor - self._data = None # the tensor of this gradient belongs to + # tensor gradient + self._requires_grad: bool = True + self._grad: Optional[Union[IRTensor, float]] = None @property - def requires_grad(self): - return self._requires_grad - - @requires_grad.setter - def requires_grad(self, val: bool): - self._requires_grad = val - - @property - def dtype(self): + def dtype(self) -> IRDType: """ - Data type + Tensor data type """ return self._dtype @@ -445,6 +436,15 @@ def dtype(self, val: IRDType): raise TypeError(f"Expected IRDType but got {val}") self._dtype = val + @property + def cell(self) -> Optional[IRCell]: + return self._cell + + @cell.setter + def cell(self, val: Optional[IRCell]): + assert isinstance(val, IRCell) or val is None, "Expected cell to be Optional[IRCell]" + self._cell = val + def attach_cell(self, cell: IRCell): """ Attach to a cell, to be with input or output @@ -453,12 +453,6 @@ def attach_cell(self, cell: IRCell): raise TypeError("Expected an IRCell") self._cell = cell - def detach_cell(self): - """ - Detach from a cell - """ - self._cell = None - @property def device(self) -> List[int]: if self._cell: @@ -476,7 +470,8 @@ def as_param(self): """ Set the tensor as trainable parameter """ - self.requires_grad = True + assert self._grad is not None, "missing grad tensor" + self._requires_grad = True self._is_grad = False self._is_param = True return self @@ -487,25 +482,6 @@ def is_param(self): """ return self._is_param - @property - def data(self): - return self._data - - @property - def grad(self): - return self._grad - - @grad.setter - def grad(self, grad): - if grad is None: - self._grad = grad - return - elif not isinstance(grad, IRTensor): - raise TypeError("grad can only be None or Tensor") - self.requires_grad = True - self._grad = grad - grad._data = self - def as_grad(self): self._is_param = False self._is_grad = True @@ -514,22 +490,9 @@ def as_grad(self): def is_grad(self): return self._is_grad - def renew(self): - """ - Renew a new tensor with same name and shape, - but with a different new id - - Returns: - tensor - """ - tensor = IRTensor(self._shape, self.name) - new_id = tensor._id - for key in self.__dict__: - setattr(tensor, key, getattr(self, key)) - # clear attached cells - tensor._cell = list() - tensor._id = new_id - return tensor + @property + def requires_grad(self) -> bool: + return self._requires_grad def __copy__(self): """ @@ -543,7 +506,7 @@ def __copy__(self): for key in self.__dict__: setattr(tensor, key, getattr(self, key)) # clear attached cells - tensor._cell = list() + tensor.cell = None return tensor def __eq__(self, tensor): @@ -558,8 +521,8 @@ def shape(self) -> Tuple[int]: @shape.setter def shape(self, val: Tuple[int]): self._shape = tuple(val) - if self.grad is not None: - self.grad.shape = tuple(val) + if isinstance(self._grad, IRTensor): + self._grad.shape = tuple(val) def nelement(self) -> int: """ @@ -572,36 +535,6 @@ def nelement(self) -> int: cnt *= num return cnt - def src(self, cells: List[IRCell]) -> List[IRCell]: - """ - Return all the cells that will generate this tensor - """ - src_cells = list() - for cell in cells: - if not isinstance(cell, IRCell): - raise TypeError("Expected cells to be List[IRCell]") - if self in cell.outputs(): - src_cells.append(cell) - return src_cells - - def dst(self, cells: List[IRCell]) -> List[IRCell]: - """ - Return all the cells that will generate this tensor - """ - dst_cells = list() - for cell in cells: - if not isinstance(cell, IRCell): - raise TypeError("Expected cells to be List[IRCell]") - if self in cell.inputs(): - dst_cells.append(cell) - return dst_cells - - def is_leaf(self, cells: List[IRCell]): - """ - Check if it is a leaf tensor (parameter or input data) - """ - return len(self.src(cells)) == 0 - def backward(self): """ Autograd backward on the tensor diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 2b8ac283..c3e36ad4 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -7,16 +7,20 @@ from cube.ir.unique import IDGenerator -class BaseOperator: +class IRBaseOp(IRCell): def __init__(self, name: str, signature: str, - input_length: int, output_length: int, - init_outputs=False): - super().__init__(name, signature, - input_length, output_length, - init_outputs=init_outputs) - - def infer_shape(self): + inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): + super().__init__(name, signature, len(inputs), len(outputs), init_outputs=False) + self.kwargs = kwargs + assert all(isinstance(t, IRTensor) for t in inputs), "expect all inputs to be IRTensors" + assert all(isinstance(t, IRTensor) for t in outputs), "expect all outputs to be IRTensors" + for idx, itensor in enumerate(inputs): + self.set_input(idx, itensor) + for idx, otensor in enumerate(outputs): + self.set_output(idx, otensor) + + def infer_shape(self) -> bool: """ Infer output value shape """ @@ -26,21 +30,10 @@ def replicate(self): """ Replicate the Operation """ - cpy = copy.copy(self) - cpy._device = list() - cpy._id = IDGenerator().gen_cell_id() - # reset input and output - cpy._inputs = [None] * len(self.inputs()) - for idx, input in enumerate(self.inputs()): - cpy.set_input(idx, input) - cpy._outputs = [None] * len(self.outputs()) - for idx, output in enumerate(self.outputs()): - cpy.set_output(idx, output) - cpy._mirror = None - cpy._tag = None - cpy.clear_predecessor() - cpy.clear_successor() - return cpy + node = type(self)( + self.name, self.signature, self.inputs(), self.outputs(), **self.kwargs) + node._id = self._id + return node class IRFwOperation(IRCell): @@ -131,16 +124,12 @@ def gen_backward(self): data_num=len(self.inputs()), grad_num=len(self.outputs()) ) - for idx, input in enumerate(self.inputs()): - grad = None - if isinstance(input, IRSubTensor): - grad = input.get_grad(self) - input.grad = grad - bnode.set_data(idx, input) + for idx, itensor in enumerate(self.inputs()): + grad = itensor.grad if isinstance(itensor, IRSubTensor) else None + bnode.set_data(idx, itensor) bnode.set_output(idx, grad) - for idx, output in enumerate(self.outputs()): - grad = output.get_grad(self) - output.grad = grad + for idx, otensor in enumerate(self.outputs()): + grad = otensor.grad if isinstance(otensor, IRSubTensor) else None bnode.set_input(idx, grad) IRCell.make_pair(self, bnode) return bnode @@ -245,31 +234,20 @@ def update(self): graph.attach(node, idx) ``` """ - fnode = self.mirror - for idx, input in enumerate(fnode.inputs()): - grad = None - if isinstance(input, IRSubTensor): - grad = input.get_grad(fnode) - self.set_data(idx, input) + fnode: IRFwOperation = self.mirror + assert isinstance(fnode, IRFwOperation), "Cannot find corresponding IRFwOperation" + for idx, itensor in enumerate(fnode.inputs()): + grad = itensor.grad if isinstance(itensor, IRSubTensor) else None + self.set_data(idx, itensor) self.set_output(idx, grad) - for idx, output in enumerate(fnode.outputs()): - grad = output.get_grad(fnode) + for idx, otensor in enumerate(fnode.outputs()): + grad = otensor.grad if isinstance(otensor, IRSubTensor) else None self.set_input(idx, grad) def __repr__(self): dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, inputs={self.inputs()}, datas={self.datas()}, outputs={self.outputs()})' return dscp - def module_repr(self) -> str: - """ - Weight-hidden string representation - """ - ins = [t for t in self.datas() if isinstance(t, IRSubTensor) and not t.is_param()] - outs = [t.grad for t in ins] - assert all([out in self.outputs() for out in outs]) - dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, inputs={self.inputs()}, outputs={outs})' - return dscp - class IRDataOperation(IRCell): diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 686205f7..111e96e3 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -433,6 +433,8 @@ def add_consumer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): assert cell not in self._consumers, f"{cell} already exists as consumer" self._consumers.insert(idx, cell) self._ctensors.insert(idx, tensor) + for t in self.ctensors: + t._dirty_grad = True def rm_producer(self, cell: IRCell) -> int: if cell not in self.producers: @@ -473,15 +475,11 @@ def grad(self, val: Optional[Union[IRTensor, float]]): """ assert isinstance(val, (IRFullTensor, float)) or val is None, f"grad can only be IRFullTensor or None, but got {val}" self._grad = val - if val is None: - self._requires_grad = False - for t in self.ptensors + self.ctensors: - t.grad = None - elif isinstance(val, IRFullTensor): + self._requires_grad = False if val is None else True + if isinstance(val, IRFullTensor): assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." - self._requires_grad = True - else: - self._requires_grad = True + for tensor in self.ctensors + self.ptensors: + tensor._dirty_grad = True @property def requires_grad(self): @@ -494,6 +492,8 @@ def requires_grad(self, val: bool): self.grad = IRFullTensor(self.shape, 'g' + self.name, False).as_grad() elif not val and self.grad is not None: self.grad = None + for tensor in self.ctensors + self.ptensors: + tensor._dirty_grad = True def as_param(self): """ @@ -504,7 +504,7 @@ def as_param(self): self._is_grad = False def as_grad(self): - self._requires_grad = False + self.requires_grad = False self._is_param = False self._is_grad = True return self @@ -551,8 +551,8 @@ def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, sub_tensor = IRSubTensor(self, indmap, valmap, shape) for attr in IRFullTensor._attr: setattr(sub_tensor, attr, getattr(self, attr)) - sub_tensor.grad = None self._segments.append(sub_tensor) + sub_tensor._dirty_grad = True return sub_tensor def overlap(self, other): @@ -632,21 +632,8 @@ def __init__(self, full_tensor: IRTensor, # val map self._valmap = _to_value_map(valmap) - @property - def grad(self) -> Optional[Union[IRTensor, float]]: - return self._grad - - @grad.setter - def grad(self, val: Optional[Union[IRTensor, float]]): - assert isinstance(val, (IRSubTensor, float)) or val is None, f"grad can only be IRFullTensor or None, but got {val}" - self._grad = val - if val is None: - self._requires_grad = False - elif isinstance(val, IRSubTensor): - assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." - self._requires_grad = True - else: - self._requires_grad = True + # grad flag + self._dirty_grad = True def __eq__(self, other): @@ -676,7 +663,7 @@ def indmap(self) -> IndexMap: return copy.copy(self._indmap) @property - def valmap(self): + def valmap(self) -> ValueMap: return copy.copy(self._valmap) def __copy__(self): @@ -694,16 +681,30 @@ def __copy__(self): tensor._cell = None return tensor - def get_grad(self, fcell: IRCell) -> Optional[IRTensor]: + @property + def grad(self) -> Optional[Union[IRTensor, float]]: """ Get gradient of this tensor which is associated by a forward cell + + Gradient can be: + - None: if the tensor doesn't require gradient + - 1.0: if the tensor is loss tensor + - IRSubTensor: if the tensor requires gradient and not the loss tensor + + Gradient cannot be set and can only be inferred by its IRFullTensor. """ + if not self._dirty_grad: + return self._grad + + assert isinstance(self._cell, IRCell), "No cell attached to this tensor." full_grad = self.parent.grad + # None indicate the tensor doesn't need grad. + # float means this tensor is a loss tensor if full_grad is None or isinstance(full_grad, float): - self.grad = full_grad - return full_grad - if self in fcell.inputs(): + self._grad = full_grad + # this tensor is consumed + elif self in self._cell.inputs(): ref_cell_ids = list() for dst_cell in self.parent.consumers: for input in dst_cell.inputs(): @@ -712,27 +713,36 @@ def get_grad(self, fcell: IRCell) -> Optional[IRTensor]: break ref_times = len(ref_cell_ids) if ref_times == 0: - raise RuntimeError("Internal Error: ref time is 0") - idx = ref_cell_ids.index(fcell._id) + raise RuntimeError("Internal error: consumer doesn't have the operator attached to this tensor") + idx = ref_cell_ids.index(self._cell._id) grad = full_grad.select( indmap = self.indmap, valmap = (idx, ref_times), shape = self.shape ) - self.grad = grad + self._grad = grad + self._dirty_grad = False return grad - elif self in fcell.outputs(): + # this tensor is produced + elif self in self._cell.outputs(): grad = full_grad.select( indmap = self.indmap, valmap = (0, 1), shape = self.shape ) - self.grad = grad - return grad + self._grad = grad else: - raise RuntimeError(f"{self} not found in cell {fcell}") + raise RuntimeError("visiting a tensor grad that potentially generated by IRAdapter") + self._dirty_grad = False + self._requires_grad = False if full_grad is None else True + return self._grad + + @property + def requires_grad(self) -> bool: + _ = self.grad + return self._requires_grad - def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, None], shape=None): + def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, None], shape=None) -> IRTensor: """ Select an IRSubTensor diff --git a/cube/logics/translator.py b/cube/logics/translator.py index 08768df9..a1d9bf0e 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -21,7 +21,7 @@ def gen_logic_graph(outputs=None): graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') has_bp = any(n for n in graph.nodes() if isinstance(n, IRBpOperation)) if has_bp: - assert (fnode.mirror in graph.nodes() for node in graph.nodes() if isinstance(node, IRFwOperation)), \ + assert all(fnode.mirror in graph.nodes() for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)), \ "Training requires all nodes have backward." return graph # remove backward nodes if no backward is called @@ -30,6 +30,11 @@ def gen_logic_graph(outputs=None): IRCell.make_pair(fnode, None) for ftensor in graph.full_tensors(): ftensor.requires_grad = False + #TODO: ad hoc fix on operators with multiple same input tensors + for node in graph.nodes(): + for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor): + itensor._dirty_grad = True return graph @staticmethod @@ -82,12 +87,8 @@ def backward(loss: IRSubTensor): raise RuntimeError("No forward detected") if loss.nelement() != 1: raise RuntimeError("backward can only perform on the scaler tensor") - # grad should be None or 1.0 + # loss tensor grad should be 1.0 loss.parent.grad = 1.0 - for node in loss.parent.producers: - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor) and otensor.overlap(loss): - loss.grad = loss.parent.grad for node in trace[::-1]: SchedulePool().add_node(node.mirror) From a921e60ba7de8f46f2677eefab3d0df9a0c7c4d8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Jun 2022 11:39:06 +0800 Subject: [PATCH 0862/1892] parser support module list --- cube/graph/function/function.py | 15 +++++++++++++++ cube/graph/parser/mapping.py | 2 ++ cube/graph/parser/parser.py | 31 ++++++++++++++++++++----------- cube/runtime/function/function.py | 8 +++++--- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index b710080c..24ed5b26 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -813,6 +813,21 @@ def Repeat(signature, inputs:Tuple[IRTensor, List[int]]): return IRRepeat(signature, [tensor], 'repeat', repeats) +def Embedding(signature, inputs: List): + """ + torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False) + """ + signature = 'cube.runtime.function.embedding' + itensor, weight = inputs[:2] + padding_idx = inputs[3] + start, stop = 0, weight.shape[0] + letters = iter(string.ascii_lowercase) + ishapes = [_create_eshape(itensor.shape, letters), _create_eshape(weight.shape, letters)] + oshapes = [ishapes[0] + [ishapes[1][-1]]] + anno = _create_anno(ishapes, oshapes) + return IREinops(signature, [anno], [itensor, weight], 'embedding', padding_idx=padding_idx, start=start, stop=stop) + + def ScriptEinOps(signature, inputs): """ apply_for_scriptable_torch(recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str) -> torch.Tensor: diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 37d3aa34..c7813bef 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -68,6 +68,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('layer_norm'): function.LayerNorm, + __ftemplate('embedding'): function.Embedding, + # __ftemplate('layer_norm'): function.LayerNorm, # torch aten diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 72892f38..53315564 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -226,6 +226,8 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: input_vals.append(val) # map to IR operator + if 'torch' not in fsig: # indicate a customized operator + fsig = fsig.split('.')[-1] ir_node = Sign2Op.map(fsig)(inputs=input_vals) # push output in the frame @@ -344,7 +346,9 @@ def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: if self_module: call_module = module else: - call_module = getattr(module, node.inputsAt(0).debugName()) + call_module = frame.get_var(node.inputsAt(0).debugName()) + assert isinstance(call_module, torch.nn.Module), "the call module is not torch.nn.Module" + # call_module = getattr(module, node.inputsAt(0).debugName()) frame.push_attr() call_method = getattr(call_module, label) @@ -369,7 +373,9 @@ def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: Parse script module node like: %2 :__torch__.torch.nn.modules.linear.___torch_mangle_0.Linear = prim::GetAttr[name="linear1"](%self) %3 : Tensor = prim::GetAttr[name="weight"](%self) - The __torch__.torch.nn.modules.* will be ignored + Or: + %embed.1 : __torch__.torch.nn.modules.sparse.Embedding = prim::GetAttr[name="embed"](%self) + %embed.3 : Tensor = prim::CallMethod[name="forward"](%embed.1, %input_ids.1) This will add frame with the variable name and it's value @@ -382,8 +388,10 @@ def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: Empty list """ global _refmodule - if node.inputsAt(0).debugName() != 'self': - raise RuntimeError(f"Fail to parse {node} due to missing %self") + + module_name = node.inputsAt(0).debugName() + module = module if module_name == 'self' else frame.get_var(module_name) + assert isinstance(module, torch.nn.Module) label = node.s('name') var_name = node.outputsAt(0).debugName() @@ -414,18 +422,19 @@ def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: val = 'self.' + label else: val = getattr(module, label) - # print(f'get: var_name {var_name}: {val}') frame.add_var(var_name, val) # NoneType elif dtype == 'NoneType': frame.add_var(var_name, None) - # module name or other things cannot handle - elif dtype == '__torch__.einops.einops.TransformRecipe': - recipe = getattr(module, label) - frame.add_var(var_name, recipe) else: - # print("### parse_prim_attr_node unknown: {}".format(dtype)) - frame.add_var(var_name, label) + if isinstance(module, torch.nn.ModuleList): + if str.isdecimal(label): + val = module[int(label)] + else: + val = getattr(module, label) + else: + val = getattr(module, label) + frame.add_var(var_name, val) return list() @staticmethod diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 55bc5c34..4900ca3b 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -40,7 +40,8 @@ def conv3d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso input = TorchF.pad(input, pad_padding, 'constant', 0) return TorchF.conv3d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) -def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): + +def embedding(input: torch.Tensor, weight: torch.Tensor, padding_idx: Optional[int], start: int, stop: int): """ Embedding @@ -58,12 +59,13 @@ def embedding(input: torch.Tensor, weight: torch.Tensor, start: int, stop: int): masked_input = input.clone() - start masked_input[input_mask] = 0 output = TorchF.embedding( - masked_input, weight, - None, None, 2.0, False, False + masked_input, weight, padding_idx, + None, 2.0, False, False ) output[input_mask, :] = 0.0 return output + def einops(input: torch.Tensor, recipe_str, reduction_type: str): import pickle recipe = pickle.loads(recipe_str) From 93a866eb025b5d4c7c362ebb9e85065869556d96 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Jun 2022 11:39:28 +0800 Subject: [PATCH 0863/1892] support gpt training --- examples/nlp/blocks/attention.py | 12 ++++++------ examples/nlp/blocks/encoder.py | 6 ++++-- examples/nlp/blocks/mlp.py | 1 - examples/nlp/gpt/policy/naive.py | 8 ++++++++ examples/nlp/gpt/train.py | 11 ++++++++--- 5 files changed, 26 insertions(+), 12 deletions(-) create mode 100644 examples/nlp/gpt/policy/naive.py diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 4798ed45..bf9d9471 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -2,7 +2,7 @@ import cube -@cube.graph.parser.register('L^ N E^, H+ E^, H+, H+ E^, H+, H+ E^, H+, E^ H+, E^ -> L^ N E') +@cube.graph.parser.register('L^ N E^, H+ E^, H+, H+ E^, H+, H+ E^, H+, E^ H+, E^ -> L^ N E^') def self_attention(query: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, @@ -30,10 +30,10 @@ def self_attention(query: torch.Tensor, if mask: # (N h) L L -> (N h) L L attn = attn.view(N, num_head, L, L) ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) + amask = torch.tril(ones) + amask = amask.view(N, 1, L, L) + amask = (amask < 0.5) + attn = attn.masked_fill_(amask, -10000.0) attn = attn.view((N * num_head), L, L) attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L @@ -45,7 +45,7 @@ def self_attention(query: torch.Tensor, return output -@cube.graph.parser.register('L^ N E^, L^ N E^, H+ E^, H+, H+ E^, H+, H+ E^, H+, E^ H+, E^ -> L^ N E') +@cube.graph.parser.register('L^ N E^, L^ N E^, H+ E^, H+, H+ E^, H+, H+ E^, H+, E^ H+, E^ -> L^ N E^') def cross_attention(query: torch.Tensor, key: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 83645f61..1e082fd9 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -1,6 +1,7 @@ import torch from examples.nlp.blocks.attention import MultiHeadSelfAttention from examples.nlp.blocks.mlp import MLP +import warnings class EncoderLayer(torch.nn.Module): @@ -16,17 +17,18 @@ def __init__(self, embed_dim: int, num_heads: int, self.dropout = torch.nn.Dropout(p=dropout) self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + warnings.warn('residual is disabled in encoder block') def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.self_attn_layer_norm(x) x = self.self_attn(x) x = self.dropout(x) - x = x + residual + # x = x + residual residual = x x = self.final_layer_norm(x) x = self.mlp(x) x = self.dropout(x) - x = x + residual + # x = x + residual return x diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py index 1f9472ea..961f7214 100644 --- a/examples/nlp/blocks/mlp.py +++ b/examples/nlp/blocks/mlp.py @@ -18,7 +18,6 @@ class MLP(torch.nn.Module): def __init__(self, embed_dim: int, hidden_dim: int, dropout: float, bias=True): super().__init__() - print((hidden_dim, embed_dim)) self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) diff --git a/examples/nlp/gpt/policy/naive.py b/examples/nlp/gpt/policy/naive.py new file mode 100644 index 00000000..10de3596 --- /dev/null +++ b/examples/nlp/gpt/policy/naive.py @@ -0,0 +1,8 @@ +from cube.graph import IRGraph + +def PAS(graph: IRGraph, resource): + # print(graph.extra_repr()) + for node in graph.nodes(): + graph.assign(node, 0) + # print(graph.extra_repr()) + return graph \ No newline at end of file diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index d10676b1..6c0000fb 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -17,29 +17,34 @@ from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary, model_summary +from examples.nlp.gpt.policy.naive import PAS + def train(): batch_size = 1 - model = GPT().cuda() + model = GPT() dataloader = GPTDataLoader(batch_size) optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) print_each_rank('model weight consumpition:') memory_summary() + model = cube.SemanticModel(model, dataloader.shapes) + @cube.compile(model, dataloader, PAS=PAS, override=True) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) loss.backward() + model = model.get_gen_module() CudaTimer(enable=False).warmup() iter_num = 64 for step in range(iter_num): - if step == 0: - model_summary(model, next(dataloader)) + # if step == 0: + # model_summary(model, next(dataloader)) if step >= 20: CudaTimer(enable=True).start('e2e') From bcef5158d648ab6de512013332c78284bd417084 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 16 Jun 2022 11:28:16 +0000 Subject: [PATCH 0864/1892] Merged PR 1385: Make WRF fake loop upper bounds adjustable Now setting fake UBs to -1 will recover the original loop size. Otherwise any (>=0) UBs will enable the fake size. This is enabled by unrolling `prim::If` if and only if the condition is a static bool. --- cube/graph/parser/parser.py | 65 ++++++++++++++++++++++++++++++++++--- examples/wrf/wrf2.py | 42 ++++++++++++------------ 2 files changed, 82 insertions(+), 25 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 53315564..7d9e3226 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -191,6 +191,8 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] return ScriptModuleParser.parse_prim_list_unpack_node(node, module, frame) if node_type == ScriptNodeKind.PrimPythonOp: return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) + if node_type == ScriptNodeKind.PrimIf: + return ScriptModuleParser.parse_prim_if_node(node, module, frame) if node_type == ScriptNodeKind.PrimLoop: return ScriptModuleParser.parse_prim_loop_node(node, module, frame) @@ -468,13 +470,68 @@ def parse_prim_constant_node(node, module, frame) -> List[None]: def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: """ Parse script module node like - %output2 : Tensor = prim::If(%15) # /tmp/ipykernel_27188/2459450745.py:13:8 + %output1 : Tensor, %output2 : Tensor = prim::If(%15) # /tmp/ipykernel_27188/2459450745.py:13:8 block0(): - -> (%output1.1) + -> (%1, %2) block1(): - -> (%output2.1) + -> (%3, %4) + + and the only input (e.g. %15) must be of type bool. """ - raise NotImplementedError("Dynamic Graph is not supported yet") + + inputs : List[torch._C.Value] = list(node.inputs()) + outputs : List[torch._C.Value] = list(node.outputs()) + + assert len(inputs) == 1 + in_val = frame.get_var(inputs[0].debugName()) + if not isinstance(in_val, bool): + raise RuntimeError("Dynamic Graph is not supported yet") + + # type: torch._C.Block + true_block, false_block = node.blocks() + chosen_block : torch._C.Block = true_block if in_val else false_block + body_out_vars = list(chosen_block.outputs()) + + all_ir_nodes : List[IRFwOperation] = [] + + # Evaluate the 'eval_block' in a new frame, to isolate within-block variables from + # polluting the current frame. And we'll manually bind all resultant variables later on. + frame.push_var(inherit_from_top=True) + + # prim::If's blocks do not have any subgraph parameters, directly evaluate the body + for subnode in chosen_block.nodes(): + subnode : torch._C.Node + + sub_ir_nodes : List[IRFwOperation] = ScriptModuleParser.parse_node(subnode, module, frame) + + for ir_node in sub_ir_nodes: + try: + ret = ir_node.infer_shape() + if not ret: + print(f'warning: {ir_node} cannot infer shape') + except Exception: + raise RuntimeError( + f"====== Shape Infer Error ====\n\n\n" + f"IR Node: {ir_node}\n\n" + f"Module:\n{module.code}\n\n" + f"Node:\n{node}\n" + f"====== Shape Infer Error ====\n\n\n" + ) + + all_ir_nodes += sub_ir_nodes + + # retrieve the block's resultant values + result_vals = [frame.get_var(body_out_var.debugName()) for body_out_var in body_out_vars] + + # clean up + frame.pop_var() + + # bind the prim:If's resultant variables + assert len(result_vals) == len(outputs) + for output, out_val in zip(outputs, result_vals): + frame.add_var(output.debugName(), out_val) + + return all_ir_nodes @staticmethod def parse_prim_loop_node(node, module, frame: Frame) -> List[IRFwOperation]: diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py index e79ff9ff..56766f08 100644 --- a/examples/wrf/wrf2.py +++ b/examples/wrf/wrf2.py @@ -32,12 +32,17 @@ def __init__(self, dt, ntau, nz, ny, nx, dz, dy, dx, device): self.device = torch.device(device) - # TODO remove these testing parameters - # These three are to control the size of the unrolled graph, and they are related to the three layers of the nested loops, respectively. - # The magnitude is almost decided by `ntau` only. + # These three fields are to control the size of the unrolled graph + # by faking the loop upper bounds (UB), + # and they are related to the three layers of the nested loops, respectively. + # + # By setting to -1 we can recover the original loop upper bound. + # + # NOTE The magnitude is almost decided by `ntau` only. The final graph size may vary + # from ~4k (all fake UBs are 1) to ~23k (all fake UBs are -1, i.e. the original) self._step_fake_ntau = 1 - self._ac_step_fake_ub = 2 - self._solver_fake_ub = 2 + self._ac_step_fake_ub = 1 + self._solver_fake_ub = 1 def init(self, theta, Ptop=250e2): @@ -137,8 +142,7 @@ def step(self, dtau:float, ntau:int, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W alpha = - self.dz(self.pzphi(phi)) / mu p = self.PREF * (self.RD * Theta / mu / self.PREF / alpha)**self.GAMMA - # TODO fake upper bound - ntau = self._step_fake_ntau + ntau = self._step_fake_ntau if self._step_fake_ntau >= 0 else ntau for i in range(ntau): U2, V2, W2, O2, Theta2, phi2, mu2, pi2 = \ @@ -187,17 +191,15 @@ def ac_step(self, dtau:float, O2_ = torch.zeros(O2.shape, device=O2.device) mu2_ = torch.zeros(mu2.shape, device=mu2.device) - # TODO fake upper bound - #for i in range(1, O2.shape[0] + 1): - for i in range(1, self._ac_step_fake_ub): + _ctrl_O2_ub = self._ac_step_fake_ub + 1 if self._ac_step_fake_ub >= 0 else O2.shape[0] + 1 + for i in range(1, _ctrl_O2_ub): sub = i * self.delta_z * dpi2 + \ (self.dx(self.px(U2_)) + self.dy(self.py(V2_)) - R_mu)[-i:].view( -1, self.ny, self.nx).sum(0) * self.delta_z O2_ = O2_.select_scatter(sub, dim=0, index=-i) - # TODO fake upper bound - #for i in range(mu2.shape[0]): - for i in range(1, self._ac_step_fake_ub): + _ctrl_mu2_ub = self._ac_step_fake_ub if self._ac_step_fake_ub >= 0 else mu2.shape[0] + for i in range(_ctrl_mu2_ub): mu2_ = mu2_.select_scatter(pi2, dim=0, index=i) # self.O2_ = O2_ @@ -401,9 +403,8 @@ def solve_tridiagonal_(self, # forward sweep - # TODO fake upper bound - #for i in range(1, d.shape[0]): - for i in range(1, self._solver_fake_ub): + _ctrl_d_ub = self._solver_fake_ub + 1 if self._solver_fake_ub >= 0 else d.shape[0] + for i in range(1, _ctrl_d_ub): w = l[i - 1] / d[i - 1] d_i = d[i] - w * u[i - 1] @@ -414,12 +415,11 @@ def solve_tridiagonal_(self, # backward substitution x = torch.zeros(b.shape, device=b.device) - x.select_scatter(b[-1] / d[-1], dim=0, index=-1) + x = x.select_scatter(b[-1] / d[-1], dim=0, index=-1) - # TODO fake upper bound - #for i in range(x.shape[0] - 2, -1, -1): - for i in range(1, self._solver_fake_ub): - x.select_scatter( (b[i] - u[i] * x[i + 1]) / d[i], dim=0, index=i) + _ctrl_x_range_start = self._solver_fake_ub - 1 if self._solver_fake_ub >= 0 else x.shape[0] - 2 + for i in range(_ctrl_x_range_start, -1, -1): + x = x.select_scatter( (b[i] - u[i] * x[i + 1]) / d[i], dim=0, index=i) return x From 7e8a6a81b693eea376152f5882d578d343342c86 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Jun 2022 20:55:25 +0800 Subject: [PATCH 0865/1892] add memory constraints --- cube/tetris/composer.py | 199 +++++++++++++++++++++++++--------------- 1 file changed, 125 insertions(+), 74 deletions(-) diff --git a/cube/tetris/composer.py b/cube/tetris/composer.py index f1cbaa11..d73dc278 100644 --- a/cube/tetris/composer.py +++ b/cube/tetris/composer.py @@ -4,7 +4,7 @@ Abstraction layer for microb-batch execution plan merge. """ -from typing import Any, Callable, Dict, List, Tuple, Optional, Union +from typing import Callable, Dict, List, Tuple, Optional, Union import numpy as np from enum import Enum import time @@ -19,9 +19,10 @@ class BType(Enum): FW = 'forward' BW = 'backward' - def __init__(self, mid: int, btype: BType): + def __init__(self, mid: int, btype: BType, mem: int = 1): self.mid: int = mid - self.type = btype + self.btype = btype + self.memory = abs(mem) if btype == Block.BType.FW else 0-abs(mem) # dependency track self.before: List[Block] = list() self.after: List[Block] = list() @@ -36,14 +37,16 @@ def add_dependency(before, after): after.before.append(before) def __repr__(self): - return f'f{self.mid}' if self.type == Block.BType.FW else f'b{self.mid}' + return f'f{self.mid}' if self.btype == Block.BType.FW else f'b{self.mid}' class PlanBase: def __init__(self, ndevs: int): - self.blocks: Dict[Tuple[Tuple[int], int], Block] = dict() - self.positions: Dict[int, Tuple[int, int]] = dict() + # (device, step) -> block + self.blocks: Dict[Tuple[int, int], Block] = dict() + # block id -> ((device,), step) + self.positions: Dict[int, Tuple[Tuple[int], int]] = dict() self.plan = np.zeros((ndevs, ndevs * 2), dtype=int) @property @@ -54,6 +57,25 @@ def ndevs(self): def nsteps(self): return self.plan.shape[1] + def memory(self, devid: Optional[int] = None) -> Union[List[int], int]: + """ + Get memory of the this plan + """ + if isinstance(devid, int): + memory = 0 + peak_mem = memory + for step, have_block in enumerate(self.plan[devid]): + have_block = have_block != 0 + if have_block: + memory += self.block(devid, step).memory + peak_mem = max(peak_mem, memory) + return peak_mem + else: + dev_peak_mem = [] + for devid in range(self.ndevs): + dev_peak_mem.append(self.memory(devid)) + return dev_peak_mem + def block(self, dev: int, step: int): """ Get block given a position @@ -76,7 +98,7 @@ def squeeze(self): """ remove redundant steps """ - execflag = np.sum(self.plan, axis=1) + execflag = np.sum(self.plan, axis=0) for idx in range(self.nsteps): if execflag[-idx-1] != 0: break @@ -121,7 +143,7 @@ def visualize(self, outfile=None): fontsize = [40] txts = list() def draw_block(block: Block, position: Tuple[Tuple[int], int], fontsize): - color = '#4472C4' if block.type == Block.BType.FW else '#ED7D31' + color = '#4472C4' if block.btype == Block.BType.FW else '#ED7D31' devs, step = position for dev in devs: rec = Rectangle((step, dev+0.5), 1, 1, color=color, ec='black', lw=1.5) @@ -314,6 +336,8 @@ def conflict(micros: List[MicroPlan], step: int) -> Dict[int, List[Tuple[MicroPl """ Get conflict blocks at `step`. Return the conflicted (MicroPlan, Block) grouped by device id + + This assumes micros are composable < step - 1 """ max_steps = max(micro.nsteps for micro in micros) for micro in micros: @@ -339,7 +363,10 @@ def premise(fn, ndevs: int, nmicros: int): return micros @staticmethod - def bfs_schedule(micros: List[MicroPlan], mem_opt=True, prune_symmetric=True): + def bfs_schedule(micros: List[MicroPlan], + mem_constraints: Union[int, Tuple[int]], + mem_opt=True, prune_symmetric=True): + mem_constraints = [mem_constraints] * micros[0].ndevs if isinstance(mem_constraints, int) else mem_constraints total_status = 1 micros.sort(key=lambda m: m.mid) block_hash = Composer.construct_hash(micros) if prune_symmetric else None @@ -353,8 +380,11 @@ def bfs_schedule(micros: List[MicroPlan], mem_opt=True, prune_symmetric=True): for micros in prev: # get and solve conflicts conflicts = SchedulePlan.conflict(micros, step) + if len(conflicts) == 0: + next.append(micros) + continue # input(f'conflicts: dev: {list(conflicts.keys())}, mids: {[[conf[0].mid for conf in c] for c in conflicts.values()]} | >>>') - for shifts in Composer.iter_shifts(conflicts, block_hash=block_hash): + for shifts in Composer.iter_shifts(micros, conflicts, step, mem_constraints, block_hash=block_hash): # print(f"step {step}: {shifts}") shifted_micros = [micro.copy() for micro in micros] for cblock in shifts: @@ -439,58 +469,89 @@ def to_same_step(micros: List[MicroPlan]): return micros @staticmethod - def iter_shifts(conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], + def memory(micros: List[MicroPlan], to_step: int) -> List[int]: + """ + Get memory consumption from of step [0, `to_step`) + """ + if to_step == 0: + return [0] * micros[0].ndevs + micros = [micro.copy() for micro in micros] + for micro in micros: + micro.expand_to(to_step-1) + micro.plan = micro.plan[:,:to_step] + for block in set(micro.blocks.values()): + devices, step = micro.position(block) + if step >= to_step: + for devid in devices: + del micro.blocks[(devid, step)] + del micro.positions[id(block)] + sched_plan = SchedulePlan(micros) + return sched_plan.memory() + + @staticmethod + def iter_shifts(micros: List[MicroPlan], + conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], + step: int, + memory_constraints: List[int], block_hash = Union[None, Callable]) -> List[Block]: """ Enumerate shifted blocks to resolve conflicts on step `step`. """ devs = list(conflicts.keys()) - prev_shifts: List[List[Block]] = [[],] - next_shifts: List[List[Block]] = [] - for dev in devs: - for shifts in prev_shifts: - cmicros = [c[0] for c in conflicts[dev]] - cblocks = [c[1] for c in conflicts[dev]] - # since a same block can be on multiple devices (e.g., tensor parallel) - # we need to remove shifted blocks if it is decided before - for sblock in shifts: - if sblock in cblocks: - idx = cblocks.index(sblock) - cblocks = cblocks[:idx] + cblocks[idx+1:] - cmicros = cmicros[:idx] + cmicros[idx+1:] - if len(cblocks) <= 1: - next_shifts.append(shifts) - continue + prev_keep: List[Dict[int, Block]] = [{devid: None for devid in devs}] + next_keep: List[Dict[int, Block]] = [] + + # memory constraints to decide keep candidates + keep_candidates: Dict[int, List[Tuple[MicroPlan, Block]]] = {devid: [] for devid in devs} + memory = Composer.memory(micros, step) + for devid in devs: + cmicro_blocks = conflicts[devid] + for cmicro, cblock in cmicro_blocks: + for cdev in cmicro.position(cblock)[0]: + if memory[cdev] + cblock.memory <= memory_constraints[cdev]: + keep_candidates[devid].append((cmicro, cblock)) + # print(memory, memory_constraints) - candidates = [] - if block_hash is not None: - gids = [block_hash(cblock) for cblock in cblocks] - for gid in set(gids): - gblocks = [cblock for cblock, cgid in zip(cblocks, gids) if cgid == gid] - gmids = [gblock.mid for gblock in gblocks] - idx = gmids.index(min(gmids)) - candidates.append(gblocks[idx]) + for dev in devs: + for keeps in prev_keep: + cmicros = [c[0] for c in keep_candidates[dev]] + cblocks = [c[1] for c in keep_candidates[dev]] + if keeps[dev] is None: + # get candidate by pruning the symetric block + candidates = [] + if block_hash is not None: + gids = [block_hash(cblock) for cblock in cblocks] + for gid in set(gids): + gblocks = [cblock for cblock, cgid in zip(cblocks, gids) if cgid == gid] + gmids = [gblock.mid for gblock in gblocks] + idx = gmids.index(min(gmids)) + candidates.append(gblocks[idx]) + else: + candidates = cblocks + for kblock in candidates: + idx = cblocks.index(kblock) + kmicro = cmicros[idx] + empty = True + for kdev in kmicro.position(kblock)[0]: + if kdev in keeps and keeps[kdev] is not None: + empty = False + break + if empty: + new_keeps = {devid: blk for devid, blk in keeps.items()} + for kdev in kmicro.position(kblock)[0]: + if kdev in new_keeps: + new_keeps[kdev] = kblock + next_keep.append(new_keeps) else: - candidates = cblocks - - # if prune_same_micro: - # if Composer.same_plans(cmicros, start_step=step): - # candidates = [candidates[0]] - - for kblock in candidates: - idx = cblocks.index(kblock) - # keep blocks on the idx while shifts the rest - nshifts = shifts + cblocks[:idx] + cblocks[idx+1:] - # if the reserved block executes on multiple devices, - # then the rest device must shift all other blocks - for odev in cmicros[idx].position(kblock)[0]: - if odev != dev and odev in conflicts: - for _, ocblock in conflicts[odev]: - if ocblock != cblocks[idx] and ocblock not in nshifts: - nshifts.append(ocblock) - next_shifts.append(nshifts) - prev_shifts, next_shifts = next_shifts, [] - for shifts in prev_shifts: + next_keep.append(keeps) + prev_keep, next_keep = next_keep, [] + + for keeps in prev_keep: + shifts = [] + for devid, block in keeps.items(): + for cmicro, cblock in conflicts[devid]: + if cblock != block and cblock not in shifts: + shifts.append(cblock) yield shifts @staticmethod @@ -514,7 +575,7 @@ def memory_opt_step(micros: List[MicroPlan], step: int): # find forward blocks for dev in devs: block = micro.block(dev, step) - if block.type != Block.BType.FW: + if block.btype != Block.BType.FW: continue if block not in fblocks: fblocks.append(block) @@ -628,8 +689,8 @@ def compose_1F1B(ndevs, nmicros): def search(ndevs, nmicros, visualize=False): # premise - # micros = Composer.premise(uniform_staging, ndevs, nmicros) - micros = Composer.premise(chimera_staging, ndevs, nmicros) + micros = Composer.premise(uniform_staging, ndevs, nmicros) + # micros = Composer.premise(chimera_staging, ndevs, nmicros) # micros = Composer.premise(mbart_staging, ndevs, nmicros) print('============== Premise ================') for idx, micro in enumerate(micros): @@ -641,7 +702,7 @@ def search(ndevs, nmicros, visualize=False): # search shift tic = time.time() - schedules = Composer.bfs_schedule(micros, mem_opt=True, prune_symmetric=True) + schedules = Composer.bfs_schedule(micros, 10, mem_opt=False, prune_symmetric=True) toc = time.time() print('search done. time {:.2f}s'.format(toc - tic)) @@ -650,11 +711,11 @@ def search(ndevs, nmicros, visualize=False): assert len(steps) == 1, f"got un-consistent step set: {steps}" nsteps = list(steps)[0] print(f'find {len(schedules)} step-optimal plans (step={nsteps})') - # for idx, schedule in enumerate(schedules): - # print(f'Schedule #{idx+1}:') - # print(schedule) - # if visualize: - # schedule.visualize(f'planlog/plan{idx+1}.png') + for idx, schedule in enumerate(schedules): + print(f'Schedule #{idx+1}:') + print(schedule) + if visualize: + schedule.visualize(f'planlog/plan{idx+1}.png') ndevs = 4 @@ -663,13 +724,3 @@ def search(ndevs, nmicros, visualize=False): # schedule = compose_1F1B(ndevs, nmicros) # schedule.visualize('out.png') search(ndevs, nmicros, visualize=False) - - # micros = mbart_staging(ndevs, nmicros) - # for idx, micro in enumerate(micros): - # print(f'microbatch #{idx}:') - # print(micro) - # - # micros[0].shift(micros[0].block(0, 0)) - # micros[0].shift(micros[0].block(0, 2)) - # micros[0].shift(micros[0].block(0, 5)) - # print(micros[0]) \ No newline at end of file From 13670854fcd3055e210aaa46114ab4fc0c7178aa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Jun 2022 13:38:28 +0800 Subject: [PATCH 0866/1892] update composer by considering memory constraints --- cube/tetris/composer.py | 200 ++++++++++++++++++++++++---------------- cube/tetris/solver.py | 62 ++++++++++--- 2 files changed, 174 insertions(+), 88 deletions(-) diff --git a/cube/tetris/composer.py b/cube/tetris/composer.py index d73dc278..9537876b 100644 --- a/cube/tetris/composer.py +++ b/cube/tetris/composer.py @@ -242,6 +242,25 @@ def copy(self): micro.positions.update(self.positions) return micro + def select(self, begin_step: int, stop_step: int): + """ + select micro plans of [begin, stop) + """ + micro = MicroPlan(self.mid, self.ndevs) + micro.expand_to(stop_step) + all_blocks = [] + for block in self.blocks.values(): + if block not in all_blocks: + all_blocks.append(block) + for block in all_blocks: + devs, step = self.position(block) + if begin_step <= step and step < stop_step: + for dev in devs: + micro.plan[dev, step] += 1 + micro.blocks[(dev, step)] = block + micro.positions[id(block)] = (devs, step) + return micro + def shift(self, block: Block, inplace=True): """ The primitive: shift a block by pushing one step later @@ -322,37 +341,79 @@ def __init__(self, micros: List[MicroPlan]): self.positions.update(micro.positions) @staticmethod - def composable(micros: List[MicroPlan]) -> bool: + def composable(micros: List[MicroPlan], mem_constraints: Optional[List[int]] = None) -> bool: + # check execution conflicts max_steps = max(micro.nsteps for micro in micros) for micro in micros: micro.expand_to(max_steps) plans = tuple(micro.plan for micro in micros) schedule = np.sum(np.stack(plans, axis=-1), axis=-1) devids = np.where(schedule > 1)[0] - return len(devids) == 0 + if len(devids) != 0: + return False + # check memory conflicts + if mem_constraints is not None: + mems = SchedulePlan(micros).memory() + for mem, bound in zip(mems, mem_constraints): + if mem > bound: + return False + return True @staticmethod - def conflict(micros: List[MicroPlan], step: int) -> Dict[int, List[Tuple[MicroPlan, Block]]]: + def conflict(micros: List[MicroPlan], step: int, memory_constraints: List[int]) -> Dict[int, List[Tuple[MicroPlan, Block]]]: """ Get conflict blocks at `step`. Return the conflicted (MicroPlan, Block) grouped by device id - This assumes micros are composable < step - 1 - """ + This assumes micros are composable for steps < step + """ + ndevs = micros[0].ndevs + # find memory conflicted blocks + mem_conflicts: Dict[int, List[Block]] = dict() + curr_memory = [] + for devid in range(ndevs): + mem = 0 + for t in range(step): + for micro in micros: + block = micro.block(devid, t) + if block is not None: + mem += block.memory + curr_memory.append(mem) + for devid in range(ndevs): + for micro in micros: + block = micro.block(devid, step) + if block is not None: + if curr_memory[devid] + block.memory > memory_constraints[devid]: + for dev in micro.position(block)[0]: + if dev not in mem_conflicts: + mem_conflicts[dev] = [] + if block not in mem_conflicts[dev]: + mem_conflicts[dev].append(block) + + # find execution conflicted blocks + exe_conflicts: Dict[int, List[Block]] = dict() max_steps = max(micro.nsteps for micro in micros) for micro in micros: micro.expand_to(max_steps) plans = tuple(micro.plan[:,step] for micro in micros) schedule = np.sum(np.stack(plans, axis=-1), axis=-1) - conflicts = dict() devids = np.where(schedule > 1)[0] for dev in devids: - conflicts[dev] = [] + exe_conflicts[dev] = [] for micro in micros: cblock = micro.block(dev, step) if cblock is not None: - conflicts[dev].append((micro, cblock)) - return conflicts + exe_conflicts[dev].append(cblock) + + # consistent device set + devices = set(list(exe_conflicts.keys()) + list(mem_conflicts.keys())) + for devid in devices: + if devid not in mem_conflicts: + mem_conflicts[devid] = [] + if devid not in exe_conflicts: + exe_conflicts[devid] = [] + # print(exe_conflicts, mem_conflicts) + return exe_conflicts, mem_conflicts class Composer: @@ -379,12 +440,12 @@ def bfs_schedule(micros: List[MicroPlan], print(f'solving step {step}, candidates {len(prev)}...') for micros in prev: # get and solve conflicts - conflicts = SchedulePlan.conflict(micros, step) - if len(conflicts) == 0: + exe_conflicts, mem_conflicts = SchedulePlan.conflict(micros, step, mem_constraints) + if len(exe_conflicts) == 0 and len(mem_conflicts) == 0: next.append(micros) continue # input(f'conflicts: dev: {list(conflicts.keys())}, mids: {[[conf[0].mid for conf in c] for c in conflicts.values()]} | >>>') - for shifts in Composer.iter_shifts(micros, conflicts, step, mem_constraints, block_hash=block_hash): + for shifts in Composer.iter_shifts(micros, exe_conflicts, mem_conflicts, block_hash=block_hash): # print(f"step {step}: {shifts}") shifted_micros = [micro.copy() for micro in micros] for cblock in shifts: @@ -395,14 +456,21 @@ def bfs_schedule(micros: List[MicroPlan], # for micro in shifted_micros: # print(f'microbatch #{micro.mid}:') # print(micro) - if SchedulePlan.composable(shifted_micros): + if SchedulePlan.composable(shifted_micros, mem_constraints): schedule = SchedulePlan(shifted_micros) schedules.append(schedule) if schedule.nsteps < opt_step: print(f'find fewer steps: {schedule.nsteps}') opt_step = min(opt_step, schedule.nsteps) else: - next.append(shifted_micros) + # pruning technique: discard plans that exceed opt_step + discard = False + for m in shifted_micros: + if m.nsteps > opt_step: + discard = True + break + if not discard: + next.append(shifted_micros) total_status += len(next) prev, next = next, [] step += 1 @@ -468,54 +536,28 @@ def to_same_step(micros: List[MicroPlan]): micro.expand_to(nsteps) return micros - @staticmethod - def memory(micros: List[MicroPlan], to_step: int) -> List[int]: - """ - Get memory consumption from of step [0, `to_step`) - """ - if to_step == 0: - return [0] * micros[0].ndevs - micros = [micro.copy() for micro in micros] - for micro in micros: - micro.expand_to(to_step-1) - micro.plan = micro.plan[:,:to_step] - for block in set(micro.blocks.values()): - devices, step = micro.position(block) - if step >= to_step: - for devid in devices: - del micro.blocks[(devid, step)] - del micro.positions[id(block)] - sched_plan = SchedulePlan(micros) - return sched_plan.memory() - @staticmethod def iter_shifts(micros: List[MicroPlan], - conflicts: Dict[int, List[Tuple[MicroPlan, Block]]], - step: int, - memory_constraints: List[int], - block_hash = Union[None, Callable]) -> List[Block]: + exe_conflicts: Dict[int, List[Block]], + mem_conflicts: Dict[int, List[Block]], + block_hash: Optional[Callable] = None) -> List[Block]: """ Enumerate shifted blocks to resolve conflicts on step `step`. """ - devs = list(conflicts.keys()) + devs = tuple(exe_conflicts.keys()) prev_keep: List[Dict[int, Block]] = [{devid: None for devid in devs}] next_keep: List[Dict[int, Block]] = [] - # memory constraints to decide keep candidates - keep_candidates: Dict[int, List[Tuple[MicroPlan, Block]]] = {devid: [] for devid in devs} - memory = Composer.memory(micros, step) + keep_candidates: Dict[int, List[Block]] = {devid: [] for devid in devs} for devid in devs: - cmicro_blocks = conflicts[devid] - for cmicro, cblock in cmicro_blocks: - for cdev in cmicro.position(cblock)[0]: - if memory[cdev] + cblock.memory <= memory_constraints[cdev]: - keep_candidates[devid].append((cmicro, cblock)) - # print(memory, memory_constraints) + for cblock in exe_conflicts[devid]: + if cblock not in mem_conflicts[devid]: + keep_candidates[devid].append(cblock) for dev in devs: for keeps in prev_keep: - cmicros = [c[0] for c in keep_candidates[dev]] - cblocks = [c[1] for c in keep_candidates[dev]] + cblocks = [c for c in keep_candidates[dev]] + cmicros = [micros[c.mid] for c in cblocks] if keeps[dev] is None: # get candidate by pruning the symetric block candidates = [] @@ -528,29 +570,33 @@ def iter_shifts(micros: List[MicroPlan], candidates.append(gblocks[idx]) else: candidates = cblocks - for kblock in candidates: - idx = cblocks.index(kblock) - kmicro = cmicros[idx] - empty = True - for kdev in kmicro.position(kblock)[0]: - if kdev in keeps and keeps[kdev] is not None: - empty = False - break - if empty: - new_keeps = {devid: blk for devid, blk in keeps.items()} + if len(candidates) == 0: + next_keep.append(keeps) + else: + for kblock in candidates: + idx = cblocks.index(kblock) + kmicro = cmicros[idx] + empty = True for kdev in kmicro.position(kblock)[0]: - if kdev in new_keeps: - new_keeps[kdev] = kblock - next_keep.append(new_keeps) + if kdev in keeps and keeps[kdev] is not None: + empty = False + break + if empty: + new_keeps = {devid: blk for devid, blk in keeps.items()} + for kdev in kmicro.position(kblock)[0]: + if kdev in new_keeps: + new_keeps[kdev] = kblock + next_keep.append(new_keeps) else: next_keep.append(keeps) prev_keep, next_keep = next_keep, [] for keeps in prev_keep: shifts = [] - for devid, block in keeps.items(): - for cmicro, cblock in conflicts[devid]: - if cblock != block and cblock not in shifts: + for devid in devs: + kblock = keeps[devid] + for cblock in exe_conflicts[devid] + mem_conflicts[devid]: + if kblock != cblock and cblock not in shifts: shifts.append(cblock) yield shifts @@ -687,11 +733,11 @@ def compose_1F1B(ndevs, nmicros): print(schedule) return schedule - def search(ndevs, nmicros, visualize=False): + def search(ndevs, nmicros, mem_constraints: int, visualize=False): # premise - micros = Composer.premise(uniform_staging, ndevs, nmicros) + # micros = Composer.premise(uniform_staging, ndevs, nmicros) # micros = Composer.premise(chimera_staging, ndevs, nmicros) - # micros = Composer.premise(mbart_staging, ndevs, nmicros) + micros = Composer.premise(mbart_staging, ndevs, nmicros) print('============== Premise ================') for idx, micro in enumerate(micros): print(f'microbatch #{idx}:') @@ -702,20 +748,20 @@ def search(ndevs, nmicros, visualize=False): # search shift tic = time.time() - schedules = Composer.bfs_schedule(micros, 10, mem_opt=False, prune_symmetric=True) + schedules = Composer.bfs_schedule(micros, mem_constraints=mem_constraints, mem_opt=False, prune_symmetric=True) toc = time.time() print('search done. time {:.2f}s'.format(toc - tic)) - steps = set(schedule.nsteps for schedule in schedules) assert len(steps) == 1, f"got un-consistent step set: {steps}" nsteps = list(steps)[0] print(f'find {len(schedules)} step-optimal plans (step={nsteps})') - for idx, schedule in enumerate(schedules): - print(f'Schedule #{idx+1}:') - print(schedule) - if visualize: - schedule.visualize(f'planlog/plan{idx+1}.png') + print(f'one solution:\n{schedules[0]}\n{schedules[0].memory()}') + # for idx, schedule in enumerate(schedules): + # print(f'Schedule #{idx+1}:') + # print(schedule) + # if visualize: + # schedule.visualize(f'planlog/plan{idx+1}.png') ndevs = 4 @@ -723,4 +769,4 @@ def search(ndevs, nmicros, visualize=False): # schedule = compose_1F1B(ndevs, nmicros) # schedule.visualize('out.png') - search(ndevs, nmicros, visualize=False) + search(ndevs, nmicros, mem_constraints=10, visualize=False) diff --git a/cube/tetris/solver.py b/cube/tetris/solver.py index 58a3633c..07f06d1b 100644 --- a/cube/tetris/solver.py +++ b/cube/tetris/solver.py @@ -194,6 +194,46 @@ def solve(self, decrease = True): toc = time.time() print('iterate all plans: {:.2f} seconds.'.format(toc-tic)) + def solve_mconstraints(self, memory: int, decrease=True): + global gsolver + tic = time.time() + min_step = max(len(blks) for blks in self._blocks) + max_step = self.nblocks + opt_step = max_step if decrease else min_step + + self.set_memory() + + # memory constraints + gsolver.push() + gsolver.add(self._mem <= memory) + # find optimal step + while True: + assert min_step <= opt_step and opt_step <= max_step, "out of step boundary. consider this as a bug." + gsolver.push() + gsolver.add(self._nsteps == opt_step) + if gsolver.check() == sat: + print(f'find scheduling plan in {opt_step} steps') + solution = gsolver.model() + self.set_solution(solution) + gsolver.pop() + if not decrease: break + else: + print(f'fail to find solution for {opt_step} steps') + gsolver.pop() + if decrease: + opt_step += 1 + break + opt_step = opt_step - 1 if decrease else opt_step + 1 + toc = time.time() + print('search time: {:.2f} seconds. find optimal step: {}'.format(toc-tic, opt_step)) + print('solution:') + print(self) + + tic = time.time() + self.iter_space(opt_step) + toc = time.time() + print('iterate all plans: {:.2f} seconds.'.format(toc-tic)) + def iter_space(self, nsteps: int, memory: int = None): """ @@ -269,8 +309,8 @@ def chimera_staging(ndevs: int, nmicros: int) -> SchedulePlan: sched = SchedulePlan(ndevs) assert nmicros % 2 == 0, "require microbatch# can be devided by 2" for mid in range(nmicros // 2): # V shape - fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}', mem=devid+1) for devid in range(ndevs)] - bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}', mem=devid+1) for devid in range(ndevs-1,-1,-1)] + fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}', mem=1) for devid in range(ndevs)] + bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}', mem=1) for devid in range(ndevs-1,-1,-1)] blocks = fblocks + bblocks for idx in range(ndevs * 2 - 1): Block.add_dependency(blocks[idx], blocks[idx+1]) @@ -279,8 +319,8 @@ def chimera_staging(ndevs: int, nmicros: int) -> SchedulePlan: sched.add_block(bblocks[ndevs-1-devid], devid) for mid in range(nmicros // 2): # ^ shape mid = mid + nmicros // 2 - fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}', mem=ndevs-devid) for devid in range(ndevs-1,-1,-1)] - bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}', mem=ndevs-devid) for devid in range(ndevs)] + fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}', mem=1) for devid in range(ndevs-1,-1,-1)] + bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}', mem=1) for devid in range(ndevs)] blocks = fblocks + bblocks for idx in range(ndevs * 2 - 1): Block.add_dependency(blocks[idx], blocks[idx+1]) @@ -303,8 +343,8 @@ def mbart_staging(ndevs: int, nmicros: int) -> SchedulePlan: for step in range(ndevs+2): if step in [0, ndevs // 2 + 1]: fdevid = bdevid = tuple(range(ndevs)) - fblock = Block(mid, Block.BType.FW, f'fe{step}{mid}devall', mem=4) - bblock = Block(mid, Block.BType.BW, f'be{step}{mid}devall', mem=4) + fblock = Block(mid, Block.BType.FW, f'fe{step}{mid}devall', mem=1) + bblock = Block(mid, Block.BType.BW, f'be{step}{mid}devall', mem=1) else: fdevid = bdevid = step - 1 if step < ndevs // 2 + 1 else step - 2 fblock = Block(mid, Block.BType.FW, f'f{mid}dev{fdevid}', mem=1) @@ -318,10 +358,10 @@ def mbart_staging(ndevs: int, nmicros: int) -> SchedulePlan: Block.add_dependency(blocks[idx], blocks[idx+1]) return sched - ndevs = 4 - nmicros = 4 + ndevs = 8 + nmicros = 8 - # sched = uniform_staging(ndevs, nmicros) + sched = uniform_staging(ndevs, nmicros) # sched = chimera_staging(ndevs, nmicros) - sched = mbart_staging(ndevs, nmicros) # ndev=4, nmicro=4 => solution: step=32 - sched.solve(decrease=True) + # sched = mbart_staging(ndevs, nmicros) # ndev=4, nmicro=4 => solution: step=30 + sched.solve_mconstraints(memory=ndevs, decrease=True) From f88f7b472257853a6abfbcd4e5f6cbffd4299801 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Jun 2022 16:20:59 +0800 Subject: [PATCH 0867/1892] gradient error log --- cube/ir/tensor.py | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 111e96e3..57f7cbc7 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -684,37 +684,42 @@ def __copy__(self): @property def grad(self) -> Optional[Union[IRTensor, float]]: """ - Get gradient of this tensor which is associated by a - forward cell + Get gradient of this tensor. Gradient can be: - - None: if the tensor doesn't require gradient - - 1.0: if the tensor is loss tensor - - IRSubTensor: if the tensor requires gradient and not the loss tensor + - None: the tensor doesn't require gradient + - 1.0: the tensor is loss tensor (scalar) + - IRSubTensor: the tensor requires gradient and is not the loss tensor (scalar) Gradient cannot be set and can only be inferred by its IRFullTensor. + The gradient will be lazy updated when its IRFullTensor gets + new consumed / produced tensors """ if not self._dirty_grad: return self._grad - assert isinstance(self._cell, IRCell), "No cell attached to this tensor." + assert isinstance(self.cell, IRCell), "No cell attached to this tensor." full_grad = self.parent.grad - # None indicate the tensor doesn't need grad. - # float means this tensor is a loss tensor if full_grad is None or isinstance(full_grad, float): self._grad = full_grad # this tensor is consumed - elif self in self._cell.inputs(): - ref_cell_ids = list() - for dst_cell in self.parent.consumers: - for input in dst_cell.inputs(): - if self.overlap(input) and dst_cell._id not in ref_cell_ids: - ref_cell_ids.append(dst_cell._id) - break - ref_times = len(ref_cell_ids) + elif self in self.cell.inputs(): + ref_consumers = list() + for consumer in self.parent.consumers: + for itensor in consumer.inputs(): + if self.overlap(itensor): + assert itensor == self, \ + "partial overlapping of consumed tensors is not supported during backward" + # replicated nodes will have same node id + if consumer._id not in ref_consumers: + ref_consumers.append(consumer._id) + # if one node has multiple same tensors, + # will consider them as one + break + ref_times = len(ref_consumers) if ref_times == 0: raise RuntimeError("Internal error: consumer doesn't have the operator attached to this tensor") - idx = ref_cell_ids.index(self._cell._id) + idx = ref_consumers.index(self.cell._id) grad = full_grad.select( indmap = self.indmap, valmap = (idx, ref_times), @@ -724,7 +729,7 @@ def grad(self) -> Optional[Union[IRTensor, float]]: self._dirty_grad = False return grad # this tensor is produced - elif self in self._cell.outputs(): + elif self in self.cell.outputs(): grad = full_grad.select( indmap = self.indmap, valmap = (0, 1), @@ -732,7 +737,7 @@ def grad(self) -> Optional[Union[IRTensor, float]]: ) self._grad = grad else: - raise RuntimeError("visiting a tensor grad that potentially generated by IRAdapter") + raise RuntimeError("Visit graidient of a tensor that is potentially generated by IRAdapter") self._dirty_grad = False self._requires_grad = False if full_grad is None else True return self._grad From 9728fd8f98bed9ab9445010a111f1688aabf5275 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Jun 2022 16:21:23 +0800 Subject: [PATCH 0868/1892] search space for guodong's work --- examples/gsearch/blocks.py | 157 +++++++++++++++++++++++++++ examples/gsearch/gpt/model.py | 103 ++++++++++++++++++ examples/gsearch/gpt/policy/naive.py | 8 ++ examples/gsearch/gpt/train.py | 74 +++++++++++++ 4 files changed, 342 insertions(+) create mode 100644 examples/gsearch/blocks.py create mode 100644 examples/gsearch/gpt/model.py create mode 100644 examples/gsearch/gpt/policy/naive.py create mode 100644 examples/gsearch/gpt/train.py diff --git a/examples/gsearch/blocks.py b/examples/gsearch/blocks.py new file mode 100644 index 00000000..05281ffb --- /dev/null +++ b/examples/gsearch/blocks.py @@ -0,0 +1,157 @@ +import torch +import cube +import warnings + + +@cube.graph.parser.register('L N E+, (h d) E+, (h d), (h d) E+, (h d), (h d) E+, (h d) -> N h L d, N h L d, N h L d') +def attn_qkv(query: torch.Tensor, + q_proj: torch.Tensor, q_bias: torch.Tensor, + k_proj: torch.Tensor, k_bias: torch.Tensor, + v_proj: torch.Tensor, v_bias: torch.Tensor, + h: int, scale: float): + L, N = query.size(0), query.size(1) + d = q_proj.size(0) // h + + q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + q = q.contiguous().view(L, (N * h), d) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * h), d) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * h), d) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + q = q.view(N, h, L, d) + k = k.view(N, h, L, d) + v = v.view(N, h, L, d) + return q, k, v + + +@cube.graph.parser.register('N h L d, N h L d -> N h L L') +def attn_score(q: torch.Tensor, k: torch.Tensor, h: int, mask: bool = True): + N, num_head, L, d = q.size() + assert num_head == h + q = q.view(-1, L, d) + k = k.view(-1, L, d) + k = k.transpose(1, 2) + attn = torch.bmm(q, k) + attn = attn.view(N, h, L, L) + # attention mask + if mask: + ones = torch.ones((N, L, L), device=attn.device) + amask = torch.tril(ones) + amask = amask.view(N, 1, L, L) + amask = (amask < 0.5) + attn = attn.masked_fill_(amask, -10000.0) + return attn + + +@cube.graph.parser.register('N h L^ L^ -> N h L^ L^') +def attn_softmax(attn: torch.Tensor): + N, h, L, L = attn.size() + attn = attn.view((N * h), L, L) + attn = torch.nn.functional.softmax(attn, dim=-1) + return attn.view(N, h, L, L) + + +@cube.graph.parser.register('N h L L, N h L d -> L N (h d)') +def attn_context(attn: torch.Tensor, v: torch.Tensor): + N, h, L, d = v.size() + attn = attn.view((N * h), L, L) + v = v.view((N * h), L, d) + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, h * d) # (N h) L d -> L N (h d) + return output + + +class MultiHeadSelfAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): + super().__init__() + self.inner_dim = inner_dim + self.num_heads = num_heads + self.head_dim = inner_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # Q + self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + # K + self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + # V + self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) + self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + + def forward(self, query: torch.Tensor): + # QKV + q, k, v = attn_qkv( + query, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.num_heads, self.scaling + ) + # AttentionScore + attn = attn_score(q, k, self.num_heads, mask=True) + # softmax + attn = attn_softmax(attn) + # dropout + attn = torch.nn.functional.dropout(attn, self.dropout_p, True, False) # (N h) L L -> (N h) L L + # context + context = attn_context(attn, v) + # DenseOutput + output = torch.nn.functional.linear(context, self.out_proj, self.out_bias) # L N (h d), E E -> L N E + return output + + +class MLP(torch.nn.Module): + + def __init__(self, embed_dim: int, hidden_dim: int, dropout: float): + super().__init__() + self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) + self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) + self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) + self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) + self.dropout = dropout + + def forward(self, x: torch.Tensor): + x = torch.nn.functional.linear(x, self.proj1, self.proj1_bias) + x = torch.nn.functional.gelu(x) + x = torch.nn.functional.linear(x, self.proj2, self.proj2_bias) + x = torch.nn.functional.dropout(x, self.dropout, True, False) + return x + + +class EncoderLayer(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, + attn_hidden_dim: int, ffn_hidden_dim: int, + dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): + super().__init__() + self.self_attn = MultiHeadSelfAttention( + embed_dim, num_heads, attn_hidden_dim, atten_dropout + ) + self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) + self.dropout = torch.nn.Dropout(p=dropout) + self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + warnings.warn('residual is disabled in encoder block') + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x) + x = self.dropout(x) + # x = x + residual + + residual = x + x = self.final_layer_norm(x) + x = self.mlp(x) + # x = x + residual + return x diff --git a/examples/gsearch/gpt/model.py b/examples/gsearch/gpt/model.py new file mode 100644 index 00000000..b8de4b12 --- /dev/null +++ b/examples/gsearch/gpt/model.py @@ -0,0 +1,103 @@ +import torch + + +from examples.gsearch.blocks import EncoderLayer + +import cube + + +class Config: + + num_embeddings = 50304 + seqlen = 512 + + # 1.7B model + embed_dim = 2304 + layers = 8 # 24 + attention_heads = 24 + + # 3.6B model + # embed_dim = 3072 + # layers = 32 + # attention_heads = 32 + + # 7.5B model + # embed_dim = 4096 + # layers = 32 + # attention_heads = 36 + + attn_hidden_dim = embed_dim + ffn_hidden_dim = embed_dim * 4 + dropout = 0.0 + attn_dropout = 0.0 + activation_dropout = 0.0 + + +class GPT(torch.nn.Module): + + def __init__(self): + super().__init__() + cfg = Config() + + self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) + self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) + self.embed_dropout = torch.nn.Dropout() + + self.layers = torch.nn.ModuleList( + [EncoderLayer( + cfg.embed_dim, cfg.attention_heads, + cfg.attn_hidden_dim, cfg.ffn_hidden_dim, + cfg.dropout, cfg.attn_dropout, cfg.activation_dropout + ) for _ in range(cfg.layers)] + ) + self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): + + embed = self.embed(input_ids) + pos_embed = self.position(position_ids) + embed = embed + pos_embed + embed = self.embed_dropout(embed) + enc = embed.transpose(0, 1) + + for layer in self.layers: + enc = layer(enc) + enc = self.final_layernorm(enc) + + logits = torch.nn.functional.linear(enc, self.embed.weight) + # simplified + loss = torch.sum(logits) + return loss + + +class GPTDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + + self.bs = batch_size + self.cfg = Config() + super().__init__( + shapes=([batch_size, self.cfg.seqlen], + [batch_size, self.cfg.seqlen], + ), + dtypes=(torch.int64, torch.int64), + batch_dims=(0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + input_ids = torch.randint( + 0, self.cfg.num_embeddings, + size=(self.bs, self.cfg.seqlen), + dtype=torch.int64, device=torch.cuda.current_device() + ) + position_ids = torch.arange( + 0, self.cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() + ).repeat(self.bs) + return (input_ids, position_ids) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] \ No newline at end of file diff --git a/examples/gsearch/gpt/policy/naive.py b/examples/gsearch/gpt/policy/naive.py new file mode 100644 index 00000000..10de3596 --- /dev/null +++ b/examples/gsearch/gpt/policy/naive.py @@ -0,0 +1,8 @@ +from cube.graph import IRGraph + +def PAS(graph: IRGraph, resource): + # print(graph.extra_repr()) + for node in graph.nodes(): + graph.assign(node, 0) + # print(graph.extra_repr()) + return graph \ No newline at end of file diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py new file mode 100644 index 00000000..bc6758f5 --- /dev/null +++ b/examples/gsearch/gpt/train.py @@ -0,0 +1,74 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/gsearch/gpt/train.py +""" + + +import torch + +from examples.gsearch.gpt.model import GPT +from examples.gsearch.gpt.model import GPTDataLoader + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary, model_summary + +from examples.nlp.gpt.policy.naive import PAS + + +def train(): + + batch_size = 1 + + model = GPT() + dataloader = GPTDataLoader(batch_size) + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + print_each_rank('model weight consumpition:') + memory_summary() + + model = cube.SemanticModel(model, dataloader.shapes) + @cube.compile(model, dataloader, PAS=PAS, override=True) + def train_iter(model, dataloader): + input_ids, position_ids = next(dataloader) + loss = model(input_ids, position_ids) + loss.backward() + model = model.get_gen_module() + + CudaTimer(enable=False).warmup() + iter_num = 64 + for step in range(iter_num): + + # if step == 0: + # model_summary(model, next(dataloader)) + + if step >= 20: + CudaTimer(enable=True).start('e2e') + + # training + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + if step >= 20: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('passed first iteration') + + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-40, field_name='e2e'))) + memory_summary() + + +if __name__ == '__main__': + + cube.init() + train() \ No newline at end of file From 568a300ab6ebf77c670da4df1e9225bc66d60c53 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Jun 2022 16:22:17 +0800 Subject: [PATCH 0869/1892] gelu for torch v1.11 --- cube/graph/function/function.py | 7 ++----- cube/graph/parser/mapping.py | 1 + 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 24ed5b26..3a67fffb 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -76,11 +76,7 @@ def Identity(signature, inputs): def Linear(signature, inputs): - if signature == 'torch.linear': - import warnings - warnings.warn(f'signature {signature} replaced into torch.nn.functional.linear') - signature = 'torch.nn.functional.linear' - + signature = 'torch.nn.functional.linear' annos = [ 'b * k+, n k+ -> b * n', # no bias 'b * k+, n k+, n -> b * n' # have bias @@ -472,6 +468,7 @@ def Cos(signature, inputs): def GeLU(signature, inputs): annos = ['* -> *'] + signature = 'torch.nn.functional.gelu' tensor = inputs[0:1] if len(inputs) == 2: # adapt for newest pytorch version diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index c7813bef..736ac8e1 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -63,6 +63,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('dropout') : function.Dropout, __ftemplate('gelu') : function.GeLU, + __ttemplate('gelu') : function.GeLU, __ftemplate('_pad'): function.Pad, From 8fd56e14fdf621aa49f84a2a6d3732fbbe1ff241 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Jun 2022 00:19:18 +0800 Subject: [PATCH 0870/1892] new eniops --- cube/algorithm/generics.py | 40 +- cube/algorithm/ops/conv.py | 53 +-- cube/algorithm/ops/dataloader.py | 24 +- cube/algorithm/ops/einops.py | 125 +++--- cube/algorithm/ops/pad.py | 18 +- cube/graph/function/__init__.py | 2 +- cube/graph/function/einops.py | 702 +++++++++++++++++++++---------- cube/graph/function/function.py | 298 +++++-------- cube/graph/graph.py | 6 +- cube/graph/parser/register.py | 8 +- cube/ir/tensor.py | 3 +- 11 files changed, 698 insertions(+), 581 deletions(-) diff --git a/cube/algorithm/generics.py b/cube/algorithm/generics.py index e6d9cef2..b5e52fd6 100644 --- a/cube/algorithm/generics.py +++ b/cube/algorithm/generics.py @@ -1,26 +1,14 @@ -from typing import Dict +from typing import List, Optional -from cube.ir.cten import IRCell, IRTensor +from cube.ir.cten import IRCell class GenericDistAlgo: + """! + Generic distributed algorithm that partitions a node into sub-nodes. + """ def __init__(self, node: IRCell): - """ - Layout is the community distribution requirement for input and - output logical tensors. - - Format is the dimension ordering based on the logical format, - `None` indicates the format is consistent with logical op, - otherwise should be a list of integers like torch.Tensor.permute() - on the logical required format. - - # TODO: - input_format (list[list[int], None]): - input dim order compare with logical definition - output_format (list[list[int], None]): - output dim order compare with logical definition - """ if not isinstance(node, IRCell): raise TypeError("Expected node to be IRCell") self._node = node @@ -29,14 +17,22 @@ def __init__(self, node: IRCell): def node(self) -> IRCell: return self._node - def satisfy(self, config: Dict): - """ + def satisfy(self, **config) -> bool: + """! Check if the config satisfies instantiation conditions + + @param config Dict: configuration for the algorithm, like number of partitioned chunks. + + @return satisfy bool: True if the configuration can satisfy for this node """ raise NotImplementedError - def instantiate(self, config: Dict): - """ + def instantiate(self, **config) -> Optional[List[IRCell]]: + """! Instantiate the algorithm given the config + + @param config Dict: configuration for the algorithm, like number of partitioned chunks. + + @return sub_nodes Optional[List[IRCell]]: if sucess, the partitioned sub nodes, else None """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 9fa8ce3e..03729aec 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -19,22 +19,17 @@ def __init__(self, node: IRConv2D): raise TypeError(f"Expect IRConv2D") super().__init__(node) - def satisfy(self, config: Dict): - """ - config = dict(idx=int, dim=int, num=num) + def satisfy(self, idx: int, dim: int, num: int): + """! + Dimension split on Conv2D operator: N iC H W, oC iC dH dW, oC -> N oC oH oW Splittable dimension: N, oC Reduce dimension: oC """ - for attr in ['idx', 'dim', 'num']: - if not attr in config: - raise KeyError("Expected idx, dim, num in the config") + assert all(isinstance(t, int) for t in [idx, dim, num]), "idx, dim and num should be integer" node: IRConv2D = self.node - idx: int = config['idx'] - dim: int = config['dim'] - num: int = config['num'] groups = node.kwargs['groups'] # split N: if (idx, dim) == (0, 0): @@ -46,14 +41,10 @@ def satisfy(self, config: Dict): if (idx, dim) == (0, 1) or (idx, dim) == (1, 1): return groups == 1 and node.inputs(1).shape[0] % 0 == num - def instantiate(self, config: Dict): - if not self.satisfy(config): + def instantiate(self, idx: int, dim: int, num: int): + if not self.satisfy(idx, dim, num): return False node: IRConv2D = self.node - idx: int = config['idx'] - dim: int = config['dim'] - num: int = config['num'] - inputs, weights, bias = list(), list(), list() outputs = list() # split N @@ -98,15 +89,10 @@ def __init__(self, node: IRConv2D): raise TypeError(f"Expect IRConv2D") super().__init__(node) - def satisfy(self, config: Dict): - for attr in ['idx', 'dim', 'num']: - if not attr in config: - raise KeyError("Expected idx, dim, num in the config") + def satisfy(self, idx: int, dim: int, num: int): + assert all(isinstance(t, int) for t in [idx, dim, num]), "idx, dim and num should be integer" node: IRConv2D = self.node oH, oW = node.outputs(0).shape[2:] - idx: int = config['idx'] - dim: int = config['dim'] - num: int = config['num'] stride = node.kwargs['stride'] dilation = node.kwargs['dilation'] if dim not in [2, 3]: @@ -123,16 +109,13 @@ def satisfy(self, config: Dict): if (idx, dim) == (0, 3): return oW % num == 0 - def instantiate(self, config: Dict): - if not self.satisfy(config): + def instantiate(self, idx: int, dim: int, num: int): + if not self.satisfy(idx, dim, num): return None node: IRConv2D = self.node H, W = node.inputs(0).shape[2:] dH, dW = node.inputs(1).shape[2:] oH, oW = node.outputs(0).shape[2:] - idx: int = config['idx'] - dim: int = config['dim'] - num: int = config['num'] groups = node.kwargs['groups'] stride = node.kwargs['stride'] padding = node.kwargs['padding'] @@ -208,15 +191,10 @@ def __init__(self, node: IRConv3D): raise TypeError(f"Expect IRConv2D") super().__init__(node) - def satisfy(self, config: Dict): - for attr in ['idx', 'dim', 'num']: - if not attr in config: - raise KeyError("Expected idx, dim, num in the config") + def satisfy(self, idx: int, dim: int, num: int): + assert all(isinstance(t, int) for t in [idx, dim, num]), "idx, dim and num should be integer" node: IRConv3D = self.node oD, oH, oW = node.outputs(0).shape[2:] - idx: int = config['idx'] - dim: int = config['dim'] - num: int = config['num'] stride = node.kwargs['stride'] dilation = node.kwargs['dilation'] if dim not in [2, 3]: @@ -233,16 +211,13 @@ def satisfy(self, config: Dict): if (idx, dim) == (0, 3): return oW % num == 0 - def instantiate(self, config: Dict): - if not self.satisfy(config): + def instantiate(self, idx: int, dim: int, num: int): + if not self.satisfy(idx, dim, num): return None node: IRConv3D = self.node D, H, W = node.inputs(0).shape[2:] dD, dH, dW = node.inputs(1).shape[2:] oD, oH, oW = node.outputs(0).shape[2:] - idx: int = config['idx'] - dim: int = config['dim'] - num: int = config['num'] groups = node.kwargs['groups'] stride = node.kwargs['stride'] padding = node.kwargs['padding'] diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py index 2f5d8386..9eec536b 100644 --- a/cube/algorithm/ops/dataloader.py +++ b/cube/algorithm/ops/dataloader.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import List import copy from cube.algorithm.utils import split_axis @@ -14,17 +14,14 @@ def __init__(self, node: IRDataOperation): raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") super().__init__(node) - def satisfy(self, config: Dict): + def satisfy(self, num: int): """ - config = dict(dim=int) - num: int - number of chunks to partition + Check whether the condition satisfies. + + @param num int: number of chunks to partition """ - for attr in ['num']: - if not attr in config: - raise KeyError("Expected idx, dim, num in the config") + node: IRDataOperation = self.node - num: int = config['num'] dims: List[int] = node.get_batch_dims() # check batch size all_batch_size = set([output.shape[dim] for dim, output in zip(dims, node.outputs())]) @@ -36,13 +33,12 @@ def satisfy(self, config: Dict): return False return True - def instantiate(self, config: Dict): - if not self.satisfy(config): - raise RuntimeError("Instantiate failed. Condition not satisfied.") + def instantiate(self, num: int): + if not self.satisfy(num): + return False node: IRDataOperation = self.node - num: int = config['num'] dims: List[int] = node.get_batch_dims() - + outputs = list() for dim, output in zip(dims, node.outputs()): output = split_axis(output, dim, num) diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py index faacbbbb..29e43103 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/einops.py @@ -1,9 +1,9 @@ -from typing import List, Dict +from typing import List, Dict, Optional from cube.algorithm.utils import split_axis, split_value from cube.algorithm.generics import GenericDistAlgo -from cube.graph.function import IREinops, EinDim +from cube.graph.function.einops import IREinops, DimAnno class DimSplitEinops(GenericDistAlgo): @@ -21,73 +21,78 @@ def __init__(self, node: IREinops): if not isinstance(node, IREinops): raise TypeError(f"Expect IREinops") super().__init__(node) + self._adim: str = None + self._reduce: DimAnno.ReduceType = None - def satisfy(self, config: Dict): + def satisfy(self, idx: int, dim: int, num: int) -> bool: """ - config = dict(idx=int, dim=int) + Check whether the condition satisfies. - idx: int - input index - dim: int - dimension of index-th input - num: int - number of chunks to partition + @param idx int: input index + @param dim int: input dimension + @param num int: chunks to partition the dimension + + @return satisfy bool: true if can be partitioned, elsewise false. """ - for attr in ['idx', 'dim', 'num']: - if not attr in config: - raise KeyError("Expected idx, dim, num in the config") + assert all(isinstance(cond, int) for cond in [idx, dim, num]), "expect int condition" node: IREinops = self.node - idx: int = config['idx'] - dim: int = config['dim'] - num: int = config['num'] - if not (isinstance(idx, int) and abs(idx) < len(node.inputs())): + + ninputs = len(node.inputs()) + if idx >= ninputs or idx < 0-ninputs: return False if node.inputs(idx).shape is None or abs(dim) >= len(node.inputs(idx).shape): return False - if node.inputs(idx).shape[dim] % num != 0: + # due to implementation limits, we only partition the first annotated dimension + # for inner-dimension cases. + self._adim: str = node.anno.inputs(idx).dims[dim].identifiers[0] + self._reduce: DimAnno.ReduceType = node.anno.inputs(idx).dims[dim].reduces[0] + dimlen = node.anno.getlen(self._adim) + if self._reduce == DimAnno.ReduceType.Freeze: return False - if node._iannos[idx][dim].reduce == EinDim.ReduceType.Stay: + if dimlen % num != 0: return False return True - def instantiate(self, config: Dict) -> List[IREinops]: - if not self.satisfy(config): + def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IREinops]]: + if not self.satisfy(idx, dim, num): return False node: IREinops = self.node - idx: int = config['idx'] - dim: int = config['dim'] - num: int = config['num'] - edim: EinDim = node._iannos[idx][dim] - - # print(f'splitting: {node.einexpr()}') + print(node.anno, f'partition: {self._adim}; reduce: {self._reduce.value}') ins, ous = list(), list() - for iidx, input in enumerate(node.inputs()): - if edim in node._iannos[iidx]: - dim = node._iannos[iidx].index(edim) - sub_tensors = split_axis(input, dim, num) + for iidx, itensor in enumerate(node.inputs()): + shape_anno = node.anno.inputs(iidx) + split_dims = shape_anno.getdims(self._adim) + assert len(split_dims) <= 1, f"find split dims ({self._adim}) more than 1: {shape_anno}" + if len(split_dims) == 1: + dim = split_dims[0] + # split axis + sub_tensors = split_axis(itensor, dim, num) ins.append(sub_tensors) else: - if edim.reduce[0] == EinDim.ReduceType.Sum: - # print(f'Warning: value split on one input tensor in node{node._id}:{node.name} as reduce axis {axis} not appeared.') - ins.append(split_value(input, num)) + # replicate if no split dimension of this tensor + # ins.append([itensor] * num) + # ad-hoc FIXME: for linear function Ax+b of splitting reduction dimension, b should + # be splitted by value dimension. + if self._reduce == DimAnno.ReduceType.Sum: + ins.append(split_value(itensor, num)) else: - ins.append([input] * num) - for oidx, output in enumerate(node.outputs()): - # split on the non-reduce axis, the output value keeps same - # but the output shape gets splitted - if edim in node._oannos[oidx]: - dim = node._oannos[oidx].index(edim) - if edim.reduce[0] == EinDim.ReduceType.Sum: - raise RuntimeError(f"Reduced axis {dim} appeared in output") - sub_tensors = split_axis(output, dim, num) + ins.append([itensor] * num) + + for oidx, otensor in enumerate(node.outputs()): + shape_anno = node.anno.outputs(oidx) + split_dims = shape_anno.getdims(self._adim) + assert len(split_dims) <= 1, f"find split dims ({self._adim}) more than 1: {shape_anno}" + # split axis + if self._reduce == DimAnno.ReduceType.Dim: + assert len(split_dims) == 1, f"expect only one spatial dimension in output tensor but got {len(split_dims)}" + dim = split_dims[0] + sub_tensors = split_axis(otensor, dim, num) ous.append(sub_tensors) - # split on the reduce axis, the output shape keeps same - # but the output value get splitted + # split numerical dimension else: - if edim.reduce[0] != EinDim.ReduceType.Sum: - raise RuntimeError(f"Expect axis {edim} to be reduced axis") - sub_tensors = split_value(output, num) + assert len(split_dims) == 0, f"expect no numerical dimension in output tensor but got {len(split_dims)}" + sub_tensors = split_value(otensor, num) ous.append(sub_tensors) sub_nodes = list() @@ -98,27 +103,3 @@ def instantiate(self, config: Dict) -> List[IREinops]: sub_node.infer_shape() sub_nodes.append(sub_node) return sub_nodes - - def space(self, num_device: int) -> List[Dict[str, int]]: - """ - Return a list of possible configurations - given the number of devices - """ - possible_idx = list() - possible_dim = list() - num = num_device - dims = list() - node: IREinops = self.node - for idx, eindims in enumerate(node._ieins): - for dim, eindim in enumerate(eindims): - if eindim.reduce != EinDim.ReduceType.Stay: - if eindim not in dims: - dims.append(eindim) - possible_idx.append(idx) - possible_dim.append(dim) - possible_configs = list() - for idx, dim in zip(possible_idx, possible_dim): - config = dict(idx=idx, dim=dim, num=num) - if self.satisfy(config): - possible_configs.append(config) - return possible_configs diff --git a/cube/algorithm/ops/pad.py b/cube/algorithm/ops/pad.py index 50890907..f325d7c5 100644 --- a/cube/algorithm/ops/pad.py +++ b/cube/algorithm/ops/pad.py @@ -1,6 +1,4 @@ -from typing import Dict - -from cube.algorithm.utils import split_axis, split_axis_custom, split_value +from cube.algorithm.utils import split_axis, split_axis_custom from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.pad import IRPad @@ -15,17 +13,13 @@ def __init__(self, node: IRPad): raise TypeError(f"Expect IRConv2D") super().__init__(node) - def satisfy(self, config: Dict): + def satisfy(self, dim: int, num: int): """ config = dict(idx=int, dim=int, num=num) """ - for attr in ['dim', 'num']: - if not attr in config: - raise KeyError("Expected dim, num in the config") + assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" node: IRPad = self.node - dim: int = config['dim'] - num: int = config['num'] pad = node.kwargs['pad'] mode = node.kwargs['mode'] value = node.kwargs['value'] @@ -40,12 +34,10 @@ def satisfy(self, config: Dict): dim_in_pad = len(node.inputs(0).shape) - 1 - dim return (node.inputs(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) % num == 0 - def instantiate(self, config: Dict): - if not self.satisfy(config): + def instantiate(self, dim: int, num: int): + if not self.satisfy(dim, num): return False node: IRPad = self.node - dim: int = config['dim'] - num: int = config['num'] pad = node.kwargs['pad'] mode = node.kwargs['mode'] value = node.kwargs['value'] diff --git a/cube/graph/function/__init__.py b/cube/graph/function/__init__.py index 93f14214..8dede842 100644 --- a/cube/graph/function/__init__.py +++ b/cube/graph/function/__init__.py @@ -1,2 +1,2 @@ -from cube.graph.function.einops import EinDim, IREinops +from cube.graph.function.einops import IREinops from cube.graph.function.function import * \ No newline at end of file diff --git a/cube/graph/function/einops.py b/cube/graph/function/einops.py index a33dc445..aad441a0 100644 --- a/cube/graph/function/einops.py +++ b/cube/graph/function/einops.py @@ -1,59 +1,70 @@ """ -This operator class is highly inspired by einops. +Dimension Annotion Operations. -* Annotating Dimensions: +An operator has (multiple) input tensors and (multiple) output tensors. +Each tensor can be annotated with dimension annotations (DimAnno) using `identifiers`. +The same `identifier` indicates the they have the same real length. + +* Dimension Annotation: e.g., 'a+', 'ab^', 'cd', '(ab+ c^ d)', '64' -A dimension of a tensor can be annotated by {identifier}{reduce} template. +A dimension of a tensor can be annotated by {identifier}{reduction} template. An `identifier` must be one of: 1) symbolic annotation that must match with the criteria of python str.isidentifier. - 2) numeric string that must match with python str.isnumeric. This indicates the shape is the same value - numeric string will always have '^' reduction type - 3) '*': this special value indicates the dimension is dynamic will automatically get expanded given the shape + 2) numeric string that must match with python str.isdecimal. This indicates the shape is the same value + numeric string will always have '^' reduction type' + +Special identifier: + 1) '*': this special identifier indicates the dimension is dynamic, which will automatically get expanded given the shape + 2) '?': this special identifier indicates the value is not a tensor, which will be ignored + +A `reduction` can be a set of {'', '+', '^'}: + '' indicates this dimension can be partitioned, and each output should have this dimension. + '+' indicates this dimension can be partitioned, and each ouutput doesn't have this and need to do sum-reduction. + '^' means this dimension cannot be partitioned. + +A dimension can also be annotated with inner-dimensions using brackets, i.e., '(' and ')'. +The value of inner dimension needs to be inferrable, or indicated by function args (of same name). -A `reduce` can be a set of {'', '+', '^'}: - '' indicates this dimension will apear in output. - '+' indicates no this dimension will be reduced in output using sume - '^' means this dimension is out of scope, Einops will not handle this (cannot do split on it) +* Shape Annotation: -A complex annotation for a dimension is using brackets, i.e., '(' and ')', to include -more inner-dimensions. The value of inner dimension must be (partially) indicated by function args (of same name) -so that letting system know (infer). + e.g., 'a (c+ d^) e' -* Annotating Operator: +A shape annotation consists of dimension annotation separated by (multiple) spaces. -e.g., 'm k+, n k+ -> m n', '4 k+, k+ d -> 8 d', '* d^, s -> * s' -An operator dimension can be annoted with input dimensions and output dimensions. -Same identifier indicates the same shape and semantically same dimension propagation. +* Operator Annotation: -'->' seperates the inputs (left) and outputs (right) and ',' separates each input and output. -A shape needs to be annotated using dimension annotations with delimiters of (mulitple) space ' '. + e.g., 'm k+, n k+ -> m n', '4 k+, k+ d -> 8 d', '* d^, s -> * s' -Dimension annotations in Output must apear in inputs, or using numeric string + An operator can be annotated with input shape annotations and output shape annotations. -* Splitting Rule: + '->' seperates the inputs (left) and outputs (right) and ',' separates each input and output tensor. -Spatial Splitting (dimension with '' reduce type): - tensors that have this dimension will be splitted spatially. - tensors that don't have this dimension will be replicated. + Identifiers in output tensor annotation needs to be + 1) apearred in input tensor annotations + 2) using numeric string -Numerical Splitting (dimension with '+' reduce type): - tensors that have this dimension will be splitted spatially, - tensors that don't have this dimension will be splitted numerically +* Operator Partitioning Rule: -Illegal Splitting (dimension with '^' reduce type): - Illegal splitting algorithm on this dimension. + 1) Spatial Partition (dimension with '' reduce type): + tensors can be uniformly partitioned on dimensions having spatial reduction type. + other tensors in the operator that don't have this dimension will be replicated. + + 2) Value Partition (dimension with '+' reduce type): + * tensors can be uniformly partition on dimensions having numerical reduction type + * other tensors in the the operator that don't have this dimension will be partitioned numerically. + + 3) Illegal Splitting (dimension with '^' reduce type): + * tensors can not be partitioned on dimensions having '^' reduction type. """ -from typing import Any, Dict, List, Union -from typing import Optional, Set, Tuple, Optional +from typing import Dict, Iterable, List, Union, Optional, Set, Tuple, Optional import enum import re -import copy import string from cube.ir.cten import IRTensor @@ -61,7 +72,10 @@ from cube.algorithm.factory import DistAlgorithmFactory -class EinDim: +_kSpecialIdentifiers = ('*', '?') + + +class DimAnno: """ To represent a dimension, name = {identifier}{reducetype} e.g., @@ -73,109 +87,156 @@ class EinDim: """ class ReduceType(enum.Enum): - Spatial='' + Dim = '' Sum = '+' - Stay = '^' # the dim is not allowed to be split - - def __init__(self, name: Union[str, List[str]]): - if isinstance(name, str): - name = [name] - self._name: List[str] = list() - self._reduce: List[EinDim.ReduceType] = list() - self._length: Dict[str, Optional[int]] = dict() - for n in name: - # complex name cannot have * - if len(name) > 1 and '*' in n: - raise ValueError("Einstein Axis name cannot have * for multiple inner-dimension") - # get reduce type - reduce = EinDim.ReduceType.Spatial - if n[-1] == EinDim.ReduceType.Sum.value: - reduce = EinDim.ReduceType.Sum - n = n[:-1] - elif n[-1] == EinDim.ReduceType.Stay.value: - reduce = EinDim.ReduceType.Stay - n = n[:-1] - # get identifier name - if len(n) == 0 or not (str.isidentifier(n) or str.isnumeric(n) or n == '*'): - raise ValueError(f"EinDim name {n} should be identifier") - if str.isnumeric(n): - reduce = EinDim.ReduceType.Stay - self._name.append(n) - self._reduce.append(reduce) - for n in self._name: - self._length[n] = None + Freeze = '^' # the dim is not allowed to be split + + def __init__(self, name: Union[str, Tuple[str]]): + identifiers, reduces = DimAnno.parse(name) + self._identifiers: Tuple[str] = identifiers + self._reduces: Tuple[DimAnno.ReduceType] = reduces @property def name(self) -> str: """ Return identifier without reduce """ - if len(self._name) == 1: - return self._name[0] - return '(' + ' '.join(self._name) + ')' + if len(self._identifiers) == 1: + return self._identifiers[0] + return '(' + ' '.join(self._identifiers) + ')' - def names(self) -> List[str]: - return copy.copy(self._name) + def length(self, identifier: str) -> Optional[int]: + """ + Return the integer of identifer + """ + assert identifier in self._identifiers, f"identifier {identifier} not in {self}" + return self._length[identifier] @property - def reduce(self) -> str: - return self._reduce + def identifiers(self) -> Tuple[str]: + return self._identifiers - def setlen(self, anno: str, dim: int): - if anno not in self._name: - raise KeyError(f"Cannot find anno: {anno} in {self.name}") - self._length[anno] = dim + @property + def reduces(self) -> Tuple[ReduceType]: + return self._reduces def __eq__(self, other): - if isinstance(other, EinDim): + if isinstance(other, DimAnno): if other.name == self.name: return True return False - def is_reduce(self): - return self.reduce == EinDim.ReduceType.Sum - def __repr__(self): - name_reduce = [name + reduce.value for name, reduce in zip(self._name, self._reduce)] - if len(self._name) == 1: - return self._name[0] + self._reduce[0].value + name_reduce = [name + reduce.value for name, reduce in zip(self._identifiers, self._reduces)] + if len(name_reduce) == 1: + return name_reduce[0] return '(' + ' '.join(name_reduce) + ')' + @staticmethod + def parse(anno: Union[str, Tuple[str]]) -> Tuple[Tuple[str], Tuple[ReduceType]]: + assert isinstance(anno, str) or all(isinstance(n, str) for n in anno), \ + "Expect anno to be str or Tuple[str]" + if isinstance(anno, str): + anno = (anno,) + if len(anno) > 1 and any(i in anno for i in _kSpecialIdentifiers): + raise ValueError(f"Dim annotation cannot have {_kSpecialIdentifiers} as partial dimension.") + identifiers, reduces = [], [] + for identifier in anno: + # get reduce type + reduce = DimAnno.ReduceType.Dim + if identifier[-1] == DimAnno.ReduceType.Sum.value: + reduce = DimAnno.ReduceType.Sum + identifier = identifier[:-1] + elif identifier[-1] == DimAnno.ReduceType.Freeze.value: + reduce = DimAnno.ReduceType.Freeze + identifier = identifier[:-1] + # get identifier name + assert str.isdecimal(identifier) or str.isidentifier(identifier) or identifier in _kSpecialIdentifiers, \ + f"identifier can only be integer or python identifier but got {identifier}" + # integer will always have stay reduction type + if str.isdecimal(identifier): + reduce = DimAnno.ReduceType.Freeze + identifiers.append(identifier) + reduces.append(reduce) + return tuple(identifiers), tuple(reduces) + + +class ShapeAnno: + """ + Shape annotation + + e.g., a (b+ dim) d^ + """ -class EinopAnno: + def __init__(self, dim_annos: Union[str, Tuple[DimAnno]]): + assert isinstance(dim_annos, str) or all(isinstance(adim, DimAnno) for adim in dim_annos), \ + f"dim_annos must be str or Tuple[DimAnno] but got {dim_annos}" + if isinstance(dim_annos, str): + dim_annos = ShapeAnno.parse(dim_annos) + self._dims: Tuple[DimAnno] = dim_annos - def __init__(self, anno: str): + @property + def dims(self) -> Tuple[DimAnno]: + return self._dims + + @property + def ndims(self) -> int: + """! + Get dimension number """ - initializing annotations specfied in str, e.g., - a (b c) d+, d+ k -> a (b c) k + return len(self._dims) + + def getdims(self, identifier: str) -> List[int]: + """! + Get dims that has the identifier + + @param identifier str: the query identifier + + @return dims List[int]: dimensions that contain the identifier """ - if not isinstance(anno, str): - raise TypeError("Expected anno to be str") - self.anno = anno - if '->' not in self.anno: - raise ValueError("Expected -> in anno") - # to inputs and outputs - inputs, outputs = self.anno.split('->') - inputs = inputs.split(',') - outputs = outputs.split(',') - # to eindims - self._identifiers: Set[str] = set() - self.inputs: List[List[EinDim]] = [ - self.parse_shape(shape) for shape in inputs - ] - self.outputs: List[List[EinDim]] = [ - self.parse_shape(shape) for shape in outputs - ] - self.reset_identifiers() + dims = [] + for dim, dim_anno in enumerate(self.dims): + if identifier in dim_anno.identifiers: + dims.append(dim) + return dims + + def __getitem__(self, dim: int) -> DimAnno: + assert isinstance(dim, int), "indexing only support int, but got {dim}" + assert dim < len(self._dims), f"dim {dim} out of boudary {len(self._dims)}" + return self._dims[dim] + + def __setitem__(self, index: int, dim_anno: Union[DimAnno, str]): + assert isinstance(index, int), "Expected index to be int" + assert isinstance(dim_anno, (DimAnno, str)), "Expected DimAnno or str" + if isinstance(dim_anno, str): + dim_anno = DimAnno(dim_anno) + self._dims[index] = dim_anno + + def __repr__(self) -> str: + return ' '.join(repr(dim) for dim in self._dims) + + @property + def ignore(self) -> bool: + """! + Check if the shape should be ignored, i.e., annotation is '?'. - def parse_shape(self, shape: str) -> List[EinDim]: + @return is_ignored bool: True if the shape should ignore else False """ - parsing annotations like of a single shape, e.g., + return self.ndims == 1 and self._dims[0].name == '?' + + @staticmethod + def parse(shape_anno: str) -> Tuple[DimAnno]: + """ + Parse annotations like of a single shape, e.g., a (b+ dim) d^ + + @param shape str: shape annotation + + @return dim_annos Tuple[DimAnno]: tuple of dimension annotations """ # => ['a', '(', 'b+', 'dim', ')', 'd^'] shapes = list() - for group in re.split('\ +', shape): + for group in re.split('\ +', shape_anno): if len(group) == 0: continue if '(' in group or ')' in group: @@ -189,15 +250,13 @@ def parse_shape(self, shape: str) -> List[EinDim]: bracket_group = False for w in shapes: if w == '(': - if bracket_group: - raise RuntimeError("brackets inside brackets not allowed") + assert not bracket_group, "Syntax Error: brackets inside brackets not allowed" bracket_group = True if len(current_identifier) > 0: edims.append(current_identifier) current_identifier = list() elif w == ')': - if not bracket_group: - raise RuntimeError("backets are not balanced at (") + assert bracket_group, "Syntax Error: backets are not balanced at (" bracket_group = False if len(current_identifier) > 0: edims.append(current_identifier) @@ -209,28 +268,211 @@ def parse_shape(self, shape: str) -> List[EinDim]: if len(current_identifier) > 0: edims.append(current_identifier) current_identifier = [w] - if bracket_group: - raise RuntimeError("brackets are not balanced at )") + assert not bracket_group, "Syntax Error: brackets are not balanced at )" if len(current_identifier) != 0: edims.append(current_identifier) - edims = [EinDim(edim) for edim in edims] - return edims + dim_annos = tuple(DimAnno(edim) for edim in edims) + return dim_annos + @staticmethod + def create_shape_str(shape: Tuple[int], reduction: str = '', iterator: Optional[Iterable] = None) -> List[str]: + """ + Create dimension string annotation given the shape and identity iterator + e.g., ['a+', 'b+', 'c+'] + + @param shape List[int]: tensor shape + @param iterator Optional[Iterable]: identity iterators. If None, use string.ascii_lowercase + @param reduce (str): reduction type must be in '', '+' or '^' + + @return strs List[str]: each element in strs represents a dimension + """ + if iterator is None: + iterator = iter(string.ascii_lowercase) + return [next(iterator) + reduction for _ in range(len(shape))] + + +class OpAnno: + """ + Operator annotation. + + e.g., a (b c) d+, d+ k -> a (b c) k + + """ + + def __init__(self, anno: Union[str, Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]]]): + assert isinstance(anno, str) or \ + (len(anno) == 2 and all(isinstance(ashape, ShapeAnno) for ashape in list(anno[0]) + list(anno[1]))), \ + "Expected anno to be str or (inputs: [ShapeAnno], outputs: [ShapeAnno])" + if isinstance(anno, str): + anno = OpAnno.parse(anno) + inputs, outputs = anno + self._inputs: Tuple[ShapeAnno] = tuple(inputs) + self._outputs: Tuple[ShapeAnno] = tuple(outputs) + self._identifiers: Dict[str, int] = dict() + self.reset_identifiers() + + @property def identifiers(self) -> Set[str]: - return copy.copy(self._identifiers) + """! + Get all identifier set + + @return identifiers Set[str] + """ + return tuple(self._identifiers.keys()) + + def inputs(self, index: Optional[int] = None) -> Union[ShapeAnno, Tuple[ShapeAnno]]: + """! + Get shape annotation of index-th input. + If index is None, will return all shape annotations + + @param index Optional[int]: the index of input. + + @return shape_annos Union[ShapeAnno, Tuple[ShapeAnno]]: the shape annotation + """ + assert index is None or index < len(self._inputs), "index out of boundary" + if index is None: + return self._inputs + else: + return self._inputs[index] + + def set_input(self, index: int, shape_anno: Union[str, ShapeAnno]): + """ + set the shape of index-th input tensors + """ + assert isinstance(shape_anno, (str, ShapeAnno)), f"must be str or ShapeAnno but got {shape_anno}" + assert index is None or index < len(self._inputs), "index out of boundary" + inputs = list(self._inputs) + inputs[index] = shape_anno if isinstance(shape_anno, ShapeAnno) else ShapeAnno(shape_anno) + self._inputs = tuple(inputs) + + def outputs(self, index: Optional[int] = None) -> Union[ShapeAnno, Tuple[ShapeAnno]]: + assert index is None or index < len(self._outputs), "index out of boundary" + if index is None: + return self._outputs + else: + return self._outputs[index] + + def set_output(self, index: int, shape_anno: Union[str, ShapeAnno]): + """ + set the shape of index-th input tensors + """ + assert isinstance(shape_anno, (str, ShapeAnno)), f"must be str or ShapeAnno but got {shape_anno}" + assert index is None or index < len(self._outputs), "index out of boundary" + outputs = list(self._outputs) + outputs[index] = shape_anno if isinstance(shape_anno, ShapeAnno) else ShapeAnno(shape_anno) + self._outputs = tuple(outputs) def reset_identifiers(self): - self._identifiers = set() - for eshape in self.inputs + self.outputs: - for edim in eshape: - for name in edim.names(): - self._identifiers.add(name) + """! + Reset identifier set. + + @return None + """ + self._identifiers = dict() + shape_annos = list(self._inputs) + list(self._outputs) + for ashape in shape_annos: + for adim in ashape.dims: + self._identifiers.update({identifier: None for identifier in adim.identifiers}) + for identifier in self._identifiers.keys(): + if str.isdecimal(identifier): + self._identifiers[identifier] = int(identifier) + + def setlen(self, identifier: str, length: int, override=False) -> bool: + """! + Set identifier length + + @param identifier str: identifier name + @param length int: the real length of identifier + @param override bool: if True will always set length, else will check if the existing length matches the new length + + @return success True if sucessfully set else False + """ + assert identifier in self._identifiers, f"{identifier} not int identifier set {self._identifiers}" + if not override: + prelen = self._identifiers[identifier] + if prelen is not None and prelen != length: + return False + self._identifiers[identifier] = length + return True + + def getlen(self, identifier: str) -> Optional[int]: + """! + Get identifier length + + @param identifier str: identifier name + + @return length Optional[int]: the length of identifier + """ + assert identifier in self._identifiers, f"{identifier} not int identifier set {self._identifiers}" + return self._identifiers[identifier] def __repr__(self) -> str: - inputs = ', '.join([repr(input) for input in self.inputs]) - outputs = ', '.join(repr(output) for output in self.outputs) + inputs = ', '.join(repr(input) for input in self.inputs()) + outputs = ', '.join(repr(output) for output in self.outputs()) return inputs + ' -> ' + outputs + @staticmethod + def parse(anno: str) -> Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]]: + """! + Parse op annotation string to input shape annos and output shape annos. + + @param anno str: operator annotation + + @return (inputs, outputs) Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]] + """ + # to inputs and outputs + if '->' not in anno: + raise ValueError("Syntax Error: Expected -> in operator anno") + inputs, outputs = anno.split('->') + inputs = inputs.split(',') + outputs = outputs.split(',') + # to ShapeAnnos + inputs: Tuple[ShapeAnno] = tuple(ShapeAnno(shape) for shape in inputs) + outputs: Tuple[ShapeAnno] = tuple(ShapeAnno(shape) for shape in outputs) + return inputs, outputs + + @staticmethod + def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], + ous: Tuple[Tuple[Union[str, Tuple[str]]]]) -> str: + """! + Create operator annotation string + e.g., + ins = [ ['a', 'b', 'c+'], ['c+', ['d', 'e']] ] + ous = [ ['a', 'b', 'd', 'e'] ] + => + 'a b c+, c+ (d e) -> a b d e' + + @param ins Tuple[Tuple[Union[str, Tuple[str]]]: input identifier list + @param ous Tuple[Tuple[Union[str, Tuple[str]]]: output identifier list + + @return anno str: operator annotation + """ + in_annos = list() + ou_annos = list() + for shape in ins: + flatten = list() + for edim in shape: + if isinstance(edim, str): + flatten.append(edim) + # List + elif len(edim) == 1: + flatten.append(edim[0]) + else: + flatten.append('(' + ' '.join(edim) + ')') + in_annos.append(' '.join(flatten)) + for shape in ous: + flatten = list() + for edim in shape: + if isinstance(edim, str): + flatten.append(edim) + # List + elif len(edim) == 1: + flatten.append(edim[0]) + else: + flatten.append('(' + ' '.join(edim) + ')') + ou_annos.append(' '.join(flatten)) + return ', '.join(in_annos) + ' -> ' + ', '.join(ou_annos) + class IREinops(IRFwOperation): @@ -238,17 +480,29 @@ class IREinops(IRFwOperation): Einstein-inspired notation operations """ def __init__(self, signature: str, annos: Tuple[str], - inputs: List, name: str, **kwargs): + inputs: List[IRTensor], name: str, **kwargs): + """! + Create a IRDimops + + @param signature str: operator signature + @param annos List[str]: annotation candidates + @param inputs List[IRTensor]: input tensor list + @param name str: the name of the operator + @param kwargs: the kwarg non-tensor parameters + """ + assert all(isinstance(anno, str) for anno in annos), "Expect annos to be List[str]" self._annos_candidates: List[str] = tuple(annos) - self._iannos: List[List[EinDim]] = None - self._oannos: List[List[EinDim]] = None + self._anno: OpAnno = None + self._iannos: List[ShapeAnno] = None + self._oannos: List[ShapeAnno] = None for anno in self._annos_candidates: - anno = EinopAnno(anno) + anno = OpAnno(anno) # expand * and check shape dimension consistency - if self.parse(inputs, anno): - self._iannos = anno.inputs - self._oannos = anno.outputs + if self.align(inputs, anno, **kwargs): + self._iannos = anno.inputs() + self._oannos = anno.outputs() + self._anno = anno break else: raise RuntimeError( @@ -266,54 +520,61 @@ def __init__(self, signature: str, annos: Tuple[str], for name in kwargs: self.kwargs[name] = kwargs[name] + @property + def anno(self) -> OpAnno: + return self._anno + + def ianno(self, index: int) -> Tuple[DimAnno]: + """! + Get index-th input tensor shape annotation + + @param index int: the input index + + @return dim_annos Tuple[DimAnno]: a tuple that each element is a dimension annotation + """ + assert index < len(self.inputs()), "index out of boudary" + return tuple(self._iannos[index]) + + def oanno(self, index: int) -> Tuple[DimAnno]: + """! + Get index-th output tensor shape annotation + + @param index int: the output index + + @return dim_annos Tuple[DimAnno]: a tuple that each element is a dimension annotation + """ + assert index < len(self.outputs()), "index out of boudary" + return self._oannos[index] + def infer_shape(self) -> bool: """ Shape inference using the matched annotation + + @return sucess: True if successfully inferred shape """ - dimlen: Dict[str, int] = dict() - for input, ishape in zip(self.inputs(), self._iannos): - if not isinstance(input, IRTensor): - continue - for tdim, edim in zip(input.shape, ishape): - if len(edim.names()) == 1: - dimlen[edim.name] = tdim - continue - # infer hidden dim shape - toinfer = None + for oidx, otensor in enumerate(self.outputs()): + shape_anno = self.oanno(oidx) + shape = [] + for odim in range(shape_anno.ndims): accum = 1 - for name in edim.names(): - if str.isnumeric(name): - accum *= int(name) - dimlen[name] = int(name) - elif name in self.kwargs: - accum *= self.kwargs[name] - dimlen[name] = self.kwargs[name] - else: - if toinfer is not None: - raise RuntimeError(f"Too many dimensions need to be inferred") - toinfer = name - if toinfer is not None: - dimlen[toinfer] = tdim // accum - # figure output shape - for oidx in range(len(self._outputs)): - output_shape = list() - for odim in self._oannos[oidx]: - accum = 1 - for name in odim.names(): - if str.isnumeric(name): - accum *= int(name) - else: - if name not in dimlen: - raise KeyError(f"Dim annotation {name} not in input") - accum *= dimlen[name] - output_shape.append(accum) - self.outputs(oidx).shape = output_shape + for identifier in shape_anno[odim].identifiers: + accum *= self.anno.getlen(identifier) + shape.append(accum) + otensor.shape = shape + # print(f'=> sign: {self.signature} anno: {self.anno}\n' + # f'=> inputs: {self.inputs()}\n' + # f'=> outputs: {self.outputs()}') return True - def new(self, inputs: List, outputs: List): - """ - construct a new operator sharing same kwargs with new inputs + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + """! + Construct a new operator sharing same kwargs with new inputs and outputs + + @param inputs List[IRTensor]: input tensors + @param outputs List[IRTensor]: output tensors + + @return op IRDimop: the new constructed operator """ annos = self._annos_candidates op = IREinops(self.signature, annos, inputs, self.name, **self.kwargs) @@ -321,74 +582,89 @@ def new(self, inputs: List, outputs: List): op.set_output(idx, output) return op - def parse(self, inputs: List[Any], anno: EinopAnno) -> Tuple[bool, List[List[EinDim]], List[List[EinDim]]]: - """ - parse annotations, assuming input tensor shape is given - """ - identifiers = anno.identifiers() + def align(self, inputs: List[IRTensor], op_anno: OpAnno, **kwargs) -> bool: + """! + Align input tensor shapes to the operator annotation. + + @param inputs List[IRTensor]: input tensor list + @param op_anno OpAnno: operator annotation + @return success True if align success else False + """ + identifiers = op_anno.identifiers # input shape match - if len(anno.inputs) != len(inputs): + if len(op_anno.inputs()) != len(inputs): return False - # expand * expand_dims = None if '*' in identifiers: candicates = [c for c in string.ascii_lowercase if c not in identifiers] # go through inputs - for idx, (eshape, input) in enumerate(zip(anno.inputs, inputs)): - names = [edim.name for edim in eshape] + for idx, (ashape, itensor) in enumerate(zip(op_anno.inputs(), inputs)): + names = [dim_anno.name for dim_anno in ashape.dims] if '*' in names: - if not isinstance(input, IRTensor): + if not isinstance(itensor, IRTensor): return False pos = names.index('*') - split = eshape[pos].reduce[0].value - span = len(inputs[idx].shape) - (len(names) - 1) - if expand_dims is not None and len(expand_dims) != span: + reduce = ashape[pos].reduces[0].value + ndims = len(inputs[idx].shape) - (len(names) - 1) + if expand_dims is not None and len(expand_dims) != ndims: return False if expand_dims is None: expand_dims = [] - if span > 0: - expand_dims = [EinDim(candicates[dim]+split) for dim in range(span)] - anno.inputs[idx] = anno.inputs[idx][:pos] + expand_dims + anno.inputs[idx][pos+1:] + if ndims > 0: + expand_dims = list(DimAnno(candicates[dim] + reduce) for dim in range(ndims)) + shape_anno = list(op_anno.inputs(idx).dims[:pos]) + expand_dims + list(op_anno.inputs(idx).dims[pos+1:]) + shape_anno = ShapeAnno(tuple(shape_anno)) + op_anno.set_input(idx, shape_anno) # * should appear in inputs - if expand_dims is None: - return False + assert expand_dims is not None, f"Syntax Error: {op_anno}: * should also appear in inputs" # go through outputs - for idx, eshape in enumerate(anno.outputs): - names = [edim.name for edim in eshape] + for idx, shape_anno in enumerate(op_anno.outputs()): + names = [dim_anno.name for dim_anno in shape_anno.dims] if '*' in names: pos = names.index('*') - anno.outputs[idx] = anno.outputs[idx][:pos] + expand_dims + anno.outputs[idx][pos+1:] - anno.reset_identifiers() + shape_anno = list(op_anno.outputs(idx).dims[:pos]) + expand_dims + list(op_anno.outputs(idx).dims[pos+1:]) + shape_anno = ShapeAnno(tuple(shape_anno)) + op_anno.set_output(idx, shape_anno) + op_anno.reset_identifiers() # check dimension consistency - dimlen: Dict[str, int] = dict() - for eshape, input in zip(anno.inputs, inputs): - if input is None: + for ashape, itensor in zip(op_anno.inputs(), inputs): + if not (isinstance(itensor, IRTensor) ^ ashape.ignore): + return False + if not isinstance(itensor, IRTensor): continue - if not isinstance(input, IRTensor): - if not (len(eshape) == 1 and eshape[0].name == '1'): - return False - else: - if len(input.shape) != len(eshape): + if ashape.ndims != len(itensor.shape): + return False + for adim, dimlen in zip(ashape.dims, itensor.shape): + ret = True + identifiers = adim.identifiers + if len(identifiers) == 1: + ret = op_anno.setlen(identifiers[0], dimlen) + else: + toinfer, accum = [], 1 + for identifier in identifiers: + length = op_anno.getlen(identifier) + if length is None: + if identifier not in kwargs: + toinfer.append(identifier) + else: + assert isinstance(kwargs[identifier], int), "require integer for annotation inference" + ret = op_anno.setlen(identifier, kwargs[identifier]) + accum *= kwargs[identifier] + else: + accum *= length + if len(toinfer) == 0 and accum != dimlen: + return False + assert len(toinfer) <= 1, f"Syntax Error {op_anno}: cannot infer hidden dim: {adim}" + if len(toinfer) == 1: + assert dimlen % accum == 0, f"{dimlen} % {accum} != 0" + ret = op_anno.setlen(toinfer[0], dimlen // accum) + if not ret: return False - for edim, nele in zip(eshape, input.shape): - if edim.name in dimlen: - if nele != dimlen[edim.name]: - return False - dimlen[edim.name] = nele return True - def einexpr(self) -> str: - inputs = list() - outputs = list() - for shape in self._iannos: - inputs.append(' '.join([repr(edim) for edim in shape])) - for shape in self._oannos: - outputs.append(' '.join([repr(edim) for edim in shape])) - return ', '.join(inputs) + ' -> ' + ', '.join(outputs) - def algorithms(self, tag: Optional[str] = None): factory = DistAlgorithmFactory() if tag is None: diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 3a67fffb..73e9368d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1,11 +1,11 @@ -from typing import Any, Iterable, List, Optional, Tuple, Union, Dict +from typing import Any, List, Optional, Tuple, Dict import string import copy import torch import warnings from cube.ir.cten import IRTensor -from cube.graph.function.einops import EinDim, IREinops +from cube.graph.function.einops import ShapeAnno, OpAnno, IREinops from cube.graph.function.conv import IRConv2D from cube.graph.function.conv import IRConv3D from cube.graph.function.pad import IRPad @@ -20,58 +20,10 @@ from cube.graph.torch_dtype_mapping import DType2IRDType, TorchScalarTypeEnumMap -def _create_eshape(shape: List[int], iterator: Optional[Iterable] = None, - reduce: EinDim.ReduceType = EinDim.ReduceType.Spatial) -> List[str]: - """ - Create dimension annotation given the shape and - letter iterator - """ - if iterator is None: - iterator = iter(string.ascii_lowercase) - return [next(iterator) + reduce.value for _ in range(len(shape))] - - -def _create_anno(ins: List[List[Union[str, List[str]]]], - ous: List[List[Union[str, List[str]]]]) -> str: - """ - Create annotation string - e.g., - ins = [ ['a', 'b', 'c+'], ['c+', ['d', 'e']] ] - ous = [ ['a', 'b', 'd', 'e'] ] - => - 'a b c+, c+ (d e) -> a b d e' - """ - in_annos = list() - ou_annos = list() - for shape in ins: - flatten = list() - for edim in shape: - if isinstance(edim, str): - flatten.append(edim) - # List - elif len(edim) == 1: - flatten.append(edim[0]) - else: - flatten.append('(' + ' '.join(edim) + ')') - in_annos.append(' '.join(flatten)) - for shape in ous: - flatten = list() - for edim in shape: - if isinstance(edim, str): - flatten.append(edim) - # List - elif len(edim) == 1: - flatten.append(edim[0]) - else: - flatten.append('(' + ' '.join(edim) + ')') - ou_annos.append(' '.join(flatten)) - return ', '.join(in_annos) + ' -> ' + ', '.join(ou_annos) - - def Identity(signature, inputs): signature = 'cube.runtime.function.identity' - eshape = _create_eshape(inputs[0].shape) - anno = _create_anno([eshape], [eshape]) + eshape = ShapeAnno.create_shape_str(inputs[0].shape) + anno = OpAnno.create_op_str([eshape], [eshape]) return IREinops(signature, [anno], inputs, 'identity') @@ -183,6 +135,51 @@ def ToTensor(signature, signature = 'torch.Tensor.to' return IRToTensor(signature, [tensor], 'to', ir_dtype=ir_dtype) + +def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: + """! + Create shape annotations for element wise operator following broadcastable rules: + https://pytorch.org/docs/stable/notes/broadcasting.html + + @param lhs IRTensor: the lhs input tensor + @param rhs IRTensor: the rhs input tensor + + @return lhs_shape, rhs_shape, out_shape: the lhs, rhs and output shape annotation + """ + lndims, rndims = len(lhs.shape), len(rhs.shape) + # init lhs_shape and rhs_shape annotation string + shape_anno = ShapeAnno.create_shape_str(lhs.shape if lndims > rndims else rhs.shape) + lhs_shape = shape_anno[0-lndims:] + rhs_shape = shape_anno[0-rndims:] + # expand dimensions for empty dimensions + lofst = max(lndims, rndims) - lndims + lshape = [1] * lofst + list(lhs.shape) + rofst = max(lndims, rndims) - rndims + rshape = [1] * rofst + list(rhs.shape) + # init out_shape + out_shape = [] + for dim in range(len(lshape)): + ldim_anno = None if dim - lofst < 0 else lhs_shape[dim-lofst] + rdim_anno = None if dim - rofst < 0 else rhs_shape[dim-rofst] + if lshape[dim] == rshape[dim]: + assert rdim_anno is not None or ldim_anno is not None + out_shape.append(rdim_anno if rdim_anno is not None else ldim_anno) + elif lshape[dim] == 1: + assert rdim_anno is not None + out_shape.append(rdim_anno) + if ldim_anno is not None: + lhs_shape[dim-lofst] = '1' + elif rshape[dim] == 1: + assert ldim_anno is not None + out_shape.append(ldim_anno) + if rdim_anno is not None: + rhs_shape[dim-rofst] = '1' + else: + raise ValueError(f"cannot broadcast lhs: {lhs.shape} and rhs: {rhs.shape}") + # print(lhs.shape, rhs.shape, lhs_shape, rhs_shape, out_shape) + return lhs_shape, rhs_shape, out_shape + + def Add(signature, inputs): if len(inputs) == 2: kwargs = {} @@ -201,29 +198,16 @@ def Add(signature, inputs): return lhs + rhs annos = [ - '*, 1 -> *', - '1, * -> *', - '*, * -> *', + '*, ? -> *', + '?, * -> *', ] - # broadcast - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ - len(lhs.shape) == len(rhs.shape): - if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): - # TODO: support spatial partitioning on broadcast dim - lshape = _create_eshape(lhs.shape) - rshape = copy.copy(lshape) - oshape = copy.copy(lshape) - for dim in range(len(lhs.shape)): - if lhs.shape[dim] < rhs.shape[dim]: - oshape[dim] = rshape[dim] - lshape[dim] = str(lhs.shape[dim]) - elif lhs.shape[dim] > rhs.shape[dim]: - oshape[dim] = lshape[dim] - rshape[dim] = str(rhs.shape[dim]) - annos = [_create_anno([lshape, rshape], [oshape])] + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): + lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'add', **kwargs) + def Sub(signature, inputs): if len(inputs) == 2: alpha = 1 @@ -243,26 +227,12 @@ def Sub(signature, inputs): return lhs - rhs annos = [ - '*, 1 -> *', - '1, * -> *', - '*, * -> *', + '*, ? -> *', + '?, * -> *', ] - # broadcast - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ - len(lhs.shape) == len(rhs.shape): - if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): - # TODO: support spatial partitioning on broadcast dim - lshape = _create_eshape(lhs.shape) - rshape = copy.copy(lshape) - oshape = copy.copy(lshape) - for dim in range(len(lhs.shape)): - if lhs.shape[dim] < rhs.shape[dim]: - oshape[dim] = rshape[dim] - lshape[dim] = str(lhs.shape[dim]) - elif lhs.shape[dim] > rhs.shape[dim]: - oshape[dim] = lshape[dim] - rshape[dim] = str(rhs.shape[dim]) - annos = [_create_anno([lshape, rshape], [oshape])] + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): + lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'sub', **kwargs) @@ -273,58 +243,29 @@ def Mul(signature, inputs): return lhs * rhs annos = [ - '*, 1 -> *', - '1, * -> *', - '*, * -> *', + '*, ? -> *', + '?, * -> *', ] - # broadcast - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ - len(lhs.shape) == len(rhs.shape): - if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): - # TODO: support spatial partitioning on broadcast dim - lshape = _create_eshape(lhs.shape) - rshape = copy.copy(lshape) - oshape = copy.copy(lshape) - for dim in range(len(lhs.shape)): - if lhs.shape[dim] < rhs.shape[dim]: - oshape[dim] = rshape[dim] - lshape[dim] = str(lhs.shape[dim]) - elif lhs.shape[dim] > rhs.shape[dim]: - oshape[dim] = lshape[dim] - rshape[dim] = str(rhs.shape[dim]) - annos = [_create_anno([lshape, rshape], [oshape])] + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): + lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'mul') def Div(signature, inputs): lhs, rhs = inputs - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): # For `aten::div` we always do floating division, even operands are both ints. # TorchScript would dispatch frontend `a // b` to another op `aten::floordiv`. return lhs / rhs annos = [ - '*, 1 -> *', - '1, * -> *', - '*, * -> *', + '*, ? -> *', + '?, * -> *', ] - # broadcast - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ - len(lhs.shape) == len(rhs.shape): - if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): - # TODO: support spatial partitioning on broadcast dim - lshape = _create_eshape(lhs.shape) - rshape = copy.copy(lshape) - oshape = copy.copy(lshape) - for dim in range(len(lhs.shape)): - if lhs.shape[dim] < rhs.shape[dim]: - oshape[dim] = rshape[dim] - lshape[dim] = str(lhs.shape[dim]) - elif lhs.shape[dim] > rhs.shape[dim]: - oshape[dim] = lshape[dim] - rshape[dim] = str(rhs.shape[dim]) - annos = [_create_anno([lshape, rshape], [oshape])] + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): + lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'div') @@ -335,26 +276,12 @@ def FloorDiv(signature, inputs): return lhs // rhs annos = [ - '*, 1 -> *', - '1, * -> *', - '*, * -> *', + '*, ? -> *', + '?, * -> *', ] - # broadcast - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ - len(lhs.shape) == len(rhs.shape): - if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): - # TODO: support spatial partitioning on broadcast dim - lshape = _create_eshape(lhs.shape) - rshape = copy.copy(lshape) - oshape = copy.copy(lshape) - for dim in range(len(lhs.shape)): - if lhs.shape[dim] < rhs.shape[dim]: - oshape[dim] = rshape[dim] - lshape[dim] = str(lhs.shape[dim]) - elif lhs.shape[dim] > rhs.shape[dim]: - oshape[dim] = lshape[dim] - rshape[dim] = str(rhs.shape[dim]) - annos = [_create_anno([lshape, rshape], [oshape])] + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): + lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'floordiv') @@ -365,26 +292,12 @@ def Pow(signature, inputs): return lhs ** rhs annos = [ - '*, 1 -> *', - '1, * -> *', - '*, * -> *', + '*, ? -> *', + '?, * -> *', ] - # broadcast - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ - len(lhs.shape) == len(rhs.shape): - if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): - # TODO: support spatial partitioning on broadcast dim - lshape = _create_eshape(lhs.shape) - rshape = copy.copy(lshape) - oshape = copy.copy(lshape) - for dim in range(len(lhs.shape)): - if lhs.shape[dim] < rhs.shape[dim]: - oshape[dim] = rshape[dim] - lshape[dim] = str(lhs.shape[dim]) - elif lhs.shape[dim] > rhs.shape[dim]: - oshape[dim] = lshape[dim] - rshape[dim] = str(rhs.shape[dim]) - annos = [_create_anno([lshape, rshape], [oshape])] + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): + lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, 'pow') @@ -399,26 +312,12 @@ def comparison_einops(f, name, signature, inputs): return f(lhs, rhs) annos = [ - '*, 1 -> *', - '1, * -> *', - '*, * -> *', + '*, ? -> *', + '?, * -> *', ] - # broadcast - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor) and \ - len(lhs.shape) == len(rhs.shape): - if not all([l == r for l, r in zip(lhs.shape, rhs.shape)]): - # TODO: support spatial partitioning on broadcast dim - lshape = _create_eshape(lhs.shape) - rshape = copy.copy(lshape) - oshape = copy.copy(lshape) - for dim in range(len(lhs.shape)): - if lhs.shape[dim] < rhs.shape[dim]: - oshape[dim] = rshape[dim] - lshape[dim] = str(lhs.shape[dim]) - elif lhs.shape[dim] > rhs.shape[dim]: - oshape[dim] = lshape[dim] - rshape[dim] = str(rhs.shape[dim]) - annos = [_create_anno([lshape, rshape], [oshape])] + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): + lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] return IREinops(signature, annos, inputs, name) @@ -502,8 +401,8 @@ def LayerNorm(signature, inputs): if len(normalized_shape) != 1: raise NotImplementedError("Only support normalized_shape to be int") annos = [ - f'N *, 1, {normalized_shape[0]}, {normalized_shape[0]} -> N *', - f'N *, 1, 1, 1 -> N *' + f'N *, ?, {normalized_shape[0]}, {normalized_shape[0]} -> N *', + f'N *, ?, ?, ? -> N *' ] return IREinops(signature, annos, [input, normalized_shape, weight, bias], 'layernorm', eps=eps) @@ -512,7 +411,7 @@ def LayerNorm(signature, inputs): def Sum(signature, inputs): tensor = inputs[0] dim = inputs[1] - einput = _create_eshape(tensor.shape) + einput = ShapeAnno.create_shape_str(tensor.shape) eoutput = copy.copy(einput) if dim is not None: keepdim = inputs[2] @@ -525,7 +424,7 @@ def Sum(signature, inputs): eoutput = ['1'] # every dimension is reduced einput = [edim + '+' for edim in einput] - anno = _create_anno([einput], [eoutput]) + anno = OpAnno.create_op_str([einput], [eoutput]) if dim is not None: return IREinops(signature, [anno], [tensor], 'sum', dim=dim, keepdim=keepdim) else: @@ -539,10 +438,10 @@ def Transpose(signature, inputs): assert len(inputs) == 3 input, dim0, dim1 = inputs - edim_in = _create_eshape(input.shape) + edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = copy.copy(edim_in) edim_ou[dim0], edim_ou[dim1] = edim_ou[dim1], edim_ou[dim0] - anno = _create_anno([edim_in], [edim_ou]) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IREinops(signature, [anno], [input], 'transpose', dim0=dim0, dim1=dim1) @@ -658,7 +557,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s if edim not in spatial: bracket[subdim] = str(shape_map[edim]) # bracket[subdim] = edim + '^' - anno = _create_anno([in_anno], [ou_anno]) + anno = OpAnno.create_op_str([in_anno], [ou_anno]) signature = 'torch.Tensor.view' return IREinops(signature, [anno], [input], 'view', size=tuple(shape)) @@ -681,7 +580,7 @@ def Reshape(signature, inputs): # torch.conv2d(input, weight, bias, stride, padding, dialation, groups) # https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html?highlight=torch%20conv2d#torch.nn.functional.conv2d # """ -# def adapt(anno: EinopAnno, node: IREinops) -> EinopAnno: +# def adapt(anno: OpAnno, node: IREinops) -> OpAnno: # iH, iW = node.inputs(0).shape[2:4] # stride = node.kwargs['stride'] # padding = node.kwargs['padding'] @@ -690,8 +589,8 @@ def Reshape(signature, inputs): # dW = node.inputs(1).shape[3] # oH = (iH + 2 * padding[0] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 # oW = (iW + 2 * padding[1] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 -# anno.outputs[0][2] = EinDim([str(oH)]) -# anno.outputs[0][3] = EinDim([str(oW)]) +# anno.outputs[0][2] = DimAnno([str(oH)]) +# anno.outputs[0][3] = DimAnno([str(oW)]) # return anno # annos = [ # ('N iC+ H^ W^, oC iC+ dH^ dW^, oC -> N oC oH^ oW^', adapt), @@ -819,9 +718,12 @@ def Embedding(signature, inputs: List): padding_idx = inputs[3] start, stop = 0, weight.shape[0] letters = iter(string.ascii_lowercase) - ishapes = [_create_eshape(itensor.shape, letters), _create_eshape(weight.shape, letters)] + ishapes = [ + ShapeAnno.create_shape_str(itensor.shape, iterator=letters), + ShapeAnno.create_shape_str(weight.shape, iterator=letters) + ] oshapes = [ishapes[0] + [ishapes[1][-1]]] - anno = _create_anno(ishapes, oshapes) + anno = OpAnno.create_op_str(ishapes, oshapes) return IREinops(signature, [anno], [itensor, weight], 'embedding', padding_idx=padding_idx, start=start, stop=stop) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 380b8d38..c2576a85 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -453,7 +453,7 @@ def replicate(self, op: IRCell, times=1, reset_dependency=True) -> Optional[List self.reset_dependency() return [op] + fnodes - def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional[List[IRCell]]: + def partition(self, op: IRCell, algo: GenericDistAlgo, **config) -> Optional[List[IRCell]]: """ Partition an operator (op) by using op partition algorithm (algo) and its configuration (config). @@ -477,9 +477,9 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, config: Dict) -> Optional if algo.node != op: return None - if not algo.satisfy(config): + if not algo.satisfy(**config): return None - fnodes = algo.instantiate(config) + fnodes = algo.instantiate(**config) #FIXME: we don't allow non-weight input to be splitted in value for fnode in fnodes: diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 6efbeb4c..e167242e 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -2,7 +2,7 @@ Register cutomized function """ -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional import inspect import torch @@ -11,7 +11,7 @@ from cube.graph.parser.mapping import Sign2Op -def register(anno: str): +def register(anno: str, name: Optional[str] = None): """ Register a function with einop annotations. @@ -20,7 +20,7 @@ def register(anno: str): 1). Has type annotations for each input 2). Tensor inputs goes first then other inputs - For EinDims containing brackets (e.g., (3 h d)) that can not be + For DimAnnos containing brackets (e.g., (3 h d)) that can not be inferred by system, user should have same argument name in the function definition to help system infer each dim length, e.g., @@ -31,7 +31,7 @@ def funcname(x: torch.Tensor, b: int = 4): def decorator(fn: Callable): if not callable(fn): raise TypeError("Expected a function") - fsig = fn.__name__ + fsig = fn.__name__ if name is None else name args = inspect.signature(fn) arg_names = list(args.parameters.keys()) arg_kind = [args.parameters[name].annotation for name in arg_names] diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 57f7cbc7..dfec9c66 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -708,8 +708,7 @@ def grad(self) -> Optional[Union[IRTensor, float]]: for consumer in self.parent.consumers: for itensor in consumer.inputs(): if self.overlap(itensor): - assert itensor == self, \ - "partial overlapping of consumed tensors is not supported during backward" + # TODO: we should guarantee in final status itensor == self # replicated nodes will have same node id if consumer._id not in ref_consumers: ref_consumers.append(consumer._id) From 98f860bf797fae275845cc428e8363f5652dfba9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Jun 2022 00:20:33 +0800 Subject: [PATCH 0871/1892] update examples with new partition interface --- examples/gsearch/blocks.py | 8 +- examples/gsearch/gpt/policy/naive.py | 8 - examples/gsearch/gpt/policy/spmd.py | 48 +++++ examples/gsearch/gpt/train.py | 3 +- examples/mlp/linears.py | 12 +- examples/mlp/policy/col_parallel.py | 63 ------ examples/mlp/policy/data_parallel.py | 24 --- examples/mlp/policy/hybrid_parallel.py | 24 --- examples/mlp/policy/megatron.py | 33 --- examples/mlp/policy/no_parallel.py | 4 - examples/mlp/policy/optimal.py | 60 ------ examples/mlp/policy/row_parallel.py | 25 --- examples/mlp/policy/spmd.py | 189 ++++++++++++++++++ examples/poisson/policy/{naive.py => spmd.py} | 17 +- examples/poisson/sci.py | 16 +- 15 files changed, 263 insertions(+), 271 deletions(-) delete mode 100644 examples/gsearch/gpt/policy/naive.py create mode 100644 examples/gsearch/gpt/policy/spmd.py delete mode 100644 examples/mlp/policy/col_parallel.py delete mode 100644 examples/mlp/policy/data_parallel.py delete mode 100644 examples/mlp/policy/hybrid_parallel.py delete mode 100644 examples/mlp/policy/megatron.py delete mode 100644 examples/mlp/policy/no_parallel.py delete mode 100644 examples/mlp/policy/optimal.py delete mode 100644 examples/mlp/policy/row_parallel.py create mode 100644 examples/mlp/policy/spmd.py rename examples/poisson/policy/{naive.py => spmd.py} (52%) diff --git a/examples/gsearch/blocks.py b/examples/gsearch/blocks.py index 05281ffb..c34f0b16 100644 --- a/examples/gsearch/blocks.py +++ b/examples/gsearch/blocks.py @@ -3,7 +3,7 @@ import warnings -@cube.graph.parser.register('L N E+, (h d) E+, (h d), (h d) E+, (h d), (h d) E+, (h d) -> N h L d, N h L d, N h L d') +@cube.graph.parser.register('L N E+, (h d) E+, (h d), (h d) E+, (h d), (h d) E+, (h d) -> N h L d, N h L d, N h L d', name='attn_qkv') def attn_qkv(query: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, @@ -28,7 +28,7 @@ def attn_qkv(query: torch.Tensor, return q, k, v -@cube.graph.parser.register('N h L d, N h L d -> N h L L') +@cube.graph.parser.register('N h L d+, N h L d+ -> N h L L', name='attn_score') def attn_score(q: torch.Tensor, k: torch.Tensor, h: int, mask: bool = True): N, num_head, L, d = q.size() assert num_head == h @@ -47,7 +47,7 @@ def attn_score(q: torch.Tensor, k: torch.Tensor, h: int, mask: bool = True): return attn -@cube.graph.parser.register('N h L^ L^ -> N h L^ L^') +@cube.graph.parser.register('N h L K^ -> N h L K^', name='attn_softmax') def attn_softmax(attn: torch.Tensor): N, h, L, L = attn.size() attn = attn.view((N * h), L, L) @@ -55,7 +55,7 @@ def attn_softmax(attn: torch.Tensor): return attn.view(N, h, L, L) -@cube.graph.parser.register('N h L L, N h L d -> L N (h d)') +@cube.graph.parser.register('N h L K+, N h K+ d -> L N (h d)', name='attn_context') def attn_context(attn: torch.Tensor, v: torch.Tensor): N, h, L, d = v.size() attn = attn.view((N * h), L, L) diff --git a/examples/gsearch/gpt/policy/naive.py b/examples/gsearch/gpt/policy/naive.py deleted file mode 100644 index 10de3596..00000000 --- a/examples/gsearch/gpt/policy/naive.py +++ /dev/null @@ -1,8 +0,0 @@ -from cube.graph import IRGraph - -def PAS(graph: IRGraph, resource): - # print(graph.extra_repr()) - for node in graph.nodes(): - graph.assign(node, 0) - # print(graph.extra_repr()) - return graph \ No newline at end of file diff --git a/examples/gsearch/gpt/policy/spmd.py b/examples/gsearch/gpt/policy/spmd.py new file mode 100644 index 00000000..b21d0d35 --- /dev/null +++ b/examples/gsearch/gpt/policy/spmd.py @@ -0,0 +1,48 @@ +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation + + +def PASReplica(graph: IRGraph, resource): + assert resource.ngpus == 1 + print(graph.extra_repr()) + for node in graph.nodes(): + graph.assign(node, 0) + # print(graph.extra_repr()) + return graph + + +def PASMegatron(graph: IRGraph, resource): + + tp_size = resource.ngpus + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + def tensor_parallelism(node, idx: int, dim: int, num: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, config=dict(idx=idx, dim=dim, num=num)) + assert all(isinstance(n, IRFwOperation) for n in sub_nodes), f"Fail to partition node {node}" + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return sub_nodes + + qkvs = [node for node in fnodes if node.name == 'attn_qkv'] + for qkv in qkvs: + tensor_parallelism(qkv, idx=1, dim=0, num=tp_size) + + scores = [node for node in fnodes if node.name == 'attn_score'] + for score in scores: + tensor_parallelism(score, idx=0, dim=1, num=tp_size) + + softmaxs = [node for node in fnodes if node.name == 'attn_softmax'] + for softmax in softmaxs: + tensor_parallelism(softmax, idx=0, dim=1, num=tp_size) + + contexts = [node for node in fnodes if node.name == 'attn_context'] + for context in contexts: + tensor_parallelism(context, idx=0, dim=1, num=tp_size) + + for node in graph.nodes(): + if isinstance(node, IRFwOperation) and len(node.device) == 0: + rnodes = graph.replicate(node, times=tp_size) + for idx, rnode in enumerate(rnodes): + graph.assign(rnode, idx) + return graph diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py index bc6758f5..f019ef13 100644 --- a/examples/gsearch/gpt/train.py +++ b/examples/gsearch/gpt/train.py @@ -12,13 +12,12 @@ from examples.gsearch.gpt.model import GPT from examples.gsearch.gpt.model import GPTDataLoader +from examples.gsearch.gpt.policy.spmd import PASReplica as PAS import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary, model_summary -from examples.nlp.gpt.policy.naive import PAS - def train(): diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 2039085f..47622401 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -5,14 +5,6 @@ --nproc_per_node=4 \ --nnodes=1 \ examples/mlp/linears.py - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --rdzv_id=888 \ - --rdzv_backend=c10d \ - --rdzv_endpoint=worker0:8004 \ - examples/mlp/linears.py """ import torch @@ -21,7 +13,7 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.optimal import PAS +from examples.mlp.policy.spmd import PASMegatron as PAS # =================== Semantic Model Description ==================== @@ -51,7 +43,7 @@ def forward(self, data): def train(): - batch_size = 8192 + batch_size = 256 dim = 8192 model = MLP(dim=dim) diff --git a/examples/mlp/policy/col_parallel.py b/examples/mlp/policy/col_parallel.py deleted file mode 100644 index 6ea6a7de..00000000 --- a/examples/mlp/policy/col_parallel.py +++ /dev/null @@ -1,63 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation - - -def P(graph, resource): - """ - P policy - """ - for node in graph.nodes(): - if isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation): - algo = node.algorithms('dim') - if algo: - sub_nodes = graph.partition( - node, algo, config=dict(idx=1, dim=0, num=resource.ngpus) - ) - else: - # graph.assign(node, list(range(resource.ngpus))) - sub_nodes = graph.replicate(node, times=resource.ngpus) - # device hint - for idx, node in enumerate(sub_nodes): - node.tag = idx - return graph - - -def A(graph, resource): - """ - A policy - """ - for node in graph.nodes(): - if node.tag is not None: - device = node.tag - graph.assign(node, device) - return graph - - -def S(graph, resource): - """ - Schedule Policy. => use default schedule - """ - return graph - - -def PAS(graph: IRGraph, resource): - """ - Linear Column Partition - """ - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=1, dim=0, num=resource.ngpus) - ) - if sub_nodes is None: # partition fails - # graph.assign(node, list(range(resource.ngpus))) - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - print(graph.extra_repr()) - return graph diff --git a/examples/mlp/policy/data_parallel.py b/examples/mlp/policy/data_parallel.py deleted file mode 100644 index e224a0b6..00000000 --- a/examples/mlp/policy/data_parallel.py +++ /dev/null @@ -1,24 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation - - -def PAS(graph: IRGraph, resource): - """ - Linear Column Partition - """ - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(num=resource.ngpus)) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - batch_dim = node.get_batch_dims()[0] - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=batch_dim, num=resource.ngpus)) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - print(graph.extra_repr()) - return graph diff --git a/examples/mlp/policy/hybrid_parallel.py b/examples/mlp/policy/hybrid_parallel.py deleted file mode 100644 index 84230022..00000000 --- a/examples/mlp/policy/hybrid_parallel.py +++ /dev/null @@ -1,24 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation - - -def PAS(graph: IRGraph, resource): - """ - Linear Hybrid Partition - """ - for idx, node in enumerate(graph.nodes()): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=1, dim=(idx+1)%2, num=resource.ngpus) - ) - if sub_nodes is None: # partition fails - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - print(graph.extra_repr()) - return graph diff --git a/examples/mlp/policy/megatron.py b/examples/mlp/policy/megatron.py deleted file mode 100644 index 99eadc64..00000000 --- a/examples/mlp/policy/megatron.py +++ /dev/null @@ -1,33 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation - - -def PAS(graph: IRGraph, resource): - """ - Linear Hybrid + Nested Partition - """ - tp = 4 - dp = resource.ngpus // tp - for idx, node in enumerate(graph.nodes()): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - continue - if isinstance(node, IRFwOperation): - sub_nodes = list() - algo = node.algorithms('dim') - tp_nodes = graph.partition( - node, algo, config=dict(idx=1, dim=(idx+1)%2, num=tp) - ) - if tp_nodes is not None: - for tp_node in tp_nodes: - algo = tp_node.algorithms('dim') - dp_nodes = graph.partition(tp_node, algo, config=dict(idx=0, dim=0, num=dp)) - sub_nodes += dp_nodes - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - # print(graph.extra_repr()) - return graph diff --git a/examples/mlp/policy/no_parallel.py b/examples/mlp/policy/no_parallel.py deleted file mode 100644 index 18e3fade..00000000 --- a/examples/mlp/policy/no_parallel.py +++ /dev/null @@ -1,4 +0,0 @@ -from cube.graph import IRGraph - -def PAS(graph: IRGraph, resource): - return graph diff --git a/examples/mlp/policy/optimal.py b/examples/mlp/policy/optimal.py deleted file mode 100644 index fcc6d8ea..00000000 --- a/examples/mlp/policy/optimal.py +++ /dev/null @@ -1,60 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation - - -def PAS(graph: IRGraph, resource): - - assert resource.ngpus == 4, "the optimal plan is for 4 GPU case." - - # replicate data operation - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - - # replicate loss operation - fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] - loss = fnodes[-1] - sub_nodes = graph.replicate(loss, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - - fnodes = fnodes[:-1] - # linear0 config - config0 = [ - None, - dict(idx=1, dim=0, num=4) # col - ] - # linear1 config - config1 = [ - dict(idx=0, dim=1, num=2), # row - dict(idx=1, dim=0, num=2), # col - ] - # linear2 config - config2 = [ - dict(idx=0, dim=0, num=2), # dat - dict(idx=0, dim=1, num=2), # row - ] - # linear3 config - config3 = [ - dict(idx=0, dim=0, num=2), # dat - dict(idx=0, dim=1, num=2), # row - ] - configs = [config0, config1, config2, config3] - assert len(fnodes) == len(configs) - for fnode, config in zip(fnodes, configs): - all_nodes = [fnode] - for conf in config: - if conf is None: - continue - sub_nodes = list() - for node in all_nodes: - algo = node.algorithms('dim') - nodes = graph.partition(node, algo, conf) - sub_nodes += nodes - all_nodes = sub_nodes - assert len(all_nodes) == 4 - for idx, node in enumerate(all_nodes): - graph.assign(node, idx) - return graph diff --git a/examples/mlp/policy/row_parallel.py b/examples/mlp/policy/row_parallel.py deleted file mode 100644 index 012f12aa..00000000 --- a/examples/mlp/policy/row_parallel.py +++ /dev/null @@ -1,25 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation - - -def PAS(graph: IRGraph, resource): - """ - Linear Column Partition - """ - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=1, dim=1, num=resource.ngpus) - ) - if sub_nodes is None: # partition fails - # graph.assign(node, list(range(resource.ngpus))) - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - print(graph.extra_repr()) - return graph diff --git a/examples/mlp/policy/spmd.py b/examples/mlp/policy/spmd.py new file mode 100644 index 00000000..6cbfa308 --- /dev/null +++ b/examples/mlp/policy/spmd.py @@ -0,0 +1,189 @@ +from cube.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation + + +def PASSingle(graph: IRGraph, resource): + """ + Single device + """ + assert resource.ngpus == 1, "only apply for single gpu case" + for node in graph.nodes(): + graph.assign(node, 0) + return graph + + +def PASData(graph: IRGraph, resource): + """ + Data Parallel + """ + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + batch_dim = node.get_batch_dims()[0] + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=0, dim=batch_dim, num=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def PASCol(graph: IRGraph, resource): + """ + Linear Column Parallel + """ + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + if isinstance(node, IRFwOperation): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=0, num=resource.ngpus + ) + if sub_nodes is None: # partition fails + # graph.assign(node, list(range(resource.ngpus))) + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def PASRow(graph: IRGraph, resource): + """ + Linear Column Parallel + """ + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + if isinstance(node, IRFwOperation): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=1, num=resource.ngpus + ) + if sub_nodes is None: # partition fails + # graph.assign(node, list(range(resource.ngpus))) + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def PASHybrid(graph: IRGraph, resource): + """ + Linear Hybrid Parallelism (Megatron) + """ + for idx, node in enumerate(graph.nodes()): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + if isinstance(node, IRFwOperation): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=(idx+1)%2, num=resource.ngpus + ) + if sub_nodes is None: # partition fails + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + print(graph.extra_repr()) + return graph + + +def PASMegatron(graph: IRGraph, resource): + """ + Tensor + Data Parallelism + """ + tp = 2 + dp = resource.ngpus // tp + for idx, node in enumerate(graph.nodes()): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + continue + if isinstance(node, IRFwOperation): + sub_nodes = list() + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=(idx+1)%2, num=tp) + if tp_nodes is not None: + for tp_node in tp_nodes: + algo = tp_node.algorithms('dim') + dp_nodes = graph.partition(tp_node, algo, idx=0, dim=0, num=dp) + sub_nodes += dp_nodes + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + # print(graph.extra_repr()) + return graph + + +def PASOptimal(graph: IRGraph, resource): + """ + Square Linear optimal parallelism (4GPU) + """ + assert resource.ngpus == 4, "only apply to 4 GPU case" + + # replicate data operation + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + + # replicate loss operation + fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] + loss = fnodes[-1] + sub_nodes = graph.replicate(loss, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + + fnodes = fnodes[:-1] + # linear0 config + config0 = [ + None, + dict(idx=1, dim=0, num=4) # col + ] + # linear1 config + config1 = [ + dict(idx=0, dim=1, num=2), # row + dict(idx=1, dim=0, num=2), # col + ] + # linear2 config + config2 = [ + dict(idx=0, dim=0, num=2), # dat + dict(idx=0, dim=1, num=2), # row + ] + # linear3 config + config3 = [ + dict(idx=0, dim=0, num=2), # dat + dict(idx=0, dim=1, num=2), # row + ] + configs = [config0, config1, config2, config3] + assert len(fnodes) == len(configs) + for fnode, config in zip(fnodes, configs): + all_nodes = [fnode] + for conf in config: + if conf is None: + continue + sub_nodes = list() + for node in all_nodes: + algo = node.algorithms('dim') + nodes = graph.partition(node, algo, **conf) + sub_nodes += nodes + all_nodes = sub_nodes + assert len(all_nodes) == 4 + for idx, node in enumerate(all_nodes): + graph.assign(node, idx) + return graph + diff --git a/examples/poisson/policy/naive.py b/examples/poisson/policy/spmd.py similarity index 52% rename from examples/poisson/policy/naive.py rename to examples/poisson/policy/spmd.py index d58e2d5e..03720306 100644 --- a/examples/poisson/policy/naive.py +++ b/examples/poisson/policy/spmd.py @@ -1,20 +1,27 @@ from cube.graph import IRGraph from cube.graph.function import IRConv2D -def PAS(graph: IRGraph, resource): + +def PASReplica(graph: IRGraph, resource) -> IRGraph: + for node in graph.nodes(): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph + + +def PASHaloConv(graph: IRGraph, resource) -> IRGraph: for node in graph.nodes(): if isinstance(node, IRConv2D): sub_nodes = list() algo = node.algorithms('halo') - Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus // 2)) + Wnodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus // 2) for Wnode in Wnodes: algo = Wnode.algorithms('halo') - Hnodes = graph.partition(Wnode, algo, config=dict(idx=0, dim=2, num=2)) + Hnodes = graph.partition(Wnode, algo, idx=0, dim=2, num=2) sub_nodes += Hnodes else: sub_nodes = graph.replicate(node, times=resource.ngpus) - # sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) - print(graph.extra_repr()) return graph diff --git a/examples/poisson/sci.py b/examples/poisson/sci.py index 47e7b900..e6415f0e 100644 --- a/examples/poisson/sci.py +++ b/examples/poisson/sci.py @@ -1,4 +1,9 @@ -from typing import List +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/poisson/sci.py +""" import torch import torch.nn.functional as F @@ -8,14 +13,7 @@ from cube.runtime.syndata import SciLoopVariables import cube -from examples.poisson.policy.naive import PAS - -""" -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - examples/poisson/sci.py -""" +from examples.poisson.policy.spmd import PASHaloConv as PAS class ScientificModel(torch.nn.Module): From 4af35477d166df3372471503310937652461e072 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Jun 2022 13:36:56 +0800 Subject: [PATCH 0872/1892] fix bugs for einop partition on kwarg dimensions --- cube/algorithm/ops/einops.py | 7 +++++-- cube/graph/function/einops.py | 7 +++++-- cube/graph/gener/layout.py | 7 ++++++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py index 29e43103..f24f86f5 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/einops.py @@ -57,7 +57,7 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IREinops]]: if not self.satisfy(idx, dim, num): return False node: IREinops = self.node - print(node.anno, f'partition: {self._adim}; reduce: {self._reduce.value}') + print(f'{node.anno} | => partition: {self._adim} reduce: {self._reduce.value}') ins, ous = list(), list() for iidx, itensor in enumerate(node.inputs()): @@ -99,7 +99,10 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IREinops]]: for nid in range(num): inputs = [t[nid] for t in ins] outputs = [t[nid] for t in ous] - sub_node: IREinops = node.new(inputs, outputs) + updated_kwargs = dict() + if self._adim in node.kwargs and isinstance(node.kwargs[self._adim], int): + updated_kwargs[self._adim] = node.kwargs[self._adim] // num + sub_node: IREinops = node.new(inputs, outputs, **updated_kwargs) sub_node.infer_shape() sub_nodes.append(sub_node) return sub_nodes diff --git a/cube/graph/function/einops.py b/cube/graph/function/einops.py index aad441a0..de72088b 100644 --- a/cube/graph/function/einops.py +++ b/cube/graph/function/einops.py @@ -65,6 +65,7 @@ from typing import Dict, Iterable, List, Union, Optional, Set, Tuple, Optional import enum import re +import copy import string from cube.ir.cten import IRTensor @@ -566,7 +567,7 @@ def infer_shape(self) -> bool: # f'=> outputs: {self.outputs()}') return True - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + def new(self, inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): """! Construct a new operator sharing same kwargs with new inputs and outputs @@ -577,7 +578,9 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): @return op IRDimop: the new constructed operator """ annos = self._annos_candidates - op = IREinops(self.signature, annos, inputs, self.name, **self.kwargs) + updated_kwargs = copy.copy(self.kwargs) + updated_kwargs.update(kwargs) + op = IREinops(self.signature, annos, inputs, self.name, **updated_kwargs) for idx, output in enumerate(outputs): op.set_output(idx, output) return op diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 44814f6c..1c56b4a4 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -339,7 +339,12 @@ def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): start = subtensor.indmap.get()[dim].start fnele = ftensor.shape[dim] if fnele % snele != 0 or start % snele != 0: - raise RuntimeError(f"dimension split error: full nele: {fnele}, sub nele: {snele}, start: {start}") + print(subtensor, dim) + raise RuntimeError( + f"dimension split error:\n" + f"Full Tensor: {ftensor}\n" + f"full nele: {fnele}, sub nele: {snele}, start: {start}" + ) dchunks[dim].add(fnele // snele) _tindex[tid].append(start // snele) # replica (R) From 705e57060b5eed09e486c837bef41e50ee8241a3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Jun 2022 13:37:22 +0800 Subject: [PATCH 0873/1892] fix split_allgather interface --- cube/runtime/adapter/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index 3132883b..bc4f99c0 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -187,7 +187,7 @@ def allgather_split(tensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch. class SplitAllGather(torch.autograd.Function): @staticmethod - def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int], group): + def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int]): """ ranks should be the global rank """ @@ -203,7 +203,7 @@ def backward(ctx, grad: torch.Tensor): return grad, None, None -def chunk_allgather(tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: +def split_allgather(tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: return SplitAllGather.apply(tensor, dim, ranks) From 30c0d8cb31de6ec29efd49c841326a8cfe2be8d4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Jun 2022 13:38:05 +0800 Subject: [PATCH 0874/1892] enable comment generation --- cube/codegen/codegen.py | 3 +++ cube/ir/cten.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index c255ce5b..72398610 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -228,6 +228,9 @@ def emit_op_call(self, node: IRFwOperation): """ Emit op forward code """ + # insert comment + if node.comment is not None: + self.forward_region.append(f'# {node.comment}') signature = node.signature inputs = [self.tensor_naming(t) for t in node.inputs()] kwargs = {} diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 41c1e96a..c8b18428 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -69,7 +69,9 @@ def __init__(self, self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length+1)] self._mirror = None - self._tag = None + + # the comment for code generation + self._comment: Optional[str] = None # def __eq__(self, other): # if isinstance(other, IRCell): @@ -361,15 +363,16 @@ def get_outputs(cells): return outputs @property - def tag(self) -> Any: - return self._tag + def comment(self) -> Any: + return self._comment - @tag.setter - def tag(self, info: Any): + @comment.setter + def comment(self, info: str): """ Tag an info to the cell """ - self._tag = info + assert isinstance(info, str), "comment only allowed to be string" + self._comment = info def __repr__(self): """ From 791b2758baa79735c95cb1e22b59ba0a200f1393 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Jun 2022 13:38:34 +0800 Subject: [PATCH 0875/1892] megatron attention tensor parallelism --- examples/gsearch/blocks.py | 16 ++++++++++++++-- examples/gsearch/gpt/policy/spmd.py | 24 ++++++++++++++++++------ examples/gsearch/gpt/train.py | 4 ++-- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/examples/gsearch/blocks.py b/examples/gsearch/blocks.py index c34f0b16..96a44b48 100644 --- a/examples/gsearch/blocks.py +++ b/examples/gsearch/blocks.py @@ -55,6 +55,11 @@ def attn_softmax(attn: torch.Tensor): return attn.view(N, h, L, L) +@cube.graph.parser.register('N h L L -> N h L L', name='attn_dropout') +def attn_dropout(attn: torch.Tensor, dropout_p: float): + return torch.nn.functional.dropout(attn, dropout_p, True, False) + + @cube.graph.parser.register('N h L K+, N h K+ d -> L N (h d)', name='attn_context') def attn_context(attn: torch.Tensor, v: torch.Tensor): N, h, L, d = v.size() @@ -66,6 +71,11 @@ def attn_context(attn: torch.Tensor, v: torch.Tensor): return output +@cube.graph.parser.register('L N hd+, E hd+, E -> L N E', name='attn_dense_out') +def attn_dense_out(context: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): + return torch.nn.functional.linear(context, weight, bias) + + class MultiHeadSelfAttention(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): @@ -102,11 +112,13 @@ def forward(self, query: torch.Tensor): # softmax attn = attn_softmax(attn) # dropout - attn = torch.nn.functional.dropout(attn, self.dropout_p, True, False) # (N h) L L -> (N h) L L + attn = attn_dropout(attn, self.dropout_p) # N h L L -> N h L L + # attn = torch.nn.functional.dropout(attn, self.dropout_p, True, False) # N h L L -> N h L L # context context = attn_context(attn, v) # DenseOutput - output = torch.nn.functional.linear(context, self.out_proj, self.out_bias) # L N (h d), E E -> L N E + # output = torch.nn.functional.linear(context, self.out_proj, self.out_bias) # L N (h d), E E -> L N E + output = attn_dense_out(context, self.out_proj, self.out_bias) # L N (h d), E E -> L N E return output diff --git a/examples/gsearch/gpt/policy/spmd.py b/examples/gsearch/gpt/policy/spmd.py index b21d0d35..8cd49120 100644 --- a/examples/gsearch/gpt/policy/spmd.py +++ b/examples/gsearch/gpt/policy/spmd.py @@ -1,5 +1,5 @@ from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation def PASReplica(graph: IRGraph, resource): @@ -16,17 +16,20 @@ def PASMegatron(graph: IRGraph, resource): tp_size = resource.ngpus fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - def tensor_parallelism(node, idx: int, dim: int, num: int): + def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, config=dict(idx=idx, dim=dim, num=num)) + sub_nodes = graph.partition(node, algo, **configs) + if isinstance(comment, str): + for sub_node in sub_nodes: + sub_node.comment = comment assert all(isinstance(n, IRFwOperation) for n in sub_nodes), f"Fail to partition node {node}" for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) return sub_nodes qkvs = [node for node in fnodes if node.name == 'attn_qkv'] - for qkv in qkvs: - tensor_parallelism(qkv, idx=1, dim=0, num=tp_size) + for idx, qkv in enumerate(qkvs): + tensor_parallelism(qkv, f'====> start of transformer {idx}', idx=1, dim=0, num=tp_size) scores = [node for node in fnodes if node.name == 'attn_score'] for score in scores: @@ -36,13 +39,22 @@ def tensor_parallelism(node, idx: int, dim: int, num: int): for softmax in softmaxs: tensor_parallelism(softmax, idx=0, dim=1, num=tp_size) + dropouts = [node for node in fnodes if node.name == 'attn_dropout'] + for dropout in dropouts: + tensor_parallelism(dropout, idx=0, dim=1, num=tp_size) + contexts = [node for node in fnodes if node.name == 'attn_context'] for context in contexts: tensor_parallelism(context, idx=0, dim=1, num=tp_size) + dense_outs = [node for node in fnodes if node.name == 'attn_dense_out'] + for dense in dense_outs: + tensor_parallelism(dense, idx=0, dim=2, num=tp_size) + for node in graph.nodes(): - if isinstance(node, IRFwOperation) and len(node.device) == 0: + if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: rnodes = graph.replicate(node, times=tp_size) for idx, rnode in enumerate(rnodes): graph.assign(rnode, idx) + print(graph.extra_repr()) return graph diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py index f019ef13..a8e22879 100644 --- a/examples/gsearch/gpt/train.py +++ b/examples/gsearch/gpt/train.py @@ -2,7 +2,7 @@ example: OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ examples/gsearch/gpt/train.py """ @@ -12,7 +12,7 @@ from examples.gsearch.gpt.model import GPT from examples.gsearch.gpt.model import GPTDataLoader -from examples.gsearch.gpt.policy.spmd import PASReplica as PAS +from examples.gsearch.gpt.policy.spmd import PASMegatron as PAS import cube from cube.profiler.timer import CudaTimer, print_each_rank From b3d4b18c036a5d648cbc02d29d4f88192d635e2d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Jun 2022 21:20:14 +0800 Subject: [PATCH 0876/1892] init recomputation --- cube/codegen/codegen.py | 3 + cube/graph/graph.py | 7 ++ cube/ir/adapter/adapter.py | 26 ++++++ cube/ir/operator.py | 173 +++++++++++++++++-------------------- 4 files changed, 115 insertions(+), 94 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 72398610..76814883 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -435,6 +435,9 @@ def emit_node(self, node: IRCell, name: str) -> str: if isinstance(node, IRSegment): # emit forward if node.forward: + recompute = any(isinstance(n.recompute, int) for n in node.nodes()) + if recompute: + raise NotImplementedError("recompute mechanism is not supported") body = fsign.format(model=f'model.{name}', inputs=inputs, req_grad=req_grad) code = f'{outputs} = {body}' # emit backward diff --git a/cube/graph/graph.py b/cube/graph/graph.py index c2576a85..8d785ca7 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -11,6 +11,7 @@ import copy from cube.ir.cten import IRTensor, IRCell +from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.ir.adapter import IRAdapter from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -650,6 +651,12 @@ def add_schedule(self, nodes: List[IRCell]) -> bool: post.add_predecessor(input_index=-1, cell=prev) return True + def recompute(self, nodes: List[IRFwOperation]): + assert all(isinstance(fnode, IRFwOperation) for fnode in nodes), "require forward operations" + recompute_group_id = IDGenerator().gen_cell_id() + for fnode in nodes: + fnode.recompute = recompute_group_id + def set_order(self, seq: List[IRCell]): """ Set a topological order for IRGraph, which requires seq: diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index bd0d09b5..b90f3fbe 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -28,6 +28,9 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): device.update(set(tensor.device)) self.device = list(device) + # recompute group id + self._recompute = None + # setup whether this adapter is for forward stage is_fw = any(not t.is_grad() for t in self.inputs() + self.outputs()) is_bw = any(t.is_grad() for t in self.inputs() + self.outputs()) @@ -75,6 +78,29 @@ def forward(self) -> bool: """ return self._forward + @property + def recompute(self) -> Optional[int]: + """! + Get recompute group id. + To enable recompute, a recompute group refers to a sequence of operators that + will perform recompute optimization. + + @return group_id Optional[int]: None if no recompute, else a group id. + """ + return self._recompute + + @recompute.setter + def recompute(self, group_id: Optional[int]): + """! + Set recompute group + + @param group_id Optional[int]: recompute group id. None indicates no group is applied + """ + assert group_id is None or isinstance(group_id, int), "Expect None or int" + if isinstance(group_id, int) and self._recompute is not None: + assert self._recompute == group_id, "The operator is set to recompute in another recompute group." + self._recompute = group_id + def dispatch(self, devid: int, for_mirror=True): """ Get Adapter for a specific rank diff --git a/cube/ir/operator.py b/cube/ir/operator.py index c3e36ad4..7ee1bf2d 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -37,23 +37,27 @@ def replicate(self): class IRFwOperation(IRCell): + """ + Forward operation + """ def __init__(self, name: str, signature: str, input_length: int, output_length: int): - """ - Create a node with name (variable name) and module type (module_name) + """! + Create a forward operation. - Args: - name (str): the op semantic name - signature (str): the op signature, e.g., torch.functional.nn.linear - input_length (int): the number of inputs for the op - output_length (int): the number of outputs for the op + @param name str: the name of forward operation + @param signature str: the signature of the forward operation + @param input_length int: number of inputs + @param output_length int: number of outputs """ # additional argument self.kwargs = dict() + # recompute schedule + self._recompute = None super().__init__(name, signature, input_length, output_length, init_outputs=False) outputs = [IRFullTensor() for _ in range(output_length)] for idx, output in enumerate(outputs): @@ -65,6 +69,29 @@ def infer_shape(self): """ raise NotImplementedError + @property + def recompute(self) -> Optional[int]: + """! + Get recompute group id. + To enable recompute, a recompute group refers to a sequence of operators that + will perform recompute optimization. + + @return group_id Optional[int]: None if no recompute, else a group id. + """ + return self._recompute + + @recompute.setter + def recompute(self, group_id: Optional[int]): + """! + Set recompute group + + @param group_id Optional[int]: recompute group id. None indicates no group is applied + """ + assert group_id is None or isinstance(group_id, int), "Expect None or int" + if isinstance(group_id, int) and self._recompute is not None: + assert self._recompute == group_id, "The operator is set to recompute in another recompute group." + self._recompute = group_id + def algorithms(self, tag: Optional[str] = None): """ get algorithm from algorithm factory @@ -88,12 +115,14 @@ def algorithms(self, tag: Optional[str] = None): return template(self) def replicate(self): - """ - Replicate the Operation + """! + Replicate the forward operation. + The operator id, recompute and comment attribute will also be replicated. + + @return replica IRFwOperation: the replicated operator """ cpy = copy.copy(self) cpy._device = list() - # cpy._id = IDGenerator().gen_cell_id() # reset input and output cpy._inputs = [None] * len(self.inputs()) for idx, input in enumerate(self.inputs()): @@ -102,13 +131,12 @@ def replicate(self): for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None - cpy._tag = None cpy.clear_predecessor() cpy.clear_successor() return cpy - def gen_backward(self): - """ + def gen_backward(self) -> IRCell: + """! Generate backward operator for this forward operator. Note by calling this API, this forward operator must be @@ -120,18 +148,7 @@ def gen_backward(self): if self.mirror is not None: raise RuntimeError( "Backward Op already generated. Use self.mirror.update() instead.") - bnode = IRBpOperation( - data_num=len(self.inputs()), - grad_num=len(self.outputs()) - ) - for idx, itensor in enumerate(self.inputs()): - grad = itensor.grad if isinstance(itensor, IRSubTensor) else None - bnode.set_data(idx, itensor) - bnode.set_output(idx, grad) - for idx, otensor in enumerate(self.outputs()): - grad = otensor.grad if isinstance(otensor, IRSubTensor) else None - bnode.set_input(idx, grad) - IRCell.make_pair(self, bnode) + bnode = IRBpOperation(self) return bnode def __repr__(self): @@ -150,75 +167,26 @@ def module_repr(self) -> str: class IRBpOperation(IRCell): + """ + Backward operation + """ - def __init__(self, data_num: int, grad_num, name='backward'): + def __init__(self, fwop: IRFwOperation): """ - Args: - data_num (int): corresponding forward input length - grad_num (int): corresponding forward output length + Create dummy backward node for forward inputs and forward outputs + + @param fwop IRFwOperation: forward operator """ - signature = 'torch.autograd.backward' - self.data_num = data_num - self.grad_num = grad_num - self._datas = [None] * data_num + assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" + finputs, foutputs = fwop.inputs(), fwop.outputs() super().__init__( - name, signature, - input_length=grad_num, - output_length=data_num, - init_outputs=False + 'backward', 'torch.autograd.grad', + len(foutputs), len(finputs), init_outputs=False ) - - def replicate(self): - """ - Replicate the backward op - """ - cpy = copy.copy(self) - cpy._device = list() - cpy._id = IDGenerator().gen_cell_id() - # reset input and output - cpy._inputs = [None] * len(self.inputs()) - for idx, input in enumerate(self.inputs()): - cpy.set_input(idx, input) - cpy._outputs = [None] * len(self.outputs()) - for idx, output in enumerate(self.outputs()): - cpy.set_output(idx, output) - cpy._mirror = None - cpy._tag = None - cpy.clear_predecessor() - cpy.clear_successor() - return cpy - - def datas(self, index: Optional[int] = None) -> Union[List[Any], Any]: - """ - Forward inputs - """ - if index is None: - return copy.copy(self._datas[:self.data_num]) - if index >= self.data_num: - raise RuntimeError( - f"Set the input out of range ({index} >= {self.data_num})" - ) - return self._datas[index] - - def set_data(self, data_index: int, val: Any): - """ - Set the node inputs[input_index] with the tensor - - Args: - val: Union[IRTensor, Any] - - Return: - the set tensor - """ - if data_index >= self.data_num: - raise RuntimeError( - f"Set the input out of range ({data_index} >= {self.data_num})" - ) - val = copy.copy(val) - if isinstance(val, IRTensor): - val.attach_cell(self) - self._datas[data_index] = val - return val + # pair forward op and backward op + IRCell.make_pair(self, fwop) + # set inputs and outputs + self.update() def update(self): """ @@ -238,14 +206,32 @@ def update(self): assert isinstance(fnode, IRFwOperation), "Cannot find corresponding IRFwOperation" for idx, itensor in enumerate(fnode.inputs()): grad = itensor.grad if isinstance(itensor, IRSubTensor) else None - self.set_data(idx, itensor) self.set_output(idx, grad) for idx, otensor in enumerate(fnode.outputs()): grad = otensor.grad if isinstance(otensor, IRSubTensor) else None self.set_input(idx, grad) - def __repr__(self): - dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, inputs={self.inputs()}, datas={self.datas()}, outputs={self.outputs()})' + def replicate(self): + """ + Replicate the backward op + """ + cpy = copy.copy(self) + cpy._device = list() + cpy._id = IDGenerator().gen_cell_id() + # reset input and output + cpy._inputs = [None] * len(self.inputs()) + for idx, input in enumerate(self.inputs()): + cpy.set_input(idx, input) + cpy._outputs = [None] * len(self.outputs()) + for idx, output in enumerate(self.outputs()): + cpy.set_output(idx, output) + cpy._mirror = None + cpy.clear_predecessor() + cpy.clear_successor() + return cpy + + def __repr__(self) -> str: + dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, inputs={self.inputs()}, outputs={self.outputs()})' return dscp @@ -256,7 +242,7 @@ def __init__(self, data_num: int, batch_dims: Tuple[int], name='dataloader'): raise RuntimeError("Expected each output data has a specified batch dim") signature = 'dataloader.__next__' super().__init__(name, signature, 0, data_num) - self.batch_dims = batch_dims + self.batch_dims = tuple(batch_dims) def replicate(self): """ @@ -273,7 +259,6 @@ def replicate(self): for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None - cpy._tag = None cpy.clear_predecessor() cpy.clear_successor() return cpy From c6f106ecf8cfda320cb6a9989a18b202d82c05bc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Jun 2022 21:20:34 +0800 Subject: [PATCH 0877/1892] init recompute policy --- examples/attention/attention.py | 180 ---------------------------- examples/attention/policy/naive.py | 11 -- examples/gsearch/gpt/policy/spmd.py | 21 +++- examples/gsearch/gpt/train.py | 1 + 4 files changed, 21 insertions(+), 192 deletions(-) delete mode 100644 examples/attention/attention.py delete mode 100644 examples/attention/policy/naive.py diff --git a/examples/attention/attention.py b/examples/attention/attention.py deleted file mode 100644 index c7873501..00000000 --- a/examples/attention/attention.py +++ /dev/null @@ -1,180 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/attention/attention.py - -OMP_NUM_THREADS=1 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/attention/attention.py -""" - -import torch -from torch import nn - -import cube -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - -from examples.attention.policy.naive import PAS - - -@cube.graph.parser.register('L^ N E^, (3 h d^) E^ -> L^ N (h d^)') -def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, h: int, - scale: float, dropout: float, training: bool): - """ - L: sequence length - N: batch size - E: embedding size - x: hidden state: [L, N, E] - wqkv: qkv weight: [3 * (num_head * dim_head), E] - dropout: float - h: int: number of heads - """ - num_head = h - L, N = x.shape[0], x.shape[1] - dim_head = wqkv.shape[0] // 3 // num_head - # L N E, (3 h d) E -> L N (3 h d) - qkv = torch.nn.functional.linear(x, wqkv, None) - # L N (3 h d) -> L N (h d), L N (h d), L N (h d) - q, k, v = qkv.chunk(3, dim=-1) - # L N (h d) -> L (N h) d - q = q.contiguous().view(L, (N * num_head), dim_head) - # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) - # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) - # L (N h) d -> (N h) L d - q = q.transpose(0, 1) - # L (N h) d -> (N h) L d - k = k.transpose(0, 1) - # L (N h) d -> (N h) L d - v = v.transpose(0, 1) - # (N h) L d, 1 -> (N h) L d - q = q * scale - # (N h) L d -> (N h) d L - k = k.transpose(-2, -1) - # (N h) L d, (N h) d L -> (N h) L L - attn = torch.bmm(q, k) - - # attention mask - # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N * num_head), L, L) - - # (N h) L L -> (N h) L L - attn = torch.nn.functional.softmax(attn, dim=-1) - # (N h) L L -> (N h) L L - if training: - attn = torch.nn.functional.dropout(attn, dropout, True, False) - # (N h) L L, (N h) L d -> (N h) L d - output = torch.bmm(attn, v) - # (N h) L d -> L (N h) d - output = output.transpose(0, 1).contiguous() - # L (N h) d -> L N (h d) - output = output.view(L, N, num_head * dim_head) - return output - - -class MultiHeadSelfAttention(nn.Module): - - def __init__(self, seq_len, embed_dim, heads, dropout: float): - super().__init__() - - self.seq_len = seq_len - self.embed_dim = embed_dim - self.num_head = heads - self.dim_head = embed_dim // heads - self.scale = self.dim_head ** -0.5 - - self.wqkv = torch.nn.Parameter(torch.empty( - 3 * embed_dim, embed_dim - )) - self.wout = torch.nn.Parameter(torch.empty( - embed_dim, embed_dim - )) - self.dropout = dropout - - def forward(self, x): - """ - x: [L, N, E]: seq_len, batch_size, embedding dimension - output: [L, N, E] - """ - # L N E, (3 h d) E -> L N (h d) - output = attnfc1(x, self.wqkv, self.num_head, - self.scale, self.dropout, self.training) - # L N (h d), E (h d) -> L N E - output = torch.nn.functional.linear(output, self.wout) - - loss = torch.sum(output) - return loss - - -def train(): - L = 512 # seq len - N = 32 # batch size - # configs: [hidden size, num_head] - # E, num_head = [1536, 16] # 1.2B model - # E, num_head = [1920, 20] # 2.5B model - # E, num_head = [2304, 24] # 4.2B model - E, num_head = [3072, 32] # 8.7B model - - - model = MultiHeadSelfAttention( - seq_len=L, embed_dim=E, heads=num_head, dropout=0.5 - ) - model = cube.SemanticModel( - model, input_shapes=([L, N, E],), - ) - - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([L, N, E],), - dtypes=(torch.float32,), - batch_dims=(1,) - ) - - @cube.compile(model, dataloader, PAS=PAS) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - model = model.get_gen_module() - - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - memory_summary() - - -if __name__ == '__main__': - - cube.init() - train() diff --git a/examples/attention/policy/naive.py b/examples/attention/policy/naive.py deleted file mode 100644 index 164de037..00000000 --- a/examples/attention/policy/naive.py +++ /dev/null @@ -1,11 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.adapter.adapter import IRAdapter -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation - - -def PAS(graph: IRGraph, resource): - print(graph.extra_repr()) - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - return graph diff --git a/examples/gsearch/gpt/policy/spmd.py b/examples/gsearch/gpt/policy/spmd.py index 8cd49120..3a6b4f62 100644 --- a/examples/gsearch/gpt/policy/spmd.py +++ b/examples/gsearch/gpt/policy/spmd.py @@ -3,6 +3,9 @@ def PASReplica(graph: IRGraph, resource): + """ + Single device test + """ assert resource.ngpus == 1 print(graph.extra_repr()) for node in graph.nodes(): @@ -12,7 +15,9 @@ def PASReplica(graph: IRGraph, resource): def PASMegatron(graph: IRGraph, resource): - + """ + Megatron tensor parallelism (attention) + """ tp_size = resource.ngpus fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] @@ -58,3 +63,17 @@ def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): graph.assign(rnode, idx) print(graph.extra_repr()) return graph + + +def PASRecompute(graph: IRGraph, resource): + """ + Recompute parallelism test + """ + assert resource.ngpus == 1 + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + graph.recompute(fnodes) + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py index a8e22879..07d0ac0a 100644 --- a/examples/gsearch/gpt/train.py +++ b/examples/gsearch/gpt/train.py @@ -64,6 +64,7 @@ def train_iter(model, dataloader): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-40, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-40) memory_summary() From 7eb191249aa3cb0c51c851879eb4e52c3b11659a Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Mon, 20 Jun 2022 06:38:42 +0000 Subject: [PATCH 0878/1892] Merged PR 1386: Rename ExecutionPlan there seemed to be a typo in the class name, rename it. --- cube/codegen/codegen.py | 10 +++++----- cube/compiler.py | 4 ++-- cube/execplan/__init__.py | 2 +- cube/execplan/execplan.py | 2 +- cube/execplan/planpass/fusion.py | 4 ++-- cube/execplan/planpass/grouping.py | 4 ++-- cube/execplan/planpass/planpass.py | 4 ++-- cube/search/sampler.py | 4 ++-- examples/mlp/policy/st_search.py | 4 ++-- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 76814883..6ace2048 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -16,7 +16,7 @@ from cube.graph.graph import IRGraph, IRSegment from cube.graph.schedule import IRScheduleStrategy -from cube.execplan import ExectuionPlan +from cube.execplan import ExecutionPlan from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -28,8 +28,8 @@ class CodeGen: """ Generate code for the model """ - def __init__(self, execplan: ExectuionPlan): - if not isinstance(execplan, ExectuionPlan): + def __init__(self, execplan: ExecutionPlan): + if not isinstance(execplan, ExecutionPlan): raise TypeError("execplan should be ExecutionPlan") self.execplan = execplan @@ -60,7 +60,7 @@ class ModelCodeGen(CodeGen): Generate model code """ - def __init__(self, execplan: ExectuionPlan): + def __init__(self, execplan: ExecutionPlan): super().__init__(execplan) # model full code self.init_code: List[str] = [ @@ -332,7 +332,7 @@ def clear(self): class ScheduleCodeGen(CodeGen): - def __init__(self, execplan: ExectuionPlan): + def __init__(self, execplan: ExecutionPlan): super().__init__(execplan) # model full code self.init_code: List[str] = [ diff --git a/cube/compiler.py b/cube/compiler.py index 349fea68..51827810 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -13,7 +13,7 @@ from cube.logics.pool import SchedulePool from cube.logics.translator import LogicTranslator -from cube.execplan import ExectuionPlan +from cube.execplan import ExecutionPlan from cube.execplan.planpass.fusion import DiffFusion from cube.execplan.planpass.grouping import Grouping @@ -172,7 +172,7 @@ def decorator(fn: Callable) -> Callable: print(graph.schedule_plan) # to execution plan - execplan = ExectuionPlan(graph) + execplan = ExecutionPlan(graph) # plan pass for communication optimization start = time.time() diff --git a/cube/execplan/__init__.py b/cube/execplan/__init__.py index a1160701..c6d0899c 100644 --- a/cube/execplan/__init__.py +++ b/cube/execplan/__init__.py @@ -1 +1 @@ -from cube.execplan.execplan import ExectuionPlan \ No newline at end of file +from cube.execplan.execplan import ExecutionPlan \ No newline at end of file diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 691257c1..4e286401 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -8,7 +8,7 @@ from cube.graph.graph import IRGraph, IRSegment -class ExectuionPlan: +class ExecutionPlan: def __init__(self, graph: IRGraph): assert isinstance(graph, IRGraph), "Expected an IRGraph" diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 5e1cb2d1..83a07e2d 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -3,7 +3,7 @@ from cube.ir.adapter import IRAdapter -from cube.execplan import ExectuionPlan +from cube.execplan import ExecutionPlan from cube.execplan.planpass.planpass import PlanPass from cube.ir.adapter.prim import IRAdapterPrim @@ -18,7 +18,7 @@ class DiffFusion(PlanPass): @staticmethod - def apply(execplan: ExectuionPlan) -> ExectuionPlan: + def apply(execplan: ExecutionPlan) -> ExecutionPlan: """ Fuse the non-differentiable adapters into differentiable adapters. """ diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 64cc254a..d2c20802 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -3,7 +3,7 @@ """ from typing import List, Dict, Tuple -from cube.execplan import ExectuionPlan +from cube.execplan import ExecutionPlan from cube.execplan.planpass.planpass import PlanPass from cube.ir.adapter import IRAdapter from cube.ir.operator import IRBpOperation, IRFwOperation @@ -13,7 +13,7 @@ class Grouping(PlanPass): @staticmethod - def apply(execplan: ExectuionPlan) -> ExectuionPlan: + def apply(execplan: ExecutionPlan) -> ExecutionPlan: """ Group contiguous differentiable operators segments """ diff --git a/cube/execplan/planpass/planpass.py b/cube/execplan/planpass/planpass.py index 558de959..3d079b2b 100644 --- a/cube/execplan/planpass/planpass.py +++ b/cube/execplan/planpass/planpass.py @@ -1,8 +1,8 @@ -from cube.execplan import ExectuionPlan +from cube.execplan import ExecutionPlan class PlanPass: @staticmethod - def apply(execplan: ExectuionPlan) -> ExectuionPlan: + def apply(execplan: ExecutionPlan) -> ExecutionPlan: raise NotImplementedError diff --git a/cube/search/sampler.py b/cube/search/sampler.py index 08aa873f..41e32b4e 100644 --- a/cube/search/sampler.py +++ b/cube/search/sampler.py @@ -5,7 +5,7 @@ from cube.graph.graph import IRGraph from cube.ir.operator import IRFwOperation, IRBpOperation from cube.ir.cten import IRCell -from cube.execplan import ExectuionPlan +from cube.execplan import ExecutionPlan from multiprocessing import Pool import numpy as np @@ -352,7 +352,7 @@ def _run(seqs: List[List[IRCell]]) -> Dict[int, Tuple[int, List]]: graph = IRGraph([], [], [], 'search') for seq in seqs: graph._nodes = seq - execplan = ExectuionPlan(graph) + execplan = ExecutionPlan(graph) span, mem = execplan.analyze(map2time=Estimator.map2time, map2mem=Estimator.map2mem) if mem not in bucket: bucket[mem] = (span, copy.copy(seq)) diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py index eacfa9e7..3aa140cf 100644 --- a/examples/mlp/policy/st_search.py +++ b/examples/mlp/policy/st_search.py @@ -3,7 +3,7 @@ from cube.graph import IRGraph from cube.ir.operator import IRFwOperation from cube.ir.cten import IRCell -from cube.execplan import ExectuionPlan +from cube.execplan import ExecutionPlan from cube.search.sampler import Estimator, Sampler, SpatialSampler, TemporalSampler, Searcher @@ -67,7 +67,7 @@ def PAS(graph: IRGraph, resource): Searcher.search(seqs, bucket, n_worker=n_worker) for mem, (span, seq) in bucket.items(): sgraph._nodes = seq - execplan = ExectuionPlan(sgraph) + execplan = ExecutionPlan(sgraph) execplan.analyze(map2time=Estimator.map2time, outfile=f'plan.mem{mem}.png') cnt += len(seqs) print(f'done search on {cnt} sequences') From 0aad3fd191511dbcc6c15f89e314e91aaaffb781 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 20 Jun 2022 17:16:22 +0800 Subject: [PATCH 0879/1892] enable residual with different partition layout --- cube/algorithm/ops/einops.py | 43 ++++++----- cube/graph/function/einops.py | 7 -- cube/graph/gener/gen.py | 32 +++++--- cube/graph/graph.py | 126 +++++++++++++++++------------- cube/ir/cten.py | 40 ++++++---- cube/ir/tensor.py | 141 ++++++++++++++++++++++++++-------- cube/logics/model.py | 9 +++ cube/logics/translator.py | 2 +- 8 files changed, 259 insertions(+), 141 deletions(-) diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/einops.py index f24f86f5..be4be240 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/einops.py @@ -1,9 +1,8 @@ -from typing import List, Dict, Optional - -from cube.algorithm.utils import split_axis, split_value +from typing import List, Optional from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.einops import IREinops, DimAnno +from cube.ir.tensor import IRSubTensor class DimSplitEinops(GenericDistAlgo): @@ -38,10 +37,11 @@ def satisfy(self, idx: int, dim: int, num: int) -> bool: node: IREinops = self.node ninputs = len(node.inputs()) - if idx >= ninputs or idx < 0-ninputs: - return False - if node.inputs(idx).shape is None or abs(dim) >= len(node.inputs(idx).shape): - return False + idx = idx if idx >= 0 else idx + ninputs + assert idx < ninputs, f"index out of boundary: {idx} >= {ninputs}" + assert isinstance(node.inputs(idx), IRSubTensor), f"partitioning on a non-tensor input" + dim = dim if dim >= 0 else dim + node.inputs(idx).ndims + assert dim < node.inputs(idx).ndims, f"dimension output of boundary: {dim} >= {node.inputs(idx).ndims}" # due to implementation limits, we only partition the first annotated dimension # for inner-dimension cases. self._adim: str = node.anno.inputs(idx).dims[dim].identifiers[0] @@ -54,32 +54,39 @@ def satisfy(self, idx: int, dim: int, num: int) -> bool: return True def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IREinops]]: - if not self.satisfy(idx, dim, num): - return False + node: IREinops = self.node - print(f'{node.anno} | => partition: {self._adim} reduce: {self._reduce.value}') + satisfy = self.satisfy(idx, dim, num) + print(f'partition {node.name}: {node.anno} | dim: {self._adim} reduce: {self._reduce.value}') + if not satisfy: + return None ins, ous = list(), list() for iidx, itensor in enumerate(node.inputs()): + if not isinstance(itensor, IRSubTensor): + ins.append([itensor] * num) + continue shape_anno = node.anno.inputs(iidx) split_dims = shape_anno.getdims(self._adim) assert len(split_dims) <= 1, f"find split dims ({self._adim}) more than 1: {shape_anno}" if len(split_dims) == 1: dim = split_dims[0] # split axis - sub_tensors = split_axis(itensor, dim, num) - ins.append(sub_tensors) + ins.append(itensor.split_dim(dim, num)) else: # replicate if no split dimension of this tensor # ins.append([itensor] * num) # ad-hoc FIXME: for linear function Ax+b of splitting reduction dimension, b should # be splitted by value dimension. if self._reduce == DimAnno.ReduceType.Sum: - ins.append(split_value(itensor, num)) + ins.append(itensor.split_val(num)) else: - ins.append([itensor] * num) - + ins.append(itensor.replicate(num)) + for oidx, otensor in enumerate(node.outputs()): + if not isinstance(otensor, IRSubTensor): + ous.append([otensor] * num) + continue shape_anno = node.anno.outputs(oidx) split_dims = shape_anno.getdims(self._adim) assert len(split_dims) <= 1, f"find split dims ({self._adim}) more than 1: {shape_anno}" @@ -87,13 +94,11 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IREinops]]: if self._reduce == DimAnno.ReduceType.Dim: assert len(split_dims) == 1, f"expect only one spatial dimension in output tensor but got {len(split_dims)}" dim = split_dims[0] - sub_tensors = split_axis(otensor, dim, num) - ous.append(sub_tensors) + ous.append(otensor.split_dim(dim, num)) # split numerical dimension else: assert len(split_dims) == 0, f"expect no numerical dimension in output tensor but got {len(split_dims)}" - sub_tensors = split_value(otensor, num) - ous.append(sub_tensors) + ous.append(otensor.split_val(num)) sub_nodes = list() for nid in range(num): diff --git a/cube/graph/function/einops.py b/cube/graph/function/einops.py index de72088b..46cd498b 100644 --- a/cube/graph/function/einops.py +++ b/cube/graph/function/einops.py @@ -106,13 +106,6 @@ def name(self) -> str: return self._identifiers[0] return '(' + ' '.join(self._identifiers) + ')' - def length(self, identifier: str) -> Optional[int]: - """ - Return the integer of identifer - """ - assert identifier in self._identifiers, f"identifier {identifier} not in {self}" - return self._length[identifier] - @property def identifiers(self) -> Tuple[str]: return self._identifiers diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 06cfa8ae..1675e36c 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -96,6 +96,14 @@ def gen_weight(graph: IRGraph) -> IRGraph: @staticmethod def gen_activation(graph: IRGraph) -> IRGraph: + """! + Generate adapter for activation tensors. + The forward/backward adapter is inserted before the first consumers of its full tensor. + + @param graph IRGraph: the graph the requires for adapter. + + @return graph IRGraph: the (inplace) modified graph with activation adapters. + """ for ftensor in graph.full_tensors(): # backward will gen in forward if ftensor.is_param() or ftensor.is_grad(): @@ -103,16 +111,18 @@ def gen_activation(graph: IRGraph) -> IRGraph: adapters = IRAdapterGener.gen_fulltensor(ftensor) if len(adapters) == 0: continue + # insert forward adapter + fidx = min([graph.nodes().index(c) for c in ftensor.consumers]) + for fadapter in adapters: + graph._nodes.insert(fidx, fadapter) + # insert bacward adapter + bidx = None if ftensor.grad is None else min([graph.nodes().index(c) for c in ftensor.grad.consumers]) for fadapter in adapters: - # insert forward adapter - idx = min([graph.nodes().index(c) for c in ftensor.consumers]) - graph._nodes.insert(idx, fadapter) # insert backward adapter badapter: IRAdapter = fadapter.mirror if badapter is not None: - grad: Optional[IRFullTensor] = ftensor.grad - idx = min([graph.nodes().index(c) for c in grad.consumers]) - graph._nodes.insert(idx, badapter) + assert isinstance(bidx, int), "have backward adapter but no gradient required." + graph._nodes.insert(bidx, badapter) return graph @staticmethod @@ -273,7 +283,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: stop = start + oslicer.stop - oslicer.start indmap.append(slice(start, stop, 1)) valmap = ValueMap(0, 1) - common.attach_cell(subtensor._cell) + common.cell = subtensor.cell prims.append(SelectPrim(tensor, indmap, valmap, common)) return prims # check local + remote @@ -282,7 +292,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: if not itensor.overlap(subtensor): continue common = itensor.common(subtensor) - common.attach_cell(itensor._cell) + common.cell = itensor.cell # print(f'get common: {common.extra_repr()}') intersections.append(common) if common == itensor: @@ -307,7 +317,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: mtensor = tensor if tensor.device != subtensor.device: mtensor = copy.copy(tensor) - mtensor.attach_cell(subtensor._cell) + mtensor.cell = subtensor.cell prims.append(MovePrim(tensor, mtensor)) tmoved.append(mtensor) @@ -329,7 +339,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: vid = min(vid1, vid2) // 2 valmap = ValueMap(vid, t1.valmap.chunk_num // 2) out = subtensor.parent.select(t1.indmap, valmap, t1.shape) - out.attach_cell(subtensor._cell) + out.cell = subtensor.cell prims.append(SumPrim([t1, t2], out)) merged = True break @@ -355,7 +365,7 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: indmap = IndexMap(tuple(indmap)) valmap = t1.valmap out = t1.parent.select(indmap, valmap, indmap.shape) - out.attach_cell(subtensor._cell) + out.cell = subtensor.cell cdim = list(cat_dim.keys())[0] prims.append(MergeDimPrim(cat_dim[cdim], out, cdim)) merged = True diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 8d785ca7..b5cf6b1f 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -322,11 +322,13 @@ def attach(self, node: IRCell, index, reset_dependency=False): if isinstance(node, IRAdapter): return # update consumer - itensors = [] + itensors: List[IRSubTensor] = [] for itensor in node.inputs(): if isinstance(itensor, IRSubTensor) and itensor not in itensors: itensors.append(itensor) for itensor in itensors: + if itensor.parent._id not in self._full_tensors: + self._full_tensors[itensor.parent._id] = itensor.parent idx = 0 for consumer in itensor.parent.consumers: if self.nodes().index(consumer) < index: @@ -335,11 +337,13 @@ def attach(self, node: IRCell, index, reset_dependency=False): break itensor.parent.add_consumer(node, itensor, idx) # update producer - otensors = [] + otensors: List[IRSubTensor] = [] for otensor in node.outputs(): if isinstance(otensor, IRSubTensor) and otensor not in otensors: otensors.append(otensor) for otensor in otensors: + if otensor.parent._id not in self._full_tensors: + self._full_tensors[itensor.parent._id] = itensor.parent idx = 0 for producer in otensor.parent.producers: if self.nodes().index(producer) < index: @@ -421,91 +425,95 @@ def get_outputs(nodes: List[IRCell]): outputs.append(output) return outputs - ## Parallel Policy Primitives ## + ##### Partition Primitives ##### - def replicate(self, op: IRCell, times=1, reset_dependency=True) -> Optional[List[IRCell]]: + def replicate(self, node: Union[IRFwOperation, IRDataOperation], + times=1, reset_dependency=True) -> Optional[List[IRCell]]: """ - Replicate a forward or data operation multiple times. + Partition Primitive: + - replicate: replicate a forward or data operation multiple times. + + Each input and output will be replicated with no gradient accumulation. The backward of the forward operation will automatically be replicated. + + @param: node: Union[IRFwOperation, IRDataOperation] """ - if not (isinstance(op, IRFwOperation) or isinstance(op, IRDataOperation)): + if not isinstance(node, (IRFwOperation, IRDataOperation)): raise TypeError("Expected op to be forward op or data op") if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") - if op not in self.nodes(): - raise RuntimeError(f"Op {op} not exsits") + if node not in self.nodes(): + raise RuntimeError(f"Op {node} not exsits") - fnodes = [op.replicate() for _ in range(times - 1)] + fnodes = [node.replicate() for _ in range(times - 1)] # insert forward - fidx = self.nodes().index(op) + fidx = self.nodes().index(node) for idx, fnode in enumerate(fnodes): self.attach(fnode, fidx + idx + 1) # insert backward - if isinstance(op.mirror, IRBpOperation): + if isinstance(node.mirror, IRBpOperation): for fnode in fnodes: fnode.gen_backward() bnodes = [fnode.mirror for fnode in fnodes][::-1] - bidx = self.nodes().index(op.mirror) + bidx = self.nodes().index(node.mirror) for idx, bnode in enumerate(bnodes): self.attach(bnode, bidx + idx) if reset_dependency: self.reset_dependency() - return [op] + fnodes + return [node] + fnodes - def partition(self, op: IRCell, algo: GenericDistAlgo, **config) -> Optional[List[IRCell]]: + def partition(self, node: Union[IRFwOperation, IRDataOperation], + algo: GenericDistAlgo, **config) -> Optional[List[IRCell]]: """ - Partition an operator (op) by using - op partition algorithm (algo) and its configuration (config). - Note the backward op-partition will be automatically done. - - Args: - op: cell to be partitioned - algo: generic distributed algorithm related to the op - config: dict + Partition Primitive: + - partition: partition a forward or data operation using algorithms. + + The backward of the forward operation will automaticall be partitioned. + + Requirement to partition algorithm: + if backward is required, the algorithm can only transform tensors in: + replicate: results in gradient accumulation + split dimensionL no gradient accumulation + split value (outputs only): no gradient accumulation + + Difference of partition and replicate primitive: + Both primitive may replicate the tensors, but `replicate` will not do gradient + accumulation while `partition` will always require gradient accumulation on + replicated tensors. + + @param node Union[IRFwOperation, IRDataOperation]: the node to partition + @param algo GenericDistAlgo: the partition algorithm related to the node + @param config Dict[str, Any]: the algorithm configuration, e.g., partition number - Returns: - nodes: List[IRCell] if partitioned successfully. - None if failed - """ - if not isinstance(algo, GenericDistAlgo): - raise TypeError("Expected algo to be GenericDistAlgo") - if op not in self.nodes(): - raise RuntimeError(f"Not Exist: {op}") - if not (isinstance(op, IRFwOperation) or isinstance(op, IRDataOperation)): + @return Optional[IRCell]: partitioned sub-nodes or None (fail to partition) + """ + assert isinstance(algo, GenericDistAlgo) and node == algo.node, \ + "The partition algorithm is not initialized for this node" + if node not in self.nodes(): + raise RuntimeError(f"Not Exist: {node}") + if not (isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation)): raise ValueError("Only allow op to be forward op or data op.") - if algo.node != op: - return None - if not algo.satisfy(**config): - return None + # get partitioned sub-nodes fnodes = algo.instantiate(**config) + if fnodes is None: return fnodes - #FIXME: we don't allow non-weight input to be splitted in value - for fnode in fnodes: - for input in fnode.inputs(): - if isinstance(input, IRSubTensor): - if input.valmap.chunk_num != 1 and not input.is_param(): - raise NotImplementedError( - f"Not support feature-map {input} to be splitted in value as input" - ) # update forward - findex = self.detach(op) + findex = self.detach(node) for idx, fnode in enumerate(fnodes): self.attach(fnode, findex + idx) # update backward - if isinstance(op.mirror, IRBpOperation): - bindex = self.detach(op.mirror) + if isinstance(node.mirror, IRBpOperation): + bindex = self.detach(node.mirror) bnodes = [fnode.gen_backward() for fnode in fnodes][::-1] for idx, bnode in enumerate(bnodes): self.attach(bnode, bindex + idx) # update gradient updated = set() - for input in op.inputs(): - if not isinstance(input, IRSubTensor): - continue - for fnode in input.parent.consumers: + for itensor in [t for t in node.inputs() if isinstance(t, IRSubTensor)]: + for fnode in itensor.parent.consumers: bnode = fnode.mirror if isinstance(bnode, IRBpOperation) and fnode._id not in updated: idx = self.detach(bnode) @@ -514,10 +522,9 @@ def partition(self, op: IRCell, algo: GenericDistAlgo, **config) -> Optional[Lis updated.add(fnode._id) # update device for fnode in fnodes: - fnode.device = op.device + fnode.device = node.device if isinstance(fnode.mirror, IRCell): - fnode.mirror.device = op.device - self.reset_dependency() + fnode.mirror.device = node.device return fnodes def merge(self, nodes: List[IRCell], target_node: IRCell): @@ -578,7 +585,7 @@ def merge(self, nodes: List[IRCell], target_node: IRCell): def identity(self, input_tensor, dst_op): raise NotImplementedError - ## Assign Policy Primitives ## + ## Spatial Primitives ## def assign(self, op: IRCell, ranks: Union[int, List[int]]): """ @@ -651,11 +658,20 @@ def add_schedule(self, nodes: List[IRCell]) -> bool: post.add_predecessor(input_index=-1, cell=prev) return True - def recompute(self, nodes: List[IRFwOperation]): + def recompute(self, nodes: List[IRFwOperation]) -> bool: + """! + Recompute a set of nodes. The forward nodes will be assigned with a unique + recompute group id. A forward not can not be recomputed in different recompute groups. + + @param nodes List[IRFwOperation]: nodes for a recompute group + + @return success boolean: always success + """ assert all(isinstance(fnode, IRFwOperation) for fnode in nodes), "require forward operations" recompute_group_id = IDGenerator().gen_cell_id() for fnode in nodes: fnode.recompute = recompute_group_id + return True def set_order(self, seq: List[IRCell]): """ diff --git a/cube/ir/cten.py b/cube/ir/cten.py index c8b18428..2a3a46dd 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -61,7 +61,7 @@ def __init__(self, if init_outputs: self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] for tensor in self._outputs: - tensor.attach_cell(self) + tensor.cell = self # destination cells. [-1] for control dependency self._successors: List[List[IRCell]] = [list() for _ in range(output_length+1)] @@ -73,11 +73,6 @@ def __init__(self, # the comment for code generation self._comment: Optional[str] = None - # def __eq__(self, other): - # if isinstance(other, IRCell): - # return self._id == other._id - # return False - @property def device(self): return copy.copy(self._device) @@ -235,7 +230,7 @@ def set_input(self, input_index: int, val: Any): # copy the val val = copy.copy(val) # set tensor dst - val.attach_cell(self) + val.cell = self # set input value dtype if self._dtype == IRDType.unknown: self._dtype = val.dtype @@ -260,7 +255,7 @@ def set_output(self, output_index: int, val: Any): ) if isinstance(val, IRTensor): val = copy.copy(val) - val.attach_cell(self) + val.cell = self # set output value dtype val.dtype = self._dtype self._outputs[output_index] = val @@ -404,7 +399,7 @@ class IRTensor: and will be translated to None in code generation. """ - _attr = ['name', '_is_param', '_is_grad', '_requires_grad', '_dtype'] + _attr = ['name', '_is_param', '_is_grad', '_requires_grad', '_dtype', '_grad_accum'] def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): @@ -422,6 +417,8 @@ def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): # tensor gradient self._requires_grad: bool = True self._grad: Optional[Union[IRTensor, float]] = None + # multi-reference id + self._grad_accum: Tuple[int] = (0, 1) @property def dtype(self) -> IRDType: @@ -448,14 +445,6 @@ def cell(self, val: Optional[IRCell]): assert isinstance(val, IRCell) or val is None, "Expected cell to be Optional[IRCell]" self._cell = val - def attach_cell(self, cell: IRCell): - """ - Attach to a cell, to be with input or output - """ - if not isinstance(cell, IRCell): - raise TypeError("Expected an IRCell") - self._cell = cell - @property def device(self) -> List[int]: if self._cell: @@ -469,6 +458,23 @@ def device(self, val: Union[int, List[int]]): "tensor placement is not allowed to set manually" ) + @property + def grad_accum(self) -> Tuple[int, int]: + return self._grad_accum + + @grad_accum.setter + def grad_accum(self, accum: Optional[Tuple[int, int]]): + """! + Set gradient accumulation: (idx, chunks) + """ + if accum is None: + self._grad_accum + else: + assert len(accum) == 2 and all(isinstance(acc, int) for acc in accum), \ + "Expected accum to be [int, int]: [idx, chunks]" + assert accum[0] < accum[1] + self._grad_accum = tuple(accum) + def as_param(self): """ Set the tensor as trainable parameter diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index dfec9c66..4e817fd1 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -24,7 +24,7 @@ import math from cube.ir.cten import IRCell, IRTensor -import cube.ir.dtype as irdtype +import cube.ir.dtype as irdtype class IndexMap: @@ -428,6 +428,15 @@ def add_producer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): self._ptensors.insert(idx, tensor) def add_consumer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): + """! + Add the tensor and its operator into consumer list. + The tensor should be in cell.inputs() + + @param cell IRCell: node to be consumer + @param tensor IRTensor: tensor to be consumed tensors + @param idx int: the index to be inserted + """ + assert tensor in cell.inputs(), f"tensor {tensor} not in node: {cell} inputs" if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): raise TypeError("Expect an IRCell and an IRTensor") assert cell not in self._consumers, f"{cell} already exists as consumer" @@ -509,19 +518,6 @@ def as_grad(self): self._is_grad = True return self - def like(self): - """ - Create a new tensor with same name and shape, - but with a different new id - - Returns: - tensor - """ - tensor = IRFullTensor(self._shape, self.name) - for attr in IRFullTensor._attr: - setattr(tensor, attr, getattr(self, attr)) - return tensor - def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, None], shape: List[int]): """ Select a SubTensor from FullTensor. @@ -666,6 +662,10 @@ def indmap(self) -> IndexMap: def valmap(self) -> ValueMap: return copy.copy(self._valmap) + @property + def ndims(self) -> int: + return len(self.shape) + def __copy__(self): """ Copy the tensor that will have the exactly same id @@ -704,24 +704,25 @@ def grad(self) -> Optional[Union[IRTensor, float]]: self._grad = full_grad # this tensor is consumed elif self in self.cell.inputs(): - ref_consumers = list() - for consumer in self.parent.consumers: - for itensor in consumer.inputs(): - if self.overlap(itensor): - # TODO: we should guarantee in final status itensor == self - # replicated nodes will have same node id - if consumer._id not in ref_consumers: - ref_consumers.append(consumer._id) - # if one node has multiple same tensors, - # will consider them as one - break - ref_times = len(ref_consumers) - if ref_times == 0: - raise RuntimeError("Internal error: consumer doesn't have the operator attached to this tensor") - idx = ref_consumers.index(self.cell._id) + # ref_consumers = list() + # for consumer in self.parent.consumers: + # for itensor in consumer.inputs(): + # if self.overlap(itensor): + # # TODO: we should guarantee in final status itensor == self + # # replicated nodes will have same node id + # if consumer._id not in ref_consumers: + # ref_consumers.append(consumer._id) + # # if one node has multiple same tensors, + # # will consider them as one + # break + # ref_times = len(ref_consumers) + # if ref_times == 0: + # raise RuntimeError("Internal error: consumer doesn't have the operator attached to this tensor") + # idx = ref_consumers.index(self.cell._id) + assert self.grad_accum is not None, "not supported for gradient accumulation" grad = full_grad.select( indmap = self.indmap, - valmap = (idx, ref_times), + valmap = self.grad_accum, #(idx, ref_times), shape = self.shape ) self._grad = grad @@ -746,6 +747,8 @@ def requires_grad(self) -> bool: _ = self.grad return self._requires_grad + # partition primitives + def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, None], shape=None) -> IRTensor: """ Select an IRSubTensor @@ -770,9 +773,85 @@ def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, valmap = self.valmap.map(sub_valmap) sub_tensor = self.parent.select(index_map, valmap, shape) + sub_tensor.grad_accum = None return sub_tensor - def overlap(self, other): + def replicate(self, num: int) -> List[IRTensor]: + """! + Partition primitive + - replicate: replicate the tensor. + + @return tensor IRTensor: the copied tensor + """ + aidx, chunks = self.grad_accum + tensors = [] + for idx in range(num): + tensor = copy.copy(self) + tensor.grad_accum = (aidx * num + idx, chunks * num) + tensors.append(tensor) + return tensors + + def split_dim(self, dim: int, num: int) -> List[IRTensor]: + """ + Partition primitive: + split_dim: uniformly split the tensor along a dimension. + + @param dim int: the dimension to get partitioned + @param num int: the number of sub-tensor generated + + @return sub_tensors List[IRSubTensor]: the generated sub-tensors + """ + dim = dim + self.ndims if dim < 0 else dim + assert dim < self.ndims, f"Dim should within ndims but {dim} >= {self.ndims})" + assert self.shape[dim] % num == 0, f"Expected dimension can be split: {self.shape[dim]} % {num} != 0" + chunk_size = self.shape[dim] // num + + shape_slicer = list() + chunk_shape = list() + for tdim, nele in enumerate(self.shape): + if tdim != dim: + shape_slicer.append(slice(0, nele, 1)) + chunk_shape.append(nele) + else: + shape_slicer.append(None) + chunk_shape.append(chunk_size) + sub_tensors = list() + for cid in range(num): + shape_slicer[dim] = slice(chunk_size * cid, chunk_size * (cid + 1), 1) + sub_tensor = self.select( + indmap = tuple(shape_slicer), + valmap = None, + shape = chunk_shape + ) + sub_tensor.grad_accum = self.grad_accum + sub_tensors.append(sub_tensor) + return sub_tensors + + def split_val(self, num: int) -> List[IRTensor]: + """! + Partition primitive: + split_val: uniformly split the tensor value. + + @param num int: the number of sub-tensor generated + + @return sub_tensors List[IRSubTensor]: the generated sub-tensors + """ + # full shape + shape_slicer = list() + for nele in self.shape: + shape_slicer.append(slice(0, nele, 1)) + sub_tensors = list() + for idx in range(num): + sub_tensor = self.select( + indmap = tuple(shape_slicer), + valmap = (idx, num), + shape = self.shape + ) + sub_tensor.grad_accum = self.grad_accum + sub_tensors.append(sub_tensor) + return sub_tensors + + def overlap(self, other) -> bool: """ Check if the two tensor is overlapped. diff --git a/cube/logics/model.py b/cube/logics/model.py index 130a1a0a..180ffce7 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -72,6 +72,15 @@ def forward(graph: IRGraph, *args) -> IRGraph: while itensor in graph.outputs(): oidx = graph.outputs().index(itensor) graph.set_output(oidx, arg) + # setup gradient accum + for ftensor in graph.full_tensors(): + naccum = len(ftensor.ctensors) + for idx, ctensor in enumerate(ftensor.ctensors): + ctensor.grad_accum = (idx, naccum) + # actually producer doesn't need to know accumulation + naccum = len(ftensor.producers) + for idx, ptensor in enumerate(ftensor.ptensors): + ptensor.grad_accum = (idx, naccum) # generate backward reverse is only to make op id looks consecutive for fnode in [n for n in graph.nodes() if isinstance(n, IRFwOperation)][::-1]: fnode.gen_backward() diff --git a/cube/logics/translator.py b/cube/logics/translator.py index a1d9bf0e..0b3c32fd 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -13,7 +13,7 @@ class LogicTranslator: @staticmethod - def gen_logic_graph(outputs=None): + def gen_logic_graph(outputs=None) -> IRGraph: """ Generate Training Logic Graph """ From b04dbb735e59c62cd9567ad30b0b800dd7eb61fd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 20 Jun 2022 17:24:14 +0800 Subject: [PATCH 0880/1892] update policy --- examples/mlp/policy/spmd.py | 116 +++++++++++++++++------------------- 1 file changed, 55 insertions(+), 61 deletions(-) diff --git a/examples/mlp/policy/spmd.py b/examples/mlp/policy/spmd.py index 6cbfa308..3520496a 100644 --- a/examples/mlp/policy/spmd.py +++ b/examples/mlp/policy/spmd.py @@ -37,21 +37,20 @@ def PASCol(graph: IRGraph, resource): """ Linear Column Parallel """ + linears = [node for node in graph.nodes() if node.name == 'linear'] + for idx, node in enumerate(linears): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=0, num=resource.ngpus + ) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=1, dim=0, num=resource.ngpus - ) - if sub_nodes is None: # partition fails - # graph.assign(node, list(range(resource.ngpus))) - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) + if isinstance(node, (IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) return graph @@ -59,21 +58,20 @@ def PASRow(graph: IRGraph, resource): """ Linear Column Parallel """ + linears = [node for node in graph.nodes() if node.name == 'linear'] + for idx, node in enumerate(linears): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=1, num=resource.ngpus + ) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=1, dim=1, num=resource.ngpus - ) - if sub_nodes is None: # partition fails - # graph.assign(node, list(range(resource.ngpus))) - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) + if isinstance(node, (IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) return graph @@ -81,20 +79,18 @@ def PASHybrid(graph: IRGraph, resource): """ Linear Hybrid Parallelism (Megatron) """ - for idx, node in enumerate(graph.nodes()): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=1, dim=(idx+1)%2, num=resource.ngpus - ) - if sub_nodes is None: # partition fails - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) + linears = [node for node in graph.nodes() if node.name == 'linear'] + for idx, node in enumerate(linears): + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=resource.ngpus) + for idx, node in enumerate(tp_nodes): + graph.assign(node, idx) + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) print(graph.extra_repr()) return graph @@ -105,25 +101,23 @@ def PASMegatron(graph: IRGraph, resource): """ tp = 2 dp = resource.ngpus // tp - for idx, node in enumerate(graph.nodes()): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - continue - if isinstance(node, IRFwOperation): - sub_nodes = list() - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=(idx+1)%2, num=tp) - if tp_nodes is not None: - for tp_node in tp_nodes: - algo = tp_node.algorithms('dim') - dp_nodes = graph.partition(tp_node, algo, idx=0, dim=0, num=dp) - sub_nodes += dp_nodes - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) + linears = [node for node in graph.nodes() if node.name == 'linear'] + for idx, node in enumerate(linears): + sub_nodes = [] + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=tp) + for tp_node in tp_nodes: + algo = tp_node.algorithms('dim') + dp_nodes = graph.partition(tp_node, algo, idx=0, dim=0, num=dp) + sub_nodes += dp_nodes + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) # print(graph.extra_repr()) return graph From 37d88e8d56ca54e9a48e4c0672efff8ae109144c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 20 Jun 2022 18:17:15 +0800 Subject: [PATCH 0881/1892] update pipeline --- examples/mlp/policy/megatron_pptp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/mlp/policy/megatron_pptp.py b/examples/mlp/policy/megatron_pptp.py index 9d567842..87efe7ef 100644 --- a/examples/mlp/policy/megatron_pptp.py +++ b/examples/mlp/policy/megatron_pptp.py @@ -47,9 +47,10 @@ def PAS(graph: IRGraph, resource): sid = node2stage(node) tp_group = tp_mesh[sid] # partition - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, dict(idx=1, dim=idx%2, num=num_tp)) - if tp_nodes is None: + if node.name == 'linear': + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=num_tp) + else: tp_nodes = graph.replicate(node, times=num_tp) # assign for devid, node in zip(tp_group, tp_nodes): From 867551ba2373eb0668f01a24aef9d73642562bcf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 21 Jun 2022 11:03:33 +0800 Subject: [PATCH 0882/1892] add test to primitive --- tests/test_rvd_prim.py | 134 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tests/test_rvd_prim.py diff --git a/tests/test_rvd_prim.py b/tests/test_rvd_prim.py new file mode 100644 index 00000000..0ecf66f6 --- /dev/null +++ b/tests/test_rvd_prim.py @@ -0,0 +1,134 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=1 \ + tests/test_rvd_prim.py --prims all + +""" + +from typing import Callable +import cube +import torch +import time +import argparse +from cube.profiler.timer import CudaTimer, print_each_rank + +from cube.runtime.adapter.collectives import all_reduce, all_gather, reduce_scatter, all_to_all +from cube.runtime.device import DeviceGroup + + +def prim_allreduce(itensor, ranks, dim0=None, dim1=None): + return all_reduce(itensor, ranks) + + +def bw_allreduce(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * 2 * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_allgather(itensor, ranks, dim0=0, dim1=None): + return all_gather(itensor, dim0, ranks) + + +def bw_allgather(itensor: torch.Tensor, ranks, sec_per_call: float): + ndevs = len(ranks) + msg_size = itensor.nelement() * 4 / 1e9 * ndevs + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_reducescatter(itensor, ranks, dim0=0, dim1=None): + return reduce_scatter(itensor, dim0, ranks) + + +def bw_reducescatter(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_alltoall(itensor, ranks, dim0=0, dim1=1): + return all_to_all(itensor, dim0, dim1, ranks) + + +def bw_alltoall(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_bw(prim: Callable, bandwidth: Callable, ranks, size, warmup=100, profile=100): + if 'allgather' in prim.__name__: + size = size // len(ranks) + tensor: torch.Tensor = torch.zeros(size, device=torch.cuda.current_device()) + tensor = tensor.view(256, -1).contiguous() + torch.distributed.barrier() + # warm up + for _ in range(warmup): + _ = prim(tensor, ranks) + # profile + torch.cuda.synchronize() + torch.distributed.barrier() + tic = time.perf_counter() + for _ in range(profile): + _ = prim(tensor, ranks) + torch.cuda.synchronize() + toc = time.perf_counter() + + span = (toc - tic) / profile # seconds + msg_size = tensor.nelement() * 4 // 1024 // 1024 # MB + if 'allgather' in prim.__name__: + msg_size = len(ranks) * tensor.nelement() * 4 // 1024 // 1024 # MB + algo_bw, bus_bw = bandwidth(tensor, ranks, span) + print_each_rank( + '{} msg {} : MBwall-time(ms) algo-bw(GB/s) bus-bw(GB/s) {:.2f} {:.2f} {:.2f}'.format( + prim.__name__, msg_size, span*1000, algo_bw, bus_bw + ), rank_only=0 + ) + + +if __name__ == '__main__': + + cube.init() + + parser = argparse.ArgumentParser(description='comm primitive') + parser.add_argument('--prims', type=str, nargs='+', action='append', + help='prims: all, allreduce, reducescatter, allgather, alltoall') + parser.add_argument('--begin', type=int, default=1, + help='start message size in MB') + parser.add_argument('--end', type=int, default=256, + help='end message size in MB') + args = parser.parse_args() + args.prims = args.prims[0] + + prims, bws = [], [] + if 'allrecuce' in args.prims or 'all' in args.prims: + prims.append(prim_allreduce) + bws.append(bw_allreduce) + if 'allgather' in args.prims or 'all' in args.prims: + prims.append(prim_allgather) + bws.append(bw_allgather) + if 'reducescatter' in args.prims or 'all' in args.prims: + prims.append(prim_reducescatter) + bws.append(bw_reducescatter) + if 'alltoall' in args.prims or 'all' in args.prims: + prims.append(prim_alltoall) + bws.append(bw_alltoall) + + ranks = tuple(range(DeviceGroup().world_size)) + CudaTimer(enable=False) + for prim, bw in zip(prims, bws): + print_each_rank(f'====> test start {prim.__name__}', rank_only=0) + size = args.begin + while size <= args.end: + prim_bw(prim, bw, ranks, size * 1024 * 1024 // 4) + size *= 2 + print_each_rank(f'====> test finish {prim.__name__}', rank_only=0) From 3fac23b3459e3a3196fc54a44f8e50fee0040a2b Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 23 Jun 2022 11:44:52 +0800 Subject: [PATCH 0883/1892] update Readme with non-install run mode and single device debug mode --- README.md | 36 +++++++++++++++++++++++++++++++----- cube/compiler.py | 3 ++- examples/mlp/linears.py | 3 ++- examples/mlp/policy/spmd.py | 2 +- 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8484cb53..6e14a58c 100644 --- a/README.md +++ b/README.md @@ -8,20 +8,46 @@ AI System Compiler to map a semantic (single-device) model into distributed exec > Install Python 3.7 in the development environment for widest compatibility. -## Install +Install dependent packages +```shell +pip install -r requirements.txt +``` + +## Option 1: Quick Start without Installation + +* ### Run on repo root path: +```sh +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py +``` + +[comment]: <> (UDA_VISIBLE_DEVICES=7 PYTHONPATH=.:$PYTHONPATH python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 ./examples/wrf/wrf2.py) + +* ### Debug for model parsing check on single Device +```shell +PYTHONPATH=.:$PYTHONPATH SINGLE_DEV_MODE=1 python examples/mlp/linears.py +``` + + +--- + +## Option 2: Install for Run + +* ### Install ```python pip install -r requirements.txt python setup.py develop ``` -## Run Examples - -* [Micro Benchmark] Run a mutiple MLP Model +* ### Run Example +[Micro Benchmark] Run a mutiple MLP Model ```sh OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.pys + examples/mlp/linears.py ``` diff --git a/cube/compiler.py b/cube/compiler.py index 51827810..f002a366 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -228,7 +228,8 @@ def decorator(fn: Callable) -> Callable: torch.distributed.barrier() # reset dataloader - torch.distributed.broadcast(batch_size, src=0) + if torch.distributed.is_initialized(): + torch.distributed.broadcast(batch_size, src=0) batch_size = batch_size.item() print_each_rank(f'reseting dataloader batch size to {batch_size}') dataloader.set_batch_size(batch_size) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 47622401..b544bfed 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -67,7 +67,8 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) CudaTimer(enable=False).warmup() - torch.distributed.barrier() + if torch.distributed.is_initialized(): + torch.distributed.barrier() iter_num = 64 warmup = 20 for step in range(iter_num): diff --git a/examples/mlp/policy/spmd.py b/examples/mlp/policy/spmd.py index 3520496a..42df3f59 100644 --- a/examples/mlp/policy/spmd.py +++ b/examples/mlp/policy/spmd.py @@ -99,7 +99,7 @@ def PASMegatron(graph: IRGraph, resource): """ Tensor + Data Parallelism """ - tp = 2 + tp = min(2, resource.ngpus) dp = resource.ngpus // tp linears = [node for node in graph.nodes() if node.name == 'linear'] for idx, node in enumerate(linears): From bc57873ca2f0ebe72d5382c35ae8d1b34e1331fd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Jun 2022 15:03:53 +0800 Subject: [PATCH 0884/1892] enable residual with updated abstraction --- cube/graph/function/function.py | 14 + cube/graph/gener/gen.py | 370 +++++++------- cube/graph/gener/layout.py | 15 +- cube/graph/graph.py | 121 +++-- cube/graph/parser/converter.py | 23 +- cube/ir/adapter/adapter.py | 16 +- cube/ir/adapter/prim.py | 2 +- cube/ir/cten.py | 39 +- cube/ir/operator.py | 20 +- cube/ir/tensor.py | 825 +++++++++++++----------------- cube/logics/model.py | 48 -- cube/runtime/function/function.py | 10 +- 12 files changed, 682 insertions(+), 821 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 73e9368d..7e4d7995 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -5,6 +5,7 @@ import warnings from cube.ir.cten import IRTensor +from cube.ir.tensor import IRFullTensor from cube.graph.function.einops import ShapeAnno, OpAnno, IREinops from cube.graph.function.conv import IRConv2D from cube.graph.function.conv import IRConv3D @@ -727,6 +728,19 @@ def Embedding(signature, inputs: List): return IREinops(signature, [anno], [itensor, weight], 'embedding', padding_idx=padding_idx, start=start, stop=stop) +def MultiRef(signature, inputs: List[IRFullTensor]): + """ + cube.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] + """ + signature = 'cube.runtime.function.multiref' + itensor, times = inputs + assert isinstance(itensor, IRFullTensor), "require all inputs to be IRSubTensor" + assert isinstance(times, int), "require int for second input" + anno = '* -> ' + ', '.join('*' for _ in range(times)) + node = IREinops(signature, [anno], [itensor], 'multiref', times=times) + return node + + def ScriptEinOps(signature, inputs): """ apply_for_scriptable_torch(recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str) -> torch.Tensor: diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 1675e36c..aec27c59 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -7,7 +7,6 @@ from cube.ir.adapter import IRAdapter, IRWeightReducer from cube.ir.operator import IRBpOperation, IRFwOperation -from cube.ir.cten import IRCell from cube.ir.adapter.prim import IRAdapterPrim from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim @@ -54,26 +53,25 @@ def gen_weight(graph: IRGraph) -> IRGraph: if not isinstance(fnode, IRFwOperation): continue devid = fnode.device[0] - for input in fnode.inputs(): - if isinstance(input, IRSubTensor) and input.is_param(): - grad = input.grad - if grad is None: - continue + for wtensor in fnode.inputs(): + if isinstance(wtensor, IRSubTensor) and wtensor.is_param(): + grad: Optional[IRSubTensor] = wtensor.grad + if grad is None: continue # nothing to sync - if grad.valmap == ValueMap(0, 1): + if grad.valmap == (0, 1): continue - if input._id not in grads: - grads[input._id] = dict() - weights[input._id] = input - if devid not in grads[input._id]: - grads[input._id][devid] = list() - if grad in grads[input._id][devid]: + if wtensor._id not in grads: + grads[wtensor._id] = dict() + weights[wtensor._id] = wtensor + if devid not in grads[wtensor._id]: + grads[wtensor._id][devid] = list() + if grad in grads[wtensor._id][devid]: raise RuntimeError( "Find two same gradient (not expected). " "This is usually due to replicated node assigned to same device. " f"\nCheck node:\n\t{fnode}" ) - grads[input._id][devid].append(grad) + grads[wtensor._id][devid].append(grad) # step 2: generate reducers. # reducers: tuple(ranks): List[weight] reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() @@ -108,75 +106,90 @@ def gen_activation(graph: IRGraph) -> IRGraph: # backward will gen in forward if ftensor.is_param() or ftensor.is_grad(): continue - adapters = IRAdapterGener.gen_fulltensor(ftensor) - if len(adapters) == 0: + # no consumer usually mean loss + if len(ftensor.consumers) == 0: + continue + # no require for communication + if len(ftensor.consumers) == 1 and len(ftensor.producers) == 0 and \ + ftensor.consumers[0].device == ftensor.producers[0].device: + continue + + # print(f'==> analyzing full tensor: {ftensor}') + # print('producer:') + # for ptensor in ftensor.ptensors: + # print(ptensor, 'device:', ptensor.device) + # print('consumer') + # for ctensor in ftensor.ctensors: + # print(ctensor, 'device:', ctensor.device) + # print('') + + ptensors, ctensors = ftensor.ptensors, ftensor.ctensors + pdevs = tuple(ptensor.device[0] for ptensor in ptensors) + cdevs = tuple(ctensor.device[0] for ctensor in ctensors) + + fadapter = None + # Case 1: sharing device (in-shard) + if set(pdevs) == set(cdevs) and len(pdevs) > 1 and \ + len(set(pdevs)) == len(ptensors) and len(set(cdevs)) == len(ctensors): + fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) + + # Case 2: sperating device (cross-shard) + if len(set(pdevs).intersection(cdevs)) == 0: + pass + + # Case 3: General cases + # warnings.warn('The adapter is generated using + if fadapter is None: + fadapter = IRAdapterGener.gen_general(ftensor) + + badapter: Optional[IRAdapter] = fadapter.mirror + + if (badapter is not None and len(fadapter.prims) == 0 and len(badapter.prims) == 0) or \ + (badapter is None and len(fadapter.prims) == 0): continue + # insert forward adapter - fidx = min([graph.nodes().index(c) for c in ftensor.consumers]) - for fadapter in adapters: - graph._nodes.insert(fidx, fadapter) - # insert bacward adapter - bidx = None if ftensor.grad is None else min([graph.nodes().index(c) for c in ftensor.grad.consumers]) - for fadapter in adapters: - # insert backward adapter - badapter: IRAdapter = fadapter.mirror - if badapter is not None: - assert isinstance(bidx, int), "have backward adapter but no gradient required." - graph._nodes.insert(bidx, badapter) + fidx = min([graph.nodes().index(consumer) for consumer in ftensor.consumers]) + graph._nodes.insert(fidx, fadapter) + + # insert backward + if badapter is not None: + bidx = min(graph.nodes().index(consumer) for consumer in ftensor.grad.consumers) + graph._nodes.insert(bidx, badapter) return graph @staticmethod - def gen_fulltensor(ftensor: IRFullTensor) -> List[IRAdapter]: - # print(f'analyzing ftensor: {ftensor}') - # print(f'ptensors: {ftensor.ptensors}') - # print(f'ctensors: {ftensor.ctensors}') - if len(ftensor.consumers) == 0: - return [] - pdevs = set() - for pnode in ftensor.producers: - pdevs.update(pnode.device) - cdevs = set() - for cnode in ftensor.consumers: - cdevs.update(cnode.device) - - # sharing devices - if pdevs == cdevs and len(pdevs) > 1 and \ - len(pdevs) == len(ftensor.producers) and \ - len(cdevs) == len(ftensor.consumers): - # TODO: enable tensor fusion of tensors on same device - return IRAdapterGener.gen_gridlayout(ftensor) - - # no-sharing devices - # elif len(pdevs.intersection(cdevs)) == 0: - # print(f'detect no intersection') - # return [] - - # general cases - warnings.warn('The adapter is generated using inefficient P2P send/recv') - fprims, bprims = [], [] - for subtensor in ftensor.ctensors: - fprims += IRAdapterGener.gen_subtensor(subtensor) - fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) - fadapter.prims = fprims - grad: IRFullTensor = ftensor.grad - # TODO: understand why grad cannot be None in inference-only - if grad is not None and (len(grad.ptensors) != 0 or len(grad.ctensors) != 0): - for subtensor in grad.ctensors: - bprims += IRAdapterGener.gen_subtensor(subtensor) - badapter = IRAdapter(grad.ptensors, grad.ctensors) - badapter.prims = bprims - IRCell.make_pair(fadapter, badapter) - if len(fprims) == 0 and len(bprims) == 0: - return [] - return [fadapter] + def gen_fulltensor(ftensor: IRFullTensor, allow_reorder=False) -> Optional[IRAdapter]: + """ + Generate forward / backward adapter for fulltensor + """ + ptensors, ctensors = ftensor.ptensors, ftensor.ctensors + pdevs = tuple(ptensor.device[0] for ptensor in ptensors) + cdevs = tuple(ctensor.device[0] for ctensor in ctensors) + + # Case 1: sharing device (in-shard) + if set(pdevs) == set(cdevs) and len(pdevs) > 1 and \ + len(set(pdevs)) == len(ptensors) and len(set(cdevs)) == len(ctensors): + return IRAdapterGener.gen_in_shard(ftensor, allow_reorder) + + # Case 2: sperating device (cross-shard) + if len(set(pdevs).intersection(cdevs)) == 0: + pass + + # Case 3: General cases + # warnings.warn('The adapter is generated using inefficient P2P send/recv') + return IRAdapterGener.gen_general(ftensor) @staticmethod - def gen_gridlayout(ftensor: IRFullTensor) -> List[IRAdapter]: + def gen_in_shard(ftensor: IRFullTensor, allow_reorder=False) -> Optional[IRAdapter]: """ - Generate adapters for connecting producer with consumer with - shared devices for forward and backward. + Generate communication for sharing devices (SPMD-like) + + @param ftensor: IRFullTensor + @param ptensors: List[IRSubTensor]: produced subtensors + @param ctensors: List[IRSubTensor]: consumed subtensors - ftensor: IRFullTensor: forward full tensor. + @return adapter Optional[IRAdapter]: generated adapter. """ # producer grid layout ilayout = GridLayout.togrid(ftensor, ftensor.ptensors) @@ -185,35 +198,32 @@ def gen_gridlayout(ftensor: IRFullTensor) -> List[IRAdapter]: ctensors = [None] * len(devs) for ctensor in ftensor.ctensors: idx = devs.index(ctensor.device) - assert ctensors[idx] is None, "same device of different tensors" ctensors[idx] = ctensor + assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" # consumer grid layout olayout = GridLayout.togrid(ftensor, ctensors) - # print(f'forward full tensor: {ftensor}\n producer: {ilayout}, consumer: {olayout}') # find path paths, fprims = ilayout.path(olayout) # re-assign the operator if miss-ordered names, from_dev, to_dev = [], [], [] - reorder : Dict[str, Tuple[int, int]] = dict() for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): assert len(itensor.device) == 1 and len(otensor.device) == 1, \ "Expect tensor only has one device. Report this as a bug" if itensor.device != otensor.device: - inode, onode = itensor._cell, otensor._cell - names.append(f'{onode.name}{onode._id}') + inode, onode = itensor.cell, otensor.cell + names.append(f'{onode.name}{onode.cid}') from_dev.append(onode.device[0]) to_dev.append(inode.device[0]) - onode.device = inode.device - if onode.mirror is not None: - onode.mirror.device = inode.device - if len(reorder) > 0: - warnings.warn(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') - - # print('find path:') - # for path in paths: print(path) - # print('comm prims:') - # for prim in fprims: print(prim) + if allow_reorder: + onode.device = inode.device + if onode.mirror is not None: + onode.mirror.device = inode.device + else: + raise RuntimeError("device mismatch. Try to enable reorder") + if len(names) > 0: + print(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') + fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) fadapter.prims = fprims @@ -229,145 +239,113 @@ def gen_gridlayout(ftensor: IRFullTensor) -> List[IRAdapter]: ptensors[idx] = ptensor ilayout = GridLayout.togrid(grad, ptensors) olayout = GridLayout.togrid(grad, grad.ctensors) - # print(f'backward full tensor: {grad}\n producer: {ilayout}, consumer: {olayout}') paths, bprims = ilayout.path(olayout) # check the device order for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): assert len(itensor.device) == len(otensor.device), "backward device not match" - # print('find path:') - # for path in paths: print(path) - # print('comm prims') - # for prim in bprims: print(prim) badapter = IRAdapter(grad.ptensors, grad.ctensors) badapter.prims = bprims - IRCell.make_pair(fadapter, badapter) - if len(fprims) == 0 and len(bprims) == 0: - return [] - return [fadapter] + IRAdapter.make_pair(fadapter, badapter) + + return fadapter + + @staticmethod + def gen_cross_shard(ftensor: IRFullTensor, ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> Optional[IRAdapter]: + pass + + @staticmethod + def gen_general(ftensor: IRFullTensor) -> IRAdapter: + fprims = [] + for ctensor in ftensor.ctensors: + fprims += IRAdapterGener.gen_subtensor(ctensor, ftensor.ptensors) + fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) + if ftensor.grad is not None: + bprims = [] + for cgrad in ftensor.grad.ctensors: + bprims += IRAdapterGener.gen_subtensor(cgrad, ftensor.grad.ptensors) + badapter = IRAdapter(ftensor.grad.ptensors, ftensor.grad.ctensors) + IRAdapter.make_pair(fadapter, badapter) + return fadapter @staticmethod - def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: + def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRAdapterPrim]: """ - Generate communication prims for a sub-tensor. - The subtensor should be a IRSubTensor of consumer. + Generate communiction primitives for ctensor - The generation takes three stages: select, move, merge + @param ctensor IRSubTensor: the consumed tensor as destination + @param ptensors List[IRSubTensor]: the produced tensors as source + + @return prims List[IRAdapterPrim]: the primitives for adapter """ - ftensor = subtensor.parent # category to local tensor and remote tensor - local = [t for t in ftensor.ptensors if t.device == subtensor.device] - remote = [t for t in ftensor.ptensors if t.device != subtensor.device] - # consumers before this consumer can also be considered as input - cidx = ftensor.consumers.index(subtensor._cell) - for ctensor in ftensor.ctensors[:cidx]: - if subtensor.device == ctensor.device: - if ctensor not in local: - local.append(ctensor) - # TODO: also consider consumers on other devices - # else: - # if ctensor not in remote: - # remote.append(ctensor) + local = [t for t in ptensors if t.device == ctensor.device] + remote = [t for t in ptensors if t.device != ctensor.device] prims = [] # ==== select ==== # intersections = [] # check local - for tensor in local: - common = tensor.common(subtensor) - if tensor == subtensor: + for itensor in local+remote: + if itensor.device == ctensor.device and itensor == ctensor: return [] - elif common == subtensor: - indmap = [] - for islicer, oslicer in zip(tensor.indmap.get(), common.indmap.get()): - start = oslicer.start - islicer.start - stop = start + oslicer.stop - oslicer.start - indmap.append(slice(start, stop, 1)) - valmap = ValueMap(0, 1) - common.cell = subtensor.cell - prims.append(SelectPrim(tensor, indmap, valmap, common)) - return prims - # check local + remote - if len(intersections) == 0: - for itensor in local+remote: - if not itensor.overlap(subtensor): - continue - common = itensor.common(subtensor) - common.cell = itensor.cell - # print(f'get common: {common.extra_repr()}') - intersections.append(common) - if common == itensor: - continue - indmap = [] - for islicer, oslicer in zip(itensor.indmap.get(), common.indmap.get()): - start = oslicer.start - islicer.start - stop = start + oslicer.stop - oslicer.start - indmap.append(slice(start, stop, 1)) - assert itensor.valmap == common.valmap or itensor.valmap == ValueMap(0,1), \ - f"Not supported value select: {itensor.valmap} -> {common.valmap}" - valmap = ValueMap(0, 1) - prims.append(SelectPrim(itensor, indmap, valmap, common)) - # TODO: check union == subtensor - if common == subtensor: - break + common: Optional[IRSubTensor] = itensor.common(ctensor) + if common is None: + continue + common.cell = itensor.cell + intersections.append(common) + # create select primitive + indmap = [] + for dim in range(itensor.ndims): + (s1, e1), (s2, e2) = itensor.indmap[dim], common.indmap[dim] + start = s2 - s1 + end = start + e2 - s2 + indmap.append((start, end)) + indmap = IndexMap(tuple(indmap)) + assert itensor.valmap == common.valmap, "Value map not same" + valmap = ValueMap((0, 1)) + select_prim = SelectPrim(itensor, indmap, valmap, common) + if itensor.device == ctensor.device and common == ctensor: + return [select_prim] + prims.append(select_prim) + # TODO: check union == subtensor + if common == ctensor: + break + # print(intersections) # ====== move ===== # tmoved = [] for tensor in intersections: - assert len(tensor.device) == 1 and len(subtensor.device) == 1, "Expected only one device." + assert len(tensor.device) == 1 and len(ctensor.device) == 1, "Expected only one device." mtensor = tensor - if tensor.device != subtensor.device: + if tensor.device != ctensor.device: mtensor = copy.copy(tensor) - mtensor.cell = subtensor.cell + mtensor.cell = ctensor.cell prims.append(MovePrim(tensor, mtensor)) tmoved.append(mtensor) # ===== merge ===== # remain_tensors: List[IRSubTensor] = copy.copy(tmoved) - if subtensor in remain_tensors: + if ctensor in remain_tensors: return prims out = None - while out != subtensor: + while out != ctensor: out, merged = None, False for idx1 in range(len(remain_tensors) - 1): for idx2 in range(idx1, len(remain_tensors)): t1, t2 = remain_tensors[idx1], remain_tensors[idx2] - # check reducable - if t1.indmap == t2.indmap and t1.valmap.chunk_num == t2.valmap.chunk_num: - vid1, vid2 = t1.valmap.idx, t2.valmap.idx - # sum e.g., 0,1 but not 1,2 - if min(vid1, vid2) % 2 == 0 and abs(vid1-vid2) == 1: - vid = min(vid1, vid2) // 2 - valmap = ValueMap(vid, t1.valmap.chunk_num // 2) - out = subtensor.parent.select(t1.indmap, valmap, t1.shape) - out.cell = subtensor.cell - prims.append(SumPrim([t1, t2], out)) - merged = True - break - # try merge dimension - elif t1.valmap == t2.valmap: - cat_dim: Dict[int, List[IRSubTensor]] = dict() - indmap = [] - for dim, (s1, s2) in enumerate(zip(t1.indmap.get(), t2.indmap.get())): - if s1 != s2: - if min(s1.stop, s2.stop) == max(s1.start, s2.start): - if s1.start < s2.start: - cat_dim[dim] = [t1, t2] - else: - cat_dim[dim] = [t2, t1] - indmap.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop), 1)) - else: - cat_dim[dim] = None - indmap.append(None) - else: - indmap.append(s1) - if None in indmap or len(cat_dim) > 1: - continue - indmap = IndexMap(tuple(indmap)) - valmap = t1.valmap - out = t1.parent.select(indmap, valmap, indmap.shape) - out.cell = subtensor.cell - cdim = list(cat_dim.keys())[0] - prims.append(MergeDimPrim(cat_dim[cdim], out, cdim)) + catdim = t1.catdims(t2) + if catdim is not None: + tensors = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] + out = tensors[0].concat(tensors[1], dim=catdim) + out.cell = ctensor.cell + prims.append(MergeDimPrim(tensors, out, catdim)) + merged = True + break + # reduction + if t1.accumable(t2): + out = t1.accum(t2) + out.cell = ctensor.cell + prims.append(SumPrim([t1, t2], out)) merged = True break if merged: @@ -376,11 +354,11 @@ def gen_subtensor(subtensor: IRSubTensor) -> List[IRAdapterPrim]: remain_tensors.append(out) break if out is None: - ptensors = '\n\t'.join(t.extra_repr() for t in ftensor.ptensors) + ptensors = '\n\t'.join(t.extra_repr() for t in ptensors) raise RuntimeError( f"Fail to build adapter.\n" - f"FullTensor:{ftensor}\n" + f"FullTensor:{ctensor.parent}\n" f"Producers:\n\t{ptensors}\n" - f"SubTensor:\n\t{subtensor.extra_repr()}" + f"SubTensor:\n\t{ctensor.extra_repr()}" ) return prims diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 1c56b4a4..822d0458 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -289,16 +289,16 @@ def iter_idx(dims: List[int]) -> Tuple[int]: yield (i,) + indices # generate tensor for each index for indices in iter_idx([v,]+dims): - valmap = ValueMap(indices[0], v) + valmap = ValueMap((indices[0], v)) indmap = [] shape = [] for dim, (nchunk, index) in enumerate(zip(dims, indices[1:])): assert ftensor.shape[dim] % nchunk == 0, f"not dividable for {nchunk} chunks over dim {dim}" csize = ftensor.shape[dim] // nchunk start = csize * index - indmap.append(slice(start, start+csize, 1)) + indmap.append((start, start+csize)) shape.append(csize) - subtensor = ftensor.select(tuple(indmap), valmap, shape) + subtensor = ftensor.select(tuple(indmap), valmap) # replicate subtensors = [copy.copy(subtensor) for _ in range(r)] all_subtensors += subtensors @@ -331,15 +331,14 @@ def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): _tindex[tid] = [len(replicas[subtensor._id])] replicas[subtensor._id].append(subtensor) # setup value - _tindex[tid].append(subtensor.valmap.idx) - vchunks.add(subtensor.valmap.chunk_num) + _tindex[tid].append(subtensor.valmap[0]) + vchunks.add(subtensor.valmap[1]) # setup dimensions for dim in range(ndims): snele = subtensor.shape[dim] - start = subtensor.indmap.get()[dim].start + start = subtensor.indmap[dim][0] fnele = ftensor.shape[dim] if fnele % snele != 0 or start % snele != 0: - print(subtensor, dim) raise RuntimeError( f"dimension split error:\n" f"Full Tensor: {ftensor}\n" @@ -353,7 +352,7 @@ def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): raise RuntimeError(f"different replicas: {nreplicas}") _replica = list(nreplicas)[0] # value (V) - nchunks = set(t.valmap.chunk_num for t in subtensors) + nchunks = set(t.valmap[1] for t in subtensors) if len(nchunks) != 1: raise RuntimeError(f"different value split: {nchunks}") _value = list(nchunks)[0] diff --git a/cube/graph/graph.py b/cube/graph/graph.py index b5cf6b1f..671ded7b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -9,6 +9,7 @@ from typing import Any, Union, Tuple, List, Optional, Dict import copy +from cube.graph.function.function import MultiRef from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator @@ -425,10 +426,67 @@ def get_outputs(nodes: List[IRCell]): outputs.append(output) return outputs + @staticmethod + def from_logic_graph(nodes: List[IRCell], + inputs: List[IRFullTensor], outputs: List[IRFullTensor], + module_name: str): + # handle multi-consumed tensor + consumers: Dict[IRFullTensor, List[IRCell]] = dict() + producers: Dict[IRFullTensor, IRCell] = dict() + for node in nodes: + ftensors = set() + for ftensor in node.inputs(): + # remove redundant tensors within an operator + if isinstance(ftensor, IRFullTensor) and ftensor._id not in ftensors: + ftensors.add(ftensor._id) + if ftensor not in consumers: + consumers[ftensor] = [] + consumers[ftensor].append(node) + for ftensor in node.outputs(): + if isinstance(ftensor, IRFullTensor): + producers[ftensor] = node + for ftensor, cnodes in consumers.items(): + if len(cnodes) == 1: continue + itensors = [ftensor.like() for _ in range(len(cnodes))] + for itensor, consumer in zip(itensors, cnodes): + while ftensor in consumer.inputs(): + idx = consumer.inputs().index(ftensor) + consumer.set_input(idx, itensor) + # create and insert multiref operation + multiref = MultiRef(None, [ftensor, len(cnodes)]) + for idx, itensor in enumerate(itensors): + multiref.set_output(idx, itensor) + multiref.infer_shape() + idx = nodes.index(producers[ftensor]) + 1 if ftensor in producers else 0 + nodes.insert(idx, multiref) + + # instantiate graph inputs / outputs + for idx, tensor in enumerate(inputs): + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + inputs[idx] = tensor + for idx, tensor in enumerate(outputs): + if isinstance(tensor, IRFullTensor): + tensor = tensor.tosub() + outputs[idx] = tensor + + # instantiate to subtensor + for node in nodes: + for idx, ftensor in enumerate(node.inputs()): + ftensors = set() + if isinstance(ftensor, IRFullTensor): + subtensor = ftensor.tosub() + node.set_input(idx, subtensor) + for idx, ftensor in enumerate(node.outputs()): + if isinstance(ftensor, IRFullTensor): + subtensor = ftensor.tosub() + node.set_output(idx, subtensor) + graph = IRGraph(nodes, inputs, outputs, module_name) + return graph + ##### Partition Primitives ##### - def replicate(self, node: Union[IRFwOperation, IRDataOperation], - times=1, reset_dependency=True) -> Optional[List[IRCell]]: + def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> List[IRCell]: """ Partition Primitive: - replicate: replicate a forward or data operation multiple times. @@ -446,23 +504,22 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], if node not in self.nodes(): raise RuntimeError(f"Op {node} not exsits") - - fnodes = [node.replicate() for _ in range(times - 1)] + + fidx = self.detach(node) + fnodes = [node.replicate() for _ in range(times)] # insert forward - fidx = self.nodes().index(node) for idx, fnode in enumerate(fnodes): - self.attach(fnode, fidx + idx + 1) + self.attach(fnode, fidx + idx) # insert backward if isinstance(node.mirror, IRBpOperation): + bidx = self.detach(node.mirror) for fnode in fnodes: fnode.gen_backward() bnodes = [fnode.mirror for fnode in fnodes][::-1] - bidx = self.nodes().index(node.mirror) for idx, bnode in enumerate(bnodes): self.attach(bnode, bidx + idx) - if reset_dependency: - self.reset_dependency() - return [node] + fnodes + #TODO: dependency set + return fnodes def partition(self, node: Union[IRFwOperation, IRDataOperation], algo: GenericDistAlgo, **config) -> Optional[List[IRCell]]: @@ -582,12 +639,10 @@ def merge(self, nodes: List[IRCell], target_node: IRCell): updated.add(fnode._id) return True - def identity(self, input_tensor, dst_op): - raise NotImplementedError - ## Spatial Primitives ## - def assign(self, op: IRCell, ranks: Union[int, List[int]]): + def assign(self, node: Union[IRFwOperation, IRBpOperation], + ranks: Union[int, Tuple[int]]): """ Assign an operator (subgraph) to (multiple) rank(s). @@ -597,25 +652,20 @@ def assign(self, op: IRCell, ranks: Union[int, List[int]]): Corresponding backward operators (if have) will also be replicated and assigned to the same device with it's forward operator - Returns: - True if assigned successfully. - False if not. - """ - if op not in self._nodes: - raise KeyError(f"{op} is not in the graph") - if isinstance(ranks, int): - ranks = [ranks] - if not all([isinstance(rank, int) for rank in ranks]): - raise TypeError("Expected rank to be int") - if len(ranks) > 1: - ops = self.replicate(op, times=len(ranks)) - else: - ops = [op] - for op, rank in zip(ops, ranks): - op.device = rank - # pytorch requirement: forward + backward happened on same device - if op.mirror is not None: - op.mirror.device = rank + @param node Union[IRFwOperation, IRBpOperation]: operator + @param ranks Tuple[int, Tuple[int]]: assigned ranks + + @return sucess bool: always true + """ + assert node in self._nodes, f"{node} is not in the graph" + ranks = (ranks,) if isinstance(ranks, int) else ranks + assert all([isinstance(rank, int) for rank in ranks]), "Expected rank to be int" + nodes = [node] if len(ranks) == 1 else self.replicate(node, times=len(ranks)) + for node, rank in zip(nodes, ranks): + node.device = rank + if isinstance(node.mirror, IRBpOperation): + bnode = node.mirror + bnode.device = rank return True ## Schedule Policy Primitives ## @@ -767,12 +817,13 @@ def extra_repr(self): dscp += f"Inputs: {self.inputs()}\n" # nodes for node in self._nodes: - succ_node_ids = [node._id for node in node.successors()] + # succ_node_ids = [node._id for node in node.successors()] # succ_node_ids = [None] * len(node.outputs()) # for out_idx in range(len(node.outputs())): # node_list = [snode._id for snode in node.successors(out_idx)] # succ_node_ids[out_idx] = node_list - dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}" + # dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}" + dscp += f"\n{node}" # outputs dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" return dscp diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 6a90b2b2..bde06b93 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,6 +1,5 @@ from typing import Optional, List -from cube.ir.cten import IRTensor from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser from cube.graph import IRGraph @@ -21,27 +20,9 @@ def convert_model(model: torch.nn.Module, module_name = smodule.original_name inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) for input in inputs: - if isinstance(input, IRTensor): + if isinstance(input, IRFullTensor): input.requires_grad = False - # convert to SubTensor - for idx, tensor in enumerate(inputs): - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - inputs[idx] = tensor - for idx, tensor in enumerate(outputs): - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - outputs[idx] = tensor - for node in nodes: - for idx, tensor in enumerate(node.inputs()): - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - node.set_input(idx, tensor) - for idx, tensor in enumerate(node.outputs()): - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - node.set_output(idx, tensor) - graph = IRGraph(nodes, inputs, outputs, module_name) + graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) return graph diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index b90f3fbe..ff40018a 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -20,7 +20,7 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): self._inputs = inputs self._outputs = outputs - self._prims: Optional[List[IRAdapterPrim]] = None + self._prims: List[IRAdapterPrim] = [] self._differentiable = False device = set() @@ -37,20 +37,6 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): assert not (is_fw and is_bw), "An IRAdapter cannot serve for both forward and backward stage" self._forward = is_fw - @property - def prims(self) -> List[IRAdapterPrim]: - if self.is_forward: - if self.differentiable(): - return self.diffcolls - else: - return self.forward - else: - if self.differentiable(): - # not able to see - return [] - else: - return self.backward - @property def prims(self) -> List[IRAdapterPrim]: return copy.copy(self._prims) diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index a5efaac8..223b2040 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -110,7 +110,7 @@ def __init__(self, itensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, otensor: IRSubTensor): - super().__init__([itensor], [otensor], indmap=indmap, valmap=(valmap.idx, valmap.chunk_num)) + super().__init__([itensor], [otensor], indmap=indmap, valmap=valmap) self.signature = f"cube.runtime.adapter.select" def __repr__(self): diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 2a3a46dd..764bf93f 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -73,6 +73,15 @@ def __init__(self, # the comment for code generation self._comment: Optional[str] = None + @property + def cid(self) -> int: + """ + Get cell id + + @return cid int: the cell id. + """ + return self._id + @property def device(self): return copy.copy(self._device) @@ -399,7 +408,7 @@ class IRTensor: and will be translated to None in code generation. """ - _attr = ['name', '_is_param', '_is_grad', '_requires_grad', '_dtype', '_grad_accum'] + _meta = ['name', '_is_param', '_is_grad', '_requires_grad', '_dtype'] def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): @@ -417,8 +426,15 @@ def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): # tensor gradient self._requires_grad: bool = True self._grad: Optional[Union[IRTensor, float]] = None - # multi-reference id - self._grad_accum: Tuple[int] = (0, 1) + + @property + def tid(self) -> int: + """ + Get tensor id + + @return cid int: the tensor id. + """ + return self._id @property def dtype(self) -> IRDType: @@ -458,23 +474,6 @@ def device(self, val: Union[int, List[int]]): "tensor placement is not allowed to set manually" ) - @property - def grad_accum(self) -> Tuple[int, int]: - return self._grad_accum - - @grad_accum.setter - def grad_accum(self, accum: Optional[Tuple[int, int]]): - """! - Set gradient accumulation: (idx, chunks) - """ - if accum is None: - self._grad_accum - else: - assert len(accum) == 2 and all(isinstance(acc, int) for acc in accum), \ - "Expected accum to be [int, int]: [idx, chunks]" - assert accum[0] < accum[1] - self._grad_accum = tuple(accum) - def as_param(self): """ Set the tensor as trainable parameter diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 7ee1bf2d..24834c4f 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -151,18 +151,12 @@ def gen_backward(self) -> IRCell: bnode = IRBpOperation(self) return bnode - def __repr__(self): - sign = self.signature.split('.')[-1] - dscp = f'FwOp{self._id}-{self.device}(sign={sign}, inputs={self.inputs()}, outputs={self.outputs()})' - return dscp - - def module_repr(self) -> str: - """ - Weight-hidden string representation - """ + def __repr__(self) -> str: sign = self.signature.split('.')[-1] ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_param()] - dscp = f'FwOp{self._id}-{self.device}(sign={sign}, inputs={ins}, outputs={self.outputs()})' + dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " + f"inputs={ins}, " + f"outputs={self.outputs()})") return dscp @@ -231,7 +225,9 @@ def replicate(self): return cpy def __repr__(self) -> str: - dscp = f'BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, inputs={self.inputs()}, outputs={self.outputs()})' + dscp = (f"BwOp{self._id}-{self.device}(FwOp{self.mirror._id}, " + f"inputs={self.inputs()}, " + f"outputs={self.outputs()})") return dscp @@ -295,7 +291,7 @@ def algorithms(self, tag: Optional[str] = None): return template(self) def __repr__(self): - dscp = f'DataLoader{self._id}-{self.device}(outputs={self.outputs()})' + dscp = (f"DataLoader{self._id}-{self.device}(outputs={self.outputs()})") return dscp def module_repr(self) -> str: diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 4e817fd1..1026c49e 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -18,46 +18,46 @@ val is always (0/1) """ - -from typing import List, Optional, Union, Tuple -import copy -import math +from typing import List, Optional, Union, Tuple, NewType, Dict from cube.ir.cten import IRCell, IRTensor import cube.ir.dtype as irdtype +StartEnd = NewType('[start:end)', Tuple[int, int]) +IdxChunk = NewType('(index, chunks)', Tuple[int, int]) -class IndexMap: - def __init__(self, indmap): +class IndexMap: - if not isinstance(indmap, tuple): - raise TypeError("Expected indmap to be a tuple") + def __init__(self, indmap: Tuple[StartEnd]): + """! + Create an index map. - if not all([isinstance(s, slice) for s in indmap]): - raise NotImplementedError( - "Only support for sliced index mapping" - ) - self._indices: List[slice] = indmap + @param indmap Union[Tuple[StartEnd], IndexMap]: index range [start, end) for each dimension + + @return indmap IndexMap: the created new instance of index map. + """ + if isinstance(indmap, IndexMap): + indmap = indmap.indices + assert all(isinstance(dim, tuple) and len(dim) == 2 for dim in indmap), "expected Tuple[Tuple[int, int]]" + self._indices: Tuple[StartEnd] = tuple(indmap) + self._shape = tuple(end - start for (start, end) in self._indices) def __eq__(self, other): if isinstance(other, IndexMap): if self.ndims != self.ndims: return False - for myslicer, oslicer in zip(self.get(), other.get()): - mstart, mstop = myslicer.start, myslicer.stop - mstep = myslicer.step if myslicer.step is not None else 1 - ostart, ostop = oslicer.start, oslicer.stop - ostep = oslicer.step if oslicer.step is not None else 1 - if mstart != ostart or mstop != ostop or mstep != ostep: + for dim in range(self.ndims): + if self.indices[dim] != other.indices[dim]: return False return True return False - def get(self): - """ - Get indmap - """ + def __hash__(self) -> int: + return hash(tuple([self.ndims]+list(self._indices))) + + @property + def indices(self) -> Tuple[StartEnd]: return self._indices @property @@ -67,74 +67,40 @@ def ndims(self) -> int: """ return len(self._indices) - @property - def neles(self) -> int: - """ - Number of elements of the index map - """ - nelements = 1 - for slicer in self._indices: - count = slicer.stop - slicer.start - if slicer.step: - count = int(count // slicer.step) - nelements *= count - return nelements - @property def shape(self) -> List[int]: """ Get the shape of the slice """ - shape = list() - for slicer in self._indices: - count = slicer.stop - slicer.start - if slicer.step: - count = int(count // slicer.step) - shape.append(count) - return shape + return self._shape - def map(self, submap): - """ + def map(self, submap: Tuple[StartEnd]): + """! Map from the current indmap by sub_indices. - Args: - sub_indices: IndexMap - - Returns: - sub_indices: IndexMap + @param submap Union[Tuple[StartEnd], IndexMap]: IndexMap of this IndexMap + @return indmap IndexMap: the mapped IndexMap """ - if not isinstance(submap, IndexMap): - raise TypeError("Expected IndexMap") - if self.ndims != submap.ndims: - raise ValueError("Expected same length of sub_indices") - - # e.g., (slice(0, M), slice(0, int(K // 2)) + submap: IndexMap = IndexMap(submap) + assert self.ndims == submap.ndims, "Expected same dimensions of submap" sub = list() - for dim_indices, dim_sub_indices in zip(self.get(), submap.get()): - start, stop = dim_indices.start, dim_indices.stop - step = dim_indices.step if dim_indices.step else 1 - - sub_start, sub_stop = dim_sub_indices.start, dim_sub_indices.stop - sub_step = dim_sub_indices.step if dim_sub_indices.step else 1 - - new_start = start + sub_start - new_stop = new_start + sub_stop - sub_start - new_step = step * sub_step - if new_stop > stop: - raise ValueError("Trying to map a index out of range") - sub.append(slice(new_start, new_stop, new_step)) + for dim in range(self.ndims): + s1, e1 = self.indices[dim] + s2, e2 = submap.indices[dim] + start = s1 + s2 + end = start + e2 - s2 + assert end <= e1, f"select out of boundary at dim {dim}: ({self})[{submap}]" + sub.append((start, end)) return IndexMap(tuple(sub)) - def overlap(self, other): + def overlap(self, other) -> bool: """ Check if this indmap overlapped with the other - Args: - other: IndexMap + @param other IndexMap - Returns: - Boolean: True has overlap, otherwise False + @return overlap bool: True has overlap, otherwise False """ if not isinstance(other, IndexMap): raise TypeError("Expected IndexMap") @@ -142,112 +108,34 @@ def overlap(self, other): if other.ndims != self.ndims: raise TypeError("Expected same dimension") - for slicer1, slicer2 in zip(self.get(), other.get()): - start1, stop1 = slicer1.start, slicer1.stop - step1 = slicer1.step if slicer1.step else 1 - - start2, stop2 = slicer2.start, slicer2.stop - step2 = slicer2.step if slicer2.step else 1 - - if step1 == step2: - if min(stop1, stop2) <= max(start1, start2): - return False - elif start1 % step1 != start2 % step2: - return False - else: - raise NotImplementedError(f"not supported for differnt steps") + for dim in range(self.ndims): + start1, end1 = self.indices[dim] + start2, end2 = other.indices[dim] + if min(end1, end2) <= max(start1, start2): + return False return True def __and__(self, other): - """ + """! Get the common part - Args: - other: IndexMap + @param other IndexMap: the other one - Returns: - IndexMap for the common part + @return indexmap IndexMap: index map for the common part """ if not self.overlap(other): return None - slices = list() - for slicer1, slicer2 in zip(self.get(), other.get()): - start1, stop1 = slicer1.start, slicer1.stop - step1 = slicer1.step if slicer1.step else 1 - - start2, stop2 = slicer2.start, slicer2.stop - step2 = slicer2.step if slicer2.step else 1 - - if step1 == step2: - start = max(start1, start2) - stop = min(stop1, stop2) - slices.append(slice(start, stop, step1)) - else: - raise NotImplementedError(f"not supported for differnt steps") - return IndexMap(tuple(slices)) - - def __sub__(self, other) -> Optional[List]: - """ - Get the remaining part. - We reuqire other should completely inside this tensor - and the remaining part should be only one tile, else - will return None - - Args: - other: IndexMap - - Returns: - IndexMap for the remaining part - """ - if not isinstance(other, IndexMap): - raise TypeError("Expected IndexMap") - if self.ndims != other.ndims: - return None - dim_common: List[List[slice]] = [list() for _ in range(self.ndims)] - dim_differ: List[List[slice]] = [list() for _ in range(self.ndims)] - for dim, (slicer1, slicer2) in enumerate(zip(self.get(), other.get())): - # self indices - start1, stop1 = slicer1.start, slicer1.stop - step1 = slicer1.step if slicer1.step else 1 - # other indices - start2, stop2 = slicer2.start, slicer2.stop - step2 = slicer2.step if slicer2.step else 1 - if step1 != 1 or step2 != 1: - return None - # no intersection - if min(stop1, stop2) <= max(start1, start2): - return None - # set common - start = max(start1, start2) - stop = min(stop1, stop2) - dim_common[dim].append(slice(start, stop, step1)) - # set difference - if start1 == start2: - if stop2 < stop1: - dim_differ[dim].append(slice(stop2, stop1, step1)) - elif stop1 == stop2: - if start1 < start2: - dim_differ.append(slice(start1, start2, step1)) - else: - raise NotImplementedError("Multipe indexmap is not supported") - indmaps = list() - splitdim = set() - slices = list() + tile = [] for dim in range(self.ndims): - common = dim_common[dim] - differ = dim_differ[dim] - if len(common) + len(differ) != 1: - raise NotImplementedError("Multipe indexmap is not supported") - if len(differ) == 1: - splitdim.add(dim) - slices.append(differ[0]) - else: - slices.append(common[0]) - indmaps.append(IndexMap(tuple(slices))) - return indmaps + start1, end1 = self.indices[dim] + start2, end2 = other.indices[dim] + start = max(start1, start2) + end = min(end1, end2) + tile.append((start, end)) + return IndexMap(tuple(tile)) - def __repr__(self): - dscp = repr(self._indices) + def __repr__(self) -> str: + dscp = ','.join(f'{start}:{end}' for (start, end) in self.indices) return dscp @@ -255,106 +143,102 @@ class ValueMap: r""" Represent the value split. - Value is represented as a summation of several variables - - value = \sigma_{i=1}^{chunk_num} a_i - - two tensors consider as same value mapping: - they have same chunk num and share the same a_i (idx) - - Note we regard these mapping as same: - 1.0 = 0.9 (a1) + 0.1 (a2) - 1.0 = 0.4 (a1) + 0.6 (a2) + replica: the replicated group: + different replicated operator (no gradient accumulation) stands for different group - The mapping doesn't consider what a1 really contains, but only - consider the variable (a) itself and number of variable. + weight: the partitioned but tensor replicated group: + different replicated tensor (gradient accumulation) stands for different group """ - def __init__(self, idx: int, chunk_num: int): - if idx >= chunk_num or idx < 0: - raise ValueError(f"Expected idx {idx} in [0, {chunk_num})") - self._idx = idx - self._chunk_num = chunk_num + def __init__(self, weight: IdxChunk): + """ + Create a value map. + @param weight Union[IdxChunk, ValueMap]: the (idx, nchunks) - @property - def idx(self): - return self._idx + @return valmap ValueMap: a new instance. + """ + if isinstance(weight, ValueMap): + weight = weight.weight + assert len(weight) == 2 and all(isinstance(i, int) for i in weight), \ + "expected weight to be (idx, nchunks)" + self._weight = weight @property - def chunk_num(self): - return self._chunk_num - - def map(self, sub_map): - if not isinstance(sub_map, ValueMap): - raise TypeError("Expected sub_map to be ValueMap") - idx = self.idx * sub_map.chunk_num + sub_map.idx - chunk_num = self.chunk_num * sub_map.chunk_num - return ValueMap(idx, chunk_num) - - def overlap(self, other): + def weight(self) -> IdxChunk: + """! + Get value partitioned chunks in tha accumulcated group + """ + return self._weight + + def overlap(self, other) -> bool: + """! + Check on value overlapping. + Note the overlap can only be within a same accumulation group and + a same replication group. + """ if not isinstance(other, ValueMap): raise TypeError("Expected ValueMap") - if self.chunk_num == other.chunk_num: - return self.idx == other.idx - else: - if self.chunk_num == 1 or other.chunk_num == 1: - return True - else: - chk1, chk2 = self.chunk_num, other.chunk_num - time1 = int(chk2 / math.gcd(chk1, chk2)) - time2 = int(chk1 / math.gcd(chk1, chk2)) - span1 = (self.idx * time1, self.idx * time1 + time1) - span2 = (other.idx * time2, other.idx * time2 + time2) - if max(span1[0], span2[0]) < min(span1[1], span2[1]): - return True - else: - return False + idx1, nchunk1 = self.weight + idx2, nchunk2 = self.weight + span1 = (idx1 * nchunk2, idx1 * nchunk2 + nchunk2) + span2 = (idx2 * nchunk1, idx2 * nchunk1 + nchunk1) + if max(span1[0], span2[0]) < min(span1[1], span2[1]): + return True + return False - def __eq__(self, other): + def __eq__(self, other) -> bool: + """! + Check whether tensor is same to other tensor. + Note we treat tensors in different replica region as different + tensors, also they may have same data in reality. + """ if isinstance(other, ValueMap): - if other.idx == self.idx and other.chunk_num == self.chunk_num: - return True + return other.weight == self.weight return False + def __hash__(self) -> int: + return hash(self._weight) + + def map(self, submap: IdxChunk): + """! + Select the value chunk at position (idx, chunk) given the current view + No change will make for the replica group. + + @param idnmap Union[ValueMap, IdxChunk]: the (index, chunk) for current view + + @return valmap ValueMap: the selected one + """ + if isinstance(submap, ValueMap): + submap: IdxChunk = submap.weight + idx, chunk = self.weight + sub_idx, sub_chunk = submap + idx = idx * sub_chunk + sub_idx + chunk = sub_chunk * chunk + return ValueMap((idx, chunk)) + def __and__(self, other): """ Find the common part + + @param other ValueMap + + @return Optional[None] """ if not isinstance(other, ValueMap): raise TypeError("Expected ValueMap for & operator") if not self.overlap(other): return None - if self.chunk_num == other.chunk_num: - return ValueMap(self.idx, self.chunk_num) - if self.chunk_num == 1: - return ValueMap(other.idx, other.chunk_num) + if self.weight[1] == other.weight[1]: + return ValueMap(self.weight) + if self.weight[1] == 1: + return ValueMap(other.weight) + elif other.weight[1] == 1: + return ValueMap(self.weight) else: - return ValueMap(self.idx, self.chunk_num) + raise ValueError(f"Not supported common value map: {self}, {other}") def __repr__(self): - return f'({self.idx}/{self.chunk_num})' - - -def _to_indmap(indmap: Union[Tuple, IndexMap]) -> IndexMap: - if not isinstance(indmap, tuple) and not isinstance(indmap, IndexMap): - raise TypeError("Expected indmap to be tuple or IndexMap") - if isinstance(indmap, tuple): - indmap = IndexMap(indmap) - return indmap - - -def _to_value_map(valmap: Union[Tuple, ValueMap, None]) -> ValueMap: - if not isinstance(valmap, tuple) and \ - not isinstance(valmap, ValueMap) and \ - not valmap is None: - raise TypeError("Expected valmap to be tuple, IndexMap or None") - if valmap is None: - valmap = ValueMap(0, 1) - elif isinstance(valmap, tuple): - if len(valmap) != 2: - raise ValueError("Expected tuple to be (idx, chunk_num)") - valmap = ValueMap(*valmap) - return valmap + return f'({self.weight[0]}/{self.weight[1]})' class IRFullTensor(IRTensor): @@ -379,10 +263,13 @@ def __init__(self, shape=None, name=None, requires_grad=True, dtype=irdtype.floa self._ctensors : List[IRSubTensor] = list() # record all created sub_tensors - self._segments : List[IRSubTensor] = list() + self._segments : Dict[(ValueMap, IndexMap), int] = dict() self.requires_grad = requires_grad + def __hash__(self) -> int: + return self._id + def __copy__(self): """ Full tensor should only exist one instance per id @@ -392,6 +279,15 @@ def __copy__(self): """ return self + def like(self): + """! + Create a IRFullTensor with same meta data but a different id. + + @return tensor IRFullTensor: the created tensor + """ + tensor = IRFullTensor(self.shape, self.name, self.requires_grad, self.dtype) + return tensor + @property def producers(self) -> List[IRCell]: """ @@ -467,12 +363,6 @@ def clear_producer_consumer(self) -> int: self._consumers = [] self._ctensors = [] - def subtensors(self): - """ - Get created sub-tensors of this tensor. - """ - return copy.copy(self._segments) - @property def grad(self) -> Optional[Union[IRTensor, float]]: return self._grad @@ -518,82 +408,41 @@ def as_grad(self): self._is_grad = True return self - def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, None], shape: List[int]): - """ + def select(self, indmap: IndexMap, valmap: ValueMap): + """! Select a SubTensor from FullTensor. - Note due to implementation issue, one value in the full tensor - cannot be splitted by different valmap - - Args: - indmap: the index of this tensor's index - - valmap: how the tensor mapped from original value - - shape: the sub_tensor shape. + @param indmap IndexMap: the index range of this tensor + @param valmap ValueMap: the value range of this tensor - Returns: - IRSubTensor + @return subtensor IRSubTensor: the selected SubTensor """ - indmap = _to_indmap(indmap) - valmap = _to_value_map(valmap) - + indmap, valmap = IndexMap(indmap), ValueMap(valmap) + keys = (indmap, valmap) + # print(f'key: {keys}, hash {hash(keys)}') # return tensor to keep id same for same sub tensor - for sub_tensor in self.subtensors(): - if sub_tensor.indmap == indmap and sub_tensor.valmap == valmap: - sub_tensor = copy.copy(sub_tensor) - return sub_tensor - - sub_tensor = IRSubTensor(self, indmap, valmap, shape) - for attr in IRFullTensor._attr: - setattr(sub_tensor, attr, getattr(self, attr)) - self._segments.append(sub_tensor) - sub_tensor._dirty_grad = True - return sub_tensor - - def overlap(self, other): - """ - Check if the two tensor is overlapped. - - Returns: - True if they are sharing co-located position in - the full tensor, otherwise False - """ - if not isinstance(other, IRTensor): - raise TypeError("Expected Tensor") - if isinstance(other, IRFullTensor): - return self == other - elif isinstance(other, IRSubTensor): - return other.parent == self + if keys in self._segments: + tid = self._segments[keys] + sub_tensor = IRSubTensor(self, indmap, valmap, tid=tid) else: - raise TypeError("Customized IRTensor not support") - - def common(self, other) -> Optional[IRTensor]: - """ - Get the common sub-tensor - - Args: - IRTensor - - Returns: - None for not overlap, - else IRSubTensor or IRFullTensor - """ - return other if self.overlap(other) else None + sub_tensor = IRSubTensor(self, indmap, valmap) + self._segments[keys] = sub_tensor.tid + return sub_tensor def tosub(self): - """ + """! Convert to SubTensor by selecting all indmap and full value + + @return sub_tensor IRSubTensor: the sub-tensor """ if self.shape is None: raise RuntimeError("Expected know shape") - slicers = list() - for dim_len in self.shape: - slicers.append(slice(0, dim_len, 1)) + indmap = [] + for dimlen in self.shape: + indmap.append((0, dimlen)) sub_tensor = self.select( - indmap=tuple(slicers), - valmap=None, - shape=self.shape + indmap=tuple(indmap), + valmap=(0, 1), ) return sub_tensor @@ -604,45 +453,38 @@ def __repr__(self): class IRSubTensor(IRTensor): - def __init__(self, full_tensor: IRTensor, - indmap: List[Union[Tuple, IndexMap]], - valmap: Optional[ValueMap] = None, shape=None): + def __init__(self, ftensor: IRFullTensor, + indmap: Union[Tuple[StartEnd], IndexMap], + valmap: Union[Tuple[StartEnd], ValueMap], + **kwargs): """ Create an IRSubTensor. - Args: - full_tensor: the full tensor - indmap: index list - valmap: the value operation to merge SubTensors into one + @param ftensor IRFullTensor: the full tensor + @param indmap IndexMap: index map + @param valmap ValueMap: value map """ - if not isinstance(full_tensor, IRFullTensor): - raise TypeError(f"Expected IRFullTensor but got {full_tensor}") - super().__init__(shape=shape, name=full_tensor.name) - + indmap, valmap = IndexMap(indmap), ValueMap(valmap) + assert isinstance(ftensor, IRFullTensor), "Expcted ftensor to be IRFullTensor" + super().__init__(shape=indmap.shape, name=ftensor.name, **kwargs) + for attr in IRFullTensor._meta: + setattr(self, attr, getattr(ftensor, attr)) + self.cell = None # the full tensor - self._full_tensor = full_tensor - + self._full_tensor = ftensor # the index from full_tensor - self._indmap = _to_indmap(indmap) - + self._indmap: IndexMap = indmap # val map - self._valmap = _to_value_map(valmap) - + self._valmap: ValueMap = valmap # grad flag self._dirty_grad = True - def __eq__(self, other): - - if isinstance(other, IRFullTensor): - return self.parent == other and \ - self.shape == other.shape and \ - self.valmap == ValueMap(0, 1) + def __eq__(self, other) -> bool: if isinstance(other, IRSubTensor): - return self.parent == other.parent and \ - self.indmap == other.indmap and \ - self.valmap == other.valmap and \ - self.shape == other.shape - return False + return self._id == other._id + + def __hash__(self) -> int: + return self._id @property def parent(self) -> IRFullTensor: @@ -652,29 +494,129 @@ def parent(self) -> IRFullTensor: return self._full_tensor @property - def indmap(self) -> IndexMap: + def indmap(self) -> Tuple[StartEnd]: """ - Return indmap list mapped to the full tensor + Get index range of each dimension of this tensor in its full tensor + + @return indices Tuple[StartEnd]: indices """ - return copy.copy(self._indmap) + return self._indmap.indices @property - def valmap(self) -> ValueMap: - return copy.copy(self._valmap) + def valmap(self) -> IdxChunk: + """ + Get value range of this tensor in tis full tensor + + @return idxchunk IdxChunk: (idx, nchunks) + """ + return self._valmap.weight @property def ndims(self) -> int: return len(self.shape) + def splitdims(self) -> Tuple[int]: + """! + Get partitioned dimensions + + @return dims int: the partitioned dimension. + """ + return tuple( + dim for dim in range(self.ndims) if self.shape[dim] != self.parent.shape[dim] + ) + + def catdims(self, other: IRTensor) -> Optional[int]: + """! + Get concatable dimensions with other IRSubTensor + + @parm other IRSubTensor + @return dim int: the concatable dimension. None means no such dimension + """ + assert isinstance(other, IRSubTensor), "expected IRSubTensor" + if other.parent != self.parent or self.valmap != other.valmap: + return None + cat_dim: int = None + for dim in range(self.ndims): + if self.indmap[dim] != other.indmap[dim]: + s1, e1 = self.indmap[dim] + s2, e2 = self.indmap[dim] + if min(e1, e2) == max(s1, s2): + if cat_dim is not None: + return None + else: + cat_dim = dim + else: + return None + return cat_dim + + def concat(self, other: IRTensor, dim: int) -> IRTensor: + """! + concat dimension with other IRSubTensor. The concatenate + order will follow the index map order. + + @param other IRSubTensor + @param dim int: the concat dimension + @return tensor IRSubTensor: the concatenated tensor + """ + assert isinstance(other, IRSubTensor), "expected IRSubTensor" + assert self.parent == other.valmap and self.valmap == other.valmap + indmap = [] + for cdim in range(self.ndims): + if cdim == dim: + (s1, e1), (s2, e2) = self.indmap[cdim], other.indmap[cdim] + assert min(e1, e2) == max(s1, s2), f"fail to concat: {cdim} should be concatable" + indmap.append((min(s1, s2), max(e1, e2))) + else: + assert self.indmap[cdim] == other.indmap[cdim], f"fail to concat: {cdim} should be same" + indmap.append(self.indmap[cdim]) + valmap = self.valmap + tensor = self.parent.select(tuple(indmap), valmap) + return tensor + + def accumable(self, tensors: Union[IRTensor, List[IRTensor]]) -> bool: + """! + Check whether tensors are accumable with this tensor + + @param: tensors Union[IRTensor, List[IRTensor]] + @return accumable bool: True if accumable + """ + tensors: List[IRSubTensor] = [tensors,] if isinstance(tensors, IRSubTensor) else tensors + assert all(isinstance(t, IRSubTensor) for t in tensors), "Expected IRSubTensor or List[IRSubTensor]" + if any(t.parent != self.parent for t in tensors) or any(t.indmap != self.indmap for t in tensors): + return False + if any(t.indmap != self.indmap for t in tensors): + return False + if any(t.valmap[1] != self.valmap[1] for t in tensors): + return False + return self.valmap[1] % (len(tensors) + 1) == 0 + + def accum(self, tensors: Union[IRTensor, List[IRTensor]]) -> IRTensor: + """! + Accumulate tensor on value dimension. + The replica id will be + + @param: tensors Union[IRTensor, List[IRTensor]] + @return tensor IRSubTensor: accumulated tensor + """ + tensors: List[IRSubTensor] = [tensors,] if isinstance(tensors, IRSubTensor) else tensors + assert self.accumable(tensors), "Not accumable" + nreduce = len(tensors) + 1 + assert self.valmap[1] % nreduce == 0 + # TODO: make accum more robust + cid = min(t.valmap[0] for t in [self] + tensors) // nreduce + valmap = (cid, self.valmap[1] // nreduce) + indmap = self.indmap + tensor = self.parent.select(indmap, valmap) + return tensor + def __copy__(self): """ Copy the tensor that will have the exactly same id except the empty attached cell - Returns: - tensor + @return tensor IRSubTensor: the same tensor in a new instance """ - tensor = IRSubTensor(self.parent, self.indmap, self.valmap, self._shape) + tensor = IRSubTensor(self.parent, self.indmap, self.valmap, tid=self.tid) for key in self.__dict__: setattr(tensor, key, getattr(self, key)) # clear attached cells @@ -704,26 +646,16 @@ def grad(self) -> Optional[Union[IRTensor, float]]: self._grad = full_grad # this tensor is consumed elif self in self.cell.inputs(): - # ref_consumers = list() - # for consumer in self.parent.consumers: - # for itensor in consumer.inputs(): - # if self.overlap(itensor): - # # TODO: we should guarantee in final status itensor == self - # # replicated nodes will have same node id - # if consumer._id not in ref_consumers: - # ref_consumers.append(consumer._id) - # # if one node has multiple same tensors, - # # will consider them as one - # break - # ref_times = len(ref_consumers) - # if ref_times == 0: - # raise RuntimeError("Internal error: consumer doesn't have the operator attached to this tensor") - # idx = ref_consumers.index(self.cell._id) - assert self.grad_accum is not None, "not supported for gradient accumulation" + # for backard, we assume in final distributed graph, + # each tensor can be represented as nested + consumers = [] + for ctensor, consumer in zip(self.parent.ctensors, self.parent.consumers): + if ctensor == self and consumer.cid not in consumers: + consumers.append(consumer.cid) + valmap = (consumers.index(self.cell.cid), len(consumers)) grad = full_grad.select( indmap = self.indmap, - valmap = self.grad_accum, #(idx, ref_times), - shape = self.shape + valmap = valmap, ) self._grad = grad self._dirty_grad = False @@ -733,47 +665,38 @@ def grad(self) -> Optional[Union[IRTensor, float]]: grad = full_grad.select( indmap = self.indmap, valmap = (0, 1), - shape = self.shape ) self._grad = grad else: - raise RuntimeError("Visit graidient of a tensor that is potentially generated by IRAdapter") + raise RuntimeError("Visit gradient of a tensor that is potentially generated by IRAdapter") self._dirty_grad = False self._requires_grad = False if full_grad is None else True return self._grad @property def requires_grad(self) -> bool: - _ = self.grad - return self._requires_grad + return self.parent._requires_grad # partition primitives - def select(self, indmap: Union[Tuple, IndexMap], valmap: Union[Tuple, ValueMap, None], shape=None) -> IRTensor: + def select(self, + indmap: Union[Tuple[StartEnd], IndexMap], + valmap: Union[IdxChunk, ValueMap]) -> IRTensor: """ Select an IRSubTensor - Args: - indmap: the index of this tensor's index + @param indmap IndexMap: the index map of this tensor's index - valmap: the value operation to merge - co-located indmap of SubTensors into one + @param valmap ValueMap: the value map of this tensor's value - shape: the sub_tensor shape - - Returns: - IRSubTensor + @return subtensor IRSubTensor: the selected tensor """ - sub_ind_map = _to_indmap(indmap) - sub_valmap = _to_value_map(valmap) - + indmap, valmap = IndexMap(indmap), ValueMap(valmap) # index mapping - index_map = self.indmap.map(sub_ind_map) + indmap = self._indmap.map(indmap) # value mapping - valmap = self.valmap.map(sub_valmap) - - sub_tensor = self.parent.select(index_map, valmap, shape) - sub_tensor.grad_accum = None + valmap = self._valmap.map(valmap) + sub_tensor = self.parent.select(indmap, valmap) return sub_tensor def replicate(self, num: int) -> List[IRTensor]: @@ -783,11 +706,12 @@ def replicate(self, num: int) -> List[IRTensor]: @return tensor IRTensor: the copied tensor """ - aidx, chunks = self.grad_accum tensors = [] - for idx in range(num): - tensor = copy.copy(self) - tensor.grad_accum = (aidx * num + idx, chunks * num) + for _ in range(num): + tensor = self.parent.select( + indmap=self.indmap, + valmap=self.valmap, + ) tensors.append(tensor) return tensors @@ -806,24 +730,19 @@ def split_dim(self, dim: int, num: int) -> List[IRTensor]: assert self.shape[dim] % num == 0, f"Expected dimension can be split: {self.shape[dim]} % {num} != 0" chunk_size = self.shape[dim] // num - shape_slicer = list() - chunk_shape = list() + indmap = [] for tdim, nele in enumerate(self.shape): if tdim != dim: - shape_slicer.append(slice(0, nele, 1)) - chunk_shape.append(nele) + indmap.append((0, nele)) else: - shape_slicer.append(None) - chunk_shape.append(chunk_size) + indmap.append(None) sub_tensors = list() for cid in range(num): - shape_slicer[dim] = slice(chunk_size * cid, chunk_size * (cid + 1), 1) + indmap[dim] = (chunk_size * cid, chunk_size * (cid + 1)) sub_tensor = self.select( - indmap = tuple(shape_slicer), - valmap = None, - shape = chunk_shape + indmap=tuple(indmap), + valmap=(0,1), ) - sub_tensor.grad_accum = self.grad_accum sub_tensors.append(sub_tensor) return sub_tensors @@ -837,90 +756,68 @@ def split_val(self, num: int) -> List[IRTensor]: @return sub_tensors List[IRSubTensor]: the generated sub-tensors """ # full shape - shape_slicer = list() + indmap = [] for nele in self.shape: - shape_slicer.append(slice(0, nele, 1)) + indmap.append((0, nele)) sub_tensors = list() for idx in range(num): + valmap = self._valmap.map((idx, num)) sub_tensor = self.select( - indmap = tuple(shape_slicer), - valmap = (idx, num), - shape = self.shape + indmap=tuple(indmap), + valmap=valmap, ) - sub_tensor.grad_accum = self.grad_accum sub_tensors.append(sub_tensor) return sub_tensors def overlap(self, other) -> bool: - """ - Check if the two tensor is overlapped. + """! + Check whether the two subtensors are overlapped. - Returns: - True if they are sharing co-located position in - the full tensor, otherwise False + @param other IRSubTensor + + @return overlapped bool: True if they are overlapped else False """ - if not isinstance(other, IRTensor): - return False - if isinstance(other, IRFullTensor): - return self.parent == other - elif isinstance(other, IRSubTensor): + if isinstance(other, IRSubTensor): if self.parent != other.parent: return False - return self.indmap.overlap(other.indmap) and \ - self.valmap.overlap(other.valmap) - else: - raise TypeError("Customized IRTensor not support") + return self._indmap.overlap(other._indmap) and \ + self._valmap.overlap(other._valmap) + return False - def common(self, other): - """ + def common(self, other) -> Optional[IRTensor]: + """! Get the common sub-tensor - Args: - IRTensor + @param other IRSubTensor - Returns: - None for not overlap, - else IRSubTensor or IRFullTensor + @return subtensor Optional[IRSubTensor]: the common sub-tensor. + If not common region, return None """ if self.overlap(other): - if isinstance(other, IRFullTensor): - return self - elif isinstance(other, IRSubTensor): - indmap = self.indmap & other.indmap - valmap = self.valmap & other.valmap - sub_tensor = self.parent.select( - indmap = indmap, - valmap = valmap, - shape = indmap.shape - ) - return sub_tensor - else: - raise NotImplementedError("Customized IRTensor not support") + indmap = self._indmap & other._indmap + valmap = self._valmap & other._valmap + sub_tensor = self.parent.select( + indmap = indmap, + valmap = valmap, + ) + return sub_tensor return None - def difference(self, other): - """ - Get differene part of sub-tensor - - Currently this requires tensor to be subset - - Args: - other: IRSubTensor - - Returns: - None for fail - """ - pass - - def __repr__(self): + def __repr__(self) -> str: anno = 't' if self.is_param(): anno = 'w' if self.is_grad(): anno = 'g' - dscp = f'{anno}{self._id}(p{self.parent._id},{self.shape},{self.valmap})' + split_dims = self.splitdims() + dscp = f'{anno}{self._id}(p{self.parent._id},{self.shape},d{split_dims},v{self._valmap})' return dscp - def extra_repr(self): - dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device}, ind={self.indmap}, val={self.valmap})' + def extra_repr(self) -> str: + anno = 't' + if self.is_param(): + anno = 'w' + if self.is_grad(): + anno = 'g' + dscp = f'{anno}{self._id}(id={self._id}, shape={self.shape}, dev={self.device}, ind=[{self._indmap}], val={self._valmap})' return dscp diff --git a/cube/logics/model.py b/cube/logics/model.py index 180ffce7..2eb8c059 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -8,45 +8,6 @@ from cube.ir.cten import IRTensor -__all__ = ['forward'] - - -class _TensorGener: - - def __init__(self): - self.symbol = dict() - - def renew(self, val: Any, keep_param=True): - self._check_is_sub_tensor(val) - if not isinstance(val, IRTensor): - return val - if keep_param and val.is_param(): - return val - if val.parent._id not in self.symbol: - self.symbol[val.parent._id] = val.parent.like() - new_val = self.symbol[val.parent._id].select( - indmap=val.indmap, - valmap=val.valmap, - shape=val.shape - ) - return new_val - - def set_map(self, origin: Any, new: Any): - self._check_is_sub_tensor(origin) - self._check_is_sub_tensor(new) - if isinstance(origin, IRSubTensor): - tid = origin.parent._id - if isinstance(new, IRSubTensor): - self.symbol[tid] = new.parent - return - self.symbol[tid] = new - - def _check_is_sub_tensor(self, tensor): - if isinstance(tensor, IRTensor): - if not isinstance(tensor, IRSubTensor): - raise TypeError("Tensor only allows to be SubTensor") - - def forward(graph: IRGraph, *args) -> IRGraph: """ Forward the IRGraph, replacing all the intermediate tensors @@ -72,15 +33,6 @@ def forward(graph: IRGraph, *args) -> IRGraph: while itensor in graph.outputs(): oidx = graph.outputs().index(itensor) graph.set_output(oidx, arg) - # setup gradient accum - for ftensor in graph.full_tensors(): - naccum = len(ftensor.ctensors) - for idx, ctensor in enumerate(ftensor.ctensors): - ctensor.grad_accum = (idx, naccum) - # actually producer doesn't need to know accumulation - naccum = len(ftensor.producers) - for idx, ptensor in enumerate(ftensor.ptensors): - ptensor.grad_accum = (idx, naccum) # generate backward reverse is only to make op id looks consecutive for fnode in [n for n in graph.nodes() if isinstance(n, IRFwOperation)][::-1]: fnode.gen_backward() diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 4900ca3b..38d02363 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Tuple import torch import torch.nn.functional as TorchF @@ -10,6 +10,14 @@ def identity(tensor: torch.Tensor) -> torch.Tensor: return tensor +def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: + """ + identity forward. Create multiple same tensor. + """ + assert times > 1, "multiref can only be used for num of tensor >= 2" + return tuple([tensor] * times) + + def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], stride: int, padding: List[int], dilation, groups: int = 1): """ From a3d86610ed11baee19e3a32f24f3e2f28a004d07 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Jun 2022 15:38:08 +0800 Subject: [PATCH 0885/1892] fix code bug and easy for debug --- cube/codegen/register.py | 4 ++-- cube/execplan/execplan.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/codegen/register.py b/cube/codegen/register.py index ba812b5c..0f38a371 100644 --- a/cube/codegen/register.py +++ b/cube/codegen/register.py @@ -40,7 +40,6 @@ def allocate(self, tensor: Union[IRTensor, Any]) -> str: if tensor._id in self.tmap: # fetch the original one reg = self.tmap[tensor._id] - return f'{ttype}{reg}' else: # allocate a new one if len(self.slots) == 0: @@ -49,7 +48,8 @@ def allocate(self, tensor: Union[IRTensor, Any]) -> str: else: reg = self.slots.pop(-1) self.tmap[tensor._id] = reg - return f'{ttype}{reg}' + # reg = tensor._id # => enable this for debug + return f'{ttype}{reg}' else: return str(tensor) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 4e286401..6335e9cc 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -274,6 +274,6 @@ def __repr__(self): dscp = f'Execution Plan ({self.graph.name}):\n' for devid in self.devices(): dscp += f'====> Device {devid}:\n' - for node in self._seq(devid): + for node in self._seq[devid]: dscp += f'{node}\n' return dscp From 37574bbe6ec4da04d13ec0df497068c0072fdad4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Jun 2022 16:55:08 +0800 Subject: [PATCH 0886/1892] fix general generation bug --- cube/graph/gener/gen.py | 41 ++++++++++--------------------- cube/ir/adapter/prim.py | 3 +++ cube/ir/tensor.py | 12 ++++----- cube/runtime/adapter/transform.py | 6 ++--- 4 files changed, 25 insertions(+), 37 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index aec27c59..f59ae9f1 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -129,9 +129,9 @@ def gen_activation(graph: IRGraph) -> IRGraph: fadapter = None # Case 1: sharing device (in-shard) - if set(pdevs) == set(cdevs) and len(pdevs) > 1 and \ - len(set(pdevs)) == len(ptensors) and len(set(cdevs)) == len(ctensors): - fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) + # if set(pdevs) == set(cdevs) and len(pdevs) > 1 and \ + # len(set(pdevs)) == len(ptensors) and len(set(cdevs)) == len(ctensors): + # fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) # Case 2: sperating device (cross-shard) if len(set(pdevs).intersection(cdevs)) == 0: @@ -158,28 +158,6 @@ def gen_activation(graph: IRGraph) -> IRGraph: graph._nodes.insert(bidx, badapter) return graph - @staticmethod - def gen_fulltensor(ftensor: IRFullTensor, allow_reorder=False) -> Optional[IRAdapter]: - """ - Generate forward / backward adapter for fulltensor - """ - ptensors, ctensors = ftensor.ptensors, ftensor.ctensors - pdevs = tuple(ptensor.device[0] for ptensor in ptensors) - cdevs = tuple(ctensor.device[0] for ctensor in ctensors) - - # Case 1: sharing device (in-shard) - if set(pdevs) == set(cdevs) and len(pdevs) > 1 and \ - len(set(pdevs)) == len(ptensors) and len(set(cdevs)) == len(ctensors): - return IRAdapterGener.gen_in_shard(ftensor, allow_reorder) - - # Case 2: sperating device (cross-shard) - if len(set(pdevs).intersection(cdevs)) == 0: - pass - - # Case 3: General cases - # warnings.warn('The adapter is generated using inefficient P2P send/recv') - return IRAdapterGener.gen_general(ftensor) - @staticmethod def gen_in_shard(ftensor: IRFullTensor, allow_reorder=False) -> Optional[IRAdapter]: """ @@ -259,11 +237,13 @@ def gen_general(ftensor: IRFullTensor) -> IRAdapter: for ctensor in ftensor.ctensors: fprims += IRAdapterGener.gen_subtensor(ctensor, ftensor.ptensors) fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) + fadapter.prims = fprims if ftensor.grad is not None: bprims = [] for cgrad in ftensor.grad.ctensors: bprims += IRAdapterGener.gen_subtensor(cgrad, ftensor.grad.ptensors) badapter = IRAdapter(ftensor.grad.ptensors, ftensor.grad.ctensors) + badapter.prims = bprims IRAdapter.make_pair(fadapter, badapter) return fadapter @@ -331,9 +311,9 @@ def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRA while out != ctensor: out, merged = None, False for idx1 in range(len(remain_tensors) - 1): - for idx2 in range(idx1, len(remain_tensors)): + for idx2 in range(idx1+1, len(remain_tensors)): t1, t2 = remain_tensors[idx1], remain_tensors[idx2] - catdim = t1.catdims(t2) + catdim = t1.catdim(t2) if catdim is not None: tensors = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] out = tensors[0].concat(tensors[1], dim=catdim) @@ -355,10 +335,15 @@ def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRA break if out is None: ptensors = '\n\t'.join(t.extra_repr() for t in ptensors) + remain = '\n\t'.join(t.extra_repr() for t in remain_tensors) + print(remain_tensors[0].extra_repr()) + print(remain_tensors[1].extra_repr()) + print('cadim:', remain_tensors[0].catdim(remain_tensors[1])) raise RuntimeError( f"Fail to build adapter.\n" f"FullTensor:{ctensor.parent}\n" f"Producers:\n\t{ptensors}\n" - f"SubTensor:\n\t{ctensor.extra_repr()}" + f"SubTensor:\n\t{ctensor.extra_repr()}\n" + f"Remain Tensor:\n\t{remain}" ) return prims diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index 223b2040..f85861a2 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -110,6 +110,9 @@ def __init__(self, itensor: IRSubTensor, indmap: IndexMap, valmap: ValueMap, otensor: IRSubTensor): + indmap = IndexMap(indmap).indices + indmap = tuple(slice(s, e) for s, e in indmap) + valmap = ValueMap(valmap).weight[1] super().__init__([itensor], [otensor], indmap=indmap, valmap=valmap) self.signature = f"cube.runtime.adapter.select" diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 1026c49e..d83f0065 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -525,7 +525,7 @@ def splitdims(self) -> Tuple[int]: dim for dim in range(self.ndims) if self.shape[dim] != self.parent.shape[dim] ) - def catdims(self, other: IRTensor) -> Optional[int]: + def catdim(self, other: IRTensor) -> Optional[int]: """! Get concatable dimensions with other IRSubTensor @@ -539,12 +539,12 @@ def catdims(self, other: IRTensor) -> Optional[int]: for dim in range(self.ndims): if self.indmap[dim] != other.indmap[dim]: s1, e1 = self.indmap[dim] - s2, e2 = self.indmap[dim] + s2, e2 = other.indmap[dim] if min(e1, e2) == max(s1, s2): - if cat_dim is not None: - return None - else: + if cat_dim is None: cat_dim = dim + else: + return None else: return None return cat_dim @@ -559,7 +559,7 @@ def concat(self, other: IRTensor, dim: int) -> IRTensor: @return tensor IRSubTensor: the concatenated tensor """ assert isinstance(other, IRSubTensor), "expected IRSubTensor" - assert self.parent == other.valmap and self.valmap == other.valmap + assert self.parent == other.parent and self.valmap == other.valmap indmap = [] for cdim in range(self.ndims): if cdim == dim: diff --git a/cube/runtime/adapter/transform.py b/cube/runtime/adapter/transform.py index 506bd14a..016d3d4a 100644 --- a/cube/runtime/adapter/transform.py +++ b/cube/runtime/adapter/transform.py @@ -19,15 +19,15 @@ def identity(tensor: torch.Tensor): def select(tensor: torch.Tensor, - indmap: Tuple[slice], valmap: Tuple[int, int]) -> torch.Tensor: + indmap: Tuple[slice], valmap: int) -> torch.Tensor: """ Select a part of tensor spatially and numerically. """ require_grad = tensor.requires_grad with torch.no_grad(): sub_tensor = tensor[indmap] - if valmap != (0, 1): - sub_tensor = sub_tensor / valmap[1] + if valmap != 1: + sub_tensor = sub_tensor / valmap sub_tensor = sub_tensor.detach() if require_grad: sub_tensor = sub_tensor.requires_grad_() From 69ce5227a6742c78b91cf94a700ec0db65791e70 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Jun 2022 18:50:14 +0800 Subject: [PATCH 0887/1892] grid adapter communication --- cube/graph/gener/gen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index f59ae9f1..0618dd70 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -129,9 +129,9 @@ def gen_activation(graph: IRGraph) -> IRGraph: fadapter = None # Case 1: sharing device (in-shard) - # if set(pdevs) == set(cdevs) and len(pdevs) > 1 and \ - # len(set(pdevs)) == len(ptensors) and len(set(cdevs)) == len(ctensors): - # fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) + if set(pdevs) == set(cdevs) and len(pdevs) > 1 and \ + len(set(pdevs)) == len(ptensors) and len(set(cdevs)) == len(ctensors): + fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) # Case 2: sperating device (cross-shard) if len(set(pdevs).intersection(cdevs)) == 0: From 08cc930510f98682df1f08eb443feae5789f7c5b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Jun 2022 19:42:11 +0800 Subject: [PATCH 0888/1892] autograd adapter code gen --- cube/codegen/codegen.py | 96 +++++++++++++++++++++++++++----- cube/execplan/planpass/fusion.py | 29 ++++++---- cube/graph/gener/gen.py | 10 +++- cube/ir/adapter/adapter.py | 4 ++ 4 files changed, 110 insertions(+), 29 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 6ace2048..955dcecb 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -12,7 +12,7 @@ from cube.ir.tensor import IRSubTensor from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from cube.ir.adapter import IRWeightReducer, IRAdapter -from cube.ir.adapter.prim import CollectivePrim +from cube.ir.adapter.prim import CollectivePrim, IRAdapterPrim from cube.graph.graph import IRGraph, IRSegment from cube.graph.schedule import IRScheduleStrategy @@ -54,6 +54,77 @@ def tensor_naming(self, tensor: Any) -> str: name = str(tensor) return name + def tuple_naming(self, tensors: List[Any]) -> str: + tensors = [self.tensor_naming(t) for t in tensors] + tensors = '(' + ', '.join(tensors + ['']) + ')' + return tensors + + def return_naming(self, tensors: List[Any]) -> str: + tensors = [self.tensor_naming(t) for t in tensors] + if len(tensors) == 0: + tensors = '_' + else: + tensors = ', '.join(tensors) + return tensors + + +class AutogradAdapterCodeGen(CodeGen): + """ + Generate autograd adapter code (PyTorch) + """ + def __init__(self): + + self.fw_ins: List[IRSubTensor] = list() + self.fw_body: List[str] = list() + self.fw_ous: List[IRSubTensor] = list() + + self.bw_ins: List[IRSubTensor] = list() + self.bw_body: List[str] = list() + self.bw_ous: List[IRSubTensor] = list() + + def emit_prim(self, prim: IRAdapterPrim) -> str: + if len(prim.inputs()) == 1: + itensors = self.tensor_naming(prim.inputs()[0]) + else: + itensors = self.tuple_naming(prim.inputs()) + kwargs = list() + for name, val in prim.kwargs.items(): + kwargs.append(f'{name}={val}') + kwargs = ', '.join(kwargs) + outputs = self.return_naming(prim.outputs()) + code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' + return code + + def gen(self, fadapter: IRAdapter) -> List[str]: + assert fadapter.forward and fadapter.differentiable and fadapter.custom, "generate autograd for a non-differentiable adapter" + assert fadapter.mirror is not None + name = AutogradAdapterCodeGen.name(fadapter) + with ClassBlock(class_name=name, derived=['torch.autograd.Function']) as cb: + # forward + cb.insert_body('@staticmethod') + finputs = [self.tensor_naming(t) for t in fadapter.inputs()] + with FunctionBlock(func_name='forward', args=['ctx']+finputs) as fw: + for prim in fadapter.prims: + fw.insert_body(self.emit_prim(prim)) + outputs = self.return_naming(fadapter.outputs()) + fw.insert_body(f'return {outputs}') + cb.insert_body(fw.code) + # backward + cb.insert_body('@staticmethod') + badapter: IRAdapter = fadapter.mirror + binputs = [self.tensor_naming(t) for t in badapter.inputs()] + with FunctionBlock(func_name='backward', args=['ctx']+binputs) as bw: + for prim in badapter.prims: + bw.insert_body(self.emit_prim(prim)) + outputs = self.return_naming(badapter.outputs()) + bw.insert_body(f'return {outputs}') + cb.insert_body(bw.code) + return cb.code + + @staticmethod + def name(adapter: IRAdapter) -> str: + return f'Adapter{adapter.cid}' + class ModelCodeGen(CodeGen): """ @@ -123,6 +194,13 @@ def gen(self, device: int, outfile=None, attach=False) -> str: node_args: List[List[str]] = list() gen_nodes: List[IRCell] = list() + # init customized adapter + for seg in [seg for seg in self.execplan.seq(device) if isinstance(seg, IRSegment)]: + for adapter in [n for n in seg.nodes() if isinstance(n, IRAdapter)]: + if adapter.forward and adapter.differentiable and adapter.custom: + gencode += AutogradAdapterCodeGen().gen(adapter) + ['', ''] + adapter.signature = AutogradAdapterCodeGen.name(adapter) + '.apply' + # initialize communication groups self.init_comm_groups() @@ -256,7 +334,8 @@ def emit_adapter_call(self, node: IRAdapter): Emit adapter call """ assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" - for prim in node.prims: + prims = [node] if node.differentiable and node.custom else [prim for prim in node.prims] + for prim in prims: if len(prim.inputs()) == 1: itensors = self.tensor_naming(prim.inputs()[0]) else: @@ -292,19 +371,6 @@ def emit_reducer_call(self, node: IRWeightReducer): call_code = f'{reducer_name}.allreduce()' self.forward_region.append(call_code) - def return_naming(self, tensors: List[Any]) -> str: - tensors = [self.tensor_naming(t) for t in tensors] - if len(tensors) == 0: - tensors = '_' - else: - tensors = ', '.join(tensors) - return tensors - - def tuple_naming(self, tensors: List[Any]) -> str: - tensors = [self.tensor_naming(t) for t in tensors] - tensors = '(' + ', '.join(tensors + ['']) + ')' - return tensors - def tensor_naming(self, tensor: Any): """ Generate tensor name. diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 83a07e2d..ac599c71 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -25,20 +25,23 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: cnt = 0 for devid in execplan.devices(): for node in execplan.seq(devid): - if isinstance(node, IRAdapter) and node.forward and not node.differentiable: - ret = DiffFusion.nnfuse(node) - cnt = cnt+1 if ret else cnt + if isinstance(node, IRAdapter): + print(node, node.cid) + if node.forward: + ret = DiffFusion.nnfuse(node) + cnt = cnt+1 if ret else cnt if isinstance(node, IRSegment) and node.forward: for fnode in node.nodes(): - if isinstance(fnode, IRAdapter) and node.forward and not fnode.differentiable: - ret = DiffFusion.nnfuse(fnode) - if not ret: - raise NotImplementedError( - f"adapter within IRSegment cannot fuse to differientiable adapter" - f"\nforward: {fnode.extra_repr()}" - f"\nbackward: {fnode.mirror.extra_repr()}" - ) - cnt = cnt + 1 + if isinstance(fnode, IRAdapter): + if node.forward: + ret = DiffFusion.nnfuse(fnode) + if not ret: + raise NotImplementedError( + f"adapter within IRSegment cannot fuse to differientiable adapter" + f"\nforward: {fnode.extra_repr()}" + f"\nbackward: {fnode.mirror.extra_repr()}" + ) + cnt = cnt + 1 print(f'successfully generate {cnt} differentiable adapters') return execplan @@ -103,7 +106,9 @@ def is_alltoall(prims: List[IRAdapterPrim]) -> bool: if prims is not None: fadapter.prims = prims badapter.prims = prims + fadapter.custom = False fadapter.differentiable = True + badapter.custom = False badapter.differentiable = True return True return False diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 0618dd70..c1bc8672 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -129,8 +129,9 @@ def gen_activation(graph: IRGraph) -> IRGraph: fadapter = None # Case 1: sharing device (in-shard) - if set(pdevs) == set(cdevs) and len(pdevs) > 1 and \ - len(set(pdevs)) == len(ptensors) and len(set(cdevs)) == len(ctensors): + inshard = set(pdevs) == set(cdevs) and len(ptensors) == len(ctensors) and \ + len(pdevs) == len(ptensors) + if inshard and len(pdevs) > 1: fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) # Case 2: sperating device (cross-shard) @@ -148,6 +149,11 @@ def gen_activation(graph: IRGraph) -> IRGraph: (badapter is None and len(fadapter.prims) == 0): continue + # set differentiable for autograd generation + if inshard and badapter is not None: + fadapter.differentiable = True + badapter.differentiable = True + # insert forward adapter fidx = min([graph.nodes().index(consumer) for consumer in ftensor.consumers]) graph._nodes.insert(fidx, fadapter) diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index ff40018a..8455a242 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -15,6 +15,7 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): output_length=len(outputs), init_outputs=False ) + self.kwargs = dict() # we don't use input and output setter as this will # change tensor device info self._inputs = inputs @@ -22,6 +23,7 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): self._prims: List[IRAdapterPrim] = [] self._differentiable = False + self.custom = True device = set() for tensor in inputs + outputs: @@ -117,6 +119,8 @@ def dispatch(self, devid: int, for_mirror=True): fadapter.prims = prims fadapter.name = self.name fadapter._id = self._id + fadapter.differentiable = self.differentiable + fadapter.custom = self.custom # dispatch for mirror if for_mirror and isinstance(self.mirror, IRAdapter): badapter = self.mirror.dispatch(devid, for_mirror=False) From de4367a40224c2594d089ae4fbb411dd351641b3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Jun 2022 20:06:06 +0800 Subject: [PATCH 0889/1892] switch to general generation if fail for in shard gen; fix conv partition --- cube/algorithm/ops/conv.py | 75 ++++++++++++++++++++++++-------------- cube/graph/gener/gen.py | 19 +++++----- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 03729aec..d67db580 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -1,12 +1,32 @@ -from typing import Dict -from cube.algorithm.utils import split_axis, split_axis_custom, split_value +from typing import List, Tuple + +from cube.ir.tensor import IRSubTensor from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.conv import IRConv2D from cube.graph.function.conv import IRConv3D +def _split_axis_custom(tensor: IRSubTensor, dim: int, chunks: List[Tuple[int, int]]): + """ + Split tensor along an axis with customized selection + """ + dim = len(tensor.shape) + dim if dim < 0 else dim + assert dim < len(tensor.shape), f"dim should within ndims ({dim} >= {tensor.ndims})" + chunk_num = len(chunks) + indmap = list() + for nele in tensor.shape: + indmap.append((0, nele)) + sub_tensors = list() + for cid in range(chunk_num): + indmap[dim] = chunks[cid] + sub_tensors.append(tensor.select( + indmap=tuple(indmap), valmap=(0,1) + )) + return sub_tensors + + class DimSplitConv2D(GenericDistAlgo): """ split Conv2D at dimension level @@ -49,28 +69,28 @@ def instantiate(self, idx: int, dim: int, num: int): outputs = list() # split N if (idx, dim) == (0, 0): - inputs = split_axis(node.inputs(0), axis=0, chunk_num=num) + inputs = node.inputs(0).split_dim(dim, num) weights = [node.inputs(1)] * num bias = [node.inputs(2)] * num - outputs = split_axis(node.outputs(0), axis=0, chunk_num=num) + outputs = node.outputs(0).split_dim(dim, num) # split oC if (idx, dim) == (1, 0): inputs = [node.inputs(0)] * num - weights = split_axis(node.inputs(1), axis=0, chunk_num=num) + weights = node.inputs(1).split_dim(dim, num) if node.inputs(2) is None: bias = [None] * num else: - bias = split_axis(node.inputs(2), axis=0, chunk_num=num) - outputs = split_axis(node.outputs(0), axis=1, chunk_num=num) + bias = node.inputs(2).split_dim(dim, num) + outputs = node.outputs(0).split_dim(dim=1, num=num) # split iC if (idx, dim) == (0, 1) or (idx, dim) == (1, 1): - inputs = split_axis(node.inputs(0), axis=1, chunk_num=num) - weights = split_axis(node.inputs(1), axis=1, chunk_num=num) + inputs = node.inputs(0).split_dim(dim, num) + weights = node.inputs(1).split_dim(dim, num) if node.inputs(2) is None: bias = [None] * num else: - bias = split_value(node.inputs(2), chunk_num=num) - outputs = split_value(node.outputs(0), chunk_num=num) + bias = node.inputs(2).split_val(num) + outputs = node.outputs(0).split_val(num) subnodes = list() for i, w, b, o in zip(inputs, weights, bias, outputs): subnodes.append(node.new([i, w, b], [o])) @@ -123,7 +143,7 @@ def instantiate(self, idx: int, dim: int, num: int): # split H if (idx, dim) == (0, 2): # input and padding - slicers = list() + indmap = list() pads = list() start = 0 - padding[0] for cid in range(num): @@ -134,21 +154,21 @@ def instantiate(self, idx: int, dim: int, num: int): # input -- FIXME: only work for stride=[1,1] chunkH = oH // num + dilation[0] * (dH - 1) stop = start + chunkH - padr - slicers.append(slice(max(0, start), min(H, stop))) + indmap.append((max(0, start), min(H, stop))) start = stop - dilation[0] * (dH - 1) # start = 0 if cid == 0 else 1023 # stop = 1025 if cid == 0 else H - inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) + inputs = _split_axis_custom(node.inputs(0), dim=dim, chunks=tuple(indmap)) # weight weights = [node.inputs(1)] * num # bias bias = [node.inputs(2)] * num # outputs - outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) + outputs = node.outputs(0).split_dim(dim, num) # split W if (idx, dim) == (0, 3): # input and padding - slicers = list() + indmap = list() pads = list() start = 0 - padding[2] for cid in range(num): @@ -159,15 +179,15 @@ def instantiate(self, idx: int, dim: int, num: int): # input -- FIXME: only work for stride=[1,1] chunkH = oW // num + dilation[0] * (dH - 1) stop = start + chunkH - padb - slicers.append(slice(max(0, start), min(H, stop))) + indmap.append((max(0, start), min(H, stop))) start = stop - dilation[0] * (dH - 1) - inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) + inputs = _split_axis_custom(node.inputs(0), dim=dim, chunks=tuple(indmap)) # weight weights = [node.inputs(1)] * num # bias bias = [node.inputs(2)] * num # outputs - outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) + outputs = node.outputs(0).split_dim(dim, num) sub_nodes = list() for i, w, b, pad, o in zip(inputs, weights, bias, pads, outputs): conv = IRConv2D(node.signature, [i, w, b], node.name, @@ -177,7 +197,6 @@ def instantiate(self, idx: int, dim: int, num: int): return sub_nodes - class HaloSplitConv3D(GenericDistAlgo): """ Halo-exchange split @@ -225,7 +244,7 @@ def instantiate(self, idx: int, dim: int, num: int): # split H if (idx, dim) == (0, 2): # input and padding - slicers = list() + indmap = list() pads = list() start = 0 - padding[0] for cid in range(num): @@ -236,21 +255,21 @@ def instantiate(self, idx: int, dim: int, num: int): # input -- FIXME: only work for stride=[1,1] chunkH = oH // num + dilation[0] * (dH - 1) stop = start + chunkH - padr - slicers.append(slice(max(0, start), min(H, stop))) + indmap.append((max(0, start), min(H, stop))) start = stop - dilation[0] * (dH - 1) # start = 0 if cid == 0 else 1023 # stop = 1025 if cid == 0 else H - inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) + inputs = _split_axis_custom(node.inputs(0), dim=dim, chunks=indmap) # weight weights = [node.inputs(1)] * num # bias bias = [node.inputs(2)] * num # outputs - outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) + outputs = node.outputs(0).split_dim(dim, num) # split W if (idx, dim) == (0, 3): # input and padding - slicers = list() + indmap = list() pads = list() start = 0 - padding[2] for cid in range(num): @@ -261,15 +280,15 @@ def instantiate(self, idx: int, dim: int, num: int): # input -- FIXME: only work for stride=[1,1] chunkH = oW // num + dilation[0] * (dH - 1) stop = start + chunkH - padb - slicers.append(slice(max(0, start), min(H, stop))) + indmap.append((max(0, start), min(H, stop))) start = stop - dilation[0] * (dH - 1) - inputs = split_axis_custom(node.inputs(0), axis=dim, chunks=slicers) + inputs = _split_axis_custom(node.inputs(0), dim=dim, chunks=indmap) # weight weights = [node.inputs(1)] * num # bias bias = [node.inputs(2)] * num # outputs - outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) + outputs = node.outputs(0).split_dim(dim, num) sub_nodes = list() for i, w, b, pad, o in zip(inputs, weights, bias, pads, outputs): conv = IRConv3D(node.signature, [i, w, b], node.name, diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index c1bc8672..8078cfb7 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -114,15 +114,6 @@ def gen_activation(graph: IRGraph) -> IRGraph: ftensor.consumers[0].device == ftensor.producers[0].device: continue - # print(f'==> analyzing full tensor: {ftensor}') - # print('producer:') - # for ptensor in ftensor.ptensors: - # print(ptensor, 'device:', ptensor.device) - # print('consumer') - # for ctensor in ftensor.ctensors: - # print(ctensor, 'device:', ctensor.device) - # print('') - ptensors, ctensors = ftensor.ptensors, ftensor.ctensors pdevs = tuple(ptensor.device[0] for ptensor in ptensors) cdevs = tuple(ctensor.device[0] for ctensor in ctensors) @@ -132,7 +123,15 @@ def gen_activation(graph: IRGraph) -> IRGraph: inshard = set(pdevs) == set(cdevs) and len(ptensors) == len(ctensors) and \ len(pdevs) == len(ptensors) if inshard and len(pdevs) > 1: - fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) + try: + fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) + except Exception as e: + fadapter = None + print( + f"full tensor: {ftensor} cannot use grid generation.\n" + f"Reason: {str(e)}\n" + f"Switch to general P2P communication." + ) # Case 2: sperating device (cross-shard) if len(set(pdevs).intersection(cdevs)) == 0: From 311113352890bb5cf0cfe0101e80e1e73a2105b3 Mon Sep 17 00:00:00 2001 From: lynex Date: Fri, 24 Jun 2022 13:13:33 +0800 Subject: [PATCH 0890/1892] update wrf1 example --- examples/wrf/policy/hw_halo.py | 22 ++++++++++++++++++++++ examples/wrf/wrf.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 examples/wrf/policy/hw_halo.py diff --git a/examples/wrf/policy/hw_halo.py b/examples/wrf/policy/hw_halo.py new file mode 100644 index 00000000..48aedc92 --- /dev/null +++ b/examples/wrf/policy/hw_halo.py @@ -0,0 +1,22 @@ +from cube.graph import IRGraph +from cube.graph.function import IRConv2D, IRConv3D + +def PAS(graph: IRGraph, resource): + for node in graph.nodes(): +# graph.assign(node, 0) + if isinstance(node, IRConv3D): + sub_nodes = list() + algo = node.algorithms('halo') + Wnodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus // 2) + for Wnode in Wnodes: + algo = Wnode.algorithms('halo') + Hnodes = graph.partition(Wnode, algo, idx=0, dim=2, num=2) + sub_nodes += Hnodes + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + # sub_nodes = graph.replicate(node, times=resource.ngpus) + + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + print(graph.extra_repr()) + return graph diff --git a/examples/wrf/wrf.py b/examples/wrf/wrf.py index 3f2d4857..47acdad9 100644 --- a/examples/wrf/wrf.py +++ b/examples/wrf/wrf.py @@ -16,7 +16,7 @@ print("torch einops 1") import cube -from examples.poisson.policy.naive import PAS +from examples.wrf.policy.hw_halo import PAS device = 'cuda' # From e435c4b6883ca8a9bf193210757df583200dad3a Mon Sep 17 00:00:00 2001 From: Zijian Ding Date: Sun, 26 Jun 2022 18:05:44 +0800 Subject: [PATCH 0891/1892] fix bugs and allow non-uniform partition in HaloSplitConv3D --- cube/algorithm/ops/conv.py | 38 +++++++++++++++++++++++--------------- cube/ir/tensor.py | 7 +++++-- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index d67db580..2bec034c 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -207,7 +207,7 @@ class HaloSplitConv3D(GenericDistAlgo): def __init__(self, node: IRConv3D): if not isinstance(node, IRConv3D): - raise TypeError(f"Expect IRConv2D") + raise TypeError(f"Expect IRConv3D") super().__init__(node) def satisfy(self, idx: int, dim: int, num: int): @@ -225,10 +225,12 @@ def satisfy(self, idx: int, dim: int, num: int): raise NotImplementedError("Splitting on dilation != [1,1] is not supported") # split H if (idx, dim) == (0, 2): - return oH % num == 0 + return oH >= num + # return oH % num == 0 # split W if (idx, dim) == (0, 3): - return oW % num == 0 + return oW >= num + # return oW % num == 0 def instantiate(self, idx: int, dim: int, num: int): if not self.satisfy(idx, dim, num): @@ -247,48 +249,54 @@ def instantiate(self, idx: int, dim: int, num: int): indmap = list() pads = list() start = 0 - padding[0] + addone_num = oH % num for cid in range(num): - # padding padl = padding[1] if cid == 0 else 0 padr = padding[1] if cid == num - 1 else 0 - pads.append([padding[0], padding[0], padl, padr, padding[2], padding[2]]) + # padding -- FIXME: padding here is not correct, only work for pad=[0,..,0] + pads.append([padding[0], padl, padr, padding[2], padding[2]]) # input -- FIXME: only work for stride=[1,1] chunkH = oH // num + dilation[0] * (dH - 1) - stop = start + chunkH - padr + addone = int(cid < addone_num) + stop = start + chunkH - padr + addone + # stop = start + chunkH - padr indmap.append((max(0, start), min(H, stop))) start = stop - dilation[0] * (dH - 1) # start = 0 if cid == 0 else 1023 # stop = 1025 if cid == 0 else H - inputs = _split_axis_custom(node.inputs(0), dim=dim, chunks=indmap) + inputs = _split_axis_custom(node.inputs(0), dim=dim+1, chunks=indmap) # weight weights = [node.inputs(1)] * num # bias bias = [node.inputs(2)] * num # outputs - outputs = node.outputs(0).split_dim(dim, num) + outputs = node.outputs(0).split_dim(dim+1, num) # split W if (idx, dim) == (0, 3): # input and padding indmap = list() pads = list() start = 0 - padding[2] + addone_num = oW % num for cid in range(num): # padding padt = padding[2] if cid == 0 else 0 padb = padding[2] if cid == num - 1 else 0 - pads.append([padding[0], padding[0], padding[1], padding[1], padt, padb]) + # padding -- FIXME: padding here is not correct, only work for pad=[0,..,0] + pads.append([padding[0], padding[1], padding[1], padt, padb]) # input -- FIXME: only work for stride=[1,1] - chunkH = oW // num + dilation[0] * (dH - 1) - stop = start + chunkH - padb - indmap.append((max(0, start), min(H, stop))) - start = stop - dilation[0] * (dH - 1) - inputs = _split_axis_custom(node.inputs(0), dim=dim, chunks=indmap) + chunkH = oW // num + dilation[0] * (dW - 1) + addone = int(cid < addone_num) + stop = start + chunkH - padb + addone + indmap.append((max(0, start), min(W, stop))) + start = stop - dilation[0] * (dW - 1) + inputs = _split_axis_custom(node.inputs(0), dim=dim+1, chunks=indmap) # weight weights = [node.inputs(1)] * num # bias bias = [node.inputs(2)] * num # outputs - outputs = node.outputs(0).split_dim(dim, num) + outputs = node.outputs(0).split_dim(dim+1, num) sub_nodes = list() for i, w, b, pad, o in zip(inputs, weights, bias, pads, outputs): conv = IRConv3D(node.signature, [i, w, b], node.name, diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index d83f0065..02e72232 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -727,8 +727,9 @@ def split_dim(self, dim: int, num: int) -> List[IRTensor]: """ dim = dim + self.ndims if dim < 0 else dim assert dim < self.ndims, f"Dim should within ndims but {dim} >= {self.ndims})" - assert self.shape[dim] % num == 0, f"Expected dimension can be split: {self.shape[dim]} % {num} != 0" + # assert self.shape[dim] % num == 0, f"Expected dimension can be split: {self.shape[dim]} % {num} != 0" chunk_size = self.shape[dim] // num + addone_num = self.shape[dim] % num indmap = [] for tdim, nele in enumerate(self.shape): @@ -738,7 +739,9 @@ def split_dim(self, dim: int, num: int) -> List[IRTensor]: indmap.append(None) sub_tensors = list() for cid in range(num): - indmap[dim] = (chunk_size * cid, chunk_size * (cid + 1)) + num_prev_addone = addone_num if cid >= addone_num else cid + addone = int(cid < addone_num) + indmap[dim] = (chunk_size * cid + num_prev_addone, chunk_size * (cid+1) + addone + num_prev_addone) sub_tensor = self.select( indmap=tuple(indmap), valmap=(0,1), From 5b43edc4fc35cf86c910a26fd2ce670d30413fdb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 26 Jun 2022 22:44:27 +0800 Subject: [PATCH 0892/1892] enbale residual; full megatron tensor parallelism --- examples/gsearch/blocks.py | 24 ++++++++++++++++++------ examples/gsearch/gpt/policy/spmd.py | 17 ++++++++++++++++- examples/gsearch/gpt/train.py | 2 +- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/examples/gsearch/blocks.py b/examples/gsearch/blocks.py index 96a44b48..e112be60 100644 --- a/examples/gsearch/blocks.py +++ b/examples/gsearch/blocks.py @@ -1,6 +1,5 @@ import torch import cube -import warnings @cube.graph.parser.register('L N E+, (h d) E+, (h d), (h d) E+, (h d), (h d) E+, (h d) -> N h L d, N h L d, N h L d', name='attn_qkv') @@ -76,6 +75,17 @@ def attn_dense_out(context: torch.Tensor, weight: torch.Tensor, bias: torch.Tens return torch.nn.functional.linear(context, weight, bias) +@cube.graph.parser.register('L N E+, inner E+, inner -> L N inner', name='mlp_linear1') +def mlp_linear1(x: torch.Tensor, proj: torch.Tensor, bias: torch.Tensor): + return torch.nn.functional.linear(x, proj, bias) + + +@cube.graph.parser.register('L N inner+, E inner+, E -> L N E', name='mlp_linear2') +def mlp_linear2(x: torch.Tensor, proj: torch.Tensor, bias: torch.Tensor): + return torch.nn.functional.linear(x, proj, bias) + + + class MultiHeadSelfAttention(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): @@ -133,9 +143,11 @@ def __init__(self, embed_dim: int, hidden_dim: int, dropout: float): self.dropout = dropout def forward(self, x: torch.Tensor): - x = torch.nn.functional.linear(x, self.proj1, self.proj1_bias) + # L N E, inner E -> L N inner + x = mlp_linear1(x, self.proj1, self.proj1_bias) + # L N inner -> L N inner x = torch.nn.functional.gelu(x) - x = torch.nn.functional.linear(x, self.proj2, self.proj2_bias) + x = mlp_linear2(x, self.proj2, self.proj2_bias) x = torch.nn.functional.dropout(x, self.dropout, True, False) return x @@ -153,17 +165,17 @@ def __init__(self, embed_dim: int, num_heads: int, self.dropout = torch.nn.Dropout(p=dropout) self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - warnings.warn('residual is disabled in encoder block') + # warnings.warn('residual is disabled in encoder block') def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.self_attn_layer_norm(x) x = self.self_attn(x) x = self.dropout(x) - # x = x + residual + x = x + residual residual = x x = self.final_layer_norm(x) x = self.mlp(x) - # x = x + residual + x = x + residual return x diff --git a/examples/gsearch/gpt/policy/spmd.py b/examples/gsearch/gpt/policy/spmd.py index 3a6b4f62..d0c73200 100644 --- a/examples/gsearch/gpt/policy/spmd.py +++ b/examples/gsearch/gpt/policy/spmd.py @@ -14,7 +14,7 @@ def PASReplica(graph: IRGraph, resource): return graph -def PASMegatron(graph: IRGraph, resource): +def PASMegatronTP(graph: IRGraph, resource): """ Megatron tensor parallelism (attention) """ @@ -32,6 +32,7 @@ def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): graph.assign(sub_node, idx) return sub_nodes + # ============ Attention =============== qkvs = [node for node in fnodes if node.name == 'attn_qkv'] for idx, qkv in enumerate(qkvs): tensor_parallelism(qkv, f'====> start of transformer {idx}', idx=1, dim=0, num=tp_size) @@ -56,6 +57,20 @@ def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): for dense in dense_outs: tensor_parallelism(dense, idx=0, dim=2, num=tp_size) + # ============= MLP =================== + linear1s = [node for node in fnodes if node.name == 'mlp_linear1'] + for mlp_linear1 in linear1s: + tensor_parallelism(mlp_linear1, idx=1, dim=0, num=tp_size) + + gelus = [node for node in fnodes if node.name == 'gelu'] + for gelu in gelus: + tensor_parallelism(gelu, idx=0, dim=2, num=tp_size) + + linear2s = [node for node in fnodes if node.name == 'mlp_linear2'] + for mlp_linear2 in linear2s: + tensor_parallelism(mlp_linear2, idx=0, dim=2, num=tp_size) + + # replicate others for node in graph.nodes(): if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: rnodes = graph.replicate(node, times=tp_size) diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py index 07d0ac0a..7bab77e3 100644 --- a/examples/gsearch/gpt/train.py +++ b/examples/gsearch/gpt/train.py @@ -12,7 +12,7 @@ from examples.gsearch.gpt.model import GPT from examples.gsearch.gpt.model import GPTDataLoader -from examples.gsearch.gpt.policy.spmd import PASMegatron as PAS +from examples.gsearch.gpt.policy.spmd import PASMegatronTP as PAS import cube from cube.profiler.timer import CudaTimer, print_each_rank From 0d259e98814b6f9895e360b60ec33a19edea3926 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Jun 2022 09:41:42 +0800 Subject: [PATCH 0893/1892] update mpmd example --- examples/mlp/policy/mpmd.py | 91 +++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 examples/mlp/policy/mpmd.py diff --git a/examples/mlp/policy/mpmd.py b/examples/mlp/policy/mpmd.py new file mode 100644 index 00000000..ed754407 --- /dev/null +++ b/examples/mlp/policy/mpmd.py @@ -0,0 +1,91 @@ +import random +from typing import Tuple +import numpy as np + +from cube.graph.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.schedule.sched1f1b import IRSchedule1F1B + + +def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + + +def PASRandom(graph, resource): + """ + Random pipeline + """ + assert len(graph.nodes()) // 2 >= resource.ngpus, "not enough operator number." + remain_device = set(range(resource.ngpus)) + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if len(remain_device) != 0: + idx = random.randint(0, len(remain_device) - 1) + device = list(remain_device)[idx] + remain_device.remove(device) + else: + device = random.randint(0, resource.ngpus - 1) + graph.assign(node, device) + elif isinstance(node, IRDataOperation): + device = random.randint(0, resource.ngpus - 1) + graph.assign(node, device) + print(graph.extra_repr()) + return graph + + +def PAS1F1B(graph: IRGraph, resource): + + # assert resource.ngpus == 8, "should apply on 8 gpus" + num_stage = 4 + num_tp = resource.ngpus // num_stage + num_microbatch = resource.ngpus + + _, tp_mesh = _create_mesh(resource.ngpus, (num_stage, num_tp)) + print(f'> pipeline-tensor parallel group: {tp_mesh}') + assert len(tp_mesh) == num_stage + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + node2stage = lambda node: min(fnodes.index(node) // (len(fnodes) // num_stage), num_stage-1) + + for idx, node in enumerate(fnodes): + # get tensor parallel group + sid = node2stage(node) + tp_group = tp_mesh[sid] + # partition + if node.name == 'linear': + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=num_tp) + else: + tp_nodes = graph.replicate(node, times=num_tp) + # assign + for devid, node in zip(tp_group, tp_nodes): + graph.assign(node, devid) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + mesh = tp_mesh[0] + rnodes = graph.replicate(node, times=num_tp) + for devid, rnode in zip(mesh, rnodes): + graph.assign(rnode, devid) + # setup schedule to 1F1B + schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) + graph.schedule_plan = schedule + return graph \ No newline at end of file From 7c0676888e33307faccf654677a9d6f0871104ba Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Jun 2022 09:42:38 +0800 Subject: [PATCH 0894/1892] remove print and add assertation --- cube/execplan/planpass/fusion.py | 1 - cube/graph/graph.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index ac599c71..9a09b055 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -26,7 +26,6 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: for devid in execplan.devices(): for node in execplan.seq(devid): if isinstance(node, IRAdapter): - print(node, node.cid) if node.forward: ret = DiffFusion.nnfuse(node) cnt = cnt+1 if ret else cnt diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 671ded7b..5d50d918 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -657,6 +657,7 @@ def assign(self, node: Union[IRFwOperation, IRBpOperation], @return sucess bool: always true """ + assert isinstance(node, (IRFwOperation, IRDataOperation)), "Only forward and data operation can be assigned to device." assert node in self._nodes, f"{node} is not in the graph" ranks = (ranks,) if isinstance(ranks, int) else ranks assert all([isinstance(rank, int) for rank in ranks]), "Expected rank to be int" From a8685947c53af88ca159a8e9963910cab780b4f2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Jun 2022 09:43:37 +0800 Subject: [PATCH 0895/1892] remove unnecessary select in adapter --- cube/graph/gener/gen.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 8078cfb7..7c6d502e 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -279,19 +279,20 @@ def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRA common.cell = itensor.cell intersections.append(common) # create select primitive - indmap = [] - for dim in range(itensor.ndims): - (s1, e1), (s2, e2) = itensor.indmap[dim], common.indmap[dim] - start = s2 - s1 - end = start + e2 - s2 - indmap.append((start, end)) - indmap = IndexMap(tuple(indmap)) - assert itensor.valmap == common.valmap, "Value map not same" - valmap = ValueMap((0, 1)) - select_prim = SelectPrim(itensor, indmap, valmap, common) + if common != itensor: + indmap = [] + for dim in range(itensor.ndims): + (s1, e1), (s2, e2) = itensor.indmap[dim], common.indmap[dim] + start = s2 - s1 + end = start + e2 - s2 + indmap.append((start, end)) + indmap = IndexMap(tuple(indmap)) + assert itensor.valmap == common.valmap, "Value map not same" + valmap = ValueMap((0, 1)) + select_prim = SelectPrim(itensor, indmap, valmap, common) + prims.append(select_prim) if itensor.device == ctensor.device and common == ctensor: return [select_prim] - prims.append(select_prim) # TODO: check union == subtensor if common == ctensor: break From be0e242f95ac3b20c656778f9043a63e9299bc62 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Jun 2022 09:45:14 +0800 Subject: [PATCH 0896/1892] clean old policy --- examples/mlp/policy/megatron_pptp.py | 68 ---------------------------- examples/mlp/policy/pipe_parallel.py | 26 ----------- 2 files changed, 94 deletions(-) delete mode 100644 examples/mlp/policy/megatron_pptp.py delete mode 100644 examples/mlp/policy/pipe_parallel.py diff --git a/examples/mlp/policy/megatron_pptp.py b/examples/mlp/policy/megatron_pptp.py deleted file mode 100644 index 87efe7ef..00000000 --- a/examples/mlp/policy/megatron_pptp.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Tuple -import numpy as np - -from cube.graph.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.schedule.sched1f1b import IRSchedule1F1B - - -def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: - """ - Create hybrid (nested) groups given the each group number. - - The product of group_num should be same with total devices. - """ - group_num = np.array(group_num) - cnt = np.prod(group_num) - assert cnt == ngpus, 'total device not match' - grid = np.arange(cnt).reshape(tuple(group_num)) - dims = list(range(len(group_num))) - outputs = [] - for dim, num in enumerate(group_num): - remain = ngpus // num - order = tuple(dims[:dim] + dims[dim+1:] + [dim]) - grid_dim = np.transpose(grid, order).reshape((remain,num)) - grid_dim = grid_dim.tolist() - outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) - assert len(outputs) == len(group_num) - return tuple(outputs) - - -def PAS(graph: IRGraph, resource): - - # assert resource.ngpus == 8, "should apply on 8 gpus" - num_stage = 4 - num_tp = resource.ngpus // num_stage - num_microbatch = resource.ngpus - - _, tp_mesh = create_mesh(resource.ngpus, (num_stage, num_tp)) - print(f'> pipeline-tensor parallel group: {tp_mesh}') - assert len(tp_mesh) == num_stage - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - node2stage = lambda node: min(fnodes.index(node) // (len(fnodes) // num_stage), num_stage-1) - - for idx, node in enumerate(fnodes): - # get tensor parallel group - sid = node2stage(node) - tp_group = tp_mesh[sid] - # partition - if node.name == 'linear': - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=num_tp) - else: - tp_nodes = graph.replicate(node, times=num_tp) - # assign - for devid, node in zip(tp_group, tp_nodes): - graph.assign(node, devid) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - mesh = tp_mesh[0] - rnodes = graph.replicate(node, times=num_tp) - for devid, rnode in zip(mesh, rnodes): - graph.assign(rnode, devid) - # setup schedule to 1F1B - schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) - graph.schedule_plan = schedule - return graph diff --git a/examples/mlp/policy/pipe_parallel.py b/examples/mlp/policy/pipe_parallel.py deleted file mode 100644 index e50ca93f..00000000 --- a/examples/mlp/policy/pipe_parallel.py +++ /dev/null @@ -1,26 +0,0 @@ -import math -import random - -from cube.ir.operator import IRDataOperation, IRFwOperation - - -def PAS(graph, resource): - """ - Random pipeline - """ - micro_batch_num = resource.ngpus - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - device = random.randint(0, resource.ngpus - 1) - graph.assign(node, device) - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=0, num=micro_batch_num)) - if sub_nodes is None: - sub_nodes = [node] - for idx, sub_node in enumerate(sub_nodes): - device = random.randint(0, resource.ngpus - 1) - graph.assign(sub_node, device) - print(graph.extra_repr()) - return graph From 43d83ff0bb15434ead04697fbf859722f3b5afbe Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Jun 2022 10:34:59 +0800 Subject: [PATCH 0897/1892] add mpmd 1f1b policy --- examples/gsearch/gpt/policy/mpmd.py | 102 ++++++++++++++++++++++++++++ examples/gsearch/gpt/policy/spmd.py | 9 ++- examples/gsearch/gpt/train.py | 9 +-- 3 files changed, 114 insertions(+), 6 deletions(-) create mode 100644 examples/gsearch/gpt/policy/mpmd.py diff --git a/examples/gsearch/gpt/policy/mpmd.py b/examples/gsearch/gpt/policy/mpmd.py new file mode 100644 index 00000000..0b4b08a6 --- /dev/null +++ b/examples/gsearch/gpt/policy/mpmd.py @@ -0,0 +1,102 @@ +from typing import List, Tuple +import numpy as np + +from cube.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.schedule.sched1f1b import IRSchedule1F1B + + +def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + + +def PASRoundRobin(graph: IRGraph, resource): + """ + 1F1B scheduling + """ + + def org_transformer_layer(graph: IRGraph) -> List[List[IRFwOperation]]: + multiref_idx = [ + fidx for fidx, node in enumerate(graph.nodes()) if \ + isinstance(node, IRFwOperation) and node.name == 'multiref' + ] + assert len(multiref_idx) % 2 == 0, "un-recognized transormer structure" + transformers = [] + last_fidx = [fidx for fidx, node in enumerate(graph.nodes()) if isinstance(node, IRFwOperation)][-1] + for idx in range(0, len(multiref_idx), 2): + graph.nodes()[multiref_idx[idx]].comment = f'===> start of transformer {idx // 2}' + start = multiref_idx[idx] if idx != 0 else 0 + end = multiref_idx[idx+2] if idx+2 < len(multiref_idx) else last_fidx+1 + transformers.append(graph.nodes()[start:end]) + return transformers + + transformers = org_transformer_layer(graph) + for lid, fnodes in enumerate(transformers): + stage_id = lid % resource.ngpus + print(f'assigning {lid}-th transformer layter to stage {stage_id}') + for fnode in fnodes: + graph.assign(fnode, stage_id) + + for node in graph.nodes(): + if len(node.device) == 0: + graph.assign(node, 0) + + return graph + + +def PAS1F1B(graph: IRGraph, resource): + """ + 1F1B scheduling + """ + num_stage = resource.ngpus + num_microbatch = resource.ngpus + + _, stage_mesh = _create_mesh(resource.ngpus, (num_stage, 1)) + + def org_transformer_layer(graph: IRGraph) -> List[List[IRFwOperation]]: + multiref_idx = [ + fidx for fidx, node in enumerate(graph.nodes()) if \ + isinstance(node, IRFwOperation) and node.name == 'multiref' + ] + assert len(multiref_idx) % 2 == 0, "un-recognized transormer structure" + transformers = [] + last_fidx = [fidx for fidx, node in enumerate(graph.nodes()) if isinstance(node, IRFwOperation)][-1] + for idx in range(0, len(multiref_idx), 2): + graph.nodes()[multiref_idx[idx]].comment = f'===> start of transformer {idx // 2}' + start = multiref_idx[idx] if idx != 0 else 0 + end = multiref_idx[idx+2] if idx+2 < len(multiref_idx) else last_fidx+1 + transformers.append(graph.nodes()[start:end]) + return transformers + + transformers = org_transformer_layer(graph) + for lid, fnodes in enumerate(transformers): + stage_id = min(lid // (len(transformers) // resource.ngpus), num_stage-1) + print(f'assigning {lid}-th transformer layter to stage {stage_id}') + for fnode in fnodes: + graph.assign(fnode, stage_id) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + graph.assign(node, 0) + + schedule = IRSchedule1F1B(num_microbatch, stage_mesh, recompute=False) + graph.schedule_plan = schedule + return graph \ No newline at end of file diff --git a/examples/gsearch/gpt/policy/spmd.py b/examples/gsearch/gpt/policy/spmd.py index d0c73200..8212abc8 100644 --- a/examples/gsearch/gpt/policy/spmd.py +++ b/examples/gsearch/gpt/policy/spmd.py @@ -32,10 +32,15 @@ def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): graph.assign(sub_node, idx) return sub_nodes + # annotating code structure + multirefs = [node for node in graph.nodes() if isinstance(node, IRFwOperation) and node.name == 'multiref'] + for idx in range(0, len(multirefs), 2): + multirefs[idx].comment = f'====> start of transformer {idx // 2}' + # ============ Attention =============== qkvs = [node for node in fnodes if node.name == 'attn_qkv'] for idx, qkv in enumerate(qkvs): - tensor_parallelism(qkv, f'====> start of transformer {idx}', idx=1, dim=0, num=tp_size) + tensor_parallelism(qkv, idx=1, dim=0, num=tp_size) scores = [node for node in fnodes if node.name == 'attn_score'] for score in scores: @@ -76,7 +81,7 @@ def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): rnodes = graph.replicate(node, times=tp_size) for idx, rnode in enumerate(rnodes): graph.assign(rnode, idx) - print(graph.extra_repr()) + # print(graph.extra_repr()) return graph diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py index 7bab77e3..78a0228b 100644 --- a/examples/gsearch/gpt/train.py +++ b/examples/gsearch/gpt/train.py @@ -10,14 +10,15 @@ import torch -from examples.gsearch.gpt.model import GPT -from examples.gsearch.gpt.model import GPTDataLoader -from examples.gsearch.gpt.policy.spmd import PASMegatronTP as PAS - import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary, model_summary +from examples.gsearch.gpt.model import GPT +from examples.gsearch.gpt.model import GPTDataLoader +from examples.gsearch.gpt.policy.spmd import PASMegatronTP as PAS +# from examples.gsearch.gpt.policy.mpmd import PAS1F1B as PAS + def train(): From 84e6ab73711797e3e63a92a728da815aa48215ed Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Jun 2022 10:35:55 +0800 Subject: [PATCH 0898/1892] add comment inherit --- cube/graph/graph.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5d50d918..2c97d5f9 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -527,7 +527,8 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], Partition Primitive: - partition: partition a forward or data operation using algorithms. - The backward of the forward operation will automaticall be partitioned. + The comment in the node will be inherited to partitioned nodes. + The backward of the forward operation will be automatically partitioned. Requirement to partition algorithm: if backward is required, the algorithm can only transform tensors in: @@ -561,12 +562,16 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], findex = self.detach(node) for idx, fnode in enumerate(fnodes): self.attach(fnode, findex + idx) + if isinstance(node.comment, str): + fnode.comment = node.comment # update backward if isinstance(node.mirror, IRBpOperation): bindex = self.detach(node.mirror) bnodes = [fnode.gen_backward() for fnode in fnodes][::-1] for idx, bnode in enumerate(bnodes): self.attach(bnode, bindex + idx) + if isinstance(node.mirror.comment, str): + bnode.comment = node.mirror.comment # update gradient updated = set() for itensor in [t for t in node.inputs() if isinstance(t, IRSubTensor)]: @@ -657,7 +662,7 @@ def assign(self, node: Union[IRFwOperation, IRBpOperation], @return sucess bool: always true """ - assert isinstance(node, (IRFwOperation, IRDataOperation)), "Only forward and data operation can be assigned to device." + assert isinstance(node, (IRFwOperation, IRDataOperation)), f"Only forward and data operation can be assigned to device, but got {node}" assert node in self._nodes, f"{node} is not in the graph" ranks = (ranks,) if isinstance(ranks, int) else ranks assert all([isinstance(rank, int) for rank in ranks]), "Expected rank to be int" From a4de4e145f2119d2f0a07b9b33dd20619974c8d7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Jun 2022 15:27:41 +0800 Subject: [PATCH 0899/1892] add test script for all the cases --- examples/gsearch/gpt/policy/spmd.py | 3 +- examples/gsearch/gpt/train.py | 32 ++++++++++--- examples/mlp/linears.py | 32 ++++++++++--- examples/mlp/policy/spmd.py | 3 +- tests/test_examples.sh | 74 +++++++++++++++++++++++++++++ 5 files changed, 129 insertions(+), 15 deletions(-) create mode 100755 tests/test_examples.sh diff --git a/examples/gsearch/gpt/policy/spmd.py b/examples/gsearch/gpt/policy/spmd.py index 8212abc8..964958b1 100644 --- a/examples/gsearch/gpt/policy/spmd.py +++ b/examples/gsearch/gpt/policy/spmd.py @@ -9,7 +9,8 @@ def PASReplica(graph: IRGraph, resource): assert resource.ngpus == 1 print(graph.extra_repr()) for node in graph.nodes(): - graph.assign(node, 0) + if isinstance(node, (IRDataOperation, IRFwOperation)): + graph.assign(node, 0) # print(graph.extra_repr()) return graph diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py index 78a0228b..d2fa68d3 100644 --- a/examples/gsearch/gpt/train.py +++ b/examples/gsearch/gpt/train.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/gsearch/gpt/train.py + examples/gsearch/gpt/train.py --policy PASMegatronTP """ @@ -16,8 +16,29 @@ from examples.gsearch.gpt.model import GPT from examples.gsearch.gpt.model import GPTDataLoader -from examples.gsearch.gpt.policy.spmd import PASMegatronTP as PAS -# from examples.gsearch.gpt.policy.mpmd import PAS1F1B as PAS + +import examples.gsearch.gpt.policy.spmd as spmd +import examples.gsearch.gpt.policy.mpmd as mpmd + +import argparse +parser = argparse.ArgumentParser(description='comm primitive') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +args = parser.parse_args() + +cube.init() + +# set up policy +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + def train(): @@ -69,7 +90,4 @@ def train_iter(model, dataloader): memory_summary() -if __name__ == '__main__': - - cube.init() - train() \ No newline at end of file +train() \ No newline at end of file diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index b544bfed..b7084296 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py + examples/mlp/linears.py --policy PASMegatron """ import torch @@ -13,7 +13,30 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from examples.mlp.policy.spmd import PASMegatron as PAS + +import examples.mlp.policy.spmd as spmd +import examples.mlp.policy.mpmd as mpmd + +import argparse + +parser = argparse.ArgumentParser(description='comm primitive') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +args = parser.parse_args() + +cube.init() + +# set up policy +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + # =================== Semantic Model Description ==================== @@ -87,7 +110,4 @@ def train_iter(model, dataloader): CudaTimer().print_all(times=iter_num-warmup) -if __name__ == '__main__': - - cube.init() - train() \ No newline at end of file +train() \ No newline at end of file diff --git a/examples/mlp/policy/spmd.py b/examples/mlp/policy/spmd.py index 42df3f59..db541ad2 100644 --- a/examples/mlp/policy/spmd.py +++ b/examples/mlp/policy/spmd.py @@ -8,7 +8,8 @@ def PASSingle(graph: IRGraph, resource): """ assert resource.ngpus == 1, "only apply for single gpu case" for node in graph.nodes(): - graph.assign(node, 0) + if isinstance(node, (IRDataOperation, IRFwOperation)): + graph.assign(node, 0) return graph diff --git a/tests/test_examples.sh b/tests/test_examples.sh new file mode 100755 index 00000000..02e456d7 --- /dev/null +++ b/tests/test_examples.sh @@ -0,0 +1,74 @@ + + +# test MLP + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/mlp/linears.py --policy PASSingle + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py --policy PASData + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py --policy PASCol + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py --policy PASRow + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py --policy PASHybrid + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py --policy PASMegatron + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py --policy PASOptimal + + +# test GSearch + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/gsearch/gpt/train.py --policy PASReplica + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/gsearch/gpt/train.py --policy PASMegatronTP + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/gsearch/gpt/train.py --policy PASRoundRobin + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/gsearch/gpt/train.py --policy PAS1F1B + + +# test scientific model + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/poisson/sci.py + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/wrf/wrf2.py From 76c00453528e03c5e823edc1a242409e36489bda Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Jun 2022 15:28:13 +0800 Subject: [PATCH 0900/1892] fix dataloader split data bug --- cube/algorithm/ops/dataloader.py | 3 +-- cube/algorithm/ops/pad.py | 33 ++++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py index 9eec536b..da10cf68 100644 --- a/cube/algorithm/ops/dataloader.py +++ b/cube/algorithm/ops/dataloader.py @@ -1,7 +1,6 @@ from typing import List import copy -from cube.algorithm.utils import split_axis from cube.algorithm.generics import GenericDistAlgo from cube.ir.operator import IRDataOperation @@ -41,7 +40,7 @@ def instantiate(self, num: int): outputs = list() for dim, output in zip(dims, node.outputs()): - output = split_axis(output, dim, num) + output = output.split_dim(dim, num) outputs.append(output) nodes = list() diff --git a/cube/algorithm/ops/pad.py b/cube/algorithm/ops/pad.py index f325d7c5..61b74966 100644 --- a/cube/algorithm/ops/pad.py +++ b/cube/algorithm/ops/pad.py @@ -1,7 +1,28 @@ -from cube.algorithm.utils import split_axis, split_axis_custom +from typing import List, Tuple from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.pad import IRPad +from cube.ir.tensor import IRSubTensor + + +def _split_axis_custom(tensor: IRSubTensor, dim: int, chunks: List[Tuple[int, int]]): + """ + Split tensor along an axis with customized selection + """ + dim = len(tensor.shape) + dim if dim < 0 else dim + assert dim < len(tensor.shape), f"dim should within ndims ({dim} >= {tensor.ndims})" + chunk_num = len(chunks) + indmap = list() + for nele in tensor.shape: + indmap.append((0, nele)) + sub_tensors = list() + for cid in range(chunk_num): + indmap[dim] = chunks[cid] + sub_tensors.append(tensor.select( + indmap=tuple(indmap), valmap=(0,1) + )) + return sub_tensors + class DimSplitPad(GenericDistAlgo): """ @@ -49,12 +70,12 @@ def instantiate(self, dim: int, num: int): # split non-pad dim if dim < len(node.inputs(0).shape) - pad_dim_count: - inputs = split_axis(node.inputs(0), axis=dim, chunk_num=num) - outputs = split_axis(node.outputs(0), axis=dim, chunk_num=num) + inputs = node.inputs(0).split_dim(dim, num) + outputs = node.outputs(0).split_dim(dim, num) for i, o in zip(inputs, outputs): subnodes.append(node.new([i], [o])) else: # split pad dim - inputs = split_axis(node.inputs(0), axis=dim, chunk_num=num) + inputs = node.inputs(0).split_dim(dim, num) slicers = list() pads = list() dim_in_pad = len(node.inputs(0).shape) - 1 - dim @@ -72,10 +93,10 @@ def instantiate(self, dim: int, num: int): pads.append(cur_pad) stop = start + padl + padr + chunk_size - slicers.append(slice(max(0, start), min(node.outputs(0).shape[dim], stop))) + slicers.append((max(0, start), min(node.outputs(0).shape[dim], stop))) start = stop - outputs = split_axis_custom(node.outputs(0), axis=dim, chunks=slicers) + outputs = _split_axis_custom(node.outputs(0), dim, tuple(slicers)) for i, o, p in zip(inputs, outputs, pads): subnodes.append(node.new([i], [o], pad=p)) From 21210dcb459008b7e1f3c75e85456763d2695f87 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Mon, 27 Jun 2022 08:45:27 +0000 Subject: [PATCH 0901/1892] Merged PR 1390: Ensure when entering CubeIR each no-input ops have dtype PyTorch frontend allows ops like `zeros` to have no `dtype` specified, where the global `dtype` will be set by default. But in CubeIR we must have dtype so that send/recv adapters work correctly. Add the process in lowering part. --- cube/graph/function/creators.py | 15 ++++++++++++++- cube/graph/function/function.py | 12 ++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index 1b098154..0691429e 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -6,13 +6,18 @@ from cube.ir.cten import IRTensor class IRZeros(IRFwOperation): - def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:Optional[IRDType]=None): + def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): # The shape information must be statically known integer values assert all(isinstance(dim, int) for dim in shape) + assert isinstance(ir_dtype, IRDType) super().__init__(name, signature, input_length=0, output_length=1) + # Customize output's dtype only after 'super().__init__' and 'self.set_input', + # otherwise it gets overwritten. + self.outputs(0).dtype = ir_dtype + # The positional argument to specify the shape is actually called 'size'. self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) @@ -35,8 +40,16 @@ def infer_shape(self) -> bool: # https://github.com/pytorch/pytorch/blob/483bb4f0cb273f42f655aa30eee6a1fbbaba69b0/torch/csrc/jit/runtime/register_prim_ops.cpp#L2215 class IRToTensor(IRFwOperation): def __init__(self, signature: str, inputs, name:str, ir_dtype:IRDType): + + assert isinstance(ir_dtype, IRDType) + super().__init__(name, signature, input_length=1, output_length=1) self.set_input(0, inputs[0]) + + # Customize output's dtype only after 'super().__init__' and 'self.set_input', + # otherwise it gets overwritten. + self.outputs(0).dtype = ir_dtype + self.kwargs.update({"dtype": ir_dtype}) def infer_shape(self) -> bool: diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 7e4d7995..90e40471 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -56,13 +56,14 @@ def Zeros(signature, size, dtype, layout, _erased_device, pin_memory = inputs # TODO parameters to support, currently they are all None - assert dtype is None assert layout is None assert pin_memory is None - ir_dtype : Optional[IRDType] = None + ir_dtype : IRDType if dtype is not None: ir_dtype = DType2IRDType.map(dtype) + else: + ir_dtype = DType2IRDType.map(torch.get_default_dtype()) for dim, i in enumerate(size): if not isinstance(dim, int) and not dim >= 0: @@ -80,12 +81,13 @@ def NewTensor(signature, data, dtype, _erased_device, requires_grad = inputs # TODO parameters to support, currently they are all None - assert dtype is None assert requires_grad == False - ir_dtype : Optional[IRDType] = None + ir_dtype : IRDType if dtype is not None: ir_dtype = DType2IRDType.map(dtype) + else: + ir_dtype = DType2IRDType.map(torch.get_default_dtype()) # if 'data' is not: # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 @@ -98,8 +100,6 @@ def NewTensor(signature, # but since we have omitted the 'data', we must do type inferrence ourselves, # only in this way we get correct dtype e.g. ints or bools. shape = list(arr.shape) - torch_inferred_dtype = arr.dtype - ir_dtype = DType2IRDType.map(torch_inferred_dtype) signature = 'torch.zeros' return IRZeros(signature, shape, 'tensor', ir_dtype=ir_dtype) From 1af428a84c01ff3e5b3b8d5e8a1ecd9c6ab749a8 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Mon, 27 Jun 2022 14:40:28 +0000 Subject: [PATCH 0902/1892] Merged PR 1391: Add faster execplan::grouping implementation Add faster execplan::grouping implementation for inference only. Add environment variable `SCINTIFIC_COMPUTING` to turn it on. Make execplan::grouping complexity nearly linear to (unrolled) graph size, that 1k nodes ~= 1s. --- README.md | 28 +++++++++++ cube/execplan/planpass/grouping.py | 79 ++++++++++++++++++++++++------ tests/test_execplan_grouping.py | 66 +++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 tests/test_execplan_grouping.py diff --git a/README.md b/README.md index 6e14a58c..7be5f602 100644 --- a/README.md +++ b/README.md @@ -51,3 +51,31 @@ OMP_NUM_THREADS=4 torchrun \ --nnodes=1 \ examples/mlp/linears.py ``` + +## Profile + +### Use cProfile + +Due to the multi-process architecture of `torch.distributed.launch`, instead of directly using +the command-line interface of cProfile, we need to exactly specify the scope to profile, like: + +```python +import cProfile +prof = cProfile.Profile() +prof.enable() + +# our code to profile goes here +@cube.compile(...) +def iter(dataloader): + x, y = next(dataloader) + z = model(x, y) + return z +for i in range(N): + iter(...) +# our code ends + +prof.disabled() +pr.dump_stats('cube_%d.prof' % torch.distributed.get_rank()) # or use TID/PID, if to profile multi-thread/-process program. +``` + +After the modification, run the Python file using the same command line with `torchrun` as usual. diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index d2c20802..d3cbca65 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -1,7 +1,8 @@ """ Operation grouping """ -from typing import List, Dict, Tuple +import os +from typing import List, Dict, Optional, Tuple from cube.execplan import ExecutionPlan from cube.execplan.planpass.planpass import PlanPass @@ -9,9 +10,32 @@ from cube.ir.operator import IRBpOperation, IRFwOperation from cube.ir.cten import IRCell +SCIENTIFIC_COMPUTING = 'SCIENTIFIC_COMPUTING' +_use_new_grouping_algo:Optional[bool] = None -class Grouping(PlanPass): +def _set_use_new_grouping_algo(use_new_grouping_algo:Optional[bool]) -> None: + """ + Set the internal flag whether to use a new grouping algorithm which is faster for grouping forward-only graphs, + especially for workloads from scientific-computing domains. + + Parameters: + - use_new_grouping_algo (bool): + 'True' to force the use of the new grouping algorithm. + 'False' to force the use of the old grouping algorithm. + 'None' to use the new grouping algorithm if the environment variable 'SCIENTIFIC_COMPUTING' exists. + """ + assert use_new_grouping_algo is None or isinstance(use_new_grouping_algo, bool) + global _use_new_grouping_algo + _use_new_grouping_algo = use_new_grouping_algo +def _get_use_new_grouping_algo() -> bool: + if _use_new_grouping_algo is None: + return SCIENTIFIC_COMPUTING in os.environ + else: + return _use_new_grouping_algo + + +class Grouping(PlanPass): @staticmethod def apply(execplan: ExecutionPlan) -> ExecutionPlan: """ @@ -53,10 +77,16 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: fpieces, bpieces = list(), list() seq = execplan.seq(devid) fnodes = [] - for fnode in seq: + + def is_forward_node(fnode): if isinstance(fnode, IRFwOperation): - fnodes.append(fnode) + return True if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.forward: + return True + return False + + for fnode in seq: + if is_forward_node(fnode): fnodes.append(fnode) have_backward = all(fnode.mirror in seq for fnode in fnodes) # training @@ -75,16 +105,37 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: fpieces, bpieces = [fnode], [bnode] # inference else: - for fnode in fnodes + [-1]: - fconsecutive = Grouping.consecutive(seq, fpieces, fnode) - if fconsecutive: - fpieces.append(fnode) - bpieces.append(None) - else: - if len(fpieces) != 0: - fgroups[devid].append(fpieces) - bgroups[devid].append(None) - fpieces, bpieces = [fnode], [None] + if _get_use_new_grouping_algo(): + + for fnode in seq: + if is_forward_node(fnode): + fpieces.append(fnode) + else: + if len(fpieces) != 0: + fgroups[devid].append(fpieces) + bgroups[devid].append(None) + + # If the fnode is not a "forward node", e.g. it's DataOp node, don't add it into the group. + fpieces = [] + # 'bpieces' is never filled or returned in the inference mode + + if len(fpieces) != 0: + fgroups[devid].append(fpieces) + bgroups[devid].append(None) + + else: # Not using new algo + + for fnode in fnodes + [-1]: + fconsecutive = Grouping.consecutive(seq, fpieces, fnode) + if fconsecutive: + fpieces.append(fnode) + bpieces.append(None) + else: + if len(fpieces) != 0: + fgroups[devid].append(fpieces) + bgroups[devid].append(None) + fpieces, bpieces = [fnode], [None] + return fgroups, bgroups @staticmethod diff --git a/tests/test_execplan_grouping.py b/tests/test_execplan_grouping.py new file mode 100644 index 00000000..9323a6aa --- /dev/null +++ b/tests/test_execplan_grouping.py @@ -0,0 +1,66 @@ +# run tests: +# pytest ./tests/test_execplan_grouping.py + +from typing import Dict, List + +import pytest +from cube.execplan.planpass import grouping +from cube.execplan.planpass.grouping import Grouping +from cube.ir.cten import IRCell +from cube.ir.operator import IRDataOperation, IRFwOperation + +# Stub object for 'cube.execplan.ExecPlan' +# Commonly the devices are like [0,1,2,...] +class StubExecPlan(): + def __init__(self, devices:List[int], seq:Dict[int, List[IRCell]]) -> None: + assert all(devid in seq for devid in devices) + self._devices = devices + self._seq = seq + + def devices(self): + return self._devices + def seq(self, devid:int): + return self._seq[devid] + +# With these settings, all tests here are run twice, with 'grouping._get_new...algo' returning True or False, respectively. +# And all the setting ups and the recovery of this flag happen in the background. +# +# By runninng tests in both environments, we can check the consistency of the old and new algorithms. +@pytest.fixture(params=[True, False], autouse=True) +def setup_and_cleanup(request:pytest.FixtureRequest) -> None: + flag = grouping._get_use_new_grouping_algo() + grouping._set_use_new_grouping_algo(request.param) + yield + grouping._set_use_new_grouping_algo(flag) + + +def test_grouping_forward_single_group(): + execplan = StubExecPlan([0], {0: [IRFwOperation(f"op{i}", f"sign{i}", i, i) for i in range(1, 10)] }) + # each type: Dict[DeviceIdInt, List[List[IRCell]] ] + fwgroups, bpgroups = Grouping.group(execplan) + + assert len(fwgroups) == 1 # one device + assert len(fwgroups[0]) == 1 # one group + assert all(fnode.name == f"op{i+1}" for i, fnode in enumerate(fwgroups[0][0])) + + assert len(bpgroups) == 1 + assert len(bpgroups[0]) == 1 + assert bpgroups[0][0] is None + + +def test_grouping_forward_interleaving_excluded_nodes(): + execplan = StubExecPlan([0], {0: [ + IRFwOperation(f"op{i}", f"sign{i}", i, i) if i % 2 == 0 + else IRDataOperation(i, (2,)*i) # IRDataOperation is the IRCell to exclude from the group + for i in range(1, 9) # [1,2,...,8] + ] }) + # each type: Dict[DeviceIdInt, List[List[IRCell]] ] + fwgroups, bpgroups = Grouping.group(execplan) + + assert len(fwgroups) == 1 + assert len(fwgroups[0]) == 4 + assert all(len(fwgroup) == 1 and fwgroup[0].name == f"op{i}" for fwgroup, i in zip(fwgroups[0], [2,4,6,8])) + + assert len(bpgroups) == 1 + assert len(bpgroups[0]) == 4 + assert all(bpgroup is None for bpgroup in bpgroups[0]) \ No newline at end of file From a78921736c5224ac1726121f1bbb06f45a91f41f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 28 Jun 2022 19:37:54 +0800 Subject: [PATCH 0903/1892] remove useless utils --- cube/algorithm/utils.py | 79 ----------------------------------------- 1 file changed, 79 deletions(-) delete mode 100644 cube/algorithm/utils.py diff --git a/cube/algorithm/utils.py b/cube/algorithm/utils.py deleted file mode 100644 index 7d43e84c..00000000 --- a/cube/algorithm/utils.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import List, Union -from cube.ir.tensor import IRSubTensor - - -def split_axis(tensor: IRSubTensor, axis: int, chunk_num: int): - """ - Split tensor along an axis. The axis can be positive or negative. - """ - if axis < 0: - axis = len(tensor.shape) + axis - if axis >= len(tensor.shape): - raise RuntimeError(f"Axis should within dims ({axis} >= {len(tensor.shape)})") - - chunk_size = int(tensor.shape[axis] // chunk_num) - - shape_slicer = list() - chunk_shape = list() - for dim, nele in enumerate(tensor.shape): - if dim != axis: - shape_slicer.append(slice(0, nele, 1)) - chunk_shape.append(nele) - else: - shape_slicer.append(None) - chunk_shape.append(chunk_size) - - sub_tensors = list() - for cid in range(chunk_num): - shape_slicer[axis] = slice(chunk_size * cid, chunk_size * (cid + 1), 1) - sub_tensors.append(tensor.select( - indmap = tuple(shape_slicer), - valmap = None, - shape = chunk_shape - )) - return sub_tensors - - -def split_axis_custom(tensor: IRSubTensor, axis: int, chunks: List[slice]): - """ - Split tensor along an axis with customized selection - """ - if axis < 0: - axis = len(tensor.shape) + axis - if axis >= len(tensor.shape): - raise RuntimeError(f"Axis should within dims ({axis} >= {len(tensor.shape)})") - chunk_num = len(chunks) - - slicers, shape = list(), list() - for nele in tensor.shape: - slicers.append(slice(0, nele, 1)) - shape.append(nele) - sub_tensors = list() - for cid in range(chunk_num): - slicers[axis] = chunks[cid] - shape[axis] = chunks[cid].stop - chunks[cid].start - sub_tensors.append(tensor.select( - indmap = tuple(slicers), - valmap = None, - shape = shape - )) - return sub_tensors - - -def split_value(tensor: IRSubTensor, chunk_num: int): - - # full shape - shape_slicer = list() - for nele in tensor.shape: - shape_slicer.append(slice(0, nele, 1)) - - sub_tensors = list() - for idx in range(chunk_num): - sub_tensor = tensor.select( - indmap = tuple(shape_slicer), - valmap = (idx, chunk_num), - shape = tensor.shape - ) - sub_tensors.append(sub_tensor) - - return sub_tensors From ea41963ca432915265eeba274347216b56e61538 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Jun 2022 10:03:33 +0800 Subject: [PATCH 0904/1892] hotfix: GPT data loader position input shape --- examples/gsearch/gpt/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/gsearch/gpt/model.py b/examples/gsearch/gpt/model.py index b8de4b12..817bf1ff 100644 --- a/examples/gsearch/gpt/model.py +++ b/examples/gsearch/gpt/model.py @@ -9,7 +9,7 @@ class Config: num_embeddings = 50304 - seqlen = 512 + seqlen = 1024 # 1.7B model embed_dim = 2304 @@ -93,7 +93,7 @@ def random_sample(self): ) position_ids = torch.arange( 0, self.cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() - ).repeat(self.bs) + ).repeat(self.bs).view((self.bs, -1)) return (input_ids, position_ids) def __iter__(self): From f3b7397f58a753a7ab0b985bfe333622301fc00b Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 30 Jun 2022 06:50:11 +0000 Subject: [PATCH 0905/1892] Merged PR 1392: Add 'torch.ones' and dispatch 'torch.tensor' to it Add 'torch.ones' and dispatch 'torch.tensor' to it --- cube/codegen/frontend_mapping.py | 18 ++++++++++++++++++ cube/graph/function/creators.py | 20 ++++++++++++++++++++ cube/graph/function/function.py | 26 +++++++++++++++++++++++--- cube/graph/parser/mapping.py | 1 + 4 files changed, 62 insertions(+), 3 deletions(-) diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index e14ecf19..8a6ae7d2 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -88,6 +88,23 @@ def emit_zeros(node, arg_vars:list, kw_pairs:dict) -> str: assert len(arg_vars) == 0 return _common_rule_join_all(node, arg_vars, kw_pairs) +def emit_ones(node, arg_vars:list, kw_pairs:dict) -> str: + """ + ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + """ + kw_pairs = kw_pairs.copy() + if 'dtype' in kw_pairs: + ir_dtype : IRDType = kw_pairs['dtype'] + if ir_dtype is not None: + kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) + + # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. + assert 'device' not in kw_pairs + kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. + + assert len(arg_vars) == 0 + return _common_rule_join_all(node, arg_vars, kw_pairs) + # Basically to convert internal 'IRDType' to frontend 'torch.dtype' def emit_to(node, arg_vars:list, kw_pairs:dict) -> str: kw_pairs = kw_pairs.copy() @@ -125,6 +142,7 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: 'torch.slice': emit_slice, 'torch.zeros': emit_zeros, + 'torch.ones': emit_ones, 'torch.Tensor.to': emit_to, } diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index 0691429e..c822a0bd 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -26,6 +26,26 @@ def infer_shape(self) -> bool: self.outputs(0).shape = shape return True +class IROnes(IRFwOperation): + def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): + + # The shape information must be statically known integer values + assert all(isinstance(dim, int) for dim in shape) + assert isinstance(ir_dtype, IRDType) + + super().__init__(name, signature, input_length=0, output_length=1) + + # Customize output's dtype only after 'super().__init__' and 'self.set_input', + # otherwise it gets overwritten. + self.outputs(0).dtype = ir_dtype + + # The positional argument to specify the shape is actually called 'size'. + self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) + + def infer_shape(self) -> bool: + shape : list = copy(self.kwargs["size"]) + self.outputs(0).shape = shape + return True #class IRNewTensor(IRFwOperation): # def __init__(self, signature: str, data, name:str): diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 90e40471..f107b91d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -13,7 +13,7 @@ from cube.graph.function.scripteinops import IRScriptEinOps from cube.graph.function.customops import IRCustomOps from cube.graph.function.cat import IRCat, IRStack -from cube.graph.function.creators import IRToTensor, IRZeros +from cube.graph.function.creators import IROnes, IRToTensor, IRZeros from cube.graph.function.select import IRSelect, IRSlice from cube.graph.function.scatter import IRSelectScatter from cube.graph.function.repeat import IRRepeat @@ -70,6 +70,26 @@ def Zeros(signature, raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") return IRZeros(signature, size, 'zeros', ir_dtype) +def Ones(signature, + inputs: Tuple[ List[int], Optional[Any], Optional[Any], 'ErasedDevice', Optional[bool] ]): + # ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + + size, dtype, layout, _erased_device, pin_memory = inputs + + # TODO parameters to support, currently they are all None + assert layout is None + assert pin_memory is None + + ir_dtype : IRDType + if dtype is not None: + ir_dtype = DType2IRDType.map(dtype) + else: + ir_dtype = DType2IRDType.map(torch.get_default_dtype()) + + for dim, i in enumerate(size): + if not isinstance(dim, int) and not dim >= 0: + raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") + return IROnes(signature, size, 'ones', ir_dtype) def NewTensor(signature, inputs: Tuple[ list, Optional[Any], 'ErasedDevice', bool ]): @@ -100,8 +120,8 @@ def NewTensor(signature, # but since we have omitted the 'data', we must do type inferrence ourselves, # only in this way we get correct dtype e.g. ints or bools. shape = list(arr.shape) - signature = 'torch.zeros' - return IRZeros(signature, shape, 'tensor', ir_dtype=ir_dtype) + signature = 'torch.ones' + return IROnes(signature, shape, 'tensor', ir_dtype=ir_dtype) def ToTensor(signature, inputs: Tuple[ IRTensor, ... ]): diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 736ac8e1..182e2f1b 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -77,6 +77,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # creators __ttemplate('zeros'): function.Zeros, + __ttemplate('ones'): function.Ones, __ttemplate('tensor'): function.NewTensor, __ttemplate('to'): function.ToTensor, From 59de61d7a7c87ab5719efdc90ac0b7716f541b1a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Jun 2022 16:38:54 +0800 Subject: [PATCH 0906/1892] example database call --- cube/profiler/database.py | 65 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 cube/profiler/database.py diff --git a/cube/profiler/database.py b/cube/profiler/database.py new file mode 100644 index 00000000..c0668f80 --- /dev/null +++ b/cube/profiler/database.py @@ -0,0 +1,65 @@ +from typing import Callable, Tuple +import torch +import time + + +class CompProfiler: + + @staticmethod + def profile(func: Callable, shapes: Tuple[Tuple[int]], dtypes=None, warmup_sec: float=2, prof_times: int = 50, **kwargs): + """ + Profile a function + + @param func Callable: the callable function, e.g., torch.nn.functional.linear + @param shapes Tuple[Tuple[int]]: the shapes of each input tensor + @param dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 + + @return span float: the time in milliseconds for forward + backward time + """ + + # create data + dtypes = [torch.float32] * len(shapes) if dtypes is None else dtypes + tensors = tuple( + torch.rand(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=True) \ + for shape, dtype in zip(shapes, dtypes) + ) + outputs = func(*tensors, **kwargs) + outputs = (outputs,) if torch.is_tensor(outputs) else outputs + assert all(torch.is_tensor(otensor) for otensor in outputs), f"{func.__name__}: require all the outputs to be tensors" + grads = tuple(torch.zeros_like(otensor) for otensor in outputs) + + # warmup + tic = time.time() + while time.time() - tic < warmup_sec: + # forward + outputs = func(*tensors, **kwargs) + outputs = (outputs,) if torch.is_tensor(outputs) else outputs + # backward + torch.autograd.backward(outputs, grads) + + # profile forward + torch.cuda.synchronize() + tic = time.perf_counter() + for _ in range(prof_times): + # forward + outputs = func(*tensors, **kwargs) + outputs = (outputs,) if torch.is_tensor(outputs) else outputs + # backward + torch.autograd.backward(outputs, grads) + torch.cuda.synchronize() + toc = time.perf_counter() + span = (toc - tic) / prof_times * 1000 # in milliseconds + return span + + +if __name__ == '__main__': + + func = torch.nn.functional.linear + + shapes = ([2, 1024, 2304], [2, 2304]) + span = CompProfiler.profile(torch.nn.functional.linear, shapes) + print(f'span of {func.__name__}: shapes: {shapes}: {span} ms') + + shapes = ([8, 1024, 2304], [8, 2304]) + span = CompProfiler.profile(torch.nn.functional.linear, shapes) + print(f'span of {func.__name__}: shapes: {shapes}: {span} ms') From 49f9a6e87d50152411bf065d51c98086163ac4b7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Jun 2022 19:27:08 +0800 Subject: [PATCH 0907/1892] profile and save to database --- cube/profiler/database.py | 243 +++++++++++++++++++++++++++++++++++--- 1 file changed, 227 insertions(+), 16 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index c0668f80..695d8094 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -1,22 +1,42 @@ -from typing import Callable, Tuple +""" +Usage: + python -m cube.profiler.database --export ./profile.dat.json +""" + +from typing import Callable, Tuple, Union, Optional, Dict, NewType import torch import time +import os +import json + + +Shapes = NewType('Shapes', Tuple[Tuple[int]]) +DTypes = NewType('DTypes', Tuple[torch.dtype]) +ShapesDTypes = NewType('ShapesDTypes', Tuple[Shapes, DTypes]) +NameOrFunc = NewType('NameOrFunc', Union[str, Callable]) class CompProfiler: @staticmethod - def profile(func: Callable, shapes: Tuple[Tuple[int]], dtypes=None, warmup_sec: float=2, prof_times: int = 50, **kwargs): + def profile(func: Callable, shapes: Shapes, dtypes: DTypes, + warmup_sec: float = 2, prof_times: int = 50, backward = True, + **kwargs): """ Profile a function @param func Callable: the callable function, e.g., torch.nn.functional.linear @param shapes Tuple[Tuple[int]]: the shapes of each input tensor @param dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 + @param warmup_sec float: warmup seconds + @param prof_times int: profile times + @param backward bool: whether profile backward times. Default true. + @param kwargs Dict: other keyword argument for func call. - @return span float: the time in milliseconds for forward + backward time + @return span float: the time in milliseconds for forward (+backward) time """ - + assert len(shapes) == len(dtypes), \ + f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" # create data dtypes = [torch.float32] * len(shapes) if dtypes is None else dtypes tensors = tuple( @@ -25,7 +45,8 @@ def profile(func: Callable, shapes: Tuple[Tuple[int]], dtypes=None, warmup_sec: ) outputs = func(*tensors, **kwargs) outputs = (outputs,) if torch.is_tensor(outputs) else outputs - assert all(torch.is_tensor(otensor) for otensor in outputs), f"{func.__name__}: require all the outputs to be tensors" + assert all(torch.is_tensor(otensor) for otensor in outputs), \ + f"{func.__name__}: require all the outputs to be tensors" grads = tuple(torch.zeros_like(otensor) for otensor in outputs) # warmup @@ -33,9 +54,9 @@ def profile(func: Callable, shapes: Tuple[Tuple[int]], dtypes=None, warmup_sec: while time.time() - tic < warmup_sec: # forward outputs = func(*tensors, **kwargs) - outputs = (outputs,) if torch.is_tensor(outputs) else outputs # backward - torch.autograd.backward(outputs, grads) + if backward: + torch.autograd.backward(outputs, grads) # profile forward torch.cuda.synchronize() @@ -43,23 +64,213 @@ def profile(func: Callable, shapes: Tuple[Tuple[int]], dtypes=None, warmup_sec: for _ in range(prof_times): # forward outputs = func(*tensors, **kwargs) - outputs = (outputs,) if torch.is_tensor(outputs) else outputs # backward - torch.autograd.backward(outputs, grads) + if backward: + torch.autograd.backward(outputs, grads) torch.cuda.synchronize() toc = time.perf_counter() span = (toc - tic) / prof_times * 1000 # in milliseconds return span +class ProfileDataBase: + + def __init__(self, filename: Optional[str] = None) -> None: + """! + Create a database for profiling result + """ + + self._data: Dict[str, Dict[str, float]] = dict() + if filename is not None: + self.load(filename) + + def profile(self, func: Callable, shapes: Shapes, dtypes: DTypes, + backward=True, **kwargs): + """! + Profile the function and log into the database + + @param func Callable: the callable function, e.g., torch.nn.functional.linear + @param shapes Tuple[Tuple[int]]: the shapes of each input tensor + @param dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 + @param backward bool: whether profile backward times. Default true. + @param kwargs Dict: other keyword argument for func call. + """ + try: + assert callable(func), "func should be callable" + span = CompProfiler.profile(func, shapes, dtypes, backward=backward, **kwargs) + except Exception as e: + print(f'fail to profile {func.__name__}: reason: {str(e)}') + name = func.__name__ + key = self.serialize(shapes, dtypes) + self.log(name, key, span) + print(f'profiled {func.__name__} | shapes: {shapes} | dtypes: {dtypes} => span: {round(span, 2)} ms') + + def log(self, name: str, key: str, span: float): + """ + log the span of a function name with key + """ + assert isinstance(name, str) and isinstance(span, float) and isinstance(key, str) + if name not in self._data: + self._data[name] = dict() + self._data[name][key] = span + + def query(self, func: NameOrFunc, shapes: Shapes, dtypes: DTypes) -> float: + """! + Get the performance number of the function name and its key + + @param name str: function name + @param shapes Tuple[Tuple[int]]: the shape of each input tensor + @param dtypes Tuple[torch.dtype]: the dtype of each tensor + + @return span float: the performance number + """ + name = func if isinstance(func, str) else func.__name__ + key = self.serialize(shapes, dtypes) + return self._data[name][key] + + def exist_item(self, func: NameOrFunc, shapes: Shapes, dtypes: DTypes) -> bool: + """! + Check if the required data exists + + @param name Union[str, Callable]: function name + @param shapes Tuple[Tuple[int]]: the shape of each input tensor + @param dtypes Tuple[torch.dtype]: the dtype of each tensor + + @return exist bool: True if the item exists else False + """ + name = func if isinstance(func, str) else func.__name__ + if name not in self._data: + return False + key = self.serialize(self, shapes, dtypes) + if key not in self._data[key]: + return False + return True + + def exist_func(self, func: NameOrFunc) -> bool: + """! + Check if the required function exists + + @param name Union[str, Callable]: function name + + @return exist bool: True if the function exists else False + """ + name = func if isinstance(func, str) else func.__name__ + return name in self._data + + def shapes_and_dtypes(self, func: NameOrFunc) -> Tuple[ShapesDTypes]: + """ + Get recorded shapes and dtypes of the func. + + @param func UnShapesDTypesion[str, Callable]: function name + + @return shapes_and_dtypes Tuple[ShapesDTyptes] + """ + name = func if isinstance(func, str) else func.__name__ + rets = [] + for shapes_dtypes_str in self._data[name].keys(): + (shapes, dtypes) = self.deserialize(shapes_dtypes_str) + rets.append((shapes, dtypes)) + return tuple(rets) + + def serialize(self, shapes: Shapes, dtypes: DTypes) -> str: + """ + Serialize the shapes, dtypes and kwargs into a string + + e.g., + shapes: ((1024,), (1024,1024)) + dtypes: (torch.float32, torch.float32) + => (1024,)-(1024,1024)=torch.float32-torch.float32 + + @param shapes Tuple[Tuple[int]]: the shape of each tensor + @param dtypes Tuple[torch.dtype]: the dtype of each tensor + + @return key str: the serialized string + """ + shapes = '-'.join(str(tuple(shape)) for shape in shapes) + if dtypes is not None: + dtypes = '-'.join(str(dtype) for dtype in dtypes) + else: + dtypes = '-'.join([str(torch.float32)] * len(shapes)) + return shapes + '=' + dtypes + + def deserialize(self, key: str) -> ShapesDTypes: + """ + De-serialize the key string to shapes and dtypes + + e.g., (1024,)-(1024,1024)=torch.float32-torch.float32 + => shapes: ((1024,), (1024,1024)) + dtypes: (torch.float32, torch.float32) + + @param key str: the serialized string + @return shapes_and_dtypes ShapesDTypes: shapes and dtypes + """ + shapes, dtypes = key.split('=') + print(shapes) + shapes = tuple(eval(shape) for shape in shapes.split('-')) + dtypes = tuple(eval(dtype) for dtype in dtypes.split('-')) + return shapes, dtypes + + def dump(self, file: str, override=False): + """! + dump the profiled data into json format + + @param file str: the file name + @param override bool: True if the existed can be overrided else False + """ + if os.path.exists(file): + assert override, f"File {file} exists. Set override = True to force dump." + with open(file, 'w') as f: + json.dump(self._data, f) + + def load(self, file: str): + """! + load the profiled data into data base. The original existed one will be + overrided by the loaded data. + + @param file str: the file name + """ + with open(file, 'r') as f: + self._data = json.load(f) + + if __name__ == '__main__': - func = torch.nn.functional.linear + import argparse + parser = argparse.ArgumentParser(description='database') + parser.add_argument('--export', type=str, default='./profile.dat.json', + help='saved profiling database') + args = parser.parse_args() - shapes = ([2, 1024, 2304], [2, 2304]) - span = CompProfiler.profile(torch.nn.functional.linear, shapes) - print(f'span of {func.__name__}: shapes: {shapes}: {span} ms') + db = ProfileDataBase() + + # profile + dtype = torch.float32 + # func: [ + # [shapes, dtypes, kwargs], + # ] + funcs = { + torch.nn.functional.linear: [ + [([1024, 1, 2304], [2304, 2304]), (dtype, dtype), {}], + [([1024, 4, 2304], [2304, 2304]), (dtype, dtype), {}], + [([1024, 8, 2304], [2304, 2304]), (dtype, dtype), {}] + ], + + torch.nn.functional.gelu: [ + [((1024, 8, 2304),), (dtype,), {}] + ], + + torch.nn.functional.softmax: [ + [((1024, 8, 2304),), (dtype,), dict(dim=-1)] + ] + } + + for func, keys in funcs.items(): + for shapes, dtypes, kwargs in keys: + db.profile(func, shapes, dtypes, backward=True, **kwargs) + + db.dump(args.export, override=True) - shapes = ([8, 1024, 2304], [8, 2304]) - span = CompProfiler.profile(torch.nn.functional.linear, shapes) - print(f'span of {func.__name__}: shapes: {shapes}: {span} ms') + # db = ProfileDataBase(args.export) + # for shapes, dtypes in db.shapes_and_dtypes(torch.nn.functional.linear): + # span = db.query(torch.nn.functional.linear, shapes, dtypes) + # print(f'logged shapes: {shapes}, dtypes: {dtypes} => span: {span} ms') From 5f677e424b7b13262faaf4bfb12177dd3e2abbbd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 30 Jun 2022 19:28:35 +0800 Subject: [PATCH 0908/1892] robustness --- cube/profiler/database.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 695d8094..81120d3a 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -98,13 +98,13 @@ def profile(self, func: Callable, shapes: Shapes, dtypes: DTypes, try: assert callable(func), "func should be callable" span = CompProfiler.profile(func, shapes, dtypes, backward=backward, **kwargs) + name = func.__name__ + key = self.serialize(shapes, dtypes) + self.log(name, key, span) + print(f'profiled {func.__name__} | shapes: {shapes} | dtypes: {dtypes} => span: {round(span, 2)} ms') except Exception as e: print(f'fail to profile {func.__name__}: reason: {str(e)}') - name = func.__name__ - key = self.serialize(shapes, dtypes) - self.log(name, key, span) - print(f'profiled {func.__name__} | shapes: {shapes} | dtypes: {dtypes} => span: {round(span, 2)} ms') - + def log(self, name: str, key: str, span: float): """ log the span of a function name with key From fc92275429c6d8c32a2efc1c97a91e576b9c63b0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Jul 2022 10:50:45 +0800 Subject: [PATCH 0909/1892] fix self dependency bug --- cube/graph/graph.py | 2 +- cube/logics/translator.py | 33 ++++++++++++++++++--------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 2c97d5f9..b9b23ee4 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -176,7 +176,7 @@ def reset_dependency(self): consumer.add_predecessor(cidx, producer) # set mirror as control dependency if producer.mirror and isinstance(producer, IRFwOperation): - producer.add_successor(-1, producer) + producer.add_successor(-1, producer.mirror) producer.mirror.add_predecessor(-1, producer) def parameters(self): diff --git a/cube/logics/translator.py b/cube/logics/translator.py index 0b3c32fd..5c2db40d 100644 --- a/cube/logics/translator.py +++ b/cube/logics/translator.py @@ -18,23 +18,26 @@ def gen_logic_graph(outputs=None) -> IRGraph: Generate Training Logic Graph """ nodes = SchedulePool().nodes() - graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') - has_bp = any(n for n in graph.nodes() if isinstance(n, IRBpOperation)) + has_bp = any(n for n in nodes if isinstance(n, IRBpOperation)) if has_bp: - assert all(fnode.mirror in graph.nodes() for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)), \ + assert all(fnode.mirror in nodes for fnode in nodes if isinstance(fnode, IRFwOperation)), \ "Training requires all nodes have backward." - return graph - # remove backward nodes if no backward is called - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - for fnode in fnodes: - IRCell.make_pair(fnode, None) - for ftensor in graph.full_tensors(): - ftensor.requires_grad = False - #TODO: ad hoc fix on operators with multiple same input tensors - for node in graph.nodes(): - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor): - itensor._dirty_grad = True + else: + # remove backward nodes if no backward is called + fnodes = [node for node in nodes if isinstance(node, IRFwOperation)] + for fnode in fnodes: + IRCell.make_pair(fnode, None) + # remove node gradient + for node in nodes: + for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor): + itensor.parent.requires_grad = False + # ad hoc fix on operators with multiple same input tensors + itensor._dirty_grad = True + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor): + otensor.parent.requires_grad = False + graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') return graph @staticmethod From bea106dea0bf61148c7a469a48d3a7d3ccff76ab Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Jul 2022 10:51:23 +0800 Subject: [PATCH 0910/1892] torch.pad for v1.12 --- cube/graph/function/pad.py | 1 + cube/graph/parser/mapping.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/cube/graph/function/pad.py b/cube/graph/function/pad.py index 0c3147c0..0d46be5e 100644 --- a/cube/graph/function/pad.py +++ b/cube/graph/function/pad.py @@ -8,6 +8,7 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, **kwargs): # torch.nn.functional.pad(input, pad, mode='constant', value=0.0) # pad: List[int] + signature = 'torch.nn.functional.pad' assert len(inputs) == 1, "Expected only input, weight, bias as inputs" assert len(kwargs) == 3, "Expected 2 kwargs: mode, value" super().__init__(name, signature, 1, 1) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 182e2f1b..749453b4 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -118,6 +118,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('conv3d'): function.Conv3D, + __ttemplate('pad'): function.Pad, + __ttemplate('select'): function.Select, __ttemplate('slice'): function.Slice, From 26762fdc2ce9a22e1eb77a7a53a2b1f6111d1596 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 3 Jul 2022 10:25:52 +0800 Subject: [PATCH 0911/1892] change cat, stack to einops --- cube/graph/function/function.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index f107b91d..1c4a4054 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -12,7 +12,6 @@ from cube.graph.function.pad import IRPad from cube.graph.function.scripteinops import IRScriptEinOps from cube.graph.function.customops import IRCustomOps -from cube.graph.function.cat import IRCat, IRStack from cube.graph.function.creators import IROnes, IRToTensor, IRZeros from cube.graph.function.select import IRSelect, IRSlice from cube.graph.function.scatter import IRSelectScatter @@ -676,25 +675,41 @@ def Pad(signature, inputs): pad, mode, value = inputs[1:] return IRPad(signature, tensors, 'pad', pad=pad, mode=mode, value=value) -def Cat(signature, inputs): + +def Cat(signature, inputs: Tuple[List[IRTensor], int]): """ torch.cat(inputs: List[Tensor], dim: int) -> Tensor e.g. cat(tensor([2,3]), tensor([2,3])).shape == [4,3] """ - tensors : List[IRTensor] - dim : int tensors, dim = inputs - return IRCat(signature, tensors, 'cat', dim=dim) + iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] + dimlens = [t.shape[dim] for t in tensors] + for ashape, dimlen in zip(iannos, dimlens): + ashape[dim] = str(dimlen) + oannos = [copy.copy(iannos[-1])] + oannos[0][dim] = str(sum(dimlens)) + anno = OpAnno.create_op_str(iannos, oannos) + return IREinops(signature, [anno], tensors, 'cat', dim=dim) + def Stack(signature, inputs: Tuple[List[IRTensor], int]): """ torch.stack(inputs: List[Tensor], dim: int) -> Tensor + inputs: + tensors: List[Tensor]: all tensors need to have same size + dim: the new inserted dim + e.g. stack(tensor([2,3]), tensor([2,3])).shape == [2,2,3] """ tensors, dim = inputs - return IRCat(signature, tensors, 'stack', dim=dim) + iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] + oannos = [copy.copy(iannos[-1])] + oannos[0].insert(dim, str(len(tensors))) + anno = OpAnno.create_op_str(iannos, oannos) + return IREinops(signature, [anno], tensors, 'stack', dim=dim) + def Select(signature, inputs: Tuple[IRTensor, int, int]): """ From 7168140f4d080c9cc7544b8e08962fa99fef28d7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 3 Jul 2022 10:52:32 +0800 Subject: [PATCH 0912/1892] rename einops to dimops --- cube/algorithm/factory.py | 4 +- cube/algorithm/ops/{einops.py => dimops.py} | 16 ++--- cube/graph/function/__init__.py | 2 +- cube/graph/function/{einops.py => dimops.py} | 12 ++-- cube/graph/function/function.py | 62 ++++++++++---------- cube/graph/parser/register.py | 4 +- 6 files changed, 50 insertions(+), 50 deletions(-) rename cube/algorithm/ops/{einops.py => dimops.py} (93%) rename cube/graph/function/{einops.py => dimops.py} (99%) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 71545ec8..36501836 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -62,8 +62,8 @@ def _load_predefined_algos(self): import cube.algorithm.ops.dataloader as dataloader self.register(dataloader.IRDataOperation, dataloader.DPDataLoader, tag='data') - import cube.algorithm.ops.einops as einops - self.register(einops.IREinops, einops.DimSplitEinops, tag='dim') + import cube.algorithm.ops.dimops as dimops + self.register(dimops.IRDimops, dimops.DimSplitEinops, tag='dim') import cube.algorithm.ops.conv as conv self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') diff --git a/cube/algorithm/ops/einops.py b/cube/algorithm/ops/dimops.py similarity index 93% rename from cube/algorithm/ops/einops.py rename to cube/algorithm/ops/dimops.py index be4be240..a76b8e62 100644 --- a/cube/algorithm/ops/einops.py +++ b/cube/algorithm/ops/dimops.py @@ -1,7 +1,7 @@ from typing import List, Optional from cube.algorithm.generics import GenericDistAlgo -from cube.graph.function.einops import IREinops, DimAnno +from cube.graph.function.dimops import IRDimops, DimAnno from cube.ir.tensor import IRSubTensor @@ -16,9 +16,9 @@ class DimSplitEinops(GenericDistAlgo): For stay-reduce dimension, this dimension is not allowed to be splitted. """ - def __init__(self, node: IREinops): - if not isinstance(node, IREinops): - raise TypeError(f"Expect IREinops") + def __init__(self, node: IRDimops): + if not isinstance(node, IRDimops): + raise TypeError(f"Expect IRDimops") super().__init__(node) self._adim: str = None self._reduce: DimAnno.ReduceType = None @@ -34,7 +34,7 @@ def satisfy(self, idx: int, dim: int, num: int) -> bool: @return satisfy bool: true if can be partitioned, elsewise false. """ assert all(isinstance(cond, int) for cond in [idx, dim, num]), "expect int condition" - node: IREinops = self.node + node: IRDimops = self.node ninputs = len(node.inputs()) idx = idx if idx >= 0 else idx + ninputs @@ -53,9 +53,9 @@ def satisfy(self, idx: int, dim: int, num: int) -> bool: return False return True - def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IREinops]]: + def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: - node: IREinops = self.node + node: IRDimops = self.node satisfy = self.satisfy(idx, dim, num) print(f'partition {node.name}: {node.anno} | dim: {self._adim} reduce: {self._reduce.value}') if not satisfy: @@ -107,7 +107,7 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IREinops]]: updated_kwargs = dict() if self._adim in node.kwargs and isinstance(node.kwargs[self._adim], int): updated_kwargs[self._adim] = node.kwargs[self._adim] // num - sub_node: IREinops = node.new(inputs, outputs, **updated_kwargs) + sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) sub_node.infer_shape() sub_nodes.append(sub_node) return sub_nodes diff --git a/cube/graph/function/__init__.py b/cube/graph/function/__init__.py index 8dede842..fc28ba75 100644 --- a/cube/graph/function/__init__.py +++ b/cube/graph/function/__init__.py @@ -1,2 +1,2 @@ -from cube.graph.function.einops import IREinops +from cube.graph.function.dimops import IRDimops from cube.graph.function.function import * \ No newline at end of file diff --git a/cube/graph/function/einops.py b/cube/graph/function/dimops.py similarity index 99% rename from cube/graph/function/einops.py rename to cube/graph/function/dimops.py index 46cd498b..c9e5e28c 100644 --- a/cube/graph/function/einops.py +++ b/cube/graph/function/dimops.py @@ -469,7 +469,7 @@ def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], -class IREinops(IRFwOperation): +class IRDimops(IRFwOperation): """ Einstein-inspired notation operations """ @@ -573,7 +573,7 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): annos = self._annos_candidates updated_kwargs = copy.copy(self.kwargs) updated_kwargs.update(kwargs) - op = IREinops(self.signature, annos, inputs, self.name, **updated_kwargs) + op = IRDimops(self.signature, annos, inputs, self.name, **updated_kwargs) for idx, output in enumerate(outputs): op.set_output(idx, output) return op @@ -667,14 +667,14 @@ def algorithms(self, tag: Optional[str] = None): algos = list() if factory.exist(type(self)): algos += [template(self) for template in factory.algorithms(type(self))] - if factory.exist(IREinops): - algos += [template(self) for template in factory.algorithms(IREinops)] + if factory.exist(IRDimops): + algos += [template(self) for template in factory.algorithms(IRDimops)] return algos else: if factory.exist(type(self), tag): template = factory.algorithms(type(self), tag) return template(self) - if factory.exist(IREinops, tag): - template = factory.algorithms(IREinops, tag) + if factory.exist(IRDimops, tag): + template = factory.algorithms(IRDimops, tag) return template(self) return None diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 1c4a4054..bd5527d8 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -6,7 +6,7 @@ from cube.ir.cten import IRTensor from cube.ir.tensor import IRFullTensor -from cube.graph.function.einops import ShapeAnno, OpAnno, IREinops +from cube.graph.function.dimops import ShapeAnno, OpAnno, IRDimops from cube.graph.function.conv import IRConv2D from cube.graph.function.conv import IRConv3D from cube.graph.function.pad import IRPad @@ -24,7 +24,7 @@ def Identity(signature, inputs): signature = 'cube.runtime.function.identity' eshape = ShapeAnno.create_shape_str(inputs[0].shape) anno = OpAnno.create_op_str([eshape], [eshape]) - return IREinops(signature, [anno], inputs, 'identity') + return IRDimops(signature, [anno], inputs, 'identity') def Linear(signature, inputs): @@ -35,14 +35,14 @@ def Linear(signature, inputs): ] if inputs[2] is None: inputs = inputs[0:2] - return IREinops(signature, annos, inputs, 'linear') + return IRDimops(signature, annos, inputs, 'linear') def BatchLinear(signature, inputs): annos = [ 'b m k, b k n -> b m n' ] - return IREinops(signature, annos, inputs, 'bmm') + return IRDimops(signature, annos, inputs, 'bmm') def Zeros(signature, @@ -224,7 +224,7 @@ def Add(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, 'add', **kwargs) + return IRDimops(signature, annos, inputs, 'add', **kwargs) @@ -253,7 +253,7 @@ def Sub(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, 'sub', **kwargs) + return IRDimops(signature, annos, inputs, 'sub', **kwargs) def Mul(signature, inputs): @@ -269,7 +269,7 @@ def Mul(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, 'mul') + return IRDimops(signature, annos, inputs, 'mul') def Div(signature, inputs): @@ -286,7 +286,7 @@ def Div(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, 'div') + return IRDimops(signature, annos, inputs, 'div') def FloorDiv(signature, inputs): @@ -302,7 +302,7 @@ def FloorDiv(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, 'floordiv') + return IRDimops(signature, annos, inputs, 'floordiv') def Pow(signature, inputs): @@ -318,7 +318,7 @@ def Pow(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, 'pow') + return IRDimops(signature, annos, inputs, 'pow') # if both operands are scalars, returns bool. @@ -338,7 +338,7 @@ def comparison_einops(f, name, signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IREinops(signature, annos, inputs, name) + return IRDimops(signature, annos, inputs, name) def Neg(signature, inputs): @@ -359,7 +359,7 @@ def Neg(signature, inputs): return -arg annos = ['* -> *'] - return IREinops(signature, annos, inputs, 'neg', **kwargs) + return IRDimops(signature, annos, inputs, 'neg', **kwargs) def Sin(signature, inputs): annos = ['* -> *'] @@ -367,10 +367,10 @@ def Sin(signature, inputs): if len(inputs) == 2: # adapt for newest pytorch version approximate = inputs[1] - return IREinops(signature, annos, tensor, 'sin', + return IRDimops(signature, annos, tensor, 'sin', approximate=approximate) else: - return IREinops(signature, annos, tensor, 'sin') + return IRDimops(signature, annos, tensor, 'sin') def Cos(signature, inputs): @@ -379,10 +379,10 @@ def Cos(signature, inputs): if len(inputs) == 2: # adapt for newest pytorch version approximate = inputs[1] - return IREinops(signature, annos, tensor, 'cos', + return IRDimops(signature, annos, tensor, 'cos', approximate=approximate) else: - return IREinops(signature, annos, tensor, 'cos') + return IRDimops(signature, annos, tensor, 'cos') def GeLU(signature, inputs): @@ -392,17 +392,17 @@ def GeLU(signature, inputs): if len(inputs) == 2: # adapt for newest pytorch version approximate = inputs[1] - return IREinops(signature, annos, tensor, 'gelu', + return IRDimops(signature, annos, tensor, 'gelu', approximate=approximate) else: - return IREinops(signature, annos, tensor, 'gelu') + return IRDimops(signature, annos, tensor, 'gelu') def Softmax(signature, inputs): annos = ['* -> *'] tensor = inputs[0:1] dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] - return IREinops(signature, annos, tensor, 'softmax', + return IRDimops(signature, annos, tensor, 'softmax', dim=dim, _stacklevel=_stacklevel, dtype=dtype) @@ -412,7 +412,7 @@ def Dropout(signature, inputs): ] tensor = inputs[0:1] p, training, inplace = inputs[1], inputs[2], inputs[3] - return IREinops(signature, annos, tensor, 'dropout', + return IRDimops(signature, annos, tensor, 'dropout', p=p, training=training, inplace=inplace) @@ -424,7 +424,7 @@ def LayerNorm(signature, inputs): f'N *, ?, {normalized_shape[0]}, {normalized_shape[0]} -> N *', f'N *, ?, ?, ? -> N *' ] - return IREinops(signature, annos, [input, normalized_shape, weight, bias], + return IRDimops(signature, annos, [input, normalized_shape, weight, bias], 'layernorm', eps=eps) @@ -446,9 +446,9 @@ def Sum(signature, inputs): einput = [edim + '+' for edim in einput] anno = OpAnno.create_op_str([einput], [eoutput]) if dim is not None: - return IREinops(signature, [anno], [tensor], 'sum', dim=dim, keepdim=keepdim) + return IRDimops(signature, [anno], [tensor], 'sum', dim=dim, keepdim=keepdim) else: - return IREinops(signature, [anno], [tensor], 'sum') + return IRDimops(signature, [anno], [tensor], 'sum') def Transpose(signature, inputs): @@ -463,7 +463,7 @@ def Transpose(signature, inputs): edim_ou[dim0], edim_ou[dim1] = edim_ou[dim1], edim_ou[dim0] anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IREinops(signature, [anno], [input], 'transpose', + return IRDimops(signature, [anno], [input], 'transpose', dim0=dim0, dim1=dim1) @@ -579,7 +579,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s # bracket[subdim] = edim + '^' anno = OpAnno.create_op_str([in_anno], [ou_anno]) signature = 'torch.Tensor.view' - return IREinops(signature, [anno], [input], 'view', size=tuple(shape)) + return IRDimops(signature, [anno], [input], 'view', size=tuple(shape)) def Reshape(signature, inputs): @@ -600,7 +600,7 @@ def Reshape(signature, inputs): # torch.conv2d(input, weight, bias, stride, padding, dialation, groups) # https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html?highlight=torch%20conv2d#torch.nn.functional.conv2d # """ -# def adapt(anno: OpAnno, node: IREinops) -> OpAnno: +# def adapt(anno: OpAnno, node: IRDimops) -> OpAnno: # iH, iW = node.inputs(0).shape[2:4] # stride = node.kwargs['stride'] # padding = node.kwargs['padding'] @@ -620,7 +620,7 @@ def Reshape(signature, inputs): # if tensors[-1] is None: # tensors = inputs[0:2] # stride, padding, dilation, groups = inputs[3:] -# return IREinops(signature, annos, tensors, 'conv2d', +# return IRDimops(signature, annos, tensors, 'conv2d', # stride=stride, padding=padding, dilation=dilation, groups=groups) @@ -690,7 +690,7 @@ def Cat(signature, inputs: Tuple[List[IRTensor], int]): oannos = [copy.copy(iannos[-1])] oannos[0][dim] = str(sum(dimlens)) anno = OpAnno.create_op_str(iannos, oannos) - return IREinops(signature, [anno], tensors, 'cat', dim=dim) + return IRDimops(signature, [anno], tensors, 'cat', dim=dim) def Stack(signature, inputs: Tuple[List[IRTensor], int]): @@ -708,7 +708,7 @@ def Stack(signature, inputs: Tuple[List[IRTensor], int]): oannos = [copy.copy(iannos[-1])] oannos[0].insert(dim, str(len(tensors))) anno = OpAnno.create_op_str(iannos, oannos) - return IREinops(signature, [anno], tensors, 'stack', dim=dim) + return IRDimops(signature, [anno], tensors, 'stack', dim=dim) def Select(signature, inputs: Tuple[IRTensor, int, int]): @@ -760,7 +760,7 @@ def Embedding(signature, inputs: List): ] oshapes = [ishapes[0] + [ishapes[1][-1]]] anno = OpAnno.create_op_str(ishapes, oshapes) - return IREinops(signature, [anno], [itensor, weight], 'embedding', padding_idx=padding_idx, start=start, stop=stop) + return IRDimops(signature, [anno], [itensor, weight], 'embedding', padding_idx=padding_idx, start=start, stop=stop) def MultiRef(signature, inputs: List[IRFullTensor]): @@ -772,7 +772,7 @@ def MultiRef(signature, inputs: List[IRFullTensor]): assert isinstance(itensor, IRFullTensor), "require all inputs to be IRSubTensor" assert isinstance(times, int), "require int for second input" anno = '* -> ' + ', '.join('*' for _ in range(times)) - node = IREinops(signature, [anno], [itensor], 'multiref', times=times) + node = IRDimops(signature, [anno], [itensor], 'multiref', times=times) return node diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index e167242e..6db1fc25 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -6,7 +6,7 @@ import inspect import torch -from cube.graph.function.einops import IREinops +from cube.graph.function.dimops import IRDimops from cube.graph.parser.mapping import Sign2Op @@ -48,7 +48,7 @@ def udfop(signature: str, inputs: List[Any]): kwargs = dict() for name, val in zip(kwarg_names, kwarg_vals): kwargs[name] = val - return IREinops(signature, [anno], tensors, **kwargs, name=fsig) + return IRDimops(signature, [anno], tensors, **kwargs, name=fsig) print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') Sign2Op.register(fsig, udfop, code) From 91b45f08ad626d29b91a76409ca1442400959734 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 4 Jul 2022 10:15:08 +0800 Subject: [PATCH 0913/1892] ad hoc fix on gpt model of embed weight reuse --- examples/gsearch/gpt/model.py | 12 +- examples/gsearch/gpt/policy/spmd.py | 4 +- examples/gsearch/gpt/train.py | 6 +- examples/transformer/policy/naive.py | 11 -- examples/transformer/transformers.py | 248 --------------------------- tests/test_examples.sh | 21 +-- 6 files changed, 25 insertions(+), 277 deletions(-) delete mode 100644 examples/transformer/policy/naive.py delete mode 100644 examples/transformer/transformers.py diff --git a/examples/gsearch/gpt/model.py b/examples/gsearch/gpt/model.py index 817bf1ff..862c8766 100644 --- a/examples/gsearch/gpt/model.py +++ b/examples/gsearch/gpt/model.py @@ -39,7 +39,8 @@ def __init__(self): super().__init__() cfg = Config() - self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) + self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.embed_dim)) + # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) self.embed_dropout = torch.nn.Dropout() @@ -54,7 +55,11 @@ def __init__(self): def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): - embed = self.embed(input_ids) + # embed = self.embed(input_ids) + embed = torch.nn.functional.embedding( + input_ids, self.embedw, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False + ) pos_embed = self.position(position_ids) embed = embed + pos_embed embed = self.embed_dropout(embed) @@ -64,7 +69,8 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): enc = layer(enc) enc = self.final_layernorm(enc) - logits = torch.nn.functional.linear(enc, self.embed.weight) + # logits = torch.nn.functional.linear(enc, self.embed.weight) + logits = torch.nn.functional.linear(enc, self.embedw) # simplified loss = torch.sum(logits) return loss diff --git a/examples/gsearch/gpt/policy/spmd.py b/examples/gsearch/gpt/policy/spmd.py index 964958b1..c8cff30b 100644 --- a/examples/gsearch/gpt/policy/spmd.py +++ b/examples/gsearch/gpt/policy/spmd.py @@ -33,8 +33,8 @@ def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): graph.assign(sub_node, idx) return sub_nodes - # annotating code structure - multirefs = [node for node in graph.nodes() if isinstance(node, IRFwOperation) and node.name == 'multiref'] + # annotating code structure -- not consider multiref on embedding weight + multirefs = [node for node in graph.nodes() if isinstance(node, IRFwOperation) and node.name == 'multiref'][1:] for idx in range(0, len(multirefs), 2): multirefs[idx].comment = f'====> start of transformer {idx // 2}' diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py index d2fa68d3..8206def4 100644 --- a/examples/gsearch/gpt/train.py +++ b/examples/gsearch/gpt/train.py @@ -49,9 +49,6 @@ def train(): dataloader = GPTDataLoader(batch_size) optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - print_each_rank('model weight consumpition:') - memory_summary() - model = cube.SemanticModel(model, dataloader.shapes) @cube.compile(model, dataloader, PAS=PAS, override=True) def train_iter(model, dataloader): @@ -60,6 +57,9 @@ def train_iter(model, dataloader): loss.backward() model = model.get_gen_module() + print_each_rank('model weight consumpition:') + memory_summary() + CudaTimer(enable=False).warmup() iter_num = 64 for step in range(iter_num): diff --git a/examples/transformer/policy/naive.py b/examples/transformer/policy/naive.py deleted file mode 100644 index 250e4ae7..00000000 --- a/examples/transformer/policy/naive.py +++ /dev/null @@ -1,11 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.adapter.adapter import IRAdapter -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation - - -def PAS(graph: IRGraph, resource): - print(graph.extra_repr()) - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - return graph \ No newline at end of file diff --git a/examples/transformer/transformers.py b/examples/transformer/transformers.py deleted file mode 100644 index bfa82c5c..00000000 --- a/examples/transformer/transformers.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/transformer/transformers.py - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/transformer/transformers.py -""" - -import torch -from torch import nn -import cube - -from examples.transformer.policy.naive import PAS - -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary - - -@cube.graph.parser.register('L^ N E^, (3 h d^) E^ -> L^ N (h d^)') -def attnfc1(x: torch.Tensor, wqkv: torch.Tensor, h: int, - scale: float, dropout: float, training: bool): - """ - L: sequence length - N: batch size - E: embedding size - x: hidden state: [L, N, E] - wqkv: qkv weight: [3 * (num_head * dim_head), E] - dropout: float - h: int: number of heads - """ - num_head = h - L, N = x.shape[0], x.shape[1] - dim_head = wqkv.shape[0] // 3 // num_head - # L N E, (3 h d) E -> L N (3 h d) - qkv = torch.nn.functional.linear(x, wqkv, None) - # L N (3 h d) -> L N (h d), L N (h d), L N (h d) - q, k, v = qkv.chunk(3, dim=-1) - # L N (h d) -> L (N h) d - q = q.contiguous().view(L, (N * num_head), dim_head) - # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) - # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) - # L (N h) d -> (N h) L d - q = q.transpose(0, 1) - # L (N h) d -> (N h) L d - k = k.transpose(0, 1) - # L (N h) d -> (N h) L d - v = v.transpose(0, 1) - # (N h) L d, 1 -> (N h) L d - q = q * scale - # (N h) L d -> (N h) d L - k = k.transpose(-2, -1) - # (N h) L d, (N h) d L -> (N h) L L - attn = torch.bmm(q, k) - - # attention mask - # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N * num_head), L, L) - - # (N h) L L -> (N h) L L - attn = torch.nn.functional.softmax(attn, dim=-1) - # (N h) L L -> (N h) L L - if training: - attn = torch.nn.functional.dropout(attn, dropout, True, False) - # (N h) L L, (N h) L d -> (N h) L d - output = torch.bmm(attn, v) - # (N h) L d -> L (N h) d - output = output.transpose(0, 1).contiguous() - # L (N h) d -> L N (h d) - output = output.view(L, N, num_head * dim_head) - return output - - -class MultiHeadSelfAttention(nn.Module): - - def __init__(self, embed_dim, heads, dropout: float): - super().__init__() - self.embed_dim = embed_dim - self.num_head = heads - self.dim_head = embed_dim // heads - self.scale = self.dim_head ** -0.5 - - self.wqkv = torch.nn.Parameter(torch.empty( - 3 * embed_dim, embed_dim - )) - self.wout = torch.nn.Parameter(torch.empty( - embed_dim, embed_dim - )) - self.dropout = dropout - - def forward(self, x): - """ - x: [L, N, E]: seq_len, batch_size, embedding dimension - output: [L, N, E] - """ - # L N E, (3 h d) E -> L N (h d) - output = attnfc1(x, self.wqkv, self.num_head, - self.scale, self.dropout, self.training) - # L N (h d), E (h d) -> L N E - output = torch.nn.functional.linear(output, self.wout) - return output - - -class FFN(torch.nn.Module): - - def __init__(self, hidden_size: int): - super().__init__() - self.dense_h_to_4h = torch.nn.Linear( - hidden_size, 4 * hidden_size - ) - self.dense_4h_to_h = torch.nn.Linear( - 4 * hidden_size, hidden_size - ) - - def forward(self, hidden_states): - # [L, N, E] * [E, 4E] -> [L, N, 4E] - out = self.dense_h_to_4h(hidden_states) - # [L, N, 4E] -> [L, N, 4E] - out = torch.nn.functional.gelu(out) - # [L, N, 4E] * [4E, E] -> [L, N, E] - out = self.dense_4h_to_h(out) - return out - - -class TransformerLayer(torch.nn.Module): - - def __init__(self, hidden_size, head_num, dropout): - super().__init__() - # layer norm - self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - - self.attention = MultiHeadSelfAttention(hidden_size, head_num, dropout) - self.attn_dropout = torch.nn.Dropout(dropout) - - self.ffn_layernorm = torch.nn.LayerNorm(hidden_size, eps=0.00001) - self.ffn = FFN(hidden_size) - self.ffn_dropout = torch.nn.Dropout(dropout) - - def forward(self, hidden_states): - # Attention - in_attn_norm = self.input_layernorm(hidden_states) - attn_out = self.attention(in_attn_norm) - # residual - attn_out = self.attn_dropout(attn_out) - # TODO: enable residual - residual = attn_out # + hidden_states - # ffn - in_ffn_norm = self.ffn_layernorm(residual) - ffn_out = self.ffn(in_ffn_norm) - # residual - ffn_out = self.ffn_dropout(ffn_out) - # TODO: enable residual - ffn_out = ffn_out # + residual - return ffn_out - - -class Transformers(torch.nn.Module): - - def __init__(self, hidden_size, head_num): - super().__init__() - - self.transformer1 = TransformerLayer(hidden_size, head_num, 0.5) - self.transformer2 = TransformerLayer(hidden_size, head_num, 0.5) - self.transformer3 = TransformerLayer(hidden_size, head_num, 0.5) - self.transformer4 = TransformerLayer(hidden_size, head_num, 0.5) - - def forward(self, hidden_states): - - hidden_states = self.transformer1(hidden_states) - hidden_states = self.transformer2(hidden_states) - hidden_states = self.transformer3(hidden_states) - hidden_states = self.transformer4(hidden_states) - loss = torch.sum(hidden_states) - return loss - - -def train(): - L = 512 # seq len - N = 8 # batch size - # configs: [hidden size, num_head] - # E, num_head = [1536, 16] # 1.2B model - # E, num_head = [1920, 20] # 2.5B model - # E, num_head = [2304, 24] # 4.2B model - E, num_head = [3072, 32] # 8.7B model - - - model = Transformers( - hidden_size=E, head_num=num_head - ) - model = cube.SemanticModel( - model, input_shapes=([L, N, E],), - ) - - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([L, N, E],), - dtypes=(torch.float32,), - batch_dims=(1,) - ) - - @cube.compile(model, dataloader, PAS=PAS) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - model = model.get_gen_module() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - CudaTimer().warmup() - torch.distributed.barrier() - iter_num = 128 - for step in range(iter_num): - if step >= 40: - CudaTimer().start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - memory_summary() - -if __name__ == '__main__': - - cube.init() - train() \ No newline at end of file diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 02e456d7..83034387 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -50,15 +50,16 @@ OMP_NUM_THREADS=4 torchrun \ --nnodes=1 \ examples/gsearch/gpt/train.py --policy PASMegatronTP -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/gsearch/gpt/train.py --policy PASRoundRobin - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/gsearch/gpt/train.py --policy PAS1F1B +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# examples/gsearch/gpt/train.py --policy PASRoundRobin +# +# +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# examples/gsearch/gpt/train.py --policy PAS1F1B # test scientific model @@ -68,7 +69,7 @@ OMP_NUM_THREADS=4 torchrun \ --nnodes=1 \ examples/poisson/sci.py -OMP_NUM_THREADS=4 torchrun \ +SCIENTIFIC_COMPUTING=1 OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=1 \ --nnodes=1 \ examples/wrf/wrf2.py From 1c3c9d5ce009d7e67584986fcd42c3efcced01fb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 4 Jul 2022 12:59:01 +0800 Subject: [PATCH 0914/1892] add anchor function for graph navigation --- cube/compiler.py | 7 +++++-- cube/graph/function/function.py | 13 ++++++++++++- cube/graph/parser/mapping.py | 8 ++++++++ cube/graph/parser/parser.py | 2 -- cube/runtime/function/function.py | 6 ++++++ 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index f002a366..92b49d83 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -9,6 +9,7 @@ from cube.graph.gener.gen import IRAdapterGener from cube.graph.graph import IRGraph from cube.ir.operator import IRDataOperation +from cube.graph.function.anchor import IRGraphAnchor from cube.logics.pool import SchedulePool from cube.logics.translator import LogicTranslator @@ -159,9 +160,11 @@ def decorator(fn: Callable) -> Callable: if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") - # check assignment and order + # check assignment and remove anchor node for node in graph.nodes(): - if len(node.device) == 0: + if isinstance(node, IRGraphAnchor) or isinstance(node.mirror, IRGraphAnchor): + graph.detach(node) + elif len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") # generate adapter diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index bd5527d8..dbb7c35d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -5,7 +5,8 @@ import warnings from cube.ir.cten import IRTensor -from cube.ir.tensor import IRFullTensor +from cube.ir.operator import IRFwOperation +from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.graph.function.dimops import ShapeAnno, OpAnno, IRDimops from cube.graph.function.conv import IRConv2D from cube.graph.function.conv import IRConv3D @@ -16,6 +17,7 @@ from cube.graph.function.select import IRSelect, IRSlice from cube.graph.function.scatter import IRSelectScatter from cube.graph.function.repeat import IRRepeat +from cube.graph.function.anchor import IRGraphAnchor from cube.ir.dtype import IRDType from cube.graph.torch_dtype_mapping import DType2IRDType, TorchScalarTypeEnumMap @@ -776,6 +778,15 @@ def MultiRef(signature, inputs: List[IRFullTensor]): return node +def GraphAnchor(signature, inputs: List[IRSubTensor]): + """ + cube.runtime.function.anchor() -> None + """ + name: str = inputs[0] + node = IRGraphAnchor(signature, name) + return node + + def ScriptEinOps(signature, inputs): """ apply_for_scriptable_torch(recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str) -> torch.Tensor: diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 749453b4..0bebce0e 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -21,6 +21,8 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: """ Map the signature to GenericLogicalOp """ + if 'torch.' not in signature and 'cube.runtime.' not in signature: + signature = signature.split('.')[-1] if signature in Sign2Op.kOpMap: return partial(Sign2Op.kOpMap[signature], signature=signature) else: @@ -46,6 +48,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # tensor template __ttemplate = lambda name: f'torch.{name}' + # runtime template + __rtemplate = lambda name: f'cube.runtime.function.function.{name}' + # einops __einopsize = lambda name: f'einops._torch_specific.{name}' @@ -136,6 +141,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('stack'): function.Stack, + # runtime functions + __rtemplate('anchor'): function.GraphAnchor, + #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 7d9e3226..9ea0b216 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -228,8 +228,6 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: input_vals.append(val) # map to IR operator - if 'torch' not in fsig: # indicate a customized operator - fsig = fsig.split('.')[-1] ir_node = Sign2Op.map(fsig)(inputs=input_vals) # push output in the frame diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 38d02363..cf762f32 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -9,6 +9,12 @@ def identity(tensor: torch.Tensor) -> torch.Tensor: """ return tensor +def anchor(name: str): + """ + anchor operation for graph navigation + """ + return None + def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: """ From 6e751db5901aa7f9419e4d6b4dd1450f6fc56d28 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 4 Jul 2022 13:14:18 +0800 Subject: [PATCH 0915/1892] add ir anchor --- cube/graph/function/anchor.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 cube/graph/function/anchor.py diff --git a/cube/graph/function/anchor.py b/cube/graph/function/anchor.py new file mode 100644 index 00000000..649ec9ef --- /dev/null +++ b/cube/graph/function/anchor.py @@ -0,0 +1,24 @@ + +from cube.ir.operator import IRFwOperation +from cube.ir.tensor import IRSubTensor + + +class IRGraphAnchor(IRFwOperation): + """ + The anchor function for navigation inside the graph + """ + def __init__(self, signature: str, name: str): + super().__init__(name, signature, 0, 1) + self.kwargs['name'] = name + self.set_output(0, None) + + def infer_shape(self): + return True + + def __repr__(self) -> str: + sign = self.signature.split('.')[-1] + ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_param()] + dscp = (f"FwOp{self._id}(sign={sign}[{self.name}], " + f"inputs={ins}, " + f"outputs={self.outputs()})") + return dscp From 30cbb6798ae20a91af774747bc3f56e036a322ee Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Mon, 4 Jul 2022 08:44:27 +0000 Subject: [PATCH 0916/1892] Merged PR 1393: Fix ScalarTypes not treated as ints during parsing TorchScript ScalarType is neither 'torch.dtype' (Python enum) or 'ScalarType' (C++ enum) during TorchScript, but is 'int' which plays the role as the underlying type of 'ScalarType' Also add UTs to guard this behavior. --- cube/graph/function/function.py | 47 +++++++++++++++---------- tests/test_parser.py | 61 +++++++++++++++++++++++++++++++++ tests/test_prim_loop.py | 6 ++-- 3 files changed, 92 insertions(+), 22 deletions(-) create mode 100644 tests/test_parser.py diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index dbb7c35d..270ffeb2 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -48,23 +48,26 @@ def BatchLinear(signature, inputs): def Zeros(signature, - inputs: Tuple[ List[int], Optional[Any], Optional[Any], 'ErasedDevice', Optional[bool] ]): + inputs: Tuple[ List[int], Optional[int], Optional[Any], 'ErasedDevice', Optional[bool] ]): # zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. - size, dtype, layout, _erased_device, pin_memory = inputs + size, dtype_underlying, layout, _erased_device, pin_memory = inputs # TODO parameters to support, currently they are all None assert layout is None assert pin_memory is None - ir_dtype : IRDType - if dtype is not None: - ir_dtype = DType2IRDType.map(dtype) + if dtype_underlying is not None: + # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, + # which is the underlying type of PyTorch C++ enum 'ScalarType'. + dtype = TorchScalarTypeEnumMap.map(dtype_underlying) else: - ir_dtype = DType2IRDType.map(torch.get_default_dtype()) + dtype = torch.get_default_dtype() + + ir_dtype : IRDType = DType2IRDType.map(dtype) for dim, i in enumerate(size): if not isinstance(dim, int) and not dim >= 0: @@ -72,20 +75,23 @@ def Zeros(signature, return IRZeros(signature, size, 'zeros', ir_dtype) def Ones(signature, - inputs: Tuple[ List[int], Optional[Any], Optional[Any], 'ErasedDevice', Optional[bool] ]): + inputs: Tuple[ List[int], Optional[int], Optional[Any], 'ErasedDevice', Optional[bool] ]): # ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - size, dtype, layout, _erased_device, pin_memory = inputs + size, dtype_underlying, layout, _erased_device, pin_memory = inputs # TODO parameters to support, currently they are all None assert layout is None assert pin_memory is None - ir_dtype : IRDType - if dtype is not None: - ir_dtype = DType2IRDType.map(dtype) + if dtype_underlying is not None: + # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, + # which is the underlying type of PyTorch C++ enum 'ScalarType'. + dtype = TorchScalarTypeEnumMap.map(dtype_underlying) else: - ir_dtype = DType2IRDType.map(torch.get_default_dtype()) + dtype = torch.get_default_dtype() + + ir_dtype : IRDType = DType2IRDType.map(dtype) for dim, i in enumerate(size): if not isinstance(dim, int) and not dim >= 0: @@ -93,28 +99,31 @@ def Ones(signature, return IROnes(signature, size, 'ones', ir_dtype) def NewTensor(signature, - inputs: Tuple[ list, Optional[Any], 'ErasedDevice', bool ]): + inputs: Tuple[ list, Optional[int], 'ErasedDevice', bool ]): # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor # # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. - data, dtype, _erased_device, requires_grad = inputs + data, dtype_underlying, _erased_device, requires_grad = inputs # TODO parameters to support, currently they are all None assert requires_grad == False - ir_dtype : IRDType - if dtype is not None: - ir_dtype = DType2IRDType.map(dtype) + if dtype_underlying is not None: + # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, + # which is the underlying type of PyTorch C++ enum 'ScalarType'. + dtype = TorchScalarTypeEnumMap.map(dtype_underlying) else: - ir_dtype = DType2IRDType.map(torch.get_default_dtype()) + dtype = torch.get_default_dtype() + + ir_dtype : IRDType = DType2IRDType.map(dtype) # if 'data' is not: # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 # 2) non-ragged # ... then this call will throw. - arr = torch.tensor(data, dtype=dtype) + arr = torch.tensor(data, dtype=dtype_underlying) # TODO temporarily fake creation with Zeros # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 00000000..7ef1b874 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,61 @@ +# run tests: +# pytest ./tests/test_parser.py + +import pytest +import torch +from cube.graph.function.creators import IROnes, IRZeros + +from cube.graph.parser.frame import Frame +from cube.graph.parser.parser import ScriptModuleParser +from cube.graph.torch_dtype_mapping import DType2IRDType +from cube.ir.dtype import IRDType +from cube.ir.tensor import IRFullTensor +from cube import ir + +@pytest.mark.parametrize( + "aten_op, ir_op_cls", + [("zeros", IRZeros), ("ones", IROnes)] +) +def test_optional_dtype_none(aten_op, ir_op_cls): + g = torch._C.parse_ir(f''' + graph(): + %d : int = prim::Constant[value=2]() + %shape : int[] = prim::ListConstruct(%d, %d, %d) + %none : NoneType = prim::Constant() + %z : Tensor = aten::{aten_op}(%shape, %none, %none, %none, %none) + return (%z) + ''') + frame = Frame() + frame.push_var() + frame.push_attr() + + for node in g.nodes(): + ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) + for node in ir_nodes: + if isinstance(node, ir_op_cls): + assert node.outputs(0).dtype == DType2IRDType.map(torch.get_default_dtype()) + +@pytest.mark.parametrize( + "aten_op, ir_op_cls", + [("zeros", IRZeros), ("ones", IROnes)] +) +def test_optional_dtype_underlying_int(aten_op, ir_op_cls): + # ScalarType(3) == torch.int32 + g = torch._C.parse_ir(f''' + graph(): + %d : int = prim::Constant[value=2]() + %shape : int[] = prim::ListConstruct(%d, %d, %d) + %none : NoneType = prim::Constant() + %scalarType : int = prim::Constant[value=3]() + %z : Tensor = aten::{aten_op}(%shape, %scalarType, %none, %none, %none) + return (%z) + ''') + frame = Frame() + frame.push_var() + frame.push_attr() + + for node in g.nodes(): + ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) + for node in ir_nodes: + if isinstance(node, ir_op_cls): + assert node.outputs(0).dtype == IRDType.int32 diff --git a/tests/test_prim_loop.py b/tests/test_prim_loop.py index 523bb9a0..005fcfb1 100644 --- a/tests/test_prim_loop.py +++ b/tests/test_prim_loop.py @@ -47,7 +47,7 @@ def test_simple_unroll_evaluation(): return (%z) ''') frame = Frame() - frame.push() + frame.push_var() frame.add_var("a", 0) for node in g.nodes(): @@ -71,7 +71,7 @@ def test_unroll_with_structural_info(): return (%z) ''') frame = Frame() - frame.push() + frame.push_var() t_a = IRFullTensor(shape=[2,3]) frame.add_var("a", t_a) @@ -131,7 +131,7 @@ def __init__(self) -> None: module = StubScriptModule() frame = Frame() - frame.push() + frame.push_var() t_a = IRFullTensor(shape=[2,3]) frame.add_var("a", t_a) From 9ff16cc13ce81589dcf4157fb4c0ad0ebedebf07 Mon Sep 17 00:00:00 2001 From: Zijian Ding Date: Wed, 6 Jul 2022 15:07:30 +0800 Subject: [PATCH 0917/1892] Add partition strategy for wrf2, modify partion algorithms to support the strategy. --- cube/algorithm/factory.py | 1 + cube/algorithm/ops/dimops.py | 104 +++++++++++++++++++++++- cube/algorithm/ops/pad.py | 12 ++- cube/graph/parser/mapping.py | 2 + examples/wrf/policy/onedim.py | 146 ++++++++++++++++++++++++++++++++++ 5 files changed, 260 insertions(+), 5 deletions(-) create mode 100644 examples/wrf/policy/onedim.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 36501836..1bda1b15 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -64,6 +64,7 @@ def _load_predefined_algos(self): import cube.algorithm.ops.dimops as dimops self.register(dimops.IRDimops, dimops.DimSplitEinops, tag='dim') + self.register(dimops.IRDimops, dimops.SimpleViewSplitEinops, tag='view_simp') import cube.algorithm.ops.conv as conv self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index a76b8e62..c6a92563 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -49,7 +49,7 @@ def satisfy(self, idx: int, dim: int, num: int) -> bool: dimlen = node.anno.getlen(self._adim) if self._reduce == DimAnno.ReduceType.Freeze: return False - if dimlen % num != 0: + if dimlen < num: return False return True @@ -111,3 +111,105 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: sub_node.infer_shape() sub_nodes.append(sub_node) return sub_nodes + + +class SimpleViewSplitEinops(GenericDistAlgo): + """ + split Einops at dimension level. + + The sum-reduce dimension and non-reduce dimension can be splitted. + + For sum-reduce dimension, the output keeps same shape but has partial-sum valmap result. + For non-reduce dimension, the output keeps same valmap but has partial output shape. + For stay-reduce dimension, this dimension is not allowed to be splitted. + """ + + def __init__(self, node: IRDimops): + if not isinstance(node, IRDimops): + raise TypeError(f"Expect IRDimops") + super().__init__(node) + self._adim: str = None + self._reduce: DimAnno.ReduceType = None + + def satisfy(self, idx: int, dimi: int, dimo: int, num: int) -> bool: + """ + Check whether the condition satisfies. + + @param idx int: input index + @param dimi int: input dimension + @param dimo int: corresponding output dimension + @param num int: chunks to partition the dimension + + @return satisfy bool: true if can be partitioned, elsewise false. + """ + # assert all(isinstance(cond, int) for cond in [idx, dim, num]), "expect int condition" + node: IRDimops = self.node + assert idx == 0, f"Index should be 0" + assert len(node.inputs()) == 1, f"Inputs size should be 1" + assert len(node.outputs()) == 1, f"Outputs size should be 1" + dimi = dimi if dimi >= 0 else dimi + node.inputs(0).ndims + dimo = dimo if dimo >= 0 else dimo + node.outputs(0).ndims + assert dimi < node.inputs(0).ndims, f"dimension out of boundary: {dimi} >= {node.inputs(0).ndims}" + assert dimo < node.outputs(0).ndims, f"dimension out of boundary" + # # due to implementation limits, we only partition the first annotated dimension + # # for inner-dimension cases. + self._adimi: str = node.anno.inputs(0).dims[dimi].identifiers[0] + self._adimo: str = node.anno.outputs(0).dims[dimo].identifiers[0] + dimlen = node.anno.getlen(self._adimi) + if dimlen < num: + return False + return True + + def instantiate(self, idx: int, dimi: int, dimo: int, num: int) -> Optional[List[IRDimops]]: + + node: IRDimops = self.node + satisfy = self.satisfy(idx, dimi, dimo, num) + if not satisfy: + return None + + ins, ous = list(), list() + for iidx, itensor in enumerate(node.inputs()): + if not isinstance(itensor, IRSubTensor): + assert 0, "should not happen" + shape_anno = node.anno.inputs(iidx) + split_dims = shape_anno.getdims(self._adimi) + assert len(split_dims) <= 1, f"find split dims ({self._adimi}) more than 1: {shape_anno}" + if len(split_dims) == 1: + dim = split_dims[0] + # split axis + # print('dimi =', dim) + ins.append(itensor.split_dim(dim, num)) + else: + assert 0, "should not happen" + + for oidx, otensor in enumerate(node.outputs()): + if not isinstance(otensor, IRSubTensor): + assert 0, f"should not happen" + shape_anno = node.anno.outputs(oidx) + split_dims = shape_anno.getdims(self._adimo) + assert len(split_dims) <= 1, f"find split dims ({self._adimo}) more than 1: {shape_anno}" + # split axis + if self._reduce != DimAnno.ReduceType.Dim: + assert len(split_dims) == 1, f"expect only one spatial dimension in output tensor but got {len(split_dims)}" + dim = split_dims[0] + # print('dimo =', dim) + ous.append(otensor.split_dim(dim, num)) + # split numerical dimension + else: + assert 0, f"not implemented" + + sub_nodes = list() + for nid in range(num): + inputs = [t[nid] for t in ins] + outputs = [t[nid] for t in ous] + updated_kwargs = dict() + if self._adimi in node.kwargs and isinstance(node.kwargs[self._adimi], int): + assert 0, "should not happen" + if self._adimo in node.kwargs and isinstance(node.kwargs[self._adimo], int): + assert 0, "should not happen" + assert len(outputs) == 1, f"outputs len should be one" + node.kwargs['size'] = outputs[0].shape + sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) + sub_node.infer_shape() + sub_nodes.append(sub_node) + return sub_nodes \ No newline at end of file diff --git a/cube/algorithm/ops/pad.py b/cube/algorithm/ops/pad.py index 61b74966..856a6487 100644 --- a/cube/algorithm/ops/pad.py +++ b/cube/algorithm/ops/pad.py @@ -49,15 +49,17 @@ def satisfy(self, dim: int, num: int): # split non-pad dim if dim < len(node.inputs(0).shape) - pad_dim_count: - return node.inputs(0).shape[dim] % num == 0 + return node.inputs(0).shape[dim] >= num + # return node.inputs(0).shape[dim] % num == 0 # split pad dim else: dim_in_pad = len(node.inputs(0).shape) - 1 - dim - return (node.inputs(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) % num == 0 + return (node.inputs(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) >= num + # return (node.inputs(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) % num == 0 def instantiate(self, dim: int, num: int): if not self.satisfy(dim, num): - return False + return None node: IRPad = self.node pad = node.kwargs['pad'] mode = node.kwargs['mode'] @@ -82,6 +84,7 @@ def instantiate(self, dim: int, num: int): global_padl = pad[dim_in_pad * 2] global_padr = pad[dim_in_pad * 2 + 1] chunk_size = (node.outputs(0).shape[dim] - global_padl - global_padr) // num + addone_num = (node.outputs(0).shape[dim] - global_padl - global_padr) % num start = 0 for cid in range(num): padl = global_padl if cid == 0 else 0 @@ -92,7 +95,8 @@ def instantiate(self, dim: int, num: int): cur_pad[dim_in_pad * 2 + 1] = padr pads.append(cur_pad) - stop = start + padl + padr + chunk_size + addone = int(cid < addone_num) + stop = start + padl + padr + chunk_size + addone slicers.append((max(0, start), min(node.outputs(0).shape[dim], stop))) start = stop diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 0bebce0e..5c97fb59 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -143,6 +143,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # runtime functions __rtemplate('anchor'): function.GraphAnchor, + + __rtemplate('identity'): function.Identity, #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, diff --git a/examples/wrf/policy/onedim.py b/examples/wrf/policy/onedim.py new file mode 100644 index 00000000..a0d0dde5 --- /dev/null +++ b/examples/wrf/policy/onedim.py @@ -0,0 +1,146 @@ +from cube.graph import IRGraph +from cube.graph.function import IRConv2D, IRConv3D +from cube.graph.function import IRDimops, IRPad +from cube.ir.cten import IRTensor, IRCell + + +def PAS(graph: IRGraph, resource): + for node in graph.nodes(): + if isinstance(node, IRConv3D): + sub_nodes = list() + algo = node.algorithms('halo') + Wnodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus // 2) + for Wnode in Wnodes: + algo = Wnode.algorithms('halo') + Hnodes = graph.partition(Wnode, algo, idx=0, dim=2, num=2) + sub_nodes += Hnodes + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + # sub_nodes = graph.replicate(node, times=resource.ngpus) + + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + print(graph.extra_repr()) + return graph + +global opSigns + +opSigns = [] + +def append_sign(sign: str): + global opSigns + if not sign in opSigns: + opSigns.append(sign) + +def PAS_ALL_TEST(graph: IRGraph, resource): + for node in graph.nodes(): + sign = node.signature.split('.')[-1] + append_sign(sign) + if isinstance(node, IRConv3D): + sub_nodes = list() + algo = node.algorithms('halo') + sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) + elif isinstance(node, IRDimops): + sign = node.signature.split('.')[-1] + if (sign == 'mul' or sign == 'add' or sign == 'sub' or sign == 'div') and (len(node.inputs(0).shape) == 5 or len(node.inputs(0).shape) == 3): + algo = node.algorithms('dim') + if len(node.inputs(0).shape) == 3: + sub_nodes = graph.partition(node, algo, idx=0, dim=1, num=resource.ngpus) + if sub_nodes == None: + sub_nodes = graph.replicate(node, times=resource.ngpus) + elif len(node.inputs(0).shape) == 5: + sub_nodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus) + if sub_nodes == None: + sub_nodes = graph.replicate(node, times=resource.ngpus) + elif sign == 'view': + print('partition view') + print(node) + algo = node.algorithms('view_simp') + sub_nodes = graph.partition(node, algo, idx=0, dimi=node.inputs(0).ndims-2, dimo=node.outputs(0).ndims-2, num=resource.ngpus) + print(sub_nodes) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + elif isinstance(node, IRPad): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, dim=node.inputs(0).ndims-2, num=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + print(graph.extra_repr()) + return graph + + +def PAS_ALL_X(graph: IRGraph, resource): + for node in graph.nodes(): + sign = node.signature.split('.')[-1] + append_sign(sign) + if isinstance(node, IRConv3D): + sub_nodes = list() + algo = node.algorithms('halo') + sub_nodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus) + elif isinstance(node, IRDimops): + if sign in ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat']: + ndims = node.inputs(0).ndims + algo = node.algorithms('dim') + if ndims == 3 or ndims == 5: + sub_nodes = graph.partition(node, algo, idx=0, dim=ndims-1, num=resource.ngpus) + if sub_nodes == None: + sub_nodes = graph.replicate(node, times=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + elif sign == 'view': + algo = node.algorithms('view_simp') + if node.inputs(0).ndims >= 3 and node.outputs(0).ndims >= 3: + sub_nodes = graph.partition(node, algo, idx=0, dimi=node.inputs(0).ndims-1, dimo=node.outputs(0).ndims-1, num=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + elif isinstance(node, IRPad): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, dim=node.inputs(0).ndims-1, num=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + print(graph.extra_repr()) + print(opSigns) + return graph + +def PAS_ALL_Y(graph: IRGraph, resource): + for node in graph.nodes(): + sign = node.signature.split('.')[-1] + append_sign(sign) + if isinstance(node, IRConv3D): + sub_nodes = list() + algo = node.algorithms('halo') + sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) + elif isinstance(node, IRDimops): + if sign in ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat']: + ndims = node.inputs(0).ndims + algo = node.algorithms('dim') + if ndims == 3 or ndims == 5: + sub_nodes = graph.partition(node, algo, idx=0, dim=ndims-2, num=resource.ngpus) + if sub_nodes == None: + sub_nodes = graph.replicate(node, times=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + elif sign == 'view': + algo = node.algorithms('view_simp') + if node.inputs(0).ndims >= 3 and node.outputs(0).ndims >= 3: + sub_nodes = graph.partition(node, algo, idx=0, dimi=node.inputs(0).ndims-2, dimo=node.outputs(0).ndims-2, num=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + elif isinstance(node, IRPad): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, dim=node.inputs(0).ndims-2, num=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + # print(graph.extra_repr()) + print(opSigns) + return graph From 3897b8195742dfa07ead7d7abdb565f7cf952c92 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Mon, 11 Jul 2022 07:24:35 +0000 Subject: [PATCH 0918/1892] Merged PR 1395: Add select_scatter runtime function PyTorch v1.11 add 'torch.select_scatter' function, which has the most concise and accurate semantics. But the torch-to-ONNX component in the same version doesn't handle this function well. Two other alternatives: in-place `x[i,:] = sub`, which is in-compatible with current dataflow; 'y = x.clone(); y[i,:] = sub` causes crash. So instead of these, for codegen to use, add a runtime function based on `torch.masked_scatter` which is well supported in ONNX opset=11. --- cube/graph/function/scatter.py | 1 + cube/runtime/function/function.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/cube/graph/function/scatter.py b/cube/graph/function/scatter.py index a4116dc7..43d503db 100644 --- a/cube/graph/function/scatter.py +++ b/cube/graph/function/scatter.py @@ -32,6 +32,7 @@ class IRSelectScatter(IRFwOperation): def __init__(self, signature: str, inputs:Tuple[IRTensor, IRTensor], name: str, dim:int, index:int): assert len(inputs) == 2 + signature = 'cube.runtime.function.select_scatter' super().__init__(name, signature, 2, 1) self.set_input(0, inputs[0]) self.set_input(1, inputs[1]) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index cf762f32..f4920b36 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -80,6 +80,21 @@ def embedding(input: torch.Tensor, weight: torch.Tensor, padding_idx: Optional[i return output +# 'torch.select_scatter' isn't supported by Torch2ONNX yet. +# Implement it with 'torch.masked_scatter' which is supported with ONNX opset=11. +def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): + # e.g. [..., 1, -1, 1, ...] + shape = [1] * input.ndim + shape[dim] = -1 + + d = input.shape[dim] + mask = torch.zeros([d], dtype=torch.bool, device=input.device) + mask[index] = True + mask = mask.reshape(shape) + + return torch.masked_scatter(input, mask, src) + + def einops(input: torch.Tensor, recipe_str, reduction_type: str): import pickle recipe = pickle.loads(recipe_str) From 739a2a780a24b2e3f9838507e5a3ca139e9ef98e Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 14 Jul 2022 19:48:32 +0800 Subject: [PATCH 0919/1892] Fix dtype/ir_type misuse --- cube/graph/function/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 270ffeb2..df6427cb 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -123,7 +123,7 @@ def NewTensor(signature, # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 # 2) non-ragged # ... then this call will throw. - arr = torch.tensor(data, dtype=dtype_underlying) + arr = torch.tensor(data, dtype=dtype) # TODO temporarily fake creation with Zeros # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', From dcd7341fca9907e2e751a7149c1e65e586f34bc9 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Mon, 18 Jul 2022 06:23:01 +0000 Subject: [PATCH 0920/1892] Merged PR 1396: Separate `input(int)` and `inputs()` APIs; Make inputs/outputs field tuple, not list - Syntactically separated `inputs(idx:Optional[int]) : Union[T, List[T]]` and `outputs(idx:Optional[int]) : Union[T, List[T]]` APIs into `input(idx:int) : T` and `inputs() : List[T]` (also `output/outputs`) to eliminate runtime overhead to simulate overloading. -- Mainly refactored for `IRCell` -- Refactored `IRAdapterPrim` and `OpAnno` for the sake of consistency. - Changed the field type of `IRCell._inputs/_outputs` to tuples, instead of list, to ensure immutability and eliminate copy cost. --- cube/algorithm/ops/conv.py | 82 +++++++++---------- cube/algorithm/ops/dimops.py | 30 +++---- cube/algorithm/ops/pad.py | 30 +++---- cube/codegen/codegen.py | 2 +- cube/codegen/frontend_mapping.py | 2 +- cube/graph/function/cat.py | 10 +-- cube/graph/function/conv.py | 32 ++++---- cube/graph/function/creators.py | 12 +-- cube/graph/function/customops.py | 14 ++-- cube/graph/function/dimops.py | 34 ++++---- cube/graph/function/function.py | 6 +- cube/graph/function/pad.py | 8 +- cube/graph/function/repeat.py | 4 +- cube/graph/function/scatter.py | 6 +- cube/graph/function/scripteinops.py | 6 +- cube/graph/function/select.py | 8 +- cube/graph/graph.py | 4 +- cube/graph/parser/parser.py | 6 +- cube/ir/adapter/prim.py | 42 +++++----- cube/ir/cten.py | 117 ++++++++++++++++------------ cube/ir/operator.py | 12 +-- cube/logics/model.py | 4 +- examples/wrf/policy/onedim.py | 26 +++---- tests/test_parser.py | 4 +- tests/test_prim_loop.py | 4 +- 25 files changed, 257 insertions(+), 248 deletions(-) diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index 2bec034c..e782b76c 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -53,13 +53,13 @@ def satisfy(self, idx: int, dim: int, num: int): groups = node.kwargs['groups'] # split N: if (idx, dim) == (0, 0): - return node.inputs(0).shape[0] % num == 0 + return node.input(0).shape[0] % num == 0 # split oC if (idx, dim) == (1, 0): - return node.inputs(1).shape[0] % num == 0 + return node.input(1).shape[0] % num == 0 # split iC if (idx, dim) == (0, 1) or (idx, dim) == (1, 1): - return groups == 1 and node.inputs(1).shape[0] % 0 == num + return groups == 1 and node.input(1).shape[0] % 0 == num def instantiate(self, idx: int, dim: int, num: int): if not self.satisfy(idx, dim, num): @@ -69,28 +69,28 @@ def instantiate(self, idx: int, dim: int, num: int): outputs = list() # split N if (idx, dim) == (0, 0): - inputs = node.inputs(0).split_dim(dim, num) - weights = [node.inputs(1)] * num - bias = [node.inputs(2)] * num - outputs = node.outputs(0).split_dim(dim, num) + inputs = node.input(0).split_dim(dim, num) + weights = [node.input(1)] * num + bias = [node.input(2)] * num + outputs = node.output(0).split_dim(dim, num) # split oC if (idx, dim) == (1, 0): - inputs = [node.inputs(0)] * num - weights = node.inputs(1).split_dim(dim, num) - if node.inputs(2) is None: + inputs = [node.input(0)] * num + weights = node.input(1).split_dim(dim, num) + if node.input(2) is None: bias = [None] * num else: - bias = node.inputs(2).split_dim(dim, num) - outputs = node.outputs(0).split_dim(dim=1, num=num) + bias = node.input(2).split_dim(dim, num) + outputs = node.output(0).split_dim(dim=1, num=num) # split iC if (idx, dim) == (0, 1) or (idx, dim) == (1, 1): - inputs = node.inputs(0).split_dim(dim, num) - weights = node.inputs(1).split_dim(dim, num) - if node.inputs(2) is None: + inputs = node.input(0).split_dim(dim, num) + weights = node.input(1).split_dim(dim, num) + if node.input(2) is None: bias = [None] * num else: - bias = node.inputs(2).split_val(num) - outputs = node.outputs(0).split_val(num) + bias = node.input(2).split_val(num) + outputs = node.output(0).split_val(num) subnodes = list() for i, w, b, o in zip(inputs, weights, bias, outputs): subnodes.append(node.new([i, w, b], [o])) @@ -112,7 +112,7 @@ def __init__(self, node: IRConv2D): def satisfy(self, idx: int, dim: int, num: int): assert all(isinstance(t, int) for t in [idx, dim, num]), "idx, dim and num should be integer" node: IRConv2D = self.node - oH, oW = node.outputs(0).shape[2:] + oH, oW = node.output(0).shape[2:] stride = node.kwargs['stride'] dilation = node.kwargs['dilation'] if dim not in [2, 3]: @@ -133,9 +133,9 @@ def instantiate(self, idx: int, dim: int, num: int): if not self.satisfy(idx, dim, num): return None node: IRConv2D = self.node - H, W = node.inputs(0).shape[2:] - dH, dW = node.inputs(1).shape[2:] - oH, oW = node.outputs(0).shape[2:] + H, W = node.input(0).shape[2:] + dH, dW = node.input(1).shape[2:] + oH, oW = node.output(0).shape[2:] groups = node.kwargs['groups'] stride = node.kwargs['stride'] padding = node.kwargs['padding'] @@ -158,13 +158,13 @@ def instantiate(self, idx: int, dim: int, num: int): start = stop - dilation[0] * (dH - 1) # start = 0 if cid == 0 else 1023 # stop = 1025 if cid == 0 else H - inputs = _split_axis_custom(node.inputs(0), dim=dim, chunks=tuple(indmap)) + inputs = _split_axis_custom(node.input(0), dim=dim, chunks=tuple(indmap)) # weight - weights = [node.inputs(1)] * num + weights = [node.input(1)] * num # bias - bias = [node.inputs(2)] * num + bias = [node.input(2)] * num # outputs - outputs = node.outputs(0).split_dim(dim, num) + outputs = node.output(0).split_dim(dim, num) # split W if (idx, dim) == (0, 3): # input and padding @@ -181,13 +181,13 @@ def instantiate(self, idx: int, dim: int, num: int): stop = start + chunkH - padb indmap.append((max(0, start), min(H, stop))) start = stop - dilation[0] * (dH - 1) - inputs = _split_axis_custom(node.inputs(0), dim=dim, chunks=tuple(indmap)) + inputs = _split_axis_custom(node.input(0), dim=dim, chunks=tuple(indmap)) # weight - weights = [node.inputs(1)] * num + weights = [node.input(1)] * num # bias - bias = [node.inputs(2)] * num + bias = [node.input(2)] * num # outputs - outputs = node.outputs(0).split_dim(dim, num) + outputs = node.output(0).split_dim(dim, num) sub_nodes = list() for i, w, b, pad, o in zip(inputs, weights, bias, pads, outputs): conv = IRConv2D(node.signature, [i, w, b], node.name, @@ -213,7 +213,7 @@ def __init__(self, node: IRConv3D): def satisfy(self, idx: int, dim: int, num: int): assert all(isinstance(t, int) for t in [idx, dim, num]), "idx, dim and num should be integer" node: IRConv3D = self.node - oD, oH, oW = node.outputs(0).shape[2:] + oD, oH, oW = node.output(0).shape[2:] stride = node.kwargs['stride'] dilation = node.kwargs['dilation'] if dim not in [2, 3]: @@ -236,9 +236,9 @@ def instantiate(self, idx: int, dim: int, num: int): if not self.satisfy(idx, dim, num): return None node: IRConv3D = self.node - D, H, W = node.inputs(0).shape[2:] - dD, dH, dW = node.inputs(1).shape[2:] - oD, oH, oW = node.outputs(0).shape[2:] + D, H, W = node.input(0).shape[2:] + dD, dH, dW = node.input(1).shape[2:] + oD, oH, oW = node.output(0).shape[2:] groups = node.kwargs['groups'] stride = node.kwargs['stride'] padding = node.kwargs['padding'] @@ -264,13 +264,13 @@ def instantiate(self, idx: int, dim: int, num: int): start = stop - dilation[0] * (dH - 1) # start = 0 if cid == 0 else 1023 # stop = 1025 if cid == 0 else H - inputs = _split_axis_custom(node.inputs(0), dim=dim+1, chunks=indmap) + inputs = _split_axis_custom(node.input(0), dim=dim+1, chunks=indmap) # weight - weights = [node.inputs(1)] * num + weights = [node.input(1)] * num # bias - bias = [node.inputs(2)] * num + bias = [node.input(2)] * num # outputs - outputs = node.outputs(0).split_dim(dim+1, num) + outputs = node.output(0).split_dim(dim+1, num) # split W if (idx, dim) == (0, 3): # input and padding @@ -290,13 +290,13 @@ def instantiate(self, idx: int, dim: int, num: int): stop = start + chunkH - padb + addone indmap.append((max(0, start), min(W, stop))) start = stop - dilation[0] * (dW - 1) - inputs = _split_axis_custom(node.inputs(0), dim=dim+1, chunks=indmap) + inputs = _split_axis_custom(node.input(0), dim=dim+1, chunks=indmap) # weight - weights = [node.inputs(1)] * num + weights = [node.input(1)] * num # bias - bias = [node.inputs(2)] * num + bias = [node.input(2)] * num # outputs - outputs = node.outputs(0).split_dim(dim+1, num) + outputs = node.output(0).split_dim(dim+1, num) sub_nodes = list() for i, w, b, pad, o in zip(inputs, weights, bias, pads, outputs): conv = IRConv3D(node.signature, [i, w, b], node.name, diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index c6a92563..49500565 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -39,13 +39,13 @@ def satisfy(self, idx: int, dim: int, num: int) -> bool: ninputs = len(node.inputs()) idx = idx if idx >= 0 else idx + ninputs assert idx < ninputs, f"index out of boundary: {idx} >= {ninputs}" - assert isinstance(node.inputs(idx), IRSubTensor), f"partitioning on a non-tensor input" - dim = dim if dim >= 0 else dim + node.inputs(idx).ndims - assert dim < node.inputs(idx).ndims, f"dimension output of boundary: {dim} >= {node.inputs(idx).ndims}" + assert isinstance(node.input(idx), IRSubTensor), f"partitioning on a non-tensor input" + dim = dim if dim >= 0 else dim + node.input(idx).ndims + assert dim < node.input(idx).ndims, f"dimension output of boundary: {dim} >= {node.input(idx).ndims}" # due to implementation limits, we only partition the first annotated dimension # for inner-dimension cases. - self._adim: str = node.anno.inputs(idx).dims[dim].identifiers[0] - self._reduce: DimAnno.ReduceType = node.anno.inputs(idx).dims[dim].reduces[0] + self._adim: str = node.anno.input(idx).dims[dim].identifiers[0] + self._reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] dimlen = node.anno.getlen(self._adim) if self._reduce == DimAnno.ReduceType.Freeze: return False @@ -66,7 +66,7 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: if not isinstance(itensor, IRSubTensor): ins.append([itensor] * num) continue - shape_anno = node.anno.inputs(iidx) + shape_anno = node.anno.input(iidx) split_dims = shape_anno.getdims(self._adim) assert len(split_dims) <= 1, f"find split dims ({self._adim}) more than 1: {shape_anno}" if len(split_dims) == 1: @@ -87,7 +87,7 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: if not isinstance(otensor, IRSubTensor): ous.append([otensor] * num) continue - shape_anno = node.anno.outputs(oidx) + shape_anno = node.anno.output(oidx) split_dims = shape_anno.getdims(self._adim) assert len(split_dims) <= 1, f"find split dims ({self._adim}) more than 1: {shape_anno}" # split axis @@ -147,14 +147,14 @@ def satisfy(self, idx: int, dimi: int, dimo: int, num: int) -> bool: assert idx == 0, f"Index should be 0" assert len(node.inputs()) == 1, f"Inputs size should be 1" assert len(node.outputs()) == 1, f"Outputs size should be 1" - dimi = dimi if dimi >= 0 else dimi + node.inputs(0).ndims - dimo = dimo if dimo >= 0 else dimo + node.outputs(0).ndims - assert dimi < node.inputs(0).ndims, f"dimension out of boundary: {dimi} >= {node.inputs(0).ndims}" - assert dimo < node.outputs(0).ndims, f"dimension out of boundary" + dimi = dimi if dimi >= 0 else dimi + node.input(0).ndims + dimo = dimo if dimo >= 0 else dimo + node.output(0).ndims + assert dimi < node.input(0).ndims, f"dimension out of boundary: {dimi} >= {node.input(0).ndims}" + assert dimo < node.output(0).ndims, f"dimension out of boundary" # # due to implementation limits, we only partition the first annotated dimension # # for inner-dimension cases. - self._adimi: str = node.anno.inputs(0).dims[dimi].identifiers[0] - self._adimo: str = node.anno.outputs(0).dims[dimo].identifiers[0] + self._adimi: str = node.anno.input(0).dims[dimi].identifiers[0] + self._adimo: str = node.anno.output(0).dims[dimo].identifiers[0] dimlen = node.anno.getlen(self._adimi) if dimlen < num: return False @@ -171,7 +171,7 @@ def instantiate(self, idx: int, dimi: int, dimo: int, num: int) -> Optional[List for iidx, itensor in enumerate(node.inputs()): if not isinstance(itensor, IRSubTensor): assert 0, "should not happen" - shape_anno = node.anno.inputs(iidx) + shape_anno = node.anno.input(iidx) split_dims = shape_anno.getdims(self._adimi) assert len(split_dims) <= 1, f"find split dims ({self._adimi}) more than 1: {shape_anno}" if len(split_dims) == 1: @@ -185,7 +185,7 @@ def instantiate(self, idx: int, dimi: int, dimo: int, num: int) -> Optional[List for oidx, otensor in enumerate(node.outputs()): if not isinstance(otensor, IRSubTensor): assert 0, f"should not happen" - shape_anno = node.anno.outputs(oidx) + shape_anno = node.anno.output(oidx) split_dims = shape_anno.getdims(self._adimo) assert len(split_dims) <= 1, f"find split dims ({self._adimo}) more than 1: {shape_anno}" # split axis diff --git a/cube/algorithm/ops/pad.py b/cube/algorithm/ops/pad.py index 856a6487..0358ae98 100644 --- a/cube/algorithm/ops/pad.py +++ b/cube/algorithm/ops/pad.py @@ -48,14 +48,14 @@ def satisfy(self, dim: int, num: int): pad_dim_count = len(pad) / 2 # split non-pad dim - if dim < len(node.inputs(0).shape) - pad_dim_count: - return node.inputs(0).shape[dim] >= num - # return node.inputs(0).shape[dim] % num == 0 + if dim < len(node.input(0).shape) - pad_dim_count: + return node.input(0).shape[dim] >= num + # return node.input(0).shape[dim] % num == 0 # split pad dim else: - dim_in_pad = len(node.inputs(0).shape) - 1 - dim - return (node.inputs(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) >= num - # return (node.inputs(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) % num == 0 + dim_in_pad = len(node.input(0).shape) - 1 - dim + return (node.input(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) >= num + # return (node.input(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) % num == 0 def instantiate(self, dim: int, num: int): if not self.satisfy(dim, num): @@ -71,20 +71,20 @@ def instantiate(self, dim: int, num: int): subnodes = list() # split non-pad dim - if dim < len(node.inputs(0).shape) - pad_dim_count: - inputs = node.inputs(0).split_dim(dim, num) - outputs = node.outputs(0).split_dim(dim, num) + if dim < len(node.input(0).shape) - pad_dim_count: + inputs = node.input(0).split_dim(dim, num) + outputs = node.output(0).split_dim(dim, num) for i, o in zip(inputs, outputs): subnodes.append(node.new([i], [o])) else: # split pad dim - inputs = node.inputs(0).split_dim(dim, num) + inputs = node.input(0).split_dim(dim, num) slicers = list() pads = list() - dim_in_pad = len(node.inputs(0).shape) - 1 - dim + dim_in_pad = len(node.input(0).shape) - 1 - dim global_padl = pad[dim_in_pad * 2] global_padr = pad[dim_in_pad * 2 + 1] - chunk_size = (node.outputs(0).shape[dim] - global_padl - global_padr) // num - addone_num = (node.outputs(0).shape[dim] - global_padl - global_padr) % num + chunk_size = (node.output(0).shape[dim] - global_padl - global_padr) // num + addone_num = (node.output(0).shape[dim] - global_padl - global_padr) % num start = 0 for cid in range(num): padl = global_padl if cid == 0 else 0 @@ -97,10 +97,10 @@ def instantiate(self, dim: int, num: int): addone = int(cid < addone_num) stop = start + padl + padr + chunk_size + addone - slicers.append((max(0, start), min(node.outputs(0).shape[dim], stop))) + slicers.append((max(0, start), min(node.output(0).shape[dim], stop))) start = stop - outputs = _split_axis_custom(node.outputs(0), dim, tuple(slicers)) + outputs = _split_axis_custom(node.output(0), dim, tuple(slicers)) for i, o, p in zip(inputs, outputs, pads): subnodes.append(node.new([i], [o], pad=p)) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 955dcecb..3d7884b0 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -432,7 +432,7 @@ def later_ref(tensor, node) -> bool: else: finputs = ref_node.mirror.inputs() foutputs = ref_node.mirror.outputs() - grad_in = [t.grad for t in foutputs] + grad_in = tuple(t.grad for t in foutputs) if tensor in finputs + foutputs + grad_in: return True else: diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index 8a6ae7d2..e2f930c8 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -44,7 +44,7 @@ def emit_slice(node, arg_vars:list, kw_pairs:dict) -> str: but at the frontend such an invocation must be rewritten as 'x[:, l:h:s, :, :]' depending on the 'input's rank and the 'dim' value. """ - out_tensors : list = node.outputs() + out_tensors : tuple = node.outputs() assert len(out_tensors) == 1 out_tensor : IRTensor = out_tensors[0] diff --git a/cube/graph/function/cat.py b/cube/graph/function/cat.py index 374e314a..72b97799 100644 --- a/cube/graph/function/cat.py +++ b/cube/graph/function/cat.py @@ -1,6 +1,6 @@ from copy import copy import itertools -from typing import List +from typing import List, Tuple from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor @@ -27,7 +27,7 @@ def infer_shape(self) -> bool: # validation # TODO how about zero inputs? - tensors : List[IRTensor] = self.inputs(None) # None for all inputs + tensors : Tuple[IRTensor, ...] = self.inputs() # None for all inputs # Shape without the dim-th component s0 : list = None @@ -48,7 +48,7 @@ def infer_shape(self) -> bool: sumLen : int = sum(t.shape[dim] for t in tensors) s0.insert(dim, sumLen) - self.outputs(0).shape = s0 + self.output(0).shape = s0 return True @@ -64,7 +64,7 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, dim: int): def infer_shape(self) -> bool: dim = self.kwargs['dim'] - tensors : List[IRTensor] = self.inputs(None) # None for all inputs + tensors : Tuple[IRTensor, ...] = self.inputs() # None for all inputs # `stack` requires all input tensors to have the same shape if len(set(t.shape for t in tensors)) != 1: @@ -72,6 +72,6 @@ def infer_shape(self) -> bool: shp : list = tensors[0].shape.copy() shp.insert(dim, len(tensors)) - self.outputs(0).shape = shp + self.output(0).shape = shp return True diff --git a/cube/graph/function/conv.py b/cube/graph/function/conv.py index f115f8fa..8b2357d1 100644 --- a/cube/graph/function/conv.py +++ b/cube/graph/function/conv.py @@ -20,20 +20,20 @@ def infer_shape(self) -> bool: """ Output shape inference given the input shapes """ - if len(self.inputs(0).shape) == 0 or len(self.inputs(1).shape) == 0: + if len(self.input(0).shape) == 0 or len(self.input(1).shape) == 0: return False - N = self.inputs(0).shape[0] - iH, iW = self.inputs(0).shape[2:4] - oC = self.inputs(1).shape[0] + N = self.input(0).shape[0] + iH, iW = self.input(0).shape[2:4] + oC = self.input(1).shape[0] stride = self.kwargs['stride'] padding = self.kwargs['padding'] dilation = self.kwargs['dilation'] - dH = self.inputs(1).shape[2] - dW = self.inputs(1).shape[3] + dH = self.input(1).shape[2] + dW = self.input(1).shape[3] oH = (iH + padding[0] + padding[1] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 oW = (iW + padding[2] + padding[3] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 shape = [N, oC, oH, oW] - self.outputs(0).shape = shape + self.output(0).shape = shape return True def new(self, inputs: List, outputs: List): @@ -70,26 +70,26 @@ def infer_shape(self) -> bool: """ Output shape inference given the input shapes """ - if len(self.inputs(0).shape) == 0 or len(self.inputs(1).shape) == 0: + if len(self.input(0).shape) == 0 or len(self.input(1).shape) == 0: return False - N = self.inputs(0).shape[0] - iC = self.inputs(0).shape[1] - iT, iH, iW = self.inputs(0).shape[2:5] + N = self.input(0).shape[0] + iC = self.input(0).shape[1] + iT, iH, iW = self.input(0).shape[2:5] - oC = self.inputs(1).shape[0] + oC = self.input(1).shape[0] stride = self.kwargs['stride'] padding = self.kwargs['padding'] dilation = self.kwargs['dilation'] - dT = self.inputs(1).shape[2] - dH = self.inputs(1).shape[3] - dW = self.inputs(1).shape[4] + dT = self.input(1).shape[2] + dH = self.input(1).shape[3] + dW = self.input(1).shape[4] oT = (iT + 2 * padding[0] - dilation[0] * (dT - 1) - 1) // stride[0] + 1 oH = (iH + 2 * padding[1] - dilation[1] * (dH - 1) - 1) // stride[1] + 1 oW = (iW + 2 * padding[2] - dilation[2] * (dW - 1) - 1) // stride[2] + 1 shape = [N, oC, oT, oH, oW] - self.outputs(0).shape = shape + self.output(0).shape = shape return True def new(self, inputs: List, outputs: List): diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index c822a0bd..e42467e9 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -16,14 +16,14 @@ def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType # Customize output's dtype only after 'super().__init__' and 'self.set_input', # otherwise it gets overwritten. - self.outputs(0).dtype = ir_dtype + self.output(0).dtype = ir_dtype # The positional argument to specify the shape is actually called 'size'. self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) def infer_shape(self) -> bool: shape : list = copy(self.kwargs["size"]) - self.outputs(0).shape = shape + self.output(0).shape = shape return True class IROnes(IRFwOperation): @@ -37,14 +37,14 @@ def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType # Customize output's dtype only after 'super().__init__' and 'self.set_input', # otherwise it gets overwritten. - self.outputs(0).dtype = ir_dtype + self.output(0).dtype = ir_dtype # The positional argument to specify the shape is actually called 'size'. self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) def infer_shape(self) -> bool: shape : list = copy(self.kwargs["size"]) - self.outputs(0).shape = shape + self.output(0).shape = shape return True #class IRNewTensor(IRFwOperation): @@ -68,12 +68,12 @@ def __init__(self, signature: str, inputs, name:str, ir_dtype:IRDType): # Customize output's dtype only after 'super().__init__' and 'self.set_input', # otherwise it gets overwritten. - self.outputs(0).dtype = ir_dtype + self.output(0).dtype = ir_dtype self.kwargs.update({"dtype": ir_dtype}) def infer_shape(self) -> bool: - self.outputs(0).shape = self.inputs(0).shape + self.output(0).shape = self.input(0).shape return True diff --git a/cube/graph/function/customops.py b/cube/graph/function/customops.py index c8572d22..4ccd206c 100644 --- a/cube/graph/function/customops.py +++ b/cube/graph/function/customops.py @@ -40,19 +40,19 @@ def infer_shape(self) -> bool: Output shape inference given the input shapes """ if self.signature.endswith('strip_2_borders'): - if len(self.inputs(0).shape) == 0: + if len(self.input(0).shape) == 0: return False - shape = self.inputs(0).shape + shape = self.input(0).shape shape[0] = shape[0]-2 - self.outputs(0).shape = shape + self.output(0).shape = shape return True elif self.signature.endswith('update_diag_'): - shape = self.inputs(0).shape - self.outputs(0).shape = shape + shape = self.input(0).shape + self.output(0).shape = shape return True elif self.signature.endswith('update_geopotential_'): - shape = self.inputs(0).shape - self.outputs(0).shape = shape + shape = self.input(0).shape + self.output(0).shape = shape return True else: raise RuntimeError(f'IRCustomOps::infer_shape unknown signature: {self.signature}') diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index c9e5e28c..c205dd99 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -314,20 +314,12 @@ def identifiers(self) -> Set[str]: """ return tuple(self._identifiers.keys()) - def inputs(self, index: Optional[int] = None) -> Union[ShapeAnno, Tuple[ShapeAnno]]: - """! - Get shape annotation of index-th input. - If index is None, will return all shape annotations + def input(self, index:int) -> ShapeAnno: + assert index < len(self._inputs), "index out of boundary" + return self._inputs[index] - @param index Optional[int]: the index of input. - - @return shape_annos Union[ShapeAnno, Tuple[ShapeAnno]]: the shape annotation - """ - assert index is None or index < len(self._inputs), "index out of boundary" - if index is None: - return self._inputs - else: - return self._inputs[index] + def inputs(self) -> Tuple[ShapeAnno, ...]: + return self._inputs def set_input(self, index: int, shape_anno: Union[str, ShapeAnno]): """ @@ -339,12 +331,12 @@ def set_input(self, index: int, shape_anno: Union[str, ShapeAnno]): inputs[index] = shape_anno if isinstance(shape_anno, ShapeAnno) else ShapeAnno(shape_anno) self._inputs = tuple(inputs) - def outputs(self, index: Optional[int] = None) -> Union[ShapeAnno, Tuple[ShapeAnno]]: - assert index is None or index < len(self._outputs), "index out of boundary" - if index is None: - return self._outputs - else: - return self._outputs[index] + def output(self, index:int) -> ShapeAnno: + assert index < len(self._outputs), "index out of boundary" + return self._outputs[index] + + def outputs(self) -> Tuple[ShapeAnno, ...]: + return self._outputs def set_output(self, index: int, shape_anno: Union[str, ShapeAnno]): """ @@ -610,7 +602,7 @@ def align(self, inputs: List[IRTensor], op_anno: OpAnno, **kwargs) -> bool: expand_dims = [] if ndims > 0: expand_dims = list(DimAnno(candicates[dim] + reduce) for dim in range(ndims)) - shape_anno = list(op_anno.inputs(idx).dims[:pos]) + expand_dims + list(op_anno.inputs(idx).dims[pos+1:]) + shape_anno = list(op_anno.input(idx).dims[:pos]) + expand_dims + list(op_anno.input(idx).dims[pos+1:]) shape_anno = ShapeAnno(tuple(shape_anno)) op_anno.set_input(idx, shape_anno) # * should appear in inputs @@ -620,7 +612,7 @@ def align(self, inputs: List[IRTensor], op_anno: OpAnno, **kwargs) -> bool: names = [dim_anno.name for dim_anno in shape_anno.dims] if '*' in names: pos = names.index('*') - shape_anno = list(op_anno.outputs(idx).dims[:pos]) + expand_dims + list(op_anno.outputs(idx).dims[pos+1:]) + shape_anno = list(op_anno.output(idx).dims[:pos]) + expand_dims + list(op_anno.output(idx).dims[pos+1:]) shape_anno = ShapeAnno(tuple(shape_anno)) op_anno.set_output(idx, shape_anno) op_anno.reset_identifiers() diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index df6427cb..21fee65a 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -612,12 +612,12 @@ def Reshape(signature, inputs): # https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html?highlight=torch%20conv2d#torch.nn.functional.conv2d # """ # def adapt(anno: OpAnno, node: IRDimops) -> OpAnno: -# iH, iW = node.inputs(0).shape[2:4] +# iH, iW = node.input(0).shape[2:4] # stride = node.kwargs['stride'] # padding = node.kwargs['padding'] # dilation = node.kwargs['dilation'] -# dH = node.inputs(1).shape[2] -# dW = node.inputs(1).shape[3] +# dH = node.input(1).shape[2] +# dW = node.input(1).shape[3] # oH = (iH + 2 * padding[0] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 # oW = (iW + 2 * padding[1] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 # anno.outputs[0][2] = DimAnno([str(oH)]) diff --git a/cube/graph/function/pad.py b/cube/graph/function/pad.py index 0d46be5e..ee50cb8a 100644 --- a/cube/graph/function/pad.py +++ b/cube/graph/function/pad.py @@ -20,20 +20,20 @@ def infer_shape(self) -> bool: """ Output shape inference given the input shapes """ - if len(self.inputs(0).shape) == 0: + if len(self.input(0).shape) == 0: return False - N = self.inputs(0).shape[0] + N = self.input(0).shape[0] pad = self.kwargs['pad'] mode = self.kwargs['mode'] value = self.kwargs['value'] assert len(pad) % 2 == 0, "IRPad::infer_shape len(pad) % 2 == 0" - shape = self.inputs(0).shape + shape = self.input(0).shape for pad_idx, pad_size in enumerate(pad): shape[-1 - (pad_idx // 2)] += pad_size - self.outputs(0).shape = shape + self.output(0).shape = shape return True def new(self, inputs: List, outputs: List, pad = None): diff --git a/cube/graph/function/repeat.py b/cube/graph/function/repeat.py index 0d2e650f..ef0dbf18 100644 --- a/cube/graph/function/repeat.py +++ b/cube/graph/function/repeat.py @@ -19,7 +19,7 @@ def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, repeats:Li self.kwargs.update({"repeats": repeats}) def infer_shape(self) -> bool: - shp_self : List[int] = self.inputs(0).shape + shp_self : List[int] = self.input(0).shape if len(shp_self) == 0: return False @@ -35,6 +35,6 @@ def infer_shape(self) -> bool: shp = [d1 * d2 for d1, d2 in itertools.zip_longest(s1, s2, fillvalue=1)] shp.reverse() - self.outputs(0).shape = shp + self.output(0).shape = shp return True diff --git a/cube/graph/function/scatter.py b/cube/graph/function/scatter.py index 43d503db..14c531b9 100644 --- a/cube/graph/function/scatter.py +++ b/cube/graph/function/scatter.py @@ -39,11 +39,11 @@ def __init__(self, signature: str, inputs:Tuple[IRTensor, IRTensor], name: str, self.kwargs.update({"dim": dim, "index": index}) def infer_shape(self) -> bool: - shp_self : List[int] = self.inputs(0).shape + shp_self : List[int] = self.input(0).shape if len(shp_self) == 0: return False - shp_input = self.inputs(1).shape + shp_input = self.input(1).shape if len(shp_input) == 0: print("The 0-length input shape is ambiguous, may be uninferrable or just of a 0-d tensor") @@ -55,6 +55,6 @@ def infer_shape(self) -> bool: raise RuntimeError(f"self shape {shp_self} and input shape {shp_input} with dim={dim} mismatch") s2 = copy(shp_self) - self.outputs(0).shape = s2 + self.output(0).shape = s2 return True diff --git a/cube/graph/function/scripteinops.py b/cube/graph/function/scripteinops.py index fc078fef..29b017ac 100644 --- a/cube/graph/function/scripteinops.py +++ b/cube/graph/function/scripteinops.py @@ -24,7 +24,7 @@ def infer_shape(self) -> bool: """ Output shape inference given the input shapes """ - if len(self.inputs(0).shape) == 0: + if len(self.input(0).shape) == 0: return False recipe_str = self.kwargs['recipe_str'] @@ -32,9 +32,9 @@ def infer_shape(self) -> bool: recipe = pickle.loads(recipe_str) reduction_type = self.kwargs['reduction_type'] - tmp_tensor = torch.zeros(self.inputs(0).shape) + tmp_tensor = torch.zeros(self.input(0).shape) tmp_output = _apply_recipe(recipe, tmp_tensor, reduction_type) - self.outputs(0).shape = list(tmp_output.shape) + self.output(0).shape = list(tmp_output.shape) return True def new(self, inputs: List, outputs: List): diff --git a/cube/graph/function/select.py b/cube/graph/function/select.py index d21096e2..ba5d8b8d 100644 --- a/cube/graph/function/select.py +++ b/cube/graph/function/select.py @@ -16,7 +16,7 @@ def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, dim:int, i self.kwargs.update({"dim": dim, "index": index}) def infer_shape(self) -> bool: - s : List[int] = self.inputs(0).shape + s : List[int] = self.input(0).shape if len(s) == 0: return False @@ -24,7 +24,7 @@ def infer_shape(self) -> bool: s2 = copy(s) s2.pop(dim) - self.outputs(0).shape = s2 + self.output(0).shape = s2 return True @@ -43,7 +43,7 @@ def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, self.kwargs.update({"dim": dim, "start": start, "end": end, "step": step}) def infer_shape(self) -> bool: - s : List[int] = self.inputs(0).shape + s : List[int] = self.input(0).shape if len(s) == 0: return False @@ -70,7 +70,7 @@ def clip(offset): sliced_dim_len = len(range(start, end, step)) s2 = s.copy() s2[dim] = sliced_dim_len - self.outputs(0).shape = s2 + self.output(0).shape = s2 return True diff --git a/cube/graph/graph.py b/cube/graph/graph.py index b9b23ee4..7a3ecbe7 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -380,7 +380,7 @@ def get_inputs(nodes: List[IRCell]): """ all_outputs = list() for node in nodes: - all_outputs += node.outputs() + all_outputs += list(node.outputs()) inputs = list() for cell in nodes: for input in cell.inputs(): @@ -405,7 +405,7 @@ def get_outputs(nodes: List[IRCell]): """ all_inputs = list() for node in nodes: - all_inputs += node.inputs() + all_inputs += list(node.inputs()) outputs = list() for node in nodes: for idx, output in enumerate(node.outputs()): diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 9ea0b216..a3dcc862 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -238,10 +238,10 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: for output in node.outputs(): if isinstance(output.type(), torch._C.TupleType): tuplen = len(output.type().elements()) - ir_output = [ir_node.outputs(idx) for idx in range(cnt, cnt+tuplen)] + ir_output = [ir_node.output(idx) for idx in range(cnt, cnt+tuplen)] cnt += tuplen else: - ir_output = ir_node.outputs(cnt) + ir_output = ir_node.output(cnt) cnt += 1 frame.add_var(output.debugName(), ir_output) @@ -312,7 +312,7 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: # handle outputs for index, output in enumerate(outputs): - frame.add_var(output.debugName(), ir_node.outputs(index)) + frame.add_var(output.debugName(), ir_node.output(index)) return [ir_node] diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index f85861a2..b43835cf 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -20,19 +20,17 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor], **kwar self.kwargs[arg] = val self.signature = None - def inputs(self, idx: Optional[int] = None): - assert idx is None or isinstance(idx, int), "expected idx to be None or int" - if idx is None: - return copy.copy(self._inputs) - else: - return self._inputs[idx] - - def outputs(self, idx: Optional[int] = None): - assert idx is None or isinstance(idx, int), "expected idx to be None or int" - if idx is None: - return copy.copy(self._outputs) - else: - return self._outputs[idx] + def input(self, idx:int): + return self._inputs[idx] + + def inputs(self): + return copy.copy(self._inputs) + + def output(self, idx:int): + return self._outputs[idx] + + def outputs(self): + return copy.copy(self._outputs) def dispatch(self, devid: int): if devid not in self.device: @@ -100,7 +98,7 @@ def __init__(self, itensor: IRSubTensor): self.signature = 'cube.runtime.adapter.identity' def __repr__(self): - dscp = f"{self.outputs(0)} = identity({self.inputs(0)})" + dscp = f"{self.output(0)} = identity({self.input(0)})" return dscp @@ -117,7 +115,7 @@ def __init__(self, self.signature = f"cube.runtime.adapter.select" def __repr__(self): - dscp = f"{self.outputs(0)} = select({self.inputs(0)}, indmap={self.kwargs['indmap']}, valmap={self.kwargs['valmap']})" + dscp = f"{self.output(0)} = select({self.input(0)}, indmap={self.kwargs['indmap']}, valmap={self.kwargs['valmap']})" return dscp @@ -131,7 +129,7 @@ def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, dim: int) self.signature = 'cube.runtime.adapter.smerge' def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs(0)} = concat({self.inputs()}, dim={self.kwargs['dim']})" + return f"dev{self.device}: {self.output(0)} = concat({self.inputs()}, dim={self.kwargs['dim']})" # numerical primitive @@ -143,7 +141,7 @@ def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): self.signature = 'cube.runtime.adapter.vmerge' def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs(0)} = add({self.inputs()})" + return f"dev{self.device}: {self.output(0)} = add({self.inputs()})" # communication primitive @@ -156,7 +154,7 @@ def __init__(self, tensor, dst: int): self.signature = 'cube.runtime.adapter.send' def __repr__(self) -> str: - return f"{self.inputs(0)} = send({self.inputs(0)}, dst={self.kwargs['dst']}" + return f"{self.input(0)} = send({self.input(0)}, dst={self.kwargs['dst']}" class RecvPrim(CommPrim): @@ -169,7 +167,7 @@ def __init__(self, tensor: IRSubTensor, src: int): self.signature = 'cube.runtime.adapter.recv' def __repr__(self) -> str: - return f"{self.outputs(0)} = recv(shape={self.kwargs['shape']}, dtype={self.kwargs['dtype']}, src={self.kwargs['src']}" + return f"{self.output(0)} = recv(shape={self.kwargs['shape']}, dtype={self.kwargs['dtype']}, src={self.kwargs['src']}" class MovePrim(CommPrim): @@ -182,13 +180,13 @@ def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor): def dispatch(self, devid: int) -> Union[SendPrim, RecvPrim]: if devid == self.kwargs['src']: - return SendPrim(self.inputs(0), self.kwargs['dst']) + return SendPrim(self.input(0), self.kwargs['dst']) if devid == self.kwargs['dst']: - return RecvPrim(self.outputs(0), self.kwargs['src']) + return RecvPrim(self.output(0), self.kwargs['src']) return None def __repr__(self): - dscp = f"move({self.inputs(0)}, src={self.kwargs['src']}, dst={self.kwargs['dst']})" + dscp = f"move({self.input(0)}, src={self.kwargs['src']}, dst={self.kwargs['dst']})" return dscp diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 764bf93f..e4adf4b1 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -14,7 +14,7 @@ """ -from typing import List, Tuple, Union, Optional, Any +from typing import Iterable, List, Tuple, Union, Optional, Any import copy from cube.ir.unique import IDGenerator @@ -54,12 +54,12 @@ def __init__(self, self._device = list() # source tensors - self._inputs: List[Any] = [None] * input_length + self._inputs: Tuple[Optional[IRTensor], ...] = (None,) * input_length # destination tensors - self._outputs: List[IRTensor] = [None] * output_length + self._outputs: Tuple[Optional[IRTensor], ...] = (None,) * output_length if init_outputs: - self._outputs: List[IRTensor] = [IRTensor() for _ in range(output_length)] + self._outputs = tuple(IRTensor() for _ in range(output_length)) for tensor in self._outputs: tensor.cell = self @@ -130,28 +130,31 @@ def on_device(self, device_id: int): raise TypeError(f"Expected device id to be int but got {type(device_id)}") return device_id in self.device - def inputs(self, index: Optional[int] = None) -> Union[List[Any], Any]: + def input(self, index:int): + # type: (int) -> Optional[IRTensor] """ - Get input tensor at input index + Get the input tensor at input index Args: - index (int or None): - index of the inputs, None will return the nodes - for all the inputs + index (int): + index of the inputs Returns: - values: Union[List[Any], Any] + values: Optional[IRTensor] """ - if isinstance(index, int): - if index >= len(self._inputs): - raise RuntimeError( - f"Get the input out of range ({index} >= {len(self._inputs)}" - ) - return self._inputs[index] - elif index is None: - return copy.copy(self._inputs) - else: - raise TypeError("Expected index to be None or int") + return self._inputs[index] + + def inputs(self): + # type: () -> Tuple[Optional[IRTensor], ...] + """ + Get all input tensors + + Returns: + values: Tuple[Optional[IRTensor], ...] + """ + + # self._inputs is a tuple and is immutable. + return self._inputs def predecessors(self, index: Optional[int] = None) -> List: """ @@ -175,28 +178,31 @@ def predecessors(self, index: Optional[int] = None) -> List: else: raise TypeError("Expected index to be None or int") - def outputs(self, index: Optional[int] = None) -> Union[List[Any], Any]: + def output(self, index:int): + # type: (int) -> Optional[IRTensor] """ - Get output tensor at output index + Get the output tensor at output index Args: - index (int or None): - index of the outputs, None will return the nodes - for all the outputs + index (int): + index of the outputs Returns: - values: Union[List[Any], Any] + values: Optional[IRTensor] """ - if isinstance(index, int): - if index >= len(self._outputs): - raise RuntimeError( - f"Get the output out of range ({index} >= {len(self._outputs)}" - ) - return self._outputs[index] - elif index is None: - return copy.copy(self._outputs) - else: - raise TypeError("Expected index to be None or int") + return self._outputs[index] + + def outputs(self): + # type: () -> Tuple[Optional[IRTensor], ...] + """ + Get all output tensors + + Returns: + values: Tuple[Optional[IRTensor], ...] + """ + + # self._outputs is a tuple and is immutable. + return self._outputs def successors(self, index: Optional[int] = None) -> List: """ @@ -221,19 +227,21 @@ def successors(self, index: Optional[int] = None) -> List: else: raise TypeError("Expected index to be None or int") - def set_input(self, input_index: int, val: Any): + def set_input(self, input_index: int, val): + # type: (int, Optional[IRTensor]) -> Optional[IRTensor] """ Set the node inputs[input_index] with the tensor Args: - val: Union[IRTensor, Any] + val: Optional[IRTensor] Return: the set tensor """ - if input_index >= len(self.inputs()): + c = len(self._inputs) + if input_index >= c or input_index < -c: raise RuntimeError( - f"Set the input out of range ({input_index} >= {len(self._inputs)})" + f"Set the input out of range ({input_index} >= {c} or {input_index} < {-c})" ) if isinstance(val, IRTensor): # copy the val @@ -247,27 +255,36 @@ def set_input(self, input_index: int, val: Any): if isinstance(output, IRTensor): output.dtype = self._dtype val.dtype = self._dtype - self._inputs[input_index] = val + + l = list(self._inputs) + l[input_index] = val + self._inputs = tuple(l) + return val - def set_output(self, output_index: int, val: Any): + def set_output(self, output_index: int, val): + # type: (int, Optional[IRTensor]) -> Optional[IRTensor] """ Set the node inputs[output_index] with the tensor Args: - val: Union[IRTensor, Any] + val: Optional[IRTensor] IRTensor or any deterministic value (int, bool, str, etc) """ - if output_index >= len(self.outputs()): + c = len(self._outputs) + if output_index >= c or output_index < -c: raise RuntimeError( - f"Set the input out of range ({output_index} >= {len(self._inputs)})" + f"Set the input out of range ({output_index} >= {c} or {output_index} < {-c})" ) if isinstance(val, IRTensor): val = copy.copy(val) val.cell = self # set output value dtype val.dtype = self._dtype - self._outputs[output_index] = val + + l = list(self._outputs) + l[output_index] = val + self._outputs = tuple(l) return val def add_predecessor(self, input_index: int, cell): @@ -298,7 +315,7 @@ def clear_predecessor(self): def add_successor(self, output_index: int, cell): """ Set self node the output index node. - `node` will take the self.outputs(index) as the input + `node` will take the self.output(index) as the input To add control dependency, use `output_index=-1` """ @@ -326,6 +343,7 @@ def make_empty(self): @staticmethod def get_inputs(cells): + # type: (Iterable[IRCell]) -> list[IRCell] """ Get all the input tensors the is not generated by nodes @@ -336,7 +354,7 @@ def get_inputs(cells): """ all_outputs = list() for cell in cells: - all_outputs += cell.outputs() + all_outputs += list(cell.outputs()) inputs = list() for cell in cells: for input in cell.inputs(): @@ -348,6 +366,7 @@ def get_inputs(cells): @staticmethod def get_outputs(cells): + # type: (Iterable[IRCell]) -> list[IRCell] """ Get all the input tensors the is not generated by nodes @@ -356,7 +375,7 @@ def get_outputs(cells): """ all_inputs = list() for node in cells: - all_inputs += node.inputs() + all_inputs += list(node.inputs()) outputs = list() for node in cells: for output in node.outputs(): diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 24834c4f..a7bce616 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -124,10 +124,10 @@ def replicate(self): cpy = copy.copy(self) cpy._device = list() # reset input and output - cpy._inputs = [None] * len(self.inputs()) + cpy._inputs = (None,) * len(self.inputs()) for idx, input in enumerate(self.inputs()): cpy.set_input(idx, input) - cpy._outputs = [None] * len(self.outputs()) + cpy._outputs = (None,) * len(self.outputs()) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None @@ -213,10 +213,10 @@ def replicate(self): cpy._device = list() cpy._id = IDGenerator().gen_cell_id() # reset input and output - cpy._inputs = [None] * len(self.inputs()) + cpy._inputs = (None,) * len(self.inputs()) for idx, input in enumerate(self.inputs()): cpy.set_input(idx, input) - cpy._outputs = [None] * len(self.outputs()) + cpy._outputs = (None,) * len(self.outputs()) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None @@ -248,10 +248,10 @@ def replicate(self): cpy._device = list() cpy._id = IDGenerator().gen_cell_id() # reset input and output - cpy._inputs = [None] * len(self.inputs()) + cpy._inputs = (None,) * len(self.inputs()) for idx, input in enumerate(self.inputs()): cpy.set_input(idx, input) - cpy._outputs = [None] * len(self.outputs()) + cpy._outputs = (None,) * len(self.outputs()) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None diff --git a/cube/logics/model.py b/cube/logics/model.py index 2eb8c059..d0b29f71 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Tuple import copy from cube.graph.graph import IRGraph @@ -15,7 +15,7 @@ def forward(graph: IRGraph, *args) -> IRGraph: if not isinstance(graph, IRGraph): raise TypeError("Requires IRGraph for forward") # align graph with input tensors - itensors: List[IRSubTensor] = graph.inputs() + itensors: Tuple[IRSubTensor, ...] = graph.inputs() for idx, (itensor, arg) in enumerate(zip(itensors, args)): graph.set_input(idx, arg) for producer in copy.copy(itensor.parent.producers): diff --git a/examples/wrf/policy/onedim.py b/examples/wrf/policy/onedim.py index a0d0dde5..7c45d8e6 100644 --- a/examples/wrf/policy/onedim.py +++ b/examples/wrf/policy/onedim.py @@ -42,13 +42,13 @@ def PAS_ALL_TEST(graph: IRGraph, resource): sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) elif isinstance(node, IRDimops): sign = node.signature.split('.')[-1] - if (sign == 'mul' or sign == 'add' or sign == 'sub' or sign == 'div') and (len(node.inputs(0).shape) == 5 or len(node.inputs(0).shape) == 3): + if (sign == 'mul' or sign == 'add' or sign == 'sub' or sign == 'div') and (len(node.input(0).shape) == 5 or len(node.input(0).shape) == 3): algo = node.algorithms('dim') - if len(node.inputs(0).shape) == 3: + if len(node.input(0).shape) == 3: sub_nodes = graph.partition(node, algo, idx=0, dim=1, num=resource.ngpus) if sub_nodes == None: sub_nodes = graph.replicate(node, times=resource.ngpus) - elif len(node.inputs(0).shape) == 5: + elif len(node.input(0).shape) == 5: sub_nodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus) if sub_nodes == None: sub_nodes = graph.replicate(node, times=resource.ngpus) @@ -56,13 +56,13 @@ def PAS_ALL_TEST(graph: IRGraph, resource): print('partition view') print(node) algo = node.algorithms('view_simp') - sub_nodes = graph.partition(node, algo, idx=0, dimi=node.inputs(0).ndims-2, dimo=node.outputs(0).ndims-2, num=resource.ngpus) + sub_nodes = graph.partition(node, algo, idx=0, dimi=node.input(0).ndims-2, dimo=node.output(0).ndims-2, num=resource.ngpus) print(sub_nodes) else: sub_nodes = graph.replicate(node, times=resource.ngpus) elif isinstance(node, IRPad): algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.inputs(0).ndims-2, num=resource.ngpus) + sub_nodes = graph.partition(node, algo, dim=node.input(0).ndims-2, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): @@ -81,7 +81,7 @@ def PAS_ALL_X(graph: IRGraph, resource): sub_nodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus) elif isinstance(node, IRDimops): if sign in ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat']: - ndims = node.inputs(0).ndims + ndims = node.input(0).ndims algo = node.algorithms('dim') if ndims == 3 or ndims == 5: sub_nodes = graph.partition(node, algo, idx=0, dim=ndims-1, num=resource.ngpus) @@ -91,15 +91,15 @@ def PAS_ALL_X(graph: IRGraph, resource): sub_nodes = graph.replicate(node, times=resource.ngpus) elif sign == 'view': algo = node.algorithms('view_simp') - if node.inputs(0).ndims >= 3 and node.outputs(0).ndims >= 3: - sub_nodes = graph.partition(node, algo, idx=0, dimi=node.inputs(0).ndims-1, dimo=node.outputs(0).ndims-1, num=resource.ngpus) + if node.input(0).ndims >= 3 and node.output(0).ndims >= 3: + sub_nodes = graph.partition(node, algo, idx=0, dimi=node.input(0).ndims-1, dimo=node.output(0).ndims-1, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) elif isinstance(node, IRPad): algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.inputs(0).ndims-1, num=resource.ngpus) + sub_nodes = graph.partition(node, algo, dim=node.input(0).ndims-1, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): @@ -118,7 +118,7 @@ def PAS_ALL_Y(graph: IRGraph, resource): sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) elif isinstance(node, IRDimops): if sign in ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat']: - ndims = node.inputs(0).ndims + ndims = node.input(0).ndims algo = node.algorithms('dim') if ndims == 3 or ndims == 5: sub_nodes = graph.partition(node, algo, idx=0, dim=ndims-2, num=resource.ngpus) @@ -128,15 +128,15 @@ def PAS_ALL_Y(graph: IRGraph, resource): sub_nodes = graph.replicate(node, times=resource.ngpus) elif sign == 'view': algo = node.algorithms('view_simp') - if node.inputs(0).ndims >= 3 and node.outputs(0).ndims >= 3: - sub_nodes = graph.partition(node, algo, idx=0, dimi=node.inputs(0).ndims-2, dimo=node.outputs(0).ndims-2, num=resource.ngpus) + if node.input(0).ndims >= 3 and node.output(0).ndims >= 3: + sub_nodes = graph.partition(node, algo, idx=0, dimi=node.input(0).ndims-2, dimo=node.output(0).ndims-2, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) elif isinstance(node, IRPad): algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.inputs(0).ndims-2, num=resource.ngpus) + sub_nodes = graph.partition(node, algo, dim=node.input(0).ndims-2, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): diff --git a/tests/test_parser.py b/tests/test_parser.py index 7ef1b874..8923261d 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -33,7 +33,7 @@ def test_optional_dtype_none(aten_op, ir_op_cls): ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) for node in ir_nodes: if isinstance(node, ir_op_cls): - assert node.outputs(0).dtype == DType2IRDType.map(torch.get_default_dtype()) + assert node.output(0).dtype == DType2IRDType.map(torch.get_default_dtype()) @pytest.mark.parametrize( "aten_op, ir_op_cls", @@ -58,4 +58,4 @@ def test_optional_dtype_underlying_int(aten_op, ir_op_cls): ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) for node in ir_nodes: if isinstance(node, ir_op_cls): - assert node.outputs(0).dtype == IRDType.int32 + assert node.output(0).dtype == IRDType.int32 diff --git a/tests/test_prim_loop.py b/tests/test_prim_loop.py index 005fcfb1..575d7342 100644 --- a/tests/test_prim_loop.py +++ b/tests/test_prim_loop.py @@ -90,7 +90,7 @@ def test_unroll_with_structural_info(): assert in_val_1 == p assert in_val_2 == p - out_val = ir_node.outputs(0) + out_val = ir_node.output(0) assert out_val.shape == [ 2**(i+2) , 3] p = out_val @@ -150,7 +150,7 @@ def __init__(self) -> None: assert in_val_1 == p assert in_val_2 == p - out_val = ir_node.outputs(0) + out_val = ir_node.output(0) assert out_val.shape == [ 2**(i+2) , 3] p = out_val \ No newline at end of file From ebd474f61cb0115b67f7ed153de2ca7ba49c5c42 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 19 Jul 2022 11:06:16 +0800 Subject: [PATCH 0921/1892] refine grouping --- cube/execplan/planpass/grouping.py | 130 ++++++----------------------- 1 file changed, 25 insertions(+), 105 deletions(-) diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index d3cbca65..f974f49d 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -1,38 +1,13 @@ """ Operation grouping """ -import os -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Tuple from cube.execplan import ExecutionPlan from cube.execplan.planpass.planpass import PlanPass from cube.ir.adapter import IRAdapter -from cube.ir.operator import IRBpOperation, IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.cten import IRCell - -SCIENTIFIC_COMPUTING = 'SCIENTIFIC_COMPUTING' -_use_new_grouping_algo:Optional[bool] = None - -def _set_use_new_grouping_algo(use_new_grouping_algo:Optional[bool]) -> None: - """ - Set the internal flag whether to use a new grouping algorithm which is faster for grouping forward-only graphs, - especially for workloads from scientific-computing domains. - - Parameters: - - use_new_grouping_algo (bool): - 'True' to force the use of the new grouping algorithm. - 'False' to force the use of the old grouping algorithm. - 'None' to use the new grouping algorithm if the environment variable 'SCIENTIFIC_COMPUTING' exists. - """ - assert use_new_grouping_algo is None or isinstance(use_new_grouping_algo, bool) - global _use_new_grouping_algo - _use_new_grouping_algo = use_new_grouping_algo - -def _get_use_new_grouping_algo() -> bool: - if _use_new_grouping_algo is None: - return SCIENTIFIC_COMPUTING in os.environ - else: - return _use_new_grouping_algo class Grouping(PlanPass): @@ -71,92 +46,37 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: Returns: Tuple: (fgroups, bgroups) """ + def is_forward_node(fnode): + if isinstance(fnode, IRFwOperation): + return True + if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.forward: + return True + return False + fgroups, bgroups = dict(), dict() for devid in execplan.devices(): fgroups[devid], bgroups[devid] = list(), list() - fpieces, bpieces = list(), list() seq = execplan.seq(devid) - fnodes = [] - - def is_forward_node(fnode): - if isinstance(fnode, IRFwOperation): - return True - if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.forward: - return True - return False + fnodes = [node for node in seq if is_forward_node(node)] + have_backward = all(fnode.mirror in seq for fnode in fnodes) + fpieces = [] for fnode in seq: if is_forward_node(fnode): - fnodes.append(fnode) - have_backward = all(fnode.mirror in seq for fnode in fnodes) - # training - if have_backward: - bnodes = [fnode.mirror for fnode in fnodes] - for fnode, bnode in zip(fnodes + [-1], bnodes + [-1]): - fconsecutive = Grouping.consecutive(seq, fpieces, fnode) - bconsecutive = Grouping.consecutive(seq, bpieces, bnode) - if fconsecutive and bconsecutive: - fpieces.append(fnode) - bpieces.insert(0, bnode) - else: - if len(fpieces) != 0: - fgroups[devid].append(fpieces) - bgroups[devid].append(bpieces) - fpieces, bpieces = [fnode], [bnode] - # inference - else: - if _get_use_new_grouping_algo(): - - for fnode in seq: - if is_forward_node(fnode): - fpieces.append(fnode) - else: - if len(fpieces) != 0: - fgroups[devid].append(fpieces) - bgroups[devid].append(None) - - # If the fnode is not a "forward node", e.g. it's DataOp node, don't add it into the group. - fpieces = [] - # 'bpieces' is never filled or returned in the inference mode - + fpieces.append(fnode) + else: if len(fpieces) != 0: fgroups[devid].append(fpieces) - bgroups[devid].append(None) - - else: # Not using new algo - - for fnode in fnodes + [-1]: - fconsecutive = Grouping.consecutive(seq, fpieces, fnode) - if fconsecutive: - fpieces.append(fnode) - bpieces.append(None) - else: - if len(fpieces) != 0: - fgroups[devid].append(fpieces) - bgroups[devid].append(None) - fpieces, bpieces = [fnode], [None] - - return fgroups, bgroups + fpieces = [] - @staticmethod - def consecutive(seq: List[IRCell], pieces: List[IRCell], node: IRCell): - """ - Check whether the piecies with new node - is consecutive in the sequence. + if len(fpieces) != 0: + fgroups[devid].append(fpieces) - Assume all the node in pieces will apear in seq. - If node not in the sequence, will return False. - """ - if len(pieces) == 0: - return True - if node not in seq: - return False - idx = seq.index(node) - pidx = [seq.index(pnode) for pnode in pieces] - # check whether pieces is consecutive - if max(pidx) - min(pidx) != len(pidx) - 1: - return False - # check whether new node adding new node is consecutive - if idx != max(pidx) + 1 and idx != min(pidx) - 1: - return False - return True + for pieces in fgroups[devid]: + if have_backward: + bpieces = [fnode.mirror for fnode in pieces[::-1] if fnode.mirror is not None] + bgroups[devid].append(bpieces) + else: + bgroups[devid].append(None) + + return fgroups, bgroups From 29ddce8d720aa4350ef1de6435ee9870d9d2f3b3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Jul 2022 10:03:41 +0800 Subject: [PATCH 0922/1892] add recompute mechanism --- cube/codegen/codegen.py | 372 ++++++++++++++++++++++++--------------- cube/graph/graph.py | 4 +- cube/runtime/executor.py | 112 ++++++------ 3 files changed, 289 insertions(+), 199 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 3d7884b0..dc2ddc29 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,7 +1,8 @@ """ Generate Pytorch code given the model DAG and the transformation config """ -from typing import Dict, List, Any, Tuple, Union +from sys import prefix +from typing import Dict, List, Any, Tuple, Union, Optional import torch import copy from cube.graph.parser.mapping import Sign2Op @@ -20,7 +21,6 @@ from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock -from cube.codegen.register import VarManager from cube.codegen.frontend_mapping import Sign2EmitRule @@ -41,31 +41,79 @@ def dtype_map(self, dtype: IRDType) -> str: def node_naming(self, node: IRCell) -> str: return f"{node.name}{node._id}" - def tensor_naming(self, tensor: Any) -> str: + def tensor_naming(self, tensor: Any, prefix_attr: Optional[str] = None) -> str: """ - Return the var name (unique for different variable) + Return the var name. + For tensor, return the {prefix}{tensor.name}_{tensor.tid} + For non-tensor, return its string + + @param tensor Any: any value + @attr_prefix Optional[str]: prefix for a attributed tensor + + @return str """ if isinstance(tensor, IRTensor): tensor_name = tensor.name if '.' in tensor_name: tensor_name = tensor_name.split('.')[0] - name = '_'.join([tensor_name, str(tensor._id)]) + name = '_'.join([tensor_name, str(tensor.tid)]) + if prefix_attr is not None and tensor.is_param(): + name = prefix_attr + name else: name = str(tensor) return name - def tuple_naming(self, tensors: List[Any]) -> str: - tensors = [self.tensor_naming(t) for t in tensors] - tensors = '(' + ', '.join(tensors + ['']) + ')' - return tensors + def tuple_naming(self, tensors: List[Any], + skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: + """ + Return the tupled tensor name. + + @param tensors List[Any]: list of any value + @param skip_attr bool: whether to skip graph attribute in the tensors + @param prefix_attr bool: whether to add a prefix for graph attribute - def return_naming(self, tensors: List[Any]) -> str: - tensors = [self.tensor_naming(t) for t in tensors] - if len(tensors) == 0: - tensors = '_' - else: - tensors = ', '.join(tensors) - return tensors + @return name str: the tupled tensor name + """ + names = [] + for t in tensors: + if isinstance(t, IRTensor) and skip_attr and t.is_param(): + continue + names.append(self.tensor_naming(t, prefix_attr)) + name = '(' + ', '.join(names + ['']) + ')' + return name + + def return_naming(self, tensors: List[Any], + skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: + """ + Return the tensors in return format, i.e. tupled name without brackets. + + @param tensors List[Any]: list of any value + @param skip_attr bool: whether to skip graph attribute in the tensors + @param prefix_attr bool: whether to add a prefix for graph attribute + + @return name str: the tupled tensor name + """ + names = [] + for t in tensors: + if isinstance(t, IRTensor) and skip_attr and t.is_param(): + continue + names.append(self.tensor_naming(t, prefix_attr)) + names = '_' if len(names) == 0 else ', '.join(names) + return names + + def kwargs_naming(self, **kwargs) -> str: + """ + Return the kwarg naming, connected by ', ' + + @param kwargs Dict[str, Any]: kwargs + + @return name str + """ + names = [] + for name, val in kwargs.items(): + names.append(f'{name}={val}') + name = ', '.join(names) + return name class AutogradAdapterCodeGen(CodeGen): @@ -136,11 +184,12 @@ def __init__(self, execplan: ExecutionPlan): # model full code self.init_code: List[str] = [ '\n\n########## Generated Model Code ###########', - 'import torch', 'import cube', '', ''] + 'import torch', 'import torch.utils.checkpoint as ckpt', + 'import cube', '', ''] # customized op code for _, op_impl in Sign2Op.kOpCodeDef.items(): self.init_code.append(op_impl) - self.init_code += ['', ''] + self.init_code += [''] # module init code self.declare_region: List[str] = list() # module forward code @@ -207,23 +256,22 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # parse graph body for node in self.execplan.seq(device): if isinstance(node, IRSegment): - # skip backward ir graph - if not node.forward: - continue - self.emit_segment_call(node) + if not node.forward: continue # skip backward segment + codes = self.emit_segment_code(node) elif isinstance(node, IRFwOperation): - self.emit_op_call(node) + codes = self.emit_op_code(node) elif isinstance(node, IRAdapter): - self.emit_adapter_call(node) + codes = self.emit_adapter_code(node) elif isinstance(node, IRWeightReducer): self.emit_reducer_init(node) - self.emit_reducer_call(node) + codes = self.emit_reducer_call(node) elif isinstance(node, IRBpOperation): continue elif isinstance(node, IRDataOperation): continue else: raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") + self.forward_region += codes # emit node tensor declaration self.emit_node_declare(node) # emit node code @@ -276,7 +324,7 @@ def emit_node_declare(self, node: IRCell): """ sign = 'torch.nn.Parameter(torch.empty({shape}, dtype={dtype}))' for input in node.inputs(): - name = self.tensor_naming(input) + name = self.tensor_naming(input, prefix_attr='self.') if isinstance(input, IRTensor): if input.is_param() and not self.symbols.exist(name): self.symbols.create(name) @@ -287,30 +335,94 @@ def emit_node_declare(self, node: IRCell): if not hasattr(self._ref_module, name[5:]): raise NotImplementedError("member attribute is not added") for output in node.outputs(): - self.symbols.create(self.tensor_naming(output)) + self.symbols.create(self.tensor_naming(output, prefix_attr='self.')) return - def emit_segment_call(self, graph: IRGraph): + def emit_segment_code(self, node: IRSegment) -> List[str]: """ Emit IRSegment code + + Nodes in the segment will group into recompute region """ - for node in graph.nodes(): + + codes = [] + + def emit_nodes(nodes: List[IRCell]) -> List[str]: + node_codes = [] + for node in nodes: + if isinstance(node, IRFwOperation): + code = self.emit_op_code(node) + elif isinstance(node, IRAdapter): + code = self.emit_adapter_code(node) + else: + raise RuntimeError(f"unexpected type {type(node)} in IRSegment") + node_codes += code + return node_codes + + def emit_rc_nodes(nodes: List[IRCell]) -> List[str]: + node_codes = [] + subseg = self.execplan.graph.segment(nodes) + inputs = [self.tensor_naming(t) for t in subseg.inputs() if not t.is_param()] + inputs_tup = ', '.join(inputs) + outputs = [self.tensor_naming(t) for t in subseg.outputs()] + outputs = ', '.join(outputs) + with FunctionBlock('recompute', inputs, False) as fb: + for ncode in emit_nodes(nodes): + fb.insert_body(ncode) + fb.insert_body(f'return {outputs}') + node_codes += [''] + fb.code + [''] + node_codes.append( + f'{outputs} = ckpt.checkpoint(recompute, {inputs_tup})' + ) + return node_codes + + # to recompute region + curr_recompute_gid = None + curr_nodes = [] + for node in node.nodes(): if isinstance(node, IRFwOperation): - self.emit_op_call(node) + if node.recompute != curr_recompute_gid: + if len(curr_nodes) != 0: + if curr_recompute_gid is None: + codes += emit_nodes(curr_nodes) + else: + codes += emit_rc_nodes(curr_nodes) + curr_recompute_gid = node.recompute + curr_nodes = [node] + else: + curr_nodes.append(node) elif isinstance(node, IRAdapter): - self.emit_adapter_call(node) + # strategy 1: recompute close before adapter communication + if curr_recompute_gid is None: + curr_nodes.append(node) + else: + if len(curr_nodes) != 0: + if curr_recompute_gid is None: + codes += emit_nodes() + else: + codes += emit_rc_nodes() + else: + curr_recompute_gid = None + curr_nodes = [node] + + if len(curr_nodes) != 0: + if curr_recompute_gid is None: + codes += emit_nodes(curr_nodes) else: - raise RuntimeError(f"unexpected type {type(node)} in forward graph:\n{graph.extra_repr()}") + codes += emit_rc_nodes(curr_nodes) + + return codes - def emit_op_call(self, node: IRFwOperation): + def emit_op_code(self, node: IRFwOperation) -> List[str]: """ Emit op forward code """ + codes = [] # insert comment if node.comment is not None: - self.forward_region.append(f'# {node.comment}') + codes.append(f'# {node.comment}') signature = node.signature - inputs = [self.tensor_naming(t) for t in node.inputs()] + inputs = [self.tensor_naming(t, prefix_attr='self.') for t in node.inputs()] kwargs = {} for key in node.kwargs: val = node.kwargs[key] @@ -327,26 +439,26 @@ def emit_op_call(self, node: IRFwOperation): outputs = [self.tensor_naming(t) for t in node.outputs()] outputs = ', '.join(outputs) code = f'{outputs} = {body}' - self.forward_region.append(code) + codes.append(code) + return codes - def emit_adapter_call(self, node: IRAdapter): + def emit_adapter_code(self, node: IRAdapter) -> List[str]: """ Emit adapter call """ + codes = [] assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" prims = [node] if node.differentiable and node.custom else [prim for prim in node.prims] for prim in prims: if len(prim.inputs()) == 1: - itensors = self.tensor_naming(prim.inputs()[0]) + itensors = self.tensor_naming(prim.inputs()[0], prefix_attr='self.') else: - itensors = self.tuple_naming(prim.inputs()) - kwargs = list() - for name, val in prim.kwargs.items(): - kwargs.append(f'{name}={val}') - kwargs = ', '.join(kwargs) + itensors = self.tuple_naming(prim.inputs(), prefix_attr='self.') + kwargs = self.kwargs_naming(**prim.kwargs) outputs = self.return_naming(prim.outputs()) code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' - self.forward_region.append(code) + codes.append(code) + return codes def emit_reducer_init(self, node: IRWeightReducer): # reducer init interface @@ -359,7 +471,7 @@ def emit_reducer_init(self, node: IRWeightReducer): self.declare_region.append('') init_code = reducer_init.format(reducer=reducer_name, ranks=node.device) self.declare_region.append(init_code) - weights = [self.tensor_naming(t) for t in weights] + weights = [self.tensor_naming(t, prefix_attr='self.') for t in weights] for weight in weights: add_param_code = add_param.format(reducer=reducer_name, weight=weight) self.declare_region.append(add_param_code) @@ -368,20 +480,8 @@ def emit_reducer_init(self, node: IRWeightReducer): def emit_reducer_call(self, node: IRWeightReducer): reducer_name = f'self.wreducer{node._id}' - call_code = f'{reducer_name}.allreduce()' - self.forward_region.append(call_code) - - def tensor_naming(self, tensor: Any): - """ - Generate tensor name. - - Will add prefix 'self.' for parameters - """ - name = super().tensor_naming(tensor) - if isinstance(tensor, IRSubTensor): - if tensor.is_param(): - name = 'self.' + name - return name + code = f'{reducer_name}.allreduce()' + return [code] def clear(self): """ @@ -406,39 +506,46 @@ def __init__(self, execplan: ExecutionPlan): 'import torch', 'import cube', ''] # module member name self.symbols = SymbolTable() - self.vars = VarManager() def gen(self, device: int, outfile=None, attach=False) -> str: """ Generate scheduling code based on the given sus """ gencode = copy.copy(self.init_code) - self.vars = VarManager() device_nodes = self.execplan.seq(device) - def later_ref(tensor, node) -> bool: + def removable(node) -> bool: """ - check whether the output tensor of the node need to be later used. + Get removable tensor lists that will not be used after the execution of the node + + @param node IRCell """ + inputs = [t for t in node.inputs() if isinstance(t, IRSubTensor) and not t.is_param()] idx = device_nodes.index(node) - if tensor in self.execplan.graph.outputs(): - return True - for ref_node in device_nodes[idx+1:]: - if isinstance(ref_node, IRSegment): - if ref_node.forward: - if tensor in ref_node.inputs(): - return True + remove = [] + for itensor in inputs: + free = True + for ref_node in device_nodes[idx+1:]: + if isinstance(ref_node, IRSegment): + if ref_node.forward: + if itensor in ref_node.inputs(): + free = False + break + else: + input_tensors = ref_node.mirror.inputs() + output_tensors = ref_node.mirror.outputs() + output_grads = tuple(t.grad for t in output_tensors) + if itensor in input_tensors + output_tensors + output_grads: + free = False + break else: - finputs = ref_node.mirror.inputs() - foutputs = ref_node.mirror.outputs() - grad_in = tuple(t.grad for t in foutputs) - if tensor in finputs + foutputs + grad_in: - return True - else: - if tensor in ref_node.inputs(): - return True - return False + if itensor in ref_node.inputs(): + free = False + break + if free: + remove.append(itensor) + return remove with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: @@ -455,11 +562,11 @@ def later_ref(tensor, node) -> bool: code = self.emit_node(node, name=name) fb.insert_body(code) # free unused tensor - for tensor in node.inputs() + node.outputs(): - if isinstance(tensor, IRSubTensor) and not tensor.is_param(): - refcnt = later_ref(tensor, node) - if refcnt == 0: - self.vars.free(tensor) + removable_tensors = removable(node) + if len(removable_tensors) > 0: + nones = ', '.join(['None'] * len(removable_tensors)) + code = f'{self.return_naming(removable_tensors)} = {nones}' + fb.insert_body(code) # return code outputs = self.return_naming(self.execplan.graph.outputs()) code = f'return {outputs}' @@ -489,45 +596,42 @@ def emit_node(self, node: IRCell, name: str) -> str: """ Emit node / subgraph code """ - fsign = 'cube.runtime.executor.fexecute({model}, *{inputs}, requires_grad={req_grad})' - bsign = 'cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' + fsign = '{outputs} = cube.runtime.executor.fexecute({model}, *{inputs}, requires_grad={req_grad})' + bsign = '{input_grads} = cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' - inputs = [self.tensor_naming(t) for t in node.inputs() if not t.is_param()] - outputs = [self.tensor_naming(t) for t in node.outputs()] + inputs = self.tuple_naming(node.inputs(), skip_attr=True, prefix_attr='model.') + outputs = self.return_naming(node.outputs(), skip_attr=True, prefix_attr='model.') req_grad = any(t.requires_grad for t in node.outputs() if isinstance(t, IRTensor)) - inputs = self.tuple_naming(inputs) - outputs = self.return_naming(outputs) if isinstance(node, IRSegment): # emit forward if node.forward: - recompute = any(isinstance(n.recompute, int) for n in node.nodes()) - if recompute: - raise NotImplementedError("recompute mechanism is not supported") - body = fsign.format(model=f'model.{name}', inputs=inputs, req_grad=req_grad) - code = f'{outputs} = {body}' + code = fsign.format( + outputs = outputs, + model = f'model.{name}', + inputs = inputs, + req_grad = req_grad + ) # emit backward else: - finputs = [t for t in node.mirror.inputs() if t.requires_grad] - foutputs = node.mirror.outputs() - inputs = [t.grad for t in foutputs] - for idx, itensor in enumerate(inputs): - if isinstance(itensor, float): - assert itensor == 1.0, "Loss gradient should be 1.0" - inputs[idx] = None - outputs = [t.grad for t in finputs] - # remove weight gradient in outputs - for input in finputs: - if input.is_param(): - outputs.remove(input.grad) - finputs = self.tuple_naming(finputs) - foutputs = self.tuple_naming(foutputs) - inputs = self.tuple_naming(inputs) - outputs = self.return_naming(outputs) - body = bsign.format( - input_tensors=finputs, output_tensors=foutputs, output_grads=inputs + input_tensors = [t for t in node.mirror.inputs() if \ + isinstance(t, IRSubTensor) and \ + t.requires_grad and \ + not t.is_param() + ] + output_tensors = [t for t in node.mirror.outputs() if isinstance(t, IRSubTensor)] + input_grads = [t.grad for t in input_tensors] + output_grads = [t.grad for t in output_tensors] + for idx, tensor in enumerate(output_grads): + if isinstance(tensor, float): + assert tensor == 1.0, "Loss gradient should be 1.0" + output_grads[idx] = None + code = bsign.format( + input_grads = self.return_naming(input_grads), + input_tensors = self.tuple_naming(input_tensors, skip_attr=True, prefix_attr='model.'), + output_tensors = self.tuple_naming(output_tensors, skip_attr=True, prefix_attr='model.'), + output_grads = self.tuple_naming(output_grads, skip_attr=True, prefix_attr='model.') ) - code = f'{outputs} = {body}' elif isinstance(node, IRDataOperation): if len(node.inputs()) != 0: @@ -537,37 +641,21 @@ def emit_node(self, node: IRCell, name: str) -> str: code = f'{outputs} = next(dataloader)' elif isinstance(node, IRAdapter): - body = fsign.format(model=f'model.{name}', inputs=inputs, req_grad=req_grad) - code = f'{outputs} = {body}' + code = fsign.format( + outputs = outputs, + model = f'model.{name}', + inputs = inputs, + req_grad = req_grad + ) elif isinstance(node, IRWeightReducer): - body = fsign.format(model=f'model.{name}', inputs='()', req_grad=req_grad) - code = f'{outputs} = {body}' + code = fsign.format( + outputs = outputs, + model=f'model.{name}', + inputs='()', + req_grad=req_grad + ) else: raise RuntimeError(f"Unspported node type: {type(node)}") return code - - def tuple_naming(self, tensors: List[Any]) -> str: - tensors = [self.tensor_naming(t) for t in tensors] - tensors = '(' + ', '.join(tensors + ['']) + ')' - return tensors - - def return_naming(self, tensors: List[Any]) -> str: - tensors = [self.tensor_naming(t) for t in tensors] - if len(tensors) == 0: - tensors = '_' - else: - tensors = ', '.join(tensors) - return tensors - - def tensor_naming(self, tensor: Union[IRSubTensor, Any]): - """ - Generate tensor name. - - Will add prefix 'model.' for parameters - """ - if isinstance(tensor, IRSubTensor) and tensor.is_param(): - return 'model.' + self.vars.allocate(tensor) - else: - return self.vars.allocate(tensor) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 7a3ecbe7..a35343fa 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -484,7 +484,7 @@ def from_logic_graph(nodes: List[IRCell], graph = IRGraph(nodes, inputs, outputs, module_name) return graph - ##### Partition Primitives ##### + ##### Transformation Primitives ##### def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> List[IRCell]: """ @@ -564,6 +564,8 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], self.attach(fnode, findex + idx) if isinstance(node.comment, str): fnode.comment = node.comment + if isinstance(node, IRFwOperation): + fnode.recompute = node.recompute # update backward if isinstance(node.mirror, IRBpOperation): bindex = self.detach(node.mirror) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 07b5d184..fb384508 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -19,52 +19,9 @@ def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) return outputs -def backward(input_tensors : List[torch.Tensor], - output_tensors: List[torch.Tensor], - output_tensor_grads: List[torch.Tensor]): - """ - Backward Procedure. - - input_tensors: List[torch.Tensor]: - tensors that their gradient need to be computed, including parameters. - Correspoinding forward input tensors. - - output_tensors: - tensors that start for gradient backward computation. - Corresponding to forward output tensors. - - output_tensor_grads: - gradient tensors corresponding to output_tensors. - - Returns: - gradient in order of non-parameter tensors in input_tensors. - (Note parameter tnesors already have gradient accumulated at .grad attribute) - """ - if len(input_tensors) == 0: - return None - grads = list() - in_grads = torch.autograd.grad( - outputs = output_tensors, - inputs = input_tensors, - grad_outputs = output_tensor_grads, - allow_unused=True - ) - for tensor, grad in zip(input_tensors, in_grads): - if isinstance(tensor, torch.nn.Parameter): - if tensor.grad is not None: - tensor.grad += grad - else: - tensor.grad = grad - else: - grads.append(grad) - if len(grads) == 0: return None - elif len(grads) == 1: return grads[0] - else: return tuple(grads) - - -# def backward(input_tensors: List[torch.Tensor], +# def backward(input_tensors : List[torch.Tensor], # output_tensors: List[torch.Tensor], -# output_tensor_grads: List[torch.Tensor]) -> Tuple[torch.Tensor]: +# output_tensor_grads: List[torch.Tensor]): # """ # Backward Procedure. # @@ -83,23 +40,66 @@ def backward(input_tensors : List[torch.Tensor], # gradient in order of non-parameter tensors in input_tensors. # (Note parameter tnesors already have gradient accumulated at .grad attribute) # """ -# if len(output_tensors) == 0: +# if len(input_tensors) == 0: # return None -# inputs = list() -# for input_ in input_tensors: -# if torch.is_tensor(input_) and not isinstance(input_, torch.nn.Parameter): -# if input_.requires_grad: -# input_.retain_grad() -# inputs.append(input_) -# torch.autograd.backward( -# output_tensors, -# grad_tensors=output_tensor_grads, +# grads = list() +# in_grads = torch.autograd.grad( +# outputs = output_tensors, +# inputs = input_tensors, +# grad_outputs = output_tensor_grads, +# allow_unused=True # ) -# grads = tuple(input_.grad for input_ in inputs) +# for tensor, grad in zip(input_tensors, in_grads): +# if isinstance(tensor, torch.nn.Parameter): +# if tensor.grad is not None: +# tensor.grad += grad +# else: +# tensor.grad = grad +# else: +# grads.append(grad) # if len(grads) == 0: return None # elif len(grads) == 1: return grads[0] # else: return tuple(grads) + +def backward(input_tensors: List[torch.Tensor], + output_tensors: List[torch.Tensor], + output_tensor_grads: List[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Backward Procedure. + + input_tensors: List[torch.Tensor]: + tensors that their gradient need to be computed, including parameters. + Correspoinding forward input tensors. + + output_tensors: + tensors that start for gradient backward computation. + Corresponding to forward output tensors. + + output_tensor_grads: + gradient tensors corresponding to output_tensors. + + Returns: + gradient in order of non-parameter tensors in input_tensors. + (Note parameter tnesors already have gradient accumulated at .grad attribute) + """ + if len(output_tensors) == 0: + return None + inputs = list() + for input_ in input_tensors: + if torch.is_tensor(input_) and not isinstance(input_, torch.nn.Parameter): + if input_.requires_grad: + input_.retain_grad() + inputs.append(input_) + torch.autograd.backward( + output_tensors, + grad_tensors=output_tensor_grads, + ) + grads = tuple(input_.grad for input_ in inputs) + if len(grads) == 0: return None + elif len(grads) == 1: return grads[0] + else: return tuple(grads) + ### =================== Experimental Feature ======================= # import queue From 2f8ecb64e0e4aa237a574dac743c1f780223270a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Jul 2022 10:19:34 +0800 Subject: [PATCH 0923/1892] replicate for recompute --- cube/codegen/codegen.py | 5 ++--- cube/ir/operator.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index dc2ddc29..e6c94541 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,8 +1,7 @@ """ Generate Pytorch code given the model DAG and the transformation config """ -from sys import prefix -from typing import Dict, List, Any, Tuple, Union, Optional +from typing import Dict, List, Any, Tuple, Optional import torch import copy from cube.graph.parser.mapping import Sign2Op @@ -14,7 +13,7 @@ from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from cube.ir.adapter import IRWeightReducer, IRAdapter from cube.ir.adapter.prim import CollectivePrim, IRAdapterPrim -from cube.graph.graph import IRGraph, IRSegment +from cube.graph.graph import IRSegment from cube.graph.schedule import IRScheduleStrategy from cube.execplan import ExecutionPlan diff --git a/cube/ir/operator.py b/cube/ir/operator.py index a7bce616..00e82332 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -131,6 +131,7 @@ def replicate(self): for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None + cpy.recompute = self.recompute cpy.clear_predecessor() cpy.clear_successor() return cpy From 47a618bb191ba35166e9ab2d68292a9f87a4ddc5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Jul 2022 13:56:09 +0800 Subject: [PATCH 0924/1892] loading content support --- cube/codegen/codegen.py | 21 +++++++++++++++------ cube/compiler.py | 9 +++++---- cube/graph/parser/frame.py | 18 +++++++++++++++++- cube/graph/parser/parser.py | 2 ++ cube/runtime/module.py | 29 +++++++++++++++++++++++++---- 5 files changed, 64 insertions(+), 15 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index e6c94541..ab34ab74 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -322,14 +322,23 @@ def emit_node_declare(self, node: IRCell): Emit tensor declaration code """ sign = 'torch.nn.Parameter(torch.empty({shape}, dtype={dtype}))' - for input in node.inputs(): - name = self.tensor_naming(input, prefix_attr='self.') - if isinstance(input, IRTensor): - if input.is_param() and not self.symbols.exist(name): + map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" + for itensor in node.inputs(): + name = self.tensor_naming(itensor, prefix_attr='self.') + if isinstance(itensor, IRSubTensor): + if itensor.is_param() and not self.symbols.exist(name): self.symbols.create(name) - code = f'{name} = {sign.format(shape=tuple(input.shape), dtype=self.dtype_map(input.dtype))}' + code = f'{name} = {sign.format(shape=tuple(itensor.shape), dtype=self.dtype_map(itensor.dtype))}' self.declare_region.append(code) - if isinstance(input, str): + tid = itensor.parent.tid + slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) + val_chunks = itensor.valmap[1] + code = map_sign.format( + attr=self.tensor_naming(itensor), tid=tid, + slicers=str(slicers), val_chunks=val_chunks + ) + self.declare_region.append(code) + if isinstance(itensor, str): if name.startswith('self.'): if not hasattr(self._ref_module, name[5:]): raise NotImplementedError("member attribute is not added") diff --git a/cube/compiler.py b/cube/compiler.py index 92b49d83..26996ff1 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -39,15 +39,16 @@ def __init__(self, model: torch.nn.Module, input_shapes): def get_graph(self): return self.ir_graph - def load_module(self, filename: str): + def load_module(self, filename: str, load_content=True): import importlib.util spec = importlib.util.spec_from_file_location("GenModel", filename) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self._loaded_module = module.GenModel().cuda() - self._loaded_module.init_param() - # sync parameters before start training - self._loaded_module.sync_params() + if load_content: + print_each_rank("> loading parameter content...") + # TODO: make hardcode ./fullmodel.pt programmable + self._loaded_module.load_attr_content('./fullmodel.pt') def get_gen_module(self): return self._loaded_module diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index 9350ef4b..40ec379b 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -1,5 +1,6 @@ from collections import OrderedDict -from typing import List, Any +from typing import List, Any, Dict +import torch class Frame: @@ -13,6 +14,7 @@ def __init__(self): self._var_stack: List[str] = list() # module attributes self._attributes: List[dict[str, Any]] = list() + self._attr_vals: Dict[int, Any] = dict() # tensor tid to real value mapping def push_var(self, inherit_from_top=False): """ @@ -126,6 +128,20 @@ def has_attr(self, name: str) -> bool: """ return name in self._attributes[-1] + def add_attr_content(self, tid: int, val: torch.Tensor): + """ + Add module attribute content + """ + if torch.is_tensor(val): + val = val.cpu() + self._attr_vals[tid] = val + + def save_attr_content(self, save_file: str = 'fullmodel.pt'): + """ + Save attribute content into file. + """ + torch.save(self._attr_vals, save_file) + def push_param(self, var_name): """ push var name to the method stack diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index a3dcc862..28c50c4f 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -95,6 +95,7 @@ def parse_module(module, frame.pop_var() frame.pop_attr() + frame.save_attr_content() return input_val, all_ir_nodes, output_val @staticmethod @@ -415,6 +416,7 @@ def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: warnings.warn('Detected non-parameter tensor as graph attribute. Regard them as parameters') ir_tensor.as_param() frame.add_attr(label, ir_tensor) + frame.add_attr_content(ir_tensor.tid, tensor) frame.add_var(var_name, ir_tensor) # symbolic attributes elif dtype in ['bool', 'int', 'float']: diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 46d69a36..792627a6 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict, Tuple import torch from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer @@ -13,6 +13,7 @@ class CubeModule(torch.nn.Module): def __init__(self): super().__init__() self._reducers = list() + self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() def add_reducer(self, reducer: Reducer): if not isinstance(reducer, Reducer): @@ -23,9 +24,29 @@ def sync_params(self): for reducer in self._reducers: reducer.sync() - def init_param(self): - for param in self.parameters(): - torch.nn.init.uniform_(param) + def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: int): + """ + Add an attribute map. + The mapping includes current attribute name (str) to logical tensor id, + and the mapping of logical tensor id including spatial (slice) and val chunks + + @param attr str: attribute name of this moudle + @param tid int: full tensor id + @param slicers Tuple[slice]: indexing from full tensor + @param val_chunks int: the number of value chunks. + """ + assert hasattr(self, attr), f"{attr} is not in the module" + self._fullmap[attr] = (tid, slicers, val_chunks) + + def load_attr_content(self, filename: str): + with torch.no_grad(): + full = torch.load(filename) + for attr in self._fullmap.keys(): + tensor: torch.Tensor = getattr(self, attr) + tid, slicers, nchunks = self._fullmap[attr] + content = full[tid][slicers] / nchunks + tensor.copy_(content) + # print(f'attr {attr}:\n{getattr(self, attr)}') def init_group(self, ranks: List[int]): if not all([isinstance(rank, int) for rank in ranks]): From 34313ddd9bbc09c728ae4d6bb28d3e3e1486c7eb Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Fri, 22 Jul 2022 12:23:49 +0000 Subject: [PATCH 0925/1892] Merged PR 1397: Accelerate lifetime calculation for scheduling code in codegen TODO this calculation of lifetime is now only applied to sheduling code, but it's ok to apply it to model/segment code too. --- cube/codegen/codegen.py | 151 +++++++++++++++++++++++++++++----------- 1 file changed, 112 insertions(+), 39 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index ab34ab74..ea0e0207 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -1,7 +1,8 @@ """ Generate Pytorch code given the model DAG and the transformation config """ -from typing import Dict, List, Any, Tuple, Optional +import itertools +from typing import Dict, Generator, Iterable, List, Any, Optional, Set, Tuple, Union import torch import copy from cube.graph.parser.mapping import Sign2Op @@ -23,6 +24,104 @@ from cube.codegen.frontend_mapping import Sign2EmitRule +# TODO this could be applied in all codegen: Segments, Adapters, Schedules... +# TODO but Schedules don't have standalone inputs, while Segments and Adapters have and +# those inputs may need to be released too. +def calc_tenvars_lifetime( + nodes:Iterable[IRCell], subgraph_outputs:Iterable[IRTensor] + ) -> Dict[IRTensor, int]: + """ + Calculate the lifetime of tensor variables ahead-of-time. + So that during schedule the GC on those variables can take place in time. + + E.g. at what timings may a tensor variable O_i be discarded (i.e. no longer referred)? + ``` + ..., O_i, O_j, ... = f(I_1, ..., I_M) + # Case 1, immediately, because it's never used + O_i = None + # Case 2, after some invocation, because it's no longer referred + ... = g(..., O_j, ...) + O_j = None + ``` + + Returns: `Dict[IRTensor, int]` + + For each kv-pair `(t, i)` it indicates the last reference of tensor `t` + is at the `i`-th (0-based) node's inputs, + i.e. the variable for tensor `t` could be released *AFTER* the `i`-th statement + in codegen. + + If a tensor is not included in the dict, it means that tensor is never referenced. + + TODO If an input of the subgraph is never used, its corresponding `i` is `-1`. + + REMARK: + + We cannot `detele O_j` because it may be a variable alias and the tensor object + behind is still active (e.g. `runtime.multiref`). + So we just set it to `None`, decrement the reference count + and let Python (the codegen target) decide the deletion. + """ + + lifetime : Dict[IRTensor, int] = dict() + + def is_temp_tensor(v): + return isinstance(v, IRSubTensor) and not v.is_param() + + #lifetime.update((tsin, -1) for tsin in subgraph_inputs if is_temp_tensor(tsin)) + + for i, node in enumerate(nodes): + # aggressively mark all outputs for immediate deletion, namely 'i' + # TODO should work fine even for IRBpOperation + lifetime.update((tout, i) for tout in node.outputs() if is_temp_tensor(tout)) + + # "fast-forward" all inputs to the current statement, namely 'i' + inputs : Iterable[IRTensor] + + if isinstance(node, IRSegment): + if node.forward: + inputs = node.inputs() + else: + # NOTE + # An 'IRBpOperation' does not explicitly record all tensors that are + # inputs-and-outputs-to-its-correspondeding-autograd.grad-call after codegen, + # + # E.g. + # IRBpOperation bp_op{ pair=fw_op } has: + # ``` + # len(bp_op.inputs())==len(fw_op.outputs()) + # len(bp_op.outputs())==len(fw_op.inputs()) + # ```` + # + # while a call to 'torch.autograd.grad' in Python is like: + # ``` + # grad_inputs : Tuple[torch.Tensor, ...] = torch.autograd.grad(outputs, inputs, grad_outputs) + # len(grad_inputs) == len(inputs) + # len(grad_outputs) == len(outputs) + # ``` + # in other words, if we simply treat `autograd.grad` as an `g_op:IRCell`, it has: + # ``` + # len(g_op.inputs()) == len(fw_op.outputs())*2 + len(fw_op.inputs()) + # len(g_op.outputs()) == len(fw_op.inputs()) + # ``` + # + fw_inputs : tuple = node.mirror.inputs() + fw_outputs : tuple = node.mirror.outputs() + grad_inputs : Generator = (t.grad for t in fw_outputs) + + inputs = itertools.chain(fw_inputs, fw_outputs, grad_inputs) + + else: + inputs = node.inputs() + + lifetime.update((tin, i) for tin in inputs if is_temp_tensor(tin)) + + i += 1 + lifetime.update((tsout, i) for tsout in subgraph_outputs if is_temp_tensor(tsout)) + + return lifetime + + class CodeGen: """ Generate code for the model @@ -523,37 +622,10 @@ def gen(self, device: int, outfile=None, attach=False) -> str: device_nodes = self.execplan.seq(device) - def removable(node) -> bool: - """ - Get removable tensor lists that will not be used after the execution of the node - - @param node IRCell - """ - inputs = [t for t in node.inputs() if isinstance(t, IRSubTensor) and not t.is_param()] - idx = device_nodes.index(node) - remove = [] - for itensor in inputs: - free = True - for ref_node in device_nodes[idx+1:]: - if isinstance(ref_node, IRSegment): - if ref_node.forward: - if itensor in ref_node.inputs(): - free = False - break - else: - input_tensors = ref_node.mirror.inputs() - output_tensors = ref_node.mirror.outputs() - output_grads = tuple(t.grad for t in output_tensors) - if itensor in input_tensors + output_tensors + output_grads: - free = False - break - else: - if itensor in ref_node.inputs(): - free = False - break - if free: - remove.append(itensor) - return remove + lifetime : Dict[IRTensor, int] = calc_tenvars_lifetime(device_nodes, self.execplan.graph.outputs()) + lifetime_by_line_id : Dict[int, List[IRTensor]] = dict() + for tensor, line_id in lifetime.items(): + lifetime_by_line_id.setdefault(line_id, []).append(tensor) with FunctionBlock(func_name='_train_step', args=['model', 'dataloader']) as fb: @@ -565,16 +637,17 @@ def removable(node) -> bool: code = self.emit_schedule_plan(self.execplan.graph.schedule_plan, device) fb.insert_body(code) else: - for node in device_nodes: + for i, node in enumerate(device_nodes): name = self.node_naming(node) code = self.emit_node(node, name=name) fb.insert_body(code) - # free unused tensor - removable_tensors = removable(node) - if len(removable_tensors) > 0: - nones = ', '.join(['None'] * len(removable_tensors)) - code = f'{self.return_naming(removable_tensors)} = {nones}' - fb.insert_body(code) + + # decrement reference counts for output tensors that are no longer used + tensors : Optional[List[IRTensor]] = lifetime_by_line_id.get(i, None) + if tensors is not None: # not necessarily to have one after each line + tnames : Generator = (self.tensor_naming(t) for t in tensors) + fb.insert_body(', '.join(tnames) + ' = ' + ', '.join(['None'] * len(tensors))) + # return code outputs = self.return_naming(self.execplan.graph.outputs()) code = f'return {outputs}' From 84e802be74629404a2199a7cfdaece0aab28968d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 25 Jul 2022 16:53:43 +0800 Subject: [PATCH 0926/1892] add training mode to the gen code init section --- cube/codegen/codegen.py | 7 +++++++ cube/execplan/execplan.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index ea0e0207..b8f065c7 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -389,6 +389,11 @@ def gen(self, device: int, outfile=None, attach=False) -> str: with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.declare_region) + # switch to training or inference mode + if self.execplan.inference: + ib.insert_body('self.eval()') + else: + ib.insert_body('self.train()') cb.insert_body('') cb.insert_body(ib.code) for idx, node in enumerate(gen_nodes): @@ -403,6 +408,8 @@ def gen(self, device: int, outfile=None, attach=False) -> str: fb.insert_body(return_code) cb.insert_body('') cb.insert_body(fb.code) + + gencode += cb.code gencode += [''] diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 6335e9cc..75ecce40 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -14,6 +14,7 @@ def __init__(self, graph: IRGraph): assert isinstance(graph, IRGraph), "Expected an IRGraph" self._graph = graph self._seq: Dict[int, List[IRCell]] = dict() + self._inference_only = not any(isinstance(n, IRBpOperation) for n in graph.nodes()) # execution sequence for each device for node in graph.nodes(): @@ -49,6 +50,10 @@ def __init__(self, graph: IRGraph): def graph(self) -> IRGraph: return self._graph + @property + def inference(self) -> bool: + return self._inference_only + def devices(self) -> List[int]: """ Get device set From e229e208139ceb0293af4ba934762f164efff5b5 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 25 Jul 2022 23:38:37 +0800 Subject: [PATCH 0927/1892] branch init: add palm code --- examples/nlp/palm/palm.py | 257 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 257 insertions(+) create mode 100644 examples/nlp/palm/palm.py diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py new file mode 100644 index 00000000..c720cf37 --- /dev/null +++ b/examples/nlp/palm/palm.py @@ -0,0 +1,257 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/linears.py --policy PASMegatron +""" + +import torch +import torch.nn.functional as F +from torch import nn, einsum + +from math import log2, floor + +from einops import rearrange, repeat + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + +import examples.mlp.policy.spmd as spmd +import examples.mlp.policy.mpmd as mpmd + +import argparse + +def exists(val): + return val is not None + +# parser = argparse.ArgumentParser(description='comm primitive') +# parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +# args = parser.parse_args() +# +# cube.init() +# +# # set up policy +# PAS = None +# policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +# if args.policy in spmd.__dict__: +# PAS = spmd.__dict__[args.policy] +# print_each_rank(f'using policy from spmd.{args.policy}') +# elif args.policy in mpmd.__dict__: +# PAS = mpmd.__dict__[args.policy] +# print_each_rank(f'using policy from mpmd.{args.policy}') +# else: +# raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + + +# =================== Semantic Model Description ==================== + +class MLP(nn.Module): + def __init__(self, dim, mult=1): + super().__init__() + self.linear1 = nn.Linear(dim, dim * mult) + self.linear2 = nn.Linear(dim * mult, dim) + self.linear3 = nn.Linear(dim, dim * mult) + self.linear4 = nn.Linear(dim * mult, dim) + # self.linear5 = nn.Linear(dim, dim * mult) + # self.linear6 = nn.Linear(dim * mult, dim) + # self.linear7 = nn.Linear(dim, dim * mult) + # self.linear8 = nn.Linear(dim * mult, dim) + + def forward(self, data): + output = self.linear1(data) + output = self.linear2(output) + output = self.linear3(output) + output = self.linear4(output) + # output = self.linear5(output) + # output = self.linear6(output) + # output = self.linear7(output) + # output = self.linear8(output) + loss = torch.sum(output) + return loss + +# normalization + +class RMSNorm(nn.Module): + def __init__(self, dim, eps = 1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim = -1, keepdim = True) * self.scale + return x / norm.clamp(min = self.eps) * self.g + +# AliBi + +class AlibiPositionalBias(nn.Module): + def __init__(self, heads, **kwargs): + super().__init__() + self.heads = heads + slopes = torch.Tensor(self._get_slopes(heads)) + slopes = rearrange(slopes, 'h -> h 1 1') + self.register_buffer('slopes', slopes, persistent = False) + self.register_buffer('bias', None, persistent = False) + + def get_bias(self, i, j, device): + i_arange = torch.arange(i, device = device) + j_arange = torch.arange(j, device = device) + bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) + return bias + + @staticmethod + def _get_slopes(heads): + def get_slopes_power_of_2(n): + start = (2**(-2**-(log2(n)-3))) + ratio = start + return [start*ratio**i for i in range(n)] + + if log2(heads).is_integer(): + return get_slopes_power_of_2(heads) + + closest_power_of_2 = 2 ** floor(log2(heads)) + return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] + + def forward(self, qk_sim): + h, i, j, device = *qk_sim.shape[-3:], qk_sim.device + + if exists(self.bias) and self.bias.shape[-1] >= j: + return self.bias[..., :i, :j] + + bias = self.get_bias(i, j, device) + bias = bias * self.slopes + + num_heads_unalibied = h - bias.shape[0] + bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) + self.register_buffer('bias', bias, persistent=False) + + return bias + +class SwiGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + +class PaLMLayer(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): + super().__init__() + + self.dim, self.dim_head, self.heads, self.scale = dim, dim_head, heads, dim_head**-0.5 + + self.alibi_pos_biases = AlibiPositionalBias(heads=self.heads) + + self.norm = RMSNorm(dim) + self.qkv_proj = nn.Linear(dim, dim + dim_head, bias=False) + self.attn_out = nn.Linear(dim, dim, bias=False) + + self.ff = nn.Sequential( + nn.Linear(dim, 2 * ff_mult * dim, bias=False), + SwiGLU(), + nn.Linear(ff_mult * dim, dim, bias=False) + ) + + self.register_buffer("mask", None, persistent=False) + + def get_mask(self, n, device): + if self.mask is not None and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), 1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def forward(self, x): + n, device = x.shape[1], x.device + + # pre layernorm + x = self.norm(x) + + ff_out = self.ff(x) + + q, kv = self.qkv_proj(x).split((self.dim, self.dim_head), dim=-1) + + q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) + + # scale + q = q * self.scale + + # similarity + sim = einsum('b h i d, b j d -> b h i j', q, kv) + + sim = sim + self.alibi_pos_biases(sim) + + # causal mask + + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b j d -> b h i d", attn, kv) + + # merge heads + + out = rearrange(out, "b h n d -> b n (h d)") + + merge_heads = self.attn_out(out) + ff_out + return merge_heads + +def train(): + bs, n, d = 8, 1024, 512 + + model = PaLMLayer(d) + + x = torch.randn(bs, n, d) + + y = model(x) + +# def train(): +# batch_size = 256 +# dim = 8192 +# +# model = MLP(dim=dim) +# model = cube.SemanticModel( +# model, input_shapes=([batch_size, dim],), +# ) +# +# dataloader = cube.runtime.syndata.SynDataLoader( +# shapes=([batch_size, dim],), +# dtypes=(torch.float32,), +# batch_dims=(0,) +# ) +# +# @cube.compile(model, dataloader, PAS=PAS) +# def train_iter(model, dataloader): +# data = next(dataloader) +# loss = model(data) +# loss.backward() +# model = model.get_gen_module() +# +# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) +# +# CudaTimer(enable=False).warmup() +# if torch.distributed.is_initialized(): +# torch.distributed.barrier() +# iter_num = 64 +# warmup = 20 +# for step in range(iter_num): +# if step >= warmup: +# CudaTimer(enable=True).start('e2e') +# train_iter(model, dataloader) +# optimizer.step() +# optimizer.zero_grad() +# if step >= warmup: +# CudaTimer().stop('e2e') +# if (step + 1) % 20 == 0: +# print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) +# +# print_each_rank('e2e time (ms) per iteration: {} ms'.format( +# CudaTimer().duration(iter_num-warmup, field_name='e2e'))) +# CudaTimer().print_all(times=iter_num-warmup) + + +train() From d2daa9e67b4a864bc1dc7536b8d1b1edc91fc943 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Thu, 28 Jul 2022 09:24:09 +0000 Subject: [PATCH 0928/1892] Merged PR 1400: Refine IRCell inputs/outputs API - Make underlying list and cache the public properties immutable tuples - Encapsulate `reset_inputs/outputs` APIs --- cube/graph/graph.py | 4 ++-- cube/ir/cten.py | 46 ++++++++++++++++++++++++++++++--------------- cube/ir/operator.py | 12 ++++++------ 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index a35343fa..66cd79c1 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -380,7 +380,7 @@ def get_inputs(nodes: List[IRCell]): """ all_outputs = list() for node in nodes: - all_outputs += list(node.outputs()) + all_outputs.extend(node.outputs()) inputs = list() for cell in nodes: for input in cell.inputs(): @@ -405,7 +405,7 @@ def get_outputs(nodes: List[IRCell]): """ all_inputs = list() for node in nodes: - all_inputs += list(node.inputs()) + all_inputs.extend(node.inputs()) outputs = list() for node in nodes: for idx, output in enumerate(node.outputs()): diff --git a/cube/ir/cten.py b/cube/ir/cten.py index e4adf4b1..7345c635 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -14,6 +14,7 @@ """ +from functools import lru_cache from typing import Iterable, List, Tuple, Union, Optional, Any import copy @@ -54,12 +55,12 @@ def __init__(self, self._device = list() # source tensors - self._inputs: Tuple[Optional[IRTensor], ...] = (None,) * input_length + self._inputs: List[Optional[IRTensor]] = [None,] * input_length # destination tensors - self._outputs: Tuple[Optional[IRTensor], ...] = (None,) * output_length + self._outputs: List[Optional[IRTensor]] = [None,] * output_length if init_outputs: - self._outputs = tuple(IRTensor() for _ in range(output_length)) + self._outputs = [IRTensor() for _ in range(output_length)] for tensor in self._outputs: tensor.cell = self @@ -144,6 +145,8 @@ def input(self, index:int): """ return self._inputs[index] + # 'maxsize=None' set no limit on cache growth, but it's ok since we have no args + @lru_cache(maxsize=None) def inputs(self): # type: () -> Tuple[Optional[IRTensor], ...] """ @@ -153,8 +156,7 @@ def inputs(self): values: Tuple[Optional[IRTensor], ...] """ - # self._inputs is a tuple and is immutable. - return self._inputs + return tuple(self._inputs) def predecessors(self, index: Optional[int] = None) -> List: """ @@ -192,6 +194,8 @@ def output(self, index:int): """ return self._outputs[index] + # 'maxsize=None' set no limit on cache growth, but it's ok since we have no args + @lru_cache(maxsize=None) def outputs(self): # type: () -> Tuple[Optional[IRTensor], ...] """ @@ -201,8 +205,7 @@ def outputs(self): values: Tuple[Optional[IRTensor], ...] """ - # self._outputs is a tuple and is immutable. - return self._outputs + return tuple(self._outputs) def successors(self, index: Optional[int] = None) -> List: """ @@ -227,6 +230,13 @@ def successors(self, index: Optional[int] = None) -> List: else: raise TypeError("Expected index to be None or int") + def reset_inputs(self, length:int) -> None: + """ + Resize the inputs list to the new length and reset all input items to None. + """ + self._inputs = [None] * length + self.inputs.cache_clear() + def set_input(self, input_index: int, val): # type: (int, Optional[IRTensor]) -> Optional[IRTensor] """ @@ -256,12 +266,18 @@ def set_input(self, input_index: int, val): output.dtype = self._dtype val.dtype = self._dtype - l = list(self._inputs) - l[input_index] = val - self._inputs = tuple(l) + self._inputs[input_index] = val + self.inputs.cache_clear() return val + def reset_outputs(self, length:int) -> None: + """ + Resize the outputs list to the new length and reset all output items to None. + """ + self._outputs = [None] * length + self.outputs.cache_clear() + def set_output(self, output_index: int, val): # type: (int, Optional[IRTensor]) -> Optional[IRTensor] """ @@ -282,9 +298,9 @@ def set_output(self, output_index: int, val): # set output value dtype val.dtype = self._dtype - l = list(self._outputs) - l[output_index] = val - self._outputs = tuple(l) + self._outputs[output_index] = val + self.outputs.cache_clear() + return val def add_predecessor(self, input_index: int, cell): @@ -354,7 +370,7 @@ def get_inputs(cells): """ all_outputs = list() for cell in cells: - all_outputs += list(cell.outputs()) + all_outputs.extend(cell.outputs()) inputs = list() for cell in cells: for input in cell.inputs(): @@ -375,7 +391,7 @@ def get_outputs(cells): """ all_inputs = list() for node in cells: - all_inputs += list(node.inputs()) + all_inputs.extend(node.inputs()) outputs = list() for node in cells: for output in node.outputs(): diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 00e82332..d19ab835 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -124,10 +124,10 @@ def replicate(self): cpy = copy.copy(self) cpy._device = list() # reset input and output - cpy._inputs = (None,) * len(self.inputs()) + cpy.reset_inputs(len(self.inputs())) for idx, input in enumerate(self.inputs()): cpy.set_input(idx, input) - cpy._outputs = (None,) * len(self.outputs()) + cpy.reset_outputs(len(self.outputs())) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None @@ -214,10 +214,10 @@ def replicate(self): cpy._device = list() cpy._id = IDGenerator().gen_cell_id() # reset input and output - cpy._inputs = (None,) * len(self.inputs()) + cpy.reset_inputs(len(self.inputs())) for idx, input in enumerate(self.inputs()): cpy.set_input(idx, input) - cpy._outputs = (None,) * len(self.outputs()) + cpy.reset_outputs(len(self.outputs())) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None @@ -249,10 +249,10 @@ def replicate(self): cpy._device = list() cpy._id = IDGenerator().gen_cell_id() # reset input and output - cpy._inputs = (None,) * len(self.inputs()) + cpy.reset_inputs(len(self.inputs())) for idx, input in enumerate(self.inputs()): cpy.set_input(idx, input) - cpy._outputs = (None,) * len(self.outputs()) + cpy.reset_outputs(len(self.outputs())) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None From dcc2bf1853e4cab5d8ac67b650e8c34c585ad0a0 Mon Sep 17 00:00:00 2001 From: Zijian Ding Date: Mon, 1 Aug 2022 16:20:12 +0800 Subject: [PATCH 0929/1892] Add splitting method for creators, select and scatter, update wrf2 onedim splitting policy. --- cube/algorithm/factory.py | 12 +++ cube/algorithm/ops/creators.py | 130 +++++++++++++++++++++++++++++++ cube/algorithm/ops/dimops.py | 7 +- cube/algorithm/ops/scatter.py | 55 +++++++++++++ cube/algorithm/ops/select.py | 103 ++++++++++++++++++++++++ cube/codegen/frontend_mapping.py | 49 ++++++++++++ cube/graph/function/creators.py | 57 ++++++++++++-- cube/graph/function/function.py | 30 ++++++- cube/graph/function/scatter.py | 6 ++ cube/graph/function/select.py | 14 ++++ cube/graph/parser/mapping.py | 1 + examples/mlp/linears.py | 36 ++++++--- examples/wrf/policy/onedim.py | 52 ++++++++++--- 13 files changed, 520 insertions(+), 32 deletions(-) create mode 100644 cube/algorithm/ops/creators.py create mode 100644 cube/algorithm/ops/scatter.py create mode 100644 cube/algorithm/ops/select.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 1bda1b15..2ccfa44e 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -73,6 +73,18 @@ def _load_predefined_algos(self): import cube.algorithm.ops.pad as pad self.register(pad.IRPad, pad.DimSplitPad, tag='dim') + + import cube.algorithm.ops.select as select + self.register(select.IRSelect, select.DimSplitSelect, tag='dim') + self.register(select.IRSlice, select.DimSplitSlice, tag='dim') + + import cube.algorithm.ops.scatter as scatter + self.register(scatter.IRSelectScatter, scatter.DimSplitScatter, tag='dim') + + import cube.algorithm.ops.creators as creators + self.register(creators.IRToTensor, creators.DimSplitTo, tag='dim') + self.register(creators.IROnes, creators.DimSplitOnes, tag='dim') + self.register(creators.IRRand, creators.DimSplitRand, tag='dim') # import cube.algorithm.ops.elementwise as elew # self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') # self.register(elew.Add, elew.AddDimParallel, tag='dim') diff --git a/cube/algorithm/ops/creators.py b/cube/algorithm/ops/creators.py new file mode 100644 index 00000000..e397e1df --- /dev/null +++ b/cube/algorithm/ops/creators.py @@ -0,0 +1,130 @@ +from typing import List, Tuple, Optional + +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.function.creators import IRToTensor, IROnes, IRRand +from cube.ir.tensor import IRSubTensor + + +class DimSplitTo(GenericDistAlgo): + """ + split Pad at dimension level + + """ + def __init__(self, node: IRToTensor): + if not isinstance(node, IRToTensor): + raise TypeError(f"Expect IRToTensor") + super().__init__(node) + + def satisfy(self, dim: int, num: int): + """ + config = dict(idx=int, dim=int, num=num) + + """ + assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" + node: IRToTensor = self.node + + assert dim < len(node.input(0).shape), "Split dimension should be smaller than tensor dimension" + + # split non-pad dim + return node.input(0).shape[dim] >= num + + def instantiate(self, dim: int, num: int) -> Optional[List[IRToTensor]]: + + node: IRToTensor = self.node + satisfy = self.satisfy(dim, num) + if not satisfy: + return None + + ins, ous = list(), list() + for iidx, itensor in enumerate(node.inputs()): + assert isinstance(itensor, IRSubTensor), "Input of select shoud be IRSubTensor" + ins.append(itensor.split_dim(dim, num)) + + odim = dim + + for oidx, otensor in enumerate(node.outputs()): + assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" + ous.append(otensor.split_dim(odim, num)) + + sub_nodes = list() + for nid in range(num): + inputs = [t[nid] for t in ins] + outputs = [t[nid] for t in ous] + sub_nodes.append(node.new(inputs, outputs)) + return sub_nodes + + +class DimSplitOnes(GenericDistAlgo): + def __init__(self, node: IROnes): + if not isinstance(node, IROnes): + raise TypeError(f"Expect IROnes") + super().__init__(node) + + def satisfy(self, dim: int, num: int): + """ + config = dict(idx=int, dim=int, num=num) + + """ + assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" + node: IROnes = self.node + + assert dim < len(node.outputs(0).shape), "Split dimension should be smaller than tensor dimension" + + # split non-pad dim + return node.outputs(0).shape[dim] >= num + + def instantiate(self, dim: int, num: int) -> Optional[List[IROnes]]: + + node: IROnes = self.node + satisfy = self.satisfy(dim, num) + if not satisfy: + return None + + ous = list() + for oidx, otensor in enumerate(node.outputs()): + assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" + ous.append(otensor.split_dim(dim, num)) + + sub_nodes = list() + for nid in range(num): + outputs = [t[nid] for t in ous] + sub_nodes.append(node.new(outputs)) + return sub_nodes + +class DimSplitRand(GenericDistAlgo): + def __init__(self, node: IRRand): + if not isinstance(node, IRRand): + raise TypeError(f"Expect IRRand") + super().__init__(node) + + def satisfy(self, dim: int, num: int): + """ + config = dict(idx=int, dim=int, num=num) + + """ + assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" + node: IRRand = self.node + + assert dim < len(node.outputs(0).shape), "Split dimension should be smaller than tensor dimension" + + # split non-pad dim + return node.outputs(0).shape[dim] >= num + + def instantiate(self, dim: int, num: int) -> Optional[List[IRRand]]: + + node: IRRand = self.node + satisfy = self.satisfy(dim, num) + if not satisfy: + return None + + ous = list() + for oidx, otensor in enumerate(node.outputs()): + assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" + ous.append(otensor.split_dim(dim, num)) + + sub_nodes = list() + for nid in range(num): + outputs = [t[nid] for t in ous] + sub_nodes.append(node.new(outputs)) + return sub_nodes \ No newline at end of file diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 49500565..1198b2f5 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -106,6 +106,7 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: outputs = [t[nid] for t in ous] updated_kwargs = dict() if self._adim in node.kwargs and isinstance(node.kwargs[self._adim], int): + assert 0, "Should not happen" updated_kwargs[self._adim] = node.kwargs[self._adim] // num sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) sub_node.infer_shape() @@ -153,8 +154,10 @@ def satisfy(self, idx: int, dimi: int, dimo: int, num: int) -> bool: assert dimo < node.output(0).ndims, f"dimension out of boundary" # # due to implementation limits, we only partition the first annotated dimension # # for inner-dimension cases. - self._adimi: str = node.anno.input(0).dims[dimi].identifiers[0] - self._adimo: str = node.anno.output(0).dims[dimo].identifiers[0] + idi = 1 if dimi == 0 else 0 + ido = 1 if dimo == 0 else 0 + self._adimi: str = node.anno.input(0).dims[dimi].identifiers[idi] + self._adimo: str = node.anno.output(0).dims[dimo].identifiers[ido] dimlen = node.anno.getlen(self._adimi) if dimlen < num: return False diff --git a/cube/algorithm/ops/scatter.py b/cube/algorithm/ops/scatter.py new file mode 100644 index 00000000..6b8c467e --- /dev/null +++ b/cube/algorithm/ops/scatter.py @@ -0,0 +1,55 @@ +from typing import List, Tuple, Optional +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.function.scatter import IRSelectScatter +from cube.ir.tensor import IRSubTensor + + +class DimSplitScatter(GenericDistAlgo): + """ + split Pad at dimension level + + """ + def __init__(self, node: IRSelectScatter): + if not isinstance(node, IRSelectScatter): + raise TypeError(f"Expect IRSelectScatter") + super().__init__(node) + + def satisfy(self, diml: int, dimr: int, num: int): + """ + config = dict(idx=int, dim=int, num=num) + + """ + assert all(isinstance(t, int) for t in [diml, dimr, num]), "dim and num should be integer" + node: IRSelectScatter = self.node + + assert diml != node.kwargs['dim'], "Split dimension should not be equal to scatter dimension" + assert diml < len(node.input(0).shape), "Split dimension should be smaller than tensor dimension" + assert dimr < len(node.output(0).shape), "Split dimension should be smaller than tensor dimension" + assert node.input(0).shape[diml] == node.input(1).shape[dimr], "Two split dimension should at least have equal size" + + return node.input(0).shape[diml] >= num + + def instantiate(self, diml: int, dimr: int, num: int) -> Optional[List[IRSelectScatter]]: + + node: IRSelectScatter = self.node + satisfy = self.satisfy(diml, dimr, num) + if not satisfy: + return None + + assert len(node.inputs()) == 2, "Select_scatter do not has two inputs" + assert len(node.outputs()) == 1, "Select_scatter do not has one outputs" + + ins, ous = list(), list() + ins.append(node.input(0).split_dim(diml, num)) + ins.append(node.input(1).split_dim(dimr, num)) + + ous.append(node.output(0).split_dim(diml, num)) + + sub_nodes = list() + for nid in range(num): + inputs = tuple([t[nid] for t in ins]) + outputs = [t[nid] for t in ous] + sub_nodes.append(node.new(inputs, outputs)) + return sub_nodes + \ No newline at end of file diff --git a/cube/algorithm/ops/select.py b/cube/algorithm/ops/select.py new file mode 100644 index 00000000..7cf128cc --- /dev/null +++ b/cube/algorithm/ops/select.py @@ -0,0 +1,103 @@ +from typing import List, Tuple, Optional + +from cube.algorithm.generics import GenericDistAlgo + +from cube.graph.function.select import IRSelect, IRSlice +from cube.ir.tensor import IRSubTensor + + +class DimSplitSelect(GenericDistAlgo): + """ + split Pad at dimension level + + """ + def __init__(self, node: IRSelect): + if not isinstance(node, IRSelect): + raise TypeError(f"Expect IRSelect") + super().__init__(node) + + def satisfy(self, dim: int, num: int): + """ + config = dict(idx=int, dim=int, num=num) + + """ + assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" + node: IRSelect = self.node + + assert dim != node.kwargs['dim'], "Split dimension should not be equal to select dimension" + assert dim < len(node.input(0).shape), "Split dimension should be smaller than tensor dimension" + + # split non-pad dim + return node.input(0).shape[dim] >= num + + def instantiate(self, dim: int, num: int) -> Optional[List[IRSelect]]: + + node: IRSelect = self.node + satisfy = self.satisfy(dim, num) + if not satisfy: + return None + + ins, ous = list(), list() + for iidx, itensor in enumerate(node.inputs()): + assert isinstance(itensor, IRSubTensor), "Input of select shoud be IRSubTensor" + ins.append(itensor.split_dim(dim, num)) + + odim = dim - int(node.kwargs['dim'] < dim) + + for oidx, otensor in enumerate(node.outputs()): + assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" + ous.append(otensor.split_dim(odim, num)) + + sub_nodes = list() + for nid in range(num): + inputs = [t[nid] for t in ins] + outputs = [t[nid] for t in ous] + sub_nodes.append(node.new(inputs, outputs)) + return sub_nodes + + +class DimSplitSlice(GenericDistAlgo): + """ + split Pad at dimension level + + """ + def __init__(self, node: IRSlice): + if not isinstance(node, IRSlice): + raise TypeError(f"Expect IRSlice") + super().__init__(node) + + def satisfy(self, dim: int, num: int): + assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" + node: IRSlice = self.node + + if dim == node.kwargs['dim']: + return None + assert dim < len(node.input(0).shape), "Split dimension should be smaller than tensor dimension" + + # split non-pad dim + return node.input(0).shape[dim] >= num + + def instantiate(self, dim: int, num: int) -> Optional[List[IRSlice]]: + + node: IRSlice = self.node + print(dim, node.kwargs['dim']) + satisfy = self.satisfy(dim, num) + if not satisfy: + return None + + ins, ous = list(), list() + for iidx, itensor in enumerate(node.inputs()): + assert isinstance(itensor, IRSubTensor), "Input of select shoud be IRSubTensor" + ins.append(itensor.split_dim(dim, num)) + + for oidx, otensor in enumerate(node.outputs()): + assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" + ous.append(otensor.split_dim(dim, num)) + + sub_nodes = list() + for nid in range(num): + inputs = [t[nid] for t in ins] + outputs = [t[nid] for t in ous] + sub_nodes.append(node.new(inputs, outputs)) + return sub_nodes + \ No newline at end of file diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index e2f930c8..83649e21 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -105,6 +105,53 @@ def emit_ones(node, arg_vars:list, kw_pairs:dict) -> str: assert len(arg_vars) == 0 return _common_rule_join_all(node, arg_vars, kw_pairs) +def emit_rand(node, arg_vars:list, kw_pairs:dict) -> str: + """ + rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + """ + kw_pairs = kw_pairs.copy() + if 'dtype' in kw_pairs: + ir_dtype : IRDType = kw_pairs['dtype'] + if ir_dtype is not None: + kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) + + # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. + assert 'device' not in kw_pairs + kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. + + assert len(arg_vars) == 0 + return _common_rule_join_all(node, arg_vars, kw_pairs) + + +def emit_new_tensor(node, arg_vars:list, kw_pairs:dict) -> str: + """ + rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + """ + kw_pairs = kw_pairs.copy() + if 'dtype' in kw_pairs: + ir_dtype : IRDType = kw_pairs['dtype'] + if ir_dtype is not None: + kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) + + # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. + assert 'device' not in kw_pairs + kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. + + assert len(arg_vars) == 0 + assert 'data' in kw_pairs + assert 'shape' in kw_pairs + data_str = str(kw_pairs['data']) + _ = kw_pairs.pop('data') + _ = kw_pairs.pop('shape') + + kw_assigns = list() + for key, val in kw_pairs.items(): + assert key != 'data' + code = f'{key}={val}' + kw_assigns.append(code) + args = data_str + ', ' + ', '.join(kw_assigns) + return f'{node.signature}({args})' + # Basically to convert internal 'IRDType' to frontend 'torch.dtype' def emit_to(node, arg_vars:list, kw_pairs:dict) -> str: kw_pairs = kw_pairs.copy() @@ -144,6 +191,8 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: 'torch.zeros': emit_zeros, 'torch.ones': emit_ones, 'torch.Tensor.to': emit_to, + 'torch.rand': emit_rand, + 'torch.tensor': emit_new_tensor } diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index e42467e9..aa08f3ea 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -5,6 +5,8 @@ from cube.ir.operator import IRFwOperation from cube.ir.cten import IRTensor +import numpy as np + class IRZeros(IRFwOperation): def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): @@ -46,12 +48,51 @@ def infer_shape(self) -> bool: shape : list = copy(self.kwargs["size"]) self.output(0).shape = shape return True + + def new(self, outputs: List[IRTensor]): + op = IROnes(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IROnes::new infer_shape failed" + return op + +class IRRand(IRFwOperation): + def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): + + # The shape information must be statically known integer values + assert all(isinstance(dim, int) for dim in shape) + assert isinstance(ir_dtype, IRDType) + + super().__init__(name, signature, input_length=0, output_length=1) -#class IRNewTensor(IRFwOperation): -# def __init__(self, signature: str, data, name:str): -# pass -# def infer_shape(self) -> bool: -# pass + # Customize output's dtype only after 'super().__init__' and 'self.set_input', + # otherwise it gets overwritten. + self.output(0).dtype = ir_dtype + + # The positional argument to specify the shape is actually called 'size'. + self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) + + def infer_shape(self) -> bool: + shape : list = copy(self.kwargs["size"]) + self.output(0).shape = shape + return True + + def new(self, outputs: List[IRTensor]): + op = IRRand(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRRand::new infer_shape failed" + return op + +class IRNewTensor(IRFwOperation): + def __init__(self, signature: str, data: list, name: str, ir_dtype: IRDType): + super().__init__(name, signature, input_length=0, output_length=1) + self.output(0).dtype = ir_dtype + self.kwargs.update({'data': data, 'shape': np.array(data).shape, 'dtype': ir_dtype}) + + def infer_shape(self) -> bool: + shape : list = copy(self.kwargs['shape']) + self.output(0).shape = shape + return True + # `aten::to` has several overloading, which one should be dispatched is determined by the argument types @@ -75,5 +116,11 @@ def __init__(self, signature: str, inputs, name:str, ir_dtype:IRDType): def infer_shape(self) -> bool: self.output(0).shape = self.input(0).shape return True + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + op = IRToTensor(self.signature, inputs, self.name, self.kwargs['dtype']) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRToTensor::new infer_shape failed" + return op diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 21fee65a..f22ddb7f 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -13,7 +13,7 @@ from cube.graph.function.pad import IRPad from cube.graph.function.scripteinops import IRScriptEinOps from cube.graph.function.customops import IRCustomOps -from cube.graph.function.creators import IROnes, IRToTensor, IRZeros +from cube.graph.function.creators import IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor from cube.graph.function.select import IRSelect, IRSlice from cube.graph.function.scatter import IRSelectScatter from cube.graph.function.repeat import IRRepeat @@ -98,6 +98,30 @@ def Ones(signature, raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") return IROnes(signature, size, 'ones', ir_dtype) +def Rand(signature, + inputs: Tuple[ List[int], Optional[int], Optional[Any], 'ErasedDevice', Optional[bool] ]): + # ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + + size, dtype_underlying, layout, _erased_device, pin_memory = inputs + + # TODO parameters to support, currently they are all None + assert layout is None + assert pin_memory is None + + if dtype_underlying is not None: + # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, + # which is the underlying type of PyTorch C++ enum 'ScalarType'. + dtype = TorchScalarTypeEnumMap.map(dtype_underlying) + else: + dtype = torch.get_default_dtype() + + ir_dtype : IRDType = DType2IRDType.map(dtype) + + for dim, i in enumerate(size): + if not isinstance(dim, int) and not dim >= 0: + raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") + return IRRand(signature, size, 'rand', ir_dtype) + def NewTensor(signature, inputs: Tuple[ list, Optional[int], 'ErasedDevice', bool ]): # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor @@ -129,9 +153,7 @@ def NewTensor(signature, # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', # but since we have omitted the 'data', we must do type inferrence ourselves, # only in this way we get correct dtype e.g. ints or bools. - shape = list(arr.shape) - signature = 'torch.ones' - return IROnes(signature, shape, 'tensor', ir_dtype=ir_dtype) + return IRNewTensor(signature, data, 'tensor', ir_dtype=ir_dtype) def ToTensor(signature, inputs: Tuple[ IRTensor, ... ]): diff --git a/cube/graph/function/scatter.py b/cube/graph/function/scatter.py index 14c531b9..dcb2fe97 100644 --- a/cube/graph/function/scatter.py +++ b/cube/graph/function/scatter.py @@ -58,3 +58,9 @@ def infer_shape(self) -> bool: self.output(0).shape = s2 return True + def new(self, inputs:Tuple[IRTensor, IRTensor], outputs: List[IRTensor]): + op = IRSelectScatter(self.signature, inputs, self.name, self.kwargs['dim'], self.kwargs['index']) + assert len(outputs) == 1, "Select_scatter: too many outputs" + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRSelect::new infer_shape failed" + return op diff --git a/cube/graph/function/select.py b/cube/graph/function/select.py index ba5d8b8d..1f8739d4 100644 --- a/cube/graph/function/select.py +++ b/cube/graph/function/select.py @@ -28,6 +28,12 @@ def infer_shape(self) -> bool: return True + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + op = IRSelect(self.signature, inputs, self.name, self.kwargs['dim'], self.kwargs['index']) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRSelect::new infer_shape failed" + return op class IRSlice(IRFwOperation): """ @@ -73,6 +79,14 @@ def clip(offset): self.output(0).shape = s2 return True + + def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): + assert len(inputs) == 1, "Slice: number of inputs not equal to 1" + op = IRSlice(self.signature, inputs, self.name, self.kwargs['dim'], self.kwargs['start'], self.kwargs['end'], self.kwargs['step']) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRSlice::new infer_shape failed" + return op # torch.gather(input:Tensor, dim:int, index:LongTensor, *, sparse_grad=False, out=None) -> Tensor diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 5c97fb59..b78089c8 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -85,6 +85,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ones'): function.Ones, __ttemplate('tensor'): function.NewTensor, __ttemplate('to'): function.ToTensor, + __ttemplate('rand'): function.Rand, __ttemplate('add') : function.Add, diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index b7084296..eecb2ed2 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -47,26 +47,42 @@ def __init__(self, dim, mult=1): self.linear2 = nn.Linear(dim * mult, dim) self.linear3 = nn.Linear(dim, dim * mult) self.linear4 = nn.Linear(dim * mult, dim) - # self.linear5 = nn.Linear(dim, dim * mult) - # self.linear6 = nn.Linear(dim * mult, dim) - # self.linear7 = nn.Linear(dim, dim * mult) - # self.linear8 = nn.Linear(dim * mult, dim) + self.linear5 = nn.Linear(dim, dim * mult) + self.linear6 = nn.Linear(dim * mult, dim) + self.linear7 = nn.Linear(dim, dim * mult) + self.linear8 = nn.Linear(dim * mult, dim) + self.linear9 = nn.Linear(dim, dim * mult) + self.linear10 = nn.Linear(dim * mult, dim) + self.linear11 = nn.Linear(dim, dim * mult) + self.linear12 = nn.Linear(dim * mult, dim) + self.linear13 = nn.Linear(dim, dim * mult) + self.linear14 = nn.Linear(dim * mult, dim) + self.linear15 = nn.Linear(dim, dim * mult) + self.linear16 = nn.Linear(dim * mult, dim) def forward(self, data): output = self.linear1(data) output = self.linear2(output) output = self.linear3(output) output = self.linear4(output) - # output = self.linear5(output) - # output = self.linear6(output) - # output = self.linear7(output) - # output = self.linear8(output) + output = self.linear5(output) + output = self.linear6(output) + output = self.linear7(output) + output = self.linear8(output) + output = self.linear9(output) + output = self.linear10(output) + output = self.linear11(output) + output = self.linear12(output) + output = self.linear13(output) + output = self.linear14(output) + output = self.linear15(output) + output = self.linear16(output) loss = torch.sum(output) return loss def train(): - batch_size = 256 + batch_size = 128 dim = 8192 model = MLP(dim=dim) @@ -92,7 +108,7 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() if torch.distributed.is_initialized(): torch.distributed.barrier() - iter_num = 64 + iter_num = 500 warmup = 20 for step in range(iter_num): if step >= warmup: diff --git a/examples/wrf/policy/onedim.py b/examples/wrf/policy/onedim.py index 7c45d8e6..02198695 100644 --- a/examples/wrf/policy/onedim.py +++ b/examples/wrf/policy/onedim.py @@ -2,6 +2,7 @@ from cube.graph.function import IRConv2D, IRConv3D from cube.graph.function import IRDimops, IRPad from cube.ir.cten import IRTensor, IRCell +from cube.graph.function import IRSelect, IRSelectScatter, IRSlice, IRToTensor, IROnes, IRRand def PAS(graph: IRGraph, resource): @@ -72,18 +73,20 @@ def PAS_ALL_TEST(graph: IRGraph, resource): def PAS_ALL_X(graph: IRGraph, resource): + elewise_sign = ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat', 'stack', 'sum', 'sin', 'gt'] + # elewise_sign = ['mul', 'div', 'add', 'sub'] for node in graph.nodes(): sign = node.signature.split('.')[-1] - append_sign(sign) if isinstance(node, IRConv3D): sub_nodes = list() algo = node.algorithms('halo') sub_nodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus) elif isinstance(node, IRDimops): - if sign in ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat']: + if sign in elewise_sign: ndims = node.input(0).ndims algo = node.algorithms('dim') - if ndims == 3 or ndims == 5: + append_sign(ndims) + if ndims == 3 or ndims == 5 or ndims == 2 or ndims == 4: sub_nodes = graph.partition(node, algo, idx=0, dim=ndims-1, num=resource.ngpus) if sub_nodes == None: sub_nodes = graph.replicate(node, times=resource.ngpus) @@ -91,15 +94,25 @@ def PAS_ALL_X(graph: IRGraph, resource): sub_nodes = graph.replicate(node, times=resource.ngpus) elif sign == 'view': algo = node.algorithms('view_simp') - if node.input(0).ndims >= 3 and node.output(0).ndims >= 3: + if node.input(0).ndims >= 2 and node.output(0).ndims >= 3: sub_nodes = graph.partition(node, algo, idx=0, dimi=node.input(0).ndims-1, dimo=node.output(0).ndims-1, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) - elif isinstance(node, IRPad): + # FIXME: Check 'circular' padding, should not be splitted easily + elif isinstance(node, IRSelect) or isinstance(node, IRPad) or isinstance(node, IRSlice) or isinstance(node, IRToTensor): algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, dim=node.input(0).ndims-1, num=resource.ngpus) + elif isinstance(node, IRSelectScatter): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, diml=node.input(0).ndims-1, dimr=node.input(1).ndims-1, num=resource.ngpus) + elif isinstance(node, IROnes) and node.output(0).ndims >= 3: + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-1, num=resource.ngpus) + # elif isinstance(node, IRRand) and node.output(0).ndims >= 3: + # algo = node.algorithms('dim') + # sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-1, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): @@ -109,18 +122,21 @@ def PAS_ALL_X(graph: IRGraph, resource): return graph def PAS_ALL_Y(graph: IRGraph, resource): + elewise_sign = ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat', 'stack', 'sum', 'sin', 'gt'] + # elewise_sign = ['mul', 'div', 'add', 'sub'] for node in graph.nodes(): sign = node.signature.split('.')[-1] - append_sign(sign) if isinstance(node, IRConv3D): sub_nodes = list() algo = node.algorithms('halo') sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) + assert sub_nodes != None elif isinstance(node, IRDimops): - if sign in ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat']: + if sign in elewise_sign: ndims = node.input(0).ndims algo = node.algorithms('dim') - if ndims == 3 or ndims == 5: + append_sign(ndims) + if ndims == 3 or ndims == 5 or ndims == 2 or ndims == 4: sub_nodes = graph.partition(node, algo, idx=0, dim=ndims-2, num=resource.ngpus) if sub_nodes == None: sub_nodes = graph.replicate(node, times=resource.ngpus) @@ -128,19 +144,33 @@ def PAS_ALL_Y(graph: IRGraph, resource): sub_nodes = graph.replicate(node, times=resource.ngpus) elif sign == 'view': algo = node.algorithms('view_simp') - if node.input(0).ndims >= 3 and node.output(0).ndims >= 3: + if node.input(0).ndims >= 2 and node.output(0).ndims >= 3: + print(node.input(0).shape, node.output(0).shape) sub_nodes = graph.partition(node, algo, idx=0, dimi=node.input(0).ndims-2, dimo=node.output(0).ndims-2, num=resource.ngpus) + assert sub_nodes != None else: sub_nodes = graph.replicate(node, times=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) - elif isinstance(node, IRPad): + elif isinstance(node, IRSelect) or isinstance(node, IRPad) or isinstance(node, IRSlice) or isinstance(node, IRToTensor): algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, dim=node.input(0).ndims-2, num=resource.ngpus) + assert sub_nodes != None + elif isinstance(node, IRSelectScatter): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, diml=node.input(0).ndims-2, dimr=node.input(1).ndims-2, num=resource.ngpus) + assert sub_nodes != None + elif isinstance(node, IROnes) and node.output(0).ndims >= 3: + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-2, num=resource.ngpus) + assert sub_nodes != None + # elif isinstance(node, IRRand) and node.output(0).ndims >= 3: + # algo = node.algorithms('dim') + # sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-1, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) - # print(graph.extra_repr()) + print(graph.extra_repr()) print(opSigns) return graph From b3b80cbe57ce73aa510104e1920f55b423f059fc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Aug 2022 11:10:17 +0800 Subject: [PATCH 0930/1892] adapter for graph output --- cube/algorithm/ops/dimops.py | 1 - cube/execplan/execplan.py | 6 ----- cube/graph/gener/gen.py | 18 ++++++++++++- examples/mlp/linears.py | 49 ++++++++++-------------------------- 4 files changed, 30 insertions(+), 44 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 1198b2f5..7a6c259c 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -106,7 +106,6 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: outputs = [t[nid] for t in ous] updated_kwargs = dict() if self._adim in node.kwargs and isinstance(node.kwargs[self._adim], int): - assert 0, "Should not happen" updated_kwargs[self._adim] = node.kwargs[self._adim] // num sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) sub_node.infer_shape() diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 75ecce40..e60de05b 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -40,12 +40,6 @@ def __init__(self, graph: IRGraph): nodes.remove(fnode.mirror) self.at(devid)[bidx] = fnode_dev.mirror - # TODO: adapter support for return consistency - for output in graph.outputs(): - for devid in self.devices(): - ptensors = [pt for pt in output.parent.ptensors if pt == output and devid in pt.device] - assert len(ptensors) >= 1, f"Missing full graph output tensor {output} in device {devid}" - @property def graph(self) -> IRGraph: return self._graph diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 7c6d502e..09566a31 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -3,6 +3,7 @@ import copy from cube.graph.graph import IRGraph +from cube.graph.function import Identity from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap from cube.ir.adapter import IRAdapter, IRWeightReducer @@ -31,6 +32,19 @@ def gen(graph: IRGraph) -> IRGraph: Returns: graph (IRGraph) """ + # insert identity operator for graph output + devs = set() + for node in graph.nodes(): + devs.update(node.device) + outputs = [otensor for otensor in graph.outputs() if isinstance(otensor, IRSubTensor)] + all_identities = [] + for otensor in outputs: + identity = Identity('', [otensor]) + graph.attach(identity, len(graph.nodes())) + identites = graph.replicate(identity, times=len(devs)) + all_identities += identites + for devid, identity in zip(devs, identites): + graph.assign(identity, devid) # update the gradient before generate adapter for node in graph.nodes(): if isinstance(node, IRBpOperation): @@ -39,7 +53,9 @@ def gen(graph: IRGraph) -> IRGraph: graph.attach(node, idx) graph = IRAdapterGener.gen_activation(graph) graph = IRAdapterGener.gen_weight(graph) - # TODO: generate adapter for graph outputs + # remove inserted identity + for identity in all_identities: + graph.detach(identity) return graph @staticmethod diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index eecb2ed2..8bc454a8 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -41,43 +41,20 @@ # =================== Semantic Model Description ==================== class MLP(nn.Module): - def __init__(self, dim, mult=1): + def __init__(self, dim, mult=1, nlayers=4): super().__init__() - self.linear1 = nn.Linear(dim, dim * mult) - self.linear2 = nn.Linear(dim * mult, dim) - self.linear3 = nn.Linear(dim, dim * mult) - self.linear4 = nn.Linear(dim * mult, dim) - self.linear5 = nn.Linear(dim, dim * mult) - self.linear6 = nn.Linear(dim * mult, dim) - self.linear7 = nn.Linear(dim, dim * mult) - self.linear8 = nn.Linear(dim * mult, dim) - self.linear9 = nn.Linear(dim, dim * mult) - self.linear10 = nn.Linear(dim * mult, dim) - self.linear11 = nn.Linear(dim, dim * mult) - self.linear12 = nn.Linear(dim * mult, dim) - self.linear13 = nn.Linear(dim, dim * mult) - self.linear14 = nn.Linear(dim * mult, dim) - self.linear15 = nn.Linear(dim, dim * mult) - self.linear16 = nn.Linear(dim * mult, dim) + self.layers = torch.nn.ModuleList([]) + for lid in range(nlayers): + if lid % 2 == 0: + self.layers.append(nn.Linear(dim, dim * mult)) + else: + self.layers.append(nn.Linear(dim * mult, dim)) def forward(self, data): - output = self.linear1(data) - output = self.linear2(output) - output = self.linear3(output) - output = self.linear4(output) - output = self.linear5(output) - output = self.linear6(output) - output = self.linear7(output) - output = self.linear8(output) - output = self.linear9(output) - output = self.linear10(output) - output = self.linear11(output) - output = self.linear12(output) - output = self.linear13(output) - output = self.linear14(output) - output = self.linear15(output) - output = self.linear16(output) - loss = torch.sum(output) + x = data + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) return loss @@ -108,8 +85,8 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() if torch.distributed.is_initialized(): torch.distributed.barrier() - iter_num = 500 - warmup = 20 + iter_num = 32 + warmup = 8 for step in range(iter_num): if step >= warmup: CudaTimer(enable=True).start('e2e') From 85ba23c16044df3b6a966164967b1e2ec345c433 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Aug 2022 11:13:18 +0800 Subject: [PATCH 0931/1892] database for profiling forward, backward, memory --- cube/profiler/database.py | 82 ++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 32 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 81120d3a..f1cbfc6a 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -20,8 +20,8 @@ class CompProfiler: @staticmethod def profile(func: Callable, shapes: Shapes, dtypes: DTypes, - warmup_sec: float = 2, prof_times: int = 50, backward = True, - **kwargs): + warmup_sec: float = 2, prof_times: int = 50, + **kwargs) -> Tuple[float, float, int]: """ Profile a function @@ -30,10 +30,11 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, @param dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 @param warmup_sec float: warmup seconds @param prof_times int: profile times - @param backward bool: whether profile backward times. Default true. @param kwargs Dict: other keyword argument for func call. - @return span float: the time in milliseconds for forward (+backward) time + @return fw_span float: the time in milliseconds for forward time + @return bw_span float: the time in milliseconds for backward time + @return memory int: the peak memory in bytes after forward """ assert len(shapes) == len(dtypes), \ f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" @@ -49,28 +50,46 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, f"{func.__name__}: require all the outputs to be tensors" grads = tuple(torch.zeros_like(otensor) for otensor in outputs) - # warmup - tic = time.time() - while time.time() - tic < warmup_sec: - # forward + def run_step(func, tensors, kwargs, backward: bool): outputs = func(*tensors, **kwargs) - # backward if backward: torch.autograd.backward(outputs, grads) - - # profile forward + return outputs + + # warmup + tic = time.time() + while time.time() - tic < warmup_sec: + run_step(func, tensors, kwargs, backward=True) + + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + mtic = torch.cuda.max_memory_allocated() # in bytes + outs = run_step(func, tensors, kwargs, backward=False) + mtoc = torch.cuda.max_memory_allocated() # in bytes + memory = mtoc - mtic + + # profile forward only torch.cuda.synchronize() tic = time.perf_counter() for _ in range(prof_times): - # forward - outputs = func(*tensors, **kwargs) - # backward - if backward: - torch.autograd.backward(outputs, grads) + with torch.no_grad(): + run_step(func, tensors, kwargs, backward=False) + torch.cuda.synchronize() + toc = time.perf_counter() + fw_span = (toc - tic) / prof_times * 1000 # in milliseconds + + # profile forward + backward + torch.cuda.synchronize() + tic = time.perf_counter() + for _ in range(prof_times): + run_step(func, tensors, kwargs, backward=True) torch.cuda.synchronize() toc = time.perf_counter() - span = (toc - tic) / prof_times * 1000 # in milliseconds - return span + fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds + bw_span = fwbw_span - fw_span + + return fw_span, bw_span, memory class ProfileDataBase: @@ -84,8 +103,7 @@ def __init__(self, filename: Optional[str] = None) -> None: if filename is not None: self.load(filename) - def profile(self, func: Callable, shapes: Shapes, dtypes: DTypes, - backward=True, **kwargs): + def profile(self, func: Callable, shapes: Shapes, dtypes: DTypes, **kwargs): """! Profile the function and log into the database @@ -97,22 +115,22 @@ def profile(self, func: Callable, shapes: Shapes, dtypes: DTypes, """ try: assert callable(func), "func should be callable" - span = CompProfiler.profile(func, shapes, dtypes, backward=backward, **kwargs) + fw_span, bw_span, memory = CompProfiler.profile(func, shapes, dtypes, **kwargs) name = func.__name__ key = self.serialize(shapes, dtypes) - self.log(name, key, span) - print(f'profiled {func.__name__} | shapes: {shapes} | dtypes: {dtypes} => span: {round(span, 2)} ms') + self.log(name, key, fw_span, bw_span, memory) + print(f'profiled {func.__name__} | shapes: {shapes} | dtypes: {dtypes} => fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} | mem: {memory}') except Exception as e: print(f'fail to profile {func.__name__}: reason: {str(e)}') - def log(self, name: str, key: str, span: float): + def log(self, name: str, key: str, fw_span: float, bw_span: float, memory: float): """ log the span of a function name with key """ - assert isinstance(name, str) and isinstance(span, float) and isinstance(key, str) + assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = span + self._data[name][key] = (fw_span, bw_span, memory) def query(self, func: NameOrFunc, shapes: Shapes, dtypes: DTypes) -> float: """! @@ -122,7 +140,7 @@ def query(self, func: NameOrFunc, shapes: Shapes, dtypes: DTypes) -> float: @param shapes Tuple[Tuple[int]]: the shape of each input tensor @param dtypes Tuple[torch.dtype]: the dtype of each tensor - @return span float: the performance number + @return (fw_span, bw_span, mem) (float, float, int): the performance number """ name = func if isinstance(func, str) else func.__name__ key = self.serialize(shapes, dtypes) @@ -249,16 +267,16 @@ def load(self, file: str): # [shapes, dtypes, kwargs], # ] funcs = { + torch.nn.functional.gelu: [ + [((1024, 8, 2304),), (dtype,), {}] + ], + torch.nn.functional.linear: [ [([1024, 1, 2304], [2304, 2304]), (dtype, dtype), {}], [([1024, 4, 2304], [2304, 2304]), (dtype, dtype), {}], [([1024, 8, 2304], [2304, 2304]), (dtype, dtype), {}] ], - torch.nn.functional.gelu: [ - [((1024, 8, 2304),), (dtype,), {}] - ], - torch.nn.functional.softmax: [ [((1024, 8, 2304),), (dtype,), dict(dim=-1)] ] @@ -266,7 +284,7 @@ def load(self, file: str): for func, keys in funcs.items(): for shapes, dtypes, kwargs in keys: - db.profile(func, shapes, dtypes, backward=True, **kwargs) + db.profile(func, shapes, dtypes, **kwargs) db.dump(args.export, override=True) From 6ed2a3b117e2d805de943f31f9a007407386e5bb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Aug 2022 19:27:30 +0800 Subject: [PATCH 0932/1892] fix register bug --- cube/algorithm/ops/dimops.py | 86 +++++++++++++++++------------------ cube/graph/function/dimops.py | 2 +- cube/graph/parser/register.py | 5 +- 3 files changed, 47 insertions(+), 46 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 7a6c259c..5c6ea673 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Any from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.dimops import IRDimops, DimAnno @@ -6,14 +6,28 @@ class DimSplitEinops(GenericDistAlgo): - """ - split Einops at dimension level. + """! + Split Dimops at tensor dimension. - The sum-reduce dimension and non-reduce dimension can be splitted. + Note: for dimensions of multiple identitifers, only the first identifier + can be partitioned. - For sum-reduce dimension, the output keeps same shape but has partial-sum valmap result. - For non-reduce dimension, the output keeps same valmap but has partial output shape. - For stay-reduce dimension, this dimension is not allowed to be splitted. + Rules for identifier split: + * Sum-reduce identifier ('+'): + * For inputs/outputs that have the identifier, will be partitioned on its diemension uniformly.. + * For inputs that don't have the identifier, will be replicated + * For outputs that don't have the identifier, will be partitioned on its value uniformly. + + * Spatial identifier (''): + * For inputs/outputs that have the identifier, will be partitioned on its diemnsion uniformly. + * For inputs/outputs that don't have the identifier, will be replicated + + * Frozen identifier ('^'): + * Cannot be partitioned. + + Non-tensor will always be replicated. + + Note this rule will not correctly apply for some operators like linear: xw + b """ def __init__(self, node: IRDimops): @@ -42,8 +56,7 @@ def satisfy(self, idx: int, dim: int, num: int) -> bool: assert isinstance(node.input(idx), IRSubTensor), f"partitioning on a non-tensor input" dim = dim if dim >= 0 else dim + node.input(idx).ndims assert dim < node.input(idx).ndims, f"dimension output of boundary: {dim} >= {node.input(idx).ndims}" - # due to implementation limits, we only partition the first annotated dimension - # for inner-dimension cases. + # we only partition the first annotated dimension for inner-dimension cases. self._adim: str = node.anno.input(idx).dims[dim].identifiers[0] self._reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] dimlen = node.anno.getlen(self._adim) @@ -61,44 +74,30 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: if not satisfy: return None + def transform(tensor: Any, split_dims: List[int], is_input: bool): + # rule: non-tensor will always be replicated + if not isinstance(tensor, IRSubTensor): + return [tensor] * num + assert len(split_dims) <= 1, "find split dims ({self._adim}) more than 1" + # rule: spatial identifier ('') + if self._reduce == DimAnno.ReduceType.Dim: + return tensor.replicate(num) if len(split_dims) == 0 else tensor.split_dim(split_dims[0], num) + # rule: reduce-sum identifier ('+') + if self._reduce == DimAnno.ReduceType.Sum: + if len(split_dims) == 0: + return tensor.replicate(num) if is_input else tensor.split_val(num) + else: + return tensor.split_dim(split_dims[0], num) + raise RuntimeError(f"no matching reduce type for transform: {self._reduce}") + ins, ous = list(), list() for iidx, itensor in enumerate(node.inputs()): - if not isinstance(itensor, IRSubTensor): - ins.append([itensor] * num) - continue - shape_anno = node.anno.input(iidx) - split_dims = shape_anno.getdims(self._adim) - assert len(split_dims) <= 1, f"find split dims ({self._adim}) more than 1: {shape_anno}" - if len(split_dims) == 1: - dim = split_dims[0] - # split axis - ins.append(itensor.split_dim(dim, num)) - else: - # replicate if no split dimension of this tensor - # ins.append([itensor] * num) - # ad-hoc FIXME: for linear function Ax+b of splitting reduction dimension, b should - # be splitted by value dimension. - if self._reduce == DimAnno.ReduceType.Sum: - ins.append(itensor.split_val(num)) - else: - ins.append(itensor.replicate(num)) + split_dims = node.anno.input(iidx).getdims(self._adim) + ins.append(transform(itensor, split_dims, is_input=True)) for oidx, otensor in enumerate(node.outputs()): - if not isinstance(otensor, IRSubTensor): - ous.append([otensor] * num) - continue - shape_anno = node.anno.output(oidx) - split_dims = shape_anno.getdims(self._adim) - assert len(split_dims) <= 1, f"find split dims ({self._adim}) more than 1: {shape_anno}" - # split axis - if self._reduce == DimAnno.ReduceType.Dim: - assert len(split_dims) == 1, f"expect only one spatial dimension in output tensor but got {len(split_dims)}" - dim = split_dims[0] - ous.append(otensor.split_dim(dim, num)) - # split numerical dimension - else: - assert len(split_dims) == 0, f"expect no numerical dimension in output tensor but got {len(split_dims)}" - ous.append(otensor.split_val(num)) + split_dims = node.anno.output(oidx).getdims(self._adim) + ous.append(transform(otensor, split_dims, is_input=False)) sub_nodes = list() for nid in range(num): @@ -110,6 +109,7 @@ def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) sub_node.infer_shape() sub_nodes.append(sub_node) + return sub_nodes diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index c205dd99..696d002b 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -494,7 +494,7 @@ def __init__(self, signature: str, annos: Tuple[str], raise RuntimeError( f"no matching anno for given annos." f"op: {signature}\n" - f"inputs: {inputs}\n" + f"inputs: {tuple(t.shape for t in inputs)}\n" f"annos: {annos}\n" ) diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 6db1fc25..43c89188 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -31,7 +31,8 @@ def funcname(x: torch.Tensor, b: int = 4): def decorator(fn: Callable): if not callable(fn): raise TypeError("Expected a function") - fsig = fn.__name__ if name is None else name + fsig = fn.__name__ + op_name = name if name is not None else fsig args = inspect.signature(fn) arg_names = list(args.parameters.keys()) arg_kind = [args.parameters[name].annotation for name in arg_names] @@ -48,7 +49,7 @@ def udfop(signature: str, inputs: List[Any]): kwargs = dict() for name, val in zip(kwarg_names, kwarg_vals): kwargs[name] = val - return IRDimops(signature, [anno], tensors, **kwargs, name=fsig) + return IRDimops(signature, [anno], tensors, **kwargs, name=op_name) print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') Sign2Op.register(fsig, udfop, code) From c65ada2d6353d2cdeb10ac67a54075b515ec1b14 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 2 Aug 2022 19:27:56 +0800 Subject: [PATCH 0933/1892] gpt model --- examples/nlp/blocks/attention.py | 17 +++++++---- examples/nlp/blocks/encoder.py | 6 ++-- examples/nlp/blocks/mlp.py | 8 +++-- examples/nlp/gpt/policy/naive.py | 8 ----- examples/nlp/gpt/policy/spmd.py | 51 ++++++++++++++++++++++++++++++++ examples/nlp/gpt/train.py | 24 +++++++-------- 6 files changed, 82 insertions(+), 32 deletions(-) delete mode 100644 examples/nlp/gpt/policy/naive.py create mode 100644 examples/nlp/gpt/policy/spmd.py diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index bf9d9471..7661c6a4 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -1,19 +1,24 @@ +from typing import Optional + import torch import cube +import warnings + -@cube.graph.parser.register('L^ N E^, H+ E^, H+, H+ E^, H+, H+ E^, H+, E^ H+, E^ -> L^ N E^') +# @cube.graph.parser.register('L^ N E^, (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), E^ (h+ d^), E^ -> L^ N E^', name='self_attention') +@cube.graph.parser.register('L^ N E^, (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), E^ (h+ d^) -> L^ N E^', name='self_attention') def self_attention(query: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, out_bias: torch.Tensor, + out_proj: torch.Tensor, out_bias: None, h: int, scale: float, dropout_p: float, mask: bool = True): num_head = h L, N = query.size(0), query.size(1) dim_head = q_proj.size(0) // num_head - q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d @@ -45,7 +50,7 @@ def self_attention(query: torch.Tensor, return output -@cube.graph.parser.register('L^ N E^, L^ N E^, H+ E^, H+, H+ E^, H+, H+ E^, H+, E^ H+, E^ -> L^ N E^') +@cube.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d), E^ -> L^ N E^', name='cross_attention') def cross_attention(query: torch.Tensor, key: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, @@ -108,7 +113,9 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None # Out self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + # self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + self.out_bias = None + warnings.warn('self attention dense bias is skipped for correctness.') def forward(self, query): return self_attention( diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 1e082fd9..83645f61 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -1,7 +1,6 @@ import torch from examples.nlp.blocks.attention import MultiHeadSelfAttention from examples.nlp.blocks.mlp import MLP -import warnings class EncoderLayer(torch.nn.Module): @@ -17,18 +16,17 @@ def __init__(self, embed_dim: int, num_heads: int, self.dropout = torch.nn.Dropout(p=dropout) self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - warnings.warn('residual is disabled in encoder block') def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.self_attn_layer_norm(x) x = self.self_attn(x) x = self.dropout(x) - # x = x + residual + x = x + residual residual = x x = self.final_layer_norm(x) x = self.mlp(x) x = self.dropout(x) - # x = x + residual + x = x + residual return x diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py index 961f7214..7ff2de9d 100644 --- a/examples/nlp/blocks/mlp.py +++ b/examples/nlp/blocks/mlp.py @@ -1,11 +1,13 @@ import torch import cube +import warnings -@cube.graph.parser.register('L^ N E^, H+ E^, H+, E H+, E -> L^ N E') +# @cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+, E^ -> L^ N E^', name='feedforward') +@cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj1_bias: torch.Tensor, - proj2: torch.Tensor, proj2_bias: torch.Tensor, + proj2: torch.Tensor, proj2_bias: None, #torch.Tensor, dropout: float) -> torch.Tensor: x = torch.nn.functional.linear(x, proj1, proj1_bias) x = torch.nn.functional.gelu(x) @@ -22,7 +24,9 @@ def __init__(self, embed_dim: int, hidden_dim: int, dropout: float, bias=True): self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) + self.proj2_bias = None # torch.nn.Parameter(torch.empty((embed_dim,))) self.dropout = dropout + warnings.warn('feedforward output bias is skipped for correctness') def forward(self, x: torch.Tensor): x = feedforward(x, diff --git a/examples/nlp/gpt/policy/naive.py b/examples/nlp/gpt/policy/naive.py deleted file mode 100644 index 10de3596..00000000 --- a/examples/nlp/gpt/policy/naive.py +++ /dev/null @@ -1,8 +0,0 @@ -from cube.graph import IRGraph - -def PAS(graph: IRGraph, resource): - # print(graph.extra_repr()) - for node in graph.nodes(): - graph.assign(node, 0) - # print(graph.extra_repr()) - return graph \ No newline at end of file diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py new file mode 100644 index 00000000..d8b32e45 --- /dev/null +++ b/examples/nlp/gpt/policy/spmd.py @@ -0,0 +1,51 @@ +from cube.graph import IRGraph +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation + + +def PASSingle(graph: IRGraph, resource): + assert resource.ngpus == 1 + # print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + return graph + + +def PASMegatron(graph: IRGraph, resource): + tp_size = resource.ngpus + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + if isinstance(comment, str): + for sub_node in sub_nodes: + sub_node.comment = comment + assert all(isinstance(n, IRFwOperation) for n in sub_nodes), f"Fail to partition node {node}" + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return sub_nodes + + # annotating code structure -- not consider multiref on embedding weight + multirefs = [node for node in fnodes if isinstance(node, IRFwOperation) and node.name == 'multiref'][1:] + for idx in range(0, len(multirefs), 2): + multirefs[idx].comment = f'====> start of transformer {idx // 2}' + + # attention + attns = [node for node in fnodes if node.name == 'self_attention'] + for attn in attns: + tensor_parallelism(attn, idx=1, dim=0, num=tp_size) + + # feedforward + ffns = [node for node in fnodes if node.name == 'feedforward'] + for ffn in ffns: + tensor_parallelism(ffn, idx=1, dim=0, num=tp_size) + + # replicate other nodes + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: + rnodes = graph.replicate(node, times=tp_size) + for idx, rnode in enumerate(rnodes): + graph.assign(rnode, idx) + + return graph diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 6c0000fb..48af650e 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -17,7 +17,7 @@ from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary, model_summary -from examples.nlp.gpt.policy.naive import PAS +from examples.nlp.gpt.policy.spmd import PASMegatron as PAS def train(): @@ -28,9 +28,6 @@ def train(): dataloader = GPTDataLoader(batch_size) optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - print_each_rank('model weight consumpition:') - memory_summary() - model = cube.SemanticModel(model, dataloader.shapes) @cube.compile(model, dataloader, PAS=PAS, override=True) def train_iter(model, dataloader): @@ -39,32 +36,33 @@ def train_iter(model, dataloader): loss.backward() model = model.get_gen_module() + torch.distributed.barrier() + print_each_rank('model weight consumpition:', rank_only=0) + memory_summary() + CudaTimer(enable=False).warmup() - iter_num = 64 + iter_num = 40 + warmup = 8 for step in range(iter_num): - # if step == 0: # model_summary(model, next(dataloader)) - if step >= 20: + if step >= warmup: CudaTimer(enable=True).start('e2e') - - # training train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() - - if step >= 20: + if step >= warmup: CudaTimer().stop('e2e') if step == 0: print_each_rank('passed first iteration') - if (step + 1) % 10 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) memory_summary() From 5c530bf2e6f40b53871be8e4577faf32804286fd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 3 Aug 2022 15:56:26 +0800 Subject: [PATCH 0934/1892] add gpt model and policy --- examples/nlp/blocks/attention.py | 3 +- examples/nlp/blocks/mlp.py | 3 +- examples/nlp/gpt/model.py | 47 +++++++++---- examples/nlp/gpt/policy/mpmd.py | 110 +++++++++++++++++++++++++++++++ examples/nlp/gpt/train.py | 4 +- 5 files changed, 150 insertions(+), 17 deletions(-) create mode 100644 examples/nlp/gpt/policy/mpmd.py diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 7661c6a4..a1d94a34 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -115,7 +115,8 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) # self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None self.out_bias = None - warnings.warn('self attention dense bias is skipped for correctness.') + if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + warnings.warn('self attention dense bias is skipped for correctness.') def forward(self, query): return self_attention( diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py index 7ff2de9d..faa98ec2 100644 --- a/examples/nlp/blocks/mlp.py +++ b/examples/nlp/blocks/mlp.py @@ -26,7 +26,8 @@ def __init__(self, embed_dim: int, hidden_dim: int, dropout: float, bias=True): self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) self.proj2_bias = None # torch.nn.Parameter(torch.empty((embed_dim,))) self.dropout = dropout - warnings.warn('feedforward output bias is skipped for correctness') + if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + warnings.warn('feedforward output bias is skipped for correctness') def forward(self, x: torch.Tensor): x = feedforward(x, diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 73e24739..980f7dab 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -1,31 +1,45 @@ -import torch +import torch from examples.nlp.blocks.encoder import EncoderLayer - import cube class Config: - num_embeddings = 50304 - seqlen = 512 + num_embeddings = 50432 + seqlen = 1024 - # 1.7B model - embed_dim = 2304 + # 340 M model + embed_dim = 1024 layers = 8 # 24 - attention_heads = 24 + attention_heads = 16 + + # 1.3 B model + # embed_dim = 2048 + # layers = 24 + # attention_heads = 32 - # 3.6B model - # embed_dim = 3072 + # 2.6 B model + # embed_dim = 2560 # layers = 32 # attention_heads = 32 - # 7.5B model + # 6.7 B model # embed_dim = 4096 # layers = 32 + # attention_heads = 32 + + # 15 B model + # embed_dim = 5120 + # layers = 48 # attention_heads = 36 + # 39 B model + # embed_dim = 8192 + # layers = 48 + # attention_heads = 64 + attn_hidden_dim = embed_dim ffn_hidden_dim = embed_dim * 4 dropout = 0.0 @@ -39,7 +53,8 @@ def __init__(self): super().__init__() cfg = Config() - self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) + # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) + self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.embed_dim)) self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) self.embed_dropout = torch.nn.Dropout() @@ -54,17 +69,23 @@ def __init__(self): def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): - embed = self.embed(input_ids) + # embed = self.embed(input_ids) + embed = torch.nn.functional.embedding( + input_ids, self.embedw, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False + ) pos_embed = self.position(position_ids) embed = embed + pos_embed embed = self.embed_dropout(embed) enc = embed.transpose(0, 1) for layer in self.layers: + cube.runtime.function.anchor('transformer start') enc = layer(enc) enc = self.final_layernorm(enc) - logits = torch.nn.functional.linear(enc, self.embed.weight) + # logits = torch.nn.functional.linear(enc, self.embed.weight) + logits = torch.nn.functional.linear(enc, self.embedw) # simplified loss = torch.sum(logits) return loss diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py new file mode 100644 index 00000000..02e79ead --- /dev/null +++ b/examples/nlp/gpt/policy/mpmd.py @@ -0,0 +1,110 @@ +from typing import List, Tuple +import numpy as np + +from cube.graph import IRGraph +from cube.graph.function.anchor import IRGraphAnchor +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.schedule.sched1f1b import IRSchedule1F1B + + +def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + + e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: + ( + ( (0,1,2), (3,4,5) ), + ( (0,3), (2,5), (3,6) ), + ) + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + + +def PASRoundRobin(graph: IRGraph, resource): + """ + roundrobin scheduling + """ + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # group to transformer layers + transformers: List[List[IRFwOperation]] = [] + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + start = idx if lid != 0 else 0 + end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) + transformers.append(fnodes[start:end]) + for lid in range(len(transformers) - 1): + if transformers[lid][-1].name == 'multiref': + node = transformers[lid].pop() + transformers[lid+1].insert(0, node) + + for lid, transformer in enumerate(transformers): + stage_id = lid % resource.ngpus + print(f'assigning {lid} transformer to stage {stage_id}') + for node in transformer: + graph.assign(node, stage_id) + + for node in graph.nodes(): + if len(node.device) == 0: + graph.assign(node, 0) + + # print(graph.extra_repr()) + return graph + + +def PAS1F1B(graph: IRGraph, resource): + """ + 1F1B scheduling + """ + num_stage = resource.ngpus + num_microbatch = resource.ngpus + _, stage_mesh = _create_mesh(resource.ngpus, (num_stage, 1)) + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # group to transformer layers + transformers: List[List[IRFwOperation]] = [] + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + start = idx if lid != 0 else 0 + end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) + transformers.append(fnodes[start:end]) + for lid in range(len(transformers) - 1): + if transformers[lid][-1].name == 'multiref': + node = transformers[lid].pop() + transformers[lid+1].insert(0, node) + + # staging + nlayer_per_stage = (len(transformers) // resource.ngpus) + for lid, fnodes in enumerate(transformers): + stage_id = min(lid // nlayer_per_stage, num_stage-1) + print(f'assigning {lid}-th transformer layter to stage {stage_id}') + for fnode in fnodes: + graph.assign(fnode, stage_id) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + graph.assign(node, 0) + + schedule = IRSchedule1F1B(num_microbatch, stage_mesh, recompute=False) + graph.schedule_plan = schedule + return graph diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 48af650e..dc9b7e78 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -2,7 +2,7 @@ example: OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ examples/nlp/gpt/train.py """ @@ -17,7 +17,7 @@ from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary, model_summary -from examples.nlp.gpt.policy.spmd import PASMegatron as PAS +from examples.nlp.gpt.policy.mpmd import PASRoundRobin as PAS def train(): From 03ba45415a9c2e99ee657b8c12363c0d754c67b4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 3 Aug 2022 15:57:18 +0800 Subject: [PATCH 0935/1892] add schedule --- cube/graph/graph.py | 173 ++++++++++++++++++++++---------------------- 1 file changed, 88 insertions(+), 85 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 66cd79c1..a115c570 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -446,7 +446,7 @@ def from_logic_graph(nodes: List[IRCell], if isinstance(ftensor, IRFullTensor): producers[ftensor] = node for ftensor, cnodes in consumers.items(): - if len(cnodes) == 1: continue + if len(cnodes) == 1 or ftensor.is_param(): continue itensors = [ftensor.like() for _ in range(len(cnodes))] for itensor, consumer in zip(itensors, cnodes): while ftensor in consumer.inputs(): @@ -458,6 +458,7 @@ def from_logic_graph(nodes: List[IRCell], multiref.set_output(idx, itensor) multiref.infer_shape() idx = nodes.index(producers[ftensor]) + 1 if ftensor in producers else 0 + # idx = nodes.index(cnodes[0]) nodes.insert(idx, multiref) # instantiate graph inputs / outputs @@ -698,6 +699,89 @@ def happen_before(self, node1: IRCell, node2: IRCell, skip=None) -> bool: return True return False + def depends(self, pre_node: IRCell, post_node: IRCell) -> bool: + """! + Check whether pre_node has dataflow dependency on post_node: + pre_node -> post_node + + @param pre_node: the happen before node + @param post_node: the happen after node + + @return ret bool: True if post_node depends on pre_node on dataflow, otherwise False. + """ + itensors = [t for t in post_node.inputs() if isinstance(t, IRSubTensor)] + for otensor in pre_node.outputs(): + if not isinstance(otensor, IRSubTensor): continue + for itensor in itensors: + if otensor.overlap(itensor): + return True + return False + + def schedule(self, node1: IRCell, action: str, node2: IRCell) -> bool: + """! + Schedule node1 and node2 based on the action + + The node2 will keep unchanged in the sequence and schedule will perform + on node1. + + @param node1 IRCell + @param node2 IRCell + @param action str: + 'after': fixed node2 and schedule node1 after node2 in the sequence. + 'before': fixed node2 and schedule node1 before node2 in the sequence. + + @return success bool: True if the scheduling success otherwise False. + """ + idx1 = self._nodes.index(node1) + idx2 = self._nodes.index(node2) + # node2 -> node1 + if action == 'after': + if idx2 < idx1: + return True + for idx in range(idx1+1, idx2+1): + if self.depends(node1, self._nodes[idx]): + return False + self.detach(node1) + self.attach(node1, idx2) + return True + # node1 -> node2 + if action == 'before': + if idx1 < idx2: + return True + for idx in range(idx2, idx1): + if self.depends(self._nodes[idx], node1): + return False + self.detach(node1) + self.attach(node2, idx2) + return True + raise KeyError(f"Unknown scheduling action {action}") + + @staticmethod + def legal_schedule(seq: List[IRCell], integrity_check=False): + """ + Check whether seq satisfies topological order. + + @note: this functionality is not enabled due to predecessor and succesor + functionality. + + @param seq List[IRCell]: the nodes in scheudled order + @param integrity_check bool: + If true, performs additional integrity check that requires + all the nodes in predecessor and successor of a node should + appear in the sequence. + + @return valid bool: True for satisfying topo order, otherwise False. + """ + for index, node in enumerate(seq): + for pre in node.predecessors(): + if pre in seq: + pre_idx = seq.index(pre) + if pre_idx >= index: + return False + elif integrity_check: + return False + return True + def add_schedule(self, nodes: List[IRCell]) -> bool: """ Add node happen before dependencies according to nodes list order @@ -716,6 +800,9 @@ def add_schedule(self, nodes: List[IRCell]) -> bool: post.add_predecessor(input_index=-1, cell=prev) return True + + # ================= Other optimizations ================== + def recompute(self, nodes: List[IRFwOperation]) -> bool: """! Recompute a set of nodes. The forward nodes will be assigned with a unique @@ -731,90 +818,6 @@ def recompute(self, nodes: List[IRFwOperation]) -> bool: fnode.recompute = recompute_group_id return True - def set_order(self, seq: List[IRCell]): - """ - Set a topological order for IRGraph, which requires seq: - - 1). The set of nodes in seq must be same with this IRGraph - 2). Staisfies topological order - - Returns: - True if set succesfully, False not. - """ - for node in seq: - if node not in self.nodes(): - return False - if len(seq) != len(self.nodes()): - return False - if not IRGraph.check_legal_order(seq, integrity_check=True): - return False - self._nodes = seq - return True - - def partial_set_order(self, seq: List[IRCell], eager=True): - """ - Set a partial topological order for IRGrah. - The remaining nodes will be automatically inserted to - make the full legal sequence. - - In most of the cases, `eager=True` has better performance. - - Args: - seq: partial scheduling sequence - eager (default True): - if True, the remaining nodes are inserted once it is ready - if Flase, the remaining nodes are inserted only when it is needed. - - Returns: - True if set succesfully, False not. - """ - seq = copy.copy(seq) - for node in seq: - if node not in self.nodes(): - return False - if not IRGraph.check_legal_order(seq, integrity_check=False): - return False - remain: List[IRCell] = [node for node in self.nodes() if node not in seq] - for node in remain: - if eager: - pre_indices = [seq.index(pre) for pre in node.predecessors()] - if len(pre_indices) == 0: - index = 0 - else: - index = max(pre_indices) + 1 - else: - suc_indices = [seq.index[suc] for suc in node.successors()] - index = min(suc_indices) - seq.insert(index, node) - self._nodes = seq - return True - - @staticmethod - def check_legal_order(seq: List[IRCell], integrity_check=False): - """ - Check whether seq satisfies topological order. - - Args: - seq: List of IRCell - integrity_check: - If true, performs additional integrity check that requires - all the SUs in predecessor and successor of a SU should - appear in the sequence. - - Returns: - Boolean: True for satisfying topo order, otherwise False. - """ - #TODO: check no new operators are created (including replicate) - for index, node in enumerate(seq): - for pre in node.predecessors(): - if pre in seq: - pre_idx = seq.index(pre) - if pre_idx >= index: - return False - elif integrity_check: - return False - return True - def __repr__(self): dscp = f"Graph{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" return dscp From ab19409db3d84ee6a6d2ad25a17ec1e448794ad4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 3 Aug 2022 20:43:13 +0800 Subject: [PATCH 0936/1892] fix empty adapter --- cube/ir/adapter/adapter.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index 8455a242..154565b0 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -110,8 +110,8 @@ def dispatch(self, devid: int, for_mirror=True): outputs.append(otensor) # insert identity prims if len(prims) == 0: - assert len(inputs) == len(outputs) and all(itensor in outputs for itensor in inputs), \ - "input/output tensor not match for empty prims" + assert all(otensor in inputs for otensor in outputs), \ + "output tensor not apear in input tensors for empty prims" for itensor in inputs: prims.append(IdentityPrim(itensor)) # dispatch @@ -127,6 +127,26 @@ def dispatch(self, devid: int, for_mirror=True): IRCell.make_pair(fadapter, badapter) return fadapter + @staticmethod + def merge(adapters: List): + """! + Merge adapters to one adapter + """ + adapters : List[IRAdapter] = adapters + assert all(isinstance(n, IRAdapter) for n in adapters) + # TODO: check recompute consistency + itensors = [] + otensors = [] + prims = [] + for adapter in adapters: + itensors += adapter.inputs() + otensors += adapter.outputs() + prims += adapter.prims + adapter = IRAdapter(itensors, otensors) + adapter.prims = prims + return adapter + + def __repr__(self): return f'Adapter-{self._id}{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' From 6d84d75377cf50b9c2863f1338105c5860ac9d60 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 10:12:28 +0800 Subject: [PATCH 0937/1892] schedule to be generalize --- cube/graph/schedule/sched1f1b.py | 98 +++++++++---------------- cube/graph/schedule/strategy.py | 118 +++++++++++++++++++++---------- 2 files changed, 115 insertions(+), 101 deletions(-) diff --git a/cube/graph/schedule/sched1f1b.py b/cube/graph/schedule/sched1f1b.py index b7825d0d..d3cb8804 100644 --- a/cube/graph/schedule/sched1f1b.py +++ b/cube/graph/schedule/sched1f1b.py @@ -1,7 +1,8 @@ -from typing import Dict, Tuple -from cube.ir.adapter.adapter import IRAdapter +from typing import Dict, Tuple, Optional + from cube.ir.cten import IRCell +from cube.ir.adapter.adapter import IRAdapter from cube.graph.graph import IRGraph, IRSegment from cube.graph.schedule import IRScheduleStrategy @@ -11,82 +12,51 @@ class IRSchedule1F1B(IRScheduleStrategy): """ 1F1B Scheduling - This requires a micro-batch can be grouped into continguous segments - which are placed on distinct device groups (refered as a stage): + This treats model as a linear graph which can be + grouped into continous stages. [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] [Recv-Backward] Backward-Segment [Send-Backward] """ - def __init__(self, num_microbatch: int, devmesh: Tuple[Tuple[int]], recompute=False): - super().__init__(num_microbatch, devmesh) + def __init__(self, graph, nmicros: int, devmesh: Tuple[Tuple[int]]): + super().__init__(graph, nmicros, devmesh) self.signature = 'cube.runtime.schedule.Schedule1F1B.run' # forward body - self.segment = dict() + self.segment: Dict[int, IRSegment] = dict() # forward send - self.sfadapter = dict() + self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() # forward recv - self.rfadapter = dict() + self.rfadapter: Dict[int, Optional[IRAdapter]] = dict() # backard send - self.sbadapter = dict() + self.sbadapter: Dict[int, Optional[IRAdapter]] = dict() # backward recv - self.rbadapter = dict() + self.rbadapter: Dict[int, Optional[IRAdapter]] = dict() # num_stage - self.num_stages = len(devmesh) + self.num_stages: int = len(devmesh) # stage id - self.stage_id = dict() + self.stage_id: Dict[int, int] = dict() # recompute - self.recompute = recompute + self.recompute = False - def apply(self, graph: IRGraph) -> IRGraph: - graph = IRSchedule1F1B.segmentation(graph, self.devmesh) - for stage_id, devices in enumerate(self.devmesh): + def apply(self) -> IRGraph: + self.segmentation() + for gid, devices in enumerate(self.devmesh): for devid in devices: - nodes = [n for n in graph.nodes() if devid in n.device] + # forward recv + self.rfadapter[devid] = None if gid == 0 else self.cross_groups[gid-1] # forward body - fsegments = [seg for seg in nodes if isinstance(seg, IRSegment) and seg.forward] - assert len(fsegments) == 1, "find more than one segment." - fsegment = fsegments[0] - self.segment[devid] = fsegment - fidx = nodes.index(fsegment) - bidx = nodes.index(fsegment.mirror) - # adapters - adapters = [adapter for adapter in nodes if isinstance(adapter, IRAdapter)] - # forward sends - forward_sends = [n for n in adapters if n.forward and nodes.index(n) > fidx] - if stage_id == self.num_stages - 1: - assert len(forward_sends) == 0, f"stage: {stage_id}: last stage should not send forward outputs" - self.sfadapter[devid] = None - else: - assert len(forward_sends) == 1, f"stage: {stage_id}: last stage should not send forward outputs" - self.sfadapter[devid] = forward_sends[0] - # forward recvs - forward_recvs = [n for n in adapters if n.forward and nodes.index(n) < fidx] - if stage_id == 0: - assert len(forward_recvs) == 0, f"stage: {stage_id}: first stage should not recv inputs" - self.rfadapter[devid] = None - else: - assert len(forward_recvs) == 1, f"stage: {stage_id}: non-first stage should recv 1 inputs" - self.rfadapter[devid] = forward_recvs[0] - # backward sends - backward_sends = [n for n in adapters if not n.forward and nodes.index(n) > bidx] - if stage_id == 0: - assert len(backward_sends) == 0, f"stage: {stage_id}: first stage should not send back gradient" - self.sbadapter[devid] = None - else: - assert len(backward_sends) == 1, f"stage: {stage_id}: non-first stage should not send back gradient" - self.sbadapter[devid] = backward_sends[0] - # backward recvs - backward_recvs = [n for n in adapters if not n.forward and nodes.index(n) < bidx] - if stage_id == self.num_stages - 1: - assert len(backward_recvs) == 0, f"stage: {stage_id}: last stage should not recv gradient" - self.rbadapter[devid] = None - else: - assert len(backward_recvs) == 1, f"stage: {stage_id}: non-last stage should recv 1 gradient" - self.rbadapter[devid] = backward_recvs[0] + self.segment[devid] = self.inner_groups[gid] + # forward send + if gid == len(self.devmesh)-1: assert self.cross_groups[gid] is None + self.sfadapter[devid] = self.cross_groups[gid] + # backward recv + self.rbadapter[devid] = None if gid == len(self.devmesh)-1 else self.sfadapter[devid].mirror + # backward send + self.sbadapter[devid] = None if gid == 0 else self.rfadapter[devid].mirror # stage id - self.stage_id[devid] = stage_id - return graph + self.stage_id[devid] = gid + return self.graph def kwargs(self, devid: int) -> Dict[str, IRCell]: """ @@ -101,15 +71,17 @@ def kwargs(self, devid: int) -> Dict[str, IRCell]: dataloader = 'dataloader', stage_id = self.stage_id[devid], num_stages = self.num_stages, - num_microbatch = self.num_microbatch, + num_microbatch = self.nmicros, recompute = self.recompute ) def __repr__(self) -> str: dscp = '' for mesh in self.devmesh: - dscp += (f"1F1B-Schedule-stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" - f" segment = {self.segment[mesh[0]]}\n" + devid = mesh[0] + segment = self.segment[devid].to_str(skip_attr=True) if self.segment[mesh[0]] else None + dscp += (f"1F1B Schedule: Stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" + f" segment = {segment}\n" f" send-fw = {self.sfadapter[mesh[0]]}\n" f" recv-fw = {self.rfadapter[mesh[0]]}\n" f" send-bw = {self.sbadapter[mesh[0]]}\n" diff --git a/cube/graph/schedule/strategy.py b/cube/graph/schedule/strategy.py index 5c254e44..852fc889 100644 --- a/cube/graph/schedule/strategy.py +++ b/cube/graph/schedule/strategy.py @@ -1,5 +1,5 @@ -from typing import Tuple, Dict, Any -from cube.graph.graph import IRGraph +from typing import Tuple, Dict, Any, List +from cube.graph.graph import IRGraph, IRSegment from cube.ir.adapter.adapter import IRAdapter from cube.ir.cten import IRCell from cube.ir.operator import IRFwOperation @@ -7,9 +7,12 @@ class IRScheduleStrategy: - def __init__(self, num_microbatch: int, devmesh: Tuple[Tuple[int]]) -> None: - self.num_microbatch = num_microbatch - self.devmesh = devmesh + def __init__(self, graph: IRGraph, nmicros: int, devmesh: Tuple[Tuple[int]]) -> None: + self.graph : IRGraph = graph + self.nmicros : int = nmicros + self.devmesh: Tuple[Tuple[int]] = devmesh + self.inner_groups: List[IRSegment] = [None] * len(devmesh) + self.cross_groups: List[IRAdapter] = [None] * len(devmesh) self.signature: str = '' def apply(self, graph: IRGraph) -> IRGraph: @@ -18,38 +21,77 @@ def apply(self, graph: IRGraph) -> IRGraph: def kwargs(self, device: int) -> Dict[str, Any]: raise NotImplementedError - @staticmethod - def segmentation(graph: IRGraph, devmesh: Tuple[Tuple[int]]) -> IRGraph: - """ - Utilities for grouping operators into segments with device mesh + def segmentation(self): + """! + Group operators into segments corresponding to devmesh. + + A greedy grouping is applied for each group given the device mesh. + The non-differentiable adapters need to be moved at the boundary + of device mesh, as the cross group communication. """ - stages = [[] for _ in range(len(devmesh))] - for node in graph.nodes(): - for meshid, devices in enumerate(devmesh): - if set(node.device).issubset(set(devices)): - stages[meshid].append(node) - break + def differientiable(node: IRCell) -> bool: + return isinstance(node, IRFwOperation) or \ + (isinstance(node, IRAdapter) and node.forward and node.differentiable) + + devmesh = self.devmesh + inner_groups: List[List[IRCell]] = [[] for _ in range(len(devmesh))] + cross_groups: List[List[IRAdapter]] = [[] for _ in range(len(devmesh))] + sid = 0 + for node in self.graph.nodes(): + if not (isinstance(node, (IRFwOperation, IRAdapter))): + continue + devs = set(node.device) + if differientiable(node): + while sid < len(devmesh) and not devs.issubset(devmesh[sid]): + sid += 1 + assert sid < len(devmesh), f"invalid stategy with graph placement" + inner_groups[sid].append(node) + else: + if not (isinstance(node, IRAdapter) and node.forward): + continue + assert not devs.issubset(devmesh[sid]), f"find a non-differentiable adapter in devmesh: {devmesh[sid]}" + cross_mesh = devmesh[sid] + devmesh[sid+1] if sid < len(devmesh) - 1 else devmesh[sid] + assert devs.issubset(set(cross_mesh)) + cross_groups[sid].append(node) + + # move non-differentiable adapter to the boundary of groups + for igroup, cgroup in zip(inner_groups, cross_groups): + if len(igroup) == 0: + print('warning: find a group with no operator') + continue + last_node: IRCell = igroup[-1] + for fadapter in cgroup[::-1]: + success = self.graph.schedule(fadapter, 'after', last_node) + if fadapter.mirror is not None and last_node.mirror is not None: + success = self.graph.schedule( + fadapter.mirror, 'before', last_node.mirror + ) + if not success: + raise RuntimeError("Fail to schedule non-differentiable adapter to group boundaries") + # grouping - for stage in stages: - fconsecutive, bconsecutive = [], [] - for node in stage: - if isinstance(node, IRFwOperation) or (isinstance(node, IRAdapter) and node.forward): - fconsecutive.append(node) - if node.mirror: - bconsecutive.append(node.mirror) - else: - assert len(fconsecutive) == len(bconsecutive) or len(bconsecutive) == 0, 'mismatch number of forward and backward operators.' - if len(fconsecutive) != 0: - fsegment = graph.group(fconsecutive) - if len(bconsecutive) != 0: - bsegment = graph.group(bconsecutive[::-1]) - IRCell.make_pair(fsegment, bsegment) - fconsecutive, bconsecutive = [], [] - return graph - - @staticmethod - def merging(graph: IRGraph) -> IRGraph: - """ - merge the adapters into one - """ - pass + for gid in range(len(devmesh)): + # group computation groups + igroup = inner_groups[gid] + if len(igroup) != 0: + fsegment = self.graph.group(igroup) + bnodes = [n.mirror for n in igroup[::-1] if n.mirror is not None] + if len(bnodes) != 0: + bsegment = self.graph.group(bnodes) + IRCell.make_pair(fsegment, bsegment) + self.inner_groups[gid] = fsegment + else: + self.inner_groups[gid] = None + # merge cross communication adapters + cgroup = cross_groups[gid] + if len(cgroup) == 1: + self.cross_groups[gid] = cgroup[0] + elif len(cgroup) > 1: + fadapter = IRAdapter.merge(cgroup) + bnodes = [n.mirror for n in igroup[::-1] if n.mirror is not None] + if len(bnodes) != 0: + badapter = IRAdapter.merge(bnodes) + IRCell.make_pair(fadapter, badapter) + self.cross_groups[gid] = fadapter + else: + self.cross_groups[gid] = None From 0627d92757c405dca0eaee33b4b586c9096933af Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 10:17:01 +0800 Subject: [PATCH 0938/1892] fix .eval() bug --- cube/execplan/execplan.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index e60de05b..d9173d2f 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -14,7 +14,11 @@ def __init__(self, graph: IRGraph): assert isinstance(graph, IRGraph), "Expected an IRGraph" self._graph = graph self._seq: Dict[int, List[IRCell]] = dict() - self._inference_only = not any(isinstance(n, IRBpOperation) for n in graph.nodes()) + self._inference_only = not any( + isinstance(n, IRBpOperation) or \ + (isinstance(n, IRAdapter) and not n.forward) or \ + (isinstance(n, IRSegment) and not n.forward) for n in graph.nodes() + ) # execution sequence for each device for node in graph.nodes(): From ff8c0a4ec057c0d0dfd0a5ab9b4ea99ec8a377f9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 10:18:43 +0800 Subject: [PATCH 0939/1892] schedule plan --- cube/codegen/codegen.py | 4 +- cube/compiler.py | 11 ++-- cube/graph/graph.py | 138 +++++++++++++++++++--------------------- 3 files changed, 73 insertions(+), 80 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index b8f065c7..952eaaa0 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -640,8 +640,8 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # body code if len(device_nodes) == 0: fb.insert_body('pass') - elif self.execplan.graph.schedule_plan: - code = self.emit_schedule_plan(self.execplan.graph.schedule_plan, device) + elif self.execplan.graph.sched: + code = self.emit_schedule_plan(self.execplan.graph.sched, device) fb.insert_body(code) else: for i, node in enumerate(device_nodes): diff --git a/cube/compiler.py b/cube/compiler.py index 26996ff1..1e1e16d5 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -171,9 +171,12 @@ def decorator(fn: Callable) -> Callable: # generate adapter graph = IRAdapterGener.gen(graph) - if graph.schedule_plan: - graph = graph.schedule_plan.apply(graph) - print(graph.schedule_plan) + if graph.sched is not None: + start = time.time() + graph.sched.apply() + span = time.time() - start + print('> planpass on applying schedule strategy: {:.2f} s'.format(span)) + print(graph.sched) # to execution plan execplan = ExecutionPlan(graph) @@ -185,7 +188,7 @@ def decorator(fn: Callable) -> Callable: print('> planpass on diff-fusion operations: {:.2f} s'.format(span)) # plan pass for computation grouping - if not graph.schedule_plan: + if not graph.sched: start = time.time() execplan = Grouping.apply(execplan) span = time.time() - start diff --git a/cube/graph/graph.py b/cube/graph/graph.py index a115c570..3caa3ef2 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,7 +7,7 @@ will be inserted at scheduling time. """ -from typing import Any, Union, Tuple, List, Optional, Dict +from typing import Union, Tuple, List, Optional, Dict import copy from cube.graph.function.function import MultiRef @@ -84,9 +84,14 @@ def dispatch(self, devid: int, for_mirror=True) -> Optional[IRCell]: IRCell.make_pair(fseg, bseg) return fseg - def __repr__(self): + def to_str(self, skip_attr: bool = False) -> str: name = ('f' if self.forward else 'b') + 'Segment' - return f'{name}{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' + inputs = tuple(t for t in self.inputs() if not (t.is_param() and skip_attr)) + outputs = tuple(t for t in self.outputs() if not (t.is_param() and skip_attr)) + return f'{name}{self._id}-{self.device}(inputs={inputs}, outputs={outputs})' + + def __repr__(self): + return self.to_str() def extra_repr(self) -> str: dscp = repr(self) @@ -111,7 +116,7 @@ def __init__(self, self._parameters = list() self._full_tensors: Dict[int, IRFullTensor] = dict() - self._schedule_strategy = None + self._sched = None # the schedule strategy if inputs is None: inputs = IRGraph.get_inputs(nodes) @@ -148,14 +153,6 @@ def __init__(self, self.reset_dependency() - @property - def schedule_plan(self) -> Optional[Any]: - return self._schedule_strategy - - @schedule_plan.setter - def schedule_plan(self, val: Optional[Any]): - self._schedule_strategy = val - def reset_dependency(self): """ Reset the node dataflow dependency @@ -227,13 +224,17 @@ def __call__(self, *args): return self.forward(*args) def segment(self, nodes: List[IRCell]) -> IRSegment: - """ + """! Create a segment (sub-graph) with part of the nodes. + Nodes are allowed to be on different devices. + The grouped segement will not add into graph.nodes(). - Return: - IRSegment + @param nodes List[IRCell]: the subset nodes of this graph + + @return segment IRSegment: the grouped segment. """ inputs, outputs = [], [] + itdevs, otdevs = dict(), dict() for node in nodes: assert not isinstance(node, IRSegment), 'A segment cannot be in other segments' # update inputs @@ -242,9 +243,11 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: producers = [p for p in itensor.parent.producers if set(p.device).issubset(set(node.device))] # no producer means a weight or cross device-group if len(producers) == 0 or any(p not in nodes for p in producers): - # FIXME: itensor should also consider device difference - if itensor not in inputs: + if itensor not in itdevs: + itdevs[itensor] = [] + if itensor.device not in itdevs[itensor]: inputs.append(itensor) + itdevs[itensor].append(itensor.device) # update outputs otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] for otensor in otensors: @@ -253,18 +256,25 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: consumers = [c for c in otensor.parent.consumers if set(c.device).issubset(set(node.device))] # no consumer usually means the loss or cross device-group if len(consumers) == 0 or any(c not in nodes for c in consumers): - # FIXME: otensor should also consider device difference - if otensor not in outputs: + if otensor not in otdevs: + otdevs[otensor] = [] + if otensor.device not in otdevs[otensor]: outputs.append(otensor) + otdevs[otensor].append(otensor.device) segment = IRSegment(nodes, inputs, outputs) return segment def group(self, nodes: List[IRCell]) -> IRSegment: - """ - Group consecutive nodes into IRSegment. + """! + Group consecutive nodes into IRSegment. the grouped segment will + replace the nodes in the graph. - Currently this interface will break the dependency, + Note: Currently this interface will break the dependency, it can only be used after user policy + + @param nodes List[IRCell]: the consecutive node subset of this graph + + @return segment IRSegment: the grouped segment """ allnodes = self.nodes() indices = [allnodes.index(n) for n in nodes] @@ -592,59 +602,23 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], fnode.mirror.device = node.device return fnodes - def merge(self, nodes: List[IRCell], target_node: IRCell): - """ - Merge consecutive nodes in the graph to the target_node. - Note corresponding mirror nodes (if have) will also be merged. + def replace(self, old_nodes: List[IRCell], new_nodes: List[IRCell]): + """! + Replace nodes with node. + + Note we don't check semantic correctness for the replacement. - We don't check computation equivalence between nodes and target_node. + @param old_nodes List[IRCell]: nodes to be replaced + @param new_nodes List[IRCell]: nodes to replace in - Merge requires nodes are consecutive in the graph sequence. + @return True """ - if not isinstance(target_node, IRCell): - raise TypeError("Expected target node to be IRCell") - if target_node in self.nodes(): - raise ValueError("Target node is already in the graph") - for node in nodes: - if node not in self.nodes(): - raise KeyError(f"node {node} is not in the graph") - indices = [self.nodes().index(node) for node in nodes] - # consecutive - if max(indices) - min(indices) != len(indices) - 1: - return False - index = min(indices) - # update forward - for node in nodes: - self.detach(node) - self.attach(target_node, index) - # update backward - if all([isinstance(node.mirror, IRCell) for node in nodes]): - bidx = len(self.nodes()) - for node in nodes: - idx = self.detach(node.mirror) - bidx = min(idx, bidx) - if target_node.mirror is None: - if not isinstance(target_node, IRFwOperation): - raise RuntimeError("target node is not FwOp and doens't have mirror node") - target_node.gen_backward() - self.attach(target_node.mirror, bidx) - elif all([isinstance(node.mirror, None) for node in nodes]): - pass - else: - raise ValueError("nodes should have nothing-or-all mirror nodes") - # update weights - updated = set() - for node in nodes + [target_node]: - for input in node.inputs(): - if not isinstance(input, IRSubTensor): - continue - for fnode in input.parent.consumers: - bnode = fnode.mirror - if isinstance(bnode, IRBpOperation) and fnode._id not in updated: - idx = self.detach(bnode) - bnode.update() - self.attach(bnode, idx) - updated.add(fnode._id) + idx = len(self._nodes) + for old_node in old_nodes: + oidx = self.detach(old_node) + idx = min(oidx, idx) + for new_node in new_nodes[::-1]: + self.attach(new_node, idx) return True ## Spatial Primitives ## @@ -752,10 +726,26 @@ def schedule(self, node1: IRCell, action: str, node2: IRCell) -> bool: if self.depends(self._nodes[idx], node1): return False self.detach(node1) - self.attach(node2, idx2) + self.attach(node1, idx2) return True raise KeyError(f"Unknown scheduling action {action}") - + + @property + def sched(self): + """! + Return schedule plan for the execution. + """ + return self._sched + + @sched.setter + def sched(self, strategy): + """! + Set schedule plan for the execution. + + @param strategy IRScheduleStrategy: the schedule strategy instance + """ + self._sched = strategy + @staticmethod def legal_schedule(seq: List[IRCell], integrity_check=False): """ From 232654c262991692dce249f00fc540f00410d284 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 10:19:36 +0800 Subject: [PATCH 0940/1892] add gpt model of Megatron TP+PP --- examples/nlp/gpt/model.py | 2 +- examples/nlp/gpt/policy/mpmd.py | 103 +++++++++++++++++++++++++------- examples/nlp/gpt/train.py | 2 +- 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 980f7dab..5eade94b 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -12,7 +12,7 @@ class Config: # 340 M model embed_dim = 1024 - layers = 8 # 24 + layers = 4 # 24 attention_heads = 16 # 1.3 B model diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index 02e79ead..71400eb7 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -3,6 +3,7 @@ from cube.graph import IRGraph from cube.graph.function.anchor import IRGraphAnchor +from cube.ir.cten import IRCell from cube.ir.operator import IRDataOperation, IRFwOperation from cube.graph.schedule.sched1f1b import IRSchedule1F1B @@ -35,12 +36,7 @@ def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: return tuple(outputs) -def PASRoundRobin(graph: IRGraph, resource): - """ - roundrobin scheduling - """ - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - +def _group_to_transformers(fnodes) -> List[List[IRCell]]: # group to transformer layers transformers: List[List[IRFwOperation]] = [] anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] @@ -54,6 +50,36 @@ def PASRoundRobin(graph: IRGraph, resource): if transformers[lid][-1].name == 'multiref': node = transformers[lid].pop() transformers[lid+1].insert(0, node) + return transformers + +# ========================= parallelisms ================================= + +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# ========================= parallelisms ================================= + +def PASRoundRobin(graph: IRGraph, resource): + """ + roundrobin scheduling + """ + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # group to transformer layers + transformers = _group_to_transformers(fnodes) for lid, transformer in enumerate(transformers): stage_id = lid % resource.ngpus @@ -64,8 +90,7 @@ def PASRoundRobin(graph: IRGraph, resource): for node in graph.nodes(): if len(node.device) == 0: graph.assign(node, 0) - - # print(graph.extra_repr()) + return graph @@ -74,24 +99,13 @@ def PAS1F1B(graph: IRGraph, resource): 1F1B scheduling """ num_stage = resource.ngpus - num_microbatch = resource.ngpus + num_microbatch = resource.ngpus * 8 _, stage_mesh = _create_mesh(resource.ngpus, (num_stage, 1)) fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] # group to transformer layers - transformers: List[List[IRFwOperation]] = [] - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - fnodes[idx-1].comment = f'===> start of transformer layer {lid}' - start = idx if lid != 0 else 0 - end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) - transformers.append(fnodes[start:end]) - for lid in range(len(transformers) - 1): - if transformers[lid][-1].name == 'multiref': - node = transformers[lid].pop() - transformers[lid+1].insert(0, node) + transformers = _group_to_transformers(fnodes) # staging nlayer_per_stage = (len(transformers) // resource.ngpus) @@ -105,6 +119,49 @@ def PAS1F1B(graph: IRGraph, resource): if isinstance(node, IRDataOperation): graph.assign(node, 0) - schedule = IRSchedule1F1B(num_microbatch, stage_mesh, recompute=False) - graph.schedule_plan = schedule + strategy = IRSchedule1F1B(graph, num_microbatch, stage_mesh) + graph.sched = strategy + return graph + + +def PASMegatron(graph: IRGraph, resource): + """ + 1F1B scheduling + """ + dp_size = 1 + tp_size = 2 + pp_size = resource.ngpus // (dp_size * tp_size) + num_microbatch = resource.ngpus + + # device mesh + dp_groups, pp_groups, tp_groups = \ + _create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) + print(f'dp groups: {dp_groups}') + print(f'pp groups: {pp_groups}') + print(f'tp groups: {tp_groups}') + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # group to transformer layers + transformers = _group_to_transformers(fnodes) + + # staging + nlayer_per_stage = (len(transformers) // pp_size) + for lid, fnodes in enumerate(transformers): + sid = min(lid // nlayer_per_stage, pp_size-1) + print(f'assigning {lid}-th transformer layer to stage {sid}: {tp_groups[sid]}') + for fnode in fnodes: + if fnode.name == 'self_attention': + _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) + elif fnode.name == 'feedforward': + _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) + else: + _replica(graph, fnode, tp_groups[sid]) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + _replica(graph, node, tp_groups[0]) + + strategy = IRSchedule1F1B(graph, num_microbatch, tp_groups) + graph.sched = strategy return graph diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index dc9b7e78..307bef12 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -17,7 +17,7 @@ from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary, model_summary -from examples.nlp.gpt.policy.mpmd import PASRoundRobin as PAS +from examples.nlp.gpt.policy.mpmd import PASMegatron as PAS def train(): From 5214604a048314faaab31454c26d4664e4091b9b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 11:43:15 +0800 Subject: [PATCH 0941/1892] fix dtype infer bug --- cube/compiler.py | 11 +++++++---- cube/graph/parser/parser.py | 5 +++-- cube/ir/cten.py | 36 ++++++++++++++++++++++++++++-------- cube/ir/tensor.py | 2 +- 4 files changed, 39 insertions(+), 15 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 1e1e16d5..33a9b6d6 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -30,10 +30,13 @@ def __init__(self, model: torch.nn.Module, input_shapes): """ Create semantic model based on AI Scientist description. """ - from cube.graph import parser - self.ir_graph = parser.convert_model( - model, input_shapes=input_shapes - ) + dist = torch.distributed.is_initialized() + if (not dist) or (dist and torch.distributed.get_rank() == 0): + self.ir_graph = parser.convert_model( + model, input_shapes=input_shapes + ) + else: + self.ir_graph = None self._loaded_module = None def get_graph(self): diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 28c50c4f..f2aa1f8b 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -5,6 +5,7 @@ from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor +import cube.ir as ir from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import Sign2Op, DType2IRDType @@ -51,11 +52,11 @@ def parse_module(module, raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") # handle graph input -- Assuming all the inputs are tensors - kDefaultType = DType2IRDType.map(torch.get_default_dtype()) + # kDefaultType = DType2IRDType.map(torch.get_default_dtype()) for idx, input in enumerate(inputs): if isinstance(input.type(), torch._C.TensorType): shape = None if input_shapes is None else input_shapes[idx] - dtype = kDefaultType + dtype = ir.IRDType.unknown # kDefaultType val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.debugName()) else: raise NotImplementedError("Graph inputs only accepts Tensor") diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 7345c635..47755ef2 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -25,6 +25,30 @@ __all__ = ['IRCell', 'IRDType', 'IRTensor'] +class DTypeInferRule: + """ + According to promotion doc: + https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc + + complex > floating > integral > boolean + """ + @staticmethod + def infer(node, dtypes: List[IRDType]) -> IRDType: + dtypes = [dtype for dtype in dtypes if dtype != IRDType.unknown] + if IRDType.float32 in dtypes and IRDType.float16 in dtypes: + raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") + # in priority: fp32 > fp16 > bool > int64 > int16 > + priority = [ + IRDType.float64, IRDType.float32, IRDType.float16, + IRDType.int64, IRDType.int32, IRDType.int16, IRDType.int8, + IRDType.boolean + ] + for dtype in priority: + if dtype in dtypes: + return dtype + return IRDType.unknown + + class IRCell: r""" IRCell serves as a general node for different purpose @@ -51,7 +75,7 @@ def __init__(self, self.name: str = name self.signature = signature - self._dtype = IRDType.unknown + self._dtype = IRDType.unknown # output tensor dtype self._device = list() # source tensors @@ -258,15 +282,11 @@ def set_input(self, input_index: int, val): val = copy.copy(val) # set tensor dst val.cell = self - # set input value dtype - if self._dtype == IRDType.unknown: - self._dtype = val.dtype - for output in self.outputs(): - if isinstance(output, IRTensor): - output.dtype = self._dtype - val.dtype = self._dtype + # update dtype + self._dtype = DTypeInferRule.infer(self, [self._dtype, val.dtype]) self._inputs[input_index] = val + self.inputs.cache_clear() return val diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 02e72232..946f9337 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -250,7 +250,7 @@ class IRFullTensor(IRTensor): the sequentail execution order by its graph. """ - def __init__(self, shape=None, name=None, requires_grad=True, dtype=irdtype.float32): + def __init__(self, shape=None, name=None, requires_grad=True, dtype=irdtype.IRDType.unknown): super().__init__(shape, name, dtype) From 5ae45a957ac78c69de3d0b18b8010dc290a32362 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 11:43:40 +0800 Subject: [PATCH 0942/1892] support fp16 --- examples/nlp/gpt/policy/spmd.py | 2 +- examples/nlp/gpt/train.py | 28 +++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index d8b32e45..ced5f2b3 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -11,7 +11,7 @@ def PASSingle(graph: IRGraph, resource): return graph -def PASMegatron(graph: IRGraph, resource): +def PASMegatronTP(graph: IRGraph, resource): tp_size = resource.ngpus fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 307bef12..95549e55 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/nlp/gpt/train.py + examples/nlp/gpt/train.py --policy PASMegatronTP """ @@ -18,6 +18,31 @@ from cube.profiler.memory import memory_summary, model_summary from examples.nlp.gpt.policy.mpmd import PASMegatron as PAS +import examples.nlp.gpt.policy.spmd as spmd +import examples.nlp.gpt.policy.mpmd as mpmd + +import argparse + +parser = argparse.ArgumentParser(description='GPT Train') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') +parser.add_argument('--fp16', action='store_true', default=False, + help='use fp16 for the training') +args = parser.parse_args() + +cube.init() + +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +policies = [policy for policy in policies if policy.startswith('PAS')] +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + def train(): @@ -25,6 +50,7 @@ def train(): batch_size = 1 model = GPT() + model = model if not args.fp16 else model.half() dataloader = GPTDataLoader(batch_size) optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) From 8468324fe7a8c9cfda2a2b2012818c0de8c1bd11 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 13:49:37 +0800 Subject: [PATCH 0943/1892] fix dtype infer --- cube/ir/cten.py | 29 ------------------------- cube/ir/tensor.py | 17 +++++++++++++-- cube/logics/model.py | 51 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 64 insertions(+), 33 deletions(-) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 47755ef2..125ed853 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -25,30 +25,6 @@ __all__ = ['IRCell', 'IRDType', 'IRTensor'] -class DTypeInferRule: - """ - According to promotion doc: - https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc - - complex > floating > integral > boolean - """ - @staticmethod - def infer(node, dtypes: List[IRDType]) -> IRDType: - dtypes = [dtype for dtype in dtypes if dtype != IRDType.unknown] - if IRDType.float32 in dtypes and IRDType.float16 in dtypes: - raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") - # in priority: fp32 > fp16 > bool > int64 > int16 > - priority = [ - IRDType.float64, IRDType.float32, IRDType.float16, - IRDType.int64, IRDType.int32, IRDType.int16, IRDType.int8, - IRDType.boolean - ] - for dtype in priority: - if dtype in dtypes: - return dtype - return IRDType.unknown - - class IRCell: r""" IRCell serves as a general node for different purpose @@ -75,7 +51,6 @@ def __init__(self, self.name: str = name self.signature = signature - self._dtype = IRDType.unknown # output tensor dtype self._device = list() # source tensors @@ -282,8 +257,6 @@ def set_input(self, input_index: int, val): val = copy.copy(val) # set tensor dst val.cell = self - # update dtype - self._dtype = DTypeInferRule.infer(self, [self._dtype, val.dtype]) self._inputs[input_index] = val @@ -315,8 +288,6 @@ def set_output(self, output_index: int, val): if isinstance(val, IRTensor): val = copy.copy(val) val.cell = self - # set output value dtype - val.dtype = self._dtype self._outputs[output_index] = val self.outputs.cache_clear() diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 946f9337..0aad1958 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -21,7 +21,7 @@ from typing import List, Optional, Union, Tuple, NewType, Dict from cube.ir.cten import IRCell, IRTensor -import cube.ir.dtype as irdtype +from cube.ir.dtype import IRDType StartEnd = NewType('[start:end)', Tuple[int, int]) IdxChunk = NewType('(index, chunks)', Tuple[int, int]) @@ -250,7 +250,7 @@ class IRFullTensor(IRTensor): the sequentail execution order by its graph. """ - def __init__(self, shape=None, name=None, requires_grad=True, dtype=irdtype.IRDType.unknown): + def __init__(self, shape=None, name=None, requires_grad=True, dtype=IRDType.unknown): super().__init__(shape, name, dtype) @@ -466,6 +466,7 @@ def __init__(self, ftensor: IRFullTensor, """ indmap, valmap = IndexMap(indmap), ValueMap(valmap) assert isinstance(ftensor, IRFullTensor), "Expcted ftensor to be IRFullTensor" + assert 'dtype' not in kwargs, "IRSubTensor is not allowed to initialize with a dtype" super().__init__(shape=indmap.shape, name=ftensor.name, **kwargs) for attr in IRFullTensor._meta: setattr(self, attr, getattr(ftensor, attr)) @@ -515,6 +516,18 @@ def valmap(self) -> IdxChunk: def ndims(self) -> int: return len(self.shape) + @property + def dtype(self) -> IRDType: + return self.parent.dtype + + @dtype.setter + def dtype(self, val: IRDType): + if self.parent.dtype == IRDType.unknown: + self.parent.dtype = val + else: + assert self.parent.dtype == val, \ + f"dtype mis-matched with previous setting: {val} != {self.parent.dtype}" + def splitdims(self) -> Tuple[int]: """! Get partitioned dimensions diff --git a/cube/logics/model.py b/cube/logics/model.py index d0b29f71..bb9db939 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -1,11 +1,36 @@ -from typing import Any, List, Tuple +from typing import Tuple, List import copy from cube.graph.graph import IRGraph +from cube.ir.dtype import IRDType from cube.ir.tensor import IRSubTensor from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor + +class DTypeInferRule: + """ + According to promotion doc: + https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc + + complex > floating > integral > boolean + """ + @staticmethod + def infer(node, dtypes: List[IRDType]) -> IRDType: + dtypes = [dtype for dtype in dtypes if dtype != IRDType.unknown] + if IRDType.unknown in dtypes: + raise RuntimeError(f"Find an unkown dtype") + if IRDType.float32 in dtypes and IRDType.float16 in dtypes: + raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") + # in priority: fp32 > fp16 > bool > int64 > int16 > + priority = [ + IRDType.float64, IRDType.float32, IRDType.float16, + IRDType.int64, IRDType.int32, IRDType.int16, IRDType.int8, + IRDType.boolean + ] + for dtype in priority: + if dtype in dtypes: + return dtype + return IRDType.unknown def forward(graph: IRGraph, *args) -> IRGraph: @@ -14,6 +39,7 @@ def forward(graph: IRGraph, *args) -> IRGraph: """ if not isinstance(graph, IRGraph): raise TypeError("Requires IRGraph for forward") + # align graph with input tensors itensors: Tuple[IRSubTensor, ...] = graph.inputs() for idx, (itensor, arg) in enumerate(zip(itensors, args)): @@ -33,6 +59,27 @@ def forward(graph: IRGraph, *args) -> IRGraph: while itensor in graph.outputs(): oidx = graph.outputs().index(itensor) graph.set_output(oidx, arg) + for itensor in itensors: + del graph._full_tensors[itensor.parent.tid] + + # dtype inference + for node in graph.nodes(): + itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] + # setup gradient + for itensor in itensors: + if itensor.parent.grad is not None: + itensor.parent.dtype = itensor.dtype + if len(itensors) == 0: + continue + odtype = DTypeInferRule.infer(node, [t.dtype for t in itensors]) + assert odtype != IRDType.unknown, f"{node} : {[t.dtype for t in itensors]}" + otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + for tensor in otensors: + tensor.dtype = odtype + # setup graidient + if tensor.parent.grad is not None: + tensor.parent.grad.dtype = odtype + # generate backward reverse is only to make op id looks consecutive for fnode in [n for n in graph.nodes() if isinstance(n, IRFwOperation)][::-1]: fnode.gen_backward() From e18feffa4eebfcaa5aa0fb0db8f6a03d63c007da Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 13:49:52 +0800 Subject: [PATCH 0944/1892] update check scripts --- tests/test_examples.sh | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 83034387..67bd2aac 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -38,28 +38,22 @@ OMP_NUM_THREADS=4 torchrun \ examples/mlp/linears.py --policy PASOptimal -# test GSearch +# test GPT model OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ - examples/gsearch/gpt/train.py --policy PASReplica + examples/nlp/gpt/train.py --policy PASMegatronTP --fp16 OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/gsearch/gpt/train.py --policy PASMegatronTP + examples/nlp/gpt/train.py --policy PASRoundRobin --fp16 -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# examples/gsearch/gpt/train.py --policy PASRoundRobin -# -# -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# examples/gsearch/gpt/train.py --policy PAS1F1B +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/nlp/gpt/train.py --policy PASMegatron --fp16 # test scientific model @@ -69,7 +63,7 @@ OMP_NUM_THREADS=4 torchrun \ --nnodes=1 \ examples/poisson/sci.py -SCIENTIFIC_COMPUTING=1 OMP_NUM_THREADS=4 torchrun \ +OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=1 \ --nnodes=1 \ examples/wrf/wrf2.py From 0b5722638111d0586cd5bd9084ec689b141bd8bf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 14:33:48 +0800 Subject: [PATCH 0945/1892] add bias --- examples/nlp/blocks/attention.py | 48 +++++++++++++++----------------- examples/nlp/blocks/mlp.py | 18 ++++-------- examples/nlp/gpt/model.py | 18 ++++++------ 3 files changed, 37 insertions(+), 47 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index a1d94a34..7a953781 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -1,18 +1,13 @@ -from typing import Optional - import torch import cube -import warnings - -# @cube.graph.parser.register('L^ N E^, (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), E^ (h+ d^), E^ -> L^ N E^', name='self_attention') @cube.graph.parser.register('L^ N E^, (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), E^ (h+ d^) -> L^ N E^', name='self_attention') def self_attention(query: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, out_bias: None, + out_proj: torch.Tensor, h: int, scale: float, dropout_p: float, mask: bool = True): num_head = h L, N = query.size(0), query.size(1) @@ -46,16 +41,16 @@ def self_attention(query: torch.Tensor, output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + output = torch.nn.functional.linear(output, out_proj) # L N (h d), E E -> L N E return output -@cube.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d), E^ -> L^ N E^', name='cross_attention') +@cube.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') def cross_attention(query: torch.Tensor, key: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, out_bias: torch.Tensor, + out_proj: torch.Tensor, h: int, scale: float, dropout_p: float, mask=True): num_head = h L, N = query.size(0), query.size(1) @@ -89,13 +84,13 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E + output = torch.nn.functional.linear(output, out_proj, None) # L N (h d), E E -> L N E return output class MultiHeadSelfAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): super().__init__() self.inner_dim = inner_dim self.num_heads = num_heads @@ -104,34 +99,33 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa self.dropout_p = dropout # Q self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) # K self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) # V self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) # Out self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) - # self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None - self.out_bias = None - if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: - warnings.warn('self attention dense bias is skipped for correctness.') + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) def forward(self, query): - return self_attention( + attn = self_attention( query, self.q_proj, self.q_bias, self.k_proj, self.k_bias, self.v_proj, self.v_bias, - self.out_proj, self.out_bias, + self.out_proj, self.num_heads, self.scaling, self.dropout_p, mask=True ) + attn = attn + self.out_bias + return attn class MultiHeadCrossAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): super().__init__() self.inner_dim = inner_dim self.num_heads = num_heads @@ -140,19 +134,19 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa self.dropout_p = dropout # Q self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) # K self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) # V self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None + self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) # Out self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) def forward(self, query: torch.Tensor, key: torch.Tensor): - return cross_attention( + attn = cross_attention( query, key, self.q_proj, self.q_bias, self.k_proj, self.k_bias, @@ -160,3 +154,5 @@ def forward(self, query: torch.Tensor, key: torch.Tensor): self.out_proj, self.out_bias, self.num_heads, self.scaling, self.dropout_p, mask=True ) + attn = attn + self.out_bias + return attn diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py index faa98ec2..c9162364 100644 --- a/examples/nlp/blocks/mlp.py +++ b/examples/nlp/blocks/mlp.py @@ -1,37 +1,31 @@ import torch import cube -import warnings -# @cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+, E^ -> L^ N E^', name='feedforward') @cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj1_bias: torch.Tensor, - proj2: torch.Tensor, proj2_bias: None, #torch.Tensor, + proj2: torch.Tensor, dropout: float) -> torch.Tensor: x = torch.nn.functional.linear(x, proj1, proj1_bias) x = torch.nn.functional.gelu(x) x = torch.nn.functional.dropout(x, dropout, True, False) - x = torch.nn.functional.linear(x, proj2, proj2_bias) + x = torch.nn.functional.linear(x, proj2, None) return x class MLP(torch.nn.Module): - def __init__(self, embed_dim: int, hidden_dim: int, dropout: float, bias=True): + def __init__(self, embed_dim: int, hidden_dim: int, dropout: float): super().__init__() self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) - self.proj2_bias = None # torch.nn.Parameter(torch.empty((embed_dim,))) self.dropout = dropout - if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: - warnings.warn('feedforward output bias is skipped for correctness') def forward(self, x: torch.Tensor): - x = feedforward(x, - self.proj1, self.proj1_bias, - self.proj2, self.proj2_bias, - self.dropout) + x = feedforward(x, self.proj1, self.proj1_bias, + self.proj2, self.dropout) + x = x + self.proj2_bias return x diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 5eade94b..b38f6a84 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -11,9 +11,9 @@ class Config: seqlen = 1024 # 340 M model - embed_dim = 1024 - layers = 4 # 24 - attention_heads = 16 + # embed_dim = 1024 + # layers = 24 + # attention_heads = 16 # 1.3 B model # embed_dim = 2048 @@ -21,9 +21,9 @@ class Config: # attention_heads = 32 # 2.6 B model - # embed_dim = 2560 - # layers = 32 - # attention_heads = 32 + embed_dim = 2560 + layers = 32 + attention_heads = 32 # 6.7 B model # embed_dim = 4096 @@ -42,9 +42,9 @@ class Config: attn_hidden_dim = embed_dim ffn_hidden_dim = embed_dim * 4 - dropout = 0.0 - attn_dropout = 0.0 - activation_dropout = 0.0 + dropout = 0.2 + attn_dropout = 0.2 + activation_dropout = 0.2 class GPT(torch.nn.Module): From 71c343c18cd8fdb8f6744f5c3286ac5518b0bca1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 5 Aug 2022 14:05:16 +0800 Subject: [PATCH 0946/1892] add wide-resnet --- examples/vision/resnet/model.py | 281 +++++++++++++++++++++++++++ examples/vision/resnet/model_alpa.py | 152 +++++++++++++++ examples/vision/resnet/train.py | 76 ++++++++ 3 files changed, 509 insertions(+) create mode 100644 examples/vision/resnet/model.py create mode 100644 examples/vision/resnet/model_alpa.py create mode 100644 examples/vision/resnet/train.py diff --git a/examples/vision/resnet/model.py b/examples/vision/resnet/model.py new file mode 100644 index 00000000..09753c73 --- /dev/null +++ b/examples/vision/resnet/model.py @@ -0,0 +1,281 @@ +from typing import List, Optional, Callable +import torch +import torch.nn as nn +import cube + + +class Config: + + width_factor = 1 # for scaling default 1 + inplanes = 160 # 64 + # setting for wide-resnet 50 + layers : List[int] = [3, 4, 6, 3] + + # setting for wide-resnet 101 + layers : List[int] = [3, 4, 23, 3] + + width_per_group = 128 * width_factor + # conv2d: + # in_channel: 128 | out_channel: 128 | stride: 1 | groups: 1 | dilation: 1 + # in_channel: 256 | out_channel: 256 | stride: 2 | groups: 1 | dilation: 1 + # in_channel: 512 | out_channel: 512 | stride: 1 | groups: 1 | dilation: 1 + # in_channel: 1024 | out_channel: 1024 | stride: 2 | groups: 1 | dilation: 1 + # conv2d inputs: + # torch.Size([1, 128, 128, 128]) + # torch.Size([1, 256, 64, 64]) + # torch.Size([1, 512, 32, 32]) + # torch.Size([1, 1024, 16, 16]) + + # input + img_size = 224 + num_classes = 1024 # 1000 + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + # print(f'conv2d: in_channel: {in_planes} | out_channel: {out_planes} | stride: {stride} | groups: {groups} | dilation: {dilation}') + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + print(f'adding conv2d channel: {width}, stride: {stride}, padding: {dilation}') + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # print(f'conv2d input shape: {out.size()}') + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + print(identity.size(), out.size()) + out += identity + out = self.relu(out) + + return out + + +class WideResNet(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.layers = Config.layers + self.num_classes = 1000 + self._norm_layer = nn.BatchNorm2d + self.block = Bottleneck + self.inplanes = 64 + self.dilation = 1 + self.replace_stride_with_dilation = [False, False, False] + self.groups = 1 + self.base_width = Config.width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(64, self.layers[0]) + self.layer2 = self._make_layer(128, self.layers[1], stride=2, dilate=self.replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(256, self.layers[2], stride=2, dilate=self.replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(512, self.layers[3], stride=2, dilate=self.replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * self.block.expansion, self.num_classes) + self.loss_func = nn.CrossEntropyLoss() + + def _make_layer(self, planes: int, blocks: int, stride: int = 1, dilate = False): + block = Bottleneck + + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + return torch.nn.ModuleList(layers) + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + for layer in self.layer1: + x = layer(x) + for layer in self.layer2: + x = layer(x) + for layer in self.layer3: + x = layer(x) + for layer in self.layer4: + x = layer(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + loss = self.loss_func(x, target) + return loss + + +class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + + self.bs = batch_size + self.img_size = Config.img_size + self.num_classes = Config.num_classes + super().__init__( + shapes=([batch_size, 3, self.img_size, self.img_size,], + [batch_size], + ), + dtypes=(torch.float, torch.int), + batch_dims=(0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + img = torch.rand( + *(self.bs, 3, self.img_size, self.img_size), + dtype=torch.float, + device=torch.cuda.current_device() + ) + labels = torch.randint( + 0, self.num_classes, + size=(self.bs,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + return (img, labels) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] \ No newline at end of file diff --git a/examples/vision/resnet/model_alpa.py b/examples/vision/resnet/model_alpa.py new file mode 100644 index 00000000..f8938d94 --- /dev/null +++ b/examples/vision/resnet/model_alpa.py @@ -0,0 +1,152 @@ +import torch +import torch.nn as nn +import cube + + +class Config: + + stages = { + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3] + } + + num_layers = 50 + width_factor = 2 + num_filters = 160 + + img_size = 224 + num_classes = 1024 + + +class Bottleneck(nn.Module): + + def __init__(self, in_channels: int, out_channels: int, width_factor: int, stride: int) -> None: + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self.norm1 = nn.BatchNorm2d(out_channels) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels * width_factor, kernel_size=3, stride=stride, padding=1, bias=False) + self.norm2 = nn.BatchNorm2d(out_channels * width_factor) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(out_channels * width_factor, out_channels * 4, kernel_size=1, bias=False) + self.norm3 = nn.BatchNorm2d(out_channels * 4) + + # down sample + self.downsample = None if in_channels == out_channels * 4 else torch.nn.ModuleList([ + nn.Conv2d(in_channels, out_channels * 4, 1, stride, bias=False), + nn.BatchNorm2d(out_channels * 4) + ]) + + self.act3 = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor): + + residual = x + + y = self.conv1(x) + y = self.norm1(y) + y = self.act1(y) + + y = self.conv2(y) + y = self.norm2(y) + y = self.act2(y) + + y = self.conv3(y) + y = self.norm3(y) + + if self.downsample is not None: + for layer in self.downsample: + residual = layer(residual) + + # print(residual.size(), y.size()) + y = self.act3(residual + y) + return y + + +class WideResNet(nn.Module): + + def __init__(self): + super().__init__() + config = Config() + + # preprocess + self.conv1 = nn.Conv2d(3, config.num_filters, kernel_size=7, stride=2, padding=3, bias=False) + self.norm1 = nn.BatchNorm2d(config.num_filters) + self.act1 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 'padding=SAME' + + self.layers = torch.nn.ModuleList([]) + + stages = config.stages[config.num_layers] + for i, block_size in enumerate(stages): + channel = config.num_filters * (2 ** i) + for j in range(block_size): + if i == 0 and j == 0: + in_channels = channel + elif i > 0 and j == 0: + in_channels = channel // 2 * 4 + else: + in_channels = channel * 4 + stride = 2 if i > 0 and j == 0 else 1 + print(f'add in_channel: {in_channels} | out_channel: {channel * 4}') + block = Bottleneck( + in_channels, channel, config.width_factor, stride + ) + self.layers.append(block) + + # postprocess + self.fc = nn.Linear(channel * 4, config.num_classes, bias=False) + self.criterion = nn.CrossEntropyLoss() + + def forward(self, img: torch.Tensor, label: torch.Tensor): + x = self.conv1(img) + x = self.norm1(x) + x = self.act1(x) + x = self.maxpool(x) + print(x.size()) + + for block in self.layers: + x = block(x) + + # N C H W -> N C + x = torch.mean(x, dim=(2,3)) + x = self.fc(x) + loss = self.criterion(x, label) + return loss + + +class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + + self.bs = batch_size + self.img_size = Config.img_size + self.num_classes = Config.num_classes + super().__init__( + shapes=([batch_size, 3, self.img_size, self.img_size,], + [batch_size], + ), + dtypes=(torch.float, torch.int), + batch_dims=(0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + img = torch.rand( + *(self.bs, 3, self.img_size, self.img_size), + dtype=torch.float, + device=torch.cuda.current_device() + ) + labels = torch.randint( + 0, self.num_classes, + size=(self.bs,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + return (img, labels) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] diff --git a/examples/vision/resnet/train.py b/examples/vision/resnet/train.py new file mode 100644 index 00000000..20ef08c7 --- /dev/null +++ b/examples/vision/resnet/train.py @@ -0,0 +1,76 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/vision/resnet/train.py +""" + +import torch +from examples.vision.resnet.model_alpa import WideResNet, ImageDataLoader + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary, model_summary + + + +def train(): + + batch_size = 32 + nmicros = 1536 // batch_size + + + model = WideResNet() + model = model.cuda() + + cnt = 0 + for param in model.parameters(): + cnt += param.nelement() + print(f'param#: {cnt / 1e6} M') + + dataloader = ImageDataLoader(batch_size) + optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) + + print_each_rank('model weight consumpition:') + memory_summary() + + def train_iter(model, dataloader): + imgs, labels = next(dataloader) + loss = model(imgs, labels) + loss.backward() + + CudaTimer(enable=False).warmup() + iter_num = 10 + for step in range(iter_num): + + # if step == 0: + # model_summary(model, next(dataloader)) + + if step >= 4: + CudaTimer(enable=True).start('e2e') + + # training + for _ in range(nmicros): + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + if step >= 4: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('passed first iteration') + + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-4, field_name='e2e'))) + memory_summary() + +if __name__ == '__main__': + + cube.init() + train() From ad835c856d2e69d6a7f9851ca354aa31aa009c5a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Aug 2022 23:16:12 -0700 Subject: [PATCH 0947/1892] clean up output --- examples/vision/resnet/model_alpa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/vision/resnet/model_alpa.py b/examples/vision/resnet/model_alpa.py index f8938d94..362d6432 100644 --- a/examples/vision/resnet/model_alpa.py +++ b/examples/vision/resnet/model_alpa.py @@ -103,7 +103,7 @@ def forward(self, img: torch.Tensor, label: torch.Tensor): x = self.norm1(x) x = self.act1(x) x = self.maxpool(x) - print(x.size()) + # print(x.size()) for block in self.layers: x = block(x) From ec17b4c4c283b4a2a46e623411902779b2f9d116 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Fri, 5 Aug 2022 10:41:16 +0000 Subject: [PATCH 0948/1892] Merged PR 1406: Release zero-use tensor variables in codegen - Added releasing to Segment, including the `checkpoint` parts - Fixed a potential error that releases incorrect variables in Schedule. TODO - added this to Adapter, although Adapter functions are usually very short and hold few tensor variables. --- cube/codegen/codegen.py | 534 ++++++++++++++++++++++++++++++---------- requirements.txt | 1 + tests/test_codegen.py | 185 ++++++++++++++ 3 files changed, 587 insertions(+), 133 deletions(-) create mode 100644 tests/test_codegen.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 952eaaa0..97ef03bd 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -5,6 +5,7 @@ from typing import Dict, Generator, Iterable, List, Any, Optional, Set, Tuple, Union import torch import copy +from more_itertools import split_when from cube.graph.parser.mapping import Sign2Op from cube.ir.cten import IRCell, IRTensor @@ -23,12 +24,37 @@ from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock from cube.codegen.frontend_mapping import Sign2EmitRule +def get_backward_callsite_io_tensors(bp_segment:IRSegment): + """ + Returns: + ``` + (input_tensors, output_tensors, output_grads, input_grads) + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~ + #inputs to 'backward' outputs of 'backward' + ``` + """ + assert isinstance(bp_segment, IRSegment) and not bp_segment.forward + + input_tensors = [t for t in bp_segment.mirror.inputs() if \ + isinstance(t, IRSubTensor) and \ + t.requires_grad and \ + not t.is_param() + ] + output_tensors = [t for t in bp_segment.mirror.outputs() if isinstance(t, IRSubTensor)] + input_grads = [t.grad for t in input_tensors] + + # WARNING !!! + # non-tensor gradients like scalar '1.0f' are removed in 'bpSeg.inputs()' + # so the items of 'bpSeg.inputs()' are generally disaligned with 'output_grads' here. + output_grads = [t.grad for t in output_tensors] -# TODO this could be applied in all codegen: Segments, Adapters, Schedules... -# TODO but Schedules don't have standalone inputs, while Segments and Adapters have and -# those inputs may need to be released too. + return input_tensors, output_tensors, output_grads, input_grads + +# TODO this could be applied in Adapters def calc_tenvars_lifetime( - nodes:Iterable[IRCell], subgraph_outputs:Iterable[IRTensor] + nodes:Iterable[IRCell], + subgraph_outputs:Iterable[IRTensor], + subgraph_inputs:Iterable[IRTensor] = [] ) -> Dict[IRTensor, int]: """ Calculate the lifetime of tensor variables ahead-of-time. @@ -48,19 +74,14 @@ def calc_tenvars_lifetime( For each kv-pair `(t, i)` it indicates the last reference of tensor `t` is at the `i`-th (0-based) node's inputs, - i.e. the variable for tensor `t` could be released *AFTER* the `i`-th statement + i.e. the variable for tensor `t` could be released *BEFORE* the `i`-th statement in codegen. - If a tensor is not included in the dict, it means that tensor is never referenced. - - TODO If an input of the subgraph is never used, its corresponding `i` is `-1`. - - REMARK: - - We cannot `detele O_j` because it may be a variable alias and the tensor object - behind is still active (e.g. `runtime.multiref`). - So we just set it to `None`, decrement the reference count - and let Python (the codegen target) decide the deletion. + If an input of the subgraph is never used, its corresponding `i` is `0` -- this will + lead to an immediate release at the beginning of a function. + Tensors that exist till the end of the subgraph will have lifetime greater than the + size of that subgraph. Generally we don't need to manually release those tensors, + since they are automatically released when the generated function returns. """ lifetime : Dict[IRTensor, int] = dict() @@ -68,56 +89,70 @@ def calc_tenvars_lifetime( def is_temp_tensor(v): return isinstance(v, IRSubTensor) and not v.is_param() - #lifetime.update((tsin, -1) for tsin in subgraph_inputs if is_temp_tensor(tsin)) + lifetime.update((tsin, 0) for tsin in subgraph_inputs if is_temp_tensor(tsin)) for i, node in enumerate(nodes): - # aggressively mark all outputs for immediate deletion, namely 'i' - # TODO should work fine even for IRBpOperation - lifetime.update((tout, i) for tout in node.outputs() if is_temp_tensor(tout)) - # "fast-forward" all inputs to the current statement, namely 'i' + outputs : Iterable[IRTensor] inputs : Iterable[IRTensor] if isinstance(node, IRSegment): if node.forward: + outputs = node.outputs() inputs = node.inputs() else: # NOTE - # An 'IRBpOperation' does not explicitly record all tensors that are - # inputs-and-outputs-to-its-correspondeding-autograd.grad-call after codegen, - # - # E.g. - # IRBpOperation bp_op{ pair=fw_op } has: - # ``` - # len(bp_op.inputs())==len(fw_op.outputs()) - # len(bp_op.outputs())==len(fw_op.inputs()) - # ```` + # An backward 'IRSegment' does not explicitly record all tensors that are + # inputs-and-outputs-to-its-correspondeding-autograd.grad-call after codegen. # - # while a call to 'torch.autograd.grad' in Python is like: + # Where a call to 'torch.autograd.grad' in Python is like: # ``` # grad_inputs : Tuple[torch.Tensor, ...] = torch.autograd.grad(outputs, inputs, grad_outputs) # len(grad_inputs) == len(inputs) # len(grad_outputs) == len(outputs) # ``` - # in other words, if we simply treat `autograd.grad` as an `g_op:IRCell`, it has: - # ``` - # len(g_op.inputs()) == len(fw_op.outputs())*2 + len(fw_op.inputs()) - # len(g_op.outputs()) == len(fw_op.inputs()) - # ``` # - fw_inputs : tuple = node.mirror.inputs() - fw_outputs : tuple = node.mirror.outputs() - grad_inputs : Generator = (t.grad for t in fw_outputs) - - inputs = itertools.chain(fw_inputs, fw_outputs, grad_inputs) + # But a backward 'IRSegment' itself only records _extra_ information to take + # gradients for inputs to a forward 'IRSegment': + # + # - Inputs of the backward 'IRSegment' are + # gradient tensors for outputs of the corresponding forward 'IRSegment' + # + # WARNING: non-tensor gradients like scalar '1.0f' are removed, + # so the items of 'bpSeg.inputs()' are generally disaligned with 'fw_outputs()' + # + # - Outputs of the backward 'IRSegment' are + # gradient tensors for both explicit and implicit inputs of the forward 'IRSeg' + # + # P.S. the implicit inputs of the forward 'IRSeg' are like 'nn.Parameter's + # which are model fields and accessed by e.g. 'self.weights'. + # Generally, by viewing a gradient tensor of some input, we cannot distinguish + # whether the corresponding input is explicit or implicit. + + fw_inputs, fw_outputs, output_grads, input_grads = \ + get_backward_callsite_io_tensors(node) + + outputs = input_grads + inputs = list(itertools.chain(fw_inputs, fw_outputs, output_grads)) else: + outputs = node.outputs() inputs = node.inputs() - lifetime.update((tin, i) for tin in inputs if is_temp_tensor(tin)) + # aggressively mark all outputs for immediate deletion, + # namely *before* 'i+1'-th statement, in case it's never used. + lifetime.update((tout, i+1) for tout in outputs if is_temp_tensor(tout)) + + # "fast-forward" all inputs to the current statement, namely before 'i+1'-th node. + lifetime.update((tin, i+1) for tin in inputs if is_temp_tensor(tin)) - i += 1 - lifetime.update((tsout, i) for tsout in subgraph_outputs if is_temp_tensor(tsout)) + # end of 'for' + + # Here (i+1) is always greater than 'len(nodes)' + # Generally we don't manually release those tensors since the enclosing function is about to + # return, all local variables are automatically released. + # But we do need to update the lifetime of all outputs, to avoid early releasing. + lifetime.update((tsout, i+1) for tsout in subgraph_outputs if is_temp_tensor(tsout)) return lifetime @@ -213,6 +248,9 @@ def kwargs_naming(self, **kwargs) -> str: name = ', '.join(names) return name + def emit_tensors_release(self, tensors:Iterable[IRTensor]) -> str: + tnames : Generator = (self.tensor_naming(t) for t in tensors) + return 'del ' + ', '.join(tnames) class AutogradAdapterCodeGen(CodeGen): """ @@ -275,6 +313,49 @@ def name(adapter: IRAdapter) -> str: class ModelCodeGen(CodeGen): """ Generate model code + + `ModelCodeGen` traverses all IR nodes and categorizes their intermediately generated + codes into different parts, + then reorders and concatenates these parts into the final code for PyTorch to run. + + These parts are progressively stored into fields of `ModelCodeGen` + + - `init_code : List[str]` + Statements like `import torch` + + - `model_init_statements : List[str]` + Statements of the `__init__` constructor of the final `nn.Module` in codegen, + + E.g. (lines are split into `List[str]`) + ```python + self.init_group(ranks=[0, 1, 2, 3]) + self.weight_63 = torch.nn.Parameter(torch.empty((2048, 8192), dtype=torch.float32)) + self.add_full_map('weight_63', 3, (slice(0, 2048, None), slice(0, 8192, None)), 1) + ``` + + including: + -- initialization of model weights, which are class fields; + + - `model_methods_bodies : List[List[str]]` + Definitions of the Python code for forward computations like Segments or Adapters + + Note that codes within this field haven't been organized into valid Python methods, + namely without signatures and return statements, both of which will be extracted + from corresponding IRSegment/IRAdapter in later processes. + E.g. + ``` + [ + # intermediate codes for 'segment123(self, tensor_2222)' + [ + 'tensor_3333 = torch.view(tensor_2222, [1,2,3,4])' + ] + + # intermediate codes for 'adapter456(self, tensor_4444)' + [ + 'tensor_5555 = cube.runtime.adapter.all_reduce(tensor_4444, ranks=[0,1,2,3])' + ] + ] + ``` """ def __init__(self, execplan: ExecutionPlan): @@ -289,10 +370,9 @@ def __init__(self, execplan: ExecutionPlan): self.init_code.append(op_impl) self.init_code += [''] # module init code - self.declare_region: List[str] = list() - # module forward code - self.forward_region_units: List[List[str]] = list() - self.forward_region: List[str] = list() + self.model_init_statements: List[str] = list() + # module method bodies for forward computations, e.g. Segments, Adapters. + self.model_methods_bodies: List[List[str]] = list() # module member name self.symbols = SymbolTable() # ref module to check shared variables @@ -304,6 +384,9 @@ def init_comm_groups(self): Creating communication group requires all the devices enter the same call. + + The fields storing intermediate codes that are populated by this method: + - `model_init_statements` """ graph = self.execplan.graph sign = 'self.init_group(ranks={ranks})' @@ -327,11 +410,11 @@ def init_comm_groups(self): if ranks not in comm_groups: comm_groups.append(ranks) # create communication group - self.declare_region.append('# communication groups') + self.model_init_statements.append('# communication groups') for ranks in comm_groups: code = sign.format(ranks=list(ranks)) - self.declare_region.append(code) - self.declare_region.append(' ') + self.model_init_statements.append(code) + self.model_init_statements.append(' ') def gen(self, device: int, outfile=None, attach=False) -> str: """ @@ -357,7 +440,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: if not node.forward: continue # skip backward segment codes = self.emit_segment_code(node) elif isinstance(node, IRFwOperation): - codes = self.emit_op_code(node) + raise RuntimeError(f"Unexcepted global-level op call: {node}") elif isinstance(node, IRAdapter): codes = self.emit_adapter_code(node) elif isinstance(node, IRWeightReducer): @@ -369,13 +452,16 @@ def gen(self, device: int, outfile=None, attach=False) -> str: continue else: raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") - self.forward_region += codes - # emit node tensor declaration - self.emit_node_declare(node) + + # emit node tensor declaration into `__init__` + # typically it's about the `nn.Parameter` + self.emit_node_tensors_declare(node) + # emit node code - self.forward_region_units.append(self.forward_region) - self.forward_region = list() + # codes : List[str] + self.model_methods_bodies.append(codes) gen_nodes.append(node) + args = list() for t in node.inputs(): if isinstance(t, IRSubTensor): @@ -388,7 +474,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # generate full code with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: with FunctionBlock(func_name='__init__', args=['self']) as ib: - ib.insert_body(self.declare_region) + ib.insert_body(self.model_init_statements) # switch to training or inference mode if self.execplan.inference: ib.insert_body('self.eval()') @@ -399,7 +485,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: for idx, node in enumerate(gen_nodes): name = self.node_naming(node) input_args = ['self'] + node_args[idx] - forward_code = self.forward_region_units[idx] + forward_code = self.model_methods_bodies[idx] with FunctionBlock(func_name=name, args=input_args) as fb: fb.insert_body(forward_code) # generate output @@ -423,10 +509,17 @@ def gen(self, device: int, outfile=None, attach=False) -> str: self.clear() return code - def emit_node_declare(self, node: IRCell): + def emit_node_tensors_declare(self, node: IRCell): """ Emit tensor declaration code + + The fields storing intermediate codes that are populated by this method: + - `model_init_statements` + + This method also populates `self.symbols : SymbolTable` to record + the names of the variables for the tensors ever encountered. """ + sign = 'torch.nn.Parameter(torch.empty({shape}, dtype={dtype}))' map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" for itensor in node.inputs(): @@ -435,7 +528,7 @@ def emit_node_declare(self, node: IRCell): if itensor.is_param() and not self.symbols.exist(name): self.symbols.create(name) code = f'{name} = {sign.format(shape=tuple(itensor.shape), dtype=self.dtype_map(itensor.dtype))}' - self.declare_region.append(code) + self.model_init_statements.append(code) tid = itensor.parent.tid slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) val_chunks = itensor.valmap[1] @@ -443,7 +536,7 @@ def emit_node_declare(self, node: IRCell): attr=self.tensor_naming(itensor), tid=tid, slicers=str(slicers), val_chunks=val_chunks ) - self.declare_region.append(code) + self.model_init_statements.append(code) if isinstance(itensor, str): if name.startswith('self.'): if not hasattr(self._ref_module, name[5:]): @@ -452,84 +545,245 @@ def emit_node_declare(self, node: IRCell): self.symbols.create(self.tensor_naming(output, prefix_attr='self.')) return - def emit_segment_code(self, node: IRSegment) -> List[str]: + def emit_segment_code(self, segment: IRSegment) -> List[str]: """ - Emit IRSegment code + Emit IRSegment code. + + The resultant `List[str]` will be lines of the statements of the final + Python method for the targeted Segment. + The resultant lines will not include the signature and the return statement + of the generated Python method. These lines will be put into `model_methods_bodies` + and the missing Python-syntactic parts will be injected later on. + + e.g. + ``` + [ + # no method signature + 'tensor_3333 = torch.view(tensor_2222, [1,2,3,4])', + 'tensor_2222 = None', # if in dataflow there is no more reference + 'tensor_4444 = torch.sum(tensor_3333)', + 'def recompute(...):', + ' return ...', + 'tensor_5555 = torch.utils.checkpoint(recompute, tensor_4444)', + 'tensor_4444 = None', # if in dataflow there is no more reference + # no return statement + ] + ``` Nodes in the segment will group into recompute region + + The fields storing intermediate codes that are populated by this method: + - NONE """ codes = [] - def emit_nodes(nodes: List[IRCell]) -> List[str]: + def emit_nodes_invocations(i_nodes: List[Tuple[int, IRCell]], + lifetime_by_line_id: Dict[int, List[IRTensor]]) -> List[str]: + """ + Emit code to invoke operations and adapter, + e.g. (the lines are split into `List[str]`) + + ``` + tensor_2222 = torch.view(tensor_1111, size=[3,6,9]) + tensor_1111 = None # if no more reference + tensor_3333 = cube.runtime.adapter.allgather_reducescatter(tensor_2222, dim=1, rank=[0,1]) + tensor_2222 = None # if no more reference + ``` + + The fields storing intermediate codes that are populated by this method: + - NONE + """ node_codes = [] - for node in nodes: + for i, node in i_nodes: + + # NOTE + # If a tensor is still referenced in any later recomputing group, its lifetime is + # definitely greater than the current sequence of statements here. + # Therefore we get chance to extend the lifetime of tensors like that, + # and properly release them after the call to 'torch.utils.checkpoint'. + # + tensors_to_del : Optional[List[IRTensor]] = lifetime_by_line_id.get(i, None) + if tensors_to_del is not None: + node_codes.append(self.emit_tensors_release(tensors_to_del)) + if isinstance(node, IRFwOperation): code = self.emit_op_code(node) + node_codes += code elif isinstance(node, IRAdapter): code = self.emit_adapter_code(node) + node_codes += code else: raise RuntimeError(f"unexpected type {type(node)} in IRSegment") - node_codes += code + return node_codes - def emit_rc_nodes(nodes: List[IRCell]) -> List[str]: + # returns: (code_lines, group_inputs, group_outputs) + def emit_rc_nodes(i_nodes: List[Tuple[int, IRCell]], lifetime_by_line_id: dict) \ + -> Tuple[List[str], List[IRTensor], List[IRTensor]]: + """ + Emit code to define a Python function for ReComputing and invoke it + e.g. (the lines are split into `List[str]`) + + ``` + def recompute(tensor_2222): + tensor_3333 = torch.view(tensor_2222, size=[3,6,9]) + tensor_2222 = None # no more reference + return tensor_3333 + # in the beginning we have `import torch.utils.checkpoint as ckpt` + tensor_4444 = ckpt.checkpoint(recompute, tensor_1111) + ``` + + REMARK: + - In the example above, 'tensor_2222' can be released within the RC subgraph, which also means that + the variable for this tensor can also be released within the enclosing graph, after the 'checkpoint' call. + - The generated RC subgraph will have no "free variables". + All involved tensors that are defined outside of the RC group are made explicit inputs; + All tensors, that are defined within the RC group and are referenced after RC subgraph ends, are made explicit outputs; + And if a within-RC-group tensors are not used anymore, it's not returned. + + The fields storing intermediate codes that are populated by this method: + - NONE + """ + assert len(i_nodes) > 0 node_codes = [] + + nodes : List[IRCell] = [node for i, node in i_nodes] + subseg = self.execplan.graph.segment(nodes) - inputs = [self.tensor_naming(t) for t in subseg.inputs() if not t.is_param()] - inputs_tup = ', '.join(inputs) - outputs = [self.tensor_naming(t) for t in subseg.outputs()] - outputs = ', '.join(outputs) - with FunctionBlock('recompute', inputs, False) as fb: - for ncode in emit_nodes(nodes): + + inputs = [t for t in subseg.inputs() if not t.is_param()] + input_names = [self.tensor_naming(t) for t in inputs] + input_names_tuple = ', '.join(input_names) + outputs = [t for t in subseg.outputs()] + output_names = [self.tensor_naming(t) for t in outputs] + output_names_tuple = ', '.join(output_names) + + # 'graph.segment(nodes)' ensures that if a tensor is no longer used (in RC group or in later code), + # it's not included in 'outputs'. + # And we will not generate 'return' statement for it, since it will cause the error + # that the variable is not defined (because it has been 'del'-ed). + + with FunctionBlock('recompute', input_names, False) as fb: + # The nodes to recompute share the same space of line_ids (or "node ids") with non-recomputable nodes. + # e.g. those ids in subgraphs are not 0-based, and incremented after the preceding non-rc nodes and so on. + # + # So within the recomputing subgraph, tensors can be released if they are no longer used + # i.e. not returned by the 'def recompute(...)' + # since 'execplan.graph.segment(nodes)' will make all "free variables" as explicit inputs/outputs + # to that subgraph. + for ncode in emit_nodes_invocations(i_nodes, lifetime_by_line_id): fb.insert_body(ncode) - fb.insert_body(f'return {outputs}') + fb.insert_body(f'return {output_names_tuple}') node_codes += [''] + fb.code + [''] node_codes.append( - f'{outputs} = ckpt.checkpoint(recompute, {inputs_tup})' + f'{output_names_tuple} = ckpt.checkpoint(recompute, {input_names_tuple})' ) - return node_codes - # to recompute region - curr_recompute_gid = None - curr_nodes = [] - for node in node.nodes(): - if isinstance(node, IRFwOperation): - if node.recompute != curr_recompute_gid: - if len(curr_nodes) != 0: - if curr_recompute_gid is None: - codes += emit_nodes(curr_nodes) - else: - codes += emit_rc_nodes(curr_nodes) - curr_recompute_gid = node.recompute - curr_nodes = [node] - else: - curr_nodes.append(node) - elif isinstance(node, IRAdapter): - # strategy 1: recompute close before adapter communication - if curr_recompute_gid is None: - curr_nodes.append(node) - else: - if len(curr_nodes) != 0: - if curr_recompute_gid is None: - codes += emit_nodes() - else: - codes += emit_rc_nodes() - else: - curr_recompute_gid = None - curr_nodes = [node] - - if len(curr_nodes) != 0: - if curr_recompute_gid is None: - codes += emit_nodes(curr_nodes) + return node_codes, inputs, outputs + + def get_equiv_recompute_gid(node:Union[IRFwOperation, IRAdapter]) -> Optional[int]: + if isinstance(node, IRAdapter): + # IRAdapter is equivalent to be non-recomputable. And it always terminates the + # nodes sequence of any recomputing group before it. + return None + elif isinstance(node, IRFwOperation): + return node.recompute else: - codes += emit_rc_nodes(curr_nodes) + raise ValueError(f'Unexcepted node type {type(node)}') + + def should_start_new_recompute_group(i_prev, i_cur) -> bool: + # i_prev, i_cur: Tuple[int, Union[IRFwOp,IRAdapter]] + prev_gid = get_equiv_recompute_gid(i_prev[1]) + cur_gid = get_equiv_recompute_gid(i_cur[1]) + return cur_gid != prev_gid + + nodes : List[IRCell] = segment.nodes() + + # After calculating the recompute groups, for each group, its input tensors' lifetime + # should be extend to at least beyond the lifetime of that group. + lifetime : Dict[IRTensor, int] = calc_tenvars_lifetime(nodes, segment.outputs(), segment.inputs()) + lifetime_by_line_id : Dict[int, List[IRTensor]] = dict() + for tensor, line_id in lifetime.items(): + lifetime_by_line_id.setdefault(line_id, []).append(tensor) + + # more_itertools.split_when # type: (Iterable[T], (T,T)->bool) -> Iterator[List[T]] + recompute_groups : List[List[Tuple[int, IRCell]]] \ + = list(split_when(enumerate(nodes), should_start_new_recompute_group)) + + for rc_group in recompute_groups: + # all FwOps/Adapters in a group have the same (equivalent) group id, + # check that of the head item, and 'rc_group' will not be empty here. + gid : Optional[int] = get_equiv_recompute_gid(rc_group[0][1]) + if gid is None: + codes += emit_nodes_invocations(rc_group, lifetime_by_line_id) + else: + assert len(rc_group) > 0 + + # Step 1: when entering a RC group: + # + # We insert tensor releasing statement *before* emitting each node. + # But here we are entering the scope of a RC group i.e. 'def recompute(...)'. + # Any releasing before the first node of the RC group, + # should be done before and outside of the RC group. + rc_first_line_id, _rc_first_node = rc_group[0] + # ... and to avoid emitting again, 'pop' the lifetime record. + # Specify the default collection since there might not be any. + rel_tensors_before_rc : Optional[list] = lifetime_by_line_id.pop(rc_first_line_id, None) + if rel_tensors_before_rc is not None: + codes.append(self.emit_tensors_release(rel_tensors_before_rc)) + + # Step 2 + rc_codes, rc_inputs, rc_outputs = emit_rc_nodes(rc_group, lifetime_by_line_id) + codes += rc_codes + + # Step 3: when exiting a RC group: + # + # `emit_rc_nodes` will not emit 'del`-statement for output tensors of the last + # node in the RC group, since those tensors will be immediately released + # as soon as 'recompute(...)' returns. + # We need to remove those tensors from the linearized lifetime + # (namely those with lifetime 'rc_next_line_id') + # and do not release them before the next node after the RC group. + rc_last_line_id, _rc_last_node = rc_group[-1] + rc_next_line_id = rc_last_line_id + 1 + lifetime_by_line_id.pop(rc_next_line_id, None) # specify a default to avoid KeyError + + # Step 4: after exiting a RC group: + # + # We need to release some argument tensors to the 'def recompute(...)' if they are + # no longer used. + # NOTE those tensors may have resulted in some 'del'-statements within the RC + # subfunction. But we need to release them again in the enclosing function, + # after the call to 'torch.checkpoint(recompute, *input_tensors)'. + + # Only release an RC input if: + # - its lifetime does not exceed the lifetime of the RC group; + # - not the case that the function returns after 'checkpoint' the RC subgraph. + if rc_next_line_id != len(nodes): + inputs_to_rel = [rcin for rcin in rc_inputs if lifetime[rcin] <= rc_next_line_id] + if len(inputs_to_rel) > 0: + del_stmt = self.emit_tensors_release(inputs_to_rel) + codes.append(del_stmt) + + # any resultant tensors *defined within the RC group and not used after the group* + # will not be returned from the generate 'def recompute(...)', + # so here we have no resultant tensors (namely 'rc_outputs') to release. return codes def emit_op_code(self, node: IRFwOperation) -> List[str]: """ - Emit op forward code + Emit the statement to call the op in the forward code + (e.g. in Segments, Adapter or CodeGen.Main) + + The result will look like (the lines are split into `List[str]`) + ``` + tensor_3333 = torch.view(tensor_2222, [1,2,3,4,5]) + ``` + + The fields storing intermediate codes that are populated by this method: + - NONE """ codes = [] # insert comment @@ -558,11 +812,19 @@ def emit_op_code(self, node: IRFwOperation) -> List[str]: def emit_adapter_code(self, node: IRAdapter) -> List[str]: """ - Emit adapter call + Emit the statment of the adapter call + + The resultant `List[str]` will be lines of the statements of the final + Python method for the targeted Segment, + without the method signature and the return statement. + + The fields storing intermediate codes that are populated by this method: + - NONE """ codes = [] assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" prims = [node] if node.differentiable and node.custom else [prim for prim in node.prims] + for prim in prims: if len(prim.inputs()) == 1: itensors = self.tensor_naming(prim.inputs()[0], prefix_attr='self.') @@ -574,7 +836,13 @@ def emit_adapter_code(self, node: IRAdapter) -> List[str]: codes.append(code) return codes - def emit_reducer_init(self, node: IRWeightReducer): + def emit_reducer_init(self, node: IRWeightReducer) -> None: + """ + Emit code to initialize involved reducer objects in `__init__`. + + The fields storing intermediate codes that are populated by this method: + - `model_init_statements` + """ # reducer init interface reducer_init = '{reducer} = cube.runtime.adapter.Reducer(ranks={ranks})' reducer_add = 'self.add_reducer({reducer})' @@ -582,17 +850,23 @@ def emit_reducer_init(self, node: IRWeightReducer): # create reducer in declare region weights = node.inputs() reducer_name = f'self.wreducer{node._id}' - self.declare_region.append('') + self.model_init_statements.append('') init_code = reducer_init.format(reducer=reducer_name, ranks=node.device) - self.declare_region.append(init_code) + self.model_init_statements.append(init_code) weights = [self.tensor_naming(t, prefix_attr='self.') for t in weights] for weight in weights: add_param_code = add_param.format(reducer=reducer_name, weight=weight) - self.declare_region.append(add_param_code) + self.model_init_statements.append(add_param_code) add_code = reducer_add.format(reducer=reducer_name) - self.declare_region.append(add_code) + self.model_init_statements.append(add_code) def emit_reducer_call(self, node: IRWeightReducer): + """ + Emit the statment to invoke a reducer object. + + The fields storing intermediate codes that are populated by this method: + - NONE + """ reducer_name = f'self.wreducer{node._id}' code = f'{reducer_name}.allreduce()' return [code] @@ -602,10 +876,9 @@ def clear(self): Clear buffer that used for generating code """ # module init code - self.declare_region: List[str] = list() + self.model_init_statements: List[str] = list() # module forward code - self.forward_region_units: List[List[str]] = list() - self.forward_region: List[str] = list() + self.model_methods_bodies: List[List[str]] = list() # module member name self.symbols = SymbolTable() @@ -645,15 +918,15 @@ def gen(self, device: int, outfile=None, attach=False) -> str: fb.insert_body(code) else: for i, node in enumerate(device_nodes): + # Decrement reference counts for output tensors that are no longer used + # Tensors here need to release *before* the i-th statement. + tensors : Optional[List[IRTensor]] = lifetime_by_line_id.get(i, None) + if tensors is not None: # not necessarily to have one after each line + fb.insert_body(self.emit_tensors_release(tensors)) + name = self.node_naming(node) code = self.emit_node(node, name=name) fb.insert_body(code) - - # decrement reference counts for output tensors that are no longer used - tensors : Optional[List[IRTensor]] = lifetime_by_line_id.get(i, None) - if tensors is not None: # not necessarily to have one after each line - tnames : Generator = (self.tensor_naming(t) for t in tensors) - fb.insert_body(', '.join(tnames) + ' = ' + ', '.join(['None'] * len(tensors))) # return code outputs = self.return_naming(self.execplan.graph.outputs()) @@ -702,14 +975,9 @@ def emit_node(self, node: IRCell, name: str) -> str: ) # emit backward else: - input_tensors = [t for t in node.mirror.inputs() if \ - isinstance(t, IRSubTensor) and \ - t.requires_grad and \ - not t.is_param() - ] - output_tensors = [t for t in node.mirror.outputs() if isinstance(t, IRSubTensor)] - input_grads = [t.grad for t in input_tensors] - output_grads = [t.grad for t in output_tensors] + input_tensors, output_tensors, output_grads, input_grads = \ + get_backward_callsite_io_tensors(node) + for idx, tensor in enumerate(output_grads): if isinstance(tensor, float): assert tensor == 1.0, "Loss gradient should be 1.0" diff --git a/requirements.txt b/requirements.txt index 7da97869..ac1c8955 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ einops matplotlib pytest +more-itertools --find-links https://download.pytorch.org/whl/torch_stable.html torch==1.11.0+cu113 \ No newline at end of file diff --git a/tests/test_codegen.py b/tests/test_codegen.py new file mode 100644 index 00000000..a525a2b8 --- /dev/null +++ b/tests/test_codegen.py @@ -0,0 +1,185 @@ +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +import pytest +from cube.codegen.codegen import ModelCodeGen +from cube.execplan.execplan import ExecutionPlan + +from cube.graph.graph import IRGraph, IRSegment +from cube.ir.cten import IRCell, IRTensor +from cube.ir.operator import IRFwOperation +from cube.ir.tensor import IRFullTensor, IRSubTensor + +# Override tensor naming to omit TensorID since the ID assignment is too hard to predict. +class FakeModelCodeGen(ModelCodeGen): + def tensor_naming(self, tensor: Any, prefix_attr: Optional[str] = None) -> str: + if isinstance(tensor, IRTensor): + name = tensor.name.replace(".", "_") + if prefix_attr is not None and tensor.is_param(): + name = prefix_attr + name + return name + else: + return super().tensor_naming(tensor, prefix_attr) + +def make_nodes(args_list: list, input_vars:List[str], output_vars:List[str]) \ + -> Tuple[List[IRFwOperation], List[IRTensor], List[IRTensor]]: + """ + Each element of `args_list` is in a form of: + + (RCGID:Optional[int], OutputNames:str|List[str], Signature:str, OpArg...:OpArgType...) + + If any `OpArg` is string, it's automatically mapped to a `IRTensor` with the same name. + + E.g. + ``` + [ + ("sum_res", "sum_fn", "a", IRFullTensor(name="b")) + (1, ["prod_res"], "prod_fn", "sum", "a") + ] + ``` + ... results in + ``` + sum_res = sum_fn(a, b) + def recompute(sum_res, a): + prod_res = prod_fn(sum_res, a) + return prod_res + prod_res = checkpoint(recompute, sum_res, a) + ``` + + REMARK: + `signature:str` will affect how the call is dumped, see also `cube/codegen/frontend_mapping.py`. + Generally if it's not a 'torch.some_fn' operator, it's dumped as-it-is. + """ + var_tensor_map = dict() + + def _convert(output_names:Union[str, List[str]], signature:str, op_args, rc_gid:Optional[int]): + if type(output_names) is str: + output_names = [output_names] + + op_kwargs = {} + if type(op_args[-1]) is dict: + op_args, op_kwargs = op_args[:-1], op_args[-1] + + mapped_inputs = [var_tensor_map.setdefault(arg, IRFullTensor(name=arg).tosub()) if type(arg) is str else arg for arg in op_args] + mapped_outputs = [var_tensor_map.setdefault(oname, IRFullTensor(name=oname).tosub()) for oname in output_names] + + op = IRFwOperation("not_matter_name", signature, len(mapped_inputs), len(output_names)) + for i, input in enumerate(mapped_inputs): + op.set_input(i, input) + for i, output in enumerate(mapped_outputs): + op.set_output(i, output) + op.kwargs.update(op_kwargs) + + # All devices are the same + op.device = 0 + + op.recompute = rc_gid + + return op + + def convert(args): + rc_gid = None + if type(args[0]) is int: + rc_gid, args = args[0], args[1:] + return _convert(args[0], args[1], args[2:], rc_gid) + + nodes = [convert(args) for args in args_list] + inputs = [var_tensor_map[n] for n in input_vars] + outputs = [var_tensor_map[n] for n in output_vars] + return nodes, inputs, outputs + +def gen(node_defs, invars, outvars): + nodes, inputs, outputs = make_nodes(node_defs, invars, outvars) + # REMARK + # Do not directly create a 'IRSegment' from 'nodes', instead, re-retrieve the IRSegment + # using 'graph.segment(nodes)' + # Because we rely on proper dataflow analysis when segmentation, which requires all nodes + # are properly registered/'attach'-ed into the graph. + graph = IRGraph(nodes, inputs, outputs, "module_name_not_matter") + segment = graph.segment(nodes) + assert list(segment.inputs()) == inputs + assert list(segment.outputs()) == outputs + + codegen = FakeModelCodeGen(ExecutionPlan(graph)) + code : list = codegen.emit_segment_code(segment) + return str.join("\n", code) + + +def test_codegen_segment_recompute__simple(): + code = gen([ + ("c", "add", "a", "b"), + ("d", "add", "a", "c"), + ], invars=["a","b"], outvars=["d"]) + assert code == """\ +c = add(a, b) +del b +d = add(a, c)""" + + +def test_codegen_segment_recompute_rc__simple(): + code = gen([ + ("c", "add", "a", "b"), + (1, "d", "add", "a", "c"), + ], invars=["a","b"], outvars=["d"]) + + assert code == """\ +c = add(a, b) +del b + +def recompute(a, c): + d = add(a, c) + return d + +d = ckpt.checkpoint(recompute, a, c)""" + + +def test_codegen_segment_recompute_rc__del_args(): + code = gen([ + ("c", "add", "a", "b"), + (1, "d", "add", "a", "c"), + ("e", "sub", "d", "a"), + ], invars=["a","b"], outvars=["e"]) + + assert code == """\ +c = add(a, b) +del b + +def recompute(a, c): + d = add(a, c) + return d + +d = ckpt.checkpoint(recompute, a, c) +del c +e = sub(d, a)""" + + +def test_codegen_segment_recompute_rc__multi_rc(): + code = gen([ + ("c", "add", "a", "b"), + (1, "d", "add", "a", "c"), + (1, "e", "sub", "d", "a"), + ("f", "mul", "a", "d"), + (2, "g", "div", "f", "e"), + (2, "h", "pow", "f", "g"), + ], invars=["a","b"], outvars=["g","h"]) + + assert code == """\ +c = add(a, b) +del b + +def recompute(a, c): + d = add(a, c) + del c + e = sub(d, a) + return d, e + +d, e = ckpt.checkpoint(recompute, a, c) +del c +f = mul(a, d) +del a, d + +def recompute(f, e): + g = div(f, e) + del e + h = pow(f, g) + return g, h + +g, h = ckpt.checkpoint(recompute, f, e)""" # e,f will be auto rel-ed when it returns From 954569e1dc5f92a76d91627d3acf1c1f9972f6c0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 5 Aug 2022 20:44:45 +0800 Subject: [PATCH 0949/1892] fix multi output tensor parse under rc region --- cube/graph/graph.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 3caa3ef2..3708410f 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -245,22 +245,22 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: if len(producers) == 0 or any(p not in nodes for p in producers): if itensor not in itdevs: itdevs[itensor] = [] - if itensor.device not in itdevs[itensor]: + devs = set(itensor.device) + if devs not in itdevs[itensor]: inputs.append(itensor) - itdevs[itensor].append(itensor.device) + itdevs[itensor].append(devs) # update outputs otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] for otensor in otensors: - if otensor in self.outputs(): - outputs.append(otensor) consumers = [c for c in otensor.parent.consumers if set(c.device).issubset(set(node.device))] # no consumer usually means the loss or cross device-group - if len(consumers) == 0 or any(c not in nodes for c in consumers): + if otensor in self.outputs() or len(consumers) == 0 or any(c not in nodes for c in consumers): + devs = set(otensor.device) if otensor not in otdevs: otdevs[otensor] = [] - if otensor.device not in otdevs[otensor]: + if devs not in otdevs[otensor]: outputs.append(otensor) - otdevs[otensor].append(otensor.device) + otdevs[otensor].append(devs) segment = IRSegment(nodes, inputs, outputs) return segment @@ -338,8 +338,8 @@ def attach(self, node: IRCell, index, reset_dependency=False): if isinstance(itensor, IRSubTensor) and itensor not in itensors: itensors.append(itensor) for itensor in itensors: - if itensor.parent._id not in self._full_tensors: - self._full_tensors[itensor.parent._id] = itensor.parent + if itensor.parent.tid not in self._full_tensors: + self._full_tensors[itensor.parent.tid] = itensor.parent idx = 0 for consumer in itensor.parent.consumers: if self.nodes().index(consumer) < index: From 67b0d7f296bf3e5b1e509d6e9558a97f1b2a8e4f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Aug 2022 14:46:56 +0800 Subject: [PATCH 0950/1892] functionality support: allow an operator has different input but shares a same full tensor: changeed producer and consumer; allow partitioned operators to be placed on same device: local adapter fusion and multiref insertation --- cube/graph/function/function.py | 6 +- cube/graph/gener/gen.py | 219 +++++++++++++++++++++++++++++++- cube/graph/gener/layout.py | 8 +- cube/graph/graph.py | 14 ++ cube/ir/tensor.py | 44 ++++--- 5 files changed, 261 insertions(+), 30 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index f22ddb7f..340e4eef 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -22,7 +22,7 @@ from cube.graph.torch_dtype_mapping import DType2IRDType, TorchScalarTypeEnumMap -def Identity(signature, inputs): +def Identity(signature, inputs: List[IRTensor]): signature = 'cube.runtime.function.identity' eshape = ShapeAnno.create_shape_str(inputs[0].shape) anno = OpAnno.create_op_str([eshape], [eshape]) @@ -796,13 +796,13 @@ def Embedding(signature, inputs: List): return IRDimops(signature, [anno], [itensor, weight], 'embedding', padding_idx=padding_idx, start=start, stop=stop) -def MultiRef(signature, inputs: List[IRFullTensor]): +def MultiRef(signature, inputs: List[IRTensor]): """ cube.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] """ signature = 'cube.runtime.function.multiref' itensor, times = inputs - assert isinstance(itensor, IRFullTensor), "require all inputs to be IRSubTensor" + assert isinstance(itensor, IRTensor), "require all inputs to be IRSubTensor" assert isinstance(times, int), "require int for second input" anno = '* -> ' + ', '.join('*' for _ in range(times)) node = IRDimops(signature, [anno], [itensor], 'multiref', times=times) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 09566a31..53804256 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,15 +1,16 @@ +import itertools from typing import Dict, List, Optional, Tuple -import warnings import copy from cube.graph.graph import IRGraph -from cube.graph.function import Identity +from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap -from cube.ir.adapter import IRAdapter, IRWeightReducer from cube.ir.operator import IRBpOperation, IRFwOperation from cube.ir.adapter.prim import IRAdapterPrim +from cube.ir.adapter import IRAdapter, IRWeightReducer +from cube.graph.function.function import Add, Cat, Identity, MultiRef from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim from cube.graph.gener.layout import GridLayout @@ -130,6 +131,11 @@ def gen_activation(graph: IRGraph) -> IRGraph: ftensor.consumers[0].device == ftensor.producers[0].device: continue + # optimization: local fusion on producer + if graph.train: + ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) + IRAdapterGener.local_consumer_multiref(graph, ftensor) + ptensors, ctensors = ftensor.ptensors, ftensor.ctensors pdevs = tuple(ptensor.device[0] for ptensor in ptensors) cdevs = tuple(ctensor.device[0] for ctensor in ctensors) @@ -177,6 +183,8 @@ def gen_activation(graph: IRGraph) -> IRGraph: if badapter is not None: bidx = min(graph.nodes().index(consumer) for consumer in ftensor.grad.consumers) graph._nodes.insert(bidx, badapter) + + # print(graph.extra_repr()) return graph @staticmethod @@ -369,3 +377,208 @@ def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRA f"Remain Tensor:\n\t{remain}" ) return prims + + @staticmethod + def local_producer_fusion(graph: IRGraph, ftensor: IRFullTensor) -> IRFullTensor: + """! + Fuse the producer tensors using concat and add. + This will add a new full tensor by chaging from: + producer --(ftensor)--> consumer + to: + producer --(ftensor)--> fused nodes --(new ftensor)--> consumer + + @param tensors List[IRSubTensor]: tensors to be fused in local device + + @return new_ftensor IRFullTensor: the new full tensor. + If cannot fuse, the original ftensor. + """ + + def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTensor: + parent = tensor.parent.like() if share is None else share + return parent.select(tensor.indmap, tensor.valmap) + + # collect device tensors + devtensors: Dict[int, List[IRSubTensor]] = dict() + # devid: old tensor -> [nodes,] + fuse_tensors: Dict[int, Dict[IRSubTensor, List[IRSubTensor]]] = dict() + tensor_map: Dict[int, Dict[IRSubTensor, IRSubTensor]] = dict() + + for tensor in ftensor.ptensors: + assert len(tensor.device) == 1 + devid = tensor.device[0] + if devid not in devtensors: + devtensors[devid] = [] + fuse_tensors[devid] = dict() + tensor_map[devid] = dict() + devtensors[devid].append(tensor) + fuse_tensors[devid][tensor] = [tensor] + tensor_map[devid][tensor] = tensor + + nodes: List[IRCell] = [] + for devid, tensors in devtensors.items(): + if len(tensors) == 1: + continue + + # repeatly search for combinable tensors + while True: + can_merge = False + out = None + node = None + for t1, t2 in itertools.combinations(tensors, 2): + catdim = t1.catdim(t2) + if catdim is not None: + t1, t2 = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] + out = t1.concat(t2, dim=catdim) + node = Cat( + 'torch.cat', + ([tensor_map[devid][t1], tensor_map[devid][t2]], catdim) + ) + can_merge = True + break + elif t1.accumable(t2): + out = t1.accum(t2) + node = Add( + 'torch.add', + [tensor_map[devid][t1], tensor_map[devid][t2]] + ) + can_merge = True + break + # each time when creats a merge node, the output will be + # updated with a new full tensor. The corresponding input + # will be set according to the previous node output + if can_merge: + tensor_map[devid][out] = like(out) + node.set_output(0, tensor_map[devid][out]) # update output to a new full tensor + tensors.remove(t1) + tensors.remove(t2) + tensors.append(out) + nodes.append(node) + node.device = devid + fuse_tensors[devid][out] = fuse_tensors[devid][t1] + fuse_tensors[devid][t2] + del fuse_tensors[devid][t1] + del fuse_tensors[devid][t2] + else: + break + + if len(nodes) == 0: return ftensor + + new_ftensor = ftensor.like() + + # update consumer + min_idx = len(graph.nodes()) + assert len(ftensor.ctensors) == len(ftensor.consumers) + for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + fidx = graph.detach(consumer) + consumer.set_input( + consumer.inputs().index(ctensor), + new_ftensor.select(ctensor.indmap, ctensor.valmap) + ) + graph.attach(consumer, fidx) + min_idx = min(fidx, min_idx) + + # insert new producer + for devid, tensors in fuse_tensors.items(): + for ptensor in tensors: + new_tensor = like(ptensor, share=new_ftensor) + if len(tensors[ptensor]) == 1: + node = Identity('', [ptensor]) + node.device = devid + node.set_output(0, new_tensor) + nodes.append(node) + else: + for node in nodes: + if node.output(0) == tensor_map[devid][ptensor]: + node.set_output(0, new_tensor) + + for node in nodes[::-1]: + # print(node) + assert node not in graph.nodes() + assert len(node.outputs()) == 1 + graph.attach(node, min_idx) + + # insert and update backward node + if graph.train: + # update backward node + for consumer in new_ftensor.consumers: + assert isinstance(consumer.mirror, IRBpOperation) + bidx = graph.detach(consumer.mirror) + consumer.mirror.update() + graph.attach(consumer.mirror, bidx) + # insert backward node + bnodes = [node.gen_backward() for node in nodes] + bidx = min(graph.nodes().index(producer.mirror) for producer in ftensor.producers) + for bnode in bnodes: + bnode.device = bnode.mirror.device + graph.attach(bnode, bidx) + + return new_ftensor + + @staticmethod + def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): + """! + If a device have a same sub-tensor to be consumed multiple times, + then create a multiref forward node for it to make + each sub-tensor to be consumed only once in each device. + + This is to adapt with pytorch autograd function. + + producer -> consumers[0,1] + + producer -> multiref -> consumer[0] + |-----> consumer[1] + + @param graph IRGraph + @param ftensor IRFullTensor: the forward full tensor + """ + # collect to consumer tensors of each device + devtensors: Dict[int, Dict[IRSubTensor, List[IRCell]]] = dict() + for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + assert len(ctensor.device) == 1 + devid = ctensor.device[0] + if devid not in devtensors: + devtensors[devid] = dict() + if ctensor not in devtensors[devid]: + devtensors[devid][ctensor] = [] + devtensors[devid][ctensor].append(consumer) + + # add multiref forward node + multirefs: Dict[MultiRef, List[IRFwOperation]] = dict() + for devid in devtensors: + for ctensor in devtensors[devid]: + consumers = devtensors[devid][ctensor] + if len(consumers) == 1: + continue + multiref = MultiRef(None, [ctensor, len(consumers)]) + multiref.device = devid + ftensors = [ctensor.parent.like() for _ in range(len(consumers))] + itensors = [ft.select(ctensor.indmap, ctensor.valmap) for ft in ftensors] + for idx, itensor in enumerate(itensors): + multiref.set_output(idx, itensor) + + # update consumer + min_fidx = len(graph.nodes()) + for itensor, consumer in zip(itensors, consumers): + fidx = graph.detach(consumer) + idx = consumer.inputs().index(ctensor) + consumer.set_input(idx, itensor) + graph.attach(consumer, fidx) + min_fidx = min(fidx, min_fidx) + + # insert forward multiref + graph.attach(multiref, min_fidx) + multirefs[multiref] = consumers + + # insert / update backward + if graph.train: + for multiref, consumers in multirefs.items(): + # update consumer backward + for consumer in consumers: + assert isinstance(consumer.mirror, IRBpOperation) + bidx = graph.detach(consumer.mirror) + consumer.mirror.update() + graph.attach(consumer.mirror, bidx) + # insert backward + bnode = multiref.gen_backward() + bnode.device = multiref.device + bidx = max(graph.nodes().index(consumer.mirror) for consumer in consumers) + graph.attach(bnode, bidx+1) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 822d0458..10e127d6 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -326,10 +326,10 @@ def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): for subtensor in subtensors: tid = id(subtensor) # set up replica - if subtensor._id not in replicas: - replicas[subtensor._id] = [] - _tindex[tid] = [len(replicas[subtensor._id])] - replicas[subtensor._id].append(subtensor) + if subtensor.tid not in replicas: + replicas[subtensor.tid] = [] + _tindex[tid] = [len(replicas[subtensor.tid])] + replicas[subtensor.tid].append(subtensor) # setup value _tindex[tid].append(subtensor.valmap[0]) vchunks.add(subtensor.valmap[1]) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 3708410f..555fb2a0 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -115,6 +115,11 @@ def __init__(self, self._nodes: List[IRCell] = list() self._parameters = list() self._full_tensors: Dict[int, IRFullTensor] = dict() + self._train: bool = any( + isinstance(node, IRBpOperation) or + (isinstance(node, IRSegment) and node.forward) or + (isinstance(node, IRAdapter) and node.forward) for node in nodes + ) self._sched = None # the schedule strategy @@ -153,6 +158,15 @@ def __init__(self, self.reset_dependency() + @property + def train(self) -> bool: + """! + Train flag. + + @return train bool: True if backward is required, otherwise False (inference only). + """ + return self._train + def reset_dependency(self): """ Reset the node dataflow dependency diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 0aad1958..a21311ee 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -289,32 +289,32 @@ def like(self): return tensor @property - def producers(self) -> List[IRCell]: + def producers(self) -> Tuple[IRCell]: """ Producer IRCell list """ - return self._producers + return tuple(self._producers) @property - def ptensors(self): + def ptensors(self) -> Tuple[IRTensor]: """ Produced IRSubTensor list correspongding to producer IRCell """ - return self._ptensors + return tuple(self._ptensors) @property - def consumers(self) -> List[IRCell]: + def consumers(self) -> Tuple[IRCell]: """ Consumer IRCell list """ - return self._consumers + return tuple(self._consumers) @property - def ctensors(self): + def ctensors(self) -> Tuple[IRTensor]: """ Consumed IRSubTensor list correspongding to consumer IRCell """ - return self._ctensors + return tuple(self._ctensors) def add_producer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): @@ -335,26 +335,30 @@ def add_consumer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): assert tensor in cell.inputs(), f"tensor {tensor} not in node: {cell} inputs" if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): raise TypeError("Expect an IRCell and an IRTensor") - assert cell not in self._consumers, f"{cell} already exists as consumer" + if cell in self._consumers: + for idx, consumer in enumerate(self._consumers): + if cell == consumer: + assert self._ctensors[idx] != tensor, f"double add a same consumer-tensor pair: {cell}" self._consumers.insert(idx, cell) self._ctensors.insert(idx, tensor) - for t in self.ctensors: + for t in self._ctensors: t._dirty_grad = True def rm_producer(self, cell: IRCell) -> int: - if cell not in self.producers: + if cell not in self._producers: raise KeyError(f"Cell {cell} not found in producer") - idx = self.producers.index(cell) - self.producers.pop(idx) - self.ptensors.pop(idx) + while cell in self._producers: + idx = self._producers.index(cell) + self._producers.pop(idx) + self._ptensors.pop(idx) return idx def rm_consumer(self, cell: IRCell) -> int: - if cell not in self.consumers: + if cell not in self._consumers: raise KeyError(f"Cell {cell} not found in producer") - idx = self.consumers.index(cell) - self.consumers.pop(idx) - self.ctensors.pop(idx) + idx = self._consumers.index(cell) + self._consumers.pop(idx) + self._ctensors.pop(idx) return idx def clear_producer_consumer(self) -> int: @@ -377,7 +381,7 @@ def grad(self, val: Optional[Union[IRTensor, float]]): self._requires_grad = False if val is None else True if isinstance(val, IRFullTensor): assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." - for tensor in self.ctensors + self.ptensors: + for tensor in self._ctensors + self._ptensors: tensor._dirty_grad = True @property @@ -391,7 +395,7 @@ def requires_grad(self, val: bool): self.grad = IRFullTensor(self.shape, 'g' + self.name, False).as_grad() elif not val and self.grad is not None: self.grad = None - for tensor in self.ctensors + self.ptensors: + for tensor in self._ctensors + self._ptensors: tensor._dirty_grad = True def as_param(self): From 8bc0f40fc4f1272c2a126e49a4fd830f8857ad2e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Aug 2022 14:47:25 +0800 Subject: [PATCH 0951/1892] add coshard policy --- examples/nlp/gpt/model.py | 12 ++-- examples/nlp/gpt/policy/spmd.py | 99 +++++++++++++++++++++++++++------ examples/nlp/gpt/train.py | 2 +- 3 files changed, 88 insertions(+), 25 deletions(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index b38f6a84..c9b4480c 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -11,9 +11,9 @@ class Config: seqlen = 1024 # 340 M model - # embed_dim = 1024 - # layers = 24 - # attention_heads = 16 + embed_dim = 1024 + layers = 8 # 24 + attention_heads = 16 # 1.3 B model # embed_dim = 2048 @@ -21,9 +21,9 @@ class Config: # attention_heads = 32 # 2.6 B model - embed_dim = 2560 - layers = 32 - attention_heads = 32 + # embed_dim = 2560 + # layers = 32 + # attention_heads = 32 # 6.7 B model # embed_dim = 4096 diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index ced5f2b3..ac0428c0 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -1,7 +1,45 @@ +from typing import List + from cube.graph import IRGraph +from cube.graph.function.anchor import IRGraphAnchor from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +# ========================= parallelisms ================================= + +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], + idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# coshard +def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, + idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) + assert sub_nodes is not None + for devid in devs: + for coid in range(colocate): + sub_node = sub_nodes[devid * colocate + coid] + graph.assign(sub_node, devid) + return sub_nodes + + +# ========================= parallelisms ================================= + + def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 # print(graph.extra_repr()) @@ -13,39 +51,64 @@ def PASSingle(graph: IRGraph, resource): def PASMegatronTP(graph: IRGraph, resource): tp_size = resource.ngpus + tp_devs = list(range(tp_size)) fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - if isinstance(comment, str): - for sub_node in sub_nodes: - sub_node.comment = comment - assert all(isinstance(n, IRFwOperation) for n in sub_nodes), f"Fail to partition node {node}" - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return sub_nodes + # annotating code structure -- not consider multiref on embedding weight + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + # why -1: multiref + fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + + # attention + attns = [node for node in fnodes if node.name == 'self_attention'] + for attn in attns: + _tp(graph, attn, tp_devs, idx=1, dim=0) + + # feedforward + ffns = [node for node in fnodes if node.name == 'feedforward'] + for ffn in ffns: + _tp(graph, ffn, tp_devs, idx=1, dim=0) + + # replicate other nodes + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: + _replica(graph, node, tp_devs) + + return graph + + +def PASMeshShard(graph: IRGraph, resource): + + # print(graph.extra_repr()) + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] # annotating code structure -- not consider multiref on embedding weight - multirefs = [node for node in fnodes if isinstance(node, IRFwOperation) and node.name == 'multiref'][1:] - for idx in range(0, len(multirefs), 2): - multirefs[idx].comment = f'====> start of transformer {idx // 2}' + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + # why -1: multiref + fnodes[idx-1].comment = f'===> start of transformer layer {lid}' # attention attns = [node for node in fnodes if node.name == 'self_attention'] for attn in attns: - tensor_parallelism(attn, idx=1, dim=0, num=tp_size) + # _tp(graph, attn, tp_devs, idx=1, dim=0) + _coshard(graph, attn, tp_devs, colocate=2, idx=1, dim=0) # feedforward ffns = [node for node in fnodes if node.name == 'feedforward'] for ffn in ffns: - tensor_parallelism(ffn, idx=1, dim=0, num=tp_size) + # _tp(graph, ffn, tp_devs, idx=1, dim=0) + _coshard(graph, ffn, tp_devs, colocate=4, idx=1, dim=0) # replicate other nodes for node in graph.nodes(): if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: - rnodes = graph.replicate(node, times=tp_size) - for idx, rnode in enumerate(rnodes): - graph.assign(rnode, idx) + _replica(graph, node, tp_devs) + # print(graph.extra_repr()) return graph diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 95549e55..2233699c 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMegatronTP + examples/nlp/gpt/train.py --policy PASMeshShard --fp16 """ From e41b52fd861c7a18029f70000bd8d13111a85441 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Aug 2022 15:15:32 +0800 Subject: [PATCH 0952/1892] bring back distnn --- handcraft/gpt3/train.py | 4 +- handcraft/mbart/train.py | 2 +- handcraft/module/distnn.py | 278 +++++++++++++++++++++++++++++++++++++ handcraft/swin/train.py | 2 +- 4 files changed, 282 insertions(+), 4 deletions(-) create mode 100644 handcraft/module/distnn.py diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py index 895d8368..12aedf1c 100644 --- a/handcraft/gpt3/train.py +++ b/handcraft/gpt3/train.py @@ -31,7 +31,7 @@ from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer -from cube.runtime.adapter.distnn import AllReduceIdentity, IdentityAllreduce, AllGatherSplit +from handcraft.module.distnn import AllReduceIdentity, IdentityAllreduce, AllGatherSplit from cube.profiler import CudaTimer from cube.profiler.memory import memory_summary @@ -486,7 +486,7 @@ class Pooler(torch.nn.Module): def __init__(self): super().__init__() - self.dense = troch.nn.Linear(config.hidden_size, config.hidden_size) + self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) def forward(self, hidden_states, sequence_index=0): pooled = hidden_states[:, sequence_index, :] diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py index 58139b0f..ea9fddfc 100644 --- a/handcraft/mbart/train.py +++ b/handcraft/mbart/train.py @@ -19,7 +19,7 @@ import cube from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer -from cube.runtime.adapter.distnn import ReduceBroadcast, AllReduceIdentity, IdentityAllreduce +from handcraft.module.distnn import ReduceBroadcast, AllReduceIdentity, IdentityAllreduce from cube.profiler import CudaTimer from cube.profiler.memory import memory_summary diff --git a/handcraft/module/distnn.py b/handcraft/module/distnn.py new file mode 100644 index 00000000..bd0878c4 --- /dev/null +++ b/handcraft/module/distnn.py @@ -0,0 +1,278 @@ +from typing import List +import torch + +from cube.profiler.timer import CudaTimer +from cube.runtime.device import DeviceGroup + + +class SendRecv(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dst: int, group): + CudaTimer().start(field_name='comm') + ctx._tsize = input_.size() + ctx._tdtype = input_.dtype + ctx._src = dst + if not input_.is_contiguous(): + input_ = input_.contiguous() + sendop = torch.distributed.P2POp( + torch.distributed.isend, input_, dst + ) + reqs = torch.distributed.batch_isend_irecv([sendop]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, _grad: torch.Tensor): + CudaTimer().start(field_name='comm') + size = ctx._tsize + dtype = ctx._tdtype + src = ctx._src + grad = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) + recvop = torch.distributed.P2POp( + torch.distributed.irecv, grad, src + ) + reqs = torch.distributed.batch_isend_irecv([recvop]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return grad, None, None + + +class RecvSend(torch.autograd.Function): + + @staticmethod + def forward(ctx, size, dtype, src: int, ranks: List[int]): + CudaTimer().start(field_name='comm') + ctx._tsize = size + ctx._tdtype = dtype + ctx._dst = src + input_ = torch.empty( + size, dtype=dtype, device=torch.cuda.current_device(), + requires_grad=True) + recvop = torch.distributed.P2POp( + torch.distributed.irecv, input_, src + ) + reqs = torch.distributed.batch_isend_irecv([recvop]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad: torch.Tensor): + CudaTimer().start(field_name='comm') + dst = ctx._dst + if not grad.is_contiguous(): + grad = grad.contiguous() + sendop = torch.distributed.P2POp( + torch.distributed.isend, grad, dst + ) + reqs = torch.distributed.batch_isend_irecv([sendop]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return None, None, None, None + + +class AllReduceIdentity(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, group): + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + CudaTimer().start(field_name='comm') + torch.distributed.all_reduce(input_, group=group) + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class IdentityAllreduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, group): + ctx._group = group + return input_ + + @staticmethod + def backward(ctx, grad_output): + world_size = torch.distributed.get_world_size(ctx._group) + if world_size == 1: + return grad_output, None + CudaTimer().start(field_name='comm') + torch.distributed.all_reduce(grad_output, group=ctx._group) + CudaTimer().stop(field_name='comm') + return grad_output, None + + +class ReduceScatterAllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dim: int, group): + ctx._group = group + ctx._dim = dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_, + CudaTimer().start(field_name='comm') + input_tensors = input_.chunk(world_size, dim) + rank = torch.distributed.get_rank(group) + input_ = torch.empty_like(input_tensors[rank], requires_grad=True) + torch.distributed.reduce_scatter( + input_, input_tensors, group=group + ) + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output): + group = ctx._group + dim = ctx._dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output + CudaTimer().start(field_name='comm') + rank = torch.distributed.get_rank(group) + tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] + tensor_list[rank] = grad_output + torch.distributed.all_gather(tensor_list, grad_output, group=group) + grad = torch.cat(tensor_list, dim=dim).contiguous() + CudaTimer().stop(field_name='comm') + return grad, None, None + + +class AllGatherSplit(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dim: int, group): + ctx._group = group + ctx._dim = dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + CudaTimer().start(field_name='comm') + rank = torch.distributed.get_rank(group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + output = torch.cat(tensor_list, dim=dim).contiguous() + CudaTimer().stop(field_name='comm') + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + group = ctx._group + dim = ctx._dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output + CudaTimer().start(field_name='comm') + input_list = grad_output.chunk(world_size, dim=dim) + rank = torch.distributed.get_rank(group) + grad = input_list[rank].contiguous() + CudaTimer().stop(field_name='comm') + return grad, None, None + + +class SplitAllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dim: int, group): + ctx._group = group + ctx._dim = dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + CudaTimer().start(field_name='comm') + input_list = input_.chunk(world_size, dim=dim) + rank = torch.distributed.get_rank(group) + input_ = input_list[rank].contiguous() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + group = ctx._group + dim = ctx._dim + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output + CudaTimer().start(field_name='comm') + rank = torch.distributed.get_rank(group) + tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] + tensor_list[rank] = grad_output + torch.distributed.all_gather(tensor_list, grad_output, group=group) + grad = torch.cat(tensor_list, dim=dim).contiguous() + CudaTimer().stop(field_name='comm') + return grad, None, None + + +class ReduceBroadcast(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, dst: int, group): + ctx._dst = dst + ctx._group = group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + CudaTimer().start(field_name='comm') + torch.distributed.reduce(input_, dst, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output): + src = ctx._dst + group = ctx._group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output, None, None + CudaTimer().start(field_name='comm') + torch.distributed.broadcast(grad_output, src, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return grad_output, None, None + + +class BroadcastReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_: torch.Tensor, src: int, group): + ctx._src = src + ctx._group = group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + CudaTimer().start(field_name='comm') + torch.distributed.broadcast(input_, src, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return input_ + + @staticmethod + def backward(ctx, grad_output): + dst = ctx._src + group = ctx._group + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return grad_output, None, None + CudaTimer().start(field_name='comm') + if not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + torch.distributed.reduce(grad_output, dst, group=group) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm') + return grad_output, None, None diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py index 83ffb5f7..d7f70506 100644 --- a/handcraft/swin/train.py +++ b/handcraft/swin/train.py @@ -19,7 +19,7 @@ from cube.profiler.memory import memory_summary, model_summary from cube.runtime.adapter.reducer import Reducer from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.distnn import IdentityAllreduce, AllReduceIdentity, AllGatherSplit +from handcraft.module.distnn import IdentityAllreduce, AllReduceIdentity, AllGatherSplit from handcraft.module.schedule import schedule_1f1b from handcraft.module.stage import PipeStage, layer_division from handcraft.swin.utils import create_position_bias, create_position_index, trunc_normal_, window_partition, window_reverse, DropPath From 647ceaf7c81f4d4dd2171a86f6cd64eda93b77cb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Aug 2022 18:31:50 +0800 Subject: [PATCH 0953/1892] recompute for additional generated node --- cube/graph/gener/gen.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 53804256..94aef8ec 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -387,6 +387,10 @@ def local_producer_fusion(graph: IRGraph, ftensor: IRFullTensor) -> IRFullTensor to: producer --(ftensor)--> fused nodes --(new ftensor)--> consumer + Recompute policy: if all the producers are recomputed in a same + recompute group, then the additional generated cat/add are also + apllied with same recompute region. Otherwise no recompute. + @param tensors List[IRSubTensor]: tensors to be fused in local device @return new_ftensor IRFullTensor: the new full tensor. @@ -414,7 +418,7 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens fuse_tensors[devid][tensor] = [tensor] tensor_map[devid][tensor] = tensor - nodes: List[IRCell] = [] + nodes: List[IRFwOperation] = [] for devid, tensors in devtensors.items(): if len(tensors) == 1: continue @@ -462,6 +466,12 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens if len(nodes) == 0: return ftensor + # recompute + rcid = set(producer.recompute for producer in ftensor.producers) + rcid = list(rcid)[0] if len(rcid) == 1 else None + for node in nodes: + node.recompute = rcid + new_ftensor = ftensor.like() # update consumer From 613d4f1ccd019da3d21b6d7e1d005fa823fd2a4a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Aug 2022 18:32:12 +0800 Subject: [PATCH 0954/1892] add coshard test --- examples/nlp/gpt/policy/spmd.py | 1 + tests/test_examples.sh | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index ac0428c0..7729874b 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -30,6 +30,7 @@ def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) assert sub_nodes is not None + graph.recompute(sub_nodes) for devid in devs: for coid in range(colocate): sub_node = sub_nodes[devid * colocate + coid] diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 67bd2aac..0d0022d9 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -55,6 +55,11 @@ OMP_NUM_THREADS=4 torchrun \ --nnodes=1 \ examples/nlp/gpt/train.py --policy PASMegatron --fp16 +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/nlp/gpt/train.py --policy PASMeshShard --fp16 + # test scientific model From 42ef0172f7b26f63a72d4b782df3ec451de2e548 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Aug 2022 19:31:05 +0800 Subject: [PATCH 0955/1892] fix full tensor delete --- cube/graph/graph.py | 5 ++++- cube/logics/model.py | 2 -- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 555fb2a0..95462681 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -329,6 +329,9 @@ def detach(self, node: IRCell, reset_dependency=False) -> int: otensors.append(otensor) for otensor in otensors: otensor.parent.rm_producer(node) + ftensor = otensor.parent + if len(ftensor.producers) == 0 and len(ftensor.consumers) == 0: + del self._full_tensors[otensor.parent.tid] if reset_dependency: self.reset_dependency() return index @@ -368,7 +371,7 @@ def attach(self, node: IRCell, index, reset_dependency=False): otensors.append(otensor) for otensor in otensors: if otensor.parent._id not in self._full_tensors: - self._full_tensors[itensor.parent._id] = itensor.parent + self._full_tensors[itensor.parent.tid] = itensor.parent idx = 0 for producer in otensor.parent.producers: if self.nodes().index(producer) < index: diff --git a/cube/logics/model.py b/cube/logics/model.py index bb9db939..10696476 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -59,8 +59,6 @@ def forward(graph: IRGraph, *args) -> IRGraph: while itensor in graph.outputs(): oidx = graph.outputs().index(itensor) graph.set_output(oidx, arg) - for itensor in itensors: - del graph._full_tensors[itensor.parent.tid] # dtype inference for node in graph.nodes(): From 0b46d30f93fc17d83b58e7bf06a6de99a2f27286 Mon Sep 17 00:00:00 2001 From: Zijian Ding Date: Mon, 8 Aug 2022 23:18:31 -0700 Subject: [PATCH 0956/1892] Add wrf2 ondim testcase, change wrf2 to pass policy through command line. --- examples/wrf/wrf2.py | 20 +++++++++++++++++++- tests/test_examples.sh | 7 ++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py index 56766f08..044d81c2 100644 --- a/examples/wrf/wrf2.py +++ b/examples/wrf/wrf2.py @@ -2,12 +2,30 @@ import torch.nn.functional as F from cube.runtime.syndata import SciLoopVariables +from cube.profiler.timer import CudaTimer, print_each_rank from examples.wrf.policy.naive import PAS +import examples.wrf.policy.onedim as onedim torch.set_default_tensor_type(torch.DoubleTensor) import cube +import argparse + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') +args = parser.parse_args() + + +cube.init() +chosenPAS = PAS +policies = list(onedim.__dict__.keys()) +policies = [policy for policy in policies if policy.startswith('PAS')] +if args.policy in onedim.__dict__: + chosenPAS = onedim.__dict__[args.policy] + print_each_rank(f'using policy from onedim.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") class WRF(torch.nn.Module): def __init__(self, dt, ntau, nz, ny, nx, dz, dy, dx, device): @@ -455,7 +473,7 @@ def solve_tridiagonal_(self, varloader = SciLoopVariables(variables=[U, V, W, O, Theta, phi1, mu1], constants=[]) model = cube.SemanticModel(wrf, input_shapes=tuple(varloader.shapes)) - @cube.compile(model=model, dataloader=varloader, PAS=PAS) + @cube.compile(model=model, dataloader=varloader, PAS=chosenPAS) def train_iter(model, dataloader): U, V, W, O, Theta, phi1, mu1 = next(dataloader) U, V, W, O, Theta, phi1, mu1 = model(U, V, W, O, Theta, phi1, mu1) diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 0d0022d9..71ef5575 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -71,4 +71,9 @@ OMP_NUM_THREADS=4 torchrun \ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=1 \ --nnodes=1 \ - examples/wrf/wrf2.py + examples/wrf/wrf2.py --policy PAS + +OMP_NUM_THREADS=1 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/wrf/wrf2.py --policy PAS_ALL_Y From b027162456a636f72e3d2529d3bd46b5c354ef34 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 Aug 2022 14:29:17 +0800 Subject: [PATCH 0957/1892] feature support and fix parser bug: fix bug: parser will not consider Tensor? type (Optional[Tensor]) fix bug: graph attach adding full tensor bug (Zijian) add feature: tensor can be attribute of graph: 1) parameter (trainable), 2) buffer (un-trainable) --- cube/codegen/codegen.py | 30 ++++++++++------- cube/codegen/register.py | 57 --------------------------------- cube/graph/function/anchor.py | 2 +- cube/graph/function/function.py | 53 +++++++++++++++++++++++++++++- cube/graph/gener/gen.py | 5 ++- cube/graph/graph.py | 24 +++++++------- cube/graph/parser/mapping.py | 8 ++++- cube/graph/parser/parser.py | 10 +++--- cube/graph/parser/register.py | 26 +++++++++++---- cube/ir/cten.py | 54 +++++++++++++++++++++++++------ cube/ir/operator.py | 2 +- cube/ir/tensor.py | 16 ++++++--- 12 files changed, 179 insertions(+), 108 deletions(-) delete mode 100644 cube/codegen/register.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 97ef03bd..bf5849fd 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -38,7 +38,7 @@ def get_backward_callsite_io_tensors(bp_segment:IRSegment): input_tensors = [t for t in bp_segment.mirror.inputs() if \ isinstance(t, IRSubTensor) and \ t.requires_grad and \ - not t.is_param() + not t.is_attr() ] output_tensors = [t for t in bp_segment.mirror.outputs() if isinstance(t, IRSubTensor)] input_grads = [t.grad for t in input_tensors] @@ -87,7 +87,7 @@ def calc_tenvars_lifetime( lifetime : Dict[IRTensor, int] = dict() def is_temp_tensor(v): - return isinstance(v, IRSubTensor) and not v.is_param() + return isinstance(v, IRSubTensor) and not v.is_attr() lifetime.update((tsin, 0) for tsin in subgraph_inputs if is_temp_tensor(tsin)) @@ -190,7 +190,7 @@ def tensor_naming(self, tensor: Any, prefix_attr: Optional[str] = None) -> str: if '.' in tensor_name: tensor_name = tensor_name.split('.')[0] name = '_'.join([tensor_name, str(tensor.tid)]) - if prefix_attr is not None and tensor.is_param(): + if prefix_attr is not None and tensor.is_attr(): name = prefix_attr + name else: name = str(tensor) @@ -209,7 +209,7 @@ def tuple_naming(self, tensors: List[Any], """ names = [] for t in tensors: - if isinstance(t, IRTensor) and skip_attr and t.is_param(): + if isinstance(t, IRTensor) and skip_attr and t.is_attr(): continue names.append(self.tensor_naming(t, prefix_attr)) name = '(' + ', '.join(names + ['']) + ')' @@ -228,7 +228,7 @@ def return_naming(self, tensors: List[Any], """ names = [] for t in tensors: - if isinstance(t, IRTensor) and skip_attr and t.is_param(): + if isinstance(t, IRTensor) and skip_attr and t.is_attr(): continue names.append(self.tensor_naming(t, prefix_attr)) names = '_' if len(names) == 0 else ', '.join(names) @@ -252,6 +252,7 @@ def emit_tensors_release(self, tensors:Iterable[IRTensor]) -> str: tnames : Generator = (self.tensor_naming(t) for t in tensors) return 'del ' + ', '.join(tnames) + class AutogradAdapterCodeGen(CodeGen): """ Generate autograd adapter code (PyTorch) @@ -363,6 +364,7 @@ def __init__(self, execplan: ExecutionPlan): # model full code self.init_code: List[str] = [ '\n\n########## Generated Model Code ###########', + 'from typing import *', 'import torch', 'import torch.utils.checkpoint as ckpt', 'import cube', '', ''] # customized op code @@ -465,7 +467,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: args = list() for t in node.inputs(): if isinstance(t, IRSubTensor): - if not t.is_param(): + if not t.is_attr(): args.append(self.tensor_naming(t)) else: args.append(self.tensor_naming(t)) @@ -519,15 +521,20 @@ def emit_node_tensors_declare(self, node: IRCell): This method also populates `self.symbols : SymbolTable` to record the names of the variables for the tensors ever encountered. """ - - sign = 'torch.nn.Parameter(torch.empty({shape}, dtype={dtype}))' + psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" + bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}))" map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" for itensor in node.inputs(): name = self.tensor_naming(itensor, prefix_attr='self.') if isinstance(itensor, IRSubTensor): - if itensor.is_param() and not self.symbols.exist(name): + if itensor.is_attr() and not self.symbols.exist(name): self.symbols.create(name) - code = f'{name} = {sign.format(shape=tuple(itensor.shape), dtype=self.dtype_map(itensor.dtype))}' + sign = psign if itensor.is_param() else bsign + code = sign.format( + name=self.tensor_naming(itensor), + shape=tuple(itensor.shape), + dtype=self.dtype_map(itensor.dtype) + ) self.model_init_statements.append(code) tid = itensor.parent.tid slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) @@ -537,6 +544,7 @@ def emit_node_tensors_declare(self, node: IRCell): slicers=str(slicers), val_chunks=val_chunks ) self.model_init_statements.append(code) + self.model_init_statements.append('') if isinstance(itensor, str): if name.startswith('self.'): if not hasattr(self._ref_module, name[5:]): @@ -652,7 +660,7 @@ def recompute(tensor_2222): subseg = self.execplan.graph.segment(nodes) - inputs = [t for t in subseg.inputs() if not t.is_param()] + inputs = [t for t in subseg.inputs() if not t.is_attr()] input_names = [self.tensor_naming(t) for t in inputs] input_names_tuple = ', '.join(input_names) outputs = [t for t in subseg.outputs()] diff --git a/cube/codegen/register.py b/cube/codegen/register.py deleted file mode 100644 index 0f38a371..00000000 --- a/cube/codegen/register.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any, Dict, List, Union - -from cube.ir.cten import IRTensor - - -class VarManager: - """ - Tensor naming reuse engine for saving memory - """ - - def __init__(self): - # the unique id - self.nid = 0 - self.slots: List[int] = list() - # original tensor id -> new tensor id mapping - self.tmap: Dict[int, int] = dict() - - def free(self, tensor: Union[IRTensor, Any]): - """ - Free a tensor - """ - if isinstance(tensor, IRTensor): - assert tensor._id in self.tmap, f"Double free on tensor {tensor}" - reg = self.tmap[tensor._id] - del self.tmap[tensor._id] - self.slots.append(reg) - - def allocate(self, tensor: Union[IRTensor, Any]) -> str: - """ - Allocate a tensor name for the tensor. - New tensors will be allocated by available - unique ids freed by other tensor. - Existing teensor will get the allocated name. - """ - if isinstance(tensor, IRTensor): - ttype = 'g' if tensor.is_grad() else 't' - # param is graph attribute, don't need allocation - if tensor.is_param(): - return f'{tensor.name}_{tensor._id}' - if tensor._id in self.tmap: - # fetch the original one - reg = self.tmap[tensor._id] - else: - # allocate a new one - if len(self.slots) == 0: - reg = self.nid - self.nid += 1 - else: - reg = self.slots.pop(-1) - self.tmap[tensor._id] = reg - # reg = tensor._id # => enable this for debug - return f'{ttype}{reg}' - else: - return str(tensor) - - - diff --git a/cube/graph/function/anchor.py b/cube/graph/function/anchor.py index 649ec9ef..18c82803 100644 --- a/cube/graph/function/anchor.py +++ b/cube/graph/function/anchor.py @@ -17,7 +17,7 @@ def infer_shape(self): def __repr__(self) -> str: sign = self.signature.split('.')[-1] - ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_param()] + ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_attr()] dscp = (f"FwOp{self._id}(sign={sign}[{self.name}], " f"inputs={ins}, " f"outputs={self.outputs()})") diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 340e4eef..6b8c3443 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Tuple, Dict +from typing import Any, List, Optional, Tuple, Dict, Union import string import copy import torch @@ -796,6 +796,57 @@ def Embedding(signature, inputs: List): return IRDimops(signature, [anno], [itensor, weight], 'embedding', padding_idx=padding_idx, start=start, stop=stop) +def Flatten(signature, inputs: List): + tensor: IRTensor = inputs[0] + start_dim, end_dim = inputs[1:] + end_dim = len(tensor.shape) + end_dim if end_dim < 0 else end_dim + ishape = ShapeAnno.create_shape_str(tensor.shape) + for dim in range(start_dim, end_dim+1): + ishape[dim] += '^' + oshape = ishape[:start_dim] + oshape.append(ishape[start_dim:end_dim+1]) + anno = OpAnno.create_op_str([ishape], [oshape]) + return IRDimops(signature, [anno], [tensor], 'flatten', start_dim=start_dim, end_dim=end_dim) + + +def Roll(signature, inputs: Tuple[IRTensor, Union[int, Tuple[int]], Union[int, Tuple[int]]]): + tensor = inputs[0] + shifts, dims = inputs[1:] + # TODO: enable partition + ishape = ShapeAnno.create_shape_str(tensor.shape, reduction='^') + anno = OpAnno.create_op_str([ishape], [ishape]) + return IRDimops(signature, [anno], [tensor], 'roll', shifts=shifts, dims=dims) + + +def AdaptiveAvgPool1d(signature, inputs: Tuple[IRTensor, Tuple[int]]): + tensor = inputs[0] + out_size = inputs[1] + ishape = ShapeAnno.create_shape_str(tensor.shape) + ishape[-1] += '^' + oshape = ishape[:-1] + [str(size) for size in out_size] + anno = OpAnno.create_op_str([ishape], [oshape]) + return IRDimops(signature, [anno], [tensor], 'adaptive_avg_pool1d', output_size=out_size) + + +def CrossEntropy(signature, inputs): + # FIXME: reduction is by default 'mean', in this way it cannot be partitioned + # no N dimension. + tensor, target, weight = inputs[0:3] + assert weight is None, "weight not supported for cross entropy" + size_average, ignore_index, reduce, reduction, label_smoothing = inputs[3:] + annos = [ + 'C^, N -> 1', + 'N+ C, N+ -> 1', + 'N+ C *, N+ * -> 1' + ] + return IRDimops( + signature, annos, [tensor, target], 'cross_entropy', + weight=weight, size_average=size_average, ignore_index=ignore_index, + reduce=reduce, reduction=reduction, label_smoothing=label_smoothing + ) + + + def MultiRef(signature, inputs: List[IRTensor]): """ cube.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 94aef8ec..2b52a58a 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -126,8 +126,11 @@ def gen_activation(graph: IRGraph) -> IRGraph: # no consumer usually mean loss if len(ftensor.consumers) == 0: continue + # graph attribute: buffer + if len(ftensor.producers) == 0: + continue # no require for communication - if len(ftensor.consumers) == 1 and len(ftensor.producers) == 0 and \ + if len(ftensor.consumers) == 1 and len(ftensor.producers) == 1 and \ ftensor.consumers[0].device == ftensor.producers[0].device: continue diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 95462681..48d8426c 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -86,8 +86,8 @@ def dispatch(self, devid: int, for_mirror=True) -> Optional[IRCell]: def to_str(self, skip_attr: bool = False) -> str: name = ('f' if self.forward else 'b') + 'Segment' - inputs = tuple(t for t in self.inputs() if not (t.is_param() and skip_attr)) - outputs = tuple(t for t in self.outputs() if not (t.is_param() and skip_attr)) + inputs = tuple(t for t in self.inputs() if not (t.is_attr() and skip_attr)) + outputs = tuple(t for t in self.outputs() if not (t.is_attr() and skip_attr)) return f'{name}{self._id}-{self.device}(inputs={inputs}, outputs={outputs})' def __repr__(self): @@ -113,7 +113,7 @@ def __init__(self, module_name: str): self._nodes: List[IRCell] = list() - self._parameters = list() + self._attributes = list() self._full_tensors: Dict[int, IRFullTensor] = dict() self._train: bool = any( isinstance(node, IRBpOperation) or @@ -140,14 +140,14 @@ def __init__(self, for idx, tensor in enumerate(outputs): self.set_output(idx, tensor) - # set parameters and full tensors + # set parameters / buffers and full tensors for node in nodes: for tensor in node.inputs() + node.outputs(): if isinstance(tensor, IRSubTensor): pid = tensor.parent._id self._full_tensors[pid] = tensor.parent - if tensor.is_param(): - self._parameters.append(input) + if tensor.is_attr(): + self._attributes.append(tensor) for ftensor in self._full_tensors.values(): ftensor.clear_producer_consumer() @@ -190,13 +190,13 @@ def reset_dependency(self): producer.add_successor(-1, producer.mirror) producer.mirror.add_predecessor(-1, producer) - def parameters(self): + def attributes(self) -> Tuple[IRSubTensor]: """ Return parameter list """ - return copy.copy(self._parameters) + return tuple(self._attributes) - def full_tensors(self): + def full_tensors(self) -> List[IRSubTensor]: """ Return full tensor list """ @@ -370,8 +370,8 @@ def attach(self, node: IRCell, index, reset_dependency=False): if isinstance(otensor, IRSubTensor) and otensor not in otensors: otensors.append(otensor) for otensor in otensors: - if otensor.parent._id not in self._full_tensors: - self._full_tensors[itensor.parent.tid] = itensor.parent + if otensor.parent.tid not in self._full_tensors: + self._full_tensors[otensor.parent.tid] = otensor.parent idx = 0 for producer in otensor.parent.producers: if self.nodes().index(producer) < index: @@ -473,7 +473,7 @@ def from_logic_graph(nodes: List[IRCell], if isinstance(ftensor, IRFullTensor): producers[ftensor] = node for ftensor, cnodes in consumers.items(): - if len(cnodes) == 1 or ftensor.is_param(): continue + if len(cnodes) == 1 or ftensor.is_attr(): continue itensors = [ftensor.like() for _ in range(len(cnodes))] for itensor, consumer in zip(itensors, cnodes): while ftensor in consumer.inputs(): diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index b78089c8..0dac58ff 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -76,7 +76,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('embedding'): function.Embedding, - # __ftemplate('layer_norm'): function.LayerNorm, + __ftemplate('cross_entropy'): function.CrossEntropy, # torch aten @@ -142,6 +142,12 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('stack'): function.Stack, + __ttemplate('flatten'): function.Flatten, + + __ttemplate('roll'): function.Roll, + + __ttemplate('adaptive_avg_pool1d'): function.AdaptiveAvgPool1d, + # runtime functions __rtemplate('anchor'): function.GraphAnchor, diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index f2aa1f8b..5e605154 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -9,8 +9,6 @@ from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import Sign2Op, DType2IRDType -import warnings - _refmodule = torch.nn.Module() @@ -399,6 +397,11 @@ def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: var_name = node.outputsAt(0).debugName() dtype = node.outputsAt(0).type().str() + if dtype == 'Tensor?': + tensor = getattr(module, label) + if torch.is_tensor(tensor): + dtype = 'Tensor' + # this usually means weight (nn.Parameter in torch) if dtype == 'Tensor': tensor = getattr(module, label) @@ -414,8 +417,7 @@ def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: if isinstance(tensor, torch.nn.Parameter): ir_tensor.as_param() else: - warnings.warn('Detected non-parameter tensor as graph attribute. Regard them as parameters') - ir_tensor.as_param() + ir_tensor.as_buffer() frame.add_attr(label, ir_tensor) frame.add_attr_content(ir_tensor.tid, tensor) frame.add_var(var_name, ir_tensor) diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 43c89188..4e51af77 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -6,7 +6,7 @@ import inspect import torch -from cube.graph.function.dimops import IRDimops +from cube.graph.function.dimops import IRDimops, OpAnno from cube.graph.parser.mapping import Sign2Op @@ -27,6 +27,9 @@ def register(anno: str, name: Optional[str] = None): @cube.register('a (b c) -> (a b) c') def funcname(x: torch.Tensor, b: int = 4): xxx + + Note: for Optional[torch.Tensor] type, user should annotate the + dimension when the input is not None. """ def decorator(fn: Callable): if not callable(fn): @@ -35,21 +38,32 @@ def decorator(fn: Callable): op_name = name if name is not None else fsig args = inspect.signature(fn) arg_names = list(args.parameters.keys()) - arg_kind = [args.parameters[name].annotation for name in arg_names] - kwarg_names = [name for (name, kind) in zip(arg_names, arg_kind) if kind != torch.Tensor] - nkwargs = len(kwarg_names) - ninputs = len(arg_names) - len(kwarg_names) + arg_kinds = [args.parameters[name].annotation for name in arg_names] + allow_types = (torch.Tensor, Optional[torch.Tensor]) + for ninputs, kind in enumerate(arg_kinds): + if kind in allow_types: + ninputs += 1 + continue + assert not any(k in allow_types for k in arg_kinds[ninputs:]), \ + f"Type of {allow_types} should be consecutive in parameter order." + break + nkwargs = len(arg_names) - ninputs + kwarg_names = [name for name in arg_names[ninputs:]] # get customized op code code = inspect.getsource(fn) code = code[code.index('def'):] def udfop(signature: str, inputs: List[Any]): + manno = OpAnno(anno) tensors = inputs[:ninputs] + for idx in range(ninputs): + if arg_kinds[idx] == Optional[torch.Tensor] and tensors[idx] is None: + manno.set_input(idx, '?') kwarg_vals = inputs[ninputs:] kwargs = dict() for name, val in zip(kwarg_names, kwarg_vals): kwargs[name] = val - return IRDimops(signature, [anno], tensors, **kwargs, name=op_name) + return IRDimops(signature, [repr(manno)], tensors, **kwargs, name=op_name) print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') Sign2Op.register(fsig, udfop, code) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 125ed853..aa938450 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -434,7 +434,7 @@ class IRTensor: and will be translated to None in code generation. """ - _meta = ['name', '_is_param', '_is_grad', '_requires_grad', '_dtype'] + _meta = ['name', '_is_attr', '_is_grad', '_requires_grad', '_dtype'] def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): @@ -446,7 +446,7 @@ def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): self._cell: Optional[IRCell] = None self._dtype: IRDType = dtype - self._is_param: bool = False + self._is_attr: bool = False self._is_grad: bool = False # tensor gradient @@ -500,30 +500,66 @@ def device(self, val: Union[int, List[int]]): "tensor placement is not allowed to set manually" ) + def is_attr(self) -> bool: + """! + Check if the tensor is graph attribute. + + @return is_attr boolean: True if is graph attribute (buffer or parameter) + """ + return self._is_attr + + def is_param(self) -> bool: + """! + Check if the tensor is parameter + + @return is_param boolean: True if is parameter. + """ + return self._is_attr and self._requires_grad + + def is_buffer(self) -> bool: + """! + Check if the tensor is buffer. + + @return is_buffer boolean: True if is buffer. + """ + return self._is_attr and not self._requires_grad + + def is_grad(self) -> bool: + """! + Check if the tensor is gradient + + @return is_grad boolean: True if is gradient + """ + return self._is_grad + def as_param(self): """ Set the tensor as trainable parameter """ assert self._grad is not None, "missing grad tensor" self._requires_grad = True + self._is_attr = True self._is_grad = False - self._is_param = True return self - def is_param(self): + def as_buffer(self): """ - Check if the tensor is parameter + Set the tensor as un-trainable buffer """ - return self._is_param + self._requires_grad = False + self._is_attr = True + self._is_grad = False + return self def as_grad(self): + """ + Set the tensor as gradient + """ self._is_param = False + self._is_attr = False self._is_grad = True return self - def is_grad(self): - return self._is_grad - @property def requires_grad(self) -> bool: return self._requires_grad diff --git a/cube/ir/operator.py b/cube/ir/operator.py index d19ab835..8c7beb57 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -154,7 +154,7 @@ def gen_backward(self) -> IRCell: def __repr__(self) -> str: sign = self.signature.split('.')[-1] - ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_param()] + ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_attr()] dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " f"inputs={ins}, " f"outputs={self.outputs()})") diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index a21311ee..885ec143 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -403,12 +403,20 @@ def as_param(self): Set the tensor as trainable parameter """ self.requires_grad = True - self._is_param = True + self._is_attr = True + self._is_grad = False + + def as_buffer(self): + """ + Set the tensor as un-trainable buffer + """ + self.requires_grad = False + self._is_attr = True self._is_grad = False def as_grad(self): self.requires_grad = False - self._is_param = False + self._is_attr = False self._is_grad = True return self @@ -825,7 +833,7 @@ def common(self, other) -> Optional[IRTensor]: def __repr__(self) -> str: anno = 't' - if self.is_param(): + if self.is_attr(): anno = 'w' if self.is_grad(): anno = 'g' @@ -835,7 +843,7 @@ def __repr__(self) -> str: def extra_repr(self) -> str: anno = 't' - if self.is_param(): + if self.is_attr(): anno = 'w' if self.is_grad(): anno = 'g' From de6480e87341f1600a65060c950c4d97796dbed0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 Aug 2022 14:57:39 +0800 Subject: [PATCH 0958/1892] swin single device parse --- examples/vision/swin/baseline.py | 781 +++++++++++++++++++++ examples/vision/swin/blocks/__init__.py | 0 examples/vision/swin/blocks/attention.py | 145 ++++ examples/vision/swin/blocks/mlp.py | 32 + examples/vision/swin/blocks/patch.py | 102 +++ examples/vision/swin/blocks/transformer.py | 187 +++++ examples/vision/swin/blocks/utils.py | 14 + examples/vision/swin/model.py | 572 ++------------- examples/vision/swin/policy/spmd.py | 14 + examples/vision/swin/train.py | 82 +-- 10 files changed, 1371 insertions(+), 558 deletions(-) create mode 100644 examples/vision/swin/baseline.py create mode 100644 examples/vision/swin/blocks/__init__.py create mode 100644 examples/vision/swin/blocks/attention.py create mode 100644 examples/vision/swin/blocks/mlp.py create mode 100644 examples/vision/swin/blocks/patch.py create mode 100644 examples/vision/swin/blocks/transformer.py create mode 100644 examples/vision/swin/blocks/utils.py create mode 100644 examples/vision/swin/policy/spmd.py diff --git a/examples/vision/swin/baseline.py b/examples/vision/swin/baseline.py new file mode 100644 index 00000000..661ee224 --- /dev/null +++ b/examples/vision/swin/baseline.py @@ -0,0 +1,781 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + --nnodes=1 \ + examples/vision/swin/baseline.py +""" + + +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from cube.profiler.memory import memory_summary, model_summary +from cube.profiler.timer import CudaTimer, print_each_rank + +import cube + + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +class DropPath(torch.nn.Module): + + def __init__(self, drop_prob: float): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.criterion = nn.CrossEntropyLoss() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x, labels: torch.Tensor): + x = self.forward_features(x) + x = self.head(x) + loss = self.criterion(x, labels) + return loss + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int, img_size: int, num_classes: int): + + self.bs = batch_size + self.img_size = img_size + self.num_classes = num_classes + super().__init__( + shapes=([batch_size, 3, img_size, img_size,], + [batch_size], + ), + dtypes=(torch.float, torch.int), + batch_dims=(0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + img = torch.rand( + *(self.bs, 3, self.img_size, self.img_size), + dtype=torch.float, + device=torch.cuda.current_device() + ) + labels = torch.randint( + 0, self.num_classes, + size=(self.bs,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + return (img, labels) + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + +class Config: + + # swin-large 201M + embed_dim = 192 + depths = [2, 2, 18, 2] + num_heads = [6, 12, 24, 48] + + # swin-huge: 2.5B + # embed_dim = 512 + # depths = [2, 2, 42, 2] + # num_heads = [16, 32, 64, 128] + + mlp_ratio = 4 + qkv_bias = True + qk_scale = None + + drop_path_rate = 0.2 + drop_rate = 0.2 + + + # 224 x 224 + img_size = 224 + window_size = 7 + + # 640 x 640 + img_size = 640 + window_size = 40 + + # 1536 x 1536 + # img_size = 1536 + # window_size = 48 + + num_classes = 1000 + + +def train(): + + batch_size = 1 + + cfg = Config() + model = SwinTransformer(img_size=cfg.img_size, + patch_size=4, + in_chans=3, + num_classes=cfg.num_classes, + embed_dim=cfg.embed_dim, + depths=cfg.depths, + num_heads=cfg.num_heads, + window_size=cfg.window_size, + mlp_ratio=cfg.mlp_ratio, + qkv_bias=cfg.qkv_bias, + qk_scale=cfg.qk_scale, + drop_rate=cfg.drop_rate, + drop_path_rate=cfg.drop_path_rate, + ape=False, + patch_norm=True, + use_checkpoint=False) + + model = model.cuda() + dataloader = ImageDataLoader(batch_size, cfg.img_size, cfg.num_classes) + optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) + + print_each_rank('model weight consumpition:') + memory_summary() + + def train_iter(model, dataloader): + imgs, labels = next(dataloader) + loss = model(imgs, labels) + loss.backward() + + CudaTimer(enable=False).warmup() + iter_num = 10 + for step in range(iter_num): + + if step == 0: + model_summary(model, next(dataloader)) + + if step >= 4: + CudaTimer(enable=True).start('e2e') + + # training + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + if step >= 4: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('passed first iteration') + + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-4, field_name='e2e'))) + memory_summary() + +if __name__ == '__main__': + + cube.init() + train() diff --git a/examples/vision/swin/blocks/__init__.py b/examples/vision/swin/blocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/vision/swin/blocks/attention.py b/examples/vision/swin/blocks/attention.py new file mode 100644 index 00000000..2948daa6 --- /dev/null +++ b/examples/vision/swin/blocks/attention.py @@ -0,0 +1,145 @@ +from typing import Optional +import torch +import cube + +from examples.vision.swin.blocks.utils import trunc_normal_ + + +# REMARK: as default attention has qkv project weight of (3 head dim_head) C, +# this cannot partition on head dimension +# as the head dimension is a secondary hidden dimension in (3 head dim_head). +# To make partition work (correctness guarantee), the dimension is swapped as (head dim_head 3) +@cube.graph.parser.register('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), a^ b^, c^ d^, C^ (h+ dh^), B N^ N^ -> B N^ C^') +def window_attn(x: torch.Tensor, + qkv_w: torch.Tensor, qkv_bias: torch.Tensor, + relative_position_index: torch.Tensor, + relative_position_bias_table: torch.Tensor, + dense_w: torch.Tensor, + mask: Optional[torch.Tensor], + attn_drop: float, + h: int, scale: float, wh: int, ww: int,): + """ + @param h int: number of head + @param wh int: window size of height + @param ww int: window size of width + """ + B_, N, C = x.shape + # B N (h+ dh 3) + qkv = torch.nn.functional.linear(x, qkv_w, qkv_bias) + # 3 B h N C//h + qkv = qkv.reshape(B_, N, h, C // h, 3).permute(4, 0, 2, 1, 3) + # 3 B h N C//h + # qkv = qkv.reshape(B_, N, 3, h, C // h).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * scale + attn = (q @ k.transpose(-2, -1)) + + # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias_table[ + relative_position_index.view(-1) + ].view(wh * ww, wh * ww, -1) + # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, h, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, h, N, N) + attn = torch.nn.functional.softmax(attn, dim=-1) + else: + attn = torch.nn.functional.softmax(attn, dim=-1) + + attn = torch.nn.functional.dropout(attn, attn_drop, True, False) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = torch.nn.functional.linear(x, dense_w) + return x + + +class WindowAttention(torch.nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.wh = window_size[0] + self.ww = window_size[1] + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = torch.nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # wh * ww, wh * ww + self.register_buffer('relative_position_index', relative_position_index) + + self.attn_drop = attn_drop + self.proj_drop = proj_drop + # qkv + self.qkv_w = torch.nn.Parameter(torch.empty(dim * 3, dim)) + self.qkv_b = torch.nn.Parameter(torch.empty(dim * 3)) + + # out + self.out_w = torch.nn.Parameter(torch.empty(dim, dim)) + self.out_b = torch.nn.Parameter(torch.empty(dim)) + + trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + x = window_attn( + x, self.qkv_w, self.qkv_b, + self.relative_position_index, + self.relative_position_bias_table, + self.out_w, mask, + self.attn_drop, self.num_heads, + self.scale, self.wh, self.ww + ) + x = x + self.out_b + x = torch.nn.functional.dropout(x, self.proj_drop, True, False) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops \ No newline at end of file diff --git a/examples/vision/swin/blocks/mlp.py b/examples/vision/swin/blocks/mlp.py new file mode 100644 index 00000000..4cbbdb2c --- /dev/null +++ b/examples/vision/swin/blocks/mlp.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +import cube + + +@cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') +def feedforward(x: torch.Tensor, + proj1: torch.Tensor, proj1_bias: torch.Tensor, + proj2: torch.Tensor, dropout: float) -> torch.Tensor: + x = torch.nn.functional.linear(x, proj1, proj1_bias) + x = torch.nn.functional.gelu(x) + x = torch.nn.functional.dropout(x, dropout, True, False) + x = torch.nn.functional.linear(x, proj2, None) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1_w = torch.nn.Parameter(torch.empty(hidden_features, in_features)) + self.fc1_b = torch.nn.Parameter(torch.empty(hidden_features)) + self.fc2_w = torch.nn.Parameter(torch.empty(out_features, hidden_features)) + self.fc2_b = torch.nn.Parameter(torch.empty(out_features)) + self.drop = drop + + def forward(self, x): + x = feedforward(x, self.fc1_w, self.fc1_b, self.fc2_w, self.drop) + x = x + self.fc2_b + x = torch.nn.functional.dropout(x, self.drop, True, False) + return x diff --git a/examples/vision/swin/blocks/patch.py b/examples/vision/swin/blocks/patch.py new file mode 100644 index 00000000..a8fa3a1c --- /dev/null +++ b/examples/vision/swin/blocks/patch.py @@ -0,0 +1,102 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +import cube + + +@cube.graph.parser.register('B (2 h^ 2 w^) C^ -> B (h w) (4 C)') +def patch_merge(x: torch.Tensor, h: int, w: int): + B, L, C = x.shape + H = 2 * h + W = 2 * w + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + x = x.view(B, H, W, C) + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution: Tuple[int, int], dim: int, norm_layer=nn.LayerNorm): + super().__init__() + self.H = input_resolution[0] + self.W = input_resolution[1] + # self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B (H W) C + """ + x = patch_merge(x, self.H // 2, self.W // 2) + x = self.norm(x) + x = self.reduction(x) + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/examples/vision/swin/blocks/transformer.py b/examples/vision/swin/blocks/transformer.py new file mode 100644 index 00000000..7bda0406 --- /dev/null +++ b/examples/vision/swin/blocks/transformer.py @@ -0,0 +1,187 @@ +from typing import Tuple +import torch +import torch.nn as nn + +from examples.vision.swin.blocks.attention import WindowAttention +from examples.vision.swin.blocks.mlp import Mlp + +import cube + + +@cube.graph.parser.register('* -> *') +def drop_path(x: torch.Tensor, drop_prob: float, training: bool): + if drop_prob <= 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +@cube.graph.parser.register('B (nh ws) (nw ws) C -> (B nh nw) ws ws C') +def window_partition(x: torch.Tensor, ws: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + window_size = ws + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +@cube.graph.parser.register('(B nh nw) ws ws C -> B (nh ws) (nw ws) C') +def window_reverse(windows: torch.Tensor, ws: int, nh: int, nw: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = ws + B = int(windows.shape[0] / (nh * nw)) + x = windows.view(B, nh, nw, window_size, window_size, -1) + H = nh * ws + W = nw * ws + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim: int, input_resolution: Tuple[int, int], num_heads: int, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.H = input_resolution[0] + self.W = input_resolution[1] + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = drop_path + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + # self.attn_mask = attn_mask + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H = self.H + W = self.W + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H // self.window_size, W // self.window_size) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + drop_path(x, self.drop_path, self.training) + x = x + drop_path(self.mlp(self.norm2(x)), self.drop_path, self.training) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops diff --git a/examples/vision/swin/blocks/utils.py b/examples/vision/swin/blocks/utils.py new file mode 100644 index 00000000..acf35e91 --- /dev/null +++ b/examples/vision/swin/blocks/utils.py @@ -0,0 +1,14 @@ +import torch +import math + + +def trunc_normal_(tensor: torch.Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.): + with torch.no_grad(): + l = (1. + math.erf((a - mean) / std / math.sqrt(2.))) / 2. + u = (1. + math.erf((b - mean) / std / math.sqrt(2.))) / 2. + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor \ No newline at end of file diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index ee6edb3f..3703a7b1 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -1,394 +1,47 @@ -# -------------------------------------------------------- -# Swin Transformer -# Copyright (c) 2021 Microsoft -# Licensed under The MIT License [see LICENSE for details] -# Written by Ze Liu - -# The file is merged with source code from timm -# -------------------------------------------------------- -import math -import warnings import torch import torch.nn as nn -import torch.utils.checkpoint as checkpoint - -import cube - - - -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -class DropPath(torch.nn.Module): - - def __init__(self, drop_prob: float): - super().__init__() - self.drop_prob = drop_prob - - def forward(self, x): - if self.drop_prob == 0. or not self.training: - return x - keep_prob = 1 - self.drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C +from examples.vision.swin.blocks.utils import trunc_normal_ +from examples.vision.swin.blocks.transformer import SwinTransformerBlock +from examples.vision.swin.blocks.patch import PatchEmbed, PatchMerging - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) +import cube - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x +class Config: - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + # swin-large 201M + embed_dim = 192 + depths = [2, 2, 18, 2] + num_heads = [6, 12, 24, 48] - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops + # swin-huge: 2.5B + # embed_dim = 512 + # depths = [2, 2, 42, 2] + # num_heads = [16, 32, 64, 128] + mlp_ratio = 4 -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) + drop_path_rate = 0.2 + drop_rate = 0.2 + attn_drop_rate = 0.0 + + # dataloader - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) + # 224 x 224 + img_size = 224 + window_size = 7 - return x + # 640 x 640 + img_size = 640 + window_size = 40 - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" + # 1536 x 1536 + # img_size = 1536 + # window_size = 48 - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops + num_classes = 1000 class BasicLayer(nn.Module): @@ -412,13 +65,12 @@ class BasicLayer(nn.Module): def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + drop_path=0., norm_layer=nn.LayerNorm, downsample=None): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth - self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ @@ -440,10 +92,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, def forward(self, x): for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) + x = blk(x) if self.downsample is not None: x = self.downsample(x) return x @@ -460,133 +109,56 @@ def flops(self): return flops -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ +class SwinTransformer(nn.Module): - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + def __init__(self): super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x + cfg = Config() - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - """ + self.num_classes = cfg.num_classes + self.num_layers = len(cfg.depths) + self.embed_dim = cfg.embed_dim + self.num_features = int(cfg.embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = 4. - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, **kwargs): - super().__init__() - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio + self.patch_size = 4 # split image into non-overlapping patches self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches + img_size=cfg.img_size, + patch_size=self.patch_size, + in_chans=3, embed_dim=cfg.embed_dim, + norm_layer=nn.LayerNorm + ) patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=cfg.drop_rate) # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, sum(cfg.depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + layer = BasicLayer(dim=int(cfg.embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint) + depth=cfg.depths[i_layer], + num_heads=cfg.num_heads[i_layer], + window_size=cfg.window_size, + mlp_ratio=cfg.mlp_ratio, + qkv_bias=True, qk_scale=None, + drop=cfg.drop_rate, attn_drop=cfg.attn_drop_rate, + drop_path=dpr[sum(cfg.depths[:i_layer]):sum(cfg.depths[:i_layer + 1])], + norm_layer=nn.LayerNorm, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) self.layers.append(layer) - self.norm = norm_layer(self.num_features) + self.norm = nn.LayerNorm(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, cfg.num_classes) if cfg.num_classes > 0 else nn.Identity() self.criterion = nn.CrossEntropyLoss() self.apply(self._init_weights) @@ -623,7 +195,16 @@ def forward_features(self, x): return x def forward(self, x, labels: torch.Tensor): - x = self.forward_features(x) + x = self.patch_embed(x) + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + x = self.head(x) loss = self.criterion(x, labels) return loss @@ -638,9 +219,12 @@ def flops(self): return flops +# =========================== Data Loader ======================= + + class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): - def __init__(self, batch_size: int, img_size: int, num_classes: int): + def __init__(self, batch_size: int, img_size: int, num_classes: int, dtype=torch.float32): self.bs = batch_size self.img_size = img_size @@ -649,15 +233,15 @@ def __init__(self, batch_size: int, img_size: int, num_classes: int): shapes=([batch_size, 3, img_size, img_size,], [batch_size], ), - dtypes=(torch.float, torch.int), + dtypes=(dtype, torch.int), batch_dims=(0, 0) ) - self.samples = [self.random_sample()] + self.samples = [self.random_sample(dtype)] - def random_sample(self): + def random_sample(self, dtype: torch.dtype): img = torch.rand( *(self.bs, 3, self.img_size, self.img_size), - dtype=torch.float, + dtype=dtype, device=torch.cuda.current_device() ) labels = torch.randint( @@ -672,4 +256,4 @@ def __iter__(self): return self def __next__(self): - return self.samples[0] \ No newline at end of file + return self.samples[0] diff --git a/examples/vision/swin/policy/spmd.py b/examples/vision/swin/policy/spmd.py new file mode 100644 index 00000000..fe8228a4 --- /dev/null +++ b/examples/vision/swin/policy/spmd.py @@ -0,0 +1,14 @@ +from typing import List + +from cube.graph import IRGraph +from cube.graph.function.anchor import IRGraphAnchor +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation + + +def PASSingle(graph: IRGraph, resource): + assert resource.ngpus == 1 + # print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + return graph \ No newline at end of file diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index ae6e350e..634acfcf 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -8,46 +8,15 @@ """ import torch -from examples.vision.swin.model import SwinTransformer, ImageDataLoader +from examples.vision.swin.model import Config, SwinTransformer, ImageDataLoader import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary, model_summary +import examples.vision.swin.policy.spmd as spmd -class Config: - - # swin-large 201M - embed_dim = 192 - depths = [2, 2, 18, 2] - num_heads = [6, 12, 24, 48] - - # swin-huge: 2.5B - # embed_dim = 512 - # depths = [2, 2, 42, 2] - # num_heads = [16, 32, 64, 128] - - mlp_ratio = 4 - qkv_bias = True - qk_scale = None - - drop_path_rate = 0.2 - drop_rate = 0.2 - - - # 224 x 224 - img_size = 224 - window_size = 7 - - # 640 x 640 - img_size = 640 - window_size = 40 - - # 1536 x 1536 - # img_size = 1536 - # window_size = 48 - - num_classes = 1000 +PAS = spmd.PASSingle def train(): @@ -55,43 +24,27 @@ def train(): batch_size = 1 cfg = Config() - model = SwinTransformer(img_size=cfg.img_size, - patch_size=4, - in_chans=3, - num_classes=cfg.num_classes, - embed_dim=cfg.embed_dim, - depths=cfg.depths, - num_heads=cfg.num_heads, - window_size=cfg.window_size, - mlp_ratio=cfg.mlp_ratio, - qkv_bias=cfg.qkv_bias, - qk_scale=cfg.qk_scale, - drop_rate=cfg.drop_rate, - drop_path_rate=cfg.drop_path_rate, - ape=False, - patch_norm=True, - use_checkpoint=False) - - model = model.cuda() + model = SwinTransformer() dataloader = ImageDataLoader(batch_size, cfg.img_size, cfg.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) - print_each_rank('model weight consumpition:') - memory_summary() - + model = cube.SemanticModel(model, dataloader.shapes) + @cube.compile(model, dataloader, PAS=PAS, override=True) def train_iter(model, dataloader): imgs, labels = next(dataloader) loss = model(imgs, labels) loss.backward() + model = model.get_gen_module() + + torch.distributed.barrier() + print_each_rank('model weight consumpition:') + memory_summary() CudaTimer(enable=False).warmup() - iter_num = 10 + iter_num, warmup = 10, 2 for step in range(iter_num): - if step == 0: - model_summary(model, next(dataloader)) - - if step >= 4: + if step >= warmup: CudaTimer(enable=True).start('e2e') # training @@ -99,19 +52,20 @@ def train_iter(model, dataloader): optimizer.step() optimizer.zero_grad() - if step >= 4: + if step >= warmup: CudaTimer().stop('e2e') if step == 0: print_each_rank('passed first iteration') - - if (step + 1) % 2 == 0: + if (step + 1) % 4 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-4, field_name='e2e'))) + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) memory_summary() + if __name__ == '__main__': cube.init() From ccc5785ed87d954db4a29829aca1fc604335c870 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 Aug 2022 15:47:22 +0800 Subject: [PATCH 0959/1892] swin transfomer correctness check passed --- examples/vision/swin/baseline.py | 36 ++++++++++--------------- examples/vision/swin/model.py | 46 +++++++++----------------------- examples/vision/swin/train.py | 14 +++++++--- 3 files changed, 38 insertions(+), 58 deletions(-) diff --git a/examples/vision/swin/baseline.py b/examples/vision/swin/baseline.py index 661ee224..557fec3e 100644 --- a/examples/vision/swin/baseline.py +++ b/examples/vision/swin/baseline.py @@ -19,7 +19,6 @@ import cube - def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -151,7 +150,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 @@ -592,24 +591,12 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.criterion = nn.CrossEntropyLoss() - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} + torch.manual_seed(0) + for param in self.parameters(): + if len(param.size()) > 1: + trunc_normal_(param, std=.02) + else: + nn.init.constant_(param, 0) def forward_features(self, x): x = self.patch_embed(x) @@ -658,6 +645,7 @@ def __init__(self, batch_size: int, img_size: int, num_classes: int): self.samples = [self.random_sample()] def random_sample(self): + torch.manual_seed(0) img = torch.rand( *(self.bs, 3, self.img_size, self.img_size), dtype=torch.float, @@ -742,17 +730,21 @@ def train(): print_each_rank('model weight consumpition:') memory_summary() + print(list(model.parameters())[0].shape) + print(list(model.parameters())[0]) + def train_iter(model, dataloader): imgs, labels = next(dataloader) loss = model(imgs, labels) loss.backward() + print(loss) CudaTimer(enable=False).warmup() iter_num = 10 for step in range(iter_num): - if step == 0: - model_summary(model, next(dataloader)) + # if step == 0: + # model_summary(model, next(dataloader)) if step >= 4: CudaTimer(enable=True).start('e2e') diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index 3703a7b1..00fdbfc5 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn @@ -161,38 +160,18 @@ def __init__(self): self.head = nn.Linear(self.num_features, cfg.num_classes) if cfg.num_classes > 0 else nn.Identity() self.criterion = nn.CrossEntropyLoss() - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def forward_features(self, x): - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - return x + torch.manual_seed(0) + for param in self.parameters(): + if len(param.size()) > 1: + trunc_normal_(param, std=.02) + else: + nn.init.constant_(param, 0) + # this is to match for the correctness with baseline + for basic_layer in self.layers: + for block in basic_layer.blocks: + with torch.no_grad(): + w: torch.Tensor = block.attn.qkv_w.view(3, -1, block.attn.qkv_w.size(-1)) + block.attn.qkv_w.copy_(w.permute(1,0,2).reshape(-1, w.size(-1))) def forward(self, x, labels: torch.Tensor): x = self.patch_embed(x) @@ -239,6 +218,7 @@ def __init__(self, batch_size: int, img_size: int, num_classes: int, dtype=torch self.samples = [self.random_sample(dtype)] def random_sample(self, dtype: torch.dtype): + torch.manual_seed(0) img = torch.rand( *(self.bs, 3, self.img_size, self.img_size), dtype=dtype, diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 634acfcf..d82335fc 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -18,6 +18,7 @@ PAS = spmd.PASSingle +torch.random.manual_seed(0) def train(): @@ -26,7 +27,6 @@ def train(): cfg = Config() model = SwinTransformer() dataloader = ImageDataLoader(batch_size, cfg.img_size, cfg.num_classes) - optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) model = cube.SemanticModel(model, dataloader.shapes) @cube.compile(model, dataloader, PAS=PAS, override=True) @@ -34,11 +34,18 @@ def train_iter(model, dataloader): imgs, labels = next(dataloader) loss = model(imgs, labels) loss.backward() - model = model.get_gen_module() + return loss + model: torch.nn.Module = model.get_gen_module() + + optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) torch.distributed.barrier() print_each_rank('model weight consumpition:') memory_summary() + nparams = 0 + for param in model.parameters(): + nparams += param.nelement() + print_each_rank(f'model parameter: {nparams}') CudaTimer(enable=False).warmup() iter_num, warmup = 10, 2 @@ -48,7 +55,8 @@ def train_iter(model, dataloader): CudaTimer(enable=True).start('e2e') # training - train_iter(model, dataloader) + loss = train_iter(model, dataloader) + print(loss) optimizer.step() optimizer.zero_grad() From 7f4f29b55249b0d9f3f3d31c3f89a7ec7685917c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 Aug 2022 17:08:29 +0800 Subject: [PATCH 0960/1892] fix swin window attention parallel bug --- examples/vision/swin/blocks/attention.py | 39 +++++++------- examples/vision/swin/model.py | 27 +++++++--- examples/vision/swin/policy/mpmd.py | 0 examples/vision/swin/policy/spmd.py | 69 +++++++++++++++++++++++- examples/vision/swin/train.py | 38 ++++++++++--- 5 files changed, 138 insertions(+), 35 deletions(-) create mode 100644 examples/vision/swin/policy/mpmd.py diff --git a/examples/vision/swin/blocks/attention.py b/examples/vision/swin/blocks/attention.py index 2948daa6..c2d48ccc 100644 --- a/examples/vision/swin/blocks/attention.py +++ b/examples/vision/swin/blocks/attention.py @@ -2,14 +2,12 @@ import torch import cube -from examples.vision.swin.blocks.utils import trunc_normal_ - # REMARK: as default attention has qkv project weight of (3 head dim_head) C, # this cannot partition on head dimension # as the head dimension is a secondary hidden dimension in (3 head dim_head). # To make partition work (correctness guarantee), the dimension is swapped as (head dim_head 3) -@cube.graph.parser.register('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), a^ b^, c^ d^, C^ (h+ dh^), B N^ N^ -> B N^ C^') +@cube.graph.parser.register('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), B N^ N^ -> B N^ C^') def window_attn(x: torch.Tensor, qkv_w: torch.Tensor, qkv_bias: torch.Tensor, relative_position_index: torch.Tensor, @@ -24,24 +22,27 @@ def window_attn(x: torch.Tensor, @param ww int: window size of width """ B_, N, C = x.shape + dh = qkv_w.size(0) // 3 // h # B N (h+ dh 3) qkv = torch.nn.functional.linear(x, qkv_w, qkv_bias) - # 3 B h N C//h - qkv = qkv.reshape(B_, N, h, C // h, 3).permute(4, 0, 2, 1, 3) - # 3 B h N C//h - # qkv = qkv.reshape(B_, N, 3, h, C // h).permute(2, 0, 3, 1, 4) + # 3 B h N dh + qkv = qkv.reshape(B_, N, h, dh, 3).permute(4, 0, 2, 1, 3) + # 3 B h N dh + # qkv = qkv.reshape(B_, N, 3, h, dh).permute(2, 0, 3, 1, 4) + # B h N dh q, k, v = qkv[0], qkv[1], qkv[2] + # B h N dh q = q * scale + # B h N dh @ B h dh N -> B h N N attn = (q @ k.transpose(-2, -1)) - - # Wh*Ww,Wh*Ww,nH + # (wh ww) (wh * ww) h relative_position_bias = relative_position_bias_table[ relative_position_index.view(-1) ].view(wh * ww, wh * ww, -1) - # nH, Wh*Ww, Wh*Ww + # h (wh ww) (wh ww) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + # attn: B h N N attn = attn + relative_position_bias.unsqueeze(0) - if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, h, N, N) + mask.unsqueeze(1).unsqueeze(0) @@ -49,10 +50,10 @@ def window_attn(x: torch.Tensor, attn = torch.nn.functional.softmax(attn, dim=-1) else: attn = torch.nn.functional.softmax(attn, dim=-1) - + # attn: B h N N attn = torch.nn.functional.dropout(attn, attn_drop, True, False) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # B h N N @ B h N dh -> B h N dh -> B N h dh -> B N h * dh + x = (attn @ v).transpose(1, 2).reshape(B_, N, h * dh) x = torch.nn.functional.linear(x, dense_w) return x @@ -82,7 +83,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias - self.relative_position_bias_table = torch.nn.Parameter( + self.rp_bias_table = torch.nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window @@ -96,7 +97,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # wh * ww, wh * ww - self.register_buffer('relative_position_index', relative_position_index) + self.register_buffer('rp_index', relative_position_index) self.attn_drop = attn_drop self.proj_drop = proj_drop @@ -108,8 +109,6 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.out_w = torch.nn.Parameter(torch.empty(dim, dim)) self.out_b = torch.nn.Parameter(torch.empty(dim)) - trunc_normal_(self.relative_position_bias_table, std=.02) - def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]): """ Args: @@ -118,8 +117,8 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]): """ x = window_attn( x, self.qkv_w, self.qkv_b, - self.relative_position_index, - self.relative_position_bias_table, + self.rp_index, + self.rp_bias_table, self.out_w, mask, self.attn_drop, self.num_heads, self.scale, self.wh, self.ww diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index 00fdbfc5..7521f448 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -11,17 +11,31 @@ class Config: # swin-large 201M - embed_dim = 192 - depths = [2, 2, 18, 2] - num_heads = [6, 12, 24, 48] + # embed_dim = 192 + # depths = [2, 2, 18, 2] + # num_heads = [6, 12, 24, 48] # swin-huge: 2.5B # embed_dim = 512 # depths = [2, 2, 42, 2] # num_heads = [16, 32, 64, 128] - mlp_ratio = 4 + # 355M + embed_dim = 256 + depths = [2, 2, 18, 2] + num_heads = [8, 16, 32, 64] + # 1.8B + # embed_dim = 512 + # depths = [2, 2, 26, 2] + # num_heads = [16, 32, 64, 128] + + # 6.6B + # embed_dim = 768 + # depths = [2, 2, 42, 2] + # num_heads = [24, 48, 96, 192] + + mlp_ratio = 4 drop_path_rate = 0.2 drop_rate = 0.2 attn_drop_rate = 0.0 @@ -29,8 +43,8 @@ class Config: # dataloader # 224 x 224 - img_size = 224 - window_size = 7 + # img_size = 224 + # window_size = 7 # 640 x 640 img_size = 640 @@ -91,6 +105,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, def forward(self, x): for blk in self.blocks: + cube.runtime.function.anchor('transformer block start') x = blk(x) if self.downsample is not None: x = self.downsample(x) diff --git a/examples/vision/swin/policy/mpmd.py b/examples/vision/swin/policy/mpmd.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/vision/swin/policy/spmd.py b/examples/vision/swin/policy/spmd.py index fe8228a4..f62cde10 100644 --- a/examples/vision/swin/policy/spmd.py +++ b/examples/vision/swin/policy/spmd.py @@ -5,10 +5,77 @@ from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +# ========================= parallelisms ================================= + +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], + idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# coshard +def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, + idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) + assert sub_nodes is not None + graph.recompute(sub_nodes) + for devid in devs: + for coid in range(colocate): + sub_node = sub_nodes[devid * colocate + coid] + graph.assign(sub_node, devid) + return sub_nodes + + +# ========================= parallelisms ================================= + + + def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 # print(graph.extra_repr()) for node in graph.nodes(): if not isinstance(node, IRBpOperation): graph.assign(node, 0) - return graph \ No newline at end of file + return graph + + +def PASMegatronTP(graph: IRGraph, resource): + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # annotating code structure -- not consider multiref on embedding weight + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + # why -1: multiref + fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + + # attention + attns = [node for node in fnodes if node.name == 'window_attn'] + for attn in attns: + _tp(graph, attn, tp_devs, idx=1, dim=0) + + # feedforward + ffns = [node for node in fnodes if node.name == 'feedforward'] + for ffn in ffns: + _tp(graph, ffn, tp_devs, idx=1, dim=0) + + # replicate other nodes + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: + _replica(graph, node, tp_devs) + + return graph diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index d82335fc..20873078 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -2,9 +2,9 @@ example: OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ + --nproc_per_node=4 \ --nnodes=1 \ - examples/vision/swin/train.py + examples/vision/swin/train.py --policy PASMegatronTP --fp16 """ import torch @@ -15,10 +15,30 @@ from cube.profiler.memory import memory_summary, model_summary import examples.vision.swin.policy.spmd as spmd +import examples.vision.swin.policy.mpmd as mpmd -PAS = spmd.PASSingle +import argparse + +parser = argparse.ArgumentParser(description='GPT Train') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') +parser.add_argument('--fp16', action='store_true', default=False, + help='use fp16 for the training') +args = parser.parse_args() +cube.init() + + +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +policies = [policy for policy in policies if policy.startswith('PAS')] +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") -torch.random.manual_seed(0) def train(): @@ -26,7 +46,10 @@ def train(): cfg = Config() model = SwinTransformer() - dataloader = ImageDataLoader(batch_size, cfg.img_size, cfg.num_classes) + model = model.half() if args.fp16 else model + + dtype = torch.float16 if args.fp16 else torch.float32 + dataloader = ImageDataLoader(batch_size, cfg.img_size, cfg.num_classes, dtype=dtype) model = cube.SemanticModel(model, dataloader.shapes) @cube.compile(model, dataloader, PAS=PAS, override=True) @@ -34,7 +57,7 @@ def train_iter(model, dataloader): imgs, labels = next(dataloader) loss = model(imgs, labels) loss.backward() - return loss + # return loss model: torch.nn.Module = model.get_gen_module() optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) @@ -56,7 +79,7 @@ def train_iter(model, dataloader): # training loss = train_iter(model, dataloader) - print(loss) + # print(loss) optimizer.step() optimizer.zero_grad() @@ -76,5 +99,4 @@ def train_iter(model, dataloader): if __name__ == '__main__': - cube.init() train() From b6f3ad3fd860d8cb06f080f72dbe9dd12c6c11aa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 Aug 2022 17:15:58 +0800 Subject: [PATCH 0961/1892] coshard for swin --- examples/vision/swin/policy/spmd.py | 35 +++++++++++++++++++++++++++++ examples/vision/swin/train.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/examples/vision/swin/policy/spmd.py b/examples/vision/swin/policy/spmd.py index f62cde10..39d31eed 100644 --- a/examples/vision/swin/policy/spmd.py +++ b/examples/vision/swin/policy/spmd.py @@ -79,3 +79,38 @@ def PASMegatronTP(graph: IRGraph, resource): _replica(graph, node, tp_devs) return graph + + +def PASMeshShard(graph: IRGraph, resource): + + # print(graph.extra_repr()) + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # annotating code structure -- not consider multiref on embedding weight + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + # why -1: multiref + fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + + # attention + attns = [node for node in fnodes if node.name == 'window_attn'] + for attn in attns: + # _tp(graph, attn, tp_devs, idx=1, dim=0) + _coshard(graph, attn, tp_devs, colocate=2, idx=1, dim=0) + + # feedforward + ffns = [node for node in fnodes if node.name == 'feedforward'] + for ffn in ffns: + # _tp(graph, ffn, tp_devs, idx=1, dim=0) + _coshard(graph, ffn, tp_devs, colocate=4, idx=1, dim=0) + + # replicate other nodes + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: + _replica(graph, node, tp_devs) + + # print(graph.extra_repr()) + return graph diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 20873078..ef5443a5 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/vision/swin/train.py --policy PASMegatronTP --fp16 + examples/vision/swin/train.py --policy PASMeshShard --fp16 """ import torch From a23ad32baee877ee6ebaa6041d4ac0127cef104f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 Aug 2022 17:48:32 +0800 Subject: [PATCH 0962/1892] support mpmd --- examples/vision/swin/model.py | 21 ++-- examples/vision/swin/policy/mpmd.py | 167 ++++++++++++++++++++++++++++ examples/vision/swin/train.py | 4 +- 3 files changed, 176 insertions(+), 16 deletions(-) diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index 7521f448..dd0fe36a 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -188,7 +188,7 @@ def __init__(self): w: torch.Tensor = block.attn.qkv_w.view(3, -1, block.attn.qkv_w.size(-1)) block.attn.qkv_w.copy_(w.permute(1,0,2).reshape(-1, w.size(-1))) - def forward(self, x, labels: torch.Tensor): + def forward(self, x): # , labels: torch.Tensor): x = self.patch_embed(x) x = self.pos_drop(x) @@ -200,7 +200,8 @@ def forward(self, x, labels: torch.Tensor): x = torch.flatten(x, 1) x = self.head(x) - loss = self.criterion(x, labels) + # loss = self.criterion(x, labels) + loss = torch.sum(x) return loss def flops(self): @@ -224,11 +225,9 @@ def __init__(self, batch_size: int, img_size: int, num_classes: int, dtype=torch self.img_size = img_size self.num_classes = num_classes super().__init__( - shapes=([batch_size, 3, img_size, img_size,], - [batch_size], - ), - dtypes=(dtype, torch.int), - batch_dims=(0, 0) + shapes=([batch_size, 3, img_size, img_size,],), + dtypes=(dtype,), + batch_dims=(0,) ) self.samples = [self.random_sample(dtype)] @@ -239,13 +238,7 @@ def random_sample(self, dtype: torch.dtype): dtype=dtype, device=torch.cuda.current_device() ) - labels = torch.randint( - 0, self.num_classes, - size=(self.bs,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - return (img, labels) + return img def __iter__(self): return self diff --git a/examples/vision/swin/policy/mpmd.py b/examples/vision/swin/policy/mpmd.py index e69de29b..52e41c14 100644 --- a/examples/vision/swin/policy/mpmd.py +++ b/examples/vision/swin/policy/mpmd.py @@ -0,0 +1,167 @@ +from typing import List, Tuple +import numpy as np + +from cube.graph import IRGraph +from cube.graph.function.anchor import IRGraphAnchor +from cube.ir.cten import IRCell +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.schedule.sched1f1b import IRSchedule1F1B + + +def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + + e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: + ( + ( (0,1,2), (3,4,5) ), + ( (0,3), (2,5), (3,6) ), + ) + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + + +def _group_to_transformers(fnodes) -> List[List[IRCell]]: + # group to transformer layers + transformers: List[List[IRFwOperation]] = [] + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + start = idx if lid != 0 else 0 + end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) + transformers.append(fnodes[start:end]) + for lid in range(len(transformers) - 1): + if transformers[lid][-1].name == 'multiref': + node = transformers[lid].pop() + transformers[lid+1].insert(0, node) + return transformers + +# ========================= parallelisms ================================= + +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# ========================= parallelisms ================================= + +def PASRoundRobin(graph: IRGraph, resource): + """ + roundrobin scheduling + """ + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # group to transformer layers + transformers = _group_to_transformers(fnodes) + + for lid, transformer in enumerate(transformers): + stage_id = lid % resource.ngpus + print(f'assigning {lid} transformer to stage {stage_id}') + for node in transformer: + graph.assign(node, stage_id) + + for node in graph.nodes(): + if len(node.device) == 0: + _replica(graph, node, list(range(resource.ngpus))) + + return graph + + +def PAS1F1B(graph: IRGraph, resource): + """ + 1F1B scheduling + """ + num_stage = resource.ngpus + num_microbatch = resource.ngpus * 8 + _, stage_mesh = _create_mesh(resource.ngpus, (num_stage, 1)) + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # group to transformer layers + transformers = _group_to_transformers(fnodes) + + # staging + nlayer_per_stage = (len(transformers) // resource.ngpus) + for lid, fnodes in enumerate(transformers): + stage_id = min(lid // nlayer_per_stage, num_stage-1) + print(f'assigning {lid}-th transformer layter to stage {stage_id}') + for fnode in fnodes: + graph.assign(fnode, stage_id) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + _replica(graph, node, list(range(resource.ngpus))) + + strategy = IRSchedule1F1B(graph, num_microbatch, stage_mesh) + graph.sched = strategy + return graph + + +def PASMegatron(graph: IRGraph, resource): + """ + 1F1B scheduling + """ + dp_size = 1 + tp_size = 2 + pp_size = resource.ngpus // (dp_size * tp_size) + num_microbatch = resource.ngpus + + # device mesh + dp_groups, pp_groups, tp_groups = \ + _create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) + print(f'dp groups: {dp_groups}') + print(f'pp groups: {pp_groups}') + print(f'tp groups: {tp_groups}') + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # group to transformer layers + transformers = _group_to_transformers(fnodes) + + # staging + nlayer_per_stage = (len(transformers) // pp_size) + for lid, fnodes in enumerate(transformers): + sid = min(lid // nlayer_per_stage, pp_size-1) + print(f'assigning {lid}-th transformer layer to stage {sid}: {tp_groups[sid]}') + for fnode in fnodes: + if fnode.name == 'window_attn': + _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) + elif fnode.name == 'feedforward': + _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) + else: + _replica(graph, fnode, tp_groups[sid]) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + _replica(graph, node, list(range(resource.ngpus))) + + strategy = IRSchedule1F1B(graph, num_microbatch, tp_groups) + graph.sched = strategy + return graph diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index ef5443a5..4a993133 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -54,8 +54,8 @@ def train(): model = cube.SemanticModel(model, dataloader.shapes) @cube.compile(model, dataloader, PAS=PAS, override=True) def train_iter(model, dataloader): - imgs, labels = next(dataloader) - loss = model(imgs, labels) + imgs = next(dataloader) + loss = model(imgs) loss.backward() # return loss model: torch.nn.Module = model.get_gen_module() From 6d06b1e7c82070c04cccc39c46927a932a85a420 Mon Sep 17 00:00:00 2001 From: Zijian Ding Date: Thu, 11 Aug 2022 22:22:03 -0700 Subject: [PATCH 0963/1892] small update to onedim policy --- cube/algorithm/factory.py | 1 + cube/algorithm/ops/creators.py | 39 ++++++++++++++++++++++++++++++++- cube/graph/function/creators.py | 6 +++++ examples/wrf/policy/onedim.py | 21 +++++------------- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 2ccfa44e..afd0142a 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -83,6 +83,7 @@ def _load_predefined_algos(self): import cube.algorithm.ops.creators as creators self.register(creators.IRToTensor, creators.DimSplitTo, tag='dim') + self.register(creators.IRZeros, creators.DimSplitZeros, tag='dim') self.register(creators.IROnes, creators.DimSplitOnes, tag='dim') self.register(creators.IRRand, creators.DimSplitRand, tag='dim') # import cube.algorithm.ops.elementwise as elew diff --git a/cube/algorithm/ops/creators.py b/cube/algorithm/ops/creators.py index e397e1df..7119c3fc 100644 --- a/cube/algorithm/ops/creators.py +++ b/cube/algorithm/ops/creators.py @@ -2,7 +2,7 @@ from cube.algorithm.generics import GenericDistAlgo -from cube.graph.function.creators import IRToTensor, IROnes, IRRand +from cube.graph.function.creators import IRToTensor, IROnes, IRRand, IRZeros from cube.ir.tensor import IRSubTensor @@ -53,6 +53,43 @@ def instantiate(self, dim: int, num: int) -> Optional[List[IRToTensor]]: outputs = [t[nid] for t in ous] sub_nodes.append(node.new(inputs, outputs)) return sub_nodes + +class DimSplitZeros(GenericDistAlgo): + def __init__(self, node: IRZeros): + if not isinstance(node, IRZeros): + raise TypeError(f"Expect IRZeros") + super().__init__(node) + + def satisfy(self, dim: int, num: int): + """ + config = dict(idx=int, dim=int, num=num) + + """ + assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" + node: IRZeros = self.node + + assert dim < len(node.output(0).shape), "Split dimension should be smaller than tensor dimension" + + # split non-pad dim + return node.output(0).shape[dim] >= num + + def instantiate(self, dim: int, num: int) -> Optional[List[IRZeros]]: + + node: IRZeros = self.node + satisfy = self.satisfy(dim, num) + if not satisfy: + return None + + ous = list() + for oidx, otensor in enumerate(node.outputs()): + assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" + ous.append(otensor.split_dim(dim, num)) + + sub_nodes = list() + for nid in range(num): + outputs = [t[nid] for t in ous] + sub_nodes.append(node.new(outputs)) + return sub_nodes class DimSplitOnes(GenericDistAlgo): diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index aa08f3ea..c9f43e0e 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -28,6 +28,12 @@ def infer_shape(self) -> bool: self.output(0).shape = shape return True + def new(self, outputs: List[IRTensor]): + op = IROnes(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRZeros::new infer_shape failed" + return op + class IROnes(IRFwOperation): def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): diff --git a/examples/wrf/policy/onedim.py b/examples/wrf/policy/onedim.py index 02198695..0d0ea297 100644 --- a/examples/wrf/policy/onedim.py +++ b/examples/wrf/policy/onedim.py @@ -2,25 +2,11 @@ from cube.graph.function import IRConv2D, IRConv3D from cube.graph.function import IRDimops, IRPad from cube.ir.cten import IRTensor, IRCell -from cube.graph.function import IRSelect, IRSelectScatter, IRSlice, IRToTensor, IROnes, IRRand - +from cube.graph.function import IRSelect, IRSelectScatter, IRSlice, IRToTensor, IROnes, IRRand, IRZeros def PAS(graph: IRGraph, resource): for node in graph.nodes(): - if isinstance(node, IRConv3D): - sub_nodes = list() - algo = node.algorithms('halo') - Wnodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus // 2) - for Wnode in Wnodes: - algo = Wnode.algorithms('halo') - Hnodes = graph.partition(Wnode, algo, idx=0, dim=2, num=2) - sub_nodes += Hnodes - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - # sub_nodes = graph.replicate(node, times=resource.ngpus) - - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) + graph.assign(node, 0) print(graph.extra_repr()) return graph @@ -164,6 +150,9 @@ def PAS_ALL_Y(graph: IRGraph, resource): algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-2, num=resource.ngpus) assert sub_nodes != None + elif isinstance(node, IRZeros) and node.output(0).ndims >= 3: + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-2, num=resource.ngpus) # elif isinstance(node, IRRand) and node.output(0).ndims >= 3: # algo = node.algorithms('dim') # sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-1, num=resource.ngpus) From 765ad37b929a5a54b5a02d4f097d244c01bda75f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Aug 2022 13:32:35 +0800 Subject: [PATCH 0964/1892] enhanced dimop transformation expression: special transformation rule with modifiers to kwargs can be supported for expression --- cube/algorithm/ops/dimops.py | 142 ++++++++++++++++++++++---------- cube/graph/function/dimops.py | 109 ++++++++++++++++++++++-- cube/graph/function/function.py | 73 +++++++++++----- cube/ir/operator.py | 8 ++ 4 files changed, 263 insertions(+), 69 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 5c6ea673..36fa4d41 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -1,7 +1,7 @@ -from typing import List, Optional, Any +from typing import List, Optional, Any, Dict from cube.algorithm.generics import GenericDistAlgo -from cube.graph.function.dimops import IRDimops, DimAnno +from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule from cube.ir.tensor import IRSubTensor @@ -12,7 +12,7 @@ class DimSplitEinops(GenericDistAlgo): Note: for dimensions of multiple identitifers, only the first identifier can be partitioned. - Rules for identifier split: + Default rule for identifier split: * Sum-reduce identifier ('+'): * For inputs/outputs that have the identifier, will be partitioned on its diemension uniformly.. * For inputs that don't have the identifier, will be replicated @@ -24,18 +24,22 @@ class DimSplitEinops(GenericDistAlgo): * Frozen identifier ('^'): * Cannot be partitioned. + + If the identifier appears as the same name in argument name, the + argument will also be uniformly partitioned. Non-tensor will always be replicated. - Note this rule will not correctly apply for some operators like linear: xw + b + Note the default rule isn't always expressive for all possible partition algorithms. + E.g., linear xw + b to partition on reduction dimension, + whitch requires b to be value split but actually according to the default rule, will be replicated. + Therefore we require special rules for such cases. """ def __init__(self, node: IRDimops): if not isinstance(node, IRDimops): raise TypeError(f"Expect IRDimops") super().__init__(node) - self._adim: str = None - self._reduce: DimAnno.ReduceType = None def satisfy(self, idx: int, dim: int, num: int) -> bool: """ @@ -50,68 +54,122 @@ def satisfy(self, idx: int, dim: int, num: int) -> bool: assert all(isinstance(cond, int) for cond in [idx, dim, num]), "expect int condition" node: IRDimops = self.node + assert isinstance(node.input(idx), IRSubTensor), f"partitioning on a non-tensor input" ninputs = len(node.inputs()) idx = idx if idx >= 0 else idx + ninputs assert idx < ninputs, f"index out of boundary: {idx} >= {ninputs}" - assert isinstance(node.input(idx), IRSubTensor), f"partitioning on a non-tensor input" dim = dim if dim >= 0 else dim + node.input(idx).ndims assert dim < node.input(idx).ndims, f"dimension output of boundary: {dim} >= {node.input(idx).ndims}" - # we only partition the first annotated dimension for inner-dimension cases. - self._adim: str = node.anno.input(idx).dims[dim].identifiers[0] - self._reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] - dimlen = node.anno.getlen(self._adim) - if self._reduce == DimAnno.ReduceType.Freeze: - return False - if dimlen < num: + + # we only partition the first non-1 annotated dimension for hidden-dimension cases. + for adim in node.anno.input(idx).dims[dim].identifiers: + if adim == '1^': continue + break + dimlen = node.anno.getlen(adim) + # check node special rules first + for rule in node.transform_rules: + if rule.input(idx) == DimopSplit.D(dim): + return dimlen >= num + # otherwise check for default rules + reduce = node.anno.input(idx).dims[dim].reduces[0] + if reduce == DimAnno.ReduceType.Freeze: return False - return True + return dimlen >= num def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: node: IRDimops = self.node satisfy = self.satisfy(idx, dim, num) - print(f'partition {node.name}: {node.anno} | dim: {self._adim} reduce: {self._reduce.value}') + for adim in node.anno.input(idx).dims[dim].identifiers: + if adim == '1^': continue + break + reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] + print(f'try split {node.name}: {node.anno} | dim: {adim} reduce: {reduce}') if not satisfy: + print(f'Failed!') return None - def transform(tensor: Any, split_dims: List[int], is_input: bool): - # rule: non-tensor will always be replicated + rule: TransformRule = self.infer(idx, dim, num) + + # transform + def transform(tensor: Any, split: DimopSplit) -> List[Any]: if not isinstance(tensor, IRSubTensor): return [tensor] * num - assert len(split_dims) <= 1, "find split dims ({self._adim}) more than 1" - # rule: spatial identifier ('') - if self._reduce == DimAnno.ReduceType.Dim: - return tensor.replicate(num) if len(split_dims) == 0 else tensor.split_dim(split_dims[0], num) - # rule: reduce-sum identifier ('+') - if self._reduce == DimAnno.ReduceType.Sum: - if len(split_dims) == 0: - return tensor.replicate(num) if is_input else tensor.split_val(num) - else: - return tensor.split_dim(split_dims[0], num) - raise RuntimeError(f"no matching reduce type for transform: {self._reduce}") - - ins, ous = list(), list() - for iidx, itensor in enumerate(node.inputs()): - split_dims = node.anno.input(iidx).getdims(self._adim) - ins.append(transform(itensor, split_dims, is_input=True)) + if split.isD(): + return tensor.split_dim(split.dim, num) + if split.isR(): + return tensor.replicate(num) + if split.isV(): + return tensor.split_val(num) + assert False, f"got unknown split: {split}" - for oidx, otensor in enumerate(node.outputs()): - split_dims = node.anno.output(oidx).getdims(self._adim) - ous.append(transform(otensor, split_dims, is_input=False)) + ins = list() + for split, itensor in zip(rule.inputs(), node.inputs()): + ins.append(transform(itensor, split)) + ous = list() + for split, otensor in zip(rule.outputs(), node.outputs()): + ous.append(transform(otensor, split)) + kwargs = rule.modifier()(node.kwargs, idx, dim, num) sub_nodes = list() for nid in range(num): inputs = [t[nid] for t in ins] outputs = [t[nid] for t in ous] - updated_kwargs = dict() - if self._adim in node.kwargs and isinstance(node.kwargs[self._adim], int): - updated_kwargs[self._adim] = node.kwargs[self._adim] // num - sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) + sub_node: IRDimops = node.new(inputs, outputs, **kwargs) sub_node.infer_shape() sub_nodes.append(sub_node) return sub_nodes + def infer(self, idx: int, dim: int, num: int) -> Optional[TransformRule]: + """ + Given the partition choice on `dim` dimension of idx-th input, + return the partitioning of the output tensor. + + @param idx int: the input index + @param dim int: the dimension to partition + + @return rule TransformRule: the transformation rule + """ + node: IRDimops = self.node + adim: str = node.anno.input(idx).dims[dim].identifiers[0] + reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] + if reduce == DimAnno.ReduceType.Freeze: + return None + # check node special rules first + for r in node.transform_rules: + if r.input(idx) == DimopSplit.D(dim): + return r + # otherwise use default rule + itransform, otransform = [], [] + # input + for idx, idim in enumerate(node.anno.inputs()): + dims = idim.getdims(adim) + assert len(dims) <= 1, "Cannot split on multple same tensors" + if len(dims) == 1: + itransform.append(DimopSplit.D(dims[0])) + else: + itransform.append(DimopSplit.R()) + # output + for idx, odim in enumerate(node.anno.outputs()): + dims = odim.getdims(adim) + if len(dims) == 1: + otransform.append(DimopSplit.D(dims[0])) + else: + otransform.append( + DimopSplit.R() if reduce == DimAnno.ReduceType.Dim else DimopSplit.V() + ) + # modifier + def modify(kwargs: Dict, idx: int, dim: int, num: int): + updated_kwargs = dict(**kwargs) + if adim in updated_kwargs: + assert updated_kwargs[adim] % num == 0, \ + f"cannot set kwargs: {adim}: {updated_kwargs[adim]} % num ({num}) != 0" + updated_kwargs[adim] = updated_kwargs[adim] // num + return updated_kwargs + + return TransformRule(itransform, otransform, modify) + class SimpleViewSplitEinops(GenericDistAlgo): """ @@ -210,7 +268,7 @@ def instantiate(self, idx: int, dimi: int, dimo: int, num: int) -> Optional[List if self._adimo in node.kwargs and isinstance(node.kwargs[self._adimo], int): assert 0, "should not happen" assert len(outputs) == 1, f"outputs len should be one" - node.kwargs['size'] = outputs[0].shape + updated_kwargs['size'] = outputs[0].shape sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) sub_node.infer_shape() sub_nodes.append(sub_node) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 696d002b..5ad2502c 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -62,10 +62,9 @@ """ -from typing import Dict, Iterable, List, Union, Optional, Set, Tuple, Optional +from typing import Callable, Dict, Iterable, List, Union, Set, Tuple, Optional import enum import re -import copy import string from cube.ir.cten import IRTensor @@ -460,13 +459,103 @@ def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], return ', '.join(in_annos) + ' -> ' + ', '.join(ou_annos) +class DimopSplit: + """ + Partition status of a tensor + """ + def __init__(self, dim: Optional[int] = None, r = False, v = False) -> None: + self.dim = dim + self.rep = r + self.val = v + + def isR(self) -> bool: + return self.rep + + def isD(self) -> bool: + return self.dim is not None + + def isV(self) -> bool: + return self.val + + def __eq__(self, other): + if not isinstance(other, DimopSplit): + return False + if other.isR() and self.isR(): + return True + if other.isD() and self.isD() and other.dim == self.dim: + return True + if other.isV() and self.isV(): + return True + return False + + def __hash__(self) -> int: + if self.isV(): + return -1 + elif self.isR(): + return -2 + else: + return self.dim + + def __repr__(self) -> str: + if self.isD(): + return f'D({self.dim})' + if self.isR(): + return f'R' + if self.isV(): + return f'V' + return 'Unknown-DimopSplit' + + @staticmethod + def R(): + return DimopSplit(r=True) + + @staticmethod + def V(): + return DimopSplit(v=True) + + @staticmethod + def D(dim: int): + return DimopSplit(dim=dim) + + +class TransformRule: + """ + Partition rule + """ + def __init__(self, irules: Tuple[DimopSplit], orules: Tuple[DimopSplit], kwarg_modifier: Optional[Callable] = None) -> None: + self._inputs = tuple(irules) + self._outputs = tuple(orules) + modifier = kwarg_modifier if kwarg_modifier is not None else lambda x : x + self._modifier = (modifier,) + + def inputs(self) -> Tuple[DimopSplit]: + return self._inputs + + def input(self, idx: int) -> DimopSplit: + return self._inputs[idx] + + def outputs(self) -> Tuple[DimopSplit]: + return self._outputs + + def output(self, idx: int) -> DimopSplit: + return self._outputs[idx] + + def modifier(self) -> Optional[Callable]: + return self._modifier[0] + + def __repr__(self) -> str: + inputs = ', '.join(repr(split) for split in self._inputs) + outputs = ', '.join(repr(split) for split in self._outputs) + return f'{inputs} -> {outputs}' + class IRDimops(IRFwOperation): """ Einstein-inspired notation operations """ def __init__(self, signature: str, annos: Tuple[str], - inputs: List[IRTensor], name: str, **kwargs): + inputs: List[IRTensor], name: str, + transform_rules: Optional[Tuple[TransformRule]] = None, **kwargs): """! Create a IRDimops @@ -474,6 +563,7 @@ def __init__(self, signature: str, annos: Tuple[str], @param annos List[str]: annotation candidates @param inputs List[IRTensor]: input tensor list @param name str: the name of the operator + @param transform_rules: the special rules to partition the operator. Default None. @param kwargs: the kwarg non-tensor parameters """ assert all(isinstance(anno, str) for anno in annos), "Expect annos to be List[str]" @@ -481,6 +571,7 @@ def __init__(self, signature: str, annos: Tuple[str], self._anno: OpAnno = None self._iannos: List[ShapeAnno] = None self._oannos: List[ShapeAnno] = None + self._trans_rules: Tuple[TransformRule] = tuple(transform_rules) if transform_rules is not None else () for anno in self._annos_candidates: anno = OpAnno(anno) @@ -494,8 +585,9 @@ def __init__(self, signature: str, annos: Tuple[str], raise RuntimeError( f"no matching anno for given annos." f"op: {signature}\n" - f"inputs: {tuple(t.shape for t in inputs)}\n" + f"inputs: {tuple(t.shape if isinstance(t, IRTensor) else t for t in inputs)}\n" f"annos: {annos}\n" + f"kwargs: {kwargs}\n" ) n_outputs = len(self._oannos) @@ -510,6 +602,10 @@ def __init__(self, signature: str, annos: Tuple[str], def anno(self) -> OpAnno: return self._anno + @property + def transform_rules(self) -> Tuple[TransformRule]: + return self._trans_rules + def ianno(self, index: int) -> Tuple[DimAnno]: """! Get index-th input tensor shape annotation @@ -563,9 +659,8 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): @return op IRDimop: the new constructed operator """ annos = self._annos_candidates - updated_kwargs = copy.copy(self.kwargs) - updated_kwargs.update(kwargs) - op = IRDimops(self.signature, annos, inputs, self.name, **updated_kwargs) + rules = self._trans_rules + op = IRDimops(self.signature, annos, inputs, self.name, rules, **kwargs) for idx, output in enumerate(outputs): op.set_output(idx, output) return op diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 6b8c3443..da870756 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -7,7 +7,7 @@ from cube.ir.cten import IRTensor from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor, IRSubTensor -from cube.graph.function.dimops import ShapeAnno, OpAnno, IRDimops +from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRConv2D from cube.graph.function.conv import IRConv3D from cube.graph.function.pad import IRPad @@ -593,26 +593,57 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s bracket[subdim] = str(shape_map[edim]) # find out the axis that can be partitioned - spatial_in = set() - spatial_ou = set() - for in_bracket in in_anno: - for edim in in_bracket: - if edim != '1': - spatial_in.add(edim) - break - for ou_bracket in ou_anno: - for edim in ou_bracket: - spatial_ou.add(edim) - spatial = spatial_in.intersection(spatial_ou) - + ispatial = set() + ifirst = [] + for bracket in in_anno: + for hdim in range(len(bracket)): + if bracket[hdim] == '1': + continue + ispatial.add(bracket[hdim]) + ifirst.append(bracket[hdim]) + break + ospatial = set() + ofirst = [] + for bracket in ou_anno: + for hdim in range(len(bracket)): + if bracket[hdim] == '1': + continue + ospatial.add(bracket[hdim]) + ofirst.append(bracket[hdim]) + break + spatial = ispatial.intersection(ospatial) + + # set dimension cannot be partitioned for bracket in in_anno + ou_anno: - for subdim, edim in enumerate(bracket): - if edim not in spatial: - bracket[subdim] = str(shape_map[edim]) - # bracket[subdim] = edim + '^' + for hdim in range(len(bracket)): + if bracket[hdim] not in spatial: + bracket[hdim] = str(shape_map[bracket[hdim]]) + + # TODO: strange behaviour if every identitifer creates own + # modifier, seems all previous modifiers will be overrided by + # the last one. + def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: + kwargs = dict(**kwargs) + ofirst = [bracket[0] for bracket in ou_anno] + identifier = in_anno[idx][0] + oidx = ofirst.index(identifier) + size = list(kwargs['size']) + size[oidx] = size[oidx] // num + kwargs['size'] = tuple(size) + return kwargs + + # special rules: to change output size argument + rules = [] + for identifier in spatial: + iidx = ifirst.index(identifier) + oidx = ofirst.index(identifier) + rules.append( + TransformRule([DimopSplit.D(iidx)], [DimopSplit.D(oidx)], view_modifier) + ) + anno = OpAnno.create_op_str([in_anno], [ou_anno]) signature = 'torch.Tensor.view' - return IRDimops(signature, [anno], [input], 'view', size=tuple(shape)) + return IRDimops(signature, [anno], [input], 'view', rules, size=tuple(shape)) def Reshape(signature, inputs): @@ -812,8 +843,10 @@ def Flatten(signature, inputs: List): def Roll(signature, inputs: Tuple[IRTensor, Union[int, Tuple[int]], Union[int, Tuple[int]]]): tensor = inputs[0] shifts, dims = inputs[1:] - # TODO: enable partition - ishape = ShapeAnno.create_shape_str(tensor.shape, reduction='^') + ishape = ShapeAnno.create_shape_str(tensor.shape) + for dim in range(len(ishape)): + if dims is None or dim in dims: + ishape[dim] += '^' anno = OpAnno.create_op_str([ishape], [ishape]) return IRDimops(signature, [anno], [tensor], 'roll', shifts=shifts, dims=dims) diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 8c7beb57..86aa5f8d 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -160,6 +160,14 @@ def __repr__(self) -> str: f"outputs={self.outputs()})") return dscp + def extra_repr(self) -> str: + sign = self.signature.split('.')[-1] + ins = [t for t in self.inputs()] + dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " + f"inputs={ins}, " + f"outputs={self.outputs()})") + return dscp + class IRBpOperation(IRCell): """ From 80eb995e711549ff97de09821ca21f58a350f13b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Aug 2022 13:46:58 +0800 Subject: [PATCH 0965/1892] swin data parallelism --- examples/vision/swin/blocks/attention.py | 2 +- examples/vision/swin/blocks/mlp.py | 2 +- examples/vision/swin/blocks/patch.py | 21 +++++++--- examples/vision/swin/model.py | 29 ++++--------- examples/vision/swin/policy/spmd.py | 52 +++++++++++++++++++++++- examples/vision/swin/train.py | 2 +- tests/test_examples.sh | 18 ++++++++ 7 files changed, 96 insertions(+), 30 deletions(-) diff --git a/examples/vision/swin/blocks/attention.py b/examples/vision/swin/blocks/attention.py index c2d48ccc..a0be2db3 100644 --- a/examples/vision/swin/blocks/attention.py +++ b/examples/vision/swin/blocks/attention.py @@ -7,7 +7,7 @@ # this cannot partition on head dimension # as the head dimension is a secondary hidden dimension in (3 head dim_head). # To make partition work (correctness guarantee), the dimension is swapped as (head dim_head 3) -@cube.graph.parser.register('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), B N^ N^ -> B N^ C^') +@cube.graph.parser.register('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), nw N^ N^ -> B N^ C^') def window_attn(x: torch.Tensor, qkv_w: torch.Tensor, qkv_bias: torch.Tensor, relative_position_index: torch.Tensor, diff --git a/examples/vision/swin/blocks/mlp.py b/examples/vision/swin/blocks/mlp.py index 4cbbdb2c..1873cc53 100644 --- a/examples/vision/swin/blocks/mlp.py +++ b/examples/vision/swin/blocks/mlp.py @@ -3,7 +3,7 @@ import cube -@cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') +@cube.graph.parser.register('B HW^ E^, H+ E^, H+, E^ H+ -> B HW^ E^', name='feedforward') def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj1_bias: torch.Tensor, proj2: torch.Tensor, dropout: float) -> torch.Tensor: diff --git a/examples/vision/swin/blocks/patch.py b/examples/vision/swin/blocks/patch.py index a8fa3a1c..3b48677a 100644 --- a/examples/vision/swin/blocks/patch.py +++ b/examples/vision/swin/blocks/patch.py @@ -22,6 +22,13 @@ def patch_merge(x: torch.Tensor, h: int, w: int): x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C return x +@cube.graph.parser.register('B ic+ (ps^ w^) (ps^ h^), oc ic+ k^ k^, oc -> B oc w^ h^') +def patch(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, ps: int): + """ + @param ps int: patch size + """ + return torch.conv2d(x, w, b, stride=ps) + class PatchMerging(nn.Module): r""" Patch Merging Layer. @@ -72,8 +79,7 @@ class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + patches_resolution = [img_size[0] // patch_size, img_size[1] // patch_size] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution @@ -82,21 +88,26 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la self.in_chans = in_chans self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + # patch_size = (patch_size, patch_size) + # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.conv_w = nn.Parameter(torch.empty(embed_dim, in_chans, self.patch_size, self.patch_size)) + self.conv_b = nn.Parameter(torch.empty(embed_dim)) + if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + # x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + x = patch(x, self.conv_w, self.conv_b, self.patch_size).flatten(2).transpose(1, 2) if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size * self.patch_size) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index dd0fe36a..7381fc56 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -10,6 +10,11 @@ class Config: + # POC test case + embed_dim = 192 + depths = [2, 2, 2, 2] + num_heads = [8, 16, 32, 64] + # swin-large 201M # embed_dim = 192 # depths = [2, 2, 18, 2] @@ -21,9 +26,9 @@ class Config: # num_heads = [16, 32, 64, 128] # 355M - embed_dim = 256 - depths = [2, 2, 18, 2] - num_heads = [8, 16, 32, 64] + # embed_dim = 256 + # depths = [2, 2, 18, 2] + # num_heads = [8, 16, 32, 64] # 1.8B # embed_dim = 512 @@ -217,7 +222,7 @@ def flops(self): # =========================== Data Loader ======================= -class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): +class ImageDataLoader(cube.runtime.syndata.SynDataLoader): def __init__(self, batch_size: int, img_size: int, num_classes: int, dtype=torch.float32): @@ -229,19 +234,3 @@ def __init__(self, batch_size: int, img_size: int, num_classes: int, dtype=torch dtypes=(dtype,), batch_dims=(0,) ) - self.samples = [self.random_sample(dtype)] - - def random_sample(self, dtype: torch.dtype): - torch.manual_seed(0) - img = torch.rand( - *(self.bs, 3, self.img_size, self.img_size), - dtype=dtype, - device=torch.cuda.current_device() - ) - return img - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] diff --git a/examples/vision/swin/policy/spmd.py b/examples/vision/swin/policy/spmd.py index 39d31eed..815d43ca 100644 --- a/examples/vision/swin/policy/spmd.py +++ b/examples/vision/swin/policy/spmd.py @@ -1,8 +1,10 @@ -from typing import List +from typing import Dict, List from cube.graph import IRGraph from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.dimops import DimopSplit, TransformRule from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.tensor import IRFullTensor, IRSubTensor # ========================= parallelisms ================================= @@ -37,7 +39,6 @@ def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int graph.assign(sub_node, devid) return sub_nodes - # ========================= parallelisms ================================= @@ -51,6 +52,53 @@ def PASSingle(graph: IRGraph, resource): return graph +def PASData(graph: IRGraph, resource): + dp_size = resource.ngpus + dp_devs = list(range(dp_size)) + + ftensors: Dict[IRFullTensor, DimopSplit] = dict() # ftensor: producer partition index + + dataloaders = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] + for dataloader in dataloaders: + algo = dataloader.algorithms('data') + subnodes = graph.partition(dataloader, algo, num=dp_size) + for idx, sub_node in enumerate(subnodes): + graph.assign(sub_node, idx) + for oidx, output in enumerate(dataloader.outputs()): + if not isinstance(output, IRSubTensor): + continue + if output.parent not in ftensors: + bdim = dataloader.get_batch_dims()[oidx] + ftensors[output.parent] = DimopSplit.D(bdim) + + + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + for node in fnodes: + if isinstance(node, IRGraphAnchor): + continue + partitioned = False + for iidx, itensor in enumerate(node.inputs()): + if not isinstance(itensor, IRSubTensor): + continue + if itensor.parent in ftensors: + dim = ftensors[itensor.parent] + assert dim.isD(), f"on partitioning node: {node}:\nexpected input to be partitioned on dimensions but found {dim}" + rule: TransformRule = node.algorithms('dim').infer(idx=iidx, dim=dim.dim, num=len(dp_devs)) + # print(rule) + assert rule is not None, f"fail to infer node: {node}, idx={iidx}" + for odim, output in zip(rule.outputs(), node.outputs()): + ftensors[output.parent] = odim + # print(f'==> setting next dim: {odim}') + _tp(graph, node, dp_devs, idx=iidx, dim=dim.dim) + partitioned = True + break + if not partitioned: + print(f'warning: cannot partition of node using dim propagation, use replica instead: {node}') + _replica(graph, node, dp_devs) + + return graph + + def PASMegatronTP(graph: IRGraph, resource): tp_size = resource.ngpus tp_devs = list(range(tp_size)) diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 4a993133..240618fd 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -42,7 +42,7 @@ def train(): - batch_size = 1 + batch_size = 4 cfg = Config() model = SwinTransformer() diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 71ef5575..84845983 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -61,6 +61,24 @@ OMP_NUM_THREADS=4 torchrun \ examples/nlp/gpt/train.py --policy PASMeshShard --fp16 +# test Swin model + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/vision/swin/train.py --policy PASData --fp16 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/vision/swin/train.py --policy PASMegatronTP --fp16 + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/vision/swin/train.py --policy PASMegatron --fp16 + + # test scientific model OMP_NUM_THREADS=4 torchrun \ From c0934e09496bdcfef94373fa85256b6d5f6e72de Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Aug 2022 13:47:32 +0800 Subject: [PATCH 0966/1892] fix optimizer bug --- examples/nlp/gpt/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 2233699c..c7403e7f 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -52,7 +52,6 @@ def train(): model = GPT() model = model if not args.fp16 else model.half() dataloader = GPTDataLoader(batch_size) - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) model = cube.SemanticModel(model, dataloader.shapes) @cube.compile(model, dataloader, PAS=PAS, override=True) @@ -62,6 +61,8 @@ def train_iter(model, dataloader): loss.backward() model = model.get_gen_module() + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + torch.distributed.barrier() print_each_rank('model weight consumpition:', rank_only=0) memory_summary() From a7135c7a74b5f4967eecbac6b6b64c46f3c5d390 Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 15 Aug 2022 15:01:13 +0800 Subject: [PATCH 0967/1892] add nnfusion.jit support in codegen, usage: USE_NNFUSION=1 ... python ... --- cube/codegen/codegen.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index bf5849fd..f8197989 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -24,6 +24,8 @@ from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock from cube.codegen.frontend_mapping import Sign2EmitRule +import os + def get_backward_callsite_io_tensors(bp_segment:IRSegment): """ Returns: @@ -367,6 +369,11 @@ def __init__(self, execplan: ExecutionPlan): 'from typing import *', 'import torch', 'import torch.utils.checkpoint as ckpt', 'import cube', '', ''] + + use_nnfusion = os.environ.get('USE_NNFUSION') + if use_nnfusion: + self.init_code.extend(['import nnfusion', '']) + # customized op code for _, op_impl in Sign2Op.kOpCodeDef.items(): self.init_code.append(op_impl) @@ -473,6 +480,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: args.append(self.tensor_naming(t)) node_args.append(args) + use_nnfusion = os.environ.get('USE_NNFUSION') # generate full code with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: with FunctionBlock(func_name='__init__', args=['self']) as ib: @@ -488,6 +496,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: name = self.node_naming(node) input_args = ['self'] + node_args[idx] forward_code = self.model_methods_bodies[idx] + with FunctionBlock(func_name=name, args=input_args) as fb: fb.insert_body(forward_code) # generate output @@ -495,6 +504,8 @@ def gen(self, device: int, outfile=None, attach=False) -> str: return_code = f"return {', '.join(outputs)}" fb.insert_body(return_code) cb.insert_body('') + if use_nnfusion and name.startswith('segment'): + cb.insert_body('@nnfusion.jit') cb.insert_body(fb.code) From 72b19620f9bb3e1b07c07b2dea3f3527600e1351 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Aug 2022 18:19:04 +0800 Subject: [PATCH 0968/1892] update synthetic dataloader --- cube/runtime/syndata.py | 64 ++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 42 deletions(-) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index be91c80b..15a3c27a 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -2,7 +2,7 @@ Synthetic Data Loader """ -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import torch @@ -121,7 +121,7 @@ class SynDataLoader(CubeDataLoader): for given shapes, dtypes. """ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, - batch_dims: Tuple[int] = None, length: int = 1280): + batch_dims: Tuple[int] = None): """ shapes Tuple[Tuple[int]]: The shape for each data @@ -129,8 +129,6 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, The dtype for each data (Default None: use torch.float32) batch_dims Tuple[int]: The batch dimension of each data (Default None: dimension 0 is the batch dim) - length int: - Total number of sample batches. (Default 1280) """ if batch_dims is None: batch_dims = tuple([0] * len(shapes)) @@ -138,50 +136,32 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, dtypes = tuple([torch.float] * len(shapes)) super().__init__(shapes, dtypes, batch_dims) - self.length = length - self.pos = 0 - - self._buffer_num = None - self.datas: torch.Tensor = list() - self.set_data_buffer() + self.buffer: Union[torch.Tensor, Tuple[torch.Tensor]] = None + self.set_random_sample() def __iter__(self): - self.pos = 0 return self - def set_data_buffer(self, buffer_num = 4): + def __next__(self): + return self.buffer + + def set_random_sample(self): torch.manual_seed(0) - self.datas = list() - self._buffer_num = buffer_num - for _ in range(self._buffer_num): - datas = list() - for shape, dtype in zip(self.shapes, self.dtypes): - data = torch.randn(shape, dtype=dtype).cuda() - datas.append(data) - self.datas.append(datas) + datas = [] + for shape, dtype in zip(self.shapes, self.dtypes): + datas.append( + torch.rand( + shape, dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False) + ) + if len(datas) == 0: + self.buffer = None + else: + datas = tuple(datas) if len(datas) > 1 else datas[0] + self.buffer = datas def set_batch_size(self, batch_size: int): super().set_batch_size(batch_size) - self.set_data_buffer() - - def __next__(self): - self.pos += 1 - if self.pos == self.length: - raise StopIteration - datas = self.datas[self.pos % self._buffer_num] - if len(datas) == 1: return datas[0] - else: return tuple(datas) - + self.set_random_sample() -class SynTextDataLoader(SynDataLoader): - - def set_data_buffer(self, buffer_num=4, text_num=50257): - torch.manual_seed(0) - self.datas = list() - self._buffer_num = buffer_num - for _ in range(self._buffer_num): - datas = list() - for shape in self.shapes: - data = torch.randint(0, text_num, shape, dtype=torch.long).cuda() - datas.append(data) - self.datas.append(datas) From 8c3063ca383a44c226a8838bd3ae360daeec8c66 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Aug 2022 18:19:18 +0800 Subject: [PATCH 0969/1892] change pytorch version check --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ac1c8955..c5ffda85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ pytest more-itertools --find-links https://download.pytorch.org/whl/torch_stable.html -torch==1.11.0+cu113 \ No newline at end of file +torch>=1.11.0+cu113 \ No newline at end of file From 2edcd93fba370993ffc28f166760df9b3b6d8f05 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Thu, 25 Aug 2022 18:30:58 +0800 Subject: [PATCH 0970/1892] palm init: single device runnable --- cube/graph/function/function.py | 53 ++++-- cube/graph/parser/mapping.py | 6 +- cube/runtime/function/function.py | 4 +- cube/runtime/syndata.py | 6 +- examples/nlp/palm/palm.py | 295 +++++++++++++++++------------- 5 files changed, 216 insertions(+), 148 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 21fee65a..6d8f0768 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -47,7 +47,7 @@ def BatchLinear(signature, inputs): return IRDimops(signature, annos, inputs, 'bmm') -def Zeros(signature, +def Zeros(signature, inputs: Tuple[ List[int], Optional[int], Optional[Any], 'ErasedDevice', Optional[bool] ]): # zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # @@ -74,7 +74,7 @@ def Zeros(signature, raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") return IRZeros(signature, size, 'zeros', ir_dtype) -def Ones(signature, +def Ones(signature, inputs: Tuple[ List[int], Optional[int], Optional[Any], 'ErasedDevice', Optional[bool] ]): # ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -98,7 +98,7 @@ def Ones(signature, raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") return IROnes(signature, size, 'ones', ir_dtype) -def NewTensor(signature, +def NewTensor(signature, inputs: Tuple[ list, Optional[int], 'ErasedDevice', bool ]): # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor # @@ -162,7 +162,7 @@ def ToTensor(signature, dtype : torch.dtype = TorchScalarTypeEnumMap.map(dtype_underlying) ir_dtype : IRDType = DType2IRDType.map(dtype) - + signature = 'torch.Tensor.to' return IRToTensor(signature, [tensor], 'to', ir_dtype=ir_dtype) @@ -249,7 +249,7 @@ def Sub(signature, inputs): inputs = inputs[0:2] else: raise RuntimeError("The number of inputs must be 2 or 3") - + lhs, rhs = inputs if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): @@ -409,6 +409,13 @@ def GeLU(signature, inputs): return IRDimops(signature, annos, tensor, 'gelu') +def SiLU(signature, inputs): + annos = ['* -> *'] + signature = 'torch.nn.functional.silu' + tensor = inputs[0:1] + return IRDimops(signature, annos, tensor, 'silu') + + def Softmax(signature, inputs): annos = ['* -> *'] tensor = inputs[0:1] @@ -461,6 +468,28 @@ def Sum(signature, inputs): else: return IRDimops(signature, [anno], [tensor], 'sum') +def Mean(signature, inputs): + tensor = inputs[0] + dim = inputs[1] + einput = ShapeAnno.create_shape_str(tensor.shape) + eoutput = copy.copy(einput) + if dim is not None: + keepdim = inputs[2] + sort_dim = list(dim) + sort_dim.sort() + for dimidx in sort_dim[::-1]: + eoutput.pop(dimidx) + einput[dimidx] = einput[dimidx] + '+' + else: + eoutput = ['1'] + # every dimension is reduced + einput = [edim + '+' for edim in einput] + anno = OpAnno.create_op_str([einput], [eoutput]) + if dim is not None: + return IRDimops(signature, [anno], [tensor], 'mean', dim=dim, keepdim=keepdim) + else: + return IRDimops(signature, [anno], [tensor], 'mean') + def Transpose(signature, inputs): """ @@ -589,8 +618,8 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s bracket[subdim] = str(shape_map[edim]) # bracket[subdim] = edim + '^' anno = OpAnno.create_op_str([in_anno], [ou_anno]) - signature = 'torch.Tensor.view' - return IRDimops(signature, [anno], [input], 'view', size=tuple(shape)) + signature = 'torch.reshape' + return IRDimops(signature, [anno], [input], 'view', shape=tuple(shape)) def Reshape(signature, inputs): @@ -598,10 +627,10 @@ def Reshape(signature, inputs): torch.reshape(Tensor self, int[] shape) -> Tensor """ - warnings.warn(""" - 'torch.reshape' is currently dispatched to 'torch.Tensor.view', - but 'reshape' has keyword parameter 'shape' while 'view' has 'size'. - ArgumentMissing error may be raised during codegen.""") + # warnings.warn(""" + # 'torch.reshape' is currently dispatched to 'torch.Tensor.view', + # but 'reshape' has keyword parameter 'shape' while 'view' has 'size'. + # ArgumentMissing error may be raised during codegen.""") return View(signature, inputs) @@ -831,4 +860,4 @@ def CustomOps(signature, inputs): else: import warnings - warnings.warn(f"ERROR Unknown custom op, signature {signature}") \ No newline at end of file + warnings.warn(f"ERROR Unknown custom op, signature {signature}") diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 5c97fb59..93124d42 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -70,6 +70,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('gelu') : function.GeLU, __ttemplate('gelu') : function.GeLU, + __ftemplate('silu') : function.SiLU, + __ttemplate('silu') : function.SiLU, + __ftemplate('_pad'): function.Pad, __ftemplate('layer_norm'): function.LayerNorm, @@ -112,6 +115,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('bmm') : function.BatchLinear, __ttemplate('sum') : function.Sum, + __ttemplate('mean') : function.Mean, __ttemplate('transpose') : function.Transpose, @@ -143,7 +147,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # runtime functions __rtemplate('anchor'): function.GraphAnchor, - + __rtemplate('identity'): function.Identity, #einops diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index f4920b36..1abb4ccb 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -11,7 +11,7 @@ def identity(tensor: torch.Tensor) -> torch.Tensor: def anchor(name: str): """ - anchor operation for graph navigation + anchor operation for graph navigation """ return None @@ -138,4 +138,4 @@ def update_geopotential_(phi: torch.Tensor, zs: torch.Tensor, P: torch.Tensor, P return phi def strip_2_borders(w: torch.Tensor): - return w[1:-1] \ No newline at end of file + return w[1:-1] diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index be91c80b..21440b69 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -156,7 +156,11 @@ def set_data_buffer(self, buffer_num = 4): for _ in range(self._buffer_num): datas = list() for shape, dtype in zip(self.shapes, self.dtypes): - data = torch.randn(shape, dtype=dtype).cuda() + # TODO + if dtype == torch.int32: + data = torch.randint(0, 20000, shape, dtype=dtype).cuda() + else: + data = torch.randn(shape, dtype=dtype).cuda() datas.append(data) self.datas.append(datas) diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index c720cf37..e6dce76c 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -15,105 +15,78 @@ from einops import rearrange, repeat +from cube.graph import IRGraph + import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank +from cube.ir.operator import IRDataOperation, IRFwOperation import examples.mlp.policy.spmd as spmd import examples.mlp.policy.mpmd as mpmd import argparse -def exists(val): - return val is not None - -# parser = argparse.ArgumentParser(description='comm primitive') -# parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -# args = parser.parse_args() -# -# cube.init() -# -# # set up policy -# PAS = None -# policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -# if args.policy in spmd.__dict__: -# PAS = spmd.__dict__[args.policy] -# print_each_rank(f'using policy from spmd.{args.policy}') -# elif args.policy in mpmd.__dict__: -# PAS = mpmd.__dict__[args.policy] -# print_each_rank(f'using policy from mpmd.{args.policy}') -# else: -# raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - +cube.init() # =================== Semantic Model Description ==================== -class MLP(nn.Module): - def __init__(self, dim, mult=1): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult) - self.linear2 = nn.Linear(dim * mult, dim) - self.linear3 = nn.Linear(dim, dim * mult) - self.linear4 = nn.Linear(dim * mult, dim) - # self.linear5 = nn.Linear(dim, dim * mult) - # self.linear6 = nn.Linear(dim * mult, dim) - # self.linear7 = nn.Linear(dim, dim * mult) - # self.linear8 = nn.Linear(dim * mult, dim) - - def forward(self, data): - output = self.linear1(data) - output = self.linear2(output) - output = self.linear3(output) - output = self.linear4(output) - # output = self.linear5(output) - # output = self.linear6(output) - # output = self.linear7(output) - # output = self.linear8(output) - loss = torch.sum(output) - return loss - # normalization + class RMSNorm(nn.Module): - def __init__(self, dim, eps = 1e-8): + + def __init__(self, dim, eps=1e-8): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): - norm = torch.norm(x, dim = -1, keepdim = True) * self.scale - return x / norm.clamp(min = self.eps) * self.g + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +def exists(val): + return val is not None + # AliBi + class AlibiPositionalBias(nn.Module): + def __init__(self, heads, **kwargs): super().__init__() self.heads = heads slopes = torch.Tensor(self._get_slopes(heads)) slopes = rearrange(slopes, 'h -> h 1 1') - self.register_buffer('slopes', slopes, persistent = False) - self.register_buffer('bias', None, persistent = False) + self.register_buffer('slopes', slopes, persistent=False) + self.register_buffer('bias', None, persistent=False) def get_bias(self, i, j, device): - i_arange = torch.arange(i, device = device) - j_arange = torch.arange(j, device = device) - bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) + i_arange = torch.arange(i, device=device) + j_arange = torch.arange(j, device=device) + bias = -torch.abs( + rearrange(j_arange, 'j -> 1 1 j') - + rearrange(i_arange, 'i -> 1 i 1')) return bias @staticmethod def _get_slopes(heads): + def get_slopes_power_of_2(n): - start = (2**(-2**-(log2(n)-3))) + start = (2**(-2**-(log2(n) - 3))) ratio = start - return [start*ratio**i for i in range(n)] + return [start * ratio**i for i in range(n)] if log2(heads).is_integer(): return get_slopes_power_of_2(heads) - closest_power_of_2 = 2 ** floor(log2(heads)) - return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] + closest_power_of_2 = 2**floor(log2(heads)) + return get_slopes_power_of_2( + closest_power_of_2) + get_slopes_power_of_2( + 2 * closest_power_of_2)[0::2][:heads - closest_power_of_2] def forward(self, qk_sim): h, i, j, device = *qk_sim.shape[-3:], qk_sim.device @@ -130,28 +103,38 @@ def forward(self, qk_sim): return bias + class SwiGLU(nn.Module): + + def __init__(self, dim): + super().__init__() + + self.dim = dim + def forward(self, x): - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x + # x, gate = x.chunk(2, dim=-1) + u, gate = x[:, :, 0:self.dim // 2], x[:, :, self.dim // 2:] + return F.silu(gate) * u + class PaLMLayer(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): super().__init__() self.dim, self.dim_head, self.heads, self.scale = dim, dim_head, heads, dim_head**-0.5 - self.alibi_pos_biases = AlibiPositionalBias(heads=self.heads) + # TODO + # self.alibi_pos_biases = AlibiPositionalBias(heads=self.heads) + # self.norm = RMSNorm(dim) + self.norm = torch.nn.LayerNorm(self.dim) - self.norm = RMSNorm(dim) self.qkv_proj = nn.Linear(dim, dim + dim_head, bias=False) self.attn_out = nn.Linear(dim, dim, bias=False) - self.ff = nn.Sequential( - nn.Linear(dim, 2 * ff_mult * dim, bias=False), - SwiGLU(), - nn.Linear(ff_mult * dim, dim, bias=False) - ) + self.ff = nn.Sequential(nn.Linear(dim, 2 * ff_mult * dim, bias=False), + SwiGLU(2 * ff_mult * dim), + nn.Linear(ff_mult * dim, dim, bias=False)) self.register_buffer("mask", None, persistent=False) @@ -159,99 +142,147 @@ def get_mask(self, n, device): if self.mask is not None and self.mask.shape[-1] >= n: return self.mask[:n, :n] - mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), 1) + mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), + 1) self.register_buffer("mask", mask, persistent=False) return mask - def forward(self, x): - n, device = x.shape[1], x.device + def forward(self, in_x): + bs, n, device = in_x.shape[0], in_x.shape[1], in_x.device # pre layernorm - x = self.norm(x) + x = self.norm(in_x) ff_out = self.ff(x) - q, kv = self.qkv_proj(x).split((self.dim, self.dim_head), dim=-1) + # q, kv = self.qkv_proj(x).split((self.dim, self.dim_head), dim=-1) + proj = self.qkv_proj(x) + q, kv = proj[:, :, 0:self.dim], proj[:, :, self.dim:] - q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) + # q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) + q = q.view(bs, n, self.heads, self.dim_head).transpose(1, 2) # scale q = q * self.scale # similarity - sim = einsum('b h i d, b j d -> b h i j', q, kv) + # sim = einsum('b h i d, b j d -> b h i j', q, kv) + q = q.reshape(bs, self.heads * n, self.dim_head) + trans_kv = kv.transpose(1, 2) + sim = torch.bmm(q, trans_kv).view(bs, self.heads, n, n) - sim = sim + self.alibi_pos_biases(sim) + # TODO + # sim = sim + self.alibi_pos_biases(sim) + # TODO # causal mask - causal_mask = self.get_mask(n, device) - sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + # causal_mask = self.get_mask(n, device) + # sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) # attention - attn = sim.softmax(dim=-1) - out = einsum("b h i j, b j d -> b h i d", attn, kv) + # attn = sim.softmax(dim=-1) + attn = torch.nn.functional.softmax(sim, dim=-1) + + # out = einsum("b h i j, b j d -> b h i d", attn, kv) + attn = attn.view(bs, self.heads * n, n) + out = torch.bmm(attn, kv).view(bs, self.heads, n, self.dim_head) # merge heads - out = rearrange(out, "b h n d -> b n (h d)") + # out = rearrange(out, "b h n d -> b n (h d)") + out = torch.transpose(out, 1, 2).reshape(bs, n, self.dim) merge_heads = self.attn_out(out) + ff_out - return merge_heads + return in_x + merge_heads + + +class PaLM(nn.Module): + + def __init__(self, + dim, + num_tokens, + depth, + dim_head=64, + heads=8, + ff_mult=4): + super().__init__() + + self.net = nn.Sequential( + nn.Embedding(num_tokens, dim), + *[PaLMLayer(dim, dim_head, heads, ff_mult) for _ in range(depth)], + # TODO: RMSNorm(dim), + torch.nn.LayerNorm(dim), + nn.Linear(dim, num_tokens, bias=False), + ) + + self.net[-1].weight = self.net[0].weight + nn.init.normal_(self.net[0].weight, std=0.02) + + def forward(self, x): + return self.net(x).mean() + + +def PASSingle(graph: IRGraph, resource): + assert resource.ngpus == 1 + + for node in graph.nodes(): + if isinstance(node, (IRDataOperation, IRFwOperation)): + graph.assign(node, 0) + + return graph + def train(): - bs, n, d = 8, 1024, 512 - - model = PaLMLayer(d) - - x = torch.randn(bs, n, d) - - y = model(x) - -# def train(): -# batch_size = 256 -# dim = 8192 -# -# model = MLP(dim=dim) -# model = cube.SemanticModel( -# model, input_shapes=([batch_size, dim],), -# ) -# -# dataloader = cube.runtime.syndata.SynDataLoader( -# shapes=([batch_size, dim],), -# dtypes=(torch.float32,), -# batch_dims=(0,) -# ) -# -# @cube.compile(model, dataloader, PAS=PAS) -# def train_iter(model, dataloader): -# data = next(dataloader) -# loss = model(data) -# loss.backward() -# model = model.get_gen_module() -# -# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) -# -# CudaTimer(enable=False).warmup() -# if torch.distributed.is_initialized(): -# torch.distributed.barrier() -# iter_num = 64 -# warmup = 20 -# for step in range(iter_num): -# if step >= warmup: -# CudaTimer(enable=True).start('e2e') -# train_iter(model, dataloader) -# optimizer.step() -# optimizer.zero_grad() -# if step >= warmup: -# CudaTimer().stop('e2e') -# if (step + 1) % 20 == 0: -# print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) -# -# print_each_rank('e2e time (ms) per iteration: {} ms'.format( -# CudaTimer().duration(iter_num-warmup, field_name='e2e'))) -# CudaTimer().print_all(times=iter_num-warmup) + bs, n, dim = 8, 1024, 512 + num_tokens, depth, heads, dim_head = 20000, 1, 8, 64 + + model = PaLM(dim, num_tokens, depth, heads=heads, dim_head=dim_head) + + # for debug + # tokens = torch.randint(0, num_tokens, (bs, n)) + # print(model(tokens)) + # return + + model = cube.SemanticModel( + model, + input_shapes=([bs, n], ), + ) + + dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, n], ), + dtypes=(torch.int32, ), + batch_dims=(0, )) + + @cube.compile(model, dataloader, PAS=PASSingle) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + + model = model.get_gen_module() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer(enable=False).warmup() + if torch.distributed.is_initialized(): + torch.distributed.barrier() + iter_num = 64 + warmup = 20 + for step in range(iter_num): + if step >= warmup: + CudaTimer(enable=True).start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= warmup: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num - warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num - warmup) train() From 92688ae65a29f459e43c751dcf524a509d14844e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 26 Aug 2022 19:48:03 +0800 Subject: [PATCH 0971/1892] add dimop tutorial --- tutorial.md | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 tutorial.md diff --git a/tutorial.md b/tutorial.md new file mode 100644 index 00000000..d4d410d2 --- /dev/null +++ b/tutorial.md @@ -0,0 +1,107 @@ +# Dimop Tutorial + +## Dimop: Dimension-annotated Operator + +### Annotation for Shape Inference and Transformation + +SuperScaler uses annotated dimension to represent an operator (Dimop). +The goal of annotation is for 1). shape inference and 2) transformation plan. + +To annotate an operator, following example shows the annotation matrix multiplication. An operator has inputs and outputs. The inputs can be tensors or non-tensors, while outputs are usually tensors. + +```py +# annotation: m^ kd+, kd+ n -> m^ n +def operator(x: torch.Tensor, w: torch.Tensor, h: float) -> torch.Tensor: + out = torch.matmul(x, w) + return out +``` + +To separate inputs and outputs of an operator, `'->'` is a separation keyword where its left part are inputs and right part are outputs. Each tensor representation is further separated by `','`. + +Each tensor in inputs and outputs is reperented by **{identifiers}{reduction}** on each dimension, like `'m^ kd+'`, `'kd+ n'`, `'m^ n'`, where `m`, `kd` and `n` are identitifiers, `'^'` and `'+'` are reductions. + +If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimensions, the first dimension is `m` and the second dimension is `kd`. Dimensions need to be separated by space `' '`. + +* Identifiers + + Identifiers are served for the shape inference. Same identifier of different tensors indicates they have same length at their dimension. + + An `identifier` must be one of: + + 1) symbolic annotation that must match with the criterion of python `str.isidentifier`. + + 2) numeric string that must match with python str.isdecimal. This indicates the shape is the same value. numeric string will always have '^' reduction type' + + Special identifier: + + 1) `'*'`: this special identifier indicates the dimension is dynamic, which will automatically get expanded given the shape + + 2) `'?'`: this special identifier indicates the value is not a tensor, which will be ignored + + To infer the output shape, the identifiers in output tensors must be 1) appear in inputs, 2) numeric string. + +* Reductions + + Reductions are served for transformation plans. The reduction can be one of {`''`, `'+'`, `'^'`}: + + * `''` (empty) indicates this dimension can be partitioned, and each output should have this dimension. + + * `'+'` indicates this dimension can be partitioned. Each output doesn't have this and need to do sum-reduction. + + * `'^'` means this dimension cannot be partitioned. + +### Advance + +* Hidden dimension + + Sometimes user need to reshape the tensor by splitting a dimension into multiple dimensions. For example, a tensor of (1024, 8) size needs to be reshaped into the shape of (8, 128, 8): + + ```py + # annotation: (h t) k -> h t k + def reshape(tensor: torch.Tensor, h : int = 8) -> torch.Tensor: + out = tensor.reshape(h, tensor.size(0) // h, tensor.size(-1)) + return out + ``` + + This can be represented by annotating a dimension using brackets `()` that contains multple identifiers (and their reductions), like `'(h t) k'` here for the input tensor. To help system infer the number of `h` and `t` in the annotation, the function requires to put in a same-named argument `h` or `t` (`h=8` here in example). + + +## Register Python Functions as Operators + +To register a customized "matmul" operator in the runtime, user can simply define the python function and add an decorator on the function with its annotations: + +```py +@cube.graph.parser.register('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom') +def operator(x: torch.Tensor, w: torch.Tensor, h: float) -> torch.Tensor: + out = torch.matmul(x, w) + out = out.view(h, out.size(0) // h, out.size(1)) + return out + + +class Model(troch.nn.Module): + + ... + + def forward(x, w): + ... + out = operator(x, w) # simply use it + ... +``` + +During policy decsion, user can see the operator and its name is 'matmul_custom'. To partition the operator, user can get algorithm of tag `'dim'` and partition the annotated dimension, e.g., `kd+` and `n` of the above example. + +```py +def PAS(graph: IRGraph, resource): + for node in graph.nodes(): + if node.name == 'matmul_custom': + algo = node.algorithms('dim') + # partition kd+ + config = dict(idx=0, dim=1, num=resource.ngpus) + subnodes = graph.partition(node, algo, **config) + ... + ... + ... + return graph +``` + +Note: we require user to add type annotation of output and input in the function, to help system understand each identifier number. The non-tensor inputs should be listed at the last and don't need to be represented into annotation. \ No newline at end of file From 0abde99b2961eda590c920d1d89ca307d2c4bb7b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 26 Aug 2022 20:00:01 +0800 Subject: [PATCH 0972/1892] fix typos --- tutorial.md | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tutorial.md b/tutorial.md index d4d410d2..d71ba92e 100644 --- a/tutorial.md +++ b/tutorial.md @@ -4,10 +4,10 @@ ### Annotation for Shape Inference and Transformation -SuperScaler uses annotated dimension to represent an operator (Dimop). +SuperScaler uses annotation to represent an operator (Dimop). The goal of annotation is for 1). shape inference and 2) transformation plan. -To annotate an operator, following example shows the annotation matrix multiplication. An operator has inputs and outputs. The inputs can be tensors or non-tensors, while outputs are usually tensors. +To annotate an operator, following example shows the annotation of matrix multiplication. An operator has inputs and outputs. The inputs can be tensors or non-tensors, while outputs are usually tensors. ```py # annotation: m^ kd+, kd+ n -> m^ n @@ -16,37 +16,39 @@ def operator(x: torch.Tensor, w: torch.Tensor, h: float) -> torch.Tensor: return out ``` -To separate inputs and outputs of an operator, `'->'` is a separation keyword where its left part are inputs and right part are outputs. Each tensor representation is further separated by `','`. +To separate inputs and outputs of an operator, `'->'` is a separation keyword where its left part are inputs and right part are outputs. Inside inputs and outputs region, annotation of each tensor is further separated by `','`. -Each tensor in inputs and outputs is reperented by **{identifiers}{reduction}** on each dimension, like `'m^ kd+'`, `'kd+ n'`, `'m^ n'`, where `m`, `kd` and `n` are identitifiers, `'^'` and `'+'` are reductions. +Every dimension of a tensor is annotated by a template of **{identifiers}{reduction}**, like `'m^ kd+'`, `'kd+ n'`, `'m^ n'`, where `m`, `kd` and `n` are identitifiers, `'^'` and `'+'` are reductions. If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimensions, the first dimension is `m` and the second dimension is `kd`. Dimensions need to be separated by space `' '`. * Identifiers - Identifiers are served for the shape inference. Same identifier of different tensors indicates they have same length at their dimension. + Identifiers are served for shape inference. Same identifier of different tensors indicates they have same length of their dimension. An `identifier` must be one of: 1) symbolic annotation that must match with the criterion of python `str.isidentifier`. - 2) numeric string that must match with python str.isdecimal. This indicates the shape is the same value. numeric string will always have '^' reduction type' + 2) numeric string that must match with python str.isdecimal. This indicates the shape is the same value. Numeric string will always have '^' reduction type Special identifier: - 1) `'*'`: this special identifier indicates the dimension is dynamic, which will automatically get expanded given the shape + 1) `'*'`: this special identifier indicates the dimension is dynamic, which will automatically get expanded given the shape. If there are multiple `*` for different tensors, then they must have same shape for the expanded dimensions, + + e.g., `'* t -> a * t'` can be expanded into `'b c t -> a b c t'` 2) `'?'`: this special identifier indicates the value is not a tensor, which will be ignored - To infer the output shape, the identifiers in output tensors must be 1) appear in inputs, 2) numeric string. + To infer the output shape, the identifiers in output tensors must be 1) appear in inputs or 2) numeric string. * Reductions Reductions are served for transformation plans. The reduction can be one of {`''`, `'+'`, `'^'`}: - * `''` (empty) indicates this dimension can be partitioned, and each output should have this dimension. + * `''` (empty) indicates this dimension can be spatially partitioned, and each output that have this identifier will also be spatially partitioned. - * `'+'` indicates this dimension can be partitioned. Each output doesn't have this and need to do sum-reduction. + * `'+'` indicates this dimension can be spatially partitioned. And each output that doesn't have this identifier will be numerically partitioned (sum-reduction required). * `'^'` means this dimension cannot be partitioned. @@ -63,12 +65,12 @@ If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimens return out ``` - This can be represented by annotating a dimension using brackets `()` that contains multple identifiers (and their reductions), like `'(h t) k'` here for the input tensor. To help system infer the number of `h` and `t` in the annotation, the function requires to put in a same-named argument `h` or `t` (`h=8` here in example). + This can be represented by annotating a dimension using brackets `()`. The bracket contains multple identifiers (and their reductions), like `'(h t)'` here for the first dimension of the input tensor. To help system infer the number of `h` and `t` in the annotation, the function requires to put in a same-named argument `h` or `t` (`h=8` here in example). ## Register Python Functions as Operators -To register a customized "matmul" operator in the runtime, user can simply define the python function and add an decorator on the function with its annotations: +To register a customized "matmul" operator in the runtime, user can simply define a python function and add an decorator on the function with its annotations: ```py @cube.graph.parser.register('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom') From e30538480ef22f495a53a069d2ca3986faf7f656 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Wed, 31 Aug 2022 16:41:15 +0800 Subject: [PATCH 0973/1892] data parallel runnable --- examples/nlp/palm/palm.py | 139 +++++++++++++++++++++----------------- 1 file changed, 78 insertions(+), 61 deletions(-) diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index e6dce76c..d5846b78 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -104,17 +104,44 @@ def forward(self, qk_sim): return bias -class SwiGLU(nn.Module): - - def __init__(self, dim): - super().__init__() - - self.dim = dim - - def forward(self, x): - # x, gate = x.chunk(2, dim=-1) - u, gate = x[:, :, 0:self.dim // 2], x[:, :, self.dim // 2:] - return F.silu(gate) * u +@cube.graph.parser.register('N L^ E^, E^ F^, E^ E^ -> N L^ E^', + name='multi_head_attention') +def multi_head_attention(x: torch.Tensor, qkv_proj: torch.Tensor, + out_proj: torch.Tensor, heads: int, scale: float): + ''' + x: [bs, len, dim] + qkv_proj: [dim, dim + dim_head] + out_proj: [dim, dim] + ''' + bs, n, dim = x.size() + dim_head = dim // heads + + q, kv = torch.matmul(x, qkv_proj).split((dim, dim_head), dim=-1) + q = q.view(bs, n, heads, dim_head).transpose(1, 2) * scale + q = q.reshape(bs, heads * n, dim_head) + trans_kv = kv.transpose(1, 2) + sim = torch.bmm(q, trans_kv).view(bs, heads, n, n) + attn = torch.nn.functional.softmax(sim, dim=-1) + attn = attn.view(bs, heads * n, n) + out = torch.bmm(attn, kv).view(bs, heads, n, dim_head) + out = torch.transpose(out, 1, 2).reshape(bs, n, dim) + out = torch.matmul(out, out_proj) + return out + + +@cube.graph.parser.register('N L^ E^, E^ F^, G^ H^ -> N L^ H^', + name='feedforward') +def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): + ''' + x: [bs, len, dim] + proj1: [dim, 2 * ff_mult * dim] + proj2: [ff_mult * dim, dim] + ''' + x = torch.matmul(x, proj1) + x, gate = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(gate) * x + x = torch.matmul(x, proj2) + return x class PaLMLayer(nn.Module): @@ -129,14 +156,13 @@ def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): # self.norm = RMSNorm(dim) self.norm = torch.nn.LayerNorm(self.dim) - self.qkv_proj = nn.Linear(dim, dim + dim_head, bias=False) - self.attn_out = nn.Linear(dim, dim, bias=False) + self.qkv_proj = torch.nn.Parameter(torch.randn(dim, dim + dim_head)) + self.attn_out_proj = torch.nn.Parameter(torch.randn(dim, dim)) - self.ff = nn.Sequential(nn.Linear(dim, 2 * ff_mult * dim, bias=False), - SwiGLU(2 * ff_mult * dim), - nn.Linear(ff_mult * dim, dim, bias=False)) + self.ff_proj1 = torch.nn.Parameter(torch.randn(dim, 2 * ff_mult * dim)) + self.ff_proj2 = torch.nn.Parameter(torch.randn(ff_mult * dim, dim)) - self.register_buffer("mask", None, persistent=False) + # self.register_buffer("mask", None, persistent=False) def get_mask(self, n, device): if self.mask is not None and self.mask.shape[-1] >= n: @@ -153,49 +179,12 @@ def forward(self, in_x): # pre layernorm x = self.norm(in_x) - ff_out = self.ff(x) - - # q, kv = self.qkv_proj(x).split((self.dim, self.dim_head), dim=-1) - proj = self.qkv_proj(x) - q, kv = proj[:, :, 0:self.dim], proj[:, :, self.dim:] - - # q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) - q = q.view(bs, n, self.heads, self.dim_head).transpose(1, 2) - - # scale - q = q * self.scale - - # similarity - # sim = einsum('b h i d, b j d -> b h i j', q, kv) - q = q.reshape(bs, self.heads * n, self.dim_head) - trans_kv = kv.transpose(1, 2) - sim = torch.bmm(q, trans_kv).view(bs, self.heads, n, n) - - # TODO - # sim = sim + self.alibi_pos_biases(sim) - - # TODO - # causal mask + attn_out = multi_head_attention(x, self.qkv_proj, self.attn_out_proj, + self.heads, self.scale) - # causal_mask = self.get_mask(n, device) - # sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + ff_out = feedforward(x, self.ff_proj1, self.ff_proj2) - # attention - - # attn = sim.softmax(dim=-1) - attn = torch.nn.functional.softmax(sim, dim=-1) - - # out = einsum("b h i j, b j d -> b h i d", attn, kv) - attn = attn.view(bs, self.heads * n, n) - out = torch.bmm(attn, kv).view(bs, self.heads, n, self.dim_head) - - # merge heads - - # out = rearrange(out, "b h n d -> b n (h d)") - out = torch.transpose(out, 1, 2).reshape(bs, n, self.dim) - - merge_heads = self.attn_out(out) + ff_out - return in_x + merge_heads + return in_x + attn_out + ff_out class PaLM(nn.Module): @@ -212,7 +201,6 @@ def __init__(self, self.net = nn.Sequential( nn.Embedding(num_tokens, dim), *[PaLMLayer(dim, dim_head, heads, ff_mult) for _ in range(depth)], - # TODO: RMSNorm(dim), torch.nn.LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False), ) @@ -234,6 +222,34 @@ def PASSingle(graph: IRGraph, resource): return graph +def PASData(graph: IRGraph, resource): + """ + Data Parallel + """ + assert resource.ngpus == 2 + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + batch_dim = node.get_batch_dims()[0] + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + sign = node.signature.split('.')[-1] + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, + algo, + idx=0, + dim=batch_dim, + num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph + + def train(): bs, n, dim = 8, 1024, 512 num_tokens, depth, heads, dim_head = 20000, 1, 8, 64 @@ -251,10 +267,11 @@ def train(): ) dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, n], ), - dtypes=(torch.int32, ), + dtypes=(torch.float32, ), batch_dims=(0, )) - @cube.compile(model, dataloader, PAS=PASSingle) + # @cube.compile(model, dataloader, PAS=PASSingle) + @cube.compile(model, dataloader, PAS=PASData) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) From f409471d6e9aa87ba6b2fe67c492ae6c016c0b71 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 1 Sep 2022 20:45:18 +0800 Subject: [PATCH 0974/1892] initialize hierarchical graph (new segment) --- cube/graph/function/anchor.py | 18 +- cube/graph/function/dimops.py | 10 +- cube/graph/graph.py | 1002 ++++++++++++++++++++------------- cube/graph/segment.py | 238 ++++++++ cube/ir/dtype.py | 31 + cube/ir/operator.py | 53 +- cube/ir/tensor.py | 25 +- cube/logics/model.py | 46 +- 8 files changed, 941 insertions(+), 482 deletions(-) create mode 100644 cube/graph/segment.py diff --git a/cube/graph/function/anchor.py b/cube/graph/function/anchor.py index 18c82803..2fec2200 100644 --- a/cube/graph/function/anchor.py +++ b/cube/graph/function/anchor.py @@ -1,24 +1,26 @@ from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRSubTensor class IRGraphAnchor(IRFwOperation): """ - The anchor function for navigation inside the graph + The anchor function serves for + 1) navigation inside the graph + 2) staging boundary inside the graph + + This operator will eventually be removed from graph, + user doesn't need to manipulate it. """ def __init__(self, signature: str, name: str): super().__init__(name, signature, 0, 1) self.kwargs['name'] = name self.set_output(0, None) + def infer_dtype(self): + return + def infer_shape(self): return True def __repr__(self) -> str: - sign = self.signature.split('.')[-1] - ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_attr()] - dscp = (f"FwOp{self._id}(sign={sign}[{self.name}], " - f"inputs={ins}, " - f"outputs={self.outputs()})") - return dscp + return f"AnchorOp-{self.cid}(name={self.name})" diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 5ad2502c..d0b148ed 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -68,8 +68,10 @@ import string from cube.ir.cten import IRTensor +from cube.ir.dtype import DTypeInferRule from cube.ir.operator import IRFwOperation from cube.algorithm.factory import DistAlgorithmFactory +from cube.ir.tensor import IRSubTensor _kSpecialIdentifiers = ('*', '?') @@ -630,10 +632,12 @@ def oanno(self, index: int) -> Tuple[DimAnno]: def infer_shape(self) -> bool: """ - Shape inference using the matched annotation + Shape and dtype inference using the matched annotation and tensor. @return sucess: True if successfully inferred shape """ + idtypes = [t.dtype for t in self._inputs if isinstance(t, IRTensor)] + odtype = DTypeInferRule.infer(self, idtypes) for oidx, otensor in enumerate(self.outputs()): shape_anno = self.oanno(oidx) shape = [] @@ -643,6 +647,10 @@ def infer_shape(self) -> bool: accum *= self.anno.getlen(identifier) shape.append(accum) otensor.shape = shape + # set output shape + if isinstance(otensor, IRSubTensor): + otensor.parent.dtype = odtype + otensor.dtype = odtype # print(f'=> sign: {self.signature} anno: {self.anno}\n' # f'=> inputs: {self.inputs()}\n' # f'=> outputs: {self.outputs()}') diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 48d8426c..7aea3deb 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,100 +7,80 @@ will be inserted at scheduling time. """ -from typing import Union, Tuple, List, Optional, Dict -import copy -from cube.graph.function.function import MultiRef +from contextlib import contextmanager +from typing import Union, Tuple, List, Optional, Dict, Set + +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.function import Identity, MultiRef +from cube.graph.segment import IRSegment from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.ir.adapter import IRAdapter -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor, StartEnd from cube.algorithm.generics import GenericDistAlgo -class IRSegment(IRCell): - """ - A distributed sub-graph representing a piece of workload in parent IRGraph - """ +class GraphIndex: - def __init__(self, nodes: List[IRCell], inputs: List[IRSubTensor], outputs: List[IRSubTensor]): - super().__init__('segment', '', len(inputs), len(outputs), init_outputs=False) + def __init__(self, gidx: int, sidx: Optional[int]): + # inner-graph index + assert isinstance(gidx, int) + self.gidx = gidx + # inner-segment index + assert sidx is None or isinstance(sidx, int) + self.sidx: Optional[int] = sidx - self._nodes = nodes - self._idevice = [t.device for t in inputs] - self._odevice = [t.device for t in outputs] + def __hash__(self) -> int: + return hash((self.gidx, self.sidx)) - for idx, val in enumerate(inputs): - self.set_input(idx, val) - for idx, val in enumerate(outputs): - self.set_output(idx, val) - # setup device - device = set() - for node in nodes: - device.update(node.device) - self.device = list(device) - # setup whether forward - fnodes = any(isinstance(n, IRFwOperation) for n in nodes) - bnodes = any(isinstance(n, IRBpOperation) for n in nodes) - assert not (fnodes and bnodes), "An IRSegment cannot have both forward nodes and backward nodes" - self._forward = fnodes + def __eq__(self, other: object) -> bool: + assert isinstance(other, GraphIndex), "Cannot compare with non-GraphIndex object" + return self.gidx == other.gidx and self.sidx == other.gidx + + def __lt__(self, other: object) -> bool: + assert isinstance(other, GraphIndex), "Cannot compare with non-GraphIndex object" + if self.gidx < other.gidx: + return True + if self.gidx > other.gidx: + return False + if isinstance(self.sidx, int) and isinstance(other.sidx, int): + return self.sidx < other.sidx + if self.sidx is None and isinstance(other.sidx, int): + return True + return False + + def __le__(self, other: object) -> bool: + return self < other or self == other - @property - def forward(self) -> bool: - return self._forward + def __gt__(self, other: object) -> bool: + return not self <= other - def nodes(self, idx: Optional[int] = None) -> Union[IRCell, List[IRCell]]: - if isinstance(idx, int): - return self._nodes[idx] + def __ge__(self, other: object) -> bool: + return not self < other + + def __sub__(self, offset: int): + assert isinstance(offset, int) + if self.sidx is None: + return GraphIndex(self.gidx - offset, self.sidx) else: - return copy.copy(self._nodes) + return GraphIndex(self.gidx, self.sidx - offset) - def dispatch(self, devid: int, for_mirror=True) -> Optional[IRCell]: - """ - Instantiate from distributed representation to a - device-specific sub-graph. - - The mirror will also be dispatched if it is not None. + def __add__(self, offset: int): + assert isinstance(offset, int) + if self.sidx is None: + return GraphIndex(self.gidx + offset, self.sidx) + else: + return GraphIndex(self.gidx, self.sidx + offset) - Return the dispatched segment - """ - if devid not in self.device: - return None - if len(self.device) == 1 and self.device == [devid]: - return self - itensors = [t for t, device in zip(self.inputs(), self._idevice) if devid in device] - otensors = [t for t, device in zip(self.outputs(), self._odevice) if devid in device] - nodes = [n for n in self.nodes() if devid in n.device] - for idx, adapter in enumerate(nodes): - if isinstance(adapter, IRAdapter): - nodes[idx] = adapter.dispatch(devid) - fseg = IRSegment(nodes, itensors, otensors) - fseg._id = self._id - # dispatch for mirror - if for_mirror and isinstance(self.mirror, IRSegment): - bseg = self.mirror.dispatch(devid, for_mirror=False) - IRCell.make_pair(fseg, bseg) - return fseg - - def to_str(self, skip_attr: bool = False) -> str: - name = ('f' if self.forward else 'b') + 'Segment' - inputs = tuple(t for t in self.inputs() if not (t.is_attr() and skip_attr)) - outputs = tuple(t for t in self.outputs() if not (t.is_attr() and skip_attr)) - return f'{name}{self._id}-{self.device}(inputs={inputs}, outputs={outputs})' - - def __repr__(self): - return self.to_str() + def tuple(self) -> Tuple[int, Optional[int]]: + return (self.gidx, self.sidx) - def extra_repr(self) -> str: - dscp = repr(self) - for node in self.nodes(): - dscp += '\n\t' + repr(node) - return dscp -class IRGraph(IRCell): +class IRGraph(IRSegment): """ IR Graph. The hyperGraph for representing distributed graph. @@ -112,49 +92,35 @@ def __init__(self, outputs: Optional[List[IRTensor]], module_name: str): - self._nodes: List[IRCell] = list() - self._attributes = list() - self._full_tensors: Dict[int, IRFullTensor] = dict() - self._train: bool = any( - isinstance(node, IRBpOperation) or - (isinstance(node, IRSegment) and node.forward) or - (isinstance(node, IRAdapter) and node.forward) for node in nodes - ) - - self._sched = None # the schedule strategy - if inputs is None: inputs = IRGraph.get_inputs(nodes) if outputs is None: outputs = IRGraph.get_outputs(nodes) + super().__init__([], inputs, outputs, module_name) + + self._attributes = set() + self._full_tensors: Set[IRFullTensor] = set() - super().__init__( - name=module_name, - signature=module_name, - input_length=len(inputs), - output_length=len(outputs) + self._train: bool = any( + isinstance(node, IRBpOperation) or + (isinstance(node, IRSegment) and node.forward) or + (isinstance(node, IRAdapter) and node.forward) for node in nodes ) - for idx, tensor in enumerate(inputs): - self.set_input(idx, tensor) - for idx, tensor in enumerate(outputs): - self.set_output(idx, tensor) + self._sched = None # the schedule strategy # set parameters / buffers and full tensors for node in nodes: for tensor in node.inputs() + node.outputs(): if isinstance(tensor, IRSubTensor): - pid = tensor.parent._id - self._full_tensors[pid] = tensor.parent + tensor.parent.clear_producer_consumer() + self._full_tensors.add(tensor.parent) if tensor.is_attr(): - self._attributes.append(tensor) - - for ftensor in self._full_tensors.values(): - ftensor.clear_producer_consumer() + self._attributes.add(tensor) # insert node from nodes for idx, node in enumerate(nodes): - self.attach(node, idx) + self.insert(node, idx) self.reset_dependency() @@ -177,7 +143,7 @@ def reset_dependency(self): node.clear_predecessor() node.clear_successor() # TODO: adapter dependency not set - for ftensor in self._full_tensors.values(): + for ftensor in self._full_tensors: for ptensor, producer in zip(ftensor.ptensors, ftensor.producers): for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): if ptensor.overlap(ctensor): @@ -196,26 +162,11 @@ def attributes(self) -> Tuple[IRSubTensor]: """ return tuple(self._attributes) - def full_tensors(self) -> List[IRSubTensor]: + def full_tensors(self) -> List[IRFullTensor]: """ Return full tensor list """ - return list(self._full_tensors.values()) - - def nodes(self, index: Optional[int] = None) -> Union[IRCell, List[IRCell]]: - """ - Get node at position index - """ - if isinstance(index, int): - if index >= len(self._nodes): - raise RuntimeError( - f"Get node out of range ({index} >= {len(self._nodes)})" - ) - return self._nodes[index] - elif index is None: - return copy.copy(self._nodes) - else: - raise TypeError("Expected index to be None or int") + return list(self._full_tensors) def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: """ @@ -237,11 +188,291 @@ def __call__(self, *args): """ return self.forward(*args) + + # ====================== Graph Accessment ========================= + + def flatten(self) -> List[IRCell]: + """ + Get all the single nodes by opening the segment. + + @return List[] + """ + nodes = [] + for node in self._nodes: + if isinstance(node, IRSegment): + nodes += node._nodes + else: + nodes.append(node) + return nodes + + def index(self, node: IRCell) -> GraphIndex: + """ + Get node index in the graph. + + @param node IRCell: the queried node + + @return index Tuple[int, Optional[int]]: (GraphIndex, SegmentIndex) + + """ + if node in self._nodes: + return GraphIndex(self._nodes.index(node), None) + for idx, check_node in enumerate(self._nodes): + if isinstance(check_node, IRSegment): + if check_node.exist(node): + return GraphIndex(idx, check_node.index(node)) + raise KeyError(f"The queried node: {node} not in the graph.") + + def flatten_index(self, node: IRCell) -> int: + """ + Get node index of all the flatten nodes + + @param node IRCell: the queried node, cannot be IRSegment + + @return index int: the index. + """ + idx = 0 + for check_node in self._nodes: + if isinstance(check_node, IRSegment): + if node in check_node._nodes: + return idx + check_node.index(node) + else: + idx += len(check_node.nnodes) + if check_node == node: + return idx + raise KeyError(f"Node {node} not exist in graph") + + def node(self, index: Union[int, GraphIndex]) -> IRCell: + """ + Get node given the index + + @param index Tuple[Optional[int], int]: the queired index of + (SegmentIndex, Index) + + @return node IRCell: the quried node. + """ + index = GraphIndex(index, None) if isinstance(index, int) else index + assert isinstance(index, GraphIndex) + assert len(index) == 2 and isinstance(index[0], int) + node = self._nodes[index.gidx] + if index.sidx is not None: + assert isinstance(node, IRSegment), "Expected IRSegment" + node = node.index(index.sidx) + return node + + # ========================= Graph Manipulation ======================== + + def remove(self, node: IRCell) -> GraphIndex: + """ + Detach (remove) a node from current graph. + TODO: dataflow dependency update. + + * Producer/consumer relationship: + + All the used input and output tensors inside the node + are removed from consumed and produced tensor list. + + @param node IRCell: the removed node. + + @return index Tuple[int, Optional[int]]: index of the detached node in the graph + """ + index = self.index(node) + + # remove node + if index.sidx is None: + self._nodes.pop(index.gidx) + else: + segment = self._nodes[index.gidx] + assert isinstance(segment, IRSegment), "Internal Error: Removing at a wrong index" + segment.remove(node) + + # update consumer and producer for non-adapter nodes + rm_nodes = node.nodes() if isinstance(node, IRSegment) else [node] + for node in rm_nodes: + # adapter doesn't need to consider producer and consumer + if isinstance(node, IRAdapter): + continue + # update consumer + itensors: List[IRSubTensor] = [] + for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor) and itensor not in itensors: + itensors.append(itensor) + for itensor in itensors: + itensor.parent.rm_consumer(node) + # update producer + otensors: List[IRSubTensor] = [] + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor) and otensor not in otensors: + otensors.append(otensor) + for otensor in otensors: + otensor.parent.rm_producer(node) + ftensor = otensor.parent + if len(ftensor.producers) == 0 and len(ftensor.consumers) == 0: + del self._full_tensors[otensor.parent] + return index + + def insert(self, node: IRCell, index: Union[int, GraphIndex]): + """ + Insert a node into current graph at node index. + TODO: dataflow dependency update. + + * Producer/consumer relationship: + + For the node except IRAdapter, all its input and output tensors + will be recorded in consumed and produced tensor list. + + IRAdapter node will not record the consumer and producer. + + @param node IRCell: the inserted node + @param index Union[int, Tuple[int, Optional[int]]]: the inserted index + """ + index = GraphIndex(index, None) if isinstance(index, int) else index + assert isinstance(index, GraphIndex) + + # update producer and consumer + in_nodes = node.nodes() if isinstance(node, IRSegment) else [node] + for node in in_nodes: + if isinstance(node, IRAdapter): continue + # update consumer + itensors: List[IRSubTensor] = [] + for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor) and itensor not in itensors: + itensors.append(itensor) + for itensor in itensors: + self._full_tensors.add(itensor.parent) + idx = 0 + for consumer in itensor.parent.consumers: + if self.index(consumer) < index: + idx += 1 + else: + break + itensor.parent.add_consumer(node, itensor, idx) + # update producer + otensors: List[IRSubTensor] = [] + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor) and otensor not in otensors: + otensors.append(otensor) + for otensor in otensors: + self._full_tensors.add(otensor.parent) + idx = 0 + for producer in otensor.parent.producers: + if self.index(producer) < index: + idx += 1 + else: + break + otensor.parent.add_producer(node, otensor, idx) + + # insert node + if index.sidx is None: + self._nodes.insert(index.gidx, node) + else: + segment = self._nodes[index.gidx] + assert isinstance(segment, IRSegment), "Expected to be a segment" + segment.insert(node, index.sidx) + + return + + def exist(self, node: IRCell) -> bool: + """ + Check if the node is in the graph + + @param node IRCell: the queried node + @return existence bool: True if exist otherwise False + """ + if node in self._nodes: + return True + for segment in self._nodes: + if not isinstance(segment, IRSegment): + continue + if segment.exist(node): + return True + return False + + @contextmanager + def update(self, node): + """ + Update a node. + TODO: update operator dependency + + e.g., + with graph.modify(node) as node: + node.set_input(0, tensor) + + @param node IRCell: the node that must in the graph + @return node IRCell: the modify node + """ + index = self.remove(node) + yield node + self.insert(node, index) + + def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: + """ + Replace one node by multiple nodes + + TODO: update dataflow dependency + + @param node IRCell: the replaced node + @param new_nodes List[IRCell]: the nodes to be inserted. + + @return index int: the replaced node index + """ + index = self.remove(node) + for new_node in new_nodes[::-1]: + self.insert(new_node, index) + return index + + def group(self, fnodes: List[IRCell]) -> IRSegment: + """! + Group consecutive forward nodes into IRSegment. + TODO: update operator dependency + + The corresponding backward nodes will also be grouped. + + @param nodes List[IRCell]: the consecutive node subset of this graph + + @return segment IRSegment: the grouped segment + """ + assert any(not isinstance(node, (IRBpOperation, IRSegment)) for node in fnodes), \ + "grouped nodes cannot be backward operation or segment" + + findices = [self.index(fnode) for fnode in fnodes] + + # get backward nodes + bnodes = [fnode.mirror for fnode in fnodes[::-1] if fnode.mirror is not None] + bindices = [self.index(bnode) for bnode in bnodes] + + assert all(idx.sidx is None for idx in findices), \ + "Grouping operators that are already in segment is not allowed" + assert all(idx.sidx is None for idx in bindices), \ + "Internal Error: backward operators found in segments" + findices = tuple(idx.gidx for idx in findices) + bindices = tuple(idx.gidx for idx in bindices) + + minfidx, maxfidx = min(findices), max(findices) + assert maxfidx - minfidx + 1 == len(fnodes), \ + "Forward nodes are not consecutive" + + if len(bnodes) > 0: + minbidx, maxbidx = min(bindices), max(bindices) + assert maxbidx - minbidx + 1 == len(bnodes), \ + f"Internal Error: backward nodes are not consecutive. maxbidx: {maxbidx}, minbidx: {minbidx}" + + fsegment = self.segment(fnodes) + bsegment = self.segment(bnodes) if len(bnodes) > 0 else None + IRCell.make_pair(fsegment, bsegment) + + # replace backward + if len(bnodes) > 0: + self._nodes = self._nodes[:minbidx] + [bsegment] + self._nodes[maxbidx+1:] + # replace forward + self._nodes = self._nodes[:minfidx] + [fsegment] + self._nodes[maxfidx+1:] + + return fsegment + + # ========================== Graph Creation ======================== + def segment(self, nodes: List[IRCell]) -> IRSegment: """! - Create a segment (sub-graph) with part of the nodes. - Nodes are allowed to be on different devices. - The grouped segement will not add into graph.nodes(). + Create a segment with part of the nodes. @param nodes List[IRCell]: the subset nodes of this graph @@ -278,185 +509,40 @@ def segment(self, nodes: List[IRCell]) -> IRSegment: segment = IRSegment(nodes, inputs, outputs) return segment - def group(self, nodes: List[IRCell]) -> IRSegment: - """! - Group consecutive nodes into IRSegment. the grouped segment will - replace the nodes in the graph. - - Note: Currently this interface will break the dependency, - it can only be used after user policy - - @param nodes List[IRCell]: the consecutive node subset of this graph - - @return segment IRSegment: the grouped segment - """ - allnodes = self.nodes() - indices = [allnodes.index(n) for n in nodes] - minidx, maxidx = min(indices), max(indices) - assert maxidx - minidx + 1 == len(nodes), "nodes are not consecutive" - segment = self.segment(nodes) - self._nodes = allnodes[:minidx] + [segment] + allnodes[maxidx+1:] - # FIXME: set segment dependnecy - return segment - - def detach(self, node: IRCell, reset_dependency=False) -> int: - """ - Detach (remove) a node from current graph. - - All the used input and output tensors inside the node - are removed from consumed and produced tensor list. - - Return: - index (int): index of the detached node in the graph - """ - if node not in self.nodes(): - raise KeyError(f"node {node} is not in graph.") - index = self._nodes.index(node) - self._nodes.pop(index) - if isinstance(node, IRAdapter): - return index - # update consumer - itensors: List[IRSubTensor] = [] - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor) and itensor not in itensors: - itensors.append(itensor) - for itensor in itensors: - itensor.parent.rm_consumer(node) - # update producer - otensors: List[IRSubTensor] = [] - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor) and otensor not in otensors: - otensors.append(otensor) - for otensor in otensors: - otensor.parent.rm_producer(node) - ftensor = otensor.parent - if len(ftensor.producers) == 0 and len(ftensor.consumers) == 0: - del self._full_tensors[otensor.parent.tid] - if reset_dependency: - self.reset_dependency() - return index - - def attach(self, node: IRCell, index, reset_dependency=False): - """ - Attach (insert) a node into current graph at node index. - - All the used input and output tensors inside the node are - recorded in consumed and produced tensor list. Adapter node - will not record the consumer and producer. - """ - if node in self.nodes(): - raise KeyError(f"node {node} is already in graph.") - self._nodes.insert(index, node) - if isinstance(node, IRAdapter): - return - # update consumer - itensors: List[IRSubTensor] = [] - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor) and itensor not in itensors: - itensors.append(itensor) - for itensor in itensors: - if itensor.parent.tid not in self._full_tensors: - self._full_tensors[itensor.parent.tid] = itensor.parent - idx = 0 - for consumer in itensor.parent.consumers: - if self.nodes().index(consumer) < index: - idx += 1 - else: - break - itensor.parent.add_consumer(node, itensor, idx) - # update producer - otensors: List[IRSubTensor] = [] - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor) and otensor not in otensors: - otensors.append(otensor) - for otensor in otensors: - if otensor.parent.tid not in self._full_tensors: - self._full_tensors[otensor.parent.tid] = otensor.parent - idx = 0 - for producer in otensor.parent.producers: - if self.nodes().index(producer) < index: - idx += 1 - else: - break - otensor.parent.add_producer(node, otensor, idx) - if reset_dependency: - self.reset_dependency() - return - - def flatten(self) -> List[IRCell]: - """ - Flattent the graph by expanding nodes - """ - nodes = [] - for node in self.nodes(): - if isinstance(node, IRSegment): - nodes += node.nodes() - else: - nodes.append(node) - return nodes - - @staticmethod - def get_inputs(nodes: List[IRCell]): - """ - Get all the input tensors the is not generated by nodes - - Inputs - - Returns: - List[IRTensor] - """ - all_outputs = list() - for node in nodes: - all_outputs.extend(node.outputs()) - inputs = list() - for cell in nodes: - for input in cell.inputs(): - if isinstance(input, IRTensor): - if input not in all_outputs: - if input not in inputs: - inputs.append(input) - return inputs - - @staticmethod - def get_outputs(nodes: List[IRCell]): - """ - Get all the output tensors the is not used by nodes - - Args: - This will also consider the successor forward nodes. - If it is required by other outside forward nodes, - put in the outputs list - - Returns: - List[IRTensor] - """ - all_inputs = list() - for node in nodes: - all_inputs.extend(node.inputs()) - outputs = list() - for node in nodes: - for idx, output in enumerate(node.outputs()): - # not consumed tensor - if isinstance(output, IRSubTensor): - if output not in all_inputs: - if output not in outputs: - outputs.append(output) - continue - # consumed by other nodes - succs = node.successors(idx) - fsuccs = [ - fnode for fnode in succs if isinstance(fnode, IRFwOperation) - ] - for fsucc in fsuccs: - if fsucc not in nodes: - if output not in outputs: - outputs.append(output) - return outputs - @staticmethod def from_logic_graph(nodes: List[IRCell], inputs: List[IRFullTensor], outputs: List[IRFullTensor], module_name: str): + """ + Generate IRGraph from logical graph (IRFullTensor) + + Multiref will be inserted: + + e.g., original graph: + ``` + t = producer(xx) + ... + xx = consumer1(t) + ... + xx = consumer2(t) + ... + xx = consumer3(t) + ... + ``` + will be changed into: + ``` + t = producer(xx) + ... + t1, t2 = multiref(t) + xx = consumer1(t1) + ... + t3, t4 = multiref(t2) + xx = consumer2(t3) + ... + xx = consumer3(t4) + ... + ``` + """ # handle multi-consumed tensor consumers: Dict[IRFullTensor, List[IRCell]] = dict() producers: Dict[IRFullTensor, IRCell] = dict() @@ -464,8 +550,8 @@ def from_logic_graph(nodes: List[IRCell], ftensors = set() for ftensor in node.inputs(): # remove redundant tensors within an operator - if isinstance(ftensor, IRFullTensor) and ftensor._id not in ftensors: - ftensors.add(ftensor._id) + if isinstance(ftensor, IRFullTensor) and ftensor.tid not in ftensors: + ftensors.add(ftensor.tid) if ftensor not in consumers: consumers[ftensor] = [] consumers[ftensor].append(node) @@ -474,20 +560,28 @@ def from_logic_graph(nodes: List[IRCell], producers[ftensor] = node for ftensor, cnodes in consumers.items(): if len(cnodes) == 1 or ftensor.is_attr(): continue - itensors = [ftensor.like() for _ in range(len(cnodes))] - for itensor, consumer in zip(itensors, cnodes): + reftensor = ftensor + ctensor = ftensor + while len(cnodes) > 0: + consumer = cnodes.pop(0) + if len(cnodes) > 0: + itensors = [ftensor.like() for _ in range(2)] + multiref = MultiRef(None, [reftensor, 2]) + for idx, itensor in enumerate(itensors): + multiref.set_output(idx, itensor) + multiref.infer_shape() + # insert multiref right before the consumor + idx = nodes.index(consumer) + nodes.insert(idx, multiref) + ctensor, reftensor = itensors + else: + # the last consumer doesn't need multiref + ctensor = reftensor + # update consumer while ftensor in consumer.inputs(): idx = consumer.inputs().index(ftensor) - consumer.set_input(idx, itensor) - # create and insert multiref operation - multiref = MultiRef(None, [ftensor, len(cnodes)]) - for idx, itensor in enumerate(itensors): - multiref.set_output(idx, itensor) - multiref.infer_shape() - idx = nodes.index(producers[ftensor]) + 1 if ftensor in producers else 0 - # idx = nodes.index(cnodes[0]) - nodes.insert(idx, multiref) - + consumer.set_input(idx, ctensor) + # instantiate graph inputs / outputs for idx, tensor in enumerate(inputs): if isinstance(tensor, IRFullTensor): @@ -533,20 +627,26 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis if node not in self.nodes(): raise RuntimeError(f"Op {node} not exsits") - fidx = self.detach(node) fnodes = [node.replicate() for _ in range(times)] + # insert forward - for idx, fnode in enumerate(fnodes): - self.attach(fnode, fidx + idx) + self.replace(node, fnodes) + for fnode in fnodes: + if isinstance(node, IRFwOperation): + fnode.recompute = node.recompute + if isinstance(node.comment, str): + fnode.comment = node.comment + fnode.device = node.device + # insert backward if isinstance(node.mirror, IRBpOperation): - bidx = self.detach(node.mirror) + bnode: IRBpOperation = node.mirror for fnode in fnodes: fnode.gen_backward() - bnodes = [fnode.mirror for fnode in fnodes][::-1] - for idx, bnode in enumerate(bnodes): - self.attach(bnode, bidx + idx) - #TODO: dependency set + bnodes = [fnode.mirror for fnode in fnodes[::-1]] + self.replace(bnode, bnodes) + for bnode in bnodes: + bnode.device = node.device return fnodes def partition(self, node: Union[IRFwOperation, IRDataOperation], @@ -577,71 +677,42 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], """ assert isinstance(algo, GenericDistAlgo) and node == algo.node, \ "The partition algorithm is not initialized for this node" - if node not in self.nodes(): - raise RuntimeError(f"Not Exist: {node}") - if not (isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation)): - raise ValueError("Only allow op to be forward op or data op.") + assert isinstance(node, (IRFwOperation, IRDataOperation)), \ + f"Only allow op to be forward op or data op, but got: {node}" # get partitioned sub-nodes fnodes = algo.instantiate(**config) - if fnodes is None: return fnodes + if fnodes is None: + return None # update forward - findex = self.detach(node) - for idx, fnode in enumerate(fnodes): - self.attach(fnode, findex + idx) - if isinstance(node.comment, str): - fnode.comment = node.comment + self.replace(node, fnodes) + for fnode in fnodes: if isinstance(node, IRFwOperation): fnode.recompute = node.recompute + if isinstance(node.comment, str): + fnode.comment = node.comment + fnode.device = node.device # update backward if isinstance(node.mirror, IRBpOperation): - bindex = self.detach(node.mirror) - bnodes = [fnode.gen_backward() for fnode in fnodes][::-1] - for idx, bnode in enumerate(bnodes): - self.attach(bnode, bindex + idx) - if isinstance(node.mirror.comment, str): - bnode.comment = node.mirror.comment + bnodes = [fnode.gen_backward() for fnode in fnodes[::-1]] + self.replace(node.mirror, bnodes) + for bnode in bnodes: + bnode.device = node.device # update gradient updated = set() for itensor in [t for t in node.inputs() if isinstance(t, IRSubTensor)]: for fnode in itensor.parent.consumers: - bnode = fnode.mirror - if isinstance(bnode, IRBpOperation) and fnode._id not in updated: - idx = self.detach(bnode) - bnode.update() - self.attach(bnode, idx) - updated.add(fnode._id) - # update device - for fnode in fnodes: - fnode.device = node.device - if isinstance(fnode.mirror, IRCell): - fnode.mirror.device = node.device + bnode: IRBpOperation = fnode.mirror + if isinstance(bnode, IRBpOperation) and fnode.cid not in updated: + with self.update(bnode): + bnode.update() + updated.add(fnode.cid) return fnodes - def replace(self, old_nodes: List[IRCell], new_nodes: List[IRCell]): - """! - Replace nodes with node. - - Note we don't check semantic correctness for the replacement. - - @param old_nodes List[IRCell]: nodes to be replaced - @param new_nodes List[IRCell]: nodes to replace in - - @return True - """ - idx = len(self._nodes) - for old_node in old_nodes: - oidx = self.detach(old_node) - idx = min(oidx, idx) - for new_node in new_nodes[::-1]: - self.attach(new_node, idx) - return True - ## Spatial Primitives ## - def assign(self, node: Union[IRFwOperation, IRBpOperation], - ranks: Union[int, Tuple[int]]): + def assign(self, node: Union[IRFwOperation, IRDataOperation], device: int) -> bool: """ Assign an operator (subgraph) to (multiple) rank(s). @@ -656,16 +727,12 @@ def assign(self, node: Union[IRFwOperation, IRBpOperation], @return sucess bool: always true """ - assert isinstance(node, (IRFwOperation, IRDataOperation)), f"Only forward and data operation can be assigned to device, but got {node}" - assert node in self._nodes, f"{node} is not in the graph" - ranks = (ranks,) if isinstance(ranks, int) else ranks - assert all([isinstance(rank, int) for rank in ranks]), "Expected rank to be int" - nodes = [node] if len(ranks) == 1 else self.replicate(node, times=len(ranks)) - for node, rank in zip(nodes, ranks): - node.device = rank - if isinstance(node.mirror, IRBpOperation): - bnode = node.mirror - bnode.device = rank + assert isinstance(node, (IRFwOperation, IRDataOperation)), \ + f"Only forward and data operation can be assigned to device, but got {node}" + assert self.exist(node), f"{node} is not in the graph" + node.device = device + if node.mirror is not None: + node.mirror.device = device return True ## Schedule Policy Primitives ## @@ -677,6 +744,7 @@ def happen_before(self, node1: IRCell, node2: IRCell, skip=None) -> bool: Returns: Boolean """ + raise NotImplementedError("dependency is not supported yet") skip = list() if skip is None else skip if node1 in skip: return False @@ -732,8 +800,8 @@ def schedule(self, node1: IRCell, action: str, node2: IRCell) -> bool: for idx in range(idx1+1, idx2+1): if self.depends(node1, self._nodes[idx]): return False - self.detach(node1) - self.attach(node1, idx2) + self.remove(node1) + self.insert(node1, idx2) return True # node1 -> node2 if action == 'before': @@ -742,8 +810,8 @@ def schedule(self, node1: IRCell, action: str, node2: IRCell) -> bool: for idx in range(idx2, idx1): if self.depends(self._nodes[idx], node1): return False - self.detach(node1) - self.attach(node1, idx2) + self.remove(node1) + self.insert(node1, idx2) return True raise KeyError(f"Unknown scheduling action {action}") @@ -807,10 +875,135 @@ def add_schedule(self, nodes: List[IRCell]) -> bool: post.add_predecessor(input_index=-1, cell=prev) return True + # ================= staging primitives ================== + + def staging(self, nodes: Tuple[IRFwOperation]): + """! + Group forward / dataloader operators into sequential stages. + The corresponding backward operators will also be grouped into stages + Cross-stage dataflow will be limited to neighbor stages. + This should be called before any operator partition. + + The transformation and temporal scheduling can only be applied within each stage. + For example, after staging, user cannot schedule a (transformed) node + from one stage to another stage. + + The stage is a concept that is only about logical separation of nodes, + it doesn't have additional constraints for device assignment. + + Changes will be made: + + 1). Identity creation: + If a non-attribute tensor is produced / consumed not in + neighbor stages, + e.g., + stage 1: t1 = producer() + stage 2: ... + stage 3: xx = consume(t1) + stage 4: ... + stage 5: xx = consume(t1) + then Identity nodes will be created for every device in stage2: + stage 1: t1 = producer() + stage 2: t2 = identity(t1) + stage 3: xx = consume(t2) + stage 4: t3 = identity(t2) + stage 5: xx = consume(t3) + + 2). REMOVED: Multiref Modification: + If a non-attribute tensor has multiref node to different devmeshes, + e.g., + stage 1: t1, t2 = multiref(t) + stage 2: xx = consume(t1) + stage 3: ... + stage 4: xx = consume(t2) + then the multiref will be transfered into identity operator: + stage 1: t1 = multiref(t) + stage 2: xx = consume(t1) + t2 = identity(t1) + stage 3: t3 = identity(t2) + stage 4: xx = consume(t3) + + @param starts Tuple[int]: the start index of each stage + @return None + """ + assert all(isinstance(node, (IRFwOperation, IRDataOperation)) for node in nodes), \ + f"Find node is not IRFwOperation or IRDataOperation: {node}" + assert all(node in self._nodes for node in nodes), \ + f"Exist node is not in graph nodes" + starts = tuple(self._nodes.index(node) for node in nodes) + assert len(starts) > 0 + starts = (0,) + starts if starts[0] != 0 else starts + + last_fidx = 0 + for idx, node in enumerate(self._nodes): + if not isinstance(node, IRBpOperation): + last_fidx = idx + + fstages = [] + for sid in range(len(starts)): + begin = starts[sid] + if sid == len(starts) - 1: + end = last_fidx + 1 + else: + end = starts[sid+1] + if begin == end: + continue + assert begin < end + fstages.append(self._nodes[begin:end]) + + # grouping into index + fsegs: List[IRSegment] = [] + for sid in range(len(fstages)): + fsegs.append(self.group(fstages[sid])) + + fidxs: List[int] = [self._nodes.index(seg) for seg in fsegs] + bidxs: List[Optional[int]] = [self._nodes.index(seg.mirror) if seg.mirror is not None else None for seg in fsegs] + + def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: + identity = Identity('', [tensor]) + identity.infer_shape() + identity.set_output(0, identity.output(0).tosub()) + # insert forward + self.insert(identity, GraphIndex(fidxs[sid], 0)) + # insert backward + if self.train: + bnode = identity.gen_backward() + self.insert(bnode, GraphIndex(bidxs[sid], -1)) + return identity + + # create identity op for cross-stage dataflow + # the gradient flow of neighbor stages is automatically guaranteed + for ftensor in self.full_tensors(): + if ftensor.is_grad() or ftensor.is_attr(): continue + assert len(ftensor.producers) <= 1, \ + "The staging interface should be called before any operator partition." + if len(ftensor.consumers) == 0: continue + producer = ftensor.producers[0] + # TODO: robustness + psid = fidxs.index(self.index(producer).gidx) + ptensor = ftensor.ptensors[0] + out = ptensor + curr_sid = psid + for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + assert ctensor == ptensor, "The staging interface should be called before any operator partition." + # TODO: robustness + csid = fidxs.index(self.index(consumer).gidx) + if curr_sid == csid: continue + for sid in range(curr_sid + 1, csid): + identity = insert_identity(out, sid) + out = identity.output(0) + # update consumer and its backward + with self.update(consumer) as consumer: + tidx = consumer.inputs().index(ptensor) + consumer.set_input(tidx, out) + if self.train: + with self.update(consumer.mirror) as bnode: + bnode.update() + curr_sid = csid # ================= Other optimizations ================== - def recompute(self, nodes: List[IRFwOperation]) -> bool: + def recompute(self, nodes: Union[List[IRFwOperation], IRSegment]) -> bool: """! Recompute a set of nodes. The forward nodes will be assigned with a unique recompute group id. A forward not can not be recomputed in different recompute groups. @@ -819,34 +1012,43 @@ def recompute(self, nodes: List[IRFwOperation]) -> bool: @return success boolean: always success """ - assert all(isinstance(fnode, IRFwOperation) for fnode in nodes), "require forward operations" - recompute_group_id = IDGenerator().gen_cell_id() - for fnode in nodes: - fnode.recompute = recompute_group_id + assert all(isinstance(nodes, IRFwOperation)) or isinstance(nodes, IRSegment), \ + "Require forward nodes or a single segment" + + recompute_group_id: int = IDGenerator().gen_cell_id() + + if isinstance(nodes, IRSegment): + assert nodes.forward, "Can only apply recompute on segment node" + for fnode in nodes.node(): + fnode.recompute = recompute_group_id + else: + indices = [self.index(node) for node in nodes] + if all(idx[1] is None for idx in indices): + assert all(idx[0] == indices[0][0] for idx in indices), \ + f"Cross-stage recompute is not allowed yet." + elif all(idx[1] is not None for idx in indices): + assert all(idx[0] == indices[0][0] for idx in indices), \ + f"Cross-stage recompute is not allowed yet." + else: + assert False, f"Cross-stage recompute is not allowed yet." + for fnode in nodes: + fnode.recompute = recompute_group_id + return True - def __repr__(self): + def __repr__(self) -> str: dscp = f"Graph{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" return dscp - def extra_repr(self): + def extra_repr(self) -> str: dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" # inputs dscp += f"Inputs: {self.inputs()}\n" - # nodes for node in self._nodes: - # succ_node_ids = [node._id for node in node.successors()] - # succ_node_ids = [None] * len(node.outputs()) - # for out_idx in range(len(node.outputs())): - # node_list = [snode._id for snode in node.successors(out_idx)] - # succ_node_ids[out_idx] = node_list - # dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}" dscp += f"\n{node}" + if isinstance(node, IRSegment): + for subnode in node.nodes(): + dscp += f"\n\t{subnode}" # outputs dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" return dscp - - def module_repr(self): - return repr(self) - - diff --git a/cube/graph/segment.py b/cube/graph/segment.py new file mode 100644 index 00000000..0acef0e0 --- /dev/null +++ b/cube/graph/segment.py @@ -0,0 +1,238 @@ +from typing import Union, Tuple, List, Optional, Dict +import copy + +from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.cten import IRTensor, IRCell +from cube.ir.operator import IRFwOperation, IRBpOperation +from cube.ir.adapter import IRAdapter + + + +class IRSegment(IRCell): + """ + A distributed sub-graph representing a piece of workload in parent IRGraph + """ + + def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRSubTensor], name='segment'): + super().__init__(name, '', len(inputs), len(outputs), init_outputs=False) + + self._nodes: List[IRCell] = nodes + self._idevice = [t.device for t in inputs] + self._odevice = [t.device for t in outputs] + + for idx, val in enumerate(inputs): + self.set_input(idx, val) + for idx, val in enumerate(outputs): + self.set_output(idx, val) + + self._have_forward = any(isinstance(n, IRFwOperation) for n in nodes) + self._have_backward = any(isinstance(n, IRBpOperation) for n in nodes) + + @property + def forward(self) -> bool: + return self._have_forward + + # ========================= Basic Graph access ======================= + + @property + def device(self) -> List[int]: + devices = set() + for node in self._nodes: + devices.update(node.device) + devices = list(devices) + devices.sort() + return devices + + @property + def nnodes(self) -> int: + """ + Get total node number + + @return number int: the number of nodes + """ + return len(self._nodes) + + def nodes(self, idx: Optional[int] = None) -> Union[IRCell, List[IRCell]]: + """ + Get all the nodes. + + @return nodes List[IRCell]: all the nodes + """ + if isinstance(idx, int): + return self._nodes[idx] + else: + return copy.copy(self._nodes) + + def node(self, index: int) -> IRCell: + """ + Get node at position index + + @param index int: the node index + + @return node IRCell: the node. + """ + return self._nodes[index] + + def index(self, node: IRCell) -> int: + """ + Get node index. + + @param node IRCell: the queried node + + @return index int: the index + """ + return self._nodes.index(node) + + # ====================== Basic Graph manipulations ====================== + + def insert(self, node: IRCell, index: int): + """ + Insert a node at index. + + TODO: update input and output + + @param node IRCell: the inserted node + @param index int: the index + + """ + assert node not in self._nodes, f"duplicated insertation of node: {node}" + self._nodes.insert(index, node) + + def remove(self, node: IRCell) -> int: + """ + Remove a node at index + + # TODO: update input and output + + @param node IRCell: the removed node + + @return index int: the removed index + """ + assert node in self._nodes, f"The removed node doesn't exist" + index = self._nodes.index(node) + self._nodes.pop(index) + return index + + def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: + """ + Replace one node by multiple nodes + + # TODO: update input and output + + @param node IRCell: the replaced node + @param new_nodes List[IRCell]: the nodes to be inserted. + + @return index int: the replaced node index + """ + idx = self.remove(node) + self._nodes = self._nodes[:idx] + list(new_nodes) + self._nodes[idx:] + return idx + + def exist(self, node: IRCell) -> bool: + """ + Check if the node is in this graph + + @param node IRCell: the queried node + + @return exsit bool: True if exist otherwise False + """ + return node in self._nodes + + # ====================== Graph Generations ============================ + + @staticmethod + def get_inputs(nodes: List[IRCell]): + """ + Get all the input tensors that are required by nodes. + + @param nodes List[IRCell]: the nodes + + @return inputs List[IRTensor]: the input tensors + """ + all_outputs = list() + for node in nodes: + all_outputs.extend(node.outputs()) + inputs = list() + for node in nodes: + for input in node.inputs(): + if isinstance(input, IRTensor): + if input not in all_outputs: + if input not in inputs: + inputs.append(input) + return inputs + + @staticmethod + def get_outputs(nodes: List[IRCell]): + """ + Get tensors that are produced but not consumed by nodes + + As long as the tensor is consumed in by the nodes, it will + not be in the output. A tensor will not appear as output if it + is double-consumed both outside and inside the nodes. + + @param nodes List[IRCell]: the nodes + + @return outputs List[IRTensor]: the output tensors + """ + all_inputs = list() + for node in nodes: + all_inputs.extend(node.inputs()) + outputs = list() + for node in nodes: + for output in node.outputs(): + # not consumed tensor + if isinstance(output, IRTensor): + if output not in all_inputs: + if output not in outputs: + outputs.append(output) + continue + return outputs + + + ###### ============ Transformation Primitives ============ ####### + + + def dispatch(self, devid: int, for_mirror=True) -> Optional[IRCell]: + """ + Instantiate from distributed representation to a + device-specific sub-graph. + + The mirror will also be dispatched if it is not None. + + Return the dispatched segment + """ + if devid not in self.device: + return None + if len(self.device) == 1 and self.device == [devid]: + return self + itensors = [t for t, device in zip(self.inputs(), self._idevice) if devid in device] + otensors = [t for t, device in zip(self.outputs(), self._odevice) if devid in device] + nodes = [n for n in self.nodes() if devid in n.device] + for idx, adapter in enumerate(nodes): + if isinstance(adapter, IRAdapter): + nodes[idx] = adapter.dispatch(devid) + fseg = IRSegment(nodes, itensors, otensors) + fseg._id = self._id + # dispatch for mirror + if for_mirror and isinstance(self.mirror, IRSegment): + bseg = self.mirror.dispatch(devid, for_mirror=False) + IRCell.make_pair(fseg, bseg) + return fseg + + + # ========================== Graph Visualize ================================ + + def to_str(self, skip_attr: bool = False) -> str: + name = ('f' if self.forward else 'b') + 'Segment' + inputs = tuple(t for t in self.inputs() if not (t.is_attr() and skip_attr)) + outputs = tuple(t for t in self.outputs() if not (t.is_attr() and skip_attr)) + return f'{name}{self._id}-{self.device}(inputs={inputs}, outputs={outputs})' + + def __repr__(self): + return self.to_str() + + def extra_repr(self) -> str: + dscp = repr(self) + for node in self.nodes(): + dscp += '\n\t' + repr(node) + return dscp diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index 2bf0408a..7a81638c 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -1,5 +1,7 @@ +from typing import List from enum import Enum + class IRDType(Enum): float64 = 'float64' float16 = 'float16' @@ -13,6 +15,35 @@ class IRDType(Enum): unknown = 'unknown' +class DTypeInferRule: + """ + Infer the output shape according to given input shapes. + This will follow the dtype promotion rule, which is same with PyTorch. + + Reference: + https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc + + complex > floating > integral > boolean + """ + @staticmethod + def infer(node, dtypes: List[IRDType]) -> IRDType: + dtypes = [dtype for dtype in dtypes if dtype != IRDType.unknown] + if IRDType.unknown in dtypes: + raise RuntimeError(f"Find an unkown dtype") + if IRDType.float32 in dtypes and IRDType.float16 in dtypes: + raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") + # in priority: fp32 > fp16 > bool > int64 > int16 > + priority = [ + IRDType.float64, IRDType.float32, IRDType.float16, + IRDType.int64, IRDType.int32, IRDType.int16, IRDType.int8, + IRDType.boolean + ] + for dtype in priority: + if dtype in dtypes: + return dtype + return IRDType.unknown + + float64 = IRDType.float64 float16 = IRDType.float16 float32 = IRDType.float32 diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 86aa5f8d..9bd1c82d 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -1,39 +1,11 @@ -from typing import Any, Optional, Tuple, Union, List +from typing import Optional, Tuple import copy from cube.ir.cten import IRCell, IRTensor from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory from cube.ir.unique import IDGenerator - - -class IRBaseOp(IRCell): - - def __init__(self, name: str, signature: str, - inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): - super().__init__(name, signature, len(inputs), len(outputs), init_outputs=False) - self.kwargs = kwargs - assert all(isinstance(t, IRTensor) for t in inputs), "expect all inputs to be IRTensors" - assert all(isinstance(t, IRTensor) for t in outputs), "expect all outputs to be IRTensors" - for idx, itensor in enumerate(inputs): - self.set_input(idx, itensor) - for idx, otensor in enumerate(outputs): - self.set_output(idx, otensor) - - def infer_shape(self) -> bool: - """ - Infer output value shape - """ - raise NotImplementedError - - def replicate(self): - """ - Replicate the Operation - """ - node = type(self)( - self.name, self.signature, self.inputs(), self.outputs(), **self.kwargs) - node._id = self._id - return node +from cube.ir.dtype import IRDType, DTypeInferRule class IRFwOperation(IRCell): @@ -63,6 +35,20 @@ def __init__(self, for idx, output in enumerate(outputs): self.set_output(idx, output) + def infer_dtype(self): + """ + Infer output value dtype. + + By default will follow the same dtype promotion rule with PyTorch. + """ + itensors = [t for t in self.inputs() if isinstance(t, IRTensor)] + assert len(itensors) > 0, "Missing input tensors, need to customize the infer rule" + odtype = DTypeInferRule.infer(self, [t.dtype for t in itensors]) + assert odtype != IRDType.unknown, f"{self} : {[t.dtype for t in itensors]}" + otensors = [t for t in self.outputs() if isinstance(t, IRTensor)] + for tensor in otensors: + tensor.dtype = odtype + def infer_shape(self): """ Infer output value shape @@ -154,7 +140,7 @@ def gen_backward(self) -> IRCell: def __repr__(self) -> str: sign = self.signature.split('.')[-1] - ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_attr()] + ins = [t for t in self.inputs() if isinstance(t, IRTensor) and not t.is_attr()] dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " f"inputs={ins}, " f"outputs={self.outputs()})") @@ -200,9 +186,8 @@ def update(self): wrapped with IRGraph detach and attach: ``` - idx = graph.detach(node) - node.update() - graph.attach(node, idx) + with graph.update(node): + node.update() ``` """ fnode: IRFwOperation = self.mirror diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 885ec143..a9d9359e 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -359,6 +359,8 @@ def rm_consumer(self, cell: IRCell) -> int: idx = self._consumers.index(cell) self._consumers.pop(idx) self._ctensors.pop(idx) + for t in self._ctensors: + t._dirty_grad = True return idx def clear_producer_consumer(self) -> int: @@ -392,12 +394,31 @@ def requires_grad(self): def requires_grad(self, val: bool): self._requires_grad = val if val and self.grad is None: - self.grad = IRFullTensor(self.shape, 'g' + self.name, False).as_grad() + self.grad = IRFullTensor(self.shape, 'g' + self.name, False, dtype=self.dtype).as_grad() elif not val and self.grad is not None: self.grad = None for tensor in self._ctensors + self._ptensors: tensor._dirty_grad = True + @property + def dtype(self) -> IRDType: + """ + Tensor data type + """ + return self._dtype + + @dtype.setter + def dtype(self, val: IRDType): + """ + Set data type. + It's gradient data type will also be set. + """ + if not isinstance(val, IRDType): + raise TypeError(f"Expected IRDType but got {val}") + self._dtype = val + if isinstance(self.grad, IRTensor): + self.grad.dtype = val + def as_param(self): """ Set the tensor as trainable parameter @@ -847,5 +868,5 @@ def extra_repr(self) -> str: anno = 'w' if self.is_grad(): anno = 'g' - dscp = f'{anno}{self._id}(id={self._id}, shape={self.shape}, dev={self.device}, ind=[{self._indmap}], val={self._valmap})' + dscp = f'{anno}{self._id}(pid={self.parent.tid}, shape={self.shape}, dev={self.device}, ind=[{self._indmap}], val={self._valmap})' return dscp diff --git a/cube/logics/model.py b/cube/logics/model.py index 10696476..a6dc1a9d 100644 --- a/cube/logics/model.py +++ b/cube/logics/model.py @@ -2,37 +2,11 @@ import copy from cube.graph.graph import IRGraph -from cube.ir.dtype import IRDType +from cube.ir.dtype import IRDType, DTypeInferRule from cube.ir.tensor import IRSubTensor from cube.ir.operator import IRFwOperation -class DTypeInferRule: - """ - According to promotion doc: - https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc - - complex > floating > integral > boolean - """ - @staticmethod - def infer(node, dtypes: List[IRDType]) -> IRDType: - dtypes = [dtype for dtype in dtypes if dtype != IRDType.unknown] - if IRDType.unknown in dtypes: - raise RuntimeError(f"Find an unkown dtype") - if IRDType.float32 in dtypes and IRDType.float16 in dtypes: - raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") - # in priority: fp32 > fp16 > bool > int64 > int16 > - priority = [ - IRDType.float64, IRDType.float32, IRDType.float16, - IRDType.int64, IRDType.int32, IRDType.int16, IRDType.int8, - IRDType.boolean - ] - for dtype in priority: - if dtype in dtypes: - return dtype - return IRDType.unknown - - def forward(graph: IRGraph, *args) -> IRGraph: """ Forward the IRGraph, replacing all the intermediate tensors @@ -45,17 +19,15 @@ def forward(graph: IRGraph, *args) -> IRGraph: for idx, (itensor, arg) in enumerate(zip(itensors, args)): graph.set_input(idx, arg) for producer in copy.copy(itensor.parent.producers): - pidx = graph.detach(producer) - while itensor in producer.outputs(): - oidx = producer.outputs().index(itensor) - producer.set_output(oidx, arg) - graph.attach(producer, pidx) + with graph.update(producer): + while itensor in producer.outputs(): + oidx = producer.outputs().index(itensor) + producer.set_output(oidx, arg) for consumer in copy.copy(itensor.parent.consumers): - cidx = graph.detach(consumer) - while itensor in consumer.inputs(): - iidx = consumer.inputs().index(itensor) - consumer.set_input(iidx, arg) - graph.attach(consumer, cidx) + with graph.update(consumer): + while itensor in consumer.inputs(): + iidx = consumer.inputs().index(itensor) + consumer.set_input(iidx, arg) while itensor in graph.outputs(): oidx = graph.outputs().index(itensor) graph.set_output(oidx, arg) From 480eb4748eec53f74c29cc23a57df34dbba6eb0c Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 1 Sep 2022 21:02:57 +0800 Subject: [PATCH 0975/1892] single device mode fix --- cube/profiler/memory.py | 8 +++++++- examples/nlp/gpt/infer.py | 0 examples/wrf/policy/h_halo.py | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 examples/nlp/gpt/infer.py create mode 100644 examples/wrf/policy/h_halo.py diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index e9cb8a13..7b82ca93 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -3,7 +3,13 @@ from cube.profiler.timer import print_each_rank def memory_summary(): - rank = torch.distributed.get_rank() + import os + single_device_mode = os.environ.get('SINGLE_DEV_MODE') + if single_device_mode: + rank = 0 + else: + rank = torch.distributed.get_rank() + # memory measurement mem = torch.cuda.max_memory_allocated() # mem = torch.cuda.max_memory_reserved() diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/wrf/policy/h_halo.py b/examples/wrf/policy/h_halo.py new file mode 100644 index 00000000..a615ce11 --- /dev/null +++ b/examples/wrf/policy/h_halo.py @@ -0,0 +1,18 @@ +from cube.graph import IRGraph +from cube.graph.function import IRConv2D, IRConv3D + +def PAS(graph: IRGraph, resource): + for node in graph.nodes(): +# graph.assign(node, 0) + if isinstance(node, IRConv3D): + sub_nodes = list() + algo = node.algorithms('halo') + sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + # sub_nodes = graph.replicate(node, times=resource.ngpus) + + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + print(graph.extra_repr()) + return graph From e9116ece793a20cb2bd2231bf803ae999d16c094 Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 1 Sep 2022 21:06:55 +0800 Subject: [PATCH 0976/1892] add PyTorch example: regressive generation of GPT inference --- examples/nlp/blocks/attention.py | 93 ++++++++++++++++++++++++++ examples/nlp/blocks/encoder.py | 39 ++++++++++- examples/nlp/gpt/infer.py | 110 +++++++++++++++++++++++++++++++ examples/nlp/gpt/model.py | 100 +++++++++++++++++++++++++++- 4 files changed, 338 insertions(+), 4 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 7a953781..3f62396a 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -88,6 +88,62 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, return output +from typing import Optional, Tuple + +# @cube.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') +def one_attention(hidden_states: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor]], + q_proj: torch.Tensor, q_bias: torch.Tensor, + k_proj: torch.Tensor, k_bias: torch.Tensor, + v_proj: torch.Tensor, v_bias: torch.Tensor, + out_proj: torch.Tensor, out_bias: torch.Tensor, + h: int, scale: float, dropout_p: float, mask=True): + num_head = h + L, N = hidden_states.size(0), hidden_states.size(1) + dim_head = q_proj.size(0) // num_head + + q = torch.nn.functional.linear(hidden_states, q_proj, q_bias) # L N E, (h d) E -> L N (h d) + k = torch.nn.functional.linear(hidden_states, k_proj, k_bias) # L N E, (h d) E -> L N (h d) + v = torch.nn.functional.linear(hidden_states, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-3) + v = torch.cat((past_value, v), dim=-3) + + q_N = hidden_states.size(1) + + k_L = k.size(0) + v_L = v.size(0) + + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(k_L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(v_L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # # attention mask + # if mask: # (N h) L L -> (N h) L L + # attn = attn.view(N, num_head, L, L) + # ones = torch.ones((N, L, L), device=attn.device) + # mask = torch.tril(ones) + # mask = mask.view(N, 1, L, L) + # mask = (mask < 0.5) + # attn = attn.masked_fill_(mask, -10000.0) + # attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + output = torch.nn.functional.linear(output, out_proj, None) # L N (h d), E E -> L N E + return output + class MultiHeadSelfAttention(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): @@ -156,3 +212,40 @@ def forward(self, query: torch.Tensor, key: torch.Tensor): ) attn = attn + self.out_bias return attn + + +class MultiHeadOneAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): + super().__init__() + self.inner_dim = inner_dim + self.num_heads = num_heads + self.head_dim = inner_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # Q + self.q_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) + self.q_bias = torch.nn.Parameter(torch.rand(inner_dim)) + # K + self.k_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) + self.k_bias = torch.nn.Parameter(torch.rand(inner_dim)) + # V + self.v_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) + self.v_bias = torch.nn.Parameter(torch.rand(inner_dim)) + # Out + self.out_proj = torch.nn.Parameter(torch.rand(embed_dim, inner_dim)) + self.out_bias = torch.nn.Parameter(torch.rand(embed_dim)) + + from typing import Optional, Tuple + + def forward(self, query: torch.Tensor, layer_past: Optional[Tuple[torch.Tensor]] = None): + attn = one_attention( + query, layer_past, + self.q_proj, self.q_bias, + self.k_proj, self.k_bias, + self.v_proj, self.v_bias, + self.out_proj, self.out_bias, + self.num_heads, self.scaling, self.dropout_p, mask=True + ) + attn = attn + self.out_bias + return attn \ No newline at end of file diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 83645f61..5582bd36 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -1,5 +1,5 @@ import torch -from examples.nlp.blocks.attention import MultiHeadSelfAttention +from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention from examples.nlp.blocks.mlp import MLP @@ -30,3 +30,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.dropout(x) x = x + residual return x + + +class EncoderInferLayer(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, + attn_hidden_dim: int, ffn_hidden_dim: int, seqlen: int = -1, + dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): + super().__init__() + self.self_attn_partial = MultiHeadOneAttention( + embed_dim, num_heads, attn_hidden_dim, atten_dropout + ) + self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) + self.dropout = torch.nn.Dropout(p=dropout) + self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + # id-embed + pos-embed + tmp_batch_size = 1 + self.past_embed_key = torch.nn.Parameter(torch.rand(seqlen, tmp_batch_size, embed_dim)) + self.past_embed_value = torch.nn.Parameter(torch.rand(seqlen, tmp_batch_size, embed_dim)) + self.past_embed = tuple([self.past_embed_key, self.past_embed_value]) + print(f'self.past_embed.type = {type(self.past_embed)}') + + # def forward(self, x: torch.Tensor, encoder_output: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn_partial(x, self.past_embed) + x = self.dropout(x) + x = x + residual + + residual = x + x = self.final_layer_norm(x) + x = self.mlp(x) + x = self.dropout(x) + x = x + residual + return x \ No newline at end of file diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py index e69de29b..30557002 100644 --- a/examples/nlp/gpt/infer.py +++ b/examples/nlp/gpt/infer.py @@ -0,0 +1,110 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/nlp/gpt/train.py --policy PASMeshShard --fp16 +""" + + +import torch + +from examples.nlp.gpt.model import GPTInfer, GPTInferDataLoader +from examples.nlp.gpt.model import GPTDataLoader + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary, model_summary + +from examples.nlp.gpt.policy.mpmd import PASMegatron as PAS +import examples.nlp.gpt.policy.spmd as spmd +import examples.nlp.gpt.policy.mpmd as mpmd + +import argparse + +parser = argparse.ArgumentParser(description='GPT Train') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') +parser.add_argument('--fp16', action='store_true', default=False, + help='use fp16 for the training') +parser.add_argument('--local_rank', type=int, default=0) +args = parser.parse_args() + +cube.init() + +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +policies = [policy for policy in policies if policy.startswith('PAS')] +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + +def inter(): + + batch_size = 1 + + model = GPTInfer() + # model = model if not args.fp16 else model.half() + model.eval() + dataloader = GPTInferDataLoader(batch_size) + + for i in range(10): + input_ids, position_ids = next(dataloader) + print(f'input_ids = {input_ids} [{input_ids.size()}], position_ids = {position_ids} [{position_ids.size()}]') + output = model(input_ids, position_ids) + print(f'output = {output}') + + + # model = cube.SemanticModel(model, dataloader.shapes) + # @cube.compile(model, dataloader, PAS=PAS, override=True) + # def train_iter(model, dataloader): + # input_ids, position_ids = next(dataloader) + # loss = model(input_ids, position_ids) + # # loss.backward() + # model = model.get_gen_module() + # + # # optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + # + # import os + # single_device_mode = os.environ.get('SINGLE_DEV_MODE') + # print(f"single_dev_mode = {single_device_mode}") + # if not single_device_mode: + # torch.distributed.barrier() + # print_each_rank('model weight consumpition:', rank_only=0) + # memory_summary() + # + # CudaTimer(enable=False).warmup() + # iter_num = 4 + # warmup = 2 + # for step in range(iter_num): + # # if step == 0: + # # model_summary(model, next(dataloader)) + # + # if step >= warmup: + # CudaTimer(enable=True).start('e2e') + # train_iter(model, dataloader) + # # optimizer.step() + # # optimizer.zero_grad() + # if step >= warmup: + # CudaTimer().stop('e2e') + # + # if step == 0: + # print_each_rank('passed first iteration') + # if (step + 1) % 10 == 0: + # print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + # + # print_each_rank('e2e time (ms) per iteration: {} ms'.format( + # CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + # CudaTimer().print_all(times=iter_num-warmup) + # memory_summary() + + +if __name__ == '__main__': + + cube.init() + inter() \ No newline at end of file diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index c9b4480c..296598d5 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -1,7 +1,7 @@ import torch -from examples.nlp.blocks.encoder import EncoderLayer +from examples.nlp.blocks.encoder import EncoderLayer, EncoderInferLayer import cube @@ -10,11 +10,21 @@ class Config: num_embeddings = 50432 seqlen = 1024 - # 340 M model + # toy model embed_dim = 1024 - layers = 8 # 24 + layers = 8 # 96 attention_heads = 16 + # # 1 layer of 175B model + # embed_dim = 12288 + # layers = 1 # 96 + # attention_heads = 96 + # + # # 350 M model (Medium)* + # embed_dim = 1024 + # layers = 24 + # attention_heads = 16 + # 1.3 B model # embed_dim = 2048 # layers = 24 @@ -40,6 +50,11 @@ class Config: # layers = 48 # attention_heads = 64 + # 175 B model* + # embed_dim = 12288 + # layers = 96 + # attention_heads = 96 + attn_hidden_dim = embed_dim ffn_hidden_dim = embed_dim * 4 dropout = 0.2 @@ -91,6 +106,51 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): return loss +class GPTInfer(torch.nn.Module): + + def __init__(self): + super().__init__() + cfg = Config() + + # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) + self.embedw = torch.nn.Parameter(torch.rand(cfg.num_embeddings, cfg.embed_dim) / 128) + self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) + self.embed_dropout = torch.nn.Dropout() + + self.layers = torch.nn.ModuleList( + [EncoderInferLayer( + cfg.embed_dim, cfg.attention_heads, + cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.seqlen, + cfg.dropout, cfg.attn_dropout, cfg.activation_dropout + ) for _ in range(cfg.layers)] + ) + self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) + + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): + + # embed = self.embed(input_ids) + embed = torch.nn.functional.embedding( + input_ids, self.embedw, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False + ) + pos_embed = self.position(position_ids) + embed = embed + pos_embed + embed = self.embed_dropout(embed) + enc = embed.transpose(0, 1) + + for layer in self.layers: + cube.runtime.function.anchor('transformer start') + enc = layer(enc) + enc = self.final_layernorm(enc) + + # logits = torch.nn.functional.linear(enc, self.embed.weight) + logits = torch.nn.functional.linear(enc, self.embedw) + # simplified + loss = torch.sum(logits) + return loss + + class GPTDataLoader(cube.runtime.syndata.CubeDataLoader): def __init__(self, batch_size: int): @@ -120,5 +180,39 @@ def random_sample(self): def __iter__(self): return self + def __next__(self): + return self.samples[0] + +class GPTInferDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + + self.bs = batch_size + self.cfg = Config() + super().__init__( + shapes=([batch_size, 1], + [batch_size, 1], + ), + dtypes=(torch.int64, torch.int64), + batch_dims=(0, 0) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + input_ids = torch.randint( + 0, self.cfg.num_embeddings, + size=(self.bs, 1), + dtype=torch.int64, + # device=torch.cuda.current_device() + ) + position_ids = torch.arange( + 0, 1, dtype=torch.int64, + # device=torch.cuda.current_device() + ).repeat(self.bs) + return (input_ids, position_ids) + + def __iter__(self): + return self + def __next__(self): return self.samples[0] \ No newline at end of file From cd8b2d5b227054321c6a07358b53dc54371354eb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 2 Sep 2022 19:13:14 +0800 Subject: [PATCH 0977/1892] adapter hierarchical generation --- cube/graph/gener/gen.py | 418 ++++++++++++---------------------------- cube/graph/graph.py | 140 +++++++------- cube/graph/segment.py | 34 +++- 3 files changed, 221 insertions(+), 371 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 2b52a58a..5b6aea1b 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,18 +1,18 @@ import itertools from typing import Dict, List, Optional, Tuple -import copy + +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.gener.concurrent import ConcurrentGener from cube.graph.graph import IRGraph +from cube.graph.segment import IRSegment from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.operator import IRBpOperation, IRFwOperation -from cube.ir.adapter.prim import IRAdapterPrim from cube.ir.adapter import IRAdapter, IRWeightReducer from cube.graph.function.function import Add, Cat, Identity, MultiRef -from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim -from cube.graph.gener.layout import GridLayout class IRAdapterGener: @@ -21,17 +21,10 @@ class IRAdapterGener: def gen(graph: IRGraph) -> IRGraph: """ Generate tensor adapter for both activations and weights + Note weight reducers are always append to the last. - Args: - graph: IRGraph. - eager (Boolean): - if True, - each adapter will be inserted right after it's ready to execute. - if False (i.e., lazy), - each adatper will be inserted right before the tensor needs it. - Note weight reducers are always append to last. - Returns: - graph (IRGraph) + @param graph IRGraph: the graph without adapter + @return graph IRGraph: the graph with adapter inserted """ # insert identity operator for graph output devs = set() @@ -41,7 +34,8 @@ def gen(graph: IRGraph) -> IRGraph: all_identities = [] for otensor in outputs: identity = Identity('', [otensor]) - graph.attach(identity, len(graph.nodes())) + identity.set_output(0, identity.output(0).tosub()) + graph.insert(identity, len(graph.nodes())) identites = graph.replicate(identity, times=len(devs)) all_identities += identites for devid, identity in zip(devs, identites): @@ -49,16 +43,34 @@ def gen(graph: IRGraph) -> IRGraph: # update the gradient before generate adapter for node in graph.nodes(): if isinstance(node, IRBpOperation): - idx = graph.detach(node) - node.update() - graph.attach(node, idx) + with graph.update(node): + node.update() + # generate adapters for activation graph = IRAdapterGener.gen_activation(graph) + # generate weight reducer graph = IRAdapterGener.gen_weight(graph) # remove inserted identity for identity in all_identities: - graph.detach(identity) + graph.remove(identity) + # remove anchor node + IRAdapterGener.remove_anchor(graph) + print(graph.extra_repr()) return graph + @staticmethod + def remove_anchor(graph: IRSegment): + for node in graph.nodes(): + if isinstance(node, IRGraphAnchor): + graph.remove(node) + if node.mirror is not None: + graph.remove(node.mirror) + if isinstance(node, IRSegment): + for anchor in node.nodes(): + if isinstance(anchor, IRGraphAnchor): + graph.remove(anchor) + if anchor.mirror is not None: + graph.remove(anchor.mirror) + @staticmethod def gen_weight(graph: IRGraph) -> IRGraph: # step 1: get weight and gradient @@ -66,7 +78,7 @@ def gen_weight(graph: IRGraph) -> IRGraph: # grads : Dict[weight_id: int, Dict[device: int, List[grad: IRSubTensor]]] grads = dict() weights = dict() - for fnode in graph.nodes(): + for fnode in graph.flatten(): if not isinstance(fnode, IRFwOperation): continue devid = fnode.device[0] @@ -106,11 +118,11 @@ def gen_weight(graph: IRGraph) -> IRGraph: weights = reducers[ranks] opt_op = IRWeightReducer(weights) opt_op.device = list(ranks) - graph._nodes.append(opt_op) + graph.insert(opt_op, graph.nnodes) return graph @staticmethod - def gen_activation(graph: IRGraph) -> IRGraph: + def gen_activation(graph: IRSegment) -> IRSegment: """! Generate adapter for activation tensors. The forward/backward adapter is inserted before the first consumers of its full tensor. @@ -119,267 +131,77 @@ def gen_activation(graph: IRGraph) -> IRGraph: @return graph IRGraph: the (inplace) modified graph with activation adapters. """ + segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment)] + + def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: + # e.g., loss or parameter/buffer + if len(ptensors) == 0 or len(ctensors) == 0: + return True + # direct connection + if len(ptensors) == 1 and len(ctensors) == 1 and \ + set(ptensors[0].device) == set(ctensors[0].device): + return True + return False + + def filter(nodes: List[IRCell], tensors: List[IRSubTensor]) -> Tuple[IRCell, IRSubTensor]: + assert len(nodes) == len(tensors) + filter_nodes, filter_tensors = [], [] + for node, tensor in zip(nodes, tensors): + if node in graph.nodes(): + filter_nodes.append(node) + filter_tensors.append(tensor) + return filter_nodes, filter_tensors + + # generate adapter for inter-segments + # FIXME: assume producers and consumers can run in parallel for ftensor in graph.full_tensors(): # backward will gen in forward if ftensor.is_param() or ftensor.is_grad(): continue - # no consumer usually mean loss - if len(ftensor.consumers) == 0: - continue - # graph attribute: buffer - if len(ftensor.producers) == 0: - continue - # no require for communication - if len(ftensor.consumers) == 1 and len(ftensor.producers) == 1 and \ - ftensor.consumers[0].device == ftensor.producers[0].device: - continue - # optimization: local fusion on producer - if graph.train: + # optimization: local fusion / multiref on producer / consumer + if isinstance(graph, IRGraph) and graph.train: ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) IRAdapterGener.local_consumer_multiref(graph, ftensor) - ptensors, ctensors = ftensor.ptensors, ftensor.ctensors - pdevs = tuple(ptensor.device[0] for ptensor in ptensors) - cdevs = tuple(ctensor.device[0] for ctensor in ctensors) - - fadapter = None - # Case 1: sharing device (in-shard) - inshard = set(pdevs) == set(cdevs) and len(ptensors) == len(ctensors) and \ - len(pdevs) == len(ptensors) - if inshard and len(pdevs) > 1: - try: - fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) - except Exception as e: - fadapter = None - print( - f"full tensor: {ftensor} cannot use grid generation.\n" - f"Reason: {str(e)}\n" - f"Switch to general P2P communication." - ) - - # Case 2: sperating device (cross-shard) - if len(set(pdevs).intersection(cdevs)) == 0: - pass - - # Case 3: General cases - # warnings.warn('The adapter is generated using - if fadapter is None: - fadapter = IRAdapterGener.gen_general(ftensor) - - badapter: Optional[IRAdapter] = fadapter.mirror + # producers can be operators and graph inputs + producers, ptensors = filter(ftensor.producers, ftensor.ptensors) + for itensor in graph.inputs(): + if isinstance(itensor, IRSubTensor): + if itensor.parent == ftensor: + ptensors.append(itensor) + # consumers can be operators and graph outputs + consumers, ctensors = filter(ftensor.consumers, ftensor.ctensors) + for otensor in graph.outputs(): + if isinstance(otensor, IRSubTensor): + if otensor.parent == ftensor: + ctensors.append(otensor) + + if skip(ptensors, ctensors): continue - if (badapter is not None and len(fadapter.prims) == 0 and len(badapter.prims) == 0) or \ - (badapter is None and len(fadapter.prims) == 0): + fadapter = ConcurrentGener.gen(ptensors, ctensors) + if fadapter is None: continue - # set differentiable for autograd generation - if inshard and badapter is not None: - fadapter.differentiable = True - badapter.differentiable = True - # insert forward adapter - fidx = min([graph.nodes().index(consumer) for consumer in ftensor.consumers]) - graph._nodes.insert(fidx, fadapter) - - # insert backward - if badapter is not None: - bidx = min(graph.nodes().index(consumer) for consumer in ftensor.grad.consumers) - graph._nodes.insert(bidx, badapter) - - # print(graph.extra_repr()) + # fidx = max(graph.index(prod).gidx for prod in producers) + fidx = min(graph.index(cons) for cons in consumers) + graph.insert(fadapter, fidx) + + # insert backward adapter + if fadapter.mirror is not None: + bsegment = graph if isinstance(graph, IRGraph) else graph.mirror + # bidx = max(graph.index(cons.mirror) for cons in consumers if cons.mirror is not None) + bidx = min(bsegment.index(prod.mirror) for prod in producers if prod.mirror is not None) + bsegment.insert(fadapter.mirror, bidx) + + # generate adapter for each segment + for segment in segments: + IRAdapterGener.gen_activation(segment) + + print(graph.extra_repr()) return graph - @staticmethod - def gen_in_shard(ftensor: IRFullTensor, allow_reorder=False) -> Optional[IRAdapter]: - """ - Generate communication for sharing devices (SPMD-like) - - @param ftensor: IRFullTensor - @param ptensors: List[IRSubTensor]: produced subtensors - @param ctensors: List[IRSubTensor]: consumed subtensors - - @return adapter Optional[IRAdapter]: generated adapter. - """ - # producer grid layout - ilayout = GridLayout.togrid(ftensor, ftensor.ptensors) - # reorder ctensors to match with ptensors - devs = [ptensor.device for ptensor in ilayout.mat.flatten()] - ctensors = [None] * len(devs) - for ctensor in ftensor.ctensors: - idx = devs.index(ctensor.device) - ctensors[idx] = ctensor - assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" - # consumer grid layout - olayout = GridLayout.togrid(ftensor, ctensors) - # find path - paths, fprims = ilayout.path(olayout) - - # re-assign the operator if miss-ordered - names, from_dev, to_dev = [], [], [] - for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): - assert len(itensor.device) == 1 and len(otensor.device) == 1, \ - "Expect tensor only has one device. Report this as a bug" - if itensor.device != otensor.device: - inode, onode = itensor.cell, otensor.cell - names.append(f'{onode.name}{onode.cid}') - from_dev.append(onode.device[0]) - to_dev.append(inode.device[0]) - if allow_reorder: - onode.device = inode.device - if onode.mirror is not None: - onode.mirror.device = inode.device - else: - raise RuntimeError("device mismatch. Try to enable reorder") - if len(names) > 0: - print(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') - - fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) - fadapter.prims = fprims - - # generate backward - grad: IRFullTensor = ftensor.grad - bprims = [] - if grad is not None and (len(grad.ptensors) != 0 or len(grad.ctensors) != 0): - # reorder ptensors to match with forward - ptensors = [None] * len(devs) - for ptensor in grad.ptensors: - idx = devs.index(ptensor.device) - assert ptensors[idx] is None, "same device of different tensors" - ptensors[idx] = ptensor - ilayout = GridLayout.togrid(grad, ptensors) - olayout = GridLayout.togrid(grad, grad.ctensors) - paths, bprims = ilayout.path(olayout) - # check the device order - for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): - assert len(itensor.device) == len(otensor.device), "backward device not match" - badapter = IRAdapter(grad.ptensors, grad.ctensors) - badapter.prims = bprims - IRAdapter.make_pair(fadapter, badapter) - - return fadapter - - @staticmethod - def gen_cross_shard(ftensor: IRFullTensor, ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> Optional[IRAdapter]: - pass - - @staticmethod - def gen_general(ftensor: IRFullTensor) -> IRAdapter: - fprims = [] - for ctensor in ftensor.ctensors: - fprims += IRAdapterGener.gen_subtensor(ctensor, ftensor.ptensors) - fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) - fadapter.prims = fprims - if ftensor.grad is not None: - bprims = [] - for cgrad in ftensor.grad.ctensors: - bprims += IRAdapterGener.gen_subtensor(cgrad, ftensor.grad.ptensors) - badapter = IRAdapter(ftensor.grad.ptensors, ftensor.grad.ctensors) - badapter.prims = bprims - IRAdapter.make_pair(fadapter, badapter) - return fadapter - - @staticmethod - def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRAdapterPrim]: - """ - Generate communiction primitives for ctensor - - @param ctensor IRSubTensor: the consumed tensor as destination - @param ptensors List[IRSubTensor]: the produced tensors as source - - @return prims List[IRAdapterPrim]: the primitives for adapter - """ - # category to local tensor and remote tensor - local = [t for t in ptensors if t.device == ctensor.device] - remote = [t for t in ptensors if t.device != ctensor.device] - prims = [] - - # ==== select ==== # - intersections = [] - # check local - for itensor in local+remote: - if itensor.device == ctensor.device and itensor == ctensor: - return [] - common: Optional[IRSubTensor] = itensor.common(ctensor) - if common is None: - continue - common.cell = itensor.cell - intersections.append(common) - # create select primitive - if common != itensor: - indmap = [] - for dim in range(itensor.ndims): - (s1, e1), (s2, e2) = itensor.indmap[dim], common.indmap[dim] - start = s2 - s1 - end = start + e2 - s2 - indmap.append((start, end)) - indmap = IndexMap(tuple(indmap)) - assert itensor.valmap == common.valmap, "Value map not same" - valmap = ValueMap((0, 1)) - select_prim = SelectPrim(itensor, indmap, valmap, common) - prims.append(select_prim) - if itensor.device == ctensor.device and common == ctensor: - return [select_prim] - # TODO: check union == subtensor - if common == ctensor: - break - - # print(intersections) - # ====== move ===== # - tmoved = [] - for tensor in intersections: - assert len(tensor.device) == 1 and len(ctensor.device) == 1, "Expected only one device." - mtensor = tensor - if tensor.device != ctensor.device: - mtensor = copy.copy(tensor) - mtensor.cell = ctensor.cell - prims.append(MovePrim(tensor, mtensor)) - tmoved.append(mtensor) - - # ===== merge ===== # - remain_tensors: List[IRSubTensor] = copy.copy(tmoved) - if ctensor in remain_tensors: - return prims - out = None - while out != ctensor: - out, merged = None, False - for idx1 in range(len(remain_tensors) - 1): - for idx2 in range(idx1+1, len(remain_tensors)): - t1, t2 = remain_tensors[idx1], remain_tensors[idx2] - catdim = t1.catdim(t2) - if catdim is not None: - tensors = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] - out = tensors[0].concat(tensors[1], dim=catdim) - out.cell = ctensor.cell - prims.append(MergeDimPrim(tensors, out, catdim)) - merged = True - break - # reduction - if t1.accumable(t2): - out = t1.accum(t2) - out.cell = ctensor.cell - prims.append(SumPrim([t1, t2], out)) - merged = True - break - if merged: - remain_tensors.remove(t1) - remain_tensors.remove(t2) - remain_tensors.append(out) - break - if out is None: - ptensors = '\n\t'.join(t.extra_repr() for t in ptensors) - remain = '\n\t'.join(t.extra_repr() for t in remain_tensors) - print(remain_tensors[0].extra_repr()) - print(remain_tensors[1].extra_repr()) - print('cadim:', remain_tensors[0].catdim(remain_tensors[1])) - raise RuntimeError( - f"Fail to build adapter.\n" - f"FullTensor:{ctensor.parent}\n" - f"Producers:\n\t{ptensors}\n" - f"SubTensor:\n\t{ctensor.extra_repr()}\n" - f"Remain Tensor:\n\t{remain}" - ) - return prims @staticmethod def local_producer_fusion(graph: IRGraph, ftensor: IRFullTensor) -> IRFullTensor: @@ -478,16 +300,15 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens new_ftensor = ftensor.like() # update consumer - min_idx = len(graph.nodes()) assert len(ftensor.ctensors) == len(ftensor.consumers) for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): - fidx = graph.detach(consumer) - consumer.set_input( - consumer.inputs().index(ctensor), - new_ftensor.select(ctensor.indmap, ctensor.valmap) - ) - graph.attach(consumer, fidx) - min_idx = min(fidx, min_idx) + # TODO: the change can happend inside segment + with graph.update(consumer) as consumer: + consumer.set_input( + consumer.inputs().index(ctensor), + new_ftensor.select(ctensor.indmap, ctensor.valmap) + ) + min_idx = min(graph.nodes().index(consumer) for consumer in ftensor.consumers) # insert new producer for devid, tensors in fuse_tensors.items(): @@ -503,20 +324,21 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens if node.output(0) == tensor_map[devid][ptensor]: node.set_output(0, new_tensor) + fsid = max(graph.stage_id(prod) for prod in ftensor.producers) for node in nodes[::-1]: # print(node) assert node not in graph.nodes() assert len(node.outputs()) == 1 - graph.attach(node, min_idx) + graph.attach(node, min_idx, stage_idx=fsid) # insert and update backward node if graph.train: # update backward node for consumer in new_ftensor.consumers: assert isinstance(consumer.mirror, IRBpOperation) - bidx = graph.detach(consumer.mirror) - consumer.mirror.update() - graph.attach(consumer.mirror, bidx) + bnode = consumer.mirror + with graph.update(bnode) as bnode: + bnode.update() # insert backward node bnodes = [node.gen_backward() for node in nodes] bidx = min(graph.nodes().index(producer.mirror) for producer in ftensor.producers) @@ -553,6 +375,17 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): if ctensor not in devtensors[devid]: devtensors[devid][ctensor] = [] devtensors[devid][ctensor].append(consumer) + + # restrict each device has same subtensor + nl = '\n' + for devid in devtensors: + assert len(devtensors[devid]) <= 1, ( + "Detect that a full tensor is partitioned differently on a device.\n" + "To achieve this, need manually add multiref operator in model description.\n" + f"Full Tensor: {ftensor}\n" + f"Producers:\n{nl.join(repr(node) for node in ftensor.producers)}\n" + f"Consumers:\n{nl.join(repr(node) for node in ftensor.consumers)}" + ) # add multiref forward node multirefs: Dict[MultiRef, List[IRFwOperation]] = dict() @@ -562,6 +395,7 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): if len(consumers) == 1: continue multiref = MultiRef(None, [ctensor, len(consumers)]) + multiref.infer_shape() multiref.device = devid ftensors = [ctensor.parent.like() for _ in range(len(consumers))] itensors = [ft.select(ctensor.indmap, ctensor.valmap) for ft in ftensors] @@ -571,12 +405,11 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): # update consumer min_fidx = len(graph.nodes()) for itensor, consumer in zip(itensors, consumers): - fidx = graph.detach(consumer) - idx = consumer.inputs().index(ctensor) - consumer.set_input(idx, itensor) - graph.attach(consumer, fidx) - min_fidx = min(fidx, min_fidx) - + with graph.update(consumer) as consumer: + idx = consumer.inputs().index(ctensor) + consumer.set_input(idx, itensor) + min_fidx = min(graph.nodes().index(consumer) for consumer in consumers) + # insert forward multiref graph.attach(multiref, min_fidx) multirefs[multiref] = consumers @@ -587,11 +420,12 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): # update consumer backward for consumer in consumers: assert isinstance(consumer.mirror, IRBpOperation) - bidx = graph.detach(consumer.mirror) - consumer.mirror.update() - graph.attach(consumer.mirror, bidx) + bnode: IRBpOperation = consumer.mirror + with graph.update(bnode) as bnode: + bnode.update() # insert backward bnode = multiref.gen_backward() bnode.device = multiref.device bidx = max(graph.nodes().index(consumer.mirror) for consumer in consumers) - graph.attach(bnode, bidx+1) + bsid = graph.stage_id(graph.node(bidx)) + graph.attach(bnode, bidx+1, stage_idx=bsid) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 7aea3deb..5d3b9e9e 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -10,15 +10,14 @@ from contextlib import contextmanager from typing import Union, Tuple, List, Optional, Dict, Set -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.function import Identity, MultiRef -from cube.graph.segment import IRSegment - from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.ir.adapter import IRAdapter -from cube.ir.tensor import IRFullTensor, IRSubTensor, StartEnd +from cube.ir.tensor import IRFullTensor, IRSubTensor + +from cube.graph.function.function import Identity, MultiRef +from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo @@ -79,11 +78,11 @@ def tuple(self) -> Tuple[int, Optional[int]]: return (self.gidx, self.sidx) - class IRGraph(IRSegment): """ - IR Graph. The hyperGraph for representing distributed - graph. + IRGraph. + + IRGraph is used for reprensting a distributed training iteration. """ def __init__(self, @@ -98,25 +97,19 @@ def __init__(self, outputs = IRGraph.get_outputs(nodes) super().__init__([], inputs, outputs, module_name) - self._attributes = set() - self._full_tensors: Set[IRFullTensor] = set() - - self._train: bool = any( - isinstance(node, IRBpOperation) or - (isinstance(node, IRSegment) and node.forward) or - (isinstance(node, IRAdapter) and node.forward) for node in nodes - ) + # atrribute tensors + self._attributes: Set[IRSubTensor] = set() self._sched = None # the schedule strategy - # set parameters / buffers and full tensors + # set parameters / buffers for node in nodes: - for tensor in node.inputs() + node.outputs(): - if isinstance(tensor, IRSubTensor): - tensor.parent.clear_producer_consumer() - self._full_tensors.add(tensor.parent) - if tensor.is_attr(): - self._attributes.add(tensor) + tensors = node.inputs() + node.outputs() + tensors = [t for t in tensors if isinstance(t, IRSubTensor)] + for t in tensors: + t.parent.clear_producer_consumer() + if t.is_attr(): + self._attributes.add(t) # insert node from nodes for idx, node in enumerate(nodes): @@ -131,7 +124,7 @@ def train(self) -> bool: @return train bool: True if backward is required, otherwise False (inference only). """ - return self._train + return self._have_forward and self._have_backward def reset_dependency(self): """ @@ -162,12 +155,6 @@ def attributes(self) -> Tuple[IRSubTensor]: """ return tuple(self._attributes) - def full_tensors(self) -> List[IRFullTensor]: - """ - Return full tensor list - """ - return list(self._full_tensors) - def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: """ forward will divide the graph into Actions according to @@ -188,7 +175,6 @@ def __call__(self, *args): """ return self.forward(*args) - # ====================== Graph Accessment ========================= def flatten(self) -> List[IRCell]: @@ -252,7 +238,6 @@ def node(self, index: Union[int, GraphIndex]) -> IRCell: """ index = GraphIndex(index, None) if isinstance(index, int) else index assert isinstance(index, GraphIndex) - assert len(index) == 2 and isinstance(index[0], int) node = self._nodes[index.gidx] if index.sidx is not None: assert isinstance(node, IRSegment), "Expected IRSegment" @@ -431,8 +416,8 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: @return segment IRSegment: the grouped segment """ - assert any(not isinstance(node, (IRBpOperation, IRSegment)) for node in fnodes), \ - "grouped nodes cannot be backward operation or segment" + assert any(not isinstance(node, (IRBpOperation, IRSegment, IRDataOperation)) for node in fnodes), \ + "grouped nodes cannot be backward operation, segment or data operation" findices = [self.index(fnode) for fnode in fnodes] @@ -716,23 +701,27 @@ def assign(self, node: Union[IRFwOperation, IRDataOperation], device: int) -> bo """ Assign an operator (subgraph) to (multiple) rank(s). - If `ranks` has multiple integer, then the operator will be replicated - `len(ranks)` times and assigned to given device correspondingly. + Corresponding backward operators (if have) will also be + assigned to the same device. - Corresponding backward operators (if have) will also be replicated - and assigned to the same device with it's forward operator - - @param node Union[IRFwOperation, IRBpOperation]: operator - @param ranks Tuple[int, Tuple[int]]: assigned ranks + @param node Union[IRFwOperation, IRBpOperation, IRSegment]: operator + @param device int: assigned device id @return sucess bool: always true """ - assert isinstance(node, (IRFwOperation, IRDataOperation)), \ - f"Only forward and data operation can be assigned to device, but got {node}" assert self.exist(node), f"{node} is not in the graph" - node.device = device - if node.mirror is not None: - node.mirror.device = device + if isinstance(node, IRSegment): + assert node.forward, "Only forward segment is allowed to assign devices" + for subnode in node.nodes(): + subnode.device = device + if subnode.mirror is not None: + subnode.mirror.device = device + else: + assert isinstance(node, (IRFwOperation, IRDataOperation)), \ + "Only forward operators and dataloader operators are allowed to assign devices" + node.device = device + if node.mirror is not None: + node.mirror.device = device return True ## Schedule Policy Primitives ## @@ -926,7 +915,7 @@ def staging(self, nodes: Tuple[IRFwOperation]): @param starts Tuple[int]: the start index of each stage @return None """ - assert all(isinstance(node, (IRFwOperation, IRDataOperation)) for node in nodes), \ + assert all(isinstance(node, IRFwOperation) for node in nodes), \ f"Find node is not IRFwOperation or IRDataOperation: {node}" assert all(node in self._nodes for node in nodes), \ f"Exist node is not in graph nodes" @@ -939,36 +928,41 @@ def staging(self, nodes: Tuple[IRFwOperation]): if not isinstance(node, IRBpOperation): last_fidx = idx - fstages = [] + fstages: List[List[IRCell]] = [] + bstages: List[List[IRCell]] = [] for sid in range(len(starts)): begin = starts[sid] - if sid == len(starts) - 1: - end = last_fidx + 1 - else: - end = starts[sid+1] - if begin == end: - continue + end = starts[sid+1] if sid != len(starts) - 1 else last_fidx + 1 + while isinstance(self.node(begin), IRDataOperation): + begin += 1 + while isinstance(self.node(end), IRDataOperation): + end -= 1 + if begin == end: continue assert begin < end - fstages.append(self._nodes[begin:end]) - - # grouping into index - fsegs: List[IRSegment] = [] - for sid in range(len(fstages)): - fsegs.append(self.group(fstages[sid])) - - fidxs: List[int] = [self._nodes.index(seg) for seg in fsegs] - bidxs: List[Optional[int]] = [self._nodes.index(seg.mirror) if seg.mirror is not None else None for seg in fsegs] + fnodes = self._nodes[begin:end] + bnodes = [fnode.mirror for fnode in fnodes[::-1] if fnode.mirror is not None] + fstages.append(fnodes) + bstages = [bnodes] + bstages + + def get_sid(fnode: IRCell) -> Optional[int]: + for idx, fnodes in enumerate(fstages): + if fnode in fnodes: + return idx + return None def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: identity = Identity('', [tensor]) identity.infer_shape() identity.set_output(0, identity.output(0).tosub()) # insert forward - self.insert(identity, GraphIndex(fidxs[sid], 0)) + self.insert(identity, self.index(fstages[sid][0])) + fstages[sid].insert(0, identity) + # insert backward if self.train: bnode = identity.gen_backward() - self.insert(bnode, GraphIndex(bidxs[sid], -1)) + self.insert(bnode, self.index(bstages[sid][-1]) + 1) + bstages[sid].append(bnode) return identity # create identity op for cross-stage dataflow @@ -978,16 +972,15 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: assert len(ftensor.producers) <= 1, \ "The staging interface should be called before any operator partition." if len(ftensor.consumers) == 0: continue - producer = ftensor.producers[0] - # TODO: robustness - psid = fidxs.index(self.index(producer).gidx) - ptensor = ftensor.ptensors[0] + producer, ptensor = ftensor.producers[0], ftensor.ptensors[0] + psid = get_sid(producer) + # outside of stages, not consider + if psid is None: continue out = ptensor curr_sid = psid for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): assert ctensor == ptensor, "The staging interface should be called before any operator partition." - # TODO: robustness - csid = fidxs.index(self.index(consumer).gidx) + csid = get_sid(consumer) if curr_sid == csid: continue for sid in range(curr_sid + 1, csid): identity = insert_identity(out, sid) @@ -1000,6 +993,11 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: with self.update(consumer.mirror) as bnode: bnode.update() curr_sid = csid + + # grouping into segment + for sid in range(len(fstages)): + self.group(fstages[sid]) + # ================= Other optimizations ================== diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 0acef0e0..5cc3ab60 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1,4 +1,4 @@ -from typing import Union, Tuple, List, Optional, Dict +from typing import Union, List, Optional, Set import copy from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -11,6 +11,9 @@ class IRSegment(IRCell): """ A distributed sub-graph representing a piece of workload in parent IRGraph + + Once the segment is generated, its input and output will be fixed. + Inserting and removing nodes that could change input/output are not allowed. """ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRSubTensor], name='segment'): @@ -20,10 +23,19 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR self._idevice = [t.device for t in inputs] self._odevice = [t.device for t in outputs] - for idx, val in enumerate(inputs): - self.set_input(idx, val) - for idx, val in enumerate(outputs): - self.set_output(idx, val) + self._inputs = list(inputs) + self._outputs = list(outputs) + # for idx, val in enumerate(inputs): + # self.set_input(idx, val) + # for idx, val in enumerate(outputs): + # self.set_output(idx, val) + + # full tensors + self._full_tensors: Set[IRFullTensor] = set() + for node in nodes: + for tensor in node.inputs() + node.outputs(): + if isinstance(tensor, IRSubTensor): + self._full_tensors.add(tensor.parent) self._have_forward = any(isinstance(n, IRFwOperation) for n in nodes) self._have_backward = any(isinstance(n, IRBpOperation) for n in nodes) @@ -32,6 +44,12 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR def forward(self) -> bool: return self._have_forward + def full_tensors(self) -> List[IRFullTensor]: + """ + Return full tensor list + """ + return list(self._full_tensors) + # ========================= Basic Graph access ======================= @property @@ -89,7 +107,7 @@ def insert(self, node: IRCell, index: int): """ Insert a node at index. - TODO: update input and output + TODO: check input and output @param node IRCell: the inserted node @param index int: the index @@ -102,7 +120,7 @@ def remove(self, node: IRCell) -> int: """ Remove a node at index - # TODO: update input and output + # TODO: check input and output @param node IRCell: the removed node @@ -117,7 +135,7 @@ def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: """ Replace one node by multiple nodes - # TODO: update input and output + # TODO: check input and output @param node IRCell: the replaced node @param new_nodes List[IRCell]: the nodes to be inserted. From 9d38492c4435f9e9327c1d102a77a8c9a243824c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 2 Sep 2022 19:13:40 +0800 Subject: [PATCH 0978/1892] making segment as island (detaching) --- cube/runtime/executor.py | 132 ++++++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 49 deletions(-) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index fb384508..f0e4d97a 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -2,21 +2,93 @@ Executor for runtime """ -from typing import Tuple, Any, Callable, List +from typing import Tuple, Any, Callable, List, Dict import torch -def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): - """ - forward the sub-graph. - """ - if not requires_grad: - with torch.no_grad(): + +class Executor: + + _detach: Dict[torch.Tensor, torch.Tensor] = dict() + + @staticmethod + def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): + """ + forward the sub-graph. + """ + if not requires_grad: + with torch.no_grad(): + outputs = subgraph(*input_tensors) + else: + # everytime forward a segment, detach the tensor from previous graph + for itensor in input_tensors: + if torch.is_tensor(itensor) and itensor.requires_grad: + assert itensor not in Executor._detach + Executor._detach[itensor] = itensor.detach().requires_grad_() + input_tensors = tuple( + Executor._detach[t] if t in Executor._detach else t for t in input_tensors + ) outputs = subgraph(*input_tensors) - else: - outputs = subgraph(*input_tensors) - # print('forwarding... ') - return outputs + # print('forwarding... ') + return outputs + + @staticmethod + def aexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): + """ + execute adapter + """ + if not requires_grad: + with torch.no_grad(): + outputs = subgraph(*input_tensors) + else: + outputs = subgraph(*input_tensors) + # print('forwarding... ') + return outputs + + @staticmethod + def backward(input_tensors: List[torch.Tensor], + output_tensors: List[torch.Tensor], + output_tensor_grads: List[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Backward Procedure. + + input_tensors: List[torch.Tensor]: + tensors that their gradient need to be computed, including parameters. + Correspoinding forward input tensors. + + output_tensors: + tensors that start for gradient backward computation. + Corresponding to forward output tensors. + + output_tensor_grads: + gradient tensors corresponding to output_tensors. + + Returns: + gradient in order of non-parameter tensors in input_tensors. + (Note parameter tnesors already have gradient accumulated at .grad attribute) + """ + if len(output_tensors) == 0: + return None + inputs = list() + # everytime backward a input tensor, remove it from _detach + input_tensors = [Executor._detach.pop(t) if t in Executor._detach else t for t in input_tensors] + for input_ in input_tensors: + if torch.is_tensor(input_) and not isinstance(input_, torch.nn.Parameter): + if input_.requires_grad: + input_.retain_grad() + inputs.append(input_) + torch.autograd.backward( + output_tensors, + grad_tensors=output_tensor_grads, + ) + grads = tuple(input_.grad for input_ in inputs) + if len(grads) == 0: return None + elif len(grads) == 1: return grads[0] + else: return tuple(grads) + + +fexecute = Executor.fexecute +backward = Executor.backward # def backward(input_tensors : List[torch.Tensor], @@ -62,44 +134,6 @@ def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) # else: return tuple(grads) -def backward(input_tensors: List[torch.Tensor], - output_tensors: List[torch.Tensor], - output_tensor_grads: List[torch.Tensor]) -> Tuple[torch.Tensor]: - """ - Backward Procedure. - - input_tensors: List[torch.Tensor]: - tensors that their gradient need to be computed, including parameters. - Correspoinding forward input tensors. - - output_tensors: - tensors that start for gradient backward computation. - Corresponding to forward output tensors. - - output_tensor_grads: - gradient tensors corresponding to output_tensors. - - Returns: - gradient in order of non-parameter tensors in input_tensors. - (Note parameter tnesors already have gradient accumulated at .grad attribute) - """ - if len(output_tensors) == 0: - return None - inputs = list() - for input_ in input_tensors: - if torch.is_tensor(input_) and not isinstance(input_, torch.nn.Parameter): - if input_.requires_grad: - input_.retain_grad() - inputs.append(input_) - torch.autograd.backward( - output_tensors, - grad_tensors=output_tensor_grads, - ) - grads = tuple(input_.grad for input_ in inputs) - if len(grads) == 0: return None - elif len(grads) == 1: return grads[0] - else: return tuple(grads) - ### =================== Experimental Feature ======================= # import queue From e405800bfbfa545bc63b9a5c314d564ab677cf3d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 2 Sep 2022 19:14:26 +0800 Subject: [PATCH 0979/1892] adapter gen for concurrent producers and consumers --- cube/graph/gener/concurrent.py | 248 +++++++++++++++++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 cube/graph/gener/concurrent.py diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py new file mode 100644 index 00000000..8625db46 --- /dev/null +++ b/cube/graph/gener/concurrent.py @@ -0,0 +1,248 @@ +""" +Concurrent producer / consumer Adapter Generator +""" +from typing import List, Optional +import copy + +from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from cube.ir.adapter.prim import IRAdapterPrim +from cube.ir.adapter import IRAdapter +from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim + +from cube.graph.gener.layout import GridLayout + + +class ConcurrentGener: + + @staticmethod + def gen(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> Optional[IRAdapter]: + """ + Generate forward adapter and backward adapter + + @param ptensors List[IRSubTensor]: forward producer tensors + @param ctensors List[IRSubTensor]: forward consumer tensors + + @return fadapter Optional[IRAdapter]: forward adapter + None indicate no adapter required. + """ + pdevs = tuple(t.device[0] for t in ptensors) + cdevs = tuple(t.device[0] for t in ctensors) + + fadapter: IRAdapter = None + + # case 1: sharing device (in-shard) + inshard = (set(pdevs) == set(cdevs)) and (len(ptensors) == len(ctensors)) and (len(pdevs) == len(ptensors)) + if inshard and len(pdevs) > 1: + try: + fadapter = ConcurrentGener.gen_in_shard(ptensors, ctensors, allow_reorder=True) + except Exception as e: + fadapter = None + print( + f"full tensor: {ptensors[0].parent} cannot use grid generation.\n" + f"Reason: {str(e)}\n" + f"Switch to general P2P communication." + ) + + # Case 2: sperating device (cross-shard) + if len(set(pdevs).intersection(cdevs)) == 0: + pass + + # Case 3: General cases + # warnings.warn('The adapter is generated using P2P communication') + if fadapter is None: + fadapter = ConcurrentGener.gen_general(ptensors, ctensors) + + return fadapter + + @staticmethod + def gen_in_shard(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor], allow_reorder=False): + ftensor = ptensors[0].parent + # producer grid layout + ilayout = GridLayout.togrid(ftensor, ptensors) + # reorder ctensors to match with ptensors + devs = [ptensor.device for ptensor in ilayout.mat.flatten()] + ctensors = [None] * len(devs) + for ctensor in ctensors: + idx = devs.index(ctensor.device) + ctensors[idx] = ctensor + assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" + # consumer grid layout + olayout = GridLayout.togrid(ftensor, ctensors) + # find path + paths, fprims = ilayout.path(olayout) + + # re-assign the operator if miss-ordered + names, from_dev, to_dev = [], [], [] + for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): + assert len(itensor.device) == 1 and len(otensor.device) == 1, \ + "Expect tensor only has one device. Report this as a bug" + if itensor.device != otensor.device: + inode, onode = itensor.cell, otensor.cell + names.append(f'{onode.name}{onode.cid}') + from_dev.append(onode.device[0]) + to_dev.append(inode.device[0]) + if allow_reorder: + onode.device = inode.device + if onode.mirror is not None: + onode.mirror.device = inode.device + else: + raise RuntimeError("device mismatch. Try to enable reorder") + if len(names) > 0: + print(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') + + fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) + fadapter.prims = fprims + + # generate backward + grad: IRFullTensor = ftensor.grad + b_ptensors = [ctensor.grad for ctensor in ctensors] + b_ctensors = [ptensor.grad for ptensor in ptensors] + bprims = [] + if grad is not None and (len(b_ptensors) != 0 or len(b_ctensors) != 0): + # reorder ptensors to match with forward + ptensors = [None] * len(devs) + for b_ptensor in b_ptensors: + idx = devs.index(b_ptensor.device) + assert ptensors[idx] is None, "same device of different tensors" + ptensors[idx] = b_ptensor + ilayout = GridLayout.togrid(grad, b_ptensors) + olayout = GridLayout.togrid(grad, b_ctensors) + paths, bprims = ilayout.path(olayout) + # check the device order + for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): + assert len(itensor.device) == len(otensor.device), "backward device not match" + badapter = IRAdapter(b_ptensors, b_ctensors) + badapter.prims = bprims + IRAdapter.make_pair(fadapter, badapter) + + return fadapter + + @staticmethod + def gen_cross_shard(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> IRAdapter: + pass + + @staticmethod + def gen_general(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> IRAdapter: + """ + A general way to generate adapter. + FIXME: Assuming consumers at different devices can happen at the same time. + This will block the pipeline parallelism description. + + @param ftensor IRFullTensor + @return adapter IRAdapter + """ + fprims = [] + for ctensor in ctensors: + fprims += ConcurrentGener.gen_subtensor(ctensor, ptensors) + fadapter = IRAdapter(ptensors,ctensors) + fadapter.prims = fprims + # backward + b_ptensors = [ctensor.grad for ctensor in ctensors if ctensor.grad is not None] + b_ctensors = [ptensor.grad for ptensor in ptensors if ptensor.grad is not None] + bprims = [] + for cgrad in b_ctensors: + bprims += ConcurrentGener.gen_subtensor(cgrad, b_ptensors) + badapter = IRAdapter(b_ptensors, b_ctensors) + badapter.prims = bprims + IRAdapter.make_pair(fadapter, badapter) + return fadapter + + @staticmethod + def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRAdapterPrim]: + """ + Generate communiction primitives for ctensor + + @param ctensor IRSubTensor: the consumed tensor as destination + @param ptensors List[IRSubTensor]: the produced tensors as source + + @return prims List[IRAdapterPrim]: the primitives for adapter + """ + # category to local tensor and remote tensor + local = [t for t in ptensors if t.device == ctensor.device] + remote = [t for t in ptensors if t.device != ctensor.device] + prims = [] + + # ==== select ==== # + intersections = [] + # check local + for itensor in local+remote: + if itensor.device == ctensor.device and itensor == ctensor: + return [] + common: Optional[IRSubTensor] = itensor.common(ctensor) + if common is None: + continue + common.cell = itensor.cell + intersections.append(common) + # create select primitive + if common != itensor: + indmap = [] + for dim in range(itensor.ndims): + (s1, e1), (s2, e2) = itensor.indmap[dim], common.indmap[dim] + start = s2 - s1 + end = start + e2 - s2 + indmap.append((start, end)) + indmap = IndexMap(tuple(indmap)) + assert itensor.valmap == common.valmap, "Value map not same" + valmap = ValueMap((0, 1)) + select_prim = SelectPrim(itensor, indmap, valmap, common) + prims.append(select_prim) + if itensor.device == ctensor.device and common == ctensor: + return [select_prim] + # TODO: check union == subtensor + if common == ctensor: + break + + # print(intersections) + # ====== move ===== # + tmoved = [] + for tensor in intersections: + assert len(tensor.device) == 1 and len(ctensor.device) == 1, "Expected only one device." + mtensor = tensor + if tensor.device != ctensor.device: + mtensor = copy.copy(tensor) + mtensor.cell = ctensor.cell + prims.append(MovePrim(tensor, mtensor)) + tmoved.append(mtensor) + + # ===== merge ===== # + remain_tensors: List[IRSubTensor] = copy.copy(tmoved) + if ctensor in remain_tensors: + return prims + out = None + while out != ctensor: + out, merged = None, False + for idx1 in range(len(remain_tensors) - 1): + for idx2 in range(idx1+1, len(remain_tensors)): + t1, t2 = remain_tensors[idx1], remain_tensors[idx2] + catdim = t1.catdim(t2) + if catdim is not None: + tensors = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] + out = tensors[0].concat(tensors[1], dim=catdim) + out.cell = ctensor.cell + prims.append(MergeDimPrim(tensors, out, catdim)) + merged = True + break + # reduction + if t1.accumable(t2): + out = t1.accum(t2) + out.cell = ctensor.cell + prims.append(SumPrim([t1, t2], out)) + merged = True + break + if merged: + remain_tensors.remove(t1) + remain_tensors.remove(t2) + remain_tensors.append(out) + break + if out is None: + ptensors = '\n\t'.join(t.extra_repr() for t in ptensors) + remain = '\n\t'.join(t.extra_repr() for t in remain_tensors) + raise RuntimeError( + f"Fail to build adapter.\n" + f"FullTensor:{ctensor.parent}\n" + f"Produced Tensors:\n\t{ptensors}\n" + f"Consumed Tensors:\n\t{ctensor.extra_repr()}\n" + f"Consumer:\n\t{ctensor.cell}\n" + f"Remain Tensor:\n\t{remain}" + ) + return prims From df29685cb5de0a3ee723581d27bfad702c804efc Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Sun, 4 Sep 2022 12:49:44 +0800 Subject: [PATCH 0980/1892] evoformer runnable --- examples/alphafold2/evoformer.py | 295 +++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 examples/alphafold2/evoformer.py diff --git a/examples/alphafold2/evoformer.py b/examples/alphafold2/evoformer.py new file mode 100644 index 00000000..0c802c0c --- /dev/null +++ b/examples/alphafold2/evoformer.py @@ -0,0 +1,295 @@ +import torch +import math +""" +[bs, s, r, cm] -> [bs, s, r, cm] + +used as column-wise gated self-attention +""" + + +def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias, + head: int, c: int, scale: float): + bs, s, r, cm = x.size() + + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + + gate = gate.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + q = q.reshape(bs, s, r, head, c).transpose(2, + 3).reshape(bs * s * head, r, c) + k = k.reshape(bs, s, r, head, c).transpose(2, + 3).reshape(bs * s * head, r, + c).transpose(1, 2) + v = v.reshape(bs, s, r, head, c).transpose(2, + 3).reshape(bs * s * head, r, c) + + sim = torch.bmm(q, k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + + if isinstance(bias, torch.Tensor): + sim = sim.reshape(bs, s, head, r, r) + bias + sim = sim.reshape(bs * s * head, r, r) + + attend = torch.bmm(sim, v) * gate + + out = attend.reshape(bs, s, head, r, c).transpose(2, + 3).reshape(bs, s, r, cm) + out = torch.matmul(out, out_proj) + return out + + +""" +([bs, s, r, cm], [bs, r, r, cz]) -> [bs, s, r, cm] +""" + + +def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, + pair_repr: torch.Tensor, + gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias_proj: torch.Tensor, head: int, c: int, + scale: float): + bs, s, r, cm = msa_repr.size() + + bias = torch.matmul(pair_repr, + bias_proj).permute(0, 3, 1, + 2).reshape(bs, 1, head, r, r) + + return MSAAttention(msa_repr, gate_proj, qkv_proj, out_proj, bias, head, c, + scale) + + +""" +[bs, s, r, cm] -> [bs, s, r, cm] +""" + + +def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + return torch.matmul( + torch.nn.functional.relu(torch.matmul(msa_repr, proj1)), proj2) + + +""" +[bs, s, r, cm] -> [r, r, cz] +""" + + +def OuterProductMean(msa_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor, out_proj: torch.Tensor): + bs, s, r, cm = msa_repr.size() + c = proj1.size(-1) + + a = torch.matmul(msa_repr, proj1).transpose(-2, -3) + b = torch.matmul(msa_repr, proj2).transpose(-2, -3) + + outer = torch.einsum('...bac,...dae->...bdce', a, + b).reshape(bs, r, r, c * c) + outer = torch.matmul(outer, out_proj) + return outer + + +def TriangleMultiplication( + pair_repr: torch.Tensor, tri_mul_norm1_weight: torch.Tensor, + tri_mul_norm1_bias: torch.Tensor, tri_mul_norm2_weight: torch.Tensor, + tri_mul_norm2_bias: torch.Tensor, tri_mul_proj1: torch.Tensor, + tri_mul_proj2: torch.Tensor, tri_mul_proj3: torch.Tensor, + tri_mul_proj4: torch.Tensor, tri_mul_proj5: torch.Tensor, + tri_mul_proj6: torch.Tensor, cz: int, out_going: False): + pair_repr = torch.nn.functional.layer_norm(pair_repr, (cz, ), + tri_mul_norm1_weight, + tri_mul_norm1_bias) + a = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj1)) + a = a * torch.matmul(pair_repr, tri_mul_proj2) + b = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj3)) + b = b * torch.matmul(pair_repr, tri_mul_proj4) + + if out_going: + a = a.permute(0, 3, 1, 2) + b = b.permute(0, 3, 2, 1) + else: + a = a.permute(0, 3, 2, 1) + b = b.permute(0, 3, 1, 2) + + p = torch.matmul(a, b).permute(0, 2, 3, 1) + p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, + tri_mul_norm2_bias) + p = torch.matmul(p, tri_mul_proj5) + g = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj6)) + return p * g + + +def TriangleAttentionNode(pair_repr: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias_proj: torch.Tensor, head: int, c: int, + scale: float): + bias = torch.matmul(pair_repr, bias_proj).permute(0, 3, 1, 2) + + return MSAAttention(pair_repr, gate_proj, qkv_proj, out_proj, bias, head, + c, scale) + + +def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + return torch.matmul( + torch.nn.functional.relu(torch.matmul(pair_repr, proj1)), proj2) + + +""" +a simplified version for evoformer in alphafold2 + - dropout layers are omitted + - masks are omitted +""" + + +class Evoformer(torch.nn.Module): + + def __init__(self, s: int, cm: int, cz: int, c: int, head: int): + super().__init__() + + self.s, self.cm, self.cz, self.c, self.head = s, cm, cz, c, head + self.scale = 1.0 / math.sqrt(c) + + # MSA row-wise gated self-attention with pair bias + self.row_norm_m = torch.nn.LayerNorm(cm) + self.row_norm_z = torch.nn.LayerNorm(cz) + self.row_gate_proj = torch.nn.Parameter(torch.randn(cm, head * c)) + self.row_qkv_proj = torch.nn.Parameter(torch.randn(cm, 3 * head * c)) + self.row_out_proj = torch.nn.Parameter(torch.randn(head * c, cm)) + self.row_bias_proj = torch.nn.Parameter(torch.randn(cz, head)) + + # MSA column-wise gated self-attention + self.col_norm = torch.nn.LayerNorm(cm) + self.col_gate_proj = torch.nn.Parameter(torch.randn(cm, head * c)) + self.col_qkv_proj = torch.nn.Parameter(torch.randn(cm, 3 * head * c)) + self.col_out_proj = torch.nn.Parameter(torch.randn(head * c, cm)) + + # MSA transition + self.msa_transition_norm = torch.nn.LayerNorm(cm) + self.msa_transition_proj1 = torch.nn.Parameter(torch.randn(cm, 4 * cm)) + self.msa_transition_proj2 = torch.nn.Parameter(torch.randn(4 * cm, cm)) + + # Outer product mean + self.outer_norm = torch.nn.LayerNorm(cm) + self.outer_proj1 = torch.nn.Parameter(torch.randn(cm, c)) + self.outer_proj2 = torch.nn.Parameter(torch.randn(cm, c)) + self.outer_out_proj = torch.nn.Parameter(torch.randn(c * c, cz)) + + # Triangular multiplicative update using outgoing edges + self.tri_mul_out_norm1_weight = torch.nn.Parameter(torch.empty(cz)) + self.tri_mul_out_norm1_bias = torch.nn.Parameter(torch.empty(cz)) + self.tri_mul_out_norm2_weight = torch.nn.Parameter(torch.empty(128)) + self.tri_mul_out_norm2_bias = torch.nn.Parameter(torch.empty(128)) + self.tri_mul_out_proj1 = torch.nn.Parameter(torch.randn(cz, 128)) + self.tri_mul_out_proj2 = torch.nn.Parameter(torch.randn(cz, 128)) + self.tri_mul_out_proj3 = torch.nn.Parameter(torch.randn(cz, 128)) + self.tri_mul_out_proj4 = torch.nn.Parameter(torch.randn(cz, 128)) + self.tri_mul_out_proj5 = torch.nn.Parameter(torch.randn(128, cz)) + self.tri_mul_out_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) + + # Triangular multiplicative update using incoming edges + self.tri_mul_in_norm1_weight = torch.nn.Parameter(torch.empty(cz)) + self.tri_mul_in_norm1_bias = torch.nn.Parameter(torch.empty(cz)) + self.tri_mul_in_norm2_weight = torch.nn.Parameter(torch.empty(128)) + self.tri_mul_in_norm2_bias = torch.nn.Parameter(torch.empty(128)) + self.tri_mul_in_proj1 = torch.nn.Parameter(torch.randn(cz, 128)) + self.tri_mul_in_proj2 = torch.nn.Parameter(torch.randn(cz, 128)) + self.tri_mul_in_proj3 = torch.nn.Parameter(torch.randn(cz, 128)) + self.tri_mul_in_proj4 = torch.nn.Parameter(torch.randn(cz, 128)) + self.tri_mul_in_proj5 = torch.nn.Parameter(torch.randn(128, cz)) + self.tri_mul_in_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) + + # Triangular gated self-attention around starting node + self.tri_att_start_norm = torch.nn.LayerNorm(cz) + self.tri_att_start_gate_proj = torch.nn.Parameter( + torch.randn(cz, 4 * c)) + self.tri_att_start_qkv_proj = torch.nn.Parameter( + torch.randn(cz, 3 * 4 * c)) + self.tri_att_start_out_proj = torch.nn.Parameter(torch.randn( + 4 * c, cz)) + self.tri_att_start_bias_proj = torch.nn.Parameter(torch.randn(cz, 4)) + + # Triangular gated self-attention around ending node + self.tri_att_end_norm = torch.nn.LayerNorm(cz) + self.tri_att_end_gate_proj = torch.nn.Parameter(torch.randn(cz, 4 * c)) + self.tri_att_end_qkv_proj = torch.nn.Parameter( + torch.randn(cz, 3 * 4 * c)) + self.tri_att_end_out_proj = torch.nn.Parameter(torch.randn(4 * c, cz)) + self.tri_att_end_bias_proj = torch.nn.Parameter(torch.randn(cz, 4)) + + # Transition in the pair stack + self.pair_transition_norm = torch.nn.LayerNorm(cz) + self.pair_transition_proj1 = torch.nn.Parameter(torch.randn( + cz, 4 * cz)) + self.pair_transition_proj2 = torch.nn.Parameter(torch.randn( + 4 * cz, cz)) + + def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): + + msa_repr = msa_repr + MSARowAttentionWithPairBias( + self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, + self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, + self.head, self.c, self.scale) + + msa_repr = msa_repr.transpose(-3, -2) + msa_repr = msa_repr + MSAAttention( + self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, + self.col_out_proj, None, self.head, self.c, self.scale) + msa_repr = msa_repr.transpose(-3, -2) + + msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), + self.msa_transition_proj1, + self.msa_transition_proj2) + + pair_repr = pair_repr + OuterProductMean( + self.outer_norm(msa_repr), self.outer_proj1, self.outer_proj2, + self.outer_out_proj) + + pair_repr = pair_repr + TriangleMultiplication( + pair_repr, self.tri_mul_out_norm1_weight, + self.tri_mul_out_norm1_bias, self.tri_mul_out_norm2_weight, + self.tri_mul_out_norm2_bias, self.tri_mul_out_proj1, + self.tri_mul_out_proj2, self.tri_mul_out_proj3, + self.tri_mul_out_proj4, self.tri_mul_out_proj5, + self.tri_mul_out_proj6, self.cz, True) + + pair_repr = pair_repr + TriangleMultiplication( + pair_repr, self.tri_mul_in_norm1_weight, + self.tri_mul_in_norm1_bias, self.tri_mul_in_norm2_weight, + self.tri_mul_in_norm2_bias, self.tri_mul_in_proj1, + self.tri_mul_in_proj2, self.tri_mul_in_proj3, + self.tri_mul_in_proj4, self.tri_mul_in_proj5, + self.tri_mul_in_proj6, self.cz, True) + + pair_repr = pair_repr + TriangleAttentionNode( + self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, + self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, + self.tri_att_start_bias_proj, 4, self.c, self.scale) + + pair_repr = pair_repr.transpose(-3, -2) + pair_repr = pair_repr + TriangleAttentionNode( + self.tri_att_end_norm(pair_repr), self.tri_att_end_gate_proj, + self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, + self.tri_att_end_bias_proj, 4, self.c, self.scale) + pair_repr = pair_repr.transpose(-3, -2) + + pair_repr = pair_repr + PairTransition( + self.pair_transition_norm(pair_repr), self.pair_transition_proj1, + self.pair_transition_proj2) + + return msa_repr, pair_repr + + +def test(): + bs, s, r, cm, cz, c, heads = 1, 128, 256, 256, 128, 32, 8 + model = Evoformer(s, cm, cz, c, heads) + + msa = torch.randn(bs, s, r, cm) + pair = torch.randn(bs, r, r, cz) + + new_msa, new_pair = model(msa, pair) + + +test() From 023329c12654fcf589e0ac6a4ca9a759604f9266 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Sun, 4 Sep 2022 13:31:53 +0800 Subject: [PATCH 0981/1892] refine code --- examples/alphafold2/evoformer.py | 117 +++++++++++++++++++------------ 1 file changed, 71 insertions(+), 46 deletions(-) diff --git a/examples/alphafold2/evoformer.py b/examples/alphafold2/evoformer.py index 0c802c0c..704695c6 100644 --- a/examples/alphafold2/evoformer.py +++ b/examples/alphafold2/evoformer.py @@ -97,7 +97,7 @@ def TriangleMultiplication( tri_mul_norm2_bias: torch.Tensor, tri_mul_proj1: torch.Tensor, tri_mul_proj2: torch.Tensor, tri_mul_proj3: torch.Tensor, tri_mul_proj4: torch.Tensor, tri_mul_proj5: torch.Tensor, - tri_mul_proj6: torch.Tensor, cz: int, out_going: False): + tri_mul_proj6: torch.Tensor, cz: int, out_going: bool): pair_repr = torch.nn.functional.layer_norm(pair_repr, (cz, ), tri_mul_norm1_weight, tri_mul_norm1_bias) @@ -146,30 +146,42 @@ def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, class Evoformer(torch.nn.Module): - def __init__(self, s: int, cm: int, cz: int, c: int, head: int): + def __init__(self, + s: int, + cm: int, + cz: int, + c=32, + msa_head=8, + pair_head=4, + c_tri_mult=128, + ff_mult=4): super().__init__() - self.s, self.cm, self.cz, self.c, self.head = s, cm, cz, c, head + self.s, self.cm, self.cz, self.c, self.msa_head, self.pair_head, self.c_tri_mult, self.ff_mult = s, cm, cz, c, msa_head, pair_head, c_tri_mult, ff_mult self.scale = 1.0 / math.sqrt(c) # MSA row-wise gated self-attention with pair bias self.row_norm_m = torch.nn.LayerNorm(cm) self.row_norm_z = torch.nn.LayerNorm(cz) - self.row_gate_proj = torch.nn.Parameter(torch.randn(cm, head * c)) - self.row_qkv_proj = torch.nn.Parameter(torch.randn(cm, 3 * head * c)) - self.row_out_proj = torch.nn.Parameter(torch.randn(head * c, cm)) - self.row_bias_proj = torch.nn.Parameter(torch.randn(cz, head)) + self.row_gate_proj = torch.nn.Parameter(torch.randn(cm, msa_head * c)) + self.row_qkv_proj = torch.nn.Parameter( + torch.randn(cm, 3 * msa_head * c)) + self.row_out_proj = torch.nn.Parameter(torch.randn(msa_head * c, cm)) + self.row_bias_proj = torch.nn.Parameter(torch.randn(cz, msa_head)) # MSA column-wise gated self-attention self.col_norm = torch.nn.LayerNorm(cm) - self.col_gate_proj = torch.nn.Parameter(torch.randn(cm, head * c)) - self.col_qkv_proj = torch.nn.Parameter(torch.randn(cm, 3 * head * c)) - self.col_out_proj = torch.nn.Parameter(torch.randn(head * c, cm)) + self.col_gate_proj = torch.nn.Parameter(torch.randn(cm, msa_head * c)) + self.col_qkv_proj = torch.nn.Parameter( + torch.randn(cm, 3 * msa_head * c)) + self.col_out_proj = torch.nn.Parameter(torch.randn(msa_head * c, cm)) # MSA transition self.msa_transition_norm = torch.nn.LayerNorm(cm) - self.msa_transition_proj1 = torch.nn.Parameter(torch.randn(cm, 4 * cm)) - self.msa_transition_proj2 = torch.nn.Parameter(torch.randn(4 * cm, cm)) + self.msa_transition_proj1 = torch.nn.Parameter( + torch.randn(cm, ff_mult * cm)) + self.msa_transition_proj2 = torch.nn.Parameter( + torch.randn(ff_mult * cm, cm)) # Outer product mean self.outer_norm = torch.nn.LayerNorm(cm) @@ -180,63 +192,76 @@ def __init__(self, s: int, cm: int, cz: int, c: int, head: int): # Triangular multiplicative update using outgoing edges self.tri_mul_out_norm1_weight = torch.nn.Parameter(torch.empty(cz)) self.tri_mul_out_norm1_bias = torch.nn.Parameter(torch.empty(cz)) - self.tri_mul_out_norm2_weight = torch.nn.Parameter(torch.empty(128)) - self.tri_mul_out_norm2_bias = torch.nn.Parameter(torch.empty(128)) - self.tri_mul_out_proj1 = torch.nn.Parameter(torch.randn(cz, 128)) - self.tri_mul_out_proj2 = torch.nn.Parameter(torch.randn(cz, 128)) - self.tri_mul_out_proj3 = torch.nn.Parameter(torch.randn(cz, 128)) - self.tri_mul_out_proj4 = torch.nn.Parameter(torch.randn(cz, 128)) - self.tri_mul_out_proj5 = torch.nn.Parameter(torch.randn(128, cz)) + self.tri_mul_out_norm2_weight = torch.nn.Parameter( + torch.empty(c_tri_mult)) + self.tri_mul_out_norm2_bias = torch.nn.Parameter( + torch.empty(c_tri_mult)) + self.tri_mul_out_proj1 = torch.nn.Parameter(torch.randn( + cz, c_tri_mult)) + self.tri_mul_out_proj2 = torch.nn.Parameter(torch.randn( + cz, c_tri_mult)) + self.tri_mul_out_proj3 = torch.nn.Parameter(torch.randn( + cz, c_tri_mult)) + self.tri_mul_out_proj4 = torch.nn.Parameter(torch.randn( + cz, c_tri_mult)) + self.tri_mul_out_proj5 = torch.nn.Parameter(torch.randn( + c_tri_mult, cz)) self.tri_mul_out_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) # Triangular multiplicative update using incoming edges self.tri_mul_in_norm1_weight = torch.nn.Parameter(torch.empty(cz)) self.tri_mul_in_norm1_bias = torch.nn.Parameter(torch.empty(cz)) - self.tri_mul_in_norm2_weight = torch.nn.Parameter(torch.empty(128)) - self.tri_mul_in_norm2_bias = torch.nn.Parameter(torch.empty(128)) - self.tri_mul_in_proj1 = torch.nn.Parameter(torch.randn(cz, 128)) - self.tri_mul_in_proj2 = torch.nn.Parameter(torch.randn(cz, 128)) - self.tri_mul_in_proj3 = torch.nn.Parameter(torch.randn(cz, 128)) - self.tri_mul_in_proj4 = torch.nn.Parameter(torch.randn(cz, 128)) - self.tri_mul_in_proj5 = torch.nn.Parameter(torch.randn(128, cz)) + self.tri_mul_in_norm2_weight = torch.nn.Parameter( + torch.empty(c_tri_mult)) + self.tri_mul_in_norm2_bias = torch.nn.Parameter( + torch.empty(c_tri_mult)) + self.tri_mul_in_proj1 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) + self.tri_mul_in_proj2 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) + self.tri_mul_in_proj3 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) + self.tri_mul_in_proj4 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) + self.tri_mul_in_proj5 = torch.nn.Parameter(torch.randn(c_tri_mult, cz)) self.tri_mul_in_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) # Triangular gated self-attention around starting node self.tri_att_start_norm = torch.nn.LayerNorm(cz) self.tri_att_start_gate_proj = torch.nn.Parameter( - torch.randn(cz, 4 * c)) + torch.randn(cz, pair_head * c)) self.tri_att_start_qkv_proj = torch.nn.Parameter( - torch.randn(cz, 3 * 4 * c)) - self.tri_att_start_out_proj = torch.nn.Parameter(torch.randn( - 4 * c, cz)) - self.tri_att_start_bias_proj = torch.nn.Parameter(torch.randn(cz, 4)) + torch.randn(cz, 3 * pair_head * c)) + self.tri_att_start_out_proj = torch.nn.Parameter( + torch.randn(pair_head * c, cz)) + self.tri_att_start_bias_proj = torch.nn.Parameter( + torch.randn(cz, pair_head)) # Triangular gated self-attention around ending node self.tri_att_end_norm = torch.nn.LayerNorm(cz) - self.tri_att_end_gate_proj = torch.nn.Parameter(torch.randn(cz, 4 * c)) + self.tri_att_end_gate_proj = torch.nn.Parameter( + torch.randn(cz, pair_head * c)) self.tri_att_end_qkv_proj = torch.nn.Parameter( - torch.randn(cz, 3 * 4 * c)) - self.tri_att_end_out_proj = torch.nn.Parameter(torch.randn(4 * c, cz)) - self.tri_att_end_bias_proj = torch.nn.Parameter(torch.randn(cz, 4)) + torch.randn(cz, 3 * pair_head * c)) + self.tri_att_end_out_proj = torch.nn.Parameter( + torch.randn(pair_head * c, cz)) + self.tri_att_end_bias_proj = torch.nn.Parameter( + torch.randn(cz, pair_head)) # Transition in the pair stack self.pair_transition_norm = torch.nn.LayerNorm(cz) - self.pair_transition_proj1 = torch.nn.Parameter(torch.randn( - cz, 4 * cz)) - self.pair_transition_proj2 = torch.nn.Parameter(torch.randn( - 4 * cz, cz)) + self.pair_transition_proj1 = torch.nn.Parameter( + torch.randn(cz, ff_mult * cz)) + self.pair_transition_proj2 = torch.nn.Parameter( + torch.randn(ff_mult * cz, cz)) def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): msa_repr = msa_repr + MSARowAttentionWithPairBias( self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, - self.head, self.c, self.scale) + self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) msa_repr = msa_repr + MSAAttention( self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - self.col_out_proj, None, self.head, self.c, self.scale) + self.col_out_proj, None, self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), @@ -261,18 +286,18 @@ def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): self.tri_mul_in_norm2_bias, self.tri_mul_in_proj1, self.tri_mul_in_proj2, self.tri_mul_in_proj3, self.tri_mul_in_proj4, self.tri_mul_in_proj5, - self.tri_mul_in_proj6, self.cz, True) + self.tri_mul_in_proj6, self.cz, False) pair_repr = pair_repr + TriangleAttentionNode( self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, - self.tri_att_start_bias_proj, 4, self.c, self.scale) + self.tri_att_start_bias_proj, self.pair_head, self.c, self.scale) pair_repr = pair_repr.transpose(-3, -2) pair_repr = pair_repr + TriangleAttentionNode( self.tri_att_end_norm(pair_repr), self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, - self.tri_att_end_bias_proj, 4, self.c, self.scale) + self.tri_att_end_bias_proj, self.pair_head, self.c, self.scale) pair_repr = pair_repr.transpose(-3, -2) pair_repr = pair_repr + PairTransition( @@ -283,8 +308,8 @@ def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): def test(): - bs, s, r, cm, cz, c, heads = 1, 128, 256, 256, 128, 32, 8 - model = Evoformer(s, cm, cz, c, heads) + bs, s, r, cm, cz = 1, 128, 256, 256, 128 + model = Evoformer(s, cm, cz) msa = torch.randn(bs, s, r, cm) pair = torch.randn(bs, r, r, cz) From 7590e5785380036e78adeca3dbe1a7b03886bb0a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Sep 2022 18:03:50 +0800 Subject: [PATCH 0982/1892] graph abstraction on hierarhical segment --- cube/__init__.py | 1 - cube/compiler.py | 55 +--- cube/graph/graph.py | 535 +++++++------------------------ cube/graph/parser/__init__.py | 2 +- cube/graph/parser/converter.py | 9 - cube/graph/segment.py | 562 +++++++++++++++++++++++++++++---- cube/ir/cten.py | 8 +- cube/ir/operator.py | 28 +- cube/ir/tensor.py | 134 +------- cube/logics/__init__.py | 0 cube/logics/dataloader.py | 30 -- cube/logics/model.py | 56 ---- cube/logics/pool.py | 50 --- cube/logics/translator.py | 100 ------ 14 files changed, 623 insertions(+), 947 deletions(-) delete mode 100644 cube/logics/__init__.py delete mode 100644 cube/logics/dataloader.py delete mode 100644 cube/logics/model.py delete mode 100644 cube/logics/pool.py delete mode 100644 cube/logics/translator.py diff --git a/cube/__init__.py b/cube/__init__.py index 534f909f..24ef255c 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,4 +1,3 @@ -from cube import logics from cube import runtime from cube.compiler import SemanticModel, compile diff --git a/cube/compiler.py b/cube/compiler.py index 33a9b6d6..2efa4b5a 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -5,15 +5,11 @@ import cube -from cube.graph import parser from cube.graph.gener.gen import IRAdapterGener from cube.graph.graph import IRGraph from cube.ir.operator import IRDataOperation from cube.graph.function.anchor import IRGraphAnchor -from cube.logics.pool import SchedulePool -from cube.logics.translator import LogicTranslator - from cube.execplan import ExecutionPlan from cube.execplan.planpass.fusion import DiffFusion from cube.execplan.planpass.grouping import Grouping @@ -23,47 +19,7 @@ from cube.profiler.timer import print_each_rank from cube.runtime.syndata import CubeDataLoader, SciLoopVariables - -class SemanticModel: - - def __init__(self, model: torch.nn.Module, input_shapes): - """ - Create semantic model based on AI Scientist description. - """ - dist = torch.distributed.is_initialized() - if (not dist) or (dist and torch.distributed.get_rank() == 0): - self.ir_graph = parser.convert_model( - model, input_shapes=input_shapes - ) - else: - self.ir_graph = None - self._loaded_module = None - - def get_graph(self): - return self.ir_graph - - def load_module(self, filename: str, load_content=True): - import importlib.util - spec = importlib.util.spec_from_file_location("GenModel", filename) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self._loaded_module = module.GenModel().cuda() - if load_content: - print_each_rank("> loading parameter content...") - # TODO: make hardcode ./fullmodel.pt programmable - self._loaded_module.load_attr_content('./fullmodel.pt') - - def get_gen_module(self): - return self._loaded_module - - def clear_module(self): - self._loaded_module = None - - def __call__(self, *args): - if self._loaded_module: - return self._loaded_module(*args) - else: - return self.ir_graph(*args) +from cube.program import Program, SemanticDataLoader, SemanticModel def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, @@ -105,7 +61,7 @@ def train_step(model, dataloader): PAS = (PAS,) model_graph = model.get_graph() - ir_dataloader = parser.convert_dataloader(dataloader) + ir_dataloader = SemanticDataLoader(dataloader) if torch.distributed.is_initialized(): # multiple device @@ -142,7 +98,6 @@ def decorator(fn: Callable) -> Callable: compile_start = time.time() - SchedulePool().clear() resource = cube.runtime.resource.EnvResource() # logic translator @@ -151,7 +106,7 @@ def decorator(fn: Callable) -> Callable: outputs = [] elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = [outputs] - graph = LogicTranslator.gen_logic_graph(outputs=outputs) + graph = Program().get_graph() if len(PAS) == 1: graph = PAS[0](graph, resource) @@ -167,8 +122,8 @@ def decorator(fn: Callable) -> Callable: # check assignment and remove anchor node for node in graph.nodes(): if isinstance(node, IRGraphAnchor) or isinstance(node.mirror, IRGraphAnchor): - graph.detach(node) - elif len(node.device) == 0: + continue + if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") # generate adapter diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5d3b9e9e..5c90507b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,14 +7,13 @@ will be inserted at scheduling time. """ -from contextlib import contextmanager -from typing import Union, Tuple, List, Optional, Dict, Set +from typing import Union, Tuple, List, Optional, Dict from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.ir.adapter import IRAdapter from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.dtype import IRDType, DTypeInferRule from cube.graph.function.function import Identity, MultiRef from cube.graph.segment import IRSegment @@ -22,62 +21,6 @@ from cube.algorithm.generics import GenericDistAlgo -class GraphIndex: - - def __init__(self, gidx: int, sidx: Optional[int]): - # inner-graph index - assert isinstance(gidx, int) - self.gidx = gidx - # inner-segment index - assert sidx is None or isinstance(sidx, int) - self.sidx: Optional[int] = sidx - - def __hash__(self) -> int: - return hash((self.gidx, self.sidx)) - - def __eq__(self, other: object) -> bool: - assert isinstance(other, GraphIndex), "Cannot compare with non-GraphIndex object" - return self.gidx == other.gidx and self.sidx == other.gidx - - def __lt__(self, other: object) -> bool: - assert isinstance(other, GraphIndex), "Cannot compare with non-GraphIndex object" - if self.gidx < other.gidx: - return True - if self.gidx > other.gidx: - return False - if isinstance(self.sidx, int) and isinstance(other.sidx, int): - return self.sidx < other.sidx - if self.sidx is None and isinstance(other.sidx, int): - return True - return False - - def __le__(self, other: object) -> bool: - return self < other or self == other - - def __gt__(self, other: object) -> bool: - return not self <= other - - def __ge__(self, other: object) -> bool: - return not self < other - - def __sub__(self, offset: int): - assert isinstance(offset, int) - if self.sidx is None: - return GraphIndex(self.gidx - offset, self.sidx) - else: - return GraphIndex(self.gidx, self.sidx - offset) - - def __add__(self, offset: int): - assert isinstance(offset, int) - if self.sidx is None: - return GraphIndex(self.gidx + offset, self.sidx) - else: - return GraphIndex(self.gidx, self.sidx + offset) - - def tuple(self) -> Tuple[int, Optional[int]]: - return (self.gidx, self.sidx) - - class IRGraph(IRSegment): """ IRGraph. @@ -85,37 +28,13 @@ class IRGraph(IRSegment): IRGraph is used for reprensting a distributed training iteration. """ - def __init__(self, - nodes: List[IRCell], - inputs: Optional[List[IRTensor]], - outputs: Optional[List[IRTensor]], + def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRTensor], module_name: str): - if inputs is None: - inputs = IRGraph.get_inputs(nodes) - if outputs is None: - outputs = IRGraph.get_outputs(nodes) - super().__init__([], inputs, outputs, module_name) - - # atrribute tensors - self._attributes: Set[IRSubTensor] = set() + super().__init__(nodes, inputs, outputs, module_name) self._sched = None # the schedule strategy - # set parameters / buffers - for node in nodes: - tensors = node.inputs() + node.outputs() - tensors = [t for t in tensors if isinstance(t, IRSubTensor)] - for t in tensors: - t.parent.clear_producer_consumer() - if t.is_attr(): - self._attributes.add(t) - - # insert node from nodes - for idx, node in enumerate(nodes): - self.insert(node, idx) - - self.reset_dependency() @property def train(self) -> bool: @@ -126,284 +45,98 @@ def train(self) -> bool: """ return self._have_forward and self._have_backward - def reset_dependency(self): - """ - Reset the node dataflow dependency - - Note all the predefined control dependencies will be removed. - """ - for node in self._nodes: - node.clear_predecessor() - node.clear_successor() - # TODO: adapter dependency not set - for ftensor in self._full_tensors: - for ptensor, producer in zip(ftensor.ptensors, ftensor.producers): - for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): - if ptensor.overlap(ctensor): - pidx = producer.outputs().index(ptensor) - cidx = consumer.inputs().index(ctensor) - producer.add_successor(pidx, consumer) - consumer.add_predecessor(cidx, producer) - # set mirror as control dependency - if producer.mirror and isinstance(producer, IRFwOperation): - producer.add_successor(-1, producer.mirror) - producer.mirror.add_predecessor(-1, producer) - - def attributes(self) -> Tuple[IRSubTensor]: - """ - Return parameter list - """ - return tuple(self._attributes) - - def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: - """ - forward will divide the graph into Actions according to - node device assignment - - Currently each forward call will result in a new flow - even if the input is same - - Returns: - IRTensors - """ - from cube.logics.translator import LogicTranslator - return LogicTranslator.forward(self, *args) + # ================ Deep Learning Interfalce ====================== def __call__(self, *args): """ Register forward action """ return self.forward(*args) - - # ====================== Graph Accessment ========================= - def flatten(self) -> List[IRCell]: - """ - Get all the single nodes by opening the segment. - - @return List[] - """ - nodes = [] - for node in self._nodes: - if isinstance(node, IRSegment): - nodes += node._nodes - else: - nodes.append(node) - return nodes - - def index(self, node: IRCell) -> GraphIndex: - """ - Get node index in the graph. - - @param node IRCell: the queried node - - @return index Tuple[int, Optional[int]]: (GraphIndex, SegmentIndex) - - """ - if node in self._nodes: - return GraphIndex(self._nodes.index(node), None) - for idx, check_node in enumerate(self._nodes): - if isinstance(check_node, IRSegment): - if check_node.exist(node): - return GraphIndex(idx, check_node.index(node)) - raise KeyError(f"The queried node: {node} not in the graph.") - - def flatten_index(self, node: IRCell) -> int: - """ - Get node index of all the flatten nodes - - @param node IRCell: the queried node, cannot be IRSegment - - @return index int: the index. - """ - idx = 0 - for check_node in self._nodes: - if isinstance(check_node, IRSegment): - if node in check_node._nodes: - return idx + check_node.index(node) - else: - idx += len(check_node.nnodes) - if check_node == node: - return idx - raise KeyError(f"Node {node} not exist in graph") - - def node(self, index: Union[int, GraphIndex]) -> IRCell: - """ - Get node given the index - - @param index Tuple[Optional[int], int]: the queired index of - (SegmentIndex, Index) - - @return node IRCell: the quried node. - """ - index = GraphIndex(index, None) if isinstance(index, int) else index - assert isinstance(index, GraphIndex) - node = self._nodes[index.gidx] - if index.sidx is not None: - assert isinstance(node, IRSegment), "Expected IRSegment" - node = node.index(index.sidx) - return node - - # ========================= Graph Manipulation ======================== - - def remove(self, node: IRCell) -> GraphIndex: - """ - Detach (remove) a node from current graph. - TODO: dataflow dependency update. - - * Producer/consumer relationship: - - All the used input and output tensors inside the node - are removed from consumed and produced tensor list. - - @param node IRCell: the removed node. - - @return index Tuple[int, Optional[int]]: index of the detached node in the graph + def forward(self, *args: Tuple[IRSubTensor]) -> Union[IRTensor, Tuple[IRTensor]]: """ - index = self.index(node) - - # remove node - if index.sidx is None: - self._nodes.pop(index.gidx) - else: - segment = self._nodes[index.gidx] - assert isinstance(segment, IRSegment), "Internal Error: Removing at a wrong index" - segment.remove(node) - - # update consumer and producer for non-adapter nodes - rm_nodes = node.nodes() if isinstance(node, IRSegment) else [node] - for node in rm_nodes: - # adapter doesn't need to consider producer and consumer - if isinstance(node, IRAdapter): - continue - # update consumer - itensors: List[IRSubTensor] = [] - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor) and itensor not in itensors: - itensors.append(itensor) - for itensor in itensors: - itensor.parent.rm_consumer(node) - # update producer - otensors: List[IRSubTensor] = [] - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor) and otensor not in otensors: - otensors.append(otensor) - for otensor in otensors: - otensor.parent.rm_producer(node) - ftensor = otensor.parent - if len(ftensor.producers) == 0 and len(ftensor.consumers) == 0: - del self._full_tensors[otensor.parent] - return index - - def insert(self, node: IRCell, index: Union[int, GraphIndex]): - """ - Insert a node into current graph at node index. - TODO: dataflow dependency update. + forward will divide the graph into Actions according to + node device assignment - * Producer/consumer relationship: + Currently each forward call will result in a new flow + even if the input is same - For the node except IRAdapter, all its input and output tensors - will be recorded in consumed and produced tensor list. + @param args Tuple[Any] + + @return outputs Union[IRSubTensor, Tuple[IRSubTensor]] + """ + # align graph with input tensors + itensors: Tuple[IRSubTensor, ...] = self.inputs() + assert len(args) == len(itensors) + for idx, (itensor, arg) in enumerate(zip(itensors, args)): + self.set_input(idx, arg) + for producer in self.producers(itensor.parent): + with self.update(producer): + while itensor in producer.outputs(): + oidx = producer.outputs().index(itensor) + producer.set_output(oidx, arg) + for consumer in self.consumers(itensor.parent): + with self.update(consumer): + while itensor in consumer.inputs(): + iidx = consumer.inputs().index(itensor) + consumer.set_input(iidx, arg) + while itensor in self.outputs(): + oidx = self.outputs().index(itensor) + self.set_output(oidx, arg) + while itensor in self.inputs(): + iidx = self.inputs().index(itensor) + self.set_input(iidx, arg) - IRAdapter node will not record the consumer and producer. - - @param node IRCell: the inserted node - @param index Union[int, Tuple[int, Optional[int]]]: the inserted index - """ - index = GraphIndex(index, None) if isinstance(index, int) else index - assert isinstance(index, GraphIndex) - - # update producer and consumer - in_nodes = node.nodes() if isinstance(node, IRSegment) else [node] - for node in in_nodes: - if isinstance(node, IRAdapter): continue - # update consumer - itensors: List[IRSubTensor] = [] - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor) and itensor not in itensors: - itensors.append(itensor) + # dtype inference + for node in self._nodes: + itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] + # setup gradient for itensor in itensors: - self._full_tensors.add(itensor.parent) - idx = 0 - for consumer in itensor.parent.consumers: - if self.index(consumer) < index: - idx += 1 - else: - break - itensor.parent.add_consumer(node, itensor, idx) - # update producer - otensors: List[IRSubTensor] = [] - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor) and otensor not in otensors: - otensors.append(otensor) - for otensor in otensors: - self._full_tensors.add(otensor.parent) - idx = 0 - for producer in otensor.parent.producers: - if self.index(producer) < index: - idx += 1 - else: - break - otensor.parent.add_producer(node, otensor, idx) - - # insert node - if index.sidx is None: - self._nodes.insert(index.gidx, node) + if itensor.parent.grad is not None: + itensor.parent.dtype = itensor.dtype + if len(itensors) == 0: continue + odtype = DTypeInferRule.infer(node, [t.dtype for t in itensors]) + assert odtype != IRDType.unknown, f"{node} : {[t.dtype for t in itensors]}" + otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + for tensor in otensors: + tensor.dtype = odtype + # setup graidient + if tensor.parent.grad is not None: + tensor.parent.grad.dtype = odtype + + from cube.program import Program + Program().add_nodes(self.nodes()) + + # return + if len(self.outputs()) == 1: + return self.output(0) else: - segment = self._nodes[index.gidx] - assert isinstance(segment, IRSegment), "Expected to be a segment" - segment.insert(node, index.sidx) + return self.outputs() - return - - def exist(self, node: IRCell) -> bool: + def backward(self, loss: IRSubTensor): """ - Check if the node is in the graph + Backward the graph from the entry tensor of loss. - @param node IRCell: the queried node - @return existence bool: True if exist otherwise False - """ - if node in self._nodes: - return True - for segment in self._nodes: - if not isinstance(segment, IRSegment): - continue - if segment.exist(node): - return True - return False + @param loss IRSubTensor: the loss tensor, must be in the output + of current graph. The loss shape should be (1,) - @contextmanager - def update(self, node): + @return self IRGraph: None """ - Update a node. - TODO: update operator dependency + assert loss in self.outputs() and tuple(loss.shape) == (1,), \ + f"backward should be in graph outputs and the loss is of shape [1,] (got {loss.shape})" + from cube.program import Program + loss.parent.grad = 1.0 + for fnode in self.nodes()[::-1]: + assert not isinstance(fnode, IRSegment), "Internal Error: Segment should not appear for now" + if isinstance(fnode, IRFwOperation): + bnode: IRBpOperation = self.bwop(fnode) + Program().add_node(bnode) + # set program graph mirror to self + Program().mirror_as_self() + return self - e.g., - with graph.modify(node) as node: - node.set_input(0, tensor) - - @param node IRCell: the node that must in the graph - @return node IRCell: the modify node - """ - index = self.remove(node) - yield node - self.insert(node, index) - def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: - """ - Replace one node by multiple nodes - - TODO: update dataflow dependency - - @param node IRCell: the replaced node - @param new_nodes List[IRCell]: the nodes to be inserted. - - @return index int: the replaced node index - """ - index = self.remove(node) - for new_node in new_nodes[::-1]: - self.insert(new_node, index) - return index + # ========================= Graph Manipulation ======================== def group(self, fnodes: List[IRCell]) -> IRSegment: """! @@ -416,84 +149,49 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: @return segment IRSegment: the grouped segment """ - assert any(not isinstance(node, (IRBpOperation, IRSegment, IRDataOperation)) for node in fnodes), \ + assert any(not isinstance(node, (IRBpOperation, IRDataOperation)) for node in fnodes), \ "grouped nodes cannot be backward operation, segment or data operation" - findices = [self.index(fnode) for fnode in fnodes] + fgraphs = [self.segment(fnode) for fnode in fnodes] + assert len(set(fgraphs)) == 1, "Cross-segment grouping is not allowed yet." # get backward nodes bnodes = [fnode.mirror for fnode in fnodes[::-1] if fnode.mirror is not None] - bindices = [self.index(bnode) for bnode in bnodes] + + fgraph: IRSegment = fgraphs[0] + bgraph: IRSegment = fgraph.mirror - assert all(idx.sidx is None for idx in findices), \ - "Grouping operators that are already in segment is not allowed" - assert all(idx.sidx is None for idx in bindices), \ - "Internal Error: backward operators found in segments" - findices = tuple(idx.gidx for idx in findices) - bindices = tuple(idx.gidx for idx in bindices) + findices: Tuple[int] = tuple(fgraph.index(fnode)[0] for fnode in fnodes) + bindices: Tuple[int] = tuple(bgraph.index(bnode)[0] for bnode in bnodes) minfidx, maxfidx = min(findices), max(findices) assert maxfidx - minfidx + 1 == len(fnodes), \ "Forward nodes are not consecutive" - + if len(bnodes) > 0: minbidx, maxbidx = min(bindices), max(bindices) assert maxbidx - minbidx + 1 == len(bnodes), \ f"Internal Error: backward nodes are not consecutive. maxbidx: {maxbidx}, minbidx: {minbidx}" - fsegment = self.segment(fnodes) - bsegment = self.segment(bnodes) if len(bnodes) > 0 else None + fsegment = fgraph.create_segment(fnodes) + bsegment = bgraph.create_segment(bnodes) if len(bnodes) > 0 else None IRCell.make_pair(fsegment, bsegment) + # replace forward + for fnode in fnodes: + fidx = fgraph.remove(fnode) + fgraph.insert(fsegment, fidx) + # replace backward if len(bnodes) > 0: - self._nodes = self._nodes[:minbidx] + [bsegment] + self._nodes[maxbidx+1:] - # replace forward - self._nodes = self._nodes[:minfidx] + [fsegment] + self._nodes[maxfidx+1:] + for bnode in bnodes: + bidx = bgraph.remove(bnode) + bgraph.insert(bsegment, bidx) return fsegment # ========================== Graph Creation ======================== - def segment(self, nodes: List[IRCell]) -> IRSegment: - """! - Create a segment with part of the nodes. - - @param nodes List[IRCell]: the subset nodes of this graph - - @return segment IRSegment: the grouped segment. - """ - inputs, outputs = [], [] - itdevs, otdevs = dict(), dict() - for node in nodes: - assert not isinstance(node, IRSegment), 'A segment cannot be in other segments' - # update inputs - itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] - for itensor in itensors: - producers = [p for p in itensor.parent.producers if set(p.device).issubset(set(node.device))] - # no producer means a weight or cross device-group - if len(producers) == 0 or any(p not in nodes for p in producers): - if itensor not in itdevs: - itdevs[itensor] = [] - devs = set(itensor.device) - if devs not in itdevs[itensor]: - inputs.append(itensor) - itdevs[itensor].append(devs) - # update outputs - otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] - for otensor in otensors: - consumers = [c for c in otensor.parent.consumers if set(c.device).issubset(set(node.device))] - # no consumer usually means the loss or cross device-group - if otensor in self.outputs() or len(consumers) == 0 or any(c not in nodes for c in consumers): - devs = set(otensor.device) - if otensor not in otdevs: - otdevs[otensor] = [] - if devs not in otdevs[otensor]: - outputs.append(otensor) - otdevs[otensor].append(devs) - segment = IRSegment(nodes, inputs, outputs) - return segment - @staticmethod def from_logic_graph(nodes: List[IRCell], inputs: List[IRFullTensor], outputs: List[IRFullTensor], @@ -609,29 +307,25 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") - if node not in self.nodes(): - raise RuntimeError(f"Op {node} not exsits") - + fsegment: IRSegment = self.segment(node) + # replicate fnodes = [node.replicate() for _ in range(times)] - # insert forward - self.replace(node, fnodes) for fnode in fnodes: if isinstance(node, IRFwOperation): fnode.recompute = node.recompute if isinstance(node.comment, str): fnode.comment = node.comment fnode.device = node.device - + fsegment.replace(node, fnodes) # insert backward + bsegment: IRSegment = fsegment.mirror if isinstance(node.mirror, IRBpOperation): bnode: IRBpOperation = node.mirror - for fnode in fnodes: - fnode.gen_backward() - bnodes = [fnode.mirror for fnode in fnodes[::-1]] - self.replace(bnode, bnodes) + bnodes = tuple(self.bwop(fnode) for fnode in fnodes[::-1]) for bnode in bnodes: bnode.device = node.device + bsegment.replace(bnode, bnodes) return fnodes def partition(self, node: Union[IRFwOperation, IRDataOperation], @@ -664,34 +358,33 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], "The partition algorithm is not initialized for this node" assert isinstance(node, (IRFwOperation, IRDataOperation)), \ f"Only allow op to be forward op or data op, but got: {node}" - + + fsegment: IRSegment = self.segment(node) # get partitioned sub-nodes fnodes = algo.instantiate(**config) - if fnodes is None: - return None - + assert fnodes is not None, f"Fail to partition node: {node} use algothim and config: {config}" # update forward - self.replace(node, fnodes) for fnode in fnodes: if isinstance(node, IRFwOperation): fnode.recompute = node.recompute if isinstance(node.comment, str): fnode.comment = node.comment fnode.device = node.device + fsegment.replace(node, fnodes) # update backward + bsegment: IRSegment = fsegment.mirror if isinstance(node.mirror, IRBpOperation): - bnodes = [fnode.gen_backward() for fnode in fnodes[::-1]] - self.replace(node.mirror, bnodes) + bnodes = tuple(self.bwop(fnode) for fnode in fnodes[::-1]) + bsegment.replace(node.mirror, bnodes) for bnode in bnodes: bnode.device = node.device # update gradient updated = set() for itensor in [t for t in node.inputs() if isinstance(t, IRSubTensor)]: - for fnode in itensor.parent.consumers: + for fnode in fsegment.consumers(itensor.parent): bnode: IRBpOperation = fnode.mirror if isinstance(bnode, IRBpOperation) and fnode.cid not in updated: - with self.update(bnode): - bnode.update() + self.update_bwop(bnode) updated.add(fnode.cid) return fnodes @@ -969,16 +662,16 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: # the gradient flow of neighbor stages is automatically guaranteed for ftensor in self.full_tensors(): if ftensor.is_grad() or ftensor.is_attr(): continue - assert len(ftensor.producers) <= 1, \ + assert len(self.producers(ftensor)) <= 1, \ "The staging interface should be called before any operator partition." - if len(ftensor.consumers) == 0: continue - producer, ptensor = ftensor.producers[0], ftensor.ptensors[0] + if len(self.consumers(ftensor)) == 0: continue + producer, ptensor = self.producers(ftensor), self.consumers(ftensor) psid = get_sid(producer) # outside of stages, not consider if psid is None: continue out = ptensor curr_sid = psid - for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): assert ctensor == ptensor, "The staging interface should be called before any operator partition." csid = get_sid(consumer) if curr_sid == csid: continue diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index cd7b3a25..5d2c539f 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,3 +1,3 @@ from cube.graph.parser.parser import ScriptModuleParser -from cube.graph.parser.converter import convert_model, convert_dataloader +from cube.graph.parser.converter import convert_model from cube.graph.parser.register import register \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index bde06b93..37f034de 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -3,7 +3,6 @@ from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser from cube.graph import IRGraph -from cube.logics.dataloader import IRDataLoader import torch @@ -25,11 +24,3 @@ def convert_model(model: torch.nn.Module, graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) return graph - -def convert_dataloader(dataloader) -> IRDataLoader: - """ - convert pytorch dataloader into IRDataLoader - """ - from cube.graph.parser.mapping import DType2IRDType - dataloader = IRDataLoader(dataloader, dtype_map=DType2IRDType) - return dataloader diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 5cc3ab60..28ccd02f 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1,5 +1,5 @@ -from typing import Union, List, Optional, Set -import copy +from contextlib import contextmanager +from typing import Dict, Union, List, Optional, Set, Tuple from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.cten import IRTensor, IRCell @@ -7,6 +7,60 @@ from cube.ir.adapter import IRAdapter +class CellPosition: + + def __init__(self, indices: Tuple[int]): + assert all(isinstance(idx, int) for idx in indices) and len(indices) > 0 + self.indices = tuple(indices) + + def __hash__(self) -> int: + return hash(self.indices) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, CellPosition), "Cannot compare with non-GraphIndex object" + return self.indices == other.indices + + def __lt__(self, other: object) -> bool: + assert isinstance(other, CellPosition), "Cannot compare with non-GraphIndex object" + if len(self.indices) < len(other.indices): + return True + if len(self.indices) > len(other.indices): + return False + for lidx, ridx in zip(self.indices, other.indices): + if lidx >= ridx: + return False + return True + + def __le__(self, other: object) -> bool: + return self < other or self == other + + def __gt__(self, other: object) -> bool: + return not self <= other + + def __ge__(self, other: object) -> bool: + return not self < other + + def __sub__(self, offset: int): + assert isinstance(offset, int) + indices = list(self.indices) + indices[-1] -= offset + return CellPosition(indices) + + def __add__(self, offset: int): + assert isinstance(offset, int) + indices = list(self.indices) + indices[-1] += offset + return CellPosition(indices) + + def __getitem__(self, idx: int) -> int: + return self.indices[idx] + + def __len__(self) -> int: + return len(self.indices) + + def __repr__(self) -> str: + return repr(self.indices) + class IRSegment(IRCell): """ @@ -19,24 +73,31 @@ class IRSegment(IRCell): def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRSubTensor], name='segment'): super().__init__(name, '', len(inputs), len(outputs), init_outputs=False) - self._nodes: List[IRCell] = nodes + self._nodes: List[IRCell] = [] self._idevice = [t.device for t in inputs] self._odevice = [t.device for t in outputs] - self._inputs = list(inputs) - self._outputs = list(outputs) - # for idx, val in enumerate(inputs): - # self.set_input(idx, val) - # for idx, val in enumerate(outputs): - # self.set_output(idx, val) + for idx, val in enumerate(inputs): + self.set_input(idx, val) + for idx, val in enumerate(outputs): + self.set_output(idx, val) + + # full-tensor / sub-tensor mapping + self._ftensors: Set[IRFullTensor] = set() + self._producers: Dict[IRFullTensor, List[IRCell]] = dict() + self._consumers: Dict[IRFullTensor, List[IRCell]] = dict() + self._ptensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() + self._ctensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() + + # attributes + self._attributes: Set[IRFullTensor] = set() - # full tensors - self._full_tensors: Set[IRFullTensor] = set() for node in nodes: - for tensor in node.inputs() + node.outputs(): - if isinstance(tensor, IRSubTensor): - self._full_tensors.add(tensor.parent) + self.insert(node, self.nnodes) + # self.reset_dependency() + + # FIXME: update when manipulating self._have_forward = any(isinstance(n, IRFwOperation) for n in nodes) self._have_backward = any(isinstance(n, IRBpOperation) for n in nodes) @@ -44,11 +105,48 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR def forward(self) -> bool: return self._have_forward - def full_tensors(self) -> List[IRFullTensor]: + def full_tensors(self) -> Tuple[IRFullTensor]: + """ + Get all full tensors of this graph. + Note the full tensor inside the node will not be returned. + + @return ftensors List[IRFullTensor] + """ + return tuple(self._ftensors) + + def attributes(self) -> Tuple[IRFullTensor]: + """ + Get al full tensor attributes of this graph + Note the full tensor inside the node will not be returned. + + @return ftensors List[IRFullTensor] + """ + return Tuple(self._attributes) + + def reset_dependency(self): """ - Return full tensor list + Reset the node dataflow dependency + + FIXME + + Note all the predefined control dependencies will be removed. """ - return list(self._full_tensors) + for node in self._nodes: + node.clear_predecessor() + node.clear_successor() + # TODO: adapter dependency not set + for ftensor in self._ftensors: + for ptensor, producer in zip(ftensor.ptensors, ftensor.producers): + for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + if ptensor.overlap(ctensor): + pidx = producer.outputs().index(ptensor) + cidx = consumer.inputs().index(ctensor) + producer.add_successor(pidx, consumer) + consumer.add_predecessor(cidx, producer) + # set mirror as control dependency + if producer.mirror and isinstance(producer, IRFwOperation): + producer.add_successor(-1, producer.mirror) + producer.mirror.add_predecessor(-1, producer) # ========================= Basic Graph access ======================= @@ -70,28 +168,41 @@ def nnodes(self) -> int: """ return len(self._nodes) - def nodes(self, idx: Optional[int] = None) -> Union[IRCell, List[IRCell]]: + def nodes(self, flatten = False) -> Tuple[IRCell]: """ Get all the nodes. + @param flatten bool: Flat the segment to get all the nested cells + @return nodes List[IRCell]: all the nodes """ - if isinstance(idx, int): - return self._nodes[idx] - else: - return copy.copy(self._nodes) + if not flatten: + return tuple(self._nodes) + nodes = [] + for node in self._nodes: + if not isinstance(node, IRSegment): + nodes.append(node) + else: + nodes += list(node.nodes(flatten)) + return tuple(nodes) - def node(self, index: int) -> IRCell: + def node(self, index: Union[int, CellPosition]) -> IRCell: """ Get node at position index - @param index int: the node index + @param index Union[int, CellPosition]: the node index @return node IRCell: the node. """ - return self._nodes[index] - - def index(self, node: IRCell) -> int: + pos = CellPosition((index,)) if isinstance(index, int) else index + assert isinstance(pos, CellPosition), "Expect index to be int or CellPosition" + node = self + for idx in pos.indices: + assert isinstance(node, IRSegment), "idx applies on a non-segment node" + node = node._nodes[idx] + return node + + def index(self, node: IRCell) -> CellPosition: """ Get node index. @@ -99,24 +210,228 @@ def index(self, node: IRCell) -> int: @return index int: the index """ - return self._nodes.index(node) + if node in self._nodes: + return CellPosition((self._nodes.index(node),)) + for idx, segment in enumerate(self._nodes): + if isinstance(segment, IRSegment): + if segment.exist(node): + index = segment.index(node) + return CellPosition((idx,) + index.indices) + raise KeyError(f"The queried node: {node} not in the graph") + + def segment(self, node: IRCell) -> IRCell: + """ + Get the lowest segment that constains the node + + @param node IRCell: the queried node + + @return segment IRSegment + """ + assert isinstance(node, IRCell) + index = self.index(node) + if len(index) == 1: + return self + else: + return self.node(CellPosition(index.indices[:-1])) + + def producers(self, ftensor: IRFullTensor) -> Tuple[IRCell]: + """ + Get producers of ftensor in execution order in this graph + + @param ftensor IRFullTensor: the queried full tensor. + + @return subtensors Tuple[IRSubTensor]: the producers. + """ + assert ftensor in self._producers, f"{ftensor} is not in the graph" + return tuple(self._producers[ftensor]) + + def consumers(self, ftensor: IRFullTensor) -> Tuple[IRCell]: + """ + Get consumers of ftensor in execution order in this graph + + @param ftensor IRFullTensor: the queried full tensor. + + @return subtensors Tuple[IRCell]: theconsumers + """ + assert ftensor in self._consumers, f"{ftensor} is not in the graph" + return tuple(self._consumers[ftensor]) + + def ptensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: + """ + Get consumed sub-tensors of ftensor in execution order in this graph + + @param ftensor IRFullTensor: the queried full tensor. + + @return subtensors Tuple[IRSubTensor]: the consumed subtensors. + """ + assert ftensor in self._ptensors, f"{ftensor} is not in the graph" + return tuple(self._ptensors[ftensor]) + + def ctensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: + """ + Get consumed sub-tensors of ftensor in execution order in this graph + + @param ftensor IRFullTensor: the queried full tensor. + + @return subtensors Tuple[IRSubTensor]: the consumed subtensors. + """ + assert ftensor in self._ctensors, f"{ftensor} is not in the graph" + return tuple(self._ctensors[ftensor]) + + def grad(self, tensor: IRSubTensor) -> IRSubTensor: + """ + Get gradient of the tensor. + + @param tensor IRSubTensor: IRSubTensor: the queried tensor + + @return gradient IRSubTensor: the gradient + """ + segment: IRSegment = self.segment(tensor.cell) + assert isinstance(tensor, IRSubTensor), "Only tensor has gradient" + fgrad = tensor.parent.grad + # None means no gradient requirement, flaot means its the loss + if fgrad is None or isinstance(fgrad, float): + return fgrad + ftensor = tensor.parent + # this tensor is consumed + if tensor in tensor.cell.inputs(): + consumers = [] + for ctensor, consumer in zip(segment.ctensors(ftensor), segment.consumers(ftensor)): + assert not (ctensor != tensor and ctensor.overlap(tensor)), "parital overlap is not supported for gradient" + if ctensor == tensor and consumer not in consumers: + consumers.append(consumer) + # segment.debug_print_tensor_map() + valmap = (consumers.index(tensor.cell), len(consumers)) + grad = ftensor.grad.select( + indmap = tensor.indmap, + valmap = valmap + ) + # this tensor is produced + elif tensor in tensor.cell.outputs(): + grad = ftensor.grad.select( + indmap = tensor.indmap, + valmap = (0, 1), + ) + return grad + + def debug_print_tensor_map(self): + for ftensor in self._ftensors: + print(f'Full Tensor: {ftensor}') + print(f'Producers:') + for producer in self._producers[ftensor]: + print(f'\t{producer}') + print(f'Consumers:') + for producer in self._consumers[ftensor]: + print(f'\t{producer}') + + def bwop(self, fwop: IRFwOperation) -> IRBpOperation: + """ + Create dummy backward operator for given forward operator + """ + assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" + fsegment: IRSegment = self.segment(fwop) + igrads = [fsegment.grad(t) if isinstance(t, IRSubTensor) else None for t in fwop.inputs()] + ograds = [fsegment.grad(t) if isinstance(t, IRSubTensor) else None for t in fwop.outputs()] + bwop = IRBpOperation(ograds, igrads) + IRCell.make_pair(fwop, bwop) + return bwop + + def update_bwop(self, bwop: IRBpOperation) -> IRBpOperation: + """ + Update backward operator. + + This is neccessary when fwop is partitioned and reference count is changed. + + @param bwop IRBpOperation: the backward operation. + It can be at any hierarchy of this segemtn + + @return bwop IRBpOperation: the updated operation (inplace) + """ + bsegment: IRSegment = self.segment(bwop) + fsegment = bsegment.mirror + with bsegment.update(bwop): + fwop: IRFwOperation = bwop.mirror + igrads = [fsegment.grad(t) if isinstance(t, IRSubTensor) else None for t in fwop.inputs()] + for idx, igrad in enumerate(igrads): + bwop.set_output(idx, igrad) + ograds = [fsegment.grad(t) if isinstance(t, IRSubTensor) else None for t in fwop.outputs()] + for idx, ograd in enumerate(ograds): + bwop.set_input(idx, ograd) + return bwop # ====================== Basic Graph manipulations ====================== - def insert(self, node: IRCell, index: int): + def _add_ftensor(self, ftensor: IRFullTensor): + """ + Add a full tensor in segment if the segment doesn't have the tensor. + """ + assert isinstance(ftensor, IRFullTensor) + if ftensor not in self._ftensors: + self._ftensors.add(ftensor) + self._producers[ftensor] = [] + self._consumers[ftensor] = [] + self._ptensors[ftensor] = [] + self._ctensors[ftensor] = [] + if ftensor.is_attr(): + self._attributes.add(ftensor) + + def _remove_ftensor(self, ftensor: IRFullTensor): + """ + Remove a full tensor in segment + """ + assert isinstance(ftensor, IRFullTensor) + if ftensor in self._ftensors: + self._ftensors.remove(ftensor) + del self._producers[ftensor] + del self._consumers[ftensor] + del self._ptensors[ftensor] + del self._ctensors[ftensor] + if ftensor.is_attr() and ftensor in self._attributes: + self._attributes.remove(ftensor) + + def insert(self, node: IRCell, index: Union[int, CellPosition]): """ Insert a node at index. - TODO: check input and output + TODO: dataflow dependency update + TODO: input / output check @param node IRCell: the inserted node @param index int: the index """ - assert node not in self._nodes, f"duplicated insertation of node: {node}" - self._nodes.insert(index, node) + pos = CellPosition((index,)) if isinstance(index, int) else index + assert isinstance(pos, CellPosition), "Expect index to be int or CellPosition" + + if len(pos) == 1: + index = pos[0] + # insert node + self._nodes.insert(index, node) + # update producer and consumer + if isinstance(node, IRAdapter): return + # consumer + itensors = set(t for t in node.inputs() if isinstance(t, IRSubTensor)) + for itensor in itensors: + ftensor = itensor.parent + self._add_ftensor(ftensor) + idx = len([c for c in self._consumers[ftensor] if self._nodes.index(c) < index]) + self._consumers[ftensor].insert(idx, node) + self._ctensors[ftensor].insert(idx, itensor) + # producer + otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) + for otensor in otensors: + ftensor = otensor.parent + self._add_ftensor(ftensor) + idx = len([c for c in self._producers[ftensor] if self._nodes.index(c) < index]) + self._producers[ftensor].insert(idx, node) + self._ptensors[ftensor].insert(idx, otensor) + else: + segment = self._nodes[pos[0]] + assert isinstance(segment, IRSegment), "Expected IRSegment" + pos = CellPosition(pos.indices[1:]) + segment.insert(node, pos) - def remove(self, node: IRCell) -> int: + def remove(self, node: IRCell, _pos: CellPosition = None) -> CellPosition: """ Remove a node at index @@ -124,12 +439,41 @@ def remove(self, node: IRCell) -> int: @param node IRCell: the removed node - @return index int: the removed index + @return index CellPosition: the removed index """ - assert node in self._nodes, f"The removed node doesn't exist" - index = self._nodes.index(node) - self._nodes.pop(index) - return index + pos = self.index(node) if _pos is None else _pos + assert self.node(pos) == node, "posititon doesn't not match with node" + + if len(pos.indices) == 1: + index = pos[0] + # remove + self._nodes.pop(index) + # update producer and consumer + if isinstance(node, IRAdapter): return pos + # consumer + itensors = set(t for t in node.inputs() if isinstance(t, IRSubTensor)) + for itensor in itensors: + ftensor = itensor.parent + idx = self._consumers[ftensor].index(node) + self._consumers[ftensor].pop(idx) + self._ctensors[ftensor].pop(idx) + if len(self._consumers[ftensor]) == 0 and len(self._producers[ftensor]) == 0: + self._remove_ftensor(ftensor) + # producer + otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) + for otensor in otensors: + ftensor = otensor.parent + idx = self._producers[ftensor].index(node) + self._producers[ftensor].pop(idx) + self._ptensors[ftensor].pop(idx) + if len(self._consumers[ftensor]) == 0 and len(self._producers[ftensor]) == 0: + self._remove_ftensor(ftensor) + else: + segment = self._nodes[pos[0]] + assert isinstance(segment, IRSegment) + segment.remove(node, _pos=CellPosition(pos.indices[1:])) + + return pos def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: """ @@ -143,9 +487,27 @@ def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: @return index int: the replaced node index """ idx = self.remove(node) - self._nodes = self._nodes[:idx] + list(new_nodes) + self._nodes[idx:] + for new_node in new_nodes[::-1]: + self.insert(new_node, idx) return idx + @contextmanager + def update(self, node): + """ + Update a node. + TODO: update operator dependency + + e.g., + with graph.modify(node) as node: + node.set_input(0, tensor) + + @param node IRCell: the node that must in the graph + @return node IRCell: the modify node + """ + index = self.remove(node) + yield node + self.insert(node, index) + def exist(self, node: IRCell) -> bool: """ Check if the node is in this graph @@ -154,7 +516,13 @@ def exist(self, node: IRCell) -> bool: @return exsit bool: True if exist otherwise False """ - return node in self._nodes + if node in self._nodes: + return True + for segment in self._nodes: + if not isinstance(segment, IRSegment): continue + if segment.exist(node): + return True + return False # ====================== Graph Generations ============================ @@ -206,51 +574,107 @@ def get_outputs(nodes: List[IRCell]): continue return outputs + def create_segment(self, nodes: List[IRCell]) -> IRCell: + """! + Create a segment with part of the nodes. + This only return the created segment wihout modifying the graph. + + @param nodes List[IRCell]: the subset nodes of this graph + + @return segment IRSegment: the grouped segment. + """ + segments: List[IRSegment] = [self.segment(node) for node in nodes] + assert len(set(segments)) == 1, "Cross segment hierarchy grouping is not allowed" + segment = segments[0] + + inputs, outputs = set(), set() + for node in nodes: + # update inputs + itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] + for itensor in itensors: + ftensor = itensor.parent + if itensor.is_attr(): continue + # from segment inputs + if any(t.overlap(itensor) for t in segment.inputs() if isinstance(t, IRSubTensor)): + inputs.add(itensor) + continue + # from outside producers + for ptensor, producer in zip(segment.ptensors(ftensor), segment.producers(ftensor)): + if ptensor.overlap(itensor) and producer not in nodes: + inputs.add(itensor) + continue + # update outputs + otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + for otensor in otensors: + ftensor = otensor.parent + if otensor.is_attr(): continue + # from segment outputs + if any(t.overlap(otensor) for t in segment.outputs() if isinstance(t, IRSubTensor)): + outputs.add(otensor) + continue + # for outside consumers + for ctensor, consumer in zip(segment.ctensors(ftensor), segment.consumers(ftensor)): + if ctensor.overlap(otensor) and consumer not in nodes: + outputs.add(otensor) + continue + segment = IRSegment(nodes, tuple(inputs), tuple(outputs)) + return segment + ###### ============ Transformation Primitives ============ ####### - def dispatch(self, devid: int, for_mirror=True) -> Optional[IRCell]: + def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: """ - Instantiate from distributed representation to a - device-specific sub-graph. - - The mirror will also be dispatched if it is not None. + Instantiate the segement to a specific device. + + @param devid int: the target device - Return the dispatched segment + @return segment IRSegment: the instantiated segment """ if devid not in self.device: return None if len(self.device) == 1 and self.device == [devid]: return self - itensors = [t for t, device in zip(self.inputs(), self._idevice) if devid in device] - otensors = [t for t, device in zip(self.outputs(), self._odevice) if devid in device] - nodes = [n for n in self.nodes() if devid in n.device] - for idx, adapter in enumerate(nodes): - if isinstance(adapter, IRAdapter): - nodes[idx] = adapter.dispatch(devid) - fseg = IRSegment(nodes, itensors, otensors) - fseg._id = self._id - # dispatch for mirror - if for_mirror and isinstance(self.mirror, IRSegment): - bseg = self.mirror.dispatch(devid, for_mirror=False) - IRCell.make_pair(fseg, bseg) - return fseg + inputs, outputs, nodes = [], [], [] + for node in self._nodes: + if devid in node.device: + if isinstance(node, IRAdapter): + nodes.append(node.dispatch(devid)) + elif isinstance(node, IRSegment): + nodes.append(node.dispatch(devid)) + else: + assert len(node.device) == 1 + nodes.append(node) + for itensor in node.inputs(): + if itensor in self._inputs: + inputs.append(itensor) + for otensor in node.outputs(): + if otensor in self._outputs: + otensor.append(otensor) + outputs.append(otensor) + segment = IRSegment(nodes, inputs, outputs, self.name) + if mirror and segment.mirror is not None: + msegment = segment.mirror.dispatch(devid, mirror=False) + IRCell.make_pair(segment, msegment) + return segment # ========================== Graph Visualize ================================ - def to_str(self, skip_attr: bool = False) -> str: - name = ('f' if self.forward else 'b') + 'Segment' - inputs = tuple(t for t in self.inputs() if not (t.is_attr() and skip_attr)) - outputs = tuple(t for t in self.outputs() if not (t.is_attr() and skip_attr)) - return f'{name}{self._id}-{self.device}(inputs={inputs}, outputs={outputs})' - def __repr__(self): - return self.to_str() + dscp = f"Graph{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" + return dscp def extra_repr(self) -> str: - dscp = repr(self) - for node in self.nodes(): - dscp += '\n\t' + repr(node) + dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" + # inputs + dscp += f"Inputs: {self.inputs()}\n" + for node in self._nodes: + dscp += f"\n{node}" + if isinstance(node, IRSegment): + for subnode in node.nodes(): + dscp += f"\n\t{subnode}" + # outputs + dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" return dscp diff --git a/cube/ir/cten.py b/cube/ir/cten.py index aa938450..d7b8d984 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -605,12 +605,14 @@ def nelement(self) -> int: cnt *= num return cnt - def backward(self): + def backward(self) -> IRCell: """ Autograd backward on the tensor + + @return graph IRGraph: the forward + backward graph """ - from cube.logics.translator import LogicTranslator - return LogicTranslator.backward(self) + return self.cell.backward(self) + def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device})' diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 9bd1c82d..bec66a4a 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Any import copy from cube.ir.cten import IRCell, IRTensor @@ -122,22 +122,6 @@ def replicate(self): cpy.clear_successor() return cpy - def gen_backward(self) -> IRCell: - """! - Generate backward operator for this forward operator. - - Note by calling this API, this forward operator must be - attached into any of one IRGraph, or will lead to reference - count 0 error on gradient calcaultion. - - return: IRBpOperation - """ - if self.mirror is not None: - raise RuntimeError( - "Backward Op already generated. Use self.mirror.update() instead.") - bnode = IRBpOperation(self) - return bnode - def __repr__(self) -> str: sign = self.signature.split('.')[-1] ins = [t for t in self.inputs() if isinstance(t, IRTensor) and not t.is_attr()] @@ -160,22 +144,16 @@ class IRBpOperation(IRCell): Backward operation """ - def __init__(self, fwop: IRFwOperation): + def __init__(self, ograds: Tuple[Any], igrads: Tuple[Any]): """ Create dummy backward node for forward inputs and forward outputs @param fwop IRFwOperation: forward operator """ - assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" - finputs, foutputs = fwop.inputs(), fwop.outputs() super().__init__( 'backward', 'torch.autograd.grad', - len(foutputs), len(finputs), init_outputs=False + len(ograds), len(igrads), init_outputs=False ) - # pair forward op and backward op - IRCell.make_pair(self, fwop) - # set inputs and outputs - self.update() def update(self): """ diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index a9d9359e..013d9812 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -254,14 +254,6 @@ def __init__(self, shape=None, name=None, requires_grad=True, dtype=IRDType.unkn super().__init__(shape, name, dtype) - # producer cell and produced sub tensor - self._producers: List[IRCell] = list() - self._ptensors : List[IRSubTensor] = list() - - # consumer cell and consumed sub tensor - self._consumers: List[IRCell] = list() - self._ctensors : List[IRSubTensor] = list() - # record all created sub_tensors self._segments : Dict[(ValueMap, IndexMap), int] = dict() @@ -288,87 +280,6 @@ def like(self): tensor = IRFullTensor(self.shape, self.name, self.requires_grad, self.dtype) return tensor - @property - def producers(self) -> Tuple[IRCell]: - """ - Producer IRCell list - """ - return tuple(self._producers) - - @property - def ptensors(self) -> Tuple[IRTensor]: - """ - Produced IRSubTensor list correspongding to producer IRCell - """ - return tuple(self._ptensors) - - @property - def consumers(self) -> Tuple[IRCell]: - """ - Consumer IRCell list - """ - return tuple(self._consumers) - - @property - def ctensors(self) -> Tuple[IRTensor]: - """ - Consumed IRSubTensor list correspongding to consumer IRCell - """ - return tuple(self._ctensors) - - def add_producer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): - if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): - raise TypeError("Expect an IRCell and an IRTensor") - assert cell not in self._producers, f"{cell} already exists as producer" - self._producers.insert(idx, cell) - self._ptensors.insert(idx, tensor) - - def add_consumer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): - """! - Add the tensor and its operator into consumer list. - The tensor should be in cell.inputs() - - @param cell IRCell: node to be consumer - @param tensor IRTensor: tensor to be consumed tensors - @param idx int: the index to be inserted - """ - assert tensor in cell.inputs(), f"tensor {tensor} not in node: {cell} inputs" - if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): - raise TypeError("Expect an IRCell and an IRTensor") - if cell in self._consumers: - for idx, consumer in enumerate(self._consumers): - if cell == consumer: - assert self._ctensors[idx] != tensor, f"double add a same consumer-tensor pair: {cell}" - self._consumers.insert(idx, cell) - self._ctensors.insert(idx, tensor) - for t in self._ctensors: - t._dirty_grad = True - - def rm_producer(self, cell: IRCell) -> int: - if cell not in self._producers: - raise KeyError(f"Cell {cell} not found in producer") - while cell in self._producers: - idx = self._producers.index(cell) - self._producers.pop(idx) - self._ptensors.pop(idx) - return idx - - def rm_consumer(self, cell: IRCell) -> int: - if cell not in self._consumers: - raise KeyError(f"Cell {cell} not found in producer") - idx = self._consumers.index(cell) - self._consumers.pop(idx) - self._ctensors.pop(idx) - for t in self._ctensors: - t._dirty_grad = True - return idx - - def clear_producer_consumer(self) -> int: - self._producers = [] - self._ptensors = [] - self._consumers = [] - self._ctensors = [] - @property def grad(self) -> Optional[Union[IRTensor, float]]: return self._grad @@ -383,8 +294,6 @@ def grad(self, val: Optional[Union[IRTensor, float]]): self._requires_grad = False if val is None else True if isinstance(val, IRFullTensor): assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." - for tensor in self._ctensors + self._ptensors: - tensor._dirty_grad = True @property def requires_grad(self): @@ -397,8 +306,6 @@ def requires_grad(self, val: bool): self.grad = IRFullTensor(self.shape, 'g' + self.name, False, dtype=self.dtype).as_grad() elif not val and self.grad is not None: self.grad = None - for tensor in self._ctensors + self._ptensors: - tensor._dirty_grad = True @property def dtype(self) -> IRDType: @@ -510,8 +417,6 @@ def __init__(self, ftensor: IRFullTensor, self._indmap: IndexMap = indmap # val map self._valmap: ValueMap = valmap - # grad flag - self._dirty_grad = True def __eq__(self, other) -> bool: if isinstance(other, IRSubTensor): @@ -669,8 +574,7 @@ def __copy__(self): tensor._cell = None return tensor - @property - def grad(self) -> Optional[Union[IRTensor, float]]: + def grad(self, graph: IRCell) -> Optional[IRTensor]: """ Get gradient of this tensor. @@ -683,41 +587,7 @@ def grad(self) -> Optional[Union[IRTensor, float]]: The gradient will be lazy updated when its IRFullTensor gets new consumed / produced tensors """ - if not self._dirty_grad: - return self._grad - - assert isinstance(self.cell, IRCell), "No cell attached to this tensor." - full_grad = self.parent.grad - if full_grad is None or isinstance(full_grad, float): - self._grad = full_grad - # this tensor is consumed - elif self in self.cell.inputs(): - # for backard, we assume in final distributed graph, - # each tensor can be represented as nested - consumers = [] - for ctensor, consumer in zip(self.parent.ctensors, self.parent.consumers): - if ctensor == self and consumer.cid not in consumers: - consumers.append(consumer.cid) - valmap = (consumers.index(self.cell.cid), len(consumers)) - grad = full_grad.select( - indmap = self.indmap, - valmap = valmap, - ) - self._grad = grad - self._dirty_grad = False - return grad - # this tensor is produced - elif self in self.cell.outputs(): - grad = full_grad.select( - indmap = self.indmap, - valmap = (0, 1), - ) - self._grad = grad - else: - raise RuntimeError("Visit gradient of a tensor that is potentially generated by IRAdapter") - self._dirty_grad = False - self._requires_grad = False if full_grad is None else True - return self._grad + return graph.grad(self) @property def requires_grad(self) -> bool: diff --git a/cube/logics/__init__.py b/cube/logics/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/logics/dataloader.py b/cube/logics/dataloader.py deleted file mode 100644 index 4ad72e1d..00000000 --- a/cube/logics/dataloader.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Tuple -from cube.runtime.syndata import CubeDataLoader - - -class IRDataLoader: - - def __init__(self, dataloader: CubeDataLoader, dtype_map): - if not isinstance(dataloader, CubeDataLoader): - raise TypeError("Expected data loader derived from CubeDataLoader") - self.dataloader: CubeDataLoader = iter(dataloader) - self.dtypes = [dtype_map.map(dtype) for dtype in dataloader.dtypes] - self.shapes = [list(shape) for shape in dataloader.shapes] - - def get_batch_dims(self) -> Tuple[int]: - return tuple(self.dataloader.batch_dims) - - def get_batch_size(self) -> int: - return self.dataloader.get_batch_size() - - def set_batch_size(self, bs: int): - self.dataloader.set_batch_size(bs) - return - - def __iter__(self): - return self - - def __next__(self): - from cube.logics.translator import LogicTranslator - datas = LogicTranslator.load_data(self) - return datas diff --git a/cube/logics/model.py b/cube/logics/model.py deleted file mode 100644 index a6dc1a9d..00000000 --- a/cube/logics/model.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Tuple, List -import copy - -from cube.graph.graph import IRGraph -from cube.ir.dtype import IRDType, DTypeInferRule -from cube.ir.tensor import IRSubTensor -from cube.ir.operator import IRFwOperation - - -def forward(graph: IRGraph, *args) -> IRGraph: - """ - Forward the IRGraph, replacing all the intermediate tensors - """ - if not isinstance(graph, IRGraph): - raise TypeError("Requires IRGraph for forward") - - # align graph with input tensors - itensors: Tuple[IRSubTensor, ...] = graph.inputs() - for idx, (itensor, arg) in enumerate(zip(itensors, args)): - graph.set_input(idx, arg) - for producer in copy.copy(itensor.parent.producers): - with graph.update(producer): - while itensor in producer.outputs(): - oidx = producer.outputs().index(itensor) - producer.set_output(oidx, arg) - for consumer in copy.copy(itensor.parent.consumers): - with graph.update(consumer): - while itensor in consumer.inputs(): - iidx = consumer.inputs().index(itensor) - consumer.set_input(iidx, arg) - while itensor in graph.outputs(): - oidx = graph.outputs().index(itensor) - graph.set_output(oidx, arg) - - # dtype inference - for node in graph.nodes(): - itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] - # setup gradient - for itensor in itensors: - if itensor.parent.grad is not None: - itensor.parent.dtype = itensor.dtype - if len(itensors) == 0: - continue - odtype = DTypeInferRule.infer(node, [t.dtype for t in itensors]) - assert odtype != IRDType.unknown, f"{node} : {[t.dtype for t in itensors]}" - otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] - for tensor in otensors: - tensor.dtype = odtype - # setup graidient - if tensor.parent.grad is not None: - tensor.parent.grad.dtype = odtype - - # generate backward reverse is only to make op id looks consecutive - for fnode in [n for n in graph.nodes() if isinstance(n, IRFwOperation)][::-1]: - fnode.gen_backward() - return graph diff --git a/cube/logics/pool.py b/cube/logics/pool.py deleted file mode 100644 index fd9f2045..00000000 --- a/cube/logics/pool.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import List, Any -import copy - - -class SchedulePool: - - class __SchedulePool: - - def __init__(self): - - self._nodes = list() - self._tapes = dict() - - instance = None - - def __init__(self): - if not SchedulePool.instance: - SchedulePool.instance = SchedulePool.__SchedulePool() - - def __getattr__(self, name): - return getattr(self.instance, name) - - def add_node(self, node): - self.instance._nodes.append(node) - - def nodes(self) -> List: - return copy.copy(self.instance._nodes) - - def tape(self, tensor, trace: Any): - """ - Record the trace generated to this tensor - """ - self.instance._tapes[tensor._id] = trace - - def get_tape(self, tensor): - """ - Get the trace given the tensor - """ - if tensor._id not in self.instance._tapes: - return None - else: - return self.instance._tapes[tensor._id] - - def clear(self): - self.instance._nodes = list() - self.instance._tapes = dict() - - def __repr__(self): - dscp = '\n'.join([repr(node) for node in self._nodes]) - return dscp diff --git a/cube/logics/translator.py b/cube/logics/translator.py deleted file mode 100644 index 5c2db40d..00000000 --- a/cube/logics/translator.py +++ /dev/null @@ -1,100 +0,0 @@ -from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor - -from cube.graph.graph import IRGraph - -from cube.logics.dataloader import IRDataLoader -from cube.logics import model -from cube.logics.pool import SchedulePool - - - -class LogicTranslator: - - @staticmethod - def gen_logic_graph(outputs=None) -> IRGraph: - """ - Generate Training Logic Graph - """ - nodes = SchedulePool().nodes() - has_bp = any(n for n in nodes if isinstance(n, IRBpOperation)) - if has_bp: - assert all(fnode.mirror in nodes for fnode in nodes if isinstance(fnode, IRFwOperation)), \ - "Training requires all nodes have backward." - else: - # remove backward nodes if no backward is called - fnodes = [node for node in nodes if isinstance(node, IRFwOperation)] - for fnode in fnodes: - IRCell.make_pair(fnode, None) - # remove node gradient - for node in nodes: - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor): - itensor.parent.requires_grad = False - # ad hoc fix on operators with multiple same input tensors - itensor._dirty_grad = True - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor): - otensor.parent.requires_grad = False - graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') - return graph - - @staticmethod - def load_data(dataloader: IRDataLoader): - """ - Translator Action: Load data from data loaderw - """ - if not isinstance(dataloader, IRDataLoader): - raise TypeError("Expected IRDataLoader") - outputs = list() - for dtype, shape in zip(dataloader.dtypes, dataloader.shapes): - data = IRFullTensor( - shape, 'data', requires_grad=False, dtype=dtype - ).tosub() - outputs.append(data) - - data_op = IRDataOperation( - data_num=len(outputs), batch_dims=dataloader.get_batch_dims(), - ) - for idx, output in enumerate(outputs): - data_op.set_output(idx, output) - - SchedulePool().add_node(data_op) - if len(outputs) == 0: return - elif len(outputs) == 1: return outputs[0] - else: return tuple(outputs) - - @staticmethod - def forward(graph, *args): - """ - Translator Action: forward an IRGraph - """ - fgraph = model.forward(graph, *args) - for node in fgraph.nodes(): - SchedulePool().add_node(node) - for output in fgraph.outputs(): - SchedulePool().tape(output, fgraph.nodes()) - outputs = fgraph.outputs() - if len(outputs) == 1: return outputs[0] - elif len(outputs) == 0: return None - else: return outputs - - @staticmethod - def backward(loss: IRSubTensor): - """ - Translator Action: backward a tensor - """ - trace = SchedulePool().get_tape(loss) - if trace is None: - raise RuntimeError("No forward detected") - if loss.nelement() != 1: - raise RuntimeError("backward can only perform on the scaler tensor") - # loss tensor grad should be 1.0 - loss.parent.grad = 1.0 - for node in trace[::-1]: - SchedulePool().add_node(node.mirror) - - @staticmethod - def update(optimizer): - raise NotImplementedError From 5d73d0c200c2bc83d043c3701fa3b259f7596474 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Sep 2022 20:22:15 +0800 Subject: [PATCH 0983/1892] hierarchical adapter gener --- cube/graph/gener/gen.py | 241 +++++++++++++++++++--------------------- cube/graph/segment.py | 8 +- cube/ir/operator.py | 4 + cube/ir/tensor.py | 15 --- 4 files changed, 123 insertions(+), 145 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 5b6aea1b..fc65aa4c 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,5 +1,6 @@ import itertools -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Set +import copy from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener @@ -15,6 +16,16 @@ from cube.graph.function.function import Add, Cat, Identity, MultiRef +def to_device(tensor: IRSubTensor, device: int) -> IRFwOperation: + """ + This is used for changing tensor device + """ + fwop = IRFwOperation('dummy', 'dummpy', 1, 0) + fwop.set_input(0, tensor) + fwop.device = device + return fwop.input(0) + + class IRAdapterGener: @staticmethod @@ -26,32 +37,14 @@ def gen(graph: IRGraph) -> IRGraph: @param graph IRGraph: the graph without adapter @return graph IRGraph: the graph with adapter inserted """ - # insert identity operator for graph output - devs = set() - for node in graph.nodes(): - devs.update(node.device) - outputs = [otensor for otensor in graph.outputs() if isinstance(otensor, IRSubTensor)] - all_identities = [] - for otensor in outputs: - identity = Identity('', [otensor]) - identity.set_output(0, identity.output(0).tosub()) - graph.insert(identity, len(graph.nodes())) - identites = graph.replicate(identity, times=len(devs)) - all_identities += identites - for devid, identity in zip(devs, identites): - graph.assign(identity, devid) # update the gradient before generate adapter for node in graph.nodes(): if isinstance(node, IRBpOperation): - with graph.update(node): - node.update() + graph.update_bwop(node) # generate adapters for activation graph = IRAdapterGener.gen_activation(graph) # generate weight reducer graph = IRAdapterGener.gen_weight(graph) - # remove inserted identity - for identity in all_identities: - graph.remove(identity) # remove anchor node IRAdapterGener.remove_anchor(graph) print(graph.extra_repr()) @@ -59,65 +52,46 @@ def gen(graph: IRGraph) -> IRGraph: @staticmethod def remove_anchor(graph: IRSegment): - for node in graph.nodes(): - if isinstance(node, IRGraphAnchor): - graph.remove(node) - if node.mirror is not None: - graph.remove(node.mirror) - if isinstance(node, IRSegment): - for anchor in node.nodes(): - if isinstance(anchor, IRGraphAnchor): - graph.remove(anchor) - if anchor.mirror is not None: - graph.remove(anchor.mirror) + for anchor in graph.nodes(): + if isinstance(anchor, IRGraphAnchor): + graph.remove(anchor) + if anchor.mirror is not None: + graph.mirror.remove(anchor.mirror) + if isinstance(anchor, IRSegment): + IRAdapterGener.remove_anchor(anchor) @staticmethod def gen_weight(graph: IRGraph) -> IRGraph: - # step 1: get weight and gradient - # weights: Dict[weight_id: int, IRSubTensor] - # grads : Dict[weight_id: int, Dict[device: int, List[grad: IRSubTensor]]] - grads = dict() weights = dict() - for fnode in graph.flatten(): - if not isinstance(fnode, IRFwOperation): - continue - devid = fnode.device[0] + for fnode in graph.nodes(flatten=True): + if not isinstance(fnode, IRFwOperation): continue + assert len(fnode.device) == 1 for wtensor in fnode.inputs(): if isinstance(wtensor, IRSubTensor) and wtensor.is_param(): - grad: Optional[IRSubTensor] = wtensor.grad - if grad is None: continue - # nothing to sync - if grad.valmap == (0, 1): - continue - if wtensor._id not in grads: - grads[wtensor._id] = dict() - weights[wtensor._id] = wtensor - if devid not in grads[wtensor._id]: - grads[wtensor._id][devid] = list() - if grad in grads[wtensor._id][devid]: - raise RuntimeError( - "Find two same gradient (not expected). " - "This is usually due to replicated node assigned to same device. " - f"\nCheck node:\n\t{fnode}" - ) - grads[wtensor._id][devid].append(grad) - # step 2: generate reducers. - # reducers: tuple(ranks): List[weight] + if wtensor.grad is None: continue + if wtensor.parent not in weights: + weights[wtensor.parent] = dict() + if wtensor not in weights[wtensor.parent]: + weights[wtensor.parent][wtensor] = set() + weights[wtensor.parent][wtensor].add(wtensor.device[0]) + reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() - for wid in grads: - ranks = list(grads[wid].keys()) - ranks.sort() - ranks = tuple(ranks) # ranks are used for group - if len(ranks) == 1: - continue - if ranks not in reducers: - reducers[ranks] = list() - reducers[ranks].append(weights[wid]) + for ftensor, subtensors in weights.items(): + # TODO: check no overlapping (not same) weights on a device + for subw in subtensors: + if len(subtensors[subw]) == 1: + continue + devices = list(subtensors[subw]) + devices.sort() + devices = tuple(devices) + if devices not in reducers: + reducers[devices] = [] + reducers[devices].append(subw) # generate reducer for each rank - for ranks in reducers: - weights = reducers[ranks] + for devices in reducers: + weights = reducers[devices] opt_op = IRWeightReducer(weights) - opt_op.device = list(ranks) + opt_op.device = list(devices) graph.insert(opt_op, graph.nnodes) return graph @@ -131,8 +105,6 @@ def gen_activation(graph: IRSegment) -> IRSegment: @return graph IRGraph: the (inplace) modified graph with activation adapters. """ - segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment)] - def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # e.g., loss or parameter/buffer if len(ptensors) == 0 or len(ctensors) == 0: @@ -142,15 +114,8 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: set(ptensors[0].device) == set(ctensors[0].device): return True return False - - def filter(nodes: List[IRCell], tensors: List[IRSubTensor]) -> Tuple[IRCell, IRSubTensor]: - assert len(nodes) == len(tensors) - filter_nodes, filter_tensors = [], [] - for node, tensor in zip(nodes, tensors): - if node in graph.nodes(): - filter_nodes.append(node) - filter_tensors.append(tensor) - return filter_nodes, filter_tensors + + devices = graph.device # generate adapter for inter-segments # FIXME: assume producers and consumers can run in parallel @@ -160,22 +125,48 @@ def filter(nodes: List[IRCell], tensors: List[IRSubTensor]) -> Tuple[IRCell, IRS continue # optimization: local fusion / multiref on producer / consumer - if isinstance(graph, IRGraph) and graph.train: - ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) - IRAdapterGener.local_consumer_multiref(graph, ftensor) + # ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) + # IRAdapterGener.local_consumer_multiref(graph, ftensor) # producers can be operators and graph inputs - producers, ptensors = filter(ftensor.producers, ftensor.ptensors) - for itensor in graph.inputs(): - if isinstance(itensor, IRSubTensor): - if itensor.parent == ftensor: - ptensors.append(itensor) + fproducers, bproducers, ptensors = [], [], [] + # operators + for ptensor, producer in zip(graph.ptensors(ftensor), graph.producers(ftensor)): + for devid in producer.device: + ptensors.append(to_device(ptensor, devid)) + fproducers.append(graph.index(producer)[0]) + if ptensor.requires_grad: + bproducers.append(graph.mirror.index(producer.mirror)[0]) + # graph inputs + for ptensor in graph.inputs(): + if isinstance(ptensor, IRSubTensor) and ptensor.parent == ftensor: + # TODO: mapping back forawrd / backward + ptensor = ftensor.select(ptensor.indmap, (0, 1)) + for devid in devices: + ptensors.append(to_device(ptensor, devid)) + fproducers.append(0) + if ptensor.requires_grad: + bproducers.append(graph.mirror.nnodes) + # consumers can be operators and graph outputs - consumers, ctensors = filter(ftensor.consumers, ftensor.ctensors) - for otensor in graph.outputs(): - if isinstance(otensor, IRSubTensor): - if otensor.parent == ftensor: - ctensors.append(otensor) + fconsumers, bconsumers, ctensors = [], [], [] + # operators + for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): + for devid in consumer.device: + ctensors.append(to_device(ctensor, devid)) + fconsumers.append(graph.index(consumer)[0]) + if ctensor.requires_grad: + bconsumers.append(graph.mirror.index(consumer.mirror)[0]) + # graph outputs + for ctensor in graph.outputs(): + if isinstance(ctensor, IRSubTensor) and ctensor.parent == ftensor: + # TODO: mapping back forward / backward + ctensor = ftensor.select(ctensor.indmap, (0, 1)) + for devid in devices: + ctensors.append(to_device(ctensor, devid)) + fconsumers.append(graph.nnodes) + if ctensor.requires_grad: + bconsumers.append(0) if skip(ptensors, ctensors): continue @@ -184,18 +175,18 @@ def filter(nodes: List[IRCell], tensors: List[IRSubTensor]) -> Tuple[IRCell, IRS continue # insert forward adapter - # fidx = max(graph.index(prod).gidx for prod in producers) - fidx = min(graph.index(cons) for cons in consumers) - graph.insert(fadapter, fidx) + # graph.insert(fadapter, max(producers)) + graph.insert(fadapter, min(fconsumers)) # insert backward adapter - if fadapter.mirror is not None: - bsegment = graph if isinstance(graph, IRGraph) else graph.mirror - # bidx = max(graph.index(cons.mirror) for cons in consumers if cons.mirror is not None) - bidx = min(bsegment.index(prod.mirror) for prod in producers if prod.mirror is not None) - bsegment.insert(fadapter.mirror, bidx) + if len(bproducers) > 0: + assert isinstance(fadapter.mirror, IRAdapter) + assert isinstance(graph.mirror, IRSegment) + bidx = max(bproducers) + graph.mirror.insert(fadapter.mirror, bidx) # generate adapter for each segment + segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment)] for segment in segments: IRAdapterGener.gen_activation(segment) @@ -232,7 +223,7 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens fuse_tensors: Dict[int, Dict[IRSubTensor, List[IRSubTensor]]] = dict() tensor_map: Dict[int, Dict[IRSubTensor, IRSubTensor]] = dict() - for tensor in ftensor.ptensors: + for tensor in graph.ptensors(ftensor): assert len(tensor.device) == 1 devid = tensor.device[0] if devid not in devtensors: @@ -292,7 +283,7 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens if len(nodes) == 0: return ftensor # recompute - rcid = set(producer.recompute for producer in ftensor.producers) + rcid = set(producer.recompute for producer in graph.producers(ftensor)) rcid = list(rcid)[0] if len(rcid) == 1 else None for node in nodes: node.recompute = rcid @@ -300,15 +291,14 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens new_ftensor = ftensor.like() # update consumer - assert len(ftensor.ctensors) == len(ftensor.consumers) - for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): - # TODO: the change can happend inside segment + assert len(graph.ctensors(ftensor)) == len(graph.consumers(ftensor)) + for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): with graph.update(consumer) as consumer: consumer.set_input( consumer.inputs().index(ctensor), new_ftensor.select(ctensor.indmap, ctensor.valmap) ) - min_idx = min(graph.nodes().index(consumer) for consumer in ftensor.consumers) + min_idx = min(graph.index(consumer) for consumer in graph.consumers(ftensor)) # insert new producer for devid, tensors in fuse_tensors.items(): @@ -324,27 +314,24 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens if node.output(0) == tensor_map[devid][ptensor]: node.set_output(0, new_tensor) - fsid = max(graph.stage_id(prod) for prod in ftensor.producers) for node in nodes[::-1]: - # print(node) assert node not in graph.nodes() assert len(node.outputs()) == 1 - graph.attach(node, min_idx, stage_idx=fsid) + graph.insert(node, min_idx) # insert and update backward node - if graph.train: - # update backward node - for consumer in new_ftensor.consumers: - assert isinstance(consumer.mirror, IRBpOperation) - bnode = consumer.mirror - with graph.update(bnode) as bnode: - bnode.update() - # insert backward node - bnodes = [node.gen_backward() for node in nodes] - bidx = min(graph.nodes().index(producer.mirror) for producer in ftensor.producers) - for bnode in bnodes: - bnode.device = bnode.mirror.device - graph.attach(bnode, bidx) + bgraph: IRSegment = graph.mirror + # update backward node + for consumer in graph.consumers(new_ftensor): + assert isinstance(consumer.mirror, IRBpOperation) + bnode = consumer.mirror + bgraph.update_bwop(bnode) + # insert backward node + bnodes = [graph.bwop(node) for node in nodes] + bidx = min(bgraph.index(producer.mirror) for producer in bgraph.producers(ftensor)) + for bnode in bnodes: + bnode.device = bnode.mirror.device + bgraph.insert(bnode, bidx) return new_ftensor @@ -367,7 +354,7 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): """ # collect to consumer tensors of each device devtensors: Dict[int, Dict[IRSubTensor, List[IRCell]]] = dict() - for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): assert len(ctensor.device) == 1 devid = ctensor.device[0] if devid not in devtensors: @@ -383,8 +370,8 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): "Detect that a full tensor is partitioned differently on a device.\n" "To achieve this, need manually add multiref operator in model description.\n" f"Full Tensor: {ftensor}\n" - f"Producers:\n{nl.join(repr(node) for node in ftensor.producers)}\n" - f"Consumers:\n{nl.join(repr(node) for node in ftensor.consumers)}" + f"Producers:\n{nl.join(repr(node) for node in graph.producers(ftensor))}\n" + f"Consumers:\n{nl.join(repr(node) for node in graph.consumers(ftensor))}" ) # add multiref forward node diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 28ccd02f..6e2b1a13 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -291,6 +291,7 @@ def grad(self, tensor: IRSubTensor) -> IRSubTensor: fgrad = tensor.parent.grad # None means no gradient requirement, flaot means its the loss if fgrad is None or isinstance(fgrad, float): + tensor.grad = fgrad return fgrad ftensor = tensor.parent # this tensor is consumed @@ -312,6 +313,7 @@ def grad(self, tensor: IRSubTensor) -> IRSubTensor: indmap = tensor.indmap, valmap = (0, 1), ) + tensor.grad = grad return grad def debug_print_tensor_map(self): @@ -608,6 +610,9 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: for otensor in otensors: ftensor = otensor.parent if otensor.is_attr(): continue + # loss doesn't have consumers + if len(segment.consumers(ftensor)) == 0: + outputs.add(otensor) # from segment outputs if any(t.overlap(otensor) for t in segment.outputs() if isinstance(t, IRSubTensor)): outputs.add(otensor) @@ -621,9 +626,6 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: return segment - ###### ============ Transformation Primitives ============ ####### - - def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: """ Instantiate the segement to a specific device. diff --git a/cube/ir/operator.py b/cube/ir/operator.py index bec66a4a..79e9a3ba 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -154,6 +154,10 @@ def __init__(self, ograds: Tuple[Any], igrads: Tuple[Any]): 'backward', 'torch.autograd.grad', len(ograds), len(igrads), init_outputs=False ) + for idx, ograd in enumerate(ograds): + self.set_input(idx, ograd) + for idx, igrad in enumerate(igrads): + self.set_output(idx, igrad) def update(self): """ diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 013d9812..75f1dcae 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -574,21 +574,6 @@ def __copy__(self): tensor._cell = None return tensor - def grad(self, graph: IRCell) -> Optional[IRTensor]: - """ - Get gradient of this tensor. - - Gradient can be: - - None: the tensor doesn't require gradient - - 1.0: the tensor is loss tensor (scalar) - - IRSubTensor: the tensor requires gradient and is not the loss tensor (scalar) - - Gradient cannot be set and can only be inferred by its IRFullTensor. - The gradient will be lazy updated when its IRFullTensor gets - new consumed / produced tensors - """ - return graph.grad(self) - @property def requires_grad(self) -> bool: return self.parent._requires_grad From fcd3f2c6a33134bd7b98c01968e0e74918ddb455 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Sep 2022 20:22:41 +0800 Subject: [PATCH 0984/1892] hierarchical code gen --- cube/codegen/codegen.py | 60 ++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index f8197989..67ba35ce 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -409,7 +409,7 @@ def init_comm_groups(self): if ranks not in comm_groups: comm_groups.append(ranks) # collect groups from p2p fusion - adapters = [n for n in graph.flatten() if isinstance(n, IRAdapter)] + adapters = [n for n in graph.nodes(flatten=True) if isinstance(n, IRAdapter)] for adapter in adapters: for prim in adapter.prims: if isinstance(prim, CollectivePrim): @@ -535,33 +535,37 @@ def emit_node_tensors_declare(self, node: IRCell): psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}))" map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" - for itensor in node.inputs(): - name = self.tensor_naming(itensor, prefix_attr='self.') - if isinstance(itensor, IRSubTensor): - if itensor.is_attr() and not self.symbols.exist(name): - self.symbols.create(name) - sign = psign if itensor.is_param() else bsign - code = sign.format( - name=self.tensor_naming(itensor), - shape=tuple(itensor.shape), - dtype=self.dtype_map(itensor.dtype) - ) - self.model_init_statements.append(code) - tid = itensor.parent.tid - slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) - val_chunks = itensor.valmap[1] - code = map_sign.format( - attr=self.tensor_naming(itensor), tid=tid, - slicers=str(slicers), val_chunks=val_chunks - ) - self.model_init_statements.append(code) - self.model_init_statements.append('') - if isinstance(itensor, str): - if name.startswith('self.'): - if not hasattr(self._ref_module, name[5:]): - raise NotImplementedError("member attribute is not added") - for output in node.outputs(): - self.symbols.create(self.tensor_naming(output, prefix_attr='self.')) + if not isinstance(node, IRSegment): + for itensor in node.inputs(): + name = self.tensor_naming(itensor, prefix_attr='self.') + if isinstance(itensor, IRSubTensor): + if itensor.is_attr() and not self.symbols.exist(name): + self.symbols.create(name) + sign = psign if itensor.is_param() else bsign + code = sign.format( + name=self.tensor_naming(itensor), + shape=tuple(itensor.shape), + dtype=self.dtype_map(itensor.dtype) + ) + self.model_init_statements.append(code) + tid = itensor.parent.tid + slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) + val_chunks = itensor.valmap[1] + code = map_sign.format( + attr=self.tensor_naming(itensor), tid=tid, + slicers=str(slicers), val_chunks=val_chunks + ) + self.model_init_statements.append(code) + self.model_init_statements.append('') + if isinstance(itensor, str): + if name.startswith('self.'): + if not hasattr(self._ref_module, name[5:]): + raise NotImplementedError("member attribute is not added") + for output in node.outputs(): + self.symbols.create(self.tensor_naming(output, prefix_attr='self.')) + else: + for sub_node in node.nodes(): + self.emit_node_tensors_declare(sub_node) return def emit_segment_code(self, segment: IRSegment) -> List[str]: From 287670ffd72cbb08067fbc21eee3ea0e3cf35854 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 5 Sep 2022 10:48:42 +0800 Subject: [PATCH 0985/1892] staging and grouping fix --- cube/graph/graph.py | 48 +++++++--------------- cube/graph/segment.py | 95 +++++++++++++++++++++++++++++++++++-------- cube/ir/cten.py | 2 + cube/ir/operator.py | 22 ---------- cube/ir/tensor.py | 17 ++++++++ 5 files changed, 113 insertions(+), 71 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5c90507b..d637d7e3 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -129,7 +129,7 @@ def backward(self, loss: IRSubTensor): for fnode in self.nodes()[::-1]: assert not isinstance(fnode, IRSegment), "Internal Error: Segment should not appear for now" if isinstance(fnode, IRFwOperation): - bnode: IRBpOperation = self.bwop(fnode) + bnode: IRBpOperation = self.create_bwop(fnode) Program().add_node(bnode) # set program graph mirror to self Program().mirror_as_self() @@ -187,6 +187,8 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: for bnode in bnodes: bidx = bgraph.remove(bnode) bgraph.insert(bsegment, bidx) + # setup gradient + self.update_bwop(bsegment) return fsegment @@ -322,7 +324,7 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis bsegment: IRSegment = fsegment.mirror if isinstance(node.mirror, IRBpOperation): bnode: IRBpOperation = node.mirror - bnodes = tuple(self.bwop(fnode) for fnode in fnodes[::-1]) + bnodes = tuple(self.create_bwop(fnode) for fnode in fnodes[::-1]) for bnode in bnodes: bnode.device = node.device bsegment.replace(bnode, bnodes) @@ -374,7 +376,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], # update backward bsegment: IRSegment = fsegment.mirror if isinstance(node.mirror, IRBpOperation): - bnodes = tuple(self.bwop(fnode) for fnode in fnodes[::-1]) + bnodes = tuple(self.create_bwop(fnode) for fnode in fnodes[::-1]) bsegment.replace(node.mirror, bnodes) for bnode in bnodes: bnode.device = node.device @@ -648,14 +650,13 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: identity.infer_shape() identity.set_output(0, identity.output(0).tosub()) # insert forward - self.insert(identity, self.index(fstages[sid][0])) + fidx = self.index(fstages[sid][0]) + if tensor.requires_grad: + self.finsert(identity, fidx) + bstages[sid].append(identity.mirror) + else: + self.insert(identity, fidx) fstages[sid].insert(0, identity) - - # insert backward - if self.train: - bnode = identity.gen_backward() - self.insert(bnode, self.index(bstages[sid][-1]) + 1) - bstages[sid].append(bnode) return identity # create identity op for cross-stage dataflow @@ -665,7 +666,7 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: assert len(self.producers(ftensor)) <= 1, \ "The staging interface should be called before any operator partition." if len(self.consumers(ftensor)) == 0: continue - producer, ptensor = self.producers(ftensor), self.consumers(ftensor) + producer, ptensor = self.producers(ftensor)[0], self.ptensors(ftensor)[0] psid = get_sid(producer) # outside of stages, not consider if psid is None: continue @@ -678,15 +679,13 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: for sid in range(curr_sid + 1, csid): identity = insert_identity(out, sid) out = identity.output(0) - # update consumer and its backward + # update consumer with self.update(consumer) as consumer: tidx = consumer.inputs().index(ptensor) consumer.set_input(tidx, out) - if self.train: - with self.update(consumer.mirror) as bnode: - bnode.update() curr_sid = csid - + # update all its backward operators + self.update_ftensor_bw(ftensor.grad) # grouping into segment for sid in range(len(fstages)): self.group(fstages[sid]) @@ -726,20 +725,3 @@ def recompute(self, nodes: Union[List[IRFwOperation], IRSegment]) -> bool: fnode.recompute = recompute_group_id return True - - def __repr__(self) -> str: - dscp = f"Graph{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" - return dscp - - def extra_repr(self) -> str: - dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" - # inputs - dscp += f"Inputs: {self.inputs()}\n" - for node in self._nodes: - dscp += f"\n{node}" - if isinstance(node, IRSegment): - for subnode in node.nodes(): - dscp += f"\n\t{subnode}" - # outputs - dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" - return dscp diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 6e2b1a13..2bcbb025 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -101,10 +101,12 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR self._have_forward = any(isinstance(n, IRFwOperation) for n in nodes) self._have_backward = any(isinstance(n, IRBpOperation) for n in nodes) - @property - def forward(self) -> bool: + def isfw(self) -> bool: return self._have_forward + def isbw(self) -> bool: + return self._have_backward + def full_tensors(self) -> Tuple[IRFullTensor]: """ Get all full tensors of this graph. @@ -301,7 +303,6 @@ def grad(self, tensor: IRSubTensor) -> IRSubTensor: assert not (ctensor != tensor and ctensor.overlap(tensor)), "parital overlap is not supported for gradient" if ctensor == tensor and consumer not in consumers: consumers.append(consumer) - # segment.debug_print_tensor_map() valmap = (consumers.index(tensor.cell), len(consumers)) grad = ftensor.grad.select( indmap = tensor.indmap, @@ -316,8 +317,9 @@ def grad(self, tensor: IRSubTensor) -> IRSubTensor: tensor.grad = grad return grad - def debug_print_tensor_map(self): - for ftensor in self._ftensors: + def debug_print_tensor_map(self, ftensor: Optional[IRFullTensor] = None): + ftensors = [ftensor] if ftensor is not None else self._ftensors + for ftensor in ftensors: print(f'Full Tensor: {ftensor}') print(f'Producers:') for producer in self._producers[ftensor]: @@ -326,41 +328,65 @@ def debug_print_tensor_map(self): for producer in self._consumers[ftensor]: print(f'\t{producer}') - def bwop(self, fwop: IRFwOperation) -> IRBpOperation: + def create_bwop(self, fwop: IRFwOperation) -> IRBpOperation: """ Create dummy backward operator for given forward operator """ assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" fsegment: IRSegment = self.segment(fwop) - igrads = [fsegment.grad(t) if isinstance(t, IRSubTensor) else None for t in fwop.inputs()] - ograds = [fsegment.grad(t) if isinstance(t, IRSubTensor) else None for t in fwop.outputs()] + igrads = [fsegment.grad(t) if t.requires_grad else None for t in fwop.inputs() if isinstance(t, IRSubTensor)] + ograds = [fsegment.grad(t) if t.requires_grad else None for t in fwop.outputs() if isinstance(t, IRSubTensor)] bwop = IRBpOperation(ograds, igrads) IRCell.make_pair(fwop, bwop) return bwop - def update_bwop(self, bwop: IRBpOperation) -> IRBpOperation: + def update_bwop(self, bwop: IRCell) -> IRBpOperation: """ - Update backward operator. + Update backward operator or a backward segment. This is neccessary when fwop is partitioned and reference count is changed. - @param bwop IRBpOperation: the backward operation. + @param bwop IRBpOperation or IRSegment: the backward operation. It can be at any hierarchy of this segemtn @return bwop IRBpOperation: the updated operation (inplace) """ + assert isinstance(bwop, (IRBpOperation, IRSegment)) + if isinstance(bwop, IRSegment): + assert bwop.isbw() and (not bwop.isfw()) bsegment: IRSegment = self.segment(bwop) fsegment = bsegment.mirror with bsegment.update(bwop): - fwop: IRFwOperation = bwop.mirror - igrads = [fsegment.grad(t) if isinstance(t, IRSubTensor) else None for t in fwop.inputs()] + fwop: Union[IRFwOperation, IRSegment] = bwop.mirror + igrads = [fsegment.grad(t) if t.requires_grad else None for t in fwop.inputs() if isinstance(t, IRSubTensor)] for idx, igrad in enumerate(igrads): bwop.set_output(idx, igrad) - ograds = [fsegment.grad(t) if isinstance(t, IRSubTensor) else None for t in fwop.outputs()] + ograds = [fsegment.grad(t) if t.requires_grad else None for t in fwop.outputs() if isinstance(t, IRSubTensor)] + # Ad-hoc fix: remove float that could be caused by loss for segment + if isinstance(bwop, IRSegment): + ograds = [grad for grad in ograds if isinstance(grad, IRSubTensor)] for idx, ograd in enumerate(ograds): bwop.set_input(idx, ograd) return bwop + def update_ftensor_bw(self, ftensor: IRFullTensor): + """ + Update all backward operators for a full tensor. + + @param ftensor IRFullTensor: the full tensor. If the full + tensor is not a gradient, will update backward operators + of ftensor.grad + + @return None + """ + fgrad = ftensor.grad if not ftensor.is_grad() else ftensor + if fgrad is None: + return + for producer in self.producers(fgrad): + self.update_bwop(producer) + for consumer in self.consumers(fgrad): + self.update_bwop(consumer) + # ====================== Basic Graph manipulations ====================== def _add_ftensor(self, ftensor: IRFullTensor): @@ -496,7 +522,9 @@ def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: @contextmanager def update(self, node): """ - Update a node. + Update a node. Note the related change in backward operator + will not be automatically updated. + TODO: update operator dependency e.g., @@ -526,6 +554,40 @@ def exist(self, node: IRCell) -> bool: return True return False + def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwOperation: + """ + Insert a forward node and create its backward. + The created backward operator will be happen right before + the backward of fwop's previous forward node + + This requires the segment has its backward segment + + @param fwop IRFwOperation: forward node + @param index Union[int, CellPosition]: inserted position + + @return node IRFwOperation: the node itself + """ + assert isinstance(fwop, IRFwOperation), "Only allow insert an IRFwOperation" + pos = CellPosition((index,)) if isinstance(index, int) else index + assert isinstance(pos, CellPosition), "Expect index to be int or CellPosition" + + index = pos.indices[-1] + fsegment = self if len(pos) == 1 else self.node(CellPosition(pos.indices[1:])) + fsegment.insert(fwop, index) + # create backward + bwop = fsegment.create_bwop(fwop) + # insert backward + assert fsegment.mirror is not None, "Missing backward segment" + bsegment: IRSegment = fsegment.mirror + bidx = 0 + for idx in range(index - 1, -1, -1): + prev_fnode = fsegment.node(idx) + if prev_fnode.mirror is not None: + bidx = bsegment.index(prev_fnode.mirror) + break + bsegment.insert(bwop, bidx) + return fwop + # ====================== Graph Generations ============================ @staticmethod @@ -665,7 +727,8 @@ def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: # ========================== Graph Visualize ================================ def __repr__(self): - dscp = f"Graph{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" + fw = 'f' if self.isfw() else 'b' + dscp = f"{fw}Graph{self.cid}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" return dscp def extra_repr(self) -> str: diff --git a/cube/ir/cten.py b/cube/ir/cten.py index d7b8d984..b52c783e 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -477,6 +477,8 @@ def dtype(self, val: IRDType): if not isinstance(val, IRDType): raise TypeError(f"Expected IRDType but got {val}") self._dtype = val + if isinstance(self._grad, IRTensor): + self._dtype = val @property def cell(self) -> Optional[IRCell]: diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 79e9a3ba..f88cec8c 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -159,28 +159,6 @@ def __init__(self, ograds: Tuple[Any], igrads: Tuple[Any]): for idx, igrad in enumerate(igrads): self.set_output(idx, igrad) - def update(self): - """ - Update this backward operator. - This is neccessary when op is partitioned and reference count is changed. - - Note in order to update produced and consumed tensor list, this call should be - wrapped with IRGraph detach and attach: - - ``` - with graph.update(node): - node.update() - ``` - """ - fnode: IRFwOperation = self.mirror - assert isinstance(fnode, IRFwOperation), "Cannot find corresponding IRFwOperation" - for idx, itensor in enumerate(fnode.inputs()): - grad = itensor.grad if isinstance(itensor, IRSubTensor) else None - self.set_output(idx, grad) - for idx, otensor in enumerate(fnode.outputs()): - grad = otensor.grad if isinstance(otensor, IRSubTensor) else None - self.set_input(idx, grad) - def replicate(self): """ Replicate the backward op diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 75f1dcae..d027fcf9 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -578,6 +578,23 @@ def __copy__(self): def requires_grad(self) -> bool: return self.parent._requires_grad + @property + def grad(self) -> bool: + return self._grad + + @grad.setter + def grad(self, val: Optional[IRTensor]): + if isinstance(val, (IRSubTensor, float)): + assert self.requires_grad + if isinstance(val, IRSubTensor): + val.shape == self.shape + self._grad = val + elif val is None: + assert not self.requires_grad + self._grad = None + else: + raise ValueError(f"Expected grad to be None or IRSubTensor but got: {val}") + # partition primitives def select(self, From b2409e276f3a962a67c9704aacc35865ffe1cf2f Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 5 Sep 2022 16:24:06 +0800 Subject: [PATCH 0986/1892] update PyTorch example: regressive generation of GPT inference --- examples/nlp/blocks/attention.py | 20 ++++++++++---------- examples/nlp/blocks/encoder.py | 4 +--- examples/nlp/gpt/infer.py | 9 +++++++-- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 3f62396a..525b3fb4 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -90,13 +90,14 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, from typing import Optional, Tuple -# @cube.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') +@cube.graph.parser.register('1 N E^, L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') def one_attention(hidden_states: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor]], + past_embed_key: torch.Tensor, + past_embed_value: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, out_bias: torch.Tensor, + out_proj: torch.Tensor, #out_bias: torch.Tensor, h: int, scale: float, dropout_p: float, mask=True): num_head = h L, N = hidden_states.size(0), hidden_states.size(1) @@ -106,10 +107,9 @@ def one_attention(hidden_states: torch.Tensor, k = torch.nn.functional.linear(hidden_states, k_proj, k_bias) # L N E, (h d) E -> L N (h d) v = torch.nn.functional.linear(hidden_states, v_proj, v_bias) # L N E, (h d) E -> L N (h d) - if layer_past is not None: - past_key, past_value = layer_past - k = torch.cat((past_key, k), dim=-3) - v = torch.cat((past_value, v), dim=-3) + if past_embed_key is not None and past_embed_value is not None: + k = torch.cat((past_embed_key, k), dim=-3) + v = torch.cat((past_embed_value, v), dim=-3) q_N = hidden_states.size(1) @@ -238,13 +238,13 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa from typing import Optional, Tuple - def forward(self, query: torch.Tensor, layer_past: Optional[Tuple[torch.Tensor]] = None): + def forward(self, query: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor): attn = one_attention( - query, layer_past, + query, past_embed_key, past_embed_value, self.q_proj, self.q_bias, self.k_proj, self.k_bias, self.v_proj, self.v_bias, - self.out_proj, self.out_bias, + self.out_proj, #self.out_bias, self.num_heads, self.scaling, self.dropout_p, mask=True ) attn = attn + self.out_bias diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 5582bd36..6ce484da 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -50,14 +50,12 @@ def __init__(self, embed_dim: int, num_heads: int, tmp_batch_size = 1 self.past_embed_key = torch.nn.Parameter(torch.rand(seqlen, tmp_batch_size, embed_dim)) self.past_embed_value = torch.nn.Parameter(torch.rand(seqlen, tmp_batch_size, embed_dim)) - self.past_embed = tuple([self.past_embed_key, self.past_embed_value]) - print(f'self.past_embed.type = {type(self.past_embed)}') # def forward(self, x: torch.Tensor, encoder_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.self_attn_layer_norm(x) - x = self.self_attn_partial(x, self.past_embed) + x = self.self_attn_partial(x, self.past_embed_key, self.past_embed_value) x = self.dropout(x) x = x + residual diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py index 30557002..6794bf37 100644 --- a/examples/nlp/gpt/infer.py +++ b/examples/nlp/gpt/infer.py @@ -4,7 +4,9 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMeshShard --fp16 + examples/nlp/gpt/infer.py --policy PASMeshShard --fp16 + +PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/gpt/infer.py --policy PASSingle --fp16 """ @@ -45,14 +47,17 @@ raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") def inter(): + print(f'torch.cuda.is_available() = {torch.cuda.is_available()}') batch_size = 1 model = GPTInfer() - # model = model if not args.fp16 else model.half() + model = model if not args.fp16 else model.half() + model = model.cuda() model.eval() dataloader = GPTInferDataLoader(batch_size) + output = None for i in range(10): input_ids, position_ids = next(dataloader) print(f'input_ids = {input_ids} [{input_ids.size()}], position_ids = {position_ids} [{position_ids.size()}]') From 1ac2f3be123fb8e9fd65a983833fb89021a68697 Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 5 Sep 2022 16:44:57 +0800 Subject: [PATCH 0987/1892] update PyTorch example: regressive generation of GPT inference, support graph capture --- examples/nlp/blocks/attention.py | 4 +- examples/nlp/gpt/infer.py | 63 +++++++++----------------------- 2 files changed, 19 insertions(+), 48 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 525b3fb4..42b83843 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -90,7 +90,7 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, from typing import Optional, Tuple -@cube.graph.parser.register('1 N E^, L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') +@cube.graph.parser.register('l N E^, L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> l N E^', name='one_attention') def one_attention(hidden_states: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor, @@ -98,7 +98,7 @@ def one_attention(hidden_states: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, out_proj: torch.Tensor, #out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, mask=True): + h: int, scale: float, dropout_p: float, mask: bool = True): num_head = h L, N = hidden_states.size(0), hidden_states.size(1) dim_head = q_proj.size(0) // num_head diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py index 6794bf37..0eb0b67d 100644 --- a/examples/nlp/gpt/infer.py +++ b/examples/nlp/gpt/infer.py @@ -57,56 +57,27 @@ def inter(): model.eval() dataloader = GPTInferDataLoader(batch_size) - output = None - for i in range(10): + ################## SuperScaler run + model = cube.SemanticModel(model, dataloader.shapes) + @cube.compile(model, dataloader, PAS=PAS, override=True) + def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) - print(f'input_ids = {input_ids} [{input_ids.size()}], position_ids = {position_ids} [{position_ids.size()}]') - output = model(input_ids, position_ids) - print(f'output = {output}') + loss = model(input_ids, position_ids) + return loss + model = model.get_gen_module() + iter_num = 2 + for step in range(iter_num): + output = train_iter(model, dataloader) + print(f'output = {output}') - # model = cube.SemanticModel(model, dataloader.shapes) - # @cube.compile(model, dataloader, PAS=PAS, override=True) - # def train_iter(model, dataloader): + ################## PyTorch run + # output = None + # for i in range(10): # input_ids, position_ids = next(dataloader) - # loss = model(input_ids, position_ids) - # # loss.backward() - # model = model.get_gen_module() - # - # # optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - # - # import os - # single_device_mode = os.environ.get('SINGLE_DEV_MODE') - # print(f"single_dev_mode = {single_device_mode}") - # if not single_device_mode: - # torch.distributed.barrier() - # print_each_rank('model weight consumpition:', rank_only=0) - # memory_summary() - # - # CudaTimer(enable=False).warmup() - # iter_num = 4 - # warmup = 2 - # for step in range(iter_num): - # # if step == 0: - # # model_summary(model, next(dataloader)) - # - # if step >= warmup: - # CudaTimer(enable=True).start('e2e') - # train_iter(model, dataloader) - # # optimizer.step() - # # optimizer.zero_grad() - # if step >= warmup: - # CudaTimer().stop('e2e') - # - # if step == 0: - # print_each_rank('passed first iteration') - # if (step + 1) % 10 == 0: - # print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - # - # print_each_rank('e2e time (ms) per iteration: {} ms'.format( - # CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - # CudaTimer().print_all(times=iter_num-warmup) - # memory_summary() + # print(f'input_ids = {input_ids} [{input_ids.size()}], position_ids = {position_ids} [{position_ids.size()}]') + # output = model(input_ids, position_ids) + # print(f'output = {output}') if __name__ == '__main__': From 8d3a5e2f25c5eb5262767faa424f3e440eaba64b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 5 Sep 2022 17:12:03 +0800 Subject: [PATCH 0988/1892] fix core bugs --- cube/graph/gener/concurrent.py | 88 +++++++++++----------- cube/graph/gener/gen.py | 120 +++++++++++++++--------------- cube/graph/graph.py | 3 +- cube/graph/segment.py | 130 +++++++++++++++++++++++++-------- 4 files changed, 208 insertions(+), 133 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 8625db46..a227401b 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -15,33 +15,37 @@ class ConcurrentGener: @staticmethod - def gen(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> Optional[IRAdapter]: + def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor]) -> Optional[IRAdapter]: """ Generate forward adapter and backward adapter - @param ptensors List[IRSubTensor]: forward producer tensors - @param ctensors List[IRSubTensor]: forward consumer tensors + @param fptensors List[IRSubTensor]: forward producer tensors + @param fctensors List[IRSubTensor]: forward consumer tensors + @param bptensors List[IRSubTensor]: backward producer tensors + @param bctensors List[IRSubTensor]: backward consumer tensors @return fadapter Optional[IRAdapter]: forward adapter None indicate no adapter required. """ - pdevs = tuple(t.device[0] for t in ptensors) - cdevs = tuple(t.device[0] for t in ctensors) + pdevs = tuple(t.device[0] for t in fptensors) + cdevs = tuple(t.device[0] for t in fctensors) fadapter: IRAdapter = None # case 1: sharing device (in-shard) - inshard = (set(pdevs) == set(cdevs)) and (len(ptensors) == len(ctensors)) and (len(pdevs) == len(ptensors)) + inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) if inshard and len(pdevs) > 1: - try: - fadapter = ConcurrentGener.gen_in_shard(ptensors, ctensors, allow_reorder=True) - except Exception as e: - fadapter = None - print( - f"full tensor: {ptensors[0].parent} cannot use grid generation.\n" - f"Reason: {str(e)}\n" - f"Switch to general P2P communication." - ) + # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + try: + fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + except Exception as e: + fadapter = None + print( + f"full tensor: {fptensors[0].parent} cannot use grid generation.\n" + f"Reason: {str(e)}\n" + f"Switch to general P2P communication." + ) # Case 2: sperating device (cross-shard) if len(set(pdevs).intersection(cdevs)) == 0: @@ -50,19 +54,21 @@ def gen(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> Optional[IR # Case 3: General cases # warnings.warn('The adapter is generated using P2P communication') if fadapter is None: - fadapter = ConcurrentGener.gen_general(ptensors, ctensors) + fadapter = ConcurrentGener.gen_general(fptensors, fctensors, bptensors, bctensors) return fadapter @staticmethod - def gen_in_shard(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor], allow_reorder=False): - ftensor = ptensors[0].parent + def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], + allow_reorder=False): + ftensor = fptensors[0].parent # producer grid layout - ilayout = GridLayout.togrid(ftensor, ptensors) + ilayout = GridLayout.togrid(ftensor, fptensors) # reorder ctensors to match with ptensors devs = [ptensor.device for ptensor in ilayout.mat.flatten()] ctensors = [None] * len(devs) - for ctensor in ctensors: + for ctensor in fctensors: idx = devs.index(ctensor.device) ctensors[idx] = ctensor assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" @@ -90,28 +96,26 @@ def gen_in_shard(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor], allow if len(names) > 0: print(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') - fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) + fadapter = IRAdapter(fptensors, fctensors) fadapter.prims = fprims # generate backward grad: IRFullTensor = ftensor.grad - b_ptensors = [ctensor.grad for ctensor in ctensors] - b_ctensors = [ptensor.grad for ptensor in ptensors] bprims = [] - if grad is not None and (len(b_ptensors) != 0 or len(b_ctensors) != 0): + if grad is not None and (len(bptensors) != 0 or len(bctensors) != 0): # reorder ptensors to match with forward ptensors = [None] * len(devs) - for b_ptensor in b_ptensors: - idx = devs.index(b_ptensor.device) + for bptensor in bptensors: + idx = devs.index(bptensor.device) assert ptensors[idx] is None, "same device of different tensors" - ptensors[idx] = b_ptensor - ilayout = GridLayout.togrid(grad, b_ptensors) - olayout = GridLayout.togrid(grad, b_ctensors) + ptensors[idx] = bptensor + ilayout = GridLayout.togrid(grad, ptensors) + olayout = GridLayout.togrid(grad, bctensors) paths, bprims = ilayout.path(olayout) # check the device order for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): assert len(itensor.device) == len(otensor.device), "backward device not match" - badapter = IRAdapter(b_ptensors, b_ctensors) + badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims IRAdapter.make_pair(fadapter, badapter) @@ -122,7 +126,8 @@ def gen_cross_shard(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> pass @staticmethod - def gen_general(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> IRAdapter: + def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor]) -> IRAdapter: """ A general way to generate adapter. FIXME: Assuming consumers at different devices can happen at the same time. @@ -132,19 +137,18 @@ def gen_general(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> IRA @return adapter IRAdapter """ fprims = [] - for ctensor in ctensors: - fprims += ConcurrentGener.gen_subtensor(ctensor, ptensors) - fadapter = IRAdapter(ptensors,ctensors) + for ctensor in fctensors: + fprims += ConcurrentGener.gen_subtensor(ctensor, fptensors) + fadapter = IRAdapter(fptensors,fctensors) fadapter.prims = fprims # backward - b_ptensors = [ctensor.grad for ctensor in ctensors if ctensor.grad is not None] - b_ctensors = [ptensor.grad for ptensor in ptensors if ptensor.grad is not None] - bprims = [] - for cgrad in b_ctensors: - bprims += ConcurrentGener.gen_subtensor(cgrad, b_ptensors) - badapter = IRAdapter(b_ptensors, b_ctensors) - badapter.prims = bprims - IRAdapter.make_pair(fadapter, badapter) + if len(bptensors) > 0 and len(bctensors) > 0: + bprims = [] + for cgrad in bctensors: + bprims += ConcurrentGener.gen_subtensor(cgrad, bptensors) + badapter = IRAdapter(bptensors, bctensors) + badapter.prims = bprims + IRAdapter.make_pair(fadapter, badapter) return fadapter @staticmethod diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index fc65aa4c..09d83923 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,5 +1,5 @@ import itertools -from typing import Dict, List, Optional, Tuple, Set +from typing import Dict, List, Optional, Tuple import copy from cube.graph.function.anchor import IRGraphAnchor @@ -8,7 +8,7 @@ from cube.graph.graph import IRGraph from cube.graph.segment import IRSegment from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap from cube.ir.operator import IRBpOperation, IRFwOperation @@ -23,7 +23,11 @@ def to_device(tensor: IRSubTensor, device: int) -> IRFwOperation: fwop = IRFwOperation('dummy', 'dummpy', 1, 0) fwop.set_input(0, tensor) fwop.device = device - return fwop.input(0) + otensor = fwop.input(0) + otensor.grad = copy.copy(tensor.grad) + if isinstance(otensor.grad, IRSubTensor): + otensor.grad.cell = fwop + return otensor class IRAdapterGener: @@ -37,17 +41,24 @@ def gen(graph: IRGraph) -> IRGraph: @param graph IRGraph: the graph without adapter @return graph IRGraph: the graph with adapter inserted """ + # remove anchor node + graph = IRAdapterGener.remove_anchor(graph) # update the gradient before generate adapter - for node in graph.nodes(): - if isinstance(node, IRBpOperation): - graph.update_bwop(node) + graph = IRAdapterGener.update_grad(graph) # generate adapters for activation graph = IRAdapterGener.gen_activation(graph) # generate weight reducer graph = IRAdapterGener.gen_weight(graph) - # remove anchor node - IRAdapterGener.remove_anchor(graph) - print(graph.extra_repr()) + return graph + + @staticmethod + def update_grad(graph: IRSegment): + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): + graph.update_ftensor_bw(ftensor) + for node in graph.nodes(): + if isinstance(node, IRSegment) and node.isbw(): + IRAdapterGener.update_grad(node) return graph @staticmethod @@ -57,8 +68,9 @@ def remove_anchor(graph: IRSegment): graph.remove(anchor) if anchor.mirror is not None: graph.mirror.remove(anchor.mirror) - if isinstance(anchor, IRSegment): + elif isinstance(anchor, IRSegment): IRAdapterGener.remove_anchor(anchor) + return graph @staticmethod def gen_weight(graph: IRGraph) -> IRGraph: @@ -129,68 +141,58 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # IRAdapterGener.local_consumer_multiref(graph, ftensor) # producers can be operators and graph inputs - fproducers, bproducers, ptensors = [], [], [] - # operators - for ptensor, producer in zip(graph.ptensors(ftensor), graph.producers(ftensor)): - for devid in producer.device: - ptensors.append(to_device(ptensor, devid)) - fproducers.append(graph.index(producer)[0]) - if ptensor.requires_grad: - bproducers.append(graph.mirror.index(producer.mirror)[0]) - # graph inputs - for ptensor in graph.inputs(): - if isinstance(ptensor, IRSubTensor) and ptensor.parent == ftensor: - # TODO: mapping back forawrd / backward - ptensor = ftensor.select(ptensor.indmap, (0, 1)) - for devid in devices: - ptensors.append(to_device(ptensor, devid)) - fproducers.append(0) - if ptensor.requires_grad: - bproducers.append(graph.mirror.nnodes) - - # consumers can be operators and graph outputs - fconsumers, bconsumers, ctensors = [], [], [] - # operators - for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): - for devid in consumer.device: - ctensors.append(to_device(ctensor, devid)) - fconsumers.append(graph.index(consumer)[0]) - if ctensor.requires_grad: - bconsumers.append(graph.mirror.index(consumer.mirror)[0]) - # graph outputs - for ctensor in graph.outputs(): - if isinstance(ctensor, IRSubTensor) and ctensor.parent == ftensor: - # TODO: mapping back forward / backward - ctensor = ftensor.select(ctensor.indmap, (0, 1)) - for devid in devices: - ctensors.append(to_device(ctensor, devid)) - fconsumers.append(graph.nnodes) - if ctensor.requires_grad: - bconsumers.append(0) - - if skip(ptensors, ctensors): continue + fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) + assert all(len(ptensor.device) == 1 for ptensor in fptensors), "Not support for multi-device" + fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) + assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" + + bproducers, bptensors = [], [] + bconsumers, bctensors = [], [] + if isinstance(ftensor.grad, IRFullTensor): + bproducers, bptensors = graph.producers(ftensor.grad), graph.ptensors(ftensor.grad) + assert all(len(ptensor.device) == 1 for ptensor in bptensors), "Not support for multi-device" + bconsumers, bctensors = graph.consumers(ftensor.grad), graph.ctensors(ftensor.grad) + assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" + + if skip(fptensors, fctensors) and skip(bptensors, bctensors): continue - fadapter = ConcurrentGener.gen(ptensors, ctensors) + + # print((f"generating for {ftensor}:\n" + # f"fptensor device: {[t.device for t in fptensors]}\n" + # f"fctensor device: {[t.device for t in fctensors]}\n" + # f"bptensor device: {[t.device for t in bptensors]}\n" + # f"bctensor device: {[t.device for t in bctensors]}\n" + # )) + fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors) if fadapter is None: continue + badapter: Optional[IRAdapter] = fadapter.mirror + + if (badapter is not None and len(fadapter.prims) == 0 and len(badapter.prims) == 0) or \ + (badapter is None and len(fadapter.prims) == 0): + continue + # insert forward adapter - # graph.insert(fadapter, max(producers)) - graph.insert(fadapter, min(fconsumers)) + # graph.insert(fadapter, max(producers) + 1) + graph.insert(fadapter, min(graph.index(c) for c in fconsumers)) # insert backward adapter - if len(bproducers) > 0: - assert isinstance(fadapter.mirror, IRAdapter) + if badapter is not None: + assert isinstance(badapter, IRAdapter) assert isinstance(graph.mirror, IRSegment) - bidx = max(bproducers) - graph.mirror.insert(fadapter.mirror, bidx) + bproducers = [ + graph.mirror.index(consumer.mirror) + 1 for \ + consumer in graph.consumers(ftensor) + ] + bidx = max(bproducers) if len(bproducers) > 0 else 0 + graph.mirror.insert(badapter, bidx) # generate adapter for each segment - segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment)] + segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] for segment in segments: IRAdapterGener.gen_activation(segment) - print(graph.extra_repr()) return graph diff --git a/cube/graph/graph.py b/cube/graph/graph.py index d637d7e3..70317097 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -323,11 +323,10 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis # insert backward bsegment: IRSegment = fsegment.mirror if isinstance(node.mirror, IRBpOperation): - bnode: IRBpOperation = node.mirror bnodes = tuple(self.create_bwop(fnode) for fnode in fnodes[::-1]) for bnode in bnodes: bnode.device = node.device - bsegment.replace(bnode, bnodes) + bsegment.replace(node.mirror, bnodes) return fnodes def partition(self, node: Union[IRFwOperation, IRDataOperation], diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 2bcbb025..86914561 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -206,14 +206,17 @@ def node(self, index: Union[int, CellPosition]) -> IRCell: def index(self, node: IRCell) -> CellPosition: """ - Get node index. + Get node index. The dispatched node (e.g., IRAdapter, IRSegment) + will return the index to its un-dispatched node @param node IRCell: the queried node @return index int: the index """ - if node in self._nodes: - return CellPosition((self._nodes.index(node),)) + assert isinstance(node, IRCell) + cids = tuple(node.cid for node in self._nodes) + if node.cid in cids: + return CellPosition((cids.index(node.cid),)) for idx, segment in enumerate(self._nodes): if isinstance(segment, IRSegment): if segment.exist(node): @@ -229,7 +232,7 @@ def segment(self, node: IRCell) -> IRCell: @return segment IRSegment """ - assert isinstance(node, IRCell) + assert isinstance(node, IRCell), f"Expected IRCell, but got {node}" index = self.index(node) if len(index) == 1: return self @@ -280,7 +283,7 @@ def ctensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: assert ftensor in self._ctensors, f"{ftensor} is not in the graph" return tuple(self._ctensors[ftensor]) - def grad(self, tensor: IRSubTensor) -> IRSubTensor: + def grad(self, tensor: IRSubTensor, no_partial_overlap=False) -> IRSubTensor: """ Get gradient of the tensor. @@ -298,12 +301,17 @@ def grad(self, tensor: IRSubTensor) -> IRSubTensor: ftensor = tensor.parent # this tensor is consumed if tensor in tensor.cell.inputs(): - consumers = [] + consumer_cids = [] for ctensor, consumer in zip(segment.ctensors(ftensor), segment.consumers(ftensor)): - assert not (ctensor != tensor and ctensor.overlap(tensor)), "parital overlap is not supported for gradient" - if ctensor == tensor and consumer not in consumers: - consumers.append(consumer) - valmap = (consumers.index(tensor.cell), len(consumers)) + if no_partial_overlap: + assert not (ctensor != tensor and ctensor.overlap(tensor)), ( + f"parital overlapping is not supported for gradient\n" + f"{self.debug_tensor_map_str(ctensor.parent)}" + ) + if ctensor == tensor and consumer.cid not in consumer_cids: + consumer_cids.append(consumer.cid) + + valmap = (consumer_cids.index(tensor.cell.cid), len(consumer_cids)) grad = ftensor.grad.select( indmap = tensor.indmap, valmap = valmap @@ -317,16 +325,18 @@ def grad(self, tensor: IRSubTensor) -> IRSubTensor: tensor.grad = grad return grad - def debug_print_tensor_map(self, ftensor: Optional[IRFullTensor] = None): + def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: + dscp : str = '' ftensors = [ftensor] if ftensor is not None else self._ftensors for ftensor in ftensors: - print(f'Full Tensor: {ftensor}') - print(f'Producers:') + dscp += f'====\nFull Tensor: {ftensor}\n' + dscp += f'Producers:\n' for producer in self._producers[ftensor]: - print(f'\t{producer}') - print(f'Consumers:') - for producer in self._consumers[ftensor]: - print(f'\t{producer}') + dscp += f'\t{producer}\n' + dscp += f'Consumers:\n' + for consumer in self._consumers[ftensor]: + dscp += f'\t{consumer}\n' + return dscp def create_bwop(self, fwop: IRFwOperation) -> IRBpOperation: """ @@ -334,8 +344,10 @@ def create_bwop(self, fwop: IRFwOperation) -> IRBpOperation: """ assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" fsegment: IRSegment = self.segment(fwop) - igrads = [fsegment.grad(t) if t.requires_grad else None for t in fwop.inputs() if isinstance(t, IRSubTensor)] - ograds = [fsegment.grad(t) if t.requires_grad else None for t in fwop.outputs() if isinstance(t, IRSubTensor)] + fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] + fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] + igrads = [fsegment.grad(t) if t.requires_grad else None for t in fins] + ograds = [fsegment.grad(t) if t.requires_grad else None for t in fous] bwop = IRBpOperation(ograds, igrads) IRCell.make_pair(fwop, bwop) return bwop @@ -358,10 +370,12 @@ def update_bwop(self, bwop: IRCell) -> IRBpOperation: fsegment = bsegment.mirror with bsegment.update(bwop): fwop: Union[IRFwOperation, IRSegment] = bwop.mirror - igrads = [fsegment.grad(t) if t.requires_grad else None for t in fwop.inputs() if isinstance(t, IRSubTensor)] + fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] + fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] + igrads = [fsegment.grad(t) if t.requires_grad else None for t in fins] + ograds = [fsegment.grad(t) if t.requires_grad else None for t in fous] for idx, igrad in enumerate(igrads): bwop.set_output(idx, igrad) - ograds = [fsegment.grad(t) if t.requires_grad else None for t in fwop.outputs() if isinstance(t, IRSubTensor)] # Ad-hoc fix: remove float that could be caused by loss for segment if isinstance(bwop, IRSegment): ograds = [grad for grad in ograds if isinstance(grad, IRSubTensor)] @@ -652,26 +666,74 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: segment = segments[0] inputs, outputs = set(), set() + + # go through adapters + adapter_ins: Dict[IRSubTensor, Set[int]] = dict() + adapter_ous: Dict[IRSubTensor, Set[int]] = dict() + for adapter in nodes: + if not isinstance(adapter, IRAdapter): + continue + for itensor in adapter.inputs(): + if not isinstance(itensor, IRSubTensor): continue + if itensor not in adapter_ins: + adapter_ins[itensor] = set() + adapter_ins[itensor].update(itensor.device) + # producers can from out side node + producers = [] + for ptensor, prod in zip(segment.ptensors(itensor.parent), segment.producers(itensor.parent)): + if ptensor == itensor and set(itensor.device).issubset(set(prod.device)): + producers.append(prod) + if not any(p in nodes for p in producers): + inputs.add(itensor) + for otensor in adapter.outputs(): + if not isinstance(otensor, IRSubTensor): continue + if otensor not in adapter_ous: + adapter_ous[otensor] = set() + adapter_ous[otensor].update(otensor.device) + consumers = [] + for ctensor, cons in zip(segment.ctensors(otensor.parent), segment.consumers(otensor.parent)): + if ctensor == otensor and set(otensor.device).issubset(set(cons.device)): + consumers.append(cons) + if not any(c in nodes for c in consumers): + outputs.add(otensor) + + # go through non-adapter nodes for node in nodes: + if isinstance(node, IRAdapter): + assert node.differentiable, \ + "Non-differentiable IRAdapter is not allowed to be grouped" + continue # update inputs itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] for itensor in itensors: ftensor = itensor.parent if itensor.is_attr(): continue + # from inside adapters + if itensor in adapter_ous: + if len(node.device) > 0 and set(itensor.device).issubset(adapter_ous[itensor]): + continue # from segment inputs if any(t.overlap(itensor) for t in segment.inputs() if isinstance(t, IRSubTensor)): inputs.add(itensor) continue # from outside producers - for ptensor, producer in zip(segment.ptensors(ftensor), segment.producers(ftensor)): - if ptensor.overlap(itensor) and producer not in nodes: - inputs.add(itensor) - continue + producers, ptensors = segment.producers(ftensor), segment.ptensors(ftensor) + producers = [p for p, t in zip(producers, ptensors) if t == itensor] + if len(itensor.device) > 0: + producers = [p for p in producers if set(itensor.device).issubset(set(p.device))] + # from graph inputs or outside adapter (no producer) + if len(producers) == 0 or any(p not in nodes for p in producers): + inputs.add(itensor) + continue # update outputs otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] for otensor in otensors: ftensor = otensor.parent if otensor.is_attr(): continue + # from inside adapters + if otensor in adapter_ins: + if len(node.device) > 0 and set(otensor.device).issubset(adapter_ins[otensor]): + continue # loss doesn't have consumers if len(segment.consumers(ftensor)) == 0: outputs.add(otensor) @@ -680,10 +742,14 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: outputs.add(otensor) continue # for outside consumers - for ctensor, consumer in zip(segment.ctensors(ftensor), segment.consumers(ftensor)): - if ctensor.overlap(otensor) and consumer not in nodes: - outputs.add(otensor) - continue + consumers, ctensors = segment.consumers(ftensor), segment.ctensors(ftensor) + consumers = [c for c, t in zip(consumers, ctensors) if t == otensor] + if len(otensor.device) > 0: + consumers = [c for c in consumers if set(otensor.device).issubset(set(c.device))] + # for adapter (no consumer) + if len(consumers) == 0 or any(c not in nodes for c in consumers): + outputs.add(otensor) + continue segment = IRSegment(nodes, tuple(inputs), tuple(outputs)) return segment @@ -728,7 +794,11 @@ def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: def __repr__(self): fw = 'f' if self.isfw() else 'b' - dscp = f"{fw}Graph{self.cid}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" + inputs = tuple(t for t in self.inputs() if isinstance(t, IRTensor) and not t.is_param()) + if self.isfw(): + dscp = f"{fw}Graph{self.cid}-{self.device}(inputs={inputs}, outputs={self.outputs()})" + else: + dscp = f"{fw}Graph{self.cid}-{self.device}(fGraph{self.mirror.cid}, inputs={inputs}, outputs={self.outputs()})" return dscp def extra_repr(self) -> str: From 031cb554553f3a12052a6af95befa2a25751af8b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 5 Sep 2022 17:12:33 +0800 Subject: [PATCH 0989/1892] enable codegen --- cube/codegen/codegen.py | 8 +- cube/execplan/execplan.py | 2 +- cube/execplan/planpass/fusion.py | 2 +- cube/execplan/planpass/grouping.py | 6 +- cube/program.py | 139 +++++++++++++++++++++++++++++ 5 files changed, 148 insertions(+), 9 deletions(-) create mode 100644 cube/program.py diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 67ba35ce..619deeb7 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -35,7 +35,7 @@ def get_backward_callsite_io_tensors(bp_segment:IRSegment): #inputs to 'backward' outputs of 'backward' ``` """ - assert isinstance(bp_segment, IRSegment) and not bp_segment.forward + assert isinstance(bp_segment, IRSegment) and not bp_segment.isfw() input_tensors = [t for t in bp_segment.mirror.inputs() if \ isinstance(t, IRSubTensor) and \ @@ -99,7 +99,7 @@ def is_temp_tensor(v): inputs : Iterable[IRTensor] if isinstance(node, IRSegment): - if node.forward: + if node.isfw(): outputs = node.outputs() inputs = node.inputs() else: @@ -446,7 +446,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # parse graph body for node in self.execplan.seq(device): if isinstance(node, IRSegment): - if not node.forward: continue # skip backward segment + if not node.isfw(): continue # skip backward segment codes = self.emit_segment_code(node) elif isinstance(node, IRFwOperation): raise RuntimeError(f"Unexcepted global-level op call: {node}") @@ -989,7 +989,7 @@ def emit_node(self, node: IRCell, name: str) -> str: if isinstance(node, IRSegment): # emit forward - if node.forward: + if node.isfw(): code = fsign.format( outputs = outputs, model = f'model.{name}', diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index d9173d2f..86ff6bcc 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -17,7 +17,7 @@ def __init__(self, graph: IRGraph): self._inference_only = not any( isinstance(n, IRBpOperation) or \ (isinstance(n, IRAdapter) and not n.forward) or \ - (isinstance(n, IRSegment) and not n.forward) for n in graph.nodes() + (isinstance(n, IRSegment) and not n.isfw()) for n in graph.nodes() ) # execution sequence for each device diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 9a09b055..ca6e0d61 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -29,7 +29,7 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: if node.forward: ret = DiffFusion.nnfuse(node) cnt = cnt+1 if ret else cnt - if isinstance(node, IRSegment) and node.forward: + if isinstance(node, IRSegment) and node.isfw(): for fnode in node.nodes(): if isinstance(fnode, IRAdapter): if node.forward: diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index f974f49d..4b89965b 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -20,11 +20,11 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: fgroups, bgroups = Grouping.group(execplan) for devid in execplan.devices(): for fpieces, bpieces in zip(fgroups[devid], bgroups[devid]): - fsubgraph = graph.segment(fpieces) + fsubgraph = graph.create_segment(fpieces) if bpieces is not None: - bsubgraph = graph.segment(bpieces) + bsubgraph = graph.create_segment(bpieces) IRCell.make_pair(fsubgraph, bsubgraph) - subgraphs = [fsubgraph] if bpieces is None else [fsubgraph, bsubgraph] + subgraphs = [fsubgraph] if bpieces is None else [fsubgraph, fsubgraph.mirror] for subgraph in subgraphs: # update execution plan: replace the nodes with the subgraph pieces = subgraph.nodes() diff --git a/cube/program.py b/cube/program.py new file mode 100644 index 00000000..5a602a10 --- /dev/null +++ b/cube/program.py @@ -0,0 +1,139 @@ +from typing import List, Tuple +from cube.graph.torch_dtype_mapping import DType2IRDType + +from cube.ir.cten import IRCell, IRTensor +from cube.ir.tensor import IRFullTensor +from cube.ir.operator import IRDataOperation + +from cube.graph import IRGraph +from cube.graph import parser + +from cube.runtime.syndata import CubeDataLoader +from cube.profiler.timer import print_each_rank + +import torch + + +class Program: + + class __Program: + + def __init__(self): + + self._graph = IRGraph([], [], [], 'program') + + instance = None + + def __init__(self): + if not Program.instance: + Program.instance = Program.__Program() + + def __getattr__(self, name): + return getattr(self.instance, name) + + def add_node(self, node: IRCell): + self.instance._graph.insert(node, self.instance._graph.nnodes) + + def add_nodes(self, nodes: List[IRCell]): + for node in nodes: + self.add_node(node) + + def get_graph(self): + return self.instance._graph + + def mirror_as_self(self): + """ + Set mirror as self. This is called when a backward is triggered. + """ + IRCell.make_pair(self.instance._graph, self.instance._graph) + + def clear(self): + self.instance._graph = IRGraph([], [], [], 'program') + + def __repr__(self): + return repr(self.instance._graph) + + +class SemanticDataLoader: + + def __init__(self, dataloader: CubeDataLoader): + if not isinstance(dataloader, CubeDataLoader): + raise TypeError("Expected data loader derived from CubeDataLoader") + self.dataloader: CubeDataLoader = iter(dataloader) + dtype_map = DType2IRDType + self.dtypes = [dtype_map.map(dtype) for dtype in dataloader.dtypes] + self.shapes = [list(shape) for shape in dataloader.shapes] + + def get_batch_dims(self) -> Tuple[int]: + return tuple(self.dataloader.batch_dims) + + def get_batch_size(self) -> int: + return self.dataloader.get_batch_size() + + def set_batch_size(self, bs: int): + self.dataloader.set_batch_size(bs) + return + + def __iter__(self): + return self + + def __next__(self): + outputs = list() + for dtype, shape in zip(self.dtypes, self.shapes): + data = IRFullTensor( + shape, 'data', requires_grad=False, dtype=dtype + ).tosub() + outputs.append(data) + + data_op = IRDataOperation( + data_num=len(outputs), batch_dims=self.get_batch_dims(), + ) + for idx, output in enumerate(outputs): + data_op.set_output(idx, output) + + Program().add_node(data_op) + if len(outputs) == 0: return + elif len(outputs) == 1: return outputs[0] + else: return tuple(outputs) + + +class SemanticModel: + + def __init__(self, model: torch.nn.Module, input_shapes): + """ + Create semantic model based on AI Scientist description. + """ + dist = torch.distributed.is_initialized() + if (not dist) or (dist and torch.distributed.get_rank() == 0): + self.ir_graph = parser.convert_model( + model, input_shapes=input_shapes + ) + else: + self.ir_graph = None + self._loaded_module = None + + def get_graph(self): + return self.ir_graph + + def load_module(self, filename: str, load_content=True): + import importlib.util + spec = importlib.util.spec_from_file_location("GenModel", filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._loaded_module = module.GenModel().cuda() + if load_content: + print_each_rank("> loading parameter content...") + # TODO: make hardcode ./fullmodel.pt programmable + self._loaded_module.load_attr_content('./fullmodel.pt') + + def get_gen_module(self): + return self._loaded_module + + def clear_module(self): + self._loaded_module = None + + def __call__(self, *args): + if self._loaded_module: + return self._loaded_module(*args) + else: + return self.ir_graph(*args) \ No newline at end of file From 8cd9763a3232eaf06fd0763d0263dece882ca3ab Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 5 Sep 2022 19:16:11 +0800 Subject: [PATCH 0990/1892] refine for merge --- cube/graph/function/function.py | 8 ++++---- examples/nlp/palm/palm.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 9179d55a..df208b0c 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -680,10 +680,10 @@ def Reshape(signature, inputs): torch.reshape(Tensor self, int[] shape) -> Tensor """ - # warnings.warn(""" - # 'torch.reshape' is currently dispatched to 'torch.Tensor.view', - # but 'reshape' has keyword parameter 'shape' while 'view' has 'size'. - # ArgumentMissing error may be raised during codegen.""") + warnings.warn(""" + 'torch.reshape' is currently dispatched to 'torch.Tensor.view', + but 'reshape' has keyword parameter 'shape' while 'view' has 'size'. + ArgumentMissing error may be raised during codegen.""") return View(signature, inputs) diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index 298a00e3..1643677e 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -1,10 +1,6 @@ """ -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/linears.py --policy PASMegatron +2 way branch: + OMP_NUM_THREADS=2 torchrun --nproc_per_node=2 --nnodes=1 palm.py """ import torch @@ -269,7 +265,11 @@ def PASBranch(graph: IRGraph, resource): if node.name == 'embedding' or node.name == 'linear': # data parallel algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=0, dim=batch_dim, num=resource.ngpus) + sub_nodes = graph.partition(node, + algo, + idx=0, + dim=batch_dim, + num=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add' or node.name == 'mean': From 8cbe46fee28680be0b6f245f57bc42b096d3c244 Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 5 Sep 2022 20:18:17 +0800 Subject: [PATCH 0991/1892] update PyTorch example: regressive generation of GPT inference, fix TP of PASMegatronInferTP --- examples/nlp/blocks/attention.py | 2 +- examples/nlp/gpt/infer.py | 37 ++++++++++++++++++++++++++++---- examples/nlp/gpt/policy/spmd.py | 30 ++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 42b83843..ff626831 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -90,7 +90,7 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, from typing import Optional, Tuple -@cube.graph.parser.register('l N E^, L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> l N E^', name='one_attention') +@cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> l N E^', name='one_attention') def one_attention(hidden_states: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor, diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py index 0eb0b67d..10d07657 100644 --- a/examples/nlp/gpt/infer.py +++ b/examples/nlp/gpt/infer.py @@ -7,6 +7,8 @@ examples/nlp/gpt/infer.py --policy PASMeshShard --fp16 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/gpt/infer.py --policy PASSingle --fp16 + +PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 examples/nlp/gpt/infer.py --policy PASMegatronInferTP --fp16 """ @@ -53,7 +55,7 @@ def inter(): model = GPTInfer() model = model if not args.fp16 else model.half() - model = model.cuda() + # model = model.cuda() #only for PyTorch run model.eval() dataloader = GPTInferDataLoader(batch_size) @@ -66,10 +68,37 @@ def train_iter(model, dataloader): return loss model = model.get_gen_module() - iter_num = 2 + torch.distributed.barrier() + print_each_rank('model weight consumpition:', rank_only=0) + memory_summary() + + CudaTimer(enable=False).warmup() + iter_num = 4 + warmup = 2 for step in range(iter_num): - output = train_iter(model, dataloader) - print(f'output = {output}') + # if step == 0: + # model_summary(model, next(dataloader)) + + if step >= warmup: + CudaTimer(enable=True).start('e2e') + train_iter(model, dataloader) + if step >= warmup: + CudaTimer().stop('e2e') + + if step == 0: + print_each_rank('passed first iteration') + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num - warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num - warmup) + memory_summary() + + # iter_num = 2 + # for step in range(iter_num): + # output = train_iter(model, dataloader) + # print(f'output = {output}') ################## PyTorch run # output = None diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index 7729874b..95987994 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -80,6 +80,36 @@ def PASMegatronTP(graph: IRGraph, resource): return graph +def PASMegatronInferTP(graph: IRGraph, resource): + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # annotating code structure -- not consider multiref on embedding weight + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + # why -1: multiref + fnodes[idx - 1].comment = f'===> start of transformer layer {lid}' + + # attention + attns = [node for node in fnodes if node.name == 'one_attention'] + for attn in attns: + _tp(graph, attn, tp_devs, idx=3, dim=0) + + # feedforward + ffns = [node for node in fnodes if node.name == 'feedforward'] + for ffn in ffns: + _tp(graph, ffn, tp_devs, idx=1, dim=0) + + # replicate other nodes + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: + _replica(graph, node, tp_devs) + + return graph + + def PASMeshShard(graph: IRGraph, resource): # print(graph.extra_repr()) From 2db45d6ca1f255d90479ea1826dd49283c519196 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 5 Sep 2022 20:42:04 +0800 Subject: [PATCH 0992/1892] fix backward on consecutive segments --- cube/codegen/codegen.py | 7 ++++--- cube/runtime/adapter/collectives.py | 8 ++++---- cube/runtime/adapter/nn.py | 1 - cube/runtime/adapter/transform.py | 16 +--------------- cube/runtime/executor.py | 21 ++++++++++++++++++++- 5 files changed, 29 insertions(+), 24 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 619deeb7..ed3115e7 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -42,7 +42,7 @@ def get_backward_callsite_io_tensors(bp_segment:IRSegment): t.requires_grad and \ not t.is_attr() ] - output_tensors = [t for t in bp_segment.mirror.outputs() if isinstance(t, IRSubTensor)] + output_tensors = [t for t in bp_segment.mirror.outputs() if isinstance(t, IRSubTensor) and t.grad is not None] input_grads = [t.grad for t in input_tensors] # WARNING !!! @@ -981,6 +981,7 @@ def emit_node(self, node: IRCell, name: str) -> str: Emit node / subgraph code """ fsign = '{outputs} = cube.runtime.executor.fexecute({model}, *{inputs}, requires_grad={req_grad})' + asign = '{outputs} = cube.runtime.executor.aexecute({model}, *{inputs}, requires_grad={req_grad})' bsign = '{input_grads} = cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' inputs = self.tuple_naming(node.inputs(), skip_attr=True, prefix_attr='model.') @@ -1020,7 +1021,7 @@ def emit_node(self, node: IRCell, name: str) -> str: code = f'{outputs} = next(dataloader)' elif isinstance(node, IRAdapter): - code = fsign.format( + code = asign.format( outputs = outputs, model = f'model.{name}', inputs = inputs, @@ -1028,7 +1029,7 @@ def emit_node(self, node: IRCell, name: str) -> str: ) elif isinstance(node, IRWeightReducer): - code = fsign.format( + code = asign.format( outputs = outputs, model=f'model.{name}', inputs='()', diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 00e2114f..e514f475 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -99,7 +99,7 @@ def all_reduce(itensor: torch.Tensor, CudaTimer().start(field_name='comm') if not itensor.is_contiguous(): itensor = itensor.contiguous() - itensor = itensor.detach().requires_grad_() + itensor = itensor.detach() group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(itensor, group=group) torch.cuda.synchronize() @@ -120,7 +120,7 @@ def all_gather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tens torch.distributed.all_gather(tensor_list, itensor, group=group) torch.cuda.synchronize() # concat - otensor = torch.concat(tuple(tensor_list), dim=dim).requires_grad_() + otensor = torch.concat(tuple(tensor_list), dim=dim) CudaTimer().stop(field_name='comm') return otensor @@ -135,7 +135,7 @@ def reduce_scatter(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch. if not tensor.is_contiguous(): itensors[idx] = tensor.contiguous() group = DeviceGroup().get_group(ranks) - otensor = torch.empty_like(itensors[0], requires_grad=True) + otensor = torch.empty_like(itensors[0], requires_grad=False) torch.distributed.reduce_scatter(otensor, itensors, group=group) torch.cuda.synchronize() CudaTimer().stop(field_name='comm') @@ -155,7 +155,7 @@ def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) - group = DeviceGroup().get_group(ranks) torch.distributed.all_to_all(otensors, itensors, group=group) torch.cuda.synchronize() - otensor = torch.concat(tuple(otensors), dim=idim).requires_grad_() + otensor = torch.concat(tuple(otensors), dim=idim) CudaTimer().stop(field_name='comm') return otensor diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index bc4f99c0..6dab1c6b 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -2,7 +2,6 @@ import torch from cube.profiler.timer import CudaTimer -from cube.runtime.adapter.collectives import all_reduce from cube.runtime.device import DeviceGroup diff --git a/cube/runtime/adapter/transform.py b/cube/runtime/adapter/transform.py index 016d3d4a..d60a9b18 100644 --- a/cube/runtime/adapter/transform.py +++ b/cube/runtime/adapter/transform.py @@ -10,11 +10,6 @@ def identity(tensor: torch.Tensor): """ identity """ - require_grad = tensor.requires_grad - with torch.no_grad(): - tensor = tensor.detach() - if require_grad: - tensor = tensor.requires_grad_() return tensor @@ -23,14 +18,11 @@ def select(tensor: torch.Tensor, """ Select a part of tensor spatially and numerically. """ - require_grad = tensor.requires_grad with torch.no_grad(): sub_tensor = tensor[indmap] if valmap != 1: sub_tensor = sub_tensor / valmap sub_tensor = sub_tensor.detach() - if require_grad: - sub_tensor = sub_tensor.requires_grad_() return sub_tensor @@ -43,11 +35,8 @@ def smerge(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: tensors: a list of torch tensor dim: the dimension to concatenate. """ - require_grad = any(t.requires_grad for t in tensors) with torch.no_grad(): - out = torch.concat(tuple(tensors), dim).requires_grad_() - if require_grad: - out = out.requires_grad_() + out = torch.concat(tuple(tensors), dim) return out @@ -59,11 +48,8 @@ def vmerge(tensors: List[torch.Tensor]) -> torch.Tensor: Args: tensors: a list of torch tensor """ - require_grad = any(t.requires_grad for t in tensors) with torch.no_grad(): out = tensors[0] for tensor in tensors[1:]: out = out + tensor - if require_grad: - out = out.requires_grad_() return out diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index f0e4d97a..8d7331de 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -6,6 +6,13 @@ import torch +def debug_id(tensors, msg: str, rank: int): + if torch.distributed.get_rank() == rank: + if torch.is_tensor(tensors): + print(f'[{torch.distributed.get_rank()}] {msg}: [{id(tensors)}]') + else: + print(f'[{torch.distributed.get_rank()}] {msg}: {[id(t) for t in tensors]}') + class Executor: @@ -21,6 +28,7 @@ def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) outputs = subgraph(*input_tensors) else: # everytime forward a segment, detach the tensor from previous graph + # debug_id(input_tensors, 'outside fexecute args', 0) for itensor in input_tensors: if torch.is_tensor(itensor) and itensor.requires_grad: assert itensor not in Executor._detach @@ -28,7 +36,9 @@ def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) input_tensors = tuple( Executor._detach[t] if t in Executor._detach else t for t in input_tensors ) + # debug_id(input_tensors, 'inside fexecute args', 0) outputs = subgraph(*input_tensors) + # debug_id(outputs, 'fexecute result', 0) # print('forwarding... ') return outputs @@ -42,7 +52,7 @@ def aexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) outputs = subgraph(*input_tensors) else: outputs = subgraph(*input_tensors) - # print('forwarding... ') + outputs = outputs.requires_grad_() if torch.is_tensor(outputs) else (t.requires_grad_() for t in outputs) return outputs @staticmethod @@ -71,7 +81,10 @@ def backward(input_tensors: List[torch.Tensor], return None inputs = list() # everytime backward a input tensor, remove it from _detach + # debug_id(input_tensors, 'outside grad of input', 0) input_tensors = [Executor._detach.pop(t) if t in Executor._detach else t for t in input_tensors] + # debug_id(input_tensors, 'inside grad of input', 0) + # debug_id(output_tensors, 'grad of output', 0) for input_ in input_tensors: if torch.is_tensor(input_) and not isinstance(input_, torch.nn.Parameter): if input_.requires_grad: @@ -82,12 +95,18 @@ def backward(input_tensors: List[torch.Tensor], grad_tensors=output_tensor_grads, ) grads = tuple(input_.grad for input_ in inputs) + assert all(grad is not None for grad in grads), "RuntimeError: got gradient None" if len(grads) == 0: return None elif len(grads) == 1: return grads[0] else: return tuple(grads) + @staticmethod + def clear(): + Executor._detach = dict() + fexecute = Executor.fexecute +aexecute = Executor.aexecute backward = Executor.backward From a37a59ba3b62d1261ede961ea7a0dcf13771cf06 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 6 Sep 2022 14:10:50 +0800 Subject: [PATCH 0993/1892] save work, not runnable --- examples/nlp/palm/palm.py | 124 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 119 insertions(+), 5 deletions(-) diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index 1643677e..55468221 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -140,6 +140,22 @@ def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): return x +@cube.graph.parser.register('N L^ E^, E^ F^ -> N L^ F^', name='feedforward1') +def feedforward1(x: torch.Tensor, proj: torch.Tensor): + return torch.nn.functional.silu(torch.matmul(x, proj)) + + +@cube.graph.parser.register('N L^ E^, E^ F^ -> N L^ F^', name='feedforward2') +def feedforward2(x: torch.Tensor, proj: torch.Tensor): + return torch.matmul(x, proj) + + +@cube.graph.parser.register('N L^ E^, N L^ E^, E^ F -> N L^ F', + name='feedforward3') +def feedforward3(x: torch.Tensor, y: torch.Tensor, proj: torch.Tensor): + return torch.matmul(x * y, proj) + + class PaLMLayer(nn.Module): def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): @@ -183,6 +199,52 @@ def forward(self, in_x): return in_x + attn_out + ff_out +class PaLMLayerV2(nn.Module): + + def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): + super().__init__() + + self.dim, self.dim_head, self.heads, self.scale = dim, dim_head, heads, dim_head**-0.5 + + # TODO + # self.alibi_pos_biases = AlibiPositionalBias(heads=self.heads) + # self.norm = RMSNorm(dim) + self.norm = torch.nn.LayerNorm(self.dim) + + self.qkv_proj = torch.nn.Parameter(torch.randn(dim, dim + dim_head)) + self.attn_out_proj = torch.nn.Parameter(torch.randn(dim, dim)) + + self.ff_proj1 = torch.nn.Parameter(torch.randn(dim, ff_mult * dim)) + self.ff_proj2 = torch.nn.Parameter(torch.randn(dim, ff_mult * dim)) + self.ff_proj3 = torch.nn.Parameter(torch.randn(ff_mult * dim, dim)) + + # self.register_buffer("mask", None, persistent=False) + + def get_mask(self, n, device): + if self.mask is not None and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), + 1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def forward(self, in_x): + bs, n, device = in_x.shape[0], in_x.shape[1], in_x.device + + # pre layernorm + x = self.norm(in_x) + + attn_out = multi_head_attention(x, self.qkv_proj, self.attn_out_proj, + self.heads, self.scale) + + ff1 = feedforward1(x, self.ff_proj1) + ff2 = feedforward2(x, self.ff_proj2) + ff_out = feedforward3(ff1, ff2, self.ff_proj3) + + return in_x + attn_out + ff_out + + class PaLM(nn.Module): def __init__(self, @@ -196,7 +258,11 @@ def __init__(self, self.net = nn.Sequential( nn.Embedding(num_tokens, dim), - *[PaLMLayer(dim, dim_head, heads, ff_mult) for _ in range(depth)], + # *[PaLMLayer(dim, dim_head, heads, ff_mult) for _ in range(depth)], + *[ + PaLMLayerV2(dim, dim_head, heads, ff_mult) + for _ in range(depth) + ], torch.nn.LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False), ) @@ -261,7 +327,6 @@ def PASBranch(graph: IRGraph, resource): for node in graph.nodes(): if isinstance(node, IRFwOperation): - print(node) if node.name == 'embedding' or node.name == 'linear': # data parallel algo = node.algorithms('dim') @@ -287,9 +352,57 @@ def PASBranch(graph: IRGraph, resource): return graph +def PASBranch3(graph: IRGraph, resource): + ''' + 3 way branch + ''' + assert resource.ngpus == 3 + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + batch_dim = node.get_batch_dims()[0] + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if node.name == 'embedding' or node.name == 'linear': + # data parallel + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, + algo, + idx=0, + dim=batch_dim, + num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add' or node.name == 'mean': + # replicate + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + elif node.name == 'feedforward1': + graph.assign(node, 0) + elif node.name == 'feedforward2': + graph.assign(node, 1) + elif node.name == 'feedforward3': + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=2, dim=1, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + elif node.name == 'multi_head_attention': + graph.assign(node, 2) + else: + assert False, node.name + + return graph + + def train(): - bs, n, dim = 8, 1024, 512 - num_tokens, depth, heads, dim_head = 20000, 1, 8, 64 + bs, n, dim = 3, 2048, 4096 + num_tokens, depth, heads, dim_head = 20000, 1, 16, 256 model = PaLM(dim, num_tokens, depth, heads=heads, dim_head=dim_head) @@ -309,7 +422,8 @@ def train(): # @cube.compile(model, dataloader, PAS=PASSingle) # @cube.compile(model, dataloader, PAS=PASData) - @cube.compile(model, dataloader, PAS=PASBranch) + # @cube.compile(model, dataloader, PAS=PASBranch) + @cube.compile(model, dataloader, PAS=PASBranch3) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) From a5812512ab03c16b406f0e65c2e5697e359c2a77 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 6 Sep 2022 16:21:11 +0800 Subject: [PATCH 0994/1892] fix recompute and local fusion bug --- cube/codegen/codegen.py | 2 +- cube/graph/gener/concurrent.py | 24 +++--- cube/graph/gener/gen.py | 144 +++++++++++++++++++-------------- cube/graph/graph.py | 25 +++--- cube/graph/segment.py | 16 ++-- 5 files changed, 118 insertions(+), 93 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index ed3115e7..f89abcec 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -673,7 +673,7 @@ def recompute(tensor_2222): nodes : List[IRCell] = [node for i, node in i_nodes] - subseg = self.execplan.graph.segment(nodes) + subseg = self.execplan.graph.create_segment(nodes) inputs = [t for t in subseg.inputs() if not t.is_attr()] input_names = [self.tensor_naming(t) for t in inputs] diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index a227401b..04a07c0f 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -36,16 +36,16 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # case 1: sharing device (in-shard) inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) if inshard and len(pdevs) > 1: - # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) - try: - fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) - except Exception as e: - fadapter = None - print( - f"full tensor: {fptensors[0].parent} cannot use grid generation.\n" - f"Reason: {str(e)}\n" - f"Switch to general P2P communication." - ) + fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + # try: + # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + # except Exception as e: + # fadapter = None + # print( + # f"full tensor: {fptensors[0].parent} cannot use grid generation.\n" + # f"Reason: {str(e)}\n" + # f"Switch to general P2P communication." + # ) # Case 2: sperating device (cross-shard) if len(set(pdevs).intersection(cdevs)) == 0: @@ -56,6 +56,10 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], if fadapter is None: fadapter = ConcurrentGener.gen_general(fptensors, fctensors, bptensors, bctensors) + if set(pdevs) == set(cdevs) and fadapter.mirror is not None: + fadapter.differentiable = True + fadapter.mirror.differentiable = True + return fadapter @staticmethod diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 09d83923..65abb1c5 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -30,6 +30,38 @@ def to_device(tensor: IRSubTensor, device: int) -> IRFwOperation: return otensor +def create_dummy(segment: IRSegment) -> List[IRFwOperation]: + """ + Create dummy operators that + 1) produce segment input tensors + 2) consume segment output tensors + + @param segment IRSegment: the target segment + + @return nodes List[IRCell]: the generated operation + """ + devices = segment.device + fwops = [] + for devid in devices: + for tensor in segment.inputs(): + if not isinstance(tensor, IRSubTensor): continue + assert tensor.valmap == ValueMap((0, 1)) + fwop = IRFwOperation('segment input', '', 0, 1) + fwop.set_output(0, tensor) + fwop.device = devid + segment.insert(fwop, 0) + fwops.append(fwop) + for tensor in segment.outputs(): + if not isinstance(tensor, IRSubTensor): continue + assert tensor.valmap == ValueMap((0, 1)) + fwop = IRFwOperation('segment output', '', 1, 0) + fwop.set_intput(0, tensor) + fwop.device = devid + segment.insert(fwop, -1) + fwops.append(fwop) + return fwops + + class IRAdapterGener: @staticmethod @@ -49,6 +81,7 @@ def gen(graph: IRGraph) -> IRGraph: graph = IRAdapterGener.gen_activation(graph) # generate weight reducer graph = IRAdapterGener.gen_weight(graph) + # print(graph.extra_repr()) return graph @staticmethod @@ -122,12 +155,14 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: if len(ptensors) == 0 or len(ctensors) == 0: return True # direct connection - if len(ptensors) == 1 and len(ctensors) == 1 and \ - set(ptensors[0].device) == set(ctensors[0].device): - return True - return False - - devices = graph.device + for ctensor in ctensors: + if not any(t == ctensor and set(ctensor.device).issubset(set(t.device)) for t in ptensors): + return False + return True + + fdummies = create_dummy(graph) + bgraph: Optional[IRSegment] = graph.mirror + bdummies = create_dummy(bgraph) if isinstance(bgraph, IRSegment) else [] # generate adapter for inter-segments # FIXME: assume producers and consumers can run in parallel @@ -137,8 +172,11 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: continue # optimization: local fusion / multiref on producer / consumer - # ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) - # IRAdapterGener.local_consumer_multiref(graph, ftensor) + ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) + IRAdapterGener.local_consumer_multiref(graph, ftensor) + + # print(graph.debug_tensor_map_str(ftensor)) + # print(graph.debug_tensor_map_str(ftensor.grad)) # producers can be operators and graph inputs fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) @@ -150,19 +188,17 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bconsumers, bctensors = [], [] if isinstance(ftensor.grad, IRFullTensor): bproducers, bptensors = graph.producers(ftensor.grad), graph.ptensors(ftensor.grad) - assert all(len(ptensor.device) == 1 for ptensor in bptensors), "Not support for multi-device" + assert all(len(ptensor.device) == 1 for ptensor in bptensors), ( + f"Not support for multi-device:\n" + f"{[ptensor.device for ptensor in bptensors]}" + f"{[ptensor.cell for ptensor in bptensors]}" + ) bconsumers, bctensors = graph.consumers(ftensor.grad), graph.ctensors(ftensor.grad) assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" - if skip(fptensors, fctensors) and skip(bptensors, bctensors): continue - + if skip(fptensors, fctensors) and skip(bptensors, bctensors): + continue - # print((f"generating for {ftensor}:\n" - # f"fptensor device: {[t.device for t in fptensors]}\n" - # f"fctensor device: {[t.device for t in fctensors]}\n" - # f"bptensor device: {[t.device for t in bptensors]}\n" - # f"bctensor device: {[t.device for t in bctensors]}\n" - # )) fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors) if fadapter is None: continue @@ -180,13 +216,19 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # insert backward adapter if badapter is not None: assert isinstance(badapter, IRAdapter) - assert isinstance(graph.mirror, IRSegment) + assert isinstance(bgraph, IRSegment) bproducers = [ - graph.mirror.index(consumer.mirror) + 1 for \ + bgraph.index(consumer.mirror) + 1 for \ consumer in graph.consumers(ftensor) ] bidx = max(bproducers) if len(bproducers) > 0 else 0 - graph.mirror.insert(badapter, bidx) + bgraph.insert(badapter, bidx) + + # remove dummy op + for dummy_op in fdummies: + graph.remove(dummy_op) + for dummy_op in bdummies: + bgraph.remove(dummy_op) # generate adapter for each segment segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] @@ -195,9 +237,8 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: return graph - @staticmethod - def local_producer_fusion(graph: IRGraph, ftensor: IRFullTensor) -> IRFullTensor: + def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTensor: """! Fuse the producer tensors using concat and add. This will add a new full tensor by chaging from: @@ -292,7 +333,8 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens new_ftensor = ftensor.like() - # update consumer + # update consumer + min_idx = min(graph.index(consumer) for consumer in graph.consumers(ftensor)) assert len(graph.ctensors(ftensor)) == len(graph.consumers(ftensor)) for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): with graph.update(consumer) as consumer: @@ -300,7 +342,6 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens consumer.inputs().index(ctensor), new_ftensor.select(ctensor.indmap, ctensor.valmap) ) - min_idx = min(graph.index(consumer) for consumer in graph.consumers(ftensor)) # insert new producer for devid, tensors in fuse_tensors.items(): @@ -319,26 +360,20 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens for node in nodes[::-1]: assert node not in graph.nodes() assert len(node.outputs()) == 1 - graph.insert(node, min_idx) - - # insert and update backward node - bgraph: IRSegment = graph.mirror - # update backward node - for consumer in graph.consumers(new_ftensor): - assert isinstance(consumer.mirror, IRBpOperation) - bnode = consumer.mirror - bgraph.update_bwop(bnode) - # insert backward node - bnodes = [graph.bwop(node) for node in nodes] - bidx = min(bgraph.index(producer.mirror) for producer in bgraph.producers(ftensor)) - for bnode in bnodes: - bnode.device = bnode.mirror.device - bgraph.insert(bnode, bidx) + if graph.mirror is not None: + graph.finsert(node, min_idx) + else: + graph.insert(node, min_idx) + + # update backward + if isinstance(ftensor.grad, IRFullTensor): + graph.update_ftensor_bw(new_ftensor) + graph.update_ftensor_bw(ftensor) return new_ftensor @staticmethod - def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): + def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): """! If a device have a same sub-tensor to be consumed multiple times, then create a multiref forward node for it to make @@ -392,29 +427,20 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): multiref.set_output(idx, itensor) # update consumer - min_fidx = len(graph.nodes()) + min_fidx = graph.nnodes for itensor, consumer in zip(itensors, consumers): with graph.update(consumer) as consumer: idx = consumer.inputs().index(ctensor) consumer.set_input(idx, itensor) - min_fidx = min(graph.nodes().index(consumer) for consumer in consumers) - + # insert forward multiref - graph.attach(multiref, min_fidx) + min_fidx = min(graph.index(consumer) for consumer in consumers) + if graph.mirror is not None: + graph.finsert(multiref, min_fidx) + else: + graph.insert(multiref, min_fidx) multirefs[multiref] = consumers + + if isinstance(ftensor.grad, IRFullTensor): + graph.update_ftensor_bw(ftensor) - # insert / update backward - if graph.train: - for multiref, consumers in multirefs.items(): - # update consumer backward - for consumer in consumers: - assert isinstance(consumer.mirror, IRBpOperation) - bnode: IRBpOperation = consumer.mirror - with graph.update(bnode) as bnode: - bnode.update() - # insert backward - bnode = multiref.gen_backward() - bnode.device = multiref.device - bidx = max(graph.nodes().index(consumer.mirror) for consumer in consumers) - bsid = graph.stage_id(graph.node(bidx)) - graph.attach(bnode, bidx+1, stage_idx=bsid) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 70317097..23e50bd6 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -692,7 +692,7 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: # ================= Other optimizations ================== - def recompute(self, nodes: Union[List[IRFwOperation], IRSegment]) -> bool: + def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: """! Recompute a set of nodes. The forward nodes will be assigned with a unique recompute group id. A forward not can not be recomputed in different recompute groups. @@ -701,25 +701,18 @@ def recompute(self, nodes: Union[List[IRFwOperation], IRSegment]) -> bool: @return success boolean: always success """ - assert all(isinstance(nodes, IRFwOperation)) or isinstance(nodes, IRSegment), \ + assert all(isinstance(node, IRFwOperation) for node in nodes) or isinstance(nodes, IRSegment), \ "Require forward nodes or a single segment" - recompute_group_id: int = IDGenerator().gen_cell_id() - if isinstance(nodes, IRSegment): - assert nodes.forward, "Can only apply recompute on segment node" - for fnode in nodes.node(): - fnode.recompute = recompute_group_id + assert nodes.isfw() and (not nodes.isbw()), "Only forward IRSegment can recompute" + return self.recompute(nodes.nodes()) + else: - indices = [self.index(node) for node in nodes] - if all(idx[1] is None for idx in indices): - assert all(idx[0] == indices[0][0] for idx in indices), \ - f"Cross-stage recompute is not allowed yet." - elif all(idx[1] is not None for idx in indices): - assert all(idx[0] == indices[0][0] for idx in indices), \ - f"Cross-stage recompute is not allowed yet." - else: - assert False, f"Cross-stage recompute is not allowed yet." + segments = [self.segment(node) for node in nodes] + assert all(segment == segments[0] for segment in segments), \ + "Cross-segment recompute is not allowed yet" + recompute_group_id: int = IDGenerator().gen_cell_id() for fnode in nodes: fnode.recompute = recompute_group_id diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 86914561..ace8ceb8 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -214,9 +214,8 @@ def index(self, node: IRCell) -> CellPosition: @return index int: the index """ assert isinstance(node, IRCell) - cids = tuple(node.cid for node in self._nodes) - if node.cid in cids: - return CellPosition((cids.index(node.cid),)) + if node in self._nodes: + return CellPosition((self._nodes.index(node),)) for idx, segment in enumerate(self._nodes): if isinstance(segment, IRSegment): if segment.exist(node): @@ -484,7 +483,8 @@ def remove(self, node: IRCell, _pos: CellPosition = None) -> CellPosition: @return index CellPosition: the removed index """ pos = self.index(node) if _pos is None else _pos - assert self.node(pos) == node, "posititon doesn't not match with node" + assert self.node(pos) == node, \ + f"posititon doesn't not match with node:\n\t{node}\nGot:\n\t{self.node(pos)}" if len(pos.indices) == 1: index = pos[0] @@ -590,6 +590,7 @@ def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwO fsegment.insert(fwop, index) # create backward bwop = fsegment.create_bwop(fwop) + bwop.device = fwop.device # insert backward assert fsegment.mirror is not None, "Missing backward segment" bsegment: IRSegment = fsegment.mirror @@ -661,9 +662,10 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: @return segment IRSegment: the grouped segment. """ - segments: List[IRSegment] = [self.segment(node) for node in nodes] - assert len(set(segments)) == 1, "Cross segment hierarchy grouping is not allowed" - segment = segments[0] + segment = self + # segments: List[IRSegment] = [self.segment(node) for node in nodes] + # assert len(set(segments)) == 1, "Cross segment hierarchy grouping is not allowed" + # segment = segments[0] inputs, outputs = set(), set() From baa6fb5e9c49b02cf8f84869f2125b388d598c6e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Sep 2022 09:44:49 +0800 Subject: [PATCH 0995/1892] enhance segment as subgraph for training --- cube/codegen/codegen.py | 6 ++++-- cube/execplan/planpass/grouping.py | 7 ++++++ cube/graph/gener/concurrent.py | 2 -- cube/graph/gener/gen.py | 34 ++++++++++++++++++++++++++++++ cube/runtime/executor.py | 34 +++++++++++++++--------------- 5 files changed, 62 insertions(+), 21 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index f89abcec..16158e32 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -980,9 +980,9 @@ def emit_node(self, node: IRCell, name: str) -> str: """ Emit node / subgraph code """ - fsign = '{outputs} = cube.runtime.executor.fexecute({model}, *{inputs}, requires_grad={req_grad})' + fsign = '{outputs} = cube.runtime.executor.fexecute({name}, {model}, *{inputs}, requires_grad={req_grad})' asign = '{outputs} = cube.runtime.executor.aexecute({model}, *{inputs}, requires_grad={req_grad})' - bsign = '{input_grads} = cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' + bsign = '{input_grads} = cube.runtime.executor.backward({name}, {input_tensors}, {output_tensors}, {output_grads})' inputs = self.tuple_naming(node.inputs(), skip_attr=True, prefix_attr='model.') outputs = self.return_naming(node.outputs(), skip_attr=True, prefix_attr='model.') @@ -993,6 +993,7 @@ def emit_node(self, node: IRCell, name: str) -> str: if node.isfw(): code = fsign.format( outputs = outputs, + name = f"'{name}'", model = f'model.{name}', inputs = inputs, req_grad = req_grad @@ -1007,6 +1008,7 @@ def emit_node(self, node: IRCell, name: str) -> str: assert tensor == 1.0, "Loss gradient should be 1.0" output_grads[idx] = None code = bsign.format( + name = f"'{self.node_naming(node.mirror)}'", input_grads = self.return_naming(input_grads), input_tensors = self.tuple_naming(input_tensors, skip_attr=True, prefix_attr='model.'), output_tensors = self.tuple_naming(output_tensors, skip_attr=True, prefix_attr='model.'), diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 4b89965b..16164755 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -6,11 +6,13 @@ from cube.execplan import ExecutionPlan from cube.execplan.planpass.planpass import PlanPass from cube.ir.adapter import IRAdapter +from cube.ir.adapter.prim import IdentityPrim from cube.ir.operator import IRFwOperation from cube.ir.cten import IRCell class Grouping(PlanPass): + @staticmethod def apply(execplan: ExecutionPlan) -> ExecutionPlan: """ @@ -32,6 +34,11 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: execplan.at(devid).insert(idx, subgraph) for node in pieces: execplan.at(devid).remove(node) + # remove identity adapter + for adapter in execplan.seq(devid): + if isinstance(adapter, IRAdapter): + if all(isinstance(prim, IdentityPrim) for prim in adapter.prims): + execplan.at(devid).remove(adapter) return execplan @staticmethod diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 04a07c0f..7b40fb2a 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -134,8 +134,6 @@ def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], bptensors: List[IRSubTensor], bctensors: List[IRSubTensor]) -> IRAdapter: """ A general way to generate adapter. - FIXME: Assuming consumers at different devices can happen at the same time. - This will block the pipeline parallelism description. @param ftensor IRFullTensor @return adapter IRAdapter diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 65abb1c5..030c12ef 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -81,6 +81,8 @@ def gen(graph: IRGraph) -> IRGraph: graph = IRAdapterGener.gen_activation(graph) # generate weight reducer graph = IRAdapterGener.gen_weight(graph) + # fuse consecutive non-differentiable adapters into one + graph = IRAdapterGener.fusion(graph) # print(graph.extra_repr()) return graph @@ -444,3 +446,35 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): if isinstance(ftensor.grad, IRFullTensor): graph.update_ftensor_bw(ftensor) + @staticmethod + def fusion(graph: IRSegment) -> IRSegment: + """ + Fuse consecutive adapters into one + """ + fadapters, badapters = [], [] + for adapter in graph.nodes(): + if isinstance(adapter, IRAdapter) and adapter.forward and not adapter.differentiable: + fadapters.append(adapter) + if adapter.mirror is not None: + badapters.insert(0, adapter.mirror) + else: + if len(fadapters) > 1: + # insert fused fadapter + fused_fadapter = IRAdapter.merge(fadapters) + for adapter in fadapters: + idx = graph.remove(adapter) + graph.insert(fused_fadapter, idx) + # insert fused badapter + fused_badapter = IRAdapter.merge(badapters) if len(badapters) > 0 else None + for adapter in badapters: + idx = graph.remove(adapter) + if fused_badapter is not None: + graph.insert(fused_badapter, idx) + IRCell.make_pair(fused_fadapter, fused_badapter) + fadapters, badapters = [], [] + + for segment in graph.nodes(): + if isinstance(segment, IRSegment) and segment.isfw(): + IRAdapterGener.fusion(segment) + + return graph \ No newline at end of file diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 8d7331de..f731df81 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -16,10 +16,10 @@ def debug_id(tensors, msg: str, rank: int): class Executor: - _detach: Dict[torch.Tensor, torch.Tensor] = dict() + _detach: Dict[str, Dict[torch.Tensor, torch.Tensor]] = dict() @staticmethod - def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): + def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): """ forward the sub-graph. """ @@ -29,10 +29,12 @@ def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) else: # everytime forward a segment, detach the tensor from previous graph # debug_id(input_tensors, 'outside fexecute args', 0) + assert name not in Executor._detach + Executor._detach[name] = dict() for itensor in input_tensors: if torch.is_tensor(itensor) and itensor.requires_grad: - assert itensor not in Executor._detach - Executor._detach[itensor] = itensor.detach().requires_grad_() + if itensor not in Executor._detach[name]: + Executor._detach[itensor] = itensor.detach().requires_grad_() input_tensors = tuple( Executor._detach[t] if t in Executor._detach else t for t in input_tensors ) @@ -56,7 +58,8 @@ def aexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) return outputs @staticmethod - def backward(input_tensors: List[torch.Tensor], + def backward(name: str, + input_tensors: List[torch.Tensor], output_tensors: List[torch.Tensor], output_tensor_grads: List[torch.Tensor]) -> Tuple[torch.Tensor]: """ @@ -79,23 +82,20 @@ def backward(input_tensors: List[torch.Tensor], """ if len(output_tensors) == 0: return None - inputs = list() - # everytime backward a input tensor, remove it from _detach - # debug_id(input_tensors, 'outside grad of input', 0) - input_tensors = [Executor._detach.pop(t) if t in Executor._detach else t for t in input_tensors] - # debug_id(input_tensors, 'inside grad of input', 0) - # debug_id(output_tensors, 'grad of output', 0) - for input_ in input_tensors: - if torch.is_tensor(input_) and not isinstance(input_, torch.nn.Parameter): - if input_.requires_grad: - input_.retain_grad() - inputs.append(input_) + + assert name in Executor._detach, f"forward graph: {name} not run before" + input_tensors = [t for t in input_tensors if torch.is_tensor(t) and not isinstance(t, torch.nn.Parameter)] + input_tensors = [t for t in input_tensors if t.requires_grad] + input_tensors = [Executor._detach[t] if t in Executor._detach else t for t in input_tensors] + for t in input_tensors: + t.retain_grad() torch.autograd.backward( output_tensors, grad_tensors=output_tensor_grads, ) - grads = tuple(input_.grad for input_ in inputs) + grads = tuple(t.grad for t in input_tensors) assert all(grad is not None for grad in grads), "RuntimeError: got gradient None" + del Executor._detach[name] if len(grads) == 0: return None elif len(grads) == 1: return grads[0] else: return tuple(grads) From 624fa12f6a47cba4bb99573961c6cd0cf3592855 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Sep 2022 10:57:15 +0800 Subject: [PATCH 0996/1892] fix executor --- cube/graph/graph.py | 38 ++++++++++++++++++++++++++++++++++++++ cube/runtime/executor.py | 11 ++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 23e50bd6..93c7938b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -717,3 +717,41 @@ def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: fnode.recompute = recompute_group_id return True + + # =================== Helpers ==================== + + def auto_multiref(self): + """ + Automatically partition and schedule multiref node. + This requires to call after all transformation and + scheduling. + + The policy is to partition and assign multiref + in the same way of its input producer + """ + for node in self.nodes(flatten=True): + if node.name == 'multiref': + multirefs = [] + segment: IRSegment = self.segment(node) + ftensor = node.input(0).parent + ptensors = segment.ptensors(ftensor) + for ptensor in ptensors: + assert len(ptensor.device) > 0, \ + "Auto Multiref requires its producer nodes assigned to devices" + for devid in ptensor.device: + outputs = [] + for output in node.outputs(): + outputs.append(output.parent.select(ptensor.indmap, ptensor.valmap)) + multiref = MultiRef('', [ptensor, len(outputs)]) + for idx, otensor in enumerate(outputs): + multiref.set_output(idx, otensor) + multiref.device = devid + multirefs.append(multiref) + fidx = self.remove(node) + if node.mirror is not None: + self.remove(node.mirror) + for multiref in multirefs[::-1]: + if node.mirror is not None: + self.finsert(multiref, fidx) + else: + self.insert(multiref, fidx) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index f731df81..91b56d28 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -34,9 +34,9 @@ def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires for itensor in input_tensors: if torch.is_tensor(itensor) and itensor.requires_grad: if itensor not in Executor._detach[name]: - Executor._detach[itensor] = itensor.detach().requires_grad_() + Executor._detach[name][itensor] = itensor.detach().requires_grad_() input_tensors = tuple( - Executor._detach[t] if t in Executor._detach else t for t in input_tensors + Executor._detach[name][t] if t in Executor._detach[name] else t for t in input_tensors ) # debug_id(input_tensors, 'inside fexecute args', 0) outputs = subgraph(*input_tensors) @@ -86,7 +86,7 @@ def backward(name: str, assert name in Executor._detach, f"forward graph: {name} not run before" input_tensors = [t for t in input_tensors if torch.is_tensor(t) and not isinstance(t, torch.nn.Parameter)] input_tensors = [t for t in input_tensors if t.requires_grad] - input_tensors = [Executor._detach[t] if t in Executor._detach else t for t in input_tensors] + input_tensors = [Executor._detach[name][t] if t in Executor._detach[name] else t for t in input_tensors] for t in input_tensors: t.retain_grad() torch.autograd.backward( @@ -104,6 +104,11 @@ def backward(name: str, def clear(): Executor._detach = dict() + @staticmethod + def check_clear(): + assert len(Executor._detach) == 0, \ + f"Find remain not consumed sub-graph: {tuple(Executor._detach.keys())}" + fexecute = Executor.fexecute aexecute = Executor.aexecute From 3e19cb314df5e329378306adbdf787f0be594966 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Sep 2022 15:30:05 +0800 Subject: [PATCH 0997/1892] reset dependency and visualization fix --- cube/execplan/execplan.py | 60 +++++++++++++++------------ cube/graph/graph.py | 85 ++++++++++++++++++++++++++++++++------- cube/graph/segment.py | 15 ++++--- 3 files changed, 114 insertions(+), 46 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 86ff6bcc..67ea9112 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -99,21 +99,19 @@ def set(self, devid: int, seq: List[IRCell]): raise TypeError("Expected a list of Cell") self._seq[devid] = seq - def analyze(self, + def visualize(self, map2time: Optional[Callable] = None, map2mem: Optional[Callable] = None, map2name: Optional[Callable] = None, - outfile = None): + outfile: Optional[str] = None): """ - Draw the execution timeline. + Visualize the graph - Args: - span (List[int]): - length equal to schedule unit num. - Each element stands for the time span for corresponding Cell - - outfile: - the output file name + @param map2time Optional[Callable]: node to time (int) map. + @param map2mem Optional[Callable]: node to memory consumption (int) map + @param map2name Optional[Callable]: node to name (str) map + @param outfile Optional[str]: the output file name. + If given, will save the visualized execution plan in file. """ ndevice = len(self.devices()) # timeline [ [ (start_time, end_time), ... ], ... ] @@ -124,10 +122,11 @@ def analyze(self, if map2time is None: def map2time(node): - if isinstance(node, IRGraph): + if isinstance(node, IRSegment): span = 0 for node in node.nodes(): span += map2time(node) + return span if isinstance(node, IRFwOperation): return 1 if isinstance(node, IRBpOperation): @@ -138,7 +137,7 @@ def map2time(node): if map2mem is None: def map2mem(node): - if isinstance(node, IRGraph): + if isinstance(node, IRSegment): peak_mem = 0 curr_mem = 0 for node in node.nodes(): @@ -152,20 +151,21 @@ def map2mem(node): if map2name is None: def map2name(node): - if isinstance(node, IRGraph): - if all([isinstance(n, IRFwOperation) for n in node.nodes()]): - return f'f{node._id}' - if all([isinstance(n, IRBpOperation) for n in node.nodes()]): - if node.mirror is not None: - return f'b{node.mirror._id}' + if isinstance(node, IRSegment): + if node.isfw(): + return f'f{node.cid}' + elif node.isbw(): + return f'b{node.mirror.cid}' if isinstance(node, IRFwOperation): - return f'f{node._id}' + return f'f{node.cid}' if isinstance(node, IRBpOperation): - return f'b{node.mirror._id}' - return str(node._id) + return f'b{node.mirror.cid}' + if isinstance(node, IRAdapter): + return f'a{node.cid}' + return f'?{node.cid}' def map2color(node): - if isinstance(node, IRGraph): + if isinstance(node, IRSegment): return map2color(node.nodes(0)) if isinstance(node, IRFwOperation): return '#4472C4' # excel blue @@ -174,13 +174,12 @@ def map2color(node): if isinstance(node, IRAdapter): return '#70AD47' # excel green + self.graph.reset_dependency() for node in self.graph.nodes(): span, mem = map2time(node), map2mem(node) + # calculate time + start_times = [] for device in node.device: - # memory - device_mem[device] += mem - if device_peak_mem[device] < device_mem[device]: - device_peak_mem[device] = device_mem[device] # tight execution if no dependency if len(device_timeline[device]) == 0: start_time = 1 @@ -196,8 +195,17 @@ def map2color(node): if other_node in node.predecessors(): start_time = max(start_time, end_time) break + start_times.append(start_time) + + start_time = max(start_times) + for device in node.device: + # time device_timeline[device].append((start_time, start_time + span)) device_nodes[device].append(node) + # memory + device_mem[device] += mem + if device_peak_mem[device] < device_mem[device]: + device_peak_mem[device] = device_mem[device] max_time = max( [tline[-1][1] for tline in device_timeline if len(tline) != 0] diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 93c7938b..d4efb15d 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -12,7 +12,7 @@ from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap from cube.ir.dtype import IRDType, DTypeInferRule from cube.graph.function.function import Identity, MultiRef @@ -267,6 +267,35 @@ def from_logic_graph(nodes: List[IRCell], idx = consumer.inputs().index(ftensor) consumer.set_input(idx, ctensor) + # another version to generate multiref: one for all + # for node in nodes: + # ftensors = set() + # for ftensor in node.inputs(): + # # remove redundant tensors within an operator + # if isinstance(ftensor, IRFullTensor) and ftensor._id not in ftensors: + # ftensors.add(ftensor._id) + # if ftensor not in consumers: + # consumers[ftensor] = [] + # consumers[ftensor].append(node) + # for ftensor in node.outputs(): + # if isinstance(ftensor, IRFullTensor): + # producers[ftensor] = node + # for ftensor, cnodes in consumers.items(): + # if len(cnodes) == 1 or ftensor.is_attr(): continue + # itensors = [ftensor.like() for _ in range(len(cnodes))] + # for itensor, consumer in zip(itensors, cnodes): + # while ftensor in consumer.inputs(): + # idx = consumer.inputs().index(ftensor) + # consumer.set_input(idx, itensor) + # # create and insert multiref operation + # multiref = MultiRef(None, [ftensor, len(cnodes)]) + # for idx, itensor in enumerate(itensors): + # multiref.set_output(idx, itensor) + # multiref.infer_shape() + # idx = nodes.index(producers[ftensor]) + 1 if ftensor in producers else 0 + # # idx = nodes.index(cnodes[0]) + # nodes.insert(idx, multiref) + # instantiate graph inputs / outputs for idx, tensor in enumerate(inputs): if isinstance(tensor, IRFullTensor): @@ -731,22 +760,50 @@ def auto_multiref(self): """ for node in self.nodes(flatten=True): if node.name == 'multiref': - multirefs = [] + if len(node.device) != 0: continue segment: IRSegment = self.segment(node) ftensor = node.input(0).parent ptensors = segment.ptensors(ftensor) - for ptensor in ptensors: - assert len(ptensor.device) > 0, \ - "Auto Multiref requires its producer nodes assigned to devices" - for devid in ptensor.device: - outputs = [] - for output in node.outputs(): - outputs.append(output.parent.select(ptensor.indmap, ptensor.valmap)) - multiref = MultiRef('', [ptensor, len(outputs)]) - for idx, otensor in enumerate(outputs): - multiref.set_output(idx, otensor) - multiref.device = devid - multirefs.append(multiref) + + multirefs = [] + + # use downstream consumers + devtensors: Dict[int, List[IRSubTensor]] = dict() + for tensor in node.outputs(): + for ctensor in segment.ctensors(tensor.parent): + for devid in ctensor.device: + if devid not in devtensors: + devtensors[devid] = [] + devtensors[devid].append(ctensor) + devids = list(devtensors.keys()) + ctensors = [ts[0] for ts in devtensors.values()] + for devid, ctensor in zip(devids, ctensors): + itensor = node.input(0).parent.select(ctensor.indmap, ctensor.valmap) + otensors = [] + for otensor in node.outputs(): + otensors.append(otensor.parent.select(ctensor.indmap, ctensor.valmap)) + multiref = MultiRef('', [itensor, len(otensors)]) + for idx, otensor in enumerate(otensors): + multiref.set_output(idx, otensor) + multiref.device = devid + multirefs.append(multiref) + + # if no downstream consumers, use upstream producers + if len(multirefs) == 0: + for ptensor in ptensors: + assert len(ptensor.device) > 0, \ + "Auto Multiref requires its producer nodes assigned to devices" + for devid in ptensor.device: + outputs = [] + for output in node.outputs(): + outputs.append(output.parent.select(ptensor.indmap, ptensor.valmap)) + multiref = MultiRef('', [ptensor, len(outputs)]) + for idx, otensor in enumerate(outputs): + multiref.set_output(idx, otensor) + multiref.device = devid + multirefs.append(multiref) + + # replace into graph fidx = self.remove(node) if node.mirror is not None: self.remove(node.mirror) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index ace8ceb8..848d4251 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -128,27 +128,30 @@ def attributes(self) -> Tuple[IRFullTensor]: def reset_dependency(self): """ Reset the node dataflow dependency - - FIXME - + Note all the predefined control dependencies will be removed. + TODO: adapter dependency is not set """ for node in self._nodes: node.clear_predecessor() node.clear_successor() # TODO: adapter dependency not set for ftensor in self._ftensors: - for ptensor, producer in zip(ftensor.ptensors, ftensor.producers): - for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + for ptensor, producer in zip(self.ptensors(ftensor), self.producers(ftensor)): + for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): if ptensor.overlap(ctensor): pidx = producer.outputs().index(ptensor) cidx = consumer.inputs().index(ctensor) producer.add_successor(pidx, consumer) consumer.add_predecessor(cidx, producer) # set mirror as control dependency - if producer.mirror and isinstance(producer, IRFwOperation): + if producer.mirror is not None and isinstance(producer, IRFwOperation): producer.add_successor(-1, producer.mirror) producer.mirror.add_predecessor(-1, producer) + # sub segments + for segment in self._nodes: + if isinstance(segment, IRSegment): + segment.reset_dependency() # ========================= Basic Graph access ======================= From 8888417ae1e1eb0acda964e4e12f0c2cf16c8987 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Sep 2022 19:09:47 +0800 Subject: [PATCH 0998/1892] fix bug on adapter for graph input and output --- cube/compiler.py | 12 ++++++++++-- cube/graph/gener/concurrent.py | 20 +++++++++---------- cube/graph/gener/gen.py | 36 ++++++++++++++++++++++------------ cube/ir/tensor.py | 2 +- cube/program.py | 23 ++++++++++++++++++++-- 5 files changed, 66 insertions(+), 27 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 2efa4b5a..efef2e31 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -100,14 +100,17 @@ def decorator(fn: Callable) -> Callable: resource = cube.runtime.resource.EnvResource() - # logic translator + # run once to get model structure and tensor shape outputs = fn(model_graph, ir_dataloader) if outputs is None: outputs = [] elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = [outputs] - graph = Program().get_graph() + # setup program output + Program().set_output(outputs) + # run policy + graph = Program().get_graph() if len(PAS) == 1: graph = PAS[0](graph, resource) elif len(PAS) == 3: @@ -127,7 +130,10 @@ def decorator(fn: Callable) -> Callable: raise RuntimeError(f"Node {node} device is not set") # generate adapter + start = time.time() graph = IRAdapterGener.gen(graph) + span = time.time() - start + print('> finish generating adapters: {:.2f} s'.format(span)) if graph.sched is not None: start = time.time() @@ -145,6 +151,8 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on diff-fusion operations: {:.2f} s'.format(span)) + # execplan.visualize(outfile='plan.png') + # plan pass for computation grouping if not graph.sched: start = time.time() diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 7b40fb2a..73152ea2 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -36,16 +36,16 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # case 1: sharing device (in-shard) inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) if inshard and len(pdevs) > 1: - fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) - # try: - # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) - # except Exception as e: - # fadapter = None - # print( - # f"full tensor: {fptensors[0].parent} cannot use grid generation.\n" - # f"Reason: {str(e)}\n" - # f"Switch to general P2P communication." - # ) + # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + try: + fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + except Exception as e: + fadapter = None + print( + f"full tensor: {fptensors[0].parent} cannot use grid generation.\n" + f"Reason: {str(e)}\n" + f"Switch to general P2P communication." + ) # Case 2: sperating device (cross-shard) if len(set(pdevs).intersection(cdevs)) == 0: diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 030c12ef..1016d684 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -8,7 +8,7 @@ from cube.graph.graph import IRGraph from cube.graph.segment import IRSegment from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap +from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.operator import IRBpOperation, IRFwOperation @@ -30,6 +30,21 @@ def to_device(tensor: IRSubTensor, device: int) -> IRFwOperation: return otensor +class DummyInputOuput(IRFwOperation): + + def __init__(self, tensor: IRSubTensor, device: int, is_input=False, is_output=False): + super().__init__('dummy', '', + 1 if is_input else 0, + 1 if is_output else 0 + ) + assert (is_input and not is_output) or (is_output and not is_input) + if is_input: + self.set_input(0, tensor) + if is_output: + self.set_output(0, tensor) + self.device = device + + def create_dummy(segment: IRSegment) -> List[IRFwOperation]: """ Create dummy operators that @@ -45,19 +60,15 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: for devid in devices: for tensor in segment.inputs(): if not isinstance(tensor, IRSubTensor): continue - assert tensor.valmap == ValueMap((0, 1)) - fwop = IRFwOperation('segment input', '', 0, 1) - fwop.set_output(0, tensor) - fwop.device = devid + assert tensor.valmap == (0, 1) + fwop = DummyInputOuput(tensor, devid, is_output=True) segment.insert(fwop, 0) fwops.append(fwop) for tensor in segment.outputs(): if not isinstance(tensor, IRSubTensor): continue - assert tensor.valmap == ValueMap((0, 1)) - fwop = IRFwOperation('segment output', '', 1, 0) - fwop.set_intput(0, tensor) - fwop.device = devid - segment.insert(fwop, -1) + assert tensor.valmap == (0, 1) + fwop = DummyInputOuput(tensor, devid, is_input=True) + segment.insert(fwop, segment.nnodes) fwops.append(fwop) return fwops @@ -166,6 +177,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bgraph: Optional[IRSegment] = graph.mirror bdummies = create_dummy(bgraph) if isinstance(bgraph, IRSegment) else [] + skip_grads = [t.parent for t in graph.inputs() + graph.outputs() if isinstance(t, IRSubTensor)] # generate adapter for inter-segments # FIXME: assume producers and consumers can run in parallel for ftensor in graph.full_tensors(): @@ -188,7 +200,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bproducers, bptensors = [], [] bconsumers, bctensors = [], [] - if isinstance(ftensor.grad, IRFullTensor): + if (ftensor not in skip_grads) and isinstance(ftensor.grad, IRFullTensor): bproducers, bptensors = graph.producers(ftensor.grad), graph.ptensors(ftensor.grad) assert all(len(ptensor.device) == 1 for ptensor in bptensors), ( f"Not support for multi-device:\n" @@ -443,7 +455,7 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): graph.insert(multiref, min_fidx) multirefs[multiref] = consumers - if isinstance(ftensor.grad, IRFullTensor): + if len(multirefs) > 0 and isinstance(ftensor.grad, IRFullTensor): graph.update_ftensor_bw(ftensor) @staticmethod diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index d027fcf9..43ae40d1 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -161,7 +161,7 @@ def __init__(self, weight: IdxChunk): weight = weight.weight assert len(weight) == 2 and all(isinstance(i, int) for i in weight), \ "expected weight to be (idx, nchunks)" - self._weight = weight + self._weight = tuple(weight) @property def weight(self) -> IdxChunk: diff --git a/cube/program.py b/cube/program.py index 5a602a10..0d88d512 100644 --- a/cube/program.py +++ b/cube/program.py @@ -2,8 +2,8 @@ from cube.graph.torch_dtype_mapping import DType2IRDType from cube.ir.cten import IRCell, IRTensor -from cube.ir.tensor import IRFullTensor -from cube.ir.operator import IRDataOperation +from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.operator import IRBpOperation, IRDataOperation from cube.graph import IRGraph from cube.graph import parser @@ -39,8 +39,27 @@ def add_nodes(self, nodes: List[IRCell]): self.add_node(node) def get_graph(self): + has_bp = any(isinstance(node, IRBpOperation) for node in self.instance._graph.nodes()) + if not has_bp: + for ftensor in self.instance._graph.full_tensors(): + ftensor.requires_grad = False + for node in self.instance._graph.nodes(flatten=True): + for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor): + itensor.grad = None + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor): + otensor.grad = None return self.instance._graph + def set_output(self, outputs: List[IRTensor]): + for otensor in outputs: + if not isinstance(otensor, IRTensor): + raise NotImplementedError("Not support for non-tensor graph output") + self.instance._graph.reset_outputs(len(outputs)) + for idx, otensor in enumerate(outputs): + self.instance._graph.set_output(idx, otensor) + def mirror_as_self(self): """ Set mirror as self. This is called when a backward is triggered. From 3e584de71c30cf9e656d77469815aa0e129103e3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Sep 2022 19:10:08 +0800 Subject: [PATCH 0999/1892] switch to no bias case --- examples/mlp/linears.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 8bc454a8..05eb5f51 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -46,9 +46,9 @@ def __init__(self, dim, mult=1, nlayers=4): self.layers = torch.nn.ModuleList([]) for lid in range(nlayers): if lid % 2 == 0: - self.layers.append(nn.Linear(dim, dim * mult)) + self.layers.append(nn.Linear(dim, dim * mult, bias=False)) else: - self.layers.append(nn.Linear(dim * mult, dim)) + self.layers.append(nn.Linear(dim * mult, dim, bias=False)) def forward(self, data): x = data From 69aede962b953eada37f0627f65e193724d2d22a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Sep 2022 11:21:52 +0000 Subject: [PATCH 1000/1892] Merged PR 1415: Hierarchical graph and hierarchical generation of adapter An IRGraph can be organized with hierarchy, with graph.group() call. The nodes of a sub-graph (IRSegment) inside the graph can also get partitioned, assigned and re-ordered (within sub-graph). Changes made to enable this: 1) IRSegment is re-designed and now can be viewed as a (distributed) operator of its parent graph 2) An IRFullTensor can be appeared at any graph hierarchy. The IRFullTensor-IRSubTensor mapping information is moved from FullTensor to IRSegment. 3) Communication can be recursively generated for each graph hierarchy. 3) New executor runtime to treat each segement as an independent graph. TODO in future: 1) The scheduling interface for graph nodes. 2) Codegen for hierarchical graph --- cube/__init__.py | 1 - cube/codegen/codegen.py | 83 +-- cube/compiler.py | 65 +- cube/execplan/execplan.py | 62 +- cube/execplan/planpass/fusion.py | 2 +- cube/execplan/planpass/grouping.py | 13 +- cube/graph/function/anchor.py | 18 +- cube/graph/function/dimops.py | 10 +- cube/graph/gener/concurrent.py | 254 +++++++ cube/graph/gener/gen.py | 643 ++++++++---------- cube/graph/graph.py | 996 +++++++++++++--------------- cube/graph/parser/__init__.py | 2 +- cube/graph/parser/converter.py | 9 - cube/graph/segment.py | 820 +++++++++++++++++++++++ cube/ir/cten.py | 10 +- cube/ir/dtype.py | 31 + cube/ir/operator.py | 101 +-- cube/ir/tensor.py | 185 ++---- cube/logics/__init__.py | 0 cube/logics/dataloader.py | 30 - cube/logics/model.py | 84 --- cube/logics/pool.py | 50 -- cube/logics/translator.py | 100 --- cube/program.py | 158 +++++ cube/runtime/adapter/collectives.py | 8 +- cube/runtime/adapter/nn.py | 1 - cube/runtime/adapter/transform.py | 16 +- cube/runtime/executor.py | 156 +++-- examples/mlp/linears.py | 4 +- 29 files changed, 2319 insertions(+), 1593 deletions(-) create mode 100644 cube/graph/gener/concurrent.py create mode 100644 cube/graph/segment.py delete mode 100644 cube/logics/__init__.py delete mode 100644 cube/logics/dataloader.py delete mode 100644 cube/logics/model.py delete mode 100644 cube/logics/pool.py delete mode 100644 cube/logics/translator.py create mode 100644 cube/program.py diff --git a/cube/__init__.py b/cube/__init__.py index 534f909f..24ef255c 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,4 +1,3 @@ -from cube import logics from cube import runtime from cube.compiler import SemanticModel, compile diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index f8197989..16158e32 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -35,14 +35,14 @@ def get_backward_callsite_io_tensors(bp_segment:IRSegment): #inputs to 'backward' outputs of 'backward' ``` """ - assert isinstance(bp_segment, IRSegment) and not bp_segment.forward + assert isinstance(bp_segment, IRSegment) and not bp_segment.isfw() input_tensors = [t for t in bp_segment.mirror.inputs() if \ isinstance(t, IRSubTensor) and \ t.requires_grad and \ not t.is_attr() ] - output_tensors = [t for t in bp_segment.mirror.outputs() if isinstance(t, IRSubTensor)] + output_tensors = [t for t in bp_segment.mirror.outputs() if isinstance(t, IRSubTensor) and t.grad is not None] input_grads = [t.grad for t in input_tensors] # WARNING !!! @@ -99,7 +99,7 @@ def is_temp_tensor(v): inputs : Iterable[IRTensor] if isinstance(node, IRSegment): - if node.forward: + if node.isfw(): outputs = node.outputs() inputs = node.inputs() else: @@ -409,7 +409,7 @@ def init_comm_groups(self): if ranks not in comm_groups: comm_groups.append(ranks) # collect groups from p2p fusion - adapters = [n for n in graph.flatten() if isinstance(n, IRAdapter)] + adapters = [n for n in graph.nodes(flatten=True) if isinstance(n, IRAdapter)] for adapter in adapters: for prim in adapter.prims: if isinstance(prim, CollectivePrim): @@ -446,7 +446,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: # parse graph body for node in self.execplan.seq(device): if isinstance(node, IRSegment): - if not node.forward: continue # skip backward segment + if not node.isfw(): continue # skip backward segment codes = self.emit_segment_code(node) elif isinstance(node, IRFwOperation): raise RuntimeError(f"Unexcepted global-level op call: {node}") @@ -535,33 +535,37 @@ def emit_node_tensors_declare(self, node: IRCell): psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}))" map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" - for itensor in node.inputs(): - name = self.tensor_naming(itensor, prefix_attr='self.') - if isinstance(itensor, IRSubTensor): - if itensor.is_attr() and not self.symbols.exist(name): - self.symbols.create(name) - sign = psign if itensor.is_param() else bsign - code = sign.format( - name=self.tensor_naming(itensor), - shape=tuple(itensor.shape), - dtype=self.dtype_map(itensor.dtype) - ) - self.model_init_statements.append(code) - tid = itensor.parent.tid - slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) - val_chunks = itensor.valmap[1] - code = map_sign.format( - attr=self.tensor_naming(itensor), tid=tid, - slicers=str(slicers), val_chunks=val_chunks - ) - self.model_init_statements.append(code) - self.model_init_statements.append('') - if isinstance(itensor, str): - if name.startswith('self.'): - if not hasattr(self._ref_module, name[5:]): - raise NotImplementedError("member attribute is not added") - for output in node.outputs(): - self.symbols.create(self.tensor_naming(output, prefix_attr='self.')) + if not isinstance(node, IRSegment): + for itensor in node.inputs(): + name = self.tensor_naming(itensor, prefix_attr='self.') + if isinstance(itensor, IRSubTensor): + if itensor.is_attr() and not self.symbols.exist(name): + self.symbols.create(name) + sign = psign if itensor.is_param() else bsign + code = sign.format( + name=self.tensor_naming(itensor), + shape=tuple(itensor.shape), + dtype=self.dtype_map(itensor.dtype) + ) + self.model_init_statements.append(code) + tid = itensor.parent.tid + slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) + val_chunks = itensor.valmap[1] + code = map_sign.format( + attr=self.tensor_naming(itensor), tid=tid, + slicers=str(slicers), val_chunks=val_chunks + ) + self.model_init_statements.append(code) + self.model_init_statements.append('') + if isinstance(itensor, str): + if name.startswith('self.'): + if not hasattr(self._ref_module, name[5:]): + raise NotImplementedError("member attribute is not added") + for output in node.outputs(): + self.symbols.create(self.tensor_naming(output, prefix_attr='self.')) + else: + for sub_node in node.nodes(): + self.emit_node_tensors_declare(sub_node) return def emit_segment_code(self, segment: IRSegment) -> List[str]: @@ -669,7 +673,7 @@ def recompute(tensor_2222): nodes : List[IRCell] = [node for i, node in i_nodes] - subseg = self.execplan.graph.segment(nodes) + subseg = self.execplan.graph.create_segment(nodes) inputs = [t for t in subseg.inputs() if not t.is_attr()] input_names = [self.tensor_naming(t) for t in inputs] @@ -976,8 +980,9 @@ def emit_node(self, node: IRCell, name: str) -> str: """ Emit node / subgraph code """ - fsign = '{outputs} = cube.runtime.executor.fexecute({model}, *{inputs}, requires_grad={req_grad})' - bsign = '{input_grads} = cube.runtime.executor.backward({input_tensors}, {output_tensors}, {output_grads})' + fsign = '{outputs} = cube.runtime.executor.fexecute({name}, {model}, *{inputs}, requires_grad={req_grad})' + asign = '{outputs} = cube.runtime.executor.aexecute({model}, *{inputs}, requires_grad={req_grad})' + bsign = '{input_grads} = cube.runtime.executor.backward({name}, {input_tensors}, {output_tensors}, {output_grads})' inputs = self.tuple_naming(node.inputs(), skip_attr=True, prefix_attr='model.') outputs = self.return_naming(node.outputs(), skip_attr=True, prefix_attr='model.') @@ -985,9 +990,10 @@ def emit_node(self, node: IRCell, name: str) -> str: if isinstance(node, IRSegment): # emit forward - if node.forward: + if node.isfw(): code = fsign.format( outputs = outputs, + name = f"'{name}'", model = f'model.{name}', inputs = inputs, req_grad = req_grad @@ -1002,6 +1008,7 @@ def emit_node(self, node: IRCell, name: str) -> str: assert tensor == 1.0, "Loss gradient should be 1.0" output_grads[idx] = None code = bsign.format( + name = f"'{self.node_naming(node.mirror)}'", input_grads = self.return_naming(input_grads), input_tensors = self.tuple_naming(input_tensors, skip_attr=True, prefix_attr='model.'), output_tensors = self.tuple_naming(output_tensors, skip_attr=True, prefix_attr='model.'), @@ -1016,7 +1023,7 @@ def emit_node(self, node: IRCell, name: str) -> str: code = f'{outputs} = next(dataloader)' elif isinstance(node, IRAdapter): - code = fsign.format( + code = asign.format( outputs = outputs, model = f'model.{name}', inputs = inputs, @@ -1024,7 +1031,7 @@ def emit_node(self, node: IRCell, name: str) -> str: ) elif isinstance(node, IRWeightReducer): - code = fsign.format( + code = asign.format( outputs = outputs, model=f'model.{name}', inputs='()', diff --git a/cube/compiler.py b/cube/compiler.py index 33a9b6d6..efef2e31 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -5,15 +5,11 @@ import cube -from cube.graph import parser from cube.graph.gener.gen import IRAdapterGener from cube.graph.graph import IRGraph from cube.ir.operator import IRDataOperation from cube.graph.function.anchor import IRGraphAnchor -from cube.logics.pool import SchedulePool -from cube.logics.translator import LogicTranslator - from cube.execplan import ExecutionPlan from cube.execplan.planpass.fusion import DiffFusion from cube.execplan.planpass.grouping import Grouping @@ -23,47 +19,7 @@ from cube.profiler.timer import print_each_rank from cube.runtime.syndata import CubeDataLoader, SciLoopVariables - -class SemanticModel: - - def __init__(self, model: torch.nn.Module, input_shapes): - """ - Create semantic model based on AI Scientist description. - """ - dist = torch.distributed.is_initialized() - if (not dist) or (dist and torch.distributed.get_rank() == 0): - self.ir_graph = parser.convert_model( - model, input_shapes=input_shapes - ) - else: - self.ir_graph = None - self._loaded_module = None - - def get_graph(self): - return self.ir_graph - - def load_module(self, filename: str, load_content=True): - import importlib.util - spec = importlib.util.spec_from_file_location("GenModel", filename) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self._loaded_module = module.GenModel().cuda() - if load_content: - print_each_rank("> loading parameter content...") - # TODO: make hardcode ./fullmodel.pt programmable - self._loaded_module.load_attr_content('./fullmodel.pt') - - def get_gen_module(self): - return self._loaded_module - - def clear_module(self): - self._loaded_module = None - - def __call__(self, *args): - if self._loaded_module: - return self._loaded_module(*args) - else: - return self.ir_graph(*args) +from cube.program import Program, SemanticDataLoader, SemanticModel def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, @@ -105,7 +61,7 @@ def train_step(model, dataloader): PAS = (PAS,) model_graph = model.get_graph() - ir_dataloader = parser.convert_dataloader(dataloader) + ir_dataloader = SemanticDataLoader(dataloader) if torch.distributed.is_initialized(): # multiple device @@ -142,17 +98,19 @@ def decorator(fn: Callable) -> Callable: compile_start = time.time() - SchedulePool().clear() resource = cube.runtime.resource.EnvResource() - # logic translator + # run once to get model structure and tensor shape outputs = fn(model_graph, ir_dataloader) if outputs is None: outputs = [] elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = [outputs] - graph = LogicTranslator.gen_logic_graph(outputs=outputs) + # setup program output + Program().set_output(outputs) + # run policy + graph = Program().get_graph() if len(PAS) == 1: graph = PAS[0](graph, resource) elif len(PAS) == 3: @@ -167,12 +125,15 @@ def decorator(fn: Callable) -> Callable: # check assignment and remove anchor node for node in graph.nodes(): if isinstance(node, IRGraphAnchor) or isinstance(node.mirror, IRGraphAnchor): - graph.detach(node) - elif len(node.device) == 0: + continue + if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") # generate adapter + start = time.time() graph = IRAdapterGener.gen(graph) + span = time.time() - start + print('> finish generating adapters: {:.2f} s'.format(span)) if graph.sched is not None: start = time.time() @@ -190,6 +151,8 @@ def decorator(fn: Callable) -> Callable: span = time.time() - start print('> planpass on diff-fusion operations: {:.2f} s'.format(span)) + # execplan.visualize(outfile='plan.png') + # plan pass for computation grouping if not graph.sched: start = time.time() diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index d9173d2f..67ea9112 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -17,7 +17,7 @@ def __init__(self, graph: IRGraph): self._inference_only = not any( isinstance(n, IRBpOperation) or \ (isinstance(n, IRAdapter) and not n.forward) or \ - (isinstance(n, IRSegment) and not n.forward) for n in graph.nodes() + (isinstance(n, IRSegment) and not n.isfw()) for n in graph.nodes() ) # execution sequence for each device @@ -99,21 +99,19 @@ def set(self, devid: int, seq: List[IRCell]): raise TypeError("Expected a list of Cell") self._seq[devid] = seq - def analyze(self, + def visualize(self, map2time: Optional[Callable] = None, map2mem: Optional[Callable] = None, map2name: Optional[Callable] = None, - outfile = None): + outfile: Optional[str] = None): """ - Draw the execution timeline. + Visualize the graph - Args: - span (List[int]): - length equal to schedule unit num. - Each element stands for the time span for corresponding Cell - - outfile: - the output file name + @param map2time Optional[Callable]: node to time (int) map. + @param map2mem Optional[Callable]: node to memory consumption (int) map + @param map2name Optional[Callable]: node to name (str) map + @param outfile Optional[str]: the output file name. + If given, will save the visualized execution plan in file. """ ndevice = len(self.devices()) # timeline [ [ (start_time, end_time), ... ], ... ] @@ -124,10 +122,11 @@ def analyze(self, if map2time is None: def map2time(node): - if isinstance(node, IRGraph): + if isinstance(node, IRSegment): span = 0 for node in node.nodes(): span += map2time(node) + return span if isinstance(node, IRFwOperation): return 1 if isinstance(node, IRBpOperation): @@ -138,7 +137,7 @@ def map2time(node): if map2mem is None: def map2mem(node): - if isinstance(node, IRGraph): + if isinstance(node, IRSegment): peak_mem = 0 curr_mem = 0 for node in node.nodes(): @@ -152,20 +151,21 @@ def map2mem(node): if map2name is None: def map2name(node): - if isinstance(node, IRGraph): - if all([isinstance(n, IRFwOperation) for n in node.nodes()]): - return f'f{node._id}' - if all([isinstance(n, IRBpOperation) for n in node.nodes()]): - if node.mirror is not None: - return f'b{node.mirror._id}' + if isinstance(node, IRSegment): + if node.isfw(): + return f'f{node.cid}' + elif node.isbw(): + return f'b{node.mirror.cid}' if isinstance(node, IRFwOperation): - return f'f{node._id}' + return f'f{node.cid}' if isinstance(node, IRBpOperation): - return f'b{node.mirror._id}' - return str(node._id) + return f'b{node.mirror.cid}' + if isinstance(node, IRAdapter): + return f'a{node.cid}' + return f'?{node.cid}' def map2color(node): - if isinstance(node, IRGraph): + if isinstance(node, IRSegment): return map2color(node.nodes(0)) if isinstance(node, IRFwOperation): return '#4472C4' # excel blue @@ -174,13 +174,12 @@ def map2color(node): if isinstance(node, IRAdapter): return '#70AD47' # excel green + self.graph.reset_dependency() for node in self.graph.nodes(): span, mem = map2time(node), map2mem(node) + # calculate time + start_times = [] for device in node.device: - # memory - device_mem[device] += mem - if device_peak_mem[device] < device_mem[device]: - device_peak_mem[device] = device_mem[device] # tight execution if no dependency if len(device_timeline[device]) == 0: start_time = 1 @@ -196,8 +195,17 @@ def map2color(node): if other_node in node.predecessors(): start_time = max(start_time, end_time) break + start_times.append(start_time) + + start_time = max(start_times) + for device in node.device: + # time device_timeline[device].append((start_time, start_time + span)) device_nodes[device].append(node) + # memory + device_mem[device] += mem + if device_peak_mem[device] < device_mem[device]: + device_peak_mem[device] = device_mem[device] max_time = max( [tline[-1][1] for tline in device_timeline if len(tline) != 0] diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 9a09b055..ca6e0d61 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -29,7 +29,7 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: if node.forward: ret = DiffFusion.nnfuse(node) cnt = cnt+1 if ret else cnt - if isinstance(node, IRSegment) and node.forward: + if isinstance(node, IRSegment) and node.isfw(): for fnode in node.nodes(): if isinstance(fnode, IRAdapter): if node.forward: diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index f974f49d..16164755 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -6,11 +6,13 @@ from cube.execplan import ExecutionPlan from cube.execplan.planpass.planpass import PlanPass from cube.ir.adapter import IRAdapter +from cube.ir.adapter.prim import IdentityPrim from cube.ir.operator import IRFwOperation from cube.ir.cten import IRCell class Grouping(PlanPass): + @staticmethod def apply(execplan: ExecutionPlan) -> ExecutionPlan: """ @@ -20,11 +22,11 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: fgroups, bgroups = Grouping.group(execplan) for devid in execplan.devices(): for fpieces, bpieces in zip(fgroups[devid], bgroups[devid]): - fsubgraph = graph.segment(fpieces) + fsubgraph = graph.create_segment(fpieces) if bpieces is not None: - bsubgraph = graph.segment(bpieces) + bsubgraph = graph.create_segment(bpieces) IRCell.make_pair(fsubgraph, bsubgraph) - subgraphs = [fsubgraph] if bpieces is None else [fsubgraph, bsubgraph] + subgraphs = [fsubgraph] if bpieces is None else [fsubgraph, fsubgraph.mirror] for subgraph in subgraphs: # update execution plan: replace the nodes with the subgraph pieces = subgraph.nodes() @@ -32,6 +34,11 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: execplan.at(devid).insert(idx, subgraph) for node in pieces: execplan.at(devid).remove(node) + # remove identity adapter + for adapter in execplan.seq(devid): + if isinstance(adapter, IRAdapter): + if all(isinstance(prim, IdentityPrim) for prim in adapter.prims): + execplan.at(devid).remove(adapter) return execplan @staticmethod diff --git a/cube/graph/function/anchor.py b/cube/graph/function/anchor.py index 18c82803..2fec2200 100644 --- a/cube/graph/function/anchor.py +++ b/cube/graph/function/anchor.py @@ -1,24 +1,26 @@ from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRSubTensor class IRGraphAnchor(IRFwOperation): """ - The anchor function for navigation inside the graph + The anchor function serves for + 1) navigation inside the graph + 2) staging boundary inside the graph + + This operator will eventually be removed from graph, + user doesn't need to manipulate it. """ def __init__(self, signature: str, name: str): super().__init__(name, signature, 0, 1) self.kwargs['name'] = name self.set_output(0, None) + def infer_dtype(self): + return + def infer_shape(self): return True def __repr__(self) -> str: - sign = self.signature.split('.')[-1] - ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_attr()] - dscp = (f"FwOp{self._id}(sign={sign}[{self.name}], " - f"inputs={ins}, " - f"outputs={self.outputs()})") - return dscp + return f"AnchorOp-{self.cid}(name={self.name})" diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 5ad2502c..d0b148ed 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -68,8 +68,10 @@ import string from cube.ir.cten import IRTensor +from cube.ir.dtype import DTypeInferRule from cube.ir.operator import IRFwOperation from cube.algorithm.factory import DistAlgorithmFactory +from cube.ir.tensor import IRSubTensor _kSpecialIdentifiers = ('*', '?') @@ -630,10 +632,12 @@ def oanno(self, index: int) -> Tuple[DimAnno]: def infer_shape(self) -> bool: """ - Shape inference using the matched annotation + Shape and dtype inference using the matched annotation and tensor. @return sucess: True if successfully inferred shape """ + idtypes = [t.dtype for t in self._inputs if isinstance(t, IRTensor)] + odtype = DTypeInferRule.infer(self, idtypes) for oidx, otensor in enumerate(self.outputs()): shape_anno = self.oanno(oidx) shape = [] @@ -643,6 +647,10 @@ def infer_shape(self) -> bool: accum *= self.anno.getlen(identifier) shape.append(accum) otensor.shape = shape + # set output shape + if isinstance(otensor, IRSubTensor): + otensor.parent.dtype = odtype + otensor.dtype = odtype # print(f'=> sign: {self.signature} anno: {self.anno}\n' # f'=> inputs: {self.inputs()}\n' # f'=> outputs: {self.outputs()}') diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py new file mode 100644 index 00000000..73152ea2 --- /dev/null +++ b/cube/graph/gener/concurrent.py @@ -0,0 +1,254 @@ +""" +Concurrent producer / consumer Adapter Generator +""" +from typing import List, Optional +import copy + +from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from cube.ir.adapter.prim import IRAdapterPrim +from cube.ir.adapter import IRAdapter +from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim + +from cube.graph.gener.layout import GridLayout + + +class ConcurrentGener: + + @staticmethod + def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor]) -> Optional[IRAdapter]: + """ + Generate forward adapter and backward adapter + + @param fptensors List[IRSubTensor]: forward producer tensors + @param fctensors List[IRSubTensor]: forward consumer tensors + @param bptensors List[IRSubTensor]: backward producer tensors + @param bctensors List[IRSubTensor]: backward consumer tensors + + @return fadapter Optional[IRAdapter]: forward adapter + None indicate no adapter required. + """ + pdevs = tuple(t.device[0] for t in fptensors) + cdevs = tuple(t.device[0] for t in fctensors) + + fadapter: IRAdapter = None + + # case 1: sharing device (in-shard) + inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) + if inshard and len(pdevs) > 1: + # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + try: + fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + except Exception as e: + fadapter = None + print( + f"full tensor: {fptensors[0].parent} cannot use grid generation.\n" + f"Reason: {str(e)}\n" + f"Switch to general P2P communication." + ) + + # Case 2: sperating device (cross-shard) + if len(set(pdevs).intersection(cdevs)) == 0: + pass + + # Case 3: General cases + # warnings.warn('The adapter is generated using P2P communication') + if fadapter is None: + fadapter = ConcurrentGener.gen_general(fptensors, fctensors, bptensors, bctensors) + + if set(pdevs) == set(cdevs) and fadapter.mirror is not None: + fadapter.differentiable = True + fadapter.mirror.differentiable = True + + return fadapter + + @staticmethod + def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], + allow_reorder=False): + ftensor = fptensors[0].parent + # producer grid layout + ilayout = GridLayout.togrid(ftensor, fptensors) + # reorder ctensors to match with ptensors + devs = [ptensor.device for ptensor in ilayout.mat.flatten()] + ctensors = [None] * len(devs) + for ctensor in fctensors: + idx = devs.index(ctensor.device) + ctensors[idx] = ctensor + assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" + # consumer grid layout + olayout = GridLayout.togrid(ftensor, ctensors) + # find path + paths, fprims = ilayout.path(olayout) + + # re-assign the operator if miss-ordered + names, from_dev, to_dev = [], [], [] + for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): + assert len(itensor.device) == 1 and len(otensor.device) == 1, \ + "Expect tensor only has one device. Report this as a bug" + if itensor.device != otensor.device: + inode, onode = itensor.cell, otensor.cell + names.append(f'{onode.name}{onode.cid}') + from_dev.append(onode.device[0]) + to_dev.append(inode.device[0]) + if allow_reorder: + onode.device = inode.device + if onode.mirror is not None: + onode.mirror.device = inode.device + else: + raise RuntimeError("device mismatch. Try to enable reorder") + if len(names) > 0: + print(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') + + fadapter = IRAdapter(fptensors, fctensors) + fadapter.prims = fprims + + # generate backward + grad: IRFullTensor = ftensor.grad + bprims = [] + if grad is not None and (len(bptensors) != 0 or len(bctensors) != 0): + # reorder ptensors to match with forward + ptensors = [None] * len(devs) + for bptensor in bptensors: + idx = devs.index(bptensor.device) + assert ptensors[idx] is None, "same device of different tensors" + ptensors[idx] = bptensor + ilayout = GridLayout.togrid(grad, ptensors) + olayout = GridLayout.togrid(grad, bctensors) + paths, bprims = ilayout.path(olayout) + # check the device order + for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): + assert len(itensor.device) == len(otensor.device), "backward device not match" + badapter = IRAdapter(bptensors, bctensors) + badapter.prims = bprims + IRAdapter.make_pair(fadapter, badapter) + + return fadapter + + @staticmethod + def gen_cross_shard(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> IRAdapter: + pass + + @staticmethod + def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor]) -> IRAdapter: + """ + A general way to generate adapter. + + @param ftensor IRFullTensor + @return adapter IRAdapter + """ + fprims = [] + for ctensor in fctensors: + fprims += ConcurrentGener.gen_subtensor(ctensor, fptensors) + fadapter = IRAdapter(fptensors,fctensors) + fadapter.prims = fprims + # backward + if len(bptensors) > 0 and len(bctensors) > 0: + bprims = [] + for cgrad in bctensors: + bprims += ConcurrentGener.gen_subtensor(cgrad, bptensors) + badapter = IRAdapter(bptensors, bctensors) + badapter.prims = bprims + IRAdapter.make_pair(fadapter, badapter) + return fadapter + + @staticmethod + def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRAdapterPrim]: + """ + Generate communiction primitives for ctensor + + @param ctensor IRSubTensor: the consumed tensor as destination + @param ptensors List[IRSubTensor]: the produced tensors as source + + @return prims List[IRAdapterPrim]: the primitives for adapter + """ + # category to local tensor and remote tensor + local = [t for t in ptensors if t.device == ctensor.device] + remote = [t for t in ptensors if t.device != ctensor.device] + prims = [] + + # ==== select ==== # + intersections = [] + # check local + for itensor in local+remote: + if itensor.device == ctensor.device and itensor == ctensor: + return [] + common: Optional[IRSubTensor] = itensor.common(ctensor) + if common is None: + continue + common.cell = itensor.cell + intersections.append(common) + # create select primitive + if common != itensor: + indmap = [] + for dim in range(itensor.ndims): + (s1, e1), (s2, e2) = itensor.indmap[dim], common.indmap[dim] + start = s2 - s1 + end = start + e2 - s2 + indmap.append((start, end)) + indmap = IndexMap(tuple(indmap)) + assert itensor.valmap == common.valmap, "Value map not same" + valmap = ValueMap((0, 1)) + select_prim = SelectPrim(itensor, indmap, valmap, common) + prims.append(select_prim) + if itensor.device == ctensor.device and common == ctensor: + return [select_prim] + # TODO: check union == subtensor + if common == ctensor: + break + + # print(intersections) + # ====== move ===== # + tmoved = [] + for tensor in intersections: + assert len(tensor.device) == 1 and len(ctensor.device) == 1, "Expected only one device." + mtensor = tensor + if tensor.device != ctensor.device: + mtensor = copy.copy(tensor) + mtensor.cell = ctensor.cell + prims.append(MovePrim(tensor, mtensor)) + tmoved.append(mtensor) + + # ===== merge ===== # + remain_tensors: List[IRSubTensor] = copy.copy(tmoved) + if ctensor in remain_tensors: + return prims + out = None + while out != ctensor: + out, merged = None, False + for idx1 in range(len(remain_tensors) - 1): + for idx2 in range(idx1+1, len(remain_tensors)): + t1, t2 = remain_tensors[idx1], remain_tensors[idx2] + catdim = t1.catdim(t2) + if catdim is not None: + tensors = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] + out = tensors[0].concat(tensors[1], dim=catdim) + out.cell = ctensor.cell + prims.append(MergeDimPrim(tensors, out, catdim)) + merged = True + break + # reduction + if t1.accumable(t2): + out = t1.accum(t2) + out.cell = ctensor.cell + prims.append(SumPrim([t1, t2], out)) + merged = True + break + if merged: + remain_tensors.remove(t1) + remain_tensors.remove(t2) + remain_tensors.append(out) + break + if out is None: + ptensors = '\n\t'.join(t.extra_repr() for t in ptensors) + remain = '\n\t'.join(t.extra_repr() for t in remain_tensors) + raise RuntimeError( + f"Fail to build adapter.\n" + f"FullTensor:{ctensor.parent}\n" + f"Produced Tensors:\n\t{ptensors}\n" + f"Consumed Tensors:\n\t{ctensor.extra_repr()}\n" + f"Consumer:\n\t{ctensor.cell}\n" + f"Remain Tensor:\n\t{remain}" + ) + return prims diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 2b52a58a..1016d684 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -2,17 +2,75 @@ from typing import Dict, List, Optional, Tuple import copy +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.gener.concurrent import ConcurrentGener + from cube.graph.graph import IRGraph +from cube.graph.segment import IRSegment from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.operator import IRBpOperation, IRFwOperation -from cube.ir.adapter.prim import IRAdapterPrim from cube.ir.adapter import IRAdapter, IRWeightReducer from cube.graph.function.function import Add, Cat, Identity, MultiRef -from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim -from cube.graph.gener.layout import GridLayout + + +def to_device(tensor: IRSubTensor, device: int) -> IRFwOperation: + """ + This is used for changing tensor device + """ + fwop = IRFwOperation('dummy', 'dummpy', 1, 0) + fwop.set_input(0, tensor) + fwop.device = device + otensor = fwop.input(0) + otensor.grad = copy.copy(tensor.grad) + if isinstance(otensor.grad, IRSubTensor): + otensor.grad.cell = fwop + return otensor + + +class DummyInputOuput(IRFwOperation): + + def __init__(self, tensor: IRSubTensor, device: int, is_input=False, is_output=False): + super().__init__('dummy', '', + 1 if is_input else 0, + 1 if is_output else 0 + ) + assert (is_input and not is_output) or (is_output and not is_input) + if is_input: + self.set_input(0, tensor) + if is_output: + self.set_output(0, tensor) + self.device = device + + +def create_dummy(segment: IRSegment) -> List[IRFwOperation]: + """ + Create dummy operators that + 1) produce segment input tensors + 2) consume segment output tensors + + @param segment IRSegment: the target segment + + @return nodes List[IRCell]: the generated operation + """ + devices = segment.device + fwops = [] + for devid in devices: + for tensor in segment.inputs(): + if not isinstance(tensor, IRSubTensor): continue + assert tensor.valmap == (0, 1) + fwop = DummyInputOuput(tensor, devid, is_output=True) + segment.insert(fwop, 0) + fwops.append(fwop) + for tensor in segment.outputs(): + if not isinstance(tensor, IRSubTensor): continue + assert tensor.valmap == (0, 1) + fwop = DummyInputOuput(tensor, devid, is_input=True) + segment.insert(fwop, segment.nnodes) + fwops.append(fwop) + return fwops class IRAdapterGener: @@ -21,96 +79,82 @@ class IRAdapterGener: def gen(graph: IRGraph) -> IRGraph: """ Generate tensor adapter for both activations and weights + Note weight reducers are always append to the last. - Args: - graph: IRGraph. - eager (Boolean): - if True, - each adapter will be inserted right after it's ready to execute. - if False (i.e., lazy), - each adatper will be inserted right before the tensor needs it. - Note weight reducers are always append to last. - Returns: - graph (IRGraph) + @param graph IRGraph: the graph without adapter + @return graph IRGraph: the graph with adapter inserted """ - # insert identity operator for graph output - devs = set() - for node in graph.nodes(): - devs.update(node.device) - outputs = [otensor for otensor in graph.outputs() if isinstance(otensor, IRSubTensor)] - all_identities = [] - for otensor in outputs: - identity = Identity('', [otensor]) - graph.attach(identity, len(graph.nodes())) - identites = graph.replicate(identity, times=len(devs)) - all_identities += identites - for devid, identity in zip(devs, identites): - graph.assign(identity, devid) + # remove anchor node + graph = IRAdapterGener.remove_anchor(graph) # update the gradient before generate adapter - for node in graph.nodes(): - if isinstance(node, IRBpOperation): - idx = graph.detach(node) - node.update() - graph.attach(node, idx) + graph = IRAdapterGener.update_grad(graph) + # generate adapters for activation graph = IRAdapterGener.gen_activation(graph) + # generate weight reducer graph = IRAdapterGener.gen_weight(graph) - # remove inserted identity - for identity in all_identities: - graph.detach(identity) + # fuse consecutive non-differentiable adapters into one + graph = IRAdapterGener.fusion(graph) + # print(graph.extra_repr()) + return graph + + @staticmethod + def update_grad(graph: IRSegment): + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): + graph.update_ftensor_bw(ftensor) + for node in graph.nodes(): + if isinstance(node, IRSegment) and node.isbw(): + IRAdapterGener.update_grad(node) + return graph + + @staticmethod + def remove_anchor(graph: IRSegment): + for anchor in graph.nodes(): + if isinstance(anchor, IRGraphAnchor): + graph.remove(anchor) + if anchor.mirror is not None: + graph.mirror.remove(anchor.mirror) + elif isinstance(anchor, IRSegment): + IRAdapterGener.remove_anchor(anchor) return graph @staticmethod def gen_weight(graph: IRGraph) -> IRGraph: - # step 1: get weight and gradient - # weights: Dict[weight_id: int, IRSubTensor] - # grads : Dict[weight_id: int, Dict[device: int, List[grad: IRSubTensor]]] - grads = dict() weights = dict() - for fnode in graph.nodes(): - if not isinstance(fnode, IRFwOperation): - continue - devid = fnode.device[0] + for fnode in graph.nodes(flatten=True): + if not isinstance(fnode, IRFwOperation): continue + assert len(fnode.device) == 1 for wtensor in fnode.inputs(): if isinstance(wtensor, IRSubTensor) and wtensor.is_param(): - grad: Optional[IRSubTensor] = wtensor.grad - if grad is None: continue - # nothing to sync - if grad.valmap == (0, 1): - continue - if wtensor._id not in grads: - grads[wtensor._id] = dict() - weights[wtensor._id] = wtensor - if devid not in grads[wtensor._id]: - grads[wtensor._id][devid] = list() - if grad in grads[wtensor._id][devid]: - raise RuntimeError( - "Find two same gradient (not expected). " - "This is usually due to replicated node assigned to same device. " - f"\nCheck node:\n\t{fnode}" - ) - grads[wtensor._id][devid].append(grad) - # step 2: generate reducers. - # reducers: tuple(ranks): List[weight] + if wtensor.grad is None: continue + if wtensor.parent not in weights: + weights[wtensor.parent] = dict() + if wtensor not in weights[wtensor.parent]: + weights[wtensor.parent][wtensor] = set() + weights[wtensor.parent][wtensor].add(wtensor.device[0]) + reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() - for wid in grads: - ranks = list(grads[wid].keys()) - ranks.sort() - ranks = tuple(ranks) # ranks are used for group - if len(ranks) == 1: - continue - if ranks not in reducers: - reducers[ranks] = list() - reducers[ranks].append(weights[wid]) + for ftensor, subtensors in weights.items(): + # TODO: check no overlapping (not same) weights on a device + for subw in subtensors: + if len(subtensors[subw]) == 1: + continue + devices = list(subtensors[subw]) + devices.sort() + devices = tuple(devices) + if devices not in reducers: + reducers[devices] = [] + reducers[devices].append(subw) # generate reducer for each rank - for ranks in reducers: - weights = reducers[ranks] + for devices in reducers: + weights = reducers[devices] opt_op = IRWeightReducer(weights) - opt_op.device = list(ranks) - graph._nodes.append(opt_op) + opt_op.device = list(devices) + graph.insert(opt_op, graph.nnodes) return graph @staticmethod - def gen_activation(graph: IRGraph) -> IRGraph: + def gen_activation(graph: IRSegment) -> IRSegment: """! Generate adapter for activation tensors. The forward/backward adapter is inserted before the first consumers of its full tensor. @@ -119,270 +163,96 @@ def gen_activation(graph: IRGraph) -> IRGraph: @return graph IRGraph: the (inplace) modified graph with activation adapters. """ + def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: + # e.g., loss or parameter/buffer + if len(ptensors) == 0 or len(ctensors) == 0: + return True + # direct connection + for ctensor in ctensors: + if not any(t == ctensor and set(ctensor.device).issubset(set(t.device)) for t in ptensors): + return False + return True + + fdummies = create_dummy(graph) + bgraph: Optional[IRSegment] = graph.mirror + bdummies = create_dummy(bgraph) if isinstance(bgraph, IRSegment) else [] + + skip_grads = [t.parent for t in graph.inputs() + graph.outputs() if isinstance(t, IRSubTensor)] + # generate adapter for inter-segments + # FIXME: assume producers and consumers can run in parallel for ftensor in graph.full_tensors(): # backward will gen in forward if ftensor.is_param() or ftensor.is_grad(): continue - # no consumer usually mean loss - if len(ftensor.consumers) == 0: - continue - # graph attribute: buffer - if len(ftensor.producers) == 0: - continue - # no require for communication - if len(ftensor.consumers) == 1 and len(ftensor.producers) == 1 and \ - ftensor.consumers[0].device == ftensor.producers[0].device: + + # optimization: local fusion / multiref on producer / consumer + ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) + IRAdapterGener.local_consumer_multiref(graph, ftensor) + + # print(graph.debug_tensor_map_str(ftensor)) + # print(graph.debug_tensor_map_str(ftensor.grad)) + + # producers can be operators and graph inputs + fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) + assert all(len(ptensor.device) == 1 for ptensor in fptensors), "Not support for multi-device" + fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) + assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" + + bproducers, bptensors = [], [] + bconsumers, bctensors = [], [] + if (ftensor not in skip_grads) and isinstance(ftensor.grad, IRFullTensor): + bproducers, bptensors = graph.producers(ftensor.grad), graph.ptensors(ftensor.grad) + assert all(len(ptensor.device) == 1 for ptensor in bptensors), ( + f"Not support for multi-device:\n" + f"{[ptensor.device for ptensor in bptensors]}" + f"{[ptensor.cell for ptensor in bptensors]}" + ) + bconsumers, bctensors = graph.consumers(ftensor.grad), graph.ctensors(ftensor.grad) + assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" + + if skip(fptensors, fctensors) and skip(bptensors, bctensors): continue - # optimization: local fusion on producer - if graph.train: - ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) - IRAdapterGener.local_consumer_multiref(graph, ftensor) - - ptensors, ctensors = ftensor.ptensors, ftensor.ctensors - pdevs = tuple(ptensor.device[0] for ptensor in ptensors) - cdevs = tuple(ctensor.device[0] for ctensor in ctensors) - - fadapter = None - # Case 1: sharing device (in-shard) - inshard = set(pdevs) == set(cdevs) and len(ptensors) == len(ctensors) and \ - len(pdevs) == len(ptensors) - if inshard and len(pdevs) > 1: - try: - fadapter = IRAdapterGener.gen_in_shard(ftensor, allow_reorder=True) - except Exception as e: - fadapter = None - print( - f"full tensor: {ftensor} cannot use grid generation.\n" - f"Reason: {str(e)}\n" - f"Switch to general P2P communication." - ) - - # Case 2: sperating device (cross-shard) - if len(set(pdevs).intersection(cdevs)) == 0: - pass - - # Case 3: General cases - # warnings.warn('The adapter is generated using + fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors) if fadapter is None: - fadapter = IRAdapterGener.gen_general(ftensor) + continue badapter: Optional[IRAdapter] = fadapter.mirror - + if (badapter is not None and len(fadapter.prims) == 0 and len(badapter.prims) == 0) or \ (badapter is None and len(fadapter.prims) == 0): continue - # set differentiable for autograd generation - if inshard and badapter is not None: - fadapter.differentiable = True - badapter.differentiable = True - # insert forward adapter - fidx = min([graph.nodes().index(consumer) for consumer in ftensor.consumers]) - graph._nodes.insert(fidx, fadapter) + # graph.insert(fadapter, max(producers) + 1) + graph.insert(fadapter, min(graph.index(c) for c in fconsumers)) - # insert backward + # insert backward adapter if badapter is not None: - bidx = min(graph.nodes().index(consumer) for consumer in ftensor.grad.consumers) - graph._nodes.insert(bidx, badapter) + assert isinstance(badapter, IRAdapter) + assert isinstance(bgraph, IRSegment) + bproducers = [ + bgraph.index(consumer.mirror) + 1 for \ + consumer in graph.consumers(ftensor) + ] + bidx = max(bproducers) if len(bproducers) > 0 else 0 + bgraph.insert(badapter, bidx) + + # remove dummy op + for dummy_op in fdummies: + graph.remove(dummy_op) + for dummy_op in bdummies: + bgraph.remove(dummy_op) + + # generate adapter for each segment + segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] + for segment in segments: + IRAdapterGener.gen_activation(segment) - # print(graph.extra_repr()) return graph @staticmethod - def gen_in_shard(ftensor: IRFullTensor, allow_reorder=False) -> Optional[IRAdapter]: - """ - Generate communication for sharing devices (SPMD-like) - - @param ftensor: IRFullTensor - @param ptensors: List[IRSubTensor]: produced subtensors - @param ctensors: List[IRSubTensor]: consumed subtensors - - @return adapter Optional[IRAdapter]: generated adapter. - """ - # producer grid layout - ilayout = GridLayout.togrid(ftensor, ftensor.ptensors) - # reorder ctensors to match with ptensors - devs = [ptensor.device for ptensor in ilayout.mat.flatten()] - ctensors = [None] * len(devs) - for ctensor in ftensor.ctensors: - idx = devs.index(ctensor.device) - ctensors[idx] = ctensor - assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" - # consumer grid layout - olayout = GridLayout.togrid(ftensor, ctensors) - # find path - paths, fprims = ilayout.path(olayout) - - # re-assign the operator if miss-ordered - names, from_dev, to_dev = [], [], [] - for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): - assert len(itensor.device) == 1 and len(otensor.device) == 1, \ - "Expect tensor only has one device. Report this as a bug" - if itensor.device != otensor.device: - inode, onode = itensor.cell, otensor.cell - names.append(f'{onode.name}{onode.cid}') - from_dev.append(onode.device[0]) - to_dev.append(inode.device[0]) - if allow_reorder: - onode.device = inode.device - if onode.mirror is not None: - onode.mirror.device = inode.device - else: - raise RuntimeError("device mismatch. Try to enable reorder") - if len(names) > 0: - print(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') - - fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) - fadapter.prims = fprims - - # generate backward - grad: IRFullTensor = ftensor.grad - bprims = [] - if grad is not None and (len(grad.ptensors) != 0 or len(grad.ctensors) != 0): - # reorder ptensors to match with forward - ptensors = [None] * len(devs) - for ptensor in grad.ptensors: - idx = devs.index(ptensor.device) - assert ptensors[idx] is None, "same device of different tensors" - ptensors[idx] = ptensor - ilayout = GridLayout.togrid(grad, ptensors) - olayout = GridLayout.togrid(grad, grad.ctensors) - paths, bprims = ilayout.path(olayout) - # check the device order - for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): - assert len(itensor.device) == len(otensor.device), "backward device not match" - badapter = IRAdapter(grad.ptensors, grad.ctensors) - badapter.prims = bprims - IRAdapter.make_pair(fadapter, badapter) - - return fadapter - - @staticmethod - def gen_cross_shard(ftensor: IRFullTensor, ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> Optional[IRAdapter]: - pass - - @staticmethod - def gen_general(ftensor: IRFullTensor) -> IRAdapter: - fprims = [] - for ctensor in ftensor.ctensors: - fprims += IRAdapterGener.gen_subtensor(ctensor, ftensor.ptensors) - fadapter = IRAdapter(ftensor.ptensors, ftensor.ctensors) - fadapter.prims = fprims - if ftensor.grad is not None: - bprims = [] - for cgrad in ftensor.grad.ctensors: - bprims += IRAdapterGener.gen_subtensor(cgrad, ftensor.grad.ptensors) - badapter = IRAdapter(ftensor.grad.ptensors, ftensor.grad.ctensors) - badapter.prims = bprims - IRAdapter.make_pair(fadapter, badapter) - return fadapter - - @staticmethod - def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRAdapterPrim]: - """ - Generate communiction primitives for ctensor - - @param ctensor IRSubTensor: the consumed tensor as destination - @param ptensors List[IRSubTensor]: the produced tensors as source - - @return prims List[IRAdapterPrim]: the primitives for adapter - """ - # category to local tensor and remote tensor - local = [t for t in ptensors if t.device == ctensor.device] - remote = [t for t in ptensors if t.device != ctensor.device] - prims = [] - - # ==== select ==== # - intersections = [] - # check local - for itensor in local+remote: - if itensor.device == ctensor.device and itensor == ctensor: - return [] - common: Optional[IRSubTensor] = itensor.common(ctensor) - if common is None: - continue - common.cell = itensor.cell - intersections.append(common) - # create select primitive - if common != itensor: - indmap = [] - for dim in range(itensor.ndims): - (s1, e1), (s2, e2) = itensor.indmap[dim], common.indmap[dim] - start = s2 - s1 - end = start + e2 - s2 - indmap.append((start, end)) - indmap = IndexMap(tuple(indmap)) - assert itensor.valmap == common.valmap, "Value map not same" - valmap = ValueMap((0, 1)) - select_prim = SelectPrim(itensor, indmap, valmap, common) - prims.append(select_prim) - if itensor.device == ctensor.device and common == ctensor: - return [select_prim] - # TODO: check union == subtensor - if common == ctensor: - break - - # print(intersections) - # ====== move ===== # - tmoved = [] - for tensor in intersections: - assert len(tensor.device) == 1 and len(ctensor.device) == 1, "Expected only one device." - mtensor = tensor - if tensor.device != ctensor.device: - mtensor = copy.copy(tensor) - mtensor.cell = ctensor.cell - prims.append(MovePrim(tensor, mtensor)) - tmoved.append(mtensor) - - # ===== merge ===== # - remain_tensors: List[IRSubTensor] = copy.copy(tmoved) - if ctensor in remain_tensors: - return prims - out = None - while out != ctensor: - out, merged = None, False - for idx1 in range(len(remain_tensors) - 1): - for idx2 in range(idx1+1, len(remain_tensors)): - t1, t2 = remain_tensors[idx1], remain_tensors[idx2] - catdim = t1.catdim(t2) - if catdim is not None: - tensors = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] - out = tensors[0].concat(tensors[1], dim=catdim) - out.cell = ctensor.cell - prims.append(MergeDimPrim(tensors, out, catdim)) - merged = True - break - # reduction - if t1.accumable(t2): - out = t1.accum(t2) - out.cell = ctensor.cell - prims.append(SumPrim([t1, t2], out)) - merged = True - break - if merged: - remain_tensors.remove(t1) - remain_tensors.remove(t2) - remain_tensors.append(out) - break - if out is None: - ptensors = '\n\t'.join(t.extra_repr() for t in ptensors) - remain = '\n\t'.join(t.extra_repr() for t in remain_tensors) - print(remain_tensors[0].extra_repr()) - print(remain_tensors[1].extra_repr()) - print('cadim:', remain_tensors[0].catdim(remain_tensors[1])) - raise RuntimeError( - f"Fail to build adapter.\n" - f"FullTensor:{ctensor.parent}\n" - f"Producers:\n\t{ptensors}\n" - f"SubTensor:\n\t{ctensor.extra_repr()}\n" - f"Remain Tensor:\n\t{remain}" - ) - return prims - - @staticmethod - def local_producer_fusion(graph: IRGraph, ftensor: IRFullTensor) -> IRFullTensor: + def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTensor: """! Fuse the producer tensors using concat and add. This will add a new full tensor by chaging from: @@ -410,7 +280,7 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens fuse_tensors: Dict[int, Dict[IRSubTensor, List[IRSubTensor]]] = dict() tensor_map: Dict[int, Dict[IRSubTensor, IRSubTensor]] = dict() - for tensor in ftensor.ptensors: + for tensor in graph.ptensors(ftensor): assert len(tensor.device) == 1 devid = tensor.device[0] if devid not in devtensors: @@ -470,24 +340,22 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens if len(nodes) == 0: return ftensor # recompute - rcid = set(producer.recompute for producer in ftensor.producers) + rcid = set(producer.recompute for producer in graph.producers(ftensor)) rcid = list(rcid)[0] if len(rcid) == 1 else None for node in nodes: node.recompute = rcid new_ftensor = ftensor.like() - # update consumer - min_idx = len(graph.nodes()) - assert len(ftensor.ctensors) == len(ftensor.consumers) - for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): - fidx = graph.detach(consumer) - consumer.set_input( - consumer.inputs().index(ctensor), - new_ftensor.select(ctensor.indmap, ctensor.valmap) - ) - graph.attach(consumer, fidx) - min_idx = min(fidx, min_idx) + # update consumer + min_idx = min(graph.index(consumer) for consumer in graph.consumers(ftensor)) + assert len(graph.ctensors(ftensor)) == len(graph.consumers(ftensor)) + for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): + with graph.update(consumer) as consumer: + consumer.set_input( + consumer.inputs().index(ctensor), + new_ftensor.select(ctensor.indmap, ctensor.valmap) + ) # insert new producer for devid, tensors in fuse_tensors.items(): @@ -504,30 +372,22 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens node.set_output(0, new_tensor) for node in nodes[::-1]: - # print(node) assert node not in graph.nodes() assert len(node.outputs()) == 1 - graph.attach(node, min_idx) - - # insert and update backward node - if graph.train: - # update backward node - for consumer in new_ftensor.consumers: - assert isinstance(consumer.mirror, IRBpOperation) - bidx = graph.detach(consumer.mirror) - consumer.mirror.update() - graph.attach(consumer.mirror, bidx) - # insert backward node - bnodes = [node.gen_backward() for node in nodes] - bidx = min(graph.nodes().index(producer.mirror) for producer in ftensor.producers) - for bnode in bnodes: - bnode.device = bnode.mirror.device - graph.attach(bnode, bidx) + if graph.mirror is not None: + graph.finsert(node, min_idx) + else: + graph.insert(node, min_idx) + + # update backward + if isinstance(ftensor.grad, IRFullTensor): + graph.update_ftensor_bw(new_ftensor) + graph.update_ftensor_bw(ftensor) return new_ftensor @staticmethod - def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): + def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): """! If a device have a same sub-tensor to be consumed multiple times, then create a multiref forward node for it to make @@ -545,7 +405,7 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): """ # collect to consumer tensors of each device devtensors: Dict[int, Dict[IRSubTensor, List[IRCell]]] = dict() - for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): + for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): assert len(ctensor.device) == 1 devid = ctensor.device[0] if devid not in devtensors: @@ -553,6 +413,17 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): if ctensor not in devtensors[devid]: devtensors[devid][ctensor] = [] devtensors[devid][ctensor].append(consumer) + + # restrict each device has same subtensor + nl = '\n' + for devid in devtensors: + assert len(devtensors[devid]) <= 1, ( + "Detect that a full tensor is partitioned differently on a device.\n" + "To achieve this, need manually add multiref operator in model description.\n" + f"Full Tensor: {ftensor}\n" + f"Producers:\n{nl.join(repr(node) for node in graph.producers(ftensor))}\n" + f"Consumers:\n{nl.join(repr(node) for node in graph.consumers(ftensor))}" + ) # add multiref forward node multirefs: Dict[MultiRef, List[IRFwOperation]] = dict() @@ -562,6 +433,7 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): if len(consumers) == 1: continue multiref = MultiRef(None, [ctensor, len(consumers)]) + multiref.infer_shape() multiref.device = devid ftensors = [ctensor.parent.like() for _ in range(len(consumers))] itensors = [ft.select(ctensor.indmap, ctensor.valmap) for ft in ftensors] @@ -569,29 +441,52 @@ def local_consumer_multiref(graph: IRGraph, ftensor: IRFullTensor): multiref.set_output(idx, itensor) # update consumer - min_fidx = len(graph.nodes()) + min_fidx = graph.nnodes for itensor, consumer in zip(itensors, consumers): - fidx = graph.detach(consumer) - idx = consumer.inputs().index(ctensor) - consumer.set_input(idx, itensor) - graph.attach(consumer, fidx) - min_fidx = min(fidx, min_fidx) - + with graph.update(consumer) as consumer: + idx = consumer.inputs().index(ctensor) + consumer.set_input(idx, itensor) + # insert forward multiref - graph.attach(multiref, min_fidx) + min_fidx = min(graph.index(consumer) for consumer in consumers) + if graph.mirror is not None: + graph.finsert(multiref, min_fidx) + else: + graph.insert(multiref, min_fidx) multirefs[multiref] = consumers + + if len(multirefs) > 0 and isinstance(ftensor.grad, IRFullTensor): + graph.update_ftensor_bw(ftensor) - # insert / update backward - if graph.train: - for multiref, consumers in multirefs.items(): - # update consumer backward - for consumer in consumers: - assert isinstance(consumer.mirror, IRBpOperation) - bidx = graph.detach(consumer.mirror) - consumer.mirror.update() - graph.attach(consumer.mirror, bidx) - # insert backward - bnode = multiref.gen_backward() - bnode.device = multiref.device - bidx = max(graph.nodes().index(consumer.mirror) for consumer in consumers) - graph.attach(bnode, bidx+1) + @staticmethod + def fusion(graph: IRSegment) -> IRSegment: + """ + Fuse consecutive adapters into one + """ + fadapters, badapters = [], [] + for adapter in graph.nodes(): + if isinstance(adapter, IRAdapter) and adapter.forward and not adapter.differentiable: + fadapters.append(adapter) + if adapter.mirror is not None: + badapters.insert(0, adapter.mirror) + else: + if len(fadapters) > 1: + # insert fused fadapter + fused_fadapter = IRAdapter.merge(fadapters) + for adapter in fadapters: + idx = graph.remove(adapter) + graph.insert(fused_fadapter, idx) + # insert fused badapter + fused_badapter = IRAdapter.merge(badapters) if len(badapters) > 0 else None + for adapter in badapters: + idx = graph.remove(adapter) + if fused_badapter is not None: + graph.insert(fused_badapter, idx) + IRCell.make_pair(fused_fadapter, fused_badapter) + fadapters, badapters = [], [] + + for segment in graph.nodes(): + if isinstance(segment, IRSegment) and segment.isfw(): + IRAdapterGener.fusion(segment) + + return graph \ No newline at end of file diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 48d8426c..d4efb15d 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -8,155 +8,33 @@ """ from typing import Union, Tuple, List, Optional, Dict -import copy -from cube.graph.function.function import MultiRef from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.ir.adapter import IRAdapter -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from cube.ir.dtype import IRDType, DTypeInferRule + +from cube.graph.function.function import Identity, MultiRef +from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo -class IRSegment(IRCell): - """ - A distributed sub-graph representing a piece of workload in parent IRGraph +class IRGraph(IRSegment): """ + IRGraph. - def __init__(self, nodes: List[IRCell], inputs: List[IRSubTensor], outputs: List[IRSubTensor]): - super().__init__('segment', '', len(inputs), len(outputs), init_outputs=False) - - self._nodes = nodes - self._idevice = [t.device for t in inputs] - self._odevice = [t.device for t in outputs] - - for idx, val in enumerate(inputs): - self.set_input(idx, val) - for idx, val in enumerate(outputs): - self.set_output(idx, val) - # setup device - device = set() - for node in nodes: - device.update(node.device) - self.device = list(device) - # setup whether forward - fnodes = any(isinstance(n, IRFwOperation) for n in nodes) - bnodes = any(isinstance(n, IRBpOperation) for n in nodes) - assert not (fnodes and bnodes), "An IRSegment cannot have both forward nodes and backward nodes" - self._forward = fnodes - - @property - def forward(self) -> bool: - return self._forward - - def nodes(self, idx: Optional[int] = None) -> Union[IRCell, List[IRCell]]: - if isinstance(idx, int): - return self._nodes[idx] - else: - return copy.copy(self._nodes) - - def dispatch(self, devid: int, for_mirror=True) -> Optional[IRCell]: - """ - Instantiate from distributed representation to a - device-specific sub-graph. - - The mirror will also be dispatched if it is not None. - - Return the dispatched segment - """ - if devid not in self.device: - return None - if len(self.device) == 1 and self.device == [devid]: - return self - itensors = [t for t, device in zip(self.inputs(), self._idevice) if devid in device] - otensors = [t for t, device in zip(self.outputs(), self._odevice) if devid in device] - nodes = [n for n in self.nodes() if devid in n.device] - for idx, adapter in enumerate(nodes): - if isinstance(adapter, IRAdapter): - nodes[idx] = adapter.dispatch(devid) - fseg = IRSegment(nodes, itensors, otensors) - fseg._id = self._id - # dispatch for mirror - if for_mirror and isinstance(self.mirror, IRSegment): - bseg = self.mirror.dispatch(devid, for_mirror=False) - IRCell.make_pair(fseg, bseg) - return fseg - - def to_str(self, skip_attr: bool = False) -> str: - name = ('f' if self.forward else 'b') + 'Segment' - inputs = tuple(t for t in self.inputs() if not (t.is_attr() and skip_attr)) - outputs = tuple(t for t in self.outputs() if not (t.is_attr() and skip_attr)) - return f'{name}{self._id}-{self.device}(inputs={inputs}, outputs={outputs})' - - def __repr__(self): - return self.to_str() - - def extra_repr(self) -> str: - dscp = repr(self) - for node in self.nodes(): - dscp += '\n\t' + repr(node) - return dscp - - -class IRGraph(IRCell): - """ - IR Graph. The hyperGraph for representing distributed - graph. + IRGraph is used for reprensting a distributed training iteration. """ - def __init__(self, - nodes: List[IRCell], - inputs: Optional[List[IRTensor]], - outputs: Optional[List[IRTensor]], + def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRTensor], module_name: str): - self._nodes: List[IRCell] = list() - self._attributes = list() - self._full_tensors: Dict[int, IRFullTensor] = dict() - self._train: bool = any( - isinstance(node, IRBpOperation) or - (isinstance(node, IRSegment) and node.forward) or - (isinstance(node, IRAdapter) and node.forward) for node in nodes - ) + super().__init__(nodes, inputs, outputs, module_name) self._sched = None # the schedule strategy - if inputs is None: - inputs = IRGraph.get_inputs(nodes) - if outputs is None: - outputs = IRGraph.get_outputs(nodes) - - super().__init__( - name=module_name, - signature=module_name, - input_length=len(inputs), - output_length=len(outputs) - ) - - for idx, tensor in enumerate(inputs): - self.set_input(idx, tensor) - for idx, tensor in enumerate(outputs): - self.set_output(idx, tensor) - - # set parameters / buffers and full tensors - for node in nodes: - for tensor in node.inputs() + node.outputs(): - if isinstance(tensor, IRSubTensor): - pid = tensor.parent._id - self._full_tensors[pid] = tensor.parent - if tensor.is_attr(): - self._attributes.append(tensor) - - for ftensor in self._full_tensors.values(): - ftensor.clear_producer_consumer() - - # insert node from nodes - for idx, node in enumerate(nodes): - self.attach(node, idx) - - self.reset_dependency() @property def train(self) -> bool: @@ -165,59 +43,17 @@ def train(self) -> bool: @return train bool: True if backward is required, otherwise False (inference only). """ - return self._train + return self._have_forward and self._have_backward - def reset_dependency(self): - """ - Reset the node dataflow dependency - - Note all the predefined control dependencies will be removed. - """ - for node in self._nodes: - node.clear_predecessor() - node.clear_successor() - # TODO: adapter dependency not set - for ftensor in self._full_tensors.values(): - for ptensor, producer in zip(ftensor.ptensors, ftensor.producers): - for ctensor, consumer in zip(ftensor.ctensors, ftensor.consumers): - if ptensor.overlap(ctensor): - pidx = producer.outputs().index(ptensor) - cidx = consumer.inputs().index(ctensor) - producer.add_successor(pidx, consumer) - consumer.add_predecessor(cidx, producer) - # set mirror as control dependency - if producer.mirror and isinstance(producer, IRFwOperation): - producer.add_successor(-1, producer.mirror) - producer.mirror.add_predecessor(-1, producer) - - def attributes(self) -> Tuple[IRSubTensor]: - """ - Return parameter list - """ - return tuple(self._attributes) - - def full_tensors(self) -> List[IRSubTensor]: - """ - Return full tensor list - """ - return list(self._full_tensors.values()) + # ================ Deep Learning Interfalce ====================== - def nodes(self, index: Optional[int] = None) -> Union[IRCell, List[IRCell]]: + def __call__(self, *args): """ - Get node at position index + Register forward action """ - if isinstance(index, int): - if index >= len(self._nodes): - raise RuntimeError( - f"Get node out of range ({index} >= {len(self._nodes)})" - ) - return self._nodes[index] - elif index is None: - return copy.copy(self._nodes) - else: - raise TypeError("Expected index to be None or int") - - def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: + return self.forward(*args) + + def forward(self, *args: Tuple[IRSubTensor]) -> Union[IRTensor, Tuple[IRTensor]]: """ forward will divide the graph into Actions according to node device assignment @@ -225,238 +61,173 @@ def forward(self, *args) -> Union[IRTensor, Tuple[IRTensor]]: Currently each forward call will result in a new flow even if the input is same - Returns: - IRTensors - """ - from cube.logics.translator import LogicTranslator - return LogicTranslator.forward(self, *args) + @param args Tuple[Any] + + @return outputs Union[IRSubTensor, Tuple[IRSubTensor]] + """ + # align graph with input tensors + itensors: Tuple[IRSubTensor, ...] = self.inputs() + assert len(args) == len(itensors) + for idx, (itensor, arg) in enumerate(zip(itensors, args)): + self.set_input(idx, arg) + for producer in self.producers(itensor.parent): + with self.update(producer): + while itensor in producer.outputs(): + oidx = producer.outputs().index(itensor) + producer.set_output(oidx, arg) + for consumer in self.consumers(itensor.parent): + with self.update(consumer): + while itensor in consumer.inputs(): + iidx = consumer.inputs().index(itensor) + consumer.set_input(iidx, arg) + while itensor in self.outputs(): + oidx = self.outputs().index(itensor) + self.set_output(oidx, arg) + while itensor in self.inputs(): + iidx = self.inputs().index(itensor) + self.set_input(iidx, arg) + + # dtype inference + for node in self._nodes: + itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] + # setup gradient + for itensor in itensors: + if itensor.parent.grad is not None: + itensor.parent.dtype = itensor.dtype + if len(itensors) == 0: continue + odtype = DTypeInferRule.infer(node, [t.dtype for t in itensors]) + assert odtype != IRDType.unknown, f"{node} : {[t.dtype for t in itensors]}" + otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + for tensor in otensors: + tensor.dtype = odtype + # setup graidient + if tensor.parent.grad is not None: + tensor.parent.grad.dtype = odtype + + from cube.program import Program + Program().add_nodes(self.nodes()) + + # return + if len(self.outputs()) == 1: + return self.output(0) + else: + return self.outputs() - def __call__(self, *args): + def backward(self, loss: IRSubTensor): """ - Register forward action + Backward the graph from the entry tensor of loss. + + @param loss IRSubTensor: the loss tensor, must be in the output + of current graph. The loss shape should be (1,) + + @return self IRGraph: None """ - return self.forward(*args) + assert loss in self.outputs() and tuple(loss.shape) == (1,), \ + f"backward should be in graph outputs and the loss is of shape [1,] (got {loss.shape})" + from cube.program import Program + loss.parent.grad = 1.0 + for fnode in self.nodes()[::-1]: + assert not isinstance(fnode, IRSegment), "Internal Error: Segment should not appear for now" + if isinstance(fnode, IRFwOperation): + bnode: IRBpOperation = self.create_bwop(fnode) + Program().add_node(bnode) + # set program graph mirror to self + Program().mirror_as_self() + return self - def segment(self, nodes: List[IRCell]) -> IRSegment: - """! - Create a segment (sub-graph) with part of the nodes. - Nodes are allowed to be on different devices. - The grouped segement will not add into graph.nodes(). - @param nodes List[IRCell]: the subset nodes of this graph + # ========================= Graph Manipulation ======================== - @return segment IRSegment: the grouped segment. - """ - inputs, outputs = [], [] - itdevs, otdevs = dict(), dict() - for node in nodes: - assert not isinstance(node, IRSegment), 'A segment cannot be in other segments' - # update inputs - itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] - for itensor in itensors: - producers = [p for p in itensor.parent.producers if set(p.device).issubset(set(node.device))] - # no producer means a weight or cross device-group - if len(producers) == 0 or any(p not in nodes for p in producers): - if itensor not in itdevs: - itdevs[itensor] = [] - devs = set(itensor.device) - if devs not in itdevs[itensor]: - inputs.append(itensor) - itdevs[itensor].append(devs) - # update outputs - otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] - for otensor in otensors: - consumers = [c for c in otensor.parent.consumers if set(c.device).issubset(set(node.device))] - # no consumer usually means the loss or cross device-group - if otensor in self.outputs() or len(consumers) == 0 or any(c not in nodes for c in consumers): - devs = set(otensor.device) - if otensor not in otdevs: - otdevs[otensor] = [] - if devs not in otdevs[otensor]: - outputs.append(otensor) - otdevs[otensor].append(devs) - segment = IRSegment(nodes, inputs, outputs) - return segment - - def group(self, nodes: List[IRCell]) -> IRSegment: + def group(self, fnodes: List[IRCell]) -> IRSegment: """! - Group consecutive nodes into IRSegment. the grouped segment will - replace the nodes in the graph. - - Note: Currently this interface will break the dependency, - it can only be used after user policy + Group consecutive forward nodes into IRSegment. + TODO: update operator dependency + + The corresponding backward nodes will also be grouped. @param nodes List[IRCell]: the consecutive node subset of this graph @return segment IRSegment: the grouped segment """ - allnodes = self.nodes() - indices = [allnodes.index(n) for n in nodes] - minidx, maxidx = min(indices), max(indices) - assert maxidx - minidx + 1 == len(nodes), "nodes are not consecutive" - segment = self.segment(nodes) - self._nodes = allnodes[:minidx] + [segment] + allnodes[maxidx+1:] - # FIXME: set segment dependnecy - return segment - - def detach(self, node: IRCell, reset_dependency=False) -> int: - """ - Detach (remove) a node from current graph. - - All the used input and output tensors inside the node - are removed from consumed and produced tensor list. - - Return: - index (int): index of the detached node in the graph - """ - if node not in self.nodes(): - raise KeyError(f"node {node} is not in graph.") - index = self._nodes.index(node) - self._nodes.pop(index) - if isinstance(node, IRAdapter): - return index - # update consumer - itensors: List[IRSubTensor] = [] - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor) and itensor not in itensors: - itensors.append(itensor) - for itensor in itensors: - itensor.parent.rm_consumer(node) - # update producer - otensors: List[IRSubTensor] = [] - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor) and otensor not in otensors: - otensors.append(otensor) - for otensor in otensors: - otensor.parent.rm_producer(node) - ftensor = otensor.parent - if len(ftensor.producers) == 0 and len(ftensor.consumers) == 0: - del self._full_tensors[otensor.parent.tid] - if reset_dependency: - self.reset_dependency() - return index - - def attach(self, node: IRCell, index, reset_dependency=False): - """ - Attach (insert) a node into current graph at node index. + assert any(not isinstance(node, (IRBpOperation, IRDataOperation)) for node in fnodes), \ + "grouped nodes cannot be backward operation, segment or data operation" + + fgraphs = [self.segment(fnode) for fnode in fnodes] + assert len(set(fgraphs)) == 1, "Cross-segment grouping is not allowed yet." + + # get backward nodes + bnodes = [fnode.mirror for fnode in fnodes[::-1] if fnode.mirror is not None] + + fgraph: IRSegment = fgraphs[0] + bgraph: IRSegment = fgraph.mirror - All the used input and output tensors inside the node are - recorded in consumed and produced tensor list. Adapter node - will not record the consumer and producer. - """ - if node in self.nodes(): - raise KeyError(f"node {node} is already in graph.") - self._nodes.insert(index, node) - if isinstance(node, IRAdapter): - return - # update consumer - itensors: List[IRSubTensor] = [] - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor) and itensor not in itensors: - itensors.append(itensor) - for itensor in itensors: - if itensor.parent.tid not in self._full_tensors: - self._full_tensors[itensor.parent.tid] = itensor.parent - idx = 0 - for consumer in itensor.parent.consumers: - if self.nodes().index(consumer) < index: - idx += 1 - else: - break - itensor.parent.add_consumer(node, itensor, idx) - # update producer - otensors: List[IRSubTensor] = [] - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor) and otensor not in otensors: - otensors.append(otensor) - for otensor in otensors: - if otensor.parent.tid not in self._full_tensors: - self._full_tensors[otensor.parent.tid] = otensor.parent - idx = 0 - for producer in otensor.parent.producers: - if self.nodes().index(producer) < index: - idx += 1 - else: - break - otensor.parent.add_producer(node, otensor, idx) - if reset_dependency: - self.reset_dependency() - return + findices: Tuple[int] = tuple(fgraph.index(fnode)[0] for fnode in fnodes) + bindices: Tuple[int] = tuple(bgraph.index(bnode)[0] for bnode in bnodes) - def flatten(self) -> List[IRCell]: - """ - Flattent the graph by expanding nodes - """ - nodes = [] - for node in self.nodes(): - if isinstance(node, IRSegment): - nodes += node.nodes() - else: - nodes.append(node) - return nodes + minfidx, maxfidx = min(findices), max(findices) + assert maxfidx - minfidx + 1 == len(fnodes), \ + "Forward nodes are not consecutive" - @staticmethod - def get_inputs(nodes: List[IRCell]): - """ - Get all the input tensors the is not generated by nodes + if len(bnodes) > 0: + minbidx, maxbidx = min(bindices), max(bindices) + assert maxbidx - minbidx + 1 == len(bnodes), \ + f"Internal Error: backward nodes are not consecutive. maxbidx: {maxbidx}, minbidx: {minbidx}" - Inputs + fsegment = fgraph.create_segment(fnodes) + bsegment = bgraph.create_segment(bnodes) if len(bnodes) > 0 else None + IRCell.make_pair(fsegment, bsegment) - Returns: - List[IRTensor] - """ - all_outputs = list() - for node in nodes: - all_outputs.extend(node.outputs()) - inputs = list() - for cell in nodes: - for input in cell.inputs(): - if isinstance(input, IRTensor): - if input not in all_outputs: - if input not in inputs: - inputs.append(input) - return inputs + # replace forward + for fnode in fnodes: + fidx = fgraph.remove(fnode) + fgraph.insert(fsegment, fidx) - @staticmethod - def get_outputs(nodes: List[IRCell]): - """ - Get all the output tensors the is not used by nodes + # replace backward + if len(bnodes) > 0: + for bnode in bnodes: + bidx = bgraph.remove(bnode) + bgraph.insert(bsegment, bidx) + # setup gradient + self.update_bwop(bsegment) - Args: - This will also consider the successor forward nodes. - If it is required by other outside forward nodes, - put in the outputs list + return fsegment - Returns: - List[IRTensor] - """ - all_inputs = list() - for node in nodes: - all_inputs.extend(node.inputs()) - outputs = list() - for node in nodes: - for idx, output in enumerate(node.outputs()): - # not consumed tensor - if isinstance(output, IRSubTensor): - if output not in all_inputs: - if output not in outputs: - outputs.append(output) - continue - # consumed by other nodes - succs = node.successors(idx) - fsuccs = [ - fnode for fnode in succs if isinstance(fnode, IRFwOperation) - ] - for fsucc in fsuccs: - if fsucc not in nodes: - if output not in outputs: - outputs.append(output) - return outputs + # ========================== Graph Creation ======================== @staticmethod def from_logic_graph(nodes: List[IRCell], inputs: List[IRFullTensor], outputs: List[IRFullTensor], module_name: str): + """ + Generate IRGraph from logical graph (IRFullTensor) + + Multiref will be inserted: + + e.g., original graph: + ``` + t = producer(xx) + ... + xx = consumer1(t) + ... + xx = consumer2(t) + ... + xx = consumer3(t) + ... + ``` + will be changed into: + ``` + t = producer(xx) + ... + t1, t2 = multiref(t) + xx = consumer1(t1) + ... + t3, t4 = multiref(t2) + xx = consumer2(t3) + ... + xx = consumer3(t4) + ... + ``` + """ # handle multi-consumed tensor consumers: Dict[IRFullTensor, List[IRCell]] = dict() producers: Dict[IRFullTensor, IRCell] = dict() @@ -464,8 +235,8 @@ def from_logic_graph(nodes: List[IRCell], ftensors = set() for ftensor in node.inputs(): # remove redundant tensors within an operator - if isinstance(ftensor, IRFullTensor) and ftensor._id not in ftensors: - ftensors.add(ftensor._id) + if isinstance(ftensor, IRFullTensor) and ftensor.tid not in ftensors: + ftensors.add(ftensor.tid) if ftensor not in consumers: consumers[ftensor] = [] consumers[ftensor].append(node) @@ -474,20 +245,57 @@ def from_logic_graph(nodes: List[IRCell], producers[ftensor] = node for ftensor, cnodes in consumers.items(): if len(cnodes) == 1 or ftensor.is_attr(): continue - itensors = [ftensor.like() for _ in range(len(cnodes))] - for itensor, consumer in zip(itensors, cnodes): + reftensor = ftensor + ctensor = ftensor + while len(cnodes) > 0: + consumer = cnodes.pop(0) + if len(cnodes) > 0: + itensors = [ftensor.like() for _ in range(2)] + multiref = MultiRef(None, [reftensor, 2]) + for idx, itensor in enumerate(itensors): + multiref.set_output(idx, itensor) + multiref.infer_shape() + # insert multiref right before the consumor + idx = nodes.index(consumer) + nodes.insert(idx, multiref) + ctensor, reftensor = itensors + else: + # the last consumer doesn't need multiref + ctensor = reftensor + # update consumer while ftensor in consumer.inputs(): idx = consumer.inputs().index(ftensor) - consumer.set_input(idx, itensor) - # create and insert multiref operation - multiref = MultiRef(None, [ftensor, len(cnodes)]) - for idx, itensor in enumerate(itensors): - multiref.set_output(idx, itensor) - multiref.infer_shape() - idx = nodes.index(producers[ftensor]) + 1 if ftensor in producers else 0 - # idx = nodes.index(cnodes[0]) - nodes.insert(idx, multiref) - + consumer.set_input(idx, ctensor) + + # another version to generate multiref: one for all + # for node in nodes: + # ftensors = set() + # for ftensor in node.inputs(): + # # remove redundant tensors within an operator + # if isinstance(ftensor, IRFullTensor) and ftensor._id not in ftensors: + # ftensors.add(ftensor._id) + # if ftensor not in consumers: + # consumers[ftensor] = [] + # consumers[ftensor].append(node) + # for ftensor in node.outputs(): + # if isinstance(ftensor, IRFullTensor): + # producers[ftensor] = node + # for ftensor, cnodes in consumers.items(): + # if len(cnodes) == 1 or ftensor.is_attr(): continue + # itensors = [ftensor.like() for _ in range(len(cnodes))] + # for itensor, consumer in zip(itensors, cnodes): + # while ftensor in consumer.inputs(): + # idx = consumer.inputs().index(ftensor) + # consumer.set_input(idx, itensor) + # # create and insert multiref operation + # multiref = MultiRef(None, [ftensor, len(cnodes)]) + # for idx, itensor in enumerate(itensors): + # multiref.set_output(idx, itensor) + # multiref.infer_shape() + # idx = nodes.index(producers[ftensor]) + 1 if ftensor in producers else 0 + # # idx = nodes.index(cnodes[0]) + # nodes.insert(idx, multiref) + # instantiate graph inputs / outputs for idx, tensor in enumerate(inputs): if isinstance(tensor, IRFullTensor): @@ -530,23 +338,24 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") - if node not in self.nodes(): - raise RuntimeError(f"Op {node} not exsits") - - fidx = self.detach(node) + fsegment: IRSegment = self.segment(node) + # replicate fnodes = [node.replicate() for _ in range(times)] # insert forward - for idx, fnode in enumerate(fnodes): - self.attach(fnode, fidx + idx) + for fnode in fnodes: + if isinstance(node, IRFwOperation): + fnode.recompute = node.recompute + if isinstance(node.comment, str): + fnode.comment = node.comment + fnode.device = node.device + fsegment.replace(node, fnodes) # insert backward + bsegment: IRSegment = fsegment.mirror if isinstance(node.mirror, IRBpOperation): - bidx = self.detach(node.mirror) - for fnode in fnodes: - fnode.gen_backward() - bnodes = [fnode.mirror for fnode in fnodes][::-1] - for idx, bnode in enumerate(bnodes): - self.attach(bnode, bidx + idx) - #TODO: dependency set + bnodes = tuple(self.create_bwop(fnode) for fnode in fnodes[::-1]) + for bnode in bnodes: + bnode.device = node.device + bsegment.replace(node.mirror, bnodes) return fnodes def partition(self, node: Union[IRFwOperation, IRDataOperation], @@ -577,95 +386,65 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], """ assert isinstance(algo, GenericDistAlgo) and node == algo.node, \ "The partition algorithm is not initialized for this node" - if node not in self.nodes(): - raise RuntimeError(f"Not Exist: {node}") - if not (isinstance(node, IRFwOperation) or isinstance(node, IRDataOperation)): - raise ValueError("Only allow op to be forward op or data op.") - + assert isinstance(node, (IRFwOperation, IRDataOperation)), \ + f"Only allow op to be forward op or data op, but got: {node}" + + fsegment: IRSegment = self.segment(node) # get partitioned sub-nodes fnodes = algo.instantiate(**config) - if fnodes is None: return fnodes - + assert fnodes is not None, f"Fail to partition node: {node} use algothim and config: {config}" # update forward - findex = self.detach(node) - for idx, fnode in enumerate(fnodes): - self.attach(fnode, findex + idx) - if isinstance(node.comment, str): - fnode.comment = node.comment + for fnode in fnodes: if isinstance(node, IRFwOperation): fnode.recompute = node.recompute + if isinstance(node.comment, str): + fnode.comment = node.comment + fnode.device = node.device + fsegment.replace(node, fnodes) # update backward + bsegment: IRSegment = fsegment.mirror if isinstance(node.mirror, IRBpOperation): - bindex = self.detach(node.mirror) - bnodes = [fnode.gen_backward() for fnode in fnodes][::-1] - for idx, bnode in enumerate(bnodes): - self.attach(bnode, bindex + idx) - if isinstance(node.mirror.comment, str): - bnode.comment = node.mirror.comment + bnodes = tuple(self.create_bwop(fnode) for fnode in fnodes[::-1]) + bsegment.replace(node.mirror, bnodes) + for bnode in bnodes: + bnode.device = node.device # update gradient updated = set() for itensor in [t for t in node.inputs() if isinstance(t, IRSubTensor)]: - for fnode in itensor.parent.consumers: - bnode = fnode.mirror - if isinstance(bnode, IRBpOperation) and fnode._id not in updated: - idx = self.detach(bnode) - bnode.update() - self.attach(bnode, idx) - updated.add(fnode._id) - # update device - for fnode in fnodes: - fnode.device = node.device - if isinstance(fnode.mirror, IRCell): - fnode.mirror.device = node.device + for fnode in fsegment.consumers(itensor.parent): + bnode: IRBpOperation = fnode.mirror + if isinstance(bnode, IRBpOperation) and fnode.cid not in updated: + self.update_bwop(bnode) + updated.add(fnode.cid) return fnodes - def replace(self, old_nodes: List[IRCell], new_nodes: List[IRCell]): - """! - Replace nodes with node. - - Note we don't check semantic correctness for the replacement. - - @param old_nodes List[IRCell]: nodes to be replaced - @param new_nodes List[IRCell]: nodes to replace in - - @return True - """ - idx = len(self._nodes) - for old_node in old_nodes: - oidx = self.detach(old_node) - idx = min(oidx, idx) - for new_node in new_nodes[::-1]: - self.attach(new_node, idx) - return True - ## Spatial Primitives ## - def assign(self, node: Union[IRFwOperation, IRBpOperation], - ranks: Union[int, Tuple[int]]): + def assign(self, node: Union[IRFwOperation, IRDataOperation], device: int) -> bool: """ Assign an operator (subgraph) to (multiple) rank(s). - If `ranks` has multiple integer, then the operator will be replicated - `len(ranks)` times and assigned to given device correspondingly. - - Corresponding backward operators (if have) will also be replicated - and assigned to the same device with it's forward operator + Corresponding backward operators (if have) will also be + assigned to the same device. - @param node Union[IRFwOperation, IRBpOperation]: operator - @param ranks Tuple[int, Tuple[int]]: assigned ranks + @param node Union[IRFwOperation, IRBpOperation, IRSegment]: operator + @param device int: assigned device id @return sucess bool: always true """ - assert isinstance(node, (IRFwOperation, IRDataOperation)), f"Only forward and data operation can be assigned to device, but got {node}" - assert node in self._nodes, f"{node} is not in the graph" - ranks = (ranks,) if isinstance(ranks, int) else ranks - assert all([isinstance(rank, int) for rank in ranks]), "Expected rank to be int" - nodes = [node] if len(ranks) == 1 else self.replicate(node, times=len(ranks)) - for node, rank in zip(nodes, ranks): - node.device = rank - if isinstance(node.mirror, IRBpOperation): - bnode = node.mirror - bnode.device = rank + assert self.exist(node), f"{node} is not in the graph" + if isinstance(node, IRSegment): + assert node.forward, "Only forward segment is allowed to assign devices" + for subnode in node.nodes(): + subnode.device = device + if subnode.mirror is not None: + subnode.mirror.device = device + else: + assert isinstance(node, (IRFwOperation, IRDataOperation)), \ + "Only forward operators and dataloader operators are allowed to assign devices" + node.device = device + if node.mirror is not None: + node.mirror.device = device return True ## Schedule Policy Primitives ## @@ -677,6 +456,7 @@ def happen_before(self, node1: IRCell, node2: IRCell, skip=None) -> bool: Returns: Boolean """ + raise NotImplementedError("dependency is not supported yet") skip = list() if skip is None else skip if node1 in skip: return False @@ -732,8 +512,8 @@ def schedule(self, node1: IRCell, action: str, node2: IRCell) -> bool: for idx in range(idx1+1, idx2+1): if self.depends(node1, self._nodes[idx]): return False - self.detach(node1) - self.attach(node1, idx2) + self.remove(node1) + self.insert(node1, idx2) return True # node1 -> node2 if action == 'before': @@ -742,8 +522,8 @@ def schedule(self, node1: IRCell, action: str, node2: IRCell) -> bool: for idx in range(idx2, idx1): if self.depends(self._nodes[idx], node1): return False - self.detach(node1) - self.attach(node1, idx2) + self.remove(node1) + self.insert(node1, idx2) return True raise KeyError(f"Unknown scheduling action {action}") @@ -807,10 +587,141 @@ def add_schedule(self, nodes: List[IRCell]) -> bool: post.add_predecessor(input_index=-1, cell=prev) return True + # ================= staging primitives ================== + + def staging(self, nodes: Tuple[IRFwOperation]): + """! + Group forward / dataloader operators into sequential stages. + The corresponding backward operators will also be grouped into stages + Cross-stage dataflow will be limited to neighbor stages. + This should be called before any operator partition. + + The transformation and temporal scheduling can only be applied within each stage. + For example, after staging, user cannot schedule a (transformed) node + from one stage to another stage. + + The stage is a concept that is only about logical separation of nodes, + it doesn't have additional constraints for device assignment. + + Changes will be made: + + 1). Identity creation: + If a non-attribute tensor is produced / consumed not in + neighbor stages, + e.g., + stage 1: t1 = producer() + stage 2: ... + stage 3: xx = consume(t1) + stage 4: ... + stage 5: xx = consume(t1) + then Identity nodes will be created for every device in stage2: + stage 1: t1 = producer() + stage 2: t2 = identity(t1) + stage 3: xx = consume(t2) + stage 4: t3 = identity(t2) + stage 5: xx = consume(t3) + + 2). REMOVED: Multiref Modification: + If a non-attribute tensor has multiref node to different devmeshes, + e.g., + stage 1: t1, t2 = multiref(t) + stage 2: xx = consume(t1) + stage 3: ... + stage 4: xx = consume(t2) + then the multiref will be transfered into identity operator: + stage 1: t1 = multiref(t) + stage 2: xx = consume(t1) + t2 = identity(t1) + stage 3: t3 = identity(t2) + stage 4: xx = consume(t3) + + @param starts Tuple[int]: the start index of each stage + @return None + """ + assert all(isinstance(node, IRFwOperation) for node in nodes), \ + f"Find node is not IRFwOperation or IRDataOperation: {node}" + assert all(node in self._nodes for node in nodes), \ + f"Exist node is not in graph nodes" + starts = tuple(self._nodes.index(node) for node in nodes) + assert len(starts) > 0 + starts = (0,) + starts if starts[0] != 0 else starts + + last_fidx = 0 + for idx, node in enumerate(self._nodes): + if not isinstance(node, IRBpOperation): + last_fidx = idx + + fstages: List[List[IRCell]] = [] + bstages: List[List[IRCell]] = [] + for sid in range(len(starts)): + begin = starts[sid] + end = starts[sid+1] if sid != len(starts) - 1 else last_fidx + 1 + while isinstance(self.node(begin), IRDataOperation): + begin += 1 + while isinstance(self.node(end), IRDataOperation): + end -= 1 + if begin == end: continue + assert begin < end + fnodes = self._nodes[begin:end] + bnodes = [fnode.mirror for fnode in fnodes[::-1] if fnode.mirror is not None] + fstages.append(fnodes) + bstages = [bnodes] + bstages + + def get_sid(fnode: IRCell) -> Optional[int]: + for idx, fnodes in enumerate(fstages): + if fnode in fnodes: + return idx + return None + + def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: + identity = Identity('', [tensor]) + identity.infer_shape() + identity.set_output(0, identity.output(0).tosub()) + # insert forward + fidx = self.index(fstages[sid][0]) + if tensor.requires_grad: + self.finsert(identity, fidx) + bstages[sid].append(identity.mirror) + else: + self.insert(identity, fidx) + fstages[sid].insert(0, identity) + return identity + + # create identity op for cross-stage dataflow + # the gradient flow of neighbor stages is automatically guaranteed + for ftensor in self.full_tensors(): + if ftensor.is_grad() or ftensor.is_attr(): continue + assert len(self.producers(ftensor)) <= 1, \ + "The staging interface should be called before any operator partition." + if len(self.consumers(ftensor)) == 0: continue + producer, ptensor = self.producers(ftensor)[0], self.ptensors(ftensor)[0] + psid = get_sid(producer) + # outside of stages, not consider + if psid is None: continue + out = ptensor + curr_sid = psid + for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): + assert ctensor == ptensor, "The staging interface should be called before any operator partition." + csid = get_sid(consumer) + if curr_sid == csid: continue + for sid in range(curr_sid + 1, csid): + identity = insert_identity(out, sid) + out = identity.output(0) + # update consumer + with self.update(consumer) as consumer: + tidx = consumer.inputs().index(ptensor) + consumer.set_input(tidx, out) + curr_sid = csid + # update all its backward operators + self.update_ftensor_bw(ftensor.grad) + # grouping into segment + for sid in range(len(fstages)): + self.group(fstages[sid]) + # ================= Other optimizations ================== - def recompute(self, nodes: List[IRFwOperation]) -> bool: + def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: """! Recompute a set of nodes. The forward nodes will be assigned with a unique recompute group id. A forward not can not be recomputed in different recompute groups. @@ -819,34 +730,85 @@ def recompute(self, nodes: List[IRFwOperation]) -> bool: @return success boolean: always success """ - assert all(isinstance(fnode, IRFwOperation) for fnode in nodes), "require forward operations" - recompute_group_id = IDGenerator().gen_cell_id() - for fnode in nodes: - fnode.recompute = recompute_group_id - return True - - def __repr__(self): - dscp = f"Graph{self._id}-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})" - return dscp - - def extra_repr(self): - dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" - # inputs - dscp += f"Inputs: {self.inputs()}\n" - # nodes - for node in self._nodes: - # succ_node_ids = [node._id for node in node.successors()] - # succ_node_ids = [None] * len(node.outputs()) - # for out_idx in range(len(node.outputs())): - # node_list = [snode._id for snode in node.successors(out_idx)] - # succ_node_ids[out_idx] = node_list - # dscp += f"\n{node._id}: {node} -> node id {succ_node_ids}" - dscp += f"\n{node}" - # outputs - dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" - return dscp - - def module_repr(self): - return repr(self) + assert all(isinstance(node, IRFwOperation) for node in nodes) or isinstance(nodes, IRSegment), \ + "Require forward nodes or a single segment" + if isinstance(nodes, IRSegment): + assert nodes.isfw() and (not nodes.isbw()), "Only forward IRSegment can recompute" + return self.recompute(nodes.nodes()) + + else: + segments = [self.segment(node) for node in nodes] + assert all(segment == segments[0] for segment in segments), \ + "Cross-segment recompute is not allowed yet" + recompute_group_id: int = IDGenerator().gen_cell_id() + for fnode in nodes: + fnode.recompute = recompute_group_id + + return True + # =================== Helpers ==================== + + def auto_multiref(self): + """ + Automatically partition and schedule multiref node. + This requires to call after all transformation and + scheduling. + + The policy is to partition and assign multiref + in the same way of its input producer + """ + for node in self.nodes(flatten=True): + if node.name == 'multiref': + if len(node.device) != 0: continue + segment: IRSegment = self.segment(node) + ftensor = node.input(0).parent + ptensors = segment.ptensors(ftensor) + + multirefs = [] + + # use downstream consumers + devtensors: Dict[int, List[IRSubTensor]] = dict() + for tensor in node.outputs(): + for ctensor in segment.ctensors(tensor.parent): + for devid in ctensor.device: + if devid not in devtensors: + devtensors[devid] = [] + devtensors[devid].append(ctensor) + devids = list(devtensors.keys()) + ctensors = [ts[0] for ts in devtensors.values()] + for devid, ctensor in zip(devids, ctensors): + itensor = node.input(0).parent.select(ctensor.indmap, ctensor.valmap) + otensors = [] + for otensor in node.outputs(): + otensors.append(otensor.parent.select(ctensor.indmap, ctensor.valmap)) + multiref = MultiRef('', [itensor, len(otensors)]) + for idx, otensor in enumerate(otensors): + multiref.set_output(idx, otensor) + multiref.device = devid + multirefs.append(multiref) + + # if no downstream consumers, use upstream producers + if len(multirefs) == 0: + for ptensor in ptensors: + assert len(ptensor.device) > 0, \ + "Auto Multiref requires its producer nodes assigned to devices" + for devid in ptensor.device: + outputs = [] + for output in node.outputs(): + outputs.append(output.parent.select(ptensor.indmap, ptensor.valmap)) + multiref = MultiRef('', [ptensor, len(outputs)]) + for idx, otensor in enumerate(outputs): + multiref.set_output(idx, otensor) + multiref.device = devid + multirefs.append(multiref) + + # replace into graph + fidx = self.remove(node) + if node.mirror is not None: + self.remove(node.mirror) + for multiref in multirefs[::-1]: + if node.mirror is not None: + self.finsert(multiref, fidx) + else: + self.insert(multiref, fidx) diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index cd7b3a25..5d2c539f 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,3 +1,3 @@ from cube.graph.parser.parser import ScriptModuleParser -from cube.graph.parser.converter import convert_model, convert_dataloader +from cube.graph.parser.converter import convert_model from cube.graph.parser.register import register \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index bde06b93..37f034de 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -3,7 +3,6 @@ from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser from cube.graph import IRGraph -from cube.logics.dataloader import IRDataLoader import torch @@ -25,11 +24,3 @@ def convert_model(model: torch.nn.Module, graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) return graph - -def convert_dataloader(dataloader) -> IRDataLoader: - """ - convert pytorch dataloader into IRDataLoader - """ - from cube.graph.parser.mapping import DType2IRDType - dataloader = IRDataLoader(dataloader, dtype_map=DType2IRDType) - return dataloader diff --git a/cube/graph/segment.py b/cube/graph/segment.py new file mode 100644 index 00000000..848d4251 --- /dev/null +++ b/cube/graph/segment.py @@ -0,0 +1,820 @@ +from contextlib import contextmanager +from typing import Dict, Union, List, Optional, Set, Tuple + +from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.cten import IRTensor, IRCell +from cube.ir.operator import IRFwOperation, IRBpOperation +from cube.ir.adapter import IRAdapter + + +class CellPosition: + + def __init__(self, indices: Tuple[int]): + assert all(isinstance(idx, int) for idx in indices) and len(indices) > 0 + self.indices = tuple(indices) + + def __hash__(self) -> int: + return hash(self.indices) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, CellPosition), "Cannot compare with non-GraphIndex object" + return self.indices == other.indices + + def __lt__(self, other: object) -> bool: + assert isinstance(other, CellPosition), "Cannot compare with non-GraphIndex object" + if len(self.indices) < len(other.indices): + return True + if len(self.indices) > len(other.indices): + return False + for lidx, ridx in zip(self.indices, other.indices): + if lidx >= ridx: + return False + return True + + def __le__(self, other: object) -> bool: + return self < other or self == other + + def __gt__(self, other: object) -> bool: + return not self <= other + + def __ge__(self, other: object) -> bool: + return not self < other + + def __sub__(self, offset: int): + assert isinstance(offset, int) + indices = list(self.indices) + indices[-1] -= offset + return CellPosition(indices) + + def __add__(self, offset: int): + assert isinstance(offset, int) + indices = list(self.indices) + indices[-1] += offset + return CellPosition(indices) + + def __getitem__(self, idx: int) -> int: + return self.indices[idx] + + def __len__(self) -> int: + return len(self.indices) + + def __repr__(self) -> str: + return repr(self.indices) + + +class IRSegment(IRCell): + """ + A distributed sub-graph representing a piece of workload in parent IRGraph + + Once the segment is generated, its input and output will be fixed. + Inserting and removing nodes that could change input/output are not allowed. + """ + + def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRSubTensor], name='segment'): + super().__init__(name, '', len(inputs), len(outputs), init_outputs=False) + + self._nodes: List[IRCell] = [] + self._idevice = [t.device for t in inputs] + self._odevice = [t.device for t in outputs] + + for idx, val in enumerate(inputs): + self.set_input(idx, val) + for idx, val in enumerate(outputs): + self.set_output(idx, val) + + # full-tensor / sub-tensor mapping + self._ftensors: Set[IRFullTensor] = set() + self._producers: Dict[IRFullTensor, List[IRCell]] = dict() + self._consumers: Dict[IRFullTensor, List[IRCell]] = dict() + self._ptensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() + self._ctensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() + + # attributes + self._attributes: Set[IRFullTensor] = set() + + for node in nodes: + self.insert(node, self.nnodes) + + # self.reset_dependency() + + # FIXME: update when manipulating + self._have_forward = any(isinstance(n, IRFwOperation) for n in nodes) + self._have_backward = any(isinstance(n, IRBpOperation) for n in nodes) + + def isfw(self) -> bool: + return self._have_forward + + def isbw(self) -> bool: + return self._have_backward + + def full_tensors(self) -> Tuple[IRFullTensor]: + """ + Get all full tensors of this graph. + Note the full tensor inside the node will not be returned. + + @return ftensors List[IRFullTensor] + """ + return tuple(self._ftensors) + + def attributes(self) -> Tuple[IRFullTensor]: + """ + Get al full tensor attributes of this graph + Note the full tensor inside the node will not be returned. + + @return ftensors List[IRFullTensor] + """ + return Tuple(self._attributes) + + def reset_dependency(self): + """ + Reset the node dataflow dependency + + Note all the predefined control dependencies will be removed. + TODO: adapter dependency is not set + """ + for node in self._nodes: + node.clear_predecessor() + node.clear_successor() + # TODO: adapter dependency not set + for ftensor in self._ftensors: + for ptensor, producer in zip(self.ptensors(ftensor), self.producers(ftensor)): + for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): + if ptensor.overlap(ctensor): + pidx = producer.outputs().index(ptensor) + cidx = consumer.inputs().index(ctensor) + producer.add_successor(pidx, consumer) + consumer.add_predecessor(cidx, producer) + # set mirror as control dependency + if producer.mirror is not None and isinstance(producer, IRFwOperation): + producer.add_successor(-1, producer.mirror) + producer.mirror.add_predecessor(-1, producer) + # sub segments + for segment in self._nodes: + if isinstance(segment, IRSegment): + segment.reset_dependency() + + # ========================= Basic Graph access ======================= + + @property + def device(self) -> List[int]: + devices = set() + for node in self._nodes: + devices.update(node.device) + devices = list(devices) + devices.sort() + return devices + + @property + def nnodes(self) -> int: + """ + Get total node number + + @return number int: the number of nodes + """ + return len(self._nodes) + + def nodes(self, flatten = False) -> Tuple[IRCell]: + """ + Get all the nodes. + + @param flatten bool: Flat the segment to get all the nested cells + + @return nodes List[IRCell]: all the nodes + """ + if not flatten: + return tuple(self._nodes) + nodes = [] + for node in self._nodes: + if not isinstance(node, IRSegment): + nodes.append(node) + else: + nodes += list(node.nodes(flatten)) + return tuple(nodes) + + def node(self, index: Union[int, CellPosition]) -> IRCell: + """ + Get node at position index + + @param index Union[int, CellPosition]: the node index + + @return node IRCell: the node. + """ + pos = CellPosition((index,)) if isinstance(index, int) else index + assert isinstance(pos, CellPosition), "Expect index to be int or CellPosition" + node = self + for idx in pos.indices: + assert isinstance(node, IRSegment), "idx applies on a non-segment node" + node = node._nodes[idx] + return node + + def index(self, node: IRCell) -> CellPosition: + """ + Get node index. The dispatched node (e.g., IRAdapter, IRSegment) + will return the index to its un-dispatched node + + @param node IRCell: the queried node + + @return index int: the index + """ + assert isinstance(node, IRCell) + if node in self._nodes: + return CellPosition((self._nodes.index(node),)) + for idx, segment in enumerate(self._nodes): + if isinstance(segment, IRSegment): + if segment.exist(node): + index = segment.index(node) + return CellPosition((idx,) + index.indices) + raise KeyError(f"The queried node: {node} not in the graph") + + def segment(self, node: IRCell) -> IRCell: + """ + Get the lowest segment that constains the node + + @param node IRCell: the queried node + + @return segment IRSegment + """ + assert isinstance(node, IRCell), f"Expected IRCell, but got {node}" + index = self.index(node) + if len(index) == 1: + return self + else: + return self.node(CellPosition(index.indices[:-1])) + + def producers(self, ftensor: IRFullTensor) -> Tuple[IRCell]: + """ + Get producers of ftensor in execution order in this graph + + @param ftensor IRFullTensor: the queried full tensor. + + @return subtensors Tuple[IRSubTensor]: the producers. + """ + assert ftensor in self._producers, f"{ftensor} is not in the graph" + return tuple(self._producers[ftensor]) + + def consumers(self, ftensor: IRFullTensor) -> Tuple[IRCell]: + """ + Get consumers of ftensor in execution order in this graph + + @param ftensor IRFullTensor: the queried full tensor. + + @return subtensors Tuple[IRCell]: theconsumers + """ + assert ftensor in self._consumers, f"{ftensor} is not in the graph" + return tuple(self._consumers[ftensor]) + + def ptensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: + """ + Get consumed sub-tensors of ftensor in execution order in this graph + + @param ftensor IRFullTensor: the queried full tensor. + + @return subtensors Tuple[IRSubTensor]: the consumed subtensors. + """ + assert ftensor in self._ptensors, f"{ftensor} is not in the graph" + return tuple(self._ptensors[ftensor]) + + def ctensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: + """ + Get consumed sub-tensors of ftensor in execution order in this graph + + @param ftensor IRFullTensor: the queried full tensor. + + @return subtensors Tuple[IRSubTensor]: the consumed subtensors. + """ + assert ftensor in self._ctensors, f"{ftensor} is not in the graph" + return tuple(self._ctensors[ftensor]) + + def grad(self, tensor: IRSubTensor, no_partial_overlap=False) -> IRSubTensor: + """ + Get gradient of the tensor. + + @param tensor IRSubTensor: IRSubTensor: the queried tensor + + @return gradient IRSubTensor: the gradient + """ + segment: IRSegment = self.segment(tensor.cell) + assert isinstance(tensor, IRSubTensor), "Only tensor has gradient" + fgrad = tensor.parent.grad + # None means no gradient requirement, flaot means its the loss + if fgrad is None or isinstance(fgrad, float): + tensor.grad = fgrad + return fgrad + ftensor = tensor.parent + # this tensor is consumed + if tensor in tensor.cell.inputs(): + consumer_cids = [] + for ctensor, consumer in zip(segment.ctensors(ftensor), segment.consumers(ftensor)): + if no_partial_overlap: + assert not (ctensor != tensor and ctensor.overlap(tensor)), ( + f"parital overlapping is not supported for gradient\n" + f"{self.debug_tensor_map_str(ctensor.parent)}" + ) + if ctensor == tensor and consumer.cid not in consumer_cids: + consumer_cids.append(consumer.cid) + + valmap = (consumer_cids.index(tensor.cell.cid), len(consumer_cids)) + grad = ftensor.grad.select( + indmap = tensor.indmap, + valmap = valmap + ) + # this tensor is produced + elif tensor in tensor.cell.outputs(): + grad = ftensor.grad.select( + indmap = tensor.indmap, + valmap = (0, 1), + ) + tensor.grad = grad + return grad + + def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: + dscp : str = '' + ftensors = [ftensor] if ftensor is not None else self._ftensors + for ftensor in ftensors: + dscp += f'====\nFull Tensor: {ftensor}\n' + dscp += f'Producers:\n' + for producer in self._producers[ftensor]: + dscp += f'\t{producer}\n' + dscp += f'Consumers:\n' + for consumer in self._consumers[ftensor]: + dscp += f'\t{consumer}\n' + return dscp + + def create_bwop(self, fwop: IRFwOperation) -> IRBpOperation: + """ + Create dummy backward operator for given forward operator + """ + assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" + fsegment: IRSegment = self.segment(fwop) + fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] + fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] + igrads = [fsegment.grad(t) if t.requires_grad else None for t in fins] + ograds = [fsegment.grad(t) if t.requires_grad else None for t in fous] + bwop = IRBpOperation(ograds, igrads) + IRCell.make_pair(fwop, bwop) + return bwop + + def update_bwop(self, bwop: IRCell) -> IRBpOperation: + """ + Update backward operator or a backward segment. + + This is neccessary when fwop is partitioned and reference count is changed. + + @param bwop IRBpOperation or IRSegment: the backward operation. + It can be at any hierarchy of this segemtn + + @return bwop IRBpOperation: the updated operation (inplace) + """ + assert isinstance(bwop, (IRBpOperation, IRSegment)) + if isinstance(bwop, IRSegment): + assert bwop.isbw() and (not bwop.isfw()) + bsegment: IRSegment = self.segment(bwop) + fsegment = bsegment.mirror + with bsegment.update(bwop): + fwop: Union[IRFwOperation, IRSegment] = bwop.mirror + fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] + fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] + igrads = [fsegment.grad(t) if t.requires_grad else None for t in fins] + ograds = [fsegment.grad(t) if t.requires_grad else None for t in fous] + for idx, igrad in enumerate(igrads): + bwop.set_output(idx, igrad) + # Ad-hoc fix: remove float that could be caused by loss for segment + if isinstance(bwop, IRSegment): + ograds = [grad for grad in ograds if isinstance(grad, IRSubTensor)] + for idx, ograd in enumerate(ograds): + bwop.set_input(idx, ograd) + return bwop + + def update_ftensor_bw(self, ftensor: IRFullTensor): + """ + Update all backward operators for a full tensor. + + @param ftensor IRFullTensor: the full tensor. If the full + tensor is not a gradient, will update backward operators + of ftensor.grad + + @return None + """ + fgrad = ftensor.grad if not ftensor.is_grad() else ftensor + if fgrad is None: + return + for producer in self.producers(fgrad): + self.update_bwop(producer) + for consumer in self.consumers(fgrad): + self.update_bwop(consumer) + + # ====================== Basic Graph manipulations ====================== + + def _add_ftensor(self, ftensor: IRFullTensor): + """ + Add a full tensor in segment if the segment doesn't have the tensor. + """ + assert isinstance(ftensor, IRFullTensor) + if ftensor not in self._ftensors: + self._ftensors.add(ftensor) + self._producers[ftensor] = [] + self._consumers[ftensor] = [] + self._ptensors[ftensor] = [] + self._ctensors[ftensor] = [] + if ftensor.is_attr(): + self._attributes.add(ftensor) + + def _remove_ftensor(self, ftensor: IRFullTensor): + """ + Remove a full tensor in segment + """ + assert isinstance(ftensor, IRFullTensor) + if ftensor in self._ftensors: + self._ftensors.remove(ftensor) + del self._producers[ftensor] + del self._consumers[ftensor] + del self._ptensors[ftensor] + del self._ctensors[ftensor] + if ftensor.is_attr() and ftensor in self._attributes: + self._attributes.remove(ftensor) + + def insert(self, node: IRCell, index: Union[int, CellPosition]): + """ + Insert a node at index. + + TODO: dataflow dependency update + TODO: input / output check + + @param node IRCell: the inserted node + @param index int: the index + + """ + pos = CellPosition((index,)) if isinstance(index, int) else index + assert isinstance(pos, CellPosition), "Expect index to be int or CellPosition" + + if len(pos) == 1: + index = pos[0] + # insert node + self._nodes.insert(index, node) + # update producer and consumer + if isinstance(node, IRAdapter): return + # consumer + itensors = set(t for t in node.inputs() if isinstance(t, IRSubTensor)) + for itensor in itensors: + ftensor = itensor.parent + self._add_ftensor(ftensor) + idx = len([c for c in self._consumers[ftensor] if self._nodes.index(c) < index]) + self._consumers[ftensor].insert(idx, node) + self._ctensors[ftensor].insert(idx, itensor) + # producer + otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) + for otensor in otensors: + ftensor = otensor.parent + self._add_ftensor(ftensor) + idx = len([c for c in self._producers[ftensor] if self._nodes.index(c) < index]) + self._producers[ftensor].insert(idx, node) + self._ptensors[ftensor].insert(idx, otensor) + else: + segment = self._nodes[pos[0]] + assert isinstance(segment, IRSegment), "Expected IRSegment" + pos = CellPosition(pos.indices[1:]) + segment.insert(node, pos) + + def remove(self, node: IRCell, _pos: CellPosition = None) -> CellPosition: + """ + Remove a node at index + + # TODO: check input and output + + @param node IRCell: the removed node + + @return index CellPosition: the removed index + """ + pos = self.index(node) if _pos is None else _pos + assert self.node(pos) == node, \ + f"posititon doesn't not match with node:\n\t{node}\nGot:\n\t{self.node(pos)}" + + if len(pos.indices) == 1: + index = pos[0] + # remove + self._nodes.pop(index) + # update producer and consumer + if isinstance(node, IRAdapter): return pos + # consumer + itensors = set(t for t in node.inputs() if isinstance(t, IRSubTensor)) + for itensor in itensors: + ftensor = itensor.parent + idx = self._consumers[ftensor].index(node) + self._consumers[ftensor].pop(idx) + self._ctensors[ftensor].pop(idx) + if len(self._consumers[ftensor]) == 0 and len(self._producers[ftensor]) == 0: + self._remove_ftensor(ftensor) + # producer + otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) + for otensor in otensors: + ftensor = otensor.parent + idx = self._producers[ftensor].index(node) + self._producers[ftensor].pop(idx) + self._ptensors[ftensor].pop(idx) + if len(self._consumers[ftensor]) == 0 and len(self._producers[ftensor]) == 0: + self._remove_ftensor(ftensor) + else: + segment = self._nodes[pos[0]] + assert isinstance(segment, IRSegment) + segment.remove(node, _pos=CellPosition(pos.indices[1:])) + + return pos + + def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: + """ + Replace one node by multiple nodes + + # TODO: check input and output + + @param node IRCell: the replaced node + @param new_nodes List[IRCell]: the nodes to be inserted. + + @return index int: the replaced node index + """ + idx = self.remove(node) + for new_node in new_nodes[::-1]: + self.insert(new_node, idx) + return idx + + @contextmanager + def update(self, node): + """ + Update a node. Note the related change in backward operator + will not be automatically updated. + + TODO: update operator dependency + + e.g., + with graph.modify(node) as node: + node.set_input(0, tensor) + + @param node IRCell: the node that must in the graph + @return node IRCell: the modify node + """ + index = self.remove(node) + yield node + self.insert(node, index) + + def exist(self, node: IRCell) -> bool: + """ + Check if the node is in this graph + + @param node IRCell: the queried node + + @return exsit bool: True if exist otherwise False + """ + if node in self._nodes: + return True + for segment in self._nodes: + if not isinstance(segment, IRSegment): continue + if segment.exist(node): + return True + return False + + def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwOperation: + """ + Insert a forward node and create its backward. + The created backward operator will be happen right before + the backward of fwop's previous forward node + + This requires the segment has its backward segment + + @param fwop IRFwOperation: forward node + @param index Union[int, CellPosition]: inserted position + + @return node IRFwOperation: the node itself + """ + assert isinstance(fwop, IRFwOperation), "Only allow insert an IRFwOperation" + pos = CellPosition((index,)) if isinstance(index, int) else index + assert isinstance(pos, CellPosition), "Expect index to be int or CellPosition" + + index = pos.indices[-1] + fsegment = self if len(pos) == 1 else self.node(CellPosition(pos.indices[1:])) + fsegment.insert(fwop, index) + # create backward + bwop = fsegment.create_bwop(fwop) + bwop.device = fwop.device + # insert backward + assert fsegment.mirror is not None, "Missing backward segment" + bsegment: IRSegment = fsegment.mirror + bidx = 0 + for idx in range(index - 1, -1, -1): + prev_fnode = fsegment.node(idx) + if prev_fnode.mirror is not None: + bidx = bsegment.index(prev_fnode.mirror) + break + bsegment.insert(bwop, bidx) + return fwop + + # ====================== Graph Generations ============================ + + @staticmethod + def get_inputs(nodes: List[IRCell]): + """ + Get all the input tensors that are required by nodes. + + @param nodes List[IRCell]: the nodes + + @return inputs List[IRTensor]: the input tensors + """ + all_outputs = list() + for node in nodes: + all_outputs.extend(node.outputs()) + inputs = list() + for node in nodes: + for input in node.inputs(): + if isinstance(input, IRTensor): + if input not in all_outputs: + if input not in inputs: + inputs.append(input) + return inputs + + @staticmethod + def get_outputs(nodes: List[IRCell]): + """ + Get tensors that are produced but not consumed by nodes + + As long as the tensor is consumed in by the nodes, it will + not be in the output. A tensor will not appear as output if it + is double-consumed both outside and inside the nodes. + + @param nodes List[IRCell]: the nodes + + @return outputs List[IRTensor]: the output tensors + """ + all_inputs = list() + for node in nodes: + all_inputs.extend(node.inputs()) + outputs = list() + for node in nodes: + for output in node.outputs(): + # not consumed tensor + if isinstance(output, IRTensor): + if output not in all_inputs: + if output not in outputs: + outputs.append(output) + continue + return outputs + + def create_segment(self, nodes: List[IRCell]) -> IRCell: + """! + Create a segment with part of the nodes. + This only return the created segment wihout modifying the graph. + + @param nodes List[IRCell]: the subset nodes of this graph + + @return segment IRSegment: the grouped segment. + """ + segment = self + # segments: List[IRSegment] = [self.segment(node) for node in nodes] + # assert len(set(segments)) == 1, "Cross segment hierarchy grouping is not allowed" + # segment = segments[0] + + inputs, outputs = set(), set() + + # go through adapters + adapter_ins: Dict[IRSubTensor, Set[int]] = dict() + adapter_ous: Dict[IRSubTensor, Set[int]] = dict() + for adapter in nodes: + if not isinstance(adapter, IRAdapter): + continue + for itensor in adapter.inputs(): + if not isinstance(itensor, IRSubTensor): continue + if itensor not in adapter_ins: + adapter_ins[itensor] = set() + adapter_ins[itensor].update(itensor.device) + # producers can from out side node + producers = [] + for ptensor, prod in zip(segment.ptensors(itensor.parent), segment.producers(itensor.parent)): + if ptensor == itensor and set(itensor.device).issubset(set(prod.device)): + producers.append(prod) + if not any(p in nodes for p in producers): + inputs.add(itensor) + for otensor in adapter.outputs(): + if not isinstance(otensor, IRSubTensor): continue + if otensor not in adapter_ous: + adapter_ous[otensor] = set() + adapter_ous[otensor].update(otensor.device) + consumers = [] + for ctensor, cons in zip(segment.ctensors(otensor.parent), segment.consumers(otensor.parent)): + if ctensor == otensor and set(otensor.device).issubset(set(cons.device)): + consumers.append(cons) + if not any(c in nodes for c in consumers): + outputs.add(otensor) + + # go through non-adapter nodes + for node in nodes: + if isinstance(node, IRAdapter): + assert node.differentiable, \ + "Non-differentiable IRAdapter is not allowed to be grouped" + continue + # update inputs + itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] + for itensor in itensors: + ftensor = itensor.parent + if itensor.is_attr(): continue + # from inside adapters + if itensor in adapter_ous: + if len(node.device) > 0 and set(itensor.device).issubset(adapter_ous[itensor]): + continue + # from segment inputs + if any(t.overlap(itensor) for t in segment.inputs() if isinstance(t, IRSubTensor)): + inputs.add(itensor) + continue + # from outside producers + producers, ptensors = segment.producers(ftensor), segment.ptensors(ftensor) + producers = [p for p, t in zip(producers, ptensors) if t == itensor] + if len(itensor.device) > 0: + producers = [p for p in producers if set(itensor.device).issubset(set(p.device))] + # from graph inputs or outside adapter (no producer) + if len(producers) == 0 or any(p not in nodes for p in producers): + inputs.add(itensor) + continue + # update outputs + otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + for otensor in otensors: + ftensor = otensor.parent + if otensor.is_attr(): continue + # from inside adapters + if otensor in adapter_ins: + if len(node.device) > 0 and set(otensor.device).issubset(adapter_ins[otensor]): + continue + # loss doesn't have consumers + if len(segment.consumers(ftensor)) == 0: + outputs.add(otensor) + # from segment outputs + if any(t.overlap(otensor) for t in segment.outputs() if isinstance(t, IRSubTensor)): + outputs.add(otensor) + continue + # for outside consumers + consumers, ctensors = segment.consumers(ftensor), segment.ctensors(ftensor) + consumers = [c for c, t in zip(consumers, ctensors) if t == otensor] + if len(otensor.device) > 0: + consumers = [c for c in consumers if set(otensor.device).issubset(set(c.device))] + # for adapter (no consumer) + if len(consumers) == 0 or any(c not in nodes for c in consumers): + outputs.add(otensor) + continue + segment = IRSegment(nodes, tuple(inputs), tuple(outputs)) + return segment + + + def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: + """ + Instantiate the segement to a specific device. + + @param devid int: the target device + + @return segment IRSegment: the instantiated segment + """ + if devid not in self.device: + return None + if len(self.device) == 1 and self.device == [devid]: + return self + inputs, outputs, nodes = [], [], [] + for node in self._nodes: + if devid in node.device: + if isinstance(node, IRAdapter): + nodes.append(node.dispatch(devid)) + elif isinstance(node, IRSegment): + nodes.append(node.dispatch(devid)) + else: + assert len(node.device) == 1 + nodes.append(node) + for itensor in node.inputs(): + if itensor in self._inputs: + inputs.append(itensor) + for otensor in node.outputs(): + if otensor in self._outputs: + otensor.append(otensor) + outputs.append(otensor) + segment = IRSegment(nodes, inputs, outputs, self.name) + if mirror and segment.mirror is not None: + msegment = segment.mirror.dispatch(devid, mirror=False) + IRCell.make_pair(segment, msegment) + return segment + + + # ========================== Graph Visualize ================================ + + def __repr__(self): + fw = 'f' if self.isfw() else 'b' + inputs = tuple(t for t in self.inputs() if isinstance(t, IRTensor) and not t.is_param()) + if self.isfw(): + dscp = f"{fw}Graph{self.cid}-{self.device}(inputs={inputs}, outputs={self.outputs()})" + else: + dscp = f"{fw}Graph{self.cid}-{self.device}(fGraph{self.mirror.cid}, inputs={inputs}, outputs={self.outputs()})" + return dscp + + def extra_repr(self) -> str: + dscp = f"\n{self.name}:\n{'=' * len(self.name)}\n" + # inputs + dscp += f"Inputs: {self.inputs()}\n" + for node in self._nodes: + dscp += f"\n{node}" + if isinstance(node, IRSegment): + for subnode in node.nodes(): + dscp += f"\n\t{subnode}" + # outputs + dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" + return dscp diff --git a/cube/ir/cten.py b/cube/ir/cten.py index aa938450..b52c783e 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -477,6 +477,8 @@ def dtype(self, val: IRDType): if not isinstance(val, IRDType): raise TypeError(f"Expected IRDType but got {val}") self._dtype = val + if isinstance(self._grad, IRTensor): + self._dtype = val @property def cell(self) -> Optional[IRCell]: @@ -605,12 +607,14 @@ def nelement(self) -> int: cnt *= num return cnt - def backward(self): + def backward(self) -> IRCell: """ Autograd backward on the tensor + + @return graph IRGraph: the forward + backward graph """ - from cube.logics.translator import LogicTranslator - return LogicTranslator.backward(self) + return self.cell.backward(self) + def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device})' diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index 2bf0408a..7a81638c 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -1,5 +1,7 @@ +from typing import List from enum import Enum + class IRDType(Enum): float64 = 'float64' float16 = 'float16' @@ -13,6 +15,35 @@ class IRDType(Enum): unknown = 'unknown' +class DTypeInferRule: + """ + Infer the output shape according to given input shapes. + This will follow the dtype promotion rule, which is same with PyTorch. + + Reference: + https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc + + complex > floating > integral > boolean + """ + @staticmethod + def infer(node, dtypes: List[IRDType]) -> IRDType: + dtypes = [dtype for dtype in dtypes if dtype != IRDType.unknown] + if IRDType.unknown in dtypes: + raise RuntimeError(f"Find an unkown dtype") + if IRDType.float32 in dtypes and IRDType.float16 in dtypes: + raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") + # in priority: fp32 > fp16 > bool > int64 > int16 > + priority = [ + IRDType.float64, IRDType.float32, IRDType.float16, + IRDType.int64, IRDType.int32, IRDType.int16, IRDType.int8, + IRDType.boolean + ] + for dtype in priority: + if dtype in dtypes: + return dtype + return IRDType.unknown + + float64 = IRDType.float64 float16 = IRDType.float16 float32 = IRDType.float32 diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 86aa5f8d..f88cec8c 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -1,39 +1,11 @@ -from typing import Any, Optional, Tuple, Union, List +from typing import Optional, Tuple, Any import copy from cube.ir.cten import IRCell, IRTensor from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.algorithm.factory import DistAlgorithmFactory from cube.ir.unique import IDGenerator - - -class IRBaseOp(IRCell): - - def __init__(self, name: str, signature: str, - inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): - super().__init__(name, signature, len(inputs), len(outputs), init_outputs=False) - self.kwargs = kwargs - assert all(isinstance(t, IRTensor) for t in inputs), "expect all inputs to be IRTensors" - assert all(isinstance(t, IRTensor) for t in outputs), "expect all outputs to be IRTensors" - for idx, itensor in enumerate(inputs): - self.set_input(idx, itensor) - for idx, otensor in enumerate(outputs): - self.set_output(idx, otensor) - - def infer_shape(self) -> bool: - """ - Infer output value shape - """ - raise NotImplementedError - - def replicate(self): - """ - Replicate the Operation - """ - node = type(self)( - self.name, self.signature, self.inputs(), self.outputs(), **self.kwargs) - node._id = self._id - return node +from cube.ir.dtype import IRDType, DTypeInferRule class IRFwOperation(IRCell): @@ -63,6 +35,20 @@ def __init__(self, for idx, output in enumerate(outputs): self.set_output(idx, output) + def infer_dtype(self): + """ + Infer output value dtype. + + By default will follow the same dtype promotion rule with PyTorch. + """ + itensors = [t for t in self.inputs() if isinstance(t, IRTensor)] + assert len(itensors) > 0, "Missing input tensors, need to customize the infer rule" + odtype = DTypeInferRule.infer(self, [t.dtype for t in itensors]) + assert odtype != IRDType.unknown, f"{self} : {[t.dtype for t in itensors]}" + otensors = [t for t in self.outputs() if isinstance(t, IRTensor)] + for tensor in otensors: + tensor.dtype = odtype + def infer_shape(self): """ Infer output value shape @@ -136,25 +122,9 @@ def replicate(self): cpy.clear_successor() return cpy - def gen_backward(self) -> IRCell: - """! - Generate backward operator for this forward operator. - - Note by calling this API, this forward operator must be - attached into any of one IRGraph, or will lead to reference - count 0 error on gradient calcaultion. - - return: IRBpOperation - """ - if self.mirror is not None: - raise RuntimeError( - "Backward Op already generated. Use self.mirror.update() instead.") - bnode = IRBpOperation(self) - return bnode - def __repr__(self) -> str: sign = self.signature.split('.')[-1] - ins = [t for t in self.inputs() if isinstance(t, IRSubTensor) and not t.is_attr()] + ins = [t for t in self.inputs() if isinstance(t, IRTensor) and not t.is_attr()] dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " f"inputs={ins}, " f"outputs={self.outputs()})") @@ -174,45 +144,20 @@ class IRBpOperation(IRCell): Backward operation """ - def __init__(self, fwop: IRFwOperation): + def __init__(self, ograds: Tuple[Any], igrads: Tuple[Any]): """ Create dummy backward node for forward inputs and forward outputs @param fwop IRFwOperation: forward operator """ - assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" - finputs, foutputs = fwop.inputs(), fwop.outputs() super().__init__( 'backward', 'torch.autograd.grad', - len(foutputs), len(finputs), init_outputs=False + len(ograds), len(igrads), init_outputs=False ) - # pair forward op and backward op - IRCell.make_pair(self, fwop) - # set inputs and outputs - self.update() - - def update(self): - """ - Update this backward operator. - This is neccessary when op is partitioned and reference count is changed. - - Note in order to update produced and consumed tensor list, this call should be - wrapped with IRGraph detach and attach: - - ``` - idx = graph.detach(node) - node.update() - graph.attach(node, idx) - ``` - """ - fnode: IRFwOperation = self.mirror - assert isinstance(fnode, IRFwOperation), "Cannot find corresponding IRFwOperation" - for idx, itensor in enumerate(fnode.inputs()): - grad = itensor.grad if isinstance(itensor, IRSubTensor) else None - self.set_output(idx, grad) - for idx, otensor in enumerate(fnode.outputs()): - grad = otensor.grad if isinstance(otensor, IRSubTensor) else None - self.set_input(idx, grad) + for idx, ograd in enumerate(ograds): + self.set_input(idx, ograd) + for idx, igrad in enumerate(igrads): + self.set_output(idx, igrad) def replicate(self): """ diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 885ec143..43ae40d1 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -161,7 +161,7 @@ def __init__(self, weight: IdxChunk): weight = weight.weight assert len(weight) == 2 and all(isinstance(i, int) for i in weight), \ "expected weight to be (idx, nchunks)" - self._weight = weight + self._weight = tuple(weight) @property def weight(self) -> IdxChunk: @@ -254,14 +254,6 @@ def __init__(self, shape=None, name=None, requires_grad=True, dtype=IRDType.unkn super().__init__(shape, name, dtype) - # producer cell and produced sub tensor - self._producers: List[IRCell] = list() - self._ptensors : List[IRSubTensor] = list() - - # consumer cell and consumed sub tensor - self._consumers: List[IRCell] = list() - self._ctensors : List[IRSubTensor] = list() - # record all created sub_tensors self._segments : Dict[(ValueMap, IndexMap), int] = dict() @@ -288,85 +280,6 @@ def like(self): tensor = IRFullTensor(self.shape, self.name, self.requires_grad, self.dtype) return tensor - @property - def producers(self) -> Tuple[IRCell]: - """ - Producer IRCell list - """ - return tuple(self._producers) - - @property - def ptensors(self) -> Tuple[IRTensor]: - """ - Produced IRSubTensor list correspongding to producer IRCell - """ - return tuple(self._ptensors) - - @property - def consumers(self) -> Tuple[IRCell]: - """ - Consumer IRCell list - """ - return tuple(self._consumers) - - @property - def ctensors(self) -> Tuple[IRTensor]: - """ - Consumed IRSubTensor list correspongding to consumer IRCell - """ - return tuple(self._ctensors) - - def add_producer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): - if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): - raise TypeError("Expect an IRCell and an IRTensor") - assert cell not in self._producers, f"{cell} already exists as producer" - self._producers.insert(idx, cell) - self._ptensors.insert(idx, tensor) - - def add_consumer(self, cell: IRCell, tensor: IRTensor, idx: int = 0): - """! - Add the tensor and its operator into consumer list. - The tensor should be in cell.inputs() - - @param cell IRCell: node to be consumer - @param tensor IRTensor: tensor to be consumed tensors - @param idx int: the index to be inserted - """ - assert tensor in cell.inputs(), f"tensor {tensor} not in node: {cell} inputs" - if not isinstance(cell, IRCell) or not isinstance(tensor, IRTensor): - raise TypeError("Expect an IRCell and an IRTensor") - if cell in self._consumers: - for idx, consumer in enumerate(self._consumers): - if cell == consumer: - assert self._ctensors[idx] != tensor, f"double add a same consumer-tensor pair: {cell}" - self._consumers.insert(idx, cell) - self._ctensors.insert(idx, tensor) - for t in self._ctensors: - t._dirty_grad = True - - def rm_producer(self, cell: IRCell) -> int: - if cell not in self._producers: - raise KeyError(f"Cell {cell} not found in producer") - while cell in self._producers: - idx = self._producers.index(cell) - self._producers.pop(idx) - self._ptensors.pop(idx) - return idx - - def rm_consumer(self, cell: IRCell) -> int: - if cell not in self._consumers: - raise KeyError(f"Cell {cell} not found in producer") - idx = self._consumers.index(cell) - self._consumers.pop(idx) - self._ctensors.pop(idx) - return idx - - def clear_producer_consumer(self) -> int: - self._producers = [] - self._ptensors = [] - self._consumers = [] - self._ctensors = [] - @property def grad(self) -> Optional[Union[IRTensor, float]]: return self._grad @@ -381,8 +294,6 @@ def grad(self, val: Optional[Union[IRTensor, float]]): self._requires_grad = False if val is None else True if isinstance(val, IRFullTensor): assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." - for tensor in self._ctensors + self._ptensors: - tensor._dirty_grad = True @property def requires_grad(self): @@ -392,11 +303,28 @@ def requires_grad(self): def requires_grad(self, val: bool): self._requires_grad = val if val and self.grad is None: - self.grad = IRFullTensor(self.shape, 'g' + self.name, False).as_grad() + self.grad = IRFullTensor(self.shape, 'g' + self.name, False, dtype=self.dtype).as_grad() elif not val and self.grad is not None: self.grad = None - for tensor in self._ctensors + self._ptensors: - tensor._dirty_grad = True + + @property + def dtype(self) -> IRDType: + """ + Tensor data type + """ + return self._dtype + + @dtype.setter + def dtype(self, val: IRDType): + """ + Set data type. + It's gradient data type will also be set. + """ + if not isinstance(val, IRDType): + raise TypeError(f"Expected IRDType but got {val}") + self._dtype = val + if isinstance(self.grad, IRTensor): + self.grad.dtype = val def as_param(self): """ @@ -489,8 +417,6 @@ def __init__(self, ftensor: IRFullTensor, self._indmap: IndexMap = indmap # val map self._valmap: ValueMap = valmap - # grad flag - self._dirty_grad = True def __eq__(self, other) -> bool: if isinstance(other, IRSubTensor): @@ -648,60 +574,27 @@ def __copy__(self): tensor._cell = None return tensor - @property - def grad(self) -> Optional[Union[IRTensor, float]]: - """ - Get gradient of this tensor. - - Gradient can be: - - None: the tensor doesn't require gradient - - 1.0: the tensor is loss tensor (scalar) - - IRSubTensor: the tensor requires gradient and is not the loss tensor (scalar) - - Gradient cannot be set and can only be inferred by its IRFullTensor. - The gradient will be lazy updated when its IRFullTensor gets - new consumed / produced tensors - """ - if not self._dirty_grad: - return self._grad - - assert isinstance(self.cell, IRCell), "No cell attached to this tensor." - full_grad = self.parent.grad - if full_grad is None or isinstance(full_grad, float): - self._grad = full_grad - # this tensor is consumed - elif self in self.cell.inputs(): - # for backard, we assume in final distributed graph, - # each tensor can be represented as nested - consumers = [] - for ctensor, consumer in zip(self.parent.ctensors, self.parent.consumers): - if ctensor == self and consumer.cid not in consumers: - consumers.append(consumer.cid) - valmap = (consumers.index(self.cell.cid), len(consumers)) - grad = full_grad.select( - indmap = self.indmap, - valmap = valmap, - ) - self._grad = grad - self._dirty_grad = False - return grad - # this tensor is produced - elif self in self.cell.outputs(): - grad = full_grad.select( - indmap = self.indmap, - valmap = (0, 1), - ) - self._grad = grad - else: - raise RuntimeError("Visit gradient of a tensor that is potentially generated by IRAdapter") - self._dirty_grad = False - self._requires_grad = False if full_grad is None else True - return self._grad - @property def requires_grad(self) -> bool: return self.parent._requires_grad + @property + def grad(self) -> bool: + return self._grad + + @grad.setter + def grad(self, val: Optional[IRTensor]): + if isinstance(val, (IRSubTensor, float)): + assert self.requires_grad + if isinstance(val, IRSubTensor): + val.shape == self.shape + self._grad = val + elif val is None: + assert not self.requires_grad + self._grad = None + else: + raise ValueError(f"Expected grad to be None or IRSubTensor but got: {val}") + # partition primitives def select(self, @@ -847,5 +740,5 @@ def extra_repr(self) -> str: anno = 'w' if self.is_grad(): anno = 'g' - dscp = f'{anno}{self._id}(id={self._id}, shape={self.shape}, dev={self.device}, ind=[{self._indmap}], val={self._valmap})' + dscp = f'{anno}{self._id}(pid={self.parent.tid}, shape={self.shape}, dev={self.device}, ind=[{self._indmap}], val={self._valmap})' return dscp diff --git a/cube/logics/__init__.py b/cube/logics/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/logics/dataloader.py b/cube/logics/dataloader.py deleted file mode 100644 index 4ad72e1d..00000000 --- a/cube/logics/dataloader.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Tuple -from cube.runtime.syndata import CubeDataLoader - - -class IRDataLoader: - - def __init__(self, dataloader: CubeDataLoader, dtype_map): - if not isinstance(dataloader, CubeDataLoader): - raise TypeError("Expected data loader derived from CubeDataLoader") - self.dataloader: CubeDataLoader = iter(dataloader) - self.dtypes = [dtype_map.map(dtype) for dtype in dataloader.dtypes] - self.shapes = [list(shape) for shape in dataloader.shapes] - - def get_batch_dims(self) -> Tuple[int]: - return tuple(self.dataloader.batch_dims) - - def get_batch_size(self) -> int: - return self.dataloader.get_batch_size() - - def set_batch_size(self, bs: int): - self.dataloader.set_batch_size(bs) - return - - def __iter__(self): - return self - - def __next__(self): - from cube.logics.translator import LogicTranslator - datas = LogicTranslator.load_data(self) - return datas diff --git a/cube/logics/model.py b/cube/logics/model.py deleted file mode 100644 index 10696476..00000000 --- a/cube/logics/model.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Tuple, List -import copy - -from cube.graph.graph import IRGraph -from cube.ir.dtype import IRDType -from cube.ir.tensor import IRSubTensor -from cube.ir.operator import IRFwOperation - - -class DTypeInferRule: - """ - According to promotion doc: - https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc - - complex > floating > integral > boolean - """ - @staticmethod - def infer(node, dtypes: List[IRDType]) -> IRDType: - dtypes = [dtype for dtype in dtypes if dtype != IRDType.unknown] - if IRDType.unknown in dtypes: - raise RuntimeError(f"Find an unkown dtype") - if IRDType.float32 in dtypes and IRDType.float16 in dtypes: - raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") - # in priority: fp32 > fp16 > bool > int64 > int16 > - priority = [ - IRDType.float64, IRDType.float32, IRDType.float16, - IRDType.int64, IRDType.int32, IRDType.int16, IRDType.int8, - IRDType.boolean - ] - for dtype in priority: - if dtype in dtypes: - return dtype - return IRDType.unknown - - -def forward(graph: IRGraph, *args) -> IRGraph: - """ - Forward the IRGraph, replacing all the intermediate tensors - """ - if not isinstance(graph, IRGraph): - raise TypeError("Requires IRGraph for forward") - - # align graph with input tensors - itensors: Tuple[IRSubTensor, ...] = graph.inputs() - for idx, (itensor, arg) in enumerate(zip(itensors, args)): - graph.set_input(idx, arg) - for producer in copy.copy(itensor.parent.producers): - pidx = graph.detach(producer) - while itensor in producer.outputs(): - oidx = producer.outputs().index(itensor) - producer.set_output(oidx, arg) - graph.attach(producer, pidx) - for consumer in copy.copy(itensor.parent.consumers): - cidx = graph.detach(consumer) - while itensor in consumer.inputs(): - iidx = consumer.inputs().index(itensor) - consumer.set_input(iidx, arg) - graph.attach(consumer, cidx) - while itensor in graph.outputs(): - oidx = graph.outputs().index(itensor) - graph.set_output(oidx, arg) - - # dtype inference - for node in graph.nodes(): - itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] - # setup gradient - for itensor in itensors: - if itensor.parent.grad is not None: - itensor.parent.dtype = itensor.dtype - if len(itensors) == 0: - continue - odtype = DTypeInferRule.infer(node, [t.dtype for t in itensors]) - assert odtype != IRDType.unknown, f"{node} : {[t.dtype for t in itensors]}" - otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] - for tensor in otensors: - tensor.dtype = odtype - # setup graidient - if tensor.parent.grad is not None: - tensor.parent.grad.dtype = odtype - - # generate backward reverse is only to make op id looks consecutive - for fnode in [n for n in graph.nodes() if isinstance(n, IRFwOperation)][::-1]: - fnode.gen_backward() - return graph diff --git a/cube/logics/pool.py b/cube/logics/pool.py deleted file mode 100644 index fd9f2045..00000000 --- a/cube/logics/pool.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import List, Any -import copy - - -class SchedulePool: - - class __SchedulePool: - - def __init__(self): - - self._nodes = list() - self._tapes = dict() - - instance = None - - def __init__(self): - if not SchedulePool.instance: - SchedulePool.instance = SchedulePool.__SchedulePool() - - def __getattr__(self, name): - return getattr(self.instance, name) - - def add_node(self, node): - self.instance._nodes.append(node) - - def nodes(self) -> List: - return copy.copy(self.instance._nodes) - - def tape(self, tensor, trace: Any): - """ - Record the trace generated to this tensor - """ - self.instance._tapes[tensor._id] = trace - - def get_tape(self, tensor): - """ - Get the trace given the tensor - """ - if tensor._id not in self.instance._tapes: - return None - else: - return self.instance._tapes[tensor._id] - - def clear(self): - self.instance._nodes = list() - self.instance._tapes = dict() - - def __repr__(self): - dscp = '\n'.join([repr(node) for node in self._nodes]) - return dscp diff --git a/cube/logics/translator.py b/cube/logics/translator.py deleted file mode 100644 index 5c2db40d..00000000 --- a/cube/logics/translator.py +++ /dev/null @@ -1,100 +0,0 @@ -from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor - -from cube.graph.graph import IRGraph - -from cube.logics.dataloader import IRDataLoader -from cube.logics import model -from cube.logics.pool import SchedulePool - - - -class LogicTranslator: - - @staticmethod - def gen_logic_graph(outputs=None) -> IRGraph: - """ - Generate Training Logic Graph - """ - nodes = SchedulePool().nodes() - has_bp = any(n for n in nodes if isinstance(n, IRBpOperation)) - if has_bp: - assert all(fnode.mirror in nodes for fnode in nodes if isinstance(fnode, IRFwOperation)), \ - "Training requires all nodes have backward." - else: - # remove backward nodes if no backward is called - fnodes = [node for node in nodes if isinstance(node, IRFwOperation)] - for fnode in fnodes: - IRCell.make_pair(fnode, None) - # remove node gradient - for node in nodes: - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor): - itensor.parent.requires_grad = False - # ad hoc fix on operators with multiple same input tensors - itensor._dirty_grad = True - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor): - otensor.parent.requires_grad = False - graph = IRGraph(nodes, inputs=[], outputs=outputs, module_name='LogicGraph') - return graph - - @staticmethod - def load_data(dataloader: IRDataLoader): - """ - Translator Action: Load data from data loaderw - """ - if not isinstance(dataloader, IRDataLoader): - raise TypeError("Expected IRDataLoader") - outputs = list() - for dtype, shape in zip(dataloader.dtypes, dataloader.shapes): - data = IRFullTensor( - shape, 'data', requires_grad=False, dtype=dtype - ).tosub() - outputs.append(data) - - data_op = IRDataOperation( - data_num=len(outputs), batch_dims=dataloader.get_batch_dims(), - ) - for idx, output in enumerate(outputs): - data_op.set_output(idx, output) - - SchedulePool().add_node(data_op) - if len(outputs) == 0: return - elif len(outputs) == 1: return outputs[0] - else: return tuple(outputs) - - @staticmethod - def forward(graph, *args): - """ - Translator Action: forward an IRGraph - """ - fgraph = model.forward(graph, *args) - for node in fgraph.nodes(): - SchedulePool().add_node(node) - for output in fgraph.outputs(): - SchedulePool().tape(output, fgraph.nodes()) - outputs = fgraph.outputs() - if len(outputs) == 1: return outputs[0] - elif len(outputs) == 0: return None - else: return outputs - - @staticmethod - def backward(loss: IRSubTensor): - """ - Translator Action: backward a tensor - """ - trace = SchedulePool().get_tape(loss) - if trace is None: - raise RuntimeError("No forward detected") - if loss.nelement() != 1: - raise RuntimeError("backward can only perform on the scaler tensor") - # loss tensor grad should be 1.0 - loss.parent.grad = 1.0 - for node in trace[::-1]: - SchedulePool().add_node(node.mirror) - - @staticmethod - def update(optimizer): - raise NotImplementedError diff --git a/cube/program.py b/cube/program.py new file mode 100644 index 00000000..0d88d512 --- /dev/null +++ b/cube/program.py @@ -0,0 +1,158 @@ +from typing import List, Tuple +from cube.graph.torch_dtype_mapping import DType2IRDType + +from cube.ir.cten import IRCell, IRTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.operator import IRBpOperation, IRDataOperation + +from cube.graph import IRGraph +from cube.graph import parser + +from cube.runtime.syndata import CubeDataLoader +from cube.profiler.timer import print_each_rank + +import torch + + +class Program: + + class __Program: + + def __init__(self): + + self._graph = IRGraph([], [], [], 'program') + + instance = None + + def __init__(self): + if not Program.instance: + Program.instance = Program.__Program() + + def __getattr__(self, name): + return getattr(self.instance, name) + + def add_node(self, node: IRCell): + self.instance._graph.insert(node, self.instance._graph.nnodes) + + def add_nodes(self, nodes: List[IRCell]): + for node in nodes: + self.add_node(node) + + def get_graph(self): + has_bp = any(isinstance(node, IRBpOperation) for node in self.instance._graph.nodes()) + if not has_bp: + for ftensor in self.instance._graph.full_tensors(): + ftensor.requires_grad = False + for node in self.instance._graph.nodes(flatten=True): + for itensor in node.inputs(): + if isinstance(itensor, IRSubTensor): + itensor.grad = None + for otensor in node.outputs(): + if isinstance(otensor, IRSubTensor): + otensor.grad = None + return self.instance._graph + + def set_output(self, outputs: List[IRTensor]): + for otensor in outputs: + if not isinstance(otensor, IRTensor): + raise NotImplementedError("Not support for non-tensor graph output") + self.instance._graph.reset_outputs(len(outputs)) + for idx, otensor in enumerate(outputs): + self.instance._graph.set_output(idx, otensor) + + def mirror_as_self(self): + """ + Set mirror as self. This is called when a backward is triggered. + """ + IRCell.make_pair(self.instance._graph, self.instance._graph) + + def clear(self): + self.instance._graph = IRGraph([], [], [], 'program') + + def __repr__(self): + return repr(self.instance._graph) + + +class SemanticDataLoader: + + def __init__(self, dataloader: CubeDataLoader): + if not isinstance(dataloader, CubeDataLoader): + raise TypeError("Expected data loader derived from CubeDataLoader") + self.dataloader: CubeDataLoader = iter(dataloader) + dtype_map = DType2IRDType + self.dtypes = [dtype_map.map(dtype) for dtype in dataloader.dtypes] + self.shapes = [list(shape) for shape in dataloader.shapes] + + def get_batch_dims(self) -> Tuple[int]: + return tuple(self.dataloader.batch_dims) + + def get_batch_size(self) -> int: + return self.dataloader.get_batch_size() + + def set_batch_size(self, bs: int): + self.dataloader.set_batch_size(bs) + return + + def __iter__(self): + return self + + def __next__(self): + outputs = list() + for dtype, shape in zip(self.dtypes, self.shapes): + data = IRFullTensor( + shape, 'data', requires_grad=False, dtype=dtype + ).tosub() + outputs.append(data) + + data_op = IRDataOperation( + data_num=len(outputs), batch_dims=self.get_batch_dims(), + ) + for idx, output in enumerate(outputs): + data_op.set_output(idx, output) + + Program().add_node(data_op) + if len(outputs) == 0: return + elif len(outputs) == 1: return outputs[0] + else: return tuple(outputs) + + +class SemanticModel: + + def __init__(self, model: torch.nn.Module, input_shapes): + """ + Create semantic model based on AI Scientist description. + """ + dist = torch.distributed.is_initialized() + if (not dist) or (dist and torch.distributed.get_rank() == 0): + self.ir_graph = parser.convert_model( + model, input_shapes=input_shapes + ) + else: + self.ir_graph = None + self._loaded_module = None + + def get_graph(self): + return self.ir_graph + + def load_module(self, filename: str, load_content=True): + import importlib.util + spec = importlib.util.spec_from_file_location("GenModel", filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._loaded_module = module.GenModel().cuda() + if load_content: + print_each_rank("> loading parameter content...") + # TODO: make hardcode ./fullmodel.pt programmable + self._loaded_module.load_attr_content('./fullmodel.pt') + + def get_gen_module(self): + return self._loaded_module + + def clear_module(self): + self._loaded_module = None + + def __call__(self, *args): + if self._loaded_module: + return self._loaded_module(*args) + else: + return self.ir_graph(*args) \ No newline at end of file diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 00e2114f..e514f475 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -99,7 +99,7 @@ def all_reduce(itensor: torch.Tensor, CudaTimer().start(field_name='comm') if not itensor.is_contiguous(): itensor = itensor.contiguous() - itensor = itensor.detach().requires_grad_() + itensor = itensor.detach() group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(itensor, group=group) torch.cuda.synchronize() @@ -120,7 +120,7 @@ def all_gather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tens torch.distributed.all_gather(tensor_list, itensor, group=group) torch.cuda.synchronize() # concat - otensor = torch.concat(tuple(tensor_list), dim=dim).requires_grad_() + otensor = torch.concat(tuple(tensor_list), dim=dim) CudaTimer().stop(field_name='comm') return otensor @@ -135,7 +135,7 @@ def reduce_scatter(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch. if not tensor.is_contiguous(): itensors[idx] = tensor.contiguous() group = DeviceGroup().get_group(ranks) - otensor = torch.empty_like(itensors[0], requires_grad=True) + otensor = torch.empty_like(itensors[0], requires_grad=False) torch.distributed.reduce_scatter(otensor, itensors, group=group) torch.cuda.synchronize() CudaTimer().stop(field_name='comm') @@ -155,7 +155,7 @@ def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) - group = DeviceGroup().get_group(ranks) torch.distributed.all_to_all(otensors, itensors, group=group) torch.cuda.synchronize() - otensor = torch.concat(tuple(otensors), dim=idim).requires_grad_() + otensor = torch.concat(tuple(otensors), dim=idim) CudaTimer().stop(field_name='comm') return otensor diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index bc4f99c0..6dab1c6b 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -2,7 +2,6 @@ import torch from cube.profiler.timer import CudaTimer -from cube.runtime.adapter.collectives import all_reduce from cube.runtime.device import DeviceGroup diff --git a/cube/runtime/adapter/transform.py b/cube/runtime/adapter/transform.py index 016d3d4a..d60a9b18 100644 --- a/cube/runtime/adapter/transform.py +++ b/cube/runtime/adapter/transform.py @@ -10,11 +10,6 @@ def identity(tensor: torch.Tensor): """ identity """ - require_grad = tensor.requires_grad - with torch.no_grad(): - tensor = tensor.detach() - if require_grad: - tensor = tensor.requires_grad_() return tensor @@ -23,14 +18,11 @@ def select(tensor: torch.Tensor, """ Select a part of tensor spatially and numerically. """ - require_grad = tensor.requires_grad with torch.no_grad(): sub_tensor = tensor[indmap] if valmap != 1: sub_tensor = sub_tensor / valmap sub_tensor = sub_tensor.detach() - if require_grad: - sub_tensor = sub_tensor.requires_grad_() return sub_tensor @@ -43,11 +35,8 @@ def smerge(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: tensors: a list of torch tensor dim: the dimension to concatenate. """ - require_grad = any(t.requires_grad for t in tensors) with torch.no_grad(): - out = torch.concat(tuple(tensors), dim).requires_grad_() - if require_grad: - out = out.requires_grad_() + out = torch.concat(tuple(tensors), dim) return out @@ -59,11 +48,8 @@ def vmerge(tensors: List[torch.Tensor]) -> torch.Tensor: Args: tensors: a list of torch tensor """ - require_grad = any(t.requires_grad for t in tensors) with torch.no_grad(): out = tensors[0] for tensor in tensors[1:]: out = out + tensor - if require_grad: - out = out.requires_grad_() return out diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index fb384508..91b56d28 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -2,21 +2,117 @@ Executor for runtime """ -from typing import Tuple, Any, Callable, List +from typing import Tuple, Any, Callable, List, Dict import torch -def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): - """ - forward the sub-graph. - """ - if not requires_grad: - with torch.no_grad(): +def debug_id(tensors, msg: str, rank: int): + if torch.distributed.get_rank() == rank: + if torch.is_tensor(tensors): + print(f'[{torch.distributed.get_rank()}] {msg}: [{id(tensors)}]') + else: + print(f'[{torch.distributed.get_rank()}] {msg}: {[id(t) for t in tensors]}') + + +class Executor: + + _detach: Dict[str, Dict[torch.Tensor, torch.Tensor]] = dict() + + @staticmethod + def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): + """ + forward the sub-graph. + """ + if not requires_grad: + with torch.no_grad(): + outputs = subgraph(*input_tensors) + else: + # everytime forward a segment, detach the tensor from previous graph + # debug_id(input_tensors, 'outside fexecute args', 0) + assert name not in Executor._detach + Executor._detach[name] = dict() + for itensor in input_tensors: + if torch.is_tensor(itensor) and itensor.requires_grad: + if itensor not in Executor._detach[name]: + Executor._detach[name][itensor] = itensor.detach().requires_grad_() + input_tensors = tuple( + Executor._detach[name][t] if t in Executor._detach[name] else t for t in input_tensors + ) + # debug_id(input_tensors, 'inside fexecute args', 0) outputs = subgraph(*input_tensors) - else: - outputs = subgraph(*input_tensors) - # print('forwarding... ') - return outputs + # debug_id(outputs, 'fexecute result', 0) + # print('forwarding... ') + return outputs + + @staticmethod + def aexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): + """ + execute adapter + """ + if not requires_grad: + with torch.no_grad(): + outputs = subgraph(*input_tensors) + else: + outputs = subgraph(*input_tensors) + outputs = outputs.requires_grad_() if torch.is_tensor(outputs) else (t.requires_grad_() for t in outputs) + return outputs + + @staticmethod + def backward(name: str, + input_tensors: List[torch.Tensor], + output_tensors: List[torch.Tensor], + output_tensor_grads: List[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Backward Procedure. + + input_tensors: List[torch.Tensor]: + tensors that their gradient need to be computed, including parameters. + Correspoinding forward input tensors. + + output_tensors: + tensors that start for gradient backward computation. + Corresponding to forward output tensors. + + output_tensor_grads: + gradient tensors corresponding to output_tensors. + + Returns: + gradient in order of non-parameter tensors in input_tensors. + (Note parameter tnesors already have gradient accumulated at .grad attribute) + """ + if len(output_tensors) == 0: + return None + + assert name in Executor._detach, f"forward graph: {name} not run before" + input_tensors = [t for t in input_tensors if torch.is_tensor(t) and not isinstance(t, torch.nn.Parameter)] + input_tensors = [t for t in input_tensors if t.requires_grad] + input_tensors = [Executor._detach[name][t] if t in Executor._detach[name] else t for t in input_tensors] + for t in input_tensors: + t.retain_grad() + torch.autograd.backward( + output_tensors, + grad_tensors=output_tensor_grads, + ) + grads = tuple(t.grad for t in input_tensors) + assert all(grad is not None for grad in grads), "RuntimeError: got gradient None" + del Executor._detach[name] + if len(grads) == 0: return None + elif len(grads) == 1: return grads[0] + else: return tuple(grads) + + @staticmethod + def clear(): + Executor._detach = dict() + + @staticmethod + def check_clear(): + assert len(Executor._detach) == 0, \ + f"Find remain not consumed sub-graph: {tuple(Executor._detach.keys())}" + + +fexecute = Executor.fexecute +aexecute = Executor.aexecute +backward = Executor.backward # def backward(input_tensors : List[torch.Tensor], @@ -62,44 +158,6 @@ def fexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) # else: return tuple(grads) -def backward(input_tensors: List[torch.Tensor], - output_tensors: List[torch.Tensor], - output_tensor_grads: List[torch.Tensor]) -> Tuple[torch.Tensor]: - """ - Backward Procedure. - - input_tensors: List[torch.Tensor]: - tensors that their gradient need to be computed, including parameters. - Correspoinding forward input tensors. - - output_tensors: - tensors that start for gradient backward computation. - Corresponding to forward output tensors. - - output_tensor_grads: - gradient tensors corresponding to output_tensors. - - Returns: - gradient in order of non-parameter tensors in input_tensors. - (Note parameter tnesors already have gradient accumulated at .grad attribute) - """ - if len(output_tensors) == 0: - return None - inputs = list() - for input_ in input_tensors: - if torch.is_tensor(input_) and not isinstance(input_, torch.nn.Parameter): - if input_.requires_grad: - input_.retain_grad() - inputs.append(input_) - torch.autograd.backward( - output_tensors, - grad_tensors=output_tensor_grads, - ) - grads = tuple(input_.grad for input_ in inputs) - if len(grads) == 0: return None - elif len(grads) == 1: return grads[0] - else: return tuple(grads) - ### =================== Experimental Feature ======================= # import queue diff --git a/examples/mlp/linears.py b/examples/mlp/linears.py index 8bc454a8..05eb5f51 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/linears.py @@ -46,9 +46,9 @@ def __init__(self, dim, mult=1, nlayers=4): self.layers = torch.nn.ModuleList([]) for lid in range(nlayers): if lid % 2 == 0: - self.layers.append(nn.Linear(dim, dim * mult)) + self.layers.append(nn.Linear(dim, dim * mult, bias=False)) else: - self.layers.append(nn.Linear(dim * mult, dim)) + self.layers.append(nn.Linear(dim * mult, dim, bias=False)) def forward(self, data): x = data From c18007d3b35c64b47cba42042dedc3784af3a7af Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 8 Sep 2022 14:18:30 +0800 Subject: [PATCH 1001/1892] fix schedule policy --- cube/graph/gener/gen.py | 15 ++++- cube/graph/graph.py | 9 +-- cube/graph/schedule/sched1f1b.py | 59 ++++++++++++------ cube/graph/schedule/strategy.py | 104 ++++++++----------------------- 4 files changed, 82 insertions(+), 105 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 1016d684..0ea7adda 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -201,13 +201,13 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bproducers, bptensors = [], [] bconsumers, bctensors = [], [] if (ftensor not in skip_grads) and isinstance(ftensor.grad, IRFullTensor): - bproducers, bptensors = graph.producers(ftensor.grad), graph.ptensors(ftensor.grad) + bproducers, bptensors = bgraph.producers(ftensor.grad), bgraph.ptensors(ftensor.grad) assert all(len(ptensor.device) == 1 for ptensor in bptensors), ( f"Not support for multi-device:\n" f"{[ptensor.device for ptensor in bptensors]}" f"{[ptensor.cell for ptensor in bptensors]}" ) - bconsumers, bctensors = graph.consumers(ftensor.grad), graph.ctensors(ftensor.grad) + bconsumers, bctensors = bgraph.consumers(ftensor.grad), bgraph.ctensors(ftensor.grad) assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" if skip(fptensors, fctensors) and skip(bptensors, bctensors): @@ -217,6 +217,17 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: if fadapter is None: continue + if not isinstance(graph, IRGraph): + if not fadapter.differentiable: + raise NotImplementedError( + "Require adapter to be differentiable for nested IRAdapter." + "Condition to be differentiable: prodcuers have same device set with consumers" + f"Failed FullTensor: {ftensor}" + f"{graph.debug_tensor_map_str(ftensor)}" + f"Failed FullTensor.grad: {ftensor.grad}" + f"{bgraph.debug_tensor_map_str(ftensor.grad) if ftensor.grad is not None else None}" + ) + badapter: Optional[IRAdapter] = fadapter.mirror if (badapter is not None and len(fadapter.prims) == 0 and len(badapter.prims) == 0) or \ diff --git a/cube/graph/graph.py b/cube/graph/graph.py index d4efb15d..b22475e1 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -434,11 +434,9 @@ def assign(self, node: Union[IRFwOperation, IRDataOperation], device: int) -> bo """ assert self.exist(node), f"{node} is not in the graph" if isinstance(node, IRSegment): - assert node.forward, "Only forward segment is allowed to assign devices" + assert node.isfw(), "Only forward segment is allowed to assign devices" for subnode in node.nodes(): - subnode.device = device - if subnode.mirror is not None: - subnode.mirror.device = device + self.assign(subnode, device) else: assert isinstance(node, (IRFwOperation, IRDataOperation)), \ "Only forward operators and dataloader operators are allowed to assign devices" @@ -534,8 +532,7 @@ def sched(self): """ return self._sched - @sched.setter - def sched(self, strategy): + def predef_sched(self, strategy): """! Set schedule plan for the execution. diff --git a/cube/graph/schedule/sched1f1b.py b/cube/graph/schedule/sched1f1b.py index d3cb8804..37dc6abc 100644 --- a/cube/graph/schedule/sched1f1b.py +++ b/cube/graph/schedule/sched1f1b.py @@ -1,5 +1,5 @@ -from typing import Dict, Tuple, Optional +from typing import Dict, Optional, List from cube.ir.cten import IRCell from cube.ir.adapter.adapter import IRAdapter @@ -19,11 +19,11 @@ class IRSchedule1F1B(IRScheduleStrategy): [Recv-Backward] Backward-Segment [Send-Backward] """ - def __init__(self, graph, nmicros: int, devmesh: Tuple[Tuple[int]]): - super().__init__(graph, nmicros, devmesh) + def __init__(self, graph, nmicros: int): + super().__init__(graph, nmicros) self.signature = 'cube.runtime.schedule.Schedule1F1B.run' # forward body - self.segment: Dict[int, IRSegment] = dict() + self.fsegments: Dict[int, IRSegment] = dict() # forward send self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() # forward recv @@ -33,29 +33,48 @@ def __init__(self, graph, nmicros: int, devmesh: Tuple[Tuple[int]]): # backward recv self.rbadapter: Dict[int, Optional[IRAdapter]] = dict() # num_stage - self.num_stages: int = len(devmesh) + self.num_stages: int = -1 # stage id self.stage_id: Dict[int, int] = dict() # recompute self.recompute = False + def apply(self) -> IRGraph: - self.segmentation() - for gid, devices in enumerate(self.devmesh): - for devid in devices: - # forward recv - self.rfadapter[devid] = None if gid == 0 else self.cross_groups[gid-1] + self.mesh() + # each forward has corresponding backward + assert all(fseg.mirror in self.segments for fseg in self.segments if fseg.isfw()), \ + "Require backward of each forward stage" + # stage doesn't share devices + fsegments: List[IRSegment] = [fseg for fseg in self.segments if fseg.isfw()] + self.num_stages = len(fsegments) + for sid, fseg in enumerate(fsegments): + for devid in fseg.device: # forward body - self.segment[devid] = self.inner_groups[gid] - # forward send - if gid == len(self.devmesh)-1: assert self.cross_groups[gid] is None - self.sfadapter[devid] = self.cross_groups[gid] - # backward recv - self.rbadapter[devid] = None if gid == len(self.devmesh)-1 else self.sfadapter[devid].mirror - # backward send - self.sbadapter[devid] = None if gid == 0 else self.rfadapter[devid].mirror + assert devid not in self.fsegments, "One device cannot have multiple forward stages" + self.fsegments[devid] = fseg + # forward recv / backward send + assert len(self.recvers[fseg]) <= 1, "Corss-stage adapter can only be one" + if sid == 0: + assert len(self.recvers[fseg]) == 0, "Expect no forward send at first stage" + assert len(self.senders[fseg.mirror]) == 0, "Expect no backward send at first stage" + else: + assert len(self.recvers[fseg]) == 1, "Expect one forward recv at non-first stage" + assert len(self.senders[fseg.mirror]) == 1, "Expect one backward send at non-first stage" + self.rfadapter[devid] = None if sid == 0 else self.recvers[fseg][0] + self.sbadapter[devid] = None if sid == 0 else self.senders[fseg.mirror][0] + # forward send / backward recv + if sid == self.num_stages - 1: + assert len(self.senders[fseg]) == 0, "Expect no forward send at last stage" + assert len(self.recvers[fseg.mirror]) == 0, "Expect no backward recv at last stage" + else: + assert len(self.senders[fseg]) == 1, "Expect no forward send at last stage" + assert len(self.recvers[fseg.mirror]) == 1, "Expect no forward send at last stage" + self.sfadapter[devid] = None if sid == self.num_stages - 1 else self.senders[fseg][0] + self.rbadapter[devid] = None if sid == self.num_stages - 1 else self.recvers[fseg.mirror][0] # stage id - self.stage_id[devid] = gid + self.stage_id[devid] = sid + return self.graph def kwargs(self, devid: int) -> Dict[str, IRCell]: @@ -63,7 +82,7 @@ def kwargs(self, devid: int) -> Dict[str, IRCell]: return kwargs for runtime caller """ return dict( - segment = self.segment[devid], + segment = self.fsegments[devid], sfadapter = self.sfadapter[devid], rfadapter = self.rfadapter[devid], sbadapter = self.sbadapter[devid], diff --git a/cube/graph/schedule/strategy.py b/cube/graph/schedule/strategy.py index 852fc889..ab660f27 100644 --- a/cube/graph/schedule/strategy.py +++ b/cube/graph/schedule/strategy.py @@ -1,18 +1,25 @@ from typing import Tuple, Dict, Any, List from cube.graph.graph import IRGraph, IRSegment -from cube.ir.adapter.adapter import IRAdapter +from cube.graph.function import IRGraphAnchor +from cube.ir.adapter.adapter import IRAdapter, IRWeightReducer from cube.ir.cten import IRCell -from cube.ir.operator import IRFwOperation class IRScheduleStrategy: - def __init__(self, graph: IRGraph, nmicros: int, devmesh: Tuple[Tuple[int]]) -> None: + def __init__(self, graph: IRGraph, nmicros: int) -> None: self.graph : IRGraph = graph self.nmicros : int = nmicros - self.devmesh: Tuple[Tuple[int]] = devmesh - self.inner_groups: List[IRSegment] = [None] * len(devmesh) - self.cross_groups: List[IRAdapter] = [None] * len(devmesh) + self.devmesh: List[Tuple[int]] = [] + # preprocess before segments + self.pre_process: List[IRCell] = [] + self.segments: List[IRSegment] = [] + # the recver adapters for this segment + self.recvers: Dict[IRSegment, List[IRAdapter]] = dict() + # the sender adapters for this segment + self.senders: Dict[IRSegment, List[IRAdapter]] = dict() + # postprocess after segments + self.post_process: List[IRCell] = [] self.signature: str = '' def apply(self, graph: IRGraph) -> IRGraph: @@ -21,77 +28,20 @@ def apply(self, graph: IRGraph) -> IRGraph: def kwargs(self, device: int) -> Dict[str, Any]: raise NotImplementedError - def segmentation(self): + def mesh(self) -> List[List[int]]: """! - Group operators into segments corresponding to devmesh. - - A greedy grouping is applied for each group given the device mesh. - The non-differentiable adapters need to be moved at the boundary - of device mesh, as the cross group communication. + Group operators into segments corresponding to graph stage """ - def differientiable(node: IRCell) -> bool: - return isinstance(node, IRFwOperation) or \ - (isinstance(node, IRAdapter) and node.forward and node.differentiable) + for segment in self.graph.nodes(): + if isinstance(segment, IRSegment): + self.segments.append(segment) + self.recvers[segment] = [] + self.senders[segment] = [] - devmesh = self.devmesh - inner_groups: List[List[IRCell]] = [[] for _ in range(len(devmesh))] - cross_groups: List[List[IRAdapter]] = [[] for _ in range(len(devmesh))] - sid = 0 - for node in self.graph.nodes(): - if not (isinstance(node, (IRFwOperation, IRAdapter))): - continue - devs = set(node.device) - if differientiable(node): - while sid < len(devmesh) and not devs.issubset(devmesh[sid]): - sid += 1 - assert sid < len(devmesh), f"invalid stategy with graph placement" - inner_groups[sid].append(node) - else: - if not (isinstance(node, IRAdapter) and node.forward): - continue - assert not devs.issubset(devmesh[sid]), f"find a non-differentiable adapter in devmesh: {devmesh[sid]}" - cross_mesh = devmesh[sid] + devmesh[sid+1] if sid < len(devmesh) - 1 else devmesh[sid] - assert devs.issubset(set(cross_mesh)) - cross_groups[sid].append(node) - - # move non-differentiable adapter to the boundary of groups - for igroup, cgroup in zip(inner_groups, cross_groups): - if len(igroup) == 0: - print('warning: find a group with no operator') - continue - last_node: IRCell = igroup[-1] - for fadapter in cgroup[::-1]: - success = self.graph.schedule(fadapter, 'after', last_node) - if fadapter.mirror is not None and last_node.mirror is not None: - success = self.graph.schedule( - fadapter.mirror, 'before', last_node.mirror - ) - if not success: - raise RuntimeError("Fail to schedule non-differentiable adapter to group boundaries") - - # grouping - for gid in range(len(devmesh)): - # group computation groups - igroup = inner_groups[gid] - if len(igroup) != 0: - fsegment = self.graph.group(igroup) - bnodes = [n.mirror for n in igroup[::-1] if n.mirror is not None] - if len(bnodes) != 0: - bsegment = self.graph.group(bnodes) - IRCell.make_pair(fsegment, bsegment) - self.inner_groups[gid] = fsegment - else: - self.inner_groups[gid] = None - # merge cross communication adapters - cgroup = cross_groups[gid] - if len(cgroup) == 1: - self.cross_groups[gid] = cgroup[0] - elif len(cgroup) > 1: - fadapter = IRAdapter.merge(cgroup) - bnodes = [n.mirror for n in igroup[::-1] if n.mirror is not None] - if len(bnodes) != 0: - badapter = IRAdapter.merge(bnodes) - IRCell.make_pair(fadapter, badapter) - self.cross_groups[gid] = fadapter - else: - self.cross_groups[gid] = None + for adapter in self.graph.nodes(): + if isinstance(adapter, IRAdapter): + for segment in self.segments: + if self.graph.depends(adapter, segment): + self.recvers[segment].append(adapter) + elif self.graph.depends(segment, adapter): + self.senders[segment].append(adapter) From d8809f92b0094c1ab1cbe0bd64b16563b7827832 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 8 Sep 2022 17:15:07 +0800 Subject: [PATCH 1002/1892] fix hierarchical policy --- cube/execplan/execplan.py | 1 + cube/execplan/planpass/fusion.py | 38 +++++++------ cube/graph/gener/gen.py | 93 +++++++++++++++++++------------- cube/graph/graph.py | 4 +- cube/graph/segment.py | 25 ++++----- cube/ir/cten.py | 2 +- cube/ir/tensor.py | 31 ++++++++--- examples/nlp/gpt/policy/mpmd.py | 49 ++++++++++------- 8 files changed, 146 insertions(+), 97 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 67ea9112..d8ddb2c2 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -43,6 +43,7 @@ def __init__(self, graph: IRGraph): bidx = self.at(devid).index(fnode.mirror) nodes.remove(fnode.mirror) self.at(devid)[bidx] = fnode_dev.mirror + assert fnode_dev.mirror is not None, f"Find None:\n{fnode_dev}" @property def graph(self) -> IRGraph: diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index ca6e0d61..74881138 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -25,25 +25,31 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: cnt = 0 for devid in execplan.devices(): for node in execplan.seq(devid): - if isinstance(node, IRAdapter): - if node.forward: - ret = DiffFusion.nnfuse(node) - cnt = cnt+1 if ret else cnt - if isinstance(node, IRSegment) and node.isfw(): - for fnode in node.nodes(): - if isinstance(fnode, IRAdapter): - if node.forward: - ret = DiffFusion.nnfuse(fnode) - if not ret: - raise NotImplementedError( - f"adapter within IRSegment cannot fuse to differientiable adapter" - f"\nforward: {fnode.extra_repr()}" - f"\nbackward: {fnode.mirror.extra_repr()}" - ) - cnt = cnt + 1 + if isinstance(node, IRAdapter) and node.forward: + ret = DiffFusion.nnfuse(node) + cnt = cnt+1 if ret else cnt + elif isinstance(node, IRSegment) and node.isfw(): + cnt += DiffFusion._apply(node) print(f'successfully generate {cnt} differentiable adapters') return execplan + @staticmethod + def _apply(segment: IRSegment) -> int: + cnt = 0 + for node in segment.nodes(): + if isinstance(node, IRAdapter) and node.forward: + ret = DiffFusion.nnfuse(node) + if not ret and not node.differentiable: + raise NotImplementedError( + f"Adapter within IRSegment cannot fuse to differientiable adapter" + f"\nforward: {node.extra_repr()}" + f"\nbackward: {node.mirror.extra_repr()}" + ) + cnt = cnt + 1 + elif isinstance(node, IRSegment) and node.isfw(): + cnt += DiffFusion._apply(node) + return cnt + @staticmethod def nnfuse(fadapter: IRAdapter) -> bool: """ diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 0ea7adda..35f79f32 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,39 +1,25 @@ import itertools from typing import Dict, List, Optional, Tuple -import copy from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener - from cube.graph.graph import IRGraph from cube.graph.segment import IRSegment + from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor - -from cube.ir.operator import IRBpOperation, IRFwOperation +from cube.ir.operator import IRFwOperation from cube.ir.adapter import IRAdapter, IRWeightReducer from cube.graph.function.function import Add, Cat, Identity, MultiRef -def to_device(tensor: IRSubTensor, device: int) -> IRFwOperation: - """ - This is used for changing tensor device - """ - fwop = IRFwOperation('dummy', 'dummpy', 1, 0) - fwop.set_input(0, tensor) - fwop.device = device - otensor = fwop.input(0) - otensor.grad = copy.copy(tensor.grad) - if isinstance(otensor.grad, IRSubTensor): - otensor.grad.cell = fwop - return otensor - - class DummyInputOuput(IRFwOperation): - def __init__(self, tensor: IRSubTensor, device: int, is_input=False, is_output=False): - super().__init__('dummy', '', + def __init__(self, tensor: IRSubTensor, device: int, + is_input=False, is_output=False, + name='dummy'): + super().__init__(name, name, 1 if is_input else 0, 1 if is_output else 0 ) @@ -60,19 +46,48 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: for devid in devices: for tensor in segment.inputs(): if not isinstance(tensor, IRSubTensor): continue - assert tensor.valmap == (0, 1) + assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" fwop = DummyInputOuput(tensor, devid, is_output=True) segment.insert(fwop, 0) fwops.append(fwop) for tensor in segment.outputs(): if not isinstance(tensor, IRSubTensor): continue - assert tensor.valmap == (0, 1) + assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" fwop = DummyInputOuput(tensor, devid, is_input=True) segment.insert(fwop, segment.nnodes) fwops.append(fwop) return fwops +def expand_devices(tensors: List[IRSubTensor], + producer: bool = False, consumer: bool = False) -> List[IRSubTensor]: + """ + Scatter a tensor if it is on multiple devices. It produces a tensor list where + each tensor is attached to one device, with tensor itself is replicated. + + @param tensors List[IRSubTensor]: each tensor can be on multiple devices. + @param producer bool: if the tensor is producer role + @param consumer bool: if the tensor is consumer role + + @return dtensors List[IRSubTensor]: each tensor is on one device + """ + dtensors = [] + for tensor in tensors: + if len(tensor.device) == 1: + dtensors.append(tensor) + continue + for devid in tensor.device: + if producer: + fwop = DummyInputOuput(tensor, devid, is_output=True, name=tensor.cell.name) + dtensors.append(fwop.output(0)) + elif consumer: + fwop = DummyInputOuput(tensor, devid, is_input=True, name=tensor.cell.name) + dtensors.append(fwop.input(0)) + else: + raise ValueError("At least one of producer or consumer") + return dtensors + + class IRAdapterGener: @staticmethod @@ -194,20 +209,24 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # producers can be operators and graph inputs fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) + fptensors = expand_devices(fptensors, producer=True) assert all(len(ptensor.device) == 1 for ptensor in fptensors), "Not support for multi-device" fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) + fctensors = expand_devices(fctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" bproducers, bptensors = [], [] bconsumers, bctensors = [], [] if (ftensor not in skip_grads) and isinstance(ftensor.grad, IRFullTensor): bproducers, bptensors = bgraph.producers(ftensor.grad), bgraph.ptensors(ftensor.grad) + bptensors = expand_devices(bptensors, producer=True) assert all(len(ptensor.device) == 1 for ptensor in bptensors), ( f"Not support for multi-device:\n" f"{[ptensor.device for ptensor in bptensors]}" f"{[ptensor.cell for ptensor in bptensors]}" ) bconsumers, bctensors = bgraph.consumers(ftensor.grad), bgraph.ctensors(ftensor.grad) + bctensors = expand_devices(bctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" if skip(fptensors, fctensors) and skip(bptensors, bctensors): @@ -292,15 +311,14 @@ def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTens tensor_map: Dict[int, Dict[IRSubTensor, IRSubTensor]] = dict() for tensor in graph.ptensors(ftensor): - assert len(tensor.device) == 1 - devid = tensor.device[0] - if devid not in devtensors: - devtensors[devid] = [] - fuse_tensors[devid] = dict() - tensor_map[devid] = dict() - devtensors[devid].append(tensor) - fuse_tensors[devid][tensor] = [tensor] - tensor_map[devid][tensor] = tensor + for devid in tensor.device: + if devid not in devtensors: + devtensors[devid] = [] + fuse_tensors[devid] = dict() + tensor_map[devid] = dict() + devtensors[devid].append(tensor) + fuse_tensors[devid][tensor] = [tensor] + tensor_map[devid][tensor] = tensor nodes: List[IRFwOperation] = [] for devid, tensors in devtensors.items(): @@ -417,13 +435,12 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): # collect to consumer tensors of each device devtensors: Dict[int, Dict[IRSubTensor, List[IRCell]]] = dict() for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): - assert len(ctensor.device) == 1 - devid = ctensor.device[0] - if devid not in devtensors: - devtensors[devid] = dict() - if ctensor not in devtensors[devid]: - devtensors[devid][ctensor] = [] - devtensors[devid][ctensor].append(consumer) + for devid in ctensor.device: + if devid not in devtensors: + devtensors[devid] = dict() + if ctensor not in devtensors[devid]: + devtensors[devid][ctensor] = [] + devtensors[devid][ctensor].append(consumer) # restrict each device has same subtensor nl = '\n' diff --git a/cube/graph/graph.py b/cube/graph/graph.py index b22475e1..8eec99b3 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -174,9 +174,6 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: f"Internal Error: backward nodes are not consecutive. maxbidx: {maxbidx}, minbidx: {minbidx}" fsegment = fgraph.create_segment(fnodes) - bsegment = bgraph.create_segment(bnodes) if len(bnodes) > 0 else None - IRCell.make_pair(fsegment, bsegment) - # replace forward for fnode in fnodes: fidx = fgraph.remove(fnode) @@ -184,6 +181,7 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: # replace backward if len(bnodes) > 0: + bsegment = fgraph.create_bwop(fsegment) if len(bnodes) > 0 else None for bnode in bnodes: bidx = bgraph.remove(bnode) bgraph.insert(bsegment, bidx) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 848d4251..d8032ef0 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -74,8 +74,6 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR super().__init__(name, '', len(inputs), len(outputs), init_outputs=False) self._nodes: List[IRCell] = [] - self._idevice = [t.device for t in inputs] - self._odevice = [t.device for t in outputs] for idx, val in enumerate(inputs): self.set_input(idx, val) @@ -340,17 +338,21 @@ def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: dscp += f'\t{consumer}\n' return dscp - def create_bwop(self, fwop: IRFwOperation) -> IRBpOperation: + def create_bwop(self, fwop: IRFwOperation) -> Union[IRBpOperation, IRBpOperation]: """ Create dummy backward operator for given forward operator """ - assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" + assert isinstance(fwop, (IRFwOperation, IRSegment)), "Expected IRFwOperation" fsegment: IRSegment = self.segment(fwop) fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] igrads = [fsegment.grad(t) if t.requires_grad else None for t in fins] ograds = [fsegment.grad(t) if t.requires_grad else None for t in fous] - bwop = IRBpOperation(ograds, igrads) + if isinstance(fwop, IRFwOperation): + bwop = IRBpOperation(ograds, igrads) + else: + bnodes = [fnode.mirror for fnode in fwop.nodes() if fnode.mirror is not None] + bwop = IRSegment(bnodes, ograds, igrads) IRCell.make_pair(fwop, bwop) return bwop @@ -758,7 +760,6 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: segment = IRSegment(nodes, tuple(inputs), tuple(outputs)) return segment - def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: """ Instantiate the segement to a specific device. @@ -782,15 +783,15 @@ def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: assert len(node.device) == 1 nodes.append(node) for itensor in node.inputs(): - if itensor in self._inputs: + if itensor in self._inputs and itensor not in inputs: inputs.append(itensor) for otensor in node.outputs(): - if otensor in self._outputs: - otensor.append(otensor) - outputs.append(otensor) + if otensor in self._outputs and otensor not in outputs: + outputs.append(otensor) segment = IRSegment(nodes, inputs, outputs, self.name) - if mirror and segment.mirror is not None: - msegment = segment.mirror.dispatch(devid, mirror=False) + segment._id = self.cid + if mirror and self.mirror is not None: + msegment = self.mirror.dispatch(devid, mirror=False) IRCell.make_pair(segment, msegment) return segment diff --git a/cube/ir/cten.py b/cube/ir/cten.py index b52c783e..4ffba98b 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -506,7 +506,7 @@ def is_attr(self) -> bool: """! Check if the tensor is graph attribute. - @return is_attr boolean: True if is graph attribute (buffer or parameter) + @return is_attr boolean: True if is graph attribute (buffer or parameter or gradient of parameter) """ return self._is_attr diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 43ae40d1..a9e6793e 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -16,6 +16,12 @@ 2). for (FwOp) output tensors, gradient SubTensor is: indmap = output.indmap; val is always (0/1) + +Tensor can be graph attributes. In deep learning, these graph attribute tensors +can be + 1) parameters (require gradient), + 2) buffers (not require gradient) + 3) gradient of parameters """ from typing import List, Optional, Union, Tuple, NewType, Dict @@ -289,11 +295,14 @@ def grad(self, val: Optional[Union[IRTensor, float]]): """ int indicates the tensor is the loss tensor. """ - assert isinstance(val, (IRFullTensor, float)) or val is None, f"grad can only be IRFullTensor or None, but got {val}" + if self._requires_grad: + assert isinstance(val, (IRFullTensor, float)) + if isinstance(val, IRFullTensor): + assert val.shape == self.shape + assert val.is_attr() == self.is_attr() + else: + assert val is None self._grad = val - self._requires_grad = False if val is None else True - if isinstance(val, IRFullTensor): - assert val.shape == self.shape, f"IRFullTensor gradient shape mismatch." @property def requires_grad(self): @@ -303,9 +312,13 @@ def requires_grad(self): def requires_grad(self, val: bool): self._requires_grad = val if val and self.grad is None: - self.grad = IRFullTensor(self.shape, 'g' + self.name, False, dtype=self.dtype).as_grad() + grad = IRFullTensor( + self.shape, 'g' + self.name, + requires_grad=False, dtype=self.dtype + ).as_grad(self.is_attr()) + self._grad = grad elif not val and self.grad is not None: - self.grad = None + self._grad = None @property def dtype(self) -> IRDType: @@ -333,6 +346,8 @@ def as_param(self): self.requires_grad = True self._is_attr = True self._is_grad = False + if isinstance(self.grad, IRFullTensor): + self.grad._is_attr = True def as_buffer(self): """ @@ -342,9 +357,9 @@ def as_buffer(self): self._is_attr = True self._is_grad = False - def as_grad(self): + def as_grad(self, of_attr: bool = False): + self._attr = True if of_attr else False self.requires_grad = False - self._is_attr = False self._is_grad = True return self diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index 71400eb7..7ce693d5 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -2,6 +2,7 @@ import numpy as np from cube.graph import IRGraph +from cube.graph.segment import IRSegment from cube.graph.function.anchor import IRGraphAnchor from cube.ir.cten import IRCell from cube.ir.operator import IRDataOperation, IRFwOperation @@ -98,29 +99,33 @@ def PAS1F1B(graph: IRGraph, resource): """ 1F1B scheduling """ - num_stage = resource.ngpus + num_stages = resource.ngpus num_microbatch = resource.ngpus * 8 - _, stage_mesh = _create_mesh(resource.ngpus, (num_stage, 1)) - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - # group to transformer layers + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] transformers = _group_to_transformers(fnodes) # staging + fstages = [[] for _ in range(num_stages)] nlayer_per_stage = (len(transformers) // resource.ngpus) for lid, fnodes in enumerate(transformers): - stage_id = min(lid // nlayer_per_stage, num_stage-1) - print(f'assigning {lid}-th transformer layter to stage {stage_id}') - for fnode in fnodes: - graph.assign(fnode, stage_id) + stage_id = min(lid // nlayer_per_stage, num_stages - 1) + fstages[stage_id] += fnodes + graph.staging(tuple(stages[0] for stages in fstages)) + + # stage to device + fsegments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] + assert len(fsegments) == num_stages + for devid, segment in enumerate(fsegments): + graph.assign(segment, devid) for node in graph.nodes(): if isinstance(node, IRDataOperation): graph.assign(node, 0) - strategy = IRSchedule1F1B(graph, num_microbatch, stage_mesh) - graph.sched = strategy + strategy = IRSchedule1F1B(graph, num_microbatch) + graph.predef_sched(strategy) return graph @@ -140,17 +145,23 @@ def PASMegatron(graph: IRGraph, resource): print(f'pp groups: {pp_groups}') print(f'tp groups: {tp_groups}') - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - # group to transformer layers + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] transformers = _group_to_transformers(fnodes) - # staging - nlayer_per_stage = (len(transformers) // pp_size) + # inter-staging: set each stage operators + fstages = [[] for _ in range(pp_size)] + nlayer_per_stage = (len(transformers) // resource.ngpus) for lid, fnodes in enumerate(transformers): - sid = min(lid // nlayer_per_stage, pp_size-1) - print(f'assigning {lid}-th transformer layer to stage {sid}: {tp_groups[sid]}') - for fnode in fnodes: + stage_id = min(lid // nlayer_per_stage, pp_size - 1) + fstages[stage_id] += fnodes + graph.staging(tuple(stages[0] for stages in fstages)) + + # intra-stage: tp and dp parallelism on device group + fsegments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] + assert len(fsegments) == pp_size + for sid, segment in enumerate(fsegments): + for fnode in segment.nodes(): if fnode.name == 'self_attention': _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) elif fnode.name == 'feedforward': @@ -162,6 +173,6 @@ def PASMegatron(graph: IRGraph, resource): if isinstance(node, IRDataOperation): _replica(graph, node, tp_groups[0]) - strategy = IRSchedule1F1B(graph, num_microbatch, tp_groups) - graph.sched = strategy + strategy = IRSchedule1F1B(graph, num_microbatch) + graph.predef_sched(strategy) return graph From 637ce8ab9ecb54c04084effe5fce1a486e13f644 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 8 Sep 2022 17:31:41 +0800 Subject: [PATCH 1003/1892] ad hoc fix embed to match same computation cost --- cube/graph/function/function.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index df208b0c..9d2293e4 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -853,7 +853,19 @@ def Embedding(signature, inputs: List): ] oshapes = [ishapes[0] + [ishapes[1][-1]]] anno = OpAnno.create_op_str(ishapes, oshapes) - return IRDimops(signature, [anno], [itensor, weight], 'embedding', padding_idx=padding_idx, start=start, stop=stop) + + def embed_modifer(kwargs: Dict, idx, dim, num): + import warnings + warnings.warn('FIXME: The semantic is error when split embedding, but the computation cost is same.') + kwargs = dict(**kwargs) + kwargs['stop'] = kwargs['stop'] // num + return kwargs + rules = [TransformRule( + [DimopSplit.R(), DimopSplit.D(0)], [DimopSplit.V()], embed_modifer + )] + + return IRDimops(signature, [anno], [itensor, weight], 'embedding', rules, + padding_idx=padding_idx, start=start, stop=stop) def Flatten(signature, inputs: List): From 2ef2d0e207faa5a8e09c80322cf07b178c4e31a9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 8 Sep 2022 18:47:53 +0800 Subject: [PATCH 1004/1892] add compiler option to skip content load --- cube/compiler.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index efef2e31..9992350c 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -23,7 +23,8 @@ def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, - PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, override = True): + PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, + override = True, load_content = True) -> Callable: """ AI Scientist calls like: @@ -45,10 +46,16 @@ def train_step(model, dataloader): ... - Args: - model: AI Scientist specified SemanticModel - dataloader: dataloader used for training - policy: tuple of transformation policy and scheduling policy + @param model SemanticModel: AI Scientist specified SemanticModel + @param dataloader CubDataLoader: dataloader used for training + @param policy Callable: policy to transform and schedule graph + @param override bool: If true, the generated code will override exsisting + files (if they are already existed.), otherwise, use the already existed + generated code, i.e., the policy won't take effect. Default true. + @param load_content bool: If true, will load parameter from exsiting saved models. + Otherwise, will initial model parameters with empty tensor. + + @return sched_fn Callable: the scheduling function loaded from generated code. """ if not isinstance(model, SemanticModel): raise TypeError("Expect Semantic Model") @@ -89,7 +96,7 @@ def decorator(fn: Callable) -> Callable: print('warning: dataloader batch size stay as default.') # load module code print_each_rank(f'loading existed module from {filename} ...') - model.load_module(filename) + model.load_module(filename, load_content=load_content) # load schedule code print_each_rank(f'loading existed schedule from {filename} ...') return _load_tschedule_fn(filename) @@ -210,7 +217,7 @@ def decorator(fn: Callable) -> Callable: # load module filename = filename.format(myrank) print_each_rank(f'loading generated module from {filename} ...') - model.load_module(filename) + model.load_module(filename, load_content=load_content) # load temporal schedule print_each_rank(f'loading generated schedule from {filename} ...') return _load_tschedule_fn(filename) From 8476fddd74dcf47e108da88441986dec1fa11bee Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 9 Sep 2022 15:21:25 +0800 Subject: [PATCH 1005/1892] standard function creation for dimops --- cube/graph/function/cat.py | 77 --------- cube/graph/function/dimops.py | 17 +- cube/graph/function/function.py | 287 ++++++++++++++++++-------------- cube/graph/parser/mapping.py | 21 +-- cube/graph/parser/parser.py | 3 +- cube/graph/parser/register.py | 2 +- 6 files changed, 182 insertions(+), 225 deletions(-) delete mode 100644 cube/graph/function/cat.py diff --git a/cube/graph/function/cat.py b/cube/graph/function/cat.py deleted file mode 100644 index 72b97799..00000000 --- a/cube/graph/function/cat.py +++ /dev/null @@ -1,77 +0,0 @@ -from copy import copy -import itertools -from typing import List, Tuple - -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor - -class IRCat(IRFwOperation): - def __init__(self, signature: str, inputs: List[IRTensor], name: str, - **kwargs): - # torch.cat(inputs:List[Tensor], dim:int) -> Tensor - # REMARK: the input to 'cat' is a tensor list, so 'inputs' parameter directly reflects the singleton list containing that list, - # so the meaning of param 'inputs' is sligtly different from other IRXXXOp. - assert len(inputs) > 0, "TODO handle zero inputs" - assert len(kwargs) == 1, "Expected 1 kwargs: dim" - - super().__init__(name, signature, len(inputs), 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update(kwargs) - - def infer_shape(self) -> bool: - """ - Output shape inference given the input shapes - """ - dim = self.kwargs['dim'] - - # validation - # TODO how about zero inputs? - tensors : Tuple[IRTensor, ...] = self.inputs() # None for all inputs - - # Shape without the dim-th component - s0 : list = None - for i, tensor in enumerate(tensors): - s : list = copy(tensor.shape) # avoid mutating the original shape - - if len(s) == 0: - # Any shape unknown - return False - - s.pop(dim) - if i == 0: - s0 = s - else: - if s != s0: - # Inconsistent input shape - return False - - sumLen : int = sum(t.shape[dim] for t in tensors) - s0.insert(dim, sumLen) - self.output(0).shape = s0 - return True - - -class IRStack(IRFwOperation): - def __init__(self, signature: str, inputs: List[IRTensor], name: str, dim: int): - # torch.stack(inputs:List[Tensor], dim:int) -> Tensor - assert len(inputs) > 0 - - super().__init__(name, signature, len(inputs), 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update({"dim": dim}) - - def infer_shape(self) -> bool: - dim = self.kwargs['dim'] - tensors : Tuple[IRTensor, ...] = self.inputs() # None for all inputs - - # `stack` requires all input tensors to have the same shape - if len(set(t.shape for t in tensors)) != 1: - return False - - shp : list = tensors[0].shape.copy() - shp.insert(dim, len(tensors)) - self.output(0).shape = shp - return True - diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index d0b148ed..419bce16 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -555,9 +555,11 @@ class IRDimops(IRFwOperation): """ Einstein-inspired notation operations """ - def __init__(self, signature: str, annos: Tuple[str], - inputs: List[IRTensor], name: str, - transform_rules: Optional[Tuple[TransformRule]] = None, **kwargs): + def __init__(self, create_fn: Callable, name: str, + signature: str, annos: Tuple[str], + inputs: List[IRTensor], + transform_rules: Optional[Tuple[TransformRule]] = None, + **kwargs): """! Create a IRDimops @@ -574,6 +576,7 @@ def __init__(self, signature: str, annos: Tuple[str], self._iannos: List[ShapeAnno] = None self._oannos: List[ShapeAnno] = None self._trans_rules: Tuple[TransformRule] = tuple(transform_rules) if transform_rules is not None else () + self._create_fn: Tuple[Callable] = (create_fn,) for anno in self._annos_candidates: anno = OpAnno(anno) @@ -666,9 +669,11 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): @return op IRDimop: the new constructed operator """ - annos = self._annos_candidates - rules = self._trans_rules - op = IRDimops(self.signature, annos, inputs, self.name, rules, **kwargs) + inputs = inputs + [kwargs[key] for key in kwargs.keys()] + op = self._create_fn[0](self.signature, inputs) + # annos = self._annos_candidates + # rules = self._trans_rules + # op = IRDimops(self.signature, annos, inputs, self.name, rules, **kwargs) for idx, output in enumerate(outputs): op.set_output(idx, output) return op diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 9d2293e4..d5ae8a47 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1,18 +1,17 @@ -from typing import Any, List, Optional, Tuple, Dict, Union +from typing import Any, Callable, List, Optional, Tuple, Dict, Union import string import copy import torch import warnings +import operator from cube.ir.cten import IRTensor -from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRSubTensor from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRConv2D from cube.graph.function.conv import IRConv3D from cube.graph.function.pad import IRPad from cube.graph.function.scripteinops import IRScriptEinOps -from cube.graph.function.customops import IRCustomOps from cube.graph.function.creators import IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor from cube.graph.function.select import IRSelect, IRSlice from cube.graph.function.scatter import IRSelectScatter @@ -26,25 +25,28 @@ def Identity(signature, inputs: List[IRTensor]): signature = 'cube.runtime.function.identity' eshape = ShapeAnno.create_shape_str(inputs[0].shape) anno = OpAnno.create_op_str([eshape], [eshape]) - return IRDimops(signature, [anno], inputs, 'identity') + return IRDimops(Identity, 'identity', signature, [anno], inputs) def Linear(signature, inputs): + assert len(inputs) == 3 signature = 'torch.nn.functional.linear' - annos = [ - 'b * k+, n k+ -> b * n', # no bias - 'b * k+, n k+, n -> b * n' # have bias - ] if inputs[2] is None: - inputs = inputs[0:2] - return IRDimops(signature, annos, inputs, 'linear') + annos = ['b * k+, n k+ -> b * n'] + return IRDimops(Linear, 'linear', signature, annos, inputs[:2], bias=None) + else: + annos = ['b * k+, n k+, n -> b * n'] + rules = [TransformRule( + [DimopSplit.D(-1), DimopSplit.D(1), DimopSplit.V()], [DimopSplit.V()] + )] + return IRDimops(Linear, 'linear', signature, annos, inputs, rules) def BatchLinear(signature, inputs): annos = [ - 'b m k, b k n -> b m n' + 'b m k+, b k+ n -> b m n' ] - return IRDimops(signature, annos, inputs, 'bmm') + return IRDimops(BatchLinear, 'bmm', signature, annos, inputs) def Zeros(signature, @@ -257,8 +259,7 @@ def Add(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(signature, annos, inputs, 'add', **kwargs) - + return IRDimops(Add, 'add', signature, annos, inputs, **kwargs) def Sub(signature, inputs): @@ -286,7 +287,7 @@ def Sub(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(signature, annos, inputs, 'sub', **kwargs) + return IRDimops(Sub, 'sub', signature, annos, inputs, **kwargs) def Mul(signature, inputs): @@ -302,7 +303,7 @@ def Mul(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(signature, annos, inputs, 'mul') + return IRDimops(Mul, 'mul', signature, annos, inputs) def Div(signature, inputs): @@ -319,7 +320,7 @@ def Div(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(signature, annos, inputs, 'div') + return IRDimops(Div, 'div', signature, annos, inputs) def FloorDiv(signature, inputs): @@ -335,7 +336,7 @@ def FloorDiv(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(signature, annos, inputs, 'floordiv') + return IRDimops(FloorDiv, 'floordiv', signature, annos, inputs) def Pow(signature, inputs): @@ -351,48 +352,20 @@ def Pow(signature, inputs): if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): lshape, rshape, oshape = _handle_broadcast(lhs, rhs) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(signature, annos, inputs, 'pow') - - -# if both operands are scalars, returns bool. -# if one operand is a tensor, returns a broadcasted tensor with dtype being bool. -def comparison_einops(f, name, signature, inputs): - # f : (Scalar, Scalar) -> bool - assert len(inputs) == 2 - lhs, rhs = inputs - - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): - return f(lhs, rhs) - - annos = [ - '*, ? -> *', - '?, * -> *', - ] - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): - lshape, rshape, oshape = _handle_broadcast(lhs, rhs) - annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(signature, annos, inputs, name) + return IRDimops(Pow, 'pow', signature, annos, inputs) def Neg(signature, inputs): - if len(inputs) == 1: - kwargs = {} - elif len(inputs) == 2: - # adapt for newest pytorch version - approximate = inputs[1] - kwargs = {'approximate': approximate} - - inputs = inputs[0:1] - else: - raise RuntimeError("The number of inputs must be 1 or 2") + assert len(inputs) == 1 or len(inputs) == 2 + kwargs = {} if len(inputs) == 1 else {'approximate': inputs[1]} + tensors = inputs[0:1] - arg, = inputs - if isinstance(arg, (int, float)): + if isinstance(tensors[0], (int, float)): assert not('approximate' in kwargs) - return -arg + return -tensors[0] annos = ['* -> *'] - return IRDimops(signature, annos, inputs, 'neg', **kwargs) + return IRDimops(Neg, 'neg', signature, annos, inputs, **kwargs) def Sin(signature, inputs): annos = ['* -> *'] @@ -400,10 +373,10 @@ def Sin(signature, inputs): if len(inputs) == 2: # adapt for newest pytorch version approximate = inputs[1] - return IRDimops(signature, annos, tensor, 'sin', + return IRDimops(Sin, 'sin', signature, annos, tensor, approximate=approximate) else: - return IRDimops(signature, annos, tensor, 'sin') + return IRDimops(Sin, 'sin', signature, annos, tensor) def Cos(signature, inputs): @@ -412,10 +385,10 @@ def Cos(signature, inputs): if len(inputs) == 2: # adapt for newest pytorch version approximate = inputs[1] - return IRDimops(signature, annos, tensor, 'cos', + return IRDimops(Cos, 'cos', signature, annos, tensor, approximate=approximate) else: - return IRDimops(signature, annos, tensor, 'cos') + return IRDimops(Cos, 'cos', signature, annos, tensor) def GeLU(signature, inputs): @@ -425,34 +398,35 @@ def GeLU(signature, inputs): if len(inputs) == 2: # adapt for newest pytorch version approximate = inputs[1] - return IRDimops(signature, annos, tensor, 'gelu', + return IRDimops(GeLU, 'gelu', signature, annos, tensor, approximate=approximate) else: - return IRDimops(signature, annos, tensor, 'gelu') + return IRDimops(GeLU, 'gelu', signature, annos, tensor) def SiLU(signature, inputs): + assert len(inputs) == 1 annos = ['* -> *'] signature = 'torch.nn.functional.silu' tensor = inputs[0:1] - return IRDimops(signature, annos, tensor, 'silu') + return IRDimops(SiLU, 'silu', signature, annos, tensor) def Softmax(signature, inputs): + assert len(inputs) == 4 annos = ['* -> *'] tensor = inputs[0:1] dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] - return IRDimops(signature, annos, tensor, 'softmax', + return IRDimops(Softmax, 'softmax', signature, annos, tensor, dim=dim, _stacklevel=_stacklevel, dtype=dtype) def Dropout(signature, inputs): - annos = [ - '* -> *' - ] + assert len(inputs) == 4 + annos = ['* -> *'] tensor = inputs[0:1] p, training, inplace = inputs[1], inputs[2], inputs[3] - return IRDimops(signature, annos, tensor, 'dropout', + return IRDimops(Dropout, 'dropout', signature, annos, tensor, p=p, training=training, inplace=inplace) @@ -464,31 +438,47 @@ def LayerNorm(signature, inputs): f'N *, ?, {normalized_shape[0]}, {normalized_shape[0]} -> N *', f'N *, ?, ?, ? -> N *' ] - return IRDimops(signature, annos, [input, normalized_shape, weight, bias], - 'layernorm', eps=eps) + return IRDimops(LayerNorm, 'layernorm', signature, annos, [input, normalized_shape, weight, bias], + eps=eps) def Sum(signature, inputs): + """ + torch.sum(input, *, dtype=None) -> Tensor + torch.sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + """ + assert len(inputs) == 2 or len(inputs) == 4, f"{inputs}" tensor = inputs[0] - dim = inputs[1] einput = ShapeAnno.create_shape_str(tensor.shape) eoutput = copy.copy(einput) - if dim is not None: - keepdim = inputs[2] - sort_dim = list(dim) - sort_dim.sort() - for dimidx in sort_dim[::-1]: - eoutput.pop(dimidx) - einput[dimidx] = einput[dimidx] + '+' - else: + if len(inputs) == 2: + dtype = inputs[1] + assert dtype is None, "Currently Sum only support dtype=None" + # torch.sum(input) + inputs = [tensor] eoutput = ['1'] - # every dimension is reduced + # every dimension can be reduced einput = [edim + '+' for edim in einput] - anno = OpAnno.create_op_str([einput], [eoutput]) - if dim is not None: - return IRDimops(signature, [anno], [tensor], 'sum', dim=dim, keepdim=keepdim) + anno = OpAnno.create_op_str([einput], [eoutput]) + return IRDimops(Sum, 'sum', signature, [anno], [tensor], dtype=dtype) else: - return IRDimops(signature, [anno], [tensor], 'sum') + # torch.sum(input, dim, keepdim, *, dtype) + dim, keepdim, dtype = inputs[1:4] + assert dtype is None, "Currently Sum only support dtype=None" + assert isinstance(dim, list), f"Expect dim to be list but got: {dim}" + for dimidx in dim: + einput[dimidx] += '+' + if keepdim: + for dimidx in dim: + eoutput[dimidx] = '1' + else: + sort_dim = list(dim) + sort_dim.sort() + for dimidx in sort_dim[::-1]: + eoutput.pop(dimidx) + anno = OpAnno.create_op_str([einput], [eoutput]) + return IRDimops(Sum, 'sum', signature, [anno], [tensor], dim=dim, keepdim=keepdim, dtype=dtype) + def Mean(signature, inputs): tensor = inputs[0] @@ -508,9 +498,9 @@ def Mean(signature, inputs): einput = [edim + '+' for edim in einput] anno = OpAnno.create_op_str([einput], [eoutput]) if dim is not None: - return IRDimops(signature, [anno], [tensor], 'mean', dim=dim, keepdim=keepdim) + return IRDimops(Mean, 'mean', signature, [anno], [tensor], dim=dim, keepdim=keepdim) else: - return IRDimops(signature, [anno], [tensor], 'mean') + return IRDimops(Mean, 'mean', signature, [anno], [tensor]) def Transpose(signature, inputs): @@ -525,7 +515,7 @@ def Transpose(signature, inputs): edim_ou[dim0], edim_ou[dim1] = edim_ou[dim1], edim_ou[dim0] anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(signature, [anno], [input], 'transpose', + return IRDimops(Transpose, 'transpose', signature, [anno], [input], dim0=dim0, dim1=dim1) @@ -672,7 +662,7 @@ def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: anno = OpAnno.create_op_str([in_anno], [ou_anno]) signature = 'torch.Tensor.view' - return IRDimops(signature, [anno], [input], 'view', rules, size=tuple(shape)) + return IRDimops(View, 'view', signature, [anno], [input], rules, size=tuple(shape)) def Reshape(signature, inputs): @@ -772,10 +762,16 @@ def Pad(signature, inputs): def Cat(signature, inputs: Tuple[List[IRTensor], int]): """ torch.cat(inputs: List[Tensor], dim: int) -> Tensor + torch.cat(tensor1: Tensor, tensor2: Tensor, ..., dim: int) e.g. cat(tensor([2,3]), tensor([2,3])).shape == [4,3] """ - tensors, dim = inputs + assert len(inputs) >= 2 + if len(inputs) == 2: + tensors, dim = inputs + else: + tensors, dim = inputs[:-1], inputs[-1] + assert all(isinstance(tensor, IRTensor) for tensor in tensors) iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] dimlens = [t.shape[dim] for t in tensors] for ashape, dimlen in zip(iannos, dimlens): @@ -783,12 +779,13 @@ def Cat(signature, inputs: Tuple[List[IRTensor], int]): oannos = [copy.copy(iannos[-1])] oannos[0][dim] = str(sum(dimlens)) anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(signature, [anno], tensors, 'cat', dim=dim) + return IRDimops(Cat, 'cat', signature, [anno], tensors, dim=dim) def Stack(signature, inputs: Tuple[List[IRTensor], int]): """ torch.stack(inputs: List[Tensor], dim: int) -> Tensor + torch.stack(tensor1: Tensor, tensor2: Tensor, ..., dim: int) -> Tensor inputs: tensors: List[Tensor]: all tensors need to have same size @@ -796,12 +793,17 @@ def Stack(signature, inputs: Tuple[List[IRTensor], int]): e.g. stack(tensor([2,3]), tensor([2,3])).shape == [2,2,3] """ - tensors, dim = inputs + assert len(inputs) >= 2 + if len(inputs) == 2: + tensors, dim = inputs + else: + tensors, dim = inputs[:-1], inputs[-1] + assert all(isinstance(tensor, IRTensor) for tensor in tensors) iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] oannos = [copy.copy(iannos[-1])] oannos[0].insert(dim, str(len(tensors))) anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(signature, [anno], tensors, 'stack', dim=dim) + return IRDimops(Stack, 'stack', signature, [anno], tensors, dim=dim) def Select(signature, inputs: Tuple[IRTensor, int, int]): @@ -845,7 +847,10 @@ def Embedding(signature, inputs: List): signature = 'cube.runtime.function.embedding' itensor, weight = inputs[:2] padding_idx = inputs[3] - start, stop = 0, weight.shape[0] + if isinstance(weight, IRSubTensor): + start, stop = weight.indmap[0] + else: + start, stop = 0, weight.shape[0] letters = iter(string.ascii_lowercase) ishapes = [ ShapeAnno.create_shape_str(itensor.shape, iterator=letters), @@ -854,17 +859,17 @@ def Embedding(signature, inputs: List): oshapes = [ishapes[0] + [ishapes[1][-1]]] anno = OpAnno.create_op_str(ishapes, oshapes) - def embed_modifer(kwargs: Dict, idx, dim, num): - import warnings - warnings.warn('FIXME: The semantic is error when split embedding, but the computation cost is same.') - kwargs = dict(**kwargs) - kwargs['stop'] = kwargs['stop'] // num - return kwargs - rules = [TransformRule( - [DimopSplit.R(), DimopSplit.D(0)], [DimopSplit.V()], embed_modifer - )] - - return IRDimops(signature, [anno], [itensor, weight], 'embedding', rules, + # def embed_modifer(kwargs: Dict, idx, dim, num): + # import warnings + # warnings.warn('FIXME: The semantic is error when split embedding, but the computation cost is same.') + # kwargs = dict(**kwargs) + # kwargs['stop'] = kwargs['stop'] // num + # return kwargs + # rules = [TransformRule( + # [DimopSplit.R(), DimopSplit.D(0)], [DimopSplit.V()], embed_modifer + # )] + + return IRDimops(Embedding, 'embedding', signature, [anno], [itensor, weight], padding_idx=padding_idx, start=start, stop=stop) @@ -878,7 +883,7 @@ def Flatten(signature, inputs: List): oshape = ishape[:start_dim] oshape.append(ishape[start_dim:end_dim+1]) anno = OpAnno.create_op_str([ishape], [oshape]) - return IRDimops(signature, [anno], [tensor], 'flatten', start_dim=start_dim, end_dim=end_dim) + return IRDimops(Flatten, 'flatten', signature, [anno], [tensor], start_dim=start_dim, end_dim=end_dim) def Roll(signature, inputs: Tuple[IRTensor, Union[int, Tuple[int]], Union[int, Tuple[int]]]): @@ -889,7 +894,7 @@ def Roll(signature, inputs: Tuple[IRTensor, Union[int, Tuple[int]], Union[int, T if dims is None or dim in dims: ishape[dim] += '^' anno = OpAnno.create_op_str([ishape], [ishape]) - return IRDimops(signature, [anno], [tensor], 'roll', shifts=shifts, dims=dims) + return IRDimops(Roll, 'roll', signature, [anno], [tensor], shifts=shifts, dims=dims) def AdaptiveAvgPool1d(signature, inputs: Tuple[IRTensor, Tuple[int]]): @@ -899,7 +904,7 @@ def AdaptiveAvgPool1d(signature, inputs: Tuple[IRTensor, Tuple[int]]): ishape[-1] += '^' oshape = ishape[:-1] + [str(size) for size in out_size] anno = OpAnno.create_op_str([ishape], [oshape]) - return IRDimops(signature, [anno], [tensor], 'adaptive_avg_pool1d', output_size=out_size) + return IRDimops(AdaptiveAvgPool1d, 'adaptive_avg_pool1d', signature, [anno], [tensor], output_size=out_size) def CrossEntropy(signature, inputs): @@ -914,7 +919,8 @@ def CrossEntropy(signature, inputs): 'N+ C *, N+ * -> 1' ] return IRDimops( - signature, annos, [tensor, target], 'cross_entropy', + CrossEntropy, 'cross_entropy', + signature, annos, [tensor, target], weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce, reduction=reduction, label_smoothing=label_smoothing ) @@ -930,7 +936,7 @@ def MultiRef(signature, inputs: List[IRTensor]): assert isinstance(itensor, IRTensor), "require all inputs to be IRSubTensor" assert isinstance(times, int), "require int for second input" anno = '* -> ' + ', '.join('*' for _ in range(times)) - node = IRDimops(signature, [anno], [itensor], 'multiref', times=times) + node = IRDimops(MultiRef, 'multiref', signature, [anno], [itensor], times=times) return node @@ -959,23 +965,52 @@ def ScriptEinOps(signature, inputs): return IRScriptEinOps(signature, tensors, 'scripteinops', recipe_str=recipe_str, reduction_type=reduction_type) -def CustomOps(signature, inputs): - if signature == 'examples.custom_ops.strip_2_borders': - tensors = inputs[0:1] - print(f'CustomOps:tensors[0] = {tensors[0]}') - return IRCustomOps(signature, tensors, 'custom_ops') - elif signature == 'examples.custom_ops.update_diag_': - tensors = inputs[0:10] - # dt = inputs[9] - dz = inputs[10] - return IRCustomOps(signature, tensors, 'custom_ops', dz=dz) - elif signature == 'examples.custom_ops.update_geopotential_': - tensors = inputs[0:5] - g = inputs[5] - CPD = inputs[6] - nz = inputs[7] - return IRCustomOps(signature, tensors, 'custom_ops', g=g, CPD=CPD, nz=nz) +def _comparison(creator: Callable, f: Callable, name: str, signature: str, inputs): + """ + if both operands are scalars, returns bool. + if one operand is a tensor, returns a broadcasted tensor with dtype being bool. + + @param creator Callable: the outside creation function + @param f Callable: (Scalar, Scalar) -> bools + """ + assert len(inputs) == 2 + lhs, rhs = inputs + + if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): + return f(lhs, rhs) - else: - import warnings - warnings.warn(f"ERROR Unknown custom op, signature {signature}") + annos = [ + '*, ? -> *', + '?, * -> *', + ] + if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): + lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(creator, name, signature, annos, inputs) + + +def CompareGT(signature, inputs): + """ + torch.gt(input, other, *, out=None) -> Tensor + """ + return _comparison(CompareGT, operator.gt, 'gt', signature, inputs) + + +def CompareLT(signature, inputs): + """ + torch.lt(input, other, *, out=None) -> Tensor + """ + return _comparison(CompareLT, operator.lt, 'lt', signature, inputs) + + +def CompareGE(signature, inputs): + """ + torch.ge(input, other, *, out=None) -> Tensor + """ + return _comparison(CompareGE, operator.ge, 'ge', signature, inputs) + +def CompareLE(signature, inputs): + """ + torch.gt(input, other, *, out=None) -> Tensor + """ + return _comparison(CompareLE, operator.le, 'le', signature, inputs) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 1e90f6e6..35a6257a 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -2,17 +2,12 @@ Mapping of Signature -> IROperator """ -from typing import Any, Callable, Dict, Union -import torch - -import operator +from typing import Callable, Dict, Union from functools import partial import cube.graph.function as function from cube.ir.operator import IRFwOperation -# TODO this is a backwards-compatible alias -from cube.graph.torch_dtype_mapping import DType2IRDType class Sign2Op: @@ -102,10 +97,10 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('neg'): function.Neg, - __ttemplate('gt'): partial(function.comparison_einops, operator.gt, 'gt'), - __ttemplate('lt'): partial(function.comparison_einops, operator.lt, 'lt'), - __ttemplate('ge'): partial(function.comparison_einops, operator.ge, 'ge'), - __ttemplate('le'): partial(function.comparison_einops, operator.le, 'le'), + __ttemplate('gt'): function.CompareGT, + __ttemplate('lt'): function.CompareLT, + __ttemplate('ge'): function.CompareGE, + __ttemplate('le'): function.CompareLE, __ttemplate('pow'): function.Pow, @@ -157,13 +152,11 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __rtemplate('identity'): function.Identity, + __rtemplate('multiref'): function.MultiRef, + #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, - #custom ops - __customops('strip_2_borders'): function.CustomOps, - __customops('update_diag_'): function.CustomOps, - __customops('update_geopotential_'): function.CustomOps, } # customized operator code: signature -> code diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 5e605154..95627ec2 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -7,7 +7,8 @@ from cube.ir.tensor import IRFullTensor import cube.ir as ir from cube.graph.parser.frame import Frame -from cube.graph.parser.mapping import Sign2Op, DType2IRDType +from cube.graph.parser.mapping import Sign2Op +from cube.graph.torch_dtype_mapping import DType2IRDType _refmodule = torch.nn.Module() diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 4e51af77..10dcf9ec 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -63,7 +63,7 @@ def udfop(signature: str, inputs: List[Any]): kwargs = dict() for name, val in zip(kwarg_names, kwarg_vals): kwargs[name] = val - return IRDimops(signature, [repr(manno)], tensors, **kwargs, name=op_name) + return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, **kwargs) print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') Sign2Op.register(fsig, udfop, code) From db0437036b9ea8b110412383865779b379a060d4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 9 Sep 2022 18:30:30 +0800 Subject: [PATCH 1006/1892] add gpt megatron benchmark --- benchmark/megatron/benchmark_gpt.sh | 40 +++++++ benchmark/megatron/pretrain_gpt_synthetic.py | 109 +++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100755 benchmark/megatron/benchmark_gpt.sh create mode 100644 benchmark/megatron/pretrain_gpt_synthetic.py diff --git a/benchmark/megatron/benchmark_gpt.sh b/benchmark/megatron/benchmark_gpt.sh new file mode 100755 index 00000000..1265d0ab --- /dev/null +++ b/benchmark/megatron/benchmark_gpt.sh @@ -0,0 +1,40 @@ + +# get megatron +# git clone https://github.com/NVIDIA/Megatron-LM.git + +cp pretrain_gpt_synthetic.py ./Megatron-LM/ + +GPT_ARGS="--num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --micro-batch-size 8 \ + --global-batch-size 8 \ + --lr 0.00015 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --lr-warmup-fraction .01 \ + --fp16 \ + --fp16-lm-cross-entropy \ + --no-masked-softmax-fusion \ + --no-bias-gelu-fusion \ + --no-bias-dropout-fusion" + +DISTRIBUTED_ARGS="--nproc_per_node 8 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 6000" + + +cd Megatron-LM + +OMP_NUM_THREADS=4 python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt_synthetic.py $GPT_ARGS \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --DDP-impl torch + +cd .. \ No newline at end of file diff --git a/benchmark/megatron/pretrain_gpt_synthetic.py b/benchmark/megatron/pretrain_gpt_synthetic.py new file mode 100644 index 00000000..76716e5a --- /dev/null +++ b/benchmark/megatron/pretrain_gpt_synthetic.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain GPT""" + +import torch +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron import mpu +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import GPTModel, ModelType +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + + vocab_size = 50257 + after = vocab_size + multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size + while after % multiple != 0: + after += 1 + args.padded_vocab_size = after + + print_rank_0('building GPT model ...') + model = GPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + vocab_size = 50257 + tokens = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size + labels = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size + loss_mask = (torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()) < 0.5).float() + attention_mask = (torch.rand((args.micro_batch_size, 1, args.seq_length, args.seq_length), requires_grad=False, device=torch.cuda.current_device()) < 0.5) + position_ids = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * args.seq_length + + return tokens, labels, loss_mask, attention_mask, position_ids + +def loss_func(loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + return None, None, None + + +if __name__ == "__main__": + + + + pretrain(train_valid_test_datasets_provider, model_provider, + ModelType.encoder_or_decoder, + forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) + + mem = torch.cuda.max_memory_allocated() + for rank in range(torch.distributed.get_world_size()): + if rank == torch.distributed.get_rank(): + print(f'rank[{rank}]: memory consumption: {round(mem / 1024 / 1024 / 1024 * 100) / 100} GBs') + torch.distributed.barrier() \ No newline at end of file From a0e21dbaff22ac1921f8a68060cef4b42d94fddb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 9 Sep 2022 19:40:28 +0800 Subject: [PATCH 1007/1892] fix un-train bug --- benchmark/megatron/benchmark_gpt.sh | 13 ++++++++----- benchmark/megatron/pretrain_gpt_synthetic.py | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/benchmark/megatron/benchmark_gpt.sh b/benchmark/megatron/benchmark_gpt.sh index 1265d0ab..b786e40e 100755 --- a/benchmark/megatron/benchmark_gpt.sh +++ b/benchmark/megatron/benchmark_gpt.sh @@ -4,25 +4,28 @@ cp pretrain_gpt_synthetic.py ./Megatron-LM/ +GPUS=8 + GPT_ARGS="--num-layers 24 \ --hidden-size 1024 \ --num-attention-heads 16 \ --seq-length 1024 \ --max-position-embeddings 1024 \ - --micro-batch-size 8 \ - --global-batch-size 8 \ + --micro-batch-size 1 \ + --global-batch-size 1 \ --lr 0.00015 \ - --train-iters 500000 \ + --train-iters 200 \ --lr-decay-iters 320000 \ --lr-decay-style cosine \ --lr-warmup-fraction .01 \ --fp16 \ --fp16-lm-cross-entropy \ + --no-query-key-layer-scaling \ --no-masked-softmax-fusion \ --no-bias-gelu-fusion \ --no-bias-dropout-fusion" -DISTRIBUTED_ARGS="--nproc_per_node 8 \ +DISTRIBUTED_ARGS="--nproc_per_node $GPUS \ --nnodes 1 \ --node_rank 0 \ --master_addr localhost \ @@ -33,7 +36,7 @@ cd Megatron-LM OMP_NUM_THREADS=4 python -m torch.distributed.launch $DISTRIBUTED_ARGS \ pretrain_gpt_synthetic.py $GPT_ARGS \ - --tensor-model-parallel-size 8 \ + --tensor-model-parallel-size ${GPUS}\ --pipeline-model-parallel-size 1 \ --DDP-impl torch diff --git a/benchmark/megatron/pretrain_gpt_synthetic.py b/benchmark/megatron/pretrain_gpt_synthetic.py index 76716e5a..7494738c 100644 --- a/benchmark/megatron/pretrain_gpt_synthetic.py +++ b/benchmark/megatron/pretrain_gpt_synthetic.py @@ -91,7 +91,7 @@ def forward_step(data_iterator, model): def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" - return None, None, None + return [1]*10000, None, None if __name__ == "__main__": @@ -106,4 +106,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): for rank in range(torch.distributed.get_world_size()): if rank == torch.distributed.get_rank(): print(f'rank[{rank}]: memory consumption: {round(mem / 1024 / 1024 / 1024 * 100) / 100} GBs') - torch.distributed.barrier() \ No newline at end of file + torch.distributed.barrier() From d42db34204ba85429cad18a8dfae74a81dd7ae2d Mon Sep 17 00:00:00 2001 From: lynex Date: Mon, 12 Sep 2022 16:15:48 +0800 Subject: [PATCH 1008/1892] update PyTorch example: regressive generation of GPT inference, fix TP of PASMegatronInferTP with embeding split and disable dropout for inference --- examples/nlp/blocks/attention.py | 57 +++++++++++++------------------- examples/nlp/blocks/mlp.py | 7 ++-- examples/nlp/gpt/model.py | 2 ++ examples/nlp/gpt/policy/spmd.py | 20 +++++++++++ 4 files changed, 49 insertions(+), 37 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index ff626831..1867a431 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -88,8 +88,6 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, return output -from typing import Optional, Tuple - @cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> l N E^', name='one_attention') def one_attention(hidden_states: torch.Tensor, past_embed_key: torch.Tensor, @@ -98,14 +96,14 @@ def one_attention(hidden_states: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, out_proj: torch.Tensor, #out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = True): + h: int, scale: float, dropout_p: float, is_training: bool = True, mask: bool = True): num_head = h - L, N = hidden_states.size(0), hidden_states.size(1) + l, N = hidden_states.size(0), hidden_states.size(1) dim_head = q_proj.size(0) // num_head - q = torch.nn.functional.linear(hidden_states, q_proj, q_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(hidden_states, k_proj, k_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(hidden_states, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + q = torch.nn.functional.linear(hidden_states, q_proj, q_bias) # l N E, (h d) E -> l N (h d) + k = torch.nn.functional.linear(hidden_states, k_proj, k_bias) # l N E, (h d) E -> l N (h d) + v = torch.nn.functional.linear(hidden_states, v_proj, v_bias) # l N E, (h d) E -> l N (h d) if past_embed_key is not None and past_embed_value is not None: k = torch.cat((past_embed_key, k), dim=-3) @@ -116,34 +114,26 @@ def one_attention(hidden_states: torch.Tensor, k_L = k.size(0) v_L = v.size(0) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(k_L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(v_L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d + q = q.contiguous().view(l, (N * num_head), dim_head) # l N (h d) -> L (N h) d + k = k.contiguous().view(k_L, (N * num_head), dim_head) # (L+l) N (h d) -> (L+l) (N h) d + v = v.contiguous().view(v_L, (N * num_head), dim_head) # (L+l) N (h d) -> (L+l) (N h) d + q = q.transpose(0, 1) # l (N h) d -> (N h) l d + k = k.transpose(0, 1) # (L+l) (N h) d -> (N h) (L+l) d + v = v.transpose(0, 1) # (L+l) (N h) d -> (N h) (L+l) d q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # # attention mask - # if mask: # (N h) L L -> (N h) L L - # attn = attn.view(N, num_head, L, L) - # ones = torch.ones((N, L, L), device=attn.device) - # mask = torch.tril(ones) - # mask = mask.view(N, 1, L, L) - # mask = (mask < 0.5) - # attn = attn.masked_fill_(mask, -10000.0) - # attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj, None) # L N (h d), E E -> L N E + k = k.transpose(1, 2) # (N h) (L+l) d -> (N h) d (L+l) + attn = torch.bmm(q, k) # (N h) l d, (N h) d (L+l) -> (N h) l (L+l) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) l (L+l) -> (N h) l (L+l) + #no dropout in inference attn + torch.nn.functional.dropout(attn, dropout_p, is_training, False) # (N h) l (L+l) -> (N h) l (L+l) + output = torch.bmm(attn, v) # (N h) l (L+l), (N h) (L+l) d -> (N h) l d + output = output.transpose(0, 1).contiguous() # (N h) l d -> l (N h) d + output = output.view(l, N, num_head * dim_head) # l (N h) d -> l N (h d) + output = torch.nn.functional.linear(output, out_proj, None) # l N (h d), E E -> l N E return output + class MultiHeadSelfAttention(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): @@ -236,7 +226,6 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa self.out_proj = torch.nn.Parameter(torch.rand(embed_dim, inner_dim)) self.out_bias = torch.nn.Parameter(torch.rand(embed_dim)) - from typing import Optional, Tuple def forward(self, query: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor): attn = one_attention( @@ -245,7 +234,7 @@ def forward(self, query: torch.Tensor, past_embed_key: torch.Tensor, past_embed_ self.k_proj, self.k_bias, self.v_proj, self.v_bias, self.out_proj, #self.out_bias, - self.num_heads, self.scaling, self.dropout_p, mask=True + self.num_heads, self.scaling, self.dropout_p, self.training, mask=True ) attn = attn + self.out_bias return attn \ No newline at end of file diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py index c9162364..767a54e6 100644 --- a/examples/nlp/blocks/mlp.py +++ b/examples/nlp/blocks/mlp.py @@ -6,10 +6,11 @@ def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj1_bias: torch.Tensor, proj2: torch.Tensor, - dropout: float) -> torch.Tensor: + dropout: float, + is_training: bool = True) -> torch.Tensor: x = torch.nn.functional.linear(x, proj1, proj1_bias) x = torch.nn.functional.gelu(x) - x = torch.nn.functional.dropout(x, dropout, True, False) + x = torch.nn.functional.dropout(x, dropout, is_training, False) x = torch.nn.functional.linear(x, proj2, None) return x @@ -26,6 +27,6 @@ def __init__(self, embed_dim: int, hidden_dim: int, dropout: float): def forward(self, x: torch.Tensor): x = feedforward(x, self.proj1, self.proj1_bias, - self.proj2, self.dropout) + self.proj2, self.dropout, self.training) x = x + self.proj2_bias return x diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 296598d5..b1d0b2fb 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -130,6 +130,7 @@ def __init__(self): def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # embed = self.embed(input_ids) + cube.runtime.function.anchor('first_embed') embed = torch.nn.functional.embedding( input_ids, self.embedw, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False @@ -145,6 +146,7 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): enc = self.final_layernorm(enc) # logits = torch.nn.functional.linear(enc, self.embed.weight) + cube.runtime.function.anchor('last_embed') logits = torch.nn.functional.linear(enc, self.embedw) # simplified loss = torch.sum(logits) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index 95987994..c71e8639 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -102,6 +102,26 @@ def PASMegatronInferTP(graph: IRGraph, resource): for ffn in ffns: _tp(graph, ffn, tp_devs, idx=1, dim=0) + # first embedding linear + first_emb_anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor) and node.name == 'first_embed'] + print(f'last_emd_anchors = {first_emb_anchors}') + indices = [fnodes.index(anchor) for anchor in first_emb_anchors] + for lid, idx in enumerate(indices): + print(f'fnodes[idx+1].name = {fnodes[idx+1].name}') + print(f'fnodes[idx+1] = {fnodes[idx + 1]}') + first_emb_node = fnodes[idx+1] + _tp(graph, first_emb_node, tp_devs, idx=1, dim=0) + + # last embedding linear + last_emb_anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor) and node.name == 'last_embed'] + print(f'last_emd_anchors = {last_emb_anchors}') + indices = [fnodes.index(anchor) for anchor in last_emb_anchors] + for lid, idx in enumerate(indices): + print(f'fnodes[idx+1].name = {fnodes[idx+1].name}') + print(f'fnodes[idx+1] = {fnodes[idx + 1]}') + last_emb_node = fnodes[idx+1] + _tp(graph, last_emb_node, tp_devs, idx=1, dim=0) + # replicate other nodes for node in graph.nodes(): if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: From 253e93dc948249c7b36bc59e84d6ced45c3158e0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 09:30:17 +0800 Subject: [PATCH 1009/1892] fix train dataloader bug --- examples/nlp/gpt/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index b1d0b2fb..b3bc2d2f 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -176,7 +176,7 @@ def random_sample(self): ) position_ids = torch.arange( 0, self.cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() - ).repeat(self.bs) + ).repeat(self.bs).view(self.bs, -1) return (input_ids, position_ids) def __iter__(self): From 5f9aa2aae727a086d35aac25b8f47371c8cb5a42 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Tue, 13 Sep 2022 06:03:46 +0000 Subject: [PATCH 1010/1892] Merged PR 1419: Add profiler introduction Add profiler introduction for snakeviz and alternative-to-cProfile viztracer --- README.md | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7be5f602..c4b054e9 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ OMP_NUM_THREADS=4 torchrun \ ## Profile -### Use cProfile +### Use cProfile + snakeviz Due to the multi-process architecture of `torch.distributed.launch`, instead of directly using the command-line interface of cProfile, we need to exactly specify the scope to profile, like: @@ -79,3 +79,30 @@ pr.dump_stats('cube_%d.prof' % torch.distributed.get_rank()) # or use TID/PID, i ``` After the modification, run the Python file using the same command line with `torchrun` as usual. + +After dumping the profiling data, we can use `snakeviz` to visualize it: + +```shell +pip install snakeviz +snakeviz cube_RANK_0.prof +``` + +### Use viztracer + +An alternative to cProfile + snakeviz is to use the profiler `viztracer`, +as well as its builtin visualization. + +`viztracer` is aware of the multi-process architecture of `torchrun` and it offers a command-line +interface and offers a very detailed profiling log. + +P.S. However, too detailed to be effectively used to profile huge DAG like the 23k~ nodes unrolled +WRF model. + +`viztracer` can be used like: + +```shell +pip install viztracer +viztracer --log_multiprocess torchrun --nproc_per_node=4 --nnodes=1 examples/mlp/linears.py +``` + +For more configurations please check `viztracer -h`. \ No newline at end of file From f0136ef52212b5f95e7dc229acb7c44e66dfc359 Mon Sep 17 00:00:00 2001 From: Xu Cao Date: Tue, 13 Sep 2022 06:12:04 +0000 Subject: [PATCH 1011/1892] Merged PR 1420: Fix typo in README fix typo --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index c4b054e9..ac2dc1d0 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ for i in range(N): # our code ends prof.disabled() -pr.dump_stats('cube_%d.prof' % torch.distributed.get_rank()) # or use TID/PID, if to profile multi-thread/-process program. +prof.dump_stats('cube_RANK%d.prof' % torch.distributed.get_rank()) # or use TID/PID, if to profile multi-thread/-process program. ``` After the modification, run the Python file using the same command line with `torchrun` as usual. @@ -84,7 +84,7 @@ After dumping the profiling data, we can use `snakeviz` to visualize it: ```shell pip install snakeviz -snakeviz cube_RANK_0.prof +snakeviz cube_RANK0.prof ``` ### Use viztracer @@ -93,10 +93,10 @@ An alternative to cProfile + snakeviz is to use the profiler `viztracer`, as well as its builtin visualization. `viztracer` is aware of the multi-process architecture of `torchrun` and it offers a command-line -interface and offers a very detailed profiling log. +interface and offers a very detailed profiling log, including the sequence, timing and durations. -P.S. However, too detailed to be effectively used to profile huge DAG like the 23k~ nodes unrolled -WRF model. +> P.S. However, too detailed to be effectively used to profile huge DAG like the 23k~ nodes unrolled +> WRF model, it would output very big log file and be very slow to render. `viztracer` can be used like: From dda04662a926b1156b1c5d341f5c135172b086f7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 14:39:27 +0800 Subject: [PATCH 1012/1892] fix weight reducer for replica --- cube/graph/gener/gen.py | 79 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 7 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 35f79f32..c64ffa07 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,5 +1,5 @@ import itertools -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Dict from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener @@ -135,18 +135,83 @@ def remove_anchor(graph: IRSegment): @staticmethod def gen_weight(graph: IRGraph) -> IRGraph: - weights = dict() + """ + Generate gradient accumulation + + Only suuport cases that: + + 1) each sub-tensor weight is consumed by different node cids (no replica) + 2) If the sub-tensor weight is consumed by same replicated node: + The consumers can be grouped by node cids and satisfy: + 1. same number of nodes per cid group + 2. same device set or no-overlapping device set per cid group + """ + # collect subtensor and consumer + fweights: Dict[IRFullTensor, List[IRSubTensor]] = dict() + fgrads: Dict[IRFullTensor, List[IRSubTensor]] = dict() + consumers: Dict[IRFullTensor, List[IRFwOperation]] = dict() for fnode in graph.nodes(flatten=True): if not isinstance(fnode, IRFwOperation): continue assert len(fnode.device) == 1 for wtensor in fnode.inputs(): if isinstance(wtensor, IRSubTensor) and wtensor.is_param(): if wtensor.grad is None: continue - if wtensor.parent not in weights: - weights[wtensor.parent] = dict() - if wtensor not in weights[wtensor.parent]: - weights[wtensor.parent][wtensor] = set() - weights[wtensor.parent][wtensor].add(wtensor.device[0]) + fweight = wtensor.parent + if fweight not in fweights: + fweights[fweight] = [] + fgrads[fweight] = [] + consumers[fweight] = [] + fweights[fweight].append(wtensor) + fgrads[fweight].append(wtensor.grad) + consumers[fweight].append(fnode) + + # bucketing + weights: Dict[IRFullTensor, Dict[IRSubTensor, List[int]]] = dict() + for fweight in fweights.keys(): + cids = set(fnode.cid for fnode in consumers[fweight]) + nl = '\n' + # case 1: no replica + if len(cids) == len(consumers[fweight]): + weights[fweight] = dict() + for wtensor, consumer in zip(fweights[fweight], consumers[fweight]): + if wtensor not in weights[fweight]: + weights[fweight][wtensor] = set() + weights[fweight][wtensor].add(consumer.device[0]) + # case 2: replica but has same number of replicas and same/no-overlapping devices + else: + cid_fnodes = {cid : [n for n in consumers[fweight] if n.cid == cid] for cid in cids} + cid_nnodes = [len(ns) for ns in cid_fnodes.values()] + # same replica# for each cid + assert all(cid_nnodes[0] == ns for ns in cid_nnodes), ( + f"If one of the weight consumers are replicated, " + f"other same-weight consumers should also replicated in same way." + f"FullTensor Weight: {fweight}\n" + f"Consumers:\n{nl.join([repr(node) for node in consumers[fweight]])}" + ) + cid_devs = {cid: set(n.device[0] for n in consumers[fweight]) for cid in cids} + # case 2.1: same device sharing + first = list(cid_devs.keys())[0] + if all(cid_devs[first] == devs for devs in cid_devs.values()): + #TODO: need to be more robust + continue + # case 2.2: no-overlapping device sharing + all_devs = set() + for devs in cid_devs.values(): + all_devs.update(devs) + if sum(len(devs) for devs in cid_devs.values()) == len(all_devs): + raise NotImplementedError( + f"Weight is consumed by multiple different operators.\n" + f"Replicating different operators on no-overlapping device group is not supported yet.\n" + f"FullTensor Weight: {fweight}\n" + f"Consumers:\n{nl.join([repr(node) for node in consumers[fweight]])}" + ) + else: + raise NotImplementedError( + f"Weight is consumed by multiple different operators.\n" + f"Replicating different operators on partial-overlapping device group is not supported yet.\n" + f"FullTensor Weight: {fweight}\n" + f"Consumers:\n{nl.join([repr(node) for node in consumers[fweight]])}" + ) reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() for ftensor, subtensors in weights.items(): From 9c3da56bda40b580e68897fa7f406feb513539fd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 14:39:58 +0800 Subject: [PATCH 1013/1892] align with megatron strategy --- examples/nlp/gpt/policy/spmd.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index c71e8639..f5f1f4ab 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -72,6 +72,15 @@ def PASMegatronTP(graph: IRGraph, resource): for ffn in ffns: _tp(graph, ffn, tp_devs, idx=1, dim=0) + # replicate embed + embeds = [node for node in fnodes if node.name == 'embedding'] + for embed in embeds: + _tp(graph, embed, tp_devs, idx=1, dim=0) + + # replicate last linear + linears = [node for node in fnodes if node.name == 'linear'] + _tp(graph, linears[-1], tp_devs, idx=1, dim=0) + # replicate other nodes for node in graph.nodes(): if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: From 0a470ac99f7160db021202ac83a8a2c71fe6a048 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 15:50:37 +0800 Subject: [PATCH 1014/1892] align with megatron: partition loss computation --- examples/nlp/gpt/policy/spmd.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index f5f1f4ab..f5e942bb 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -72,15 +72,20 @@ def PASMegatronTP(graph: IRGraph, resource): for ffn in ffns: _tp(graph, ffn, tp_devs, idx=1, dim=0) - # replicate embed + # partition embed embeds = [node for node in fnodes if node.name == 'embedding'] for embed in embeds: _tp(graph, embed, tp_devs, idx=1, dim=0) - # replicate last linear + # partition last linear linears = [node for node in fnodes if node.name == 'linear'] _tp(graph, linears[-1], tp_devs, idx=1, dim=0) + # partition loss + sums = [node for node in fnodes if node.name == 'sum'] + assert len(sums) == 1 + _tp(graph, sums[0], tp_devs, idx=0, dim=2) + # replicate other nodes for node in graph.nodes(): if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: From e9677fb78274bdf856d2e2e959fd9f9c2597776b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 16:18:33 +0800 Subject: [PATCH 1015/1892] remove cuda timer and synchronize --- cube/runtime/adapter/collectives.py | 29 ++++++++++++++--------------- cube/runtime/adapter/nn.py | 19 ++++++++----------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index e514f475..eb028ac2 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Optional import torch from cube.runtime.device import DeviceGroup @@ -96,14 +96,13 @@ def all_reduce(itensor: torch.Tensor, """ Allreduce """ - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') if not itensor.is_contiguous(): itensor = itensor.contiguous() itensor = itensor.detach() group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(itensor, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + # CudaTimer().stop(field_name='comm') return itensor @@ -111,17 +110,16 @@ def all_gather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tens """ Allgather """ - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') if not itensor.is_contiguous(): itensor = itensor.contiguous() group = DeviceGroup().get_group(ranks) tensor_list = [torch.empty_like(itensor) for _ in ranks] tensor_list[torch.distributed.get_rank(group)] = itensor.data torch.distributed.all_gather(tensor_list, itensor, group=group) - torch.cuda.synchronize() # concat otensor = torch.concat(tuple(tensor_list), dim=dim) - CudaTimer().stop(field_name='comm') + # CudaTimer().stop(field_name='comm') return otensor @@ -129,7 +127,7 @@ def reduce_scatter(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch. """ ReduceScatter """ - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') itensors = list(itensor.chunk(len(ranks), dim)) for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): @@ -137,8 +135,8 @@ def reduce_scatter(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch. group = DeviceGroup().get_group(ranks) otensor = torch.empty_like(itensors[0], requires_grad=False) torch.distributed.reduce_scatter(otensor, itensors, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + # torch.cuda.synchronize() + # CudaTimer().stop(field_name='comm') return otensor @@ -146,7 +144,7 @@ def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) - """ All-to-all """ - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') itensors = list(itensor.chunk(len(ranks), dim=odim)) for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): @@ -154,9 +152,9 @@ def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) - otensors = [torch.empty_like(t) for t in itensors] group = DeviceGroup().get_group(ranks) torch.distributed.all_to_all(otensors, itensors, group=group) - torch.cuda.synchronize() + # torch.cuda.synchronize() otensor = torch.concat(tuple(otensors), dim=idim) - CudaTimer().stop(field_name='comm') + # CudaTimer().stop(field_name='comm') return otensor @@ -184,7 +182,7 @@ def broadcast(input_tensors: List[torch.Tensor], """ Broadcast. ranks[0] is the root """ - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') assert len(input_tensors) == 1 or len(input_tensors) == 0 if len(input_tensors) == 1: tensor: torch.Tensor = input_tensors[0] @@ -198,7 +196,7 @@ def broadcast(input_tensors: List[torch.Tensor], tensor = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) group = DeviceGroup().get_group(ranks) torch.distributed.broadcast(tensor, ranks[0], group=group) - CudaTimer().stop(field_name='comm') + # CudaTimer().stop(field_name='comm') return tensor @@ -279,3 +277,4 @@ def scatter(input_tensors: List[torch.Tensor], torch.cuda.synchronize() CudaTimer().stop(field_name='comm') return output + \ No newline at end of file diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index 6dab1c6b..8c39e5f3 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -6,33 +6,31 @@ def _allreduce(itensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') if not itensor.is_contiguous(): itensor = itensor.contiguous() group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(itensor, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + # CudaTimer().stop(field_name='comm') return itensor def _allgather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') if not itensor.is_contiguous(): itensor = itensor.contiguous() group = DeviceGroup().get_group(ranks) tensor_list = [torch.empty_like(itensor) for _ in ranks] tensor_list[torch.distributed.get_rank(group)] = itensor.data torch.distributed.all_gather(tensor_list, itensor, group=group) - torch.cuda.synchronize() # concat otensor = torch.concat(tuple(tensor_list), dim=dim).requires_grad_() - CudaTimer().stop(field_name='comm') + # CudaTimer().stop(field_name='comm') return otensor def _reducescatter(itensor: torch.Tensor, dim:int, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') itensors = list(itensor.chunk(len(ranks), dim)) for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): @@ -40,14 +38,13 @@ def _reducescatter(itensor: torch.Tensor, dim:int, ranks: Tuple[int]) -> torch.T group = DeviceGroup().get_group(ranks) otensor = torch.empty_like(itensors[0]) torch.distributed.reduce_scatter(otensor, itensors, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + # CudaTimer().stop(field_name='comm') return otensor def _alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm') + # CudaTimer().start(field_name='comm') itensors = list(itensor.chunk(len(ranks), dim=odim)) for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): @@ -57,7 +54,7 @@ def _alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.distributed.all_to_all(otensors, itensors, group=group) torch.cuda.synchronize() otensor = torch.concat(tuple(otensors), dim=idim) - CudaTimer().stop(field_name='comm') + # CudaTimer().stop(field_name='comm') return otensor From 83fd3abf67f4c402b9d0a591e032affe54b15456 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 16:19:00 +0800 Subject: [PATCH 1016/1892] using manual timer --- examples/nlp/gpt/train.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index c7403e7f..18a40669 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -2,13 +2,14 @@ example: OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ + --nproc_per_node=8 \ --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMeshShard --fp16 + examples/nlp/gpt/train.py --policy PASMegatronTP --fp16 """ import torch +import time from examples.nlp.gpt.model import GPT from examples.nlp.gpt.model import GPTDataLoader @@ -71,25 +72,30 @@ def train_iter(model, dataloader): iter_num = 40 warmup = 8 for step in range(iter_num): - # if step == 0: - # model_summary(model, next(dataloader)) + if step == warmup: + torch.cuda.synchronize() + start = time.time() + # CudaTimer(enable=True).start('e2e') - if step >= warmup: - CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() - if step >= warmup: - CudaTimer().stop('e2e') if step == 0: print_each_rank('passed first iteration') if (step + 1) % 10 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) + torch.cuda.synchronize() + stop = time.time() + span = (stop - start) / (iter_num - warmup) * 1000 # ms + print_each_rank(f'span time: {span} ms') + + # CudaTimer(enable=True).stop('e2e') + # print_each_rank('e2e time (ms) per iteration: {} ms'.format( + # CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + # CudaTimer().print_all(times=iter_num-warmup) + memory_summary() From 1d49dbb36058387211a42697b71181c05570d45a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 16:20:39 +0800 Subject: [PATCH 1017/1892] add flag to align performance --- benchmark/megatron/benchmark_gpt.sh | 13 ++++++++++++- benchmark/megatron/pretrain_gpt_synthetic.py | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/benchmark/megatron/benchmark_gpt.sh b/benchmark/megatron/benchmark_gpt.sh index b786e40e..9580a124 100755 --- a/benchmark/megatron/benchmark_gpt.sh +++ b/benchmark/megatron/benchmark_gpt.sh @@ -23,7 +23,10 @@ GPT_ARGS="--num-layers 24 \ --no-query-key-layer-scaling \ --no-masked-softmax-fusion \ --no-bias-gelu-fusion \ - --no-bias-dropout-fusion" + --no-bias-dropout-fusion \ + --no-async-tensor-model-parallel-allreduce \ + --no-gradient-accumulation-fusion \ + --num-workers 0" DISTRIBUTED_ARGS="--nproc_per_node $GPUS \ --nnodes 1 \ @@ -40,4 +43,12 @@ OMP_NUM_THREADS=4 python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --pipeline-model-parallel-size 1 \ --DDP-impl torch +# OMP_NUM_THREADS=4 python -m torch.distributed.launch \ +# --nproc_per_node 1 \ +# --nnodes 1 \ +# --node_rank 0 \ +# --master_addr localhost \ +# --master_port 6000 \ +# pretrain_gpt_synthetic.py -h + cd .. \ No newline at end of file diff --git a/benchmark/megatron/pretrain_gpt_synthetic.py b/benchmark/megatron/pretrain_gpt_synthetic.py index 7494738c..1a6040cb 100644 --- a/benchmark/megatron/pretrain_gpt_synthetic.py +++ b/benchmark/megatron/pretrain_gpt_synthetic.py @@ -55,7 +55,7 @@ def get_batch(data_iterator): vocab_size = 50257 tokens = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size labels = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size - loss_mask = (torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()) < 0.5).float() + loss_mask = torch.ones(tokens.size(), dtype=torch.float, device=torch.cuda.current_device()) attention_mask = (torch.rand((args.micro_batch_size, 1, args.seq_length, args.seq_length), requires_grad=False, device=torch.cuda.current_device()) < 0.5) position_ids = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * args.seq_length From 22e3a3369f7fd94746c7ceee6406971656ff5a0f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 16:50:09 +0800 Subject: [PATCH 1018/1892] more efficient implementation for attention block --- examples/nlp/blocks/attention.py | 33 +++++++++++--------------------- examples/nlp/gpt/train.py | 17 +++++----------- 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 1867a431..b4d3ebda 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -2,20 +2,18 @@ import cube -@cube.graph.parser.register('L^ N E^, (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), (h+ d^) E^, (h+ d^), E^ (h+ d^) -> L^ N E^', name='self_attention') -def self_attention(query: torch.Tensor, - q_proj: torch.Tensor, q_bias: torch.Tensor, - k_proj: torch.Tensor, k_bias: torch.Tensor, - v_proj: torch.Tensor, v_bias: torch.Tensor, +@cube.graph.parser.register('L^ N E^, (h+ d^ 3) E^, (h+ d^ 3), E^ (h+ d^) -> L^ N E^', name='self_attention') +def self_attention(query: torch.Tensor, + qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, h: int, scale: float, dropout_p: float, mask: bool = True): num_head = h L, N = query.size(0), query.size(1) - dim_head = q_proj.size(0) // num_head + dim_head = qkv_proj.size(0) // num_head // 3 - q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) + qkv = torch.nn.functional.linear(query, qkv_proj, qkv_bias) # L N E, (h d 3) E -> L N (h d 3) + qkv = qkv.view(L, N, num_head * dim_head, 3) # L N (h d 3) -> L N (h d) 3 + q, k, v = qkv.chunk(3, dim=-1) # L N (3 h d) -> L N (h d), L N (h d), L N (h d) q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d @@ -143,25 +141,16 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa self.head_dim = inner_dim // num_heads self.scaling = self.head_dim ** -0.5 self.dropout_p = dropout - # Q - self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) - # K - self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) - # V - self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) + # QKV [(h d 3), E] + self.qkv_proj = torch.nn.Parameter(torch.empty(3 * inner_dim, embed_dim)) + self.qkv_bias = torch.nn.Parameter(torch.empty(3 * inner_dim)) # Out self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) def forward(self, query): attn = self_attention( - query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, + query, self.qkv_proj, self.qkv_bias, self.out_proj, self.num_heads, self.scaling, self.dropout_p, mask=True ) diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 18a40669..232bcc66 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -73,9 +73,7 @@ def train_iter(model, dataloader): warmup = 8 for step in range(iter_num): if step == warmup: - torch.cuda.synchronize() - start = time.time() - # CudaTimer(enable=True).start('e2e') + CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) optimizer.step() @@ -86,15 +84,10 @@ def train_iter(model, dataloader): if (step + 1) % 10 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - torch.cuda.synchronize() - stop = time.time() - span = (stop - start) / (iter_num - warmup) * 1000 # ms - print_each_rank(f'span time: {span} ms') - - # CudaTimer(enable=True).stop('e2e') - # print_each_rank('e2e time (ms) per iteration: {} ms'.format( - # CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - # CudaTimer().print_all(times=iter_num-warmup) + CudaTimer().stop('e2e') + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) memory_summary() From 1b9bec19e1f2a6ef6eece3f2334ee6b86d772e2c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 17:36:18 +0800 Subject: [PATCH 1019/1892] more efficient implementation for baddbmm --- examples/nlp/blocks/attention.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index b4d3ebda..25123811 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -17,12 +17,24 @@ def self_attention(query: torch.Tensor, q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # ======== replace the semantic into more efficient implementation ============ + # q = q.transpose(0, 1) # L (N h) d -> (N h) L d + # k = k.transpose(0, 1) # L (N h) d -> (N h) L d + # q = q * scale # (N h) L d, 1 -> (N h) L d + # k = k.transpose(1, 2) # (N h) L d -> (N h) d L + # attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + + # preallocating input tensor: (N h) L L + matmul_input_buffer = torch.empty([N * h, L, L], dtype=query.dtype, device=query.device) + # L (N h) d, L (N h) d -> (N h) L L + attn = torch.baddbmm( + matmul_input_buffer, + q.transpose(0, 1), # (N h) L d + k.transpose(0, 1).transpose(1, 2), # (N h) d L + beta=0.0, alpha=scale + ) + # ======== replace the semantic into more efficient implementation ============ # attention mask if mask: # (N h) L L -> (N h) L L @@ -36,6 +48,7 @@ def self_attention(query: torch.Tensor, attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L + v = v.transpose(0, 1) # L (N h) d -> (N h) L d output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) From f9b29b898d37ca10e19e5c4b7170d2ce5f1a0d24 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 13 Sep 2022 18:23:27 +0800 Subject: [PATCH 1020/1892] attention mask to be same with original one --- benchmark/megatron/pretrain_gpt_synthetic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmark/megatron/pretrain_gpt_synthetic.py b/benchmark/megatron/pretrain_gpt_synthetic.py index 1a6040cb..48345eff 100644 --- a/benchmark/megatron/pretrain_gpt_synthetic.py +++ b/benchmark/megatron/pretrain_gpt_synthetic.py @@ -56,7 +56,10 @@ def get_batch(data_iterator): tokens = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size labels = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size loss_mask = torch.ones(tokens.size(), dtype=torch.float, device=torch.cuda.current_device()) - attention_mask = (torch.rand((args.micro_batch_size, 1, args.seq_length, args.seq_length), requires_grad=False, device=torch.cuda.current_device()) < 0.5) + attention_mask = torch.tril(torch.ones( + (args.micro_batch_size, args.seq_length, args.seq_length), device=torch.cuda.current_device() + )).view(args.micro_batch_size, 1, args.seq_length, args.seq_length) + attention_mask = (attention_mask < 0.5) position_ids = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * args.seq_length return tokens, labels, loss_mask, attention_mask, position_ids From 59e4813c3ad01915f37bab15c95bff28b77d158e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 14 Sep 2022 13:53:17 +0800 Subject: [PATCH 1021/1892] fix mean op --- cube/graph/function/function.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index d5ae8a47..a758b0d2 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -481,8 +481,11 @@ def Sum(signature, inputs): def Mean(signature, inputs): - tensor = inputs[0] - dim = inputs[1] + if len(inputs) >= 2: + tensor, dim = inputs[:2] + elif len(inputs) == 1: + tensor = inputs[0] + dim = None einput = ShapeAnno.create_shape_str(tensor.shape) eoutput = copy.copy(einput) if dim is not None: From 0db6adac7772287ea839bb9692814031779c4c83 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Wed, 14 Sep 2022 04:44:12 -0700 Subject: [PATCH 1022/1892] save work --- examples/nlp/palm/palm.py | 75 ++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index 55468221..06c0a48b 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -3,13 +3,15 @@ OMP_NUM_THREADS=2 torchrun --nproc_per_node=2 --nnodes=1 palm.py """ +from typing import List + import torch import torch.nn.functional as F from torch import nn, einsum from math import log2, floor -from einops import rearrange, repeat +# from einops import rearrange, repeat from cube.graph import IRGraph @@ -140,17 +142,17 @@ def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): return x -@cube.graph.parser.register('N L^ E^, E^ F^ -> N L^ F^', name='feedforward1') +@cube.graph.parser.register('N L^ E+, E+ F -> N L^ F', name='feedforward1') def feedforward1(x: torch.Tensor, proj: torch.Tensor): return torch.nn.functional.silu(torch.matmul(x, proj)) -@cube.graph.parser.register('N L^ E^, E^ F^ -> N L^ F^', name='feedforward2') +@cube.graph.parser.register('N L^ E+, E+ F -> N L^ F', name='feedforward2') def feedforward2(x: torch.Tensor, proj: torch.Tensor): return torch.matmul(x, proj) -@cube.graph.parser.register('N L^ E^, N L^ E^, E^ F -> N L^ F', +@cube.graph.parser.register('N L^ E+, N L^ E+, E+ F -> N L^ F', name='feedforward3') def feedforward3(x: torch.Tensor, y: torch.Tensor, proj: torch.Tensor): return torch.matmul(x * y, proj) @@ -288,7 +290,7 @@ def PASData(graph: IRGraph, resource): ''' 2 way Data Parallel ''' - assert resource.ngpus == 2 + # assert resource.ngpus == 2 for node in graph.nodes(): if isinstance(node, IRDataOperation): @@ -311,11 +313,9 @@ def PASData(graph: IRGraph, resource): return graph -def PASBranch(graph: IRGraph, resource): - ''' - 2 way brach - ''' - assert resource.ngpus == 2 +def PASMegatron(graph: IRGraph, resource): + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) for node in graph.nodes(): if isinstance(node, IRDataOperation): @@ -325,10 +325,28 @@ def PASBranch(graph: IRGraph, resource): graph.assign(sub_node, idx) batch_dim = node.get_batch_dims()[0] + def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for dev_id, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, dev_id) + return sub_nodes + for node in graph.nodes(): if isinstance(node, IRFwOperation): - if node.name == 'embedding' or node.name == 'linear': - # data parallel + if node.name == 'embedding': + _tp(graph, node, tp_devs, idx=1, dim=0) + elif node.name == "linear": + _tp(graph, node, tp_devs, idx=1, dim=0) + elif node.name == 'multi_head_attention': + # TODO: data parallel current algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, @@ -337,21 +355,18 @@ def PASBranch(graph: IRGraph, resource): num=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) - elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add' or node.name == 'mean': - # replicate - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif node.name == 'feedforward': - graph.assign(node, 0) - elif node.name == 'multi_head_attention': - graph.assign(node, 1) + elif node.name == 'feedforward1': + _tp(graph, node, tp_devs, idx=1, dim=1) + elif node.name == 'feedforward2': + _tp(graph, node, tp_devs, idx=1, dim=1) + elif node.name == 'feedforward3': + _tp(graph, node, tp_devs, idx=2, dim=0) + elif node.name == 'mean': + _tp(graph, node, tp_devs, idx=0, dim=2) else: - assert False, node.name - + _replica(graph, node, tp_devs) return graph - def PASBranch3(graph: IRGraph, resource): ''' 3 way branch @@ -366,8 +381,10 @@ def PASBranch3(graph: IRGraph, resource): graph.assign(sub_node, idx) batch_dim = node.get_batch_dims()[0] + fnodes = [] for node in graph.nodes(): if isinstance(node, IRFwOperation): + fnodes.append(node) if node.name == 'embedding' or node.name == 'linear': # data parallel algo = node.algorithms('dim') @@ -389,7 +406,7 @@ def PASBranch3(graph: IRGraph, resource): graph.assign(node, 1) elif node.name == 'feedforward3': algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=2, dim=1, num=2) + sub_nodes = graph.partition(node, algo, idx=2, dim=0, num=2) graph.assign(sub_nodes[0], 0) graph.assign(sub_nodes[1], 1) elif node.name == 'multi_head_attention': @@ -397,6 +414,7 @@ def PASBranch3(graph: IRGraph, resource): else: assert False, node.name + # graph.recompute(fnodes) return graph @@ -421,9 +439,10 @@ def train(): batch_dims=(0, )) # @cube.compile(model, dataloader, PAS=PASSingle) - # @cube.compile(model, dataloader, PAS=PASData) # @cube.compile(model, dataloader, PAS=PASBranch) - @cube.compile(model, dataloader, PAS=PASBranch3) + # @cube.compile(model, dataloader, PAS=PASData) + # @cube.compile(model, dataloader, PAS=PASBranch3) + @cube.compile(model, dataloader, PAS=PASMegatron) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) @@ -454,4 +473,4 @@ def train_iter(model, dataloader): CudaTimer().print_all(times=iter_num - warmup) -train() +train() \ No newline at end of file From 07ea2eee4b1cc59a316dd083dcf997b52024b7e1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 15 Sep 2022 10:38:57 +0800 Subject: [PATCH 1023/1892] update timer to have predefined flag --- cube/profiler/timer.py | 108 ++++++++++++++++++---------- cube/runtime/adapter/collectives.py | 42 ++++++----- cube/runtime/adapter/nn.py | 38 +++++----- cube/runtime/adapter/reducer.py | 6 +- 4 files changed, 108 insertions(+), 86 deletions(-) diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index cd90da65..616be085 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -1,3 +1,4 @@ +from typing import Optional import time import sys @@ -30,88 +31,118 @@ def print_each_rank(msg, rank_only=None, outfile=''): class CudaTimer: r""" - Singleton Timer + Singleton Cuda Timer + + Note that frequently using timer may decrease the performance. + + The runtime predefines the timer on each communication primitive. + By default, the timer on communications are disabled for higher performance. + For users who want to analyze communication overhead, turn on the timer + by using `CudaTimer(enable=True, predefined=True)`. + + There are two switches to allow user to control the timer behaviour + + * enable: + the overall controller to turn on/off the all profiling. + * predefined: + the controller to turn on/off the predefined timer (mostly are communications) """ class __CudaTimer: - def __init__(self, **kwargs): + def __init__(self, enable = True, predefined = False): self.start_t = None self.stop_t = None self.field = dict() self.field_data = dict() - self.enabled = True - if 'enable' in kwargs: - self.enabled = kwargs['enable'] + self.enabled = enable + self.predefined = predefined instance = None - def __init__(self, enable = None): - if not CudaTimer.instance: - kwargs = dict() + def __init__(self, enable: Optional[bool] = None, predefined: Optional[bool] = None): + # not have instance + if not self.instance: + enable = enable if enable is not None else True + predefined = predefined if predefined is not None else False + CudaTimer.instance = CudaTimer.__CudaTimer(enable, predefined) + # have instance + else: if enable is not None: - kwargs = dict(enable=enable) - CudaTimer.instance = CudaTimer.__CudaTimer(**kwargs) - elif enable is not None: - CudaTimer.instance.enabled = enable - + self.instance.enabled = enable + if predefined is not None: + self.instance.predefined = predefined - def start(self, field_name='default'): + def start(self, field_name='default', predefined: bool = False): """ Start recording time on the the field Note `start` and `stop` on the same field can be called nestly """ - if not CudaTimer.instance.enabled: + if (not self.instance.enabled) or (predefined and not self.instance.predefined): return torch.cuda.synchronize() - if field_name not in CudaTimer.instance.field: - CudaTimer.instance.field[field_name] = list() - CudaTimer.instance.field_data[field_name] = 0 - CudaTimer.instance.field[field_name].append(time.time()) + start_time = time.time() + if field_name not in self.instance.field: + self.instance.field[field_name] = list() + self.instance.field_data[field_name] = 0 + self.instance.field[field_name].append(start_time) - def stop(self, field_name='default'): + def stop(self, field_name='default', predefined: bool = False): """ Return the time span from last `start` on the smae field name to now Returns: float (ms) """ - if not CudaTimer.instance.enabled: + if (not self.instance.enabled) or (predefined and not self.instance.predefined): return - if field_name not in CudaTimer.instance.field: + if field_name not in self.instance.field: raise RuntimeError("Missing start on the field") torch.cuda.synchronize() stop_time = time.time() - start_time = CudaTimer.instance.field[field_name].pop(-1) + start_time = self.instance.field[field_name].pop(-1) span = stop_time - start_time # in seconds - CudaTimer.instance.field_data[field_name] += span + self.instance.field_data[field_name] += span return span - def duration(self, times, field_name='default'): - if field_name not in CudaTimer.instance.field: + def duration(self, times: int, field_name: str = 'default') -> float: + """ + Get dthe total span (wall clock) of a field name. The span is divided by times. + + @param times int: division factor + @param filed_name str: the field name + + @return span float: wall clock in milliseconds. + """ + if field_name not in self.instance.field: raise RuntimeError(f"Missing start on the field {field_name}") - if len(CudaTimer.instance.field[field_name]) != 0: + if len(self.instance.field[field_name]) != 0: raise RuntimeError(f"timer for field {field_name} not stopped") - return CudaTimer.instance.field_data[field_name] / times * 1000 # in ms + return self.instance.field_data[field_name] / times * 1000 # in ms def __getattr__(self, name): return getattr(self.instance, name) def clear(self): - CudaTimer.instance = CudaTimer.__CudaTimer() + self.instance = CudaTimer.__CudaTimer() + + def print_all(self, times: int, rank_only: Optional[int] = None): + """ + Print the total span of each recorded field divided by `times` + + Note this should be called by each process - def print_all(self, times): + @param times int: division factor + @param rank_only Optional[int]: select only one rank for print + + @return None + """ msg = list() - comm_span = 0 - for field_name in CudaTimer.instance.field_data: + for field_name in self.instance.field_data: span = self.duration(times, field_name) - if 'send' in field_name or 'recv' in field_name: - comm_span += span msg.append('{} : {:.2f} ms'.format(field_name, span)) - # msg.append('{} : {:.2f} ms'.format('communication', comm_span)) msg = ' | '.join(msg) - - print_each_rank(msg) + print_each_rank(msg, rank_only) def warmup(self, seconds=1.0): """ @@ -123,9 +154,10 @@ def warmup(self, seconds=1.0): # warm up 1s if torch.distributed.is_initialized(): torch.distributed.barrier() + torch.cuda.synchronize() start = time.time() while time.time() - start < seconds: - out = torch.matmul(data1, data2) + _ = torch.matmul(data1, data2) # if torch.distributed.is_initialized(): # torch.distributed.all_reduce(out) torch.cuda.synchronize() diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index eb028ac2..02030eca 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -15,7 +15,7 @@ def send(tensor: torch.Tensor, dst: int): tensor_devices (List[List[int]]): tensor sent devices """ # print(f'{torch.distributed.get_rank()}: sending...') - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) send_ops = list() if not tensor.is_contiguous(): @@ -28,13 +28,13 @@ def send(tensor: torch.Tensor, dst: int): for req in reqs: req.wait() torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return tensor def recv(tensors: List[torch.Tensor], shape: List[int], dtype: torch.dtype, src: int): # print(f'{torch.distributed.get_rank()}: recving...') - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) ## synthetic ## # for shape in shapes: # recv_tensors.append( @@ -53,7 +53,7 @@ def recv(tensors: List[torch.Tensor], shape: List[int], dtype: torch.dtype, src: for req in reqs: req.wait() torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return tensor @@ -62,7 +62,7 @@ def sendrecv(input_tensors: List[torch.Tensor], output_dtypes: List[torch.dtype], send_ranks: List[int], recv_ranks: List[int]) -> List[torch.Tensor]: - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) # print('sending and recving...') ops = list() outputs = list() @@ -87,7 +87,7 @@ def sendrecv(input_tensors: List[torch.Tensor], for req in reqs: req.wait() torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return outputs @@ -96,13 +96,13 @@ def all_reduce(itensor: torch.Tensor, """ Allreduce """ - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) if not itensor.is_contiguous(): itensor = itensor.contiguous() itensor = itensor.detach() group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(itensor, group=group) - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return itensor @@ -110,7 +110,7 @@ def all_gather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tens """ Allgather """ - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) if not itensor.is_contiguous(): itensor = itensor.contiguous() group = DeviceGroup().get_group(ranks) @@ -119,7 +119,7 @@ def all_gather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tens torch.distributed.all_gather(tensor_list, itensor, group=group) # concat otensor = torch.concat(tuple(tensor_list), dim=dim) - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return otensor @@ -127,7 +127,7 @@ def reduce_scatter(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch. """ ReduceScatter """ - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) itensors = list(itensor.chunk(len(ranks), dim)) for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): @@ -135,8 +135,7 @@ def reduce_scatter(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch. group = DeviceGroup().get_group(ranks) otensor = torch.empty_like(itensors[0], requires_grad=False) torch.distributed.reduce_scatter(otensor, itensors, group=group) - # torch.cuda.synchronize() - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return otensor @@ -144,7 +143,7 @@ def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) - """ All-to-all """ - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) itensors = list(itensor.chunk(len(ranks), dim=odim)) for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): @@ -152,9 +151,8 @@ def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) - otensors = [torch.empty_like(t) for t in itensors] group = DeviceGroup().get_group(ranks) torch.distributed.all_to_all(otensors, itensors, group=group) - # torch.cuda.synchronize() otensor = torch.concat(tuple(otensors), dim=idim) - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return otensor @@ -182,7 +180,7 @@ def broadcast(input_tensors: List[torch.Tensor], """ Broadcast. ranks[0] is the root """ - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) assert len(input_tensors) == 1 or len(input_tensors) == 0 if len(input_tensors) == 1: tensor: torch.Tensor = input_tensors[0] @@ -196,7 +194,7 @@ def broadcast(input_tensors: List[torch.Tensor], tensor = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) group = DeviceGroup().get_group(ranks) torch.distributed.broadcast(tensor, ranks[0], group=group) - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return tensor @@ -207,7 +205,7 @@ def gather(input_tensors: List[torch.Tensor], """ Gather. ranks[0] is the root """ - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) assert len(input_tensors) == 1 input_tensor = input_tensors[0] dst = ranks[0] @@ -233,7 +231,7 @@ def gather(input_tensors: List[torch.Tensor], for req in reqs: req.wait() torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return tensor_list @@ -241,7 +239,7 @@ def scatter(input_tensors: List[torch.Tensor], output_shapes: List[List[int]], output_dtypes: List[torch.dtype], ranks: List[int]) -> List[torch.Tensor]: - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) output = None src = ranks[0] if DeviceGroup().rank == src: @@ -275,6 +273,6 @@ def scatter(input_tensors: List[torch.Tensor], for req in reqs: req.wait() torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return output \ No newline at end of file diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index 8c39e5f3..09b975a2 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -6,17 +6,17 @@ def _allreduce(itensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) if not itensor.is_contiguous(): itensor = itensor.contiguous() group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(itensor, group=group) - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return itensor def _allgather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) if not itensor.is_contiguous(): itensor = itensor.contiguous() group = DeviceGroup().get_group(ranks) @@ -25,12 +25,12 @@ def _allgather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tens torch.distributed.all_gather(tensor_list, itensor, group=group) # concat otensor = torch.concat(tuple(tensor_list), dim=dim).requires_grad_() - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return otensor def _reducescatter(itensor: torch.Tensor, dim:int, ranks: Tuple[int]) -> torch.Tensor: - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) itensors = list(itensor.chunk(len(ranks), dim)) for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): @@ -38,13 +38,12 @@ def _reducescatter(itensor: torch.Tensor, dim:int, ranks: Tuple[int]) -> torch.T group = DeviceGroup().get_group(ranks) otensor = torch.empty_like(itensors[0]) torch.distributed.reduce_scatter(otensor, itensors, group=group) - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return otensor def _alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: - - # CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) itensors = list(itensor.chunk(len(ranks), dim=odim)) for idx, tensor in enumerate(itensors): if not tensor.is_contiguous(): @@ -52,9 +51,8 @@ def _alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> otensors = [torch.empty_like(t) for t in itensors] group = DeviceGroup().get_group(ranks) torch.distributed.all_to_all(otensors, itensors, group=group) - torch.cuda.synchronize() otensor = torch.concat(tuple(otensors), dim=idim) - # CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return otensor @@ -234,10 +232,9 @@ def forward(ctx, input_: torch.Tensor, dst: int, ranks: List[int]): world_size = torch.distributed.get_world_size(group) if world_size == 1: return input_ - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) torch.distributed.reduce(input_, dst, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return input_ @staticmethod @@ -247,10 +244,9 @@ def backward(ctx, grad_output): world_size = torch.distributed.get_world_size(group) if world_size == 1: return grad_output, None, None - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) torch.distributed.broadcast(grad_output, src, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return grad_output, None, None @@ -264,10 +260,9 @@ def forward(ctx, input_: torch.Tensor, src: int, ranks: List[int]): world_size = torch.distributed.get_world_size(group) if world_size == 1: return input_ - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) torch.distributed.broadcast(input_, src, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return input_ @staticmethod @@ -277,10 +272,9 @@ def backward(ctx, grad_output): world_size = torch.distributed.get_world_size(group) if world_size == 1: return grad_output, None, None - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) if not grad_output.is_contiguous(): grad_output = grad_output.contiguous() torch.distributed.reduce(grad_output, dst, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) return grad_output, None, None diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index b59138da..d622e879 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -31,11 +31,9 @@ def allreduce(self): if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) - # TODO: figure out why Megatron needs this? - # param.main_grad = param.grad # for each bucket, do all-reduce for tp in buckets: - CudaTimer().start(field_name='comm') + CudaTimer().start(field_name='comm', predefined=True) bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = self._flatten_dense_tensors(grads) @@ -44,7 +42,7 @@ def allreduce(self): all_synced = self._unflatten_dense_tensors(coalesced, grads) for grad, synced in zip(grads, all_synced): grad.copy_(synced) - CudaTimer().stop(field_name='comm') + CudaTimer().stop(field_name='comm', predefined=True) def sync(self): """ From 9a084fdc73db9efd521e2d4e002e9ba62bc44f37 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Wed, 14 Sep 2022 19:51:38 -0700 Subject: [PATCH 1024/1892] reorg code structure --- examples/nlp/palm/palm.py | 155 +------------------------------ examples/nlp/palm/policy/mpmd.py | 49 ++++++++++ examples/nlp/palm/policy/spmd.py | 94 +++++++++++++++++++ 3 files changed, 147 insertions(+), 151 deletions(-) create mode 100644 examples/nlp/palm/policy/mpmd.py create mode 100644 examples/nlp/palm/policy/spmd.py diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index 211c3726..e155f0eb 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -1,8 +1,3 @@ -""" -2 way branch: - OMP_NUM_THREADS=2 torchrun --nproc_per_node=2 --nnodes=1 palm.py -""" - from typing import List import torch @@ -20,8 +15,8 @@ from cube.profiler.timer import print_each_rank from cube.ir.operator import IRDataOperation, IRFwOperation -import examples.mlp.policy.spmd as spmd -import examples.mlp.policy.mpmd as mpmd +import examples.nlp.palm.policy.spmd as spmd +import examples.nlp.palm.policy.mpmd as mpmd import argparse @@ -276,150 +271,8 @@ def forward(self, x): return self.net(x).mean() -def PASSingle(graph: IRGraph, resource): - assert resource.ngpus == 1 - - for node in graph.nodes(): - if isinstance(node, (IRDataOperation, IRFwOperation)): - graph.assign(node, 0) - - return graph - - -def PASData(graph: IRGraph, resource): - ''' - 2 way Data Parallel - ''' - # assert resource.ngpus == 2 - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] - - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=0, - dim=batch_dim, - num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return graph - - -def PASMegatron(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] - - def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for dev_id, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, dev_id) - return sub_nodes - - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - if node.name == 'embedding': - _tp(graph, node, tp_devs, idx=1, dim=0) - elif node.name == "linear": - _tp(graph, node, tp_devs, idx=1, dim=0) - elif node.name == 'multi_head_attention': - # TODO: data parallel current - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=0, - dim=batch_dim, - num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif node.name == 'feedforward1': - _tp(graph, node, tp_devs, idx=1, dim=1) - elif node.name == 'feedforward2': - _tp(graph, node, tp_devs, idx=1, dim=1) - elif node.name == 'feedforward3': - _tp(graph, node, tp_devs, idx=2, dim=0) - elif node.name == 'mean': - _tp(graph, node, tp_devs, idx=0, dim=2) - else: - _replica(graph, node, tp_devs) - return graph - -def PASBranch3(graph: IRGraph, resource): - ''' - 3 way branch - ''' - assert resource.ngpus == 3 - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] - - fnodes = [] - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - fnodes.append(node) - if node.name == 'embedding' or node.name == 'linear': - # data parallel - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=0, - dim=batch_dim, - num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add' or node.name == 'mean': - # replicate - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif node.name == 'feedforward1': - graph.assign(node, 0) - elif node.name == 'feedforward2': - graph.assign(node, 1) - elif node.name == 'feedforward3': - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=2, dim=0, num=2) - graph.assign(sub_nodes[0], 0) - graph.assign(sub_nodes[1], 1) - elif node.name == 'multi_head_attention': - graph.assign(node, 2) - else: - assert False, node.name - - # graph.recompute(fnodes) - return graph - - def train(): - bs, n, dim = 8, 2048, 4096 + bs, n, dim = 4, 2048, 4096 num_tokens, depth, heads, dim_head = 20000, 1, 16, 256 model = PaLM(dim, num_tokens, depth, heads=heads, dim_head=dim_head) @@ -442,7 +295,7 @@ def train(): # @cube.compile(model, dataloader, PAS=PASBranch) # @cube.compile(model, dataloader, PAS=PASData) # @cube.compile(model, dataloader, PAS=PASBranch3) - @cube.compile(model, dataloader, PAS=PASMegatron) + @cube.compile(model, dataloader, PAS=spmd.PASMegatron) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) diff --git a/examples/nlp/palm/policy/mpmd.py b/examples/nlp/palm/policy/mpmd.py new file mode 100644 index 00000000..7c1d3f90 --- /dev/null +++ b/examples/nlp/palm/policy/mpmd.py @@ -0,0 +1,49 @@ +from cube.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation + +def PASBranch3(graph: IRGraph, resource): + ''' + 3 way branch + ''' + assert resource.ngpus == 3 + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + batch_dim = node.get_batch_dims()[0] + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if node.name == 'embedding' or node.name == 'linear': + # data parallel + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, + algo, + idx=0, + dim=batch_dim, + num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add' or node.name == 'mean': + # replicate + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + elif node.name == 'feedforward1': + graph.assign(node, 0) + elif node.name == 'feedforward2': + graph.assign(node, 1) + elif node.name == 'feedforward3': + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=2, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + elif node.name == 'multi_head_attention': + graph.assign(node, 2) + else: + assert False, node.name + + return graph \ No newline at end of file diff --git a/examples/nlp/palm/policy/spmd.py b/examples/nlp/palm/policy/spmd.py new file mode 100644 index 00000000..257aa435 --- /dev/null +++ b/examples/nlp/palm/policy/spmd.py @@ -0,0 +1,94 @@ +from typing import List +from cube.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation + +def PASSingle(graph: IRGraph, resource): + assert resource.ngpus == 1 + + for node in graph.nodes(): + if isinstance(node, (IRDataOperation, IRFwOperation)): + graph.assign(node, 0) + + return graph + + +def PASData(graph: IRGraph, resource): + ''' + 2 way Data Parallel + ''' + # assert resource.ngpus == 2 + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + batch_dim = node.get_batch_dims()[0] + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, + algo, + idx=0, + dim=batch_dim, + num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph + + +def PASMegatron(graph: IRGraph, resource): + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + batch_dim = node.get_batch_dims()[0] + + def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for dev_id, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, dev_id) + return sub_nodes + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if node.name == 'embedding': + _tp(graph, node, tp_devs, idx=1, dim=0) + elif node.name == "linear": + _tp(graph, node, tp_devs, idx=1, dim=0) + elif node.name == 'multi_head_attention': + # TODO: data parallel current + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, + algo, + idx=0, + dim=batch_dim, + num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + elif node.name == 'feedforward1': + _tp(graph, node, tp_devs, idx=1, dim=1) + elif node.name == 'feedforward2': + _tp(graph, node, tp_devs, idx=1, dim=1) + elif node.name == 'feedforward3': + _tp(graph, node, tp_devs, idx=2, dim=0) + elif node.name == 'mean': + _tp(graph, node, tp_devs, idx=0, dim=2) + else: + _replica(graph, node, tp_devs) + return graph \ No newline at end of file From 45780d0fd74970d00dab52510d3d418cad0ed598 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Wed, 14 Sep 2022 22:55:01 -0700 Subject: [PATCH 1025/1892] add test code for 5 way branch parallelism --- examples/nlp/palm/palm.py | 5 +-- examples/nlp/palm/policy/mpmd.py | 55 ++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index e155f0eb..d74fc4e5 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -272,7 +272,7 @@ def forward(self, x): def train(): - bs, n, dim = 4, 2048, 4096 + bs, n, dim = 5, 2048, 4096 num_tokens, depth, heads, dim_head = 20000, 1, 16, 256 model = PaLM(dim, num_tokens, depth, heads=heads, dim_head=dim_head) @@ -295,7 +295,8 @@ def train(): # @cube.compile(model, dataloader, PAS=PASBranch) # @cube.compile(model, dataloader, PAS=PASData) # @cube.compile(model, dataloader, PAS=PASBranch3) - @cube.compile(model, dataloader, PAS=spmd.PASMegatron) + # @cube.compile(model, dataloader, PAS=spmd.PASMegatron) + @cube.compile(model, dataloader, PAS=mpmd.PASBranch5) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) diff --git a/examples/nlp/palm/policy/mpmd.py b/examples/nlp/palm/policy/mpmd.py index 7c1d3f90..ec85054e 100644 --- a/examples/nlp/palm/policy/mpmd.py +++ b/examples/nlp/palm/policy/mpmd.py @@ -46,4 +46,59 @@ def PASBranch3(graph: IRGraph, resource): else: assert False, node.name + return graph + +def PASBranch5(graph: IRGraph, resource): + ''' + 5 way branch + ''' + assert resource.ngpus == 5 + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + batch_dim = node.get_batch_dims()[0] + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if node.name == 'embedding' or node.name == 'linear': + # data parallel + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, + algo, + idx=0, + dim=batch_dim, + num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add' or node.name == 'mean': + # replicate + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + elif node.name == 'feedforward1': + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=1, dim=1, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + elif node.name == 'feedforward2': + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=1, dim=1, num=2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + elif node.name == 'feedforward3': + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=2, dim=0, num=4) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + graph.assign(sub_nodes[2], 2) + graph.assign(sub_nodes[3], 3) + elif node.name == 'multi_head_attention': + graph.assign(node, 4) + else: + assert False, node.name + return graph \ No newline at end of file From 94349c4d8e73857ece0c8e0118d62f55befd1575 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Sep 2022 11:29:55 +0800 Subject: [PATCH 1026/1892] move batch size reset into gen code --- cube/codegen/codegen.py | 20 ++++++++++++++++++++ cube/compiler.py | 38 +++++++++++++++++++------------------- cube/program.py | 3 ++- cube/runtime/module.py | 10 +++++++++- cube/runtime/syndata.py | 3 ++- 5 files changed, 52 insertions(+), 22 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 16158e32..ca7f1baa 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -3,6 +3,7 @@ """ import itertools from typing import Dict, Generator, Iterable, List, Any, Optional, Set, Tuple, Union +import warnings import torch import copy from more_itertools import split_when @@ -386,6 +387,8 @@ def __init__(self, execplan: ExecutionPlan): self.symbols = SymbolTable() # ref module to check shared variables self._ref_module = torch.nn.Module() + # batch size + self.batch_size = None def init_comm_groups(self): """ @@ -458,6 +461,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: elif isinstance(node, IRBpOperation): continue elif isinstance(node, IRDataOperation): + self.emit_batchsize_code(node) continue else: raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") @@ -894,6 +898,20 @@ def emit_reducer_call(self, node: IRWeightReducer): code = f'{reducer_name}.allreduce()' return [code] + def emit_batchsize_code(self, node: IRDataOperation): + """ + Emit batch size declare + """ + signature = 'self.set_batch_size({bs})' + bs = [t.shape[dim] for t, dim in zip(node.outputs(), node.get_batch_dims())] + bs = set(bs) + if len(bs) > 1: + warnings.warn(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') + bs = list(bs)[0] if len(bs) == 1 else None + assert self.batch_size is None or self.batch_size == bs, f"Not match for batch size: {self.batch_size} != {bs}" + self.model_init_statements.append(signature.format(bs=bs)) + self.batch_size = bs + def clear(self): """ Clear buffer that used for generating code @@ -904,6 +922,8 @@ def clear(self): self.model_methods_bodies: List[List[str]] = list() # module member name self.symbols = SymbolTable() + # batch size + self.batch_size = None class ScheduleCodeGen(CodeGen): diff --git a/cube/compiler.py b/cube/compiler.py index 9992350c..347d57c9 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -188,18 +188,6 @@ def decorator(fn: Callable) -> Callable: outfile = fname, attach=True ) - - # setup batch size - if not isinstance(dataloader, SciLoopVariables): - all_batch_size = set() - dnodes = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] - for dnode in dnodes: - bs = [out.shape[dim] for out, dim in zip(dnode.outputs(), dnode.get_batch_dims())] - all_batch_size.update(bs) - if len(all_batch_size) != 1: - raise NotImplementedError(f"Heterogenous batch size {bs} is not supported") - batch_size = torch.tensor(list(all_batch_size), dtype=torch.int).cuda() - compile_end = time.time() compile_time = compile_end - compile_start print('> compile time: {:.2f} seconds'.format(compile_time)) @@ -207,17 +195,29 @@ def decorator(fn: Callable) -> Callable: if torch.distributed.is_initialized(): torch.distributed.barrier() - # reset dataloader - if torch.distributed.is_initialized(): - torch.distributed.broadcast(batch_size, src=0) - batch_size = batch_size.item() - print_each_rank(f'reseting dataloader batch size to {batch_size}') - dataloader.set_batch_size(batch_size) - # load module filename = filename.format(myrank) print_each_rank(f'loading generated module from {filename} ...') model.load_module(filename, load_content=load_content) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # set dataloder batch size (serialize output) + bs = model.get_gen_module().get_batch_size() + if torch.distributed.is_initialized(): + for rank in range(torch.distributed.get_world_size()): + if rank == torch.distributed.get_rank(): + if bs is not None and dataloader is not None: + dataloader.set_batch_size(bs) + torch.distributed.barrier() + else: + if bs is not None and dataloader is not None: + dataloader.set_batch_size(bs) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + # load temporal schedule print_each_rank(f'loading generated schedule from {filename} ...') return _load_tschedule_fn(filename) diff --git a/cube/program.py b/cube/program.py index 0d88d512..3e82372b 100644 --- a/cube/program.py +++ b/cube/program.py @@ -9,6 +9,7 @@ from cube.graph import parser from cube.runtime.syndata import CubeDataLoader +from cube.runtime.module import CubeModule from cube.profiler.timer import print_each_rank import torch @@ -129,7 +130,7 @@ def __init__(self, model: torch.nn.Module, input_shapes): ) else: self.ir_graph = None - self._loaded_module = None + self._loaded_module: CubeModule = None def get_graph(self): return self.ir_graph diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 792627a6..852c79f4 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Optional import torch from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer @@ -14,6 +14,7 @@ def __init__(self): super().__init__() self._reducers = list() self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() + self._batch_size: Optional[int] = None def add_reducer(self, reducer: Reducer): if not isinstance(reducer, Reducer): @@ -38,6 +39,13 @@ def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: i assert hasattr(self, attr), f"{attr} is not in the module" self._fullmap[attr] = (tid, slicers, val_chunks) + def set_batch_size(self, bs: Optional[int]): + assert (bs is None) or (isinstance(bs, int) and bs > 0) + self._batch_size = bs + + def get_batch_size(self) -> Optional[int]: + return self._batch_size + def load_attr_content(self, filename: str): with torch.no_grad(): full = torch.load(filename) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 4f5ace18..e7059cd8 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -46,7 +46,8 @@ def set_batch_size(self, batch_size: int): self.batch_size = batch_size for shape, dim in zip(self.shapes, self.batch_dims): shape[dim] = batch_size - print(f'> data loader output shape change to: {self.shapes}') + rank = 0 if not torch.distributed.is_initialized() else torch.distributed.get_rank() + print(f'rank [{rank}]: > set batch size to {batch_size}. dataloader outputs change to: {self.shapes}') class SciLoopVariables(CubeDataLoader): From 8c32558eea23693004156fa7b8f1b3555d7ef682 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 16 Sep 2022 14:20:55 +0800 Subject: [PATCH 1027/1892] fix embedding partition bug --- cube/graph/function/function.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index a758b0d2..98423d5c 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -854,25 +854,8 @@ def Embedding(signature, inputs: List): start, stop = weight.indmap[0] else: start, stop = 0, weight.shape[0] - letters = iter(string.ascii_lowercase) - ishapes = [ - ShapeAnno.create_shape_str(itensor.shape, iterator=letters), - ShapeAnno.create_shape_str(weight.shape, iterator=letters) - ] - oshapes = [ishapes[0] + [ishapes[1][-1]]] - anno = OpAnno.create_op_str(ishapes, oshapes) - - # def embed_modifer(kwargs: Dict, idx, dim, num): - # import warnings - # warnings.warn('FIXME: The semantic is error when split embedding, but the computation cost is same.') - # kwargs = dict(**kwargs) - # kwargs['stop'] = kwargs['stop'] // num - # return kwargs - # rules = [TransformRule( - # [DimopSplit.R(), DimopSplit.D(0)], [DimopSplit.V()], embed_modifer - # )] - - return IRDimops(Embedding, 'embedding', signature, [anno], [itensor, weight], + annos = ['*, n+ e -> * e'] + return IRDimops(Embedding, 'embedding', signature, annos, [itensor, weight], padding_idx=padding_idx, start=start, stop=stop) From f8d0c5953076f959d9f62dac5ecba2d1dc2e7d95 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Thu, 15 Sep 2022 23:30:09 -0700 Subject: [PATCH 1028/1892] updt policy --- examples/nlp/palm/policy/mpmd.py | 51 +++++++++++++++++++------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/examples/nlp/palm/policy/mpmd.py b/examples/nlp/palm/policy/mpmd.py index ec85054e..5fc2eb8b 100644 --- a/examples/nlp/palm/policy/mpmd.py +++ b/examples/nlp/palm/policy/mpmd.py @@ -1,3 +1,4 @@ +from typing import List from cube.graph import IRGraph from cube.ir.operator import IRDataOperation, IRFwOperation @@ -48,37 +49,45 @@ def PASBranch3(graph: IRGraph, resource): return graph + +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + def PASBranch5(graph: IRGraph, resource): ''' 5 way branch ''' assert resource.ngpus == 5 + devs = list(range(resource.ngpus)) + for node in graph.nodes(): if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] + _replica(graph, node, devs) for node in graph.nodes(): if isinstance(node, IRFwOperation): - if node.name == 'embedding' or node.name == 'linear': - # data parallel - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=0, - dim=batch_dim, - num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add' or node.name == 'mean': - # replicate - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) + if node.name == 'embedding': + _tp(graph, node, devs, idx=1, dim=0) + elif node.name == 'linear': + _tp(graph, node, devs, idx=1, dim=0) + elif node.name == 'mean': + _tp(graph, node, devs, idx=0, dim=2) + elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add': + _replica(graph, node, devs) elif node.name == 'feedforward1': algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, idx=1, dim=1, num=2) @@ -101,4 +110,4 @@ def PASBranch5(graph: IRGraph, resource): else: assert False, node.name - return graph \ No newline at end of file + return graph From ed3691c9815abcf63aad4ab5b16607dcd4c6bbfc Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Fri, 16 Sep 2022 01:07:52 -0700 Subject: [PATCH 1029/1892] add zhiqi implementation --- examples/nlp/palm/policy/mpmd.py | 39 +++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/examples/nlp/palm/policy/mpmd.py b/examples/nlp/palm/policy/mpmd.py index 5fc2eb8b..9ee2d1f6 100644 --- a/examples/nlp/palm/policy/mpmd.py +++ b/examples/nlp/palm/policy/mpmd.py @@ -1,6 +1,7 @@ from typing import List from cube.graph import IRGraph from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.ir.tensor import IRSubTensor, IRFullTensor def PASBranch3(graph: IRGraph, resource): ''' @@ -66,6 +67,34 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): return sub_nodes +def convert_add_to_valmap(graph: IRGraph, add_node: IRFwOperation): + """ + Remove add node by replacing with tensor valmap + """ + + segment: IRSegment = graph.segment(add_node) + assert add_node.name == 'add' + nchunks = 0 + for itensor in add_node.inputs(): + assert isinstance(itensor, IRSubTensor) + nchunks += len(segment.producers(itensor.parent)) + ftensor: IRFullTensor = add_node.output(0).parent + vid = 0 + for itensor in add_node.inputs(): + parent = itensor.parent + for ptensor, producer in zip(segment.ptensors(parent), segment.producers(parent)): + idx = producer.outputs().index(ptensor) + new_ptensor = ftensor.select(ptensor.indmap, (vid, nchunks)) + with segment.update(producer): + producer.set_output(idx, new_ptensor) + segment.update_bwop(producer.mirror) + vid += 1 + + segment.remove(add_node) + assert add_node.mirror is not None + segment.remove(add_node.mirror) + + def PASBranch5(graph: IRGraph, resource): ''' 5 way branch @@ -86,7 +115,7 @@ def PASBranch5(graph: IRGraph, resource): _tp(graph, node, devs, idx=1, dim=0) elif node.name == 'mean': _tp(graph, node, devs, idx=0, dim=2) - elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add': + elif node.name == 'layernorm' or node.name == 'multiref': _replica(graph, node, devs) elif node.name == 'feedforward1': algo = node.algorithms('dim') @@ -107,7 +136,15 @@ def PASBranch5(graph: IRGraph, resource): graph.assign(sub_nodes[3], 3) elif node.name == 'multi_head_attention': graph.assign(node, 4) + elif node.name == 'add': + continue else: assert False, node.name + # adjust add + adds = [node for node in graph.nodes() if node.name == 'add'] + assert len(adds) == 2 + graph.assign(adds[0], 4) + convert_add_to_valmap(graph, adds[1]) + return graph From 04ca6ab34cb4087b0a7307e8f598d713d9e7648e Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Fri, 16 Sep 2022 02:49:25 -0700 Subject: [PATCH 1030/1892] fix tp bug --- examples/nlp/palm/policy/spmd.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/nlp/palm/policy/spmd.py b/examples/nlp/palm/policy/spmd.py index 257aa435..ba20e0a6 100644 --- a/examples/nlp/palm/policy/spmd.py +++ b/examples/nlp/palm/policy/spmd.py @@ -43,14 +43,6 @@ def PASMegatron(graph: IRGraph, resource): tp_size = resource.ngpus tp_devs = list(range(tp_size)) - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] - def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) @@ -64,6 +56,11 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): for dev_id, sub_node in zip(devs, sub_nodes): graph.assign(sub_node, dev_id) return sub_nodes + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + _replica(graph, node, tp_devs) + batch_dim = node.get_batch_dims()[0] for node in graph.nodes(): if isinstance(node, IRFwOperation): @@ -91,4 +88,4 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): _tp(graph, node, tp_devs, idx=0, dim=2) else: _replica(graph, node, tp_devs) - return graph \ No newline at end of file + return graph From 02a99a1e7009a52c2d8e5e8e6d6407bf426a5436 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 20 Sep 2022 03:33:38 +0000 Subject: [PATCH 1031/1892] Merged PR 1421: Gradient inference and valmap representation Changes made: 1) Gradient adjustment only happens inside graph.partition and graph.replicate. 2) Valmap in initial representation is changed by exponential increasement on nchunks. --- cube/graph/function/function.py | 12 + cube/graph/gener/gen.py | 343 ++++++++++++++------------- cube/graph/graph.py | 374 ++++++++++++++++-------------- cube/graph/parser/mapping.py | 2 + cube/graph/segment.py | 347 +++++++++++++++++++-------- cube/ir/cten.py | 23 +- cube/ir/tensor.py | 45 ++-- cube/program.py | 13 +- cube/runtime/function/function.py | 10 +- 9 files changed, 694 insertions(+), 475 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 98423d5c..93516cf1 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -762,6 +762,18 @@ def Pad(signature, inputs): return IRPad(signature, tensors, 'pad', pad=pad, mode=mode, value=value) +def Accum(signature, inputs: Tuple[IRTensor]): + """ + tensor = cube.runtime.function.accum(tensors) + """ + assert all(isinstance(t, IRTensor) for t in inputs) + signature = 'cube.runtime.function.accum' + iannos = [ShapeAnno.create_shape_str(t.shape) for t in inputs] + oannos = [copy.copy(iannos[0])] + anno = OpAnno.create_op_str(iannos, oannos) + return IRDimops(Cat, 'accum', signature, [anno], inputs) + + def Cat(signature, inputs: Tuple[List[IRTensor], int]): """ torch.cat(inputs: List[Tensor], dim: int) -> Tensor diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index c64ffa07..12ab0111 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,5 +1,5 @@ -import itertools from typing import Dict, List, Optional, Tuple, Dict +import numpy as np from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener @@ -11,7 +11,10 @@ from cube.ir.operator import IRFwOperation from cube.ir.adapter import IRAdapter, IRWeightReducer -from cube.graph.function.function import Add, Cat, Identity, MultiRef +from cube.graph.function.function import Accum, Cat, MultiRef + + +DeviceID = int class DummyInputOuput(IRFwOperation): @@ -101,8 +104,6 @@ def gen(graph: IRGraph) -> IRGraph: """ # remove anchor node graph = IRAdapterGener.remove_anchor(graph) - # update the gradient before generate adapter - graph = IRAdapterGener.update_grad(graph) # generate adapters for activation graph = IRAdapterGener.gen_activation(graph) # generate weight reducer @@ -112,15 +113,6 @@ def gen(graph: IRGraph) -> IRGraph: # print(graph.extra_repr()) return graph - @staticmethod - def update_grad(graph: IRSegment): - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): - graph.update_ftensor_bw(ftensor) - for node in graph.nodes(): - if isinstance(node, IRSegment) and node.isbw(): - IRAdapterGener.update_grad(node) - return graph @staticmethod def remove_anchor(graph: IRSegment): @@ -270,7 +262,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: IRAdapterGener.local_consumer_multiref(graph, ftensor) # print(graph.debug_tensor_map_str(ftensor)) - # print(graph.debug_tensor_map_str(ftensor.grad)) + # print(graph.mirror.debug_tensor_map_str(ftensor.grad)) # producers can be operators and graph inputs fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) @@ -364,119 +356,129 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens @return new_ftensor IRFullTensor: the new full tensor. If cannot fuse, the original ftensor. """ + if not ftensor.requires_grad: return ftensor - def like(tensor: IRSubTensor, share: Optional[IRFullTensor] = None) -> IRSubTensor: - parent = tensor.parent.like() if share is None else share - return parent.select(tensor.indmap, tensor.valmap) + devtensors: Dict[DeviceID, List[IRSubTensor]] = dict() + devops: Dict[DeviceID, List[IRCell]] = dict() - # collect device tensors - devtensors: Dict[int, List[IRSubTensor]] = dict() - # devid: old tensor -> [nodes,] - fuse_tensors: Dict[int, Dict[IRSubTensor, List[IRSubTensor]]] = dict() - tensor_map: Dict[int, Dict[IRSubTensor, IRSubTensor]] = dict() - - for tensor in graph.ptensors(ftensor): - for devid in tensor.device: + # collect producers for each device + for ptensor, producer in zip(graph.ptensors(ftensor), graph.producers(ftensor)): + for devid in ptensor.device: if devid not in devtensors: - devtensors[devid] = [] - fuse_tensors[devid] = dict() - tensor_map[devid] = dict() - devtensors[devid].append(tensor) - fuse_tensors[devid][tensor] = [tensor] - tensor_map[devid][tensor] = tensor - - nodes: List[IRFwOperation] = [] - for devid, tensors in devtensors.items(): - if len(tensors) == 1: - continue - - # repeatly search for combinable tensors - while True: - can_merge = False - out = None - node = None - for t1, t2 in itertools.combinations(tensors, 2): - catdim = t1.catdim(t2) - if catdim is not None: - t1, t2 = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] - out = t1.concat(t2, dim=catdim) - node = Cat( - 'torch.cat', - ([tensor_map[devid][t1], tensor_map[devid][t2]], catdim) - ) - can_merge = True - break - elif t1.accumable(t2): - out = t1.accum(t2) - node = Add( - 'torch.add', - [tensor_map[devid][t1], tensor_map[devid][t2]] - ) - can_merge = True - break - # each time when creats a merge node, the output will be - # updated with a new full tensor. The corresponding input - # will be set according to the previous node output - if can_merge: - tensor_map[devid][out] = like(out) - node.set_output(0, tensor_map[devid][out]) # update output to a new full tensor - tensors.remove(t1) - tensors.remove(t2) - tensors.append(out) - nodes.append(node) - node.device = devid - fuse_tensors[devid][out] = fuse_tensors[devid][t1] + fuse_tensors[devid][t2] - del fuse_tensors[devid][t1] - del fuse_tensors[devid][t2] - else: - break - - if len(nodes) == 0: return ftensor + devtensors[devid], devops[devid] = [], [] + devtensors[devid].append(ptensor) + devops[devid].append(producer) - # recompute - rcid = set(producer.recompute for producer in graph.producers(ftensor)) - rcid = list(rcid)[0] if len(rcid) == 1 else None - for node in nodes: - node.recompute = rcid + require_fusion = any(len(set(ts)) > 1 for ts in devtensors.values()) + if not require_fusion: return ftensor new_ftensor = ftensor.like() # update consumer - min_idx = min(graph.index(consumer) for consumer in graph.consumers(ftensor)) - assert len(graph.ctensors(ftensor)) == len(graph.consumers(ftensor)) for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): + itensor = new_ftensor.select(ctensor.indmap, ctensor.valmap) + igrad = new_ftensor.grad.select(ctensor.grad.indmap, ctensor.grad.valmap) with graph.update(consumer) as consumer: - consumer.set_input( - consumer.inputs().index(ctensor), - new_ftensor.select(ctensor.indmap, ctensor.valmap) - ) + idx = consumer.inputs().index(ctensor) + consumer.set_input(idx, itensor) + with graph.mirror.update(consumer.mirror) as bconsumer: + idx = bconsumer.outputs().index(ctensor.grad) + bconsumer.set_output(idx, igrad) - # insert new producer - for devid, tensors in fuse_tensors.items(): - for ptensor in tensors: - new_tensor = like(ptensor, share=new_ftensor) - if len(tensors[ptensor]) == 1: - node = Identity('', [ptensor]) - node.device = devid - node.set_output(0, new_tensor) - nodes.append(node) - else: - for node in nodes: - if node.output(0) == tensor_map[devid][ptensor]: - node.set_output(0, new_tensor) - - for node in nodes[::-1]: - assert node not in graph.nodes() - assert len(node.outputs()) == 1 - if graph.mirror is not None: - graph.finsert(node, min_idx) - else: - graph.insert(node, min_idx) + for devid in devtensors: + indmaps = [t.indmap for t in devtensors[devid]] + valmaps = [t.valmap for t in devtensors[devid]] + split_dim = len(set(indmaps)) > 1 + split_val = len(set(valmaps)) > 1 + assert not (split_dim and split_val), ( + f"Not support for simutaneously partitioning tensor dimension and tensor value.\n" + f"{graph.debug_tensor_map_str(ftensor)}" + ) - # update backward - if isinstance(ftensor.grad, IRFullTensor): - graph.update_ftensor_bw(new_ftensor) - graph.update_ftensor_bw(ftensor) + node = None + + # split dimension case + if split_dim: + catdim: int = None + for dim in range(len(ftensor.shape)): + dim_maps = [ind[dim] for ind in indmaps] + if set(len(dim_maps)) != 1: + assert catdim is None, ( + f"Not support for multi-dim partitioning on local producers.\n" + f"{graph.debug_tensor_map_str(ftensor)}" + ) + catdim = dim + assert catdim is not None + start_idx = np.array([ind[catdim][0] for ind in indmaps]) + indices = np.argsort(start_idx) + ptensors = [devtensors[devid][idx] for idx in indices] + try: + otensor = ptensors[0] + for t in ptensors[1:]: + otensor = otensor.concat(t, dim=catdim) + except Exception as e: + raise RuntimeError( + f"Device {devid}: Fail to concat local produced tensors on dimension: {catdim}\n" + f"Users can try to adjust node ordering to meet with concat order.\n" + f"{graph.debug_tensor_map_str(ftensor)}" + ) + # set concat input / output + node = Cat('torch.cat', (ptensors, catdim)) + node.set_output(0, new_ftensor.select(otensor.indmap, otensor.valmap)) + # set gradient + for idx, ptensor in enumerate(ptensors): + node.input(idx).grad = ftensor.grad.select(ptensor.indmap, (0,1)) + node.output(0).grad = new_ftensor.grad.select(otensor.indmap, (0,1)) + + # split value case + if split_val: + # reverse to meet with add order + ptensors = devtensors[devid] + try: + nchunks = [t.valmap[1] for t in ptensors] + if len(set(nchunks)) == 1: + otensor = ptensors[0].accum(ptensors[1:]) + else: + # the add order is to adapt with ordering valmap ordering: (3/4) (2/4) (0/2) + ptensors = ptensors[::-1] + otensor = ptensors[0] + for t in ptensors[1:]: + otensor = otensor.accum(t) + except Exception as e: + raise RuntimeError( + f"Device {devid}: Fail to accum local produced tensors\n" + f"Users can try to adjust node ordering to meet with accum order\n" + f"{graph.debug_tensor_map_str(ftensor)}" + ) + # set accum input / output + node = Accum('cube.runtime.accum', ptensors) + node.set_output(0, new_ftensor.select(otensor.indmap, otensor.valmap)) + # set gradient + for idx, ptensor in enumerate(ptensors): + node.input(idx).grad = ftensor.grad.select(ptensor.indmap, (0,1)) + node.output(0).grad = new_ftensor.grad.select(otensor.indmap, (0,1)) + + # no need for fusion, change the producer output to new tensor + if node is None: + for ptensor, producer in zip(devtensors[devid], devops[devid]): + otensor = new_ftensor.select(ptensor.indmap, ptensor.valmap) + ograd = new_ftensor.grad.select(otensor.grad.indmap, otensor.grad.valmap) + with graph.update(producer): + idx = producer.outputs().index(ptensor) + producer.set_input(idx, otensor) + producer.input(idx).grad = ograd + with graph.mirror.update(producer.mirror) as bproducer: + idx = bproducer.inputs().index(otensor.grad) + bproducer.set_input(idx, ograd) + else: + node.device = devid + # set recompute + rcid = set(producer.recompute for producer in devops[devid]) + rcid = list(rcid)[0] if len(rcid) == 1 else None + node.recompute = rcid + # insert + max_fid = max(graph.index(producer) for producer in devops[devid]) + graph.finsert(node, max_fid + 1) return new_ftensor @@ -496,60 +498,77 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): @param graph IRGraph @param ftensor IRFullTensor: the forward full tensor + + @return None """ - # collect to consumer tensors of each device - devtensors: Dict[int, Dict[IRSubTensor, List[IRCell]]] = dict() + if not ftensor.requires_grad: return + + devtensors : Dict[DeviceID, List[IRSubTensor]] = dict() + devops : Dict[DeviceID, List[IRCell]] = dict() + + # collect consumer of each device for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): for devid in ctensor.device: if devid not in devtensors: - devtensors[devid] = dict() - if ctensor not in devtensors[devid]: - devtensors[devid][ctensor] = [] - devtensors[devid][ctensor].append(consumer) - - # restrict each device has same subtensor - nl = '\n' - for devid in devtensors: - assert len(devtensors[devid]) <= 1, ( - "Detect that a full tensor is partitioned differently on a device.\n" - "To achieve this, need manually add multiref operator in model description.\n" - f"Full Tensor: {ftensor}\n" - f"Producers:\n{nl.join(repr(node) for node in graph.producers(ftensor))}\n" - f"Consumers:\n{nl.join(repr(node) for node in graph.consumers(ftensor))}" - ) + devtensors[devid], devops[devid] = [], [] + assert len(devtensors[devid]) == 0 or devtensors[devid][0] == ctensor, ( + f"Detect that a full tensor is partitioned differently on a device.\n" + f"To achieve this, need manually add multiref operator in model description.\n" + f"{graph.debug_tensor_map_str(ftensor)}" + ) + devtensors[devid].append(ctensor) + devops[devid].append(consumer) + + require_multiref = any(len(ops) > 1 for ops in devops.values()) + if not require_multiref: return - # add multiref forward node - multirefs: Dict[MultiRef, List[IRFwOperation]] = dict() for devid in devtensors: - for ctensor in devtensors[devid]: - consumers = devtensors[devid][ctensor] - if len(consumers) == 1: - continue - multiref = MultiRef(None, [ctensor, len(consumers)]) - multiref.infer_shape() - multiref.device = devid - ftensors = [ctensor.parent.like() for _ in range(len(consumers))] - itensors = [ft.select(ctensor.indmap, ctensor.valmap) for ft in ftensors] - for idx, itensor in enumerate(itensors): - multiref.set_output(idx, itensor) - - # update consumer - min_fidx = graph.nnodes - for itensor, consumer in zip(itensors, consumers): - with graph.update(consumer) as consumer: - idx = consumer.inputs().index(ctensor) - consumer.set_input(idx, itensor) - - # insert forward multiref - min_fidx = min(graph.index(consumer) for consumer in consumers) - if graph.mirror is not None: - graph.finsert(multiref, min_fidx) + grads: List[IRSubTensor] = [t.grad for t in devtensors[devid]] + try: + nchunks = [grad.valmap[1] for grad in grads] + if len(set(nchunks)) == 1: + accum_grad = grads[0].accum(grads[1:]) else: - graph.insert(multiref, min_fidx) - multirefs[multiref] = consumers - - if len(multirefs) > 0 and isinstance(ftensor.grad, IRFullTensor): - graph.update_ftensor_bw(ftensor) + # the add order is to adapt with ordering valmap ordering: (3/4) (2/4) (0/2) + accum_grad = grads[0] + for grad in grads[1:]: + accum_grad = accum_grad.accum(grad) + except Exception as e: + raise RuntimeError( + f"Device {devid}: Fail to accumulate local gradient: {ftensor.grad}\n" + f"Error information: {str(e)}\n" + f"Users can try:\n" + f" 1) Replicate all operators whose inputs have multi-consumed tensors\n" + f" 2) Partition all operators whose inputs have multi-consumed tensors\n" + f" 3) Mannually add cube.runtime.multiref in model description to divide replicated and partitioned groups\n" + f"{graph.debug_tensor_map_str(ftensor)}" + f"{graph.mirror.debug_tensor_map_str(ftensor.grad)}" + ) + + multiref = MultiRef(None, [devtensors[devid][0], len(grads)]) + # set input gradient + multiref.input(0).grad = accum_grad + # set output and its gradient + for idx, ctensor in enumerate(devtensors[devid]): + new_ftensor = ctensor.parent.like() + otensor = new_ftensor.select(ctensor.indmap, (0,1)) + multiref.set_output(idx, otensor) + multiref.output(idx).grad = new_ftensor.grad.select(ctensor.indmap, (0,1)) + # set corresponding consumer input and its backward + consumer = devops[devid][idx] + with graph.update(consumer): + while ctensor in consumer.inputs(): + fidx = consumer.inputs().index(ctensor) + consumer.set_input(fidx, otensor) + consumer.input(fidx).grad = new_ftensor.grad.select(ctensor.indmap, (0,1)) + with graph.mirror.update(consumer.mirror) as bconsumer: + while ctensor.grad in bconsumer.outputs(): + bidx = bconsumer.outputs().index(ctensor.grad) + bconsumer.set_output(bidx, new_ftensor.grad.select(ctensor.indmap, (0,1))) + # insert multiref + multiref.device = devid + min_fidx = min(graph.index(consumer) for consumer in devops[devid]) + graph.finsert(multiref, min_fidx) @staticmethod def fusion(graph: IRSegment) -> IRSegment: diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 8eec99b3..2e4ca693 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -122,17 +122,39 @@ def backward(self, loss: IRSubTensor): @return self IRGraph: None """ - assert loss in self.outputs() and tuple(loss.shape) == (1,), \ - f"backward should be in graph outputs and the loss is of shape [1,] (got {loss.shape})" - from cube.program import Program + # set mirror as self + self._mirror = self + + # infer gradient requirement + for node in self.nodes(): + itensors = [t for t in node.inputs() if isinstance(t, IRTensor)] + require_grad = any(t.requires_grad for t in itensors) + for otensor in node.outputs(): + if not isinstance(otensor, IRTensor): continue + if isinstance(otensor, IRSubTensor): + otensor.parent.requires_grad = require_grad + else: + otensor.requires_grad = require_grad + + # set loss gradient + assert tuple(loss.shape) == (1,), f"the loss should be of shape [1,] (got {loss.shape})" loss.parent.grad = 1.0 + + # infer gradient + for ftensor in self._ftensors: + self.infer_grad(ftensor) + # create backward node for fnode in self.nodes()[::-1]: assert not isinstance(fnode, IRSegment), "Internal Error: Segment should not appear for now" - if isinstance(fnode, IRFwOperation): - bnode: IRBpOperation = self.create_bwop(fnode) - Program().add_node(bnode) - # set program graph mirror to self - Program().mirror_as_self() + if not isinstance(fnode, IRFwOperation): continue + tensors = [t for t in fnode.inputs() + fnode.outputs() if isinstance(t, IRSubTensor)] + grads = [t.grad for t in tensors] + # no backward op generated for fnode + if all(grad is None for grad in grads): continue + # create backward op and insert to graph + bwop = self.create_bwop(fnode) + self.insert(bwop, self.nnodes) + return self @@ -141,6 +163,7 @@ def backward(self, loss: IRSubTensor): def group(self, fnodes: List[IRCell]) -> IRSegment: """! Group consecutive forward nodes into IRSegment. + Note the fnodes should not apply any transformation. TODO: update operator dependency The corresponding backward nodes will also be grouped. @@ -173,22 +196,26 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: assert maxbidx - minbidx + 1 == len(bnodes), \ f"Internal Error: backward nodes are not consecutive. maxbidx: {maxbidx}, minbidx: {minbidx}" - fsegment = fgraph.create_segment(fnodes) - # replace forward + # remove fnodes and insert fsegment + fsegment: IRSegment = fgraph.create_segment(fnodes) for fnode in fnodes: fidx = fgraph.remove(fnode) fgraph.insert(fsegment, fidx) - # replace backward + # reset fsegment gradient + for itensor in fsegment.inputs(): + fgraph.infer_grad(itensor.parent) + + # update backward if len(bnodes) > 0: - bsegment = fgraph.create_bwop(fsegment) if len(bnodes) > 0 else None + # remove backward nodes for bnode in bnodes: bidx = bgraph.remove(bnode) + # create new backward node + bnodes = [fsegment.create_bwop(fnode) for fnode in fnodes[::-1]] + # create and insert backward segment + bsegment = fgraph.create_bwop(fsegment) bgraph.insert(bsegment, bidx) - # setup gradient - self.update_bwop(bsegment) - - return fsegment # ========================== Graph Creation ======================== @@ -199,101 +226,13 @@ def from_logic_graph(nodes: List[IRCell], """ Generate IRGraph from logical graph (IRFullTensor) - Multiref will be inserted: - - e.g., original graph: - ``` - t = producer(xx) - ... - xx = consumer1(t) - ... - xx = consumer2(t) - ... - xx = consumer3(t) - ... - ``` - will be changed into: - ``` - t = producer(xx) - ... - t1, t2 = multiref(t) - xx = consumer1(t1) - ... - t3, t4 = multiref(t2) - xx = consumer2(t3) - ... - xx = consumer3(t4) - ... - ``` - """ - # handle multi-consumed tensor - consumers: Dict[IRFullTensor, List[IRCell]] = dict() - producers: Dict[IRFullTensor, IRCell] = dict() - for node in nodes: - ftensors = set() - for ftensor in node.inputs(): - # remove redundant tensors within an operator - if isinstance(ftensor, IRFullTensor) and ftensor.tid not in ftensors: - ftensors.add(ftensor.tid) - if ftensor not in consumers: - consumers[ftensor] = [] - consumers[ftensor].append(node) - for ftensor in node.outputs(): - if isinstance(ftensor, IRFullTensor): - producers[ftensor] = node - for ftensor, cnodes in consumers.items(): - if len(cnodes) == 1 or ftensor.is_attr(): continue - reftensor = ftensor - ctensor = ftensor - while len(cnodes) > 0: - consumer = cnodes.pop(0) - if len(cnodes) > 0: - itensors = [ftensor.like() for _ in range(2)] - multiref = MultiRef(None, [reftensor, 2]) - for idx, itensor in enumerate(itensors): - multiref.set_output(idx, itensor) - multiref.infer_shape() - # insert multiref right before the consumor - idx = nodes.index(consumer) - nodes.insert(idx, multiref) - ctensor, reftensor = itensors - else: - # the last consumer doesn't need multiref - ctensor = reftensor - # update consumer - while ftensor in consumer.inputs(): - idx = consumer.inputs().index(ftensor) - consumer.set_input(idx, ctensor) - - # another version to generate multiref: one for all - # for node in nodes: - # ftensors = set() - # for ftensor in node.inputs(): - # # remove redundant tensors within an operator - # if isinstance(ftensor, IRFullTensor) and ftensor._id not in ftensors: - # ftensors.add(ftensor._id) - # if ftensor not in consumers: - # consumers[ftensor] = [] - # consumers[ftensor].append(node) - # for ftensor in node.outputs(): - # if isinstance(ftensor, IRFullTensor): - # producers[ftensor] = node - # for ftensor, cnodes in consumers.items(): - # if len(cnodes) == 1 or ftensor.is_attr(): continue - # itensors = [ftensor.like() for _ in range(len(cnodes))] - # for itensor, consumer in zip(itensors, cnodes): - # while ftensor in consumer.inputs(): - # idx = consumer.inputs().index(ftensor) - # consumer.set_input(idx, itensor) - # # create and insert multiref operation - # multiref = MultiRef(None, [ftensor, len(cnodes)]) - # for idx, itensor in enumerate(itensors): - # multiref.set_output(idx, itensor) - # multiref.infer_shape() - # idx = nodes.index(producers[ftensor]) + 1 if ftensor in producers else 0 - # # idx = nodes.index(cnodes[0]) - # nodes.insert(idx, multiref) + @param nodes: nodes of the graph + @param inputs List[IRFullTensor]: graph inputs + @param outputs List[IRFullTensor]: graph outputs + @param module_name str: graph name + @return graph IRGraph + """ # instantiate graph inputs / outputs for idx, tensor in enumerate(inputs): if isinstance(tensor, IRFullTensor): @@ -307,7 +246,6 @@ def from_logic_graph(nodes: List[IRCell], # instantiate to subtensor for node in nodes: for idx, ftensor in enumerate(node.inputs()): - ftensors = set() if isinstance(ftensor, IRFullTensor): subtensor = ftensor.tosub() node.set_input(idx, subtensor) @@ -329,7 +267,9 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis The backward of the forward operation will automatically be replicated. - @param: node: Union[IRFwOperation, IRDataOperation] + @param node Union[IRFwOperation, IRDataOperation] + + @return ops List[IRCell]: the replicated operators """ if not isinstance(node, (IRFwOperation, IRDataOperation)): raise TypeError("Expected op to be forward op or data op") @@ -339,6 +279,11 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis fsegment: IRSegment = self.segment(node) # replicate fnodes = [node.replicate() for _ in range(times)] + # set gradient + for fnode in fnodes: + for rtensor, itensor in zip(fnode.inputs(), node.inputs()): + if isinstance(rtensor, IRSubTensor): + rtensor.grad = itensor.grad # insert forward for fnode in fnodes: if isinstance(node, IRFwOperation): @@ -349,15 +294,16 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis fsegment.replace(node, fnodes) # insert backward bsegment: IRSegment = fsegment.mirror - if isinstance(node.mirror, IRBpOperation): - bnodes = tuple(self.create_bwop(fnode) for fnode in fnodes[::-1]) - for bnode in bnodes: - bnode.device = node.device + if isinstance(node.mirror, IRCell): + bnodes = [node.mirror.replicate() for _ in range(times)] + for bnode, fnode in zip(bnodes, fnodes[::-1]): + IRCell.make_pair(fnode, bnode) + bnode.device = fnode.device bsegment.replace(node.mirror, bnodes) return fnodes def partition(self, node: Union[IRFwOperation, IRDataOperation], - algo: GenericDistAlgo, **config) -> Optional[List[IRCell]]: + algo: GenericDistAlgo, **config) -> List[IRCell]: """ Partition Primitive: - partition: partition a forward or data operation using algorithms. @@ -380,18 +326,58 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], @param algo GenericDistAlgo: the partition algorithm related to the node @param config Dict[str, Any]: the algorithm configuration, e.g., partition number - @return Optional[IRCell]: partitioned sub-nodes or None (fail to partition) + @return ops List[IRCell]: partitioned sub-nodes """ assert isinstance(algo, GenericDistAlgo) and node == algo.node, \ "The partition algorithm is not initialized for this node" assert isinstance(node, (IRFwOperation, IRDataOperation)), \ f"Only allow op to be forward op or data op, but got: {node}" - - fsegment: IRSegment = self.segment(node) + # get partitioned sub-nodes fnodes = algo.instantiate(**config) - assert fnodes is not None, f"Fail to partition node: {node} use algothim and config: {config}" - # update forward + assert fnodes is not None, f"Fail to partition node: {node} use algorithm and config: {config}" + + # set gradient + valmaps: Dict[IRFullTensor, ValueMap] = dict() + for t in node.inputs(): + if isinstance(t, IRSubTensor) and t.requires_grad: + valmaps[t.parent] = ValueMap(t.grad.valmap) + # set up consumers + ctensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() + consumers: Dict[IRFullTensor, List[IRCell]] = dict() + for fnode in fnodes: + for itensor in fnode.inputs(): + if not isinstance(itensor, IRSubTensor): continue + if not itensor.requires_grad: continue + if itensor.parent not in ctensors: + ctensors[itensor.parent] = [] + consumers[itensor.parent] = [] + ctensors[itensor.parent].append(itensor) + consumers[itensor.parent].append(fnode) + # set up gradient + for fnode in fnodes: + for itensor in fnode.inputs(): + if not isinstance(itensor, IRSubTensor): continue + if not itensor.requires_grad: continue + ftensor = itensor.parent + # the [::-1] only makes the valuemap to grow with execution order + cs = [c for c, t in zip(consumers[ftensor], ctensors[ftensor]) if t == itensor][::-1] + valmap = valmaps[itensor.parent].map((cs.index(fnode), len(cs))) + grad = ftensor.grad.select(itensor.indmap, valmap) + itensor.grad = grad + for otensor in fnode.outputs(): + if not isinstance(otensor, IRSubTensor): continue + if not otensor.requires_grad: + grad = None + else: + if isinstance(otensor.parent.grad, float): + grad = otensor.parent.grad + else: + grad = otensor.parent.grad.select(otensor.indmap, (0,1)) + otensor.grad = grad + + # insert forward node + fsegment: IRSegment = self.segment(node) for fnode in fnodes: if isinstance(node, IRFwOperation): fnode.recompute = node.recompute @@ -399,21 +385,17 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], fnode.comment = node.comment fnode.device = node.device fsegment.replace(node, fnodes) - # update backward - bsegment: IRSegment = fsegment.mirror - if isinstance(node.mirror, IRBpOperation): - bnodes = tuple(self.create_bwop(fnode) for fnode in fnodes[::-1]) - bsegment.replace(node.mirror, bnodes) + + # insert backward node + if isinstance(node.mirror, IRCell): + bnodes = [fsegment.create_bwop(fnode) for fnode in fnodes[::-1]] + assert isinstance(node.mirror, IRBpOperation) + assert len(bnodes) == len(fnodes) for bnode in bnodes: bnode.device = node.device - # update gradient - updated = set() - for itensor in [t for t in node.inputs() if isinstance(t, IRSubTensor)]: - for fnode in fsegment.consumers(itensor.parent): - bnode: IRBpOperation = fnode.mirror - if isinstance(bnode, IRBpOperation) and fnode.cid not in updated: - self.update_bwop(bnode) - updated.add(fnode.cid) + bsegment: IRSegment = fsegment.mirror + bsegment.replace(node.mirror, bnodes) + return fnodes ## Spatial Primitives ## @@ -598,6 +580,9 @@ def staging(self, nodes: Tuple[IRFwOperation]): The stage is a concept that is only about logical separation of nodes, it doesn't have additional constraints for device assignment. + This will keep each tensor to be only consumed once in + semantic representation. + Changes will be made: 1). Identity creation: @@ -607,30 +592,20 @@ def staging(self, nodes: Tuple[IRFwOperation]): stage 1: t1 = producer() stage 2: ... stage 3: xx = consume(t1) + xx = consume(t1) stage 4: ... stage 5: xx = consume(t1) then Identity nodes will be created for every device in stage2: stage 1: t1 = producer() stage 2: t2 = identity(t1) - stage 3: xx = consume(t2) - stage 4: t3 = identity(t2) - stage 5: xx = consume(t3) - - 2). REMOVED: Multiref Modification: - If a non-attribute tensor has multiref node to different devmeshes, - e.g., - stage 1: t1, t2 = multiref(t) - stage 2: xx = consume(t1) - stage 3: ... - stage 4: xx = consume(t2) - then the multiref will be transfered into identity operator: - stage 1: t1 = multiref(t) - stage 2: xx = consume(t1) - t2 = identity(t1) stage 3: t3 = identity(t2) - stage 4: xx = consume(t3) + xx = consume(t3) + xx = consume(t3) + stage 4: t4 = identity(t3) + stage 5: t5 = identity(t4) + xx = consume(t5) - @param starts Tuple[int]: the start index of each stage + @param starts Tuple[IRFwOperations]: the start node of each stage @return None """ assert all(isinstance(node, IRFwOperation) for node in nodes), \ @@ -669,46 +644,87 @@ def get_sid(fnode: IRCell) -> Optional[int]: return None def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: - identity = Identity('', [tensor]) - identity.infer_shape() - identity.set_output(0, identity.output(0).tosub()) - # insert forward - fidx = self.index(fstages[sid][0]) + fwop = Identity('', [tensor]) + fwop.infer_shape() + fwop.set_output(0, fwop.output(0).tosub()) if tensor.requires_grad: - self.finsert(identity, fidx) - bstages[sid].append(identity.mirror) + fwop.output(0).parent.requires_grad = True + # set input grad + igrad = tensor.parent.grad.select(tensor.indmap, tensor.valmap) + fwop.input(0).grad = igrad + # set output grad + otensor = fwop.output(0).parent + ograd = otensor.grad.select(tensor.indmap, (0,1)) + fwop.output(0).grad = ograd + # insert identity + fidx = self.index(fstages[sid][0]) + self.finsert(fwop, fidx) else: - self.insert(identity, fidx) - fstages[sid].insert(0, identity) - return identity + self.insert(fwop, fidx) + # update stage op group + fstages[sid].insert(0, fwop) + if isinstance(fwop.mirror, IRCell): + bstages[sid].append(fwop.mirror) + return fwop # create identity op for cross-stage dataflow - # the gradient flow of neighbor stages is automatically guaranteed for ftensor in self.full_tensors(): if ftensor.is_grad() or ftensor.is_attr(): continue + if len(self.consumers(ftensor)) == 0: continue + assert len(self.producers(ftensor)) <= 1, \ "The staging interface should be called before any operator partition." - if len(self.consumers(ftensor)) == 0: continue + ctensors = self.ctensors(ftensor) + if len(self.ctensors(ftensor)) > 0: + assert all(ctensor == ctensors[0] for ctensor in ctensors), ( + "The staging interface should be called before any operator partition." + ) + producer, ptensor = self.producers(ftensor)[0], self.ptensors(ftensor)[0] psid = get_sid(producer) # outside of stages, not consider if psid is None: continue + + # group consumers into stages + consumers = self.consumers(ftensor) + csids = [get_sid(consumer) for consumer in consumers] + buckets = [[] for _ in range(len(fstages))] + for idx, csid in enumerate(csids): + buckets[csid].append(consumers[idx]) + + # go through each stage to generate identity operators out = ptensor - curr_sid = psid - for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): - assert ctensor == ptensor, "The staging interface should be called before any operator partition." - csid = get_sid(consumer) - if curr_sid == csid: continue - for sid in range(curr_sid + 1, csid): - identity = insert_identity(out, sid) - out = identity.output(0) - # update consumer - with self.update(consumer) as consumer: - tidx = consumer.inputs().index(ptensor) - consumer.set_input(tidx, out) - curr_sid = csid - # update all its backward operators - self.update_ftensor_bw(ftensor.grad) + end_sid = max(csids) + 1 + for sid in range(psid + 1, end_sid): + # insert identity + op = insert_identity(out, sid) + out = op.output(0) + # calculate gradient + curr_valmap = ValueMap((0, 1)) + nconsumers = len(buckets[sid]) + fgrad = ftensor.grad + for cidx, consumer in enumerate(buckets[sid]): + if fgrad is None: + grad = None + elif isinstance(fgrad, float): + assert fgrad == 1.0, "Detect a backward tensor, but gradient can only be 1.0" + grad = fgrad + else: + valmap = curr_valmap.map((0, 2)) if cidx != nconsumers - 1 else curr_valmap + grad = fgrad.select(ptensor.indmap, valmap) + curr_valmap = curr_valmap.map((1, 2)) if cidx != nconsumers - 1 else curr_valmap + # update forward consumer + idx = consumer.inputs().index(ptensor) + ptensor = consumer.input(idx) + with self.update(consumer) as consumer: + consumer.set_input(idx, out) + consumer.input(idx).grad = grad + # update backward + if isinstance(consumer.mirror, IRCell): + with self.update(consumer.mirror) as bconsumer: + idx = bconsumer.outputs().index(ptensor.grad) + bconsumer.set_output(idx,grad ) + # grouping into segment for sid in range(len(fstages)): self.group(fstages[sid]) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 35a6257a..b6a2cafb 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -154,6 +154,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __rtemplate('multiref'): function.MultiRef, + __rtemplate('accum'): function.Accum, + #einops __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, diff --git a/cube/graph/segment.py b/cube/graph/segment.py index d8032ef0..d95a2842 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1,11 +1,13 @@ from contextlib import contextmanager from typing import Dict, Union, List, Optional, Set, Tuple -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap from cube.ir.cten import IRTensor, IRCell -from cube.ir.operator import IRFwOperation, IRBpOperation +from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation from cube.ir.adapter import IRAdapter +from cube.graph.function.function import MultiRef + class CellPosition: @@ -283,47 +285,54 @@ def ctensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: assert ftensor in self._ctensors, f"{ftensor} is not in the graph" return tuple(self._ctensors[ftensor]) - def grad(self, tensor: IRSubTensor, no_partial_overlap=False) -> IRSubTensor: - """ - Get gradient of the tensor. - - @param tensor IRSubTensor: IRSubTensor: the queried tensor - - @return gradient IRSubTensor: the gradient - """ - segment: IRSegment = self.segment(tensor.cell) - assert isinstance(tensor, IRSubTensor), "Only tensor has gradient" - fgrad = tensor.parent.grad - # None means no gradient requirement, flaot means its the loss - if fgrad is None or isinstance(fgrad, float): - tensor.grad = fgrad - return fgrad - ftensor = tensor.parent - # this tensor is consumed - if tensor in tensor.cell.inputs(): - consumer_cids = [] - for ctensor, consumer in zip(segment.ctensors(ftensor), segment.consumers(ftensor)): - if no_partial_overlap: - assert not (ctensor != tensor and ctensor.overlap(tensor)), ( - f"parital overlapping is not supported for gradient\n" - f"{self.debug_tensor_map_str(ctensor.parent)}" - ) - if ctensor == tensor and consumer.cid not in consumer_cids: - consumer_cids.append(consumer.cid) - - valmap = (consumer_cids.index(tensor.cell.cid), len(consumer_cids)) - grad = ftensor.grad.select( - indmap = tensor.indmap, - valmap = valmap - ) - # this tensor is produced - elif tensor in tensor.cell.outputs(): - grad = ftensor.grad.select( - indmap = tensor.indmap, - valmap = (0, 1), + def infer_grad(self, ftensor: IRFullTensor) -> None: + """ + Set gradient on sub-tensors of a fulltensor + + Note this can only be called when no operator transformation is + applied for this graph. + + @param ftensor IRFullTensor: the full tensor. + + @return None: gradient are set to producer/consumer tensor's .grad + """ + fgrad = ftensor.grad + # set for producer + assert len(self.producers(ftensor)) <= 1, ( + f"grad can only be set when no transformation is applied but got:\n" + f"{self.debug_tensor_map_str(ftensor)}" + ) + for ptensor, producer in zip(self.ptensors(ftensor), self.producers(ftensor)): + idx = producer.outputs().index(ptensor) + if fgrad is None: + grad = None + elif isinstance(fgrad, float): + assert fgrad == 1.0, "Detect a backward tensor, but gradient can only be 1.0" + grad = fgrad + else: + grad = fgrad.select(ptensor.indmap, (0, 1)) + producer.output(idx).grad = grad + # set for consumers + ctensors = self.ctensors(ftensor) + if len(ctensors) > 0: + assert all(ctensor == ctensors[0] for ctensor in ctensors), ( + f"grad can only be set when no transformation is applied but got:\n" + f"{self.debug_tensor_map_str(ftensor)}" ) - tensor.grad = grad - return grad + curr_valmap = ValueMap((0, 1)) + nconsumers = len(self.consumers(ftensor)) + for cidx, (ctensor, consumer) in enumerate(zip(self.ctensors(ftensor), self.consumers(ftensor))): + idx = consumer.inputs().index(ctensor) + if fgrad is None: + grad = None + elif isinstance(fgrad, float): + assert fgrad == 1.0, "Detect a backward tensor, but gradient can only be 1.0" + grad = fgrad + else: + valmap = curr_valmap.map((0, 2)) if cidx != nconsumers - 1 else curr_valmap + grad = fgrad.select(ctensor.indmap, valmap) + curr_valmap = curr_valmap.map((1, 2)) if cidx != nconsumers - 1 else curr_valmap + consumer.input(idx).grad = grad def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: dscp : str = '' @@ -343,11 +352,10 @@ def create_bwop(self, fwop: IRFwOperation) -> Union[IRBpOperation, IRBpOperation Create dummy backward operator for given forward operator """ assert isinstance(fwop, (IRFwOperation, IRSegment)), "Expected IRFwOperation" - fsegment: IRSegment = self.segment(fwop) fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] - igrads = [fsegment.grad(t) if t.requires_grad else None for t in fins] - ograds = [fsegment.grad(t) if t.requires_grad else None for t in fous] + igrads = [t.grad if t.requires_grad else None for t in fins] + ograds = [t.grad if t.requires_grad else None for t in fous] if isinstance(fwop, IRFwOperation): bwop = IRBpOperation(ograds, igrads) else: @@ -355,55 +363,6 @@ def create_bwop(self, fwop: IRFwOperation) -> Union[IRBpOperation, IRBpOperation bwop = IRSegment(bnodes, ograds, igrads) IRCell.make_pair(fwop, bwop) return bwop - - def update_bwop(self, bwop: IRCell) -> IRBpOperation: - """ - Update backward operator or a backward segment. - - This is neccessary when fwop is partitioned and reference count is changed. - - @param bwop IRBpOperation or IRSegment: the backward operation. - It can be at any hierarchy of this segemtn - - @return bwop IRBpOperation: the updated operation (inplace) - """ - assert isinstance(bwop, (IRBpOperation, IRSegment)) - if isinstance(bwop, IRSegment): - assert bwop.isbw() and (not bwop.isfw()) - bsegment: IRSegment = self.segment(bwop) - fsegment = bsegment.mirror - with bsegment.update(bwop): - fwop: Union[IRFwOperation, IRSegment] = bwop.mirror - fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] - fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] - igrads = [fsegment.grad(t) if t.requires_grad else None for t in fins] - ograds = [fsegment.grad(t) if t.requires_grad else None for t in fous] - for idx, igrad in enumerate(igrads): - bwop.set_output(idx, igrad) - # Ad-hoc fix: remove float that could be caused by loss for segment - if isinstance(bwop, IRSegment): - ograds = [grad for grad in ograds if isinstance(grad, IRSubTensor)] - for idx, ograd in enumerate(ograds): - bwop.set_input(idx, ograd) - return bwop - - def update_ftensor_bw(self, ftensor: IRFullTensor): - """ - Update all backward operators for a full tensor. - - @param ftensor IRFullTensor: the full tensor. If the full - tensor is not a gradient, will update backward operators - of ftensor.grad - - @return None - """ - fgrad = ftensor.grad if not ftensor.is_grad() else ftensor - if fgrad is None: - return - for producer in self.producers(fgrad): - self.update_bwop(producer) - for consumer in self.consumers(fgrad): - self.update_bwop(consumer) # ====================== Basic Graph manipulations ====================== @@ -573,6 +532,27 @@ def exist(self, node: IRCell) -> bool: return True return False + def select(self, name: Optional[str] = None, ntype: Optional[IRCell] = None) -> List[IRCell]: + """ + Select all the nodes (including nodes in sub-segment) that + satisfy the condition. + + @param name str: the node name + @param ntype Type: the node type + + @return nodes List[IRCell]: the nodes that have the name. + """ + nodes = [] + for node in self.nodes(flatten=True): + if name is not None: + if node.name != name: + continue + if ntype is not None: + if not isinstance(node, ntype): + continue + nodes.append(node) + return nodes + def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwOperation: """ Insert a forward node and create its backward. @@ -608,6 +588,185 @@ def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwO bsegment.insert(bwop, bidx) return fwop + # ===================== Advance Graph manipulations ================== + + def multiref(self, tensor: IRSubTensor, node_groups: List[List[IRFwOperation]]) -> IRFwOperation: + """ + Add multiref to separate nodes into different tensor alias. + Each other consumer that is not in the node_groups will be set as a group. + + @param tensor IRSubTensor: tensor. + @param node_groups List[List[IRFwOperation]]: operators that have tensor has input + """ + assert tensor.parent in self._ftensors + # add remaining consumers + node_groups = tuple(node_groups) + for consumer in self.consumers(tensor.parent): + if not any(consumer in nodes for nodes in node_groups): + node_groups = node_groups + ([consumer],) + # create new full tensors + ftensors = [tensor.parent.like() for _ in node_groups] + otensors = [ft.select(tensor.indmap, tensor.valmap) for ft in ftensors] + # update consumer + insert_idx = CellPosition((self.nnodes,)) + for fidx, nodes in enumerate(node_groups): + for node in nodes: + assert tensor in node.inputs() + idx = node.inputs().index(tensor) + with self.update(node): + node.set_input(idx, multiref.output(fidx)) + insert_idx = min(insert_idx, self.index(node)) + # create multiref + multiref = MultiRef('cube.runtime.function.multiref', [tensor, len(node_groups)]) + for idx, otensor in enumerate(otensors): + multiref.set_output(idx, otensor) + if len(tensor.device) > 0: + multiref.device = tensor.device + # set backward + if tensor.requires_grad: + # add multiref + multiref.input(0).grad = tensor.parent.grad.select(tensor.indmap, (0,1)) + for idx, output in enumerate(multiref.outputs()): + output.grad = ftensors[idx].grad.select(tensor.indmap, (0,1)) + self.finsert(multiref, insert_idx) + # update forward gradient + for ftensor in ftensors: + self.infer_grad(ftensor) + # update backward operator + for nodes in node_groups + ([multiref,]): + for fnode in nodes: + bidx = self.remove(fnode.mirror) + bnode = self.create_bwop(fnode) + self.insert(bidx, bnode) + else: + self.insert(multiref, insert_idx) + return multiref + + + def single_consume(self, one_for_all: bool = True): + """ + Transform graph to make each non-attribute tensor has up to + one consumer. Multiref nodes will be inserted. The API is useful + for cases like inference, where different consumers are partitioned + with different tensor dimensions. + + This should be called before any graph transformation. + + e.g., original graph: + + t = producer(xx) + ... + xx = consumer1(t) + ... + xx = consumer2(t) + ... + xx = consumer3(t) + ... + + If one_for_all is True, will be: + + t = producer(xx) + t1, t2, t3 = multiref(t) + ... + xx = consumer1(t1) + ... + xx = consumer2(t2) + ... + xx = consumer3(t3) + + Otherwise: + + t = producer(xx) + ... + t1, t2 = multiref(t) + xx = consumer1(t1) + ... + t3, t4 = multiref(t2) + xx = consumer2(t3) + ... + xx = consumer3(t4) + + + @param one_for_all bool: If True, + one single multiref node will be created for each fulltensor. Otherwise, + if a fulltensor has K consumers, then K-1 multiref nodes will be created. + + @return None + """ + consumers: Dict[IRFullTensor, List[IRCell]] = dict() + producers: Dict[IRFullTensor, IRCell] = dict() + if not one_for_all: + for node in self.nodes(): + ftensors = set() + for ftensor in node.inputs(): + # remove redundant tensors within an operator + if isinstance(ftensor, IRFullTensor) and ftensor.tid not in ftensors: + ftensors.add(ftensor.tid) + if ftensor not in consumers: + consumers[ftensor] = [] + consumers[ftensor].append(node) + for ftensor in node.outputs(): + if isinstance(ftensor, IRFullTensor): + producers[ftensor] = node + for ftensor, cnodes in consumers.items(): + if len(cnodes) == 1 or ftensor.is_attr(): continue + reftensor = ftensor + ctensor = ftensor + while len(cnodes) > 0: + consumer = cnodes.pop(0) + if len(cnodes) > 0: + itensors = [ftensor.like() for _ in range(2)] + multiref = MultiRef(None, [reftensor, 2]) + for idx, itensor in enumerate(itensors): + multiref.set_output(idx, itensor) + multiref.infer_shape() + # insert multiref right before the consumor + idx = self.index(consumer) + # require backward + if any(itensor.requires_grad for itensor in node.inputs()): + self.finsert(multiref, idx) + else: + self.insert(multiref, idx) + ctensor, reftensor = itensors + else: + # the last consumer doesn't need multiref + ctensor = reftensor + # update consumer + while ftensor in consumer.inputs(): + idx = consumer.inputs().index(ftensor) + consumer.set_input(idx, ctensor) + else: + for node in self.nodes(): + ftensors = set() + for ftensor in node.inputs(): + # remove redundant tensors within an operator + if isinstance(ftensor, IRFullTensor) and ftensor._id not in ftensors: + ftensors.add(ftensor._id) + if ftensor not in consumers: + consumers[ftensor] = [] + consumers[ftensor].append(node) + for ftensor in node.outputs(): + if isinstance(ftensor, IRFullTensor): + producers[ftensor] = node + for ftensor, cnodes in consumers.items(): + if len(cnodes) == 1 or ftensor.is_attr(): continue + itensors = [ftensor.like() for _ in range(len(cnodes))] + for itensor, consumer in zip(itensors, cnodes): + while ftensor in consumer.inputs(): + idx = consumer.inputs().index(ftensor) + consumer.set_input(idx, itensor) + # create and insert multiref operation + multiref = MultiRef(None, [ftensor, len(cnodes)]) + for idx, itensor in enumerate(itensors): + multiref.set_output(idx, itensor) + multiref.infer_shape() + idx = self.index(producers[ftensor]) + 1 if ftensor in producers else 0 + # idx = nodes.index(cnodes[0]) + if any(itensor.requires_grad for itensor in node.inputs()): + self.finsert(multiref, idx) + else: + self.insert(multiref, idx) + # ====================== Graph Generations ============================ @staticmethod diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 4ffba98b..82060b6d 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -51,7 +51,7 @@ def __init__(self, self.name: str = name self.signature = signature - self._device = list() + self._device: Tuple[int] = () # source tensors self._inputs: List[Optional[IRTensor]] = [None,] * input_length @@ -83,8 +83,8 @@ def cid(self) -> int: return self._id @property - def device(self): - return copy.copy(self._device) + def device(self) -> Tuple[int]: + return self._device @device.setter def device(self, device_id: Union[int, List[int]]): @@ -92,10 +92,10 @@ def device(self, device_id: Union[int, List[int]]): Set the operation device. """ if isinstance(device_id, int): - device_id = [device_id] + device_id = (device_id,) if not all([isinstance(devid, int) for devid in device_id]): raise KeyError("Require device Union[int, List[int]]") - self._device = copy.copy(list(device_id)) + self._device = tuple(device_id) @property def mirror(self): @@ -450,7 +450,7 @@ def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): self._is_grad: bool = False # tensor gradient - self._requires_grad: bool = True + self._requires_grad: bool = False self._grad: Optional[Union[IRTensor, float]] = None @property @@ -607,14 +607,17 @@ def nelement(self) -> int: cnt *= num return cnt - def backward(self) -> IRCell: + def backward(self) -> None: """ Autograd backward on the tensor - @return graph IRGraph: the forward + backward graph + The backward will apply on the program graph + + @return None """ - return self.cell.backward(self) - + from cube.program import Program + graph = Program().get_graph() + return graph.backward(self) def __repr__(self): dscp = f'Tensor(id={self._id}, shape={self.shape}, device={self.device})' diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index a9e6793e..c110f850 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -26,7 +26,7 @@ from typing import List, Optional, Union, Tuple, NewType, Dict -from cube.ir.cten import IRCell, IRTensor +from cube.ir.cten import IRTensor from cube.ir.dtype import IRDType StartEnd = NewType('[start:end)', Tuple[int, int]) @@ -256,7 +256,7 @@ class IRFullTensor(IRTensor): the sequentail execution order by its graph. """ - def __init__(self, shape=None, name=None, requires_grad=True, dtype=IRDType.unknown): + def __init__(self, shape=None, name=None, requires_grad=False, dtype=IRDType.unknown): super().__init__(shape, name, dtype) @@ -301,7 +301,7 @@ def grad(self, val: Optional[Union[IRTensor, float]]): assert val.shape == self.shape assert val.is_attr() == self.is_attr() else: - assert val is None + assert val is None, "The FullTensor doesn't require grad but is assigned with a grad." self._grad = val @property @@ -309,15 +309,17 @@ def requires_grad(self): return self._requires_grad @requires_grad.setter - def requires_grad(self, val: bool): - self._requires_grad = val - if val and self.grad is None: - grad = IRFullTensor( - self.shape, 'g' + self.name, - requires_grad=False, dtype=self.dtype - ).as_grad(self.is_attr()) - self._grad = grad - elif not val and self.grad is not None: + def requires_grad(self, req_grad: bool): + if req_grad: + self._requires_grad = True + if self._grad is None: + grad = IRFullTensor( + self.shape, 'g' + self.name, + requires_grad=False, dtype=self.dtype + ).as_grad(self.is_attr()) + self._grad = grad + else: + self._requires_grad = False self._grad = None @property @@ -402,7 +404,11 @@ def tosub(self): return sub_tensor def __repr__(self): - dscp = f'FullTensor(id={self._id}, shape={self.shape}, device={self.device})' + dscp = f'FullTensor(id={self._id}, shape={self.shape})' + return dscp + + def extra_repr(self) -> str: + dscp = f'FullTensor(id={self._id}, shape={self.shape}, req_grad={self.requires_grad}, is_param={self.is_param()}, is_buff={self.is_buffer()}, is_grad={self.is_grad()})' return dscp @@ -554,7 +560,15 @@ def accumable(self, tensors: Union[IRTensor, List[IRTensor]]) -> bool: return False if any(t.valmap[1] != self.valmap[1] for t in tensors): return False - return self.valmap[1] % (len(tensors) + 1) == 0 + if self.valmap[1] % (len(tensors) + 1) != 0: + return False + # consecutive + cids = tuple(t.valmap[0] for t in [self] + tensors) + if len(set(cids)) != len(cids) or max(cids) - min(cids) + 1 != len(cids): + return False + if min(cids) % len(cids) != 0: + return False + return True def accum(self, tensors: Union[IRTensor, List[IRTensor]]) -> IRTensor: """! @@ -564,11 +578,10 @@ def accum(self, tensors: Union[IRTensor, List[IRTensor]]) -> IRTensor: @param: tensors Union[IRTensor, List[IRTensor]] @return tensor IRSubTensor: accumulated tensor """ + # print(f'try accuming: {self.extra_repr()} and {tensors.extra_repr()}') tensors: List[IRSubTensor] = [tensors,] if isinstance(tensors, IRSubTensor) else tensors assert self.accumable(tensors), "Not accumable" nreduce = len(tensors) + 1 - assert self.valmap[1] % nreduce == 0 - # TODO: make accum more robust cid = min(t.valmap[0] for t in [self] + tensors) // nreduce valmap = (cid, self.valmap[1] // nreduce) indmap = self.indmap diff --git a/cube/program.py b/cube/program.py index 3e82372b..1ca0f409 100644 --- a/cube/program.py +++ b/cube/program.py @@ -39,18 +39,7 @@ def add_nodes(self, nodes: List[IRCell]): for node in nodes: self.add_node(node) - def get_graph(self): - has_bp = any(isinstance(node, IRBpOperation) for node in self.instance._graph.nodes()) - if not has_bp: - for ftensor in self.instance._graph.full_tensors(): - ftensor.requires_grad = False - for node in self.instance._graph.nodes(flatten=True): - for itensor in node.inputs(): - if isinstance(itensor, IRSubTensor): - itensor.grad = None - for otensor in node.outputs(): - if isinstance(otensor, IRSubTensor): - otensor.grad = None + def get_graph(self) -> IRGraph: return self.instance._graph def set_output(self, outputs: List[IRTensor]): diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 1abb4ccb..b80e9ae3 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -20,8 +20,14 @@ def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: """ identity forward. Create multiple same tensor. """ - assert times > 1, "multiref can only be used for num of tensor >= 2" - return tuple([tensor] * times) + return tensor if times == 1 else tuple([tensor] * times) + + +def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: + """ + accumulate tensors in to one tensor + """ + return torch.sum(torch.stack(tensors, dim=0), dim=0) def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], From cb86055b71424caa84969fd9d76a9c4d59f79051 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Mon, 19 Sep 2022 23:55:32 -0700 Subject: [PATCH 1032/1892] save work, current time 373 ms --- examples/nlp/palm/palm.py | 6 +- examples/nlp/palm/policy/mpmd.py | 132 +++++++++++++++---------------- 2 files changed, 69 insertions(+), 69 deletions(-) diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index d74fc4e5..505b155c 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -227,8 +227,9 @@ def get_mask(self, n, device): return mask def forward(self, in_x): - bs, n, device = in_x.shape[0], in_x.shape[1], in_x.device + in_x = cube.runtime.function.identity(in_x) + residual = in_x # pre layernorm x = self.norm(in_x) @@ -239,8 +240,7 @@ def forward(self, in_x): ff2 = feedforward2(x, self.ff_proj2) ff_out = feedforward3(ff1, ff2, self.ff_proj3) - return in_x + attn_out + ff_out - + return attn_out + ff_out + residual class PaLM(nn.Module): diff --git a/examples/nlp/palm/policy/mpmd.py b/examples/nlp/palm/policy/mpmd.py index 9ee2d1f6..d90c8522 100644 --- a/examples/nlp/palm/policy/mpmd.py +++ b/examples/nlp/palm/policy/mpmd.py @@ -71,28 +71,43 @@ def convert_add_to_valmap(graph: IRGraph, add_node: IRFwOperation): """ Remove add node by replacing with tensor valmap """ - - segment: IRSegment = graph.segment(add_node) assert add_node.name == 'add' - nchunks = 0 - for itensor in add_node.inputs(): - assert isinstance(itensor, IRSubTensor) - nchunks += len(segment.producers(itensor.parent)) - ftensor: IRFullTensor = add_node.output(0).parent - vid = 0 + ptensors, producers = [], [] for itensor in add_node.inputs(): - parent = itensor.parent - for ptensor, producer in zip(segment.ptensors(parent), segment.producers(parent)): - idx = producer.outputs().index(ptensor) - new_ptensor = ftensor.select(ptensor.indmap, (vid, nchunks)) - with segment.update(producer): - producer.set_output(idx, new_ptensor) - segment.update_bwop(producer.mirror) - vid += 1 - - segment.remove(add_node) - assert add_node.mirror is not None - segment.remove(add_node.mirror) + iptensors = graph.ptensors(itensor.parent) + assert len(set(t.valmap for t in iptensors)) == len(iptensors) + ptensors += iptensors + producers += graph.producers(itensor.parent) + ftensor = add_node.output(0).parent + for idx, (ptensor, producer) in enumerate(zip(ptensors, producers)): + fidx = producer.outputs().index(ptensor) + bidx = producer.mirror.inputs().index(ptensor.grad) + ptensor = ftensor.select(ptensor.indmap, (idx, len(producers))) + ptensor.grad = ftensor.grad.select(ptensor.indmap, (0,1)) + with graph.update(producer): + producer.set_output(fidx, ptensor) + with graph.mirror.update(producer.mirror) as bnode: + bnode.set_input(bidx, ptensor.grad) + graph.remove(add_node) + graph.mirror.remove(add_node.mirror) + + +def flatten_branch_grad(graph: IRGraph, ftensor: IRFullTensor): + """ + Flatten valmap for different branches. + """ + assert ftensor.requires_grad + ctensors = graph.ctensors(ftensor) + consumers = graph.consumers(ftensor) + # same tinput ensor + assert all(ctensor == ctensors[0] for ctensor in ctensors) + # different gradient (no replicate) + assert len(set(ctensor.grad.valmap for ctensor in ctensors)) == len(ctensors) + for idx, (consumer, ctensor) in enumerate(zip(consumers, ctensors)): + with graph.mirror.update(consumer.mirror) as bnode: + tidx = bnode.outputs().index(ctensor.grad) + ctensor.grad = ftensor.grad.select(ctensor.indmap, (idx, len(ctensors))) + bnode.set_output(tidx, ctensor.grad) def PASBranch5(graph: IRGraph, resource): @@ -100,51 +115,36 @@ def PASBranch5(graph: IRGraph, resource): 5 way branch ''' assert resource.ngpus == 5 - devs = list(range(resource.ngpus)) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - _replica(graph, node, devs) - - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - if node.name == 'embedding': - _tp(graph, node, devs, idx=1, dim=0) - elif node.name == 'linear': - _tp(graph, node, devs, idx=1, dim=0) - elif node.name == 'mean': - _tp(graph, node, devs, idx=0, dim=2) - elif node.name == 'layernorm' or node.name == 'multiref': - _replica(graph, node, devs) - elif node.name == 'feedforward1': - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=1, dim=1, num=2) - graph.assign(sub_nodes[0], 0) - graph.assign(sub_nodes[1], 1) - elif node.name == 'feedforward2': - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=1, dim=1, num=2) - graph.assign(sub_nodes[0], 2) - graph.assign(sub_nodes[1], 3) - elif node.name == 'feedforward3': - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=2, dim=0, num=4) - graph.assign(sub_nodes[0], 0) - graph.assign(sub_nodes[1], 1) - graph.assign(sub_nodes[2], 2) - graph.assign(sub_nodes[3], 3) - elif node.name == 'multi_head_attention': - graph.assign(node, 4) - elif node.name == 'add': - continue - else: - assert False, node.name - - # adjust add - adds = [node for node in graph.nodes() if node.name == 'add'] + for node in graph.select(ntype=IRDataOperation): + _replica(graph, node, devs) + for node in graph.select(name='embedding'): + _tp(graph, node, devs, idx=1, dim=0) + for node in graph.select(name='linear'): + _tp(graph, node, devs, idx=1, dim=0) + for node in graph.select(name='mean'): + _tp(graph, node, devs, idx=0, dim=2) + for node in graph.select(name='layernorm'): + _replica(graph, node, devs) + for node in graph.select(name='feedforward1'): + _tp(graph, node, [0, 1], idx=1, dim=1) + for node in graph.select(name='feedforward2'): + _tp(graph, node, [2, 3], idx=1, dim=1) + for node in graph.select(name='feedforward3'): + _tp(graph, node, [0, 1, 2, 3], idx=2, dim=0) + for node in graph.select(name='multi_head_attention'): + graph.assign(node, 4) + for node in graph.select(name='identity'): + _replica(graph, node, devs) + adds = tuple(graph.select(name='add')) assert len(adds) == 2 - graph.assign(adds[0], 4) - convert_add_to_valmap(graph, adds[1]) - - return graph + # graph.assign(adds[0], 4) + convert_add_to_valmap(graph, adds[0]) + _replica(graph, adds[1], devs) + # convert_add_to_valmap(graph, adds[1]) + for node in graph.select('feedforward1'): + ftensor = node.input(0).parent + break + flatten_branch_grad(graph, ftensor) + print(graph.extra_repr()) + return graph \ No newline at end of file From 656ac2b710842d1029f8d12eee5d2836864ba51b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 21 Sep 2022 21:09:12 +0800 Subject: [PATCH 1033/1892] fix multiref all-partition adapter generation --- cube/graph/gener/gen.py | 62 ++++++++++++++++++++--------- cube/graph/gener/utils.py | 84 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 18 deletions(-) create mode 100644 cube/graph/gener/utils.py diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 12ab0111..c369c970 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -3,6 +3,7 @@ from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener +import cube.graph.gener.utils as utils from cube.graph.graph import IRGraph from cube.graph.segment import IRSegment @@ -36,7 +37,9 @@ def __init__(self, tensor: IRSubTensor, device: int, def create_dummy(segment: IRSegment) -> List[IRFwOperation]: """ - Create dummy operators that + Create dummy operators segment inputs and outputs. + The backward operator is also inserted. + 1) produce segment input tensors 2) consume segment output tensors @@ -46,19 +49,40 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: """ devices = segment.device fwops = [] - for devid in devices: - for tensor in segment.inputs(): - if not isinstance(tensor, IRSubTensor): continue - assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = DummyInputOuput(tensor, devid, is_output=True) - segment.insert(fwop, 0) - fwops.append(fwop) - for tensor in segment.outputs(): - if not isinstance(tensor, IRSubTensor): continue - assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = DummyInputOuput(tensor, devid, is_input=True) - segment.insert(fwop, segment.nnodes) - fwops.append(fwop) + + # create inputs + for tensor in segment.inputs(): + if not isinstance(tensor, IRSubTensor): continue + assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" + fwop = DummyInputOuput(tensor, 0, is_output=True) + for devid in devices: + fop = fwop.replicate() + fop.device = devid + if tensor.requires_grad: + fop.output(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) + segment.finsert(fop, 0) + else: + segment.insert(fop, 0) + fwops.append(fop) + + # create outputs + for tensor in segment.outputs(): + if not isinstance(tensor, IRSubTensor): continue + assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" + fwop = DummyInputOuput(tensor, 0, is_input=True) + for devid in devices: + fop = fwop.replicate() + fop.device = devid + if tensor.requires_grad and segment.mirror != segment: + if isinstance(tensor.grad, float): + fop.input(0).grad = tensor.grad + else: + fop.input(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) + segment.finsert(fop, segment.nnodes) + else: + segment.insert(fop, segment.nnodes) + fwops.append(fop) + return fwops @@ -246,10 +270,9 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: return True fdummies = create_dummy(graph) + bdummies = [fwop.mirror for fwop in fdummies if fwop.mirror is not None] bgraph: Optional[IRSegment] = graph.mirror - bdummies = create_dummy(bgraph) if isinstance(bgraph, IRSegment) else [] - skip_grads = [t.parent for t in graph.inputs() + graph.outputs() if isinstance(t, IRSubTensor)] # generate adapter for inter-segments # FIXME: assume producers and consumers can run in parallel for ftensor in graph.full_tensors(): @@ -257,6 +280,9 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: if ftensor.is_param() or ftensor.is_grad(): continue + # flatten gradient + utils.flatten_grad(graph, ftensor) + # optimization: local fusion / multiref on producer / consumer ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) IRAdapterGener.local_consumer_multiref(graph, ftensor) @@ -274,7 +300,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bproducers, bptensors = [], [] bconsumers, bctensors = [], [] - if (ftensor not in skip_grads) and isinstance(ftensor.grad, IRFullTensor): + if isinstance(ftensor.grad, IRFullTensor): bproducers, bptensors = bgraph.producers(ftensor.grad), bgraph.ptensors(ftensor.grad) bptensors = expand_devices(bptensors, producer=True) assert all(len(ptensor.device) == 1 for ptensor in bptensors), ( @@ -523,7 +549,7 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): if not require_multiref: return for devid in devtensors: - grads: List[IRSubTensor] = [t.grad for t in devtensors[devid]] + grads: List[IRSubTensor] = [t.grad for t in devtensors[devid]][::-1] try: nchunks = [grad.valmap[1] for grad in grads] if len(set(nchunks)) == 1: diff --git a/cube/graph/gener/utils.py b/cube/graph/gener/utils.py new file mode 100644 index 00000000..1850c22f --- /dev/null +++ b/cube/graph/gener/utils.py @@ -0,0 +1,84 @@ +""" +Utilities for gradient modification +""" +from typing import Dict, List +from cube.graph import IRGraph +from cube.graph.segment import IRSegment +from cube.ir.operator import IRFwOperation +from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap + + + +def convert_add_to_valmap(graph: IRGraph, add_node: IRFwOperation): + """ + Remove add node by replacing with tensor valmap + + @param graph IRGraph: the program + @param add_node IRFwOperation: the add forward operation + """ + assert add_node.name == 'add' + ptensors, producers = [], [] + for itensor in add_node.inputs(): + iptensors = graph.ptensors(itensor.parent) + assert len(set(t.valmap for t in iptensors)) == len(iptensors) + ptensors += iptensors + producers += graph.producers(itensor.parent) + ftensor = add_node.output(0).parent + for idx, (ptensor, producer) in enumerate(zip(ptensors, producers)): + fidx = producer.outputs().index(ptensor) + bidx = producer.mirror.inputs().index(ptensor.grad) + ptensor = ftensor.select(ptensor.indmap, (idx, len(producers))) + ptensor.grad = ftensor.grad.select(ptensor.indmap, (0,1)) + with graph.update(producer): + producer.set_output(fidx, ptensor) + with graph.mirror.update(producer.mirror) as bnode: + bnode.set_input(bidx, ptensor.grad) + graph.remove(add_node) + graph.mirror.remove(add_node.mirror) + + +def flatten_grad(graph: IRSegment, ftensor: IRFullTensor): + """ + Reset gradient for consumers that are different (no replica) + Gradient valuemap will be flatten iter-devices, e.g.,(0,3), (1,3), (2,3) + Gradient valuemap will be exponent intra-devices, e.g., (0,2), (2,4), (3,4) + + @param graph IRGraph: the graph + @param ftensor IRFullTensor: the fulltensor + + @return None: this is an inplacement update. + """ + if not isinstance(ftensor.grad, IRFullTensor): return + + grads = [t.grad for t in graph.ctensors(ftensor)] + # require each consumer is a different operator (no replica) + if len(set(grads)) != len(grads): return + + # group consumers by same tensor and same device + devtensors : Dict[IRSubTensor, Dict[int, List[IRFwOperation]]] = dict() + for ctensor in graph.ctensors(ftensor): + devtensors[ctensor] = dict() + for ctensor in graph.ctensors(ftensor): + if len(ctensor.device) > 1: return + devtensors[ctensor][ctensor.device[0]] = [] + for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): + devid = ctensor.device[0] + devtensors[ctensor][devid].append(consumer) + + # setup gradient + for ctensor in devtensors: + nchunks = len(devtensors[ctensor]) + for vid, consumers in enumerate(devtensors[ctensor].values()): + curr_valmap = ValueMap((vid, nchunks)) + for cidx, consumer in enumerate(consumers): + valmap = curr_valmap.map((0, 2)) if cidx != len(consumers) - 1 else curr_valmap + grad = ftensor.grad.select(ctensor.indmap, valmap) + # update consumer and its mirror node + fidx = consumer.inputs().index(ctensor) + assert consumer.mirror is not None, consumer + bidx = consumer.mirror.outputs().index(consumer.input(fidx).grad) + consumer.input(fidx).grad = grad + with graph.mirror.update(consumer.mirror) as bnode: + bnode.set_output(bidx, grad) + # update current valmap + curr_valmap = curr_valmap.map((1, 2)) if cidx != len(consumers) - 1 else curr_valmap From 0e41730b6060ee6931c53f95c94e6efcb66a9738 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 21 Sep 2022 21:09:41 +0800 Subject: [PATCH 1034/1892] fix typo --- tests/test_rvd_prim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rvd_prim.py b/tests/test_rvd_prim.py index 0ecf66f6..9442c8cb 100644 --- a/tests/test_rvd_prim.py +++ b/tests/test_rvd_prim.py @@ -89,7 +89,7 @@ def prim_bw(prim: Callable, bandwidth: Callable, ranks, size, warmup=100, profil msg_size = len(ranks) * tensor.nelement() * 4 // 1024 // 1024 # MB algo_bw, bus_bw = bandwidth(tensor, ranks, span) print_each_rank( - '{} msg {} : MBwall-time(ms) algo-bw(GB/s) bus-bw(GB/s) {:.2f} {:.2f} {:.2f}'.format( + '{} msg {} MB | wall-time(ms) algo-bw(GB/s) bus-bw(GB/s) {:.2f} {:.2f} {:.2f}'.format( prim.__name__, msg_size, span*1000, algo_bw, bus_bw ), rank_only=0 ) From 8d32ab7c88af1701ecaf859ee5af9cbe8a63ac1c Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Wed, 21 Sep 2022 22:25:18 +0800 Subject: [PATCH 1035/1892] pull main and save work --- examples/alphafold2/evoformer.py | 50 +++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/examples/alphafold2/evoformer.py b/examples/alphafold2/evoformer.py index 704695c6..42ac0d95 100644 --- a/examples/alphafold2/evoformer.py +++ b/examples/alphafold2/evoformer.py @@ -1,5 +1,11 @@ import torch import math + +from cube.profiler import CudaTimer + +from torch.utils.checkpoint import checkpoint + + """ [bs, s, r, cm] -> [bs, s, r, cm] @@ -253,13 +259,20 @@ def __init__(self, def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): - msa_repr = msa_repr + MSARowAttentionWithPairBias( + # msa_repr = msa_repr + MSARowAttentionWithPairBias( + # self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, + # self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, + # self.msa_head, self.c, self.scale) + msa_repr = msa_repr + checkpoint(MSARowAttentionWithPairBias, self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) - msa_repr = msa_repr + MSAAttention( + # msa_repr = msa_repr + MSAAttention( + # self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, + # self.col_out_proj, None, self.msa_head, self.c, self.scale) + msa_repr = msa_repr + checkpoint(MSAAttention, self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, self.col_out_proj, None, self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) @@ -306,15 +319,36 @@ def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): return msa_repr, pair_repr +def train_iter(model, msa, pair): + bs = msa.size() + new_msa, new_pair = model(msa, pair) + loss = torch.sum(new_msa) + torch.sum(new_pair) + loss.backward() -def test(): +def test(dev): bs, s, r, cm, cz = 1, 128, 256, 256, 128 - model = Evoformer(s, cm, cz) + model = Evoformer(s, cm, cz).to(dev) - msa = torch.randn(bs, s, r, cm) - pair = torch.randn(bs, r, r, cz) + msa = torch.randn(bs, s, r, cm).to(dev) + pair = torch.randn(bs, r, r, cz).to(dev) - new_msa, new_pair = model(msa, pair) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + warm_up = 20 + iter_num = 64 + CudaTimer(enable=False).warmup() + + for i in range(iter_num): + if i >= warm_up: + CudaTimer(enable=True).start('e2e') + train_iter(model, msa, pair) + optimizer.step() + optimizer.zero_grad() + if i >= warm_up: + CudaTimer().stop('e2e') + + print(CudaTimer().duration(iter_num - warm_up, field_name='e2e'), 'ms') + print(torch.cuda.memory_summary(dev)) -test() +test(torch.device('cuda:0')) From 53f10fa6324a6ca5b6b1ac943ebb8028573f211f Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 21 Sep 2022 19:47:11 -0700 Subject: [PATCH 1036/1892] add multi2ref func --- examples/nlp/palm/palm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py index 505b155c..a75c370d 100644 --- a/examples/nlp/palm/palm.py +++ b/examples/nlp/palm/palm.py @@ -152,6 +152,10 @@ def feedforward2(x: torch.Tensor, proj: torch.Tensor): def feedforward3(x: torch.Tensor, y: torch.Tensor, proj: torch.Tensor): return torch.matmul(x * y, proj) +@cube.graph.parser.register('* -> *, *', name='multi2ref') +def multi2ref(x: torch.Tensor): + return (x, x) + class PaLMLayer(nn.Module): From 2bd8075e7163ccb0deb9174d7b2e25c6b0664bf8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 22 Sep 2022 11:17:47 +0800 Subject: [PATCH 1037/1892] remove useless component --- cube/search/__init__.py | 0 cube/search/iterator.py | 76 ---- cube/search/sampler.py | 368 ------------------- cube/tetris/composer.py | 772 ---------------------------------------- cube/tetris/solver.py | 367 ------------------- 5 files changed, 1583 deletions(-) delete mode 100644 cube/search/__init__.py delete mode 100644 cube/search/iterator.py delete mode 100644 cube/search/sampler.py delete mode 100644 cube/tetris/composer.py delete mode 100644 cube/tetris/solver.py diff --git a/cube/search/__init__.py b/cube/search/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/search/iterator.py b/cube/search/iterator.py deleted file mode 100644 index 1b698a00..00000000 --- a/cube/search/iterator.py +++ /dev/null @@ -1,76 +0,0 @@ -from itertools import combinations -from typing import Any, List - - -def comb_iter(candidates: List, pick_num: int): - """ - combination pickers - """ - return combinations(candidates, pick_num) - - -def otho_iter(slots: List[List[Any]]): - """ - othogonal pickers - - item for each slot can be randomly selected - """ - if len(slots) == 0: - yield [] - return - slot = slots[0] - if len(slots) == 1: - for item in slot: - yield [item] - else: - slots = slots[1:] - for item in slot: - for res in otho_iter(slots): - yield [item] + res - return - - -def factorization(K: int, num=1): - """ - Decompose K into `depth` numbers that - a1 * a2 * ... * a_depth = K - ($\prod\limits_{i=1}^depth a_i = K$) - - Yield: - List[int] - """ - if num == 1: - yield [K] - else: - for i in range(1, K+1): - if K % i == 0: - for res in factorization(K // i, num-1): - yield [i] + res - -def diff_balls_diff_boxes(nballs: int, nboxes: int, remain = None, placement = None): - balls_per_box = nballs // nboxes - if placement is None and remain is None: - # placement[ball_id] = box_id - placement = [] - # remain slots: remain_slots[box_id] = int - remain = [balls_per_box] * nboxes - if len(placement) == nballs: - yield placement - for box_id, remain_balls in enumerate(remain): - if remain_balls > 0: - placement.append(box_id) - remain[box_id] -= 1 - for seq in diff_balls_diff_boxes(nballs, nboxes, remain, placement): - yield seq - remain[box_id] += 1 - placement = placement[:-1] - - - -if __name__ == '__main__': - - # for seq in otho_iter([[1,2,3], [4,5], [6,7,8]]): - # print(seq) - - for seq in factorization(8, 2): - print(seq) \ No newline at end of file diff --git a/cube/search/sampler.py b/cube/search/sampler.py deleted file mode 100644 index 41e32b4e..00000000 --- a/cube/search/sampler.py +++ /dev/null @@ -1,368 +0,0 @@ -""" -Micro-batch sampler for scheduling search -""" -from typing import Callable, Dict, List, Tuple -from cube.graph.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRBpOperation -from cube.ir.cten import IRCell -from cube.execplan import ExecutionPlan - -from multiprocessing import Pool -import numpy as np -import time, copy, math - - -class Estimator: - """ - A node cost is represented as (mem_weight, mem_activation, exec_time) - """ - @staticmethod - def taging(graph: IRGraph): - for node in graph.nodes(): - # tag: (mem_weight, mem_activation, span) - if isinstance(node, IRFwOperation): - node.cost = (0, 1, 1) - elif isinstance(node, IRBpOperation): - node.cost = (0, -1, 2) - else: - node.cost = (0, 0, 0) - - @staticmethod - def map2mem(node: IRCell): - if node.cost is not None: - mem_w, mem_a, span = node.cost - else: - mem_w, mem_a, span = 0, 0, 0 - return mem_w + mem_a - - @staticmethod - def map2time(node: IRCell): - if node.cost is not None: - mem_w, mem_a, span = node.cost - else: - mem_w, mem_a, span = 0, 0, 0 - return span - - -class Sampler: - """ - Schedule sampler - """ - @staticmethod - def sample(micro_seqs: List[List[IRCell]], n_microbatch: int, n_stage: int, n_device: int, - ssampler: Callable, tsampler: Callable, wlimits: int, alimits: int): - assert len(micro_seqs) == n_microbatch - for seq in micro_seqs: - assert len(seq) // 2 == n_stage - graph = IRGraph([], [], [], 'search') - flatten_nodes = list() - for seq in micro_seqs: - flatten_nodes += seq - graph = IRGraph(flatten_nodes, [], [], 'search') - # graph._nodes = flatten_nodes - for sidx, placements in enumerate(ssampler(n_microbatch, n_stage, n_device)): - print('seraching placement:\n', placements) - # assign to device - for mid in range(n_microbatch): - for devid, fnode in zip(placements[mid], micro_seqs[mid]): - graph.assign(fnode, devid) - - # pruning: add dependecies for micro-batches with same device assignment - # this pruning guarantees the optimal - # graph.reset_dependency() - # same_microbatch = dict() - # for mid, placement in enumerate(placements): - # placement = tuple(placement) - # if placement not in same_microbatch: - # same_microbatch[placement] = list() - # same_microbatch[placement].append(mid) - # for placement, mids in same_microbatch.items(): - # if len(mids) > 1: - # print(f'find {mids} microbatch same, add dependency') - # for sid in range(len(placement)): - # # add forward dependency - # graph.add_schedule([micro_seqs[mid][sid] for mid in mids]) - # # add backward dependency - # graph.add_schedule([micro_seqs[mid][sid+len(placement)] for mid in mids]) - - # pruning - graph.reset_dependency() - forders = [[] for _ in range(n_device)] - # n_device x n_stage - borders = [[[] for _ in range(n_stage)] for _ in range(n_device)] - for sid in range(n_stage): - for mid in range(min(n_microbatch, alimits)): - devid = placements[mid][sid] - forders[devid].append((mid, sid)) - borders[devid][n_stage - 1 - sid].append(mid) - for devid, order in enumerate(forders): - fseq = list() - for mid, sid in order: - fseq.append(micro_seqs[mid][sid]) - graph.add_schedule(fseq) - for devid, order in enumerate(borders): - bseq = list() - for sid in range(n_stage): - bseq += [micro_seqs[mid][n_stage-1-sid] for mid in order[sid]] - graph.add_schedule(bseq) - - # search - for seqs in tsampler(graph.nodes()): - print(f'searching {len(seqs)} sequences under {sidx}-th placement') - yield seqs - - -class TemporalSampler: - """ - Temporal sampler takes nodes (List[IRCell]) as input - """ - - @staticmethod - def btemporal(nodes: List[IRCell], bs=1): - seqs = list() - for seq in TemporalSampler.temporal(nodes): - seqs.append(seq) - if len(seqs) % bs == 0: - yield seqs - seqs = list() - if len(seqs) > 0: - yield seqs - - @staticmethod - def temporal(nodes: List[IRCell], seq = None): - if seq is None: - seq = list() - if len(nodes) == 0: - yield seq - # initial entry - entry_nodes = TemporalSampler.ready_emit_set(remain=nodes, seq=seq) - if len(entry_nodes) == 0: - return None - for node in entry_nodes: - seq = seq + [node] - nid = nodes.index(node) - sub_nodes = nodes[:nid] + nodes[nid+1:] - for res in TemporalSampler.temporal(sub_nodes, seq): - if res is None: - continue - yield res - seq = seq[:-1] - - @staticmethod - def ready_emit_set(remain: List[IRCell], seq: List[IRCell]): - """ - Get ready-to-emit node list from remain node set - """ - ready = list() - for node in remain: - satisfy = True - for pre in node.predecessors(): - if pre not in seq: - satisfy = False - break - if satisfy: - if len(seq) > 0 and len(seq[-1].device) != 0 and len(node.device) != 0: - # pruning #1: filter out equal sequences - if seq[-1] not in node.predecessors(): - if node.device[0] < seq[-1].device[0]: - continue - ready.append(node) - return ready - - -class SpatialSampler: - """ - Spatial sampler takes (n_microbatch, n_stage, n_device) as input - """ - - @staticmethod - def full(n_microbatch: int, n_stage: int, n_device: int, placement = None): - # each device pick n_microbatch * n_stage // n_device blocks - per_device_nblocks = n_microbatch * n_stage // n_device - # placement each stage placement - placement = placement if placement is not None else [] - - if len(placement) == n_microbatch * n_stage: - bucket_min = [n_microbatch * n_stage] * n_device - for nid, devid in enumerate(placement): - bucket_min[devid] = min(bucket_min[devid], nid) - check = [bucket_min[idx + 1] - bucket_min[idx] for idx in range(n_device - 1)] - if min(check) < 0: - yield None - else: - yield placement - else: - # require strict increasing array [min(bucket) for bucket in buckets] - # bucket_min = list(range(n_microbatch * n_stage, n_microbatch * n_stage + n_device + 1)) - bucket_cnt = [0] * n_device - for nid, devid in enumerate(placement): - # bucket_min[devid] = min(nid, bucket_min[devid]) if bucket_min[devid] is not None else nid - bucket_cnt[devid] += 1 - for devid in range(n_device): - if bucket_cnt[devid] < per_device_nblocks: - placement = placement + [devid] - for seq in SpatialSampler.full(n_microbatch, n_stage, n_device, placement): - if seq is None: - continue - yield seq - placement = placement[:-1] - - @staticmethod - def same(n_microbatch: int, n_stage: int, n_device: int, wlimits: int): - """ - Same spatial placement for each micro-batch - """ - placements = [] - for _ in range(n_microbatch): - placement = [sid % n_device for sid in range(n_stage)] - placements.append(placement) - yield placements - - @staticmethod - def othogonal(n_microbatch: int, n_stage: int, n_device: int, - wlimits: int, status = None, placements = None): - """ - Find othogonal plans given weight_limits - - Yield: - List[microbatch][stage] = device (int) - """ - # each element denotes number of block assigned - status = np.zeros((n_device, n_stage), dtype=int) if status is None else status - placements = [] if placements is None else placements - # repeat to reduce space - if len(placements) == wlimits: - for idx in range(n_microbatch - wlimits): - placements = placements + [copy.copy(placements[idx % wlimits])] - yield placements - # find othogonal placements - elif len(placements) == 0: - # fix the first one due to symmetric device - placements = placements + [[sid % n_device for sid in range(n_stage)]] - for sid in range(n_stage): - status[sid % n_device][sid] += 1 - for seqs in SpatialSampler.othogonal(n_microbatch, n_stage, n_device, - wlimits, status, placements): - yield seqs - else: - for placement in SpatialSampler.microbatch_othogonal(np.copy(status)): - placements = placements + [placement] - for sid, devid in enumerate(placement): - status[devid][sid] += 1 - for seqs in SpatialSampler.othogonal(n_microbatch, n_stage, n_device, - wlimits, status, placements): - yield seqs - for sid, devid in enumerate(placement): - status[devid][sid] -= 1 - placements = placements[:-1] - - @staticmethod - def microbatch_othogonal(status: np.ndarray, placement = None): - """ - status: - 2D array [n_device, n_stage], each element represents - how many stage blocks are assigned. - """ - n_device, n_stage = status.shape - assert n_stage == 4 - placement = [] if placement is None else placement - if len(placement) == n_stage: - # print(placement) - # input('>>>out') - yield placement - else: - sid = len(placement) - allocation = np.sum(status, axis=1) - min_alloc = np.min(allocation) - collision = status[:,sid] - valid = list() - for devid, coll in enumerate(collision): - if coll != 0 or allocation[devid] != min_alloc: - continue - valid.append(devid) - for devid in valid: - placement = placement + [devid] - status[devid][sid] += 1 - for seq in SpatialSampler.microbatch_othogonal(status, placement): - yield seq - status[devid][sid] -= 1 - placement = placement[:-1] - - - @staticmethod - def microbatch_placement(n_stage: int, n_device: int, - wlimits: int, placement = None, wstatus = None): - """ - Find microbatch placement - Yield: - List[stage] = device[int] - """ - placement = [] if placement is None else placement - wstatus = [0] * n_device if wstatus is None else wstatus - if len(placement) == n_stage: - yield placement - else: - for devid in range(n_device): - if wstatus[devid] == wlimits: - continue - placement = placement + [devid] - wstatus[devid] += 1 - for seq in SpatialSampler.microbatch_placement(n_stage, n_device, wlimits, placement, wstatus): - yield seq - wstatus[devid] -= 1 - placement = placement[:-1] - - -class Searcher: - - pool = Pool(processes=32) - - @staticmethod - def search(seqs: List[List[IRCell]], bucket: Dict, n_worker: int = 1) -> Dict[int, Tuple[int, List]]: - pool = Pool(processes=32) - # memory (int) -> (time, seq) - tic = time.time() - per_worker_seqs = int(math.ceil(len(seqs) / n_worker)) - worker_buckets = list() - for wid in range(n_worker): - start = wid * per_worker_seqs - stop = (wid + 1) * per_worker_seqs - worker_seqs = seqs[start:stop] - worker_buckets.append(pool.apply_async(Searcher._run, (worker_seqs,))) - worker_buckets: List[Dict] = map(lambda buck: buck.get(), worker_buckets) - # merge results - for worker_bucket in worker_buckets: - for mem, (span, seq) in worker_bucket.items(): - if mem in bucket and bucket[mem][0] <= span: - continue - print(f'find better plan at mem budget {mem}: span: {span}') - bucket[mem] = (span, seq) - toc = time.time() - throughput = round(len(seqs) / (toc - tic), 2) - print(f'searched {len(seqs)} sequences... throughput: {throughput} seqs/s') - pool.close() - pool.join() - - @staticmethod - def _run(seqs: List[List[IRCell]]) -> Dict[int, Tuple[int, List]]: - """ - Worker run - """ - bucket = dict() - graph = IRGraph([], [], [], 'search') - for seq in seqs: - graph._nodes = seq - execplan = ExecutionPlan(graph) - span, mem = execplan.analyze(map2time=Estimator.map2time, map2mem=Estimator.map2mem) - if mem not in bucket: - bucket[mem] = (span, copy.copy(seq)) - elif bucket[mem][0] > span: - bucket[mem] = (span, copy.copy(seq)) - return bucket - - -if __name__ == '__main__': - - for idx, placement in enumerate(SpatialSampler.othogonal(n_microbatch=4, n_stage=4, n_device=4, wlimits=2)): - print(placement) - print(f'total {idx+1} placements') diff --git a/cube/tetris/composer.py b/cube/tetris/composer.py deleted file mode 100644 index 9537876b..00000000 --- a/cube/tetris/composer.py +++ /dev/null @@ -1,772 +0,0 @@ -""" -The Tetris. - -Abstraction layer for microb-batch execution plan merge. -""" - -from typing import Callable, Dict, List, Tuple, Optional, Union -import numpy as np -from enum import Enum -import time - - -class Block: - """ - Execution block for a MicroPlan - """ - - class BType(Enum): - FW = 'forward' - BW = 'backward' - - def __init__(self, mid: int, btype: BType, mem: int = 1): - self.mid: int = mid - self.btype = btype - self.memory = abs(mem) if btype == Block.BType.FW else 0-abs(mem) - # dependency track - self.before: List[Block] = list() - self.after: List[Block] = list() - - @staticmethod - def add_dependency(before, after): - if not (isinstance(before, Block) and isinstance(after, Block)): - raise ValueError("Expected before and after to be Block") - if after not in before.after: - before.after.append(after) - if before not in after.before: - after.before.append(before) - - def __repr__(self): - return f'f{self.mid}' if self.btype == Block.BType.FW else f'b{self.mid}' - - -class PlanBase: - - def __init__(self, ndevs: int): - # (device, step) -> block - self.blocks: Dict[Tuple[int, int], Block] = dict() - # block id -> ((device,), step) - self.positions: Dict[int, Tuple[Tuple[int], int]] = dict() - self.plan = np.zeros((ndevs, ndevs * 2), dtype=int) - - @property - def ndevs(self): - return self.plan.shape[0] - - @property - def nsteps(self): - return self.plan.shape[1] - - def memory(self, devid: Optional[int] = None) -> Union[List[int], int]: - """ - Get memory of the this plan - """ - if isinstance(devid, int): - memory = 0 - peak_mem = memory - for step, have_block in enumerate(self.plan[devid]): - have_block = have_block != 0 - if have_block: - memory += self.block(devid, step).memory - peak_mem = max(peak_mem, memory) - return peak_mem - else: - dev_peak_mem = [] - for devid in range(self.ndevs): - dev_peak_mem.append(self.memory(devid)) - return dev_peak_mem - - def block(self, dev: int, step: int): - """ - Get block given a position - """ - if (dev, step) not in self.blocks: - return None - return self.blocks[(dev, step)] - - def position(self, block: Block) -> Optional[Tuple[Tuple[int], int]]: - """ - Get (dev, step) position given a block. - If block not in this plan, return None - """ - if id(block) in self.positions: - return self.positions[id(block)] - else: - return None - - def squeeze(self): - """ - remove redundant steps - """ - execflag = np.sum(self.plan, axis=0) - for idx in range(self.nsteps): - if execflag[-idx-1] != 0: - break - self.plan = self.plan[:, :-idx] if idx > 0 else self.plan - - def __repr__(self): - namelen = 2 - dscp = '' - for dev in range(self.ndevs): - for step in range(self.nsteps): - block = self.block(dev, step) - if block is None: - dscp += '-' * namelen + ' ' - else: - # TODO: 2 replace to namelen - dscp += '{: <2}'.format(repr(block)) + ' ' - dscp += '\n' - return dscp - - def visualize(self, outfile=None): - import matplotlib.pyplot as plt - from matplotlib.patches import Rectangle - from matplotlib.ticker import AutoMinorLocator - plt.close('all') - fig, ax = plt.subplots(figsize=(4 * self.nsteps // self.ndevs, 4)) - renderer = fig.canvas.get_renderer() - - # xaxis - ax.set_xlim((0, self.nsteps)) - plt.xticks( - ticks=np.arange(0.5, self.nsteps+0.5, 1.0, dtype=float), - labels=np.arange(1, self.nsteps+1, 1, dtype=int) - ) - minor_locator = AutoMinorLocator(2) - plt.gca().xaxis.set_minor_locator(minor_locator) - ax.xaxis.grid(which='minor', linestyle='--') - # yaxis - ax.set_ylim((0.5, self.ndevs+0.5)) - plt.yticks(np.arange(1, self.ndevs+1, 1, dtype=int)) - ax.invert_yaxis() - - fontsize = [40] - txts = list() - def draw_block(block: Block, position: Tuple[Tuple[int], int], fontsize): - color = '#4472C4' if block.btype == Block.BType.FW else '#ED7D31' - devs, step = position - for dev in devs: - rec = Rectangle((step, dev+0.5), 1, 1, color=color, ec='black', lw=1.5) - ax.add_artist(rec) - rx, ry = rec.get_xy() - cx = rx + rec.get_width() / 2.0 - cy = ry + rec.get_height() / 2.0 - anno = str(block.mid) - txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') - rbox = rec.get_window_extent(renderer) - for fs in range(fontsize[0], 1, -2): - txt.set_fontsize(fs) - tbox = txt.get_window_extent(renderer) - if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: - break - fontsize[0] = min(fontsize[0], fs) - txts.append(txt) - - for dev in range(self.ndevs): - for step in range(self.nsteps): - block = self.block(dev, step) - if block is not None: - draw_block(block, self.position(block), fontsize) - # set fontsize to same - fontsize = fontsize[0] - for txt in txts: - txt.set_fontsize(fontsize) - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(fontsize) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(fontsize) - plt.xlabel('Time Step', fontsize=fontsize) - plt.ylabel('Device', fontsize=fontsize) - plt.tight_layout() - if outfile: - plt.savefig(outfile) - else: - plt.show() - - -class MicroPlan(PlanBase): - - def __init__(self, mid: int, ndevs: int): - """ - Create an empty microbatch execution plan - - mid: microbatch id - ndevs: number of devices - """ - super().__init__(ndevs) - self.mid = mid - - def expand_to(self, nsteps: int): - if self.nsteps < nsteps: - extend = nsteps - self.nsteps - self.plan = np.pad(self.plan, ((0,0),(0,extend))) - - def add_block(self, pos: Tuple[Union[int, Tuple[int]], int], btype: Block.BType) -> Block: - """ - Add a execution block. - pos: [dev(s), step] - """ - devs, step = pos - if isinstance(devs, int): - devs = (devs,) - else: - devs = tuple(devs) - if max(devs) >= self.ndevs: - raise ValueError("device out of scope") - if step >= self.nsteps: - self.expand_to(step + 1) - if not all([self.plan[dev, step] == 0 for dev in devs]): - raise ValueError(f"Postition {pos} already has blocks") - block = Block(self.mid, btype) - for dev in devs: - self.plan[dev, step] += 1 - self.blocks[(dev, step)] = block - self.positions[id(block)] = (devs, step) - return block - - def add_dependency(self, blocks: List[Block]): - """ - Add dependent blocks: - block[0] < block[1] < block[2] < ... - """ - for idx in range(len(blocks) - 1): - Block.add_dependency(blocks[idx], blocks[idx+1]) - - def copy(self): - """ - copy a micro plan - """ - micro = MicroPlan(self.mid, self.ndevs) - micro.plan = np.array(self.plan, copy=True) - micro.blocks.update(self.blocks) - micro.positions.update(self.positions) - return micro - - def select(self, begin_step: int, stop_step: int): - """ - select micro plans of [begin, stop) - """ - micro = MicroPlan(self.mid, self.ndevs) - micro.expand_to(stop_step) - all_blocks = [] - for block in self.blocks.values(): - if block not in all_blocks: - all_blocks.append(block) - for block in all_blocks: - devs, step = self.position(block) - if begin_step <= step and step < stop_step: - for dev in devs: - micro.plan[dev, step] += 1 - micro.blocks[(dev, step)] = block - micro.positions[id(block)] = (devs, step) - return micro - - def shift(self, block: Block, inplace=True): - """ - The primitive: shift a block by pushing one step later - """ - micro = self if inplace else self.copy() - # check block in this plan - if block not in micro.blocks.values(): - raise ValueError("Block not in this micro plan") - devs, step = micro.position(block) - # shift later blocks - for after_block in block.after: - if step + 1 == micro.position(after_block)[1]: - micro.shift(after_block, inplace=True) - for dev in devs: - next_block = self.block(dev, step+1) - if next_block is not None: - micro.shift(next_block, inplace=True) - # shift this one - for dev in devs: - micro.plan[dev, step] = 0 - if step + 1 >= micro.nsteps: - micro.expand_to(micro.nsteps + 1) - micro.plan[dev, step+1] = 1 - # update blocks and positions - del micro.blocks[(dev, step)] - micro.blocks[(dev, step+1)] = block - micro.positions[id(block)] = (devs, step+1) - return micro - - def unshift(self, block: Block, inplace=True): - """ - reverse shift, for search only - """ - micro = self if inplace else self.copy() - devs, step = micro.position(block) - if step == 0: - raise ValueError("unshift a block with step = 0") - # shift back - for dev in devs: - micro.plan[dev, step] = 0 - micro.plan[dev, step-1] = 1 - del micro.blocks[(dev, step)] - micro.blocks[(dev, step-1)] = 1 - micro.positions[id(block)] = (devs, step-1) - # shift back shifted blocks - for after_block in block.after: - if step + 1 == micro.position(after_block)[1]: - micro.unshift(after_block, inplace=True) - # TODO: how can I know the independent blocks got shifted? - return micro - - -class SchedulePlan(PlanBase): - - def __init__(self, micros: List[MicroPlan]): - ndevs = [micro.ndevs for micro in micros] - if len(set(ndevs)) != 1: - raise ValueError(f"Device number not same: {ndevs}") - ndevs = ndevs[0] - - super().__init__(ndevs) - self.micros = micros - - # get schedule plans - max_steps = max(micro.nsteps for micro in micros) - for micro in micros: - micro.expand_to(max_steps) - plans = tuple(micro.plan for micro in micros) - schedule = np.sum(np.stack(plans, axis=-1), axis=-1) - if len(np.where(schedule > 1)[0]) > 0: - raise ValueError("micro plans are not composable") - self.plan = schedule - self.squeeze() - - # set blocks and positions - for micro in micros: - self.blocks.update(micro.blocks) - self.positions.update(micro.positions) - - @staticmethod - def composable(micros: List[MicroPlan], mem_constraints: Optional[List[int]] = None) -> bool: - # check execution conflicts - max_steps = max(micro.nsteps for micro in micros) - for micro in micros: - micro.expand_to(max_steps) - plans = tuple(micro.plan for micro in micros) - schedule = np.sum(np.stack(plans, axis=-1), axis=-1) - devids = np.where(schedule > 1)[0] - if len(devids) != 0: - return False - # check memory conflicts - if mem_constraints is not None: - mems = SchedulePlan(micros).memory() - for mem, bound in zip(mems, mem_constraints): - if mem > bound: - return False - return True - - @staticmethod - def conflict(micros: List[MicroPlan], step: int, memory_constraints: List[int]) -> Dict[int, List[Tuple[MicroPlan, Block]]]: - """ - Get conflict blocks at `step`. - Return the conflicted (MicroPlan, Block) grouped by device id - - This assumes micros are composable for steps < step - """ - ndevs = micros[0].ndevs - # find memory conflicted blocks - mem_conflicts: Dict[int, List[Block]] = dict() - curr_memory = [] - for devid in range(ndevs): - mem = 0 - for t in range(step): - for micro in micros: - block = micro.block(devid, t) - if block is not None: - mem += block.memory - curr_memory.append(mem) - for devid in range(ndevs): - for micro in micros: - block = micro.block(devid, step) - if block is not None: - if curr_memory[devid] + block.memory > memory_constraints[devid]: - for dev in micro.position(block)[0]: - if dev not in mem_conflicts: - mem_conflicts[dev] = [] - if block not in mem_conflicts[dev]: - mem_conflicts[dev].append(block) - - # find execution conflicted blocks - exe_conflicts: Dict[int, List[Block]] = dict() - max_steps = max(micro.nsteps for micro in micros) - for micro in micros: - micro.expand_to(max_steps) - plans = tuple(micro.plan[:,step] for micro in micros) - schedule = np.sum(np.stack(plans, axis=-1), axis=-1) - devids = np.where(schedule > 1)[0] - for dev in devids: - exe_conflicts[dev] = [] - for micro in micros: - cblock = micro.block(dev, step) - if cblock is not None: - exe_conflicts[dev].append(cblock) - - # consistent device set - devices = set(list(exe_conflicts.keys()) + list(mem_conflicts.keys())) - for devid in devices: - if devid not in mem_conflicts: - mem_conflicts[devid] = [] - if devid not in exe_conflicts: - exe_conflicts[devid] = [] - # print(exe_conflicts, mem_conflicts) - return exe_conflicts, mem_conflicts - - -class Composer: - - @staticmethod - def premise(fn, ndevs: int, nmicros: int): - micros = fn(ndevs, nmicros) - return micros - - @staticmethod - def bfs_schedule(micros: List[MicroPlan], - mem_constraints: Union[int, Tuple[int]], - mem_opt=True, prune_symmetric=True): - mem_constraints = [mem_constraints] * micros[0].ndevs if isinstance(mem_constraints, int) else mem_constraints - total_status = 1 - micros.sort(key=lambda m: m.mid) - block_hash = Composer.construct_hash(micros) if prune_symmetric else None - step = 0 - opt_step = sum(micro.nsteps for micro in micros) # initial - prev: List[List[MicroPlan]] = [micros] - next: List[List[MicroPlan]] = [] - schedules: List[SchedulePlan] = [] - while step < opt_step: - print(f'solving step {step}, candidates {len(prev)}...') - for micros in prev: - # get and solve conflicts - exe_conflicts, mem_conflicts = SchedulePlan.conflict(micros, step, mem_constraints) - if len(exe_conflicts) == 0 and len(mem_conflicts) == 0: - next.append(micros) - continue - # input(f'conflicts: dev: {list(conflicts.keys())}, mids: {[[conf[0].mid for conf in c] for c in conflicts.values()]} | >>>') - for shifts in Composer.iter_shifts(micros, exe_conflicts, mem_conflicts, block_hash=block_hash): - # print(f"step {step}: {shifts}") - shifted_micros = [micro.copy() for micro in micros] - for cblock in shifts: - cmid = cblock.mid - cmicro = shifted_micros[cmid] - cmicro.shift(cblock, inplace=True) - # print(f"solved results: ") - # for micro in shifted_micros: - # print(f'microbatch #{micro.mid}:') - # print(micro) - if SchedulePlan.composable(shifted_micros, mem_constraints): - schedule = SchedulePlan(shifted_micros) - schedules.append(schedule) - if schedule.nsteps < opt_step: - print(f'find fewer steps: {schedule.nsteps}') - opt_step = min(opt_step, schedule.nsteps) - else: - # pruning technique: discard plans that exceed opt_step - discard = False - for m in shifted_micros: - if m.nsteps > opt_step: - discard = True - break - if not discard: - next.append(shifted_micros) - total_status += len(next) - prev, next = next, [] - step += 1 - total_status += len(schedules) - schedules = [schedule for schedule in schedules if schedule.nsteps == opt_step] - if mem_opt: - schedules = [SchedulePlan(Composer.memory_opt(schedule.micros)) for schedule in schedules] - print(f'searched {total_status} status.') - return schedules - - @staticmethod - def same_plans(micros: List[MicroPlan], start_step: int = 0) -> bool: - Composer.to_same_step(micros) - plans = [micro.plan[:,start_step:] for micro in micros] - plan = plans[0] - for other in plans[1:]: - if not np.array_equal(plan, other): - return False - return True - - @staticmethod - def construct_hash(micros: List[MicroPlan]) -> Callable: - """ - construct a hashing function to map "same" blocks into a same integer. - - The "same" blocks refer to the same-position blocks of same micro plans. - """ - # group same micro plans - same_plans: List[List[MicroPlan]] = [[]] - for micro in micros: - for smicros in same_plans: - if Composer.same_plans(smicros + [micro]): - smicros.append(micro) - break - else: - same_plans.append([micro]) - print(f'detecting {len(same_plans)} same-microplan groups: {[[plan.mid for plan in smicros] for smicros in same_plans]}') - # for each micro plan group, group same hash functions - gid = 0 - block2gid: Dict[int, int] = dict() - for smicros in same_plans: - positions: Dict[Tuple[Tuple[int], int], List[Block]] = dict() - for micro in smicros: - for pos, block in micro.blocks.items(): - if pos not in positions: - positions[pos] = [] - positions[pos].append(block) - for blocks in positions.values(): - for block in blocks: - block2gid[id(block)] = gid - gid += 1 - def blockhash(block: Block) -> int: - return block2gid[id(block)] - return blockhash - - @staticmethod - def to_same_step(micros: List[MicroPlan]): - """ - extend micros to same step - """ - nsteps = max(micro.nsteps for micro in micros) - for micro in micros: - micro.expand_to(nsteps) - return micros - - @staticmethod - def iter_shifts(micros: List[MicroPlan], - exe_conflicts: Dict[int, List[Block]], - mem_conflicts: Dict[int, List[Block]], - block_hash: Optional[Callable] = None) -> List[Block]: - """ - Enumerate shifted blocks to resolve conflicts on step `step`. - """ - devs = tuple(exe_conflicts.keys()) - prev_keep: List[Dict[int, Block]] = [{devid: None for devid in devs}] - next_keep: List[Dict[int, Block]] = [] - - keep_candidates: Dict[int, List[Block]] = {devid: [] for devid in devs} - for devid in devs: - for cblock in exe_conflicts[devid]: - if cblock not in mem_conflicts[devid]: - keep_candidates[devid].append(cblock) - - for dev in devs: - for keeps in prev_keep: - cblocks = [c for c in keep_candidates[dev]] - cmicros = [micros[c.mid] for c in cblocks] - if keeps[dev] is None: - # get candidate by pruning the symetric block - candidates = [] - if block_hash is not None: - gids = [block_hash(cblock) for cblock in cblocks] - for gid in set(gids): - gblocks = [cblock for cblock, cgid in zip(cblocks, gids) if cgid == gid] - gmids = [gblock.mid for gblock in gblocks] - idx = gmids.index(min(gmids)) - candidates.append(gblocks[idx]) - else: - candidates = cblocks - if len(candidates) == 0: - next_keep.append(keeps) - else: - for kblock in candidates: - idx = cblocks.index(kblock) - kmicro = cmicros[idx] - empty = True - for kdev in kmicro.position(kblock)[0]: - if kdev in keeps and keeps[kdev] is not None: - empty = False - break - if empty: - new_keeps = {devid: blk for devid, blk in keeps.items()} - for kdev in kmicro.position(kblock)[0]: - if kdev in new_keeps: - new_keeps[kdev] = kblock - next_keep.append(new_keeps) - else: - next_keep.append(keeps) - prev_keep, next_keep = next_keep, [] - - for keeps in prev_keep: - shifts = [] - for devid in devs: - kblock = keeps[devid] - for cblock in exe_conflicts[devid] + mem_conflicts[devid]: - if kblock != cblock and cblock not in shifts: - shifts.append(cblock) - yield shifts - - @staticmethod - def memory_opt(micros: List[MicroPlan]): - """ - optimize memory given a schedule plan. - The micros are composable. - """ - nsteps = max(micro.nsteps for micro in micros) - for step in range(nsteps-1, -1, -1): - micros = Composer.memory_opt_step(micros, step) - return micros - - @staticmethod - def memory_opt_step(micros: List[MicroPlan], step: int): - splan = sum(micro.plan for micro in micros) - free_steps = [np.where(splan[dev,:] == 0)[0] for dev in range(micros[0].ndevs)] - for micro in micros: - devs = np.where(micro.plan[:,step] > 0)[0] - fblocks = [] - # find forward blocks - for dev in devs: - block = micro.block(dev, step) - if block.btype != Block.BType.FW: - continue - if block not in fblocks: - fblocks.append(block) - # find non-critical forward blocks - for block in fblocks: - maxstep = min(micro.position(nblock)[1] for nblock in block.after) - 1 - if maxstep == step: # no room for shift => critical - continue - # find maximal shift distance - maxshift = None - for t in range(maxstep, step, -1): - if all([t in free_steps[dev] for dev in micro.position(block)[0]]): - maxshift = t - step - break - # apply shift by `distance` times - if maxshift is not None: - for _ in range(maxshift): - micro.shift(block, inplace=True) - return micros - - -if __name__ == '__main__': - - def uniform_staging(ndevs: int, nmicros=4) -> List[MicroPlan]: - """ - f b - f b - f b - f b - """ - micros = [] - for mid in range(nmicros): - micro = MicroPlan(mid, ndevs) - fblocks = [micro.add_block((sid, sid), Block.BType.FW) for sid in range(ndevs)] - bblocks = [micro.add_block((ndevs-1-sid, sid+ndevs), Block.BType.BW) for sid in range(ndevs)] - blocks = fblocks + bblocks - micro.add_dependency(blocks) - micros.append(micro) - return micros - - def mbart_staging(ndevs: int, nmicros=4) -> List[MicroPlan]: - """ - f f f b b b - f f f b b b - f f f b b b - f f f b b b - """ - micros = [] - for mid in range(nmicros): - micro = MicroPlan(mid, ndevs) - fblocks = [] - bblocks = [] - for step in range(ndevs+2): - if step in [0, ndevs // 2+1]: - fblock = micro.add_block((tuple(range(ndevs)), step), Block.BType.FW) - bblock = micro.add_block((tuple(range(ndevs)), (ndevs+2)*2-1-step), Block.BType.BW) - else: - dev = step - 1 if step < ndevs//2+1 else step - 2 - fblock = micro.add_block((dev, step), Block.BType.FW) - bblock = micro.add_block((dev, (ndevs+2)*2-1-step), Block.BType.BW) - fblocks.append(fblock) - bblocks.append(bblock) - micro.add_dependency(fblocks+bblocks[::-1]) - micros.append(micro) - return micros - - def chimera_staging(ndevs: int, nmicros: int) -> List[MicroPlan]: - """ - f b f b - f b f b - f b f b - f b f b - """ - micros = [] - assert nmicros % 2 == 0, "require microbatch# can be divided by 2." - for mid in range(nmicros // 2): # V shape - micro = MicroPlan(mid, ndevs) - fblocks = [micro.add_block((sid, sid), Block.BType.FW) for sid in range(ndevs)] - bblocks = [micro.add_block((ndevs-1-sid, sid+ndevs), Block.BType.BW) for sid in range(ndevs)] - blocks = fblocks + bblocks - micro.add_dependency(blocks) - micros.append(micro) - for mid in range(nmicros // 2): # ^ shape - micro = MicroPlan(mid + nmicros // 2, ndevs) - fblocks = [micro.add_block((ndevs-1-sid, sid), Block.BType.FW) for sid in range(ndevs)] - bblocks = [micro.add_block((sid, sid+ndevs), Block.BType.BW) for sid in range(ndevs)] - blocks = fblocks + bblocks - micro.add_dependency(blocks) - micros.append(micro) - return micros - - def compose_1F1B(ndevs, nmicros): - # premise - micros = uniform_staging(ndevs, nmicros) - print('premise micros:') - for micro in micros: - print(micro) - # shift - for mid, micro in enumerate(micros): - block = micro.block(0, 0) - for _ in range(2 * mid): - micro.shift(block) - print('shifted micros:') - for micro in micros: - print(micro) - assert SchedulePlan.composable(micros) - schedule = SchedulePlan(micros) - print(f'schedule (step={schedule.nsteps}):') - print(schedule) - return schedule - - def search(ndevs, nmicros, mem_constraints: int, visualize=False): - # premise - # micros = Composer.premise(uniform_staging, ndevs, nmicros) - # micros = Composer.premise(chimera_staging, ndevs, nmicros) - micros = Composer.premise(mbart_staging, ndevs, nmicros) - print('============== Premise ================') - for idx, micro in enumerate(micros): - print(f'microbatch #{idx}:') - print(micro) - if visualize: - micro.visualize(f'planlog/micro{idx}.png') - print('============== Premise ================') - - # search shift - tic = time.time() - schedules = Composer.bfs_schedule(micros, mem_constraints=mem_constraints, mem_opt=False, prune_symmetric=True) - toc = time.time() - print('search done. time {:.2f}s'.format(toc - tic)) - - steps = set(schedule.nsteps for schedule in schedules) - assert len(steps) == 1, f"got un-consistent step set: {steps}" - nsteps = list(steps)[0] - print(f'find {len(schedules)} step-optimal plans (step={nsteps})') - print(f'one solution:\n{schedules[0]}\n{schedules[0].memory()}') - # for idx, schedule in enumerate(schedules): - # print(f'Schedule #{idx+1}:') - # print(schedule) - # if visualize: - # schedule.visualize(f'planlog/plan{idx+1}.png') - - - ndevs = 4 - nmicros = 4 - - # schedule = compose_1F1B(ndevs, nmicros) - # schedule.visualize('out.png') - search(ndevs, nmicros, mem_constraints=10, visualize=False) diff --git a/cube/tetris/solver.py b/cube/tetris/solver.py deleted file mode 100644 index 07f06d1b..00000000 --- a/cube/tetris/solver.py +++ /dev/null @@ -1,367 +0,0 @@ -""" -A solver based solution for scheduling plan -""" - -from typing import List, Optional, Tuple -from enum import Enum - -from z3 import * -import time -import copy - - - -gsolver = Solver() - - -class Block: - - class BType(Enum): - FW = 'forward' - BW = 'backward' - - def __init__(self, mid: int, btype: BType, name: str, mem=1): - global _uid - global gsolver - self.name = name - self.mid = mid - self.step = Int(name) - self.memory = mem if btype == Block.BType.FW else 0-mem - gsolver.add(self.step >= 1) - self.btype = btype - - @staticmethod - def add_dependency(blk1, blk2): - """ - add dependency: blk1 -> blk2 - """ - global gsolver - gsolver.add(blk1.step < blk2.step) - - def __repr__(self): - return f'f{self.mid}' if self.btype == Block.BType.FW else f'b{self.mid}' - - -class SchedulePlan: - - def __init__(self, ndevs: int) -> None: - - self._blocks: List[List[Block]] = [[] for _ in range(ndevs)] - self.ndevs = ndevs - self._nsteps = None - self._mem = None - self._solution: Optional[z3.z3.ModelRef] = None - - @property - def nblocks(self) -> int: - return sum(len(blks) for blks in self._blocks) - - @property - def nsteps(self) -> int: - if self._solution is None: - return -1 - return self._solution.eval(self._nsteps).as_long() - - @property - def mem(self) -> int: - if self._mem is None: - return -1 - return self._solution.eval(self._mem).as_long() - - def blocks(self, devid: Optional[int] = None) -> List[Block]: - if isinstance(devid, int): - return copy.copy(self._blocks[devid]) - else: - allblocks = [] - for blks in self._blocks: - allblocks += blks - return allblocks - - def position(self, block: Block) -> Tuple[int, int]: - """ - get block position (device, time) after the search - """ - # device - for devid in range(self.ndevs): - if block in self.blocks(devid): - break - else: - assert False, 'block not in schedule plan' - # time step - step = None - if self._solution is not None: - step = self._solution[block.step] - return (devid, step) - - def add_block(self, block: Block, devices: Tuple[int]): - global gsolver - devices = (devices,) if isinstance(devices, int) else devices - for device in devices: - for blk in self._blocks[device]: - gsolver.add(blk.step != block.step) - self._blocks[device].append(block) - # set plan step variable - if self._nsteps is None: - self._nsteps = block.step - else: - self._nsteps = If(block.step > self._nsteps, block.step, self._nsteps) - - def set_memory(self): - nblocks = max(len(blks) for blks in self._blocks) - # mems = [IntVector(f'memdev{devid}', nblocks) for devid in range(self.ndevs)] - peaks = [] - for devid in range(self.ndevs): - peak = 0 - curr = 0 - for step in range(0, nblocks): - mem = 0 - for block in self.blocks(devid): - mem = If(block.step == step, block.memory, mem) - curr = mem + curr - peak = If(curr > peak, curr, peak) - peaks.append(peak) - # global peak - globalpeak = peaks[0] - for devid in range(1, self.ndevs): - globalpeak = If(peaks[devid] > globalpeak, peaks[devid], globalpeak) - self._mem = globalpeak - return globalpeak - - def set_solution(self, solution: z3.z3.ModelRef): - self._solution = solution - - def solve(self, decrease = True): - global gsolver - tic = time.time() - min_step = max(len(blks) for blks in self._blocks) - max_step = self.nblocks - opt_step = max_step if decrease else min_step - while True: - assert min_step <= opt_step and opt_step <= max_step, "out of step boundary. consider this as a bug." - gsolver.push() - gsolver.add(self._nsteps == opt_step) - if gsolver.check() == sat: - print(f'find scheduling plan in {opt_step} steps') - solution = gsolver.model() - self.set_solution(solution) - gsolver.pop() - if not decrease: break - else: - print(f'fail to find solution for {opt_step} steps') - gsolver.pop() - if decrease: - opt_step += 1 - break - opt_step = opt_step - 1 if decrease else opt_step + 1 - toc = time.time() - print('search time: {:.2f} seconds. find optimal step: {}'.format(toc-tic, opt_step)) - print('solution:') - print(self) - - # search for optimal memory - tic = time.time() - min_mem = max(min(blk.memory for blk in blks if blk.btype == Block.BType.FW) for blks in self._blocks) - max_mem = max(sum(blk.memory for blk in blks if blk.btype == Block.BType.FW) for blks in self._blocks) - opt_mem = max_mem if decrease else min_mem - self.set_memory() - gsolver.push() - gsolver.add(self._nsteps == opt_step) - while True: - assert min_mem <= opt_mem and opt_mem <= max_mem, "out of memory boundary. consider this as a bug" - gsolver.push() - gsolver.add(self._mem == opt_mem) - if gsolver.check() == sat: - print(f'find scheduling plan in {opt_mem} memory') - solution = gsolver.model() - self.set_solution(solution) - gsolver.pop() - if not decrease: break - else: - print(f'fail to find solution for memory {opt_mem}') - gsolver.pop() - if decrease: - opt_mem += 1 - break - opt_mem = opt_mem - 1 if decrease else opt_mem + 1 - gsolver.pop() - toc = time.time() - print('search memory time: {:.2f} seconds. opt-memory: {}'.format(toc-tic, opt_mem)) - print('solution:') - print(self) - - tic = time.time() - self.iter_space(opt_step, opt_mem) - toc = time.time() - print('iterate all plans: {:.2f} seconds.'.format(toc-tic)) - - def solve_mconstraints(self, memory: int, decrease=True): - global gsolver - tic = time.time() - min_step = max(len(blks) for blks in self._blocks) - max_step = self.nblocks - opt_step = max_step if decrease else min_step - - self.set_memory() - - # memory constraints - gsolver.push() - gsolver.add(self._mem <= memory) - # find optimal step - while True: - assert min_step <= opt_step and opt_step <= max_step, "out of step boundary. consider this as a bug." - gsolver.push() - gsolver.add(self._nsteps == opt_step) - if gsolver.check() == sat: - print(f'find scheduling plan in {opt_step} steps') - solution = gsolver.model() - self.set_solution(solution) - gsolver.pop() - if not decrease: break - else: - print(f'fail to find solution for {opt_step} steps') - gsolver.pop() - if decrease: - opt_step += 1 - break - opt_step = opt_step - 1 if decrease else opt_step + 1 - toc = time.time() - print('search time: {:.2f} seconds. find optimal step: {}'.format(toc-tic, opt_step)) - print('solution:') - print(self) - - tic = time.time() - self.iter_space(opt_step) - toc = time.time() - print('iterate all plans: {:.2f} seconds.'.format(toc-tic)) - - - def iter_space(self, nsteps: int, memory: int = None): - """ - iterate all solutions find by solver - """ - global gsolver - gsolver.push() - gsolver.add(self._nsteps == nsteps) - if memory is not None: - gsolver.add(self._mem == memory) - models = [] - while gsolver.check() == sat: - model = gsolver.model() - models.append(model) - block = [] - for d in model: - assert not d.arity() > 0, 'uniterpreted functions found' - c = d() - block.append(c != model[d]) - gsolver.add(Or(block)) - if len(models) % 100 == 0: - print(f'find {len(models)} solutions..') - gsolver.pop() - print(f'find {len(models)} possible models') - - - def __repr__(self) -> str: - if self._solution is None: - return 'Unsolved Schedule Plan.' - namelen = 2 - dscp = '' - for devid in range(self.ndevs): - blocks = self.blocks(devid) - steps = [self.position(blk)[1] for blk in blocks] - for step in range(1, self.nsteps+1): - if step not in steps: - dscp += '-' * namelen + ' ' - else: - idx = steps.index(step) - dscp += '{: <2}'.format(repr(blocks[idx])) + ' ' - dscp += '\n' - return dscp - - -if __name__ == '__main__': - - def uniform_staging(ndevs: int, nmicros) -> SchedulePlan: - """ - f b - f b - f b - f b - """ - sched = SchedulePlan(ndevs) - for mid in range(nmicros): - fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}') for devid in range(ndevs)] - bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}') for devid in range(ndevs)][::-1] - blocks = fblocks + bblocks - for idx in range(ndevs * 2 - 1): - Block.add_dependency(blocks[idx], blocks[idx+1]) - for devid in range(ndevs): - sched.add_block(fblocks[devid], devid) - sched.add_block(bblocks[ndevs-1-devid], devid) - return sched - - def chimera_staging(ndevs: int, nmicros: int) -> SchedulePlan: - """ - f b f b - f b f b - f b f b - f b f b - """ - sched = SchedulePlan(ndevs) - assert nmicros % 2 == 0, "require microbatch# can be devided by 2" - for mid in range(nmicros // 2): # V shape - fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}', mem=1) for devid in range(ndevs)] - bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}', mem=1) for devid in range(ndevs-1,-1,-1)] - blocks = fblocks + bblocks - for idx in range(ndevs * 2 - 1): - Block.add_dependency(blocks[idx], blocks[idx+1]) - for devid in range(ndevs): - sched.add_block(fblocks[devid], devid) - sched.add_block(bblocks[ndevs-1-devid], devid) - for mid in range(nmicros // 2): # ^ shape - mid = mid + nmicros // 2 - fblocks = [Block(mid, Block.BType.FW, f'f{mid}d{devid}', mem=1) for devid in range(ndevs-1,-1,-1)] - bblocks = [Block(mid, Block.BType.BW, f'b{mid}d{devid}', mem=1) for devid in range(ndevs)] - blocks = fblocks + bblocks - for idx in range(ndevs * 2 - 1): - Block.add_dependency(blocks[idx], blocks[idx+1]) - for devid in range(ndevs): - sched.add_block(fblocks[ndevs-1-devid], devid) - sched.add_block(bblocks[devid], devid) - return sched - - def mbart_staging(ndevs: int, nmicros: int) -> SchedulePlan: - """ - f f f b b b - f f f b b b - f f f b b b - f f f b b b - """ - sched = SchedulePlan(ndevs) - for mid in range(nmicros): - fblocks = [] - bblocks = [] - for step in range(ndevs+2): - if step in [0, ndevs // 2 + 1]: - fdevid = bdevid = tuple(range(ndevs)) - fblock = Block(mid, Block.BType.FW, f'fe{step}{mid}devall', mem=1) - bblock = Block(mid, Block.BType.BW, f'be{step}{mid}devall', mem=1) - else: - fdevid = bdevid = step - 1 if step < ndevs // 2 + 1 else step - 2 - fblock = Block(mid, Block.BType.FW, f'f{mid}dev{fdevid}', mem=1) - bblock = Block(mid, Block.BType.BW, f'b{mid}dev{bdevid}', mem=1) - fblocks.append(fblock) - bblocks.append(bblock) - sched.add_block(fblock, fdevid) - sched.add_block(bblock, bdevid) - blocks = fblocks + bblocks[::-1] - for idx in range((ndevs + 2) * 2 - 1): - Block.add_dependency(blocks[idx], blocks[idx+1]) - return sched - - ndevs = 8 - nmicros = 8 - - sched = uniform_staging(ndevs, nmicros) - # sched = chimera_staging(ndevs, nmicros) - # sched = mbart_staging(ndevs, nmicros) # ndev=4, nmicro=4 => solution: step=30 - sched.solve_mconstraints(memory=ndevs, decrease=True) From c3c15b9a89ebfc49f2d001740963b90b543fc729 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 23 Sep 2022 02:38:39 -0700 Subject: [PATCH 1038/1892] save work. not runnable --- .../{evoformer.py => alphafold2.py} | 163 ++++++++++++++---- examples/alphafold2/policy/spmd.py | 25 +++ 2 files changed, 157 insertions(+), 31 deletions(-) rename examples/alphafold2/{evoformer.py => alphafold2.py} (67%) create mode 100644 examples/alphafold2/policy/spmd.py diff --git a/examples/alphafold2/evoformer.py b/examples/alphafold2/alphafold2.py similarity index 67% rename from examples/alphafold2/evoformer.py rename to examples/alphafold2/alphafold2.py index 42ac0d95..ba3971bc 100644 --- a/examples/alphafold2/evoformer.py +++ b/examples/alphafold2/alphafold2.py @@ -1,10 +1,16 @@ import torch import math +import cube from cube.profiler import CudaTimer - +from cube.profiler.timer import print_each_rank +from torch import nn from torch.utils.checkpoint import checkpoint +import examples.alphafold2.policy.spmd as spmd + +cube.init() + """ [bs, s, r, cm] -> [bs, s, r, cm] @@ -12,9 +18,39 @@ used as column-wise gated self-attention """ - +@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R M', name='MSAAttention') def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + head: int, c: int, scale: float): + bs, s, r, cm = x.size() + + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + + gate = gate.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + q = q.reshape(bs, s, r, head, c).transpose(2, + 3).reshape(bs * s * head, r, c) + k = k.reshape(bs, s, r, head, c).transpose(2, + 3).reshape(bs * s * head, r, + c).transpose(1, 2) + v = v.reshape(bs, s, r, head, c).transpose(2, + 3).reshape(bs * s * head, r, c) + + sim = torch.bmm(q, k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + + attend = torch.bmm(sim, v) * gate + + out = attend.reshape(bs, s, head, r, c).transpose(2, + 3).reshape(bs, s, r, cm) + out = torch.matmul(out, out_proj) + return out + + +@cube.graph.parser.register('N S R M, M E, M F, E M, N 1 8 R R -> N S R M', name='MSAAttentionWithBias') +def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float): bs, s, r, cm = x.size() @@ -34,9 +70,8 @@ def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, sim = torch.bmm(q, k) * scale sim = torch.nn.functional.softmax(sim, dim=-1) - if isinstance(bias, torch.Tensor): - sim = sim.reshape(bs, s, head, r, r) + bias - sim = sim.reshape(bs * s * head, r, r) + sim = sim.reshape(bs, s, head, r, r) + bias + sim = sim.reshape(bs * s * head, r, r) attend = torch.bmm(sim, v) * gate @@ -50,7 +85,8 @@ def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, ([bs, s, r, cm], [bs, r, r, cz]) -> [bs, s, r, cm] """ - +# note: code not reused constrained by cube's interface +@cube.graph.parser.register('N S R M, N R R Z, M E, M F, E M, Z H -> N S R M', name='MSARowAttentionWithPairBias') def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, pair_repr: torch.Tensor, gate_proj: torch.Tensor, @@ -63,15 +99,14 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, bias_proj).permute(0, 3, 1, 2).reshape(bs, 1, head, r, r) - return MSAAttention(msa_repr, gate_proj, qkv_proj, out_proj, bias, head, c, - scale) + return MSAAttentionWithBias(msa_repr, gate_proj, qkv_proj, out_proj, bias, head, c, scale) """ [bs, s, r, cm] -> [bs, s, r, cm] """ - +@cube.graph.parser.register('N S R M, M E, E M -> N S R M', name='MSATransition') def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): return torch.matmul( @@ -79,10 +114,11 @@ def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, """ -[bs, s, r, cm] -> [r, r, cz] +[bs, s, r, cm] -> [bs, r, r, cz] """ +@cube.graph.parser.register('N S R M, M C, M C, F Z -> N R R Z', name='OuterProductMean') def OuterProductMean(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor, out_proj: torch.Tensor): bs, s, r, cm = msa_repr.size() @@ -97,6 +133,7 @@ def OuterProductMean(msa_repr: torch.Tensor, proj1: torch.Tensor, return outer +@cube.graph.parser.register('N R R Z, Z, Z, E, E, Z E, Z E, Z E, Z E, E Z, Z Z -> N R R Z', name='TriangleMultiplication') def TriangleMultiplication( pair_repr: torch.Tensor, tri_mul_norm1_weight: torch.Tensor, tri_mul_norm1_bias: torch.Tensor, tri_mul_norm2_weight: torch.Tensor, @@ -127,16 +164,18 @@ def TriangleMultiplication( return p * g +@cube.graph.parser.register('N R R Z, Z E, Z F, E Z, Z G -> N R R Z', name='TriangleAttentionNode') def TriangleAttentionNode(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias_proj: torch.Tensor, head: int, c: int, scale: float): bias = torch.matmul(pair_repr, bias_proj).permute(0, 3, 1, 2) - return MSAAttention(pair_repr, gate_proj, qkv_proj, out_proj, bias, head, + return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, head, c, scale) +@cube.graph.parser.register('N R R Z, Z E, E Z -> N R R Z', name='PairTransition') def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): return torch.matmul( @@ -257,7 +296,8 @@ def __init__(self, self.pair_transition_proj2 = torch.nn.Parameter( torch.randn(ff_mult * cz, cz)) - def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): + def forward(self, state): + msa_repr, pair_repr = state # msa_repr = msa_repr + MSARowAttentionWithPairBias( # self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, @@ -269,21 +309,27 @@ def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) - # msa_repr = msa_repr + MSAAttention( - # self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - # self.col_out_proj, None, self.msa_head, self.c, self.scale) - msa_repr = msa_repr + checkpoint(MSAAttention, + msa_repr = msa_repr + MSAAttention( self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - self.col_out_proj, None, self.msa_head, self.c, self.scale) + self.col_out_proj, self.msa_head, self.c, self.scale) + # msa_repr = msa_repr + checkpoint(MSAAttention, + # self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, + # self.col_out_proj, self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), self.msa_transition_proj1, self.msa_transition_proj2) + # msa_repr = msa_repr + checkpoint(MSATransition, self.msa_transition_norm(msa_repr), + # self.msa_transition_proj1, + # self.msa_transition_proj2) pair_repr = pair_repr + OuterProductMean( self.outer_norm(msa_repr), self.outer_proj1, self.outer_proj2, self.outer_out_proj) + # pair_repr = pair_repr + checkpoint(OuterProductMean, + # self.outer_norm(msa_repr), self.outer_proj1, self.outer_proj2, + # self.outer_out_proj) pair_repr = pair_repr + TriangleMultiplication( pair_repr, self.tri_mul_out_norm1_weight, @@ -292,6 +338,13 @@ def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): self.tri_mul_out_proj2, self.tri_mul_out_proj3, self.tri_mul_out_proj4, self.tri_mul_out_proj5, self.tri_mul_out_proj6, self.cz, True) + # pair_repr = pair_repr + checkpoint(TriangleMultiplication, + # pair_repr, self.tri_mul_out_norm1_weight, + # self.tri_mul_out_norm1_bias, self.tri_mul_out_norm2_weight, + # self.tri_mul_out_norm2_bias, self.tri_mul_out_proj1, + # self.tri_mul_out_proj2, self.tri_mul_out_proj3, + # self.tri_mul_out_proj4, self.tri_mul_out_proj5, + # self.tri_mul_out_proj6, self.cz, True) pair_repr = pair_repr + TriangleMultiplication( pair_repr, self.tri_mul_in_norm1_weight, @@ -300,43 +353,87 @@ def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor): self.tri_mul_in_proj2, self.tri_mul_in_proj3, self.tri_mul_in_proj4, self.tri_mul_in_proj5, self.tri_mul_in_proj6, self.cz, False) + # pair_repr = pair_repr + checkpoint(TriangleMultiplication, + # pair_repr, self.tri_mul_in_norm1_weight, + # self.tri_mul_in_norm1_bias, self.tri_mul_in_norm2_weight, + # self.tri_mul_in_norm2_bias, self.tri_mul_in_proj1, + # self.tri_mul_in_proj2, self.tri_mul_in_proj3, + # self.tri_mul_in_proj4, self.tri_mul_in_proj5, + # self.tri_mul_in_proj6, self.cz, False) pair_repr = pair_repr + TriangleAttentionNode( self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, self.tri_att_start_bias_proj, self.pair_head, self.c, self.scale) + # pair_repr = pair_repr + checkpoint(TriangleAttentionNode, + # self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, + # self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, + # self.tri_att_start_bias_proj, self.pair_head, self.c, self.scale) pair_repr = pair_repr.transpose(-3, -2) pair_repr = pair_repr + TriangleAttentionNode( self.tri_att_end_norm(pair_repr), self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, self.tri_att_end_bias_proj, self.pair_head, self.c, self.scale) + # pair_repr = pair_repr + checkpoint(TriangleAttentionNode, + # self.tri_att_end_norm(pair_repr), self.tri_att_end_gate_proj, + # self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, + # self.tri_att_end_bias_proj, self.pair_head, self.c, self.scale) pair_repr = pair_repr.transpose(-3, -2) pair_repr = pair_repr + PairTransition( self.pair_transition_norm(pair_repr), self.pair_transition_proj1, self.pair_transition_proj2) + # pair_repr = pair_repr + checkpoint(PairTransition, + # self.pair_transition_norm(pair_repr), self.pair_transition_proj1, + # self.pair_transition_proj2) + + return (msa_repr, pair_repr) + +class AlphaFold2(nn.Module): + def __init__(self, s: int, cm: int, cz: int, evo_num: int): + super().__init__() + + self.net = nn.Sequential( + *[Evoformer(s, cm, cz) for _ in range(evo_num)], + ) + + + def forward(self, msa, pair): + new_msa, new_pair = self.net((msa, pair)) + loss = torch.sum(new_msa) + torch.sum(new_pair) + return loss + + +def test(): + bs, s, r, cm, cz = 4, 128, 256, 256, 128 + + model = AlphaFold2(s, cm, cz, 1) - return msa_repr, pair_repr + model = cube.SemanticModel( + model, + input_shapes=([bs, s, r, cm], [bs, r, r, cz], ), + ) -def train_iter(model, msa, pair): - bs = msa.size() - new_msa, new_pair = model(msa, pair) - loss = torch.sum(new_msa) + torch.sum(new_pair) - loss.backward() + dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], [bs, r, r, cz],), + dtypes=(torch.float32, ), + batch_dims=(0, )) -def test(dev): - bs, s, r, cm, cz = 1, 128, 256, 256, 128 - model = Evoformer(s, cm, cz).to(dev) + @cube.compile(model, dataloader, PAS=spmd.PASData) + def train_iter(model, dataloader): + msa_repr, pair_repr = next(dataloader) + loss = model(msa_repr, pair_repr) + loss.backward() - msa = torch.randn(bs, s, r, cm).to(dev) - pair = torch.randn(bs, r, r, cz).to(dev) + model = model.get_gen_module() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) warm_up = 20 iter_num = 64 CudaTimer(enable=False).warmup() + if torch.distributed.is_initialized(): + torch.distributed.barrier() for i in range(iter_num): if i >= warm_up: @@ -346,9 +443,13 @@ def test(dev): optimizer.zero_grad() if i >= warm_up: CudaTimer().stop('e2e') + if i > 0 and (i + 1) % 20 == 0: + # print_each_rank(f'iter [{i + 1}/{iter_num}]', rank_only=0) + print(f'iter [{i + 1}/{iter_num}]') print(CudaTimer().duration(iter_num - warm_up, field_name='e2e'), 'ms') - print(torch.cuda.memory_summary(dev)) + # print(torch.cuda.memory_summary(dev)) + print(torch.cuda.max_memory_allocated() / 1024 / 1024, ' MB') -test(torch.device('cuda:0')) +test() diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py new file mode 100644 index 00000000..87d0243d --- /dev/null +++ b/examples/alphafold2/policy/spmd.py @@ -0,0 +1,25 @@ +from typing import List +from cube.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation + +def PASData(graph: IRGraph, resource): + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + batch_dim = node.get_batch_dims()[0] + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, + algo, + idx=0, + dim=batch_dim, + num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph From b4e4b27faf9cfe120e6200e6fa8fe0e46c12e4fd Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 23 Sep 2022 06:40:10 -0700 Subject: [PATCH 1039/1892] data parallel runnable --- examples/alphafold2/alphafold2.py | 40 ++++++++++++++---------------- examples/alphafold2/policy/spmd.py | 6 +++++ 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index ba3971bc..183de046 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -296,17 +296,16 @@ def __init__(self, self.pair_transition_proj2 = torch.nn.Parameter( torch.randn(ff_mult * cz, cz)) - def forward(self, state): - msa_repr, pair_repr = state + def forward(self, msa_repr, pair_repr): - # msa_repr = msa_repr + MSARowAttentionWithPairBias( - # self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, - # self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, - # self.msa_head, self.c, self.scale) - msa_repr = msa_repr + checkpoint(MSARowAttentionWithPairBias, + msa_repr = msa_repr + MSARowAttentionWithPairBias( self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, self.msa_head, self.c, self.scale) + # msa_repr = msa_repr + checkpoint(MSARowAttentionWithPairBias, + # self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, + # self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, + # self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) msa_repr = msa_repr + MSAAttention( @@ -393,15 +392,14 @@ def forward(self, state): class AlphaFold2(nn.Module): def __init__(self, s: int, cm: int, cz: int, evo_num: int): super().__init__() - - self.net = nn.Sequential( - *[Evoformer(s, cm, cz) for _ in range(evo_num)], - ) + self.evo_num = evo_num + # self.evoformers: List[torch.nn.Module] = [Evoformer(s, cm, cz) for _ in range(evo_num)] + self.evoformer = Evoformer(s, cm, cz) def forward(self, msa, pair): - new_msa, new_pair = self.net((msa, pair)) - loss = torch.sum(new_msa) + torch.sum(new_pair) + new_msa, new_pair = self.evoformer(msa, pair) + loss = torch.sum(new_msa) * torch.sum(new_pair) return loss @@ -416,8 +414,8 @@ def test(): ) dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], [bs, r, r, cz],), - dtypes=(torch.float32, ), - batch_dims=(0, )) + dtypes=(torch.float32, torch.float32, ), + batch_dims=(0, 0, )) @cube.compile(model, dataloader, PAS=spmd.PASData) def train_iter(model, dataloader): @@ -438,18 +436,18 @@ def train_iter(model, dataloader): for i in range(iter_num): if i >= warm_up: CudaTimer(enable=True).start('e2e') - train_iter(model, msa, pair) + train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() if i >= warm_up: CudaTimer().stop('e2e') if i > 0 and (i + 1) % 20 == 0: - # print_each_rank(f'iter [{i + 1}/{iter_num}]', rank_only=0) - print(f'iter [{i + 1}/{iter_num}]') + print_each_rank(f'iter [{i + 1}/{iter_num}]', rank_only=0) - print(CudaTimer().duration(iter_num - warm_up, field_name='e2e'), 'ms') - # print(torch.cuda.memory_summary(dev)) - print(torch.cuda.max_memory_allocated() / 1024 / 1024, ' MB') + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num - warm_up, field_name='e2e'))) + CudaTimer().print_all(times=iter_num - warm_up) + print_each_rank(torch.cuda.max_memory_allocated() / 1024 / 1024) test() diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 87d0243d..0f233551 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -3,6 +3,7 @@ from cube.ir.operator import IRDataOperation, IRFwOperation def PASData(graph: IRGraph, resource): + devs = list(range(resource.ngpus)) for node in graph.nodes(): if isinstance(node, IRDataOperation): @@ -14,6 +15,11 @@ def PASData(graph: IRGraph, resource): for node in graph.nodes(): if isinstance(node, IRFwOperation): + if node.name == 'mul': + sub_nodes = graph.replicate(node, times=resource.ngpus) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + continue algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, From cada67c9f3343136ec23ac0683594589424f6575 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 26 Sep 2022 04:55:52 -0700 Subject: [PATCH 1040/1892] refine code --- examples/alphafold2/alphafold2.py | 41 +----------------------------- examples/alphafold2/policy/spmd.py | 17 +++++++++++++ 2 files changed, 18 insertions(+), 40 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 183de046..43685008 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -302,33 +302,20 @@ def forward(self, msa_repr, pair_repr): self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, self.msa_head, self.c, self.scale) - # msa_repr = msa_repr + checkpoint(MSARowAttentionWithPairBias, - # self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, - # self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, - # self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) msa_repr = msa_repr + MSAAttention( self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, self.col_out_proj, self.msa_head, self.c, self.scale) - # msa_repr = msa_repr + checkpoint(MSAAttention, - # self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - # self.col_out_proj, self.msa_head, self.c, self.scale) msa_repr = msa_repr.transpose(-3, -2) msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), self.msa_transition_proj1, self.msa_transition_proj2) - # msa_repr = msa_repr + checkpoint(MSATransition, self.msa_transition_norm(msa_repr), - # self.msa_transition_proj1, - # self.msa_transition_proj2) pair_repr = pair_repr + OuterProductMean( self.outer_norm(msa_repr), self.outer_proj1, self.outer_proj2, self.outer_out_proj) - # pair_repr = pair_repr + checkpoint(OuterProductMean, - # self.outer_norm(msa_repr), self.outer_proj1, self.outer_proj2, - # self.outer_out_proj) pair_repr = pair_repr + TriangleMultiplication( pair_repr, self.tri_mul_out_norm1_weight, @@ -337,13 +324,6 @@ def forward(self, msa_repr, pair_repr): self.tri_mul_out_proj2, self.tri_mul_out_proj3, self.tri_mul_out_proj4, self.tri_mul_out_proj5, self.tri_mul_out_proj6, self.cz, True) - # pair_repr = pair_repr + checkpoint(TriangleMultiplication, - # pair_repr, self.tri_mul_out_norm1_weight, - # self.tri_mul_out_norm1_bias, self.tri_mul_out_norm2_weight, - # self.tri_mul_out_norm2_bias, self.tri_mul_out_proj1, - # self.tri_mul_out_proj2, self.tri_mul_out_proj3, - # self.tri_mul_out_proj4, self.tri_mul_out_proj5, - # self.tri_mul_out_proj6, self.cz, True) pair_repr = pair_repr + TriangleMultiplication( pair_repr, self.tri_mul_in_norm1_weight, @@ -352,40 +332,22 @@ def forward(self, msa_repr, pair_repr): self.tri_mul_in_proj2, self.tri_mul_in_proj3, self.tri_mul_in_proj4, self.tri_mul_in_proj5, self.tri_mul_in_proj6, self.cz, False) - # pair_repr = pair_repr + checkpoint(TriangleMultiplication, - # pair_repr, self.tri_mul_in_norm1_weight, - # self.tri_mul_in_norm1_bias, self.tri_mul_in_norm2_weight, - # self.tri_mul_in_norm2_bias, self.tri_mul_in_proj1, - # self.tri_mul_in_proj2, self.tri_mul_in_proj3, - # self.tri_mul_in_proj4, self.tri_mul_in_proj5, - # self.tri_mul_in_proj6, self.cz, False) pair_repr = pair_repr + TriangleAttentionNode( self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, self.tri_att_start_bias_proj, self.pair_head, self.c, self.scale) - # pair_repr = pair_repr + checkpoint(TriangleAttentionNode, - # self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, - # self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, - # self.tri_att_start_bias_proj, self.pair_head, self.c, self.scale) pair_repr = pair_repr.transpose(-3, -2) pair_repr = pair_repr + TriangleAttentionNode( self.tri_att_end_norm(pair_repr), self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, self.tri_att_end_bias_proj, self.pair_head, self.c, self.scale) - # pair_repr = pair_repr + checkpoint(TriangleAttentionNode, - # self.tri_att_end_norm(pair_repr), self.tri_att_end_gate_proj, - # self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, - # self.tri_att_end_bias_proj, self.pair_head, self.c, self.scale) pair_repr = pair_repr.transpose(-3, -2) pair_repr = pair_repr + PairTransition( self.pair_transition_norm(pair_repr), self.pair_transition_proj1, self.pair_transition_proj2) - # pair_repr = pair_repr + checkpoint(PairTransition, - # self.pair_transition_norm(pair_repr), self.pair_transition_proj1, - # self.pair_transition_proj2) return (msa_repr, pair_repr) @@ -393,7 +355,6 @@ class AlphaFold2(nn.Module): def __init__(self, s: int, cm: int, cz: int, evo_num: int): super().__init__() self.evo_num = evo_num - # self.evoformers: List[torch.nn.Module] = [Evoformer(s, cm, cz) for _ in range(evo_num)] self.evoformer = Evoformer(s, cm, cz) @@ -447,7 +408,7 @@ def train_iter(model, dataloader): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num - warm_up, field_name='e2e'))) CudaTimer().print_all(times=iter_num - warm_up) - print_each_rank(torch.cuda.max_memory_allocated() / 1024 / 1024) + print_each_rank('memory consumption: {} MB'.format(int(torch.cuda.max_memory_allocated() / 1024 / 1024))) test() diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 0f233551..30f3c7f6 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -2,6 +2,21 @@ from cube.graph import IRGraph from cube.ir.operator import IRDataOperation, IRFwOperation +recompute_info = { + 'MSAAttention': True, + 'MSAAttentionWithBias': True, + 'MSARowAttentionWithPairBias': True, + 'MSATransition': True, + 'OuterProductMean': True, + 'TriangleMultiplication': True, + 'TriangleAttentionNode': True, + 'PairTransition': True, + 'add': False, + 'sum': False, + 'layernorm': False, + 'transpose': False, +} + def PASData(graph: IRGraph, resource): devs = list(range(resource.ngpus)) @@ -28,4 +43,6 @@ def PASData(graph: IRGraph, resource): num=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) + if node.name in recompute_info and recompute_info[node.name] == True: + graph.recompute(sub_nodes) return graph From 577159f87102843a0d6a11410b8b03e78e22b81f Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 27 Sep 2022 23:14:08 -0700 Subject: [PATCH 1041/1892] add naive pytorch module profiler --- examples/nlp/palm/module_profiler.py | 56 ++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 examples/nlp/palm/module_profiler.py diff --git a/examples/nlp/palm/module_profiler.py b/examples/nlp/palm/module_profiler.py new file mode 100644 index 00000000..ab942023 --- /dev/null +++ b/examples/nlp/palm/module_profiler.py @@ -0,0 +1,56 @@ +import torch +from cube.profiler import CudaTimer + +bs, n, dim, heads, dim_head = 10, 2048, 4096, 16, 256 +scale = 0.125 + +dev = torch.device('cuda:0') + +def multi_head_attention(x: torch.Tensor, qkv_proj: torch.Tensor, + out_proj: torch.Tensor): + + q, kv = torch.matmul(x, qkv_proj).split((dim, dim_head), dim=-1) + q = q.view(bs, n, heads, dim_head).transpose(1, 2) + q = q.reshape(bs, heads * n, dim_head) + trans_kv = kv.transpose(1, 2) + sim = torch.bmm(q, trans_kv).view(bs, heads, n, n) + attn = torch.nn.functional.softmax(sim, dim=-1) + attn = attn.view(bs, heads * n, n) + out = torch.bmm(attn, kv).view(bs, heads, n, dim_head) + out = torch.transpose(out, 1, 2).reshape(bs, n, dim) + out = torch.matmul(out, out_proj) + return out + +def ffn(x: torch.Tensor, xx: torch.Tensor, y: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor): + return torch.matmul(x, w1), torch.matmul(xx * y, w2) + +x = torch.randn(bs, n, dim).to(dev) +xx = torch.randn(bs, n, dim).to(dev) +y = torch.randn(bs, n, dim).to(dev) +qkv_proj = torch.randn(dim, dim+dim_head).to(dev) +q_proj = torch.randn(dim, dim).to(dev) +kv_proj = torch.randn(dim, dim_head).to(dev) +out_proj = torch.randn(dim, dim).to(dev) +w1 = torch.randn(dim, 2 * dim).to(dev) +w2 = torch.randn(dim, dim).to(dev) +score = torch.randn([bs * heads * n, n], requires_grad=True).to(dev) + +CudaTimer(enable=False).warmup() + +iter_num = 64 +warmup = 20 + +for step in range(iter_num): + softmax_score = torch.nn.functional.softmax(score, dim=-1) + if step >= warmup: + CudaTimer(enable=True).start('e2e') + # out = multi_head_attention(x, qkv_proj, out_proj) + # out = ffn(x, xx, y, w1, w2) + out = torch.autograd.grad(outputs=softmax_score, inputs=score, grad_outputs=softmax_score) + if step >= warmup: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print(f'iter [{step + 1}/{iter_num}]') + +print('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num - warmup, field_name='e2e'))) \ No newline at end of file From a7df9800cb2e30e5dd7062527a336ac63e54d4e4 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 28 Sep 2022 06:04:32 -0700 Subject: [PATCH 1042/1892] save tem work, not runnable --- examples/alphafold2/alphafold2.py | 29 ++++++++++++++++++++++------- examples/alphafold2/policy/spmd.py | 27 +++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 43685008..c382565b 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -80,7 +80,6 @@ def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, out = torch.matmul(out, out_proj) return out - """ ([bs, s, r, cm], [bs, r, r, cz]) -> [bs, s, r, cm] """ @@ -101,6 +100,13 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, return MSAAttentionWithBias(msa_repr, gate_proj, qkv_proj, out_proj, bias, head, c, scale) +@cube.graph.parser.register('N S R M, M E, M F, E M -> N S R M', name='MSAColAttention') +def MSAColAttention(msa_repr: torch.Tensor, + gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + head: int, c: int, + scale: float): + return MSAAttention(msa_repr.permute(0, 2, 1, 3), gate_proj, qkv_proj, out_proj, head, c, scale).permute(0, 2, 1, 3) """ [bs, s, r, cm] -> [bs, s, r, cm] @@ -302,12 +308,17 @@ def forward(self, msa_repr, pair_repr): self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, self.msa_head, self.c, self.scale) - - msa_repr = msa_repr.transpose(-3, -2) - msa_repr = msa_repr + MSAAttention( - self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - self.col_out_proj, self.msa_head, self.c, self.scale) - msa_repr = msa_repr.transpose(-3, -2) + + # return (msa_repr, pair_repr) + + # msa_repr = msa_repr.transpose(-3, -2) + # msa_repr = msa_repr + MSAAttention( + # self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, + # self.col_out_proj, self.msa_head, self.c, self.scale) + # msa_repr = msa_repr.transpose(-3, -2) + msa_repr = msa_repr + MSAColAttention(self.col_norm(msa_repr), self.col_gate_proj, + self.col_qkv_proj, self.col_out_proj, self.msa_head, self.c, self.scale) + # return (msa_repr, pair_repr) msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), self.msa_transition_proj1, @@ -369,6 +380,10 @@ def test(): model = AlphaFold2(s, cm, cz, 1) + # msa_repr, pair_repr = torch.randn(bs, s, r, cm), torch.randn(bs, r, r, cz) + # x = model(msa_repr, pair_repr) + # return + model = cube.SemanticModel( model, input_shapes=([bs, s, r, cm], [bs, r, r, cz], ), diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 30f3c7f6..acdf0e52 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -6,6 +6,7 @@ 'MSAAttention': True, 'MSAAttentionWithBias': True, 'MSARowAttentionWithPairBias': True, + 'MSAColAttention': True, 'MSATransition': True, 'OuterProductMean': True, 'TriangleMultiplication': True, @@ -46,3 +47,29 @@ def PASData(graph: IRGraph, resource): if node.name in recompute_info and recompute_info[node.name] == True: graph.recompute(sub_nodes) return graph + +def PASMegatron(graph: IRGraph, resource): + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + + def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for dev_id, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, dev_id) + return sub_nodes + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + batch_dim = node.get_batch_dims()[0] \ No newline at end of file From 91201d83a83ab0f9f41b58384d7bab28022fd07e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Sep 2022 13:01:12 +0800 Subject: [PATCH 1043/1892] fix grouping on no-backward operators --- cube/execplan/planpass/grouping.py | 36 +++++++++--------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 16164755..c9eff511 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -23,10 +23,10 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: for devid in execplan.devices(): for fpieces, bpieces in zip(fgroups[devid], bgroups[devid]): fsubgraph = graph.create_segment(fpieces) - if bpieces is not None: + if len(bpieces) > 0: bsubgraph = graph.create_segment(bpieces) IRCell.make_pair(fsubgraph, bsubgraph) - subgraphs = [fsubgraph] if bpieces is None else [fsubgraph, fsubgraph.mirror] + subgraphs = [fsubgraph] if len(bpieces) == 0 else [fsubgraph, fsubgraph.mirror] for subgraph in subgraphs: # update execution plan: replace the nodes with the subgraph pieces = subgraph.nodes() @@ -53,7 +53,7 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: Returns: Tuple: (fgroups, bgroups) """ - def is_forward_node(fnode): + def differentiable(fnode): if isinstance(fnode, IRFwOperation): return True if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.forward: @@ -63,27 +63,13 @@ def is_forward_node(fnode): fgroups, bgroups = dict(), dict() for devid in execplan.devices(): fgroups[devid], bgroups[devid] = list(), list() - seq = execplan.seq(devid) - fnodes = [node for node in seq if is_forward_node(node)] - have_backward = all(fnode.mirror in seq for fnode in fnodes) - fpieces = [] - - for fnode in seq: - if is_forward_node(fnode): - fpieces.append(fnode) - else: - if len(fpieces) != 0: - fgroups[devid].append(fpieces) - fpieces = [] - - if len(fpieces) != 0: - fgroups[devid].append(fpieces) - - for pieces in fgroups[devid]: - if have_backward: - bpieces = [fnode.mirror for fnode in pieces[::-1] if fnode.mirror is not None] - bgroups[devid].append(bpieces) - else: - bgroups[devid].append(None) + nodes = execplan.seq(devid) + break_idx = [idx for idx, node in enumerate(nodes) if not differentiable(node)] + for start, end in zip([-1] + break_idx, break_idx + [len(nodes)]): + if start+1 == end: continue + fpieces = nodes[start+1:end] + bpieces = [node.mirror for node in fpieces[::-1] if node.mirror is not None] + fgroups[devid].append(nodes[start+1:end]) + bgroups[devid].append(bpieces) return fgroups, bgroups From 726dc088c7e115ad91d8a2a7dd38d22b8b5d48b6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Sep 2022 19:30:45 +0800 Subject: [PATCH 1044/1892] loss gradient to IRTensor --- cube/codegen/codegen.py | 5 ++-- cube/execplan/planpass/fusion.py | 14 +++++------ cube/graph/gener/gen.py | 15 ++++++----- cube/graph/graph.py | 8 ++---- cube/graph/segment.py | 15 ++++------- cube/ir/tensor.py | 43 ++++++++++++++++++++++++-------- 6 files changed, 57 insertions(+), 43 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index ca7f1baa..915f33a7 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -134,6 +134,8 @@ def is_temp_tensor(v): fw_inputs, fw_outputs, output_grads, input_grads = \ get_backward_callsite_io_tensors(node) + # remove loss gradient + output_grads = [t for t in output_grads if not t.is_loss()] outputs = input_grads inputs = list(itertools.chain(fw_inputs, fw_outputs, output_grads)) @@ -1024,8 +1026,7 @@ def emit_node(self, node: IRCell, name: str) -> str: get_backward_callsite_io_tensors(node) for idx, tensor in enumerate(output_grads): - if isinstance(tensor, float): - assert tensor == 1.0, "Loss gradient should be 1.0" + if isinstance(tensor, IRSubTensor) and tensor.is_loss(): output_grads[idx] = None code = bsign.format( name = f"'{self.node_naming(node.mirror)}'", diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 74881138..f745b645 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -39,13 +39,13 @@ def _apply(segment: IRSegment) -> int: for node in segment.nodes(): if isinstance(node, IRAdapter) and node.forward: ret = DiffFusion.nnfuse(node) - if not ret and not node.differentiable: - raise NotImplementedError( - f"Adapter within IRSegment cannot fuse to differientiable adapter" - f"\nforward: {node.extra_repr()}" - f"\nbackward: {node.mirror.extra_repr()}" - ) - cnt = cnt + 1 + # if not ret and not node.differentiable: + # raise NotImplementedError( + # f"Adapter within IRSegment cannot fuse to differientiable adapter" + # f"\nforward: {node.extra_repr()}" + # f"\nbackward: {node.mirror.extra_repr()}" + # ) + cnt = cnt + 1 if ret else cnt elif isinstance(node, IRSegment) and node.isfw(): cnt += DiffFusion._apply(node) return cnt diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index c369c970..0afbd62c 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -47,11 +47,12 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: @return nodes List[IRCell]: the generated operation """ - devices = segment.device + # devices = segment.device fwops = [] # create inputs for tensor in segment.inputs(): + devices = [consumer.device for consumer in segment.consumers(tensor.parent)][::-1] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" fwop = DummyInputOuput(tensor, 0, is_output=True) @@ -67,6 +68,7 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: # create outputs for tensor in segment.outputs(): + devices = [producer.device for producer in segment.producers(tensor.parent)] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" fwop = DummyInputOuput(tensor, 0, is_input=True) @@ -74,10 +76,7 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: fop = fwop.replicate() fop.device = devid if tensor.requires_grad and segment.mirror != segment: - if isinstance(tensor.grad, float): - fop.input(0).grad = tensor.grad - else: - fop.input(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) + fop.input(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) segment.finsert(fop, segment.nnodes) else: segment.insert(fop, segment.nnodes) @@ -320,10 +319,10 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: continue if not isinstance(graph, IRGraph): - if not fadapter.differentiable: + if not (fadapter.differentiable or fadapter.mirror is None): raise NotImplementedError( - "Require adapter to be differentiable for nested IRAdapter." - "Condition to be differentiable: prodcuers have same device set with consumers" + "Require adapter to be differentiable for nested IRAdapter.\n" + "Condition to be differentiable: prodcuers have same device set with consumers\n" f"Failed FullTensor: {ftensor}" f"{graph.debug_tensor_map_str(ftensor)}" f"Failed FullTensor.grad: {ftensor.grad}" diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 2e4ca693..77f4f481 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -137,8 +137,7 @@ def backward(self, loss: IRSubTensor): otensor.requires_grad = require_grad # set loss gradient - assert tuple(loss.shape) == (1,), f"the loss should be of shape [1,] (got {loss.shape})" - loss.parent.grad = 1.0 + loss.parent.to_loss() # infer gradient for ftensor in self._ftensors: @@ -370,10 +369,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], if not otensor.requires_grad: grad = None else: - if isinstance(otensor.parent.grad, float): - grad = otensor.parent.grad - else: - grad = otensor.parent.grad.select(otensor.indmap, (0,1)) + grad = otensor.parent.grad.select(otensor.indmap, (0,1)) otensor.grad = grad # insert forward node diff --git a/cube/graph/segment.py b/cube/graph/segment.py index d95a2842..91de1fea 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -306,9 +306,6 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: idx = producer.outputs().index(ptensor) if fgrad is None: grad = None - elif isinstance(fgrad, float): - assert fgrad == 1.0, "Detect a backward tensor, but gradient can only be 1.0" - grad = fgrad else: grad = fgrad.select(ptensor.indmap, (0, 1)) producer.output(idx).grad = grad @@ -325,9 +322,6 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: idx = consumer.inputs().index(ctensor) if fgrad is None: grad = None - elif isinstance(fgrad, float): - assert fgrad == 1.0, "Detect a backward tensor, but gradient can only be 1.0" - grad = fgrad else: valmap = curr_valmap.map((0, 2)) if cidx != nconsumers - 1 else curr_valmap grad = fgrad.select(ctensor.indmap, valmap) @@ -532,18 +526,19 @@ def exist(self, node: IRCell) -> bool: return True return False - def select(self, name: Optional[str] = None, ntype: Optional[IRCell] = None) -> List[IRCell]: + def select(self, name: Optional[str] = None, ntype: Optional[IRCell] = None, flatten: bool = True) -> List[IRCell]: """ Select all the nodes (including nodes in sub-segment) that satisfy the condition. - @param name str: the node name - @param ntype Type: the node type + @param name Optional[str]: the node name + @param ntype Optional[Type]: the node type + @param flatten bool: whether to flatten the segment to nodes. (Default True) @return nodes List[IRCell]: the nodes that have the name. """ nodes = [] - for node in self.nodes(flatten=True): + for node in self.nodes(flatten=flatten): if name is not None: if node.name != name: continue diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index c110f850..3ab88072 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -264,6 +264,7 @@ def __init__(self, shape=None, name=None, requires_grad=False, dtype=IRDType.unk self._segments : Dict[(ValueMap, IndexMap), int] = dict() self.requires_grad = requires_grad + self._is_loss = False def __hash__(self) -> int: return self._id @@ -287,23 +288,39 @@ def like(self): return tensor @property - def grad(self) -> Optional[Union[IRTensor, float]]: + def grad(self) -> Optional[IRTensor]: return self._grad @grad.setter - def grad(self, val: Optional[Union[IRTensor, float]]): + def grad(self, val: Optional[IRTensor]): """ int indicates the tensor is the loss tensor. """ if self._requires_grad: - assert isinstance(val, (IRFullTensor, float)) - if isinstance(val, IRFullTensor): - assert val.shape == self.shape - assert val.is_attr() == self.is_attr() + assert isinstance(val, IRFullTensor) + assert val.shape == self.shape + assert val.is_attr() == self.is_attr() else: assert val is None, "The FullTensor doesn't require grad but is assigned with a grad." self._grad = val + def is_loss(self) -> bool: + """ + Check whether this tensor is a loss tensor + + @return loss bool: True if the tensor is loss + """ + return self._is_loss + + def to_loss(self): + """ + Set this tensor is loss tensor. The tensor shape must be [1,] + """ + assert tuple(self.shape) == (1,), f"Loss tensor can only have shape [1,] but got {self.shape}" + assert self.requires_grad, f"The tensor doesn't require gradient. Cannot backward" + self._is_loss = True + self.grad._is_loss = True + @property def requires_grad(self): return self._requires_grad @@ -612,10 +629,8 @@ def grad(self) -> bool: @grad.setter def grad(self, val: Optional[IRTensor]): - if isinstance(val, (IRSubTensor, float)): - assert self.requires_grad - if isinstance(val, IRSubTensor): - val.shape == self.shape + if isinstance(val, IRSubTensor): + assert self.requires_grad and val.shape == self.shape self._grad = val elif val is None: assert not self.requires_grad @@ -623,6 +638,14 @@ def grad(self, val: Optional[IRTensor]): else: raise ValueError(f"Expected grad to be None or IRSubTensor but got: {val}") + def is_loss(self) -> bool: + """ + Check whether this tensor is loss tensor. + + @return loss bool: True if the tensor is a loss tensor. + """ + return self.parent.is_loss() + # partition primitives def select(self, From 17d4007c228e074d8fa67ccad1d1350a006256d3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Sep 2022 19:31:43 +0800 Subject: [PATCH 1045/1892] fix embedding bug --- cube/graph/function/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 93516cf1..e170c707 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -861,7 +861,7 @@ def Embedding(signature, inputs: List): """ signature = 'cube.runtime.function.embedding' itensor, weight = inputs[:2] - padding_idx = inputs[3] + padding_idx = inputs[2] if isinstance(weight, IRSubTensor): start, stop = weight.indmap[0] else: From 0e30658077c5aa3c1be0e9de694c5b0c585a453c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 29 Sep 2022 19:47:35 +0800 Subject: [PATCH 1046/1892] full tp/dp/pp megatron policy --- .gitignore | 7 +++- examples/nlp/gpt/policy/mpmd.py | 59 +++++++++++++++++++++------------ 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 6988bde4..951faa42 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,9 @@ __pycache__ *.egg-info .vs/ -.vscode/ \ No newline at end of file +.vscode/ + +benchmark/megatron/Megatron-LM + +gencode*.py +fullmodel.pt diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index 7ce693d5..ec4c1783 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -43,7 +43,7 @@ def _group_to_transformers(fnodes) -> List[List[IRCell]]: anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] indices = [fnodes.index(anchor) for anchor in anchors] for lid, idx in enumerate(indices): - fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + fnodes[idx+1].comment = f'===> start of transformer layer {lid}' start = idx if lid != 0 else 0 end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) transformers.append(fnodes[start:end]) @@ -133,10 +133,10 @@ def PASMegatron(graph: IRGraph, resource): """ 1F1B scheduling """ - dp_size = 1 + dp_size = 2 tp_size = 2 pp_size = resource.ngpus // (dp_size * tp_size) - num_microbatch = resource.ngpus + num_microbatch = pp_size * 2 # device mesh dp_groups, pp_groups, tp_groups = \ @@ -145,33 +145,50 @@ def PASMegatron(graph: IRGraph, resource): print(f'pp groups: {pp_groups}') print(f'tp groups: {tp_groups}') + def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: + return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] + # group to transformer layers - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - transformers = _group_to_transformers(fnodes) + transformers = _group_to_transformers(graph.select(ntype=IRFwOperation)) - # inter-staging: set each stage operators + # group to stage: set each stage operators fstages = [[] for _ in range(pp_size)] - nlayer_per_stage = (len(transformers) // resource.ngpus) + nlayer_per_stage = (len(transformers) // pp_size) for lid, fnodes in enumerate(transformers): stage_id = min(lid // nlayer_per_stage, pp_size - 1) fstages[stage_id] += fnodes graph.staging(tuple(stages[0] for stages in fstages)) - # intra-stage: tp and dp parallelism on device group - fsegments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] - assert len(fsegments) == pp_size - for sid, segment in enumerate(fsegments): - for fnode in segment.nodes(): - if fnode.name == 'self_attention': - _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) - elif fnode.name == 'feedforward': - _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) - else: - _replica(graph, fnode, tp_groups[sid]) + dataloader = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[0] + + # partition dataloader + dls = _replica(graph, dataloader, [0]*dp_size) # graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) + for dp_idx, dl in enumerate(dls): + # only stage 0 needs dataloader + devices = [get_device(dp_idx, 0, tp_idx) for tp_idx in range(tp_size)] + _replica(graph, dl, devices) - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - _replica(graph, node, tp_groups[0]) + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + assert len(fstages) > 0 + for pp_idx, fstage in enumerate(fstages): + for fnode in fstage.nodes(): + if len(fnode.inputs()) == 0: continue # anchor + # tensor parallel -- FIXME: current restriction needs replica happen before partition + if fnode.name == 'self_attention' or fnode.name == 'feedforward': + fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + elif fnode.name == 'embedding': + fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + elif fnode.name == 'linear': # the last embeding linear + fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + else: + fnodes = _replica(graph, fnode, [0]*tp_size) + # data parallel + for tp_idx, fnode in enumerate(fnodes): + dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] + print(dp_devices) + batch_dim = fnode.input(0).shape.index(bs) + _tp(graph, fnode, devs=dp_devices, idx=0, dim=batch_dim, num=dp_size) strategy = IRSchedule1F1B(graph, num_microbatch) graph.predef_sched(strategy) From 43d15847c7bd4504e097b705b0387bebe16d4c6c Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Thu, 29 Sep 2022 05:11:48 -0700 Subject: [PATCH 1047/1892] runnable, tensor parallelism --- examples/alphafold2/alphafold2.py | 240 +++++++++++++++++++---------- examples/alphafold2/policy/spmd.py | 52 +++++-- 2 files changed, 199 insertions(+), 93 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index c382565b..4803ff65 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -10,18 +10,18 @@ import examples.alphafold2.policy.spmd as spmd cube.init() - - """ [bs, s, r, cm] -> [bs, s, r, cm] used as column-wise gated self-attention """ -@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R M', name='MSAAttention') + +@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R M', + name='MSAAttention') def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - head: int, c: int, scale: float): + qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, + c: int, scale: float): bs, s, r, cm = x.size() gate = torch.sigmoid(torch.matmul(x, gate_proj)) @@ -48,10 +48,11 @@ def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, return out -@cube.graph.parser.register('N S R M, M E, M F, E M, N 1 8 R R -> N S R M', name='MSAAttentionWithBias') +@cube.graph.parser.register('N S R M, M E, M F, E M, N 1 8 R R -> N S R M', + name='MSAAttentionWithBias') def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias: torch.Tensor, - head: int, c: int, scale: float): + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias: torch.Tensor, head: int, c: int, scale: float): bs, s, r, cm = x.size() gate = torch.sigmoid(torch.matmul(x, gate_proj)) @@ -80,12 +81,15 @@ def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, out = torch.matmul(out, out_proj) return out + """ ([bs, s, r, cm], [bs, r, r, cz]) -> [bs, s, r, cm] """ + # note: code not reused constrained by cube's interface -@cube.graph.parser.register('N S R M, N R R Z, M E, M F, E M, Z H -> N S R M', name='MSARowAttentionWithPairBias') +@cube.graph.parser.register('N S R M, N R R Z, M E, M F, E M, Z H -> N S R M', + name='MSARowAttentionWithPairBias') def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, pair_repr: torch.Tensor, gate_proj: torch.Tensor, @@ -98,21 +102,26 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, bias_proj).permute(0, 3, 1, 2).reshape(bs, 1, head, r, r) - return MSAAttentionWithBias(msa_repr, gate_proj, qkv_proj, out_proj, bias, head, c, scale) + return MSAAttentionWithBias(msa_repr, gate_proj, qkv_proj, out_proj, bias, + head, c, scale) + + +@cube.graph.parser.register('N S R M, M E, M F, E M -> N S R M', + name='MSAColAttention') +def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, + c: int, scale: float): + return MSAAttention(msa_repr.permute(0, 2, 1, 3), gate_proj, qkv_proj, + out_proj, head, c, scale).permute(0, 2, 1, 3) -@cube.graph.parser.register('N S R M, M E, M F, E M -> N S R M', name='MSAColAttention') -def MSAColAttention(msa_repr: torch.Tensor, - gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - head: int, c: int, - scale: float): - return MSAAttention(msa_repr.permute(0, 2, 1, 3), gate_proj, qkv_proj, out_proj, head, c, scale).permute(0, 2, 1, 3) """ [bs, s, r, cm] -> [bs, s, r, cm] """ -@cube.graph.parser.register('N S R M, M E, E M -> N S R M', name='MSATransition') + +@cube.graph.parser.register('N S R M, M E, E M -> N S R M', + name='MSATransition') def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): return torch.matmul( @@ -124,43 +133,47 @@ def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, """ -@cube.graph.parser.register('N S R M, M C, M C, F Z -> N R R Z', name='OuterProductMean') -def OuterProductMean(msa_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor, out_proj: torch.Tensor): +@cube.graph.parser.register('N S R M, N S T M, M C, M C, F Z -> N R T Z', + name='OuterProductMean') +def OuterProductMean(msa_repr: torch.Tensor, dummy_msa_repr: torch.Tensor, + proj1: torch.Tensor, proj2: torch.Tensor, + out_proj: torch.Tensor): bs, s, r, cm = msa_repr.size() + t = dummy_msa_repr.size(2) c = proj1.size(-1) a = torch.matmul(msa_repr, proj1).transpose(-2, -3) - b = torch.matmul(msa_repr, proj2).transpose(-2, -3) + b = torch.matmul(dummy_msa_repr, proj2).transpose(-2, -3) outer = torch.einsum('...bac,...dae->...bdce', a, - b).reshape(bs, r, r, c * c) + b).reshape(bs, r, t, c * c) outer = torch.matmul(outer, out_proj) return outer -@cube.graph.parser.register('N R R Z, Z, Z, E, E, Z E, Z E, Z E, Z E, E Z, Z Z -> N R R Z', name='TriangleMultiplication') -def TriangleMultiplication( - pair_repr: torch.Tensor, tri_mul_norm1_weight: torch.Tensor, - tri_mul_norm1_bias: torch.Tensor, tri_mul_norm2_weight: torch.Tensor, - tri_mul_norm2_bias: torch.Tensor, tri_mul_proj1: torch.Tensor, - tri_mul_proj2: torch.Tensor, tri_mul_proj3: torch.Tensor, - tri_mul_proj4: torch.Tensor, tri_mul_proj5: torch.Tensor, - tri_mul_proj6: torch.Tensor, cz: int, out_going: bool): +@cube.graph.parser.register( + 'N S R Z, N T R Z, Z, Z, E, E, Z E, Z E, Z E, Z E, E Z, Z Z -> N S T Z', + name='TriangleMultiplicationOut') +def TriangleMultiplicationOut( + pair_repr: torch.Tensor, dummy_pair_repr: torch.Tensor, + tri_mul_norm1_weight: torch.Tensor, tri_mul_norm1_bias: torch.Tensor, + tri_mul_norm2_weight: torch.Tensor, tri_mul_norm2_bias: torch.Tensor, + tri_mul_proj1: torch.Tensor, tri_mul_proj2: torch.Tensor, + tri_mul_proj3: torch.Tensor, tri_mul_proj4: torch.Tensor, + tri_mul_proj5: torch.Tensor, tri_mul_proj6: torch.Tensor, cz: int): pair_repr = torch.nn.functional.layer_norm(pair_repr, (cz, ), tri_mul_norm1_weight, tri_mul_norm1_bias) + dummy_pair_repr = torch.nn.functional.layer_norm(dummy_pair_repr, (cz, ), + tri_mul_norm1_weight, + tri_mul_norm1_bias) a = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj1)) a = a * torch.matmul(pair_repr, tri_mul_proj2) - b = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj3)) - b = b * torch.matmul(pair_repr, tri_mul_proj4) + b = torch.sigmoid(torch.matmul(dummy_pair_repr, tri_mul_proj3)) + b = b * torch.matmul(dummy_pair_repr, tri_mul_proj4) - if out_going: - a = a.permute(0, 3, 1, 2) - b = b.permute(0, 3, 2, 1) - else: - a = a.permute(0, 3, 2, 1) - b = b.permute(0, 3, 1, 2) + a = a.permute(0, 3, 1, 2) + b = b.permute(0, 3, 2, 1) p = torch.matmul(a, b).permute(0, 2, 3, 1) p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, @@ -170,24 +183,75 @@ def TriangleMultiplication( return p * g -@cube.graph.parser.register('N R R Z, Z E, Z F, E Z, Z G -> N R R Z', name='TriangleAttentionNode') -def TriangleAttentionNode(pair_repr: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias_proj: torch.Tensor, head: int, c: int, - scale: float): - bias = torch.matmul(pair_repr, bias_proj).permute(0, 3, 1, 2) +@cube.graph.parser.register( + 'N R S Z, N R T Z, Z, Z, E, E, Z E, Z E, Z E, Z E, E Z, Z Z -> N S T Z', + name='TriangleMultiplicationIn') +def TriangleMultiplicationIn( + pair_repr: torch.Tensor, dummy_pair_repr: torch.Tensor, + tri_mul_norm1_weight: torch.Tensor, tri_mul_norm1_bias: torch.Tensor, + tri_mul_norm2_weight: torch.Tensor, tri_mul_norm2_bias: torch.Tensor, + tri_mul_proj1: torch.Tensor, tri_mul_proj2: torch.Tensor, + tri_mul_proj3: torch.Tensor, tri_mul_proj4: torch.Tensor, + tri_mul_proj5: torch.Tensor, tri_mul_proj6: torch.Tensor, cz: int): + pair_repr = torch.nn.functional.layer_norm(pair_repr, (cz, ), + tri_mul_norm1_weight, + tri_mul_norm1_bias) + dummy_pair_repr = torch.nn.functional.layer_norm(dummy_pair_repr, (cz, ), + tri_mul_norm1_weight, + tri_mul_norm1_bias) + a = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj1)) + a = a * torch.matmul(pair_repr, tri_mul_proj2) + b = torch.sigmoid(torch.matmul(dummy_pair_repr, tri_mul_proj3)) + b = b * torch.matmul(dummy_pair_repr, tri_mul_proj4) + + a = a.permute(0, 3, 2, 1) + b = b.permute(0, 3, 1, 2) + + p = torch.matmul(a, b).permute(0, 2, 3, 1) + p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, + tri_mul_norm2_bias) + p = torch.matmul(p, tri_mul_proj5) + g = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj6)).transpose(1, 2) + return p * g + + +@cube.graph.parser.register('N S R Z, Z E, Z F, E Z, Z G -> N S R Z', + name='TriangleAttentionNodeStart') +def TriangleAttentionNodeStart(pair_repr: torch.Tensor, + gate_proj: torch.Tensor, qkv_proj: torch.Tensor, + out_proj: torch.Tensor, bias_proj: torch.Tensor, + head: int, c: int, scale: float): + bias = torch.matmul(pair_repr, bias_proj).permute(0, 1, 3, 2).unsqueeze(3) + + return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, + head, c, scale) + - return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, head, - c, scale) +@cube.graph.parser.register('N S R Z, Z E, Z F, E Z, Z G -> N S R Z', + name='TriangleAttentionNodeEnd') +def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias_proj: torch.Tensor, head: int, c: int, + scale: float): + pair_repr = pair_repr.permute(0, 2, 1, 3) + out = TriangleAttentionNodeStart(pair_repr, gate_proj, qkv_proj, out_proj, + bias_proj, head, c, scale) + return out.permute(0, 2, 1, 3) -@cube.graph.parser.register('N R R Z, Z E, E Z -> N R R Z', name='PairTransition') +@cube.graph.parser.register('N R T Z, Z E, E Z -> N R T Z', + name='PairTransition') def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): return torch.matmul( torch.nn.functional.relu(torch.matmul(pair_repr, proj1)), proj2) +@cube.graph.parser.register('* -> *, *', name='multi2ref') +def multi2ref(x: torch.Tensor): + return (x, x) + + """ a simplified version for evoformer in alphafold2 - dropout layers are omitted @@ -308,53 +372,48 @@ def forward(self, msa_repr, pair_repr): self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, self.msa_head, self.c, self.scale) - - # return (msa_repr, pair_repr) - - # msa_repr = msa_repr.transpose(-3, -2) - # msa_repr = msa_repr + MSAAttention( - # self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - # self.col_out_proj, self.msa_head, self.c, self.scale) - # msa_repr = msa_repr.transpose(-3, -2) - msa_repr = msa_repr + MSAColAttention(self.col_norm(msa_repr), self.col_gate_proj, - self.col_qkv_proj, self.col_out_proj, self.msa_head, self.c, self.scale) - # return (msa_repr, pair_repr) + + msa_repr = msa_repr + MSAColAttention( + self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, + self.col_out_proj, self.msa_head, self.c, self.scale) msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), self.msa_transition_proj1, self.msa_transition_proj2) + pair_msa_repr, dummy_pair_msa_repr = multi2ref( + self.outer_norm(msa_repr)) pair_repr = pair_repr + OuterProductMean( - self.outer_norm(msa_repr), self.outer_proj1, self.outer_proj2, - self.outer_out_proj) + pair_msa_repr, dummy_pair_msa_repr, self.outer_proj1, + self.outer_proj2, self.outer_out_proj) - pair_repr = pair_repr + TriangleMultiplication( - pair_repr, self.tri_mul_out_norm1_weight, + out_pair_repr, out_dummy_pair_repr = multi2ref(pair_repr) + pair_repr = pair_repr + TriangleMultiplicationOut( + out_pair_repr, out_dummy_pair_repr, self.tri_mul_out_norm1_weight, self.tri_mul_out_norm1_bias, self.tri_mul_out_norm2_weight, self.tri_mul_out_norm2_bias, self.tri_mul_out_proj1, self.tri_mul_out_proj2, self.tri_mul_out_proj3, self.tri_mul_out_proj4, self.tri_mul_out_proj5, - self.tri_mul_out_proj6, self.cz, True) + self.tri_mul_out_proj6, self.cz) - pair_repr = pair_repr + TriangleMultiplication( - pair_repr, self.tri_mul_in_norm1_weight, + in_pair_repr, in_dummy_pair_repr = multi2ref(pair_repr) + pair_repr = pair_repr + TriangleMultiplicationIn( + in_pair_repr, in_dummy_pair_repr, self.tri_mul_in_norm1_weight, self.tri_mul_in_norm1_bias, self.tri_mul_in_norm2_weight, self.tri_mul_in_norm2_bias, self.tri_mul_in_proj1, self.tri_mul_in_proj2, self.tri_mul_in_proj3, self.tri_mul_in_proj4, self.tri_mul_in_proj5, - self.tri_mul_in_proj6, self.cz, False) + self.tri_mul_in_proj6, self.cz) - pair_repr = pair_repr + TriangleAttentionNode( + pair_repr = pair_repr + TriangleAttentionNodeStart( self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, self.tri_att_start_bias_proj, self.pair_head, self.c, self.scale) - pair_repr = pair_repr.transpose(-3, -2) - pair_repr = pair_repr + TriangleAttentionNode( + pair_repr = pair_repr + TriangleAttentionNodeEnd( self.tri_att_end_norm(pair_repr), self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, self.tri_att_end_bias_proj, self.pair_head, self.c, self.scale) - pair_repr = pair_repr.transpose(-3, -2) pair_repr = pair_repr + PairTransition( self.pair_transition_norm(pair_repr), self.pair_transition_proj1, @@ -362,13 +421,14 @@ def forward(self, msa_repr, pair_repr): return (msa_repr, pair_repr) + class AlphaFold2(nn.Module): + def __init__(self, s: int, cm: int, cz: int, evo_num: int): super().__init__() self.evo_num = evo_num self.evoformer = Evoformer(s, cm, cz) - def forward(self, msa, pair): new_msa, new_pair = self.evoformer(msa, pair) loss = torch.sum(new_msa) * torch.sum(new_pair) @@ -376,7 +436,7 @@ def forward(self, msa, pair): def test(): - bs, s, r, cm, cz = 4, 128, 256, 256, 128 + bs, s, r, cm, cz = 1, 128, 384, 256, 128 model = AlphaFold2(s, cm, cz, 1) @@ -385,15 +445,28 @@ def test(): # return model = cube.SemanticModel( - model, - input_shapes=([bs, s, r, cm], [bs, r, r, cz], ), - ) - - dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], [bs, r, r, cz],), - dtypes=(torch.float32, torch.float32, ), - batch_dims=(0, 0, )) - - @cube.compile(model, dataloader, PAS=spmd.PASData) + model, + input_shapes=( + [bs, s, r, cm], + [bs, r, r, cz], + ), + ) + + dataloader = cube.runtime.syndata.SynDataLoader(shapes=( + [bs, s, r, cm], + [bs, r, r, cz], + ), + dtypes=( + torch.float32, + torch.float32, + ), + batch_dims=( + 0, + 0, + )) + + # @cube.compile(model, dataloader, PAS=spmd.PASData) + @cube.compile(model, dataloader, PAS=spmd.PASMegatron, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) @@ -423,7 +496,8 @@ def train_iter(model, dataloader): print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num - warm_up, field_name='e2e'))) CudaTimer().print_all(times=iter_num - warm_up) - print_each_rank('memory consumption: {} MB'.format(int(torch.cuda.max_memory_allocated() / 1024 / 1024))) + print_each_rank('memory consumption: {} MB'.format( + int(torch.cuda.max_memory_allocated() / 1024 / 1024))) test() diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index acdf0e52..b6fba9d4 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -1,4 +1,6 @@ from typing import List + +from numpy import TooHardError from cube.graph import IRGraph from cube.ir.operator import IRDataOperation, IRFwOperation @@ -9,8 +11,10 @@ 'MSAColAttention': True, 'MSATransition': True, 'OuterProductMean': True, - 'TriangleMultiplication': True, - 'TriangleAttentionNode': True, + 'TriangleMultiplicationOut': True, + 'TriangleMultiplicationIn': True, + 'TriangleAttentionNodeStart': True, + 'TriangleAttentionNodeEnd': True, 'PairTransition': True, 'add': False, 'sum': False, @@ -18,6 +22,7 @@ 'transpose': False, } + def PASData(graph: IRGraph, resource): devs = list(range(resource.ngpus)) @@ -44,17 +49,24 @@ def PASData(graph: IRGraph, resource): num=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) - if node.name in recompute_info and recompute_info[node.name] == True: + if node.name in recompute_info and recompute_info[ + node.name] == True: graph.recompute(sub_nodes) return graph + def PASMegatron(graph: IRGraph, resource): tp_size = resource.ngpus tp_devs = list(range(tp_size)) - def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): + def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, + dim: int): algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + sub_nodes = graph.partition(node, + algo, + idx=idx, + dim=dim, + num=len(devs)) assert sub_nodes is not None for devid, sub_node in zip(devs, sub_nodes): graph.assign(sub_node, devid) @@ -68,8 +80,28 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): for node in graph.nodes(): if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] \ No newline at end of file + _replica(graph, node, tp_devs) + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if node.name == 'MSARowAttentionWithPairBias': + _tp(graph, node, tp_devs, 0, 1) + elif node.name == 'MSAColAttention': + _tp(graph, node, tp_devs, 0, 2) + elif node.name == 'MSATransition': + _tp(graph, node, tp_devs, 0, 1) + elif node.name == 'OuterProductMean': + _tp(graph, node, tp_devs, 0, 2) + elif node.name == 'TriangleMultiplicationOut': + _tp(graph, node, tp_devs, 0, 1) + elif node.name == 'TriangleMultiplicationIn': + _tp(graph, node, tp_devs, 0, 2) + elif node.name == 'TriangleAttentionNodeStart': + _tp(graph, node, tp_devs, 0, 1) + elif node.name == 'TriangleAttentionNodeEnd': + _tp(graph, node, tp_devs, 0, 2) + elif node.name == 'PairTransition': + _tp(graph, node, tp_devs, 0, 1) + else: + _replica(graph, node, tp_devs) + return graph From c1e7a7d1d06498fdf2de2b51660c2c4e724857c8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 30 Sep 2022 00:15:22 +0800 Subject: [PATCH 1048/1892] fix inference bug when partitioning weight --- cube/compiler.py | 2 +- cube/graph/segment.py | 2 +- cube/ir/cten.py | 4 ++-- cube/ir/tensor.py | 3 ++- cube/program.py | 11 +++++++++++ 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 347d57c9..74cb4879 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -88,7 +88,6 @@ def _load_tschedule_fn(filename) -> Callable: def decorator(fn: Callable) -> Callable: filename = 'gencode{}.py' - batch_size = torch.tensor([-1], dtype=torch.int).cuda() if not override and os.path.exists(filename.format(myrank)): filename = filename.format(myrank) @@ -109,6 +108,7 @@ def decorator(fn: Callable) -> Callable: # run once to get model structure and tensor shape outputs = fn(model_graph, ir_dataloader) + Program().finalize() if outputs is None: outputs = [] elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 91de1fea..0f0b0ae6 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -123,7 +123,7 @@ def attributes(self) -> Tuple[IRFullTensor]: @return ftensors List[IRFullTensor] """ - return Tuple(self._attributes) + return tuple(self._attributes) def reset_dependency(self): """ diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 82060b6d..d99a29bd 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -516,7 +516,7 @@ def is_param(self) -> bool: @return is_param boolean: True if is parameter. """ - return self._is_attr and self._requires_grad + return self._is_attr and self.requires_grad def is_buffer(self) -> bool: """! @@ -524,7 +524,7 @@ def is_buffer(self) -> bool: @return is_buffer boolean: True if is buffer. """ - return self._is_attr and not self._requires_grad + return self._is_attr and not self.requires_grad def is_grad(self) -> bool: """! diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 3ab88072..20633768 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -621,7 +621,8 @@ def __copy__(self): @property def requires_grad(self) -> bool: - return self.parent._requires_grad + self._requires_grad = self.parent.requires_grad + return self.parent.requires_grad @property def grad(self) -> bool: diff --git a/cube/program.py b/cube/program.py index 1ca0f409..3c691f36 100644 --- a/cube/program.py +++ b/cube/program.py @@ -49,6 +49,17 @@ def set_output(self, outputs: List[IRTensor]): self.instance._graph.reset_outputs(len(outputs)) for idx, otensor in enumerate(outputs): self.instance._graph.set_output(idx, otensor) + + def finalize(self): + """ + Close the recording of program. + If the program doesn't do backward, set all tensors with requires_grad=False. + """ + graph = self.get_graph() + if not any(isinstance(node, IRBpOperation) for node in graph.nodes()): + for ftensor in graph.full_tensors(): + ftensor.requires_grad = False + def mirror_as_self(self): """ From a20d5b15945332792c66f57a52bcc0c700e74a6e Mon Sep 17 00:00:00 2001 From: lynex Date: Fri, 30 Sep 2022 13:10:02 +0800 Subject: [PATCH 1049/1892] update PyTorch example: regressive generation of GPT inference, update performance with merged QKV tensor --- examples/nlp/blocks/attention.py | 77 ++++++++++++++++++++------------ examples/nlp/blocks/encoder.py | 3 +- examples/nlp/gpt/model.py | 9 ++-- 3 files changed, 56 insertions(+), 33 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 25123811..eba5d666 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -99,22 +99,27 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, return output -@cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> l N E^', name='one_attention') +# @cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> l N E^', name='one_attention') +@cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d 3) E^, (h+ d 3), E^ (h+ d) -> l N E^', name='one_attention') def one_attention(hidden_states: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor, - q_proj: torch.Tensor, q_bias: torch.Tensor, - k_proj: torch.Tensor, k_bias: torch.Tensor, - v_proj: torch.Tensor, v_bias: torch.Tensor, + # q_proj: torch.Tensor, q_bias: torch.Tensor, + # k_proj: torch.Tensor, k_bias: torch.Tensor, + # v_proj: torch.Tensor, v_bias: torch.Tensor, + qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, #out_bias: torch.Tensor, h: int, scale: float, dropout_p: float, is_training: bool = True, mask: bool = True): num_head = h l, N = hidden_states.size(0), hidden_states.size(1) - dim_head = q_proj.size(0) // num_head + # dim_head = q_proj.size(0) // num_head + dim_head = qkv_proj.size(0) // num_head // 3 - q = torch.nn.functional.linear(hidden_states, q_proj, q_bias) # l N E, (h d) E -> l N (h d) - k = torch.nn.functional.linear(hidden_states, k_proj, k_bias) # l N E, (h d) E -> l N (h d) - v = torch.nn.functional.linear(hidden_states, v_proj, v_bias) # l N E, (h d) E -> l N (h d) + # q = torch.nn.functional.linear(hidden_states, q_proj, q_bias) # l N E, (h d) E -> l N (h d) + # k = torch.nn.functional.linear(hidden_states, k_proj, k_bias) # l N E, (h d) E -> l N (h d) + # v = torch.nn.functional.linear(hidden_states, v_proj, v_bias) # l N E, (h d) E -> l N (h d) + qkv = torch.nn.functional.linear(hidden_states, qkv_proj, qkv_bias) # l N E, (h d 3) E -> l N (h d) 3 + q, k, v = qkv.chunk(3, dim=-1) if past_embed_key is not None and past_embed_value is not None: k = torch.cat((past_embed_key, k), dim=-3) @@ -125,19 +130,31 @@ def one_attention(hidden_states: torch.Tensor, k_L = k.size(0) v_L = v.size(0) - q = q.contiguous().view(l, (N * num_head), dim_head) # l N (h d) -> L (N h) d + q = q.contiguous().view(l, (N * num_head), dim_head) # l N (h d) -> l (N h) d k = k.contiguous().view(k_L, (N * num_head), dim_head) # (L+l) N (h d) -> (L+l) (N h) d v = v.contiguous().view(v_L, (N * num_head), dim_head) # (L+l) N (h d) -> (L+l) (N h) d - q = q.transpose(0, 1) # l (N h) d -> (N h) l d - k = k.transpose(0, 1) # (L+l) (N h) d -> (N h) (L+l) d - v = v.transpose(0, 1) # (L+l) (N h) d -> (N h) (L+l) d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) (L+l) d -> (N h) d (L+l) - attn = torch.bmm(q, k) # (N h) l d, (N h) d (L+l) -> (N h) l (L+l) + + # q = q.transpose(0, 1) # l (N h) d -> (N h) l d + # k = k.transpose(0, 1) # (L+l) (N h) d -> (N h) (L+l) d + # v = v.transpose(0, 1) # (L+l) (N h) d -> (N h) (L+l) d + # q = q * scale # (N h) L d, 1 -> (N h) L d + # k = k.transpose(1, 2) # (N h) (L+l) d -> (N h) d (L+l) + # attn = torch.bmm(q, k) # (N h) l d, (N h) d (L+l) -> (N h) l (L+l) + + # preallocating input tensor: (N h) L L + matmul_input_buffer = torch.empty([N * h, l, k_L], dtype=hidden_states.dtype, device=hidden_states.device) + # L (N h) d, L (N h) d -> (N h) L L + attn = torch.baddbmm( + matmul_input_buffer, + q.transpose(0, 1), # (N h) l d + k.transpose(0, 1).transpose(1, 2), # (N h) d (L+l) + beta=0.0, alpha=scale + ) attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) l (L+l) -> (N h) l (L+l) #no dropout in inference attn - torch.nn.functional.dropout(attn, dropout_p, is_training, False) # (N h) l (L+l) -> (N h) l (L+l) + attn = torch.nn.functional.dropout(attn, dropout_p, is_training, False) # (N h) l (L+l) -> (N h) l (L+l) + v = v.transpose(0, 1) output = torch.bmm(attn, v) # (N h) l (L+l), (N h) (L+l) d -> (N h) l d output = output.transpose(0, 1).contiguous() # (N h) l d -> l (N h) d output = output.view(l, N, num_head * dim_head) # l (N h) d -> l N (h d) @@ -215,15 +232,18 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa self.head_dim = inner_dim // num_heads self.scaling = self.head_dim ** -0.5 self.dropout_p = dropout - # Q - self.q_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.rand(inner_dim)) - # K - self.k_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.rand(inner_dim)) - # V - self.v_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.rand(inner_dim)) + # # Q + # self.q_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) + # self.q_bias = torch.nn.Parameter(torch.rand(inner_dim)) + # # K + # self.k_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) + # self.k_bias = torch.nn.Parameter(torch.rand(inner_dim)) + # # V + # self.v_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) + # self.v_bias = torch.nn.Parameter(torch.rand(inner_dim)) + # QKV + self.qkv_proj = torch.nn.Parameter(torch.rand(3 * inner_dim, embed_dim)) + self.qkv_bias = torch.nn.Parameter(torch.rand(3 * inner_dim)) # Out self.out_proj = torch.nn.Parameter(torch.rand(embed_dim, inner_dim)) self.out_bias = torch.nn.Parameter(torch.rand(embed_dim)) @@ -232,9 +252,10 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa def forward(self, query: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor): attn = one_attention( query, past_embed_key, past_embed_value, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, + # self.q_proj, self.q_bias, + # self.k_proj, self.k_bias, + # self.v_proj, self.v_bias, + self.qkv_proj, self.qkv_bias, self.out_proj, #self.out_bias, self.num_heads, self.scaling, self.dropout_p, self.training, mask=True ) diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 6ce484da..1fa1b697 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -36,6 +36,7 @@ class EncoderInferLayer(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, attn_hidden_dim: int, ffn_hidden_dim: int, seqlen: int = -1, + batch_size: int = 1, dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): super().__init__() self.self_attn_partial = MultiHeadOneAttention( @@ -47,7 +48,7 @@ def __init__(self, embed_dim: int, num_heads: int, self.final_layer_norm = torch.nn.LayerNorm(embed_dim) # id-embed + pos-embed - tmp_batch_size = 1 + tmp_batch_size = batch_size self.past_embed_key = torch.nn.Parameter(torch.rand(seqlen, tmp_batch_size, embed_dim)) self.past_embed_value = torch.nn.Parameter(torch.rand(seqlen, tmp_batch_size, embed_dim)) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index b3bc2d2f..a8df651d 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -108,7 +108,7 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): class GPTInfer(torch.nn.Module): - def __init__(self): + def __init__(self, batch_size: int = 1): super().__init__() cfg = Config() @@ -121,6 +121,7 @@ def __init__(self): [EncoderInferLayer( cfg.embed_dim, cfg.attention_heads, cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.seqlen, + batch_size, cfg.dropout, cfg.attn_dropout, cfg.activation_dropout ) for _ in range(cfg.layers)] ) @@ -205,12 +206,12 @@ def random_sample(self): 0, self.cfg.num_embeddings, size=(self.bs, 1), dtype=torch.int64, - # device=torch.cuda.current_device() + device=torch.cuda.current_device() ) position_ids = torch.arange( 0, 1, dtype=torch.int64, - # device=torch.cuda.current_device() - ).repeat(self.bs) + device=torch.cuda.current_device() + ).repeat(self.bs).view(self.bs, -1) return (input_ids, position_ids) def __iter__(self): From 16ef67b504e301432fb1640164aa74b7c4420349 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Fri, 30 Sep 2022 01:26:45 -0700 Subject: [PATCH 1050/1892] DAP runnbale, need further optimization --- examples/alphafold2/alphafold2.py | 10 +++-- examples/alphafold2/policy/spmd.py | 59 +++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 4803ff65..2c100ca6 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -1,3 +1,4 @@ +from audioop import mul import torch import math import cube @@ -380,6 +381,7 @@ def forward(self, msa_repr, pair_repr): msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), self.msa_transition_proj1, self.msa_transition_proj2) + succ_msa_repr, msa_repr = multi2ref(msa_repr) pair_msa_repr, dummy_pair_msa_repr = multi2ref( self.outer_norm(msa_repr)) @@ -388,7 +390,7 @@ def forward(self, msa_repr, pair_repr): self.outer_proj2, self.outer_out_proj) out_pair_repr, out_dummy_pair_repr = multi2ref(pair_repr) - pair_repr = pair_repr + TriangleMultiplicationOut( + pair_repr = out_pair_repr + TriangleMultiplicationOut( out_pair_repr, out_dummy_pair_repr, self.tri_mul_out_norm1_weight, self.tri_mul_out_norm1_bias, self.tri_mul_out_norm2_weight, self.tri_mul_out_norm2_bias, self.tri_mul_out_proj1, @@ -397,7 +399,7 @@ def forward(self, msa_repr, pair_repr): self.tri_mul_out_proj6, self.cz) in_pair_repr, in_dummy_pair_repr = multi2ref(pair_repr) - pair_repr = pair_repr + TriangleMultiplicationIn( + pair_repr = in_pair_repr + TriangleMultiplicationIn( in_pair_repr, in_dummy_pair_repr, self.tri_mul_in_norm1_weight, self.tri_mul_in_norm1_bias, self.tri_mul_in_norm2_weight, self.tri_mul_in_norm2_bias, self.tri_mul_in_proj1, @@ -419,7 +421,7 @@ def forward(self, msa_repr, pair_repr): self.pair_transition_norm(pair_repr), self.pair_transition_proj1, self.pair_transition_proj2) - return (msa_repr, pair_repr) + return (succ_msa_repr, pair_repr) class AlphaFold2(nn.Module): @@ -436,7 +438,7 @@ def forward(self, msa, pair): def test(): - bs, s, r, cm, cz = 1, 128, 384, 256, 128 + bs, s, r, cm, cz = 1, 128, 256, 256, 128 model = AlphaFold2(s, cm, cz, 1) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index b6fba9d4..f3866f30 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -78,30 +78,79 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): graph.assign(sub_node, dev_id) return sub_nodes + pred_name = '' for node in graph.nodes(): if isinstance(node, IRDataOperation): + # _tp(graph, node, tp_devs, 0, 1) _replica(graph, node, tp_devs) - - for node in graph.nodes(): - if isinstance(node, IRFwOperation): + elif isinstance(node, IRFwOperation): if node.name == 'MSARowAttentionWithPairBias': _tp(graph, node, tp_devs, 0, 1) + pred_name = node.name elif node.name == 'MSAColAttention': _tp(graph, node, tp_devs, 0, 2) + pred_name = node.name elif node.name == 'MSATransition': - _tp(graph, node, tp_devs, 0, 1) + _tp(graph, node, tp_devs, 0, 2) + pred_name = node.name elif node.name == 'OuterProductMean': _tp(graph, node, tp_devs, 0, 2) + pred_name = node.name elif node.name == 'TriangleMultiplicationOut': _tp(graph, node, tp_devs, 0, 1) + pred_name = node.name elif node.name == 'TriangleMultiplicationIn': _tp(graph, node, tp_devs, 0, 2) + pred_name = node.name elif node.name == 'TriangleAttentionNodeStart': _tp(graph, node, tp_devs, 0, 1) + pred_name = node.name elif node.name == 'TriangleAttentionNodeEnd': _tp(graph, node, tp_devs, 0, 2) + pred_name = node.name elif node.name == 'PairTransition': _tp(graph, node, tp_devs, 0, 1) + pred_name = node.name else: - _replica(graph, node, tp_devs) + if node.name == 'add': + if pred_name == 'PairTransition': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'TriangleAttentionNodeEnd': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'TriangleAttentionNodeStart': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'TriangleMultiplicationIn': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'TriangleMultiplicationOut': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'OuterProductMean': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'MSATransition': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'MSAColAttention': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'MSARowAttentionWithPairBias': + _tp(graph, node, tp_devs, 0, 1) + else: + assert False + elif node.name == 'layernorm': + if pred_name == 'TriangleAttentionNodeEnd': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'TriangleAttentionNodeStart': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'TriangleMultiplicationIn': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'MSATransition': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'MSAColAttention': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'MSARowAttentionWithPairBias': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == '': + _tp(graph, node, tp_devs, 0, 1) + else: + assert False + else: + print('replica node:', node.name) + _replica(graph, node, tp_devs) return graph From 41cae9a3e209b120c430af6277a8e2eea032bb66 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Fri, 30 Sep 2022 02:59:14 -0700 Subject: [PATCH 1051/1892] add util --- examples/alphafold2/alphafold2.py | 4 +++- examples/alphafold2/policy/spmd.py | 10 +++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 2c100ca6..d892af90 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -438,7 +438,8 @@ def forward(self, msa, pair): def test(): - bs, s, r, cm, cz = 1, 128, 256, 256, 128 + # bs, s, r, cm, cz = 1, 128, 256, 256, 128 + bs, s, r, cm, cz = 1, 512, 384, 256, 128 model = AlphaFold2(s, cm, cz, 1) @@ -467,6 +468,7 @@ def test(): 0, )) + # @cube.compile(model, dataloader, PAS=spmd.PASSingle) # @cube.compile(model, dataloader, PAS=spmd.PASData) @cube.compile(model, dataloader, PAS=spmd.PASMegatron, override=True) def train_iter(model, dataloader): diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index f3866f30..ad7080b0 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -2,7 +2,7 @@ from numpy import TooHardError from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation recompute_info = { 'MSAAttention': True, @@ -23,6 +23,14 @@ } +def PASSingle(graph: IRGraph, resource): + assert resource.ngpus == 1 + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + return graph + + def PASData(graph: IRGraph, resource): devs = list(range(resource.ngpus)) From 6d8def18cff0a7686f96f0fc02109205ff448412 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Sat, 1 Oct 2022 02:33:09 -0700 Subject: [PATCH 1052/1892] runnbale, collect fp16 numbers --- examples/alphafold2/alphafold2.py | 195 ++++++++++++++++------------- examples/alphafold2/policy/spmd.py | 16 ++- 2 files changed, 123 insertions(+), 88 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index d892af90..5d70700c 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -129,22 +129,30 @@ def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, torch.nn.functional.relu(torch.matmul(msa_repr, proj1)), proj2) +@cube.graph.parser.register('N S R M, M C -> N S R C', name='OPMLeftProj') +def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): + return torch.matmul(msa_repr, proj) + + +@cube.graph.parser.register('N S R M, M C -> N S R C', name='OPMRightProj') +def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): + return torch.matmul(msa_repr, proj) + + """ [bs, s, r, cm] -> [bs, r, r, cz] """ -@cube.graph.parser.register('N S R M, N S T M, M C, M C, F Z -> N R T Z', +@cube.graph.parser.register('N S R M, N S T M, F Z -> N R T Z', name='OuterProductMean') -def OuterProductMean(msa_repr: torch.Tensor, dummy_msa_repr: torch.Tensor, - proj1: torch.Tensor, proj2: torch.Tensor, +def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, out_proj: torch.Tensor): - bs, s, r, cm = msa_repr.size() - t = dummy_msa_repr.size(2) - c = proj1.size(-1) + bs, s, r, c = left_act.size() + t = right_act.size(2) - a = torch.matmul(msa_repr, proj1).transpose(-2, -3) - b = torch.matmul(dummy_msa_repr, proj2).transpose(-2, -3) + a = left_act.transpose(-2, -3) + b = right_act.transpose(-2, -3) outer = torch.einsum('...bac,...dae->...bdce', a, b).reshape(bs, r, t, c * c) @@ -152,27 +160,35 @@ def OuterProductMean(msa_repr: torch.Tensor, dummy_msa_repr: torch.Tensor, return outer -@cube.graph.parser.register( - 'N S R Z, N T R Z, Z, Z, E, E, Z E, Z E, Z E, Z E, E Z, Z Z -> N S T Z', - name='TriangleMultiplicationOut') -def TriangleMultiplicationOut( - pair_repr: torch.Tensor, dummy_pair_repr: torch.Tensor, - tri_mul_norm1_weight: torch.Tensor, tri_mul_norm1_bias: torch.Tensor, - tri_mul_norm2_weight: torch.Tensor, tri_mul_norm2_bias: torch.Tensor, - tri_mul_proj1: torch.Tensor, tri_mul_proj2: torch.Tensor, - tri_mul_proj3: torch.Tensor, tri_mul_proj4: torch.Tensor, - tri_mul_proj5: torch.Tensor, tri_mul_proj6: torch.Tensor, cz: int): - pair_repr = torch.nn.functional.layer_norm(pair_repr, (cz, ), - tri_mul_norm1_weight, - tri_mul_norm1_bias) - dummy_pair_repr = torch.nn.functional.layer_norm(dummy_pair_repr, (cz, ), - tri_mul_norm1_weight, - tri_mul_norm1_bias) - a = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj1)) - a = a * torch.matmul(pair_repr, tri_mul_proj2) - b = torch.sigmoid(torch.matmul(dummy_pair_repr, tri_mul_proj3)) - b = b * torch.matmul(dummy_pair_repr, tri_mul_proj4) +@cube.graph.parser.register('N S R Z, Z E, Z E -> N S R E', name='TMOLeftProj') +def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + a = torch.sigmoid(torch.matmul(pair_repr, proj1)) + a = a * torch.matmul(pair_repr, proj2) + return a + + +@cube.graph.parser.register('N S R Z, Z E, Z E -> N S R E', + name='TMORightProj') +def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + a = torch.sigmoid(torch.matmul(pair_repr, proj1)) + a = a * torch.matmul(pair_repr, proj2) + return a + + +@cube.graph.parser.register('N S T Z, Z Z -> N S T Z', name='TMOGate') +def TMOGate(pair_repr: torch.Tensor, proj: torch.Tensor): + return torch.sigmoid(torch.matmul(pair_repr, proj)) + +@cube.graph.parser.register('N S R E, N T R E, N S T Z, E, E, E Z -> N S T Z', + name='TriangleMultiplicationOut') +def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, + g: torch.Tensor, + tri_mul_norm2_weight: torch.Tensor, + tri_mul_norm2_bias: torch.Tensor, + tri_mul_proj5: torch.Tensor, cz: int): a = a.permute(0, 3, 1, 2) b = b.permute(0, 3, 2, 1) @@ -180,31 +196,37 @@ def TriangleMultiplicationOut( p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, tri_mul_norm2_bias) p = torch.matmul(p, tri_mul_proj5) - g = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj6)) return p * g -@cube.graph.parser.register( - 'N R S Z, N R T Z, Z, Z, E, E, Z E, Z E, Z E, Z E, E Z, Z Z -> N S T Z', - name='TriangleMultiplicationIn') -def TriangleMultiplicationIn( - pair_repr: torch.Tensor, dummy_pair_repr: torch.Tensor, - tri_mul_norm1_weight: torch.Tensor, tri_mul_norm1_bias: torch.Tensor, - tri_mul_norm2_weight: torch.Tensor, tri_mul_norm2_bias: torch.Tensor, - tri_mul_proj1: torch.Tensor, tri_mul_proj2: torch.Tensor, - tri_mul_proj3: torch.Tensor, tri_mul_proj4: torch.Tensor, - tri_mul_proj5: torch.Tensor, tri_mul_proj6: torch.Tensor, cz: int): - pair_repr = torch.nn.functional.layer_norm(pair_repr, (cz, ), - tri_mul_norm1_weight, - tri_mul_norm1_bias) - dummy_pair_repr = torch.nn.functional.layer_norm(dummy_pair_repr, (cz, ), - tri_mul_norm1_weight, - tri_mul_norm1_bias) - a = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj1)) - a = a * torch.matmul(pair_repr, tri_mul_proj2) - b = torch.sigmoid(torch.matmul(dummy_pair_repr, tri_mul_proj3)) - b = b * torch.matmul(dummy_pair_repr, tri_mul_proj4) +@cube.graph.parser.register('N R S Z, Z E, Z E -> N R S E', name='TMILeftProj') +def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + a = torch.sigmoid(torch.matmul(pair_repr, proj1)) + a = a * torch.matmul(pair_repr, proj2) + return a + + +@cube.graph.parser.register('N R T Z, Z E, Z E -> N R T E', + name='TMIRightProj') +def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + a = torch.sigmoid(torch.matmul(pair_repr, proj1)) + a = a * torch.matmul(pair_repr, proj2) + return a + +@cube.graph.parser.register('N S T Z, Z Z -> N S T Z', name='TMIGate') +def TMIGate(pair_repr: torch.Tensor, proj: torch.Tensor): + return torch.sigmoid(torch.matmul(pair_repr, proj)) + + +@cube.graph.parser.register('N R S E, N R T E, N T S Z, E, E, E Z -> N T S Z', + name='TriangleMultiplicationIn') +def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, + tri_mul_norm2_weight: torch.Tensor, + tri_mul_norm2_bias: torch.Tensor, + tri_mul_proj5: torch.Tensor, cz: int): a = a.permute(0, 3, 2, 1) b = b.permute(0, 3, 1, 2) @@ -212,8 +234,7 @@ def TriangleMultiplicationIn( p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, tri_mul_norm2_bias) p = torch.matmul(p, tri_mul_proj5) - g = torch.sigmoid(torch.matmul(pair_repr, tri_mul_proj6)).transpose(1, 2) - return p * g + return p.permute(0, 2, 1, 3) * g @cube.graph.parser.register('N S R Z, Z E, Z F, E Z, Z G -> N S R Z', @@ -306,8 +327,7 @@ def __init__(self, self.outer_out_proj = torch.nn.Parameter(torch.randn(c * c, cz)) # Triangular multiplicative update using outgoing edges - self.tri_mul_out_norm1_weight = torch.nn.Parameter(torch.empty(cz)) - self.tri_mul_out_norm1_bias = torch.nn.Parameter(torch.empty(cz)) + self.tri_mul_out_norm1 = torch.nn.LayerNorm(cz) self.tri_mul_out_norm2_weight = torch.nn.Parameter( torch.empty(c_tri_mult)) self.tri_mul_out_norm2_bias = torch.nn.Parameter( @@ -325,8 +345,7 @@ def __init__(self, self.tri_mul_out_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) # Triangular multiplicative update using incoming edges - self.tri_mul_in_norm1_weight = torch.nn.Parameter(torch.empty(cz)) - self.tri_mul_in_norm1_bias = torch.nn.Parameter(torch.empty(cz)) + self.tri_mul_in_norm1 = torch.nn.LayerNorm(cz) self.tri_mul_in_norm2_weight = torch.nn.Parameter( torch.empty(c_tri_mult)) self.tri_mul_in_norm2_bias = torch.nn.Parameter( @@ -383,29 +402,33 @@ def forward(self, msa_repr, pair_repr): self.msa_transition_proj2) succ_msa_repr, msa_repr = multi2ref(msa_repr) - pair_msa_repr, dummy_pair_msa_repr = multi2ref( - self.outer_norm(msa_repr)) - pair_repr = pair_repr + OuterProductMean( - pair_msa_repr, dummy_pair_msa_repr, self.outer_proj1, - self.outer_proj2, self.outer_out_proj) - - out_pair_repr, out_dummy_pair_repr = multi2ref(pair_repr) - pair_repr = out_pair_repr + TriangleMultiplicationOut( - out_pair_repr, out_dummy_pair_repr, self.tri_mul_out_norm1_weight, - self.tri_mul_out_norm1_bias, self.tri_mul_out_norm2_weight, - self.tri_mul_out_norm2_bias, self.tri_mul_out_proj1, - self.tri_mul_out_proj2, self.tri_mul_out_proj3, - self.tri_mul_out_proj4, self.tri_mul_out_proj5, - self.tri_mul_out_proj6, self.cz) - - in_pair_repr, in_dummy_pair_repr = multi2ref(pair_repr) - pair_repr = in_pair_repr + TriangleMultiplicationIn( - in_pair_repr, in_dummy_pair_repr, self.tri_mul_in_norm1_weight, - self.tri_mul_in_norm1_bias, self.tri_mul_in_norm2_weight, - self.tri_mul_in_norm2_bias, self.tri_mul_in_proj1, - self.tri_mul_in_proj2, self.tri_mul_in_proj3, - self.tri_mul_in_proj4, self.tri_mul_in_proj5, - self.tri_mul_in_proj6, self.cz) + msa_repr = self.outer_norm(msa_repr) + opm_left, opm_right = OPMLeftProj(msa_repr, + self.outer_proj1), OPMRightProj( + msa_repr, self.outer_proj2) + pair_repr = pair_repr + OuterProductMean(opm_left, opm_right, + self.outer_out_proj) + + pair_repr = self.tri_mul_out_norm1(pair_repr) + tmo_left, tmo_right = TMOLeftProj( + pair_repr, self.tri_mul_out_proj1, + self.tri_mul_out_proj2), TMORightProj(pair_repr, + self.tri_mul_out_proj3, + self.tri_mul_out_proj4) + tmo_g = TMOGate(pair_repr, self.tri_mul_out_proj6) + pair_repr = pair_repr + TriangleMultiplicationOut( + tmo_left, tmo_right, tmo_g, self.tri_mul_out_norm2_weight, + self.tri_mul_out_norm2_bias, self.tri_mul_out_proj5, self.cz) + + pair_repr = self.tri_mul_in_norm1(pair_repr) + tmi_left = TMILeftProj(pair_repr, self.tri_mul_in_proj1, + self.tri_mul_in_proj2) + tmi_right = TMIRightProj(pair_repr, self.tri_mul_in_proj3, + self.tri_mul_in_proj4) + tmi_gate = TMIGate(pair_repr, self.tri_mul_in_proj6) + pair_repr = pair_repr + TriangleMultiplicationIn( + tmi_left, tmi_right, tmi_gate, self.tri_mul_in_norm2_weight, + self.tri_mul_in_norm2_bias, self.tri_mul_in_proj5, self.cz) pair_repr = pair_repr + TriangleAttentionNodeStart( self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, @@ -438,10 +461,12 @@ def forward(self, msa, pair): def test(): - # bs, s, r, cm, cz = 1, 128, 256, 256, 128 - bs, s, r, cm, cz = 1, 512, 384, 256, 128 + bs, s, r, cm, cz = 1, 128, 256, 256, 128 + # bs, s, r, cm, cz = 1, 512, 384, 256, 128 + + dtype = torch.float16 - model = AlphaFold2(s, cm, cz, 1) + model = AlphaFold2(s, cm, cz, 1).to(dtype) # msa_repr, pair_repr = torch.randn(bs, s, r, cm), torch.randn(bs, r, r, cz) # x = model(msa_repr, pair_repr) @@ -460,8 +485,8 @@ def test(): [bs, r, r, cz], ), dtypes=( - torch.float32, - torch.float32, + dtype, + dtype, ), batch_dims=( 0, @@ -469,8 +494,8 @@ def test(): )) # @cube.compile(model, dataloader, PAS=spmd.PASSingle) - # @cube.compile(model, dataloader, PAS=spmd.PASData) - @cube.compile(model, dataloader, PAS=spmd.PASMegatron, override=True) + # @cube.compile(model, dataloader, PAS=spmd.PASMegatron, override=True) + @cube.compile(model, dataloader, PAS=spmd.PASData) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index ad7080b0..c0c6401e 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -101,13 +101,19 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): elif node.name == 'MSATransition': _tp(graph, node, tp_devs, 0, 2) pred_name = node.name + elif node.name == 'OPMLeftProj' or node.name == 'OPMRightProj': + _tp(graph, node, tp_devs, 0, 2) + pred_name = node.name elif node.name == 'OuterProductMean': _tp(graph, node, tp_devs, 0, 2) pred_name = node.name - elif node.name == 'TriangleMultiplicationOut': + elif node.name == 'TMOLeftProj' or node.name == 'TMORightProj' or node.name == 'TMOGate' or node.name == 'TriangleMultiplicationOut': _tp(graph, node, tp_devs, 0, 1) pred_name = node.name - elif node.name == 'TriangleMultiplicationIn': + elif node.name in { + 'TMILeftProj', 'TMIRightProj', 'TMIGate', + 'TriangleMultiplicationIn' + }: _tp(graph, node, tp_devs, 0, 2) pred_name = node.name elif node.name == 'TriangleAttentionNodeStart': @@ -148,8 +154,12 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): _tp(graph, node, tp_devs, 0, 2) elif pred_name == 'TriangleMultiplicationIn': _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'TriangleMultiplicationOut': + _tp(graph, node, tp_devs, 0, 2) elif pred_name == 'MSATransition': _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'OuterProductMean': + _tp(graph, node, tp_devs, 0, 1) elif pred_name == 'MSAColAttention': _tp(graph, node, tp_devs, 0, 2) elif pred_name == 'MSARowAttentionWithPairBias': @@ -157,7 +167,7 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): elif pred_name == '': _tp(graph, node, tp_devs, 0, 1) else: - assert False + assert False, pred_name else: print('replica node:', node.name) _replica(graph, node, tp_devs) From 74d023c205221e29cb6d3aa2eeee57df24770443 Mon Sep 17 00:00:00 2001 From: yizhu1 Date: Sun, 2 Oct 2022 18:19:48 -0700 Subject: [PATCH 1053/1892] refine code: add recompute --- examples/alphafold2/alphafold2.py | 10 +- examples/alphafold2/policy/spmd.py | 170 +++++++++++++++-------------- 2 files changed, 94 insertions(+), 86 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 5d70700c..1218f096 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -461,8 +461,8 @@ def forward(self, msa, pair): def test(): - bs, s, r, cm, cz = 1, 128, 256, 256, 128 - # bs, s, r, cm, cz = 1, 512, 384, 256, 128 + # bs, s, r, cm, cz = 1, 128, 256, 256, 128 + bs, s, r, cm, cz = 1, 512, 384, 256, 128 dtype = torch.float16 @@ -493,9 +493,9 @@ def test(): 0, )) - # @cube.compile(model, dataloader, PAS=spmd.PASSingle) - # @cube.compile(model, dataloader, PAS=spmd.PASMegatron, override=True) - @cube.compile(model, dataloader, PAS=spmd.PASData) + # @cube.compile(model, dataloader, PAS=spmd.PASData) + # @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) + @cube.compile(model, dataloader, PAS=spmd.PASSingle) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index c0c6401e..a1436729 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -5,13 +5,19 @@ from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation recompute_info = { - 'MSAAttention': True, - 'MSAAttentionWithBias': True, 'MSARowAttentionWithPairBias': True, 'MSAColAttention': True, 'MSATransition': True, + 'OPMLeftProj': False, + 'OPMRightProj': False, 'OuterProductMean': True, + 'TMOLeftProj': True, + 'TMORightProj': True, + 'TMOGate': True, 'TriangleMultiplicationOut': True, + 'TMILeftProj': True, + 'TMIRightProj': True, + 'TMIGate': True, 'TriangleMultiplicationIn': True, 'TriangleAttentionNodeStart': True, 'TriangleAttentionNodeEnd': True, @@ -19,7 +25,7 @@ 'add': False, 'sum': False, 'layernorm': False, - 'transpose': False, + 'multi2ref': False, } @@ -28,6 +34,9 @@ def PASSingle(graph: IRGraph, resource): for node in graph.nodes(): if not isinstance(node, IRBpOperation): graph.assign(node, 0) + if node.name in recompute_info and recompute_info[ + node.name] == True: + graph.recompute([node]) return graph @@ -63,7 +72,7 @@ def PASData(graph: IRGraph, resource): return graph -def PASMegatron(graph: IRGraph, resource): +def PASDAP(graph: IRGraph, resource): tp_size = resource.ngpus tp_devs = list(range(tp_size)) @@ -92,83 +101,82 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): # _tp(graph, node, tp_devs, 0, 1) _replica(graph, node, tp_devs) elif isinstance(node, IRFwOperation): - if node.name == 'MSARowAttentionWithPairBias': - _tp(graph, node, tp_devs, 0, 1) - pred_name = node.name - elif node.name == 'MSAColAttention': - _tp(graph, node, tp_devs, 0, 2) - pred_name = node.name - elif node.name == 'MSATransition': - _tp(graph, node, tp_devs, 0, 2) - pred_name = node.name - elif node.name == 'OPMLeftProj' or node.name == 'OPMRightProj': - _tp(graph, node, tp_devs, 0, 2) - pred_name = node.name - elif node.name == 'OuterProductMean': - _tp(graph, node, tp_devs, 0, 2) - pred_name = node.name - elif node.name == 'TMOLeftProj' or node.name == 'TMORightProj' or node.name == 'TMOGate' or node.name == 'TriangleMultiplicationOut': - _tp(graph, node, tp_devs, 0, 1) - pred_name = node.name - elif node.name in { - 'TMILeftProj', 'TMIRightProj', 'TMIGate', - 'TriangleMultiplicationIn' - }: - _tp(graph, node, tp_devs, 0, 2) - pred_name = node.name - elif node.name == 'TriangleAttentionNodeStart': - _tp(graph, node, tp_devs, 0, 1) - pred_name = node.name - elif node.name == 'TriangleAttentionNodeEnd': - _tp(graph, node, tp_devs, 0, 2) - pred_name = node.name - elif node.name == 'PairTransition': - _tp(graph, node, tp_devs, 0, 1) - pred_name = node.name + if node.name == 'add': + if pred_name == 'PairTransition': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'TriangleAttentionNodeEnd': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'TriangleAttentionNodeStart': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'TriangleMultiplicationIn': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'TriangleMultiplicationOut': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'OuterProductMean': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'MSATransition': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'MSAColAttention': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'MSARowAttentionWithPairBias': + _tp(graph, node, tp_devs, 0, 1) + else: + assert False, pred_name + elif node.name == 'layernorm': + if pred_name == 'TriangleAttentionNodeEnd': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'TriangleAttentionNodeStart': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'TriangleMultiplicationIn': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'TriangleMultiplicationOut': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'MSATransition': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'OuterProductMean': + _tp(graph, node, tp_devs, 0, 1) + elif pred_name == 'MSAColAttention': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == 'MSARowAttentionWithPairBias': + _tp(graph, node, tp_devs, 0, 2) + elif pred_name == '': + _tp(graph, node, tp_devs, 0, 1) + else: + assert False, pred_name + elif node.name in {'sum', 'mul', 'multi2ref'}: + _replica(graph, node, tp_devs) else: - if node.name == 'add': - if pred_name == 'PairTransition': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'TriangleAttentionNodeEnd': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'TriangleAttentionNodeStart': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'TriangleMultiplicationIn': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'TriangleMultiplicationOut': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'OuterProductMean': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'MSATransition': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'MSAColAttention': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'MSARowAttentionWithPairBias': - _tp(graph, node, tp_devs, 0, 1) - else: - assert False - elif node.name == 'layernorm': - if pred_name == 'TriangleAttentionNodeEnd': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'TriangleAttentionNodeStart': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'TriangleMultiplicationIn': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'TriangleMultiplicationOut': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'MSATransition': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'OuterProductMean': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'MSAColAttention': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'MSARowAttentionWithPairBias': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == '': - _tp(graph, node, tp_devs, 0, 1) - else: - assert False, pred_name + pred_name = node.name + if node.name == 'MSARowAttentionWithPairBias': + sub_nodes = _tp(graph, node, tp_devs, 0, 1) + elif node.name == 'MSAColAttention': + sub_nodes = _tp(graph, node, tp_devs, 0, 2) + elif node.name == 'MSATransition': + sub_nodes = _tp(graph, node, tp_devs, 0, 2) + elif node.name in { + 'OPMLeftProj', 'OPMRightProj', 'OuterProductMean' + }: + sub_nodes = _tp(graph, node, tp_devs, 0, 2) + elif node.name in { + 'TMOLeftProj', 'TMORightProj', 'TMOGate', + 'TriangleMultiplicationOut' + }: + sub_nodes = _tp(graph, node, tp_devs, 0, 1) + elif node.name in { + 'TMILeftProj', 'TMIRightProj', 'TMIGate', + 'TriangleMultiplicationIn' + }: + sub_nodes = _tp(graph, node, tp_devs, 0, 2) + elif node.name == 'TriangleAttentionNodeStart': + sub_nodes = _tp(graph, node, tp_devs, 0, 1) + elif node.name == 'TriangleAttentionNodeEnd': + sub_nodes = _tp(graph, node, tp_devs, 0, 2) + elif node.name == 'PairTransition': + sub_nodes = _tp(graph, node, tp_devs, 0, 1) else: - print('replica node:', node.name) - _replica(graph, node, tp_devs) + assert False, node.name + + if node.name in recompute_info and recompute_info[ + node.name] == True: + graph.recompute(sub_nodes) return graph From 8a51dce1b8b20e8c1a4bc59c00fde3d87a3f3ce6 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 10 Oct 2022 16:33:23 +0800 Subject: [PATCH 1054/1892] save work, not fully runnable --- examples/alphafold2/alphafold2.py | 184 +++++++++++++++++++++-------- examples/alphafold2/policy/spmd.py | 12 ++ 2 files changed, 146 insertions(+), 50 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 1218f096..8eeb2535 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -1,16 +1,31 @@ from audioop import mul import torch +import torch.utils.checkpoint as ckpt import math import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank from torch import nn -from torch.utils.checkpoint import checkpoint import examples.alphafold2.policy.spmd as spmd cube.init() + + +@cube.graph.parser.register('TODO', name='calc_qkvg') +def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, + bs: int, s: int, r: int, head: int, c: int): + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + + gate = gate.reshape(bs, s, r, head, c).transpose(2, 3) + q = q.reshape(bs, s, r, head, c).transpose(2, 3) + k = k.reshape(bs, s, r, head, c).transpose(2, 3) + v = v.reshape(bs, s, r, head, c).transpose(2, 3) + return q, k, v, gate + + """ [bs, s, r, cm] -> [bs, s, r, cm] @@ -20,66 +35,118 @@ @cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R M', name='MSAAttention') +@torch.jit.ignore def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, c: int, scale: float): bs, s, r, cm = x.size() - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - - gate = gate.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - q = q.reshape(bs, s, r, head, c).transpose(2, - 3).reshape(bs * s * head, r, c) - k = k.reshape(bs, s, r, head, c).transpose(2, - 3).reshape(bs * s * head, r, - c).transpose(1, 2) - v = v.reshape(bs, s, r, head, c).transpose(2, - 3).reshape(bs * s * head, r, c) - - sim = torch.bmm(q, k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - attend = torch.bmm(sim, v) * gate - - out = attend.reshape(bs, s, head, r, c).transpose(2, - 3).reshape(bs, s, r, cm) - out = torch.matmul(out, out_proj) + q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, + gate_proj, bs, s, r, + head, c) + + chunk_size = 1 + + assert s % chunk_size == 0 + out_chunks = [] + + def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Tensor, start: int): + cur_q = q[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_k = k[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c).transpose(1, 2) + cur_v = v[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + sim = torch.bmm(cur_q, cur_k) * 0.125 + sim = torch.nn.functional.softmax(sim, dim=-1) + attend = torch.bmm(sim, cur_v) * cur_gate + attend = attend.reshape(bs, chunk_size, head, r, c).transpose(2, 3).reshape(bs, chunk_size, r, cm) + return attend + + for start in range(0, s, chunk_size): + # attend = ckpt.checkpoint(attention, q, k, v, gate, start) + attend = attention(q, k, v, gate, start) + out_chunks.append(attend) + + out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) return out @cube.graph.parser.register('N S R M, M E, M F, E M, N 1 8 R R -> N S R M', name='MSAAttentionWithBias') +@torch.jit.ignore def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float): bs, s, r, cm = x.size() - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - - gate = gate.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - q = q.reshape(bs, s, r, head, c).transpose(2, - 3).reshape(bs * s * head, r, c) - k = k.reshape(bs, s, r, head, c).transpose(2, - 3).reshape(bs * s * head, r, - c).transpose(1, 2) - v = v.reshape(bs, s, r, head, c).transpose(2, - 3).reshape(bs * s * head, r, c) - - sim = torch.bmm(q, k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - - sim = sim.reshape(bs, s, head, r, r) + bias - sim = sim.reshape(bs * s * head, r, r) - - attend = torch.bmm(sim, v) * gate - - out = attend.reshape(bs, s, head, r, c).transpose(2, - 3).reshape(bs, s, r, cm) - out = torch.matmul(out, out_proj) + chunk_size = 1 + + if chunk_size == -1: + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + + gate = gate.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + q = q.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + k = k.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, + c).transpose(1, 2) + v = v.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + + sim = torch.bmm(q, k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + + sim = sim.reshape(bs, s, head, r, r) + bias + sim = sim.reshape(bs * s * head, r, r) + + attend = torch.bmm(sim, v) * gate + + out = attend.reshape(bs, s, head, r, + c).transpose(2, 3).reshape(bs, s, r, cm) + out = torch.matmul(out, out_proj) + else: + q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, + gate_proj, bs, s, r, + head, c) + + assert s % chunk_size == 0 + out_chunks = [] + def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Tensor, + bias: torch.Tensor, start: int): + cur_q = q[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_k = k[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c).transpose(1, 2) + cur_v = v[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + + sim = torch.bmm(cur_q, cur_k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + sim = sim.reshape(bs, chunk_size, head, r, r) + bias + sim = sim.reshape(bs * chunk_size * head, r, r) + + attend = torch.bmm(sim, cur_v) * cur_gate + attend = attend.reshape(bs, chunk_size, head, r, c).transpose(2, 3).reshape(bs, chunk_size, r, cm) + return attend + + for start in range(0, s, chunk_size): + # attend = ckpt.checkpoint(attention_bias, q, k, v, gate, + # bias, + # start) + attend = attention_bias(q, k, v, gate, + bias, + start) + + out_chunks.append(attend) + out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) return out @@ -138,6 +205,11 @@ def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): return torch.matmul(msa_repr, proj) +@cube.graph.parser.register('TODO', name='opm') +def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int, bs: int, chunk_size: int, t: int, c: int): + lhs_slice = lhs[:, start:start+chunk_size, :, :] + out = torch.einsum('...bac,...dae->...bdce', lhs_slice, rhs).reshape(bs, chunk_size, t, c * c) + return out """ [bs, s, r, cm] -> [bs, r, r, cz] @@ -154,9 +226,17 @@ def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, a = left_act.transpose(-2, -3) b = right_act.transpose(-2, -3) - outer = torch.einsum('...bac,...dae->...bdce', a, + chunk_size = 1 + + if chunk_size == -1: + outer = torch.einsum('...bac,...dae->...bdce', a, b).reshape(bs, r, t, c * c) - outer = torch.matmul(outer, out_proj) + outer = torch.matmul(outer, out_proj) + else: + out_chunks = [] + for start in range(0, r, chunk_size): + out_chunks.append(opm(a, b, start, bs, chunk_size, t, c)) + outer = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) return outer @@ -243,7 +323,7 @@ def TriangleAttentionNodeStart(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias_proj: torch.Tensor, head: int, c: int, scale: float): - bias = torch.matmul(pair_repr, bias_proj).permute(0, 1, 3, 2).unsqueeze(3) + bias = torch.matmul(pair_repr, bias_proj).permute(0, 3, 1, 2).unsqueeze(1) return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, head, c, scale) @@ -462,7 +542,11 @@ def forward(self, msa, pair): def test(): # bs, s, r, cm, cz = 1, 128, 256, 256, 128 - bs, s, r, cm, cz = 1, 512, 384, 256, 128 + # bs, s, r, cm, cz = 1, 512, 384, 256, 128 + # bs, s, r, cm, cz = 1, 1024, 256, 256, 128 + # bs, s, r, cm, cz = 1, 5120, 384, 256, 128 + bs, s, r, cm, cz = 1, 512, 2048, 256, 128 + # bs, s, r, cm, cz = 1, 128, 2048, 256, 128 dtype = torch.float16 @@ -495,7 +579,7 @@ def test(): # @cube.compile(model, dataloader, PAS=spmd.PASData) # @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) - @cube.compile(model, dataloader, PAS=spmd.PASSingle) + @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index a1436729..c3eab632 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -28,6 +28,18 @@ 'multi2ref': False, } +# coshard +def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, + idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) + assert sub_nodes is not None + graph.recompute(sub_nodes) + for devid in devs: + for coid in range(colocate): + sub_node = sub_nodes[devid * colocate + coid] + graph.assign(sub_node, devid) + return sub_nodes def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 From 47d4ca6e9558c6b74e622298ae57f4ea227c04c1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 10 Oct 2022 16:54:24 +0800 Subject: [PATCH 1055/1892] support torch.jit.ignore for customized function --- cube/graph/parser/parser.py | 45 +++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 95627ec2..fff5a392 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -693,8 +693,49 @@ def parse_prim_list_unpack_node(node, module, frame: Frame) -> List[None]: @staticmethod def parse_prim_python_op_node(node, module, frame): - raise NotImplementedError("Cannot support torch.jit.ignore") - print(dir(node)) + """ + parse node like: + %64 : Tensor = ^OuterProductMean()(%opm_left.1, %opm_right.1, %outer_out_proj) + """ + # get inputs + input_vals = list() + for input in node.inputs(): + var_name = input.debugName() + val = frame.get_var(var_name) + input_vals.append(val) + + fsig: str = str(node.pyname()) + + # map to IR operator + ir_node = Sign2Op.map(fsig)(inputs=input_vals) + + # push output in the frame + # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) + # : >>> dir(a) + # : >>> a.elements() # [TensorType, TensorType] + cnt = 0 + for output in node.outputs(): + if isinstance(output.type(), torch._C.TupleType): + tuplen = len(output.type().elements()) + ir_output = [ir_node.output(idx) for idx in range(cnt, cnt+tuplen)] + cnt += tuplen + else: + ir_output = ir_node.output(cnt) + cnt += 1 + frame.add_var(output.debugName(), ir_output) + + if cnt != len(ir_node.outputs()): + raise RuntimeError( + f"Parse fail: {fsig} has {cnt} outputs != pre-defined {len(ir_node.outputs())}" + ) + + # print(input_vals) + # print(node.pyname()) + # print(dir(node)) + # print(tuple(node.inputs())) + # print(tuple(node.outputs())) + # raise NotImplementedError("Cannot support torch.jit.ignore") + return [ir_node] @staticmethod def parse_value_erased_node(node, module, frame, erased_vals: List[Any]): From 284f0cb88e18aedf01677c9b681d2b60b9d906dd Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 10 Oct 2022 20:37:09 +0800 Subject: [PATCH 1056/1892] save work, runnable for several large case, need further debug --- examples/alphafold2/alphafold2.py | 80 ++++++++++++++++++------------ examples/alphafold2/policy/spmd.py | 16 ++++-- 2 files changed, 58 insertions(+), 38 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 8eeb2535..d4a272e8 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -41,17 +41,18 @@ def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, c: int, scale: float): bs, s, r, cm = x.size() + q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, bs, s, + r, head, c) - q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, - gate_proj, bs, s, r, - head, c) - - chunk_size = 1 + import math + chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) + print(chunk_size) assert s % chunk_size == 0 out_chunks = [] - def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Tensor, start: int): + def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + gate: torch.Tensor, start: int): cur_q = q[:, start:start + chunk_size, :, :, :].reshape( bs * chunk_size * head, r, c) cur_k = k[:, start:start + chunk_size, :, :, :].reshape( @@ -63,12 +64,14 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Ten sim = torch.bmm(cur_q, cur_k) * 0.125 sim = torch.nn.functional.softmax(sim, dim=-1) attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, c).transpose(2, 3).reshape(bs, chunk_size, r, cm) + attend = attend.reshape(bs, chunk_size, head, r, + c).transpose(2, + 3).reshape(bs, chunk_size, r, cm) return attend for start in range(0, s, chunk_size): - # attend = ckpt.checkpoint(attention, q, k, v, gate, start) - attend = attention(q, k, v, gate, start) + attend = ckpt.checkpoint(attention, q, k, v, gate, start) + # attend = attention(q, k, v, gate, start) out_chunks.append(attend) out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) @@ -83,7 +86,8 @@ def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float): bs, s, r, cm = x.size() - chunk_size = 1 + import math + chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) if chunk_size == -1: gate = torch.sigmoid(torch.matmul(x, gate_proj)) @@ -111,14 +115,14 @@ def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, c).transpose(2, 3).reshape(bs, s, r, cm) out = torch.matmul(out, out_proj) else: - q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, - gate_proj, bs, s, r, - head, c) + q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, bs, + s, r, head, c) assert s % chunk_size == 0 out_chunks = [] - def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Tensor, - bias: torch.Tensor, start: int): + + def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + gate: torch.Tensor, bias: torch.Tensor, start: int): cur_q = q[:, start:start + chunk_size, :, :, :].reshape( bs * chunk_size * head, r, c) cur_k = k[:, start:start + chunk_size, :, :, :].reshape( @@ -134,16 +138,16 @@ def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torc sim = sim.reshape(bs * chunk_size * head, r, r) attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, c).transpose(2, 3).reshape(bs, chunk_size, r, cm) + attend = attend.reshape(bs, chunk_size, head, r, c).transpose( + 2, 3).reshape(bs, chunk_size, r, cm) return attend for start in range(0, s, chunk_size): - # attend = ckpt.checkpoint(attention_bias, q, k, v, gate, + attend = ckpt.checkpoint(attention_bias, q, k, v, gate, bias, + start) + # attend = attention_bias(q, k, v, gate, # bias, # start) - attend = attention_bias(q, k, v, gate, - bias, - start) out_chunks.append(attend) out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) @@ -205,11 +209,6 @@ def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): return torch.matmul(msa_repr, proj) -@cube.graph.parser.register('TODO', name='opm') -def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int, bs: int, chunk_size: int, t: int, c: int): - lhs_slice = lhs[:, start:start+chunk_size, :, :] - out = torch.einsum('...bac,...dae->...bdce', lhs_slice, rhs).reshape(bs, chunk_size, t, c * c) - return out """ [bs, s, r, cm] -> [bs, r, r, cz] @@ -218,6 +217,7 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int, bs: int, chunk_size: i @cube.graph.parser.register('N S R M, N S T M, F Z -> N R T Z', name='OuterProductMean') +@torch.jit.ignore def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, out_proj: torch.Tensor): bs, s, r, c = left_act.size() @@ -226,17 +226,27 @@ def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, a = left_act.transpose(-2, -3) b = right_act.transpose(-2, -3) - chunk_size = 1 + import math + chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) if chunk_size == -1: outer = torch.einsum('...bac,...dae->...bdce', a, - b).reshape(bs, r, t, c * c) + b).reshape(bs, r, t, c * c) outer = torch.matmul(outer, out_proj) else: out_chunks = [] + + def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): + lhs_slice = lhs[:, start:start + chunk_size, :, :] + out = torch.einsum('...bac,...dae->...bdce', lhs_slice, + rhs).reshape(bs, chunk_size, t, c * c) + out = torch.matmul(out, out_proj) + return out + for start in range(0, r, chunk_size): - out_chunks.append(opm(a, b, start, bs, chunk_size, t, c)) - outer = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) + # out_chunks.append(opm(a, b, start)) + out_chunks.append(ckpt.checkpoint(opm, a, b, start)) + outer = torch.cat(out_chunks, dim=1) return outer @@ -533,9 +543,11 @@ def __init__(self, s: int, cm: int, cz: int, evo_num: int): super().__init__() self.evo_num = evo_num self.evoformer = Evoformer(s, cm, cz) + self.evoformer2 = Evoformer(s, cm, cz) def forward(self, msa, pair): new_msa, new_pair = self.evoformer(msa, pair) + new_msa, new_pair = self.evoformer2(new_msa, new_pair) loss = torch.sum(new_msa) * torch.sum(new_pair) return loss @@ -545,8 +557,10 @@ def test(): # bs, s, r, cm, cz = 1, 512, 384, 256, 128 # bs, s, r, cm, cz = 1, 1024, 256, 256, 128 # bs, s, r, cm, cz = 1, 5120, 384, 256, 128 - bs, s, r, cm, cz = 1, 512, 2048, 256, 128 - # bs, s, r, cm, cz = 1, 128, 2048, 256, 128 + # bs, s, r, cm, cz = 1, 512, 2048, 256, 128 + # bs, s, r, cm, cz = 1, 128, 1024, 256, 128 + # bs, s, r, cm, cz = 1, 128, 768, 256, 128 + bs, s, r, cm, cz = 1, 256, 768, 256, 128 dtype = torch.float16 @@ -589,8 +603,8 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - warm_up = 20 - iter_num = 64 + warm_up = 1 + iter_num = 2 CudaTimer(enable=False).warmup() if torch.distributed.is_initialized(): torch.distributed.barrier() diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index c3eab632..23c356f5 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -8,8 +8,8 @@ 'MSARowAttentionWithPairBias': True, 'MSAColAttention': True, 'MSATransition': True, - 'OPMLeftProj': False, - 'OPMRightProj': False, + 'OPMLeftProj': True, + 'OPMRightProj': True, 'OuterProductMean': True, 'TMOLeftProj': True, 'TMORightProj': True, @@ -28,11 +28,16 @@ 'multi2ref': False, } + # coshard -def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, - idx: int, dim: int): +def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], + colocate: int, idx: int, dim: int): algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) + sub_nodes = graph.partition(node, + algo, + idx=idx, + dim=dim, + num=colocate * len(devs)) assert sub_nodes is not None graph.recompute(sub_nodes) for devid in devs: @@ -41,6 +46,7 @@ def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int graph.assign(sub_node, devid) return sub_nodes + def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 for node in graph.nodes(): From d2ab85b6ffdfe3715d5fa4c7470abb7c56501bb2 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 11 Oct 2022 10:26:07 +0800 Subject: [PATCH 1057/1892] updt config --- examples/alphafold2/alphafold2.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index d4a272e8..e4151d05 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -46,7 +46,6 @@ def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, import math chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) - print(chunk_size) assert s % chunk_size == 0 out_chunks = [] @@ -543,24 +542,31 @@ def __init__(self, s: int, cm: int, cz: int, evo_num: int): super().__init__() self.evo_num = evo_num self.evoformer = Evoformer(s, cm, cz) - self.evoformer2 = Evoformer(s, cm, cz) def forward(self, msa, pair): new_msa, new_pair = self.evoformer(msa, pair) - new_msa, new_pair = self.evoformer2(new_msa, new_pair) loss = torch.sum(new_msa) * torch.sum(new_pair) return loss def test(): + # Training + # initial training: evoformer # bs, s, r, cm, cz = 1, 128, 256, 256, 128 + # first fine-tuning: evoformer + # bs, s, r, cm, cz = 1, 512, 256, 256, 128 + # second fine-tuning: evoformer # bs, s, r, cm, cz = 1, 512, 384, 256, 128 + # initial training: extra sequence # bs, s, r, cm, cz = 1, 1024, 256, 256, 128 + # second fine-tuning: extra sequence + # bs, s, r, cm, cz = 1, 1024, 384, 256, 128 + # OOM on RTX 2080 Ti # bs, s, r, cm, cz = 1, 5120, 384, 256, 128 - # bs, s, r, cm, cz = 1, 512, 2048, 256, 128 - # bs, s, r, cm, cz = 1, 128, 1024, 256, 128 - # bs, s, r, cm, cz = 1, 128, 768, 256, 128 - bs, s, r, cm, cz = 1, 256, 768, 256, 128 + + # Inference + # T1044: 2048 -> 2180 + # bs, s, r, cm, cz = 1, 128, 2048, 256, 128 dtype = torch.float16 From 26103c45b922b738353506009186a281b4cd0f76 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 11 Oct 2022 00:08:36 -0700 Subject: [PATCH 1058/1892] fix bug, runnable for multiple device --- examples/alphafold2/alphafold2.py | 43 +++++++++++++++++++----------- examples/alphafold2/policy/spmd.py | 6 +++-- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index e4151d05..7f2a1c91 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -226,7 +226,7 @@ def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, b = right_act.transpose(-2, -3) import math - chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) + chunk_size = min(r, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) if chunk_size == -1: outer = torch.einsum('...bac,...dae->...bdce', a, @@ -325,28 +325,37 @@ def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, p = torch.matmul(p, tri_mul_proj5) return p.permute(0, 2, 1, 3) * g +@cube.graph.parser.register('N S R C, C D -> N S R D', name='TANSBias') +def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): + return torch.matmul(pair_repr, bias_proj) + -@cube.graph.parser.register('N S R Z, Z E, Z F, E Z, Z G -> N S R Z', +@cube.graph.parser.register('N S R Z, Z E, Z F, E Z, N T R G -> N S R Z', name='TriangleAttentionNodeStart') def TriangleAttentionNodeStart(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, - out_proj: torch.Tensor, bias_proj: torch.Tensor, + out_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float): - bias = torch.matmul(pair_repr, bias_proj).permute(0, 3, 1, 2).unsqueeze(1) + bias = bias.permute(0, 3, 1, 2).unsqueeze(1) return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, head, c, scale) +@cube.graph.parser.register('N S R C, C D -> N S R D', name='TANEBias') +def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): + return torch.matmul(pair_repr, bias_proj) + -@cube.graph.parser.register('N S R Z, Z E, Z F, E Z, Z G -> N S R Z', +@cube.graph.parser.register('N R S Z, Z E, Z F, E Z, N R T G -> N R S Z', name='TriangleAttentionNodeEnd') def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias_proj: torch.Tensor, head: int, c: int, + bias: torch.Tensor, head: int, c: int, scale: float): pair_repr = pair_repr.permute(0, 2, 1, 3) + bias = bias.permute(0, 2, 1, 3) out = TriangleAttentionNodeStart(pair_repr, gate_proj, qkv_proj, out_proj, - bias_proj, head, c, scale) + bias, head, c, scale) return out.permute(0, 2, 1, 3) @@ -519,15 +528,19 @@ def forward(self, msa_repr, pair_repr): tmi_left, tmi_right, tmi_gate, self.tri_mul_in_norm2_weight, self.tri_mul_in_norm2_bias, self.tri_mul_in_proj5, self.cz) + pair_repr = self.tri_att_start_norm(pair_repr) + bias = TANSBias(pair_repr, self.tri_att_start_bias_proj) pair_repr = pair_repr + TriangleAttentionNodeStart( - self.tri_att_start_norm(pair_repr), self.tri_att_start_gate_proj, + pair_repr, self.tri_att_start_gate_proj, self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, - self.tri_att_start_bias_proj, self.pair_head, self.c, self.scale) + bias, self.pair_head, self.c, self.scale) + pair_repr = self.tri_att_end_norm(pair_repr) + bias = TANEBias(pair_repr, self.tri_att_end_bias_proj) pair_repr = pair_repr + TriangleAttentionNodeEnd( - self.tri_att_end_norm(pair_repr), self.tri_att_end_gate_proj, + pair_repr, self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, - self.tri_att_end_bias_proj, self.pair_head, self.c, self.scale) + bias, self.pair_head, self.c, self.scale) pair_repr = pair_repr + PairTransition( self.pair_transition_norm(pair_repr), self.pair_transition_proj1, @@ -561,8 +574,8 @@ def test(): # bs, s, r, cm, cz = 1, 1024, 256, 256, 128 # second fine-tuning: extra sequence # bs, s, r, cm, cz = 1, 1024, 384, 256, 128 - # OOM on RTX 2080 Ti - # bs, s, r, cm, cz = 1, 5120, 384, 256, 128 + # OOM on RTX 2080 Ti & V100 + bs, s, r, cm, cz = 1, 5120, 384, 256, 128 # Inference # T1044: 2048 -> 2180 @@ -598,8 +611,8 @@ def test(): )) # @cube.compile(model, dataloader, PAS=spmd.PASData) - # @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) - @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) + # @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) + @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 23c356f5..d9c6a128 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -19,7 +19,9 @@ 'TMIRightProj': True, 'TMIGate': True, 'TriangleMultiplicationIn': True, + 'TANSBias': True, 'TriangleAttentionNodeStart': True, + 'TANEBias': True, 'TriangleAttentionNodeEnd': True, 'PairTransition': True, 'add': False, @@ -185,9 +187,9 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): 'TriangleMultiplicationIn' }: sub_nodes = _tp(graph, node, tp_devs, 0, 2) - elif node.name == 'TriangleAttentionNodeStart': + elif node.name in {'TANSBias', 'TriangleAttentionNodeStart'}: sub_nodes = _tp(graph, node, tp_devs, 0, 1) - elif node.name == 'TriangleAttentionNodeEnd': + elif node.name in {'TANEBias', 'TriangleAttentionNodeEnd'}: sub_nodes = _tp(graph, node, tp_devs, 0, 2) elif node.name == 'PairTransition': sub_nodes = _tp(graph, node, tp_devs, 0, 1) From 3fa4cfa85916ffacd3295af6c35f0cc81d48064f Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 11 Oct 2022 02:39:44 -0700 Subject: [PATCH 1059/1892] save work --- examples/alphafold2/alphafold2.py | 15 ++++++++++----- examples/alphafold2/policy/spmd.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 7f2a1c91..11df2f48 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -486,8 +486,10 @@ def __init__(self, def forward(self, msa_repr, pair_repr): + pair_repr, dummy_pair_repr = multi2ref(pair_repr) + msa_repr = msa_repr + MSARowAttentionWithPairBias( - self.row_norm_m(msa_repr), pair_repr, self.row_gate_proj, + self.row_norm_m(msa_repr), dummy_pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, self.msa_head, self.c, self.scale) @@ -554,15 +556,18 @@ class AlphaFold2(nn.Module): def __init__(self, s: int, cm: int, cz: int, evo_num: int): super().__init__() self.evo_num = evo_num - self.evoformer = Evoformer(s, cm, cz) + self.evoformers = torch.nn.ModuleList([Evoformer(s, cm, cz) for _ in range(evo_num)]) def forward(self, msa, pair): - new_msa, new_pair = self.evoformer(msa, pair) - loss = torch.sum(new_msa) * torch.sum(new_pair) + for evoformer in self.evoformers: + msa, pair = evoformer(msa, pair) + loss = torch.sum(msa) * torch.sum(pair) return loss def test(): + evo_num = 2 + # Training # initial training: evoformer # bs, s, r, cm, cz = 1, 128, 256, 256, 128 @@ -583,7 +588,7 @@ def test(): dtype = torch.float16 - model = AlphaFold2(s, cm, cz, 1).to(dtype) + model = AlphaFold2(s, cm, cz, evo_num).to(dtype) # msa_repr, pair_repr = torch.randn(bs, s, r, cm), torch.randn(bs, r, r, cz) # x = model(msa_repr, pair_repr) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index d9c6a128..3b651610 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -159,7 +159,7 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): _tp(graph, node, tp_devs, 0, 2) elif pred_name == 'MSARowAttentionWithPairBias': _tp(graph, node, tp_devs, 0, 2) - elif pred_name == '': + elif pred_name == '' or pred_name == 'PairTransition': _tp(graph, node, tp_devs, 0, 1) else: assert False, pred_name From 5b454fa2d561a6276f9e059b2aa54bbc853ad483 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 11 Oct 2022 04:40:24 -0700 Subject: [PATCH 1060/1892] fix bugs, align communication primitives, 6 all2all_all2all & 6 allgather_reducescatter each evoformer block --- examples/alphafold2/policy/spmd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 3b651610..49813e88 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -164,7 +164,12 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): else: assert False, pred_name elif node.name in {'sum', 'mul', 'multi2ref'}: - _replica(graph, node, tp_devs) + if node.name == 'multi2ref' and pred_name == 'PairTransition': + _tp(graph, node, tp_devs, 0, 1) + elif node.name == 'multi2ref' and pred_name == 'MSATransition': + _tp(graph, node, tp_devs, 0, 2) + else: + _replica(graph, node, tp_devs) else: pred_name = node.name if node.name == 'MSARowAttentionWithPairBias': From 09ec93f462ca45d470792bb10d00920a920f4af2 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Wed, 12 Oct 2022 21:51:34 +0800 Subject: [PATCH 1061/1892] recompute runnable for single machine, memory incorrect currently --- examples/alphafold2/alphafold2.py | 128 ++++++++++++++++++----------- examples/alphafold2/policy/spmd.py | 16 +++- 2 files changed, 95 insertions(+), 49 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 11df2f48..28f60d46 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -41,39 +41,61 @@ def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, c: int, scale: float): bs, s, r, cm = x.size() - q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, bs, s, - r, head, c) - import math - chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) - - assert s % chunk_size == 0 - out_chunks = [] - - def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - gate: torch.Tensor, start: int): - cur_q = q[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_k = k[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c).transpose(1, 2) - cur_v = v[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - sim = torch.bmm(cur_q, cur_k) * 0.125 + # chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) + chunk_size = -1 + + if chunk_size == -1: + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + + gate = gate.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + q = q.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + k = k.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, + c).transpose(1, 2) + v = v.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + + sim = torch.bmm(q, k) * scale sim = torch.nn.functional.softmax(sim, dim=-1) - attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, - c).transpose(2, - 3).reshape(bs, chunk_size, r, cm) - return attend - - for start in range(0, s, chunk_size): - attend = ckpt.checkpoint(attention, q, k, v, gate, start) - # attend = attention(q, k, v, gate, start) - out_chunks.append(attend) - - out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) + + attend = torch.bmm(sim, v) * gate + + out = attend.reshape(bs, s, head, r, + c).transpose(2, 3).reshape(bs, s, r, cm) + out = torch.matmul(out, out_proj) + else: + q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, bs, + s, r, head, c) + assert s % chunk_size == 0 + out_chunks = [] + + def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + gate: torch.Tensor, start: int): + cur_q = q[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_k = k[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c).transpose(1, 2) + cur_v = v[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + sim = torch.bmm(cur_q, cur_k) * 0.125 + sim = torch.nn.functional.softmax(sim, dim=-1) + attend = torch.bmm(sim, cur_v) * cur_gate + attend = attend.reshape(bs, chunk_size, head, r, c).transpose( + 2, 3).reshape(bs, chunk_size, r, cm) + return attend + + for start in range(0, s, chunk_size): + attend = ckpt.checkpoint(attention, q, k, v, gate, start) + # attend = attention(q, k, v, gate, start) + out_chunks.append(attend) + + out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) return out @@ -86,7 +108,8 @@ def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, bs, s, r, cm = x.size() import math - chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) + # chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) + chunk_size = -1 if chunk_size == -1: gate = torch.sigmoid(torch.matmul(x, gate_proj)) @@ -226,7 +249,8 @@ def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, b = right_act.transpose(-2, -3) import math - chunk_size = min(r, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) + # chunk_size = min(r, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) + chunk_size = -1 if chunk_size == -1: outer = torch.einsum('...bac,...dae->...bdce', a, @@ -325,6 +349,7 @@ def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, p = torch.matmul(p, tri_mul_proj5) return p.permute(0, 2, 1, 3) * g + @cube.graph.parser.register('N S R C, C D -> N S R D', name='TANSBias') def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): return torch.matmul(pair_repr, bias_proj) @@ -341,6 +366,7 @@ def TriangleAttentionNodeStart(pair_repr: torch.Tensor, return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, head, c, scale) + @cube.graph.parser.register('N S R C, C D -> N S R D', name='TANEBias') def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): return torch.matmul(pair_repr, bias_proj) @@ -488,20 +514,24 @@ def forward(self, msa_repr, pair_repr): pair_repr, dummy_pair_repr = multi2ref(pair_repr) + cube.runtime.function.anchor('MSARow') msa_repr = msa_repr + MSARowAttentionWithPairBias( self.row_norm_m(msa_repr), dummy_pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, self.msa_head, self.c, self.scale) + cube.runtime.function.anchor('MSACol') msa_repr = msa_repr + MSAColAttention( self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, self.col_out_proj, self.msa_head, self.c, self.scale) + cube.runtime.function.anchor('MSATrans') msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), self.msa_transition_proj1, self.msa_transition_proj2) succ_msa_repr, msa_repr = multi2ref(msa_repr) + cube.runtime.function.anchor('OPM') msa_repr = self.outer_norm(msa_repr) opm_left, opm_right = OPMLeftProj(msa_repr, self.outer_proj1), OPMRightProj( @@ -509,6 +539,7 @@ def forward(self, msa_repr, pair_repr): pair_repr = pair_repr + OuterProductMean(opm_left, opm_right, self.outer_out_proj) + cube.runtime.function.anchor('TMO') pair_repr = self.tri_mul_out_norm1(pair_repr) tmo_left, tmo_right = TMOLeftProj( pair_repr, self.tri_mul_out_proj1, @@ -520,6 +551,7 @@ def forward(self, msa_repr, pair_repr): tmo_left, tmo_right, tmo_g, self.tri_mul_out_norm2_weight, self.tri_mul_out_norm2_bias, self.tri_mul_out_proj5, self.cz) + cube.runtime.function.anchor('TMI') pair_repr = self.tri_mul_in_norm1(pair_repr) tmi_left = TMILeftProj(pair_repr, self.tri_mul_in_proj1, self.tri_mul_in_proj2) @@ -530,20 +562,23 @@ def forward(self, msa_repr, pair_repr): tmi_left, tmi_right, tmi_gate, self.tri_mul_in_norm2_weight, self.tri_mul_in_norm2_bias, self.tri_mul_in_proj5, self.cz) + cube.runtime.function.anchor('TANS') pair_repr = self.tri_att_start_norm(pair_repr) bias = TANSBias(pair_repr, self.tri_att_start_bias_proj) pair_repr = pair_repr + TriangleAttentionNodeStart( pair_repr, self.tri_att_start_gate_proj, - self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, - bias, self.pair_head, self.c, self.scale) + self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, bias, + self.pair_head, self.c, self.scale) + cube.runtime.function.anchor('TANE') pair_repr = self.tri_att_end_norm(pair_repr) bias = TANEBias(pair_repr, self.tri_att_end_bias_proj) pair_repr = pair_repr + TriangleAttentionNodeEnd( - pair_repr, self.tri_att_end_gate_proj, - self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, - bias, self.pair_head, self.c, self.scale) + pair_repr, self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, + self.tri_att_end_out_proj, bias, self.pair_head, self.c, + self.scale) + cube.runtime.function.anchor('PairTrans') pair_repr = pair_repr + PairTransition( self.pair_transition_norm(pair_repr), self.pair_transition_proj1, self.pair_transition_proj2) @@ -556,7 +591,8 @@ class AlphaFold2(nn.Module): def __init__(self, s: int, cm: int, cz: int, evo_num: int): super().__init__() self.evo_num = evo_num - self.evoformers = torch.nn.ModuleList([Evoformer(s, cm, cz) for _ in range(evo_num)]) + self.evoformers = torch.nn.ModuleList( + [Evoformer(s, cm, cz) for _ in range(evo_num)]) def forward(self, msa, pair): for evoformer in self.evoformers: @@ -566,11 +602,11 @@ def forward(self, msa, pair): def test(): - evo_num = 2 + evo_num = 1 # Training # initial training: evoformer - # bs, s, r, cm, cz = 1, 128, 256, 256, 128 + bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning: evoformer # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning: evoformer @@ -580,7 +616,7 @@ def test(): # second fine-tuning: extra sequence # bs, s, r, cm, cz = 1, 1024, 384, 256, 128 # OOM on RTX 2080 Ti & V100 - bs, s, r, cm, cz = 1, 5120, 384, 256, 128 + # bs, s, r, cm, cz = 1, 5120, 384, 256, 128 # Inference # T1044: 2048 -> 2180 @@ -616,8 +652,8 @@ def test(): )) # @cube.compile(model, dataloader, PAS=spmd.PASData) - # @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) - @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) + # @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) + @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) @@ -627,8 +663,8 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - warm_up = 1 - iter_num = 2 + warm_up = 2 + iter_num = 4 CudaTimer(enable=False).warmup() if torch.distributed.is_initialized(): torch.distributed.barrier() diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 49813e88..f9fdb5f0 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -3,6 +3,7 @@ from numpy import TooHardError from cube.graph import IRGraph from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation +from cube.graph.function.anchor import IRGraphAnchor recompute_info = { 'MSARowAttentionWithPairBias': True, @@ -51,12 +52,21 @@ def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 + fnodes = graph.nodes() + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] for node in graph.nodes(): if not isinstance(node, IRBpOperation): graph.assign(node, 0) - if node.name in recompute_info and recompute_info[ - node.name] == True: - graph.recompute([node]) + for i in range(len(indices) - 1): + # hack: first layernorm should not be recomputed + if i == 0: + u = fnodes[indices[i] + 2:indices[i + 1]] + else: + u = fnodes[indices[i] + 1:indices[i + 1]] + v = [item.name for item in u] + print(v) + graph.recompute(u) return graph From 4a8fcc26402dbd308cda3a50296c5b1cff8dd704 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 12 Oct 2022 22:12:47 +0800 Subject: [PATCH 1062/1892] fix bug: recompute group of multiref --- cube/graph/gener/gen.py | 52 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 0afbd62c..551522a9 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,5 +1,6 @@ from typing import Dict, List, Optional, Tuple, Dict import numpy as np +import itertools from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener @@ -593,6 +594,8 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): # insert multiref multiref.device = devid min_fidx = min(graph.index(consumer) for consumer in devops[devid]) + # set recompute id + multiref.recompute = graph.node(min_fidx).recompute graph.finsert(multiref, min_fidx) @staticmethod @@ -626,4 +629,51 @@ def fusion(graph: IRSegment) -> IRSegment: if isinstance(segment, IRSegment) and segment.isfw(): IRAdapterGener.fusion(segment) - return graph \ No newline at end of file + return graph + + @staticmethod + def tensor_merge(tensors: List[IRSubTensor], target: Optional[IRSubTensor] = None) -> List[Tuple[str, List, IRSubTensor]]: + """ + Merge sub-tensors into one tensor or stop right after gets target tensor. + + Merge primtiives: + "sum: output = sum(inputs)" + "cat: output = cat(inputs, dim: int) + + @param tensors List[IRSubTensor]: list of tensors + @param target Optional[IRSubTensor]: the target tensor (default None). + + @return primitives List[Tuple[str, List, IRSubTensor]]: + List primitives of in forms of (op, inputs, outputs) + """ + prims = [] + tensors = [t for t in tensors] + while len(tensors) > 1: + out = None + for t1, t2 in itertools.combinations(tensors, 2): + # try concat + catdim = t1.catdim(t2) + if catdim is not None: + tensors = [t1, t2] if t1.indmap[catdim][0] < t2.indmap[catdim][0] else [t2, t1] + out = tensors[0].concat(tensors[1], dim=catdim) + prims.append(('cat', tensors + [catdim], out)) + break + # try summation + if t1.accumable(t2): + out = t1.accum(t2) + prims.append(('sum', [t1, t2], out)) + break + if out is not None: + tensors.remove(t1) + tensors.remove(t2) + tensors.append(out) + if target is not None and out == target: break + else: + remain = '\n\t'.join(t.extra_repr() for t in tensors) + sprims = '\n\t'.join(repr(p) for p in prims) + raise RuntimeError( + f"Fail to merge tensors into one tensor or cannot match with target.\n" + f"Remain Tensor:\n\t{remain}\n" + f"Existing primitives:\n\t{sprims}\n" + ) + return prims From 76a278a1c04c6e095911d19691e645069dece129 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Wed, 12 Oct 2022 22:38:37 +0800 Subject: [PATCH 1063/1892] merge master, nearly correct & runnable for initial setting on 2080 Ti --- examples/alphafold2/alphafold2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 28f60d46..a8ee1e81 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -602,7 +602,7 @@ def forward(self, msa, pair): def test(): - evo_num = 1 + evo_num = 48 # Training # initial training: evoformer From 8834f69761dcf0e2a0ef7a8985e6e7cd0b1522e3 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Thu, 13 Oct 2022 15:25:55 +0800 Subject: [PATCH 1064/1892] save work for changing experiment platform --- examples/alphafold2/alphafold2.py | 37 ++++++++++++++++++++---------- examples/alphafold2/policy/spmd.py | 35 ++++++++++++++++++---------- 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index a8ee1e81..9be4f3ec 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -137,8 +137,9 @@ def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, c).transpose(2, 3).reshape(bs, s, r, cm) out = torch.matmul(out, out_proj) else: - q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, bs, - s, r, head, c) + # q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, bs, + # s, r, head, c) + q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, c) assert s % chunk_size == 0 out_chunks = [] @@ -511,10 +512,9 @@ def __init__(self, torch.randn(ff_mult * cz, cz)) def forward(self, msa_repr, pair_repr): + cube.runtime.function.anchor('MSARow') pair_repr, dummy_pair_repr = multi2ref(pair_repr) - - cube.runtime.function.anchor('MSARow') msa_repr = msa_repr + MSARowAttentionWithPairBias( self.row_norm_m(msa_repr), dummy_pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, @@ -583,7 +583,7 @@ def forward(self, msa_repr, pair_repr): self.pair_transition_norm(pair_repr), self.pair_transition_proj1, self.pair_transition_proj2) - return (succ_msa_repr, pair_repr) + return succ_msa_repr, pair_repr class AlphaFold2(nn.Module): @@ -591,31 +591,44 @@ class AlphaFold2(nn.Module): def __init__(self, s: int, cm: int, cz: int, evo_num: int): super().__init__() self.evo_num = evo_num + # add norm to work with PyTorch's recompute mechanism + self.msa_norm = torch.nn.LayerNorm(cm) + self.pair_norm = torch.nn.LayerNorm(cz) self.evoformers = torch.nn.ModuleList( [Evoformer(s, cm, cz) for _ in range(evo_num)]) def forward(self, msa, pair): + msa = self.msa_norm(msa) + pair = self.pair_norm(pair) + + cube.runtime.function.anchor('Evoformer Stack Start') for evoformer in self.evoformers: + cube.runtime.function.anchor('One Layer Evoformer Start') msa, pair = evoformer(msa, pair) + cube.runtime.function.anchor('One Layer Evoformer End') + cube.runtime.function.anchor('Evoformer Stack End') loss = torch.sum(msa) * torch.sum(pair) return loss def test(): - evo_num = 48 + evo_num = 2 - # Training + # Training evo_num = 48 # initial training: evoformer - bs, s, r, cm, cz = 1, 128, 256, 256, 128 + # bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning: evoformer # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning: evoformer - # bs, s, r, cm, cz = 1, 512, 384, 256, 128 - # initial training: extra sequence + bs, s, r, cm, cz = 1, 512, 384, 256, 128 + + # Extra sequence evo_num = 4 + # initial training # bs, s, r, cm, cz = 1, 1024, 256, 256, 128 - # second fine-tuning: extra sequence + # first fine-tuning + # bs, s, r, cm, cz = 1, 1024, 512, 256, 128 + # second fine-tuning # bs, s, r, cm, cz = 1, 1024, 384, 256, 128 - # OOM on RTX 2080 Ti & V100 # bs, s, r, cm, cz = 1, 5120, 384, 256, 128 # Inference diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index f9fdb5f0..7956333d 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -52,21 +52,32 @@ def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 - fnodes = graph.nodes() - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] + for node in graph.nodes(): if not isinstance(node, IRBpOperation): graph.assign(node, 0) - for i in range(len(indices) - 1): - # hack: first layernorm should not be recomputed - if i == 0: - u = fnodes[indices[i] + 2:indices[i + 1]] - else: - u = fnodes[indices[i] + 1:indices[i + 1]] - v = [item.name for item in u] - print(v) - graph.recompute(u) + + fnodes = graph.nodes() + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + + indices = [ + fnodes.index(anchor) for anchor in anchors + if anchor.name == 'One Layer Evoformer Start' + or anchor.name == 'One Layer Evoformer End' + ] + assert len(indices) % 2 == 0 + for i in range(len(indices) // 2): + lhs = indices[2 * i] + rhs = indices[2 * i + 1] + # graph.recompute(fnodes[lhs + 1:rhs]) + sub_indices = [] + for j in range(lhs + 1, rhs): + if isinstance(fnodes[j], IRGraphAnchor): + sub_indices.append(j) + sub_indices.append(rhs) + for j in range(len(sub_indices) - 1): + graph.recompute(fnodes[sub_indices[j] + 1:sub_indices[j + 1]]) + return graph From 350e722546c9f4c8ed2baecd5b7d4145de379228 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 13 Oct 2022 10:20:40 +0000 Subject: [PATCH 1065/1892] collect single machine default recompute strategy data --- examples/alphafold2/alphafold2.py | 6 +++--- examples/alphafold2/policy/spmd.py | 20 ++++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 9be4f3ec..c50c6d16 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -612,15 +612,15 @@ def forward(self, msa, pair): def test(): - evo_num = 2 + evo_num = 48 # Training evo_num = 48 # initial training: evoformer - # bs, s, r, cm, cz = 1, 128, 256, 256, 128 + bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning: evoformer # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning: evoformer - bs, s, r, cm, cz = 1, 512, 384, 256, 128 + # bs, s, r, cm, cz = 1, 512, 384, 256, 128 # Extra sequence evo_num = 4 # initial training diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 7956333d..d9987b37 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -69,14 +69,18 @@ def PASSingle(graph: IRGraph, resource): for i in range(len(indices) // 2): lhs = indices[2 * i] rhs = indices[2 * i + 1] - # graph.recompute(fnodes[lhs + 1:rhs]) - sub_indices = [] - for j in range(lhs + 1, rhs): - if isinstance(fnodes[j], IRGraphAnchor): - sub_indices.append(j) - sub_indices.append(rhs) - for j in range(len(sub_indices) - 1): - graph.recompute(fnodes[sub_indices[j] + 1:sub_indices[j + 1]]) + + # deepmind's default recompute strategy + graph.recompute(fnodes[lhs + 1:rhs]) + + # another strategy + # sub_indices = [] + # for j in range(lhs + 1, rhs): + # if isinstance(fnodes[j], IRGraphAnchor): + # sub_indices.append(j) + # sub_indices.append(rhs) + # for j in range(len(sub_indices) - 1): + # graph.recompute(fnodes[sub_indices[j] + 1:sub_indices[j + 1]]) return graph From 738df56e81eeab64f04b6181db969ecb71e56edf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Oct 2022 21:27:43 +0800 Subject: [PATCH 1066/1892] allow adapter to recompute --- cube/graph/gener/gen.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 551522a9..a6618fba 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -250,12 +250,14 @@ def gen_weight(graph: IRGraph) -> IRGraph: return graph @staticmethod - def gen_activation(graph: IRSegment) -> IRSegment: + def gen_activation(graph: IRSegment, allow_recompute: bool = True) -> IRSegment: """! Generate adapter for activation tensors. The forward/backward adapter is inserted before the first consumers of its full tensor. @param graph IRGraph: the graph the requires for adapter. + @param allow_recompute bool: Allow adapter recomputes. If this enables, all adapters will be + set to the same recompute group with its consumed node. @return graph IRGraph: the (inplace) modified graph with activation adapters. """ @@ -338,6 +340,10 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # insert forward adapter # graph.insert(fadapter, max(producers) + 1) + fidx = min(graph.index(c) for c in fconsumers) + # setup recompute + if fadapter.differentiable and allow_recompute: + fadapter.recompute = graph.node(fidx).recompute graph.insert(fadapter, min(graph.index(c) for c in fconsumers)) # insert backward adapter From 701782a042a80940ff647e0ec1a6a71acbaaa651 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 13 Oct 2022 15:00:29 +0000 Subject: [PATCH 1067/1892] save work --- examples/alphafold2/alphafold2.py | 10 +- examples/alphafold2/policy/spmd.py | 186 ++++++++++++----------------- 2 files changed, 84 insertions(+), 112 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index c50c6d16..6968fe0e 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -612,15 +612,15 @@ def forward(self, msa, pair): def test(): - evo_num = 48 + evo_num = 16 # Training evo_num = 48 # initial training: evoformer - bs, s, r, cm, cz = 1, 128, 256, 256, 128 + # bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning: evoformer # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning: evoformer - # bs, s, r, cm, cz = 1, 512, 384, 256, 128 + bs, s, r, cm, cz = 1, 512, 384, 256, 128 # Extra sequence evo_num = 4 # initial training @@ -665,8 +665,8 @@ def test(): )) # @cube.compile(model, dataloader, PAS=spmd.PASData) - # @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) - @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) + # @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) + @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index d9987b37..03af0a59 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -31,8 +31,32 @@ 'multi2ref': False, } +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for dev_id, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, dev_id) + return sub_nodes + +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, + dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, + algo, + idx=idx, + dim=dim, + num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +def _tps(graph: IRGraph, nodes: List[IRFwOperation], devs: List[int], idx: int, + dim: int): + sub_nodes = [] + for node in nodes: + sub_nodes = sub_nodes + _tp(graph, node, devs, idx, dim) + return sub_nodes -# coshard def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, idx: int, dim: int): algo = node.algorithms('dim') @@ -121,112 +145,60 @@ def PASDAP(graph: IRGraph, resource): tp_size = resource.ngpus tp_devs = list(range(tp_size)) - def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, - dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=idx, - dim=dim, - num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes + fnodes = graph.nodes() + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for dev_id, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, dev_id) - return sub_nodes + indices = [ + fnodes.index(anchor) for anchor in anchors + if anchor.name == 'One Layer Evoformer Start' + or anchor.name == 'One Layer Evoformer End' + ] + assert len(indices) % 2 == 0 - pred_name = '' - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - # _tp(graph, node, tp_devs, 0, 1) - _replica(graph, node, tp_devs) - elif isinstance(node, IRFwOperation): - if node.name == 'add': - if pred_name == 'PairTransition': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'TriangleAttentionNodeEnd': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'TriangleAttentionNodeStart': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'TriangleMultiplicationIn': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'TriangleMultiplicationOut': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'OuterProductMean': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'MSATransition': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'MSAColAttention': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'MSARowAttentionWithPairBias': - _tp(graph, node, tp_devs, 0, 1) - else: - assert False, pred_name - elif node.name == 'layernorm': - if pred_name == 'TriangleAttentionNodeEnd': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'TriangleAttentionNodeStart': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'TriangleMultiplicationIn': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'TriangleMultiplicationOut': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'MSATransition': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'OuterProductMean': - _tp(graph, node, tp_devs, 0, 1) - elif pred_name == 'MSAColAttention': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == 'MSARowAttentionWithPairBias': - _tp(graph, node, tp_devs, 0, 2) - elif pred_name == '' or pred_name == 'PairTransition': - _tp(graph, node, tp_devs, 0, 1) - else: - assert False, pred_name - elif node.name in {'sum', 'mul', 'multi2ref'}: - if node.name == 'multi2ref' and pred_name == 'PairTransition': - _tp(graph, node, tp_devs, 0, 1) - elif node.name == 'multi2ref' and pred_name == 'MSATransition': - _tp(graph, node, tp_devs, 0, 2) - else: - _replica(graph, node, tp_devs) + for i in range(indices[0]): + if isinstance(fnodes[i], IRDataOperation) or isinstance(fnodes[i], IRFwOperation): + _replica(graph, fnodes[i], tp_devs) + + for i in range(len(indices) // 2): + lhs, rhs = indices[2 * i], indices[2 * i + 1] + sub_indices = [] + for j in range(lhs+1, rhs): + if isinstance(fnodes[j], IRGraphAnchor): + sub_indices.append(j) + sub_indices.append(rhs) + for j in range(len(sub_indices) - 1): + sub_l, sub_r = sub_indices[j], sub_indices[j + 1] + names = [] + for k in range(sub_l+1, sub_r): + names.append(fnodes[k].name) + names = set(names) + nodes = fnodes[sub_l+1:sub_r] + graph.recompute(nodes) + + if 'MSARowAttentionWithPairBias' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) + elif 'MSAColAttention' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'MSATransition' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'OuterProductMean' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'TriangleMultiplicationOut' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) + elif 'TriangleMultiplicationIn' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'TriangleAttentionNodeStart' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) + elif 'TriangleAttentionNodeEnd' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'PairTransition' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) else: - pred_name = node.name - if node.name == 'MSARowAttentionWithPairBias': - sub_nodes = _tp(graph, node, tp_devs, 0, 1) - elif node.name == 'MSAColAttention': - sub_nodes = _tp(graph, node, tp_devs, 0, 2) - elif node.name == 'MSATransition': - sub_nodes = _tp(graph, node, tp_devs, 0, 2) - elif node.name in { - 'OPMLeftProj', 'OPMRightProj', 'OuterProductMean' - }: - sub_nodes = _tp(graph, node, tp_devs, 0, 2) - elif node.name in { - 'TMOLeftProj', 'TMORightProj', 'TMOGate', - 'TriangleMultiplicationOut' - }: - sub_nodes = _tp(graph, node, tp_devs, 0, 1) - elif node.name in { - 'TMILeftProj', 'TMIRightProj', 'TMIGate', - 'TriangleMultiplicationIn' - }: - sub_nodes = _tp(graph, node, tp_devs, 0, 2) - elif node.name in {'TANSBias', 'TriangleAttentionNodeStart'}: - sub_nodes = _tp(graph, node, tp_devs, 0, 1) - elif node.name in {'TANEBias', 'TriangleAttentionNodeEnd'}: - sub_nodes = _tp(graph, node, tp_devs, 0, 2) - elif node.name == 'PairTransition': - sub_nodes = _tp(graph, node, tp_devs, 0, 1) - else: - assert False, node.name - - if node.name in recompute_info and recompute_info[ - node.name] == True: - graph.recompute(sub_nodes) - return graph + assert False, names + + + for i in range(indices[-1] + 1, len(fnodes)): + if isinstance(fnodes[i], IRDataOperation) or isinstance(fnodes[i], IRFwOperation): + _replica(graph, fnodes[i], tp_devs) + + return graph \ No newline at end of file From 62d606c5d9bf4a44ffaf26f92c0c5a43d20a6b9b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Oct 2022 23:57:30 +0800 Subject: [PATCH 1068/1892] fix adapter recompute --- cube/codegen/codegen.py | 24 +++++++++++------------- cube/graph/gener/gen.py | 2 +- cube/ir/adapter/adapter.py | 1 + 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 915f33a7..024035a3 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -27,6 +27,10 @@ import os +USE_NNFUSION = os.environ.get('USE_NNFUSION') +USE_JIT = os.environ.get('USE_JIT') + + def get_backward_callsite_io_tensors(bp_segment:IRSegment): """ Returns: @@ -373,12 +377,12 @@ def __init__(self, execplan: ExecutionPlan): 'import torch', 'import torch.utils.checkpoint as ckpt', 'import cube', '', ''] - use_nnfusion = os.environ.get('USE_NNFUSION') - if use_nnfusion: + if USE_NNFUSION: self.init_code.extend(['import nnfusion', '']) # customized op code for _, op_impl in Sign2Op.kOpCodeDef.items(): + # self.init_code.append('@torch.jit.script') self.init_code.append(op_impl) self.init_code += [''] # module init code @@ -485,8 +489,6 @@ def gen(self, device: int, outfile=None, attach=False) -> str: else: args.append(self.tensor_naming(t)) node_args.append(args) - - use_nnfusion = os.environ.get('USE_NNFUSION') # generate full code with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: with FunctionBlock(func_name='__init__', args=['self']) as ib: @@ -510,8 +512,10 @@ def gen(self, device: int, outfile=None, attach=False) -> str: return_code = f"return {', '.join(outputs)}" fb.insert_body(return_code) cb.insert_body('') - if use_nnfusion and name.startswith('segment'): + if USE_NNFUSION and name.startswith('segment'): cb.insert_body('@nnfusion.jit') + if USE_JIT and name.startswith('segment'): + cb.insert_body('@torch.jit.script_method') cb.insert_body(fb.code) @@ -712,14 +716,8 @@ def recompute(tensor_2222): return node_codes, inputs, outputs def get_equiv_recompute_gid(node:Union[IRFwOperation, IRAdapter]) -> Optional[int]: - if isinstance(node, IRAdapter): - # IRAdapter is equivalent to be non-recomputable. And it always terminates the - # nodes sequence of any recomputing group before it. - return None - elif isinstance(node, IRFwOperation): - return node.recompute - else: - raise ValueError(f'Unexcepted node type {type(node)}') + assert isinstance(node, (IRAdapter, IRFwOperation)), f"Invalid type: {type(node)}" + return node.recompute def should_start_new_recompute_group(i_prev, i_cur) -> bool: # i_prev, i_cur: Tuple[int, Union[IRFwOp,IRAdapter]] diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index a6618fba..cdaf7ede 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -344,7 +344,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # setup recompute if fadapter.differentiable and allow_recompute: fadapter.recompute = graph.node(fidx).recompute - graph.insert(fadapter, min(graph.index(c) for c in fconsumers)) + graph.insert(fadapter, fidx) # insert backward adapter if badapter is not None: diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index 154565b0..97d0a146 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -121,6 +121,7 @@ def dispatch(self, devid: int, for_mirror=True): fadapter._id = self._id fadapter.differentiable = self.differentiable fadapter.custom = self.custom + fadapter.recompute = self.recompute # dispatch for mirror if for_mirror and isinstance(self.mirror, IRAdapter): badapter = self.mirror.dispatch(devid, for_mirror=False) From 8b338ed741d84bb2df5fabaf1fae8da5ab622dcc Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 14 Oct 2022 05:52:10 +0000 Subject: [PATCH 1069/1892] refine code, collect numbers under training settings --- examples/alphafold2/alphafold2.py | 10 ++--- examples/alphafold2/policy/spmd.py | 61 +++++++++--------------------- 2 files changed, 22 insertions(+), 49 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 6968fe0e..c50c6d16 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -612,15 +612,15 @@ def forward(self, msa, pair): def test(): - evo_num = 16 + evo_num = 48 # Training evo_num = 48 # initial training: evoformer - # bs, s, r, cm, cz = 1, 128, 256, 256, 128 + bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning: evoformer # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning: evoformer - bs, s, r, cm, cz = 1, 512, 384, 256, 128 + # bs, s, r, cm, cz = 1, 512, 384, 256, 128 # Extra sequence evo_num = 4 # initial training @@ -665,8 +665,8 @@ def test(): )) # @cube.compile(model, dataloader, PAS=spmd.PASData) - # @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) - @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) + # @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) + @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 03af0a59..9a836192 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -5,31 +5,6 @@ from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation from cube.graph.function.anchor import IRGraphAnchor -recompute_info = { - 'MSARowAttentionWithPairBias': True, - 'MSAColAttention': True, - 'MSATransition': True, - 'OPMLeftProj': True, - 'OPMRightProj': True, - 'OuterProductMean': True, - 'TMOLeftProj': True, - 'TMORightProj': True, - 'TMOGate': True, - 'TriangleMultiplicationOut': True, - 'TMILeftProj': True, - 'TMIRightProj': True, - 'TMIGate': True, - 'TriangleMultiplicationIn': True, - 'TANSBias': True, - 'TriangleAttentionNodeStart': True, - 'TANEBias': True, - 'TriangleAttentionNodeEnd': True, - 'PairTransition': True, - 'add': False, - 'sum': False, - 'layernorm': False, - 'multi2ref': False, -} def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): sub_nodes = graph.replicate(node, times=len(devs)) @@ -37,26 +12,25 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): graph.assign(sub_node, dev_id) return sub_nodes + def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=idx, - dim=dim, - num=len(devs)) + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) assert sub_nodes is not None for devid, sub_node in zip(devs, sub_nodes): graph.assign(sub_node, devid) return sub_nodes + def _tps(graph: IRGraph, nodes: List[IRFwOperation], devs: List[int], idx: int, - dim: int): + dim: int): sub_nodes = [] for node in nodes: sub_nodes = sub_nodes + _tp(graph, node, devs, idx, dim) return sub_nodes + def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, idx: int, dim: int): algo = node.algorithms('dim') @@ -135,9 +109,6 @@ def PASData(graph: IRGraph, resource): num=resource.ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) - if node.name in recompute_info and recompute_info[ - node.name] == True: - graph.recompute(sub_nodes) return graph @@ -156,24 +127,26 @@ def PASDAP(graph: IRGraph, resource): assert len(indices) % 2 == 0 for i in range(indices[0]): - if isinstance(fnodes[i], IRDataOperation) or isinstance(fnodes[i], IRFwOperation): + if isinstance(fnodes[i], IRDataOperation) or isinstance( + fnodes[i], IRFwOperation): _replica(graph, fnodes[i], tp_devs) - + for i in range(len(indices) // 2): lhs, rhs = indices[2 * i], indices[2 * i + 1] sub_indices = [] - for j in range(lhs+1, rhs): + for j in range(lhs + 1, rhs): if isinstance(fnodes[j], IRGraphAnchor): sub_indices.append(j) sub_indices.append(rhs) + graph.recompute(fnodes[lhs:rhs]) for j in range(len(sub_indices) - 1): sub_l, sub_r = sub_indices[j], sub_indices[j + 1] names = [] - for k in range(sub_l+1, sub_r): + for k in range(sub_l + 1, sub_r): names.append(fnodes[k].name) names = set(names) - nodes = fnodes[sub_l+1:sub_r] - graph.recompute(nodes) + nodes = fnodes[sub_l + 1:sub_r] + # graph.recompute(nodes) if 'MSARowAttentionWithPairBias' in names: sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) @@ -196,9 +169,9 @@ def PASDAP(graph: IRGraph, resource): else: assert False, names - for i in range(indices[-1] + 1, len(fnodes)): - if isinstance(fnodes[i], IRDataOperation) or isinstance(fnodes[i], IRFwOperation): + if isinstance(fnodes[i], IRDataOperation) or isinstance( + fnodes[i], IRFwOperation): _replica(graph, fnodes[i], tp_devs) - - return graph \ No newline at end of file + + return graph From 5ea2c217dc8633c44d38c17e0f0fed5848a1c047 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 14 Oct 2022 09:24:31 +0000 Subject: [PATCH 1070/1892] refine code structure, collect inference time --- examples/alphafold2/alphafold2.py | 189 ++++++++++++++++------------- examples/alphafold2/policy/spmd.py | 5 +- 2 files changed, 111 insertions(+), 83 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index c50c6d16..0cd753d0 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -38,13 +38,9 @@ def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, @torch.jit.ignore def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, - c: int, scale: float): + c: int, scale: float, chunk_size: int): bs, s, r, cm = x.size() - import math - # chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) - chunk_size = -1 - if chunk_size == -1: gate = torch.sigmoid(torch.matmul(x, gate_proj)) q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) @@ -68,8 +64,7 @@ def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, c).transpose(2, 3).reshape(bs, s, r, cm) out = torch.matmul(out, out_proj) else: - q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, bs, - s, r, head, c) + q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, c) assert s % chunk_size == 0 out_chunks = [] @@ -91,8 +86,7 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return attend for start in range(0, s, chunk_size): - attend = ckpt.checkpoint(attention, q, k, v, gate, start) - # attend = attention(q, k, v, gate, start) + attend = attention(q, k, v, gate, start) out_chunks.append(attend) out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) @@ -104,13 +98,10 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @torch.jit.ignore def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias: torch.Tensor, head: int, c: int, scale: float): + bias: torch.Tensor, head: int, c: int, scale: float, + chunk_size: int): bs, s, r, cm = x.size() - import math - # chunk_size = min(s, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) - chunk_size = -1 - if chunk_size == -1: gate = torch.sigmoid(torch.matmul(x, gate_proj)) q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) @@ -137,8 +128,6 @@ def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, c).transpose(2, 3).reshape(bs, s, r, cm) out = torch.matmul(out, out_proj) else: - # q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, bs, - # s, r, head, c) q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, c) assert s % chunk_size == 0 @@ -166,11 +155,7 @@ def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return attend for start in range(0, s, chunk_size): - attend = ckpt.checkpoint(attention_bias, q, k, v, gate, bias, - start) - # attend = attention_bias(q, k, v, gate, - # bias, - # start) + attend = attention_bias(q, k, v, gate, bias, start) out_chunks.append(attend) out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) @@ -190,7 +175,7 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias_proj: torch.Tensor, head: int, c: int, - scale: float): + scale: float, chunk_size: int): bs, s, r, cm = msa_repr.size() bias = torch.matmul(pair_repr, @@ -198,16 +183,17 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, 2).reshape(bs, 1, head, r, r) return MSAAttentionWithBias(msa_repr, gate_proj, qkv_proj, out_proj, bias, - head, c, scale) + head, c, scale, chunk_size) @cube.graph.parser.register('N S R M, M E, M F, E M -> N S R M', name='MSAColAttention') def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, - c: int, scale: float): - return MSAAttention(msa_repr.permute(0, 2, 1, 3), gate_proj, qkv_proj, - out_proj, head, c, scale).permute(0, 2, 1, 3) + c: int, scale: float, chunk_size: int): + return MSAAttention(msa_repr.permute(0, 2, 1, + 3), gate_proj, qkv_proj, out_proj, + head, c, scale, chunk_size).permute(0, 2, 1, 3) """ @@ -242,17 +228,13 @@ def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): name='OuterProductMean') @torch.jit.ignore def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, - out_proj: torch.Tensor): + out_proj: torch.Tensor, chunk_size: int): bs, s, r, c = left_act.size() t = right_act.size(2) a = left_act.transpose(-2, -3) b = right_act.transpose(-2, -3) - import math - # chunk_size = min(r, max(1, 2**int(math.log2(2048 * 2048 / r / r)))) - chunk_size = -1 - if chunk_size == -1: outer = torch.einsum('...bac,...dae->...bdce', a, b).reshape(bs, r, t, c * c) @@ -268,8 +250,7 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): return out for start in range(0, r, chunk_size): - # out_chunks.append(opm(a, b, start)) - out_chunks.append(ckpt.checkpoint(opm, a, b, start)) + out_chunks.append(opm(a, b, start)) outer = torch.cat(out_chunks, dim=1) return outer @@ -361,11 +342,12 @@ def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): def TriangleAttentionNodeStart(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias: torch.Tensor, - head: int, c: int, scale: float): + head: int, c: int, scale: float, + chunk_size: int): bias = bias.permute(0, 3, 1, 2).unsqueeze(1) return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, - head, c, scale) + head, c, scale, chunk_size) @cube.graph.parser.register('N S R C, C D -> N S R D', name='TANEBias') @@ -378,11 +360,11 @@ def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, - scale: float): + scale: float, chunk_size: int): pair_repr = pair_repr.permute(0, 2, 1, 3) bias = bias.permute(0, 2, 1, 3) out = TriangleAttentionNodeStart(pair_repr, gate_proj, qkv_proj, out_proj, - bias, head, c, scale) + bias, head, c, scale, chunk_size) return out.permute(0, 2, 1, 3) @@ -412,6 +394,7 @@ def __init__(self, s: int, cm: int, cz: int, + use_chunk=False, c=32, msa_head=8, pair_head=4, @@ -419,9 +402,18 @@ def __init__(self, ff_mult=4): super().__init__() - self.s, self.cm, self.cz, self.c, self.msa_head, self.pair_head, self.c_tri_mult, self.ff_mult = s, cm, cz, c, msa_head, pair_head, c_tri_mult, ff_mult + self.s, self.cm, self.cz, self.c = s, cm, cz, c + self.msa_head, self.pair_head = msa_head, pair_head + self.c_tri_mult, self.ff_mult = c_tri_mult, ff_mult self.scale = 1.0 / math.sqrt(c) + if use_chunk: + self.msa_row_chunk, self.msa_col_chunk = 4, 256 + self.opm_chunk, self.tans_chunk, self.tane_chunk = 4, 4, 4 + else: + self.msa_row_chunk, self.msa_col_chunk = -1, -1 + self.opm_chunk, self.tans_chunk, self.tane_chunk = -1, -1, -1 + # MSA row-wise gated self-attention with pair bias self.row_norm_m = torch.nn.LayerNorm(cm) self.row_norm_z = torch.nn.LayerNorm(cz) @@ -518,12 +510,13 @@ def forward(self, msa_repr, pair_repr): msa_repr = msa_repr + MSARowAttentionWithPairBias( self.row_norm_m(msa_repr), dummy_pair_repr, self.row_gate_proj, self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, - self.msa_head, self.c, self.scale) + self.msa_head, self.c, self.scale, self.msa_row_chunk) cube.runtime.function.anchor('MSACol') msa_repr = msa_repr + MSAColAttention( self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - self.col_out_proj, self.msa_head, self.c, self.scale) + self.col_out_proj, self.msa_head, self.c, self.scale, + self.msa_col_chunk) cube.runtime.function.anchor('MSATrans') msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), @@ -536,8 +529,8 @@ def forward(self, msa_repr, pair_repr): opm_left, opm_right = OPMLeftProj(msa_repr, self.outer_proj1), OPMRightProj( msa_repr, self.outer_proj2) - pair_repr = pair_repr + OuterProductMean(opm_left, opm_right, - self.outer_out_proj) + pair_repr = pair_repr + OuterProductMean( + opm_left, opm_right, self.outer_out_proj, self.opm_chunk) cube.runtime.function.anchor('TMO') pair_repr = self.tri_mul_out_norm1(pair_repr) @@ -568,7 +561,7 @@ def forward(self, msa_repr, pair_repr): pair_repr = pair_repr + TriangleAttentionNodeStart( pair_repr, self.tri_att_start_gate_proj, self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, bias, - self.pair_head, self.c, self.scale) + self.pair_head, self.c, self.scale, self.tans_chunk) cube.runtime.function.anchor('TANE') pair_repr = self.tri_att_end_norm(pair_repr) @@ -576,7 +569,7 @@ def forward(self, msa_repr, pair_repr): pair_repr = pair_repr + TriangleAttentionNodeEnd( pair_repr, self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, self.tri_att_end_out_proj, bias, self.pair_head, self.c, - self.scale) + self.scale, self.tane_chunk) cube.runtime.function.anchor('PairTrans') pair_repr = pair_repr + PairTransition( @@ -588,14 +581,20 @@ def forward(self, msa_repr, pair_repr): class AlphaFold2(nn.Module): - def __init__(self, s: int, cm: int, cz: int, evo_num: int): + def __init__(self, + s: int, + cm: int, + cz: int, + evo_num: int, + use_chunk=False): super().__init__() self.evo_num = evo_num # add norm to work with PyTorch's recompute mechanism self.msa_norm = torch.nn.LayerNorm(cm) self.pair_norm = torch.nn.LayerNorm(cz) - self.evoformers = torch.nn.ModuleList( - [Evoformer(s, cm, cz) for _ in range(evo_num)]) + self.evoformers = torch.nn.ModuleList([ + Evoformer(s, cm, cz, use_chunk=use_chunk) for _ in range(evo_num) + ]) def forward(self, msa, pair): msa = self.msa_norm(msa) @@ -611,58 +610,85 @@ def forward(self, msa, pair): return loss -def test(): +def test_inference(): + evo_num = 48 + + # T1044: 2048 -> 2180 + bs, s, r, cm, cz = 1, 128, 2048, 256, 128 + + dtype = torch.float32 + + model = AlphaFold2(s, cm, cz, evo_num, use_chunk=True).to(dtype) + model.eval() + + model = cube.SemanticModel(model, + input_shapes=([bs, s, r, cm], [bs, r, r, cz])) + + dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], + [bs, r, r, cz]), + dtypes=(dtype, dtype), + batch_dims=(0, 0)) + + # @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) + @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) + def train_iter(model, dataloader): + msa_repr, pair_repr = next(dataloader) + loss = model(msa_repr, pair_repr) + return loss + + model = model.get_gen_module() + + warm_up = 2 + iter_num = 4 + CudaTimer(enable=False).warmup() + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + for i in range(iter_num): + if i >= warm_up: + CudaTimer(enable=True).start('e2e') + train_iter(model, dataloader) + if i >= warm_up: + CudaTimer().stop('e2e') + if i > 0 and (i + 1) % 20 == 0: + print_each_rank(f'iter [{i + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num - warm_up, field_name='e2e'))) + CudaTimer().print_all(times=iter_num - warm_up) + print_each_rank('memory consumption: {} MB'.format( + int(torch.cuda.max_memory_allocated() / 1024 / 1024))) + + +def test_train(): evo_num = 48 # Training evo_num = 48 # initial training: evoformer - bs, s, r, cm, cz = 1, 128, 256, 256, 128 + # bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning: evoformer # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning: evoformer - # bs, s, r, cm, cz = 1, 512, 384, 256, 128 + bs, s, r, cm, cz = 1, 512, 384, 256, 128 # Extra sequence evo_num = 4 # initial training # bs, s, r, cm, cz = 1, 1024, 256, 256, 128 - # first fine-tuning - # bs, s, r, cm, cz = 1, 1024, 512, 256, 128 # second fine-tuning # bs, s, r, cm, cz = 1, 1024, 384, 256, 128 # bs, s, r, cm, cz = 1, 5120, 384, 256, 128 - # Inference - # T1044: 2048 -> 2180 - # bs, s, r, cm, cz = 1, 128, 2048, 256, 128 - dtype = torch.float16 model = AlphaFold2(s, cm, cz, evo_num).to(dtype) - # msa_repr, pair_repr = torch.randn(bs, s, r, cm), torch.randn(bs, r, r, cz) - # x = model(msa_repr, pair_repr) - # return - - model = cube.SemanticModel( - model, - input_shapes=( - [bs, s, r, cm], - [bs, r, r, cz], - ), - ) - - dataloader = cube.runtime.syndata.SynDataLoader(shapes=( - [bs, s, r, cm], - [bs, r, r, cz], - ), - dtypes=( - dtype, - dtype, - ), - batch_dims=( - 0, - 0, - )) + model = cube.SemanticModel(model, + input_shapes=([bs, s, r, cm], [bs, r, r, cz])) + + dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], + [bs, r, r, cz]), + dtypes=(dtype, dtype), + batch_dims=(0, 0)) # @cube.compile(model, dataloader, PAS=spmd.PASData) # @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) @@ -671,6 +697,7 @@ def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) loss.backward() + return loss model = model.get_gen_module() @@ -700,4 +727,4 @@ def train_iter(model, dataloader): int(torch.cuda.max_memory_allocated() / 1024 / 1024))) -test() +test_inference() diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 9a836192..921374d7 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -69,7 +69,7 @@ def PASSingle(graph: IRGraph, resource): rhs = indices[2 * i + 1] # deepmind's default recompute strategy - graph.recompute(fnodes[lhs + 1:rhs]) + # graph.recompute(fnodes[lhs + 1:rhs]) # another strategy # sub_indices = [] @@ -138,7 +138,7 @@ def PASDAP(graph: IRGraph, resource): if isinstance(fnodes[j], IRGraphAnchor): sub_indices.append(j) sub_indices.append(rhs) - graph.recompute(fnodes[lhs:rhs]) + # graph.recompute(fnodes[lhs:rhs]) for j in range(len(sub_indices) - 1): sub_l, sub_r = sub_indices[j], sub_indices[j + 1] names = [] @@ -146,6 +146,7 @@ def PASDAP(graph: IRGraph, resource): names.append(fnodes[k].name) names = set(names) nodes = fnodes[sub_l + 1:sub_r] + # DO NOT USE THIS # graph.recompute(nodes) if 'MSARowAttentionWithPairBias' in names: From 5666b596cd8a7c1ad85cc5a26eb77a124ce31d82 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 14 Oct 2022 13:21:26 +0000 Subject: [PATCH 1071/1892] not runnable, save work --- examples/alphafold2/alphafold2.py | 77 ++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 0cd753d0..859312ae 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -196,6 +196,33 @@ def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, head, c, scale, chunk_size).permute(0, 2, 1, 3) +@cube.graph.parser.register('N S R M, M M, M E, M E, M M, M M -> N S R M', + name='MSAColGlobalAttention') +def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, + v_proj: torch.Tensor, gate_proj: torch.Tensor, out_proj: torch.Tensor, + head: int, + c: int, scale: float): + msa_repr = msa_repr.transpose(-2, -3) + + q = torch.sum(msa_repr, dim=-2) + q = torch.matmul(q, q_proj) * scale + q = q.view(q.shape[:-1] + (head, -1)) + + k, v = torch.matmul(msa_repr, k_proj), torch.matmul(msa_repr, v_proj) + + a = torch.matmul(q, k.transpose(-1, -2)) + a = torch.nn.functional.softmax(a, dim=-1) + o = torch.matmul(a, v) + + g = torch.nn.functional.sigmoid(torch.matmul(msa_repr, gate_proj)) + g = g.view(g.shape[:-1] + (head, -1)) + + o = o.unsqueeze(-3) * g + o = o.reshape(o.shape[:-2] + (-1,)) + + return torch.matmul(o, out_proj).transpose(-2, -3) + + """ [bs, s, r, cm] -> [bs, s, r, cm] """ @@ -395,6 +422,7 @@ def __init__(self, cm: int, cz: int, use_chunk=False, + is_extra=False, c=32, msa_head=8, pair_head=4, @@ -407,6 +435,8 @@ def __init__(self, self.c_tri_mult, self.ff_mult = c_tri_mult, ff_mult self.scale = 1.0 / math.sqrt(c) + self.is_extra = is_extra + if use_chunk: self.msa_row_chunk, self.msa_col_chunk = 4, 256 self.opm_chunk, self.tans_chunk, self.tane_chunk = 4, 4, 4 @@ -425,10 +455,14 @@ def __init__(self, # MSA column-wise gated self-attention self.col_norm = torch.nn.LayerNorm(cm) - self.col_gate_proj = torch.nn.Parameter(torch.randn(cm, msa_head * c)) + self.col_gate_proj = torch.nn.Parameter(torch.randn(cm, cm)) + # TODO: fix me + self.col_q_proj = torch.nn.Parameter(torch.randn(cm, cm)) + self.col_k_proj = torch.nn.Parameter(torch.randn(cm, 8)) + self.col_v_proj = torch.nn.Parameter(torch.randn(cm, 8)) self.col_qkv_proj = torch.nn.Parameter( torch.randn(cm, 3 * msa_head * c)) - self.col_out_proj = torch.nn.Parameter(torch.randn(msa_head * c, cm)) + self.col_out_proj = torch.nn.Parameter(torch.randn(cm, cm)) # MSA transition self.msa_transition_norm = torch.nn.LayerNorm(cm) @@ -513,10 +547,16 @@ def forward(self, msa_repr, pair_repr): self.msa_head, self.c, self.scale, self.msa_row_chunk) cube.runtime.function.anchor('MSACol') - msa_repr = msa_repr + MSAColAttention( - self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - self.col_out_proj, self.msa_head, self.c, self.scale, - self.msa_col_chunk) + if self.is_extra: + msa_repr = msa_repr + MSAColGlobalAttention( + self.col_norm(msa_repr), self.col_q_proj, self.col_k_proj, self.col_v_proj, self.col_gate_proj, + self.col_out_proj, self.msa_head, self.c, self.scale + ) + else: + msa_repr = msa_repr + MSAColAttention( + self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, + self.col_out_proj, self.msa_head, self.c, self.scale, + self.msa_col_chunk) cube.runtime.function.anchor('MSATrans') msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), @@ -586,14 +626,15 @@ def __init__(self, cm: int, cz: int, evo_num: int, - use_chunk=False): + use_chunk=False, + is_extra=False): super().__init__() self.evo_num = evo_num # add norm to work with PyTorch's recompute mechanism self.msa_norm = torch.nn.LayerNorm(cm) self.pair_norm = torch.nn.LayerNorm(cz) self.evoformers = torch.nn.ModuleList([ - Evoformer(s, cm, cz, use_chunk=use_chunk) for _ in range(evo_num) + Evoformer(s, cm, cz, use_chunk=use_chunk, is_extra=is_extra) for _ in range(evo_num) ]) def forward(self, msa, pair): @@ -661,7 +702,7 @@ def train_iter(model, dataloader): def test_train(): - evo_num = 48 + evo_num = 1 # Training evo_num = 48 # initial training: evoformer @@ -669,18 +710,22 @@ def test_train(): # first fine-tuning: evoformer # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning: evoformer - bs, s, r, cm, cz = 1, 512, 384, 256, 128 + # bs, s, r, cm, cz = 1, 512, 384, 256, 128 # Extra sequence evo_num = 4 # initial training - # bs, s, r, cm, cz = 1, 1024, 256, 256, 128 + bs, s, r, cm, cz = 1, 1024, 256, 64, 128 # second fine-tuning - # bs, s, r, cm, cz = 1, 1024, 384, 256, 128 - # bs, s, r, cm, cz = 1, 5120, 384, 256, 128 + # bs, s, r, cm, cz = 1, 1024, 384, 64, 128 + # bs, s, r, cm, cz = 1, 5120, 384, 64, 128 + + dtype = torch.float32 - dtype = torch.float16 + model = AlphaFold2(s, cm, cz, evo_num, is_extra=True).to(dtype) - model = AlphaFold2(s, cm, cz, evo_num).to(dtype) + msa, pair = torch.randn(bs, s, r, cm), torch.randn(bs, r, r, cz) + loss = model(msa, pair) + return loss model = cube.SemanticModel(model, input_shapes=([bs, s, r, cm], [bs, r, r, cz])) @@ -727,4 +772,4 @@ def train_iter(model, dataloader): int(torch.cuda.max_memory_allocated() / 1024 / 1024))) -test_inference() +test_train() From 855b47da0e076d88eb7a8e4a8f6150e383c9c544 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 14 Oct 2022 14:59:41 +0000 Subject: [PATCH 1072/1892] fix bug, partial runnable for 1 evoformer --- examples/alphafold2/alphafold2.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 859312ae..a44f813e 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -101,6 +101,8 @@ def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float, chunk_size: int): bs, s, r, cm = x.size() + assert cm % head == 0 + c = cm // head if chunk_size == -1: gate = torch.sigmoid(torch.matmul(x, gate_proj)) @@ -214,7 +216,7 @@ def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: a = torch.nn.functional.softmax(a, dim=-1) o = torch.matmul(a, v) - g = torch.nn.functional.sigmoid(torch.matmul(msa_repr, gate_proj)) + g = torch.sigmoid(torch.matmul(msa_repr, gate_proj)) g = g.view(g.shape[:-1] + (head, -1)) o = o.unsqueeze(-3) * g @@ -447,10 +449,10 @@ def __init__(self, # MSA row-wise gated self-attention with pair bias self.row_norm_m = torch.nn.LayerNorm(cm) self.row_norm_z = torch.nn.LayerNorm(cz) - self.row_gate_proj = torch.nn.Parameter(torch.randn(cm, msa_head * c)) + self.row_gate_proj = torch.nn.Parameter(torch.randn(cm, cm)) self.row_qkv_proj = torch.nn.Parameter( - torch.randn(cm, 3 * msa_head * c)) - self.row_out_proj = torch.nn.Parameter(torch.randn(msa_head * c, cm)) + torch.randn(cm, 3 * cm)) + self.row_out_proj = torch.nn.Parameter(torch.randn(cm, cm)) self.row_bias_proj = torch.nn.Parameter(torch.randn(cz, msa_head)) # MSA column-wise gated self-attention @@ -723,9 +725,9 @@ def test_train(): model = AlphaFold2(s, cm, cz, evo_num, is_extra=True).to(dtype) - msa, pair = torch.randn(bs, s, r, cm), torch.randn(bs, r, r, cz) - loss = model(msa, pair) - return loss + # msa, pair = torch.randn(bs, s, r, cm), torch.randn(bs, r, r, cz) + # loss = model(msa, pair) + # return loss model = cube.SemanticModel(model, input_shapes=([bs, s, r, cm], [bs, r, r, cz])) @@ -742,7 +744,6 @@ def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) loss.backward() - return loss model = model.get_gen_module() From dbca4bc9efadc62424c3bd3baf1af4ddef4a53b4 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 16 Oct 2022 11:10:59 +0000 Subject: [PATCH 1073/1892] refine code structure --- examples/alphafold2/alphafold2.py | 757 ++--------------------------- examples/alphafold2/model.py | 260 ++++++++++ examples/alphafold2/module.py | 422 ++++++++++++++++ examples/alphafold2/policy/spmd.py | 100 +++- 4 files changed, 824 insertions(+), 715 deletions(-) create mode 100644 examples/alphafold2/model.py create mode 100644 examples/alphafold2/module.py diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index a44f813e..060a607a 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -1,668 +1,29 @@ -from audioop import mul import torch -import torch.utils.checkpoint as ckpt import math import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from torch import nn +from examples.alphafold2.model import * import examples.alphafold2.policy.spmd as spmd cube.init() -@cube.graph.parser.register('TODO', name='calc_qkvg') -def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, - bs: int, s: int, r: int, head: int, c: int): - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) +def run(size_config, other_config, policy): + bs, s, r, cm, cz = size_config + dtype, evo_num, use_chunk, is_train, is_extra = other_config - gate = gate.reshape(bs, s, r, head, c).transpose(2, 3) - q = q.reshape(bs, s, r, head, c).transpose(2, 3) - k = k.reshape(bs, s, r, head, c).transpose(2, 3) - v = v.reshape(bs, s, r, head, c).transpose(2, 3) - return q, k, v, gate - - -""" -[bs, s, r, cm] -> [bs, s, r, cm] - -used as column-wise gated self-attention -""" - - -@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R M', - name='MSAAttention') -@torch.jit.ignore -def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, - c: int, scale: float, chunk_size: int): - bs, s, r, cm = x.size() - - if chunk_size == -1: - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - - gate = gate.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - q = q.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - k = k.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, - c).transpose(1, 2) - v = v.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - - sim = torch.bmm(q, k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - - attend = torch.bmm(sim, v) * gate - - out = attend.reshape(bs, s, head, r, - c).transpose(2, 3).reshape(bs, s, r, cm) - out = torch.matmul(out, out_proj) - else: - q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, c) - assert s % chunk_size == 0 - out_chunks = [] - - def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - gate: torch.Tensor, start: int): - cur_q = q[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_k = k[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c).transpose(1, 2) - cur_v = v[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - sim = torch.bmm(cur_q, cur_k) * 0.125 - sim = torch.nn.functional.softmax(sim, dim=-1) - attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, c).transpose( - 2, 3).reshape(bs, chunk_size, r, cm) - return attend - - for start in range(0, s, chunk_size): - attend = attention(q, k, v, gate, start) - out_chunks.append(attend) - - out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) - return out - - -@cube.graph.parser.register('N S R M, M E, M F, E M, N 1 8 R R -> N S R M', - name='MSAAttentionWithBias') -@torch.jit.ignore -def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias: torch.Tensor, head: int, c: int, scale: float, - chunk_size: int): - bs, s, r, cm = x.size() - assert cm % head == 0 - c = cm // head - - if chunk_size == -1: - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - - gate = gate.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - q = q.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - k = k.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, - c).transpose(1, 2) - v = v.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - - sim = torch.bmm(q, k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - - sim = sim.reshape(bs, s, head, r, r) + bias - sim = sim.reshape(bs * s * head, r, r) - - attend = torch.bmm(sim, v) * gate - - out = attend.reshape(bs, s, head, r, - c).transpose(2, 3).reshape(bs, s, r, cm) - out = torch.matmul(out, out_proj) - else: - q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, c) - - assert s % chunk_size == 0 - out_chunks = [] - - def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - gate: torch.Tensor, bias: torch.Tensor, start: int): - cur_q = q[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_k = k[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c).transpose(1, 2) - cur_v = v[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - - sim = torch.bmm(cur_q, cur_k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - sim = sim.reshape(bs, chunk_size, head, r, r) + bias - sim = sim.reshape(bs * chunk_size * head, r, r) - - attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, c).transpose( - 2, 3).reshape(bs, chunk_size, r, cm) - return attend - - for start in range(0, s, chunk_size): - attend = attention_bias(q, k, v, gate, bias, start) - - out_chunks.append(attend) - out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) - return out - - -""" -([bs, s, r, cm], [bs, r, r, cz]) -> [bs, s, r, cm] -""" - - -# note: code not reused constrained by cube's interface -@cube.graph.parser.register('N S R M, N R R Z, M E, M F, E M, Z H -> N S R M', - name='MSARowAttentionWithPairBias') -def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, - pair_repr: torch.Tensor, - gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias_proj: torch.Tensor, head: int, c: int, - scale: float, chunk_size: int): - bs, s, r, cm = msa_repr.size() - - bias = torch.matmul(pair_repr, - bias_proj).permute(0, 3, 1, - 2).reshape(bs, 1, head, r, r) - - return MSAAttentionWithBias(msa_repr, gate_proj, qkv_proj, out_proj, bias, - head, c, scale, chunk_size) - - -@cube.graph.parser.register('N S R M, M E, M F, E M -> N S R M', - name='MSAColAttention') -def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, - c: int, scale: float, chunk_size: int): - return MSAAttention(msa_repr.permute(0, 2, 1, - 3), gate_proj, qkv_proj, out_proj, - head, c, scale, chunk_size).permute(0, 2, 1, 3) - - -@cube.graph.parser.register('N S R M, M M, M E, M E, M M, M M -> N S R M', - name='MSAColGlobalAttention') -def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, - v_proj: torch.Tensor, gate_proj: torch.Tensor, out_proj: torch.Tensor, - head: int, - c: int, scale: float): - msa_repr = msa_repr.transpose(-2, -3) - - q = torch.sum(msa_repr, dim=-2) - q = torch.matmul(q, q_proj) * scale - q = q.view(q.shape[:-1] + (head, -1)) - - k, v = torch.matmul(msa_repr, k_proj), torch.matmul(msa_repr, v_proj) - - a = torch.matmul(q, k.transpose(-1, -2)) - a = torch.nn.functional.softmax(a, dim=-1) - o = torch.matmul(a, v) - - g = torch.sigmoid(torch.matmul(msa_repr, gate_proj)) - g = g.view(g.shape[:-1] + (head, -1)) - - o = o.unsqueeze(-3) * g - o = o.reshape(o.shape[:-2] + (-1,)) - - return torch.matmul(o, out_proj).transpose(-2, -3) - - -""" -[bs, s, r, cm] -> [bs, s, r, cm] -""" - - -@cube.graph.parser.register('N S R M, M E, E M -> N S R M', - name='MSATransition') -def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - return torch.matmul( - torch.nn.functional.relu(torch.matmul(msa_repr, proj1)), proj2) - - -@cube.graph.parser.register('N S R M, M C -> N S R C', name='OPMLeftProj') -def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): - return torch.matmul(msa_repr, proj) - - -@cube.graph.parser.register('N S R M, M C -> N S R C', name='OPMRightProj') -def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): - return torch.matmul(msa_repr, proj) - - -""" -[bs, s, r, cm] -> [bs, r, r, cz] -""" - - -@cube.graph.parser.register('N S R M, N S T M, F Z -> N R T Z', - name='OuterProductMean') -@torch.jit.ignore -def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, - out_proj: torch.Tensor, chunk_size: int): - bs, s, r, c = left_act.size() - t = right_act.size(2) - - a = left_act.transpose(-2, -3) - b = right_act.transpose(-2, -3) - - if chunk_size == -1: - outer = torch.einsum('...bac,...dae->...bdce', a, - b).reshape(bs, r, t, c * c) - outer = torch.matmul(outer, out_proj) - else: - out_chunks = [] - - def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): - lhs_slice = lhs[:, start:start + chunk_size, :, :] - out = torch.einsum('...bac,...dae->...bdce', lhs_slice, - rhs).reshape(bs, chunk_size, t, c * c) - out = torch.matmul(out, out_proj) - return out - - for start in range(0, r, chunk_size): - out_chunks.append(opm(a, b, start)) - outer = torch.cat(out_chunks, dim=1) - return outer - - -@cube.graph.parser.register('N S R Z, Z E, Z E -> N S R E', name='TMOLeftProj') -def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - a = a * torch.matmul(pair_repr, proj2) - return a - - -@cube.graph.parser.register('N S R Z, Z E, Z E -> N S R E', - name='TMORightProj') -def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - a = a * torch.matmul(pair_repr, proj2) - return a - - -@cube.graph.parser.register('N S T Z, Z Z -> N S T Z', name='TMOGate') -def TMOGate(pair_repr: torch.Tensor, proj: torch.Tensor): - return torch.sigmoid(torch.matmul(pair_repr, proj)) - - -@cube.graph.parser.register('N S R E, N T R E, N S T Z, E, E, E Z -> N S T Z', - name='TriangleMultiplicationOut') -def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, - g: torch.Tensor, - tri_mul_norm2_weight: torch.Tensor, - tri_mul_norm2_bias: torch.Tensor, - tri_mul_proj5: torch.Tensor, cz: int): - a = a.permute(0, 3, 1, 2) - b = b.permute(0, 3, 2, 1) - - p = torch.matmul(a, b).permute(0, 2, 3, 1) - p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, - tri_mul_norm2_bias) - p = torch.matmul(p, tri_mul_proj5) - return p * g - - -@cube.graph.parser.register('N R S Z, Z E, Z E -> N R S E', name='TMILeftProj') -def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - a = a * torch.matmul(pair_repr, proj2) - return a - - -@cube.graph.parser.register('N R T Z, Z E, Z E -> N R T E', - name='TMIRightProj') -def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - a = a * torch.matmul(pair_repr, proj2) - return a - - -@cube.graph.parser.register('N S T Z, Z Z -> N S T Z', name='TMIGate') -def TMIGate(pair_repr: torch.Tensor, proj: torch.Tensor): - return torch.sigmoid(torch.matmul(pair_repr, proj)) - - -@cube.graph.parser.register('N R S E, N R T E, N T S Z, E, E, E Z -> N T S Z', - name='TriangleMultiplicationIn') -def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, - tri_mul_norm2_weight: torch.Tensor, - tri_mul_norm2_bias: torch.Tensor, - tri_mul_proj5: torch.Tensor, cz: int): - a = a.permute(0, 3, 2, 1) - b = b.permute(0, 3, 1, 2) - - p = torch.matmul(a, b).permute(0, 2, 3, 1) - p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, - tri_mul_norm2_bias) - p = torch.matmul(p, tri_mul_proj5) - return p.permute(0, 2, 1, 3) * g - - -@cube.graph.parser.register('N S R C, C D -> N S R D', name='TANSBias') -def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): - return torch.matmul(pair_repr, bias_proj) - - -@cube.graph.parser.register('N S R Z, Z E, Z F, E Z, N T R G -> N S R Z', - name='TriangleAttentionNodeStart') -def TriangleAttentionNodeStart(pair_repr: torch.Tensor, - gate_proj: torch.Tensor, qkv_proj: torch.Tensor, - out_proj: torch.Tensor, bias: torch.Tensor, - head: int, c: int, scale: float, - chunk_size: int): - bias = bias.permute(0, 3, 1, 2).unsqueeze(1) - - return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, - head, c, scale, chunk_size) - - -@cube.graph.parser.register('N S R C, C D -> N S R D', name='TANEBias') -def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): - return torch.matmul(pair_repr, bias_proj) - - -@cube.graph.parser.register('N R S Z, Z E, Z F, E Z, N R T G -> N R S Z', - name='TriangleAttentionNodeEnd') -def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias: torch.Tensor, head: int, c: int, - scale: float, chunk_size: int): - pair_repr = pair_repr.permute(0, 2, 1, 3) - bias = bias.permute(0, 2, 1, 3) - out = TriangleAttentionNodeStart(pair_repr, gate_proj, qkv_proj, out_proj, - bias, head, c, scale, chunk_size) - return out.permute(0, 2, 1, 3) - - -@cube.graph.parser.register('N R T Z, Z E, E Z -> N R T Z', - name='PairTransition') -def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - return torch.matmul( - torch.nn.functional.relu(torch.matmul(pair_repr, proj1)), proj2) - - -@cube.graph.parser.register('* -> *, *', name='multi2ref') -def multi2ref(x: torch.Tensor): - return (x, x) - - -""" -a simplified version for evoformer in alphafold2 - - dropout layers are omitted - - masks are omitted -""" - - -class Evoformer(torch.nn.Module): - - def __init__(self, - s: int, - cm: int, - cz: int, - use_chunk=False, - is_extra=False, - c=32, - msa_head=8, - pair_head=4, - c_tri_mult=128, - ff_mult=4): - super().__init__() - - self.s, self.cm, self.cz, self.c = s, cm, cz, c - self.msa_head, self.pair_head = msa_head, pair_head - self.c_tri_mult, self.ff_mult = c_tri_mult, ff_mult - self.scale = 1.0 / math.sqrt(c) - - self.is_extra = is_extra - - if use_chunk: - self.msa_row_chunk, self.msa_col_chunk = 4, 256 - self.opm_chunk, self.tans_chunk, self.tane_chunk = 4, 4, 4 - else: - self.msa_row_chunk, self.msa_col_chunk = -1, -1 - self.opm_chunk, self.tans_chunk, self.tane_chunk = -1, -1, -1 - - # MSA row-wise gated self-attention with pair bias - self.row_norm_m = torch.nn.LayerNorm(cm) - self.row_norm_z = torch.nn.LayerNorm(cz) - self.row_gate_proj = torch.nn.Parameter(torch.randn(cm, cm)) - self.row_qkv_proj = torch.nn.Parameter( - torch.randn(cm, 3 * cm)) - self.row_out_proj = torch.nn.Parameter(torch.randn(cm, cm)) - self.row_bias_proj = torch.nn.Parameter(torch.randn(cz, msa_head)) - - # MSA column-wise gated self-attention - self.col_norm = torch.nn.LayerNorm(cm) - self.col_gate_proj = torch.nn.Parameter(torch.randn(cm, cm)) - # TODO: fix me - self.col_q_proj = torch.nn.Parameter(torch.randn(cm, cm)) - self.col_k_proj = torch.nn.Parameter(torch.randn(cm, 8)) - self.col_v_proj = torch.nn.Parameter(torch.randn(cm, 8)) - self.col_qkv_proj = torch.nn.Parameter( - torch.randn(cm, 3 * msa_head * c)) - self.col_out_proj = torch.nn.Parameter(torch.randn(cm, cm)) - - # MSA transition - self.msa_transition_norm = torch.nn.LayerNorm(cm) - self.msa_transition_proj1 = torch.nn.Parameter( - torch.randn(cm, ff_mult * cm)) - self.msa_transition_proj2 = torch.nn.Parameter( - torch.randn(ff_mult * cm, cm)) - - # Outer product mean - self.outer_norm = torch.nn.LayerNorm(cm) - self.outer_proj1 = torch.nn.Parameter(torch.randn(cm, c)) - self.outer_proj2 = torch.nn.Parameter(torch.randn(cm, c)) - self.outer_out_proj = torch.nn.Parameter(torch.randn(c * c, cz)) - - # Triangular multiplicative update using outgoing edges - self.tri_mul_out_norm1 = torch.nn.LayerNorm(cz) - self.tri_mul_out_norm2_weight = torch.nn.Parameter( - torch.empty(c_tri_mult)) - self.tri_mul_out_norm2_bias = torch.nn.Parameter( - torch.empty(c_tri_mult)) - self.tri_mul_out_proj1 = torch.nn.Parameter(torch.randn( - cz, c_tri_mult)) - self.tri_mul_out_proj2 = torch.nn.Parameter(torch.randn( - cz, c_tri_mult)) - self.tri_mul_out_proj3 = torch.nn.Parameter(torch.randn( - cz, c_tri_mult)) - self.tri_mul_out_proj4 = torch.nn.Parameter(torch.randn( - cz, c_tri_mult)) - self.tri_mul_out_proj5 = torch.nn.Parameter(torch.randn( - c_tri_mult, cz)) - self.tri_mul_out_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) - - # Triangular multiplicative update using incoming edges - self.tri_mul_in_norm1 = torch.nn.LayerNorm(cz) - self.tri_mul_in_norm2_weight = torch.nn.Parameter( - torch.empty(c_tri_mult)) - self.tri_mul_in_norm2_bias = torch.nn.Parameter( - torch.empty(c_tri_mult)) - self.tri_mul_in_proj1 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) - self.tri_mul_in_proj2 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) - self.tri_mul_in_proj3 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) - self.tri_mul_in_proj4 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) - self.tri_mul_in_proj5 = torch.nn.Parameter(torch.randn(c_tri_mult, cz)) - self.tri_mul_in_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) - - # Triangular gated self-attention around starting node - self.tri_att_start_norm = torch.nn.LayerNorm(cz) - self.tri_att_start_gate_proj = torch.nn.Parameter( - torch.randn(cz, pair_head * c)) - self.tri_att_start_qkv_proj = torch.nn.Parameter( - torch.randn(cz, 3 * pair_head * c)) - self.tri_att_start_out_proj = torch.nn.Parameter( - torch.randn(pair_head * c, cz)) - self.tri_att_start_bias_proj = torch.nn.Parameter( - torch.randn(cz, pair_head)) - - # Triangular gated self-attention around ending node - self.tri_att_end_norm = torch.nn.LayerNorm(cz) - self.tri_att_end_gate_proj = torch.nn.Parameter( - torch.randn(cz, pair_head * c)) - self.tri_att_end_qkv_proj = torch.nn.Parameter( - torch.randn(cz, 3 * pair_head * c)) - self.tri_att_end_out_proj = torch.nn.Parameter( - torch.randn(pair_head * c, cz)) - self.tri_att_end_bias_proj = torch.nn.Parameter( - torch.randn(cz, pair_head)) - - # Transition in the pair stack - self.pair_transition_norm = torch.nn.LayerNorm(cz) - self.pair_transition_proj1 = torch.nn.Parameter( - torch.randn(cz, ff_mult * cz)) - self.pair_transition_proj2 = torch.nn.Parameter( - torch.randn(ff_mult * cz, cz)) - - def forward(self, msa_repr, pair_repr): - cube.runtime.function.anchor('MSARow') - - pair_repr, dummy_pair_repr = multi2ref(pair_repr) - msa_repr = msa_repr + MSARowAttentionWithPairBias( - self.row_norm_m(msa_repr), dummy_pair_repr, self.row_gate_proj, - self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, - self.msa_head, self.c, self.scale, self.msa_row_chunk) - - cube.runtime.function.anchor('MSACol') - if self.is_extra: - msa_repr = msa_repr + MSAColGlobalAttention( - self.col_norm(msa_repr), self.col_q_proj, self.col_k_proj, self.col_v_proj, self.col_gate_proj, - self.col_out_proj, self.msa_head, self.c, self.scale - ) - else: - msa_repr = msa_repr + MSAColAttention( - self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, - self.col_out_proj, self.msa_head, self.c, self.scale, - self.msa_col_chunk) - - cube.runtime.function.anchor('MSATrans') - msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), - self.msa_transition_proj1, - self.msa_transition_proj2) - succ_msa_repr, msa_repr = multi2ref(msa_repr) - - cube.runtime.function.anchor('OPM') - msa_repr = self.outer_norm(msa_repr) - opm_left, opm_right = OPMLeftProj(msa_repr, - self.outer_proj1), OPMRightProj( - msa_repr, self.outer_proj2) - pair_repr = pair_repr + OuterProductMean( - opm_left, opm_right, self.outer_out_proj, self.opm_chunk) - - cube.runtime.function.anchor('TMO') - pair_repr = self.tri_mul_out_norm1(pair_repr) - tmo_left, tmo_right = TMOLeftProj( - pair_repr, self.tri_mul_out_proj1, - self.tri_mul_out_proj2), TMORightProj(pair_repr, - self.tri_mul_out_proj3, - self.tri_mul_out_proj4) - tmo_g = TMOGate(pair_repr, self.tri_mul_out_proj6) - pair_repr = pair_repr + TriangleMultiplicationOut( - tmo_left, tmo_right, tmo_g, self.tri_mul_out_norm2_weight, - self.tri_mul_out_norm2_bias, self.tri_mul_out_proj5, self.cz) - - cube.runtime.function.anchor('TMI') - pair_repr = self.tri_mul_in_norm1(pair_repr) - tmi_left = TMILeftProj(pair_repr, self.tri_mul_in_proj1, - self.tri_mul_in_proj2) - tmi_right = TMIRightProj(pair_repr, self.tri_mul_in_proj3, - self.tri_mul_in_proj4) - tmi_gate = TMIGate(pair_repr, self.tri_mul_in_proj6) - pair_repr = pair_repr + TriangleMultiplicationIn( - tmi_left, tmi_right, tmi_gate, self.tri_mul_in_norm2_weight, - self.tri_mul_in_norm2_bias, self.tri_mul_in_proj5, self.cz) - - cube.runtime.function.anchor('TANS') - pair_repr = self.tri_att_start_norm(pair_repr) - bias = TANSBias(pair_repr, self.tri_att_start_bias_proj) - pair_repr = pair_repr + TriangleAttentionNodeStart( - pair_repr, self.tri_att_start_gate_proj, - self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, bias, - self.pair_head, self.c, self.scale, self.tans_chunk) - - cube.runtime.function.anchor('TANE') - pair_repr = self.tri_att_end_norm(pair_repr) - bias = TANEBias(pair_repr, self.tri_att_end_bias_proj) - pair_repr = pair_repr + TriangleAttentionNodeEnd( - pair_repr, self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, - self.tri_att_end_out_proj, bias, self.pair_head, self.c, - self.scale, self.tane_chunk) - - cube.runtime.function.anchor('PairTrans') - pair_repr = pair_repr + PairTransition( - self.pair_transition_norm(pair_repr), self.pair_transition_proj1, - self.pair_transition_proj2) - - return succ_msa_repr, pair_repr - - -class AlphaFold2(nn.Module): - - def __init__(self, - s: int, - cm: int, - cz: int, - evo_num: int, - use_chunk=False, - is_extra=False): - super().__init__() - self.evo_num = evo_num - # add norm to work with PyTorch's recompute mechanism - self.msa_norm = torch.nn.LayerNorm(cm) - self.pair_norm = torch.nn.LayerNorm(cz) - self.evoformers = torch.nn.ModuleList([ - Evoformer(s, cm, cz, use_chunk=use_chunk, is_extra=is_extra) for _ in range(evo_num) - ]) - - def forward(self, msa, pair): - msa = self.msa_norm(msa) - pair = self.pair_norm(pair) - - cube.runtime.function.anchor('Evoformer Stack Start') - for evoformer in self.evoformers: - cube.runtime.function.anchor('One Layer Evoformer Start') - msa, pair = evoformer(msa, pair) - cube.runtime.function.anchor('One Layer Evoformer End') - cube.runtime.function.anchor('Evoformer Stack End') - loss = torch.sum(msa) * torch.sum(pair) - return loss - - -def test_inference(): - evo_num = 48 - - # T1044: 2048 -> 2180 - bs, s, r, cm, cz = 1, 128, 2048, 256, 128 - - dtype = torch.float32 - - model = AlphaFold2(s, cm, cz, evo_num, use_chunk=True).to(dtype) - model.eval() + model = AlphaFold2(s, + cm, + cz, + evo_num, + use_chunk=use_chunk, + is_extra=is_extra, + is_train=is_train).to(dtype) + if not is_train: + model.eval() model = cube.SemanticModel(model, input_shapes=([bs, s, r, cm], [bs, r, r, cz])) @@ -672,15 +33,20 @@ def test_inference(): dtypes=(dtype, dtype), batch_dims=(0, 0)) - # @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) - @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) + @cube.compile(model, dataloader, PAS=policy, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) - return loss + if is_train: + loss.backward() + else: + return loss model = model.get_gen_module() + if is_train: + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + warm_up = 2 iter_num = 4 CudaTimer(enable=False).warmup() @@ -691,6 +57,9 @@ def train_iter(model, dataloader): if i >= warm_up: CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) + if is_train: + optimizer.step() + optimizer.zero_grad() if i >= warm_up: CudaTimer().stop('e2e') if i > 0 and (i + 1) % 20 == 0: @@ -703,74 +72,36 @@ def train_iter(model, dataloader): int(torch.cuda.max_memory_allocated() / 1024 / 1024))) -def test_train(): - evo_num = 1 - - # Training evo_num = 48 - # initial training: evoformer +def test_main(): + # Training && Evoformer Stack + # initial training # bs, s, r, cm, cz = 1, 128, 256, 256, 128 - # first fine-tuning: evoformer + # first fine-tuning # bs, s, r, cm, cz = 1, 512, 256, 256, 128 - # second fine-tuning: evoformer + # second fine-tuning # bs, s, r, cm, cz = 1, 512, 384, 256, 128 - # Extra sequence evo_num = 4 + # dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False + # policy = spmd.PASDAP + + # Training && Extra Sequence # initial training - bs, s, r, cm, cz = 1, 1024, 256, 64, 128 + # bs, s, r, cm, cz = 1, 1024, 256, 64, 128 # second fine-tuning # bs, s, r, cm, cz = 1, 1024, 384, 64, 128 # bs, s, r, cm, cz = 1, 5120, 384, 64, 128 - dtype = torch.float32 - - model = AlphaFold2(s, cm, cz, evo_num, is_extra=True).to(dtype) - - # msa, pair = torch.randn(bs, s, r, cm), torch.randn(bs, r, r, cz) - # loss = model(msa, pair) - # return loss - - model = cube.SemanticModel(model, - input_shapes=([bs, s, r, cm], [bs, r, r, cz])) - - dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], - [bs, r, r, cz]), - dtypes=(dtype, dtype), - batch_dims=(0, 0)) - - # @cube.compile(model, dataloader, PAS=spmd.PASData) - # @cube.compile(model, dataloader, PAS=spmd.PASDAP, override=True) - @cube.compile(model, dataloader, PAS=spmd.PASSingle, override=True) - def train_iter(model, dataloader): - msa_repr, pair_repr = next(dataloader) - loss = model(msa_repr, pair_repr) - loss.backward() - - model = model.get_gen_module() + # dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 4, True, True, True + # policy = spmd.PASExtraSingle - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - warm_up = 2 - iter_num = 4 - CudaTimer(enable=False).warmup() - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - for i in range(iter_num): - if i >= warm_up: - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if i >= warm_up: - CudaTimer().stop('e2e') - if i > 0 and (i + 1) % 20 == 0: - print_each_rank(f'iter [{i + 1}/{iter_num}]', rank_only=0) + # Inference + bs, s, r, cm, cz = 1, 128, 2048, 256, 128 + dtype, evo_num, use_chunk, is_train, is_extra = torch.float32, 48, True, False, False + policy = spmd.PASSingleInference + policy = spmd.PASDAPInference - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num - warm_up, field_name='e2e'))) - CudaTimer().print_all(times=iter_num - warm_up) - print_each_rank('memory consumption: {} MB'.format( - int(torch.cuda.max_memory_allocated() / 1024 / 1024))) + run((bs, s, r, cm, cz), (dtype, evo_num, use_chunk, is_train, is_extra), + policy) -test_train() +test_main() diff --git a/examples/alphafold2/model.py b/examples/alphafold2/model.py new file mode 100644 index 00000000..30d68981 --- /dev/null +++ b/examples/alphafold2/model.py @@ -0,0 +1,260 @@ +import cube +import torch +import math +from torch import nn + +from examples.alphafold2.module import * +""" +a simplified version for evoformer in alphafold2 + - dropout layers are omitted + - masks are omitted +""" + + +class Evoformer(torch.nn.Module): + + def __init__(self, + s: int, + cm: int, + cz: int, + use_chunk=False, + is_extra=False, + is_train=True, + c=32, + msa_head=8, + pair_head=4, + c_tri_mult=128, + ff_mult=4): + super().__init__() + + self.s, self.cm, self.cz, self.c = s, cm, cz, c + self.msa_head, self.pair_head = msa_head, pair_head + self.c_tri_mult, self.ff_mult = c_tri_mult, ff_mult + self.scale = 1.0 / math.sqrt(c) + + self.is_extra = is_extra + self.is_train = is_train + + if use_chunk: + if is_extra: + self.msa_row_chunk, self.msa_col_chunk = 4, -1 + self.opm_chunk, self.tans_chunk, self.tane_chunk = -1, -1, -1 + else: + self.msa_row_chunk, self.msa_col_chunk = 4, 256 + self.opm_chunk, self.tans_chunk, self.tane_chunk = 4, 4, 4 + else: + self.msa_row_chunk, self.msa_col_chunk = -1, -1 + self.opm_chunk, self.tans_chunk, self.tane_chunk = -1, -1, -1 + + # MSA row-wise gated self-attention with pair bias + self.row_norm_m = torch.nn.LayerNorm(cm) + self.row_norm_z = torch.nn.LayerNorm(cz) + self.row_gate_proj = torch.nn.Parameter(torch.randn(cm, cm)) + self.row_qkv_proj = torch.nn.Parameter(torch.randn(cm, 3 * cm)) + self.row_out_proj = torch.nn.Parameter(torch.randn(cm, cm)) + self.row_bias_proj = torch.nn.Parameter(torch.randn(cz, msa_head)) + + # MSA column-wise gated self-attention + self.col_norm = torch.nn.LayerNorm(cm) + self.col_gate_proj = torch.nn.Parameter(torch.randn(cm, cm)) + # TODO: fix me + self.col_q_proj = torch.nn.Parameter(torch.randn(cm, cm)) + self.col_k_proj = torch.nn.Parameter(torch.randn(cm, 8)) + self.col_v_proj = torch.nn.Parameter(torch.randn(cm, 8)) + self.col_qkv_proj = torch.nn.Parameter( + torch.randn(cm, 3 * msa_head * c)) + self.col_out_proj = torch.nn.Parameter(torch.randn(cm, cm)) + + # MSA transition + self.msa_transition_norm = torch.nn.LayerNorm(cm) + self.msa_transition_proj1 = torch.nn.Parameter( + torch.randn(cm, ff_mult * cm)) + self.msa_transition_proj2 = torch.nn.Parameter( + torch.randn(ff_mult * cm, cm)) + + # Outer product mean + self.outer_norm = torch.nn.LayerNorm(cm) + self.outer_proj1 = torch.nn.Parameter(torch.randn(cm, c)) + self.outer_proj2 = torch.nn.Parameter(torch.randn(cm, c)) + self.outer_out_proj = torch.nn.Parameter(torch.randn(c * c, cz)) + + # Triangular multiplicative update using outgoing edges + self.tri_mul_out_norm1 = torch.nn.LayerNorm(cz) + self.tri_mul_out_norm2_weight = torch.nn.Parameter( + torch.empty(c_tri_mult)) + self.tri_mul_out_norm2_bias = torch.nn.Parameter( + torch.empty(c_tri_mult)) + self.tri_mul_out_proj1 = torch.nn.Parameter(torch.randn( + cz, c_tri_mult)) + self.tri_mul_out_proj2 = torch.nn.Parameter(torch.randn( + cz, c_tri_mult)) + self.tri_mul_out_proj3 = torch.nn.Parameter(torch.randn( + cz, c_tri_mult)) + self.tri_mul_out_proj4 = torch.nn.Parameter(torch.randn( + cz, c_tri_mult)) + self.tri_mul_out_proj5 = torch.nn.Parameter(torch.randn( + c_tri_mult, cz)) + self.tri_mul_out_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) + + # Triangular multiplicative update using incoming edges + self.tri_mul_in_norm1 = torch.nn.LayerNorm(cz) + self.tri_mul_in_norm2_weight = torch.nn.Parameter( + torch.empty(c_tri_mult)) + self.tri_mul_in_norm2_bias = torch.nn.Parameter( + torch.empty(c_tri_mult)) + self.tri_mul_in_proj1 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) + self.tri_mul_in_proj2 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) + self.tri_mul_in_proj3 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) + self.tri_mul_in_proj4 = torch.nn.Parameter(torch.randn(cz, c_tri_mult)) + self.tri_mul_in_proj5 = torch.nn.Parameter(torch.randn(c_tri_mult, cz)) + self.tri_mul_in_proj6 = torch.nn.Parameter(torch.randn(cz, cz)) + + # Triangular gated self-attention around starting node + self.tri_att_start_norm = torch.nn.LayerNorm(cz) + self.tri_att_start_gate_proj = torch.nn.Parameter( + torch.randn(cz, pair_head * c)) + self.tri_att_start_qkv_proj = torch.nn.Parameter( + torch.randn(cz, 3 * pair_head * c)) + self.tri_att_start_out_proj = torch.nn.Parameter( + torch.randn(pair_head * c, cz)) + self.tri_att_start_bias_proj = torch.nn.Parameter( + torch.randn(cz, pair_head)) + + # Triangular gated self-attention around ending node + self.tri_att_end_norm = torch.nn.LayerNorm(cz) + self.tri_att_end_gate_proj = torch.nn.Parameter( + torch.randn(cz, pair_head * c)) + self.tri_att_end_qkv_proj = torch.nn.Parameter( + torch.randn(cz, 3 * pair_head * c)) + self.tri_att_end_out_proj = torch.nn.Parameter( + torch.randn(pair_head * c, cz)) + self.tri_att_end_bias_proj = torch.nn.Parameter( + torch.randn(cz, pair_head)) + + # Transition in the pair stack + self.pair_transition_norm = torch.nn.LayerNorm(cz) + self.pair_transition_proj1 = torch.nn.Parameter( + torch.randn(cz, ff_mult * cz)) + self.pair_transition_proj2 = torch.nn.Parameter( + torch.randn(ff_mult * cz, cz)) + + def forward(self, msa_repr, pair_repr): + cube.runtime.function.anchor('MSARow') + + pair_repr, dummy_pair_repr = multi2ref(pair_repr) + msa_repr = msa_repr + MSARowAttentionWithPairBias( + self.row_norm_m(msa_repr), dummy_pair_repr, self.row_gate_proj, + self.row_qkv_proj, self.row_out_proj, self.row_bias_proj, + self.msa_head, self.c, self.scale, self.msa_row_chunk, + self.is_train) + + cube.runtime.function.anchor('MSACol') + if self.is_extra: + msa_repr = msa_repr + MSAColGlobalAttention( + self.col_norm(msa_repr), self.col_q_proj, self.col_k_proj, + self.col_v_proj, self.col_gate_proj, self.col_out_proj, + self.msa_head, self.c, self.scale) + else: + msa_repr = msa_repr + MSAColAttention( + self.col_norm(msa_repr), self.col_gate_proj, self.col_qkv_proj, + self.col_out_proj, self.msa_head, self.c, self.scale, + self.msa_col_chunk, self.is_train) + + cube.runtime.function.anchor('MSATrans') + msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), + self.msa_transition_proj1, + self.msa_transition_proj2) + succ_msa_repr, msa_repr = multi2ref(msa_repr) + + cube.runtime.function.anchor('OPM') + msa_repr = self.outer_norm(msa_repr) + opm_left, opm_right = OPMLeftProj(msa_repr, + self.outer_proj1), OPMRightProj( + msa_repr, self.outer_proj2) + pair_repr = pair_repr + OuterProductMean( + opm_left, opm_right, self.outer_out_proj, self.opm_chunk, + self.is_train) + + cube.runtime.function.anchor('TMO') + pair_repr = self.tri_mul_out_norm1(pair_repr) + tmo_left, tmo_right = TMOLeftProj( + pair_repr, self.tri_mul_out_proj1, + self.tri_mul_out_proj2), TMORightProj(pair_repr, + self.tri_mul_out_proj3, + self.tri_mul_out_proj4) + tmo_g = TMOGate(pair_repr, self.tri_mul_out_proj6) + pair_repr = pair_repr + TriangleMultiplicationOut( + tmo_left, tmo_right, tmo_g, self.tri_mul_out_norm2_weight, + self.tri_mul_out_norm2_bias, self.tri_mul_out_proj5, self.cz) + + cube.runtime.function.anchor('TMI') + pair_repr = self.tri_mul_in_norm1(pair_repr) + tmi_left = TMILeftProj(pair_repr, self.tri_mul_in_proj1, + self.tri_mul_in_proj2) + tmi_right = TMIRightProj(pair_repr, self.tri_mul_in_proj3, + self.tri_mul_in_proj4) + tmi_gate = TMIGate(pair_repr, self.tri_mul_in_proj6) + pair_repr = pair_repr + TriangleMultiplicationIn( + tmi_left, tmi_right, tmi_gate, self.tri_mul_in_norm2_weight, + self.tri_mul_in_norm2_bias, self.tri_mul_in_proj5, self.cz) + + cube.runtime.function.anchor('TANS') + pair_repr = self.tri_att_start_norm(pair_repr) + bias = TANSBias(pair_repr, self.tri_att_start_bias_proj) + pair_repr = pair_repr + TriangleAttentionNodeStart( + pair_repr, self.tri_att_start_gate_proj, + self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, bias, + self.pair_head, self.c, self.scale, self.tans_chunk, self.is_train) + + cube.runtime.function.anchor('TANE') + pair_repr = self.tri_att_end_norm(pair_repr) + bias = TANEBias(pair_repr, self.tri_att_end_bias_proj) + pair_repr = pair_repr + TriangleAttentionNodeEnd( + pair_repr, self.tri_att_end_gate_proj, self.tri_att_end_qkv_proj, + self.tri_att_end_out_proj, bias, self.pair_head, self.c, + self.scale, self.tane_chunk, self.is_train) + + cube.runtime.function.anchor('PairTrans') + pair_repr = pair_repr + PairTransition( + self.pair_transition_norm(pair_repr), self.pair_transition_proj1, + self.pair_transition_proj2) + + return succ_msa_repr, pair_repr + + +class AlphaFold2(nn.Module): + + def __init__(self, + s: int, + cm: int, + cz: int, + evo_num: int, + use_chunk=False, + is_extra=False, + is_train=True): + super().__init__() + self.evo_num = evo_num + # add norm to work with PyTorch's recompute mechanism + self.msa_norm = torch.nn.LayerNorm(cm) + self.pair_norm = torch.nn.LayerNorm(cz) + self.evoformers = torch.nn.ModuleList([ + Evoformer(s, + cm, + cz, + use_chunk=use_chunk, + is_extra=is_extra, + is_train=is_train) for _ in range(evo_num) + ]) + + def forward(self, msa, pair): + msa = self.msa_norm(msa) + pair = self.pair_norm(pair) + + cube.runtime.function.anchor('Evoformer Stack Start') + for evoformer in self.evoformers: + cube.runtime.function.anchor('One Layer Evoformer Start') + msa, pair = evoformer(msa, pair) + cube.runtime.function.anchor('One Layer Evoformer End') + cube.runtime.function.anchor('Evoformer Stack End') + loss = torch.sum(msa) * torch.sum(pair) + return loss diff --git a/examples/alphafold2/module.py b/examples/alphafold2/module.py new file mode 100644 index 00000000..79af32b6 --- /dev/null +++ b/examples/alphafold2/module.py @@ -0,0 +1,422 @@ +import cube +import torch +import torch.utils.checkpoint as ckpt + + +@cube.graph.parser.register('TODO', name='calc_qkvg') +def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, + bs: int, s: int, r: int, head: int, c: int): + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + + gate = gate.reshape(bs, s, r, head, c).transpose(2, 3) + q = q.reshape(bs, s, r, head, c).transpose(2, 3) + k = k.reshape(bs, s, r, head, c).transpose(2, 3) + v = v.reshape(bs, s, r, head, c).transpose(2, 3) + return q, k, v, gate + + +""" +[bs, s, r, cm] -> [bs, s, r, cm] + +used as column-wise gated self-attention +""" + + +@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R M', + name='MSAAttention') +@torch.jit.ignore +def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, + c: int, scale: float, chunk_size: int, is_train: bool): + bs, s, r, cm = x.size() + + if chunk_size == -1: + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + + gate = gate.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + q = q.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + k = k.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, + c).transpose(1, 2) + v = v.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + + sim = torch.bmm(q, k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + + attend = torch.bmm(sim, v) * gate + + out = attend.reshape(bs, s, head, r, + c).transpose(2, 3).reshape(bs, s, r, cm) + out = torch.matmul(out, out_proj) + else: + if is_train: + q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, + bs, s, r, head, c) + else: + q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, + c) + assert s % chunk_size == 0 + out_chunks = [] + + def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + gate: torch.Tensor, start: int): + cur_q = q[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_k = k[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c).transpose(1, 2) + cur_v = v[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + sim = torch.bmm(cur_q, cur_k) * 0.125 + sim = torch.nn.functional.softmax(sim, dim=-1) + attend = torch.bmm(sim, cur_v) * cur_gate + attend = attend.reshape(bs, chunk_size, head, r, c).transpose( + 2, 3).reshape(bs, chunk_size, r, cm) + return attend + + for start in range(0, s, chunk_size): + if is_train: + attend = ckpt.checkpoint(attention, q, k, v, gate, start) + else: + attend = attention(q, k, v, gate, start) + out_chunks.append(attend) + + out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) + return out + + +@cube.graph.parser.register('N S R M, M E, M F, E M, N 1 8 R R -> N S R M', + name='MSAAttentionWithBias') +@torch.jit.ignore +def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias: torch.Tensor, head: int, c: int, scale: float, + chunk_size: int, is_train: bool): + bs, s, r, cm = x.size() + assert cm % head == 0 + c = cm // head + + if chunk_size == -1: + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + + gate = gate.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + q = q.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + k = k.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, + c).transpose(1, 2) + v = v.reshape(bs, s, r, head, + c).transpose(2, 3).reshape(bs * s * head, r, c) + + sim = torch.bmm(q, k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + + sim = sim.reshape(bs, s, head, r, r) + bias + sim = sim.reshape(bs * s * head, r, r) + + attend = torch.bmm(sim, v) * gate + + out = attend.reshape(bs, s, head, r, + c).transpose(2, 3).reshape(bs, s, r, cm) + out = torch.matmul(out, out_proj) + else: + if is_train: + q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, + bs, s, r, head, c) + else: + q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, + c) + + assert s % chunk_size == 0 + out_chunks = [] + + def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + gate: torch.Tensor, bias: torch.Tensor, start: int): + cur_q = q[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_k = k[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c).transpose(1, 2) + cur_v = v[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + + sim = torch.bmm(cur_q, cur_k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + sim = sim.reshape(bs, chunk_size, head, r, r) + bias + sim = sim.reshape(bs * chunk_size * head, r, r) + + attend = torch.bmm(sim, cur_v) * cur_gate + attend = attend.reshape(bs, chunk_size, head, r, c).transpose( + 2, 3).reshape(bs, chunk_size, r, cm) + return attend + + for start in range(0, s, chunk_size): + if is_train: + attend = ckpt.checkpoint(attention_bias, q, k, v, gate, bias, + start) + else: + attend = attention_bias(q, k, v, gate, bias, start) + + out_chunks.append(attend) + out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) + return out + + +""" +([bs, s, r, cm], [bs, r, r, cz]) -> [bs, s, r, cm] +""" + + +# note: code not reused constrained by cube's interface +@cube.graph.parser.register('N S R M, N R R Z, M E, M F, E M, Z H -> N S R M', + name='MSARowAttentionWithPairBias') +def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, + pair_repr: torch.Tensor, + gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias_proj: torch.Tensor, head: int, c: int, + scale: float, chunk_size: int, is_train: bool): + bs, s, r, cm = msa_repr.size() + + bias = torch.matmul(pair_repr, + bias_proj).permute(0, 3, 1, + 2).reshape(bs, 1, head, r, r) + + return MSAAttentionWithBias(msa_repr, gate_proj, qkv_proj, out_proj, bias, + head, c, scale, chunk_size, is_train) + + +@cube.graph.parser.register('N S R M, M E, M F, E M -> N S R M', + name='MSAColAttention') +def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, + c: int, scale: float, chunk_size: int, is_train: bool): + return MSAAttention(msa_repr.permute(0, 2, 1, 3), gate_proj, qkv_proj, + out_proj, head, c, scale, chunk_size, + is_train).permute(0, 2, 1, 3) + + +@cube.graph.parser.register('N S R M, M M, M E, M E, M M, M M -> N S R M', + name='MSAColGlobalAttention') +def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, + k_proj: torch.Tensor, v_proj: torch.Tensor, + gate_proj: torch.Tensor, out_proj: torch.Tensor, + head: int, c: int, scale: float): + msa_repr = msa_repr.transpose(-2, -3) + + q = torch.sum(msa_repr, dim=-2) + q = torch.matmul(q, q_proj) * scale + q = q.view(q.shape[:-1] + (head, -1)) + + k, v = torch.matmul(msa_repr, k_proj), torch.matmul(msa_repr, v_proj) + + a = torch.matmul(q, k.transpose(-1, -2)) + a = torch.nn.functional.softmax(a, dim=-1) + o = torch.matmul(a, v) + + g = torch.sigmoid(torch.matmul(msa_repr, gate_proj)) + g = g.view(g.shape[:-1] + (head, -1)) + + o = o.unsqueeze(-3) * g + o = o.reshape(o.shape[:-2] + (-1, )) + + return torch.matmul(o, out_proj).transpose(-2, -3) + + +""" +[bs, s, r, cm] -> [bs, s, r, cm] +""" + + +@cube.graph.parser.register('N S R M, M E, E M -> N S R M', + name='MSATransition') +def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + return torch.matmul( + torch.nn.functional.relu(torch.matmul(msa_repr, proj1)), proj2) + + +@cube.graph.parser.register('N S R M, M C -> N S R C', name='OPMLeftProj') +def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): + return torch.matmul(msa_repr, proj) + + +@cube.graph.parser.register('N S R M, M C -> N S R C', name='OPMRightProj') +def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): + return torch.matmul(msa_repr, proj) + + +""" +[bs, s, r, cm] -> [bs, r, r, cz] +""" + + +@cube.graph.parser.register('N S R M, N S T M, F Z -> N R T Z', + name='OuterProductMean') +@torch.jit.ignore +def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, + out_proj: torch.Tensor, chunk_size: int, is_train: bool): + bs, s, r, c = left_act.size() + t = right_act.size(2) + + a = left_act.transpose(-2, -3) + b = right_act.transpose(-2, -3) + + if chunk_size == -1: + outer = torch.einsum('...bac,...dae->...bdce', a, + b).reshape(bs, r, t, c * c) + outer = torch.matmul(outer, out_proj) + else: + out_chunks = [] + + def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): + lhs_slice = lhs[:, start:start + chunk_size, :, :] + out = torch.einsum('...bac,...dae->...bdce', lhs_slice, + rhs).reshape(bs, chunk_size, t, c * c) + out = torch.matmul(out, out_proj) + return out + + for start in range(0, r, chunk_size): + if is_train: + ret = ckpt.checkpoint(opm, a, b, start) + else: + ret = opm(a, b, start) + out_chunks.append(ret) + outer = torch.cat(out_chunks, dim=1) + return outer + + +@cube.graph.parser.register('N S R Z, Z E, Z E -> N S R E', name='TMOLeftProj') +def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + a = torch.sigmoid(torch.matmul(pair_repr, proj1)) + a = a * torch.matmul(pair_repr, proj2) + return a + + +@cube.graph.parser.register('N S R Z, Z E, Z E -> N S R E', + name='TMORightProj') +def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + a = torch.sigmoid(torch.matmul(pair_repr, proj1)) + a = a * torch.matmul(pair_repr, proj2) + return a + + +@cube.graph.parser.register('N S T Z, Z Z -> N S T Z', name='TMOGate') +def TMOGate(pair_repr: torch.Tensor, proj: torch.Tensor): + return torch.sigmoid(torch.matmul(pair_repr, proj)) + + +@cube.graph.parser.register('N S R E, N T R E, N S T Z, E, E, E Z -> N S T Z', + name='TriangleMultiplicationOut') +def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, + g: torch.Tensor, + tri_mul_norm2_weight: torch.Tensor, + tri_mul_norm2_bias: torch.Tensor, + tri_mul_proj5: torch.Tensor, cz: int): + a = a.permute(0, 3, 1, 2) + b = b.permute(0, 3, 2, 1) + + p = torch.matmul(a, b).permute(0, 2, 3, 1) + p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, + tri_mul_norm2_bias) + p = torch.matmul(p, tri_mul_proj5) + return p * g + + +@cube.graph.parser.register('N R S Z, Z E, Z E -> N R S E', name='TMILeftProj') +def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + a = torch.sigmoid(torch.matmul(pair_repr, proj1)) + a = a * torch.matmul(pair_repr, proj2) + return a + + +@cube.graph.parser.register('N R T Z, Z E, Z E -> N R T E', + name='TMIRightProj') +def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + a = torch.sigmoid(torch.matmul(pair_repr, proj1)) + a = a * torch.matmul(pair_repr, proj2) + return a + + +@cube.graph.parser.register('N S T Z, Z Z -> N S T Z', name='TMIGate') +def TMIGate(pair_repr: torch.Tensor, proj: torch.Tensor): + return torch.sigmoid(torch.matmul(pair_repr, proj)) + + +@cube.graph.parser.register('N R S E, N R T E, N T S Z, E, E, E Z -> N T S Z', + name='TriangleMultiplicationIn') +def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, + tri_mul_norm2_weight: torch.Tensor, + tri_mul_norm2_bias: torch.Tensor, + tri_mul_proj5: torch.Tensor, cz: int): + a = a.permute(0, 3, 2, 1) + b = b.permute(0, 3, 1, 2) + + p = torch.matmul(a, b).permute(0, 2, 3, 1) + p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, + tri_mul_norm2_bias) + p = torch.matmul(p, tri_mul_proj5) + return p.permute(0, 2, 1, 3) * g + + +@cube.graph.parser.register('N S R C, C D -> N S R D', name='TANSBias') +def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): + return torch.matmul(pair_repr, bias_proj) + + +@cube.graph.parser.register('N S R Z, Z E, Z F, E Z, N T R G -> N S R Z', + name='TriangleAttentionNodeStart') +def TriangleAttentionNodeStart(pair_repr: torch.Tensor, + gate_proj: torch.Tensor, qkv_proj: torch.Tensor, + out_proj: torch.Tensor, bias: torch.Tensor, + head: int, c: int, scale: float, + chunk_size: int, is_train: bool): + bias = bias.permute(0, 3, 1, 2).unsqueeze(1) + + return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, + head, c, scale, chunk_size, is_train) + + +@cube.graph.parser.register('N S R C, C D -> N S R D', name='TANEBias') +def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): + return torch.matmul(pair_repr, bias_proj) + + +@cube.graph.parser.register('N R S Z, Z E, Z F, E Z, N R T G -> N R S Z', + name='TriangleAttentionNodeEnd') +def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias: torch.Tensor, head: int, c: int, + scale: float, chunk_size: int, is_train: bool): + pair_repr = pair_repr.permute(0, 2, 1, 3) + bias = bias.permute(0, 2, 1, 3) + out = TriangleAttentionNodeStart(pair_repr, gate_proj, qkv_proj, out_proj, + bias, head, c, scale, chunk_size, + is_train) + return out.permute(0, 2, 1, 3) + + +@cube.graph.parser.register('N R T Z, Z E, E Z -> N R T Z', + name='PairTransition') +def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, + proj2: torch.Tensor): + return torch.matmul( + torch.nn.functional.relu(torch.matmul(pair_repr, proj1)), proj2) + + +@cube.graph.parser.register('* -> *, *', name='multi2ref') +def multi2ref(x: torch.Tensor): + return (x, x) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 921374d7..24d07f1c 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -48,6 +48,16 @@ def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], return sub_nodes +def PASSingleInference(graph: IRGraph, resource): + assert resource.ngpus == 1 + + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + + return graph + + def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 @@ -69,7 +79,7 @@ def PASSingle(graph: IRGraph, resource): rhs = indices[2 * i + 1] # deepmind's default recompute strategy - # graph.recompute(fnodes[lhs + 1:rhs]) + graph.recompute(fnodes[lhs + 1:rhs]) # another strategy # sub_indices = [] @@ -83,6 +93,29 @@ def PASSingle(graph: IRGraph, resource): return graph +def PASExtraSingle(graph: IRGraph, resource): + assert resource.ngpus == 1 + + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + + fnodes = graph.nodes() + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + + indices = [ + fnodes.index(anchor) for anchor in anchors + if anchor.name == 'MSACol' or anchor.name == 'One Layer Evoformer End' + ] + assert len(indices) % 2 == 0 + for i in range(len(indices) // 2): + lhs = indices[2 * i] + rhs = indices[2 * i + 1] + + graph.recompute(fnodes[lhs + 1:rhs]) + return graph + + def PASData(graph: IRGraph, resource): devs = list(range(resource.ngpus)) @@ -138,7 +171,7 @@ def PASDAP(graph: IRGraph, resource): if isinstance(fnodes[j], IRGraphAnchor): sub_indices.append(j) sub_indices.append(rhs) - # graph.recompute(fnodes[lhs:rhs]) + graph.recompute(fnodes[lhs:rhs]) for j in range(len(sub_indices) - 1): sub_l, sub_r = sub_indices[j], sub_indices[j + 1] names = [] @@ -176,3 +209,66 @@ def PASDAP(graph: IRGraph, resource): _replica(graph, fnodes[i], tp_devs) return graph + + +def PASDAPInference(graph: IRGraph, resource): + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + + fnodes = graph.nodes() + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + + indices = [ + fnodes.index(anchor) for anchor in anchors + if anchor.name == 'One Layer Evoformer Start' + or anchor.name == 'One Layer Evoformer End' + ] + assert len(indices) % 2 == 0 + + for i in range(indices[0]): + if isinstance(fnodes[i], IRDataOperation) or isinstance( + fnodes[i], IRFwOperation): + _replica(graph, fnodes[i], tp_devs) + + for i in range(len(indices) // 2): + lhs, rhs = indices[2 * i], indices[2 * i + 1] + sub_indices = [] + for j in range(lhs + 1, rhs): + if isinstance(fnodes[j], IRGraphAnchor): + sub_indices.append(j) + sub_indices.append(rhs) + for j in range(len(sub_indices) - 1): + sub_l, sub_r = sub_indices[j], sub_indices[j + 1] + names = [] + for k in range(sub_l + 1, sub_r): + names.append(fnodes[k].name) + names = set(names) + nodes = fnodes[sub_l + 1:sub_r] + + if 'MSARowAttentionWithPairBias' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) + elif 'MSAColAttention' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'MSATransition' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'OuterProductMean' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'TriangleMultiplicationOut' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) + elif 'TriangleMultiplicationIn' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'TriangleAttentionNodeStart' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) + elif 'TriangleAttentionNodeEnd' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) + elif 'PairTransition' in names: + sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) + else: + assert False, names + + for i in range(indices[-1] + 1, len(fnodes)): + if isinstance(fnodes[i], IRDataOperation) or isinstance( + fnodes[i], IRFwOperation): + _replica(graph, fnodes[i], tp_devs) + + return graph From 48662b6e5b304256a8e39265460e78853a28b370 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 17 Oct 2022 09:31:33 +0000 Subject: [PATCH 1074/1892] add basic nums --- examples/alphafold2/README.md | 61 +++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 examples/alphafold2/README.md diff --git a/examples/alphafold2/README.md b/examples/alphafold2/README.md new file mode 100644 index 00000000..c936e4ab --- /dev/null +++ b/examples/alphafold2/README.md @@ -0,0 +1,61 @@ +# Introduction + +Benchmark different schedule plans of Alphafold2 based on MagicCube. + +# Results + +## Training + +### Evoformer Stack + +**s, r = 128, 256** + +| device num | policy | peak mem (MB) | activation mem (MB) | time (ms) | +|:-----------|:-------|:--------------|:--------------------|:----------| +| 1 | small | 8119 | 7462 | 3656.61 | +| 1 | large | 4414 | 2070 | 3635.38 | +| 2 | small | 4351 | 4014 | 2539.56 | +| 2 | large | 2531 | 1318 | 2506.10 | + +**s, r = 512, 256** + +| device num | policy | peak mem (MB) | activation mem (MB) | time (ms) | +|:-----------|:-------|:--------------|:--------------------|:----------| +| 1 | small | 18952 | 14471 | 7949.96 | +| 1 | large | 10729 | 4423 | 7914.68 | +| 2 | small | 9839 | 7567 | 4839.22 | +| 2 | large | 5744 | 2543 | 4793.78 | + +**s, r = 512, 384** + +| device num | policy | peak mem (MB) | activation mem (MB) | time (ms) | +|:-----------|:-------|:--------------|:--------------------|:----------| +| 1 | small | OOM | OOM | OOM | +| 1 | large | 17810 | 7104 | 17063.41 | +| 2 | small | 16230 | 12847 | 9659.66 | +| 2 | large | 9416 | 3870 | 9629.48 | + +### Extra Msa Stack + +**device num = 1** + +| Config | peak mem (MB) | activation mem (MB) | time (ms) | +|:-----------------|:--------------|:--------------------|:----------| +| s, r = 1024, 256 | 3236 | 1166 | 2306 | +| s, r = 1024, 384 | 6976 | 1805 | 3749.43 | +| s, r = 5120, 384 | 16168 | 8210 | 58393.83 | + +## Inference + +### T1044 + +**s, r = 128, 2048** + +| device num | policy | peak mem (MB) | time (ms) | +|:-----------|:------------|:--------------|:----------| +| 1 | direct | OOM | OOM | +| 1 | chunk | 23374 | 339742.02 | +| 2 | DAP | OOM | OOM | +| 2 | DAP + chunk | 13006 | 192577.34 | +| 4 | DAP | OOM | OOM | +| 4 | DAP + chunk | 9358 | 101993 | \ No newline at end of file From c039a0c1c2ded89928a467faed3278e7018ee253 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 18 Oct 2022 21:30:18 +0800 Subject: [PATCH 1075/1892] save work --- examples/alphafold2/README.md | 58 +++++++++++++++++++++++++++++++++++ examples/alphafold2/module.py | 11 +++++++ 2 files changed, 69 insertions(+) diff --git a/examples/alphafold2/README.md b/examples/alphafold2/README.md index c936e4ab..b83d8fa5 100644 --- a/examples/alphafold2/README.md +++ b/examples/alphafold2/README.md @@ -2,14 +2,64 @@ Benchmark different schedule plans of Alphafold2 based on MagicCube. +# Model + +## Structure + +TODO + +## Challenge + +TODO + +## Problem Formulation + +TODO: try out del tensor in functions that to be recomputed -> offload problems to jit tensor compilers + +strategy: detect memory constrained parts then coshard them + +large enough size of input shapes already utilize accelerators + +should include coshard into the dp formulation + +## Memory Consumption + +notation +- $s$: multiple sequence alignment (MSA) number +- $r$: residue number +- $c_{m}$: hidden dimension of MSA representation +- $c_{z}$: hidden dimension of pair representation +- $h$: head number. Different modules may differ + +activation +- one Evoformer's output: $s \cdot r \cdot c_{m} + r^{2} \cdot c_{z}$ +- Modules' outputs inside a Evoformer block: $3 \cdot s \cdot r \cdot c_{m} + 6 \cdot s \cdot r^{2} \cdot c_{z}$ + +peak memory +- MSA Row Attention with Bias: $h \cdot s \cdot r^2$, where $h=8$ +- MSA Col Attention: $h \cdot s^2 \cdot r$, where $h=8$ +- MSA Transition: $4 \cdot s \cdot r \cdot c_{m}$ +- Outer Product Mean: $r^2 \cdot c^2$, where $c=32$ +- Triangular Multiplicative Update using Outgoing Edges: $r^2 \cdot c$, where $c=128$ +- Triangular Multiplicative Update using Ingoing Edges: $r^2 \cdot c$, where $c=128$ +- Triangular Gated Self-Attention around Starting Node: $h \cdot r^3$, where $h=4$ +- Triangular Gated Self-Attention around Ending Node: $h \cdot r^3$, where $h=4$ +- Pair Transition: $4 \cdot s \cdot r^2 \cdot c_{z}$ + # Results ## Training +Computation in float16 + ### Evoformer Stack +$48$ Evoformers in total + **s, r = 128, 256** +1 Evoformer output size: $16 + 16 = 32$MB + | device num | policy | peak mem (MB) | activation mem (MB) | time (ms) | |:-----------|:-------|:--------------|:--------------------|:----------| | 1 | small | 8119 | 7462 | 3656.61 | @@ -19,6 +69,8 @@ Benchmark different schedule plans of Alphafold2 based on MagicCube. **s, r = 512, 256** +1 Evoformer output size: $64 + 16 = 80$MB + | device num | policy | peak mem (MB) | activation mem (MB) | time (ms) | |:-----------|:-------|:--------------|:--------------------|:----------| | 1 | small | 18952 | 14471 | 7949.96 | @@ -28,6 +80,8 @@ Benchmark different schedule plans of Alphafold2 based on MagicCube. **s, r = 512, 384** +1 Evoformer output size: $96 + 36 = 132$MB + | device num | policy | peak mem (MB) | activation mem (MB) | time (ms) | |:-----------|:-------|:--------------|:--------------------|:----------| | 1 | small | OOM | OOM | OOM | @@ -37,6 +91,8 @@ Benchmark different schedule plans of Alphafold2 based on MagicCube. ### Extra Msa Stack +$4$ Extra-Evoformer, $c_{m} = 64$ and $c_{z} = 128$ + **device num = 1** | Config | peak mem (MB) | activation mem (MB) | time (ms) | @@ -47,6 +103,8 @@ Benchmark different schedule plans of Alphafold2 based on MagicCube. ## Inference +Computation in float32 + ### T1044 **s, r = 128, 2048** diff --git a/examples/alphafold2/module.py b/examples/alphafold2/module.py index 79af32b6..80af9bf2 100644 --- a/examples/alphafold2/module.py +++ b/examples/alphafold2/module.py @@ -211,22 +211,33 @@ def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, v_proj: torch.Tensor, gate_proj: torch.Tensor, out_proj: torch.Tensor, head: int, c: int, scale: float): + # [N R S M] msa_repr = msa_repr.transpose(-2, -3) + # [N R M] q = torch.sum(msa_repr, dim=-2) + # [N R M] q = torch.matmul(q, q_proj) * scale + # [N R H E] q = q.view(q.shape[:-1] + (head, -1)) + # [N R S E] k, v = torch.matmul(msa_repr, k_proj), torch.matmul(msa_repr, v_proj) + # [N R H S] a = torch.matmul(q, k.transpose(-1, -2)) a = torch.nn.functional.softmax(a, dim=-1) + # [N R H E] o = torch.matmul(a, v) + # [N R S M] g = torch.sigmoid(torch.matmul(msa_repr, gate_proj)) + # [N R S H E] g = g.view(g.shape[:-1] + (head, -1)) + # [N R 1 H E] o = o.unsqueeze(-3) * g + # [N R S M] o = o.reshape(o.shape[:-2] + (-1, )) return torch.matmul(o, out_proj).transpose(-2, -3) From e7908d35ff15a99e7892121d3f1209d817e21fa9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 19 Oct 2022 18:26:29 +0800 Subject: [PATCH 1076/1892] init dijkstra plan for rvd and rvd+ --- cube/graph/gener/layout.py | 528 ++++++++++++++++++++++++++++++++++++- cube/ir/adapter/prim.py | 215 +++++++++++++-- cube/ir/cten.py | 23 +- 3 files changed, 720 insertions(+), 46 deletions(-) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 10e127d6..10b8ad17 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -1,16 +1,23 @@ -from typing import Dict, List, Tuple +from typing import Callable, Dict, List, Tuple, Optional import copy import numpy as np +from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.tensor import IndexMap, ValueMap +from cube.ir.adapter.prim import IRAdapterPrim from cube.ir.adapter.prim import AllGatherPrim # d2r from cube.ir.adapter.prim import AllToAllPrim # d2d from cube.ir.adapter.prim import AllReducePrim # v2r from cube.ir.adapter.prim import ReduceScatterPrim # v2d from cube.ir.adapter.prim import ChunkPrim # r2d +from cube.ir.adapter.prim import MovePrim # p2p +from cube.ir.adapter.prim import BroadcastPrim +from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim +from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim + class GridLayout: """ @@ -52,15 +59,32 @@ def vec(self) -> Tuple[int]: def ndims(self): return len(self._mats.shape) + @property + def ndevs(self): + return len(self.subtensors) + @property def mat(self): return self._mats - # ====== primitives ===== # + def tensor(self, r: int, v: int, d: List[int]) -> IRSubTensor: + """ + Get subtenor indexed by RVD position. + """ + assert r <= self.R and v <= self.V and len(d) == len(self.D), "out of scope" + indices = [r, v] + list(d) + return self._mats[tuple(indices)] + + def __repr__(self): + dscp = f'T{self.ftensor._id}' + return dscp + + # ====== inshard transformation primitives ===== # def d2r(self, dim: int, chunks: int): """ - dimension to replica: allgather + RVD Primitive: dimension to replica + collective: allgather """ layout = list(self.vec) assert layout[2+dim] % chunks == 0, f"not dividable dim: {layout[2+dim]} // {chunks}" @@ -80,7 +104,8 @@ def d2r(self, dim: int, chunks: int): def d2d(self, from_dim: int, to_dim: int, chunks: int): """ - dimension to dimension: all-to-all + RVD Primitive: dimension to dimension + collective: all-to-all """ layout = list(self.vec) assert layout[2+from_dim] % chunks == 0, f"not dividable dim: {layout[2+from_dim]} // {chunks}" @@ -100,7 +125,8 @@ def d2d(self, from_dim: int, to_dim: int, chunks: int): def v2r(self, chunks: int): """ - value to replica: all-reduce + RVD Prmitive: value to replica + collective: all-reduce """ layout = list(self.vec) assert layout[1] % chunks == 0, f"not dividable value chunks: {layout[1]} // {chunks}" @@ -120,7 +146,8 @@ def v2r(self, chunks: int): def v2d(self, dim: int, chunks: int): """ - value to dimension: reduce-scatter + RVD Primitive: value to dimension + collective: reduce-scatter """ layout = list(self.vec) assert layout[1] % chunks == 0, f"not dividable value chunks: {layout[0]} // {chunks}" @@ -140,7 +167,8 @@ def v2d(self, dim: int, chunks: int): def r2d(self, dim: int, chunks: int): """ - replica to dimension: split + RVD Primitive: replica to dimension + collective: split """ layout = list(self.vec) assert layout[0] % chunks == 0, f"not dividable replica: {layout[0]} // {chunks}" @@ -161,6 +189,110 @@ def r2d(self, dim: int, chunks: int): # prims.append(ChunkPrim(itensor, otensor, dim, ranks)) return glayout, prims + def incr(self, chunks: int, devices: List[int]): + """ + RVD+ Prmitive: increase replica + collective: broadcast + """ + layout = list(self.vec) + layout[0] = layout[0] * chunks + glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + # set device + imat = GridLayout.dims2last(self.mat, [0]).flatten() + omat = GridLayout.dims2last(glayout.mat, [0]).reshape(-1, chunks) + prims = [] + for src, dsts in zip(imat, omat): + prims.append(BroadcastPrim(src, [src] + list(dsts))) + return glayout, prims + + + def decr(self, chunks: int, devices: List[int]): + """ + RVD+ Prmitive: decrease replica + collective: move + """ + layout = list(self.vec) + assert layout[0] % chunks == 0, f"not divisible replica: {layout[0]} // {chunks}" + layout[0] = layout[0] // chunks + glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + # set device + imat = GridLayout.dims2last(self.mat, [0]).reshape(-1, chunks) + omat = GridLayout.dims2last(glayout.mat, [0]).flatten() + prims = [] + for srcs, dst in zip(imat, omat): + prims.append(MovePrim(srcs[0], dst)) + return glayout, prims + + + def incd(self, chunks: int, dim: int, devices: List[int]): + """ + RVD+ Prmitive: increase dimension + collective: rdscatter + """ + layout = list(self.vec) + layout[2+dim] = layout[2+dim] * chunks + glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + # TODO: set device + imat = GridLayout.dims2last(glayout.mat, [2+dim]).flatten() + omat = GridLayout.dims2last(glayout.mat, [2+dim]).reshape(-1, chunks) + prims = [] + for src, dsts in zip(imat, omat): + prims.append(RDScatterPrim(src, dsts, dim=dim)) + return glayout, prims + + + def decd(self, chunks: int, dim: int, devices: List[int]): + """ + RVD+ Prmitive: increase dimension + collective: rdgather + """ + layout = list(self.vec) + assert layout[2+dim] % chunks == 0, f"not divisible dim: {self.D[dim]} % {chunks} != 0" + layout[2+dim] = layout[2+dim] // chunks + glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + # set device + imat = GridLayout.dims2last(self.mat, [2+dim]).reshape(-1, chunks) + omat = GridLayout.dims2last(glayout.mat, [2+dim]).flatten() + prims = [] + for srcs, dst in zip(imat, omat): + prims.append(RDGatherPrim(srcs, dst, dim=dim)) + return glayout, prims + + + def incv(self, chunks: int, devices: List[int]): + """ + RVD+ Primitive: increase value partition + collective: rvscatter + """ + layout = list(self.vec) + layout[1] = layout[1] * chunks + glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + # TODO: set device + imat = GridLayout.dims2last(glayout.mat, [1]).flatten() + omat = GridLayout.dims2last(glayout.mat, [1]).reshape(-1, chunks) + prims = [] + for src, dsts in zip(imat, omat): + prims.append(RVScatterPrim(src, dsts)) + return glayout, prims + + def decv(self, chunks: int, devices: List[int]): + """ + RVD+ Primitive: decrease value partition + collective: rvgather + """ + layout = list(self.vec) + assert layout[1] % chunks == 0, f"not divisable value split: {self.V} % {chunks} != 0" + layout[1] = layout[1] * chunks + glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + # TODO: set device + imat = GridLayout.dims2last(self.mat, [1]).reshape(-1, chunks) + omat = GridLayout.dims2last(glayout.mat, [1]).flatten() + prims = [] + for srcs, dst in zip(imat, omat): + prims.append(RVGatherPrim(srcs, dst)) + return glayout, prims + + # ================ solution ============= # def path(self, dst) -> Tuple: @@ -238,10 +370,6 @@ def step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLa break return paths, comm_prims - def __repr__(self): - dscp = f'T{self.ftensor._id}' - return dscp - def print_dev_tensors(self): """ print each device hold tensors. @@ -273,10 +401,25 @@ def transpose(mat: np.ndarray, dim0: int, dim1: int): return np.transpose(mat, axes) @staticmethod - def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int]): + def dims2last(mat: np.ndarray, dims: List[int]) -> np.ndarray: + """ + Permute a matrix by putting dimensions to the last. + """ + axes = list(range(len(mat.shape))) + for dim in dims: + axes.remove(dim) + axes += list(dims) + return np.transpose(mat, axes) + + @staticmethod + def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int], devices: Optional[Tuple[int]] = None): """ partition a ftensor using grid layout of """ + def dummy_assign(tensor: IRSubTensor, devid: int): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = devid + mats = np.empty([r, v] + dims, dtype=IRSubTensor) all_subtensors = [] @@ -303,6 +446,13 @@ def iter_idx(dims: List[int]) -> Tuple[int]: subtensors = [copy.copy(subtensor) for _ in range(r)] all_subtensors += subtensors mats[(slice(None),)+indices] = np.array(subtensors, dtype=IRSubTensor) + + # devices + if devices is not None: + assert len(devices) == len(all_subtensors), f"devices number {len(devices)} not match with RVD number {len(all_subtensors)}" + for tensor, devid in zip(all_subtensors, devices): + dummy_assign(tensor, devid) + return GridLayout(ftensor, all_subtensors, mats) @staticmethod @@ -370,3 +520,357 @@ def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): mats[tuple(idx)] = subtensor assert not (mats == None).any(), "at least one entry not set" return GridLayout(ftensor, subtensors, mats) + + +TShape = Tuple[int, ...] +TRVD = Tuple[int, ...] + + +class PathFinder: + """ + Pathfinder for generating communication plans for GridLayout + """ + + # intra-shard: cached nodes. paths[shape][i][j] = List[int] of indices from (src -> dst] + _cached_intra_nodes: Dict[Tuple[TShape, int], Tuple[TRVD]] = {} + _cached_intra_edges: Dict[Tuple[TShape, int], np.ndarray] = {} + _cached_intra_paths: Dict[Tuple[TShape, int], Dict[TRVD, List[List[int]]]] = {} + + # inter-shard: cached nodes. paths[(shape1, shape2)][i][j] = List[int] + _cached_inter_nodes: Dict[Tuple[TShape, int, int], Tuple[Tuple[TRVD]]] = {} + _cached_inter_edges: Dict[Tuple[TShape, int, int], Tuple[np.ndarray]] = {} + _cached_inter_paths: Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]] = {} + + @staticmethod + def intra_shard(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: + """ + Get primitive path of transforming ilayout into olayout. + ilayout has the same device set with olayout + + @param ftensor IRFullTensor: The fulltensor + @param ilayout GridLayout: input tensor layout + @param olayout GridLayout: output tensor layout + @param cost_fn Optional[Callable]: cost function of each primitive. + Default (None) will use transmission volume as metrics + + @return layouts List[GridLayout]: each transformation. + @return prims List[IRAdapterPrim]: the primitives to perform transformation. + """ + cost_fn = PathFinder.default_cost_fn if cost_fn is None else cost_fn + shape = tuple(ftensor.shape) + key = (shape, ilayout.ndevs) + src = (ilayout.R, ilayout.V) + tuple(ilayout.D) + dst = (olayout.R, olayout.V) + tuple(olayout.D) + if src == dst: return [], [] + + # get paths using dijkstra algorithm or cached + if key in PathFinder._cached_intra_paths and src in PathFinder._cached_intra_paths[key]: + paths = PathFinder._cached_intra_paths[key][src] + else: + # initialize the graph if not cached + if key not in PathFinder._cached_intra_nodes: + nodes, edges = PathFinder.init_intra_graph(ftensor, ilayout.ndevs, cost_fn) + PathFinder._cached_intra_nodes[key] = nodes + PathFinder._cached_intra_edges[key] = edges + PathFinder._cached_intra_paths[key] = {} + nodes = PathFinder._cached_intra_nodes[key] + edges = PathFinder._cached_intra_edges[key] + # build and initialize cost table + cost = np.full((len(nodes),), np.inf) + cost[nodes.index(src)] = 0 + # setup unvisited and visited set + unvisited = set(range(len(nodes))) + visited = set() + paths = [[] for _ in range(len(nodes))] + paths[nodes.index(src)] = [nodes.index(src)] + # dijkstra body + while len(unvisited) > 0: + min_cost, visit = np.inf, None + for idx in unvisited: + if cost[idx] < min_cost: + min_cost = idx + visit = idx + if visit is None: break + for neighbor in np.where(edges[visit] != np.inf)[0]: + if neighbor in visited: continue + new_cost = cost[visit] + edges[visit, neighbor] + if cost[neighbor] == np.inf or new_cost < cost[neighbor]: + cost[neighbor] = new_cost + paths[neighbor] = paths[visit] + [neighbor] + cost[neighbor] = min(cost[neighbor], cost[visit] + edges[visit, neighbor]) + unvisited.remove(visit) + visited.add(visit) + PathFinder._cached_intra_paths[key][src] = paths + + # print for debug + for idx, path in enumerate(paths): + print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") + + # get layout + nodes = PathFinder._cached_intra_nodes[key] + path = paths[nodes.index(dst)] + assert len(path) > 0, f"Un-reachable src RVD ({src}) -> dst RVD ({dst})" + + layouts = [ilayout] + all_prims = [] + curr_rvd = src + for hop in path[1:]: + hop_rvd = nodes[hop] + inc_dim, dec_dim = None, None + for dim, (ipnum, opnum) in enumerate(zip(curr_rvd, hop_rvd)): + if ipnum > opnum: + assert dec_dim is None + dec_dim = dim + continue + if opnum > ipnum: + assert inc_dim is None + inc_dim = dim + continue + nchunks = curr_rvd[dec_dim] // hop_rvd[dec_dim] + layout, prims = PathFinder.intra_step(layouts[-1], dec_dim, inc_dim, nchunks) + layouts.append(layout) + all_prims += prims + curr_rvd = hop_rvd + return layouts, all_prims + + + @staticmethod + def inter_shard(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: + """ + Get primitives for transforming ilayout into olayout. ilayout has the different device set + to olayout. And number of device of ilayout and olayout must be divisable by each other. + + @param ftensor IRFullTensor: The fulltensor + @param ilayout GridLayout: input tensor layout + @param olayout GridLayout: output tensor layout + @param cost_fn Optional[Callable]: cost function of each primitive. + Default (None) will use transmission volume as metrics + + @return layouts List[GridLayout]: each transformation. + @return prims List[IRAdapterPrim]: the primitives to perform transformation. + """ + cost_fn = PathFinder.default_cost_fn if cost_fn is None else cost_fn + shape = tuple(ftensor.shape) + key = (shape, ilayout.ndevs, olayout.ndevs) + + src = ('p',) + (ilayout.R, ilayout.V) + tuple(ilayout.D) + dst = ('c',) + (olayout.R, olayout.V) + tuple(olayout.D) + + if key in PathFinder._cached_inter_nodes and src in PathFinder._cached_inter_paths[key]: + paths = PathFinder._cached_inter_paths[key][src] + else: + if key in PathFinder._cached_inter_nodes: + nodes = PathFinder._cached_inter_nodes[key] + edges = PathFinder._cached_inter_edges[key] + else: + nodes, edges = PathFinder.init_inter_graph(ftensor, ilayout.ndevs, olayout.ndevs, cost_fn) + PathFinder._cached_inter_nodes[key] = nodes + PathFinder._cached_inter_edges[key] = edges + PathFinder._cached_inter_paths[key] = {} + # build cost + cost = np.full((len(nodes),), np.inf) + cost[nodes.index(src)] = 0 + # setup unvisited and visited set + unvisited = set(range(len(nodes))) + visited = set() + paths = [[] for _ in range(len(nodes))] + paths[nodes.index(src)] = [nodes.index(src)] + # dijkstra body + while len(unvisited) > 0: + min_cost, visit = np.inf, None + for idx in unvisited: + if cost[idx] < min_cost: + min_cost = idx + visit = idx + if visit is None: break + for neighbor in np.where(edges[visit] != np.inf)[0]: + if neighbor in visited: continue + new_cost = cost[visit] + edges[visit, neighbor] + if cost[neighbor] == np.inf or new_cost < cost[neighbor]: + cost[neighbor] = new_cost + paths[neighbor] = paths[visit] + [neighbor] + cost[neighbor] = min(cost[neighbor], cost[visit] + edges[visit, neighbor]) + unvisited.remove(visit) + visited.add(visit) + PathFinder._cached_inter_paths[key][src] = paths + + # print for debug + for idx, path in enumerate(paths): + print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") + + + @staticmethod + def intra_step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> Tuple[GridLayout, List[IRAdapterPrim]]: + if dec_idx >= 2 and inc_idx == 0: # d2r + return ilayout.d2r(dec_idx-2, chunks) + if dec_idx >= 2 and inc_idx >= 2: # d2d + return ilayout.d2d(dec_idx-2, inc_idx-2, chunks) + if dec_idx == 1 and inc_idx == 0: # v2r + return ilayout.v2r(chunks) + if dec_idx == 1 and inc_idx >= 2: # v2d + return ilayout.v2d(inc_idx-2, chunks) + if dec_idx == 0 and inc_idx >= 2: # r2d + return ilayout.r2d(inc_idx-2, chunks) + raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") + + @staticmethod + def inter_step(ilayout: GridLayout, dec_idx: Optional[int], inc_idx: Optional[int], chunks: int): + assert dec_idx is None or inc_idx is None + if isinstance(inc_idx, int): + if inc_idx == 0: + return ilayout.incr(chunks, []) + if inc_idx == 1: + return ilayout.incv(chunks, []) + if inc_idx > 1: + return ilayout.incd(chunks, inc_idx-2, []) + raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") + else: + if dec_idx == 0: + return ilayout.decr(chunks, []) + if dec_idx == 1: + return ilayout.decv(chunks, []) + if dec_idx > 1: + return ilayout.decd(chunks, dec_idx-2, []) + raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") + + + @staticmethod + def init_intra_graph(ftensor: IRFullTensor, ndevs: int, cost_fn: Optional[Callable]) -> Tuple[List[TRVD], np.ndarray]: + """ + Initialize the graph of RVD status graph. + + @param ftensor IRFullTensor: the full tensor + @param ndevs int: total device number + + @return nodes Tuple[TRVD] + @return edges np.ndarray: edges among nodes + """ + nodes = tuple(PathFinder.get_inshard_space(ftensor, ndevs)) + edges = np.full((len(nodes), len(nodes)), np.inf) + # initialize the cost + for i in range(len(nodes)): + for j in range(len(nodes)): + if i == j: continue + isrc, idst = nodes[i], nodes[j] + inc_dim, dec_dim = [], [] + for dim, (pnum_src, pnum_dst) in enumerate(zip(isrc, idst)): + if pnum_src > pnum_dst: + dec_dim.append(dim) + elif pnum_src < pnum_dst: + inc_dim.append(dim) + if len(inc_dim) != 1 or len(dec_dim) != 1: + continue # not direct + inc_dim, dec_dim = inc_dim[0], dec_dim[0] + if idst[inc_dim] % isrc[inc_dim] != 0 or isrc[dec_dim] % idst[dec_dim] != 0: + continue # not direct + if inc_dim == 1: + continue # not consider increasing value partition + nchunks = isrc[dec_dim] // idst[dec_dim] + isrc_layout = GridLayout.grid(ftensor, isrc[0], isrc[1], list(isrc[2:])) + _, prims = PathFinder.intra_step(isrc_layout, dec_dim, inc_dim, nchunks) + edges[i, j] = cost_fn(prims[0]) + return nodes, edges + + @staticmethod + def init_inter_graph(ftensor: IRFullTensor, idevs: int, odevs: int, cost_fn: Callable) -> Tuple[List[TRVD], np.ndarray]: + """ + Initialize the graph of RVD status graph. + + An additional positition tage is append to at the first element of each node, i.e., + For source (producer) layout: ('p', 2,1,1,2) means + For dest (consumer) layout: ('c', 2,1,1,2) means + + @param ftensor IRFullTensor: the full tensor + @param idevs int: total device number of source tensor + + @return nodes Tuple[TRVD] + @return edges np.ndarray: edges among nodes + """ + shape = tuple(ftensor.shape) + if (shape, idevs) in PathFinder._cached_intra_nodes: + src_nodes = PathFinder._cached_intra_nodes[(shape, idevs)] + src_edges = PathFinder._cached_intra_edges[(shape, idevs)] + else: + src_nodes, src_edges = PathFinder.init_intra_graph(ftensor, idevs, cost_fn) + PathFinder._cached_intra_nodes[(shape, idevs)] = src_nodes + PathFinder._cached_intra_edges[(shape, idevs)] = src_edges + PathFinder._cached_intra_paths[(shape, idevs)] = {} + if (shape, odevs) in PathFinder._cached_inter_edges: + dst_nodes = PathFinder._cached_intra_nodes[(shape, odevs)] + dst_edges = PathFinder._cached_intra_edges[(shape, odevs)] + else: + dst_nodes, dst_edges = PathFinder.init_intra_graph(ftensor, odevs, cost_fn) + PathFinder._cached_intra_nodes[(shape, odevs)] = dst_nodes + PathFinder._cached_intra_edges[(shape, odevs)] = dst_edges + PathFinder._cached_intra_paths[(shape, odevs)] = {} + nodes = tuple(('p',) + n for n in src_nodes ) + tuple(('c',) + n for n in dst_nodes) + for node in nodes: + print(node) + edges = np.full((len(nodes), len(nodes)), np.inf) + edges[:len(src_nodes), :len(src_nodes)] = src_edges + edges[len(src_nodes):,len(src_nodes):] = dst_edges + # NVLink: 300GBps Inter-node: 100Gbps + comm_factor = 24 + for i in range(len(src_nodes)): + for j in range(len(dst_nodes)): + src, dst = src_nodes[i], dst_nodes[j] + diff_dim = [] + for dim, (pnum_src, pnum_dst) in enumerate(zip(src, dst)): + if pnum_src != pnum_dst: + diff_dim.append(dim) + diff_dim = [0] if len(diff_dim) == 0 else diff_dim + if len(diff_dim) != 1: + continue # not direct + diff_dim = diff_dim[0] + if (src[diff_dim] % dst[diff_dim] != 0) and (dst[diff_dim] % src[diff_dim] != 0): + continue # not divisible -> not direct + nchunks = src[diff_dim] // dst[diff_dim] if src[diff_dim] > dst[diff_dim] else dst[diff_dim] // src[diff_dim] + # set for [i, len(src_nodes) + j] + src_layout = GridLayout.grid(ftensor, src[0], src[1], list(src[2:])) + dec_dim = diff_dim if src[diff_dim] > dst[diff_dim] else None + inc_dim = diff_dim if dec_dim is None else None + _, prims = PathFinder.inter_step(src_layout, dec_dim, inc_dim, nchunks) + edges[i, len(src_nodes) + j] = cost_fn(prims[0]) * comm_factor + # set for [len(src_nodes) + j, i] + dst_layout = GridLayout.grid(ftensor, dst[0], dst[1], list(dst[2:])) + dec_dim, inc_dim = inc_dim, dec_dim + _, prims = PathFinder.inter_step(dst_layout, dec_dim, inc_dim, nchunks) + # NVLink: 300GBps Inter-node: 100Gbps + edges[len(src_nodes) + j, i] = cost_fn(prims[0]) * comm_factor + return nodes, edges + + # utility function + @staticmethod + def get_inshard_space(ftensor: IRSubTensor, ndevs: int) -> List[Tuple[int, ...]]: + """ + Get all possible space that can be transformed from layout. + + This space is pruned by limiting partition number of each RVD dimension + in the range of [min(ilayout[dim], olayout[dim]), max(ilayout[dim], olayout[dim])] + + @param ftensor IRFullTensor + @param ilayout GridLayout: input layout + @param olayout GridLayout: output layout + + @return layouts List[GridLayout]: + """ + all_layouts: List[int] = [] + + def factors(ndevs: int, length: int): + if length == 1: yield [ndevs] + else: + for i in range(1, ndevs + 1): + if ndevs % i == 0: + for res in factors(ndevs // i, length - 1): + yield [i] + res + + for rvd in factors(ndevs, 2+len(ftensor.shape)): + for dimlen, pnum in zip(ftensor.shape, rvd[2:]): + if dimlen % pnum != 0: + continue + all_layouts.append(tuple(rvd)) + return all_layouts + + @staticmethod + def default_cost_fn(prim: IRAdapterPrim) -> int: + return prim.volume() + 1 # 1 is hop penalty diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index b43835cf..8ab6cc8a 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -37,6 +37,15 @@ def dispatch(self, devid: int): return None return self + def volume(self) -> int: + """ + Communication volume of the primitive. The total elements + transferred in the network. + + @return nele int: the number of elements go through network + """ + raise NotImplementedError("The communication cost is not implemented") + @property def device(self) -> List[int]: return copy.copy(self._device) @@ -101,6 +110,9 @@ def __repr__(self): dscp = f"{self.output(0)} = identity({self.input(0)})" return dscp + def volume(self) -> int: + return 0 + class SelectPrim(SpatialPrim): @@ -118,6 +130,9 @@ def __repr__(self): dscp = f"{self.output(0)} = select({self.input(0)}, indmap={self.kwargs['indmap']}, valmap={self.kwargs['valmap']})" return dscp + def volume(self) -> int: + return 0 + class MergeDimPrim(SpatialPrim): """ @@ -131,6 +146,9 @@ def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, dim: int) def __repr__(self) -> str: return f"dev{self.device}: {self.output(0)} = concat({self.inputs()}, dim={self.kwargs['dim']})" + def volume(self) -> int: + return 0 + # numerical primitive class SumPrim(ValuePrim): @@ -140,9 +158,13 @@ def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): super().__init__(itensors, [otensor]) self.signature = 'cube.runtime.adapter.vmerge' + def volume(self) -> int: + return 0 + def __repr__(self) -> str: return f"dev{self.device}: {self.output(0)} = add({self.inputs()})" + # communication primitive class SendPrim(CommPrim): @@ -153,8 +175,11 @@ def __init__(self, tensor, dst: int): super().__init__([tensor], [tensor], dst=dst) self.signature = 'cube.runtime.adapter.send' + def volume(self) -> int: + return self.input(0).nelement() + def __repr__(self) -> str: - return f"{self.input(0)} = send({self.input(0)}, dst={self.kwargs['dst']}" + return f"{self.input(0)} = send[{self.device}]({self.input(0)}, dst={self.kwargs['dst']}" class RecvPrim(CommPrim): @@ -166,8 +191,11 @@ def __init__(self, tensor: IRSubTensor, src: int): shape=tensor.shape, dtype='torch.'+tensor.dtype.value, src=src) self.signature = 'cube.runtime.adapter.recv' + def volume(self) -> int: + return self.input(0).nelement() + def __repr__(self) -> str: - return f"{self.output(0)} = recv(shape={self.kwargs['shape']}, dtype={self.kwargs['dtype']}, src={self.kwargs['src']}" + return f"{self.output(0)} = recv[{self.device}](shape={self.kwargs['shape']}, dtype={self.kwargs['dtype']}, src={self.kwargs['src']}" class MovePrim(CommPrim): @@ -175,8 +203,9 @@ class MovePrim(CommPrim): P2P send/recv, non-differentiable """ def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor): - assert itensor.device != otensor.device, "no movement detected." - super().__init__([itensor], [otensor], src=itensor.device[0], dst=otensor.device[0]) + src: int = itensor.device[0] if len(itensor.device) > 0 else None + dst: int = otensor.device[0] if len(otensor.device) > 0 else None + super().__init__([itensor], [otensor], src=src, dst=dst) def dispatch(self, devid: int) -> Union[SendPrim, RecvPrim]: if devid == self.kwargs['src']: @@ -185,11 +214,103 @@ def dispatch(self, devid: int) -> Union[SendPrim, RecvPrim]: return RecvPrim(self.output(0), self.kwargs['src']) return None + def volume(self) -> int: + return self.input(0).nelement() + def __repr__(self): - dscp = f"move({self.input(0)}, src={self.kwargs['src']}, dst={self.kwargs['dst']})" + dscp = f"move[{self.device}]({self.input(0)}, src={self.kwargs['src']}, dst={self.kwargs['dst']})" return dscp +class RDScatterPrim(CommPrim): + """ + P2P Cross-device dimension scatter, non-differentiable. + + Tensor[Tile0, Tile1]: device 0 -> Tensor[Tile0]: device0, Tensor[Tile1]: device1 + """ + def __init__(self, itensor: IRSubTensor, otensors: List[IRSubTensor], dim: int): + """ + @param itensors List[IRSubTensor]: one tensor at device of `src`. + @param otensors List[IRSubTensor]: each ran hosts one tenor partitioned by dim. + @param dim int: the dimension that itensor will be partitioned + """ + src: int = itensor.device[0] if len(itensor.device) > 0 else None + dst: List[int] = [t.device[0] if len(t.device) > 0 else None for t in otensors] + super().__init__([itensor], otensors, dim=dim, src=src, dst=dst) + self.signature = 'cube.runtime.adapter.rdscatter' + + def volume(self) -> int: + return self.input(0).nelement() + + def __repr__(self) -> str: + return f"{self.outputs()} = rdscatter[{self.device}]({self.inputs()}, dim={self.kwargs['dim']})" + + +class RVScatterPrim(CommPrim): + """ + P2P Cross-device dimension scatter, non-differentiable. + + Tensor[Tile0, Tile1]: device 0 -> Tensor[Tile0]: device0, Tensor[Tile1]: device1 + """ + def __init__(self, itensor: IRSubTensor, otensors: List[IRSubTensor]): + """ + @param itensors List[IRSubTensor]: one tensor at device of `src`. + @param otensors List[IRSubTensor]: each ran hosts one tenor partitioned by dim. + @param dim int: the dimension that itensor will be partitioned + """ + src: int = itensor.device[0] if len(itensor.device) > 0 else None + dst: List[int] = [t.device[0] if len(t.device) > 0 else None for t in otensors] + super().__init__([itensor], otensors, src=src, dst=dst) + self.signature = 'cube.runtime.adapter.rvscatter' + + def volume(self) -> int: + return self.input(0).nelement() * len(self.outputs()) + + def __repr__(self) -> str: + return f"{self.outputs()} = rvscatter[{self.device}]({self.inputs()})" + + +class RDGatherPrim(CommPrim): + """ + Gather tensors from remote devices to a local device. + The local device doesn't have any tensor + """ + def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, dim: int): + src = [t.device[0] if len(t.device) > 0 else None for t in itensors] + dst = otensor.device[0] if len(otensor.device) > 0 else None + super().__init__(itensors, [otensor], src=src, dst=dst, dim=dim) + self.signature = 'cube.runtime.adapter.rdgather' + + def volume(self) -> int: + return self.output(0).nelement() + + def __repr__(self) -> str: + return ( + f"rdgather[{self.device}](" + f"{self.inputs()}, dim={self.kwargs['dim']}, " + f"src={self.kwargs['src']}, dst={self.kwargs['dst']})" + ) + +class RVGatherPrim(CommPrim): + """ + Gather tensors from remote devices and sum in the local device. + The local device doesn't have any tensor + """ + def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): + src = [t.device[0] if len(t.device) > 0 else None for t in itensors] + dst = otensor.device[0] if len(otensor.device) > 0 else None + super().__init__(itensors, [otensor], src=src, dst=dst) + self.signature = 'cube.runtime.adapter.rvgather' + + def volume(self) -> int: + return self.output(0).nelement() * len(self.inputs()) + + def __repr__(self) -> str: + src = self.kwargs['src'] + dst = self.kwargs['dst'] + return f"{self.outputs()} = rvgather[{self.device}]({self.inputs()}, src={src}, dst={dst})" + + class CollectivePrim(CommPrim): """ Collective primitive, non-differentiable @@ -221,8 +342,15 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k super().__init__(itensors, otensors, **kwargs) self.signature = 'cube.runtime.adapter.all_reduce' + def volume(self) -> int: + """ + Use ring-allreduce communication cost + """ + ndevs = len(self.inputs()) + return 2 * (ndevs - 1) * self.input(0).nelement() // ndevs + def __repr__(self) -> str: - return f'dev{self.device}: {self.outputs()} = all_reduce({self.inputs()}' + return f'{self.outputs()} = all_reduce[{self.device}]({self.inputs()})' class AllGatherPrim(CollectivePrim): @@ -233,8 +361,15 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim super().__init__(itensors, otensors, dim=dim, **kwargs) self.signature = 'cube.runtime.adapter.all_gather' + def volume(self) -> int: + """ + Use ring-based communication cost + """ + ndevs = len(self.inputs()) + return (ndevs - 1) * self.input(0).nelement() + def __repr__(self) -> str: - return f'dev{self.device}: {self.outputs()} = all_gather({self.inputs()})' + return f'{self.outputs()} = all_gather[{self.device}]({self.inputs()})' class ReduceScatterPrim(CollectivePrim): @@ -245,24 +380,48 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim super().__init__(itensors, otensors, dim=dim, **kwargs) self.signature = 'cube.runtime.adapter.reduce_scatter' + def volume(self) -> int: + """ + Use ring-based communication cost + """ + ndevs = len(self.inputs()) + return (ndevs - 1) * self.input(0).nelement() // ndevs + def __repr__(self) -> str: - return f'dev{self.device}: {self.outputs()} = reduce_scatter({self.inputs()})' + return f'{self.outputs()} = reduce_scatter[{self.device}]({self.inputs()})' class BroadcastPrim(CollectivePrim): """ non-differential reduce-scatter """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], src: int, **kwargs): - super().__init__(itensors, otensors, src=src, **kwargs) + def __init__(self, itensor: IRSubTensor, otensors: List[IRSubTensor], **kwargs): + src: int = itensor.device[0] if len(itensor.device) > 0 else None + super().__init__([itensor], otensors, src=src, **kwargs) + self.signature = 'cube.runtime.adapter.broadcast' + + def volume(self) -> int: + ndevs = len(self.outputs()) + return self.input(0).nelement() * (ndevs-1) + + def __repr__(self) -> str: + return f"{self.outputs()} = broadcast[{self.device}]({self.inputs()}, src={self.kwargs['src']})" class ReducePrim(CollectivePrim): """ non-differential reduce prim """ - def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dst: int, **kwargs): - super().__init__(itensors, otensors, dst=dst, **kwargs) + def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, **kwargs): + super().__init__(itensors, [otensor], dst=otensor.device[0], **kwargs) + self.signature = 'cube.runtime.adapter.reduce' + + def volume(self) -> int: + ndevs = len(self.inputs()) + return self.input(0).nelement() * ndevs + + def __repr__(self) -> str: + return f"{self.outputs()} = reduce[{self.device}]({self.inputs()}, dst={self.kwargs['dst']})" class AllToAllPrim(CollectivePrim): @@ -278,8 +437,12 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idi super().__init__(itensors, otensors, idim=idim, odim=odim, **kwargs) self.signature = 'cube.runtime.adapter.all_to_all' + def volume(self) -> int: + ndevs = len(self.inputs()) + return self.input(0).nelement() * (ndevs - 1) // ndevs + def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs()} = all_to_all({self.inputs()}, idim={self.kwargs['idm']}, odim={self.kwargs['odim']})" + return f"{self.outputs()} = all_to_all[{self.device}]({self.inputs()}, idim={self.kwargs['idim']}, odim={self.kwargs['odim']})" class ChunkPrim(CollectivePrim): @@ -290,8 +453,26 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim super().__init__(itensors, otensors, dim=dim, **kwargs) self.signature = 'cube.runtime.adapter.chunk' + def volume(self) -> int: + return 0 + + def __repr__(self) -> str: + return f"{self.outputs()} = split[{self.device}]({self.inputs()}, dim={self.kwargs['dim']})" + + +class VChunkPrim(CollectivePrim): + """ + split value in n chunks and take idx-th chunk + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + super().__init__(itensors, otensors, **kwargs) + self.signature = 'cube.runtime.adapter.vchunk' + + def volume(self) -> int: + return 0 + def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs()} = split({self.inputs()}, dim={self.kwargs['dim']})" + return f"{self.outputs()} = vsplit[{self.device}]({self.inputs()})" class AllReduceIdentityPrim(AllReducePrim): @@ -304,7 +485,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k self.signature = 'cube.runtime.adapter.nn.allreduce_identity' def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs()} = nn.allreduce_identity({self.inputs()})" + return f"{self.outputs()} = allreduce_identity[{self.device}]({self.inputs()})" class IdentityAllreducePrim(AllReducePrim): @@ -317,7 +498,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k self.signature = 'cube.runtime.adapter.nn.identity_allreduce' def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs()} = nn.identity_allreduce({self.inputs()})" + return f"{self.outputs()} = identity_allreduce[{self.device}]({self.inputs()})" class AllReduceAllReducePrim(AllReducePrim): @@ -330,7 +511,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k self.signature = 'cube.runtime.adapter.nn.allreduce_allreduce' def __repr__(self) -> str: - return f"dev{self.device}: {self.outputs} = nn.allreduce_allreduce({self.inputs()}" + return f"{self.outputs} = nn.allreduce_allreduce[{self.device}]({self.inputs()}" class ReduceScatterAllGatherPrim(ReduceScatterPrim): diff --git a/cube/ir/cten.py b/cube/ir/cten.py index d99a29bd..f46dc154 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -404,26 +404,15 @@ def comment(self, info: str): assert isinstance(info, str), "comment only allowed to be string" self._comment = info - def __repr__(self): + def __repr__(self) -> str: """ Cell string presentation """ - inputs = list() - for tensor in self.inputs(): - if isinstance(tensor, IRTensor): - inputs.append(f't{tensor._id}-dev{tensor.device}') - else: - inputs.append(tensor) - - outputs = list() - for tensor in self.outputs(): - if isinstance(tensor, IRTensor): - outputs.append(f't{tensor._id}-dev{tensor.device}') - else: - outputs.append(tensor) - dcsp = f'Cell-{self._id}({self.signature}, device={self.device})'\ - f'({inputs}) -> {outputs}' - return dcsp + ins = [t for t in self.inputs() if isinstance(t, IRTensor)] + dscp = (f"Cell{self._id}-{self.device}(sign={self.signature}, " + f"inputs={ins}, " + f"outputs={self.outputs()})") + return dscp class IRTensor: From 249794a28d1f80fa9acaa89c50c8084dacce6770 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 20 Oct 2022 04:57:42 +0000 Subject: [PATCH 1077/1892] fix typo --- examples/alphafold2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/alphafold2/README.md b/examples/alphafold2/README.md index b83d8fa5..8f6e2264 100644 --- a/examples/alphafold2/README.md +++ b/examples/alphafold2/README.md @@ -44,7 +44,7 @@ peak memory - Triangular Multiplicative Update using Ingoing Edges: $r^2 \cdot c$, where $c=128$ - Triangular Gated Self-Attention around Starting Node: $h \cdot r^3$, where $h=4$ - Triangular Gated Self-Attention around Ending Node: $h \cdot r^3$, where $h=4$ -- Pair Transition: $4 \cdot s \cdot r^2 \cdot c_{z}$ +- Pair Transition: $4 \cdot r^2 \cdot c_{z}$ # Results From 333183ce03b46d11575533263991281fabd6f5b9 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 21 Oct 2022 04:11:45 +0000 Subject: [PATCH 1078/1892] refine doc --- examples/alphafold2/README.md | 99 +++++++++++++++++++---- examples/alphafold2/images/evoformer.png | Bin 0 -> 159728 bytes 2 files changed, 84 insertions(+), 15 deletions(-) create mode 100644 examples/alphafold2/images/evoformer.png diff --git a/examples/alphafold2/README.md b/examples/alphafold2/README.md index 8f6e2264..693dd934 100644 --- a/examples/alphafold2/README.md +++ b/examples/alphafold2/README.md @@ -6,23 +6,20 @@ Benchmark different schedule plans of Alphafold2 based on MagicCube. ## Structure -TODO +An evoformer block is composed of 9 sub-modules. -## Challenge - -TODO - -## Problem Formulation - -TODO: try out del tensor in functions that to be recomputed -> offload problems to jit tensor compilers - -strategy: detect memory constrained parts then coshard them - -large enough size of input shapes already utilize accelerators +- Row-wise gated self-attention with pair bias & Column-wise gated self-attention -> customized attention module +- MSA transition -> feed forward network +- Outer product mean +- Triangle update using outgoing edges & Triangle update using incoming edges +- Triangle self-attention around starting nodes & Triangle self-attention around ending node -> customized attention module +- Pair transition -> feed forward network -should include coshard into the dp formulation +

+ +

-## Memory Consumption +## Memory Estimation notation - $s$: multiple sequence alignment (MSA) number @@ -46,7 +43,79 @@ peak memory - Triangular Gated Self-Attention around Ending Node: $h \cdot r^3$, where $h=4$ - Pair Transition: $4 \cdot r^2 \cdot c_{z}$ -# Results +parameter +- less than 1M + +## Challenge + +The core problem is: the evoformer consumes large amount of memory and we need to find the minimal execution time under the accelerator's memory constraint. + +According to the estimation above, we find that the memory distribution of evoformer is different from the classical transformer. Using GPT as an example, batch size is 1 in both blocks. + +| Model | # Parameter | # Activation | # Output | +|:-------------------------|:------------|:-------------|:---------| +| Evoformer (Alphafold2) | < 1 M | 5120 M | 66 M | +| Transformer (GPT-3 6.7 B)| 192 M | 512 M | 8 M | + +Assume the data type is float32 in the following analysis. + +If recompute (checkpoint) is not used, the whole memory usage of $n$ evoformer blocks is around $10 \cdot n$ GB. + +## Problem Formulation + +TODO: try out del tensor in functions that to be recomputed -> offload problems to jit tensor compilers + +strategy: detect memory constrained parts then coshard them + +large enough size of input shapes already utilize accelerators + +should include coshard into the dp formulation + +# Experiment + +## Usage + +**Two steps** + +1. change var values in *alphafold2.py* +2. run commands (add *torchrun* to **PATH**) + +**Training** + +1. Evoformer Stack + - shape config + - bs, s, r, cm, cz = 1, 128, 256, 256, 128 + - bs, s, r, cm, cz = 1, 512, 256, 256, 128 + - bs, s, r, cm, cz = 1, 512, 384, 256, 128 + - other config: dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 4, False, True, False + - policy + - spmd.PASSingle + - spmd.PASDAP + +2. Extra Msa Stack + - shape config + - bs, s, r, cm, cz = 1, 1024, 256, 64, 128 + - bs, s, r, cm, cz = 1, 1024, 384, 64, 128 + - bs, s, r, cm, cz = 1, 5120, 384, 64, 128 + - other config: dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 4, True, True, True + - policy + - spmd.PASExtraSingle + +**Inference** + +- shape config: bs, s, r, cm, cz = 1, 128, 2048, 256, 128 +- other config: dtype, evo_num, use_chunk, is_train, is_extra = torch.float32, 48, True, False, False +- policy + - spmd.PASSingleInference + - spmd.PASDAPInference + +**Command** + +```bash +torchrun --nproc_per_node=X alphafold2.py +``` + +X = 1 when *single* in policy. X = 2, 4, 8 when *DAP* in policy. ## Training diff --git a/examples/alphafold2/images/evoformer.png b/examples/alphafold2/images/evoformer.png new file mode 100644 index 0000000000000000000000000000000000000000..c3f30249b54932b7bc12be52d538a45f9b1b6593 GIT binary patch literal 159728 zcmdq}g;!MV_dbrJfOK~YDBa!N;2=n+lyplEAYD=dg3{d}A`OEyh@^B6Idpe@59;$g zulHyD{(|4VmgAb4zV^QFW5i34JO(N$DjXafhN6OuIvgA_JsjL)Ipin6C#KrI z1;AgAT-D{J;3@{mH-R?@mXa?d;oz!b(QZr-f%hm*3c9Xva1>z=e;&oI9a_M_-RCRH zNNRW*?aU+UXy{TsaVn75+2hrkbl=tV7_Vt+^4Z-ZY-%EaN7Kh|iuz_dUSsI<0$pRt zr&TbypM%=~$Nb3?u>a7pf3WL5I=9Clwa9*&84}I^vzSc<0>c0ADvN|7{6ANHWeUan zzZZq~mJoxU{y&$M{{Q2T**#4=z^ucp{GQ_Q;$^<4nE6m5C1XN&7^pl@DL2UM!`KC* zwwi|36e6fUc6ku#@{P>zW$yK8$ocrjYYvRVz#mL-afjkybW7%u0g;cJJ9`m;1+D&t zAuO`DoG^!!dlSvMrv;C1BDlfM zkB#$t=^DKc#e2NB`#tW61e*0VDBVy-Ddb{$%cQn46O8p&Y{5GL?pg%gk{va#5$*<7 z8SISA&(XEdxo869p5I0ns0%*T47eh$z7~TB4&0frZgpw5K}MxSqQ9qA-LJn~g+Y)t^N`kp49{dI;WQXICxV zXBv8(l56i01tW7bigSA7M~wuVvp_KgWDVYrf6E`=bwMG0e-jZ`qd=Xfo*ITH27Oug zbGRrWXS6w0uD>4%df3>lP^v|%S-R`VUdLoAAbDSSVED5#1-bDtt<#{p$J=2ufi^7} zOMyp{LsbQSE&}-E5!`(9;t{{omZ3KClatgiElUwGyWN3AI)Vj#Ce|$$t2otrws;BZ ztSsp}Xd!Kb4D}ytGWvx0&b!h`4(?6Ww2*zZ`q_cv=Ty9coK!*FZOm(UEDF1`ORT92 z^6f(O#9f$!I1503Nqd%OMyTETIMjLk=fIhJdTkK3Zfjb74M}veJD_)8?7#UChlj%%G{?PJKhpq+;kF=y@`=A zzCGR&HWe~x$z-^n-}b5*@a@K(D_Nll^{OI@L6TI13Ky$y5}e zUSee1F^bmBy~?p%#7WuPil(P#ml;-0TZ9G@xXwJS1R?g|idNaJDV2%6#iP~+T&P-> zxfU$h;+?gRO|PO4A@XnK+ij?v;Fof-fAe;M$YJbn4S+Ruz7&9xfEs>&|7QJ5#dh8M z0`hSv9^Agr$|ZqEmzE`2P=RzR|2Ot*0uSuuk<6nq$y%^l1MZ=FsL)`zG%`(WAvdKl z!^>fF-0RDNb%n|i8hL-Lj%P(VaZc@26vNl&_`^jS!-8}YBIQ$ZQV&37ws1L&qSggO zvMV{6J_!j$pkr=TT^OjWw{VokE52gNss9280ou1f&5)#rlyqlM!zM=VX0vH#xTQ|) zy;OW7lGuUb2rkHW5n|cU5^(?56Diw4}`ng_O)YSvV0pS zb;ESKZC;56A)Hy29d_#|KQM-G2^S5EbCNP0qZC5*FSzPnLG_OG4N`g<;>oH>Iz_Kw zUO(|U(SmdBeJeQUj#OpX7~W9wy+AWkP!xTLo=EJ)=@+f$f%_Vwg2_Qt`9t0@p<7op z2$ucUD6uog1urs=r$4hdzGA_?{IU>?N*_!y#>9)eVH=yx>USw8K#h#Gt}Ovh!tClc z9;@M`J?+Ro9-&xeiKlkRvGVwt3yvhvtGMHDT6K!ZDL!c509i?ha! zjl&RPuKd^`6@;5MZ0)>_bO8T+1)O;w70>f%49V9y=Kvw*fCd=N??1# zZqgJF&Tkt}M0JCI#E_c9>gkV|+5cjxuQa;-Hg<+Yv%fcdVXR%(hxpTY=75V>!e8PQ zvI5i{#y%u}*+^tgR!{0M*DLVqt_Qy9akDDGW^TQP&ay1;E{2~m2WzcpnO-}1lx2<} zK7VD4!)_^W=VLD7RwjW-U!-9pU~c>Fll-1?LZ+l2f0eBYf%$MrR{sbF(&~5bDIk6K zSV}|2YzWA!$2$^z`46|M@AiAbjC2+nkUvD6wbL#=A1ZpXp!bfcZ~P@abBxjU;7S8e zz%{w0);C!rc>*4inb+6M&IOmN&e%p7=q2NUB!VN78DbY@3{tH4I=sc9IB`*;I3f5gK?E`)5f~ zIrX%#M?S}mt)@Neb-fPo_?WF_!>bW9{nGJ=*c|wjBw(Ma-PX$zUy*JG%}k~j@g@Gq z2KG_?*sx&U_4>+{Q&tpYs3^YqfZX-pUE( zr+6zwF&44izLq~%;XkL{O&9`7ATh*>u*3}x+fIENBo}o;AZI$8M07!Zf-*GNWGeBA zXV_{qt8FU|TNUut(r!pmj(-bJala|7);5li{lW6aHjzT-Fx)g;8^OzlSk3tV@pv+jzgw zJ7wdP=i*-d#CU(WvqYg%jau=AC2`}PB^CN$ZV2B7oF|xhQ|b@%z*pp?1R;Z+3Bn9E z+KRQ6HP}C)k$ezR*ast%DR8j;L#%UHU}6(IKi^%nVs@fo7)xM2r9@foe(jgm>rYa@D0M7+q~ zS%!t?m~&w)uUQ3XPQJg}Laz_3l70J9FtKAUPSHj2sYCsVOvjV-s*?}XD_m!r*o|*S zC9ajJ(l}JEeo(^+^Qw=zqNG3l*xXT2(3drF8;UX+NWZ9h`b(T%^Ks&5mvKA&kf$Q^ zQ|Ozr5?Kp2xgNWz5eKf}KUYw{cl8m(^D3jDi(B)h-g+b2biZoL{$VPTOuFh62wAim zo#;B6%^q5=K7YMBVjWk!>)y=abHiR&Cn*;_g}D-V(tL{*Q7%8Gpy!73(|uWDxFe}3 zRee#u6$d1-WdD<41x|BWNavKMo)I?9rJ;(Am9NhysbwhiZSi>_`^kzqsKImCd?&ai zryqRQy(A)J#Z}XF_63X3p35eSnFp09uh0_k;$|x2sqbl(?jks%yu5QOejSP+T=Z%z zC`$KYZN8vDpV2OmCR7y8AaQIF;e$Tg*=7BZi zXB(a)e$*H@CJKwxA>I>X{Sz%R@kFlB;CDV#f;n_9bR!Zh^42$*`il=?Qnq`%F;wrm zfeA_%3VD&rXEK;3n8IQ5xcfax&`7R)wGInmRF`R$T|EsQc3qbe!XcVzZ}IRhK2Yj$H2Lj_wA@%yAO(B2py5b%g)a0vo1$f zBrQu2$Fdz?+TBd-GCo_T|{z78}pLlKZY<5Gj9wje5ql4m&wWE41Uv5j$#VVUj zjT$k0&{475=$2qVZWRz;avIthjwpwBnUq{4*>+z2@#?mm7B;{%=8iwK*ay1|t4Mq4 zrg(7TTQnpvacpVz_L5`g+uf5#M4TmYo}lq6mCMIjwr?=TiuF2nOQYyzP6TC^of%O) zqkT58Of6bR6^d(veAi@=_g+v#5m;S5+ri_$QwD`(zbcfvl`V8WT$cTHlpZysIR1s5 z7o!dy-sw$e<<&+UTn)alfw-N`$H!ZaA$TXTYwKg;OIi1yS?dt#UB&8a()j$#h?wSaTYkx z=wZe^BO#}iW*TFzp>y*a&wM zabZM!6Ug|djz5BShF|%H-}mnBi0skzw%3^l6~+wt`;vcLt5XJJKJ-6ZzTevz%8V`7 zuWSGH3(sSJL000swx2OQO{63z9tgQPYWxLPhu9neuErPDB&^moCAwWnYw!0!SFenB zA75+Uhx>MC3I`F5Fj%XUm(1gz-=0#(!Ev&uWXe}~M?>2zMbCMP-+{7!mbIcvT9lC3 z^{v0aKLT+C zd<(~m8pAER7rz!T&j{964#VbQaUcx)S|fq8flE!VBc1#ZP@!xj%cnS4n9+i4)NlYMYEycAcVf2i7y{x!&s=p8my4>uRT6k2+EDG;|k<9oEA&xkj|1&-6 zPbs*&reUM|G#>HuPPH|(ce?&F9fOd^>!~dr3o^kO($6T9pTP}*D!Wozw{kQx&=qVd zTODt2yNBX%7Ofz68xsKqtb zBX8sn%+40aB^Qa3sB*LOSf6&HSrCE_JmjyxT|qvz)QzB$W{de?$gtfmnyv#B3ef2N zE9!PALW>8HR<>nGnL-xkg3zmccbP4bwPA2Q8aR{Wd`@+0o?4XWs%LT6sq=2~`A=8N zsz2#JEh8u(kfl>)3LDKwOG{68zP)n0C zcJjkV-O^Cd@ZdG-KgZs}!*m-!j}~H`X$Y&1c}ETk*EVZt_l1*;UC(YtCQAH7?dt&a zgKx9Axwib}2d1numlp#A_~~-<1~D-;KQ46@io#MP;a_}oxH!A9Q4o*}Q*ql^x!EpF z(Pd*C+e$2|TCFu*dJ3EOy=biL{FY6j)}$k6-LAbB2cJG~oWlCcaW#p%Boz!x-k9;823O2)`y1G}tWN!h zEZM>^bJtC4W=$@P{3EgUC$chfr356?yUQxU?eTlMC;mZs(lzmeIt?!QrsXrR0QrDS zr>tYuD_8F9e#v~aLlPr+S81{Ur^k;gO$|6HzIi?A=((pWxY>JW9?icLedlaMv_flw zAs~L-6Hlp|tvvr&e=T29i~{EoCB~>G8H!~{594@A^nvFyBdcns_CjA+-hS%Vn0?hN zlX|ew6u~DRlMWQc9AXeMVX()!FBIZJ8tPsaZAyYCQLf%bEpJ&*R$F+`bEd_ZJ@H0< ze`QhSOd%11^f#8Oaw_Fw$Qf@Jed`?NU)DJo-ki-k_NH*=-Z`c%y)JQH>y3xMy*}?b zTzs1A+LCpXsSU&YC+x>K_$U$pd>5VKn=}9q~hSVwP z3yy(f?#F}C5|3~|MRkNG+Bx^Z((^I`ZC0^<&s_v0o3C%P+O4VF-6XuDNv&ralrw=@ zO5om*oGuc(&=A?J}M|5M|C3|%(tNVfW*1bSitSkT{Urf6kuOSV@*h_ z@Jxiq&kfRd!-B^O!Ek|lX70L3MQ>#}wSXGt_jSo*)!z*y*giy!8_WEiNcuR!9+kAtgbGt{oqC@&#G}` z_&kV`UFNww?}st{pf9m}*jT=E(Ao}okG%VDK+Qs!?|RR5b~RZ=+#Ob+VScotrpuRx zeej*A$z!+8nvun1^%n+9MG~ovCe4^W*{JJcTRuqSc7-`6VzR@)@L>uZDZ2pijWX+r zF#+6aHS>*8Jx{@ei~i|C8bfW6=h;uIaS<9C=BNOtcKV%LWk#phNN!)fb#6sY#A5;} z?Io(wKXYtwrb*y?Sq`R20??vTYulT|oSay6a(LKwe|Ii^sq%1}t9V9-?E^phnfzDM z>CLF?i^#e^#C3*OqOI-{o|JU*LT@@nkYCGS(oYp9WH0ZunsC{ZJ5jhYKM&GNptPS0 zEPMxGJ#uqU0|Bm7eqVTZ>3+(b7;H`LgdK*Q#D{d1j_=R-WNA5t=8N`bT+^gUGpZQt zRN*nPF&FjN;oWz9>!WeIBb_6XWV&Q2nj})jQuc1v&2NTn{f8TSQFz_LxG=;tD#j3g z+)yH*5Jy}4!%il@2-cFm`9xbM)t25s5ce`)vb(#~x##5ZlGS~n6Mn+XEy(ZIaSd5K zOhJUrs|m%e@r+}j0Ng(QX)RvE9_cy~4SS9kLI@fb5Ai{J$rw+HqAbuhX}%cyRAhF7 zRNVSo@ARUy_dLAgkX!=-lNEPeXMudTFB@_t=9!Gc?eUkWmjaR^j#W$wB8d(!(X4l) zzHfe*YZb5G&P>LfCgtgN#Oxb!YEffHZ45(9!j(+_f{>;{#S)ldFKv74IlR7ug|*qA zdR{3p_gR+^<)qSV!hQB|3?S@($C4$w2g1tM@fO~ z91#+i1ZP7Lybv+nA_=ET>hFIKYrC`p0s;h_ml2+z6Ti#!znkVzeeLKNyFQS*@*|1) z=#z}h1RgvvAu=G=L>9VC%~0ID7f{tL!;l*?e1D1OlPwn$Y&`9%9)wPD&H)6>qdLdM zkLvN%$^ymXHe(fr`bsLe4=D2K6*@iaKBkGQROCac=5Ri10AKEUhed}FOZpfE0asU$ zQaW*K^qIPj+(I}x=q3G&&^*k|1NpE+okkNhzVfuty*2;9Af58;Zd2=6e`*y0$)Wob znM6=Ys8TXvzR>%wcHZqJk#Z`VJ6z-Y>k4?6w=)rzN9K;h{l2hx{rf4NzL!OpRxIxG z!Kcb}OO@)I*?1f7v$SD^K(c#h4w^~z1z&A^d^}D#FZFa`#VRvxv}xuSz9Uj$$o3bD zS;DeU!uY4h&jR%lFkkSCBrUjAsu;<>)F#el4@D56!2=B+i$_P~CrFi#Qrj0}5;r+v zObDt%^S?tbC3}QHZpu%i6cL6^h1GZ0Ax7 zefN|Zv{5wicY-}bMijSi8ejifchlj>sk$?ef2Nv&y@ec#hwVh#akPChWQbuuK`CH! z5gVH9RGg}_=+84uR~|P0M&AN_KoW!-aT?1ki^D_3NoC9L*gbsu&7^I~|n;>ik&1t-t8 zC~}8(M(F)@E-c{{QW76VWSwh!yPN(O`+m}&6na!%xSN^MjRjNKA3k`tVz9m{f$Eru zQHD+?JxnzjwD26Y=34#5x5kQT?GMShJceuS^??tJ#iGJ{kASk4bq_z|gM1X}!A0Eu zB%YP`1MgzVJzi!wWPQHL1NN&hHqje@g-$XM0XCA0JmGQ08Hc| z_35W+ruU^IcM|F0D3SE`Cwlbkca^CEf66{Gg_Y^w0_9mh+Kg z0Hy%ucgp)Aumhhb0wvOS1samcG~!!=yWK5ug8sVSZ|8|M>G?m9-rZ?PrgJbeVNAVG z4ON6y2IHNGrIQ3Lqny+!s$oQR9--V#F3mKFe0?@XOxcvK6qP5)l(mg(+AZk88q5A~|3Hf;Tq-yfx%WXlEx=X~DxY--n4LzruC1`g{8JL)%e%aEI zkE1!)4g^_BlMK9n1yrn>&mp6;r)3eHpwk!dOGC%bvvIit)Kp+AN2@tmpE# zz8O2E`?-fNZDw8w8vOR2^L!8zR$c>q`rY}k>^1^xuSegUb4s?fYi**RZ|07W-E$parxM6p89_v0J;=9})I zi<29=K`-(il>JiA6S1Exu6c-;t>5RFAp&kckdht6vHZJ`%^nd!H98-ml(Iw~HH`+{ zP%cLfF{T}e#Lg}@!l6R01;OL=&Go)GM)aaidlON^euq$MHgC`6qZNglo;TlZ=UQOV z36vYo+tXnO^nK!@Jp-o{XZ^irw93fex0{Hd<^BsBj-nl5y?5eWmf|A52iQYv zWPVfAn44o~4E)ytCk+;eZTDS<28W^{uf&p_uf3+0A|}7j4L>glMm*0I1O3i8cbrm= zs9g#vEtA1d-1H)%qHDP_G9Picp$OCmIG8VNBBhGoFW}&)GhuvKEeaK5BRViOrZ*o; zp_oi&x}!>QiQYt@F{A05D%EE1V8_z^YTg!4E58c-#zWRs2Of0hQ>qKy#q{0yoFNgs zE3LMn){a@^gbn*NydXTmW;^{QQO^7h$I|qv)u98ski`VF@NADmuQT{TI6n>u1k#(z zU&+(+wAP!L<24)vd`Uy~oF?*f3L1Uz#Rl)-9qx$^hmxZ9_wLo39@HO%tj6D4CYel_H?oNyqicCSE8eu=OQJ zJHVgFW=J}eDV`jQi{smxf<=$5mY6Q)OG?6J9_ObBaCO?V-C24n>faLRAK!I8#aA+Z z#C;bNVUNtPj=0p8acE_9MBk~6ROxMLSs2JSGQ=?HOTgBWpc)0LJkQe4Kra&6F1Z_d zD5*Gz$7g3B*U3Bg%RX^ic88!5tqCCF=83_q#^F;!1|DcEa{}4#wuo`~a@Egn@ zJP54bBk%J))&a%vh6qugv)*zU8Jr%c$7(FV^n+8H!vr>*?ZW(9uRkyGrzg9!+MiRy ziP$Ndo10g)XmY|$1Oga7HUKj>fX-6i9_%l)G|egfQSA4txo^6V+lO_A(R}5wFf6Jc zz&#ebAMUZdu@MHoJW>=t(XX|OxWC=^Kf?JRJ!rPkkdY-WkJj=VB6DP;+7}m%#iV5L z#egxNX?%AynX>o}Fzky3JlhZaz50Jz2|G8D9UoyL~+|9*degm3^hX-IG*!Dx8kC-9i6}Z}$pe(*5A4|zv_nS5S zHAh#PZuJ{QSE9|){EkxXO28B}7U=ZYU60+iCzAZ?iI_ByjUnK8{XaW=f2?f^difGt zDT(R%fXO$9c|AY|b-)tw(ZVjP;V07b3kybiH2-^juhv|%7hX(E%&0+VY-}uGR(wJr z0Qh$N@2{yM30awLB{Bt^bN;Wng+npO=msnjnYBR<)HJA3z)Xv*)}-UfW|&unK?621 zF~k4SV>*W4ap9@g(duYgH^le)K@WfyVox(v*#sg#52spH{>buwwXhokb~N(@F{)^I zaN=uz$4{REdveB_nug)HQ*=yTeB`6u>WZ#pMn*{1c6Gy_!6# zrH(MHWc@$J`SWTw8AmOh*AT5$uGhyZBa;od;j1JtzYEbT4{d1?1sp>2f5orxIp0$} zCd0(U^gLV+c`(t#6pxI@q#c1#cM52p8@kSb(7Nv^911wHG)H;#NOIKe`7Ay$aWMO% zo#T(4+t#yU{wOjpv0LeNBQNgQAVYZnxR{rkS$Zg4So^__501iM(ke?i3Gx9F*(GYX z{&=6vPWS*zZcs2p&#;!-C*MNhy65q;dwWWyIvw`g{;ziV%r)48r9!!qs{HQ_mm(;_ z9GBZx>W-MDCZ5stvm5ziChuwwTQ1`y+Sfe#)Xyqc^5K}km|1IVEBQx%T+W4vIYSFF z;&z2pYMT#-?KlbfAtYAY@n%gsq^xiM)-p}#8vA%ML%Ij&`ej)gsX0qEGNbaBA^pYzPTwAbtp|bznsP zvRFeQ?&+hOTQff{kjh{rhg64=a9mBhdRIhJ+*Ngcsi|=jK~))=Ka;#S!QDVk&lkz`}mufh7~#%32~cnk|x(fF8Z zjb~*rcAP-mfRH|A--FfG^P^JPzd3AvZPet*qs}Id6TFnB4l7G+lH)mWEff(+t`Q zn`42}?zH_7|4B(Z{RWq{yoMMIGQM6QHh&9M@FG=7{7W*cK3$zmCE8?0_WW>*aXLqz zS{F(5^-gEr^M%pQiCoKl0YNK^yUfQ<_api{A}-1^6Ot-)k(zjkMpRW@a)t5N&c}Dm zmlG$f4q*;Tc=Jk=uQZ_EY*ylS?Ca%%Q3;lJNwq{vVuus7bOYiS{%zx__a^Dri^M#; zDJNH8Rvwz#k0jiS@Q#bVANK>=X(5?EoduZ@KPOfcB0mk+bn;dec&|CjrYw%P;Cu0N zT(?XCEbeu^hX-a8>;4c;G(5fur^KgwJ`ShRM`xQq8Cr?*Hv^`>l5oFKaHXISc25M- z^dT7qr>iZK{azb4d#X7^P-eCwnBGXSclE`Ad9`0(p&E;L-e`%X+O_Q>Rpx&QM1G&r z`o>iFaTV_)`bx;&bQIQTE%Ev!NF#)gtjce4(R1ZoQ7qL|@Ad0A9wMA)Z`qF?r@ zCl0=xj0&~N2+Dl>O)(`D9gu5dyP7sl{Pr35rA&@e)o>KY!x#(B0Oi`-rhBvr zjO)F7qZKSy0z~d*pU(P_43lxTWeU=}}8~rw(KT#Eq`MLHnr8$&DiM^d> z^7vx$o>+$VFEY*^K|7FusWI%W5csN-KLAIK;{g_X+k%kto!a+X@7d~fl6924XIvIJ zuEQsHYbhojS`&wIres*!j3B3S36rl_!iF^Q10xjetL!5VPjPt7ubelDyV!D69mOxO zT(o!%MT1Kjf*aqi2rBXNV`N&e)e$q-8GV#O!*6Q7zd%e#z6!o<7|_K3Yqws)wSiRa z&&LA3SKsEOWWfRdY;P(= zqIp_VK}Hc4v;M1ny|#76PgPF|=qB@C8WMS?`7mTkR*m=bQ;j%xiX}{HjRm@Gta1CP z=$|hlb4M7qhLLG2g~M0e98=xdZQ2~oL~0T=4~l;n`se6mq!+8+x43l?pd>u6!ybu^ zdBArryD5%Mq|>P?vo*qP&)W$qg4xfIri}f@_dcEG&6iB57~YMF`^~o_;RMfr);Yt( zFBukmmyYh8Bd$mEO|;Zf>Zg>&uFQ0I-4Pih!;Fw{#PlI+AHa#rkE^am5-kc!8N!oX zFumsz{u4tt=n2M$brHH3pP&6J#wgGMZz9!HS69!OdI9Mt^@*@xzc#IEnJ4)sHn)=r zmG(_0k5`MW`vU5{>Z2nP=S*Ze;{9K2AHRw3boDjWE0yIuy{dQ4q-y&93J?cdpO0Al zelfM8V?GQ$vrChxQe1w|rXQoS4w;!FOX4p;*1kU5nI27(2b!fdLYk`)U?^xV&7@ty za+|eKZ5cafstaKK%iT)WP!#mT3O%64tmfh;N5CodZnKprE~a1LVxQM})_dI(_!W~k zng;5KLT?s_fNQiP4x(Wz-M)S?{koS{bG5fi;&jY(8dlXfl?xaB@{AR@(1buYZ4x7o zs=&u}o_`&blTz!vB4<6AYtj+wf`)Sodw>aQ0Iie!US^8<>N!-Et>1c((6Pxh>^0I+ zWmB~Q0ia^^kU;(N%LRar%ab$N32rng2j}HPM{1GVogdCwpQQK=%{JbLf_WT-+|H;+ zZzvl({qx;Zm(DZvlW!O(MSU!BUt%+NJ2lb$O{qvl`55x}b&($*S(*QJkb0KHBOp#5 zh&W+C4)e2M6`gRs=C&i_AiKq2StGjJ*Y|x*NAdNAw_jR9HyF=m#eFX$z`=g_nPNye z?vwVReho%8fdPlAgjgZO*E#a`7mFL#m=ATVG@h_Kg6bWOep`MafBV zK3Y|@9$bkPzl+lkhsgt^Iv%4hQ@r&g=;~;#-zCaI1bhXWu@pD|YI;$YaD;P(bIWVm zn$5mi40-H$Bi1Ls9AhZ&NS2XXR8QWXbr6vRA30`h4?`QcY5~6F56&pg=jdO<-s6S(;Uw33T)?Paf8~pW=M&ID^oHsPrEz$i>UQJ>rzGO-e|e5-0%-}-v7oKB9sV_N0B(FQ_OuQc zk$Yi*+9Kl8^I}X@3|F`XfLegs@rlq+8dmwECs>d;?!Wxnr}u9g#^4RiXl2%aC(TE6 z`tY?>v`JS^6(zeFvbr-J`u>y^T+-LYbAAbPq{8!;=h4sx1XYpNgjnnlF3c|~I;+p? zxJkGaM1vS)!DSq?`w0i zqZo~0aY!h{){^M;lx!VqNaXX=Sb($CU9>yZg)P~b4L)mpybgT)z{mN<%=B#ghGF4c4dk*!}Z(zR~;n!-|9G`CZR8=mCAIfDrC* z#h(2XZvP@2@&=$O(*DIG)}?c#O>bR=tUNKRcr|233Lnz_j?3cHL!iC&q9#r)h$bV^^F4cdjQtTu(@}v|4PaDA9Wc>)pURIB&bY}6{Bw+)uI}o!&2UGqP z^%@S(Q5VCf=`n8K9H~13-{>?Qmq4rP$QreQZ;U??R)qSkxZTD7XDQMDEcLt{1t^7l zDq?w;zKpFNn`K6gsck7%!RyT_1jwVs(Y zbcm*S0up`Wxl^r`4z!)^leRY|(yh9y-bVj+>|ZGXfAo?+bS#~?(BM~=RHSrNKWY|$ z-1F53&;-MG<0^fsI_$Qy%P6#GmuE=&L7DRFO5pw;HEV#?UN+pm=`a&t_BDR^CbB#k z{kEVV2LJD#3Xz@D_D{6I|>lzTof^M8kLQrd09-4NpWPLlBuYK-{lG=4{BChYY&nIHO)R z_!*n`8WJ&9k6)OS>RiTLluqK?O?e~VWs3D{GrZBXYZ@8$0JC@95Wf#EWM3eobd6ye zbSbsJIC>5@I~S|GIH|PN+M55v=v5W_4RR8jwtZshINojOP=_lo;0jh)%2S3 zFP6~Q0*D*_^Eb6&!2!WUso$zvKQI}s*(IMd$G0#+-OpJ>3H;M&d{VbpEgV_bI`4P* zQF5vS+>eQ6`xYTfUP;cE#rEr81iO14cs z=gui1N&03np7M~CckKGrg(ov_e&}QI+Pd|85yPu;x#F@*+Xx&u*1qScaX&*?HbRW@C0kty5ty+?I)Crbmdt2i=&Ca*6ugE6MVxrKlqSIrx2Ckke2Rs?&WmGH1 z`&c<0^8LHSWr=KSR@Obq6)1A~_+dozXKe$sxBqN}eXS6tB4#(|3M7%Psp5G?kfkCX z&$OR?ICV?+N_$L;r~f*-vhn?n@Ou>L0i=lg#+Nt1@nwfS`Gi}p{| z+VmDD3PhzE$t`1>cI)r>ps~iC7BXia)tQ-*FRQ@C z0|*-{Vxvr^|NRJGt@fY!GEGkX|CFt_0@MZRVjdW}fREZ+>)MMk93t@&%f7f-O!{e* zjq0MlwVQL-fqb5*GhacJUL+COt)H{Gw#y!hdL5YJ#G=2+Zm{U1aOh?c!iv+2Gn*EU zvKP}rC+{PX-Sy5e6jFUXWe)`Tv{Ys$M~AHUyxuu^h%(k7BdEK&{o^?XyN2zIexrNo z#D09`o0C4BH}To~&DRg6{SD3D6;of(X&V~3?P-S5F^u&2q8Bg zrAHw!y|WBx@cxSLdI@#P0qjDtYcherZ-AIr|8p6|@;3!Y$w(0JJ-5#Y-R9}mus?d8 z(?txfzUY`eH(c@8$*cQx? zh|Hi-G>S5f#@kkXZ`7wf=}im@Z$78GQCl54V>ps$eu9^y6XX+W7$VR7=$z~H5i5;zJd@4Wk=3>8#>~c$QQPrP~?}P%VE`L_<%)fjl zt{xXeEap(m)8{#CQN8@^8S}E+w68eVYDJJ%de1*11mqGb?E8t;x~DF0_e6TOoE*tT zaM`P6X^=Jh05za}!h#$$<(3EycD%f~L{rk-jhuV_8T!-SxJ%JPt8@F^zDeGszrW+8 zZy-3Nd`KV-=yd>u|Fgb6oNxV8^`5Z><2fw)K6^27w#_H!{P+XGnpKS`}KVH(JFl%g=@ z#k|qe7>zBqe`tLCSJ6nZf0=}W9Tt7~?Rhn!RWPP~C=*ENE2K5;3Fsl;m!Y{EN8`03 zp?EZV{mF1Q>h&mkyi`RO-c^uRK6V|}mV}${YQ~o~`mxReZMQI1AP>4^GKuzgh${ON zpASYR?3C?SwmylG4i#}bT;CP`b0u_F_&@JA{%3_LN|1odav6_TjboQwv-nDoEA4iU zdIT;5Re@5D&2$ND%2~jD_XL?+ki@p%>``bUQ-F|YT2ich_T2(SbJ|tOsn2Suy_Yh- z5zj+>wR=F1UTmID@hr+7o|^jlx164Ksd{ahc3$KII%P}*0r6ZU7{#ybdw0T0)rI0d zaXwiPv^QEa@85Y9koI%!L9*W~2iEZ~G-ARGyzmp@rv4dy!b~?u8trouiT=E>lz?z zH#AuCS4>}dEj!&+Gjc>Y#Lsu1x@$6$pr%BbpSy zTCFq#OWATZF#+cC_uOuK-VbKA7Jw%F#}g_3k9WWAabb2T*TvHlOShw77V&BY6a1um z>Kof`Cc&7eHY=IhLJ{MWN)a2qwh|-#tKjuYNy~$n7+dFZ)1O5j!~~++pcQ(Y_a%IE zyED=o113fY(3`l*+S6|&%{{{&z`;J_Z1LHLtdC6CE-TWmirgLf`b!6pOm*Ew8Qu4R zMEy-MzW-9JHVltxaz82*!@D2v{US1b_>o=+M3?u=FT%D&Lx@RCt4JeKa&03tGdIH| zrmNQ5=LoJMa>Hc`-jN`N?8h-jh5cg8w!BD38Qn--DV__VJ)NV^+wOk6Q~-D^wzYWAsZ(>O9nK0 zV5NW7NCY7DpXMiG7aA;H>LDLhaJz=mjU0>zp5w@VN9f*&o8romH&5h|hLok*_D!Lv zHlZRcT}X_Juq=de5NQ>~hn;*k4ypLV8`()j%Ix!vCaS0(g;Gb)NQP}@B8JbnDW?_~ z(A4u-qNh=*!-pf6Z!^-~(J>~_*tUEMb7PjBJkE0p%v?mejMSH(r3T!pab@5(+++QDxl@RFOS4qdU!&-kK4fMhAs=id0F8d= zY5uJL>Czq$Be8gbRUI0voqC<6@|a-f*gd4Oxx6DG-(HOaLa{^6oh#SBkh0J77*Fq` zET;U(M?7Cbs)RwGhI$;XQr?UFHJ4`8hM|OCBRs`Q{utxEV^*9-q2GK2bLHtU-nO27 z_o8apWhThK7TpZq5+eAZmCm&2zzIC$A(ix$Y|D)}*aEM|Ol`k_Ayo@a*k0>Im#q@n z`3GVy>a$?rRMZ;j>38zN_>|EKt|NL4>7;K61&(JQ2l1f6DwkXR+OtRRJ5b;YxzeyR z-cZPKr@ih%MWJs#^>{duXE<89%xFV3Q3}NA5^p8-LWyZ>nJgSi>dfe3X`MmQ>`eLsTH9`R?EY!a3lZPzR1E*@cUn<2^L$~g`GG}6 zbCtRR(1FjGy@i^z_UY5v0Z)8eoE$>HD0n%qd&WIoyOB1$iGp*51Z|?-`MtnHTJ92U zo;Go503n9k|yO89}?Z#4I5a$VRkeaNELl+;IIqQQCz~W zoqT~&Q-MskkFU>ru>_6YVX}uV1u^@gbr|_K;f(GMe&6r(GnlV} z*M4uiVRWoL2n+bI~=MgUO zA5`wAdncqwP$t;s)IQlMlW8f?)(AY9d10C}1bBA-zyAnendSj^(z@DFu&Bg3f&FGq zV4J?asfm_>K>^tE#v$PEKoLsZZ^10&w?{3Ht`Okzd8%g~izKY6#Pj*>OTFjNt`ky) z?9rUpg&3Y@JZdBcVgkQ-vAzKZ_C>~}ZG6O+1!|l;iEV2{YnF+pDcQ`UU4l>K;P4Y>e_EeGQuCfT`6Z=4~9IEjqi6=wPpM@yY@u1dQpv~*}x%8rm z05uq$UZ?H>id9a$Lh;J?21}R7-5JYyc86ObtI2^?v={`IJ1M<+?KxdC|WWI5^(6 z=JA%1*BA!8s&t@9<&iD|P~c#V`cvqnv40nEbR>@eZ+$2eD@iG2 zp~a`)Qr56H1pk~$9XMLbqLuff{(fqC>1L!IWj{7nJ{`e*czp}Oai80>8ba$gUJ_Tf zj39#UddFE-`~jcI-W}b5)-l&<9l@ZKK8){$sGaT%so-w4*|47TPdwqR!T>gXeeyIz zERt{Hxvk2*_B85?UafDS!FBTIOUq{z^{ziIAT7>5y7Y;lyg()k3Om6n6k1o??7p$b zNcr~9Rkr&USOuyrh$t`h{PmFlP6v#2x3%7e;LTtGH4`|T|HVO8TYy@>%OVu=KzRN( zDEbCi^uW;qL}2rOs>U|?8oy#OZea2=;|?b0mn>AVyVQx>O>NZ%FFoU+{&F$ zjl@blDJVymgNC8ed$P(`s*o0O^R!W_-fk?E<|m=-EE9b%ax19wg3iFV(!Z3vb|oX{ z{oLu#mCvMZ!=0^G19nn)MJ3#M3!9~iu6*~b5p1Q$i?>t#t=%Lf;{IT(X>XDx?Lfod zfhLLe2i5{XbBMIUrCq=)vD;%69h6d4uZ~SIDLfpgqj$sV!H@{ zVq=d4eVS;*#_ zcBEMW56l5iDZJAw$4r1D6?$95!;pLRO!|pQFC0FRh8+=fV}pKo_GoLIxfy(k%Wl;C zx(Z+U1Zza~e+-((iiq!5?-V#B_rAA0#4zxyj=RRU%jLbYtKN$0AR$GgZ+p;S4A%lP z#aJPR&l(uQyQ)^g9d?iJVb6FJasuT_Wj~;<LVEFP9728Aw0;alwaE zyx(A4?G(7!#%T3f2}3h*$7;&!nZX8o=R`!ul13$*y0wOap$WOe_12Iu2Squz&FRMN zdv6N)?`~~0c9|k@^N|ICTIuAx@(s$f20FrNU%+S`C^!V!PJ!njyS)pfubtHi-8I(5 z@z3-%RrKA)zKP2hh|uM#NeRPHPc_9TqnI1pQ7A(9xqq|$o`)m{KA2_(###RmIC{1) zyb9x1HqpBAQtf`$w$7UvE8|wbmN{JK2P@b7wX(3sUgDE!VB=x*8CtrOQo000qzCC%xi$x}_D6PRXIWbEw}v>OS>;&&yx42hIw2UGa$v`CO+Eadml)l0MqZHiqj~*_3KA z-$nYb(xs;cT`?lVSH;rlg0Jo%<9@Gt%SR?HRbJlbc`tMLy_$sOH^TED2z|0x_Crnv zJXy4ef<9xzP1ycSc)x#E)?Z3xVq&51h71RZi|yAEJ)Lp$`{^6R%4L&MnT-oxKU%OD zt@dCekNJMY;R?3em-=?`aw&|iNBRbO;8l6nINo==jK6_3oVGHN%^py;CsIgb^c3qD6e?XhM6DfMtu<&Yd+WmtdfFm1ko84Li||ICJM1A?RE^o* z>!Aog)(7}3R@2^+HWi+{11l8ZZD^N*BHAtgqjH-^30)qDn0CeH=v3tKjgc6CeS2dt zsp|TS<&AB?Je-vyk5qSADgzvKD zBcVCI-M~z}tFKvKiagm7pq^DbtPulcDVnXVt&5x6{BZ6w&UpD?F9u}bDoc7(kQWm& za#c26EZFNV^24ebNM)1Fp$vIT!9u!YbNyeIpfyK-8L+27bi1gOk8S){u+6Bon@Fpa z`~rx3HqIJ8MLw74v{Q2pZVS|isvR(iHs1fvFVazW`c+>Fmq-&fDf^4%o{z6b^5PC%0ZVY{ z0b_)HdO?nc)kjD;ST#cB9P7zS7z`GLp%2JbbIZ%jbeCQXSPZ38esYdMp)AC7#`78; zl0puvgj3VgrKa7tt){Apii(N`@3V2nvkuTnH4$D)9zq5TrJ9H?We=GKQl*-RFFEe8 zkqnV5z3|r@#b+rP?*i$+Kb6w%8}Vx^aGvaqcN1Y6($I$lTSIw9WIF^sVlQrM2{h8w zCA*oZba$){7jVhoeY{bOHuHJK4~xZRZnK_O>%JFO5t2NarOR0?LdmXWaLv-r5MQEx z{60Z9{PWzZmtp(w%Ws0ICM`>G5rZ@Kco;ArcW1q?DMTX`)+MGirUK<) zLFc6BFM}Kc<=o&msijrl(3* z^y&QFidC@7F4&57$;=DthFt)@XjI7>6w++s95<65M7?!AJDm$}#GMR)@3l|g)mIU% zP(m8K$R3r*ZsOT1WwzsXfWI!U2nK1bGu^A(LoM7!($#S68nK%C_)m%$56dThBK*o3 z19@x+xG|CiMRbl|N6yArd>Ry_Z){*bwanrSbq!7K7giL$T%m-vl@ZLml8z&k7Ub@H zu;wMX&4l*zaKq>R_oFKH4Z>ORYx1_5#Mm6cm)*`cNkE=cVhZwngwe0m_@B7mE)DV> z!uxdW%|dyHqbEo6FzxR3jjeo2Q0;yG0RXe2pU} zCT4DkTYBNA_1?HMf(H7sdLG55@vMr3T{6e^gxJTtDu)w=*6^Pd1s#EtD!gf zquN!DqlFpgoN6l(Tb=$ffw1-iRvWmee;t{p-C@6Y?(dfqwczcGk8uGW%ux^Qg|wd> zo`Kaqk-fmj6vF$$?~IQ}ELAeIP5#c}`|}d}Z$UaDD~Nd_o_diQOj(&x*=DJ95i*QI zVN#DqfRH)qyrxmhF#h0L?WQ)nU4xTEWg6IUbsnX$5oAS*G^#q6KTY22NH(QWz5fn7 zPl~j3y?W1cl~Bw}A=XCdW@BK(QRD(iQIADL%ih)}C3Uh%;l)nE;E*C%iy_nTuZ|mF zy*AgP4l%Vl9AIK+-30#zxE$NsVJ(2PtJ37<8RObtJVNJv!SIyOySr*x%JaC+UWpH` zme^VTlC9T7=L^X)e^YksvXpa8-Q*P50Dp*}((Bk**(bRz>Az^23{#xpW~^3v z;kYEFVt!zJc-yXMuhdra31Bzu)N%hiZ~u2&u9~&TA9B1s&0%eqD(Gmj*q2JDp7%U@ zSrYJw!}-hyKiSSU2ArwwVeG&K>=pXF2Q%gOMOe>A;j7-X?xZXCxzs}r65=`iU7sk% z4+?+2t7u0%^gip}2lL#Y^>)5Legw-1!eMh?ZC+6%l?u`j2F66wz!oalhyK>(`6j zyMAjszL0`5z*|U}m-(d~hSSctlQK|DupNIg)!>yweU8^YVxvywc7I*v=&N(`) z`SDp+RqV#aGY+I&jcj9ZRV>SlzFOE`u83AwzaUwTBsbDw%5L4r+1v6#B~QsRpev$^ zs*ch*;8A7l|Ei8iJm=M*;yp~kU&Wh!Zf3b=1rbh)@LB&P(J9n$TS1O_SSlT^eV-1o z{|Nqc)Oxm{*3%rdp3q41?`jF%<eKDy_6p@eDbo2a5QtB^f_b4h=cq9gv+4uFM*X653AJ z-o^9K)YNhC*?y1bO&^8UWgkK^;ECm+dQ${n%{6iZzjL!0lTs%G9_ zy)_RP1bV4RD=Kx*eU?6udAa*cP}b0v|Kzb9rqJvA4?AxV8ShT}vq?LY3@vvAg&S!? zT)ba%TCk{z$CR=i2OqThL5@u{>@LT`4R3J#V7VX=N|W)ZBgbVSr#p87m3i6$$dMSq zi~Or-f1|E%14s^rTFxGSDs`gvC78rgS6owVvS+bDLlmd>{k4T}Ah%tXq60~b&HpUU z#|NK5pm|+@!&0DKhGEhX)@)vN^#60!wsjvVTmq&qs=;1bCyL%dF~^A zmM&O}rh?tB))x|tNbpvnK45|mMt^AoOVN@PEm9?DDF%nz(bu-Mq&=B67m|&U{rGQ1 z19KnZId&Bgrzf!$tyuY$U};qmfsn%`_rHX6mSl|MK*M)1-JvidLC}qH^eIK8fId~a zgp6#?X1UzS%02i)dEm^Si|(u_BN=c*h^tjfQR*p6LPa z5+0jyPTxDoVhXWm+PIYWX)T%YJ9BDpMsv&GUQ{Rq>D-eK1DGgm#c%i5t5*b**tSSR z__197y97qDpIWih*m}Ulh+PIt@jK+zJW(V?+3^5}Nkt>@m&c-(;P+ZYWucB(@KXeO zlL(rMjAfu7Da(9K%O0p+2v`n3egimEBZU&-k~=zWMpC**hhP6{X9f8rL)O#xG!@s0 z=Cy)8dR~E=MjFY6(Y@lNZdBI9i(Hdfy3^szNN~gtm$m{(kKQoUG}>a$YWSa;E54X3 zFohx2yJ0>Q0Go zS^WQoTW~=EYJCfw&U9svN8a*>;5uUSkMT5DnnOxFTAo`u&iK?}r=cJ03pHXk=EF<9 z+XeMRH?v%N$JX1gxS-qhxPfCkZIX0L_Et3@BKD62Bm9EE|MgeKCp@A_jsji({Dv>B zaar!c>Er&7$kz^lgnE~~Oz}Iy?#EA%)j-rVxU$?yhKVa z7Z=I@fjZi;!AsnZEY@nQ#!JMJdLEn#^1y`wuk+&aGAoipJ@1ZZksrb@C+@$Ktq{;F zlkmru_x%Q3gq~$9#fft3*Ooi&T*#=W3Opon0g!vocfD#w2% zR;ZmWAX3_$x8VIlHsry*H9;syi0v3krU$_H<$Dh8*zwe@*svb5=5Gkb1 z9_ZO{iVPaqJvkK!Z};atL;u za|;9BGK*82@y3E{*g_}+*21>=Itt%H-S}e0iJ6<00G~*B_8}P>4zr&^95t7j>*}O1 z%zXpHh07u`T}P;mX9BZRUmsC+aO>i5<$Q}G5+irGKa(?{dFau6;|CrCTCKo1E*ypm zo$%qM&N({6y(14;t+LS_wcIF9Z=7l?s81DgvPrPyc6Z`(3>r=;2Yl|!Wk}i%c-?^n z@mD`!b=jf+zo4qcLLy2fmqU~Q?GK_{1kwh=#7EW^2?r!gb93|+p(;M7dOado0MXzV z*g>_L5Ui{jS|I9x@svMcsid76xVyctzCo%1uizm5ts}S>9K2tK_L`;&Jcq>Le zIs3WrcEan5KsL_DHyNKafR?ZPY>VRUeI21Y1yG_sycIXZ>!gDIyrJ~=LrFC)hQ*5$ zGyK+yR~9+Ii}+;&|BO3lo8j89%3aRjiEWvDwxIoK-*(d&3DVG3tOge9c)*eT1NneE zsnxV$`l7f}aZxd{7$PPi&wVj3n@q{aG_8&VqxSlH~ zpEqzXMKb&P4BeF)|8Zqc&jC-5F_Z;)*)A}!3%bE_bB`~s<{#rWfy9WmI{+e#(BL5* z&U^WqcZg}zuJ7h*f(8vG>T&V;lKZLD{;RU|#*1%ndB;n>-CQ|NZ5b*xCE{U3p!c<} z#wy`CcFpp-w7uq}Eqty7({3c|TN#tZsGy{v;{4i;+TZ-R!l-8^bqGJK(E2Ax1*tqB z>^vK-Y(M3dA^TM=YYC)je6VKryvz6y@W?$ZCyg)2GeBlI%B2{H65oa4W|4h7j0=kG zYwz3_`2*Y&NdqYZQLDhf4o`1zfT7(Z_&tYJm9>KqQjSDFNL#0AcYK^^L;;;Sy%)3u zeRAb#^XOR_Y&kRgd+9yLqrbQd6G94+2UT@eUpWpIG>;eNeHapE${F3BbH2k@?>6Z% z+(Ik}OS-C1jF7EdjSEzxRYs5p$tK11MT2xIA0%tl9(EPV46B>vtW%DVnM4}zvNQ;` z#TG8YkDuE*TK3v?p4iuJr1NmfTlV6{c@UtOoET}?VdfCWw|{9(t0Ks)a@R0AIHZ7< zw?e$OFI9ESFs&PHpq9r+Z6-~ktNhTP6E;x9@}(wfVSedJ{&UT4u@p?Zj${@~Wp0al zFwOp~Bp2&gZ2_1`H|3p<%+Xxv~u)dzM13%)B$`WbdS&fHgXa(H`1_LpbteB@98 z_&bn6j(`GI$95W-?jS6bGc@M2;WgWDA}Ul`?!lg09Owrz=qU>4V<#OPP6;PtSP9;r z@{^Z_NSJ0z*EFBcu~eOJ(FBQAILWDaq1Z2BJXf@O)L7~ftfSv=_Tg?J_eI_wJsf;o z_&v;^yCE}<@-s;|rqELzzU2BbSks^2&8)Mmjn9H?sEx4Fol+p3R5BrWOwq+RHw&_|V(+DyMR zB=Ux58c%iRxn|%0=^}i)50Dpzmsd}FcmxClC`8E)o-okK4ih|JaT~#V5E5JixjY)U z3QJ2;_NT>pMqB-8P7^9Orsn(iv_=y33jB2&d#^vl4+z*5cpTEDf^m6jbqZx#{IJpn zUc?bAYue3!4~mf`Cpy_?qiFXgf|80_P0KryTgfD?^Yn2&36 zTXfStLunWO-h=qBhRMhmZ;Fp!Yu_W?gse!Ug;7dEoBKQoy9CxdW$HO|hU~~So+kU4 zq6wv&r!30lTwjqWU$<_oMGiwa;JLDs$H?*8zVd8%tNw>r|C7Vs?pp0MHH zIm7ZQB^H=HyFc)sS?|6|sRMh3DhOxUGM-k2YpNZnnwMAIVNvwf98^D!vLePL=n@ zD%P-SxHH>zOfJ^@=VBrJzJ8?#TC+> zCsVGiKj{d)cb?zV(kkj7rC}ZD=btp6*0O)6&yDQyn>Z2n@^c)?`}t`KIwVONfg;Kv zMivAh(EayCf+W!_GG6M)3f#z+?lNRI-b`O2PgcgUuB_azNr7%XCzF1*zDTEm9Ap-6 z3=t`{+`9E#vEbS$vuVp#q_Rs_in_B=NjjP>?Kb;m&2g5@eA>msOMK2&KlYj^1+iY- z6H*?Uu1AhOV|HhD@=KwUp|u${$87=qQKyu-2P3V<*`L1^m2SaWfY>#<_GR0QyjRm@M6}8P59LtYF`cipSC)g6}G`nY?K1?~ju8>rQ|RM)}VLqYs1I zcV#wt6tY*%-kJ6J=Z3vOVVh95LssBE$yNw>A|o6&+00rxWh3f%C680wW06q(}{|qL`o>2*u&HF={kTsQE2AH#l>P-`gSJ{ z%WKt2McG*A=S+@;@-C)OS{*u#M~`n;+?K8@m&Gd`YhZxSz3XnUSM)7=ZL8^l|A6qQ zq7hR2QqHejOtY;N?^Wj`AtP~d%pB^27r_U(dau*zR@r5Tw?**{!}@SQHU&1y4>ahN zk9S(cFqLlx;wb$9JRJ7bkVH?BKFR9MT4K!j;`S^_gMi}|cRjn;?HE)0_XJOJL5`4LQ-bJ zB-|1Cuy))7$`JnnwEX6i7i;pk`N0p2Sj^d+-LCw8TKXq+W3%WroPpbNrV||J@MpDy zIJrPu6xjB0{0L0!i%t>qHgH>TdVT=|KEOt}I|d|sBi{#%d(-+C0t6#`e+-Eic84RT zuiiGuqgFaPSBnm?r|;a6i+b$k3>zn&6u3Zc6tB)UiEYm`aLlfsUj!li-n5Q=t?`fm z{;~d{vfuWi1aYg}*wUO>{mg}A(ktbm%~aPeW6BOZaYA25eLAH8}e_PqN&R=X}w+k+UB#?#JRYPJ0GZ(~N` zs_s+VK$z)7-Do~&7OPNyFd9!GHo9sIpv*pQ>V!GJyAZgY{*26>YGFICW^zuaL_bM) z&+wl|3+>iK@d5%y8yyhjWSz;Z3kzpb+4n{GNs2qn#X}6&avJmis@h9tm`XhlR8J{V zA!;mBC6l=<7<`stHrBN-mK`*w_+2?iIl#&HLzF2c9#UlebX|=1sI3fsb zP|&et(#{%iu}Hv)+vwpdf0`vzZ`=U0Gd3@S{k7SJy|=9!IdV<_U$_Qn(AiJn`-)GH;ja2;D`IC>lP@o{GD1f{jITRv$_#Ey#bn~ADEx`2Aep`&;#paK znz5slq4JIlzoG~D&7#$E9=Ewiet#X?k6$a{iQ;jXL)Tw9L^XEilJ?s@`JLX>^?iO*=*7y7jYt7N%HW=Mu^BY$0ehZxdU01pbn+2} z*v#)Rl!)*+NUoX^UoXi1t2KGzwccu}R)$XJ>nK+nyUOONF;!*+=XIS{`Mpu&@TV16gC=z?k`x-{|ngA86LF&b>+)v7=w^9ONbN5Al%-9M_^n|KXVbTqo* zc9Ho`P#w=o3^KT#>;0si&`@mLzQQ@6X~O#NVmJDJ`y|Ki?rDpq?b<{HI!7s(24L&- zqd6z5yvSjgZ5$ z|2_c>ehHey2dThMIYw6k&~vbW-IFW;nzQ)rA}JrTOB^VP^}tsPrxR55{2(Dl5hoP& z`Wl#PBYP@NnL>@^y#jAdfrV*uHMK^vsIT*bE5C;1+c({617;!{p)8c8iK31QQ@2=v zYEy14^}cD>H?A(5=~{oinm6*k4k(8o%eIRWgrATBbXX9LuhxUydCp;1pvuYUcxUgE zk~Urv_o~>n0<>A>n!1&1ELUr=>gG>%*0+Z5yU?D+kYv^Iw1noT`L6%Cpagm5LUu>l zUhnIYo78kZ7AU}8lgzz0pfdD#8~D_%LB#bh%|}Jer4OlKOC=xX1?(CMNR@UQcgMjt zj!li2{c9TuSXR&Lb?GkB$|Tdpl3|7x2kps3-t>*0;(Ud%v_E4jN~O_>4m&TL@E%jR zKO7uO3SF5;b?i?7&F-$(8b^q^dzhx^ZnJAPbB=DK=(6?hsQ0#ihV>VTe5KIaLj)hK zPhjLl51&aE{s?B+Xs_w$_7pqgB1b)qbrMA&6_?%fpeD6nK7$bH7LR#Rm?d7NagEM@^^Sp7-t za4nCf|8BdQvB^kiZt?N2`#2)rrRf_r0#H{8l(HKeoNGr+t0)!B<4J#UMQF>2kC3)UqzZZ_?~3 zhEl(@Gn`zV=IjtnkyCNdc`JDCIN!ZA&zLdiusBOHokZ>fdzR{5X^Fxs5B2+e$&Oza zuTYJJ00^MiqC!(zEjP2L>~kQ)++CdzI>Vhby-3JVbnt`BoLhCC{wVTQB zlkHR_8G|85=C77h;gw~4>*h$&?%*|?zi0uii#}K(^nJth?&wA3ZBENTNXkX_h0;DS z&T#+PkM2;goTc|A_7}ga);U?%-1RUPeI|Te zq+i#?)#Wt~YoVcYaq7hmKPJO=DU8nTR< z$j2Cevl?fV_ILUICH=|BvDW$2$mlFE0^Q}oRCLW`9d63RS~x}WYM(7@am418+V zDybgOMU@)*ujSG2*JC7=PBJ>^&ciJz|^ji3;ZOj7~nYx{Cv=Wpy-%LuIuLUcXR!%JtYLlii>JEn%Hp(17=_Eie!eA1~+M6#8yE<9(N|!nz|;pcX|9 z_~Hc0UEbp;`PuJ=mZo9XGm?K2M>}0T92P!{;ag&thY%khpT&4v$n~{ak)FUE`S<7q zbnH#4i-8{@>L3(MPepu}-bz!NB@`N~l)?|c?OIn9tr1Vw3BW-RGIB%Wz}%6ye^CoL zJ#qW$a|#6=MZ&UIKi=DLaJN)l+(T34h0`BDr|=1hu*f{UOEs8V!TQs-%!oc<8pPg} z^W}HG$&}{Z1Reg}*l2rjoDLEPzhptupFE;IpSn5jEs}-vxnyq>(7gQYh z$l|L6GW@EZzpcMeYHdu>x>m^d!v;(HCjp3E?~FLzoEWeCDg1mc_&`#TDXt42t2?3< zGV7V3K6N9{&$M*2WY4wv$f=LrDn%46(8(o0;MMn#2v`f|HP7i9>M{2q;Mp-&q&!BXjRtMTc$H4Rw(k0Fi-bl!1x6jrxDI8tqA zQf@m-%N~b~B2e^&K%?(~Z;=2K${T35L2$_e3POk*?%w3BUg9Rk-21 zq{fcU^r`|HB`shrm9_&~nTEeP%NbnYm7+?av45O@ot3M!M$*3o#21%5*7M^B+7;0< zNkoq9%l@A8*jjdNt4_-Z^vH?VaT0eud1mRp9Qo?mjyg(-7GABqyTt-1owL57Ole(2 z?SeIP9|v4_JPBTShV!Xd-p?3&!xM8};vg*xF2*mUD&DG|{x?|-6v`K#-$@nqUYT6w z-5Oqx8+VpH^O{#c{p>{qlZN80HKDN`sVisKKfxpYThI@LAK>gM9;D6ht#rlV*e~}) ziyBnJ(mZ4qds4*=YcGL4`BJ=!m=hF1{kI2PMM!nGBH~2X@WTF?sG%W9y^-_q*M1zM zz^?f#ga+k-PyWz?WWGI=#nYhJK;$e;#dfA%9N34V)KBI(m?nTX_!Z>lnUTQfietJH zSbb&nH7y5%1ZbnZo?VV@=7-KJa@r|uYe2|d<(Tu~h_7Oc*bL?yeE5j_d{JPaq|;&Q zJ;o4&Rc^B0_0xw8>q;_C&(5Q(976odv0-K~4B0`UDt?dj?eynqdpP_pvm-=Q34Tlc~za%4I!+w7Q^O)KbE* zQ0nxPgp;cOc12l7ph7Y+9?N!MD)eNS9k=hpF0%d`jM;$HmEd2CycO>VE*lx&ZYfi&y9Mu)Ta`evp3h9Vg@j zqRB4Ug4F_?>eQOpCM|q{!Q1bzmm;MP;k)zgIOd_}z+2vAp(FfTib&m|aN8&@SofGd z>q_SCrvBuoX4FzPcP7C0h1^!+c2BcUr3}OxTWpPz4?xEYGzaTf5k48%24V z0gj0X1H$itS?@Ql=rQB3+^sJsXCTgtqpL(KEm*WDEhl#3;Yf)?gHVyptbd@L%MyCp z$k*v_>W(#D!crc61qrvm@HO4ikQ&n<8K6*d*{Q(xEL-kjp35# z4ZsXN58utb5SL>+Ug6T3C}+JdkQeTyIz%u3*?1M&)1GO2PNNlorT&H?ZgllSt;?E9 zWiF&XA~40Q5O8xPRQj)@ikB7m*6N4~4j(1nOcTMCcQ0tDD6Tv& z(E7#Tzla5mTvn15sn$Y3F(RDg4!Ut4_^5(R5%{Oc9IcP+%{8O34Wn$|&u+oCem)d- zJCwV^9i+D!$s;Y-;mtJ;eEI#od@zFvb&zV#B=&i}RtYL!Qv@j1%eVv=!n?;G08^hE z4znW3Ti<+P_!CPxye~Wr6}!=iA8j>60ADl20ol;v=R)Gh2boQw-+!)BwhmA#mB7wi zsarV2o*n`o?${ZIUU*d4?VT15Cw2t%cwxXo=kmrNrkpJm?9rh<#YNDZ)R4a1>VKI5PDUI;4_#Hi&IhZTef`u)|E9=J;5sM7+Hrl!v|vvUp#te zy=|M*KVA-GXyTNw_ERPYgtyISbmDeO>+frYu^qfJ9*PUC_Lv}j(CH*IiT`w<($>3Y z%AG_ryf?h}3vlK~w;tWB4hga5Pdv2S6O#BWM#f-Yb64#XM`i;|gsex;m9o#kBWmX# z(ktt_u_c)N@k*75VVSeN#yEm49Y`eu0smLNO7R0X&5Bfp8XSUGWvmWG0H@-?PlU;$ zjM%q1S#MI%ih2YAawY`jj2h{8ce@Tun+u48w5hQy`^Ka3Vz<8|w*jQbmN@UL2T%Bd zvbr*K%l#q3{6+@;z#pke!-<`Xnug{R5#6MDbmQ=27gx$;cZa*ovGTT1&1`k~F|C*a zdJL|6I@q(o2BeH;!3*!|>ZRVXDSVf6@q(OH5k%3tcA~V9rxa;S1i0jxdLG~kB8jc=%qr|C#XblqnOm3fGo#x2jC3Y z6YShENO@4)uJXNU8XhBXo$+0%k==!!5Y|P^bi|s94vs<@R&M7uQHJXEG0}f@suVDv zGRp=|G;vDlieaD@-F!SNq4WB{v4wr?`1GOO?#aH(2n(X+Aar-UAOGh2o0tY8g0gpl zvi%i2b~ISpZJGGz-aCkq*;%}?TOFntc&+D5qI~x_a;RC=yXerKf3%^%606};FH^Lhu6bgkMw#&xGj?&Q?k#lH_Jy~%Aa zT5J$9e?lCIHn^wsw69Tlu z>DvDD4cuSp<59o0_qLs%n27t{Shg}9=o7Yr#siSr2G-TpdE6-ghM;mQ_!RfiffIH6 z2Ogsjs>_eVB<$!xu`B>sEjQ`zn~!*n8`L~RN@+q7SJnG+d1_3O6Yrp3W}jKOZ#NPu z-F+ru{ee(jb`uAm=5)I`L?QRnZHN+0cJoOT*@ZB{S9=M?i)cMO$Nhb+%1-`MucsPH zyDj=vc)L0Sq*)cn%JgZwo)J<}p~%GT4Z0tsBQT!?t94Kd4U2gJ0unGpZ2NIuyqICg z!tC5GwQY+xCRbHBB0q!H%NJ%B#{>lfwoIbqh=fk^rxSFSmC0y zY9DeUMIU@NSwm!0JK16-US$Tbn;SXklu?Tk^e9%JX1k1OaR{Kht;4w;Pkv&8)7k^S)TQ7waMCASYS#{0^*dl zS$rD7Z?Y|K?CoQEb9E}%QJCdQpYL^FG*An(qS_s$-0%{R|CmPw4fg>dY}Vs#@hiRc8exOk_Eb>ZKeuTnXJ(V98(8udQlP zmwb#En2*OyscKGM=ou+Qf}I^yLTpKOKi`-Snzu0Wm2a^l#yF?^tbc65N7eoCIrvxIm8ZmiirtpGl<=S?H75OxZS7O# z3rk#K_?M5nw1E{Ki?An>1UTL2sAXXTqhEh8S5L8z$*9H8Wq&c)Ln8;YnLc@*ACK3& zmQ4}IYdvL9P7zkSiJowEEeu|1m0z#4K2#hxVtpiE5ltyeOb)az_#C z%O{xwQLR!JCL8yJ)D^rPrJ7Ibccr%&@)1kx7CPK$Dk5pvwLT4Y6SCf3Qmt5`lcn~< zw|eM|ba1kZJ6ogd3!cwXDM2?K3Dy*Oq`ot6z|K0n-)x+?A9Wb&e3U$_5Fbjn?Q^#W zZ7eFJ$B~jQ$XA%WlZ(HYDY9MuURkkS|u2u5U<-9@lV_wfl0(PrkzI-B)oZU z6K(*S8^F2_y&j(U$qMNbbxbn$p^Jgy4QXM>J0&Hw`-9=3mQkm^~Pky3S*mA)}>2){zhO z83s%ncPCmg!=pYizMA3{`dpl^sn2Vzg7%#A>VX=|sKV^IdPHtC^T#WKZvaz4&#CMg z+Y|-EmAQL(vc<+O&Kqu@1#aD|vmV}X2jJX?6uEpil%4R7s7Jgmm)RnGZs!1JEcM@0 z5*WE31w#CQ!W%JFV?THi9R=7+e6}-ssE!Iil?!WC3UwaVw1X`E&$TdxHp82mn#vm_ zo4hU;#)^FPrEu5}ed##iy*_C%o9cqq>$ zc5`xWPv5rgS`!{wuEY7gD zX|!{vEC#5?3wQ8XI44}h*CC<&dQrNZiWgyo9IQdkxrWupPe9fgd3A$z^hJ1DpEI4k ziYT$C)m``}1$DAex&+$uNi$sb&m&*C4FpMSXKper_ciBT3L)(J0PR(X6;S-WZ^_WZ zVO#sRtF5#H;+QL_lll^7AFUsO$C*fmRj=)wqCktnCG##vL~3L89T{gX)? znuKg2(Cwcl@`(rFQe<*bGAdl88iNrf(RbvDFV1(G-sH3O27{M{h0+Fqj6NuJWhRjH zq>4&gec1r5THsqp`2Mb90GN`|U5G$Mii!Eh!b`Y2=8`4W0}@LM%fF;!c3kk?-8G?? zY)x(V%TZYQhl%m~W{$9&udCzTg_5dVAF-YGx?9ewd2I05B3$zPfCoy+ zcUQHu&E`_=p;Egu{Cbj9Sw1+SFX;sA*Bpnm!}Op)^^5I={h9|*?~S{R$o!EHZy&Q= zM3v<|FsQ#`t$Cu0<&~kSVa1BO#uwEycuR#}B%VKX$H6;X?MbRR%H_ zkyxdsO2BXxWoLZqeSey&SEucmjp0%RfEuM^F49NM+FwfCMV$TSyI`9G+bpc*Zt%hLXCZ3*GuBk1Itr za;+x5NV|z{ZxdaMeFHhUVwHBI2LPUKqfYxHN~O~;_UukiXnjnQ^hqO=Vq1WuhP@Y5 zpYiy{^o9C$8a8LOsMrmvR3LE&Qf6EXU@Sv)q)eC=F%mW5 z!%P(w?g>Ea>2G#0N=eJ9#ks2(!czZ*@ z{U}$SJUjncn|Bc$ddY*DW^t1}Jp|#`zb{mOV|ysO)pLR9l#Tj21#G3nhIn0cMa!U^ zmGdqZW*yMT_TH$_kg@r4ozsVwd(AZshB}ne@&-sE>sZJ1#}=7pAB{|q0k7NE^T_D) znon(jK4?~wsjZOnoNeI6G)2a7ceTeGdWAb0%}SboXK))1SG6a@e6%Q70EIVmPNB`y zJ9YNUPQdW>adOy<`kOzcGpkb>DAo!#$z$?`yw3Ke-44JY(~+h4-Ecr|cf0yEZ6FrJ z4Q$dlSj$IsEul|TpMTi5n`JE^k*}bLe!haO8R8-4AMGp|$0rMlfHHPF%bn|yh+!_t zGLX)qFe+LYXfW=9E7=WTUEqiT-b%&E$JI48-wTOa-;w(xXWqFuvb?(f?5}^ZujQ?T zsMbT^w>|5=S65OUIa_)aqQ$*rh|K_;SsQHR@J!@tF3f&sjQ~(8u%$5BTT2`MCr5i>~Qv?6aVq`8|`S7ESbbE zh~1v*hf$0{XRP;Z{ zHFO_=@ZfWcIY^yS>P3f6ubwM=gQak0-%5VXeYH!+nZ&ZetuLB7%!T*%MHcQ#6=1W# zwR}k%++NHd{&2}~5Dq!Aa8-Oc7WAnvWbRYsP1Kuw9yy~hi`L1#8>`DWf6ijz52mOr zR0UWTAni@S(3pTLZUyvvjHux?YfiwATVh4^+^+5X%DL%a8k_?%yO(NMZDZG)4m}1$ z)7i?c^-B=mFNRMSl}YJ>W1wG^4j>j!^#Vkk=aryFAW6++6(w+{Ym7$yTjw5N;>V## zl~?4qOHt)NudCQ2(3H`zuH(G))|}9wDuG?P3cC{x@Uv?x%{F?bY2@3TZ0pzDEdUH% z#dqj?QJZGIH@RHf0 zTOO9orJm@wi7oqn7V=9=4Pa$-qo3tKj=jIT8>K5D;oIiig`G#S|9tU~ndc~Y5#4YTo8=Ux;h6uB5ds07GMDGww_C74x+Agv9tQn-qjz@iMtV3Lgp>#v9WdE$ef|I!|#a2Uk{U`Id#W_-1e7V7dk zZEBgbK}ywp$Q{6-sg~{!vsLv9!x;01(6L5FjdDvLvmv*+=^&(*LB=hNM;3kr3@cBp z{C%H#cSa4K4V3BG^;l1mC?BmTbS}O-H9bM z^rMiQaVarne06B!n^LIz0=6TnRrv8Bp6AUnlFM2?f@#@;af% zE2jifkx63Ckwrd)QB6WJ1RqI^VN~3!wl}LTG{ONojZkja0c616&0(k@!|?B(FHr(;umFwCH~9;_m3DU;zU{2oYtVbD4Pf z!@}FTxsD*3s?SFa$RkkWyuxuxrN35biNOhvYCSv&oHp&=Z6HK;y~qc_RAbq(7p;3t z)i0eIxzvHcylT5C-k@p3Msq1XC`2c<$?hcX&{0||?4R(*2rxnLdo5Rj@m%tg=Q<`u zE5J`+yw158u-yu@AFR285mUGTaQO-b-}TozTCU+uri%cBvJl(tt3=ix!4bCDRj)`( z3k_vGQU-2T)oFh9!}c$&b{^}uZxmoz_emP6ebN3&{M7vfTR33|Z%9`|GEz{Tce7Xt zTJHAIsOUj;{9H@5l9s?L;j76e47$^iF)PLgB_F6)@nD&Zj&hH&tAKJwJ^IQ;l4VCY?g*L+^j#I-#imcvM+d< z^1y4~E4s7yMj@q+H&s;0!*eWpj;_AuRJGg_CyYYtvgME{_)n>y~0lR49gPEs?G&DjcoM zlWJD%A@OfR@nXpcJ0J0AKTKQ{wBy@GCr+N z_PC(4L3?+NKDX_TYAWi;y1~}_u!-Dy-b%AEo{w8W6Z~rT$V8mMw~=T-=nNbr{l8!Oz>_1wON%$H+3Q z?fsKmv*PiS{i^ZM%zMP8EAJn{M8mt_luFD^K+wesyCE|iL#@%QL0cWiW`=BO-o5kJ zc17c%h5Mq0Nn@KYUy+Ah7*}kD*}KE)xx)j(_0>?oho&GQ6~pPZ+7z=LYBhBnMxAWR z)o-`{QyPl%Wx^uKPaBt)D4QC*IQL1pYHzrUv_nOIC`gu0v(+il zSe*HLfXSKmvHZh|EgUd}`oFKdqt^h*vJtRWVRih+?|$U+-t@<3fXa(LY86oU*NYQi z*_Z=$)d?p97n@sN4qD19-Z~0YwwW-$CV|e<$a!1MVjt^m@19NBtm+$lzd73=saxL; zx+a65IJt-`5s#x&H-%W=C3rXY3-2mtzCPsCtITpdJhhT21QE)prpUuM>S53Eem-xV zC7TbPw#CZ`pMCrE3d*{~o-qQr9V!0o%>8ik*2P#`GH*^|2mM2*vxbv4zAAby- zRfZVY19oRPf@oYZ8n5Wp%77=1{rKYIB`Ym0tr0*W3k*bD=G-df@eu%`))^bN$9`Iz z0EQ3Uf)%$Zcx_$jvMT3eNoey6bjT7vX# zVee>R2C&_3RpudYR>rOHTjd$onOQr`jgZdHTwgKB7+W%rWB)3&KK#E9+e9XPu0&Xz_EU;x{r`d1AWMpl5)T z=G=D2#~$@4Nm11vChq&MExtGnBrji&U){vzG{5vTKI)4|%5*Wi^~br@K~z&Kznd=- z6z;H~G%=c%jB|4vgS1h6e}T$lx%$4fog}tu{t~S?#o@YgW`t<2 z{o!UqFO@^U!mzKjQ^!>2v(@x3D-0`X-9=X#JYM*2c~0REZrDh=*{4etzp1BD)39B) zSg1&v^IOB~<2;kia#|?BG(DxQsB3ZR4uLx(4FfwO!+{J05Igm2*6q@n`u=JTz7~KT zx%C$Q5d{*b`%}s8$8L0gp3>Qt2WdG%wUTz?QR2AUAS_9DI_A^Yj|7KzaTvVIviG+e zoB$pkZPqv_|8XOKBG7CzKrLDLVyC?0!4*Y;DRrBOiMjZ7l%kHNIfGvEpB#E7-Nf|j zwc54&aJ{$e9Ip>CBUXJ55(h5P0gV6M#jxDr9fN=2V^e

|-8%u%ST#=Qs)V$zNqv zzB7Nsy^DR;=%2Lws~gj*4b!;KB%O_dG~R%cund&U*_z1cqVQpKIpBk!hoIAwQS-<~ zSHp%Go!y_)&AZ{*o=_6Tk+}xfi z;=G~jd~$f`v%JMF1pe6L$CI-)+>~^z;RB1Js)ninr(DFxef3645AdmJlAFN$LsW(X$ayh)A&S<=D#*n_HbP$ z1r(!Qy!-OXp@l98Lx9z!-Ji8GBBL@sTht2=pGHhWj!8{xM)n+uB8WxuE6g8nTT{lj zT+h=nd{P^It@6?Yq%(-QhKQnue)Z@$Ze*A6r{0LptTD)!dWrjH-sq3yD7*mnbyx={ z$I44KTiwpAuS!a|3X1xGXAnMI_v2OmKumL1_}avBd69j(lj%E|EAl(PEwVcbs}BZA zAkJc^pN!Yf=FJ*wYzZhf6WUF~DUM=^*GEb`M=sW2l7`?9(LOwQsIn6$m`rj!Z!z~;F>8Ikn*rBdebqqJ|H7n!8Fa+# zPZ_vak7At`6gFtE&HLS5J``bYG=~Jo2Q|Q|Kn;COj4cMJTdg`gUe%<564ju}i;>=) z)Yl1bsU91})xsj{R0?5qONd@3b~_L1i0Ft5z<+a_m!y8I^6TDAxjcVyi-+UY-Tvw|bZzaj6PdKr7y_BI+E*iQ-x<0Dia^%tjo){>Vo5_$l zqH`B|fLQhGI@~gdPtBXBH!-~m;&8`JO-*BVY#yx=1V|}4l+7#Ai@!AscOCAy;==ZO zI}q~9kWYejg{HcAI;+&I?Wxf9C#kxLG>rG8bH17oW~dLBh^0nn?gS4~pG?{>QvDFJ zemcHzlx7A?4sw!PgNK>YHKbpes`TVX=Xz7}8F7J2_@{r~Sip}*kz?ttXC>^a2ApiT zzb&C+6?=&7oQ)`(ZEj(ECgnmalK2ISo;8KasECc%ok0)RX-E`G`DD-!Js7j8y^*ZZ z@~Z;*puYh(u(w1^fWy%^8;?@?lAQXd6K#Xc%m5-R0sp`%<=S57X-yt#qCT zc62E})oyWeHT$vVNXK&XSm@9MDrQf`cvew8YOzVy~SI*25d=(4y!sC!d)FisBh)3fg*ZBbV< zb^#5hwO2I-smV@1@U&Bvw{p+sf56@?L_(!6h^RRqAo@zxX!gl9V?+yCr^ayqh`otJ zFTkU>n+NBApBQm@Bdn(TqkyXm08kNu*i!`3&f}TWN2N6xhm05-2}_cP$Rtdj=|{>snoq!7j{j3Kceg4p50A|Qtrpfu;+*2428oc*HAWM3XA)f< z#JsO=Xz|OhS3{i@VLTRj@9<>goIN_Y>61n4EMC5#CaX1bepLR|lNI5+0;*e4_ZN1e z3i0DMF9QW*l1>&*&~Bhz-pB_%T?=oiVhlWh_#I!I zxQfB_r8?Sgdr`z1ZCXcsFcv79q|4c(zWB++#2Jp{g<)$sL!^xKdZ(wK|K*Ya7T8dM z;|F_&2hb1B0~Eh^MRRd7NtE4s4lhO8pJ;!w%^i6|ZX$8&>J)3sfTkzQTt|{~nGw+- zVR399Z>|7V?Kx5PUE|)~4N~R9TE+b_kRxHgVH~L^?7|&SpiIs_PPNU`TLEI@-B=oRb@QhAo@LC7K~*7d#9KUx#~~ z9*w?V|J8Rr-c=q0AGcR1`i`tchH#-d;C!>%(Mck2KK&PE ziO}H)R|&&2c$}|=E#!U#(-5B`UT`SN(VLLuLpP1FrtzNt$`N_%s7dGJQWEr%B~=H6 z3`INY$fW``HIY&CeP7Y=iFw#vXyIw9Ws#oxaQIufdcEN8wqxk(at-SY)FNmJIR9ox zCo^BNOrR z_J_^4lrD-8L&Q1!K!PRvkZcQV&WVfq!u+0El92T#+~YTK+6c znhVbhODu{gA$cA6dIfaPV})gk*tu$FxN|bSnh;}KSHV{%4LeIL_#7FYP8AlUy?uzj z1Se|{By#R*9S1j@-Oklfq#g*(t4h5PA&{u!mvKZ#0uJUTEY^8Vl2tA0t{V+5YFocn z=oLh!-VgKlJJU zVjBo@R*AKVl+t=rjfY-FTWM)&4M2RlKsgOBOT^>zY$g#Kt9@ROqYZZ&sFqex>W^9e z0#Plq$kRZiyM`^)d~jo2&OTR3tmU1Gli-_?-fG*ZvpO&CUmPWgUV^l|zRBZr!Dv92~CArG5{#8~o7IAdr&eTYp8Ct)V{FZi@ zhCPM}WxU`>RMu6;!fGRBBgL+--m8~&Sa^cW>tHL@)0~py8>nf`|D>>PMjnwdihz$lxmVT^&>t+a1rE3IX$v< zuWY`S{iT9=Xcw^lL#i7I0!p8Pc2fzDet=HWV0n~93RJpzaE{8xUCai@*K*p~kJ@Xh zAL0EXFpW8}9rjgbh^9N1S=S-kg9}GK1#Q~kTz3#N&Sq)Seb4#X5W{Rxfb`oeazsof zlj@Ic;e|{7sbQ{%RsEpMd3tfBRE~&a1Wpg5Brl=0xpO%VZv*_f>c|ZS7GE-GB6m1B zwy09`=@2H6L@5K8cJp3;Y3*PCp48DX4?#b9pBi}mq6p?Pi7gH-$sIrt!rO%?z^}6- z^}H-JnYLZ27OU@My;UIeGWl;O3IwwBIl%sNZ=7l{X?F9F$|^2f{}!&HDG#oPFoHD6 z_%oTnZPgg-1U8XX69uf=~ zS`@BKv;8E3~ju0JmBpJd?T;RzKJ!=tjqvkk2l@IN(Rg3w~+Lk++ zJ3l7w-TsYqFj7cWuxKTF^+C=@tdjnRlJkRvJ7jh6Qj*!Ty^y?u8U$RRT{6?CH1u|>qSz17HgM$jeHF%w zZl_loay|)O<wuryNDy|A4`OoZtU5o?@1dlC>Q@p+=~Nuo;zqG8Uo)d0 z-%l%~@q>TZ5$UM`7n^-U%U+}8R^-l1U_?XhO4OUoPAPs}!-JT{L@*g-H0F8iTH~RX z3bHgekCeOvs{?j6)gHSuC$oyIEc*GduOnL6e!($u=&Ev$FuQ9aY>@nU_k4Bg1nmP2 zjhrahX1sAoBR|TjpG?dq$5ez@pk%zL61el8QkVZQs`lMWw2EPsO!w>D@Lzd4#?&z^ zWxR7lyp-FH9DD9&Zuly2($(5msO-M^rLS&(MC(bxRG$8lllPhsw)~D&V#9^a)CqeH zTyvt5dPen2+%2yt_yr(9F}vArv&ww`+0|WUzt7ZcJCqk=xq&v`S+=_}S9@&`6LFj8 z%!Mn^G{5q-5ZaYOoz56Y62YcVPSfHhEhowfo>Th5XSLDfc70-njNDf>f=;ZNP8`2B zl6d#4YhBpy9E{XF(^sow`n7l0?pHH1lxNQ&8`p08^7zB_ zdr|sG*)sT8ut5z}`o+Q3a8FNaoo2a)r{x5zbKH-uu@1;hB~*G&%iZI)Qplacf2xmC zlFK%n%8nDuL*J^}{|Kt=oIL)!2+-mhL-uX#y8O+EudlLn zXzv!|qV*MCN*1q2H>&))dYyY>Zx;vjqjJfhotLzEt@FnVjkllfwwOFi71IlHPS>c& z#rXBF*|wMO9&{g)yYf7W$KAT5-Z<=Ej`*-@6)fBBVKskSjQ($}p!F#-65|>{QTpav z9+hH7&z>cl#IG6W9kWo`rL@QL`3!*G%p}pUElshFTCP`_eHU7$si80^8^a2=ERSvL zLx$$Fuf84sXXh-X%C&Obl__XHJxrWmp58qxUM%Vi=-*Ff9pVw@%&GK@RfleRKEWQe zJVvbG0{QrlMaa;OEumN9)g>kstbw|5_ceB{e5;!)%V$n=f~}HHp_MXI&MHAgc$29N*ZM^z+y)Mn)P@j6Z zSuhZoB{N+A|0Y!)TD=M=S9`KPst%+}{a zymx*Z1WqO&vU7_GEh1RBXS_P$6O;$|I%-$9he1eV0cLK|LMk6qMbnbTI@1!Lkx3a1Y(pH*q`7m(a#?YV`pRwjWXi6sX@Ra0ruyx zQ@orG;KN8e?1|$~24#;_#{RP7{N(;NN>X3#N#OmtUaLR(V{uK7WzBo3PG)3iLV2V3 z8K_YyDNm{4x3`slPJU@sjnvC+$VQ(_yv|AJcp9kZd=~rN@tbw~qL^OO$(`KnXsiqs ztE2tTt|{HJSsg6CRyWNR=qvW`^dPyeh$q+D8I^(e z+D-C{e_x8_1{mFgN^h`*5!2_!7j>{Pr#T7#8MEk4YkSi;yvna2?3c%w*J>sgCI z1lGORO3Lqj?uPS}%Gb}e)gl)o6JHk~YAST9y-kw`=7iwPxgXD3P9%Y!iKr_01e&U- ztG+3L{$iqtCe@4lMgyRkNsVSfiI{D% z>+(VwxgdO#qDBz&pfz@|!0sg{KzU-jcM9s@BOJo4SCt;bG}ohW{gxrcPt+VDH8JAT z9fzCdY+$|6{m8JK6g8pJluD^Gvh;l);JSl5*z*6o^8MLNUXWa0c)5P{HV2J(-AD)Z z-TfuP`{s=G;a=zf{~VAwAqBKj6tw=mF8_Dmw@2CBqPSgpCvCDOtpzITxFdhr2}1Q8 zV@Xfut&NbO)mMI=0aI>T^Z~h}giXKBo-F_)0@Dll;cj!jAG2Vs&v(Wi%W^lG2beLb zRo#Tezt>@5>66lEH=&LJeh+)T>*JMb*9gUzvr&uQ&CoQEr>8u1)>(iZ9#CorbD|RMqiA`zw}t8l&bgSx7-L02jzjk#)X`bnQ}!oz>b2v@XUU1ZGWBEkX`7F zI4b5-O9nl}9d$(#0h4>luA|}HE1~i<5DAAqojRSXG_c|#1q`QtW{dl+Y;8pVaGnA6 zF1q&MU_dB_Z|8ol2u6oo;ZWiHf@c@Y&gMUkH3Y`#uE!MTtyT-%4*krm8~aaVTTr>! zM;v6R3YVJ8uXIxJOrf5ccFwAL51*+4hBjt9Qh|dxQs=@ba%+f`*vQ;Y$y!-3bAX!v zy;Z=~xg%I5Ke%?*8LZ{4bW9u#mYGfe}R zs0&RA{M5pRk|KaArTUE;L5{9b<4c2wWJ~FSD5&%qeJI!Y&kI364un&8t25wp=o1#| z^f4_f%lR=H*a`l0|v7n4rO za7cEtmlnFmkWlOPa9Y`r_Z}RCEZt!I^lFWWt#1n&rb~%brut7EKQTZ#4k|ssJv*2N zA!$)HnEYL1od{N^x!xkR7iw@u^a13}jk>XN!-h40Ec}Fd=`L(RnH1m@HrIJ z>>j~X>xZdD8@(Phmn+_!8!zVFoE?n!7kMr~{^_S*pUwTewHOlRSf5b!^v^8MXB|pJpKdwj$8syWPp5z1SssJ0X!=4sp`H9HjZt*^vYLB6&w7w{T zcC?5h+(;wWz?kC@fw^$g@U!Q8lX*@A3)PD%*pd2*My+Dy=WKbi+Q;wYeeqebLh^d= zd{DvLqf3~yvv*kR#c^FDOdy_b9^@MeIfP%b-i(ph1S1VER#WtMTMs_glNZ4Iwa}=2 z&w5be9j}Z<9@HRGR11>C)4dVx9>0%#Zl#V4_-FyOxje;xj_;{7%9f{sX_$HHuv!Lk68Kq5ExnCT+GMZemM3@i zmAXVYRJ!gW6WP00;D0hRj|>&EE)XIP7^KTT^P6Nc)j8W1#W@W>lih;;F6?kt`sy#? z=m4Sjl_b$o1K(G;$Q?JgRr`-Tbx2(%j&n4^f*A4rK~N#>QSc9+h>@(TkG!_uhUr*y zh1F9xo1eXr6a&H49e?)lZ5Ij^ebK1UzzY;T?lk#8{heHXOpvtvdea3 zPq-!yt+)Jb#%zWCw}ycRxoxbugdN8JUe@PL%RxJd|K$*_>-)JmvMBVDZitaTd@O z@iMSVSnhb!(^x|q8@CWdWT4jQa~XDb>(H5DgY3UmxPx$erVt{WcR$(>CT=u6=geFG+>{C{Y!||VAwEFm25=!XCzfPwZ^g4g&Ljv{p znyp)b6n{%c_4xh8p!HZ!pn5mOro;9%*^J0+?9bHj*BU>)V>Jq;ii}<#sg8d1e(z3x z*ISo&-%wRY1=fn}GREHA7W+MOR3JE*o=i|F0YwZueN zJSjAs_}Rfzx4r}qsqGKl`P_MO9@29Kf!+sJWk~zW+*z7xmq0eZk_2OlXZEa+OTqxGg9XWK8lH#ymP{aP2tChdejcV! z1FC4lv1nOz^W2uRw0vc82VpDsMct^~+b};1*`|2S_FkOHT!Gbna}9X-!#JVwr0$)jJh=(&H9$1Bb4;p+$vFb3mR7~Sla zNn@ZWeGB+S9D?Tl;=;UoD%-g2od4lrLF933PbL+_g?=X8%U3wwx0?>#W01Di`@9Tm z`I=aH?1W`q9jJhsd{_IN#7J!kD4Zuto}D}rKCUiv4-zgApgyVImHSe(dpx^P(4nXB zNd|n_pQg=AOf1{^KlS0gw7MZofV`jV)%9GIbKU^(Ak%2`_3*#=ZGGZaqF>Db7;;7b zB~Aerj`}U0oOV+s^#6)5T}vE4C&sZaO7*`I)Rckb(utAAB(~Fp74tVUxWBQ6

f7 zy)`=ZkN3vy#ZPLp+ZV{Qh-j}lcDF?gZB9Jt# zZOyigEog)Wvpi4cHZ{rC2WDfRy|k(TQ)n=t_FuM0A>w;h;?MC7HEbh2(v>+5+9jIA zmKfF6{wdbHCaUDc<_%kYRAc+6LQKFn8}?4rHV=06Pf`LfL?N(hLYt3vBEyjU6e#u9CDh6!I<$|P*{z+C|*dCymNJ zMT3L9=BUvKMVRpZ739Q>d{BzP9G+bg9iMFYzpiIlzXg^6`)h0YxW}+i@2DM@LG&mg zY~F+6^DBY=CJ3llH3;mrfTNYT%P8@Ibe`P2H}4W0zp~gB94FROBY67k8B$R&=)oTu zX|*=cqo;ATTiKz4g=J!5a+2QYe-+gam4=x?I^Sz~LFlo(?n%4dY)oVRjFBMRcjCUS zr_rssgg#vS$;U7iz-5;ge65&;HDWd2?Y8r#>|!}=!H}fWYmmjAgX13g=MCCDvufC= zyR%UYPvD(~(Z!eWFXHWeLe$M>6Z-hFD4AS8B(U{APET$!zF?u?#&hfmUv1+pF*d~6 z5yaUQqS7KU$aj@&KP0WAjd-uxy<;V+LSwjb>_mHow;x77VS%@|Osh=l*h_t?tPzU( zZuHD{u9Y66k`ZsU2ewMKP)&W~xyRJ@`}c>1W1}Ao0RCllVz~H(WNxVs4RGUgbaSHg zv^iS--chS=V73qK8lq-Avg*7mnV6!kUoVb|gDHwg3P@sK9M>&vxbvppZDI~vxZX=*)5mwcr9}=QvoYc| z*J%FY=JJsu>*k`X3U#RZ^7}Cpz3d2`%L3Ie(-u@S$*k^gU&F|R@~}nD_ktfLN#4P& zPLF4l%?B}q=h|dor_YLMs|OB!x(~hm7IMEXRRx@kGHC(RqnDi4|53&@*9rk>h(7mV zA|pB=!cj~Wl{xoOq9V8p5IsU#0GCK#O+mU5cyYBXA!fgSSe=CW9J=T`xeqC8xBhuD5Iz3G~Jlx;s47zVeMQvdLpK(PwsHxe^MxrlYe8wGRC8vIU4^vuyLJCIu z*)5IP7(Fq`j*|8EM{qmP%r^r|UZBi76%7u#|IQE}5?CfztbI`VI+e(=$}C>%XTr(N znEup^twt3RLj^jDRet6xEF#()Gx|1wYNQVsvD!eS^tXddOj@tsD;qlPN2AqAAyRA@Rtti3U95XR-8CjwQu+=q=dlQll zY!8U0z1kBl_I0lEYa;SDB8{UMUDNsZ8eLa>FQ8N!_O|K&m1lvbdO&#=J|to!+XL7{ zi#7cHTkoiIJ91vSvHI^*lAQ)PJrD=83c8OG+r%DEAV_}~bkqpMZLe68<>WILh5au2 znlJ|qu5xj#^&h#+cJKcleP&`fG7!IDb>>BT2gTd$sUQSMqo0E$2>m*Gu>P;7(=&IIIk{$vFQr^C!+ys1JiSy_M_W#dBx zD85D>M<7F|^!^VCV?WieRROREByzd2VgDjw=#va$sfJTmet+RvT&}|r1ESE8v#61O>n}?in|u(2rI+e1i%R-j;J4x0LKVL^ z3F6YP7)B7ro@7V&7Xs#A$LrSXBo-OjmGtxc$jSKWm)>KdHW!s9fly#w6`Ws#4b|C; z8(!$0xf=Iw_QFI?Sdlxfjxn z$my6~k?y!3d%vrBXEMyvSorKrJC3U zIvkcm|9NGqrom{xZ+q2g>A-SFVfn#8k3#!dva(dq@%=IY1S#uWxict%4Vi7m_0PF7 zj7!bUjp0&w$O&XtAUZoW)HJNW<052=azY<%O`cim?rONc`Y_-s^EVOc9yUhqO^ay) zWfB#R>RoZW^K!sPxi+uCXen_pRnJ1lqzSK7)oVy+f*u%ueegn6xZRerf=V!x-}Z$f zP@@BxVJJ};*Avi zYL8@Qr5%7mCB+wtQK>8S30@h9R2JsBE){=MZO^bMjSWl@*Z7(q!{5PuQAyDbB1O=G z4s^B#T;sotpGW1Gz9>sSh35oWOfHw6zzDZm=Mp=0(@=hd!zDWsz4(Bt`m@0x-G5Ai zUFGdG4;F)37Q30UZ&bX`X=u0(DNKJ@D*A}K12X9+t};@<%KeaNXLpyKR_Hwoi-7%% z%0q^Io+8Gq?7HeFpf|_~xF83d$1RU1@X!Dk_1A!tl_tQAzqaS~1XOv_A#ThEhz@RA ze&*~d7M3@qx~g1)=^amCk%|~#^H!h&^IU5!3-Mu)adC_$^%qHz`HL&jfv?6v4Z8+0 z6C?KwpNe?A7ogpWzz7PObf60BtFRMJru;OdSuF~l%5-^g;fj=9d3W17_UF^P6FIkHT{&$b1^mWB`;+TT{`f`ls?CphC=EuwFoI3_^mrHzU=5 zds2A3GMOS2Wed%@CDbIPUDe+ZFFTWr57frbpU}(Ps5W1<>slK_38D}ZG>3H7SL*`d zm2*d8mBD?B(!b?e&&IA%xAflUAFtz^$rFEq!Z4rq#1kWyN%t%xM4L4q+L0MdILs!W}GyWzzHC)ijC~pY!Qv6s4Pf!3=q)WZsA@pq^ zR@ar?#Gy|FTRI09y%`I*@nsK@rY;XnCa^m~YdC28*~0|ZY|L(Vq0`D6H)!%6Ykh1u z!4rON;26`Sx>rADt)=lv6yZhHN_a{o5-0YRis`=~_e-1@!p!rg;~J*(j7Ky$=RnVl z0m0lPBQ=}xyfA1FA1{}7k%};MBeIp5X~qPwu^tv2F9@~IJ7(@p1%(2#E<+$WqrBCY z!*-w#y8(#zNKQ4`D0&GYr^UY-{db5p*IsVs>iZ~OwVoHDAl~J_-D&C&re>qSx0g_R z9hVVgTt%C_Ssh#nhJn#FIYw@?nVbx4sN3x`DJfzpcP_J0a=vMEzMo#KYkXeg7F?<` zggGj12ml8{#L9I6v{dW7AVx~B8m10-CHt%nxDCh8Qb0I8O*BLvv_SnYRd->yqC3t1 z{v%x2JXe=v&S}}}dB@{%1JwE7Q+yBUoSQb395S5tq6v7CBVgzZt%!zbl;6K^N~p2B zMZG)w)r=4=tWktW4V;TcRJ)z3#E{QZcnl$BMe4BF;7@Hjt_%Ea!(`JyG-cAaz;Ybg z&`F_;429qFYTadf-Qn;%DVS0YVv=hGLONd=Wo(=w;rcbFL};TpeR?66UlkF?Hi9lf zzN%uOR+OV_-e&{b!ufXGvQuK9Ny=t+L}m%ZZoC zQ%R$A`lMOgs3%}+I5XledgxIu2T+H>p%91!vN7hr*_aEAA$fQ7Y%qu$K?y=nAsJKp)Vhv6 zV_29%bS9zX6_x4Q$9?O#T$K8 zQ=8MzE>FE{KfroL&*KOpR;zYnB&5`dfmp_#!F zeOn8Lky=H|v%7adMnlSlhcJeC%MV&-{?hS0{jECh#1&{T^3d8u%^Orlip^daqLlQky!bE$8za2B zGXCbBf-4xn4DouaHa?_hrGi+@wA>-|Hw4mYzVB``#NRh$^`r=5ukgPoPW@#NVK@l8 zBT%e>K5dgMO#>w>bJ;|Dbf$_HH5iNE;7X(hsyFxOsikYby$Q(1!}VWd1&g(PSoy89 zk~cwTeG0)KK}G4F%TepBe1yAb*=klHR5I){yd0M-(Do0F3NUp;JzucZvc z-5y$wsFzbyDBTQOZ$wpWz%??B#i+6BBn7BrF9BuKh4gdHL_niwu!RGtc%L~V+Epp- z^8(kamY1HBE5Ga&l&tyRIqD>XaPso?=$ghQkxNS8V%W5>>GKhZ?`%cXxL1f0JB6g^rBri(3ltbP z!~4^|6A-5P{I7bwO`HeXJwS4Jtj=EV+NTL$kmKl{ROEe8lX zsm#dtitGTE&E1UuOQ7G`0I`9INc-H#s7)lYkP3CC z>@6d$sK*Ym@A(cBDVbb64QE5pvffTkq}!29L!Qg@W?$%A-NkmuZ=xUT;V(a671mZ0h;U$>{LDH{nURaL)SFp5rhI|^b*5)+BBv;_ z*1z&K^+LJmDFYHUKx-8J#Ob0!~ zt6U0Jj!<7sDdc^rDnNto1V0umb#))2S&Xj|U_J(dDG)fMu4AB${u*o%YkM`|y@w9G zRu^z)Bns2)VNTDuMv%i&$Mf5K^_Ac--kJV$(N1BPetg6Bg{YOxEsnTr4KoP+vA)`W z@$tj-TQ$sOtp`)(3m+=5I`0N_{Wi9QTvlEob-x13$Nn$;;D^hON3+|D{X?Uh&d*cr z%r{%DH!ibw8WqzTceX$3zsw1aL$1us{mLkha1R{bMy651yIQQ>a1MBhykcgBIA4^P zM-^p2gBxQ)(j5%Sb+auy9&Xr`(zs&xe%A~osoG2wU;?HjNq@kp^WZhpQ_8QO1qh=B z_M#EENw&ju-pNBa;O#mcTMvc9ua7TsMU1)5$VKp+RhhGWHdmkv{$sYM@AeW>ZBmr4 zQG9L&=UNe6Zp4@ueGem638PYvtkUHnArGK6HeBeNKar8ItQIwCy_-b)7uF9YSptTk z!NUay(5FdLKtU?`xdnD4=0ZIa`zP9rdfrUqvOQhg2Qm^WYe;oFNkdykvH0xSg2&IJ zK#SEwS_gwEURE&d)zpOT5jg@YXXnjW$`#g7v0Cp=ir>-1mBav5DrhgBa|jb>2ve~W z`d?3#c(3G$3ZXuoMk{8SsW2ua!k z;AWD98v-8wWIU@fm{Eej_{CB7vErSh!@da`YKq55j3D}YQ{fiNXSrEfZ6E?1{vUU5 z85VW7y^lXCf{2tL9TL(h-4fCb0s_)V#}EQiN+YF`(kUQ~#Lx}W9n!-vgeWZ?zdh;` z=X^cqy!yZS&vkiSzF^|B_u6Z(xbJ(ZQR|8&@j1kObluQO7IITg4O!A<^fTW5{`Ip) zW9KU>A-BzTho*D7s8E$JsUh%f(1ErCq)CaWH^ME@jmjF!hKSmS#8&Db{WW$&OO zn;2sRll_BQ#AuTwO7(8{F?}{i*xOwMwym%%@EHs#DKt81rfe8fNsGfw-hSslalHb~ z8~dd7flABsIXuGUY;f))>@k;+3lqH;r&-|rXI!mK0~kegeA0zeeSjTQmfYM<2$o7_ zHx{oI@_dPqVVQQGHb28z8O@^TqCc8V#?}2Om^1>jrOcQ|t(ikBw~70eeA4x_@4Kz! zZRRB5AdRTw{8+eyP~C?V3Sg?3Q1BfCc;GoWFIE3;?nTLWZ{Aos?n@TndK$40@>R>R zyfN%(m0S=w?!KA+J^9?cML*ymZA}YBWsfgDZ-u6#zRMVf8h_fqe&pF@!ieFx z7^XN!|1}T(Os%XP7}(RRm@t+Q$|7bI`_H-VsQ5$r3%t)pY z-&B5>xHv^II&-1Bfa`j{_O4+TjI%5C7BhxSta4&<>d|UK6Zd1b&n%a-qEWtB$DdhF zz2kwOqCrA$!M`XoW_pxCsjyokNoC2EurkpbgB@iWr)9oIJj)|_uggwo&+f)!W$v68 zA|*(vwR7gDYDI|HGut2*y|%X)(;l7F-DB2O#BR5N_(rGTQ9HT~0tV(r6p7$^!s!FD zLnH$e=X?qqAM83dC>4ToVQAOC9uR=yO9%Bu&W#jmq&73~;Rp$du1InOgj~>tgV;1A zDJkiUXhr>d;#zgvc0>}%v_22Sw!ruamle3Wl}bW0N!EE>`vNcE7E49LB)K`+2p+Yl znvWU)Idg=FH(W-2?M|J!PLuiqYA-*W`S1;^9&bQkpfS)T-7UV^iPQrIC97;<`B@}E zux@hnnLxd`AQjs*STlYK`bOH_lkYaiP}tie06F1zSsN1C$jXwAc9Tbx~|qa!OZl&z;NC}LaC$6!vHUVjQi#JJVzC6QfC*2c$|*B z?+HgGQr5Y=>Ly#i5$vA!C62J8^X&4GBT0sTxsmd3$&S|dP7u)-%Bs0tU6q0+11ywM zXn_q}n~18$BhB0(y%fle*b9nCj&5m|!u6Y##bi&nLZ4qcNsQTlO}dYs8m$QbR(1y) z)717qS!e_z01E2nq!Z$B`HQn-M9OM^&I-?PHN8wmr1JGA>m}RgqVUT8_To^sw#e!A zjER-DcEr-PMSZ*oc4Sk2tLc&&6yxo$KYlf(CjPZ_+fdzA049y&HJj7@zCX%po`CLg zGYc8{$PK%&rvd)lR4Kb&sFx8tO42G)RlV0z=rB(DblJSmDe)keMEipoLB*CB`Zp>Z z>jKYs5nMHAs@Vv4U^*cP6uQPAHyiMKh}@#lp%QfpXkIK|nKiu``*P$nkz80ijKQPt zV-D-!?)2g>;{R@-8EdQqC8i;k+G7+8*?WNiYe((~y>1OI(|W5NSWS7NLhp zHp?0bL7E0Yk$Sb8-+ele9D2OB3I{5lFlgx;6y2Z9^Knf2$Ww{w$s>&9(2g@-k=x^fd~Y12D)uDKxz^gW0>qY0Z0sO zgL2xlgcl33roOg`0jAEu@U=|hqsNM~bN)ljNtWkC#gityxYy!gnO*#3dkKD9O;6~_ z_u5YeN5pf^Z@}Y^Ng|#ETTmEgCi|1Di+u~(GNoG*)Zcn6b>cCa1z&yyo*t(cUC|2} zu+nM^zz{DZrYc>CB_8HCA60EX(b1y87-J&jPV%@~f*m3Ggr4XAMC!}aPWZAE>=aW0 zXyNMWqsKS;EI75If!OH)i9ZfP1hkt9{2JZx$)_wi-EgbWprQ6m(6UnZ*!N*ONq%%R zUTa&M#$n>x_#b`Dj1v!-jXj%tP>MsJ!(Oi5Kp7-|Svnqu;A~R9B=gATyd$Wr-kDGw z+BnCH(M1?3`KkQ6J4BUCVa6uYsy2D3$=BCxAeq>0YrX@RW|<>QLfH3$3@mXxJq(m* zVty<%&PzUw4Y;Y*Zp`KLOl}Wh$b082*&0@hS{1kH_BXID(HmehA9=0otN^(pim|ud zRbxMT@3>TDk--Oo_zJUptc7 zRW%FzBoEYyCQ8_=6cz@vl&;4(s!4uRC(?*3zW+S@(T%)~n_wyq%R)IIlaZPTP!&tT z(VB9msw^?t(#l=x)LIXz-QS>H<)vt6R;czJqFwaRxd9fHrltpdFpgIStCE1yO+O+H z06nDKw@ileN4f!;bNI9j`U+!gy~n7v6V!kP7Qe#+j|6j3r>~~ql2MU^ow)t-cf0b3 z2ffD*Nk|GFlBnV#V>dZ1lq9$Bg&$$W?t8z0RzRsY&ATP?-@ZoHc9_A0n6@cVzQ=iJ>i@W@8D* zvve(V9%@$!Lzw0$_^$B*hbqZJ-yX|hYrS?ewpe&&GfFEmaA~=aDuJ`_ngz9KSEDF5 zL8U8`Ss~x#iryf-a9nEFDX`+xH@5Mv{OkTNuV*ieP8T{lOb{wO?eCg6<|-}qBtMRb z?ku{(i?TkLlS|qtp0htvrg(3tc#^2_9Urz$=Ye!m<_IZc1ZkqmObPqf(&PuB&HvXQ z1CG!$&{1hZ1@p61Bilbb{`kJWK*HJ(&->#d2L332x^h#^iEAGasR5KzZiZRK7(4=F zNgj(D{>92}1!B;&P49mq_CEP_+|OjAiD+6&hCAJ zex7UBVcNzIs+YLPA~S&g2*=_5!kvB9wq~Y315Ge6Lp`!y`lOe=i9rT_2#*r$iFdG^ zg>vdv+q0HLO_Uf?f_yrfeUHPuCuXdPuGWp3j&PjHJFw(C<&@S0*MftRivSkSK%{(1 zt`1abd5&jzVRc;mO_UzosV|2lCc6co)R2|x+%yiG36>aUZMJ6+saH@Q`c=B$8u zwt0y`g9Jz=mQR7END{BzYeo&aW1BMb(3&wr<=1p`@9iz$3}1Ovr_w>WrT8AYmwQRn zSM>T$tYD~ugbs-*FD~|T>UuZgC0Qciab7CIhCiW%%XEF@agPtjD+e^3l?a@rFytIb9@7ban&Kj%~C#5du1&H^U zo=PH$xwhkncMcW4=;0}qAp#_Wv8ZGs$#$ji^upcU-HpJuAmN|2Zd+K|+S)a9AMNJ^ z5fodx_3oTw)|K4|QqYW!0%gD+jB<;CWVOgIIe|`oQz9_iNr_Ld3sAAXa-=CjR=;OI zx19dKSG-l_bt;4iIG5xox2CbxbAV>IovwcCWXnsNzy2MSEyN0^40O8ks{#Cl)aI$_ zc*SRaue&|VRjYONrO&@`+vG=IQy(rihg)qwsU;VaJ8W7J!YFvp-nJe+3~e&U z^~w#G=xPPfoSyWB?ELL_n-O8SPeHrot`tj3q(_5LreT*sf{5oed(tOYsf+8U;r`>J z?@b*Z5T&9Ochv$QVYau>tuF#7jiR1?#pe#P_?su}L90WJM8g@Zjq79N8eQz-?XVls zoMnTOzF)iS6c&Z-a9g9K;Yw3iP?Sgm)aK}5(0zTAtz3|+9nc*2qFVtZ^KauD&_y(v z^q;P@=^e~a&5%Hj_(P67E#V9$$3~XI7bY?vVgRo}+0&pH`ecoqoP0Q24$p&C9%n;s zNfZkUOQ*@FpPHJg_;?YwmcaB z!;8{cF8ENzS|3jd&|@L2_-afHhK<7Z^Y!^;Jl_GLGZRdy+eO*m-BpiFw4SbD?1^J# zzC4_(bK|t1<@Y$nYcnYfnR*EW0eM^(;1k03BP-+R-&C3yUwzEQ8BJ$|B26$5AlUG} zy$nNQ$wV~cHX%EiWg?mflh^sc1brOcYaUapAJFb8%6WyOdyfnJ|sL}bHYSY?WuVu5(efvCF=-K`{&T8SLysG8F@-1$0wm!7SN1b)q6 zH(5ZHknGjKSctdcZ~4LPe|EJVP`C0Snr3(Qo6TJRj~}znoIhxe z7^b>P)95o2ehS5EO)WNm{fOCqvbp8+^#Q^2%D3F_m`zsRbC?(I?BT?_hP1xBl9Mrd zrz;`0m!m$Md2Gp(f3lM#YwqqXWW)xpw;Axy_cy*zo{c`FGE3NF_b!-~zKd1`UjnEP z-xqo@M7L~`JF9FzQ#YZ4xpT3T86c)-^fj6p9n}zJQwJ@Qv+Xu)NO9euM47VW+Xzy@8zrQ1wVw1^OoQ#|)Y}OlxIbNV0uP3yFi>F-S z^#gTqwKaKaIve!Gc7029E9KdQLlngK*M_eFZm%>l7efb!;hBeGA`ll~=#1_pDC71h zBv!!tqzfMI*-~9$+lDud+Nj)h`a?CSJp+~0xAw6S!8I4bywS;^Ew#2I2B|}4^b|V6 z=bmAK-lGrCEyU|^YFzJ0#}A7m^f8@zPiRbLM1@~dOad&?2HES6CT_O#KKSgQPqsdP zk4UrUI+M~f7rmzp7&?)S$#)wVRfW=3=U_*hFQz2KNTS@>JK8^d&Um!0S`k4{LDK2z z6L>}fS{@sbo5Phq`?3&ChU#ed$C-^pcp-tCf_eI1jQxk%t^vBBVkOK|mbhR!*8}tQ<08q9 zD_7{c`b`<=*KWFXsO^#8iuBBI6J^z0|6tO6W#{Nk2$3f7b*K22HJ;9R$Je*-UymSr z1GX`#k&kxRKF$mc-BmeL{|4s-p z5QO?m29|p7Q>~hZCDIWjpW@jUgyP+F%FRO0bfPFk-ZuFhZD*Y6n9x6m%vRgt%%t&Qu1Iq}z}raS)KD4_hwJ z;ekAaPDa>@^I2;EI%o~wVbTooC{imYlb>lN zx;)igUUscY8Sx0LoaaaH$oQgQ6COnuHC=fBb46Q$g8Lh`CxXa)3-o(h(#8E>^Vm!zd{@L( zlfpqb?~EGnKMB!;na;Ps6r&EnBV<{iV;jDUf`ZbnNC)@2I9}b2=CObl`cH|s z%*~0c1peNTFPvJ5>u(YT$O~eMTEin&EvRv3YyKc}s%*I^y<2$HM~r2T1~(Za7T&Q8 z|2l2E29-55kwvQgG}Oogfl_p4dj71>ebEFCz8vdZ{2yHM?z>kzvD|h$dun3RJDsP+ zAas9yY$Ppn~?YtPZ4w@6@36noxse%U))RxSrlz(_WD0Q8KQIVG)jEVAv4oUWL|XOh z=yp{+WhpY(Wd=U_F5TM%d=n>PU7?lsGSJgV!UxjmQ!l|py)G#)yY)st270gkVVS8) zn|NXLhwH!h0;^3u=nk3BjGP$BD*|g_?o;-CEE_NSzOO8HajEh1<=t{w|IoJKNYrQ( zk5O>3)LP_ z1ddZML#XfGP?5Aqe>dtGgE#hg*dx?QF3cTS1~vN8BLKVfn&yL_xJ!%@DG#pl=?|=7Mh+&oOV?`uprWc`S!a$FLhDHMhA|q2FP{{suj*<79 z73-+390U;nfp>P>k(ZJB6@~OZz-Q4>ou8IcYz2wzr!H_(b^p&vJq7^9cd}9ktXEQs zx!)Hi1$440=ZVnal;q496mEns)v_EtzaR-``ye@i>bzsxcP=U=ApUiGG0UH;d9z@>g{|NZ8%Fu-t~T!6|KDtq;QC%vGI zmsT~Z)0(_dYnAM|U2a@@SQ1v?mkG&-!Qqwy8wtbbEU|c4t&%{tYrI1c@SikQE%L&T zg=iG*J&8GPx^m@ZKx|&wb>)~EE99in(* zRi!Kt&|8zd@#G5tH67?Q&IN%m=s30xLf1E_`7|9K10^*8RN z$t-!@UXA5%0s%?nAtQZ#=zd#XreGcqC}vhj8bM)IS=&V^6in-GFDEK_wa`fCy3_?= zeR8&i^yC^A(t$zb9EGcz0YQ&HCZ|ikzg6t$a*F@tRHZ%WcyE!bg4G(~^4h19#ggBL z&meRMRHyQs1lI@qi90q9I(G#yNC!`5&3B1FCW`mp%v%BQd$Qvw7uv{!A@>& z!AVJ^SY-U^Alj|mdI8kPKg%L}eM$`X?>=bZvOjy;=zAi~ln_zb9$jXM)&xi+Ok;l+H z4?0y${}cvinUu^A7@phqx0hy z{3}NO5x;@F#vDvXshOuL&PE+lboW-VVG~;vd40-RI-4xG@C?-dRwgbY%6W!zwD6w zr?|>p5845Gh(YL{Lb3%5J?-3SEZ$c!z|5b6h`Y0TpHsW@NlX^Da9(=56RpyJZi6^P ziA5GY4l5ODaX04oVZy<+nWXwMFyWB zx)x|n&m~&cBLaNRPrCIYLb~w@ccj`A00?g>H_b;q@+uN8qE%wR>HwjpTiIJ3jMdxk zKq;$}06cKV?_VAv@)$vX5HwYGb`c7~V1^TD9k&Bu>~eW~zrPTR%3Y9b{)mGA3Wt&Z zK`Ga2b+dajNCOcK-IGr?ToMMpd|ll}pk*Wh8y8vcIzt37%zih#HJIVenJKB~HJUH< zKGM)Hqz(6ECM$~CYnZ!`Y#jfdD-(ab74~4hM%lC{p}oUnlm*7Dp|Vmerri5$d{#_e z?>*rr0Xy*ErSd`UzOOqxjq5|0(VfrlQuFt$9RO}r>D7YVei>O+ZUeIE5X?4DQ`NdE zqA?qy2?u_BXX4OFnQEic3Pzow3J!pu(6kmq7uCo$y|iDqh?tI&{P^DZdpS8>6zgic zy%(PyL#K;a5s5{kJQ8KpZdKQR%>e^j6aZuOPAK}rxSn|6k`fNmY4lHpoeQfr=C z)^Zm?T@UHmWa8SOjGL|ja9=Ei|9Uw7Om51Zd+C1OQAl<+x8=}-VehY)BQy z#`6QcWAqMl5Su!^6Svi-rgPJ8$DpRmXOIrDsR~(Y10W4tz(L^?mN?vOZ)m#$C5K`z zY%Tbap7O#@48%6>?FJ}0KRWTSd$+M*_tC2OAw{*_1grA%3HHwKm0@cU<_pYbkd zR{VqOCKlUYHvIDtEqMVeu=$cg{UJSLF--!C$;}a%#mD_h3uF z)8rZEu%XFAe?L|kqqm+~mN}Dc!=L3md~}x!MX@a43=IFNj>Agt?&;YpRy}o9OcY!h zZ`h;QpUEb07)#W{JFxam_&j!?ndwGfn2*AxX8(%HNGHd_9!>WfW2!^oA52T8Ok@DFu0(J+ou`7CUWO{r6=(Sxn zrhYGtsT;jd37XCil7urlcy!o63pAbt=r;v}S?b21jG$BRUN*~-$Yl{(A91rSBxFxK znQ-Y0XEl`Hu-$`ZfD%Fxx+j;cy>xbal^6i5;;nnook^(mYHuAmWg3@FiXx-?kb>mx zu|4MH$-Fl?`haQTE#Z6d|2ZczD<@5okv@+MaDjtJ=2mG=EVErMHK66$G)&OE^zQlF z3qf*G!~_9*GgV6u2Ja8$C#yM3NlP6$w2hmNz;ze^l|cyI@Tr<4A_lu3wWPVj6xZGL zK28^1wRFv`C3p1j$0bew#*~a&%Gv)YGFOMCWyYY@s zm`z|>XA&4#Se*b9jf;bGck9#Q)>ae{G`b&71`!R}lSrmq62@H(<9BF3rWyF#BYQ6F zjl}6F8!cggP^`^K0O-^DGi5#4yie31T@&C>MWr$HN$q;q2RKAq;6Tt3-};NDV-X*k zN<^3R4@C{PJfxz=9nZ6lAbW<4(YuBk*t2-%*RxJ%4=rB_laTiu%rg=Gfz*L4pGXa& zRO}3DEfGh}bIsx8X?A1`vGfw_c|x)U0*w8Rs{&gw0LPTA;bZw|AdlBwrZ0!;Rush3 zUUK<%NBHz8MCK$FPy+-t?2agI=Y&q}i1pjN=(8qg!n(mBVE>;&$uTDh6aBhpN{wR; z0M68Xk|(>f)B!FrB14J=GUR&Gg7wy|Og2)icm>8{-&3i*at(D61SCDITPobV?f0mq z{NaZSzVRjgR0hJsn>aHnA@92Yo*9Rjn9)IwijBR6qZ4;EkYB9bgL;4vG9R)hll*6( zF1lR#DEu$bOme&{hl<&aV!Itbm0YJ!@3mSozVOeg*(>OPWd{;f4{`^>SKwzjR1L!% z-JNQdUq1v%)?`4*?f@Ov26Pz{N^NBntg1nbh~gX4#}Aph=07ImtVcm8!^p`+gIA45 z%U>otI_wYRycBfCvyJxIOi01|pcKMrMzm}4$P<5TP4z=?8em}p*ZaHH5i63)6L-8s zatWsVy_xhHNE0iWy)9N8dS6aS{xM$VV*o<&JS;v_=o%row`;DTJ&kG?8!A9aAjD&)h(S!Qmh$s9xe z&c)hFx!?_(CzjbbEl^YndO={Xs?bE)z7+z^>>fuwZNo$a>|2UV=Goq*NB=2f_CAQt z!4oBj*)$xlfpU*YR;;@CHUasoNT5-g$2YluEC&#Z2oQArJ5tdK^o>+2tsnikA-^<1 zlw#zp)dBodsTZAIIr+8k`u>-3KqstMjid|TF~hgRDw_*{Qgi}nJ0h-&z3Gg)sg!z2 zm-v<`Sz8U&03M_f*pa&`_j6Y>Kv7P$po3mIR>j$x+tJM}o|F{bJR#g93B4@<|68Qk<(Y<$ z_(LhUzcEXPV`RwQ!(_f7bcO+XPbRt>Ec6^ZNYs*XSu#!yE=|~4y3@}}h=UL>Iwl$5 zREKn8!#^}X{`TrAS7c`-4=oc;SG>E?>2!lTTSTQC;H~_zqV4iRUi)@*Rg#F?U4fRPZ|z_#ZnA#smv(qQfCN%K?TFK6=teJ7if|{`58C;OS4>; zg}zX)8Vg`HA6Qy4014ti5Jz?!I?O`Vh<6sCbUivWovTqIMWNICXtJZ%+e~~+qjO)lub-U z*1aD6U4G~>57ywnP$D){wRv4BMFT1=cu6b$?Mw5Rs-2H8BlO6RHVK7DfD)r4s>;fQ zq;xFFXL+9X7sVnlIN{TbsssUM!|2ITa&UmljlwS@Hdr&C*0TUqAgs21H=e!sw4F{O|9y&4Tm0#DQG!c0D zb^sVz&mQ^cP(wH=S&j$GO1ZqC%*X%38S#oC^1T{-QcKfRpe(wn<9(lkDQa)3n%x@l zED%sdG(^m%4AZw^BR%!$_DJJZ5*(2r>>(kz4Hojj?!BMaMzeQAb0G#+l20#GM0uTX* zDuAz@j|3SJ_l*Gf-ml==-J*m$weB^4*Ws(@ z3#xDk78eO_p0vRBxlc?PXw&$y)cWLpGH;F6C2_srQWSY~kcq@0*4j)lo-`IwRLVuL zDPasII(GZ;V;YVO8#J%cG9kKKDdBTMl5Z=CT9EN&Q*tvc&+p3(9RSe$4gjS-MFM%m z3!upNp=#{~(>v~wjrX8wOH(njL0?tFu4%#DYWnJi@zmJR9i}H9eRB@bLP>7>lC@Xz z?<#g}iiZW-#^stC8s(#jc~Bi%xe%Tr3!J#Ddjj3?u@UIL^2uqQi@=a$g;**x-1dJo z1CX#yK%2Z<4z{4)7sNexG_BF_KM0X6kmF2FM0LXF6gbPQpIJT<#J7Xz>$fdP2*6_K zpCcw7q?D0Kv&~0z4V%m|>Ko4Ejq8uATG$1p7%%B?OpNvaQAtOi7-0}K13Q6C)}v*n zCZRG-e%Tf@V-Ws!;nxK!3UQ| zWy>!Zu5ulR9<77Je z=z3NZ&X6!yJtQHoM!F?Ruti8<`)@Fr|F%u~6R=UyneE+_85*3CBAf~)?TQ3f{gdet zbMijC(BRW*qlGc$BbJZ4Bzkugu*2ZFj_KZ_Yns+(3_2M8LH(_6GyW!Q$FGsXQU9fp z4*CRAW&p28E%X*ZCMrp=6<3`6xDxw?t#%I3qDB3)~-4qP<*y!tpzXtOxR7j0F z1-_nngjTsfU6AsTMujdyO+Z|JO{>t_5dC`c_-s9F*r>YA_bBb)?alZL(;29T4X-8kFO3=KjZeI1LpsZUd6(hYhu zz~7_)gRb)T4yivE$nYsJww!sI(9t1nBV(xSDAm?YMPMJ*Sv0)qsJgy=TyCOM#*>&wovd;ZGAM2{6;OYX4nOfqA z*OVghDDcW*em!&dORJ-)a~s6?d+hwx^qM@~g1Qbddb;&@d=TJ+fKFF+53@3*F3ibp z*kRT-3}ZuaWHkMo5x13?u16&!)j#M>fB6Dgd4Nk(iS{+f&HeS;hZeJV4G7rE9Dg-T zBvkdAPIvE396jX_;TL!c@k@G5{!ihp5Fam0Gwp&! zT~9iSq}iuQE|z>()A+vT)=%$KDW!fy^-!p4E?9%I;kDA^5I<-2NBuA!z zrY;y=X4evV<7Nao{qB#k1^KT<-}Rv9%nBJT=5-+LWTg1K3!!q5Ht*WJly)b&LbG5_ zl1=E%Igy|}<0~y&3y5)woNjALaBq`}*Tj~Zft~z~zoBveYSrRhnTYFgrb*8p|N7-T z_styxvu1MFrhTb-`_jPV6KwhTEd3m`V^_JKzzbp7}X8|_8!FQDB!hCB}sADNJizkbneofV4P2Jt= z@zu{ROvjaY5Eo|uVmw}0SHcdUv99YJVI5Vm_$OnP9gdna1x@&CQ`@RznyD#txuh9& z7BkeNV>F*&ZGK8q#^ch;8KKs=kY5uqrlFATy3nuc@sm?n?kL#ug_Zy-%LrAl9V=ew zT1q2~K4N)Sr7t?kZdg}TKVpXMo@-zzAA^Ohlq*!_;{c872WtxeNT}Ym){p?W z(Lb^K$aCThPRQpsF$T&wL6LevUOBo;U3m1YY{Jt_?0#-EF8w5AdkewAZ{oUidiEAY zLUfu?TapeJ4FX_Cj9yz!1Txj+tz21UC2f8X*HdXAv42I)GjcJ1*siLtV9nA?w-QhH z8A(Zfj~Vuao#lsZ<0Hr8-8@(q!~-@e~z|4x5$t^u7d6_CU-o1 zMh6pkz6q_@_;i+&6s!I2Hqf;F6GQ*EzXn|V*4PcVbUS_vbzf%+-Cifqx zChRgri@5L2P%`KjvnKLT(z1WhJNamOJ8pc>d)e80(2X8^>euf4PYSzfgGa{-tZWrr z=qh!M@Yi)x&cZ&s`W~^28NYQS)(rGF&Sl*1T9iEvnJpV!x=OouaSYq$#AiFblUP9+ z;d^Cb+eVf2~T{Rpmsh`r~zzZ5^o@TiiD$!GYQ+JcO(zKg|$P;jOnLcxsNQiSaC zY%gl}%ZP9O)bwC>__nA3(Q}y^zF!j7_xSz_^Jpu9O3u(W#;0?L^O)UG{PjR7DSSM^ z3Q{WH4->wVxQIvKtCm*@^E9jZZpK;l`$7rNIZ(^DY7Wt!>`({WWopHbJBXa1sp7wzXy7!Gdzdr96q`Qf;* z=&~_@6N09VO;!cv=&nz?Bo~P+(Z!ApL*L+~R5eg4w}_&+WFf?FH%I4X(>2{0jP>AK zJtT|q6$>$Goo|lrDUf4)$&eoEClM$(SGMY&-to|VWv=jZCHX_Uokqi=nlcPA&egS>e=Y>}Hw z`6}Kaxc;7AtHTTx;!$?Y;*QGTQ1-cHjVAT81}KGl`%)UGGb-@^ZKk{G%dEp!JD8s} zXJo1M6v#5lx$m3DoYxoVvk`N)U0dBut%BYi)R&&BnNSz{Oth7bd+)0b9-i7(^%MhG>-}~;uQ5FDOHjypqB3l_E?0k z+7K2>^s_(jE>UZUlZ5O|cq7KMVwM+tF9Q#!=bUP2J%`_0F4*%Q9NTxO(6_N%e9q0O ze-+ue{sfJ~XU!I$aYkBnvhE-$GuHY=qrSIG z%M!J>ZcoRlMH$dSnLZ?s(x>E$Rv_Dpvu&NJ_GK=ixYt?xB%Xx;w@(9zQzQPI+Pa^C zbZb9j+k8+jE8SYmPf&?l1y^U<PvpZ!Re@Jg&()vtWL8^O4OnnfS=89i^zyU5_-ak&nIn z+}0zM$MTgih8SfJW}Jugk2MZsp6f%s6xbL;Od^k*2Q44urBJ${xTY=QaXZ0g3~H`Sk96A;@g*OBFj?5ef={ev|m8#x|N z+Z~T)ee+Zn50y9U4dbJv{``ttvdYBpaXqU5HJjV;hdmUutQ_gV)TZ{$YZ(cG zZmha4WWh||PZmcjTm}>dE3MBPcNZL%wge#EZx?(nB|Tg}>DERb513H3@E=k(s7vGh z<|xKG`3aOc?wxfopPhX*FMrxrjFC_)Jho(474BZ_;B$KI^3eA(xBeqBt8*WWnBfcL8Vz1fqPXS^4(T)Y{f-E@&3{_7hvv4Kh_ z^?krceM1gvbH_mZiSpxXDgu7w7b0*vB;e|PR3e>}n$o|JLc_s-yW`U2*Ir*-<_N$W zs)7jgq`;b}^-!_(o9Qd!{BW7Ezjy?q+<&=DkQpyS-~)SZ1YBSRwCsT>Q)Y}2@N?Y@ zh=&^`oik03d1-(CxN^dOc`VZ)ZD_YiknAMq0%P@U@Ky4Tt1SxAO}$+#!o{-^eev*ZPC5sO{GB zJ2bi+6nNo6$0V4!!*CDR*Db86jk)@BHW}5g1@uHF4}WI9MxOo_s7)I39NB7HDLo9} zPvp_ti)Cs=5B-pKe&g4Eea865a-htdoJTdOrAJ0}nD|k-Q8sEKZA(pl$j5;^kohZtqjqeY0Y8#?O z5Z2hG?z3_EXAdbg9mXhG1tZ&*1#ECixA-ZAY zhQZ*L%jjzU%zp=>=27{}7DfOD^sg!j%8bX!8J}6nidETbqtrEps2ESMMsLKM=W*~^ z7YqzSN=D5XwA?ozO06A2J=WzC2P*fXcIFznhy;7^9>z@OdJi(XpN(=MVrhQvhkVk% z>|hrVI}3Y=qChObE>%MaKDxo@JgJ`Mg^o+focoP~Z;@3{*+JnJ#7AsSn@xgc=JA_> zy_IbsEr`{A^YiyId^7$>(q5WU`=nZrg<3~GgWU_rbNj}40!z$!l9@2HR`BN2qj~{4 z6~m|vLJB?o7z;(`uiNHc!?p?fYdV$2CA-$(Pm!-2)_>Y{6*vvfVfO-#obE1HIXQsi zP_f?c;M3*5p+&|6jEixyRl6v2o$<~{;xlno$ z(sBJ6ot4#?HTRcFviJ4VzYbYG>0j|utCPeovgQ!-DQ5D^6VUbhGKNh+IQcHye2qb6 znQ{24!@@;T{nDRr{5$FV`;BAJb!N4-(W)PLw#CldKWff)7Q5&{(css*JpawQC}O6S zlHFr+Er*s4(25usoGv@H7R_MCXea+#lkb05LO`|q*P4)k#8_Ty(-xT+Kd-Mp>;+;t z)tk&E{J-qz%isA5xEtz`YeRj%<41q#mmo;ay?@$%Hz?iVxA&ic>u`sWWXw-SIU;tP9D zNKnYcp80fflbA8gqvG)zbhs)UdM`*(bY{Y$Ii9V{=~ud(qGA2mMzDG@asPCOzf4LI zh<-b=kvVUy`YY#1pGYWiI1W&t1kJV}xEA0?k^Kp?Fs|jH{B&hQf(EY=`r35I^G~*# zZz$-!{U1S%UE@*2X&{rffp5@-S2!873bw#ksuu{t83kBXMo}n~@K#{@CRcQBoAX|L zwRaiFeObJwiXuhX(H!Nvd&&UkAJg$0?Kv7vdY@@;U~vr9UHv0PH+-1dCJ?4WnfZ0h z@q#u8a9Iv_9BsEr0?LWUcHLKN4YWRo0!HtzGpz8t1uQgFb`RJ>jRkw$Mt@B*b`EE$ zdp*MP-6L{6X?MWYuJODvL+IB~5S&q{9O#m!(l+qIbjo*|u2{jegz|A{rvUfR(~_50 z6qN7oel^v{5_r`y_klV|+TZrOFfP+8Jb6HKslLs$tQbsoN62_8W_)`N^4NzgdMhpD zubs65EwNu|)!(N*=6f$Xc4@xp70XAFc_Q$*z?pt0wi^rhYup~g*6Xex-?m`7^A2K06yHMKJmY+yoURTS+!pA5v0ar~ zKh7&yl*|gyaQZta^g#L#DYG5|1j=P`Ai_x^f{~A9);{jzLWZ_Dt0q<@$2oMrlBWlK zRe~cg8|w>81ZV#x(E)10(^Y8Dr}SqJI&jsXPZ{Vf!_tEnu8nS41pe^(=&DiKjjjf_zq0cC_+)_@?K) zi0FTuFCPIIda*rmc&6?hGWm#>i%Y0t>*t(rub&k8DMWPKXu3-tR+8-alu!`i5#{7} zUm31NcjwR0kyXZ z|5OuAq{k_u`Nh^!__Fq53fQ>r_!Xx3_d2!)%KSBChR{4DxEuC%pJ+g}j=XKWuoh@K< z8Ql7yf}h8YxzO>o)R9nu)I305`>J>KEFX}b{z|HKeLd$;i4GC4D3HhGO&R~3*%W|s zM{eU;Q#DLqPdg9`*yHTV(T83geCm-tOdyWkg!bdYL>}|UV7A}?wb&xC3hZ+&#cHmv zpA`7@FtX3;-RcS;7S-8=1rkf1W*Y1i4#{u5OvZ3)EtiZ6-#jFu&*hUst-QH@lUw>y zj=9YCC0bBraX8t#SDCWPMREeLr`ha{E|0`hYRD2@NJ4zL{2b-Tq69-)%%5aTn2E1< zwsCrFlCUHRVP7MOeM_O|Pi9!yT(&m?NX|bUZ%AR#?+pAI0_Dvq{}&O82nEE8EZuue ztDI6Bx_>+pVv%U>?%JB)(#I#LWDM&QY;)%sK@x-sy9UUfwAjiN6}9S_HrOH9~cB@K2ZcW*liee|D(;x?8G}jXu#x zQQl}jK@^l(UI-o)3k3}^D8Ln+8JF>&*-wSp38}yKpRWf4t;>>-c{5wSN}he(wm7Gq?VK^dU?g@px>r%{~j5 z+q)t!*Dwg|FPo7u#E;j=ZD2Eh^>mkiM51r<@eP&vJ^v87X_>(~ML9E|3!O_06%K#c zMsNIeBpPbH{tV;4e*h!0*F724Cw&J7hXc$IgZr@Evb!v#d^Z~!H(%O&zc*<);R2ml z0T{(qZFw>H>>U5|`LCR8`kj#vM+PgniQbG1;(N}}NNr9B5c{KL=>?0f2^o<*^(2>x zfn5;V8x5A&=V!g8MR^NbPC4V%IsQeotn0|qUd0u5?fj4ZT(dwfHNZcLi4 zl7v`n9W#nt8}Wp+m_jUEfK@aoTHc?mf0E{8L?qxFM$7zoo2pi>MD1X%u|+fIS21_I z$q32?iT@rpFhLL~7yaeuA?%4}lb{#Rxf1pSgKe(u0gJGtC+v$s00HvtB6&?qg~-?9Iae#}T`?322^1 zs?AYb^FF*d-ZyPBob0Yfspcf7v-(l^V25{}hcqn2yi`W`C2ExDOnEGdPE_}zMF^7W zvNVJxtzu>Qa;`&0c$$+{_sZ0%wX)>sTr+#nwXaEMR7{~~OsGqdM1=t`!mmPF2U05O zufp4}7>h{_fwJ1ELu9pnkd*Lv4worKNcx*d20J~JyM;Qs?6CBMIaVIe^&yXM-dXS< z;m)$*@x^qf2bDYVj(hub6$9^scnuwFZzHbl?W{(cJb*j_KOnwdkh)GcEJ|>*(;NUo`$7d2bz6<<_+itDu5_AX~bmrP*{Np>#+KNO!}gQ@Xoh6Cw!G zNVn9cLAtxUyS^Je=Q+=F-sg?+zQ6yzF&t;W(S7fA&o$$k*P3h2!ll*39iudPIH^Jl z6%Uwee#|{SB;En8Y4mLiX7d2eDY`XMB80vO%$`2{hZH_M{5Oa}0a$es&>r<{wzg8l z{9W=kZceG)Cz;IK@6Syti%PmOq5Igxx_sW=;{HrifO+_f@(O-&j&JF7KT*||Q_u;C9edV9vj<6J zqDK%|>0hW<1)gYPc}qJp0MnU)J78I%OhC^B&jz&+2ZO3`x0c*j^(>>Z!05$Y>?iT) z+`KP^;z!s+VSG(11WU5?DY#)P6if3?Hziw55LrcF>P>y@Uy}Z3j_1$ZJT?a?<0WM> zuo6DCApo0oQPZBO_<;#>N#nhrCFP$m-uxh)4#@^azYOxiBY*QkHo>r+p#IYP<3fee z0SvSK7XGM@ItTyRmzhJFDhq5Prw4H2S@E;XSwgp)H;fjeCifWZ1}{eNNp??1jI_Ce zv|~v05MEb!O5q#b>DH~L zJ9XwC7Cv6A4-sDp*iN%;4Kc4g)&Na?D6O=}{LnWziPjih&=J60>fBIO&wi3JI} z3kQsey1sw?WdQQNh1`OX*DgolAYu7uU(@~OjrlO+cN;~blZBiaYY zP|=)TA*ivt`qw{g^vIUo*QmC_`#eY3I(qW7m~!9m1J3-6Iv-)+jJwN1Bk8mE;Ky*? z)h}(HakFo?RitEFt5&lJO;?8pz%}h_R$_7612+qnx<$x{|5TCxB=VNocx8b@tp95- z|3AQlz{pUQYkj3EJ+n`VKuMPJfBiseZJ?Xq=YJrce~ZHa7V|9S|} zzIvC~DvNTJ;hFHb?^BNIwx>&%?%vV_PqkY^+&J9d@$XK1?j&#;J*ct8|K!;cuOl@7 zv9$d?I*xBBGv!ipAMZf<8n)yIXXMB=x)5#;43~I@=7d}CfpQ_4<8@ZTA!U4_^ABIZ z*)rO_+ZZYh7fgj|^&u%pUv=Nar>+m{bHMsgTfafx=@$HVoSgHY@Jed-Irc>Mv+qVD zKRK&^yw8i8Ju-^?Quy@Rpj>Fdc5qCUl7ri1xaGW(7s$EZJ88Q?(WalV+146d?wm#-E(}P_?AiK! zL?Ie(*sQN^WVrRwG=|y5SEj+dk&baKj6Ujc$W4bhOW$WtHWw6R72uT35UA1Ycu%{l zgfRh-ka=r?Kaq5{KjxcWyDDKqfGe1Ax-YjE)@B&i#1Wf7SI6EOcMb@}(SQ7;fu)s0HdYk(kRZL6Tj zk;Bq4VFHmv6zKbwm9vxSRXF_;Kayc${{go%roVJFRzaLbyWLVryv2`FCkrE6eMYOb z@cX#)PN_G%PTyZU_<$Tg6Po&-D2!S}A+Pj~|TPbj+Po7JB_4OXY`$ z1$Hmqt`)x0+BGUGkjHszqvzwTTJtgGIj4<*Crpr&*bPooo$>Zo?oUTYBF^=1l1Zvs zT1x1Iue1}(&#IBOPsZJh4t`m*?dXlcRrY+(N;j9ilvYA`V!3qYIb<9Q_<);04%U6! zS2$gWg?~Ek7ZdRR21Uhu1V+%RHxK9j!6Lm|iuorNsnCMJODDrh)|@ab=n6M9IQ8@~ z{$C3P%18eyEB=$Ew+AM|A7i&u3=2A!Bk}EcAgqP9Yy2}0RJDA;`W?{xSsP}0oN#VX zkaaAxAe~#I35b5F#x~?n;7Ky|uV)K){j*M=*v;L{r|t0(zPzETjoFQ>khPzPLCM;# zXV>1VKYNZR`05R2#Dj;Aw8@Ck5GH);MbfqRzAImN?5)vshiwN)id94(xul*MSJ&8( zU8!%Wu8madNxwqydW9gLKLkyeKXymH<)s`>>OWBXvXH>zKoHlP;{OoEqs`C+<=OBD zPobIyEV$JS8#WAu3HLWQ4M-)92hOzE{EvFwyTkFTZMJSx34&OY;3_{_|M6Vl(ZHWX zV#p%_wABbs&sX^fOq8#Y5<&O>yZL1vE9JT^SPu(!tBCGYPHoZ9+POfZuv_)0mIUR- zLT4|(@0?dN&WCTZ{{V8+X*Ks9q1@)1ow;fZA*LtfU%ukQfA$rfy@*WfuH)eC$tt{A zwJ$GTxG9(QuZ#yHoX!MGr8aS(Id_k{Z~h>5bRmEZ=o;K$p#NUnuayV%z~BGI=--0~ zul+WHQ{Y1f^bH!k0NU4aFV3d8x2y2KUl8~&3A4ccZ$0G zl>d3HOrZPUn(=|Z_rU-@r2lK_@B#^ftME?$vq!d2(uy2BOA!WbWgeT3Z9gBxWC6CaKGGNfArMvU(<*4 zew{R4J&0Yv%I=@as2{s!PwzxXU@yf|EiInr7p59W(w4k%bA@`%E&Y#=qv2rPzq|Hx zKz7YRuOMFE?{^+pmn32^S4$#jdggfm#hBQJ{LL!q`|JtX+Zl>~S~tpnBJn!clP4D* zMj??)O2NRVjmN5&9QN%xN3l7L$?q&cf36)(80H7;mOz3Ku;yZt4rn^Cr^I8tmsn(xhan+%#GQxZu&cay)0??Hts?^+>auqShwH|x;a$#ayKt) z5^lUP&qWXIT`|a@-eF$|AJxd2HD*>tnf1lVBPrgRS8Sbg>ho~+87|JHP`-5vD!5|M z21d=XY*fOHa9O}LU!uYx^Q}fbxkRjiLrp4P@Dk! zG)+C#3>Q0PE?7^f1~z)fs9u@+t|!sO@#vKQ7DT+fI&&4y;eN|EdX?G~@u4Ya4npH_ zqn{>qW_=~h&n`Eq*bqyZYd;>+NKmN@9eUxoR~-;+EwkTv_g)c9t~Ga{yp*P+_8@r; ze}(TU+f#jr4jVt#Rgb7K))h*Ndvfs{0&DHL*7~-$)s#CXCH-{!73%0ovMl1?6ov%K zf_}f1+FkpI-H~_xFe^J0G^v{CD6Bl5ft(TSmsGRJSPXY9jiqr*OYL{oVs-Ri&TJ5E z7Sub`QCuuR=9E;V_RRA6DS

7Odl_H0-HAh72t)Wa34U*!TF_@#95Vc573Z8v&|? zZ0ZFwMmgbkBznH={p%I(TX~8ouTam4JY!^)=tE_VPI4>%9<`W zF~es_Q?Lea11^#I(fS_kLb8AL>A_YyqQCtegJ!+JOD-#DY&P!x zgh8`$jbDJIyx8QWMr|KNh9S7VUB`q^E8B=$OrvCvObNqZ@iZgonX$iOCN?L7Po5~n zDPQ!mvfL{C=W+G276cD?k?WR^Y0T7r^V2V+pCI15LD8OItfJ5-zMVreRtLkp1HwF= z3jSVw8KZMToVs4wG7s)7sQV|n!+$!6zpC8!vYORJqhaq7ulYV$*zhpk`T{1;<9dkO zuh{^o&)q_c2a!~(c*L`|E@Hgn(-{v6oxlg5H5zZUhV$6J>L6l$9+%g6QlbkU;Rf-# zM5cSBHL||Qqkc6UZD<+dd|mV{!BgnL?mcZe$6C6t>JQ4e(X zh(Kyy-ZQUK_i#hgi}9wf`uY5O4i%E=e()gy?uVJ+{IOv8k(Gtp51{m*ZE@}RRO%QWZ zi9R8pev(cs-8%*jNtw9pN0^n-@xIpGT!Ho3v@bMDLb32EpEf%kiTMyTh9Zp-lJ2fH zPOE>4cz)`x5qyWWR>arb)ZsB-M(mfj96KBZ&PuA*4RoPpCnA%BHIQ0QNAc+FQG(Q* zL+DmuakfZ2&fa;cEvs#2VZuiu#@dbH<9ndft5?^p=J}JhQCEdxCWNYmS(IZ8r1V#B z{ua#_;3Iy8R_g?T&uVuc)HCXV?VfMkN-z>kJ$_)QW2rj|4SjuSisY(XLsxM{v^w`z zU;Eu7rLW1cewcj@3F9@{kTKzf!4t|vzW0WC#dw{`*9=x$G07eclC&?nqbh|T(O$HV zqG7o-j)RhFpKGjG5P~XDBqp7=dCoeszkFz>D5bFfppz*HIRO8(AGhd(f2;e0OG~V| zBif-VJ2f%z=k*b_+)ScGraZ~PQ1L~t^BnWE&BN6_B2zrJHMB6;Y({2f?2iy67u0s^ z)i)(3^o{Ix?F$LXu9kwx24D3JP>;^S0u?o_=d9zHd(lT9UhQhXLuY%6an&(fQf-4pB2*MWvB4bOmsS_dTN|dmo)deX z>fteIo9t72MtHCX{a&1{ZfbkTmbk`PQtn#$RAV&+Nf_oiQ-#k&lR%@*a9uZHgGy5u zW_Y=bk+%r87BkAwd+O4as)(#D(~gc4t!udV`L67h1uX)_=vni1iI7udiiqG5lzaxnWh%v(>Z(O*hvKZb^70I!{G;ut#H8<>9g{=wF<>UOxbHiYg zt376^?{i=Of;W(R*(d>cG3G7D2I)PZA#qCO^*~%Z;eOwsWJ!^V-FLrT{r_$+C!E*3#x9Q3Jeu#tes4CuFUw3plsEd z;!|v>ep=%NI$>gU&?5=yWpUx> zaz=Et7k0?waY5+8ZNcuuuhLAy%u67-UX6VQZ&zau4t%Wi2pSvKD_MNv+&gA&Xi7WO zd)b4;Y^=CfHbF+A@@QUGfoG0Xx5Q*9P+0d-Jv_pg#~YBwCL1BxZWjquwxk79xi6xI#Gb>iJYJ)_>61$Y04r^J&NS zn2mvC&fu&q!DgF9`#m-)hZod*+frqR0>Ox^N)OgGMK!$MhsgisHzk6BXBVe%D>U%{Y5zlxu;koZ|ULgF=n~qKER0I$tgqqjHJ!nX%iaZe-Oh=~=2e;7OcBo_V{q z2}<)J-=?V$W85n?c~+4+uZ?lF?JkKf`@8@r`YK%Wv#>7v(gl-lGoamo!hLCx+&FQuV0OPXtF9mg zSDqJCJp9|jk)ydN{#v=itl4Eu=9GOBH@a9$RIwR~g-}n7$!g!sdJn>M)eZ7n(|op6 zwbF;&`@6j4`hhXldjm8v#=F`fL`qUkqp4ht?prtw`KgRgRFS8}R-uvz5`B0&p8|0( znW(qFX)jdwso}L*5;Gy0eE=CWgInKx&dsyhxCFVI!h{?PHE%YW>Gs~}(cik=VvcZ8 zC#c*-JDT#f@orp`$Bk|AEax(`)X$KcUO&f zOFamF5FK`N<{{@TAC_X)l!=-dKaQmDJRx&BH^tx|UFUdVDRfIw??Xg{HxsLtbHJG6 zHOGL~&sXZ<-=|r7n}w{YYsH-Y=XAmb-Z(FOS9Z5f#=8_1JB-Xw>&$pWzIsOG5z|>Q}+1(gEF7}smwd7eNnPfm*iwRcK}4^<$Q-mrN8D$L$L%K6>m;z z$@RnT%K}n34gdo;`o5C$A1n0He=#n{3HbU0wftnjSQ75h7U6+fKo5e%sFFpzEP3{b z7Qed$?DN#P8|;W?Y!GBQJE@cf5>%z1a|~)8frf$s^WbPC*WF*=zys-StUva@0NzJ` ztSJxXsP!zcq0_R&;y*!k0$Ow(;OpGi@jTWQR1!NdaoloF$*ACkN?v-#&wcA}*=izha6w-6LQL(%xMBQ`~ z|K9U`qR@we{Q$m&jcavtPrDok&FRz@pN#sh^oaCU-ed(L7;8yz4)0ZuVB>AyF+X}J zQiRtZ*r}8nLAksJgc8XHRC8^5nZE-_vb0|Sa*CAsWA}%s9;-sdNqJytM9nJ%K?)A5 zQG6Ax8xp9Y?ft1}VZ07$fXZ$NRn@^UQ`b2kdCuv41YG(LJ!ELr8{JQ$KU|&sLu3`b zW4G9${d-ayh&IhHk5&kpWyakVLmMy8ar+HB^1WSY1lC{B@5z?IOn|@|Hm8&FH?QBu z7^q8fZ!Cb@rya97W!dgW0i}j%%gDVDcK$Zy77YHXWX{!8Ir#(#BrJq3cg6z1Gex}_ zzilX-C!0Rd2P^8?$w+)+Shxx3lnQ&)A(;!n|wT_%_EbI0tV!k z8P0Bb*WzF|hO4`!rN`6QR`Y@)?B2|Q7;2kEZbU-8RzFB(`xK~(!*r}cO-#?!^@qg zuJSISV5P9u>BUMFqW@qev;VM?mU3vpXHyQ5)0g8n?M3Hw@R9Bm^qeLk3~|>bkEtuC z^mq)yQ`XLvydvVXAkh$xnhVY z8WdmR1j0cGpuvGJl<;T<#zXU7>6ijtnvvIJvu-FFVqdi;7IO?^Gt9ttm(&VS|K<*{ zDJ&)b^oI+oe3s@h2ly@_l{+U8#I5@b5O%K9nTAy~d!&{hJrqxf8y?TA*pW(4hQP_; zauEq0eWv(PpH+WZM#E8yhyJvT%*7*yDf4w;10zil@73 zlz+k}c1z`hF$54$t7?dX(7c1M?@lX$e)m9U?Eg~WbpX(SLG)z2U-%6F()yBpQ00Zv zR;#&7^N>&36J}y|3{@wbrJdonr>=SWANE{ej*@QKKe1m$8Qy#31Y_P-uIMYVrc`(~ z{5t?Ty5_AoU2LL^I$eITsc+~wyP1Oq-A>}i-foBSZ=JsDb#`O&9j{i(9U#!ev@~!f zx;0}#Bb>cD*gR3Yoso+J?!1M7OPjYi=1lPJFCr5KjYz5tE;KETU1NHkvor^Zm#pm` z6^uUqJ2U{X{{tFW(cM)qa4m6aG``IWLEOOMx=IJb0MF4`sY+R4oA0wTbvIYZ2ku}f&n*sm{U{j8ii4F?le931U=|=n{TE`0a0$3r+ zOXH`!j)Wj@<&u&_x&DkV)n{5K_=bJOKW3AOy2`x!Pql78%0I7|4kb2lmM$*H)m6R~ zb9{C?_6`qFUC|_WdG8SG7RiuNg#Y{x`SY=rWyf!z1@*n4l)ORJ;xSgv^|N5>OtOV(K5 zc0&Nz-)>ogd`8G4kZciM)#@^2;ElGrM46OB*|wNdNrcNv_@iFgY{P%dGA1v>YSpG0 zU3Nv)`LLkE)AEP@Qi%ihLOj=O%ARa|ggAZTp6Oe{Oj zIjCAI2whY|BrztDnFYh7LZxQ`|0%HQ6GIv|MiLnnRiFb}g>Ol|ZoIqc&j}~uj}rYB z#ps@IHy;I@ydeLjZ@^v|vXi5%u{(Xho-;{t1tRur{nZ=v0gpSk#A95ZPF%MH| z$JB>iX8f_+ z9?&N6E~XrIdK?P?fW#GiK#BP!cy&VCc}?6A=5l!)cKH@V5a;5|e67+xW&}D8X1C3j z3i7iy#W%97NfL_7W=~=2Zr>-3zz2Im%s4CzYFC{-k*J2F@wNu{2%iIHO>`5L%Zz0mTMZNj6d z7Q3zm7&^yU?{wrRj2=t;PNlB6O8cN?v4hd5z&!o8sNDZF&3jRipQ5a-)vuAvnS7Y7 zbMwD>N&+^Fq*R_fUppj-0N82u35~&Q#rbyFYsCV*8sJ9uTN*bjvTL$4msZuBQR`Ch zXaeF7+zG-2+0Ck}DmP%ytQHPY@{f+&(+7gzmkaQ}q@Pt_c`o|(vPO@F#f*H1+asTo-69zOiW*(Z=uLb``=D-}UE!D8^#kTA^U2^Q-MDnsmfP3IkX3x7 z?a7Y9$WkZ|8lm~c75C`%P2x*$19-gfDK~%pSrPk!+q@G|CW=h?+EKGb7;I0mmWt`_ zm#?vBGpp=lJ#TwGFy8Apec+G^_$#>DaCeSvG6Q;Bk-S3Cfl#%5V8b1g_cJU@KU@#g z)`${`g@t;eJ3g1rMxnSbW=N>I+v*qX-;XfB9nc8uE|;i5QnhZEAng%AB_FUNOP?Av z0zKM+_If{nZkEB>GHHtkOYKlXUT0yF#11q(TIs*`LFLbVh-Ukzeb{4iE;_iT{pfd* zmVCwt&>pUuU+Em=3+H8|_6@x?u$gR`4Ej|*y*MWFR#Jr z#m~xfXkS$1b38xSy!}{#4;&OQ!&HaA)dKsmu;DQ47&2O%N}q?hSWDw0QHW=ubSL+j zfXg1o(>Qufr)m3b#=)CG(^*=|5F&OYDb?q#?Br^=jmpl*(9kb0vM(Ga4i)xdw5GMU z64cP+pse`a_gRh7*8Fj`={kQi;@cqhCBE)O{i)AVTiOXMuJo(YDn5^+Q>!v zwKRy5A?6_0wZb*$>%x%T_p!gR_7EWYnhz0xBz9h{AZ>4}1~>ewvy}Zl#@(#a0gTeV zNDAMyVMoJMF3YBOD(P3<9w>Rjav8~x(Y2R8HkaMQBwIpHqVBC;9L?KFFXok$Mq6)> z_85Pso*k?`2owIz{X`BF ztrYlgGo{Ym^B79{PPdB~Pr^SolMhdXcFelL7CQ(v>uLyn!yUq+2h$CeNsFhZkm#mR z7A%auXK$}D%yf{~(!CA#yYwFSMC-EzJ zR$B@I>uXZt8h0~cZ5R7D30%hX?$lF$E}C2A*nYwBbVv1+J_QiVfZ7D+vvZ~H`9(;V zBkdA2Y`-9GWvBKGY90EKw1IvAo%D%hyIA8*Mqp=8PWy|zfEES(uz&~W3w%uhEKRzPQ((q8&>k9S#Mdp3y zDPEC4E?EM|dCPM5pivcb9t)-1Iotpq`{yerr;B`*Nxhf4EAJSf)pn2KiBico$`}?O zT!thC7KxY5OUC8vJXc#4G|o3}AbFAqLQk)wq~iTPQkRXbWG<~m{yJgEJNBE_^9Z1o zJ;d@Hm}_9QIZSmwv<0|fgMfr21oRdR)FZ*fqUBS87CUNNw=DX%y$GhnRHkiEfktu? zLv#|#{wZp5N?@1T2~C7{6)yo`Y=z^cGk&A;PS=c)u2>6qHfk^hCUD{$oAc_Jh|#ll z;c9IT&U&@(uoDXw;N-urOr*)OUXJ3Qu?Q%-&dAqj*g;=vKpFgf=Y&s@($i~-TSc=7 zv(>&5`N3`#nPnR%p+;u5AN?0srmSHEN*{f03@Ts38N%bWzgzC)20)+6g55O76AA{T zB{jNs0H9XuapR!^?I~L7e=-i}k%<-E;8mn6WJymm@fW7vUf1Rz0~pO3JS~dV?Uwxn zK(#W-pAji~=Mh$V#hEdytkcF;K4}f~$HX|lHJfO2Go-bjZEoy7ES|g~Q^p7~UU@h5 z4XPF3@D=s$!|tS^)+q^UpX;g13H|=_eUMaN^l?yVu)vwNeD%3OH||V5WzoZSC|@+( zfuW(L!rpRdhkRU0Q2-Ff;a~5Gi5HoH_o)0vl0eA2#7aS4oP6VW-P5lCfBDYQk>e2@ z+yXF09cc8c155)Mud*hstE&?wlcpt9Q7=B%vTk1VZ9@qr!W5xRRZ%ZFXQ~Rwe|T?T zR)aEs*;gD-V^&_J4f|N`=!WN#|MlDG0J4oTk12!eQxO*nw^MWXu6R>RcGpVq=%X-= zHMN3ZAnz(&PtcM`j`=o_lJ|5W$~lUo(F|H;~jC5pt&R$rBkRJDm-=#>RKER0-or}IzghMTbh(p5|X zjr&B`8~4mKb;%Y-z3jPYnw@X)M}=6{)Yft<6{-jNB4O#AZcikDb)S5dNs}hA);ajG z=+*W#7=+0{YoT&o%%OA8x){`k70irjIXX{50#s?J3}ybJC<2C~0WTKE)kxl_R+Sv- z{X}XpRR>c7KPq!9T0+}s)&T9yXA)N6d7b?kdB^_>Pq$aK7ZF(^6wyHP^~ZYdGZv5S zrkRX1A~C{?D)V_{V!LG2wqUF-yq8I*ln^gYpi2%TE=YQNnlj1tzMFB&n$_3@XE4u& zgQA8`-*^M3fbk>LHmkH3iI$mT(($3wE-nBVxwDYy_p$@S#)%<wX4S}5*u2IeSuV3 zoll$Nb9y}ejJ3PC-sXZ3%GvfecS0BK{);IceD|&~PHpj8sc@!mCIp|?gex|$@>2p% zrg^(IRG^9f74%MFS;RHj{%A?6v)p(`%0fB3Ck6m& zM$15~%FLE8`v?R5;W73zC@4!h9*M!OC9hd1do-k0eTLC&j^WQPj&WG!Sq`Md6Hnjb zzZsZQP5;!>^%-9 z3Kdy~B9YGmL+(?iX8vp|UU)mwgsD5lBHT^PB9 zO0%q32UM$zOU7dtqQdSGYBU%bV^fmpP@Vfxw!wg*&iyFP!+*uI2l1Of`OCb~O{&X@ z$Uz3oE8Zu4t8?ZXAF5J&5RDfuf|bwfO_XDbhTse>(_R8q9dd;ed81ZiJ-rQ2&mC0w zfbYN~i$8nWttd|Z4_W>7gL;<2GhPq>KO_c#%L;lJFeVF_2ZaJeuMmhwu^dLSn?vAp z#&2Y7Rhcp4^HFj--V)-0DD+h1Y(FyPtd^un92w9vS28Yc{M*qti{*DF&3?o`jI z+UQ7i%LW2fZ@ir=ylpnIa$>YHR9UTuF>>O6{h&-0`#i0-tu35&jkm4X1k#+$Zdr8# zUR+!Rt_67l!NbPXxXN;tg6l_S)W6E-USJQR#XYs#tN(IvOq6=kw7eU?&bLV1^o~ud zqt@7UJ#W7)x7t@obd7DcHBJoz5UB~u*qzveNOfOI_xe^u zgk^_Ij%5HSGGp62TPjT9oASPJP$kU&HW-bqC;c;8eO8_p5Ypeob}vXkqEySg;oc84?V-P#Vwnf(A6H`NZca`sc) zGTf%dp+haH*wp*9Aa(gLRi|3~8(Ca-jGT^>I2HL{Kjg)T6B}CC9O(8s$$;}`q5JCt zr0Uf+k3WC@e2-!bdJ;U#caw#mCMclesmqC*S&31VG?8buK8cS($J3HC-diyG(`w(Y z`t$^oQ{McdI5X3L3|)Ei8rd{}h<}+O37PDJ^eo;eB&~NKVCXuV990m~Jlmott(E2A zK|A(*#w85%MqM;+=qkB=-(@AtXviG>&bTx9C)A2#m152W$LF9wlk9Alf3+bvu3shc zb?3R2#m8m2Z&fZho#IzC6qcdV$GC{k{FS|-HSYS1WBl**08#Uyn%m@RX&O|WpW3cF zNg_8ejNp<|be$zW>qR_qj^HFFCo`5tNDhvxaYPF}5FX4hYb+lLgf37np7F91T%@9(mt2Atmwh1*+@Ct3F<0JX1GzM=qmZm);x;ZNh2 zi69@dQ{HrT-|ZRt%jyW)yvLPDJr!@f^j z$lhPIT}K$w-y6Tr9U&8_l`f4VzAWCj{TlSi<81K!Cf>jbrGqC>wW3TY#X#udw;yzl zDum95|KMbn(s7YW%YB{xUO=r5L+RP%N5h`x=V>jAMJ-eiSl#rh&s)4O99=+1slJHc zG#HJV#~9h9#PB&iU5~d%;B@x-t-uumHb4{x!pFZx;op%3h&2Y%TEDfk1Nw^OFWvi7 z83BO-M|cP&%YR_J<9m!J4keP4h@(tK%8?ZT3P zPx5`{2{|Og{xG2WwOk#}J(BTg&mupacGH~()JY#uC$xc~@|}Z}lxH7urC{cfJ8YCp zlq&ENZP$5YmiIN2*zA_ye#s%F$~q*ulJ+Cr+*uJToe(*Y4ITr}*!d?p0I$zPDGIQ= zgOq>0xt3gD526p6gw*fo7~ywsX^YqP=R^w@aVdP?m01u)zt>eb)j-k)E+p8`Wu%oghSTW>*bJ;M?U6`XVNwuVi}K7fbe@ea!C@5!FIYeeB$3hZ(Dv3^pt{xFNSzKx;c2~yEH1}cbP4$-MJ zUdVfbNu4*9?kd~Mq31@We4U~ekY=4qct~@IqMo#97~a9zaVNyVVPKhy>gKY+3=b$n zPXbh;J&Bt47aRcoQSD#x3e;P z&+A3K`}BP8SRMjfS}6S3nmvTw9BGlCbpzPTN8xEaOS5~i!STWFgk)Q?@w;ED$c>l3 z0u?m6J^BU^!MjN5*J@d%Mp59BC}v7wQu`m4KTSNE6Gnhp*uv5|iooqn6P=F<9hX;N z&;2dcM@~EAmp8cM+A3Za&o7y3;bvrS= zn_9{9%W8P!?N!)l86VZr-e{rZpZ|07()dWRxkHKkb2*J4@z#ZKu~7yayuaVbkFgIE zwBbxM4>a*e`Xg%#E?GBip6#R6m>_eD3w(trKG@ zkrr=s|2F&rxP-7aXT4GKRa`~~aGR)r_0j(e&IQ1x4`8Tn|KhcOIica-P6)xcV|~nq zysyZvQY>d|{5{Ks$oJLkEt%GU>JoLhg zewij6v^0Q5;O2D31WP``y#MEqTtOt+L&3WmNI;-IEHY=wf@h7J7GuqDW9r@ z*1$>sJk>7=??U?}9NI@yME1Sv-zD(dAHVjUHer3n^1-V;Ka`y^XPY2`S(7sa+W46L zi_QT_52_mXwHH1$$}@J&uX}miex*y)#Nn#muKKmQrIrY1HK;er+NvXuewwhzk{e40 zG{t^6kyX+47sPPZJ&M!f2@U*lB2@~`l+-W+`V5&}^4*ktee(p;%+J1euvb-v(2HnS z129wM9XiqxT&pOn@G~misk}tF4r$>D-tTKl(ELh+xzi=tFjy{vnf26{>7|3%9G zb4DkB)JcwQtTZ-@Gx+YC7@#re;~?z5nRI@TxI~Idsk$;3ByH$Rd;htNG$C+vLj)7z zCH}Y=r|NgG4x0> z^xgV1;69OsO6r)cPcz4{a9!{>wz<^mqVGDC6pQ(niXPZ{9 zxIgPYL@DDxd#4wZ&WjkyDS$80#w`SV|1j8+g)%3ya+O7$8T`bj`HO~hPI;qtu4UQo zTqMG#3bOqsPhh*qwt$&Q3_lJ;kW8?~4XNNO8W-nd%E`G}>xV{V{JPTyB`fmi!99qJ zD%7i;&{vIs6|BMof$V>6F$PNl>%-=mMu*{gytyi+c$dhD{T?QRLw3lCAY|agx_KtO z@g>gd%MTIxLaC-FozuXn-C-J(rwmte72c0$6`uQg+=(|e$qYngZCRLuraoTwj6Owt z390Q)s!1z=&4J{VaQ;;A3F?dEHoTiUb0YJoTh=g%KC=T4t>?oU1g&*JL3n$hZoTVwc6@~`@`jEkpV6;DkLyg4kBpJl8j+eAw>&l z>Q*W+B+5xj;g*fx24cFG7fI*waCsyk0c11mE0VC!YRAM($)hp&^6{|!d0~Cthl~U& zbpIMxxLUFv#PE1FApwhm8?sC)pV___USl_9F#xr*D-KXQufjP678BT$KsM}5U|K9d|)r)X5D%n*|b@* zQwAHZA;pNWHa7WED>q1Ld)yHLuC;^wBLxBhp^g+T$GEMJPeymFw*vg%4ysy$YM@>z6 zUlkHU03@=2Us{fzUWGT!Jtq5qAQ-tp==}w#hLKBI9Z3LHv+3agJ|AHD(O~@1Q60L{ z9scwZ(PG2zoV1qDdaPs-gB_B|FHxlg-LA?KK^ohHez~d`WsRrpB}SzUt{2XVfJIu;y`ysW^0L`#&b3=5+ENHJrI5J-ZzECzV4|UGuMumD7=P&_87Irb50MfD7 zmByLbGVy4ODOmlQ_*$3~b1g3AJd;y^y&+Da=b2q!k`?dkY2~7|tLc(n;}>2IS^25~ zQL8@RXqtb4Mlev)c9E3{%tJL02M$TofsDmCUL{pdJaC07D@LVIjvDZrg1|QL{U)pa zAY*`P7#|grKL~u5d-?@^EuIe|h-|U@F=%HiWj^`Q&Ugb{zp$qNSiu8fC)(z)l-kxx zHBlmIN9{4cHV*t36g403eR4Rut{@l0t8A$ScTrX{j*A{R8gQ8@P2ePUrK(w$7fLCj za+OLw4f|2Duo9VE(0JFb(%QB@DW@&%h_T~ey@H=U_)@?y`%N9=X(;T%xct4!7x6*p zJ~fOen8v`LOAS$SpQCFpZ!9?7cU@K%uC|DqR3+23AptOyu&l)tZ7@X8M2C+R2+Uw8 zj>hu|MO@dJ*1$Q!zYWNWP;-2%Q;2 z)TO}r9a(q)4Z}(VbwGc-f(TCjE6o607e!;D%#`_O((nFGDsf!WZ$6Y5<(}qiy~yf^ zqCMH8;Hx=SdlXHqDp@=Q-w3PZOa-egBX;^n}9j@=?{0!3Giq2_!kb;+mciNnh zcr?!np@og!J**zVeVNOOg4IJ4=_+WG42p9B9?l9NRQFCx>EcxVoO`Z^5D@Xlzc?O% z^$}i$zfb%w|6*O?Bmsr?MAP{u#{cVf=wQUJlv&f6hO%`h0jYn5c|rRpy2y1~YQqqw z|B@iU3utJl?F%2gDGM~lZ`$R{6 zW2P*oY_iNJ@}wrdfa<0>kYE@qpboATuwEXcXszq|%xF@P=GQ@rcoP*3xC*Ln zfSSVxR9Les2}H}S`NF#g@tza5|Lr1v`3mYRK%L0|joXO+i@asLm%m^@7R7H`OjPKe zbF_BfV5SUM_oI2@q5D)Z&cv$z6~HUR;`AW$!uHr>CxWm6 z;dFdPXuWb;`dM>=aE1hv7jhZr6E2+lwEp18$8z)dSo9VB4)-VjY56(r@kuhG;mN~< zMv*jHY1Q{t^;3AnHM+t3NQ;`Jc0fF{x*Qbx%~xcOLSNM{hh8X0il9{*5?T1Xh1|JU zRUtmhEzbel$SgLHIAeati`?BptLsD^yBHG<1&l3!AK+0<0A;%dp7kfm_U`Wcs|SS^}`d>}FgHjh_4 z1Ti=XM1T3ia>I5TfyW2cfyBzKouvn78pZTfc)_iVhK)Y);avpdklQjY(LobOgEszG z)a_3X>s@pfsAZ-DQCkv2ndd9|hQ|BlmRf))2e+UnjV8e~nVC_cqj6%n-?H~{!DHBV zD%fr4Crc6TB$3NPPDz(n9T4#LfwO2J!j>r1G1FQ~N2tE%V9Vfgr#|KO<0N7Qn(w2U zny&b&UXj8a$rV0C$cra|C>Q`n(dOPQ00d@nS)$zic^BXS{K-?C(OKYnH5~6F${reEN7JL6VkoTmJpf zt^6^Ku33E4&q<6HtH#%bMMj4wUU-GMD#Rc7$0kW6 zfo6$D<45ROZCS6M#=NnhA0_zd`ZmD`zl)WR52_LrxpKH)l>B&tpsdDS_2q5pQ~YjY z?X!=Dx|$N90MQSG+-|E?pK(ziu3N9|nsGXn3Y67;wNN2Ki~TfKG;nA~XT)0C@!XkpC^tFrGoldz4yd`f5je{p zu9h|biLPFgU!=EBV(=KI3`8a_hMH;-EAz|cyk{T2>x(RO@fNTAB(AH4ATyA}YZ-mp z2HMnq$&PQ>Yv%@6F=aCy=CF1YZPo5{Xw!eVB$>o1HNpC%^u?2!3esF~Gd|A3&tiWe z>oP6~<-_wu9YcL(uZX44pY2JZX>F)0CR7zn{QYsA#DG`wF?r!OBSg3S2d~@M` z_SxsW`<(Aw@Adwn*Ch*KtvTm!j(gnW9{2DqW7d2)#YdH-=Qh>&BAjRBAs$0Tq#(Gu ziV(8bO^`9X`G1#S4FY&@Y~tRGr0gJ}(cd%+d${8d84F$9C=qhh&`8(I>J0r(mD@vh zwEn>PhKE{zEM>xP26d!uoipydD`PypMMK|vCft?qJehQ2VlHo|(;0mzM`0Ka<-MNU z4OR0moaSF@W^Qia7t&cE)VC_57n&vBxxe)*1_Y^OP7;HL z<4zZs@ciaxF!|Q#Y5EezEJvjM#7p|mVGz-rJwuQq)19~02rM5Goz}( z8a<-Mlj7|}Gqn5)FB<%GqGS;79&>yHuUiPBOa~CMx|1AD1tfhk@?J&!c1h_-g#IJb z2$Yler6ryHGW(73vUyMbny~ZYsDr6D$_6c-4tkZa_ z9Sg4R%s}sev?`~O%kmKxaHNFjKz{R%J054VVJA$1vsDyOfWgBe!s(8F?Uf94TKW7o z!;vGZbAO6TCmj1Z>M;>#CEdmiv(MU$aA3|ZC zG38-aAv+3tdUhK8PXB^6wSw_Ovp`>;xD4+s+rB4)J?QzzZBOOwSBDS|#C~uG>F_Gv zA4WI|M9$k(AH>c7TPU59*@ybuOLmdKyrO>=<}iY7O}rYE%)t0T)iO|$)7;97IX?FQP(22J_N9EBTO|LVUxY?*=v~=jM!*Tw)(hU*(T@K!nyC_;6b?F))_e%wQtChm zycauRyVu28gmRX^FlfQ9j%Ohj3xJwQGLHoNa^Yw~oN)s=`cWUt5@0aY(aElpq=a-_ zjxT`k&on}yW1-H^%6n!m_O-)1Q1)O96XaPnnu?t)G~6v4-d-&Yw&4(1|GS2)YJ z%Nxb~?IJYzY*9KY8N(?xoznVtBn;!AkfoJmwijVcA%mkttISgmam@OmgNX?m?0Ldc za3yi&BK+BO<1k-oG;la}F+ujj9FMa~Om2umta!J2l)>btg5x}#FhT;yGhI92vg>me zU(c@y5rEVFVE(JL+bdVAzGXd4R&K65v(G`)lm6~o+*fkCuu@O9--5oabUH5Lt7@;02YIyhiwsVw$uupU7g*jOvDvs6zfZ{&yd{F%{tGwk^duB)hKV#}4Rd zvOnoQffI6=fNK%P4Ng-aM09-5k&!TsGb#Od+*Fkm>Za@dl@_*hg-ugmiI#=$w&A9A z&R$wHBn=^VW8;89C}vSfU5wiby7V<_r^wgcx0|;xbbxDQdFLOgR+j_(K4a(`5Il41 zL3aK(OOCP*RCYyre<$bQFhv7G&n*3YVjX~aplMJ6|Gy#%YF?0BZy-Cay(7lDux(i@ zJ-0%VVKwr8SYwTDmykrFl*1e6??A;v$Xexk(OOihE1bR6YwQ@7`t>Rc*Ozbtt7;>{ zmXjp~OS7lmf$E8J_La)Lec)5(U4`*Ud7hWa!C3r{#A zu%+X1j!8p+9NM4lIw6`GxQJoNi;8>Vm8hfyUbHmee+I}37*7G|b0lEuJ}vd3s^K4i zEU$wVn2*t#R5URtFgHlvb1NyKKC~?%#9Jy9Y z`NKr^a^>`)04ei+PTEyDLHdeprKS>ikV)vG0QTtz&*XUR$43^I-s~#djYNa86{O;q z!&>yIXm?8~OoBZMyGF1mO}N?m%y)cY28^{XmUOYabcuk4kS|8KMnn*P=y~ z@HhDVjWEOk0UlCe#o}=2D!BL&q{UNMApmr5|Cts4FV?F!8V(gu3?|w`FUc*VSIJ?< zug!ZosmFW2fGGI5-T8TbcjP$Z@v{HVhS*PF&+;$8kxb}ECAar5?AgIWjPreY@@aQ! zIk0Okx7blfJi03<@SB3J$G5NFgZoOL5W^iq!1>TT$yB``A)DuKUs{s+eDnvuiNAlc zkL5$#_f`)35f2YBo6`)rcgxcX)zHfeQhK(XReqJDO<;!Tra>s)(wSYe=-9m|Wa9p| z^nTAZCJwxho-G7w#|71YoAwX<5=%`+iun>1HkyW=?V0!z-@{*;KvV65iOnYHjW(?d zO?ywW4Nl%zhH2tK3A{vboc&@t>cCHNlza6*Gv)t_i-*-25OE(C&~Ml=6?209g%jGl zu^^!_uWo3BJ!mT;A0T4Y4lamQU59{Soo{FA41?1sj*6z7E}O)WE~r_CVm7~$I8oG+ zO2FX{2Z9N4KwIz1a9zk?$~GrDIH9-UjN`+xZ|=k>DrATiSVcYTOu#a7FjW)69iL27 z2$-L&qalSANe4M zhv-Vb2G|Ja4{)b9l?zVMBz`kWrI5mA1o?PP7HNboXO8REXj>7D?rpYoDi;sRb~ty;}NE^wqnJRYz;$#fsU=qbHdc z?q3B0eRBi8DHwYoCo!Km{6Hl|w@AIJu52g_fxO=_*6>|uk&h()ng=nY`bSBoCmtvL z;YcB08H492VSAFLhtNs#2VrjO^bHStyR(3-?KuUVT+yKpxnN%H?WJ$Z#<)Wkd!9S} zI{M1acNYeGjPyYAF&XebBp)BZ-2w~{IftIh>H!`F!PdL0xjKaz)y)1wQ$X7 zrm+2{x6#j*Qc$ECaiWMnfl@7^VYc`s{5}XsegNyZkja9gu~7_8i&}=hBD2bG7`T)E=NfPI72%DHRC+u_H0vTMD@{Luva0|dhpT8Vac$q5VLNEa z_ZYYT=uKVO6EvaAEa*KF4A418b7dgrVYQ|3SBuGXTayIwGduf6r zLz?D!0@8+7#p~yz+E#S4xSc0ftU#r?4rRAPuhiqj-L3i4oFVE(H}$iARNzY!-zFtB z>eVFYW4|Hh5hevLx-%NnHK^6y16ir|zUx>}Gx=vR<=^48CF<~7pNx%f{XcYRC$?8{ z-)`bJW^GXll2-#DZu?rw*)MePS=6Ldh=!zt_49wO-K_*Qp-n@|Wn&)Q$Nr4{W=1MK z8nQG6R!+YLBdt&6=`4ORY43Mmta1i{#6Cg4e=R^8ipULBEjQd34jo_7fo?B6rX2 zjfl3z>8cM73hOVTo*>JfJARlCYL$i+nO?C74qXXk!tYBAW(Go&*TmBpGKW^#;tNOJ~YE*jzyJ+;O zAYFA=Hps5lE1o*EQr;&+Yep`$t)K^+xe5QuR`D4}8{Vw-kStEeL<4Za;(TkRC^Oy_ z%R|gS?Me&jLj_)*ydlM_qg6X@KJun)Hy5{yzBl}hKDDBCkW|-k!9*LE^6d>uJ_E5i zs9~=}w&RtGC+FX%^&r7J3~m3HIyIltL~M5-Y8Y_op!pE@)rAdoppMks9#LaClpy;N zK=523`RjL5*MCT$>Vr$cCm3}Fh>eYiQBVz~erG--^;~;W@!cJvf zZ8J_fwx04su8_E|Dqm|j_3GXZ?pT6tRQ`L(i!!sJ#h+pE-3!Cu)|LOpLq~$uei{(s z{KcQ3%IN1Xb9~JT(|G7sr!_^3qz>9Dv!4GQ%pB5R-ZlifwvGo})x?H& zA6+OX7WeQ={m3F%f^w&OkO&-cuiPDc}Pmt(>>Vr z)Ff|NQF)ADGQr^uQ}b8WQ+3d8qVZwvMNpavQ?Me#84WLcy3ZS24Cwop^A95C<0MW9 zdX;H^0^)NJGsasrEd0&05+Ouny-lq**y$$Ip4W%`sn zlz>uEs$k<(ln=VGtPCjEub(Jj9T+S8BP2WS+O>J) zwcc?jhY3BT9qvq}pPk#T@2mNJTezmt+*;8^nUcC9RzAt`rJ4(6Uo~OCic@^}W$F z&wLoTs}Tc2=)N>~m6L>R#X=+`LGFP4?dofnU-JZW$nx!6gY?jL zbfoF6&|Mcj|2M?O;kN8S6^7d-`IC-IG%WK~_X{&Q7x^4$Do{>vg8z@Tp7hE$!QqXv zuBuk^)Rsp*Kcp07t^POxvM=W(GlSNi@+!9xo0ZGon_ z4#F&t_iD{sEj`RJHs>F$(8+ba56<{@L5vqLbAfL#66fKhW3GoK0A2 z+v(8hduKM5rKPyE^U#*peW_{R(&1!2LVeCE$ARGC~V<{HryBVN55T$`L z0_=jl?AO>tkZ2U=VIQlhgCI*CTjn7~9XhJh_RF5ybS=iJMD6{k2{KOcID&3pI76PW zhwJ0^oK(m`JHP+^m=gARJPmuo9`Ise)l(Zx06-vKkVyU{ju$CniX&j2Hjb3M*bm=e zH{QR0_-iXz&ij?&bb2`6>JRKf>Di8-myzsqyI}eV- z7?Jaj-4ACf?$!6M1}4_Otf19s!ZvaKC%EbU18ypsf%h%#ghLvz3=avC&&#^+{r1rp zu!UA~Rd~U8V=j%_!nw#h)@CzydF@R`M1Y74$@Un>M;Z*oQkaSpi_~kRVr5*<@Bgr= zxm9&m+>=JH)i(_|tBI;=4B?+bU>cYIxd(zG4* zD4C;?9r&K?Z5vwbtEv^~=8x^K^kALc{0w+EQa2=K)NKh&EE)rf%?ikWS1$t(@U!_I zLmv_Z=6^#I1fWULNyQ6Cs!^p#jfN~&NFAYd6keTIbgd9hWnwrH`sS{_P;0f1 zsT*VGSVT=J^#XNhXK_<~JY(a@608Kq^==>+2?6rNVK{Wm8aY)1e=hP`u%A6x+jouQ zWhE0@IX|z@?+UUz>PH*jVZl>(I?-z0XR*A)|8n{KDjL$fZ5^GnOGARDDzQ<_SY%>+ zUu-8HOTqtdYIzCzCsmnQ1diBhG~MN2G7?DR82CU|EW;{H;10;4(HVY0ZqNZyK$59}6v zyHCo%<(A8J`(iu}p-nfb;VgPh3z2Bu zw$h$SJu?BlMdCAiC-RzU!&fYYf;*HG6=o6yt0Pxx-HLCY!l<@a1bgK>blVasVs;a@ zD|-mGDwc5LN-v=hil|k=1aMp)bz=^X5!X2^E_-X@x-8xCV-@@?i##CmV)5*iJYM){A$vhL zpI@Ey1Q;r|y2kzt(+>BM;8s}w9;P~ECU3Sd9eISg|C|{FG*4r-7{AjJ6>MmHe&x+% z4pGVXO+{ZlwgtC8NcSWux4iV}d#_f~4Bt^P<*%q-n5HL~{OeGb0M~WpbnFP{b*g+s zVgk%X65A3K5Kc6D?CYK{kC?bT0l59UW2M zoUz&Gq4vYWZwUS6!6R_)VkovMe;8P?czccLr*pa(*zYuGzaubH69p=`@Sv4IZf z2P>RtOZOTA)Ra+lC&;-jl7bdzzP;P_^6B{c^xBI$8V0}L(MFR-yGN_i_rwyX>sk9H zUo--IftcG;O0c7l{rmfi;?~i!8?1)7BT7NNKlfWN?4iS9jpQs!xkW4cpBHGrak@lY zeDMt&v($D_e)qLU3%ww56qQFb#Cm1omUs;uJUZS=hAwO2aRqFo9w~?pIRE5x*__7D zZ0Je5xGE4E*MDU+g<3Hv#D|^#;qrlWM)Er82nlE<%7#fixUdQ>O0ur?e-iF2YO4Rl zAFK4z0^ak61H?$m&%|GQC*Cxzx5;X*dhYYI6kamsrZ0b?OUs(%uMy;Rm?Wr(FMA`M z3RlqIY+$Omah_}Pe&_wE*g>xngSi(zhH2#)o$M50jV)M=#hn02n{~Nr99#(4P*BTh z#*(PrX(lURwtSOP<=@X?<=TdshZgs0UKGe4|MiH0s*Fs*ip<~4q!0&E%Rv-+E+?23 zZ@ho!(+gU2jG~cxHCqo|!3g&|^$MSb+I8(2$l_1Ne(A>{b>;HtjyLP@8_ypBG_{GwXND`DubGuoc>#j$M#4u_G{fI3SahUQT-^n1$Fv?LQvH^#SmmJe`u^RZJR<3gY#bO}AKtoj zs7q1wt_!A633@R2NSAq~KR?!A3_@BdroIgMd97o&e~9OCZr9#rhxbYPL|a1JRJg(c z((55()z*g^^W&sRXs(e{{cogp-%Reu^2PT2?kq7DtN*b?%lfV=R0nTrXP>)@nJm)m zh#^hmz0Nr|+1S*sMDt*pB&a?he`l_;BMFlPSnnSx*Yuy1i}fuUE;D=Jz}+OScrHtMy@QpD>sk(hw9 zjQ!L~LjTZ8fv6F;2i0?%H|pL8Tr=x4I(s>fwIPgvNntpz5_A|>Oh4uIim;L-0uK8O z=F=8;a9RFR^vQLl5PXRUYGM_EoO`ohSql%Mez1jZpLm~9oDx>vpWDwhJaP3T!Q~`x zv+rorFE@ijEoSFAM1 z?c^&;Pzf+lv25(7hR4J6!{L<-k!Qv9z|v$j1_p zdfldS24T+pMvnz#z^s7vN-f?NW3^ihgg6;my52HIL6!j%j^G1;F1tT@-MWv)B5T*1 z?eL~=sUwq~YMb0&3>F5;e~xV1Umkl?tS%>m$T z{lHgLI%Saw1+X^g+W^-KzEp)wmY-ZcG~Wuan`)?d7kLXGm3(aj?T3-)xy#}H?9cv* zt|f_Ysly5hFnd6@VC|K${hl@l7UJ;*TCalIRA0qM&SK+M4r>#GyhHhaG39>uWWM0g zAaunjtk+J@e(7IW!xu^`w6MATVcG4jbBnAj_5#dp~-5Lj8 z0Pd%N5Kr~f?GD*(pq%eikzW!xv0t44=tAo&Kha@I_zE( z4&>&!SIseG(eb%m)^aE2gqU^y7rT1reZJJbz+Z_05vLH)f;sCt{jiK!RchbCEa68j< zKtc6Rz?s3Rc+4uM8nX0|XLJxNCKw`mU}}B;=kk2njbq)eMPLgNH(G^uI4lDJI*UDI!e!G^~!J*aC9^{`%WUjTG?TC$k;E^|s~ITRBH18MLZWbbo@Uw2CxQ2)L?wUpiO*UoR+wnrT)|@p>JdWNQE(wY``)^`famz8OJDnRdiz?h% z;A+NqHACJnbjIK{FV31NP0M6^?`4gxrZR}-di*)QWP+>sa_Laag(dy{+)45y1~ny- ziIcll z+2~yDEI2FT)B`g4?`Ey=HghA}ZLeK1REeCp5|a(&euT$VP2qh~PR7K0<&nrDBH&0 zV1#iu68c-;(^UxA0_{oUAQ8?s(v9zos5^TUXR$KCRm2a@SS>_cFa#(1g4Q?ZEpjiC zm?US7evx(ePH=*NJ&BX99|aSl8Tf)7Xd4qFx;4`-BH1@mDgS)3pWFM|o(SQ4FSbuCBxUa*Za z+PzNm2tz~qU3{aSe7H`GsCqFbJItEdhbrvJ>a?c@Z2t>-f@vOgG{TIZV@UR)rs++x znLLE^*|u^<*{1RFTcem^IE&dt7wHPmog@Zm`Uf_9Y6Is053C7i z%p}jO=hMUfMJHb-H%v$n%m3!NzUnR**|-O$-u$P7eZ=%qwL2{|Y4fzESp&m$`kmwf zJk>aAxGPGC*U3y(R5=e1DG$3sqYdg=}+RM=OgS_`7^0Gv~RBZaPll7lOnAg0@ zK!bS*_v!0Zx|oJqiT$GtYNLa0bsY(P*?ZA$I@Rwa<|hW$z3bNvUyEV#X7FA@`EAG z4d)7*UycZRP>vN%yMsigjEdL@rOa|{;bmGLd z%jJ+E3R@GSFo*cw8{hn(YHc1vSs}H6PjdoRANE%@dHcC>UohX-=95Ktog{VBpzL*5ek9x*|)blT7`ZtaccisOSC5+PIFLMlOHr=N+7M zOl|_S8(v~qU@B68z12pp$usX+rFfh?TDZv32_=j@{cwTc4?!egspW)8^2N9mYtQ*eu@oj?$+N^*8#8UbOqUCBy36 zpQ;_Yazuc|C;X0foPgVPJ0jr+&>5=4gFaqN2hmcc2M-!+%TG7gzx^y1CTz8WFZo1G z#w0?I%x|4CLb?E~arxPTX|q5aCPD1KgzwYKSL8;@4{cAv+CUOr7LCLI2j(ULHYUXO zzx%X_!B`ChM`PFzMvq2py^muDZt?0?wPpnNVAMcF&0J+Fu%y=b{U{VnnMBMF^}(L` zW<<=o;y-&G3e_rrQmFI8ghU(L~x3{ zzLO<8kKuVCtFbN`rE9pyWD%|To~);OH1eowfMlcj^i2?oX7XKjgx^d3ghg)OBgnm@Z3Gw@ZP~)6+vb<^EN1hmOdM1@M>K(Di zN#qjFosGyX9ved?0>-0!dAgz6cQ zf%xikTs1^APaYl<4Tt9L$pdS1dMRFa-iRmc#Z|drCXJd0>AIbl7SYmWa!f%CdOL7{ z`*iz1Xb}u^#UXjETeQux83%q}JHs@X6>@LApAJvfja~W665*v^;eLM4$1(`TzqJ*^ z%~>FGe)uT#o8eJ*$EifFpLmFy1_Wq@lPjt0Mkw==InYQON5o>v@1M&iqQ(6h;3IwJ zb)^7vO$ba7;!m}|u#4v!NBz`1l#Z-RZYc$7nMahtJRlIVBA7SLjc|S8*6KE25MpduV92zSYUDT zOmp!)fkug%0~{9AH03JR%Wv|nFHfG=1*5(AIs1rkB8Qno;J9UTYCBju-K!BSp8Jl7 zlNl6@s&!ba4=A=~IaaA&S$$scIHtTA?i-K`A&yl2Wvcp_hpvm{AdP7r$yg<*g78N^ zIljOfDq*c6-K^+#FAO`{qm%TjM{L$aI&rq#Z*0GL`mksStEe8l2bVEHYK04f2sHA+#0dB`nVV9 z1_r~;;uKZ9{8_BsQH|#5r0_@Q2*5Guc;_NdmA5PT+>FpV#_}5%A204%4?H}$T=TTE zs}nioSJ%z^l)*aDXb_`H$hP$tpsD&B6HAH~tnyUy4P}bTzpjCH#0mMoL@Q=vuz(wkmx{jPKLCS2!1x&#_K5ca z*n0cg(MwTqX#De2SLXM1Zh4~xNXJ-HOR1*4OWCR2%2=HBEqe@BI2~7C*QTL8O16}D zJ#VNrVuRwmx%q;23(1!Qr?@{fy)Vx#FIv~1**+L1EM-7pty&2dot;E0n(zv-qdL!k z6Vx(7->}yhr$E}X zK2MVq^^tsvDS&k4<5vA4<3vY9aP7Q(S)xf}bku5B_TW*x(h~aI8_SRF6pX$;ZuAEvuVpbXfU)gFo@i~>Blo6cGdvM(tL*cEG(s8 z5g6s0lE7l!B73fXS+X66_2cZxfzgzo>4-5)- z&o7v@!#SZ=v|6{?{PoZ+DsryK#563%r-!SaNrS)`Rv4B@-cl(Q$(?}1!`tgmQk~nn z#s%cUF=Wnn8GrS}av96ZUmc$yF`dZh(m;E84`c8rW&6Pdz;?t`Oy-_u8Ogvsh4ek7 z234}WUKks}mhfPp;1TUIO9GAzHCSdK+lepYi`HJbIQh1nQvhvR{w!48Rlq62tzO zIEwBR=~@|sNg>yL) zB{`|x3nT(9Q^gOSx@FM0x!z!E#oV*Po_d#|m*fC{>Y3aIhpyMsr46lHWtZSfL^U%y zoGVa&A)gBpkJ3&*G_P?BS|;2m&?gd0$yTKZ6?emBm>?ytBx&QED;7G`4tw8d;x83N zJ{KF{?A;%u=PY*ktF|LE^=;U~AkB9rwAz=aoAtp&$%oSP2(`8JxfYSxLyoq~reDaX zYHO#_5Yd5n-~Zm}?xp|T@~jm!V&a_OvmnXv;unCVhNY<7I~!x*!_5y#V02qnGvOlJ zboO`dUGB=T9#pI26+6wO7{)pH#Z6a4ldX|imd>r_ObZ)T5igLg1X~e_YACI zuqGMgzM{|LXK0*DVR|3g^bJ)}V*>wYIWQ7W8mG+Z6t?@VEqxyHu?e|ni$9Ih)CgH# zp8eY@!jrbsw%}oFmn^(vXr4m?8pQsoW+dc&bo8*5oPoS^sZ#+{L|{LpO-MFIZPMwX zufXp(bwkr8k;|;vaIYn;i_K^3kYWbsyT78O3L)=>B(=e8x9yiVd|6b#wD^N!)OmEw zCzV4|zHdj7A%@D4+(SqW^Lr@lPmI4iS}ozh#x#FfT6I5l~KEVJdrLN{hBF_GItZTu%)dMa*lY z0_a8Y7?~f>Ke^{l46^+i11t66*j0Zura~}Hj&*;NYpezwr~LOUcQMa%dpowD*RyZC z^2gI>|I&|g^B*<_mixa-`OFTGDhhnG3>|yFFqun>76tKXtlgDviN)h+tIJB+|HkL& z><42f5BX@fAeQ?KDksr3t0lQg@S~m?p9YBbbB2`J+utFb^V_t=8p}=1{L9M z3QqFzwABrdIp%c(hWRt7+wuGD(Ao(2TA(oCKMVdWJ|`RNP<6>hQo3O_RBwhS)7W$k za+d5*Swl7?Ycd!}62_9HgD%Dy_XC4c9+xp@lz~AKUI7V5nqEa0Q%e2nlRI}IaK;o6 z!2buK`A70&7Yt{0V%p2RFy=p(>>WnV?pQVhnsrUhH#r89DzP%A6F$dZqsv=+9Ejiy zOrlL+tfIjT-^svXBXG8m)V7Ke7%zIHNTlyv^p{O;5fUVkzC2|#QZR^7fgYUrf4bS# zUPai*r5(oO03|{QT7F%X>Eq$3KD-tg)6HMn_!)iAsn#%k)e(sT9X1bY8U<0DSq z#eFk}iu_nFPMP|+^5QWOa8O#ad|mPF2khZmxTDBML~j~{8d2{rbljz|`GCo&I^Wh; z?z0_PU5xTv!hUQAv{H;}&<^D^AaSHN*S3LKqK$onU;wGmUj?d?j~zu*TK z5p9m&NYuRBlzP>R(NqKD?FDRj7rv9InF3P?B%x#2>$J}4z}~6vq}LSh-8)B09=*E2^{eP_>T4g7DR+6|7DB4eAA|3FSNqFX zFGl(9b;eTI(LszJ&b=yMMnqxtzwfwtDgq`uvB`m33H#5j6y(NYF9^y zf#y`#pi%^}^w*mSJHU(G{EUM*G)mTT27$3wq`c0O$r;7NNsd!mYfQpHef(XOU7c3; z-uvM>PPO4)xuAoynI~EesBmi8jPG*R6_qDtcI6wwreY$vn_j2B0Vt{(-yu*E8aF>(>JWwyc|HsI=l zp3k!q(qjq2a}?Z_$Tz00o%#9iB<`?xyx1~#T~GXyA_uf+JZJigMt%u~lOwB2Esx7$ zrtMZ`*eC3TRo-s{-boSF8SybTI%9D5hy-l zU9J%s9}~erU9ZpHBEwVG|x%CJ`CVcbDT^A8&R);jI) zri%Z_MOFjyMLTDI1ft$dGY{XSQXmp(N?XFu$xeDHQ*TT1!g#gaMYuOy7`nkw>M&=# z3YP5I4PnqX%M+~yj?=4@b_EBhmnss5e8`RvNoVD2#DvJ%5u-r%<_5Y|5br9|z>O83 z78vtVu0QcIr=9Kk>ua`RrG-8d6T*9B0)uTsV0R&_nczZ%?y$3KPrY<@kNH{Cw#^i{iuVL z*e&EW2i;1Rcb8pnpWW@qZbr&P*=c+U(PdpEzEGu@CxWY1{>v zBm00is8@ExmLNSV!;Ye2^|hhO<|KBM8jHW@q8;=GC;UNp$p>IM`mG6;xg zZF(u({|4nG>SYc$9Vk#5-_*5JuCSH|ZLd8m-@@eQLOt3VF$NCJ<&L)w8nm2Pw1?dEZ^^M(9b00LYVXT` zlD3fxp7K9i+jPCTeh^g=pG>a2i@|jzgISxt_2^*^4D#Nam!q|*(skSk$uIhIEM#Qh z-rJ17vF3EdA)#90kHi~ox$Cm!!NjC<*=!jHI#1Mlrgco|A_ zEv3s%Y$qUa2Vbj9B>Nx7`2>gHAFkkuWZ>^-Z|&+{lza3brr}gw?9!YZdbG~xu=jku zu1#WhITmH1V0vXr%8uv^2y>^ER&_6Ba29)yuVHK?#}{HN?)y%SU`KWR!J?l-N|0f^ z>Z=o>WLnTan{+7^4xx~JnGjd$xYTC}tKY+YhWSR#GQ7V~noDlMi-rb~V$}u$TMFA- zi~WCTo6Iy#wZoVV8Vcn#=-IYTp9ui%#USx&T0}BsP3-Jg74YPlh^ks_et7TrT->r`2L|XU z#Ct~{DKS~>9CzwQtOgE8-$(852?Fk|?Co^|?|j-fQ+uZG=RA>%D-c=gU?FZ*;&kj* zj5@o>2~Cj`6Y<=OVD!FObUdHrSRqfv9 ziU^M)Aaf{A`rwSH5tl15c2rm*gz8Jln#9 zm0zJ{g9$a^bBKy{Gbjt3;=l3O`(58%qGqb?@%7xk#KT$2>!i$P;Ode=7CK)TT@ym2e97bunAb8tYclqSZ^k0`)jAfHrrNDF2w?Avj!h>ziCCyf`cW!y6K+e>(`$GvXd7{jH_ zTnxrg=GAWK1Sz%0G<5&d?wyj;$1_PiVjFe=hLHZj25?5^ns4F-T$k;(NLy)gbDkvl zni~4Cih!oeIa5OSQW7x0_)$09eaKcGrAd15{^?@YjoHWj#(qIQ@xwqW1{c|}eIYPL zvv@Jt&=I6~gm&goI1%uJNP@tnyX z^PfitZS2Yqz&PaVbQ7M;N|v|zdr- zSJmG>Wc3sAiO!q!LXJM2W{RHhWW+L8wu*We&wqLiv<6Hi*QN?(z!`=ShTw2WnJ5#% z;${8gwVr!NH=FLPjSXb%8Wq-6J;9TyZmyK!$Rrts;A{>I!p#Ue40o%)*W4y4T{k==9mLgaNe#Q#x27sj96KiT-{N?NNj{hV# zdmUEgVXH&clOq1x-+0E|_dH4bSzwu+(Ma8fWb+*c_T1JsV9%X;-LzIdpdG!IFc~O^`>*I2y~c&>D4=}3;W|1TW2va+ z!c@b+r#8k_0W^CF0HG6l74)g9{IM5w_kv2U_FK)E_Ah;UI789`HuGI1Sq7-)@?t5-mBGgnw*5-C*;VZ0-NKTCr8iYIom)@I(_SM zARwFmM(2+|4W(8VNG)fRR#5hPCBc`Fv8#@>)UVR2MpjjK9`;_TdGImx(*&4PGoA=9?1(o}d5&ODh^db{!Ce7K z${8it=-)@lsvycML+S0V2YHmm0m;GGe*cX5o<1d6Wr6Ba3rf69ShLVf_s}#^MWq{a zfs^&ZPrq+JWeZ4!e?VPv1`D?rs#H_L7atFobPj`skazuCv|)rd(`&>XFA+)PaA=l4 zl+!KmW=-w?m)7^ptqhD?I(=cHjbjO4(ylqV{tvADM-LbnoDB0j(VRGP!5JC2zv#^v zj##zeJ!oaY(pNzQY8Z{mQSkL3V~*`e7goBJ7ADWXYSO~#8XPi-x;Sa#|EcYzGj1lQ zw17t#)Bt@Jk7h98qKkZN)MV{R5jC2!5<-oUMCpyWPr`q4rvw-a3tY>+nTMx$>wo`q zv!}PexS6{H=tPCS4k4cKjo7DeR#X-0noy9eE;FhF_vW1w! z1-c?guf3CT1CqSeO{*13{25~iMyOi8;q0UiOxasc4oGkQWBo#}j&zaM4~U=4GFLS3 z?$D93GpW2;<HFQZE6yRcIn?er>Iz_n7rr7N z6&Nu?OmD$gQAOhsO~z&|SN;cotvHcXCBA5UlV!V3)QQw!nAp5N{N*Q@f1nU^@81@w%InIUisCfVF!DH(HxT37ZAu zdqfg;f<1 zl>xRSFrTPQ2gvl(LILYNIHRlZZh~!330S*9`+!r}Sf;5rx0}A~9=0=1Y<`sFW#sai z1}tW;fezzOum9U?aIIUT;f62CoX3seV{N1H*p5{U-G`E@>zHut%mLmwqJ)65f?_HtJ#@h^FUz!--*qfY5g%jNC_k~9KN z3jmrwDOh#ISb%)%ZC)dKeAN_CvqCgxdw`<)64@f5HI?2>;Aslj>0XgoEme{DIhOJC zz>4fX6IOcxqXk5QC=m-TlAWZ`al2>zAwFZ?`UD$i$ywP2#!p*p{>nFrm-? z{+yP|BTRSWe`0Z}0Q$mY3K{D?42-w5JIzs{9qnW z6oJ>cbrLDHg}fvh!>|6hs={dAKDA+>&JJqJTYHPJF6x1ir5^Q{y0><3H6I_Z#J>0V zduQvLZiLO-K)WQt^0)9;Cm{jF{B|Hg6jRQFo>|_}t|#1o6?5Hb+nS*J&6Q)3^zNM5 z)?Qd(OWpa43<}ZPcFY%GLZJ)#Lrp|_)OHS-+lAjQYyK(QWsUxLWltxaf-SYL2Bf5& zyQCdJ`I@)>taiU6hY33QR?#tY1Yu$9nVUQVNUy9Rq9&@>SA3%|86rS%-ye(CT1Xtw z&{TH&=XBDLkO0=J6r*Ls$p2dS&VgCiJv_4tJxj{TG?x2|Y-9JbToc$Q258O{ObjoB zduc9WCEn`*ngEr`{{Y~#s!;}}!Mz-yv2j}*#R6(5cXsd6M+-L6WGu(1kL7WTa{GV7iO#4852Aq%Bu0gz!7LytrKfgqzcJ#}&2{3TIoL zAS2PhEbGMAd%Rz0X>EereTsLUl1>aCdZua18h9U>)AOeWS630TcT0LJsbJ-&QslJZ zBq0eT?iu2%g#@Rbcn`{{FpnRelITWLTheCJE%*mD4q}r&m!Af9+4&nDzfJ)8SZ@~C zbp-3m0Q2XUl`7{%B{LxwyA_A&Ohc5^BmWgkX!(W|O3zSIc-w-Dc!aaxnDIT@5cu@v zpYO##7E(XZmVAipfoU_fo*YM1EXz{yX#@_u#x>E~sV#F!_qP|Xv^NO0C>1!WKX&d@ zGx?N93`o%z@g*s<5UJkaj@KWO_2n9E`w*{1i*r?sALrZO!h1u|#6KOngOMWJERA4; zbc`+}p$0;eiEe}1wy0>8#}&c^s!d6)TZ!7pOd6b!cTFrGvx9ayvd_uc4L?}mLM^{n zYz_okY2CHuGWiy%bYgxa$UyKg&2&=|UG;akZGJQxLT$+>(=;75Rup$P3w0J*rIHK9~*3bN6y3^3xM6v$u?sMC|SI*vIL}gjndd&T7@^Tfi;H2J~>POTgV*9^o zLI_T21AFCQ2-n;OI6rF6a>{2jhVR?|E#!EOe)Obu=RA8;mbJ99xR5S+Irlv~fUcI@ zlF4-KlGrMdj@S({2clzcdCE>?KgYh)s@!`XykFdw3{h&QXQd+mymHpp>z$7YNrk-r zS^R9@6}x=Dy5i7Sh>%~H8Q-YM3XJ;&0b=kvFg|jP*oN;$wa5RTNOLVPHGbiu z#9iLdb^=qaJ<_^cMb-AuazTOU?!qEcqMRF4%geY$UQTs{(;VMl71&dWTkf*;|Ib(T%wN2mEGJrFY)_t=EGXeW_UkRPy8sGz zFozWYQtca(7MH{^xmR*<|@VM z*h;CcEh zsj4XZ2|!`(<;@~T4LLLT38q$Zy7qBR8ShWjx{fk@isf=E%r`85Tfdj* z@)es#QtvV_2G;;*<0k#@Laf6ApIqcsv#7!UxylgY5u(Y>2PQj(0Y~~#7cMrGsk9AWQF#;r>h66i8IG7nF15YRWW3&3V|+3Us)WWgFE4M z#K!9-7;-nz3;S+KI?DUhl#d?D>L8L`^PRl%$y(Xqf>(1IPQ71(*{VC0d`baz#%{1F zwK1vkqaHm)Lvy}s+N4-7TGp~{V*^aV4{R>({TQv1o<68-`GJ4#hRSR2aaW_<)WLb# z)`U;cJOAzxy;|E7Nz#?AT{yojM(u0V&(eoXF;Y~nGSjwqS8Zig>x1`~>&`N(8b!^f zQvqh-mbaxRrw0#ThffS21=8DQ-h9`S>T{R)jncEmYe4issIL-RN6A~l7xdJ;|JPmF z$&AAwdwtH1?zciZ%h0dsXcopFp(-8aEWl?yaq^f!5j$VHUwvT&9cH+r*5bs{!FlxWs5WY=8S4l6f$hob$)_EMK8vID z;sHiEh>tMPbHs{SB9~9*X-B)0jqk-rd#BWVM|YKc>fV(9P}Np+;mht^xI+L8Cw z?4I`;ZZnx^e{GXQ^iyDpQj0VD=juOZ=9)jB{%NX z-yT3l(Rii7KLaW$3<8+<*d|dSUWwnXCS2t#RS5{2MCiEwBI}Q%W7w?|@+YE6?6Wxb zo{OuiqJAgJ*jM>)e8 z$chn8aUvJ>KqDV-2`Q|uwcj}J$E32jIVOgz5U`DBLPR_u8$^zRV# zGQ7tv*dCqUy3MX4UWlJUU9~ z%8m!_m9wH)X4@Kz{Vw5p)`0gL+_c`sfswXUB06;V;nu`RmI{vtLNT0W(j7f9TL{@KE6GlI&g)dj1klI ze~yL!))(C>7e5^5RhIAwA64N@cGpv5@t1LBi90``rm6O~V(sWZUzr(mHFQUItACN| z3nXcaIBv53oBGX}R=ZM@Nf~+28DX8{FgtAeZ$-S4ROie~TKJ-FKmbsssrLP3;zQCn z@nH%=TVMVS{f=1yC+al{A5Saf<|e-Pax$!^dADn-XnB#*lM%q(tM>K1Xe#XUqnaji z#EGreL@N1>7Hw{dOZp`ygRdMEn%rwQ4y||>Z?)0G)Zz#XOaZ|IFblwVKpE@rAk~7Xkp9oyqGX z`t^WXR5H!Y2M?JN-;Vt6Efw!x$ZZb37??BveDC__;@4b7^i|Q44V7t}@pS!!AI=V( ze9XmdVa8iQmPJx)5D&D(g^{zj$MD;eui_%ESjdWH#n~8#RV1&Es{^qTf7ywP)rv$C z@x2(pIqDAJMQ$w-LLNk49j_4*9(zr8RisYf{E?~kp#JS^!0Q%36H=gLo|LC|#6Dpu z8*Iy`cW{~UkYW4gab?;`*JdKscsc7mZZE0O?=X4%~V3?t~uS7`@WUq#;C6_%Qly!{C*qZCR`) zDQOUmj-FEScOa@(`4((A(krD+0>E^1Frc3v>XddS;Peo?PzgoxGE}JF^HSAF!i!|; zv7QCz_ePaZ;s_6Sn^vfmN1YZ{5lo8V-!5S7RXl(x|EzSYj-+OE*R{O0+dE*>)A99P za^H~U!>v@)@h8F08rR^G(o*wjcG56h?DqGPtVQkQ(ndw9T?&3myteNBBeaWiH1YZ! zyY?eLH%g2~)>s`}lbt$AHsv$N4Uh8^eJ#^o#INf{!^ag+qqmT+&T6{Tt|Cr^o(=iJ zT+}{1UM*u5U(o8OnX=P6~s%kHWu!#iR< z#v=u?f}G+2aC9Ym_7)@WG7%Zdw$BiFZQIO`daEf|@jMT4e`x{=xCJ=i$zT=Xrt%a{8gKvz;|+B;YvnFRVyW|HKVR#_nO}?|8w& z*kp3PxkK!3y+<-gE?hIcQ9RrUp1;JEuRbm=YZbjzn-y}+{vAJn$CoxledF&o|0_}* z7<1~`S8DrQ{J;dDsa`FJKC=&0D&u8F@>fZ_el8`)0Qj?f8oDFJ`9q`(AuB2kyHVwY3AbY&-r8FVq!ycZ*xW=gJuA^Tkf5YR+2> z7OnF6r@x@n*9xoSUo8E@NWz#7PHh@bX&8W}4&M~|*P4%C3RwxyX=PKadhcy&#!8;x zQrC*eJOJaE=VrVm$os?}+>xfmhZJ?%VMtKV2Jg1{c}$@NfF0JkUr+PM$1C7y?N7;R z0+Ff|iI9W(eQ58B>`qXv|HWoes~6514>+18DnlVsZcO&BJ8~!{8(4gK;RUNHjhnsp zBmh#E4ZL`^;t&%hg8AkC#3X*SE0QL=$R-gLSmIj7rw{pEAALPT>-Zkl@{K9MdS`5N z#ZH$^M|X--!1G<@FbJR)5fmNLh#PZpnzbJ^%#8<1L&OFdOcHcqgajc~K>H{GTfLjs z>hgy0N0v6~1&3aK$|~~Bc)zusrAOw2_0wcV_WPE5l%P~XhaI>kvcSLFaw4R$lZZmq zhCEG2HJ<<~shF;e>k1o%q|(oI3pj?os&6^^3%$*80opiq;QF zyK7r_CTjn+6UB^rQ9GnKN~RMW7Z=!?;u1|g&G$15Eb1;lQF=T&#f;p~6fE2CR-W5- zQ`nlMIS1h=t5Q(;VNL=XDaln;gC{%;mrc9Gh06?ZCq$u-8%a!cn_a*6oGFR$szu=U zlkVI|f*t}8H2G=?Q6^x7i3?%n5^l^{2Wj{=#%3+-V}t`4%PS8()}c-c`>vM+ej8d< z*`MoLzD1R|4&32BS#ry?Wm;s8N>A2m9;CU8&Sbe+prUN+MtJus!~6%a z%!~Zsi((dD+&EsX1KK1vK^)I+@@f1og{Eg#OBnGYd}5$}3T92-?~WO&*Y?h&sO&wL zrXDHU$f5!J+>X2GCHoigcAs${F-br+KT=${eNV;LvRrKtG8OFhSEDn*J!GFqdKq={ zAIY6G5pan2f@)r8?vca)?yTE$yy7ZBW(`Y6IM^B!ewltH$8WiShY^0W-Sz6>b&EoB zO-*0>C0Du8EHlcnsx;>h|J3s10Wxkz`+ds$KEGA*FbpDf$Db?gHBwKVyd}s@QD3tIQlsDG_ULuYln0q0_TF;3 zW=!KtY!1U1Z#)=r019PaUD$6p)!tShS}m%!2JJ0<};vsuN4W4N=ujqD~dI4P0{N72w)W~aWtMOIF~Dh~ z!%UuVO8@Frd)3 zLC&|`IX8hH9m+dOBIzsboVdOuc%^I~ICJ+n=PE)J5o@;m~MuYa0W~Xx3#}ICEp~MpN5$yjjV0jou8*Nxqaj_ zZTv|e`R^3-Dy8H=UvEJGw5OwB>z3*kz9+YQq zg9fzJ-hmo6o+B-QGn9y1xaFju)0;rUr+kP-3p&zSA>FT;a9H+YxshC%!2v6Bk)=`+ ziQb1!Rxna4L%y|hRoQ5vH2e1>-tbDv#?eu<1!|P(b68XOV~D)JYu;zwDQMntew1Dc zx#_F<;>Rq=YAr1&r|WU!7fYKSnpE`tiF8x7dxsiXW)$tcQhNd_(pwiSahcn{I;=FL~ z7<&D~D2O*PQ3!v?>gcqNVE=E`V1^;ahsS(M6SNNkr|ctoVe}5v$UMReHxrAvZ;N(= z(-t%igNFh>W-PB2zHcP9XtX}E7dcRfZ2A1Xc98CAW+XuP!53N`>tIVIfingwOeNRP z4~i0OK`Yl}&}`n>$gyvh-kQS**g|5$zW9*%-PwKlHe=QM<4#s2PaJj$5d|pt`Ek(? zlxM1&49Uh?NXLTxmJSxk^0pFkT(0gG+!uVFL-4$ZVp_g%9kv0s&j!8T&#tDtp_l`R zGZE3)1Jg$)Zng24yXKePb$Feo>8#cv2q7M)k-C3AR;bkLq`Yz31RaEIYu!W6dkUiw zZrfhM1koWp?f3NBcOVp!x~{8m1}Zq{IWNuDgN*SuoTaw`A`;Tiwlei1)&A=4+955X z$M)~Hs;!5W*p|-v9aU|sNs&LltHxlbdUpk5TzbMjRk$V;toB!kCuSog=tae0D{m|Y zL=7J#ji0Bf;iD10m5FS{U;NvFSne868&!z$GZpW3gw5d4kgf7BiLS%VLz$zRILp_63HUJ0bqWHZFFT z0F=EC)5h(6RhHu`OAPe~b5f(>3aCJ4F}3q}N1UrJUg%zZnSRVFB-P1{1b+G6NttX3 zhCP3X#@RqQ1wn}T!w6~gb+(3LxP@%Is@-d`?HA^q)0Ixr?EpbRcwd5lK6apF!U6n9dhgSaY?PqZj3qIG{A)-l4( z4X|-5iQc8*7}WgORimKlVg`9=qaby}pe1Kcnta}kK&oP0(hCv1;Cc6ui2V#}dx=e? zO{(r%t)?NRwzrBis}^o05+8aZ{QpwH0W7wpHf=}}{9zv7GDO;)4A#~5#PHkn>_&malhiuTVioPag z!N?YUhzsG?;PceG*40lfddz!{T8e*0eKjr8!ocQ*TxjCMscXSS4kszoV5&R~CzvwlHyRZ97!>0%yi^lY*yKfmjs>UOyrUX!0L z^LB_Dv}d2)ogg#}pGH6HoN0$=cOgTVq#^Ie?jXq+43>K7SvRj)FOc`M#a{|rAmt!9 zUHUuYbfQ_kizH%W&^v?LRDSeVeiH>z=?*!KjgYcM{q3%s%^7-IfZcSix9EXHlHGXo zZ}6@!#CS}XD}32qh0yV1AEGJxT2PrkiP~F|n}ET^aFtXkzX9*pAGxXJAl%I&OhQI~!FQW7a|o5c-Un$Q{A6wph1J`5+^NiSkif&AA;wkj{;Rd9Qc z+WEq2Z9UzxQ|k{vANmk|?**K*olVgv4bI6ip__nolKl&|#xl)2vRLd5%7Y})e=ENj zuwGmey}RafJcu-x)sF?e0??9a~>YxP&~G;+NW3Sz;N9YurFrxBih+K*1VrxioHiXi%!h7%EDu$cqml(ka^ZIIUyw> zI15t+{D+P6&MM4$e|@)8&MyB9o*@HANY~MM8aK9DZ}<>4quFm%UTZvd`v*2{NU-QK zo5b&lX5c7!7wLM^GZB#w4MkEG@xaf{dvtHbG|~4&Jl*es&gkFjgga9nKKg;ZR3`KU!wxKvgu6PermTC4jhTojPmR^tB{&NRuG4>)uji7C~rcD(;_I7(; zcEx4dDKVK+w$>9D9KKse#8 z=5ng+^P0U35lFO+ix}bU$bu>z&lb*`;RQVu^j=>)o?IUo?m=lo{3{*`xZ^?DV-?QN zWqmjKTp7Qv;4_Ek;QbVc>OO4yKWOq)@+zALY_k#RP|YedI3jwsV)ou*W(!h+h+7GK z1>;jQ$G=KmCsTkKvr1c=#?CB3W>*+wiR-*9t7@68D7Zvou6TaN`U~S+X?<{v=}*o1iT&InfKWOaY;36kx5BaLHT$-zHO31h z){p4JO|HO{jxKuhJHiHJ?~Oom0E~rW(>hTzFEhZ_Xed~Yr${m9_k_8!m-^KWgw`WTOVx`43-NYa&3zV1&ZZ{CmS2$Mr z9gpzA!|Tp-YZ=rl*I$OGG@HlU{M`N#LSF(QZgTzNkl@rTzN3TQur{*j;0mx;CRwGr z9%g<=G>76(>B&1SYFU+$(8FF_0H}me&&TLm&jPB?~ z6WDr>gwMKOO!!v8`LB|LVJA?pp4?TeYMkhZ3Sd)Mp&KV;Sk^tW)&vttGk|n^v`)7* zE55RMc&RMdCI~@{M{j5}UC# z;iJpefS1;qWx3Fq80<8|mvb$;fO9D(8mu|F07e2tt8xYxu~#nXLM4K<)Hv7Cw`sgsISYh_P)zT0zASpdrH+T+iyyl z1&|&;;k1~Rb??nrjpYd|-3-A@;;SN%?i zMf-<-L};gYV|w)}^Iu(W@ZJVJ2u@Ej{Zv`@W#_67=E=pr$BK@vaKD)!y0K*_^QD-B z;CEhG$6PYo_#dSYISvqHe93*+zU4QbGilw-t{DMPIRU$H6uPN$koJvsxVRCLy>Fc7 z$)=(%=~l8*$(`9%N)47@5KMS}Y2wv%&ga@4qH=yQ6y^cMe+MQ-I-3X<1_-gMf3rUJ z4EvY~fh4*tBg~d!RX-S5WJBm#Yj`g-;4EoYp7l)pr32|U2X+YER0e{2Nm3Yz5KIyl znjDk)*=uV`5?{*T_*puk<~ms>WALnEuF;v)nVBR2jCwiP=3UA0+d2*riv}ds(1IB| zKuD(2u3ZXtz4!*0#}iCf@{oImWYJ~Rc>_9m58m6YOpw~{Fu1fXe9&cA+?lkelTl78 zIbHSLT08w~gqG zYXZh=T+fxrGgRfdmzXX!2!-9fB1>^2HYCJIGeL7c-Qo`>(C`n z;5^8_BmdIaQxXSfHGb> zi7npwcvscKwddta<}fjUxLcm%YO&uW;~FmD&13Npfb9Ub*pOm`wGk}ew7+?zYdzE; z#cg3qrv8^G!{SM?n}M(-A+qh?cgo&jR$&0n@K&P} zT;|P_6X!*g&pbnH-NVwm_ik`JUqNLHm!oe;$gpx1xk2AN4!ioBbSxgAC!TqJMw8TP zj7=MP9B?tO@;Q6Zzi5(${PPs&rV{LkXPPOY(G>2@@&i>Fa5%sDg`yRl%uoki=|~@M z@iMx*NhY-RSE!H1c>X(+iM!2ftsMs4u^*Rtxq;tf9z7vd04%ysT<@88A`5842Gv(A1GAxbXRpK_|kaKw%+b5!hvfvXLLEKXv8BRjk!sjT;rcqYuZ$8MnbZCwu z^wg4(k#hMOq>&2WEaD1 zHEW=+EpP@=M5myvyPo(q(8tCnnd@5Ism-~I+Th85fDt`U-UzFoh>^tD=zAH*jU zwt_S(*GHb(8fWn_$(LSgviY>DfY8p_VQiO;$^FR|+F@DM80(hIW0Llx``qp1cCQfQ2t=%hw7Ho~I##jY#bX^j zreDJsw*o)lqJZxY580|d6=*x>Oo&GzMy$LGD5~oyqRKzl7c>C?3VL+UHLe!KqU~^E zZ}=}CZv;$^mwxPTcR>&EHVJBFJb=4TZNIi(SC(6}w3ZkVpoxm8hfvs}ih+5+|Fd&3 zi24zYXOO|~oU#So&%IwvmoGKZ)5ZD*dAfAI%ThuD`2U`IKlWZ}Y-FUK?a#?TxIcIp zcnSTk;_RyKb_>?(`#E{)?*}^V>ptrGS#96pYDKJzCN9g>JE}po?LT;yqJ@b@Xv(fh zM7VsDN#v&S(HBc-ICLdJ4k%59*BXazJ!#f_uiMJPlN@AP%NV#yA3SLax8cq``$(zz z4*L<6nGL#2Fo(|of-grOW4xQ;TpFKJjZAjx+K*7E3sh`!0&J0kmOPk!9zJ;%Qes)8 z25W}lJS5zRcBd8^Kc^|^VXQi8`j#ddNYtUC_uikEDt2uoePD1DhcV5{ISt(oAPzW< zI+;Q;++$g#EY+dY`AmTaqAt+>IJa+@)~MU#*K66DhiNkQB1?y9ao9(ZT?2@A!>vCQ z8Rt(zc$)mmd9GYs7uPKl>_x1@k?wTvA6STC0s$O+p5@m=p0_$^UG-Z#o=1EeVW{qp zcAt*WQ^N>T!4A(o0>nICNJK8_T_>cUjRfkRn$jOtZX(-!pg`M2AgJ^t(0)1vhrUj#;oS3-0& z9&r~6;td}?_urq$!YHC9`NGC%!tl`4Vt1aXDC|3gUOZXCzHNlgX}pB2R>GeP2McG_ z)0i~e>1L#21kFi7FzU&k)lQ(sc<(GJT0w}fMLcZX&bZgBY^XKpKpwn+F&W&GZp-#6 z;`#3bLBU}9<8C=Pr>2v*M8&(GQrbMi^c8)G{oLm<*pzK$nmS9cn>LVlf}H+=?Siqu zPi;s!K5$ZZap2n;9;Zn(H_Y17+8($$~YUjCZ<#%WEPRH>2ek25!5_ zth5iaOU+SKl0tg0W`bv&{EqeB-LnLcTsuq@HGKW&*VddW?7Z;sBOVg6- z*a-MWBxe{;jVIGYlj>lsH=c^4;ebVE+MmUzijRs7CWR7opaTEWEx23&CelQy7ZbB9 zFC@r)oo6yMYfvH4bO@VBQCwl8I^SP>5= zD%+vFvRy+Y-QG*&>D;&=MIq3UFr+-OS%8B$0%=P)-mCcRV^8a+B?Da|t!nr6a~0cU(ACTkvy5HwIT0$d@=&C!Bs9ASGLRWevLQ1 zKgCb#E%m#Ej(wK@{F!T5#>d=!b5@+fkmqy=l)=7Pi>c>bn6sG#lNC&Eb*G{9n|S1X zeB5qrsjpHEIB_jaO4?Z{b{2E}NqY+${%G=f6P*M})amiLV0Vc6VYYtd+_^gXAz13& zv6snaH!yfH*VBBq;A@@j29;WzAsF*~zqEw5o{YK;CtPROyT#EFmEK~ykj1=-f8W2_ z&mZ|{v2gJ3`)AsO&fa42cBpFgU0F5J6=cd^F46A{W6|xKl zZ_b`p=;%`lm(KV9#GZ-6{-eJV`fl0q_6DV+4C-iMw5)&fD#oDVGru2~9UdZ}vZ+>~NH9P#3brhPX8+7Q`@tL>PjQVz$ zB#_F5K3E5mRu9194QyU$ybUl#s+d~K6d|9Cv3Z0k04!waFd-yF`M2^g-8B8Qoc$ly z(0|>U0mP87i8W<50NY2{hWk%Ry&tY)$p3+PCPKrqbaoMD(5ErrrRKuKtRH1Sap(|a zFXFgDd>nB>&++LCeUO*R?%#(syj%C7u5gogG1|qVkrDvmHuJ)|^fwT3 zh|lCWe_$+IV6KHr1j@AcHC${r*wPaEg+5nSq)V0X4MSr&!K`|{iou3P0us#KiXtkT z!@(S1#P)Z7ew>@+k3>-THp05qkBa5uT!_WPs#pPwPh&l0#+?nT&-sawPCuRF# zr%H#A2XAiz@ZLhX5 zb@)@hf%uo0$U=v2zA9T4Y-+t50z2Ic7>~SVOanP!<|nF!Ja)6p@jQ_VTD@`3y5lXl zisnuiJ?8=6j!YF7IUVdMZZ6uqJG|9yTE=@1lb2u?#6)Zn%+?aNk9pHX55et-%#rZH zGRy*!Z(;FpIE`4akss27*U9c9m&C^Xs_4TDM0JAlt>V=vQ%TOl9a$3llU5)e^)Sz+NQSD3Rdoqn`4b1pi&%Vt9bIz(oRcL@Fd%sZ>jdF{KL5M?dqs(>F>lDVA&DUR^VO zvk%PaC>_OBl6LmW$Z^t4)-E4?!6{A)>e@&udFNRQ9n(Nw2^$>Gn1Kut6BfbIiMMP6| z8~Q%P*-?4%-*U!sTO!NS z#p|AjZ~m-awvTj|D(N786#DwmCftXH;`HE=(8l za&`NVD3X{dyB{~USD8`s&`Ab*h*=FQ{BTzT_McV9<0}gD636T z>cFd-?iyt8tM$KYmaJeV^j~n5@sk|}plZ@MYyh&v5D6#dca?1A<&%xaB>moB&=-K>r zjC{WM{lero4IE!=N|4!PRz)#B<)H^i1FC^ zCj>u!X;o$DYH)@5KGBzyq`e^w8ql=wsFyDlOeY6=SDu1~xNeB!9l1z&WUQg$N{9lX z(=`*qmQWsH@;_I9W<3*^cKITBe`aXmJb%9-gzo%~DU-n$U&siHl5=mBE1>{AZg9Qo z&h^??v}Sz#g4iAhLI5`{x)eN2QWHm;(w9eLZ#gHfb5&6ivyYgS&(PuZ-F=4C{~+p81AaELy9()O5cUM@?x1eP_eP763cf4B?%INs_J##EW} z*%&z1FQkmO;M#$1J%}ZqvL_%(IllcYIAurITY=h& z5Ghh6Q3ytG;W@(SCp_nFQ~6c;6>jW%MN+m33dmj)kpJta>n<_d9&^lX+=ZE+?n)ac zIF%0#ssv%w8zaHTV+P2-JqW)}sFTFA5Vu?HZLTLfimV7g9phH(w6_g=Qz*42b2XIb z$GIGJhQTT#vIQ6PQPbihLGJ?RcOnUVoTzMg^*haKTpl(6*0-=`28mIwOm_cz%u-Od zXu$=rq5<*6xs|GJMr31&-{;6r(JxIk7suLY7hQh6>17^xWLT{weqVjhA2N@*xpUtZ zVp{OUUI#rcI`U8#%6gJg$F|*{df%5((*iu9s_fDwqf`?4 zD?KGg+||1mkDI8%?jQhu)A@&*+iXVmi|bR+%4u%pA2{Ci;}#F4o-Oh^4zXNwlp-Us zAB9&~a<28U$W&obi$K0NBfOQ7g01Vmo%9Z&C+c(E3%qH_srXp|R(~|%bP)O7|Gv`9 zefko#^|MDNli~ziTO8j71fz4$uVYtxrwl+L_4V}~bXCLarSsM^SKnn57VqIS(FybFEye)c}+?Du=m_a8rgz;(0MTyu>%=NRKc zce_1Ajkl?PU$_J5rBUgl{OZbar-`l5`S#7?v&P0-gbEs3nrBg9xE!*4dV>cBgZgGc z-E7JkE4YrtPWX*9BndfzW1?u?swSWJ@RsDm+gTp=AbzJb|Be`g&?>7X&L9Rev?o3! zjqTdG1_&mmJnJuGuc?)4e{dF`);z;Q5p)6N9IQvXjXRlR$y+$_=k+bZj7eEeBBzXI z8OqUETAO_>WUJNaYTAND{VJ24BAFgm2H4|}K)0jH7I*Bzsh2i5^HKqCc4ducw9e717a zlkPd|h)mCYdk^_o)@KpmbvB8p=HFM34!7ZBytEkwsEj=}d4^RfH z>1JwfkSY{upoLS=%=e}(yC?w+&fmzcE_@W%^hm&>_^@?`Q(b#swGi2UZ zVD3p7i}i#Q4Di1I(BuI!1rQ*Q{0n~aOb!{}!?0a&seQ_BVinRreF&Oihu_bW&a}H# zm11hZ1j$W|iHH2-UDHh(<38i0qa5>_x1d<$9|=)&sg+wxImBMjj}LGd-dkdkJe_qt)TQzULXFUEq6)Gc8Nl6!W)0nyvyF7Vy+ZDWUFcI6cG{ILZunqy?kefj#H#0_Z z287QN*ruJs=prImb7^QPbw`~cCI?g*_H7IUA53G_(IX5vM=d9@j&l3E7JhhO?C!x= zYea&-y?AvE#?Na-nv_u=L4!l^&eIKfd2}PXXa$L%%f_T9EFwigEv?dC1R9?qkE9fe{x2*x3xp3Js+M65n3vifyA*3bf7Hk~D)N1s z&PcCop=eCaOSBsSbks?|3Tti%yh~`r8Ke7-D37Km_PC%qFF(|bD-*nb>lG9{DrQc# zV{d+W-z(7GFGl{SK-^}Wc>u?NJOCj(V2d%)M(@8oNl$+H8+ceEt^_P&R;|E3;fuxk zC|OKJgJJZ9o~B`X8mCiisPYj1nqur5peowxb7U(-W0iBN6IUn@x#K(Z^1U>m4gJ6; znprIgH6i?x&Dn1{PnG?Ca7pdW_XiP2E)%Beo+nP7QT<=QDa;6d9hC)}yM3{5Cq_&f zF)0V7yWA0$d-{P=&t8PJ-;wEoKlHH^l(?e^|4L3@oMXsG5QON|8f9dZI0+j}yUPn9 z@ZZ;XBjcGRQ;C7nj{Ok_arS{yrcl3T_}uuKh@HVPmsCL+I*;!&G>dM;W-Z6A$sLBW zHR*dU1V4X3!O8QNMRpQ1ajnA;v2H{kR|C(tY$~A!-_km7LZ%Y~_d4zQF8Fk(_Ro8bpz~0+Opo{y#ZzL@9g;O*Gtsq$UuL-jH65k1m#JxC|cC70frl)2l>C;ajDZH zYg7KqvaUBYf5U<2WbD{@>C<`hil7G(LaAuqCdvY9qy66N3Y~UO9$Hp?0MH`j*&8`m zNDY|w;)Is^*0a|8Ay*Q@m`5Fm1XMDOiR{#3fL;w|0Xy$TLO#o{s@VETq@nJt5`XG4Gn@+f#4yN}=1AOwnD%%q8tbd2 zP2_70fu_w0L_GcNdW3-m9htoZIw5A7zDqbXr6l(vAqpD)M%c8^aoJ}^DeRHPt!pmI zd@(}TWjQ)}+|pj`I&(j2=&iDnqa8ud;&A0VB8#2n4}A8mouHfQxYI6y)sydHol_(w zNzTG~vRwL>d4H1GGc&kCcBwTk5$Grf>+qRP;LSnUMY{AGIZLr0#AXVvAuMBGtiaFy zqguq!)UUR0E3um3)WcXUne(rfzAA|xx7egCfxahHMq}8O?*T}iZZmDO zn!{T|VhHLvG18|nG9kH%45w!KN#fc-E@%SzY*Y4GG~{H$>VW$BSLOk4Bh4XWKK~6m`D6o ztnS+=mMI($KPANppH^NdJPQ-_1yhF#BG%k|+EdX^#D>@CIE2ZKt&(tj%lavOuW9s1xTKVI8_+$N_ z{Z9 zQ*&j47?FHk%+&O%0sIr4TykFY?Z98NYuL+RE;7y5pk(e=Y*;u-?7!CEb}ddhxO2rn zY!W_tTBM?O5_V7nvJ@J?y-n;q3!W|Jqy4dPIMgX!JYbVwh1VHn+CLWkBAu4ALai6% zbLm>soE@X zw3-=0jupGcTF6Ba9g`MAz1@yuu>`RIpIHiVI`_tOC0;~MU%l4-iWy82lJ5E}9 z{*HW9UAu`t{4T6o<((8oKe70Ft`m2Y9^cnX2{IYFfylkq&ra6+Gj8#}jG{yclil9Y zLxi~sf>~=K499(JZ92#+(?@nTcbw-nC6eiREZX#t4*-Tu;EjF63a<@k;A*fvKgZqP zpM+1_GVW#1PuAJHoc}hCYbjJtB*?`=3wc$jAeWx}^!(#;fP=|emJ&S-5-R_vb)#)@ zM$pm?J#MV%sy5R*cW0$e&C+|v9=6L9R4e}7(+FWwl$6_V|M5=w;Ak)sfAEARoZtoNRfD`5 zV;4q<#Z;bN`(ahhQHy7NZ160C8Xpc3mt|Mlnf&5(P=Oj(31%n(yDj}Vt_7}KD>b@& z16%!%DTdt%z9KS-GgJRo|6si%IuWnQMRb7$R@}b zxN4RTZW66#18elAGR?Rhdn$9i&q;}6BCC(0xn}3iB=!pz5+~|WFHbgo(^M_n_He7Y zsA1(B@N_Tozv4d=z+0jo5E;zDXyC0vNx~fv9b91zobtS{mm~p+6+YO`K9GUaYQ}2P zc>9cK+`26*nB3LTkrNkX>;0|L9KMhl?jz=)j)q3My|23gYty&*GZm#EyB^!&@L z*L?By2jQ0}h3_~|prAQF_0g2a_rer}4%vq|ti&jg?7Z~e@%Un9B){Yl-e2iPWcg;| z%7c%RGk7HhNJ$4JZrTkFIM!h*dSm@M-#FLt`qa3we(*Yrj_^EWi_f@Y9cMH7`J{4Z z2!5_=JKSABYshA)Hb1OXJpQ#6f{7QiXZ^D3h0R}S)lxuVASE$;5l+C~g)e9@f5*pQ zulsO#o+)Ct7dPG7TJqwD<|km{rXAU2<){x~NoyY!d<%2TuO~?R11s_Z89q}g9vcdA zm*mC|S0>vSRu%6B41%3q{o8y(+d(7Ih`{eY>3>mV$)9}z6WqPhubrbd6%>KzvKd3I zt*yG3UNLYhyaUh@k6Z|Pl!3aFna<|kf;Y`7+ihbtF3jP&F75QY%f?C=tgxg=y?E?7 zS?G&eFMLw#-S8fQ<3s0^EUe^?(aE6K`CQ+q=w+HqkWPMKcx?wa?UTHHT&sj1mGy}< z%}zsKU^v0@@mb4!J|x~^v$>fY&WYHud(N9X_w4}fj9f4_9Z8mTTFEZA)(e(LLdR{s z-c&KKSWdU*P0cO3H69!A#svtc!MF2g?Xd^~5Y7h*+kgRaO|}`^oCRpJ?2UFp9szFn z*N|c*ou4_ZC+Y<@>lLFdp)Xd&uUW5~N*D@_*drg72v>64%5fPo`m^oMx|03k?Dx{- zUF{1Uwz(u;=gojOHyoAeatf%(vP0RZ7=v!kC=V69|KzDDnTI_Xwq*KF!EOSs;rUG3 z%6Z6Hv+T-m;CauMy^o%I9S2Bc_mtRRdw?5T@hEAg5!>Gn!>u2kw$~-~aP-NK^HW;n z!C`MyBa)b}ao&u*im?Bx`nswYz2bCU;#!T|V|0(gKMSUE@$TfaLiXLm;t)$0Q@Bu` z&FH4@5^h8wS^=S#0z#6{vhZSwkK0@+W6Uquc#2Hl_-1r4hdGGiVwl2X#TviP5q)gC z5utg%n{)NUFG11Tv_bbtiEuR!LKzrA;&h5$)njpM8q`X&-pu~<#zYs5N$*pB>TvUr zaCcN<2DwhH5i8Zwt*~u(f%4l*)!U?hh2^h)pWAN`ZTF3?CdwvrR(#28`Gp@aKGQ>WCpQ?d6O2eiG2gxOLX;oEQP^wz$Aa=oHO!SoGdt$8DD>iBuc%R+b5Yv#*aao$|C z&)IV}^#jTAWHw^m(tY`#9>b-efmZJKN1)XCf-PtJU4ztMY`1OXN@=dn;lY&$?tPb^ zXOgAUVq|CTaky}!qsE4dJ0gl!6O4Sowf!7CIWl2Z;_=nHgc8itYl=x|o=ws+Y=I3? za(Vh~zEO>X=A`!->!}MJZJ0P*a4hY4yR9pHT4%i^xoblyUSVxW*MWfZG`k>VVmYAZ zW@?n|9ttCa;0dz#y8#7F9SCwA5L@1~GzxMUo^&xuhF5C(L*B9zl-vAwF#)a_Mlslf z*|b4W>v4l!K3Cdz$a&sg##O{km+)xn40--^ zf9S^;Kq+kBMl^s%n?2rAyQr&GmMC$1lk>|o=`Flf-Q)ygu6a~NvThxCS!6!d4V;}B z`;?{mg>Le09K5>~FT3V8j@RSbl{Xjr3(oY&nW?T6rSyk>Z+8GWL^8 zg8Pw>8Xbj`l$Vk!*dtQn(kRV)3$*%_Dqjx1=3oqbYXs*29XF*AP@tTHiv! zk6!7eHF6M@jF{@tj*yKn=g3i9RU>Z>2wDy`!`NkNT65zMXSMU*xt{KTUUV=qc<-uw z1!T`ZI3sAZ%js9r;`>)3tgF~#Kb-=g3WN@i!h2tJ(9?9R%ozv%-;+t+W~%gwG>87w zmo*D5$Gs2RyaI>E+>v2rqpU*eN1o@8G~KbIO&lHQr;)Rwgu>OkuAiQ<<*6d#vqyGM zxv_9|Oq@Lk3wrAxCWAJrfnh!AjMG?{{T`YD=PuPHH^w1P7mM{~QQ6a_cdR}0>6Qs(CY0|;=N?R_vdw-o-d6|ww;6dBpqTf3UsfsQ+IE z61|8v^2Z_sY*!*FA6yxY)xMeDMM*L&51GXj{CspFcrTnCgXL$dPw7q)u!_+!d3hkL zxzaqBx*kzJ?3Ht&N1du!d2<#c#Q!BX2vC^HTXLgXop}b3Mgde=lAJD0FLqKWx7^cHh%|gj2Db5)kBD4Id-D&%C4YNA zD{y#cz~L1Fppgqur;tHq1p32Za{pI5sRP0>P+bGLmIOL(qR=SbU_P*s0GzFE$i6Pz zd=Sy}ppH}_UtME$Q@FA_$Ddh-ua7_P+nxV)Bps@~TT`OGByA`5QIHDF$>%_jMqXXwS>i9|4!WKB^O?(M6=^ zY+>x*92^twKz;CcDgP1d2L}A4k)KKe7K}4I-8Q;tjRlVz{RQufhEbqg!^^NRW;(u@ zTpAS^;jC!jnIH%}8M1HN`XL)FTO@jGZLuFEt{HPi-ma)z7uDTnX$ho~(A5eqd;>a~ zhS^bA^(uGl^@SiVT&qr;gWH=!VR;^MbzK~uRw}d`K~KawqnJyVS?7J=$GG}IXD%`p zOKZ?P&ZiIJlAHDU)9BY#Zhp==QlKhEyW_M72Ll=iXwx}?d+`>uobDH$1uVKyEEX*d zLK6zl0;rXout?j^t1ax>lEqoj`bXIkLZ;$@NQ*WiUnQEg?0`;%S#-dX*YpJk^wt)v^!3s#QmqqMX`x-F zXyrW7ezaa}h}R|TED3xnW-MTV4v=BrN<$8SJ+=_|!2`gCtOqd` z3sxs6c^@|cZ(I^-!$q{f4iA%52K7j}cIgFNAS$ono;xg8PO}gQ+71(C7sEL3G#`OM zG8!O_NbE}d#ST+>d2L!1@q*@lcEYGX<$L*LYsw@Bu4R1Pi?OX{X9K3Ma0xcz2>tv* zGTo6-k8`NM`zZfw z#q_EMP{iw%dwXd}dZMNSUC6f2I#+}A3s1H1E4lFXNv>BoPzJLv(WaP?KaXn-J&Ri% z`AgTu8Hx3tl zpzNx;>Lk2UW!ht~uk!{A?02t5+3xJKSDM0J`Sz5X(O8olzu45bCr8>~Flgjpd+HSn zjhcZcSX!pP@FFNd?yr0=aOGeADLHbw5xF2P64VEphVy3Ap*cYp2Mw*Yk5A_F?Vv6) zWZ$PU&Ev{CTs(x9hqq#iwiM4bcAIT+U*xg0tUF~GFSfyi-}Ngt**1)qHG>|(!@pJ= zz5jK=BY#eTFV^mQ)2HhXCXGIMcO;}DlnPNyiPKzPx^45$N2JU)&-W;=oU*~qo)Aq` z-eAOA^F2Is7+AInUb}t&h7bnnn8z2(*h)T_RN^W#a*$WGD;>*6+)8MLgP;>tBipP& z{l9?4wv~V*S3Gtkr(}*nqdU?hzrL_gJ=eN=N@>0i98`Ssm7X{?AO+V*3Ed*=9_t-_ z#`Uwc^Hi^qn)5a2MqHE3g80J~cqf`s!v!p^#va;^t;aOaJmP<@NqKiK8H)DCGwWLp z@F64aSlpxo!Y1J&DtLp}GAW0W%Rj2O5$C4->-lx zdZXQIMPU&Ky?#6(D*4TEORre)AjwX%Tl$U5&`TXxd=(uwi>s}7j5~@?=qd{ooUh%I zL#Mq3u)TWl(`CnzMp&r#zT_d?JAK$M2UOQMS<)h{qtXh}zt*;(;eGdR8?9#)e3`APqd_)!C#DY1Qq=eT-e zjiCTEf+S>9CkHTY2*u$^yX-0qh;?QO&8p;Rxi-FBbAjj<=JEZlDx3HSwjuL90Y9Yp zSW_5AboKcdL;-tsZu5E8ZP}d;l|{)-;*8jy_%~hQOIwn;OBgW z2Ab5_4ZIUXbJ_8`3`W+Yd&5LWTJQ#1nj3Qo6sSG%Jve(DCqKY_;|#1Y$71=&Dh)HKK%VcE-)|^@KMPp@KnMd7UR-GF5ZDg7zy3m#k z+c}cr8l4E*Imri*F}yY{^*%vE`ekap`)cyLyc~j`d>8+)hnY6YEQaAUE-QS=yX>>` z@^;4mA)NN7mPXd!N?m&Cm_LfVIoSBJW;1LPyr&dGXzcCCgW)@O+Yk_AggR)n6^hzS zG$l4hA5j}sR+TgNnvkU2h2L=UlnB@H9LasdMIoO4hIM@{!Zzgv%+98F#2XE!WSIzP zVVH4OP0i!Eym1QgxU{k}RZuKE1QYKG@AOo$fk*D#V_E_8c{X-#Z?S4jQxqI-y{0*A zOr*1u(+5mqv0^7B(Us_crGy}v2GHs&HK3QiTpCticuFGU$r_T~t%4()=em&2j-M5! zrvXPHZghOa*`NE8r`7#0jxRannH^}m2Gd}&7SR}6QIQGRND5d6D)Do>Pb8n1@J$HOb)Fyq`?g3&`#Nti(ue&*FTjD{rb7E8{B9c z3WWzxT|G`Esr(;y!dri6l4lER&nR10A%Brv8FN(x$|F2Ijy0gT zuE9p6^nsx@=gO{I0igt1>Nl78krf|Fw=V{x;?xdcy{VX3d%y0fL8?SG7P4W92rSgf zejbJDJ+0|z6Y>wQwXPsu-H3mM1h2gat#w54(-T<}_OzAKZ#oej&Sbvaqg5!L2! zA@`)*tT$ay^B`ygf?B1db@v=8urHF4a$rMZWgg;>=RSHK>Yl$*9Rp>fk*-dPiW7#a zuTUCj!L?yhNv3@=Q3lBBcmEUgk2xYEn8UCOnz~TT(xKCN$w?j_ls2_=+y{Aw-bcdJ zYn$HeKa@!YBE?vK>h|eyFKz|LO*MYI?T~fV;hHn)`MG3z7x$2^sK4F%Q{}&J@Q}TT zA?fA9?e`zt6*iRpk5j|hel!?q>rDEm-&k{Rc_9p7kM%}5>|QJ=uCwvi8Zf3x{0l#S zN>cCqYML3=UHZ)9`a$Y!#dN9l?DR3T(LpdI`p)ZiVvZB55)J1YEBm%T#18R(gTMlH zUORMR=_&aLIGJ*hPT#6Ems!oJP@4aSl`BV|u(lZhdXE|qk^6S9ivAa5NzAjrw43}n zL|T$Lcxsy zA%~Lr$%AjGo?r9M@jklVWM!h&7W&WLRpZr{gO?n2as%4SB zubRU;+`fpkAKdk=Ob^Z!HIVgIt<9K5>do#R)Nk)7Lm{xA+naB->1-`G(~TNX$=siI zY$IKq-?#q1;qfd~gyQ_hhEIGzZBWhG*OEM&GF19Q9)H#YL4|oHVoo&Tiqf;%AaBd? zdU~A!RGs+~&J-&0pF>5jeuNg1(_mEEQW&4^hTfJ;=_-uhs+D(>xcF1-wa1Q-PV8)k zrc8NZ+A%+9uB+A8sAg8H^E{H=rGyAEa&w2I`iaxYT-1s5g%Ib&PAT<4@)l8_~t%zat~ z%o!iouh1{TGpy;dd?;IdaJ##Lu12%ey6F6Qb^24eOCbn?XwARc*&w1s$F~LW#{JXJ zPG~A1F^GqLbJL3Wa|KN9GB?(%>|9#tV3;q0)@l*NFde`Be1`METEZ(Yhp&#)gSr7W zpCfn#_|oOK(pcTsUnd_bX=qd7pE>t+Wi}0STaJ)sWjmWh!Y>WHlgT-Fj&9O`#d!p7 z=qo?aXcOfDlfrqKk67LwmqfVN(U_4qCM`7aVZ~X8$84wa7o{A)GQuWIj z&!F~ttpnM}!*%cZg3c!HnKBA18lrC_TJ<&QmZCmw- zztUfY#%JVD`SA7gtIV=IF-U*+E6CXiPIvReIu*%c@>T+Cu?Atyn;U{%;|hB47yd!A z3}Jeu4Wv#DJd;E#%I3P=N@zE^{k~_dO#5ev9|(RRQBFNiM` zvpv;Z;Z&<}$GMKO>^Q)I8T=|mZ|d}U$ov;(gKEQ0-u+by9wG6k1t!z59MJMk?erZp zw;B7ODXP@xVNCafyWy?0tV>$h$EoeTI|7QsmK#qDh!uQ>=`QdwyAt$FKU6}eHc)Jx zodblmo5cQ5)L$OoG)~Ttt)RliZkHil)y2Nls7=B0-kv+5Ji~lna`G;nEw*!8G_!^< zbe|2pE3EeebNf|>iiE#oGn%b)>@9G6hQRTt891h&&XyIeTf^GA?@eP=kU zlox;Cr_s5U+d|2)O#{2x*9xvquB4Z^uV7N9_p${ypG6jlZ$6S`%@p2Ev=?b1u5=!1 z$=GIHn#<5uQyY5luHQIuq?DpuYDsi$9+qZ=d&s0PveTmv_Hn^Ys@y*s=I&ZSujAM< zjb7m&yRo`@wHcd%{QWPzL`9JH`MgII`l*C?w=cn~q~3VXMQEc2>58ib9D38Sf!sxA zV0;E=T7S;(XCtW7N6!L%N$5GE$0;HgBsJx4>oUcD z%n&n4%h(|-H^|N1N~P=;o>-5C?w8XwzxUzYP8da;MDX^H`1#M*HVCah%i>nbk_)s{ zwzzZR<)8tj3feTpbdhb z^8-8hV*uEQg)oBdzDn& zwU9;mJCiT(b7~>+e&CX}TD<4?fbLmb*6=xayyhY^4J>7)4e+Kz%jX$Z^Ng52GatY{ zE_{oe(p8&e{1mKY49TV0{N8(uz~XMjVunaf#)bW>>`mL)$eWoOEXApC zdp|hBcM|Mhmd(npMHtU8i;~mDgEYR@7}OQc;0sya&OX);s$Dc?$fp*$FRq|4&R&pmhtk+40(yy>+}+w;7R6^9pHNJV&}|Nh=xpd_$L zn{Dcr?aCwjC!s>q9qGlMNs$J5lw4TlTj*k)cykyx@8gJJvnA?3)NR%9gNsQUPK`Ry zOr9lM6U}b+AJfw%6{y@0^wcW#@{y?T1FQr<5&cIyi=Rih@m^P4p=p>)dMH0ZuoEI2R%Gc@+JSgc2O>EbHuj=u&C#C5}+ zp<+ldp!=yM$V9`hGe-Q&X}HF`QbAAa@uDv1%Yxiw`On;xo*cdZe>)I>S@InabuHoV z8JZxuQ0)S$q_w4XxPS~~-gy)S5z0TuhlV<0kzw~W(8;8O)5vaqW%Vm9b!b&w*F0XB z0e`X@DxH4JJ!#XD67a#AQJ5I2IMO}uW?XTyO6;WXo)0g9?QHqJ~7rEEb+ z>iJCwwwja4ugmYG2A6gc`W9sU6U>#BTb+!%RP@W4t-*G+{Rk{}`tOLC@WGj1j%)74 z8_AvAo+R782jbbF*!W-kmminufJIcn9{u*rH>3m~MCFOQ3(ynak@{-ic4fh5=h^;< zSjY`&kLlDa5`@6-ys8`3uq-Ty;&>ER%i_uh#eV+s({l~E8Ug7L7`h=xiAoF*tyqRX z>AT_nolI$Vl?PP0S5nEzPiOMhSVwJ#U#3+c@<)bEB~1nqcTDnv005!3`%FU4mBbwZ znSn<%sT>w+BS>VR6L7{~IZ)_-pIu-(fw=O<8rD42JE4ReLi}ROMd{9~CX)$Crb|&q z4&y~)A2W+#aPm6)o$IJ8mVGD7SEuNAG`nH;ZEIk*{{I;B&%k5>-`nf3TNw{s2&4Gh z-4PhDPQ?II9IjLnMXA!RaWAgLh_>4S(2ND3BHroT!N~*r$qRTxTHnaVF8!Nnw+T+G zXA-0R=y)Y)>RUfVY8s0DN1dJDWe=Igl&N_g?@f~wRl`1A95fx3EAWNVk$pKOj*=A` z2%O4KHC1JV(Nf9_`XmumN*LajAz-p>;Hf32&=dt|7^MR#o!_;oq9uxa9d_y0e|pfa zqU=US8a<~J|3X2VBm+MA=TD$O?HQ5vQecb2t*aCt%{q1w9rk@?##{e4g=P zi`?zhALOvixr3Ux$o}0kop5})=VUCnaqo%3^b7~l6y4;gSk41g7ch5Y(~XED(s09{K7uugK45fCC7Y{ugA*bE{L?1E2&~Y z**k{yl;9MOvy8cugL#fq7-m*9W+06SgYvlU3(3+Fh-n|7p}_Fk{5qvj$^21!Q*~(R zsnKcZ`dS@~{Og@X(62>2Xy5|?ya?{)Bl-D5$O(6oH{jw?zs3lW*G}e%B$AX*ByARQ=IDnl$#}F%%BiDG zu?rUQ?-BugTGx9;6c|kGEc78F{oB753zzUkkxlE#igEO6em8vD?Wtm z06H~QZ|C~`&9uEfD?bTnI*pDtR)Q{dr|-ix+m!aeNrHV z7Ja!?^I^n}{u!$?&q5P=zYTs^0nB7pGkef#$%LQe*^Qmi6uYHXb?0_XZpVS$wSzS! zr1liry=%p4_|m2?r5Q!FG}vpfU*kTLK6yst33aQr;3iz{^_NajT^zu_60z&8;w(2s zr5!RT`gMBGC%ef=O^$wz)qw0^zAW0Rlw!AT+2@5R`EKZ$KW>nSfGwwaeBe1q*#9yE zDyYk=vDGNhE|B=AFzxRz%;U^g7I(U0u||Q4b}^H^W;Fb_9unqt$v@hq&j^!x5y!a} z*W7vLpLwnQG)1+I&igg?GOdgeTA$Y90XPXFYkDugvqLw^oj|navJ8B|2T6X8cbCWc zf`(M2lA-V}E(Rx2wV9+g(k7DYop%-}-Rdsv%+nG4srO@A=Vt)VY|ksO_wUf2y(?*b z;IDGbCh**61V%gI2huG%nt@AorWA{OX^CiPtem}OiOfSsk_n8*Ov?dhc(?{)7miTD zA{p{1@*$TBpDEPFUZV87LWAS7oQ@qk*oB^fxrL(cvZdH3a-eHKh+jV5@)1jzEz%4u z&Hc=c@k!<5r7+&_X>30=qgvwM64@CqNn7F%KIl|7%jng@+gh{}=K-CVrLY97{>QDQ z;7v?2X$0!MRnTdkB3t_-3F0qo53G=K!El`5M;@6(z@n89@E-Y__VrI7_GiCT3RHlV z1<+l0i2oxs#KI+l;Q6B+kz(hJ!y+YZhfsri(?ldC@n{c;>MDG5Q`a$Xql z&^uo{M6=w}#8r3XHN4Ox-0dxYP5V`jIf`9_X*gl;fy*M~?)~S5 zhyQjxH9ELCoUz*CRzp2tJ|AU!_RutO=IM!pED8^Kf7H^ZV>9@{*m(} z%-$CO&FF6fwE!^#;B@;VgPGQRklZEd;(R32^}^AMa(R?M!8nm`R@$d|_7|a>lSXz4 z2igLz)m&PVU+zYgF@T_QXFV{DoFreEPRpYo%jv83_OY!5cqFkFMn9}0jwo9z>|{S4 zU(Hf>CPATcU%0-(&>~gGhKm4BG+aGMF*~yKm(Jv05+(9_^7x9FO9D>Bl zyKC0<@{A+*IrPij)G_t+7uZm-J&bFjCL}M1jdWS%#|$EjmpknSYC1wBmB#jMNjE|n z(b{3*sJV+7E3OV#4|((fN5F^JM*dWs#~+TaF!K2;tX%$yu#^>N^N z5Cypqfzp2WrOEyh?#7=mxk3-U( z2e;8ymGn)HP+_bG*y^i9TwO@2{1~$E(cs`ByXz`MtU0Ovx#-0i{*POg+LDGQV199* zg4J{+<6^s{#pI5J#hiCu&H4Y8jlx(Ehm_EQu>8g`BOB3niwS9V!stL#5?BK-7M z@8GArt1LZ8OUFhfd95mw_s8<5FsW$`Jk|()o=Air8j@nA*plwc z>EzA6BS1XmvIkw-N8f9o1mVMHNe|}1F6E;xRRrb+or=9p>%Qq59QRDJst@q{QL5;% zv6q=gXUh9|ZDqIG@5l(Uy{%)hL~Mn_xgwA?in(uur+jP%#HT7JO?k*kqn2h;-r`5g z6WVCtVVBH3JW?%gjqCL#n+?9d=)to82f8U>h;8A3kQKjIG5>v8{MI#tEadm~PL4;I z3Dm#L_riHvRA!DW_LfmFF}SxU!4SzT1J-&WFgwr5-8FPq*6kr}2lc-DmPwczUd$-d z23inuV-Z_xQ*wNYm4JQm^)KqkUQMUmP6OpTpd0)*1t>;VTjq{($Vik7y)3+4Jy~ZI z*U4T|-mKV3Ix6!mk?!;&lk4?-`&9*S_?bXU>Ooux8)6Fk6}ulgL@vMf*0rgC4H$Nw z7hY*Bm80Uic{I>S6Hu(cnxRy^h|Cba!iEkuM&K*HBx9oKSh5Q-BGz`V*ME6JY#T`| zsUc~|{N~{foV@s;Q5kGl=Z!i!fc5{8Quz;z^Z$C*`FO|?XzGlEL-ht|4c1WdQ-oVr0NrazQG%Z_w;T`jG^9 z_U>1)G-B}V``~&vYS*^^+!2`{53M!vltYocUUB}_!CPck4dw*>z(1GdZzA)5Bc{CM zAzx4|R4ydASG_70zAXfakRIqm?C!zfAK?2i7uhMQ$DACkD!Z5rzM4y!LK#3sUckc+I3uUrw?aZiF}hls zgn!{2urM2L6RotlCc#6F;zFGe=+7QRPajTRR{4s@pBi`%cP+UR)jfviViEsV+dygW zLZ;%KMye6<;tLa~r3aAiVc}n_KadE=fVvIkMgqy#dq3H$Wx%pD-LS z^;kO)G(_3V9j16&JBv_f)x?4lHGUJK#VgRXTOpiUe*117VN|-nSI|V%HJPg!i4Vo6 zwptwX{)4(=hbAfaw5w-H-Z~@icUw~XyR<3~Q6$M-Mv^emTwm?8#{aYp8moOqO?D7v z&Tx2p$nS`)8jH7%(Cnd4hs4I=Er{~Vg;S#5XK$}1Kp)ODnU2QOdg6^721S3LwCzx? zSk8jxevju|JJ!SuJ%gi?5ZNl-$q~5MzZFK?moMb~)ONKQJH26HU`tO21~-5ly$6rd zRjrS*V&>PF&<7p#7o9j4HD1r2OHs5J;V10^48r9Vs?cNmc9|pRnBEdoB_+9YlnYQ` zDe&*BAx2#=V7fnlqu0CxuAkd;U|6G|6HuIt9u>!TzzV*N3ew|$uhGK)m^0D;1ep#8 z6sV>xXR><*&k_ZtnI^D6c>Buf;9yCu@ z_@-jD%yIZAH33uT$ z7I_8-Et9lgqzRSPH}!&RQS`toOTICsqd6!Wg$BuBBmki_?%<9WaomhDSkC`1^YV~Jd>GmRE_ zeE%!Kg8NCcus`i3PIW+kA%T6U)(&*xe0N){poL17T$(?>D4sX{4V@C43mDHF5hoW zXB|+7_p>C0RylBy^RweGXc`;H4WO-$T+)P-@?zyR_~BSPpvm*A8te?F;!i_3;2d|N zsF0oKt@ne0yO~?7%`R>D-Y9P4luK8inRN~6J|<7=0G6Z{yq@6|=lWa@3vRzl*9}I+ z4gGzeLGLXO`F|!d#6o@bg#`T>ir%qRm`EO5&5%-FMz!~o%(`9(^BHKnTD3wH8`;Fd zqjp8-SDg_CfMNVKc|;zEf1TdTBxn3F7nJE4ul*j&=6lfl|GQ>7as#x^gH=4Ydo1jVm^cD+tbmQqI_1U*g|mK zo(@U(FZyO>2@`7z49H8EApw)d>QC-rl2ut@GR9_LAPy>UGYR6kp{EEFhDq_qO3uCu zJ)*pP{jT^2AroVv~17)7s1!hXCuEwfZ)UmjRUWLi)q4}Q+ZtI%QBJ9ktQecu( zPzKF`zkfQS1bg?5Z}R>XG~=8+ILp(rgc~yEy4`2wKA-Ml8g5Vo_WCrKVS=f=WpGd#};p@U=9J=KjUT&!yrDORBLDZhgFA+s- zV*dYhxwE5!DoOj}5?aw?TMXk^IcWC$lU~p@*7lS{lYxv+pzu`9w?1RFONHfM)dm0C zU^nL?4M^ifS7p4VD*fBbttarA!Eu3ori!YnRzr?s(()%8b0>*CatGN`!ZPYB^sDbQ z>prN*N}fy~m1l#PJ1dQ;VtQW@aZs~^(z|pLj>xm0ewMn{@aRj`;nwl`fljV!q#Bmy zzj>(%@AHVAd+4iPWlH4?k$Y57xl&i7x3XYaP*fZX?!^}2=CVvh1q+j(AFWFiTPnid zE55>xJLr5y%0!(FRNU}PtIYOggMT6L5Kvg5x6N3M43>(SXMNM9rV%4f@s0oNG-y&7 z!gl1m=&eAT>S5CbjWmi8&l8R3RkrIz>ICD;)8La8SX3o$9!kxNuM?0FJeGmL#ey{v z6xuIO2;Uhysz1@SB(uk0FwYHESqizXf@^iSd)0uz6Q&7l6Fp)^hn;C+H);{(WxKqq zJZQ~pOfzJapZFVl%x5GXP<9hN9GT1z0Urpw@+e?&EPOBKU7jQ)RG`ct4_!L|ZgKbY zGrkrljAI>&Denm_#p^B=V3P5-OyqyNk(<@Tkt#dEw77R4O+*eh4~4Y^Aq3=2@CV1K zN=eHJ!*BU?2iL`ORU_0sQFg(sU>zq65G#WHGnkzb^qr!qCf}csn{L7_o}jKZ+9{6% z+qQinGr}NQ{6TgZ)=Goj`6N^w;Up;{o>>Dw^sLLRw!03^|8PCSj^2Nxo-y77K_K$S z$qQg1;ZHr%u(HIj`4h>_5%fKt0_F=nAWIFt5X60~u2aL+ z^C%^Q6}qdSZA`%ivWm}y$Uo-jCQIiAr)N?{_ro@U!#S2Qr0i8|h9P3HS@S?uf`^bX9 z3A+M{J?Xv3E4&7%U-#C_21=fd4M-lcQIw!W5UTMB7Acmi(Lzq&krigNi{wdH2L3BVrb zkLl$9wIf>Q!js{+G5mkp`|@xo+xPFM%@SoRYhmn>Ws;O_vM-HniI8QG>|2PzWM9Uv z46@5oME2|@Lb7Gwx0&o^-;Lje@AFi@=ldSV`~Uko{4tJu?&~_8_vb$M>pWKo*E8;5 z;JU|Vc>$v!jg?$!KEqTp5)ypoHplGhoe$7$-({Qj=p!_1QRL&MkrTtR!TE&7Oy+&| zv=;b|v9&i8mR%_9$>(nornexP53R!V!oa{-hq|H$6+4<<>9JCV1@gqS4f@; z2u`=>y#7}ugPHIBWRm?}T8wpIoZ;z>tEZ)3C6K?PO%#md^-80?>`9KGD$($cIft~Lb00uG198Lp1a3S9n$r82|1}xE zRAY$FjAO|ssOta$`m}9htdu|3t%-rj%x_PD!}$nAP1;{5Z6>}4b~3+B0POjg&;5n| zOl@s5%4C)ogT`CFrkQ4YEmI2vMJq_9^|>Bp-5U7+m5~=*-9mhk`(6Yr`%_)rFXl#S z^amu}Ows~nPFLZRk zsl)V5FqA6GaQ#BYd&KIF@tgL9cNQ`BaZ#`XHzT*tit5Io*twMxce*p6NWtSKNd>RQ z+$YHUoA202=?UG1N9E@}fyPxi-2p`<^&fAyl&X8*KqyFfJWp71ce2Dre%9!HBVHg? zPO*ZaoPEU2JyRW{;qIOzZg#)c0R6+97Y8$W4LJU8>UA9F z(8&OS`V?@+j*{jur*|9#HneCM#re9e&owi3H-=pnv_6_-2&s1y^IXBWUvt2vnI0ik z@X5>tEKkoD8x?N%ur<0HWDPte(U@*kH|5w@t%)5}E4#h^)&a`-v;{XvVvx*^ALX{D ztezoKp=v<^3FP)qPg~r4Y{)f8lL`T~6vN5)8D@*!+JM$?(psecDb$`m(!fEML=EnpV(|){W4{U>p;8Ji(P^N`g>JhC5G2h_x?0%#mKE#5Dy_-SQ z^Xu8f$CdiULPJ?n^!R?VbGo4d;zRX}fYp+7p9tmB?kk;p@?frR@tNxwNrO+%Wwmc= zb7QU|A9SgSr0kqz&fK(IYZ|dTZxbNLwCOaXc**J|8G&+f0<(#~9~juA{NUu2aB`>( z4|^SwanP)%+1lfIOrV0vVS1y4{#+f|VgNUFUtGQEs+ahRSvG3vVpq{VuHfl1qWnA5 zZv1vv3_sOjLf8$!km_!N`Yc%N-&TcJF z;m9mH+lh}bUvNfrbNdVD*SnX8-)a+?MDbZ-=(2Wdp5kI=dss?u?@3(Md@HC^%8Gup~ERVa4ulKgb`8XN7rD;mi zm|rrHqUw`=!*6_OeBjAjHUl#VJzd=u3qG{5iKRc{dv3 zNL)^SC(wW}_NG#~C0T=T@6vw4_k!-OPN@5N+4_i5X8KM2ka@L5cw`vr=?53CB)JMQ z4c$=4DemjZ-X)7Gw{7`26@r$WKYiVijd<|&4f(+66;X>4N7)WKv~f;ED@kv&#}531 zY?q+&!iinbSjl&1vb#Ss`}qLHV3WjaD6{lO9opJ0a{c%beYU@wK0Wg2|JCyeiaiZY z`cZLCcAjN+g51gKoI054=$R+FR6p5&nEKnd-!sYH(=YzwQHk(MRMgIcm4{4fil(?W z+Gm>Ww7fEI3|m3AWGcUe14aDb9ufQ24r}?*GYT?+aQX%WunGvM>c4z6TIH%;(EQLI zBolurmfwQtcM0^L??%q3q!|nhF^k(z`tC_Se5)hWMSYyB$qm;&r3rfM6`4jOBqItu z9;qaA)^bdmKnukA@nI1ei50Mo&y|Em!TA)%Lwlav7izdiC_{(BZ&WMQ$1L1>s^h65 zQq}gbp#Sl%?__gl;?3%|b2UTa%i|gv?Sx4maMtto?)%Pq>*wy<7{_3;&I1^WY$-I> z%pXkb>|11PY*wq5=reT{!57Xzk#`T0qAwT}vE7-W8y+kt<{ZAk+|eccxsXH2)e`|b z*<0D{vE`LfyebJg!b-X%=#=S;lnN)k5t9xyt#F}7+s>_BKa`YUB<(Kjc}~yev64&K z+9qa{F(sA5xb`;SuPC>#VH9V-W#HlivYXwTLElxyTab-oId;@}pSwqCDL)*CH*R-Q zVP61)Dt@94&@}w|zXaQVdPsAf6DV=uzy=S8%b+Ot zmDcz>+ZRkDf*V;SvJ(>%ag8!ZwrC(PIVT%K@ue^+Q{AcN9>HKg$9Mf9)z2GSy*ZbY`uKtD!Gc(S5L((I0e2jZ3md@aO8)8hEM%U zDB&J>2g5f3~LQ${)5K$#mZV%t%667d)>%K_o&mfFPY>f!%)l77)LI(3@UxusvFfJY_Dg8O!x|E zY`g^nHo0o_L#zu;Mk&h;3q5(z?-iT9(#>#vMjKmag3riv{BX5fAw7$+^bt`em_i*) zHC)+{!*~j+7Py%;^o1$9D|xO`?U24Sb1Mz1UivYDyl2}%Cd*USSiYp%;4MnE84`QL z0RMGOV%@aenuqqS?Qw0D zE?`ja9eJ5_lcH^4DP^mQ!$}nHI ztgYDZSeyf75@R+hc^-;um&D|$ovVZw?_U+UmfVIOH!QY-p1u`)y&+#1 zB`980pQK&vkhNE{+TU)lEM!8Ngp~Nqj<|2#ZwojjZa+I`?YfcoeMD{o(}=l%{_!=d z2sf_3E!>j=MZ?<{FfCE+cd$D0;f&%yp6UHJd4}6t%mQfItjJ&r__{$rK1(KO6Wj2|wBW493HSHtt zduP;P4}x0NQMc^IpADmAV3hlcYdoE`GO(g7$qEOY()OuJ)0UXOjlce|kt9xBv^nqK z`I!7kEpF;-bJNa&!f*N(InczCMt$&nf&xPQy~_}iv*l=^Pmply;C!ZZ&rulZrv$ps zIp=rJCvzUya4NQs@gMlmpWAH_TvSyJMES;cl(q_#8`zGx4zJ}>B0upjMeWfq>KDi^ zI{^+_qkCK0t!G@7VmH22@CIP}vNhYG7cNkeLpZQpgBMB|$fRWIVVP#@Eg7EsI}UG+ zTVi;DBf*!Pi)6g{p~8&vI2~^hOB|#=qS4(FEj`jjSqYjcOe*ezw&9%l=;gT{o`DS_ zGBQlEq0zIzE8O+bHI{vR3;hMB0E|BMzSX6Z3!5QJGYG^(c7kU3)5ZevtUE`_;RtZz ziCFTkyvHRv@87>~3J*Vqn{Ltchb!B`Z7dfD3u!HJdp|cFwTqtk>22|;mY!=kEh-O2))H`Ri06VzHn|&!q~W7jWY5x)5{RMj!^P_@{!S_!P{lSd;c)@f zavG@C(?}Y*RJpFo>9b0U71ZsKYiI=iEcvdLa*`lI`{Kj)D&^PvI@kU z8?Gt#B*v+49)@3ey27b4I3S34xK)0qlvQ_vgX=-LcyGS8pO%aF1p@;PBiAR}*B{;8 zeGubx_Ju9FI^gchBoBlwkz}fD*7C}(7Mj|}esRgZL5j`sJc=;DuysMRNjI%rOtdX; zpHsP=rNE3je~^9i=jv`4S4|##k80fhw&&eGD-lDtm>}^Y>ppudU+2kfrLubOb%vYh zI{C@Vb7(i~klL4nPhP^IAAKt;EA^(bHFILjNsLCK6|JHiN4hRm#!#{FvJ~9#kM) zKCgZlhz`djcOhPgHgTd9S0+|UMAeG8+%+rMKZhosI7mlc0kxK>y-C7+@gFmtmZXIq zBeZcOH>5Vbik2I@Xfn+b`gtbxzU;gWh|H29aTP1JxOJMa*Q$}p?9>I4XajXT6}3MP zVYPtASFsDcWA{BhQ_5Z=C@G7q&7$~>;QZHDyt{IeR)!B z@#9XZJo~HqRsF?&?|wA1aWq%|Fb^96(-T&`xZz~biw*OA`5F8I6on^NCL2h$6=}BT z^J{52Bc& znmET~WYkI@%uCLUxLhJ2i}Qs@a|oWL6MIVOwoZp8P~QEbfPvlbl8fi4vn_RIzm!*N*DUhDzg~b!e>GSEG`HeE9GeWay#df|aY{p+J7%RGU zy}&qo_q)kO^HlXRaf!R`xIX*HdG#aJ`jxTlwkIFiC2>nPXGI6e1lZ=vImvlGDqNo` z8!GPp;|?1l5z^Q|n@@VRu}>Jq=-#@8Ne<{dOiMyXX^I+ubRct;gHW~t)}-{!PmGT< z)2Jw3&qZDfQgp~e#vAuJcr!bs1YhA|AlRuO~oHi<^&%m^_c$r@-|r3<$IUCK?TKwvqDnsA06TC+=_XD zKi?sopHWa@T8b22M31i4Ba^qR4k*=?cHmTFt0J3;VSYSfc7q3LdS=W_*R1*l%1b-* z9*TBRBhJq!!4Xjj^GT8YhvO4eOt6NbAy1>|I3%DyVUZqRl9pFiqF^5c&DyTkb{^}7 zV;{mQK-)V-vSp$$iazEsb_Qq{|S3`m-y@E*g=17rev;96`bjE zUw2g;vXOF8SFxwrS&~isgDj&MCZRNZdfIq$7e_}yYxd(H&Cw?P!3vU+@CgaGjwfOy zr0Q4e1v_2mm;#OV^>EW6h?5WlYBV)(lqEgT??ToP`-Z~I27zQH2lfF_Tj@>=be}Yy znIBA=wUVSbCiln&w`zge;?Yoz`(GM1Y%QLA?3JBgX1S?dAaKAw?W!tpgV#>F{7qs+ zxgo2oA`jn?bQ8uWR$l^z-*Zr&8vp2%`PP@562w7#`=NF2+y*sOdqxsw=le`qe-Y01 zdWodU_lpyKD4^4Irk`AaeKOjuahd*{V?bDrZ^;n9Bey@BYUyE}=E zD%F`R$O#Sys6N|JUua~$3<;KhUZmx?+O${-vA>u$Ju_5fHFaNXnh%<*Q54Zx(eU;e@i{0!QzY3nEqG*-Z4 zk%nT~HW7*vys>yQ_>PwC)IKt)hz)xOw>VU`@~u0~{|I0ArHIu0L9(OySz|a8tH;jT zO2BThdDovCfd3{ctTSD{r34MEy~%+hFoevPJPjODl!_t-CQLMJ{PLwscFU8pU0}W6 zIN(b@L`6$Wlz5l3k55coTp5tf<_x$N74mqXFk=s5@aEo?D&%Hc@r&PqO9ZKR24(?) z0cyWQuV`ktyy%Dzpl7--8!HP7OIoiDCDX>iPcUG+i}8Q$H@}KH*hD=ZsVLFQ;nK=W zxXMi~r2vk)-d4Ca)s`sQ6an$sONBjMGW-{%^IhrkWaV;^yt+5U+_yRoNs<7C08&v9 zplW@6e0=sMVCLW0{(T?Reao)=#sn55qf(d;F9$Jzq9eu3g|S+gk}9pJbM7Cg0L%=m z2sF)Xuxd-=n3$Of1TluV4M62EC>d9{_E}GLrY}Ev5BKa(mdNnp_cskd{=^TwTx^}- z7yw%w>*+I4kEkv6j!b2SD!277#Q}fb!!`jduS?1=Z@jQ1f32S3jlp2d$soS zNqiLZw=!FUy=mA1bn$ob00`N>&Tmj7)0g)!eeRrQ&i#gD6yALB4iSVn0N^c|!gjRk z;|H%(sXZ1re$wsL>BiZvEM?^HxCWd8uNQD&Nn6uCr*zE4Rf;*+ONEq`xi z=mOjS&N9%duH{;SIPy?>O9h<293hSblO0bim*wQ}B#F5qDgbk|kMNVa2VyYMS=RJ7 z?U)^=+b&Pt2?~rgf`G@Gt)3AqHdU7*=_BB|_Y@v??F`t{l-A$qk6(@;LDvxy Date: Sat, 22 Oct 2022 13:02:05 +0800 Subject: [PATCH 1079/1892] add notes --- examples/alphafold2/README.md | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/examples/alphafold2/README.md b/examples/alphafold2/README.md index 693dd934..52425345 100644 --- a/examples/alphafold2/README.md +++ b/examples/alphafold2/README.md @@ -1,6 +1,6 @@ # Introduction -Benchmark different schedule plans of Alphafold2 based on MagicCube. +Benchmark and analysis different schedule plans of Alphafold2 based on MagicCube. # Model @@ -57,19 +57,33 @@ According to the estimation above, we find that the memory distribution of evofo | Evoformer (Alphafold2) | < 1 M | 5120 M | 66 M | | Transformer (GPT-3 6.7 B)| 192 M | 512 M | 8 M | -Assume the data type is float32 in the following analysis. +Assume the data type is float16 in the following analysis. -If recompute (checkpoint) is not used, the whole memory usage of $n$ evoformer blocks is around $10 \cdot n$ GB. +The memory usage of $n$ evoformer blocks is around $5 \cdot n$ GB without checkpoint (recompute). Since there are 48 evoformers in Alphafold2, checkpoint is inevitable during training. According to deepmind's nature paper, they store each evoformer's output tensors (*msa_repr* and *pair_repr*) and recompute all of the activations inside the evoformer when backward. The memory usage of this recompute policy is $2 * (48 * 66 + 5120) / 1024 \approx 16$ GB, which can be fit in accelerators like TPU, V100 and A100. -## Problem Formulation +However, this checkpoint policy cannot resolve all problems. -TODO: try out del tensor in functions that to be recomputed -> offload problems to jit tensor compilers +1. If the device's memory is less than 16 GB, can we execute the model successfully and efficiently? In other words, given a random device, can we find the optimal checkpoint plan to minimize the latency? +2. In the *Extra MSA Stack*, $s$ can be a large number (1024 and 5120). As a result, the attention matrix in *Row-wise gated self-attention with pair bias* is very large, $2 * 8 * 5120 * 384 * 384 / 1024 / 1024 = 11520$ MB, which means activations are the bottle neck now. +3. In inference, the setting is different from training. For example, the length of the protein (residue number) can be very large (around 2048). Activations in many sub-modules are extremely large and far beyond the device's memory capacity. For example, the attention matrix in the *Row-wise gated self-attention with pair bias* is about $4 * 4 * 2048^{3} / 1024^3 = 128$ GB (in inference float32 is used). -strategy: detect memory constrained parts then coshard them +## Possible Solution -large enough size of input shapes already utilize accelerators +To solve this problem, current dynamic programming formulation need to be updated. -should include coshard into the dp formulation +1. Instead of the activation memory size, we need to maintain the *peak memory*: the sum of preserved tensors and maximum intermediate variables. +2. Different from the previous binary choice (recompute or not), there is a much larger space indeed, a list of tuples $(inter\_mem, preserved\_mem, time)$. + - k pass recompute policy: reduce peak memory, increase execution time + - coshard / chunk: split computation with extremly large output size into acceptable ones + +$f(i, max(p, r), q + s) = min (f(i, max(p, r), q + s), f(i-1, p, q) + t(r, s))$ + +TODO + +- try out del tensor in functions that to be recomputed -> offload problems to jit tensor compilers +- strategy: detect memory constrained parts then coshard them +- large enough size of input shapes amay lready utilize accelerators +- should include coshard into the dp formulation # Experiment From dd0317cb9b2a46a5014e0358330a8da3c4dfd3b9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 22 Oct 2022 18:25:51 +0800 Subject: [PATCH 1080/1892] device mapping for cross-shard layout --- cube/graph/gener/layout.py | 327 +++++++++++++++++++++++++------------ 1 file changed, 226 insertions(+), 101 deletions(-) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 10b8ad17..6c557936 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -1,6 +1,7 @@ from typing import Callable, Dict, List, Tuple, Optional import copy import numpy as np +from regex import R from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -19,6 +20,10 @@ from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim +TShape = Tuple[int, ...] +TRVD = Tuple[int, ...] + + class GridLayout: """ This class assumes a full-tensor can only be @@ -189,7 +194,7 @@ def r2d(self, dim: int, chunks: int): # prims.append(ChunkPrim(itensor, otensor, dim, ranks)) return glayout, prims - def incr(self, chunks: int, devices: List[int]): + def incr(self, chunks: int, devices: Optional[np.ndarray] = None): """ RVD+ Prmitive: increase replica collective: broadcast @@ -198,6 +203,12 @@ def incr(self, chunks: int, devices: List[int]): layout[0] = layout[0] * chunks glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) # set device + if devices is not None: + assert devices.size == len(self.subtensors) * chunks + for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims imat = GridLayout.dims2last(self.mat, [0]).flatten() omat = GridLayout.dims2last(glayout.mat, [0]).reshape(-1, chunks) prims = [] @@ -205,8 +216,7 @@ def incr(self, chunks: int, devices: List[int]): prims.append(BroadcastPrim(src, [src] + list(dsts))) return glayout, prims - - def decr(self, chunks: int, devices: List[int]): + def decr(self, chunks: int, devices: Optional[np.ndarray] = None): """ RVD+ Prmitive: decrease replica collective: move @@ -216,6 +226,12 @@ def decr(self, chunks: int, devices: List[int]): layout[0] = layout[0] // chunks glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) # set device + if devices is not None: + assert devices.size == len(self.subtensors) // chunks + for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims imat = GridLayout.dims2last(self.mat, [0]).reshape(-1, chunks) omat = GridLayout.dims2last(glayout.mat, [0]).flatten() prims = [] @@ -223,8 +239,7 @@ def decr(self, chunks: int, devices: List[int]): prims.append(MovePrim(srcs[0], dst)) return glayout, prims - - def incd(self, chunks: int, dim: int, devices: List[int]): + def incd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None): """ RVD+ Prmitive: increase dimension collective: rdscatter @@ -232,16 +247,21 @@ def incd(self, chunks: int, dim: int, devices: List[int]): layout = list(self.vec) layout[2+dim] = layout[2+dim] * chunks glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) - # TODO: set device - imat = GridLayout.dims2last(glayout.mat, [2+dim]).flatten() + # set device + if devices is not None: + assert devices.size == len(self.subtensors) * chunks + for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims + imat = GridLayout.dims2last(self.mat, [2+dim]).flatten() omat = GridLayout.dims2last(glayout.mat, [2+dim]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): prims.append(RDScatterPrim(src, dsts, dim=dim)) return glayout, prims - - def decd(self, chunks: int, dim: int, devices: List[int]): + def decd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None): """ RVD+ Prmitive: increase dimension collective: rdgather @@ -251,6 +271,12 @@ def decd(self, chunks: int, dim: int, devices: List[int]): layout[2+dim] = layout[2+dim] // chunks glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) # set device + if devices is not None: + assert devices.size == len(self.subtensors) // chunks + for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims imat = GridLayout.dims2last(self.mat, [2+dim]).reshape(-1, chunks) omat = GridLayout.dims2last(glayout.mat, [2+dim]).flatten() prims = [] @@ -258,8 +284,7 @@ def decd(self, chunks: int, dim: int, devices: List[int]): prims.append(RDGatherPrim(srcs, dst, dim=dim)) return glayout, prims - - def incv(self, chunks: int, devices: List[int]): + def incv(self, chunks: int, devices: Optional[np.ndarray] = None): """ RVD+ Primitive: increase value partition collective: rvscatter @@ -267,15 +292,21 @@ def incv(self, chunks: int, devices: List[int]): layout = list(self.vec) layout[1] = layout[1] * chunks glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) - # TODO: set device - imat = GridLayout.dims2last(glayout.mat, [1]).flatten() + # set device + if devices is not None: + assert devices.size == len(self.subtensors) * chunks + for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims + imat = GridLayout.dims2last(self.mat, [1]).flatten() omat = GridLayout.dims2last(glayout.mat, [1]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): prims.append(RVScatterPrim(src, dsts)) return glayout, prims - def decv(self, chunks: int, devices: List[int]): + def decv(self, chunks: int, devices: Optional[np.ndarray] = None): """ RVD+ Primitive: decrease value partition collective: rvgather @@ -284,7 +315,13 @@ def decv(self, chunks: int, devices: List[int]): assert layout[1] % chunks == 0, f"not divisable value split: {self.V} % {chunks} != 0" layout[1] = layout[1] * chunks glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) - # TODO: set device + # set device + if devices is not None: + assert devices.size == len(self.subtensors) // chunks + for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims imat = GridLayout.dims2last(self.mat, [1]).reshape(-1, chunks) omat = GridLayout.dims2last(glayout.mat, [1]).flatten() prims = [] @@ -411,6 +448,35 @@ def dims2last(mat: np.ndarray, dims: List[int]) -> np.ndarray: axes += list(dims) return np.transpose(mat, axes) + @staticmethod + def dims2orig(mat: np.ndarray, last_dims: List[int]) -> np.ndarray: + axes = list(range(len(mat.shape))) + for dim in last_dims: + axes.remove(dim) + axes += list(last_dims) + axes = np.argsort(np.array(axes)) + return np.transpose(mat, axes) + + @staticmethod + def changed_dims(src: TRVD, dst: TRVD) -> Tuple[List[int], List[int]]: + """ + Get changed dimensions + + @param src Tuple[int]: the source RVD layout + @param dst Tuple[int]: the destination RVD layout + + @return inc_dims Tuple[int]: the dimensions that need to increase + @return dec_dims Tuple[int]: the dimensions that need to decrease + """ + assert len(src) == len(dst) + inc_dims, dec_dims = [], [] + for dim, (slen, dlen) in enumerate(zip(src, dst)): + if slen < dlen: + inc_dims.append(dim) + elif slen > dlen: + dec_dims.append(dim) + return inc_dims, dec_dims + @staticmethod def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int], devices: Optional[Tuple[int]] = None): """ @@ -450,8 +516,8 @@ def iter_idx(dims: List[int]) -> Tuple[int]: # devices if devices is not None: assert len(devices) == len(all_subtensors), f"devices number {len(devices)} not match with RVD number {len(all_subtensors)}" - for tensor, devid in zip(all_subtensors, devices): - dummy_assign(tensor, devid) + for tensor, devid in zip(mats.flatten(), devices): + dummy_assign(tensor, int(devid)) return GridLayout(ftensor, all_subtensors, mats) @@ -522,10 +588,6 @@ def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): return GridLayout(ftensor, subtensors, mats) -TShape = Tuple[int, ...] -TRVD = Tuple[int, ...] - - class PathFinder: """ Pathfinder for generating communication plans for GridLayout @@ -542,7 +604,7 @@ class PathFinder: _cached_inter_paths: Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]] = {} @staticmethod - def intra_shard(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: + def intra_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: """ Get primitive path of transforming ilayout into olayout. ilayout has the same device set with olayout @@ -616,18 +678,8 @@ def intra_shard(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, curr_rvd = src for hop in path[1:]: hop_rvd = nodes[hop] - inc_dim, dec_dim = None, None - for dim, (ipnum, opnum) in enumerate(zip(curr_rvd, hop_rvd)): - if ipnum > opnum: - assert dec_dim is None - dec_dim = dim - continue - if opnum > ipnum: - assert inc_dim is None - inc_dim = dim - continue - nchunks = curr_rvd[dec_dim] // hop_rvd[dec_dim] - layout, prims = PathFinder.intra_step(layouts[-1], dec_dim, inc_dim, nchunks) + ret, layout, prims = PathFinder.intra_transform(ftensor, curr_rvd, hop_rvd, layouts[-1]) + assert ret, "Internal Error." layouts.append(layout) all_prims += prims curr_rvd = hop_rvd @@ -635,7 +687,7 @@ def intra_shard(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, @staticmethod - def inter_shard(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: + def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: """ Get primitives for transforming ilayout into olayout. ilayout has the different device set to olayout. And number of device of ilayout and olayout must be divisable by each other. @@ -695,44 +747,148 @@ def inter_shard(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, PathFinder._cached_inter_paths[key][src] = paths # print for debug - for idx, path in enumerate(paths): - print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") + # for idx, path in enumerate(paths): + # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") + path = paths[nodes.index(dst)] + assert len(path) > 0, f"Un-reachable src RVD ({src}) -> dst RVD ({dst})" + + # setup consumer begining devices + cpaths = tuple(idx for idx in path if nodes[idx][0] == 'c') + cdevs = np.array([t.device[0] for t in olayout.mat.flatten()]).reshape(dst[1:]) + # print('result device map:', list(cdevs.flatten())) + for hop in cpaths[:-1][::-1]: + hop_rvd = nodes[hop][1:] + cdevs = PathFinder.intra_devmap(dst[1:], hop_rvd, cdevs) + # print('calculated consumer device map: ', list(cdevs.flatten())) + # setup primitives for communication + side, layouts, all_prims = 'p', [ilayout], [] + curr_rvd = src[1:] + for hop in path[1:]: + use_inter_step = side != nodes[hop][0] + hop_rvd = nodes[hop][1:] + if not use_inter_step: + ret, layout, prims = PathFinder.intra_transform(ftensor, curr_rvd, hop_rvd, layouts[-1]) + assert ret, "Internal Error" + else: + ret, layout, prims = PathFinder.inter_transform(ftensor, curr_rvd, hop_rvd, layouts[-1], cdevs) + layouts.append(layout) + all_prims += prims + curr_rvd = hop_rvd + side = nodes[hop][0] + return layouts, all_prims @staticmethod - def intra_step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> Tuple[GridLayout, List[IRAdapterPrim]]: + def intra_transform(ftensor: IRFullTensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout: Optional[GridLayout] = None) -> Tuple[GridLayout, List[IRAdapterPrim]]: + """ + Get output layout and transform primitives from a source rvd layout to dst_rvd layout, + + @param ftensor IRFullTensor + @param src_rvd Tuple[int] + @param dst_rvd Tuple[int] + @param ilayout Optional[GridLayout] + + @return ret bool: True if there is a primitive performed + @return layout Optonal[GridLayout]: the RVD layout if ilayout is not None + @return prims Optional[List[IRAdapterPrim]]: the prmitives in transformation + """ + if ilayout is not None: + assert src_rvd == tuple(ilayout.vec) + inc_dims, dec_dims = GridLayout.changed_dims(src_rvd, dst_rvd) + if len(inc_dims) != 1 or len(dec_dims) != 1: + return False, None, None + inc_idx, dec_idx = inc_dims[0], dec_dims[0] + if src_rvd[dec_idx] % dst_rvd[dec_idx] != 0: + return False, None, None + if inc_idx == 1: + return False, None, None + src = ilayout if ilayout is not None else GridLayout.grid(ftensor, src_rvd[0], src_rvd[1], list(src_rvd[2:])) + chunks = src_rvd[dec_idx] // dst_rvd[dec_idx] if dec_idx >= 2 and inc_idx == 0: # d2r - return ilayout.d2r(dec_idx-2, chunks) - if dec_idx >= 2 and inc_idx >= 2: # d2d - return ilayout.d2d(dec_idx-2, inc_idx-2, chunks) - if dec_idx == 1 and inc_idx == 0: # v2r - return ilayout.v2r(chunks) - if dec_idx == 1 and inc_idx >= 2: # v2d - return ilayout.v2d(inc_idx-2, chunks) - if dec_idx == 0 and inc_idx >= 2: # r2d - return ilayout.r2d(inc_idx-2, chunks) - raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") + olayout, prims = src.d2r(dec_idx-2, chunks) + elif dec_idx >= 2 and inc_idx >= 2: # d2d + olayout, prims = src.d2d(dec_idx-2, inc_idx-2, chunks) + elif dec_idx == 1 and inc_idx == 0: # v2r + olayout, prims = src.v2r(chunks) + elif dec_idx == 1 and inc_idx >= 2: # v2d + olayout, prims = src.v2d(inc_idx-2, chunks) + elif dec_idx == 0 and inc_idx >= 2: # r2d + olayout, prims = src.r2d(inc_idx-2, chunks) + else: + raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") + return True, (olayout if ilayout is not None else None), prims @staticmethod - def inter_step(ilayout: GridLayout, dec_idx: Optional[int], inc_idx: Optional[int], chunks: int): - assert dec_idx is None or inc_idx is None + def intra_devmap(src_rvd: TRVD, dst_rvd: TRVD, src_devs: np.ndarray): + """ + Infer device from source rvd to destination rvd + """ + assert tuple(src_rvd) == tuple(src_devs.shape) + # get changed dimensions + inc_idx, dec_idx = GridLayout.changed_dims(src_rvd, dst_rvd) + assert len(inc_idx) == 1 and len(dec_idx) == 1 + inc_idx, dec_idx = inc_idx[0], dec_idx[0] + assert src_rvd[dec_idx] % dst_rvd[dec_idx] == 0 + chunks = src_rvd[dec_idx] // dst_rvd[dec_idx] + # reshape array to match devices + dst_devs = np.full(dst_rvd, -1, dtype=int) + src_devs = GridLayout.dims2last(src_devs, [inc_idx, dec_idx]).reshape(-1, chunks) + dst_devs = GridLayout.dims2last(dst_devs, [dec_idx, inc_idx]) + dshape = dst_devs.shape + # set up device + dst_devs = dst_devs.reshape(-1, chunks) + for rid, devs in enumerate(src_devs): + dst_devs[rid] = devs + dst_devs = dst_devs.reshape(dshape) + # permute to original shape + dst_devs = GridLayout.dims2orig(dst_devs, [dec_idx, inc_idx]) + return dst_devs + + @staticmethod + def inter_transform(ftensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout: Optional[GridLayout] = None, dst_devs: Optional[np.array] = None): + """ + Get output layout and transform primitives from a source rvd layout to dst_rvd layout, + + @param ftensor IRFullTensor + @param src_rvd Tuple[int] + @param dst_rvd Tuple[int] + @param ilayout Optional[GridLayout] + + @return ret bool: True if there is a primitive performed + @return layout Optonal[GridLayout]: the RVD layout if ilayout is not None + @return prims Optional[List[IRAdapterPrim]]: the prmitives in transformation + """ + inc_dims, dec_dims = GridLayout.changed_dims(src_rvd, dst_rvd) + if not ((len(inc_dims) == 1 and len(dec_dims) == 0) or (len(inc_dims) == 0 and len(dec_dims) == 1)): + return False, None, None + inc_idx = inc_dims[0] if len(inc_dims) == 1 else None + dec_idx = dec_dims[0] if len(dec_dims) == 1 else None + src = ilayout if ilayout is not None else GridLayout.grid(ftensor, src_rvd[0], src_rvd[1], list(src_rvd[2:])) if isinstance(inc_idx, int): + if not (dst_rvd[inc_idx] % src_rvd[inc_idx] == 0): + return False, None, None + chunks = dst_rvd[inc_idx] // src_rvd[inc_idx] if inc_idx == 0: - return ilayout.incr(chunks, []) - if inc_idx == 1: - return ilayout.incv(chunks, []) - if inc_idx > 1: - return ilayout.incd(chunks, inc_idx-2, []) - raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") + olayout, prims = src.incr(chunks, dst_devs) + elif inc_idx == 1: + olayout, prims = src.incv(chunks, dst_devs) + elif inc_idx > 1: + olayout, prims = src.incd(chunks, inc_idx-2, dst_devs) + else: + raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") else: + if not (src_rvd[dec_idx] % dst_rvd[dec_idx] == 0): + return False, None, None + chunks = src_rvd[dec_idx] // dst_rvd[dec_idx] if dec_idx == 0: - return ilayout.decr(chunks, []) - if dec_idx == 1: - return ilayout.decv(chunks, []) - if dec_idx > 1: - return ilayout.decd(chunks, dec_idx-2, []) - raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") - + olayout, prims = src.decr(chunks, dst_devs) + elif dec_idx == 1: + olayout, prims = src.decv(chunks, dst_devs) + elif dec_idx > 1: + olayout, prims = src.decd(chunks, dec_idx-2, dst_devs) + else: + raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") + return True, (olayout if ilayout is not None else None), prims @staticmethod def init_intra_graph(ftensor: IRFullTensor, ndevs: int, cost_fn: Optional[Callable]) -> Tuple[List[TRVD], np.ndarray]: @@ -751,23 +907,9 @@ def init_intra_graph(ftensor: IRFullTensor, ndevs: int, cost_fn: Optional[Callab for i in range(len(nodes)): for j in range(len(nodes)): if i == j: continue - isrc, idst = nodes[i], nodes[j] - inc_dim, dec_dim = [], [] - for dim, (pnum_src, pnum_dst) in enumerate(zip(isrc, idst)): - if pnum_src > pnum_dst: - dec_dim.append(dim) - elif pnum_src < pnum_dst: - inc_dim.append(dim) - if len(inc_dim) != 1 or len(dec_dim) != 1: - continue # not direct - inc_dim, dec_dim = inc_dim[0], dec_dim[0] - if idst[inc_dim] % isrc[inc_dim] != 0 or isrc[dec_dim] % idst[dec_dim] != 0: - continue # not direct - if inc_dim == 1: - continue # not consider increasing value partition - nchunks = isrc[dec_dim] // idst[dec_dim] - isrc_layout = GridLayout.grid(ftensor, isrc[0], isrc[1], list(isrc[2:])) - _, prims = PathFinder.intra_step(isrc_layout, dec_dim, inc_dim, nchunks) + src, dst = nodes[i], nodes[j] + ret, _, prims = PathFinder.intra_transform(ftensor, src, dst) + if not ret: continue edges[i, j] = cost_fn(prims[0]) return nodes, edges @@ -804,8 +946,6 @@ def init_inter_graph(ftensor: IRFullTensor, idevs: int, odevs: int, cost_fn: Cal PathFinder._cached_intra_edges[(shape, odevs)] = dst_edges PathFinder._cached_intra_paths[(shape, odevs)] = {} nodes = tuple(('p',) + n for n in src_nodes ) + tuple(('c',) + n for n in dst_nodes) - for node in nodes: - print(node) edges = np.full((len(nodes), len(nodes)), np.inf) edges[:len(src_nodes), :len(src_nodes)] = src_edges edges[len(src_nodes):,len(src_nodes):] = dst_edges @@ -814,28 +954,13 @@ def init_inter_graph(ftensor: IRFullTensor, idevs: int, odevs: int, cost_fn: Cal for i in range(len(src_nodes)): for j in range(len(dst_nodes)): src, dst = src_nodes[i], dst_nodes[j] - diff_dim = [] - for dim, (pnum_src, pnum_dst) in enumerate(zip(src, dst)): - if pnum_src != pnum_dst: - diff_dim.append(dim) - diff_dim = [0] if len(diff_dim) == 0 else diff_dim - if len(diff_dim) != 1: - continue # not direct - diff_dim = diff_dim[0] - if (src[diff_dim] % dst[diff_dim] != 0) and (dst[diff_dim] % src[diff_dim] != 0): - continue # not divisible -> not direct - nchunks = src[diff_dim] // dst[diff_dim] if src[diff_dim] > dst[diff_dim] else dst[diff_dim] // src[diff_dim] # set for [i, len(src_nodes) + j] - src_layout = GridLayout.grid(ftensor, src[0], src[1], list(src[2:])) - dec_dim = diff_dim if src[diff_dim] > dst[diff_dim] else None - inc_dim = diff_dim if dec_dim is None else None - _, prims = PathFinder.inter_step(src_layout, dec_dim, inc_dim, nchunks) + ret, _, prims = PathFinder.inter_transform(ftensor, src, dst) + if not ret: continue edges[i, len(src_nodes) + j] = cost_fn(prims[0]) * comm_factor # set for [len(src_nodes) + j, i] - dst_layout = GridLayout.grid(ftensor, dst[0], dst[1], list(dst[2:])) - dec_dim, inc_dim = inc_dim, dec_dim - _, prims = PathFinder.inter_step(dst_layout, dec_dim, inc_dim, nchunks) - # NVLink: 300GBps Inter-node: 100Gbps + ret, _, prims = PathFinder.inter_transform(ftensor, dst, src) + assert ret edges[len(src_nodes) + j, i] = cost_fn(prims[0]) * comm_factor return nodes, edges From 33dc2eb0bedbe2467b93193d8e668dcd7e12aeae Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Oct 2022 14:23:51 +0800 Subject: [PATCH 1081/1892] add inter-mesh path solution --- cube/graph/gener/concurrent.py | 50 ++++++- cube/graph/gener/layout.py | 12 +- cube/ir/adapter/prim.py | 204 +++++++++++++--------------- cube/runtime/adapter/collectives.py | 152 ++++++++++++++++++--- cube/runtime/adapter/nn.py | 9 +- 5 files changed, 291 insertions(+), 136 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 73152ea2..a2e54186 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -9,7 +9,7 @@ from cube.ir.adapter import IRAdapter from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim -from cube.graph.gener.layout import GridLayout +from cube.graph.gener.layout import GridLayout, PathFinder class ConcurrentGener: @@ -42,14 +42,23 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], except Exception as e: fadapter = None print( - f"full tensor: {fptensors[0].parent} cannot use grid generation.\n" + f"full tensor: {fptensors[0].parent} cannot use intra-transformation generation.\n" f"Reason: {str(e)}\n" f"Switch to general P2P communication." ) # Case 2: sperating device (cross-shard) if len(set(pdevs).intersection(cdevs)) == 0: - pass + # fadapter = ConcurrentGener.gen_cross_shard(fptensors, fctensors, bptensors, bctensors) + try: + fadapter = ConcurrentGener.gen_cross_shard(fptensors, fctensors, bptensors, bctensors) + except Exception as e: + fadapter = None + print( + f"full tensor: {fptensors[0].parent} cannot use inter-transformation generation.\n" + f"Reason: {str(e)}\n" + f"Switch to general P2P communication." + ) # Case 3: General cases # warnings.warn('The adapter is generated using P2P communication') @@ -126,8 +135,37 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], return fadapter @staticmethod - def gen_cross_shard(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> IRAdapter: - pass + def gen_cross_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor],) -> IRAdapter: + """ + This assumes ptensors and ctensors can be represented by RVD layout. + + pdevices: devices of ptensors + cdevices: devices of ctensors + + @param fptensors List[IRSubTensor]: produced tensors + @param fctensors List[IRSubTensor]: consumed tensors + @param bptensors List[IRSubTensor]: produced tensors + @param bctensors List[IRSubTensor]: consumed tensors + + @return fadapter IRAdapter + """ + ftensor = fptensors[0].parent + ilayout = GridLayout.togrid(ftensor, fptensors) + olayout = GridLayout.togrid(ftensor, fctensors) + fpaths, fprims = PathFinder.inter_path(ftensor, ilayout, olayout) + fadapter = IRAdapter(fptensors, fctensors) + fadapter.prims = fprims + + grad: IRFullTensor = ftensor.grad + if grad is not None and (len(bptensors) != 0 or len(bctensors) != 0): + ilayout = GridLayout.togrid(grad, bptensors) + olayout = GridLayout.togrid(grad, bctensors) + bpaths, bprims = PathFinder.inter_path(grad, ilayout, olayout) + badapter = IRAdapter(bptensors, bctensors) + badapter.prims = bprims + IRAdapter.make_pair(fadapter, badapter) + return fadapter @staticmethod def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], @@ -207,7 +245,7 @@ def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRA if tensor.device != ctensor.device: mtensor = copy.copy(tensor) mtensor.cell = ctensor.cell - prims.append(MovePrim(tensor, mtensor)) + prims.append(MovePrim([tensor], [mtensor])) tmoved.append(mtensor) # ===== merge ===== # diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 6c557936..f87a1c8a 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -213,7 +213,7 @@ def incr(self, chunks: int, devices: Optional[np.ndarray] = None): omat = GridLayout.dims2last(glayout.mat, [0]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): - prims.append(BroadcastPrim(src, [src] + list(dsts))) + prims.append(BroadcastPrim([src], [src] + list(dsts))) return glayout, prims def decr(self, chunks: int, devices: Optional[np.ndarray] = None): @@ -236,7 +236,7 @@ def decr(self, chunks: int, devices: Optional[np.ndarray] = None): omat = GridLayout.dims2last(glayout.mat, [0]).flatten() prims = [] for srcs, dst in zip(imat, omat): - prims.append(MovePrim(srcs[0], dst)) + prims.append(MovePrim([srcs[0]], [dst])) return glayout, prims def incd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None): @@ -258,7 +258,7 @@ def incd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None): omat = GridLayout.dims2last(glayout.mat, [2+dim]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): - prims.append(RDScatterPrim(src, dsts, dim=dim)) + prims.append(RDScatterPrim([src], dsts, dim=dim)) return glayout, prims def decd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None): @@ -281,7 +281,7 @@ def decd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None): omat = GridLayout.dims2last(glayout.mat, [2+dim]).flatten() prims = [] for srcs, dst in zip(imat, omat): - prims.append(RDGatherPrim(srcs, dst, dim=dim)) + prims.append(RDGatherPrim(srcs, [dst], dim=dim)) return glayout, prims def incv(self, chunks: int, devices: Optional[np.ndarray] = None): @@ -303,7 +303,7 @@ def incv(self, chunks: int, devices: Optional[np.ndarray] = None): omat = GridLayout.dims2last(glayout.mat, [1]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): - prims.append(RVScatterPrim(src, dsts)) + prims.append(RVScatterPrim([src], dsts)) return glayout, prims def decv(self, chunks: int, devices: Optional[np.ndarray] = None): @@ -326,7 +326,7 @@ def decv(self, chunks: int, devices: Optional[np.ndarray] = None): omat = GridLayout.dims2last(glayout.mat, [1]).flatten() prims = [] for srcs, dst in zip(imat, omat): - prims.append(RVGatherPrim(srcs, dst)) + prims.append(RVGatherPrim(srcs, [dst])) return glayout, prims diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index 8ab6cc8a..d70b6f3a 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -2,7 +2,7 @@ The primitive used for IRAdapter """ -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple import copy from cube.ir.tensor import IRSubTensor, IndexMap, ValueMap @@ -88,11 +88,18 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k devices += t.device self.device = list(set(devices)) - def dispatch(self, devid: int): + def dispatch(self, devid: int) -> Optional[IRAdapterPrim]: """ dispatch to a given device """ - raise NotImplementedError + if devid not in self.device: + return None + assert devid in self.device, f"device {devid} not applied for this comm primitive" + itensors = [itensor for itensor in self.inputs() if devid in itensor.device] + otensors = [otensor for otensor in self.outputs() if devid in otensor.device] + prim = type(self)(itensors, otensors, **self.kwargs) + prim.signature = self.signature + return prim def __repr__(self) -> str: dscp = f'{self.outputs()} = {self.signature}({self.inputs()})' @@ -167,107 +174,99 @@ def __repr__(self) -> str: # communication primitive -class SendPrim(CommPrim): - """ - P2P send prim - """ - def __init__(self, tensor, dst: int): - super().__init__([tensor], [tensor], dst=dst) - self.signature = 'cube.runtime.adapter.send' - - def volume(self) -> int: - return self.input(0).nelement() - - def __repr__(self) -> str: - return f"{self.input(0)} = send[{self.device}]({self.input(0)}, dst={self.kwargs['dst']}" - - -class RecvPrim(CommPrim): - """ - P2P recv prim - """ - def __init__(self, tensor: IRSubTensor, src: int): - super().__init__([], [tensor], - shape=tensor.shape, dtype='torch.'+tensor.dtype.value, src=src) - self.signature = 'cube.runtime.adapter.recv' - - def volume(self) -> int: - return self.input(0).nelement() - - def __repr__(self) -> str: - return f"{self.output(0)} = recv[{self.device}](shape={self.kwargs['shape']}, dtype={self.kwargs['dtype']}, src={self.kwargs['src']}" - - class MovePrim(CommPrim): """ P2P send/recv, non-differentiable """ - def __init__(self, itensor: IRSubTensor, otensor: IRSubTensor): - src: int = itensor.device[0] if len(itensor.device) > 0 else None - dst: int = otensor.device[0] if len(otensor.device) > 0 else None - super().__init__([itensor], [otensor], src=src, dst=dst) - - def dispatch(self, devid: int) -> Union[SendPrim, RecvPrim]: - if devid == self.kwargs['src']: - return SendPrim(self.input(0), self.kwargs['dst']) - if devid == self.kwargs['dst']: - return RecvPrim(self.output(0), self.kwargs['src']) - return None + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + if len(kwargs) == 0: + assert len(itensors) == 1 and len(otensors) == 1 + kwargs['shape'] = itensors[0].shape + kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None + kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None + shape, dtype, src, dst = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dst'] + super().__init__(itensors, otensors, shape=shape, dtype=dtype, src=src, dst=dst) + self.signature = 'cube.runtime.adapter.move' def volume(self) -> int: return self.input(0).nelement() def __repr__(self): - dscp = f"move[{self.device}]({self.input(0)}, src={self.kwargs['src']}, dst={self.kwargs['dst']})" + dscp = f"{self.outputs()} = move{self.device}({self.inputs()}, src={self.kwargs['src']}, dst={self.kwargs['dst']})" return dscp +class CollectivePrim(CommPrim): + """ + Collective primitive, non-differentiable + """ + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + super().__init__(itensors, otensors, **kwargs) + if 'ranks' not in self.kwargs: + self.kwargs['ranks'] = self.device + + class RDScatterPrim(CommPrim): """ P2P Cross-device dimension scatter, non-differentiable. Tensor[Tile0, Tile1]: device 0 -> Tensor[Tile0]: device0, Tensor[Tile1]: device1 """ - def __init__(self, itensor: IRSubTensor, otensors: List[IRSubTensor], dim: int): + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): """ @param itensors List[IRSubTensor]: one tensor at device of `src`. @param otensors List[IRSubTensor]: each ran hosts one tenor partitioned by dim. @param dim int: the dimension that itensor will be partitioned """ - src: int = itensor.device[0] if len(itensor.device) > 0 else None - dst: List[int] = [t.device[0] if len(t.device) > 0 else None for t in otensors] - super().__init__([itensor], otensors, dim=dim, src=src, dst=dst) + if len(kwargs) == 0: + assert len(itensors) == 1 + kwargs['shape'] = tuple(itensors[0].shape) + kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None + kwargs['dsts'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) + shape, dtype, src, dsts = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dsts'] + super().__init__(itensors, otensors, shape=shape, dtype=dtype, dim=dim, src=src, dsts=dsts) self.signature = 'cube.runtime.adapter.rdscatter' def volume(self) -> int: return self.input(0).nelement() def __repr__(self) -> str: - return f"{self.outputs()} = rdscatter[{self.device}]({self.inputs()}, dim={self.kwargs['dim']})" + inputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.inputs()) + outputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.outputs()) + return f"{outputs} = rdscatter{self.device}({inputs}, dim={self.kwargs['dim']})" -class RVScatterPrim(CommPrim): +class RVScatterPrim(CollectivePrim): """ P2P Cross-device dimension scatter, non-differentiable. Tensor[Tile0, Tile1]: device 0 -> Tensor[Tile0]: device0, Tensor[Tile1]: device1 """ - def __init__(self, itensor: IRSubTensor, otensors: List[IRSubTensor]): + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): """ @param itensors List[IRSubTensor]: one tensor at device of `src`. @param otensors List[IRSubTensor]: each ran hosts one tenor partitioned by dim. @param dim int: the dimension that itensor will be partitioned """ - src: int = itensor.device[0] if len(itensor.device) > 0 else None - dst: List[int] = [t.device[0] if len(t.device) > 0 else None for t in otensors] - super().__init__([itensor], otensors, src=src, dst=dst) + if len(kwargs) == 0: + assert len(itensors) == 1 + kwargs['shape'] = tuple(itensors[0].shape) + kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None + kwargs['dsts'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) + shape, dtype, src, dsts = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dsts'] + super().__init__(itensors, otensors, shape=shape, dtype=dtype, src=src, dst=dsts) self.signature = 'cube.runtime.adapter.rvscatter' def volume(self) -> int: return self.input(0).nelement() * len(self.outputs()) def __repr__(self) -> str: - return f"{self.outputs()} = rvscatter[{self.device}]({self.inputs()})" + inputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.inputs()) + outputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.outputs()) + return f"{outputs} = rvscatter{self.device}({inputs})" class RDGatherPrim(CommPrim): @@ -275,63 +274,71 @@ class RDGatherPrim(CommPrim): Gather tensors from remote devices to a local device. The local device doesn't have any tensor """ - def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, dim: int): - src = [t.device[0] if len(t.device) > 0 else None for t in itensors] - dst = otensor.device[0] if len(otensor.device) > 0 else None - super().__init__(itensors, [otensor], src=src, dst=dst, dim=dim) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): + if len(kwargs) == 0: + assert len(otensors) == 1 + kwargs['shape'] = tuple(itensors[0].shape) # the input tensor shape + kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['srcs'] = tuple(itensor.device[0] if len(itensor.device) > 0 else None for itensor in itensors) + kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None + shape, dtype, srcs, dst = kwargs['shape'], kwargs['dtype'], kwargs['srcs'], kwargs['dst'] + super().__init__(itensors, otensors, shape=shape, dtype=dtype, srcs=srcs, dst=dst, dim=dim) self.signature = 'cube.runtime.adapter.rdgather' def volume(self) -> int: return self.output(0).nelement() def __repr__(self) -> str: - return ( - f"rdgather[{self.device}](" - f"{self.inputs()}, dim={self.kwargs['dim']}, " - f"src={self.kwargs['src']}, dst={self.kwargs['dst']})" - ) + inputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.inputs()) + outputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.outputs()) + return f"{outputs} = rdgather{self.device}({inputs}, dim={self.kwargs['dim']})" + -class RVGatherPrim(CommPrim): +class RVGatherPrim(CollectivePrim): """ Gather tensors from remote devices and sum in the local device. The local device doesn't have any tensor """ - def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): - src = [t.device[0] if len(t.device) > 0 else None for t in itensors] - dst = otensor.device[0] if len(otensor.device) > 0 else None - super().__init__(itensors, [otensor], src=src, dst=dst) + def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + if len(kwargs) == 0: + assert len(otensors) == 1 + kwargs['shape'] = tuple(itensors[0].shape) + kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['srcs'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) + kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None + shape, dtype, srcs, dst = kwargs['shape'], kwargs['dtype'], kwargs['srcs'], kwargs['dst'] + super().__init__(itensors, otensors, shape=shape, dtype=dtype, srcs=srcs, dst=dst) self.signature = 'cube.runtime.adapter.rvgather' def volume(self) -> int: return self.output(0).nelement() * len(self.inputs()) def __repr__(self) -> str: - src = self.kwargs['src'] - dst = self.kwargs['dst'] - return f"{self.outputs()} = rvgather[{self.device}]({self.inputs()}, src={src}, dst={dst})" + inputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.inputs()) + outputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.outputs()) + return f"{outputs} = rvgather{self.device}({inputs})" -class CollectivePrim(CommPrim): +class BroadcastPrim(CollectivePrim): """ - Collective primitive, non-differentiable + non-differential reduce-scatter """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): + if len(kwargs) == 0: + assert len(itensors) == 1 + kwargs['shape'] = tuple(itensors[0].shape) + kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None super().__init__(itensors, otensors, **kwargs) - if 'ranks' not in self.kwargs: - self.kwargs['ranks'] = self.device + self.signature = 'cube.runtime.adapter.broadcast' + + def volume(self) -> int: + ndevs = len(self.outputs()) + return self.input(0).nelement() * (ndevs-1) + + def __repr__(self) -> str: + return f"{self.outputs()} = broadcast{self.device}({self.inputs()}, src={self.kwargs['src']})" - def dispatch(self, devid: int) -> Optional[CommPrim]: - """ - dispatch to a given device - """ - if devid not in self.device: - return None - assert devid in self.device, f"device {devid} not applied for this comm primitive" - itensors = [itensor for itensor in self.inputs() if devid in itensor.device] - otensors = [otensor for otensor in self.outputs() if devid in otensor.device] - prim = type(self)(itensors, otensors, **self.kwargs) - prim.signature = self.signature - return prim class AllReducePrim(CollectivePrim): @@ -391,23 +398,6 @@ def __repr__(self) -> str: return f'{self.outputs()} = reduce_scatter[{self.device}]({self.inputs()})' -class BroadcastPrim(CollectivePrim): - """ - non-differential reduce-scatter - """ - def __init__(self, itensor: IRSubTensor, otensors: List[IRSubTensor], **kwargs): - src: int = itensor.device[0] if len(itensor.device) > 0 else None - super().__init__([itensor], otensors, src=src, **kwargs) - self.signature = 'cube.runtime.adapter.broadcast' - - def volume(self) -> int: - ndevs = len(self.outputs()) - return self.input(0).nelement() * (ndevs-1) - - def __repr__(self) -> str: - return f"{self.outputs()} = broadcast[{self.device}]({self.inputs()}, src={self.kwargs['src']})" - - class ReducePrim(CollectivePrim): """ non-differential reduce prim diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 02030eca..4f8de194 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -57,6 +57,16 @@ def recv(tensors: List[torch.Tensor], shape: List[int], dtype: torch.dtype, src: return tensor +def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int): + rank = torch.distributed.get_rank() + if rank == src: + assert torch.is_tensor(tensor) + return send(tensor, dst) + else: + assert rank == dst + return recv(None, shape, dtype, src) + + def sendrecv(input_tensors: List[torch.Tensor], output_shapes: List[List[int]], output_dtypes: List[torch.dtype], @@ -173,27 +183,137 @@ def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: return otensor -def broadcast(input_tensors: List[torch.Tensor], - output_shapes: List[List[int]], - output_dtypes: List[torch.dtype], - ranks: List[int]) -> List[torch.Tensor]: +def rdscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, + dim: int, src: int, dsts: Tuple[int]): """ - Broadcast. ranks[0] is the root + RDScatter: split itensor at rank `src` along dim into `len(dsts)` chunks, + and then send each chunk to `dst` devices. """ CudaTimer().start(field_name='comm', predefined=True) - assert len(input_tensors) == 1 or len(input_tensors) == 0 - if len(input_tensors) == 1: - tensor: torch.Tensor = input_tensors[0] - if not tensor.is_contiguous(): - tensor = tensor.contiguous() + rank = torch.distributed.get_rank() + if rank == src: + with torch.no_grad(): + otensors = itensor.chunk(len(dsts), dim) + send_ops = [] + for dst, otensor in zip(dsts, otensors): + if not otensor.is_contiguous(): + otensor = otensor.contiguous() + send_op = torch.distributed.P2POp( + torch.distributed.isend, otensor, dst + ) + send_ops.append(send_op) + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm', predefined=True) else: - assert len(output_shapes) == 1 - assert len(output_dtypes) == 1 - shape = output_shapes[0] - dtype = output_dtypes[0] - tensor = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) + assert rank in dsts + shape = list(shape) + shape[dim] = shape[dim] // len(dsts) + otensor = torch.empty( + shape, requires_grad=True, dtype=dtype, + device=torch.cuda.current_device() + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, otensor, src + ) + reqs = torch.distributed.batch_isend_irecv([recv_op]) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm', predefined=True) + return otensor + + +def rvscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, + dim: int, src: int, dsts: Tuple[int]): + """ + src: global rank + """ + CudaTimer().start(field_name='comm', predefined=True) + group = DeviceGroup().get_group((src,) + dsts) + rank = torch.distributed.get_rank() + tensor: torch.Tensor = itensor / len(dsts) if src == rank else \ + torch.empty(shape, dtype=dtype, requires_grad = True) + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + torch.distributed.broadcast(tensor, src, group=group) + CudaTimer().stop(field_name='comm', predefined=True) + return tensor + + +def rdgather(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, + dim: int, srcs: Tuple[int], dst: int): + """ + @param srcs Tuple[int]: global rank of each source device + @param dst int: global rank of destination device + """ + CudaTimer().start(field_name='comm', predefined=True) + rank = torch.distributed.get_rank() + if rank == dst: + recv_ops = [] + recv_tensors = [] + for src in srcs: + tensor = torch.empty( + shape, dtype=dtype, + device=torch.cuda.current_device() + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, src + ) + recv_ops.append(recv_op) + recv_tensors.append(tensor) + reqs = torch.distributed.batch_isend_irecv(recv_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + with torch.no_grad(): + otensor = torch.cat(tuple(recv_tensors), dim=dim) + otensor = otensor.requires_grad_() + CudaTimer().stop(field_name='comm', predefined=True) + return otensor + else: + assert rank in srcs + tensor = itensor.contiguous() if not itensor.is_contiguous() else itensor + send_ops = [torch.distributed.P2POp(torch.distributed.isend, tensor, dst)] + reqs = torch.distributed.batch_isend_irecv(send_ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm', predefined=True) + return itensor + + +def rvgather(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, + srcs: Tuple[int], dst: int): + """ + @param srcs Tuple[int]: global rank of each source device + @param dst int: global rank of destination device + """ + CudaTimer().start(field_name='comm', predefined=True) + rank = torch.distributed.get_rank() + group = DeviceGroup().get_group(srcs + (dst,)) + tensor = torch.zeros(shape, dtype=dtype, requires_grad=True) if rank == dst else itensor + torch.distributed.reduce(tensor, dst, group=group) + CudaTimer().stop(field_name='comm', predefined=True) + return tensor + + +def broadcast(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, src: int, ranks: List[int]) -> torch.Tensor: + """ + Broadcast + @param src: the global rank that holds tensor for broadcasting + """ + CudaTimer().start(field_name='comm', predefined=True) + rank = torch.distributed.get_rank() group = DeviceGroup().get_group(ranks) - torch.distributed.broadcast(tensor, ranks[0], group=group) + if rank == src: + tensor = itensor.contiguous() if not itensor.is_contiguous() else itensor + else: + assert rank in ranks + tensor = torch.empty(shape, + device=torch.cuda.current_device(), requires_grad=True, dtype=dtype) + torch.distributed.broadcast(tensor, src, group=group) CudaTimer().stop(field_name='comm', predefined=True) return tensor diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index 09b975a2..919d75f4 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -78,6 +78,7 @@ def backward(ctx, grad_output): return grad_output, None +@torch.jit.ignore def allreduce_identity(tensor: torch.Tensor, ranks: List[int]): return AllReduceIdentity.apply(tensor, ranks) @@ -95,7 +96,7 @@ def backward(ctx, grad: torch.Tensor): grad = _allreduce(grad, ranks) return grad, None - +@torch.jit.ignore def identity_allreduce(tensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: return IdentityAllreduce.apply(tensor, ranks) @@ -115,6 +116,7 @@ def backward(ctx, grad: torch.Tensor): return grad, None +@torch.jit.ignore def allreduce_allreduce(tensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: return AllReduceAllReduce.apply(tensor, ranks) @@ -135,6 +137,7 @@ def backward(ctx, grad: torch.Tensor): return grad, None, None +@torch.jit.ignore def reducescatter_allgather(tensor: torch.Tensor, dim: int, ranks: List[int]): return ReduceScatterAllGather.apply(tensor, dim, ranks) @@ -155,6 +158,7 @@ def backward(ctx, grad: torch.Tensor): return grad, None, None +@torch.jit.ignore def allgather_reducescatter(tensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: return AllGatherReduceScatter.apply(tensor, dim, ranks) @@ -174,6 +178,7 @@ def backward(ctx, grad: torch.Tensor): return _chunk(grad, dim, ranks), None, None +@torch.jit.ignore def allgather_split(tensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: return AllGatherSplit.apply(tensor, dim, ranks) @@ -197,6 +202,7 @@ def backward(ctx, grad: torch.Tensor): return grad, None, None +@torch.jit.ignore def split_allgather(tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: return SplitAllGather.apply(tensor, dim, ranks) @@ -218,6 +224,7 @@ def backward(ctx, grad: torch.Tensor): return grad, None, None, None +@torch.jit.ignore def alltoall_alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: return AllToAllAllToAll.apply(itensor, idim, odim, ranks) From 714a1978ae79ee83d3607c5fde23baba2140ed24 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Oct 2022 15:21:10 +0800 Subject: [PATCH 1082/1892] support multi-node code generation (only for deterministic policy) --- cube/compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 74cb4879..e1cf6006 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -73,9 +73,10 @@ def train_step(model, dataloader): if torch.distributed.is_initialized(): # multiple device myrank = torch.distributed.get_rank() + local_rank = cube.runtime.device.DeviceGroup().local_rank else: # single device - myrank = 0 + myrank = local_rank = 0 def _load_tschedule_fn(filename) -> Callable: import importlib.util @@ -100,7 +101,7 @@ def decorator(fn: Callable) -> Callable: print_each_rank(f'loading existed schedule from {filename} ...') return _load_tschedule_fn(filename) - if myrank == 0: + if local_rank == 0: compile_start = time.time() From 4fe1bf36c33059aec80b94cfcbc2fb88e41ab6b1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Oct 2022 15:43:20 +0800 Subject: [PATCH 1083/1892] fix bugs on equal-device transmission --- cube/graph/gener/layout.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index f87a1c8a..090d6da2 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -213,7 +213,10 @@ def incr(self, chunks: int, devices: Optional[np.ndarray] = None): omat = GridLayout.dims2last(glayout.mat, [0]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): - prims.append(BroadcastPrim([src], [src] + list(dsts))) + if chunks == 1: + prims.append(MovePrim([src], dsts)) + else: + prims.append(BroadcastPrim([src], [src] + list(dsts))) return glayout, prims def decr(self, chunks: int, devices: Optional[np.ndarray] = None): @@ -859,6 +862,8 @@ def inter_transform(ftensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout: Optional[Gri @return prims Optional[List[IRAdapterPrim]]: the prmitives in transformation """ inc_dims, dec_dims = GridLayout.changed_dims(src_rvd, dst_rvd) + if len(inc_dims) == 0 and len(dec_dims) == 0: + inc_dims = [0] if not ((len(inc_dims) == 1 and len(dec_dims) == 0) or (len(inc_dims) == 0 and len(dec_dims) == 1)): return False, None, None inc_idx = inc_dims[0] if len(inc_dims) == 1 else None From 6a63abe84fe305a1c520391944c795c1fd800824 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Oct 2022 01:58:50 -0700 Subject: [PATCH 1084/1892] fix mutli-node generation code --- cube/program.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cube/program.py b/cube/program.py index 3c691f36..2acefa7d 100644 --- a/cube/program.py +++ b/cube/program.py @@ -10,6 +10,7 @@ from cube.runtime.syndata import CubeDataLoader from cube.runtime.module import CubeModule +from cube.runtime.device import DeviceGroup from cube.profiler.timer import print_each_rank import torch @@ -123,8 +124,10 @@ def __init__(self, model: torch.nn.Module, input_shapes): """ Create semantic model based on AI Scientist description. """ - dist = torch.distributed.is_initialized() - if (not dist) or (dist and torch.distributed.get_rank() == 0): + local_rank = 0 + if torch.distributed.is_initialized(): + local_rank = DeviceGroup().local_rank + if local_rank == 0: self.ir_graph = parser.convert_model( model, input_shapes=input_shapes ) From 4fcb2afd37d8d2e988aed1149ef67048c2453542 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Oct 2022 22:03:02 +0800 Subject: [PATCH 1085/1892] allow partition on value dimension of dimops --- cube/algorithm/ops/dimops.py | 87 +++++++++++++++++++++-------------- cube/graph/function/dimops.py | 6 ++- cube/graph/parser/register.py | 6 +-- 3 files changed, 61 insertions(+), 38 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 36fa4d41..030d2a9b 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Any, Dict +from typing import List, Optional, Any, Dict, Union from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule @@ -41,53 +41,66 @@ def __init__(self, node: IRDimops): raise TypeError(f"Expect IRDimops") super().__init__(node) - def satisfy(self, idx: int, dim: int, num: int) -> bool: + def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: """ Check whether the condition satisfies. @param idx int: input index - @param dim int: input dimension + @param dim Union[int, str]: input dimension or 'v', ie., partition at value dimension @param num int: chunks to partition the dimension @return satisfy bool: true if can be partitioned, elsewise false. """ - assert all(isinstance(cond, int) for cond in [idx, dim, num]), "expect int condition" + assert all(isinstance(cond, int) for cond in [idx, num]), "expect int condition" + assert isinstance(dim, int) or dim == 'v', f"expect dim to be int or 'v'" node: IRDimops = self.node assert isinstance(node.input(idx), IRSubTensor), f"partitioning on a non-tensor input" ninputs = len(node.inputs()) idx = idx if idx >= 0 else idx + ninputs assert idx < ninputs, f"index out of boundary: {idx} >= {ninputs}" - dim = dim if dim >= 0 else dim + node.input(idx).ndims - assert dim < node.input(idx).ndims, f"dimension output of boundary: {dim} >= {node.input(idx).ndims}" + + if isinstance(dim, int): + dim = dim if dim >= 0 else dim + node.input(idx).ndims + assert dim < node.input(idx).ndims, f"dimension output of boundary: {dim} >= {node.input(idx).ndims}" - # we only partition the first non-1 annotated dimension for hidden-dimension cases. - for adim in node.anno.input(idx).dims[dim].identifiers: - if adim == '1^': continue - break - dimlen = node.anno.getlen(adim) - # check node special rules first - for rule in node.transform_rules: - if rule.input(idx) == DimopSplit.D(dim): - return dimlen >= num - # otherwise check for default rules - reduce = node.anno.input(idx).dims[dim].reduces[0] - if reduce == DimAnno.ReduceType.Freeze: + # try split at tensor spatial dimension + if isinstance(dim, int): + for adim in node.anno.input(idx).dims[dim].identifiers: + if adim == '1^': continue + break + dimlen = node.anno.getlen(adim) + # first check node special rules first + for rule in node.transform_rules: + if rule.input(idx) == DimopSplit.D(dim): + return dimlen >= num + # then check default rules + reduce = node.anno.input(idx).dims[dim].reduces[0] + if reduce == DimAnno.ReduceType.Freeze: + return False + return dimlen >= num + else: + for rule in node.transform_rules: + if rule.input(idx).isV(): + return True return False - return dimlen >= num + - def instantiate(self, idx: int, dim: int, num: int) -> Optional[List[IRDimops]]: + def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List[IRDimops]]: node: IRDimops = self.node satisfy = self.satisfy(idx, dim, num) - for adim in node.anno.input(idx).dims[dim].identifiers: - if adim == '1^': continue - break - reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] - print(f'try split {node.name}: {node.anno} | dim: {adim} reduce: {reduce}') - if not satisfy: - print(f'Failed!') - return None + + if isinstance(dim, int): + for adim in node.anno.input(idx).dims[dim].identifiers: + if adim == '1^': continue + break + reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] + else: + adim, reduce = 'Value', None + color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' + print(f"try split {node.name}: {node.anno} | dim: {adim} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") + if not satisfy: return None rule: TransformRule = self.infer(idx, dim, num) @@ -121,7 +134,7 @@ def transform(tensor: Any, split: DimopSplit) -> List[Any]: return sub_nodes - def infer(self, idx: int, dim: int, num: int) -> Optional[TransformRule]: + def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformRule]: """ Given the partition choice on `dim` dimension of idx-th input, return the partitioning of the output tensor. @@ -132,15 +145,21 @@ def infer(self, idx: int, dim: int, num: int) -> Optional[TransformRule]: @return rule TransformRule: the transformation rule """ node: IRDimops = self.node + assert isinstance(dim, int) or dim == 'v', f"expect dim to be int or 'v'" + # check node special rules first + for r in node.transform_rules: + if isinstance(dim, int): + if r.input(idx) == DimopSplit.D(dim): + return r + else: + if r.input(idx).isV(): + return r + # otherwise use default rule + assert isinstance(dim, int), f"Error: expect dim to be int for default rules" adim: str = node.anno.input(idx).dims[dim].identifiers[0] reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] if reduce == DimAnno.ReduceType.Freeze: return None - # check node special rules first - for r in node.transform_rules: - if r.input(idx) == DimopSplit.D(dim): - return r - # otherwise use default rule itransform, otransform = [], [] # input for idx, idim in enumerate(node.anno.inputs()): diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 419bce16..407bac08 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -527,7 +527,7 @@ class TransformRule: def __init__(self, irules: Tuple[DimopSplit], orules: Tuple[DimopSplit], kwarg_modifier: Optional[Callable] = None) -> None: self._inputs = tuple(irules) self._outputs = tuple(orules) - modifier = kwarg_modifier if kwarg_modifier is not None else lambda x : x + modifier = kwarg_modifier if kwarg_modifier is not None else TransformRule.default_modifier self._modifier = (modifier,) def inputs(self) -> Tuple[DimopSplit]: @@ -550,6 +550,10 @@ def __repr__(self) -> str: outputs = ', '.join(repr(split) for split in self._outputs) return f'{inputs} -> {outputs}' + @staticmethod + def default_modifier(kwargs: Dict, idx: int, dim: Union[int, str], num: int) -> Dict: + return kwargs + class IRDimops(IRFwOperation): """ diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 10dcf9ec..bba41f34 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -6,12 +6,12 @@ import inspect import torch -from cube.graph.function.dimops import IRDimops, OpAnno +from cube.graph.function.dimops import IRDimops, OpAnno, TransformRule from cube.graph.parser.mapping import Sign2Op -def register(anno: str, name: Optional[str] = None): +def register(anno: str, name: Optional[str] = None, rules: Optional[List[TransformRule]] = None): """ Register a function with einop annotations. @@ -63,7 +63,7 @@ def udfop(signature: str, inputs: List[Any]): kwargs = dict() for name, val in zip(kwarg_names, kwarg_vals): kwargs[name] = val - return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, **kwargs) + return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, transform_rules=rules, **kwargs) print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') Sign2Op.register(fsig, udfop, code) From 34ce0bd8efe58947f764967a87ae676f6a022e9b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 25 Oct 2022 12:48:32 +0800 Subject: [PATCH 1086/1892] fix bugs on inter-mesh generation and add support for value partition general generation --- cube/graph/gener/concurrent.py | 7 +++++-- cube/graph/gener/layout.py | 13 +++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index a2e54186..54729f58 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -226,8 +226,11 @@ def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRA end = start + e2 - s2 indmap.append((start, end)) indmap = IndexMap(tuple(indmap)) - assert itensor.valmap == common.valmap, "Value map not same" - valmap = ValueMap((0, 1)) + if itensor.valmap == common.valmap: + valmap = ValueMap((0, 1)) + else: + assert itensor.valmap == (0, 1) + valmap = common.valmap select_prim = SelectPrim(itensor, indmap, valmap, common) prims.append(select_prim) if itensor.device == ctensor.device and common == ctensor: diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 090d6da2..e9e8ea56 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -18,6 +18,7 @@ from cube.ir.adapter.prim import BroadcastPrim from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim +from cube.runtime.device import DeviceGroup TShape = Tuple[int, ...] @@ -754,15 +755,19 @@ def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") path = paths[nodes.index(dst)] + # print(f"Find path: {' -> '.join(str(nodes[i]) for i in path)}") assert len(path) > 0, f"Un-reachable src RVD ({src}) -> dst RVD ({dst})" # setup consumer begining devices cpaths = tuple(idx for idx in path if nodes[idx][0] == 'c') - cdevs = np.array([t.device[0] for t in olayout.mat.flatten()]).reshape(dst[1:]) + curr_devs = np.array([t.device[0] for t in olayout.mat.flatten()]).reshape(dst[1:]) + curr_node = dst[1:] # print('result device map:', list(cdevs.flatten())) for hop in cpaths[:-1][::-1]: hop_rvd = nodes[hop][1:] - cdevs = PathFinder.intra_devmap(dst[1:], hop_rvd, cdevs) + curr_devs = PathFinder.intra_devmap(curr_node, hop_rvd, curr_devs) + curr_node = hop_rvd + consumer_entry_devs = curr_devs # print('calculated consumer device map: ', list(cdevs.flatten())) # setup primitives for communication side, layouts, all_prims = 'p', [ilayout], [] @@ -774,7 +779,7 @@ def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, ret, layout, prims = PathFinder.intra_transform(ftensor, curr_rvd, hop_rvd, layouts[-1]) assert ret, "Internal Error" else: - ret, layout, prims = PathFinder.inter_transform(ftensor, curr_rvd, hop_rvd, layouts[-1], cdevs) + ret, layout, prims = PathFinder.inter_transform(ftensor, curr_rvd, hop_rvd, layouts[-1], consumer_entry_devs) layouts.append(layout) all_prims += prims curr_rvd = hop_rvd @@ -826,7 +831,7 @@ def intra_devmap(src_rvd: TRVD, dst_rvd: TRVD, src_devs: np.ndarray): """ Infer device from source rvd to destination rvd """ - assert tuple(src_rvd) == tuple(src_devs.shape) + assert tuple(src_rvd) == tuple(src_devs.shape), f"RVD mis-matches with device shape, {src_rvd} != {src_devs.shape}" # get changed dimensions inc_idx, dec_idx = GridLayout.changed_dims(src_rvd, dst_rvd) assert len(inc_idx) == 1 and len(dec_idx) == 1 From e4d44a996b4b85c8d94094e76d8dfbca54ded5e4 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 25 Oct 2022 13:17:42 +0800 Subject: [PATCH 1087/1892] refine doc --- examples/alphafold2/README.md | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/alphafold2/README.md b/examples/alphafold2/README.md index 52425345..f6857751 100644 --- a/examples/alphafold2/README.md +++ b/examples/alphafold2/README.md @@ -76,14 +76,22 @@ To solve this problem, current dynamic programming formulation need to be update - k pass recompute policy: reduce peak memory, increase execution time - coshard / chunk: split computation with extremly large output size into acceptable ones -$f(i, max(p, r), q + s) = min (f(i, max(p, r), q + s), f(i-1, p, q) + t(r, s))$ +**dynamic formulation** + +$f(i, max(p, r), q + s) = min (f(i, max(p, r), q + s), f(i-1, p, q) + t(i, r, s))$ + +- $f(i, p, q)$: the minimal execution time from 1st to i-th operator when maximum temporary tensor size = $p$ and the sum size of checkpointed tensor = $q$ +- $t(i, r, s)$: the minimal time of plans that schedule i-th operator when max temporary size = $r$ and checkpointed size = $s$. The space spanned by different checkpoint policies and chunk sizes is described in $t$. +- the optimal value in the end: ${min}_{p+q offload problems to jit tensor compilers -- strategy: detect memory constrained parts then coshard them -- large enough size of input shapes amay lready utilize accelerators -- should include coshard into the dp formulation +- given a computation graph and a memory constraint, generate the most efficient execution plan + - can we leverage the jit compiler or the *del* operand in PyTorch + - better options: *nnfusion*, *xla*? +- chunk (coshard) + - how to choose the chunk size: large enough size of input shapes may already utilize accelerators + - a heuristic strategy: detect the most memory-intensive operator then coshard it # Experiment From 02f8865171b7b71a2d1e40795959cf73107ba85f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 24 Oct 2022 23:32:46 -0700 Subject: [PATCH 1088/1892] allow empty device --- cube/execplan/execplan.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index d8ddb2c2..c6453222 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -67,7 +67,8 @@ def seq(self, devid: int) -> List[IRCell]: Note changing the list content will not change the execution plan. """ - assert devid in self._seq, f"device id {devid} not exists" + if devid not in self._seq: + return [] return copy.copy(self._seq[devid]) def at(self, devid: int) -> List[IRCell]: @@ -76,7 +77,8 @@ def at(self, devid: int) -> List[IRCell]: Note changing the list content will change the execution plan. """ - assert devid in self._seq, f"device id {devid} not exists" + if devid not in self._seq: + return [] return self._seq[devid] def flatten(self, devid: int) -> List[IRCell]: From 494c2ee4dccf6463affae7cdebfb14c43e345b65 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 25 Oct 2022 00:55:26 -0700 Subject: [PATCH 1089/1892] add greedy workload balance algorithm for general P2P generation --- cube/graph/gener/concurrent.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 54729f58..50054912 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -1,8 +1,9 @@ """ Concurrent producer / consumer Adapter Generator """ -from typing import List, Optional +from typing import List, Optional, Dict import copy +import numpy as np from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap from cube.ir.adapter.prim import IRAdapterPrim @@ -177,22 +178,25 @@ def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], @return adapter IRAdapter """ fprims = [] + fpdevs = set(t.device[0] for t in fptensors) + fcomm_workload = {t.device[0]: 0 for t in fptensors} for ctensor in fctensors: - fprims += ConcurrentGener.gen_subtensor(ctensor, fptensors) + fprims += ConcurrentGener.gen_subtensor(ctensor, fptensors, fcomm_workload) fadapter = IRAdapter(fptensors,fctensors) fadapter.prims = fprims # backward if len(bptensors) > 0 and len(bctensors) > 0: bprims = [] + bcomm_workload = {t.device[0]: 0 for t in bptensors} for cgrad in bctensors: - bprims += ConcurrentGener.gen_subtensor(cgrad, bptensors) + bprims += ConcurrentGener.gen_subtensor(cgrad, bptensors, bcomm_workload) badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims IRAdapter.make_pair(fadapter, badapter) return fadapter @staticmethod - def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRAdapterPrim]: + def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor], workload: Dict[int, int]) -> List[IRAdapterPrim]: """ Generate communiction primitives for ctensor @@ -203,11 +207,20 @@ def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRA """ # category to local tensor and remote tensor local = [t for t in ptensors if t.device == ctensor.device] - remote = [t for t in ptensors if t.device != ctensor.device] + # reorder remote devices: higher priority to use tensor with lower communication workload + devices = np.array([devid for devid in workload.keys()], dtype=int) + volume = np.array([workload[devid] for devid in workload.keys()]) + indices = np.argsort(volume) + sorted_devices = devices[list(indices)] + remote: List[IRSubTensor] = [] + for devid in sorted_devices: + if devid == ctensor.device[0]: continue + remote += [t for t in ptensors if t.device[0] == devid] + prims = [] # ==== select ==== # - intersections = [] + intersections: List[IRSubTensor] = [] # check local for itensor in local+remote: if itensor.device == ctensor.device and itensor == ctensor: @@ -249,6 +262,7 @@ def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor]) -> List[IRA mtensor = copy.copy(tensor) mtensor.cell = ctensor.cell prims.append(MovePrim([tensor], [mtensor])) + workload[tensor.device[0]] += tensor.nelement() tmoved.append(mtensor) # ===== merge ===== # From 77d777472433dc5bd7ab9c0a65ad7c3cbf659e69 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 25 Oct 2022 03:18:41 -0700 Subject: [PATCH 1090/1892] switch move to synchornized call due to strange bugs --- cube/runtime/adapter/collectives.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 4f8de194..bfda0434 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -57,14 +57,34 @@ def recv(tensors: List[torch.Tensor], shape: List[int], dtype: torch.dtype, src: return tensor +# def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int): +# rank = torch.distributed.get_rank() +# if rank == src: +# assert torch.is_tensor(tensor) +# return send(tensor, dst) +# else: +# assert rank == dst +# return recv(None, shape, dtype, src) + + def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int): + """ + Move a tensor from source device to destination device. + """ + CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() if rank == src: assert torch.is_tensor(tensor) - return send(tensor, dst) + torch.distributed.send(tensor, dst) else: assert rank == dst - return recv(None, shape, dtype, src) + tensor = torch.empty(shape, dtype=dtype, + device=torch.cuda.current_device(), requires_grad=True + ) + torch.distributed.recv(tensor, src) + CudaTimer().stop(field_name='comm', predefined=True) + return tensor + def sendrecv(input_tensors: List[torch.Tensor], From d851e4f429740bf1828a251e6396e7d6bcbdd519 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 25 Oct 2022 04:18:40 -0700 Subject: [PATCH 1091/1892] change rdscatter to synchronize send recv --- cube/runtime/adapter/collectives.py | 62 +++++++---------------------- 1 file changed, 15 insertions(+), 47 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index bfda0434..0a7e0cf5 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -74,6 +74,7 @@ def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() if rank == src: + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor assert torch.is_tensor(tensor) torch.distributed.send(tensor, dst) else: @@ -214,19 +215,10 @@ def rdscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, if rank == src: with torch.no_grad(): otensors = itensor.chunk(len(dsts), dim) - send_ops = [] - for dst, otensor in zip(dsts, otensors): - if not otensor.is_contiguous(): - otensor = otensor.contiguous() - send_op = torch.distributed.P2POp( - torch.distributed.isend, otensor, dst - ) - send_ops.append(send_op) - reqs = torch.distributed.batch_isend_irecv(send_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) + for dst, otensor in zip(dsts, otensors): + otensor = otensor.contiguous() if not otensor.is_contiguous() else otensor + torch.distributed.send(otensor, dst) + otensor = itensor else: assert rank in dsts shape = list(shape) @@ -235,15 +227,9 @@ def rdscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, shape, requires_grad=True, dtype=dtype, device=torch.cuda.current_device() ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, otensor, src - ) - reqs = torch.distributed.batch_isend_irecv([recv_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) - return otensor + torch.distributed.recv(otensor, src) + CudaTimer().stop(field_name='comm', predefined=True) + return otensor def rvscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, @@ -271,37 +257,19 @@ def rdgather(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() if rank == dst: - recv_ops = [] recv_tensors = [] for src in srcs: - tensor = torch.empty( - shape, dtype=dtype, - device=torch.cuda.current_device() - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, src - ) - recv_ops.append(recv_op) + tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device()) + torch.distributed.recv(tensor, src) recv_tensors.append(tensor) - reqs = torch.distributed.batch_isend_irecv(recv_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - with torch.no_grad(): - otensor = torch.cat(tuple(recv_tensors), dim=dim) + otensor = torch.cat(tuple(recv_tensors), dim=dim) otensor = otensor.requires_grad_() - CudaTimer().stop(field_name='comm', predefined=True) - return otensor else: assert rank in srcs - tensor = itensor.contiguous() if not itensor.is_contiguous() else itensor - send_ops = [torch.distributed.P2POp(torch.distributed.isend, tensor, dst)] - reqs = torch.distributed.batch_isend_irecv(send_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) - return itensor + otensor = itensor.contiguous() if not itensor.is_contiguous() else itensor + torch.distributed.send(otensor, dst) + CudaTimer().stop(field_name='comm', predefined=True) + return otensor def rvgather(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, From 05095feadd04b0065e4a34a0113a68897f5f8850 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 26 Oct 2022 12:55:55 +0800 Subject: [PATCH 1092/1892] fix bug on cache --- cube/graph/gener/layout.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index e9e8ea56..5e2a26f5 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -713,6 +713,7 @@ def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, dst = ('c',) + (olayout.R, olayout.V) + tuple(olayout.D) if key in PathFinder._cached_inter_nodes and src in PathFinder._cached_inter_paths[key]: + nodes = PathFinder._cached_inter_nodes[key] paths = PathFinder._cached_inter_paths[key][src] else: if key in PathFinder._cached_inter_nodes: From 797516b7b8242ed08e78bc2ef5c0bf6668171db4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 26 Oct 2022 13:46:18 +0800 Subject: [PATCH 1093/1892] sorted field name for print --- cube/profiler/timer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index 616be085..4569e7e9 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -138,7 +138,9 @@ def print_all(self, times: int, rank_only: Optional[int] = None): @return None """ msg = list() - for field_name in self.instance.field_data: + names = list(self.instance.field_data.keys()) + names.sort() + for field_name in names: span = self.duration(times, field_name) msg.append('{} : {:.2f} ms'.format(field_name, span)) msg = ' | '.join(msg) From aa85bae4032bca23c37c2941f049c07dca6f2a9a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 26 Oct 2022 13:46:51 +0800 Subject: [PATCH 1094/1892] fix 1f1b schedule bug --- cube/runtime/schedule/sched1f1b.py | 16 ++++++++-------- cube/runtime/schedule/strategy.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py index 555208a8..0f570a82 100644 --- a/cube/runtime/schedule/sched1f1b.py +++ b/cube/runtime/schedule/sched1f1b.py @@ -25,7 +25,7 @@ def run(segment: Callable, # forward body for _ in range(num_warmup_microbatches): # recv forward # print(f'rank[{torch.distributed.get_rank()}]: line26: recving forward') - inputs = Schedule1F1B.adapter_step(rfadapter) + inputs = Schedule1F1B.adapter_step(rfadapter, True) inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs # forward Schedule1F1B.push_tail('inputs', inputs) @@ -39,11 +39,11 @@ def run(segment: Callable, # forward body Schedule1F1B.push_tail('outputs', outputs) # send forward # print(f'rank[{torch.distributed.get_rank()}]: line40 send forward') - Schedule1F1B.adapter_step(sfadapter, *outputs) + Schedule1F1B.adapter_step(sfadapter, True, *outputs) if num_warmup_remaining > 0: # print(f'rank[{torch.distributed.get_rank()}]: line44 recv forward') - inputs = Schedule1F1B.adapter_step(rfadapter) + inputs = Schedule1F1B.adapter_step(rfadapter, True) inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs # steady @@ -61,7 +61,7 @@ def run(segment: Callable, # forward body # send forward recv backward # print(f'rank[{torch.distributed.get_rank()}]: line62 send forward recv backward') - grads = Schedule1F1B.exchange(sfadapter, rbadapter, stage_id, *outputs) + grads = Schedule1F1B.exchange(sfadapter, rbadapter, stage_id, (True, False), *outputs) grads = (None,) if len(grads) == 0 else grads # backward @@ -75,19 +75,19 @@ def run(segment: Callable, # forward body # send backward recv forward if i != num_warmup_remaining - 1: # print(f'rank[{torch.distributed.get_rank()}]: line77 send backward recv forward') - inputs = Schedule1F1B.exchange(sbadapter, rfadapter, stage_id, *input_grads) + inputs = Schedule1F1B.exchange(sbadapter, rfadapter, stage_id, (False, True), *input_grads) inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs else: # send backward # print(f'rank[{torch.distributed.get_rank()}]: line82 send backward') - Schedule1F1B.adapter_step(sbadapter, *input_grads) + Schedule1F1B.adapter_step(sbadapter, False, *input_grads) # cooldown for i in range(num_warmup_microbatches): inputs, outputs = Schedule1F1B.pop_head('inputs'), Schedule1F1B.pop_head('outputs') # recv backward # print(f'rank[{torch.distributed.get_rank()}]: line89 recv backward') - grads = Schedule1F1B.adapter_step(rbadapter) + grads = Schedule1F1B.adapter_step(rbadapter, False) grads = (None,) if len(grads) == 0 else grads # backward if recompute: @@ -97,7 +97,7 @@ def run(segment: Callable, # forward body input_grads = Schedule1F1B.backward_step(inputs, outputs, grads) # send backward # print(f'rank[{torch.distributed.get_rank()}]: line99 send backward') - Schedule1F1B.adapter_step(sbadapter, *input_grads) + Schedule1F1B.adapter_step(sbadapter, False, *input_grads) Schedule1F1B.assert_empty() # print(f'rank[{torch.distributed.get_rank()}]: ok here') diff --git a/cube/runtime/schedule/strategy.py b/cube/runtime/schedule/strategy.py index 3b78c89d..b4953817 100644 --- a/cube/runtime/schedule/strategy.py +++ b/cube/runtime/schedule/strategy.py @@ -49,7 +49,7 @@ def dataloader_step(dataloader: Iterable): return data @staticmethod - def adapter_step(adapter: Callable, *args): + def adapter_step(adapter: Callable, require_grad : bool = True, *args): """ adapter pass """ @@ -59,20 +59,22 @@ def adapter_step(adapter: Callable, *args): CudaTimer().stop('adapter') if not isinstance(outputs, tuple): outputs = (outputs,) + if require_grad: + outputs = tuple(t.requires_grad_() if torch.is_tensor(t) else t for t in outputs) return outputs @staticmethod - def exchange(sadapter: Callable, radapter: Callable, stage_id: int, *args): + def exchange(sadapter: Callable, radapter: Callable, stage_id: int, require_grads: bool, *args): """ send adapter and recv adapter """ # TODO: optimize with batch operators if stage_id % 2 == 0: - ScheduleABC.adapter_step(sadapter, *args) - outs = ScheduleABC.adapter_step(radapter) + ScheduleABC.adapter_step(sadapter, require_grads[0], *args) + outs = ScheduleABC.adapter_step(radapter, require_grads[1]) else: - outs = ScheduleABC.adapter_step(radapter) - ScheduleABC.adapter_step(sadapter, *args) + outs = ScheduleABC.adapter_step(radapter, require_grads[1]) + ScheduleABC.adapter_step(sadapter, require_grads[0], *args) return outs @staticmethod From d166e520421bd4871cbbfa485e649b5abc5af76f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 26 Oct 2022 15:54:33 +0800 Subject: [PATCH 1095/1892] fix layout space bug --- cube/graph/gener/layout.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 5e2a26f5..a5088eec 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -506,7 +506,7 @@ def iter_idx(dims: List[int]) -> Tuple[int]: indmap = [] shape = [] for dim, (nchunk, index) in enumerate(zip(dims, indices[1:])): - assert ftensor.shape[dim] % nchunk == 0, f"not dividable for {nchunk} chunks over dim {dim}" + assert ftensor.shape[dim] % nchunk == 0, f"not dividable for {nchunk} chunks over dim {dim}. ftensor shape: {ftensor.shape}" csize = ftensor.shape[dim] // nchunk start = csize * index indmap.append((start, start+csize)) @@ -1001,10 +1001,13 @@ def factors(ndevs: int, length: int): yield [i] + res for rvd in factors(ndevs, 2+len(ftensor.shape)): + skip = False for dimlen, pnum in zip(ftensor.shape, rvd[2:]): if dimlen % pnum != 0: - continue - all_layouts.append(tuple(rvd)) + skip = True + break + if not skip: + all_layouts.append(tuple(rvd)) return all_layouts @staticmethod From 1967d01b71d2b889af7bf45f144b84a94d5dc4bc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 26 Oct 2022 16:56:23 +0800 Subject: [PATCH 1096/1892] fix split val bug --- cube/ir/tensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 20633768..c5b2f322 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -734,10 +734,9 @@ def split_val(self, num: int) -> List[IRTensor]: indmap.append((0, nele)) sub_tensors = list() for idx in range(num): - valmap = self._valmap.map((idx, num)) sub_tensor = self.select( indmap=tuple(indmap), - valmap=valmap, + valmap=(idx, num), ) sub_tensors.append(sub_tensor) return sub_tensors From 0445cb722451e4a56db839b8c968ce159957c2c4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 26 Oct 2022 22:54:26 +0800 Subject: [PATCH 1097/1892] fix bugs in weight reducer generation --- cube/graph/gener/gen.py | 78 +++++++++++++++++------------------------ cube/graph/segment.py | 2 +- cube/ir/operator.py | 5 +-- 3 files changed, 37 insertions(+), 48 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index cdaf7ede..8d83be5a 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Dict +from typing import Dict, List, Optional, Tuple import numpy as np import itertools @@ -56,7 +56,7 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: devices = [consumer.device for consumer in segment.consumers(tensor.parent)][::-1] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = DummyInputOuput(tensor, 0, is_output=True) + fwop = DummyInputOuput(tensor, 0, is_output=True, name=f'segment{segment.cid}_input') for devid in devices: fop = fwop.replicate() fop.device = devid @@ -72,7 +72,7 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: devices = [producer.device for producer in segment.producers(tensor.parent)] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = DummyInputOuput(tensor, 0, is_input=True) + fwop = DummyInputOuput(tensor, 0, is_input=True, name=f'segment{segment.cid}_output') for devid in devices: fop = fwop.replicate() fop.device = devid @@ -181,57 +181,45 @@ def gen_weight(graph: IRGraph) -> IRGraph: fgrads[fweight].append(wtensor.grad) consumers[fweight].append(fnode) - # bucketing + nl = '\n' weights: Dict[IRFullTensor, Dict[IRSubTensor, List[int]]] = dict() for fweight in fweights.keys(): - cids = set(fnode.cid for fnode in consumers[fweight]) - nl = '\n' - # case 1: no replica - if len(cids) == len(consumers[fweight]): - weights[fweight] = dict() - for wtensor, consumer in zip(fweights[fweight], consumers[fweight]): - if wtensor not in weights[fweight]: - weights[fweight][wtensor] = set() - weights[fweight][wtensor].add(consumer.device[0]) - # case 2: replica but has same number of replicas and same/no-overlapping devices - else: - cid_fnodes = {cid : [n for n in consumers[fweight] if n.cid == cid] for cid in cids} - cid_nnodes = [len(ns) for ns in cid_fnodes.values()] - # same replica# for each cid - assert all(cid_nnodes[0] == ns for ns in cid_nnodes), ( + weights[fweight] = {} + weight_grads: Dict[IRSubTensor, Dict[IRSubTensor, List[IRFwOperation]]] = {} + for weight, grad, consumer in zip(fweights[fweight], fgrads[fweight], consumers[fweight]): + if weight not in weight_grads: + weight_grads[weight] = {} + if grad not in weight_grads[weight]: + weight_grads[weight][grad] = [] + weight_grads[weight][grad].append(consumer) + + # TODO: check sub_weight is no-overlapping + + # assert all(sw.valmap[1] == len(weight_grads) for sw in weight_grads.keys()) + for sub_weight in weight_grads: + diff_grads = weight_grads[sub_weight] + diff_grads_len = [len(diff_grads[grads]) for grads in diff_grads] + assert all(n == diff_grads_len[0] for n in diff_grads_len), ( f"If one of the weight consumers are replicated, " f"other same-weight consumers should also replicated in same way." f"FullTensor Weight: {fweight}\n" f"Consumers:\n{nl.join([repr(node) for node in consumers[fweight]])}" ) - cid_devs = {cid: set(n.device[0] for n in consumers[fweight]) for cid in cids} - # case 2.1: same device sharing - first = list(cid_devs.keys())[0] - if all(cid_devs[first] == devs for devs in cid_devs.values()): - #TODO: need to be more robust - continue - # case 2.2: no-overlapping device sharing - all_devs = set() - for devs in cid_devs.values(): - all_devs.update(devs) - if sum(len(devs) for devs in cid_devs.values()) == len(all_devs): - raise NotImplementedError( - f"Weight is consumed by multiple different operators.\n" - f"Replicating different operators on no-overlapping device group is not supported yet.\n" - f"FullTensor Weight: {fweight}\n" - f"Consumers:\n{nl.join([repr(node) for node in consumers[fweight]])}" - ) - else: - raise NotImplementedError( - f"Weight is consumed by multiple different operators.\n" - f"Replicating different operators on partial-overlapping device group is not supported yet.\n" - f"FullTensor Weight: {fweight}\n" - f"Consumers:\n{nl.join([repr(node) for node in consumers[fweight]])}" - ) + # get devices + devices = [] + for sub_grad in diff_grads: + sub_grad_devices = [node.device[0] for node in diff_grads[sub_grad]] + sub_grad_devices.sort() + devices.append(sub_grad_devices) + devices = np.array(devices, dtype=int).transpose((1, 0)) + for group_devices in devices: + group_devices = set(int(devid) for devid in group_devices) + group_devices = list(group_devices) + group_devices.sort() + weights[fweight][sub_weight] = group_devices reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() - for ftensor, subtensors in weights.items(): - # TODO: check no overlapping (not same) weights on a device + for subtensors in weights.values(): for subw in subtensors: if len(subtensors[subw]) == 1: continue diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 0f0b0ae6..97f19a3b 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -353,7 +353,7 @@ def create_bwop(self, fwop: IRFwOperation) -> Union[IRBpOperation, IRBpOperation if isinstance(fwop, IRFwOperation): bwop = IRBpOperation(ograds, igrads) else: - bnodes = [fnode.mirror for fnode in fwop.nodes() if fnode.mirror is not None] + bnodes = [fnode.mirror for fnode in fwop.nodes() if fnode.mirror is not None][::-1] bwop = IRSegment(bnodes, ograds, igrads) IRCell.make_pair(fwop, bwop) return bwop diff --git a/cube/ir/operator.py b/cube/ir/operator.py index f88cec8c..2fcbd0fe 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -109,6 +109,7 @@ def replicate(self): """ cpy = copy.copy(self) cpy._device = list() + cpy._id = self.cid # reset input and output cpy.reset_inputs(len(self.inputs())) for idx, input in enumerate(self.inputs()): @@ -165,7 +166,7 @@ def replicate(self): """ cpy = copy.copy(self) cpy._device = list() - cpy._id = IDGenerator().gen_cell_id() + cpy._id = self.cid # reset input and output cpy.reset_inputs(len(self.inputs())) for idx, input in enumerate(self.inputs()): @@ -200,7 +201,7 @@ def replicate(self): """ cpy = copy.copy(self) cpy._device = list() - cpy._id = IDGenerator().gen_cell_id() + cpy._id = self.cid # reset input and output cpy.reset_inputs(len(self.inputs())) for idx, input in enumerate(self.inputs()): From 51286feae29eb15525961f0dbeb737363fdac123 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 26 Oct 2022 22:55:28 +0800 Subject: [PATCH 1098/1892] add piper space, performance aligns with Megatron --- examples/nlp/gpt/policy/mpmd.py | 93 +++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 4 deletions(-) diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index ec4c1783..d8fb09ee 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -100,7 +100,7 @@ def PAS1F1B(graph: IRGraph, resource): 1F1B scheduling """ num_stages = resource.ngpus - num_microbatch = resource.ngpus * 8 + num_microbatch = 16 # group to transformer layers fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] @@ -133,10 +133,10 @@ def PASMegatron(graph: IRGraph, resource): """ 1F1B scheduling """ - dp_size = 2 + dp_size = 1 tp_size = 2 pp_size = resource.ngpus // (dp_size * tp_size) - num_microbatch = pp_size * 2 + num_microbatch = 16 # device mesh dp_groups, pp_groups, tp_groups = \ @@ -181,12 +181,97 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) elif fnode.name == 'linear': # the last embeding linear fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + elif fnode.name == 'sum': + fnodes = _tp(graph, fnode, [0]*tp_size, idx=0, dim=2, num=tp_size) else: fnodes = _replica(graph, fnode, [0]*tp_size) # data parallel for tp_idx, fnode in enumerate(fnodes): dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] - print(dp_devices) + batch_dim = fnode.input(0).shape.index(bs) + _tp(graph, fnode, devs=dp_devices, idx=0, dim=batch_dim, num=dp_size) + + strategy = IRSchedule1F1B(graph, num_microbatch) + graph.predef_sched(strategy) + print(graph.extra_repr()) + return graph + + +def PASPiperSpace(graph: IRGraph, resource): + + # ================= Policy hyper parameter =================== + num_microbatch = 16 + # num_stages = 3 + # sub_meshes = [(1,4), (1,4), (1,8)] # (dp_size, tp_size) + # stage_layers = [(0,6), (6,12), (12,24)] # (start, end) + assert resource.ngpus == 8 + num_stages = 3 + sub_meshes = [(1,2), (1,2), (1,4)] # (dp_size, tp_size) + stage_layers = [(0,6), (6,12), (12,24)] # (start, end) + # ============================================================ + + # checking + transformers = _group_to_transformers(graph.select(ntype=IRFwOperation)) + assert len(stage_layers) == num_stages, f"Expect {num_stages} pipeline stages but got {len(stage_layers)} stage layer assignment." + nlayers = 0 + for sid, (start, end) in enumerate(stage_layers): + prev_end = stage_layers[sid-1][1] if sid > 0 else 0 + assert start == prev_end, f"Layers are not contiguous" + nlayers += end-start + assert nlayers == len(transformers), f"Total layer number {nlayers} != model layers {len(transformers)}" + # check gpus allocation + device_allocation = [] + devices = 0 + for mesh in sub_meshes: + dp_size, tp_size = mesh + assert dp_size >= 1 and tp_size >= 1 + stage_ngpus = dp_size * tp_size + device_allocation.append(stage_ngpus) + devices += stage_ngpus + assert devices <= resource.ngpus, f"Total GPUs in policy ({devices}) > resource capacity ({resource.ngpus})" + + # pipeline staging + fstages = [[] for _ in range(num_stages)] + for sid, (start, end) in enumerate(stage_layers): + for lid in range(start, end): + fstages[sid] += transformers[lid] + graph.staging(tuple(stages[0] for stages in fstages)) + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + + # setup data loader + dataloader = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[0] + + + # sub mesh of (dp_size, tp_size) + for sid, fstage in enumerate(fstages): + dp_size, tp_size = sub_meshes[sid] + devices = np.arange(dp_size * tp_size, dtype=int) + sum(device_allocation[:sid]) + devices = devices.reshape((dp_size, tp_size)) + # setup dataloader + if sid == 0: + dls = _replica(graph, dataloader, [0] * tp_size) + for tp_idx, dl in enumerate(dls): + dp_devices = list(int(devid) for devid in devices[:,tp_idx].flatten()) + dp_dls = graph.partition(dl, dl.algorithms('data'), num=dp_size) + for devid, dp_dl in zip(dp_devices, dp_dls): + graph.assign(dp_dl, devid) + for fnode in fstage.nodes(): + if len(fnode.inputs()) == 0: continue # anchor + # tensor parallel -- FIXME: current restriction needs replica happen before partition + if fnode.name == 'self_attention' or fnode.name == 'feedforward': + fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + elif fnode.name == 'embedding': + fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + elif fnode.name == 'linear': # the last embeding linear + fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + elif fnode.name == 'sum': + fnodes = _tp(graph, fnode, [0]*tp_size, idx=0, dim=2, num=tp_size) + else: + fnodes = _replica(graph, fnode, [0]*tp_size) + # data parallel + for tp_idx, fnode in enumerate(fnodes): + dp_devices = list(int(devid) for devid in devices[:,tp_idx].flatten()) batch_dim = fnode.input(0).shape.index(bs) _tp(graph, fnode, devs=dp_devices, idx=0, dim=batch_dim, num=dp_size) From 15b80b19b2d1a19fa70d944893bcc2f9bfc92f6c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 27 Oct 2022 13:30:23 +0800 Subject: [PATCH 1099/1892] code generation for distributed filesystem --- cube/compiler.py | 20 ++++++-------------- cube/runtime/device.py | 6 ++++-- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index e1cf6006..ea4c3b93 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -17,6 +17,7 @@ from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen from cube.profiler.timer import print_each_rank +from cube.runtime.device import DeviceGroup from cube.runtime.syndata import CubeDataLoader, SciLoopVariables from cube.program import Program, SemanticDataLoader, SemanticModel @@ -70,13 +71,7 @@ def train_step(model, dataloader): model_graph = model.get_graph() ir_dataloader = SemanticDataLoader(dataloader) - if torch.distributed.is_initialized(): - # multiple device - myrank = torch.distributed.get_rank() - local_rank = cube.runtime.device.DeviceGroup().local_rank - else: - # single device - myrank = local_rank = 0 + myrank = DeviceGroup().rank def _load_tschedule_fn(filename) -> Callable: import importlib.util @@ -101,7 +96,7 @@ def decorator(fn: Callable) -> Callable: print_each_rank(f'loading existed schedule from {filename} ...') return _load_tschedule_fn(filename) - if local_rank == 0: + if DeviceGroup().local_rank == 0: compile_start = time.time() @@ -171,15 +166,12 @@ def decorator(fn: Callable) -> Callable: # execplan.graph.reset_dependency() # execplan.analyze(outfile='execplan.png') - if torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size() - else: - world_size = 1 - + local_world_size = DeviceGroup().local_world_size # code generation mgener = ModelCodeGen(execplan) sgener = ScheduleCodeGen(execplan) - for rank in range(world_size): + for local_rank in range(local_world_size): + rank = DeviceGroup().node_rank * local_world_size + local_rank fname = filename.format(rank) # generate spatial module code mgener.gen(rank, outfile=fname, attach=False) diff --git a/cube/runtime/device.py b/cube/runtime/device.py index 6b6cf157..9ed7a968 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -18,8 +18,9 @@ def __init__(self): print(f"DeviceGroup init using single device mode...") self.rank = 0 self.world_size = 1 + self.local_world_size = 1 self.local_rank = 0 - self.node_id = 0 + self.node_rank = 0 self.groups = dict() torch.cuda.set_device(0) else: @@ -29,8 +30,9 @@ def __init__(self): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() # assume each node has the same device number + self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) self.local_rank = int(os.environ.get('LOCAL_RANK')) - self.node_id = self.rank // torch.cuda.device_count() + self.node_rank = int(os.environ.get('GROUP_RANK')) self.groups = dict() torch.cuda.set_device(self.local_rank) From dd26a59d9e9a73b1251eedbe0586273b11c4b557 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 27 Oct 2022 13:40:32 +0800 Subject: [PATCH 1100/1892] clean useless package --- cube/graph/gener/layout.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index a5088eec..f52d1520 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -1,11 +1,10 @@ from typing import Callable, Dict, List, Tuple, Optional import copy import numpy as np -from regex import R -from cube.ir.cten import IRCell +from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor -from cube.ir.tensor import IndexMap, ValueMap +from cube.ir.tensor import ValueMap from cube.ir.adapter.prim import IRAdapterPrim from cube.ir.adapter.prim import AllGatherPrim # d2r @@ -18,7 +17,6 @@ from cube.ir.adapter.prim import BroadcastPrim from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim -from cube.runtime.device import DeviceGroup TShape = Tuple[int, ...] From 291f275b902f099ec110d08e0409f339028de6c4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 27 Oct 2022 12:46:31 +0000 Subject: [PATCH 1101/1892] benchmark gpt script update --- benchmark/megatron/benchmark_gpt.sh | 65 ++++++++++++++++++----------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/benchmark/megatron/benchmark_gpt.sh b/benchmark/megatron/benchmark_gpt.sh index 9580a124..7e973e1c 100755 --- a/benchmark/megatron/benchmark_gpt.sh +++ b/benchmark/megatron/benchmark_gpt.sh @@ -1,20 +1,26 @@ - -# get megatron +# setup megatron # git clone https://github.com/NVIDIA/Megatron-LM.git +# pip install regex + +# setup apex +# git clone https://github.com/NVIDIA/apex +# cd apex +# pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . +# cd .. cp pretrain_gpt_synthetic.py ./Megatron-LM/ -GPUS=8 +NODE_GPUS=8 +PP=4 +TP=4 -GPT_ARGS="--num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --seq-length 1024 \ - --max-position-embeddings 1024 \ - --micro-batch-size 1 \ - --global-batch-size 1 \ +GPT_ARGS="--num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ --lr 0.00015 \ - --train-iters 200 \ + --train-iters 10 \ --lr-decay-iters 320000 \ --lr-decay-style cosine \ --lr-warmup-fraction .01 \ @@ -26,29 +32,38 @@ GPT_ARGS="--num-layers 24 \ --no-bias-dropout-fusion \ --no-async-tensor-model-parallel-allreduce \ --no-gradient-accumulation-fusion \ + --checkpoint-activations \ + --log-interval 1 \ --num-workers 0" -DISTRIBUTED_ARGS="--nproc_per_node $GPUS \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr localhost \ - --master_port 6000" +# --checkpoint-activations + +SINGLE_NODE="--nproc_per_node $NODE_GPUS \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 6000" + +MULTI_NODE="--nproc_per_node $NODE_GPUS \ + --nnodes 2 \ + --node_rank ${NODE_RANK} \ + --master_addr worker-0 \ + --master_port 6012" cd Megatron-LM -OMP_NUM_THREADS=4 python -m torch.distributed.launch $DISTRIBUTED_ARGS \ +OMP_NUM_THREADS=4 python -m torch.distributed.launch $MULTI_NODE \ pretrain_gpt_synthetic.py $GPT_ARGS \ - --tensor-model-parallel-size ${GPUS}\ - --pipeline-model-parallel-size 1 \ - --DDP-impl torch + --global-batch-size 128 \ + --micro-batch-size 4 \ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --DDP-impl local + # OMP_NUM_THREADS=4 python -m torch.distributed.launch \ -# --nproc_per_node 1 \ -# --nnodes 1 \ -# --node_rank 0 \ -# --master_addr localhost \ -# --master_port 6000 \ +# --nproc_per_node 1 --master_addr localhost --master_port 6112 \ # pretrain_gpt_synthetic.py -h cd .. \ No newline at end of file From eed69c53e903312be80d82b48f9265de1de870bc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 27 Oct 2022 12:49:00 +0000 Subject: [PATCH 1102/1892] fix recompute in pipline stages --- cube/algorithm/ops/dimops.py | 2 +- cube/codegen/codegen.py | 3 +-- cube/graph/graph.py | 7 +++++++ cube/program.py | 4 +--- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 030d2a9b..d9a77c27 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -99,7 +99,7 @@ def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List else: adim, reduce = 'Value', None color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' - print(f"try split {node.name}: {node.anno} | dim: {adim} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") + print(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") if not satisfy: return None rule: TransformRule = self.infer(idx, dim, num) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 024035a3..99e0a778 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -682,8 +682,7 @@ def recompute(tensor_2222): node_codes = [] nodes : List[IRCell] = [node for i, node in i_nodes] - - subseg = self.execplan.graph.create_segment(nodes) + subseg = segment.create_segment(nodes) inputs = [t for t in subseg.inputs() if not t.is_attr()] input_names = [self.tensor_naming(t) for t in inputs] diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 77f4f481..ed1cc3ea 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -8,6 +8,7 @@ """ from typing import Union, Tuple, List, Optional, Dict +from cube.graph.function.anchor import IRGraphAnchor from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator @@ -750,6 +751,12 @@ def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: "Cross-segment recompute is not allowed yet" recompute_group_id: int = IDGenerator().gen_cell_id() for fnode in nodes: + if isinstance(fnode, IRGraphAnchor): + continue + # pytorch limitation + if all(not t.requires_grad for t in fnode.inputs() if isinstance(t, IRSubTensor) and (not t.is_attr())): + print(f"skipping recompute node: {fnode}\n\tbecause all its input tensors doesn't require grad.") + continue fnode.recompute = recompute_group_id return True diff --git a/cube/program.py b/cube/program.py index 2acefa7d..e9a8a982 100644 --- a/cube/program.py +++ b/cube/program.py @@ -124,9 +124,7 @@ def __init__(self, model: torch.nn.Module, input_shapes): """ Create semantic model based on AI Scientist description. """ - local_rank = 0 - if torch.distributed.is_initialized(): - local_rank = DeviceGroup().local_rank + local_rank = DeviceGroup().local_rank if local_rank == 0: self.ir_graph = parser.convert_model( model, input_shapes=input_shapes From 4b8ae67b40d4b1c569ab86721f1f249ea87208f7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 28 Oct 2022 12:33:38 +0800 Subject: [PATCH 1103/1892] no need for input shapes of semantic model. Remove parser save of parameters if load_content is False --- cube/compiler.py | 8 +++--- cube/graph/parser/converter.py | 4 ++- cube/graph/parser/parser.py | 5 +++- cube/program.py | 51 +++++++++++++++++++++++++--------- 4 files changed, 49 insertions(+), 19 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index ea4c3b93..96f0c52e 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -18,7 +18,7 @@ from cube.profiler.timer import print_each_rank from cube.runtime.device import DeviceGroup -from cube.runtime.syndata import CubeDataLoader, SciLoopVariables +from cube.runtime.syndata import CubeDataLoader from cube.program import Program, SemanticDataLoader, SemanticModel @@ -68,7 +68,7 @@ def train_step(model, dataloader): if callable(PAS): PAS = (PAS,) - model_graph = model.get_graph() + model.save_content = load_content ir_dataloader = SemanticDataLoader(dataloader) myrank = DeviceGroup().rank @@ -103,7 +103,7 @@ def decorator(fn: Callable) -> Callable: resource = cube.runtime.resource.EnvResource() # run once to get model structure and tensor shape - outputs = fn(model_graph, ir_dataloader) + outputs = fn(model, ir_dataloader) Program().finalize() if outputs is None: outputs = [] @@ -191,7 +191,7 @@ def decorator(fn: Callable) -> Callable: # load module filename = filename.format(myrank) print_each_rank(f'loading generated module from {filename} ...') - model.load_module(filename, load_content=load_content) + model.load_module(filename) if torch.distributed.is_initialized(): torch.distributed.barrier() diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 37f034de..229b23dd 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -7,7 +7,8 @@ import torch def convert_model(model: torch.nn.Module, - input_shapes: Optional[ List[List[int],] ] = None) -> IRGraph: + input_shapes: Optional[ List[List[int],] ] = None, + save_content: bool = True) -> IRGraph: """ Convert toch.nn.Module based model into IRGraph """ @@ -17,6 +18,7 @@ def convert_model(model: torch.nn.Module, print(ex) raise RuntimeError("Cannot convert module into torchscript moudle.") module_name = smodule.original_name + ScriptModuleParser.save_content = save_content inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) for input in inputs: if isinstance(input, IRFullTensor): diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index fff5a392..ab9f2b04 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -34,6 +34,8 @@ class ScriptNodeKind(enum.Enum): class ScriptModuleParser: + save_content: bool = True + @staticmethod def parse_module(module, input_shapes: Optional[ Tuple[List[int],] ] = None, @@ -95,7 +97,8 @@ def parse_module(module, frame.pop_var() frame.pop_attr() - frame.save_attr_content() + if ScriptModuleParser.save_content: + frame.save_attr_content() return input_val, all_ir_nodes, output_val @staticmethod diff --git a/cube/program.py b/cube/program.py index e9a8a982..a920f49e 100644 --- a/cube/program.py +++ b/cube/program.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Optional from cube.graph.torch_dtype_mapping import DType2IRDType from cube.ir.cten import IRCell, IRTensor @@ -120,41 +120,66 @@ def __next__(self): class SemanticModel: - def __init__(self, model: torch.nn.Module, input_shapes): + def __init__(self, model: Optional[torch.nn.Module], input_shapes=None): """ Create semantic model based on AI Scientist description. + + @param model Optional[torch.nn.Module]: Model description. Each device of local_rank == 0 needs to provide. + @param input_shapes Any: to compatable with previous interface. No more need. """ - local_rank = DeviceGroup().local_rank - if local_rank == 0: - self.ir_graph = parser.convert_model( - model, input_shapes=input_shapes - ) - else: - self.ir_graph = None + if DeviceGroup().local_rank == 0: + assert isinstance(model, torch.nn.Module), f"device of local_rank == 0 must provide model" + self.model = model + self.input_shapes = None + self.ir_graph = None self._loaded_module: CubeModule = None + self._save_content = True + + @property + def save_content(self) -> bool: + return self._save_content + + @save_content.setter + def save_content(self, val: bool): + self._save_content = val def get_graph(self): return self.ir_graph - def load_module(self, filename: str, load_content=True): + def load_module(self, filename: str): import importlib.util spec = importlib.util.spec_from_file_location("GenModel", filename) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self._loaded_module = module.GenModel().cuda() - if load_content: + if self.save_content: print_each_rank("> loading parameter content...") # TODO: make hardcode ./fullmodel.pt programmable self._loaded_module.load_attr_content('./fullmodel.pt') - def get_gen_module(self): + def get_gen_module(self) -> Optional[torch.nn.Module]: return self._loaded_module def clear_module(self): self._loaded_module = None def __call__(self, *args): + """ + Forward the model. + This will trigger torch.jit.script to parse the model. + """ if self._loaded_module: return self._loaded_module(*args) else: - return self.ir_graph(*args) \ No newline at end of file + assert all(isinstance(t, IRSubTensor) for t in args), f"Only support tensors as model inputs" + input_shapes = [tuple(t.shape) for t in args] + if DeviceGroup().local_rank == 0: + if self.ir_graph is None: + self.ir_graph = parser.convert_model( + self.model, input_shapes=input_shapes, save_content=self.save_content + ) + self.input_shapes = input_shapes + else: + assert tuple(self.input_shapes) == tuple(input_shapes), \ + f"Multiple forwarding of a same model, which require input shapes to be same." + return self.ir_graph(*args) From 49419f2d9e1c0fedce245c98bbeb4f9902e18570 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 31 Oct 2022 10:51:37 +0800 Subject: [PATCH 1104/1892] layer_norm to match node semantics --- cube/graph/function/dimops.py | 2 +- cube/graph/function/function.py | 46 +++++++++++++++++++++++++------ cube/runtime/function/function.py | 9 ++++++ 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 407bac08..7204c444 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -276,8 +276,8 @@ def create_shape_str(shape: Tuple[int], reduction: str = '', iterator: Optional[ e.g., ['a+', 'b+', 'c+'] @param shape List[int]: tensor shape + @param reduction (str): reduction type must be in '', '+' or '^' @param iterator Optional[Iterable]: identity iterators. If None, use string.ascii_lowercase - @param reduce (str): reduction type must be in '', '+' or '^' @return strs List[str]: each element in strs represents a dimension """ diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index e170c707..922bfbda 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -431,15 +431,43 @@ def Dropout(signature, inputs): def LayerNorm(signature, inputs): - input, normalized_shape, weight, bias, eps = inputs - if len(normalized_shape) != 1: - raise NotImplementedError("Only support normalized_shape to be int") - annos = [ - f'N *, ?, {normalized_shape[0]}, {normalized_shape[0]} -> N *', - f'N *, ?, ?, ? -> N *' - ] - return IRDimops(LayerNorm, 'layernorm', signature, annos, [input, normalized_shape, weight, bias], - eps=eps) + """ + torch.nn.functional.layer_norm(input, normliazed_shape, weight=None, bias=None, eps) + cube.runtime.function.layer_norm(input, weight, bias, normliazed_shape, eps) + """ + if 'torch.' in signature: + tensor, normalized_shape, weight, bias, eps = inputs + assert isinstance(normalized_shape, list), f"normalized_shape for layer_norm can only be List[int]" + else: + tensor, weight, bias, normalized_shape, eps = inputs + letters = iter(string.ascii_lowercase) + einput = ShapeAnno.create_shape_str(tensor.shape, iterator=letters) + eoutput = copy.copy(einput) + ndims = len(tensor.shape) + for dim in range(len(normalized_shape)): + einput[ndims-1-dim] += '^' + eoutput[ndims-1-dim] += '^' + assert not (bias is None is weight is not None), f"Not support for None of weight and parameter of bias" + einputs, inputs = [einput], [tensor] + kwargs = {} + if weight is not None: + eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) + einputs.append(eweight) + inputs.append(weight) + else: + kwargs['weight'] = weight + if bias is not None: + ebias = ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) + einputs.append(ebias) + inputs.append(bias) + else: + kwargs['bias'] = bias + anno = OpAnno.create_op_str(einputs, [eoutput]) + kwargs['normalized_shape'] = normalized_shape + kwargs['eps'] = eps + signature = 'cube.runtime.function.layer_norm' + print(anno) + return IRDimops(LayerNorm, 'layernorm', signature, [anno], inputs, **kwargs) def Sum(signature, inputs): diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index b80e9ae3..4267d8ce 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -86,6 +86,15 @@ def embedding(input: torch.Tensor, weight: torch.Tensor, padding_idx: Optional[i return output +def layer_norm(input: torch.Tensor, + weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], + normalized_shape: List[int], eps: float = 1e-05) -> torch.Tensor: + """ + LayerNorm + """ + return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias, eps) + + # 'torch.select_scatter' isn't supported by Torch2ONNX yet. # Implement it with 'torch.masked_scatter' which is supported with ONNX opset=11. def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): From 20fc37dfb1bd49dc547121d984fc0b840c50f75d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 31 Oct 2022 10:53:51 +0800 Subject: [PATCH 1105/1892] remove print --- cube/graph/function/function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 922bfbda..aa052be4 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -466,7 +466,6 @@ def LayerNorm(signature, inputs): kwargs['normalized_shape'] = normalized_shape kwargs['eps'] = eps signature = 'cube.runtime.function.layer_norm' - print(anno) return IRDimops(LayerNorm, 'layernorm', signature, [anno], inputs, **kwargs) From 485c9024045c5928c04624c3aecb5efd4b0a0e68 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 31 Oct 2022 14:09:05 +0800 Subject: [PATCH 1106/1892] profiler database and refine code --- cube/graph/function/function.py | 20 ++- cube/graph/parser/mapping.py | 95 +++++++++++- cube/graph/parser/parser.py | 3 +- cube/graph/parser/register.py | 11 +- cube/graph/torch_dtype_mapping.py | 77 ---------- cube/profiler/__init__.py | 2 +- cube/profiler/database.py | 244 +++++++++++++++--------------- cube/program.py | 2 +- 8 files changed, 244 insertions(+), 210 deletions(-) delete mode 100644 cube/graph/torch_dtype_mapping.py diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index aa052be4..d872127a 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -7,6 +7,7 @@ from cube.ir.cten import IRTensor from cube.ir.tensor import IRSubTensor +from cube.ir.dtype import IRDType from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRConv2D from cube.graph.function.conv import IRConv3D @@ -17,8 +18,9 @@ from cube.graph.function.scatter import IRSelectScatter from cube.graph.function.repeat import IRRepeat from cube.graph.function.anchor import IRGraphAnchor -from cube.ir.dtype import IRDType -from cube.graph.torch_dtype_mapping import DType2IRDType, TorchScalarTypeEnumMap + + +ErasedDevice = 'str' def Identity(signature, inputs: List[IRTensor]): @@ -50,12 +52,14 @@ def BatchLinear(signature, inputs): def Zeros(signature, - inputs: Tuple[ List[int], Optional[int], Optional[Any], 'ErasedDevice', Optional[bool] ]): + inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): # zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. + from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap + size, dtype_underlying, layout, _erased_device, pin_memory = inputs # TODO parameters to support, currently they are all None @@ -77,7 +81,7 @@ def Zeros(signature, return IRZeros(signature, size, 'zeros', ir_dtype) def Ones(signature, - inputs: Tuple[ List[int], Optional[int], Optional[Any], 'ErasedDevice', Optional[bool] ]): + inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): # ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor size, dtype_underlying, layout, _erased_device, pin_memory = inputs @@ -85,6 +89,7 @@ def Ones(signature, # TODO parameters to support, currently they are all None assert layout is None assert pin_memory is None + from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap if dtype_underlying is not None: # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, @@ -101,7 +106,7 @@ def Ones(signature, return IROnes(signature, size, 'ones', ir_dtype) def Rand(signature, - inputs: Tuple[ List[int], Optional[int], Optional[Any], 'ErasedDevice', Optional[bool] ]): + inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): # ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor size, dtype_underlying, layout, _erased_device, pin_memory = inputs @@ -109,6 +114,7 @@ def Rand(signature, # TODO parameters to support, currently they are all None assert layout is None assert pin_memory is None + from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap if dtype_underlying is not None: # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, @@ -125,7 +131,7 @@ def Rand(signature, return IRRand(signature, size, 'rand', ir_dtype) def NewTensor(signature, - inputs: Tuple[ list, Optional[int], 'ErasedDevice', bool ]): + inputs: Tuple[ list, Optional[int], ErasedDevice, bool ]): # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor # # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of @@ -135,6 +141,7 @@ def NewTensor(signature, # TODO parameters to support, currently they are all None assert requires_grad == False + from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap if dtype_underlying is not None: # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, @@ -184,6 +191,7 @@ def ToTensor(signature, opt_memory_format : Optional[int] tensor, dtype_underlying, non_blocking, copy, opt_memory_format = inputs + from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap dtype : torch.dtype = TorchScalarTypeEnumMap.map(dtype_underlying) ir_dtype : IRDType = DType2IRDType.map(dtype) diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index b6a2cafb..b080b4b4 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -1,11 +1,16 @@ """ Mapping of - Signature -> IROperator + signature -> IROperator + torch.dtype -> cube.ir.IRDType + cube.ir.IRDType -> torch.dtype """ +import torch + from typing import Callable, Dict, Union from functools import partial import cube.graph.function as function +import cube.ir as ir from cube.ir.operator import IRFwOperation @@ -163,3 +168,91 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # customized operator code: signature -> code kOpCodeDef: Dict[str, str] = {} + + +class DType2IRDType: + + @staticmethod + def map(dtype: torch.dtype): + """ + Map the torch dtype to IRDType + """ + return DType2IRDType.kDtypeMap[dtype] + + kDtypeMap = { + torch.double: ir.float64, + torch.float64: ir.float64, + torch.float32: ir.float32, + torch.float : ir.float32, + torch.float16: ir.float16, + torch.half : ir.float16, + torch.uint8 : ir.uint8, + torch.int8 : ir.int8, + torch.int16 : ir.int16, + torch.short : ir.int16, + torch.int32 : ir.int32, + torch.int : ir.int32, + torch.int64 : ir.int64, + torch.long : ir.int64, + torch.bool : ir.boolean + } + + +class IRDType2TorchDType: + + @staticmethod + def map(ir_dtype: ir.IRDType): + """ + Map the IRDtype to torch dtype + """ + return IRDType2TorchDType.kDtypeMap[ir_dtype] + + kDtypeMap = {val: key for key, val in DType2IRDType.kDtypeMap.items()} + + +# see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h +# +# ScalarType enum is totally a PyTorch-internal object. Neither itself nor its underlying ints +# are accessible from its Python frontend. +class TorchScalarTypeEnumMap: + + @staticmethod + def map(underlying: int) -> torch.dtype: + + assert isinstance(underlying, int), """ + This function is to convert an underlying 'int' for a Torch-internal 'at::ScalarType' enum + to its corresponding Python-frontend 'torch.dtype' enum. + """ + + dtype = TorchScalarTypeEnumMap._fields[underlying] + + assert dtype is not None, f""" + Referenced to an unsupported ScalarType with underlying int being {underlying} + """ + + return dtype + + # Less used dtypes are masked out because PyTorch keeps **exposing and hiding** them recently + # from a view of Python frontend. + _fields = [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.half, + torch.float32, + torch.float64, + None, #torch.complex32, # complexHalf + None, #torch.complex64, # complexFloat + None, #torch.complex128, # complexDouble + torch.bool, + None, #torch.qint8, + None, #torch.quint8, + None, #torch.qint32, + None, #torch.bfloat16, + None, #torch.quint4x2, + None, #torch.quint2x4, + ] + + assert len(_fields) == 18, "Do not remove any item, mask it out with None" diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index ab9f2b04..0bc04e61 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -7,8 +7,7 @@ from cube.ir.tensor import IRFullTensor import cube.ir as ir from cube.graph.parser.frame import Frame -from cube.graph.parser.mapping import Sign2Op -from cube.graph.torch_dtype_mapping import DType2IRDType +from cube.graph.parser.mapping import Sign2Op, DType2IRDType _refmodule = torch.nn.Module() diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index bba41f34..6be4e316 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -6,12 +6,11 @@ import inspect import torch -from cube.graph.function.dimops import IRDimops, OpAnno, TransformRule - +from cube.graph.function.dimops import IRDimops, OpAnno from cube.graph.parser.mapping import Sign2Op -def register(anno: str, name: Optional[str] = None, rules: Optional[List[TransformRule]] = None): +def register(anno: str, name: Optional[str] = None, rules: Optional[List] = None): """ Register a function with einop annotations. @@ -30,6 +29,12 @@ def funcname(x: torch.Tensor, b: int = 4): Note: for Optional[torch.Tensor] type, user should annotate the dimension when the input is not None. + + @param anno str: operator annotation + @param name str: operator name + @param rules Optional[List[TransformRule]]: additional transformation rules. + + @return fn Callable: the runtime function """ def decorator(fn: Callable): if not callable(fn): diff --git a/cube/graph/torch_dtype_mapping.py b/cube/graph/torch_dtype_mapping.py deleted file mode 100644 index 0787c913..00000000 --- a/cube/graph/torch_dtype_mapping.py +++ /dev/null @@ -1,77 +0,0 @@ -from cube import ir -import torch - -class DType2IRDType: - - @staticmethod - def map(dtype: torch.dtype): - """ - Map the torch dtype to IRDType - """ - return DType2IRDType.kDtypeMap[dtype] - - kDtypeMap = { - torch.double: ir.float64, - torch.float64: ir.float64, - torch.float32: ir.float32, - torch.float : ir.float32, - torch.float16: ir.float16, - torch.half : ir.float16, - torch.uint8 : ir.uint8, - torch.int8 : ir.int8, - torch.int16 : ir.int16, - torch.short : ir.int16, - torch.int32 : ir.int32, - torch.int : ir.int32, - torch.int64 : ir.int64, - torch.long : ir.int64, - torch.bool : ir.boolean - } - - -# see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h -# -# ScalarType enum is totally a PyTorch-internal object. Neither itself nor its underlying ints -# are accessible from its Python frontend. -class TorchScalarTypeEnumMap: - - @staticmethod - def map(underlying: int) -> torch.dtype: - - assert isinstance(underlying, int), """ - This function is to convert an underlying 'int' for a Torch-internal 'at::ScalarType' enum - to its corresponding Python-frontend 'torch.dtype' enum. - """ - - dtype = TorchScalarTypeEnumMap._fields[underlying] - - assert dtype is not None, f""" - Referenced to an unsupported ScalarType with underlying int being {underlying} - """ - - return dtype - - # Less used dtypes are masked out because PyTorch keeps **exposing and hiding** them recently - # from a view of Python frontend. - _fields = [ - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.half, - torch.float32, - torch.float64, - None, #torch.complex32, # complexHalf - None, #torch.complex64, # complexFloat - None, #torch.complex128, # complexDouble - torch.bool, - None, #torch.qint8, - None, #torch.quint8, - None, #torch.qint32, - None, #torch.bfloat16, - None, #torch.quint4x2, - None, #torch.quint2x4, - ] - - assert len(_fields) == 18, "Do not remove any item, mask it out with None" \ No newline at end of file diff --git a/cube/profiler/__init__.py b/cube/profiler/__init__.py index e349da1d..6bf47044 100644 --- a/cube/profiler/__init__.py +++ b/cube/profiler/__init__.py @@ -1,2 +1,2 @@ from cube.profiler.timer import CudaTimer -from cube.profiler.estimator import Estimator \ No newline at end of file +from cube.profiler.database import ProfileDataBase diff --git a/cube/profiler/database.py b/cube/profiler/database.py index f1cbfc6a..85630166 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -2,18 +2,22 @@ Usage: python -m cube.profiler.database --export ./profile.dat.json """ - -from typing import Callable, Tuple, Union, Optional, Dict, NewType +from typing import Callable, Tuple, Union, Optional, Dict, NewType, List import torch import time import os import json +import cube +from cube.ir.cten import IRTensor +from cube.ir.operator import IRFwOperation +from cube.graph.parser.mapping import Sign2Op, IRDType2TorchDType + Shapes = NewType('Shapes', Tuple[Tuple[int]]) DTypes = NewType('DTypes', Tuple[torch.dtype]) ShapesDTypes = NewType('ShapesDTypes', Tuple[Shapes, DTypes]) -NameOrFunc = NewType('NameOrFunc', Union[str, Callable]) +NameOrFunc = Union[str, Callable] class CompProfiler: @@ -99,31 +103,59 @@ def __init__(self, filename: Optional[str] = None) -> None: Create a database for profiling result """ - self._data: Dict[str, Dict[str, float]] = dict() + self._data: Dict[str, Dict[str, Tuple[float, float, int]]] = dict() if filename is not None: self.load(filename) - def profile(self, func: Callable, shapes: Shapes, dtypes: DTypes, **kwargs): - """! - Profile the function and log into the database - - @param func Callable: the callable function, e.g., torch.nn.functional.linear - @param shapes Tuple[Tuple[int]]: the shapes of each input tensor - @param dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 - @param backward bool: whether profile backward times. Default true. - @param kwargs Dict: other keyword argument for func call. + @staticmethod + def get_func(node: IRFwOperation) -> Tuple[Callable, Shapes, DTypes, Dict]: """ - try: - assert callable(func), "func should be callable" - fw_span, bw_span, memory = CompProfiler.profile(func, shapes, dtypes, **kwargs) - name = func.__name__ - key = self.serialize(shapes, dtypes) - self.log(name, key, fw_span, bw_span, memory) - print(f'profiled {func.__name__} | shapes: {shapes} | dtypes: {dtypes} => fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} | mem: {memory}') - except Exception as e: - print(f'fail to profile {func.__name__}: reason: {str(e)}') - - def log(self, name: str, key: str, fw_span: float, bw_span: float, memory: float): + Get function call and its arguments from a cude IRGraph node + """ + assert isinstance(node, IRFwOperation), f"Only support profiling forward operation but got {type(node)}" + if node.signature in Sign2Op.kOpCodeDef: + code_impl: str = Sign2Op.kOpCodeDef[node.signature] + local = {} + exec(code_impl, globals(), local) + fn = list(local.values())[0] + else: + fn = eval(node.signature) + shapes, dtypes = [], [] + for t in node.inputs(): + assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) + return fn, shapes, dtypes, node.kwargs + + def profile(self, node: IRFwOperation, device: Optional[int] = None): + """ + Profile a forward node in IRGraph on a specific device (default current device) + + @param node IRFwOperation: node of IRGraph + @param device int: the node + + @return fw_span float: forward span in milliseconds + @return bw_span float: backward span in milliseconds + @return mem int: peak memory consumpiton of forward + backward procedure. + """ + fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) + + if isinstance(device, int): + orig_device = torch.cuda.current_device() + torch.cuda.set_device(device) + + # run profiling + fw_span, bw_span, memory = CompProfiler.profile(fn, shapes, dtypes, **kwargs) + # log to database + key = self._serialize(node) + self.insert(node.signature, key, fw_span, bw_span, memory) + print(f'profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} => fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} | mem: {memory}') + + if isinstance(device, int): + torch.cuda.set_device(orig_device) + return fw_span, bw_span, memory + + def insert(self, name: str, key: str, fw_span: float, bw_span: float, memory: float): """ log the span of a function name with key """ @@ -132,86 +164,94 @@ def log(self, name: str, key: str, fw_span: float, bw_span: float, memory: float self._data[name] = dict() self._data[name][key] = (fw_span, bw_span, memory) - def query(self, func: NameOrFunc, shapes: Shapes, dtypes: DTypes) -> float: - """! - Get the performance number of the function name and its key - - @param name str: function name - @param shapes Tuple[Tuple[int]]: the shape of each input tensor - @param dtypes Tuple[torch.dtype]: the dtype of each tensor - - @return (fw_span, bw_span, mem) (float, float, int): the performance number + def exist(self, node: IRFwOperation) -> bool: """ - name = func if isinstance(func, str) else func.__name__ - key = self.serialize(shapes, dtypes) - return self._data[name][key] - - def exist_item(self, func: NameOrFunc, shapes: Shapes, dtypes: DTypes) -> bool: - """! - Check if the required data exists + Check if the node has the performance recorded in the database - @param name Union[str, Callable]: function name - @param shapes Tuple[Tuple[int]]: the shape of each input tensor - @param dtypes Tuple[torch.dtype]: the dtype of each tensor + @param node IRFwOperation: forward operation - @return exist bool: True if the item exists else False + @return exist bool: True if the performance is recorded, else False """ - name = func if isinstance(func, str) else func.__name__ - if name not in self._data: + key = self._serialize(node) + if node.signature not in self._data: return False - key = self.serialize(self, shapes, dtypes) - if key not in self._data[key]: + if key not in self._data[node.signature]: return False return True - def exist_func(self, func: NameOrFunc) -> bool: + def query(self, node: IRFwOperation) -> float: """! - Check if the required function exists + Get the performance number of a node in IRGraph - @param name Union[str, Callable]: function name + @param node IRFwOperation: node in IRGraph - @return exist bool: True if the function exists else False + @return fw_span float: forward span in milliseconds + @return bw_span float: backward span in milliseconds + @return mem int: peak memory consumpiton of forward + backward procedure. """ - name = func if isinstance(func, str) else func.__name__ - return name in self._data - - def shapes_and_dtypes(self, func: NameOrFunc) -> Tuple[ShapesDTypes]: + key = self._serialize(node) + if node.signature not in self._data: + return None + if key not in self._data[node.signature]: + return None + return self._data[node.signature][key] + + def query_func(self, signature, shapes, dtypes): """ - Get recorded shapes and dtypes of the func. - - @param func UnShapesDTypesion[str, Callable]: function name + Get performance number of given name (signature), shapes and dtypes + + @param signature str: function signature + @param shapes Tuple[Tuple[int]]: the shape of each input tensor + @param dtypes Tuple[torch.dtype]: the dtype of each tensor - @return shapes_and_dtypes Tuple[ShapesDTyptes] + @return fw_span float: forward span in milliseconds + @return bw_span float: backward span in milliseconds + @return mem int: peak memory consumpiton of forward + backward procedure. """ - name = func if isinstance(func, str) else func.__name__ - rets = [] - for shapes_dtypes_str in self._data[name].keys(): - (shapes, dtypes) = self.deserialize(shapes_dtypes_str) - rets.append((shapes, dtypes)) - return tuple(rets) - - def serialize(self, shapes: Shapes, dtypes: DTypes) -> str: + key = self._serialize(shapes, dtypes) + if signature not in self._data: + return None + if key not in self._data[signature]: + return None + return self._data[signature][key] + + def query_args(self, signature: str) -> Tuple[List[Shapes], List[DTypes]]: + """ + Get the recorded shapes and dtypes of + """ + item_shapes, item_dtypes = [], [] + if signature not in self._data: + return item_shapes, item_dtypes + for shapes_dtypes_str in self._data[torch.signature].keys(): + shapes, dtypes = self._deserialize(shapes_dtypes_str) + item_shapes.append(shapes) + item_dtypes.append(dtypes) + return item_shapes, item_dtypes + + def _serialize(self, node: IRFwOperation) -> str: """ Serialize the shapes, dtypes and kwargs into a string e.g., shapes: ((1024,), (1024,1024)) dtypes: (torch.float32, torch.float32) - => (1024,)-(1024,1024)=torch.float32-torch.float32 + => (1024,)-(1024,1024) : torch.float32-torch.float32 @param shapes Tuple[Tuple[int]]: the shape of each tensor @param dtypes Tuple[torch.dtype]: the dtype of each tensor @return key str: the serialized string """ + shapes, dtypes = [], [] + for t in node.inputs(): + assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) shapes = '-'.join(str(tuple(shape)) for shape in shapes) - if dtypes is not None: - dtypes = '-'.join(str(dtype) for dtype in dtypes) - else: - dtypes = '-'.join([str(torch.float32)] * len(shapes)) - return shapes + '=' + dtypes + dtypes = '-'.join(str(dtype) for dtype in dtypes) + return shapes + ' : ' + dtypes - def deserialize(self, key: str) -> ShapesDTypes: + def _deserialize(self, key: str) -> ShapesDTypes: """ De-serialize the key string to shapes and dtypes @@ -222,8 +262,7 @@ def deserialize(self, key: str) -> ShapesDTypes: @param key str: the serialized string @return shapes_and_dtypes ShapesDTypes: shapes and dtypes """ - shapes, dtypes = key.split('=') - print(shapes) + shapes, dtypes = key.split(' : ') shapes = tuple(eval(shape) for shape in shapes.split('-')) dtypes = tuple(eval(dtype) for dtype in dtypes.split('-')) return shapes, dtypes @@ -250,45 +289,12 @@ def load(self, file: str): with open(file, 'r') as f: self._data = json.load(f) - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser(description='database') - parser.add_argument('--export', type=str, default='./profile.dat.json', - help='saved profiling database') - args = parser.parse_args() - - db = ProfileDataBase() - - # profile - dtype = torch.float32 - # func: [ - # [shapes, dtypes, kwargs], - # ] - funcs = { - torch.nn.functional.gelu: [ - [((1024, 8, 2304),), (dtype,), {}] - ], - - torch.nn.functional.linear: [ - [([1024, 1, 2304], [2304, 2304]), (dtype, dtype), {}], - [([1024, 4, 2304], [2304, 2304]), (dtype, dtype), {}], - [([1024, 8, 2304], [2304, 2304]), (dtype, dtype), {}] - ], - - torch.nn.functional.softmax: [ - [((1024, 8, 2304),), (dtype,), dict(dim=-1)] - ] - } - - for func, keys in funcs.items(): - for shapes, dtypes, kwargs in keys: - db.profile(func, shapes, dtypes, **kwargs) - - db.dump(args.export, override=True) - - # db = ProfileDataBase(args.export) - # for shapes, dtypes in db.shapes_and_dtypes(torch.nn.functional.linear): - # span = db.query(torch.nn.functional.linear, shapes, dtypes) - # print(f'logged shapes: {shapes}, dtypes: {dtypes} => span: {span} ms') + def __repr__(self) -> str: + data = [] + for signature in self._data: + for key in self._data[signature]: + shapes, dtypes = self._deserialize(key) + fw_span, bw_span, mem = self._data[signature][key] + data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, fw span: {fw_span} ms, bw span: {bw_span} ms, mem {mem} bytes') + data = '\n'.join(data) + return data diff --git a/cube/program.py b/cube/program.py index a920f49e..b03b0dd0 100644 --- a/cube/program.py +++ b/cube/program.py @@ -1,5 +1,4 @@ from typing import List, Tuple, Optional -from cube.graph.torch_dtype_mapping import DType2IRDType from cube.ir.cten import IRCell, IRTensor from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -7,6 +6,7 @@ from cube.graph import IRGraph from cube.graph import parser +from cube.graph.parser.mapping import DType2IRDType from cube.runtime.syndata import CubeDataLoader from cube.runtime.module import CubeModule From cf737deba53add52e28186b2cd2bcc76ce0ad714 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 31 Oct 2022 16:04:05 +0800 Subject: [PATCH 1107/1892] add inference memory --- cube/profiler/database.py | 78 ++++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 85630166..43f0dce6 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -38,7 +38,8 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, @return fw_span float: the time in milliseconds for forward time @return bw_span float: the time in milliseconds for backward time - @return memory int: the peak memory in bytes after forward + @return infer_memory int: the peak memory in bytes after inference of the function + @return train_memory int: the peak memory in bytes after forward with autograd enabled """ assert len(shapes) == len(dtypes), \ f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" @@ -60,10 +61,15 @@ def run_step(func, tensors, kwargs, backward: bool): torch.autograd.backward(outputs, grads) return outputs - # warmup - tic = time.time() - while time.time() - tic < warmup_sec: - run_step(func, tensors, kwargs, backward=True) + # profile inference peak memory + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + mtic = torch.cuda.max_memory_allocated() # in bytes + with torch.no_grad(): + run_step(func, tensors, kwargs, backward=False) + mtoc = torch.cuda.max_memory_allocated() # in bytes + infer_memory = mtoc - mtic torch.cuda.synchronize() torch.cuda.empty_cache() @@ -71,7 +77,12 @@ def run_step(func, tensors, kwargs, backward: bool): mtic = torch.cuda.max_memory_allocated() # in bytes outs = run_step(func, tensors, kwargs, backward=False) mtoc = torch.cuda.max_memory_allocated() # in bytes - memory = mtoc - mtic + train_memory = mtoc - mtic + + # warmup + tic = time.time() + while time.time() - tic < warmup_sec: + run_step(func, tensors, kwargs, backward=True) # profile forward only torch.cuda.synchronize() @@ -93,7 +104,7 @@ def run_step(func, tensors, kwargs, backward: bool): fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds bw_span = fwbw_span - fw_span - return fw_span, bw_span, memory + return fw_span, bw_span, infer_memory, train_memory class ProfileDataBase: @@ -132,11 +143,12 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): Profile a forward node in IRGraph on a specific device (default current device) @param node IRFwOperation: node of IRGraph - @param device int: the node + @param device int: the device that the node will execute on - @return fw_span float: forward span in milliseconds - @return bw_span float: backward span in milliseconds - @return mem int: peak memory consumpiton of forward + backward procedure. + @return fw_span float: the forward span time in milliseconds + @return bw_span float: the backward span time in milliseconds + @return infer_memory int: the peak memory in bytes after inference of the function + @return train_memory int: the peak memory in bytes after forward with autograd enabled """ fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) @@ -145,24 +157,36 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): torch.cuda.set_device(device) # run profiling - fw_span, bw_span, memory = CompProfiler.profile(fn, shapes, dtypes, **kwargs) + fw_span, bw_span, infer_memory, train_memory = \ + CompProfiler.profile(fn, shapes, dtypes, **kwargs) # log to database key = self._serialize(node) - self.insert(node.signature, key, fw_span, bw_span, memory) - print(f'profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} => fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} | mem: {memory}') + self.insert(node.signature, key, fw_span, bw_span, infer_memory, train_memory) + print( + f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " + f"=> fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} | " + f"infer mem: {infer_memory} | train mem: {train_memory}") if isinstance(device, int): torch.cuda.set_device(orig_device) - return fw_span, bw_span, memory + return fw_span, bw_span, infer_memory, train_memory - def insert(self, name: str, key: str, fw_span: float, bw_span: float, memory: float): + def insert(self, name: str, key: str, fw_span: float, bw_span: float, + infer_memory: int, train_memory: int): """ - log the span of a function name with key + log the span of a function name with key + + @param name str: the function signature + @param key str: the encoded shapes and dtypes of node inputs + @param fw_span float: the forward span time in milliseconds + @param bw_span float: the backward span time in milliseconds + @param infer_memory int: the peak memory in bytes after inference of the function + @param train_memory int: the peak memory in bytes after forward with autograd enabled """ assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = (fw_span, bw_span, memory) + self._data[name][key] = (fw_span, bw_span, infer_memory, train_memory) def exist(self, node: IRFwOperation) -> bool: """ @@ -179,15 +203,16 @@ def exist(self, node: IRFwOperation) -> bool: return False return True - def query(self, node: IRFwOperation) -> float: + def query(self, node: IRFwOperation) -> Tuple[float, float, int, int]: """! Get the performance number of a node in IRGraph @param node IRFwOperation: node in IRGraph - @return fw_span float: forward span in milliseconds - @return bw_span float: backward span in milliseconds - @return mem int: peak memory consumpiton of forward + backward procedure. + @return fw_span float: the forward span time in milliseconds + @return bw_span float: the backward span time in milliseconds + @return infer_memory int: the peak memory in bytes after inference of the function + @return train_memory int: the peak memory in bytes after forward with autograd enabled """ key = self._serialize(node) if node.signature not in self._data: @@ -196,7 +221,7 @@ def query(self, node: IRFwOperation) -> float: return None return self._data[node.signature][key] - def query_func(self, signature, shapes, dtypes): + def query_func(self, signature, shapes, dtypes) -> Tuple[float, float, int, int]: """ Get performance number of given name (signature), shapes and dtypes @@ -204,9 +229,10 @@ def query_func(self, signature, shapes, dtypes): @param shapes Tuple[Tuple[int]]: the shape of each input tensor @param dtypes Tuple[torch.dtype]: the dtype of each tensor - @return fw_span float: forward span in milliseconds - @return bw_span float: backward span in milliseconds - @return mem int: peak memory consumpiton of forward + backward procedure. + @return fw_span float: the forward span time in milliseconds + @return bw_span float: the backward span time in milliseconds + @return infer_memory int: the peak memory in bytes after inference of the function + @return train_memory int: the peak memory in bytes after forward with autograd enabled """ key = self._serialize(shapes, dtypes) if signature not in self._data: From af76c258431a7df89a61f6ba0aa2e0a1c01a5217 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 31 Oct 2022 17:02:47 +0800 Subject: [PATCH 1108/1892] save work --- cube/graph/function/__init__.py | 2 +- cube/profiler/database.py | 78 +++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/__init__.py b/cube/graph/function/__init__.py index fc28ba75..b0791047 100644 --- a/cube/graph/function/__init__.py +++ b/cube/graph/function/__init__.py @@ -1,2 +1,2 @@ -from cube.graph.function.dimops import IRDimops +from cube.graph.function.dimops import IRDimops, DimAnno from cube.graph.function.function import * \ No newline at end of file diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 85630166..2849293b 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -7,11 +7,15 @@ import time import os import json +from collections import deque import cube from cube.ir.cten import IRTensor from cube.ir.operator import IRFwOperation from cube.graph.parser.mapping import Sign2Op, IRDType2TorchDType +from cube.ir.tensor import IRSubTensor +from cube.graph.function import DimAnno +import copy Shapes = NewType('Shapes', Tuple[Tuple[int]]) @@ -298,3 +302,77 @@ def __repr__(self) -> str: data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, fw span: {fw_span} ms, bw span: {bw_span} ms, mem {mem} bytes') data = '\n'.join(data) return data + + +def collect_split_info(node: IRFwOperation): + # TODO(yizhu1): workaround + split_batch_ops = {} + + anno = node.anno + + split_info = {} + + for idx_shape, shape_anno in enumerate(anno.inputs()): + if not isinstance(node.inputs()[idx_shape], IRSubTensor): + continue + for idx_dim, dim_anno in enumerate(shape_anno.dims): + for idx_id, identifier in enumerate(dim_anno.identifiers): + if dim_anno.reduces[idx_id] == DimAnno.ReduceType.Freeze: + continue + if identifier in split_info: + split_info[identifier].append((idx_shape, idx_dim, idx_id)) + else: + split_info[identifier] = [(idx_shape, idx_dim, idx_id)] + + if node.signature in split_batch_ops: + for key, val in split_info.items(): + if (0, 0, 0) in val: + return {key: val} + assert False + else: + return split_info + +def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: + split_info = collect_split_info(node) + + def gen_hash(node: IRFwOperation) -> str: + ret = node.signature + for it in node.inputs(): + ret = ret + '-' + str(it.shape) + return ret + + dq = deque() + visited = set() + dq.append((node, ngpus)) + visited.add(gen_hash(node)) + + gen_nodes = [] + + while dq: + cur_node, cur_ngpus = dq.popleft() + gen_nodes.append(cur_node) + + for key, val in split_info.items(): + idx_1st, dim_1st, _ = val[0] + dim_size = cur_node.inputs()[idx_1st].shape[dim_1st] + + # TODO(yizhu1): only consider powers of 2 currently + split_deg = 2 + while split_deg <= dim_size and split_deg <= cur_ngpus: + if dim_size % split_deg != 0: + break + + new_node = cur_node.algorithms('dim').instantiate(idx=idx_1st, dim=dim_1st, num=split_deg)[0] + new_ngpus = cur_ngpus // split_deg + + cur_key = gen_hash(new_node) + + split_deg = split_deg * 2 + + if cur_key in visited: + continue + + dq.append((new_node, new_ngpus)) + visited.add(cur_key) + + return gen_nodes From f3a6dd282c8e420c68567588005be6fc5bc6a23d Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 31 Oct 2022 17:08:39 +0800 Subject: [PATCH 1109/1892] fix bug --- cube/profiler/database.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index f866eb3e..1be2caf5 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -324,8 +324,8 @@ def __repr__(self) -> str: for signature in self._data: for key in self._data[signature]: shapes, dtypes = self._deserialize(key) - fw_span, bw_span, mem = self._data[signature][key] - data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, fw span: {fw_span} ms, bw span: {bw_span} ms, mem {mem} bytes') + fw_span, bw_span, infer_mem, train_mem = self._data[signature][key] + data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, fw span: {fw_span} ms, bw span: {bw_span} ms, infer mem {infer_mem} bytes, train mem {train_mem} bytes') data = '\n'.join(data) return data From 653458b0cf1376bec8002a68ab58bac2e5ba3c83 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 31 Oct 2022 18:33:28 +0800 Subject: [PATCH 1110/1892] refine code --- cube/algorithm/ops/dimops.py | 75 ++++++++++++++++++++++++++++++- cube/graph/function/__init__.py | 2 +- cube/profiler/database.py | 78 --------------------------------- 3 files changed, 75 insertions(+), 80 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index d9a77c27..19c8c38a 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -3,6 +3,8 @@ from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule from cube.ir.tensor import IRSubTensor +from cube.ir.operator import IRFwOperation +from collections import deque class DimSplitEinops(GenericDistAlgo): @@ -291,4 +293,75 @@ def instantiate(self, idx: int, dimi: int, dimo: int, num: int) -> Optional[List sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) sub_node.infer_shape() sub_nodes.append(sub_node) - return sub_nodes \ No newline at end of file + return sub_nodes + +def collect_split_info(node: IRFwOperation): + # TODO(yizhu1): workaround + split_batch_ops = {} + + anno = node.anno + + split_info = {} + + for idx_shape, shape_anno in enumerate(anno.inputs()): + if not isinstance(node.inputs()[idx_shape], IRSubTensor): + continue + for idx_dim, dim_anno in enumerate(shape_anno.dims): + for idx_id, identifier in enumerate(dim_anno.identifiers): + if dim_anno.reduces[idx_id] == DimAnno.ReduceType.Freeze: + continue + if identifier not in split_info: + split_info[identifier] = (idx_shape, idx_dim, idx_id) + + if node.signature in split_batch_ops: + for key, val in split_info.items(): + if val == (0, 0, 0): + return {key: val} + assert False + else: + return split_info + +def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: + split_info = collect_split_info(node) + + def gen_hash(node: IRFwOperation) -> str: + ret = node.signature + for it in node.inputs(): + ret = ret + '-' + str(it.shape) + return ret + + dq = deque() + visited = set() + dq.append((node, ngpus)) + visited.add(gen_hash(node)) + + gen_nodes = [] + + while dq: + cur_node, cur_ngpus = dq.popleft() + gen_nodes.append(cur_node) + + for key, val in split_info.items(): + idx_1st, dim_1st, _ = val + dim_size = cur_node.inputs()[idx_1st].shape[dim_1st] + + # TODO(yizhu1): only consider powers of 2 currently + split_deg = 2 + while split_deg <= dim_size and split_deg <= cur_ngpus: + if dim_size % split_deg != 0: + break + + new_node = cur_node.algorithms('dim').instantiate(idx=idx_1st, dim=dim_1st, num=split_deg)[0] + new_ngpus = cur_ngpus // split_deg + + cur_key = gen_hash(new_node) + + split_deg = split_deg * 2 + + if cur_key in visited: + continue + + dq.append((new_node, new_ngpus)) + visited.add(cur_key) + + return gen_nodes \ No newline at end of file diff --git a/cube/graph/function/__init__.py b/cube/graph/function/__init__.py index b0791047..fc28ba75 100644 --- a/cube/graph/function/__init__.py +++ b/cube/graph/function/__init__.py @@ -1,2 +1,2 @@ -from cube.graph.function.dimops import IRDimops, DimAnno +from cube.graph.function.dimops import IRDimops from cube.graph.function.function import * \ No newline at end of file diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 1be2caf5..a03e9da0 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -7,15 +7,11 @@ import time import os import json -from collections import deque import cube from cube.ir.cten import IRTensor from cube.ir.operator import IRFwOperation from cube.graph.parser.mapping import Sign2Op, IRDType2TorchDType -from cube.ir.tensor import IRSubTensor -from cube.graph.function import DimAnno -import copy Shapes = NewType('Shapes', Tuple[Tuple[int]]) @@ -328,77 +324,3 @@ def __repr__(self) -> str: data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, fw span: {fw_span} ms, bw span: {bw_span} ms, infer mem {infer_mem} bytes, train mem {train_mem} bytes') data = '\n'.join(data) return data - - -def collect_split_info(node: IRFwOperation): - # TODO(yizhu1): workaround - split_batch_ops = {} - - anno = node.anno - - split_info = {} - - for idx_shape, shape_anno in enumerate(anno.inputs()): - if not isinstance(node.inputs()[idx_shape], IRSubTensor): - continue - for idx_dim, dim_anno in enumerate(shape_anno.dims): - for idx_id, identifier in enumerate(dim_anno.identifiers): - if dim_anno.reduces[idx_id] == DimAnno.ReduceType.Freeze: - continue - if identifier in split_info: - split_info[identifier].append((idx_shape, idx_dim, idx_id)) - else: - split_info[identifier] = [(idx_shape, idx_dim, idx_id)] - - if node.signature in split_batch_ops: - for key, val in split_info.items(): - if (0, 0, 0) in val: - return {key: val} - assert False - else: - return split_info - -def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: - split_info = collect_split_info(node) - - def gen_hash(node: IRFwOperation) -> str: - ret = node.signature - for it in node.inputs(): - ret = ret + '-' + str(it.shape) - return ret - - dq = deque() - visited = set() - dq.append((node, ngpus)) - visited.add(gen_hash(node)) - - gen_nodes = [] - - while dq: - cur_node, cur_ngpus = dq.popleft() - gen_nodes.append(cur_node) - - for key, val in split_info.items(): - idx_1st, dim_1st, _ = val[0] - dim_size = cur_node.inputs()[idx_1st].shape[dim_1st] - - # TODO(yizhu1): only consider powers of 2 currently - split_deg = 2 - while split_deg <= dim_size and split_deg <= cur_ngpus: - if dim_size % split_deg != 0: - break - - new_node = cur_node.algorithms('dim').instantiate(idx=idx_1st, dim=dim_1st, num=split_deg)[0] - new_ngpus = cur_ngpus // split_deg - - cur_key = gen_hash(new_node) - - split_deg = split_deg * 2 - - if cur_key in visited: - continue - - dq.append((new_node, new_ngpus)) - visited.add(cur_key) - - return gen_nodes From 34ac4ab67365034eb52ae80c0c0a3f25caa926c5 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 31 Oct 2022 19:42:50 +0800 Subject: [PATCH 1111/1892] type typo --- cube/profiler/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index a03e9da0..5bfa00b0 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -25,7 +25,7 @@ class CompProfiler: @staticmethod def profile(func: Callable, shapes: Shapes, dtypes: DTypes, warmup_sec: float = 2, prof_times: int = 50, - **kwargs) -> Tuple[float, float, int]: + **kwargs) -> Tuple[float, float, int, int]: """ Profile a function From c01ec8bcf0af1d33473b9a7094b0408a5a64df2d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 1 Nov 2022 10:05:10 +0800 Subject: [PATCH 1112/1892] fix profiling with arguments starting with self. --- cube/profiler/database.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 5bfa00b0..6760359e 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -2,7 +2,7 @@ Usage: python -m cube.profiler.database --export ./profile.dat.json """ -from typing import Callable, Tuple, Union, Optional, Dict, NewType, List +from typing import Callable, Tuple, Union, Optional, Dict, NewType, List, Any import torch import time import os @@ -20,6 +20,10 @@ NameOrFunc = Union[str, Callable] +_train_module_ref: torch.nn.Module = torch.nn.Module().train() +_eval_module_ref: torch.nn.Module = torch.nn.Module().eval() + + class CompProfiler: @staticmethod @@ -49,7 +53,18 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, torch.rand(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=True) \ for shape, dtype in zip(shapes, dtypes) ) - outputs = func(*tensors, **kwargs) + # repalce kwargs starting with 'sekf.xxx' + train_kwargs, eval_kwargs = {}, {} + for name, value in kwargs.items(): + if isinstance(value, str) and value.startswith('self.'): + train_val = getattr(_train_module_ref, value[5:]) + eval_val = getattr(_eval_module_ref, value[5:]) + else: + train_val = eval_val = value + train_kwargs[name] = train_val + eval_kwargs[name] = eval_val + # run one sample + outputs = func(*tensors, **train_kwargs) outputs = (outputs,) if torch.is_tensor(outputs) else outputs assert all(torch.is_tensor(otensor) for otensor in outputs), \ f"{func.__name__}: require all the outputs to be tensors" @@ -67,7 +82,7 @@ def run_step(func, tensors, kwargs, backward: bool): torch.cuda.reset_peak_memory_stats() mtic = torch.cuda.max_memory_allocated() # in bytes with torch.no_grad(): - run_step(func, tensors, kwargs, backward=False) + run_step(func, tensors, eval_kwargs, backward=False) mtoc = torch.cuda.max_memory_allocated() # in bytes infer_memory = mtoc - mtic @@ -75,21 +90,21 @@ def run_step(func, tensors, kwargs, backward: bool): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() mtic = torch.cuda.max_memory_allocated() # in bytes - outs = run_step(func, tensors, kwargs, backward=False) + outs = run_step(func, tensors, train_kwargs, backward=False) mtoc = torch.cuda.max_memory_allocated() # in bytes train_memory = mtoc - mtic # warmup tic = time.time() while time.time() - tic < warmup_sec: - run_step(func, tensors, kwargs, backward=True) + run_step(func, tensors, train_kwargs, backward=True) # profile forward only torch.cuda.synchronize() tic = time.perf_counter() for _ in range(prof_times): with torch.no_grad(): - run_step(func, tensors, kwargs, backward=False) + run_step(func, tensors, eval_kwargs, backward=False) torch.cuda.synchronize() toc = time.perf_counter() fw_span = (toc - tic) / prof_times * 1000 # in milliseconds @@ -98,7 +113,7 @@ def run_step(func, tensors, kwargs, backward: bool): torch.cuda.synchronize() tic = time.perf_counter() for _ in range(prof_times): - run_step(func, tensors, kwargs, backward=True) + run_step(func, tensors, train_kwargs, backward=True) torch.cuda.synchronize() toc = time.perf_counter() fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds From 8cdc5643f31fbed5413997a2714f090e8db2a7e0 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 1 Nov 2022 10:06:34 +0800 Subject: [PATCH 1113/1892] save work for merge main --- cube/profiler/database.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 5bfa00b0..6c951f4c 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -45,9 +45,12 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" # create data dtypes = [torch.float32] * len(shapes) if dtypes is None else dtypes + def gen_torch_tensors(shape, dtype): + constructor = torch.zeros if dtype == torch.int64 else torch.rand + requires_grad = False if dtype == torch.int64 else True + return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) tensors = tuple( - torch.rand(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=True) \ - for shape, dtype in zip(shapes, dtypes) + gen_torch_tensors(shape, dtype) for shape, dtype in zip(shapes, dtypes) ) outputs = func(*tensors, **kwargs) outputs = (outputs,) if torch.is_tensor(outputs) else outputs @@ -152,6 +155,9 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): """ fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) + if self.exist(node): + return self.query(node) + if isinstance(device, int): orig_device = torch.cuda.current_device() torch.cuda.set_device(device) @@ -164,7 +170,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): self.insert(node.signature, key, fw_span, bw_span, infer_memory, train_memory) print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " - f"=> fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} | " + f"=> fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} ms | " f"infer mem: {infer_memory} | train mem: {train_memory}") if isinstance(device, int): From a7d00c18761ee92ec211b6f7af227c2cde3b6ea8 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 1 Nov 2022 14:09:35 +0800 Subject: [PATCH 1114/1892] typo --- cube/profiler/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 4b059a53..e73d7ec7 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -56,7 +56,7 @@ def gen_torch_tensors(shape, dtype): tensors = tuple( gen_torch_tensors(shape, dtype) for shape, dtype in zip(shapes, dtypes) ) - # repalce kwargs starting with 'sekf.xxx' + # repalce kwargs starting with 'self.xxx' train_kwargs, eval_kwargs = {}, {} for name, value in kwargs.items(): if isinstance(value, str) and value.startswith('self.'): From 0c7737fd73367cdba2d902aff9939e0a3b113a4e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 1 Nov 2022 19:10:02 +0800 Subject: [PATCH 1115/1892] Megatron policy for swin transformer --- examples/vision/swin/blocks/attention.py | 14 ++++ examples/vision/swin/policy/mpmd.py | 94 ++++++++++++++++++------ examples/vision/swin/train.py | 17 ++++- 3 files changed, 100 insertions(+), 25 deletions(-) diff --git a/examples/vision/swin/blocks/attention.py b/examples/vision/swin/blocks/attention.py index a0be2db3..c38c6987 100644 --- a/examples/vision/swin/blocks/attention.py +++ b/examples/vision/swin/blocks/attention.py @@ -58,6 +58,20 @@ def window_attn(x: torch.Tensor, return x +def init_relative_position_index(window_size: int) -> torch.Tensor: + coords_h = torch.arange(window_size) + coords_w = torch.arange(window_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size - 1 + relative_coords[:, :, 0] *= 2 * window_size - 1 + relative_position_index = relative_coords.sum(-1) # wh * ww, wh * ww + return relative_position_index + + class WindowAttention(torch.nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. diff --git a/examples/vision/swin/policy/mpmd.py b/examples/vision/swin/policy/mpmd.py index 52e41c14..c1bfdfe2 100644 --- a/examples/vision/swin/policy/mpmd.py +++ b/examples/vision/swin/policy/mpmd.py @@ -5,6 +5,7 @@ from cube.graph.function.anchor import IRGraphAnchor from cube.ir.cten import IRCell from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.segment import IRSegment from cube.graph.schedule.sched1f1b import IRSchedule1F1B @@ -42,7 +43,7 @@ def _group_to_transformers(fnodes) -> List[List[IRCell]]: anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] indices = [fnodes.index(anchor) for anchor in anchors] for lid, idx in enumerate(indices): - fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + fnodes[idx+1].comment = f'===> start of transformer layer {lid}' start = idx if lid != 0 else 0 end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) transformers.append(fnodes[start:end]) @@ -70,6 +71,16 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): graph.assign(sub_node, devid) return sub_nodes +def _coshard(graph: IRGraph, node: IRFwOperation, devid: int, **configs): + algo = node.algorithms('dim') + if node.recompute is not None: + graph.recompute([node]) + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + for sub_node in sub_nodes: + graph.assign(sub_node, devid) + return sub_nodes + # ========================= parallelisms ================================= def PASRoundRobin(graph: IRGraph, resource): @@ -126,12 +137,15 @@ def PAS1F1B(graph: IRGraph, resource): def PASMegatron(graph: IRGraph, resource): """ - 1F1B scheduling + Megatron policy with Data, Tensor, Pipeline Parallelism. """ dp_size = 1 tp_size = 2 pp_size = resource.ngpus // (dp_size * tp_size) - num_microbatch = resource.ngpus + # note coshard will only apply to first 4 tranformer blocks + coshard = 2 + recompute: bool = False + num_microbatch = 8 # device mesh dp_groups, pp_groups, tp_groups = \ @@ -140,28 +154,66 @@ def PASMegatron(graph: IRGraph, resource): print(f'pp groups: {pp_groups}') print(f'tp groups: {tp_groups}') - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: + return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] # group to transformer layers - transformers = _group_to_transformers(fnodes) + transformers = _group_to_transformers(graph.select(ntype=IRFwOperation)) + if recompute: + for transformer in transformers: + graph.recompute(transformer) - # staging + # group to stage: set each stage operators + fstages = [[] for _ in range(pp_size)] nlayer_per_stage = (len(transformers) // pp_size) for lid, fnodes in enumerate(transformers): - sid = min(lid // nlayer_per_stage, pp_size-1) - print(f'assigning {lid}-th transformer layer to stage {sid}: {tp_groups[sid]}') - for fnode in fnodes: - if fnode.name == 'window_attn': - _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) - elif fnode.name == 'feedforward': - _tp(graph, fnode, tp_groups[sid], idx=1, dim=0, num=tp_size) - else: - _replica(graph, fnode, tp_groups[sid]) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - _replica(graph, node, list(range(resource.ngpus))) + stage_id = min(lid // nlayer_per_stage, pp_size - 1) + fstages[stage_id] += fnodes + graph.staging(tuple(stages[0] for stages in fstages)) - strategy = IRSchedule1F1B(graph, num_microbatch, tp_groups) - graph.sched = strategy + dataloader = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[0] + + # partition dataloader + dls = _replica(graph, dataloader, [0]*dp_size) # graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) + for dp_idx, dl in enumerate(dls): + # only stage 0 needs dataloader + devices = [get_device(dp_idx, 0, tp_idx) for tp_idx in range(tp_size)] + _replica(graph, dl, devices) + + tid = 0 + + # staging + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + assert len(fstages) == pp_size + nlayer_per_stage = (len(transformers) // pp_size) + for pp_idx, fstage in enumerate(fstages): + for fnode in fstage.nodes(): + subnodes = [fnode] + if len(fnode.inputs()) == 0: continue # anchor + # tensor parallel -- FIXME: current restriction needs replica happen before partition + if fnode.name == 'window_attn' or fnode.name == 'feedforward': + subnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + elif fnode.name == 'linear': # the last embeding linear + subnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + else: + subnodes = _replica(graph, fnode, [0]*tp_size) + # data parallel + pnodes = [] + for tp_idx, subnode in enumerate(subnodes): + dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] + batch_dim = 0 if bs not in subnode.input(0).shape else subnode.input(0).shape.index(bs) + nodes = _tp(graph, subnode, devs=dp_devices, idx=0, dim=batch_dim, num=dp_size) + pnodes += nodes + subnodes = pnodes + # coshard + if fnode.name in ['window_attn', 'feedforward']: + if coshard > 1 and tid < 4: + for subnode in subnodes: + devid = subnode.device[0] + _coshard(graph, subnode, devid, idx=1, dim=0, num=coshard) + tid = tid + 1 if fnode.name == 'window_attn' else tid + + strategy = IRSchedule1F1B(graph, num_microbatch) + graph.predef_sched(strategy) return graph diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 240618fd..212dd76a 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -4,10 +4,12 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/vision/swin/train.py --policy PASMeshShard --fp16 + examples/vision/swin/train.py --policy PASMegatron --fp16 """ +import math import torch +from examples.vision.swin.blocks.attention import init_relative_position_index from examples.vision.swin.model import Config, SwinTransformer, ImageDataLoader import cube @@ -42,7 +44,8 @@ def train(): - batch_size = 4 + batch_size = 1 + load_content: bool = False cfg = Config() model = SwinTransformer() @@ -51,8 +54,8 @@ def train(): dtype = torch.float16 if args.fp16 else torch.float32 dataloader = ImageDataLoader(batch_size, cfg.img_size, cfg.num_classes, dtype=dtype) - model = cube.SemanticModel(model, dataloader.shapes) - @cube.compile(model, dataloader, PAS=PAS, override=True) + model = cube.SemanticModel(model) + @cube.compile(model, dataloader, PAS=PAS, override=True, load_content=load_content) def train_iter(model, dataloader): imgs = next(dataloader) loss = model(imgs) @@ -60,6 +63,12 @@ def train_iter(model, dataloader): # return loss model: torch.nn.Module = model.get_gen_module() + if not load_content: + for name, buffer in model.named_buffers(): + if 'rp_index' in name: + window_size = int(math.sqrt(buffer.size(0))) + buffer.copy_(init_relative_position_index(window_size).cuda()) + optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) torch.distributed.barrier() From 4b34fbe259a365c1466b081432408f89b6fb6603 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 1 Nov 2022 22:20:00 +0800 Subject: [PATCH 1116/1892] coshard optimization for faster accumulation --- cube/graph/gener/gen.py | 33 +++++++++++++++++++++++-------- cube/runtime/function/function.py | 5 ++++- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 8d83be5a..701d9269 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -416,6 +416,9 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens ) node = None + # get recomput group + rcid = set(producer.recompute for producer in devops[devid]) + rcid = list(rcid)[0] if len(rcid) == 1 else None # split dimension case if split_dim: @@ -470,12 +473,29 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens f"Users can try to adjust node ordering to meet with accum order\n" f"{graph.debug_tensor_map_str(ftensor)}" ) - # set accum input / output - node = Accum('cube.runtime.accum', ptensors) + + # === Optimization: quick accumulation to early release tensor + lhs, rhs = ptensors[0], None + for ptensor in ptensors[1:]: + rhs = ptensor + output = ftensor.like().select(ptensors[0].indmap, (0,1)) + node = Accum('cube.runtime.accum', [lhs, rhs]) + node.set_output(0, output) + node.device = devid + node.recompute = rcid + graph.insert(node, graph.index(ptensor.cell) + 1) + lhs = output + # remove last node for adaptation + graph.remove(node) + + # === Orignal way to at alst release tensor + # node = Accum('cube.runtime.accum', ptensors) + # # set gradient + # for idx, ptensor in enumerate(ptensors): + # node.input(idx).grad = ftensor.grad.select(ptensor.indmap, (0,1)) + + # set output node.set_output(0, new_ftensor.select(otensor.indmap, otensor.valmap)) - # set gradient - for idx, ptensor in enumerate(ptensors): - node.input(idx).grad = ftensor.grad.select(ptensor.indmap, (0,1)) node.output(0).grad = new_ftensor.grad.select(otensor.indmap, (0,1)) # no need for fusion, change the producer output to new tensor @@ -492,9 +512,6 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens bproducer.set_input(idx, ograd) else: node.device = devid - # set recompute - rcid = set(producer.recompute for producer in devops[devid]) - rcid = list(rcid)[0] if len(rcid) == 1 else None node.recompute = rcid # insert max_fid = max(graph.index(producer) for producer in devops[devid]) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 4267d8ce..28933b28 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -27,7 +27,10 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: """ accumulate tensors in to one tensor """ - return torch.sum(torch.stack(tensors, dim=0), dim=0) + if len(tensors) == 2: + return tensors[0] + tensors[1] + else: + return torch.sum(torch.stack(tensors, dim=0), dim=0) def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], From ed307ceefb53597535c34485a6aabddf379fd5fd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 2 Nov 2022 13:23:56 +0800 Subject: [PATCH 1117/1892] gather flags with CompileFlag --- cube/algorithm/ops/dimops.py | 10 +++++++--- cube/codegen/codegen.py | 11 ++++------- cube/flags.py | 19 +++++++++++++++++++ cube/runtime/device.py | 6 +++--- 4 files changed, 33 insertions(+), 13 deletions(-) create mode 100644 cube/flags.py diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 19c8c38a..dc412f84 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -6,6 +6,8 @@ from cube.ir.operator import IRFwOperation from collections import deque +from cube.flags import CompileFlag + class DimSplitEinops(GenericDistAlgo): """! @@ -100,10 +102,12 @@ def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] else: adim, reduce = 'Value', None - color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' - print(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") - if not satisfy: return None + + if CompileFlag.log_transform: + color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' + print(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") + if not satisfy: return None rule: TransformRule = self.infer(idx, dim, num) # transform diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 99e0a778..eabb5a74 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -25,10 +25,7 @@ from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock from cube.codegen.frontend_mapping import Sign2EmitRule -import os - -USE_NNFUSION = os.environ.get('USE_NNFUSION') -USE_JIT = os.environ.get('USE_JIT') +from cube.flags import CompileFlag def get_backward_callsite_io_tensors(bp_segment:IRSegment): @@ -377,7 +374,7 @@ def __init__(self, execplan: ExecutionPlan): 'import torch', 'import torch.utils.checkpoint as ckpt', 'import cube', '', ''] - if USE_NNFUSION: + if CompileFlag.use_nnfusion: self.init_code.extend(['import nnfusion', '']) # customized op code @@ -512,9 +509,9 @@ def gen(self, device: int, outfile=None, attach=False) -> str: return_code = f"return {', '.join(outputs)}" fb.insert_body(return_code) cb.insert_body('') - if USE_NNFUSION and name.startswith('segment'): + if CompileFlag.use_nnfusion and name.startswith('segment'): cb.insert_body('@nnfusion.jit') - if USE_JIT and name.startswith('segment'): + if CompileFlag.use_jit and name.startswith('segment'): cb.insert_body('@torch.jit.script_method') cb.insert_body(fb.code) diff --git a/cube/flags.py b/cube/flags.py new file mode 100644 index 00000000..86ed1776 --- /dev/null +++ b/cube/flags.py @@ -0,0 +1,19 @@ +""" +Environment flags for compiling options +""" + +import os + + +class CompileFlag: + + # ============== runtime ==================== + dev_mode = os.environ.get('SINGLE_DEV_MODE') # allow to use python xx.py + + # ============= loggings =================== + log_transform = os.environ.get('LOG_TRANSFORM') + + # ============ code generation =============== + use_nnfusion = os.environ.get('USE_NNFUSION') + use_jit = os.environ.get('USE_JIT') + diff --git a/cube/runtime/device.py b/cube/runtime/device.py index 9ed7a968..6d67bcaf 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -6,15 +6,15 @@ import torch import os +from cube.flags import CompileFlag + class DeviceGroup: class __DeviceGroup: def __init__(self): - single_device_mode = os.environ.get('SINGLE_DEV_MODE') - print(f'single_device_mode = {single_device_mode}') - if single_device_mode: + if CompileFlag.dev_mode: print(f"DeviceGroup init using single device mode...") self.rank = 0 self.world_size = 1 From 944ea5b478d8791bd77739f558924370e1d22b39 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 2 Nov 2022 13:24:24 +0800 Subject: [PATCH 1118/1892] support grad accum with nstages = 1 --- cube/runtime/schedule/sched1f1b.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py index 0f570a82..776d99de 100644 --- a/cube/runtime/schedule/sched1f1b.py +++ b/cube/runtime/schedule/sched1f1b.py @@ -18,6 +18,14 @@ def run(segment: Callable, # forward body num_microbatch: int, recompute=False): + # special case: num_stages == 1: use gradient accum + if num_stages == 1: + for _ in range(num_microbatch): + inputs = Schedule1F1B.dataloader_step(dataloader) + outputs = Schedule1F1B.forward_step(segment, *inputs) + input_grads = Schedule1F1B.backward_step(inputs, outputs, (None,)) + return + num_warmup_microbatches = num_stages - 1 - stage_id num_warmup_remaining = num_microbatch - num_warmup_microbatches From 047ff779539598052bedcd8db6e247a2af0c5ebc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 2 Nov 2022 13:24:57 +0800 Subject: [PATCH 1119/1892] fix coshard bug --- examples/vision/swin/policy/mpmd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/vision/swin/policy/mpmd.py b/examples/vision/swin/policy/mpmd.py index c1bfdfe2..74850138 100644 --- a/examples/vision/swin/policy/mpmd.py +++ b/examples/vision/swin/policy/mpmd.py @@ -73,7 +73,7 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): def _coshard(graph: IRGraph, node: IRFwOperation, devid: int, **configs): algo = node.algorithms('dim') - if node.recompute is not None: + if node.recompute is None: graph.recompute([node]) sub_nodes = graph.partition(node, algo, **configs) assert sub_nodes is not None From 853b985828237b3718e273e760d24760bafd931e Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Wed, 2 Nov 2022 17:52:02 +0800 Subject: [PATCH 1120/1892] save profile ret --- cube/compiler.py | 3 +- cube/profiler/memory.py | 2 +- examples/nlp/blocks/attention.py | 32 ++++----- examples/nlp/gpt/model.py | 4 +- tests/gpt_profile.md | 7 ++ tests/test_profile_gpt.py | 111 +++++++++++++++++++++++++++++++ 6 files changed, 139 insertions(+), 20 deletions(-) create mode 100644 tests/gpt_profile.md create mode 100644 tests/test_profile_gpt.py diff --git a/cube/compiler.py b/cube/compiler.py index 96f0c52e..d9587a2e 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -91,7 +91,8 @@ def decorator(fn: Callable) -> Callable: print('warning: dataloader batch size stay as default.') # load module code print_each_rank(f'loading existed module from {filename} ...') - model.load_module(filename, load_content=load_content) + # model.load_module(filename, load_content=load_content) + model.load_module(filename) # load schedule code print_each_rank(f'loading existed schedule from {filename} ...') return _load_tschedule_fn(filename) diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index 7b82ca93..957b32e2 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -14,7 +14,7 @@ def memory_summary(): mem = torch.cuda.max_memory_allocated() # mem = torch.cuda.max_memory_reserved() print_each_rank( - '{:.2f}GB memory consumption'.format(mem / 1024 / 1024 / 1024), + '{:.2f} GB memory consumption'.format(mem / 1024 / 1024 / 1024), ) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index eba5d666..0c3ac0b6 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -19,22 +19,22 @@ def self_attention(query: torch.Tensor, v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d # ======== replace the semantic into more efficient implementation ============ - # q = q.transpose(0, 1) # L (N h) d -> (N h) L d - # k = k.transpose(0, 1) # L (N h) d -> (N h) L d - # q = q * scale # (N h) L d, 1 -> (N h) L d - # k = k.transpose(1, 2) # (N h) L d -> (N h) d L - # attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + q = q.transpose(0, 1) # L (N h) d -> (N h) L d + k = k.transpose(0, 1) # L (N h) d -> (N h) L d + q = q * scale # (N h) L d, 1 -> (N h) L d + k = k.transpose(1, 2) # (N h) L d -> (N h) d L + attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - # preallocating input tensor: (N h) L L - matmul_input_buffer = torch.empty([N * h, L, L], dtype=query.dtype, device=query.device) - # L (N h) d, L (N h) d -> (N h) L L - attn = torch.baddbmm( - matmul_input_buffer, - q.transpose(0, 1), # (N h) L d - k.transpose(0, 1).transpose(1, 2), # (N h) d L - beta=0.0, alpha=scale - ) - # ======== replace the semantic into more efficient implementation ============ + # # preallocating input tensor: (N h) L L + # matmul_input_buffer = torch.empty([N * h, L, L], dtype=query.dtype, device=query.device) + # # L (N h) d, L (N h) d -> (N h) L L + # attn = torch.baddbmm( + # matmul_input_buffer, + # q.transpose(0, 1), # (N h) L d + # k.transpose(0, 1).transpose(1, 2), # (N h) d L + # beta=0.0, alpha=scale + # ) + # # ======== replace the semantic into more efficient implementation ============ # attention mask if mask: # (N h) L L -> (N h) L L @@ -182,7 +182,7 @@ def forward(self, query): attn = self_attention( query, self.qkv_proj, self.qkv_bias, self.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=True + self.num_heads, self.scaling, self.dropout_p, mask=False ) attn = attn + self.out_bias return attn diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index a8df651d..b689d4de 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -12,7 +12,7 @@ class Config: # toy model embed_dim = 1024 - layers = 8 # 96 + layers = 4 # 96 attention_heads = 16 # # 1 layer of 175B model @@ -37,7 +37,7 @@ class Config: # 6.7 B model # embed_dim = 4096 - # layers = 32 + # layers = 1 # attention_heads = 32 # 15 B model diff --git a/tests/gpt_profile.md b/tests/gpt_profile.md new file mode 100644 index 00000000..8d8ea620 --- /dev/null +++ b/tests/gpt_profile.md @@ -0,0 +1,7 @@ +| layer | end2end | param | activation | e2e - 4 * p | activation2 | +|:------|:--------|:------|:-----------|:------------|:------------| +| 1 | 1.59 | 0.24 | 0.54 | 0.63 | 0.47 | +| 2 | 1.98 | 0.29 | 0.86 | 0.82 | 0.73 | +| 4 | 2.78 | 0.38 | 1.51 | 1.26 | 1.24 | +| 8 | 4.37 | 0.57 | 2.79 | 2.09 | 2.26 | +| 16 | 7.55 | 0.95 | 5.39 | 3.75 | 4.30 | \ No newline at end of file diff --git a/tests/test_profile_gpt.py b/tests/test_profile_gpt.py new file mode 100644 index 00000000..65c0da67 --- /dev/null +++ b/tests/test_profile_gpt.py @@ -0,0 +1,111 @@ +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=1 \ + examples/nlp/gpt/train.py --fp16 +""" + + +import torch +import time + +from examples.nlp.gpt.model import GPT +from examples.nlp.gpt.model import GPTDataLoader + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary, model_summary + +from examples.nlp.gpt.policy.mpmd import PASMegatron as PAS +import examples.nlp.gpt.policy.spmd as spmd +import examples.nlp.gpt.policy.mpmd as mpmd + +import argparse + +from cube.ir.operator import IRFwOperation, IRBpOperation +from cube.profiler.database import ProfileDataBase +from cube.algorithm.ops.dimops import gen_partitions +from cube.graph.function.anchor import IRGraphAnchor + +parser = argparse.ArgumentParser(description='GPT Train') +parser.add_argument('--fp16', action='store_true', default=False, + help='use fp16 for the training') +args = parser.parse_args() + +cube.init() + + +def train(): + batch_size = 1 + + model = GPT() + model = model if not args.fp16 else model.half() + dataloader = GPTDataLoader(batch_size) + + model = cube.SemanticModel(model, dataloader.shapes) + + def profile(graph, resource): + db = ProfileDataBase() + mem_sum = 0 + for node in graph.select(ntype=IRFwOperation): + if isinstance(node, IRGraphAnchor): + continue + partition_nodes = gen_partitions(node, 1) + for partition_node in partition_nodes: + fw_span, bw_span, infer_mem, train_mem = db.profile(partition_node) + mem_sum = mem_sum + train_mem + db.dump('db.json', override=True) + print('estimated train mem: ', mem_sum / 1024 / 1024 / 1024) + + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + return graph + + @cube.compile(model, dataloader, PAS=profile, override=True) + def train_iter(model, dataloader): + input_ids, position_ids = next(dataloader) + loss = model(input_ids, position_ids) + loss.backward() + model = model.get_gen_module() + + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + print_each_rank('model weight consumpition:', rank_only=0) + memory_summary() + + # CudaTimer(enable=False).warmup() + iter_num = 4 + warmup = 2 + for step in range(iter_num): + if step == warmup: + CudaTimer(enable=True).start('e2e') + + train_iter(model, dataloader) + memory_summary() + optimizer.step() + memory_summary() + optimizer.zero_grad() + memory_summary() + + if step == 0: + print_each_rank('passed first iteration') + if (step + 1) % 10 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + CudaTimer().stop('e2e') + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) + + memory_summary() + + +if __name__ == '__main__': + + cube.init() + train() \ No newline at end of file From 155bf379a4d695b9b87cf6ee3a4851df25958bc9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Nov 2022 12:59:04 +0800 Subject: [PATCH 1121/1892] add dgx1 reorder gpu script --- scripts/dgx1_reorder_gpu.py | 109 ++++++++++++++++++++++++++++++++++++ scripts/sync.sh | 22 ++++++++ 2 files changed, 131 insertions(+) create mode 100644 scripts/dgx1_reorder_gpu.py create mode 100755 scripts/sync.sh diff --git a/scripts/dgx1_reorder_gpu.py b/scripts/dgx1_reorder_gpu.py new file mode 100644 index 00000000..9a1dae12 --- /dev/null +++ b/scripts/dgx1_reorder_gpu.py @@ -0,0 +1,109 @@ +""" +Reorder GPU index by finding DGX-1 topology Find dgx topology + +┌───────────┐ +1 = 0 = 4 = 5 +‖ x | | x ‖ +2 = 3 = 7 = 6 +└───────────┘ + +""" +from typing import List +import subprocess +import numpy as np + +_kConnType = { + "NV1": 1, + "NV2": 2, + "NODE": 3, + "X": -1, +} + +_kConnTypeStr = {val: key for key, val in _kConnType.items()} + + + +def get_topology(): + cmds = [ + 'nvidia-smi', + 'topo', + '-m', + ] + + proc = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + outputs = stdout.decode('utf-8').split('\n') + + outputs = [out for out in outputs if out.startswith('GPU')] + ngpus = len(outputs) + print(f'Detected GPU number: {ngpus}') + + topology = np.empty((ngpus, ngpus), dtype=int) + for src, output in enumerate(outputs): + connections = output.split('\t')[1:1+ngpus] + for dst, link in enumerate(connections): + link = link.replace(" ", "") + assert link in _kConnType, f"Find link not in DGX-1 topology: {link}" + topology[src, dst] = _kConnType[link] + return topology + + +def topology_repr(topology: np.ndarray, reorder: List[int]): + reorder = list(reorder) + ngpus = topology.shape[0] + reorder_topo = np.empty((ngpus, ngpus), dtype=object) + for src in range(ngpus): + for dst in range(ngpus): + link = _kConnTypeStr[topology[src, dst]] + reorder_topo[reorder.index(src), reorder.index(dst)] = link + maxlen = max(len(key) for key in _kConnType) + dscp = '' + for gidx, line in enumerate(reorder_topo): + dscp += f'GPU{gidx}: '+ ' '.join(link.ljust(maxlen) for link in line) + '\n' + return dscp + + +def reorder(topology: np.ndarray) -> np.ndarray: + """ + Reorder GPU according to DGX-1 topology + + ┌───────────┐ + 1 = 0 = 4 = 5 + ‖ x | | x ‖ + 2 = 3 = 7 = 6 + └───────────┘ + """ + ngpus = topology.shape[0] + # find NV2 ring + ring = [0] + while len(ring) < ngpus: + nv2s = np.where(topology[ring[-1]] == _kConnType['NV2'])[0] + find_next = False + for gid in nv2s: + if gid not in ring: + ring.append(gid) + find_next = True + break + assert find_next + ring = np.array(ring, dtype=int) + print(f'Get ring: {ring}') + # find fc + for idx, src in enumerate(ring): + dst = ring[(src + 3) % len(ring)] + if topology[src, dst] == _kConnType['NV1']: + break + ring = np.roll(ring, 0-idx) + return ring + + +if __name__ == '__main__': + topology = get_topology() + print('original topology:') + print(topology_repr(topology, list(range(topology.shape[0])))) + reorder = reorder(topology) + print('reorder topology:') + print(topology_repr(topology, reorder)) + print( + f"Command need to be added into environment:\n" + f"export CUDA_VISIBLE_DEVICES={','.join(str(gid) for gid in reorder)}" + ) diff --git a/scripts/sync.sh b/scripts/sync.sh new file mode 100755 index 00000000..9a2a365f --- /dev/null +++ b/scripts/sync.sh @@ -0,0 +1,22 @@ +# ============= ITP Variables ============ +# NODE_RANK +# MASTER_IP +# MASTER_PORT +# ============= ITP Variables ============ + +node_num=$1 +folder=$2 + +host=worker + +if [ ${node_num} == 4 ] +then + scp -r /workspace/MagicCube/$folder $host-1:/workspace/MagicCube/ + scp -r /workspace/MagicCube/$folder $host-2:/workspace/MagicCube/ + scp -r /workspace/MagicCube/$folder $host-3:/workspace/MagicCube/ +fi + +if [ ${node_num} == 2 ] +then + scp -r /workspace/MagicCube/$folder $host-1:/workspace/MagicCube/ +fi \ No newline at end of file From f405735e5f9f7bb980d5c5dada0029b7e54ccca7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Nov 2022 05:29:09 +0000 Subject: [PATCH 1122/1892] fix reorder bug --- scripts/dgx1_reorder_gpu.py | 14 ++++++++++++-- scripts/sync-itp.sh | 23 ----------------------- scripts/sync-singularity.sh | 25 ------------------------- 3 files changed, 12 insertions(+), 50 deletions(-) delete mode 100755 scripts/sync-itp.sh delete mode 100755 scripts/sync-singularity.sh diff --git a/scripts/dgx1_reorder_gpu.py b/scripts/dgx1_reorder_gpu.py index 9a1dae12..aa312587 100644 --- a/scripts/dgx1_reorder_gpu.py +++ b/scripts/dgx1_reorder_gpu.py @@ -89,9 +89,19 @@ def reorder(topology: np.ndarray) -> np.ndarray: print(f'Get ring: {ring}') # find fc for idx, src in enumerate(ring): - dst = ring[(src + 3) % len(ring)] - if topology[src, dst] == _kConnType['NV1']: + is_fc = True + pairs = [ + (src, ring[(idx + 3) % len(ring)]), + (src, ring[(idx + 2) % len(ring)]), + (ring[(idx+1) % len(ring)], ring[(idx+3) % len(ring)]) + ] + for src, dst in pairs: + if topology[src, dst] != _kConnType['NV1']: + is_fc = False + break + if is_fc: break + assert is_fc, f"Cannot find FC group." ring = np.roll(ring, 0-idx) return ring diff --git a/scripts/sync-itp.sh b/scripts/sync-itp.sh deleted file mode 100755 index 588a2a99..00000000 --- a/scripts/sync-itp.sh +++ /dev/null @@ -1,23 +0,0 @@ -# ============= ITP Variables ============ -# NODE_RANK -# MASTER_IP -# MASTER_PORT -# ============= ITP Variables ============ - -node_num=$1 - -if [ ${node_num} == 4 ] -then - scp -r /workspace/MagicCube/handcraft worker-1:/workspace/MagicCube/ - scp -r /workspace/MagicCube/cube worker-1:/workspace/MagicCube/ - scp -r /workspace/MagicCube/handcraft worker-2:/workspace/MagicCube/ - scp -r /workspace/MagicCube/cube worker-2:/workspace/MagicCube/ - scp -r /workspace/MagicCube/handcraft worker-3:/workspace/MagicCube/ - scp -r /workspace/MagicCube/cube worker-3:/workspace/MagicCube/ -fi - -if [ ${node_num} == 2 ] -then - scp -r /workspace/MagicCube/handcraft worker-1:/workspace/MagicCube/ - scp -r /workspace/MagicCube/cube worker-1:/workspace/MagicCube/ -fi \ No newline at end of file diff --git a/scripts/sync-singularity.sh b/scripts/sync-singularity.sh deleted file mode 100755 index f8d4cf5d..00000000 --- a/scripts/sync-singularity.sh +++ /dev/null @@ -1,25 +0,0 @@ - -# ============= Singularity Variables ============ -# NODE_RANK -# MASTER_ADDR -# MASTER_PORT -# ============= Singularity Variables ============ - -node_num=$1 - -if [ ${node_num} == 4 ] -then - scp -r /workspace/MagicCube/handcraft node-1:/workspace/MagicCube/ - scp -r /workspace/MagicCube/cube node-1:/workspace/MagicCube/ - scp -r /workspace/MagicCube/handcraft node-2:/workspace/MagicCube/ - scp -r /workspace/MagicCube/cube node-2:/workspace/MagicCube/ - scp -r /workspace/MagicCube/handcraft node-3:/workspace/MagicCube/ - scp -r /workspace/MagicCube/cube node-3:/workspace/MagicCube/ -fi - -if [ ${node_num} == 2 ] -then - scp -r /workspace/MagicCube/handcraft node-1:/workspace/MagicCube/ - scp -r /workspace/MagicCube/cube node-1:/workspace/MagicCube/ -fi - From 3e801cbdb7b3f1431a0085628f81e60c53390546 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Thu, 3 Nov 2022 14:55:51 +0800 Subject: [PATCH 1123/1892] refine table --- tests/gpt_profile.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/gpt_profile.md b/tests/gpt_profile.md index 8d8ea620..8a6aad15 100644 --- a/tests/gpt_profile.md +++ b/tests/gpt_profile.md @@ -1,7 +1,7 @@ -| layer | end2end | param | activation | e2e - 4 * p | activation2 | -|:------|:--------|:------|:-----------|:------------|:------------| -| 1 | 1.59 | 0.24 | 0.54 | 0.63 | 0.47 | -| 2 | 1.98 | 0.29 | 0.86 | 0.82 | 0.73 | -| 4 | 2.78 | 0.38 | 1.51 | 1.26 | 1.24 | -| 8 | 4.37 | 0.57 | 2.79 | 2.09 | 2.26 | -| 16 | 7.55 | 0.95 | 5.39 | 3.75 | 4.30 | \ No newline at end of file +| layer | end2end | param | e2e - 4 * p | activation | activation2 | +|:------|:--------|:------|:------------|:-----------|:------------| +| 1 | 1.59 | 0.24 | 0.63 | 0.54 | 0.47 | +| 2 | 1.98 | 0.29 | 0.82 | 0.86 | 0.73 | +| 4 | 2.78 | 0.38 | 1.26 | 1.51 | 1.24 | +| 8 | 4.37 | 0.57 | 2.09 | 2.79 | 2.26 | +| 16 | 7.55 | 0.95 | 3.75 | 5.39 | 4.30 | \ No newline at end of file From e307a519af4cd5cac1b889bc11cb92245c2e7760 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Nov 2022 19:47:21 +0800 Subject: [PATCH 1124/1892] setup with deepspeed model --- benchmark/deepspeed/benchmark_gpt.sh | 136 ++++++++++++++++++++ benchmark/deepspeed/gpt_bench.py | 186 +++++++++++++++++++++++++++ 2 files changed, 322 insertions(+) create mode 100755 benchmark/deepspeed/benchmark_gpt.sh create mode 100644 benchmark/deepspeed/gpt_bench.py diff --git a/benchmark/deepspeed/benchmark_gpt.sh b/benchmark/deepspeed/benchmark_gpt.sh new file mode 100755 index 00000000..a73e45ae --- /dev/null +++ b/benchmark/deepspeed/benchmark_gpt.sh @@ -0,0 +1,136 @@ +#!/bin/bash +# run at MagicCube/ +# ./benchmark/deepspeed/benchmark_gpt.sh + +# get commit ID: +# git rev-parse --short HEAD + +# installation +# pip install deepspeed==0.7.4 +# git clone https://github.com/microsoft/Megatron-DeepSpeed +# git checkout 54f1cb7 + +# note DeepSpeed can do: +# 1) PP > 1 with constraints of Zero-Stage=1 +# 2) TP > 1 with constraints of Zero-Stage < 3 + +cp benchmark/deepspeed/pretrain_gpt_synthetic.py \ + benchmark/deepspeed/Megatron-DeepSpeed/ + +Nnodes=1 +TP=2 +PP=2 + +# Model arch +Layers=12 +Hidden=2048 +Heads=32 +Seqlen=2048 + +# batch size +Gbs=8 +Mbs=1 +Accum=$(( ${Gbs} / ( ${Nnodes} * 8 / ${TP} / ${PP} * ${Mbs} ) )) +echo "Accumulated steps: ${Accum}" + +# zero stage config +Zero=1 +OFFLOAD_DEVICE="none" +CPU_OPTIM=" " +#OFFLOAD_DEVICE="cpu" +#CPU_OPTIM=" --cpu-optimizer" + +cd benchmark/deepspeed/Megatron-DeepSpeed + +DS_CONFIG=ds_config.json + +cat < $DS_CONFIG +{ + "train_batch_size" : $Gbs, + "train_micro_batch_size_per_gpu": $Mbs, + "steps_per_print": 1, + "gradient_accumulation_steps": ${Accum}, + "zero_optimization": { + "stage": $Zero, + "stage3_max_live_parameters": 3e9, + "stage3_max_reuse_distance": 3e9, + "stage3_param_persistence_threshold": 1e5, + "stage3_prefetch_bucket_size": 5e7, + "contiguous_gradients": true, + "overlap_comm": true, + "reduce_bucket_size": 90000000, + "sub_group_size": 1e9, + "offload_optimizer": { + "device": "$OFFLOAD_DEVICE", + "buffer_count": 4, + "pipeline_read": false, + "pipeline_write": false, + "pin_memory": true + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "initial_scale_power" : 15, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "wall_clock_breakdown": true, + "zero_allow_untested_optimizer": false, + "aio": { + "block_size": 1048576, + "queue_depth": 16, + "single_submit": false, + "overlap_events": true, + "thread_count": 2 + } +} +EOT + +# export NCCL_DEBUG=warn + +ds_args=" " +ds_args=" --deepspeed ${ds_args}" +# ds_args=" --no-pipeline-parallel ${ds_args}" +ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" +ds_args=" --zero-stage=$Zero ${ds_args}" +ds_args=" --deepspeed-activation-checkpointing ${ds_args}" + + +GPT_ARGS="--num-layers $Layers \ + --hidden-size $Hidden \ + --num-attention-heads $Heads \ + --seq-length $Seqlen \ + --loss-scale 15 \ + --max-position-embeddings $Seqlen \ + --train-iters 3 \ + --lr 6.0e-5 \ + --min-lr 6.0e-6 \ + --lr-decay-style cosine \ + --fp16 \ + --fp16-lm-cross-entropy \ + --no-query-key-layer-scaling \ + --no-masked-softmax-fusion \ + --no-bias-gelu-fusion \ + --no-bias-dropout-fusion \ + --checkpoint-activations \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --init-method-std 0.006 \ + --log-interval 1 \ + --num-workers 0" + +# deepspeed --force_multi --num_nodes +deepspeed --num_nodes=$Nnodes --num_gpus 8 \ + --master_addr localhost --master_port 6144 \ + pretrain_gpt_synthetic.py \ + $GPT_ARGS $CPU_OPTIM $ds_args \ + --global-batch-size $Gbs \ + --micro-batch-size $Mbs \ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP + +cd ../../.. diff --git a/benchmark/deepspeed/gpt_bench.py b/benchmark/deepspeed/gpt_bench.py new file mode 100644 index 00000000..9980f7a4 --- /dev/null +++ b/benchmark/deepspeed/gpt_bench.py @@ -0,0 +1,186 @@ +""" +Following + +https://github.com/microsoft/DeepSpeedExamples/blob/master/HelloDeepSpeed/train_bert_ds.py + +Config file: +https://www.deepspeed.ai/docs/config-json/ + +deepspeed --num_nodes 1 --num_gpus 8 \ + benchmark/deepspeed/gpt_bench.py \ + --fp16 --mbs 1 --gbs 4 \ + --zero 2 \ + --layers 24 --heads 32 --hidden 2048 --seqlen 2048 + +""" + +from typing import List, Tuple +import torch +import time +import numpy as np +import os +import logging + +from examples.nlp.gpt.model import GPT, Config +from examples.nlp.gpt.model import GPTDataLoader + +import argparse +import deepspeed + +logging.getLogger().setLevel(logging.WARN) + + +parser = argparse.ArgumentParser(description='GPT Train') + +parser.add_argument('--fp16', action='store_true', default=False, + help='use fp16 for the training') +parser.add_argument('--mbs', type=int, default=1, + help='micro-batch size') +parser.add_argument('--gbs', type=int, default=256, + help='global batch size') +parser.add_argument('--zero', type=int, required=True, + help='zero stage, 2 or 3') +parser.add_argument('--layers', type=int, required=True) +parser.add_argument('--heads', type=int, required=True) +parser.add_argument('--seqlen', type=int, required=True) +parser.add_argument('--hidden', type=int, required=True) + +parser.add_argument('--local_rank', type=int) +args = parser.parse_args() + +print(args) +torch.cuda.set_device(args.local_rank) + +ds_zero3_config = { + "train_micro_batch_size_per_gpu": args.mbs, + "gradient_accumulation_steps": args.gbs // args.mbs, + "zero_optimization": { + "stage": 3, + "offload_param": { # Zero-3 + "device": "cpu" + }, + "offload_optimizer": { # Zero-2 + "device": "cpu" + }, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "betas": [0.9, 0.95] + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "wall_clock_breakdown": True, + "steps_per_print": 1, +} + + +ds_zero2_config = { + "train_micro_batch_size_per_gpu": args.mbs, + "gradient_accumulation_steps": args.gbs // args.mbs, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { # Zero-2 + "device": "cpu" + }, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "mp_size": 2, + "activation_checkpointing": { + "partition_activations": True, + "cpu_checkpointing": True, + "contiguous_memory_optimization": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "betas": [0.9, 0.95] + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "wall_clock_breakdown": True, + "steps_per_print": 1, +} + +assert args.zero in [2, 3], f"Zero stage can only be 2 or 3" +zero_config = ds_zero2_config if args.zero == 2 else ds_zero3_config + +def log_dist(message: str, ranks: List[int] = None) -> None: + my_rank = int(os.environ.get("RANK", "0")) + if my_rank in ranks: + print(f"rank [{my_rank}] {message}") + + +def train(): + + batch_size = args.mbs + Config.seqlen = args.seqlen + Config.layers = args.layers + Config.embed_dim = args.hidden + Config.attention_heads = args.heads + + model = GPT() + model = model if not args.fp16 else model.half() + + nparams = 0 + param: torch.Tensor + for param in model.parameters(): + nparams += param.nelement() + log_dist(f'parameter before zero: {nparams}', [0]) + + model, _, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=zero_config) + model.train() + log_dist("DeepSpeed engine created", ranks=[0]) + + nparams = 0 + param: torch.Tensor + for param in model.parameters(): + nparams += param.nelement() + log_dist(f'parameter after zero: {nparams}', [0]) + + dataloader = GPTDataLoader(batch_size) + + + iter_num = 3 + warmup = 1 + for step in range(iter_num): + if step == warmup: + torch.cuda.synchronize() + tic = time.time() + + data = next(dataloader) + loss = model(*data) + model.backward(loss) + model.step() + + if step == 0: + log_dist('passed first iteration', ranks=[0]) + if (step + 1) % 2 == 0: + log_dist(f'iter [{step + 1}/{iter_num}]', ranks=[0]) + torch.cuda.synchronize() + toc = time.time() + log_dist(f"iteration time: {(toc-tic) / (iter_num - warmup) * 1000} ms", ranks=[0]) + log_dist(f"Max allocated memory: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024} GB", [0]) + +if __name__ == '__main__': + train() \ No newline at end of file From f9f35f83b18dbb224ad04efcbaa685f0451cc383 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Nov 2022 19:53:07 +0800 Subject: [PATCH 1125/1892] add benchmark script of deepspeed --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 951faa42..715393f3 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__ .vscode/ benchmark/megatron/Megatron-LM +benchmark/deepspeed/Megatron-DeepSpeed gencode*.py fullmodel.pt From 4234390dd30575bec08c4a1facf2ff9384310ee3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 3 Nov 2022 19:54:19 +0800 Subject: [PATCH 1126/1892] deepspeed bench --- benchmark/deepspeed/pretrain_gpt_synthetic.py | 222 ++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 benchmark/deepspeed/pretrain_gpt_synthetic.py diff --git a/benchmark/deepspeed/pretrain_gpt_synthetic.py b/benchmark/deepspeed/pretrain_gpt_synthetic.py new file mode 100644 index 00000000..32de2eb2 --- /dev/null +++ b/benchmark/deepspeed/pretrain_gpt_synthetic.py @@ -0,0 +1,222 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain GPT""" + +import torch +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron import mpu +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import GPTModel, GPTModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +import os +import subprocess + +from torch import nn +import torch.nn.functional as F + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + see_memory_usage(f"Before Building Model", force=True) + + args = get_args() + vocab_size = 50257 + after = vocab_size + multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size + while after % multiple != 0: + after += 1 + args.padded_vocab_size = after + + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): + if args.deepspeed and not args.no_pipeline_parallel: + print_rank_0('building GPT model using DeepSpeed ...') + model = GPTModelPipe( + num_tokentypes=0, + parallel_output=True + ) + # This is a hack to give us a reference to get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + # Predompute the attention mask and store it in args. This avoids having to + # pipeline it as an activation during training. The mask is constant, and thus + # we can reuse it. + attention_mask = torch.tril(torch.ones( + (1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view( + 1, 1, args.seq_length, args.seq_length) + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + # Attention mask must be bool. + args.attn_mask = attention_mask.to(torch.bool) + else: + model = GPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + return_moe_loss=False + ) + + see_memory_usage(f"After Building Model", force=True) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + vocab_size = 50257 + tokens = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size + labels = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size + loss_mask = torch.ones(tokens.size(), dtype=torch.float, device=torch.cuda.current_device()) + attention_mask = torch.tril(torch.ones( + (args.micro_batch_size, args.seq_length, args.seq_length), device=torch.cuda.current_device() + )).view(args.micro_batch_size, 1, args.seq_length, args.seq_length) + attention_mask = (attention_mask < 0.5) + position_ids = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * args.seq_length + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def get_batch_pipe(data): + """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" + # args = get_args() + # tokenizer = get_tokenizer() + + # # Items and their type. + # keys = ['text'] + # datatype = torch.int64 + # + # # Broadcast data. + # data_b = mpu.broadcast_data(keys, data, datatype) + # + # # Unpack. + # tokens_ = data_b['text'].long() + # labels = tokens_[:, 1:].contiguous() + # tokens = tokens_[:, :-1].contiguous() + # + # # Get the masks and postition ids. + # attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + # tokens, + # tokenizer.eod, + # args.reset_position_ids, + # args.reset_attention_mask, + # args.eod_mask_loss) + # if args.curriculum_learning and args.curriculum_seqlen < tokens.size()[1]: + # # seqlen-based curriculum learning + # # tokens, position_ids, labels, loss_mask have size [batch size, seqlen] + # tokens = tokens[:, :args.curriculum_seqlen].contiguous() + # position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + # if labels is not None: + # labels = labels[:, :args.curriculum_seqlen].contiguous() + # loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()\ + tokens, labels, loss_mask, attention_mask, position_ids = get_batch(None) + + return (tokens, position_ids, attention_mask), (labels, loss_mask) + + + +def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + moe_loss = 0 + mos_loss = 0 + return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + return [1]*10000, None, None + + +def command_exists(cmd): + result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + return result.wait() == 0 + + +def git_ds_info(): + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****') + + +if __name__ == "__main__": + git_ds_info() + pretrain(train_valid_test_datasets_provider, model_provider, forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) + mem = torch.cuda.max_memory_allocated() + for rank in range(torch.distributed.get_world_size()): + if rank == torch.distributed.get_rank(): + print(f'rank[{rank}]: memory consumption: {round(mem / 1024 / 1024 / 1024 * 100) / 100} GBs') + torch.distributed.barrier() \ No newline at end of file From 671170254d113db5fe0bd74ef9e002bc32476d8a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 4 Nov 2022 04:57:07 +0000 Subject: [PATCH 1127/1892] add flag to sleep workers to avoid nccl time out --- cube/compiler.py | 4 ++++ cube/flags.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/cube/compiler.py b/cube/compiler.py index 96f0c52e..8dace4a3 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -21,6 +21,7 @@ from cube.runtime.syndata import CubeDataLoader from cube.program import Program, SemanticDataLoader, SemanticModel +from cube.flags import CompileFlag def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, @@ -186,6 +187,9 @@ def decorator(fn: Callable) -> Callable: print('> compile time: {:.2f} seconds'.format(compile_time)) if torch.distributed.is_initialized(): + if DeviceGroup().local_rank != 0 and CompileFlag.worker_sleep > 0: + print(f'rank [{DeviceGroup().rank}] starts sleeping {CompileFlag.worker_sleep} seconds...') + time.sleep(CompileFlag.worker_sleep) torch.distributed.barrier() # load module diff --git a/cube/flags.py b/cube/flags.py index 86ed1776..2b2c3ff4 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -13,6 +13,11 @@ class CompileFlag: # ============= loggings =================== log_transform = os.environ.get('LOG_TRANSFORM') + + # ================ compiling ======================== + # worker sleep in seconds + worker_sleep = int(os.environ.get('WORKER_SLEEP')) if os.environ.get('WORKER_SLEEP') is not None else 0 + # ============ code generation =============== use_nnfusion = os.environ.get('USE_NNFUSION') use_jit = os.environ.get('USE_JIT') From b98b0574fbd4a0c7223f6a7707f137a440068c13 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 4 Nov 2022 04:57:44 +0000 Subject: [PATCH 1128/1892] add 8 node sync --- scripts/sync.sh | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/scripts/sync.sh b/scripts/sync.sh index 9a2a365f..928958d5 100755 --- a/scripts/sync.sh +++ b/scripts/sync.sh @@ -9,6 +9,17 @@ folder=$2 host=worker +if [ ${node_num} == 8 ] +then + scp -r /workspace/MagicCube/$folder $host-1:/workspace/MagicCube/ + scp -r /workspace/MagicCube/$folder $host-2:/workspace/MagicCube/ + scp -r /workspace/MagicCube/$folder $host-3:/workspace/MagicCube/ + scp -r /workspace/MagicCube/$folder $host-4:/workspace/MagicCube/ + scp -r /workspace/MagicCube/$folder $host-5:/workspace/MagicCube/ + scp -r /workspace/MagicCube/$folder $host-6:/workspace/MagicCube/ + scp -r /workspace/MagicCube/$folder $host-7:/workspace/MagicCube/ +fi + if [ ${node_num} == 4 ] then scp -r /workspace/MagicCube/$folder $host-1:/workspace/MagicCube/ @@ -19,4 +30,11 @@ fi if [ ${node_num} == 2 ] then scp -r /workspace/MagicCube/$folder $host-1:/workspace/MagicCube/ -fi \ No newline at end of file +fi + + +# rm -f notify.py +# wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py +# python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ +# --msg "Test Results Swin Coshard | 32 GPU" \ +# --file logs/e2e-swin-32gpu-coshard-${NODE_RANK}.txt \ No newline at end of file From 0d42fa67bba50a672994909b6929db36187b7895 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Fri, 4 Nov 2022 21:41:30 +0800 Subject: [PATCH 1129/1892] save work --- cube/profiler/database.py | 21 ++++++++++++++++----- examples/nlp/blocks/attention.py | 32 ++++++++++++++++---------------- examples/nlp/gpt/model.py | 4 ++-- tests/gpt_profile.md | 14 +++++++------- tests/test_profile_gpt.py | 6 +++--- 5 files changed, 44 insertions(+), 33 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index e73d7ec7..040584f4 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -89,13 +89,24 @@ def run_step(func, tensors, kwargs, backward: bool): mtoc = torch.cuda.max_memory_allocated() # in bytes infer_memory = mtoc - mtic + train_memory = 0 + # ref torch/utils/checkpoint.py/_checkpoint_without_reentrant + def pack_hook(x): + nonlocal train_memory + byte_size = 1 + for dim in list(x.size()): + byte_size = byte_size * dim + byte_size *= x.element_size() + train_memory= train_memory + byte_size + return x + + def unpack_hook(x): + return x + torch.cuda.synchronize() torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - mtic = torch.cuda.max_memory_allocated() # in bytes - outs = run_step(func, tensors, train_kwargs, backward=False) - mtoc = torch.cuda.max_memory_allocated() # in bytes - train_memory = mtoc - mtic + with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + outs = run_step(func, tensors, train_kwargs, backward=False) # warmup tic = time.time() diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 0c3ac0b6..eba5d666 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -19,22 +19,22 @@ def self_attention(query: torch.Tensor, v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d # ======== replace the semantic into more efficient implementation ============ - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L + # q = q.transpose(0, 1) # L (N h) d -> (N h) L d + # k = k.transpose(0, 1) # L (N h) d -> (N h) L d + # q = q * scale # (N h) L d, 1 -> (N h) L d + # k = k.transpose(1, 2) # (N h) L d -> (N h) d L + # attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - # # preallocating input tensor: (N h) L L - # matmul_input_buffer = torch.empty([N * h, L, L], dtype=query.dtype, device=query.device) - # # L (N h) d, L (N h) d -> (N h) L L - # attn = torch.baddbmm( - # matmul_input_buffer, - # q.transpose(0, 1), # (N h) L d - # k.transpose(0, 1).transpose(1, 2), # (N h) d L - # beta=0.0, alpha=scale - # ) - # # ======== replace the semantic into more efficient implementation ============ + # preallocating input tensor: (N h) L L + matmul_input_buffer = torch.empty([N * h, L, L], dtype=query.dtype, device=query.device) + # L (N h) d, L (N h) d -> (N h) L L + attn = torch.baddbmm( + matmul_input_buffer, + q.transpose(0, 1), # (N h) L d + k.transpose(0, 1).transpose(1, 2), # (N h) d L + beta=0.0, alpha=scale + ) + # ======== replace the semantic into more efficient implementation ============ # attention mask if mask: # (N h) L L -> (N h) L L @@ -182,7 +182,7 @@ def forward(self, query): attn = self_attention( query, self.qkv_proj, self.qkv_bias, self.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=False + self.num_heads, self.scaling, self.dropout_p, mask=True ) attn = attn + self.out_bias return attn diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index b689d4de..a8df651d 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -12,7 +12,7 @@ class Config: # toy model embed_dim = 1024 - layers = 4 # 96 + layers = 8 # 96 attention_heads = 16 # # 1 layer of 175B model @@ -37,7 +37,7 @@ class Config: # 6.7 B model # embed_dim = 4096 - # layers = 1 + # layers = 32 # attention_heads = 32 # 15 B model diff --git a/tests/gpt_profile.md b/tests/gpt_profile.md index 8a6aad15..db4c0763 100644 --- a/tests/gpt_profile.md +++ b/tests/gpt_profile.md @@ -1,7 +1,7 @@ -| layer | end2end | param | e2e - 4 * p | activation | activation2 | -|:------|:--------|:------|:------------|:-----------|:------------| -| 1 | 1.59 | 0.24 | 0.63 | 0.54 | 0.47 | -| 2 | 1.98 | 0.29 | 0.82 | 0.86 | 0.73 | -| 4 | 2.78 | 0.38 | 1.26 | 1.51 | 1.24 | -| 8 | 4.37 | 0.57 | 2.09 | 2.79 | 2.26 | -| 16 | 7.55 | 0.95 | 3.75 | 5.39 | 4.30 | \ No newline at end of file +| layer | end2end | activation | param | e2e - 3 * p - activation | +|:------|:--------|:-----------|:------|:-------------------------| +| 1 | 1.59 | 0.47 | 0.24 | 0.40 | +| 2 | 1.98 | 0.73 | 0.29 | 0.38 | +| 4 | 2.78 | 1.24 | 0.38 | 0.40 | +| 8 | 4.37 | 2.26 | 0.57 | 0.40 | +| 16 | 7.55 | 4.30 | 0.95 | 0.40 | \ No newline at end of file diff --git a/tests/test_profile_gpt.py b/tests/test_profile_gpt.py index 65c0da67..c5448103 100644 --- a/tests/test_profile_gpt.py +++ b/tests/test_profile_gpt.py @@ -86,11 +86,11 @@ def train_iter(model, dataloader): CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) - memory_summary() + # memory_summary() optimizer.step() - memory_summary() + # memory_summary() optimizer.zero_grad() - memory_summary() + # memory_summary() if step == 0: print_each_rank('passed first iteration') From 6a66217b9cd9562c68dd2e8ed8ebda1e0bc7abd0 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Fri, 4 Nov 2022 22:57:35 +0800 Subject: [PATCH 1130/1892] refine for pull request --- cube/compiler.py | 1 - tests/{gpt_profile.md => gpt_memory_profile.md} | 2 ++ tests/test_profile_gpt.py | 7 +------ 3 files changed, 3 insertions(+), 7 deletions(-) rename tests/{gpt_profile.md => gpt_memory_profile.md} (91%) diff --git a/cube/compiler.py b/cube/compiler.py index d9587a2e..a4ea1a44 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -91,7 +91,6 @@ def decorator(fn: Callable) -> Callable: print('warning: dataloader batch size stay as default.') # load module code print_each_rank(f'loading existed module from {filename} ...') - # model.load_module(filename, load_content=load_content) model.load_module(filename) # load schedule code print_each_rank(f'loading existed schedule from {filename} ...') diff --git a/tests/gpt_profile.md b/tests/gpt_memory_profile.md similarity index 91% rename from tests/gpt_profile.md rename to tests/gpt_memory_profile.md index db4c0763..c828c185 100644 --- a/tests/gpt_profile.md +++ b/tests/gpt_memory_profile.md @@ -1,3 +1,5 @@ +# GPT-3 toy model memory profiling result + | layer | end2end | activation | param | e2e - 3 * p - activation | |:------|:--------|:-----------|:------|:-------------------------| | 1 | 1.59 | 0.47 | 0.24 | 0.40 | diff --git a/tests/test_profile_gpt.py b/tests/test_profile_gpt.py index c5448103..458aac52 100644 --- a/tests/test_profile_gpt.py +++ b/tests/test_profile_gpt.py @@ -1,10 +1,5 @@ """ -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=1 \ - examples/nlp/gpt/train.py --fp16 +torchrun --nproc_per_node=1 test/test_profile_gpt.py """ From ffc9d14f2dba037d02d44d537ec6f7b9aaec2230 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Sun, 6 Nov 2022 15:40:12 +0800 Subject: [PATCH 1131/1892] update profile interfaces --- cube/ir/cten.py | 5 ++++- cube/ir/dtype.py | 14 ++++++++++++++ cube/profiler/database.py | 39 ++++++++++++++++++++++++++------------- tests/test_profile_gpt.py | 3 ++- 4 files changed, 46 insertions(+), 15 deletions(-) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index f46dc154..ed4d92b2 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -19,7 +19,7 @@ import copy from cube.ir.unique import IDGenerator -from cube.ir.dtype import IRDType +from cube.ir.dtype import IRDType, dtype2byte_size __all__ = ['IRCell', 'IRDType', 'IRTensor'] @@ -596,6 +596,9 @@ def nelement(self) -> int: cnt *= num return cnt + def byte_size(self) -> int: + return self.nelement() * dtype2byte_size(self.dtype) + def backward(self) -> None: """ Autograd backward on the tensor diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index 7a81638c..89b8f6dd 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -15,6 +15,20 @@ class IRDType(Enum): unknown = 'unknown' +def dtype2byte_size(dtype: IRDType) -> int: + return { + IRDType.float64: 8, + IRDType.float32: 4, + IRDType.float16: 2, + IRDType.int64: 8, + IRDType.int32: 4, + IRDType.int16: 2, + IRDType.int8: 1, + IRDType.uint8: 1, + IRDType.boolean: 1, + }.get(dtype, 0) + + class DTypeInferRule: """ Infer the output shape according to given input shapes. diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 040584f4..cb32915a 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -96,8 +96,8 @@ def pack_hook(x): byte_size = 1 for dim in list(x.size()): byte_size = byte_size * dim - byte_size *= x.element_size() - train_memory= train_memory + byte_size + byte_size = byte_size * x.element_size() + train_memory = train_memory + byte_size return x def unpack_hook(x): @@ -188,28 +188,37 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): orig_device = torch.cuda.current_device() torch.cuda.set_device(device) + input_byte_size, param_byte_size = 0, 0 + for t in node.inputs(): + if t.is_param(): + param_byte_size = param_byte_size + t.byte_size() + else: + input_byte_size = input_byte_size + t.byte_size() + # run profiling fw_span, bw_span, infer_memory, train_memory = \ CompProfiler.profile(fn, shapes, dtypes, **kwargs) # log to database key = self._serialize(node) - self.insert(node.signature, key, fw_span, bw_span, infer_memory, train_memory) + self.insert(node.signature, key, input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory) print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " - f"=> fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} ms | " - f"infer mem: {infer_memory} | train mem: {train_memory}") + f"=> in mem {input_byte_size} | param mem: {param_byte_size} | fw: {round(fw_span, 2)} ms | " + f"bw: {round(bw_span, 2)} ms | infer mem: {infer_memory} | train mem: {train_memory}") if isinstance(device, int): torch.cuda.set_device(orig_device) - return fw_span, bw_span, infer_memory, train_memory + return input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory - def insert(self, name: str, key: str, fw_span: float, bw_span: float, - infer_memory: int, train_memory: int): + def insert(self, name: str, key: str, input_byte_size: int, param_byte_size: int, + fw_span: float, bw_span: float, infer_memory: int, train_memory: int): """ log the span of a function name with key @param name str: the function signature @param key str: the encoded shapes and dtypes of node inputs + @param input_byte_size int: byte size of input tensors + @param param_byte_size int: byte size of param tensors @param fw_span float: the forward span time in milliseconds @param bw_span float: the backward span time in milliseconds @param infer_memory int: the peak memory in bytes after inference of the function @@ -218,7 +227,7 @@ def insert(self, name: str, key: str, fw_span: float, bw_span: float, assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = (fw_span, bw_span, infer_memory, train_memory) + self._data[name][key] = (input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory) def exist(self, node: IRFwOperation) -> bool: """ @@ -235,12 +244,14 @@ def exist(self, node: IRFwOperation) -> bool: return False return True - def query(self, node: IRFwOperation) -> Tuple[float, float, int, int]: + def query(self, node: IRFwOperation) -> Tuple[int, int, float, float, int, int]: """! Get the performance number of a node in IRGraph @param node IRFwOperation: node in IRGraph + @return input_byte_size int: byte size of input tensors + @return param_byte_size int: byte size of param tensors @return fw_span float: the forward span time in milliseconds @return bw_span float: the backward span time in milliseconds @return infer_memory int: the peak memory in bytes after inference of the function @@ -253,7 +264,7 @@ def query(self, node: IRFwOperation) -> Tuple[float, float, int, int]: return None return self._data[node.signature][key] - def query_func(self, signature, shapes, dtypes) -> Tuple[float, float, int, int]: + def query_func(self, signature, shapes, dtypes) -> Tuple[int, int, float, float, int, int]: """ Get performance number of given name (signature), shapes and dtypes @@ -261,6 +272,8 @@ def query_func(self, signature, shapes, dtypes) -> Tuple[float, float, int, int] @param shapes Tuple[Tuple[int]]: the shape of each input tensor @param dtypes Tuple[torch.dtype]: the dtype of each tensor + @return input_byte_size int: byte size of input tensors + @return param_byte_size int: byte size of param tensors @return fw_span float: the forward span time in milliseconds @return bw_span float: the backward span time in milliseconds @return infer_memory int: the peak memory in bytes after inference of the function @@ -352,7 +365,7 @@ def __repr__(self) -> str: for signature in self._data: for key in self._data[signature]: shapes, dtypes = self._deserialize(key) - fw_span, bw_span, infer_mem, train_mem = self._data[signature][key] - data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, fw span: {fw_span} ms, bw span: {bw_span} ms, infer mem {infer_mem} bytes, train mem {train_mem} bytes') + input_byte_size, param_byte_size, fw_span, bw_span, infer_mem, train_mem = self._data[signature][key] + data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, in mem {input_byte_size} bytes, param mem {param_byte_size} bytes, fw span: {fw_span} ms, bw span: {bw_span} ms, infer mem {infer_mem} bytes, train mem {train_mem} bytes') data = '\n'.join(data) return data diff --git a/tests/test_profile_gpt.py b/tests/test_profile_gpt.py index 458aac52..eb9a391f 100644 --- a/tests/test_profile_gpt.py +++ b/tests/test_profile_gpt.py @@ -49,7 +49,8 @@ def profile(graph, resource): continue partition_nodes = gen_partitions(node, 1) for partition_node in partition_nodes: - fw_span, bw_span, infer_mem, train_mem = db.profile(partition_node) + in_mem, param_mem, fw_span, bw_span, infer_mem, train_mem = db.profile(partition_node) + print(node.signature, in_mem, param_mem) mem_sum = mem_sum + train_mem db.dump('db.json', override=True) print('estimated train mem: ', mem_sum / 1024 / 1024 / 1024) From ebd3425fc61b0940a6127d0024e33fc5d48e682c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Nov 2022 18:42:54 +0800 Subject: [PATCH 1132/1892] add clone operator and tanh operator --- cube/graph/function/function.py | 23 +++++++++++++++++++++++ cube/graph/parser/mapping.py | 3 +++ 2 files changed, 26 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index d872127a..45966f19 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -243,6 +243,18 @@ def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: return lhs_shape, rhs_shape, out_shape +def Clone(signature, inputs): + """ + torch.clone(input, *, memory_format=torch.preserve_format) + """ + assert len(inputs) == 2, f"inputs: {inputs}" + tensor, memory_format = inputs + annos = ['* -> *'] + tensor = inputs[0] + assert memory_format is None, f"Not supported for a specific memory format" + return IRDimops(Clone, 'clone', signature, annos, [tensor]) + + def Add(signature, inputs): if len(inputs) == 2: kwargs = {} @@ -375,6 +387,7 @@ def Neg(signature, inputs): annos = ['* -> *'] return IRDimops(Neg, 'neg', signature, annos, inputs, **kwargs) + def Sin(signature, inputs): annos = ['* -> *'] tensor = inputs[0:1] @@ -399,6 +412,16 @@ def Cos(signature, inputs): return IRDimops(Cos, 'cos', signature, annos, tensor) +def Tanh(signature, inputs): + """ + torch.tanh(input, *, out=None) + """ + assert len(inputs) == 1, f"inputs: {inputs}" + annos = ['* -> *'] + tensor = inputs[0:1] + return IRDimops(Tanh, 'tanh', signature, annos, tensor) + + def GeLU(signature, inputs): annos = ['* -> *'] signature = 'torch.nn.functional.gelu' diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index b080b4b4..efc2116c 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -89,6 +89,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('tensor'): function.NewTensor, __ttemplate('to'): function.ToTensor, __ttemplate('rand'): function.Rand, + __ttemplate('clone'): function.Clone, __ttemplate('add') : function.Add, @@ -113,6 +114,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('cos'): function.Cos, + __ttemplate('tanh'): function.Tanh, + __ttemplate('bmm') : function.BatchLinear, __ttemplate('sum') : function.Sum, From 3a9b1ed4631329cd5634ae9f9c2c8debb5175efa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Nov 2022 18:43:51 +0800 Subject: [PATCH 1133/1892] init mbart cube training --- examples/nlp/mbart/model.py | 150 +++++++++++++++++++++++------- examples/nlp/mbart/policy/mpmd.py | 60 ++++++++++++ examples/nlp/mbart/train.py | 63 +++++++++++-- 3 files changed, 229 insertions(+), 44 deletions(-) create mode 100644 examples/nlp/mbart/policy/mpmd.py diff --git a/examples/nlp/mbart/model.py b/examples/nlp/mbart/model.py index 57faafdb..1dedc966 100644 --- a/examples/nlp/mbart/model.py +++ b/examples/nlp/mbart/model.py @@ -9,31 +9,63 @@ class Config: + TBD = None # to be decided # source and target - max_source_positions = 1024 - max_target_positions = 1024 + num_embeddings = 2500 + hidden = 1024 + heads = 16 + layers = 4 + seqlen = 2048 - num_embeddings = 250027 + max_source_positions = None + max_target_positions = None - encoder_embed_dim = 1024 - encoder_ffn_embed_dim = 4 * 1024 - encoder_layers = 12 - encoder_attention_heads = 16 + encoder_embed_dim = TBD + encoder_ffn_embed_dim = TBD + encoder_layers = TBD + encoder_attention_heads = TBD - decoder_embed_dim = 1024 - decoder_ffn_embed_dim = 4 * 1024 - decoder_layers = 12 - decoder_attention_heads = 16 + decoder_embed_dim = TBD + decoder_ffn_embed_dim = TBD + decoder_layers = TBD + decoder_attention_heads = TBD - attention_dropout = 0.0 - dropout = 0.1 - activation_dropout = 0.0 + attention_dropout = TBD + dropout = TBD + activation_dropout = TBD - pad_token_id = 1 - eos_token_id = 2 + pad_token_id = TBD + eos_token_id = TBD # classification task - num_classes = 3 + num_classes = TBD + + def __init__(self) -> None: + + Config.max_source_positions = Config.seqlen + Config.max_target_positions = Config.seqlen + + Config.encoder_embed_dim = Config.hidden + Config.encoder_ffn_embed_dim = 4 * Config.hidden + Config.encoder_layers = Config.layers + Config.encoder_attention_heads = Config.heads + + Config.decoder_embed_dim = Config.hidden + Config.decoder_ffn_embed_dim = 4 * Config.hidden + Config.decoder_layers = Config.layers + Config.decoder_attention_heads = Config.heads + + Config.attention_dropout = 0.1 + Config.dropout = 0.1 + Config.activation_dropout = 0.1 + + Config.pad_token_id = 1 + Config.eos_token_id = 2 + + Config.num_classes = 3 + + def __repr__(self) -> str: + return f'Config(num_embeddings={Config.num_embeddings}, hidden={Config.hidden}, heads={Config.heads}, layers={Config.layers}, seqlen={Config.seqlen})' class PositionalEmbedding(torch.nn.Embedding): @@ -67,7 +99,8 @@ def __init__( self.out_proj = torch.nn.Linear(inner_dim, num_classes) self.loss_fct = torch.nn.CrossEntropyLoss() - def forward(self, dec: torch.Tensor, labels): + # def forward(self, dec: torch.Tensor, labels): + def forward(self, dec: torch.Tensor): # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] dec = dec[:,-1,:] sentence_represent = dec @@ -76,7 +109,8 @@ def forward(self, dec: torch.Tensor, labels): hidden_states = torch.tanh(hidden_states) hidden_states = self.dropout(hidden_states) logits = self.out_proj(hidden_states) - loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) + # loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) + loss = logits.sum() return loss @@ -85,73 +119,119 @@ class MBartForSentenceClassification(torch.nn.Module): def __init__(self): super().__init__() cfg = Config() + print("Model Arch:", cfg) # embedding - self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.encoder_embed_dim) - + self.vocab = torch.nn.Parameter(torch.empty( + cfg.num_embeddings, cfg.encoder_embed_dim)) # encoder embedding - self.encoder_position = PositionalEmbedding(cfg.max_source_positions, cfg.encoder_embed_dim) + self.embed_offset = 2 + self.encoder_position = torch.nn.Parameter(torch.empty( + cfg.max_source_positions, cfg.encoder_embed_dim)) self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) # encoder layers self.encoders = torch.nn.ModuleList( [EncoderLayer( - cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.decoder_ffn_embed_dim, + cfg.encoder_embed_dim, cfg.encoder_attention_heads, + cfg.encoder_embed_dim, cfg.encoder_ffn_embed_dim, cfg.dropout, cfg.attention_dropout, cfg.activation_dropout ) for _ in range(cfg.decoder_layers)] ) self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) # decoder embedding - self.decoder_position = PositionalEmbedding(cfg.max_target_positions, cfg.decoder_embed_dim) + self.decoder_position = torch.nn.Parameter(torch.empty( + cfg.max_target_positions, cfg.decoder_embed_dim)) self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) # decoder layers self.decoders = torch.nn.ModuleList( [DecoderLayer( - cfg.decoder_embed_dim, cfg.decoder_attention_heads, cfg.decoder_ffn_embed_dim, + cfg.decoder_embed_dim, cfg.decoder_attention_heads, + cfg.decoder_embed_dim, cfg.decoder_ffn_embed_dim, cfg.dropout, cfg.attention_dropout, cfg.activation_dropout ) for _ in range(cfg.decoder_layers)] ) - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) self.head = MBartClassificationHead(cfg.decoder_embed_dim, 1024, cfg.num_classes, 0.0) - def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor, labels: torch.Tensor): - + def forward(self, input_ids: torch.Tensor): + """ + The forward is only for benchmark performance, + the original input of input_ids, decoder_input_ids and labels are + simplied by using only ine input_ids. + + The loss computation is also simplified by using sum. + """ + decoder_input_ids = torch.clone(input_ids) # encoder embedding - enc_emb = self.embed(input_ids) + cube.runtime.function.anchor('encoder embedding') + enc_emb = torch.nn.functional.embedding(input_ids, self.vocab) enc_emb = enc_emb * self.embed_scale_encoder - enc_emb = enc_emb + self.encoder_position(input_ids.size(1)) + enc_emb = enc_emb + self.encoder_position enc_emb = self.layernorm_embedding_encoder(enc_emb) - enc_emb = torch.nn.functional.dropout(enc_emb, p=0.0) + enc_emb = torch.nn.functional.dropout(enc_emb, p=0.1) enc = enc_emb.transpose(0, 1) # encoder layers for layer in self.encoders: + cube.runtime.function.anchor('encoder layer') enc = layer(enc) enc = self.layer_norm_encoder(enc) # decoder embedding - dec_emb = self.embed(decoder_input_ids) + cube.runtime.function.anchor('decoder embedding') + dec_emb = torch.nn.functional.embedding(decoder_input_ids, self.vocab) dec_emb = dec_emb * self.embed_scale_decoder - dec_emb = dec_emb + self.decoder_position(decoder_input_ids.size(1)) + dec_emb = dec_emb + self.decoder_position dec_emb = self.layernorm_embedding_decoder(dec_emb) - dec_emb = torch.nn.functional.dropout(dec_emb, p=0.0) + dec_emb = torch.nn.functional.dropout(dec_emb, p=0.1) dec = dec_emb.transpose(0, 1) # decoder layers for layer in self.decoders: + cube.runtime.function.anchor('decoder layer') dec = layer(dec, enc) dec = self.layer_norm_decoder(dec) dec = dec.transpose(0, 1) # head - loss = self.head(dec, labels) + # loss = self.head(dec, labels) + loss = self.head(dec) return loss +class MBartSyntheticDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int): + + self.bs = batch_size + self.cfg = Config() + super().__init__( + shapes=([batch_size, self.cfg.max_source_positions,],), + dtypes=(torch.int64,), + batch_dims=(0,) + ) + self.samples = [self.random_sample()] + + def random_sample(self): + input_ids = torch.randint( + 0, self.cfg.num_embeddings, + size=(self.bs, self.cfg.max_source_positions), + dtype=torch.int64, device=torch.cuda.current_device() + ) + return input_ids + + def __iter__(self): + return self + + def __next__(self): + return self.samples[0] + + class MBartDataLoader(cube.runtime.syndata.CubeDataLoader): def __init__(self, batch_size: int): diff --git a/examples/nlp/mbart/policy/mpmd.py b/examples/nlp/mbart/policy/mpmd.py new file mode 100644 index 00000000..94bd1e18 --- /dev/null +++ b/examples/nlp/mbart/policy/mpmd.py @@ -0,0 +1,60 @@ +from typing import List + +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.anchor import IRGraphAnchor +from cube.ir.cten import IRCell + + +def _group_to_blocks(fnodes) -> List[List[IRCell]]: + """ + Grouping to [ + [Encoder Embed], + [Encoder Layer], [Encoder Layer], ..., + [Decoder Embed], + [Decoder Layer], [Decoder Layer], ... + ] + """ + blocks = [] + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + # encoder embedding + fnodes[indices[0] + 1].comment = f'==> start of encoder embedding' + assert anchors[0].name == 'encoder embedding' + blocks.append(fnodes[0:indices[1]]) + indices.pop(0) + anchors.pop(0) + # encoder layers + lid = 0 + while anchors[0].name == 'encoder layer': + start, end = indices[0], indices[1] + fnodes[start + 1].comment = f'==> start of encoder layer {lid}' + blocks.append(fnodes[start:end]) + indices.pop(0) + anchors.pop(0) + lid += 1 + # decoder embedding + assert anchors[0].name == 'decoder embedding' + blocks.append(fnodes[indices[0]:indices[1]]) + indices.pop(0) + anchors.pop(0) + # decoder layers + lid = 0 + while len(indices) != 0: + assert anchors[0].name == 'decoder layer' + start, end = indices[0], indices[1] if len(indices) > 1 else len(fnodes) + fnodes[start + 1].comment = f'==> start of decoder layer {lid}' + blocks.append(fnodes[indices[0]:end]) + indices.pop(0) + anchors.pop(0) + lid += 1 + return blocks + + +def PASSingle(graph: IRGraph, resource): + assert resource.ngpus == 1 + _ = _group_to_blocks(graph.select(ntype=IRFwOperation)) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index f8eadeeb..64dcf008 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -4,42 +4,87 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=1 \ --nnodes=1 \ - examples/nlp/mbart/train.py + examples/nlp/mbart/train.py --policy PASSingle """ import torch -from examples.nlp.mbart.model import MBartForSentenceClassification -from examples.nlp.mbart.model import MBartDataLoader +from examples.nlp.mbart.model import MBartForSentenceClassification, Config +from examples.nlp.mbart.model import MBartSyntheticDataLoader import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary, model_summary +import examples.nlp.mbart.policy.mpmd as mpmd + +import argparse + +parser = argparse.ArgumentParser(description='GPT Train') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') +parser.add_argument('--fp16', action='store_true', default=False, + help='use fp16 for the training') +# training +parser.add_argument('--gbs', type=int, default=1, help='global batch size') +parser.add_argument('--mbs', type=int, default=1, help='micro batch size') +# arch +parser.add_argument('--vocab', type=int, default=2500, + help='used vocabulary size') +parser.add_argument('--layers', type=int, default=4, + help='layer number of each encoder and decoder') +parser.add_argument('--heads', type=int, default=16, + help='head number') +parser.add_argument('--hidden', type=int, default=2048, + help='head number') +parser.add_argument('--seqlen', type=int, default=1024, + help='sequence length') + +args = parser.parse_args() + +cube.init() +print(args) + +PAS = None +policies = list(mpmd.__dict__.keys()) +policies = [policy for policy in policies if policy.startswith('PAS')] +if args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + def train(): - batch_size = 1 + batch_size = args.mbs + Config.num_embeddings = args.vocab + Config.layers = args.layers + Config.hidden = args.hidden + Config.heads = args.heads + Config.seqlen = args.seqlen model = MBartForSentenceClassification().cuda() - dataloader = MBartDataLoader(batch_size) + dataloader = MBartSyntheticDataLoader(batch_size) optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) print_each_rank('model weight consumpition:') memory_summary() + model = cube.SemanticModel(model) + @cube.compile(model, dataloader, PAS=PAS, override=True, load_content=False) def train_iter(model, dataloader): - input_ids, decoder_input_ids, labels = next(dataloader) - loss = model(input_ids, decoder_input_ids, labels) + input_ids = next(dataloader) + loss = model(input_ids) loss.backward() + model = model.get_gen_module() CudaTimer(enable=False).warmup() iter_num = 64 for step in range(iter_num): - if step == 0: - model_summary(model, next(dataloader)) + # if step == 0: + # model_summary(model, next(dataloader)) if step >= 20: CudaTimer(enable=True).start('e2e') From cd384b3b41bf7c72902668f87b2e93cd0acd8f00 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Nov 2022 18:46:32 +0800 Subject: [PATCH 1134/1892] fix cross attention --- examples/nlp/blocks/attention.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index eba5d666..f704dac4 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -62,7 +62,7 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask=True): + h: int, scale: float, dropout_p: float, mask: bool = True): num_head = h L, N = query.size(0), query.size(1) dim_head = q_proj.size(0) // num_head @@ -84,10 +84,10 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, if mask: # (N h) L L -> (N h) L L attn = attn.view(N, num_head, L, L) ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) + amask = torch.tril(ones) + amask = amask.view(N, 1, L, L) + amask = (amask < 0.5) + attn = attn.masked_fill_(amask, -10000.0) attn = attn.view((N * num_head), L, L) attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L @@ -216,7 +216,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor): self.q_proj, self.q_bias, self.k_proj, self.k_bias, self.v_proj, self.v_bias, - self.out_proj, self.out_bias, + self.out_proj, self.num_heads, self.scaling, self.dropout_p, mask=True ) attn = attn + self.out_bias From 634a43e31a3213b49c9ef764c31778444c393923 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 7 Nov 2022 19:18:20 +0800 Subject: [PATCH 1135/1892] fix bugs on adapter for require grad --- cube/runtime/adapter/collectives.py | 2 +- cube/runtime/executor.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 0a7e0cf5..75c80fd2 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -80,7 +80,7 @@ def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, else: assert rank == dst tensor = torch.empty(shape, dtype=dtype, - device=torch.cuda.current_device(), requires_grad=True + device=torch.cuda.current_device() ) torch.distributed.recv(tensor, src) CudaTimer().stop(field_name='comm', predefined=True) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 91b56d28..ee4604a5 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -54,7 +54,11 @@ def aexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) outputs = subgraph(*input_tensors) else: outputs = subgraph(*input_tensors) - outputs = outputs.requires_grad_() if torch.is_tensor(outputs) else (t.requires_grad_() for t in outputs) + allow_grad_dtypes = (torch.float32, torch.float16) + if torch.is_tensor(outputs) and outputs.dtype in allow_grad_dtypes: + outputs = outputs.requires_grad_() + else: + outputs = (t.requires_grad_() if t.dtype in allow_grad_dtypes else t for t in outputs) return outputs @staticmethod From 448e9bdb7b96c3bddfc51bb342d9fd5a8de5f296 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 8 Nov 2022 10:16:10 +0800 Subject: [PATCH 1136/1892] update comment --- cube/profiler/database.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index cb32915a..b8b9d1a9 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -174,6 +174,8 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @param node IRFwOperation: node of IRGraph @param device int: the device that the node will execute on + @param input_byte_size int: byte size of input tensors + @param param_byte_size int: byte size of param tensors @return fw_span float: the forward span time in milliseconds @return bw_span float: the backward span time in milliseconds @return infer_memory int: the peak memory in bytes after inference of the function From 3693b1ed52f15243a3f92a33f83eb6dde6a30a63 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Nov 2022 10:44:39 +0800 Subject: [PATCH 1137/1892] fix 1f1b schedule --- cube/codegen/codegen.py | 9 ++++- cube/graph/graph.py | 59 +++++++++++++++++++++++++++++- cube/graph/schedule/sched1f1b.py | 58 +++++++++++++++++++++++++++++ cube/graph/schedule/strategy.py | 10 +++-- cube/graph/segment.py | 37 ++++++++++++++----- cube/runtime/schedule/sched1f1b.py | 9 ++++- cube/runtime/schedule/strategy.py | 6 ++- 7 files changed, 171 insertions(+), 17 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index eabb5a74..18a3dca4 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -986,7 +986,14 @@ def emit_schedule_plan(self, schedplan: IRScheduleStrategy, devid: int): kwargs: Dict[str, Any] = schedplan.kwargs(devid) strkwargs = dict() for kwarg, val in kwargs.items(): - name = str(val) if not isinstance(val, IRCell) else 'model.'+self.node_naming(val) + if isinstance(val, IRCell): + name = 'model.' + self.node_naming(val) + elif isinstance(val, (tuple, list)): + brackets = ')' if len(val) != 1 else ',)' + name = '(' + ', '.join('model.' + self.node_naming(n) \ + if isinstance(n, IRCell) else str(n) for n in val) + brackets + else: + name = str(val) strkwargs[kwarg] = name code = ', '.join(f'{kwarg}={name}' for kwarg, name in strkwargs.items()) code = f'{signature}({code})' diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ed1cc3ea..a3f7b508 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,7 +7,7 @@ will be inserted at scheduling time. """ -from typing import Union, Tuple, List, Optional, Dict +from typing import Sequence, Set, Union, Tuple, List, Optional, Dict from cube.graph.function.anchor import IRGraphAnchor from cube.ir.cten import IRTensor, IRCell @@ -22,6 +22,9 @@ from cube.algorithm.generics import GenericDistAlgo +FOp = Union[IRFwOperation, IRDataOperation] + + class IRGraph(IRSegment): """ IRGraph. @@ -422,8 +425,61 @@ def assign(self, node: Union[IRFwOperation, IRDataOperation], device: int) -> bo node.mirror.device = device return True + def reside(self, tensor: IRSubTensor, devices: Union[int, List[int]]): + """ + Allocate an attribute tensor to devices. + """ + assert tensor.is_attr(), f"Only support to set devices for graph attribute tensors" + raise NotImplementedError("Not supported yet") + ## Schedule Policy Primitives ## + def sequential(self, nodes: Sequence[Union[FOp, Set[FOp]]]): + """ + Scheduling Primitive: sequentially execute a list of nodes, + or a list of concurrent nodes. + + Note there should be no dependency from a later node (set) to a previous node (set). + + Note in current implementation we don't check correctness + + Currently only support node (set) from a same device. + + @param nodes Sequence[Set[FOp]]: a sequence of operators or + a sequence of concurrent operators. Note there should be no + """ + assert len(nodes) > 0 + concurrent_groups = [[node] if isinstance(node, IRCell) else node for node in nodes] + segment: IRSegment = self.segment(concurrent_groups[0][0]) + idx = segment.index(nodes[0]) + for group in concurrent_groups[1:]: + for node in group: + assert segment.exist(node, flatten=False), "All nodes should in a same segment" + # TODO: should check every node to see if they can be gathered based on that node + segment.reorder(node, idx) + + def concurrent(self, nodes: Set[Union[FOp, Sequence[FOp]]]): + """ + Scheduling Primitive: concurrently execut a list of nodes, + or a list of sequential nodes. + + Note there should be no dependency from a node (set) to another node (set). + + Currently only suuport node (set) from different devices. + + @param nodes Set[Sequence[Fop]]: a set of operators or + a set of sequential operators. + """ + assert len(nodes) > 0 + seq_groups = [[node] if isinstance(node, IRCell) else node for node in nodes] + segment: IRSegment = self.segment(seq_groups[0][0]) + idx = segment.index(nodes[0]) + for group in seq_groups[1:]: + for node in group: + assert segment.exist(node, flatten=False), "All nodes should in a same segment" + # TODO: should check every node to see if they can be gathered based on that node + segment.reorder(node, idx) + def happen_before(self, node1: IRCell, node2: IRCell, skip=None) -> bool: """ Check node1 -> (happen before) node2 @@ -657,6 +713,7 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: fidx = self.index(fstages[sid][0]) self.finsert(fwop, fidx) else: + fidx = self.index(fstages[sid][0]) self.insert(fwop, fidx) # update stage op group fstages[sid].insert(0, fwop) diff --git a/cube/graph/schedule/sched1f1b.py b/cube/graph/schedule/sched1f1b.py index 37dc6abc..c4a42e14 100644 --- a/cube/graph/schedule/sched1f1b.py +++ b/cube/graph/schedule/sched1f1b.py @@ -1,13 +1,60 @@ from typing import Dict, Optional, List +import numpy as np from cube.ir.cten import IRCell from cube.ir.adapter.adapter import IRAdapter +from cube.ir.adapter.adapter import IRWeightReducer from cube.graph.graph import IRGraph, IRSegment from cube.graph.schedule import IRScheduleStrategy +def reorder_inputs_outputs(node: IRCell, also_mirror: bool = True): + """ + Inplacement reorder forward node inputs and outputs by tensor ID. + + The order of inputs/outputs in backward can also be reordered correspondingly. + """ + assert isinstance(node, (IRCell, IRSegment)) + inputs_tid = np.array([t.tid for t in node.inputs()]) + inputs_idx = np.argsort(inputs_tid) + inputs = [node.input(idx) for idx in inputs_idx] + outputs_tid = np.array([t.tid for t in node.outputs()]) + outputs_idx = np.argsort(outputs_tid) + outputs = [node.output(idx) for idx in outputs_idx] + node._inputs = inputs + node._outputs = outputs + bnode: IRCell = node.mirror + if also_mirror and isinstance(bnode, IRCell): + if isinstance(bnode, IRSegment): + assert len(bnode.inputs()) == len(node.outputs()), f"fnode:\n{node}\nbnode:\n{bnode}" + bnode._inputs = [bnode.input(idx) for idx in outputs_idx] + assert len(bnode.outputs()) == len(node.inputs()), f"fnode:\n{node}\nbnode:\n{bnode}" + bnode._outputs = [bnode.output(idx) for idx in inputs_idx] + else: + # setup input + ftids = [t.tid for t in node.outputs()] + grads = [t.grad for t in node.outputs()] + actvs = [] + for t in bnode.inputs(): + assert t in grads, f"backward gradient is not required by its forward node " + actvs.append(ftids[grads.index(t)]) + inputs_idx = np.argsort(np.array(actvs)) + inputs = [bnode.input(idx) for idx in inputs_idx] + bnode._inputs = inputs + # setup outputs + ftids = [t.tid for t in node.inputs()] + grads = [t.grad for t in node.outputs()] + actvs = [] + for t in bnode.outputs(): + assert t in grads, f"backward gradient is not required by its forward" + actvs.append(ftids[grads.index(t)]) + outputs_idx = np.argsort(np.array(actvs)) + outputs = [bnode.output(idx) for idx in outputs_idx] + bnode._outputs = outputs + + class IRSchedule1F1B(IRScheduleStrategy): """ 1F1B Scheduling @@ -36,12 +83,20 @@ def __init__(self, graph, nmicros: int): self.num_stages: int = -1 # stage id self.stage_id: Dict[int, int] = dict() + # reducers + self.dev_reducers: Dict[int, List[IRWeightReducer]] = dict() # recompute self.recompute = False def apply(self) -> IRGraph: self.mesh() + # reorder input and output by tensor id + for node in self.graph.nodes(): + if isinstance(node, IRSegment) and node.isfw(): + reorder_inputs_outputs(node) + elif isinstance(node, IRAdapter) and node.forward: + reorder_inputs_outputs(node) # each forward has corresponding backward assert all(fseg.mirror in self.segments for fseg in self.segments if fseg.isfw()), \ "Require backward of each forward stage" @@ -72,6 +127,8 @@ def apply(self) -> IRGraph: assert len(self.recvers[fseg.mirror]) == 1, "Expect no forward send at last stage" self.sfadapter[devid] = None if sid == self.num_stages - 1 else self.senders[fseg][0] self.rbadapter[devid] = None if sid == self.num_stages - 1 else self.recvers[fseg.mirror][0] + # weight reducer + self.dev_reducers[devid] = [reducer for reducer in self.reducers if devid in reducer.device] # stage id self.stage_id[devid] = sid @@ -91,6 +148,7 @@ def kwargs(self, devid: int) -> Dict[str, IRCell]: stage_id = self.stage_id[devid], num_stages = self.num_stages, num_microbatch = self.nmicros, + reducers = self.dev_reducers[devid], recompute = self.recompute ) diff --git a/cube/graph/schedule/strategy.py b/cube/graph/schedule/strategy.py index ab660f27..94fc049c 100644 --- a/cube/graph/schedule/strategy.py +++ b/cube/graph/schedule/strategy.py @@ -1,6 +1,5 @@ from typing import Tuple, Dict, Any, List from cube.graph.graph import IRGraph, IRSegment -from cube.graph.function import IRGraphAnchor from cube.ir.adapter.adapter import IRAdapter, IRWeightReducer from cube.ir.cten import IRCell @@ -18,8 +17,8 @@ def __init__(self, graph: IRGraph, nmicros: int) -> None: self.recvers: Dict[IRSegment, List[IRAdapter]] = dict() # the sender adapters for this segment self.senders: Dict[IRSegment, List[IRAdapter]] = dict() - # postprocess after segments - self.post_process: List[IRCell] = [] + # postprocess of weight reducers + self.reducers: List[IRWeightReducer] = [] self.signature: str = '' def apply(self, graph: IRGraph) -> IRGraph: @@ -30,7 +29,8 @@ def kwargs(self, device: int) -> Dict[str, Any]: def mesh(self) -> List[List[int]]: """! - Group operators into segments corresponding to graph stage + Group operators into segments corresponding to graph stage. + Reorder adapter output to match with segment input order """ for segment in self.graph.nodes(): if isinstance(segment, IRSegment): @@ -45,3 +45,5 @@ def mesh(self) -> List[List[int]]: self.recvers[segment].append(adapter) elif self.graph.depends(segment, adapter): self.senders[segment].append(adapter) + if isinstance(adapter, IRWeightReducer): + self.reducers.append(adapter) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 97f19a3b..a473a2f3 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -430,13 +430,14 @@ def insert(self, node: IRCell, index: Union[int, CellPosition]): pos = CellPosition(pos.indices[1:]) segment.insert(node, pos) - def remove(self, node: IRCell, _pos: CellPosition = None) -> CellPosition: + def remove(self, node: IRCell, _pos: Union[int, CellPosition] = None) -> CellPosition: """ Remove a node at index # TODO: check input and output @param node IRCell: the removed node + @param _pos Optional[Union[int, CellPosition]: help to save cost if provide node position. @return index CellPosition: the removed index """ @@ -491,6 +492,20 @@ def replace(self, node: IRCell, new_nodes: List[IRCell]) -> int: self.insert(new_node, idx) return idx + def reorder(self, node: IRCell, index: int): + """ + Reorder an existing node to the index. + + @param node IRCell: the node in this segment, not considering inner segments. + @param index int: the index is under the view of nodes ordering before this call. + + @return None + """ + prev_index = self._nodes.index(node) + self.remove(node, prev_index) + index = index if prev_index >= index else index - 1 + self.insert(index, node) + @contextmanager def update(self, node): """ @@ -510,7 +525,7 @@ def update(self, node): yield node self.insert(node, index) - def exist(self, node: IRCell) -> bool: + def exist(self, node: IRCell, flatten: bool = True) -> bool: """ Check if the node is in this graph @@ -520,10 +535,11 @@ def exist(self, node: IRCell) -> bool: """ if node in self._nodes: return True - for segment in self._nodes: - if not isinstance(segment, IRSegment): continue - if segment.exist(node): - return True + if flatten: + for segment in self._nodes: + if not isinstance(segment, IRSegment): continue + if segment.exist(node, flatten): + return True return False def select(self, name: Optional[str] = None, ntype: Optional[IRCell] = None, flatten: bool = True) -> List[IRCell]: @@ -895,13 +911,16 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: if otensor in adapter_ins: if len(node.device) > 0 and set(otensor.device).issubset(adapter_ins[otensor]): continue - # loss doesn't have consumers - if len(segment.consumers(ftensor)) == 0: - outputs.add(otensor) # from segment outputs if any(t.overlap(otensor) for t in segment.outputs() if isinstance(t, IRSubTensor)): outputs.add(otensor) continue + # loss doesn't have consumers + if len(segment.consumers(ftensor)) == 0: + # TODO: loss judgement should be more robust + if ftensor.nelement() == 1: + outputs.add(otensor) + continue # for outside consumers consumers, ctensors = segment.consumers(ftensor), segment.ctensors(ftensor) consumers = [c for c, t in zip(consumers, ctensors) if t == otensor] diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py index 776d99de..8e056703 100644 --- a/cube/runtime/schedule/sched1f1b.py +++ b/cube/runtime/schedule/sched1f1b.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable +from typing import Callable, Iterable, List import torch from cube.runtime.schedule.strategy import ScheduleABC @@ -16,6 +16,7 @@ def run(segment: Callable, # forward body stage_id: int, num_stages: int, num_microbatch: int, + reducers: List[Callable], # weight reducers recompute=False): # special case: num_stages == 1: use gradient accum @@ -24,6 +25,8 @@ def run(segment: Callable, # forward body inputs = Schedule1F1B.dataloader_step(dataloader) outputs = Schedule1F1B.forward_step(segment, *inputs) input_grads = Schedule1F1B.backward_step(inputs, outputs, (None,)) + for reducer in reducers: + reducer() return num_warmup_microbatches = num_stages - 1 - stage_id @@ -107,5 +110,9 @@ def run(segment: Callable, # forward body # print(f'rank[{torch.distributed.get_rank()}]: line99 send backward') Schedule1F1B.adapter_step(sbadapter, False, *input_grads) + # allreduce gradient + for reducer in reducers: + reducer() + Schedule1F1B.assert_empty() # print(f'rank[{torch.distributed.get_rank()}]: ok here') diff --git a/cube/runtime/schedule/strategy.py b/cube/runtime/schedule/strategy.py index b4953817..2d4fdaec 100644 --- a/cube/runtime/schedule/strategy.py +++ b/cube/runtime/schedule/strategy.py @@ -31,6 +31,8 @@ def backward_step(itensors: List[torch.Tensor], if torch.is_tensor(tensor) and tensor.requires_grad: tensor.retain_grad() CudaTimer().start("backward") + otensors = [t for t in otensors if t.requires_grad] + assert len(otensors) == len(otensor_grads), f"output tensor mismatches with gradient number" torch.autograd.backward(otensors, grad_tensors=otensor_grads) CudaTimer().stop("backward") itensor_grads = [] @@ -54,13 +56,15 @@ def adapter_step(adapter: Callable, require_grad : bool = True, *args): adapter pass """ if adapter is None: return () + args = tuple(t for t in args if torch.is_tensor(t)) CudaTimer().start('adapter') outputs = adapter(*args) CudaTimer().stop('adapter') if not isinstance(outputs, tuple): outputs = (outputs,) if require_grad: - outputs = tuple(t.requires_grad_() if torch.is_tensor(t) else t for t in outputs) + grad_dtypes = (torch.float16, torch.float32) + outputs = tuple(t.requires_grad_() if torch.is_tensor(t) and t.dtype in grad_dtypes else t for t in outputs) return outputs @staticmethod From 1e98d71d0ab1733b8e225be6f719b17961458cf2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Nov 2022 14:27:45 +0800 Subject: [PATCH 1138/1892] parser supports for multiple output operators like chunk; fix 1f1b scheduling --- cube/graph/function/function.py | 16 ++++++++ cube/graph/gener/gen.py | 67 ++++++++++++++++---------------- cube/graph/parser/mapping.py | 2 + cube/graph/parser/parser.py | 15 +++++-- cube/graph/schedule/sched1f1b.py | 60 +++++----------------------- 5 files changed, 72 insertions(+), 88 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 45966f19..3aad8f82 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -879,6 +879,22 @@ def Stack(signature, inputs: Tuple[List[IRTensor], int]): return IRDimops(Stack, 'stack', signature, [anno], tensors, dim=dim) +def Chunk(signature, inputs: Tuple[IRTensor, int, int]): + """ + torch.chunk(input, chunks, dim=0) + """ + assert len(inputs) == 3 + tensor, chunks, dim = inputs + assert tensor.shape[dim] % chunks == 0 + iannos = [ShapeAnno.create_shape_str(tensor.shape)] + oannos = [copy.copy(iannos[0]) for _ in range(chunks)] + iannos[0][dim] = str(tensor.shape[dim]) + for oanno in oannos: + oanno[dim] = str(tensor.shape[dim] // chunks) + anno = OpAnno.create_op_str(iannos, oannos) + return IRDimops(Chunk, 'chunk', signature, [anno], [tensor], chunks=chunks, dim=dim) + + def Select(signature, inputs: Tuple[IRTensor, int, int]): """ torch.select(self:Tensor, dim:int, index:int) -> Tensor diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 701d9269..245dec6e 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -36,15 +36,14 @@ def __init__(self, tensor: IRSubTensor, device: int, self.device = device -def create_dummy(segment: IRSegment) -> List[IRFwOperation]: +def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) -> List[IRFwOperation]: """ Create dummy operators segment inputs and outputs. The backward operator is also inserted. - 1) produce segment input tensors - 2) consume segment output tensors - @param segment IRSegment: the target segment + @param inputs bool: True for creating dummy operators to produce segement's inputs + @param outputs bool: True for creating dummpy operators to consume segment's outputs @return nodes List[IRCell]: the generated operation """ @@ -52,36 +51,38 @@ def create_dummy(segment: IRSegment) -> List[IRFwOperation]: fwops = [] # create inputs - for tensor in segment.inputs(): - devices = [consumer.device for consumer in segment.consumers(tensor.parent)][::-1] - if not isinstance(tensor, IRSubTensor): continue - assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = DummyInputOuput(tensor, 0, is_output=True, name=f'segment{segment.cid}_input') - for devid in devices: - fop = fwop.replicate() - fop.device = devid - if tensor.requires_grad: - fop.output(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) - segment.finsert(fop, 0) - else: - segment.insert(fop, 0) - fwops.append(fop) + if inputs: + for tensor in segment.inputs(): + devices = [consumer.device for consumer in segment.consumers(tensor.parent)][::-1] + if not isinstance(tensor, IRSubTensor): continue + assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" + fwop = DummyInputOuput(tensor, 0, is_output=True, name=f'segment{segment.cid}_input') + for devid in devices: + fop = fwop.replicate() + fop.device = devid + if tensor.requires_grad: + fop.output(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) + segment.finsert(fop, 0) + else: + segment.insert(fop, 0) + fwops.append(fop) # create outputs - for tensor in segment.outputs(): - devices = [producer.device for producer in segment.producers(tensor.parent)] - if not isinstance(tensor, IRSubTensor): continue - assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = DummyInputOuput(tensor, 0, is_input=True, name=f'segment{segment.cid}_output') - for devid in devices: - fop = fwop.replicate() - fop.device = devid - if tensor.requires_grad and segment.mirror != segment: - fop.input(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) - segment.finsert(fop, segment.nnodes) - else: - segment.insert(fop, segment.nnodes) - fwops.append(fop) + if outputs: + for tensor in segment.outputs(): + devices = [producer.device for producer in segment.producers(tensor.parent)] + if not isinstance(tensor, IRSubTensor): continue + assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" + fwop = DummyInputOuput(tensor, 0, is_input=True, name=f'segment{segment.cid}_output') + for devid in devices: + fop = fwop.replicate() + fop.device = devid + if tensor.requires_grad and segment.mirror != segment: + fop.input(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) + segment.finsert(fop, segment.nnodes) + else: + segment.insert(fop, segment.nnodes) + fwops.append(fop) return fwops @@ -259,7 +260,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: return False return True - fdummies = create_dummy(graph) + fdummies = create_dummy(graph, inputs=True, outputs=True) bdummies = [fwop.mirror for fwop in fdummies if fwop.mirror is not None] bgraph: Optional[IRSegment] = graph.mirror diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index efc2116c..a041c507 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -149,6 +149,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('stack'): function.Stack, + __ttemplate('chunk'): function.Chunk, + __ttemplate('flatten'): function.Flatten, __ttemplate('roll'): function.Roll, diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 0bc04e61..660d2364 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -309,13 +309,20 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: ir_node = result if len(ir_node.outputs()) != len(outputs): - raise RuntimeError( - f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" + assert len(outputs) == 1, ( + f"Farse Fail: torchscript has different output number of IR node: {len(outputs)} != {len(ir_node.outputs())}\n" + f"This can only be happend to have pre-defined output number of 1" ) + node_outputs = (ir_node.outputs(),) + # raise RuntimeError( + # f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" + # ) + else: + node_outputs = ir_node.outputs() # handle outputs - for index, output in enumerate(outputs): - frame.add_var(output.debugName(), ir_node.output(index)) + for output, node_output in zip(outputs, node_outputs): + frame.add_var(output.debugName(), node_output) return [ir_node] diff --git a/cube/graph/schedule/sched1f1b.py b/cube/graph/schedule/sched1f1b.py index c4a42e14..e51f442b 100644 --- a/cube/graph/schedule/sched1f1b.py +++ b/cube/graph/schedule/sched1f1b.py @@ -1,6 +1,6 @@ from typing import Dict, Optional, List -import numpy as np +import warnings from cube.ir.cten import IRCell from cube.ir.adapter.adapter import IRAdapter @@ -10,51 +10,6 @@ from cube.graph.schedule import IRScheduleStrategy -def reorder_inputs_outputs(node: IRCell, also_mirror: bool = True): - """ - Inplacement reorder forward node inputs and outputs by tensor ID. - - The order of inputs/outputs in backward can also be reordered correspondingly. - """ - assert isinstance(node, (IRCell, IRSegment)) - inputs_tid = np.array([t.tid for t in node.inputs()]) - inputs_idx = np.argsort(inputs_tid) - inputs = [node.input(idx) for idx in inputs_idx] - outputs_tid = np.array([t.tid for t in node.outputs()]) - outputs_idx = np.argsort(outputs_tid) - outputs = [node.output(idx) for idx in outputs_idx] - node._inputs = inputs - node._outputs = outputs - bnode: IRCell = node.mirror - if also_mirror and isinstance(bnode, IRCell): - if isinstance(bnode, IRSegment): - assert len(bnode.inputs()) == len(node.outputs()), f"fnode:\n{node}\nbnode:\n{bnode}" - bnode._inputs = [bnode.input(idx) for idx in outputs_idx] - assert len(bnode.outputs()) == len(node.inputs()), f"fnode:\n{node}\nbnode:\n{bnode}" - bnode._outputs = [bnode.output(idx) for idx in inputs_idx] - else: - # setup input - ftids = [t.tid for t in node.outputs()] - grads = [t.grad for t in node.outputs()] - actvs = [] - for t in bnode.inputs(): - assert t in grads, f"backward gradient is not required by its forward node " - actvs.append(ftids[grads.index(t)]) - inputs_idx = np.argsort(np.array(actvs)) - inputs = [bnode.input(idx) for idx in inputs_idx] - bnode._inputs = inputs - # setup outputs - ftids = [t.tid for t in node.inputs()] - grads = [t.grad for t in node.outputs()] - actvs = [] - for t in bnode.outputs(): - assert t in grads, f"backward gradient is not required by its forward" - actvs.append(ftids[grads.index(t)]) - outputs_idx = np.argsort(np.array(actvs)) - outputs = [bnode.output(idx) for idx in outputs_idx] - bnode._outputs = outputs - - class IRSchedule1F1B(IRScheduleStrategy): """ 1F1B Scheduling @@ -91,12 +46,15 @@ def __init__(self, graph, nmicros: int): def apply(self) -> IRGraph: self.mesh() - # reorder input and output by tensor id for node in self.graph.nodes(): - if isinstance(node, IRSegment) and node.isfw(): - reorder_inputs_outputs(node) - elif isinstance(node, IRAdapter) and node.forward: - reorder_inputs_outputs(node) + if isinstance(node, IRAdapter) and node.forward: + if len(set(node.outputs())) > 1 or len(set(node.inputs())) > 1: + warnings.warn( + "Detected one adapter has more than one input/output in stage transmission, " + "which is not safe for current scheduling implementation due to potential " + "mis-ordering of arguments. Better to use torch.cat and torch.chunk to " + "merge multiple tensors into one and unpack it at next stage." + ) # each forward has corresponding backward assert all(fseg.mirror in self.segments for fseg in self.segments if fseg.isfw()), \ "Require backward of each forward stage" From ed7e7991161b2c35913366655049b7cc528a634c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Nov 2022 14:44:12 +0800 Subject: [PATCH 1139/1892] change torch.select by using dimops --- cube/graph/function/function.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 3aad8f82..24ddf0a9 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -900,7 +900,13 @@ def Select(signature, inputs: Tuple[IRTensor, int, int]): torch.select(self:Tensor, dim:int, index:int) -> Tensor """ tensor, dim, index = inputs - return IRSelect(signature, [tensor], 'select', dim, index) + ianno = ShapeAnno.create_shape_str(tensor.shape) + oanno = copy.copy(ianno) + ianno[dim] += '^' + oanno[dim] = '1' + anno = OpAnno.create_op_str([ianno], [oanno]) + return IRDimops(Select, 'select', signature, [anno], [tensor], dim=dim, index=index) + # return IRSelect(signature, [tensor], 'select', dim, index) def Slice(signature, inputs: Tuple[IRTensor, int, Optional[int], Optional[int], int]): """ From 29cfe4f3a98e493a2c08513b0933bb76de7b8a74 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Nov 2022 14:44:47 +0800 Subject: [PATCH 1140/1892] add mbart megatron policy example --- examples/nlp/mbart/model.py | 37 +++++- examples/nlp/mbart/policy/mpmd.py | 195 +++++++++++++++++++++++++++++- examples/nlp/mbart/train.py | 31 +++-- 3 files changed, 242 insertions(+), 21 deletions(-) diff --git a/examples/nlp/mbart/model.py b/examples/nlp/mbart/model.py index 1dedc966..3c2ee51a 100644 --- a/examples/nlp/mbart/model.py +++ b/examples/nlp/mbart/model.py @@ -7,6 +7,11 @@ import cube +@cube.graph.parser.register('* -> *, *', name='multi2ref') +def multi2ref(tensor: torch.Tensor): + return tensor, tensor + + class Config: TBD = None # to be decided @@ -102,7 +107,8 @@ def __init__( # def forward(self, dec: torch.Tensor, labels): def forward(self, dec: torch.Tensor): # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] - dec = dec[:,-1,:] + dec = torch.select(dec, dim=1, index=-1) + # dec = dec[:,-1,:] sentence_represent = dec hidden_states = self.dropout(sentence_represent) hidden_states = self.dense(hidden_states) @@ -116,9 +122,10 @@ def forward(self, dec: torch.Tensor): class MBartForSentenceClassification(torch.nn.Module): - def __init__(self): + def __init__(self, batch_size: int): super().__init__() cfg = Config() + self.vocab_size = cfg.num_embeddings print("Model Arch:", cfg) # embedding self.vocab = torch.nn.Parameter(torch.empty( @@ -155,9 +162,16 @@ def __init__(self): ) for _ in range(cfg.decoder_layers)] ) self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) - self.head = MBartClassificationHead(cfg.decoder_embed_dim, 1024, cfg.num_classes, 0.0) - + + # FIXME: cube now is not safe for multiple + # tensor transmissions between stages. + decoder_input_ids = torch.randint( + 0, self.vocab_size, (batch_size, cfg.seqlen), dtype=torch.int64, device=torch.device('cpu'), + ) + self.register_buffer('decoder_input_ids', decoder_input_ids) + + def forward(self, input_ids: torch.Tensor): """ The forward is only for benchmark performance, @@ -166,7 +180,7 @@ def forward(self, input_ids: torch.Tensor): The loss computation is also simplified by using sum. """ - decoder_input_ids = torch.clone(input_ids) + # decoder_input_ids = torch.clone(input_ids) # encoder embedding cube.runtime.function.anchor('encoder embedding') enc_emb = torch.nn.functional.embedding(input_ids, self.vocab) @@ -184,17 +198,28 @@ def forward(self, input_ids: torch.Tensor): # decoder embedding cube.runtime.function.anchor('decoder embedding') - dec_emb = torch.nn.functional.embedding(decoder_input_ids, self.vocab) + dec_emb = torch.nn.functional.embedding(self.decoder_input_ids, self.vocab) dec_emb = dec_emb * self.embed_scale_decoder dec_emb = dec_emb + self.decoder_position dec_emb = self.layernorm_embedding_decoder(dec_emb) dec_emb = torch.nn.functional.dropout(dec_emb, p=0.1) dec = dec_emb.transpose(0, 1) + # FIXME: need to cat and chunk because cube now is not safe + # for multiple tensor transformation between stages. + encdec = torch.cat((enc, dec), dim=-1) + # decoder layers for layer in self.decoders: cube.runtime.function.anchor('decoder layer') + enc, dec = torch.chunk(encdec, 2, dim=-1) + + enc, next_enc = multi2ref(enc) + dec = layer(dec, enc) + encdec = torch.cat((next_enc, dec), dim=-1) + + enc, dec = torch.chunk(encdec, 2, dim=-1) dec = self.layer_norm_decoder(dec) dec = dec.transpose(0, 1) diff --git a/examples/nlp/mbart/policy/mpmd.py b/examples/nlp/mbart/policy/mpmd.py index 94bd1e18..5651425a 100644 --- a/examples/nlp/mbart/policy/mpmd.py +++ b/examples/nlp/mbart/policy/mpmd.py @@ -1,10 +1,41 @@ -from typing import List +from typing import List, Tuple +import numpy as np from cube.graph import IRGraph from cube.ir.operator import IRFwOperation, IRDataOperation from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.segment import IRSegment from cube.ir.cten import IRCell - +from cube.graph.schedule.sched1f1b import IRSchedule1F1B + + +def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + + e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: + ( + ( (0,1,2), (3,4,5) ), + ( (0,3), (2,5), (3,6) ), + ) + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + def _group_to_blocks(fnodes) -> List[List[IRCell]]: """ @@ -51,6 +82,31 @@ def _group_to_blocks(fnodes) -> List[List[IRCell]]: return blocks +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): + if len(devs) == 1: + graph.assign(node, devs[0]) + sub_nodes = [node] + else: + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + if len(devs) == 1: + graph.assign(node, devs[0]) + sub_nodes = [node] + else: + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 _ = _group_to_blocks(graph.select(ntype=IRFwOperation)) @@ -58,3 +114,138 @@ def PASSingle(graph: IRGraph, resource): graph.assign(node, 0) return graph + +def PAS1F1B(graph: IRGraph, resource): + + num_stages = resource.ngpus + num_microbatch = 4 + recompute: bool = True + + blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) + enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] + dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] + if recompute: + for block in blocks: + graph.recompute(block) + + # staging + fstages = [[] for _ in range(num_stages)] + nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // num_stages + for lid, fnodes in enumerate(enc_layers + dec_layers): + if lid == 0: + fstages[0] += enc_emb + elif lid == len(enc_layers): + fstages[num_stages // 2] += dec_emb + stage_id = min(lid // nlayer_per_stage, num_stages - 1) + fstages[stage_id] += fnodes + graph.staging(tuple(stage[0] for stage in fstages)) + + dataloader = graph.select(ntype=IRDataOperation)[0] + _replica(graph, dataloader, [0, num_stages // 2]) + + fsegments = [seg for seg in graph.select(ntype=IRSegment, flatten=False) if seg.isfw()] + assert len(fsegments) == num_stages, f"Not match: {len(fsegments)} != {num_stages}" + for devid, segment in enumerate(fsegments): + graph.assign(segment, devid) + + strategy = IRSchedule1F1B(graph, num_microbatch) + graph.predef_sched(strategy) + + return graph + + +def PASMegatronTP(graph: IRGraph, resource): + + tp_size = resource.ngpus + recompute: bool = True + devs = list(range(tp_size)) + + blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) + if recompute: + for block in blocks: + graph.recompute(block) + + for node in graph.select(ntype=IRFwOperation): + if node.name == 'embedding': + _tp(graph, node, devs, idx=1, dim=0) + elif node.name == 'self_attention' or node.name == 'feedforward': + _tp(graph, node, devs, idx=1, dim=0) + elif node.name == 'cross_attention': + _tp(graph, node, devs, idx=2, dim=0) + else: + _replica(graph, node, devs) + + dataloader = graph.select(ntype=IRDataOperation)[0] + _replica(graph, dataloader, devs) + + return graph + + +def PASMegatron(graph: IRGraph, resource): + + dp_size = 2 + tp_size = 2 + pp_size = resource.ngpus // (dp_size * tp_size) + recompute: bool = True + num_microbatch = 16 + + # device mesh + dp_groups, pp_groups, tp_groups = \ + _create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) + print(f'dp groups: {dp_groups}') + print(f'pp groups: {pp_groups}') + print(f'tp groups: {tp_groups}') + + def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: + return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] + + blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) + enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] + dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] + if recompute: + for block in blocks: + graph.recompute(block) + + # pipelien stage + fstages = [[] for _ in range(pp_size)] + nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // pp_size + for lid, fnodes in enumerate(enc_layers + dec_layers): + if lid == 0: + fstages[0] += enc_emb + elif lid == len(enc_layers): + fstages[pp_size // 2] += dec_emb + stage_id = min(lid // nlayer_per_stage, pp_size - 1) + fstages[stage_id] += fnodes + graph.staging(tuple(stage[0] for stage in fstages)) + + # partition dataloader + dataloader = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[0] + dls = _replica(graph, dataloader, [0]*dp_size) # graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) + for dp_idx, dl in enumerate(dls): + devices = [get_device(dp_idx, 0, tp_idx) for tp_idx in range(tp_size)] + _replica(graph, dl, devices) + + # tp-dp partition + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + assert len(fstages) == pp_size + for pp_idx, fstage in enumerate(fstages): + for node in fstage.nodes(): + if len(node.inputs()) == 0: continue # anchor + if node.name == 'embedding': + nodes = _tp(graph, node, [0]*tp_size, idx=1, dim=0) + elif node.name == 'self_attention' or node.name == 'feedforward': + nodes = _tp(graph, node, [0]*tp_size, idx=1, dim=0) + elif node.name == 'cross_attention': + nodes = _tp(graph, node, [0]*tp_size, idx=2, dim=0) + else: + nodes = _replica(graph, node, [0]*tp_size) + # data parallel + for tp_idx, node in enumerate(nodes): + dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] + batch_dim = node.input(0).shape.index(bs) + _tp(graph, node, dp_devices, idx=0, dim=batch_dim) + + strategy = IRSchedule1F1B(graph, num_microbatch) + graph.predef_sched(strategy) + return graph diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index 64dcf008..99f71d2e 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -26,9 +26,9 @@ help='use fp16 for the training') # training parser.add_argument('--gbs', type=int, default=1, help='global batch size') -parser.add_argument('--mbs', type=int, default=1, help='micro batch size') +parser.add_argument('--mbs', type=int, default=2, help='micro batch size') # arch -parser.add_argument('--vocab', type=int, default=2500, +parser.add_argument('--vocab', type=int, default=256, help='used vocabulary size') parser.add_argument('--layers', type=int, default=4, help='layer number of each encoder and decoder') @@ -64,7 +64,7 @@ def train(): Config.heads = args.heads Config.seqlen = args.seqlen - model = MBartForSentenceClassification().cuda() + model = MBartForSentenceClassification(batch_size).cuda() dataloader = MBartSyntheticDataLoader(batch_size) optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) @@ -79,14 +79,20 @@ def train_iter(model, dataloader): loss.backward() model = model.get_gen_module() + for name, buffer in model.named_buffers(): + torch.manual_seed(0) + if name.startswith('decoder_input_ids'): + inputs = torch.randint( + 0, args.vocab, buffer.size(), + dtype=torch.int64, device=torch.cuda.current_device(), + ) + buffer.copy_(inputs) + CudaTimer(enable=False).warmup() - iter_num = 64 + iter_num, warmup = 5, 2 for step in range(iter_num): - # if step == 0: - # model_summary(model, next(dataloader)) - - if step >= 20: + if step == warmup: CudaTimer(enable=True).start('e2e') # training @@ -94,17 +100,16 @@ def train_iter(model, dataloader): optimizer.step() optimizer.zero_grad() - if step >= 20: - CudaTimer().stop('e2e') - if step == 0: print_each_rank('passed first iteration') - if (step + 1) % 10 == 0: + if (step + 1) % 2 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + CudaTimer().stop('e2e') print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) memory_summary() From 8f0ab177d1e56df0c1ce2afc3d603616b2b7ae4d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Nov 2022 15:04:54 +0800 Subject: [PATCH 1141/1892] fix select shape infer --- cube/graph/function/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 24ddf0a9..e9c1c658 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -903,7 +903,7 @@ def Select(signature, inputs: Tuple[IRTensor, int, int]): ianno = ShapeAnno.create_shape_str(tensor.shape) oanno = copy.copy(ianno) ianno[dim] += '^' - oanno[dim] = '1' + oanno.pop(dim) anno = OpAnno.create_op_str([ianno], [oanno]) return IRDimops(Select, 'select', signature, [anno], [tensor], dim=dim, index=index) # return IRSelect(signature, [tensor], 'select', dim, index) From eb9f1db3eb5fc3566058bdec91ec35a7c0ec724c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Nov 2022 14:41:45 +0800 Subject: [PATCH 1142/1892] add new scheduling plan. update compiler --- cube/compiler.py | 16 +-- cube/execplan/planpass/grouping.py | 3 + cube/flags.py | 1 + cube/graph/schedule/sched1f1b.py | 4 +- cube/graph/schedule/schedmix.py | 190 +++++++++++++++++++++++++ cube/graph/schedule/strategy.py | 1 + cube/ir/cten.py | 6 +- cube/runtime/schedule/__init__.py | 3 +- cube/runtime/schedule/sched1f1b.py | 6 +- cube/runtime/schedule/schedmix.py | 213 +++++++++++++++++++++++++++++ cube/runtime/schedule/strategy.py | 6 +- 11 files changed, 427 insertions(+), 22 deletions(-) create mode 100644 cube/graph/schedule/schedmix.py create mode 100644 cube/runtime/schedule/schedmix.py diff --git a/cube/compiler.py b/cube/compiler.py index 85c1d153..d38c69f9 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -66,8 +66,6 @@ def train_step(model, dataloader): dataloader = cube.runtime.syndata.SynDataLoader(shapes=(),dtypes=()) if not isinstance(dataloader, CubeDataLoader): raise TypeError("Expect dataloader derived from CubeDataLoader") - if callable(PAS): - PAS = (PAS,) model.save_content = load_content ir_dataloader = SemanticDataLoader(dataloader) @@ -115,19 +113,14 @@ def decorator(fn: Callable) -> Callable: # run policy graph = Program().get_graph() - if len(PAS) == 1: - graph = PAS[0](graph, resource) - elif len(PAS) == 3: - P, A, S = PAS - graph = P(graph, resource) - graph = A(graph, resource) - graph = S(graph, resource) + assert callable(PAS), f"Policy PAS is not callable" + graph = PAS(graph, resource) if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") # check assignment and remove anchor node - for node in graph.nodes(): + for node in graph.nodes(flatten=True): if isinstance(node, IRGraphAnchor) or isinstance(node.mirror, IRGraphAnchor): continue if len(node.device) == 0: @@ -142,9 +135,10 @@ def decorator(fn: Callable) -> Callable: if graph.sched is not None: start = time.time() graph.sched.apply() + if CompileFlag.log_schedule: + print(graph.sched) span = time.time() - start print('> planpass on applying schedule strategy: {:.2f} s'.format(span)) - print(graph.sched) # to execution plan execplan = ExecutionPlan(graph) diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index c9eff511..002fc3e3 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -17,6 +17,9 @@ class Grouping(PlanPass): def apply(execplan: ExecutionPlan) -> ExecutionPlan: """ Group contiguous differentiable operators segments + + Note non-differentiable IRAdapter with all identity operators will be + removed from execution plan. """ graph = execplan.graph fgroups, bgroups = Grouping.group(execplan) diff --git a/cube/flags.py b/cube/flags.py index 2b2c3ff4..87c36f2c 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -12,6 +12,7 @@ class CompileFlag: # ============= loggings =================== log_transform = os.environ.get('LOG_TRANSFORM') + log_schedule = os.environ.get('LOG_SCHEDULE') # ================ compiling ======================== diff --git a/cube/graph/schedule/sched1f1b.py b/cube/graph/schedule/sched1f1b.py index e51f442b..095df847 100644 --- a/cube/graph/schedule/sched1f1b.py +++ b/cube/graph/schedule/sched1f1b.py @@ -114,9 +114,9 @@ def __repr__(self) -> str: dscp = '' for mesh in self.devmesh: devid = mesh[0] - segment = self.segment[devid].to_str(skip_attr=True) if self.segment[mesh[0]] else None + # segment = self.segments[devid].to_str(skip_attr=True) if self.segment[mesh[0]] else None dscp += (f"1F1B Schedule: Stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" - f" segment = {segment}\n" + f" segment = {self.segments[devid]}\n" f" send-fw = {self.sfadapter[mesh[0]]}\n" f" recv-fw = {self.rfadapter[mesh[0]]}\n" f" send-bw = {self.sbadapter[mesh[0]]}\n" diff --git a/cube/graph/schedule/schedmix.py b/cube/graph/schedule/schedmix.py new file mode 100644 index 00000000..66aee0f4 --- /dev/null +++ b/cube/graph/schedule/schedmix.py @@ -0,0 +1,190 @@ + +from typing import Dict, Optional, List +import warnings + +from cube.ir.cten import IRCell +from cube.ir.adapter.adapter import IRAdapter +from cube.ir.adapter.adapter import IRWeightReducer + +from cube.graph.graph import IRGraph, IRSegment +from cube.graph.schedule import IRScheduleStrategy +from cube.ir.adapter.prim import IdentityPrim + + +class IRScheduleMix(IRScheduleStrategy): + """ + 1F1B Scheduling + + This treats model as a linear graph which can be + grouped into continous stages. + + [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] + [Recv-Backward] Backward-Segment [Send-Backward] + """ + + def __init__(self, graph, nmicros: int): + super().__init__(graph, nmicros) + self.signature = 'cube.runtime.schedule.ScheduleMix.run' + # forward body + self.encoder_barriers: Dict[int, IRSegment] = dict() + self.decoder_barriers: Dict[int, IRSegment] = dict() + self.fsegments: Dict[int, IRSegment] = dict() + # body forward recv adapter + self.rfadapter: Dict[int, Optional[IRAdapter]] = dict() + # body forward send adapter + self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() + # body backward recv adapter + self.rbadapter: Dict[int, Optional[IRAdapter]] = dict() + # body backward send adapter + self.sbadapter: Dict[int, Optional[IRAdapter]] = dict() + # encoder barrier backward prepare adapter + self.enc_badapter: Dict[int, IRAdapter] = dict() + # decoder barrier forward input prepare adapter + self.dec_fadapter: Dict[int, IRAdapter] = dict() + # decoder barrier backward input prepare adapter + self.dec_badapter: Dict[int, IRAdapter] = dict() + # num_stage + self.num_stages: int = -1 + # stage id + self.stage_id: Dict[int, int] = dict() + # reducers + self.dev_reducers: Dict[int, List[IRWeightReducer]] = dict() + # recompute + self.recompute = False + + + def apply(self) -> IRGraph: + self.mesh() + # each forward adapter has only one input and one output for each device + for node in self.graph.nodes(): + if isinstance(node, IRAdapter) and node.forward: + if len(set(node.outputs())) > 1 or len(set(node.inputs())) > 1: + warnings.warn( + "Detected one adapter has more than one input/output in stage transmission, " + "which is not safe for current scheduling implementation due to potential " + "mis-ordering of arguments. Better to use torch.cat and torch.chunk to " + "merge multiple tensors into one and unpack it at next stage." + ) + # each forward has corresponding backward + assert all(fseg.mirror in self.segments for fseg in self.segments if fseg.isfw()), \ + "Require backward of each forward stage" + + fsegments: List[IRSegment] = [fseg for fseg in self.segments if fseg.isfw()] + self.num_stages = len(fsegments) - 2 + + shard_enc_sid, shard_dec_sid = (0, self.num_stages // 2) + print(f'> shard encoder stage id: {shard_enc_sid} | shard decoder stage id: {shard_dec_sid} | num stages: {self.num_stages}') + + shard_enc, shard_dec = fsegments[0], fsegments[shard_dec_sid + 1] + assert len(shard_enc.device) == len(shard_dec.device) and len(shard_enc.device) >= 4, ( + f"This scheduling can only be applied to number of devices >= 4" + ) + pipe_stages = [seg for lid, seg in enumerate(fsegments) if lid not in (shard_enc_sid, shard_dec_sid + 1)] + + # setup shard encoder embedding + assert len(self.recvers[shard_enc.mirror]) == 1 + for devid in shard_enc.device: + self.encoder_barriers[devid] = shard_enc + self.enc_badapter[devid] = self.recvers[shard_enc.mirror][0] + # setup shard decoder embedding + assert len(self.recvers[shard_dec]) == 1 + assert len(self.recvers[shard_dec.mirror]) == 1 + for devid in shard_dec.device: + self.decoder_barriers[devid] = shard_dec + self.dec_fadapter[devid] = self.recvers[shard_dec][0] + self.dec_badapter[devid] = self.recvers[shard_dec.mirror][0] + # pipeline stages + for sid, stage in enumerate(pipe_stages): + assert len(stage.device) == 1 + devid = stage.device[0] + # forward body + assert devid not in self.fsegments, f"Pipeline stage cannot be overlapped" + self.fsegments[devid] = stage + # forward recv + if sid in (shard_enc_sid, shard_dec_sid): + for adapter in self.recvers[stage]: + assert all(isinstance(prim, IdentityPrim) for prim in adapter.prims), ( + f"stage {sid} got unexpected forward recv adapters: {self.recvers[stage]}" + ) + self.rfadapter[devid] = None + else: + assert len(self.recvers[stage]) == 1 + self.rfadapter[devid] = self.recvers[stage][0] + # forward send + if sid == shard_dec_sid - 1: # decoder recv broadcast + assert len(self.senders[stage]) == 1 + self.sfadapter[devid] = None + elif sid == self.num_stages - 1: + assert len(self.senders[stage]) == 0 + self.sfadapter[devid] = None + else: + assert len(self.senders[stage]) == 1 + self.sfadapter[devid] = self.senders[stage][0] + # backward recv + if sid in (shard_dec_sid - 1, self.num_stages - 1): + for adapter in self.recvers[stage.mirror]: + assert all(isinstance(prim, IdentityPrim) for prim in adapter.prims), ( + f"stage {sid} got unexpected backward recv adapters: {self.recvers[stage]}" + ) + self.rbadapter[devid] = None + else: + assert len(self.recvers[stage.mirror]) == 1, \ + f"stage {sid} got unexpected backward recv adapters: {self.recvers[stage.mirror]}" + self.rbadapter[devid] = self.recvers[stage.mirror][0] + # backward send: + if sid == shard_dec_sid: # decoder broadcast + assert len(self.senders[stage.mirror]) == 1 + self.sbadapter[devid] = None + elif sid == shard_enc_sid: # encoder broadcast + assert len(self.senders[stage.mirror]) == 1 + self.sbadapter[devid] = None + else: + self.sbadapter[devid] = self.senders[stage.mirror][0] + + # weight reducer + self.dev_reducers[devid] = [reducer for reducer in self.reducers if devid in reducer.device] + # stage id + self.stage_id[devid] = sid + + return self.graph + + def kwargs(self, devid: int) -> Dict[str, IRCell]: + """ + return kwargs for runtime caller + """ + return dict( + encoder_barrier = self.encoder_barriers[devid], + decoder_barrier = self.decoder_barriers[devid], + segment = self.fsegments[devid], + sfadapter = self.sfadapter[devid], + rfadapter = self.rfadapter[devid], + sbadapter = self.sbadapter[devid], + rbadapter = self.rbadapter[devid], + enc_badapter = self.enc_badapter[devid], + dec_fadapter = self.dec_fadapter[devid], + dec_badapter = self.dec_badapter[devid], + dataloader = 'dataloader', + stage_id = self.stage_id[devid], + num_stages = self.num_stages, + num_microbatch = self.nmicros, + reducers = self.dev_reducers[devid], + recompute = self.recompute + ) + + def __repr__(self) -> str: + dscp = '' + devices = self.devmesh[0] + for devid in devices: + dscp += (f"Interplaced Schedule: Stage[{self.stage_id[devid]}](dev {devid})(\n" + f" encoder_barrier = {self.encoder_barriers[devid]}\n" + f" decoder_barrier = {self.decoder_barriers[devid]}\n" + f" segment = {self.fsegments[devid]}\n" + f" send-fw = {self.sfadapter[devid]}\n" + f" recv-fw = {self.rfadapter[devid]}\n" + f" send-bw = {self.sbadapter[devid]}\n" + f" recv-bw = {self.rbadapter[devid]}\n" + f" enc_badapter = {self.enc_badapter[devid]}\n" + f" dec_fadapter = {self.dec_fadapter[devid]}\n" + f" dec_badapter = {self.dec_badapter[devid]}\n" + f")\n") + return dscp diff --git a/cube/graph/schedule/strategy.py b/cube/graph/schedule/strategy.py index 94fc049c..323b4e08 100644 --- a/cube/graph/schedule/strategy.py +++ b/cube/graph/schedule/strategy.py @@ -35,6 +35,7 @@ def mesh(self) -> List[List[int]]: for segment in self.graph.nodes(): if isinstance(segment, IRSegment): self.segments.append(segment) + self.devmesh.append(segment.device) self.recvers[segment] = [] self.senders[segment] = [] diff --git a/cube/ir/cten.py b/cube/ir/cten.py index ed4d92b2..6ffd6cf0 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -479,11 +479,11 @@ def cell(self, val: Optional[IRCell]): self._cell = val @property - def device(self) -> List[int]: + def device(self) -> Tuple[int]: if self._cell: - return self._cell.device + return tuple(self._cell.device) else: - return [] + return () @device.setter def device(self, val: Union[int, List[int]]): diff --git a/cube/runtime/schedule/__init__.py b/cube/runtime/schedule/__init__.py index b2db67e5..b962ccab 100644 --- a/cube/runtime/schedule/__init__.py +++ b/cube/runtime/schedule/__init__.py @@ -1 +1,2 @@ -from cube.runtime.schedule.sched1f1b import Schedule1F1B \ No newline at end of file +from cube.runtime.schedule.sched1f1b import Schedule1F1B +from cube.runtime.schedule.schedmix import ScheduleMix \ No newline at end of file diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py index 8e056703..c8d4bdde 100644 --- a/cube/runtime/schedule/sched1f1b.py +++ b/cube/runtime/schedule/sched1f1b.py @@ -37,7 +37,7 @@ def run(segment: Callable, # forward body # recv forward # print(f'rank[{torch.distributed.get_rank()}]: line26: recving forward') inputs = Schedule1F1B.adapter_step(rfadapter, True) - inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs + inputs = Schedule1F1B.dataloader_step(dataloader) if inputs == (None,) else inputs # forward Schedule1F1B.push_tail('inputs', inputs) if recompute: @@ -55,7 +55,7 @@ def run(segment: Callable, # forward body if num_warmup_remaining > 0: # print(f'rank[{torch.distributed.get_rank()}]: line44 recv forward') inputs = Schedule1F1B.adapter_step(rfadapter, True) - inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs + inputs = Schedule1F1B.dataloader_step(dataloader) if inputs == (None,) else inputs # steady for i in range(num_warmup_remaining): @@ -87,7 +87,7 @@ def run(segment: Callable, # forward body if i != num_warmup_remaining - 1: # print(f'rank[{torch.distributed.get_rank()}]: line77 send backward recv forward') inputs = Schedule1F1B.exchange(sbadapter, rfadapter, stage_id, (False, True), *input_grads) - inputs = Schedule1F1B.dataloader_step(dataloader) if len(inputs) == 0 else inputs + inputs = Schedule1F1B.dataloader_step(dataloader) if inputs == (None,) else inputs else: # send backward # print(f'rank[{torch.distributed.get_rank()}]: line82 send backward') diff --git a/cube/runtime/schedule/schedmix.py b/cube/runtime/schedule/schedmix.py new file mode 100644 index 00000000..5ea6463f --- /dev/null +++ b/cube/runtime/schedule/schedmix.py @@ -0,0 +1,213 @@ +""" +Schedule Plan designed for Interplaced Pipeline +""" + +from typing import Callable, Iterable, List, Optional +import torch + +from cube.runtime.schedule.strategy import ScheduleABC + +def debug_msg(msg: str, ranks): + myrank = torch.distributed.get_rank() + if myrank in ranks: + print(f'rank [{myrank}]: {msg}') + + +class ScheduleMix(ScheduleABC): + """ + Emb -> Encoder -> Demb -> Decoder + + All communication will start at begining of each step and + finish at the end of step. No communication will happen cross + step, i.e., send from the previous step and recv at the next step. + """ + @staticmethod + def run(encoder_barrier: Callable, + decoder_barrier: Callable, + segment: Callable, + rfadapter: Optional[Callable], # segment adapter + sfadapter: Optional[Callable], # segment adapter + rbadapter: Optional[Callable], # segment adapter + sbadapter: Optional[Callable], # segment adapter + enc_badapter: Optional[Callable], # sharding encoder gradient input prepare adapter + dec_fadapter: Optional[Callable], # sharding decoder input prepare adapter + dec_badapter: Optional[Callable], # sharding decoder gradient input prepare adapter + dataloader: Iterable, + stage_id: int, + num_stages: int, + num_microbatch: int, + reducers: List[Callable], + recompute: bool = False): + + assert num_stages >= 4, f"Only support for stage number >= 4." + + enc_emb_stage = 0 + dec_emb_stage = num_stages // 2 + + fw_ofst = -(stage_id // 2) + bw_ofst = -(num_stages - 1 - (stage_id // 2)) + + # sharding encoder embed inputs / outputs + shard_enc_inputs, shard_enc_outputs = (None,), (None,) + shard_enc_input_grads, shard_enc_output_grads = (None,), (None,) + # sharding decoder embed inputs / outputs + shard_dec_inputs, shard_dec_outputs = (None,), (None,) + shard_dec_input_grads, shard_dec_output_grads = (None,), (None,) + # segement inputs / outputs + segment_inputs, segment_outputs = (None,), (None,) + segment_input_grads, segment_output_grads = (None,), (None,) + + for step in range(num_microbatch + num_stages - 1): + fmid, bmid = step + fw_ofst, step + bw_ofst + encoder_fw_mid = step + decoder_fw_mid = step - num_stages // 2 // 2 + encoder_bw_mid = step + 1 - num_stages // 2 * 2 + decoder_bw_mid = step + 1 - int(num_stages // 2 * 1.5) + do_forward = 0 <= fmid and fmid < num_microbatch + do_backward = 0 <= bmid and bmid < num_microbatch + + # step1: sharding encoder forward + if 0 <= encoder_fw_mid and encoder_fw_mid < num_microbatch: + data = ScheduleMix.dataloader_step(dataloader) + shard_enc_outputs = ScheduleMix.forward_step(encoder_barrier, *data) + ScheduleMix.push_tail('shard_enc_inputs', data) + ScheduleMix.push_tail('shard_enc_outputs', shard_enc_outputs) + shard_enc_outputs = tuple(t.detach().requires_grad_() for t in shard_enc_outputs) + + # step2: sharding decoder forward + if 0 <= decoder_fw_mid and decoder_fw_mid < num_microbatch: + if stage_id == dec_emb_stage - 1: + shard_dec_inputs = tuple(t.detach().requires_grad_() for t in segment_outputs) + ScheduleMix.adapter_step(dec_fadapter, True, *shard_dec_inputs) + else: + shard_dec_inputs = ScheduleMix.adapter_step(dec_fadapter, True) + shard_dec_outputs = ScheduleMix.forward_step(decoder_barrier, *shard_dec_inputs) + ScheduleMix.push_tail('shard_dec_inputs', shard_dec_inputs) + ScheduleMix.push_tail('shard_dec_outputs', shard_dec_outputs) + shard_dec_outputs = tuple(t.detach().requires_grad_() for t in shard_dec_outputs) + + # step3: forward then backward + if stage_id % 2 == 0: + + # After barrier communication: send backward recv forward =========> + if segment_input_grads != (None,): + ScheduleMix.adapter_step(sbadapter, False, *segment_input_grads) + segment_input_grads = (None,) + if do_forward: + if stage_id == enc_emb_stage: + segment_inputs = shard_enc_outputs + elif stage_id == dec_emb_stage: + segment_inputs = shard_dec_outputs + else: + segment_inputs = ScheduleMix.adapter_step(rfadapter, True) + # <=============================================================== + + segment_outputs = (None,) + if do_forward: + ScheduleMix.push_tail('segment_inputs', segment_inputs) + if recompute: + with torch.no_grad(): + segment_outputs = ScheduleMix.forward_step(segment, *segment_inputs) + ScheduleMix.push_tail('segment_outputs', None) + else: + segment_outputs = ScheduleMix.forward_step(segment, *segment_inputs) + ScheduleMix.push_tail('segment_outputs', segment_outputs) + + # recompute + if recompute: + inputs = ScheduleMix.pop_head('segment_inputs', inputs) + ScheduleMix.pop_head('segment_outputs', outputs) + outputs = ScheduleMix.forward_step(segment, *inputs) + ScheduleMix.push_head('segment_inputs', inputs) + ScheduleMix.push_head('segment_outputs', outputs) + + # Inter barrier communication: recv backward send forward ======> + if do_backward: + segment_output_grads = ScheduleMix.adapter_step(rbadapter, False) + if segment_outputs != (None,): + ScheduleMix.adapter_step(sfadapter, True, *segment_outputs) + # <=============================================================== + + segment_input_grads = (None,) + if do_backward: + inputs = ScheduleMix.pop_head('segment_inputs') + outputs = ScheduleMix.pop_head('segment_outputs') + segment_input_grads = ScheduleMix.backward_step(inputs, outputs, segment_output_grads) + + # step3: backward then forward + if stage_id % 2 == 1: + + # After barrier communication: recv backward send forward =========> + if do_backward: + if stage_id == dec_emb_stage - 1: + segment_output_grads = shard_dec_input_grads + else: + segment_output_grads = ScheduleMix.adapter_step(rbadapter, False) + if segment_outputs != (None,): + segment_input_grads = ScheduleMix.adapter_step(sfadapter, True, *segment_outputs) + # <=============================================================== + + segment_input_grads = (None,) + if do_backward: + inputs = ScheduleMix.pop_head('segment_inputs') + outputs = ScheduleMix.pop_head('segment_outputs') + segment_input_grads = ScheduleMix.backward_step(inputs, outputs, segment_output_grads) + + # Inter barrier communication: send backward recv forward ========> + if segment_input_grads != (None,): + ScheduleMix.adapter_step(sbadapter, False, *segment_input_grads) + if do_forward: + segment_inputs = ScheduleMix.adapter_step(rfadapter, True) + # <=============================================================== + + segment_outputs = (None,) + if do_forward: + ScheduleMix.push_tail('segment_inputs', segment_inputs) + if recompute: + with torch.no_grad(): + segment_outputs = ScheduleMix.forward_step(segment, *segment_inputs) + ScheduleMix.push_tail('segment_outputs', None) + else: + segment_outputs = ScheduleMix.forward_step(segment, *segment_inputs) + ScheduleMix.push_tail('segment_outputs', segment_outputs) + + # recompute + if recompute: + inputs = ScheduleMix.pop_head('segment_inputs', inputs) + ScheduleMix.pop_head('segment_outputs', outputs) + outputs = ScheduleMix.forward_step(segment, *inputs) + ScheduleMix.push_head('segment_inputs', inputs) + ScheduleMix.push_head('segment_outputs', outputs) + + # step 4: sharding decoder backward + if 0 <= decoder_bw_mid and decoder_bw_mid < num_microbatch: + if stage_id == dec_emb_stage: + assert segment_input_grads != (None,) + shard_dec_output_grads = segment_input_grads + ScheduleMix.adapter_step(dec_badapter, False, *shard_dec_output_grads) + else: + shard_dec_output_grads = ScheduleMix.adapter_step(dec_badapter, False) + + inputs = ScheduleMix.pop_head('shard_dec_inputs') + outputs = ScheduleMix.pop_head('shard_dec_outputs') + shard_dec_input_grads = ScheduleMix.backward_step( + inputs, outputs, shard_dec_output_grads) + + # step 5: sharding encoder backward + if 0 <= encoder_bw_mid and encoder_bw_mid < num_microbatch: + if stage_id == enc_emb_stage: + assert segment_input_grads != (None,) + shard_enc_output_grads = segment_input_grads + ScheduleMix.adapter_step(enc_badapter, False, *shard_enc_output_grads) + else: + shard_enc_output_grads = ScheduleMix.adapter_step(enc_badapter, False) + + inputs = ScheduleMix.pop_head('shard_enc_inputs') + outputs = ScheduleMix.pop_head('shard_enc_outputs') + shard_enc_input_grads = ScheduleMix.backward_step( + inputs, outputs, shard_enc_output_grads) + + for reducer in reducers: + reducer() + + ScheduleMix.assert_empty() diff --git a/cube/runtime/schedule/strategy.py b/cube/runtime/schedule/strategy.py index 2d4fdaec..55b3556f 100644 --- a/cube/runtime/schedule/strategy.py +++ b/cube/runtime/schedule/strategy.py @@ -53,9 +53,11 @@ def dataloader_step(dataloader: Iterable): @staticmethod def adapter_step(adapter: Callable, require_grad : bool = True, *args): """ - adapter pass + Adapter pass. + If the adapter is None, will return (None,) """ - if adapter is None: return () + if adapter is None: return (None,) + # if adapter is None: return () args = tuple(t for t in args if torch.is_tensor(t)) CudaTimer().start('adapter') outputs = adapter(*args) From cc0670f01bbbd528996d83147024f1abf49bee2c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Nov 2022 14:43:32 +0800 Subject: [PATCH 1143/1892] mbart use schedmix policy --- examples/nlp/mbart/policy/mpmd.py | 53 +++++++++++++++++++++++++++++++ examples/nlp/mbart/train.py | 33 ++++++++++++++++--- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/examples/nlp/mbart/policy/mpmd.py b/examples/nlp/mbart/policy/mpmd.py index 5651425a..f5f5b502 100644 --- a/examples/nlp/mbart/policy/mpmd.py +++ b/examples/nlp/mbart/policy/mpmd.py @@ -7,6 +7,7 @@ from cube.graph.segment import IRSegment from cube.ir.cten import IRCell from cube.graph.schedule.sched1f1b import IRSchedule1F1B +from cube.graph.schedule.schedmix import IRScheduleMix def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: @@ -249,3 +250,55 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: strategy = IRSchedule1F1B(graph, num_microbatch) graph.predef_sched(strategy) return graph + + +def PASMixPipe(graph: IRGraph, resource): + + pp_size = resource.ngpus + + blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) + enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] + dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] + + num_microbatch = 4 + + # pipelien stage + embed_sid = [0, pp_size // 2 + 1] + fstages = [[] for _ in range(pp_size)] + nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // pp_size + for lid, fnodes in enumerate(enc_layers + dec_layers): + stage_id = min(lid // nlayer_per_stage, pp_size - 1) + fstages[stage_id] += fnodes + fstages.insert(embed_sid[0], enc_emb) + fstages.insert(embed_sid[1], dec_emb) + graph.staging(tuple(stage[0] for stage in fstages)) + + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + assert len(fstages) == pp_size + 2 + + # fully shard enmbedding + enc_emb, dec_emb = fstages[embed_sid[0]], fstages[embed_sid[1]] + tp_device = list(range(resource.ngpus)) + for node in enc_emb.nodes() + dec_emb.nodes(): + # skip anchor nodes + if isinstance(node, IRGraphAnchor): continue + # shard embedding layer to all devices + if node.name == 'embedding': + _tp(graph, node, tp_device, idx=1, dim=0) + else: + _replica(graph, node, tp_device) + + dataloader = graph.select(ntype=IRDataOperation)[0] + _replica(graph, dataloader, tp_device) + + # pipeline stage to devices + pipe_stages = [stage for sid, stage in enumerate(fstages) if sid not in embed_sid] + assert len(pipe_stages) == pp_size + for sid, stage in enumerate(pipe_stages): + print(stage) + graph.assign(stage, sid) + + strategy = IRScheduleMix(graph, num_microbatch) + graph.predef_sched(strategy) + + return graph diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index 99f71d2e..09a0f25e 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -2,9 +2,9 @@ example: OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ + --nproc_per_node=8 \ --nnodes=1 \ - examples/nlp/mbart/train.py --policy PASSingle + examples/nlp/mbart/train.py --policy PASMegatron """ @@ -19,6 +19,7 @@ import examples.nlp.mbart.policy.mpmd as mpmd import argparse +import math parser = argparse.ArgumentParser(description='GPT Train') parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') @@ -28,7 +29,7 @@ parser.add_argument('--gbs', type=int, default=1, help='global batch size') parser.add_argument('--mbs', type=int, default=2, help='micro batch size') # arch -parser.add_argument('--vocab', type=int, default=256, +parser.add_argument('--vocab', type=int, default=2500, help='used vocabulary size') parser.add_argument('--layers', type=int, default=4, help='layer number of each encoder and decoder') @@ -54,6 +55,20 @@ raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") +def trunc_normal_(tensor: torch.Tensor, mean=0., std=1., a=-2., b=2.): + def norm_cdf(x): + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + # tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + def train(): @@ -64,9 +79,11 @@ def train(): Config.heads = args.heads Config.seqlen = args.seqlen - model = MBartForSentenceClassification(batch_size).cuda() + if cube.runtime.device.DeviceGroup().local_rank == 0: + model = MBartForSentenceClassification(batch_size).cuda() + else: + model = None dataloader = MBartSyntheticDataLoader(batch_size) - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) print_each_rank('model weight consumpition:') memory_summary() @@ -79,6 +96,8 @@ def train_iter(model, dataloader): loss.backward() model = model.get_gen_module() + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + for name, buffer in model.named_buffers(): torch.manual_seed(0) if name.startswith('decoder_input_ids'): @@ -87,6 +106,10 @@ def train_iter(model, dataloader): dtype=torch.int64, device=torch.cuda.current_device(), ) buffer.copy_(inputs) + + torch.manual_seed(0) + for param in model.parameters(): + trunc_normal_(param) CudaTimer(enable=False).warmup() iter_num, warmup = 5, 2 From a44fd29da8ad4616b6372569022f9b798d0dc971 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Nov 2022 15:40:12 +0800 Subject: [PATCH 1144/1892] support broadcast primitive in general communication --- cube/graph/gener/concurrent.py | 51 ++++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 50054912..bfe44598 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -1,7 +1,7 @@ """ Concurrent producer / consumer Adapter Generator """ -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Tuple import copy import numpy as np @@ -9,6 +9,7 @@ from cube.ir.adapter.prim import IRAdapterPrim from cube.ir.adapter import IRAdapter from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim +from cube.ir.adapter.prim import BroadcastPrim from cube.graph.gener.layout import GridLayout, PathFinder @@ -180,21 +181,61 @@ def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], fprims = [] fpdevs = set(t.device[0] for t in fptensors) fcomm_workload = {t.device[0]: 0 for t in fptensors} - for ctensor in fctensors: - fprims += ConcurrentGener.gen_subtensor(ctensor, fptensors, fcomm_workload) + # first try collectives + ret, prims = ConcurrentGener.gen_subtensor_coll(fctensors, fptensors, fcomm_workload) + if ret: + fprims += prims + # otherwise use general p2p send recv + else: + for ctensor in fctensors: + fprims += ConcurrentGener.gen_subtensor(ctensor, fptensors, fcomm_workload) fadapter = IRAdapter(fptensors,fctensors) fadapter.prims = fprims # backward if len(bptensors) > 0 and len(bctensors) > 0: bprims = [] bcomm_workload = {t.device[0]: 0 for t in bptensors} - for cgrad in bctensors: - bprims += ConcurrentGener.gen_subtensor(cgrad, bptensors, bcomm_workload) + # first try collectives + ret, prims = ConcurrentGener.gen_subtensor_coll(bctensors, bptensors, bcomm_workload) + if ret: + bprims += prims + # otherwise use general p2p send recv + else: + for cgrad in bctensors: + bprims += ConcurrentGener.gen_subtensor(cgrad, bptensors, bcomm_workload) badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims IRAdapter.make_pair(fadapter, badapter) return fadapter + @staticmethod + def gen_subtensor_coll(ctensors: List[IRSubTensor], ptensors: List[IRSubTensor], workload: Dict[int, int]) -> Tuple[bool, List[IRAdapterPrim]]: + """ + Generate communication primitives for a tensor using collectives of + broadcast, [reduce, gather and scatter]. => [...] Not supported yet. + + @param ctensors List[IRSubTensor]: the consumed tensors as destination. + @param ptensors List[IRSubTensor]: the produced tensors as source + + @return success bool: whether succeed in generate collective + @return prims List[IRAdapterPrim]: the primitives for adapter + """ + ret = False + prims = [] + # broadcast + if len(ptensors) == 1 and \ + len(set(ctensor.device[0] for ctensor in ctensors)) > 2 and \ + all(ptensors[0] == ctensor for ctensor in ctensors): + dev_ctensors = [] + cdevs = set() + for ctensor in ctensors: + if ctensor.device[0] not in cdevs: + cdevs.add(ctensor.device[0]) + dev_ctensors.append(ctensor) + prims.append(BroadcastPrim(ptensors, dev_ctensors)) + ret = True + return ret, prims + @staticmethod def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor], workload: Dict[int, int]) -> List[IRAdapterPrim]: """ From c0d8eb95c99816cbac8f897048be745e175f6f29 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Nov 2022 18:19:01 +0800 Subject: [PATCH 1145/1892] fix schedule with recompute --- cube/runtime/schedule/schedmix.py | 12 ++++++------ cube/runtime/schedule/strategy.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cube/runtime/schedule/schedmix.py b/cube/runtime/schedule/schedmix.py index 5ea6463f..425c5041 100644 --- a/cube/runtime/schedule/schedmix.py +++ b/cube/runtime/schedule/schedmix.py @@ -114,9 +114,9 @@ def run(encoder_barrier: Callable, ScheduleMix.push_tail('segment_outputs', segment_outputs) # recompute - if recompute: - inputs = ScheduleMix.pop_head('segment_inputs', inputs) - ScheduleMix.pop_head('segment_outputs', outputs) + if recompute and do_backward: + inputs = ScheduleMix.pop_head('segment_inputs') + ScheduleMix.pop_head('segment_outputs') outputs = ScheduleMix.forward_step(segment, *inputs) ScheduleMix.push_head('segment_inputs', inputs) ScheduleMix.push_head('segment_outputs', outputs) @@ -172,9 +172,9 @@ def run(encoder_barrier: Callable, ScheduleMix.push_tail('segment_outputs', segment_outputs) # recompute - if recompute: - inputs = ScheduleMix.pop_head('segment_inputs', inputs) - ScheduleMix.pop_head('segment_outputs', outputs) + if recompute and (0 <= bmid + 1 and bmid + 1 < num_microbatch): + inputs = ScheduleMix.pop_head('segment_inputs') + ScheduleMix.pop_head('segment_outputs') outputs = ScheduleMix.forward_step(segment, *inputs) ScheduleMix.push_head('segment_inputs', inputs) ScheduleMix.push_head('segment_outputs', outputs) diff --git a/cube/runtime/schedule/strategy.py b/cube/runtime/schedule/strategy.py index 55b3556f..3e57169f 100644 --- a/cube/runtime/schedule/strategy.py +++ b/cube/runtime/schedule/strategy.py @@ -98,7 +98,7 @@ def push_head(name: str, val: Any): @staticmethod def pop_head(name: str): assert name in ScheduleABC.status, f"{name} is empty" - out = ScheduleABC.status[name].pop(-1) + out = ScheduleABC.status[name].pop(0) if len(ScheduleABC.status[name]) == 0: del ScheduleABC.status[name] return out @@ -106,7 +106,7 @@ def pop_head(name: str): @staticmethod def pop_tail(name: str): assert name in ScheduleABC.status, f"{name} is empty" - out = ScheduleABC.status[name].pop(0) + out = ScheduleABC.status[name].pop(-1) if len(ScheduleABC.status[name]) == 0: del ScheduleABC.status return out From c59670a6f8d36e6b1b7e034c8f4fa136ea045336 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Nov 2022 21:21:47 +0800 Subject: [PATCH 1146/1892] allow multi broadcast generation --- cube/graph/gener/concurrent.py | 40 +++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index bfe44598..dfd1bf51 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -222,17 +222,37 @@ def gen_subtensor_coll(ctensors: List[IRSubTensor], ptensors: List[IRSubTensor], """ ret = False prims = [] - # broadcast - if len(ptensors) == 1 and \ - len(set(ctensor.device[0] for ctensor in ctensors)) > 2 and \ - all(ptensors[0] == ctensor for ctensor in ctensors): - dev_ctensors = [] - cdevs = set() + fuse_broadcast = True + # check broadcast + if len(ptensors) >= len(ctensors) or len(ptensors) == 0: + fuse_broadcast = False + else: + for ptensor in ptensors: + if not all(ptensor == ctensor for ctensor in ctensors): + fuse_broadcast = False + break + # fuse to broadcast + if fuse_broadcast: + cdev_tensors, pdev_tensors = dict(), dict() + for ptensor in ptensors: + pdev_tensors.setdefault(ptensor.device[0], []).append(ptensor) for ctensor in ctensors: - if ctensor.device[0] not in cdevs: - cdevs.add(ctensor.device[0]) - dev_ctensors.append(ctensor) - prims.append(BroadcastPrim(ptensors, dev_ctensors)) + # not consider self-transmission + if ctensor.device[0] in pdev_tensors: continue + cdev_tensors.setdefault(ctensor.device[0], []).append(ctensor) + if len(cdev_tensors) // len(pdev_tensors) <= 1: # can simply use send recv + return False, [] + pdevs = list(pdev_tensors.keys()) + cdevs = list(cdev_tensors.keys()) + broadcast_ndevs = len(cdevs) // len(pdevs) + start = 0 + for idx, pdev in enumerate(pdevs): + addone = 1 if idx < (len(cdevs) % len(pdevs)) else 0 + end = start + broadcast_ndevs + addone + pdev_ctensors = [cdev_tensors[devid][0] for devid in cdevs[start:end]] + pdev_ctensors += [pdev_tensors[pdev][0]] + prims.append(BroadcastPrim([pdev_tensors[pdev][0]], pdev_ctensors)) + start = end ret = True return ret, prims From bea39afc4dd8db473bb01f1f75b3f00261fee703 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Nov 2022 21:22:10 +0800 Subject: [PATCH 1147/1892] enable tp applied to schedmix --- cube/graph/schedule/schedmix.py | 99 ++++++++++++++++----------------- 1 file changed, 49 insertions(+), 50 deletions(-) diff --git a/cube/graph/schedule/schedmix.py b/cube/graph/schedule/schedmix.py index 66aee0f4..6267b5a8 100644 --- a/cube/graph/schedule/schedmix.py +++ b/cube/graph/schedule/schedmix.py @@ -95,56 +95,55 @@ def apply(self) -> IRGraph: self.dec_badapter[devid] = self.recvers[shard_dec.mirror][0] # pipeline stages for sid, stage in enumerate(pipe_stages): - assert len(stage.device) == 1 - devid = stage.device[0] - # forward body - assert devid not in self.fsegments, f"Pipeline stage cannot be overlapped" - self.fsegments[devid] = stage - # forward recv - if sid in (shard_enc_sid, shard_dec_sid): - for adapter in self.recvers[stage]: - assert all(isinstance(prim, IdentityPrim) for prim in adapter.prims), ( - f"stage {sid} got unexpected forward recv adapters: {self.recvers[stage]}" - ) - self.rfadapter[devid] = None - else: - assert len(self.recvers[stage]) == 1 - self.rfadapter[devid] = self.recvers[stage][0] - # forward send - if sid == shard_dec_sid - 1: # decoder recv broadcast - assert len(self.senders[stage]) == 1 - self.sfadapter[devid] = None - elif sid == self.num_stages - 1: - assert len(self.senders[stage]) == 0 - self.sfadapter[devid] = None - else: - assert len(self.senders[stage]) == 1 - self.sfadapter[devid] = self.senders[stage][0] - # backward recv - if sid in (shard_dec_sid - 1, self.num_stages - 1): - for adapter in self.recvers[stage.mirror]: - assert all(isinstance(prim, IdentityPrim) for prim in adapter.prims), ( - f"stage {sid} got unexpected backward recv adapters: {self.recvers[stage]}" - ) - self.rbadapter[devid] = None - else: - assert len(self.recvers[stage.mirror]) == 1, \ - f"stage {sid} got unexpected backward recv adapters: {self.recvers[stage.mirror]}" - self.rbadapter[devid] = self.recvers[stage.mirror][0] - # backward send: - if sid == shard_dec_sid: # decoder broadcast - assert len(self.senders[stage.mirror]) == 1 - self.sbadapter[devid] = None - elif sid == shard_enc_sid: # encoder broadcast - assert len(self.senders[stage.mirror]) == 1 - self.sbadapter[devid] = None - else: - self.sbadapter[devid] = self.senders[stage.mirror][0] - - # weight reducer - self.dev_reducers[devid] = [reducer for reducer in self.reducers if devid in reducer.device] - # stage id - self.stage_id[devid] = sid + for devid in stage.device: + assert devid not in self.fsegments, f"Pipeline stage cannot be overlapped" + # forward body + self.fsegments[devid] = stage + # forward recv + if sid in (shard_enc_sid, shard_dec_sid): + for adapter in self.recvers[stage]: + assert all(isinstance(prim, IdentityPrim) for prim in adapter.prims), ( + f"stage {sid} got unexpected forward recv adapters: {self.recvers[stage]}" + ) + self.rfadapter[devid] = None + else: + assert len(self.recvers[stage]) == 1 + self.rfadapter[devid] = self.recvers[stage][0] + # forward send + if sid == shard_dec_sid - 1: # decoder recv broadcast + assert len(self.senders[stage]) == 1 + self.sfadapter[devid] = None + elif sid == self.num_stages - 1: + assert len(self.senders[stage]) == 0 + self.sfadapter[devid] = None + else: + assert len(self.senders[stage]) == 1 + self.sfadapter[devid] = self.senders[stage][0] + # backward recv + if sid in (shard_dec_sid - 1, self.num_stages - 1): + for adapter in self.recvers[stage.mirror]: + assert all(isinstance(prim, IdentityPrim) for prim in adapter.prims), ( + f"stage {sid} got unexpected backward recv adapters: {self.recvers[stage]}" + ) + self.rbadapter[devid] = None + else: + assert len(self.recvers[stage.mirror]) == 1, \ + f"stage {sid} got unexpected backward recv adapters: {self.recvers[stage.mirror]}" + self.rbadapter[devid] = self.recvers[stage.mirror][0] + # backward send: + if sid == shard_dec_sid: # decoder broadcast + assert len(self.senders[stage.mirror]) == 1 + self.sbadapter[devid] = None + elif sid == shard_enc_sid: # encoder broadcast + assert len(self.senders[stage.mirror]) == 1 + self.sbadapter[devid] = None + else: + self.sbadapter[devid] = self.senders[stage.mirror][0] + + # weight reducer + self.dev_reducers[devid] = [reducer for reducer in self.reducers if devid in reducer.device] + # stage id + self.stage_id[devid] = sid return self.graph From 971f11836f344938aeed843872daa67ba4553e15 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Nov 2022 21:23:03 +0800 Subject: [PATCH 1148/1892] add tp support for schedmix --- examples/nlp/mbart/policy/mpmd.py | 14 +++++++++++--- examples/nlp/mbart/train.py | 4 ++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/nlp/mbart/policy/mpmd.py b/examples/nlp/mbart/policy/mpmd.py index f5f5b502..48299a8b 100644 --- a/examples/nlp/mbart/policy/mpmd.py +++ b/examples/nlp/mbart/policy/mpmd.py @@ -254,7 +254,8 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: def PASMixPipe(graph: IRGraph, resource): - pp_size = resource.ngpus + tp_size = 2 + pp_size = resource.ngpus // tp_size blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] @@ -295,8 +296,15 @@ def PASMixPipe(graph: IRGraph, resource): pipe_stages = [stage for sid, stage in enumerate(fstages) if sid not in embed_sid] assert len(pipe_stages) == pp_size for sid, stage in enumerate(pipe_stages): - print(stage) - graph.assign(stage, sid) + tp_devs = [idx for idx in range(tp_size * sid, tp_size * sid + tp_size)] + for node in stage.nodes(): + if len(node.inputs()) == 0: continue # anchor + if node.name == 'self_attention' or node.name == 'feedforward': + _tp(graph, node, tp_devs, idx=1, dim=0) + elif node.name == 'cross_attention': + _tp(graph, node, tp_devs, idx=2, dim=0) + else: + _replica(graph, node, tp_devs) strategy = IRScheduleMix(graph, num_microbatch) graph.predef_sched(strategy) diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index 09a0f25e..873ce256 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=1 \ - examples/nlp/mbart/train.py --policy PASMegatron + examples/nlp/mbart/train.py --policy PASMixPipe """ @@ -31,7 +31,7 @@ # arch parser.add_argument('--vocab', type=int, default=2500, help='used vocabulary size') -parser.add_argument('--layers', type=int, default=4, +parser.add_argument('--layers', type=int, default=8, help='layer number of each encoder and decoder') parser.add_argument('--heads', type=int, default=16, help='head number') From 0004082809ab82d1888b33ba6332f9fff1be3454 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 10 Nov 2022 20:22:58 +0800 Subject: [PATCH 1149/1892] add amp support --- cube/flags.py | 12 +++++++--- cube/runtime/executor.py | 48 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/cube/flags.py b/cube/flags.py index 87c36f2c..64cf8e4d 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -7,9 +7,6 @@ class CompileFlag: - # ============== runtime ==================== - dev_mode = os.environ.get('SINGLE_DEV_MODE') # allow to use python xx.py - # ============= loggings =================== log_transform = os.environ.get('LOG_TRANSFORM') log_schedule = os.environ.get('LOG_SCHEDULE') @@ -23,3 +20,12 @@ class CompileFlag: use_nnfusion = os.environ.get('USE_NNFUSION') use_jit = os.environ.get('USE_JIT') + + # ============== runtime ==================== + dev_mode = os.environ.get('SINGLE_DEV_MODE') # allow to use python xx.py + + # use automate mixture precision training, where weights, gradients + # and optimizer status are kept in its original type (can be float32), + # but the forward will be converted to float16 with torch.autocast API. + use_amp = os.environ.get('USE_AMP') + diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index ee4604a5..6a36a6b0 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -4,6 +4,20 @@ from typing import Tuple, Any, Callable, List, Dict import torch +import warnings + +from cube.flags import CompileFlag + + +if CompileFlag.use_amp: + warnings.warn( + "Detected auto mixed precision (AMP) is enabled. It's an " + "experimental feature that is only for benchmark. " + "torch.cdua.amp.GradScalerr is not enabled for loss " + "and optimizer, which may lead to gradient loss. The tensors " + "and dtypes arguments in adapter will be automatically converted to " + "torch.float16, if they are in float32 precision or torch.float32 dtype." + ) def debug_id(tensors, msg: str, rank: int): @@ -14,10 +28,27 @@ def debug_id(tensors, msg: str, rank: int): print(f'[{torch.distributed.get_rank()}] {msg}: {[id(t) for t in tensors]}') +def convert_fp32_to_fp16(t: Any): + """ + A tensor with float32 will be converted to float16. + A dtype of torch.float32 will be returned as torch.float16 + """ + if isinstance(t, torch.dtype) and t == torch.float32: + t = torch.float16 + elif torch.is_tensor(t) and t.dtype == torch.float32: + with torch.no_grad(): + t = t.half() + return t + + class Executor: _detach: Dict[str, Dict[torch.Tensor, torch.Tensor]] = dict() + # auto mixture precision loss scaler. $ TODO: support it. + _scaler = torch.cuda.amp.GradScaler(enabled=CompileFlag.use_amp) + + @staticmethod def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): """ @@ -25,7 +56,11 @@ def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires """ if not requires_grad: with torch.no_grad(): - outputs = subgraph(*input_tensors) + if CompileFlag.use_amp: + with torch.autocast('cuda', torch.float16): + outputs = subgraph(*input_tensors) + else: + outputs = subgraph(*input_tensors) else: # everytime forward a segment, detach the tensor from previous graph # debug_id(input_tensors, 'outside fexecute args', 0) @@ -38,9 +73,11 @@ def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires input_tensors = tuple( Executor._detach[name][t] if t in Executor._detach[name] else t for t in input_tensors ) - # debug_id(input_tensors, 'inside fexecute args', 0) - outputs = subgraph(*input_tensors) - # debug_id(outputs, 'fexecute result', 0) + if CompileFlag.use_amp: + with torch.autocast('cuda', torch.float16): + outputs = subgraph(*input_tensors) + else: + outputs = subgraph(*input_tensors) # print('forwarding... ') return outputs @@ -49,6 +86,9 @@ def aexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) """ execute adapter """ + if CompileFlag.use_amp: + input_tensors = tuple(convert_fp32_to_fp16(t) for t in input_tensors) + if not requires_grad: with torch.no_grad(): outputs = subgraph(*input_tensors) From d50d69cbd0fc05cfa81ce25a4b5297c9dfca3848 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 10 Nov 2022 21:28:06 +0800 Subject: [PATCH 1150/1892] fix amp bugs --- cube/flags.py | 6 +++--- cube/runtime/schedule/strategy.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/cube/flags.py b/cube/flags.py index 64cf8e4d..f11de63f 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -25,7 +25,7 @@ class CompileFlag: dev_mode = os.environ.get('SINGLE_DEV_MODE') # allow to use python xx.py # use automate mixture precision training, where weights, gradients - # and optimizer status are kept in its original type (can be float32), - # but the forward will be converted to float16 with torch.autocast API. - use_amp = os.environ.get('USE_AMP') + # and optimizer status are kept in its original data type (can be float32), + # but some of the forward operators will be converted to float16. + use_amp = True if os.environ.get('USE_AMP') else False diff --git a/cube/runtime/schedule/strategy.py b/cube/runtime/schedule/strategy.py index 3e57169f..d151565d 100644 --- a/cube/runtime/schedule/strategy.py +++ b/cube/runtime/schedule/strategy.py @@ -3,6 +3,21 @@ from cube.profiler.timer import CudaTimer +from cube.flags import CompileFlag + + +def convert_fp32_to_fp16(t: Any): + """ + A tensor with float32 will be converted to float16. + A dtype of torch.float32 will be returned as torch.float16 + """ + if isinstance(t, torch.dtype) and t == torch.float32: + t = torch.float16 + elif torch.is_tensor(t) and t.dtype == torch.float32: + with torch.no_grad(): + t = t.half() + return t + class ScheduleABC: @@ -14,7 +29,8 @@ def forward_step(segment: Callable, *args, **kwargs): forward pass """ CudaTimer().start('forward') - outputs = segment(*args, **kwargs) + with torch.autocast('cuda', torch.float16, enabled=CompileFlag.use_amp): + outputs = segment(*args, **kwargs) CudaTimer().stop('forward') if not isinstance(outputs, tuple): outputs = (outputs,) @@ -59,6 +75,8 @@ def adapter_step(adapter: Callable, require_grad : bool = True, *args): if adapter is None: return (None,) # if adapter is None: return () args = tuple(t for t in args if torch.is_tensor(t)) + if CompileFlag.use_amp: + args = tuple(convert_fp32_to_fp16(t) for t in args) CudaTimer().start('adapter') outputs = adapter(*args) CudaTimer().stop('adapter') From 60efb0de34b05693b0495ed1d6ce2e30f51249c8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 10 Nov 2022 21:29:14 +0800 Subject: [PATCH 1151/1892] update nlp cases --- examples/nlp/gpt/policy/mpmd.py | 3 +-- examples/nlp/gpt/policy/spmd.py | 26 ++++++++++++++++++++++++-- examples/nlp/gpt/train.py | 13 ++++++------- examples/nlp/mbart/train.py | 6 ++---- 4 files changed, 33 insertions(+), 15 deletions(-) diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index d8fb09ee..b9927cd1 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -174,7 +174,6 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: for pp_idx, fstage in enumerate(fstages): for fnode in fstage.nodes(): if len(fnode.inputs()) == 0: continue # anchor - # tensor parallel -- FIXME: current restriction needs replica happen before partition if fnode.name == 'self_attention' or fnode.name == 'feedforward': fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) elif fnode.name == 'embedding': @@ -193,7 +192,7 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: strategy = IRSchedule1F1B(graph, num_microbatch) graph.predef_sched(strategy) - print(graph.extra_repr()) + # print(graph.extra_repr()) return graph diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index f5e942bb..6ac1bc34 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -9,8 +9,8 @@ # tensor parallelism def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], - idx: int, dim: int): - algo = node.algorithms('dim') + idx: int, dim: int, tag='dim'): + algo = node.algorithms(tag) sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) assert sub_nodes is not None for devid, sub_node in zip(devs, sub_nodes): @@ -50,6 +50,28 @@ def PASSingle(graph: IRGraph, resource): return graph +def PASDP(graph: IRGraph, resource): + dp_size = resource.ngpus + dp_devs = list(range(dp_size)) + + dataloader = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[0] + + # partition dataloader + dls = graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) + for devid, dl in enumerate(dls): + graph.assign(dl, devid) + + # partition forward operators + for node in graph.select(ntype=IRFwOperation): + if len(node.inputs()) == 0: continue + #FIXME: a workaround to find batch dimension + batch_dim = node.input(0).shape.index(bs) + _tp(graph, node, dp_devs, idx=0, dim=batch_dim) + + return graph + + def PASMegatronTP(graph: IRGraph, resource): tp_size = resource.ngpus tp_devs = list(range(tp_size)) diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 232bcc66..60cd768a 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMegatronTP --fp16 + examples/nlp/gpt/train.py --policy PASMegatron --fp16 """ @@ -48,14 +48,14 @@ def train(): - batch_size = 1 + batch_size = 2 model = GPT() model = model if not args.fp16 else model.half() dataloader = GPTDataLoader(batch_size) - model = cube.SemanticModel(model, dataloader.shapes) - @cube.compile(model, dataloader, PAS=PAS, override=True) + model = cube.SemanticModel(model) + @cube.compile(model, dataloader, PAS=PAS, override=True, load_content=True) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) @@ -69,8 +69,7 @@ def train_iter(model, dataloader): memory_summary() CudaTimer(enable=False).warmup() - iter_num = 40 - warmup = 8 + iter_num, warmup = 5, 2 for step in range(iter_num): if step == warmup: CudaTimer(enable=True).start('e2e') @@ -81,7 +80,7 @@ def train_iter(model, dataloader): if step == 0: print_each_rank('passed first iteration') - if (step + 1) % 10 == 0: + if (step + 1) % 2 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) CudaTimer().stop('e2e') diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index 873ce256..04ee0e3c 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -80,14 +80,12 @@ def train(): Config.seqlen = args.seqlen if cube.runtime.device.DeviceGroup().local_rank == 0: - model = MBartForSentenceClassification(batch_size).cuda() + model = MBartForSentenceClassification(batch_size) + model = model.half() if args.fp16 else model else: model = None dataloader = MBartSyntheticDataLoader(batch_size) - print_each_rank('model weight consumpition:') - memory_summary() - model = cube.SemanticModel(model) @cube.compile(model, dataloader, PAS=PAS, override=True, load_content=False) def train_iter(model, dataloader): From 7bc8443e1686f3b573a8c7663b622753196f2d73 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 10 Nov 2022 21:29:49 +0800 Subject: [PATCH 1152/1892] fix amp bugs --- cube/codegen/codegen.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index 18a3dca4..ac71953d 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -250,6 +250,9 @@ def kwargs_naming(self, **kwargs) -> str: """ names = [] for name, val in kwargs.items(): + # TODO: Ad-hoc patch for amp + if CompileFlag.use_amp and val == 'torch.float32': + val = 'torch.float16' names.append(f'{name}={val}') name = ', '.join(names) return name From 2b03d359ffc34a1ae48bc54349919155161156bd Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Fri, 11 Nov 2022 14:42:36 +0800 Subject: [PATCH 1153/1892] fix memory profile bugs --- cube/profiler/database.py | 39 +++++++++++++++++---- examples/alphafold2/README.md | 2 +- examples/alphafold2/alphafold2.py | 43 ++++++++++++++++++----- examples/alphafold2/module.py | 58 +++++++++++++++++-------------- 4 files changed, 98 insertions(+), 44 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index b8b9d1a9..b9e06d04 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -90,14 +90,16 @@ def run_step(func, tensors, kwargs, backward: bool): infer_memory = mtoc - mtic train_memory = 0 + used_tensor = set() # ref torch/utils/checkpoint.py/_checkpoint_without_reentrant def pack_hook(x): - nonlocal train_memory - byte_size = 1 - for dim in list(x.size()): - byte_size = byte_size * dim - byte_size = byte_size * x.element_size() - train_memory = train_memory + byte_size + nonlocal train_memory, used_tensor + if x.storage().data_ptr() not in used_tensor: + used_tensor.add(x.storage().data_ptr()) + byte_size = x.element_size() + for dim in list(x.size()): + byte_size = byte_size * dim + train_memory = train_memory + byte_size return x def unpack_hook(x): @@ -106,7 +108,7 @@ def unpack_hook(x): torch.cuda.synchronize() torch.cuda.empty_cache() with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): - outs = run_step(func, tensors, train_kwargs, backward=False) + outs = run_step(func, tensors, train_kwargs, backward=True) # warmup tic = time.time() @@ -153,8 +155,31 @@ def get_func(node: IRFwOperation) -> Tuple[Callable, Shapes, DTypes, Dict]: Get function call and its arguments from a cude IRGraph node """ assert isinstance(node, IRFwOperation), f"Only support profiling forward operation but got {type(node)}" + + def get_dep_names(sign: str): + ret = [] + code_impl = Sign2Op.kOpCodeDef[sign] + for code_line in code_impl.split('\n'): + idx = code_line.find('# call: ') + if idx != -1: + dep_name = code_line[idx + 8:] + assert dep_name in Sign2Op.kOpCodeDef, dep_name + ret = ret + get_dep_names(dep_name) + ret.append(dep_name) + return ret + if node.signature in Sign2Op.kOpCodeDef: + dep_code_impl = '' + for dep_name in get_dep_names(node.signature): + dep_code_impl = dep_code_impl + Sign2Op.kOpCodeDef[dep_name] code_impl: str = Sign2Op.kOpCodeDef[node.signature] + def_end = code_impl.find(':\n') + assert def_end >= 0 + prev_code_lines = code_impl[:def_end+2] + succ_code_lines = code_impl[def_end+2:] + for line in dep_code_impl.split('\n'): + prev_code_lines = prev_code_lines + ' ' + line + '\n' + code_impl = prev_code_lines + succ_code_lines local = {} exec(code_impl, globals(), local) fn = list(local.values())[0] diff --git a/examples/alphafold2/README.md b/examples/alphafold2/README.md index f6857751..bb9b4cc2 100644 --- a/examples/alphafold2/README.md +++ b/examples/alphafold2/README.md @@ -109,7 +109,7 @@ TODO - bs, s, r, cm, cz = 1, 128, 256, 256, 128 - bs, s, r, cm, cz = 1, 512, 256, 256, 128 - bs, s, r, cm, cz = 1, 512, 384, 256, 128 - - other config: dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 4, False, True, False + - other config: dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False - policy - spmd.PASSingle - spmd.PASDAP diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 060a607a..26ebabd7 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -8,8 +8,10 @@ from examples.alphafold2.model import * import examples.alphafold2.policy.spmd as spmd -cube.init() - +from cube.ir.operator import IRFwOperation, IRBpOperation +from cube.profiler.database import ProfileDataBase +from cube.algorithm.ops.dimops import gen_partitions +from cube.graph.function.anchor import IRGraphAnchor def run(size_config, other_config, policy): bs, s, r, cm, cz = size_config @@ -72,16 +74,37 @@ def train_iter(model, dataloader): int(torch.cuda.max_memory_allocated() / 1024 / 1024))) +def profile(graph, resource): + db = ProfileDataBase() + mem_sum = 0 + for node in graph.select(ntype=IRFwOperation): + if isinstance(node, IRGraphAnchor): + continue + partition_nodes = gen_partitions(node, 1) + for partition_node in partition_nodes: + in_mem, param_mem, fw_span, bw_span, infer_mem, train_mem = db.profile(partition_node) + mem_sum = mem_sum + train_mem + print(node.signature, train_mem) + db.dump('db.json', override=True) + print('estimated train mem: ', mem_sum / 1024 / 1024 / 1024) + + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + + return graph + def test_main(): # Training && Evoformer Stack # initial training - # bs, s, r, cm, cz = 1, 128, 256, 256, 128 + bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning # bs, s, r, cm, cz = 1, 512, 384, 256, 128 - # dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False + dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 3, False, True, False + policy = profile # policy = spmd.PASDAP # Training && Extra Sequence @@ -95,13 +118,15 @@ def test_main(): # policy = spmd.PASExtraSingle # Inference - bs, s, r, cm, cz = 1, 128, 2048, 256, 128 - dtype, evo_num, use_chunk, is_train, is_extra = torch.float32, 48, True, False, False - policy = spmd.PASSingleInference - policy = spmd.PASDAPInference + # bs, s, r, cm, cz = 1, 128, 2048, 256, 128 + # dtype, evo_num, use_chunk, is_train, is_extra = torch.float32, 48, True, False, False + # policy = spmd.PASSingleInference + # policy = spmd.PASDAPInference run((bs, s, r, cm, cz), (dtype, evo_num, use_chunk, is_train, is_extra), policy) -test_main() +if __name__ == '__main__': + cube.init() + test_main() diff --git a/examples/alphafold2/module.py b/examples/alphafold2/module.py index 80af9bf2..09de7d1f 100644 --- a/examples/alphafold2/module.py +++ b/examples/alphafold2/module.py @@ -3,7 +3,7 @@ import torch.utils.checkpoint as ckpt -@cube.graph.parser.register('TODO', name='calc_qkvg') +@cube.graph.parser.register('*, *, * -> *, *, *, *', name='calc_qkvg') def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, bs: int, s: int, r: int, head: int, c: int): gate = torch.sigmoid(torch.matmul(x, gate_proj)) @@ -23,7 +23,7 @@ def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, """ -@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R M', +@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R^ M^', name='MSAAttention') @torch.jit.ignore def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, @@ -91,7 +91,7 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return out -@cube.graph.parser.register('N S R M, M E, M F, E M, N 1 8 R R -> N S R M', +@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^, N 1^ 8^ R^ R^ -> N S R^ M^', name='MSAAttentionWithBias') @torch.jit.ignore def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, @@ -177,7 +177,7 @@ def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # note: code not reused constrained by cube's interface -@cube.graph.parser.register('N S R M, N R R Z, M E, M F, E M, Z H -> N S R M', +@cube.graph.parser.register('N S R^ M^, N R^ R^ Z^, M^ E^, M^ F^, E^ M^, Z^ H^ -> N S R^ M^', name='MSARowAttentionWithPairBias') def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, pair_repr: torch.Tensor, @@ -185,6 +185,7 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias_proj: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): + # call: MSAAttentionWithBias bs, s, r, cm = msa_repr.size() bias = torch.matmul(pair_repr, @@ -195,17 +196,18 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, head, c, scale, chunk_size, is_train) -@cube.graph.parser.register('N S R M, M E, M F, E M -> N S R M', +@cube.graph.parser.register('N S^ R M^, M^ E^, M^ F^, E^ M^ -> N S^ R M^', name='MSAColAttention') def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): + # call: MSAAttention return MSAAttention(msa_repr.permute(0, 2, 1, 3), gate_proj, qkv_proj, out_proj, head, c, scale, chunk_size, is_train).permute(0, 2, 1, 3) -@cube.graph.parser.register('N S R M, M M, M E, M E, M M, M M -> N S R M', +@cube.graph.parser.register('N S^ R^ M^, M^ M^, M^ E^, M^ E^, M^ M^, M^ M^ -> N S^ R^ M^', name='MSAColGlobalAttention') def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, v_proj: torch.Tensor, @@ -248,7 +250,7 @@ def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, """ -@cube.graph.parser.register('N S R M, M E, E M -> N S R M', +@cube.graph.parser.register('N S R M^, M^ E^, E^ M^ -> N S R M^', name='MSATransition') def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -256,12 +258,12 @@ def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, torch.nn.functional.relu(torch.matmul(msa_repr, proj1)), proj2) -@cube.graph.parser.register('N S R M, M C -> N S R C', name='OPMLeftProj') +@cube.graph.parser.register('N S R M^, M^ C^ -> N S R C^', name='OPMLeftProj') def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): return torch.matmul(msa_repr, proj) -@cube.graph.parser.register('N S R M, M C -> N S R C', name='OPMRightProj') +@cube.graph.parser.register('N S R M^, M^ C^ -> N S R C^', name='OPMRightProj') def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): return torch.matmul(msa_repr, proj) @@ -271,7 +273,7 @@ def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): """ -@cube.graph.parser.register('N S R M, N S T M, F Z -> N R T Z', +@cube.graph.parser.register('N S^ R M^, N S^ T^ M^, F^ Z^ -> N R^ T Z^', name='OuterProductMean') @torch.jit.ignore def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, @@ -306,29 +308,29 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): return outer -@cube.graph.parser.register('N S R Z, Z E, Z E -> N S R E', name='TMOLeftProj') +@cube.graph.parser.register('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMOLeftProj') def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - a = a * torch.matmul(pair_repr, proj2) - return a + b = a * torch.matmul(pair_repr, proj2) + return b -@cube.graph.parser.register('N S R Z, Z E, Z E -> N S R E', +@cube.graph.parser.register('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMORightProj') def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - a = a * torch.matmul(pair_repr, proj2) - return a + b = a * torch.matmul(pair_repr, proj2) + return b -@cube.graph.parser.register('N S T Z, Z Z -> N S T Z', name='TMOGate') +@cube.graph.parser.register('N S T^ Z^, Z^ Z^ -> N S T^ Z^', name='TMOGate') def TMOGate(pair_repr: torch.Tensor, proj: torch.Tensor): return torch.sigmoid(torch.matmul(pair_repr, proj)) -@cube.graph.parser.register('N S R E, N T R E, N S T Z, E, E, E Z -> N S T Z', +@cube.graph.parser.register('N S R^ E^, N T^ R^ E^, N S T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='TriangleMultiplicationOut') def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, @@ -345,7 +347,7 @@ def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, return p * g -@cube.graph.parser.register('N R S Z, Z E, Z E -> N R S E', name='TMILeftProj') +@cube.graph.parser.register('N R^ S Z^, Z^ E^, Z^ E^ -> N R^ S E^', name='TMILeftProj') def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): a = torch.sigmoid(torch.matmul(pair_repr, proj1)) @@ -353,7 +355,7 @@ def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return a -@cube.graph.parser.register('N R T Z, Z E, Z E -> N R T E', +@cube.graph.parser.register('N R^ T Z^, Z^ E^, Z^ E^ -> N R^ T E^', name='TMIRightProj') def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -362,12 +364,12 @@ def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return a -@cube.graph.parser.register('N S T Z, Z Z -> N S T Z', name='TMIGate') +@cube.graph.parser.register('N S^ T Z^, Z^ Z^ -> N S^ T Z^', name='TMIGate') def TMIGate(pair_repr: torch.Tensor, proj: torch.Tensor): return torch.sigmoid(torch.matmul(pair_repr, proj)) -@cube.graph.parser.register('N R S E, N R T E, N T S Z, E, E, E Z -> N T S Z', +@cube.graph.parser.register('N R^ S E^, N R^ T^ E^, N T^ S Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='TriangleMultiplicationIn') def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, tri_mul_norm2_weight: torch.Tensor, @@ -383,35 +385,37 @@ def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, return p.permute(0, 2, 1, 3) * g -@cube.graph.parser.register('N S R C, C D -> N S R D', name='TANSBias') +@cube.graph.parser.register('N S R^ C^, C^ D^ -> N S R^ D^', name='TANSBias') def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): return torch.matmul(pair_repr, bias_proj) -@cube.graph.parser.register('N S R Z, Z E, Z F, E Z, N T R G -> N S R Z', +@cube.graph.parser.register('N S R^ Z^, Z^ E^, Z^ F^, E^ Z^, N T^ R^ G^ -> N S R^ Z^', name='TriangleAttentionNodeStart') def TriangleAttentionNodeStart(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): + # call: MSAAttentionWithBias bias = bias.permute(0, 3, 1, 2).unsqueeze(1) return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, head, c, scale, chunk_size, is_train) -@cube.graph.parser.register('N S R C, C D -> N S R D', name='TANEBias') +@cube.graph.parser.register('N S^ R C^, C^ D^ -> N S^ R D^', name='TANEBias') def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): return torch.matmul(pair_repr, bias_proj) -@cube.graph.parser.register('N R S Z, Z E, Z F, E Z, N R T G -> N R S Z', +@cube.graph.parser.register('N R^ S Z^, Z^ E^, Z^ F^, E^ Z^, N R^ T^ G^ -> N R^ S Z^', name='TriangleAttentionNodeEnd') def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): + # call: TriangleAttentionNodeStart pair_repr = pair_repr.permute(0, 2, 1, 3) bias = bias.permute(0, 2, 1, 3) out = TriangleAttentionNodeStart(pair_repr, gate_proj, qkv_proj, out_proj, @@ -420,7 +424,7 @@ def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, return out.permute(0, 2, 1, 3) -@cube.graph.parser.register('N R T Z, Z E, E Z -> N R T Z', +@cube.graph.parser.register('N R T^ Z^, Z^ E^, E^ Z^ -> N R T^ Z^', name='PairTransition') def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): From 2510c2165c18ecbc49f1f56f4e620360c369aa55 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Fri, 11 Nov 2022 14:53:19 +0800 Subject: [PATCH 1154/1892] refine for merge --- examples/alphafold2/alphafold2.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 26ebabd7..cf239d2a 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -74,26 +74,6 @@ def train_iter(model, dataloader): int(torch.cuda.max_memory_allocated() / 1024 / 1024))) -def profile(graph, resource): - db = ProfileDataBase() - mem_sum = 0 - for node in graph.select(ntype=IRFwOperation): - if isinstance(node, IRGraphAnchor): - continue - partition_nodes = gen_partitions(node, 1) - for partition_node in partition_nodes: - in_mem, param_mem, fw_span, bw_span, infer_mem, train_mem = db.profile(partition_node) - mem_sum = mem_sum + train_mem - print(node.signature, train_mem) - db.dump('db.json', override=True) - print('estimated train mem: ', mem_sum / 1024 / 1024 / 1024) - - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - - return graph - def test_main(): # Training && Evoformer Stack # initial training @@ -103,9 +83,8 @@ def test_main(): # second fine-tuning # bs, s, r, cm, cz = 1, 512, 384, 256, 128 - dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 3, False, True, False - policy = profile - # policy = spmd.PASDAP + dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False + policy = spmd.PASDAP # Training && Extra Sequence # initial training From 4017196a96212f327c88ea87e5c535ce70483d3a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 15 Nov 2022 20:46:40 +0800 Subject: [PATCH 1155/1892] allow different policies for communication generation --- cube/flags.py | 3 +++ cube/graph/gener/concurrent.py | 21 +++++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/cube/flags.py b/cube/flags.py index f11de63f..ea4f8728 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -15,6 +15,9 @@ class CompileFlag: # ================ compiling ======================== # worker sleep in seconds worker_sleep = int(os.environ.get('WORKER_SLEEP')) if os.environ.get('WORKER_SLEEP') is not None else 0 + disable_intra_rvd = os.environ.get('DISABLE_INTRA_RVD') + disable_inter_rvd = os.environ.get('DISABLE_INTRA_RVD') + disable_comm_fusion = os.environ.get('DISABLE_COMM_FUSION') # ============ code generation =============== use_nnfusion = os.environ.get('USE_NNFUSION') diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index dfd1bf51..d6174634 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -12,6 +12,16 @@ from cube.ir.adapter.prim import BroadcastPrim from cube.graph.gener.layout import GridLayout, PathFinder +from cube.flags import CompileFlag + +import warnings + +if CompileFlag.disable_intra_rvd: + warnings.warn('Detected disabling intra-RVD collective generation, which may have big impact on performance.') +if CompileFlag.disable_inter_rvd: + warnings.warn('Detected disabling inter-RVD collective generation, which may have big impact on performance.') +if CompileFlag.disable_inter_rvd: + warnings.warn('Detected disabling general communication fusion, which may have big impact on performance in certain cases.') class ConcurrentGener: @@ -37,7 +47,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # case 1: sharing device (in-shard) inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) - if inshard and len(pdevs) > 1: + if (not CompileFlag.disable_intra_rvd) and inshard and len(pdevs) > 1: # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) try: fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) @@ -50,7 +60,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], ) # Case 2: sperating device (cross-shard) - if len(set(pdevs).intersection(cdevs)) == 0: + if (not CompileFlag.disable_inter_rvd) and len(set(pdevs).intersection(cdevs)) == 0: # fadapter = ConcurrentGener.gen_cross_shard(fptensors, fctensors, bptensors, bctensors) try: fadapter = ConcurrentGener.gen_cross_shard(fptensors, fctensors, bptensors, bctensors) @@ -182,7 +192,9 @@ def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], fpdevs = set(t.device[0] for t in fptensors) fcomm_workload = {t.device[0]: 0 for t in fptensors} # first try collectives - ret, prims = ConcurrentGener.gen_subtensor_coll(fctensors, fptensors, fcomm_workload) + ret = False + if not CompileFlag.disable_comm_fusion: + ret, prims = ConcurrentGener.gen_subtensor_coll(fctensors, fptensors, fcomm_workload) if ret: fprims += prims # otherwise use general p2p send recv @@ -196,7 +208,8 @@ def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], bprims = [] bcomm_workload = {t.device[0]: 0 for t in bptensors} # first try collectives - ret, prims = ConcurrentGener.gen_subtensor_coll(bctensors, bptensors, bcomm_workload) + if not CompileFlag.disable_comm_fusion: + ret, prims = ConcurrentGener.gen_subtensor_coll(bctensors, bptensors, bcomm_workload) if ret: bprims += prims # otherwise use general p2p send recv From 66c4093ba3e71cf1b3bbb33ab2e4bcb51f793c30 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Wed, 16 Nov 2022 05:33:49 +0000 Subject: [PATCH 1156/1892] change gpt config interfaces --- examples/nlp/gpt/model.py | 89 +++++++++++++++------------------------ 1 file changed, 33 insertions(+), 56 deletions(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index a8df651d..5525d5c9 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -3,70 +3,47 @@ from examples.nlp.blocks.encoder import EncoderLayer, EncoderInferLayer import cube +from dataclasses import dataclass +@dataclass class Config: - - num_embeddings = 50432 - seqlen = 1024 - - # toy model - embed_dim = 1024 - layers = 8 # 96 - attention_heads = 16 - - # # 1 layer of 175B model - # embed_dim = 12288 - # layers = 1 # 96 - # attention_heads = 96 - # - # # 350 M model (Medium)* - # embed_dim = 1024 - # layers = 24 - # attention_heads = 16 - - # 1.3 B model - # embed_dim = 2048 - # layers = 24 - # attention_heads = 32 - - # 2.6 B model - # embed_dim = 2560 - # layers = 32 - # attention_heads = 32 - - # 6.7 B model - # embed_dim = 4096 - # layers = 32 - # attention_heads = 32 - - # 15 B model - # embed_dim = 5120 - # layers = 48 - # attention_heads = 36 - - # 39 B model - # embed_dim = 8192 - # layers = 48 - # attention_heads = 64 - - # 175 B model* - # embed_dim = 12288 - # layers = 96 - # attention_heads = 96 - - attn_hidden_dim = embed_dim - ffn_hidden_dim = embed_dim * 4 - dropout = 0.2 - attn_dropout = 0.2 - activation_dropout = 0.2 + embed_dim: int = 1024 + layers: int = 8 + attention_heads: int = 16 + attn_hidden_dim: int = 1024 + ffn_hidden_dim: int = 4096 + num_embeddings: int = 50432 + seqlen: int = 1024 + dropout: float = 0.2 + attn_dropout: float = 0.2 + activation_dropout: float = 0.2 + + +def build_gpt_config(name: str) -> Config: + if name == '350M': + embed_dim, layers, attention_heads = 1024, 24, 16 + elif name == '1.3B': + embed_dim, layers, attention_heads = 2048, 24, 32 + elif name == '2.6B': + embed_dim, layers, attention_heads = 2560, 32, 32 + elif name == '6.7B': + embed_dim, layers, attention_heads = 4096, 32, 32 + elif name == '13B': + embed_dim, layers, attention_heads = 5120, 48, 40 + elif name == '39B': + embed_dim, layers, attention_heads = 8192, 48, 64 + elif name == '175B': + embed_dim, layers, attention_heads = 12288, 96, 96 + else: + assert False, f'unrecognized name: {name}' + return Config(embed_dim, layers, attention_heads, embed_dim, 4 * embed_dim) class GPT(torch.nn.Module): - def __init__(self): + def __init__(self, cfg=Config()): super().__init__() - cfg = Config() # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.embed_dim)) From c1ca4f767de9b96473ffa12062244e18ca0da462 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 16 Nov 2022 16:05:28 +0800 Subject: [PATCH 1157/1892] fix flag bugs --- cube/flags.py | 2 +- cube/graph/gener/concurrent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/flags.py b/cube/flags.py index ea4f8728..c925fa19 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -16,7 +16,7 @@ class CompileFlag: # worker sleep in seconds worker_sleep = int(os.environ.get('WORKER_SLEEP')) if os.environ.get('WORKER_SLEEP') is not None else 0 disable_intra_rvd = os.environ.get('DISABLE_INTRA_RVD') - disable_inter_rvd = os.environ.get('DISABLE_INTRA_RVD') + disable_inter_rvd = os.environ.get('DISABLE_INTER_RVD') disable_comm_fusion = os.environ.get('DISABLE_COMM_FUSION') # ============ code generation =============== diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index d6174634..fda1f7eb 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -20,7 +20,7 @@ warnings.warn('Detected disabling intra-RVD collective generation, which may have big impact on performance.') if CompileFlag.disable_inter_rvd: warnings.warn('Detected disabling inter-RVD collective generation, which may have big impact on performance.') -if CompileFlag.disable_inter_rvd: +if CompileFlag.disable_comm_fusion: warnings.warn('Detected disabling general communication fusion, which may have big impact on performance in certain cases.') From 4c424b94d44b91aec7e873e6fa5f110a8ad24f89 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 16 Nov 2022 16:21:27 +0800 Subject: [PATCH 1158/1892] fix dataloader --- examples/nlp/gpt/model.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 5525d5c9..915e572d 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -85,10 +85,8 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): class GPTInfer(torch.nn.Module): - def __init__(self, batch_size: int = 1): + def __init__(self, batch_size: int = 1, cfg: Config = Config()): super().__init__() - cfg = Config() - # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) self.embedw = torch.nn.Parameter(torch.rand(cfg.num_embeddings, cfg.embed_dim) / 128) self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) @@ -133,10 +131,10 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): class GPTDataLoader(cube.runtime.syndata.CubeDataLoader): - def __init__(self, batch_size: int): + def __init__(self, batch_size: int, cfg: Config = Config()): self.bs = batch_size - self.cfg = Config() + self.cfg = cfg super().__init__( shapes=([batch_size, self.cfg.seqlen], [batch_size, self.cfg.seqlen], @@ -163,12 +161,13 @@ def __iter__(self): def __next__(self): return self.samples[0] + class GPTInferDataLoader(cube.runtime.syndata.CubeDataLoader): - def __init__(self, batch_size: int): + def __init__(self, batch_size: int, cfg: Config = Config()): self.bs = batch_size - self.cfg = Config() + self.cfg = cfg super().__init__( shapes=([batch_size, 1], [batch_size, 1], From 69cce9ece5862a5d23ae0d03140bbfff6e6f64d0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 21 Nov 2022 10:13:35 +0800 Subject: [PATCH 1159/1892] add timer to breakdown compiler --- cube/compiler.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index d38c69f9..a571a154 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -102,6 +102,7 @@ def decorator(fn: Callable) -> Callable: resource = cube.runtime.resource.EnvResource() # run once to get model structure and tensor shape + start = time.time() outputs = fn(model, ir_dataloader) Program().finalize() if outputs is None: @@ -110,11 +111,16 @@ def decorator(fn: Callable) -> Callable: outputs = [outputs] # setup program output Program().set_output(outputs) + span = time.time() - start + print('> finish parsing iteration: {:.2f} s'.format(span)) # run policy + start = time.time() graph = Program().get_graph() assert callable(PAS), f"Policy PAS is not callable" graph = PAS(graph, resource) + span = time.time() - start + print('> finish policy expression: {:.2f} s'.format(span)) if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") @@ -138,16 +144,19 @@ def decorator(fn: Callable) -> Callable: if CompileFlag.log_schedule: print(graph.sched) span = time.time() - start - print('> planpass on applying schedule strategy: {:.2f} s'.format(span)) + print('> finish planpass on applying schedule strategy: {:.2f} s'.format(span)) # to execution plan + start = time.time() execplan = ExecutionPlan(graph) + span = time.time() - start + print('> finish lowering to execution plan: {:.2f} s'.format(span)) # plan pass for communication optimization start = time.time() execplan = DiffFusion.apply(execplan) span = time.time() - start - print('> planpass on diff-fusion operations: {:.2f} s'.format(span)) + print('> finish planpass on diff-fusion operations: {:.2f} s'.format(span)) # execplan.visualize(outfile='plan.png') @@ -156,11 +165,12 @@ def decorator(fn: Callable) -> Callable: start = time.time() execplan = Grouping.apply(execplan) span = time.time() - start - print('> planpass on grouping operations: {:.2f} s'.format(span)) + print('> finish planpass on grouping operations: {:.2f} s'.format(span)) # execplan.graph.reset_dependency() # execplan.analyze(outfile='execplan.png') + start = time.time() local_world_size = DeviceGroup().local_world_size # code generation mgener = ModelCodeGen(execplan) @@ -176,6 +186,9 @@ def decorator(fn: Callable) -> Callable: outfile = fname, attach=True ) + span = time.time() - start + print('> finish generating code: {:.2f} seconds'.format(span)) + compile_end = time.time() compile_time = compile_end - compile_start print('> compile time: {:.2f} seconds'.format(compile_time)) From f52be1f6a1ac705f2b44a76000eb80154b8bc7a5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 21 Nov 2022 14:38:28 +0800 Subject: [PATCH 1160/1892] increase batch size fit for data parallel --- examples/nlp/gpt/policy/spmd.py | 27 +++++++++++++++++++++++++-- examples/nlp/gpt/train.py | 2 +- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index 6ac1bc34..25d51c37 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -81,8 +81,7 @@ def PASMegatronTP(graph: IRGraph, resource): anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] indices = [fnodes.index(anchor) for anchor in anchors] for lid, idx in enumerate(indices): - # why -1: multiref - fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + fnodes[idx+1].comment = f'===> start of transformer layer {lid}' # attention attns = [node for node in fnodes if node.name == 'self_attention'] @@ -108,6 +107,30 @@ def PASMegatronTP(graph: IRGraph, resource): assert len(sums) == 1 _tp(graph, sums[0], tp_devs, idx=0, dim=2) + # partition add + adds = [node for node in fnodes if node.name == 'add'] + for add in adds: + # subnodes = _replica(graph, add, [0] * 2) + # for idx, sub_node in enumerate(subnodes): + # _tp(graph, sub_node, [0,1] if idx == 0 else [2,3], idx=0, dim=1) + # _tp(graph, add, tp_devs, idx=0, dim=1) + subnodes = _tp(graph, add, [0] * 2, idx=0, dim=1) + for idx, sub_node in enumerate(subnodes): + _replica(graph, sub_node, [0,1] if idx == 0 else [2,3]) + + # partition layernorm + lns = [node for node in fnodes if node.name == 'layernorm'] + assert len(lns) > 0 + for ln in lns: + # _tp(graph, ln, tp_devs, idx=0, dim=1) + # subnodes = _replica(graph, ln, [0] * 2) + # for idx, sub_node in enumerate(subnodes): + # _tp(graph, sub_node, [0,1] if idx == 0 else [2,3], idx=0, dim=1) + subnodes = _tp(graph, ln, [0] * 2, idx=0, dim=1) + for idx, sub_node in enumerate(subnodes): + _replica(graph, sub_node, [0,1] if idx == 0 else [2,3]) + + # replicate other nodes for node in graph.nodes(): if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 60cd768a..22bf7bdf 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -48,7 +48,7 @@ def train(): - batch_size = 2 + batch_size = 4 model = GPT() model = model if not args.fp16 else model.half() From 7fa3bd55e308eae8e904e195675839f142ad5b6f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 21 Nov 2022 14:40:21 +0800 Subject: [PATCH 1161/1892] layernorm and add use replica --- examples/nlp/gpt/policy/spmd.py | 42 ++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index 25d51c37..9838de62 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -108,27 +108,27 @@ def PASMegatronTP(graph: IRGraph, resource): _tp(graph, sums[0], tp_devs, idx=0, dim=2) # partition add - adds = [node for node in fnodes if node.name == 'add'] - for add in adds: - # subnodes = _replica(graph, add, [0] * 2) - # for idx, sub_node in enumerate(subnodes): - # _tp(graph, sub_node, [0,1] if idx == 0 else [2,3], idx=0, dim=1) - # _tp(graph, add, tp_devs, idx=0, dim=1) - subnodes = _tp(graph, add, [0] * 2, idx=0, dim=1) - for idx, sub_node in enumerate(subnodes): - _replica(graph, sub_node, [0,1] if idx == 0 else [2,3]) - - # partition layernorm - lns = [node for node in fnodes if node.name == 'layernorm'] - assert len(lns) > 0 - for ln in lns: - # _tp(graph, ln, tp_devs, idx=0, dim=1) - # subnodes = _replica(graph, ln, [0] * 2) - # for idx, sub_node in enumerate(subnodes): - # _tp(graph, sub_node, [0,1] if idx == 0 else [2,3], idx=0, dim=1) - subnodes = _tp(graph, ln, [0] * 2, idx=0, dim=1) - for idx, sub_node in enumerate(subnodes): - _replica(graph, sub_node, [0,1] if idx == 0 else [2,3]) + # adds = [node for node in fnodes if node.name == 'add'] + # for add in adds: + # # subnodes = _replica(graph, add, [0] * 2) + # # for idx, sub_node in enumerate(subnodes): + # # _tp(graph, sub_node, [0,1] if idx == 0 else [2,3], idx=0, dim=1) + # # _tp(graph, add, tp_devs, idx=0, dim=1) + # subnodes = _tp(graph, add, [0] * 2, idx=0, dim=1) + # for idx, sub_node in enumerate(subnodes): + # _replica(graph, sub_node, [0,1] if idx == 0 else [2,3]) + # + # # partition layernorm + # lns = [node for node in fnodes if node.name == 'layernorm'] + # assert len(lns) > 0 + # for ln in lns: + # # _tp(graph, ln, tp_devs, idx=0, dim=1) + # # subnodes = _replica(graph, ln, [0] * 2) + # # for idx, sub_node in enumerate(subnodes): + # # _tp(graph, sub_node, [0,1] if idx == 0 else [2,3], idx=0, dim=1) + # subnodes = _tp(graph, ln, [0] * 2, idx=0, dim=1) + # for idx, sub_node in enumerate(subnodes): + # _replica(graph, sub_node, [0,1] if idx == 0 else [2,3]) # replicate other nodes From 611cf196e2ed0b73f9e5039565578d4000db8240 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 21 Nov 2022 16:24:59 +0800 Subject: [PATCH 1162/1892] optimize node insert and remove --- cube/graph/gener/gen.py | 2 ++ cube/graph/segment.py | 45 +++++++++++++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 245dec6e..3a3311c3 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -129,6 +129,8 @@ def gen(graph: IRGraph) -> IRGraph: """ # remove anchor node graph = IRAdapterGener.remove_anchor(graph) + # reorder + graph._reorder_producer_consumer() # generate adapters for activation graph = IRAdapterGener.gen_activation(graph) # generate weight reducer diff --git a/cube/graph/segment.py b/cube/graph/segment.py index a473a2f3..020d030d 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -388,6 +388,35 @@ def _remove_ftensor(self, ftensor: IRFullTensor): if ftensor.is_attr() and ftensor in self._attributes: self._attributes.remove(ftensor) + def _reorder_producer_consumer(self): + """ + Re-order producers and consumers for each full tensor to match + with the ordering of nodes. + + Note sub-segment will also be reordered. + """ + # clear up + self._ftensors, self._attributes = set(), set() + self._producers, self._ptensors = dict(), dict() + self._consumers, self._ctensors = dict(), dict() + # set producer and consumer + for node in self._nodes: + if isinstance(node, IRAdapter): continue + itensors = set(t for t in node.inputs() if isinstance(t, IRSubTensor)) + for itensor in itensors: + ftensor = itensor.parent + self._add_ftensor(ftensor) + self._consumers[ftensor].append(node) + self._ctensors[ftensor].append(itensor) + otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) + for otensor in otensors: + ftensor = otensor.parent + self._add_ftensor(ftensor) + self._producers[ftensor].append(node) + self._ptensors[ftensor].append(otensor) + if isinstance(node, IRSegment): + node._reorder_producer_consumer() + def insert(self, node: IRCell, index: Union[int, CellPosition]): """ Insert a node at index. @@ -413,17 +442,21 @@ def insert(self, node: IRCell, index: Union[int, CellPosition]): for itensor in itensors: ftensor = itensor.parent self._add_ftensor(ftensor) - idx = len([c for c in self._consumers[ftensor] if self._nodes.index(c) < index]) - self._consumers[ftensor].insert(idx, node) - self._ctensors[ftensor].insert(idx, itensor) + # idx = len([c for c in self._consumers[ftensor] if self._nodes.index(c) < index]) + # self._consumers[ftensor].insert(idx, node) + # self._ctensors[ftensor].insert(idx, itensor) + self._consumers[ftensor].append(node) + self._ctensors[ftensor].append(itensor) # producer otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) for otensor in otensors: ftensor = otensor.parent self._add_ftensor(ftensor) - idx = len([c for c in self._producers[ftensor] if self._nodes.index(c) < index]) - self._producers[ftensor].insert(idx, node) - self._ptensors[ftensor].insert(idx, otensor) + # idx = len([c for c in self._producers[ftensor] if self._nodes.index(c) < index]) + # self._producers[ftensor].insert(idx, node) + # self._ptensors[ftensor].insert(idx, otensor) + self._producers[ftensor].append(node) + self._ptensors[ftensor].append(otensor) else: segment = self._nodes[pos[0]] assert isinstance(segment, IRSegment), "Expected IRSegment" From 363f523ee473708e469df3731835f5f08c09e1c3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 21 Nov 2022 18:52:06 +0800 Subject: [PATCH 1163/1892] fix bug on multiref --- cube/graph/gener/concurrent.py | 6 +++++- cube/graph/gener/gen.py | 21 ++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index fda1f7eb..473484f4 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -103,11 +103,15 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], paths, fprims = ilayout.path(olayout) # re-assign the operator if miss-ordered + res_layout: GridLayout = paths[-1] names, from_dev, to_dev = [], [], [] - for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): + for itensor, otensor in zip(res_layout.mat.flatten(), olayout.mat.flatten()): assert len(itensor.device) == 1 and len(otensor.device) == 1, \ "Expect tensor only has one device. Report this as a bug" if itensor.device != otensor.device: + # TODO: need to be robust: multiref to a node type + if otensor.cell.name == 'multiref': + raise RuntimeError("auto-inserted multiref cannot be re-ordered") inode, onode = itensor.cell, otensor.cell names.append(f'{onode.name}{onode.cid}') from_dev.append(onode.device[0]) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 3a3311c3..bc0f0a21 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -129,8 +129,6 @@ def gen(graph: IRGraph) -> IRGraph: """ # remove anchor node graph = IRAdapterGener.remove_anchor(graph) - # reorder - graph._reorder_producer_consumer() # generate adapters for activation graph = IRAdapterGener.gen_activation(graph) # generate weight reducer @@ -266,19 +264,28 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bdummies = [fwop.mirror for fwop in fdummies if fwop.mirror is not None] bgraph: Optional[IRSegment] = graph.mirror - # generate adapter for inter-segments - # FIXME: assume producers and consumers can run in parallel + # reorder producers and consumers + graph._reorder_producer_consumer() + + # local producer fusion and local consumer multiref + ftensors = [] for ftensor in graph.full_tensors(): # backward will gen in forward if ftensor.is_param() or ftensor.is_grad(): continue - - # flatten gradient + # flatten gradient utils.flatten_grad(graph, ftensor) - # optimization: local fusion / multiref on producer / consumer ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) IRAdapterGener.local_consumer_multiref(graph, ftensor) + ftensors.append(ftensor) + + # reorder again since inserted multiref could be mis-ordered + graph._reorder_producer_consumer() + + # generate adapter for inter-segments + # FIXME: assume producers and consumers can run in parallel + for ftensor in ftensors: # print(graph.debug_tensor_map_str(ftensor)) # print(graph.mirror.debug_tensor_map_str(ftensor.grad)) From f68604905706f1da3603ef6012266e9b60c0d506 Mon Sep 17 00:00:00 2001 From: Rongwei Lu Date: Tue, 22 Nov 2022 05:01:27 +0000 Subject: [PATCH 1164/1892] Updated memory.py --- cube/profiler/memory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index 957b32e2..cf6ffcad 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -16,6 +16,8 @@ def memory_summary(): print_each_rank( '{:.2f} GB memory consumption'.format(mem / 1024 / 1024 / 1024), ) + + return mem def model_summary(model: torch.nn.Module, inputs: List[Any], do_eval=False, max_depth=6): From 0ef6d1b115f2c4b92e0bcc479cc94e98a57815f8 Mon Sep 17 00:00:00 2001 From: Rongwei Lu Date: Tue, 22 Nov 2022 05:03:18 +0000 Subject: [PATCH 1165/1892] add gpt 760M and fix 1.3B --- examples/nlp/gpt/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 915e572d..e87df949 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -23,8 +23,10 @@ class Config: def build_gpt_config(name: str) -> Config: if name == '350M': embed_dim, layers, attention_heads = 1024, 24, 16 + elif name == '760M': + embed_dim, layers, attention_heads = 1536, 24, 16 elif name == '1.3B': - embed_dim, layers, attention_heads = 2048, 24, 32 + embed_dim, layers, attention_heads = 2048, 24, 24 elif name == '2.6B': embed_dim, layers, attention_heads = 2560, 32, 32 elif name == '6.7B': From 81da0dbca5657e8539e595b109815a0a9490f9fe Mon Sep 17 00:00:00 2001 From: Rongwei Lu Date: Tue, 22 Nov 2022 06:28:08 +0000 Subject: [PATCH 1166/1892] the heads of GPT-3 1.3B changed into 32 --- examples/nlp/gpt/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index e87df949..396e7c55 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -26,7 +26,7 @@ def build_gpt_config(name: str) -> Config: elif name == '760M': embed_dim, layers, attention_heads = 1536, 24, 16 elif name == '1.3B': - embed_dim, layers, attention_heads = 2048, 24, 24 + embed_dim, layers, attention_heads = 2048, 24, 32 elif name == '2.6B': embed_dim, layers, attention_heads = 2560, 32, 32 elif name == '6.7B': From 2656cc4399056a2ce37f0fbd3b84e147b038aa96 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 24 Nov 2022 22:34:40 +0800 Subject: [PATCH 1167/1892] fix timer bug on clear --- cube/profiler/timer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index 4569e7e9..d2ad9192 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -124,7 +124,9 @@ def __getattr__(self, name): return getattr(self.instance, name) def clear(self): - self.instance = CudaTimer.__CudaTimer() + CudaTimer.instance = CudaTimer.__CudaTimer( + enable=self.enabled, predefined=self.predefined + ) def print_all(self, times: int, rank_only: Optional[int] = None): """ From be96ce03f89a038498a4552ec53b9a11a911615d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Nov 2022 17:24:21 +0800 Subject: [PATCH 1168/1892] include profiler --- cube/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cube/__init__.py b/cube/__init__.py index 24ef255c..191a1afd 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,4 +1,5 @@ from cube import runtime +from cube import profiler from cube.compiler import SemanticModel, compile From 705bd4a12abfa19aecc1b781eb8f335f111397bf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Nov 2022 17:25:04 +0800 Subject: [PATCH 1169/1892] add matmul --- cube/graph/function/function.py | 15 +++++++++++++++ cube/graph/parser/mapping.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index e9c1c658..9aa3d642 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -51,6 +51,21 @@ def BatchLinear(signature, inputs): return IRDimops(BatchLinear, 'bmm', signature, annos, inputs) +def Matmul(signature, inputs: Tuple[IRTensor, IRTensor]): + assert len(inputs) == 2 + annos = [ + 'm k+, k+ n -> m n', + 'k+, k+ n -> n', + 'm k+, k+ -> m', + '* m k+, k+ n -> * m n', + '* m k+, * k+ n -> * m n' # TODO: broadcast + ] + lhs, rhs = inputs + if len(lhs.shape) > 2 and len(rhs.shape) > 2: + assert tuple(lhs.shape[:-2]) == tuple(rhs.shape[:-2]), "broadcast of matmul (bmm) is not supported" + return IRDimops(Matmul, 'matmul', signature, annos, inputs) + + def Zeros(signature, inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): # zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index a041c507..a35d4e35 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -63,6 +63,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('linear') : function.Linear, + __ttemplate('matmul'): function.Matmul, + __ftemplate('softmax') : function.Softmax, __ftemplate('dropout') : function.Dropout, From 729f9704e228a4f1506248309a5aa2d2c3c1fc7d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Nov 2022 17:25:36 +0800 Subject: [PATCH 1170/1892] add openfold with tensor parallelism version --- examples/openfold/blocks/__init__.py | 0 examples/openfold/blocks/attention.py | 353 ++++++++++++++++++++++++++ examples/openfold/blocks/embedder.py | 230 +++++++++++++++++ examples/openfold/blocks/evoformer.py | 120 +++++++++ examples/openfold/blocks/opm.py | 68 +++++ examples/openfold/blocks/tmu.py | 79 ++++++ examples/openfold/model.py | 142 +++++++++++ examples/openfold/policy/mpmd.py | 114 +++++++++ examples/openfold/train.py | 89 +++++++ 9 files changed, 1195 insertions(+) create mode 100644 examples/openfold/blocks/__init__.py create mode 100644 examples/openfold/blocks/attention.py create mode 100644 examples/openfold/blocks/embedder.py create mode 100644 examples/openfold/blocks/evoformer.py create mode 100644 examples/openfold/blocks/opm.py create mode 100644 examples/openfold/blocks/tmu.py create mode 100644 examples/openfold/model.py create mode 100644 examples/openfold/policy/mpmd.py create mode 100644 examples/openfold/train.py diff --git a/examples/openfold/blocks/__init__.py b/examples/openfold/blocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/openfold/blocks/attention.py b/examples/openfold/blocks/attention.py new file mode 100644 index 00000000..8d581fdf --- /dev/null +++ b/examples/openfold/blocks/attention.py @@ -0,0 +1,353 @@ +""" +Attention Module for MSA Attention and Pair Attention in Evoformer +""" + +import cube +import torch +import torch.utils.checkpoint as ckpt + + +@cube.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S R^ M^', name='msa_attn') +@torch.jit.ignore +def msa_attn(x: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, + c: int, scale: float, chunk_size: int, is_train: bool): + # cube.profiler.CudaTimer().start('msa_attn') + bs, s, r, cm = x.size() + + if chunk_size == -1: + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + gate = gate.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) + q = q.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) + k = k.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c).transpose(1, 2) + v = v.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) + sim = torch.bmm(q, k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + attend = torch.bmm(sim, v) * gate + out = attend.reshape(bs, s, head, r, c).transpose(2, 3).reshape(bs, s, r, head * c) + out = torch.matmul(out, out_proj) + else: + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + gate = gate.reshape(bs, s, r, head, c).transpose(2, 3) + q = q.reshape(bs, s, r, head, c).transpose(2, 3) + k = k.reshape(bs, s, r, head, c).transpose(2, 3) + v = v.reshape(bs, s, r, head, c).transpose(2, 3) + assert s % chunk_size == 0 + out_chunks = [] + + def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + gate: torch.Tensor, start: int): + cur_q = q[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_k = k[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c).transpose(1, 2) + cur_v = v[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + sim = torch.bmm(cur_q, cur_k) * 0.125 + sim = torch.nn.functional.softmax(sim, dim=-1) + attend = torch.bmm(sim, cur_v) * cur_gate + attend = attend.reshape(bs, chunk_size, head, r, c).transpose( + 2, 3).reshape(bs, chunk_size, r, head * c) + return attend + + for start in range(0, s, chunk_size): + attend = ckpt.checkpoint(attention, q, k, v, gate, start) + # attend = attention(q, k, v, gate, start) + out_chunks.append(attend) + + out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) + # cube.profiler.CudaTimer().stop('msa_attn') + return out + + +@cube.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, N 1 head+ R^ R^ -> N S R^ M^', name='msa_attn_bias') +@torch.jit.ignore +def msa_attn_bias(x: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias: torch.Tensor, head: int, c: int, scale: float, + chunk_size: int, is_train: bool): + # cube.profiler.CudaTimer().start('msa_attn_bias') + bs, s, r, cm = x.size() + assert gate_proj.size(1) % head == 0 + c = gate_proj.size(1) // head + + if chunk_size == -1: + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) # N S R (head dim) + gate = gate.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) # (N S head) r dim + q = q.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) # (N S head) r dim + k = k.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c).transpose(1, 2) # (N S head) dim r + v = v.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) + sim = torch.bmm(q, k) * scale # (N S head) r r + sim = torch.nn.functional.softmax(sim, dim=-1) # (N S head) r r + sim = sim.reshape(bs, s, head, r, r) + bias # N S head r r, N S 1 r r + sim = sim.reshape(bs * s * head, r, r) # (N S head) r r + attend = torch.bmm(sim, v) * gate # (N S head) r dim + out = attend.reshape(bs, s, head, r, c).transpose(2, 3).reshape(bs, s, r, head * c) + out = torch.matmul(out, out_proj) + else: + gate = torch.sigmoid(torch.matmul(x, gate_proj)) + q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) + gate = gate.reshape(bs, s, r, head, c).transpose(2, 3) + q = q.reshape(bs, s, r, head, c).transpose(2, 3) + k = k.reshape(bs, s, r, head, c).transpose(2, 3) + v = v.reshape(bs, s, r, head, c).transpose(2, 3) + assert s % chunk_size == 0 + out_chunks = [] + + def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + gate: torch.Tensor, bias: torch.Tensor, start: int): + cur_q = q[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_k = k[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c).transpose(1, 2) + cur_v = v[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( + bs * chunk_size * head, r, c) + + sim = torch.bmm(cur_q, cur_k) * scale + sim = torch.nn.functional.softmax(sim, dim=-1) + sim = sim.reshape(bs, chunk_size, head, r, r) + bias + sim = sim.reshape(bs * chunk_size * head, r, r) + + attend = torch.bmm(sim, cur_v) * cur_gate + attend = attend.reshape(bs, chunk_size, head, r, c).transpose( + 2, 3).reshape(bs, chunk_size, r, cm) + return attend + + for start in range(0, s, chunk_size): + if is_train: + attend = ckpt.checkpoint(attention_bias, q, k, v, gate, bias, start) + else: + attend = attention_bias(q, k, v, gate, bias, start) + out_chunks.append(attend) + out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) + # cube.profiler.CudaTimer().stop('msa_attn_bias') + return out + + +# note: code not reused constrained by cube's interface +@cube.graph.parser.register('N S R^ M^, N R^ R^ Z^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, Z^ head+ -> N S R^ M^', name='row_attn') +def row_attn(msa_repr: torch.Tensor, pair_repr: torch.Tensor, + gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, + bias_proj: torch.Tensor, head: int, c: int, + scale: float, chunk_size: int, is_train: bool): + # call: MSAAttentionWithBias + bs, s, r, cm = msa_repr.size() + # N R R Z, Z h -> N R R h -> N h S R -> N 1 h S R + bias = torch.matmul(pair_repr, bias_proj).permute(0, 3, 1, 2).reshape(bs, 1, head, r, r) + + return msa_attn_bias(msa_repr, gate_proj, qkv_proj, out_proj, bias, + head, c, scale, chunk_size, is_train) + + +@cube.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S R^ M^', name='col_attn') +def col_attn(msa_repr: torch.Tensor, gate_proj: torch.Tensor, + qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, + c: int, scale: float, chunk_size: int, is_train: bool): + # call: MSAAttention + msa_repr = msa_repr.permute(0, 2, 1, 3) + out = msa_attn( + msa_repr, gate_proj, qkv_proj, out_proj, + head, c, scale, chunk_size, is_train) + out = out.permute(0, 2, 1, 3) + return out + + +@cube.graph.parser.register('N S^ R^ M^, M^ (head+ dim^), M^ E^, M^ E^, M^ (head+ dim^), (head+ dim) M^ -> N S^ R^ M^', name='MSAColGlobalAttention') +def global_attn(msa_repr: torch.Tensor, q_proj: torch.Tensor, + k_proj: torch.Tensor, v_proj: torch.Tensor, + gate_proj: torch.Tensor, out_proj: torch.Tensor, + head: int, c: int, scale: float): + # [N R S M] + msa_repr = msa_repr.transpose(-2, -3) + + # [N R M] + q = torch.sum(msa_repr, dim=-2) + # [N R M] + q = torch.matmul(q, q_proj) * scale + # [N R H E] + q = q.view(q.shape[:-1] + (head, -1)) + + # [N R S E] + k, v = torch.matmul(msa_repr, k_proj), torch.matmul(msa_repr, v_proj) + + # N R H E, N R E S -> N R H S + a = torch.matmul(q, k.transpose(-1, -2)) + a = torch.nn.functional.softmax(a, dim=-1) + # [N R H E] + o = torch.matmul(a, v) + + # [N R S M] + g = torch.sigmoid(torch.matmul(msa_repr, gate_proj)) + # [N R S H E] + g = g.view(g.shape[:-1] + (head, -1)) + + # [N R 1 H E] + o = o.unsqueeze(-3) * g + # [N R S M] + o = o.reshape(o.shape[:-2] + (-1, )) + + return torch.matmul(o, out_proj).transpose(-2, -3) + + +@cube.graph.parser.register('N S R M^, M^ E+, E+ M^ -> N S R M^', name='feedforward') +@torch.jit.ignore +def feedforward(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): + """ + MSA transition + """ + # cube.profiler.CudaTimer().start('ffn') + x = torch.matmul(msa_repr, proj1) + x = torch.nn.functional.relu(x) + x = torch.matmul(x, proj2) + # cube.profiler.CudaTimer().stop('ffn') + return x + + +@cube.graph.parser.register('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, Z^ head+ -> N S R^ Z^', name='tri_attn_start') +def tri_attn_start(pair_repr: torch.Tensor, + gate: torch.Tensor, qkv: torch.Tensor, + out: torch.Tensor, bias: torch.Tensor, + head: int, c: int, scale: float, + chunk_size: int, is_train: bool): + bias = torch.matmul(pair_repr, bias).permute(0, 3, 1, 2).unsqueeze(1) + out = msa_attn_bias(pair_repr, gate, qkv, out, bias, + head, c, scale, chunk_size, is_train) + return out + + +@cube.graph.parser.register('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, Z head+ -> N S R^ Z^', name='tri_attn_end') +def tri_attn_end(pair_repr: torch.Tensor, + gate: torch.Tensor, qkv: torch.Tensor, + out: torch.Tensor, bias: torch.Tensor, + head: int, c: int, scale: float, chunk_size: int, is_train: bool): + bias = torch.matmul(pair_repr, bias).permute(0, 3, 2, 1).unsqueeze(1) + pair_repr = pair_repr.permute(0, 2, 1, 3) + out = msa_attn_bias(pair_repr, gate, qkv, out, bias, + head, c, scale, chunk_size, is_train) + return out.permute(0, 2, 1, 3) + + +class MSARowAttention(torch.nn.Module): + """ + MSA Row Attention with Pair Bias + """ + + def __init__(self, hidden: int, heads: int, z: int, scale: float, chunk_size: int = -1): + super().__init__() + assert hidden % heads == 0 + self.heads = heads + self.dhead = hidden // heads + self.chunk_size = chunk_size + self.scale = scale + self.bias = torch.nn.Parameter(torch.empty(z, heads)) + self.gate = torch.nn.Parameter(torch.empty(hidden, hidden)) + self.qkv = torch.nn.Parameter(torch.empty(hidden, hidden * 3)) + self.out = torch.nn.Parameter(torch.empty(hidden, hidden)) + + def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor) -> torch.Tensor: + """ + msa_repr: [N S R M] + pair_repr: [N R R Z] + """ + out = row_attn( + msa_repr, pair_repr, self.gate, self.qkv, self.out, self.bias, + self.heads, self.dhead, self.scale, self.chunk_size, self.training + ) + return out + + +class MSAColAttention(torch.nn.Module): + """ + MSA Coloumn Attention (no bias) + """ + def __init__(self, hidden: int, heads: int, scale: float, chunk_size: int = -1) -> None: + super().__init__() + assert hidden % heads == 0 + self.heads = heads + self.dhead = hidden // heads + self.chunk_size = chunk_size + self.scale = scale + self.gate = torch.nn.Parameter(torch.empty(hidden, hidden)) + self.qkv = torch.nn.Parameter(torch.empty(hidden, hidden * 3)) + self.out = torch.nn.Parameter(torch.empty(hidden, hidden)) + + def forward(self, msa_repr: torch.Tensor) -> torch.Tensor: + """ + msa_repr: [N S R M] + """ + out = col_attn( + msa_repr, self.gate, self.qkv, self.out, + self.heads, self.dhead, self.scale,self.chunk_size, self.training + ) + return out + + +class Transition(torch.nn.Module): + """ + Feedforward for msa_repr and pair_repr + """ + def __init__(self, hidden: int, ff_mult: int = 4) -> None: + super().__init__() + self.proj1 = torch.nn.Parameter(torch.empty(hidden, ff_mult * hidden)) + self.proj2 = torch.nn.Parameter(torch.empty(ff_mult * hidden, hidden)) + + def forward(self, msa_repr: torch.Tensor) -> torch.Tensor: + """ + msa_repr: [N S R M] + """ + return feedforward(msa_repr, self.proj1, self.proj2) + + +class TriangleAttentionNodeStart(torch.nn.Module): + + def __init__(self, cz: int, pair_head: int, c: int, scale: float, chunk_size=-1) -> None: + super().__init__() + self.heads = pair_head + self.c = c + self.scale = scale + self.chunk_size = chunk_size + self.layer_norm = torch.nn.LayerNorm(cz) + self.gate = torch.nn.Parameter(torch.empty(cz, pair_head * c)) + self.qkv = torch.nn.Parameter(torch.empty(cz, 3 * pair_head * c)) + self.out = torch.nn.Parameter(torch.empty(pair_head * c, cz)) + self.bias = torch.nn.Parameter(torch.empty(cz, pair_head)) + + def forward(self, pair_repr: torch.Tensor): + pair_repr = self.layer_norm(pair_repr) + pair_repr = tri_attn_start( + pair_repr, self.gate, self.qkv, self.out, self.bias, + self.heads, self.c, self.scale, self.chunk_size, self.training + ) + return pair_repr + + +class TriangleAttentionNodeEnd(torch.nn.Module): + + def __init__(self, cz: int, pair_head: int, c: int, scale: float, chunk_size=-1) -> None: + super().__init__() + self.heads = pair_head + self.c = c + self.scale = scale + self.chunk_size = chunk_size + self.layer_norm = torch.nn.LayerNorm(cz) + self.gate = torch.nn.Parameter(torch.empty(cz, pair_head * c)) + self.qkv = torch.nn.Parameter(torch.empty(cz, 3 * pair_head * c)) + self.out = torch.nn.Parameter(torch.empty(pair_head * c, cz)) + self.bias = torch.nn.Parameter(torch.empty(cz, pair_head)) + + def forward(self, pair_repr: torch.Tensor): + pair_repr = self.layer_norm(pair_repr) + pair_repr = tri_attn_end( + pair_repr, self.gate, self.qkv, self.out, self.bias, + self.heads, self.c, self.scale, self.chunk_size, self.training + ) + return pair_repr + diff --git a/examples/openfold/blocks/embedder.py b/examples/openfold/blocks/embedder.py new file mode 100644 index 00000000..cba55626 --- /dev/null +++ b/examples/openfold/blocks/embedder.py @@ -0,0 +1,230 @@ +import torch +import torch.nn as nn + +from typing import Tuple, Optional + +import cube + + + +@cube.graph.parser.register('N res, cz nobins, cz -> N res res cz', name='relpos') +def input_embedder_pair_emb(ri: torch.Tensor, + tf_emb_i: torch.Tensor, tf_emb_j: torch.Tensor, + w_relpos: torch.Tensor, b_relpos: torch.Tensor, + relpos_k) -> torch.Tensor: + + ri = ri.type(tf_emb_i.dtype) + d = ri[..., None] - ri[..., None, :] + boundaries = torch.arange( + start=-relpos_k, end=relpos_k + 1, device=torch.cuda.current_device() + ) + reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),)) + d = d[..., None] - reshaped_bins + d = torch.abs(d) + d = torch.argmin(d, dim=-1) + d = nn.functional.one_hot(d, num_classes=len(boundaries)).float() + d = d.to(ri.dtype) + pair_emb = torch.nn.functional.linear(d, w_relpos, b_relpos) + + pair_emb = pair_emb + tf_emb_i[..., None, :] + pair_emb = pair_emb + tf_emb_j[..., None, :, :] + + return pair_emb + + +@cube.graph.parser.register('N res tfdim^, cm tfdim^, cm -> N nclust^, res, cm') +def input_embedder_tf_m(tf: torch.Tensor, w_tf_m: torch.Tensor, b_tf_m: torch.Tensor, nclust: int) -> torch.Tensor: + tf_m = torch.nn.linear(tf, w_tf_m, b_tf_m) + tf_m = tf_m.unsqueeze(-3).expand(((-1,) * len(tf.shape[:-2]) + (nclust, -1, -1))) + return tf_m + + +class InputEmbedder(nn.Module): + """ + Embeds a subset of the input features. + + Implements Algorithms 3 (InputEmbedder) and 4 (relpos). + """ + + def __init__(self, tf_dim: int, msa_dim: int, c_z: int, c_m: int, relpos_k: int): + """ + Args: + tf_dim: + Final dimension of the target features + msa_dim: + Final dimension of the MSA features + c_z: + Pair embedding dimension + c_m: + MSA embedding dimension + relpos_k: + Window size used in relative positional encoding + """ + super().__init__() + + self.tf_dim = tf_dim + self.msa_dim = msa_dim + + self.c_z = c_z + self.c_m = c_m + + self.linear_tf_z_i = nn.Linear(tf_dim, c_z) + self.linear_tf_z_j = nn.Linear(tf_dim, c_z) + # self.linear_tf_m = nn.Linear(tf_dim, c_m) + self.w_tf_m = torch.nn.Parameter(torch.empty((c_m, tf_dim))) + self.b_tf_m = torch.nn.Parameter(torch.empty((c_m))) + self.linear_msa_m = nn.Linear(msa_dim, c_m) + self.w_tf_m + + # RPE stuff + self.relpos_k = relpos_k + self.no_bins = 2 * relpos_k + 1 + self.w_linear_relpos = torch.nn.Parameter(torch.empty((c_z, self.no_bins))) + self.b_linear_relpos = torch.nn.Parameter(torch.empty((c_z,))) + + def forward(self, tf: torch.Tensor, ri: torch.Tensor, msa: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + tf: + "target_feat" features of shape [*, N_res, tf_dim] + ri: + "residue_index" features of shape [*, N_res] + msa: + "msa_feat" features of shape [*, N_clust, N_res, msa_dim] + Returns: + msa_emb: + [*, N_clust, N_res, C_m] MSA embedding + pair_emb: + [*, N_res, N_res, C_z] pair embedding + + """ + # [*, N_res, c_z] + tf_emb_i = self.linear_tf_z_i(tf) + tf_emb_j = self.linear_tf_z_j(tf) + + # [*, N_res, N_res, c_z] + pair_emb = input_embedder_pair_emb( + ri, tf_emb_i, tf_emb_j, + self.w_linear_relpos, self.b_linear_relpos + ) + # pair_emb = relpos(ri.type(tf_emb_i.dtype)) + # pair_emb = pair_emb + tf_emb_i[..., None, :] + # pair_emb = pair_emb + tf_emb_j[..., None, :, :] + + # [*, N_clust, N_res, c_m] + tf_m = input_embedder_tf_m(tf, self.w_tf_m, self.b_tf_m) + msa_emb = self.linear_msa_m(msa) + tf_m + + return msa_emb, pair_emb + + + +@cube.graph.parser.register() +def sum_d(x: torch.Tensor, bins: torch.Tensor, inf: float) -> torch.Tensor: + squared_bins = bins ** 2 + upper = torch.cat( + [squared_bins[1:], squared_bins.new_tensor([inf])], dim=-1 + ) + d = torch.sum( + (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True + ) + d = ((d > squared_bins) * (d < upper)).type(x.dtype) + return d + + +class RecyclingEmbedder(nn.Module): + """ + Embeds the output of an iteration of the model for recycling. + + Implements Algorithm 32. + """ + def __init__(self, c_m: int, c_z: int, + min_bin: float, max_bin: float, no_bins: int, + inf: float = 1e8): + """ + Args: + c_m: + MSA channel dimension + c_z: + Pair embedding channel dimension + min_bin: + Smallest distogram bin (Angstroms) + max_bin: + Largest distogram bin (Angstroms) + no_bins: + Number of distogram bins + """ + super().__init__() + + self.c_m = c_m + self.c_z = c_z + self.min_bin = min_bin + self.max_bin = max_bin + self.no_bins = no_bins + self.inf = inf + + self.linear = nn.Linear(self.no_bins, self.c_z) + self.layer_norm_m = nn.LayerNorm(self.c_m) + self.layer_norm_z = nn.LayerNorm(self.c_z) + + bins = torch.linspace(self.min_bin, self.max_bin, self.no_bins, requires_grad=False) + self.register_buffer('bins', bins) + + def forward(self, m: torch.Tensor, z: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + m: + First row of the MSA embedding. [*, N_res, C_m] + z: + [*, N_res, N_res, C_z] pair embedding + x: + [*, N_res, 3] predicted C_beta coordinates + Returns: + m: + [*, N_res, C_m] MSA embedding update + z: + [*, N_res, N_res, C_z] pair embedding update + """ + m = self.layer_norm_m(m) + z = self.layer_norm_z(z) + d = sum_d(x, self.bins, self.inf) + d = self.linear(d) + z = z + d + return m, z + + +class TemplateAngleEmbedder(nn.Module): + """ + Embeds the "template_angle_feat" feature. + + Implements Algorithm 2, line 7. + """ + def __init__(self, c_in: int, c_out: int): + """ + Args: + c_in: + Final dimension of "template_angle_feat" + c_out: + Output channel dimension + """ + super(TemplateAngleEmbedder, self).__init__() + + self.c_out = c_out + self.c_in = c_in + + self.linear_1 = nn.Linear(self.c_in, self.c_out, init="relu") + self.relu = nn.ReLU() + self.linear_2 = nn.Linear(self.c_out, self.c_out, init="relu") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [*, N_templ, N_res, c_in] "template_angle_feat" features + Returns: + x: [*, N_templ, N_res, C_out] embedding + """ + x = self.linear_1(x) + x = self.relu(x) + x = self.linear_2(x) + return x + diff --git a/examples/openfold/blocks/evoformer.py b/examples/openfold/blocks/evoformer.py new file mode 100644 index 00000000..59e0f878 --- /dev/null +++ b/examples/openfold/blocks/evoformer.py @@ -0,0 +1,120 @@ +import torch +from examples.openfold.blocks.attention import MSARowAttention, MSAColAttention, Transition, TriangleAttentionNodeStart, TriangleAttentionNodeEnd +from examples.openfold.blocks.tmu import TriangleMultiplicativeUpdate +from examples.openfold.blocks.opm import OuterProducterMean + +import math +import cube + + +@cube.graph.parser.register('* -> *, *', name='multi2ref') +def multi2ref(x: torch.Tensor): + return (x, x) + + +class Evoformer(torch.nn.Module): + """ + Simulate execution of evoformer in alphafold. + + The mask and dropout is ommited for simplicity. + """ + + def __init__(self, s: int, cm: int, cz: int, + use_chunk=False, is_train=True, + c=32, msa_head=8, pair_head=4, + c_tri_mult=128, ff_mult=4): + super().__init__() + + self.s, self.cm, self.cz, self.c = s, cm, cz, c + self.msa_head, self.pair_head = msa_head, pair_head + self.c_tri_mult, self.ff_mult = c_tri_mult, ff_mult + self.scale = 1.0 / math.sqrt(c) + + self.is_train = is_train + + self.msa_row_chunk = 4 if use_chunk else -1 + self.msa_col_chunk = -1 + self.opm_chunk = self.tans_chunk = self.tane_chunk = -1 + + # MSA row-wise gated self-attention with pair bias + self.row_norm_m = torch.nn.LayerNorm(cm) + self.row_norm_z = torch.nn.LayerNorm(cz) + self.row_attn = MSARowAttention(cm, msa_head, cz, self.scale, self.msa_row_chunk) + + # MSA column-wise gated self-attention + self.col_norm = torch.nn.LayerNorm(cm) + self.col_attn = MSAColAttention(cm, msa_head, self.scale, self.msa_col_chunk) + + # MSA transition + self.msa_transition_norm = torch.nn.LayerNorm(cm) + self.msa_transition = Transition(cm, ff_mult) + + # Outer product mean + self.outer_norm = torch.nn.LayerNorm(cm) + self.outer_prod_mean = OuterProducterMean(cm, c, cz, self.opm_chunk) + + # Triangular multiplicative update using outgoing edges + self.tmo = TriangleMultiplicativeUpdate(cz, c_tri_mult, outgoing=True) + + # Triangular multiplicative update using incoming edges + self.tmi = TriangleMultiplicativeUpdate(cz, c_tri_mult, outgoing=False) + + # Triangular gated self-attention around starting node + self.tri_attn_node_start = TriangleAttentionNodeStart(cz, pair_head, c, self.scale, self.tans_chunk) + + # Triangular gated self-attention around ending node + self.tri_attn_node_end = TriangleAttentionNodeEnd(cz, pair_head, c, self.scale, self.tane_chunk) + + # Transition in the pair stack + self.pair_transition_norm = torch.nn.LayerNorm(cz) + self.pair_transition = Transition(cz, ff_mult) + + def forward(self, msa_repr, pair_repr): + + pair_repr, dummy_pair_repr = multi2ref(pair_repr) + + # msa row attention + residual = msa_repr + msa_repr = self.row_norm_m(msa_repr) + dummy_pair_repr = self.row_norm_z(dummy_pair_repr) + msa_repr = residual + self.row_attn(msa_repr, dummy_pair_repr) + + # msa column attention + residual = msa_repr + msa_repr = self.col_norm(msa_repr) + msa_repr = residual + self.col_attn(msa_repr) + + # msa transition + residual = msa_repr + msa_repr = self.msa_transition_norm(msa_repr) + msa_repr = self.msa_transition(msa_repr) + msa_repr = residual + msa_repr + + succ_msa_repr, msa_repr = multi2ref(msa_repr) + + # out product mean + msa_repr = self.outer_norm(msa_repr) + pair_repr = pair_repr + self.outer_prod_mean(msa_repr) + + # triangle multiplicative out-going edges + pair_repr = self.tmo(pair_repr) + # triangle multiplicative in-going edges + pair_repr = self.tmi(pair_repr) + + # pair attention start + residual = pair_repr + pair_repr = self.tri_attn_node_start(pair_repr) + pair_repr = residual + pair_repr + + # pair attention end + residual = pair_repr + pair_repr = self.tri_attn_node_end(pair_repr) + pair_repr = residual + pair_repr + + # pair transition + residual = pair_repr + pair_repr = self.pair_transition_norm(pair_repr) + pair_repr = self.pair_transition(pair_repr) + pair_repr = residual + pair_repr + + return succ_msa_repr, pair_repr diff --git a/examples/openfold/blocks/opm.py b/examples/openfold/blocks/opm.py new file mode 100644 index 00000000..01a15b77 --- /dev/null +++ b/examples/openfold/blocks/opm.py @@ -0,0 +1,68 @@ +""" +Outer Product Mean module for Evoformer +""" + +import cube +import torch +import torch.utils.checkpoint as ckpt + + +@cube.graph.parser.register('N S+ R^ M^, M^ c^, M^ c^, (c^ c^) cz^ -> N R^ R^ cz^', name='outer_prod_mean') +@torch.jit.ignore +def outer_prod_mean(msa_repr: torch.Tensor, left_proj: torch.Tensor, right_proj: torch.Tensor, + out_proj: torch.Tensor, chunk_size: int, training: bool): + # cube.profiler.CudaTimer().start('opm') + # N S R M, M c -> N S R c + opm_left = torch.matmul(msa_repr, left_proj) + # N S T M, M c -> N S T c + opm_right = torch.matmul(msa_repr, right_proj) + bs, s, r, c = opm_left.size() + t = opm_right.size(2) + + # N S R M -> N R S M + a = opm_left.transpose(-2, -3) + # N S T M -> N T S M + b = opm_right.transpose(-2, -3) + + if chunk_size == -1: + # N R S M, N T S M -> N R T M M -> N R T (M M) + outer = torch.einsum('...bac,...dae->...bdce', a, + b).reshape(bs, r, t, c * c) + # N R T (M M), (M M) Z -> N R T Z + outer = torch.matmul(outer, out_proj) + else: + out_chunks = [] + + def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): + lhs_slice = lhs[:, start:start + chunk_size, :, :] + out = torch.einsum('...bac,...dae->...bdce', lhs_slice, + rhs).reshape(bs, chunk_size, t, c * c) + out = torch.matmul(out, out_proj) + return out + + for start in range(0, r, chunk_size): + ret = ckpt.checkpoint(opm, a, b, start) + ret = opm(a, b, start) + out_chunks.append(ret) + outer = torch.cat(out_chunks, dim=1) + # cube.profiler.CudaTimer().stop('opm') + return outer + + +class OuterProducterMean(torch.nn.Module): + + def __init__(self, cm: int, c: int, cz: int, chunk_size: int) -> None: + super().__init__() + self.left = torch.nn.Parameter(torch.empty(cm, c)) + self.right = torch.nn.Parameter(torch.empty(cm, c)) + self.out = torch.nn.Parameter(torch.empty(c * c, cz)) + self.chunk_size = chunk_size + + def forward(self, msa_repr: torch.Tensor): + """ + msa_repr: [N S R M] + """ + return outer_prod_mean( + msa_repr, self.left, self.right, self.out, + self.chunk_size, self.training + ) diff --git a/examples/openfold/blocks/tmu.py b/examples/openfold/blocks/tmu.py new file mode 100644 index 00000000..79a97ffa --- /dev/null +++ b/examples/openfold/blocks/tmu.py @@ -0,0 +1,79 @@ +import cube +import torch +import torch.utils.checkpoint as ckpt + + +@cube.graph.parser.register('N S R Z, Z E, Z E, Z E, Z E, Z Z, E, E, E Z -> N S R Z') +@torch.jit.ignore +def tmu(pair_repr: torch.Tensor, + left1: torch.Tensor, left2: torch.Tensor, + right1: torch.Tensor, right2: torch.Tensor, + gate: torch.Tensor, + norm_weight: torch.Tensor, norm_bias: torch.Tensor, + out: torch.Tensor, outgoing: bool) -> torch.Tensor: + # cube.profiler.CudaTimer().start('tmu') + # Note S == R + # left projection: N S R Z^, Z^ E, Z^ E -> N S R E + left = torch.matmul(pair_repr, left1) + left = torch.sigmoid(left) + left = left * torch.matmul(pair_repr, left2) + # right projection: N S R Z^, Z^ E, Z^ E -> N S R E + right = torch.matmul(pair_repr, right1) + right = torch.sigmoid(right) + right = right * torch.matmul(pair_repr, right2) + if outgoing: + # N S R E -> N E S R + left = left.permute(0, 3, 1, 2) + # N S R E -> N E R S + right = right.permute(0, 3, 2, 1) + else: + # N S R E -> N E R S + left = left.permute(0, 3, 2, 1) + # N S R E -> N E S R + right = right.permute(0, 3, 1, 2) + # N E S R+, N E R+ S -> N E S S -> N S S E (for out) + # N E R S, N E S R -> N E R R -> N R R E (for in) + p = torch.matmul(left, right).permute(0, 2, 3, 1) + e = p.size(3) + # N S S E^ -> N S S E^ + p = torch.nn.functional.layer_norm(p, (e,), norm_weight, norm_bias) + # N S S E+, E+ Z -> N S S Z + p = torch.matmul(p, out) + if not outgoing: + p = p.permute(0, 2, 1, 3) + # gate: N S R Z+, Z+ Z -> N S R Z + g = torch.matmul(pair_repr, gate) + g = torch.sigmoid(g) + # N S S Z, N S R Z -> N S R Z (broadcast R == S == 0) + p = p * g + # cube.profiler.CudaTimer().stop('tmu') + return p + + +class TriangleMultiplicativeUpdate(torch.nn.Module): + + def __init__(self, cz: int, mult: int, outgoing: bool) -> None: + super().__init__() + self.layer_norm = torch.nn.LayerNorm((cz,)) + + self.left1 = torch.nn.Parameter(torch.empty(cz, mult)) + self.left2 = torch.nn.Parameter(torch.empty(cz, mult)) + self.right1 = torch.nn.Parameter(torch.empty(cz, mult)) + self.right2 = torch.nn.Parameter(torch.empty(cz, mult)) + self.normw = torch.nn.Parameter(torch.empty(mult)) + self.normb = torch.nn.Parameter(torch.empty(mult)) + self.out = torch.nn.Parameter(torch.empty(mult, cz)) + self.gate = torch.nn.Parameter(torch.empty(cz, cz)) + self.outgoing = outgoing + + def forward(self, pair_repr: torch.Tensor): + """ + pair_repr: [N S R Z] + """ + residual = pair_repr + pair_repr = self.layer_norm(pair_repr) + pair_repr = tmu(pair_repr, + self.left1, self.left2, self.right1, self.right2, + self.gate, self.normw, self.normb, self.out, self.outgoing) + pair_repr = residual + pair_repr + return pair_repr diff --git a/examples/openfold/model.py b/examples/openfold/model.py new file mode 100644 index 00000000..7f09b6d1 --- /dev/null +++ b/examples/openfold/model.py @@ -0,0 +1,142 @@ +""" +Alphafold 2, using implementation similar with OpenFold. +""" + +import torch +import torch.nn as nn + +# from examples.openfold.blocks.embedder import InputEmbedder, RecyclingEmbedder, TemplateAngleEmbedder +from examples.openfold.blocks.evoformer import Evoformer + +from dataclasses import dataclass + +import cube + + +@dataclass +class Config: + + # input_embedder + input_embedder_cm = 256 + input_embedder_cz = 128 + input_embedder_msa_dim = 49 + input_embedder_relpos_k = 32 + input_embedder_tf_dim = 22 + + # recycling embedder + recycling_embedder_cm = 256 + recycling_embedder_cz = 128 + recycling_embedder_inf = 1000000000.0 + recycling_embedder_maxbin = 20.75 + recycling_embedder_minbin = 3.25 + recycling_embedder_nobins = 15 + + # templates + template_angle_embedder_cin = 57 + template_angle_embedder_cout = 256 + template_pair_embedder_cin = 88 + template_pair_embedder_cout = 64 + template_pair_stack_hidden_tri_att = 16 + template_pair_stack_hidden_tri_mul = 64 + template_pair_stack_ct = 64 + template_pair_stack_dp = 0.25 + template_pair_stack_inf = 100000000.0 + template_pair_stack_noblocks = 2 + template_pair_stack_noheads = 4 + template_pair_stack_pair_transition_n = 2 + template_pointwise_attention_hidden = 16 + template_pointwise_attention_ct = 64 + template_pointwise_attention_cz = 128 + template_pointwise_inf = 1000000000.0 + template_pointwise_noheads = 4 + + # extra msa + extra_msa_embedder_cin = 25 + extra_msa_embedder_cout = 64 + extra_msa_stack_hidden_att = 8 + extra_msa_stack_hidden_mul = 128 + extra_msa_stack_opm = 32 + extra_msa_stack_pair_att = 32 + extra_msa_stack_cm = 64 + extra_msa_stack_cz = 128 + extra_msa_stack_eps = 1e-8 + extra_msa_stack_inf = 1000000000.0 + extra_msa_dp = 0.15 + extra_msa_stack_noblocks = 4 + extra_msa_stack_no_heads_msa = 8 + extra_msa_stack_no_heads_pair = 4 + extra_msa_stack_pair_dp = 0.25 + extra_msa_stack_transition_n = 4 + + # evoformer + evoformer_s = 128 + evoformer_r = 256 + evoformer_cm = 256 + evoformer_cz = 128 + evoformer_use_chunk = False + evoformer_is_extra = False + evoformer_nlayers = 48 + + # batch size + bs = 1 + + + +class AlphaFold(nn.Module): + + + def __init__(self, cfg: Config = Config()) -> None: + super().__init__() + + # self.input_embedder = InputEmbedder( + # cfg.input_embedder_tf_dim, cfg.input_embedder_msa_dim, + # cfg.input_embedder_cz, cfg.input_embedder_cm, + # cfg.input_embedder_relpos_k + # ) + # self.recycling_embedder = RecyclingEmbedder( + # cfg.recycling_embedder_cm, cfg.recycling_embedder_cz, + # cfg.recycling_embedder_minbin, cfg.recycling_embedder_maxbin, + # cfg.recycling_embedder_nobins, cfg.recycling_embedder_inf + # ) + + # template config + # self.template_angle_embedder = TemplateAngleEmbedder( + # cfg.template_angle_embedder_cin, + # cfg.template_angle_embedder_cout + # ) + # self.template_pair_embedder = nn.Linear( + # cfg.template_pair_embedder_cin, + # cfg.template_pair_embedder_cout, + # ) + self.template_pair_stack = None # TemplatePairStack() + self.template_pointwise_att = None # TemplatePointwiseAttention() + + # extra msa + # self.extra_msa_embedder = nn.Linear( + # cfg.extra_msa_embedder_cin, cfg.extra_msa_embedder_cout + # ) + + self.extra_msa_stack = None # ExtraMSAStack() + + # evoformer + # self.evoformer = EvoformerStack() + self.msa_norm = torch.nn.LayerNorm(cfg.evoformer_cm) + self.pair_norm = torch.nn.LayerNorm(cfg.evoformer_cz) + self.evoformers = nn.ModuleList( + [Evoformer( + cfg.evoformer_s, cfg.evoformer_cm, cfg.evoformer_cz, + cfg.evoformer_use_chunk + ) for _ in range(cfg.evoformer_nlayers)] + ) + + self.structure_module = None # StructureModule() + self.aux_heads = None # AuxiliaryHeads() + + def forward(self, msa, pair): + msa = self.msa_norm(msa) + pair = self.pair_norm(pair) + for evoformer in self.evoformers: + cube.runtime.function.anchor('Evoformer Start') + msa, pair = evoformer(msa, pair) + loss = torch.sum(msa) * torch.sum(pair) + return loss diff --git a/examples/openfold/policy/mpmd.py b/examples/openfold/policy/mpmd.py new file mode 100644 index 00000000..6e11691c --- /dev/null +++ b/examples/openfold/policy/mpmd.py @@ -0,0 +1,114 @@ +from typing import List + +from cube.graph import IRGraph +from cube.ir.cten import IRCell +from cube.graph.function.anchor import IRGraphAnchor +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation + + +def _group_to_evoformers(fnodes) -> List[List[IRCell]]: + # group to evoformer layers + evoformers: List[List[IRFwOperation]] = [] + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + fnodes[idx+1].comment = f'===> start of transformer layer {lid}' + start = idx if lid != 0 else 0 + end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) + evoformers.append(fnodes[start:end]) + for lid in range(len(evoformers) - 1): + if evoformers[lid][-1].name == 'multiref': + node = evoformers[lid].pop() + evoformers[lid+1].insert(0, node) + return evoformers + +# ========================= parallelisms ================================= + +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], + idx: int, dim: int, tag='dim'): + algo = node.algorithms(tag) + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# coshard +def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, + idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) + assert sub_nodes is not None + graph.recompute(sub_nodes) + for devid in devs: + for coid in range(colocate): + sub_node = sub_nodes[devid * colocate + coid] + graph.assign(sub_node, devid) + return sub_nodes + + +def PASSingle(graph: IRGraph, resource): + assert resource.ngpus == 1 + # print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + return graph + + +def PASDP(graph: IRGraph, resource): + dp_size = resource.ngpus + dp_devs = list(range(dp_size)) + + dataloader = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[0] + + # partition dataloader + dls = graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) + for devid, dl in enumerate(dls): + graph.assign(dl, devid) + + # partition forward operators + for node in graph.select(ntype=IRFwOperation): + if len(node.inputs()) == 0: continue + #FIXME: a workaround to find batch dimension + batch_dim = node.input(0).shape.index(bs) + _tp(graph, node, dp_devs, idx=0, dim=batch_dim) + + return graph + + +def PASTP(graph: IRGraph, resource): + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + + # grouping + evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) + for layer in evoformers: + graph.recompute(layer) + + for node in graph.select(ntype=(IRDataOperation, IRFwOperation)): + if isinstance(node, IRGraphAnchor): continue + if node.name == 'row_attn': + _tp(graph, node, tp_devs, idx=2, dim=1) + elif node.name == 'col_attn': + _tp(graph, node, tp_devs, idx=1, dim=1) + elif node.name == 'feedforward': + _tp(graph, node, tp_devs, idx=1, dim=1) + elif node.name == 'tri_attn_start': + _tp(graph, node, tp_devs, idx=1, dim=1) + elif node.name == 'tri_attn_end': + _tp(graph, node, tp_devs, idx=1, dim=1) + elif node.name == 'outer_prod_mean': + _tp(graph, node, tp_devs, idx=0, dim=1) + else: + _replica(graph, node, tp_devs) + return graph \ No newline at end of file diff --git a/examples/openfold/train.py b/examples/openfold/train.py new file mode 100644 index 00000000..e8cf4413 --- /dev/null +++ b/examples/openfold/train.py @@ -0,0 +1,89 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/openfold/train.py --fp16 +""" + + +import torch +from examples.openfold.model import AlphaFold, Config + +import cube +from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.memory import memory_summary +from examples.openfold.policy.mpmd import * + +import argparse + +cube.init() + +parser = argparse.ArgumentParser(description='AlphaFold Train') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') +parser.add_argument('--fp16', action='store_true', default=False, + help='use fp16 for the training') +args = parser.parse_args() + + +def nparams(model) -> int: + cnt = 0 + for param in model.parameters(): + cnt += param.nelement() + return cnt + + +def train(): + + cfg = Config() + model = AlphaFold(cfg) + if args.fp16: + model = model.half() + + dtype = torch.float16 if args.fp16 else torch.float32 + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([cfg.bs, cfg.evoformer_s, cfg.evoformer_r, cfg.evoformer_cm], + [cfg.bs, cfg.evoformer_r, cfg.evoformer_r, cfg.evoformer_cz]), + dtypes=(dtype, dtype), + batch_dims=(0, 0) + ) + + print_each_rank(f'before partitioned model parameter: {nparams(model)}') + + model = cube.SemanticModel(model) + @cube.compile(model, dataloader, PAS=PASTP, override=True, load_content=True) + def train_iter(model, dataloader): + input_ids, position_ids = next(dataloader) + loss = model(input_ids, position_ids) + loss.backward() + model = model.get_gen_module() + + optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + print_each_rank(f'after partitioned model parameter: {nparams(model)}') + + torch.distributed.barrier() + print_each_rank('model weight consumpition:', rank_only=0) + memory_summary() + + CudaTimer(enable=False).warmup() + iter_num, warmup = 5, 2 + for step in range(iter_num): + if step == warmup: + CudaTimer(enable=True, predefined=True).start('e2e') + + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + + if step == 0: + print_each_rank('passed first iteration') + if (step + 1) % 2 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + CudaTimer().stop('e2e') + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) + + memory_summary() + +train() \ No newline at end of file From 27d459d9d42c58f5e132016426916dfa89351b39 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 30 Nov 2022 20:47:47 +0800 Subject: [PATCH 1171/1892] update script --- examples/openfold/blocks/attention.py | 2 +- examples/openfold/blocks/evoformer.py | 5 +- examples/openfold/blocks/tmu.py | 98 ++++++++++++++++++++++-- examples/openfold/blocks/utils.py | 7 ++ examples/openfold/model.py | 105 +++++++++++++------------- examples/openfold/policy/mpmd.py | 20 ++--- examples/openfold/train.py | 27 ++++++- 7 files changed, 181 insertions(+), 83 deletions(-) create mode 100644 examples/openfold/blocks/utils.py diff --git a/examples/openfold/blocks/attention.py b/examples/openfold/blocks/attention.py index 8d581fdf..184ca8cc 100644 --- a/examples/openfold/blocks/attention.py +++ b/examples/openfold/blocks/attention.py @@ -160,7 +160,7 @@ def col_attn(msa_repr: torch.Tensor, gate_proj: torch.Tensor, return out -@cube.graph.parser.register('N S^ R^ M^, M^ (head+ dim^), M^ E^, M^ E^, M^ (head+ dim^), (head+ dim) M^ -> N S^ R^ M^', name='MSAColGlobalAttention') +# @cube.graph.parser.register('N S^ R^ M^, M^ (head+ dim^), M^ E^, M^ E^, M^ (head+ dim^), (head+ dim) M^ -> N S^ R^ M^', name='MSAColGlobalAttention') def global_attn(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, v_proj: torch.Tensor, gate_proj: torch.Tensor, out_proj: torch.Tensor, diff --git a/examples/openfold/blocks/evoformer.py b/examples/openfold/blocks/evoformer.py index 59e0f878..f2574ca5 100644 --- a/examples/openfold/blocks/evoformer.py +++ b/examples/openfold/blocks/evoformer.py @@ -2,15 +2,12 @@ from examples.openfold.blocks.attention import MSARowAttention, MSAColAttention, Transition, TriangleAttentionNodeStart, TriangleAttentionNodeEnd from examples.openfold.blocks.tmu import TriangleMultiplicativeUpdate from examples.openfold.blocks.opm import OuterProducterMean +from examples.openfold.blocks.utils import multi2ref import math import cube -@cube.graph.parser.register('* -> *, *', name='multi2ref') -def multi2ref(x: torch.Tensor): - return (x, x) - class Evoformer(torch.nn.Module): """ diff --git a/examples/openfold/blocks/tmu.py b/examples/openfold/blocks/tmu.py index 79a97ffa..b2daa4e2 100644 --- a/examples/openfold/blocks/tmu.py +++ b/examples/openfold/blocks/tmu.py @@ -1,9 +1,9 @@ import cube import torch -import torch.utils.checkpoint as ckpt +from examples.openfold.blocks.utils import multi2ref -@cube.graph.parser.register('N S R Z, Z E, Z E, Z E, Z E, Z Z, E, E, E Z -> N S R Z') +# @cube.graph.parser.register('N S R Z, Z E, Z E, Z E, Z E, Z Z, E, E, E Z -> N S R Z') @torch.jit.ignore def tmu(pair_repr: torch.Tensor, left1: torch.Tensor, left2: torch.Tensor, @@ -32,7 +32,7 @@ def tmu(pair_repr: torch.Tensor, # N S R E -> N E S R right = right.permute(0, 3, 1, 2) # N E S R+, N E R+ S -> N E S S -> N S S E (for out) - # N E R S, N E S R -> N E R R -> N R R E (for in) + # N E R S+, N E S+ R -> N E R R -> N R R E (for in) p = torch.matmul(left, right).permute(0, 2, 3, 1) e = p.size(3) # N S S E^ -> N S S E^ @@ -50,6 +50,74 @@ def tmu(pair_repr: torch.Tensor, return p +@cube.graph.parser.register('N S^ R+ Z^, Z^ E, Z^ E, Z^ E, Z^ E -> N S S E', name='tmi_projection') +@torch.jit.ignore +def tmi_projection(pair_repr: torch.Tensor, + left1: torch.Tensor, left2: torch.Tensor, + right1: torch.Tensor, right2: torch.Tensor) -> torch.Tensor: + # left projection: N S R Z^, Z^ E, Z^ E -> N S R E + left = torch.matmul(pair_repr, left1) + left = torch.sigmoid(left) + left = left * torch.matmul(pair_repr, left2) + # right projection: N S R Z^, Z^ E, Z^ E -> N S R E + right = torch.matmul(pair_repr, right1) + right = torch.sigmoid(right) + right = right * torch.matmul(pair_repr, right2) + # N S R E -> N E S R + left = left.permute(0, 3, 1, 2) + # N S R E -> N E R S + right = right.permute(0, 3, 2, 1) + # N E S R+, N E R+ S -> N E S S -> N S S E + p = torch.matmul(left, right).permute(0, 2, 3, 1) + return p + + + +@cube.graph.parser.register('N S R Z^, N R S E, E Z^, Z^ Z^ -> N S R Z^') +@torch.jit.ignore +def tmi_gating(pair_repr: torch.Tensor, p: torch.Tensor, out: torch.Tensor, gate: torch.Tensor): + # N S R Z+, Z+ Z -> N S R Z + g = torch.matmul(pair_repr, gate) + g = torch.sigmoid(g) + # N R S E+, E+ Z -> N R S Z -> N S R Z + p = torch.matmul(p, out).permute(0, 2, 1, 3) + p = p * g + return p + + +@cube.graph.parser.register('N S+ R^ Z^, Z^ E, Z^ E, Z^ E, Z^ E -> N R R E', name='tmo_projection') +def tmo_projection(pair_repr: torch.Tensor, + left1: torch.Tensor, left2: torch.Tensor, + right1: torch.Tensor, right2: torch.Tensor) -> torch.Tensor: + # left projection: N S R Z^, Z^ E, Z^ E -> N S R E + left = torch.matmul(pair_repr, left1) + left = torch.sigmoid(left) + left = left * torch.matmul(pair_repr, left2) + # right projection: N S R Z^, Z^ E, Z^ E -> N S R E + right = torch.matmul(pair_repr, right1) + right = torch.sigmoid(right) + right = right * torch.matmul(pair_repr, right2) + # N S R E -> N E R S + left = left.permute(0, 3, 2, 1) + # N S R E -> N E S R + right = right.permute(0, 3, 1, 2) + # N E R S+, N E S+ R -> N E R R -> N R R E + p = torch.matmul(left, right).permute(0, 2, 3, 1) + return p + + +@cube.graph.parser.register('N S R Z^, N S R E^, E^ Z^, Z^ Z^ -> N S R Z^') +def tmo_gating(pair_repr: torch.Tensor, p: torch.Tensor, out: torch.Tensor, gate: torch.Tensor): + # N S R Z+, Z+ Z -> N S R Z + g = torch.matmul(pair_repr, gate) + g = torch.sigmoid(g) + # N S R E+, E+ Z -> N S R Z + p = torch.matmul(p, out) + p = p * g + return p + + + class TriangleMultiplicativeUpdate(torch.nn.Module): def __init__(self, cz: int, mult: int, outgoing: bool) -> None: @@ -60,8 +128,10 @@ def __init__(self, cz: int, mult: int, outgoing: bool) -> None: self.left2 = torch.nn.Parameter(torch.empty(cz, mult)) self.right1 = torch.nn.Parameter(torch.empty(cz, mult)) self.right2 = torch.nn.Parameter(torch.empty(cz, mult)) - self.normw = torch.nn.Parameter(torch.empty(mult)) - self.normb = torch.nn.Parameter(torch.empty(mult)) + + self.norm = torch.nn.LayerNorm(mult) + # self.normw = torch.nn.Parameter(torch.empty(mult)) + # self.normb = torch.nn.Parameter(torch.empty(mult)) self.out = torch.nn.Parameter(torch.empty(mult, cz)) self.gate = torch.nn.Parameter(torch.empty(cz, cz)) self.outgoing = outgoing @@ -72,8 +142,20 @@ def forward(self, pair_repr: torch.Tensor): """ residual = pair_repr pair_repr = self.layer_norm(pair_repr) - pair_repr = tmu(pair_repr, - self.left1, self.left2, self.right1, self.right2, - self.gate, self.normw, self.normb, self.out, self.outgoing) + # ====================== break for tp ======================= + pair_repr1, pair_repr2 = multi2ref(pair_repr) + if self.outgoing: + p = tmi_projection(pair_repr1, self.left1, self.left2, self.right1, self.right2) + else: + p = tmo_projection(pair_repr1, self.left1, self.left2, self.right1, self.right2) + p = self.norm(p) + if self.outgoing: + pair_repr = tmi_gating(pair_repr2, p, self.out, self.gate) + else: + pair_repr = tmo_gating(pair_repr2, p, self.out, self.gate) + # ======================= intergrate version ================== + # pair_repr = tmu(pair_repr, + # self.left1, self.left2, self.right1, self.right2, + # self.gate, self.normw, self.normb, self.out, self.outgoing) pair_repr = residual + pair_repr return pair_repr diff --git a/examples/openfold/blocks/utils.py b/examples/openfold/blocks/utils.py new file mode 100644 index 00000000..520a3c1d --- /dev/null +++ b/examples/openfold/blocks/utils.py @@ -0,0 +1,7 @@ +import cube +import torch + + +@cube.graph.parser.register('* -> *, *', name='multi2ref') +def multi2ref(x: torch.Tensor): + return (x, x) \ No newline at end of file diff --git a/examples/openfold/model.py b/examples/openfold/model.py index 7f09b6d1..50f77936 100644 --- a/examples/openfold/model.py +++ b/examples/openfold/model.py @@ -17,68 +17,69 @@ class Config: # input_embedder - input_embedder_cm = 256 - input_embedder_cz = 128 - input_embedder_msa_dim = 49 - input_embedder_relpos_k = 32 - input_embedder_tf_dim = 22 + # input_embedder_cm = 256 + # input_embedder_cz = 128 + # input_embedder_msa_dim = 49 + # input_embedder_relpos_k = 32 + # input_embedder_tf_dim = 22 # recycling embedder - recycling_embedder_cm = 256 - recycling_embedder_cz = 128 - recycling_embedder_inf = 1000000000.0 - recycling_embedder_maxbin = 20.75 - recycling_embedder_minbin = 3.25 - recycling_embedder_nobins = 15 + # recycling_embedder_cm = 256 + # recycling_embedder_cz = 128 + # recycling_embedder_inf = 1000000000.0 + # recycling_embedder_maxbin = 20.75 + # recycling_embedder_minbin = 3.25 + # recycling_embedder_nobins = 15 # templates - template_angle_embedder_cin = 57 - template_angle_embedder_cout = 256 - template_pair_embedder_cin = 88 - template_pair_embedder_cout = 64 - template_pair_stack_hidden_tri_att = 16 - template_pair_stack_hidden_tri_mul = 64 - template_pair_stack_ct = 64 - template_pair_stack_dp = 0.25 - template_pair_stack_inf = 100000000.0 - template_pair_stack_noblocks = 2 - template_pair_stack_noheads = 4 - template_pair_stack_pair_transition_n = 2 - template_pointwise_attention_hidden = 16 - template_pointwise_attention_ct = 64 - template_pointwise_attention_cz = 128 - template_pointwise_inf = 1000000000.0 - template_pointwise_noheads = 4 + # template_angle_embedder_cin = 57 + # template_angle_embedder_cout = 256 + # template_pair_embedder_cin = 88 + # template_pair_embedder_cout = 64 + # template_pair_stack_hidden_tri_att = 16 + # template_pair_stack_hidden_tri_mul = 64 + # template_pair_stack_ct = 64 + # template_pair_stack_dp = 0.25 + # template_pair_stack_inf = 100000000.0 + # template_pair_stack_noblocks = 2 + # template_pair_stack_noheads = 4 + # template_pair_stack_pair_transition_n = 2 + # template_pointwise_attention_hidden = 16 + # template_pointwise_attention_ct = 64 + # template_pointwise_attention_cz = 128 + # template_pointwise_inf = 1000000000.0 + # template_pointwise_noheads = 4 # extra msa - extra_msa_embedder_cin = 25 - extra_msa_embedder_cout = 64 - extra_msa_stack_hidden_att = 8 - extra_msa_stack_hidden_mul = 128 - extra_msa_stack_opm = 32 - extra_msa_stack_pair_att = 32 - extra_msa_stack_cm = 64 - extra_msa_stack_cz = 128 - extra_msa_stack_eps = 1e-8 - extra_msa_stack_inf = 1000000000.0 - extra_msa_dp = 0.15 - extra_msa_stack_noblocks = 4 - extra_msa_stack_no_heads_msa = 8 - extra_msa_stack_no_heads_pair = 4 - extra_msa_stack_pair_dp = 0.25 - extra_msa_stack_transition_n = 4 + # extra_msa_embedder_cin = 25 + # extra_msa_embedder_cout = 64 + # extra_msa_stack_hidden_att = 8 + # extra_msa_stack_hidden_mul = 128 + # extra_msa_stack_opm = 32 + # extra_msa_stack_pair_att = 32 + # extra_msa_stack_cm = 64 + # extra_msa_stack_cz = 128 + # extra_msa_stack_eps = 1e-8 + # extra_msa_stack_inf = 1000000000.0 + # extra_msa_dp = 0.15 + # extra_msa_stack_noblocks = 4 + # extra_msa_stack_no_heads_msa = 8 + # extra_msa_stack_no_heads_pair = 4 + # extra_msa_stack_pair_dp = 0.25 + # extra_msa_stack_transition_n = 4 # evoformer - evoformer_s = 128 - evoformer_r = 256 - evoformer_cm = 256 - evoformer_cz = 128 - evoformer_use_chunk = False - evoformer_is_extra = False - evoformer_nlayers = 48 + evoformer_s: int = 128 + evoformer_r: int = 256 + evoformer_cm: int = 256 + evoformer_cz: int = 128 + evoformer_c: int = 32 + evoformer_use_chunk: bool = False + evoformer_is_extra: bool = False + evoformer_nlayers: int = 4 # batch size - bs = 1 + bs: int = 1 diff --git a/examples/openfold/policy/mpmd.py b/examples/openfold/policy/mpmd.py index 6e11691c..c9ec7987 100644 --- a/examples/openfold/policy/mpmd.py +++ b/examples/openfold/policy/mpmd.py @@ -41,20 +41,6 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): graph.assign(sub_node, devid) return sub_nodes -# coshard -def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, - idx: int, dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) - assert sub_nodes is not None - graph.recompute(sub_nodes) - for devid in devs: - for coid in range(colocate): - sub_node = sub_nodes[devid * colocate + coid] - graph.assign(sub_node, devid) - return sub_nodes - - def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 # print(graph.extra_repr()) @@ -109,6 +95,12 @@ def PASTP(graph: IRGraph, resource): _tp(graph, node, tp_devs, idx=1, dim=1) elif node.name == 'outer_prod_mean': _tp(graph, node, tp_devs, idx=0, dim=1) + elif node.name == 'tmi_projection': + _tp(graph, node, tp_devs, idx=0, dim=2) + elif node.name == 'tmi_projection': + _tp(graph, node, tp_devs, idx=0, dim=1) + elif node.name == 'tmi_gating' or node.name == 'tmo_gating': + _tp(graph, node, tp_devs, idx=0, dim=1) else: _replica(graph, node, tp_devs) return graph \ No newline at end of file diff --git a/examples/openfold/train.py b/examples/openfold/train.py index e8cf4413..89bc2f13 100644 --- a/examples/openfold/train.py +++ b/examples/openfold/train.py @@ -18,10 +18,25 @@ cube.init() parser = argparse.ArgumentParser(description='AlphaFold Train') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') +parser.add_argument('--layers', type=int, default=48, + help='evoformer layer number') +parser.add_argument('--msa-hidden', type=int, default=256, + help='cm value') +parser.add_argument('--pair-hidden', type=int, default=128, + help='cz value') +parser.add_argument('--head-dim', type=int, default=32, + help='c value') +parser.add_argument('--mbs', type=int, default=1, + help='micro batch size') +parser.add_argument('--gbs', type=int, default=1, + help='global batch size') + args = parser.parse_args() +assert args.gbs % args.mbs == 0 +assert args.msa_hidden % args.head_dim == 0 +assert args.pair_hidden % args.head_dim == 0 def nparams(model) -> int: @@ -33,7 +48,11 @@ def nparams(model) -> int: def train(): - cfg = Config() + cfg = Config(evoformer_cm=args.msa_hidden, evoformer_cz=args.pair_hidden, + evoformer_c=args.head_dim, evoformer_nlayers=args.layers, + bs=args.mbs) + print_each_rank(cfg, rank_only=0) + model = AlphaFold(cfg) if args.fp16: model = model.half() @@ -69,8 +88,8 @@ def train_iter(model, dataloader): for step in range(iter_num): if step == warmup: CudaTimer(enable=True, predefined=True).start('e2e') - - train_iter(model, dataloader) + for _ in range(args.gbs // args.mbs): + train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() From 67540cbe41cda82b0f6beeee4aa9a99b986c7dae Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 1 Dec 2022 19:41:59 +0800 Subject: [PATCH 1172/1892] fix return value of cases that partition failed --- cube/algorithm/ops/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py index da10cf68..4914b381 100644 --- a/cube/algorithm/ops/dataloader.py +++ b/cube/algorithm/ops/dataloader.py @@ -34,7 +34,7 @@ def satisfy(self, num: int): def instantiate(self, num: int): if not self.satisfy(num): - return False + return None node: IRDataOperation = self.node dims: List[int] = node.get_batch_dims() From 6333d1a5644e5e86cbf8c19713cc206b743f419a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 1 Dec 2022 19:42:55 +0800 Subject: [PATCH 1173/1892] use DAP for SPMD policy --- examples/openfold/blocks/attention.py | 21 ++-- examples/openfold/blocks/evoformer.py | 21 ++-- examples/openfold/blocks/opm.py | 56 ++++++++- examples/openfold/blocks/tmu.py | 165 ++++++++------------------ examples/openfold/policy/mpmd.py | 125 +++++++++++++------ examples/openfold/train.py | 18 ++- 6 files changed, 227 insertions(+), 179 deletions(-) diff --git a/examples/openfold/blocks/attention.py b/examples/openfold/blocks/attention.py index 184ca8cc..7c94aafb 100644 --- a/examples/openfold/blocks/attention.py +++ b/examples/openfold/blocks/attention.py @@ -147,7 +147,7 @@ def row_attn(msa_repr: torch.Tensor, pair_repr: torch.Tensor, head, c, scale, chunk_size, is_train) -@cube.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S R^ M^', name='col_attn') +@cube.graph.parser.register('N S^ R M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S^ R M^', name='col_attn') def col_attn(msa_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): @@ -211,24 +211,26 @@ def feedforward(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor return x -@cube.graph.parser.register('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, Z^ head+ -> N S R^ Z^', name='tri_attn_start') +@cube.graph.parser.register('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N R^ R^ head+ -> N S R^ Z^', name='tri_attn_start') def tri_attn_start(pair_repr: torch.Tensor, gate: torch.Tensor, qkv: torch.Tensor, out: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): - bias = torch.matmul(pair_repr, bias).permute(0, 3, 1, 2).unsqueeze(1) + # bias = torch.matmul(pair_repr, bias).permute(0, 3, 1, 2).unsqueeze(1) + bias = bias.permute(0, 3, 1, 2).unsqueeze(1) out = msa_attn_bias(pair_repr, gate, qkv, out, bias, head, c, scale, chunk_size, is_train) return out -@cube.graph.parser.register('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, Z head+ -> N S R^ Z^', name='tri_attn_end') +@cube.graph.parser.register('N S^ R Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N S^ S^ head+ -> N S^ R Z^', name='tri_attn_end') def tri_attn_end(pair_repr: torch.Tensor, gate: torch.Tensor, qkv: torch.Tensor, out: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): - bias = torch.matmul(pair_repr, bias).permute(0, 3, 2, 1).unsqueeze(1) + # bias = torch.matmul(pair_repr, bias).permute(0, 3, 2, 1).unsqueeze(1) + bias = bias.permute(0, 3, 2, 1).unsqueeze(1) pair_repr = pair_repr.permute(0, 2, 1, 3) out = msa_attn_bias(pair_repr, gate, qkv, out, bias, head, c, scale, chunk_size, is_train) @@ -321,9 +323,13 @@ def __init__(self, cz: int, pair_head: int, c: int, scale: float, chunk_size=-1) self.bias = torch.nn.Parameter(torch.empty(cz, pair_head)) def forward(self, pair_repr: torch.Tensor): + """ + pair_repr: N R R cz + """ pair_repr = self.layer_norm(pair_repr) + bias = torch.matmul(pair_repr, self.bias) pair_repr = tri_attn_start( - pair_repr, self.gate, self.qkv, self.out, self.bias, + pair_repr, self.gate, self.qkv, self.out, bias, self.heads, self.c, self.scale, self.chunk_size, self.training ) return pair_repr @@ -345,8 +351,9 @@ def __init__(self, cz: int, pair_head: int, c: int, scale: float, chunk_size=-1) def forward(self, pair_repr: torch.Tensor): pair_repr = self.layer_norm(pair_repr) + bias = torch.matmul(pair_repr, self.bias) pair_repr = tri_attn_end( - pair_repr, self.gate, self.qkv, self.out, self.bias, + pair_repr, self.gate, self.qkv, self.out, bias, self.heads, self.c, self.scale, self.chunk_size, self.training ) return pair_repr diff --git a/examples/openfold/blocks/evoformer.py b/examples/openfold/blocks/evoformer.py index f2574ca5..0883543e 100644 --- a/examples/openfold/blocks/evoformer.py +++ b/examples/openfold/blocks/evoformer.py @@ -68,47 +68,46 @@ def __init__(self, s: int, cm: int, cz: int, def forward(self, msa_repr, pair_repr): + cube.runtime.function.anchor('MSARow') pair_repr, dummy_pair_repr = multi2ref(pair_repr) - - # msa row attention residual = msa_repr msa_repr = self.row_norm_m(msa_repr) dummy_pair_repr = self.row_norm_z(dummy_pair_repr) msa_repr = residual + self.row_attn(msa_repr, dummy_pair_repr) - # msa column attention + cube.runtime.function.anchor('MSACol') residual = msa_repr msa_repr = self.col_norm(msa_repr) msa_repr = residual + self.col_attn(msa_repr) - # msa transition + # cube.runtime.function.anchor('MSATrans') residual = msa_repr msa_repr = self.msa_transition_norm(msa_repr) msa_repr = self.msa_transition(msa_repr) msa_repr = residual + msa_repr - succ_msa_repr, msa_repr = multi2ref(msa_repr) - # out product mean + cube.runtime.function.anchor('OPM') msa_repr = self.outer_norm(msa_repr) pair_repr = pair_repr + self.outer_prod_mean(msa_repr) - # triangle multiplicative out-going edges + cube.runtime.function.anchor('TMO') pair_repr = self.tmo(pair_repr) - # triangle multiplicative in-going edges + + cube.runtime.function.anchor('TMI') pair_repr = self.tmi(pair_repr) - # pair attention start + cube.runtime.function.anchor('TANS') residual = pair_repr pair_repr = self.tri_attn_node_start(pair_repr) pair_repr = residual + pair_repr - # pair attention end + cube.runtime.function.anchor('TANE') residual = pair_repr pair_repr = self.tri_attn_node_end(pair_repr) pair_repr = residual + pair_repr - # pair transition + cube.runtime.function.anchor('PairTrans') residual = pair_repr pair_repr = self.pair_transition_norm(pair_repr) pair_repr = self.pair_transition(pair_repr) diff --git a/examples/openfold/blocks/opm.py b/examples/openfold/blocks/opm.py index 01a15b77..5ee13ad8 100644 --- a/examples/openfold/blocks/opm.py +++ b/examples/openfold/blocks/opm.py @@ -7,7 +7,7 @@ import torch.utils.checkpoint as ckpt -@cube.graph.parser.register('N S+ R^ M^, M^ c^, M^ c^, (c^ c^) cz^ -> N R^ R^ cz^', name='outer_prod_mean') +# @cube.graph.parser.register('N S+ R^ M^, M^ c^, M^ c^, (c^ c^) cz^ -> N R^ R^ cz^', name='outer_prod_mean') @torch.jit.ignore def outer_prod_mean(msa_repr: torch.Tensor, left_proj: torch.Tensor, right_proj: torch.Tensor, out_proj: torch.Tensor, chunk_size: int, training: bool): @@ -49,6 +49,48 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): return outer +@cube.graph.parser.register('N S R M+, M+ C -> N S R C', name='opm_projection') +def opm_projection(msa_repr: torch.Tensor, proj1: torch.Tensor): + x = torch.matmul(msa_repr, proj1) + return x + + +@cube.graph.parser.register('N S^ R C^, N S^ T^ C^, F^ Z^ -> N R T^ Z^') +@torch.jit.ignore +def opm(left: torch.Tensor, right: torch.Tensor, out_proj: torch.Tensor, + chunk_size: int, training: bool): + bs, s, r, c = left.size() + t = right.size(2) + # N S R C -> N R S C + a = left.transpose(-2, -3) + # N S T C -> N T S C + b = right.transpose(-2, -3) + + if chunk_size == -1: + # N R S M, N T S M -> N R T M M -> N R T (M M) + outer = torch.einsum('...bac,...dae->...bdce', a, + b).reshape(bs, r, t, c * c) + # N R T (M M), (M M) Z -> N R T Z + outer = torch.matmul(outer, out_proj) + else: + out_chunks = [] + + def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): + lhs_slice = lhs[:, start:start + chunk_size, :, :] + out = torch.einsum('...bac,...dae->...bdce', lhs_slice, + rhs).reshape(bs, chunk_size, t, c * c) + out = torch.matmul(out, out_proj) + return out + + for start in range(0, r, chunk_size): + ret = ckpt.checkpoint(opm, a, b, start) + ret = opm(a, b, start) + out_chunks.append(ret) + outer = torch.cat(out_chunks, dim=1) + # cube.profiler.CudaTimer().stop('opm') + return outer + + class OuterProducterMean(torch.nn.Module): def __init__(self, cm: int, c: int, cz: int, chunk_size: int) -> None: @@ -62,7 +104,11 @@ def forward(self, msa_repr: torch.Tensor): """ msa_repr: [N S R M] """ - return outer_prod_mean( - msa_repr, self.left, self.right, self.out, - self.chunk_size, self.training - ) + left = opm_projection(msa_repr, self.left) + right = opm_projection(msa_repr, self.right) + out = opm(left, right, self.out, self.chunk_size, self.training) + return out + # return outer_prod_mean( + # msa_repr, self.left, self.right, self.out, + # self.chunk_size, self.training + # ) diff --git a/examples/openfold/blocks/tmu.py b/examples/openfold/blocks/tmu.py index b2daa4e2..c5a91ea9 100644 --- a/examples/openfold/blocks/tmu.py +++ b/examples/openfold/blocks/tmu.py @@ -3,121 +3,61 @@ from examples.openfold.blocks.utils import multi2ref -# @cube.graph.parser.register('N S R Z, Z E, Z E, Z E, Z E, Z Z, E, E, E Z -> N S R Z') -@torch.jit.ignore -def tmu(pair_repr: torch.Tensor, - left1: torch.Tensor, left2: torch.Tensor, - right1: torch.Tensor, right2: torch.Tensor, - gate: torch.Tensor, - norm_weight: torch.Tensor, norm_bias: torch.Tensor, - out: torch.Tensor, outgoing: bool) -> torch.Tensor: - # cube.profiler.CudaTimer().start('tmu') - # Note S == R - # left projection: N S R Z^, Z^ E, Z^ E -> N S R E - left = torch.matmul(pair_repr, left1) - left = torch.sigmoid(left) - left = left * torch.matmul(pair_repr, left2) - # right projection: N S R Z^, Z^ E, Z^ E -> N S R E - right = torch.matmul(pair_repr, right1) - right = torch.sigmoid(right) - right = right * torch.matmul(pair_repr, right2) - if outgoing: - # N S R E -> N E S R - left = left.permute(0, 3, 1, 2) - # N S R E -> N E R S - right = right.permute(0, 3, 2, 1) - else: - # N S R E -> N E R S - left = left.permute(0, 3, 2, 1) - # N S R E -> N E S R - right = right.permute(0, 3, 1, 2) - # N E S R+, N E R+ S -> N E S S -> N S S E (for out) - # N E R S+, N E S+ R -> N E R R -> N R R E (for in) - p = torch.matmul(left, right).permute(0, 2, 3, 1) - e = p.size(3) - # N S S E^ -> N S S E^ - p = torch.nn.functional.layer_norm(p, (e,), norm_weight, norm_bias) - # N S S E+, E+ Z -> N S S Z - p = torch.matmul(p, out) - if not outgoing: - p = p.permute(0, 2, 1, 3) - # gate: N S R Z+, Z+ Z -> N S R Z - g = torch.matmul(pair_repr, gate) - g = torch.sigmoid(g) - # N S S Z, N S R Z -> N S R Z (broadcast R == S == 0) - p = p * g - # cube.profiler.CudaTimer().stop('tmu') - return p - - -@cube.graph.parser.register('N S^ R+ Z^, Z^ E, Z^ E, Z^ E, Z^ E -> N S S E', name='tmi_projection') -@torch.jit.ignore -def tmi_projection(pair_repr: torch.Tensor, +# @cube.graph.parser.register('N S R Z^, Z^ E, Z^ E -> N S R E') +# def tmu_projection(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): +# x = torch.matmul(pair_repr, proj1) +# x = torch.sigmoid(x) +# x = x * torch.matmul(pair_repr, proj2) +# +# +# @cube.graph.parser.register('N S R Z+, Z+ E-> N S R E') +# def tmu_gate(pair_repr: torch.Tensor, proj: torch.Tensor): +# return torch.sigmoid(torch.matmul(pair_repr, proj)) + + +@cube.graph.parser.register('N S R Z^, Z^ E^, Z^ E^, Z^ E, Z^ E^, Z^ Z^ -> N S R E, N S R E^, N S R Z^', name='tmu_projection') +def tmu_projection(pair_repr: torch.Tensor, left1: torch.Tensor, left2: torch.Tensor, - right1: torch.Tensor, right2: torch.Tensor) -> torch.Tensor: - # left projection: N S R Z^, Z^ E, Z^ E -> N S R E + right1: torch.Tensor, right2: torch.Tensor, + gate: torch.Tensor): + # left left = torch.matmul(pair_repr, left1) left = torch.sigmoid(left) left = left * torch.matmul(pair_repr, left2) - # right projection: N S R Z^, Z^ E, Z^ E -> N S R E + # right right = torch.matmul(pair_repr, right1) right = torch.sigmoid(right) right = right * torch.matmul(pair_repr, right2) - # N S R E -> N E S R - left = left.permute(0, 3, 1, 2) - # N S R E -> N E R S - right = right.permute(0, 3, 2, 1) - # N E S R+, N E R+ S -> N E S S -> N S S E - p = torch.matmul(left, right).permute(0, 2, 3, 1) - return p + # gate + gate = torch.sigmoid(torch.matmul(pair_repr, gate)) + return left, right, gate -@cube.graph.parser.register('N S R Z^, N R S E, E Z^, Z^ Z^ -> N S R Z^') -@torch.jit.ignore -def tmi_gating(pair_repr: torch.Tensor, p: torch.Tensor, out: torch.Tensor, gate: torch.Tensor): - # N S R Z+, Z+ Z -> N S R Z - g = torch.matmul(pair_repr, gate) - g = torch.sigmoid(g) - # N R S E+, E+ Z -> N R S Z -> N S R Z - p = torch.matmul(p, out).permute(0, 2, 1, 3) - p = p * g - return p - - -@cube.graph.parser.register('N S+ R^ Z^, Z^ E, Z^ E, Z^ E, Z^ E -> N R R E', name='tmo_projection') -def tmo_projection(pair_repr: torch.Tensor, - left1: torch.Tensor, left2: torch.Tensor, - right1: torch.Tensor, right2: torch.Tensor) -> torch.Tensor: - # left projection: N S R Z^, Z^ E, Z^ E -> N S R E - left = torch.matmul(pair_repr, left1) - left = torch.sigmoid(left) - left = left * torch.matmul(pair_repr, left2) - # right projection: N S R Z^, Z^ E, Z^ E -> N S R E - right = torch.matmul(pair_repr, right1) - right = torch.sigmoid(right) - right = right * torch.matmul(pair_repr, right2) - # N S R E -> N E R S - left = left.permute(0, 3, 2, 1) - # N S R E -> N E S R - right = right.permute(0, 3, 1, 2) - # N E R S+, N E S+ R -> N E R R -> N R R E - p = torch.matmul(left, right).permute(0, 2, 3, 1) +@cube.graph.parser.register('N S R^ E, N T^ R^ E^, N S^ T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='tmo') +def tmo(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, + norm_w: torch.Tensor, norm_b: torch.Tensor, out: torch.Tensor): + a = left.permute(0, 3, 1, 2) + b = right.permute(0, 3, 2, 1) + p = torch.matmul(a, b).permute(0, 2, 3, 1) + p = torch.nn.functional.layer_norm(p, (128, ), norm_w, norm_b) + p = torch.matmul(p, out) + p = p * gate return p -@cube.graph.parser.register('N S R Z^, N S R E^, E^ Z^, Z^ Z^ -> N S R Z^') -def tmo_gating(pair_repr: torch.Tensor, p: torch.Tensor, out: torch.Tensor, gate: torch.Tensor): - # N S R Z+, Z+ Z -> N S R Z - g = torch.matmul(pair_repr, gate) - g = torch.sigmoid(g) - # N S R E+, E+ Z -> N S R Z +@cube.graph.parser.register('N R^ S E, N R^ T^ E^, N T^ S^ Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='tmi') +def tmi(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, + norm_w: torch.Tensor, norm_b: torch.Tensor, out: torch.Tensor): + a = left.permute(0, 3, 2, 1) + b = right.permute(0, 3, 1, 2) + p = torch.matmul(a, b).permute(0, 2, 3, 1) + p = torch.nn.functional.layer_norm(p, (128, ), norm_w, norm_b) p = torch.matmul(p, out) - p = p * g + p = p.permute(0, 2, 1, 3) * gate return p - class TriangleMultiplicativeUpdate(torch.nn.Module): def __init__(self, cz: int, mult: int, outgoing: bool) -> None: @@ -129,9 +69,10 @@ def __init__(self, cz: int, mult: int, outgoing: bool) -> None: self.right1 = torch.nn.Parameter(torch.empty(cz, mult)) self.right2 = torch.nn.Parameter(torch.empty(cz, mult)) - self.norm = torch.nn.LayerNorm(mult) - # self.normw = torch.nn.Parameter(torch.empty(mult)) - # self.normb = torch.nn.Parameter(torch.empty(mult)) + # self.norm = torch.nn.LayerNorm(mult) + self.normw = torch.nn.Parameter(torch.empty(mult)) + self.normb = torch.nn.Parameter(torch.empty(mult)) + self.out = torch.nn.Parameter(torch.empty(mult, cz)) self.gate = torch.nn.Parameter(torch.empty(cz, cz)) self.outgoing = outgoing @@ -142,20 +83,16 @@ def forward(self, pair_repr: torch.Tensor): """ residual = pair_repr pair_repr = self.layer_norm(pair_repr) - # ====================== break for tp ======================= - pair_repr1, pair_repr2 = multi2ref(pair_repr) - if self.outgoing: - p = tmi_projection(pair_repr1, self.left1, self.left2, self.right1, self.right2) - else: - p = tmo_projection(pair_repr1, self.left1, self.left2, self.right1, self.right2) - p = self.norm(p) + + left, right, gate = tmu_projection(pair_repr, + self.left1, self.left2, + self.right1, self.right2, self.gate + ) + if self.outgoing: - pair_repr = tmi_gating(pair_repr2, p, self.out, self.gate) + pair_repr = tmo(left, right, gate, self.normw, self.normb, self.out) else: - pair_repr = tmo_gating(pair_repr2, p, self.out, self.gate) - # ======================= intergrate version ================== - # pair_repr = tmu(pair_repr, - # self.left1, self.left2, self.right1, self.right2, - # self.gate, self.normw, self.normb, self.out, self.outgoing) + pair_repr = tmi(left, right, gate, self.normw, self.normb, self.out) + pair_repr = residual + pair_repr return pair_repr diff --git a/examples/openfold/policy/mpmd.py b/examples/openfold/policy/mpmd.py index c9ec7987..c309e43a 100644 --- a/examples/openfold/policy/mpmd.py +++ b/examples/openfold/policy/mpmd.py @@ -5,42 +5,54 @@ from cube.graph.function.anchor import IRGraphAnchor from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +import more_itertools +import numpy as np + def _group_to_evoformers(fnodes) -> List[List[IRCell]]: # group to evoformer layers evoformers: List[List[IRFwOperation]] = [] - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor) and node.name == 'Evoformer Start'] indices = [fnodes.index(anchor) for anchor in anchors] for lid, idx in enumerate(indices): - fnodes[idx+1].comment = f'===> start of transformer layer {lid}' + # get first forward op + for fnode in fnodes[idx+1:]: + if not isinstance(fnode, IRGraphAnchor): break + fnode.comment = f'===> start of evoformer layer {lid}' start = idx if lid != 0 else 0 end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) evoformers.append(fnodes[start:end]) - for lid in range(len(evoformers) - 1): - if evoformers[lid][-1].name == 'multiref': - node = evoformers[lid].pop() - evoformers[lid+1].insert(0, node) + print(f'find {len(indices)} evoformer layers') return evoformers # ========================= parallelisms ================================= # tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], - idx: int, dim: int, tag='dim'): - algo = node.algorithms(tag) - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], tag='dim', **config): + if len(devs) == 1: + sub_nodes = [node] + else: + algo = node.algorithms(tag) + sub_nodes = graph.partition(node, algo, num=len(devs), **config) + assert sub_nodes is not None for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) + graph.assign(sub_node, int(devid)) return sub_nodes # replicate def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) + if len(devs) == 1: + sub_nodes = [node] + else: + sub_nodes = graph.replicate(node, times=len(devs)) for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) + graph.assign(sub_node, int(devid)) return sub_nodes + +# ========================= policies ================================= + + def PASSingle(graph: IRGraph, resource): assert resource.ngpus == 1 # print(graph.extra_repr()) @@ -72,35 +84,72 @@ def PASDP(graph: IRGraph, resource): return graph -def PASTP(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) +def PASDAP(graph: IRGraph, resource, tp: int, dp: int): + assert tp * dp == resource.ngpus + + devmesh = np.arange(resource.ngpus).reshape(dp, tp) + tp_devs = list(range(tp)) + + dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[dataloader.get_batch_dims()[0]] + print(f'> get batch size: {bs}') + dls: List[IRDataOperation] = _replica(graph, dataloader, tp_devs) + for tp_idx, dl in enumerate(dls): + dp_devs = devmesh[:,tp_idx] + _tp(graph, dl, dp_devs, 'data') # grouping evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) for layer in evoformers: graph.recompute(layer) - for node in graph.select(ntype=(IRDataOperation, IRFwOperation)): - if isinstance(node, IRGraphAnchor): continue - if node.name == 'row_attn': - _tp(graph, node, tp_devs, idx=2, dim=1) - elif node.name == 'col_attn': - _tp(graph, node, tp_devs, idx=1, dim=1) - elif node.name == 'feedforward': - _tp(graph, node, tp_devs, idx=1, dim=1) - elif node.name == 'tri_attn_start': - _tp(graph, node, tp_devs, idx=1, dim=1) - elif node.name == 'tri_attn_end': - _tp(graph, node, tp_devs, idx=1, dim=1) - elif node.name == 'outer_prod_mean': - _tp(graph, node, tp_devs, idx=0, dim=1) - elif node.name == 'tmi_projection': - _tp(graph, node, tp_devs, idx=0, dim=2) - elif node.name == 'tmi_projection': - _tp(graph, node, tp_devs, idx=0, dim=1) - elif node.name == 'tmi_gating' or node.name == 'tmo_gating': - _tp(graph, node, tp_devs, idx=0, dim=1) + fnodes = graph.select(ntype=IRFwOperation) + fnodes = [fnode for fnode in fnodes if fnode.name != 'Evoformer Start'] + + node_groups = more_itertools.split_at(fnodes, lambda n: isinstance(n, IRGraphAnchor)) + + for nodes in node_groups: + # tensor parallelism + names = set(n.name for n in nodes) + subnodes = [] + if len(names) == 1 or 'mul' in names: # for first layer norm operators + for node in nodes: + subnodes.append(_replica(graph, node, tp_devs)) + elif 'row_attn' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) + elif 'col_attn' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) + elif 'opm' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) + elif 'tmo' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) + elif 'tmi' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) + elif 'tri_attn_start' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) + elif 'tri_attn_end' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) + elif 'feedforward' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) else: - _replica(graph, node, tp_devs) + assert False, names + # data parallelism + for ns in subnodes: + for tp_idx, subnode in enumerate(ns): + dp_devs = devmesh[:,tp_idx] + if bs in subnode.input(0).shape: + dim = subnode.input(0).shape.index(bs) + _tp(graph, subnode, dp_devs, idx=0, dim=dim) + else: + print(f'replicate op on data parallel group: {node.name}') + _replica(graph, subnode, dp_devs) + return graph \ No newline at end of file diff --git a/examples/openfold/train.py b/examples/openfold/train.py index 89bc2f13..d7321d01 100644 --- a/examples/openfold/train.py +++ b/examples/openfold/train.py @@ -1,7 +1,7 @@ """ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ - examples/openfold/train.py --fp16 + examples/openfold/train.py --fp16 --tp 4 --dp 1 """ @@ -11,16 +11,18 @@ import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary -from examples.openfold.policy.mpmd import * +from examples.openfold.policy.mpmd import PASDAP import argparse +from functools import partial + cube.init() parser = argparse.ArgumentParser(description='AlphaFold Train') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') -parser.add_argument('--layers', type=int, default=48, +parser.add_argument('--layers', type=int, default=4, help='evoformer layer number') parser.add_argument('--msa-hidden', type=int, default=256, help='cm value') @@ -32,13 +34,21 @@ help='micro batch size') parser.add_argument('--gbs', type=int, default=1, help='global batch size') +parser.add_argument('--tp', type=int, default=1, + help='tensor parallelism size') +parser.add_argument('--dp', type=int, default=1, + help='data parallelism size') args = parser.parse_args() assert args.gbs % args.mbs == 0 +assert args.mbs % args.dp == 0 assert args.msa_hidden % args.head_dim == 0 assert args.pair_hidden % args.head_dim == 0 +PASDAP = partial(PASDAP, tp=args.tp, dp=args.dp) + + def nparams(model) -> int: cnt = 0 for param in model.parameters(): @@ -68,7 +78,7 @@ def train(): print_each_rank(f'before partitioned model parameter: {nparams(model)}') model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=PASTP, override=True, load_content=True) + @cube.compile(model, dataloader, PAS=PASDAP, override=True, load_content=True) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) From add5c79ef514e91c13a9b82df380eee56337c474 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 3 Dec 2022 16:09:12 +0800 Subject: [PATCH 1174/1892] reorder arguments and outputs to match for scheduling --- cube/graph/gener/gen.py | 10 +++++++++- cube/graph/segment.py | 23 ++++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index bc0f0a21..02d0fddf 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -629,9 +629,17 @@ def fusion(graph: IRSegment) -> IRSegment: if isinstance(adapter, IRAdapter) and adapter.forward and not adapter.differentiable: fadapters.append(adapter) if adapter.mirror is not None: - badapters.insert(0, adapter.mirror) + badapters.append(adapter.mirror) + # badapters.insert(0, adapter.mirror) else: if len(fadapters) > 1: + # reorder adapter to match output of segment. This is temporally + # necessary for pipeline scheduling with multiple output. + ftids = np.array([fadapter.input(0).parent.tid for fadapter in fadapters]) + indices = np.argsort(ftids) + fadapters = [fadapters[idx] for idx in indices] + if len(badapters) > 0: + badapters = [badapters[idx] for idx in indices] # insert fused fadapter fused_fadapter = IRAdapter.merge(fadapters) for adapter in fadapters: diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 020d030d..578cb304 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1,5 +1,6 @@ from contextlib import contextmanager from typing import Dict, Union, List, Optional, Set, Tuple +import numpy as np from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap from cube.ir.cten import IRTensor, IRCell @@ -963,7 +964,15 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: if len(consumers) == 0 or any(c not in nodes for c in consumers): outputs.add(otensor) continue - segment = IRSegment(nodes, tuple(inputs), tuple(outputs)) + + def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: + """Reorder by logical tensor id. Temporally necessary for pipeline scheduling""" + tensors = list(tensors) + tids = np.array([t.parent.tid for t in tensors]) + indices = np.argsort(tids) + return tuple(tensors[idx] for idx in indices) + + segment = IRSegment(nodes, order(inputs), order(outputs)) return segment def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: @@ -994,6 +1003,18 @@ def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: for otensor in node.outputs(): if otensor in self._outputs and otensor not in outputs: outputs.append(otensor) + + def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: + """Reorder by logical tensor id. Temporally necessary for pipeline scheduling""" + tensors = list(tensors) + print(tensors) + tids = np.array([t.parent.tid for t in tensors]) + indices = np.argsort(tids) + return tuple(tensors[idx] for idx in indices) + + if self.isfw(): + inputs, outputs = order(inputs), order(outputs) + segment = IRSegment(nodes, inputs, outputs, self.name) segment._id = self.cid if mirror and self.mirror is not None: From 8ebdfbc9d55abe8d8c8c97a7c1a42cb698a00c5c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 3 Dec 2022 16:10:21 +0800 Subject: [PATCH 1175/1892] add nf1b schedule --- cube/graph/schedule/schednf1b.py | 77 ++++++++++++++ cube/runtime/schedule/__init__.py | 3 +- cube/runtime/schedule/schednf1b.py | 161 +++++++++++++++++++++++++++++ 3 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 cube/graph/schedule/schednf1b.py create mode 100644 cube/runtime/schedule/schednf1b.py diff --git a/cube/graph/schedule/schednf1b.py b/cube/graph/schedule/schednf1b.py new file mode 100644 index 00000000..a900c406 --- /dev/null +++ b/cube/graph/schedule/schednf1b.py @@ -0,0 +1,77 @@ + +from typing import Dict, Optional, List +import warnings + +from cube.ir.cten import IRCell +from cube.ir.adapter.adapter import IRAdapter +from cube.ir.adapter.adapter import IRWeightReducer + +from cube.graph.graph import IRGraph, IRSegment +from cube.graph.schedule.sched1f1b import IRSchedule1F1B + + +class IRScheduleNF1B(IRSchedule1F1B): + """ + NF1B Scheduling + + This treats model as a linear graph which can be + grouped into continous stages. + + [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] + [Recv-Backward] Backward-Segment [Send-Backward] + """ + + def __init__(self, graph, nmicros: int, recycle: int): + super().__init__(graph, nmicros) + self.signature = 'cube.runtime.schedule.ScheduleNF1B.run' + # forward body + self.fsegments: Dict[int, IRSegment] = dict() + # forward send + self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() + # forward recv + self.rfadapter: Dict[int, Optional[IRAdapter]] = dict() + # backard send + self.sbadapter: Dict[int, Optional[IRAdapter]] = dict() + # backward recv + self.rbadapter: Dict[int, Optional[IRAdapter]] = dict() + # num_stage + self.num_stages: int = -1 + # stage id + self.stage_id: Dict[int, int] = dict() + # reducers + self.dev_reducers: Dict[int, List[IRWeightReducer]] = dict() + # recycle + self.recycle = recycle + + def kwargs(self, devid: int) -> Dict[str, IRCell]: + """ + return kwargs for runtime caller + """ + return dict( + segment = self.fsegments[devid], + sfadapter = self.sfadapter[devid], + rfadapter = self.rfadapter[devid], + sbadapter = self.sbadapter[devid], + rbadapter = self.rbadapter[devid], + dataloader = 'dataloader', + stage_id = self.stage_id[devid], + num_stages = self.num_stages, + num_microbatch = self.nmicros, + recycle = self.recycle, + reducers = self.dev_reducers[devid], + ) + + def __repr__(self) -> str: + dscp = '' + for mesh in self.devmesh: + devid = mesh[0] + # segment = self.segments[devid].to_str(skip_attr=True) if self.segment[mesh[0]] else None + dscp += (f"NF1B Schedule: Stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" + f" segment = {self.segments[devid]}\n" + f" send-fw = {self.sfadapter[mesh[0]]}\n" + f" recv-fw = {self.rfadapter[mesh[0]]}\n" + f" send-bw = {self.sbadapter[mesh[0]]}\n" + f" recv-bw = {self.rbadapter[mesh[0]]}\n" + f" recycle = {self.recycle}\n" + f")\n") + return dscp diff --git a/cube/runtime/schedule/__init__.py b/cube/runtime/schedule/__init__.py index b962ccab..f5b70816 100644 --- a/cube/runtime/schedule/__init__.py +++ b/cube/runtime/schedule/__init__.py @@ -1,2 +1,3 @@ from cube.runtime.schedule.sched1f1b import Schedule1F1B -from cube.runtime.schedule.schedmix import ScheduleMix \ No newline at end of file +from cube.runtime.schedule.schedmix import ScheduleMix +from cube.runtime.schedule.schednf1b import ScheduleNF1B \ No newline at end of file diff --git a/cube/runtime/schedule/schednf1b.py b/cube/runtime/schedule/schednf1b.py new file mode 100644 index 00000000..9d963839 --- /dev/null +++ b/cube/runtime/schedule/schednf1b.py @@ -0,0 +1,161 @@ +""" +Schedule Plan tailored for AlphaFold + +The scheduling follows forward-backward pattern. +In steady phase, each forward will perform a single forward at `recycle+1` +micro-batches, with one keeping activation while others no activation. + +""" +from typing import Callable, Iterable, List, Tuple +from functools import partial +import torch + +from cube.runtime.schedule.strategy import ScheduleABC + + +def first_stage_rfadapter(shapes: Tuple[List[int]], dtypes: Tuple[List[torch.dtype]], dataloader): + return next(dataloader) + +def last_stage_sfadapter(msa_repr: torch.Tensor, pair_repr: torch.Tensor): + pass + + +class ScheduleNF1B(ScheduleABC): + + @staticmethod + def run(segment: Callable, # forward body + rfadapter: Callable, # recv_forward adapter + sfadapter: Callable, # send_forward adapter + rbadapter: Callable, # recv_backward adapter + sbadapter: Callable, # send_backward adapter + dataloader: Iterable, + stage_id: int, + num_stages: int, + num_microbatch: int, + recycle: int, + reducers: List[Callable]): + + assert num_microbatch >= num_stages + + # special case: num_stages == 1: use gradient accum + if num_stages == 1: + for _ in range(num_microbatch): + inputs = ScheduleNF1B.dataloader_step(dataloader) + for _ in range(recycle): + # FIXME: a simulation as output will be loss + with torch.no_grad(): + _ = ScheduleNF1B.forward_step(segment, *inputs) + outputs = ScheduleNF1B.forward_step(segment, *inputs) + input_grads = ScheduleNF1B.backward_step(inputs, outputs, (None,)) + for reducer in reducers: + reducer() + return + + # =============================== recycle ==================================== + if stage_id == 0: + assert rfadapter is None + shapes, dtypes = [], [] + for data in ScheduleNF1B.dataloader_step(dataloader): + shapes.append(list(data.size())) + dtypes.append(data.dtype) + rfadapter = partial(first_stage_rfadapter, shapes=shapes, dtypes=dtypes, dataloader=dataloader) + # if stage_id == num_stages - 1: + # assert sfadapter is None + # sfadapter = last_stage_sfadapter + + for rid in range(recycle): + for mid in range(num_microbatch): + # recv forward + if stage_id == 0 and rid == 0: + inputs = ScheduleNF1B.dataloader_step(dataloader) + else: + inputs = ScheduleNF1B.adapter_step(rfadapter, require_grad=(rid == recycle-1)) + # forward + with torch.no_grad(): + outputs = ScheduleNF1B.forward_step(segment, *inputs) + # FIXME: a simulation + if stage_id == num_stages - 1: + outputs = ScheduleNF1B.dataloader_step(dataloader) + # send forward + ScheduleNF1B.adapter_step(sfadapter, False, *outputs) + # recv forward batches TODO: optmize with async + datas = [] + if stage_id == 0: + for mid in range(num_microbatch): + inputs = ScheduleNF1B.adapter_step(rfadapter, require_grad=False) + datas.append(inputs) + # ========================================================================== + + # 1F1B schedule + if stage_id == 0: rfadapter = None + if stage_id == num_stages - 1: sfadapter = None + num_warmup_microbatches = num_stages - 1 - stage_id + num_warmup_remaining = num_microbatch - num_warmup_microbatches + + # warmup + for _ in range(num_warmup_microbatches): + # recv forward + # print(f'rank[{torch.distributed.get_rank()}]: line26: recving forward') + inputs = ScheduleNF1B.adapter_step(rfadapter, True) + inputs = datas.pop(0) if inputs == (None,) else inputs + # forward + ScheduleNF1B.push_tail('inputs', inputs) + outputs = ScheduleNF1B.forward_step(segment, *inputs) + ScheduleNF1B.push_tail('outputs', outputs) + # send forward + # print(f'rank[{torch.distributed.get_rank()}]: line40 send forward') + ScheduleNF1B.adapter_step(sfadapter, True, *outputs) + + if num_warmup_remaining > 0: + # print(f'rank[{torch.distributed.get_rank()}]: line44 recv forward') + inputs = ScheduleNF1B.adapter_step(rfadapter, True) + inputs = datas.pop(0) if inputs == (None,) else inputs + + # steady + for i in range(num_warmup_remaining): + # forward + ScheduleNF1B.push_tail('inputs', inputs) + # print(f'rank[{torch.distributed.get_rank()}]: line 57 forward') + outputs = ScheduleNF1B.forward_step(segment, *inputs) + ScheduleNF1B.push_tail('outputs', outputs) + + # send forward recv backward + # print(f'rank[{torch.distributed.get_rank()}]: line62 send forward recv backward') + grads = ScheduleNF1B.exchange(sfadapter, rbadapter, stage_id, (True, False), *outputs) + grads = (None,) if len(grads) == 0 else grads + + # backward + inputs, outputs = ScheduleNF1B.pop_head('inputs'), ScheduleNF1B.pop_head('outputs') + # print(f'rank[{torch.distributed.get_rank()}]: line71 backward') + input_grads = ScheduleNF1B.backward_step(inputs, outputs, grads) + + # send backward recv forward + if i != num_warmup_remaining - 1: + # print(f'rank[{torch.distributed.get_rank()}]: line77 send backward recv forward') + inputs = ScheduleNF1B.exchange(sbadapter, rfadapter, stage_id, (False, True), *input_grads) + inputs = datas.pop(0) if inputs == (None,) else inputs + else: + # send backward + # print(f'rank[{torch.distributed.get_rank()}]: line82 send backward') + ScheduleNF1B.adapter_step(sbadapter, False, *input_grads) + + # cooldown + for i in range(num_warmup_microbatches): + inputs, outputs = ScheduleNF1B.pop_head('inputs'), ScheduleNF1B.pop_head('outputs') + # recv backward + # print(f'rank[{torch.distributed.get_rank()}]: line89 recv backward') + grads = ScheduleNF1B.adapter_step(rbadapter, False) + grads = (None,) if len(grads) == 0 else grads + # backward + # print(f'rank[{torch.distributed.get_rank()}]: line96 backward') + input_grads = ScheduleNF1B.backward_step(inputs, outputs, grads) + # send backward + # print(f'rank[{torch.distributed.get_rank()}]: line99 send backward') + ScheduleNF1B.adapter_step(sbadapter, False, *input_grads) + + # allreduce gradient + for reducer in reducers: + reducer() + + assert len(datas) == 0 + ScheduleNF1B.assert_empty() From ec24c4a0f60e21b966a6830fe7571d81e3ceb15a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 3 Dec 2022 16:14:32 +0800 Subject: [PATCH 1176/1892] support with nf1b pipeline parallelism --- examples/openfold/blocks/evoformer.py | 22 +++- examples/openfold/model.py | 18 ++- examples/openfold/policy/mpmd.py | 172 ++++++++++++++++++++++++-- examples/openfold/train.py | 20 +-- 4 files changed, 211 insertions(+), 21 deletions(-) diff --git a/examples/openfold/blocks/evoformer.py b/examples/openfold/blocks/evoformer.py index 0883543e..663a6d55 100644 --- a/examples/openfold/blocks/evoformer.py +++ b/examples/openfold/blocks/evoformer.py @@ -1,3 +1,4 @@ +from typing import Tuple import torch from examples.openfold.blocks.attention import MSARowAttention, MSAColAttention, Transition, TriangleAttentionNodeStart, TriangleAttentionNodeEnd from examples.openfold.blocks.tmu import TriangleMultiplicativeUpdate @@ -8,6 +9,22 @@ import cube +# @cube.graph.parser.register('N S^ R^ cm^, N R^ R^ cz^ -> N out^') +# @torch.jit.ignore +# def input_packing(msa: torch.Tensor, pair: torch.Tensor, out: int) -> torch.Tensor: +# buffer = torch.cat((torch.flatten(msa, start_dim=1), torch.flatten(pair, start_dim=1))) +# return buffer +# +# +# @cube.graph.parser.register('N out^ -> N S^ R^ cm^, N R^ R^ cz^', name='input_unflatten') +# @torch.jit.ignore +# def input_unpacking(buffer: torch.Tensor, +# S: int, R: int, cm: int, cz: int) -> Tuple[torch.Tensor, torch.Tensor]: +# msa_nele = S * R * cm +# msa = buffer[:,:msa_nele].reshape(buffer.size(0), S, R, cm) +# pair = buffer[:,msa_nele:].reshape(buffer.size(0), R, R, cz) +# return msa, pair + class Evoformer(torch.nn.Module): """ @@ -16,13 +33,14 @@ class Evoformer(torch.nn.Module): The mask and dropout is ommited for simplicity. """ - def __init__(self, s: int, cm: int, cz: int, + def __init__(self, s: int, r: int, cm: int, cz: int, use_chunk=False, is_train=True, c=32, msa_head=8, pair_head=4, c_tri_mult=128, ff_mult=4): super().__init__() - self.s, self.cm, self.cz, self.c = s, cm, cz, c + self.s, self.r, self.cm, self.cz, self.c = s, r, cm, cz, c + self.fout = self.s * self.r * self.cm + self.r * self.r * self.cz self.msa_head, self.pair_head = msa_head, pair_head self.c_tri_mult, self.ff_mult = c_tri_mult, ff_mult self.scale = 1.0 / math.sqrt(c) diff --git a/examples/openfold/model.py b/examples/openfold/model.py index 50f77936..f5c72822 100644 --- a/examples/openfold/model.py +++ b/examples/openfold/model.py @@ -1,12 +1,12 @@ """ Alphafold 2, using implementation similar with OpenFold. """ - import torch import torch.nn as nn # from examples.openfold.blocks.embedder import InputEmbedder, RecyclingEmbedder, TemplateAngleEmbedder from examples.openfold.blocks.evoformer import Evoformer +# from examples.openfold.blocks.evoformer import input_packing, input_unpacking from dataclasses import dataclass @@ -82,12 +82,12 @@ class Config: bs: int = 1 - class AlphaFold(nn.Module): def __init__(self, cfg: Config = Config()) -> None: super().__init__() + self.cfg = cfg # self.input_embedder = InputEmbedder( # cfg.input_embedder_tf_dim, cfg.input_embedder_msa_dim, @@ -120,12 +120,14 @@ def __init__(self, cfg: Config = Config()) -> None: self.extra_msa_stack = None # ExtraMSAStack() # evoformer - # self.evoformer = EvoformerStack() + self.s, self.r, self.cm, self.cz = cfg.evoformer_s, cfg.evoformer_r, cfg.evoformer_cm, cfg.evoformer_cz + self.fout = self.s * self.r * self.cm + self.r * self.r * self.cz + self.msa_norm = torch.nn.LayerNorm(cfg.evoformer_cm) self.pair_norm = torch.nn.LayerNorm(cfg.evoformer_cz) self.evoformers = nn.ModuleList( [Evoformer( - cfg.evoformer_s, cfg.evoformer_cm, cfg.evoformer_cz, + self.s, self.r, self.cm, self.cz, cfg.evoformer_use_chunk ) for _ in range(cfg.evoformer_nlayers)] ) @@ -134,10 +136,18 @@ def __init__(self, cfg: Config = Config()) -> None: self.aux_heads = None # AuxiliaryHeads() def forward(self, msa, pair): + """ + msa: [N S R cm] + pair: [N R R cz] + """ msa = self.msa_norm(msa) pair = self.pair_norm(pair) + # cube.runtime.function.anchor('PackingRegion') + # x = input_packing(msa, pair, self.fout) for evoformer in self.evoformers: cube.runtime.function.anchor('Evoformer Start') msa, pair = evoformer(msa, pair) + # x = evoformer(x) + # msa, pair = input_unpacking(x, self.s, self.r, self.cm, self.cz) loss = torch.sum(msa) * torch.sum(pair) return loss diff --git a/examples/openfold/policy/mpmd.py b/examples/openfold/policy/mpmd.py index c309e43a..55ab9d6c 100644 --- a/examples/openfold/policy/mpmd.py +++ b/examples/openfold/policy/mpmd.py @@ -3,7 +3,10 @@ from cube.graph import IRGraph from cube.ir.cten import IRCell from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.segment import IRSegment from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.graph.schedule.schednf1b import IRScheduleNF1B +from cube.graph.schedule.sched1f1b import IRSchedule1F1B import more_itertools import numpy as np @@ -84,12 +87,19 @@ def PASDP(graph: IRGraph, resource): return graph -def PASDAP(graph: IRGraph, resource, tp: int, dp: int): - assert tp * dp == resource.ngpus +def PASDAP(graph: IRGraph, resource, tp: int): + + assert resource.ngpus % tp == 0 + dp = resource.ngpus // tp devmesh = np.arange(resource.ngpus).reshape(dp, tp) tp_devs = list(range(tp)) + # grouping + evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) + for layer in evoformers: + graph.recompute(layer) + dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] bs = dataloader.output(0).shape[dataloader.get_batch_dims()[0]] print(f'> get batch size: {bs}') @@ -98,10 +108,6 @@ def PASDAP(graph: IRGraph, resource, tp: int, dp: int): dp_devs = devmesh[:,tp_idx] _tp(graph, dl, dp_devs, 'data') - # grouping - evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) - for layer in evoformers: - graph.recompute(layer) fnodes = graph.select(ntype=IRFwOperation) fnodes = [fnode for fnode in fnodes if fnode.name != 'Evoformer Start'] @@ -115,6 +121,9 @@ def PASDAP(graph: IRGraph, resource, tp: int, dp: int): if len(names) == 1 or 'mul' in names: # for first layer norm operators for node in nodes: subnodes.append(_replica(graph, node, tp_devs)) + # elif 'input_packing' in names: + # for node in nodes: + # subnodes.append(_replica(graph, node, tp_devs)) elif 'row_attn' in names: for node in nodes: subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) @@ -152,4 +161,153 @@ def PASDAP(graph: IRGraph, resource, tp: int, dp: int): print(f'replicate op on data parallel group: {node.name}') _replica(graph, subnode, dp_devs) - return graph \ No newline at end of file + return graph + + +def PASRoundRobin(graph: IRGraph, resource): + + pp_size = resource.ngpus + + # grouping + evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) + for layer in evoformers: + graph.recompute(layer) + + + fstages = [[] for _ in range(pp_size)] + nlayer_per_stage = len(evoformers) // pp_size + for lid, fnodes in enumerate(evoformers): + sid = min(lid // nlayer_per_stage, pp_size - 1) + fstages[sid] += fnodes + graph.staging(tuple(stages[0] for stages in fstages)) + + dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] + graph.assign(dataloader, 0) + + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + for sid, fstage in enumerate(fstages): + graph.assign(fstage, sid) + + return graph + + +def PASNF1B(graph: IRGraph, resource, mbs: int, gbs: int, recycle: int): + + assert gbs % mbs == 0 + nmbs = gbs // mbs + pp_size = resource.ngpus + + # grouping + evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) + assert len(evoformers) % pp_size == 0 + for layer in evoformers: + graph.recompute(layer) + + fstages = [[] for _ in range(pp_size)] + nlayer_per_stage = len(evoformers) // pp_size + for lid, fnodes in enumerate(evoformers): + sid = min(lid // nlayer_per_stage, pp_size - 1) + fstages[sid] += fnodes + graph.staging(tuple(stages[0] for stages in fstages)) + + dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] + graph.assign(dataloader, 0) + + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + for sid, fstage in enumerate(fstages): + graph.assign(fstage, sid) + + strategy = IRSchedule1F1B(graph, nmbs) + graph.predef_sched(strategy) + + return graph + + +def PASDAPPipe(graph: IRGraph, resource, mbs: int, gbs: int, tp: int, pp: int, recycle: int): + + assert gbs % mbs == 0 + assert resource.ngpus % (pp * tp) == 0 + dp = resource.ngpus // (pp * tp) + nmbs = gbs // mbs + + devmesh = np.arange(resource.ngpus, dtype=int).reshape(dp, pp, tp) + tp_devs = [0] * tp # dummy device, which will be reset at dp + + + # grouping + evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) + assert len(evoformers) % pp == 0 + for layer in evoformers: + graph.recompute(layer) + + fstages = [[] for _ in range(pp)] + nlayer_per_stage = len(evoformers) // pp + for lid, fnodes in enumerate(evoformers): + sid = min(lid // nlayer_per_stage, pp - 1) + fstages[sid] += fnodes + graph.staging(tuple(stages[0] for stages in fstages)) + + # setup dataloader + dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[dataloader.get_batch_dims()[0]] + print(f'> get batch size: {bs}') + dls: List[IRDataOperation] = _replica(graph, dataloader, tp_devs) + for tp_idx, dl in enumerate(dls): + dp_devs = devmesh[:, 0, tp_idx] + _tp(graph, dl, dp_devs, 'data') + + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + assert len(fstages) > 0 + for sid, fstage in enumerate(fstages): + fnodes = fstage.select(ntype=IRFwOperation) + fnodes = [fnode for fnode in fnodes if fnode.name != 'Evoformer Start'] + node_groups = more_itertools.split_at(fnodes, lambda n: isinstance(n, IRGraphAnchor)) + for nodes in node_groups: + # tensor parallelism + names = set(n.name for n in nodes) + subnodes = [] + if len(names) == 1 or 'mul' in names: # for first layer norm operators + for node in nodes: + subnodes.append(_replica(graph, node, tp_devs)) + elif 'row_attn' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) + elif 'col_attn' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) + elif 'opm' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) + elif 'tmo' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) + elif 'tmi' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) + elif 'tri_attn_start' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) + elif 'tri_attn_end' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) + elif 'feedforward' in names: + for node in nodes: + subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) + else: + assert False, names + # data parallelism + for ns in subnodes: + for tp_idx, subnode in enumerate(ns): + dp_devs = devmesh[:, sid, tp_idx] + if bs in subnode.input(0).shape: + dim = subnode.input(0).shape.index(bs) + _tp(graph, subnode, dp_devs, idx=0, dim=dim) + else: + print(f'replicate op on data parallel group: {node.name}') + _replica(graph, subnode, dp_devs) + + strategy = IRScheduleNF1B(graph, nmbs, recycle) + # strategy = IRSchedule1F1B(graph, nmbs) + graph.predef_sched(strategy) + + return graph diff --git a/examples/openfold/train.py b/examples/openfold/train.py index d7321d01..77f9bc5b 100644 --- a/examples/openfold/train.py +++ b/examples/openfold/train.py @@ -1,7 +1,7 @@ """ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ - examples/openfold/train.py --fp16 --tp 4 --dp 1 + examples/openfold/train.py --fp16 --tp 2 --pp 2 --gbs 4 --recycle 2 """ @@ -11,7 +11,7 @@ import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary -from examples.openfold.policy.mpmd import PASDAP +from examples.openfold.policy.mpmd import PASDAP, PASRoundRobin, PASNF1B, PASDAPPipe import argparse from functools import partial @@ -36,17 +36,22 @@ help='global batch size') parser.add_argument('--tp', type=int, default=1, help='tensor parallelism size') -parser.add_argument('--dp', type=int, default=1, +parser.add_argument('--pp', type=int, default=1, + help='data parallelism size') +parser.add_argument('--recycle', type=int, default=2, help='data parallelism size') args = parser.parse_args() +dp = cube.runtime.device.DeviceGroup().world_size // (args.tp * args.pp) assert args.gbs % args.mbs == 0 -assert args.mbs % args.dp == 0 +assert args.mbs % dp == 0 assert args.msa_hidden % args.head_dim == 0 assert args.pair_hidden % args.head_dim == 0 -PASDAP = partial(PASDAP, tp=args.tp, dp=args.dp) +# PASDAP = partial(PASDAP, tp=args.tp) +PASNF1B = partial(PASNF1B, mbs=args.mbs, gbs=args.gbs, recycle=1) +PASDAPPipe = partial(PASDAPPipe, mbs=args.mbs, gbs=args.gbs, tp=args.tp, pp=args.pp, recycle=args.recycle) def nparams(model) -> int: @@ -78,7 +83,7 @@ def train(): print_each_rank(f'before partitioned model parameter: {nparams(model)}') model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=PASDAP, override=True, load_content=True) + @cube.compile(model, dataloader, PAS=PASDAPPipe, override=True, load_content=True) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) @@ -98,8 +103,7 @@ def train_iter(model, dataloader): for step in range(iter_num): if step == warmup: CudaTimer(enable=True, predefined=True).start('e2e') - for _ in range(args.gbs // args.mbs): - train_iter(model, dataloader) + train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() From a930f1226b6b3fbfff813825900e936848abc59b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 3 Dec 2022 16:52:46 +0800 Subject: [PATCH 1177/1892] clear output --- cube/graph/segment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 578cb304..5ed0c8df 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1007,7 +1007,6 @@ def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: """Reorder by logical tensor id. Temporally necessary for pipeline scheduling""" tensors = list(tensors) - print(tensors) tids = np.array([t.parent.tid for t in tensors]) indices = np.argsort(tids) return tuple(tensors[idx] for idx in indices) From d769ae25f139f33c69b657bea82b9373ec20cb37 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 4 Dec 2022 20:44:11 +0800 Subject: [PATCH 1178/1892] new schedule nf1b --- cube/runtime/schedule/schednf1b.py | 137 ++++++++++++++++++++++++++++- 1 file changed, 134 insertions(+), 3 deletions(-) diff --git a/cube/runtime/schedule/schednf1b.py b/cube/runtime/schedule/schednf1b.py index 9d963839..23cb8d75 100644 --- a/cube/runtime/schedule/schednf1b.py +++ b/cube/runtime/schedule/schednf1b.py @@ -14,16 +14,18 @@ def first_stage_rfadapter(shapes: Tuple[List[int]], dtypes: Tuple[List[torch.dtype]], dataloader): - return next(dataloader) + outputs = next(dataloader) + outputs = tuple(t.clone() for t in outputs) + return outputs -def last_stage_sfadapter(msa_repr: torch.Tensor, pair_repr: torch.Tensor): +def last_stage_sfadapter(*msa_repr_and_pair_repr: torch.Tensor): pass class ScheduleNF1B(ScheduleABC): @staticmethod - def run(segment: Callable, # forward body + def _deprecate_run(segment: Callable, # forward body rfadapter: Callable, # recv_forward adapter sfadapter: Callable, # send_forward adapter rbadapter: Callable, # recv_backward adapter @@ -49,6 +51,7 @@ def run(segment: Callable, # forward body input_grads = ScheduleNF1B.backward_step(inputs, outputs, (None,)) for reducer in reducers: reducer() + print(f'> rank [{torch.distributed.get_rank()}]: {ScheduleNF1B._fw_cnt}') return # =============================== recycle ==================================== @@ -83,6 +86,7 @@ def run(segment: Callable, # forward body if stage_id == 0: for mid in range(num_microbatch): inputs = ScheduleNF1B.adapter_step(rfadapter, require_grad=False) + inputs = (t.cpu() for t in inputs) datas.append(inputs) # ========================================================================== @@ -98,6 +102,7 @@ def run(segment: Callable, # forward body # print(f'rank[{torch.distributed.get_rank()}]: line26: recving forward') inputs = ScheduleNF1B.adapter_step(rfadapter, True) inputs = datas.pop(0) if inputs == (None,) else inputs + inputs = tuple(t.cuda() for t in inputs) # forward ScheduleNF1B.push_tail('inputs', inputs) outputs = ScheduleNF1B.forward_step(segment, *inputs) @@ -134,6 +139,7 @@ def run(segment: Callable, # forward body # print(f'rank[{torch.distributed.get_rank()}]: line77 send backward recv forward') inputs = ScheduleNF1B.exchange(sbadapter, rfadapter, stage_id, (False, True), *input_grads) inputs = datas.pop(0) if inputs == (None,) else inputs + inputs = tuple(t.cuda() for t in inputs) else: # send backward # print(f'rank[{torch.distributed.get_rank()}]: line82 send backward') @@ -159,3 +165,128 @@ def run(segment: Callable, # forward body assert len(datas) == 0 ScheduleNF1B.assert_empty() + + + @staticmethod + def run(segment: Callable, # forward body + rfadapter: Callable, # recv_forward adapter + sfadapter: Callable, # send_forward adapter + rbadapter: Callable, # recv_backward adapter + sbadapter: Callable, # send_backward adapter + dataloader: Iterable, + stage_id: int, + num_stages: int, + num_microbatch: int, + recycle: int, + reducers: List[Callable]): + + assert num_microbatch >= num_stages + + # special case: num_stages == 1: use gradient accum + if num_stages == 1: + for _ in range(num_microbatch): + inputs = ScheduleNF1B.dataloader_step(dataloader) + for _ in range(recycle): + # FIXME: a simulation as output will be loss + with torch.no_grad(): + _ = ScheduleNF1B.forward_step(segment, *inputs) + outputs = ScheduleNF1B.forward_step(segment, *inputs) + input_grads = ScheduleNF1B.backward_step(inputs, outputs, (None,)) + for reducer in reducers: + reducer() + # print(f'> rank [{torch.distributed.get_rank()}]: {ScheduleNF1B._fw_cnt}') + return + + # setup dummpy adapter + if stage_id == 0: + assert rfadapter is None + shapes, dtypes = [], [] + for data in ScheduleNF1B.dataloader_step(dataloader): + shapes.append(list(data.size())) + dtypes.append(data.dtype) + rfadapter = partial(first_stage_rfadapter, shapes=shapes, dtypes=dtypes, dataloader=dataloader) + if stage_id == num_stages - 1: + assert sfadapter is None + sfadapter = last_stage_sfadapter + + # =============================== warmup ======================== + for rid in range(recycle): + # forward rid micro-batches + for t in range(rid+1): + inputs = ScheduleNF1B.adapter_step(rfadapter, False) + inputs = ScheduleNF1B.dataloader_step(dataloader) if inputs == (None,) else inputs + with torch.no_grad(): + outputs = ScheduleNF1B.forward_step(segment, *inputs) + ScheduleNF1B.adapter_step(sfadapter, False, *outputs) + + # print(f'> rank [{torch.distributed.get_rank()}]: OK here') + + # recv inputs + inputs = ScheduleNF1B.adapter_step(rfadapter, stage_id != 0) + + # steady pattern + for fmid in range(num_microbatch + num_stages - 1 - stage_id): + + # ======================= forward region ==================== + if fmid + 1 < num_microbatch: + with torch.no_grad(): + outputs = ScheduleNF1B.forward_step(segment, *inputs) + + # ================== send forward recv backward ================== + send_fw = fmid + 1 < num_microbatch + bmid = fmid - (num_stages - 1 - stage_id) + recv_bw = 0 <= bmid and bmid < num_microbatch + if send_fw and recv_bw: + grads = ScheduleNF1B.exchange(sfadapter, rbadapter, stage_id, (False, False), *outputs) + elif send_fw: + ScheduleNF1B.adapter_step(sfadapter, False, *outputs) + elif recv_bw: + grads = ScheduleNF1B.adapter_step(rbadapter, False) + else: + assert False, f"> rank [{torch.distributed.get_rank()}]: Fail at fmid: {fmid}" + + # ===================== backward region ================== + + # recycle inference + for idx in range(recycle - 1): + if fmid + 2 + idx < num_microbatch: + # recv forward + inputs = ScheduleNF1B.adapter_step(rfadapter, False) + # forward + with torch.no_grad(): + outputs = ScheduleNF1B.forward_step(segment, *inputs) + # send forward + ScheduleNF1B.adapter_step(sfadapter, False, *outputs) + + # train forward + if fmid < num_microbatch: + # recv forward + inputs = ScheduleNF1B.adapter_step(rfadapter, stage_id != 0) + # forward + ScheduleNF1B.push_tail('inputs', inputs) + outputs = ScheduleNF1B.forward_step(segment, *inputs) + ScheduleNF1B.push_tail('outputs', outputs) + # send forward + ScheduleNF1B.adapter_step(sfadapter, True, *outputs) + + # train backward + bmid = fmid - (num_stages - 1 - stage_id) + if 0 <= bmid and bmid < num_microbatch: + inputs, outputs = ScheduleNF1B.pop_head('inputs'), ScheduleNF1B.pop_head('outputs') + input_grads = ScheduleNF1B.backward_step(inputs, outputs, grads) + + # =============== send backward recv forward ===================== + send_bw = 0 <= bmid and bmid < num_microbatch + recv_fw = fmid + 2 < num_microbatch + if send_bw and recv_fw: + ScheduleNF1B.exchange(sbadapter, rfadapter, stage_id, (False, False), *input_grads) + elif send_bw: + ScheduleNF1B.adapter_step(sbadapter, False, *input_grads) + elif recv_fw: + inputs = ScheduleNF1B.adapter_step(rfadapter, False) + + for reducer in reducers: + reducer() + + ScheduleNF1B.assert_empty() + # print(f'> rank [{torch.distributed.get_rank()}]: {ScheduleNF1B._fw_cnt}') From c287ff7ba33c31ed6ae8e22f06231c653c76e5b5 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Tue, 6 Dec 2022 04:05:55 +0000 Subject: [PATCH 1179/1892] simple change --- examples/nlp/blocks/attention.py | 2 +- examples/nlp/gpt/model.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index f704dac4..0681e1c0 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -182,7 +182,7 @@ def forward(self, query): attn = self_attention( query, self.qkv_proj, self.qkv_bias, self.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=True + self.num_heads, self.scaling, self.dropout_p, mask=False ) attn = attn + self.out_bias return attn diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 396e7c55..9334e0f3 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -23,6 +23,8 @@ class Config: def build_gpt_config(name: str) -> Config: if name == '350M': embed_dim, layers, attention_heads = 1024, 24, 16 + elif name == 'test': + embed_dim, layers, attention_heads = 1024, 4, 16 elif name == '760M': embed_dim, layers, attention_heads = 1536, 24, 16 elif name == '1.3B': From b2a79f88c2736b995601f631e3e5589bda9d3b31 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 6 Dec 2022 13:46:56 +0800 Subject: [PATCH 1180/1892] allow initialize using other tools --- cube/runtime/device.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cube/runtime/device.py b/cube/runtime/device.py index 6d67bcaf..de8d7802 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -24,9 +24,8 @@ def __init__(self): self.groups = dict() torch.cuda.set_device(0) else: - torch.distributed.init_process_group( - backend='nccl', - ) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl') self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() # assume each node has the same device number From fff115d1c5ae4ab534798073590095b8046bdcbc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 6 Dec 2022 16:57:39 +0800 Subject: [PATCH 1181/1892] add tflops calculation --- examples/openfold/blocks/evoformer.py | 43 +++++++++++++++++++++++++++ examples/openfold/model.py | 10 +++++++ examples/openfold/train.py | 5 ++-- 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/examples/openfold/blocks/evoformer.py b/examples/openfold/blocks/evoformer.py index 663a6d55..0128fa1f 100644 --- a/examples/openfold/blocks/evoformer.py +++ b/examples/openfold/blocks/evoformer.py @@ -132,3 +132,46 @@ def forward(self, msa_repr, pair_repr): pair_repr = residual + pair_repr return succ_msa_repr, pair_repr + + def tflops(self, n_seq: int, n_res: int) -> float: + """ + Single sample tflops + """ + msa_size = n_seq * n_res * self.cm + pair_size = n_seq * n_res * self.cz + flops = 0 + + # msa layer norm + flops += 4 * (msa_size * 4) + # pair layer norm + flops += 2 * (pair_size * 4) + + # attention: gate + qkv + q@k (N S head r c, N S head c r) + k@v + dense + msa_attn = n_seq * n_res * self.cm * self.cm + \ + 3 * n_seq * n_res * self.cm * self.cm + \ + n_seq * (self.cm // self.c) * n_res * n_res * self.c + \ + n_seq * (self.cm // self.c) * n_res * n_res * self.c + \ + n_seq * n_res * self.cm * self.cm + + pair_attn = n_res * n_res * self.cz * self.cz + \ + 3 * n_res * n_res * self.cz * self.cz + \ + n_res * (self.cz // self.c) * n_res * n_res * self.c + \ + n_res * (self.cz // self.c) * n_res * n_res * self.c + \ + n_res * n_res * self.cz * self.cz + + # row and col end attention + flops += 2 * msa_attn + # tirangle start and triangle end + flops += 2 * pair_attn + # msa and pair transition flops + flops += 8 * n_seq * n_res * (self.cm ** 2) + \ + 8 * n_res * n_res * (self.cz ** 2) + # pair_repr tmi and tmo: projection + gate + 2 matmul + flops += 2 * (n_res * n_res * self.cz * self.c_tri_mult) + \ + n_res * n_res * self.cz * self.cz + \ + self.c_tri_mult * n_res * n_res * n_res + n_res * n_res * self.c_tri_mult * self.cz + # opm: left + right + opm + flops += 2 * n_seq * n_res * self.cm * self.cz + \ + n_res * n_res * n_seq * self.c * self.c + \ + n_res * n_res * self.c * self.c * self.cz + return flops / 1e12 diff --git a/examples/openfold/model.py b/examples/openfold/model.py index f5c72822..7dd90e92 100644 --- a/examples/openfold/model.py +++ b/examples/openfold/model.py @@ -151,3 +151,13 @@ def forward(self, msa, pair): # msa, pair = input_unpacking(x, self.s, self.r, self.cm, self.cz) loss = torch.sum(msa) * torch.sum(pair) return loss + + + def tflops(self) -> float: + """ + TFLOPs for one sample + """ + tflops = 0. + for layer in self.evoformers: + tflops += layer.tflops(self.s, self.r) + return tflops diff --git a/examples/openfold/train.py b/examples/openfold/train.py index 77f9bc5b..dec3ebbf 100644 --- a/examples/openfold/train.py +++ b/examples/openfold/train.py @@ -1,7 +1,7 @@ """ OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - examples/openfold/train.py --fp16 --tp 2 --pp 2 --gbs 4 --recycle 2 + --nproc_per_node=1 \ + examples/openfold/train.py --fp16 --layers 24 --gbs 1 --recycle 2 """ @@ -69,6 +69,7 @@ def train(): print_each_rank(cfg, rank_only=0) model = AlphaFold(cfg) + print_each_rank(f'iteration total TFLOPs: {model.tflops() * (args.recycle + 1 + 2)}') if args.fp16: model = model.half() From 301c9ba7062695e4f9c604834e24ed0c7dacbb01 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Wed, 7 Dec 2022 10:53:34 +0800 Subject: [PATCH 1182/1892] change reducer and gpt3 implementation --- cube/runtime/adapter/reducer.py | 4 +++- examples/nlp/blocks/attention.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index d622e879..ae649642 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -20,10 +20,12 @@ def __init__(self, ranks: List[int]): def add_param(self, param: torch.nn.Parameter): self._params.append(param) - def allreduce(self): + def allreduce(self, run=False): """ Reduce gradients across given group """ + if not run: + return buckets = {} for param in self._params: if param.requires_grad and param.grad is not None: diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index f704dac4..0681e1c0 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -182,7 +182,7 @@ def forward(self, query): attn = self_attention( query, self.qkv_proj, self.qkv_bias, self.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=True + self.num_heads, self.scaling, self.dropout_p, mask=False ) attn = attn + self.out_bias return attn From ca117ec33a9379f7fbc15ab444c38d419e7fe905 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 7 Dec 2022 13:21:10 +0800 Subject: [PATCH 1183/1892] allow model to scale with weights --- examples/openfold/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/openfold/model.py b/examples/openfold/model.py index 7dd90e92..06953b13 100644 --- a/examples/openfold/model.py +++ b/examples/openfold/model.py @@ -121,14 +121,15 @@ def __init__(self, cfg: Config = Config()) -> None: # evoformer self.s, self.r, self.cm, self.cz = cfg.evoformer_s, cfg.evoformer_r, cfg.evoformer_cm, cfg.evoformer_cz - self.fout = self.s * self.r * self.cm + self.r * self.r * self.cz - + self.c = self.cfg.evoformer_c + assert self.cm % self.c == 0 and self.cz % self.c == 0 + self.msa_norm = torch.nn.LayerNorm(cfg.evoformer_cm) self.pair_norm = torch.nn.LayerNorm(cfg.evoformer_cz) self.evoformers = nn.ModuleList( [Evoformer( self.s, self.r, self.cm, self.cz, - cfg.evoformer_use_chunk + c=self.c, msa_head=self.cm // self.c, pair_head=self.cz // self.c, ) for _ in range(cfg.evoformer_nlayers)] ) From 33be0426db969c070717da8cc2ef64c01995303b Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 7 Dec 2022 12:48:34 +0000 Subject: [PATCH 1184/1892] merge with autodist_3 --- cube/profiler/database.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index b9e06d04..748b4b3d 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -216,18 +216,23 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): torch.cuda.set_device(device) input_byte_size, param_byte_size = 0, 0 + Residual_input_byte_size, input_count = 0, 0 for t in node.inputs(): if t.is_param(): param_byte_size = param_byte_size + t.byte_size() else: + input_count += 1 input_byte_size = input_byte_size + t.byte_size() + if input_count == 1: + Residual_input_byte_size += t.byte_size() + # run profiling fw_span, bw_span, infer_memory, train_memory = \ CompProfiler.profile(fn, shapes, dtypes, **kwargs) # log to database key = self._serialize(node) - self.insert(node.signature, key, input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory) + self.insert(node.signature, key, input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size) print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " f"=> in mem {input_byte_size} | param mem: {param_byte_size} | fw: {round(fw_span, 2)} ms | " @@ -235,10 +240,11 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): if isinstance(device, int): torch.cuda.set_device(orig_device) - return input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory + return input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size def insert(self, name: str, key: str, input_byte_size: int, param_byte_size: int, - fw_span: float, bw_span: float, infer_memory: int, train_memory: int): + fw_span: float, bw_span: float, infer_memory: int, train_memory: int, + Residual_input_byte_size: int): """ log the span of a function name with key @@ -254,7 +260,7 @@ def insert(self, name: str, key: str, input_byte_size: int, param_byte_size: int assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = (input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory) + self._data[name][key] = (input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size) def exist(self, node: IRFwOperation) -> bool: """ From 8c74d5535da9543eee63309686b69d4d9abfbefa Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Mon, 12 Dec 2022 13:44:42 +0000 Subject: [PATCH 1185/1892] add GPTFine --- examples/nlp/blocks/attention.py | 112 ++++++++++++++++++++++++++++++- examples/nlp/blocks/encoder.py | 35 +++++++++- examples/nlp/gpt/model.py | 44 +++++++++++- tmp | 7 ++ 4 files changed, 193 insertions(+), 5 deletions(-) create mode 100644 tmp diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 0681e1c0..d1b01b79 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -55,6 +55,82 @@ def self_attention(query: torch.Tensor, output = torch.nn.functional.linear(output, out_proj) # L N (h d), E E -> L N E return output +# @cube.graph.parser.register('L^ N E^, (h d^ 3) E^, (h d^ 3) -> L^ N l^ 3', name='qkv_combined') +@cube.graph.parser.register('L^ N E^, (h d^ 3) E^, (h d^ 3) -> L^ N (h d^) 3', name='qkv_combined') +def qvk_combined(query: torch.Tensor, + qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, + #out_proj: torch.Tensor, + h: int, scale: float, dropout_p: float, mask: bool = True): + num_head = h + L, N = query.size(0), query.size(1) + dim_head = qkv_proj.size(0) // num_head // 3 + + qkv = torch.nn.functional.linear(query, qkv_proj, qkv_bias) # L N E, (h d 3) E -> L N (h d 3) + output = qkv.view(L, N, num_head * dim_head, 3) # L N (h d 3) -> L N (h d) 3 + # q, k, v = qkv.chunk(3, dim=-1) # L N (h d) 3 -> L N (h d), L N (h d), L N (h d) + # q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + # k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + # v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + + + return output + +@cube.graph.parser.register('L^ N (h d^) 3 -> L^ N (h d^)', name='attention_mask') +def attention_mask(qkv: torch.Tensor, + # out_proj: torch.Tensor, + h: int, scale: float, dropout_p: float, mask: bool = True): + + L, N = qkv.size(0), qkv.size(1) + num_head = h + dim_head = qkv.size(2) // num_head + + q, k, v = qkv.chunk(3, dim=-1) # L N (h d) 3 -> L N (h d), L N (h d), L N (h d) + q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d + + # #########split + + + # preallocating input tensor: (N h) L L + matmul_input_buffer = torch.empty([N * h, L, L], dtype=q.dtype, device=q.device) + # L (N h) d, L (N h) d -> (N h) L L + attn = torch.baddbmm( + matmul_input_buffer, + q.transpose(0, 1), # (N h) L d + k.transpose(0, 1).transpose(1, 2), # (N h) d L + beta=0.0, alpha=scale + ) + # ======== replace the semantic into more efficient implementation ============ + + # attention mask + if mask: # (N h) L L -> (N h) L L + attn = attn.view(N, num_head, L, L) + ones = torch.ones((N, L, L), device=attn.device) + amask = torch.tril(ones) + amask = amask.view(N, 1, L, L) + amask = (amask < 0.5) + attn = attn.masked_fill_(amask, -10000.0) + attn = attn.view((N * num_head), L, L) + + attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L + attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L + v = v.transpose(0, 1) # L (N h) d -> (N h) L d + output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d + output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d + output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) + return output + +@cube.graph.parser.register('L^ N (h+ d^), E^ (h+ d^) -> L^ N E^', name='attention_mask') +def lin(lin_input: torch.Tensor, + # qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, + out_proj: torch.Tensor, + h: int, scale: float, dropout_p: float, mask: bool = True): + ###########split + output = torch.nn.functional.linear(lin_input, out_proj) # L N (h d), E E -> L N E + # output = torch.nn.functional.linear(output, out_proj) + return output + @cube.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') def cross_attention(query: torch.Tensor, key: torch.Tensor, @@ -62,7 +138,7 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, v_proj: torch.Tensor, v_bias: torch.Tensor, out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = True): + h: int, scale: float, dropout_p: float, mask: bool = False): num_head = h L, N = query.size(0), query.size(1) dim_head = q_proj.size(0) // num_head @@ -161,6 +237,40 @@ def one_attention(hidden_states: torch.Tensor, output = torch.nn.functional.linear(output, out_proj, None) # l N (h d), E E -> l N E return output +class MultiHeadSelfAttentionLrw(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): + super().__init__() + self.inner_dim = inner_dim + self.num_heads = num_heads + self.head_dim = inner_dim // num_heads + self.scaling = self.head_dim ** -0.5 + self.dropout_p = dropout + # QKV [(h d 3), E] + self.qkv_proj = torch.nn.Parameter(torch.empty(3 * inner_dim, embed_dim)) + self.qkv_bias = torch.nn.Parameter(torch.empty(3 * inner_dim)) + # Out + self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) + self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) + + def forward(self, query): + + qkv = qvk_combined( + query, self.qkv_proj, self.qkv_bias, + #.out_proj, + self.num_heads, self.scaling, self.dropout_p, mask=True + ) + lin_input = attention_mask( + qkv, + self.num_heads, self.scaling, self.dropout_p, mask=True + ) + attn = lin( + lin_input, + self.out_proj, + self.num_heads, self.scaling, self.dropout_p, mask=True + ) + attn = attn + self.out_bias + return attn class MultiHeadSelfAttention(torch.nn.Module): diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 1fa1b697..d6fe946c 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -1,7 +1,40 @@ import torch -from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention +from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention, MultiHeadSelfAttentionLrw from examples.nlp.blocks.mlp import MLP +class EncoderLayerLrw(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, + attn_hidden_dim: int, ffn_hidden_dim: int, + dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): + super().__init__() + self.self_attn = MultiHeadSelfAttentionLrw( + embed_dim, num_heads, attn_hidden_dim, atten_dropout + ) + self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) + self.dropout = torch.nn.Dropout(p=dropout) + self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + residual = x + ### + x = self.self_attn_layer_norm(x) + ### + x = self.self_attn(x) + ##### + x = self.dropout(x) + x = x + residual + + residual = x + ##### + x = self.final_layer_norm(x) + x = self.mlp(x) + ### + x = self.dropout(x) + x = x + residual + return x class EncoderLayer(torch.nn.Module): diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 9334e0f3..b2608c2a 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -1,11 +1,9 @@ - import torch -from examples.nlp.blocks.encoder import EncoderLayer, EncoderInferLayer +from examples.nlp.blocks.encoder import EncoderLayer, EncoderLayerLrw, EncoderInferLayer import cube from dataclasses import dataclass - @dataclass class Config: embed_dim: int = 1024 @@ -43,6 +41,46 @@ def build_gpt_config(name: str) -> Config: assert False, f'unrecognized name: {name}' return Config(embed_dim, layers, attention_heads, embed_dim, 4 * embed_dim) +class GPTFineGrained(torch.nn.Module): + + def __init__(self, cfg=Config()): + super().__init__() + + # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) + self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.embed_dim)) + self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) + self.embed_dropout = torch.nn.Dropout() + + self.layers = torch.nn.ModuleList( + [EncoderLayerLrw( + cfg.embed_dim, cfg.attention_heads, + cfg.attn_hidden_dim, cfg.ffn_hidden_dim, + cfg.dropout, cfg.attn_dropout, cfg.activation_dropout + ) for _ in range(cfg.layers)] + ) + self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): + # embed = self.embed(input_ids) + embed = torch.nn.functional.embedding( + input_ids, self.embedw, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False + ) + pos_embed = self.position(position_ids) + embed = embed + pos_embed + embed = self.embed_dropout(embed) + enc = embed.transpose(0, 1) + + for layer in self.layers: + cube.runtime.function.anchor('transformer start') + enc = layer(enc) + enc = self.final_layernorm(enc) + + # logits = torch.nn.functional.linear(enc, self.embed.weight) + logits = torch.nn.functional.linear(enc, self.embedw) + # simplified + loss = torch.sum(logits) + return loss class GPT(torch.nn.Module): diff --git a/tmp b/tmp new file mode 100644 index 00000000..28e3d078 --- /dev/null +++ b/tmp @@ -0,0 +1,7 @@ +nohup: ignoring input +benchmarking 4 gpus... +benchmarking 4 gpus... +benchmarking 4 gpus... +benchmarking 4 gpus... +/opt/conda/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 8 leaked semaphore objects to clean up at shutdown + warnings.warn('resource_tracker: There appear to be %d ' From 00f2bd785d7b8f9489e103263763b4d67b063c0f Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Mon, 12 Dec 2022 14:20:07 +0000 Subject: [PATCH 1186/1892] refine code --- examples/nlp/blocks/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index d1b01b79..887fc592 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -185,7 +185,7 @@ def one_attention(hidden_states: torch.Tensor, # v_proj: torch.Tensor, v_bias: torch.Tensor, qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, #out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, is_training: bool = True, mask: bool = True): + h: int, scale: float, dropout_p: float, is_training: bool = True, mask: bool = False): num_head = h l, N = hidden_states.size(0), hidden_states.size(1) # dim_head = q_proj.size(0) // num_head From c1d96286241548a047511a8d1698e34f35b46ca8 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Mon, 12 Dec 2022 15:04:57 +0000 Subject: [PATCH 1187/1892] refine code --- examples/nlp/blocks/attention.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 887fc592..74e19014 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -6,7 +6,7 @@ def self_attention(query: torch.Tensor, qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = True): + h: int, scale: float, dropout_p: float, mask: bool = False): num_head = h L, N = query.size(0), query.size(1) dim_head = qkv_proj.size(0) // num_head // 3 @@ -60,7 +60,7 @@ def self_attention(query: torch.Tensor, def qvk_combined(query: torch.Tensor, qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, #out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = True): + h: int, scale: float, dropout_p: float, mask: bool = False): num_head = h L, N = query.size(0), query.size(1) dim_head = qkv_proj.size(0) // num_head // 3 @@ -78,7 +78,7 @@ def qvk_combined(query: torch.Tensor, @cube.graph.parser.register('L^ N (h d^) 3 -> L^ N (h d^)', name='attention_mask') def attention_mask(qkv: torch.Tensor, # out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = True): + h: int, scale: float, dropout_p: float, mask: bool = False): L, N = qkv.size(0), qkv.size(1) num_head = h @@ -125,7 +125,7 @@ def attention_mask(qkv: torch.Tensor, def lin(lin_input: torch.Tensor, # qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = True): + h: int, scale: float, dropout_p: float, mask: bool = False): ###########split output = torch.nn.functional.linear(lin_input, out_proj) # L N (h d), E E -> L N E # output = torch.nn.functional.linear(output, out_proj) @@ -258,16 +258,16 @@ def forward(self, query): qkv = qvk_combined( query, self.qkv_proj, self.qkv_bias, #.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=True + self.num_heads, self.scaling, self.dropout_p, mask=False ) lin_input = attention_mask( qkv, - self.num_heads, self.scaling, self.dropout_p, mask=True + self.num_heads, self.scaling, self.dropout_p, mask=False ) attn = lin( lin_input, self.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=True + self.num_heads, self.scaling, self.dropout_p, mask=False ) attn = attn + self.out_bias return attn From ba2da5473001a66207b600ef84e988d4ef75ec28 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 13 Dec 2022 11:15:58 +0800 Subject: [PATCH 1188/1892] save work --- cube/runtime/adapter/reducer.py | 41 ++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index ae649642..bdbe127b 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -10,12 +10,13 @@ class Reducer: - def __init__(self, ranks: List[int]): + def __init__(self, ranks: List[int], bucket_size=536870912): self._params: List[torch.nn.Parameter] = list() # note this need to be called for every device self.ranks = ranks self._group = DeviceGroup().get_group(ranks) + self.bucket_size = bucket_size def add_param(self, param: torch.nn.Parameter): self._params.append(param) @@ -27,24 +28,36 @@ def allreduce(self, run=False): if not run: return buckets = {} + tp2size = {} for param in self._params: if param.requires_grad and param.grad is not None: + cur_byte_size = param.nelement() * param.element_size() + assert cur_byte_size <= self.bucket_size + tp = param.data.type() - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(param) + if tp not in self.buckets: + buckets[tp] = [[param]] + tp2size[tp] = cur_byte_size + else: + if tp2size[tp] + cur_byte_size <= self.bucket_size: + tp2size[tp] = tp2size[tp] + cur_byte_size + buckets[tp][-1].append(param) + else: + tp2size[tp] = cur_byte_size + buckets[tp].append([param]) + # for each bucket, do all-reduce for tp in buckets: - CudaTimer().start(field_name='comm', predefined=True) - bucket = buckets[tp] - grads = [param.grad.data for param in bucket] - coalesced = self._flatten_dense_tensors(grads) - # coalesced /= len(self.ranks) - torch.distributed.all_reduce(coalesced, group=self._group) - all_synced = self._unflatten_dense_tensors(coalesced, grads) - for grad, synced in zip(grads, all_synced): - grad.copy_(synced) - CudaTimer().stop(field_name='comm', predefined=True) + for bucket in buckets[tp]: + CudaTimer().start(field_name='comm', predefined=True) + grads = [param.grad.data for param in bucket] + coalesced = self._flatten_dense_tensors(grads) + # coalesced /= len(self.ranks) + torch.distributed.all_reduce(coalesced, group=self._group) + all_synced = self._unflatten_dense_tensors(coalesced, grads) + for grad, synced in zip(grads, all_synced): + grad.copy_(synced) + CudaTimer().stop(field_name='comm', predefined=True) def sync(self): """ From 364441a95b56b8e950a6d56ac2481905e8a1d467 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Tue, 13 Dec 2022 15:45:30 +0800 Subject: [PATCH 1189/1892] fix bug --- cube/runtime/adapter/reducer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index bdbe127b..a1b9e9e9 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -35,7 +35,7 @@ def allreduce(self, run=False): assert cur_byte_size <= self.bucket_size tp = param.data.type() - if tp not in self.buckets: + if tp not in buckets: buckets[tp] = [[param]] tp2size[tp] = cur_byte_size else: From fa89ac5d0ca1d02ad0be37211d6c33fce5a5fe59 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 14 Dec 2022 12:20:12 +0000 Subject: [PATCH 1190/1892] refine code --- examples/nlp/blocks/attention.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 74e19014..f3a5c544 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -55,7 +55,6 @@ def self_attention(query: torch.Tensor, output = torch.nn.functional.linear(output, out_proj) # L N (h d), E E -> L N E return output -# @cube.graph.parser.register('L^ N E^, (h d^ 3) E^, (h d^ 3) -> L^ N l^ 3', name='qkv_combined') @cube.graph.parser.register('L^ N E^, (h d^ 3) E^, (h d^ 3) -> L^ N (h d^) 3', name='qkv_combined') def qvk_combined(query: torch.Tensor, qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, @@ -67,17 +66,11 @@ def qvk_combined(query: torch.Tensor, qkv = torch.nn.functional.linear(query, qkv_proj, qkv_bias) # L N E, (h d 3) E -> L N (h d 3) output = qkv.view(L, N, num_head * dim_head, 3) # L N (h d 3) -> L N (h d) 3 - # q, k, v = qkv.chunk(3, dim=-1) # L N (h d) 3 -> L N (h d), L N (h d), L N (h d) - # q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - # k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - # v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - return output @cube.graph.parser.register('L^ N (h d^) 3 -> L^ N (h d^)', name='attention_mask') def attention_mask(qkv: torch.Tensor, - # out_proj: torch.Tensor, h: int, scale: float, dropout_p: float, mask: bool = False): L, N = qkv.size(0), qkv.size(1) @@ -89,9 +82,6 @@ def attention_mask(qkv: torch.Tensor, k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - # #########split - - # preallocating input tensor: (N h) L L matmul_input_buffer = torch.empty([N * h, L, L], dtype=q.dtype, device=q.device) # L (N h) d, L (N h) d -> (N h) L L @@ -126,9 +116,8 @@ def lin(lin_input: torch.Tensor, # qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, h: int, scale: float, dropout_p: float, mask: bool = False): - ###########split + output = torch.nn.functional.linear(lin_input, out_proj) # L N (h d), E E -> L N E - # output = torch.nn.functional.linear(output, out_proj) return output From 458bae475737028477e5b507b634d613089a9469 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 14 Dec 2022 12:25:09 +0000 Subject: [PATCH 1191/1892] refine code --- examples/nlp/blocks/attention.py | 3 +-- examples/nlp/blocks/encoder.py | 5 ----- scripts/megatron.sh | 4 ++++ 3 files changed, 5 insertions(+), 7 deletions(-) create mode 100644 scripts/megatron.sh diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index f3a5c544..2e76f668 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -116,7 +116,7 @@ def lin(lin_input: torch.Tensor, # qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, h: int, scale: float, dropout_p: float, mask: bool = False): - + output = torch.nn.functional.linear(lin_input, out_proj) # L N (h d), E E -> L N E return output @@ -246,7 +246,6 @@ def forward(self, query): qkv = qvk_combined( query, self.qkv_proj, self.qkv_bias, - #.out_proj, self.num_heads, self.scaling, self.dropout_p, mask=False ) lin_input = attention_mask( diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index d6fe946c..403ddcde 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -19,19 +19,14 @@ def __init__(self, embed_dim: int, num_heads: int, def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x - ### x = self.self_attn_layer_norm(x) - ### x = self.self_attn(x) - ##### x = self.dropout(x) x = x + residual residual = x - ##### x = self.final_layer_norm(x) x = self.mlp(x) - ### x = self.dropout(x) x = x + residual return x diff --git a/scripts/megatron.sh b/scripts/megatron.sh new file mode 100644 index 00000000..bc623dee --- /dev/null +++ b/scripts/megatron.sh @@ -0,0 +1,4 @@ +#!/bin/bash --login + +torchrun --nproc_per_node=2 --nnodes=1 \ + examples/nlp/gpt/train.py --policy=PASMegatronWSRTP --lrw --fp16 | tee -a LogForMegatronRecompute.txt \ No newline at end of file From 643bc0bca972e27ad6fbc66d73520424cc5ec100 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 14 Dec 2022 12:49:27 +0000 Subject: [PATCH 1192/1892] refine code --- examples/nlp/gpt/policy/spmd.py | 91 +++++++++++++++++++++++++++++++++ examples/nlp/gpt/train.py | 11 ++-- 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index 9838de62..397fb8b7 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -222,3 +222,94 @@ def PASMeshShard(graph: IRGraph, resource): # print(graph.extra_repr()) return graph + +def PASMegatronWSRTP(graph: IRGraph, resource): + tp_size = resource.ngpus + tp_devs = list(range(tp_size)) + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + + # annotating code structure -- not consider multiref on embedding weight + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + # why -1: multiref + fnodes[idx-1].comment = f'===> start of transformer layer {lid}' + + # attention + + qkvs = [node for node in fnodes if node.name == 'qkv_combined'] + #graph.recompute(qkvs) + for qkv in qkvs: + _tp(graph, qkv, tp_devs, idx=1, dim=0) + + + attns = [node for node in fnodes if node.name == 'attention_mask'] + graph.recompute(attns) + for attn in attns: + # graph.recompute(attn) + _tp(graph, attn, tp_devs, idx=0, dim=2) + # attns = [node for node in fnodes if node.name == 'self_attention'] + # for attn in attns: + # _tp(graph, attn, tp_devs, idx=1, dim=0) + + lins = [node for node in fnodes if node.name == 'lin'] + # graph.recompute(lins) + for lin in lins: + _tp(graph, lin, tp_devs, idx=1, dim=0) + + # feedforward + ffns = [node for node in fnodes if node.name == 'feedforward'] + # graph.recompute(ffns) + for ffn in ffns: + _tp(graph, ffn, tp_devs, idx=1, dim=0) + + # partition embed + embeds = [node for node in fnodes if node.name == 'embedding'] + for embed in embeds: + _tp(graph, embed, tp_devs, idx=1, dim=0) + + # partition last linear + linears = [node for node in fnodes if node.name == 'linear'] + _tp(graph, linears[-1], tp_devs, idx=1, dim=0) + + # partition loss + sums = [node for node in fnodes if node.name == 'sum'] + assert len(sums) == 1 + _tp(graph, sums[0], tp_devs, idx=0, dim=2) + + def GenerateNodesForSP(nodes): + output=[] + count = 0 + for node in nodes: + if isinstance(node, (IRFwOperation)) and not isinstance(node, (IRGraphAnchor)): + # if len(node.device) == 0: + sign = node.signature.split('.')[-1] + cid = node.cid + if len(output) == 0: + if sign == 'layer_norm': + output.append(node) + elif sign == 'dropout': + count = 0 + output.append(node) + count += 1 + elif sign == 'add' and count == 1: + output.append(node) + count += 1 + elif sign == 'layer_norm' and count == 2: + output.append(node) + elif sign == 'add': + output.append(node) + return output + + for node in GenerateNodesForSP(graph.nodes()): + _tp(graph, node, tp_devs, idx=0, dim=0) + + # replicate other nodes + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: + _replica(graph, node, tp_devs) + print(node) + # if isinstance(node, (IRFwOperation)) and not isinstance(node, (IRGraphAnchor)): + # print(node.cid) + + return graph \ No newline at end of file diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 22bf7bdf..a1ca1af6 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -11,7 +11,7 @@ import torch import time -from examples.nlp.gpt.model import GPT +from examples.nlp.gpt.model import GPT, GPTFineGrained, build_gpt_config from examples.nlp.gpt.model import GPTDataLoader import cube @@ -28,6 +28,7 @@ parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') +parser.add_argument('--lrw', action='store_true',help='use lrw\'s model') args = parser.parse_args() cube.init() @@ -48,9 +49,11 @@ def train(): - batch_size = 4 - - model = GPT() + batch_size = 8 + if args.policy == 'PASMegatronWSRTP': + model = GPTFineGrained(build_gpt_config('760M')) + else: + model = GPT(build_gpt_config('test')) model = model if not args.fp16 else model.half() dataloader = GPTDataLoader(batch_size) From 5c213c9eb391880b20b78f13310a7e97082cc680 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 14 Dec 2022 12:55:47 +0000 Subject: [PATCH 1193/1892] refine code --- examples/nlp/gpt/model.py | 3 --- examples/nlp/gpt/policy/spmd.py | 16 ++-------------- examples/nlp/gpt/train.py | 6 +++--- 3 files changed, 5 insertions(+), 20 deletions(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index b2608c2a..48442eeb 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -46,7 +46,6 @@ class GPTFineGrained(torch.nn.Module): def __init__(self, cfg=Config()): super().__init__() - # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.embed_dim)) self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) self.embed_dropout = torch.nn.Dropout() @@ -61,7 +60,6 @@ def __init__(self, cfg=Config()): self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): - # embed = self.embed(input_ids) embed = torch.nn.functional.embedding( input_ids, self.embedw, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False @@ -76,7 +74,6 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): enc = layer(enc) enc = self.final_layernorm(enc) - # logits = torch.nn.functional.linear(enc, self.embed.weight) logits = torch.nn.functional.linear(enc, self.embedw) # simplified loss = torch.sum(logits) diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index 397fb8b7..48a698e7 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -235,31 +235,22 @@ def PASMegatronWSRTP(graph: IRGraph, resource): # why -1: multiref fnodes[idx-1].comment = f'===> start of transformer layer {lid}' - # attention - qkvs = [node for node in fnodes if node.name == 'qkv_combined'] - #graph.recompute(qkvs) for qkv in qkvs: _tp(graph, qkv, tp_devs, idx=1, dim=0) - + # implement selective recompute attns = [node for node in fnodes if node.name == 'attention_mask'] graph.recompute(attns) for attn in attns: - # graph.recompute(attn) _tp(graph, attn, tp_devs, idx=0, dim=2) - # attns = [node for node in fnodes if node.name == 'self_attention'] - # for attn in attns: - # _tp(graph, attn, tp_devs, idx=1, dim=0) lins = [node for node in fnodes if node.name == 'lin'] - # graph.recompute(lins) for lin in lins: _tp(graph, lin, tp_devs, idx=1, dim=0) # feedforward ffns = [node for node in fnodes if node.name == 'feedforward'] - # graph.recompute(ffns) for ffn in ffns: _tp(graph, ffn, tp_devs, idx=1, dim=0) @@ -277,12 +268,12 @@ def PASMegatronWSRTP(graph: IRGraph, resource): assert len(sums) == 1 _tp(graph, sums[0], tp_devs, idx=0, dim=2) + # tp def GenerateNodesForSP(nodes): output=[] count = 0 for node in nodes: if isinstance(node, (IRFwOperation)) and not isinstance(node, (IRGraphAnchor)): - # if len(node.device) == 0: sign = node.signature.split('.')[-1] cid = node.cid if len(output) == 0: @@ -308,8 +299,5 @@ def GenerateNodesForSP(nodes): for node in graph.nodes(): if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: _replica(graph, node, tp_devs) - print(node) - # if isinstance(node, (IRFwOperation)) and not isinstance(node, (IRGraphAnchor)): - # print(node.cid) return graph \ No newline at end of file diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index a1ca1af6..a6ee360f 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -28,7 +28,6 @@ parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') -parser.add_argument('--lrw', action='store_true',help='use lrw\'s model') args = parser.parse_args() cube.init() @@ -50,10 +49,11 @@ def train(): batch_size = 8 + Config=build_gpt_config('760M') if args.policy == 'PASMegatronWSRTP': - model = GPTFineGrained(build_gpt_config('760M')) + model = GPTFineGrained(Config) else: - model = GPT(build_gpt_config('test')) + model = GPT(Config) model = model if not args.fp16 else model.half() dataloader = GPTDataLoader(batch_size) From a4b5c436b6e4bb0b1b29f997e56f17e96af9621a Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 14 Dec 2022 12:58:28 +0000 Subject: [PATCH 1194/1892] refine code --- examples/nlp/blocks/attention.py | 2 +- examples/nlp/blocks/encoder.py | 6 +++--- examples/nlp/gpt/model.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 2e76f668..54a75a60 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -226,7 +226,7 @@ def one_attention(hidden_states: torch.Tensor, output = torch.nn.functional.linear(output, out_proj, None) # l N (h d), E E -> l N E return output -class MultiHeadSelfAttentionLrw(torch.nn.Module): +class MultiHeadSelfAttentionFineGrained(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): super().__init__() diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 403ddcde..d933df98 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -1,14 +1,14 @@ import torch -from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention, MultiHeadSelfAttentionLrw +from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention, MultiHeadSelfAttentionFineGrained from examples.nlp.blocks.mlp import MLP -class EncoderLayerLrw(torch.nn.Module): +class EncoderLayerFineGrained(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, attn_hidden_dim: int, ffn_hidden_dim: int, dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): super().__init__() - self.self_attn = MultiHeadSelfAttentionLrw( + self.self_attn = MultiHeadSelfAttentionFineGrained( embed_dim, num_heads, attn_hidden_dim, atten_dropout ) self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 48442eeb..8a7c9b12 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -1,6 +1,6 @@ import torch -from examples.nlp.blocks.encoder import EncoderLayer, EncoderLayerLrw, EncoderInferLayer +from examples.nlp.blocks.encoder import EncoderLayer, EncoderLayerFineGrained, EncoderInferLayer import cube from dataclasses import dataclass @@ -51,7 +51,7 @@ def __init__(self, cfg=Config()): self.embed_dropout = torch.nn.Dropout() self.layers = torch.nn.ModuleList( - [EncoderLayerLrw( + [EncoderLayerFineGrained( cfg.embed_dim, cfg.attention_heads, cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.dropout, cfg.attn_dropout, cfg.activation_dropout From 6ca4ad79b5eb73cd4f46a0c6caaf74551794abff Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 14 Dec 2022 13:05:01 +0000 Subject: [PATCH 1195/1892] add comments --- cube/profiler/database.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 748b4b3d..3e97c6cd 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -216,6 +216,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): torch.cuda.set_device(device) input_byte_size, param_byte_size = 0, 0 + # add residual_input_mem for the continous recompute Residual_input_byte_size, input_count = 0, 0 for t in node.inputs(): if t.is_param(): From cf10262d37424fa8b03a38f43cfe535e67ba4975 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 14 Dec 2022 13:10:27 +0000 Subject: [PATCH 1196/1892] delete tmp --- tmp | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 tmp diff --git a/tmp b/tmp deleted file mode 100644 index 28e3d078..00000000 --- a/tmp +++ /dev/null @@ -1,7 +0,0 @@ -nohup: ignoring input -benchmarking 4 gpus... -benchmarking 4 gpus... -benchmarking 4 gpus... -benchmarking 4 gpus... -/opt/conda/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 8 leaked semaphore objects to clean up at shutdown - warnings.warn('resource_tracker: There appear to be %d ' From af69417150531b95d05c03bdf3aea779a4a3b8ef Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 14 Dec 2022 14:01:56 +0000 Subject: [PATCH 1197/1892] refine code --- cube/profiler/database.py | 17 +++++++++++------ examples/nlp/blocks/attention.py | 9 ++++----- examples/nlp/gpt/model.py | 2 -- scripts/megatron.sh | 2 -- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 3e97c6cd..29736342 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -192,7 +192,7 @@ def get_dep_names(sign: str): dtypes.append(IRDType2TorchDType.map(t.dtype)) return fn, shapes, dtypes, node.kwargs - def profile(self, node: IRFwOperation, device: Optional[int] = None): + def profile(self, node: IRFwOperation, device: Optional[int] = None, residual_mem: bool = False): """ Profile a forward node in IRGraph on a specific device (default current device) @@ -216,7 +216,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): torch.cuda.set_device(device) input_byte_size, param_byte_size = 0, 0 - # add residual_input_mem for the continous recompute + # add Residual_input_byte_size for the continous recompute Residual_input_byte_size, input_count = 0, 0 for t in node.inputs(): if t.is_param(): @@ -233,7 +233,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): CompProfiler.profile(fn, shapes, dtypes, **kwargs) # log to database key = self._serialize(node) - self.insert(node.signature, key, input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size) + self.insert(node.signature, key, input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size, residual_mem) print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " f"=> in mem {input_byte_size} | param mem: {param_byte_size} | fw: {round(fw_span, 2)} ms | " @@ -241,11 +241,13 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): if isinstance(device, int): torch.cuda.set_device(orig_device) - return input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size + if residual_mem: + return input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size + return input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory def insert(self, name: str, key: str, input_byte_size: int, param_byte_size: int, fw_span: float, bw_span: float, infer_memory: int, train_memory: int, - Residual_input_byte_size: int): + Residual_input_byte_size: int, residual_mem: bool = False): """ log the span of a function name with key @@ -261,7 +263,10 @@ def insert(self, name: str, key: str, input_byte_size: int, param_byte_size: int assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = (input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size) + if residual_mem: + self._data[name][key] = (input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory, Residual_input_byte_size) + else: + self._data[name][key] = (input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory) def exist(self, node: IRFwOperation) -> bool: """ diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 54a75a60..b4249e3c 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -112,10 +112,9 @@ def attention_mask(qkv: torch.Tensor, return output @cube.graph.parser.register('L^ N (h+ d^), E^ (h+ d^) -> L^ N E^', name='attention_mask') -def lin(lin_input: torch.Tensor, - # qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, - out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = False): +def attention_out_linear(lin_input: torch.Tensor, + out_proj: torch.Tensor, + h: int, scale: float, dropout_p: float, mask: bool = False): output = torch.nn.functional.linear(lin_input, out_proj) # L N (h d), E E -> L N E return output @@ -252,7 +251,7 @@ def forward(self, query): qkv, self.num_heads, self.scaling, self.dropout_p, mask=False ) - attn = lin( + attn = attention_out_linear( lin_input, self.out_proj, self.num_heads, self.scaling, self.dropout_p, mask=False diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 8a7c9b12..59341e64 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -21,8 +21,6 @@ class Config: def build_gpt_config(name: str) -> Config: if name == '350M': embed_dim, layers, attention_heads = 1024, 24, 16 - elif name == 'test': - embed_dim, layers, attention_heads = 1024, 4, 16 elif name == '760M': embed_dim, layers, attention_heads = 1536, 24, 16 elif name == '1.3B': diff --git a/scripts/megatron.sh b/scripts/megatron.sh index bc623dee..67bcd5b8 100644 --- a/scripts/megatron.sh +++ b/scripts/megatron.sh @@ -1,4 +1,2 @@ -#!/bin/bash --login - torchrun --nproc_per_node=2 --nnodes=1 \ examples/nlp/gpt/train.py --policy=PASMegatronWSRTP --lrw --fp16 | tee -a LogForMegatronRecompute.txt \ No newline at end of file From a4b4471365b8de7eaf1adfaf4e467de4aa7634c1 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Fri, 16 Dec 2022 12:35:58 +0000 Subject: [PATCH 1198/1892] refine allreduce --- cube/runtime/adapter/reducer.py | 4 +--- cube/runtime/schedule/sched1f1b.py | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index a1b9e9e9..de14eaca 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -21,12 +21,10 @@ def __init__(self, ranks: List[int], bucket_size=536870912): def add_param(self, param: torch.nn.Parameter): self._params.append(param) - def allreduce(self, run=False): + def allreduce(self): """ Reduce gradients across given group """ - if not run: - return buckets = {} tp2size = {} for param in self._params: diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py index c8d4bdde..1c75e04a 100644 --- a/cube/runtime/schedule/sched1f1b.py +++ b/cube/runtime/schedule/sched1f1b.py @@ -59,6 +59,8 @@ def run(segment: Callable, # forward body # steady for i in range(num_warmup_remaining): + if torch.distributed.get_rank() == 0: + print(torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) # forward Schedule1F1B.push_tail('inputs', inputs) if recompute: From 8c91fa9ea5856fd75f352acb2c5f26f896aea728 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 19 Dec 2022 09:22:07 +0800 Subject: [PATCH 1199/1892] clear compiler status before each compile --- cube/compiler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cube/compiler.py b/cube/compiler.py index a571a154..e15514d7 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -21,6 +21,7 @@ from cube.runtime.syndata import CubeDataLoader from cube.program import Program, SemanticDataLoader, SemanticModel +from cube.ir.unique import IDGenerator from cube.flags import CompileFlag @@ -59,6 +60,10 @@ def train_step(model, dataloader): @return sched_fn Callable: the scheduling function loaded from generated code. """ + # clean global status + Program().clear() + IDGenerator().clear() + if not isinstance(model, SemanticModel): raise TypeError("Expect Semantic Model") if dataloader is None: From 375391cb92045c1d0f3602c0506eedeacfa7882e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 19 Dec 2022 10:48:29 +0800 Subject: [PATCH 1200/1892] add default allreduce bucket size --- cube/codegen/codegen.py | 5 +-- cube/flags.py | 3 ++ cube/runtime/adapter/reducer.py | 58 +++++++++++++++++++++++++++------ 3 files changed, 54 insertions(+), 12 deletions(-) diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py index ac71953d..ab01ad23 100644 --- a/cube/codegen/codegen.py +++ b/cube/codegen/codegen.py @@ -869,15 +869,16 @@ def emit_reducer_init(self, node: IRWeightReducer) -> None: The fields storing intermediate codes that are populated by this method: - `model_init_statements` """ + max_nbytes = CompileFlag.max_reducer_bucket # reducer init interface - reducer_init = '{reducer} = cube.runtime.adapter.Reducer(ranks={ranks})' + reducer_init = '{reducer} = cube.runtime.adapter.Reducer(ranks={ranks}, max_bucket_size_bytes={max_nbytes})' reducer_add = 'self.add_reducer({reducer})' add_param = '{reducer}.add_param({weight})' # create reducer in declare region weights = node.inputs() reducer_name = f'self.wreducer{node._id}' self.model_init_statements.append('') - init_code = reducer_init.format(reducer=reducer_name, ranks=node.device) + init_code = reducer_init.format(reducer=reducer_name, ranks=node.device, max_nbytes=max_nbytes) self.model_init_statements.append(init_code) weights = [self.tensor_naming(t, prefix_attr='self.') for t in weights] for weight in weights: diff --git a/cube/flags.py b/cube/flags.py index c925fa19..14ae65fb 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -26,6 +26,9 @@ class CompileFlag: # ============== runtime ==================== dev_mode = os.environ.get('SINGLE_DEV_MODE') # allow to use python xx.py + + # maximal reducer weight bytes for one allreduce + max_reducer_bucket = int(os.environ.get('MAX_REDUCER_BUCKET', default=5e8)) # use automate mixture precision training, where weights, gradients # and optimizer status are kept in its original data type (can be float32), diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index d622e879..4357627b 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -4,18 +4,38 @@ from typing import List import torch +import warnings from cube.runtime.device import DeviceGroup from cube.profiler.timer import CudaTimer, print_each_rank + +def get_nbytes(dtype: torch.dtype) -> int: + try: + if dtype.is_floating_point(): + return torch.finfo(dtype).bits // 8 + else: + return torch.iinfo(dtype).bits // 8 + except Exception as e: + warnings.warn(f'Cannot figure out bytes of dtype: {dtype}, set default as 4.') + return 4 + + class Reducer: - def __init__(self, ranks: List[int]): + def __init__(self, ranks: List[int], max_bucket_size_bytes: int): + """ + Create a reducer to synchronize weights and its gradients + + @param ranks List[int]: global ranks the reducer works on + @param max_bucket_size_bytes int: max bytes for one allreduce call + """ self._params: List[torch.nn.Parameter] = list() # note this need to be called for every device self.ranks = ranks self._group = DeviceGroup().get_group(ranks) + self._max_bucket_size_bytes = max_bucket_size_bytes def add_param(self, param: torch.nn.Parameter): self._params.append(param) @@ -27,22 +47,40 @@ def allreduce(self): buckets = {} for param in self._params: if param.requires_grad and param.grad is not None: - tp = param.data.type() + tp = param.data.dtype if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) + + def sync(grads, non_blocking: bool = False) -> None: + """ + inplacement synchronize gradients + + @param non_blocking bool: whether gradient copy is non-blocking + """ + coalesced = self._flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=self._group) + all_synced = self._unflatten_dense_tensors(coalesced, grad_groups) + for grad, synced in zip(grad_groups, all_synced): + grad.copy_(synced, non_blocking=non_blocking) + # for each bucket, do all-reduce + CudaTimer().start(field_name='comm', predefined=True) for tp in buckets: - CudaTimer().start(field_name='comm', predefined=True) bucket = buckets[tp] grads = [param.grad.data for param in bucket] - coalesced = self._flatten_dense_tensors(grads) - # coalesced /= len(self.ranks) - torch.distributed.all_reduce(coalesced, group=self._group) - all_synced = self._unflatten_dense_tensors(coalesced, grads) - for grad, synced in zip(grads, all_synced): - grad.copy_(synced) - CudaTimer().stop(field_name='comm', predefined=True) + nbytes, grad_groups = 0, [] + for grad in grads: + grad_groups.append(grad) + nbytes += grad.nelement() * grad.element_size() + if nbytes >= self._max_bucket_size_bytes: + # print(f'sync barrier: num gradients {len(grad_groups)}') + sync(grad_groups, non_blocking=True) + nbytes, grad_groups = 0, [] + if nbytes > 0: + sync(grad_groups, non_blocking=True) + torch.cuda.synchronize() + CudaTimer().stop(field_name='comm', predefined=True) def sync(self): """ From 5ee74efdb0e19c63e3d0b3e7f49272c415784ce6 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Mon, 19 Dec 2022 07:46:35 +0000 Subject: [PATCH 1201/1892] refine scripts --- scripts/megatron.sh | 2 +- scripts/pre_install.sh | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 scripts/pre_install.sh diff --git a/scripts/megatron.sh b/scripts/megatron.sh index 67bcd5b8..82466184 100644 --- a/scripts/megatron.sh +++ b/scripts/megatron.sh @@ -1,2 +1,2 @@ torchrun --nproc_per_node=2 --nnodes=1 \ - examples/nlp/gpt/train.py --policy=PASMegatronWSRTP --lrw --fp16 | tee -a LogForMegatronRecompute.txt \ No newline at end of file + examples/nlp/gpt/train.py --policy=PASMegatronWSRTP --fp16 | tee -a LogForMegatronRecompute.txt \ No newline at end of file diff --git a/scripts/pre_install.sh b/scripts/pre_install.sh new file mode 100644 index 00000000..cf7d1081 --- /dev/null +++ b/scripts/pre_install.sh @@ -0,0 +1,2 @@ +pip install -r requirements.txt +sudo /opt/conda/bin/python setup.py develop \ No newline at end of file From fefdcd5601cd6c67e7f71e79bab77f2ace8998fa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Dec 2022 06:55:31 +0000 Subject: [PATCH 1202/1892] Merged PR 1436: Multiref for a device with different partitioned tensors --- cube/graph/gener/gen.py | 47 +++++++++++++++++--- cube/graph/graph.py | 11 ++++- cube/graph/segment.py | 86 ++++++++++++++++++++---------------- tests/adapter/test_rvd.py | 91 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 190 insertions(+), 45 deletions(-) create mode 100644 tests/adapter/test_rvd.py diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 02d0fddf..82cab797 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -129,6 +129,8 @@ def gen(graph: IRGraph) -> IRGraph: """ # remove anchor node graph = IRAdapterGener.remove_anchor(graph) + # automatic transform multiref + graph = IRAdapterGener.autoref(graph) # generate adapters for activation graph = IRAdapterGener.gen_activation(graph) # generate weight reducer @@ -556,15 +558,13 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): # collect consumer of each device for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): for devid in ctensor.device: - if devid not in devtensors: - devtensors[devid], devops[devid] = [], [] - assert len(devtensors[devid]) == 0 or devtensors[devid][0] == ctensor, ( + devtensors.setdefault(devid, []).append(ctensor) + devops.setdefault(devid, []).append(consumer) + assert devtensors[devid][0] == ctensor, ( f"Detect that a full tensor is partitioned differently on a device.\n" - f"To achieve this, need manually add multiref operator in model description.\n" + f"To achieve this, need call graph.multiref before graph transformation.\n" f"{graph.debug_tensor_map_str(ftensor)}" ) - devtensors[devid].append(ctensor) - devops[devid].append(consumer) require_multiref = any(len(ops) > 1 for ops in devops.values()) if not require_multiref: return @@ -587,7 +587,7 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): f"Users can try:\n" f" 1) Replicate all operators whose inputs have multi-consumed tensors\n" f" 2) Partition all operators whose inputs have multi-consumed tensors\n" - f" 3) Mannually add cube.runtime.multiref in model description to divide replicated and partitioned groups\n" + f" 3) Call graph.multiref to divide tensors with different partition strategies\n" f"{graph.debug_tensor_map_str(ftensor)}" f"{graph.mirror.debug_tensor_map_str(ftensor.grad)}" ) @@ -619,6 +619,39 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): multiref.recompute = graph.node(min_fidx).recompute graph.finsert(multiref, min_fidx) + @staticmethod + def autoref(graph: IRSegment) -> IRGraph: + """ + Automatically transform inserted multiref. + Multiref is transformed to align with the output tensors on each device. + + @param graph IRGraph + + @return None + """ + for multiref in graph.select(name='multiref'): + ftensor: IRFullTensor = multiref.input(0).parent + for otensor in graph.ptensors(ftensor): + mr = MultiRef(None, [otensor, len(multiref.outputs())]) + for idx in range(len(multiref.outputs())): + output = multiref.output(idx).parent.select(otensor.indmap, otensor.valmap) + if otensor.requires_grad: + output.grad = multiref.output(idx).grad.parent.select(otensor.indmap, (0,1)) + mr.set_output(idx, output) + mr.device = otensor.device + mr.recompute = otensor.cell.recompute + if otensor.requires_grad: + graph.finsert(mr, graph.index(otensor.cell) + 1) + else: + graph.insert(mr, graph.index(otensor.cell) + 1) + # remove original multiref + graph.remove(multiref) + if multiref.mirror is not None: + graph.mirror.remove(multiref.mirror) + for segment in graph.select(ntype=IRSegment): + IRAdapterGener.autoref(segment) + return graph + @staticmethod def fusion(graph: IRSegment) -> IRSegment: """ diff --git a/cube/graph/graph.py b/cube/graph/graph.py index a3f7b508..a30ca113 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -8,7 +8,7 @@ """ from typing import Sequence, Set, Union, Tuple, List, Optional, Dict -from cube.graph.function.anchor import IRGraphAnchor +import warnings from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator @@ -17,6 +17,7 @@ from cube.ir.dtype import IRDType, DTypeInferRule from cube.graph.function.function import Identity, MultiRef +from cube.graph.function.anchor import IRGraphAnchor from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo @@ -278,6 +279,10 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis raise TypeError("Expected op to be forward op or data op") if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") + if node.name == 'multiref': + warnings.warn( + 'Detected partition a multiref node. This will be skipped as system will automatically handle it.') + return [node] fsegment: IRSegment = self.segment(node) # replicate @@ -335,6 +340,10 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], "The partition algorithm is not initialized for this node" assert isinstance(node, (IRFwOperation, IRDataOperation)), \ f"Only allow op to be forward op or data op, but got: {node}" + if node.name == 'multiref': + warnings.warn( + 'Detected partition a multiref node. This will be skipped as system will automatically handle it.') + return [node] # get partitioned sub-nodes fnodes = algo.instantiate(**config) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 5ed0c8df..d74dc3ea 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -344,7 +344,12 @@ def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: def create_bwop(self, fwop: IRFwOperation) -> Union[IRBpOperation, IRBpOperation]: """ - Create dummy backward operator for given forward operator + Create dummy backward operator for given forward operator. + This assumes input/output tensors of fwop have been set by correct gradient tensors. + + @param fwop IRFwOperation: forward operation + + @return bwop IRBpOperation: the created backward operation """ assert isinstance(fwop, (IRFwOperation, IRSegment)), "Expected IRFwOperation" fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] @@ -605,6 +610,7 @@ def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwO the backward of fwop's previous forward node This requires the segment has its backward segment + This assumes inputs/outputs tensors of fwop have been set with correct gradient @param fwop IRFwOperation: forward node @param index Union[int, CellPosition]: inserted position @@ -635,59 +641,65 @@ def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwO # ===================== Advance Graph manipulations ================== - def multiref(self, tensor: IRSubTensor, node_groups: List[List[IRFwOperation]]) -> IRFwOperation: + def multiref(self, ftensor: IRFullTensor, node_groups: List[List[IRFwOperation]]) -> IRFwOperation: """ - Add multiref to separate nodes into different tensor alias. - Each other consumer that is not in the node_groups will be set as a group. + Add multiref to separate forward nodes that consume a same tensor into different tensor alias. + This should be called before any graph transformation. + + Operators in a group can only be partitioned by a same tensor split strategy. + The created multiref operator will be partitioned automatically when generating + tensor adapters. @param tensor IRSubTensor: tensor. - @param node_groups List[List[IRFwOperation]]: operators that have tensor has input + @param node_groups List[List[IRFwOperation]]: + operators that take the tensor as input. + + @return multiref IRFwOperation: the inserted multiref operator. """ - assert tensor.parent in self._ftensors - # add remaining consumers - node_groups = tuple(node_groups) - for consumer in self.consumers(tensor.parent): - if not any(consumer in nodes for nodes in node_groups): - node_groups = node_groups + ([consumer],) + assert ftensor in self._ftensors, f"tensor: {ftensor} not in this graph." + # check no transformation + if len(self.consumers(ftensor)) <= 1: return + assert not ftensor.is_grad(), f"graph.multiref can only be applied on a non-gradient full tensor." + assert len(set(self.ctensors(ftensor))) == 1, \ + f"Detected happened graph transformation. This interfacee should be called before graph transformation." + # check completeness + consumers = set() + for nodes in node_groups: + consumers.update(nodes) + assert consumers == set(self.consumers(ftensor)), f"some consumer(s) are not in node_groups" # create new full tensors - ftensors = [tensor.parent.like() for _ in node_groups] - otensors = [ft.select(tensor.indmap, tensor.valmap) for ft in ftensors] - # update consumer - insert_idx = CellPosition((self.nnodes,)) - for fidx, nodes in enumerate(node_groups): - for node in nodes: - assert tensor in node.inputs() + tensor = self.ctensors(ftensor)[0] + ftensors: List[IRSubTensor] = [ftensor.like() for _ in node_groups] + otensors: List[IRSubTensor] = [ft.select(tensor.indmap, tensor.valmap) for ft in ftensors] + # update forward / backward consumer + for otensor, nodes in zip(otensors, node_groups): + for idx, node in enumerate(nodes): idx = node.inputs().index(tensor) + grad = node.input(idx).grad with self.update(node): - node.set_input(idx, multiref.output(fidx)) - insert_idx = min(insert_idx, self.index(node)) + node.set_input(idx, otensor) + if tensor.requires_grad: + node.input(idx).grad = otensor.parent.grad.select(otensor.indmap, (idx, len(nodes))) + with self.mirror.update(node.mirror) as bnode: + idx = bnode.outputs().index(grad) + bnode.set_output(idx, node.input(idx).grad) # create multiref multiref = MultiRef('cube.runtime.function.multiref', [tensor, len(node_groups)]) for idx, otensor in enumerate(otensors): multiref.set_output(idx, otensor) - if len(tensor.device) > 0: - multiref.device = tensor.device - # set backward + # setup gradient if tensor.requires_grad: - # add multiref - multiref.input(0).grad = tensor.parent.grad.select(tensor.indmap, (0,1)) + multiref.input(0).grad = tensor.parent.grad.select(tensor.indmap, (0, 1)) for idx, output in enumerate(multiref.outputs()): output.grad = ftensors[idx].grad.select(tensor.indmap, (0,1)) - self.finsert(multiref, insert_idx) - # update forward gradient - for ftensor in ftensors: - self.infer_grad(ftensor) - # update backward operator - for nodes in node_groups + ([multiref,]): - for fnode in nodes: - bidx = self.remove(fnode.mirror) - bnode = self.create_bwop(fnode) - self.insert(bidx, bnode) + # insert multiref + fidx = max(self.index(prod) for prod in self.producers(tensor.parent)) + 1 + if ftensor.requires_grad: + self.finsert(multiref, fidx) else: - self.insert(multiref, insert_idx) + self.insert(multiref, fidx) return multiref - def single_consume(self, one_for_all: bool = True): """ Transform graph to make each non-attribute tensor has up to diff --git a/tests/adapter/test_rvd.py b/tests/adapter/test_rvd.py new file mode 100644 index 00000000..ed9132fa --- /dev/null +++ b/tests/adapter/test_rvd.py @@ -0,0 +1,91 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + tests/adapter/test_rvd.py +""" +from typing import List +import cube +from cube.graph.graph import IRGraph +from cube.ir.operator import IRFwOperation, IRDataOperation +import torch + +cube.init() + + +class TestModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.empty(1024, 1024)) + + def forward(self, x): + x = self.param * x + residual = x + x = x * 2 + x = x + residual + x = torch.sum(x) + return x + + +def _tp(graph, node: IRFwOperation, idx, dim, devs: List[int]): + algo = node.algorithms('dim') + nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert nodes is not None + for devid, node in zip(devs, nodes): + graph.assign(node, devid) + return nodes + + +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + nodes = graph.replicate(node, times=len(devs)) + assert nodes is not None + for devid, node in zip(devs, nodes): + graph.assign(node, devid) + return nodes + + +def test_multiref_intra_rvd(): + + model = TestModel() + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([1024,1024],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) + + def policy(graph: IRGraph, resource): + print(graph.extra_repr()) + devs = list(range(resource.ngpus)) + + for ftensor in graph.full_tensors(): + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) + + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs) + + for node in graph.select(ntype=IRFwOperation): + if node.name == 'mul': + _tp(graph, node, idx=0, dim=0, devs=devs) + elif node.name == 'add': + _tp(graph, node, idx=0, dim=1, devs=devs) + else: + _replica(graph, node, devs) + print(graph.extra_repr()) + return graph + + model = cube.SemanticModel(model) + @cube.compile(model, dataloader, PAS=policy) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + for _ in range(4): + train_iter(model, dataloader) + + +if __name__ == '__main__': + + test_multiref_intra_rvd() From 6420085c720adf5de14d8f63e0e5b1729dbee18c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 28 Dec 2022 12:44:17 +0800 Subject: [PATCH 1203/1892] fix multiref bug with staging --- cube/compiler.py | 5 +++-- cube/graph/gener/gen.py | 5 +++-- cube/graph/segment.py | 31 +++++++++++++++++-------------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index e15514d7..287bce27 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -132,8 +132,9 @@ def decorator(fn: Callable) -> Callable: # check assignment and remove anchor node for node in graph.nodes(flatten=True): - if isinstance(node, IRGraphAnchor) or isinstance(node.mirror, IRGraphAnchor): - continue + # skip graph anchor and multiref: they will be removed or replaced by system + if isinstance(node, IRGraphAnchor) or node.name == 'multiref': + graph.assign(node, 0) if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 82cab797..a75cbf94 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -629,7 +629,7 @@ def autoref(graph: IRSegment) -> IRGraph: @return None """ - for multiref in graph.select(name='multiref'): + for multiref in graph.select(name='multiref', flatten=False): ftensor: IRFullTensor = multiref.input(0).parent for otensor in graph.ptensors(ftensor): mr = MultiRef(None, [otensor, len(multiref.outputs())]) @@ -648,7 +648,8 @@ def autoref(graph: IRSegment) -> IRGraph: graph.remove(multiref) if multiref.mirror is not None: graph.mirror.remove(multiref.mirror) - for segment in graph.select(ntype=IRSegment): + for segment in graph.select(ntype=IRSegment, flatten=False): + if not segment.isfw(): continue IRAdapterGener.autoref(segment) return graph diff --git a/cube/graph/segment.py b/cube/graph/segment.py index d74dc3ea..aaee2168 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -630,7 +630,7 @@ def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwO # insert backward assert fsegment.mirror is not None, "Missing backward segment" bsegment: IRSegment = fsegment.mirror - bidx = 0 + bidx = CellPosition((bsegment.nnodes,)) for idx in range(index - 1, -1, -1): prev_fnode = fsegment.node(idx) if prev_fnode.mirror is not None: @@ -671,18 +671,6 @@ def multiref(self, ftensor: IRFullTensor, node_groups: List[List[IRFwOperation]] tensor = self.ctensors(ftensor)[0] ftensors: List[IRSubTensor] = [ftensor.like() for _ in node_groups] otensors: List[IRSubTensor] = [ft.select(tensor.indmap, tensor.valmap) for ft in ftensors] - # update forward / backward consumer - for otensor, nodes in zip(otensors, node_groups): - for idx, node in enumerate(nodes): - idx = node.inputs().index(tensor) - grad = node.input(idx).grad - with self.update(node): - node.set_input(idx, otensor) - if tensor.requires_grad: - node.input(idx).grad = otensor.parent.grad.select(otensor.indmap, (idx, len(nodes))) - with self.mirror.update(node.mirror) as bnode: - idx = bnode.outputs().index(grad) - bnode.set_output(idx, node.input(idx).grad) # create multiref multiref = MultiRef('cube.runtime.function.multiref', [tensor, len(node_groups)]) for idx, otensor in enumerate(otensors): @@ -693,11 +681,26 @@ def multiref(self, ftensor: IRFullTensor, node_groups: List[List[IRFwOperation]] for idx, output in enumerate(multiref.outputs()): output.grad = ftensors[idx].grad.select(tensor.indmap, (0,1)) # insert multiref - fidx = max(self.index(prod) for prod in self.producers(tensor.parent)) + 1 + if len(self.producers(ftensor)) == 0: + fidx = min(self.index(consumer) for consumer in self.consumers(ftensor)) + else: + fidx = max(self.index(prod) for prod in self.producers(ftensor)) + 1 if ftensor.requires_grad: self.finsert(multiref, fidx) else: self.insert(multiref, fidx) + # update forward / backward consumer + for otensor, nodes in zip(otensors, node_groups): + for idx, node in enumerate(nodes): + fidx = node.inputs().index(tensor) + grad = node.input(fidx).grad + with self.update(node): + node.set_input(fidx, otensor) + if tensor.requires_grad: + node.input(fidx).grad = otensor.parent.grad.select(otensor.indmap, (idx, len(nodes))) + with self.mirror.update(node.mirror) as bnode: + bidx = bnode.outputs().index(grad) + bnode.set_output(bidx, node.input(bidx).grad) return multiref def single_consume(self, one_for_all: bool = True): From d77a5de98d748f081c554566da97b1d02736e92a Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 28 Dec 2022 04:57:33 +0000 Subject: [PATCH 1204/1892] debug and fix multiref --- cube/runtime/adapter/reducer.py | 6 +++--- cube/runtime/schedule/sched1f1b.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 9942f436..8aedc87c 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -23,13 +23,13 @@ def get_nbytes(dtype: torch.dtype) -> int: class Reducer: - def __init__(self, ranks: List[int], bucket_size=536870912): + def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912): self._params: List[torch.nn.Parameter] = list() # note this need to be called for every device self.ranks = ranks self._group = DeviceGroup().get_group(ranks) - self.bucket_size = bucket_size + self.bucket_size = max_bucket_size_bytes def add_param(self, param: torch.nn.Parameter): self._params.append(param) @@ -66,7 +66,7 @@ def allreduce(self): torch.distributed.all_reduce(coalesced, group=self._group) all_synced = self._unflatten_dense_tensors(coalesced, grads) for grad, synced in zip(grads, all_synced): - grad.copy_(synced, non_blocking=non_blocking) + grad.copy_(synced, non_blocking=True) torch.cuda.synchronize() CudaTimer().stop(field_name='comm', predefined=True) diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py index 1c75e04a..ae67ab4c 100644 --- a/cube/runtime/schedule/sched1f1b.py +++ b/cube/runtime/schedule/sched1f1b.py @@ -22,6 +22,8 @@ def run(segment: Callable, # forward body # special case: num_stages == 1: use gradient accum if num_stages == 1: for _ in range(num_microbatch): + # if torch.distributed.get_rank() == 0: + # print(_) inputs = Schedule1F1B.dataloader_step(dataloader) outputs = Schedule1F1B.forward_step(segment, *inputs) input_grads = Schedule1F1B.backward_step(inputs, outputs, (None,)) From e75b9d7df7c4113eadc12b629bcb70514757f02c Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Wed, 28 Dec 2022 10:09:15 +0000 Subject: [PATCH 1205/1892] refine script --- scripts/pre_install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/pre_install.sh b/scripts/pre_install.sh index cf7d1081..08e65899 100644 --- a/scripts/pre_install.sh +++ b/scripts/pre_install.sh @@ -1,2 +1,2 @@ pip install -r requirements.txt -sudo /opt/conda/bin/python setup.py develop \ No newline at end of file +python setup.py develop --user \ No newline at end of file From 1882ccb10f0dabca54576f2f8086ab091c479393 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Fri, 30 Dec 2022 13:00:01 +0800 Subject: [PATCH 1206/1892] remove useless log --- cube/runtime/schedule/sched1f1b.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py index ae67ab4c..e2c9d298 100644 --- a/cube/runtime/schedule/sched1f1b.py +++ b/cube/runtime/schedule/sched1f1b.py @@ -61,8 +61,6 @@ def run(segment: Callable, # forward body # steady for i in range(num_warmup_remaining): - if torch.distributed.get_rank() == 0: - print(torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) # forward Schedule1F1B.push_tail('inputs', inputs) if recompute: From 52568658931fa3d7b64859fa4ac1db9ef8dfd692 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 3 Jan 2023 07:10:00 +0000 Subject: [PATCH 1207/1892] Merged PR 1437: intra-RVD search with device alignment Add device alignment in the search path; Add rvd fallback plan in the search path; --- cube/graph/gener/concurrent.py | 53 +-- cube/graph/gener/gen.py | 19 +- cube/graph/gener/layout.py | 656 +++++++++++++++++++------------- tests/adapter/test_intra_rvd.py | 197 ++++++++++ tests/adapter/test_rvd.py | 91 ----- 5 files changed, 614 insertions(+), 402 deletions(-) create mode 100644 tests/adapter/test_intra_rvd.py delete mode 100644 tests/adapter/test_rvd.py diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 473484f4..31dfc4ba 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -4,6 +4,7 @@ from typing import List, Optional, Dict, Tuple import copy import numpy as np +import sys from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap from cube.ir.adapter.prim import IRAdapterPrim @@ -53,10 +54,13 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) except Exception as e: fadapter = None + color, default = '\033[33m' , '\033[0m' print( - f"full tensor: {fptensors[0].parent} cannot use intra-transformation generation.\n" + f"{color}========== Fail to use intra-RVD ==========\n" + f"full tensor: {fptensors[0].parent}\n" f"Reason: {str(e)}\n" - f"Switch to general P2P communication." + f"Switch to general P2P communication.\n" + f"===========================================\n{default}", file=sys.stderr ) # Case 2: sperating device (cross-shard) @@ -66,12 +70,14 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], fadapter = ConcurrentGener.gen_cross_shard(fptensors, fctensors, bptensors, bctensors) except Exception as e: fadapter = None + color, default = '\033[33m' , '\033[0m' print( - f"full tensor: {fptensors[0].parent} cannot use inter-transformation generation.\n" + f"{color}========== Fail to use inter-RVD ==========\n" + f"full tensor: {fptensors[0].parent}\n" f"Reason: {str(e)}\n" - f"Switch to general P2P communication." + f"Switch to general P2P communication.\n" + f"===========================================\n{default}", file=sys.stderr ) - # Case 3: General cases # warnings.warn('The adapter is generated using P2P communication') if fadapter is None: @@ -100,30 +106,7 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # consumer grid layout olayout = GridLayout.togrid(ftensor, ctensors) # find path - paths, fprims = ilayout.path(olayout) - - # re-assign the operator if miss-ordered - res_layout: GridLayout = paths[-1] - names, from_dev, to_dev = [], [], [] - for itensor, otensor in zip(res_layout.mat.flatten(), olayout.mat.flatten()): - assert len(itensor.device) == 1 and len(otensor.device) == 1, \ - "Expect tensor only has one device. Report this as a bug" - if itensor.device != otensor.device: - # TODO: need to be robust: multiref to a node type - if otensor.cell.name == 'multiref': - raise RuntimeError("auto-inserted multiref cannot be re-ordered") - inode, onode = itensor.cell, otensor.cell - names.append(f'{onode.name}{onode.cid}') - from_dev.append(onode.device[0]) - to_dev.append(inode.device[0]) - if allow_reorder: - onode.device = inode.device - if onode.mirror is not None: - onode.mirror.device = inode.device - else: - raise RuntimeError("device mismatch. Try to enable reorder") - if len(names) > 0: - print(f'UserWarning: a better device placement is found and set for op {names}: {from_dev} -> {to_dev}') + paths, fprims = PathFinder.intra_path(ftensor, ilayout, olayout) fadapter = IRAdapter(fptensors, fctensors) fadapter.prims = fprims @@ -140,10 +123,16 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], ptensors[idx] = bptensor ilayout = GridLayout.togrid(grad, ptensors) olayout = GridLayout.togrid(grad, bctensors) - paths, bprims = ilayout.path(olayout) + # paths, bprims = ilayout.path(olayout) + paths, bprims = PathFinder.intra_path(grad, ilayout, olayout) # check the device order - for itensor, otensor in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): - assert len(itensor.device) == len(otensor.device), "backward device not match" + same_device = True + for t in paths[-1].mat.flatten(): + if not any(t == c and set(t.device) == set(c.device) for c in bctensors): + same_device = False + break + assert same_device, "backward device not match" + # generate backward adapter badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims IRAdapter.make_pair(fadapter, badapter) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index a75cbf94..4c98f08a 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -127,6 +127,8 @@ def gen(graph: IRGraph) -> IRGraph: @param graph IRGraph: the graph without adapter @return graph IRGraph: the graph with adapter inserted """ + # reorder producer and consumer ordering + graph._reorder_producer_consumer() # remove anchor node graph = IRAdapterGener.remove_anchor(graph) # automatic transform multiref @@ -265,9 +267,6 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: fdummies = create_dummy(graph, inputs=True, outputs=True) bdummies = [fwop.mirror for fwop in fdummies if fwop.mirror is not None] bgraph: Optional[IRSegment] = graph.mirror - - # reorder producers and consumers - graph._reorder_producer_consumer() # local producer fusion and local consumer multiref ftensors = [] @@ -631,6 +630,7 @@ def autoref(graph: IRSegment) -> IRGraph: """ for multiref in graph.select(name='multiref', flatten=False): ftensor: IRFullTensor = multiref.input(0).parent + multirefs = [] for otensor in graph.ptensors(ftensor): mr = MultiRef(None, [otensor, len(multiref.outputs())]) for idx in range(len(multiref.outputs())): @@ -640,14 +640,17 @@ def autoref(graph: IRSegment) -> IRGraph: mr.set_output(idx, output) mr.device = otensor.device mr.recompute = otensor.cell.recompute - if otensor.requires_grad: - graph.finsert(mr, graph.index(otensor.cell) + 1) - else: - graph.insert(mr, graph.index(otensor.cell) + 1) + multirefs.append(mr) # remove original multiref - graph.remove(multiref) + fidx = graph.remove(multiref) if multiref.mirror is not None: graph.mirror.remove(multiref.mirror) + # insert multirefs + for ofst, multiref in enumerate(multirefs): + if ftensor.requires_grad: + graph.finsert(multiref, fidx + ofst) + else: + graph.insert(multiref, fidx + ofst) for segment in graph.select(ntype=IRSegment, flatten=False): if not segment.isfw(): continue IRAdapterGener.autoref(segment) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index f52d1520..2e3fc768 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -1,6 +1,7 @@ from typing import Callable, Dict, List, Tuple, Optional import copy import numpy as np +import sys from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -83,332 +84,354 @@ def __repr__(self): dscp = f'T{self.ftensor._id}' return dscp - # ====== inshard transformation primitives ===== # + # ====== intra-RVD transition primitives ====== # - def d2r(self, dim: int, chunks: int): + def d2r(self, dim: int, chunks: int) -> Tuple: """ - RVD Primitive: dimension to replica - collective: allgather + intra-RVD primitive D->R: allgather + + @param dim int: tensor dimension + @param chunks int: the number of chunks to transfer + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - assert layout[2+dim] % chunks == 0, f"not dividable dim: {layout[2+dim]} // {chunks}" - layout[0] = layout[0] * chunks - layout[2+dim] = layout[2+dim] // chunks - glayout = GridLayout.grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) - # set device - imat = GridLayout.transpose(self.mat, 0, 2+dim) - omat = GridLayout.transpose(glayout.mat, 2+dim, 0) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor._cell = itensor._cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(AllGatherPrim(itensors, otensors, dim)) - return glayout, prims - - def d2d(self, from_dim: int, to_dim: int, chunks: int): + assert self.D[dim] % chunks == 0, f"not dividable dim: {self.D[dim]} // {chunks}" + rvd = list(self.vec) + rvd[0], rvd[2+dim] = rvd[0] * chunks, rvd[2+dim] // chunks + + ret: List[Tuple[GridLayout, List[IRAdapterPrim]]] = [] + # collect all possible layouts + olayouts: List[GridLayout] = [GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:])] + if self.R != 1: + olayout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) + olayout.inner_transpose(0, chunks) + olayouts.append(olayout) + # generate primitives for all possible cases + for olayout in olayouts: + imat = GridLayout.transpose(self.mat, 0, 2+dim) + omat = GridLayout.transpose(olayout.mat, 2+dim, 0) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor.cell = itensor.cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(AllGatherPrim(itensors, otensors, dim)) + ret.append((olayout, prims)) + return ret + + def d2d(self, from_dim: int, to_dim: int, chunks: int) -> Tuple: """ - RVD Primitive: dimension to dimension - collective: all-to-all + intra-RVD primitive D(...,i,..)->D(..,j,...): alltoall + + @param from_dim int: source tensor axis + @param to_dim int: destination tensor axis + @param chunks int: the number of chunks to transfer + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - assert layout[2+from_dim] % chunks == 0, f"not dividable dim: {layout[2+from_dim]} // {chunks}" - layout[2+from_dim] = layout[2+from_dim] // chunks - layout[2+to_dim] = layout[2+to_dim] * chunks - glayout = GridLayout.grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) - # set device - imat = GridLayout.transpose(self.mat, 2+to_dim, 2+from_dim) - omat = GridLayout.transpose(glayout.mat, 2+from_dim, 2+to_dim) + assert self.D[from_dim] % chunks == 0, f"not dividable dim: {self.D[from_dim]} // {chunks}" + rvd = list(self.vec) + rvd[2+from_dim], rvd[2+to_dim] = rvd[2+from_dim] // chunks, rvd[2+to_dim] * chunks + layout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) + # d2d has no ambiguity on device mapping + imat = GridLayout.transpose(self.mat, 2+to_dim, 2+from_dim) + omat = GridLayout.transpose(layout.mat, 2+from_dim, 2+to_dim) for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor._cell = itensor._cell + otensor.cell = itensor.cell prims = [] for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): prims.append(AllToAllPrim(itensors, otensors, from_dim, to_dim)) - return glayout, prims + return [(layout, prims)] - def v2r(self, chunks: int): + def v2r(self, chunks: int) -> Tuple: """ - RVD Prmitive: value to replica - collective: all-reduce + intra-RVD primitive V->R: allreduce + + @param dim int: tensor dimension + @param chunks int: the number of chunks to transfer + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - assert layout[1] % chunks == 0, f"not dividable value chunks: {layout[1]} // {chunks}" - layout[1] = layout[1] // chunks - layout[0] = layout[0] * chunks - glayout = GridLayout.grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) - # set device - imat = GridLayout.transpose(self.mat, 0, 1) - omat = GridLayout.transpose(glayout.mat, 1, 0) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor._cell = itensor._cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(AllReducePrim(itensors, otensors)) - return glayout, prims - - def v2d(self, dim: int, chunks: int): + assert self.V % chunks == 0, f"not dividable value chunks: {self.V} // {chunks}" + rvd = list(self.vec) + rvd[1], rvd[0] = rvd[1] // chunks, rvd[0] * chunks + + ret: List[Tuple[GridLayout, List[IRAdapterPrim]]] = [] + # collect all possible layouts + ilayouts: List[GridLayout] = [self] + if self.V != chunks: + ilayout = GridLayout.grid(self.ftensor, r=self.R, v=self.V, dims=self.D) + for t1, t2 in zip(self.mat.flatten(), ilayout.mat.flatten()): + t2.cell = t1.cell + ilayout.inner_transpose(1, chunks) + ilayouts.append(ilayout) + olayouts: List[GridLayout] = [GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:])] + if self.R != 1: + olayout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) + olayout.inner_transpose(0, chunks) + olayouts.append(olayout) + # generate primitives for all possible cases + for ilayout in ilayouts: + for olayout in olayouts: + imat = GridLayout.transpose(ilayout.mat, 0, 1) + omat = GridLayout.transpose(olayout.mat, 1, 0) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor.cell = itensor.cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(AllReducePrim(itensors, otensors)) + ret.append((olayout, prims)) + return ret + + def v2d(self, dim: int, chunks: int) -> Tuple: + """ + intra-RVD primitive V->D: reduce-scatter + + @param dim int: tensor dimension + @param chunks int: the number of chunks to transfer + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - RVD Primitive: value to dimension - collective: reduce-scatter + assert self.V % chunks == 0, f"not dividable value chunks: {self.V} // {chunks}" + rvd = list(self.vec) + rvd[1], rvd[2+dim] = rvd[1] // chunks, rvd[2+dim] * chunks + + ret: List[Tuple[GridLayout, List[IRAdapterPrim]]] = [] + # collect all possible layouts + ilayouts = [self] + if self.V != chunks: + ilayout = GridLayout.grid(self.ftensor, r=self.R, v=self.V, dims=self.D) + for t1, t2 in zip(self.mat.flatten(), ilayout.mat.flatten()): + t2.cell = t1.cell + ilayout.inner_transpose(1, chunks) + ilayouts.append(ilayout) + # generate primitives for all possible cases + for ilayout in ilayouts: + olayout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) + imat = GridLayout.transpose(self.mat, 2+dim, 1) + omat = GridLayout.transpose(olayout.mat, 1, 2+dim) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor.cell = itensor.cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(ReduceScatterPrim(itensors, otensors, dim)) + ret.append((olayout, prims)) + return ret + + def r2d(self, dim: int, chunks: int) -> Tuple: """ - layout = list(self.vec) - assert layout[1] % chunks == 0, f"not dividable value chunks: {layout[0]} // {chunks}" - layout[1] = layout[1] // chunks - layout[2+dim] = layout[2+dim] * chunks - glayout = GridLayout.grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) - # set device - imat = GridLayout.transpose(self.mat, 2+dim, 1) - omat = GridLayout.transpose(glayout.mat, 1, 2+dim) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor._cell = itensor._cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(ReduceScatterPrim(itensors, otensors, dim)) - return glayout, prims - - def r2d(self, dim: int, chunks: int): + intra-RVD primitive V->D: schunk + + @param dim int: tensor axis + @param chunks int: the number of chunks to transfer + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - RVD Primitive: replica to dimension - collective: split + assert self.R % chunks == 0, f"not dividable replica: {self.R} // {chunks}" + rvd = list(self.vec) + rvd[0], rvd[2+dim] = rvd[0] // chunks, rvd[2+dim] * chunks + + ret: List[Tuple[GridLayout, List[IRAdapterPrim]]] = [] + # collect all possible layouts + ilayouts = [self] + # print(f'r->d({dim})[{chunks}]: ilayout-self : {[(t, t.device) for t in self.mat.flatten()]}') + if self.R != chunks: + ilayout = GridLayout.grid(self.ftensor, r=self.R, v=self.V, dims=self.D) + for t1, t2 in zip(self.mat.flatten(), ilayout.mat.flatten()): + t2.cell = t1.cell + ilayout.inner_transpose(0, chunks) + ilayouts.append(ilayout) + # print(f'r->d({dim})[{chunks}]: ilayout-transformed: {[(t, t.device) for t in ilayout.mat.flatten()]}') + # generate primitives for all possible cases + for ilayout in ilayouts: + olayout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) + imat = GridLayout.transpose(ilayout.mat, 2+dim, 0) + omat = GridLayout.transpose(olayout.mat, 0, 2+dim) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor.cell = itensor.cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(ChunkPrim(itensors, otensors, dim)) + ret.append((olayout, prims)) + return ret + + def incr(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: """ - layout = list(self.vec) - assert layout[0] % chunks == 0, f"not dividable replica: {layout[0]} // {chunks}" - layout[0] = layout[0] // chunks - layout[2+dim] = layout[2+dim] * chunks - glayout = GridLayout.grid(self.ftensor, - r=layout[0], v=layout[1], dims=layout[2:]) - # set device - imat = GridLayout.transpose(self.mat, 2+dim, 0) - omat = GridLayout.transpose(glayout.mat, 0, 2+dim) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor._cell = itensor._cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(ChunkPrim(itensors, otensors, dim)) - # ranks = tuple(t.device[0] for t in itensors) - # for idx, (itensor, otensor) in enumerate(zip(itensors, otensors)): - # prims.append(ChunkPrim(itensor, otensor, dim, ranks)) - return glayout, prims + inter-RVD primitive +R: broadcast - def incr(self, chunks: int, devices: Optional[np.ndarray] = None): - """ - RVD+ Prmitive: increase replica - collective: broadcast + @param chunks int: the number of chunks to transfer + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - layout[0] = layout[0] * chunks - glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + rvd = list(self.vec) + rvd[0] = rvd[0] * chunks + olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) # set device if devices is not None: assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) tensor.cell.device = int(devid) # setup prims imat = GridLayout.dims2last(self.mat, [0]).flatten() - omat = GridLayout.dims2last(glayout.mat, [0]).reshape(-1, chunks) + omat = GridLayout.dims2last(olayout.mat, [0]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): if chunks == 1: prims.append(MovePrim([src], dsts)) else: prims.append(BroadcastPrim([src], [src] + list(dsts))) - return glayout, prims + return [(olayout, prims),] - def decr(self, chunks: int, devices: Optional[np.ndarray] = None): + def decr(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: """ - RVD+ Prmitive: decrease replica - collective: move + inter-RVD primitive -R: move + + @param chunks int: the number of chunks to transfer + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - assert layout[0] % chunks == 0, f"not divisible replica: {layout[0]} // {chunks}" - layout[0] = layout[0] // chunks - glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + assert self.R % chunks == 0, f"not divisible replica {self.R} // {chunks}" + rvd = list(self.vec) + rvd[0] = rvd[0] // chunks + olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) # set device if devices is not None: assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) tensor.cell.device = int(devid) # setup prims imat = GridLayout.dims2last(self.mat, [0]).reshape(-1, chunks) - omat = GridLayout.dims2last(glayout.mat, [0]).flatten() + omat = GridLayout.dims2last(olayout.mat, [0]).flatten() prims = [] for srcs, dst in zip(imat, omat): prims.append(MovePrim([srcs[0]], [dst])) - return glayout, prims + return [(olayout, prims),] - def incd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None): + def incd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None) -> Tuple: """ - RVD+ Prmitive: increase dimension - collective: rdscatter + inter-RVD primitive +D: RD-Scatter + + @param chunks int: the number of chunks to transfer + @param dim int: tensor axis + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - layout[2+dim] = layout[2+dim] * chunks - glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + rvd = list(self.vec) + rvd[2+dim] = rvd[2+dim] * chunks + olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) # set device if devices is not None: assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) tensor.cell.device = int(devid) # setup prims imat = GridLayout.dims2last(self.mat, [2+dim]).flatten() - omat = GridLayout.dims2last(glayout.mat, [2+dim]).reshape(-1, chunks) + omat = GridLayout.dims2last(olayout.mat, [2+dim]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): prims.append(RDScatterPrim([src], dsts, dim=dim)) - return glayout, prims + return olayout, prims - def decd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None): + def decd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None) -> Tuple: """ - RVD+ Prmitive: increase dimension - collective: rdgather + inter-RVD primitive +D: RD-Gather + + @param chunks int: the number of chunks to transfer + @param dim int: tensor axis + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - assert layout[2+dim] % chunks == 0, f"not divisible dim: {self.D[dim]} % {chunks} != 0" - layout[2+dim] = layout[2+dim] // chunks - glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + assert self.D[dim] % chunks == 0, f"not divisible dim: {self.D[dim]} % {chunks} != 0" + rvd = list(self.vec) + rvd[2+dim] = rvd[2+dim] // chunks + olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) # set device if devices is not None: assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) tensor.cell.device = int(devid) # setup prims imat = GridLayout.dims2last(self.mat, [2+dim]).reshape(-1, chunks) - omat = GridLayout.dims2last(glayout.mat, [2+dim]).flatten() + omat = GridLayout.dims2last(olayout.mat, [2+dim]).flatten() prims = [] for srcs, dst in zip(imat, omat): prims.append(RDGatherPrim(srcs, [dst], dim=dim)) - return glayout, prims + return [(olayout, prims),] - def incv(self, chunks: int, devices: Optional[np.ndarray] = None): + def incv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: """ - RVD+ Primitive: increase value partition - collective: rvscatter + inter-RVD primitive +V: RV-Scatter + + @param chunks int: the number of chunks to transfer + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - layout[1] = layout[1] * chunks - glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + rvd = list(self.vec) + rvd[1] = rvd[1] * chunks + olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) # set device if devices is not None: assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) tensor.cell.device = int(devid) # setup prims imat = GridLayout.dims2last(self.mat, [1]).flatten() - omat = GridLayout.dims2last(glayout.mat, [1]).reshape(-1, chunks) + omat = GridLayout.dims2last(olayout.mat, [1]).reshape(-1, chunks) prims = [] for src, dsts in zip(imat, omat): prims.append(RVScatterPrim([src], dsts)) - return glayout, prims + return [(olayout, prims),] - def decv(self, chunks: int, devices: Optional[np.ndarray] = None): + def decv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: """ - RVD+ Primitive: decrease value partition - collective: rvgather + inter-RVD primitive -V: RV-Gather + + @param chunks int: the number of chunks to transfer + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ - layout = list(self.vec) - assert layout[1] % chunks == 0, f"not divisable value split: {self.V} % {chunks} != 0" - layout[1] = layout[1] * chunks - glayout = GridLayout.grid(self.ftensor, layout[0], layout[1], layout[2:]) + assert self.V % chunks == 0, f"not divisable value split: {self.V} % {chunks} != 0" + rvd = list(self.vec) + rvd[1] = rvd[1] // chunks + olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) # set device if devices is not None: assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(glayout.mat.flatten(), devices.flatten()): + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) tensor.cell.device = int(devid) # setup prims imat = GridLayout.dims2last(self.mat, [1]).reshape(-1, chunks) - omat = GridLayout.dims2last(glayout.mat, [1]).flatten() + omat = GridLayout.dims2last(olayout.mat, [1]).flatten() prims = [] for srcs, dst in zip(imat, omat): prims.append(RVGatherPrim(srcs, [dst])) - return glayout, prims + return [(olayout, prims),] # ================ solution ============= # - def path(self, dst) -> Tuple: - """ - Find a path from self to destination GridLayout using - primitivies. This implementation uses search order of - R -> V -> S. - - Args: - dst: GridLayout - auto_replace: bool - If true, the consumer operator may be replaced - to match the device assignment. - - Return: - paths: List[GridLayout] - the search path from source GridLayout (self) - to destination GridLayout (self) - comm_prims: List[IRAdapterPrim] - communication primitives for translation - """ - def step(ilayout: GridLayout, dec_idx: int, inc_idx: int, chunks: int) -> GridLayout: - if dec_idx >= 2 and inc_idx == 0: # d2r - return ilayout.d2r(dec_idx-2, chunks) - if dec_idx >= 2 and inc_idx >= 2: # d2d - return ilayout.d2d(dec_idx-2, inc_idx-2, chunks) - if dec_idx == 1 and inc_idx == 0: # v2r - return ilayout.v2r(chunks) - if dec_idx == 1 and inc_idx >= 2: # v2d - return ilayout.v2d(inc_idx-2, chunks) - if dec_idx == 0 and inc_idx >= 2: # r2d - return ilayout.r2d(inc_idx-2, chunks) - raise RuntimeError("Cannot find primitive. Report as a bug") - - comm_prims = [] - paths: List[GridLayout] = [self] - dst: GridLayout = dst - while paths[-1].vec != dst.vec: - src: GridLayout = paths[-1] - inc_idx, dec_idx = None, None - for idx, (schunk, dchunk) in enumerate(zip(src.vec, dst.vec)): - if schunk != dchunk: - # print(f'src: {src.vec}, dst: {dst.vec}') - if schunk < dchunk: - inc_idx = idx # src should increase chunks on idx-dim - need_chunks = dchunk // schunk if dchunk % schunk == 0 else dchunk - for dec_idx in range(inc_idx+1, self.ndims): - # print(f'{dec_idx}/{self.ndims}') - if src.vec[dec_idx] > dst.vec[dec_idx]: - if src.vec[dec_idx] % dst.vec[dec_idx] != 0: - available_chunks = dst.vec[dec_idx] - else: - available_chunks = src.vec[dec_idx] // dst.vec[dec_idx] - chunks = min(available_chunks, need_chunks) - break - else: - raise RuntimeError("Cannot find feassible dimension. Report this as a bug.") - else: - dec_idx = idx - need_chunks = schunk // dchunk if schunk % dchunk == 0 else schunk - for inc_idx in range(dec_idx+1, self.ndims): - if src.vec[inc_idx] < dst.vec[inc_idx]: - if dst.vec[inc_idx] % src.vec[inc_idx] != 0: - available_chunks = dst.vec[inc_idx] - else: - available_chunks = dst.vec[inc_idx] // src.vec[inc_idx] - chunks = min(available_chunks, need_chunks) - break - else: - raise RuntimeError("Cannot find feassible dimension. Report this as a bug.") - # print(chunks, need_chunks) - olayout, oprims = step(src, dec_idx, inc_idx, chunks) - paths.append(olayout) - comm_prims += oprims - break - return paths, comm_prims - def print_dev_tensors(self): """ print each device hold tensors. @@ -426,6 +449,22 @@ def print_dev_tensors(self): for tensor in devices[dev]: print(f'\t{tensor.extra_repr()}') + def inner_transpose(self, dim: int, chunks: int): + """ + transpose ordering of tensor within a dimension. + """ + assert 0 <= dim and dim < len(self._mats.shape) + assert self.vec[dim] % chunks == 0 + ori_shape = list(self.vec) + new_shape = list(self.vec) + new_shape.insert(dim, self.vec[dim] // chunks) + new_shape[dim+1] = chunks + self._mats = self._mats.reshape(new_shape) + axes = list(range(len(new_shape))) + axes[dim], axes[dim+1] = axes[dim+1], axes[dim] + self._mats = self._mats.transpose(axes) + self._mats = self._mats.reshape(ori_shape) + @staticmethod def transpose(mat: np.ndarray, dim0: int, dim1: int): """ @@ -484,11 +523,12 @@ def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int], devices: Optio """ partition a ftensor using grid layout of """ + dims = tuple(dims) def dummy_assign(tensor: IRSubTensor, devid: int): tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) tensor.cell.device = devid - mats = np.empty([r, v] + dims, dtype=IRSubTensor) + mats = np.empty((r, v) + dims, dtype=IRSubTensor) all_subtensors = [] def iter_idx(dims: List[int]) -> Tuple[int]: @@ -499,7 +539,7 @@ def iter_idx(dims: List[int]) -> Tuple[int]: for indices in iter_idx(dims[1:]): yield (i,) + indices # generate tensor for each index - for indices in iter_idx([v,]+dims): + for indices in iter_idx((v,)+dims): valmap = ValueMap((indices[0], v)) indmap = [] shape = [] @@ -606,7 +646,9 @@ class PathFinder: _cached_inter_paths: Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]] = {} @staticmethod - def intra_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: + def intra_path(ftensor: IRFullTensor, + ilayout: GridLayout, olayout: GridLayout, + cost_fn: Optional[Callable] = None, allow_fallback: bool = True) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: """ Get primitive path of transforming ilayout into olayout. ilayout has the same device set with olayout @@ -616,6 +658,7 @@ def intra_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, @param olayout GridLayout: output tensor layout @param cost_fn Optional[Callable]: cost function of each primitive. Default (None) will use transmission volume as metrics + @param allow_fallback bool: allow to use a fixed backup plan to make sure correct device mapping. (default True) @return layouts List[GridLayout]: each transformation. @return prims List[IRAdapterPrim]: the primitives to perform transformation. @@ -625,7 +668,7 @@ def intra_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, key = (shape, ilayout.ndevs) src = (ilayout.R, ilayout.V) + tuple(ilayout.D) dst = (olayout.R, olayout.V) + tuple(olayout.D) - if src == dst: return [], [] + if src == dst: return [ilayout], [] # get paths using dijkstra algorithm or cached if key in PathFinder._cached_intra_paths and src in PathFinder._cached_intra_paths[key]: @@ -654,40 +697,64 @@ def intra_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, if cost[idx] < min_cost: min_cost = idx visit = idx - if visit is None: break + if visit is None: break # for remaining states that cannot reach for neighbor in np.where(edges[visit] != np.inf)[0]: - if neighbor in visited: continue new_cost = cost[visit] + edges[visit, neighbor] if cost[neighbor] == np.inf or new_cost < cost[neighbor]: cost[neighbor] = new_cost paths[neighbor] = paths[visit] + [neighbor] - cost[neighbor] = min(cost[neighbor], cost[visit] + edges[visit, neighbor]) unvisited.remove(visit) visited.add(visit) PathFinder._cached_intra_paths[key][src] = paths # print for debug - for idx, path in enumerate(paths): - print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") + # for idx, path in enumerate(paths): + # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") # get layout nodes = PathFinder._cached_intra_nodes[key] - path = paths[nodes.index(dst)] + path: List[int] = paths[nodes.index(dst)] + rvds: List[Tuple[int]] = [nodes[idx] for idx in path] assert len(path) > 0, f"Un-reachable src RVD ({src}) -> dst RVD ({dst})" + # print(f'path: {rvds}') + + # search for correct device mapping + success, layouts, all_prims = PathFinder.intra_dev_align(ftensor, rvds[1:], [ilayout], [], olayout) + if not success: + ptensors_str = 'ptensors: ' + for ptensor in ilayout.mat.flatten(): + ptensors_str += " " + repr(ptensor) + f' dev{ptensor.device}' + ctensors_str = 'ctensors: ' + for ctensor in olayout.mat.flatten(): + ctensors_str += " " + repr(ctensor) + f' dev{ctensor.device}' + error_msg = ( + f"Fail to align intra-RVD devices. {ftensor}\n" + # f"{ptensors_str}\n" + # f"{ctensors_str}\n" + f"Switch to a fixed plan: ilayout -> FullReplica -> olayout" + ) + color, default = '\033[33m' , '\033[0m' + print(color+error_msg+default, file=sys.stderr) + if allow_fallback: + # switch to a fixed plan ilayout -> R(n)V(1)D(1*) -> olayout + rlayout = GridLayout.grid(ftensor, r=ilayout.ndevs, v=1, dims=tuple(1 for _ in range(ilayout.ndims-2))) + for t1, t2 in zip(ilayout.mat.flatten(), rlayout.mat.flatten()): + t2.cell = t1.cell + # find left + left: List[int] = paths[nodes.index(tuple(rlayout.vec))] + left = [nodes[idx] for idx in left] + lsuccess, llayouts, lprims = PathFinder.intra_dev_align(ftensor, left[1:], [ilayout], [], rlayout) + assert lsuccess, f"Switch fail to generate left-half intra-RVD plans for all-replica" + # find right + rlayouts, rprims = PathFinder.intra_path(ftensor, rlayout, olayout, cost_fn, allow_fallback=False) + layouts = llayouts + rlayouts + all_prims = lprims + rprims + else: + # allow_fallback is False only for generating right-half intra-RVD + assert False, f"Switch fail to generate right-half intra-RVD plans from all-replica" - layouts = [ilayout] - all_prims = [] - curr_rvd = src - for hop in path[1:]: - hop_rvd = nodes[hop] - ret, layout, prims = PathFinder.intra_transform(ftensor, curr_rvd, hop_rvd, layouts[-1]) - assert ret, "Internal Error." - layouts.append(layout) - all_prims += prims - curr_rvd = hop_rvd return layouts, all_prims - @staticmethod def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: """ @@ -739,12 +806,10 @@ def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, visit = idx if visit is None: break for neighbor in np.where(edges[visit] != np.inf)[0]: - if neighbor in visited: continue new_cost = cost[visit] + edges[visit, neighbor] if cost[neighbor] == np.inf or new_cost < cost[neighbor]: cost[neighbor] = new_cost paths[neighbor] = paths[visit] + [neighbor] - cost[neighbor] = min(cost[neighbor], cost[visit] + edges[visit, neighbor]) unvisited.remove(visit) visited.add(visit) PathFinder._cached_inter_paths[key][src] = paths @@ -764,7 +829,7 @@ def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, # print('result device map:', list(cdevs.flatten())) for hop in cpaths[:-1][::-1]: hop_rvd = nodes[hop][1:] - curr_devs = PathFinder.intra_devmap(curr_node, hop_rvd, curr_devs) + curr_devs = PathFinder.inter_devmap(curr_node, hop_rvd, curr_devs) curr_node = hop_rvd consumer_entry_devs = curr_devs # print('calculated consumer device map: ', list(cdevs.flatten())) @@ -775,8 +840,9 @@ def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, use_inter_step = side != nodes[hop][0] hop_rvd = nodes[hop][1:] if not use_inter_step: - ret, layout, prims = PathFinder.intra_transform(ftensor, curr_rvd, hop_rvd, layouts[-1]) - assert ret, "Internal Error" + ret, layout_prims = PathFinder.intra_transition(ftensor, curr_rvd, hop_rvd, layouts[-1]) + assert ret, "Internal Error: intra-RVD transition failed" + layout, prims = layout_prims[0] # the first only is enough for inter-rvd else: ret, layout, prims = PathFinder.inter_transform(ftensor, curr_rvd, hop_rvd, layouts[-1], consumer_entry_devs) layouts.append(layout) @@ -786,7 +852,8 @@ def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, return layouts, all_prims @staticmethod - def intra_transform(ftensor: IRFullTensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout: Optional[GridLayout] = None) -> Tuple[GridLayout, List[IRAdapterPrim]]: + def intra_transition(ftensor: IRFullTensor, src_rvd: TRVD, dst_rvd: TRVD, + ilayout: Optional[GridLayout] = None) -> Tuple[bool, List[Tuple[GridLayout, List[IRAdapterPrim]]]]: """ Get output layout and transform primitives from a source rvd layout to dst_rvd layout, @@ -795,7 +862,7 @@ def intra_transform(ftensor: IRFullTensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout @param dst_rvd Tuple[int] @param ilayout Optional[GridLayout] - @return ret bool: True if there is a primitive performed + @return ret bool: True if trainsition is successful. Otherwise False. @return layout Optonal[GridLayout]: the RVD layout if ilayout is not None @return prims Optional[List[IRAdapterPrim]]: the prmitives in transformation """ @@ -803,30 +870,76 @@ def intra_transform(ftensor: IRFullTensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout assert src_rvd == tuple(ilayout.vec) inc_dims, dec_dims = GridLayout.changed_dims(src_rvd, dst_rvd) if len(inc_dims) != 1 or len(dec_dims) != 1: - return False, None, None + return False, [(None, [])] inc_idx, dec_idx = inc_dims[0], dec_dims[0] if src_rvd[dec_idx] % dst_rvd[dec_idx] != 0: - return False, None, None + return False, [(None, [])] if inc_idx == 1: - return False, None, None + return False, [(None, [])] src = ilayout if ilayout is not None else GridLayout.grid(ftensor, src_rvd[0], src_rvd[1], list(src_rvd[2:])) chunks = src_rvd[dec_idx] // dst_rvd[dec_idx] if dec_idx >= 2 and inc_idx == 0: # d2r - olayout, prims = src.d2r(dec_idx-2, chunks) + ret = src.d2r(dec_idx-2, chunks) elif dec_idx >= 2 and inc_idx >= 2: # d2d - olayout, prims = src.d2d(dec_idx-2, inc_idx-2, chunks) + ret = src.d2d(dec_idx-2, inc_idx-2, chunks) elif dec_idx == 1 and inc_idx == 0: # v2r - olayout, prims = src.v2r(chunks) + ret = src.v2r(chunks) elif dec_idx == 1 and inc_idx >= 2: # v2d - olayout, prims = src.v2d(inc_idx-2, chunks) + ret = src.v2d(inc_idx-2, chunks) elif dec_idx == 0 and inc_idx >= 2: # r2d - olayout, prims = src.r2d(inc_idx-2, chunks) + ret = src.r2d(inc_idx-2, chunks) else: raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") - return True, (olayout if ilayout is not None else None), prims + return True, ret + + @staticmethod + def intra_dev_align(ftensor: IRFullTensor, remain_states: List[Tuple[int]], + ilayouts: List[GridLayout], all_prims: List[IRAdapterPrim], + olayout: GridLayout) -> Tuple[bool, List[GridLayout], List[IRAdapterPrim]]: + """ + Align devices for intra-RVD + + @param ftensor IRFullTensor + @param remain_states List[TRVD]: RVD representations + @param ilayouts List[GridLayout]: searched layouts + @param all_prims List[IRAdapterPrim]: searched primitives + @param olayout GridLayout: target layout with correct device mapping + + @return success bool: True if found device, else False. + @return layouts List[GridLayout]: the searched layouts with device match + @return primitives List[IRAdapterPrim]: the correspoinding primitives + """ + ilayout = ilayouts[-1] + if len(remain_states) == 0: + # print(f'transformed tensors: {[(t, t.device) for t in ilayout.mat.flatten()]}') + # print(f'destination tensors: {[(t, t.device) for t in olayout.mat.flatten()]}') + # check device mapping + otensors: List[IRSubTensor] = olayout.mat.flatten().tolist() + for itensor in ilayout.mat.flatten(): + dev_match = False + for idx in range(len(otensors)): + otensor = otensors[idx] + if otensor == itensor and set(otensor.device) == set(itensor.device): + otensors.pop(idx) + dev_match = True + break + if not dev_match: return False, [], [] + return True, ilayouts, all_prims + else: + success, layout_prims = PathFinder.intra_transition( + ftensor, (ilayout.R, ilayout.V) + ilayout.D, remain_states[0], ilayout) + assert success, "Internal Error at intra-RVD transition" + for (hop_layout, hop_prims) in layout_prims: + # print(f'hop layout: {[(t, t.device) for t in hop_layout.mat.flatten()]}') + # print(f'dst layout: {[(t, t.device) for t in olayout.mat.flatten()]}') + ret, ret_layouts, ret_prims = PathFinder.intra_dev_align( + ftensor, remain_states[1:], ilayouts + [hop_layout], all_prims + hop_prims, olayout) + if ret: + return True, ret_layouts, ret_prims + return False, [], [] @staticmethod - def intra_devmap(src_rvd: TRVD, dst_rvd: TRVD, src_devs: np.ndarray): + def inter_devmap(src_rvd: TRVD, dst_rvd: TRVD, src_devs: np.ndarray): """ Infer device from source rvd to destination rvd """ @@ -878,11 +991,11 @@ def inter_transform(ftensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout: Optional[Gri return False, None, None chunks = dst_rvd[inc_idx] // src_rvd[inc_idx] if inc_idx == 0: - olayout, prims = src.incr(chunks, dst_devs) + olayout, prims = src.incr(chunks, dst_devs)[0] elif inc_idx == 1: - olayout, prims = src.incv(chunks, dst_devs) + olayout, prims = src.incv(chunks, dst_devs)[0] elif inc_idx > 1: - olayout, prims = src.incd(chunks, inc_idx-2, dst_devs) + olayout, prims = src.incd(chunks, inc_idx-2, dst_devs)[0] else: raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") else: @@ -890,11 +1003,11 @@ def inter_transform(ftensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout: Optional[Gri return False, None, None chunks = src_rvd[dec_idx] // dst_rvd[dec_idx] if dec_idx == 0: - olayout, prims = src.decr(chunks, dst_devs) + olayout, prims = src.decr(chunks, dst_devs)[0] elif dec_idx == 1: - olayout, prims = src.decv(chunks, dst_devs) + olayout, prims = src.decv(chunks, dst_devs)[0] elif dec_idx > 1: - olayout, prims = src.decd(chunks, dec_idx-2, dst_devs) + olayout, prims = src.decd(chunks, dec_idx-2, dst_devs)[0] else: raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") return True, (olayout if ilayout is not None else None), prims @@ -917,8 +1030,9 @@ def init_intra_graph(ftensor: IRFullTensor, ndevs: int, cost_fn: Optional[Callab for j in range(len(nodes)): if i == j: continue src, dst = nodes[i], nodes[j] - ret, _, prims = PathFinder.intra_transform(ftensor, src, dst) + ret, layout_and_prims = PathFinder.intra_transition(ftensor, src, dst) if not ret: continue + prims = layout_and_prims[0][1] edges[i, j] = cost_fn(prims[0]) return nodes, edges diff --git a/tests/adapter/test_intra_rvd.py b/tests/adapter/test_intra_rvd.py new file mode 100644 index 00000000..8ea0663f --- /dev/null +++ b/tests/adapter/test_intra_rvd.py @@ -0,0 +1,197 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + tests/adapter/test_intra_rvd.py +""" +from typing import List, Tuple +import cube +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.ir.tensor import IRFullTensor +from cube.graph.graph import IRGraph +from cube.graph.gener.layout import GridLayout +from cube.graph.function.dimops import IRDimops +from cube.algorithm.generics import GenericDistAlgo +import torch +import numpy as np + +cube.init() + + +class RVDSplit(GenericDistAlgo): + + def __init__(self, node: IRDimops): + super().__init__(node) + + def satisfy(self, in_rvd: Tuple[int], out_rvd: Tuple[int]): + return True + + def instantiate(self, in_rvd: Tuple[int], out_rvd: Tuple[int]) -> List[IRFwOperation]: + assert np.prod(np.array(in_rvd, dtype=int)) == np.prod(np.array(out_rvd, dtype=int)), \ + f"tensor number not match: {in_rvd}, {out_rvd}" + assert tuple(in_rvd)[2:] == tuple(out_rvd)[2:], f"input / output shape should be same" + + node: IRDimops = self.node + iftensor: IRFullTensor = node.input(0).parent + itensors = GridLayout.grid(iftensor, r=in_rvd[0], v=in_rvd[1], dims=in_rvd[2:]).mat.flatten() + oftensor: IRFullTensor = node.output(0).parent + otensors = GridLayout.grid(oftensor, r=out_rvd[0], v=out_rvd[1], dims=out_rvd[2:]).mat.flatten() + subnodes = [] + for itensor, otensor in zip(itensors, otensors): + subnode = node.new([itensor, 2], [otensor]) + subnodes.append(subnode) + return subnodes + + +class TestModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.empty(1024, 1024)) + + def forward(self, x): + x = torch.matmul(x, self.param) + # residual = x + x = x * 2 + x = x * 2 + x = x * 2 + x = x * 2 + x = x * 2 + x = x * 2 + x = x * 2 + # x = x + residual + x = torch.sum(x) + return x + + +def _ntp(graph, node: IRDimops, idx: int, dim: int, devs: List[int]): + algo = node.algorithms('dim') + nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert nodes is not None + for devid, node in zip(devs, nodes): + graph.assign(node, devid) + return nodes + + +def _tp(graph, node: IRFwOperation, in_rvd: Tuple[int], out_rvd: Tuple[int], devs: List[int]): + algo = RVDSplit(node) + nodes = graph.partition(node, algo, in_rvd=in_rvd, out_rvd=out_rvd) + assert nodes is not None + for devid, node in zip(devs, nodes): + graph.assign(node, devid) + return nodes + + +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + nodes = graph.replicate(node, times=len(devs)) + assert nodes is not None + for devid, node in zip(devs, nodes): + graph.assign(node, devid) + return nodes + + +def test_multiref_intra_rvd(): + + model = TestModel() + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([1024,1024],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) + + def policy(graph: IRGraph, resource): + print(graph.extra_repr()) + devs = list(range(resource.ngpus)) + + for ftensor in graph.full_tensors(): + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) + + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs) + + for node in graph.select(ntype=IRFwOperation): + if node.name == 'multiref': continue + if node.name == 'mul': + _ntp(graph, node, idx=0, dim=0, devs=devs) + elif node.name == 'add': + _ntp(graph, node, idx=0, dim=1, devs=devs) + else: + _replica(graph, node, devs) + print(graph.extra_repr()) + return graph + + model = cube.SemanticModel(model) + @cube.compile(model, dataloader, PAS=policy) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + for _ in range(4): + train_iter(model, dataloader) + + +def test_intra_rvd(): + + model = TestModel() + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([1024,1024],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) + + def policy(graph: IRGraph, resource): + assert resource.ngpus == 4 + print(graph.extra_repr()) + devs = list(range(resource.ngpus)) + + # for ftensor in graph.full_tensors(): + # if len(graph.consumers(ftensor)) > 1: + # graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) + + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs) + + for idx, node in enumerate(graph.select(name='mul')): + if idx == 0: # out: R(4)V(1)D(1,1) -> in: R(1)V(1)D(4,1): schunk + _tp(graph, node, in_rvd=(1,1,4,1), out_rvd=(1,1,4,1), devs=devs) + elif idx == 1: # out: R(1)V(1)D(4,1) -> in: R(1)V(1)D(1,4): all-to-all wil FAIL. expected!! + _tp(graph, node, in_rvd=(1,1,1,4), out_rvd=(1,1,1,4), devs=devs) + elif idx == 2: # out: R(1)V(1)D(1,4) -> in: R(1)V(1)D(2,2): schunk + _tp(graph, node, in_rvd=(1,1,2,2), out_rvd=(1,1,2,2), devs=devs) + elif idx == 3: # out: R(1)V(1)D(2,2) -> in: R(4)V(1)D(1,1): all-gather + all-gather + _tp(graph, node, in_rvd=(4,1,1,1), out_rvd=(1,4,1,1), devs=devs) + elif idx == 4: # out: R(1)V(4)D(1,1) -> in: R(1)V(1)D(4,1): reduce-scatter + _tp(graph, node, in_rvd=(1,1,4,1), out_rvd=(1,1,4,1), devs=devs) + elif idx == 5: # out: R(1)V(1)D(4,1) -> in R(4)V(1)D(1,1): all-gather + _tp(graph, node, in_rvd=(4,1,1,1), out_rvd=(1,4,1,1), devs=devs) + elif idx == 6: # out: R(1)V(4)D(1,1) -> in R(1)V(1)D(2,2): reduce-scatter + reduce-scatter + _tp(graph, node, in_rvd=(1,1,2,2), out_rvd=(1,1,2,2), devs=devs) + else: + assert False + + for node in graph.select(ntype=IRFwOperation): + if len(node.device) == 0: + _replica(graph, node, devs) + + print(graph.extra_repr()) + return graph + + model = cube.SemanticModel(model) + @cube.compile(model, dataloader, PAS=policy) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + model = model.get_gen_module() + + for _ in range(4): + train_iter(model, dataloader) + + + +if __name__ == '__main__': + + # test_multiref_intra_rvd() + test_intra_rvd() diff --git a/tests/adapter/test_rvd.py b/tests/adapter/test_rvd.py deleted file mode 100644 index ed9132fa..00000000 --- a/tests/adapter/test_rvd.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - tests/adapter/test_rvd.py -""" -from typing import List -import cube -from cube.graph.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation -import torch - -cube.init() - - -class TestModel(torch.nn.Module): - - def __init__(self) -> None: - super().__init__() - self.param = torch.nn.Parameter(torch.empty(1024, 1024)) - - def forward(self, x): - x = self.param * x - residual = x - x = x * 2 - x = x + residual - x = torch.sum(x) - return x - - -def _tp(graph, node: IRFwOperation, idx, dim, devs: List[int]): - algo = node.algorithms('dim') - nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert nodes is not None - for devid, node in zip(devs, nodes): - graph.assign(node, devid) - return nodes - - -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - nodes = graph.replicate(node, times=len(devs)) - assert nodes is not None - for devid, node in zip(devs, nodes): - graph.assign(node, devid) - return nodes - - -def test_multiref_intra_rvd(): - - model = TestModel() - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([1024,1024],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) - - def policy(graph: IRGraph, resource): - print(graph.extra_repr()) - devs = list(range(resource.ngpus)) - - for ftensor in graph.full_tensors(): - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) - - for dl in graph.select(ntype=IRDataOperation): - _replica(graph, dl, devs) - - for node in graph.select(ntype=IRFwOperation): - if node.name == 'mul': - _tp(graph, node, idx=0, dim=0, devs=devs) - elif node.name == 'add': - _tp(graph, node, idx=0, dim=1, devs=devs) - else: - _replica(graph, node, devs) - print(graph.extra_repr()) - return graph - - model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - model = model.get_gen_module() - - for _ in range(4): - train_iter(model, dataloader) - - -if __name__ == '__main__': - - test_multiref_intra_rvd() From f6d2bba2c9dae7794ff0dd0840c969f126856349 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Fri, 6 Jan 2023 17:33:39 +0800 Subject: [PATCH 1208/1892] update profiler interface --- cube/profiler/database.py | 68 +++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index b9e06d04..079baaf7 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -29,7 +29,7 @@ class CompProfiler: @staticmethod def profile(func: Callable, shapes: Shapes, dtypes: DTypes, warmup_sec: float = 2, prof_times: int = 50, - **kwargs) -> Tuple[float, float, int, int]: + **kwargs) -> Tuple[float, float, int, Tuple[int]]: """ Profile a function @@ -42,8 +42,8 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, @return fw_span float: the time in milliseconds for forward time @return bw_span float: the time in milliseconds for backward time - @return infer_memory int: the peak memory in bytes after inference of the function - @return train_memory int: the peak memory in bytes after forward with autograd enabled + @return infer_mem int: the peak memory in bytes after inference of the function + @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ assert len(shapes) == len(dtypes), \ f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" @@ -89,17 +89,17 @@ def run_step(func, tensors, kwargs, backward: bool): mtoc = torch.cuda.max_memory_allocated() # in bytes infer_memory = mtoc - mtic - train_memory = 0 + train_mem_info = [] used_tensor = set() # ref torch/utils/checkpoint.py/_checkpoint_without_reentrant def pack_hook(x): - nonlocal train_memory, used_tensor + nonlocal train_mem_info, used_tensor if x.storage().data_ptr() not in used_tensor: used_tensor.add(x.storage().data_ptr()) byte_size = x.element_size() for dim in list(x.size()): byte_size = byte_size * dim - train_memory = train_memory + byte_size + train_mem_info.append(byte_size) return x def unpack_hook(x): @@ -135,7 +135,7 @@ def unpack_hook(x): fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds bw_span = fwbw_span - fw_span - return fw_span, bw_span, infer_memory, train_memory + return fw_span, bw_span, infer_memory, tuple(train_mem_info) class ProfileDataBase: @@ -199,12 +199,12 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @param node IRFwOperation: node of IRGraph @param device int: the device that the node will execute on - @param input_byte_size int: byte size of input tensors - @param param_byte_size int: byte size of param tensors + @return in_mem_info Tuple[int]: byte sizes of input tensors + @return param_mem_info Tuple[int]: byte sizes of param tensors @return fw_span float: the forward span time in milliseconds @return bw_span float: the backward span time in milliseconds @return infer_memory int: the peak memory in bytes after inference of the function - @return train_memory int: the peak memory in bytes after forward with autograd enabled + @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) @@ -215,46 +215,46 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): orig_device = torch.cuda.current_device() torch.cuda.set_device(device) - input_byte_size, param_byte_size = 0, 0 + in_mem_info, param_mem_info = [], [] for t in node.inputs(): if t.is_param(): - param_byte_size = param_byte_size + t.byte_size() + param_mem_info.append(t.byte_size()) else: - input_byte_size = input_byte_size + t.byte_size() + in_mem_info.append(t.byte_size()) # run profiling - fw_span, bw_span, infer_memory, train_memory = \ + fw_span, bw_span, infer_memory, train_mem_info = \ CompProfiler.profile(fn, shapes, dtypes, **kwargs) # log to database key = self._serialize(node) - self.insert(node.signature, key, input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory) + self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info) print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " - f"=> in mem {input_byte_size} | param mem: {param_byte_size} | fw: {round(fw_span, 2)} ms | " - f"bw: {round(bw_span, 2)} ms | infer mem: {infer_memory} | train mem: {train_memory}") + f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | fw: {round(fw_span, 2)} ms | " + f"bw: {round(bw_span, 2)} ms | infer mem: {infer_memory} | train mem info: {train_mem_info}") if isinstance(device, int): torch.cuda.set_device(orig_device) - return input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory + return tuple(in_mem_info), tuple(param_mem_info), fw_span, bw_span, infer_memory, train_mem_info - def insert(self, name: str, key: str, input_byte_size: int, param_byte_size: int, - fw_span: float, bw_span: float, infer_memory: int, train_memory: int): + def insert(self, name: str, key: str, in_mem_info: Tuple[int], param_mem_info: Tuple[int], + fw_span: float, bw_span: float, infer_memory: int, train_mem_info: Tuple[int]): """ log the span of a function name with key @param name str: the function signature @param key str: the encoded shapes and dtypes of node inputs - @param input_byte_size int: byte size of input tensors - @param param_byte_size int: byte size of param tensors + @param in_mem_info Tuple[int]: byte sizes of input tensors + @param param_mem_info Tuple[int]: byte sizes of param tensors @param fw_span float: the forward span time in milliseconds @param bw_span float: the backward span time in milliseconds @param infer_memory int: the peak memory in bytes after inference of the function - @param train_memory int: the peak memory in bytes after forward with autograd enabled + @param train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = (input_byte_size, param_byte_size, fw_span, bw_span, infer_memory, train_memory) + self._data[name][key] = (in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info) def exist(self, node: IRFwOperation) -> bool: """ @@ -271,18 +271,18 @@ def exist(self, node: IRFwOperation) -> bool: return False return True - def query(self, node: IRFwOperation) -> Tuple[int, int, float, float, int, int]: + def query(self, node: IRFwOperation) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int]]: """! Get the performance number of a node in IRGraph @param node IRFwOperation: node in IRGraph - @return input_byte_size int: byte size of input tensors - @return param_byte_size int: byte size of param tensors + @return in_mem_info Tuple[int]: byte sizes of input tensors + @return param_mem_info Tuple[int]: byte sizes of param tensors @return fw_span float: the forward span time in milliseconds @return bw_span float: the backward span time in milliseconds @return infer_memory int: the peak memory in bytes after inference of the function - @return train_memory int: the peak memory in bytes after forward with autograd enabled + @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ key = self._serialize(node) if node.signature not in self._data: @@ -291,7 +291,7 @@ def query(self, node: IRFwOperation) -> Tuple[int, int, float, float, int, int]: return None return self._data[node.signature][key] - def query_func(self, signature, shapes, dtypes) -> Tuple[int, int, float, float, int, int]: + def query_func(self, signature, shapes, dtypes) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int]]: """ Get performance number of given name (signature), shapes and dtypes @@ -299,12 +299,12 @@ def query_func(self, signature, shapes, dtypes) -> Tuple[int, int, float, float, @param shapes Tuple[Tuple[int]]: the shape of each input tensor @param dtypes Tuple[torch.dtype]: the dtype of each tensor - @return input_byte_size int: byte size of input tensors - @return param_byte_size int: byte size of param tensors + @return in_mem_info Tuple[int]: byte sizes of input tensors + @return param_mem_info Tuple[int]: byte sizes of param tensors @return fw_span float: the forward span time in milliseconds @return bw_span float: the backward span time in milliseconds @return infer_memory int: the peak memory in bytes after inference of the function - @return train_memory int: the peak memory in bytes after forward with autograd enabled + @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ key = self._serialize(shapes, dtypes) if signature not in self._data: @@ -392,7 +392,7 @@ def __repr__(self) -> str: for signature in self._data: for key in self._data[signature]: shapes, dtypes = self._deserialize(key) - input_byte_size, param_byte_size, fw_span, bw_span, infer_mem, train_mem = self._data[signature][key] - data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, in mem {input_byte_size} bytes, param mem {param_byte_size} bytes, fw span: {fw_span} ms, bw span: {bw_span} ms, infer mem {infer_mem} bytes, train mem {train_mem} bytes') + in_mem_info, param_mem_info, fw_span, bw_span, infer_mem, train_mem = self._data[signature][key] + data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, in mem {in_mem_info} bytes, param mem {param_mem_info} bytes, fw span: {fw_span} ms, bw span: {bw_span} ms, infer mem {infer_mem} bytes, train mem {train_mem} bytes') data = '\n'.join(data) return data From 027e11791f87d7ec5ad56e5ae1133ca27d00dfc1 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Mon, 9 Jan 2023 13:54:07 +0800 Subject: [PATCH 1209/1892] updt gpt config --- examples/nlp/gpt/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 59341e64..1bf54a7c 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -11,7 +11,7 @@ class Config: attention_heads: int = 16 attn_hidden_dim: int = 1024 ffn_hidden_dim: int = 4096 - num_embeddings: int = 50432 + num_embeddings: int = 51200 seqlen: int = 1024 dropout: float = 0.2 attn_dropout: float = 0.2 @@ -29,8 +29,8 @@ def build_gpt_config(name: str) -> Config: embed_dim, layers, attention_heads = 2560, 32, 32 elif name == '6.7B': embed_dim, layers, attention_heads = 4096, 32, 32 - elif name == '13B': - embed_dim, layers, attention_heads = 5120, 48, 40 + elif name == '15B': + embed_dim, layers, attention_heads = 5120, 48, 32 elif name == '39B': embed_dim, layers, attention_heads = 8192, 48, 64 elif name == '175B': @@ -231,4 +231,4 @@ def __iter__(self): return self def __next__(self): - return self.samples[0] \ No newline at end of file + return self.samples[0] From 51b897403ec05be0afc5c9dbb7214dbdca24a3e2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 12 Jan 2023 09:42:56 +0800 Subject: [PATCH 1210/1892] fix full replica device assignment --- cube/graph/gener/concurrent.py | 8 ++++++++ cube/graph/gener/layout.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 31dfc4ba..78ea0fda 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -93,6 +93,14 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], allow_reorder=False): + """ + Generate forward and backward adapter for concurrent produced tensors and consumed tensors. + + @param fptensors List[IRSubTensor]: forward produced tensors + @param fctensors List[IRSubTensor]: forward consumed tensors + @param bptensors List[IRSubTensor]: backward produced tensors + @param bctensors List[IRSubTensor]: backward consumed tensors + """ ftensor = fptensors[0].parent # producer grid layout ilayout = GridLayout.togrid(ftensor, fptensors) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 2e3fc768..1fb3f3a1 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -729,8 +729,8 @@ def intra_path(ftensor: IRFullTensor, ctensors_str += " " + repr(ctensor) + f' dev{ctensor.device}' error_msg = ( f"Fail to align intra-RVD devices. {ftensor}\n" - # f"{ptensors_str}\n" - # f"{ctensors_str}\n" + f"{ptensors_str}\n" + f"{ctensors_str}\n" f"Switch to a fixed plan: ilayout -> FullReplica -> olayout" ) color, default = '\033[33m' , '\033[0m' @@ -738,8 +738,12 @@ def intra_path(ftensor: IRFullTensor, if allow_fallback: # switch to a fixed plan ilayout -> R(n)V(1)D(1*) -> olayout rlayout = GridLayout.grid(ftensor, r=ilayout.ndevs, v=1, dims=tuple(1 for _ in range(ilayout.ndims-2))) - for t1, t2 in zip(ilayout.mat.flatten(), rlayout.mat.flatten()): - t2.cell = t1.cell + # assign devices + itensors = ilayout.mat.flatten() + idevs = np.array([t.device[0] for t in itensors]) + itensors = [itensors[idx] for idx in np.argsort(idevs)] + for it, rt in zip(itensors, rlayout.mat.flatten()): + rt.cell = it.cell # find left left: List[int] = paths[nodes.index(tuple(rlayout.vec))] left = [nodes[idx] for idx in left] From 615637a06225aeb09c69986e4fdcf4878d4ea58c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 12 Jan 2023 10:42:45 +0800 Subject: [PATCH 1211/1892] align device with output layout for intra-rvd --- cube/graph/gener/layout.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 1fb3f3a1..05de7da9 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -739,11 +739,13 @@ def intra_path(ftensor: IRFullTensor, # switch to a fixed plan ilayout -> R(n)V(1)D(1*) -> olayout rlayout = GridLayout.grid(ftensor, r=ilayout.ndevs, v=1, dims=tuple(1 for _ in range(ilayout.ndims-2))) # assign devices - itensors = ilayout.mat.flatten() - idevs = np.array([t.device[0] for t in itensors]) - itensors = [itensors[idx] for idx in np.argsort(idevs)] - for it, rt in zip(itensors, rlayout.mat.flatten()): - rt.cell = it.cell + # itensors = ilayout.mat.flatten() + # idevs = np.array([t.device[0] for t in itensors]) + # itensors = [itensors[idx] for idx in np.argsort(idevs)] + # for it, rt in zip(itensors, rlayout.mat.flatten()): + # rt.cell = it.cell + for rt, ot in zip(rlayout.mat.flatten(), olayout.mat.flatten()): + rt.cell = ot.cell # find left left: List[int] = paths[nodes.index(tuple(rlayout.vec))] left = [nodes[idx] for idx in left] From ad0b826fd94ab0ce37281b8dadba68fd156fdf92 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 13 Jan 2023 11:33:43 +0800 Subject: [PATCH 1212/1892] auto placement for rvd --- cube/graph/gener/concurrent.py | 21 +++++++++++---- cube/graph/gener/gen.py | 25 +++--------------- cube/graph/gener/layout.py | 48 +++++++++++++++++++++++----------- cube/graph/gener/utils.py | 19 ++++++++++++++ 4 files changed, 72 insertions(+), 41 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 78ea0fda..3a865642 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -13,6 +13,7 @@ from cube.ir.adapter.prim import BroadcastPrim from cube.graph.gener.layout import GridLayout, PathFinder +from cube.graph.gener.utils import DummyInputOuput from cube.flags import CompileFlag import warnings @@ -51,7 +52,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], if (not CompileFlag.disable_intra_rvd) and inshard and len(pdevs) > 1: # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) try: - fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) except Exception as e: fadapter = None color, default = '\033[33m' , '\033[0m' @@ -92,7 +93,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], @staticmethod def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], - allow_reorder=False): + allow_reassign=False): """ Generate forward and backward adapter for concurrent produced tensors and consumed tensors. @@ -100,21 +101,31 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], @param fctensors List[IRSubTensor]: forward consumed tensors @param bptensors List[IRSubTensor]: backward produced tensors @param bctensors List[IRSubTensor]: backward consumed tensors + @param allow_reassign bool: Allow to change placement of forward consumer tensors to better align deivce placement """ + allow_reassign = allow_reassign and \ + all(not isinstance(t.cell, DummyInputOuput) for t in fptensors + fctensors + bptensors + bctensors) and \ + all(t.cell.name != 'multiref' for t in fctensors) + # assert allow_reassign ftensor = fptensors[0].parent # producer grid layout ilayout = GridLayout.togrid(ftensor, fptensors) - # reorder ctensors to match with ptensors devs = [ptensor.device for ptensor in ilayout.mat.flatten()] + # re-order ctensors to match with ptensors ctensors = [None] * len(devs) for ctensor in fctensors: idx = devs.index(ctensor.device) ctensors[idx] = ctensor assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" - # consumer grid layout olayout = GridLayout.togrid(ftensor, ctensors) # find path - paths, fprims = PathFinder.intra_path(ftensor, ilayout, olayout) + paths, fprims = PathFinder.intra_path(ftensor, ilayout, olayout, allow_misalign=allow_reassign) + # re-assign tensors + if allow_reassign: + for t, ot in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): + ot.cell.device = t.device + if len(bptensors) != 0: + ot.cell.mirror.device = t.device fadapter = IRAdapter(fptensors, fctensors) fadapter.prims = fprims diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 4c98f08a..b072e730 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -19,23 +19,6 @@ DeviceID = int -class DummyInputOuput(IRFwOperation): - - def __init__(self, tensor: IRSubTensor, device: int, - is_input=False, is_output=False, - name='dummy'): - super().__init__(name, name, - 1 if is_input else 0, - 1 if is_output else 0 - ) - assert (is_input and not is_output) or (is_output and not is_input) - if is_input: - self.set_input(0, tensor) - if is_output: - self.set_output(0, tensor) - self.device = device - - def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) -> List[IRFwOperation]: """ Create dummy operators segment inputs and outputs. @@ -56,7 +39,7 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) devices = [consumer.device for consumer in segment.consumers(tensor.parent)][::-1] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = DummyInputOuput(tensor, 0, is_output=True, name=f'segment{segment.cid}_input') + fwop = utils.DummyInputOuput(tensor, 0, is_output=True, name=f'segment{segment.cid}_input') for devid in devices: fop = fwop.replicate() fop.device = devid @@ -73,7 +56,7 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) devices = [producer.device for producer in segment.producers(tensor.parent)] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = DummyInputOuput(tensor, 0, is_input=True, name=f'segment{segment.cid}_output') + fwop = utils.DummyInputOuput(tensor, 0, is_input=True, name=f'segment{segment.cid}_output') for devid in devices: fop = fwop.replicate() fop.device = devid @@ -106,10 +89,10 @@ def expand_devices(tensors: List[IRSubTensor], continue for devid in tensor.device: if producer: - fwop = DummyInputOuput(tensor, devid, is_output=True, name=tensor.cell.name) + fwop = utils.DummyInputOuput(tensor, devid, is_output=True, name=tensor.cell.name) dtensors.append(fwop.output(0)) elif consumer: - fwop = DummyInputOuput(tensor, devid, is_input=True, name=tensor.cell.name) + fwop = utils.DummyInputOuput(tensor, devid, is_input=True, name=tensor.cell.name) dtensors.append(fwop.input(0)) else: raise ValueError("At least one of producer or consumer") diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 05de7da9..acaa73e9 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -24,6 +24,21 @@ TRVD = Tuple[int, ...] +def tensor_vd_repr(t: IRSubTensor) -> str: + """ + Tensor index-value partition representation + """ + assert isinstance(t, IRSubTensor), f"expect IRSubTensor" + identifier = 't' if not t.is_grad() else 'g' + dchunks, dpos = [], [] + for dim in range(t.ndims): + dchunks.append(t.parent.shape[dim] // t.shape[dim]) + dpos.append(t.indmap[dim][0] // t.shape[dim]) + indmap = ','.join(f'{idx}/{nchunks}' for idx, nchunks in zip(dpos, dchunks)) + dscp = f'{identifier}{t.tid}-{t.device}(p{t.parent.tid}, shape={t.shape}, D({indmap}), V({t.valmap[0]}/{t.valmap[1]})' + return dscp + + class GridLayout: """ This class assumes a full-tensor can only be @@ -648,7 +663,9 @@ class PathFinder: @staticmethod def intra_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, - cost_fn: Optional[Callable] = None, allow_fallback: bool = True) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: + cost_fn: Optional[Callable] = None, + allow_fallback: bool = True, + allow_misalign: bool = False) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: """ Get primitive path of transforming ilayout into olayout. ilayout has the same device set with olayout @@ -659,6 +676,7 @@ def intra_path(ftensor: IRFullTensor, @param cost_fn Optional[Callable]: cost function of each primitive. Default (None) will use transmission volume as metrics @param allow_fallback bool: allow to use a fixed backup plan to make sure correct device mapping. (default True) + @param allow_misalign bool: allow to have a different device mapping. (default False) @return layouts List[GridLayout]: each transformation. @return prims List[IRAdapterPrim]: the primitives to perform transformation. @@ -720,17 +738,22 @@ def intra_path(ftensor: IRFullTensor, # search for correct device mapping success, layouts, all_prims = PathFinder.intra_dev_align(ftensor, rvds[1:], [ilayout], [], olayout) - if not success: - ptensors_str = 'ptensors: ' - for ptensor in ilayout.mat.flatten(): - ptensors_str += " " + repr(ptensor) + f' dev{ptensor.device}' - ctensors_str = 'ctensors: ' - for ctensor in olayout.mat.flatten(): - ctensors_str += " " + repr(ctensor) + f' dev{ctensor.device}' + if not success and allow_misalign: + layouts, all_prims = [], [] + curr_rvd, curr_layout = rvds[0], ilayout + for hop_rvd in rvds[1:]: + ret, layout_prims = PathFinder.intra_transition(ftensor, curr_rvd, hop_rvd, curr_layout) + assert ret, "Internal Error: intra-transition failed" + layout, prims = layout_prims[0] + layouts.append(layout) + all_prims += prims + curr_rvd, curr_layout = hop_rvd, layout + + elif not success: error_msg = ( f"Fail to align intra-RVD devices. {ftensor}\n" - f"{ptensors_str}\n" - f"{ctensors_str}\n" + f"ptensors:\n\t" + "\n\t".join(tensor_vd_repr(ptensor) for ptensor in ilayout.mat.flatten()) + "\n" + f"ctensors:\n\t" + "\n\t".join(tensor_vd_repr(ctensor) for ctensor in olayout.mat.flatten()) + "\n" f"Switch to a fixed plan: ilayout -> FullReplica -> olayout" ) color, default = '\033[33m' , '\033[0m' @@ -739,11 +762,6 @@ def intra_path(ftensor: IRFullTensor, # switch to a fixed plan ilayout -> R(n)V(1)D(1*) -> olayout rlayout = GridLayout.grid(ftensor, r=ilayout.ndevs, v=1, dims=tuple(1 for _ in range(ilayout.ndims-2))) # assign devices - # itensors = ilayout.mat.flatten() - # idevs = np.array([t.device[0] for t in itensors]) - # itensors = [itensors[idx] for idx in np.argsort(idevs)] - # for it, rt in zip(itensors, rlayout.mat.flatten()): - # rt.cell = it.cell for rt, ot in zip(rlayout.mat.flatten(), olayout.mat.flatten()): rt.cell = ot.cell # find left diff --git a/cube/graph/gener/utils.py b/cube/graph/gener/utils.py index 1850c22f..05a7ba69 100644 --- a/cube/graph/gener/utils.py +++ b/cube/graph/gener/utils.py @@ -8,6 +8,25 @@ from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap +class DummyInputOuput(IRFwOperation): + + def __init__(self, tensor: IRSubTensor, device: int, + is_input=False, is_output=False, + name='dummy'): + super().__init__(name, name, + 1 if is_input else 0, + 1 if is_output else 0 + ) + assert (is_input and not is_output) or (is_output and not is_input) + if is_input: + self.set_input(0, tensor) + if is_output: + self.set_output(0, tensor) + self.device = device + + def __repr__(self) -> str: + return f'DummyInputOutput-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' + def convert_add_to_valmap(graph: IRGraph, add_node: IRFwOperation): """ From fb1543d3e3cbd416993796cad2120c7ae929743e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 13 Jan 2023 15:03:37 +0800 Subject: [PATCH 1213/1892] fix multiple re-assign --- cube/graph/gener/concurrent.py | 23 ++++++++++------------- cube/graph/gener/layout.py | 17 ++++++++++------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 3a865642..7e9dc9a9 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -50,7 +50,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # case 1: sharing device (in-shard) inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) if (not CompileFlag.disable_intra_rvd) and inshard and len(pdevs) > 1: - # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reorder=True) + fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) try: fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) except Exception as e: @@ -103,11 +103,15 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], @param bctensors List[IRSubTensor]: backward consumed tensors @param allow_reassign bool: Allow to change placement of forward consumer tensors to better align deivce placement """ + ftensor = fptensors[0].parent allow_reassign = allow_reassign and \ all(not isinstance(t.cell, DummyInputOuput) for t in fptensors + fctensors + bptensors + bctensors) and \ all(t.cell.name != 'multiref' for t in fctensors) - # assert allow_reassign - ftensor = fptensors[0].parent + # each consumer can only be re-assigned once + for t in fctensors[0].cell.inputs(): + if isinstance(t, IRSubTensor): + allow_reassign = allow_reassign and (t.parent == ftensor) + break # producer grid layout ilayout = GridLayout.togrid(ftensor, fptensors) devs = [ptensor.device for ptensor in ilayout.mat.flatten()] @@ -119,9 +123,9 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" olayout = GridLayout.togrid(ftensor, ctensors) # find path - paths, fprims = PathFinder.intra_path(ftensor, ilayout, olayout, allow_misalign=allow_reassign) + align, paths, fprims = PathFinder.intra_path(ftensor, ilayout, olayout, allow_misalign=allow_reassign) # re-assign tensors - if allow_reassign: + if (not align) and allow_reassign: for t, ot in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): ot.cell.device = t.device if len(bptensors) != 0: @@ -143,14 +147,7 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], ilayout = GridLayout.togrid(grad, ptensors) olayout = GridLayout.togrid(grad, bctensors) # paths, bprims = ilayout.path(olayout) - paths, bprims = PathFinder.intra_path(grad, ilayout, olayout) - # check the device order - same_device = True - for t in paths[-1].mat.flatten(): - if not any(t == c and set(t.device) == set(c.device) for c in bctensors): - same_device = False - break - assert same_device, "backward device not match" + _, paths, bprims = PathFinder.intra_path(grad, ilayout, olayout) # generate backward adapter badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index acaa73e9..873ff36b 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -665,7 +665,7 @@ def intra_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None, allow_fallback: bool = True, - allow_misalign: bool = False) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: + allow_misalign: bool = False) -> Tuple[bool, List[GridLayout], List[IRAdapterPrim]]: """ Get primitive path of transforming ilayout into olayout. ilayout has the same device set with olayout @@ -678,6 +678,7 @@ def intra_path(ftensor: IRFullTensor, @param allow_fallback bool: allow to use a fixed backup plan to make sure correct device mapping. (default True) @param allow_misalign bool: allow to have a different device mapping. (default False) + @return align bool: whether correctly align the device placement @return layouts List[GridLayout]: each transformation. @return prims List[IRAdapterPrim]: the primitives to perform transformation. """ @@ -686,7 +687,8 @@ def intra_path(ftensor: IRFullTensor, key = (shape, ilayout.ndevs) src = (ilayout.R, ilayout.V) + tuple(ilayout.D) dst = (olayout.R, olayout.V) + tuple(olayout.D) - if src == dst: return [ilayout], [] + # TODO: FIXME: may not align + if src == dst: return True, [ilayout], [] # get paths using dijkstra algorithm or cached if key in PathFinder._cached_intra_paths and src in PathFinder._cached_intra_paths[key]: @@ -737,8 +739,8 @@ def intra_path(ftensor: IRFullTensor, # print(f'path: {rvds}') # search for correct device mapping - success, layouts, all_prims = PathFinder.intra_dev_align(ftensor, rvds[1:], [ilayout], [], olayout) - if not success and allow_misalign: + align, layouts, all_prims = PathFinder.intra_dev_align(ftensor, rvds[1:], [ilayout], [], olayout) + if not align and allow_misalign: layouts, all_prims = [], [] curr_rvd, curr_layout = rvds[0], ilayout for hop_rvd in rvds[1:]: @@ -749,7 +751,7 @@ def intra_path(ftensor: IRFullTensor, all_prims += prims curr_rvd, curr_layout = hop_rvd, layout - elif not success: + elif not align: error_msg = ( f"Fail to align intra-RVD devices. {ftensor}\n" f"ptensors:\n\t" + "\n\t".join(tensor_vd_repr(ptensor) for ptensor in ilayout.mat.flatten()) + "\n" @@ -770,14 +772,15 @@ def intra_path(ftensor: IRFullTensor, lsuccess, llayouts, lprims = PathFinder.intra_dev_align(ftensor, left[1:], [ilayout], [], rlayout) assert lsuccess, f"Switch fail to generate left-half intra-RVD plans for all-replica" # find right - rlayouts, rprims = PathFinder.intra_path(ftensor, rlayout, olayout, cost_fn, allow_fallback=False) + _, rlayouts, rprims = PathFinder.intra_path(ftensor, rlayout, olayout, cost_fn, allow_fallback=False) layouts = llayouts + rlayouts all_prims = lprims + rprims else: # allow_fallback is False only for generating right-half intra-RVD assert False, f"Switch fail to generate right-half intra-RVD plans from all-replica" + align = True - return layouts, all_prims + return align, layouts, all_prims @staticmethod def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: From e82130617a67a3f05572731e4743bc865361dc9d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 13 Jan 2023 15:07:09 +0800 Subject: [PATCH 1214/1892] adapter repr --- cube/ir/adapter/adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index 97d0a146..6e593c59 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -152,9 +152,9 @@ def __repr__(self): return f'Adapter-{self._id}{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' def extra_repr(self) -> str: - dscp = f'Adapter-{self._id}[{self.device}](inputs={self.inputs()}, outputs={self.outputs()})\n' + dscp = f'Adapter-{self._id}{self.device}(\n\tinputs={self.inputs()},\n\toutputs={self.outputs()}\n):' for prim in self.prims: - dscp += repr(prim) + '\n' + dscp += '\n\t' + repr(prim) return dscp From 50e622ba7e88ecd33ab4bd09df692a139bae2e3d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 15 Jan 2023 23:22:32 +0800 Subject: [PATCH 1215/1892] fix intra-rvd mis-align bug --- cube/graph/gener/concurrent.py | 2 +- cube/graph/gener/layout.py | 52 ++++++++++++++++++++++------------ 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 7e9dc9a9..62e046ad 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -50,7 +50,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # case 1: sharing device (in-shard) inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) if (not CompileFlag.disable_intra_rvd) and inshard and len(pdevs) > 1: - fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) + # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) try: fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) except Exception as e: diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 873ff36b..9272c041 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -444,8 +444,29 @@ def decv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: prims.append(RVGatherPrim(srcs, [dst])) return [(olayout, prims),] + def align(self, layout) -> bool: + """ + Check whether the layout is same with self. + + The same means 1) sub-tenosrs are same 2) device are aligned - # ================ solution ============= # + @param layout GridLayout + + @return same bool: + """ + if not isinstance(layout, GridLayout): + return False + tensors: List[IRSubTensor] = list(self.mat.flatten()) + for t in layout.mat.flatten(): + dev_match = False + for idx in range(len(tensors)): + t2 = tensors[idx] + if t == t2 and set(t.device) == set(t2.device): + tensors.pop(idx) + dev_match = True + break + if not dev_match: return False + return True def print_dev_tensors(self): """ @@ -687,8 +708,16 @@ def intra_path(ftensor: IRFullTensor, key = (shape, ilayout.ndevs) src = (ilayout.R, ilayout.V) + tuple(ilayout.D) dst = (olayout.R, olayout.V) + tuple(olayout.D) - # TODO: FIXME: may not align - if src == dst: return True, [ilayout], [] + + # cases for same source and destination RVD + if src == dst: + if ilayout.align(olayout): + return True, [ilayout], [] + else: + if allow_misalign: + return False, [ilayout], [] + else: + assert False, "Same source and destination rvd but got mis-aligned devices" # get paths using dijkstra algorithm or cached if key in PathFinder._cached_intra_paths and src in PathFinder._cached_intra_paths[key]: @@ -938,27 +967,14 @@ def intra_dev_align(ftensor: IRFullTensor, remain_states: List[Tuple[int]], """ ilayout = ilayouts[-1] if len(remain_states) == 0: - # print(f'transformed tensors: {[(t, t.device) for t in ilayout.mat.flatten()]}') - # print(f'destination tensors: {[(t, t.device) for t in olayout.mat.flatten()]}') - # check device mapping - otensors: List[IRSubTensor] = olayout.mat.flatten().tolist() - for itensor in ilayout.mat.flatten(): - dev_match = False - for idx in range(len(otensors)): - otensor = otensors[idx] - if otensor == itensor and set(otensor.device) == set(itensor.device): - otensors.pop(idx) - dev_match = True - break - if not dev_match: return False, [], [] + if not ilayout.align(olayout): + return False, [], [] return True, ilayouts, all_prims else: success, layout_prims = PathFinder.intra_transition( ftensor, (ilayout.R, ilayout.V) + ilayout.D, remain_states[0], ilayout) assert success, "Internal Error at intra-RVD transition" for (hop_layout, hop_prims) in layout_prims: - # print(f'hop layout: {[(t, t.device) for t in hop_layout.mat.flatten()]}') - # print(f'dst layout: {[(t, t.device) for t in olayout.mat.flatten()]}') ret, ret_layouts, ret_prims = PathFinder.intra_dev_align( ftensor, remain_states[1:], ilayouts + [hop_layout], all_prims + hop_prims, olayout) if ret: From f50ac5a999a256337eac17b19cb9afb0caf6ea80 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 15 Jan 2023 23:27:47 +0800 Subject: [PATCH 1216/1892] update error log --- cube/graph/gener/concurrent.py | 2 +- cube/graph/gener/layout.py | 16 +--------------- cube/graph/gener/utils.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 62e046ad..38302211 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -58,7 +58,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], color, default = '\033[33m' , '\033[0m' print( f"{color}========== Fail to use intra-RVD ==========\n" - f"full tensor: {fptensors[0].parent}\n" + f"full tensor: {fptensors[0].parent} | is grad: {fptensors[0].parent.is_grad()}\n" f"Reason: {str(e)}\n" f"Switch to general P2P communication.\n" f"===========================================\n{default}", file=sys.stderr diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 9272c041..8eebfc2c 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -19,26 +19,12 @@ from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim +from cube.graph.gener.utils import tensor_vd_repr TShape = Tuple[int, ...] TRVD = Tuple[int, ...] -def tensor_vd_repr(t: IRSubTensor) -> str: - """ - Tensor index-value partition representation - """ - assert isinstance(t, IRSubTensor), f"expect IRSubTensor" - identifier = 't' if not t.is_grad() else 'g' - dchunks, dpos = [], [] - for dim in range(t.ndims): - dchunks.append(t.parent.shape[dim] // t.shape[dim]) - dpos.append(t.indmap[dim][0] // t.shape[dim]) - indmap = ','.join(f'{idx}/{nchunks}' for idx, nchunks in zip(dpos, dchunks)) - dscp = f'{identifier}{t.tid}-{t.device}(p{t.parent.tid}, shape={t.shape}, D({indmap}), V({t.valmap[0]}/{t.valmap[1]})' - return dscp - - class GridLayout: """ This class assumes a full-tensor can only be diff --git a/cube/graph/gener/utils.py b/cube/graph/gener/utils.py index 05a7ba69..cd11495b 100644 --- a/cube/graph/gener/utils.py +++ b/cube/graph/gener/utils.py @@ -28,6 +28,21 @@ def __repr__(self) -> str: return f'DummyInputOutput-{self.device}(inputs={self.inputs()}, outputs={self.outputs()})' +def tensor_vd_repr(t: IRSubTensor) -> str: + """ + Tensor index-value partition representation + """ + assert isinstance(t, IRSubTensor), f"expect IRSubTensor" + identifier = 't' if not t.is_grad() else 'g' + dchunks, dpos = [], [] + for dim in range(t.ndims): + dchunks.append(t.parent.shape[dim] // t.shape[dim]) + dpos.append(t.indmap[dim][0] // t.shape[dim]) + indmap = ','.join(f'{idx}/{nchunks}' for idx, nchunks in zip(dpos, dchunks)) + dscp = f'{identifier}{t.tid}-{t.device}(p{t.parent.tid}, shape={t.shape}, D({indmap}), V({t.valmap[0]}/{t.valmap[1]})' + return dscp + + def convert_add_to_valmap(graph: IRGraph, add_node: IRFwOperation): """ Remove add node by replacing with tensor valmap From d56af35539d9ac529499c11c0d6fe1a333f78fcd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 16 Jan 2023 11:47:02 +0800 Subject: [PATCH 1217/1892] add intra-rvd device mis-align log --- cube/graph/gener/layout.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py index 8eebfc2c..69e388e9 100644 --- a/cube/graph/gener/layout.py +++ b/cube/graph/gener/layout.py @@ -769,6 +769,7 @@ def intra_path(ftensor: IRFullTensor, elif not align: error_msg = ( f"Fail to align intra-RVD devices. {ftensor}\n" + f"Path: {' -> '.join(str(rvd) for rvd in rvds)}\n" f"ptensors:\n\t" + "\n\t".join(tensor_vd_repr(ptensor) for ptensor in ilayout.mat.flatten()) + "\n" f"ctensors:\n\t" + "\n\t".join(tensor_vd_repr(ctensor) for ctensor in olayout.mat.flatten()) + "\n" f"Switch to a fixed plan: ilayout -> FullReplica -> olayout" From f0a3c899a423e986125a8b30bbf2a43913edd282 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 Jan 2023 02:12:02 +0000 Subject: [PATCH 1218/1892] Merged PR 1442: Intra/Inter-RVD generation with device placement advisor Add device placement suggestion for intra-RVD; Code refactor; Full-case tests of intra-RVD and inter-RVD. --- cube/graph/gener/concurrent.py | 85 +-- cube/graph/gener/gen.py | 13 +- cube/graph/gener/layout.py | 1157 ------------------------------ cube/graph/gener/rvd/__init__.py | 0 cube/graph/gener/rvd/inter.py | 474 ++++++++++++ cube/graph/gener/rvd/intra.py | 596 +++++++++++++++ cube/graph/gener/rvd/layout.py | 500 +++++++++++++ tests/adapter/test_inter_rvd.py | 88 +++ tests/adapter/test_intra_rvd.py | 441 +++++++----- 9 files changed, 1973 insertions(+), 1381 deletions(-) delete mode 100644 cube/graph/gener/layout.py create mode 100644 cube/graph/gener/rvd/__init__.py create mode 100644 cube/graph/gener/rvd/inter.py create mode 100644 cube/graph/gener/rvd/intra.py create mode 100644 cube/graph/gener/rvd/layout.py create mode 100644 tests/adapter/test_inter_rvd.py diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index 38302211..efd504b2 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -1,7 +1,7 @@ """ Concurrent producer / consumer Adapter Generator """ -from typing import List, Optional, Dict, Tuple +from typing import List, Optional, Dict, Tuple, Callable import copy import numpy as np import sys @@ -12,12 +12,15 @@ from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim from cube.ir.adapter.prim import BroadcastPrim -from cube.graph.gener.layout import GridLayout, PathFinder +from cube.graph.gener.rvd.layout import RVDLayout +from cube.graph.gener.rvd.intra import IntraPathFinder +from cube.graph.gener.rvd.inter import InterPathFinder from cube.graph.gener.utils import DummyInputOuput from cube.flags import CompileFlag import warnings + if CompileFlag.disable_intra_rvd: warnings.warn('Detected disabling intra-RVD collective generation, which may have big impact on performance.') if CompileFlag.disable_inter_rvd: @@ -30,7 +33,8 @@ class ConcurrentGener: @staticmethod def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], - bptensors: List[IRSubTensor], bctensors: List[IRSubTensor]) -> Optional[IRAdapter]: + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], + cost_fn: Optional[Callable] = None) -> Optional[IRAdapter]: """ Generate forward adapter and backward adapter @@ -38,21 +42,21 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], @param fctensors List[IRSubTensor]: forward consumer tensors @param bptensors List[IRSubTensor]: backward producer tensors @param bctensors List[IRSubTensor]: backward consumer tensors + @param cost_fn Optional[Callable]: takes in an IRAdapterPrim and outputs a cost in float - @return fadapter Optional[IRAdapter]: forward adapter - None indicate no adapter required. + @return fadapter IRAdapter: forward adapter """ pdevs = tuple(t.device[0] for t in fptensors) cdevs = tuple(t.device[0] for t in fctensors) fadapter: IRAdapter = None - # case 1: sharing device (in-shard) + # case 1: sharing device (intra-rvd) inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) if (not CompileFlag.disable_intra_rvd) and inshard and len(pdevs) > 1: # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) try: - fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) + fadapter = ConcurrentGener.gen_intra_rvd(fptensors, fctensors, bptensors, bctensors, cost_fn) except Exception as e: fadapter = None color, default = '\033[33m' , '\033[0m' @@ -64,11 +68,11 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], f"===========================================\n{default}", file=sys.stderr ) - # Case 2: sperating device (cross-shard) + # Case 2: sperating device (inter-rvd) if (not CompileFlag.disable_inter_rvd) and len(set(pdevs).intersection(cdevs)) == 0: # fadapter = ConcurrentGener.gen_cross_shard(fptensors, fctensors, bptensors, bctensors) try: - fadapter = ConcurrentGener.gen_cross_shard(fptensors, fctensors, bptensors, bctensors) + fadapter = ConcurrentGener.gen_inter_rvd(fptensors, fctensors, bptensors, bctensors, cost_fn) except Exception as e: fadapter = None color, default = '\033[33m' , '\033[0m' @@ -79,6 +83,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], f"Switch to general P2P communication.\n" f"===========================================\n{default}", file=sys.stderr ) + # Case 3: General cases # warnings.warn('The adapter is generated using P2P communication') if fadapter is None: @@ -91,9 +96,9 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], return fadapter @staticmethod - def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], - bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], - allow_reassign=False): + def gen_intra_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], + cost_fn: Optional[Callable] = None) -> IRAdapter: """ Generate forward and backward adapter for concurrent produced tensors and consumed tensors. @@ -101,35 +106,23 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], @param fctensors List[IRSubTensor]: forward consumed tensors @param bptensors List[IRSubTensor]: backward produced tensors @param bctensors List[IRSubTensor]: backward consumed tensors - @param allow_reassign bool: Allow to change placement of forward consumer tensors to better align deivce placement + @param cost_fn Optional[Callable]: takes in an IRAdapterPrim and outputs a cost in float + + @return adapter IRAdapter: forward IRAdapter with backward (if has) in its .mirror attribute. """ ftensor = fptensors[0].parent - allow_reassign = allow_reassign and \ - all(not isinstance(t.cell, DummyInputOuput) for t in fptensors + fctensors + bptensors + bctensors) and \ - all(t.cell.name != 'multiref' for t in fctensors) - # each consumer can only be re-assigned once - for t in fctensors[0].cell.inputs(): - if isinstance(t, IRSubTensor): - allow_reassign = allow_reassign and (t.parent == ftensor) - break # producer grid layout - ilayout = GridLayout.togrid(ftensor, fptensors) + ilayout = RVDLayout.togrid(ftensor, fptensors) devs = [ptensor.device for ptensor in ilayout.mat.flatten()] - # re-order ctensors to match with ptensors + # re-order ctensors to match with placement of ptensors ctensors = [None] * len(devs) for ctensor in fctensors: idx = devs.index(ctensor.device) ctensors[idx] = ctensor assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" - olayout = GridLayout.togrid(ftensor, ctensors) - # find path - align, paths, fprims = PathFinder.intra_path(ftensor, ilayout, olayout, allow_misalign=allow_reassign) - # re-assign tensors - if (not align) and allow_reassign: - for t, ot in zip(paths[-1].mat.flatten(), olayout.mat.flatten()): - ot.cell.device = t.device - if len(bptensors) != 0: - ot.cell.mirror.device = t.device + olayout = RVDLayout.togrid(ftensor, ctensors) + # get forward primitives + fprims = IntraPathFinder.path(ilayout, olayout, cost_fn) fadapter = IRAdapter(fptensors, fctensors) fadapter.prims = fprims @@ -144,10 +137,10 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], idx = devs.index(bptensor.device) assert ptensors[idx] is None, "same device of different tensors" ptensors[idx] = bptensor - ilayout = GridLayout.togrid(grad, ptensors) - olayout = GridLayout.togrid(grad, bctensors) + ilayout = RVDLayout.togrid(grad, ptensors) + olayout = RVDLayout.togrid(grad, bctensors) # paths, bprims = ilayout.path(olayout) - _, paths, bprims = PathFinder.intra_path(grad, ilayout, olayout) + bprims = IntraPathFinder.path(ilayout, olayout, cost_fn) # generate backward adapter badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims @@ -156,33 +149,33 @@ def gen_in_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], return fadapter @staticmethod - def gen_cross_shard(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], - bptensors: List[IRSubTensor], bctensors: List[IRSubTensor],) -> IRAdapter: + def gen_inter_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], + cost_fn: Optional[Callable] = None) -> IRAdapter: """ + Generate communication adapters for inter-RVD scenarios. This assumes ptensors and ctensors can be represented by RVD layout. - - pdevices: devices of ptensors - cdevices: devices of ctensors @param fptensors List[IRSubTensor]: produced tensors @param fctensors List[IRSubTensor]: consumed tensors @param bptensors List[IRSubTensor]: produced tensors @param bctensors List[IRSubTensor]: consumed tensors + @param cost_fn Optional[Callable]: takes in an IRAdapterPrim and outputs a cost in float @return fadapter IRAdapter """ ftensor = fptensors[0].parent - ilayout = GridLayout.togrid(ftensor, fptensors) - olayout = GridLayout.togrid(ftensor, fctensors) - fpaths, fprims = PathFinder.inter_path(ftensor, ilayout, olayout) + ilayout = RVDLayout.togrid(ftensor, fptensors) + olayout = RVDLayout.togrid(ftensor, fctensors) + fprims = InterPathFinder.path(ilayout, olayout, cost_fn) fadapter = IRAdapter(fptensors, fctensors) fadapter.prims = fprims grad: IRFullTensor = ftensor.grad if grad is not None and (len(bptensors) != 0 or len(bctensors) != 0): - ilayout = GridLayout.togrid(grad, bptensors) - olayout = GridLayout.togrid(grad, bctensors) - bpaths, bprims = PathFinder.inter_path(grad, ilayout, olayout) + ilayout = RVDLayout.togrid(grad, bptensors) + olayout = RVDLayout.togrid(grad, bctensors) + bprims = InterPathFinder.path(ilayout, olayout, cost_fn) badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims IRAdapter.make_pair(fadapter, badapter) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index b072e730..daddb5ce 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Callable import numpy as np import itertools @@ -102,12 +102,15 @@ def expand_devices(tensors: List[IRSubTensor], class IRAdapterGener: @staticmethod - def gen(graph: IRGraph) -> IRGraph: + def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: """ Generate tensor adapter for both activations and weights Note weight reducers are always append to the last. @param graph IRGraph: the graph without adapter + @param cost_fn Optional[Callable]: takes an IRAdapterPrim and outputs a cost in float. + default to be None, which will use communication volume. + @return graph IRGraph: the graph with adapter inserted """ # reorder producer and consumer ordering @@ -117,7 +120,7 @@ def gen(graph: IRGraph) -> IRGraph: # automatic transform multiref graph = IRAdapterGener.autoref(graph) # generate adapters for activation - graph = IRAdapterGener.gen_activation(graph) + graph = IRAdapterGener.gen_activation(graph, cost_fn=cost_fn) # generate weight reducer graph = IRAdapterGener.gen_weight(graph) # fuse consecutive non-differentiable adapters into one @@ -226,7 +229,7 @@ def gen_weight(graph: IRGraph) -> IRGraph: return graph @staticmethod - def gen_activation(graph: IRSegment, allow_recompute: bool = True) -> IRSegment: + def gen_activation(graph: IRSegment, allow_recompute: bool = True, cost_fn: Optional[Callable] = None) -> IRSegment: """! Generate adapter for activation tensors. The forward/backward adapter is inserted before the first consumers of its full tensor. @@ -299,7 +302,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: if skip(fptensors, fctensors) and skip(bptensors, bctensors): continue - fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors) + fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) if fadapter is None: continue diff --git a/cube/graph/gener/layout.py b/cube/graph/gener/layout.py deleted file mode 100644 index 69e388e9..00000000 --- a/cube/graph/gener/layout.py +++ /dev/null @@ -1,1157 +0,0 @@ -from typing import Callable, Dict, List, Tuple, Optional -import copy -import numpy as np -import sys - -from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor -from cube.ir.tensor import ValueMap - -from cube.ir.adapter.prim import IRAdapterPrim -from cube.ir.adapter.prim import AllGatherPrim # d2r -from cube.ir.adapter.prim import AllToAllPrim # d2d -from cube.ir.adapter.prim import AllReducePrim # v2r -from cube.ir.adapter.prim import ReduceScatterPrim # v2d -from cube.ir.adapter.prim import ChunkPrim # r2d - -from cube.ir.adapter.prim import MovePrim # p2p -from cube.ir.adapter.prim import BroadcastPrim -from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim -from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim - -from cube.graph.gener.utils import tensor_vd_repr - -TShape = Tuple[int, ...] -TRVD = Tuple[int, ...] - - -class GridLayout: - """ - This class assumes a full-tensor can only be - uniformly partitioned / replicated on dimensions and values. - - A partition plan N-dim tensor layout can be represented as - : R (replica), V (value), dim_i (dimension) - """ - - def __init__(self, ftensor: IRFullTensor, subtensors: List[IRSubTensor], mats: np.ndarray): - """ - ftensor: N-dim FullTensor - subtensors: List[IRSubTensors] - mats: Array[IRSubTensor]: - (2+N)-dim matrix, with index respect to - """ - self.ftensor = ftensor - self.subtensors = subtensors - self._mats = mats - - @property - def R(self) -> int: - return self._mats.shape[0] - - @property - def V(self) -> int: - return self._mats.shape[1] - - @property - def D(self) -> Tuple[int]: - return tuple(self._mats.shape[2:]) - - @property - def vec(self) -> Tuple[int]: - return tuple(self._mats.shape) - - @property - def ndims(self): - return len(self._mats.shape) - - @property - def ndevs(self): - return len(self.subtensors) - - @property - def mat(self): - return self._mats - - def tensor(self, r: int, v: int, d: List[int]) -> IRSubTensor: - """ - Get subtenor indexed by RVD position. - """ - assert r <= self.R and v <= self.V and len(d) == len(self.D), "out of scope" - indices = [r, v] + list(d) - return self._mats[tuple(indices)] - - def __repr__(self): - dscp = f'T{self.ftensor._id}' - return dscp - - # ====== intra-RVD transition primitives ====== # - - def d2r(self, dim: int, chunks: int) -> Tuple: - """ - intra-RVD primitive D->R: allgather - - @param dim int: tensor dimension - @param chunks int: the number of chunks to transfer - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.D[dim] % chunks == 0, f"not dividable dim: {self.D[dim]} // {chunks}" - rvd = list(self.vec) - rvd[0], rvd[2+dim] = rvd[0] * chunks, rvd[2+dim] // chunks - - ret: List[Tuple[GridLayout, List[IRAdapterPrim]]] = [] - # collect all possible layouts - olayouts: List[GridLayout] = [GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:])] - if self.R != 1: - olayout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) - olayout.inner_transpose(0, chunks) - olayouts.append(olayout) - # generate primitives for all possible cases - for olayout in olayouts: - imat = GridLayout.transpose(self.mat, 0, 2+dim) - omat = GridLayout.transpose(olayout.mat, 2+dim, 0) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor.cell = itensor.cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(AllGatherPrim(itensors, otensors, dim)) - ret.append((olayout, prims)) - return ret - - def d2d(self, from_dim: int, to_dim: int, chunks: int) -> Tuple: - """ - intra-RVD primitive D(...,i,..)->D(..,j,...): alltoall - - @param from_dim int: source tensor axis - @param to_dim int: destination tensor axis - @param chunks int: the number of chunks to transfer - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.D[from_dim] % chunks == 0, f"not dividable dim: {self.D[from_dim]} // {chunks}" - rvd = list(self.vec) - rvd[2+from_dim], rvd[2+to_dim] = rvd[2+from_dim] // chunks, rvd[2+to_dim] * chunks - layout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) - # d2d has no ambiguity on device mapping - imat = GridLayout.transpose(self.mat, 2+to_dim, 2+from_dim) - omat = GridLayout.transpose(layout.mat, 2+from_dim, 2+to_dim) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor.cell = itensor.cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(AllToAllPrim(itensors, otensors, from_dim, to_dim)) - return [(layout, prims)] - - def v2r(self, chunks: int) -> Tuple: - """ - intra-RVD primitive V->R: allreduce - - @param dim int: tensor dimension - @param chunks int: the number of chunks to transfer - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.V % chunks == 0, f"not dividable value chunks: {self.V} // {chunks}" - rvd = list(self.vec) - rvd[1], rvd[0] = rvd[1] // chunks, rvd[0] * chunks - - ret: List[Tuple[GridLayout, List[IRAdapterPrim]]] = [] - # collect all possible layouts - ilayouts: List[GridLayout] = [self] - if self.V != chunks: - ilayout = GridLayout.grid(self.ftensor, r=self.R, v=self.V, dims=self.D) - for t1, t2 in zip(self.mat.flatten(), ilayout.mat.flatten()): - t2.cell = t1.cell - ilayout.inner_transpose(1, chunks) - ilayouts.append(ilayout) - olayouts: List[GridLayout] = [GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:])] - if self.R != 1: - olayout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) - olayout.inner_transpose(0, chunks) - olayouts.append(olayout) - # generate primitives for all possible cases - for ilayout in ilayouts: - for olayout in olayouts: - imat = GridLayout.transpose(ilayout.mat, 0, 1) - omat = GridLayout.transpose(olayout.mat, 1, 0) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor.cell = itensor.cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(AllReducePrim(itensors, otensors)) - ret.append((olayout, prims)) - return ret - - def v2d(self, dim: int, chunks: int) -> Tuple: - """ - intra-RVD primitive V->D: reduce-scatter - - @param dim int: tensor dimension - @param chunks int: the number of chunks to transfer - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.V % chunks == 0, f"not dividable value chunks: {self.V} // {chunks}" - rvd = list(self.vec) - rvd[1], rvd[2+dim] = rvd[1] // chunks, rvd[2+dim] * chunks - - ret: List[Tuple[GridLayout, List[IRAdapterPrim]]] = [] - # collect all possible layouts - ilayouts = [self] - if self.V != chunks: - ilayout = GridLayout.grid(self.ftensor, r=self.R, v=self.V, dims=self.D) - for t1, t2 in zip(self.mat.flatten(), ilayout.mat.flatten()): - t2.cell = t1.cell - ilayout.inner_transpose(1, chunks) - ilayouts.append(ilayout) - # generate primitives for all possible cases - for ilayout in ilayouts: - olayout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) - imat = GridLayout.transpose(self.mat, 2+dim, 1) - omat = GridLayout.transpose(olayout.mat, 1, 2+dim) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor.cell = itensor.cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(ReduceScatterPrim(itensors, otensors, dim)) - ret.append((olayout, prims)) - return ret - - def r2d(self, dim: int, chunks: int) -> Tuple: - """ - intra-RVD primitive V->D: schunk - - @param dim int: tensor axis - @param chunks int: the number of chunks to transfer - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.R % chunks == 0, f"not dividable replica: {self.R} // {chunks}" - rvd = list(self.vec) - rvd[0], rvd[2+dim] = rvd[0] // chunks, rvd[2+dim] * chunks - - ret: List[Tuple[GridLayout, List[IRAdapterPrim]]] = [] - # collect all possible layouts - ilayouts = [self] - # print(f'r->d({dim})[{chunks}]: ilayout-self : {[(t, t.device) for t in self.mat.flatten()]}') - if self.R != chunks: - ilayout = GridLayout.grid(self.ftensor, r=self.R, v=self.V, dims=self.D) - for t1, t2 in zip(self.mat.flatten(), ilayout.mat.flatten()): - t2.cell = t1.cell - ilayout.inner_transpose(0, chunks) - ilayouts.append(ilayout) - # print(f'r->d({dim})[{chunks}]: ilayout-transformed: {[(t, t.device) for t in ilayout.mat.flatten()]}') - # generate primitives for all possible cases - for ilayout in ilayouts: - olayout = GridLayout.grid(self.ftensor, r=rvd[0], v=rvd[1], dims=rvd[2:]) - imat = GridLayout.transpose(ilayout.mat, 2+dim, 0) - omat = GridLayout.transpose(olayout.mat, 0, 2+dim) - for itensor, otensor in zip(imat.flatten(), omat.flatten()): - otensor.cell = itensor.cell - prims = [] - for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(ChunkPrim(itensors, otensors, dim)) - ret.append((olayout, prims)) - return ret - - def incr(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive +R: broadcast - - @param chunks int: the number of chunks to transfer - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - rvd = list(self.vec) - rvd[0] = rvd[0] * chunks - olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = GridLayout.dims2last(self.mat, [0]).flatten() - omat = GridLayout.dims2last(olayout.mat, [0]).reshape(-1, chunks) - prims = [] - for src, dsts in zip(imat, omat): - if chunks == 1: - prims.append(MovePrim([src], dsts)) - else: - prims.append(BroadcastPrim([src], [src] + list(dsts))) - return [(olayout, prims),] - - def decr(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive -R: move - - @param chunks int: the number of chunks to transfer - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.R % chunks == 0, f"not divisible replica {self.R} // {chunks}" - rvd = list(self.vec) - rvd[0] = rvd[0] // chunks - olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = GridLayout.dims2last(self.mat, [0]).reshape(-1, chunks) - omat = GridLayout.dims2last(olayout.mat, [0]).flatten() - prims = [] - for srcs, dst in zip(imat, omat): - prims.append(MovePrim([srcs[0]], [dst])) - return [(olayout, prims),] - - def incd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive +D: RD-Scatter - - @param chunks int: the number of chunks to transfer - @param dim int: tensor axis - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - rvd = list(self.vec) - rvd[2+dim] = rvd[2+dim] * chunks - olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = GridLayout.dims2last(self.mat, [2+dim]).flatten() - omat = GridLayout.dims2last(olayout.mat, [2+dim]).reshape(-1, chunks) - prims = [] - for src, dsts in zip(imat, omat): - prims.append(RDScatterPrim([src], dsts, dim=dim)) - return olayout, prims - - def decd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive +D: RD-Gather - - @param chunks int: the number of chunks to transfer - @param dim int: tensor axis - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.D[dim] % chunks == 0, f"not divisible dim: {self.D[dim]} % {chunks} != 0" - rvd = list(self.vec) - rvd[2+dim] = rvd[2+dim] // chunks - olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = GridLayout.dims2last(self.mat, [2+dim]).reshape(-1, chunks) - omat = GridLayout.dims2last(olayout.mat, [2+dim]).flatten() - prims = [] - for srcs, dst in zip(imat, omat): - prims.append(RDGatherPrim(srcs, [dst], dim=dim)) - return [(olayout, prims),] - - def incv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive +V: RV-Scatter - - @param chunks int: the number of chunks to transfer - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - rvd = list(self.vec) - rvd[1] = rvd[1] * chunks - olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = GridLayout.dims2last(self.mat, [1]).flatten() - omat = GridLayout.dims2last(olayout.mat, [1]).reshape(-1, chunks) - prims = [] - for src, dsts in zip(imat, omat): - prims.append(RVScatterPrim([src], dsts)) - return [(olayout, prims),] - - def decv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive -V: RV-Gather - - @param chunks int: the number of chunks to transfer - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.V % chunks == 0, f"not divisable value split: {self.V} % {chunks} != 0" - rvd = list(self.vec) - rvd[1] = rvd[1] // chunks - olayout = GridLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = GridLayout.dims2last(self.mat, [1]).reshape(-1, chunks) - omat = GridLayout.dims2last(olayout.mat, [1]).flatten() - prims = [] - for srcs, dst in zip(imat, omat): - prims.append(RVGatherPrim(srcs, [dst])) - return [(olayout, prims),] - - def align(self, layout) -> bool: - """ - Check whether the layout is same with self. - - The same means 1) sub-tenosrs are same 2) device are aligned - - @param layout GridLayout - - @return same bool: - """ - if not isinstance(layout, GridLayout): - return False - tensors: List[IRSubTensor] = list(self.mat.flatten()) - for t in layout.mat.flatten(): - dev_match = False - for idx in range(len(tensors)): - t2 = tensors[idx] - if t == t2 and set(t.device) == set(t2.device): - tensors.pop(idx) - dev_match = True - break - if not dev_match: return False - return True - - def print_dev_tensors(self): - """ - print each device hold tensors. - """ - devices: Dict[int, List[IRSubTensor]] = dict() - for tensor in self.subtensors: - assert len(tensor.device) == 1, f"got tensor device: {tensor.device}" - if tensor.device[0] not in devices: - devices[tensor.device[0]] = [] - devices[tensor.device[0]].append(tensor) - devs = list(devices.keys()) - devs.sort() - for dev in devs: - print(f'dev{dev}:') - for tensor in devices[dev]: - print(f'\t{tensor.extra_repr()}') - - def inner_transpose(self, dim: int, chunks: int): - """ - transpose ordering of tensor within a dimension. - """ - assert 0 <= dim and dim < len(self._mats.shape) - assert self.vec[dim] % chunks == 0 - ori_shape = list(self.vec) - new_shape = list(self.vec) - new_shape.insert(dim, self.vec[dim] // chunks) - new_shape[dim+1] = chunks - self._mats = self._mats.reshape(new_shape) - axes = list(range(len(new_shape))) - axes[dim], axes[dim+1] = axes[dim+1], axes[dim] - self._mats = self._mats.transpose(axes) - self._mats = self._mats.reshape(ori_shape) - - @staticmethod - def transpose(mat: np.ndarray, dim0: int, dim1: int): - """ - put the dim0 and dim1 of the mat to the last two dims - """ - ndims = len(mat.shape) - axes = list(range(ndims)) - assert dim0 < ndims and dim1 < ndims, "dim0 or dim1 out of index" - axes.pop(max(dim0, dim1)) - axes.pop(min(dim0, dim1)) - axes += [dim0, dim1] - return np.transpose(mat, axes) - - @staticmethod - def dims2last(mat: np.ndarray, dims: List[int]) -> np.ndarray: - """ - Permute a matrix by putting dimensions to the last. - """ - axes = list(range(len(mat.shape))) - for dim in dims: - axes.remove(dim) - axes += list(dims) - return np.transpose(mat, axes) - - @staticmethod - def dims2orig(mat: np.ndarray, last_dims: List[int]) -> np.ndarray: - axes = list(range(len(mat.shape))) - for dim in last_dims: - axes.remove(dim) - axes += list(last_dims) - axes = np.argsort(np.array(axes)) - return np.transpose(mat, axes) - - @staticmethod - def changed_dims(src: TRVD, dst: TRVD) -> Tuple[List[int], List[int]]: - """ - Get changed dimensions - - @param src Tuple[int]: the source RVD layout - @param dst Tuple[int]: the destination RVD layout - - @return inc_dims Tuple[int]: the dimensions that need to increase - @return dec_dims Tuple[int]: the dimensions that need to decrease - """ - assert len(src) == len(dst) - inc_dims, dec_dims = [], [] - for dim, (slen, dlen) in enumerate(zip(src, dst)): - if slen < dlen: - inc_dims.append(dim) - elif slen > dlen: - dec_dims.append(dim) - return inc_dims, dec_dims - - @staticmethod - def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int], devices: Optional[Tuple[int]] = None): - """ - partition a ftensor using grid layout of - """ - dims = tuple(dims) - def dummy_assign(tensor: IRSubTensor, devid: int): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = devid - - mats = np.empty((r, v) + dims, dtype=IRSubTensor) - all_subtensors = [] - - def iter_idx(dims: List[int]) -> Tuple[int]: - if len(dims) == 0: - yield () - else: - for i in range(dims[0]): - for indices in iter_idx(dims[1:]): - yield (i,) + indices - # generate tensor for each index - for indices in iter_idx((v,)+dims): - valmap = ValueMap((indices[0], v)) - indmap = [] - shape = [] - for dim, (nchunk, index) in enumerate(zip(dims, indices[1:])): - assert ftensor.shape[dim] % nchunk == 0, f"not dividable for {nchunk} chunks over dim {dim}. ftensor shape: {ftensor.shape}" - csize = ftensor.shape[dim] // nchunk - start = csize * index - indmap.append((start, start+csize)) - shape.append(csize) - subtensor = ftensor.select(tuple(indmap), valmap) - # replicate - subtensors = [copy.copy(subtensor) for _ in range(r)] - all_subtensors += subtensors - mats[(slice(None),)+indices] = np.array(subtensors, dtype=IRSubTensor) - - # devices - if devices is not None: - assert len(devices) == len(all_subtensors), f"devices number {len(devices)} not match with RVD number {len(all_subtensors)}" - for tensor, devid in zip(mats.flatten(), devices): - dummy_assign(tensor, int(devid)) - - return GridLayout(ftensor, all_subtensors, mats) - - @staticmethod - def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): - """ - convert ftensor and subtensors into a GridLayout. - - If failed, raise error - """ - _replica: int = None - _value: int = None - _dims: List[int] = [None] * len(ftensor.shape) - _tindex: Dict[int, List[int]] = dict() - - ndims = len(ftensor.shape) - - replicas: Dict[int, List[IRSubTensor]] = dict() - vchunks: set = set() - dchunks: List[set] = [set() for _ in range(ndims)] - - for subtensor in subtensors: - tid = id(subtensor) - # set up replica - if subtensor.tid not in replicas: - replicas[subtensor.tid] = [] - _tindex[tid] = [len(replicas[subtensor.tid])] - replicas[subtensor.tid].append(subtensor) - # setup value - _tindex[tid].append(subtensor.valmap[0]) - vchunks.add(subtensor.valmap[1]) - # setup dimensions - for dim in range(ndims): - snele = subtensor.shape[dim] - start = subtensor.indmap[dim][0] - fnele = ftensor.shape[dim] - if fnele % snele != 0 or start % snele != 0: - raise RuntimeError( - f"dimension split error:\n" - f"Full Tensor: {ftensor}\n" - f"full nele: {fnele}, sub nele: {snele}, start: {start}" - ) - dchunks[dim].add(fnele // snele) - _tindex[tid].append(start // snele) - # replica (R) - nreplicas = set(len(ts) for ts in replicas.values()) - if len(nreplicas) != 1: - raise RuntimeError(f"different replicas: {nreplicas}") - _replica = list(nreplicas)[0] - # value (V) - nchunks = set(t.valmap[1] for t in subtensors) - if len(nchunks) != 1: - raise RuntimeError(f"different value split: {nchunks}") - _value = list(nchunks)[0] - # dimension (D) - for dim in range(ndims): - if len(dchunks[dim]) != 1: - raise RuntimeError(f"different dimension split: {dchunks[dim]}") - _dims[dim] = list(dchunks[dim])[0] - - # set matrix - mats = np.empty([_replica, _value] + _dims, dtype=IRSubTensor) - for subtensor in subtensors: - idx = tuple(_tindex[id(subtensor)]) - assert mats[idx] is None, f"repeating entry. mutiple same {subtensor}" - mats[tuple(idx)] = subtensor - assert not (mats == None).any(), "at least one entry not set" - return GridLayout(ftensor, subtensors, mats) - - -class PathFinder: - """ - Pathfinder for generating communication plans for GridLayout - """ - - # intra-shard: cached nodes. paths[shape][i][j] = List[int] of indices from (src -> dst] - _cached_intra_nodes: Dict[Tuple[TShape, int], Tuple[TRVD]] = {} - _cached_intra_edges: Dict[Tuple[TShape, int], np.ndarray] = {} - _cached_intra_paths: Dict[Tuple[TShape, int], Dict[TRVD, List[List[int]]]] = {} - - # inter-shard: cached nodes. paths[(shape1, shape2)][i][j] = List[int] - _cached_inter_nodes: Dict[Tuple[TShape, int, int], Tuple[Tuple[TRVD]]] = {} - _cached_inter_edges: Dict[Tuple[TShape, int, int], Tuple[np.ndarray]] = {} - _cached_inter_paths: Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]] = {} - - @staticmethod - def intra_path(ftensor: IRFullTensor, - ilayout: GridLayout, olayout: GridLayout, - cost_fn: Optional[Callable] = None, - allow_fallback: bool = True, - allow_misalign: bool = False) -> Tuple[bool, List[GridLayout], List[IRAdapterPrim]]: - """ - Get primitive path of transforming ilayout into olayout. - ilayout has the same device set with olayout - - @param ftensor IRFullTensor: The fulltensor - @param ilayout GridLayout: input tensor layout - @param olayout GridLayout: output tensor layout - @param cost_fn Optional[Callable]: cost function of each primitive. - Default (None) will use transmission volume as metrics - @param allow_fallback bool: allow to use a fixed backup plan to make sure correct device mapping. (default True) - @param allow_misalign bool: allow to have a different device mapping. (default False) - - @return align bool: whether correctly align the device placement - @return layouts List[GridLayout]: each transformation. - @return prims List[IRAdapterPrim]: the primitives to perform transformation. - """ - cost_fn = PathFinder.default_cost_fn if cost_fn is None else cost_fn - shape = tuple(ftensor.shape) - key = (shape, ilayout.ndevs) - src = (ilayout.R, ilayout.V) + tuple(ilayout.D) - dst = (olayout.R, olayout.V) + tuple(olayout.D) - - # cases for same source and destination RVD - if src == dst: - if ilayout.align(olayout): - return True, [ilayout], [] - else: - if allow_misalign: - return False, [ilayout], [] - else: - assert False, "Same source and destination rvd but got mis-aligned devices" - - # get paths using dijkstra algorithm or cached - if key in PathFinder._cached_intra_paths and src in PathFinder._cached_intra_paths[key]: - paths = PathFinder._cached_intra_paths[key][src] - else: - # initialize the graph if not cached - if key not in PathFinder._cached_intra_nodes: - nodes, edges = PathFinder.init_intra_graph(ftensor, ilayout.ndevs, cost_fn) - PathFinder._cached_intra_nodes[key] = nodes - PathFinder._cached_intra_edges[key] = edges - PathFinder._cached_intra_paths[key] = {} - nodes = PathFinder._cached_intra_nodes[key] - edges = PathFinder._cached_intra_edges[key] - # build and initialize cost table - cost = np.full((len(nodes),), np.inf) - cost[nodes.index(src)] = 0 - # setup unvisited and visited set - unvisited = set(range(len(nodes))) - visited = set() - paths = [[] for _ in range(len(nodes))] - paths[nodes.index(src)] = [nodes.index(src)] - # dijkstra body - while len(unvisited) > 0: - min_cost, visit = np.inf, None - for idx in unvisited: - if cost[idx] < min_cost: - min_cost = idx - visit = idx - if visit is None: break # for remaining states that cannot reach - for neighbor in np.where(edges[visit] != np.inf)[0]: - new_cost = cost[visit] + edges[visit, neighbor] - if cost[neighbor] == np.inf or new_cost < cost[neighbor]: - cost[neighbor] = new_cost - paths[neighbor] = paths[visit] + [neighbor] - unvisited.remove(visit) - visited.add(visit) - PathFinder._cached_intra_paths[key][src] = paths - - # print for debug - # for idx, path in enumerate(paths): - # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") - - # get layout - nodes = PathFinder._cached_intra_nodes[key] - path: List[int] = paths[nodes.index(dst)] - rvds: List[Tuple[int]] = [nodes[idx] for idx in path] - assert len(path) > 0, f"Un-reachable src RVD ({src}) -> dst RVD ({dst})" - # print(f'path: {rvds}') - - # search for correct device mapping - align, layouts, all_prims = PathFinder.intra_dev_align(ftensor, rvds[1:], [ilayout], [], olayout) - if not align and allow_misalign: - layouts, all_prims = [], [] - curr_rvd, curr_layout = rvds[0], ilayout - for hop_rvd in rvds[1:]: - ret, layout_prims = PathFinder.intra_transition(ftensor, curr_rvd, hop_rvd, curr_layout) - assert ret, "Internal Error: intra-transition failed" - layout, prims = layout_prims[0] - layouts.append(layout) - all_prims += prims - curr_rvd, curr_layout = hop_rvd, layout - - elif not align: - error_msg = ( - f"Fail to align intra-RVD devices. {ftensor}\n" - f"Path: {' -> '.join(str(rvd) for rvd in rvds)}\n" - f"ptensors:\n\t" + "\n\t".join(tensor_vd_repr(ptensor) for ptensor in ilayout.mat.flatten()) + "\n" - f"ctensors:\n\t" + "\n\t".join(tensor_vd_repr(ctensor) for ctensor in olayout.mat.flatten()) + "\n" - f"Switch to a fixed plan: ilayout -> FullReplica -> olayout" - ) - color, default = '\033[33m' , '\033[0m' - print(color+error_msg+default, file=sys.stderr) - if allow_fallback: - # switch to a fixed plan ilayout -> R(n)V(1)D(1*) -> olayout - rlayout = GridLayout.grid(ftensor, r=ilayout.ndevs, v=1, dims=tuple(1 for _ in range(ilayout.ndims-2))) - # assign devices - for rt, ot in zip(rlayout.mat.flatten(), olayout.mat.flatten()): - rt.cell = ot.cell - # find left - left: List[int] = paths[nodes.index(tuple(rlayout.vec))] - left = [nodes[idx] for idx in left] - lsuccess, llayouts, lprims = PathFinder.intra_dev_align(ftensor, left[1:], [ilayout], [], rlayout) - assert lsuccess, f"Switch fail to generate left-half intra-RVD plans for all-replica" - # find right - _, rlayouts, rprims = PathFinder.intra_path(ftensor, rlayout, olayout, cost_fn, allow_fallback=False) - layouts = llayouts + rlayouts - all_prims = lprims + rprims - else: - # allow_fallback is False only for generating right-half intra-RVD - assert False, f"Switch fail to generate right-half intra-RVD plans from all-replica" - align = True - - return align, layouts, all_prims - - @staticmethod - def inter_path(ftensor: IRFullTensor, ilayout: GridLayout, olayout: GridLayout, cost_fn: Optional[Callable] = None) -> Tuple[List[GridLayout], List[IRAdapterPrim]]: - """ - Get primitives for transforming ilayout into olayout. ilayout has the different device set - to olayout. And number of device of ilayout and olayout must be divisable by each other. - - @param ftensor IRFullTensor: The fulltensor - @param ilayout GridLayout: input tensor layout - @param olayout GridLayout: output tensor layout - @param cost_fn Optional[Callable]: cost function of each primitive. - Default (None) will use transmission volume as metrics - - @return layouts List[GridLayout]: each transformation. - @return prims List[IRAdapterPrim]: the primitives to perform transformation. - """ - cost_fn = PathFinder.default_cost_fn if cost_fn is None else cost_fn - shape = tuple(ftensor.shape) - key = (shape, ilayout.ndevs, olayout.ndevs) - - src = ('p',) + (ilayout.R, ilayout.V) + tuple(ilayout.D) - dst = ('c',) + (olayout.R, olayout.V) + tuple(olayout.D) - - if key in PathFinder._cached_inter_nodes and src in PathFinder._cached_inter_paths[key]: - nodes = PathFinder._cached_inter_nodes[key] - paths = PathFinder._cached_inter_paths[key][src] - else: - if key in PathFinder._cached_inter_nodes: - nodes = PathFinder._cached_inter_nodes[key] - edges = PathFinder._cached_inter_edges[key] - else: - nodes, edges = PathFinder.init_inter_graph(ftensor, ilayout.ndevs, olayout.ndevs, cost_fn) - PathFinder._cached_inter_nodes[key] = nodes - PathFinder._cached_inter_edges[key] = edges - PathFinder._cached_inter_paths[key] = {} - # build cost - cost = np.full((len(nodes),), np.inf) - cost[nodes.index(src)] = 0 - # setup unvisited and visited set - unvisited = set(range(len(nodes))) - visited = set() - paths = [[] for _ in range(len(nodes))] - paths[nodes.index(src)] = [nodes.index(src)] - # dijkstra body - while len(unvisited) > 0: - min_cost, visit = np.inf, None - for idx in unvisited: - if cost[idx] < min_cost: - min_cost = idx - visit = idx - if visit is None: break - for neighbor in np.where(edges[visit] != np.inf)[0]: - new_cost = cost[visit] + edges[visit, neighbor] - if cost[neighbor] == np.inf or new_cost < cost[neighbor]: - cost[neighbor] = new_cost - paths[neighbor] = paths[visit] + [neighbor] - unvisited.remove(visit) - visited.add(visit) - PathFinder._cached_inter_paths[key][src] = paths - - # print for debug - # for idx, path in enumerate(paths): - # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") - - path = paths[nodes.index(dst)] - # print(f"Find path: {' -> '.join(str(nodes[i]) for i in path)}") - assert len(path) > 0, f"Un-reachable src RVD ({src}) -> dst RVD ({dst})" - - # setup consumer begining devices - cpaths = tuple(idx for idx in path if nodes[idx][0] == 'c') - curr_devs = np.array([t.device[0] for t in olayout.mat.flatten()]).reshape(dst[1:]) - curr_node = dst[1:] - # print('result device map:', list(cdevs.flatten())) - for hop in cpaths[:-1][::-1]: - hop_rvd = nodes[hop][1:] - curr_devs = PathFinder.inter_devmap(curr_node, hop_rvd, curr_devs) - curr_node = hop_rvd - consumer_entry_devs = curr_devs - # print('calculated consumer device map: ', list(cdevs.flatten())) - # setup primitives for communication - side, layouts, all_prims = 'p', [ilayout], [] - curr_rvd = src[1:] - for hop in path[1:]: - use_inter_step = side != nodes[hop][0] - hop_rvd = nodes[hop][1:] - if not use_inter_step: - ret, layout_prims = PathFinder.intra_transition(ftensor, curr_rvd, hop_rvd, layouts[-1]) - assert ret, "Internal Error: intra-RVD transition failed" - layout, prims = layout_prims[0] # the first only is enough for inter-rvd - else: - ret, layout, prims = PathFinder.inter_transform(ftensor, curr_rvd, hop_rvd, layouts[-1], consumer_entry_devs) - layouts.append(layout) - all_prims += prims - curr_rvd = hop_rvd - side = nodes[hop][0] - return layouts, all_prims - - @staticmethod - def intra_transition(ftensor: IRFullTensor, src_rvd: TRVD, dst_rvd: TRVD, - ilayout: Optional[GridLayout] = None) -> Tuple[bool, List[Tuple[GridLayout, List[IRAdapterPrim]]]]: - """ - Get output layout and transform primitives from a source rvd layout to dst_rvd layout, - - @param ftensor IRFullTensor - @param src_rvd Tuple[int] - @param dst_rvd Tuple[int] - @param ilayout Optional[GridLayout] - - @return ret bool: True if trainsition is successful. Otherwise False. - @return layout Optonal[GridLayout]: the RVD layout if ilayout is not None - @return prims Optional[List[IRAdapterPrim]]: the prmitives in transformation - """ - if ilayout is not None: - assert src_rvd == tuple(ilayout.vec) - inc_dims, dec_dims = GridLayout.changed_dims(src_rvd, dst_rvd) - if len(inc_dims) != 1 or len(dec_dims) != 1: - return False, [(None, [])] - inc_idx, dec_idx = inc_dims[0], dec_dims[0] - if src_rvd[dec_idx] % dst_rvd[dec_idx] != 0: - return False, [(None, [])] - if inc_idx == 1: - return False, [(None, [])] - src = ilayout if ilayout is not None else GridLayout.grid(ftensor, src_rvd[0], src_rvd[1], list(src_rvd[2:])) - chunks = src_rvd[dec_idx] // dst_rvd[dec_idx] - if dec_idx >= 2 and inc_idx == 0: # d2r - ret = src.d2r(dec_idx-2, chunks) - elif dec_idx >= 2 and inc_idx >= 2: # d2d - ret = src.d2d(dec_idx-2, inc_idx-2, chunks) - elif dec_idx == 1 and inc_idx == 0: # v2r - ret = src.v2r(chunks) - elif dec_idx == 1 and inc_idx >= 2: # v2d - ret = src.v2d(inc_idx-2, chunks) - elif dec_idx == 0 and inc_idx >= 2: # r2d - ret = src.r2d(inc_idx-2, chunks) - else: - raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") - return True, ret - - @staticmethod - def intra_dev_align(ftensor: IRFullTensor, remain_states: List[Tuple[int]], - ilayouts: List[GridLayout], all_prims: List[IRAdapterPrim], - olayout: GridLayout) -> Tuple[bool, List[GridLayout], List[IRAdapterPrim]]: - """ - Align devices for intra-RVD - - @param ftensor IRFullTensor - @param remain_states List[TRVD]: RVD representations - @param ilayouts List[GridLayout]: searched layouts - @param all_prims List[IRAdapterPrim]: searched primitives - @param olayout GridLayout: target layout with correct device mapping - - @return success bool: True if found device, else False. - @return layouts List[GridLayout]: the searched layouts with device match - @return primitives List[IRAdapterPrim]: the correspoinding primitives - """ - ilayout = ilayouts[-1] - if len(remain_states) == 0: - if not ilayout.align(olayout): - return False, [], [] - return True, ilayouts, all_prims - else: - success, layout_prims = PathFinder.intra_transition( - ftensor, (ilayout.R, ilayout.V) + ilayout.D, remain_states[0], ilayout) - assert success, "Internal Error at intra-RVD transition" - for (hop_layout, hop_prims) in layout_prims: - ret, ret_layouts, ret_prims = PathFinder.intra_dev_align( - ftensor, remain_states[1:], ilayouts + [hop_layout], all_prims + hop_prims, olayout) - if ret: - return True, ret_layouts, ret_prims - return False, [], [] - - @staticmethod - def inter_devmap(src_rvd: TRVD, dst_rvd: TRVD, src_devs: np.ndarray): - """ - Infer device from source rvd to destination rvd - """ - assert tuple(src_rvd) == tuple(src_devs.shape), f"RVD mis-matches with device shape, {src_rvd} != {src_devs.shape}" - # get changed dimensions - inc_idx, dec_idx = GridLayout.changed_dims(src_rvd, dst_rvd) - assert len(inc_idx) == 1 and len(dec_idx) == 1 - inc_idx, dec_idx = inc_idx[0], dec_idx[0] - assert src_rvd[dec_idx] % dst_rvd[dec_idx] == 0 - chunks = src_rvd[dec_idx] // dst_rvd[dec_idx] - # reshape array to match devices - dst_devs = np.full(dst_rvd, -1, dtype=int) - src_devs = GridLayout.dims2last(src_devs, [inc_idx, dec_idx]).reshape(-1, chunks) - dst_devs = GridLayout.dims2last(dst_devs, [dec_idx, inc_idx]) - dshape = dst_devs.shape - # set up device - dst_devs = dst_devs.reshape(-1, chunks) - for rid, devs in enumerate(src_devs): - dst_devs[rid] = devs - dst_devs = dst_devs.reshape(dshape) - # permute to original shape - dst_devs = GridLayout.dims2orig(dst_devs, [dec_idx, inc_idx]) - return dst_devs - - @staticmethod - def inter_transform(ftensor, src_rvd: TRVD, dst_rvd: TRVD, ilayout: Optional[GridLayout] = None, dst_devs: Optional[np.array] = None): - """ - Get output layout and transform primitives from a source rvd layout to dst_rvd layout, - - @param ftensor IRFullTensor - @param src_rvd Tuple[int] - @param dst_rvd Tuple[int] - @param ilayout Optional[GridLayout] - - @return ret bool: True if there is a primitive performed - @return layout Optonal[GridLayout]: the RVD layout if ilayout is not None - @return prims Optional[List[IRAdapterPrim]]: the prmitives in transformation - """ - inc_dims, dec_dims = GridLayout.changed_dims(src_rvd, dst_rvd) - if len(inc_dims) == 0 and len(dec_dims) == 0: - inc_dims = [0] - if not ((len(inc_dims) == 1 and len(dec_dims) == 0) or (len(inc_dims) == 0 and len(dec_dims) == 1)): - return False, None, None - inc_idx = inc_dims[0] if len(inc_dims) == 1 else None - dec_idx = dec_dims[0] if len(dec_dims) == 1 else None - src = ilayout if ilayout is not None else GridLayout.grid(ftensor, src_rvd[0], src_rvd[1], list(src_rvd[2:])) - if isinstance(inc_idx, int): - if not (dst_rvd[inc_idx] % src_rvd[inc_idx] == 0): - return False, None, None - chunks = dst_rvd[inc_idx] // src_rvd[inc_idx] - if inc_idx == 0: - olayout, prims = src.incr(chunks, dst_devs)[0] - elif inc_idx == 1: - olayout, prims = src.incv(chunks, dst_devs)[0] - elif inc_idx > 1: - olayout, prims = src.incd(chunks, inc_idx-2, dst_devs)[0] - else: - raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") - else: - if not (src_rvd[dec_idx] % dst_rvd[dec_idx] == 0): - return False, None, None - chunks = src_rvd[dec_idx] // dst_rvd[dec_idx] - if dec_idx == 0: - olayout, prims = src.decr(chunks, dst_devs)[0] - elif dec_idx == 1: - olayout, prims = src.decv(chunks, dst_devs)[0] - elif dec_idx > 1: - olayout, prims = src.decd(chunks, dec_idx-2, dst_devs)[0] - else: - raise RuntimeError(f"Cannot find primitive. Report as a bug. dec-idx: {dec_idx}, inc-idx: {inc_idx}") - return True, (olayout if ilayout is not None else None), prims - - @staticmethod - def init_intra_graph(ftensor: IRFullTensor, ndevs: int, cost_fn: Optional[Callable]) -> Tuple[List[TRVD], np.ndarray]: - """ - Initialize the graph of RVD status graph. - - @param ftensor IRFullTensor: the full tensor - @param ndevs int: total device number - - @return nodes Tuple[TRVD] - @return edges np.ndarray: edges among nodes - """ - nodes = tuple(PathFinder.get_inshard_space(ftensor, ndevs)) - edges = np.full((len(nodes), len(nodes)), np.inf) - # initialize the cost - for i in range(len(nodes)): - for j in range(len(nodes)): - if i == j: continue - src, dst = nodes[i], nodes[j] - ret, layout_and_prims = PathFinder.intra_transition(ftensor, src, dst) - if not ret: continue - prims = layout_and_prims[0][1] - edges[i, j] = cost_fn(prims[0]) - return nodes, edges - - @staticmethod - def init_inter_graph(ftensor: IRFullTensor, idevs: int, odevs: int, cost_fn: Callable) -> Tuple[List[TRVD], np.ndarray]: - """ - Initialize the graph of RVD status graph. - - An additional positition tage is append to at the first element of each node, i.e., - For source (producer) layout: ('p', 2,1,1,2) means - For dest (consumer) layout: ('c', 2,1,1,2) means - - @param ftensor IRFullTensor: the full tensor - @param idevs int: total device number of source tensor - - @return nodes Tuple[TRVD] - @return edges np.ndarray: edges among nodes - """ - shape = tuple(ftensor.shape) - if (shape, idevs) in PathFinder._cached_intra_nodes: - src_nodes = PathFinder._cached_intra_nodes[(shape, idevs)] - src_edges = PathFinder._cached_intra_edges[(shape, idevs)] - else: - src_nodes, src_edges = PathFinder.init_intra_graph(ftensor, idevs, cost_fn) - PathFinder._cached_intra_nodes[(shape, idevs)] = src_nodes - PathFinder._cached_intra_edges[(shape, idevs)] = src_edges - PathFinder._cached_intra_paths[(shape, idevs)] = {} - if (shape, odevs) in PathFinder._cached_inter_edges: - dst_nodes = PathFinder._cached_intra_nodes[(shape, odevs)] - dst_edges = PathFinder._cached_intra_edges[(shape, odevs)] - else: - dst_nodes, dst_edges = PathFinder.init_intra_graph(ftensor, odevs, cost_fn) - PathFinder._cached_intra_nodes[(shape, odevs)] = dst_nodes - PathFinder._cached_intra_edges[(shape, odevs)] = dst_edges - PathFinder._cached_intra_paths[(shape, odevs)] = {} - nodes = tuple(('p',) + n for n in src_nodes ) + tuple(('c',) + n for n in dst_nodes) - edges = np.full((len(nodes), len(nodes)), np.inf) - edges[:len(src_nodes), :len(src_nodes)] = src_edges - edges[len(src_nodes):,len(src_nodes):] = dst_edges - # NVLink: 300GBps Inter-node: 100Gbps - comm_factor = 24 - for i in range(len(src_nodes)): - for j in range(len(dst_nodes)): - src, dst = src_nodes[i], dst_nodes[j] - # set for [i, len(src_nodes) + j] - ret, _, prims = PathFinder.inter_transform(ftensor, src, dst) - if not ret: continue - edges[i, len(src_nodes) + j] = cost_fn(prims[0]) * comm_factor - # set for [len(src_nodes) + j, i] - ret, _, prims = PathFinder.inter_transform(ftensor, dst, src) - assert ret - edges[len(src_nodes) + j, i] = cost_fn(prims[0]) * comm_factor - return nodes, edges - - # utility function - @staticmethod - def get_inshard_space(ftensor: IRSubTensor, ndevs: int) -> List[Tuple[int, ...]]: - """ - Get all possible space that can be transformed from layout. - - This space is pruned by limiting partition number of each RVD dimension - in the range of [min(ilayout[dim], olayout[dim]), max(ilayout[dim], olayout[dim])] - - @param ftensor IRFullTensor - @param ilayout GridLayout: input layout - @param olayout GridLayout: output layout - - @return layouts List[GridLayout]: - """ - all_layouts: List[int] = [] - - def factors(ndevs: int, length: int): - if length == 1: yield [ndevs] - else: - for i in range(1, ndevs + 1): - if ndevs % i == 0: - for res in factors(ndevs // i, length - 1): - yield [i] + res - - for rvd in factors(ndevs, 2+len(ftensor.shape)): - skip = False - for dimlen, pnum in zip(ftensor.shape, rvd[2:]): - if dimlen % pnum != 0: - skip = True - break - if not skip: - all_layouts.append(tuple(rvd)) - return all_layouts - - @staticmethod - def default_cost_fn(prim: IRAdapterPrim) -> int: - return prim.volume() + 1 # 1 is hop penalty diff --git a/cube/graph/gener/rvd/__init__.py b/cube/graph/gener/rvd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/graph/gener/rvd/inter.py b/cube/graph/gener/rvd/inter.py new file mode 100644 index 00000000..fb0ac4e2 --- /dev/null +++ b/cube/graph/gener/rvd/inter.py @@ -0,0 +1,474 @@ +from typing import Callable, Dict, List, Tuple, Optional, Set, Union +from functools import partial +import numpy as np +import sys +import copy + +from cube.ir.dtype import IRDType +from cube.ir.tensor import IRFullTensor + +from cube.ir.adapter.prim import IRAdapterPrim +from cube.ir.adapter.prim import MovePrim # p2p +from cube.ir.adapter.prim import BroadcastPrim +from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim +from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim + +from cube.graph.gener.rvd.layout import RVDLayout +from cube.graph.gener.rvd.intra import IntraPathFinder +from cube.graph.gener.utils import tensor_vd_repr + + +TShape = Tuple[int, ...] +TRVD = Tuple[int, ...] +InterRVD = Tuple[str, int,] # ('p', 2, 1, 1, ...) or ('c', 2, 1, 1, ...) + + +class InterTransition: + """ + Inter-RVD transition primitives + """ + + @staticmethod + def incr(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: + """ + Inter-RVD extended primitive: increase replica number + + @param rvd Tuple[int]: source rvd + @param chunks int: the number to multiply + + @return rvd Tuple[int]: transformed RVD + @return prim Callable: primitive class + """ + rvd = list(rvd) + rvd[0] = rvd[0] * chunks + return rvd, MovePrim if chunks == 1 else BroadcastPrim + + @staticmethod + def decr(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: + """ + Inter-RVD extended primitive: decrease replica number + + @param rvd Tuple[int]: source rvd + @param chunks int: the number to divide + + @return rvd Tuple[int]: transformed RVD + @return prim Callable: primitive class + """ + assert rvd[0] % chunks == 0, f"not divisible replica {rvd[0]} // {chunks}" + rvd = list(rvd) + rvd[0] = rvd[0] // chunks + return rvd, MovePrim + + @staticmethod + def incd(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: + """ + Inter-RVD extended primitive: increase tensor dimension partition + + @param rvd Tuple[int]: source rvd + @param dim int: the tensor axes to increase + @param chunks int: the number to multiply + + @return rvd Tuple[int]: transformed RVD + @return prim Callable: primitive class + """ + rvd = list(rvd) + rvd[2+dim] = rvd[2+dim] * chunks + return rvd, partial(RDScatterPrim, dim=dim) + + @staticmethod + def decd(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: + """ + Inter-RVD extended primitive: decrease tensor dimension partition + + @param rvd Tuple[int]: source rvd + @param dim int: the tensor axes to decrease + @param chunks int: the number to divide + + @return rvd Tuple[int]: transformed RVD + @return prim Callable: primitive class + """ + assert rvd[2+dim] % chunks == 0, f"not divisible dim: {rvd[2+dim]} % {chunks} != 0" + rvd = list(rvd) + rvd[2+dim] = rvd[2+dim] // chunks + return rvd, partial(RDGatherPrim, dim=dim) + + @staticmethod + def incv(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: + """ + Inter-RVD extended primitive: increase value partition + + @param rvd Tuple[int]: source rvd + @param chunks int: the number to multiply + + @return rvd Tuple[int]: transformed RVD + @return prim Callable: primitive class + """ + rvd = list(rvd) + rvd[1] *= 2 + return rvd, RVScatterPrim + + @staticmethod + def decv(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: + """ + Inter-RVD extended primitive: decrease value partition + + @param rvd Tuple[int]: source rvd + @param chunks int: the number to divide + + @return rvd Tuple[int]: transformed RVD + @return prim Callable: primitive class + """ + assert rvd[1] % chunks == 0, f"not divisable value split: {rvd[1]} % {chunks} != 0" + rvd = list(rvd) + rvd[1] = rvd[1] // chunks + return rvd, RVGatherPrim + + @staticmethod + def transitionable(src_rvd: TRVD, dst_rvd: TRVD) -> Optional[Callable]: + """ + Check wheter a primitive exists to transform src_rvd to dst_rvd + + @param src_rvd TRVD: source RVD + @param dst_rvd TRVD: destination RVD + + @return trans_fn Optional[Callable]: None indicates no primitive found. + """ + trans_fn = None + incd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 < d2] + decd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 > d2] + if len(incd) == 0 and len(decd) == 0: + decd = [0] + if len(incd) + len(decd) != 1: return trans_fn + if len(incd) == 1: + incd = incd[0] + if incd == 0: # incr + return InterTransition.incr + elif incd == 1: + return InterTransition.incv + else: + return partial(InterTransition.incd, dim=incd-2) + else: + decd = decd[0] + if decd == 0: # decr + return InterTransition.decr + elif decd == 1: + return InterTransition.decv + else: + return partial(InterTransition.decd, dim=decd-2) + + @staticmethod + def transition(src_layout: RVDLayout, dst_rvd: TRVD, placement: Optional[Tuple[int]] = None) -> Tuple[RVDLayout, List[IRAdapterPrim]]: + """ + Transfer from source RVD to destination RVD. + Get all possible device-placement choices for RVD + given the fixed device placement of RVD. + + @param src_layout RVDLayout: source ilayout + @param dst_rvd Tuple[int]: destination RVD + @param placement Tuple[int]: output layout device placement + + @return rets Tuple[GridLayout, List[IRAdapterPrim]]: + pairs of of output + """ + + src_rvd = src_layout.vec + ftensor = src_layout.ftensor + dst_layout: RVDLayout = RVDLayout.grid(ftensor, r=dst_rvd[0], v=dst_rvd[1], dims=dst_rvd[2:], devices=placement) + trans_fn = InterTransition.transitionable(src_rvd, dst_rvd) + assert trans_fn is not None, f"Cannot find primitive: {src_rvd} -> {dst_rvd}" + # get primitive + incd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 < d2] + decd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 > d2] + if len(incd) == 0 and len(decd) == 0: + decd = [0] + + if len(incd) == 1: + change_dim = incd[0] + chunks = dst_rvd[change_dim] // src_rvd[change_dim] + else: + change_dim = decd[0] + chunks = src_rvd[change_dim] // dst_rvd[change_dim] + _, primitive = trans_fn(src_rvd, chunks=chunks) + + imat = RVDLayout.dim2last(src_layout.mat, change_dim, src_rvd[change_dim]) + omat = RVDLayout.dim2last(dst_layout.mat, change_dim, dst_rvd[change_dim]) + + prims = [] + if len(incd) == 1: + for src, dsts in zip(imat.flatten(), omat.reshape(-1, chunks)): + dsts = dsts.tolist() + if primitive is BroadcastPrim: + dsts = [src] + dsts + prims.append(primitive([src], dsts)) + else: + for srcs, dst in zip(imat.reshape(-1, chunks), omat.flatten()): + srcs = srcs.tolist() + prims.append(primitive(srcs, [dst])) + return dst_layout, prims + + +class InterPathFinder: + """ + inter-RVD Path finder for generating communication plans for RVDLayout + """ + + _cached_inter_nodes: Dict[Tuple[TShape, int, int], Tuple[Tuple[InterRVD]]] = {} + _cached_inter_edges: Dict[Tuple[TShape, int, int], Tuple[np.ndarray]] = {} + _cached_inter_paths: Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]] = {} + + + @staticmethod + def path(ilayout: RVDLayout, olayout: RVDLayout, cost_fn: Optional[Callable] = None) -> List[IRAdapterPrim]: + """ + Get primitive path of transforming ilayout into olayout. + ilayout must locate on different device set of olayout + + @param ilayout RVDLayout: input tensor layout + @param olayout RVDLayout: output tensor layout + @param cost_fn Optional[Callable]: cost function of each primitive. + Default (None) will use transmission volume as metrics + + @return all_primitives List[IRAdapterPrims]: all primitives for communication path + """ + ftensor: IRFullTensor = ilayout.ftensor + cost_fn = InterPathFinder.default_cost_fn if cost_fn is None else cost_fn + + inter_rvds: List[InterRVD] = InterPathFinder.get_optimal_path( + ftensor, ilayout.vec, olayout.vec, cost_fn) + + all_prims = InterPathFinder.device_align(ilayout, olayout, inter_rvds) + return all_prims + + @staticmethod + def device_align(ilayout: RVDLayout, olayout: RVDLayout, + rvd_paths: Tuple[InterRVD]) -> Tuple[IRAdapterPrim]: + """ + Align devices for inter-RVD + + @param ilayouts List[RVDLayout]: searched layouts + @param olayout RVDLayout: target layout with correct device mapping + @param rvd_hops: Tuple[TRVD]: the hops from ilayout to olayout, which + contains ilayout and olayout at beginning and last, respectively. + + @return primitives List[IRAdapterPrim]: the correspoinding primitives + """ + # decode producer and consumer part + prvds, crvds = InterPathFinder.decode(rvd_paths) + + # get possible consumer deivce space: try with reversed path + cdev_space = IntraPathFinder.get_device_space( + olayout.ftensor, crvds[::-1], + tuple(t.device[0] for t in olayout.mat.flatten())) + + # setup producer primitives + producer_out_devs = None + pdev_space = IntraPathFinder.get_device_space( + ilayout.ftensor, prvds, + tuple(t.device[0] for t in ilayout.mat.flatten()) + ) + for pdevs in pdev_space: + producer_out_devs = pdevs + playout = RVDLayout.grid( + ilayout.ftensor, r=prvds[-1][0], v=prvds[-1][1], + dims=prvds[-1][2:], devices=pdevs + ) + align, pprims = IntraPathFinder.device_align(ilayout, playout, prvds) + assert align, "Internal Error: inter-rvd producer side device fails to align" + break # we only take the first one + assert producer_out_devs is not None, f"Can't find inter-rvd producer out device placement" + + # setup consumer primitives and entry device placement + consumer_entry_devs = None + for cdevs in cdev_space: + clayout = RVDLayout.grid( + olayout.ftensor, r=crvds[0][0], v=crvds[0][1], + dims=crvds[0][2:], devices=cdevs) + align, cprims = IntraPathFinder.device_align(clayout, olayout, crvds) + if align: + consumer_entry_devs = cdevs + break + assert consumer_entry_devs is not None, f"Can't find inter-rvd consumer entry device placement." + + # setup inter-primitive + _, iprims = InterTransition.transition(playout, crvds[0], consumer_entry_devs) + + # merge together + return pprims + iprims + cprims + + @staticmethod + def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, cost_fn: Optional[Callable] = None) -> List[InterRVD]: + """ + Get optimal RVD path from source RVD to destination RVD + + @param src_rvd Tuple[int]: source RVD + @param dst_rvd Tuple[int]: destination RVD + + @return path Tuple[InterRVD]: + The first one is src_rvd. The last one is dst_rvd. + Otherwise they are intermediate RVD status + """ + src_ndevs = np.prod(src_rvd) + src = ('p',) + src_rvd + dst_ndevs = np.prod(dst_rvd) + dst = ('c',) + dst_rvd + + key = (tuple(ftensor.shape), np.prod(src_rvd), np.prod(dst_rvd)) + + if key in InterPathFinder._cached_inter_nodes and src in InterPathFinder._cached_inter_paths[key]: + nodes = InterPathFinder._cached_inter_nodes[key] + paths = InterPathFinder._cached_inter_paths[key][src] + else: + if key in InterPathFinder._cached_inter_nodes: + nodes = InterPathFinder._cached_inter_nodes[key] + edges = InterPathFinder._cached_inter_edges[key] + else: + nodes, edges = InterPathFinder.init_graph(ftensor, src_ndevs, dst_ndevs, cost_fn) + InterPathFinder._cached_inter_nodes[key] = nodes + InterPathFinder._cached_inter_edges[key] = edges + InterPathFinder._cached_inter_paths[key] = {} + # build cost + cost = np.full((len(nodes),), np.inf) + cost[nodes.index(src)] = 0 + # setup unvisited and visited set + unvisited = set(range(len(nodes))) + visited = set() + paths = [[] for _ in range(len(nodes))] + paths[nodes.index(src)] = [nodes.index(src)] + # dijkstra body + while len(unvisited) > 0: + min_cost, visit = np.inf, None + for idx in unvisited: + if cost[idx] < min_cost: + min_cost = idx + visit = idx + if visit is None: break + for neighbor in np.where(edges[visit] != np.inf)[0]: + new_cost = cost[visit] + edges[visit, neighbor] + if cost[neighbor] == np.inf or new_cost < cost[neighbor]: + cost[neighbor] = new_cost + paths[neighbor] = paths[visit] + [neighbor] + unvisited.remove(visit) + visited.add(visit) + InterPathFinder._cached_inter_paths[key][src] = paths + + # print for debug + # for idx, path in enumerate(paths): + # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") + + path = paths[nodes.index(dst)] + assert len(path) > 0, f"Un-reachable src RVD {src} -> dst RVD {dst}" + inter_rvds = tuple(nodes[idx] for idx in path) + return inter_rvds + + @staticmethod + def init_graph(ftensor: IRFullTensor, src_ndevs: int, dst_ndevs: int, cost_fn: Callable) -> Tuple[List[TRVD], np.ndarray]: + """ + Initialize the graph of RVD status graph. + + An additional positition tage is append to at the first element of each node, i.e., + For source (producer) layout: ('p', 2,1,1,2) means + For dest (consumer) layout: ('c', 2,1,1,2) means + + @param ftensor IRFullTensor: the full tensor + @param idevs int: total device number of source tensor + + @return nodes Tuple[TRVD] + @return edges np.ndarray: edges among nodes + """ + shape = tuple(ftensor.shape) + + if (shape, src_ndevs) in IntraPathFinder._cached_intra_nodes: + src_nodes = IntraPathFinder._cached_intra_nodes[(shape, src_ndevs)] + src_edges = IntraPathFinder._cached_intra_edges[(shape, src_ndevs)] + else: + src_nodes, src_edges = IntraPathFinder.init_graph(ftensor, src_ndevs, cost_fn) + IntraPathFinder._cached_intra_nodes[(shape, src_ndevs)] = src_nodes + IntraPathFinder._cached_intra_edges[(shape, src_ndevs)] = src_edges + IntraPathFinder._cached_intra_paths[(shape, src_ndevs)] = {} + + if (shape, dst_ndevs) in InterPathFinder._cached_inter_edges: + dst_nodes = IntraPathFinder._cached_intra_nodes[(shape, dst_ndevs)] + dst_edges = IntraPathFinder._cached_intra_edges[(shape, dst_ndevs)] + else: + dst_nodes, dst_edges = IntraPathFinder.init_graph(ftensor, dst_ndevs, cost_fn) + IntraPathFinder._cached_intra_nodes[(shape, dst_ndevs)] = dst_nodes + IntraPathFinder._cached_intra_edges[(shape, dst_ndevs)] = dst_edges + IntraPathFinder._cached_intra_paths[(shape, dst_ndevs)] = {} + nodes = tuple(('p',) + n for n in src_nodes ) + tuple(('c',) + n for n in dst_nodes) + edges = np.full((len(nodes), len(nodes)), np.inf) + edges[:len(src_nodes), :len(src_nodes)] = src_edges + edges[len(src_nodes):,len(src_nodes):] = dst_edges + # NVLink: 300GBps Inter-node: 100Gbps + for i in range(len(src_nodes)): + for j in range(len(dst_nodes)): + src, dst = src_nodes[i], dst_nodes[j] + if InterTransition.transitionable(src, dst) is None: continue + cost = InterPathFinder.estimate_cost( + ftensor, (('p',) + src, ('c',) + dst), cost_fn) + # set for [i, len(src_nodes) + j] + edges[i, len(src_nodes) + j] = cost + # set for [len(src_nodes) + j, i] + edges[len(src_nodes) + j, i] = cost + return nodes, edges + + @staticmethod + def decode(inter_rvds: Tuple[InterRVD]) -> Tuple[Tuple[TRVD], Tuple[TRVD]]: + """ + Decode searched inter-rvd paths into intra-rvd representations (TRVD) + for producer and consumer side. + + @param inter_rvds Tuple[InterRVD] + + @return prvds Tuple[TRVD]: rvd paths of producer side + @return crvds Tuple[TRVD]: rvd paths of consumer side + """ + bps = [idx for idx in range(len(inter_rvds) - 1) if inter_rvds[idx][0] != inter_rvds[idx+1][0]] + assert len(bps) == 1, \ + f"Expect path to be producer intra-rvd -> inter -> consumer intra-rvd: {inter_rvds}" + bp = bps[0] + + prvds = tuple(rvd for rvd in inter_rvds[:bp+1]) + assert all(rvd[0] == 'p' for rvd in prvds) + prvds = tuple(rvd[1:] for rvd in prvds) + if len(prvds) == 1: + prvds = prvds * 2 + + crvds = tuple(rvd for rvd in inter_rvds if rvd[0] == 'c') + assert all(rvd[0] == 'c' for rvd in crvds) + crvds = tuple(rvd[1:] for rvd in crvds) + if len(crvds) == 1: + crvds = crvds * 2 + + return prvds, crvds + + @staticmethod + def estimate_cost(ftensor: IRFullTensor, rvd_paths: Tuple[InterRVD], + cost_fn: Optional[Callable] = None) -> float: + """ + Estimate transition cost + + @return cost float + """ + cost_fn = InterPathFinder.default_cost_fn if cost_fn is None else cost_fn + # decode producer and consumer part + prvds, crvds = InterPathFinder.decode(rvd_paths) + # producer cost + pcost = IntraPathFinder.estimate_cost(ftensor, prvds, cost_fn) + # consumer cost + ccost = IntraPathFinder.estimate_cost(ftensor, crvds, cost_fn) + # inter-cost + pndevs = np.prod(prvds[-1]) + cndevs = np.prod(crvds[0]) + playout = RVDLayout.grid( + ftensor, r=prvds[-1][0], v=prvds[-1][1], + dims=prvds[-1][2:], devices=list(range(pndevs))) + _, prims = InterTransition.transition(playout, crvds[0], list(range(pndevs, pndevs + cndevs))) + icost = cost_fn(prims[0]) + # gather all + # consider differnt linkbandwidth intra NVLink 300GB/s vs. inter-node 100Gbps + comm_factor = 24 + return pcost + ccost + icost * comm_factor + + @staticmethod + def default_cost_fn(prim: IRAdapterPrim) -> int: + return prim.volume() + 1 # 1 is hop penalty diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py new file mode 100644 index 00000000..f969fcf3 --- /dev/null +++ b/cube/graph/gener/rvd/intra.py @@ -0,0 +1,596 @@ +from typing import Callable, Dict, List, Tuple, Optional, Set +from functools import partial +import numpy as np +import sys +import copy + +from cube.ir.dtype import IRDType +from cube.ir.tensor import IRFullTensor + +from cube.ir.adapter.prim import IRAdapterPrim +from cube.ir.adapter.prim import AllGatherPrim # d2r +from cube.ir.adapter.prim import AllToAllPrim # d2d +from cube.ir.adapter.prim import AllReducePrim # v2r +from cube.ir.adapter.prim import ReduceScatterPrim # v2d +from cube.ir.adapter.prim import ChunkPrim # r2d +from cube.ir.adapter.prim import VChunkPrim # r2v + +from cube.graph.gener.rvd.layout import RVDLayout + +from cube.graph.gener.utils import tensor_vd_repr + + +TShape = Tuple[int, ...] +TRVD = Tuple[int, ...] + + +class IntraTransition: + """ + Intra-RVD transition primitives + """ + + @staticmethod + def d2r(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: + """ + intra-RVD primitive D->R: allgather + + @param rvd Tuple[int]: input RVD + @param dim int: tensor dimension + @param chunks int: the number of chunks to transfer + + @return rvd Tuple[int]: output RVD + @return prim Callable: IRAdapter primitive + """ + assert rvd[2+dim] % chunks == 0, f"not dividable dim: {rvd[2+dim]} // {chunks}" + rvd = list(rvd) + rvd[0], rvd[2+dim] = rvd[0] * chunks, rvd[2+dim] // chunks + return rvd, partial(AllGatherPrim, dim=dim) + + @staticmethod + def d2d(rvd: TRVD, from_dim: int, to_dim: int, chunks: int) -> Tuple[TRVD, Callable]: + """ + intra-RVD primitive D(...,i,..)->D(..,j,...): alltoall + + @param rvd Tuple[int]: input RVD + @param from_dim int: source tensor axis + @param to_dim int: destination tensor axis + @param chunks int: the number of chunks to transfer + + @return rvd Tuple[int]: output RVD + @return prim Callable: IRAdapter primitive + """ + assert rvd[2+from_dim] % chunks == 0, f"not dividable dim: {rvd[2+from_dim]} // {chunks}" + rvd = list(rvd) + rvd[2+from_dim], rvd[2+to_dim] = rvd[2+from_dim] // chunks, rvd[2+to_dim] * chunks + return rvd, partial(AllToAllPrim, idim=from_dim, odim=from_dim) + + @staticmethod + def v2r(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: + """ + intra-RVD primitive V->R: allreduce + + @param dim int: tensor dimension + @param chunks int: the number of chunks to transfer + + @return rvd Tuple[int]: output RVD + @return prim Callable: IRAdapter primitive + """ + assert rvd[1] % chunks == 0, f"not dividable value chunks: {rvd[1]} // {chunks}" + rvd = list(rvd) + rvd[1], rvd[0] = rvd[1] // chunks, rvd[0] * chunks + return rvd, AllReducePrim + + @staticmethod + def v2d(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: + """ + intra-RVD primitive V->D: reduce-scatter + + @param dim int: tensor dimension + @param chunks int: the number of chunks to transfer + + @return rvd Tuple[int]: output RVD + """ + assert rvd[1] % chunks == 0, f"not dividable value chunks: {rvd[1]} // {chunks}" + rvd = list(rvd) + rvd[1], rvd[2+dim] = rvd[1] // chunks, rvd[2+dim] * chunks + return rvd, partial(ReduceScatterPrim, dim=dim) + + @staticmethod + def r2d(rvd: TRVD, dim: int, chunks: int) -> Tuple: + """ + intra-RVD primitive V->D: schunk + + @param dim int: tensor axis + @param chunks int: the number of chunks to transfer + + @return rvd Tuple[int]: output RVD + @return prim Callable: IRAdapter primitive + """ + assert rvd[0] % chunks == 0, f"not dividable replica: {rvd[0]} // {chunks}" + rvd = list(rvd) + rvd[0], rvd[2+dim] = rvd[0] // chunks, rvd[2+dim] * chunks + return rvd, partial(ChunkPrim, dim=dim) + + @staticmethod + def r2v(rvd: TRVD, chunks: int) -> Tuple: + """ + intra-RVD primitive V->D: schunk + + @param dim int: tensor axis + @param chunks int: the number of chunks to transfer + + @return rvd Tuple[int]: output RVD + @return prim Callable: IRAdapter primitive + """ + assert rvd[0] % chunks == 0, f"not dividable replica: {rvd[0]} // {chunks}" + rvd = list(rvd) + rvd[0], rvd[1] = rvd[0] // chunks, rvd[1] * chunks + return rvd, VChunkPrim + + @staticmethod + def transitionable(src_rvd: TRVD, dst_rvd: TRVD) -> Optional[Callable]: + """ + Check wheter a primitive exists to transform src_rvd to dst_rvd + + @param src_rvd TRVD: source RVD + @param dst_rvd TRVD: destination RVD + + @return trans_fn Optional[Callable]: None indicates no primitive found. + """ + trans_fn = None + incd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 < d2] + decd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 > d2] + if len(incd) != 1 or len(decd) != 1: return None + incd, decd = incd[0], decd[0] + # TODO: optimize: enable following may miss best solution + # ========= prune graph to avoid device mis-alignment ======== # + # for d in range(min(incd, decd) + 1, max(incd, decd)): + # if src_rvd[d] != 1 or dst_rvd[d] != 1: return None + # ============================================================ # + # if incd == 1: return None + if decd >= 2 and incd == 0: # d2r + trans_fn = partial(IntraTransition.d2r, dim=decd-2) + elif decd >= 2 and incd >= 2: # d2d + trans_fn = partial(IntraTransition.d2d, from_dim=decd-2, to_dim=incd-2) + elif decd == 1 and incd == 0: # v2r + trans_fn = IntraTransition.v2r + elif decd == 1 and incd >= 2: # v2d + trans_fn = partial(IntraTransition.v2d, dim=incd-2) + elif decd == 0 and incd >= 2: # r2d + trans_fn = partial(IntraTransition.r2d, dim=incd-2) + elif decd == 0 and incd == 1: # r2v + trans_fn = IntraTransition.r2v + return trans_fn + + @staticmethod + def transition(src_layout: RVDLayout, dst_rvd: TRVD) -> List[Tuple[RVDLayout, List[IRAdapterPrim]]]: + """ + Transfer from source RVD to destination RVD. + Get all possible device-placement choices for RVD + given the fixed device placement of RVD. + + @param src_layout RVDLayout: source ilayout + @param dst_rvd Tuple[int]: destination RVD + + @return rets List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. + """ + src_rvd = src_layout.vec + if src_rvd == dst_rvd: return [(src_layout, [])] + trans_fn = IntraTransition.transitionable(src_rvd, dst_rvd) + assert trans_fn is not None, f"Cannot find primitive: {src_rvd} -> {dst_rvd}" + # get primitive + incd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 < d2][0] + decd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 > d2][0] + chunks = src_rvd[decd] // dst_rvd[decd] + _, primitive = trans_fn(src_rvd, chunks=chunks) + + # get device spaces + optional_dims = {0, 1} + devices = tuple(t.device[0] for t in src_layout.mat.flatten()) + + ilayouts: List[RVDLayout] = [src_layout] + olayouts: List[RVDLayout] = [RVDLayout.grid(src_layout.ftensor, r=dst_rvd[0], v=dst_rvd[1], dims=dst_rvd[2:])] + # setup ilayout choices + if decd in optional_dims: + ftensor = src_layout.ftensor + for k in range(2, src_rvd[decd]): + if src_rvd[decd] % k != 0: continue + ilayout = RVDLayout.grid( + ftensor, r=src_rvd[0], v=src_rvd[1], dims=src_rvd[2:], devices=devices) + ilayout.inner_transpose(decd, k) + ilayouts.append(ilayout) + + # get olayouts with device placement + rets = [] + for ilayout in ilayouts: + for olayout in olayouts: + if len(ilayouts) > 1: olayout = copy.copy(olayout) + if len(olayouts) > 1: ilayout = copy.copy(ilayout) + # print(f'transition: {ilayout}{tuple(t.device[0] for t in ilayout.mat.flatten())} -> {olayout}') + imat = RVDLayout.dim2last(ilayout.mat, decd, chunks) + omat = RVDLayout.dim2last(olayout.mat, incd, chunks) + for itensor, otensor in zip(imat.flatten(), omat.flatten()): + otensor.cell = itensor.cell + prims = [] + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(primitive(itensors, otensors)) + rets.append((olayout, prims)) + return rets + + +class IntraPathFinder: + """ + intra-RVD Path finder for generating communication plans for RVDLayout + """ + + # intra-shard: cached nodes. paths[shape][i][j] = List[int] of indices from (src -> dst] + _cached_intra_nodes: Dict[Tuple[TShape, int], Tuple[TRVD]] = {} + _cached_intra_edges: Dict[Tuple[TShape, int], np.ndarray] = {} + _cached_intra_paths: Dict[Tuple[TShape, int], Dict[TRVD, List[List[int]]]] = {} + + @staticmethod + def path(ilayout: RVDLayout, olayout: RVDLayout, + cost_fn: Optional[Callable] = None) -> List[IRAdapterPrim]: + """ + Get primitive path of transforming ilayout into olayout. + ilayout must have same device set with olayout + + @param ilayout RVDLayout: input tensor layout + @param olayout RVDLayout: output tensor layout + @param cost_fn Optional[Callable]: cost function of each primitive. + Default (None) will use transmission volume as metrics + + @return all_primitives List[IRAdapterPrims]: all primitives for communication path + """ + assert ilayout.ftensor == olayout.ftensor, f"ilayout and olayout should have a same full tensor" + ftensor = ilayout.ftensor + src, dst = tuple(ilayout.vec), tuple(olayout.vec) + rvds = IntraPathFinder.get_optimal_path(ftensor, src, dst) + + # search for correct device mapping + align, all_prims = IntraPathFinder.device_align(ilayout, olayout, rvds) + + if not align: + warn_msg = ( + f"Fail to align intra-RVD devices. {ftensor}\n" + f"Path: {' -> '.join(str(rvd) for rvd in rvds)}\n" + f"ptensors:\n\t" + "\n\t".join(tensor_vd_repr(ptensor) for ptensor in ilayout.mat.flatten()) + "\n" + f"ctensors:\n\t" + "\n\t".join(tensor_vd_repr(ctensor) for ctensor in olayout.mat.flatten()) + "\n" + f"Switch to a fixed plan: ilayout -> FullReplica -> olayout" + ) + color, default = '\033[33m' , '\033[0m' + print(color+warn_msg+default, file=sys.stderr) + all_prims = IntraPathFinder.backup_path(ilayout, olayout, cost_fn) + + return all_prims + + @staticmethod + def backup_path(ilayout: RVDLayout, olayout: RVDLayout, + cost_fn: Optional[Callable] = None) -> List[IRAdapterPrim]: + """ + Get primitive path of transforming ilayout into olayout. + ilayout has the same device set with olayout. + + The path generation searches for a default communication plan + by ilayout -> FullReplica -> olayout. + + @param ilayout RVDLayout: input tensor layout + @param olayout RVDLayout: output tensor layout + @param cost_fn Optional[Callable]: cost function of each primitive. + Default (None) will use transmission volume as metrics + + @return all_primitives List[IRAdapterPrims]: all primitives for communication path + """ + assert ilayout.ftensor == olayout.ftensor, f"ilayout and olayout should have a same full tensor" + ftensor = ilayout.ftensor + src, dst = tuple(ilayout.vec), tuple(olayout.vec) + # create all-replicate rvd + rlayout = RVDLayout.grid(ftensor, r=ilayout.ndevs, v=1, dims=tuple(1 for _ in range(ilayout.ndims-2))) + for rt, ot in zip(rlayout.mat.flatten(), olayout.mat.flatten()): + rt.cell = ot.cell + rep = tuple(rlayout.vec) + # search for left primitives + left: List[TRVD] = IntraPathFinder.get_optimal_path(ftensor, src, rep, cost_fn) + align, lprims = IntraPathFinder.device_align(ilayout, rlayout, left) + assert align, f"Fail to align devices of backup plan at left side: {src} -> {rep}" + # search + right: List[TRVD] = IntraPathFinder.get_optimal_path(ftensor, rep, dst, cost_fn) + align, rprims = IntraPathFinder.device_align(rlayout, olayout, right) + assert align, f"Fail to align devices of backup plan at right side: {rep} -> {dst}" + return lprims + rprims + + @staticmethod + def device_align(ilayout: RVDLayout, olayout: RVDLayout, + rvd_path: Tuple[TRVD], _all_prims: Optional[None] = None) -> Tuple[bool, List[IRAdapterPrim]]: + """ + Align devices for intra-RVD + + @param ilayouts RVDLayout: source layout + @param olayout RVDLayout: target layout with correct device mapping + @param rvd_hops: Tuple[TRVD]: the hops from ilayout to olayout + (not contains ilayout at beginning, but contains olayout at last) + + @return success bool: True if found device, else False. + @return primitives List[IRAdapterPrim]: the correspoinding primitives + """ + _all_prims = [] if _all_prims is None else _all_prims + assert ilayout.vec == rvd_path[0] and olayout.vec == rvd_path[-1] + if len(rvd_path) == 1: + if not ilayout.align(olayout): + return False, [] + return True, _all_prims + else: + layout_prims = IntraTransition.transition(ilayout, rvd_path[1]) + for (hop_layout, hop_prims) in layout_prims: + ret, ret_prims = IntraPathFinder.device_align( + hop_layout, olayout, rvd_path[1:], _all_prims + hop_prims) + if ret: + return True, ret_prims + return False, [] + + @staticmethod + def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, + cost_fn: Optional[Callable] = None) -> Tuple[TRVD]: + """ + Get optimal RVD path from source RVD to destination RVD + + @param src_rvd Tuple[int]: source RVD + @param dst_rvd Tuple[int]: destination RVD + + @return path Tuple[Tuple[int]]: + The first one is src_rvd. The last one is dst_rvd. + Otherwise they are intermediate RVD status + """ + src_rvd, dst_rvd = tuple(src_rvd), tuple(dst_rvd) + if src_rvd == dst_rvd: return [src_rvd, dst_rvd] + + cost_fn = IntraPathFinder.default_cost_fn if cost_fn is None else cost_fn + shape = tuple(ftensor.shape) + ndevs = np.prod(np.array(src_rvd, dtype=int)) + key = (shape, ndevs) + + # get paths using dijkstra algorithm or cached + if key in IntraPathFinder._cached_intra_paths and src_rvd in IntraPathFinder._cached_intra_paths[key]: + paths = IntraPathFinder._cached_intra_paths[key][src_rvd] + else: + # initialize the graph if not cached + if key not in IntraPathFinder._cached_intra_nodes: + nodes, edges = IntraPathFinder.init_graph(ftensor, ndevs, cost_fn) + IntraPathFinder._cached_intra_nodes[key] = nodes + IntraPathFinder._cached_intra_edges[key] = edges + IntraPathFinder._cached_intra_paths[key] = {} + nodes = IntraPathFinder._cached_intra_nodes[key] + edges = IntraPathFinder._cached_intra_edges[key] + # build and initialize cost table + cost = np.full((len(nodes),), np.inf) + cost[nodes.index(src_rvd)] = 0 + # setup unvisited and visited set + unvisited = set(range(len(nodes))) + visited = set() + paths = [[] for _ in range(len(nodes))] + paths[nodes.index(src_rvd)] = [nodes.index(src_rvd)] + # dijkstra body + while len(unvisited) > 0: + min_cost, visit = np.inf, None + for idx in unvisited: + if cost[idx] < min_cost: + min_cost = idx + visit = idx + if visit is None: break # for remaining states that cannot reach + for neighbor in np.where(edges[visit] != np.inf)[0]: + new_cost = cost[visit] + edges[visit, neighbor] + if cost[neighbor] == np.inf or new_cost < cost[neighbor]: + cost[neighbor] = new_cost + paths[neighbor] = paths[visit] + [neighbor] + unvisited.remove(visit) + visited.add(visit) + IntraPathFinder._cached_intra_paths[key][src_rvd] = paths + + # print for debug + # for idx, path in enumerate(paths): + # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") + + # get layout + nodes = IntraPathFinder._cached_intra_nodes[key] + path: List[int] = paths[nodes.index(dst_rvd)] + rvds: List[Tuple[int]] = [nodes[idx] for idx in path] + assert len(path) > 0, f"Un-reachable src RVD ({src_rvd}) -> dst RVD ({dst_rvd})" + # print(f'get optimal path from {src_rvd} -> {dst_rvd}: {rvds}') + return rvds + + @staticmethod + def get_backup_path(ftensor: IRFullTensor, src_rvd: TRVD, dst_rvd: TRVD, + cost_fn: Optional[Callable] = None) -> Tuple[TRVD]: + """ + Get backup path + """ + rep = (np.prod(np.array(src_rvd)), 1) + (1,) * len(ftensor.shape) + # search for left primitives + left: List[TRVD] = IntraPathFinder.get_optimal_path(ftensor, src_rvd, rep, cost_fn) + # search + right: List[TRVD] = IntraPathFinder.get_optimal_path(ftensor, rep, dst_rvd, cost_fn) + # omit right[0] as same with left[-1] + return left + right[1:] + + @staticmethod + def get_device_space(ftensor: IRFullTensor, rvd_paths: List[TRVD], placement: Tuple[int]) -> Set[Tuple[int]]: + """ + Get all possible device placement of the last RVD given the rvd transition paths. + + @param ftensor IRFullTensor + @param rvd_paths Tuple[TRVDS]: transition RVD paths from source to destination + @param placement Tuple[int]: device placement of the first RVD in rvd_paths + + @return placements Set[Tuple[int]]: all possible device placement + """ + init, hops = rvd_paths[0], rvd_paths[1:] + rvds: List[RVDLayout] = [RVDLayout.grid(ftensor, r=init[0], v=init[1], dims=init[2:], devices=placement)] + for hop in hops: + for _ in range(len(rvds)): + layout = rvds.pop(0) + rets = IntraTransition.transition(layout, hop) + for (olayout, _) in rets: + rvds.append(olayout) + devices: Set[Tuple[int]] = set() + for rvd in rvds: + assert rvd.vec == tuple(hops[-1]) + devices.add(tuple(t.device[0] for t in rvd.mat.flatten())) + return devices + + @staticmethod + def init_graph(ftensor: IRFullTensor, ndevs: int, cost_fn: Optional[Callable] = None) -> Tuple[List[TRVD], np.ndarray]: + """ + Initialize the graph of RVD status graph. + + @param ftensor IRFullTensor: the full tensor + @param ndevs int: total device number + + @return nodes Tuple[TRVD] + @return edges np.ndarray: edges among nodes + """ + cost_fn = IntraPathFinder.default_cost_fn if cost_fn is None else cost_fn + nodes = tuple(IntraPathFinder.get_rvd_space(ftensor, ndevs)) + edges = np.full((len(nodes), len(nodes)), np.inf) + # initialize the cost + for i in range(len(nodes)): + for j in range(len(nodes)): + if i == j: continue + src, dst = nodes[i], nodes[j] + if IntraTransition.transitionable(src, dst) is None: continue + cost = IntraPathFinder.estimate_cost(ftensor, [src, dst], cost_fn) + edges[i, j] = cost + return nodes, edges + + @staticmethod + def get_rvd_space(ftensor: IRFullTensor, ndevs: int) -> List[Tuple[int, ...]]: + """ + Get all possible RVD representations given ftensor. + + This space is pruned by limiting partition number of each RVD dimension + in the range of [min(ilayout[dim], olayout[dim]), max(ilayout[dim], olayout[dim])] + + @param ftensor IRFullTensor + @param ilayout GridLayout: input layout + @param olayout GridLayout: output layout + + @return layouts List[GridLayout]: + """ + all_layouts: List[int] = [] + + def factors(ndevs: int, length: int): + if length == 1: yield [ndevs] + else: + for i in range(1, ndevs + 1): + if ndevs % i == 0: + for res in factors(ndevs // i, length - 1): + yield [i] + res + + for rvd in factors(ndevs, 2+len(ftensor.shape)): + skip = False + for dimlen, pnum in zip(ftensor.shape, rvd[2:]): + if dimlen % pnum != 0: + skip = True + break + if not skip: + all_layouts.append(tuple(rvd)) + return all_layouts + + @staticmethod + def estimate_cost(ftensor: IRFullTensor, rvd_paths: List[Tuple[TRVD]], cost_fn: Optional[Callable] = None) -> float: + """ + Estimate transition cost + """ + cost_fn = IntraPathFinder.default_cost_fn if cost_fn is None else cost_fn + cost = 0.0 + if len(rvd_paths) == 0: return cost + if len(rvd_paths) == 2 and (rvd_paths[0] == rvd_paths[1]): return cost + src, hops = rvd_paths[0], rvd_paths[1:] + for hop in hops: + trans_fn = IntraTransition.transitionable(src, hop) + assert trans_fn is not None, "Fails to find primitive for estimating cost" + incd = [dim for dim, (d1, d2) in enumerate(zip(src, hop)) if d1 < d2][0] + decd = [dim for dim, (d1, d2) in enumerate(zip(src, hop)) if d1 > d2][0] + chunks = src[decd] // hop[decd] + _, primitive = trans_fn(src, chunks=chunks) + ilayout: RVDLayout = RVDLayout.grid(ftensor, r=src[0], v=src[1], dims=src[2:]) + olayout: RVDLayout = RVDLayout.grid(ftensor, r=hop[0], v=hop[1], dims=hop[2:]) + imat = RVDLayout.dim2last(ilayout.mat, decd, chunks) + omat = RVDLayout.dim2last(olayout.mat, incd, chunks) + prim = primitive(imat.reshape(-1, chunks)[0], omat.reshape(-1, chunks)[0]) + cost += cost_fn(prim) + src = hop + return cost + + @staticmethod + def default_cost_fn(prim: IRAdapterPrim) -> int: + return prim.volume() + 1 # 1 is hop penalty + + +class IntraAutoPlacer: + + @staticmethod + def auto_place(shape: TShape, + fw_src_rvd: TRVD, fw_dst_rvd: TRVD, + bw_src_rvd: Optional[TRVD], bw_dst_rvd: Optional[TRVD], + src_placement: List[int], + cost_fn: Optional[Callable] = None): + """ + Search for a good device placement for + source and destination RVD partition + + @param shape Tuple[int]: full tensor shape + @param fw_src_rvd Tuple[int]: forward producer RVD layout vector + @param fw_dst_rvd Tuple[int]: forward consumer RVD layout vector + @param bw_src_rvd Optional[Tuple[int]]: backward producer RVD layout vector + @param bw_dst_rvd Optional[Tuple[int]]: backward consumer RVD layout vector + + @return devices List[int]: device sequence for RVD tensors + @return cost float: Cost of communication plan + """ + src_placement = tuple(src_placement) + ftensor = IRFullTensor(shape, dtype=IRDType.float16) + cost_fn = IntraPathFinder.default_cost_fn if cost_fn is None else cost_fn + + # forward pass + fw_rvd_hops = IntraPathFinder.get_optimal_path( + ftensor, fw_src_rvd, fw_dst_rvd, cost_fn=cost_fn) + fw_consumer_devices: Set[Tuple[int]] = IntraPathFinder.get_device_space( + ftensor, fw_rvd_hops, src_placement) + + # backward pass + if (bw_src_rvd is not None) and (bw_dst_rvd is not None): + bw_rvd_hops = IntraPathFinder.get_optimal_path( + ftensor, bw_src_rvd, bw_dst_rvd, cost_fn=cost_fn) + devices = set() + for bw_producer_devs in fw_consumer_devices: + bw_consumer_devices = IntraPathFinder.get_device_space( + ftensor, bw_rvd_hops, bw_producer_devs + ) + # FIXME: this comparison on tuples some misses possible placement + # that can be actually aligned by using layout.align (false possitive). + if src_placement in bw_consumer_devices: + devices.add(bw_producer_devs) + break + else: + devices = fw_consumer_devices + + placement = None + # - if find, choose one + if len(devices) > 0: + placement = list(devices)[0] + # - if not find, change forward one while use backup plan for backward one + else: + placement = list(fw_consumer_devices)[0] + print(f"================ forward-backward mis-aligned! ============== \n" + f"fw device choices: {fw_consumer_devices} | hops: {'->'.join(str(rvd) for rvd in fw_rvd_hops)}\n" + f"bw hops: {'->'.join(str(rvd) for rvd in bw_rvd_hops)}\n" + f"using placement: {placement}\n" + f"=============================================================") + bw_rvd_hops = IntraPathFinder.get_backup_path(ftensor, bw_src_rvd, bw_dst_rvd, cost_fn) + + # estimate cost + cost = IntraPathFinder.estimate_cost(ftensor, fw_rvd_hops, cost_fn) + if (bw_src_rvd is not None) and (bw_dst_rvd is not None): + cost += IntraPathFinder.estimate_cost(ftensor, bw_rvd_hops, cost_fn) + return placement, cost diff --git a/cube/graph/gener/rvd/layout.py b/cube/graph/gener/rvd/layout.py new file mode 100644 index 00000000..62773ec0 --- /dev/null +++ b/cube/graph/gener/rvd/layout.py @@ -0,0 +1,500 @@ +from typing import Dict, List, Tuple, Optional +import copy +import numpy as np + +from cube.ir.cten import IRCell +from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import ValueMap + +from cube.ir.adapter.prim import MovePrim # p2p +from cube.ir.adapter.prim import BroadcastPrim +from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim +from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim + + +TShape = Tuple[int, ...] +TRVD = Tuple[int, ...] + + +class RVDLayout: + """ + This class assumes a full-tensor can only be + uniformly partitioned / replicated on dimensions and values. + + A partition plan N-dim tensor layout can be represented as + : R (replica), V (value), dim_i (dimension) + """ + + def __init__(self, ftensor: IRFullTensor, subtensors: List[IRSubTensor], mats: np.ndarray): + """ + ftensor: N-dim FullTensor + subtensors: List[IRSubTensors] + mats: Array[IRSubTensor]: + (2+N)-dim matrix, with index respect to + """ + self.ftensor = ftensor + self.subtensors = subtensors + self._mats = mats + + @property + def R(self) -> int: + return self._mats.shape[0] + + @property + def V(self) -> int: + return self._mats.shape[1] + + @property + def D(self) -> Tuple[int]: + return tuple(self._mats.shape[2:]) + + @property + def vec(self) -> Tuple[int]: + return tuple(self._mats.shape) + + @property + def ndims(self): + return len(self._mats.shape) + + @property + def ndevs(self): + return len(self.subtensors) + + @property + def mat(self): + return self._mats + + def tensor(self, r: int, v: int, d: List[int]) -> IRSubTensor: + """ + Get subtenor indexed by RVD position. + """ + assert r <= self.R and v <= self.V and len(d) == len(self.D), "out of scope" + indices = [r, v] + list(d) + return self._mats[tuple(indices)] + + def __repr__(self): + dscp = f'T{self.ftensor._id}' + return dscp + + def __copy__(self): + tensors = [] + for t in self.mat.flatten(): + tensor = copy.copy(t) + tensor.cell = t.cell + tensors.append(tensor) + mat = np.array(tensors).reshape(self.mat.shape) + return RVDLayout(self.ftensor, tensors, mat) + + # ====== inter-RVD transition primitives ====== # + + def incr(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: + """ + inter-RVD primitive +R: broadcast + + @param chunks int: the number of chunks to transfer + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. + """ + rvd = list(self.vec) + rvd[0] = rvd[0] * chunks + olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) + # set device + if devices is not None: + assert devices.size == len(self.subtensors) * chunks + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims + imat = RVDLayout.dims2last(self.mat, [0]).flatten() + omat = RVDLayout.dims2last(olayout.mat, [0]).reshape(-1, chunks) + prims = [] + for src, dsts in zip(imat, omat): + if chunks == 1: + prims.append(MovePrim([src], dsts)) + else: + prims.append(BroadcastPrim([src], [src] + list(dsts))) + return [(olayout, prims),] + + def decr(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: + """ + inter-RVD primitive -R: move + + @param chunks int: the number of chunks to transfer + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. + """ + assert self.R % chunks == 0, f"not divisible replica {self.R} // {chunks}" + rvd = list(self.vec) + rvd[0] = rvd[0] // chunks + olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) + # set device + if devices is not None: + assert devices.size == len(self.subtensors) // chunks + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims + imat = RVDLayout.dims2last(self.mat, [0]).reshape(-1, chunks) + omat = RVDLayout.dims2last(olayout.mat, [0]).flatten() + prims = [] + for srcs, dst in zip(imat, omat): + prims.append(MovePrim([srcs[0]], [dst])) + return [(olayout, prims),] + + def incd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None) -> Tuple: + """ + inter-RVD primitive +D: RD-Scatter + + @param chunks int: the number of chunks to transfer + @param dim int: tensor axis + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. + """ + rvd = list(self.vec) + rvd[2+dim] = rvd[2+dim] * chunks + olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) + # set device + if devices is not None: + assert devices.size == len(self.subtensors) * chunks + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims + imat = RVDLayout.dims2last(self.mat, [2+dim]).flatten() + omat = RVDLayout.dims2last(olayout.mat, [2+dim]).reshape(-1, chunks) + prims = [] + for src, dsts in zip(imat, omat): + prims.append(RDScatterPrim([src], dsts, dim=dim)) + return olayout, prims + + def decd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None) -> Tuple: + """ + inter-RVD primitive +D: RD-Gather + + @param chunks int: the number of chunks to transfer + @param dim int: tensor axis + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. + """ + assert self.D[dim] % chunks == 0, f"not divisible dim: {self.D[dim]} % {chunks} != 0" + rvd = list(self.vec) + rvd[2+dim] = rvd[2+dim] // chunks + olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) + # set device + if devices is not None: + assert devices.size == len(self.subtensors) // chunks + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims + imat = RVDLayout.dims2last(self.mat, [2+dim]).reshape(-1, chunks) + omat = RVDLayout.dims2last(olayout.mat, [2+dim]).flatten() + prims = [] + for srcs, dst in zip(imat, omat): + prims.append(RDGatherPrim(srcs, [dst], dim=dim)) + return [(olayout, prims),] + + def incv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: + """ + inter-RVD primitive +V: RV-Scatter + + @param chunks int: the number of chunks to transfer + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. + """ + rvd = list(self.vec) + rvd[1] = rvd[1] * chunks + olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) + # set device + if devices is not None: + assert devices.size == len(self.subtensors) * chunks + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims + imat = RVDLayout.dims2last(self.mat, [1]).flatten() + omat = RVDLayout.dims2last(olayout.mat, [1]).reshape(-1, chunks) + prims = [] + for src, dsts in zip(imat, omat): + prims.append(RVScatterPrim([src], dsts)) + return [(olayout, prims),] + + def decv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: + """ + inter-RVD primitive -V: RV-Gather + + @param chunks int: the number of chunks to transfer + @param devices numpy.ndarray: the desired output device + + @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. + """ + assert self.V % chunks == 0, f"not divisable value split: {self.V} % {chunks} != 0" + rvd = list(self.vec) + rvd[1] = rvd[1] // chunks + olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) + # set device + if devices is not None: + assert devices.size == len(self.subtensors) // chunks + for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = int(devid) + # setup prims + imat = RVDLayout.dims2last(self.mat, [1]).reshape(-1, chunks) + omat = RVDLayout.dims2last(olayout.mat, [1]).flatten() + prims = [] + for srcs, dst in zip(imat, omat): + prims.append(RVGatherPrim(srcs, [dst])) + return [(olayout, prims),] + + def align(self, layout) -> bool: + """ + Check whether the layout is same with self. + + The same means 1) sub-tenosrs are same 2) device are aligned + + @param layout RVDLayout + + @return same bool: + """ + if not isinstance(layout, RVDLayout): + return False + tensors: List[IRSubTensor] = list(self.mat.flatten()) + for t in layout.mat.flatten(): + dev_match = False + for idx in range(len(tensors)): + t2 = tensors[idx] + if t == t2 and set(t.device) == set(t2.device): + tensors.pop(idx) + dev_match = True + break + if not dev_match: return False + return True + + def inner_transpose(self, dim: int, chunks: int): + """ + transpose ordering of tensor within a dimension. + """ + assert 0 <= dim and dim < len(self._mats.shape) + assert self.vec[dim] % chunks == 0 + ori_shape = list(self.vec) + new_shape = list(self.vec) + new_shape.insert(dim, self.vec[dim] // chunks) + new_shape[dim+1] = chunks + self._mats = self._mats.reshape(new_shape) + axes = list(range(len(new_shape))) + axes[dim], axes[dim+1] = axes[dim+1], axes[dim] + self._mats = self._mats.transpose(axes) + self._mats = self._mats.reshape(ori_shape) + + @staticmethod + def dim2last(mat: np.ndarray, dim: int, chunk: int) -> np.ndarray: + shape = list(mat.shape) + assert shape[dim] % chunk == 0 + shape[dim] = shape[dim] // chunk + shape.insert(dim+1, chunk) + mat = mat.reshape(shape) + # move the axis to the last + mat = np.moveaxis(mat, dim+1, -1) + return mat + + @staticmethod + def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int], devices: Optional[Tuple[int]] = None): + """ + partition a ftensor using grid layout of + """ + dims = tuple(dims) + def dummy_assign(tensor: IRSubTensor, devid: int): + tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell.device = devid + + mats = np.empty((r, v) + dims, dtype=IRSubTensor) + all_subtensors = [] + + def iter_idx(dims: List[int]) -> Tuple[int]: + if len(dims) == 0: + yield () + else: + for i in range(dims[0]): + for indices in iter_idx(dims[1:]): + yield (i,) + indices + # generate tensor for each index + for indices in iter_idx((v,)+dims): + valmap = ValueMap((indices[0], v)) + indmap = [] + shape = [] + for dim, (nchunk, index) in enumerate(zip(dims, indices[1:])): + assert ftensor.shape[dim] % nchunk == 0, f"not dividable for {nchunk} chunks over dim {dim}. ftensor shape: {ftensor.shape}" + csize = ftensor.shape[dim] // nchunk + start = csize * index + indmap.append((start, start+csize)) + shape.append(csize) + subtensor = ftensor.select(tuple(indmap), valmap) + # replicate + subtensors = [copy.copy(subtensor) for _ in range(r)] + all_subtensors += subtensors + mats[(slice(None),)+indices] = np.array(subtensors, dtype=IRSubTensor) + + # devices + if devices is not None: + assert len(devices) == len(all_subtensors), f"devices number {len(devices)} not match with RVD number {len(all_subtensors)}" + for tensor, devid in zip(mats.flatten(), devices): + dummy_assign(tensor, int(devid)) + + return RVDLayout(ftensor, all_subtensors, mats) + + @staticmethod + def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): + """ + convert ftensor and subtensors into a RVDLayout. + + If failed, raise error + """ + _replica: int = None + _value: int = None + _dims: List[int] = [None] * len(ftensor.shape) + _tindex: Dict[int, List[int]] = dict() + + ndims = len(ftensor.shape) + + replicas: Dict[int, List[IRSubTensor]] = dict() + vchunks: set = set() + dchunks: List[set] = [set() for _ in range(ndims)] + + for subtensor in subtensors: + tid = id(subtensor) + # set up replica + if subtensor.tid not in replicas: + replicas[subtensor.tid] = [] + _tindex[tid] = [len(replicas[subtensor.tid])] + replicas[subtensor.tid].append(subtensor) + # setup value + _tindex[tid].append(subtensor.valmap[0]) + vchunks.add(subtensor.valmap[1]) + # setup dimensions + for dim in range(ndims): + snele = subtensor.shape[dim] + start = subtensor.indmap[dim][0] + fnele = ftensor.shape[dim] + if fnele % snele != 0 or start % snele != 0: + raise RuntimeError( + f"dimension split error:\n" + f"Full Tensor: {ftensor}\n" + f"full nele: {fnele}, sub nele: {snele}, start: {start}" + ) + dchunks[dim].add(fnele // snele) + _tindex[tid].append(start // snele) + # replica (R) + nreplicas = set(len(ts) for ts in replicas.values()) + if len(nreplicas) != 1: + raise RuntimeError(f"different replicas: {nreplicas}") + _replica = list(nreplicas)[0] + # value (V) + nchunks = set(t.valmap[1] for t in subtensors) + if len(nchunks) != 1: + raise RuntimeError(f"different value split: {nchunks}") + _value = list(nchunks)[0] + # dimension (D) + for dim in range(ndims): + if len(dchunks[dim]) != 1: + raise RuntimeError(f"different dimension split: {dchunks[dim]}") + _dims[dim] = list(dchunks[dim])[0] + + # set matrix + mats = np.empty([_replica, _value] + _dims, dtype=IRSubTensor) + for subtensor in subtensors: + idx = tuple(_tindex[id(subtensor)]) + assert mats[idx] is None, f"repeating entry. mutiple same {subtensor}" + mats[tuple(idx)] = subtensor + assert not (mats == None).any(), "at least one entry not set" + return RVDLayout(ftensor, subtensors, mats) + + +class RVDInspector: + + @staticmethod + def draw(prvd: RVDLayout, crvd: RVDLayout, outfile: str) -> None: + """ + Draw producer RVDLayout and consumer RVDLayout + """ + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + max_dev = max( + max(t.device[0] for t in prvd.subtensors), max(t.device[0] for t in crvd.subtensors) + ) + min_dev = min( + min(t.device[0] for t in prvd.subtensors), min(t.device[0] for t in crvd.subtensors) + ) + devlen = max_dev - min_dev + plt.close('all') + plt.rcParams['figure.figsize'] = (4.0 * devlen, 7.0) + fig, ax = plt.subplots() + + fontsize = 30 + + ax.set_xlim((-0.5, devlen+0.5)) + ax.set_ylim((0, 5)) + + ptensors = prvd.mat.flatten().tolist() + ctensors = crvd.mat.flatten().tolist() + + recflen = 0.8 + def draw_subtensor(t: IRSubTensor, xy: Tuple[int], color: str): + assert len(t.shape) == 2, "Only able to draw 2-D tensor" + x, y = xy + # full tensor + rec = Rectangle(xy, recflen, recflen, color='white', ec='black', lw=2.0) + # sub tensor + subx_nchunks = t.parent.shape[1] // t.shape[1] + subw = recflen / subx_nchunks + subx = x + subw * (t.indmap[1][0] // t.shape[1]) + + suby_nchunks = t.parent.shape[0] // t.shape[0] + subh = recflen / suby_nchunks + suby = y + subh * (t.indmap[0][0] // t.shape[1]) + + # if t.valmap != (0, 1): + ax.text(x=x+recflen/2, y=y+recflen+recflen/2, s=f'val({t.valmap[0]}/{t.valmap[1]})', + fontsize=fontsize, ha='center', va='center', color='black') + + subrec = Rectangle((subx, suby), subw, subh, color=color, ec='black', lw=2.0) + ax.add_artist(rec) + ax.add_artist(subrec) + + for ptensor in ptensors: + x, y = ptensor.device[0]-min_dev-0.4, 3 + draw_subtensor(ptensor, (x, y), 'blue') + + ax.text(x=-1, y=3+recflen/2, s='Producer', + fontsize=fontsize, ha='center', va='center', color='black') + + for ctensor in ctensors: + x, y = ctensor.device[0]-min_dev-0.4, 0.5 + draw_subtensor(ctensor, (x, y), 'orange') + + ax.text(x=-1, y=0.5+recflen/2, s='Consumer', + fontsize=fontsize, ha='center', va='center', color='black') + + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize) + + ax.spines['bottom'].set_color('white') + ax.spines['top'].set_color('white') + ax.spines['left'].set_color('white') + ax.spines['right'].set_color('white') + + ax.get_yaxis().set_visible(False) + plt.savefig(outfile) + + \ No newline at end of file diff --git a/tests/adapter/test_inter_rvd.py b/tests/adapter/test_inter_rvd.py new file mode 100644 index 00000000..497a2930 --- /dev/null +++ b/tests/adapter/test_inter_rvd.py @@ -0,0 +1,88 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=1 \ + tests/adapter/test_inter_rvd.py +""" + +from typing import List, Tuple +import cube +from cube.ir.tensor import IRFullTensor +from cube.graph.gener.rvd.layout import RVDLayout, RVDInspector +from cube.graph.gener.rvd.inter import InterPathFinder +import numpy as np + +from cube.graph.gener.utils import tensor_vd_repr + + +cube.init() + + +def factors(k: int, num: int) -> List[Tuple[int]]: + """ + get all possible sequence k1 * k2 * .. k_{num} = k + """ + if num == 1: return [(k,)] + res = [] + for i in range(1, k): + if k % i != 0: continue + for sub_res in factors(k // i, num - 1): + res.append((i,) + sub_res) + return res + + +def test_one_f_case(): + + fshape = [128, 256, 512] + + src_r, src_v, src_d = 1,4,(1,1,2) + dst_r, dst_v, dst_d = 2,1,(2,1,2) + src_rvd = (src_r, src_v) + src_d + dst_rvd = (dst_r, dst_v) + dst_d + + pndevs = np.prod(src_rvd) + cndevs = np.prod(dst_rvd) + + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + + pdevs = list(range(pndevs)) + fp_rvd = RVDLayout.grid(ftensor, r=src_r, v=src_v, dims=src_d, devices=pdevs) + + cdevs = list(range(pndevs, pndevs + cndevs)) + fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=cdevs) + + rvds = InterPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) + print(f"optimal path: {' -> '.join(str(rvd) for rvd in rvds)}") + + fprims = InterPathFinder.path(fp_rvd, fc_rvd) + for prim in fprims: + print(prim) + + +def test_all_f_cases_fix_placement(): + + fshape = [128, 256, 512] + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + + pndevs = 4 + cndevs = 8 + + ndims = len(fshape) + 2 + for src_rvd in factors(pndevs, ndims): + for dst_rvd in factors(cndevs, ndims): + if src_rvd == dst_rvd or src_rvd[1] < dst_rvd[1]: continue + print(f'test generating | source rvd: {src_rvd}, destination rvd: {dst_rvd}') + pdevs = list(range(pndevs)) + fp_rvd = RVDLayout.grid(ftensor, r=src_rvd[0], v=src_rvd[1], dims=src_rvd[2:], devices=pdevs) + + cdevs = list(range(pndevs, pndevs + cndevs)) + fc_rvd = RVDLayout.grid(ftensor, r=dst_rvd[0], v=dst_rvd[1], dims=dst_rvd[2:],devices=cdevs) + + _ = InterPathFinder.path(fp_rvd, fc_rvd) + rvds = InterPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) + print(f"==> path: {'->'.join(str(rvd) for rvd in rvds)}") + + +if __name__ == '__main__': + + # test_one_f_case() + test_all_f_cases_fix_placement() \ No newline at end of file diff --git a/tests/adapter/test_intra_rvd.py b/tests/adapter/test_intra_rvd.py index 8ea0663f..8c517a17 100644 --- a/tests/adapter/test_intra_rvd.py +++ b/tests/adapter/test_intra_rvd.py @@ -1,197 +1,292 @@ """ OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ + --nproc_per_node=1 \ tests/adapter/test_intra_rvd.py """ + from typing import List, Tuple import cube -from cube.ir.operator import IRFwOperation, IRDataOperation from cube.ir.tensor import IRFullTensor -from cube.graph.graph import IRGraph -from cube.graph.gener.layout import GridLayout -from cube.graph.function.dimops import IRDimops -from cube.algorithm.generics import GenericDistAlgo -import torch +from cube.graph.gener.rvd.layout import RVDLayout, RVDInspector +from cube.graph.gener.rvd.intra import IntraPathFinder, IntraAutoPlacer, IntraTransition import numpy as np +from cube.graph.gener.utils import tensor_vd_repr + + cube.init() -class RVDSplit(GenericDistAlgo): +def factors(k: int, num: int) -> List[Tuple[int]]: + """ + get all possible sequence k1 * k2 * .. k_{num} = k + """ + if num == 1: return [(k,)] + res = [] + for i in range(1, k): + if k % i != 0: continue + for sub_res in factors(k // i, num - 1): + res.append((i,) + sub_res) + return res - def __init__(self, node: IRDimops): - super().__init__(node) - def satisfy(self, in_rvd: Tuple[int], out_rvd: Tuple[int]): - return True - - def instantiate(self, in_rvd: Tuple[int], out_rvd: Tuple[int]) -> List[IRFwOperation]: - assert np.prod(np.array(in_rvd, dtype=int)) == np.prod(np.array(out_rvd, dtype=int)), \ - f"tensor number not match: {in_rvd}, {out_rvd}" - assert tuple(in_rvd)[2:] == tuple(out_rvd)[2:], f"input / output shape should be same" - - node: IRDimops = self.node - iftensor: IRFullTensor = node.input(0).parent - itensors = GridLayout.grid(iftensor, r=in_rvd[0], v=in_rvd[1], dims=in_rvd[2:]).mat.flatten() - oftensor: IRFullTensor = node.output(0).parent - otensors = GridLayout.grid(oftensor, r=out_rvd[0], v=out_rvd[1], dims=out_rvd[2:]).mat.flatten() - subnodes = [] - for itensor, otensor in zip(itensors, otensors): - subnode = node.new([itensor, 2], [otensor]) - subnodes.append(subnode) - return subnodes +def test_intra_transition(): + + fshape = [256, 256] + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + + src = (1, 2, 1, 4) + dst = (1, 1, 1, 8) + devs = list(range(8)) + src_rvd = RVDLayout.grid(ftensor, r=src[0], v=src[1], dims=src[2:], devices=devs) + + rets = IntraTransition.transition(src, dst, src_rvd, True) + for idx, (layout, prims) in enumerate(rets): + RVDInspector.draw(src_rvd, layout, f'rvd-trans-{idx}.png') + + + +def test_transition_space(): + + fshape = [256, 256] + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + + src = (1, 2, 1, 4) + dst = (1, 1, 1, 8) + devs = list(range(8)) + + choices = IntraPathFinder.get_device_space(ftensor, [src, dst], src_placement=devs) + print('choices:', choices) + + reverse_choices = IntraPathFinder.get_device_space(ftensor, [src, dst], dst_placement=devs) + print('reverse_choices:', reverse_choices) + + # draw reverse output + for idx, choice in enumerate(choices): + src_rvd = RVDLayout.grid(ftensor, r=src[0], v=src[1], dims=src[2:], devices=devs) + dst_rvd = RVDLayout.grid(ftensor, r=dst[0], v=dst[1], dims=dst[2:], devices=choice) + RVDInspector.draw(src_rvd, dst_rvd, f'rvd-{idx}.png') -class TestModel(torch.nn.Module): - - def __init__(self) -> None: - super().__init__() - self.param = torch.nn.Parameter(torch.empty(1024, 1024)) - - def forward(self, x): - x = torch.matmul(x, self.param) - # residual = x - x = x * 2 - x = x * 2 - x = x * 2 - x = x * 2 - x = x * 2 - x = x * 2 - x = x * 2 - # x = x + residual - x = torch.sum(x) - return x - - -def _ntp(graph, node: IRDimops, idx: int, dim: int, devs: List[int]): - algo = node.algorithms('dim') - nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert nodes is not None - for devid, node in zip(devs, nodes): - graph.assign(node, devid) - return nodes - - -def _tp(graph, node: IRFwOperation, in_rvd: Tuple[int], out_rvd: Tuple[int], devs: List[int]): - algo = RVDSplit(node) - nodes = graph.partition(node, algo, in_rvd=in_rvd, out_rvd=out_rvd) - assert nodes is not None - for devid, node in zip(devs, nodes): - graph.assign(node, devid) - return nodes - - -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - nodes = graph.replicate(node, times=len(devs)) - assert nodes is not None - for devid, node in zip(devs, nodes): - graph.assign(node, devid) - return nodes + # draw reverse output + for idx, choice in enumerate(reverse_choices): + src_rvd = RVDLayout.grid(ftensor, r=src[0], v=src[1], dims=src[2:], devices=choice) + dst_rvd = RVDLayout.grid(ftensor, r=dst[0], v=dst[1], dims=dst[2:], devices=devs) + RVDInspector.draw(src_rvd, dst_rvd, f'rvd-reverse-{idx}.png') + + +def test_one_f_case(): + + fshape = [128, 256, 512] + + src_r, src_v, src_d = 1,4,(1,1,2) + dst_r, dst_v, dst_d = 2,1,(2,1,2) + src_rvd = (src_r, src_v) + src_d + dst_rvd = (dst_r, dst_v) + dst_d + ndevs = src_r * src_v * np.prod(np.array(src_d)) + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + + pdevs = list(range(ndevs)) + fp_rvd = RVDLayout.grid(ftensor, r=src_r, v=src_v, dims=src_d, devices=pdevs) + + cdevs = list(range(ndevs)) + fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=cdevs) + + rvds = IntraPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) + print(f"optimal path: {' -> '.join(str(rvd) for rvd in rvds)}") + + fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + for prim in fprims: + print(prim) + + +def test_all_f_cases_fix_placement(): + + fshape = [128, 256, 512] + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + + ndevs = 8 + ndims = len(fshape) + 2 + for src_rvd in factors(ndevs, ndims): + for dst_rvd in factors(ndevs, ndims): + if src_rvd == dst_rvd or src_rvd[1] < dst_rvd[1]: continue + print(f'test generating | source rvd: {src_rvd}, destination rvd: {dst_rvd}') + pdevs = list(range(ndevs)) + fp_rvd = RVDLayout.grid(ftensor, r=src_rvd[0], v=src_rvd[1], dims=src_rvd[2:], devices=pdevs) + fptensors = fp_rvd.subtensors + + cdevs = list(range(ndevs)) + fc_rvd = RVDLayout.grid(ftensor, r=dst_rvd[0], v=dst_rvd[1], dims=dst_rvd[2:],devices=cdevs) + fctensors = fc_rvd.subtensors + + fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + -def test_multiref_intra_rvd(): - - model = TestModel() - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([1024,1024],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) - - def policy(graph: IRGraph, resource): - print(graph.extra_repr()) - devs = list(range(resource.ngpus)) - - for ftensor in graph.full_tensors(): - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) - - for dl in graph.select(ntype=IRDataOperation): - _replica(graph, dl, devs) - - for node in graph.select(ntype=IRFwOperation): - if node.name == 'multiref': continue - if node.name == 'mul': - _ntp(graph, node, idx=0, dim=0, devs=devs) - elif node.name == 'add': - _ntp(graph, node, idx=0, dim=1, devs=devs) - else: - _replica(graph, node, devs) - print(graph.extra_repr()) - return graph - - model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - model = model.get_gen_module() - - for _ in range(4): - train_iter(model, dataloader) - - -def test_intra_rvd(): - - model = TestModel() - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([1024,1024],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) - - def policy(graph: IRGraph, resource): - assert resource.ngpus == 4 - print(graph.extra_repr()) - devs = list(range(resource.ngpus)) - - # for ftensor in graph.full_tensors(): - # if len(graph.consumers(ftensor)) > 1: - # graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) - - for dl in graph.select(ntype=IRDataOperation): - _replica(graph, dl, devs) - - for idx, node in enumerate(graph.select(name='mul')): - if idx == 0: # out: R(4)V(1)D(1,1) -> in: R(1)V(1)D(4,1): schunk - _tp(graph, node, in_rvd=(1,1,4,1), out_rvd=(1,1,4,1), devs=devs) - elif idx == 1: # out: R(1)V(1)D(4,1) -> in: R(1)V(1)D(1,4): all-to-all wil FAIL. expected!! - _tp(graph, node, in_rvd=(1,1,1,4), out_rvd=(1,1,1,4), devs=devs) - elif idx == 2: # out: R(1)V(1)D(1,4) -> in: R(1)V(1)D(2,2): schunk - _tp(graph, node, in_rvd=(1,1,2,2), out_rvd=(1,1,2,2), devs=devs) - elif idx == 3: # out: R(1)V(1)D(2,2) -> in: R(4)V(1)D(1,1): all-gather + all-gather - _tp(graph, node, in_rvd=(4,1,1,1), out_rvd=(1,4,1,1), devs=devs) - elif idx == 4: # out: R(1)V(4)D(1,1) -> in: R(1)V(1)D(4,1): reduce-scatter - _tp(graph, node, in_rvd=(1,1,4,1), out_rvd=(1,1,4,1), devs=devs) - elif idx == 5: # out: R(1)V(1)D(4,1) -> in R(4)V(1)D(1,1): all-gather - _tp(graph, node, in_rvd=(4,1,1,1), out_rvd=(1,4,1,1), devs=devs) - elif idx == 6: # out: R(1)V(4)D(1,1) -> in R(1)V(1)D(2,2): reduce-scatter + reduce-scatter - _tp(graph, node, in_rvd=(1,1,2,2), out_rvd=(1,1,2,2), devs=devs) - else: - assert False - - for node in graph.select(ntype=IRFwOperation): - if len(node.device) == 0: - _replica(graph, node, devs) - - print(graph.extra_repr()) - return graph +def test_all_f_cases_auto_placement(): + + fshape = [128, 256, 512] + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + + ndevs = 8 + ndims = len(fshape) + 2 + for src_rvd in factors(ndevs, ndims): + for dst_rvd in factors(ndevs, ndims): + if src_rvd == dst_rvd or src_rvd[1] < dst_rvd[1]: continue + print(f'test generating | source rvd: {src_rvd}, destination rvd: {dst_rvd}') + pdevs = list(range(ndevs)) + fp_rvd = RVDLayout.grid(ftensor, r=src_rvd[0], v=src_rvd[1], dims=src_rvd[2:], devices=pdevs) + + placement, cost = IntraAutoPlacer.auto_place( + ftensor.shape, + src_rvd, dst_rvd, None, None, + src_placement=pdevs + ) + fc_rvd = RVDLayout.grid(ftensor, r=dst_rvd[0], v=dst_rvd[1], dims=dst_rvd[2:],devices=placement) + + fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + print(f'cost: {cost}') + + +def test_one_fb_case(): + + fshape = [128, 256, 512] + + # forward + fsrc_r, fsrc_v, fsrc_d = 2,2,(1,1,2) + fdst_r, fdst_v, fdst_d = 2,1,(1,1,4) + bsrc_r, bsrc_v, bsrc_d = 1,2,(1,1,4) + bdst_r, bdst_v, bdst_d = 4,1,(1,1,2) + ndevs = fsrc_r * fsrc_v * np.prod(np.array(fsrc_d)) + + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=True) + btensor: IRFullTensor = ftensor.grad + + # forward producer / backward consumer + fpdevs = list(range(ndevs)) + fp_rvd = RVDLayout.grid(ftensor, r=fsrc_r, v=fsrc_v, dims=fsrc_d, devices=fpdevs) + # print('forward producer tensor:') + # for t in fp_rvd.mat.flatten(): + # print('\t'+tensor_vd_repr(t)) + bc_rvd = RVDLayout.grid(btensor, r=bdst_r, v=bdst_v, dims=bdst_d, devices=fpdevs) + + # forward consumer / backward producer + fcdevs, _ = IntraAutoPlacer.auto_place( + fshape, (fsrc_r, fsrc_v) + fsrc_d, (fdst_r, fdst_v) + fdst_d, + (bsrc_r, bsrc_v) + bsrc_d, (bdst_r, bdst_v) + bdst_d, fpdevs) - model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - model = model.get_gen_module() + fc_rvd = RVDLayout.grid(ftensor, r=fdst_r, v=fdst_v, dims=fdst_d, devices=fcdevs) + # print('forward consumer tensor:') + # for t in fc_rvd.mat.flatten(): + # print('\t'+tensor_vd_repr(t)) + bp_rvd = RVDLayout.grid(btensor, r=bsrc_r, v=bsrc_v, dims=bsrc_d, devices=fcdevs) - for _ in range(4): - train_iter(model, dataloader) + fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + bprims = IntraPathFinder.path(bp_rvd, bc_rvd) + print('forward prims:') + for prim in fprims: + print('\t', prim) + print('backward prims:') + for prim in bprims: + print('\t', prim) -if __name__ == '__main__': +def test_all_fb_cases_fix_placement(): - # test_multiref_intra_rvd() - test_intra_rvd() + fshape = [128, 256, 512] + ndevs = 8 + + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=True) + btensor: IRFullTensor = ftensor.grad + + ndims = len(fshape) + 2 + for fp_rvd in factors(ndevs, ndims): + + fdevs = list(range(ndevs)) + fp = RVDLayout.grid(ftensor, r=fp_rvd[0], v=fp_rvd[1], dims=fp_rvd[2:], devices=fdevs) + + for fc_rvd in factors(ndevs, ndims): + if fc_rvd[1] != 1: continue + fc = RVDLayout.grid(ftensor, r=fc_rvd[0], v=fc_rvd[1], dims=fc_rvd[2:], devices=fdevs) + + # case1: forward replica -> backward replica + bp_rvd = fc_rvd + bc_rvd = (fp_rvd[0] * fp_rvd[1], 1) + fp_rvd[2:] + print(f'test generating | fp rvd: {fp_rvd}, fc rvd: {fc_rvd}, bp rvd: {bp_rvd}, bc rvd: {bc_rvd}') + + bp = RVDLayout.grid(btensor, r=bp_rvd[0], v=bp_rvd[1], dims=bp_rvd[2:], devices=fdevs) + bc = RVDLayout.grid(btensor, r=bc_rvd[0], v=bc_rvd[1], dims=bc_rvd[2:], devices=fdevs) + + fprims = IntraPathFinder.path(fp, fc) + bprims = IntraPathFinder.path(bp, bc) + + # case2: forward replica -> backward accum + bp_rvd = (1, fc_rvd[0] * fc_rvd[1]) + fc_rvd[2:] + bc_rvd = (fp_rvd[0] * fp_rvd[1], 1) + fp_rvd[2:] + print(f'test generating | fp rvd: {fp_rvd}, fc rvd: {fc_rvd}, bp rvd: {bp_rvd}, bc rvd: {bc_rvd}') + + bp = RVDLayout.grid(btensor, r=bp_rvd[0], v=bp_rvd[1], dims=bp_rvd[2:], devices=fdevs) + bc = RVDLayout.grid(btensor, r=bc_rvd[0], v=bc_rvd[1], dims=bc_rvd[2:], devices=fdevs) + + fprims = IntraPathFinder.path(fp, fc) + bprims = IntraPathFinder.path(bp, bc) + + +def test_all_fb_cases_auto_placement(): + + fshape = [128, 256, 512] + ndevs = 8 + + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=True) + btensor: IRFullTensor = ftensor.grad + + ndims = len(fshape) + 2 + for fp_rvd in factors(ndevs, ndims): + + fdevs = list(range(ndevs)) + fp = RVDLayout.grid(ftensor, r=fp_rvd[0], v=fp_rvd[1], dims=fp_rvd[2:], devices=fdevs) + + for fc_rvd in factors(ndevs, ndims): + if fc_rvd[1] != 1: continue + + # case1: forward replica -> backward replica + bp_rvd = fc_rvd + bc_rvd = (fp_rvd[0] * fp_rvd[1], 1) + fp_rvd[2:] + print(f'test generating | fp rvd: {fp_rvd}, fc rvd: {fc_rvd}, bp rvd: {bp_rvd}, bc rvd: {bc_rvd}') + + placement, cost = IntraAutoPlacer.auto_place( + fshape, fp_rvd, fc_rvd, bp_rvd, bc_rvd, fdevs) + + fc = RVDLayout.grid(ftensor, r=fc_rvd[0], v=fc_rvd[1], dims=fc_rvd[2:], devices=placement) + bp = RVDLayout.grid(btensor, r=bp_rvd[0], v=bp_rvd[1], dims=bp_rvd[2:], devices=placement) + bc = RVDLayout.grid(btensor, r=bc_rvd[0], v=bc_rvd[1], dims=bc_rvd[2:], devices=fdevs) + + fprims = IntraPathFinder.path(fp, fc) + bprims = IntraPathFinder.path(bp, bc) + + # case2: forward replica -> backward accum + bp_rvd = (1, fc_rvd[0] * fc_rvd[1]) + fc_rvd[2:] + bc_rvd = (fp_rvd[0] * fp_rvd[1], 1) + fp_rvd[2:] + print(f'test generating | fp rvd: {fp_rvd}, fc rvd: {fc_rvd}, bp rvd: {bp_rvd}, bc rvd: {bc_rvd}') + + placement, cost = IntraAutoPlacer.auto_place( + fshape, fp_rvd, fc_rvd, bp_rvd, bc_rvd, fdevs) + + fc = RVDLayout.grid(ftensor, r=fc_rvd[0], v=fc_rvd[1], dims=fc_rvd[2:], devices=placement) + bp = RVDLayout.grid(btensor, r=bp_rvd[0], v=bp_rvd[1], dims=bp_rvd[2:], devices=placement) + bc = RVDLayout.grid(btensor, r=bc_rvd[0], v=bc_rvd[1], dims=bc_rvd[2:], devices=fdevs) + + fprims = IntraPathFinder.path(fp, fc) + bprims = IntraPathFinder.path(bp, bc) + + +if __name__ == '__main__': + # test_intra_transition() + # test_transition_space() + # test_one_f_case() + # test_all_f_cases_fix_placement() + # test_all_f_cases_auto_placement() + # test_one_fb_case() + # test_all_fb_cases_fix_placement() + test_all_fb_cases_auto_placement() \ No newline at end of file From dc6cdfa169689468d0996fdbc9aaec591dbb0dd1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 Jan 2023 10:33:56 +0800 Subject: [PATCH 1219/1892] refine docs --- cube/graph/gener/rvd/intra.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index f969fcf3..23e9b368 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -239,7 +239,7 @@ def path(ilayout: RVDLayout, olayout: RVDLayout, @param ilayout RVDLayout: input tensor layout @param olayout RVDLayout: output tensor layout @param cost_fn Optional[Callable]: cost function of each primitive. - Default (None) will use transmission volume as metrics + Default (None) will use communication volume as metrics @return all_primitives List[IRAdapterPrims]: all primitives for communication path """ @@ -308,8 +308,8 @@ def device_align(ilayout: RVDLayout, olayout: RVDLayout, @param ilayouts RVDLayout: source layout @param olayout RVDLayout: target layout with correct device mapping - @param rvd_hops: Tuple[TRVD]: the hops from ilayout to olayout - (not contains ilayout at beginning, but contains olayout at last) + @param rvd_hops: Tuple[TRVD]: the hops from ilayout to olayout, which + contains ilayout and olayout at beginning and last, respectively. @return success bool: True if found device, else False. @return primitives List[IRAdapterPrim]: the correspoinding primitives @@ -534,7 +534,7 @@ def auto_place(shape: TShape, fw_src_rvd: TRVD, fw_dst_rvd: TRVD, bw_src_rvd: Optional[TRVD], bw_dst_rvd: Optional[TRVD], src_placement: List[int], - cost_fn: Optional[Callable] = None): + cost_fn: Optional[Callable] = None) -> Tuple[Tuple[int], float]: """ Search for a good device placement for source and destination RVD partition @@ -544,8 +544,10 @@ def auto_place(shape: TShape, @param fw_dst_rvd Tuple[int]: forward consumer RVD layout vector @param bw_src_rvd Optional[Tuple[int]]: backward producer RVD layout vector @param bw_dst_rvd Optional[Tuple[int]]: backward consumer RVD layout vector + @param cost_fn Optional[Callable]: cost function of each primitive. + Default (None) will use communication volume as metrics - @return devices List[int]: device sequence for RVD tensors + @return devices Tuple[int]: device sequence for RVD tensors @return cost float: Cost of communication plan """ src_placement = tuple(src_placement) @@ -579,7 +581,7 @@ def auto_place(shape: TShape, # - if find, choose one if len(devices) > 0: placement = list(devices)[0] - # - if not find, change forward one while use backup plan for backward one + # - if not find, keep forward one as optimal and adopt backup plan for backward one else: placement = list(fw_consumer_devices)[0] print(f"================ forward-backward mis-aligned! ============== \n" From c2252417e34a5b390c75e68e2b01c06f862c857e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 Jan 2023 11:01:28 +0800 Subject: [PATCH 1220/1892] expose communication cost function --- cube/compiler.py | 9 ++++++--- cube/graph/gener/rvd/intra.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 287bce27..7d0a8b06 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -27,7 +27,7 @@ def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, - override = True, load_content = True) -> Callable: + comm_cost_fn: Optional[Callable] = None, override = True, load_content = True) -> Callable: """ AI Scientist calls like: @@ -51,7 +51,10 @@ def train_step(model, dataloader): @param model SemanticModel: AI Scientist specified SemanticModel @param dataloader CubDataLoader: dataloader used for training - @param policy Callable: policy to transform and schedule graph + @param PAS Callable: policy to transform and schedule graph + @param comm_cost_fn: Optional[Callable]: communication cost function, which + takes in an IRAdapterPrim, and outputs a cost in float. By default (None) use + communication volume. @param override bool: If true, the generated code will override exsisting files (if they are already existed.), otherwise, use the already existed generated code, i.e., the policy won't take effect. Default true. @@ -140,7 +143,7 @@ def decorator(fn: Callable) -> Callable: # generate adapter start = time.time() - graph = IRAdapterGener.gen(graph) + graph = IRAdapterGener.gen(graph, cost_fn=comm_cost_fn) span = time.time() - start print('> finish generating adapters: {:.2f} s'.format(span)) diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index 23e9b368..5e90f233 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -246,7 +246,7 @@ def path(ilayout: RVDLayout, olayout: RVDLayout, assert ilayout.ftensor == olayout.ftensor, f"ilayout and olayout should have a same full tensor" ftensor = ilayout.ftensor src, dst = tuple(ilayout.vec), tuple(olayout.vec) - rvds = IntraPathFinder.get_optimal_path(ftensor, src, dst) + rvds = IntraPathFinder.get_optimal_path(ftensor, src, dst, cost_fn) # search for correct device mapping align, all_prims = IntraPathFinder.device_align(ilayout, olayout, rvds) From 3c16ce07aba0eb94c81f0447b5649a275499ff38 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 Jan 2023 11:19:33 +0800 Subject: [PATCH 1221/1892] fix bugs in inspector --- cube/graph/gener/rvd/layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/gener/rvd/layout.py b/cube/graph/gener/rvd/layout.py index 62773ec0..ec7c3d6b 100644 --- a/cube/graph/gener/rvd/layout.py +++ b/cube/graph/gener/rvd/layout.py @@ -462,7 +462,7 @@ def draw_subtensor(t: IRSubTensor, xy: Tuple[int], color: str): suby_nchunks = t.parent.shape[0] // t.shape[0] subh = recflen / suby_nchunks - suby = y + subh * (t.indmap[0][0] // t.shape[1]) + suby = y + subh * (t.indmap[0][0] // t.shape[0]) # if t.valmap != (0, 1): ax.text(x=x+recflen/2, y=y+recflen+recflen/2, s=f'val({t.valmap[0]}/{t.valmap[1]})', From c2085b84e5e6d0d2e4e7cbb35b222a599f6543ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 Jan 2023 17:26:33 +0800 Subject: [PATCH 1222/1892] fix bugs in staging for inference --- cube/graph/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index a30ca113..bc475f14 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -690,7 +690,7 @@ def staging(self, nodes: Tuple[IRFwOperation]): end = starts[sid+1] if sid != len(starts) - 1 else last_fidx + 1 while isinstance(self.node(begin), IRDataOperation): begin += 1 - while isinstance(self.node(end), IRDataOperation): + while end < len(self._nodes) and isinstance(self.node(end), IRDataOperation): end -= 1 if begin == end: continue assert begin < end From 8ebd429d3c9b6acb53a5a294c4ea54d17d3d0a77 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 Jan 2023 17:27:05 +0800 Subject: [PATCH 1223/1892] add predefined gpipe inference --- cube/graph/schedule/schedinfer.py | 93 +++++++++++++++++++++++++++++ cube/runtime/schedule/__init__.py | 3 +- cube/runtime/schedule/schedinfer.py | 25 ++++++++ 3 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 cube/graph/schedule/schedinfer.py create mode 100644 cube/runtime/schedule/schedinfer.py diff --git a/cube/graph/schedule/schedinfer.py b/cube/graph/schedule/schedinfer.py new file mode 100644 index 00000000..b0c2bf6a --- /dev/null +++ b/cube/graph/schedule/schedinfer.py @@ -0,0 +1,93 @@ + +from typing import Dict, Optional, List +import warnings + +from cube.ir.cten import IRCell +from cube.ir.adapter.adapter import IRAdapter + +from cube.graph.graph import IRGraph, IRSegment +from cube.graph.schedule import IRScheduleStrategy + + +class IRScheduleInfer(IRScheduleStrategy): + """ + 1F1B Scheduling + + This treats model as a linear graph which can be + grouped into continous stages. + + [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] + [Recv-Backward] Backward-Segment [Send-Backward] + """ + + def __init__(self, graph, nmicros: int): + super().__init__(graph, nmicros) + self.signature = 'cube.runtime.schedule.ScheduleInfer.run' + # forward body + self.fsegments: Dict[int, IRSegment] = dict() + # forward send + self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() + # forward recv + self.rfadapter: Dict[int, Optional[IRAdapter]] = dict() + # num_stage + self.num_stages: int = -1 + + def apply(self) -> IRGraph: + self.mesh() + for node in self.graph.nodes(): + if isinstance(node, IRAdapter) and node.forward: + if len(set(node.outputs())) > 1 or len(set(node.inputs())) > 1: + warnings.warn( + "Detected one adapter has more than one input/output in stage transmission, " + "which is not safe for current scheduling implementation due to potential " + "mis-ordering of arguments. Better to use torch.cat and torch.chunk to " + "merge multiple tensors into one and unpack it at next stage." + ) + # no backward + for seg in self.graph.select(ntype=IRSegment): + assert seg.isfw(), "Detected backward, which should not exist in inference" + # stage doesn't share devices + fsegments: List[IRSegment] = [fseg for fseg in self.segments if fseg.isfw()] + self.num_stages = len(fsegments) + for sid, fseg in enumerate(fsegments): + for devid in fseg.device: + # forward body + assert devid not in self.fsegments, "One device cannot have multiple forward stages" + self.fsegments[devid] = fseg + if sid == 0: + assert len(self.recvers[fseg]) == 0, "Expect no forward send at first stage" + else: + assert len(self.recvers[fseg]) == 1, "Expect one forward recv at non-first stage" + self.rfadapter[devid] = None if sid == 0 else self.recvers[fseg][0] + # forward send + if sid == self.num_stages - 1: + assert len(self.senders[fseg]) == 0, "Expect no forward send at last stage" + else: + assert len(self.senders[fseg]) == 1, "Expect no forward send at last stage" + self.sfadapter[devid] = None if sid == self.num_stages - 1 else self.senders[fseg][0] + + return self.graph + + def kwargs(self, devid: int) -> Dict[str, IRCell]: + """ + return kwargs for runtime caller + """ + return dict( + segment = self.fsegments[devid], + sfadapter = self.sfadapter[devid], + rfadapter = self.rfadapter[devid], + dataloader = 'dataloader', + num_microbatch = self.nmicros, + ) + + def __repr__(self) -> str: + dscp = '' + for mesh in self.devmesh: + devid = mesh[0] + # segment = self.segments[devid].to_str(skip_attr=True) if self.segment[mesh[0]] else None + dscp += (f"GPipe Infer Schedule: Stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" + f" segment = {self.segments[devid]}\n" + f" send-fw = {self.sfadapter[mesh[0]]}\n" + f" recv-fw = {self.rfadapter[mesh[0]]}\n" + f")\n") + return dscp diff --git a/cube/runtime/schedule/__init__.py b/cube/runtime/schedule/__init__.py index f5b70816..34b7d0c7 100644 --- a/cube/runtime/schedule/__init__.py +++ b/cube/runtime/schedule/__init__.py @@ -1,3 +1,4 @@ from cube.runtime.schedule.sched1f1b import Schedule1F1B from cube.runtime.schedule.schedmix import ScheduleMix -from cube.runtime.schedule.schednf1b import ScheduleNF1B \ No newline at end of file +from cube.runtime.schedule.schednf1b import ScheduleNF1B +from cube.runtime.schedule.schedinfer import ScheduleInfer \ No newline at end of file diff --git a/cube/runtime/schedule/schedinfer.py b/cube/runtime/schedule/schedinfer.py new file mode 100644 index 00000000..0ba9a498 --- /dev/null +++ b/cube/runtime/schedule/schedinfer.py @@ -0,0 +1,25 @@ +from typing import Callable, Iterable, List, Optional +import torch + +from cube.runtime.schedule.strategy import ScheduleABC + + +class ScheduleInfer(ScheduleABC): + + @staticmethod + def run(segment: Callable, # forward body + rfadapter: Optional[Callable], # recv forward adapter + sfadapter: Optional[Callable], # send forward adapter + dataloader: Iterable, + num_microbatch: int): + + for _ in range(num_microbatch): + # recv forward + inputs = ScheduleInfer.adapter_step(rfadapter, False) + inputs = ScheduleInfer.dataloader_step(dataloader) if inputs == (None,) else inputs + # forward + outputs = ScheduleInfer.forward_step(segment, *inputs) + # send forward + ScheduleInfer.adapter_step(sfadapter, True, *outputs) + + ScheduleInfer.assert_empty() From 588045347cb3944018a886acf3f09f68e30d53a5 Mon Sep 17 00:00:00 2001 From: Rongwei Lu Date: Wed, 1 Feb 2023 04:04:44 +0000 Subject: [PATCH 1224/1892] Updated requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index c5ffda85..9d5c2fed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ einops matplotlib pytest +setuptools==60.7.0 more-itertools --find-links https://download.pytorch.org/whl/torch_stable.html From a9c50b5af7a8dbea6e21d6d9c811e349533f52e9 Mon Sep 17 00:00:00 2001 From: lynex Date: Thu, 2 Feb 2023 16:31:13 +0800 Subject: [PATCH 1225/1892] add gpt infer with pipeline parallel schedule PAS1F --- examples/nlp/gpt/infer.py | 2 +- examples/nlp/gpt/policy/mpmd.py | 35 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py index 10d07657..4e0a487c 100644 --- a/examples/nlp/gpt/infer.py +++ b/examples/nlp/gpt/infer.py @@ -65,7 +65,7 @@ def inter(): def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) - return loss + # return loss model = model.get_gen_module() torch.distributed.barrier() diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index b9927cd1..6c67b2bf 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -7,6 +7,7 @@ from cube.ir.cten import IRCell from cube.ir.operator import IRDataOperation, IRFwOperation from cube.graph.schedule.sched1f1b import IRSchedule1F1B +from cube.graph.schedule.schedinfer import IRScheduleInfer def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: @@ -129,6 +130,40 @@ def PAS1F1B(graph: IRGraph, resource): return graph +def PAS1F(graph: IRGraph, resource): + """ + 1F1B scheduling + """ + num_stages = resource.ngpus + num_microbatch = 16 + + # group to transformer layers + fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] + transformers = _group_to_transformers(fnodes) + + # staging + fstages = [[] for _ in range(num_stages)] + nlayer_per_stage = (len(transformers) // resource.ngpus) + for lid, fnodes in enumerate(transformers): + stage_id = min(lid // nlayer_per_stage, num_stages - 1) + fstages[stage_id] += fnodes + graph.staging(tuple(stages[0] for stages in fstages)) + + # stage to device + fsegments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] + assert len(fsegments) == num_stages + for devid, segment in enumerate(fsegments): + graph.assign(segment, devid) + + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + graph.assign(node, 0) + + strategy = IRScheduleInfer(graph, num_microbatch) + graph.predef_sched(strategy) + return graph + + def PASMegatron(graph: IRGraph, resource): """ 1F1B scheduling From fcc0d1ef6179147fd232a152090fc9dfcb4989f4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Feb 2023 19:19:42 +0800 Subject: [PATCH 1226/1892] auto place interface --- cube/graph/gener/rvd/intra.py | 85 ++++++++++++++++++++++++++++++++--- 1 file changed, 79 insertions(+), 6 deletions(-) diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index 5e90f233..2ec42e73 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -3,9 +3,11 @@ import numpy as np import sys import copy +import warnings from cube.ir.dtype import IRDType -from cube.ir.tensor import IRFullTensor +from cube.ir.cten import IRCell +from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.adapter.prim import IRAdapterPrim from cube.ir.adapter.prim import AllGatherPrim # d2r @@ -15,6 +17,7 @@ from cube.ir.adapter.prim import ChunkPrim # r2d from cube.ir.adapter.prim import VChunkPrim # r2v +from cube.graph.segment import IRSegment from cube.graph.gener.rvd.layout import RVDLayout from cube.graph.gener.utils import tensor_vd_repr @@ -530,11 +533,81 @@ def default_cost_fn(prim: IRAdapterPrim) -> int: class IntraAutoPlacer: @staticmethod - def auto_place(shape: TShape, - fw_src_rvd: TRVD, fw_dst_rvd: TRVD, - bw_src_rvd: Optional[TRVD], bw_dst_rvd: Optional[TRVD], - src_placement: List[int], - cost_fn: Optional[Callable] = None) -> Tuple[Tuple[int], float]: + def auto_place(graph: IRSegment, ftensor: IRFullTensor, + producers: List[IRCell], consumers: List[IRCell], + cost_fn: Optional[Callable] = None) -> List[int]: + """ + Automatically find good device placement for consumers given the producer placement + The backward will also be considered. + + @param ftensor IRFullTensor + @param producers List[IRCell]: producers that must be assigned to devices + @param consumers List[IRCell]: consumers that are about to be assigned + + @return cost float: the cost after the placement. + """ + assert all(len(p.device) > 0 for p in producers), f"Expect all producers have been assigned to a device" + + devices = [p.device[0] for p in producers] + assert len(set(devices)) == len(producers),f"Expect each producer is on a different device" + + assert len(producers) == len(consumers), \ + f"Expect same number of producer and consumer, but got {len(producers)} producers and {len(consumers)} consumers" + + if any(len(consumer.device) > 0 for consumer in consumers): + warnings.warn('Detected at least one consumer has been assigned to a device, which will be overrided by a new device placement.') + + # reorder producer to match with device order + producers = sorted(producers, key=lambda n: n.device[0]) + + # get forward produced tensors + fptensors: List[IRSubTensor] = [] + fctensors: List[IRSubTensor] = [] + for producer in producers: + assert producer in graph.producers(ftensor), f"Producer {producer} doesn't generate ftensor: {ftensor}" + pidx = graph.producers(ftensor).index(producer) + fptensors.append(graph.ptensors(ftensor)[pidx]) + for consumer in consumers: + assert consumer in graph.consumers(ftensor), f"Consumer {producer} doesn't take ftensor: {ftensor}" + cidx = graph.consumers(ftensor).index(consumer) + fctensors.append(graph.ctensors(ftensor)[cidx]) + + # get backward producer and consumer tensors + bptensors, bctensors = None, None + if ftensor.grad is not None: + bptensors = [t.grad for t in fctensors] + bctensors = [t.grad for t in fptensors] + # get RVD representation + fw_src = RVDLayout.togrid(ftensor, fptensors) + fw_src_rvd = fw_src.vec + fw_dst = RVDLayout.togrid(ftensor, fctensors) + fw_dst_rvd = fw_dst.vec + bw_src_rvd, bw_dst_rvd = None, None + if ftensor.grad is not None: + bw_src_rvd = RVDLayout.togrid(ftensor.grad, bptensors).vec + bw_dst_rvd = RVDLayout.togrid(ftensor.grad, bctensors).vec + + # get placement advice + devices = [t.device[0] for t in fw_src.mat.flatten()] + placement, _ = IntraAutoPlacer.advice( + ftensor.shape, + fw_src_rvd, fw_dst_rvd, bw_src_rvd, bw_dst_rvd, + devices, cost_fn) + + # assign to device + ordered_placement = [None] * len(consumers) + for devid, t in zip(placement, fw_dst.mat.flatten()): + ordered_placement[consumers.index(t.cell)] = devid + assert all(devid is not None for devid in ordered_placement), f"Internal Error" + + return ordered_placement + + @staticmethod + def advice(shape: TShape, + fw_src_rvd: TRVD, fw_dst_rvd: TRVD, + bw_src_rvd: Optional[TRVD], bw_dst_rvd: Optional[TRVD], + src_placement: List[int], + cost_fn: Optional[Callable] = None) -> Tuple[Tuple[int], float]: """ Search for a good device placement for source and destination RVD partition From 9ce2020895867535e9b553a7d895dd41bc58e7ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Feb 2023 19:27:56 +0800 Subject: [PATCH 1227/1892] adapt with interface changes --- tests/adapter/test_intra_rvd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/adapter/test_intra_rvd.py b/tests/adapter/test_intra_rvd.py index 8c517a17..6286d2b4 100644 --- a/tests/adapter/test_intra_rvd.py +++ b/tests/adapter/test_intra_rvd.py @@ -233,7 +233,7 @@ def test_all_fb_cases_fix_placement(): bprims = IntraPathFinder.path(bp, bc) -def test_all_fb_cases_auto_placement(): +def test_all_fb_cases_advisor(): fshape = [128, 256, 512] ndevs = 8 @@ -255,7 +255,7 @@ def test_all_fb_cases_auto_placement(): bc_rvd = (fp_rvd[0] * fp_rvd[1], 1) + fp_rvd[2:] print(f'test generating | fp rvd: {fp_rvd}, fc rvd: {fc_rvd}, bp rvd: {bp_rvd}, bc rvd: {bc_rvd}') - placement, cost = IntraAutoPlacer.auto_place( + placement, cost = IntraAutoPlacer.advice( fshape, fp_rvd, fc_rvd, bp_rvd, bc_rvd, fdevs) fc = RVDLayout.grid(ftensor, r=fc_rvd[0], v=fc_rvd[1], dims=fc_rvd[2:], devices=placement) @@ -270,7 +270,7 @@ def test_all_fb_cases_auto_placement(): bc_rvd = (fp_rvd[0] * fp_rvd[1], 1) + fp_rvd[2:] print(f'test generating | fp rvd: {fp_rvd}, fc rvd: {fc_rvd}, bp rvd: {bp_rvd}, bc rvd: {bc_rvd}') - placement, cost = IntraAutoPlacer.auto_place( + placement, cost = IntraAutoPlacer.advice( fshape, fp_rvd, fc_rvd, bp_rvd, bc_rvd, fdevs) fc = RVDLayout.grid(ftensor, r=fc_rvd[0], v=fc_rvd[1], dims=fc_rvd[2:], devices=placement) @@ -289,4 +289,4 @@ def test_all_fb_cases_auto_placement(): # test_all_f_cases_auto_placement() # test_one_fb_case() # test_all_fb_cases_fix_placement() - test_all_fb_cases_auto_placement() \ No newline at end of file + test_all_fb_cases_advisor() \ No newline at end of file From 531a3c6be0659c32b679c268bedc505c99763d5f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Feb 2023 20:09:01 +0800 Subject: [PATCH 1228/1892] early exit for intra auto placement --- cube/graph/gener/rvd/intra.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index 2ec42e73..8152b263 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -17,6 +17,7 @@ from cube.ir.adapter.prim import ChunkPrim # r2d from cube.ir.adapter.prim import VChunkPrim # r2v +from cube.graph import IRGraph from cube.graph.segment import IRSegment from cube.graph.gener.rvd.layout import RVDLayout @@ -533,7 +534,7 @@ def default_cost_fn(prim: IRAdapterPrim) -> int: class IntraAutoPlacer: @staticmethod - def auto_place(graph: IRSegment, ftensor: IRFullTensor, + def auto_place(graph: IRGraph, ftensor: IRFullTensor, producers: List[IRCell], consumers: List[IRCell], cost_fn: Optional[Callable] = None) -> List[int]: """ @@ -544,8 +545,10 @@ def auto_place(graph: IRSegment, ftensor: IRFullTensor, @param producers List[IRCell]: producers that must be assigned to devices @param consumers List[IRCell]: consumers that are about to be assigned - @return cost float: the cost after the placement. + @return placement List[int]: the adviced placement + corresponding to each consumer in consumers. """ + assert not ftensor.is_param(), f"Cannot automatically assign device given weight tensor" assert all(len(p.device) > 0 for p in producers), f"Expect all producers have been assigned to a device" devices = [p.device[0] for p in producers] @@ -557,9 +560,12 @@ def auto_place(graph: IRSegment, ftensor: IRFullTensor, if any(len(consumer.device) > 0 for consumer in consumers): warnings.warn('Detected at least one consumer has been assigned to a device, which will be overrided by a new device placement.') + if len(producers) == 1: + graph.assign(consumers[0], producers[0].device) + return [producers[0].device] + # reorder producer to match with device order producers = sorted(producers, key=lambda n: n.device[0]) - # get forward produced tensors fptensors: List[IRSubTensor] = [] fctensors: List[IRSubTensor] = [] From ffea54e07c97cad594c26405fed5a7c1b2b890db Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 7 Feb 2023 07:55:14 +0000 Subject: [PATCH 1229/1892] Merged PR 1444: Support prim::setattr for torchscript parser (inference only) --- cube/codegen/frontend_mapping.py | 11 +++++++- cube/graph/function/pyfunc.py | 30 ++++++++++++++++++++++ cube/graph/gener/gen.py | 43 ++++++++++++++++++++++++++++++++ cube/graph/graph.py | 9 +++++++ cube/graph/parser/parser.py | 29 +++++++++++++++++++++ 5 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 cube/graph/function/pyfunc.py diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index 83649e21..c3ad7aa3 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -163,6 +163,13 @@ def emit_to(node, arg_vars:list, kw_pairs:dict) -> str: return _common_rule_join_all(node, arg_vars, kw_pairs) +def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + + assert arg_vars[1].startswith('self.') + member = f'"{arg_vars[1][5:]}"' + return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" + + class Sign2EmitRule: @staticmethod @@ -192,7 +199,9 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: 'torch.ones': emit_ones, 'torch.Tensor.to': emit_to, 'torch.rand': emit_rand, - 'torch.tensor': emit_new_tensor + 'torch.tensor': emit_new_tensor, + + 'setattr': emit_setattr, } diff --git a/cube/graph/function/pyfunc.py b/cube/graph/function/pyfunc.py new file mode 100644 index 00000000..5d6c4fdf --- /dev/null +++ b/cube/graph/function/pyfunc.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Tuple, Any +import itertools + +from cube.ir.operator import IRFwOperation +from cube.ir.cten import IRTensor + + +class IRPyFunc(IRFwOperation): + """ + Python runtime function + """ + + def __init__(self, signature: str, + inputs: Tuple[Any], outputs: Tuple[Any], **kwargs): + name = signature.split('.')[-1] + super().__init__(name, signature, len(inputs), len(outputs)) + for idx, t in enumerate(inputs): + self.set_input(idx, t) + for idx, t in enumerate(outputs): + self.set_output(idx, t) + self.kwargs.update(**kwargs) + + def infer_shape(self) -> bool: + """ + Shape will not be inferred for python runtime + """ + return True + + + diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index daddb5ce..673bda86 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -7,6 +7,7 @@ import cube.graph.gener.utils as utils from cube.graph.graph import IRGraph from cube.graph.segment import IRSegment +from cube.graph.function.pyfunc import IRPyFunc from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -117,6 +118,8 @@ def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: graph._reorder_producer_consumer() # remove anchor node graph = IRAdapterGener.remove_anchor(graph) + # automatic replace pyfunc + graph = IRAdapterGener.auto_pyfunc(graph) # automatic transform multiref graph = IRAdapterGener.autoref(graph) # generate adapters for activation @@ -139,6 +142,46 @@ def remove_anchor(graph: IRSegment): elif isinstance(anchor, IRSegment): IRAdapterGener.remove_anchor(anchor) return graph + + @staticmethod + def auto_pyfunc(graph: IRSegment): + """ + Make pyfunc to be local + """ + for func in graph.select(ntype=IRPyFunc, flatten=False): + assert func.mirror is None, "PyFunc is only supported by inference" + assert all(not isinstance(t, IRSubTensor) for t in func.outputs()), \ + "PyFunc doesn't support tensor outputs" + # get devices it will lowered to + devices = set() + for t in func.inputs(): + if not isinstance(t, IRSubTensor): continue + producers = graph.producers(t.parent) + for p in producers: + devices.update(p.device) + pyfuncs = [] + # lower to each device + for devid in devices: + inputs = [] + for t in func.inputs(): + if isinstance(t, IRSubTensor): + if t.is_attr(): + tensors = set(tensor for tensor in graph.ctensors(t.parent) if devid in tensor.device and tensor.cell != func) + else: + tensors = set(tensor for tensor in graph.ptensors(t.parent) if devid in tensor.device) + assert len(tensors) == 1, \ + f"Find {len(tensors)} != 1 versions of tensor {t} on a same device." + t = list(tensors)[0] + inputs.append(t) + lower_func = IRPyFunc(func.signature, inputs, func.outputs(), **func.kwargs) + lower_func.device = devid + pyfuncs.append(lower_func) + position = graph.remove(func) + for pyfunc in pyfuncs: + graph.insert(pyfunc, position) + for segment in graph.select(ntype=IRSegment, flatten=False): + IRAdapterGener.auto_pyfunc(segment) + return graph @staticmethod def gen_weight(graph: IRGraph) -> IRGraph: diff --git a/cube/graph/graph.py b/cube/graph/graph.py index bc475f14..f07c9f96 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -18,6 +18,7 @@ from cube.graph.function.function import Identity, MultiRef from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.pyfunc import IRPyFunc from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo @@ -283,6 +284,10 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis warnings.warn( 'Detected partition a multiref node. This will be skipped as system will automatically handle it.') return [node] + if isinstance(node, IRPyFunc): + warnings.warn( + 'Detected partition a python runtime function. This will be skipped as system will automatically handle it') + return [node] fsegment: IRSegment = self.segment(node) # replicate @@ -344,6 +349,10 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], warnings.warn( 'Detected partition a multiref node. This will be skipped as system will automatically handle it.') return [node] + if isinstance(node, IRPyFunc): + warnings.warn( + 'Detected partition a python runtime function. This will be skipped as system will automatically handle it') + return [node] # get partitioned sub-nodes fnodes = algo.instantiate(**config) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 660d2364..778eaf77 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -4,6 +4,7 @@ from typing import Any, List, Tuple, Optional from cube.ir.operator import IRFwOperation +from cube.graph.function.pyfunc import IRPyFunc from cube.ir.tensor import IRFullTensor import cube.ir as ir from cube.graph.parser.frame import Frame @@ -29,6 +30,7 @@ class ScriptNodeKind(enum.Enum): PrimPythonOp = 10 PrimDevice = 11 # erased PrimLoop = 12 + PrimSetAttr = 13 class ScriptModuleParser: @@ -144,6 +146,8 @@ def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): def ntype(node: torch._C.Node): if node.kind() == 'prim::GetAttr': return ScriptNodeKind.PrimGetAttr + if node.kind() == 'prim::SetAttr': + return ScriptNodeKind.PrimSetAttr if node.kind() == 'prim::CallMethod': return ScriptNodeKind.PrimCallMethod if node.kind() == 'prim::CallFunction': # the op call @@ -186,6 +190,8 @@ def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation] return ScriptModuleParser.parse_prim_method_node(node, module, frame) if node_type == ScriptNodeKind.PrimGetAttr: return ScriptModuleParser.parse_prim_attr_node(node, module, frame) + if node_type == ScriptNodeKind.PrimSetAttr: + return ScriptModuleParser.parse_prim_setattr_node(node, module, frame) if node_type == ScriptNodeKind.PrimConstant: return ScriptModuleParser.parse_prim_constant_node(node, module, frame) if node_type == ScriptNodeKind.PrimListConstruct: @@ -452,6 +458,29 @@ def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: frame.add_var(var_name, val) return list() + @staticmethod + def parse_prim_setattr_node(node, module, frame) -> List[IRFwOperation]: + """ + = prim::SetAttr[name="past_k"](%self, %k.1) + """ + signature = 'setattr' + target = node.s('name') # past_k + module_name = node.inputsAt(0).debugName() + module = module if module_name == 'self' else frame.get_var(module_name) + + var = node.inputsAt(1).debugName() # %k.1 + dtype = node.inputsAt(1).type().str() # torch.Tensor + assert dtype == 'Tensor', "Only tensor can be set inside module" + var_tensor = frame.get_var(var) + # make sure of having same attribute name in graph + assert frame.has_attr(target), f"SetAttr currently only supports replace an existing tensor attribute" + target_tensor = frame.get_attr(target) # IRFullTensor + # target_name = f"{target_tensor.name}_{target_tensor.tid}" + func = IRPyFunc(signature, ('self', target_tensor, var_tensor), ()) + # setattr(module, target, var) -> This will have error + return [func] + + @staticmethod def parse_prim_constant_node(node, module, frame) -> List[None]: """ From 253aa207752cc1bc9708ad1d86552d6d0f29d83e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 7 Feb 2023 16:49:21 +0800 Subject: [PATCH 1230/1892] timer return 0 if the field is not recorded --- cube/profiler/timer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index d2ad9192..0c244680 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -1,6 +1,7 @@ from typing import Optional import time import sys +import warnings import torch @@ -115,7 +116,8 @@ def duration(self, times: int, field_name: str = 'default') -> float: @return span float: wall clock in milliseconds. """ if field_name not in self.instance.field: - raise RuntimeError(f"Missing start on the field {field_name}") + warnings.warn(f"CudaTimer: {field_name} doesn't record.") + return 0.0 if len(self.instance.field[field_name]) != 0: raise RuntimeError(f"timer for field {field_name} not stopped") return self.instance.field_data[field_name] / times * 1000 # in ms From 53ab1609b19a0576e812911bdc9b549a38da0349 Mon Sep 17 00:00:00 2001 From: Hongzhou Liu Date: Tue, 7 Feb 2023 09:26:10 +0000 Subject: [PATCH 1231/1892] Example for GPT token generation inference --- examples/nlp/blocks/attention.py | 18 +++++++++++------- examples/nlp/blocks/encoder.py | 8 ++++++-- examples/nlp/gpt/infer.py | 5 +++-- examples/nlp/gpt/policy/spmd.py | 5 +++++ 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index f704dac4..e628eba5 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -1,6 +1,11 @@ import torch import cube +@cube.graph.parser.register('* -> *') +@torch.jit.ignore +def func_print_shape(x: torch.Tensor, msg: str): + print(msg, x.size()) + return x @cube.graph.parser.register('L^ N E^, (h+ d^ 3) E^, (h+ d^ 3), E^ (h+ d^) -> L^ N E^', name='self_attention') def self_attention(query: torch.Tensor, @@ -99,8 +104,7 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, return output -# @cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> l N E^', name='one_attention') -@cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d 3) E^, (h+ d 3), E^ (h+ d) -> l N E^', name='one_attention') +@cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d 3) E^, (h+ d 3), E^ (h+ d) -> l N E^, L^ N (h+ d), L^ N (h+ d)', name='one_attention') def one_attention(hidden_states: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor, @@ -154,12 +158,12 @@ def one_attention(hidden_states: torch.Tensor, attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) l (L+l) -> (N h) l (L+l) #no dropout in inference attn attn = torch.nn.functional.dropout(attn, dropout_p, is_training, False) # (N h) l (L+l) -> (N h) l (L+l) - v = v.transpose(0, 1) - output = torch.bmm(attn, v) # (N h) l (L+l), (N h) (L+l) d -> (N h) l d + v_t = v.transpose(0, 1) + output = torch.bmm(attn, v_t) # (N h) l (L+l), (N h) (L+l) d -> (N h) l d output = output.transpose(0, 1).contiguous() # (N h) l d -> l (N h) d output = output.view(l, N, num_head * dim_head) # l (N h) d -> l N (h d) output = torch.nn.functional.linear(output, out_proj, None) # l N (h d), E E -> l N E - return output + return output, k.view(k_L, N, -1), v.view(v_L, N, -1) class MultiHeadSelfAttention(torch.nn.Module): @@ -250,7 +254,7 @@ def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: floa def forward(self, query: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor): - attn = one_attention( + attn, past_k, past_v = one_attention( query, past_embed_key, past_embed_value, # self.q_proj, self.q_bias, # self.k_proj, self.k_bias, @@ -260,4 +264,4 @@ def forward(self, query: torch.Tensor, past_embed_key: torch.Tensor, past_embed_ self.num_heads, self.scaling, self.dropout_p, self.training, mask=True ) attn = attn + self.out_bias - return attn \ No newline at end of file + return attn, past_k, past_v \ No newline at end of file diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 1fa1b697..5c906877 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -1,5 +1,5 @@ import torch -from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention +from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention, func_print_shape from examples.nlp.blocks.mlp import MLP @@ -56,7 +56,9 @@ def __init__(self, embed_dim: int, num_heads: int, def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.self_attn_layer_norm(x) - x = self.self_attn_partial(x, self.past_embed_key, self.past_embed_value) + x, past_k, past_v = self.self_attn_partial(x, self.past_embed_key, self.past_embed_value) + self.past_embed_key = past_k + self.past_embed_value = past_v x = self.dropout(x) x = x + residual @@ -65,4 +67,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(x) x = self.dropout(x) x = x + residual + # func_print_shape(self.past_embed_key, 'past_k: ') + # func_print_shape(self.past_embed_value, 'past_v: ') return x \ No newline at end of file diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py index 4e0a487c..81066599 100644 --- a/examples/nlp/gpt/infer.py +++ b/examples/nlp/gpt/infer.py @@ -16,6 +16,7 @@ from examples.nlp.gpt.model import GPTInfer, GPTInferDataLoader from examples.nlp.gpt.model import GPTDataLoader +from examples.nlp.gpt.model import build_gpt_config import cube from cube.profiler.timer import CudaTimer, print_each_rank @@ -51,9 +52,9 @@ def inter(): print(f'torch.cuda.is_available() = {torch.cuda.is_available()}') - batch_size = 1 + batch_size = 8 - model = GPTInfer() + model = GPTInfer(batch_size=batch_size, cfg=build_gpt_config('350M')) model = model if not args.fp16 else model.half() # model = model.cuda() #only for PyTorch run model.eval() diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index 9838de62..0f2a1710 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -160,6 +160,11 @@ def PASMegatronInferTP(graph: IRGraph, resource): ffns = [node for node in fnodes if node.name == 'feedforward'] for ffn in ffns: _tp(graph, ffn, tp_devs, idx=1, dim=0) + + # func_print_shape + prts = [node for node in fnodes if node.name == 'func_print_shape'] + for prt in prts: + _tp(graph, prt, tp_devs, idx=0, dim=2) # first embedding linear first_emb_anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor) and node.name == 'first_embed'] From d7f0a3158b26c8782ac36d591eba28fea9dd0d78 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 7 Feb 2023 18:39:26 +0800 Subject: [PATCH 1232/1892] fix synthetic dataloader --- cube/runtime/syndata.py | 26 ++++++++++++++++++-------- examples/nlp/gpt/model.py | 32 ++++++++------------------------ 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index e7059cd8..f3479d39 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -4,9 +4,7 @@ from typing import Any, List, Optional, Tuple, Union import torch - - -__all__ = ['CubeDataLoader', 'SynDataLoader'] +import warnings class CubeDataLoader: @@ -29,6 +27,9 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype], batch_d self.shapes = tuple([list(shape) for shape in shapes]) self.dtypes = dtypes self.batch_dims = (0,) * len(self.shapes) if batch_dims is None else batch_dims + bs = [shape[dim] for shape, dim in zip(self.shapes, self.batch_dims)] + assert len(set(bs)) == 1, f"Expect batch size same in each data shapes" + self.batch_size = bs[0] def get_batch_size(self) -> int: """ @@ -137,8 +138,8 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, dtypes = tuple([torch.float] * len(shapes)) super().__init__(shapes, dtypes, batch_dims) - self.buffer: Union[torch.Tensor, Tuple[torch.Tensor]] = None - self.set_random_sample() + datas = self.random_sample() + self.set_output(datas) def __iter__(self): return self @@ -146,7 +147,7 @@ def __iter__(self): def __next__(self): return self.buffer - def set_random_sample(self): + def random_sample(self) -> Tuple[torch.Tensor]: torch.manual_seed(0) datas = [] for shape, dtype in zip(self.shapes, self.dtypes): @@ -156,13 +157,22 @@ def set_random_sample(self): device=torch.cuda.current_device(), requires_grad=False) ) + datas if len(datas) == 0: self.buffer = None else: datas = tuple(datas) if len(datas) > 1 else datas[0] - self.buffer = datas + return tuple(datas) if len(datas) > 0 else datas + + def set_output(self, datas: Union[torch.Tensor, Tuple[torch.Tensor]]): + datas = (datas,) if torch.is_tensor(datas) else tuple(datas) + if len(datas) == 0: + self.buffer = None + else: + self.buffer = datas[0] if len(datas) == 1 else datas def set_batch_size(self, batch_size: int): super().set_batch_size(batch_size) - self.set_random_sample() + datas = self.random_sample() + self.set_output(datas) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 396e7c55..34e50ff5 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -131,11 +131,10 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): return loss -class GPTDataLoader(cube.runtime.syndata.CubeDataLoader): +class GPTDataLoader(cube.runtime.syndata.SynDataLoader): def __init__(self, batch_size: int, cfg: Config = Config()): - self.bs = batch_size self.cfg = cfg super().__init__( shapes=([batch_size, self.cfg.seqlen], @@ -144,31 +143,23 @@ def __init__(self, batch_size: int, cfg: Config = Config()): dtypes=(torch.int64, torch.int64), batch_dims=(0, 0) ) - self.samples = [self.random_sample()] def random_sample(self): input_ids = torch.randint( 0, self.cfg.num_embeddings, - size=(self.bs, self.cfg.seqlen), + size=(self.batch_size, self.cfg.seqlen), dtype=torch.int64, device=torch.cuda.current_device() ) position_ids = torch.arange( 0, self.cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() - ).repeat(self.bs).view(self.bs, -1) - return (input_ids, position_ids) + ).repeat(self.batch_size).view(self.batch_size, -1) + return input_ids, position_ids - def __iter__(self): - return self - def __next__(self): - return self.samples[0] - - -class GPTInferDataLoader(cube.runtime.syndata.CubeDataLoader): +class GPTInferDataLoader(cube.runtime.syndata.SynDataLoader): def __init__(self, batch_size: int, cfg: Config = Config()): - self.bs = batch_size self.cfg = cfg super().__init__( shapes=([batch_size, 1], @@ -177,23 +168,16 @@ def __init__(self, batch_size: int, cfg: Config = Config()): dtypes=(torch.int64, torch.int64), batch_dims=(0, 0) ) - self.samples = [self.random_sample()] def random_sample(self): input_ids = torch.randint( 0, self.cfg.num_embeddings, - size=(self.bs, 1), + size=(self.batch_size, 1), dtype=torch.int64, device=torch.cuda.current_device() ) position_ids = torch.arange( 0, 1, dtype=torch.int64, device=torch.cuda.current_device() - ).repeat(self.bs).view(self.bs, -1) - return (input_ids, position_ids) - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] \ No newline at end of file + ).repeat(self.batch_size).view(self.batch_size, -1) + return input_ids, position_ids From fd6df43fd6ffafd17768035bede8f0c2ca984229 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 7 Feb 2023 19:39:54 +0800 Subject: [PATCH 1233/1892] automatically assign device for IRPyFunc --- cube/compiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cube/compiler.py b/cube/compiler.py index 7d0a8b06..297b7aa3 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -9,6 +9,7 @@ from cube.graph.graph import IRGraph from cube.ir.operator import IRDataOperation from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.pyfunc import IRPyFunc from cube.execplan import ExecutionPlan from cube.execplan.planpass.fusion import DiffFusion @@ -138,6 +139,8 @@ def decorator(fn: Callable) -> Callable: # skip graph anchor and multiref: they will be removed or replaced by system if isinstance(node, IRGraphAnchor) or node.name == 'multiref': graph.assign(node, 0) + if isinstance(node, IRPyFunc): + graph.assign(node, 0) if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") From d343db88acc9cb3e9bf5f815a1b3b1d9373ce9f4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 7 Feb 2023 20:06:18 +0800 Subject: [PATCH 1234/1892] fix dataloader --- cube/runtime/syndata.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index f3479d39..b869fada 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -157,12 +157,7 @@ def random_sample(self) -> Tuple[torch.Tensor]: device=torch.cuda.current_device(), requires_grad=False) ) - datas - if len(datas) == 0: - self.buffer = None - else: - datas = tuple(datas) if len(datas) > 1 else datas[0] - return tuple(datas) if len(datas) > 0 else datas + return tuple(datas) def set_output(self, datas: Union[torch.Tensor, Tuple[torch.Tensor]]): datas = (datas,) if torch.is_tensor(datas) else tuple(datas) From 36cf27b4a33f02ea3b46e6c17acf40a887e15595 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 8 Feb 2023 12:31:28 +0000 Subject: [PATCH 1235/1892] Merged PR 1447: Pipeline scheduling customization support - add step-based description interface to customize a scheduling plan with micro-batch - refine and extend code generation with detailed scheduling - add reuse IRCell to reuse code structure with different input and output --- cube/codegen/__init__.py | 3 + cube/codegen/codegen.py | 1069 ---------------------------- cube/codegen/emit.py | 222 ++++++ cube/codegen/lifecycle.py | 117 +++ cube/codegen/module/__init__.py | 0 cube/codegen/module/autograd.py | 68 ++ cube/codegen/module/module.py | 548 ++++++++++++++ cube/codegen/schedule/__init__.py | 0 cube/codegen/schedule/schedule.py | 181 +++++ cube/codegen/syntax/blocks.py | 11 +- cube/compiler.py | 13 +- cube/execplan/execplan.py | 460 ++++++++---- cube/execplan/planpass/fusion.py | 26 +- cube/execplan/planpass/grouping.py | 2 +- cube/flags.py | 2 + cube/graph/gener/gen.py | 2 +- cube/graph/graph.py | 12 +- cube/graph/schedule/predefined.py | 112 +++ cube/graph/schedule/schedplan.py | 267 +++++++ cube/graph/segment.py | 44 +- cube/ir/adapter/adapter.py | 29 +- cube/ir/cten.py | 34 +- cube/ir/operator.py | 3 + cube/ir/tensor.py | 8 +- cube/runtime/executor.py | 134 ++-- 25 files changed, 2017 insertions(+), 1350 deletions(-) delete mode 100644 cube/codegen/codegen.py create mode 100644 cube/codegen/emit.py create mode 100644 cube/codegen/lifecycle.py create mode 100644 cube/codegen/module/__init__.py create mode 100644 cube/codegen/module/autograd.py create mode 100644 cube/codegen/module/module.py create mode 100644 cube/codegen/schedule/__init__.py create mode 100644 cube/codegen/schedule/schedule.py create mode 100644 cube/graph/schedule/predefined.py create mode 100644 cube/graph/schedule/schedplan.py diff --git a/cube/codegen/__init__.py b/cube/codegen/__init__.py index e69de29b..b9af357b 100644 --- a/cube/codegen/__init__.py +++ b/cube/codegen/__init__.py @@ -0,0 +1,3 @@ + +from cube.codegen.module.module import ModuleCodeGen +from cube.codegen.schedule.schedule import ScheduleCodeGen \ No newline at end of file diff --git a/cube/codegen/codegen.py b/cube/codegen/codegen.py deleted file mode 100644 index ab01ad23..00000000 --- a/cube/codegen/codegen.py +++ /dev/null @@ -1,1069 +0,0 @@ -""" -Generate Pytorch code given the model DAG and the transformation config -""" -import itertools -from typing import Dict, Generator, Iterable, List, Any, Optional, Set, Tuple, Union -import warnings -import torch -import copy -from more_itertools import split_when -from cube.graph.parser.mapping import Sign2Op - -from cube.ir.cten import IRCell, IRTensor -from cube.ir.dtype import IRDType - -from cube.ir.tensor import IRSubTensor -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.ir.adapter import IRWeightReducer, IRAdapter -from cube.ir.adapter.prim import CollectivePrim, IRAdapterPrim -from cube.graph.graph import IRSegment -from cube.graph.schedule import IRScheduleStrategy - -from cube.execplan import ExecutionPlan - -from cube.codegen.syntax.symtable import SymbolTable -from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock -from cube.codegen.frontend_mapping import Sign2EmitRule - -from cube.flags import CompileFlag - - -def get_backward_callsite_io_tensors(bp_segment:IRSegment): - """ - Returns: - ``` - (input_tensors, output_tensors, output_grads, input_grads) - #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~ - #inputs to 'backward' outputs of 'backward' - ``` - """ - assert isinstance(bp_segment, IRSegment) and not bp_segment.isfw() - - input_tensors = [t for t in bp_segment.mirror.inputs() if \ - isinstance(t, IRSubTensor) and \ - t.requires_grad and \ - not t.is_attr() - ] - output_tensors = [t for t in bp_segment.mirror.outputs() if isinstance(t, IRSubTensor) and t.grad is not None] - input_grads = [t.grad for t in input_tensors] - - # WARNING !!! - # non-tensor gradients like scalar '1.0f' are removed in 'bpSeg.inputs()' - # so the items of 'bpSeg.inputs()' are generally disaligned with 'output_grads' here. - output_grads = [t.grad for t in output_tensors] - - return input_tensors, output_tensors, output_grads, input_grads - -# TODO this could be applied in Adapters -def calc_tenvars_lifetime( - nodes:Iterable[IRCell], - subgraph_outputs:Iterable[IRTensor], - subgraph_inputs:Iterable[IRTensor] = [] - ) -> Dict[IRTensor, int]: - """ - Calculate the lifetime of tensor variables ahead-of-time. - So that during schedule the GC on those variables can take place in time. - - E.g. at what timings may a tensor variable O_i be discarded (i.e. no longer referred)? - ``` - ..., O_i, O_j, ... = f(I_1, ..., I_M) - # Case 1, immediately, because it's never used - O_i = None - # Case 2, after some invocation, because it's no longer referred - ... = g(..., O_j, ...) - O_j = None - ``` - - Returns: `Dict[IRTensor, int]` - - For each kv-pair `(t, i)` it indicates the last reference of tensor `t` - is at the `i`-th (0-based) node's inputs, - i.e. the variable for tensor `t` could be released *BEFORE* the `i`-th statement - in codegen. - - If an input of the subgraph is never used, its corresponding `i` is `0` -- this will - lead to an immediate release at the beginning of a function. - Tensors that exist till the end of the subgraph will have lifetime greater than the - size of that subgraph. Generally we don't need to manually release those tensors, - since they are automatically released when the generated function returns. - """ - - lifetime : Dict[IRTensor, int] = dict() - - def is_temp_tensor(v): - return isinstance(v, IRSubTensor) and not v.is_attr() - - lifetime.update((tsin, 0) for tsin in subgraph_inputs if is_temp_tensor(tsin)) - - for i, node in enumerate(nodes): - - outputs : Iterable[IRTensor] - inputs : Iterable[IRTensor] - - if isinstance(node, IRSegment): - if node.isfw(): - outputs = node.outputs() - inputs = node.inputs() - else: - # NOTE - # An backward 'IRSegment' does not explicitly record all tensors that are - # inputs-and-outputs-to-its-correspondeding-autograd.grad-call after codegen. - # - # Where a call to 'torch.autograd.grad' in Python is like: - # ``` - # grad_inputs : Tuple[torch.Tensor, ...] = torch.autograd.grad(outputs, inputs, grad_outputs) - # len(grad_inputs) == len(inputs) - # len(grad_outputs) == len(outputs) - # ``` - # - # But a backward 'IRSegment' itself only records _extra_ information to take - # gradients for inputs to a forward 'IRSegment': - # - # - Inputs of the backward 'IRSegment' are - # gradient tensors for outputs of the corresponding forward 'IRSegment' - # - # WARNING: non-tensor gradients like scalar '1.0f' are removed, - # so the items of 'bpSeg.inputs()' are generally disaligned with 'fw_outputs()' - # - # - Outputs of the backward 'IRSegment' are - # gradient tensors for both explicit and implicit inputs of the forward 'IRSeg' - # - # P.S. the implicit inputs of the forward 'IRSeg' are like 'nn.Parameter's - # which are model fields and accessed by e.g. 'self.weights'. - # Generally, by viewing a gradient tensor of some input, we cannot distinguish - # whether the corresponding input is explicit or implicit. - - fw_inputs, fw_outputs, output_grads, input_grads = \ - get_backward_callsite_io_tensors(node) - # remove loss gradient - output_grads = [t for t in output_grads if not t.is_loss()] - - outputs = input_grads - inputs = list(itertools.chain(fw_inputs, fw_outputs, output_grads)) - - else: - outputs = node.outputs() - inputs = node.inputs() - - # aggressively mark all outputs for immediate deletion, - # namely *before* 'i+1'-th statement, in case it's never used. - lifetime.update((tout, i+1) for tout in outputs if is_temp_tensor(tout)) - - # "fast-forward" all inputs to the current statement, namely before 'i+1'-th node. - lifetime.update((tin, i+1) for tin in inputs if is_temp_tensor(tin)) - - # end of 'for' - - # Here (i+1) is always greater than 'len(nodes)' - # Generally we don't manually release those tensors since the enclosing function is about to - # return, all local variables are automatically released. - # But we do need to update the lifetime of all outputs, to avoid early releasing. - lifetime.update((tsout, i+1) for tsout in subgraph_outputs if is_temp_tensor(tsout)) - - return lifetime - - -class CodeGen: - """ - Generate code for the model - """ - def __init__(self, execplan: ExecutionPlan): - if not isinstance(execplan, ExecutionPlan): - raise TypeError("execplan should be ExecutionPlan") - self.execplan = execplan - - def dtype_map(self, dtype: IRDType) -> str: - if not isinstance(dtype, IRDType): - raise TypeError("Expected IRDType") - return 'torch.' + dtype.value - - def node_naming(self, node: IRCell) -> str: - return f"{node.name}{node._id}" - - def tensor_naming(self, tensor: Any, prefix_attr: Optional[str] = None) -> str: - """ - Return the var name. - For tensor, return the {prefix}{tensor.name}_{tensor.tid} - For non-tensor, return its string - - @param tensor Any: any value - @attr_prefix Optional[str]: prefix for a attributed tensor - - @return str - """ - if isinstance(tensor, IRTensor): - tensor_name = tensor.name - if '.' in tensor_name: - tensor_name = tensor_name.split('.')[0] - name = '_'.join([tensor_name, str(tensor.tid)]) - if prefix_attr is not None and tensor.is_attr(): - name = prefix_attr + name - else: - name = str(tensor) - return name - - def tuple_naming(self, tensors: List[Any], - skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: - """ - Return the tupled tensor name. - - @param tensors List[Any]: list of any value - @param skip_attr bool: whether to skip graph attribute in the tensors - @param prefix_attr bool: whether to add a prefix for graph attribute - - @return name str: the tupled tensor name - """ - names = [] - for t in tensors: - if isinstance(t, IRTensor) and skip_attr and t.is_attr(): - continue - names.append(self.tensor_naming(t, prefix_attr)) - name = '(' + ', '.join(names + ['']) + ')' - return name - - def return_naming(self, tensors: List[Any], - skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: - """ - Return the tensors in return format, i.e. tupled name without brackets. - - @param tensors List[Any]: list of any value - @param skip_attr bool: whether to skip graph attribute in the tensors - @param prefix_attr bool: whether to add a prefix for graph attribute - - @return name str: the tupled tensor name - """ - names = [] - for t in tensors: - if isinstance(t, IRTensor) and skip_attr and t.is_attr(): - continue - names.append(self.tensor_naming(t, prefix_attr)) - names = '_' if len(names) == 0 else ', '.join(names) - return names - - def kwargs_naming(self, **kwargs) -> str: - """ - Return the kwarg naming, connected by ', ' - - @param kwargs Dict[str, Any]: kwargs - - @return name str - """ - names = [] - for name, val in kwargs.items(): - # TODO: Ad-hoc patch for amp - if CompileFlag.use_amp and val == 'torch.float32': - val = 'torch.float16' - names.append(f'{name}={val}') - name = ', '.join(names) - return name - - def emit_tensors_release(self, tensors:Iterable[IRTensor]) -> str: - tnames : Generator = (self.tensor_naming(t) for t in tensors) - return 'del ' + ', '.join(tnames) - - -class AutogradAdapterCodeGen(CodeGen): - """ - Generate autograd adapter code (PyTorch) - """ - def __init__(self): - - self.fw_ins: List[IRSubTensor] = list() - self.fw_body: List[str] = list() - self.fw_ous: List[IRSubTensor] = list() - - self.bw_ins: List[IRSubTensor] = list() - self.bw_body: List[str] = list() - self.bw_ous: List[IRSubTensor] = list() - - def emit_prim(self, prim: IRAdapterPrim) -> str: - if len(prim.inputs()) == 1: - itensors = self.tensor_naming(prim.inputs()[0]) - else: - itensors = self.tuple_naming(prim.inputs()) - kwargs = list() - for name, val in prim.kwargs.items(): - kwargs.append(f'{name}={val}') - kwargs = ', '.join(kwargs) - outputs = self.return_naming(prim.outputs()) - code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' - return code - - def gen(self, fadapter: IRAdapter) -> List[str]: - assert fadapter.forward and fadapter.differentiable and fadapter.custom, "generate autograd for a non-differentiable adapter" - assert fadapter.mirror is not None - name = AutogradAdapterCodeGen.name(fadapter) - with ClassBlock(class_name=name, derived=['torch.autograd.Function']) as cb: - # forward - cb.insert_body('@staticmethod') - finputs = [self.tensor_naming(t) for t in fadapter.inputs()] - with FunctionBlock(func_name='forward', args=['ctx']+finputs) as fw: - for prim in fadapter.prims: - fw.insert_body(self.emit_prim(prim)) - outputs = self.return_naming(fadapter.outputs()) - fw.insert_body(f'return {outputs}') - cb.insert_body(fw.code) - # backward - cb.insert_body('@staticmethod') - badapter: IRAdapter = fadapter.mirror - binputs = [self.tensor_naming(t) for t in badapter.inputs()] - with FunctionBlock(func_name='backward', args=['ctx']+binputs) as bw: - for prim in badapter.prims: - bw.insert_body(self.emit_prim(prim)) - outputs = self.return_naming(badapter.outputs()) - bw.insert_body(f'return {outputs}') - cb.insert_body(bw.code) - return cb.code - - @staticmethod - def name(adapter: IRAdapter) -> str: - return f'Adapter{adapter.cid}' - - -class ModelCodeGen(CodeGen): - """ - Generate model code - - `ModelCodeGen` traverses all IR nodes and categorizes their intermediately generated - codes into different parts, - then reorders and concatenates these parts into the final code for PyTorch to run. - - These parts are progressively stored into fields of `ModelCodeGen` - - - `init_code : List[str]` - Statements like `import torch` - - - `model_init_statements : List[str]` - Statements of the `__init__` constructor of the final `nn.Module` in codegen, - - E.g. (lines are split into `List[str]`) - ```python - self.init_group(ranks=[0, 1, 2, 3]) - self.weight_63 = torch.nn.Parameter(torch.empty((2048, 8192), dtype=torch.float32)) - self.add_full_map('weight_63', 3, (slice(0, 2048, None), slice(0, 8192, None)), 1) - ``` - - including: - -- initialization of model weights, which are class fields; - - - `model_methods_bodies : List[List[str]]` - Definitions of the Python code for forward computations like Segments or Adapters - - Note that codes within this field haven't been organized into valid Python methods, - namely without signatures and return statements, both of which will be extracted - from corresponding IRSegment/IRAdapter in later processes. - E.g. - ``` - [ - # intermediate codes for 'segment123(self, tensor_2222)' - [ - 'tensor_3333 = torch.view(tensor_2222, [1,2,3,4])' - ] - - # intermediate codes for 'adapter456(self, tensor_4444)' - [ - 'tensor_5555 = cube.runtime.adapter.all_reduce(tensor_4444, ranks=[0,1,2,3])' - ] - ] - ``` - """ - - def __init__(self, execplan: ExecutionPlan): - super().__init__(execplan) - # model full code - self.init_code: List[str] = [ - '\n\n########## Generated Model Code ###########', - 'from typing import *', - 'import torch', 'import torch.utils.checkpoint as ckpt', - 'import cube', '', ''] - - if CompileFlag.use_nnfusion: - self.init_code.extend(['import nnfusion', '']) - - # customized op code - for _, op_impl in Sign2Op.kOpCodeDef.items(): - # self.init_code.append('@torch.jit.script') - self.init_code.append(op_impl) - self.init_code += [''] - # module init code - self.model_init_statements: List[str] = list() - # module method bodies for forward computations, e.g. Segments, Adapters. - self.model_methods_bodies: List[List[str]] = list() - # module member name - self.symbols = SymbolTable() - # ref module to check shared variables - self._ref_module = torch.nn.Module() - # batch size - self.batch_size = None - - def init_comm_groups(self): - """ - Get all communication groups. - - Creating communication group requires all the devices - enter the same call. - - The fields storing intermediate codes that are populated by this method: - - `model_init_statements` - """ - graph = self.execplan.graph - sign = 'self.init_group(ranks={ranks})' - # collect groups from weight reducer - comm_groups: Dict[Tuple[int]] = list() - for node in graph.nodes(): - if isinstance(node, IRWeightReducer): - ranks = list(node.device) - ranks.sort() - ranks = tuple(ranks) - if ranks not in comm_groups: - comm_groups.append(ranks) - # collect groups from p2p fusion - adapters = [n for n in graph.nodes(flatten=True) if isinstance(n, IRAdapter)] - for adapter in adapters: - for prim in adapter.prims: - if isinstance(prim, CollectivePrim): - ranks = list(prim.kwargs['ranks']) - ranks.sort() - ranks = tuple(ranks) - if ranks not in comm_groups: - comm_groups.append(ranks) - # create communication group - self.model_init_statements.append('# communication groups') - for ranks in comm_groups: - code = sign.format(ranks=list(ranks)) - self.model_init_statements.append(code) - self.model_init_statements.append(' ') - - def gen(self, device: int, outfile=None, attach=False) -> str: - """ - Generate model implementation code based on the given graph. - """ - gencode = copy.copy(self.init_code) - node_args: List[List[str]] = list() - gen_nodes: List[IRCell] = list() - - # init customized adapter - for seg in [seg for seg in self.execplan.seq(device) if isinstance(seg, IRSegment)]: - for adapter in [n for n in seg.nodes() if isinstance(n, IRAdapter)]: - if adapter.forward and adapter.differentiable and adapter.custom: - gencode += AutogradAdapterCodeGen().gen(adapter) + ['', ''] - adapter.signature = AutogradAdapterCodeGen.name(adapter) + '.apply' - - # initialize communication groups - self.init_comm_groups() - - # parse graph body - for node in self.execplan.seq(device): - if isinstance(node, IRSegment): - if not node.isfw(): continue # skip backward segment - codes = self.emit_segment_code(node) - elif isinstance(node, IRFwOperation): - raise RuntimeError(f"Unexcepted global-level op call: {node}") - elif isinstance(node, IRAdapter): - codes = self.emit_adapter_code(node) - elif isinstance(node, IRWeightReducer): - self.emit_reducer_init(node) - codes = self.emit_reducer_call(node) - elif isinstance(node, IRBpOperation): - continue - elif isinstance(node, IRDataOperation): - self.emit_batchsize_code(node) - continue - else: - raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") - - # emit node tensor declaration into `__init__` - # typically it's about the `nn.Parameter` - self.emit_node_tensors_declare(node) - - # emit node code - # codes : List[str] - self.model_methods_bodies.append(codes) - gen_nodes.append(node) - - args = list() - for t in node.inputs(): - if isinstance(t, IRSubTensor): - if not t.is_attr(): - args.append(self.tensor_naming(t)) - else: - args.append(self.tensor_naming(t)) - node_args.append(args) - # generate full code - with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: - with FunctionBlock(func_name='__init__', args=['self']) as ib: - ib.insert_body(self.model_init_statements) - # switch to training or inference mode - if self.execplan.inference: - ib.insert_body('self.eval()') - else: - ib.insert_body('self.train()') - cb.insert_body('') - cb.insert_body(ib.code) - for idx, node in enumerate(gen_nodes): - name = self.node_naming(node) - input_args = ['self'] + node_args[idx] - forward_code = self.model_methods_bodies[idx] - - with FunctionBlock(func_name=name, args=input_args) as fb: - fb.insert_body(forward_code) - # generate output - outputs = [self.tensor_naming(t) for t in node.outputs()] - return_code = f"return {', '.join(outputs)}" - fb.insert_body(return_code) - cb.insert_body('') - if CompileFlag.use_nnfusion and name.startswith('segment'): - cb.insert_body('@nnfusion.jit') - if CompileFlag.use_jit and name.startswith('segment'): - cb.insert_body('@torch.jit.script_method') - cb.insert_body(fb.code) - - - gencode += cb.code - gencode += [''] - - code = '\n'.join(gencode) - # write to file - if outfile: - with open(outfile, 'a' if attach else 'w') as f: - f.write(code) - - # clear used buffer - self.clear() - return code - - def emit_node_tensors_declare(self, node: IRCell): - """ - Emit tensor declaration code - - The fields storing intermediate codes that are populated by this method: - - `model_init_statements` - - This method also populates `self.symbols : SymbolTable` to record - the names of the variables for the tensors ever encountered. - """ - psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" - bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}))" - map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" - if not isinstance(node, IRSegment): - for itensor in node.inputs(): - name = self.tensor_naming(itensor, prefix_attr='self.') - if isinstance(itensor, IRSubTensor): - if itensor.is_attr() and not self.symbols.exist(name): - self.symbols.create(name) - sign = psign if itensor.is_param() else bsign - code = sign.format( - name=self.tensor_naming(itensor), - shape=tuple(itensor.shape), - dtype=self.dtype_map(itensor.dtype) - ) - self.model_init_statements.append(code) - tid = itensor.parent.tid - slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) - val_chunks = itensor.valmap[1] - code = map_sign.format( - attr=self.tensor_naming(itensor), tid=tid, - slicers=str(slicers), val_chunks=val_chunks - ) - self.model_init_statements.append(code) - self.model_init_statements.append('') - if isinstance(itensor, str): - if name.startswith('self.'): - if not hasattr(self._ref_module, name[5:]): - raise NotImplementedError("member attribute is not added") - for output in node.outputs(): - self.symbols.create(self.tensor_naming(output, prefix_attr='self.')) - else: - for sub_node in node.nodes(): - self.emit_node_tensors_declare(sub_node) - return - - def emit_segment_code(self, segment: IRSegment) -> List[str]: - """ - Emit IRSegment code. - - The resultant `List[str]` will be lines of the statements of the final - Python method for the targeted Segment. - The resultant lines will not include the signature and the return statement - of the generated Python method. These lines will be put into `model_methods_bodies` - and the missing Python-syntactic parts will be injected later on. - - e.g. - ``` - [ - # no method signature - 'tensor_3333 = torch.view(tensor_2222, [1,2,3,4])', - 'tensor_2222 = None', # if in dataflow there is no more reference - 'tensor_4444 = torch.sum(tensor_3333)', - 'def recompute(...):', - ' return ...', - 'tensor_5555 = torch.utils.checkpoint(recompute, tensor_4444)', - 'tensor_4444 = None', # if in dataflow there is no more reference - # no return statement - ] - ``` - - Nodes in the segment will group into recompute region - - The fields storing intermediate codes that are populated by this method: - - NONE - """ - - codes = [] - - def emit_nodes_invocations(i_nodes: List[Tuple[int, IRCell]], - lifetime_by_line_id: Dict[int, List[IRTensor]]) -> List[str]: - """ - Emit code to invoke operations and adapter, - e.g. (the lines are split into `List[str]`) - - ``` - tensor_2222 = torch.view(tensor_1111, size=[3,6,9]) - tensor_1111 = None # if no more reference - tensor_3333 = cube.runtime.adapter.allgather_reducescatter(tensor_2222, dim=1, rank=[0,1]) - tensor_2222 = None # if no more reference - ``` - - The fields storing intermediate codes that are populated by this method: - - NONE - """ - node_codes = [] - for i, node in i_nodes: - - # NOTE - # If a tensor is still referenced in any later recomputing group, its lifetime is - # definitely greater than the current sequence of statements here. - # Therefore we get chance to extend the lifetime of tensors like that, - # and properly release them after the call to 'torch.utils.checkpoint'. - # - tensors_to_del : Optional[List[IRTensor]] = lifetime_by_line_id.get(i, None) - if tensors_to_del is not None: - node_codes.append(self.emit_tensors_release(tensors_to_del)) - - if isinstance(node, IRFwOperation): - code = self.emit_op_code(node) - node_codes += code - elif isinstance(node, IRAdapter): - code = self.emit_adapter_code(node) - node_codes += code - else: - raise RuntimeError(f"unexpected type {type(node)} in IRSegment") - - return node_codes - - # returns: (code_lines, group_inputs, group_outputs) - def emit_rc_nodes(i_nodes: List[Tuple[int, IRCell]], lifetime_by_line_id: dict) \ - -> Tuple[List[str], List[IRTensor], List[IRTensor]]: - """ - Emit code to define a Python function for ReComputing and invoke it - e.g. (the lines are split into `List[str]`) - - ``` - def recompute(tensor_2222): - tensor_3333 = torch.view(tensor_2222, size=[3,6,9]) - tensor_2222 = None # no more reference - return tensor_3333 - # in the beginning we have `import torch.utils.checkpoint as ckpt` - tensor_4444 = ckpt.checkpoint(recompute, tensor_1111) - ``` - - REMARK: - - In the example above, 'tensor_2222' can be released within the RC subgraph, which also means that - the variable for this tensor can also be released within the enclosing graph, after the 'checkpoint' call. - - The generated RC subgraph will have no "free variables". - All involved tensors that are defined outside of the RC group are made explicit inputs; - All tensors, that are defined within the RC group and are referenced after RC subgraph ends, are made explicit outputs; - And if a within-RC-group tensors are not used anymore, it's not returned. - - The fields storing intermediate codes that are populated by this method: - - NONE - """ - assert len(i_nodes) > 0 - node_codes = [] - - nodes : List[IRCell] = [node for i, node in i_nodes] - subseg = segment.create_segment(nodes) - - inputs = [t for t in subseg.inputs() if not t.is_attr()] - input_names = [self.tensor_naming(t) for t in inputs] - input_names_tuple = ', '.join(input_names) - outputs = [t for t in subseg.outputs()] - output_names = [self.tensor_naming(t) for t in outputs] - output_names_tuple = ', '.join(output_names) - - # 'graph.segment(nodes)' ensures that if a tensor is no longer used (in RC group or in later code), - # it's not included in 'outputs'. - # And we will not generate 'return' statement for it, since it will cause the error - # that the variable is not defined (because it has been 'del'-ed). - - with FunctionBlock('recompute', input_names, False) as fb: - # The nodes to recompute share the same space of line_ids (or "node ids") with non-recomputable nodes. - # e.g. those ids in subgraphs are not 0-based, and incremented after the preceding non-rc nodes and so on. - # - # So within the recomputing subgraph, tensors can be released if they are no longer used - # i.e. not returned by the 'def recompute(...)' - # since 'execplan.graph.segment(nodes)' will make all "free variables" as explicit inputs/outputs - # to that subgraph. - for ncode in emit_nodes_invocations(i_nodes, lifetime_by_line_id): - fb.insert_body(ncode) - fb.insert_body(f'return {output_names_tuple}') - node_codes += [''] + fb.code + [''] - node_codes.append( - f'{output_names_tuple} = ckpt.checkpoint(recompute, {input_names_tuple})' - ) - - return node_codes, inputs, outputs - - def get_equiv_recompute_gid(node:Union[IRFwOperation, IRAdapter]) -> Optional[int]: - assert isinstance(node, (IRAdapter, IRFwOperation)), f"Invalid type: {type(node)}" - return node.recompute - - def should_start_new_recompute_group(i_prev, i_cur) -> bool: - # i_prev, i_cur: Tuple[int, Union[IRFwOp,IRAdapter]] - prev_gid = get_equiv_recompute_gid(i_prev[1]) - cur_gid = get_equiv_recompute_gid(i_cur[1]) - return cur_gid != prev_gid - - nodes : List[IRCell] = segment.nodes() - - # After calculating the recompute groups, for each group, its input tensors' lifetime - # should be extend to at least beyond the lifetime of that group. - lifetime : Dict[IRTensor, int] = calc_tenvars_lifetime(nodes, segment.outputs(), segment.inputs()) - lifetime_by_line_id : Dict[int, List[IRTensor]] = dict() - for tensor, line_id in lifetime.items(): - lifetime_by_line_id.setdefault(line_id, []).append(tensor) - - # more_itertools.split_when # type: (Iterable[T], (T,T)->bool) -> Iterator[List[T]] - recompute_groups : List[List[Tuple[int, IRCell]]] \ - = list(split_when(enumerate(nodes), should_start_new_recompute_group)) - - for rc_group in recompute_groups: - # all FwOps/Adapters in a group have the same (equivalent) group id, - # check that of the head item, and 'rc_group' will not be empty here. - gid : Optional[int] = get_equiv_recompute_gid(rc_group[0][1]) - if gid is None: - codes += emit_nodes_invocations(rc_group, lifetime_by_line_id) - else: - assert len(rc_group) > 0 - - # Step 1: when entering a RC group: - # - # We insert tensor releasing statement *before* emitting each node. - # But here we are entering the scope of a RC group i.e. 'def recompute(...)'. - # Any releasing before the first node of the RC group, - # should be done before and outside of the RC group. - rc_first_line_id, _rc_first_node = rc_group[0] - # ... and to avoid emitting again, 'pop' the lifetime record. - # Specify the default collection since there might not be any. - rel_tensors_before_rc : Optional[list] = lifetime_by_line_id.pop(rc_first_line_id, None) - if rel_tensors_before_rc is not None: - codes.append(self.emit_tensors_release(rel_tensors_before_rc)) - - # Step 2 - rc_codes, rc_inputs, rc_outputs = emit_rc_nodes(rc_group, lifetime_by_line_id) - codes += rc_codes - - # Step 3: when exiting a RC group: - # - # `emit_rc_nodes` will not emit 'del`-statement for output tensors of the last - # node in the RC group, since those tensors will be immediately released - # as soon as 'recompute(...)' returns. - # We need to remove those tensors from the linearized lifetime - # (namely those with lifetime 'rc_next_line_id') - # and do not release them before the next node after the RC group. - rc_last_line_id, _rc_last_node = rc_group[-1] - rc_next_line_id = rc_last_line_id + 1 - lifetime_by_line_id.pop(rc_next_line_id, None) # specify a default to avoid KeyError - - # Step 4: after exiting a RC group: - # - # We need to release some argument tensors to the 'def recompute(...)' if they are - # no longer used. - # NOTE those tensors may have resulted in some 'del'-statements within the RC - # subfunction. But we need to release them again in the enclosing function, - # after the call to 'torch.checkpoint(recompute, *input_tensors)'. - - # Only release an RC input if: - # - its lifetime does not exceed the lifetime of the RC group; - # - not the case that the function returns after 'checkpoint' the RC subgraph. - if rc_next_line_id != len(nodes): - inputs_to_rel = [rcin for rcin in rc_inputs if lifetime[rcin] <= rc_next_line_id] - if len(inputs_to_rel) > 0: - del_stmt = self.emit_tensors_release(inputs_to_rel) - codes.append(del_stmt) - - # any resultant tensors *defined within the RC group and not used after the group* - # will not be returned from the generate 'def recompute(...)', - # so here we have no resultant tensors (namely 'rc_outputs') to release. - - return codes - - def emit_op_code(self, node: IRFwOperation) -> List[str]: - """ - Emit the statement to call the op in the forward code - (e.g. in Segments, Adapter or CodeGen.Main) - - The result will look like (the lines are split into `List[str]`) - ``` - tensor_3333 = torch.view(tensor_2222, [1,2,3,4,5]) - ``` - - The fields storing intermediate codes that are populated by this method: - - NONE - """ - codes = [] - # insert comment - if node.comment is not None: - codes.append(f'# {node.comment}') - signature = node.signature - inputs = [self.tensor_naming(t, prefix_attr='self.') for t in node.inputs()] - kwargs = {} - for key in node.kwargs: - val = node.kwargs[key] - if isinstance(val, str) and 'self.' not in val: - val = '"' + val + '"' - kwargs[key] = val - - emit_rule = Sign2EmitRule.map(signature) - body = emit_rule(node, inputs, kwargs) - - if len(node.outputs()) == 0: - code = body - else: - outputs = [self.tensor_naming(t) for t in node.outputs()] - outputs = ', '.join(outputs) - code = f'{outputs} = {body}' - codes.append(code) - return codes - - def emit_adapter_code(self, node: IRAdapter) -> List[str]: - """ - Emit the statment of the adapter call - - The resultant `List[str]` will be lines of the statements of the final - Python method for the targeted Segment, - without the method signature and the return statement. - - The fields storing intermediate codes that are populated by this method: - - NONE - """ - codes = [] - assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" - prims = [node] if node.differentiable and node.custom else [prim for prim in node.prims] - - for prim in prims: - if len(prim.inputs()) == 1: - itensors = self.tensor_naming(prim.inputs()[0], prefix_attr='self.') - else: - itensors = self.tuple_naming(prim.inputs(), prefix_attr='self.') - kwargs = self.kwargs_naming(**prim.kwargs) - outputs = self.return_naming(prim.outputs()) - code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' - codes.append(code) - return codes - - def emit_reducer_init(self, node: IRWeightReducer) -> None: - """ - Emit code to initialize involved reducer objects in `__init__`. - - The fields storing intermediate codes that are populated by this method: - - `model_init_statements` - """ - max_nbytes = CompileFlag.max_reducer_bucket - # reducer init interface - reducer_init = '{reducer} = cube.runtime.adapter.Reducer(ranks={ranks}, max_bucket_size_bytes={max_nbytes})' - reducer_add = 'self.add_reducer({reducer})' - add_param = '{reducer}.add_param({weight})' - # create reducer in declare region - weights = node.inputs() - reducer_name = f'self.wreducer{node._id}' - self.model_init_statements.append('') - init_code = reducer_init.format(reducer=reducer_name, ranks=node.device, max_nbytes=max_nbytes) - self.model_init_statements.append(init_code) - weights = [self.tensor_naming(t, prefix_attr='self.') for t in weights] - for weight in weights: - add_param_code = add_param.format(reducer=reducer_name, weight=weight) - self.model_init_statements.append(add_param_code) - add_code = reducer_add.format(reducer=reducer_name) - self.model_init_statements.append(add_code) - - def emit_reducer_call(self, node: IRWeightReducer): - """ - Emit the statment to invoke a reducer object. - - The fields storing intermediate codes that are populated by this method: - - NONE - """ - reducer_name = f'self.wreducer{node._id}' - code = f'{reducer_name}.allreduce()' - return [code] - - def emit_batchsize_code(self, node: IRDataOperation): - """ - Emit batch size declare - """ - signature = 'self.set_batch_size({bs})' - bs = [t.shape[dim] for t, dim in zip(node.outputs(), node.get_batch_dims())] - bs = set(bs) - if len(bs) > 1: - warnings.warn(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') - bs = list(bs)[0] if len(bs) == 1 else None - assert self.batch_size is None or self.batch_size == bs, f"Not match for batch size: {self.batch_size} != {bs}" - self.model_init_statements.append(signature.format(bs=bs)) - self.batch_size = bs - - def clear(self): - """ - Clear buffer that used for generating code - """ - # module init code - self.model_init_statements: List[str] = list() - # module forward code - self.model_methods_bodies: List[List[str]] = list() - # module member name - self.symbols = SymbolTable() - # batch size - self.batch_size = None - - -class ScheduleCodeGen(CodeGen): - - def __init__(self, execplan: ExecutionPlan): - super().__init__(execplan) - # model full code - self.init_code: List[str] = [ - '\n\n########## Generated Schedule Code ###########', - 'import torch', 'import cube', ''] - # module member name - self.symbols = SymbolTable() - - def gen(self, device: int, outfile=None, attach=False) -> str: - """ - Generate scheduling code based on the given sus - """ - gencode = copy.copy(self.init_code) - - device_nodes = self.execplan.seq(device) - - lifetime : Dict[IRTensor, int] = calc_tenvars_lifetime(device_nodes, self.execplan.graph.outputs()) - lifetime_by_line_id : Dict[int, List[IRTensor]] = dict() - for tensor, line_id in lifetime.items(): - lifetime_by_line_id.setdefault(line_id, []).append(tensor) - - with FunctionBlock(func_name='_train_step', - args=['model', 'dataloader']) as fb: - fb.insert_body('_ = None') - # body code - if len(device_nodes) == 0: - fb.insert_body('pass') - elif self.execplan.graph.sched: - code = self.emit_schedule_plan(self.execplan.graph.sched, device) - fb.insert_body(code) - else: - for i, node in enumerate(device_nodes): - # Decrement reference counts for output tensors that are no longer used - # Tensors here need to release *before* the i-th statement. - tensors : Optional[List[IRTensor]] = lifetime_by_line_id.get(i, None) - if tensors is not None: # not necessarily to have one after each line - fb.insert_body(self.emit_tensors_release(tensors)) - - name = self.node_naming(node) - code = self.emit_node(node, name=name) - fb.insert_body(code) - - # return code - outputs = self.return_naming(self.execplan.graph.outputs()) - code = f'return {outputs}' - fb.insert_body(code) - gencode += fb.code - gencode += [''] - - code = '\n'.join(gencode) - # write to file - if outfile: - with open(outfile, 'a' if attach else 'w') as f: - f.write(code) - return code - - def emit_schedule_plan(self, schedplan: IRScheduleStrategy, devid: int): - signature = schedplan.signature - kwargs: Dict[str, Any] = schedplan.kwargs(devid) - strkwargs = dict() - for kwarg, val in kwargs.items(): - if isinstance(val, IRCell): - name = 'model.' + self.node_naming(val) - elif isinstance(val, (tuple, list)): - brackets = ')' if len(val) != 1 else ',)' - name = '(' + ', '.join('model.' + self.node_naming(n) \ - if isinstance(n, IRCell) else str(n) for n in val) + brackets - else: - name = str(val) - strkwargs[kwarg] = name - code = ', '.join(f'{kwarg}={name}' for kwarg, name in strkwargs.items()) - code = f'{signature}({code})' - return code - - def emit_node(self, node: IRCell, name: str) -> str: - """ - Emit node / subgraph code - """ - fsign = '{outputs} = cube.runtime.executor.fexecute({name}, {model}, *{inputs}, requires_grad={req_grad})' - asign = '{outputs} = cube.runtime.executor.aexecute({model}, *{inputs}, requires_grad={req_grad})' - bsign = '{input_grads} = cube.runtime.executor.backward({name}, {input_tensors}, {output_tensors}, {output_grads})' - - inputs = self.tuple_naming(node.inputs(), skip_attr=True, prefix_attr='model.') - outputs = self.return_naming(node.outputs(), skip_attr=True, prefix_attr='model.') - req_grad = any(t.requires_grad for t in node.outputs() if isinstance(t, IRTensor)) - - if isinstance(node, IRSegment): - # emit forward - if node.isfw(): - code = fsign.format( - outputs = outputs, - name = f"'{name}'", - model = f'model.{name}', - inputs = inputs, - req_grad = req_grad - ) - # emit backward - else: - input_tensors, output_tensors, output_grads, input_grads = \ - get_backward_callsite_io_tensors(node) - - for idx, tensor in enumerate(output_grads): - if isinstance(tensor, IRSubTensor) and tensor.is_loss(): - output_grads[idx] = None - code = bsign.format( - name = f"'{self.node_naming(node.mirror)}'", - input_grads = self.return_naming(input_grads), - input_tensors = self.tuple_naming(input_tensors, skip_attr=True, prefix_attr='model.'), - output_tensors = self.tuple_naming(output_tensors, skip_attr=True, prefix_attr='model.'), - output_grads = self.tuple_naming(output_grads, skip_attr=True, prefix_attr='model.') - ) - - elif isinstance(node, IRDataOperation): - if len(node.inputs()) != 0: - raise RuntimeError("Expect Dataloader node has no inputs") - outputs = [self.tensor_naming(output) for output in node.outputs()] - outputs = self.return_naming(outputs) - code = f'{outputs} = next(dataloader)' - - elif isinstance(node, IRAdapter): - code = asign.format( - outputs = outputs, - model = f'model.{name}', - inputs = inputs, - req_grad = req_grad - ) - - elif isinstance(node, IRWeightReducer): - code = asign.format( - outputs = outputs, - model=f'model.{name}', - inputs='()', - req_grad=req_grad - ) - - else: - raise RuntimeError(f"Unspported node type: {type(node)}") - return code diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py new file mode 100644 index 00000000..50139b3c --- /dev/null +++ b/cube/codegen/emit.py @@ -0,0 +1,222 @@ +from typing import Dict, Generator, Iterable, List, Any, Optional, Set, Tuple, Union +from more_itertools import split_when + + +from cube.ir.cten import IRCell, IRTensor +from cube.ir.dtype import IRDType +from cube.ir.tensor import IRSubTensor +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.adapter import IRWeightReducer, IRAdapter + +from cube.graph.graph import IRSegment +from cube.execplan.execplan import ExeReuseCell + +from cube.codegen.frontend_mapping import Sign2EmitRule + +from cube.flags import CompileFlag + + +class CodeEmission: + """ + Basic emission + """ + + @staticmethod + def dtype_map(dtype: IRDType) -> str: + if not isinstance(dtype, IRDType): + raise TypeError("Expected IRDType") + return 'torch.' + dtype.value + + @staticmethod + def node_name(node: IRCell) -> str: + return f"{node.name}{node.cid}" + + @staticmethod + def tensor_name(tensor: Any, prefix_attr: Optional[str] = None) -> str: + """ + Return the var name. + For tensor, return the {prefix}{tensor.name}_{tensor.tid} + For non-tensor, return its string + + @param tensor Any: any value + @attr_prefix Optional[str]: prefix for a attributed tensor + + @return str + """ + if isinstance(tensor, IRTensor): + tensor_name = tensor.name + if '.' in tensor_name: + tensor_name = tensor_name.split('.')[0] + name = '_'.join([tensor_name, str(tensor.tid)]) + if prefix_attr is not None and tensor.is_attr(): + name = prefix_attr + name + else: + name = str(tensor) + return name + + @staticmethod + def tuple_name(tensors: List[Any], + skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: + """ + Return the tupled tensor name. + + @param tensors List[Any]: list of any value + @param skip_attr bool: whether to skip graph attribute in the tensors + @param prefix_attr bool: whether to add a prefix for graph attribute + + @return name str: the tupled tensor name + """ + names = [] + for t in tensors: + if isinstance(t, IRTensor) and skip_attr and t.is_attr(): + continue + names.append(CodeEmission.tensor_name(t, prefix_attr)) + name = '(' + ', '.join(names + ['']) + ')' + return name + + @staticmethod + def return_name(tensors: List[Any], + skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: + names = [] + for t in tensors: + if isinstance(t, IRTensor) and skip_attr and t.is_attr(): + continue + names.append(CodeEmission.tensor_name(t, prefix_attr)) + names = '_' if len(names) == 0 else ', '.join(names) + return names + + @staticmethod + def kwargs_name(**kwargs) -> str: + names = [] + for name, val in kwargs.items(): + # TODO: Ad-hoc patch for amp + if CompileFlag.use_amp and val == 'torch.float32': + val = 'torch.float16' + names.append(f'{name}={val}') + name = ', '.join(names) + return name + + +class FuncEmission(CodeEmission): + + @staticmethod + def emit_dataloader(node: IRDataOperation) -> List[str]: + return ['next(dataloader)'] + + @staticmethod + def emit_fnode(node: IRFwOperation, prefix_attr: str = None) -> List[str]: + """ + Emit the statement to call the op in the forward code + (e.g. in Segments, Adapter or CodeGen.Main) + + The result will look like (the lines are split into `List[str]`) + ``` + tensor_3333 = torch.view(tensor_2222, [1,2,3,4,5]) + ``` + + The fields storing intermediate codes that are populated by this method: + - NONE + """ + assert isinstance(node, IRFwOperation) + codes = [] + # insert comment + if node.comment is not None: + codes.append(f'# {node.comment}') + signature = node.signature + inputs = [FuncEmission.tensor_name(t, prefix_attr=prefix_attr) for t in node.inputs()] + kwargs = {} + for key in node.kwargs: + val = node.kwargs[key] + if isinstance(val, str) and 'self.' not in val: + val = '"' + val + '"' + kwargs[key] = val + + emit_rule = Sign2EmitRule.map(signature) + body = emit_rule(node, inputs, kwargs) + + if len(node.outputs()) == 0: + code = body + else: + outputs = [FuncEmission.tensor_name(t) for t in node.outputs()] + outputs = ', '.join(outputs) + code = f'{outputs} = {body}' + codes.append(code) + return codes + + @staticmethod + def emit_adapter(node: IRAdapter, prefix_attr: Optional[str] = None) -> List[str]: + """ + Emit the statment of the adapter call + + The resultant `List[str]` will be lines of the statements of the final + Python method for the targeted Segment, + without the method signature and the return statement. + + The fields storing intermediate codes that are populated by this method: + - NONE + """ + codes = [] + assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" + prims = [node] if node.differentiable and node.custom else [prim for prim in node.prims] + + for prim in prims: + if len(prim.inputs()) == 1: + itensors = FuncEmission.tensor_name(prim.inputs()[0], prefix_attr=prefix_attr) + else: + itensors = FuncEmission.tuple_name(prim.inputs(), prefix_attr=prefix_attr) + kwargs = FuncEmission.kwargs_name(**prim.kwargs) + outputs = FuncEmission.return_name(prim.outputs()) + code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' + codes.append(code) + return codes + + @staticmethod + def emit_reducer(node: IRWeightReducer) -> List[str]: + """ + Emit the statment to invoke a reducer object. + + The fields storing intermediate codes that are populated by this method: + - NONE + """ + reducer_name = f'self.wreducer{node._id}' + code = f'{reducer_name}.allreduce()' + return [code] + + @staticmethod + def emit_release(tensors: Iterable[IRTensor]) -> str: + tnames : Generator = (FuncEmission.tensor_name(t) for t in tensors) + return 'del ' + ', '.join(tnames) + + @staticmethod + def get_backward_callsite_io_tensors(bw_cell: IRCell) -> Tuple: + """ + Get backward inputs and outputs + ``` + (input_tensors, output_tensors, output_grads, input_grads) + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~ + #inputs to 'backward' outputs of 'backward' + ``` + + @return input_tensors List[IRSubTensor]: forward input tensors (backward input) + @return output_tensors List[IRSubTensor]: forward output tensors (backward output) + @return output_grads List[IRSubTensor]: gradient of forward output tensors + (backward input) + @return input_grads List[IRSubTensor]: gradient of forward input tensors + (backward output) + """ + assert not bw_cell.isfw() + + input_tensors = [t for t in bw_cell.mirror.inputs() if \ + isinstance(t, IRSubTensor) and \ + t.requires_grad and \ + not t.is_attr() + ] + output_tensors = [t for t in bw_cell.mirror.outputs() if isinstance(t, IRSubTensor) and t.grad is not None] + input_grads = [t.grad for t in input_tensors] + + # WARNING !!! + # non-tensor gradients like scalar '1.0f' are removed in 'bpSeg.inputs()' + # so the items of 'bpSeg.inputs()' are generally disaligned with 'output_grads' here. + output_grads = [t.grad for t in output_tensors] + + return input_tensors, output_tensors, output_grads, input_grads diff --git a/cube/codegen/lifecycle.py b/cube/codegen/lifecycle.py new file mode 100644 index 00000000..a3c7833b --- /dev/null +++ b/cube/codegen/lifecycle.py @@ -0,0 +1,117 @@ +from typing import Iterable, Dict, List +import itertools + +from cube.ir.cten import IRCell, IRTensor +from cube.ir.tensor import IRSubTensor +from cube.graph.segment import IRSegment +from cube.execplan.execplan import ExeReuseCell + +from cube.codegen.emit import FuncEmission + + +class LifeCycle: + + def __init__(self, nodes: List[IRCell], graph_inputs: List[IRSubTensor], graph_outputs: List[IRSubTensor]): + + self.nodes: Dict[int] = {node: lid for lid, node in enumerate(nodes)} + # the last line id of consuming or producing a tensor + self.lifetime: Dict[IRSubTensor, int] = {} + # the tensors can be released given the finish of line id + self.release: Dict[int, List[IRSubTensor]] = {} + + is_activation = lambda t: isinstance(t, IRSubTensor) and not t.is_attr() + + self.lifetime.update((tsin, 0) for tsin in graph_inputs if is_activation(tsin)) + + for i, node in enumerate(nodes): + + outputs : Iterable[IRTensor] + inputs : Iterable[IRTensor] + + if isinstance(node, (IRSegment, ExeReuseCell)): + # forward segment + if node.isfw(): + outputs = node.outputs() + inputs = node.inputs() + # backward segment + else: + fw_inputs, fw_outputs, output_grads, input_grads = \ + FuncEmission.get_backward_callsite_io_tensors(node) + # remove loss gradient + output_grads = [t for t in output_grads if not t.is_loss()] + + outputs = input_grads + inputs = list(itertools.chain(fw_inputs, fw_outputs, output_grads)) + else: + outputs = node.outputs() + inputs = node.inputs() + + # aggressively mark all outputs for immediate deletion, + # namely *after* 'i'-th statement, in case it's never used. + self.lifetime.update((tout, i) for tout in outputs if is_activation(tout)) + + # "fast-forward" all inputs to the current statement, namely after 'i'-th node. + self.lifetime.update((tin, i) for tin in inputs if is_activation(tin)) + + + # Here (i+1) is always greater than 'len(nodes)' + # Generally we don't manually release those tensors since the enclosing function is about to + # return, all local variables are automatically released. + # But we do need to update the lifetime of all outputs, to avoid early releasing. + self.lifetime.update((tsout, i+1) for tsout in graph_outputs if is_activation(tsout)) + + for tensor, line_id in self.lifetime.items(): + self.release.setdefault(line_id, []).append(tensor) + + def release_tensors_after_line(self, line_id: int) -> List[IRSubTensor]: + """ + Get the releasable IRSubTensors after finish of executing of `line_id`. + + @param line_id int + + @return tensors List[IRSubTensors]: tensors that can be released. + """ + return self.release.get(line_id, []) + + def release_tensors_after_node(self, node: IRCell) -> List[IRSubTensor]: + """ + Get the releasable IRSubTensors after finish of executing of the node. + + @param line_id int + + @return tensors List[IRSubTensors]: tensors that can be released. + """ + assert node in self.nodes + line_id = self.nodes[node] + return self.release.get(line_id, []) + + def releasable_after_node(self, tensor: IRSubTensor, node: IRCell) -> bool: + """ + Check if the tensor is releasable after executing the node + + @param tensor IRSubTensor + @param node IRCell + + @return releasable bool + """ + assert node in self.nodes + assert tensor in self.lifetime[tensor] + line_id = self.nodes[node] + return self.lifetime[tensor] < line_id + + def releasable_after_line(self, tensor: IRSubTensor, line: int) -> bool: + """ + Check if the tensor is releasable after executing the node + + @param tensor IRSubTensor + @param line int + + @return releasable bool + """ + return self.lifetime[tensor] < line + + def get_line(self, node: IRCell) -> int: + """ + Get line id of the node + """ + return self.nodes[node] \ No newline at end of file diff --git a/cube/codegen/module/__init__.py b/cube/codegen/module/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/codegen/module/autograd.py b/cube/codegen/module/autograd.py new file mode 100644 index 00000000..dd17790c --- /dev/null +++ b/cube/codegen/module/autograd.py @@ -0,0 +1,68 @@ +from typing import List +from cube.codegen.emit import FuncEmission + +from cube.ir.tensor import IRSubTensor +from cube.ir.adapter import IRAdapter +from cube.ir.adapter.prim import IRAdapterPrim + +from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock + +from cube.codegen.emit import FuncEmission + + +class AutogradAdapterCodeGen(FuncEmission): + """ + Generate autograd adapter code (PyTorch) + """ + def __init__(self): + + self.fw_ins: List[IRSubTensor] = list() + self.fw_body: List[str] = list() + self.fw_ous: List[IRSubTensor] = list() + + self.bw_ins: List[IRSubTensor] = list() + self.bw_body: List[str] = list() + self.bw_ous: List[IRSubTensor] = list() + + def emit_prim(self, prim: IRAdapterPrim) -> str: + if len(prim.inputs()) == 1: + itensors = FuncEmission.tensor_name(prim.inputs()[0]) + else: + itensors = FuncEmission.tuple_name(prim.inputs()) + kwargs = list() + for name, val in prim.kwargs.items(): + kwargs.append(f'{name}={val}') + kwargs = ', '.join(kwargs) + outputs = FuncEmission.return_name(prim.outputs()) + code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' + return code + + def gen(self, fadapter: IRAdapter) -> List[str]: + assert fadapter.isfw() and fadapter.differentiable and fadapter.custom, "generate autograd for a non-differentiable adapter" + assert fadapter.mirror is not None + name = AutogradAdapterCodeGen.name(fadapter) + with ClassBlock(class_name=name, derived=['torch.autograd.Function']) as cb: + # forward + cb.insert_body('@staticmethod') + finputs = [FuncEmission.tensor_name(t) for t in fadapter.inputs()] + with FunctionBlock(func_name='forward', args=['ctx']+finputs) as fw: + for prim in fadapter.prims: + fw.insert_body(self.emit_prim(prim)) + outputs = FuncEmission.return_name(fadapter.outputs()) + fw.insert_body(f'return {outputs}') + cb.insert_body(fw.code) + # backward + cb.insert_body('@staticmethod') + badapter: IRAdapter = fadapter.mirror + binputs = [FuncEmission.tensor_name(t) for t in badapter.inputs()] + with FunctionBlock(func_name='backward', args=['ctx']+binputs) as bw: + for prim in badapter.prims: + bw.insert_body(self.emit_prim(prim)) + outputs = FuncEmission.return_name(badapter.outputs()) + bw.insert_body(f'return {outputs}') + cb.insert_body(bw.code) + return cb.code + + @staticmethod + def name(adapter: IRAdapter) -> str: + return f'Adapter{adapter.cid}' diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py new file mode 100644 index 00000000..4a09aac9 --- /dev/null +++ b/cube/codegen/module/module.py @@ -0,0 +1,548 @@ +from typing import Dict, List, Optional, Tuple +from more_itertools import split_when +import warnings +import copy +import torch + +from cube.ir.cten import IRCell +from cube.ir.tensor import IRSubTensor +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.adapter import IRWeightReducer, IRAdapter +from cube.ir.adapter.prim import CollectivePrim + +from cube.graph.graph import IRSegment +from cube.graph.parser.mapping import Sign2Op + +from cube.execplan import ExecutionPlan +from cube.execplan.execplan import ExeRepetend, ExeReuseCell + +from cube.codegen.syntax.symtable import SymbolTable +from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock + +from cube.codegen.emit import FuncEmission +from cube.codegen.module.autograd import AutogradAdapterCodeGen +from cube.codegen.lifecycle import LifeCycle + +from cube.flags import CompileFlag + + +class ModuleCodeGen(FuncEmission): + """ + Generate module code + + `ModuleCodeGen` traverses all IR nodes and categorizes their intermediately generated + codes into different parts, + then reorders and concatenates these parts into the final code for PyTorch to run. + + These parts are progressively stored into fields of `ModelCodeGen` + + - `init_code : List[str]` + Statements like `import torch` + + - `model_init_statements : List[str]` + Statements of the `__init__` constructor of the final `nn.Module` in codegen, + + E.g. (lines are split into `List[str]`) + ```python + self.init_group(ranks=[0, 1, 2, 3]) + self.weight_63 = torch.nn.Parameter(torch.empty((2048, 8192), dtype=torch.float32)) + self.add_full_map('weight_63', 3, (slice(0, 2048, None), slice(0, 8192, None)), 1) + ``` + + including: + -- initialization of model weights, which are class fields; + + - `model_methods_bodies : List[List[str]]` + Definitions of the Python code for forward computations like Segments or Adapters + + Note that codes within this field haven't been organized into valid Python methods, + namely without signatures and return statements, both of which will be extracted + from corresponding IRSegment/IRAdapter in later processes. + E.g. + ``` + [ + # intermediate codes for 'segment123(self, tensor_2222)' + [ + 'tensor_3333 = torch.view(tensor_2222, [1,2,3,4])' + ] + + # intermediate codes for 'adapter456(self, tensor_4444)' + [ + 'tensor_5555 = cube.runtime.adapter.all_reduce(tensor_4444, ranks=[0,1,2,3])' + ] + ] + ``` + """ + + def __init__(self, execplan: ExecutionPlan) -> None: + super().__init__() + self.execplan: ExecutionPlan = execplan + + self.init_code: List[str] = [ + '\n\n########## Generated Model Code ###########', + 'from typing import *', + 'import torch', 'import torch.utils.checkpoint as ckpt', + 'import cube', '', ''] + + if CompileFlag.use_nnfusion: + self.init_code.extend(['import nnfusion', '']) + + # customized op code + for _, op_impl in Sign2Op.kOpCodeDef.items(): + # self.init_code.append('@torch.jit.script') + self.init_code.append(op_impl) + self.init_code += [''] + # module init code + self.model_init_statements: List[str] = list() + # module method bodies for forward computations, e.g. Segments, Adapters. + self.model_methods_bodies: List[List[str]] = list() + # module member name + self.symbols = SymbolTable() + # ref module to check shared variables + self._ref_module = torch.nn.Module() + # batch size + self.batch_size = None + + def gen(self, device: int, outfile=None, attach=False) -> str: + """ + Generate model implementation code based on the given graph. + """ + gencode = copy.copy(self.init_code) + node_args: List[List[str]] = list() + gen_nodes: List[IRCell] = list() + + # init customized adapter + for seg in [seg for seg in self.execplan.seq(device) if isinstance(seg, IRSegment)]: + for adapter in [n for n in seg.nodes() if isinstance(n, IRAdapter)]: + if adapter.isfw() and adapter.differentiable and adapter.custom: + gencode += AutogradAdapterCodeGen().gen(adapter) + ['', ''] + adapter.signature = AutogradAdapterCodeGen.name(adapter) + '.apply' + + # initialize communication groups + self.init_comm_groups() + + # parse graph body + unrolled_seqs = [] + + for node in self.execplan.seq(device): + # unwrap from ExeReuseCell and ExeRepetend + if isinstance(node, ExeReuseCell): + node = node.cell + if isinstance(node, ExeRepetend): + for node in node.nodes(): + if isinstance(node, ExeReuseCell): + node = node.cell + unrolled_seqs.append(node) + else: + unrolled_seqs.append(node) + # we use ordered dict as ordered set + unrolled_seqs = tuple(dict.fromkeys(unrolled_seqs)) + # emit code + for node in unrolled_seqs: + if isinstance(node, IRSegment): + if not node.isfw(): continue # skip backward segment + codes = self.emit_segment(node) + elif isinstance(node, IRFwOperation): + raise RuntimeError(f"Unexcepted global-level op call: {node}") + elif isinstance(node, IRAdapter): + codes = self.emit_adapter(node, prefix_attr='self.') + elif isinstance(node, IRWeightReducer): + self.init_reducer(node) + codes = self.emit_reducer(node) + elif isinstance(node, IRBpOperation): + continue + elif isinstance(node, IRDataOperation): + self.init_batchsize(node) + continue + else: + raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") + + # emit node tensor declaration into `__init__` + # typically it's about the `nn.Parameter` + self.init_attributes(node) + + # emit node code + # codes : List[str] + self.model_methods_bodies.append(codes) + gen_nodes.append(node) + + args = list() + for t in node.inputs(): + if isinstance(t, IRSubTensor): + if not t.is_attr(): + args.append(ModuleCodeGen.tensor_name(t)) + else: + args.append(ModuleCodeGen.tensor_name(t)) + node_args.append(args) + + # generate full code + with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: + with FunctionBlock(func_name='__init__', args=['self']) as ib: + ib.insert_body(self.model_init_statements) + # switch to training or inference mode + if self.execplan.inference: + ib.insert_body('self.eval()') + else: + ib.insert_body('self.train()') + cb.insert_body('') + cb.insert_body(ib.code) + for idx, node in enumerate(gen_nodes): + name = ModuleCodeGen.node_name(node) + input_args = ['self'] + node_args[idx] + forward_code = self.model_methods_bodies[idx] + + with FunctionBlock(func_name=name, args=input_args) as fb: + fb.insert_body(forward_code) + # generate output + outputs = [ModuleCodeGen.tensor_name(t) for t in node.outputs()] + return_code = f"return {', '.join(outputs)}" + fb.insert_body(return_code) + cb.insert_body('') + if CompileFlag.use_nnfusion and name.startswith('segment'): + cb.insert_body('@nnfusion.jit') + if CompileFlag.use_jit and name.startswith('segment'): + cb.insert_body('@torch.jit.script_method') + cb.insert_body(fb.code) + + gencode += cb.code + gencode += [''] + + code = '\n'.join(gencode) + # write to file + if outfile: + with open(outfile, 'a' if attach else 'w') as f: + f.write(code) + + # clear used buffer + self.clear() + return code + + def init_comm_groups(self): + """ + Get all communication groups. + + Creating communication group requires all the devices + enter the same call. + + The fields storing intermediate codes that are populated by this method: + - `model_init_statements` + """ + graph = self.execplan.graph + sign = 'self.init_group(ranks={ranks})' + # collect groups from weight reducer + comm_groups: Dict[Tuple[int]] = list() + for node in graph.nodes(): + if isinstance(node, IRWeightReducer): + ranks = list(node.device) + ranks.sort() + ranks = tuple(ranks) + if ranks not in comm_groups: + comm_groups.append(ranks) + # collect groups from p2p fusion + adapters = [n for n in graph.nodes(flatten=True) if isinstance(n, IRAdapter)] + for adapter in adapters: + for prim in adapter.prims: + if isinstance(prim, CollectivePrim): + ranks = list(prim.kwargs['ranks']) + ranks.sort() + ranks = tuple(ranks) + if ranks not in comm_groups: + comm_groups.append(ranks) + # create communication group + self.model_init_statements.append('# communication groups') + for ranks in comm_groups: + code = sign.format(ranks=list(ranks)) + self.model_init_statements.append(code) + self.model_init_statements.append(' ') + + def init_comm_groups(self): + """ + Get all communication groups. + + Creating communication group requires all the devices + enter the same call. + + The fields storing intermediate codes that are populated by this method: + - `model_init_statements` + """ + graph = self.execplan.graph + sign = 'self.init_group(ranks={ranks})' + # collect groups from weight reducer + comm_groups: Dict[Tuple[int]] = list() + for node in graph.nodes(): + if isinstance(node, IRWeightReducer): + ranks = list(node.device) + ranks.sort() + ranks = tuple(ranks) + if ranks not in comm_groups: + comm_groups.append(ranks) + # collect groups from p2p fusion + adapters = [n for n in graph.nodes(flatten=True) if isinstance(n, IRAdapter)] + for adapter in adapters: + for prim in adapter.prims: + if isinstance(prim, CollectivePrim): + ranks = list(prim.kwargs['ranks']) + ranks.sort() + ranks = tuple(ranks) + if ranks not in comm_groups: + comm_groups.append(ranks) + # create communication group + self.model_init_statements.append('# communication groups') + for ranks in comm_groups: + code = sign.format(ranks=list(ranks)) + self.model_init_statements.append(code) + self.model_init_statements.append(' ') + + def init_attributes(self, node: IRCell): + """ + Emit tensor declaration code + + The fields storing intermediate codes that are populated by this method: + - `model_init_statements` + + This method also populates `self.symbols : SymbolTable` to record + the names of the variables for the tensors ever encountered. + """ + psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" + bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}))" + map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" + if not isinstance(node, IRSegment): + for itensor in node.inputs(): + name = ModuleCodeGen.tensor_name(itensor, prefix_attr='self.') + if isinstance(itensor, IRSubTensor): + if itensor.is_attr() and not self.symbols.exist(name): + self.symbols.create(name) + sign = psign if itensor.is_param() else bsign + code = sign.format( + name=ModuleCodeGen.tensor_name(itensor), + shape=tuple(itensor.shape), + dtype=self.dtype_map(itensor.dtype) + ) + self.model_init_statements.append(code) + tid = itensor.parent.tid + slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) + val_chunks = itensor.valmap[1] + code = map_sign.format( + attr=ModuleCodeGen.tensor_name(itensor), tid=tid, + slicers=str(slicers), val_chunks=val_chunks + ) + self.model_init_statements.append(code) + self.model_init_statements.append('') + if isinstance(itensor, str): + if name.startswith('self.'): + if not hasattr(self._ref_module, name[5:]): + raise NotImplementedError("member attribute is not added") + for output in node.outputs(): + self.symbols.create(ModuleCodeGen.tensor_name(output, prefix_attr='self.')) + else: + for sub_node in node.nodes(): + self.init_attributes(sub_node) + return + + def init_reducer(self, node: IRWeightReducer) -> None: + """ + Emit code to initialize involved reducer objects in `__init__`. + + The fields storing intermediate codes that are populated by this method: + - `model_init_statements` + """ + max_nbytes = CompileFlag.max_reducer_bucket + # reducer init interface + reducer_init = '{reducer} = cube.runtime.adapter.Reducer(ranks={ranks}, max_bucket_size_bytes={max_nbytes})' + reducer_add = 'self.add_reducer({reducer})' + add_param = '{reducer}.add_param({weight})' + # create reducer in declare region + weights = node.inputs() + reducer_name = f'self.wreducer{node._id}' + self.model_init_statements.append('') + init_code = reducer_init.format(reducer=reducer_name, ranks=node.device, max_nbytes=max_nbytes) + self.model_init_statements.append(init_code) + weights = [ModuleCodeGen.tensor_name(t, prefix_attr='self.') for t in weights] + for weight in weights: + add_param_code = add_param.format(reducer=reducer_name, weight=weight) + self.model_init_statements.append(add_param_code) + add_code = reducer_add.format(reducer=reducer_name) + self.model_init_statements.append(add_code) + + def init_batchsize(self, node: IRDataOperation): + """ + Emit batch size declare + """ + signature = 'self.set_batch_size({bs})' + bs = [t.shape[dim] for t, dim in zip(node.outputs(), node.get_batch_dims())] + bs = set(bs) + if len(bs) > 1: + warnings.warn(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') + bs = list(bs)[0] if len(bs) == 1 else None + assert self.batch_size is None or self.batch_size == bs, f"Not match for batch size: {self.batch_size} != {bs}" + self.model_init_statements.append(signature.format(bs=bs)) + self.model_init_statements.append('') + self.batch_size = bs + + @staticmethod + def emit_segment(segment: IRSegment) -> List[str]: + """ + Emit IRSegment code. + + The resultant `List[str]` will be lines of the statements of the final + Python method for the targeted Segment. + The resultant lines will not include the signature and the return statement + of the generated Python method. These lines will be put into `model_methods_bodies` + and the missing Python-syntactic parts will be injected later on. + + e.g. + ``` + [ + # no method signature + 'tensor_3333 = torch.view(tensor_2222, [1,2,3,4])', + 'tensor_2222 = None', # if in dataflow there is no more reference + 'tensor_4444 = torch.sum(tensor_3333)', + 'def recompute(...):', + ' return ...', + 'tensor_5555 = torch.utils.checkpoint(recompute, tensor_4444)', + 'tensor_4444 = None', # if in dataflow there is no more reference + # no return statement + ] + ``` + + Nodes in the segment will group into recompute region + + The fields storing intermediate codes that are populated by this method: + - NONE + """ + nodes : List[IRCell] = segment.nodes() + lifetime = LifeCycle(nodes, segment.inputs(), segment.outputs()) + rc_groups: List[List[IRCell]] = list( + split_when(nodes, lambda prev, curr: prev.recompute != curr.recompute)) + + codes: List[str] = [] + for rc_group in rc_groups: + assert len(rc_group) > 0 + gid: Optional[int] = rc_group[0].recompute + if gid is None: + codes += ModuleCodeGen._emit_nodes(rc_group, lifetime) + else: + # get recompute excution code + rc_segment = segment.create_segment(rc_group) + rc_codes = ModuleCodeGen._emit_recompute(rc_group, + rc_segment.inputs(), rc_segment.outputs(), lifetime) + codes += rc_codes + # release input tensors after exiting a RC group: + last_node = rc_group[-1] + line = lifetime.get_line(last_node) + if last_node != nodes[-1]: # skip if it is the last node + inputs_to_rel = [t for t in rc_segment.inputs() if lifetime.releasable_after_line(t, line)] + if len(inputs_to_rel) > 0: + del_stmt = ModuleCodeGen.emit_release(inputs_to_rel) + codes.append(del_stmt) + + return codes + + @staticmethod + def _emit_nodes(nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: + """ + Emit code to invoke operations and adapter, + e.g. (the lines are split into `List[str]`) + + ``` + tensor_2222 = torch.view(tensor_1111, size=[3,6,9]) + del tensor_1111 # if no more reference + tensor_3333 = cube.runtime.adapter.allgather_reducescatter(tensor_2222, dim=1, rank=[0,1]) + del tensor_2222 # if no more reference + ``` + + The fields storing intermediate codes that are populated by this method: + - NONE + """ + node_codes = [] + for node in nodes: + # execute + if isinstance(node, IRFwOperation): + code = ModuleCodeGen.emit_fnode(node, prefix_attr='self.') + node_codes += code + elif isinstance(node, IRAdapter): + code = ModuleCodeGen.emit_adapter(node) + node_codes += code + else: + raise RuntimeError(f"unexpected type {type(node)} in IRSegment") + # release + tensors_to_del = lifecycle.release_tensors_after_node(node) + if len(tensors_to_del) > 0: + node_codes.append(FuncEmission.emit_release(tensors_to_del)) + + return node_codes + + @staticmethod + def _emit_recompute(nodes: Tuple[IRCell], inputs: List[IRSubTensor], outputs: List[IRSubTensor], + lifecycle: LifeCycle) -> List[str]: + """ + Emit code to define a Python function for Recomputing and invoke it + e.g. (the lines are split into `List[str]`) + + ``` + def recompute(tensor_2222): + tensor_3333 = torch.view(tensor_2222, size=[3,6,9]) + tensor_2222 = None # no more reference + return tensor_3333 + # in the beginning we have `import torch.utils.checkpoint as ckpt` + tensor_4444 = ckpt.checkpoint(recompute, tensor_1111) + ``` + + REMARK: + - In the example above, 'tensor_2222' can be released within the RC subgraph, which also means that + the variable for this tensor can also be released within the enclosing graph, after the 'checkpoint' call. + - The generated RC subgraph will have no "free variables". + All involved tensors that are defined outside of the RC group are made explicit inputs; + All tensors, that are defined within the RC group and are referenced after RC subgraph ends, are made explicit outputs; + And if a within-RC-group tensors are not used anymore, it's not returned. + + The fields storing intermediate codes that are populated by this method: + - NONE + + @return codes List[str] + """ + assert len(nodes) > 0 + + inputs = [t for t in inputs if not t.is_attr()] + input_names = [FuncEmission.tensor_name(t) for t in inputs] + input_names_tuple = ', '.join(input_names) + output_names = [FuncEmission.tensor_name(t) for t in outputs] + output_names_tuple = ', '.join(output_names) + + # 'graph.segment(nodes)' ensures that if a tensor is no longer used (in RC group or in later code), + # it's not included in 'outputs'. + # And we will not generate 'return' statement for it, since it will cause the error + # that the variable is not defined (because it has been 'del'-ed). + + with FunctionBlock('recompute', input_names, False) as fb: + # The nodes to recompute share the same space of line_ids (or "node ids") with non-recomputable nodes. + # e.g. those ids in subgraphs are not 0-based, and incremented after the preceding non-rc nodes and so on. + # + # So within the recomputing subgraph, tensors can be released if they are no longer used + # i.e. not returned by the 'def recompute(...)' + # since 'execplan.graph.segment(nodes)' will make all "free variables" as explicit inputs/outputs + # to that subgraph. + + # for ncode in ModuleCodeGen._emit_nodes(nodes, lifecycle): + # fb.insert_body(ncode) + fb.insert_body(ModuleCodeGen._emit_nodes(nodes, lifecycle)) + fb.insert_body(f'return {output_names_tuple}') + codes = [''] + fb.code + [''] + codes.append( + f'{output_names_tuple} = ckpt.checkpoint(recompute, {input_names_tuple})' + ) + + return codes + + def clear(self): + """ + Clear buffer that used for generating code + """ + # module init code + self.model_init_statements: List[str] = list() + # module forward code + self.model_methods_bodies: List[List[str]] = list() + # module member name + self.symbols = SymbolTable() + # batch size + self.batch_size = None diff --git a/cube/codegen/schedule/__init__.py b/cube/codegen/schedule/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py new file mode 100644 index 00000000..e60f4f01 --- /dev/null +++ b/cube/codegen/schedule/schedule.py @@ -0,0 +1,181 @@ + +from typing import List, Dict, Any, Optional +import copy + +from cube.ir.cten import IRCell, IRTensor +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.ir.tensor import IRSubTensor +from cube.ir.adapter import IRWeightReducer, IRAdapter +from cube.graph.graph import IRSegment + +from cube.graph.schedule import IRScheduleStrategy + +from cube.execplan.execplan import ExecutionPlan, ExeRepetend, ExeReuseCell + +from cube.codegen.emit import FuncEmission +from cube.codegen.syntax.symtable import SymbolTable +from cube.codegen.lifecycle import LifeCycle +from cube.codegen.syntax.blocks import FunctionBlock, ForBlock + + +class ScheduleCodeGen(FuncEmission): + + def __init__(self, execplan: ExecutionPlan): + + self.execplan = execplan + # model full code + self.init_code: List[str] = [ + '\n\n########## Generated Schedule Code ###########', + 'import torch', 'import cube', ''] + # module member name + self.symbols = SymbolTable() + + def gen(self, device: int, outfile=None, attach=None) -> str: + """ + Generate scheduling code on device + """ + gencode = copy.copy(self.init_code) + device_nodes: List[IRCell] = self.execplan.seq(device) + assert all(not isinstance(n, IRFwOperation) for n in device_nodes), \ + "Expected all forward operators have been grouped into IRSegment" + + lifetime = LifeCycle(device_nodes, [], self.execplan.graph.outputs()) + + with FunctionBlock(func_name='_train_step', + args=['model', 'dataloader']) as fb: + fb.insert_body('_ = None') + # body code + if len(device_nodes) == 0: + fb.insert_body('pass') + # legacy hardcode strategy + elif isinstance(self.execplan.graph.sched, IRScheduleStrategy): + code = self.emit_legacy_schedplan(self.execplan.graph.sched, device) + fb.insert_body(code) + else: + for line, node in enumerate(device_nodes): + # execute + if isinstance(node, ExeRepetend): + codes = ScheduleCodeGen.emit_repetend(node) + else: + codes = ScheduleCodeGen.emit_node(node) + fb.insert_body(codes) + # release + tensors = lifetime.release_tensors_after_line(line) + if len(tensors) > 0 : # not necessarily to have one after each line + fb.insert_body(ScheduleCodeGen.emit_release(tensors)) + + # return code + outputs = ScheduleCodeGen.return_name(self.execplan.graph.outputs()) + code = f'return {outputs}' + fb.insert_body(code) + gencode += fb.code + gencode += [''] + + code = '\n'.join(gencode) + # write to file + if outfile: + with open(outfile, 'a' if attach else 'w') as f: + f.write(code) + return code + + @staticmethod + def emit_node(node: IRCell) -> List[str]: + """ + Emit node / subgraph code + """ + fsign = '{outputs} = cube.runtime.executor.fexecute({name}, {model}, *{inputs}, requires_grad={req_grad})' + asign = '{outputs} = cube.runtime.executor.aexecute({model}, *{inputs}, requires_grad={req_grad})' + bsign = '{input_grads} = cube.runtime.executor.backward({name}, {input_tensors}, {output_tensors}, {output_grads})' + + node_inputs, node_outputs = node.inputs(), node.outputs() + req_grad = any(t.requires_grad for t in node.outputs() if isinstance(t, IRTensor)) + + # handle for forward + inputs = ScheduleCodeGen.tuple_name(node_inputs, skip_attr=True, prefix_attr='model.') + outputs = ScheduleCodeGen.return_name(node_outputs, skip_attr=True, prefix_attr='model.') + + unwrap_node = node.cell if isinstance(node, ExeReuseCell) else node + name = ScheduleCodeGen.node_name(unwrap_node) + + if isinstance(unwrap_node, IRSegment): + # emit forward segment + if node.isfw(): + code = fsign.format( + outputs = outputs, + name = f"'{name}'", + model = f'model.{name}', + inputs = inputs, + req_grad = req_grad + ) + else: + # get gradient computation arguments + input_tensors, output_tensors, output_grads, input_grads = \ + ScheduleCodeGen.get_backward_callsite_io_tensors(node) + # special handle for loss + for idx, tensor in enumerate(output_grads): + if isinstance(tensor, IRSubTensor) and tensor.is_loss(): + output_grads[idx] = None + code = bsign.format( + name = f"'{ScheduleCodeGen.node_name(unwrap_node.mirror)}'", + input_grads = ScheduleCodeGen.return_name(input_grads), + input_tensors = ScheduleCodeGen.tuple_name(input_tensors, skip_attr=True, prefix_attr='model.'), + output_tensors = ScheduleCodeGen.tuple_name(output_tensors, skip_attr=True, prefix_attr='model.'), + output_grads = ScheduleCodeGen.tuple_name(output_grads, skip_attr=True, prefix_attr='model.') + ) + + elif isinstance(unwrap_node, IRDataOperation): + code = f'{outputs} = next(dataloader)' + + elif isinstance(unwrap_node, IRAdapter): + code = asign.format( + outputs = outputs, + model = f'model.{name}', + inputs = inputs, + req_grad = req_grad + ) + + elif isinstance(unwrap_node, IRWeightReducer): + code = asign.format( + outputs = outputs, + model=f'model.{name}', + inputs='()', + req_grad=req_grad + ) + + else: + raise RuntimeError(f"Unspported node type: {type(unwrap_node)}") + + return [code] + + @staticmethod + def emit_repetend(repetend: ExeRepetend) -> List[str]: + """ + Emit code for executing a repetend + """ + with ForBlock(var=None, iters=f'range({repetend.repeat})') as fb: + for node in repetend.nodes(): + ncode = ScheduleCodeGen.emit_node(node) + fb.insert_body(ncode) + return fb.code + + @staticmethod + def emit_legacy_schedplan(schedplan: IRScheduleStrategy, devid: int) -> List[str]: + """ + Lagecy code + """ + signature = schedplan.signature + kwargs: Dict[str, Any] = schedplan.kwargs(devid) + strkwargs = dict() + for kwarg, val in kwargs.items(): + if isinstance(val, IRCell): + name = 'model.' + ScheduleCodeGen.node_name(val) + elif isinstance(val, (tuple, list)): + brackets = ')' if len(val) != 1 else ',)' + name = '(' + ', '.join('model.' + ScheduleCodeGen.node_name(n) \ + if isinstance(n, IRCell) else str(n) for n in val) + brackets + else: + name = str(val) + strkwargs[kwarg] = name + code = ', '.join(f'{kwarg}={name}' for kwarg, name in strkwargs.items()) + code = f'{signature}({code})' + return [code] diff --git a/cube/codegen/syntax/blocks.py b/cube/codegen/syntax/blocks.py index bea6d974..d6198d10 100644 --- a/cube/codegen/syntax/blocks.py +++ b/cube/codegen/syntax/blocks.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional class Block: @@ -72,3 +72,12 @@ def __init__(self, class_name, derived=None): derived = f'({derived})' title = f'class {self.class_name}{derived}:' super().__init__(title) + + +class ForBlock(Block): + """ + Create a for-loop block with function definition + """ + def __init__(self, var: Optional[str], iters: str): + var = '_' if var is None else var + super().__init__(f'for {var} in {iters}:') diff --git a/cube/compiler.py b/cube/compiler.py index 297b7aa3..c022b915 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -9,13 +9,14 @@ from cube.graph.graph import IRGraph from cube.ir.operator import IRDataOperation from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.schedule.schedplan import SchedulePlan from cube.graph.function.pyfunc import IRPyFunc from cube.execplan import ExecutionPlan from cube.execplan.planpass.fusion import DiffFusion from cube.execplan.planpass.grouping import Grouping -from cube.codegen.codegen import ModelCodeGen, ScheduleCodeGen +from cube.codegen import ModuleCodeGen, ScheduleCodeGen from cube.profiler.timer import print_each_rank from cube.runtime.device import DeviceGroup @@ -153,6 +154,7 @@ def decorator(fn: Callable) -> Callable: if graph.sched is not None: start = time.time() graph.sched.apply() + # print(graph.sched)qq if CompileFlag.log_schedule: print(graph.sched) span = time.time() - start @@ -160,7 +162,12 @@ def decorator(fn: Callable) -> Callable: # to execution plan start = time.time() - execplan = ExecutionPlan(graph) + if isinstance(graph.sched, SchedulePlan): + execplan = ExecutionPlan.from_schedplan(graph.sched) + else: + execplan = ExecutionPlan.from_graph(graph) + if CompileFlag.visualize_plan: + execplan.visualize('plan.png') span = time.time() - start print('> finish lowering to execution plan: {:.2f} s'.format(span)) @@ -185,7 +192,7 @@ def decorator(fn: Callable) -> Callable: start = time.time() local_world_size = DeviceGroup().local_world_size # code generation - mgener = ModelCodeGen(execplan) + mgener = ModuleCodeGen(execplan) sgener = ScheduleCodeGen(execplan) for local_rank in range(local_world_size): rank = DeviceGroup().node_rank * local_world_size + local_rank diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index c6453222..aed05aa8 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -1,49 +1,240 @@ -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import copy import numpy as np +import sys from cube.ir.cten import IRCell -from cube.ir.adapter import IRAdapter -from cube.ir.operator import IRBpOperation, IRFwOperation +from cube.ir.tensor import IRSubTensor, IRFullTensor +from cube.ir.adapter import IRAdapter, IRWeightReducer +from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.graph.graph import IRGraph, IRSegment +from cube.graph.schedule.schedplan import SchedulePlan, Block, Repetend + + +class ExeReuseCell(IRCell): + """ + A cell that reuses a cell with new inputs and outputs + This is designed for code shrinking of repeatedly executing + same operator-sequences, e.g., different micro-batches + execute on a same code piece. + """ + + def __init__(self, cell: IRCell, + inputs: List[IRSubTensor], outputs: List[IRCell]): + assert len(inputs) == len(cell.inputs()) + assert len(outputs) == len(cell.outputs()), ( + f"output length mismatch: {cell}\n" + f"cell outputs: {cell.outputs()}\noutputs: {outputs}") + super().__init__(cell.name, cell.signature, + len(inputs), len(outputs), init_outputs=False) + for idx, t in enumerate(inputs): + self.set_input(idx, t) + for idx, t in enumerate(outputs): + self.set_output(idx, t) + self._cell: IRCell = cell + self._cached_dispatched: Dict[int, ExeReuseCell] = {} + + @property + def device(self) -> int: + return self._cell.device + + @property + def cell(self) -> IRCell: + return self._cell + + def isfw(self) -> bool: + return self._cell.isfw() + + def dispatch(self, devid: int, _mirror = True): + assert len(self.device) > 0 and devid in self.device, f"Cannot dispatch of ReuseCell {self} to device {devid}" + if devid in self._cached_dispatched: + return self._cached_dispatched[devid] + + inputs = [] + for t, cell_t in zip(self._inputs, self._cell.inputs()): + if isinstance(cell_t, IRSubTensor) and devid not in cell_t.device: + continue + inputs.append(t) + outputs = [] + for t, cell_t in zip(self._outputs, self._cell.outputs()): + if isinstance(cell_t, IRSubTensor) and devid not in cell_t.device: + continue + outputs.append(t) + reuse = ExeReuseCell(self._cell.dispatch(devid), inputs, outputs) + reuse._id = self._id + if _mirror and self.mirror is not None: + mreuse = self.mirror.dispatch(devid, _mirror=False) + IRCell.make_pair(reuse, mreuse) + self._cached_dispatched[devid] = reuse + return reuse + + def __repr__(self) -> str: + return f'ReuseCell-{self.device}(name={self._cell.name}{self._cell.cid}, inputs={self.inputs()}, outputs={self.outputs()})' + + +class ExeRepetend(IRCell): + """ + A cell that will be repeatedly executed for multiple times + on a sequence of nodes + """ + + def __init__(self, nodes: List[IRCell], repeat: int = 1): + super().__init__('repetend', 'None', 0, 0, init_outputs=False) + self._nodes: List[IRCell] = nodes + self._repeat = repeat + + @property + def repeat(self) -> int: + return self._repeat + + @property + def device(self) -> Tuple[int]: + device = set() + for node in self._nodes: + device.update(node.device) + return tuple(device) + + def nodes(self) -> Tuple[IRCell]: + return tuple(self._nodes) + + def isfw(self) -> bool: + return all(n.isfw() for n in self._nodes) + + def dispatch(self, devid: int) -> IRCell: + nodes = [] + for n in self._nodes: + if devid in n.device: + nodes.append(n.dispatch(devid)) + repetend = ExeRepetend(nodes, self.repeat) + repetend._id = self._id + + def add(self, node: IRCell): + """ + Append a node + """ + self._nodes.append(node) + + def pop(self, index: int) -> IRCell: + return self._nodes.pop(index) + + def remove(self, node: IRCell): + return self._nodes.remove(node) + + def __repr__(self) -> str: + dscp = f'Repetend{self.cid}-{self.device}(repeat={self.repeat}\n' + for n in self._nodes: + dscp += ' ' + str(n) + '\n' + dscp += ')' + return dscp class ExecutionPlan: + """ + Execution plan for runtime execution. + Each device will be assigned by its execution sequence + """ + + @staticmethod + def from_graph(graph: IRGraph): + """ + Create execution plan from IRGraph + """ + return ExecutionPlan(graph, graph.nodes()) + + @staticmethod + def from_schedplan(schedplan: SchedulePlan): + """ + Create execution plan from SchedulePlan + """ + micro_ftensors: Dict[int, Dict[IRFullTensor, IRFullTensor]] = {} + def get(tensor: IRSubTensor, micro_idx: int) -> IRSubTensor: + """Get a same-shape tensor for micro-batch index""" + if not isinstance(tensor, IRSubTensor): return tensor + if micro_idx == 0: return tensor + ftensor = micro_ftensors.setdefault(micro_idx, {}).setdefault(tensor.parent, tensor.parent.like()) + t = ftensor.select(tensor.indmap, tensor.valmap) + if tensor.grad is not None: + fgrad: IRFullTensor = ftensor.grad + micro_ftensors.setdefault(micro_idx, {}).setdefault(tensor.parent.grad, fgrad) + t.grad = fgrad.select(tensor.grad.indmap, tensor.grad.valmap) + return t + + micro_fcells: Dict[(int, IRCell), ExeReuseCell] = {} + def block2reuse(node: Block) -> ExeReuseCell: + if node.blk.isfw(): + key = (node.mid, node.blk) + if key in micro_fcells: + return micro_fcells[key] + inputs = [get(t, node.mid) for t in node.blk.inputs()] + outputs = [get(t, node.mid) for t in node.blk.outputs()] + cell = ExeReuseCell(node.blk, inputs, outputs) + if isinstance(node.blk.mirror, IRCell): + minputs = [get(t, node.mid) for t in node.blk.mirror.inputs()] + moutputs = [get(t, node.mid) for t in node.blk.mirror.outputs()] + mcell = ExeReuseCell(node.blk.mirror, minputs, moutputs) + IRCell.make_pair(cell, mcell) + micro_fcells[key] = cell + return cell + else: + mcell = block2reuse(Block(node.blk.mirror, node.mid)) + return mcell.mirror + + topo_seqs: List[IRCell] = [] + for block in schedplan.nodes(): + # convert repetends and blocks + if isinstance(block, Repetend): + nodes: List[ExeReuseCell] = [] + for node in block.nodes(): + if isinstance(node, Block): + node = block2reuse(node) + nodes.append(node) + block = ExeRepetend(nodes, repeat=block.span) + elif isinstance(block, Block): + block = block2reuse(block) + assert isinstance(block, IRCell) + topo_seqs.append(block) + return ExecutionPlan(schedplan.graph, topo_seqs) + + def __init__(self, graph: IRGraph, topo_seqs: List[IRCell]): - def __init__(self, graph: IRGraph): assert isinstance(graph, IRGraph), "Expected an IRGraph" self._graph = graph - self._seq: Dict[int, List[IRCell]] = dict() - self._inference_only = not any( - isinstance(n, IRBpOperation) or \ - (isinstance(n, IRAdapter) and not n.forward) or \ - (isinstance(n, IRSegment) and not n.isfw()) for n in graph.nodes() - ) + self._topo_seqs = topo_seqs + self._seq: Dict[int, List[IRCell]] = {} - # execution sequence for each device - for node in graph.nodes(): - if len(node.device) == 0: - raise RuntimeError(f"Node device not set: {node}") + for node in self._topo_seqs: + assert len(node.device) > 0, f"Node device not set: {node}" for device in node.device: - if device not in self._seq: - self._seq[device] = [] - self._seq[device].append(node) + self._seq.setdefault(device, []).append(node) - # adapter/segment dispatch - for devid in self.devices(): - nodes = [node for node in self.at(devid) if isinstance(node, (IRAdapter, IRSegment))] - while len(nodes) > 0: - # dispatch - fnode = nodes[0] - fidx = self.at(devid).index(fnode) - fnode_dev = fnode.dispatch(devid) - self.at(devid)[fidx] = fnode_dev - nodes.pop(0) - if fnode.mirror is not None: - bidx = self.at(devid).index(fnode.mirror) - nodes.remove(fnode.mirror) - self.at(devid)[bidx] = fnode_dev.mirror - assert fnode_dev.mirror is not None, f"Find None:\n{fnode_dev}" + # due to repetends, a same node could appear multiple times + # in the execution sequence. For this case, all of them + # will be replaced by a same dispatched one. + def cached_dispatch(node: IRCell, devid: int, + dispatched: Dict[IRCell, IRCell]) -> IRCell: + """Cached dispatch""" + if node.isfw() or isinstance(node, IRWeightReducer): + return dispatched.setdefault(node, node.dispatch(devid)) + fnode = node.mirror + assert isinstance(fnode, IRCell), "Expected forward node as mirror" + assert fnode.isfw() + # return dispatched[fnode].mirror + return dispatched.setdefault(fnode, fnode.dispatch(devid)).mirror + + # dispatch for a node that is executed on multiple devices + for devid, nodes in self._seq.items(): + dispatched : Dict[IRCell, IRCell] = {} + for idx in range(len(nodes)): + node = nodes[idx] + # print(f'handling {node}') + if len(node.device) == 1: continue # no need for dispatch + if isinstance(node, ExeRepetend): + rnodes = [cached_dispatch(n, devid, dispatched) \ + for n in node.nodes() if devid in n.device] + dnode = ExeRepetend(rnodes, node.repeat) + else: + dnode = cached_dispatch(node, devid, dispatched) + nodes[idx] = dnode @property def graph(self) -> IRGraph: @@ -51,7 +242,7 @@ def graph(self) -> IRGraph: @property def inference(self) -> bool: - return self._inference_only + return not self._graph.train def devices(self) -> List[int]: """ @@ -102,11 +293,10 @@ def set(self, devid: int, seq: List[IRCell]): raise TypeError("Expected a list of Cell") self._seq[devid] = seq - def visualize(self, - map2time: Optional[Callable] = None, - map2mem: Optional[Callable] = None, - map2name: Optional[Callable] = None, - outfile: Optional[str] = None): + def visualize(self, outfile: str, + map2time: Optional[Callable] = None, + map2mem: Optional[Callable] = None, + map2name: Optional[Callable] = None): """ Visualize the graph @@ -125,18 +315,9 @@ def visualize(self, if map2time is None: def map2time(node): - if isinstance(node, IRSegment): - span = 0 - for node in node.nodes(): - span += map2time(node) - return span - if isinstance(node, IRFwOperation): - return 1 - if isinstance(node, IRBpOperation): - return 2 - if isinstance(node, IRAdapter): - return 0.5 - return 0 + if isinstance(node, IRDataOperation): return 0 + if isinstance(node, IRAdapter): return 0.25 + return 1 if node.isfw() else 2 if map2mem is None: def map2mem(node): @@ -154,32 +335,39 @@ def map2mem(node): if map2name is None: def map2name(node): - if isinstance(node, IRSegment): - if node.isfw(): - return f'f{node.cid}' - elif node.isbw(): - return f'b{node.mirror.cid}' - if isinstance(node, IRFwOperation): - return f'f{node.cid}' - if isinstance(node, IRBpOperation): - return f'b{node.mirror.cid}' if isinstance(node, IRAdapter): - return f'a{node.cid}' - return f'?{node.cid}' + return '' + else: + return f'f{node.cid}' if node.isfw() else f'b{node.cid}' def map2color(node): - if isinstance(node, IRSegment): - return map2color(node.nodes(0)) - if isinstance(node, IRFwOperation): - return '#4472C4' # excel blue - if isinstance(node, IRBpOperation): - return '#ED7D31' # excel orange + node = node.cell if isinstance(node, ExeReuseCell) else node if isinstance(node, IRAdapter): return '#70AD47' # excel green + if node.isfw(): + return '#4472C4' # excel blue + else: + return '#ED7D31' # excel orange - self.graph.reset_dependency() - for node in self.graph.nodes(): - span, mem = map2time(node), map2mem(node) + + # analyze device timeline + + def depends(prev: IRCell, next: IRCell) -> bool: + for to in prev.outputs(): + if not isinstance(to, IRSubTensor): continue + for ti in next.inputs(): + if not isinstance(ti, IRSubTensor): continue + if to.overlap(ti): return True + return False + + device_timeline: List[Tuple[int, int]] = [list() for _ in range(ndevice)] + device_nodes: List[IRCell] = [list() for _ in range(ndevice)] + device_mem = [0] * ndevice + device_peak_mem = [0] * ndevice + + for node in self._topo_seqs: + unwrap_node = node.cell if isinstance(node, ExeReuseCell) else node + span, mem = map2time(unwrap_node), map2mem(unwrap_node) # calculate time start_times = [] for device in node.device: @@ -191,11 +379,10 @@ def map2color(node): # check dependency for devid, timeline in enumerate(device_timeline): dev_seq = device_nodes[devid] - if devid == device: - continue + if devid == device: continue for nid, (_, end_time) in enumerate(timeline[::-1]): other_node = dev_seq[::-1][nid] - if other_node in node.predecessors(): + if depends(other_node, node): start_time = max(start_time, end_time) break start_times.append(start_time) @@ -217,75 +404,74 @@ def map2color(node): # max_mem = sum(device_peak_mem) # draw the timeline - if outfile is not None: - import matplotlib.pyplot as plt - from matplotlib.patches import Rectangle - from matplotlib.ticker import AutoMinorLocator - plt.close('all') - plt.rcParams['figure.figsize'] = (4.0 * max_time // ndevice, 4.0) - fig, ax = plt.subplots() - renderer = fig.canvas.get_renderer() - - # xaxis - ax.set_xlim((1, max_time)) - plt.xticks( - ticks=np.arange(1.5, max_time+0.5, 1.0, dtype=float), - labels=np.arange(1, max_time, 1, dtype=int) - ) - minor_locator = AutoMinorLocator(2) - plt.gca().xaxis.set_minor_locator(minor_locator) - ax.xaxis.grid(which='minor', linestyle='--') - # yaxis - ax.set_ylim((0.5, len(self.devices())+0.5)) - plt.yticks(list(range(1, len(self.devices())+1, 1))) - ax.invert_yaxis() - - fontsize = 40 - txts = list() - for devid in range(ndevice): - timeline = device_timeline[devid] - nodes = device_nodes[devid] - for node, (start, end) in zip(nodes, timeline): - if end - start == 0: - continue - # draw - color = map2color(node) - rec = Rectangle((start, devid + 0.5), end-start, 1, - color=color, ec='black', lw=1.5) - ax.add_artist(rec) - rx, ry = rec.get_xy() - cx = rx + rec.get_width() / 2.0 - cy = ry + rec.get_height() / 2.0 - anno = map2name(node) - txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') - - rbox = rec.get_window_extent(renderer) - for fs in range(fontsize, 1, -2): - txt.set_fontsize(fs) - tbox = txt.get_window_extent(renderer) - if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: - break - fontsize = min(fontsize, fs) - txts.append(txt) + + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + from matplotlib.ticker import AutoMinorLocator + plt.close('all') + plt.rcParams['figure.figsize'] = (4.0 * max_time // ndevice, 4.0) + fig, ax = plt.subplots() + renderer = fig.canvas.get_renderer() + + # xaxis + ax.set_xlim((1, max_time)) + plt.xticks( + ticks=np.arange(1.5, max_time+0.5, 1.0, dtype=float), + labels=np.arange(1, max_time, 1, dtype=int) + ) + minor_locator = AutoMinorLocator(2) + plt.gca().xaxis.set_minor_locator(minor_locator) + ax.xaxis.grid(which='minor', linestyle='--') + # yaxis + ax.set_ylim((0.5, len(self.devices())+0.5)) + plt.yticks(list(range(1, len(self.devices())+1, 1))) + ax.invert_yaxis() + + fontsize = [40] + txts = list() + for devid in range(ndevice): + timeline = device_timeline[devid] + nodes = device_nodes[devid] + for node, (start, end) in zip(nodes, timeline): + unwrap_node = node.cell if isinstance(node, ExeReuseCell) else node + if end - start == 0: + continue + # draw + color = map2color(unwrap_node) + rec = Rectangle((start, devid + 0.5), end-start, 1, + color=color, ec='black', lw=1.5) + ax.add_artist(rec) + rx, ry = rec.get_xy() + cx = rx + rec.get_width() / 2.0 + cy = ry + rec.get_height() / 2.0 + anno = map2name(unwrap_node) + if anno == '': continue + txt = ax.text(x=cx, y=cy, s=anno, fontsize=40, ha='center', va='center', color='w') + rbox = rec.get_window_extent(renderer) + for fs in range(fontsize[0], 1, -2): + txt.set_fontsize(fs) + tbox = txt.get_window_extent(renderer) + if tbox.x0 > rbox.x0 and tbox.x1 < rbox.x1 and tbox.y0 > rbox.y0 and tbox.y1 < rbox.y1: + break + fontsize[0] = min(fontsize[0], fs) + txts.append(txt) - # set font size to same - for txt in txts: - txt.set_fontsize(fontsize) - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(fontsize) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(fontsize) - plt.xlabel('Time Step', fontsize=fontsize) - plt.ylabel('Device ID', fontsize=fontsize) - - # plt.grid() - plt.tight_layout() - plt.savefig(outfile) + # set font size to same + for txt in txts: + txt.set_fontsize(fontsize[0]) + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize[0]) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(fontsize[0]) + plt.xlabel('Time Step', fontsize=fontsize[0]) + plt.ylabel('Device ID', fontsize=fontsize[0]) + plt.tight_layout() + plt.savefig(outfile) return max_time, max_mem def __repr__(self): - dscp = f'Execution Plan ({self.graph.name}):\n' + dscp = f'Execution Plan ({self.graph.name}) (inference: {self.inference}):\n' for devid in self.devices(): dscp += f'====> Device {devid}:\n' for node in self._seq[devid]: diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index f745b645..fefcb17b 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -1,9 +1,10 @@ -from typing import List +from typing import List, Union, Set from cube.graph.graph import IRSegment from cube.ir.adapter import IRAdapter from cube.execplan import ExecutionPlan +from cube.execplan.execplan import ExeRepetend, ExeReuseCell from cube.execplan.planpass.planpass import PlanPass from cube.ir.adapter.prim import IRAdapterPrim @@ -24,20 +25,33 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: """ cnt = 0 for devid in execplan.devices(): + # fadapters: Set[IRAdapter] = set() + visited = set() for node in execplan.seq(devid): - if isinstance(node, IRAdapter) and node.forward: + if isinstance(node, ExeReuseCell): + node = node.cell + if node in visited: + continue + if isinstance(node, IRAdapter) and node.isfw(): ret = DiffFusion.nnfuse(node) cnt = cnt+1 if ret else cnt elif isinstance(node, IRSegment) and node.isfw(): - cnt += DiffFusion._apply(node) + for fadapter in node.select(ntype=IRAdapter): + ret = DiffFusion.nnfuse(fadapter) + cnt = cnt+1 if ret else cnt + elif isinstance(node, ExeRepetend) and node.isfw(): + for fadapter in [n for n in node.nodes() if isinstance(n, IRAdapter)]: + ret = DiffFusion.nnfuse(fadapter) + cnt = cnt+1 if ret else cnt + visited.add(node) print(f'successfully generate {cnt} differentiable adapters') return execplan @staticmethod - def _apply(segment: IRSegment) -> int: + def _apply(cell: Union[IRSegment, ExeRepetend]) -> int: cnt = 0 - for node in segment.nodes(): - if isinstance(node, IRAdapter) and node.forward: + for node in cell.nodes(): + if isinstance(node, IRAdapter) and node.isfw(): ret = DiffFusion.nnfuse(node) # if not ret and not node.differentiable: # raise NotImplementedError( diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 002fc3e3..755011dd 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -59,7 +59,7 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: def differentiable(fnode): if isinstance(fnode, IRFwOperation): return True - if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.forward: + if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.isfw(): return True return False diff --git a/cube/flags.py b/cube/flags.py index 14ae65fb..f3e0a46f 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -19,6 +19,8 @@ class CompileFlag: disable_inter_rvd = os.environ.get('DISABLE_INTER_RVD') disable_comm_fusion = os.environ.get('DISABLE_COMM_FUSION') + visualize_plan = bool(os.environ.get('VISUALIZE_PLAN')) + # ============ code generation =============== use_nnfusion = os.environ.get('USE_NNFUSION') use_jit = os.environ.get('USE_JIT') diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 673bda86..ce7b8ccd 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -692,7 +692,7 @@ def fusion(graph: IRSegment) -> IRSegment: """ fadapters, badapters = [], [] for adapter in graph.nodes(): - if isinstance(adapter, IRAdapter) and adapter.forward and not adapter.differentiable: + if isinstance(adapter, IRAdapter) and adapter.isfw() and not adapter.differentiable: fadapters.append(adapter) if adapter.mirror is not None: badapters.append(adapter.mirror) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index f07c9f96..85e6d15b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -41,7 +41,6 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR self._sched = None # the schedule strategy - @property def train(self) -> bool: """! @@ -49,7 +48,7 @@ def train(self) -> bool: @return train bool: True if backward is required, otherwise False (inference only). """ - return self._have_forward and self._have_backward + return any(not n.isfw() for n in reversed(self._nodes)) # ================ Deep Learning Interfalce ====================== @@ -591,6 +590,15 @@ def predef_sched(self, strategy): """ self._sched = strategy + def _bind_schedule(self, schedplan): + """ + Set schedule plan for the execution + + @param schedplan SchedulePlan + """ + assert self._sched is None, "The graph is already binded with one schedule plan." + self._sched = schedplan + @staticmethod def legal_schedule(seq: List[IRCell], integrity_check=False): """ diff --git a/cube/graph/schedule/predefined.py b/cube/graph/schedule/predefined.py new file mode 100644 index 00000000..d35e7f4b --- /dev/null +++ b/cube/graph/schedule/predefined.py @@ -0,0 +1,112 @@ +""" +Common scheduling descriptions +""" + +from typing import List + +from cube.graph.schedule.schedplan import SchedulePlan +from cube.graph.graph import IRGraph +from cube.graph.segment import IRSegment + + +class PredefinedSched: + + @staticmethod + def sched_1f1b(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: + """ + 1F1B scheduling. The graph should be staged into segments. + + An illustration of scheduling schema (the number is micro-batch index): + ``` + f0 f1 f2 | f3 b0 | b1 b2 b3 + f0 f1 f2 | b0 f3 | b1 b2 b3 + f0 f1 b0 | f2 b1 | f3 b2 b3 + f0 b0 f1 | b1 f2 | b2 f3 b3 + ``` + """ + segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + assert len(fsegs) == num_stages, f"Mismatch of forward segement number ({len(fsegs)}) with num_stages ({len(num_stages)})" + + # describe schedule + sched = SchedulePlan(graph, num_microbatches) + + wait_steps = [sid for sid in range(num_stages)] + bw_ofst = [num_stages - 1 - sid for sid in range(num_stages)] + total_steps = num_microbatches * 2 + (num_stages - 1) * 2 + + for step in range(total_steps): + for sid in range(num_stages): + ofst = wait_steps[sid] + if step < ofst: continue + fw_idx = (step - ofst) // 2 + # forward or backward segment + segment = fsegs[sid] if (step - ofst) % 2 == 0 else fsegs[sid].mirror + mb_idx = fw_idx if (step - ofst) % 2 == 0 else fw_idx - bw_ofst[sid] + # append for execution + if mb_idx < 0 or mb_idx >= num_microbatches: continue + sched.add_segment(segment, mb_idx, step) + sched.finish() + return sched + + @staticmethod + def sched_gpipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: + """ + GPipe scheduling. The graph should be staged into segments. + + An illustration of scheduling schema (the number is micro-batch index): + ``` + f0 f1 f2 f3 b0 b1 b2 b3 + f0 f1 f2 f3 b0 b1 b2 b3 + f0 f1 f2 f3 b0 b1 b2 b3 + f0 f1 f2 f3 b0 b1 b2 b3 + ``` + """ + segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + assert len(fsegs) == num_stages, "Mismatch of forward segement number with num_stages" + # describe schedule + sched = SchedulePlan(graph, num_microbatches) + + fwait_steps = [sid for sid in range(num_stages)] + bwait_steps = [num_stages - 1 - sid for sid in range(num_stages)] + + total_steps = num_microbatches * 2 + (num_stages - 1) * 2 + middle_step = total_steps // 2 + for step in range(total_steps): + for sid in range(num_stages): + segment = fsegs[sid] if step < middle_step else fsegs[sid].mirror + mb_idx = step - fwait_steps[sid] if step < middle_step else step - middle_step - bwait_steps[sid] + if mb_idx < 0 or mb_idx >= num_microbatches: continue + sched.add_segment(segment, mb_idx, step) + sched.finish() + return sched + + @staticmethod + def sched_infer_pipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: + """ + Inference pipeline scheduling. The graph should be staged into segments. + + An illustration of scheduling schema (the number is micro-batch index): + ``` + f0 f1 f2 f3 + f0 f1 f2 f3 + f0 f1 f2 f3 + f0 f1 f2 f3 + ``` + """ + fsegs: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + assert all(seg.isfw() for seg in fsegs), f"Detect backward. The predefined scheduling only applies for inference" + assert len(fsegs) == num_stages, "Mismatch of forward segement number with num_stages" + # describe schedule + sched = SchedulePlan(graph, num_microbatches) + fwait_steps = [sid for sid in range(num_stages)] + total_steps = num_microbatches + num_stages - 1 + for step in range(total_steps): + for sid in range(num_stages): + segment = fsegs[sid] + mb_idx = step - fwait_steps[sid] + if mb_idx < 0 or mb_idx >= num_microbatches: continue + sched.add_segment(segment, mb_idx, step) + sched.finish() + return sched diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py new file mode 100644 index 00000000..45cb7f81 --- /dev/null +++ b/cube/graph/schedule/schedplan.py @@ -0,0 +1,267 @@ +from typing import Dict, List, Union, Callable, Optional, Tuple, Set +import itertools + +from cube.ir.cten import IRCell +from cube.ir.adapter import IRAdapter +from cube.ir.adapter import IRWeightReducer +from cube.ir.operator import IRDataOperation + +from cube.graph.graph import IRGraph +from cube.graph.segment import IRSegment + + + +class Block: + """ + A block is a node in SchedulePlan, representing an IRCell + that is executed with input data of a given micro-batch index. + """ + + def __init__(self, cell: IRCell, micro_batch_id: int) -> None: + """ + """ + assert isinstance(cell, IRCell), f"Expected IRCell, but got {type(cell)}: {cell}" + self._block: IRCell = cell + self._micro_batch_id: int = micro_batch_id + + @property + def device(self) -> Tuple[int]: + return tuple(self._block.device) + + @property + def mid(self) -> int: + return self._micro_batch_id + + @property + def blk(self) -> IRCell: + return self._block + + def dispatch(self, devid: int): + return Block(self._block.dispatch(devid), self._micro_batch_id) + + def __repr__(self) -> str: + return f'Block({self._micro_batch_id})-{self.device} : {self._block}' + + +class PlanBase: + + def __init__(self): + self._step_devs: List[Set[int]] = [] + self._step_segments: List[List[Block]] = [] + # adapters executed after the segments on that step + self._step_adapters: List[List[Block]] = [] + + # topological sequence + self._seqs: List[IRCell] = [] + + @property + def nsteps(self) -> int: + return len(self._step_segments) + + def nodes(self) -> Tuple[Block]: + return tuple(self._seqs) + + def add_segment(self, seg: IRSegment, micro_batch_id: int, step: int) -> Block: + """ + Add a segment `seg` to be executed with `micro-batch-id` data at step `step`. + """ + self._extend_step(step) + assert all(devid not in self._step_devs[step] for devid in seg.device), \ + f"A step cannot execute multiple segments on a same device" + block = Block(seg, micro_batch_id) + self._step_segments[step].append(block) + self._step_devs[step].update(seg.device) + return block + + def _extend_step(self, step: int): + if len(self._step_segments) <= step: + nextend = step - len(self._step_segments) + 1 + self._step_segments += [[] for _ in range(nextend)] + self._step_devs += [set() for _ in range(nextend)] + self._step_adapters += [[] for _ in range(nextend)] + + def topo_sort(self): + self._seqs = [] + for step in range(self.nsteps): + self._seqs += self._step_segments[step] + self._seqs += self._step_adapters[step] + + +class Repetend(PlanBase): + """ + A repetend is a node in SchedulePlan, representing its nodes + will be repeatedly executed by `span` times witn growing + micro-batch index. + """ + + def __init__(self, span: int, + step_nodes: List[List[Block]], + step_adapters: List[List[IRAdapter]], + step_devs: List[Set[int]]): + """ + @param span int: the repeated execution time + """ + super().__init__() + self._span = span + self._step_segments = step_nodes + self._step_adapters = step_adapters + self._step_devs = step_devs + # adapters out of for loop + self._post_adapters: List[IRAdapter] = [] + + @property + def device(self) -> Tuple[int]: + device = set() + for devs in self._step_devs: + device.update(devs) + return tuple(device) + + @property + def span(self) -> int: + return self._span + + def nodes(self) -> Tuple[Block]: + return tuple(self._seqs) + + def __repr__(self): + dscp = f'Repetend-{self.device}(span={self._span}\n' + for blk in self._seqs: + dscp += ' ' + repr(blk) + '\n' + dscp += ')' + return dscp + + +class SchedulePlan(PlanBase): + + def __init__(self, graph: IRGraph, num_microbatches: int): + super().__init__() + self._graph: IRGraph = graph + + # adapter info + self._dataloaders : List[IRDataOperation] = [] + self._segments: List[IRSegment] = [] + self._adapters: List[IRAdapter] = [] + self._recvers: Dict[IRAdapter, IRSegment] = {} + self._senders: Dict[IRAdapter, IRSegment] = {} + self._reducers: List[IRWeightReducer] = [] + # execution sequence + self._device_seqs: Dict[int, Union[Repetend, IRSegment]] = {} + self._num_microbatches = num_microbatches + # bind to the graph + graph._bind_schedule(self) + + @property + def nmicros(self) -> int: + """ + Get number of micro-batches + """ + return self._num_microbatches + + @property + def device(self) -> Tuple[int]: + devs = set() + for node in self._seqs: + devs.update(node.device) + return tuple(devs) + + @property + def graph(self) -> IRGraph: + return self._graph + + def steady_repeat(self, from_step: int, to_step: int, repeat: int): + raise NotImplementedError("Not supported for steady representation") + + def finish(self) -> bool: + """ + Check whether the description contains full micro-batches + """ + pass + + def apply(self): + """ + Insert generated adapters in the emitted sequence. + This can be called by system only after generating adapters. + """ + # step 1: identify connected segements for each generated adapter + self._build_dependency() + # step 2: place adapters, dataloaders + self._place_adapters() + self._place_dataloader() + # step 3: generate topological sequence + self.topo_sort() + + def _build_dependency(self): + """ + Cluster operations and build dependency to identify the connected + segments for each adapter. + """ + # get all dataloaders + self._dataloaders = list(self._graph.select(ntype=IRDataOperation, flatten=False)) + # get all segment + segments: List[IRSegment] = self._graph.select(ntype=IRSegment, flatten=False) + self._segments = segments + # get all adapters + for adapter in self._graph.select(ntype=IRAdapter, flatten=False): + self._adapters.append(adapter) + for segment in segments: + if self._graph.depends(adapter, segment): + assert adapter not in self._recvers, \ + f"Detected more than one segments to recv data from a same adapter" + self._recvers[adapter] = segment + elif self._graph.depends(segment, adapter): + assert adapter not in self._senders, \ + f"Detected more than one segments to send data from a same adapter" + self._senders[adapter] = segment + # get all weight reducers + self._reducers = self._graph.select(ntype=IRWeightReducer, flatten=False) + + def _place_adapters(self, cost_fn: Optional[Callable] = None): + """ + Place adapters to make sure the communication happens + correctly and efficiently. + + @param cost_fn Optional[Callable]: takes a segment and outputs + the execution cost in float.By default (None), this assumes + each segement has the same execution cost of 1.0. + """ + cost_fn = lambda x: 1.0 if cost_fn is None else cost_fn + for adapter in self._adapters: + assert adapter in self._senders + sender: IRSegment = self._senders[adapter] + # find sender step and insert adapter + for step, blocks in enumerate(self._step_segments): + segments = [block.blk for block in blocks] + mids = [block.mid for block in blocks] + if sender in segments: + mid = mids[segments.index(sender)] + self._step_adapters[step].append(Block(adapter, mid)) + + def _place_dataloader(self): + """ + Place dataloaders together with segments + """ + for dl in self._dataloaders: + for step, blocks in enumerate(self._step_segments): + for block in blocks: + segment, mid = block.blk, block.mid + if self.graph.depends(dl, segment): + self._step_segments[step].insert(0, Block(dl, mid)) + break + + def topo_sort(self): + super().topo_sort() + for reducer in self._reducers: + self._seqs.append(reducer) + + def depends(self, prev: Block, next: Block) -> bool: + return prev.mid == next.mid and self._graph.depends(prev.blk, next.blk) + + def __repr__(self) -> str: + dscp = f"SchedulePlan:\n" + for step in range(self.nsteps): + dscp += f'\nStep {step}:\n' + for segment in self._step_segments[step]: + dscp += repr(segment) + '\n' + for adapter in self._step_adapters[step]: + dscp += repr(adapter) + '\n' + return dscp diff --git a/cube/graph/segment.py b/cube/graph/segment.py index aaee2168..1ad34131 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -96,17 +96,13 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR for node in nodes: self.insert(node, self.nnodes) - # self.reset_dependency() + self._dispatch_cached: Dict[int, IRSegment] = {} - # FIXME: update when manipulating - self._have_forward = any(isinstance(n, IRFwOperation) for n in nodes) - self._have_backward = any(isinstance(n, IRBpOperation) for n in nodes) + # self.reset_dependency() def isfw(self) -> bool: - return self._have_forward - - def isbw(self) -> bool: - return self._have_backward + return all(n.isfw() for n in self._nodes) + # return self._have_forward def full_tensors(self) -> Tuple[IRFullTensor]: """ @@ -990,7 +986,7 @@ def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: segment = IRSegment(nodes, order(inputs), order(outputs)) return segment - def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: + def dispatch(self, devid: int, _gen_mirror: bool = True) -> Optional[IRCell]: """ Instantiate the segement to a specific device. @@ -1002,22 +998,19 @@ def dispatch(self, devid: int, mirror=True) -> Optional[IRCell]: return None if len(self.device) == 1 and self.device == [devid]: return self - inputs, outputs, nodes = [], [], [] + if devid in self._dispatch_cached: + return self._dispatch_cached[devid] + # inputs, outputs, nodes = [], [], [] + inputs, outputs, nodes = self.inputs(), self.outputs(), [] for node in self._nodes: if devid in node.device: - if isinstance(node, IRAdapter): - nodes.append(node.dispatch(devid)) - elif isinstance(node, IRSegment): - nodes.append(node.dispatch(devid)) - else: - assert len(node.device) == 1 - nodes.append(node) - for itensor in node.inputs(): - if itensor in self._inputs and itensor not in inputs: - inputs.append(itensor) - for otensor in node.outputs(): - if otensor in self._outputs and otensor not in outputs: - outputs.append(otensor) + nodes.append(node.dispatch(devid)) + # for itensor in node.inputs(): + # if itensor in self._inputs and itensor not in inputs: + # inputs.append(itensor) + # for otensor in node.outputs(): + # if otensor in self._outputs and otensor not in outputs: + # outputs.append(otensor) def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: """Reorder by logical tensor id. Temporally necessary for pipeline scheduling""" @@ -1031,9 +1024,10 @@ def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: segment = IRSegment(nodes, inputs, outputs, self.name) segment._id = self.cid - if mirror and self.mirror is not None: - msegment = self.mirror.dispatch(devid, mirror=False) + if _gen_mirror and self.mirror is not None: + msegment = self.mirror.dispatch(devid, _gen_mirror=False) IRCell.make_pair(segment, msegment) + self._dispatch_cached[devid] = segment return segment diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index 6e593c59..8a57ed1b 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Dict import copy from cube.ir.adapter.prim import IRAdapterPrim, IdentityPrim @@ -39,6 +39,8 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): assert not (is_fw and is_bw), "An IRAdapter cannot serve for both forward and backward stage" self._forward = is_fw + self._cached_dispatch: Dict[int, IRAdapter] = {} + @property def prims(self) -> List[IRAdapterPrim]: return copy.copy(self._prims) @@ -59,6 +61,9 @@ def differentiable(self) -> bool: def differentiable(self, val: bool): self._differentiable = val + def isfw(self) -> bool: + return self._forward + @property def forward(self) -> bool: """ @@ -89,14 +94,17 @@ def recompute(self, group_id: Optional[int]): assert self._recompute == group_id, "The operator is set to recompute in another recompute group." self._recompute = group_id - def dispatch(self, devid: int, for_mirror=True): + def dispatch(self, devid: int, _mirror: bool = True): """ - Get Adapter for a specific rank + Instantiate the adapter to a specific rank. - Returns: - IRAdapter + @param devid int: device id + + @param adapter IRAdapter: the dispatched adapter """ assert isinstance(devid, int), f"Expect devid to be int but got {devid}" + if devid in self._cached_dispatch: + return self._cached_dispatch[devid] prims = [prim.dispatch(devid) for prim in self.prims] prims = [prim for prim in prims if prim is not None] # get inputs @@ -123,9 +131,10 @@ def dispatch(self, devid: int, for_mirror=True): fadapter.custom = self.custom fadapter.recompute = self.recompute # dispatch for mirror - if for_mirror and isinstance(self.mirror, IRAdapter): - badapter = self.mirror.dispatch(devid, for_mirror=False) + if _mirror and isinstance(self.mirror, IRAdapter): + badapter = self.mirror.dispatch(devid, _mirror=False) IRCell.make_pair(fadapter, badapter) + self._cached_dispatch[devid] = fadapter return fadapter @staticmethod @@ -168,6 +177,12 @@ def __init__(self, weights: List[IRSubTensor], name='reducer'): for idx, weight in enumerate(weights): self.set_input(idx, weight) + def isfw(self) -> bool: + return False + + def dispatch(self, device: int): + return self + def __repr__(self): dscp = f'WReducer{self._id}-{self.device}(inputs={self.inputs()})' return dscp diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 6ffd6cf0..0f1cd79e 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -68,7 +68,7 @@ def __init__(self, # source cells. [-1] for control dependency self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length+1)] - self._mirror = None + self._mirror: Optional[IRCell] = None # the comment for code generation self._comment: Optional[str] = None @@ -97,6 +97,26 @@ def device(self, device_id: Union[int, List[int]]): raise KeyError("Require device Union[int, List[int]]") self._device = tuple(device_id) + def dispatch(self, device: int): + """ + Instantiate this node to a specified device. Its mirror node will also + be dispatched and paired with this node. + + For single operators, the mirror node will be reserved. + For nodes that cover multiple devices, e.g., IRSegment and IRAdapter, + the mirror node will be removed and require additional `make_pair` elsewhere. + + @param device int: device id + @return dispatched_node IRCell: the node that only has one device placement. + """ + assert len(self.device) == 1, \ + f"Require dispatch implementation for node type: {type(self)}" + if isinstance(self.mirror, IRCell): + assert len(self.mirror.device) == 1, \ + f"IRCell got unexpected mirro node that has multiple device placement.\n{self.mirror}" + assert device in self.device, f"Fail to dispatch to device {device}. node: {self}" + return self + @property def mirror(self): """ @@ -119,16 +139,12 @@ def make_pair(cell1, cell2): elif cell2 is not None: raise TypeError("Expected cell2 to be IRCell or None") - def on_device(self, device_id: int): + def isfw(self) -> bool: """ - Check whether the operation is on device_id - - Returns: - Boolean + Return if the IRCell is executed fully in forward phase. + This needs to be overrided by derived classes """ - if not isinstance(device_id, int): - raise TypeError(f"Expected device id to be int but got {type(device_id)}") - return device_id in self.device + return True def input(self, index:int): # type: (int) -> Optional[IRTensor] diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 2fcbd0fe..d6602487 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -160,6 +160,9 @@ def __init__(self, ograds: Tuple[Any], igrads: Tuple[Any]): for idx, igrad in enumerate(igrads): self.set_output(idx, igrad) + def isfw(self) -> bool: + return False + def replicate(self): """ Replicate the backward op diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index c5b2f322..0054005a 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -285,6 +285,8 @@ def like(self): @return tensor IRFullTensor: the created tensor """ tensor = IRFullTensor(self.shape, self.name, self.requires_grad, self.dtype) + if self.is_loss(): + tensor.to_loss() return tensor @property @@ -314,12 +316,12 @@ def is_loss(self) -> bool: def to_loss(self): """ - Set this tensor is loss tensor. The tensor shape must be [1,] + Set this tensor as loss tensor. The tensor shape must be [1,] """ assert tuple(self.shape) == (1,), f"Loss tensor can only have shape [1,] but got {self.shape}" - assert self.requires_grad, f"The tensor doesn't require gradient. Cannot backward" self._is_loss = True - self.grad._is_loss = True + if isinstance(self.grad, IRFullTensor): + self.grad._is_loss = True @property def requires_grad(self): diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 6a36a6b0..5c5e0df1 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -1,6 +1,7 @@ r""" Executor for runtime """ +import atexit from typing import Tuple, Any, Callable, List, Dict import torch @@ -41,13 +42,17 @@ def convert_fp32_to_fp16(t: Any): return t -class Executor: +TensorPairs = List[Tuple[int, torch.Tensor]] + - _detach: Dict[str, Dict[torch.Tensor, torch.Tensor]] = dict() +class Executor: - # auto mixture precision loss scaler. $ TODO: support it. - _scaler = torch.cuda.amp.GradScaler(enabled=CompileFlag.use_amp) - + # We consider each segment as an isolated graph. By + # executing the forward of graph, the input tensors will be detached + # from previous graph and saved for backward. + # Each graph has its name, and multiple call for the graph will append + # (instant id -> detached) input tensor pairs for backward reference. + _detach: Dict[str, List[TensorPairs]] = dict() @staticmethod def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): @@ -56,29 +61,20 @@ def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires """ if not requires_grad: with torch.no_grad(): - if CompileFlag.use_amp: - with torch.autocast('cuda', torch.float16): - outputs = subgraph(*input_tensors) - else: - outputs = subgraph(*input_tensors) - else: - # everytime forward a segment, detach the tensor from previous graph - # debug_id(input_tensors, 'outside fexecute args', 0) - assert name not in Executor._detach - Executor._detach[name] = dict() - for itensor in input_tensors: - if torch.is_tensor(itensor) and itensor.requires_grad: - if itensor not in Executor._detach[name]: - Executor._detach[name][itensor] = itensor.detach().requires_grad_() - input_tensors = tuple( - Executor._detach[name][t] if t in Executor._detach[name] else t for t in input_tensors - ) - if CompileFlag.use_amp: - with torch.autocast('cuda', torch.float16): - outputs = subgraph(*input_tensors) - else: outputs = subgraph(*input_tensors) - # print('forwarding... ') + return outputs + + # everytime forward a segment, detach the tensor from previous graph + mapping: Dict[int, torch.Tensor] = dict() + for itensor in input_tensors: + if torch.is_tensor(itensor) and itensor.requires_grad: + mapping[id(itensor)] = itensor.detach().requires_grad_() + input_dtensors = tuple(mapping[id(t)] if id(t) in mapping else t for t in input_tensors) + + saved_pairs = [(id(itensor), dtensor) for itensor, dtensor in zip(input_tensors, input_dtensors)] + Executor._detach.setdefault(name, []).append(saved_pairs) + + outputs = subgraph(*input_dtensors) return outputs @staticmethod @@ -109,40 +105,45 @@ def backward(name: str, """ Backward Procedure. - input_tensors: List[torch.Tensor]: + @param input_tensors List[torch.Tensor] tensors that their gradient need to be computed, including parameters. Correspoinding forward input tensors. - output_tensors: + @param output_tensors List[torch.Tensor] tensors that start for gradient backward computation. Corresponding to forward output tensors. - output_tensor_grads: + @param output_tensor_grads List[torch.Tensor]: gradient tensors corresponding to output_tensors. - Returns: - gradient in order of non-parameter tensors in input_tensors. - (Note parameter tnesors already have gradient accumulated at .grad attribute) + @return gradients List[torch.Tensor]: + gradient tensors corresponding to input_tensors. """ - if len(output_tensors) == 0: - return None + if len(output_tensors) == 0: return None - assert name in Executor._detach, f"forward graph: {name} not run before" - input_tensors = [t for t in input_tensors if torch.is_tensor(t) and not isinstance(t, torch.nn.Parameter)] - input_tensors = [t for t in input_tensors if t.requires_grad] - input_tensors = [Executor._detach[name][t] if t in Executor._detach[name] else t for t in input_tensors] + saved_pairs = Executor._detach[name].pop(0) + tensor_ids: List[int] = [pair[0] for pair in saved_pairs] + dtensors: List[torch.Tensor] = [pair[1] for pair in saved_pairs] for t in input_tensors: - t.retain_grad() + if id(t) not in tensor_ids: + warnings.warn("input doesn't match. Make sure in scheduling that earlier forward perform earlier backward") + + input_tensors = [] + for t in dtensors: + if torch.is_tensor(t) and t.requires_grad: + t.retain_grad() + input_tensors.append(t) + torch.autograd.backward( output_tensors, grad_tensors=output_tensor_grads, ) grads = tuple(t.grad for t in input_tensors) assert all(grad is not None for grad in grads), "RuntimeError: got gradient None" - del Executor._detach[name] + if len(grads) == 0: return None elif len(grads) == 1: return grads[0] - else: return tuple(grads) + else: return grads @staticmethod def clear(): @@ -150,56 +151,17 @@ def clear(): @staticmethod def check_clear(): - assert len(Executor._detach) == 0, \ - f"Find remain not consumed sub-graph: {tuple(Executor._detach.keys())}" + for name, npairs in Executor._detach.items(): + assert len(npairs) == 0, \ + f"Fine remaining segment needs backward: {name}, remaining times: {len(npairs)}" fexecute = Executor.fexecute aexecute = Executor.aexecute backward = Executor.backward - -# def backward(input_tensors : List[torch.Tensor], -# output_tensors: List[torch.Tensor], -# output_tensor_grads: List[torch.Tensor]): -# """ -# Backward Procedure. -# -# input_tensors: List[torch.Tensor]: -# tensors that their gradient need to be computed, including parameters. -# Correspoinding forward input tensors. -# -# output_tensors: -# tensors that start for gradient backward computation. -# Corresponding to forward output tensors. -# -# output_tensor_grads: -# gradient tensors corresponding to output_tensors. -# -# Returns: -# gradient in order of non-parameter tensors in input_tensors. -# (Note parameter tnesors already have gradient accumulated at .grad attribute) -# """ -# if len(input_tensors) == 0: -# return None -# grads = list() -# in_grads = torch.autograd.grad( -# outputs = output_tensors, -# inputs = input_tensors, -# grad_outputs = output_tensor_grads, -# allow_unused=True -# ) -# for tensor, grad in zip(input_tensors, in_grads): -# if isinstance(tensor, torch.nn.Parameter): -# if tensor.grad is not None: -# tensor.grad += grad -# else: -# tensor.grad = grad -# else: -# grads.append(grad) -# if len(grads) == 0: return None -# elif len(grads) == 1: return grads[0] -# else: return tuple(grads) +# register checking for normal exit +atexit.register(Executor.check_clear) ### =================== Experimental Feature ======================= From df8f5a8ac9a71c418d38326c498aa6a00c0dfe51 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 9 Feb 2023 09:59:12 +0800 Subject: [PATCH 1236/1892] fix inter-rvd generation bug --- cube/graph/gener/rvd/inter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cube/graph/gener/rvd/inter.py b/cube/graph/gener/rvd/inter.py index fb0ac4e2..bc3eefcb 100644 --- a/cube/graph/gener/rvd/inter.py +++ b/cube/graph/gener/rvd/inter.py @@ -203,6 +203,8 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD, placement: Optional[Tuple[i else: for srcs, dst in zip(imat.reshape(-1, chunks), omat.flatten()): srcs = srcs.tolist() + if primitive is MovePrim: + srcs = [srcs[0]] prims.append(primitive(srcs, [dst])) return dst_layout, prims From 6606292568f04de8016972c76a9b65d16b4b19cd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 9 Feb 2023 13:32:20 +0800 Subject: [PATCH 1237/1892] fix all to all primitive bug --- cube/graph/gener/rvd/intra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index 8152b263..e26af723 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -66,7 +66,7 @@ def d2d(rvd: TRVD, from_dim: int, to_dim: int, chunks: int) -> Tuple[TRVD, Calla assert rvd[2+from_dim] % chunks == 0, f"not dividable dim: {rvd[2+from_dim]} // {chunks}" rvd = list(rvd) rvd[2+from_dim], rvd[2+to_dim] = rvd[2+from_dim] // chunks, rvd[2+to_dim] * chunks - return rvd, partial(AllToAllPrim, idim=from_dim, odim=from_dim) + return rvd, partial(AllToAllPrim, idim=from_dim, odim=to_dim) @staticmethod def v2r(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: From 05065076390bfccde9c249ba0e2c427283fc8248 Mon Sep 17 00:00:00 2001 From: tntnnlrw Date: Thu, 9 Feb 2023 01:17:31 -0800 Subject: [PATCH 1238/1892] fix the number bug in gpt/model.py --- examples/nlp/gpt/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 1bf54a7c..773a3082 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -30,7 +30,7 @@ def build_gpt_config(name: str) -> Config: elif name == '6.7B': embed_dim, layers, attention_heads = 4096, 32, 32 elif name == '15B': - embed_dim, layers, attention_heads = 5120, 48, 32 + embed_dim, layers, attention_heads = 5120, 48, 40 elif name == '39B': embed_dim, layers, attention_heads = 8192, 48, 64 elif name == '175B': From e150b7581c6a0595ed8fa5550960525020b29743 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 10 Feb 2023 19:25:22 +0800 Subject: [PATCH 1239/1892] refine code and docs on schedule plan --- cube/graph/schedule/schedplan.py | 328 ++++++++++++++++++++++--------- 1 file changed, 236 insertions(+), 92 deletions(-) diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py index 45cb7f81..faab3237 100644 --- a/cube/graph/schedule/schedplan.py +++ b/cube/graph/schedule/schedplan.py @@ -1,5 +1,4 @@ -from typing import Dict, List, Union, Callable, Optional, Tuple, Set -import itertools +from typing import Dict, List, Optional, Tuple, Set from cube.ir.cten import IRCell from cube.ir.adapter import IRAdapter @@ -24,6 +23,14 @@ def __init__(self, cell: IRCell, micro_batch_id: int) -> None: self._block: IRCell = cell self._micro_batch_id: int = micro_batch_id + def __eq__(self, other): + if isinstance(other, Block): + return other.blk == self.blk and other.mid == self.mid + return False + + def __hash__(self) -> int: + return hash((self._block, self._micro_batch_id)) + @property def device(self) -> Tuple[int]: return tuple(self._block.device) @@ -43,14 +50,59 @@ def __repr__(self) -> str: return f'Block({self._micro_batch_id})-{self.device} : {self._block}' +class ScheduleDependency: + + def __init__(self, graph: IRGraph) -> None: + # adapter info + self.graph: IRGraph = graph + self.dataloaders : List[IRDataOperation] = [] + self.segments: List[IRSegment] = [] + self.adapters: List[IRAdapter] = [] + self.recvers: Dict[IRAdapter, IRSegment] = {} + self.senders: Dict[IRAdapter, IRSegment] = {} + self.reducers: List[IRWeightReducer] = [] + + def build(self): + """ + Cluster operations and build dependency to identify the connected + segments for each adapter. + """ + # get all dataloaders + self.dataloaders = list(self.graph.select(ntype=IRDataOperation, flatten=False)) + # get all segment + segments: List[IRSegment] = self.graph.select(ntype=IRSegment, flatten=False) + self.segments = segments + # get all adapters + for adapter in self.graph.select(ntype=IRAdapter, flatten=False): + self.adapters.append(adapter) + for segment in segments: + if self.graph.depends(adapter, segment): + assert adapter not in self.recvers, \ + f"Detected more than one segments to recv data from a same adapter" + self.recvers[adapter] = segment + elif self.graph.depends(segment, adapter): + assert adapter not in self.senders, \ + f"Detected more than one segments to send data from a same adapter" + self.senders[adapter] = segment + # get all weight reducers + self.reducers = self.graph.select(ntype=IRWeightReducer, flatten=False) + + def depend(self, prev: Block, next: Block) -> bool: + return prev.mid == next.mid and self.graph.depends(prev.blk, next.blk) + + class PlanBase: - def __init__(self): + def __init__(self, graph: IRGraph, _dependency: Optional[ScheduleDependency] = None): + self._graph: IRGraph = graph self._step_devs: List[Set[int]] = [] self._step_segments: List[List[Block]] = [] - # adapters executed after the segments on that step + # adapters executed *after* the segments on that step self._step_adapters: List[List[Block]] = [] + self._dependency = _dependency if _dependency is not None \ + else ScheduleDependency(graph) + # topological sequence self._seqs: List[IRCell] = [] @@ -58,6 +110,10 @@ def __init__(self): def nsteps(self) -> int: return len(self._step_segments) + @property + def graph(self) -> IRGraph: + return self._graph + def nodes(self) -> Tuple[Block]: return tuple(self._seqs) @@ -66,6 +122,8 @@ def add_segment(self, seg: IRSegment, micro_batch_id: int, step: int) -> Block: Add a segment `seg` to be executed with `micro-batch-id` data at step `step`. """ self._extend_step(step) + if len(self._step_segments[step]) == 1 and isinstance(self._step_segments[0], PlanBase): + assert False, "Cannot add an IRSegment into a step that already has Repetend." assert all(devid not in self._step_devs[step] for devid in seg.device), \ f"A step cannot execute multiple segments on a same device" block = Block(seg, micro_batch_id) @@ -73,14 +131,52 @@ def add_segment(self, seg: IRSegment, micro_batch_id: int, step: int) -> Block: self._step_devs[step].update(seg.device) return block + def segments(self, step: int) -> Tuple[Block]: + """ + Get segment blocks at step + """ + assert step < self.nsteps + return tuple(self._step_segments[step]) + + def all_segments(self) -> Tuple[Block]: + """ + Get all segment blocks + """ + blocks = [] + for step in range(self.nsteps): + blocks += self._step_segments[step] + return tuple(blocks) + def _extend_step(self, step: int): + """ + Extend the maximize plan with `step`. + """ if len(self._step_segments) <= step: nextend = step - len(self._step_segments) + 1 self._step_segments += [[] for _ in range(nextend)] self._step_devs += [set() for _ in range(nextend)] self._step_adapters += [[] for _ in range(nextend)] + def _place_dataloader(self): + """ + Place dataloaders together with segments + """ + # FIXME: this may not work for multiple segments in a same + # micro-batch require for the data + for dl in self._dependency.dataloaders: + for step, blocks in enumerate(self._step_segments): + for block in blocks: + if isinstance(block, Block): + segment, mid = block.blk, block.mid + if self.graph.depends(dl, segment): + self._step_segments[step].insert(0, Block(dl, mid)) + break + def topo_sort(self): + """ + Sort the step-based execution plan and generates an execution sequence + followed topological order. + """ self._seqs = [] for step in range(self.nsteps): self._seqs += self._step_segments[step] @@ -94,20 +190,25 @@ class Repetend(PlanBase): micro-batch index. """ - def __init__(self, span: int, - step_nodes: List[List[Block]], - step_adapters: List[List[IRAdapter]], - step_devs: List[Set[int]]): + def __init__(self, graph: IRGraph, dependency: ScheduleDependency, + span: int, step_segments: List[List[Block]], ): """ + @param graph IRGraph + @param dependency: ScheduleDependency @param span int: the repeated execution time + @param step_segments List[List[Block]] """ - super().__init__() + super().__init__(graph, dependency) self._span = span - self._step_segments = step_nodes - self._step_adapters = step_adapters - self._step_devs = step_devs - # adapters out of for loop - self._post_adapters: List[IRAdapter] = [] + self._extend_step(len(step_segments)) + self._step_segments = step_segments + for step, blocks in enumerate(step_segments): + devices = set() + for block in blocks: + devices.update(block.device) + self._step_devs[step] = devices + # the adapters that will be performed outside the repetend + self._post_adapters: List[Block] = [] @property def device(self) -> Tuple[int]: @@ -122,30 +223,84 @@ def span(self) -> int: def nodes(self) -> Tuple[Block]: return tuple(self._seqs) + + def apply(self): + self._place_adapters() + self._place_dataloader() + self.topo_sort() + + def _place_adapters(self): + """ + Place adapters + """ + # step1: unrolling repetend for one step + cnts: Dict[IRSegment, int] = {} + for step in range(self.nsteps): + for blk in self.segments(step): + cnts.setdefault(blk.blk, 0) + cnts[blk.blk] += 1 + extended_blocks = [] + for step in range(self.nsteps): + for blk in self.segments(step): + extend_blk = Block(blk.blk, blk.mid + cnts[blk.blk]) + extended_blocks.append(extend_blk) + # step2: generate adapters for each step + all_blocks = self.all_segments() + for adapter, sender in self._dependency.senders.items(): + for step in range(self.nsteps): + for block in self.segments(step): + if block.blk != sender: continue + # sender adapter can be classified into three categories + # 1) its recver are in the same repetend + # 2) its recver are in neighbored repetend + # - we don't allow send and recver in un-neighbored repetend + # 3) its recver are outside the repetend + recver = self._dependency.recvers[adapter] + rblock = Block(recver, block.mid) + ablock = Block(adapter, block.mid) + # case 1) + if rblock in all_blocks: + self._step_adapters[step].append(ablock) + # case 2) + elif rblock in extended_blocks: + self._step_adapters[self.nsteps-1].append(Block(adapter, block.mid - cnts[blk.blk])) + self._post_adapters.append(ablock) + # case 3) + else: + self._post_adapters.append(ablock) + + def get_post_adapters(self) -> List[Block]: + return tuple(self._post_adapters) def __repr__(self): dscp = f'Repetend-{self.device}(span={self._span}\n' - for blk in self._seqs: - dscp += ' ' + repr(blk) + '\n' + for step, blks in enumerate(self._step_segments): + dscp += f'\n Substep {step}:\n' + for blk in blks: + dscp += ' ' + repr(blk) + '\n' dscp += ')' return dscp class SchedulePlan(PlanBase): + """ + A schedule plan leverages the fact no data dependency across different + micro-batches. The schedule plan takes a step-based description to describe + the scheduling of different micro-batch data. - def __init__(self, graph: IRGraph, num_microbatches: int): - super().__init__() - self._graph: IRGraph = graph + The step-based description describes every segment to be executed on which + micro-batch data and executed at which step. The dependency requires segments + inside one micro-batch should follow happen-before relationship: - # adapter info - self._dataloaders : List[IRDataOperation] = [] - self._segments: List[IRSegment] = [] - self._adapters: List[IRAdapter] = [] - self._recvers: Dict[IRAdapter, IRSegment] = {} - self._senders: Dict[IRAdapter, IRSegment] = {} - self._reducers: List[IRWeightReducer] = [] + If segment A depends on segment B, then step of segment A must be smaller + after segment B for a same micro-batch index. + + For each device, only up to one segment can be executed on a step. + """ + + def __init__(self, graph: IRGraph, num_microbatches: int): + super().__init__(graph) # execution sequence - self._device_seqs: Dict[int, Union[Repetend, IRSegment]] = {} self._num_microbatches = num_microbatches # bind to the graph graph._bind_schedule(self) @@ -168,10 +323,33 @@ def device(self) -> Tuple[int]: def graph(self) -> IRGraph: return self._graph - def steady_repeat(self, from_step: int, to_step: int, repeat: int): - raise NotImplementedError("Not supported for steady representation") - - def finish(self) -> bool: + def repeat(self, from_step: int, to_step: int, span: int) -> Repetend: + """ + Create a repetend where the nodes inside the step ranges will + be repeatedly executed by `span` time, with the increasing micro-batch + index. The microbatch index among same segment must be + consecutive. + + Note: calling this will shrink self.nsteps and the blocks begin from + to_step will be shifted to the front of total steps by `to_step - from_step + + @param from_step int: starting (included) step + @param to_step int: stopping (excluded) step + @param span int: repeat time, i.e., number of increasing micro-batch index + + @return repetend Repetend + """ + raise NotImplementedError("repeat is not supported.") + assert 0 < from_step and from_step < self.nsteps + assert 0 < to_step and to_step <= self.nsteps + segment_blocks: List[List[Block]] = self._step_segments[from_step:to_step] + repetend = Repetend(self._graph, self._dependency, span, segment_blocks) + self._step_segments = self._step_segments[:from_step] + [[repetend]] + self._step_segments[to_step:] + self._step_adapters = self._step_adapters[:from_step] + [[]] + self._step_adapters[to_step:] + self._step_devs = self._step_devs[:from_step] + [set(repetend.device)] + self._step_devs[to_step:] + return repetend + + def finish(self): """ Check whether the description contains full micro-batches """ @@ -179,83 +357,49 @@ def finish(self) -> bool: def apply(self): """ - Insert generated adapters in the emitted sequence. - This can be called by system only after generating adapters. + Insert generated adapters, dataloaders and reducers, and generat + an execution sequence in topological order. + This can only be called by system after adapter generation.. """ - # step 1: identify connected segements for each generated adapter - self._build_dependency() - # step 2: place adapters, dataloaders + # step 1: build dependency for scheduling + self._dependency.build() + # step 2: apply repetends + for blocks in self._step_segments: + if len(blocks) == 1 and isinstance(blocks[0], Repetend): + blocks[0].apply() + # step 3: apply this scheduling self._place_adapters() self._place_dataloader() - # step 3: generate topological sequence + # step 4: generate topological sequence self.topo_sort() - def _build_dependency(self): - """ - Cluster operations and build dependency to identify the connected - segments for each adapter. - """ - # get all dataloaders - self._dataloaders = list(self._graph.select(ntype=IRDataOperation, flatten=False)) - # get all segment - segments: List[IRSegment] = self._graph.select(ntype=IRSegment, flatten=False) - self._segments = segments - # get all adapters - for adapter in self._graph.select(ntype=IRAdapter, flatten=False): - self._adapters.append(adapter) - for segment in segments: - if self._graph.depends(adapter, segment): - assert adapter not in self._recvers, \ - f"Detected more than one segments to recv data from a same adapter" - self._recvers[adapter] = segment - elif self._graph.depends(segment, adapter): - assert adapter not in self._senders, \ - f"Detected more than one segments to send data from a same adapter" - self._senders[adapter] = segment - # get all weight reducers - self._reducers = self._graph.select(ntype=IRWeightReducer, flatten=False) - - def _place_adapters(self, cost_fn: Optional[Callable] = None): + def _place_adapters(self): """ Place adapters to make sure the communication happens correctly and efficiently. - - @param cost_fn Optional[Callable]: takes a segment and outputs - the execution cost in float.By default (None), this assumes - each segement has the same execution cost of 1.0. """ - cost_fn = lambda x: 1.0 if cost_fn is None else cost_fn - for adapter in self._adapters: - assert adapter in self._senders - sender: IRSegment = self._senders[adapter] + assert len(self._dependency.adapters) > 0 + for adapter in self._dependency.adapters: + sender: IRSegment = self._dependency.senders[adapter] + print(f'place sender: {sender}') # find sender step and insert adapter for step, blocks in enumerate(self._step_segments): - segments = [block.blk for block in blocks] - mids = [block.mid for block in blocks] - if sender in segments: - mid = mids[segments.index(sender)] - self._step_adapters[step].append(Block(adapter, mid)) - - def _place_dataloader(self): - """ - Place dataloaders together with segments - """ - for dl in self._dataloaders: - for step, blocks in enumerate(self._step_segments): - for block in blocks: - segment, mid = block.blk, block.mid - if self.graph.depends(dl, segment): - self._step_segments[step].insert(0, Block(dl, mid)) - break + if len(blocks) == 0: continue + if len(blocks) == 1 and isinstance(blocks[0], Repetend): + self._step_adapters[step] += list(blocks[0].get_post_adapters()) + else: + assert all(isinstance(blk, Block) for blk in blocks) + segments = [block.blk for block in blocks] + mids = [block.mid for block in blocks] + if sender in segments: + mid = mids[segments.index(sender)] + self._step_adapters[step].append(Block(adapter, mid)) def topo_sort(self): super().topo_sort() - for reducer in self._reducers: + for reducer in self._dependency.reducers: self._seqs.append(reducer) - def depends(self, prev: Block, next: Block) -> bool: - return prev.mid == next.mid and self._graph.depends(prev.blk, next.blk) - def __repr__(self) -> str: dscp = f"SchedulePlan:\n" for step in range(self.nsteps): From 45014ea2902e1193a51673f5ad8ee7e8d03e5a9f Mon Sep 17 00:00:00 2001 From: lynex Date: Fri, 10 Feb 2023 21:14:46 +0800 Subject: [PATCH 1240/1892] 1. fix inference according to Zhiqi's comments 2. init GPT MoE inference example --- examples/nlp/blocks/encoder.py | 6 +- examples/nlp/blocks/mlp_moe.py | 212 ++++++++++++++++++++++++++++++++ examples/nlp/gpt/infer.py | 16 ++- examples/nlp/gpt/model.py | 42 +++++-- examples/nlp/gpt/policy/spmd.py | 4 + 5 files changed, 261 insertions(+), 19 deletions(-) create mode 100644 examples/nlp/blocks/mlp_moe.py diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py index 5c906877..13c3e191 100644 --- a/examples/nlp/blocks/encoder.py +++ b/examples/nlp/blocks/encoder.py @@ -1,6 +1,7 @@ import torch from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention, func_print_shape from examples.nlp.blocks.mlp import MLP +from examples.nlp.blocks.mlp_moe import MoEMLP class EncoderLayer(torch.nn.Module): @@ -37,14 +38,15 @@ class EncoderInferLayer(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, attn_hidden_dim: int, ffn_hidden_dim: int, seqlen: int = -1, batch_size: int = 1, - dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): + dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0, + moe_size: int = 1): super().__init__() self.self_attn_partial = MultiHeadOneAttention( embed_dim, num_heads, attn_hidden_dim, atten_dropout ) self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) + self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) if moe_size == 1 else MoEMLP(embed_dim, ffn_hidden_dim, activation_dropout, moe_size) self.final_layer_norm = torch.nn.LayerNorm(embed_dim) # id-embed + pos-embed diff --git a/examples/nlp/blocks/mlp_moe.py b/examples/nlp/blocks/mlp_moe.py new file mode 100644 index 00000000..bc936cbf --- /dev/null +++ b/examples/nlp/blocks/mlp_moe.py @@ -0,0 +1,212 @@ +import torch +import cube +import torch.distributed +from typing import Tuple + +# (N L) emb; emb * exp_num -> (N L), 1(part_idx) +@cube.graph.parser.register('*, * -> *') +@torch.jit.ignore +def gating_func(x, gate_w) -> torch.Tensor: + # assert top_k == 1 + affinities = torch.matmul(x, gate_w) + # print(f'affinities = {affinities}') + dst_pid_list = torch.argmax(affinities, -1) + # print(f'dst_pid_list = {dst_pid_list}') + return dst_pid_list + +# split tokens into groups by target expert +@cube.graph.parser.register('* -> *') +@torch.jit.ignore +def split_tokens_by_eid(tokens, eids, expert_num): + print(f"tokens = {tokens}, shape {tokens.size()}") + print(f"eids = {eids}, shape {eids.size()}") + reshape_needed = list(tokens.size()) != list(eids.size()) + reshape_feat_dim = list(tokens.size())[-1] + print("##### reshape_feat_dim = " + str(reshape_feat_dim)) + if reshape_needed: + vid_part_extend = torch.unsqueeze(eids, 2).repeat(1, 1, reshape_feat_dim) + print("vid_part_extend = " + str(vid_part_extend)) + else: + vid_part_extend = eids + + token_lists = [] + for exp_id in range(0, expert_num): + print("exp_id = " + str(exp_id)) + mask = (vid_part_extend == exp_id) + print("mask = " + str(mask)) + parted_tokens = torch.masked_select(tokens, mask) + if reshape_needed: + parted_tokens = parted_tokens.reshape(-1, reshape_feat_dim) + print("parted_tokens = " + str(parted_tokens)) + token_lists.append(parted_tokens) + return token_lists + + +@cube.graph.parser.register('* -> *') +@torch.jit.ignore +def samesize_all_gather(tensor: torch.Tensor): + tensor_list = [torch.zeros_like(tensor) for _ in + range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensor_list, tensor) + return torch.stack(tensor_list) + + +@cube.graph.parser.register('* -> *') +@torch.jit.ignore +def nonvarsize_gather(tensor: torch.Tensor, dst): + tensor_list = [torch.zeros_like(tensor) for _ in + range(torch.distributed.get_world_size())] if torch.distributed.get_rank() == dst else None + torch.distributed.gather(tensor, tensor_list, dst) + + return torch.cat(tensor_list) if torch.distributed.get_rank() == dst else None + + +@cube.graph.parser.register('* -> *') +@torch.jit.ignore +def varsize_tensor_gather(tensor: torch.Tensor, dst): + tensor = tensor.contiguous() + # cuda_device = f'cuda:{torch.distributed.get_rank()}' + print(f'tensor.get_device() = {tensor.get_device()}') + size_tens = torch.tensor([tensor.shape[0]], dtype=tensor.dtype, device=f'cuda:{tensor.get_device()}') + print(f'size_tens.get_device() = {size_tens.get_device()}') + size_tens = samesize_all_gather(size_tens) + print(f"size_tens = {size_tens}, tensor.shape[1:] = {tensor.shape[1:]}") + + max_size = size_tens.max().int().item() + padded = torch.empty(max_size, *tensor.shape[1:], dtype=tensor.dtype, device=f'cuda:{tensor.get_device()}') + padded[:tensor.shape[0]] = tensor + + ga = nonvarsize_gather(padded, dst) + print(f" tensor = {tensor}; padded = {padded}; ga = {ga}") + + if torch.distributed.get_rank() != dst: # not this rank as dst + return [] + + slices = [] + for i, sz in enumerate(size_tens): + start_idx = i * max_size + end_idx = start_idx + sz.int().item() + print("start_idx = " + str(start_idx)) + print("end_idx = " + str(end_idx)) + + if end_idx > start_idx: + print("ga[start_idx:end_idx] = " + str(ga[start_idx:end_idx])) + slices.append(ga[start_idx:end_idx]) + # print("slices = " + str(slices)) + else: + slices.append(torch.empty((0, *tensor.shape[1:]), dtype=tensor.dtype, device=f'cuda:{tensor.get_device()}')) + # slices.append(torch.tensor([], dtype=tensor.dtype).resize(0, 3)) + return slices + + +@cube.graph.parser.register('* -> *') +@torch.jit.ignore +def all_to_all_token(input_list): + print(f'***** all_to_all_token.input_list = {input_list}') + data_type = input_list[0].dtype + print(data_type) + ret = [] + for i in range(len(input_list)): + gather_list = varsize_tensor_gather(input_list[i], i) # new replacement + if i == torch.distributed.get_rank(): #TODO check local_rank + ret = gather_list + print(f'***** all_to_all_token.output_list = {ret}') + return ret + + +# N * 1, N * emb -> M * 1, M * emd +@cube.graph.parser.register('*, * -> *') +@torch.jit.ignore +def send_to_experts(dst_pid_list, x, expert_num: int) -> Tuple[torch.Tensor]: + # send to remote and recv from remote + token_lists = split_tokens_by_eid(x, dst_pid_list, expert_num) + print(f'### token_lists = {token_lists}') + local_token_lists = all_to_all_token(token_lists) # exchange idx + print(f'### local_token_lists = {local_token_lists}') + return local_token_lists + + +# M * 1, M * emd -> N * 1, N * emb +@cube.graph.parser.register('*, * -> *') +@torch.jit.ignore +def recv_from_experts(dst_pid_list: torch.Tensor, new_local_token_lists: torch.Tensor, expert_num: int) -> Tuple[torch.Tensor]: + local_token_lists = all_to_all_token(new_local_token_lists) + print(f'### [return] local_token_lists = {local_token_lists}') + + vid_part_np = dst_pid_list.detach().flatten().cpu().tolist() #TODO vid_part_np = dst_pid_list.detach().cpu().numpy() + print("vid_part_np = " + str(vid_part_np)) + # part_count = {} + # for i in range(expert_num): + # part_count[i] = 0 + part_count = [0 for i in range(expert_num)] + print(f'part_count = {part_count}') + + embed_list = [] + for i in range(len(vid_part_np)): + pid = vid_part_np[i] + print(f'pid = {pid}') + offset = part_count[pid] + part_count[pid] += 1 + embed_list.append(local_token_lists[pid][offset]) + + # print("### embed_list = " + str(embed_list)) + embed = torch.stack(embed_list) + print("### final embed = " + str(embed)) + + return embed + + +# @cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward_moe') +@cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+, E^ K-> L^ N E^', name='feedforward_moe') +@torch.jit.ignore +def feedforward_moe(x: torch.Tensor, + proj1: torch.Tensor, proj1_bias: torch.Tensor, + proj2: torch.Tensor, + gate_w: torch.Tensor, + dropout: float, + is_training: bool = True, + expert_num: int = 1) -> torch.Tensor: + #gating + dst_pid_list = gating_func(x, gate_w) + #shuffle tokens + # src_pid_list, x_local + local_token_lists = send_to_experts(dst_pid_list, x, expert_num) + + new_local_token_lists = [] + for x_local in local_token_lists: + #local expert + with torch.no_grad(): + print(f'#### checking ####', x_local, proj1, proj1_bias) + x_local = torch.nn.functional.linear(x_local, proj1, proj1_bias) + x_local = torch.nn.functional.gelu(x_local) + #TODO FIXME x_local = torch.nn.functional.dropout(x_local, dropout, is_training, False) + x_local = torch.nn.functional.linear(x_local, proj2, None) + new_local_token_lists.append(x_local) + + #shuffle back tokens + print(f'### new_local_token_lists = {new_local_token_lists}') + x = recv_from_experts(dst_pid_list, new_local_token_lists, expert_num) + return x + + +class MoEMLP(torch.nn.Module): + def __init__(self, embed_dim: int, hidden_dim: int, dropout: float, expert_num: int = 1): + super().__init__() + # self.proj1 = torch.nn.Parameter(torch.ones((hidden_dim // expert_num, embed_dim))) # TODO fix me empty + # self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim // expert_num,))) + # self.proj2 = torch.nn.Parameter(torch.ones((embed_dim, hidden_dim // expert_num))) # TODO fix me empty + # self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) + # self.gate_w = torch.nn.Parameter(torch.rand((embed_dim, expert_num))) + self.proj1 = torch.nn.Parameter(torch.rand((hidden_dim, embed_dim))) # TODO fix me empty + self.proj1_bias = torch.nn.Parameter(torch.rand((hidden_dim,))) + self.proj2 = torch.nn.Parameter(torch.rand((embed_dim, hidden_dim))) # TODO fix me empty + self.proj2_bias = torch.nn.Parameter(torch.rand((embed_dim,))) + self.gate_w = torch.nn.Parameter(torch.rand((embed_dim, expert_num))) + self.dropout = dropout + self.expert_num = expert_num + + def forward(self, x: torch.Tensor): + x = feedforward_moe(x, self.proj1, self.proj1_bias, + self.proj2, self.gate_w, self.dropout, self.training, self.expert_num) + x = x + self.proj2_bias + return x \ No newline at end of file diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py index 81066599..5e5d49f1 100644 --- a/examples/nlp/gpt/infer.py +++ b/examples/nlp/gpt/infer.py @@ -9,6 +9,8 @@ PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/gpt/infer.py --policy PASSingle --fp16 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 examples/nlp/gpt/infer.py --policy PASMegatronInferTP --fp16 + +PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 examples/nlp/gpt/infer.py --policy PASDP --fp16 --moe_size 2 """ @@ -28,11 +30,13 @@ import argparse -parser = argparse.ArgumentParser(description='GPT Train') +parser = argparse.ArgumentParser(description='GPT Inference') parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, - help='use fp16 for the training') + help='use fp16 for the inference') parser.add_argument('--local_rank', type=int, default=0) +parser.add_argument('--moe_size', type=int, default=1, + help='number of experts, use MoE for the inference if moe_size > 1') args = parser.parse_args() cube.init() @@ -54,7 +58,9 @@ def inter(): batch_size = 8 - model = GPTInfer(batch_size=batch_size, cfg=build_gpt_config('350M')) + cfg = build_gpt_config('toy') + cfg.moe_size = args.moe_size + model = GPTInfer(batch_size=batch_size, cfg=cfg) model = model if not args.fp16 else model.half() # model = model.cuda() #only for PyTorch run model.eval() @@ -66,11 +72,11 @@ def inter(): def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) - # return loss + return loss model = model.get_gen_module() torch.distributed.barrier() - print_each_rank('model weight consumpition:', rank_only=0) + print_each_rank('model weight consumption:', rank_only=0) memory_summary() CudaTimer(enable=False).warmup() diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 34e50ff5..87ae0239 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -18,10 +18,13 @@ class Config: dropout: float = 0.2 attn_dropout: float = 0.2 activation_dropout: float = 0.2 + moe_size: int = 1 def build_gpt_config(name: str) -> Config: - if name == '350M': + if name == 'toy': + embed_dim, layers, attention_heads = 32, 4, 16 + elif name == '350M': embed_dim, layers, attention_heads = 1024, 24, 16 elif name == '760M': embed_dim, layers, attention_heads = 1536, 24, 16 @@ -90,18 +93,33 @@ class GPTInfer(torch.nn.Module): def __init__(self, batch_size: int = 1, cfg: Config = Config()): super().__init__() # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) - self.embedw = torch.nn.Parameter(torch.rand(cfg.num_embeddings, cfg.embed_dim) / 128) + self.embedw = torch.nn.Parameter(torch.rand(cfg.num_embeddings, cfg.embed_dim)) self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) self.embed_dropout = torch.nn.Dropout() - self.layers = torch.nn.ModuleList( - [EncoderInferLayer( - cfg.embed_dim, cfg.attention_heads, - cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.seqlen, - batch_size, - cfg.dropout, cfg.attn_dropout, cfg.activation_dropout - ) for _ in range(cfg.layers)] - ) + if cfg.moe_size == 1: + self.layers = torch.nn.ModuleList( + [EncoderInferLayer( + cfg.embed_dim, cfg.attention_heads, + cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.seqlen, + batch_size, + cfg.dropout, cfg.attn_dropout, cfg.activation_dropout + ) for _ in range(cfg.layers)] + ) + else: + assert cfg.moe_size > 1 + self.layers = torch.nn.ModuleList() + for layer_id in range(cfg.layers): + self.layers.append( + EncoderInferLayer( + cfg.embed_dim, cfg.attention_heads, + cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.seqlen, + batch_size, + cfg.dropout, cfg.attn_dropout, cfg.activation_dropout, + 1 if (layer_id % 2) == 0 else cfg.moe_size + ) + ) + self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) @@ -127,8 +145,8 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): cube.runtime.function.anchor('last_embed') logits = torch.nn.functional.linear(enc, self.embedw) # simplified - loss = torch.sum(logits) - return loss + # loss = torch.sum(logits) + return logits class GPTDataLoader(cube.runtime.syndata.SynDataLoader): diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index 0f2a1710..04def02c 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -2,6 +2,7 @@ from cube.graph import IRGraph from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.pyfunc import IRPyFunc from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation @@ -64,6 +65,9 @@ def PASDP(graph: IRGraph, resource): # partition forward operators for node in graph.select(ntype=IRFwOperation): + if isinstance(node, IRPyFunc): + graph.assign(node, 0) + continue if len(node.inputs()) == 0: continue #FIXME: a workaround to find batch dimension batch_dim = node.input(0).shape.index(bs) From 5518f3eb870951bb16820d5ca6a2191bbd8b19f4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Feb 2023 18:26:59 +0800 Subject: [PATCH 1241/1892] better interface for graph staging --- cube/graph/graph.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 85e6d15b..3f2e9940 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -647,8 +647,8 @@ def add_schedule(self, nodes: List[IRCell]) -> bool: def staging(self, nodes: Tuple[IRFwOperation]): """! - Group forward / dataloader operators into sequential stages. - The corresponding backward operators will also be grouped into stages + Group forward operators into sequential stages. + The corresponding backward operators (if have) will also be grouped into stages Cross-stage dataflow will be limited to neighbor stages. This should be called before any operator partition. @@ -656,12 +656,6 @@ def staging(self, nodes: Tuple[IRFwOperation]): For example, after staging, user cannot schedule a (transformed) node from one stage to another stage. - The stage is a concept that is only about logical separation of nodes, - it doesn't have additional constraints for device assignment. - - This will keep each tensor to be only consumed once in - semantic representation. - Changes will be made: 1). Identity creation: @@ -684,7 +678,7 @@ def staging(self, nodes: Tuple[IRFwOperation]): stage 5: t5 = identity(t4) xx = consume(t5) - @param starts Tuple[IRFwOperations]: the start node of each stage + @param nodes Tuple[IRFwOperations]: the start forward node of each stage. @return None """ assert all(isinstance(node, IRFwOperation) for node in nodes), \ @@ -693,7 +687,20 @@ def staging(self, nodes: Tuple[IRFwOperation]): f"Exist node is not in graph nodes" starts = tuple(self._nodes.index(node) for node in nodes) assert len(starts) > 0 - starts = (0,) + starts if starts[0] != 0 else starts + + # adjust the start of the first stage to involve beginning operators + for idx in range(starts[0]): + node = self.node(idx) + if isinstance(node, IRDataOperation): + continue + assert isinstance(node, IRFwOperation), \ + f"Expected nodes previous from the first stage are all IRFwOperation, but got {type(node)}" + if node.name == 'multiref' or isinstance(node, IRPyFunc): + pass + else: + warnings.warn(f'Detect a node: {node} that is previous from the first stage. Will be included inside the first stage') + starts[0] = idx + break last_fidx = 0 for idx, node in enumerate(self._nodes): @@ -705,13 +712,12 @@ def staging(self, nodes: Tuple[IRFwOperation]): for sid in range(len(starts)): begin = starts[sid] end = starts[sid+1] if sid != len(starts) - 1 else last_fidx + 1 - while isinstance(self.node(begin), IRDataOperation): - begin += 1 - while end < len(self._nodes) and isinstance(self.node(end), IRDataOperation): - end -= 1 - if begin == end: continue - assert begin < end + if begin >= end: + warnings.warn(f"Detected stage {sid} doesn't have operators: [begin({begin}): end({end})). Skipped") + continue fnodes = self._nodes[begin:end] + assert all(isinstance(node, IRFwOperation) for node in fnodes), \ + f"find at least one nodes are not of IRFwOperation in the stage {sid}. They should be moved to the front" bnodes = [fnode.mirror for fnode in fnodes[::-1] if fnode.mirror is not None] fstages.append(fnodes) bstages = [bnodes] + bstages From 4cca574917183ef2764ac5458c1bb705fe43b67f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 15 Feb 2023 20:35:48 +0800 Subject: [PATCH 1242/1892] fix a bug in intra auto placer --- cube/graph/gener/rvd/intra.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index e26af723..bfb85488 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -534,13 +534,14 @@ def default_cost_fn(prim: IRAdapterPrim) -> int: class IntraAutoPlacer: @staticmethod - def auto_place(graph: IRGraph, ftensor: IRFullTensor, + def auto_place(graph: IRSegment, ftensor: IRFullTensor, producers: List[IRCell], consumers: List[IRCell], cost_fn: Optional[Callable] = None) -> List[int]: """ Automatically find good device placement for consumers given the producer placement The backward will also be considered. + @param graph IRSegment @param ftensor IRFullTensor @param producers List[IRCell]: producers that must be assigned to devices @param consumers List[IRCell]: consumers that are about to be assigned @@ -561,8 +562,7 @@ def auto_place(graph: IRGraph, ftensor: IRFullTensor, warnings.warn('Detected at least one consumer has been assigned to a device, which will be overrided by a new device placement.') if len(producers) == 1: - graph.assign(consumers[0], producers[0].device) - return [producers[0].device] + return [producers[0].device[0]] # reorder producer to match with device order producers = sorted(producers, key=lambda n: n.device[0]) From c5edd1afff742a9bf19cf1a7af4445bb1fb56baa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Feb 2023 11:31:26 +0800 Subject: [PATCH 1243/1892] fix replicate same tensor --- cube/graph/graph.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 3f2e9940..1a1bf9fd 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -9,6 +9,7 @@ from typing import Sequence, Set, Union, Tuple, List, Optional, Dict import warnings +import copy from cube.ir.cten import IRTensor, IRCell from cube.ir.unique import IDGenerator @@ -295,7 +296,10 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis for fnode in fnodes: for rtensor, itensor in zip(fnode.inputs(), node.inputs()): if isinstance(rtensor, IRSubTensor): - rtensor.grad = itensor.grad + rtensor.grad = copy.copy(itensor.grad) + for rtensor, itensor in zip(fnode.outputs(), node.outputs()): + if isinstance(rtensor, IRSubTensor): + rtensor.grad = copy.copy(itensor.grad) # insert forward for fnode in fnodes: if isinstance(node, IRFwOperation): From a83b50e0eaf314d35da61f921ce3ec6e2265952e Mon Sep 17 00:00:00 2001 From: rwlu Date: Thu, 16 Feb 2023 20:02:56 -0800 Subject: [PATCH 1244/1892] change the program interface for autodist --- examples/alphafold2/alphafold2.py | 24 ++++++++++++++-- examples/alphafold2/model.py | 44 +++++++++++++++++++++++++++++- examples/alphafold2/policy/spmd.py | 2 +- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index cf239d2a..1bab4725 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -13,6 +13,23 @@ from cube.algorithm.ops.dimops import gen_partitions from cube.graph.function.anchor import IRGraphAnchor + + +def build_alphafold_config(setting:int): + assert setting in [1, 2, 3], "setting should be in [1, 2, 3]." + # dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False + if setting == 1: + bs, s, r = 1, 128, 256 + elif setting == 2: + bs, s, r = 1, 512, 256 + elif setting == 3: + bs, s, r = 1, 512, 384 + else: + assert False, f"unrecognized setting {setting}" + + config = Config(bs, s, r) + return config + def run(size_config, other_config, policy): bs, s, r, cm, cz = size_config dtype, evo_num, use_chunk, is_train, is_extra = other_config @@ -77,13 +94,13 @@ def train_iter(model, dataloader): def test_main(): # Training && Evoformer Stack # initial training - bs, s, r, cm, cz = 1, 128, 256, 256, 128 + # bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning - # bs, s, r, cm, cz = 1, 512, 384, 256, 128 + bs, s, r, cm, cz = 1, 512, 384, 256, 128 - dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False + dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 4, False, True, False policy = spmd.PASDAP # Training && Extra Sequence @@ -109,3 +126,4 @@ def test_main(): if __name__ == '__main__': cube.init() test_main() + # build_alphafold_config(1) diff --git a/examples/alphafold2/model.py b/examples/alphafold2/model.py index 30d68981..a2be9bfe 100644 --- a/examples/alphafold2/model.py +++ b/examples/alphafold2/model.py @@ -4,12 +4,24 @@ from torch import nn from examples.alphafold2.module import * +from dataclasses import dataclass """ a simplified version for evoformer in alphafold2 - dropout layers are omitted - masks are omitted """ - +@dataclass +class Config: + bs: int = 1 + s: int = 128 + r: int = 256 + cm: int = 256 + cz: int = 128 + dtype = torch.float16 + evo_num: int = 4 + use_chunk: bool = False + is_train : bool = True + is_extra : bool = False class Evoformer(torch.nn.Module): @@ -258,3 +270,33 @@ def forward(self, msa, pair): cube.runtime.function.anchor('Evoformer Stack End') loss = torch.sum(msa) * torch.sum(pair) return loss + +class AlphaFoldlrw(nn.Module): + + def __init__(self, cfg=Config()): + super().__init__() + self.evo_num = cfg.evo_num + # add norm to work with PyTorch's recompute mechanism + self.msa_norm = torch.nn.LayerNorm(cfg.cm) + self.pair_norm = torch.nn.LayerNorm(cfg.cz) + self.evoformers = torch.nn.ModuleList([ + Evoformer(cfg.s, + cfg.cm, + cfg.cz, + use_chunk=cfg.use_chunk, + is_extra=cfg.is_extra, + is_train=cfg.is_train) for _ in range(cfg.evo_num) + ]) + + def forward(self, msa, pair): + msa = self.msa_norm(msa) + pair = self.pair_norm(pair) + + cube.runtime.function.anchor('Evoformer Stack Start') + for evoformer in self.evoformers: + cube.runtime.function.anchor('One Layer Evoformer Start') + msa, pair = evoformer(msa, pair) + cube.runtime.function.anchor('One Layer Evoformer End') + cube.runtime.function.anchor('Evoformer Stack End') + loss = torch.sum(msa) * torch.sum(pair) + return loss \ No newline at end of file diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 24d07f1c..9e734f7e 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -171,7 +171,7 @@ def PASDAP(graph: IRGraph, resource): if isinstance(fnodes[j], IRGraphAnchor): sub_indices.append(j) sub_indices.append(rhs) - graph.recompute(fnodes[lhs:rhs]) + # graph.recompute(fnodes[lhs:rhs]) for j in range(len(sub_indices) - 1): sub_l, sub_r = sub_indices[j], sub_indices[j + 1] names = [] From 7d7099b166cb0ca6472fb035b0fd13e4f7709fa3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Feb 2023 20:16:46 +0800 Subject: [PATCH 1245/1892] fix tensor.view transformation rules --- cube/graph/function/function.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 9aa3d642..87ab5514 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -688,24 +688,31 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s bracket[subdim] = str(shape_map[edim]) # find out the axis that can be partitioned - ispatial = set() - ifirst = [] + ispatial, ifirst = set(), [] for bracket in in_anno: + sdim = None for hdim in range(len(bracket)): - if bracket[hdim] == '1': - continue - ispatial.add(bracket[hdim]) - ifirst.append(bracket[hdim]) + if bracket[hdim] == '1': continue + sdim = bracket[hdim] break - ospatial = set() - ofirst = [] + if sdim is not None: + ispatial.add(sdim) + ifirst.append(sdim) + + ospatial, ofirst = set(), [] for bracket in ou_anno: + sdim = None for hdim in range(len(bracket)): - if bracket[hdim] == '1': - continue + if bracket[hdim] == '1': continue + sdim = bracket[hdim] ospatial.add(bracket[hdim]) ofirst.append(bracket[hdim]) break + if sdim is not None: + ospatial.add(sdim) + ofirst.append(sdim) + + # intersection for spatial partitioned dimensions spatial = ispatial.intersection(ospatial) # set dimension cannot be partitioned From d86e2b76b60484ae173472666c923c5ed4e0c259 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 19 Feb 2023 06:52:47 +0000 Subject: [PATCH 1246/1892] Merged PR 1448: Refine function code - Remove legacy customized operations - Change repleat, select, slice, select_scatter operations into IRDimops implementation - No more requirement for einops --- cube/algorithm/factory.py | 28 +---- cube/algorithm/ops/conv.py | 87 +++++++++++++++- cube/algorithm/ops/pad.py | 108 -------------------- cube/algorithm/ops/scatter.py | 55 ---------- cube/algorithm/ops/select.py | 103 ------------------- cube/graph/function/conv.py | 50 ++++++++- cube/graph/function/customops.py | 85 ---------------- cube/graph/function/function.py | 152 ++++++++++++++++++---------- cube/graph/function/pad.py | 53 ---------- cube/graph/function/repeat.py | 40 -------- cube/graph/function/scatter.py | 66 ------------ cube/graph/function/scripteinops.py | 53 ---------- cube/graph/function/select.py | 92 ----------------- cube/graph/parser/mapping.py | 8 -- cube/runtime/function/function.py | 45 -------- tests/parser/test_jit_ops.py | 66 ++++++++++++ 16 files changed, 298 insertions(+), 793 deletions(-) delete mode 100644 cube/algorithm/ops/pad.py delete mode 100644 cube/algorithm/ops/scatter.py delete mode 100644 cube/algorithm/ops/select.py delete mode 100644 cube/graph/function/customops.py delete mode 100644 cube/graph/function/pad.py delete mode 100644 cube/graph/function/repeat.py delete mode 100644 cube/graph/function/scatter.py delete mode 100644 cube/graph/function/scripteinops.py delete mode 100644 cube/graph/function/select.py create mode 100644 tests/parser/test_jit_ops.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index afd0142a..02b14e50 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -67,39 +67,13 @@ def _load_predefined_algos(self): self.register(dimops.IRDimops, dimops.SimpleViewSplitEinops, tag='view_simp') import cube.algorithm.ops.conv as conv + self.register(conv.IRPad, conv.DimSplitPad, tag='dim') self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') self.register(conv.IRConv2D, conv.HaloSplitConv2D, tag='halo') self.register(conv.IRConv3D, conv.HaloSplitConv3D, tag='halo') - - import cube.algorithm.ops.pad as pad - self.register(pad.IRPad, pad.DimSplitPad, tag='dim') - - import cube.algorithm.ops.select as select - self.register(select.IRSelect, select.DimSplitSelect, tag='dim') - self.register(select.IRSlice, select.DimSplitSlice, tag='dim') - - import cube.algorithm.ops.scatter as scatter - self.register(scatter.IRSelectScatter, scatter.DimSplitScatter, tag='dim') import cube.algorithm.ops.creators as creators self.register(creators.IRToTensor, creators.DimSplitTo, tag='dim') self.register(creators.IRZeros, creators.DimSplitZeros, tag='dim') self.register(creators.IROnes, creators.DimSplitOnes, tag='dim') self.register(creators.IRRand, creators.DimSplitRand, tag='dim') - # import cube.algorithm.ops.elementwise as elew - # self.register(elew.ElementWise, elew.ElementWiseDimParallel, tag='dim') - # self.register(elew.Add, elew.AddDimParallel, tag='dim') - - # import cube.algorithm.ops.layernorm as ln - # self.register(ln.LayerNorm, ln.LayerNormDimParallel, tag='dim') - - # import cube.algorithm.ops.activation as activation - # self.register(activation.Activation, activation.ActivationDimParallel, tag='dim') - # self.register(activation.Dropout, activation.DropoutDimParallel, tag='dim') - # self.register(activation.Softmax, activation.SoftmaxDimParallel, tag ='dim') - - # import cube.algorithm.ops.reduce as reduce - # self.register(reduce.Sum, reduce.SumDimParallel, tag='dim') - - # import cube.algorithm.ops.memory as mem - # self.register(mem.Transpose, mem.TransposeDimParallel, tag='dim') diff --git a/cube/algorithm/ops/conv.py b/cube/algorithm/ops/conv.py index e782b76c..d04f470b 100644 --- a/cube/algorithm/ops/conv.py +++ b/cube/algorithm/ops/conv.py @@ -4,8 +4,7 @@ from cube.ir.tensor import IRSubTensor from cube.algorithm.generics import GenericDistAlgo -from cube.graph.function.conv import IRConv2D -from cube.graph.function.conv import IRConv3D +from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D def _split_axis_custom(tensor: IRSubTensor, dim: int, chunks: List[Tuple[int, int]]): @@ -27,6 +26,90 @@ def _split_axis_custom(tensor: IRSubTensor, dim: int, chunks: List[Tuple[int, in return sub_tensors +class DimSplitPad(GenericDistAlgo): + """ + split Pad at dimension level + + """ + def __init__(self, node: IRPad): + if not isinstance(node, IRPad): + raise TypeError(f"Expect IRConv2D") + super().__init__(node) + + def satisfy(self, dim: int, num: int): + """ + config = dict(idx=int, dim=int, num=num) + + """ + assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" + node: IRPad = self.node + pad = node.kwargs['pad'] + mode = node.kwargs['mode'] + value = node.kwargs['value'] + assert len(pad) % 2 == 0 + pad_dim_count = len(pad) / 2 + + # split non-pad dim + if dim < len(node.input(0).shape) - pad_dim_count: + return node.input(0).shape[dim] >= num + # return node.input(0).shape[dim] % num == 0 + # split pad dim + else: + dim_in_pad = len(node.input(0).shape) - 1 - dim + return (node.input(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) >= num + # return (node.input(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) % num == 0 + + def instantiate(self, dim: int, num: int): + if not self.satisfy(dim, num): + return None + node: IRPad = self.node + pad = node.kwargs['pad'] + mode = node.kwargs['mode'] + value = node.kwargs['value'] + pad_dim_count = len(pad) / 2 + + inputs = list() + outputs = list() + subnodes = list() + + # split non-pad dim + if dim < len(node.input(0).shape) - pad_dim_count: + inputs = node.input(0).split_dim(dim, num) + outputs = node.output(0).split_dim(dim, num) + for i, o in zip(inputs, outputs): + subnodes.append(node.new([i], [o])) + else: # split pad dim + inputs = node.input(0).split_dim(dim, num) + slicers = list() + pads = list() + dim_in_pad = len(node.input(0).shape) - 1 - dim + global_padl = pad[dim_in_pad * 2] + global_padr = pad[dim_in_pad * 2 + 1] + chunk_size = (node.output(0).shape[dim] - global_padl - global_padr) // num + addone_num = (node.output(0).shape[dim] - global_padl - global_padr) % num + start = 0 + for cid in range(num): + padl = global_padl if cid == 0 else 0 + padr = global_padr if cid == num-1 else 0 + + cur_pad = pad.copy() + cur_pad[dim_in_pad * 2] = padl + cur_pad[dim_in_pad * 2 + 1] = padr + pads.append(cur_pad) + + addone = int(cid < addone_num) + stop = start + padl + padr + chunk_size + addone + slicers.append((max(0, start), min(node.output(0).shape[dim], stop))) + start = stop + + outputs = _split_axis_custom(node.output(0), dim, tuple(slicers)) + + for i, o, p in zip(inputs, outputs, pads): + subnodes.append(node.new([i], [o], pad=p)) + + return subnodes + + class DimSplitConv2D(GenericDistAlgo): """ split Conv2D at dimension level diff --git a/cube/algorithm/ops/pad.py b/cube/algorithm/ops/pad.py deleted file mode 100644 index 0358ae98..00000000 --- a/cube/algorithm/ops/pad.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import List, Tuple -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.function.pad import IRPad -from cube.ir.tensor import IRSubTensor - - -def _split_axis_custom(tensor: IRSubTensor, dim: int, chunks: List[Tuple[int, int]]): - """ - Split tensor along an axis with customized selection - """ - dim = len(tensor.shape) + dim if dim < 0 else dim - assert dim < len(tensor.shape), f"dim should within ndims ({dim} >= {tensor.ndims})" - chunk_num = len(chunks) - indmap = list() - for nele in tensor.shape: - indmap.append((0, nele)) - sub_tensors = list() - for cid in range(chunk_num): - indmap[dim] = chunks[cid] - sub_tensors.append(tensor.select( - indmap=tuple(indmap), valmap=(0,1) - )) - return sub_tensors - - -class DimSplitPad(GenericDistAlgo): - """ - split Pad at dimension level - - """ - def __init__(self, node: IRPad): - if not isinstance(node, IRPad): - raise TypeError(f"Expect IRConv2D") - super().__init__(node) - - def satisfy(self, dim: int, num: int): - """ - config = dict(idx=int, dim=int, num=num) - - """ - assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" - node: IRPad = self.node - pad = node.kwargs['pad'] - mode = node.kwargs['mode'] - value = node.kwargs['value'] - assert len(pad) % 2 == 0 - pad_dim_count = len(pad) / 2 - - # split non-pad dim - if dim < len(node.input(0).shape) - pad_dim_count: - return node.input(0).shape[dim] >= num - # return node.input(0).shape[dim] % num == 0 - # split pad dim - else: - dim_in_pad = len(node.input(0).shape) - 1 - dim - return (node.input(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) >= num - # return (node.input(0).shape[dim] + pad[dim_in_pad * 2] + pad[dim_in_pad * 2 + 1]) % num == 0 - - def instantiate(self, dim: int, num: int): - if not self.satisfy(dim, num): - return None - node: IRPad = self.node - pad = node.kwargs['pad'] - mode = node.kwargs['mode'] - value = node.kwargs['value'] - pad_dim_count = len(pad) / 2 - - inputs = list() - outputs = list() - subnodes = list() - - # split non-pad dim - if dim < len(node.input(0).shape) - pad_dim_count: - inputs = node.input(0).split_dim(dim, num) - outputs = node.output(0).split_dim(dim, num) - for i, o in zip(inputs, outputs): - subnodes.append(node.new([i], [o])) - else: # split pad dim - inputs = node.input(0).split_dim(dim, num) - slicers = list() - pads = list() - dim_in_pad = len(node.input(0).shape) - 1 - dim - global_padl = pad[dim_in_pad * 2] - global_padr = pad[dim_in_pad * 2 + 1] - chunk_size = (node.output(0).shape[dim] - global_padl - global_padr) // num - addone_num = (node.output(0).shape[dim] - global_padl - global_padr) % num - start = 0 - for cid in range(num): - padl = global_padl if cid == 0 else 0 - padr = global_padr if cid == num-1 else 0 - - cur_pad = pad.copy() - cur_pad[dim_in_pad * 2] = padl - cur_pad[dim_in_pad * 2 + 1] = padr - pads.append(cur_pad) - - addone = int(cid < addone_num) - stop = start + padl + padr + chunk_size + addone - slicers.append((max(0, start), min(node.output(0).shape[dim], stop))) - start = stop - - outputs = _split_axis_custom(node.output(0), dim, tuple(slicers)) - - for i, o, p in zip(inputs, outputs, pads): - subnodes.append(node.new([i], [o], pad=p)) - - return subnodes \ No newline at end of file diff --git a/cube/algorithm/ops/scatter.py b/cube/algorithm/ops/scatter.py deleted file mode 100644 index 6b8c467e..00000000 --- a/cube/algorithm/ops/scatter.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import List, Tuple, Optional -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.function.scatter import IRSelectScatter -from cube.ir.tensor import IRSubTensor - - -class DimSplitScatter(GenericDistAlgo): - """ - split Pad at dimension level - - """ - def __init__(self, node: IRSelectScatter): - if not isinstance(node, IRSelectScatter): - raise TypeError(f"Expect IRSelectScatter") - super().__init__(node) - - def satisfy(self, diml: int, dimr: int, num: int): - """ - config = dict(idx=int, dim=int, num=num) - - """ - assert all(isinstance(t, int) for t in [diml, dimr, num]), "dim and num should be integer" - node: IRSelectScatter = self.node - - assert diml != node.kwargs['dim'], "Split dimension should not be equal to scatter dimension" - assert diml < len(node.input(0).shape), "Split dimension should be smaller than tensor dimension" - assert dimr < len(node.output(0).shape), "Split dimension should be smaller than tensor dimension" - assert node.input(0).shape[diml] == node.input(1).shape[dimr], "Two split dimension should at least have equal size" - - return node.input(0).shape[diml] >= num - - def instantiate(self, diml: int, dimr: int, num: int) -> Optional[List[IRSelectScatter]]: - - node: IRSelectScatter = self.node - satisfy = self.satisfy(diml, dimr, num) - if not satisfy: - return None - - assert len(node.inputs()) == 2, "Select_scatter do not has two inputs" - assert len(node.outputs()) == 1, "Select_scatter do not has one outputs" - - ins, ous = list(), list() - ins.append(node.input(0).split_dim(diml, num)) - ins.append(node.input(1).split_dim(dimr, num)) - - ous.append(node.output(0).split_dim(diml, num)) - - sub_nodes = list() - for nid in range(num): - inputs = tuple([t[nid] for t in ins]) - outputs = [t[nid] for t in ous] - sub_nodes.append(node.new(inputs, outputs)) - return sub_nodes - \ No newline at end of file diff --git a/cube/algorithm/ops/select.py b/cube/algorithm/ops/select.py deleted file mode 100644 index 7cf128cc..00000000 --- a/cube/algorithm/ops/select.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import List, Tuple, Optional - -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.function.select import IRSelect, IRSlice -from cube.ir.tensor import IRSubTensor - - -class DimSplitSelect(GenericDistAlgo): - """ - split Pad at dimension level - - """ - def __init__(self, node: IRSelect): - if not isinstance(node, IRSelect): - raise TypeError(f"Expect IRSelect") - super().__init__(node) - - def satisfy(self, dim: int, num: int): - """ - config = dict(idx=int, dim=int, num=num) - - """ - assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" - node: IRSelect = self.node - - assert dim != node.kwargs['dim'], "Split dimension should not be equal to select dimension" - assert dim < len(node.input(0).shape), "Split dimension should be smaller than tensor dimension" - - # split non-pad dim - return node.input(0).shape[dim] >= num - - def instantiate(self, dim: int, num: int) -> Optional[List[IRSelect]]: - - node: IRSelect = self.node - satisfy = self.satisfy(dim, num) - if not satisfy: - return None - - ins, ous = list(), list() - for iidx, itensor in enumerate(node.inputs()): - assert isinstance(itensor, IRSubTensor), "Input of select shoud be IRSubTensor" - ins.append(itensor.split_dim(dim, num)) - - odim = dim - int(node.kwargs['dim'] < dim) - - for oidx, otensor in enumerate(node.outputs()): - assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" - ous.append(otensor.split_dim(odim, num)) - - sub_nodes = list() - for nid in range(num): - inputs = [t[nid] for t in ins] - outputs = [t[nid] for t in ous] - sub_nodes.append(node.new(inputs, outputs)) - return sub_nodes - - -class DimSplitSlice(GenericDistAlgo): - """ - split Pad at dimension level - - """ - def __init__(self, node: IRSlice): - if not isinstance(node, IRSlice): - raise TypeError(f"Expect IRSlice") - super().__init__(node) - - def satisfy(self, dim: int, num: int): - assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" - node: IRSlice = self.node - - if dim == node.kwargs['dim']: - return None - assert dim < len(node.input(0).shape), "Split dimension should be smaller than tensor dimension" - - # split non-pad dim - return node.input(0).shape[dim] >= num - - def instantiate(self, dim: int, num: int) -> Optional[List[IRSlice]]: - - node: IRSlice = self.node - print(dim, node.kwargs['dim']) - satisfy = self.satisfy(dim, num) - if not satisfy: - return None - - ins, ous = list(), list() - for iidx, itensor in enumerate(node.inputs()): - assert isinstance(itensor, IRSubTensor), "Input of select shoud be IRSubTensor" - ins.append(itensor.split_dim(dim, num)) - - for oidx, otensor in enumerate(node.outputs()): - assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" - ous.append(otensor.split_dim(dim, num)) - - sub_nodes = list() - for nid in range(num): - inputs = [t[nid] for t in ins] - outputs = [t[nid] for t in ous] - sub_nodes.append(node.new(inputs, outputs)) - return sub_nodes - \ No newline at end of file diff --git a/cube/graph/function/conv.py b/cube/graph/function/conv.py index 8b2357d1..27ed3dff 100644 --- a/cube/graph/function/conv.py +++ b/cube/graph/function/conv.py @@ -4,6 +4,53 @@ from cube.ir.cten import IRTensor +class IRPad(IRFwOperation): + def __init__(self, signature: str, inputs: List[IRTensor], name: str, + **kwargs): + # torch.nn.functional.pad(input, pad, mode='constant', value=0.0) + # pad: List[int] + signature = 'torch.nn.functional.pad' + assert len(inputs) == 1, "Expected only input, weight, bias as inputs" + assert len(kwargs) == 3, "Expected 2 kwargs: mode, value" + super().__init__(name, signature, 1, 1) + for idx, input in enumerate(inputs): + self.set_input(idx, input) + self.kwargs.update(kwargs) + + def infer_shape(self) -> bool: + """ + Output shape inference given the input shapes + """ + if len(self.input(0).shape) == 0: + return False + + pad = self.kwargs['pad'] + assert len(pad) % 2 == 0, "IRPad::infer_shape len(pad) % 2 == 0" + + shape = self.input(0).shape + for pad_idx, pad_size in enumerate(pad): + shape[-1 - (pad_idx // 2)] += pad_size + + self.output(0).shape = shape + return True + + def new(self, inputs: List, outputs: List, pad = None): + """ + construct a new operator sharing same kwargs with new inputs + and outputs + """ + if pad == None: + pad = self.kwargs['pad'] + mode = self.kwargs['mode'] + value = self.kwargs['value'] + op = IRPad(self.signature, inputs, self.name, + pad=pad, mode=mode, value=value) + assert len(outputs) == 1 + op.set_output(0, outputs[0]) + op.infer_shape() + return op + + class IRConv2D(IRFwOperation): def __init__(self, signature: str, inputs: List[IRTensor], name: str, @@ -53,7 +100,6 @@ def new(self, inputs: List, outputs: List): return op - class IRConv3D(IRFwOperation): def __init__(self, signature: str, inputs: List[IRTensor], name: str, @@ -106,4 +152,4 @@ def new(self, inputs: List, outputs: List): assert len(outputs) == 1 op.set_output(0, outputs[0]) op.infer_shape() - return op \ No newline at end of file + return op diff --git a/cube/graph/function/customops.py b/cube/graph/function/customops.py deleted file mode 100644 index 4ccd206c..00000000 --- a/cube/graph/function/customops.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import List - -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor - -class IRCustomOps(IRFwOperation): - def __init__(self, signature: str, inputs: List[IRTensor], name: str, - **kwargs): - # torch.nn.functional.pad(input, pad, mode='constant', value=0.0) - # pad: List[int] - if signature == 'examples.custom_ops.strip_2_borders': - signature = signature.replace('examples.custom_ops', 'cube.runtime.function')#'cube.runtime.function.strip_2_borders' - assert len(inputs) == 1, "Expected only input, weight, bias as inputs" - assert len(kwargs) == 0, "Expected 0 kwargs: " - super().__init__(name, signature, 1, 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - elif signature == 'examples.custom_ops.update_diag_': - signature = signature.replace('examples.custom_ops', 'cube.runtime.function') - assert len(inputs) == 10, "Expected only input, weight, bias as inputs" - assert len(kwargs) == 1, "Expected 0 kwargs: " - super().__init__(name, signature, len(inputs), 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update(kwargs) - elif signature == 'examples.custom_ops.update_geopotential_': - signature = signature.replace('examples.custom_ops', 'cube.runtime.function') - assert len(inputs) == 5, "Expected only input, weight, bias as inputs" - assert len(kwargs) == 3, "Expected 0 kwargs: " - super().__init__(name, signature, len(inputs), 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update(kwargs) - else: - raise RuntimeError(f'IRCustomOps::__init__ unknown signature: {self.signature}') - - - def infer_shape(self) -> bool: - """ - Output shape inference given the input shapes - """ - if self.signature.endswith('strip_2_borders'): - if len(self.input(0).shape) == 0: - return False - shape = self.input(0).shape - shape[0] = shape[0]-2 - self.output(0).shape = shape - return True - elif self.signature.endswith('update_diag_'): - shape = self.input(0).shape - self.output(0).shape = shape - return True - elif self.signature.endswith('update_geopotential_'): - shape = self.input(0).shape - self.output(0).shape = shape - return True - else: - raise RuntimeError(f'IRCustomOps::infer_shape unknown signature: {self.signature}') - - def new(self, inputs: List, outputs: List): - """ - construct a new operator sharing same kwargs with new inputs - and outputs - """ - if self.signature.endswith('strip_2_borders'): - op = IRCustomOps(self.signature, inputs, self.name,) - assert len(outputs) == 1 - op.set_output(0, outputs[0]) - op.infer_shape() - return op - elif self.signature.endswith('update_diag_'): - op = IRCustomOps(self.signature, inputs, self.name, self.kwargs) - assert len(outputs) == 1 - op.set_output(0, outputs[0]) - op.infer_shape() - return op - elif self.signature.endswith('update_geopotential_'): - op = IRCustomOps(self.signature, inputs, self.name, self.kwargs) - assert len(outputs) == 1 - op.set_output(0, outputs[0]) - op.infer_shape() - return op - else: - raise RuntimeError(f'IRCustomOps::new unknown signature: {self.signature}') - diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 87ab5514..20996851 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -9,14 +9,8 @@ from cube.ir.tensor import IRSubTensor from cube.ir.dtype import IRDType from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule -from cube.graph.function.conv import IRConv2D -from cube.graph.function.conv import IRConv3D -from cube.graph.function.pad import IRPad -from cube.graph.function.scripteinops import IRScriptEinOps +from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D from cube.graph.function.creators import IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor -from cube.graph.function.select import IRSelect, IRSlice -from cube.graph.function.scatter import IRSelectScatter -from cube.graph.function.repeat import IRRepeat from cube.graph.function.anchor import IRGraphAnchor @@ -761,6 +755,46 @@ def Reshape(signature, inputs): return View(signature, inputs) +# def Pad(signature, inputs): +# """ +# torch.nn.functional.pad(input: torch.Tensor, pad: List[int], mode='constant', value=0.0) +# """ +# signature = 'torch.nn.functional.pad' +# tensor, pad, mode, value = inputs +# ianno = ShapeAnno.create_shape_str(tensor.shape) +# oanno = [] +# ndims = len(pad) // 2 +# for dim in range(ndims): +# pad_left, pad_right = pad[2 * dim], pad[2 * dim + 1] +# if pad_left == 0 and pad_right == 0: +# oanno.insert(0, ianno[-1-dim]) +# else: +# ianno[-1-dim] = str(tensor.shape[-1-dim]) +# oanno.insert(0, str(tensor.shape[-1-dim] + pad_left + pad_right)) +# oanno = copy.copy(ianno[:len(tensor.shape) - ndims]) + oanno +# anno = OpAnno.create_op_str([ianno], [oanno]) +# return IRDimops(Pad, 'pad', signature, [anno], [tensor], pad=pad, mode=mode, value=value) + + +def Pad(signature, inputs): + """ + torch.nn.functional.pad(input, pad, mode='constant', value=0.0) + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad + :param signature: + :param inputs: + :return: + """ + # print("#Pad::inputs.len: {}".format(len(inputs))) + # idx = 0 + # for input in inputs: + # if idx >= 0: + # print("#Pad::input[{}]: {}".format(idx, input)) + # idx += 1 + tensors = inputs[0:1] + pad, mode, value = inputs[1:] + return IRPad(signature, tensors, 'pad', pad=pad, mode=mode, value=value) + + # def Conv2D(signature, inputs): # """ # torch.conv2d(input, weight, bias, stride, padding, dialation, groups) @@ -823,24 +857,6 @@ def Conv3D(signature, inputs): return IRConv3D(signature, tensors, 'conv3d', stride=stride, padding=padding, dilation=dilation, groups=groups) -def Pad(signature, inputs): - """ - torch.nn.functional.pad(input, pad, mode='constant', value=0.0) - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad - :param signature: - :param inputs: - :return: - """ - # print("#Pad::inputs.len: {}".format(len(inputs))) - # idx = 0 - # for input in inputs: - # if idx >= 0: - # print("#Pad::input[{}]: {}".format(idx, input)) - # idx += 1 - tensors = inputs[0:1] - pad, mode, value = inputs[1:] - return IRPad(signature, tensors, 'pad', pad=pad, mode=mode, value=value) - def Accum(signature, inputs: Tuple[IRTensor]): """ @@ -928,33 +944,77 @@ def Select(signature, inputs: Tuple[IRTensor, int, int]): oanno.pop(dim) anno = OpAnno.create_op_str([ianno], [oanno]) return IRDimops(Select, 'select', signature, [anno], [tensor], dim=dim, index=index) - # return IRSelect(signature, [tensor], 'select', dim, index) -def Slice(signature, inputs: Tuple[IRTensor, int, Optional[int], Optional[int], int]): + +def Slice(signature, inputs): """ aten::slice(input:Tensor, dim:int, start:Optional[int], end:Optional[int], step:int) -> Tensor """ + signature = 'torch.ops.aten.slice' tensor, dim, start, end, step = inputs - return IRSlice(signature, [tensor], 'slice', dim, start, end, step) + ianno = ShapeAnno.create_shape_str(tensor.shape) + oanno = copy.copy(ianno) + ianno[dim] = str(tensor.shape[dim]) + + def clip(ofst): + ofst = ofst + tensor.shape[dim] if ofst < 0 else ofst + return min(tensor.shape[dim], max(0, ofst)) -def SelectScatter(signature, inputs:Tuple[IRTensor, IRTensor, int, int]): + # set start and end to possitive itegers + start = 0 if start is None else start + end = tensor.shape[dim] if end is None else end + start, end = clip(start), clip(end) + + oanno[dim] = str(len(range(start, end, step))) + anno = OpAnno.create_op_str([ianno], [oanno]) + return IRDimops(Slice, 'slice', signature, [anno], [tensor], dim=dim, start=start, end=end, step=step) + + +def SelectScatter(signature, inputs: Tuple[IRTensor, IRTensor, int, int]): """ torch.select_scatter(self:Tensor, input:Tensor, dim:int, index:int) -> Tensor """ + # 'torch.select_scatter' isn't supported by Torch2ONNX yet. + signature = 'cube.runtime.function.select_scatter' self, input, dim, index = inputs - return IRSelectScatter(signature, [self, input], 'scatter_select', dim, index) - - -def Repeat(signature, inputs:Tuple[IRTensor, List[int]]): + # shape check + self_shape, input_shape = self.shape, input.shape + self_shape.pop(dim) + assert tuple(self_shape) == tuple(input_shape) + in1_anno = ShapeAnno.create_shape_str(self.shape) + in2_anno = in1_anno.copy() + in2_anno.pop(dim) + in1_anno[dim] = str(self.shape[dim]) + out_anno = in1_anno.copy() + anno = OpAnno.create_op_str([in1_anno, in2_anno], [out_anno]) + return IRDimops(SelectScatter, 'select_scatter', signature, + [anno], [self, input], dim=dim, index=index) + + +def Repeat(signature, inputs: Tuple[IRTensor, List[int]]): """ torch.repeat(tensor:Tensor, repeats: List[int]) -> Tensor """ + signature = 'torch.ops.aten.repeat' tensor, repeats = inputs - - assert signature == 'torch.repeat' # this is the API in TorchScript - signature = 'torch.Tensor.repeat' # this is the API in Python frontend and is not a Tensor member method - - return IRRepeat(signature, [tensor], 'repeat', repeats) + in_shape = tensor.shape + assert len(in_shape) <= len(repeats), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor" + expand = len(repeats) - len(tensor.shape) + in_shape += [1] * expand + ou_shape = [dimlen * repeat for dimlen, repeat in zip(in_shape, repeats)] + ianno, oanno = ShapeAnno.create_shape_str(in_shape), [] + for dim, dimlen in enumerate(ou_shape): + if dim < expand: + oanno.append(str(dimlen)) + else: + if repeats[dim] != 1: + ianno[dim] += '^' + dim_anno = [str(repeats[dim]), ianno[dim]] + else: + dim_anno = ianno[dim] + oanno.append(dim_anno) + anno = OpAnno.create_op_str([ianno[expand:]], [oanno]) + return IRDimops(Repeat, 'repeat', signature, [anno], [tensor], repeats=repeats) def Embedding(signature, inputs: List): @@ -1049,22 +1109,6 @@ def GraphAnchor(signature, inputs: List[IRSubTensor]): return node -def ScriptEinOps(signature, inputs): - """ - apply_for_scriptable_torch(recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str) -> torch.Tensor: - https://github.com/arogozhnikov/einops/blob/master/einops/_torch_specific.py - :param signature: - :param inputs: - :return: - """ - recipe = inputs[0] - tensors = inputs[1:2] - reduction_type = inputs[2] - import pickle - recipe_str = pickle.dumps(recipe) - return IRScriptEinOps(signature, tensors, 'scripteinops', recipe_str=recipe_str, reduction_type=reduction_type) - - def _comparison(creator: Callable, f: Callable, name: str, signature: str, inputs): """ if both operands are scalars, returns bool. diff --git a/cube/graph/function/pad.py b/cube/graph/function/pad.py deleted file mode 100644 index ee50cb8a..00000000 --- a/cube/graph/function/pad.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import List - -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor - -class IRPad(IRFwOperation): - def __init__(self, signature: str, inputs: List[IRTensor], name: str, - **kwargs): - # torch.nn.functional.pad(input, pad, mode='constant', value=0.0) - # pad: List[int] - signature = 'torch.nn.functional.pad' - assert len(inputs) == 1, "Expected only input, weight, bias as inputs" - assert len(kwargs) == 3, "Expected 2 kwargs: mode, value" - super().__init__(name, signature, 1, 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update(kwargs) - - def infer_shape(self) -> bool: - """ - Output shape inference given the input shapes - """ - if len(self.input(0).shape) == 0: - return False - - N = self.input(0).shape[0] - pad = self.kwargs['pad'] - mode = self.kwargs['mode'] - value = self.kwargs['value'] - assert len(pad) % 2 == 0, "IRPad::infer_shape len(pad) % 2 == 0" - - shape = self.input(0).shape - for pad_idx, pad_size in enumerate(pad): - shape[-1 - (pad_idx // 2)] += pad_size - - self.output(0).shape = shape - return True - - def new(self, inputs: List, outputs: List, pad = None): - """ - construct a new operator sharing same kwargs with new inputs - and outputs - """ - if pad == None: - pad = self.kwargs['pad'] - mode = self.kwargs['mode'] - value = self.kwargs['value'] - op = IRPad(self.signature, inputs, self.name, - pad=pad, mode=mode, value=value) - assert len(outputs) == 1 - op.set_output(0, outputs[0]) - op.infer_shape() - return op diff --git a/cube/graph/function/repeat.py b/cube/graph/function/repeat.py deleted file mode 100644 index ef0dbf18..00000000 --- a/cube/graph/function/repeat.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import List, Optional, Tuple -import itertools - -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor - -class IRRepeat(IRFwOperation): - """ - torch.repeat(tensor:Tensor, repeats: List[int]) -> Tensor - """ - - def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, repeats:List[int]): - assert len(inputs) == 1 - assert isinstance(repeats, list) - assert all(isinstance(r, int) for r in repeats) - - super().__init__(name, signature, 1, 1) - self.set_input(0, inputs[0]) - self.kwargs.update({"repeats": repeats}) - - def infer_shape(self) -> bool: - shp_self : List[int] = self.input(0).shape - if len(shp_self) == 0: - return False - - repeats : List[int] = self.kwargs["repeats"] - - # This API broadcasts the input tensor if the specified `repeats:list` is longer than the shape. - s1 = shp_self.copy() - s1.reverse() - s2 = repeats.copy() - s2.reverse() - - # Multiply from the end - shp = [d1 * d2 for d1, d2 in itertools.zip_longest(s1, s2, fillvalue=1)] - shp.reverse() - - self.output(0).shape = shp - return True - diff --git a/cube/graph/function/scatter.py b/cube/graph/function/scatter.py deleted file mode 100644 index dcb2fe97..00000000 --- a/cube/graph/function/scatter.py +++ /dev/null @@ -1,66 +0,0 @@ -from copy import copy -from typing import List, Optional, Tuple - -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor - -class IRSelectScatter(IRFwOperation): - """ - torch.select_scatter(self:Tensor, input:Tensor, dim:int, index:int) -> Tensor - - identical to: - ``` - x = self.copy() # Assume N-d tensor. - view = x.select(dim, index) # View and input are (N-1)-d tensors. - view.copy_(input) # See REMARK! - return x - ``` - - REMARK: - Unlike the `copy_` API in the identical code snippet above, - `select_scatter` (as well as other scatter family APIs) are NOT broadcastable, - namely it requires the `input` tensor to embed is an exactly (N-1)-dimensional tensor. - - But in-place Python code like - ``` - self[index] = input - ``` - involves broadcasting, so `input` can has any broadcastable shapes to `self.shape.pop(dim)`, - including being scalars. - """ - - def __init__(self, signature: str, inputs:Tuple[IRTensor, IRTensor], name: str, dim:int, index:int): - assert len(inputs) == 2 - - signature = 'cube.runtime.function.select_scatter' - super().__init__(name, signature, 2, 1) - self.set_input(0, inputs[0]) - self.set_input(1, inputs[1]) - self.kwargs.update({"dim": dim, "index": index}) - - def infer_shape(self) -> bool: - shp_self : List[int] = self.input(0).shape - if len(shp_self) == 0: - return False - - shp_input = self.input(1).shape - - if len(shp_input) == 0: - print("The 0-length input shape is ambiguous, may be uninferrable or just of a 0-d tensor") - elif len(shp_input) > 0: - dim: int = self.kwargs["dim"] - copy_shp = shp_self.copy() - copy_shp.pop(dim) - if copy_shp != shp_input: - raise RuntimeError(f"self shape {shp_self} and input shape {shp_input} with dim={dim} mismatch") - - s2 = copy(shp_self) - self.output(0).shape = s2 - return True - - def new(self, inputs:Tuple[IRTensor, IRTensor], outputs: List[IRTensor]): - op = IRSelectScatter(self.signature, inputs, self.name, self.kwargs['dim'], self.kwargs['index']) - assert len(outputs) == 1, "Select_scatter: too many outputs" - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IRSelect::new infer_shape failed" - return op diff --git a/cube/graph/function/scripteinops.py b/cube/graph/function/scripteinops.py deleted file mode 100644 index 29b017ac..00000000 --- a/cube/graph/function/scripteinops.py +++ /dev/null @@ -1,53 +0,0 @@ - -from typing import List - -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor - -from einops.einops import _apply_recipe - -import torch - -class IRScriptEinOps(IRFwOperation): - - def __init__(self, signature: str, inputs: List[IRTensor], name: str, - **kwargs): - signature = 'cube.runtime.function.einops' - assert len(inputs) == 1, "Expected only input" - assert len(kwargs) == 2, "Expected 2 kwargs: recipe_str, reduction_type" - super().__init__(name, signature, 1, 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update(kwargs) - - def infer_shape(self) -> bool: - """ - Output shape inference given the input shapes - """ - if len(self.input(0).shape) == 0: - return False - - recipe_str = self.kwargs['recipe_str'] - import pickle - recipe = pickle.loads(recipe_str) - - reduction_type = self.kwargs['reduction_type'] - tmp_tensor = torch.zeros(self.input(0).shape) - tmp_output = _apply_recipe(recipe, tmp_tensor, reduction_type) - self.output(0).shape = list(tmp_output.shape) - return True - - def new(self, inputs: List, outputs: List): - """ - construct a new operator sharing same kwargs with new inputs - and outputs - """ - recipe_str = self.kwargs['recipe_str'] - reduction_type = self.kwargs['reduction_type'] - op = IRScriptEinOps(self.signature, inputs, self.name, - recipe_str=recipe_str, reduction_type=reduction_type) - assert len(outputs) == 1 - op.set_output(0, outputs[0]) - op.infer_shape() - return op - diff --git a/cube/graph/function/select.py b/cube/graph/function/select.py deleted file mode 100644 index 1f8739d4..00000000 --- a/cube/graph/function/select.py +++ /dev/null @@ -1,92 +0,0 @@ -from copy import copy -from typing import List, Optional, Tuple - -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor - -class IRSelect(IRFwOperation): - """ - torch.select(input:Tensor, dim:int, index:int) -> Tensor - """ - def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, dim:int, index:int): - assert len(inputs) == 1 - - super().__init__(name, signature, 1, 1) - self.set_input(0, inputs[0]) - self.kwargs.update({"dim": dim, "index": index}) - - def infer_shape(self) -> bool: - s : List[int] = self.input(0).shape - if len(s) == 0: - return False - - dim = self.kwargs["dim"] - - s2 = copy(s) - s2.pop(dim) - self.output(0).shape = s2 - - return True - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - op = IRSelect(self.signature, inputs, self.name, self.kwargs['dim'], self.kwargs['index']) - assert len(outputs) == 1 - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IRSelect::new infer_shape failed" - return op - -class IRSlice(IRFwOperation): - """ - aten::slice(input:Tensor, dim:int=0, start:Optional[int]=None, end:Optional[int]=None, step:int=1) -> Tensor - """ - - def __init__(self, signature: str, inputs:Tuple[IRTensor], name: str, - dim:int, start:Optional[int], end:Optional[int], step:int): - assert len(inputs) == 1 - - super().__init__(name, signature, 1, 1) - self.set_input(0, inputs[0]) - self.kwargs.update({"dim": dim, "start": start, "end": end, "step": step}) - - def infer_shape(self) -> bool: - s : List[int] = self.input(0).shape - if len(s) == 0: - return False - - dim : int = self.kwargs["dim"] - start : Optional[int] = self.kwargs["start"] - end : Optional[int] = self.kwargs["end"] - step : int = self.kwargs["step"] - - if start is None: - start = 0 - if end is None: - end = 2 ** 64 - - dim_len = s[dim] - - def clip(offset): - if offset < 0: - offset += dim_len - return min(dim_len, max(0, offset)) - - start = clip(start) - end = clip(end) - - sliced_dim_len = len(range(start, end, step)) - s2 = s.copy() - s2[dim] = sliced_dim_len - self.output(0).shape = s2 - - return True - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - assert len(inputs) == 1, "Slice: number of inputs not equal to 1" - op = IRSlice(self.signature, inputs, self.name, self.kwargs['dim'], self.kwargs['start'], self.kwargs['end'], self.kwargs['step']) - assert len(outputs) == 1 - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IRSlice::new infer_shape failed" - return op - - -# torch.gather(input:Tensor, dim:int, index:LongTensor, *, sparse_grad=False, out=None) -> Tensor diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index a35d4e35..257b4dc3 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -51,11 +51,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # runtime template __rtemplate = lambda name: f'cube.runtime.function.function.{name}' - # einops - __einopsize = lambda name: f'einops._torch_specific.{name}' - - # custom ops - __customops = lambda name: f'examples.custom_ops.{name}' kOpMap = { @@ -168,9 +163,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __rtemplate('accum'): function.Accum, - #einops - __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, - } # customized operator code: signature -> code diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 28933b28..bd7c5ec3 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -112,48 +112,3 @@ def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): return torch.masked_scatter(input, mask, src) - -def einops(input: torch.Tensor, recipe_str, reduction_type: str): - import pickle - recipe = pickle.loads(recipe_str) - from einops.einops import _apply_recipe - output = _apply_recipe(recipe, input, reduction_type) - return output - -############### custom op ################# -#TODO move me -def update_diag_(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, - delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, - pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dt, dz): - import einops.einops - def pre_conv3d_reshape(X): - return einops.einops.rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) - def post_conv3d_reshape(X): - return einops.einops.rearrange(X, 'b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') - def delta_x(X): - return post_conv3d_reshape(torch.nn.functional.conv3d(pre_conv3d_reshape(X), delta_x_filter)) - def delta_y(X): - return post_conv3d_reshape(torch.nn.functional.conv3d(pre_conv3d_reshape(X), delta_y_filter)) - # update diagnostic variable w (nz + 1, ny, nx) - import warnings - warnings.warn("detaching w in update_diag_...") - w.detach_() #to prevent ERROR: A leaf Variable that requires grad is being used in an in-place operation. - for i in range(1, w.shape[0]): - w[i] = - ((delta_x(F[:i]) + delta_y(G[:i])) * dz).sum(dim=0) / deltaA / pi1 \ - - sigma[i] * (pi1 - pi0) / dt / pi1 - - return w - -def update_geopotential_(phi: torch.Tensor, zs: torch.Tensor, P: torch.Tensor, P_: torch.Tensor, theta: torch.Tensor, g, CPD, nz): - import warnings - warnings.warn("detaching phi in update_geopotential_...") - phi.detach_() - phi[-1] = g * zs - CPD * (P[-1] - P_[-1]) * theta[-1] - for i in range(1, nz): - tmp = phi[-i] - CPD * (P_[-i - 1] - P[-i]) * theta[-i] - phi[-1 - i] = tmp - CPD * (P[-1 - i] - P_[-1 - i]) * theta[-1 - i] - - return phi - -def strip_2_borders(w: torch.Tensor): - return w[1:-1] diff --git a/tests/parser/test_jit_ops.py b/tests/parser/test_jit_ops.py new file mode 100644 index 00000000..50456120 --- /dev/null +++ b/tests/parser/test_jit_ops.py @@ -0,0 +1,66 @@ +""" +torchrun --nproc_per_node=1 tests/parser/test_torch_ops.py +""" +import torch + +import cube +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops + + +class TestOpModule(torch.nn.Module): + + def __init__(self, shape=[256, 512]): + super().__init__() + self.param = torch.nn.Parameter(torch.empty(shape, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + # [256, 512], [256, 512] -> [256, 512] + x = x * self.param + # [256, 512] -> [512] + x1 = x.select(0, 6) + # [256, 512], [512] -> [256, 512] + x2 = x.select_scatter(x1, 0, 7) + # [256, 512] -> [512, 512] + x3 = x2.repeat(2, 1) + # [512, 512] -> [256, 512]: this will be parsed to 2 slice operations + x4 = x3[:256,:] + return x4 + + +def test_parse_ops(): + + cube.init() + + model = TestOpModule() + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([256, 512],), dtypes=(torch.float32,), batch_dims=(0,)) + + def policy(graph, resource): + assert resource.ngpus == 1 + for node in graph.nodes(): + if isinstance(node, IRDimops): + print(f'# {node.anno}') + print(node) + elif isinstance(node, (IRFwOperation, IRDataOperation)): + print(node) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + model = cube.SemanticModel(model) + + @cube.compile(model, dataloader, policy, load_content=False) + def eval_iter(model, dataloader): + data = next(dataloader) + out = model(data) + + model = model.get_gen_module() + + for _ in range(3): + eval_iter(model, dataloader) + + +if __name__ == '__main__': + test_parse_ops() + From a27470fef56a462d5ca730b1293e97a89ef1d03f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 20 Feb 2023 09:26:51 +0800 Subject: [PATCH 1247/1892] remove pytorch pip install requirements into prerequisite --- README.md | 66 +++++------------------------------------ cube/__init__.py | 11 +++++++ cube/compiler.py | 12 +++----- cube/profiler/README.md | 54 +++++++++++++++++++++++++++++++++ requirements.txt | 4 --- 5 files changed, 76 insertions(+), 71 deletions(-) create mode 100644 cube/profiler/README.md diff --git a/README.md b/README.md index ac2dc1d0..2c0e108d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # MagicCube -AI System Compiler to map a semantic (single-device) model into distributed execution using policies specified by System Expert. +AI System Compiler to map a semantic (single-device) model into distributed execution using policies specified by developers. ## Prerequisite @@ -11,6 +11,10 @@ AI System Compiler to map a semantic (single-device) model into distributed exec Install dependent packages ```shell pip install -r requirements.txt + +# require pytorch version >= 1.11 +pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch/ +# pip install torch==1.11.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html ``` ## Option 1: Quick Start without Installation @@ -20,7 +24,7 @@ pip install -r requirements.txt PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py + examples/mlp/linears.py --policy PASCol ``` [comment]: <> (UDA_VISIBLE_DEVICES=7 PYTHONPATH=.:$PYTHONPATH python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 ./examples/wrf/wrf2.py) @@ -38,7 +42,6 @@ PYTHONPATH=.:$PYTHONPATH SINGLE_DEV_MODE=1 python examples/mlp/linears.py * ### Install ```python -pip install -r requirements.txt python setup.py develop ``` @@ -49,60 +52,5 @@ python setup.py develop OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py -``` - -## Profile - -### Use cProfile + snakeviz - -Due to the multi-process architecture of `torch.distributed.launch`, instead of directly using -the command-line interface of cProfile, we need to exactly specify the scope to profile, like: - -```python -import cProfile -prof = cProfile.Profile() -prof.enable() - -# our code to profile goes here -@cube.compile(...) -def iter(dataloader): - x, y = next(dataloader) - z = model(x, y) - return z -for i in range(N): - iter(...) -# our code ends - -prof.disabled() -prof.dump_stats('cube_RANK%d.prof' % torch.distributed.get_rank()) # or use TID/PID, if to profile multi-thread/-process program. -``` - -After the modification, run the Python file using the same command line with `torchrun` as usual. - -After dumping the profiling data, we can use `snakeviz` to visualize it: - -```shell -pip install snakeviz -snakeviz cube_RANK0.prof + examples/mlp/linears.py --policy PASCol ``` - -### Use viztracer - -An alternative to cProfile + snakeviz is to use the profiler `viztracer`, -as well as its builtin visualization. - -`viztracer` is aware of the multi-process architecture of `torchrun` and it offers a command-line -interface and offers a very detailed profiling log, including the sequence, timing and durations. - -> P.S. However, too detailed to be effectively used to profile huge DAG like the 23k~ nodes unrolled -> WRF model, it would output very big log file and be very slow to render. - -`viztracer` can be used like: - -```shell -pip install viztracer -viztracer --log_multiprocess torchrun --nproc_per_node=4 --nnodes=1 examples/mlp/linears.py -``` - -For more configurations please check `viztracer -h`. \ No newline at end of file diff --git a/cube/__init__.py b/cube/__init__.py index 191a1afd..7c1e3188 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,14 +1,25 @@ +import warnings from cube import runtime from cube import profiler from cube.compiler import SemanticModel, compile +def _check_torch_version(): + import torch + torch_version = str(torch.__version__).split('+')[0] + torch_version = float('.'.join(torch_version.split('.')[:2])) + if torch_version < 1.11: + warnings.warn(f"Expected PyTorch version >= 1.11 but got {torch_version}") + + def init(): _ = runtime.device.DeviceGroup() _ = runtime.resource.EnvResource() +_check_torch_version() + # ================== Experimental Feature ======================= diff --git a/cube/compiler.py b/cube/compiler.py index c022b915..fa093d6d 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -33,15 +33,11 @@ def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, """ AI Scientist calls like: - @cube.compile(model, dataloader, policy=(trans_policy, schedule_policy)) + @cube.compile(model, dataloader, PAS=policy) def train_step(model, dataloader): - # do a 4-time gradient accumulation - for acc_step, (data, label) in enumerate(dataloader): - if acc_step < 4: - loss = model(data, label) - loss.backward() - else: - break + data = next(dataloader) + loss = model(data) + loss.backward() ... for epoch in range(100): diff --git a/cube/profiler/README.md b/cube/profiler/README.md new file mode 100644 index 00000000..6a02dc9a --- /dev/null +++ b/cube/profiler/README.md @@ -0,0 +1,54 @@ +## Profile + +### Use cProfile + snakeviz + +Due to the multi-process architecture of `torch.distributed.launch`, instead of directly using +the command-line interface of cProfile, we need to exactly specify the scope to profile, like: + +```python +import cProfile +prof = cProfile.Profile() +prof.enable() + +# our code to profile goes here +@cube.compile(...) +def iter(dataloader): + x, y = next(dataloader) + z = model(x, y) + return z +for i in range(N): + iter(...) +# our code ends + +prof.disabled() +prof.dump_stats('cube_RANK%d.prof' % torch.distributed.get_rank()) # or use TID/PID, if to profile multi-thread/-process program. +``` + +After the modification, run the Python file using the same command line with `torchrun` as usual. + +After dumping the profiling data, we can use `snakeviz` to visualize it: + +```shell +pip install snakeviz +snakeviz cube_RANK0.prof +``` + +### Use viztracer + +An alternative to cProfile + snakeviz is to use the profiler `viztracer`, +as well as its builtin visualization. + +`viztracer` is aware of the multi-process architecture of `torchrun` and it offers a command-line +interface and offers a very detailed profiling log, including the sequence, timing and durations. + +> P.S. However, too detailed to be effectively used to profile huge DAG like the 23k~ nodes unrolled +> WRF model, it would output very big log file and be very slow to render. + +`viztracer` can be used like: + +```shell +pip install viztracer +viztracer --log_multiprocess torchrun --nproc_per_node=4 --nnodes=1 examples/mlp/linears.py +``` + +For more configurations please check `viztracer -h`. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9d5c2fed..94cf5b68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,4 @@ -einops matplotlib pytest setuptools==60.7.0 more-itertools - ---find-links https://download.pytorch.org/whl/torch_stable.html -torch>=1.11.0+cu113 \ No newline at end of file From fffd48a189cc4b5d055f3676fab0c7eefe665fb7 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 20 Feb 2023 09:28:02 +0800 Subject: [PATCH 1248/1892] add test examples --- tests/test_examples.sh | 54 +++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 84845983..c69c616d 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -53,45 +53,55 @@ OMP_NUM_THREADS=4 torchrun \ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMegatron --fp16 + examples/nlp/gpt/train.py --policy PAS1F1B --fp16 OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMeshShard --fp16 - - -# test Swin model + examples/nlp/gpt/train.py --policy PASMegatron --fp16 OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/vision/swin/train.py --policy PASData --fp16 + examples/nlp/gpt/train.py --policy PASMeshShard --fp16 OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ + --nproc_per_node=2 \ --nnodes=1 \ - examples/vision/swin/train.py --policy PASMegatronTP --fp16 + examples/nlp/gpt/infer.py --policy PASDP --fp16 -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/vision/swin/train.py --policy PASMegatron --fp16 +# test Swin model -# test scientific model +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# examples/vision/swin/train.py --policy PASData --fp16 OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/poisson/sci.py + examples/vision/swin/train.py --policy PASMegatronTP --fp16 -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - examples/wrf/wrf2.py --policy PAS +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# examples/vision/swin/train.py --policy PASMegatron --fp16 -OMP_NUM_THREADS=1 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/wrf/wrf2.py --policy PAS_ALL_Y + +# test scientific model + +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# examples/poisson/sci.py +# +# OMP_NUM_THREADS=4 torchrun \ +# --nproc_per_node=1 \ +# --nnodes=1 \ +# examples/wrf/wrf2.py --policy PAS +# +# OMP_NUM_THREADS=1 torchrun \ +# --nproc_per_node=4 \ +# --nnodes=1 \ +# examples/wrf/wrf2.py --policy PAS_ALL_Y From 725e23b9f6cf5f25bdce10dc6d8965ea5dddfc3b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 20 Feb 2023 09:29:21 +0800 Subject: [PATCH 1249/1892] update doc --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2c0e108d..5f30c5bd 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ AI System Compiler to map a semantic (single-device) model into distributed exec > Install Python 3.7 in the development environment for widest compatibility. Install dependent packages -```shell +```sh pip install -r requirements.txt # require pytorch version >= 1.11 From 5211a3511cabb64486bfc633cf4872152ff7a592 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 20 Feb 2023 15:36:57 +0800 Subject: [PATCH 1250/1892] refine error hint for checkers --- cube/algorithm/generics.py | 3 +++ cube/graph/graph.py | 2 +- cube/ir/operator.py | 31 ++++++++++++++++++------------- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/cube/algorithm/generics.py b/cube/algorithm/generics.py index b5e52fd6..e537c165 100644 --- a/cube/algorithm/generics.py +++ b/cube/algorithm/generics.py @@ -36,3 +36,6 @@ def instantiate(self, **config) -> Optional[List[IRCell]]: @return sub_nodes Optional[List[IRCell]]: if sucess, the partitioned sub nodes, else None """ raise NotImplementedError + + def __repr__(self) -> str: + return f'TransAlgo(node{self._node.cid})' diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 1a1bf9fd..5c556df2 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -345,7 +345,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], @return ops List[IRCell]: partitioned sub-nodes """ assert isinstance(algo, GenericDistAlgo) and node == algo.node, \ - "The partition algorithm is not initialized for this node" + f"The partition algorithm ({algo}) is not initialized for this node" assert isinstance(node, (IRFwOperation, IRDataOperation)), \ f"Only allow op to be forward op or data op, but got: {node}" if node.name == 'multiref': diff --git a/cube/ir/operator.py b/cube/ir/operator.py index d6602487..4fabc785 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -1,9 +1,10 @@ -from typing import Optional, Tuple, Any +from typing import Optional, Tuple, Any, Union import copy from cube.ir.cten import IRCell, IRTensor -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.tensor import IRFullTensor from cube.algorithm.factory import DistAlgorithmFactory +from cube.algorithm.generics import GenericDistAlgo from cube.ir.unique import IDGenerator from cube.ir.dtype import IRDType, DTypeInferRule @@ -78,12 +79,15 @@ def recompute(self, group_id: Optional[int]): assert self._recompute == group_id, "The operator is set to recompute in another recompute group." self._recompute = group_id - def algorithms(self, tag: Optional[str] = None): + def algorithms(self, tag: Optional[str] = None) -> Union[Tuple[GenericDistAlgo], GenericDistAlgo]: """ get algorithm from algorithm factory - Args: - tag: str or None. If None, return all + @param tag Optional[str]: the queried tag (default None for all) + + @return algorithm(s) Union[Tuple[GenericDistAlgo], GenericDistAlgo]: + If None (default), return all possible algorithms. + Otherwise, return the specified one. """ factory = DistAlgorithmFactory() if tag is None: @@ -95,8 +99,7 @@ def algorithms(self, tag: Optional[str] = None): algos.append(template(self)) return algos else: - if not factory.exist(type(self), tag): - return None + assert factory.exist(type(self), tag), f"Node {self} doesn't have transformation algorithm tag: {tag}" template = factory.algorithms(type(self), tag) return template(self) @@ -226,12 +229,15 @@ def infer_shape(self): """ return True - def algorithms(self, tag: Optional[str] = None): + def algorithms(self, tag: Optional[str] = None) -> Union[Tuple[GenericDistAlgo], GenericDistAlgo]: """ - get algorithm from algorithm factory + Get algorithm from algorithm factory + + @param tag Optional[str]: the queried tag (default None for all) - Args: - tag: str or None. If None, return all + @return algorithm(s) Union[Tuple[GenericDistAlgo], GenericDistAlgo]: + If None (default), return all possible algorithms. + Otherwise, return the specified one. """ factory = DistAlgorithmFactory() if tag is None: @@ -243,8 +249,7 @@ def algorithms(self, tag: Optional[str] = None): algos.append(template(self)) return algos else: - if not factory.exist(type(self), tag): - return None + assert factory.exist(type(self), tag), f"Node {self} doesn't have transformation algorithm tag: {tag}" template = factory.algorithms(type(self), tag) return template(self) From 9852f92247ac36d28a319c2c6b44f81b0884ee33 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 21 Feb 2023 10:39:10 +0800 Subject: [PATCH 1251/1892] refine code --- cube/graph/gener/rvd/layout.py | 177 --------------------------------- 1 file changed, 177 deletions(-) diff --git a/cube/graph/gener/rvd/layout.py b/cube/graph/gener/rvd/layout.py index ec7c3d6b..f9b5e89e 100644 --- a/cube/graph/gener/rvd/layout.py +++ b/cube/graph/gener/rvd/layout.py @@ -6,11 +6,6 @@ from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.tensor import ValueMap -from cube.ir.adapter.prim import MovePrim # p2p -from cube.ir.adapter.prim import BroadcastPrim -from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim -from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim - TShape = Tuple[int, ...] TRVD = Tuple[int, ...] @@ -85,178 +80,6 @@ def __copy__(self): mat = np.array(tensors).reshape(self.mat.shape) return RVDLayout(self.ftensor, tensors, mat) - # ====== inter-RVD transition primitives ====== # - - def incr(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive +R: broadcast - - @param chunks int: the number of chunks to transfer - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - rvd = list(self.vec) - rvd[0] = rvd[0] * chunks - olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = RVDLayout.dims2last(self.mat, [0]).flatten() - omat = RVDLayout.dims2last(olayout.mat, [0]).reshape(-1, chunks) - prims = [] - for src, dsts in zip(imat, omat): - if chunks == 1: - prims.append(MovePrim([src], dsts)) - else: - prims.append(BroadcastPrim([src], [src] + list(dsts))) - return [(olayout, prims),] - - def decr(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive -R: move - - @param chunks int: the number of chunks to transfer - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.R % chunks == 0, f"not divisible replica {self.R} // {chunks}" - rvd = list(self.vec) - rvd[0] = rvd[0] // chunks - olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = RVDLayout.dims2last(self.mat, [0]).reshape(-1, chunks) - omat = RVDLayout.dims2last(olayout.mat, [0]).flatten() - prims = [] - for srcs, dst in zip(imat, omat): - prims.append(MovePrim([srcs[0]], [dst])) - return [(olayout, prims),] - - def incd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive +D: RD-Scatter - - @param chunks int: the number of chunks to transfer - @param dim int: tensor axis - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - rvd = list(self.vec) - rvd[2+dim] = rvd[2+dim] * chunks - olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = RVDLayout.dims2last(self.mat, [2+dim]).flatten() - omat = RVDLayout.dims2last(olayout.mat, [2+dim]).reshape(-1, chunks) - prims = [] - for src, dsts in zip(imat, omat): - prims.append(RDScatterPrim([src], dsts, dim=dim)) - return olayout, prims - - def decd(self, chunks: int, dim: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive +D: RD-Gather - - @param chunks int: the number of chunks to transfer - @param dim int: tensor axis - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.D[dim] % chunks == 0, f"not divisible dim: {self.D[dim]} % {chunks} != 0" - rvd = list(self.vec) - rvd[2+dim] = rvd[2+dim] // chunks - olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = RVDLayout.dims2last(self.mat, [2+dim]).reshape(-1, chunks) - omat = RVDLayout.dims2last(olayout.mat, [2+dim]).flatten() - prims = [] - for srcs, dst in zip(imat, omat): - prims.append(RDGatherPrim(srcs, [dst], dim=dim)) - return [(olayout, prims),] - - def incv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive +V: RV-Scatter - - @param chunks int: the number of chunks to transfer - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - rvd = list(self.vec) - rvd[1] = rvd[1] * chunks - olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) * chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = RVDLayout.dims2last(self.mat, [1]).flatten() - omat = RVDLayout.dims2last(olayout.mat, [1]).reshape(-1, chunks) - prims = [] - for src, dsts in zip(imat, omat): - prims.append(RVScatterPrim([src], dsts)) - return [(olayout, prims),] - - def decv(self, chunks: int, devices: Optional[np.ndarray] = None) -> Tuple: - """ - inter-RVD primitive -V: RV-Gather - - @param chunks int: the number of chunks to transfer - @param devices numpy.ndarray: the desired output device - - @return ret List[Tuple[RVDLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. - """ - assert self.V % chunks == 0, f"not divisable value split: {self.V} % {chunks} != 0" - rvd = list(self.vec) - rvd[1] = rvd[1] // chunks - olayout = RVDLayout.grid(self.ftensor, rvd[0], rvd[1], rvd[2:]) - # set device - if devices is not None: - assert devices.size == len(self.subtensors) // chunks - for tensor, devid in zip(olayout.mat.flatten(), devices.flatten()): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) - tensor.cell.device = int(devid) - # setup prims - imat = RVDLayout.dims2last(self.mat, [1]).reshape(-1, chunks) - omat = RVDLayout.dims2last(olayout.mat, [1]).flatten() - prims = [] - for srcs, dst in zip(imat, omat): - prims.append(RVGatherPrim(srcs, [dst])) - return [(olayout, prims),] - def align(self, layout) -> bool: """ Check whether the layout is same with self. From abcb78c9ae41397e9c951a8fc0ffec4706b98b37 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 21 Feb 2023 10:44:57 +0800 Subject: [PATCH 1252/1892] add nnfusion support --- cube/execplan/planpass/grouping.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index 755011dd..bb0de6d5 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -8,7 +8,10 @@ from cube.ir.adapter import IRAdapter from cube.ir.adapter.prim import IdentityPrim from cube.ir.operator import IRFwOperation +from cube.graph.function.pyfunc import IRPyFunc from cube.ir.cten import IRCell + +from cube.flags import CompileFlag class Grouping(PlanPass): @@ -57,6 +60,9 @@ def group(execplan) -> Tuple[Dict[int, List[List[IRCell]]],]: Tuple: (fgroups, bgroups) """ def differentiable(fnode): + # nnfusion special handle: break IRAdapter and IRPyFunc + if CompileFlag.use_nnfusion: + if isinstance(fnode, IRAdapter): return False if isinstance(fnode, IRFwOperation): return True if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.isfw(): From 9f9b334e73cc17ea537da9a54ff36a49a1c836b2 Mon Sep 17 00:00:00 2001 From: rwlu Date: Tue, 21 Feb 2023 06:36:49 -0800 Subject: [PATCH 1253/1892] refine code --- examples/alphafold2/alphafold2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 1bab4725..3dd72d5f 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -23,7 +23,7 @@ def build_alphafold_config(setting:int): elif setting == 2: bs, s, r = 1, 512, 256 elif setting == 3: - bs, s, r = 1, 512, 384 + bs, s, r = 1, 512, 512 else: assert False, f"unrecognized setting {setting}" @@ -98,7 +98,8 @@ def test_main(): # first fine-tuning # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning - bs, s, r, cm, cz = 1, 512, 384, 256, 128 + # bs, s, r, cm, cz = 1, 512, 384, 256, 128 + bs, s, r, cm, cz = 1, 512, 512, 256, 128 dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 4, False, True, False policy = spmd.PASDAP From c51e5c1ed1799a82266bc15d8d451339167e90e1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 22 Feb 2023 19:15:51 +0800 Subject: [PATCH 1254/1892] fix flag bug: allow to use 0/1 for on/off --- cube/flags.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/cube/flags.py b/cube/flags.py index f3e0a46f..bd8a54c5 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -5,35 +5,42 @@ import os +def _to_bool(s: str) -> bool: + val = os.environ.get(s, default=0) + return bool(int(val)) + + +def _to_int(s: str, default=0) -> int: + val = os.environ.get(s, default=default) + return int(val) + + class CompileFlag: # ============= loggings =================== - log_transform = os.environ.get('LOG_TRANSFORM') - log_schedule = os.environ.get('LOG_SCHEDULE') + log_transform = _to_bool('LOG_TRANSFORM') + log_schedule = _to_bool('LOG_SCHEDULE') - # ================ compiling ======================== # worker sleep in seconds - worker_sleep = int(os.environ.get('WORKER_SLEEP')) if os.environ.get('WORKER_SLEEP') is not None else 0 - disable_intra_rvd = os.environ.get('DISABLE_INTRA_RVD') - disable_inter_rvd = os.environ.get('DISABLE_INTER_RVD') - disable_comm_fusion = os.environ.get('DISABLE_COMM_FUSION') + worker_sleep = _to_int('WORKER_SLEEP') + disable_intra_rvd = _to_bool('DISABLE_INTRA_RVD') + disable_inter_rvd = _to_bool('DISABLE_INTER_RVD') + disable_comm_fusion = _to_bool('DISABLE_COMM_FUSION') - visualize_plan = bool(os.environ.get('VISUALIZE_PLAN')) + visualize_plan = _to_bool('VISUALIZE_PLAN') # ============ code generation =============== - use_nnfusion = os.environ.get('USE_NNFUSION') - use_jit = os.environ.get('USE_JIT') - + use_nnfusion = _to_bool('USE_NNFUSION') + use_jit = _to_bool('USE_JIT') # ============== runtime ==================== - dev_mode = os.environ.get('SINGLE_DEV_MODE') # allow to use python xx.py + dev_mode = _to_bool('SINGLE_DEV_MODE') # allow to use python xx.py # maximal reducer weight bytes for one allreduce - max_reducer_bucket = int(os.environ.get('MAX_REDUCER_BUCKET', default=5e8)) + max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=5e8) # use automate mixture precision training, where weights, gradients # and optimizer status are kept in its original data type (can be float32), # but some of the forward operators will be converted to float16. - use_amp = True if os.environ.get('USE_AMP') else False - + use_amp = _to_bool('USE_AMP') From dc175879b9579ba46f32394044c67ac578e6cf00 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 24 Feb 2023 09:38:04 +0800 Subject: [PATCH 1255/1892] remove useless code --- examples/custom_ops.py | 34 ------ examples/gsearch/blocks.py | 181 ---------------------------- examples/gsearch/gpt/model.py | 109 ----------------- examples/gsearch/gpt/policy/mpmd.py | 102 ---------------- examples/gsearch/gpt/policy/spmd.py | 100 --------------- examples/gsearch/gpt/train.py | 93 -------------- examples/inspector.py | 127 ------------------- examples/mlp/policy/mpmd.py | 7 +- examples/mlp/policy/search.py | 133 -------------------- examples/mlp/policy/spmd.py | 71 +++++++++-- examples/mlp/policy/st_search.py | 80 ------------ tests/test_grid.py | 19 --- 12 files changed, 63 insertions(+), 993 deletions(-) delete mode 100644 examples/custom_ops.py delete mode 100644 examples/gsearch/blocks.py delete mode 100644 examples/gsearch/gpt/model.py delete mode 100644 examples/gsearch/gpt/policy/mpmd.py delete mode 100644 examples/gsearch/gpt/policy/spmd.py delete mode 100644 examples/gsearch/gpt/train.py delete mode 100644 examples/inspector.py delete mode 100644 examples/mlp/policy/search.py delete mode 100644 examples/mlp/policy/st_search.py delete mode 100644 tests/test_grid.py diff --git a/examples/custom_ops.py b/examples/custom_ops.py deleted file mode 100644 index 151c15ef..00000000 --- a/examples/custom_ops.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import einops - -############### custom op ################# -def update_diag_(w: torch.Tensor, F: torch.Tensor, G: torch.Tensor, - delta_x_filter:torch.Tensor, delta_y_filter: torch.Tensor, deltaA:torch.Tensor, - pi0:torch.Tensor, pi1:torch.Tensor, sigma:torch.Tensor, dt:torch.Tensor, dz:float): - #NOTE place holder - # def pre_conv3d_reshape(X): - # return einops.einops.rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) - # def post_conv3d_reshape(X): - # return einops.einops.rearrange(X, 'b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') - # def delta_x(X): - # return post_conv3d_reshape(F.conv3d(pre_conv3d_reshape(X), delta_x_filter)) - # def delta_y(X): - # return post_conv3d_reshape(F.conv3d(pre_conv3d_reshape(X), delta_y_filter)) - # # update diagnostic variable w (nz + 1, ny, nx) - # for i in range(1, w.shape[0]): - # w[i] = - ((delta_x(F[:i]) + delta_y(G[:i])) * dz).sum(dim=0) / deltaA / pi1 \ - # - sigma[i] * (pi1 - pi0) / dt / pi1 - - return w - -def update_geopotential_(phi: torch.Tensor, zs: torch.Tensor, P: torch.Tensor, P_: torch.Tensor, theta: torch.Tensor, - g:float, CPD:float, nz:int): - # NOTE place holder - # phi[-1] = g * zs - CPD * (P[-1] - P_[-1]) * theta[-1] - # for i in range(1, nz): - # tmp = phi[-i] - CPD * (P_[-i - 1] - P[-i]) * theta[-i] - # phi[-1 - i] = tmp - CPD * (P[-1 - i] - P_[-1 - i]) * theta[-1 - i] - return phi - -def strip_2_borders(w: torch.Tensor): - return w[1:-1] \ No newline at end of file diff --git a/examples/gsearch/blocks.py b/examples/gsearch/blocks.py deleted file mode 100644 index e112be60..00000000 --- a/examples/gsearch/blocks.py +++ /dev/null @@ -1,181 +0,0 @@ -import torch -import cube - - -@cube.graph.parser.register('L N E+, (h d) E+, (h d), (h d) E+, (h d), (h d) E+, (h d) -> N h L d, N h L d, N h L d', name='attn_qkv') -def attn_qkv(query: torch.Tensor, - q_proj: torch.Tensor, q_bias: torch.Tensor, - k_proj: torch.Tensor, k_bias: torch.Tensor, - v_proj: torch.Tensor, v_bias: torch.Tensor, - h: int, scale: float): - L, N = query.size(0), query.size(1) - d = q_proj.size(0) // h - - q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * h), d) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * h), d) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * h), d) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - q = q.view(N, h, L, d) - k = k.view(N, h, L, d) - v = v.view(N, h, L, d) - return q, k, v - - -@cube.graph.parser.register('N h L d+, N h L d+ -> N h L L', name='attn_score') -def attn_score(q: torch.Tensor, k: torch.Tensor, h: int, mask: bool = True): - N, num_head, L, d = q.size() - assert num_head == h - q = q.view(-1, L, d) - k = k.view(-1, L, d) - k = k.transpose(1, 2) - attn = torch.bmm(q, k) - attn = attn.view(N, h, L, L) - # attention mask - if mask: - ones = torch.ones((N, L, L), device=attn.device) - amask = torch.tril(ones) - amask = amask.view(N, 1, L, L) - amask = (amask < 0.5) - attn = attn.masked_fill_(amask, -10000.0) - return attn - - -@cube.graph.parser.register('N h L K^ -> N h L K^', name='attn_softmax') -def attn_softmax(attn: torch.Tensor): - N, h, L, L = attn.size() - attn = attn.view((N * h), L, L) - attn = torch.nn.functional.softmax(attn, dim=-1) - return attn.view(N, h, L, L) - - -@cube.graph.parser.register('N h L L -> N h L L', name='attn_dropout') -def attn_dropout(attn: torch.Tensor, dropout_p: float): - return torch.nn.functional.dropout(attn, dropout_p, True, False) - - -@cube.graph.parser.register('N h L K+, N h K+ d -> L N (h d)', name='attn_context') -def attn_context(attn: torch.Tensor, v: torch.Tensor): - N, h, L, d = v.size() - attn = attn.view((N * h), L, L) - v = v.view((N * h), L, d) - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, h * d) # (N h) L d -> L N (h d) - return output - - -@cube.graph.parser.register('L N hd+, E hd+, E -> L N E', name='attn_dense_out') -def attn_dense_out(context: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): - return torch.nn.functional.linear(context, weight, bias) - - -@cube.graph.parser.register('L N E+, inner E+, inner -> L N inner', name='mlp_linear1') -def mlp_linear1(x: torch.Tensor, proj: torch.Tensor, bias: torch.Tensor): - return torch.nn.functional.linear(x, proj, bias) - - -@cube.graph.parser.register('L N inner+, E inner+, E -> L N E', name='mlp_linear2') -def mlp_linear2(x: torch.Tensor, proj: torch.Tensor, bias: torch.Tensor): - return torch.nn.functional.linear(x, proj, bias) - - - -class MultiHeadSelfAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): - super().__init__() - self.inner_dim = inner_dim - self.num_heads = num_heads - self.head_dim = inner_dim // num_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # Q - self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None - # K - self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None - # V - self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None - - def forward(self, query: torch.Tensor): - # QKV - q, k, v = attn_qkv( - query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.num_heads, self.scaling - ) - # AttentionScore - attn = attn_score(q, k, self.num_heads, mask=True) - # softmax - attn = attn_softmax(attn) - # dropout - attn = attn_dropout(attn, self.dropout_p) # N h L L -> N h L L - # attn = torch.nn.functional.dropout(attn, self.dropout_p, True, False) # N h L L -> N h L L - # context - context = attn_context(attn, v) - # DenseOutput - # output = torch.nn.functional.linear(context, self.out_proj, self.out_bias) # L N (h d), E E -> L N E - output = attn_dense_out(context, self.out_proj, self.out_bias) # L N (h d), E E -> L N E - return output - - -class MLP(torch.nn.Module): - - def __init__(self, embed_dim: int, hidden_dim: int, dropout: float): - super().__init__() - self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) - self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) - self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) - self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) - self.dropout = dropout - - def forward(self, x: torch.Tensor): - # L N E, inner E -> L N inner - x = mlp_linear1(x, self.proj1, self.proj1_bias) - # L N inner -> L N inner - x = torch.nn.functional.gelu(x) - x = mlp_linear2(x, self.proj2, self.proj2_bias) - x = torch.nn.functional.dropout(x, self.dropout, True, False) - return x - - -class EncoderLayer(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, - attn_hidden_dim: int, ffn_hidden_dim: int, - dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): - super().__init__() - self.self_attn = MultiHeadSelfAttention( - embed_dim, num_heads, attn_hidden_dim, atten_dropout - ) - self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) - self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - # warnings.warn('residual is disabled in encoder block') - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - x = self.mlp(x) - x = x + residual - return x diff --git a/examples/gsearch/gpt/model.py b/examples/gsearch/gpt/model.py deleted file mode 100644 index 862c8766..00000000 --- a/examples/gsearch/gpt/model.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch - - -from examples.gsearch.blocks import EncoderLayer - -import cube - - -class Config: - - num_embeddings = 50304 - seqlen = 1024 - - # 1.7B model - embed_dim = 2304 - layers = 8 # 24 - attention_heads = 24 - - # 3.6B model - # embed_dim = 3072 - # layers = 32 - # attention_heads = 32 - - # 7.5B model - # embed_dim = 4096 - # layers = 32 - # attention_heads = 36 - - attn_hidden_dim = embed_dim - ffn_hidden_dim = embed_dim * 4 - dropout = 0.0 - attn_dropout = 0.0 - activation_dropout = 0.0 - - -class GPT(torch.nn.Module): - - def __init__(self): - super().__init__() - cfg = Config() - - self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.embed_dim)) - # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) - self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) - self.embed_dropout = torch.nn.Dropout() - - self.layers = torch.nn.ModuleList( - [EncoderLayer( - cfg.embed_dim, cfg.attention_heads, - cfg.attn_hidden_dim, cfg.ffn_hidden_dim, - cfg.dropout, cfg.attn_dropout, cfg.activation_dropout - ) for _ in range(cfg.layers)] - ) - self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) - - def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): - - # embed = self.embed(input_ids) - embed = torch.nn.functional.embedding( - input_ids, self.embedw, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False - ) - pos_embed = self.position(position_ids) - embed = embed + pos_embed - embed = self.embed_dropout(embed) - enc = embed.transpose(0, 1) - - for layer in self.layers: - enc = layer(enc) - enc = self.final_layernorm(enc) - - # logits = torch.nn.functional.linear(enc, self.embed.weight) - logits = torch.nn.functional.linear(enc, self.embedw) - # simplified - loss = torch.sum(logits) - return loss - - -class GPTDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int): - - self.bs = batch_size - self.cfg = Config() - super().__init__( - shapes=([batch_size, self.cfg.seqlen], - [batch_size, self.cfg.seqlen], - ), - dtypes=(torch.int64, torch.int64), - batch_dims=(0, 0) - ) - self.samples = [self.random_sample()] - - def random_sample(self): - input_ids = torch.randint( - 0, self.cfg.num_embeddings, - size=(self.bs, self.cfg.seqlen), - dtype=torch.int64, device=torch.cuda.current_device() - ) - position_ids = torch.arange( - 0, self.cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() - ).repeat(self.bs).view((self.bs, -1)) - return (input_ids, position_ids) - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] \ No newline at end of file diff --git a/examples/gsearch/gpt/policy/mpmd.py b/examples/gsearch/gpt/policy/mpmd.py deleted file mode 100644 index 0b4b08a6..00000000 --- a/examples/gsearch/gpt/policy/mpmd.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import List, Tuple -import numpy as np - -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.schedule.sched1f1b import IRSchedule1F1B - - -def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: - """ - Create hybrid (nested) groups given the each group number. - - The product of group_num should be same with total devices. - """ - group_num = np.array(group_num) - cnt = np.prod(group_num) - assert cnt == ngpus, 'total device not match' - grid = np.arange(cnt).reshape(tuple(group_num)) - dims = list(range(len(group_num))) - outputs = [] - for dim, num in enumerate(group_num): - remain = ngpus // num - order = tuple(dims[:dim] + dims[dim+1:] + [dim]) - grid_dim = np.transpose(grid, order).reshape((remain,num)) - grid_dim = grid_dim.tolist() - outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) - assert len(outputs) == len(group_num) - return tuple(outputs) - - -def PASRoundRobin(graph: IRGraph, resource): - """ - 1F1B scheduling - """ - - def org_transformer_layer(graph: IRGraph) -> List[List[IRFwOperation]]: - multiref_idx = [ - fidx for fidx, node in enumerate(graph.nodes()) if \ - isinstance(node, IRFwOperation) and node.name == 'multiref' - ] - assert len(multiref_idx) % 2 == 0, "un-recognized transormer structure" - transformers = [] - last_fidx = [fidx for fidx, node in enumerate(graph.nodes()) if isinstance(node, IRFwOperation)][-1] - for idx in range(0, len(multiref_idx), 2): - graph.nodes()[multiref_idx[idx]].comment = f'===> start of transformer {idx // 2}' - start = multiref_idx[idx] if idx != 0 else 0 - end = multiref_idx[idx+2] if idx+2 < len(multiref_idx) else last_fidx+1 - transformers.append(graph.nodes()[start:end]) - return transformers - - transformers = org_transformer_layer(graph) - for lid, fnodes in enumerate(transformers): - stage_id = lid % resource.ngpus - print(f'assigning {lid}-th transformer layter to stage {stage_id}') - for fnode in fnodes: - graph.assign(fnode, stage_id) - - for node in graph.nodes(): - if len(node.device) == 0: - graph.assign(node, 0) - - return graph - - -def PAS1F1B(graph: IRGraph, resource): - """ - 1F1B scheduling - """ - num_stage = resource.ngpus - num_microbatch = resource.ngpus - - _, stage_mesh = _create_mesh(resource.ngpus, (num_stage, 1)) - - def org_transformer_layer(graph: IRGraph) -> List[List[IRFwOperation]]: - multiref_idx = [ - fidx for fidx, node in enumerate(graph.nodes()) if \ - isinstance(node, IRFwOperation) and node.name == 'multiref' - ] - assert len(multiref_idx) % 2 == 0, "un-recognized transormer structure" - transformers = [] - last_fidx = [fidx for fidx, node in enumerate(graph.nodes()) if isinstance(node, IRFwOperation)][-1] - for idx in range(0, len(multiref_idx), 2): - graph.nodes()[multiref_idx[idx]].comment = f'===> start of transformer {idx // 2}' - start = multiref_idx[idx] if idx != 0 else 0 - end = multiref_idx[idx+2] if idx+2 < len(multiref_idx) else last_fidx+1 - transformers.append(graph.nodes()[start:end]) - return transformers - - transformers = org_transformer_layer(graph) - for lid, fnodes in enumerate(transformers): - stage_id = min(lid // (len(transformers) // resource.ngpus), num_stage-1) - print(f'assigning {lid}-th transformer layter to stage {stage_id}') - for fnode in fnodes: - graph.assign(fnode, stage_id) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - graph.assign(node, 0) - - schedule = IRSchedule1F1B(num_microbatch, stage_mesh, recompute=False) - graph.schedule_plan = schedule - return graph \ No newline at end of file diff --git a/examples/gsearch/gpt/policy/spmd.py b/examples/gsearch/gpt/policy/spmd.py deleted file mode 100644 index c8cff30b..00000000 --- a/examples/gsearch/gpt/policy/spmd.py +++ /dev/null @@ -1,100 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation - - -def PASReplica(graph: IRGraph, resource): - """ - Single device test - """ - assert resource.ngpus == 1 - print(graph.extra_repr()) - for node in graph.nodes(): - if isinstance(node, (IRDataOperation, IRFwOperation)): - graph.assign(node, 0) - # print(graph.extra_repr()) - return graph - - -def PASMegatronTP(graph: IRGraph, resource): - """ - Megatron tensor parallelism (attention) - """ - tp_size = resource.ngpus - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - def tensor_parallelism(node: IRFwOperation, comment: str = None, **configs): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - if isinstance(comment, str): - for sub_node in sub_nodes: - sub_node.comment = comment - assert all(isinstance(n, IRFwOperation) for n in sub_nodes), f"Fail to partition node {node}" - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return sub_nodes - - # annotating code structure -- not consider multiref on embedding weight - multirefs = [node for node in graph.nodes() if isinstance(node, IRFwOperation) and node.name == 'multiref'][1:] - for idx in range(0, len(multirefs), 2): - multirefs[idx].comment = f'====> start of transformer {idx // 2}' - - # ============ Attention =============== - qkvs = [node for node in fnodes if node.name == 'attn_qkv'] - for idx, qkv in enumerate(qkvs): - tensor_parallelism(qkv, idx=1, dim=0, num=tp_size) - - scores = [node for node in fnodes if node.name == 'attn_score'] - for score in scores: - tensor_parallelism(score, idx=0, dim=1, num=tp_size) - - softmaxs = [node for node in fnodes if node.name == 'attn_softmax'] - for softmax in softmaxs: - tensor_parallelism(softmax, idx=0, dim=1, num=tp_size) - - dropouts = [node for node in fnodes if node.name == 'attn_dropout'] - for dropout in dropouts: - tensor_parallelism(dropout, idx=0, dim=1, num=tp_size) - - contexts = [node for node in fnodes if node.name == 'attn_context'] - for context in contexts: - tensor_parallelism(context, idx=0, dim=1, num=tp_size) - - dense_outs = [node for node in fnodes if node.name == 'attn_dense_out'] - for dense in dense_outs: - tensor_parallelism(dense, idx=0, dim=2, num=tp_size) - - # ============= MLP =================== - linear1s = [node for node in fnodes if node.name == 'mlp_linear1'] - for mlp_linear1 in linear1s: - tensor_parallelism(mlp_linear1, idx=1, dim=0, num=tp_size) - - gelus = [node for node in fnodes if node.name == 'gelu'] - for gelu in gelus: - tensor_parallelism(gelu, idx=0, dim=2, num=tp_size) - - linear2s = [node for node in fnodes if node.name == 'mlp_linear2'] - for mlp_linear2 in linear2s: - tensor_parallelism(mlp_linear2, idx=0, dim=2, num=tp_size) - - # replicate others - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: - rnodes = graph.replicate(node, times=tp_size) - for idx, rnode in enumerate(rnodes): - graph.assign(rnode, idx) - # print(graph.extra_repr()) - return graph - - -def PASRecompute(graph: IRGraph, resource): - """ - Recompute parallelism test - """ - assert resource.ngpus == 1 - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - graph.recompute(fnodes) - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - diff --git a/examples/gsearch/gpt/train.py b/examples/gsearch/gpt/train.py deleted file mode 100644 index 8206def4..00000000 --- a/examples/gsearch/gpt/train.py +++ /dev/null @@ -1,93 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/gsearch/gpt/train.py --policy PASMegatronTP -""" - - -import torch - -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary, model_summary - -from examples.gsearch.gpt.model import GPT -from examples.gsearch.gpt.model import GPTDataLoader - -import examples.gsearch.gpt.policy.spmd as spmd -import examples.gsearch.gpt.policy.mpmd as mpmd - -import argparse -parser = argparse.ArgumentParser(description='comm primitive') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -args = parser.parse_args() - -cube.init() - -# set up policy -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - - - -def train(): - - batch_size = 1 - - model = GPT() - dataloader = GPTDataLoader(batch_size) - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - model = cube.SemanticModel(model, dataloader.shapes) - @cube.compile(model, dataloader, PAS=PAS, override=True) - def train_iter(model, dataloader): - input_ids, position_ids = next(dataloader) - loss = model(input_ids, position_ids) - loss.backward() - model = model.get_gen_module() - - print_each_rank('model weight consumpition:') - memory_summary() - - CudaTimer(enable=False).warmup() - iter_num = 64 - for step in range(iter_num): - - # if step == 0: - # model_summary(model, next(dataloader)) - - if step >= 20: - CudaTimer(enable=True).start('e2e') - - # training - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - - if step >= 20: - CudaTimer().stop('e2e') - - if step == 0: - print_each_rank('passed first iteration') - - if (step + 1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-40) - memory_summary() - - -train() \ No newline at end of file diff --git a/examples/inspector.py b/examples/inspector.py deleted file mode 100644 index 4c6ceeea..00000000 --- a/examples/inspector.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Directly loading generated file for training - -python -m torch.distributed.launch \ - --nproc_per_node=2 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/inspector.py - -OMP_NUM_THREADS=4 torchrun --standalone \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/inspector.py -""" -import torch -import argparse -import time - -import cube -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - -# gpt -# L, N, E = (512, 8, 3072) -# kBatchDims = (0, 0) -# kDataShapes = ([N, L], [N, L]) -# kDTypes = (torch.float, torch.long) - -# mlp -kBatchDims = (0,) -kDataShapes = ([8192, 8192],) -kDTypes = (torch.float,) - -# transformer -# kBatchDims = (1, ) -# kDataShapes = ([512, 4, 3072],) -# kDTypes = (torch.float,) - - -def load_module(filename: str): - import importlib.util - rank = torch.distributed.get_rank() - print(f'> [{rank}] loading generated spatial moduel from {filename}') - spec = importlib.util.spec_from_file_location("GenModel", filename) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - loaded_module = module.GenModel().cuda() - # sync parameters before start training - loaded_module.sync_params() - return loaded_module - - -def load_train_fn(filename: str): - import importlib.util - rank = torch.distributed.get_rank() - print(f'> [{rank}] loading generated schedule from {filename} ...') - spec = importlib.util.spec_from_file_location( - "_train_step", filename - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module._train_step - - -def train(args): - global kDataShapes - global kDTypes - global kBatchDims - dataloader = cube.runtime.syndata.SynDataLoader( - kDataShapes, kDTypes, kBatchDims - ) - - genfile = args.genfile.format(rank=torch.distributed.get_rank()) - model = load_module(genfile) - train_fn = load_train_fn(genfile) - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - iter_num = args.iter_num - - def train_iters(): - for step in range(iter_num): - if step >= 40: - CudaTimer(enable=True).start('e2e') - train_fn(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('passed 1 iteration') - if step >= 40: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - time.sleep(0.05) - - if args.profile: - with torch.profiler.profile() as prof: - train_iters() - prof.export_chrome_trace(f"trace{torch.distributed.get_rank()}.json") - else: - train_iters() - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-40, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-40) - memory_summary() - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='inspect') - parser.add_argument('--genfile', type=str, - default='gencode{rank}.py') - parser.add_argument('--iter-num', type=int, - default=128) - parser.add_argument('--profile', dest='profile', action='store_true', - help='use edge://tracing/ or chrome://tracing/ to open the file') - args = parser.parse_args() - - cube.init() - train(args) diff --git a/examples/mlp/policy/mpmd.py b/examples/mlp/policy/mpmd.py index ed754407..49aaa10a 100644 --- a/examples/mlp/policy/mpmd.py +++ b/examples/mlp/policy/mpmd.py @@ -4,7 +4,7 @@ from cube.graph.graph import IRGraph from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.schedule.sched1f1b import IRSchedule1F1B +from cube.graph.schedule.predefined import PredefinedSched def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: @@ -86,6 +86,7 @@ def PAS1F1B(graph: IRGraph, resource): for devid, rnode in zip(mesh, rnodes): graph.assign(rnode, devid) # setup schedule to 1F1B - schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) - graph.schedule_plan = schedule + # schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) + # graph.schedule_plan = schedule + schedule = PredefinedSched.sched_1f1b(graph, num_microbatch, num_stage) return graph \ No newline at end of file diff --git a/examples/mlp/policy/search.py b/examples/mlp/policy/search.py deleted file mode 100644 index ef1fa13c..00000000 --- a/examples/mlp/policy/search.py +++ /dev/null @@ -1,133 +0,0 @@ - -from typing import Dict, List -import time -from itertools import combinations - -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation -import cube.search.iterator as iterator -from cube.profiler.estimator import Estimator - -import numpy as np - - -def get_plan(graph: IRGraph, fnode: IRFwOperation, configs: List[Dict]) -> List[IRFwOperation]: - - all_nodes = [fnode] - for config in configs: - sub_nodes = list() - for node in all_nodes: - algo = node.algorithms('dim') - sub = graph.partition(node, algo, config) - if sub is None: - sub = graph.replicate(node, times=config['num']) - fnode.tag = ('rep', 'rep') - sub_nodes += sub - all_nodes = sub_nodes - fnode.tag = tuple('{}-{}'.format(config['name'], config['num']) for config in configs) - return all_nodes - - -def compositions(graph: IRGraph, fnode: IRFwOperation, nest: List[int]) -> List[IRFwOperation]: - """" - e.g., - fnode: linear - nest: [2, 4] - will get 9 partition strategies of 8-nodes - """ - all_configs = [ - dict(idx=0, dim=0, name='dat'), # data parallel - dict(idx=0, dim=1, name='row'), # row parallel - dict(idx=1, dim=0, name='col'), # col parallel - ] - config_iter = combinations(all_configs, len(nest)) - for configs in config_iter: - for config, ndev in zip(configs, nest): - config['num'] = ndev - nodes = get_plan(graph, fnode, configs) - yield nodes - graph.merge(nodes, fnode) - fnode.tag = None - - -def sequence(graph: IRGraph, fnodes: IRFwOperation, resource): - - nest_depth = 2 - nests = iterator.factorization(resource.ngpus, nest_depth) - - if len(fnodes) == 0: - yield list() - - for fnode in fnodes: - for nest in nests: - for seq in compositions(graph, fnode, nest): - for idx, node in enumerate(seq): - graph.assign(node, idx) - for remain in sequence(graph, fnodes[1:], resource): - yield seq + remain - - -def comm_estimate(graph: IRGraph, ndevice: int) -> int: - """ - Estimate communications - """ - estimator = Estimator(graph) - total_volume = 0 - for devid in range(ndevice): - total_volume += estimator.comm_volume(devid) - return total_volume - - -def PAS(graph: IRGraph, resource): - - # replicate data operation - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - - # replicate loss operation - fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] - loss = fnodes[-1] - sub_nodes = graph.replicate(loss, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - - # search for linear operations - start = time.time() - - fnodes = fnodes[:-1] # only search linears - seqs = list() - comms = list() - plans = list() - for idx, seq in enumerate(sequence(graph, fnodes, resource)): - print(f'searching index: {idx}...') - seqs.append(seq) - comm = comm_estimate(graph, resource.ngpus) - comms.append(comm) - plan = [node.tag for node in fnodes] - plans.append(plan) - print(f'comm volume param#: {comm}') - # for node in fnodes: - # print(node.tag) - # print(graph.extra_repr()) - print(f'==> grid search done on {idx+1} plans') - print(f'\n\n') - - comms = np.array(comms) - indices = np.argsort(comms) - - top_indices = indices[:10] - top_plan = [plans[idx] for idx in top_indices] - top_comm = [comms[idx] for idx in top_indices] - for top_idx, (idx, plan, comm) in enumerate(zip(top_indices, top_plan, top_comm)): - print(f'top {top_idx} (plan index {idx}):') - for lid, node in enumerate(plan): - print(f'linear{lid}: {node}') - print(f'===> comm param#: {comm}') - - end = time.time() - print('grid search time: {:.2f}'.format(end-start)) - - raise NotImplementedError diff --git a/examples/mlp/policy/spmd.py b/examples/mlp/policy/spmd.py index db541ad2..eb2b502d 100644 --- a/examples/mlp/policy/spmd.py +++ b/examples/mlp/policy/spmd.py @@ -1,5 +1,53 @@ +from typing import List from cube.graph import IRGraph +from cube.graph.segment import IRSegment from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.gener.rvd.intra import IntraAutoPlacer + + +# tensor parallelism with auto-placer +# This is an implementation example of SPMD auto placer usage +def _tp_autoplace(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): + + if len(devs) == 1: + graph.assign(node, devs[0]) + return [node] + + segment: IRSegment = graph.segment(node) + ftensor = node.input(configs['idx']).parent + + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + producers = segment.producers(ftensor) + if ftensor.is_param() or len(producers) != len(sub_nodes): + print(f"> skip auto placer due to condition not matched: " + f"nproducers: {len(producers)}, nconsumers: {len(sub_nodes)}, " f"producer name: {producers[0].name if len(producers) > 0 else None}") + devs = sorted(list(devs)) + for devid, node in zip(devs, sub_nodes): + graph.assign(node, devid) + else: + devices = IntraAutoPlacer.auto_place( + segment, ftensor, producers, sub_nodes) + for devid, subnode in zip(devices, sub_nodes): + graph.assign(subnode, devid) + return sub_nodes + +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes def PASSingle(graph: IRGraph, resource): @@ -59,20 +107,19 @@ def PASRow(graph: IRGraph, resource): """ Linear Column Parallel """ - linears = [node for node in graph.nodes() if node.name == 'linear'] - for idx, node in enumerate(linears): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=1, dim=1, num=resource.ngpus - ) + devs = list(range(resource.ngpus)) + + for dl in graph.select(ntype=IRDataOperation): + sub_nodes = graph.replicate(dl, resource.ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - sub_nodes = graph.replicate(node, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) + + for node in graph.select(ntype=IRFwOperation): + if node.name == 'linear': + _tp(graph, node, devs, idx=0, dim=1, num=len(devs)) + else: + _replica(graph, node, devs) + return graph diff --git a/examples/mlp/policy/st_search.py b/examples/mlp/policy/st_search.py deleted file mode 100644 index 3aa140cf..00000000 --- a/examples/mlp/policy/st_search.py +++ /dev/null @@ -1,80 +0,0 @@ -from functools import partial -from typing import List -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRCell -from cube.execplan import ExecutionPlan - -from cube.search.sampler import Estimator, Sampler, SpatialSampler, TemporalSampler, Searcher - - -class MicroBatchView: - - @staticmethod - def node2stage(node: IRCell, fnodes: List[IRCell], n_stage: int): - num_fnodes = len(fnodes) - idx = fnodes.index(node) - stage = min(idx // (num_fnodes // n_stage), n_stage - 1) - return stage - - @staticmethod - def split(graph: IRGraph, n_microbatch: int) -> List[IRCell]: - """ - Split graph into micro-batch view - """ - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - micro_seqs = [list() for _ in range(n_microbatch)] - for node in fnodes: - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, config=dict(idx=0, dim=0, num=n_microbatch)) - for mid, sub_node in enumerate(sub_nodes): - micro_seqs[mid].append(sub_node) - for mid in range(n_microbatch): - micro_seqs[mid] = micro_seqs[mid] + [n.mirror for n in micro_seqs[mid][::-1]] - return micro_seqs - - @staticmethod - def flatten(micro_seqs: List[List[IRCell]]): - flatten_nodes = list() - for seq in micro_seqs: - flatten_nodes += seq - return flatten_nodes - - -def PAS(graph: IRGraph, resource): - print(graph.extra_repr()) - - # n_microbatch, n_stage, n_device - M, S, D = 4, 4, 4 - - # memory limits - wlimits = 2 - alimits = 4 - - micro_seqs = MicroBatchView.split(graph, M) - assert len(micro_seqs) == M and len(micro_seqs[0]) // 2 == S - sgraph = IRGraph(MicroBatchView.flatten(micro_seqs), [], [], 'search') - Estimator.taging(sgraph) - - n_worker, seq_per_worker = 32, 512 - tsampler = partial(TemporalSampler.btemporal, bs=n_worker*seq_per_worker) - ssampler = partial(SpatialSampler.othogonal, wlimits=wlimits) - - bucket = dict() - cnt = 0 - for seqs in Sampler.sample(micro_seqs, M, S, D, ssampler, tsampler, wlimits, alimits): - Searcher.search(seqs, bucket, n_worker=n_worker) - for mem, (span, seq) in bucket.items(): - sgraph._nodes = seq - execplan = ExecutionPlan(sgraph) - execplan.analyze(map2time=Estimator.map2time, outfile=f'plan.mem{mem}.png') - cnt += len(seqs) - print(f'done search on {cnt} sequences') - assert False - - -if __name__ == '__main__': - for idx, placement in enumerate(Sampler.spatial(3, 3, 3)): - print(placement) - print(f'total {idx + 1} seqs') diff --git a/tests/test_grid.py b/tests/test_grid.py deleted file mode 100644 index fb0164e3..00000000 --- a/tests/test_grid.py +++ /dev/null @@ -1,19 +0,0 @@ -from cube.graph.gener.layout import GridLayout -from cube.ir.tensor import IRFullTensor - -def test_grid(): - - tensor = IRFullTensor(shape=[8192,8192], name='src') - - src = GridLayout.grid(tensor, r=2, v=2, dims=[1, 1]) - dst = GridLayout.grid(tensor, r=2, v=1, dims=[2, 1]) - - path, prims = src.path(dst) - for grid in path: - print(grid) - for prim in prims: - print(prim) - - -if __name__ == '__main__': - test_grid() From 508021041248daa01ac0ed336c76deafa5f0af41 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 25 Feb 2023 08:44:23 +0000 Subject: [PATCH 1256/1892] Merged PR 1455: Asynchronized communication support To enable it, specify ASYNC_COMM=1 --- cube/__init__.py | 31 -- cube/codegen/emit.py | 24 +- cube/flags.py | 1 + cube/ir/adapter/prim.py | 22 +- cube/runtime/adapter/collectives.py | 400 +++++++++------------- cube/runtime/executor.py | 117 +++---- cube/runtime/schedule/strategy.py | 49 ++- examples/mlp/infer.py | 104 ++++++ examples/mlp/policy/mpmd.py | 57 +-- examples/mlp/policy/spmd.py | 2 +- examples/mlp/{linears.py => train.py} | 8 +- tests/runtime/test_runtime_collectives.py | 231 +++++++++++++ tests/test_examples.sh | 19 +- 13 files changed, 657 insertions(+), 408 deletions(-) create mode 100644 examples/mlp/infer.py rename examples/mlp/{linears.py => train.py} (96%) create mode 100644 tests/runtime/test_runtime_collectives.py diff --git a/cube/__init__.py b/cube/__init__.py index 7c1e3188..d2c5bbd7 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -19,34 +19,3 @@ def init(): _check_torch_version() - - -# ================== Experimental Feature ======================= - -# import threading - -# _message_context = None - -# def handle_request(): -# manager = runtime.executor.MessageManager() -# while True: -# req = manager.pull() -# if isinstance(req, int): -# break -# req.wait() - -# def init_manager(): -# global _message_context -# _ = runtime.executor.MessageManager() -# _message_context = threading.Thread(target=handle_request) -# _message_context.start() - - -# def finish_manager(): -# """ -# Clear message manager -# """ -# global _message_context -# manager = runtime.executor.MessageManager() -# manager.push(-1) -# _message_context.join() diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 50139b3c..73047db7 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -1,15 +1,11 @@ -from typing import Dict, Generator, Iterable, List, Any, Optional, Set, Tuple, Union -from more_itertools import split_when - +from typing import Generator, Iterable, List, Any, Optional, Tuple from cube.ir.cten import IRCell, IRTensor from cube.ir.dtype import IRDType from cube.ir.tensor import IRSubTensor -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from cube.ir.operator import IRDataOperation, IRFwOperation from cube.ir.adapter import IRWeightReducer, IRAdapter - -from cube.graph.graph import IRSegment -from cube.execplan.execplan import ExeReuseCell +from cube.ir.adapter.prim import IRAdapterPrim from cube.codegen.frontend_mapping import Sign2EmitRule @@ -159,12 +155,24 @@ def emit_adapter(node: IRAdapter, prefix_attr: Optional[str] = None) -> List[str assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" prims = [node] if node.differentiable and node.custom else [prim for prim in node.prims] + # only adapter that is non-differentiable can be executed as async + async_op = CompileFlag.async_comm and (not node.differentiable) + for idx, prim in enumerate(prims): + if isinstance(prim, IRAdapterPrim) and prim.volume() == 0: + continue + break + #TODO: support more general cases: independent same-group primitives + async_op = False if len(prims[idx:]) != 1 else async_op + for prim in prims: if len(prim.inputs()) == 1: itensors = FuncEmission.tensor_name(prim.inputs()[0], prefix_attr=prefix_attr) else: itensors = FuncEmission.tuple_name(prim.inputs(), prefix_attr=prefix_attr) - kwargs = FuncEmission.kwargs_name(**prim.kwargs) + prim_kwargs = dict(prim.kwargs) + if async_op: + prim_kwargs['async_op'] = True + kwargs = FuncEmission.kwargs_name(**prim_kwargs) outputs = FuncEmission.return_name(prim.outputs()) code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' codes.append(code) diff --git a/cube/flags.py b/cube/flags.py index bd8a54c5..b8f4d511 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -36,6 +36,7 @@ class CompileFlag: # ============== runtime ==================== dev_mode = _to_bool('SINGLE_DEV_MODE') # allow to use python xx.py + async_comm = _to_bool('ASYNC_COMM') # maximal reducer weight bytes for one allreduce max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=5e8) diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index d70b6f3a..0dab9656 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -65,6 +65,9 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor], **kwar super().__init__(inputs, outputs, **kwargs) self.device = list(set(t.device[0] for t in inputs)) + def volume(self) -> int: + return 0 + # numerical abstract primitive class ValuePrim(IRAdapterPrim): @@ -75,6 +78,9 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): super().__init__(inputs, outputs) self.device = list(set(t.device[0] for t in inputs)) + def volume(self) -> int: + return 0 + # communication abstract primitive class CommPrim(IRAdapterPrim): @@ -117,9 +123,6 @@ def __repr__(self): dscp = f"{self.output(0)} = identity({self.input(0)})" return dscp - def volume(self) -> int: - return 0 - class SelectPrim(SpatialPrim): @@ -137,9 +140,6 @@ def __repr__(self): dscp = f"{self.output(0)} = select({self.input(0)}, indmap={self.kwargs['indmap']}, valmap={self.kwargs['valmap']})" return dscp - def volume(self) -> int: - return 0 - class MergeDimPrim(SpatialPrim): """ @@ -153,9 +153,6 @@ def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, dim: int) def __repr__(self) -> str: return f"dev{self.device}: {self.output(0)} = concat({self.inputs()}, dim={self.kwargs['dim']})" - def volume(self) -> int: - return 0 - # numerical primitive class SumPrim(ValuePrim): @@ -165,9 +162,6 @@ def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): super().__init__(itensors, [otensor]) self.signature = 'cube.runtime.adapter.vmerge' - def volume(self) -> int: - return 0 - def __repr__(self) -> str: return f"dev{self.device}: {self.output(0)} = add({self.inputs()})" @@ -190,7 +184,9 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k self.signature = 'cube.runtime.adapter.move' def volume(self) -> int: - return self.input(0).nelement() + if len(self._inputs) > 0: + return self.input(0).nelement() + return self.output(0).nelement() def __repr__(self): dscp = f"{self.outputs()} = move{self.device}({self.inputs()}, src={self.kwargs['src']}, dst={self.kwargs['dst']})" diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 75c80fd2..54e0a1d2 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -4,190 +4,129 @@ from cube.runtime.device import DeviceGroup from cube.profiler.timer import CudaTimer, print_each_rank - -def send(tensor: torch.Tensor, dst: int): - """ - send tensor to the remote devices. Each tensor can be - sent to multiple devices - - Args: - tensors (List[torch.Tensor]): list of tensor to send - tensor_devices (List[List[int]]): tensor sent devices - """ - # print(f'{torch.distributed.get_rank()}: sending...') - CudaTimer().start(field_name='comm', predefined=True) - - send_ops = list() - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, dst - ) - send_ops.append(send_op) - reqs = torch.distributed.batch_isend_irecv(send_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) - return tensor - - -def recv(tensors: List[torch.Tensor], shape: List[int], dtype: torch.dtype, src: int): - # print(f'{torch.distributed.get_rank()}: recving...') - CudaTimer().start(field_name='comm', predefined=True) - ## synthetic ## - # for shape in shapes: - # recv_tensors.append( - # torch.ones(tuple(shape), - # device=torch.cuda.current_device() - # )) - # - tensor = torch.empty( - shape, requires_grad=True, dtype=dtype, - device=torch.cuda.current_device() - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, src - ) - reqs = torch.distributed.batch_isend_irecv([recv_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) - return tensor - - -# def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int): -# rank = torch.distributed.get_rank() -# if rank == src: -# assert torch.is_tensor(tensor) -# return send(tensor, dst) -# else: -# assert rank == dst -# return recv(None, shape, dtype, src) +from cube.runtime.executor import AsyncCommHandler -def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int): +def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int, async_op=False): """ Move a tensor from source device to destination device. """ - CudaTimer().start(field_name='comm', predefined=True) + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() + work = None if rank == src: tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor assert torch.is_tensor(tensor) - torch.distributed.send(tensor, dst) + if async_op: + work = torch.distributed.isend(tensor, dst) + # NOTE: we don't add isend work item into handler + else: + torch.distributed.send(tensor, dst) else: assert rank == dst tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device() ) - torch.distributed.recv(tensor, src) - CudaTimer().stop(field_name='comm', predefined=True) + if async_op: + work = torch.distributed.irecv(tensor, src) + AsyncCommHandler().submit(tensor, [work]) + else: + torch.distributed.recv(tensor, src) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) return tensor - -def sendrecv(input_tensors: List[torch.Tensor], - output_shapes: List[List[int]], - output_dtypes: List[torch.dtype], - send_ranks: List[int], - recv_ranks: List[int]) -> List[torch.Tensor]: - CudaTimer().start(field_name='comm', predefined=True) - # print('sending and recving...') - ops = list() - outputs = list() - for tensor, rank in zip(input_tensors, send_ranks): - if not torch.is_tensor(tensor): - raise RuntimeError(f"Expected {tensor} to be tensor") - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, rank - ) - ops.append(send_op) - for shape, dtype, rank in zip(output_shapes, output_dtypes, recv_ranks): - tensor = torch.empty( - shape, dtype=dtype, - requires_grad=True, device=torch.cuda.current_device() - ) - outputs.append(tensor) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, rank - ) - ops.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) - return outputs - - -def all_reduce(itensor: torch.Tensor, - ranks: List[int]) -> torch.Tensor: +def all_reduce(tensor: torch.Tensor, + ranks: List[int], async_op=False) -> torch.Tensor: """ Allreduce """ - CudaTimer().start(field_name='comm', predefined=True) - if not itensor.is_contiguous(): - itensor = itensor.contiguous() - itensor = itensor.detach() + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + tensor = tensor.detach() group = DeviceGroup().get_group(ranks) - torch.distributed.all_reduce(itensor, group=group) - CudaTimer().stop(field_name='comm', predefined=True) - return itensor + if async_op: + work = torch.distributed.all_reduce(tensor, group=group, async_op=True) + AsyncCommHandler().submit(tensor, [work]) + else: + torch.distributed.all_reduce(tensor, group=group) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) + return tensor -def all_gather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: + +def all_gather(tensor: torch.Tensor, dim: int, + ranks: Tuple[int], async_op=False) -> torch.Tensor: """ Allgather """ - CudaTimer().start(field_name='comm', predefined=True) - if not itensor.is_contiguous(): - itensor = itensor.contiguous() + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor group = DeviceGroup().get_group(ranks) - tensor_list = [torch.empty_like(itensor) for _ in ranks] - tensor_list[torch.distributed.get_rank(group)] = itensor.data - torch.distributed.all_gather(tensor_list, itensor, group=group) - # concat - otensor = torch.concat(tuple(tensor_list), dim=dim) - CudaTimer().stop(field_name='comm', predefined=True) + tensor_list = [torch.empty_like(tensor) for _ in ranks] + tensor_list[torch.distributed.get_rank(group)] = tensor.data + work = torch.distributed.all_gather(tensor_list, tensor, group=group, async_op=async_op) + if work: + allgather_callback = lambda t: torch.concat(tuple(tensor_list), dim=dim) + AsyncCommHandler().submit(tensor, [work], allgather_callback) + otensor = tensor + else: + otensor = torch.concat(tuple(tensor_list), dim=dim) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) return otensor -def reduce_scatter(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: +def reduce_scatter(tensor: torch.Tensor, dim: int, + ranks: Tuple[int], async_op=False) -> torch.Tensor: """ ReduceScatter """ - CudaTimer().start(field_name='comm', predefined=True) - itensors = list(itensor.chunk(len(ranks), dim)) - for idx, tensor in enumerate(itensors): - if not tensor.is_contiguous(): - itensors[idx] = tensor.contiguous() + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) + itensors = list(tensor.chunk(len(ranks), dim)) + for idx, t in enumerate(itensors): + itensors[idx] = t.contiguous() if not t.is_contiguous() else t group = DeviceGroup().get_group(ranks) otensor = torch.empty_like(itensors[0], requires_grad=False) - torch.distributed.reduce_scatter(otensor, itensors, group=group) - CudaTimer().stop(field_name='comm', predefined=True) + work = torch.distributed.reduce_scatter(otensor, itensors, group=group, async_op=async_op) + if work: + AsyncCommHandler().submit(otensor, [work]) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) return otensor -def all_to_all(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: +def all_to_all(tensor: torch.Tensor, idim: int, odim: int, + ranks: Tuple[int], async_op=False) -> torch.Tensor: """ All-to-all """ - CudaTimer().start(field_name='comm', predefined=True) - itensors = list(itensor.chunk(len(ranks), dim=odim)) - for idx, tensor in enumerate(itensors): - if not tensor.is_contiguous(): - itensors[idx] = tensor.contiguous() + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) + itensors = list(tensor.chunk(len(ranks), dim=odim)) + for idx, itensor in enumerate(itensors): + itensors[idx] = itensor.contiguous() if not itensor.is_contiguous() else itensor otensors = [torch.empty_like(t) for t in itensors] group = DeviceGroup().get_group(ranks) - torch.distributed.all_to_all(otensors, itensors, group=group) - otensor = torch.concat(tuple(otensors), dim=idim) - CudaTimer().stop(field_name='comm', predefined=True) + work = torch.distributed.all_to_all(otensors, itensors, group=group, async_op=async_op) + if work: + all2all_callback = lambda t: torch.concat(tuple(otensors), dim=idim) + AsyncCommHandler().submit(tensor, [work], all2all_callback) + otensor = tensor + else: + otensor = torch.concat(tuple(otensors), dim=idim) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) return otensor -def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: +def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int], async_op=False) -> torch.Tensor: """ split dimension in n chunks and take idx-th chunk @@ -205,19 +144,23 @@ def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: def rdscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, - dim: int, src: int, dsts: Tuple[int]): + dim: int, src: int, dsts: Tuple[int], async_op=False): """ RDScatter: split itensor at rank `src` along dim into `len(dsts)` chunks, and then send each chunk to `dst` devices. """ - CudaTimer().start(field_name='comm', predefined=True) + if async_op: + CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() if rank == src: with torch.no_grad(): otensors = itensor.chunk(len(dsts), dim) for dst, otensor in zip(dsts, otensors): otensor = otensor.contiguous() if not otensor.is_contiguous() else otensor - torch.distributed.send(otensor, dst) + if async_op: + torch.distributed.isend(otensor, dst) + else: + torch.distributed.send(otensor, dst) otensor = itensor else: assert rank in dsts @@ -227,72 +170,101 @@ def rdscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, shape, requires_grad=True, dtype=dtype, device=torch.cuda.current_device() ) - torch.distributed.recv(otensor, src) - CudaTimer().stop(field_name='comm', predefined=True) + if async_op: + work = torch.distributed.irecv(otensor, src) + AsyncCommHandler().submit(otensor, [work]) + else: + torch.distributed.recv(otensor, src) + if async_op: + CudaTimer().stop(field_name='comm', predefined=True) return otensor def rvscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, - dim: int, src: int, dsts: Tuple[int]): + src: int, dsts: Tuple[int], async_op=False): """ src: global rank """ - CudaTimer().start(field_name='comm', predefined=True) + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) group = DeviceGroup().get_group((src,) + dsts) rank = torch.distributed.get_rank() tensor: torch.Tensor = itensor / len(dsts) if src == rank else \ torch.empty(shape, dtype=dtype, requires_grad = True) tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor - torch.distributed.broadcast(tensor, src, group=group) - CudaTimer().stop(field_name='comm', predefined=True) + work = torch.distributed.broadcast(tensor, src, group=group, async_op=async_op) + if work: + AsyncCommHandler().submit(tensor, [work]) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) return tensor def rdgather(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, - dim: int, srcs: Tuple[int], dst: int): + dim: int, srcs: Tuple[int], dst: int, async_op=False): """ @param srcs Tuple[int]: global rank of each source device @param dst int: global rank of destination device """ - CudaTimer().start(field_name='comm', predefined=True) + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() if rank == dst: - recv_tensors = [] + recv_tensors, works = [], [] for src in srcs: tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device()) - torch.distributed.recv(tensor, src) recv_tensors.append(tensor) - otensor = torch.cat(tuple(recv_tensors), dim=dim) - otensor = otensor.requires_grad_() + if async_op: + work = torch.distributed.irecv(tensor, src) + works.append(work) + else: + work = torch.distributed.recv(tensor, src) + + if async_op: + rdgather_callback = lambda t: torch.cat(tuple(recv_tensors), dim=dim) + AsyncCommHandler().submit(itensor, works, rdgather_callback) + otensor = itensor + else: + otensor = torch.cat(tuple(recv_tensors), dim=dim) + otensor = otensor.requires_grad_() else: assert rank in srcs otensor = itensor.contiguous() if not itensor.is_contiguous() else itensor - torch.distributed.send(otensor, dst) - CudaTimer().stop(field_name='comm', predefined=True) + if async_op: + torch.distributed.isend(otensor, dst) + else: + torch.distributed.send(otensor, dst) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) return otensor def rvgather(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, - srcs: Tuple[int], dst: int): + srcs: Tuple[int], dst: int, async_op=False): """ @param srcs Tuple[int]: global rank of each source device @param dst int: global rank of destination device """ - CudaTimer().start(field_name='comm', predefined=True) + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() group = DeviceGroup().get_group(srcs + (dst,)) tensor = torch.zeros(shape, dtype=dtype, requires_grad=True) if rank == dst else itensor - torch.distributed.reduce(tensor, dst, group=group) - CudaTimer().stop(field_name='comm', predefined=True) + work = torch.distributed.reduce(tensor, dst, group=group, async_op=async_op) + if work and rank == dst: + AsyncCommHandler().submit(tensor, [work]) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) return tensor -def broadcast(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, src: int, ranks: List[int]) -> torch.Tensor: +def broadcast(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, src: int, ranks: List[int], async_op=False) -> torch.Tensor: """ Broadcast @param src: the global rank that holds tensor for broadcasting """ - CudaTimer().start(field_name='comm', predefined=True) + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() group = DeviceGroup().get_group(ranks) if rank == src: @@ -301,86 +273,38 @@ def broadcast(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, src: assert rank in ranks tensor = torch.empty(shape, device=torch.cuda.current_device(), requires_grad=True, dtype=dtype) - torch.distributed.broadcast(tensor, src, group=group) - CudaTimer().stop(field_name='comm', predefined=True) + work = torch.distributed.broadcast(tensor, src, group=group, async_op=async_op) + if work and rank != src: + AsyncCommHandler().submit(tensor, [work]) + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) return tensor -def gather(input_tensors: List[torch.Tensor], - output_shapes: List[List[int]], - output_dtypes: List[torch.dtype], - ranks: List[int]) -> List[torch.Tensor]: +def exchange(tensor: torch.Tensor, ranks: List[int], async_op=False) -> torch.Tensor: """ - Gather. ranks[0] is the root + Exchange a same-shaped tensor between two ranks """ - CudaTimer().start(field_name='comm', predefined=True) - assert len(input_tensors) == 1 - input_tensor = input_tensors[0] - dst = ranks[0] - if DeviceGroup().rank == dst: - # recv - tensor_list = [input_tensor] + [torch.empty_like(input_tensor) for _ in range(len(ranks)-1)] - ops = list() - for rank, tensor in zip(ranks[1:], tensor_list[1:]): - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, rank - ) - ops.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - else: - # send - tensor_list = [] - send_op = torch.distributed.P2POp( - torch.distributed.isend, input_tensor, ranks[0] - ) - reqs = torch.distributed.batch_isend_irecv([send_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) - return tensor_list - - -def scatter(input_tensors: List[torch.Tensor], - output_shapes: List[List[int]], - output_dtypes: List[torch.dtype], - ranks: List[int]) -> List[torch.Tensor]: - CudaTimer().start(field_name='comm', predefined=True) - output = None - src = ranks[0] - if DeviceGroup().rank == src: - # send - ops = list() - for rank, tensor in zip(ranks, input_tensors): - if rank == src: - output = tensor - else: - if not tensor.is_contiguous(): - with torch.no_grad(): - tensor = tensor.contiguous() - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, rank - ) - ops.append(send_op) - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() + + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) + + assert len(ranks) == 2 + group = DeviceGroup().get_group(ranks) + myrank = torch.distributed.get_rank(group) + + tensor_list = [tensor, torch.empty_like(tensor)] if myrank == 0 \ + else [torch.empty_like(tensor), tensor.data] + + work = torch.distributed.all_gather(tensor_list, tensor, group=group, async_op=async_op) + if work: + exchange_callback = lambda t: tensor_list[(myrank + 1) % 2] + AsyncCommHandler().submit(tensor, [work], exchange_callback) + otensor = tensor else: - # recv - assert len(output_shapes) == 1 and len(output_dtypes) == 1 - output = torch.empty( - output_shapes[0], dtype=output_dtypes[0], - requires_grad=True, device=torch.cuda.current_device() - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, output, src - ) - reqs = torch.distributed.batch_isend_irecv([recv_op]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) - return output - \ No newline at end of file + otensor = tensor_list[(myrank + 1) % 2] + + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) + + return otensor diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 5c5e0df1..0247854c 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -3,23 +3,10 @@ """ import atexit -from typing import Tuple, Any, Callable, List, Dict +from typing import Tuple, Any, Callable, List, Dict, Optional import torch import warnings -from cube.flags import CompileFlag - - -if CompileFlag.use_amp: - warnings.warn( - "Detected auto mixed precision (AMP) is enabled. It's an " - "experimental feature that is only for benchmark. " - "torch.cdua.amp.GradScalerr is not enabled for loss " - "and optimizer, which may lead to gradient loss. The tensors " - "and dtypes arguments in adapter will be automatically converted to " - "torch.float16, if they are in float32 precision or torch.float32 dtype." - ) - def debug_id(tensors, msg: str, rank: int): if torch.distributed.get_rank() == rank: @@ -29,17 +16,51 @@ def debug_id(tensors, msg: str, rank: int): print(f'[{torch.distributed.get_rank()}] {msg}: {[id(t) for t in tensors]}') -def convert_fp32_to_fp16(t: Any): - """ - A tensor with float32 will be converted to float16. - A dtype of torch.float32 will be returned as torch.float16 - """ - if isinstance(t, torch.dtype) and t == torch.float32: - t = torch.float16 - elif torch.is_tensor(t) and t.dtype == torch.float32: - with torch.no_grad(): - t = t.half() - return t +class AsyncCommHandler: + + class __AsyncCommHandler: + def __init__(self): + self._works: Dict[int, List] = {} + self._callbacks: Dict[int, Callable] = {} + + instance = None + + def __init__(self) -> None: + if not AsyncCommHandler.instance: + AsyncCommHandler.instance = AsyncCommHandler.__AsyncCommHandler() + + def __getattr__(self, name): + return getattr(self.instance, name) + + def wait(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Wait until the finish of the communication + + @param tensor torch.Tensor + @return tensor torch.Tensor + """ + if id(tensor) not in self._works: + return tensor + works = self._works.pop(id(tensor)) + for work in works: + work.wait() + callback = self._callbacks.pop(id(tensor)) + if callback is not None: + tensor = callback(tensor) + return tensor + + def submit(self, tensor: torch.Tensor, works: List, callback: Optional[Callable] = None): + """ + Submit an async communication + """ + self._works[id(tensor)] = works + self._callbacks[id(tensor)] = callback + + def clear(self): + AsyncCommHandler.instance = AsyncCommHandler.__AsyncCommHandler() + + def check_clear(self): + assert len(self._works) == 0 and len(self._callbacks) == 0 TensorPairs = List[Tuple[int, torch.Tensor]] @@ -59,6 +80,8 @@ def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires """ forward the sub-graph. """ + input_tensors = Executor.sync_tensors(input_tensors) + if not requires_grad: with torch.no_grad(): outputs = subgraph(*input_tensors) @@ -82,9 +105,6 @@ def aexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) """ execute adapter """ - if CompileFlag.use_amp: - input_tensors = tuple(convert_fp32_to_fp16(t) for t in input_tensors) - if not requires_grad: with torch.no_grad(): outputs = subgraph(*input_tensors) @@ -119,6 +139,8 @@ def backward(name: str, @return gradients List[torch.Tensor]: gradient tensors corresponding to input_tensors. """ + output_tensor_grads = Executor.sync_tensors(output_tensor_grads) + if len(output_tensors) == 0: return None saved_pairs = Executor._detach[name].pop(0) @@ -145,6 +167,13 @@ def backward(name: str, elif len(grads) == 1: return grads[0] else: return grads + @staticmethod + def sync_tensors(tensors: List[Any]) -> List[Any]: + """ + Wait until the finish of synchornized tensors + """ + return [AsyncCommHandler().wait(t) if torch.is_tensor(t) else t for t in tensors] + @staticmethod def clear(): Executor._detach = dict() @@ -160,35 +189,7 @@ def check_clear(): aexecute = Executor.aexecute backward = Executor.backward + # register checking for normal exit atexit.register(Executor.check_clear) - - -### =================== Experimental Feature ======================= - -# import queue -# -# -# class MessageManager: -# """ -# message manager to make send as async calls. -# """ -# -# class __MessageManager: -# def __init__(self): -# self._reqs = queue.Queue(maxsize=128) -# -# instance = None -# -# def __init__(self): -# if not MessageManager.instance: -# MessageManager.instance = MessageManager.__MessageManager() -# -# def __getattr__(self, name): -# return getattr(self.instance, name) -# -# def push(self, req): -# self.instance._reqs.put(req, block=True, timeout=None) -# -# def pull(self): -# return self.instance._reqs.get(block=True, timeout=None) +atexit.register(AsyncCommHandler().check_clear) diff --git a/cube/runtime/schedule/strategy.py b/cube/runtime/schedule/strategy.py index d151565d..581c702c 100644 --- a/cube/runtime/schedule/strategy.py +++ b/cube/runtime/schedule/strategy.py @@ -1,22 +1,9 @@ from typing import Any, Callable, Dict, Iterable, List import torch -from cube.profiler.timer import CudaTimer - +from cube.runtime.executor import AsyncCommHandler from cube.flags import CompileFlag - - -def convert_fp32_to_fp16(t: Any): - """ - A tensor with float32 will be converted to float16. - A dtype of torch.float32 will be returned as torch.float16 - """ - if isinstance(t, torch.dtype) and t == torch.float32: - t = torch.float16 - elif torch.is_tensor(t) and t.dtype == torch.float32: - with torch.no_grad(): - t = t.half() - return t +from cube.profiler.timer import CudaTimer class ScheduleABC: @@ -28,10 +15,12 @@ def forward_step(segment: Callable, *args, **kwargs): """ forward pass """ - CudaTimer().start('forward') - with torch.autocast('cuda', torch.float16, enabled=CompileFlag.use_amp): - outputs = segment(*args, **kwargs) - CudaTimer().stop('forward') + args = ScheduleABC.sync_tensors(args) + if not CompileFlag.async_comm: + CudaTimer().start('forward') + outputs = segment(*args, **kwargs) + if not CompileFlag.async_comm: + CudaTimer().stop('forward') if not isinstance(outputs, tuple): outputs = (outputs,) return outputs @@ -43,14 +32,17 @@ def backward_step(itensors: List[torch.Tensor], """ backward pass """ + otensor_grads = ScheduleABC.sync_tensors(otensor_grads) for tensor in itensors: if torch.is_tensor(tensor) and tensor.requires_grad: tensor.retain_grad() - CudaTimer().start("backward") + if not CompileFlag.async_comm: + CudaTimer().start("backward") otensors = [t for t in otensors if t.requires_grad] assert len(otensors) == len(otensor_grads), f"output tensor mismatches with gradient number" torch.autograd.backward(otensors, grad_tensors=otensor_grads) - CudaTimer().stop("backward") + if not CompileFlag.async_comm: + CudaTimer().stop("backward") itensor_grads = [] for tensor in itensors: if torch.is_tensor(tensor) and tensor.requires_grad: @@ -75,11 +67,11 @@ def adapter_step(adapter: Callable, require_grad : bool = True, *args): if adapter is None: return (None,) # if adapter is None: return () args = tuple(t for t in args if torch.is_tensor(t)) - if CompileFlag.use_amp: - args = tuple(convert_fp32_to_fp16(t) for t in args) - CudaTimer().start('adapter') + if not CompileFlag.async_comm: + CudaTimer().start('adapter') outputs = adapter(*args) - CudaTimer().stop('adapter') + if not CompileFlag.async_comm: + CudaTimer().stop('adapter') if not isinstance(outputs, tuple): outputs = (outputs,) if require_grad: @@ -128,6 +120,13 @@ def pop_tail(name: str): if len(ScheduleABC.status[name]) == 0: del ScheduleABC.status return out + + @staticmethod + def sync_tensors(tensors: List[Any]) -> List[Any]: + """ + Wait until the finish of synchornized tensors + """ + return [AsyncCommHandler().wait(t) if torch.is_tensor(t) else t for t in tensors] @staticmethod def assert_empty(): diff --git a/examples/mlp/infer.py b/examples/mlp/infer.py new file mode 100644 index 00000000..6871a553 --- /dev/null +++ b/examples/mlp/infer.py @@ -0,0 +1,104 @@ +""" +example: + +ASYNC_COMM=1 OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/infer.py --policy PASMegatron +""" + +import torch +from torch import nn +import time + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + +import examples.mlp.policy.spmd as spmd +import examples.mlp.policy.mpmd as mpmd + +import argparse + +parser = argparse.ArgumentParser(description='MLP example') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +args = parser.parse_args() + +cube.init() + +# set up policy +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + + +# =================== Semantic Model Description ==================== + +class MLP(nn.Module): + def __init__(self, dim, mult=1, nlayers=4): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for lid in range(nlayers): + if lid % 2 == 0: + self.layers.append(nn.Linear(dim, dim * mult, bias=False)) + else: + self.layers.append(nn.Linear(dim * mult, dim, bias=False)) + + def forward(self, data): + x = data + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) + return loss + + +def infer(): + batch_size = 128 + dim = 4096 + + model = MLP(dim=dim) + model = cube.SemanticModel( + model, input_shapes=([batch_size, dim],), + ) + + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([batch_size, dim],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) + + @cube.compile(model, dataloader, PAS=PAS) + def infer_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + model = model.get_gen_module() + + CudaTimer(enable=False).warmup() + if torch.distributed.is_initialized(): + torch.distributed.barrier() + iter_num = 16 + warmup = 4 + for step in range(iter_num): + if step >= warmup: + CudaTimer(enable=True, predefined=True).start('e2e') + infer_iter(model, dataloader) + if step >= warmup: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + torch.distributed.barrier() + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) + + +infer() \ No newline at end of file diff --git a/examples/mlp/policy/mpmd.py b/examples/mlp/policy/mpmd.py index 49aaa10a..bf00f3ee 100644 --- a/examples/mlp/policy/mpmd.py +++ b/examples/mlp/policy/mpmd.py @@ -3,6 +3,7 @@ import numpy as np from cube.graph.graph import IRGraph +from cube.graph.segment import IRSegment from cube.ir.operator import IRDataOperation, IRFwOperation from cube.graph.schedule.predefined import PredefinedSched @@ -51,42 +52,52 @@ def PASRandom(graph, resource): return graph -def PAS1F1B(graph: IRGraph, resource): +def PASMegatron(graph: IRGraph, resource): # assert resource.ngpus == 8, "should apply on 8 gpus" num_stage = 4 num_tp = resource.ngpus // num_stage - num_microbatch = resource.ngpus + num_microbatch = resource.ngpus * 8 _, tp_mesh = _create_mesh(resource.ngpus, (num_stage, num_tp)) print(f'> pipeline-tensor parallel group: {tp_mesh}') assert len(tp_mesh) == num_stage - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - node2stage = lambda node: min(fnodes.index(node) // (len(fnodes) // num_stage), num_stage-1) + linears = graph.select('linear') + stage_start_nodes = linears[::len(linears) // num_stage] + stage_start_nodes = stage_start_nodes[:num_stage] + assert len(stage_start_nodes) == num_stage, f"{len(stage_start_nodes)} != {num_stage}" + graph.staging(stage_start_nodes) - for idx, node in enumerate(fnodes): + segments = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + assert len(fsegs) == num_stage + + for sid, segment in enumerate(fsegs): # get tensor parallel group - sid = node2stage(node) tp_group = tp_mesh[sid] - # partition - if node.name == 'linear': - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=num_tp) - else: - tp_nodes = graph.replicate(node, times=num_tp) - # assign - for devid, node in zip(tp_group, tp_nodes): - graph.assign(node, devid) + for idx, node in enumerate(segment.nodes()): + # partition + if node.name == 'linear': + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx % 2, num=num_tp) + else: + tp_nodes = graph.replicate(node, times=num_tp) + # assign + for devid, node in zip(tp_group, tp_nodes): + graph.assign(node, devid) - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - mesh = tp_mesh[0] - rnodes = graph.replicate(node, times=num_tp) - for devid, rnode in zip(mesh, rnodes): - graph.assign(rnode, devid) + for dl in graph.select(ntype=IRDataOperation): + mesh = tp_mesh[0] + dls = graph.replicate(dl, times=num_tp) + for devid, dl in zip(mesh, dls): + graph.assign(dl, devid) + # setup schedule to 1F1B # schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) # graph.schedule_plan = schedule - schedule = PredefinedSched.sched_1f1b(graph, num_microbatch, num_stage) - return graph \ No newline at end of file + if graph.train: + schedule = PredefinedSched.sched_1f1b(graph, num_microbatch, num_stage) + else: + schedule = PredefinedSched.sched_infer_pipe(graph, num_microbatch, num_stage) + return graph diff --git a/examples/mlp/policy/spmd.py b/examples/mlp/policy/spmd.py index eb2b502d..e3bae1a4 100644 --- a/examples/mlp/policy/spmd.py +++ b/examples/mlp/policy/spmd.py @@ -143,7 +143,7 @@ def PASHybrid(graph: IRGraph, resource): return graph -def PASMegatron(graph: IRGraph, resource): +def PASMegatronTP(graph: IRGraph, resource): """ Tensor + Data Parallelism """ diff --git a/examples/mlp/linears.py b/examples/mlp/train.py similarity index 96% rename from examples/mlp/linears.py rename to examples/mlp/train.py index 05eb5f51..2d58eee9 100644 --- a/examples/mlp/linears.py +++ b/examples/mlp/train.py @@ -19,7 +19,7 @@ import argparse -parser = argparse.ArgumentParser(description='comm primitive') +parser = argparse.ArgumentParser(description='MLP example') parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') args = parser.parse_args() @@ -60,7 +60,7 @@ def forward(self, data): def train(): batch_size = 128 - dim = 8192 + dim = 4096 model = MLP(dim=dim) model = cube.SemanticModel( @@ -85,8 +85,8 @@ def train_iter(model, dataloader): CudaTimer(enable=False).warmup() if torch.distributed.is_initialized(): torch.distributed.barrier() - iter_num = 32 - warmup = 8 + iter_num = 16 + warmup = 4 for step in range(iter_num): if step >= warmup: CudaTimer(enable=True).start('e2e') diff --git a/tests/runtime/test_runtime_collectives.py b/tests/runtime/test_runtime_collectives.py new file mode 100644 index 00000000..c90cd773 --- /dev/null +++ b/tests/runtime/test_runtime_collectives.py @@ -0,0 +1,231 @@ +""" +intra-primitives: + torchrun --nproc_per_node=2 tests/runtime/test_runtime_collectives.py + +inter-primitives: + torchrun --nproc_per_node=3 tests/runtime/test_runtime_collectives.py +""" + +from typing import List + +import cube +import torch + + + +cube.init() + +mydevice = torch.cuda.current_device() +myrank = torch.distributed.get_rank() +ndevices = torch.distributed.get_world_size() + + +def _get_tensor(shape: List[int], dtype: torch.dtype = torch.float32, rank=myrank) -> torch.Tensor: + global mydevice, myrank + tensor = torch.ones(shape, dtype=dtype, device=mydevice) + tensor = tensor * rank + return tensor + + +def test_runtime_move(): + assert ndevices == 2 + shape = [128, 256] + + # synchronize + tensor = _get_tensor(shape) + res = _get_tensor(shape, rank=0) + tensor = cube.runtime.adapter.move(tensor, shape, torch.float32, 0, 1) + + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + # asynchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.move(tensor, shape, torch.float32, 0, 1, async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass move') + + +def test_runtime_allreduce(): + assert ndevices == 2 + shape = [128, 256] + + # synchronize + tensor = _get_tensor(shape) + cube.runtime.adapter.all_reduce(tensor, [0, 1]) + res = _get_tensor(shape, rank=0) + _get_tensor(shape, rank=1) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + # asynchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.all_reduce(tensor, [0, 1], async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass allreduce') + + +def test_runtime_allgather(): + assert ndevices == 2 + shape = [128, 256] + + # synchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.all_gather(tensor, 0, [0, 1]) + res = torch.concat([_get_tensor(shape, rank=0), _get_tensor(shape, rank=1)], dim=0) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + # asynchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.all_gather(tensor, 0, [0, 1], async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass allgather') + + +def test_runtime_reduce_scatter(): + assert ndevices == 2 + shape = [128, 256] + + tensor = _get_tensor(shape) + res = _get_tensor(shape, rank=0) + _get_tensor(shape, rank=1) + res = res.chunk(2, dim=0)[myrank] + + # synchronize + tensor = cube.runtime.adapter.reduce_scatter(tensor, 0, [0, 1]) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + # asynchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.reduce_scatter(tensor, 0, [0, 1], async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass reduce scatter') + + +def test_runtime_all2all(): + assert ndevices == 2 + shape = [128, 256] + + tensor = _get_tensor(shape) + res = torch.concat([_get_tensor(shape, rank=0), _get_tensor(shape, rank=1)], dim=0) + res = res.chunk(2, dim=1)[myrank] + + # synchronize + tensor = cube.runtime.adapter.all_to_all(tensor, 0, 1, [0, 1]) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + # asynchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.all_to_all(tensor, 0, 1, [0, 1], async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass all2all') + + +def test_runtime_exchange(): + assert ndevices == 2 + shape = [128, 256] + + tensor = _get_tensor(shape) + res = _get_tensor(shape, rank=(myrank + 1) % 2) + + tensor = cube.runtime.adapter.exchange(tensor, [0, 1]) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.exchange(tensor, [0, 1], async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass exchange') + + +def test_runtime_rdscatter(): + assert ndevices == 3 + shape = [128, 256] + + tensor = _get_tensor(shape) + res = _get_tensor(shape, rank=0).chunk(ndevices-1, dim=0)[myrank-1] + + # synchronize + tensor = cube.runtime.adapter.rdscatter( + tensor, shape, torch.float32, dim=0, src=0, dsts=[1,2]) + if myrank > 0: + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + # synchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.rdscatter( + tensor, shape, torch.float32, dim=0, src=0, dsts=[1,2], async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + if myrank > 0: + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass rdscatter') + + +def test_runtime_rdgather(): + assert ndevices == 3 + shape = [128, 256] + + tensor = _get_tensor(shape) + res = torch.cat((_get_tensor(shape, rank=1), _get_tensor(shape, rank=2)), dim=0) + + # synchronize + tensor = cube.runtime.adapter.rdgather( + tensor, shape, torch.float32, dim=0, srcs=[1,2], dst=0) + if myrank == 0: + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + # asynchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.rdgather( + tensor, shape, torch.float32, dim=0, srcs=[1,2], dst=0, async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + if myrank == 0: + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass rdgather') + + +def test_runtime_broadcast(): + assert ndevices == 3 + shape = [128, 256] + + tensor = _get_tensor(shape) + res = _get_tensor(shape, rank=0) + + # synchronize + tensor = cube.runtime.adapter.broadcast( + tensor, shape, torch.float32, src=0, ranks=[0,1,2]) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + # asynchronize + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.broadcast( + tensor, shape, torch.float32, src=0, ranks=[0,1,2], async_op=True) + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" + + print(f'rank[{myrank}]: pass broadcast') + + +if __name__ == '__main__': + + if ndevices == 2: + test_runtime_move() + test_runtime_allreduce() + test_runtime_allgather() + test_runtime_reduce_scatter() + test_runtime_all2all() + test_runtime_exchange() + + if ndevices == 3: + test_runtime_rdscatter() + test_runtime_rdgather() + test_runtime_broadcast() diff --git a/tests/test_examples.sh b/tests/test_examples.sh index c69c616d..771e13e7 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -5,37 +5,42 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=1 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASSingle + examples/mlp/train.py --policy PASSingle OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASData + examples/mlp/train.py --policy PASData OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASCol + examples/mlp/train.py --policy PASCol OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASRow + examples/mlp/train.py --policy PASRow OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASHybrid + examples/mlp/train.py --policy PASHybrid OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASMegatron + examples/mlp/train.py --policy PASMegatronTP OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASOptimal + examples/mlp/train.py --policy PASOptimal + +ASYNC_COMM=1 OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/infer.py --policy PASMegatron # test GPT model From 17ee9c5264cb533f1054a0c17dfe4ba9f24b8dd6 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 25 Feb 2023 16:46:10 +0800 Subject: [PATCH 1257/1892] support recording time on a specfic stream --- cube/profiler/timer.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index 0c244680..d731bee4 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -73,33 +73,54 @@ def __init__(self, enable: Optional[bool] = None, predefined: Optional[bool] = N if predefined is not None: self.instance.predefined = predefined - def start(self, field_name='default', predefined: bool = False): + def start(self, field_name='default', predefined: bool = False, stream: Optional[torch.cuda.Stream] = None): """ Start recording time on the the field Note `start` and `stop` on the same field can be called nestly + + @param field_name str + @param is_predefined bool: whether the field is a predefined field + @param stream Optional[torch.cuda.Stream]: + if None (default), will synchronize all streams on the device before + recording time. Otherwise, only synchronize the specified stream. + + @return None """ if (not self.instance.enabled) or (predefined and not self.instance.predefined): return - torch.cuda.synchronize() + if stream is None: + torch.cuda.synchronize() + else: + stream.synchronize() + # torch.cuda.default_stream().synchronize() start_time = time.time() if field_name not in self.instance.field: self.instance.field[field_name] = list() self.instance.field_data[field_name] = 0 self.instance.field[field_name].append(start_time) - def stop(self, field_name='default', predefined: bool = False): + def stop(self, field_name='default', predefined: bool = False, stream: Optional[torch.cuda.Stream] = None) -> float: """ - Return the time span from last `start` on the smae field name to now + Record the time span from last `start` on the same field_name to now - Returns: - float (ms) + @param field_name str + @param is_predefined bool: whether the field is a predefined field + @param stream Optional[torch.cuda.Stream]: + if None (default), will synchronize all streams on the device before + recording time. Otherwise, only synchronize the specified stream. + + @return None """ if (not self.instance.enabled) or (predefined and not self.instance.predefined): return if field_name not in self.instance.field: raise RuntimeError("Missing start on the field") - torch.cuda.synchronize() + if stream is None: + torch.cuda.synchronize() + else: + stream.synchronize() + # torch.cuda.default_stream().synchronize() stop_time = time.time() start_time = self.instance.field[field_name].pop(-1) span = stop_time - start_time # in seconds From 001f99f507e35b6601a3b785c4c0628d9630ff63 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 25 Feb 2023 20:35:27 +0800 Subject: [PATCH 1258/1892] refine memory summary --- cube/profiler/memory.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index cf6ffcad..89719625 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -2,14 +2,9 @@ import torch from cube.profiler.timer import print_each_rank -def memory_summary(): - import os - single_device_mode = os.environ.get('SINGLE_DEV_MODE') - if single_device_mode: - rank = 0 - else: - rank = torch.distributed.get_rank() +def memory_summary(): + torch.cuda.synchronize() # memory measurement mem = torch.cuda.max_memory_allocated() # mem = torch.cuda.max_memory_reserved() From d09d2843328478b683bf19922e7a65be6213249f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 27 Feb 2023 18:05:10 +0800 Subject: [PATCH 1259/1892] staging optimize: put multiref to the next rank --- cube/graph/graph.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5c556df2..3c0cd597 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -692,6 +692,14 @@ def staging(self, nodes: Tuple[IRFwOperation]): starts = tuple(self._nodes.index(node) for node in nodes) assert len(starts) > 0 + # multiref (created by graph.auto_multiref) will be moved to the next stage (if possible) for optimization + for sid in range(len(starts)): + while starts[sid] > 0: + if self.node(starts[sid]-1).name == 'multiref': + starts[sid] -= 1 + continue + break + # adjust the start of the first stage to involve beginning operators for idx in range(starts[0]): node = self.node(idx) From 0499acc7814e44ec262d146c892321a87a22fd0c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 28 Feb 2023 13:26:42 +0800 Subject: [PATCH 1260/1892] add validation check --- cube/execplan/execplan.py | 20 ++--- cube/graph/schedule/schedplan.py | 139 ++++++++++++++++++++----------- 2 files changed, 101 insertions(+), 58 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index aed05aa8..2fcfa13c 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -161,22 +161,22 @@ def get(tensor: IRSubTensor, micro_idx: int) -> IRSubTensor: micro_fcells: Dict[(int, IRCell), ExeReuseCell] = {} def block2reuse(node: Block) -> ExeReuseCell: - if node.blk.isfw(): - key = (node.mid, node.blk) + if node.content.isfw(): + key = (node.mid, node.content) if key in micro_fcells: return micro_fcells[key] - inputs = [get(t, node.mid) for t in node.blk.inputs()] - outputs = [get(t, node.mid) for t in node.blk.outputs()] - cell = ExeReuseCell(node.blk, inputs, outputs) - if isinstance(node.blk.mirror, IRCell): - minputs = [get(t, node.mid) for t in node.blk.mirror.inputs()] - moutputs = [get(t, node.mid) for t in node.blk.mirror.outputs()] - mcell = ExeReuseCell(node.blk.mirror, minputs, moutputs) + inputs = [get(t, node.mid) for t in node.content.inputs()] + outputs = [get(t, node.mid) for t in node.content.outputs()] + cell = ExeReuseCell(node.content, inputs, outputs) + if isinstance(node.content.mirror, IRCell): + minputs = [get(t, node.mid) for t in node.content.mirror.inputs()] + moutputs = [get(t, node.mid) for t in node.content.mirror.outputs()] + mcell = ExeReuseCell(node.content.mirror, minputs, moutputs) IRCell.make_pair(cell, mcell) micro_fcells[key] = cell return cell else: - mcell = block2reuse(Block(node.blk.mirror, node.mid)) + mcell = block2reuse(Block(node.content.mirror, node.mid)) return mcell.mirror topo_seqs: List[IRCell] = [] diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py index faab3237..dea5e7ac 100644 --- a/cube/graph/schedule/schedplan.py +++ b/cube/graph/schedule/schedplan.py @@ -20,34 +20,34 @@ def __init__(self, cell: IRCell, micro_batch_id: int) -> None: """ """ assert isinstance(cell, IRCell), f"Expected IRCell, but got {type(cell)}: {cell}" - self._block: IRCell = cell + self._content: IRCell = cell self._micro_batch_id: int = micro_batch_id def __eq__(self, other): if isinstance(other, Block): - return other.blk == self.blk and other.mid == self.mid + return other.content == self.content and other.mid == self.mid return False def __hash__(self) -> int: - return hash((self._block, self._micro_batch_id)) + return hash((self._content, self._micro_batch_id)) @property def device(self) -> Tuple[int]: - return tuple(self._block.device) + return tuple(self._content.device) @property def mid(self) -> int: return self._micro_batch_id @property - def blk(self) -> IRCell: - return self._block + def content(self) -> IRCell: + return self._content def dispatch(self, devid: int): - return Block(self._block.dispatch(devid), self._micro_batch_id) + return Block(self._content.dispatch(devid), self._micro_batch_id) def __repr__(self) -> str: - return f'Block({self._micro_batch_id})-{self.device} : {self._block}' + return f"{self._content.cid}{'f' if self.content.isfw() else 'b'}{self._micro_batch_id}" class ScheduleDependency: @@ -88,17 +88,19 @@ def build(self): self.reducers = self.graph.select(ntype=IRWeightReducer, flatten=False) def depend(self, prev: Block, next: Block) -> bool: - return prev.mid == next.mid and self.graph.depends(prev.blk, next.blk) + return prev.mid == next.mid and self.graph.depends(prev.content, next.content) class PlanBase: def __init__(self, graph: IRGraph, _dependency: Optional[ScheduleDependency] = None): self._graph: IRGraph = graph - self._step_devs: List[Set[int]] = [] + self._segments: List[Block] = [] self._step_segments: List[List[Block]] = [] + self._step_devices: List[Set[int]] = [] # adapters executed *after* the segments on that step self._step_adapters: List[List[Block]] = [] + self._block_step: Dict[Block, int] = {} self._dependency = _dependency if _dependency is not None \ else ScheduleDependency(graph) @@ -114,6 +116,13 @@ def nsteps(self) -> int: def graph(self) -> IRGraph: return self._graph + @property + def device(self) -> Tuple[int]: + device = set() + for devs in self._step_devices: + device.update(devs) + return tuple(device) + def nodes(self) -> Tuple[Block]: return tuple(self._seqs) @@ -124,11 +133,13 @@ def add_segment(self, seg: IRSegment, micro_batch_id: int, step: int) -> Block: self._extend_step(step) if len(self._step_segments[step]) == 1 and isinstance(self._step_segments[0], PlanBase): assert False, "Cannot add an IRSegment into a step that already has Repetend." - assert all(devid not in self._step_devs[step] for devid in seg.device), \ - f"A step cannot execute multiple segments on a same device" + assert all(devid not in self._step_devices for devid in seg.device), \ + f"A device cannot execute multiple segments on a same step" block = Block(seg, micro_batch_id) self._step_segments[step].append(block) - self._step_devs[step].update(seg.device) + self._step_devices[step].update(seg.device) + self._block_step[block] = step + self._segments.append(block) return block def segments(self, step: int) -> Tuple[Block]: @@ -138,14 +149,16 @@ def segments(self, step: int) -> Tuple[Block]: assert step < self.nsteps return tuple(self._step_segments[step]) + def step(self, block: Block) -> int: + """Get the step of the block + """ + return self._block_step[block] + def all_segments(self) -> Tuple[Block]: """ Get all segment blocks """ - blocks = [] - for step in range(self.nsteps): - blocks += self._step_segments[step] - return tuple(blocks) + return tuple(self._segments) def _extend_step(self, step: int): """ @@ -154,7 +167,7 @@ def _extend_step(self, step: int): if len(self._step_segments) <= step: nextend = step - len(self._step_segments) + 1 self._step_segments += [[] for _ in range(nextend)] - self._step_devs += [set() for _ in range(nextend)] + self._step_devices += [set() for _ in range(nextend)] self._step_adapters += [[] for _ in range(nextend)] def _place_dataloader(self): @@ -167,7 +180,7 @@ def _place_dataloader(self): for step, blocks in enumerate(self._step_segments): for block in blocks: if isinstance(block, Block): - segment, mid = block.blk, block.mid + segment, mid = block.content, block.mid if self.graph.depends(dl, segment): self._step_segments[step].insert(0, Block(dl, mid)) break @@ -206,17 +219,10 @@ def __init__(self, graph: IRGraph, dependency: ScheduleDependency, devices = set() for block in blocks: devices.update(block.device) - self._step_devs[step] = devices + self._step_devices[step] = devices # the adapters that will be performed outside the repetend self._post_adapters: List[Block] = [] - @property - def device(self) -> Tuple[int]: - device = set() - for devs in self._step_devs: - device.update(devs) - return tuple(device) - @property def span(self) -> int: return self._span @@ -237,19 +243,19 @@ def _place_adapters(self): cnts: Dict[IRSegment, int] = {} for step in range(self.nsteps): for blk in self.segments(step): - cnts.setdefault(blk.blk, 0) - cnts[blk.blk] += 1 + cnts.setdefault(blk.content, 0) + cnts[blk.content] += 1 extended_blocks = [] for step in range(self.nsteps): for blk in self.segments(step): - extend_blk = Block(blk.blk, blk.mid + cnts[blk.blk]) + extend_blk = Block(blk.content, blk.mid + cnts[blk.content]) extended_blocks.append(extend_blk) # step2: generate adapters for each step all_blocks = self.all_segments() for adapter, sender in self._dependency.senders.items(): for step in range(self.nsteps): for block in self.segments(step): - if block.blk != sender: continue + if block.content != sender: continue # sender adapter can be classified into three categories # 1) its recver are in the same repetend # 2) its recver are in neighbored repetend @@ -263,7 +269,7 @@ def _place_adapters(self): self._step_adapters[step].append(ablock) # case 2) elif rblock in extended_blocks: - self._step_adapters[self.nsteps-1].append(Block(adapter, block.mid - cnts[blk.blk])) + self._step_adapters[self.nsteps-1].append(Block(adapter, block.mid - cnts[blk.content])) self._post_adapters.append(ablock) # case 3) else: @@ -311,13 +317,6 @@ def nmicros(self) -> int: Get number of micro-batches """ return self._num_microbatches - - @property - def device(self) -> Tuple[int]: - devs = set() - for node in self._seqs: - devs.update(node.device) - return tuple(devs) @property def graph(self) -> IRGraph: @@ -353,7 +352,7 @@ def finish(self): """ Check whether the description contains full micro-batches """ - pass + assert self.validate(), f"The schedule plan is not valid." def apply(self): """ @@ -373,6 +372,19 @@ def apply(self): # step 4: generate topological sequence self.topo_sort() + def validate(self) -> bool: + """ + Validate the plan to check if it satisfies data dependency + + @return valid bool + """ + for block1 in self._segments: + for block2 in self._segments: + if self._dependency.depend(block1, block2): + if self.step(block1) >= self.step(block2): + return False + return True + def _place_adapters(self): """ Place adapters to make sure the communication happens @@ -381,7 +393,6 @@ def _place_adapters(self): assert len(self._dependency.adapters) > 0 for adapter in self._dependency.adapters: sender: IRSegment = self._dependency.senders[adapter] - print(f'place sender: {sender}') # find sender step and insert adapter for step, blocks in enumerate(self._step_segments): if len(blocks) == 0: continue @@ -389,7 +400,7 @@ def _place_adapters(self): self._step_adapters[step] += list(blocks[0].get_post_adapters()) else: assert all(isinstance(blk, Block) for blk in blocks) - segments = [block.blk for block in blocks] + segments = [block.content for block in blocks] mids = [block.mid for block in blocks] if sender in segments: mid = mids[segments.index(sender)] @@ -402,10 +413,42 @@ def topo_sort(self): def __repr__(self) -> str: dscp = f"SchedulePlan:\n" - for step in range(self.nsteps): - dscp += f'\nStep {step}:\n' - for segment in self._step_segments[step]: - dscp += repr(segment) + '\n' - for adapter in self._step_adapters[step]: - dscp += repr(adapter) + '\n' + + sids: Dict[IRCell, int] = {} + for block in self._segments: + if block.content not in sids: + sids[block.content] = len(sids) + + for idx, (cell, sid) in enumerate(sids.items()): + dscp += f'{cell.name}{cell.cid:<3} = {sid}; ' + if (idx + 1) % 3 == 0: + dscp += '\n' + + dscp += '\nAnnotation: i(f/b)j = segment i on executing (forward/backward) microbatch j' + + for devid in sorted(self.device): + timeline = '\n' + for step in range(self.nsteps): + # segment + have_block = False + for block in self._step_segments[step]: + if devid in block.device: + have_block = True + break + if have_block: + blk_repr = f"{sids[block.content]}{'f' if block.content.isfw() else 'b'}{block.mid}" + timeline += f" {blk_repr}" + else: + timeline += f" ---" + # adapter + # have_block = False + # for block in self._step_adapters[step]: + # if devid in block.device: + # have_block = True + # break + # if have_block: + # timeline += ' {0: <5}'.format('adapt') + # else: + # timeline += ' {0: <5}'.format('') + dscp += timeline return dscp From 1154871a1fd80bef2c0abf44ef88918ee04b035c Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Tue, 28 Feb 2023 11:27:33 +0000 Subject: [PATCH 1261/1892] Merged PR 1456: Add basic torch.fx model graph capturing support USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/mlp/linearsfx.py --policy PASData --- cube/flags.py | 1 + cube/graph/function/function.py | 15 +- cube/graph/parser/__init__.py | 1 + cube/graph/parser/converter.py | 32 ++- cube/graph/parser/mappingfx.py | 175 ++++++++++++++ cube/graph/parser/parserfx.py | 393 ++++++++++++++++++++++++++++++++ examples/mlp/linearsfx.py | 119 ++++++++++ 7 files changed, 727 insertions(+), 9 deletions(-) create mode 100644 cube/graph/parser/mappingfx.py create mode 100644 cube/graph/parser/parserfx.py create mode 100644 examples/mlp/linearsfx.py diff --git a/cube/flags.py b/cube/flags.py index b8f4d511..d5f94e83 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -22,6 +22,7 @@ class CompileFlag: log_schedule = _to_bool('LOG_SCHEDULE') # ================ compiling ======================== + use_torchfx = _to_bool('USE_TORCHFX') # using torch.fx or torchscript as frontend to capture dataflow graph # worker sleep in seconds worker_sleep = _to_int('WORKER_SLEEP') disable_intra_rvd = _to_bool('DISABLE_INTRA_RVD') diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 20996851..fcdfc45a 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -462,7 +462,9 @@ def Softmax(signature, inputs): def Dropout(signature, inputs): - assert len(inputs) == 4 + assert len(inputs) <= 4, f'but the length is {len(inputs)}' + default_inputs = [None, 0.5, True, False] + inputs = inputs + default_inputs[len(inputs):] annos = ['* -> *'] tensor = inputs[0:1] p, training, inplace = inputs[1], inputs[2], inputs[3] @@ -514,11 +516,18 @@ def Sum(signature, inputs): torch.sum(input, *, dtype=None) -> Tensor torch.sum(input, dim, keepdim=False, *, dtype=None) -> Tensor """ - assert len(inputs) == 2 or len(inputs) == 4, f"{inputs}" + assert len(inputs) == 1 or len(inputs) == 2 or len(inputs) == 4, f"{inputs}" tensor = inputs[0] einput = ShapeAnno.create_shape_str(tensor.shape) eoutput = copy.copy(einput) - if len(inputs) == 2: + if len(inputs) == 1: + inputs = [tensor] + eoutput = ['1'] + # every dimension can be reduced + einput = [edim + '+' for edim in einput] + anno = OpAnno.create_op_str([einput], [eoutput]) + return IRDimops(Sum, 'sum', signature, [anno], [tensor]) + elif len(inputs) == 2: dtype = inputs[1] assert dtype is None, "Currently Sum only support dtype=None" # torch.sum(input) diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index 5d2c539f..b3811616 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,3 +1,4 @@ from cube.graph.parser.parser import ScriptModuleParser +from cube.graph.parser.parserfx import FxModuleParser, FxFuncOpTracer from cube.graph.parser.converter import convert_model from cube.graph.parser.register import register \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 229b23dd..5f3d8210 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -2,27 +2,47 @@ from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser +from cube.graph.parser import FxModuleParser, FxFuncOpTracer from cube.graph import IRGraph +from cube.flags import CompileFlag import torch +import torch.fx def convert_model(model: torch.nn.Module, input_shapes: Optional[ List[List[int],] ] = None, save_content: bool = True) -> IRGraph: """ - Convert toch.nn.Module based model into IRGraph + Convert torch.nn.Module based model into IRGraph """ try: - smodule = torch.jit.script(model) + if CompileFlag.use_torchfx: + # from torch.fx import symbolic_trace + # # Symbolic tracing frontend - captures the semantics of the module + tracer = FxFuncOpTracer() + traced_graph: torch.fx.Graph = tracer.trace(model) + smodule: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) + smodule.graph.print_tabular() + else: + smodule = torch.jit.script(model) + except Exception as ex: print(ex) - raise RuntimeError("Cannot convert module into torchscript moudle.") - module_name = smodule.original_name - ScriptModuleParser.save_content = save_content - inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) + raise RuntimeError("Cannot convert module into torchscript/torch.fx module.") + + if CompileFlag.use_torchfx: + FxModuleParser.save_content = save_content + inputs, nodes, outputs = FxModuleParser.parse(smodule, input_shapes) + module_name = model.__class__.__name__ + else: + ScriptModuleParser.save_content = save_content + inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) + module_name = smodule.original_name + for input in inputs: if isinstance(input, IRFullTensor): input.requires_grad = False + graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) return graph diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py new file mode 100644 index 00000000..ae94df09 --- /dev/null +++ b/cube/graph/parser/mappingfx.py @@ -0,0 +1,175 @@ +import torch + +from typing import Callable, Dict, Union +from functools import partial + +import cube.graph.function as function +import cube.ir as ir +from cube.ir.operator import IRFwOperation + +class SignFx2Op: + + @staticmethod + def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: + """ + Map the signature to GenericLogicalOp + """ + if 'torch.' not in signature and 'cube.runtime.' not in signature: + signature = signature.split('.')[-1] + if signature in SignFx2Op.kOpMap: + function = SignFx2Op.kOpMap[signature] + # signature = 'torch.sum' if signature == 'sum' else signature #TODO fixme + return partial(function, signature=signature) + else: + raise KeyError(f"{signature} is not supported yet") + # return partial(function.UnkownOperator, signature=signature) + + @staticmethod + def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]], code): + """ + Register an operator + """ + if not isinstance(signature, str): + raise TypeError(f"Expected signature to be str but got {type(signature)}") + if signature in SignFx2Op.kOpMap: + raise KeyError(f"function {signature} is already registered") + SignFx2Op.kOpMap[signature] = op + SignFx2Op.kOpCodeDef[signature] = code + + # functional templates + __ftemplate = lambda name: f'torch.nn.functional.{name}' + __fcntemplate = lambda name: f'torch._C._nn.{name}' + + # tensor template + __ttemplate = lambda name: f'torch.{name}' + + # runtime template + __rtemplate = lambda name: f'cube.runtime.function.function.{name}' + + # einops + __einopsize = lambda name: f'einops._torch_specific.{name}' + + # custom ops + __customops = lambda name: f'examples.custom_ops.{name}' + + kOpMap = { + __fcntemplate('linear'): function.Linear, + __ftemplate('dropout') : function.Dropout, + __ttemplate('sum'): function.Sum, + + # # torch nn functional + # + # __ftemplate('linear') : function.Linear, + # + # __ttemplate('matmul'): function.Matmul, + # + # __ftemplate('softmax') : function.Softmax, + # + # __ftemplate('dropout') : function.Dropout, + # + # __ftemplate('gelu') : function.GeLU, + # __ttemplate('gelu') : function.GeLU, + # + # __ftemplate('silu') : function.SiLU, + # __ttemplate('silu') : function.SiLU, + # + # __ftemplate('_pad'): function.Pad, + # + # __ftemplate('layer_norm'): function.LayerNorm, + # + # __ftemplate('embedding'): function.Embedding, + # + # __ftemplate('cross_entropy'): function.CrossEntropy, + # + # # torch aten + # + # # creators + # __ttemplate('zeros'): function.Zeros, + # __ttemplate('ones'): function.Ones, + # __ttemplate('tensor'): function.NewTensor, + # __ttemplate('to'): function.ToTensor, + # __ttemplate('rand'): function.Rand, + # __ttemplate('clone'): function.Clone, + # + # __ttemplate('add') : function.Add, + # + # __ttemplate('sub') : function.Sub, + # + # __ttemplate('mul') : function.Mul, + # + # __ttemplate('div') : function.Div, + # + # __ttemplate('floordiv') : function.FloorDiv, + # + # __ttemplate('neg'): function.Neg, + # + # __ttemplate('gt'): function.CompareGT, + # __ttemplate('lt'): function.CompareLT, + # __ttemplate('ge'): function.CompareGE, + # __ttemplate('le'): function.CompareLE, + # + # __ttemplate('pow'): function.Pow, + # + # __ttemplate('sin'): function.Sin, + # + # __ttemplate('cos'): function.Cos, + # + # __ttemplate('tanh'): function.Tanh, + # + # __ttemplate('bmm') : function.BatchLinear, + # + # __ttemplate('sum') : function.Sum, + # __ttemplate('mean') : function.Mean, + # + # __ttemplate('transpose') : function.Transpose, + # + # __ttemplate('view'): function.View, + # + # __ttemplate('reshape'): function.Reshape, + # + # __ttemplate('conv2d'): function.Conv2D, + # + # __ttemplate('conv3d'): function.Conv3D, + # + # __ttemplate('pad'): function.Pad, + # + # __ttemplate('select'): function.Select, + # + # __ttemplate('slice'): function.Slice, + # + # #pytorch1.11 + # __ttemplate('select_scatter'): function.SelectScatter, + # + # __ttemplate('repeat'): function.Repeat, + # + # #pytorch1.11 + # __ttemplate('linear'): function.Linear, + # + # __ttemplate('cat'): function.Cat, + # + # __ttemplate('stack'): function.Stack, + # + # __ttemplate('chunk'): function.Chunk, + # + # __ttemplate('flatten'): function.Flatten, + # + # __ttemplate('roll'): function.Roll, + # + # __ttemplate('adaptive_avg_pool1d'): function.AdaptiveAvgPool1d, + # + # # runtime functions + # __rtemplate('anchor'): function.GraphAnchor, + # + # __rtemplate('identity'): function.Identity, + # + # __rtemplate('multiref'): function.MultiRef, + # + # __rtemplate('accum'): function.Accum, + # + # #einops + # __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, + + } + + # customized operator code: signature -> code + kOpCodeDef: Dict[str, str] = {} \ No newline at end of file diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py new file mode 100644 index 00000000..fbc08c6d --- /dev/null +++ b/cube/graph/parser/parserfx.py @@ -0,0 +1,393 @@ +import torch +import enum +import re +from typing import Any, List, Tuple, Optional + +from cube.ir.operator import IRFwOperation +from cube.ir.tensor import IRFullTensor +import cube.ir as ir +from cube.graph.parser.frame import Frame +from cube.graph.parser.mapping import DType2IRDType +from cube.graph.parser.mappingfx import SignFx2Op + +import torch.fx + +class ErasedDevice: + pass + +class FxNodeKind(enum.Enum): + PrimGetAttr = 1 + PrimCallMethod = 2 + PrimCallFunction = 3 # -> the parser may end here + PrimConstant = 4 + AtenOp = 5 # -> the parser may end here + PrimIf = 6 # dynamic + PrimListConstruct = 7 + PrimListUnpack = 8 + PrimTupleUnpack = 9 + PrimPythonOp = 10 + PrimDevice = 11 # erased + PrimLoop = 12 + PrimCallModule = 13 + # for torch.fx + Placeholder = 14 + Output = 15 + + +class FxFuncOpTracer(torch.fx.Tracer): + def __init__(self, *args, customed_leaf_module=None, **kwargs): + super().__init__(*args, **kwargs) + self.customed_leaf_module = customed_leaf_module + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + if self.customed_leaf_module and isinstance(m, self.customed_leaf_module): + return True + # capture torch.nn.functional return + return m.__module__.startswith('torch.nn.functional') and not isinstance(m, torch.nn.Sequential) + + +class FxModuleParser: + save_content: bool = True + + @staticmethod + def shape_refine(shape: torch.Size) -> torch.Size: + """ + replacing scale shape [] to [1] + :param shape: + :return: + """ + # TODO update + return torch.Size([1]) if shape == torch.Size([]) else shape + + + @staticmethod + def parse(module: torch.fx.GraphModule, + input_shapes: Optional[Tuple[List[int],]] = None, + frame: Frame = None) \ + -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: + """ + The overall entry to parse a torch.fx graph module + """ + frame = frame if frame is not None else Frame() + frame.push_var() + frame.push_attr() + + inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] + print(f'inputs = {inputs}') + if input_shapes is not None and len(input_shapes) != len(inputs): + raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + + ## shape propagation + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + sample_inputs = [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] + from torch.fx.passes.shape_prop import ShapeProp + ShapeProp(module).propagate(*sample_inputs) + for node in module.graph.nodes: + print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) + + # handle graph input -- Assuming all the inputs are tensors + for idx, input in enumerate(inputs): + assert isinstance(input, torch.fx.Node) + shape = None if input_shapes is None else input_shapes[idx] + dtype = kDefaultType + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + frame.add_var(input.name, val, graph_arg=idx) + input_val = [frame.get_var(input.name) for input in inputs] + + # add activations to frame, including call_func output and final output + activation_nodes = [node for node in module.graph.nodes if (node.op == 'call_function' or node.op == 'output')] + for node in activation_nodes: + assert isinstance(node, torch.fx.Node) + shape = node.meta['tensor_meta'].shape + shape = FxModuleParser.shape_refine(shape) + dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) + val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) + frame.add_var(node.name, val) + + # handle nodes + all_ir_nodes: List[IRFwOperation] = list() + for node in module.graph.nodes: + ir_nodes = FxModuleParser.parse_node(node, module, frame) + all_ir_nodes += ir_nodes + + # handle outputs + output_nodes = [node.all_input_nodes for node in module.graph.nodes if node.op == 'output'] + print(f'outputs = {output_nodes}') + output_var_name = [output.name for output in [item for sublist in output_nodes for item in sublist]] + output_val = [frame.get_var(var_name) for var_name in output_var_name] + + # flatten output_val + outputs = list() + for val in output_val: + if isinstance(val, list): + outputs += val + else: + outputs.append(val) + output_val = outputs + + frame.pop_var() + frame.pop_attr() + if FxModuleParser.save_content: + frame.save_attr_content() + + return input_val, all_ir_nodes, output_val + + + @staticmethod + def ntype(node: torch.fx.Node): + if node.op == 'call_module': + return FxNodeKind.PrimCallModule + if node.op == 'call_function': + return FxNodeKind.PrimCallFunction + if node.op == 'get_attr': + return FxNodeKind.PrimGetAttr + if node.op == 'placeholder': + return FxNodeKind.Placeholder + if node.op == 'output': + return FxNodeKind.Output + if node.op == 'call_method': + return FxNodeKind.PrimCallMethod + # if node.kind() == 'prim::CallMethod': + # return FxNodeKind.PrimCallMethod + # if node.kind() == 'prim::CallFunction': # the op call + # return FxNodeKind.PrimCallFunction + # if node.kind() == 'prim::Constant': + # return FxNodeKind.PrimConstant + # if node.kind().startswith('aten::'): + # return FxNodeKind.AtenOp + # if node.kind() == 'prim::If': + # return FxNodeKind.PrimIf + # if node.kind() == 'prim::Loop': + # return FxNodeKind.PrimLoop + # if node.kind() == 'prim::ListConstruct': + # return FxNodeKind.PrimListConstruct + # if node.kind() == 'prim::TupleConstruct': + # return FxNodeKind.PrimListConstruct + # if node.kind() == 'prim::ListUnpack': + # return FxNodeKind.PrimListUnpack + # if node.kind() == 'prim::TupleUnpack': + # return FxNodeKind.PrimListUnpack + # if node.kind() == 'prim::PythonOp': + # return FxNodeKind.PrimPythonOp + # if node.kind() == 'prim::device': + # return FxNodeKind.PrimDevice + raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") + + @staticmethod + def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation]: + # print("### parse_node {}".format(node)) + """ + Parse the node and return the IRFwOperation nodes + """ + node_type = FxModuleParser.ntype(node) + try: + if node_type == FxNodeKind.Placeholder: + return [] + if node_type == FxNodeKind.Output: + return [] + + if node_type == FxNodeKind.PrimCallFunction: + return FxModuleParser.parse_prim_function_node(node, module, frame) + if node_type == FxNodeKind.PrimCallMethod: + return FxModuleParser.parse_prim_method_node(node, module, frame) + if node_type == FxNodeKind.PrimGetAttr: + return FxModuleParser.parse_prim_attr_node(node, module, frame) + if node_type == FxNodeKind.PrimCallModule: + return FxModuleParser.parse_prim_module(node, module, frame) + + # TODO bother assigning all ignored prim functions new NodeKinds? + if node_type == FxNodeKind.PrimDevice: + return FxModuleParser.parse_value_erased_node(node, module, frame, [ErasedDevice()]) + raise NotImplementedError(f"Un-supported node type {node_type}") + except Exception: + raise RuntimeError(f"\n\nParsing error at node:\n\t{node}\n") + + + @staticmethod + def parse_prim_module(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation]: + """ + :param node: + :param module: + :param frame: + :return: + """ + raise RuntimeError(f"parse_prim_module needs update") + + input_nodes = node.all_input_nodes + for input_node in input_nodes: + var_name = input_node.name + val = frame.get_var(var_name) + frame.push_param(var_name) + + # TODO skip self_module in torchscript + call_module = module + + label = node.name + node_target_stack = node.target.split('.') + + #TODO check leaf module, iterate if not + + leaf_module = module + for node_target_stack_iter in node_target_stack: + leaf_module = getattr(leaf_module, node_target_stack_iter) + + _, ir_nodes, outputs_val = FxModuleParser.parse_nn_module(node, leaf_module, frame=frame) + + # pop out the frame + frame.pop_param(times=len(input_nodes) - 1) + + # # handle outputs + # # TODO outputs vs output + # outputs = [node] + # # outputs = [output for output in node.outputs()] + # for output, val in zip(outputs, outputs_val): + # frame.add_var(output.name, val) + + return ir_nodes + + @staticmethod + def parse_nn_module(node: torch.fx.Node, method: torch.nn.Module, frame: Frame): + """ + Parse module method + """ + + input_var_name = [input_node.name for input_node in node.all_input_nodes] + input_val = [frame.get_var(var_name) for var_name in input_var_name] + + all_ir_nodes: List[IRFwOperation] = list() + + # handle graph output + + fsig = type(method).__name__ + # ir_node = SignFx2Op.map(fsig)(input=input_val) + func = SignFx2Op.map(fsig) + weights_names = None #TODO obtain parameter name list + if weights_names is not None: + for idx, weight_name in enumerate(weights_names): + #create FullTensor + weight = getattr(method, weight_name) + if weight is not None: + weight_fulltensor_name = node.name + "-" + weight_name + weight_fulltensor = IRFullTensor(weight.shape, weight_fulltensor_name, requires_grad=False) + # frame.add_attr(weight_fulltensor_name, weight_fulltensor) + # tmp_ones_tensor = torch.ones(weight.shape, dtype=torch.get_default_dtype()) + # frame.add_attr_content(weight_fulltensor.tid, tmp_ones_tensor) + frame.add_var(weight_fulltensor_name, weight_fulltensor) + input_val.append(weight_fulltensor) + else: + input_val.append(None) + + ir_node = func(inputs=input_val) + all_ir_nodes += [ir_node] + + outputs = [node] + output_var_name = [output.name for output in outputs] + output_val = [frame.get_var(var_name) for var_name in output_var_name] + + # frame.pop_var() + return input_val, all_ir_nodes, output_val + + @staticmethod + def fetch_attr(mod: torch.fx.GraphModule, target: str): + target_atoms = target.split('.') + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + @staticmethod + def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + """ + parse node like: + Tensor = prim::CallFunction(%5, %input.1, %3, %4) + %5 : Function = prim::Constant[name="linear"]() + %12 : (Tensor, Tensor) = prim::CallFunction(%5, %x1.1, %x2.1) + """ + # get signature + fsig = FxModuleParser._get_qualified_name(node.target) + print(f'parse_prim_function_node: {fsig}') + + # get inputs + input_nodes = [input_node for input_node in node.args] + input_vals = list() + for index, input_node in enumerate(input_nodes): + if isinstance(input_node, torch.fx.Node): + var_name = input_node.name + val = frame.get_var(var_name) + input_vals.append(val) + else: + input_vals.append(None) + + # map to IR operator + ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + + # TODO gracefully set output + output_name = node.name + output_val = frame.get_var(output_name) + ir_node.set_output(0, output_val) + + # # push output in the frame + # # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) + # # : >>> dir(a) + # # : >>> a.elements() # [TensorType, TensorType] + # cnt = 0 + # for output in node.outputs(): + # if isinstance(output.type(), torch._C.TupleType): + # tuplen = len(output.type().elements()) + # ir_output = [ir_node.output(idx) for idx in range(cnt, cnt + tuplen)] + # cnt += tuplen + # else: + # ir_output = ir_node.output(cnt) + # cnt += 1 + # frame.add_var(output.debugName(), ir_output) + # + # if cnt != len(ir_node.outputs()): + # raise RuntimeError( + # f"Parse fail: {fsig} has {cnt} outputs != pre-defined {len(ir_node.outputs())}" + # ) + + return [ir_node] + + @staticmethod + def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + assert node is not None + tensor_name = node.name + + tensor_shape = node.meta['tensor_meta'].shape + # tensor_dtype = node.meta['tensor_meta'].dtype + #TODO assume it is weight + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=kDefaultType) + ir_tensor.as_param() + frame.add_var(tensor_name, ir_tensor) + + return list() + + + from typing import Callable + # import python_stubs.buildins + @staticmethod + def _get_qualified_name(func: Callable[..., Any]) -> str: + # # things like getattr just appear in builtins + # if getattr(builtins, func.__name__, None) is func: + # return func.__name__ + name = func.__name__ + module = FxModuleParser._find_module_of_method(func) + module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + return f'{module}.{name}' + + # this is fixed on master, WAR for 1.5 + @staticmethod + def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + name = orig_method.__name__ + module = orig_method.__module__ + if module is not None: + return module + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is orig_method: + return guess.__name__ + raise RuntimeError(f'cannot find module for {orig_method}') \ No newline at end of file diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py new file mode 100644 index 00000000..9d1de60d --- /dev/null +++ b/examples/mlp/linearsfx.py @@ -0,0 +1,119 @@ +""" +example: + +//torchscript based DAG capture +PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/mlp/linearsfx.py --policy PASData +//torch.fx based DAG capture +USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/mlp/linearsfx.py --policy PASData +""" + +import torch +from torch import nn + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + +import examples.mlp.policy.spmd as spmd +import examples.mlp.policy.mpmd as mpmd + +import argparse + +parser = argparse.ArgumentParser(description='comm primitive') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +parser.add_argument('--local_rank', type=int, default=0) +args = parser.parse_args() + +cube.init() + +# set up policy +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + + +# =================== Semantic Model Description ==================== + +class MLP(nn.Module): + def __init__(self, dim, mult=1, nlayers=4): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for lid in range(nlayers): + if lid % 2 == 0: + self.layers.append(nn.Linear(dim, dim * mult, bias=False)) + last_dim = dim * mult + else: + self.layers.append(nn.Linear(dim * mult, dim, bias=False)) + last_dim = dim + + # self.layer_norm = nn.LayerNorm(last_dim) #TODO CHECK torch.fx ignores LayerNorm + # self.p = 0.5 + self.drop_out = nn.Dropout() + # self.y = torch.nn.Parameter(torch.empty(128, last_dim)) + + def forward(self, data): + x = data + for layer in self.layers: + x = layer(x) + # x = self.layer_norm(x) + x = self.drop_out(x) + # x = torch.nn.functional.dropout(x, self.p) + # x = x * self.y + loss = torch.sum(x) + return loss + + +def train(): + batch_size = 128 + dim = 8192 + + model = MLP(dim=dim) + model = cube.SemanticModel( + model, input_shapes=([batch_size, dim],), + ) + + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([batch_size, dim],), + dtypes=(torch.float32,), + batch_dims=(0,) + ) + + @cube.compile(model, dataloader, PAS=PAS, load_content=False) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + + model = model.get_gen_module() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + CudaTimer(enable=False).warmup() + if torch.distributed.is_initialized(): + torch.distributed.barrier() + iter_num = 2 #32 + warmup = 0 #8 + for step in range(iter_num): + if step >= warmup: + CudaTimer(enable=True).start('e2e') + train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + if step >= warmup: + CudaTimer().stop('e2e') + if (step + 1) % 20 == 0: + print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num - warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num - warmup) + + +train() \ No newline at end of file From 55769ca3a39e9cb29ed7a3f33d1a255569f5257a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 28 Feb 2023 20:04:06 +0800 Subject: [PATCH 1262/1892] update schedule plan with block latency --- cube/execplan/execplan.py | 2 +- cube/graph/schedule/schedplan.py | 65 ++++++++++++++++++++------------ 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 2fcfa13c..5935b409 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -176,7 +176,7 @@ def block2reuse(node: Block) -> ExeReuseCell: micro_fcells[key] = cell return cell else: - mcell = block2reuse(Block(node.content.mirror, node.mid)) + mcell = block2reuse(Block(node.content.mirror, node.mid, node.span)) return mcell.mirror topo_seqs: List[IRCell] = [] diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py index dea5e7ac..115f8d73 100644 --- a/cube/graph/schedule/schedplan.py +++ b/cube/graph/schedule/schedplan.py @@ -16,12 +16,14 @@ class Block: that is executed with input data of a given micro-batch index. """ - def __init__(self, cell: IRCell, micro_batch_id: int) -> None: - """ + def __init__(self, cell: IRCell, micro_batch_id: int, span: int) -> None: + """Create an execution block with IRCell on microbatch index. The + block will take `span` steps to finish execution. """ assert isinstance(cell, IRCell), f"Expected IRCell, but got {type(cell)}: {cell}" self._content: IRCell = cell self._micro_batch_id: int = micro_batch_id + self._span = span def __eq__(self, other): if isinstance(other, Block): @@ -43,6 +45,10 @@ def mid(self) -> int: def content(self) -> IRCell: return self._content + @property + def span(self) -> int: + return self._span + def dispatch(self, devid: int): return Block(self._content.dispatch(devid), self._micro_batch_id) @@ -126,18 +132,19 @@ def device(self) -> Tuple[int]: def nodes(self) -> Tuple[Block]: return tuple(self._seqs) - def add_segment(self, seg: IRSegment, micro_batch_id: int, step: int) -> Block: + def add_segment(self, seg: IRSegment, micro_batch_id: int, step: int, span: Optional[int] = 1) -> Block: """ Add a segment `seg` to be executed with `micro-batch-id` data at step `step`. """ - self._extend_step(step) + self._extend_step(step + span - 1) if len(self._step_segments[step]) == 1 and isinstance(self._step_segments[0], PlanBase): assert False, "Cannot add an IRSegment into a step that already has Repetend." assert all(devid not in self._step_devices for devid in seg.device), \ f"A device cannot execute multiple segments on a same step" - block = Block(seg, micro_batch_id) - self._step_segments[step].append(block) - self._step_devices[step].update(seg.device) + block = Block(seg, micro_batch_id, span) + for t in range(span): + self._step_segments[step+t].append(block) + self._step_devices[step+t].update(seg.device) self._block_step[block] = step self._segments.append(block) return block @@ -147,7 +154,9 @@ def segments(self, step: int) -> Tuple[Block]: Get segment blocks at step """ assert step < self.nsteps - return tuple(self._step_segments[step]) + blocks = self._step_segments[step] + blocks = tuple(blk for blk in blocks if self.step(blk) == step) + return blocks def step(self, block: Block) -> int: """Get the step of the block @@ -177,13 +186,15 @@ def _place_dataloader(self): # FIXME: this may not work for multiple segments in a same # micro-batch require for the data for dl in self._dependency.dataloaders: - for step, blocks in enumerate(self._step_segments): + for step in range(self.nsteps): + blocks = self.segments(step) for block in blocks: - if isinstance(block, Block): - segment, mid = block.content, block.mid - if self.graph.depends(dl, segment): - self._step_segments[step].insert(0, Block(dl, mid)) - break + segment, mid = block.content, block.mid + if self.graph.depends(dl, segment): + dl_block = Block(dl, mid, 1) + self._step_segments[step+block.span-1].insert(0, dl_block) + self._block_step[dl_block] = step+block.span-1 + break def topo_sort(self): """ @@ -192,7 +203,7 @@ def topo_sort(self): """ self._seqs = [] for step in range(self.nsteps): - self._seqs += self._step_segments[step] + self._seqs += self.segments(step) self._seqs += self._step_adapters[step] @@ -248,7 +259,7 @@ def _place_adapters(self): extended_blocks = [] for step in range(self.nsteps): for blk in self.segments(step): - extend_blk = Block(blk.content, blk.mid + cnts[blk.content]) + extend_blk = Block(blk.content, blk.mid + cnts[blk.content], blk.span) extended_blocks.append(extend_blk) # step2: generate adapters for each step all_blocks = self.all_segments() @@ -262,14 +273,14 @@ def _place_adapters(self): # - we don't allow send and recver in un-neighbored repetend # 3) its recver are outside the repetend recver = self._dependency.recvers[adapter] - rblock = Block(recver, block.mid) - ablock = Block(adapter, block.mid) + rblock = Block(recver, block.mid, block.span) + ablock = Block(adapter, block.mid, 1) # case 1) if rblock in all_blocks: - self._step_adapters[step].append(ablock) + self._step_adapters[step+block.span-1].append(ablock) # case 2) elif rblock in extended_blocks: - self._step_adapters[self.nsteps-1].append(Block(adapter, block.mid - cnts[blk.content])) + self._step_adapters[self.nsteps-1].append(Block(adapter, block.mid - cnts[blk.content], 1)) self._post_adapters.append(ablock) # case 3) else: @@ -394,8 +405,8 @@ def _place_adapters(self): for adapter in self._dependency.adapters: sender: IRSegment = self._dependency.senders[adapter] # find sender step and insert adapter - for step, blocks in enumerate(self._step_segments): - if len(blocks) == 0: continue + for step in range(self.nsteps): + blocks = self.segments(step) if len(blocks) == 1 and isinstance(blocks[0], Repetend): self._step_adapters[step] += list(blocks[0].get_post_adapters()) else: @@ -403,8 +414,9 @@ def _place_adapters(self): segments = [block.content for block in blocks] mids = [block.mid for block in blocks] if sender in segments: + span = blocks[segments.index(sender)].span mid = mids[segments.index(sender)] - self._step_adapters[step].append(Block(adapter, mid)) + self._step_adapters[step+span-1].append(Block(adapter, mid, 1)) def topo_sort(self): super().topo_sort() @@ -428,7 +440,8 @@ def __repr__(self) -> str: for devid in sorted(self.device): timeline = '\n' - for step in range(self.nsteps): + step = 0 + while step < self.nsteps: # segment have_block = False for block in self._step_segments[step]: @@ -437,9 +450,11 @@ def __repr__(self) -> str: break if have_block: blk_repr = f"{sids[block.content]}{'f' if block.content.isfw() else 'b'}{block.mid}" - timeline += f" {blk_repr}" + timeline += f" {'-'.join([blk_repr] * block.span)}" + step += block.span else: timeline += f" ---" + step += 1 # adapter # have_block = False # for block in self._step_adapters[step]: From e2010d0d849ee3ce1747b4589fc95037fb8aea60 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 28 Feb 2023 17:28:29 -0800 Subject: [PATCH 1263/1892] support squeeze and unsqueeze in fxparser --- cube/graph/function/function.py | 32 ++++++++++++++++++++++++++++++ cube/graph/parser/frame.py | 2 +- cube/graph/parser/mappingfx.py | 2 ++ cube/graph/parser/parserfx.py | 35 ++++++++++++++++++++++++++++++++- examples/mlp/linearsfx.py | 2 ++ 5 files changed, 71 insertions(+), 2 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index fcdfc45a..7a55ed51 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -763,6 +763,38 @@ def Reshape(signature, inputs): return View(signature, inputs) +def Squeeze(signature, inputs): + """ + out = torch.squeeze(tensor) + """ + assert len(inputs) == 1 + input = inputs[0] + + edim_in = ShapeAnno.create_shape_str(input.shape) + assert len(edim_in) == len(input.shape) + edim_ou = [] + for dim_anno, dim_size in zip(edim_in, input.shape): + if dim_size > 1: + edim_ou.append(copy.copy(dim_anno)) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(Squeeze, 'squeeze', signature, [anno], [input]) + +def Unsqueeze(signature, inputs): + """ + out = torch.unsqueeze(tensor, dim) + """ + assert len(inputs) == 2 + input, dim = inputs + + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = copy.copy(edim_in) + edim_ou.insert(dim, '1') + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input], + dim=dim) + # def Pad(signature, inputs): # """ diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index 40ec379b..d3b82936 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -90,7 +90,7 @@ def get_var(self, var_name: str) -> Any: # first check whether we have variable in this frame if var_name in self._vars[-1]: return self._vars[-1][var_name] - raise KeyError(f"Cannot find var name {var_name}") + raise KeyError(f"Cannot find var name {var_name} in {self._vars}") def push_attr(self): """ diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index ae94df09..83627ee9 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -56,6 +56,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, + __ttemplate('squeeze'): function.Squeeze, + __ttemplate('unsqueeze'): function.Unsqueeze, # # torch nn functional # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index fbc08c6d..4e934693 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -96,7 +96,8 @@ def parse(module: torch.fx.GraphModule, input_val = [frame.get_var(input.name) for input in inputs] # add activations to frame, including call_func output and final output - activation_nodes = [node for node in module.graph.nodes if (node.op == 'call_function' or node.op == 'output')] + activation_op_strs = {'call_function', 'output', 'call_method'} + activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] for node in activation_nodes: assert isinstance(node, torch.fx.Node) shape = node.meta['tensor_meta'].shape @@ -308,6 +309,7 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, """ # get signature fsig = FxModuleParser._get_qualified_name(node.target) + print(node.target, type(node.target)) print(f'parse_prim_function_node: {fsig}') # get inputs @@ -351,6 +353,34 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, return [ir_node] + @staticmethod + def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + # get signature + fsig = FxModuleParser._get_qualified_name(node.target) + print(f'parse_prim_method_node: {fsig}') + + # get inputs + input_nodes = [input_node for input_node in node.args] + input_vals = list() + for index, input_node in enumerate(input_nodes): + if isinstance(input_node, torch.fx.Node): + var_name = input_node.name + val = frame.get_var(var_name) + input_vals.append(val) + elif isinstance(input_node, int): + input_vals.append(input_node) + else: + input_vals.append(None) + + # map to IR operator + ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + + output_name = node.name + output_val = frame.get_var(output_name) + ir_node.set_output(0, output_val) + + return [ir_node] + @staticmethod def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: assert node is not None @@ -375,6 +405,9 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: # # things like getattr just appear in builtins # if getattr(builtins, func.__name__, None) is func: # return func.__name__ + if isinstance(func, str): + # TODO(yizhu1): find a general solution + return f'torch.{func}' name = func.__name__ module = FxModuleParser._find_module_of_method(func) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 9d1de60d..75bbe5e4 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -63,7 +63,9 @@ def forward(self, data): for layer in self.layers: x = layer(x) # x = self.layer_norm(x) + x = x.unsqueeze(0) x = self.drop_out(x) + x = x.squeeze() # x = torch.nn.functional.dropout(x, self.p) # x = x * self.y loss = torch.sum(x) From 6a2d286131c0cda1b2073611731f75d0891bffd8 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 28 Feb 2023 17:32:06 -0800 Subject: [PATCH 1264/1892] refine code --- cube/graph/parser/parserfx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 4e934693..c0640b9f 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -309,7 +309,6 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, """ # get signature fsig = FxModuleParser._get_qualified_name(node.target) - print(node.target, type(node.target)) print(f'parse_prim_function_node: {fsig}') # get inputs From 01f1a76682df3e89dd0c2e04c3be22b0c9930d88 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 28 Feb 2023 19:10:19 -0800 Subject: [PATCH 1265/1892] update example code --- examples/mlp/linearsfx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 75bbe5e4..6b7f891e 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -63,7 +63,7 @@ def forward(self, data): for layer in self.layers: x = layer(x) # x = self.layer_norm(x) - x = x.unsqueeze(0) + x = x.unsqueeze(1) x = self.drop_out(x) x = x.squeeze() # x = torch.nn.functional.dropout(x, self.p) From 09f7a055d53da36770dcddcea54a9f9ebccf75c6 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 28 Feb 2023 21:48:58 -0800 Subject: [PATCH 1266/1892] support type_as --- cube/graph/function/function.py | 14 ++++++++++++++ cube/graph/parser/mappingfx.py | 4 ++++ cube/graph/parser/parserfx.py | 9 +++++++-- examples/mlp/linearsfx.py | 1 + 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 7a55ed51..e7b247aa 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -795,6 +795,20 @@ def Unsqueeze(signature, inputs): return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input], dim=dim) +def TypeAs(signature, inputs): + """ + out = torch.Tensor.type_as(tensor0, tensor1) + """ + assert len(inputs) == 2 + input0, input1 = inputs + + edim_in0 = ShapeAnno.create_shape_str(input0.shape) + edim_in1 = ShapeAnno.create_shape_str(input1.shape) + edim_ou = copy.copy(edim_in0) + anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) + + return IRDimops(TypeAs, 'type_as', signature, [anno], [input0, input1]) + # def Pad(signature, inputs): # """ diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 83627ee9..d9c4bb5e 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -43,6 +43,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # tensor template __ttemplate = lambda name: f'torch.{name}' + # torch.Tensor template + __tttemplate = lambda name: f'torch.Tensor.{name}' + # runtime template __rtemplate = lambda name: f'cube.runtime.function.function.{name}' @@ -58,6 +61,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('sum'): function.Sum, __ttemplate('squeeze'): function.Squeeze, __ttemplate('unsqueeze'): function.Unsqueeze, + __tttemplate('type_as'): function.TypeAs, # # torch nn functional # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index c0640b9f..e94c9554 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -404,9 +404,14 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: # # things like getattr just appear in builtins # if getattr(builtins, func.__name__, None) is func: # return func.__name__ + # TODO(yizhu1): find a general solution if isinstance(func, str): - # TODO(yizhu1): find a general solution - return f'torch.{func}' + if getattr(torch, func, None) is not None: + return f'torch.{func}' + elif getattr(torch.Tensor, func, None) is not None: + return f'torch.Tensor.{func}' + else: + raise RuntimeError(f'cannot find module for {func}') name = func.__name__ module = FxModuleParser._find_module_of_method(func) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 6b7f891e..fd0f3565 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -63,6 +63,7 @@ def forward(self, data): for layer in self.layers: x = layer(x) # x = self.layer_norm(x) + x = x.type_as(data) x = x.unsqueeze(1) x = self.drop_out(x) x = x.squeeze() From 32ec475eeca2c50f16f425f704a4ccf88c90ff9d Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 28 Feb 2023 22:56:20 -0800 Subject: [PATCH 1267/1892] support for triu --- cube/graph/function/function.py | 13 +++++++++++++ cube/graph/parser/mappingfx.py | 1 + cube/graph/parser/parserfx.py | 2 ++ examples/mlp/linearsfx.py | 1 + 4 files changed, 17 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index e7b247aa..85821263 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -809,6 +809,19 @@ def TypeAs(signature, inputs): return IRDimops(TypeAs, 'type_as', signature, [anno], [input0, input1]) +def Triu(signature, inputs): + """ + out = torch.triu(tensor, diagonal) + """ + assert len(inputs) == 2 + input, diagonal = inputs + + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(Triu, 'triu', signature, [anno], [input], + diagonal=diagonal) # def Pad(signature, inputs): # """ diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index d9c4bb5e..cb83209d 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -62,6 +62,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('squeeze'): function.Squeeze, __ttemplate('unsqueeze'): function.Unsqueeze, __tttemplate('type_as'): function.TypeAs, + __ttemplate('triu'): function.Triu, # # torch nn functional # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index e94c9554..9c254e9d 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -319,6 +319,8 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, var_name = input_node.name val = frame.get_var(var_name) input_vals.append(val) + elif isinstance(input_node, int): + input_vals.append(input_node) else: input_vals.append(None) diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index fd0f3565..fc1f4c84 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -67,6 +67,7 @@ def forward(self, data): x = x.unsqueeze(1) x = self.drop_out(x) x = x.squeeze() + x = torch.triu(x, 1) # x = torch.nn.functional.dropout(x, self.p) # x = x * self.y loss = torch.sum(x) From 846107e22cc7d5c51824ed94bbcdc337e95726f1 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 1 Mar 2023 00:53:26 -0800 Subject: [PATCH 1268/1892] support relu and ne --- cube/graph/function/function.py | 23 +++++++++++++++++++++++ cube/graph/parser/mappingfx.py | 2 ++ cube/graph/parser/parserfx.py | 4 ++-- cube/ir/dtype.py | 3 +++ examples/mlp/linearsfx.py | 3 +++ 5 files changed, 33 insertions(+), 2 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 85821263..46e88380 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -452,6 +452,14 @@ def SiLU(signature, inputs): return IRDimops(SiLU, 'silu', signature, annos, tensor) +def ReLU(signature, inputs): + assert len(inputs) == 1 + annos = ['* -> *'] + signature = 'torch.nn.functional.relu' + tensor = inputs[0:1] + return IRDimops(ReLU, 'relu', signature, annos, tensor) + + def Softmax(signature, inputs): assert len(inputs) == 4 annos = ['* -> *'] @@ -472,6 +480,21 @@ def Dropout(signature, inputs): p=p, training=training, inplace=inplace) +def NE(signature, inputs): + assert len(inputs) == 2 + input0, input1 = inputs + + edim_in0 = ShapeAnno.create_shape_str(input0.shape) + edim_ou = copy.copy(edim_in0) + if isinstance(input1, float): + anno = OpAnno.create_op_str([edim_in0], [edim_ou]) + return IRDimops(NE, 'ne', signature, [anno], [input0], other=input1) + else: + edim_in1 = copy.copy(edim_in0) + anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) + return IRDimops(NE, 'ne', signature, [anno], [input0], other=input1) + + def LayerNorm(signature, inputs): """ torch.nn.functional.layer_norm(input, normliazed_shape, weight=None, bias=None, eps) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index cb83209d..de839451 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -63,6 +63,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('unsqueeze'): function.Unsqueeze, __tttemplate('type_as'): function.TypeAs, __ttemplate('triu'): function.Triu, + __ftemplate('relu') : function.ReLU, + __ttemplate('ne') : function.NE, # # torch nn functional # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 9c254e9d..f5a9326a 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -319,7 +319,7 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, var_name = input_node.name val = frame.get_var(var_name) input_vals.append(val) - elif isinstance(input_node, int): + elif isinstance(input_node, (int, float)): input_vals.append(input_node) else: input_vals.append(None) @@ -368,7 +368,7 @@ def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr var_name = input_node.name val = frame.get_var(var_name) input_vals.append(val) - elif isinstance(input_node, int): + elif isinstance(input_node, (int, float)): input_vals.append(input_node) else: input_vals.append(None) diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index 89b8f6dd..edba56d1 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -46,6 +46,9 @@ def infer(node, dtypes: List[IRDType]) -> IRDType: raise RuntimeError(f"Find an unkown dtype") if IRDType.float32 in dtypes and IRDType.float16 in dtypes: raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") + # TODO(yizhu1): hack + if node.signature == 'torch.ne': + return IRDType.boolean # in priority: fp32 > fp16 > bool > int64 > int16 > priority = [ IRDType.float64, IRDType.float32, IRDType.float16, diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index fc1f4c84..afd36e2d 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -62,12 +62,15 @@ def forward(self, data): x = data for layer in self.layers: x = layer(x) + x = torch.nn.functional.relu(x) # x = self.layer_norm(x) x = x.type_as(data) x = x.unsqueeze(1) x = self.drop_out(x) x = x.squeeze() x = torch.triu(x, 1) + # ne cannot backward + # x = torch.ne(x, 1.0) # x = torch.nn.functional.dropout(x, self.p) # x = x * self.y loss = torch.sum(x) From b0ca7114100259e8178934605132995d33e0a9b5 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 1 Mar 2023 04:22:11 -0800 Subject: [PATCH 1269/1892] improve fx parser during supporting bloom --- cube/graph/parser/parserfx.py | 37 ++++++++---- tests/parser/test_bloom.py | 108 ++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 11 deletions(-) create mode 100644 tests/parser/test_bloom.py diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index fbc08c6d..6b524a92 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -62,7 +62,7 @@ def shape_refine(shape: torch.Size) -> torch.Size: @staticmethod def parse(module: torch.fx.GraphModule, - input_shapes: Optional[Tuple[List[int],]] = None, + dummy_inputs: Optional[Any] = None, frame: Frame = None) \ -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: """ @@ -74,22 +74,37 @@ def parse(module: torch.fx.GraphModule, inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] print(f'inputs = {inputs}') - if input_shapes is not None and len(input_shapes) != len(inputs): - raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") ## shape propagation - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - sample_inputs = [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] - from torch.fx.passes.shape_prop import ShapeProp - ShapeProp(module).propagate(*sample_inputs) + from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp + KwargsShapeProp(module).propagate(dummy_inputs) for node in module.graph.nodes: - print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) + if 'tensor_meta' in node.meta: + if node.meta['type'] is type(tuple()): + print(f'{node.name} is tuple type') + elif node.meta['type'] is type(torch.fx.immutable_collections.immutable_dict()): + print(f'{node.name} is immutable_dict type') + assert isinstance(node.meta['tensor_meta'], dict) + else: + assert node.meta['type'] is type(torch.Tensor()) or node.meta['type'] is type(torch.nn.parameter.Parameter()) + print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) + else: + print(f'{node.name} does not has tensor_meta') - # handle graph input -- Assuming all the inputs are tensors + # handle graph input -- some inputs could be None or not tensor + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, extend to other input types for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) - shape = None if input_shapes is None else input_shapes[idx] + if hasattr(dummy_inputs, input.name): + print(f'dummy_inputs has {input.name}') + shape = getattr(dummy_inputs, input.name).size()# None if dummy_inputs is None else dummy_inputs[idx].size() + else: + # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is aligned with input.name + print(f'dummy_inputs does not have {input.name}') + shape = None + # FIXME: use the input's real dtype dtype = kDefaultType val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) diff --git a/tests/parser/test_bloom.py b/tests/parser/test_bloom.py new file mode 100644 index 00000000..69829bfa --- /dev/null +++ b/tests/parser/test_bloom.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +# model_name = "bigscience/bloom-7b1" +# model_path = "/home/quzha/bloom7b1" +model_name = "bigscience/bloom-560m" +model_path = "/home/quzha/bloom560m" +# model_name = "facebook/opt-66b" +# model_name = "facebook/opt-iml-30b" +# model_name = "facebook/optiml30b" +# model_name = "facebook/opt-iml-1.3b" +# model_name = "facebook/opt-13b" +# model_path = "/home/quzha/opt13b" + +print("Loading model...") #device_map="balanced", +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir=model_path)#.cuda() +print(type(model), '; is nn.Module? ', isinstance(model, nn.Module)) +print("Model's generation config which does not list default values: ", model.generation_config) +print("Loading tokenizer...") +tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) +print("Loading Done!") +prompt = "If I want to travel to a new city, I should plan my trip as follows:" +#input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() +inputs = tokenizer(prompt, return_tensors="pt")#.to('cuda:0') + +# Cube +# from cube.graph import parser +# ir_graph = parser.convert_model(model, input_shapes=[1, 17], save_content=False) + +# model(input_ids, None, None, None, None, None, None, None, None, None) + +print("concrete tracing model...") +from nni.common.concrete_trace_utils import concrete_trace +#traced_graph = concrete_trace(model, (input_ids, None, None, None, None, None, None, None, None, None), use_function_patch=True, +# autowrap_leaf_class={torch.finfo: ((), False)}) +#traced_graph = concrete_trace(model, inputs, use_function_patch=True, +traced_graph = concrete_trace(model, inputs, use_operator_patch=True, + autowrap_leaf_class={torch.finfo: ((), False)}) +# traced_graph.graph.print_tabular() +print("tracing model done.") + +print("parsing fx graph to cube graph...") +from cube.graph.parser import FxModuleParser +# dummy_inputs = [inputs.input_ids, None, inputs.attention_mask, None, None, None, None, None, None, None, {}] +# FxModuleParser.parse(traced_graph, dummy_inputs) +FxModuleParser.parse(traced_graph, inputs) +print("parsing done.") + +# AutoDist +# from autodist.apis import compile +# from cube.runtime.resource import EnvResource +# resource = EnvResource() +# graph = compile(ir_graph, resource) + + +# print(type(model), '; is nn.Module? ', isinstance(model, nn.Module)) +# print("Model's generation config which does not list default values: ", model.generation_config) +# print("Loading tokenizer...") +# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) +# print("Loading Done!") + +# #prompt = "If you are a calculator, please tell me the results of 32 x 23 =" +# # prompt = "what is the english word that means little modification and starts with character "t"? the english word is" +# # prompt = "If I want to travel to USA, I need to apply for a" +# prompt = "If I want to travel to a new city, I should plan my trip as follows:" +# # prompt = "I look forward to" +# # prompt = "Today was an amazing day because" +# # prompt = "What is the color of a carrot?\nA:" + + +# # Some of the commonly adjusted parameters: max_new_tokens, num_beams, do_sample, num_return_sequences +# # https://huggingface.co/blog/how-to-generate +# # https://huggingface.co/docs/transformers/v4.26.1/en/generation_strategies#text-generation-strategies +# # Beam-search decoding +# generation_config_beam = GenerationConfig( +# num_beams=4, +# do_sample=False, +# early_stopping=True, +# decoder_start_token_id=0, +# eos_token_id=model.config.eos_token_id, +# pad_token=model.config.pad_token_id, +# ) +# # Beam-search decoding without early stopping +# generation_config_beam_fixed_len = GenerationConfig( +# num_beams=4, +# do_sample=False, +# early_stopping=False, +# max_new_tokens=20, +# decoder_start_token_id=0, +# eos_token_id=model.config.eos_token_id, +# pad_token=model.config.pad_token_id, +# ) +# # Contrastive search +# generation_config_contrastive = GenerationConfig( +# penalty_alpha=0.6, +# top_k=4, +# max_new_tokens=100, +# ) + +# print("Tokenizing prompt...") +# input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() +# print("input_ids shape: ", input_ids.size()) +# print("Generating sequence ids...") +# generated_ids = model.generate(input_ids, generation_config=generation_config_beam_fixed_len) +# print("Decoding sequence ids...") +# output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) +# print(output) From f925f0aa5b7373becc79a3c3b5e090a5205b4a86 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Mar 2023 20:56:09 +0800 Subject: [PATCH 1270/1892] fix async op --- cube/codegen/emit.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 73047db7..a1e12706 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -157,12 +157,13 @@ def emit_adapter(node: IRAdapter, prefix_attr: Optional[str] = None) -> List[str # only adapter that is non-differentiable can be executed as async async_op = CompileFlag.async_comm and (not node.differentiable) - for idx, prim in enumerate(prims): - if isinstance(prim, IRAdapterPrim) and prim.volume() == 0: - continue - break - #TODO: support more general cases: independent same-group primitives - async_op = False if len(prims[idx:]) != 1 else async_op + if async_op: + for idx, prim in enumerate(prims): + if isinstance(prim, IRAdapterPrim) and prim.volume() == 0: + continue + break + #TODO: support more general cases: independent same-group primitives + async_op = False if len(prims[idx:]) != 1 else async_op for prim in prims: if len(prim.inputs()) == 1: From 9784d4e67c7683975f92d3a1400f853628ee5597 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Mar 2023 21:21:08 +0800 Subject: [PATCH 1271/1892] fix optimization on staging --- cube/graph/graph.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 3c0cd597..446503ff 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -689,13 +689,14 @@ def staging(self, nodes: Tuple[IRFwOperation]): f"Find node is not IRFwOperation or IRDataOperation: {node}" assert all(node in self._nodes for node in nodes), \ f"Exist node is not in graph nodes" - starts = tuple(self._nodes.index(node) for node in nodes) + starts = list(self._nodes.index(node) for node in nodes) assert len(starts) > 0 # multiref (created by graph.auto_multiref) will be moved to the next stage (if possible) for optimization for sid in range(len(starts)): while starts[sid] > 0: - if self.node(starts[sid]-1).name == 'multiref': + node = self.node(starts[sid]-1) + if node.name == 'multiref' or isinstance(node, IRGraphAnchor): starts[sid] -= 1 continue break From 6f5f419453e312527248fed33200832626d7af27 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Mar 2023 21:51:56 +0800 Subject: [PATCH 1272/1892] local indicator --- cube/ir/adapter/prim.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index 0dab9656..b29e3232 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -19,6 +19,8 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor], **kwar for arg, val in kwargs.items(): self.kwargs[arg] = val self.signature = None + # whether the primitive is happened locally + self.local: bool = False def input(self, idx:int): return self._inputs[idx] @@ -64,6 +66,7 @@ class SpatialPrim(IRAdapterPrim): def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor], **kwargs): super().__init__(inputs, outputs, **kwargs) self.device = list(set(t.device[0] for t in inputs)) + self.local = True def volume(self) -> int: return 0 @@ -77,6 +80,7 @@ class ValuePrim(IRAdapterPrim): def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): super().__init__(inputs, outputs) self.device = list(set(t.device[0] for t in inputs)) + self.local = True def volume(self) -> int: return 0 @@ -93,6 +97,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k for t in list(itensors) + list(otensors): devices += t.device self.device = list(set(devices)) + self.local = False def dispatch(self, devid: int) -> Optional[IRAdapterPrim]: """ From d5dbd609571c50b268171a5b3cb09668fc2bf906 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 1 Mar 2023 07:26:21 -0800 Subject: [PATCH 1273/1892] support long, masked fill --- cube/graph/function/function.py | 25 +++++++++++++++++++++++++ cube/graph/parser/mappingfx.py | 3 +++ cube/graph/parser/parserfx.py | 11 +++++------ cube/ir/dtype.py | 2 ++ cube/runtime/syndata.py | 4 ++-- examples/mlp/linearsfx.py | 19 +++++++++++-------- 6 files changed, 48 insertions(+), 16 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 46e88380..aad1623b 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -495,6 +495,31 @@ def NE(signature, inputs): return IRDimops(NE, 'ne', signature, [anno], [input0], other=input1) +def NanToNum(signature, inputs): + assert len(inputs) == 1 + annos = ['* -> *'] + tensor = inputs[0:1] + return IRDimops(NanToNum, 'nan_to_num', signature, annos, tensor) + + +def Long(signature, inputs): + assert len(inputs) == 1 + annos = ['* -> *'] + tensor = inputs[0:1] + return IRDimops(Long, 'long', signature, annos, tensor) + + +def MaskedFill(signature, inputs): + assert len(inputs) == 3 + input0, input1, value = inputs + + edim_in0 = ShapeAnno.create_shape_str(input0.shape) + edim_in1 = ShapeAnno.create_shape_str(input1.shape) + edim_ou = copy.copy(edim_in0) + anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) + return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input0, input1], value=value) + + def LayerNorm(signature, inputs): """ torch.nn.functional.layer_norm(input, normliazed_shape, weight=None, bias=None, eps) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index de839451..a91be34f 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -65,6 +65,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('triu'): function.Triu, __ftemplate('relu') : function.ReLU, __ttemplate('ne') : function.NE, + __ttemplate('nan_to_num') : function.NanToNum, + __tttemplate('long'): function.Long, + __ttemplate('masked_fill'): function.MaskedFill, # # torch nn functional # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index f5a9326a..be4b81f8 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -408,12 +408,11 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: # return func.__name__ # TODO(yizhu1): find a general solution if isinstance(func, str): - if getattr(torch, func, None) is not None: - return f'torch.{func}' - elif getattr(torch.Tensor, func, None) is not None: - return f'torch.Tensor.{func}' - else: - raise RuntimeError(f'cannot find module for {func}') + for module, module_name in [(torch, 'torch'), (torch.Tensor, 'torch.Tensor')]: + lib_func = getattr(module, func, None) + if lib_func is not None and callable(lib_func): + return f'{module_name}.{func}' + raise RuntimeError(f'cannot find module for {func}') name = func.__name__ module = FxModuleParser._find_module_of_method(func) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index edba56d1..ab3a463b 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -49,6 +49,8 @@ def infer(node, dtypes: List[IRDType]) -> IRDType: # TODO(yizhu1): hack if node.signature == 'torch.ne': return IRDType.boolean + elif node.signature == 'torch.Tensor.long': + return IRDType.int64 # in priority: fp32 > fp16 > bool > int64 > int16 > priority = [ IRDType.float64, IRDType.float32, IRDType.float16, diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index b869fada..47e4baee 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -153,9 +153,9 @@ def random_sample(self) -> Tuple[torch.Tensor]: for shape, dtype in zip(self.shapes, self.dtypes): datas.append( torch.rand( - shape, dtype=dtype, + shape, device=torch.cuda.current_device(), - requires_grad=False) + requires_grad=False).to(dtype) ) return tuple(datas) diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index afd36e2d..78b6bd91 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -58,8 +58,8 @@ def __init__(self, dim, mult=1, nlayers=4): self.drop_out = nn.Dropout() # self.y = torch.nn.Parameter(torch.empty(128, last_dim)) - def forward(self, data): - x = data + def forward(self, data, mask): + x = data.masked_fill(mask, 0.0) for layer in self.layers: x = layer(x) x = torch.nn.functional.relu(x) @@ -69,11 +69,14 @@ def forward(self, data): x = self.drop_out(x) x = x.squeeze() x = torch.triu(x, 1) + x = torch.nan_to_num(x) # ne cannot backward # x = torch.ne(x, 1.0) # x = torch.nn.functional.dropout(x, self.p) # x = x * self.y loss = torch.sum(x) + # long cannot backward + # loss = loss.long() return loss @@ -83,19 +86,19 @@ def train(): model = MLP(dim=dim) model = cube.SemanticModel( - model, input_shapes=([batch_size, dim],), + model, input_shapes=([batch_size, dim], [batch_size, dim]), ) dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, dim],), - dtypes=(torch.float32,), - batch_dims=(0,) + shapes=([batch_size, dim], [batch_size, dim],), + dtypes=(torch.float32, torch.bool,), + batch_dims=(0, 0,) ) @cube.compile(model, dataloader, PAS=PAS, load_content=False) def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) + data, mask = next(dataloader) + loss = model(data, mask) loss.backward() model = model.get_gen_module() From 77209cbfd48b1290ea588a687978468bd53db319 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 1 Mar 2023 17:29:48 -0800 Subject: [PATCH 1274/1892] contain youshan's update --- cube/graph/parser/parserfx.py | 55 +++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 6b524a92..b621af50 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -101,7 +101,7 @@ def parse(module: torch.fx.GraphModule, print(f'dummy_inputs has {input.name}') shape = getattr(dummy_inputs, input.name).size()# None if dummy_inputs is None else dummy_inputs[idx].size() else: - # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is aligned with input.name + # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name print(f'dummy_inputs does not have {input.name}') shape = None # FIXME: use the input's real dtype @@ -110,15 +110,21 @@ def parse(module: torch.fx.GraphModule, frame.add_var(input.name, val, graph_arg=idx) input_val = [frame.get_var(input.name) for input in inputs] - # add activations to frame, including call_func output and final output - activation_nodes = [node for node in module.graph.nodes if (node.op == 'call_function' or node.op == 'output')] + # add activations to frame, including call_func/call_method output and final output + activation_op_strs = {'call_function', 'output', 'call_method'} + activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] for node in activation_nodes: - assert isinstance(node, torch.fx.Node) - shape = node.meta['tensor_meta'].shape - shape = FxModuleParser.shape_refine(shape) - dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) - frame.add_var(node.name, val) + if hasattr(node, 'meta') and node.meta.get('tensor_meta') and hasattr(node.meta['tensor_meta'], 'dtype'): + assert isinstance(node, torch.fx.Node) + shape = node.meta['tensor_meta'].shape + shape = FxModuleParser.shape_refine(shape) + dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) + val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) + frame.add_var(node.name, val) + else: + print(f'WARNING: creation of no-shaped activation for {node.name}') + val = IRFullTensor(shape=[1], requires_grad=True, dtype=ir.int32, name=node.name) # TODO fixme + frame.add_var(node.name, val) # handle nodes all_ir_nodes: List[IRFwOperation] = list() @@ -366,6 +372,34 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, return [ir_node] + @staticmethod + def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + # get signature + fsig = FxModuleParser._get_qualified_name(node.target) + print(f'parse_prim_method_node: {fsig}') + + # get inputs + input_nodes = [input_node for input_node in node.args] + input_vals = list() + for index, input_node in enumerate(input_nodes): + if isinstance(input_node, torch.fx.Node): + var_name = input_node.name + val = frame.get_var(var_name) + input_vals.append(val) + elif isinstance(input_node, int): + input_vals.append(input_node) + else: + input_vals.append(None) + + # map to IR operator + ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + + output_name = node.name + output_val = frame.get_var(output_name) + ir_node.set_output(0, output_val) + + return [ir_node] + @staticmethod def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: assert node is not None @@ -390,6 +424,9 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: # # things like getattr just appear in builtins # if getattr(builtins, func.__name__, None) is func: # return func.__name__ + if isinstance(func, str): + # TODO(yizhu1): find a general solution + return f'torch.{func}' name = func.__name__ module = FxModuleParser._find_module_of_method(func) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module From 1e1ae641a8dbb6dd401684a0ea07fb804d798e37 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 1 Mar 2023 19:03:03 -0800 Subject: [PATCH 1275/1892] fix triu anno --- cube/graph/function/function.py | 3 +++ examples/mlp/linearsfx.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index aad1623b..546c1bd6 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -865,6 +865,9 @@ def Triu(signature, inputs): input, diagonal = inputs edim_in = ShapeAnno.create_shape_str(input.shape) + assert len(edim_in) >= 2 + edim_in[-1] += '^' + edim_in[-2] += '^' edim_ou = copy.copy(edim_in) anno = OpAnno.create_op_str([edim_in], [edim_ou]) diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 78b6bd91..1496c861 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -81,16 +81,16 @@ def forward(self, data, mask): def train(): - batch_size = 128 - dim = 8192 + batch_size = 32 + dim = 1024 model = MLP(dim=dim) model = cube.SemanticModel( - model, input_shapes=([batch_size, dim], [batch_size, dim]), + model, input_shapes=([batch_size, dim, dim], [batch_size, dim, dim]), ) dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, dim], [batch_size, dim],), + shapes=([batch_size, dim, dim], [batch_size, dim, dim],), dtypes=(torch.float32, torch.bool,), batch_dims=(0, 0,) ) From 07f6c6aff4e40a8f05da17781aac7b423c74aabb Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 1 Mar 2023 23:09:50 -0800 Subject: [PATCH 1276/1892] minor --- cube/graph/parser/parserfx.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 4f7f7f1e..f089371f 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -91,6 +91,8 @@ def parse(module: torch.fx.GraphModule, else: print(f'{node.name} does not has tensor_meta') + # return + # handle graph input -- some inputs could be None or not tensor default_dtype = torch.get_default_dtype() kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype @@ -129,6 +131,7 @@ def parse(module: torch.fx.GraphModule, # handle nodes all_ir_nodes: List[IRFwOperation] = list() for node in module.graph.nodes: + print('zql handle node: ', node, node.op, node.meta) ir_nodes = FxModuleParser.parse_node(node, module, frame) all_ir_nodes += ir_nodes @@ -376,6 +379,10 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, @staticmethod def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + print('zql: ', node, node.__dir__()) + print('zql: ', node.args, node.kwargs) + print('zql: ', type(node.args[0])) + print('zql: ', node.args[0].name, node.args[0].op, node.args[0].meta) # get signature fsig = FxModuleParser._get_qualified_name(node.target) print(f'parse_prim_method_node: {fsig}') From 2b229f8143e1cfefa635aedafcd776228e988fee Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Thu, 2 Mar 2023 09:24:48 +0000 Subject: [PATCH 1277/1892] Merged PR 1459: Support complex dictionary input Support complex dictionary input --- .../parser/concrete_trace_utils/__init__.py | 14 + .../concrete_trace_utils/concrete_proxy.py | 422 +++++ .../concrete_trace_utils/concrete_tracer.py | 1416 +++++++++++++++++ .../concrete_trace_utils/operator_patcher.py | 270 ++++ .../parser/concrete_trace_utils/utils.py | 97 ++ cube/graph/parser/converter.py | 42 +- cube/graph/parser/parserfx.py | 47 +- cube/program.py | 5 +- cube/runtime/syndata.py | 22 +- examples/nlp/torchscale/fx_test.py | 203 +++ examples/nlp/torchscale/policy/mpmd.py | 103 ++ examples/nlp/torchscale/policy/spmd.py | 231 +++ 12 files changed, 2847 insertions(+), 25 deletions(-) create mode 100644 cube/graph/parser/concrete_trace_utils/__init__.py create mode 100644 cube/graph/parser/concrete_trace_utils/concrete_proxy.py create mode 100644 cube/graph/parser/concrete_trace_utils/concrete_tracer.py create mode 100644 cube/graph/parser/concrete_trace_utils/operator_patcher.py create mode 100644 cube/graph/parser/concrete_trace_utils/utils.py create mode 100644 examples/nlp/torchscale/fx_test.py create mode 100644 examples/nlp/torchscale/policy/mpmd.py create mode 100644 examples/nlp/torchscale/policy/spmd.py diff --git a/cube/graph/parser/concrete_trace_utils/__init__.py b/cube/graph/parser/concrete_trace_utils/__init__.py new file mode 100644 index 00000000..b825294a --- /dev/null +++ b/cube/graph/parser/concrete_trace_utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +FX is a toolkit for developers to use to transform ``nn.Module`` instances. FX consists of three main components, and this pipeline of +components (symbolic tracing -> intermediate representation -> transforms -> Python code generation) constitutes the Python-to-Python +transformation pipeline. + +This util consists a **concrete tracer** which extends the **symbolic tracer** in FX. It performs "concrete execution" of the Python code. +Then we can get the **intermediate representation** of ``nn.Module`` instances. + +More information about concrete tracing can be found in the :func:`concrete_trace` documentation. +""" +from .concrete_tracer import ConcreteTracer, concrete_trace \ No newline at end of file diff --git a/cube/graph/parser/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/concrete_trace_utils/concrete_proxy.py new file mode 100644 index 00000000..8b2fcc57 --- /dev/null +++ b/cube/graph/parser/concrete_trace_utils/concrete_proxy.py @@ -0,0 +1,422 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import dis +import logging +import inspect +import operator + +from typing import List, Optional, Iterable, Any, Set, Union + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.graph import magic_methods, reflectable_magic_methods +from torch.fx.node import Node +from torch.fx.proxy import Proxy +from torch.overrides import is_tensor_method_or_property + +from . import concrete_tracer as et +from .utils import ( + _orig_tuple, + _orig_list, + _orig_type, + _orig_isinstance, + _orig_getattr, + _orig_range, + _orig_dict, + _orig_len, + _orig_index, + _orig_bool, + _orig_slice, + _orig_set, + map_recursive, +) + +_logger = logging.getLogger(__name__) + +@compatibility(is_backward_compatible=True) +class ConcreteProxy(Proxy): + """ + `ConcreteProxy` is a wrapped proxy carried the real intermediate value. + We can use it to trace a more compatibal model, and pass the branches. + """ + + # TODO: python bytecode changes a lot in version 3.11. these ops should be updated. + jump_opnames = ( + 'JUMP_IF_FALSE_OR_POP', + 'JUMP_IF_TRUE_OR_POP', + 'POP_JUMP_IF_FALSE', + 'POP_JUMP_IF_TRUE', + 'JUMP_IF_NOT_EXC_MATCH', # occurred in new python vertion, not tested + ) + jump_opcodes = _orig_tuple(dis.opmap[name] for name in jump_opnames if name in dis.opmap) + op_compare = dis.opmap['COMPARE_OP'] + op_extended_arg = dis.opmap['EXTENDED_ARG'] + op_call = dis.opmap['CALL_FUNCTION'] + op_call_ex = dis.opmap['CALL_FUNCTION_EX'] + op_not = dis.opmap['UNARY_NOT'] + op_unpack_sequence = dis.opmap['UNPACK_SEQUENCE'] + jump_before_opcodes = (op_compare, op_not) + + # occurred in different python versions + op_list_extend = dis.opmap['LIST_EXTEND'] if 'LIST_EXTEND' in dis.opmap else None + op_tuple_unpack_call = dis.opmap['BUILD_TUPLE_UNPACK_WITH_CALL'] if 'BUILD_TUPLE_UNPACK_WITH_CALL' in dis.opmap else None + + def __init__(self, node: Node, value: Any, tracer: Optional[et.ConcreteTracer] = None): + if tracer is None: + # This allows you to create a ConcreteProxy object around a raw Node + tracer = et.GraphAppendingConcreteTracer(node.graph) + self.tracer = tracer + self.value = value + self.node = node + + def __repr__(self) -> str: + # to detect if in debugging or in code + calling_frame_name = inspect.stack()[1][1] + if calling_frame_name.endswith('pydevd_exe2.py') or calling_frame_name.endswith('pydevd_safe_repr.py'): + return f'ConcreteProxy({self.node.name})' + return repr(self.value) + + def __getattr__(self, k) -> ConcreteProxy: + return ConcreteAttrProxy(self, k) + + def __call__(self, *args, **kwargs) -> ConcreteProxy: + return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) + + def __iter__(self) -> Union[Iterable, ConcreteProxy]: + # to detect if in executing `*proxy`, or `a, b, c = atuple` + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + cur = calling_frame.f_lasti // 2 + insts: List[dis.Instruction] = _orig_list(dis.get_instructions(calling_frame.f_code)) + while insts[cur].opcode == self.op_extended_arg: + cur += 1 + + if insts[cur].opcode == self.op_call_ex: + # in executing func(..., *proxy) + # todo: don't know the func has type_guard or not + return ConcreteUnpackIterProxy(self) + elif insts[cur].opcode == self.op_tuple_unpack_call: + # in executing func(*..., *proxy) + # todo: don't know the func has type_guard or not + # <= python 3.8 + return ConcreteUnpackIterProxy(self) + elif insts[cur].opcode == self.op_list_extend: + # in executing x.extend(proxy) or [x, *proxy] + # >= python 3.9 + return ConcreteUnpackIterProxy(self) + elif insts[cur].opcode == self.op_unpack_sequence: + # in executing `a, b, c = atuple` + return ConcreteUnpackIterProxy(self) + elif insts[cur].opname == 'GET_ITER' and insts[cur + 1].opname == 'FOR_ITER' and _orig_isinstance(self.value, _orig_range): + # in executing `for i in range(...)` + return iter(self.value) + # elif insts[cur].opname == 'CONTAINS_OP': + # # in executing `for i in range(...)` + # return iter(self.value) + else: + return self.tracer.create_proxy('call_function', iter, (self,), {}) + + def __next__(self) -> ConcreteProxy: + return self.tracer.create_proxy('call_function', next, (self,), {}) + + def __len__(self) -> Union[int, ConcreteProxy]: + # to detect if in executing `*proxy` + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + cur = calling_frame.f_lasti // 2 + insts: List[dis.Instruction] = _orig_list(dis.get_instructions(calling_frame.f_code)) + while insts[cur].opcode == self.op_extended_arg: + cur += 1 + + if insts[cur].opcode == self.op_call_ex: + # in executing func(..., *proxy) + return _orig_len(self.value) + elif insts[cur].opcode == self.op_tuple_unpack_call: + # in executing func(*..., *proxy) + # <= python 3.8 + return _orig_len(self.value) + elif insts[cur].opcode == self.op_list_extend: + # in executing x.extend(*proxy) or [x, *proxy] + # >= python 3.9 + return _orig_len(self.value) + else: + return self.tracer.create_proxy('call_function', _orig_len, (self,), {}) + + def __getitem__(self, *args, **kwargs) -> ConcreteProxy: + return self.tracer.create_proxy('call_function', operator.getitem, (self,) + args, kwargs) + + def __setitem__(self, *args, **kwargs) -> ConcreteProxy: + return self.tracer.create_proxy('call_function', operator.setitem, (self,) + args, kwargs) + + def __bool__(self) -> Union[bool, ConcreteProxy]: + # to detect if in executing branch condition + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + cur = calling_frame.f_lasti // 2 + insts: List[dis.Instruction] = _orig_list(dis.get_instructions(calling_frame.f_code)) + while insts[cur].opcode == self.op_extended_arg: + cur += 1 + + if insts[cur].opcode in self.jump_opcodes or ( + insts[cur].opcode in self.jump_before_opcodes and insts[cur + 1].opcode in self.jump_opcodes): + # in executing branch condition + return _orig_bool(self.value) + elif insts[cur].opname == 'CONTAINS_OP': + # in executing 'in' + return _orig_bool(self.value) + elif insts[cur].opcode == self.op_not: + # We cannot return a proxy because 'UNARY_NOT' op will check the type. + _logger.warning('please use the function patcher, or use "x = operator.not_(y)" instead of "x = not y",' + 'otherwise the traced graph may be wrong') + return _orig_bool(self.value) + else: + return self.tracer.create_proxy('call_function', _orig_bool, (self,), {}) + + def __index__(self) -> Union[int, ConcreteProxy]: + # should only be in list/tuple getitem + return _orig_index(self.value) + + def __hash__(self) -> Union[int, ConcreteProxy]: + # should only be in dict getitem + return hash(self.value) + + @compatibility(is_backward_compatible=True) + def keys(self): + # to detect if in executing `**proxy` + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + cur = calling_frame.f_lasti // 2 + insts: List[dis.Instruction] = _orig_list(dis.get_instructions(calling_frame.f_code)) + while insts[cur].opcode == self.op_extended_arg: + cur += 1 + + if insts[cur].opcode == self.op_call_ex: + # in executing `**proxy` + return self.value.keys() + else: + return self.tracer.create_proxy('call_method', 'keys', (self,), {}) + + @compatibility(is_backward_compatible=True) + def values(self): + return self.tracer.create_proxy('call_method', 'values', (self,), {}) + + @compatibility(is_backward_compatible=True) + def items(self): + return self.tracer.create_proxy('call_method', 'items', (self,), {}) + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + # to wrap all the functions/methods with tensor inputs in the namespace 'torch.*'. + # actually a simple way to do wrap, but may get wrong in functions with no tensor inputs. + # TODO: now for most functions in torch namespace, we do wrap directly and not use __torch_function__ + + args = args if args else () + kwargs = kwargs if kwargs else {} + + tracers: Set[Any] = _orig_set() + + def find_tracer(a): + if _orig_isinstance(a, cls): + tracers.add(a.tracer) + map_recursive(find_tracer, args) + map_recursive(find_tracer, kwargs) + + if _orig_len(tracers) > 1: + raise RuntimeError(f'Found multiple different tracers {_orig_list(tracers)} while ' + f'trying to trace operations {orig_method}') + tracer, = tracers + + if isinstance(orig_method, torch._C.ScriptMethod): + args = (orig_method.owner,) + args + return tracer.create_proxy('call_method', orig_method.name, args, kwargs) + if is_tensor_method_or_property(orig_method): + return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) + else: + return tracer.create_proxy('call_function', orig_method, args, kwargs, + name=tracer.graph._target_to_str(orig_method.__name__)) + + +@compatibility(is_backward_compatible=True) +class ConcreteAttrProxy(ConcreteProxy): + """ + A more understandable way to deal with sub-field like 'x.y'. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, root: ConcreteProxy, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node: Optional[Node] = None + self.value = _orig_getattr(root.value, attr) + + def __repr__(self) -> str: + calling_frame_name = inspect.stack()[1][1] + if calling_frame_name.endswith('pydevd_exe2.py') or calling_frame_name.endswith('pydevd_safe_repr.py'): + return f'ConcreteAttrProxy({self.node.name})' + return repr(self.value) + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy( + 'call_function', _orig_getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + +@compatibility(is_backward_compatible=True) +class ConcreteUnpackIterProxy(ConcreteProxy): + """ + A more understandable way to deal with iterables. + Only support 'tuple' and 'list'. Will transfer un-subscriptables such as 'set', to 'tuple'. + todo: support for 'zip' + + examples: + 1. `a, b = c` => + ori: + iter1 = c.__iter__() + a = iter1.__next__() + b = iter1.__next__() + new: + a = c[0] + b = c[1] + + 2. `y = [x, *proxy]` => + ori: + iter1 = c.__iter__() + a = iter1.__next__() + b = iter1.__next__() + y = [x, a, b] + new: + a = proxy[0] + b = proxy[1] + y = [x, a, b] + """ + + @staticmethod + def try_create(root: Any): + if isinstance(root, ConcreteProxy): + return ConcreteUnpackIterProxy(root) + else: + return iter(root) + + @compatibility(is_backward_compatible=True) + def __init__(self, root: ConcreteProxy): + if not hasattr(root.value, '__getitem__'): + # transfer 'set' to 'tuple' + # it's tuple not _orig_tuple! + # root = tuple(root) + root = root.tracer.create_proxy('call_function', _orig_tuple, (root,), {}) + self.root = root + self.tracer = root.tracer + self._node: Optional[Node] = None + self._value: List[Any] = [] + self.index = -1 + self.len = _orig_len(root.value) + + def __repr__(self) -> str: + return f'ConcreteUnpackIterProxy({self.node.name})' + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy( + 'call_function', iter, (self.root,), {}).node + return self._node + + @property + def value(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if _orig_len(self._value) == 0: + self._value.append(iter(self.root.value)) + return self._value[0] + + def __next__(self): + self.index += 1 + if self.index == self.len: + raise StopIteration() + return self.tracer.create_proxy('call_function', operator.getitem, (self.root, self.index), {}) + +@compatibility(is_backward_compatible=True) +def map_aggregate_not_proxy(a, fn): + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + if _orig_isinstance(a, ConcreteProxy): + return fn(a) + elif _orig_isinstance(a, _orig_tuple): + t = _orig_tuple(map_aggregate_not_proxy(elem, fn) for elem in a) + # Support NamedTuple (if it has `_fields`) by repacking into original type. + return t if not hasattr(a, '_fields') else _orig_type(a)(*t) + elif _orig_type(a) == _orig_list: + return _orig_list(map_aggregate_not_proxy(elem, fn) for elem in a) + elif _orig_isinstance(a, _orig_dict): + return _orig_dict((k, map_aggregate_not_proxy(v, fn)) for k, v in a.items()) + elif _orig_isinstance(a, _orig_slice): + return _orig_slice(map_aggregate_not_proxy(a.start, fn), map_aggregate_not_proxy(a.stop, fn), map_aggregate_not_proxy(a.step, fn)) + else: + return fn(a) + +# register or wrap common methods on 'ConcreteProxy' +# for method in magic_methods: +# torch.fx.graph.inplace_methods may not exist on some verion of pytorch +inplace_methods = { + 'iadd': '{} += {}', + 'iand': '{} &= {}', + 'ifloordiv': '{} //= {}', + 'ilshift': '{} <<= {}', + 'imod': '{} %= {}', + 'imul': '{} *= {}', + 'imatmul': '{} @= {}', + 'ior': '{} |= {}', + 'ipow': '{} **= {}', + 'irshift': '{} >>= {}', + 'isub': '{} -= {}', + 'itruediv': '{} /= {}', + 'ixor': '{} ^= {}', + 'setitem': '{}[{}] = {}', +} +for method in {**magic_methods, **inplace_methods}: + def _scope(method): + def impl(*args, **kwargs): + tracer = args[0].tracer + target = _orig_getattr(operator, method) + return tracer.create_proxy('call_function', target, args, kwargs) + impl.__name__ = method + as_magic = f'__{method.strip("_")}__' + setattr(ConcreteProxy, as_magic, impl) + _scope(method) + + +def _define_reflectable(orig_method_name): + method_name = f'__r{orig_method_name.strip("_")}__' + + def impl(self, rhs): + target = _orig_getattr(operator, orig_method_name) + return self.tracer.create_proxy('call_function', target, (rhs, self), {}) + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(ConcreteProxy, method_name, impl) + + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) \ No newline at end of file diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py new file mode 100644 index 00000000..db1e3b84 --- /dev/null +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -0,0 +1,1416 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import sys +import inspect +import operator +import functools +import builtins +import copy + +from itertools import chain +from types import BuiltinMethodType, FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType +from typing import Any, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, List, Callable, Union +from contextlib import contextmanager + +import torch +from torch._C import ScriptObject +from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict + +import torch.fx +from torch.fx import GraphModule +from torch.fx._compatibility import compatibility +from torch.fx._symbolic_trace import _Patcher, _proxyable_classes +from torch.fx.graph import Graph +from torch.fx.node import Target, Node +from torch.fx.proxy import TracerBase + +from . import concrete_proxy as ep +from .operator_patcher import OperatorPatcherContext +from .utils import ( + _orig_module_call, + _orig_module_getattr, + _orig_module_getattribute, + + _orig_agfunc_apply, + _orig_torch_assert, + + _orig_type, + _orig_isinstance, + _orig_getattr, + + _orig_range, + _orig_int, + _orig_bool, + _orig_tuple, + _orig_list, + _orig_set, + _orig_frozenset, + _orig_dict, + _orig_map, + _orig_zip, + _orig_enumerate, + _orig_slice, + + _orig_len, + _orig_not, + _orig_is, + _orig_is_not, + _orig_contains, + _orig_index, +) + +HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS + +@compatibility(is_backward_compatible=True) +class ConcreteTracer(TracerBase): + """ + A model tracer similar to _symbolic_trace.Tracer, but with concrete execution and real value so we can pass complecate conditions + and go into correct brunches. + """ + + default_autowrap_modules = ( + 'math', + ) + default_autowrap_leaf_function: Dict[Any, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool, Optional[Callable]]] = { + # function + _orig_len: ([], False, None), + _orig_not: ([], False, None), + _orig_is: ([], False, None), + _orig_is_not: ([], False, None), + _orig_contains: ([], False, None), + _orig_index: ([], False, None), + + # force-traced function + torch.rand: ([], True, None), + torch.randn: ([], True, None), + torch.randint: ([], True, None), + torch.rand_like: ([], True, None), + torch.randn_like: ([], True, None), + torch.randint_like: ([], True, None), + torch.randperm: ([], True, None), + + # method + Sequential.__getitem__: ([], False, operator.getitem), + Sequential.__len__: ([], False, _orig_len), + Sequential.__iter__: ([], False, iter), + + ModuleList.__getitem__: ([], False, operator.getitem), + ModuleList.__len__: ([], False, _orig_len), + ModuleList.__iter__: ([], False, iter), + + ModuleDict.__getitem__: ([], False, operator.getitem), + ModuleDict.__len__: ([], False, _orig_len), + ModuleDict.__iter__: ([], False, iter), + ModuleDict.__contains__: ([], False, _orig_contains), + + ParameterList.__getitem__: ([], False, operator.getitem), + ParameterList.__len__: ([], False, _orig_len), + ParameterList.__iter__: ([], False, iter), + + ParameterDict.__getitem__: ([], False, operator.getitem), + ParameterDict.__len__: ([], False, _orig_len), + ParameterDict.__iter__: ([], False, iter), + ParameterDict.__contains__: ([], False, _orig_contains), + } + # equals to `from torch.nn import functional as nn_functional` + # to pass pyright check + nn_functional = getattr(torch.nn, 'functional') + # order: torch.nn.functional > torch._C._VariableFunctions > torch._C._nn > torch._C._TensorBase + for name in torch.functional.__all__: + attr = getattr(torch.functional, name) + if attr not in default_autowrap_leaf_function: + default_autowrap_leaf_function[attr] = ([], False, attr) + for name in dir(nn_functional): + attr = getattr(nn_functional, name) + if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__')\ + and getattr(attr, '__module__', None) not in ('typing', 'torch.nn.modules.utils'): + if attr not in default_autowrap_leaf_function: + default_autowrap_leaf_function[attr] = ([], False, getattr(torch.functional, name, None)) + if hasattr(attr, '__module__') and attr.__module__ != 'torch.nn.functional': + default_autowrap_leaf_function[attr][0].append((nn_functional, name)) + for name in dir(torch._C._VariableFunctions): + attr = getattr(torch._C._VariableFunctions, name) + if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): + if attr not in default_autowrap_leaf_function: + default_autowrap_leaf_function[attr] = ([], False, getattr(torch.functional, name, None)) + for name in dir(torch._C._nn): + attr = getattr(torch._C._nn, name) + if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): + if attr not in default_autowrap_leaf_function: + default_autowrap_leaf_function[attr] = ([], False, getattr(torch.functional, name, None)) + if hasattr(attr, '__module__') and attr.__module__ != 'torch._C._nn': + default_autowrap_leaf_function[attr][0].append((torch._C._nn, name)) + for name in dir(torch._C._TensorBase): + attr = getattr(torch._C._TensorBase, name) + if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): + if attr not in default_autowrap_leaf_function: + to_func = getattr(torch.Tensor, name, None) + to_func = None if to_func == attr else to_func + default_autowrap_leaf_function[attr] = ([], False, to_func) + + default_autowrap_leaf_class: Dict[Type, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool]] = { + # class + _orig_bool: ([], False), + _orig_zip: ([], False), + _orig_int: ([], False), + + # iterable class + _orig_tuple: ([], True), + _orig_list: ([], True), + _orig_set: ([], True), + _orig_frozenset: ([], True), + _orig_dict: ([], True), + } + + current_module_qualified_name : str = '' + node_to_originating_module : Dict[torch.fx.Node, str] = {} + + @compatibility(is_backward_compatible=True) + def __init__(self): + """ + similar to _symbolic_trace.Tracer.__init__. + remove the 'param_shapes_constant' because we can get real shape when executing. + """ + super().__init__() + + @contextmanager + def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): + assert call | attr | agfunc_apply + # to pass pyright check + temp_disable_call, temp_disable_attr, temp_disable_agfunc_apply = False, False, False + if call: + self.temp_disable_call_level += 1 + temp_disable_call = self.temp_disable_call + self.temp_disable_call = True + if attr: + self.temp_disable_attr_level += 1 + temp_disable_attr = self.temp_disable_attr + self.temp_disable_attr = True + if agfunc_apply: + self.temp_disable_agfunc_apply_level += 1 + temp_disable_agfunc_apply = self.temp_disable_agfunc_apply + self.temp_disable_agfunc_apply = True + try: + yield + finally: + if agfunc_apply: + self.temp_disable_agfunc_apply = temp_disable_agfunc_apply + self.temp_disable_agfunc_apply_level -= 1 + if attr: + self.temp_disable_attr = temp_disable_attr + self.temp_disable_attr_level -= 1 + if call: + self.temp_disable_call = temp_disable_call + self.temp_disable_call_level -= 1 + + @compatibility(is_backward_compatible=True) + def fetch_attr(self, target: str) -> Any: + """ + to get the attr in self.root. only for execution of 'call_module' nodes. + """ + with self.do_temp_disable(attr=True): + target_atoms = target.split('.') + attr_itr = self.root + for i, atom in _orig_enumerate(target_atoms): + # if atom == '': + # continue + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target \'{'.'.join(target_atoms[:i])}\'") + attr_itr = _orig_getattr(attr_itr, atom) + return attr_itr + + def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]): + """ + actually execute the code. + apply the patcher, and the _autowrap_check to the target function. + """ + if kind == 'call_function': + assert isinstance(target, Callable) + fn = target + if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): + _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + with self.do_temp_disable(call=True): + return OperatorPatcherContext.patch_run(fn, *args, **kwargs) + elif kind == 'call_method': + self_obj, *args_tail = args + fn = _orig_getattr(self_obj, target) + if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): + _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + with self.do_temp_disable(call=True): + return OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) + elif kind == 'call_module': + assert isinstance(target, str) + mod = self.fetch_attr(target) + if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(mod, '__globals__'): + _autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + with self.do_temp_disable(call=True): + return OperatorPatcherContext.patch_run(mod, *args, **kwargs) + elif kind == 'get_attr': + assert isinstance(target, str) + return self.fetch_attr(target) + elif kind == 'output': + return args[0] + elif kind == 'placeholder': + return self.placeholder_dict[target] + else: + raise RuntimeError() + + @compatibility(is_backward_compatible=True) + def proxy(self, value: Any, node: Node) -> ep.ConcreteProxy: + """ + overloaded to use custom 'proxy'. + """ + return ep.ConcreteProxy(node, value, self) + + @compatibility(is_backward_compatible=True) + def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], + name: Optional[str] = None, type_expr: Optional[Any] = None): + """ + similar to _symbolic_trace.Tracer.create_proxy. + use the 'run_target' to actually execute the code, and store the value in 'value' field. + """ + def upwrapper(obj: Any): + while _orig_isinstance(obj, ep.ConcreteProxy): + obj = obj.value + return obj + args_unwrapped = ep.map_aggregate_not_proxy(args, upwrapper) + kwargs_unwrapped = ep.map_aggregate_not_proxy(kwargs, upwrapper) + + # real value by execution + value_unwrapped = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) + + args_noded = self.create_arg(args) + kwargs_noded = self.create_arg(kwargs) + + assert isinstance(args_noded, tuple) + assert isinstance(kwargs_noded, dict) + + node = self.create_node(kind, target, args_noded, kwargs_noded, name, type_expr) + # return self.proxy(value_unwrapped, node) + proxy = self.proxy(value_unwrapped, node) + self.node_to_originating_module[proxy.node] = self.current_module_qualified_name + return proxy + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> Union[Node, Any]: + """ + similar to _symbolic_trace.Tracer.create_arg + move the base case to the top in case the wrapping of the function 'isinstance' + """ + # base case: we unwrap the Proxy object + if isinstance(a, ep.ConcreteProxy): + return a.node + + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node('get_attr', n, (), {}) + raise NameError('parameter is not a member of this module') + elif isinstance(a, torch.Tensor): + for n_, p_ in self.root.named_buffers(): + if a is p_: + return self.create_node('get_attr', n_, (), {}) + elif isinstance(a, torch.nn.Module): + for n_, p_ in self.root.named_modules(): + if a is p_: + return self.create_node('get_attr', n_, (), {}) + # for slice + if isinstance(a, slice): + start = self.create_arg(a.start) + stop = self.create_arg(a.stop) + step = self.create_arg(a.step) + if _orig_isinstance(start, Node)\ + or _orig_isinstance(stop, Node)\ + or _orig_isinstance(step, Node): + return self.create_node('call_function', _orig_slice, (start, stop, step), {}) + else: + return a + # For NamedTuple instances that appear literally as args, we emit + # a node to construct the NamedTuple and use that Node as the argument. + if isinstance(a, tuple) and hasattr(a, '_fields'): + args = tuple(self.create_arg(elem) for elem in a) + return self.create_node('call_function', a.__class__, args, {}) + + # Tensors do not have a reliable string repr() from which they can be + # constructed (and we probably don't want to rely on that, either), so + # for any constant Tensor values we encounter, first search for if they + # are an attribute of some module in the module hierarchy. If so, emit + # a get_attr to retrieve that tensor. Otherwise, we'll store away the + # tensor value into a special attribute on the Module s.t. we can + # retrieve it with a get_attr. + if isinstance(a, (torch.Tensor, ScriptObject)): + qualname: Optional[str] = self.tensor_attrs.get(a) + + # Tensor was not found in the Module hierarchy, stow it away in a + # TODO: warning for the not found tensor + if not qualname: + i = 0 + while True: + qualname = f'_tensor_constant{i}' + if not hasattr(self.root, qualname): + break + i += 1 + self.tensor_attrs[a] = qualname + setattr(self.root, qualname, a) + + return self.create_node('get_attr', qualname, (), {}) + + if _orig_type(a) in _proxyable_classes: + # This is an instance of a proxyable class for which we did not + # witness its construction. Intern this as a constant attribute + + # TODO: binary search + i = 0 + while True: + qualname = f'_{a.__class__.__name__}_constant_{i}' + if not hasattr(self.root, qualname): + break + i += 1 + setattr(self.root, qualname, a) + + return self.create_node('get_attr', qualname, (), {}) + + if isinstance(a, (torch.autograd.function.Function, torch.autograd.function.FunctionMeta)): + return a + + return super().create_arg(a) + + @compatibility(is_backward_compatible=True) + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + similar to _symbolic_trace.Tracer.is_leaf_module + """ + # return (m.__module__.startswith('torch.nn') and not _orig_isinstance(m, (Sequential, ModuleList, ModuleDict)))\ + # or _orig_isinstance(m, self.leaf_module) + return (m.__module__.startswith('torch.nn.functional') and not _orig_isinstance(m, (Sequential, ModuleList, ModuleDict)))\ + or _orig_isinstance(m, self.leaf_module) + + @compatibility(is_backward_compatible=True) + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + similar to _symbolic_trace.Tracer.path_of_module + """ + # Prefer the O(1) algorithm + if self.submodule_paths: + path = self.submodule_paths.get(mod) + # TODO: better infomation + if path is None: + if not hasattr(self.root, '_module_constants'): + self.root._module_constants = torch.nn.ModuleList() + module_constants = self.root._module_constants + assert isinstance(module_constants, torch.nn.ModuleList) + if hasattr(mod, 'extra_repr'): + sub_path = _orig_type(mod).__name__ + mod.extra_repr() + else: + sub_path = str(_orig_len(module_constants)) + if not hasattr(module_constants, sub_path): + module_constants.add_module(sub_path, mod) + path = '_module_constants.%s' % sub_path + self.submodule_paths[mod] = path + return path + assert isinstance(path, str) + return path + # O(N^2) fallback in the case that we didn't store the submodule + # paths. + else: + for n, p in self.root.named_modules(): + if mod is p: + return n + raise NameError('module is not installed as a submodule') + + # This method will be refactored + @compatibility(is_backward_compatible=False) + def create_args_for_root(self, root_fn, is_module, concrete_args: Union[Dict[str, Any], Tuple]) -> Tuple[Any, list, Any, Any]: + """ + for wrapping all the parameters of the function with dummy_input. + in concrete tracer, we need all the parameters input by users. + + todo: this function should be refactored after the same function in torch.fx be refactored. + """ + # In some cases, a function or method has been decorated with a wrapper + # defined via ``functools.wraps``. In this case, the outer code object + # will likely not contain the actual parameters we care about, so unwrap + # the function to get to the innermost callable. + fn_for_analysis = inspect.unwrap(root_fn) + default_value_list = fn_for_analysis.__defaults__ + if default_value_list is None: + default_value_list = tuple() + co = fn_for_analysis.__code__ + total_args = co.co_argcount + co.co_kwonlyargcount + # orig_args = list(co.co_varnames) + names_iter = iter(co.co_varnames) + args: List[Any] = [] + more_args = [] + kwargs = {} + skip_arg_idx = 0 + if is_module: + if total_args == 0: + raise RuntimeError('``self`` argument cannot be part of *args expansion!') + skip_arg_idx = 1 + next(names_iter) # skip self + args.append(self.root) + + cnt = 0 + self.placeholder_dict = {} + arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] + diff_len = len(arg_names) - len(default_value_list) + default_args = {arg_names[idx + diff_len]: default_value_list[idx] for idx in range(len(default_value_list))} + if isinstance(concrete_args, tuple): + if len(arg_names) != len(concrete_args): + raise RuntimeError(f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments") + concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} + def proxy_placeholder(name: str): + nonlocal cnt + cnt += 1 + + default_arg = () + if name in default_args and not name.startswith('*'): + default_arg = (default_args[name],) + + if name in concrete_args: + self.placeholder_dict[name] = concrete_args[name] + else: + # TODO: better infomation + assert name in default_args + self.placeholder_dict[name] = default_args[name] + return self.create_proxy('placeholder', name, default_arg, {}) + args.extend(proxy_placeholder(names) for names in arg_names) + + + if hasattr(co, 'co_kwonlyargcount') and ( + co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF): + # TODO: type annotations for *args and **kwargs + if co.co_flags & inspect.CO_VARARGS: + name = '*' + next(names_iter) + default_args[name] = () + more_args = proxy_placeholder(name) + if co.co_flags & inspect.CO_VARKEYWORDS: + name = '**' + next(names_iter) + default_args[name] = {} + kwargs = proxy_placeholder(name) + + return root_fn, args, more_args, kwargs + + @compatibility(is_backward_compatible=True) + def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, + autowrap_modules: Tuple[str] = (), + autowrap_leaf_function = {}, + autowrap_leaf_class = {}, + leaf_module = (), + fake_middle_class = (), + concrete_args: Union[Dict[str, Any], Tuple], + use_operator_patch: bool = True, + operator_patch_backlist: List[str] = [], + forwrad_function_name: str = 'forward') -> Graph: + """ + similar to _symbolic_trace.Tracer.trace + different args: + use_operator_patch: + the operators 'not/is/is not/in/not in' cannot be wrapped after + compiled. so we re-parse the functions, replace these operators + with functions 'operator.not_/is_/is_not/contains', then we + could wrap and trace these. + for example: in ``if x is None:``, if x is a proxy, the tracer will + never go into the branch, even x is a proxy with value 'None'. + values: + true: before executing a func, the func will be patched if the func + is not in operator_patch_backlist + false: before executing a func, the func will be patched if the func + is in operator_patch_backlist + + operator_patch_backlist: + such as '__main__.FooModel' or '__main__.bar_func'. the namespace is + always needed. + """ + + # Python modules to apply autowrap to at the start, in addition to + # modules we see while tracing + self._autowrap_search: List[ModuleType] = list(sys.modules[m] for m in (*autowrap_modules, *ConcreteTracer.default_autowrap_modules)) + # Functions we will eagerly wrap when we see them while tracing + # this captures both `math.sqrt()` and `from math import sqrt` automatically + self._autowrap_function_ids: Set[int] = { + id(value) for name, value in chain(*[m.__dict__.items() for m in self._autowrap_search]) + if not name.startswith("_") and callable(value)} + + self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None + + self.autowrap_leaf_function = {**autowrap_leaf_function, **ConcreteTracer.default_autowrap_leaf_function} + self.autowrap_leaf_class = {**autowrap_leaf_class, **ConcreteTracer.default_autowrap_leaf_class} + self.leaf_module = leaf_module + self.fake_middle_class = fake_middle_class + if isinstance(root, torch.nn.Module): + self.root = root + + # TODO: better infomation + assert hasattr( + root, forwrad_function_name + ), f"traced_func_name={forwrad_function_name} doesn't exist in {_orig_type(root).__name__}" + + fn = getattr(root, forwrad_function_name) + self.submodule_paths = {mod: name for name, mod in root.named_modules()} + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls = getattr(self, '__class__', None) + self.graph = Graph(tracer_cls=tracer_cls, tracer_extras={ + 'autowrap_modules': autowrap_modules, + 'autowrap_leaf_function': autowrap_leaf_function, + 'autowrap_leaf_class': autowrap_leaf_class, + 'leaf_module': leaf_module, + 'fake_middle_class': fake_middle_class, + 'concrete_args': concrete_args, + 'use_operator_patch': use_operator_patch, + 'operator_patch_backlist': operator_patch_backlist, + 'forwrad_function_name': 'forward', + }) + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + if isinstance(fn, MethodType): + fn = fn.__func__ + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args, more_args, kwargs = self.create_args_for_root(fn, isinstance(root, torch.nn.Module), concrete_args) + + self.the_path_of_parameter = {id(v): k for k, v in self.root.named_parameters()} + self.the_path_of_buffer = {id(v): k for k, v in self.root.named_buffers()} + + def get_middle_class(node, memo = set(), prefix = ''): + if node not in memo: + memo.add(node) + yield prefix, node + if isinstance(node, torch.nn.Module): + items = (*((k, v) for k, v in node.__dict__.items() if not k.startswith('_')), *node._modules.items()) + else: + items = ((k, v) for k, v in node.__dict__.items() if not k.startswith('_')) + for name, subfield in items: + if isinstance(subfield, (torch.nn.Module, self.fake_middle_class)): + submodule_prefix = prefix + ('.' if prefix else '') + name + for m in get_middle_class(subfield, memo, submodule_prefix): + yield m + self.the_path_of_middle_class = {id(v): k for k, v in get_middle_class(self.root)} + + @functools.wraps(_orig_module_getattribute) + def module_getattribute_wrapper(mod, attr): + if self.temp_disable_call | self.temp_disable_attr: + try: + return _orig_module_getattribute(mod, attr) + except AttributeError: + return _orig_module_getattr(mod, attr) + with self.do_temp_disable(call=True, attr=True): + try: + attr_val = _orig_module_getattribute(mod, attr) + except AttributeError: + attr_val = _orig_module_getattr(mod, attr) + if callable(attr_val): + if attr_val in self.wrapped_leaf: + return self.wrapped_leaf[attr_val][1] + return attr_val + elif _orig_isinstance(attr_val, (_orig_tuple, _orig_list)): + if self.the_path_of_middle_class[id(mod)] == '': + return self.create_proxy('get_attr', f'{attr}', (), {}) + else: + return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) + elif id(attr_val) in self.the_path_of_parameter: + return self.create_proxy('get_attr', self.the_path_of_parameter[id(attr_val)], (), {}) + elif id(attr_val) in self.the_path_of_buffer: + return self.create_proxy('get_attr', self.the_path_of_buffer[id(attr_val)], (), {}) + return attr_val + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + if self.temp_disable_call: + return _orig_module_call(mod, *args, **kwargs) + else: + # corresponding to call_module + old_qualname = self.current_module_qualified_name + try: + self.current_module_qualified_name = self.path_of_module(mod) + module_qualified_name = self.path_of_module(mod) + if not self.is_leaf_module(mod, module_qualified_name): + _autowrap_check(self, mod.forward.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + _autowrap_check(self, mod.__dict__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + return _orig_module_call(mod, *args, **kwargs) + else: + return self.create_proxy('call_module', module_qualified_name, args, kwargs) + finally: + self.current_module_qualified_name = old_qualname + + class map_wrapper_clz: + @functools.wraps(_orig_map) + def __call__(self, the_func, *iterables: Any): + tracers = _orig_set() + for one_iter in iterables: + if _orig_isinstance(one_iter, ep.Proxy): + tracers.add(one_iter.tracer) + if _orig_len(tracers) > 1: + raise Exception('more than 1 tracer detected. please report the issue') + elif _orig_len(tracers) == 1: + results = _orig_list() + for args in _orig_zip(*iterables): + results.append(the_func(*args)) + return next(iter(tracers)).create_proxy('call_function', _orig_tuple, (results,), {}) + + ## for the multi-level list/tuple + iterables = _orig_list(_orig_list(it) for it in iterables) + for it in iterables: + for arg in it: + if _orig_isinstance(arg, ep.Proxy): + tracers.add(arg.tracer) + if _orig_len(tracers) > 1: + raise Exception('more than 1 tracer detected. please report the issue') + elif _orig_len(tracers) == 1: + results = _orig_list() + for args in _orig_zip(*iterables): + results.append(the_func(*args)) + return next(iter(tracers)).create_proxy('call_function', _orig_tuple, (results,), {}) + ## for the multi-level list/tuple end + + return _orig_map(the_func, *iterables) + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(_orig_map)) + def __hash__(self): + return id(self) + map_wrapper = map_wrapper_clz() + + class range_wrapper_clz: + @functools.wraps(_orig_range) + def __call__(self, *args): + # TODO: better infomation + assert 1 <= _orig_len(args) <= 3 + args = (arg.value if _orig_isinstance(arg, ep.ConcreteProxy) else arg for arg in args) + return _orig_range(*args) + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(_orig_range)) + def __hash__(self): + return id(self) + range_wrapper = range_wrapper_clz() + + class enumerate_wrapper_clz: + @functools.wraps(_orig_enumerate) + def __call__(self, iterable, start=0): + count = start + for elem in iterable: + if _orig_isinstance(elem, ep.ConcreteProxy) and _orig_isinstance(elem.value, (_orig_int, str)): + yield count, elem.value + else: + yield count, elem + count += 1 + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(_orig_enumerate)) + def __hash__(self): + return id(self) + enumerate_wrapper = enumerate_wrapper_clz() + + class type_wrapper_clz: + @functools.wraps(_orig_type) + def __call__(self, instance): + orig_type = _orig_type(instance) + if orig_type in (ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): + return _orig_type(instance.value) + else: + return orig_type + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(_orig_enumerate)) + def __hash__(self): + return id(self) + type_wrapper = type_wrapper_clz() + + @classmethod + @functools.wraps(_orig_agfunc_apply) + def agfunc_apply_wrapper(clz, *args, **kwargs): + if clz not in self.agfunc_dict: + self.agfunc_dict[clz] = torch._C._FunctionBase.__dict__['apply'].__get__(None, clz) + if self.temp_disable_agfunc_apply or self.temp_disable_call: + return self.agfunc_dict[clz](*args, **kwargs) + tracers = _orig_set() + def unwrap_detect_tracers(obj): + if isinstance(obj, ep.ConcreteProxy): + tracers.add(obj.tracer) + ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) + ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) + if _orig_len(tracers) == 0: + return self.agfunc_dict[clz](*args, **kwargs) + elif _orig_len(tracers) == 1 and next(iter(tracers)) == self: + return self.create_proxy('call_function', self.agfunc_dict[clz], args, kwargs) + else: + raise Exception('more than 1 tracer detected. please report the issue') + + @functools.wraps(_orig_torch_assert) + def torch_assert_wrapper(condition, message): + while _orig_isinstance(condition, ep.ConcreteProxy): + condition = condition.value + return _orig_torch_assert(condition, message) + + self.agfunc_dict: dict[Type, Any] = {} + self.autowrap_leaf_pairs = { + id(_orig_torch_assert): torch_assert_wrapper, + } + self.wrapped_leaf = dict() + + for func, (positions, is_force_trace, to_func) in self.autowrap_leaf_function.items(): + if _orig_isinstance(func, BuiltinMethodType) and getattr(func, '__name__', None) == 'apply'\ + and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function): + # torch.autograd.function + assert to_func == None, '.apply should set to_func to None!' + if func.__self__ not in self.agfunc_dict: + self.agfunc_dict[func.__self__] = _create_wrapped_leaf_func(self, func, func) + wrapped = self.agfunc_dict[func.__self__] + else: + if func.__qualname__.startswith('_TensorBase'): + positions = (*positions, (torch.Tensor, func.__name__)) + wrapped = _create_wrapped_leaf_method(self, getattr(torch.Tensor, func.__name__), func.__name__, to_func) + elif func.__qualname__.startswith('_VariableFunctionsClass'): + if hasattr(torch, func.__name__): + # avoid bad attr like 'unique_dim' + positions = (*positions, (torch, func.__name__)) + if is_force_trace: + wrapped = _create_wrapped_leaf_func(self, func, to_func, (self,)) + else: + wrapped = _create_wrapped_leaf_func(self, func, to_func) + elif _orig_isinstance(func, (MethodDescriptorType, MethodWrapperType)): + wrapped = _create_wrapped_leaf_method(self, func, func.__name__, to_func) + elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn': + # method + if func.__module__.startswith('_') and func.__module__ != '__main__': + path = sys.modules[func.__module__[1:]] + else: + path = sys.modules[func.__module__] + path = getattr(path, func.__qualname__.split('.')[0]) + positions = (*positions, (path, func.__name__)) + wrapped = _create_wrapped_leaf_method(self, func, func.__name__, to_func) + else: + # common function + if func.__module__.startswith('_') and func.__module__ != '__main__': + path = sys.modules[func.__module__[1:]] + else: + path = sys.modules[func.__module__] + positions = (*positions, (path, func.__name__)) + if is_force_trace: + wrapped = _create_wrapped_leaf_func(self, func, to_func, (self,)) + else: + wrapped = _create_wrapped_leaf_func(self, func, to_func) + self.wrapped_leaf[func] = (positions, wrapped) + + self.clz_wrapper_map: Dict[Any, Type] = { + map_wrapper: _orig_map, + enumerate_wrapper: _orig_enumerate, + range_wrapper: _orig_range, + type_wrapper: _orig_type, + } + for clz, (positions, is_iterable) in self.autowrap_leaf_class.items(): + if clz.__module__.startswith('_') and clz.__module__ != '__main__': + path = sys.modules[clz.__module__[1:]] + else: + path = sys.modules[clz.__module__] + if is_iterable: + wrapped = _create_wrapped_leaf_iterable_class(self, clz) + else: + wrapped = _create_wrapped_leaf_class(self, clz) + positions = (*positions, (path, clz.__name__)) + self.wrapped_leaf[clz] = (positions, wrapped) + self.clz_wrapper_map[wrapped] = clz + + for clz in self.fake_middle_class: + wrapped = _create_wrapped_attr_for_middle_class(self, clz, self.the_path_of_middle_class) + self.wrapped_leaf[clz.__getattribute__] = (((clz, '__getattribute__'),), wrapped) + + @functools.wraps(_orig_isinstance) + def isinstance_wrapper(instance, clz): + if _orig_type(clz) in (slice, tuple, list, _orig_slice, _orig_tuple, _orig_list): + clz_wrapped = [] + for wrapped_type, orig_type in self.clz_wrapper_map.items(): + if wrapped_type in clz: + clz_wrapped.append(orig_type) + clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in self.clz_wrapper_map)) + # use _orig_isinstance(clz, Iterable) will cause an endless recursive loop + for cls in (object, ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): + if cls in clz and _orig_isinstance(instance, cls): + return True + if _orig_isinstance(instance, ep.ConcreteProxy): + return _orig_isinstance(instance.value, clz) + else: + return _orig_isinstance(instance, clz) + else: + if clz in (object, ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): + return _orig_isinstance(instance, clz) + if clz in self.clz_wrapper_map: + clz = self.clz_wrapper_map[clz] + if _orig_isinstance(instance, ep.ConcreteProxy): + instance = instance.value + return _orig_isinstance(instance, clz) + + @functools.wraps(_orig_getattr) + def getattr_wrapper(obj, *args): + # TODO: better infomation + if not 1 <= _orig_len(args) <= 2: + raise Exception() + args = _orig_list(args) + if _orig_isinstance(args[0], ep.ConcreteProxy): + args[0] = args[0].value + return _orig_getattr(obj, *args) + + # for passing the tracing of leaf modules + self.temp_disable_call = False + self.temp_disable_attr = False + self.temp_disable_agfunc_apply = False + self.temp_disable_call_level = 0 + self.temp_disable_attr_level = 0 + self.temp_disable_agfunc_apply_level = 0 + try: + with _Patcher() as self.patcher: + # allow duplicate patches to support the case of nested calls + self.patcher.patch_method(torch.nn.Module, "__getattribute__", module_getattribute_wrapper, deduplicate=False) + + self.patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) + self.patcher.patch_method(torch.autograd.Function, "apply", agfunc_apply_wrapper, deduplicate=False) + self.patcher.patch_method(torch, "_assert", torch_assert_wrapper, deduplicate=False) + + self.patcher.patch_method(builtins, "map", map_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "enumerate", enumerate_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "range", range_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "type", type_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "isinstance", isinstance_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "getattr", getattr_wrapper, deduplicate=False) + + for obj, (positions, wrapped) in self.wrapped_leaf.items(): + for path, name in positions: + self.patcher.patch_method(path, name, wrapped, deduplicate=False) + self.autowrap_leaf_pairs[id(obj)] = wrapped + + _patch_wrapped_functions(self.patcher) + _autowrap_check(self, fn_globals, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + for module in self._autowrap_search: + _autowrap_check(self, module.__dict__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + with OperatorPatcherContext(self, use_operator_patch, operator_patch_backlist): + self.create_node('output', 'output', + (self.create_arg(OperatorPatcherContext.patch_run(fn, *args, *more_args, **kwargs)),), + {}, type_expr=fn.__annotations__.get('return', None)) + finally: + # for cuda versions of pytorch, autograd.Function.apply should be reverted manually + delattr(torch.autograd.Function, 'apply') + pass + + self.submodule_paths = None + return self.graph + +# List of pairs of (global dict, function name) functions +# to patch for the purposes of the wrap() API. +_wrapped_fns_to_patch : List[Tuple[dict, str]] = [] + +# List of methods on classes to wrap (class type, function name) +# this currently only works for Tensor.* methods that aren't traced properly +_wrapped_methods_to_patch : List[Tuple[type, str]] = [] + + +def _find_proxy(*objects_to_search): + """ + Recursively search a data structure for a Proxy() and return it, + return None if not found. + """ + proxy = None + + def find_proxy(x): + nonlocal proxy + if isinstance(x, ep.ConcreteProxy): + proxy = x + + ep.map_aggregate_not_proxy(objects_to_search, find_proxy) + return proxy + +def _create_wrapped_func(orig_fn): + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Given an closed-over ``orig_function`` to invoke, search the args and kwargs for + a Proxy object. If there is one, emit a ``call_function`` node to preserve the + call to this leaf function directly. Otherwise, just return the results of + this function call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return_proxy = proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs) + return_proxy.node.meta['is_wrapped'] = True + return return_proxy + return orig_fn(*args, **kwargs) + + return wrapped + +def _patch_wrapped_functions(patcher : _Patcher): + """ + Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap + the listed global functions in the `_create_wrapped_func` wrapper. + """ + for frame_dict, name in _wrapped_fns_to_patch: + if name not in frame_dict and hasattr(builtins, name): + orig_fn = _orig_getattr(builtins, name) + else: + orig_fn = frame_dict[name] + patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) + + for cls, name in _wrapped_methods_to_patch: + patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) + +def _autowrap_check(tracer: ConcreteTracer, frame_dict : Dict[str, Any], function_ids : Set[int],\ + function_pairs : Dict[int, Callable], agfunc_dict: dict[Type, Any]): + """ + Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. + This method searches a scope for them and patches them if found. + """ + patcher = tracer.patcher + if patcher.visit_once(frame_dict): + for name, value in frame_dict.items(): + # if callable(value) and (not name.startswith('_') or name == '_assert'): + if callable(value) and not name.startswith('__') and not name.startswith('_orig_'): + if id(value) in function_ids: + patcher.patch(frame_dict, name, _create_wrapped_func(value)) + elif id(value) in function_pairs: + patcher.patch(frame_dict, name, function_pairs[id(value)]) + elif _orig_isinstance(value, BuiltinMethodType) and getattr(value, '__name__', None) == 'apply'\ + and _orig_isinstance(getattr(value, '__self__', None), Type) and issubclass(value.__self__, torch.autograd.Function): + # torch.autograd.function + if value.__self__ not in agfunc_dict: + agfunc_dict[value.__self__] = _create_wrapped_leaf_func(tracer, value, value) + patcher.patch(frame_dict, name, agfunc_dict[value.__self__]) + +def _create_wrapped_method(cls, name): + orig_fn = _orig_getattr(cls, name) + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Search the args and kwargs for a Proxy object. If there is one, + emit a ``call_method`` node to preserve the call to this method + directly. Otherwise, just return the results of this function + call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return proxy.tracer.create_proxy('call_method', name, args, kwargs) + return orig_fn(*args, **kwargs) + + return wrapped + + +@compatibility(is_backward_compatible=True) +class GraphAppendingConcreteTracer(ConcreteTracer): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + +class MagicMethodPatcher: + from torch.fx import graph as fx_graph + from torch.fx import graph_module as fx_graph_module + from torch.fx import node as fx_node + magic_methods_ori = fx_graph.magic_methods + magic_methods_new = { + **fx_graph.magic_methods, + 'not_': 'not {}', + 'is_': '{} is {}', + 'is_not': '{} is not {}', + 'contains': '{1} in {0}', + } + copy_attr_ori: Any = fx_graph_module._copy_attr + find_module_of_method_ori: Any = fx_node._find_module_of_method + format_import_statement_ori: Any = fx_graph_module._format_import_statement + + @staticmethod + def copy_attr_new(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): + *prefix, field = target.split('.') + for item in prefix: + f = getattr(from_module, item) + t = getattr(to_module, item, None) + if f is t: + return + + if t is None: + if isinstance(f, Sequential): + t = Sequential() + elif isinstance(f, ModuleList): + t = ModuleList() + elif isinstance(f, ModuleDict): + t = ModuleDict() + else: + t = torch.nn.Module() + if hasattr(f, '_get_name'): + t._get_name = f._get_name + to_module.add_module(item, t) + from_module, to_module = f, t + + orig = getattr(from_module, field) + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + to_module.register_buffer(field, orig) + else: + setattr(to_module, field, orig) + + @staticmethod + def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: + name = orig_method.__name__ + module = orig_method.__module__ + if module is not None: + return module + elif hasattr(orig_method, '__qualname__')\ + and isinstance(orig_method.__qualname__, str) and orig_method.__qualname__.startswith('_VariableFunctionsClass.'): + return 'torch._C._VariableFunctions' + elif hasattr(orig_method, '__self__')\ + and isinstance(orig_method.__self__, Type) and issubclass(orig_method.__self__, torch.autograd.Function): + # for torch.autograd.Function + return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' + for guess in [torch, getattr(torch.nn, 'functional')]: + if getattr(guess, name, None) is orig_method: + return guess.__name__ + raise RuntimeError(f'cannot find module for {orig_method}') + + @staticmethod + def format_import_statement_new(name: str, obj: Any, importer) -> str: + if isinstance(obj, BuiltinMethodType) and getattr(obj, '__name__', None) == 'apply'\ + and isinstance(getattr(obj, '__self__', None), Type) and issubclass(obj.__self__, torch.autograd.Function): + # torch.autograd.function + return MagicMethodPatcher.format_import_statement_ori(name, obj.__self__, importer) + f'\n{name} = {name}.apply' + return MagicMethodPatcher.format_import_statement_ori(name, obj, importer) + + def __enter__(self): + MagicMethodPatcher.fx_graph.magic_methods = self.magic_methods_new + MagicMethodPatcher.fx_graph_module._copy_attr = self.copy_attr_new + MagicMethodPatcher.fx_node._find_module_of_method = self.find_module_of_method_new + MagicMethodPatcher.fx_graph_module._format_import_statement = self.format_import_statement_new + MagicMethodPatcher.available = True + + def __exit__(self, exc_type, exc_value, tb): + MagicMethodPatcher.fx_graph.magic_methods = MagicMethodPatcher.magic_methods_ori + MagicMethodPatcher.fx_graph_module._copy_attr = MagicMethodPatcher.copy_attr_ori + MagicMethodPatcher.fx_node._find_module_of_method = MagicMethodPatcher.find_module_of_method_ori + MagicMethodPatcher.fx_graph_module._format_import_statement = MagicMethodPatcher.format_import_statement_ori + MagicMethodPatcher.available = False + return exc_type is None + +def _create_wrapped_leaf_func(tracer: ConcreteTracer, func: Callable, to_func: Optional[Callable], init_tracers = ()): + # to_func: to call correct replacement instead of the original (the original func may be wrong). + # such as: call torch.nn.norm instead of torch._C._VariableFunctions.norm. + # torch.nn.norm will help to pack dim to list if dim is an int. + if to_func is None: + to_func = func + @functools.wraps(func) + def func_wrapper(*args, **kwargs): + if tracer.temp_disable_call: + return func(*args, **kwargs) + tracers = _orig_set(init_tracers) + def unwrap_detect_tracers(obj): + if isinstance(obj, ep.ConcreteProxy): + tracers.add(obj.tracer) + ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) + ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) + if _orig_len(tracers) == 0: + return to_func(*args, **kwargs) + elif _orig_len(tracers) == 1 and next(iter(tracers)) == tracer: + return tracer.create_proxy('call_function', to_func, args, kwargs) + else: + raise Exception('more than 1 tracer detected. please report the issue') + return func_wrapper + +def _create_wrapped_leaf_method(tracer: ConcreteTracer, method, name: str, to_func: Optional[Callable]): + @functools.wraps(method) + def method_wrapper(*args, **kwargs): + if tracer.temp_disable_call: + return method(*args, **kwargs) + tracers = _orig_set() + def unwrap_detect_tracers(obj): + if isinstance(obj, ep.ConcreteProxy): + tracers.add(obj.tracer) + ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) + ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) + if _orig_len(tracers) == 0: + return method(*args, **kwargs) + elif _orig_len(tracers) == 1 and next(iter(tracers)) == tracer: + if to_func is not None: + return tracer.create_proxy('call_function', to_func, args, kwargs) + else: + return tracer.create_proxy('call_method', name, args, kwargs) + else: + raise Exception('more than 1 tracer detected. please report the issue') + return method_wrapper + +def _create_wrapped_leaf_class(tracer: ConcreteTracer, clz): + class clz_wrapper_clz: + @functools.wraps(clz) + def __call__(self, *args, **kwargs): + if tracer.temp_disable_call: + return clz(*args, **kwargs) + tracers = _orig_set() + def unwrap_detect_tracers(obj): + if isinstance(obj, ep.ConcreteProxy): + tracers.add(obj.tracer) + ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) + ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) + if _orig_len(tracers) == 0: + return clz(*args, **kwargs) + elif _orig_len(tracers) == 1 and next(iter(tracers)) == tracer: + return tracer.create_proxy('call_function', clz, args, kwargs) + else: + raise Exception('more than 1 tracer detected. please report the issue') + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(clz)) + def __hash__(self): + return id(self) + return clz_wrapper_clz() + +def _create_wrapped_leaf_iterable_class(tracer: ConcreteTracer, clz): + class clz_wrapper_clz: + @functools.wraps(clz) + def __call__(self, *args, **kwargs): + if tracer.temp_disable_call: + return clz(*args, **kwargs) + tracers = _orig_set() + if _orig_len(args) != 0: + if _orig_isinstance(args[0], ep.Proxy): + tracers.add(args[0].tracer) + if _orig_isinstance(args[0], Iterator): + args = (clz(args[0]), *args[1:]) + if _orig_isinstance(args[0], Iterable): + for item in args[0]: + if _orig_isinstance(item, ep.Proxy): + tracers.add(item.tracer) + if _orig_len(tracers) == 0: + return clz(*args, **kwargs) + elif _orig_len(tracers) == 1 and next(iter(tracers)) == tracer: + return tracer.create_proxy('call_function', + clz, args, kwargs) + else: + raise Exception('more than 1 tracer detected. please report the issue') + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(clz)) + def __hash__(self): + return id(self) + clz_wrapper = clz_wrapper_clz() + for name in dir(clz): + attr = _orig_getattr(clz, name) + if not name.startswith('_') or name in ('__getitem__', '__setitem__', '__iter__', '__len__'): + if _orig_isinstance(attr, Callable): + setattr(clz_wrapper, name, _create_wrapped_leaf_method(tracer, attr, name, None)) + else: + setattr(clz_wrapper, name, attr) + return clz_wrapper + +def _create_wrapped_attr_for_middle_class(tracer: ConcreteTracer, clz, the_path_of_middle_class): + _orig_clz_getattribute = clz.__getattribute__ + if hasattr(clz, '__getattr__'): + _orig_clz_getattr = clz.__getattr__ + else: + _orig_clz_getattr = None + @functools.wraps(_orig_clz_getattribute) + def clz_getattr_wrapper(obj, attr): + if tracer.temp_disable_call | tracer.temp_disable_attr: + if _orig_clz_getattr == None: + return _orig_clz_getattribute(obj, attr) + else: + try: + return _orig_clz_getattribute(obj, attr) + except AttributeError: + return _orig_clz_getattr(obj, attr) + else: + return tracer.create_proxy('get_attr', f'{the_path_of_middle_class[id(obj)]}.{attr}', (), {}) + return clz_getattr_wrapper + +def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Union[Dict[str, Any], Tuple], + *, + use_operator_patch: bool = False, + operator_patch_backlist: List[str] = [], + forwrad_function_name: str = 'forward', + check_args: Optional[Dict[str, Any]] = None, + autowrap_leaf_function = {}, + autowrap_leaf_class = {}, + leaf_module = (), + fake_middle_class = ()) -> GraphModule: + """ + Concrete tracing API + + Given an ``nn.Module`` or function instance ``root`` and a dummy input `concrete_args`, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + It has solved many problems compared to fx.symbolic_trace, and can execute on many third-party models. + + For example:: + + def f(a, b): + return a + b + + traced_f = concrete_trace(f, concrete_args={'a': 1, 'b': 2}) + # or `traced_f = concrete_trace(f, (1, 2))` + assert traced_f(3, 4) == 7 + + def f(x): + out1, out2 = 0, 0 + for k, v in x.items(): + out1 += k + out2 += v + return out1, out2 + traced_f = concrete_trace(f, ({1: 1, 2: 2}, )) + assert traced_f({2: 3, 4: 5}) == (6, 8) + + Note that we can only record static structure, so all the branches such as if-else or loop will be flattened:: + + def f(x): + out1, out2 = 0, 0 + for k, v in x.items(): + out1 += k + out2 += v + return out1, out2 + traced_f = concrete_trace(f, ({1: 1, 2: 2}, )) + assert traced_f({2: 3, 4: 5, 6:7}) == (6, 8) # not (12, 15) + + # traced code like: + def traced_f(self, x): + out1, out2 = 0, 0 + items = x.items() + + # for loop + iter = iter(items) + + # first loop content + items0 = next(iter) + out1 += items0[0] + out2 += items0[1] + + # second loop content + items1 = next(iter) + out1 += items1[0] + out2 += items1[1] + + return (out1, out2) + + If you want to trace 'is', 'is not', 'in' or 'not in' in your module, you can set use_function_patch to True:: + + def f(x, y): + if x is None: + return y + else: + return x - y + # traced_f = concrete_trace(f, (None, 1)) # bad + traced_f = concrete_trace(f, (None, 1), use_function_patch=True) # f should exist in a file. + + If you have a function/method that should be treated as a leaf function but not trace into it, use autowrap_leaf_function to mark it:: + + def leaf_op(x, y, z): + # if not treated as a leaf function, then only 1 branch will exist. + if x > 0: + return y + z + else: + return y - z + + def f(x): + return leaf_op(x, 3, 2) + + traced_f = concrete_trace(f, (1, ), autowrap_leaf_function = { + leaf_op: ([], False, None), **ConcreteTracer.default_autowrap_leaf_function}) + assert traced_f(1) == 5 and traced_f(-1) == 1 + + If you have a class that should be treated as a leaf class, use autowrap_leaf_class to mark it:: + + class leaf_clz: + def __init__(self, a, b): + self.c = a + b + + def f(x, y): + return leaf_clz(x, y) + + traced_f = concrete_trace(f, (1, 2), autowrap_leaf_class = { + leaf_clz: ([], False), **ConcreteTracer.default_autowrap_leaf_class}) + assert isinstance(traced_f(3, 4), leaf_clz) and traced_f(3, 4).c == 7 + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted into a Graph representation. + concrete_args (Union[Dict[str, Any], Tuple]): Dummy inputs to do concrete trace. + + use_function_patch (bool): Use operator patcher recursively on function calls. Operator patcher will re-compile the function and + translate '{} is {}' into 'operator.is_({}, {})', then we can treat 'is', 'is not', 'in' and 'not in' as function calls. + + operator_patch_backlist (List[str]): Blacklist of the operator patcher. + + autowrap_leaf_function (Dict[Any, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool, Optional[Callable]]]): Leaf function dict, + such as 'add' or 'torch.xxx'. You can add your own leaf functions. + + The struct of dict is: leaf_function: ([(module_path, module_name)], force_to_trace, replace_to_function). + (module_path, module_name): The place the function exists. Such as torch.meshgrid, there are `torch.meshgrid`, + 'torch.functional.meshgrid', 'torch._C._VariableFunctions.meshgrid', we should wrap them all. + force_to_trace: If set to false, the function will only be traced if input relates to concrete_args. + Such as 'torch.rand', we should trace it even if it doesn't relate to concrete_args. + replace_to_function: If not `None`, we will use it to replace the original function in traced code. + Such as ModuleList.__getitem__, we can use operator.getitem to replace it. + + default_autowrap_leaf_class (Dict[Type, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool]]): Leaf class dict, such as 'int', + 'range' or 'zip'. You can add your own leaf functions such as 'torch.finfo' or 'modeling_outputs.SequenceClassifierOutput'. + + The struct of dict is: leaf_class: ([(module_path, module_name)], is_iterator_class). + is_iterator_class: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True. + + Returns: + fx.GraphModule: a Module created from the recorded operations from ``root``. + """ + tracer = ConcreteTracer() + graph = tracer.trace(root, + autowrap_leaf_function = autowrap_leaf_function, + autowrap_leaf_class = autowrap_leaf_class, + leaf_module = leaf_module, + fake_middle_class = fake_middle_class, + concrete_args=concrete_args, + use_operator_patch=use_operator_patch, + operator_patch_backlist=operator_patch_backlist, + forwrad_function_name=forwrad_function_name, + ) + graph_check = tracer.trace(root, + autowrap_leaf_function = autowrap_leaf_function, + autowrap_leaf_class = autowrap_leaf_class, + leaf_module = leaf_module, + fake_middle_class = fake_middle_class, + concrete_args=concrete_args, + use_operator_patch=use_operator_patch, + operator_patch_backlist=operator_patch_backlist, + forwrad_function_name=forwrad_function_name, + ) + # compare to check equal + assert len(graph.nodes) == len(graph_check.nodes) + for node_a, node_b in zip(graph.nodes, graph_check.nodes): + node_a: Node + node_b: Node + target_a = node_a.target + target_b = node_b.target + if node_a.op == 'get_attr' and node_a.name.startswith('_tensor_constant'): + assert node_b.op == 'get_attr' and node_b.name.startswith('_tensor_constant') + assert torch.equal(getattr(root, node_a.name), getattr(root, node_b.name)) + elif node_a.op == 'call_function' and isinstance(target_a, Callable) and target_a.__name__ == 'apply' and\ + hasattr(target_a, '__self__') and issubclass(target_a.__self__, torch.autograd.Function): + assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\ + hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function) + else: + assert node_a.op == node_b.op and target_a == target_b + + with MagicMethodPatcher(): + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + traced = GraphModule(tracer.root, graph, name) + + # TODO: better infomation + # # assert root(**concrete_args) == traced(**concrete_args) + if check_args is not None: + assert root(**check_args) == traced(**check_args) + return traced, tracer \ No newline at end of file diff --git a/cube/graph/parser/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/concrete_trace_utils/operator_patcher.py new file mode 100644 index 00000000..1f6eb3ac --- /dev/null +++ b/cube/graph/parser/concrete_trace_utils/operator_patcher.py @@ -0,0 +1,270 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .concrete_tracer import ConcreteTracer + +import ast +import inspect +import logging + +from textwrap import dedent +from types import MethodType, FunctionType +from typing import List, Optional, Callable, Dict + +import torch + +from .utils import ( + _orig_type, + _orig_isinstance, + _orig_len, + _orig_dict, + _orig_zip, +) + +_logger = logging.getLogger(__name__) + +class TransformerOp(ast.NodeTransformer): + """ + An ast transformer, to check and replace the python ops 'not/is/is not/in/not in' to functions in 'operator' module. + """ + + def visit_start(self, node): + # to mark if the ast is changed + self.is_transformed = False + + # detect the expr now is in a branch test expr + # 0: not in a branch test expr. + # 1: in propagate if not in func 'visit', or not in a branch test expr in func 'visit' + # 2: in a branch test expr + self.is_incond_status = 0 + ret = super().visit(node) + return self.is_transformed, ret + + def visit(self, node): + if self.is_incond_status != 0: + # if the status is 'in branch test', + self.is_incond_status -= 1 + return super().visit(node) + + def visit_Call(self, node: ast.Call): + if not isinstance(node.func, ast.Name) or node.func.id != 'patch_run': + self.is_transformed = True + return self.generic_visit(ast.Call( + func=ast.Name(id='patch_run', ctx=ast.Load()), + args=[node.func, *node.args], + keywords=node.keywords, + )) + else: + return self.generic_visit(node) + + def visit_While(self, node: ast.While): + self.is_incond_status = 2 + node.test = self.visit(node.test) + self.is_incond_status = 0 + node.body = [self.visit(item) for item in node.body] + node.orelse = [self.visit(item) for item in node.orelse] + return node + + def visit_If(self, node: ast.If): + self.is_incond_status = 2 + node.test = self.visit(node.test) + self.is_incond_status = 0 + node.body = [self.visit(item) for item in node.body] + node.orelse = [self.visit(item) for item in node.orelse] + return node + + def visit_IfExp(self, node: ast.IfExp): + node.body = self.visit(node.body) + self.visit(node.body) + self.is_incond_status = 2 + node.test = self.visit(node.test) + self.is_incond_status = 0 + node.orelse = self.visit(node.orelse) + return node + + def visit_UnaryOp(self, node: ast.UnaryOp): + if self.is_incond_status != 0: + # in branch cond test expr, need no replacement + self.is_incond_status = 2 + return self.generic_visit(node) + elif _orig_isinstance(node.op, ast.Not): + self.is_transformed = True + return self.generic_visit(ast.Call( + func=ast.Name(id='not_', ctx=ast.Load()), + args=[node.operand], + keywords=[], + )) + else: + return self.generic_visit(node) + + def visit_BoolOp(self, node: ast.BoolOp): + if self.is_incond_status != 0: + # in branch cond test expr, need no replacement + self.is_incond_status = 2 + return self.generic_visit(node) + else: + if not _orig_isinstance(node.values[1], (ast.Call, ast.BoolOp)): + _logger.warning('warning: "and/or" will generate branch expr. The 2nd arg can\'t be traced if the 1st arg returns a True.' + ' Don\'t mix up "and/or" and "&/|"!') + return self.generic_visit(node) + + def visit_Compare(self, node: ast.Compare): + should_replace = False + for op in node.ops: + if _orig_type(op) in (ast.Is, ast.IsNot, ast.In, ast.NotIn): + should_replace = True + break + if should_replace: + if _orig_len(node.ops) != 1: + raise RuntimeError( + 'not supported in "{} cmp_op {} cmp_op {}" when cmp_op contains "is/is not/in/not in"') + self.is_transformed = True + func_id = { + ast.Is: 'is_', + ast.IsNot: 'is_not', + ast.In: 'contains', + ast.NotIn: 'contains', + }[_orig_type(node.ops[0])] + if _orig_isinstance(node.ops[0], (ast.In, ast.NotIn)): + args = [node.comparators[0], node.left] + else: + args = [node.left, node.comparators[0]] + ret_node = ast.Call( + func=ast.Name(id=func_id, ctx=ast.Load()), + args=args, + keywords=[], + ) + if _orig_isinstance(node.ops[0], ast.NotIn): + ret_node = ast.Call( + func=ast.Name(id='not_', ctx=ast.Load()), + args=[ret_node], + keywords=[], + ) + return self.generic_visit(ret_node) + else: + return self.generic_visit(node) + + +class OperatorPatcher: + """ + An function patcher, to patch the un-wrappable operator 'not/is/is not/in/not in' to wrappable functions. + """ + + transformer_op = TransformerOp() + + def __init__(self, use_operator_patch: bool, operator_patch_backlist: List[str]): + self.use_operator_patch = use_operator_patch + self.operator_patch_backlist = operator_patch_backlist + self.function_cache: Dict[int, Callable] = {} + self.function_cache_orig: Dict[int, Callable] = {} + + def patch_inner(self, func): + if id(func) not in self.function_cache: + self.function_cache[id(func)] = self.patch_inner_helper(func) + self.function_cache_orig[id(func)] = func + return self.function_cache[id(func)] + + def patch_inner_helper(self, func): + if not hasattr(func, '__module__') or func.__module__ is None or func.__module__.startswith('torch'): + return func + if hasattr(func, '_Patcher__fx_already_patched'): + return func + if self.use_operator_patch == (func in self.operator_patch_backlist): + return func + if _orig_isinstance(func, torch.nn.Module): + func = func.forward + if _orig_isinstance(func, MethodType): + func_inner = func.__func__ + the_self = func.__self__ + else: + func_inner = func + the_self = None + if not _orig_isinstance(func_inner, FunctionType) or not hasattr(func_inner, '__code__'): + return func + + lines, lnum = inspect.findsource(func_inner) + # align with original source code + source = ''.join(('\n' * lnum, *inspect.getblock(lines[lnum:]))) + dedent_src = dedent(source) + tree = ast.parse(dedent_src) + + is_transformed, new_tree = OperatorPatcher.transformer_op.visit_start(tree) + if not is_transformed: + return func + else: + body0: ast.FunctionDef = new_tree.body[0] + body0.body = [ + # equals to: + # from operator import not_, is_, is_not, contains + ast.ImportFrom( + module='operator', + names=[ + ast.alias(name='not_'), + ast.alias(name='is_'), + ast.alias(name='is_not'), + ast.alias(name='contains'), + ], + level=0 + ), + *body0.body + ] + body0.name = 'new_func' + # for deleting some annotations like 'add_start_docstrings_to_model_forward' or 'add_code_sample_docstrings' + body0.decorator_list = [i for i in body0.decorator_list + if isinstance(i, ast.Call) and isinstance(i.func, ast.Name) and i.func.id == 'patch_run' and + isinstance(i.args[0], ast.Name) and + i.args[0].id not in ('add_start_docstrings_to_model_forward', 'add_code_sample_docstrings')] + ast.fix_missing_locations(new_tree) + + # closure info + closure_dict = {} + closures = func_inner.__closure__ + co_freevars = func_inner.__code__.co_freevars + if (closures != None and _orig_len(closures) != 0) or _orig_len(co_freevars) != 0: + assert _orig_len(closures) == _orig_len(co_freevars) + closure_dict = _orig_dict(_orig_zip(co_freevars, [c.cell_contents for c in closures])) + + var_dict = {} + exec( + # use func.__code__.co_filename to make the new function easily debuggable. + compile(new_tree, func_inner.__code__.co_filename, 'exec'), + { + 'patch_run': OperatorPatcherContext.patch_run, + **func_inner.__globals__, + **closure_dict, + }, + var_dict) + if the_self is not None: + return var_dict['new_func'].__get__(the_self) + else: + return var_dict['new_func'] + +class OperatorPatcherContext: + ctx_tracer: Optional['ConcreteTracer'] = None + ctx_patcher: Optional[OperatorPatcher] = None + + def __init__(self, tracer: 'ConcreteTracer', use_operator_patch: bool, operator_patch_backlist: List[str]): + self.tracer = tracer + self.patcher = OperatorPatcher(use_operator_patch, operator_patch_backlist) + + def __enter__(self): + assert OperatorPatcherContext.ctx_tracer is None + assert OperatorPatcherContext.ctx_patcher is None + OperatorPatcherContext.ctx_tracer = self.tracer + OperatorPatcherContext.ctx_patcher = self.patcher + + def __exit__(self, exc_type, exc_value, tb): + assert OperatorPatcherContext.ctx_tracer == self.tracer + OperatorPatcherContext.ctx_tracer = None + OperatorPatcherContext.ctx_patcher = None + return exc_type is None + + @staticmethod + def patch_run(func, *args, **kwargs): + assert OperatorPatcherContext.ctx_tracer is not None + assert OperatorPatcherContext.ctx_patcher is not None + with OperatorPatcherContext.ctx_tracer.do_temp_disable(True, True, True): + new_func = OperatorPatcherContext.ctx_patcher.patch_inner(func) + return new_func(*args, **kwargs) \ No newline at end of file diff --git a/cube/graph/parser/concrete_trace_utils/utils.py b/cube/graph/parser/concrete_trace_utils/utils.py new file mode 100644 index 00000000..3cfc95dc --- /dev/null +++ b/cube/graph/parser/concrete_trace_utils/utils.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import builtins +import operator +from typing import Any, Callable, Type +import functools + +import torch + +# These need to run in global scope to handle nested calls correctly +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ +_orig_module_getattribute: Callable = torch.nn.Module.__getattribute__ + +_orig_agfunc_apply: Callable = torch.autograd.function.Function.apply +_orig_torch_assert: Callable = torch._assert + +_orig_type: Callable = builtins.type +_orig_isinstance: Callable = builtins.isinstance +_orig_getattr: Callable = builtins.getattr + +_orig_range: Type[Any] = builtins.range +_orig_int: Type[Any] = builtins.int +_orig_bool: Type[Any] = builtins.bool +_orig_tuple: Type[Any] = builtins.tuple +_orig_list: Type[Any] = builtins.list +_orig_set: Type[Any] = builtins.set +_orig_frozenset: Type[Any] = builtins.frozenset +_orig_dict: Type[Any] = builtins.dict +_orig_map: Type[Any] = builtins.map +_orig_zip: Type[Any] = builtins.zip +_orig_enumerate: Type[Any] = builtins.enumerate +_orig_slice: Type[Any] = builtins.slice + +_orig_len: Callable = builtins.len +_orig_not: Callable = operator.not_ +_orig_is: Callable = operator.is_ +_orig_is_not: Callable = operator.is_not +_orig_contains: Callable = operator.contains +_orig_index: Callable = operator.index + +def run_onlyif_instance(cond_type: Type[Any], return_orig: bool = True, return_const: Any = None): + def helper(fn): + if return_orig: + @functools.wraps(fn) + def wrapper_orig(*args): + if _orig_isinstance(args[-1], cond_type): + return fn(*args) + return args[-1] + return wrapper_orig + else: + @functools.wraps(fn) + def wrapper_const(*args): + if _orig_isinstance(args[-1], cond_type): + return fn(*args) + return return_const + return wrapper_const + return helper + +def map_recursive(fn: Callable, arg) -> Any: + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + if _orig_type(arg) != torch.Size and _orig_isinstance(arg, _orig_tuple): + t = _orig_tuple(map_recursive(fn, elem) for elem in arg) + # Support NamedTuple (if it has `_fields`) by repacking into original type. + return t if not hasattr(arg, '_fields') else _orig_type(arg)(*t) + elif _orig_isinstance(arg, _orig_list): + return _orig_list(map_recursive(fn, elem) for elem in arg) + elif _orig_isinstance(arg, _orig_dict): + return {k: map_recursive(fn, v) for k, v in arg.items()} + else: + return fn(arg) + +def map_recursive_zip(fn: Callable, arg0, *args) -> Any: + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + if _orig_type(arg0) != torch.Size and _orig_isinstance(arg0, _orig_tuple): + for arg in args: + assert (not _orig_isinstance(arg, torch.Size)) and _orig_isinstance(arg, _orig_tuple) + assert len(arg0) == len(arg) + return _orig_tuple(map_recursive_zip(fn, *sub_args) for sub_args in _orig_zip(arg0, *args)) + elif _orig_isinstance(arg0, _orig_list): + for arg in args: + assert _orig_isinstance(arg, _orig_list) + assert len(arg0) == len(arg) + return _orig_list(map_recursive_zip(fn, *sub_args) for sub_args in _orig_zip(arg0, *args)) + elif _orig_isinstance(arg0, _orig_dict): + keys = _orig_set(arg0.keys()) + for arg in args: + assert _orig_isinstance(arg, _orig_dict) and len(keys.symmetric_difference(arg.keys())) == 0 + return {k: map_recursive_zip(fn, arg0[k], *(arg[k] for arg in args)) for k in keys} + else: + # assert not _orig_isinstance(arg0, slice) + return fn(arg0, *args) \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 5f3d8210..092f91ce 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -3,6 +3,7 @@ from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser from cube.graph.parser import FxModuleParser, FxFuncOpTracer +from cube.graph.parser.concrete_trace_utils import concrete_trace from cube.graph import IRGraph from cube.flags import CompileFlag @@ -11,18 +12,36 @@ def convert_model(model: torch.nn.Module, input_shapes: Optional[ List[List[int],] ] = None, + dummy_input = None, save_content: bool = True) -> IRGraph: """ Convert torch.nn.Module based model into IRGraph """ try: if CompileFlag.use_torchfx: - # from torch.fx import symbolic_trace - # # Symbolic tracing frontend - captures the semantics of the module - tracer = FxFuncOpTracer() - traced_graph: torch.fx.Graph = tracer.trace(model) - smodule: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) - smodule.graph.print_tabular() + if not dummy_input: + from torch.fx import symbolic_trace + # Symbolic tracing frontend - captures the semantics of the module + tracer = FxFuncOpTracer() + traced_graph: torch.fx.Graph = tracer.trace(model) + smodule: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) + smodule.graph.print_tabular() + else: + print(f'input_shapes = {input_shapes}, {type(input_shapes)}') + print(f'dummy_input = {dummy_input}, {type(dummy_input)}') + with torch.no_grad(): + output_origin = model(**dummy_input) + traced_model, _ = concrete_trace( + model, + dummy_input, + use_operator_patch=True, + autowrap_leaf_class={ + torch.finfo: ((), False), + type(output_origin): ((), False), + }, + ) + print(f'type(traced_model = {type(traced_model)}') + traced_model.graph.print_tabular() else: smodule = torch.jit.script(model) @@ -31,9 +50,14 @@ def convert_model(model: torch.nn.Module, raise RuntimeError("Cannot convert module into torchscript/torch.fx module.") if CompileFlag.use_torchfx: - FxModuleParser.save_content = save_content - inputs, nodes, outputs = FxModuleParser.parse(smodule, input_shapes) - module_name = model.__class__.__name__ + if not dummy_input: + FxModuleParser.save_content = save_content + inputs, nodes, outputs = FxModuleParser.parse(smodule, input_shapes) + module_name = model.__class__.__name__ + else: + FxModuleParser.save_content = save_content + inputs, nodes, outputs = FxModuleParser.parse(traced_model, input_shapes=input_shapes, dummy_input=dummy_input) + module_name = model.__class__.__name__ else: ScriptModuleParser.save_content = save_content inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index be4b81f8..4297efa7 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -63,6 +63,7 @@ def shape_refine(shape: torch.Size) -> torch.Size: @staticmethod def parse(module: torch.fx.GraphModule, input_shapes: Optional[Tuple[List[int],]] = None, + dummy_input = None, frame: Frame = None) \ -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: """ @@ -75,21 +76,41 @@ def parse(module: torch.fx.GraphModule, inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] print(f'inputs = {inputs}') if input_shapes is not None and len(input_shapes) != len(inputs): - raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + print(f'module(type = {type(module)}.__dict__.keys() = {module.__dict__.keys()}') + print(f'input shape mismatch (got {len(input_shapes)} != {len(inputs)})') + # TODO fixme raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") ## shape propagation default_dtype = torch.get_default_dtype() kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - sample_inputs = [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] + sample_inputs = dummy_input if dummy_input else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] + sample_input_tensors = [sample_inputs[input] for input in sample_inputs] if type(sample_inputs) is dict else sample_inputs + from torch.fx.passes.shape_prop import ShapeProp - ShapeProp(module).propagate(*sample_inputs) + ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) + + # for node in module.graph.nodes: + # print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) for node in module.graph.nodes: - print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) + print(f'node.name = {node.name}') + if hasattr(node, 'meta') and node.meta.get('tensor_meta') is not None: + if node.name == 'output': + print('pause here') + if not hasattr(node.meta['tensor_meta'], 'dtype'): + for per_output_meta in node.meta['tensor_meta']: + if isinstance(per_output_meta, torch.fx.passes.shape_prop.TensorMetadata): + print(node.name, '-sub-output', per_output_meta.dtype, per_output_meta.shape) + else: + print(f'ERROR: skip {node.name}\'s non TensorMetadata sub-output') + else: + print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) + else: + print(f'ERROR: none tensor_meta of Node {node.name}') # handle graph input -- Assuming all the inputs are tensors for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) - shape = None if input_shapes is None else input_shapes[idx] + shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] dtype = kDefaultType val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) @@ -99,12 +120,16 @@ def parse(module: torch.fx.GraphModule, activation_op_strs = {'call_function', 'output', 'call_method'} activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] for node in activation_nodes: - assert isinstance(node, torch.fx.Node) - shape = node.meta['tensor_meta'].shape - shape = FxModuleParser.shape_refine(shape) - dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) - frame.add_var(node.name, val) + if hasattr(node, 'meta') and node.meta.get('tensor_meta') and hasattr(node.meta['tensor_meta'], 'dtype'): + shape = node.meta['tensor_meta'].shape + shape = FxModuleParser.shape_refine(shape) + dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) + val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) + frame.add_var(node.name, val) + else: + print(f'WARNING: creation of no-shaped activation for {node.name}') + val = IRFullTensor(shape=[1], requires_grad=True, dtype=ir.int32, name=node.name) # TODO fixme + frame.add_var(node.name, val) # handle nodes all_ir_nodes: List[IRFwOperation] = list() diff --git a/cube/program.py b/cube/program.py index b03b0dd0..ef8a49dd 100644 --- a/cube/program.py +++ b/cube/program.py @@ -120,7 +120,7 @@ def __next__(self): class SemanticModel: - def __init__(self, model: Optional[torch.nn.Module], input_shapes=None): + def __init__(self, model: Optional[torch.nn.Module], input_shapes=None, dummy_input=None): """ Create semantic model based on AI Scientist description. @@ -131,6 +131,7 @@ def __init__(self, model: Optional[torch.nn.Module], input_shapes=None): assert isinstance(model, torch.nn.Module), f"device of local_rank == 0 must provide model" self.model = model self.input_shapes = None + self.dummy_input = dummy_input self.ir_graph = None self._loaded_module: CubeModule = None self._save_content = True @@ -176,7 +177,7 @@ def __call__(self, *args): if DeviceGroup().local_rank == 0: if self.ir_graph is None: self.ir_graph = parser.convert_model( - self.model, input_shapes=input_shapes, save_content=self.save_content + self.model, input_shapes=input_shapes, dummy_input=self.dummy_input, save_content=self.save_content ) self.input_shapes = input_shapes else: diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 47e4baee..2a203b76 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -123,7 +123,7 @@ class SynDataLoader(CubeDataLoader): for given shapes, dtypes. """ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, - batch_dims: Tuple[int] = None): + batch_dims: Tuple[int] = None, names: Tuple[str] = None, append_args=None, device=torch.cuda.current_device()): """ shapes Tuple[Tuple[int]]: The shape for each data @@ -138,6 +138,11 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, dtypes = tuple([torch.float] * len(shapes)) super().__init__(shapes, dtypes, batch_dims) + self.names = names + print(f'### SynDataLoader.names = {names}') + self.append_args=append_args + self.device = device + self.buffer: Union[torch.Tensor, Tuple[torch.Tensor]] = None datas = self.random_sample() self.set_output(datas) @@ -145,7 +150,12 @@ def __iter__(self): return self def __next__(self): - return self.buffer + if self.names: + assert len(self.names) == len(self.buffer) + print(f'### named_syn_data') + return dict(zip(self.names, self.buffer)).update(self.append_args) + else: + return self.buffer def random_sample(self) -> Tuple[torch.Tensor]: torch.manual_seed(0) @@ -154,8 +164,14 @@ def random_sample(self) -> Tuple[torch.Tensor]: datas.append( torch.rand( shape, - device=torch.cuda.current_device(), + device=self.device, requires_grad=False).to(dtype) + if torch.is_floating_point(torch.zeros([1], dtype=dtype)) else + torch.ones( + shape, + device=self.device, + requires_grad=False + ).to(dtype) ) return tuple(datas) diff --git a/examples/nlp/torchscale/fx_test.py b/examples/nlp/torchscale/fx_test.py new file mode 100644 index 00000000..125aa260 --- /dev/null +++ b/examples/nlp/torchscale/fx_test.py @@ -0,0 +1,203 @@ +# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/fx_test.py examples/nlp/torchscale/input --arch mt_base --share-decoder-input-output-embed --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --dropout 0.3 --weight-decay 0.0001 --max-tokens 4096 --fp16 --policy PASData + +import torch +import pickle +from fairseq import ( + tasks, + options, + checkpoint_utils +) +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.trainer import Trainer +from fairseq.data import iterators + +import sys + +import os +print(f'os.getcwd() = {os.getcwd()}') + + +# https://github.com/microsoft/torchscale/tree/main/examples/fairseq +# sys.path.append('/home/v-junliang/torchscaletest/torchscale/examples/fairseq') +# sys.path.append('./torchscaletest/torchscale/examples/fairseq') +sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') +sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale') +print(f'sys.path = {sys.path}') +import models + +#:torchscaletest/torchscale +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +sys.path.append('.') +from policy import mpmd, spmd +# import examples.nlp.torchscale.policy.spmd as spmd + +# import argparse + +# parser = argparse.ArgumentParser(description='comm primitive') +# parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +# parser.add_argument('--local_rank', type=int, default=0) +# args = parser.parse_args() + +# build model +parser = options.get_training_parser() +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +# parser.add_argument('--local_rank', type=int, default=0) + +args = options.parse_args_and_arch(parser) + +cube.init() +# set up policy +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + +cfg = convert_namespace_to_omegaconf(args) +task = tasks.setup_task(cfg.task) +model = task.build_model(cfg.model) +model.eval() +print("building model succeed: ", type(model)) + +# create dummy input +with open('examples/nlp/torchscale/input_tl', 'rb') as f: + dummy_input = pickle.load(f) +device = next(model.parameters()).device +print(f'device = {device}') +for key in dummy_input.keys(): + dummy_input[key] = dummy_input[key].to(device) +print("creating dummy input succeed") +print(f'dummy_input = {dummy_input}, {type(dummy_input)}') +dummy_input['features_only'] = False +dummy_input['return_all_hiddens'] = False +print(f'dummy_input = {dummy_input}, {type(dummy_input)}') + +with torch.no_grad(): + output_origin = model(**dummy_input) + + +input_shapes = [list(dummy_input[input].size()) for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] +input_dtypes = [dummy_input[input].dtype for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] +input_names = tuple([input for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)]) +print(f'input_shapes(out) = {input_shapes}, {type(input_shapes)}, {type(input_shapes[0])}') +print(f'input_dtypes = {input_dtypes}') +kwargs_keys = [input for input in dummy_input if not isinstance(dummy_input[input], torch.Tensor)] +kwargs = dict() +for key in kwargs_keys: + kwargs[key] = dummy_input[key] +print(f'kwargs = {kwargs}') +# model = cube.SemanticModel( +# model, input_shapes=(input_shapes,), +# ) +model = cube.SemanticModel( + model, dummy_input=dummy_input, +) + +dataloader = cube.runtime.syndata.SynDataLoader( + shapes=(input_shapes), + dtypes=input_dtypes, + batch_dims=(0,0,0), + names=input_names, + append_args=kwargs +) +print(f'next(dataloader) = {next(dataloader)}') + +@cube.compile(model, dataloader, PAS=PAS, load_content=False) +def train_iter(model, dataloader): + data = next(dataloader) + # loss = model({'src_tokens':data[0],'src_lengths':data[1],'prev_output_tokens':data[2], }) + loss = model(*data) + loss.backward() + + + +# Conduct concrete trace below +# sys.path.append('/home/v-junliang/torchscaletest/nni') +# sys.path.append('./torchscaletest/nni') +# from nni.common.concrete_trace_utils import concrete_trace +# from concrete_trace_utils import concrete_trace +from examples.nlp.torchscale.concrete_trace_utils import concrete_trace +import examples.nlp.torchscale.torchscaletest.torchscale + + +def check_equal(a, b): + if type(a) != type(b): + return False + if isinstance(a, (list, tuple, set)): + if len(a) != len(b): + return False + for sub_a, sub_b in zip(a, b): + if not check_equal(sub_a, sub_b): + return False + return True + elif isinstance(a, dict): + keys_a, kes_b = set(a.keys()), set(b.keys()) + if keys_a != kes_b: + return False + for key in keys_a: + if not check_equal(a[key], b[key]): + return False + return True + elif isinstance(a, torch.Tensor): + return torch.equal(a, b) + else: + return a == b + + +print("start tracing...") +traced_model, _ = concrete_trace( + model, + dummy_input, + use_operator_patch=True, + autowrap_leaf_class={ + torch.finfo: ((), False), + type(output_origin): ((), False), + }, +) +print("trace succeed") +print("checking equal...") +with torch.no_grad(): + output_traced = traced_model(**dummy_input) +assert check_equal(output_origin, output_traced), "check equal failed" +print("checked") + +# check graph +traced_model.graph.print_tabular() + +# with open('input_tl', 'wb') as f: +# pickle.dump(dummy_input, f) + +# try to save traced model with pickle +# from concrete_trace_utils.concrete_tracer import MagicMethodPatcher +# from pickle import _Pickler, _Unpickler + +# with open("save/through_nn_Module/tl_traced_v2.model", "wb") as f: +# # pickle.dump(traced_model, f) +# with MagicMethodPatcher(): +# _Pickler(f).dump(traced_model) + +# with open("save/through_nn_Module/tl_traced.model", "rb") as f: +# with MagicMethodPatcher(): +# reload_model = _Unpickler(f).load() + + +# with torch.no_grad(): +# output_reload = reload_model(**dummy_input) +# assert check_equal(output_origin, output_reload), "reload check equal failed" +# print("reload is good!") + +# with open("save/through_nn_Module/tl_origin_v2.model", "wb") as f: +# with MagicMethodPatcher(): +# _Pickler(f).dump(model) + +# with open("save/through_nn_Module/tl_input_v2.pkl", "wb") as f: +# with MagicMethodPatcher(): +# _Pickler(f).dump(dummy_input) + diff --git a/examples/nlp/torchscale/policy/mpmd.py b/examples/nlp/torchscale/policy/mpmd.py new file mode 100644 index 00000000..bf00f3ee --- /dev/null +++ b/examples/nlp/torchscale/policy/mpmd.py @@ -0,0 +1,103 @@ +import random +from typing import Tuple +import numpy as np + +from cube.graph.graph import IRGraph +from cube.graph.segment import IRSegment +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.schedule.predefined import PredefinedSched + + +def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + + +def PASRandom(graph, resource): + """ + Random pipeline + """ + assert len(graph.nodes()) // 2 >= resource.ngpus, "not enough operator number." + remain_device = set(range(resource.ngpus)) + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if len(remain_device) != 0: + idx = random.randint(0, len(remain_device) - 1) + device = list(remain_device)[idx] + remain_device.remove(device) + else: + device = random.randint(0, resource.ngpus - 1) + graph.assign(node, device) + elif isinstance(node, IRDataOperation): + device = random.randint(0, resource.ngpus - 1) + graph.assign(node, device) + print(graph.extra_repr()) + return graph + + +def PASMegatron(graph: IRGraph, resource): + + # assert resource.ngpus == 8, "should apply on 8 gpus" + num_stage = 4 + num_tp = resource.ngpus // num_stage + num_microbatch = resource.ngpus * 8 + + _, tp_mesh = _create_mesh(resource.ngpus, (num_stage, num_tp)) + print(f'> pipeline-tensor parallel group: {tp_mesh}') + assert len(tp_mesh) == num_stage + + linears = graph.select('linear') + stage_start_nodes = linears[::len(linears) // num_stage] + stage_start_nodes = stage_start_nodes[:num_stage] + assert len(stage_start_nodes) == num_stage, f"{len(stage_start_nodes)} != {num_stage}" + graph.staging(stage_start_nodes) + + segments = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + assert len(fsegs) == num_stage + + for sid, segment in enumerate(fsegs): + # get tensor parallel group + tp_group = tp_mesh[sid] + for idx, node in enumerate(segment.nodes()): + # partition + if node.name == 'linear': + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx % 2, num=num_tp) + else: + tp_nodes = graph.replicate(node, times=num_tp) + # assign + for devid, node in zip(tp_group, tp_nodes): + graph.assign(node, devid) + + for dl in graph.select(ntype=IRDataOperation): + mesh = tp_mesh[0] + dls = graph.replicate(dl, times=num_tp) + for devid, dl in zip(mesh, dls): + graph.assign(dl, devid) + + # setup schedule to 1F1B + # schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) + # graph.schedule_plan = schedule + if graph.train: + schedule = PredefinedSched.sched_1f1b(graph, num_microbatch, num_stage) + else: + schedule = PredefinedSched.sched_infer_pipe(graph, num_microbatch, num_stage) + return graph diff --git a/examples/nlp/torchscale/policy/spmd.py b/examples/nlp/torchscale/policy/spmd.py new file mode 100644 index 00000000..e3bae1a4 --- /dev/null +++ b/examples/nlp/torchscale/policy/spmd.py @@ -0,0 +1,231 @@ +from typing import List +from cube.graph import IRGraph +from cube.graph.segment import IRSegment +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.gener.rvd.intra import IntraAutoPlacer + + +# tensor parallelism with auto-placer +# This is an implementation example of SPMD auto placer usage +def _tp_autoplace(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): + + if len(devs) == 1: + graph.assign(node, devs[0]) + return [node] + + segment: IRSegment = graph.segment(node) + ftensor = node.input(configs['idx']).parent + + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + producers = segment.producers(ftensor) + if ftensor.is_param() or len(producers) != len(sub_nodes): + print(f"> skip auto placer due to condition not matched: " + f"nproducers: {len(producers)}, nconsumers: {len(sub_nodes)}, " f"producer name: {producers[0].name if len(producers) > 0 else None}") + devs = sorted(list(devs)) + for devid, node in zip(devs, sub_nodes): + graph.assign(node, devid) + else: + devices = IntraAutoPlacer.auto_place( + segment, ftensor, producers, sub_nodes) + for devid, subnode in zip(devices, sub_nodes): + graph.assign(subnode, devid) + return sub_nodes + +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def PASSingle(graph: IRGraph, resource): + """ + Single device + """ + assert resource.ngpus == 1, "only apply for single gpu case" + for node in graph.nodes(): + if isinstance(node, (IRDataOperation, IRFwOperation)): + graph.assign(node, 0) + return graph + + +def PASData(graph: IRGraph, resource): + """ + Data Parallel + """ + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=resource.ngpus) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + batch_dim = node.get_batch_dims()[0] + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=0, dim=batch_dim, num=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def PASCol(graph: IRGraph, resource): + """ + Linear Column Parallel + """ + linears = [node for node in graph.nodes() if node.name == 'linear'] + for idx, node in enumerate(linears): + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=0, num=resource.ngpus + ) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def PASRow(graph: IRGraph, resource): + """ + Linear Column Parallel + """ + devs = list(range(resource.ngpus)) + + for dl in graph.select(ntype=IRDataOperation): + sub_nodes = graph.replicate(dl, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + + for node in graph.select(ntype=IRFwOperation): + if node.name == 'linear': + _tp(graph, node, devs, idx=0, dim=1, num=len(devs)) + else: + _replica(graph, node, devs) + + return graph + + +def PASHybrid(graph: IRGraph, resource): + """ + Linear Hybrid Parallelism (Megatron) + """ + linears = [node for node in graph.nodes() if node.name == 'linear'] + for idx, node in enumerate(linears): + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=resource.ngpus) + for idx, node in enumerate(tp_nodes): + graph.assign(node, idx) + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + print(graph.extra_repr()) + return graph + + +def PASMegatronTP(graph: IRGraph, resource): + """ + Tensor + Data Parallelism + """ + tp = min(2, resource.ngpus) + dp = resource.ngpus // tp + linears = [node for node in graph.nodes() if node.name == 'linear'] + for idx, node in enumerate(linears): + sub_nodes = [] + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=tp) + for tp_node in tp_nodes: + algo = tp_node.algorithms('dim') + dp_nodes = graph.partition(tp_node, algo, idx=0, dim=0, num=dp) + sub_nodes += dp_nodes + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + for node in graph.nodes(): + if isinstance(node, (IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + # print(graph.extra_repr()) + return graph + + +def PASOptimal(graph: IRGraph, resource): + """ + Square Linear optimal parallelism (4GPU) + """ + assert resource.ngpus == 4, "only apply to 4 GPU case" + + # replicate data operation + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + + # replicate loss operation + fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] + loss = fnodes[-1] + sub_nodes = graph.replicate(loss, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + + fnodes = fnodes[:-1] + # linear0 config + config0 = [ + None, + dict(idx=1, dim=0, num=4) # col + ] + # linear1 config + config1 = [ + dict(idx=0, dim=1, num=2), # row + dict(idx=1, dim=0, num=2), # col + ] + # linear2 config + config2 = [ + dict(idx=0, dim=0, num=2), # dat + dict(idx=0, dim=1, num=2), # row + ] + # linear3 config + config3 = [ + dict(idx=0, dim=0, num=2), # dat + dict(idx=0, dim=1, num=2), # row + ] + configs = [config0, config1, config2, config3] + assert len(fnodes) == len(configs) + for fnode, config in zip(fnodes, configs): + all_nodes = [fnode] + for conf in config: + if conf is None: + continue + sub_nodes = list() + for node in all_nodes: + algo = node.algorithms('dim') + nodes = graph.partition(node, algo, **conf) + sub_nodes += nodes + all_nodes = sub_nodes + assert len(all_nodes) == 4 + for idx, node in enumerate(all_nodes): + graph.assign(node, idx) + return graph + From cf64c5550caf20774c1dbe958f77f1e15440aa2b Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 2 Mar 2023 02:53:11 -0800 Subject: [PATCH 1278/1892] save work --- cube/graph/function/function.py | 35 +++++++++++++++++++++++++++++++++ cube/graph/parser/mappingfx.py | 2 ++ cube/ir/dtype.py | 2 +- examples/mlp/linearsfx.py | 2 ++ 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 546c1bd6..73d94f80 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -480,6 +480,21 @@ def Dropout(signature, inputs): p=p, training=training, inplace=inplace) +def EQ(signature, inputs): + assert len(inputs) == 2 + input0, input1 = inputs + + edim_in0 = ShapeAnno.create_shape_str(input0.shape) + edim_ou = copy.copy(edim_in0) + if isinstance(input1, float): + anno = OpAnno.create_op_str([edim_in0], [edim_ou]) + return IRDimops(EQ, 'eq', signature, [anno], [input0], other=input1) + else: + edim_in1 = copy.copy(edim_in0) + anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) + return IRDimops(EQ, 'eq', signature, [anno], [input0], other=input1) + + def NE(signature, inputs): assert len(inputs) == 2 input0, input1 = inputs @@ -843,6 +858,7 @@ def Unsqueeze(signature, inputs): return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input], dim=dim) + def TypeAs(signature, inputs): """ out = torch.Tensor.type_as(tensor0, tensor1) @@ -857,6 +873,7 @@ def TypeAs(signature, inputs): return IRDimops(TypeAs, 'type_as', signature, [anno], [input0, input1]) + def Triu(signature, inputs): """ out = torch.triu(tensor, diagonal) @@ -874,6 +891,24 @@ def Triu(signature, inputs): return IRDimops(Triu, 'triu', signature, [anno], [input], diagonal=diagonal) + +def CumSum(signature, inputs): + """ + out = torch.cumsum(tensor, dim) + """ + assert len(inputs) == 2 + input, dim = inputs + assert isinstance(dim, int) + + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_in[dim] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(CumSum, 'cumsum', signature, [anno], [input], + dim=dim) + + # def Pad(signature, inputs): # """ # torch.nn.functional.pad(input: torch.Tensor, pad: List[int], mode='constant', value=0.0) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index a91be34f..1857accb 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -64,10 +64,12 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __tttemplate('type_as'): function.TypeAs, __ttemplate('triu'): function.Triu, __ftemplate('relu') : function.ReLU, + __ttemplate('eq') : function.EQ, __ttemplate('ne') : function.NE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, __ttemplate('masked_fill'): function.MaskedFill, + __ttemplate('cumsum'): function.CumSum, # # torch nn functional # diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index ab3a463b..4bda60c0 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -47,7 +47,7 @@ def infer(node, dtypes: List[IRDType]) -> IRDType: if IRDType.float32 in dtypes and IRDType.float16 in dtypes: raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") # TODO(yizhu1): hack - if node.signature == 'torch.ne': + if node.signature in ['torch.ne', 'torch.eq']: return IRDType.boolean elif node.signature == 'torch.Tensor.long': return IRDType.int64 diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 1496c861..d54943c9 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -74,6 +74,8 @@ def forward(self, data, mask): # x = torch.ne(x, 1.0) # x = torch.nn.functional.dropout(x, self.p) # x = x * self.y + x = torch.cumsum(x, -1) + # y = torch.eq(x, 1.0) loss = torch.sum(x) # long cannot backward # loss = loss.long() From 1e830f07a32da1e895a95ae4690cc67a9b6d173e Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Thu, 2 Mar 2023 11:24:31 +0000 Subject: [PATCH 1279/1892] Merged PR 1462: fix SynData device misplace issue fix SynData device misplace issue --- cube/runtime/syndata.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 2a203b76..77eb5a08 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -123,7 +123,7 @@ class SynDataLoader(CubeDataLoader): for given shapes, dtypes. """ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, - batch_dims: Tuple[int] = None, names: Tuple[str] = None, append_args=None, device=torch.cuda.current_device()): + batch_dims: Tuple[int] = None, names: Tuple[str] = None, append_args=None, device=None): """ shapes Tuple[Tuple[int]]: The shape for each data @@ -139,9 +139,8 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, super().__init__(shapes, dtypes, batch_dims) self.names = names - print(f'### SynDataLoader.names = {names}') self.append_args=append_args - self.device = device + self.device = device if device else torch.cuda.current_device() self.buffer: Union[torch.Tensor, Tuple[torch.Tensor]] = None datas = self.random_sample() self.set_output(datas) From 8e5ca1bda338b78a2b795835b5bbc7f348376e22 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Mar 2023 12:14:20 +0000 Subject: [PATCH 1280/1892] Merged PR 1464: General Dataflow with IRObject - Support `IRObject` as basic data type in operator input/output - Support `dataloader` output non-tensor data - Add simple test for graph inputs with non-tensor data A general dataflow graph has `IRObject` as inputs/outputs of each operator. Constant values will be put inside `op.kwargs`. For supporting general graph with python runtime functions: * `IRObject` can be any data type. * `IRPyFunc` can be any python runtime function. * `IRTensor` is derived from `IRObject`. --- cube/codegen/emit.py | 4 +- cube/codegen/module/module.py | 2 +- cube/graph/function/pyfunc.py | 7 +- cube/graph/graph.py | 10 +-- cube/graph/parser/parser.py | 3 +- cube/graph/segment.py | 45 +++++------ cube/ir/cten.py | 145 +++++++++++++++++++++------------- cube/program.py | 26 +++--- cube/runtime/syndata.py | 71 +++++++++-------- tests/parser/test_jit_ops.py | 35 ++++++-- 10 files changed, 206 insertions(+), 142 deletions(-) diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index a1e12706..3243436c 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -1,6 +1,6 @@ from typing import Generator, Iterable, List, Any, Optional, Tuple -from cube.ir.cten import IRCell, IRTensor +from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.dtype import IRDType from cube.ir.tensor import IRSubTensor from cube.ir.operator import IRDataOperation, IRFwOperation @@ -39,7 +39,7 @@ def tensor_name(tensor: Any, prefix_attr: Optional[str] = None) -> str: @return str """ - if isinstance(tensor, IRTensor): + if isinstance(tensor, IRObject): tensor_name = tensor.name if '.' in tensor_name: tensor_name = tensor_name.split('.')[0] diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 4a09aac9..54ca1ffd 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -369,7 +369,7 @@ def init_batchsize(self, node: IRDataOperation): Emit batch size declare """ signature = 'self.set_batch_size({bs})' - bs = [t.shape[dim] for t, dim in zip(node.outputs(), node.get_batch_dims())] + bs = [t.shape[dim] for t, dim in zip(node.outputs(), node.get_batch_dims()) if dim is not None] bs = set(bs) if len(bs) > 1: warnings.warn(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') diff --git a/cube/graph/function/pyfunc.py b/cube/graph/function/pyfunc.py index 5d6c4fdf..074cb859 100644 --- a/cube/graph/function/pyfunc.py +++ b/cube/graph/function/pyfunc.py @@ -1,8 +1,7 @@ -from typing import List, Optional, Tuple, Any -import itertools +from typing import Tuple from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor +from cube.ir.cten import IRObject class IRPyFunc(IRFwOperation): @@ -11,7 +10,7 @@ class IRPyFunc(IRFwOperation): """ def __init__(self, signature: str, - inputs: Tuple[Any], outputs: Tuple[Any], **kwargs): + inputs: Tuple[IRObject], outputs: Tuple[IRObject], **kwargs): name = signature.split('.')[-1] super().__init__(name, signature, len(inputs), len(outputs)) for idx, t in enumerate(inputs): diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 446503ff..c8214e19 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,11 +7,11 @@ will be inserted at scheduling time. """ -from typing import Sequence, Set, Union, Tuple, List, Optional, Dict +from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any import warnings import copy -from cube.ir.cten import IRTensor, IRCell +from cube.ir.cten import IRTensor, IRCell, IRObject from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap @@ -59,7 +59,7 @@ def __call__(self, *args): """ return self.forward(*args) - def forward(self, *args: Tuple[IRSubTensor]) -> Union[IRTensor, Tuple[IRTensor]]: + def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: """ forward will divide the graph into Actions according to node device assignment @@ -72,7 +72,7 @@ def forward(self, *args: Tuple[IRSubTensor]) -> Union[IRTensor, Tuple[IRTensor]] @return outputs Union[IRSubTensor, Tuple[IRSubTensor]] """ # align graph with input tensors - itensors: Tuple[IRSubTensor, ...] = self.inputs() + itensors: Tuple[IRObject, ...] = self.inputs() assert len(args) == len(itensors) for idx, (itensor, arg) in enumerate(zip(itensors, args)): self.set_input(idx, arg) @@ -226,7 +226,7 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: @staticmethod def from_logic_graph(nodes: List[IRCell], - inputs: List[IRFullTensor], outputs: List[IRFullTensor], + inputs: List[IRObject], outputs: List[IRObject], module_name: str): """ Generate IRGraph from logical graph (IRFullTensor) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index 778eaf77..ffdf02fe 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -3,6 +3,7 @@ import re from typing import Any, List, Tuple, Optional +from cube.ir.cten import IRObject from cube.ir.operator import IRFwOperation from cube.graph.function.pyfunc import IRPyFunc from cube.ir.tensor import IRFullTensor @@ -61,7 +62,7 @@ def parse_module(module, dtype = ir.IRDType.unknown # kDefaultType val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.debugName()) else: - raise NotImplementedError("Graph inputs only accepts Tensor") + val = IRObject(name=input.debugName()) frame.add_var(input.debugName(), val, graph_arg=idx) input_val = [frame.get_var(input.debugName()) for input in inputs] diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 1ad34131..5bdfc27c 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -3,11 +3,12 @@ import numpy as np from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap -from cube.ir.cten import IRTensor, IRCell +from cube.ir.cten import IRTensor, IRCell, IRObject from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation from cube.ir.adapter import IRAdapter from cube.graph.function.function import MultiRef +from cube.graph.function.pyfunc import IRPyFunc class CellPosition: @@ -111,7 +112,7 @@ def full_tensors(self) -> Tuple[IRFullTensor]: @return ftensors List[IRFullTensor] """ - return tuple(self._ftensors) + return tuple(t for t in self._ftensors if isinstance(t, IRFullTensor)) def attributes(self) -> Tuple[IRFullTensor]: """ @@ -362,11 +363,11 @@ def create_bwop(self, fwop: IRFwOperation) -> Union[IRBpOperation, IRBpOperation # ====================== Basic Graph manipulations ====================== - def _add_ftensor(self, ftensor: IRFullTensor): + def _add_ftensor(self, ftensor: IRObject): """ Add a full tensor in segment if the segment doesn't have the tensor. """ - assert isinstance(ftensor, IRFullTensor) + assert isinstance(ftensor, IRObject) if ftensor not in self._ftensors: self._ftensors.add(ftensor) self._producers[ftensor] = [] @@ -376,11 +377,11 @@ def _add_ftensor(self, ftensor: IRFullTensor): if ftensor.is_attr(): self._attributes.add(ftensor) - def _remove_ftensor(self, ftensor: IRFullTensor): + def _remove_ftensor(self, ftensor: IRObject): """ Remove a full tensor in segment """ - assert isinstance(ftensor, IRFullTensor) + assert isinstance(ftensor, IRObject) if ftensor in self._ftensors: self._ftensors.remove(ftensor) del self._producers[ftensor] @@ -404,13 +405,13 @@ def _reorder_producer_consumer(self): # set producer and consumer for node in self._nodes: if isinstance(node, IRAdapter): continue - itensors = set(t for t in node.inputs() if isinstance(t, IRSubTensor)) + itensors = set(t for t in node.inputs() if isinstance(t, IRObject)) for itensor in itensors: ftensor = itensor.parent self._add_ftensor(ftensor) self._consumers[ftensor].append(node) self._ctensors[ftensor].append(itensor) - otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) + otensors = set(t for t in node.outputs() if isinstance(t, IRObject)) for otensor in otensors: ftensor = otensor.parent self._add_ftensor(ftensor) @@ -440,23 +441,17 @@ def insert(self, node: IRCell, index: Union[int, CellPosition]): # update producer and consumer if isinstance(node, IRAdapter): return # consumer - itensors = set(t for t in node.inputs() if isinstance(t, IRSubTensor)) + itensors = set(t for t in node.inputs() if isinstance(t, IRObject)) for itensor in itensors: ftensor = itensor.parent self._add_ftensor(ftensor) - # idx = len([c for c in self._consumers[ftensor] if self._nodes.index(c) < index]) - # self._consumers[ftensor].insert(idx, node) - # self._ctensors[ftensor].insert(idx, itensor) self._consumers[ftensor].append(node) self._ctensors[ftensor].append(itensor) # producer - otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) + otensors = set(t for t in node.outputs() if isinstance(t, IRObject)) for otensor in otensors: ftensor = otensor.parent self._add_ftensor(ftensor) - # idx = len([c for c in self._producers[ftensor] if self._nodes.index(c) < index]) - # self._producers[ftensor].insert(idx, node) - # self._ptensors[ftensor].insert(idx, otensor) self._producers[ftensor].append(node) self._ptensors[ftensor].append(otensor) else: @@ -487,7 +482,7 @@ def remove(self, node: IRCell, _pos: Union[int, CellPosition] = None) -> CellPos # update producer and consumer if isinstance(node, IRAdapter): return pos # consumer - itensors = set(t for t in node.inputs() if isinstance(t, IRSubTensor)) + itensors = set(t for t in node.inputs() if isinstance(t, IRObject)) for itensor in itensors: ftensor = itensor.parent idx = self._consumers[ftensor].index(node) @@ -496,7 +491,7 @@ def remove(self, node: IRCell, _pos: Union[int, CellPosition] = None) -> CellPos if len(self._consumers[ftensor]) == 0 and len(self._producers[ftensor]) == 0: self._remove_ftensor(ftensor) # producer - otensors = set(t for t in node.outputs() if isinstance(t, IRSubTensor)) + otensors = set(t for t in node.outputs() if isinstance(t, IRObject)) for otensor in otensors: ftensor = otensor.parent idx = self._producers[ftensor].index(node) @@ -926,7 +921,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: "Non-differentiable IRAdapter is not allowed to be grouped" continue # update inputs - itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] + itensors = [t for t in node.inputs() if isinstance(t, IRObject)] for itensor in itensors: ftensor = itensor.parent if itensor.is_attr(): continue @@ -935,7 +930,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: if len(node.device) > 0 and set(itensor.device).issubset(adapter_ous[itensor]): continue # from segment inputs - if any(t.overlap(itensor) for t in segment.inputs() if isinstance(t, IRSubTensor)): + if any(t.overlap(itensor) for t in segment.inputs() if isinstance(t, IRObject)): inputs.add(itensor) continue # from outside producers @@ -948,7 +943,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: inputs.add(itensor) continue # update outputs - otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] + otensors = [t for t in node.outputs() if isinstance(t, IRObject)] for otensor in otensors: ftensor = otensor.parent if otensor.is_attr(): continue @@ -957,7 +952,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: if len(node.device) > 0 and set(otensor.device).issubset(adapter_ins[otensor]): continue # from segment outputs - if any(t.overlap(otensor) for t in segment.outputs() if isinstance(t, IRSubTensor)): + if any(t.overlap(otensor) for t in segment.outputs() if isinstance(t, IRObject)): outputs.add(otensor) continue # loss doesn't have consumers @@ -976,7 +971,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: outputs.add(otensor) continue - def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: + def order(tensors: Set[IRObject]) -> Tuple[IRObject]: """Reorder by logical tensor id. Temporally necessary for pipeline scheduling""" tensors = list(tensors) tids = np.array([t.parent.tid for t in tensors]) @@ -1012,7 +1007,7 @@ def dispatch(self, devid: int, _gen_mirror: bool = True) -> Optional[IRCell]: # if otensor in self._outputs and otensor not in outputs: # outputs.append(otensor) - def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: + def order(tensors: Set[IRObject]) -> Tuple[IRObject]: """Reorder by logical tensor id. Temporally necessary for pipeline scheduling""" tensors = list(tensors) tids = np.array([t.parent.tid for t in tensors]) @@ -1035,7 +1030,7 @@ def order(tensors: Set[IRSubTensor]) -> Tuple[IRSubTensor]: def __repr__(self): fw = 'f' if self.isfw() else 'b' - inputs = tuple(t for t in self.inputs() if isinstance(t, IRTensor) and not t.is_param()) + inputs = tuple(t for t in self.inputs() if isinstance(t, IRObject) and not t.is_param()) if self.isfw(): dscp = f"{fw}Graph{self.cid}-{self.device}(inputs={inputs}, outputs={self.outputs()})" else: diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 0f1cd79e..50fbfced 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -22,9 +22,6 @@ from cube.ir.dtype import IRDType, dtype2byte_size -__all__ = ['IRCell', 'IRDType', 'IRTensor'] - - class IRCell: r""" IRCell serves as a general node for different purpose @@ -431,64 +428,38 @@ def __repr__(self) -> str: return dscp -class IRTensor: +class IRObject: """ - IRTensor serves as IRGraph edge - - Note by setting IRTensor name to "None" indicates this tensor holds nothing - and will be translated to None in code generation. + IRObject serves as non-tensor inputs/outputs for IRCell. """ - _meta = ['name', '_is_attr', '_is_grad', '_requires_grad', '_dtype'] - - def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): - + def __init__(self, name: Optional[str] = None, tid: Optional[int] = None): + """ + @param name str: object name + @param tid int: object unique id + """ self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() - self._shape: Tuple[int] = () if shape is None else tuple(shape) self.name: str = name if name else 'tensor' - - # device self._cell: Optional[IRCell] = None - - self._dtype: IRDType = dtype self._is_attr: bool = False - self._is_grad: bool = False - # tensor gradient - self._requires_grad: bool = False - self._grad: Optional[Union[IRTensor, float]] = None - - @property - def tid(self) -> int: - """ - Get tensor id + def __eq__(self, obj): + if not isinstance(obj, IRObject): + return False + return self._id == obj.tid - @return cid int: the tensor id. - """ + def __hash__(self) -> int: return self._id @property - def dtype(self) -> IRDType: - """ - Tensor data type - """ - return self._dtype - - @dtype.setter - def dtype(self, val: IRDType): - """ - Set data type - """ - if not isinstance(val, IRDType): - raise TypeError(f"Expected IRDType but got {val}") - self._dtype = val - if isinstance(self._grad, IRTensor): - self._dtype = val + def tid(self) -> int: + """Get object id""" + return self._id @property - def cell(self) -> Optional[IRCell]: + def cell(self) -> IRCell: return self._cell - + @cell.setter def cell(self, val: Optional[IRCell]): assert isinstance(val, IRCell) or val is None, "Expected cell to be Optional[IRCell]" @@ -504,17 +475,90 @@ def device(self) -> Tuple[int]: @device.setter def device(self, val: Union[int, List[int]]): raise RuntimeError( - "tensor placement is not allowed to set manually" + "IRObject placement is not allowed to set manually" ) + + @property + def parent(self): + """Get parent""" + return self + + def __eq__(self, obj) -> bool: + if not isinstance(obj, IRObject): + return False + return self._id == obj.tid + + def __copy__(self): + """Copy this object but remove the cell information""" + return IRObject(self.name, self._id) + + def as_attr(self): + """ + Set the obj as graph attributes + """ + self._is_attr = True + return self def is_attr(self) -> bool: """! - Check if the tensor is graph attribute. + Check if the object is graph attribute. @return is_attr boolean: True if is graph attribute (buffer or parameter or gradient of parameter) """ return self._is_attr + def overlap(self, other: Any) -> bool: + """! + Check whether two object can be overlapped + """ + if isinstance(other, IRObject): + return other.tid == self._id + else: + return False + + def __repr__(self): + return f'Object({self.name}{self.tid})' + + +class IRTensor(IRObject): + """ + IRTensor serves as IRGraph edge + + Note by setting IRTensor name to "None" indicates this tensor holds nothing + and will be translated to None in code generation. + """ + + _meta = ['name', '_is_attr', '_is_grad', '_requires_grad', '_dtype'] + + def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): + + super().__init__(name, tid) + self._shape: Tuple[int] = () if shape is None else tuple(shape) + self._cell: Optional[IRCell] = None + self._dtype: IRDType = dtype + # tensor gradient + self._is_grad: bool = False + self._requires_grad: bool = False + self._grad: Optional[Union[IRTensor, float]] = None + + @property + def dtype(self) -> IRDType: + """ + Tensor data type + """ + return self._dtype + + @dtype.setter + def dtype(self, val: IRDType): + """ + Set data type + """ + if not isinstance(val, IRDType): + raise TypeError(f"Expected IRDType but got {val}") + self._dtype = val + if isinstance(self._grad, IRTensor): + self._dtype = val + def is_param(self) -> bool: """! Check if the tensor is parameter @@ -586,11 +630,6 @@ def __copy__(self): tensor.cell = None return tensor - def __eq__(self, tensor): - if not isinstance(tensor, IRTensor): - return False - return self._id == tensor._id - @property def shape(self) -> Tuple[int]: return list(self._shape) diff --git a/cube/program.py b/cube/program.py index ef8a49dd..948f42e3 100644 --- a/cube/program.py +++ b/cube/program.py @@ -1,6 +1,6 @@ from typing import List, Tuple, Optional -from cube.ir.cten import IRCell, IRTensor +from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.operator import IRBpOperation, IRDataOperation @@ -82,11 +82,14 @@ def __init__(self, dataloader: CubeDataLoader): raise TypeError("Expected data loader derived from CubeDataLoader") self.dataloader: CubeDataLoader = iter(dataloader) dtype_map = DType2IRDType - self.dtypes = [dtype_map.map(dtype) for dtype in dataloader.dtypes] - self.shapes = [list(shape) for shape in dataloader.shapes] + sample = next(dataloader) + if not isinstance(sample, tuple): + sample = (sample,) + self.dtypes = [dtype_map.map(t.dtype) if torch.is_tensor(t) else None for t in sample] + self.shapes = [list(t.shape) if torch.is_tensor(t) else None for t in sample] - def get_batch_dims(self) -> Tuple[int]: - return tuple(self.dataloader.batch_dims) + def get_batch_dims(self) -> Tuple[Optional[int]]: + return tuple(self.dataloader.get_batch_dims()) def get_batch_size(self) -> int: return self.dataloader.get_batch_size() @@ -101,9 +104,12 @@ def __iter__(self): def __next__(self): outputs = list() for dtype, shape in zip(self.dtypes, self.shapes): - data = IRFullTensor( - shape, 'data', requires_grad=False, dtype=dtype - ).tosub() + if shape is not None: + data = IRFullTensor( + shape, 'data', requires_grad=False, dtype=dtype + ).tosub() + else: + data = IRObject('data') outputs.append(data) data_op = IRDataOperation( @@ -172,8 +178,8 @@ def __call__(self, *args): if self._loaded_module: return self._loaded_module(*args) else: - assert all(isinstance(t, IRSubTensor) for t in args), f"Only support tensors as model inputs" - input_shapes = [tuple(t.shape) for t in args] + # assert all(isinstance(t, IRSubTensor) for t in args), f"Only support tensors as model inputs" + input_shapes = [tuple(t.shape) if isinstance(t, IRTensor) else None for t in args] if DeviceGroup().local_rank == 0: if self.ir_graph is None: self.ir_graph = parser.convert_model( diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 77eb5a08..bd50cf0a 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -9,46 +9,46 @@ class CubeDataLoader: r""" - Cube Dataloader + Cube Dataloader. + User should provide a dataloader to runtime with at least these functionalities: + + 1) `__iter__()`: get the dataloder iterator + 2) `__next__()` get the next batch of data + 3) `get_batch_size()` return the batch size (int) + 4) `set_batch_size(bs)` reset the batch size (int) + 5) `get_batch_dims(self)` get the batch dimension of each output data """ - def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype], batch_dims: Tuple[int] = None): + def __init__(self, batch_size: int, batch_dims: Tuple[Optional[int]]): """ - shapes Tuple[Tuple[int]]: - The shape for each data - dtypes Tuple[torch.dtype]: - The dtype for each data - batch_dims Tuple[int]: - The batch dimension of each data + Create a dataloader for cube runtime + + @param batch_size int: dataloader batch size + @param batch_dims Tuple[Optional[int]]: the batch dimension of each output data, + None indicates the output (tensor or non-tensor) doesn't have the batch dimension. """ - if not all(isinstance(shape, list) for shape in shapes): - raise TypeError("Expected each shape in shapes to be a list") - if len(shapes) != len(batch_dims) or len(shapes) != len(dtypes): - raise TypeError("Expected number batch dim and dtypes to len(shapes)") - self.shapes = tuple([list(shape) for shape in shapes]) - self.dtypes = dtypes - self.batch_dims = (0,) * len(self.shapes) if batch_dims is None else batch_dims - bs = [shape[dim] for shape, dim in zip(self.shapes, self.batch_dims)] - assert len(set(bs)) == 1, f"Expect batch size same in each data shapes" - self.batch_size = bs[0] + self.batch_size: int = batch_size + self.batch_dims: Tuple[Optional[int]] = batch_dims + + def __iter__(self): + raise NotImplementedError("Required implementation for derived class") + + def __next__(self): + return NotImplementedError("Required implementation for derived class") def get_batch_size(self) -> int: """ get batch size """ - all_batch_size = set([shape[dim] for shape, dim in zip(self.shapes, self.batch_dims)]) - if len(all_batch_size) != 1: - raise ValueError("Heterogenous batch size in dataloader") - return list(all_batch_size)[0] + return self.batch_size def set_batch_size(self, batch_size: int): """ set batch size """ - self.batch_size = batch_size - for shape, dim in zip(self.shapes, self.batch_dims): - shape[dim] = batch_size - rank = 0 if not torch.distributed.is_initialized() else torch.distributed.get_rank() - print(f'rank [{rank}]: > set batch size to {batch_size}. dataloader outputs change to: {self.shapes}') + return NotImplementedError("Required implementation for derived class") + + def get_batch_dims(self) -> Tuple[Optional[int]]: + return tuple(self.batch_dims) class SciLoopVariables(CubeDataLoader): @@ -64,8 +64,7 @@ def __init__(self, variables: List[Any], constants: List[Any]): else: shapes.append([1,]) dtypes.append(type(var)) - batch_dims = [-1] * (len(variables) + len(constants)) - super().__init__(shapes, dtypes, batch_dims) + super().__init__(0, [None] * len(shapes)) self.variables = list() self.constants = list() for var in variables: @@ -136,8 +135,10 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, batch_dims = tuple([0] * len(shapes)) if dtypes is None: dtypes = tuple([torch.float] * len(shapes)) - - super().__init__(shapes, dtypes, batch_dims) + self.shapes = tuple([list(shape) for shape in shapes]) + self.dtypes = dtypes + batch_size = shapes[0][batch_dims[0]] + super().__init__(batch_size, batch_dims) self.names = names self.append_args=append_args self.device = device if device else torch.cuda.current_device() @@ -151,7 +152,6 @@ def __iter__(self): def __next__(self): if self.names: assert len(self.names) == len(self.buffer) - print(f'### named_syn_data') return dict(zip(self.names, self.buffer)).update(self.append_args) else: return self.buffer @@ -182,7 +182,10 @@ def set_output(self, datas: Union[torch.Tensor, Tuple[torch.Tensor]]): self.buffer = datas[0] if len(datas) == 1 else datas def set_batch_size(self, batch_size: int): - super().set_batch_size(batch_size) + self.batch_size = batch_size + for shape, dim in zip(self.shapes, self.batch_dims): + shape[dim] = batch_size + rank = 0 if not torch.distributed.is_initialized() else torch.distributed.get_rank() + print(f'rank [{rank}]: > set batch size to {batch_size}. dataloader outputs change to: {self.shapes}') datas = self.random_sample() self.set_output(datas) - diff --git a/tests/parser/test_jit_ops.py b/tests/parser/test_jit_ops.py index 50456120..7f489540 100644 --- a/tests/parser/test_jit_ops.py +++ b/tests/parser/test_jit_ops.py @@ -1,5 +1,5 @@ """ -torchrun --nproc_per_node=1 tests/parser/test_torch_ops.py +torchrun --nproc_per_node=1 tests/parser/test_jit_ops.py """ import torch @@ -14,7 +14,8 @@ def __init__(self, shape=[256, 512]): super().__init__() self.param = torch.nn.Parameter(torch.empty(shape, dtype=torch.float32)) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, cache: int): + x = x + cache # [256, 512], [256, 512] -> [256, 512] x = x * self.param # [256, 512] -> [512] @@ -26,15 +27,34 @@ def forward(self, x: torch.Tensor): # [512, 512] -> [256, 512]: this will be parsed to 2 slice operations x4 = x3[:256,:] return x4 + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self) -> None: + self.sample = ( + torch.rand([256, 512], dtype=torch.float32, device=torch.cuda.current_device()), + 4 + ) + batch_size = self.sample[0][0] + super().__init__(batch_size, (0, None)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + def set_batch_size(self, batch_size: int): + return True + def test_parse_ops(): cube.init() model = TestOpModule() - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([256, 512],), dtypes=(torch.float32,), batch_dims=(0,)) + dataloader = TestDataLoader() def policy(graph, resource): assert resource.ngpus == 1 @@ -52,13 +72,14 @@ def policy(graph, resource): @cube.compile(model, dataloader, policy, load_content=False) def eval_iter(model, dataloader): - data = next(dataloader) - out = model(data) + data1, data2 = next(dataloader) + out = model(data1, data2) model = model.get_gen_module() - for _ in range(3): + for idx in range(3): eval_iter(model, dataloader) + print(f"iter {idx}/3") if __name__ == '__main__': From bc29ca07867bcc971e0c8f8178db6aae5b08a849 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Thu, 2 Mar 2023 12:21:03 +0000 Subject: [PATCH 1281/1892] Merged PR 1463: Support concrete parser for linearsfx.py Support concrete parser for linearsfx.py --- cube/graph/parser/converter.py | 9 ++++++--- examples/mlp/linearsfx.py | 20 +++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 092f91ce..a2dfb73e 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -27,10 +27,13 @@ def convert_model(model: torch.nn.Module, smodule: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) smodule.graph.print_tabular() else: - print(f'input_shapes = {input_shapes}, {type(input_shapes)}') - print(f'dummy_input = {dummy_input}, {type(dummy_input)}') with torch.no_grad(): - output_origin = model(**dummy_input) + if isinstance(dummy_input, tuple): + output_origin = model(*dummy_input) + elif isinstance(dummy_input, dict): + output_origin = model(**dummy_input) + else: + raise RuntimeError(f'Unknown dummy_input = {dummy_input}') traced_model, _ = concrete_trace( model, dummy_input, diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 1496c861..15e288de 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -84,17 +84,27 @@ def train(): batch_size = 32 dim = 1024 - model = MLP(dim=dim) - model = cube.SemanticModel( - model, input_shapes=([batch_size, dim, dim], [batch_size, dim, dim]), - ) - dataloader = cube.runtime.syndata.SynDataLoader( shapes=([batch_size, dim, dim], [batch_size, dim, dim],), dtypes=(torch.float32, torch.bool,), batch_dims=(0, 0,) ) + model = MLP(dim=dim) + + # shape based input (will trigger standard torch.fx.Tracer) + # model = cube.SemanticModel( + # model, input_shapes=([batch_size, dim, dim], [batch_size, dim, dim]), + # ) + + # dummy based input (will trigger concrete Tracer) + device = next(model.parameters()).device + dummy_input = next(dataloader) + dummy_input = tuple([input.to(device) for input in dummy_input]) + model = cube.SemanticModel( + model, dummy_input=dummy_input, + ) + @cube.compile(model, dataloader, PAS=PAS, load_content=False) def train_iter(model, dataloader): data, mask = next(dataloader) From ebd3284534b79c28df4f7003733c0f5a1ee708b7 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 2 Mar 2023 04:58:14 -0800 Subject: [PATCH 1282/1892] support fill --- cube/graph/function/function.py | 10 ++++++++++ cube/graph/parser/mappingfx.py | 1 + examples/mlp/linearsfx.py | 1 + 3 files changed, 12 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 73d94f80..08ac975a 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -524,6 +524,16 @@ def Long(signature, inputs): return IRDimops(Long, 'long', signature, annos, tensor) +def Fill(signature, inputs): + assert len(inputs) == 2 + input, value = inputs + + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Fill, 'fill', signature, [anno], [input], value=value) + + def MaskedFill(signature, inputs): assert len(inputs) == 3 input0, input1, value = inputs diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 1857accb..0708438d 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -68,6 +68,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ne') : function.NE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, + __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index d54943c9..d50a8215 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -60,6 +60,7 @@ def __init__(self, dim, mult=1, nlayers=4): def forward(self, data, mask): x = data.masked_fill(mask, 0.0) + x = x.fill_(0.0) for layer in self.layers: x = layer(x) x = torch.nn.functional.relu(x) From a4bae17f867440219066ab50235ef4518a168e71 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Mar 2023 22:14:41 +0800 Subject: [PATCH 1283/1892] hotfix: fix view partition --- cube/graph/function/function.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 546c1bd6..944e93bc 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -665,7 +665,7 @@ def nele(shape, nele=1): if -1 in ou_shape: idx = ou_shape.index(-1) ou_shape[idx] = cnt // (-nele(ou_shape)) - assert nele(in_shape) == nele(ou_shape), "shape mismatch" + assert nele(in_shape) == nele(ou_shape), f"shape mismatch: {in_shape}, {ou_shape}" # generate annotation rest_inshape = [dimlen for dimlen in in_shape] @@ -778,7 +778,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: kwargs = dict(**kwargs) ofirst = [bracket[0] for bracket in ou_anno] - identifier = in_anno[idx][0] + identifier = in_anno[dim][0] oidx = ofirst.index(identifier) size = list(kwargs['size']) size[oidx] = size[oidx] // num From 48d41b4a55aaf5c4f10aef7f9bbc03f06bca0af7 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Thu, 2 Mar 2023 21:18:20 -0800 Subject: [PATCH 1284/1892] minor --- cube/graph/parser/parserfx.py | 134 +++++++++++++++++++++------------- 1 file changed, 85 insertions(+), 49 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index f089371f..ec37b6f6 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -1,7 +1,7 @@ import torch import enum import re -from typing import Any, List, Tuple, Optional +from typing import Any, List, Tuple, Optional, Callable from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor @@ -12,6 +12,8 @@ import torch.fx +global_not_supported = set() + class ErasedDevice: pass @@ -75,7 +77,10 @@ def parse(module: torch.fx.GraphModule, inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] print(f'inputs = {inputs}') - ## shape propagation + # remove dead nodes + from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler + DCEHandler(module).eliminate_dead_code() + # shape propagation from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp KwargsShapeProp(module).propagate(dummy_inputs) for node in module.graph.nodes: @@ -131,10 +136,12 @@ def parse(module: torch.fx.GraphModule, # handle nodes all_ir_nodes: List[IRFwOperation] = list() for node in module.graph.nodes: - print('zql handle node: ', node, node.op, node.meta) + # print('zql handle node: ', node, node.op, node.meta) ir_nodes = FxModuleParser.parse_node(node, module, frame) all_ir_nodes += ir_nodes - + + print(global_not_supported) + return # handle outputs output_nodes = [node.all_input_nodes for node in module.graph.nodes if node.op == 'output'] print(f'outputs = {output_nodes}') @@ -331,29 +338,35 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, %12 : (Tensor, Tensor) = prim::CallFunction(%5, %x1.1, %x2.1) """ # get signature - fsig = FxModuleParser._get_qualified_name(node.target) - print(f'parse_prim_function_node: {fsig}') - - # get inputs - input_nodes = [input_node for input_node in node.args] - input_vals = list() - for index, input_node in enumerate(input_nodes): - if isinstance(input_node, torch.fx.Node): - var_name = input_node.name - val = frame.get_var(var_name) - input_vals.append(val) - elif isinstance(input_node, (int, float)): - input_vals.append(input_node) - else: - input_vals.append(None) + fsig = FxModuleParser._get_qualified_name_of_call_function(node.target) + # print(f'parse_prim_function_node: {fsig}') + + # # get inputs + # input_nodes = [input_node for input_node in node.args] + # input_vals = list() + # for index, input_node in enumerate(input_nodes): + # if isinstance(input_node, torch.fx.Node): + # var_name = input_node.name + # val = frame.get_var(var_name) + # input_vals.append(val) + # elif isinstance(input_node, (int, float)): + # input_vals.append(input_node) + # else: + # input_vals.append(None) # map to IR operator - ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + if fsig in SignFx2Op.kOpMap: + # ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + ir_node = None + else: + print(f'Function {fsig} has not been supported yet!') + global_not_supported.add(fsig) + ir_node = None # TODO gracefully set output output_name = node.name output_val = frame.get_var(output_name) - ir_node.set_output(0, output_val) + # ir_node.set_output(0, output_val) # # push output in the frame # # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) @@ -379,13 +392,13 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, @staticmethod def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: - print('zql: ', node, node.__dir__()) - print('zql: ', node.args, node.kwargs) - print('zql: ', type(node.args[0])) - print('zql: ', node.args[0].name, node.args[0].op, node.args[0].meta) + # print('zql: ', node, node.__dir__()) + # print('zql: ', node.args, node.kwargs) + # print('zql: ', type(node.args[0])) + # print('zql: ', node.args[0].name, node.args[0].op, node.args[0].meta) # get signature - fsig = FxModuleParser._get_qualified_name(node.target) - print(f'parse_prim_method_node: {fsig}') + fsig = FxModuleParser._get_qualified_name_of_call_method(node.target, node) + # print(f'parse_prim_method_node: {fsig}') # get inputs input_nodes = [input_node for input_node in node.args] @@ -401,11 +414,16 @@ def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr input_vals.append(None) # map to IR operator - ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + if fsig in SignFx2Op.kOpMap: + ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + else: + print(f'Method {fsig} has not been supported yet!') + global_not_supported.add(fsig) + ir_node = None output_name = node.name output_val = frame.get_var(output_name) - ir_node.set_output(0, output_val) + # ir_node.set_output(0, output_val) return [ir_node] @@ -413,38 +431,56 @@ def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: assert node is not None tensor_name = node.name - - tensor_shape = node.meta['tensor_meta'].shape - # tensor_dtype = node.meta['tensor_meta'].dtype - #TODO assume it is weight - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=kDefaultType) - ir_tensor.as_param() - frame.add_var(tensor_name, ir_tensor) + if 'tensor_meta' in node.meta: + tensor_shape = node.meta['tensor_meta'].shape + # tensor_dtype = node.meta['tensor_meta'].dtype + #TODO assume it is weight + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=kDefaultType) + ir_tensor.as_param() + frame.add_var(tensor_name, ir_tensor) + else: + print(f'attr {node.op} has not been supported yet!') + global_not_supported.add(node.op) return list() - - from typing import Callable # import python_stubs.buildins @staticmethod - def _get_qualified_name(func: Callable[..., Any]) -> str: + def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str: + """ + The target field of call_function node must be an callable object. + """ # # things like getattr just appear in builtins # if getattr(builtins, func.__name__, None) is func: # return func.__name__ # TODO(yizhu1): find a general solution - if isinstance(func, str): - for module, module_name in [(torch, 'torch'), (torch.Tensor, 'torch.Tensor')]: - lib_func = getattr(module, func, None) - if lib_func is not None and callable(lib_func): - return f'{module_name}.{func}' - raise RuntimeError(f'cannot find module for {func}') - name = func.__name__ - module = FxModuleParser._find_module_of_method(func) + assert callable(node_target) + name = node_target.__name__ + module = FxModuleParser._find_module_of_method(node_target) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module return f'{module}.{name}' + @staticmethod + def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> str: + """ + The target field of call_method node must be a string. + """ + assert isinstance(node_target, str) + # assert len(node.args) == 1 + # assert len(node.kwargs) == 0 + for module, module_name in [(torch, 'torch'), (torch.Tensor, 'torch.Tensor')]: + lib_func = getattr(module, node_target, None) + if lib_func is not None and callable(lib_func): + return f'{module_name}.{node_target}' + # example node.args[0].meta is {'type': } + in_type = node.args[0].meta['type'] + assert node_target in in_type().__dir__() + sig = f'{in_type.__name__}.{node_target}' + print(f'The method is not torch or Tensor, but {sig}') + return sig + # this is fixed on master, WAR for 1.5 @staticmethod def _find_module_of_method(orig_method: Callable[..., Any]) -> str: From a4d222fc21adb3fbdd4971011f181b08e6407a2a Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 2 Mar 2023 22:22:30 -0800 Subject: [PATCH 1285/1892] save work --- cube/graph/function/function.py | 1 + cube/graph/parser/mappingfx.py | 3 +-- cube/graph/parser/parserfx.py | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 08ac975a..ff6d3c81 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -551,6 +551,7 @@ def LayerNorm(signature, inputs): cube.runtime.function.layer_norm(input, weight, bias, normliazed_shape, eps) """ if 'torch.' in signature: + print(inputs) tensor, normalized_shape, weight, bias, eps = inputs assert isinstance(normalized_shape, list), f"normalized_shape for layer_norm can only be List[int]" else: diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 0708438d..cf8313ab 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -71,6 +71,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, + __ftemplate('layer_norm'): function.LayerNorm, # # torch nn functional # @@ -90,8 +91,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # # __ftemplate('_pad'): function.Pad, # - # __ftemplate('layer_norm'): function.LayerNorm, - # # __ftemplate('embedding'): function.Embedding, # # __ftemplate('cross_entropy'): function.CrossEntropy, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index be4b81f8..834cdb5c 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -320,10 +320,15 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, val = frame.get_var(var_name) input_vals.append(val) elif isinstance(input_node, (int, float)): + # kw scalar args input_vals.append(input_node) else: input_vals.append(None) + if 'layer_norm' in fsig: + print(input_nodes) + print(frame._vars) + # map to IR operator ir_node = SignFx2Op.map(fsig)(inputs=input_vals) From 577bcdd96e7d0f7ba0beea0706208095cfd253cf Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Fri, 3 Mar 2023 08:34:05 +0000 Subject: [PATCH 1286/1892] Merged PR 1465: fix torchscale dataloading to adapter new IRPython object fix torchscale dataloading to adapter new IRPython object --- cube/graph/parser/converter.py | 10 +++++--- cube/runtime/syndata.py | 28 +++++++++++--------- examples/nlp/torchscale/fx_test.py | 41 +++++++++++++++--------------- 3 files changed, 43 insertions(+), 36 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index a2dfb73e..57c73273 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -20,6 +20,7 @@ def convert_model(model: torch.nn.Module, try: if CompileFlag.use_torchfx: if not dummy_input: + print('using torch.fx tracer') from torch.fx import symbolic_trace # Symbolic tracing frontend - captures the semantics of the module tracer = FxFuncOpTracer() @@ -27,13 +28,14 @@ def convert_model(model: torch.nn.Module, smodule: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) smodule.graph.print_tabular() else: + print('using concrete tracer') with torch.no_grad(): if isinstance(dummy_input, tuple): output_origin = model(*dummy_input) - elif isinstance(dummy_input, dict): - output_origin = model(**dummy_input) + elif isinstance(dummy_input, torch.Tensor): + output_origin = model(dummy_input) else: - raise RuntimeError(f'Unknown dummy_input = {dummy_input}') + raise RuntimeError(f'dummy_input should be a tuple = {dummy_input}') traced_model, _ = concrete_trace( model, dummy_input, @@ -43,9 +45,9 @@ def convert_model(model: torch.nn.Module, type(output_origin): ((), False), }, ) - print(f'type(traced_model = {type(traced_model)}') traced_model.graph.print_tabular() else: + print('using torchscript tracer') smodule = torch.jit.script(model) except Exception as ex: diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index bd50cf0a..daea5538 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -138,6 +138,7 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, self.shapes = tuple([list(shape) for shape in shapes]) self.dtypes = dtypes batch_size = shapes[0][batch_dims[0]] + assert not names super().__init__(batch_size, batch_dims) self.names = names self.append_args=append_args @@ -160,18 +161,21 @@ def random_sample(self) -> Tuple[torch.Tensor]: torch.manual_seed(0) datas = [] for shape, dtype in zip(self.shapes, self.dtypes): - datas.append( - torch.rand( - shape, - device=self.device, - requires_grad=False).to(dtype) - if torch.is_floating_point(torch.zeros([1], dtype=dtype)) else - torch.ones( - shape, - device=self.device, - requires_grad=False - ).to(dtype) - ) + if shape and all(isinstance(dim, int) for dim in list(shape)): + datas.append( + torch.rand( + shape, + device=self.device, + requires_grad=False).to(dtype) + if torch.is_floating_point(torch.zeros([1], dtype=dtype)) else + torch.ones( + shape, + device=self.device, + requires_grad=False + ).to(dtype) + ) + else: + datas.append(dtype()) return tuple(datas) def set_output(self, datas: Union[torch.Tensor, Tuple[torch.Tensor]]): diff --git a/examples/nlp/torchscale/fx_test.py b/examples/nlp/torchscale/fx_test.py index 125aa260..20588758 100644 --- a/examples/nlp/torchscale/fx_test.py +++ b/examples/nlp/torchscale/fx_test.py @@ -74,49 +74,50 @@ for key in dummy_input.keys(): dummy_input[key] = dummy_input[key].to(device) print("creating dummy input succeed") -print(f'dummy_input = {dummy_input}, {type(dummy_input)}') dummy_input['features_only'] = False dummy_input['return_all_hiddens'] = False print(f'dummy_input = {dummy_input}, {type(dummy_input)}') +# create input as list of tensors/objects +dummy_input_list = [val for key, val in dict(dummy_input).items()] +print(f'dummy_input_list = {dummy_input_list}') + with torch.no_grad(): - output_origin = model(**dummy_input) + # output_origin = model(**dummy_input) + output_origin = model(*dummy_input_list) + # print(f'output_origin = {output_origin}') input_shapes = [list(dummy_input[input].size()) for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] input_dtypes = [dummy_input[input].dtype for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] input_names = tuple([input for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)]) -print(f'input_shapes(out) = {input_shapes}, {type(input_shapes)}, {type(input_shapes[0])}') + +input_shapes += [[None], [None]] +input_dtypes += [bool, bool] + +print(f'input_shapes = {input_shapes}') print(f'input_dtypes = {input_dtypes}') -kwargs_keys = [input for input in dummy_input if not isinstance(dummy_input[input], torch.Tensor)] -kwargs = dict() -for key in kwargs_keys: - kwargs[key] = dummy_input[key] -print(f'kwargs = {kwargs}') -# model = cube.SemanticModel( -# model, input_shapes=(input_shapes,), -# ) -model = cube.SemanticModel( - model, dummy_input=dummy_input, -) dataloader = cube.runtime.syndata.SynDataLoader( shapes=(input_shapes), dtypes=input_dtypes, - batch_dims=(0,0,0), - names=input_names, - append_args=kwargs + batch_dims=(0,0,0, None, None), +) +sample_input = next(dataloader) +print(f'next(dataloader) = {sample_input}') +sample_input_cpu = tuple([val.to(device) if isinstance(val, torch.Tensor) else val for val in sample_input]) + +model = cube.SemanticModel( + model, dummy_input=sample_input_cpu, ) -print(f'next(dataloader) = {next(dataloader)}') @cube.compile(model, dataloader, PAS=PAS, load_content=False) def train_iter(model, dataloader): data = next(dataloader) - # loss = model({'src_tokens':data[0],'src_lengths':data[1],'prev_output_tokens':data[2], }) loss = model(*data) loss.backward() - +train_iter(model, dataloader) # Conduct concrete trace below # sys.path.append('/home/v-junliang/torchscaletest/nni') From 47ec9c50a900a235e43a014ab793cb61ffe67a34 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 3 Mar 2023 10:04:29 +0000 Subject: [PATCH 1287/1892] Merged PR 1467: hotfix: view partition rules hotfix: view partition rules --- cube/graph/function/function.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 944e93bc..454bb6aa 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -756,8 +756,6 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s for hdim in range(len(bracket)): if bracket[hdim] == '1': continue sdim = bracket[hdim] - ospatial.add(bracket[hdim]) - ofirst.append(bracket[hdim]) break if sdim is not None: ospatial.add(sdim) From 9af03db7da2ad0f6b97e245f7c3448b75b43e442 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Fri, 3 Mar 2023 05:27:49 -0800 Subject: [PATCH 1288/1892] update --- cube/graph/parser/mappingfx.py | 10 ++++- cube/graph/parser/parserfx.py | 73 ++++++++++++++++------------------ 2 files changed, 43 insertions(+), 40 deletions(-) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index a91be34f..92bfb613 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -24,6 +24,12 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: raise KeyError(f"{signature} is not supported yet") # return partial(function.UnkownOperator, signature=signature) + @staticmethod + def exist(signature: str) -> bool: + if 'torch.' not in signature and 'cube.runtime.' not in signature: + signature = signature.split('.')[-1] + return signature in SignFx2Op.kOpMap + @staticmethod def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]], code): """ @@ -68,6 +74,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, __ttemplate('masked_fill'): function.MaskedFill, + __ftemplate('embedding'): function.Embedding, # # torch nn functional # @@ -103,7 +110,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('rand'): function.Rand, # __ttemplate('clone'): function.Clone, # - # __ttemplate('add') : function.Add, + __ttemplate('add') : function.Add, + 'add': function.Add, # # __ttemplate('sub') : function.Sub, # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 13332322..00208747 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -5,15 +5,15 @@ from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor +from cube.ir.cten import IRObject import cube.ir as ir from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import DType2IRDType from cube.graph.parser.mappingfx import SignFx2Op +from cube.graph.function.pyfunc import IRPyFunc import torch.fx -global_not_supported = set() - class ErasedDevice: pass @@ -132,19 +132,15 @@ def parse(module: torch.fx.GraphModule, val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) frame.add_var(node.name, val) else: - print(f'WARNING: creation of no-shaped activation for {node.name}') - val = IRFullTensor(shape=[1], requires_grad=True, dtype=ir.int32, name=node.name) # TODO fixme - frame.add_var(node.name, val) + frame.add_var(node.name, IRObject()) # handle nodes all_ir_nodes: List[IRFwOperation] = list() for node in module.graph.nodes: - # print('zql handle node: ', node, node.op, node.meta) + print('zql handle node: ', node, node.op, node.meta) ir_nodes = FxModuleParser.parse_node(node, module, frame) all_ir_nodes += ir_nodes - - print(global_not_supported) - return + # handle outputs output_nodes = [node.all_input_nodes for node in module.graph.nodes if node.op == 'output'] print(f'outputs = {output_nodes}') @@ -210,7 +206,6 @@ def ntype(node: torch.fx.Node): @staticmethod def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation]: - # print("### parse_node {}".format(node)) """ Parse the node and return the IRFwOperation nodes """ @@ -342,34 +337,33 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, """ # get signature fsig = FxModuleParser._get_qualified_name_of_call_function(node.target) - # print(f'parse_prim_function_node: {fsig}') - - # # get inputs - # input_nodes = [input_node for input_node in node.args] - # input_vals = list() - # for index, input_node in enumerate(input_nodes): - # if isinstance(input_node, torch.fx.Node): - # var_name = input_node.name - # val = frame.get_var(var_name) - # input_vals.append(val) - # elif isinstance(input_node, (int, float)): - # input_vals.append(input_node) - # else: - # input_vals.append(None) + print(f'parse_prim_function_node: {fsig}') + + # get inputs + assert len(node.kwargs) == 0 + input_nodes = [input_node for input_node in node.args] + input_vals = list() + for _, input_node in enumerate(input_nodes): + if isinstance(input_node, torch.fx.Node): + var_name = input_node.name + val = frame.get_var(var_name) + input_vals.append(val) + elif isinstance(input_node, (int, float, str)) or input_node is None: + input_vals.append(input_node) + else: + raise RuntimeError(f'Unsupported input node {input_node}, {type(input_node)} in parse function!') # map to IR operator - if fsig in SignFx2Op.kOpMap: - # ir_node = SignFx2Op.map(fsig)(inputs=input_vals) - ir_node = None + if SignFx2Op.exist(fsig): + ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: - print(f'Function {fsig} has not been supported yet!') - global_not_supported.add(fsig) - ir_node = None + assert 'torch.' not in fsig, f'{fsig} is not supported' + ir_node = IRPyFunc(fsig, input_vals, [None]) # TODO gracefully set output output_name = node.name output_val = frame.get_var(output_name) - # ir_node.set_output(0, output_val) + ir_node.set_output(0, output_val) # # push output in the frame # # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) @@ -401,7 +395,7 @@ def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr # print('zql: ', node.args[0].name, node.args[0].op, node.args[0].meta) # get signature fsig = FxModuleParser._get_qualified_name_of_call_method(node.target, node) - # print(f'parse_prim_method_node: {fsig}') + print(f'parse_prim_method_node: {fsig}') # get inputs input_nodes = [input_node for input_node in node.args] @@ -417,16 +411,15 @@ def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr input_vals.append(None) # map to IR operator - if fsig in SignFx2Op.kOpMap: + if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: - print(f'Method {fsig} has not been supported yet!') - global_not_supported.add(fsig) - ir_node = None + assert 'torch.' not in fsig, f'{fsig} is not supported' + ir_node = IRPyFunc(fsig, input_vals, [None]) output_name = node.name output_val = frame.get_var(output_name) - # ir_node.set_output(0, output_val) + ir_node.set_output(0, output_val) return [ir_node] @@ -444,8 +437,10 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram ir_tensor.as_param() frame.add_var(tensor_name, ir_tensor) else: - print(f'attr {node.op} has not been supported yet!') - global_not_supported.add(node.op) + # FIXME: why no need to record the constant value of this var? + # the value can be obtained below: + # var = FxModuleParser.fetch_attr(module, node.target) + frame.add_var(tensor_name, IRObject()) return list() From d604ab73de3249db9a5315e2f31cda67faeeeb23 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 3 Mar 2023 13:28:31 +0000 Subject: [PATCH 1289/1892] Merged PR 1466: Python Runtime Function Support - support torch.Tensor.size --- cube/codegen/module/module.py | 2 +- cube/graph/parser/mappingfx.py | 9 +++++- cube/graph/parser/parserfx.py | 51 ++++++++++++++++++++++++++++------ 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 54ca1ffd..e084873d 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -82,7 +82,7 @@ def __init__(self, execplan: ExecutionPlan) -> None: '\n\n########## Generated Model Code ###########', 'from typing import *', 'import torch', 'import torch.utils.checkpoint as ckpt', - 'import cube', '', ''] + 'import cube', 'import _operator', '', ''] if CompileFlag.use_nnfusion: self.init_code.extend(['import nnfusion', '']) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index a91be34f..eec566a1 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -24,6 +24,12 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: raise KeyError(f"{signature} is not supported yet") # return partial(function.UnkownOperator, signature=signature) + @staticmethod + def exist(signature: str) -> bool: + if 'torch.' not in signature and 'cube.runtime.' not in signature: + signature = signature.split('.')[-1] + return signature in SignFx2Op.kOpMap + @staticmethod def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]], code): """ @@ -103,7 +109,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('rand'): function.Rand, # __ttemplate('clone'): function.Clone, # - # __ttemplate('add') : function.Add, + __ttemplate('add') : function.Add, + 'add': function.Add, # # __ttemplate('sub') : function.Sub, # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 4297efa7..2ea84514 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -5,10 +5,12 @@ from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor +from cube.ir.cten import IRObject import cube.ir as ir from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import DType2IRDType from cube.graph.parser.mappingfx import SignFx2Op +from cube.graph.function.pyfunc import IRPyFunc import torch.fx @@ -127,9 +129,7 @@ def parse(module: torch.fx.GraphModule, val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) frame.add_var(node.name, val) else: - print(f'WARNING: creation of no-shaped activation for {node.name}') - val = IRFullTensor(shape=[1], requires_grad=True, dtype=ir.int32, name=node.name) # TODO fixme - frame.add_var(node.name, val) + frame.add_var(node.name, IRObject()) # handle nodes all_ir_nodes: List[IRFwOperation] = list() @@ -328,9 +328,7 @@ def fetch_attr(mod: torch.fx.GraphModule, target: str): def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: """ parse node like: - Tensor = prim::CallFunction(%5, %input.1, %3, %4) - %5 : Function = prim::Constant[name="linear"]() - %12 : (Tensor, Tensor) = prim::CallFunction(%5, %x1.1, %x2.1) + %add : [#users=2] = call_function[target=operator.add](args = (%x, %x), kwargs = {}) """ # get signature fsig = FxModuleParser._get_qualified_name(node.target) @@ -350,7 +348,20 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, input_vals.append(None) # map to IR operator - ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + if SignFx2Op.exist(fsig): + ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + else: + # case1: unknown torch operator + if FxModuleParser._is_torch_autograd_op(node, frame): + print(f'>>> Find unkown pytorch operation: {fsig}') + fname = fsig.split('.')[-1] if '.' in fsig else fname + ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) + for idx, t in enumerate(input_vals): + ir_node.set_input(idx, t) + # case2: python runtime function + else: + print(f'>>> Set python runtime function: {fsig}') + ir_node = IRPyFunc(fsig, input_vals, [None]) # TODO gracefully set output output_name = node.name @@ -399,7 +410,20 @@ def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr input_vals.append(None) # map to IR operator - ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + if SignFx2Op.exist(fsig): + ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + else: + # case1: unknown torch operator + if FxModuleParser._is_torch_autograd_op(node, frame): + print(f'>>> Find unkown pytorch operation: {fsig}') + fname = fsig.split('.')[-1] if '.' in fsig else fname + ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) + for idx, t in enumerate(input_vals): + ir_node.set_input(idx, t) + # case2: python runtime function + else: + print(f'>>> Set python runtime function: {fsig}') + ir_node = IRPyFunc(fsig, input_vals, [None]) output_name = node.name output_val = frame.get_var(output_name) @@ -453,4 +477,13 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str: for guess in [torch, torch.nn.functional]: if getattr(guess, name, None) is orig_method: return guess.__name__ - raise RuntimeError(f'cannot find module for {orig_method}') \ No newline at end of file + raise RuntimeError(f'cannot find module for {orig_method}') + + @staticmethod + def _is_torch_autograd_op(node: torch.fx.Node, frame: Frame) -> bool: + """Check whether the node is of a pytorch autograd operation.""" + signature: str = FxModuleParser._get_qualified_name(node.target) + # note: some python operations like torch.Tensor.size() doesn't return + # an IRTensor, thus cannot be considered as a pytorch autograd operator. + return signature.startswith('torch.') and \ + isinstance(frame.get_var(node.name), IRFullTensor) From a8f306e97bd49ab09a4bc9d662cfcd31fc1287cc Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 3 Mar 2023 08:25:46 -0800 Subject: [PATCH 1290/1892] support several torch ops --- cube/graph/function/function.py | 45 ++++++++++++++++++++++++++++++--- cube/graph/parser/mappingfx.py | 21 +++++++-------- cube/graph/parser/parserfx.py | 32 +++++++++++++---------- examples/mlp/linearsfx.py | 13 +++++++--- 4 files changed, 77 insertions(+), 34 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 6faa88b2..f5327239 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -45,6 +45,14 @@ def BatchLinear(signature, inputs): return IRDimops(BatchLinear, 'bmm', signature, annos, inputs) +def BMMAdd(signature, inputs): + assert len(inputs) == 3 + annos = [ + 'b m n, b m k^, b k^ n -> b m n' + ] + return IRDimops(BMMAdd, 'baddbmm', signature, annos, inputs) + + def Matmul(signature, inputs: Tuple[IRTensor, IRTensor]): assert len(inputs) == 2 annos = [ @@ -461,12 +469,22 @@ def ReLU(signature, inputs): def Softmax(signature, inputs): - assert len(inputs) == 4 + assert len(inputs) >= 1 annos = ['* -> *'] tensor = inputs[0:1] - dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] - return IRDimops(Softmax, 'softmax', signature, annos, tensor, - dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if len(inputs) == 2: + if isinstance(inputs[1], dict): + return IRDimops(Softmax, 'softmax', signature, annos, tensor, **inputs[1]) + elif isinstance(inputs[1], int): + return IRDimops(Softmax, 'softmax', signature, annos, tensor, dim=inputs[1]) + else: + raise RuntimeError(f'Unexpect intput type {inputs[1]}, {type(inputs[1])}') + elif len(inputs) == 4: + dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] + return IRDimops(Softmax, 'softmax', signature, annos, tensor, + dim=dim, _stacklevel=_stacklevel, dtype=dtype) + else: + raise RuntimeError('Unexpected input num {inputs}') def Dropout(signature, inputs): @@ -837,6 +855,25 @@ def Reshape(signature, inputs): return View(signature, inputs) + +def Permute(signature, inputs): + if isinstance(inputs[1], list): + in_tensor, dims = inputs[0], inputs[1] + else: + in_tensor, dims = inputs[0], inputs[1:] + edim_in = ShapeAnno.create_shape_str(in_tensor.shape) + for idx, dim in enumerate(dims): + if idx != dim: + edim_in[idx] += '^' + assert len(edim_in) == len(dims), f'{len(edim_in)} vs {len(dims)}' + edim_ou = [] + for dim in dims: + assert isinstance(dim, int) + edim_ou.append(copy.copy(edim_in[dim])) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Permute, 'permute', signature, [anno], [in_tensor], dims=dims) + + def Squeeze(signature, inputs): """ out = torch.squeeze(tensor) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index cf8313ab..146c9285 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -71,6 +71,15 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, + __ttemplate('tanh'): function.Tanh, + __ftemplate('softmax') : function.Softmax, + __ttemplate('bmm') : function.BatchLinear, + __ttemplate('pow'): function.Pow, + __ttemplate('baddbmm'): function.BMMAdd, + __ttemplate('permute'): function.Permute, + __ttemplate('transpose'): function.Transpose, + + # TODO __ftemplate('layer_norm'): function.LayerNorm, # # torch nn functional @@ -79,10 +88,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # # __ttemplate('matmul'): function.Matmul, # - # __ftemplate('softmax') : function.Softmax, - # - # __ftemplate('dropout') : function.Dropout, - # # __ftemplate('gelu') : function.GeLU, # __ttemplate('gelu') : function.GeLU, # @@ -122,21 +127,13 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('ge'): function.CompareGE, # __ttemplate('le'): function.CompareLE, # - # __ttemplate('pow'): function.Pow, - # # __ttemplate('sin'): function.Sin, # # __ttemplate('cos'): function.Cos, # - # __ttemplate('tanh'): function.Tanh, - # - # __ttemplate('bmm') : function.BatchLinear, - # # __ttemplate('sum') : function.Sum, # __ttemplate('mean') : function.Mean, # - # __ttemplate('transpose') : function.Transpose, - # # __ttemplate('view'): function.View, # # __ttemplate('reshape'): function.Reshape, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index b61b3d89..db77bccd 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -337,22 +337,26 @@ def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, print(f'parse_prim_function_node: {fsig}') # get inputs - input_nodes = [input_node for input_node in node.args] - input_vals = list() - for index, input_node in enumerate(input_nodes): - if isinstance(input_node, torch.fx.Node): - var_name = input_node.name - val = frame.get_var(var_name) - input_vals.append(val) - elif isinstance(input_node, (int, float)): - # kw scalar args - input_vals.append(input_node) + def extract_val(fx_node): + if isinstance(fx_node, torch.fx.Node): + var_name = fx_node.name + return frame.get_var(var_name) + elif isinstance(fx_node, (int, float)): + # scalar args + return fx_node + elif fx_node is None: + return None else: - input_vals.append(None) + raise RuntimeError(f'Unsupported input node {fx_node}, {type(fx_node)}') - if 'layer_norm' in fsig: - print(input_nodes) - print(frame._vars) + input_vals = list() + for item in node.args: + input_vals.append(extract_val(item)) + if node.kwargs: + input_kwvals = {} + for k, v in node.kwargs.items(): + input_kwvals[k] = extract_val(v) + input_vals.append(input_kwvals) # map to IR operator ir_node = SignFx2Op.map(fsig)(inputs=input_vals) diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 1b935b20..3201ce4f 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -61,6 +61,11 @@ def __init__(self, dim, mult=1, nlayers=4): def forward(self, data, mask): x = data.masked_fill(mask, 0.0) x = x.fill_(0.0) + x = torch.nn.functional.softmax(x, dim=-1) + x = torch.bmm(x, x) + x = torch.baddbmm(x, x, x) + x = torch.tanh(x) + x = torch.pow(x, x) for layer in self.layers: x = layer(x) x = torch.nn.functional.relu(x) @@ -71,12 +76,12 @@ def forward(self, data, mask): x = x.squeeze() x = torch.triu(x, 1) x = torch.nan_to_num(x) - # ne cannot backward + # ne and eq cannot backward # x = torch.ne(x, 1.0) - # x = torch.nn.functional.dropout(x, self.p) - # x = x * self.y - x = torch.cumsum(x, -1) # y = torch.eq(x, 1.0) + x = torch.cumsum(x, -1) + x = x.permute(0, 2, 1) + x = x.transpose(1, 2) loss = torch.sum(x) # long cannot backward # loss = loss.long() From 973c8ef832dd50dbe9bc2e605b5c0c6301afb917 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Fri, 3 Mar 2023 23:58:48 -0800 Subject: [PATCH 1291/1892] update --- cube/graph/parser/parserfx.py | 29 +++++++++---- tests/parser/test_bloom.py | 82 ++--------------------------------- 2 files changed, 24 insertions(+), 87 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index ee79b1b8..24f37fdf 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -78,13 +78,28 @@ def parse(module: torch.fx.GraphModule, inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] print(f'inputs = {inputs}') - # remove dead nodes - from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler - DCEHandler(module).eliminate_dead_code() + if input_shapes is not None and len(input_shapes) != len(inputs): + print(f'module(type = {type(module)}.__dict__.keys() = {module.__dict__.keys()}') + print(f'input shape mismatch (got {len(input_shapes)} != {len(inputs)})') + # TODO fixme raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + + if input_shapes is not None: + ## shape propagation + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + sample_inputs = dummy_inputs if dummy_inputs else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] + sample_input_tensors = [sample_inputs[input] for input in sample_inputs] if type(sample_inputs) is dict else sample_inputs + from torch.fx.passes.shape_prop import ShapeProp + ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) + else: + assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' + # remove dead nodes + from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler + DCEHandler(module).eliminate_dead_code() - # shape propagation - from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp - KwargsShapeProp(module).propagate(dummy_inputs) + # shape propagation + from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp + KwargsShapeProp(module).propagate(dummy_inputs) for node in module.graph.nodes: if 'tensor_meta' in node.meta: @@ -99,8 +114,6 @@ def parse(module: torch.fx.GraphModule, else: print(f'{node.name} does not has tensor_meta') - # return - # handle graph input -- some inputs could be None or not tensor default_dtype = torch.get_default_dtype() kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype diff --git a/tests/parser/test_bloom.py b/tests/parser/test_bloom.py index 232cc09a..41aff42b 100644 --- a/tests/parser/test_bloom.py +++ b/tests/parser/test_bloom.py @@ -2,19 +2,11 @@ import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -# model_name = "bigscience/bloom-7b1" -# model_path = "/home/quzha/bloom7b1" model_name = "bigscience/bloom-560m" model_path = "/home/quzha/bloom560m" -# model_name = "facebook/opt-66b" -# model_name = "facebook/opt-iml-30b" -# model_name = "facebook/optiml30b" -# model_name = "facebook/opt-iml-1.3b" -# model_name = "facebook/opt-13b" -# model_path = "/home/quzha/opt13b" -print("Loading model...") #device_map="balanced", -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir=model_path)#.cuda() +print("Loading model...") +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir=model_path) print(type(model), '; is nn.Module? ', isinstance(model, nn.Module)) print("Model's generation config which does not list default values: ", model.generation_config) print("Loading tokenizer...") @@ -22,87 +14,19 @@ print("Loading Done!") prompt = "If I want to travel to a new city, I should plan my trip as follows:" #input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() -inputs = tokenizer(prompt, return_tensors="pt")#.to('cuda:0') +inputs = tokenizer(prompt, return_tensors="pt") # Cube # from cube.graph import parser # ir_graph = parser.convert_model(model, input_shapes=[1, 17], save_content=False) -# model(input_ids, None, None, None, None, None, None, None, None, None) - print("concrete tracing model...") from nni.common.concrete_trace_utils import concrete_trace -#traced_graph = concrete_trace(model, (input_ids, None, None, None, None, None, None, None, None, None), use_function_patch=True, -# autowrap_leaf_class={torch.finfo: ((), False)}) -#traced_graph = concrete_trace(model, inputs, use_function_patch=True, traced_graph = concrete_trace(model, inputs, use_operator_patch=True, autowrap_leaf_class={torch.finfo: ((), False)}) -# traced_graph.graph.print_tabular() print("tracing model done.") print("parsing fx graph to cube graph...") from cube.graph.parser import FxModuleParser -# dummy_inputs = [inputs.input_ids, None, inputs.attention_mask, None, None, None, None, None, None, None, {}] -# FxModuleParser.parse(traced_graph, dummy_inputs) FxModuleParser.parse(traced_graph, dummy_inputs=inputs) print("parsing done.") - -# AutoDist -# from autodist.apis import compile -# from cube.runtime.resource import EnvResource -# resource = EnvResource() -# graph = compile(ir_graph, resource) - - -# print(type(model), '; is nn.Module? ', isinstance(model, nn.Module)) -# print("Model's generation config which does not list default values: ", model.generation_config) -# print("Loading tokenizer...") -# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) -# print("Loading Done!") - -# #prompt = "If you are a calculator, please tell me the results of 32 x 23 =" -# # prompt = "what is the english word that means little modification and starts with character "t"? the english word is" -# # prompt = "If I want to travel to USA, I need to apply for a" -# prompt = "If I want to travel to a new city, I should plan my trip as follows:" -# # prompt = "I look forward to" -# # prompt = "Today was an amazing day because" -# # prompt = "What is the color of a carrot?\nA:" - - -# # Some of the commonly adjusted parameters: max_new_tokens, num_beams, do_sample, num_return_sequences -# # https://huggingface.co/blog/how-to-generate -# # https://huggingface.co/docs/transformers/v4.26.1/en/generation_strategies#text-generation-strategies -# # Beam-search decoding -# generation_config_beam = GenerationConfig( -# num_beams=4, -# do_sample=False, -# early_stopping=True, -# decoder_start_token_id=0, -# eos_token_id=model.config.eos_token_id, -# pad_token=model.config.pad_token_id, -# ) -# # Beam-search decoding without early stopping -# generation_config_beam_fixed_len = GenerationConfig( -# num_beams=4, -# do_sample=False, -# early_stopping=False, -# max_new_tokens=20, -# decoder_start_token_id=0, -# eos_token_id=model.config.eos_token_id, -# pad_token=model.config.pad_token_id, -# ) -# # Contrastive search -# generation_config_contrastive = GenerationConfig( -# penalty_alpha=0.6, -# top_k=4, -# max_new_tokens=100, -# ) - -# print("Tokenizing prompt...") -# input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() -# print("input_ids shape: ", input_ids.size()) -# print("Generating sequence ids...") -# generated_ids = model.generate(input_ids, generation_config=generation_config_beam_fixed_len) -# print("Decoding sequence ids...") -# output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) -# print(output) From 483eca4541f816ecc0cbfc483f522a51ea5dea10 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sat, 4 Mar 2023 02:01:45 -0800 Subject: [PATCH 1292/1892] refine code --- cube/graph/function/dimops.py | 2 +- cube/graph/function/function.py | 25 +++++++++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 7204c444..c8811348 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -409,7 +409,7 @@ def parse(anno: str) -> Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]]: """ # to inputs and outputs if '->' not in anno: - raise ValueError("Syntax Error: Expected -> in operator anno") + raise ValueError(f"Syntax Error: Expected -> in operator anno: {anno}") inputs, outputs = anno.split('->') inputs = inputs.split(',') outputs = outputs.split(',') diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index f5327239..9f2e711a 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -470,18 +470,28 @@ def ReLU(signature, inputs): def Softmax(signature, inputs): assert len(inputs) >= 1 - annos = ['* -> *'] - tensor = inputs[0:1] + tensor = inputs[0] + edim_in = ShapeAnno.create_shape_str(tensor.shape) if len(inputs) == 2: if isinstance(inputs[1], dict): - return IRDimops(Softmax, 'softmax', signature, annos, tensor, **inputs[1]) + edim_in[inputs[1]['dim']] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], **inputs[1]) elif isinstance(inputs[1], int): - return IRDimops(Softmax, 'softmax', signature, annos, tensor, dim=inputs[1]) + dim = inputs[1] + edim_in[dim] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], dim=inputs[1]) else: raise RuntimeError(f'Unexpect intput type {inputs[1]}, {type(inputs[1])}') elif len(inputs) == 4: dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] - return IRDimops(Softmax, 'softmax', signature, annos, tensor, + dim = inputs[1] + edim_in[dim] += '^' + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], dim=dim, _stacklevel=_stacklevel, dtype=dtype) else: raise RuntimeError('Unexpected input num {inputs}') @@ -504,7 +514,7 @@ def EQ(signature, inputs): edim_in0 = ShapeAnno.create_shape_str(input0.shape) edim_ou = copy.copy(edim_in0) - if isinstance(input1, float): + if isinstance(input1, (int, float)): anno = OpAnno.create_op_str([edim_in0], [edim_ou]) return IRDimops(EQ, 'eq', signature, [anno], [input0], other=input1) else: @@ -519,7 +529,7 @@ def NE(signature, inputs): edim_in0 = ShapeAnno.create_shape_str(input0.shape) edim_ou = copy.copy(edim_in0) - if isinstance(input1, float): + if isinstance(input1, (int, float)): anno = OpAnno.create_op_str([edim_in0], [edim_ou]) return IRDimops(NE, 'ne', signature, [anno], [input0], other=input1) else: @@ -569,7 +579,6 @@ def LayerNorm(signature, inputs): cube.runtime.function.layer_norm(input, weight, bias, normliazed_shape, eps) """ if 'torch.' in signature: - print(inputs) tensor, normalized_shape, weight, bias, eps = inputs assert isinstance(normalized_shape, list), f"normalized_shape for layer_norm can only be List[int]" else: From a0df8d4e279b711c4c9ab794bfd78b7d56a69cf6 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Sat, 4 Mar 2023 12:13:21 +0000 Subject: [PATCH 1293/1892] Merged PR 1458: [do not merge] improve fx parser during supporting bloom improve fx parser during supporting bloom --- cube/graph/parser/mappingfx.py | 4 +- cube/graph/parser/parserfx.py | 388 +++++++++++++-------------------- tests/parser/test_bloom.py | 32 +++ 3 files changed, 181 insertions(+), 243 deletions(-) create mode 100644 tests/parser/test_bloom.py diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index eec566a1..790742a3 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -73,7 +73,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ne') : function.NE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, - __ttemplate('masked_fill'): function.MaskedFill, + # __ttemplate('masked_fill'): function.MaskedFill, + __ftemplate('embedding'): function.Embedding, # # torch nn functional # @@ -110,7 +111,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('clone'): function.Clone, # __ttemplate('add') : function.Add, - 'add': function.Add, # # __ttemplate('sub') : function.Sub, # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 2ea84514..24f37fdf 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -1,7 +1,7 @@ import torch import enum import re -from typing import Any, List, Tuple, Optional +from typing import Any, List, Tuple, Optional, Callable from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor @@ -65,7 +65,7 @@ def shape_refine(shape: torch.Size) -> torch.Size: @staticmethod def parse(module: torch.fx.GraphModule, input_shapes: Optional[Tuple[List[int],]] = None, - dummy_input = None, + dummy_inputs = None, frame: Frame = None) \ -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: """ @@ -77,52 +77,68 @@ def parse(module: torch.fx.GraphModule, inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] print(f'inputs = {inputs}') + if input_shapes is not None and len(input_shapes) != len(inputs): print(f'module(type = {type(module)}.__dict__.keys() = {module.__dict__.keys()}') print(f'input shape mismatch (got {len(input_shapes)} != {len(inputs)})') # TODO fixme raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + + if input_shapes is not None: + ## shape propagation + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + sample_inputs = dummy_inputs if dummy_inputs else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] + sample_input_tensors = [sample_inputs[input] for input in sample_inputs] if type(sample_inputs) is dict else sample_inputs + from torch.fx.passes.shape_prop import ShapeProp + ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) + else: + assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' + # remove dead nodes + from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler + DCEHandler(module).eliminate_dead_code() - ## shape propagation - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - sample_inputs = dummy_input if dummy_input else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] - sample_input_tensors = [sample_inputs[input] for input in sample_inputs] if type(sample_inputs) is dict else sample_inputs - - from torch.fx.passes.shape_prop import ShapeProp - ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) + # shape propagation + from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp + KwargsShapeProp(module).propagate(dummy_inputs) - # for node in module.graph.nodes: - # print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) for node in module.graph.nodes: - print(f'node.name = {node.name}') - if hasattr(node, 'meta') and node.meta.get('tensor_meta') is not None: - if node.name == 'output': - print('pause here') - if not hasattr(node.meta['tensor_meta'], 'dtype'): - for per_output_meta in node.meta['tensor_meta']: - if isinstance(per_output_meta, torch.fx.passes.shape_prop.TensorMetadata): - print(node.name, '-sub-output', per_output_meta.dtype, per_output_meta.shape) - else: - print(f'ERROR: skip {node.name}\'s non TensorMetadata sub-output') + if 'tensor_meta' in node.meta: + if node.meta['type'] is type(tuple()): + print(f'{node.name} is tuple type') + elif node.meta['type'] is type(torch.fx.immutable_collections.immutable_dict()): + print(f'{node.name} is immutable_dict type') + assert isinstance(node.meta['tensor_meta'], dict) else: + assert node.meta['type'] is type(torch.Tensor()) or node.meta['type'] is type(torch.nn.parameter.Parameter()) print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) else: - print(f'ERROR: none tensor_meta of Node {node.name}') + print(f'{node.name} does not has tensor_meta') - # handle graph input -- Assuming all the inputs are tensors + # handle graph input -- some inputs could be None or not tensor + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, extend to other input types for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) - shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] + if hasattr(dummy_inputs, input.name): + print(f'dummy_inputs has {input.name}') + shape = getattr(dummy_inputs, input.name).size()# None if dummy_inputs is None else dummy_inputs[idx].size() + else: + # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name + print(f'dummy_inputs does not have {input.name}') + shape = None + # FIXME: use the input's real dtype dtype = kDefaultType val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) input_val = [frame.get_var(input.name) for input in inputs] - # add activations to frame, including call_func output and final output + # add activations to frame, including call_func/call_method output and final output activation_op_strs = {'call_function', 'output', 'call_method'} activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] for node in activation_nodes: if hasattr(node, 'meta') and node.meta.get('tensor_meta') and hasattr(node.meta['tensor_meta'], 'dtype'): + assert isinstance(node, torch.fx.Node) shape = node.meta['tensor_meta'].shape shape = FxModuleParser.shape_refine(shape) dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) @@ -174,35 +190,10 @@ def ntype(node: torch.fx.Node): return FxNodeKind.Output if node.op == 'call_method': return FxNodeKind.PrimCallMethod - # if node.kind() == 'prim::CallMethod': - # return FxNodeKind.PrimCallMethod - # if node.kind() == 'prim::CallFunction': # the op call - # return FxNodeKind.PrimCallFunction - # if node.kind() == 'prim::Constant': - # return FxNodeKind.PrimConstant - # if node.kind().startswith('aten::'): - # return FxNodeKind.AtenOp - # if node.kind() == 'prim::If': - # return FxNodeKind.PrimIf - # if node.kind() == 'prim::Loop': - # return FxNodeKind.PrimLoop - # if node.kind() == 'prim::ListConstruct': - # return FxNodeKind.PrimListConstruct - # if node.kind() == 'prim::TupleConstruct': - # return FxNodeKind.PrimListConstruct - # if node.kind() == 'prim::ListUnpack': - # return FxNodeKind.PrimListUnpack - # if node.kind() == 'prim::TupleUnpack': - # return FxNodeKind.PrimListUnpack - # if node.kind() == 'prim::PythonOp': - # return FxNodeKind.PrimPythonOp - # if node.kind() == 'prim::device': - # return FxNodeKind.PrimDevice raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") @staticmethod def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation]: - # print("### parse_node {}".format(node)) """ Parse the node and return the IRFwOperation nodes """ @@ -213,14 +204,12 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == FxNodeKind.Output: return [] - if node_type == FxNodeKind.PrimCallFunction: - return FxModuleParser.parse_prim_function_node(node, module, frame) - if node_type == FxNodeKind.PrimCallMethod: - return FxModuleParser.parse_prim_method_node(node, module, frame) + if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): + return FxModuleParser.parse_prim_function_method(node, module, frame) if node_type == FxNodeKind.PrimGetAttr: return FxModuleParser.parse_prim_attr_node(node, module, frame) if node_type == FxNodeKind.PrimCallModule: - return FxModuleParser.parse_prim_module(node, module, frame) + raise RuntimeError(f"parse_prim_module is not supported.") # TODO bother assigning all ignored prim functions new NodeKinds? if node_type == FxNodeKind.PrimDevice: @@ -229,91 +218,6 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] except Exception: raise RuntimeError(f"\n\nParsing error at node:\n\t{node}\n") - - @staticmethod - def parse_prim_module(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation]: - """ - :param node: - :param module: - :param frame: - :return: - """ - raise RuntimeError(f"parse_prim_module needs update") - - input_nodes = node.all_input_nodes - for input_node in input_nodes: - var_name = input_node.name - val = frame.get_var(var_name) - frame.push_param(var_name) - - # TODO skip self_module in torchscript - call_module = module - - label = node.name - node_target_stack = node.target.split('.') - - #TODO check leaf module, iterate if not - - leaf_module = module - for node_target_stack_iter in node_target_stack: - leaf_module = getattr(leaf_module, node_target_stack_iter) - - _, ir_nodes, outputs_val = FxModuleParser.parse_nn_module(node, leaf_module, frame=frame) - - # pop out the frame - frame.pop_param(times=len(input_nodes) - 1) - - # # handle outputs - # # TODO outputs vs output - # outputs = [node] - # # outputs = [output for output in node.outputs()] - # for output, val in zip(outputs, outputs_val): - # frame.add_var(output.name, val) - - return ir_nodes - - @staticmethod - def parse_nn_module(node: torch.fx.Node, method: torch.nn.Module, frame: Frame): - """ - Parse module method - """ - - input_var_name = [input_node.name for input_node in node.all_input_nodes] - input_val = [frame.get_var(var_name) for var_name in input_var_name] - - all_ir_nodes: List[IRFwOperation] = list() - - # handle graph output - - fsig = type(method).__name__ - # ir_node = SignFx2Op.map(fsig)(input=input_val) - func = SignFx2Op.map(fsig) - weights_names = None #TODO obtain parameter name list - if weights_names is not None: - for idx, weight_name in enumerate(weights_names): - #create FullTensor - weight = getattr(method, weight_name) - if weight is not None: - weight_fulltensor_name = node.name + "-" + weight_name - weight_fulltensor = IRFullTensor(weight.shape, weight_fulltensor_name, requires_grad=False) - # frame.add_attr(weight_fulltensor_name, weight_fulltensor) - # tmp_ones_tensor = torch.ones(weight.shape, dtype=torch.get_default_dtype()) - # frame.add_attr_content(weight_fulltensor.tid, tmp_ones_tensor) - frame.add_var(weight_fulltensor_name, weight_fulltensor) - input_val.append(weight_fulltensor) - else: - input_val.append(None) - - ir_node = func(inputs=input_val) - all_ir_nodes += [ir_node] - - outputs = [node] - output_var_name = [output.name for output in outputs] - output_val = [frame.get_var(var_name) for var_name in output_var_name] - - # frame.pop_var() - return input_val, all_ir_nodes, output_val - @staticmethod def fetch_attr(mod: torch.fx.GraphModule, target: str): target_atoms = target.split('.') @@ -325,96 +229,53 @@ def fetch_attr(mod: torch.fx.GraphModule, target: str): return attr_itr @staticmethod - def parse_prim_function_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: - """ - parse node like: - %add : [#users=2] = call_function[target=operator.add](args = (%x, %x), kwargs = {}) - """ + def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: # get signature - fsig = FxModuleParser._get_qualified_name(node.target) - print(f'parse_prim_function_node: {fsig}') - - # get inputs - input_nodes = [input_node for input_node in node.args] - input_vals = list() - for index, input_node in enumerate(input_nodes): - if isinstance(input_node, torch.fx.Node): - var_name = input_node.name - val = frame.get_var(var_name) - input_vals.append(val) - elif isinstance(input_node, (int, float)): - input_vals.append(input_node) - else: - input_vals.append(None) - - # map to IR operator - if SignFx2Op.exist(fsig): - ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + fsig = FxModuleParser._get_qualified_name(node.target, node) + if isinstance(node.target, str): + print(f'parse_prim_method_node: {fsig}') else: - # case1: unknown torch operator - if FxModuleParser._is_torch_autograd_op(node, frame): - print(f'>>> Find unkown pytorch operation: {fsig}') - fname = fsig.split('.')[-1] if '.' in fsig else fname - ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) - for idx, t in enumerate(input_vals): - ir_node.set_input(idx, t) - # case2: python runtime function + print(f'parse_prim_function_node: {fsig}') + + def handle_tuple(fx_node: tuple) -> tuple: + vals = [] + for ele in fx_node: + if isinstance(ele, torch.fx.Node): + vals.append(frame.get_var(ele.name)) + elif isinstance(ele, tuple): + vals.append(handle_tuple(ele)) + else: + assert not isinstance(ele, (list, dict)) + vals.append(ele) + return tuple(vals) + + def extract_val(fx_node): + if isinstance(fx_node, torch.fx.Node): + var_name = fx_node.name + return frame.get_var(var_name) + elif isinstance(fx_node, (int, float, str, torch.dtype)) or fx_node is None: + return fx_node + elif isinstance(fx_node, tuple): + return handle_tuple(fx_node) else: - print(f'>>> Set python runtime function: {fsig}') - ir_node = IRPyFunc(fsig, input_vals, [None]) - - # TODO gracefully set output - output_name = node.name - output_val = frame.get_var(output_name) - ir_node.set_output(0, output_val) - - # # push output in the frame - # # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) - # # : >>> dir(a) - # # : >>> a.elements() # [TensorType, TensorType] - # cnt = 0 - # for output in node.outputs(): - # if isinstance(output.type(), torch._C.TupleType): - # tuplen = len(output.type().elements()) - # ir_output = [ir_node.output(idx) for idx in range(cnt, cnt + tuplen)] - # cnt += tuplen - # else: - # ir_output = ir_node.output(cnt) - # cnt += 1 - # frame.add_var(output.debugName(), ir_output) - # - # if cnt != len(ir_node.outputs()): - # raise RuntimeError( - # f"Parse fail: {fsig} has {cnt} outputs != pre-defined {len(ir_node.outputs())}" - # ) - - return [ir_node] - - @staticmethod - def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: - # get signature - fsig = FxModuleParser._get_qualified_name(node.target) - print(f'parse_prim_method_node: {fsig}') + raise RuntimeError(f'Unsupported input node {fx_node}, {type(fx_node)} in parse function!') # get inputs - input_nodes = [input_node for input_node in node.args] input_vals = list() - for index, input_node in enumerate(input_nodes): - if isinstance(input_node, torch.fx.Node): - var_name = input_node.name - val = frame.get_var(var_name) - input_vals.append(val) - elif isinstance(input_node, (int, float)): - input_vals.append(input_node) - else: - input_vals.append(None) + for item in node.args: + input_vals.append(extract_val(item)) + if node.kwargs: + input_kwvals = {} + for k, v in node.kwargs.items(): + input_kwvals[k] = extract_val(v) + input_vals.append(input_kwvals) # map to IR operator if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: # case1: unknown torch operator - if FxModuleParser._is_torch_autograd_op(node, frame): + if FxModuleParser._is_torch_autograd_op(node, frame, fsig): print(f'>>> Find unkown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fname ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) @@ -425,6 +286,7 @@ def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr print(f'>>> Set python runtime function: {fsig}') ir_node = IRPyFunc(fsig, input_vals, [None]) + # TODO gracefully set output output_name = node.name output_val = frame.get_var(output_name) ir_node.set_output(0, output_val) @@ -435,38 +297,83 @@ def parse_prim_method_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: assert node is not None tensor_name = node.name - - tensor_shape = node.meta['tensor_meta'].shape - # tensor_dtype = node.meta['tensor_meta'].dtype - #TODO assume it is weight - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=kDefaultType) - ir_tensor.as_param() - frame.add_var(tensor_name, ir_tensor) + if 'tensor_meta' in node.meta: + tensor_shape = node.meta['tensor_meta'].shape + # tensor_dtype = node.meta['tensor_meta'].dtype + #TODO assume it is weight + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=kDefaultType) + ir_tensor.as_param() + frame.add_var(tensor_name, ir_tensor) + else: + # FIXME: why no need to record the constant value of this var? + # the value can be obtained below: + # var = FxModuleParser.fetch_attr(module, node.target) + print(f'WARNING: {node.name} {node.meta} in attr node uses empty IRObject!') + frame.add_var(tensor_name, IRObject()) return list() + # # NOTE: this is a function in torch.fx + # @staticmethod + # def _get_qualified_name(func: Callable[..., Any]) -> str: + # # things like getattr just appear in builtins + # if getattr(builtins, func.__name__, None) is func: + # return func.__name__ + # # torch.Tensor.{fn} + # if isinstance(func, types.MethodDescriptorType) and func is getattr(torch.Tensor, func.__name__, None): + # return f"torch.Tensor.{func.__name__}" + # name = func.__name__ + # module = FxModuleParser._find_module_of_method(func) + # module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + # # Fixup segment_reduce mismatch + # if module == "torch" and name == "segment_reduce": + # name = "_" + name + # return f'{module}.{name}' - from typing import Callable - # import python_stubs.buildins @staticmethod - def _get_qualified_name(func: Callable[..., Any]) -> str: + def _get_qualified_name(node_target: str | Callable[..., Any], node: torch.fx.Node = None) -> str: + if isinstance(node_target, str): + assert node is not None + return FxModuleParser._get_qualified_name_of_call_method(node_target, node) + else: + return FxModuleParser._get_qualified_name_of_call_function(node_target) + + @staticmethod + def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str: + """ + The target field of call_function node must be an callable object. + """ # # things like getattr just appear in builtins # if getattr(builtins, func.__name__, None) is func: # return func.__name__ # TODO(yizhu1): find a general solution - if isinstance(func, str): - for module, module_name in [(torch, 'torch'), (torch.Tensor, 'torch.Tensor')]: - lib_func = getattr(module, func, None) - if lib_func is not None and callable(lib_func): - return f'{module_name}.{func}' - raise RuntimeError(f'cannot find module for {func}') - name = func.__name__ - module = FxModuleParser._find_module_of_method(func) + assert callable(node_target) + name = node_target.__name__ + module = FxModuleParser._find_module_of_method(node_target) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module return f'{module}.{name}' + @staticmethod + def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> str: + """ + The target field of call_method node must be a string. + """ + assert isinstance(node_target, str) + for module, module_name in [(torch, 'torch'), (torch.Tensor, 'torch.Tensor')]: + lib_func = getattr(module, node_target, None) + if lib_func is not None and callable(lib_func): + return f'{module_name}.{node_target}' + assert len(node.args) == 1, f'invalid args {node.args} in {node.name}, {node.target}, {node.meta}' + assert len(node.kwargs) == 0, f'invalid kwargs {node.kwargs} in {node.name}, {node.target}, {node.meta}' + # example node.args[0].meta is {'type': } + in_type = node.args[0].meta['type'] + assert node_target in in_type().__dir__() + sig = f'{in_type.__name__}.{node_target}' + print(f'The method is not torch or Tensor, but {sig}') + return sig + # this is fixed on master, WAR for 1.5 @staticmethod def _find_module_of_method(orig_method: Callable[..., Any]) -> str: @@ -480,9 +387,8 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str: raise RuntimeError(f'cannot find module for {orig_method}') @staticmethod - def _is_torch_autograd_op(node: torch.fx.Node, frame: Frame) -> bool: + def _is_torch_autograd_op(node: torch.fx.Node, frame: Frame, signature: str) -> bool: """Check whether the node is of a pytorch autograd operation.""" - signature: str = FxModuleParser._get_qualified_name(node.target) # note: some python operations like torch.Tensor.size() doesn't return # an IRTensor, thus cannot be considered as a pytorch autograd operator. return signature.startswith('torch.') and \ diff --git a/tests/parser/test_bloom.py b/tests/parser/test_bloom.py new file mode 100644 index 00000000..41aff42b --- /dev/null +++ b/tests/parser/test_bloom.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +model_name = "bigscience/bloom-560m" +model_path = "/home/quzha/bloom560m" + +print("Loading model...") +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir=model_path) +print(type(model), '; is nn.Module? ', isinstance(model, nn.Module)) +print("Model's generation config which does not list default values: ", model.generation_config) +print("Loading tokenizer...") +tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) +print("Loading Done!") +prompt = "If I want to travel to a new city, I should plan my trip as follows:" +#input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() +inputs = tokenizer(prompt, return_tensors="pt") + +# Cube +# from cube.graph import parser +# ir_graph = parser.convert_model(model, input_shapes=[1, 17], save_content=False) + +print("concrete tracing model...") +from nni.common.concrete_trace_utils import concrete_trace +traced_graph = concrete_trace(model, inputs, use_operator_patch=True, + autowrap_leaf_class={torch.finfo: ((), False)}) +print("tracing model done.") + +print("parsing fx graph to cube graph...") +from cube.graph.parser import FxModuleParser +FxModuleParser.parse(traced_graph, dummy_inputs=inputs) +print("parsing done.") From c8b480e84c882e23d0ffbe0624d6ee4308d7b391 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sat, 4 Mar 2023 05:05:55 -0800 Subject: [PATCH 1294/1892] save local work --- cube/graph/function/function.py | 16 ++++++++++++++++ cube/graph/parser/mappingfx.py | 1 + examples/mlp/linearsfx.py | 2 ++ 3 files changed, 19 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 9f2e711a..cceb16f1 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -260,6 +260,21 @@ def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: return lhs_shape, rhs_shape, out_shape +def Expand(signature, inputs): + input = inputs[0] + sizes = inputs[1:] + + edim_in = ShapeAnno.create_shape_str(input.shape) + assert len(input.shape) == len(sizes) + for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes)): + if dim == 1 and dim != expand_dim: + edim_in[idx] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(Expand, 'expand', signature, [anno], [input], sizes=sizes) + + def Clone(signature, inputs): """ torch.clone(input, *, memory_format=torch.preserve_format) @@ -900,6 +915,7 @@ def Squeeze(signature, inputs): return IRDimops(Squeeze, 'squeeze', signature, [anno], [input]) + def Unsqueeze(signature, inputs): """ out = torch.unsqueeze(tensor, dim) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 146c9285..8fdf70c5 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -78,6 +78,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('baddbmm'): function.BMMAdd, __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, + __tttemplate('expand'): function.Expand, # TODO __ftemplate('layer_norm'): function.LayerNorm, diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 3201ce4f..34422de5 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -83,6 +83,8 @@ def forward(self, data, mask): x = x.permute(0, 2, 1) x = x.transpose(1, 2) loss = torch.sum(x) + # loss = loss.expand(2) + # loss = torch.sum(loss) # long cannot backward # loss = loss.long() return loss From 8635f2beb07e5036bb1e086e383f2afb66c6d3e2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 4 Mar 2023 13:31:00 +0000 Subject: [PATCH 1295/1892] Merged PR 1469: parsing output node with complex data type support --- cube/graph/parser/parserfx.py | 62 +++++++++++++++------- tests/parser/test_fx_ops.py | 97 +++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 19 deletions(-) create mode 100644 tests/parser/test_fx_ops.py diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 24f37fdf..af2857dc 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -1,11 +1,11 @@ import torch import enum import re -from typing import Any, List, Tuple, Optional, Callable +from typing import Any, List, Tuple, Optional, Callable, Union from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor -from cube.ir.cten import IRObject +from cube.ir.cten import IRObject, IRCell import cube.ir as ir from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import DType2IRDType @@ -153,20 +153,7 @@ def parse(module: torch.fx.GraphModule, ir_nodes = FxModuleParser.parse_node(node, module, frame) all_ir_nodes += ir_nodes - # handle outputs - output_nodes = [node.all_input_nodes for node in module.graph.nodes if node.op == 'output'] - print(f'outputs = {output_nodes}') - output_var_name = [output.name for output in [item for sublist in output_nodes for item in sublist]] - output_val = [frame.get_var(var_name) for var_name in output_var_name] - - # flatten output_val - outputs = list() - for val in output_val: - if isinstance(val, list): - outputs += val - else: - outputs.append(val) - output_val = outputs + output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] frame.pop_var() frame.pop_attr() @@ -202,7 +189,7 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == FxNodeKind.Placeholder: return [] if node_type == FxNodeKind.Output: - return [] + return FxModuleParser.parse_prim_output_node(node, module, frame) if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): return FxModuleParser.parse_prim_function_method(node, module, frame) @@ -274,17 +261,20 @@ def extract_val(fx_node): if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: + input_vals = [extract_val(v) for v in node.args] + kwargs = {key: extract_val(v) for key, v in node.kwargs.items()} # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): print(f'>>> Find unkown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fname ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) + ir_node.kwargs = kwargs for idx, t in enumerate(input_vals): ir_node.set_input(idx, t) # case2: python runtime function else: print(f'>>> Set python runtime function: {fsig}') - ir_node = IRPyFunc(fsig, input_vals, [None]) + ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) # TODO gracefully set output output_name = node.name @@ -315,6 +305,40 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram return list() + @staticmethod + def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: + assert len(node.args) == 1 and len(node.kwargs) == 0 + ir_nodes = [] + + # handle complex outputs + def generate_outputs(val: Any, _ops: List) -> IRObject: + """Support complex data type of List, Tuple, Dict, Tensor/Object""" + if isinstance(val, list): + inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) + output = IRObject() + _ops.append(IRPyFunc('(lambda *args: list(args))', inputs, [output])) + return output + if isinstance(val, tuple): + inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) + output = IRObject() + _ops.append(IRPyFunc('(lambda *args: args)', inputs, [output])) + return output + if isinstance(val, dict): + output = IRObject() + assert all(not isinstance(key, torch.fx.Node) for key in val.keys()), f"output dict cannot have torch.fx.Node is key" + keys = tuple(str(key) for key in val.keys()) + values = generate_outputs(tuple(generate_outputs(value, _ops) for value in val.values()), _ops) + _ops.append(IRPyFunc('(lambda vals, keys: {key:val for key,val in zip(keys,vals)})', [values], [output], keys=keys)) + return output + if isinstance(val, torch.fx.Node): + return frame.get_var(val.name) + return val + + generate_outputs(node.args[0], ir_nodes) + if len(ir_nodes) > 0: + ir_nodes[-1].set_output(0, frame.get_var(node.name)) + return ir_nodes + # # NOTE: this is a function in torch.fx # @staticmethod # def _get_qualified_name(func: Callable[..., Any]) -> str: @@ -333,7 +357,7 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram # return f'{module}.{name}' @staticmethod - def _get_qualified_name(node_target: str | Callable[..., Any], node: torch.fx.Node = None) -> str: + def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: if isinstance(node_target, str): assert node is not None return FxModuleParser._get_qualified_name_of_call_method(node_target, node) diff --git a/tests/parser/test_fx_ops.py b/tests/parser/test_fx_ops.py new file mode 100644 index 00000000..74de8f91 --- /dev/null +++ b/tests/parser/test_fx_ops.py @@ -0,0 +1,97 @@ +""" +USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_fx_ops.py +""" +from typing import List +import torch + +import cube +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops + + +class TestOpModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.param1 = torch.nn.Parameter(torch.empty([512, 256], dtype=torch.float32)) + self.param2 = torch.nn.Parameter(torch.empty([512, 256], dtype=torch.float32)) + self.ints = [1, 2, 3] + + def forward(self, x: torch.Tensor): + # matmul: [256, 512], [512, 256] -> [256, 256] + x1 = torch.matmul(x, self.param1) + x1 = torch.matmul(x, self.param1) + x2 = torch.chunk(x, 2, dim=1) + x3 = x2[0] + x = x + x.size(0) + x = x + self.ints[0] + return {'x': x}, [x3,] + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int = 256) -> None: + # self.sample = ( + # torch.rand( + # [batch_size, 512], + # dtype=torch.float32, + # device=torch.cuda.current_device() + # ), + # [torch.tensor([1], dtype=torch.float32),] + # ) + # super().__init__(batch_size, (0, None)) + self.sample = torch.rand( + [batch_size, 512], + dtype=torch.float32, + device=torch.cuda.current_device() + ) + super().__init__(batch_size, (0,)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +def test_parse_ops(): + + cube.init() + + model = TestOpModule() + dataloader = TestDataLoader() + + def policy(graph, resource): + print(graph.extra_repr()) + assert resource.ngpus == 1 + for node in graph.nodes(): + if isinstance(node, IRDimops): + print(f'# {node.anno}') + print(node) + elif isinstance(node, (IRFwOperation, IRDataOperation)): + print(node) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + model = cube.SemanticModel(model) + + @cube.compile(model, dataloader, policy, load_content=False) + def eval_iter(model, dataloader): + data = next(dataloader) + out = model(data) + return out + + model = model.get_gen_module() + + for idx in range(3): + eval_iter(model, dataloader) + print(f"iter {idx}/3") + + +if __name__ == '__main__': + test_parse_ops() + From cd3748c5e63748d4d8ab8366c76b8b3d63e04431 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sat, 4 Mar 2023 17:15:43 -0800 Subject: [PATCH 1296/1892] bug fix --- cube/graph/parser/converter.py | 2 +- cube/graph/parser/mappingfx.py | 4 ++-- cube/graph/parser/parserfx.py | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 57c73273..b05221cc 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -61,7 +61,7 @@ def convert_model(model: torch.nn.Module, module_name = model.__class__.__name__ else: FxModuleParser.save_content = save_content - inputs, nodes, outputs = FxModuleParser.parse(traced_model, input_shapes=input_shapes, dummy_input=dummy_input) + inputs, nodes, outputs = FxModuleParser.parse(traced_model, input_shapes=input_shapes, dummy_inputs=dummy_input) module_name = model.__class__.__name__ else: ScriptModuleParser.save_content = save_content diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 790742a3..055fd17f 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -73,7 +73,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ne') : function.NE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, - # __ttemplate('masked_fill'): function.MaskedFill, + __ttemplate('masked_fill'): function.MaskedFill, __ftemplate('embedding'): function.Embedding, # # torch nn functional @@ -191,4 +191,4 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] } # customized operator code: signature -> code - kOpCodeDef: Dict[str, str] = {} \ No newline at end of file + kOpCodeDef: Dict[str, str] = {} diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 24f37fdf..ccab5351 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -92,6 +92,7 @@ def parse(module: torch.fx.GraphModule, from torch.fx.passes.shape_prop import ShapeProp ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) else: + assert False assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' # remove dead nodes from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler @@ -128,6 +129,7 @@ def parse(module: torch.fx.GraphModule, print(f'dummy_inputs does not have {input.name}') shape = None # FIXME: use the input's real dtype + #shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] dtype = kDefaultType val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) From 2a4b9dd38f578644fd8ce4f66a08d787af731990 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sat, 4 Mar 2023 17:22:50 -0800 Subject: [PATCH 1297/1892] update --- cube/graph/parser/parserfx.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index ccab5351..a2119211 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -116,20 +116,20 @@ def parse(module: torch.fx.GraphModule, print(f'{node.name} does not has tensor_meta') # handle graph input -- some inputs could be None or not tensor - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + #default_dtype = torch.get_default_dtype() + #kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, extend to other input types for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) - if hasattr(dummy_inputs, input.name): + '''if hasattr(dummy_inputs, input.name): print(f'dummy_inputs has {input.name}') shape = getattr(dummy_inputs, input.name).size()# None if dummy_inputs is None else dummy_inputs[idx].size() else: # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name print(f'dummy_inputs does not have {input.name}') shape = None - # FIXME: use the input's real dtype - #shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] + # FIXME: use the input's real dtype''' + shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] dtype = kDefaultType val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) From 20c32706bab79c5a806223b4499a32c417c400ca Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sat, 4 Mar 2023 17:41:15 -0800 Subject: [PATCH 1298/1892] update --- cube/graph/parser/parserfx.py | 56 +++++++++++++++-------------------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 93985ccf..14341bfe 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -1,11 +1,7 @@ import torch import enum import re -<<<<<<< HEAD -from typing import Any, List, Tuple, Optional, Callable -======= from typing import Any, List, Tuple, Optional, Callable, Union ->>>>>>> 8635f2beb07e5036bb1e086e383f2afb66c6d3e2 from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor @@ -96,10 +92,6 @@ def parse(module: torch.fx.GraphModule, from torch.fx.passes.shape_prop import ShapeProp ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) else: -<<<<<<< HEAD - assert False -======= ->>>>>>> 8635f2beb07e5036bb1e086e383f2afb66c6d3e2 assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' # remove dead nodes from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler @@ -162,7 +154,21 @@ def parse(module: torch.fx.GraphModule, ir_nodes = FxModuleParser.parse_node(node, module, frame) all_ir_nodes += ir_nodes - output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + #output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + # handle outputs + output_nodes = [node.all_input_nodes for node in module.graph.nodes if node.op == 'output'] + print(f'outputs = {output_nodes}') + output_var_name = [output.name for output in [item for sublist in output_nodes for item in sublist]] + output_val = [frame.get_var(var_name) for var_name in output_var_name] + + # flatten output_val + outputs = list() + for val in output_val: + if isinstance(val, list): + outputs += val + else: + outputs.append(val) + output_val = outputs frame.pop_var() frame.pop_attr() @@ -198,7 +204,8 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == FxNodeKind.Placeholder: return [] if node_type == FxNodeKind.Output: - return FxModuleParser.parse_prim_output_node(node, module, frame) + #return FxModuleParser.parse_prim_output_node(node, module, frame) + return [] if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): return FxModuleParser.parse_prim_function_method(node, module, frame) @@ -270,20 +277,21 @@ def extract_val(fx_node): if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: - input_vals = [extract_val(v) for v in node.args] - kwargs = {key: extract_val(v) for key, v in node.kwargs.items()} + #input_vals = [extract_val(v) for v in node.args] + #kwargs = {key: extract_val(v) for key, v in node.kwargs.items()} # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): print(f'>>> Find unkown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fname ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) - ir_node.kwargs = kwargs + #ir_node.kwargs = kwargs for idx, t in enumerate(input_vals): ir_node.set_input(idx, t) # case2: python runtime function else: print(f'>>> Set python runtime function: {fsig}') - ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) + #ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) + ir_node = IRPyFunc(fsig, input_vals, [None]) # TODO gracefully set output output_name = node.name @@ -314,9 +322,7 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram return list() -<<<<<<< HEAD -======= - @staticmethod + '''@staticmethod def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 ir_nodes = [] @@ -348,9 +354,8 @@ def generate_outputs(val: Any, _ops: List) -> IRObject: generate_outputs(node.args[0], ir_nodes) if len(ir_nodes) > 0: ir_nodes[-1].set_output(0, frame.get_var(node.name)) - return ir_nodes + return ir_nodes''' ->>>>>>> 8635f2beb07e5036bb1e086e383f2afb66c6d3e2 # # NOTE: this is a function in torch.fx # @staticmethod # def _get_qualified_name(func: Callable[..., Any]) -> str: @@ -367,18 +372,6 @@ def generate_outputs(val: Any, _ops: List) -> IRObject: # if module == "torch" and name == "segment_reduce": # name = "_" + name # return f'{module}.{name}' -<<<<<<< HEAD - - @staticmethod - def _get_qualified_name(node_target: str | Callable[..., Any], node: torch.fx.Node = None) -> str: - if isinstance(node_target, str): - assert node is not None - return FxModuleParser._get_qualified_name_of_call_method(node_target, node) - else: - return FxModuleParser._get_qualified_name_of_call_function(node_target) - - @staticmethod -======= @staticmethod def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: @@ -389,7 +382,6 @@ def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch return FxModuleParser._get_qualified_name_of_call_function(node_target) @staticmethod ->>>>>>> 8635f2beb07e5036bb1e086e383f2afb66c6d3e2 def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str: """ The target field of call_function node must be an callable object. From ca579f37de40287a0d2f1530fdbab279bcb7246f Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sat, 4 Mar 2023 23:47:11 -0800 Subject: [PATCH 1299/1892] update --- cube/graph/parser/parserfx.py | 50 ++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 14341bfe..6824bbd7 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -83,23 +83,45 @@ def parse(module: torch.fx.GraphModule, print(f'input shape mismatch (got {len(input_shapes)} != {len(inputs)})') # TODO fixme raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype if input_shapes is not None: - ## shape propagation - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + # shape propagation sample_inputs = dummy_inputs if dummy_inputs else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] sample_input_tensors = [sample_inputs[input] for input in sample_inputs] if type(sample_inputs) is dict else sample_inputs from torch.fx.passes.shape_prop import ShapeProp ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) + # handle graph inputs + for idx, input in enumerate(inputs): + assert isinstance(input, torch.fx.Node) + shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] + dtype = kDefaultType + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + frame.add_var(input.name, val, graph_arg=idx) else: assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' # remove dead nodes from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler DCEHandler(module).eliminate_dead_code() - # shape propagation from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp KwargsShapeProp(module).propagate(dummy_inputs) + # handle graph inputs + for idx, input in enumerate(inputs): + assert isinstance(input, torch.fx.Node) + # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, + # extend to other input types + if hasattr(dummy_inputs, input.name): + print(f'dummy_inputs has {input.name}') + shape = getattr(dummy_inputs, input.name).size() + else: + # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name + print(f'dummy_inputs does not have {input.name}') + shape = None + dtype = kDefaultType + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + frame.add_var(input.name, val, graph_arg=idx) + input_val = [frame.get_var(input.name) for input in inputs] for node in module.graph.nodes: if 'tensor_meta' in node.meta: @@ -114,26 +136,6 @@ def parse(module: torch.fx.GraphModule, else: print(f'{node.name} does not has tensor_meta') - # handle graph input -- some inputs could be None or not tensor - #default_dtype = torch.get_default_dtype() - #kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, extend to other input types - for idx, input in enumerate(inputs): - assert isinstance(input, torch.fx.Node) - '''if hasattr(dummy_inputs, input.name): - print(f'dummy_inputs has {input.name}') - shape = getattr(dummy_inputs, input.name).size()# None if dummy_inputs is None else dummy_inputs[idx].size() - else: - # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name - print(f'dummy_inputs does not have {input.name}') - shape = None - # FIXME: use the input's real dtype''' - shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] - dtype = kDefaultType - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) - frame.add_var(input.name, val, graph_arg=idx) - input_val = [frame.get_var(input.name) for input in inputs] - # add activations to frame, including call_func/call_method output and final output activation_op_strs = {'call_function', 'output', 'call_method'} activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] From 908b3ca7a64b0bc93758fa1c203d7f308199952a Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Sun, 5 Mar 2023 07:56:20 +0000 Subject: [PATCH 1300/1892] Merged PR 1472: fix test --- cube/graph/parser/converter.py | 2 +- cube/graph/parser/mappingfx.py | 4 +- cube/graph/parser/parserfx.py | 81 +++++++++++++++++++++------------- 3 files changed, 53 insertions(+), 34 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 57c73273..b05221cc 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -61,7 +61,7 @@ def convert_model(model: torch.nn.Module, module_name = model.__class__.__name__ else: FxModuleParser.save_content = save_content - inputs, nodes, outputs = FxModuleParser.parse(traced_model, input_shapes=input_shapes, dummy_input=dummy_input) + inputs, nodes, outputs = FxModuleParser.parse(traced_model, input_shapes=input_shapes, dummy_inputs=dummy_input) module_name = model.__class__.__name__ else: ScriptModuleParser.save_content = save_content diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 790742a3..055fd17f 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -73,7 +73,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ne') : function.NE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, - # __ttemplate('masked_fill'): function.MaskedFill, + __ttemplate('masked_fill'): function.MaskedFill, __ftemplate('embedding'): function.Embedding, # # torch nn functional @@ -191,4 +191,4 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] } # customized operator code: signature -> code - kOpCodeDef: Dict[str, str] = {} \ No newline at end of file + kOpCodeDef: Dict[str, str] = {} diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index af2857dc..6824bbd7 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -83,23 +83,45 @@ def parse(module: torch.fx.GraphModule, print(f'input shape mismatch (got {len(input_shapes)} != {len(inputs)})') # TODO fixme raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + default_dtype = torch.get_default_dtype() + kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype if input_shapes is not None: - ## shape propagation - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype + # shape propagation sample_inputs = dummy_inputs if dummy_inputs else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] sample_input_tensors = [sample_inputs[input] for input in sample_inputs] if type(sample_inputs) is dict else sample_inputs from torch.fx.passes.shape_prop import ShapeProp ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) + # handle graph inputs + for idx, input in enumerate(inputs): + assert isinstance(input, torch.fx.Node) + shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] + dtype = kDefaultType + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + frame.add_var(input.name, val, graph_arg=idx) else: assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' # remove dead nodes from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler DCEHandler(module).eliminate_dead_code() - # shape propagation from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp KwargsShapeProp(module).propagate(dummy_inputs) + # handle graph inputs + for idx, input in enumerate(inputs): + assert isinstance(input, torch.fx.Node) + # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, + # extend to other input types + if hasattr(dummy_inputs, input.name): + print(f'dummy_inputs has {input.name}') + shape = getattr(dummy_inputs, input.name).size() + else: + # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name + print(f'dummy_inputs does not have {input.name}') + shape = None + dtype = kDefaultType + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + frame.add_var(input.name, val, graph_arg=idx) + input_val = [frame.get_var(input.name) for input in inputs] for node in module.graph.nodes: if 'tensor_meta' in node.meta: @@ -114,25 +136,6 @@ def parse(module: torch.fx.GraphModule, else: print(f'{node.name} does not has tensor_meta') - # handle graph input -- some inputs could be None or not tensor - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, extend to other input types - for idx, input in enumerate(inputs): - assert isinstance(input, torch.fx.Node) - if hasattr(dummy_inputs, input.name): - print(f'dummy_inputs has {input.name}') - shape = getattr(dummy_inputs, input.name).size()# None if dummy_inputs is None else dummy_inputs[idx].size() - else: - # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name - print(f'dummy_inputs does not have {input.name}') - shape = None - # FIXME: use the input's real dtype - dtype = kDefaultType - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) - frame.add_var(input.name, val, graph_arg=idx) - input_val = [frame.get_var(input.name) for input in inputs] - # add activations to frame, including call_func/call_method output and final output activation_op_strs = {'call_function', 'output', 'call_method'} activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] @@ -153,7 +156,21 @@ def parse(module: torch.fx.GraphModule, ir_nodes = FxModuleParser.parse_node(node, module, frame) all_ir_nodes += ir_nodes - output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + #output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + # handle outputs + output_nodes = [node.all_input_nodes for node in module.graph.nodes if node.op == 'output'] + print(f'outputs = {output_nodes}') + output_var_name = [output.name for output in [item for sublist in output_nodes for item in sublist]] + output_val = [frame.get_var(var_name) for var_name in output_var_name] + + # flatten output_val + outputs = list() + for val in output_val: + if isinstance(val, list): + outputs += val + else: + outputs.append(val) + output_val = outputs frame.pop_var() frame.pop_attr() @@ -189,7 +206,8 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == FxNodeKind.Placeholder: return [] if node_type == FxNodeKind.Output: - return FxModuleParser.parse_prim_output_node(node, module, frame) + #return FxModuleParser.parse_prim_output_node(node, module, frame) + return [] if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): return FxModuleParser.parse_prim_function_method(node, module, frame) @@ -261,20 +279,21 @@ def extract_val(fx_node): if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: - input_vals = [extract_val(v) for v in node.args] - kwargs = {key: extract_val(v) for key, v in node.kwargs.items()} + #input_vals = [extract_val(v) for v in node.args] + #kwargs = {key: extract_val(v) for key, v in node.kwargs.items()} # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): print(f'>>> Find unkown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fname ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) - ir_node.kwargs = kwargs + #ir_node.kwargs = kwargs for idx, t in enumerate(input_vals): ir_node.set_input(idx, t) # case2: python runtime function else: print(f'>>> Set python runtime function: {fsig}') - ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) + #ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) + ir_node = IRPyFunc(fsig, input_vals, [None]) # TODO gracefully set output output_name = node.name @@ -305,7 +324,7 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram return list() - @staticmethod + '''@staticmethod def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 ir_nodes = [] @@ -337,7 +356,7 @@ def generate_outputs(val: Any, _ops: List) -> IRObject: generate_outputs(node.args[0], ir_nodes) if len(ir_nodes) > 0: ir_nodes[-1].set_output(0, frame.get_var(node.name)) - return ir_nodes + return ir_nodes''' # # NOTE: this is a function in torch.fx # @staticmethod From 260cdd9d9d72fea2d143c82e232bd225ce314a8d Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sun, 5 Mar 2023 00:06:35 -0800 Subject: [PATCH 1301/1892] debug --- cube/graph/parser/parserfx.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 6824bbd7..ea5692c1 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -156,7 +156,7 @@ def parse(module: torch.fx.GraphModule, ir_nodes = FxModuleParser.parse_node(node, module, frame) all_ir_nodes += ir_nodes - #output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + # output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] # handle outputs output_nodes = [node.all_input_nodes for node in module.graph.nodes if node.op == 'output'] print(f'outputs = {output_nodes}') @@ -171,6 +171,8 @@ def parse(module: torch.fx.GraphModule, else: outputs.append(val) output_val = outputs + print(f'zql: {output_val}') + # exit(1) frame.pop_var() frame.pop_attr() @@ -206,8 +208,7 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == FxNodeKind.Placeholder: return [] if node_type == FxNodeKind.Output: - #return FxModuleParser.parse_prim_output_node(node, module, frame) - return [] + return FxModuleParser.parse_prim_output_node(node, module, frame) if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): return FxModuleParser.parse_prim_function_method(node, module, frame) @@ -324,7 +325,7 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram return list() - '''@staticmethod + @staticmethod def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 ir_nodes = [] @@ -356,7 +357,7 @@ def generate_outputs(val: Any, _ops: List) -> IRObject: generate_outputs(node.args[0], ir_nodes) if len(ir_nodes) > 0: ir_nodes[-1].set_output(0, frame.get_var(node.name)) - return ir_nodes''' + return ir_nodes # # NOTE: this is a function in torch.fx # @staticmethod From 72f09ef333857ac300897a88157509d096a9b7f5 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sun, 5 Mar 2023 00:26:02 -0800 Subject: [PATCH 1302/1892] update test --- tests/test_examples.sh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 771e13e7..9e5c526d 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -1,4 +1,12 @@ +# NOTE: This test should run in the root directory. +# Before running this test, you should run `export PYTHONPATH=.:$PYTHONPATH` first. +# test torch.fx +# working path +OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:$PYTHONPATH \ + python -m torch.distributed.launch \ + --nproc_per_node=1 \ + examples/mlp/linearsfx.py --policy PASData # test MLP From ba0e1eefecc90de68d1e9cd9fb5617231a5a7461 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Sun, 5 Mar 2023 08:34:18 +0000 Subject: [PATCH 1303/1892] Merged PR 1473: fix the issue of pr 1469 --- cube/graph/parser/parserfx.py | 7 +++---- tests/test_examples.sh | 8 ++++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 6824bbd7..3733dfa1 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -206,8 +206,7 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == FxNodeKind.Placeholder: return [] if node_type == FxNodeKind.Output: - #return FxModuleParser.parse_prim_output_node(node, module, frame) - return [] + return FxModuleParser.parse_prim_output_node(node, module, frame) if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): return FxModuleParser.parse_prim_function_method(node, module, frame) @@ -324,7 +323,7 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram return list() - '''@staticmethod + @staticmethod def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 ir_nodes = [] @@ -356,7 +355,7 @@ def generate_outputs(val: Any, _ops: List) -> IRObject: generate_outputs(node.args[0], ir_nodes) if len(ir_nodes) > 0: ir_nodes[-1].set_output(0, frame.get_var(node.name)) - return ir_nodes''' + return ir_nodes # # NOTE: this is a function in torch.fx # @staticmethod diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 771e13e7..9e5c526d 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -1,4 +1,12 @@ +# NOTE: This test should run in the root directory. +# Before running this test, you should run `export PYTHONPATH=.:$PYTHONPATH` first. +# test torch.fx +# working path +OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:$PYTHONPATH \ + python -m torch.distributed.launch \ + --nproc_per_node=1 \ + examples/mlp/linearsfx.py --policy PASData # test MLP From 72d363a9969ac87e5c0e9d29c25878101d752636 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Sun, 5 Mar 2023 08:51:16 +0000 Subject: [PATCH 1304/1892] Merged PR 1470: add torch language model example for fx test add torch language model example for fx test --- cube/graph/parser/converter.py | 7 +- examples/nlp/torchscale/lm_fx_test.py | 204 ++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 examples/nlp/torchscale/lm_fx_test.py diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index b05221cc..64d9fc35 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -30,12 +30,15 @@ def convert_model(model: torch.nn.Module, else: print('using concrete tracer') with torch.no_grad(): - if isinstance(dummy_input, tuple): + if isinstance(dummy_input, tuple) or isinstance(dummy_input, list): output_origin = model(*dummy_input) elif isinstance(dummy_input, torch.Tensor): output_origin = model(dummy_input) + elif isinstance(dummy_input, dict): + print(f'WARNING dict dummy_input') + output_origin = model(**dummy_input) else: - raise RuntimeError(f'dummy_input should be a tuple = {dummy_input}') + raise RuntimeError(f'dummy_input should be a tuple (not a {type(dummy_input)}) = {dummy_input}') traced_model, _ = concrete_trace( model, dummy_input, diff --git a/examples/nlp/torchscale/lm_fx_test.py b/examples/nlp/torchscale/lm_fx_test.py new file mode 100644 index 00000000..382174e6 --- /dev/null +++ b/examples/nlp/torchscale/lm_fx_test.py @@ -0,0 +1,204 @@ +# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/lm_fx_test.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData + +import torch +import pickle +from fairseq import ( + tasks, + options, + checkpoint_utils +) +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.trainer import Trainer +from fairseq.data import iterators + +import sys + +import os +print(f'os.getcwd() = {os.getcwd()}') + + +# https://github.com/microsoft/torchscale/tree/main/examples/fairseq +# sys.path.append('/home/v-junliang/torchscaletest/torchscale/examples/fairseq') +# sys.path.append('./torchscaletest/torchscale/examples/fairseq') +sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') +sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale') +print(f'sys.path = {sys.path}') +import models + +#:torchscaletest/torchscale +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +sys.path.append('.') +from policy import mpmd, spmd +# import examples.nlp.torchscale.policy.spmd as spmd + +# import argparse + +# parser = argparse.ArgumentParser(description='comm primitive') +# parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +# parser.add_argument('--local_rank', type=int, default=0) +# args = parser.parse_args() + +# build model +parser = options.get_training_parser() +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +# parser.add_argument('--local_rank', type=int, default=0) + +args = options.parse_args_and_arch(parser) + +cube.init() +# set up policy +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + +cfg = convert_namespace_to_omegaconf(args) +task = tasks.setup_task(cfg.task) +model = task.build_model(cfg.model) +model.eval() +print("building model succeed: ", type(model)) + +# create dummy input +with open('examples/nlp/torchscale/input_lm.bak', 'rb') as f: +# with open('examples/nlp/torchscale/lm_input_v2.pkl', 'rb') as f: + dummy_input = pickle.load(f) +device = next(model.parameters()).device +print(f'device = {device}') +for key in dummy_input.keys(): + dummy_input[key] = dummy_input[key].to(device) +print(f'dummy_input <{type(dummy_input)}> = {dummy_input}') + +# create input as list of tensors/objects +dummy_input_list = [val for key, val in dict(dummy_input).items()] +# print(f'dummy_input_list = {dummy_input_list}, len = {len(dummy_input_list)}') + +with torch.no_grad(): + output_origin = model(**dummy_input) + # output_origin = model(*dummy_input_list) + # print(f'output_origin = {output_origin}') + + +input_shapes = [list(dummy_input[input].size()) for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] +input_dtypes = [dummy_input[input].dtype for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] +input_names = tuple([input for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)]) + +# input_shapes += [[None], [None]] +# input_dtypes += [bool, bool] + +print(f'input_shapes = {input_shapes}') +print(f'input_dtypes = {input_dtypes}') + +dataloader = cube.runtime.syndata.SynDataLoader( + shapes=(input_shapes), + dtypes=input_dtypes, + batch_dims=(0,0), +) +sample_input = next(dataloader) +print(f'next(dataloader) = {sample_input}') +sample_input_cpu = tuple([val.to(device) if isinstance(val, torch.Tensor) else val for val in sample_input]) + +model = cube.SemanticModel( + #TODO fix me model, dummy_input=sample_input_cpu, + # model, dummy_input=dummy_input_list, + model, dummy_input=dummy_input, +) + +@cube.compile(model, dataloader, PAS=PAS, load_content=False) +def train_iter(model, dataloader): + data = next(dataloader) + loss = model(*data) + loss.backward() + +train_iter(model, dataloader) + +# Conduct concrete trace below +# sys.path.append('/home/v-junliang/torchscaletest/nni') +# sys.path.append('./torchscaletest/nni') +# from nni.common.concrete_trace_utils import concrete_trace +# from concrete_trace_utils import concrete_trace +from examples.nlp.torchscale.concrete_trace_utils import concrete_trace +import examples.nlp.torchscale.torchscaletest.torchscale + + +def check_equal(a, b): + if type(a) != type(b): + return False + if isinstance(a, (list, tuple, set)): + if len(a) != len(b): + return False + for sub_a, sub_b in zip(a, b): + if not check_equal(sub_a, sub_b): + return False + return True + elif isinstance(a, dict): + keys_a, kes_b = set(a.keys()), set(b.keys()) + if keys_a != kes_b: + return False + for key in keys_a: + if not check_equal(a[key], b[key]): + return False + return True + elif isinstance(a, torch.Tensor): + return torch.equal(a, b) + else: + return a == b + + +print("start tracing...") +traced_model, _ = concrete_trace( + model, + dummy_input, + use_operator_patch=True, + autowrap_leaf_class={ + torch.finfo: ((), False), + type(output_origin): ((), False), + }, +) +print("trace succeed") +print("checking equal...") +with torch.no_grad(): + output_traced = traced_model(**dummy_input) +assert check_equal(output_origin, output_traced), "check equal failed" +print("checked") + +# check graph +traced_model.graph.print_tabular() + +# with open('input_tl', 'wb') as f: +# pickle.dump(dummy_input, f) + +# try to save traced model with pickle +# from concrete_trace_utils.concrete_tracer import MagicMethodPatcher +# from pickle import _Pickler, _Unpickler + +# with open("save/through_nn_Module/tl_traced_v2.model", "wb") as f: +# # pickle.dump(traced_model, f) +# with MagicMethodPatcher(): +# _Pickler(f).dump(traced_model) + +# with open("save/through_nn_Module/tl_traced.model", "rb") as f: +# with MagicMethodPatcher(): +# reload_model = _Unpickler(f).load() + + +# with torch.no_grad(): +# output_reload = reload_model(**dummy_input) +# assert check_equal(output_origin, output_reload), "reload check equal failed" +# print("reload is good!") + +# with open("save/through_nn_Module/tl_origin_v2.model", "wb") as f: +# with MagicMethodPatcher(): +# _Pickler(f).dump(model) + +# with open("save/through_nn_Module/tl_input_v2.pkl", "wb") as f: +# with MagicMethodPatcher(): +# _Pickler(f).dump(dummy_input) + From b4e0648676915323588ecd9ceb24247220d05987 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 5 Mar 2023 01:06:40 -0800 Subject: [PATCH 1305/1892] support operators --- cube/codegen/frontend_mapping.py | 6 ++++ cube/graph/function/dimops.py | 3 ++ cube/graph/function/function.py | 52 ++++++++++++++++++++++++++++++-- cube/graph/parser/mappingfx.py | 43 +++++++++++++++----------- cube/graph/parser/parserfx.py | 6 +++- cube/ir/dtype.py | 4 +-- cube/ir/tensor.py | 2 +- examples/mlp/linearsfx.py | 30 +++++++++++++----- 8 files changed, 114 insertions(+), 32 deletions(-) diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index c3ad7aa3..37d3cba0 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -169,6 +169,11 @@ def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: member = f'"{arg_vars[1][5:]}"' return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" +def emit_index_select(node, arg_vars:list, kw_pairs:dict) -> str: + assert 'dim' in kw_pairs + dim = kw_pairs['dim'] + return f'{node.signature}({arg_vars[0]}, {dim}, {arg_vars[1]})' + class Sign2EmitRule: @@ -200,6 +205,7 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: 'torch.Tensor.to': emit_to, 'torch.rand': emit_rand, 'torch.tensor': emit_new_tensor, + 'torch.index_select': emit_index_select, 'setattr': emit_setattr, } diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index c8811348..3e47a510 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -438,6 +438,9 @@ def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], ou_annos = list() for shape in ins: flatten = list() + if isinstance(shape, str): + in_annos.append(shape) + continue for edim in shape: if isinstance(edim, str): flatten.append(edim) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index cceb16f1..669564fe 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -523,6 +523,13 @@ def Dropout(signature, inputs): p=p, training=training, inplace=inplace) +def Detach(signature, inputs): + assert len(inputs) == 1 + annos = ['* -> *'] + tensor = inputs[0:1] + return IRDimops(Detach, 'detach', signature, annos, tensor) + + def EQ(signature, inputs): assert len(inputs) == 2 input0, input1 = inputs @@ -718,8 +725,11 @@ def View(signature, inputs): """ out = torch.Tensor.view(tensor: torch.Tensor, size: List[int]) """ - assert len(inputs) == 2 - input, shape = inputs + if len(inputs) == 2: + input, shape = inputs + else: + input = inputs[0] + shape = inputs[1:] if not all([isinstance(dim, int) for dim in shape]): raise TypeError("Expected tensor.view has static int shape") in_shape, ou_shape = list(input.shape), shape @@ -1136,6 +1146,9 @@ def Stack(signature, inputs: Tuple[List[IRTensor], int]): tensors, dim = inputs else: tensors, dim = inputs[:-1], inputs[-1] + if isinstance(dim, dict): + assert 'dim' in dim + dim = dim['dim'] assert all(isinstance(tensor, IRTensor) for tensor in tensors) iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] oannos = [copy.copy(iannos[-1])] @@ -1173,6 +1186,25 @@ def Select(signature, inputs: Tuple[IRTensor, int, int]): return IRDimops(Select, 'select', signature, [anno], [tensor], dim=dim, index=index) +def IndexSelect(signature, inputs): + assert len(inputs) == 3 + # hack + if isinstance(inputs[1], int): + tensor, dim, idx = inputs + else: + assert isinstance(inputs[2], int) + tensor, idx, dim = inputs + + edim_in = ShapeAnno.create_shape_str(tensor.shape) + edim_in[dim] += '^' + idx_anno = chr(ord(edim_in[-1]) + 1) + '^' + edim_ou = copy.copy(edim_in) + edim_ou[dim] = copy.copy(idx_anno) + anno = OpAnno.create_op_str([edim_in, idx_anno], [edim_ou]) + + return IRDimops(IndexSelect, 'index_select', signature, [anno], [tensor, idx], dim=dim) + + def Slice(signature, inputs): """ aten::slice(input:Tensor, dim:int, start:Optional[int], end:Optional[int], step:int) -> Tensor @@ -1262,7 +1294,10 @@ def Embedding(signature, inputs: List): def Flatten(signature, inputs: List): tensor: IRTensor = inputs[0] - start_dim, end_dim = inputs[1:] + if len(inputs) == 1: + start_dim, end_dim = 0, len(tensor.shape) - 1 + else: + start_dim, end_dim = inputs[1:] end_dim = len(tensor.shape) + end_dim if end_dim < 0 else end_dim ishape = ShapeAnno.create_shape_str(tensor.shape) for dim in range(start_dim, end_dim+1): @@ -1385,3 +1420,14 @@ def CompareLE(signature, inputs): torch.gt(input, other, *, out=None) -> Tensor """ return _comparison(CompareLE, operator.le, 'le', signature, inputs) + + +def ShapeAsTensor(signature, inputs): + assert len(inputs) == 1 + input = inputs[0] + + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = [str(len(input.shape))] + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(ShapeAsTensor, '_shape_as_tensor', signature, [anno], [input]) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 8fdf70c5..c6da4617 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -63,7 +63,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('unsqueeze'): function.Unsqueeze, __tttemplate('type_as'): function.TypeAs, __ttemplate('triu'): function.Triu, - __ftemplate('relu') : function.ReLU, + __ftemplate('relu'): function.ReLU, + __fcntemplate('gelu'): function.GeLU, __ttemplate('eq') : function.EQ, __ttemplate('ne') : function.NE, __ttemplate('nan_to_num') : function.NanToNum, @@ -79,6 +80,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, __tttemplate('expand'): function.Expand, + __ttemplate('detach'): function.Detach, + __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, + __ttemplate('index_select'): function.IndexSelect, # TODO __ftemplate('layer_norm'): function.LayerNorm, @@ -116,17 +120,19 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('sub') : function.Sub, # # __ttemplate('mul') : function.Mul, - # - # __ttemplate('div') : function.Div, - # - # __ttemplate('floordiv') : function.FloorDiv, - # - # __ttemplate('neg'): function.Neg, - # - # __ttemplate('gt'): function.CompareGT, - # __ttemplate('lt'): function.CompareLT, - # __ttemplate('ge'): function.CompareGE, - # __ttemplate('le'): function.CompareLE, + + __ttemplate('div') : function.Div, + __ttemplate('truediv'): function.Div, + __ttemplate('true_divide'): function.Div, + __ttemplate('floordiv') : function.FloorDiv, + __ttemplate('floor_divide') : function.FloorDiv, + + __ttemplate('neg'): function.Neg, + # + __ttemplate('gt'): function.CompareGT, + __ttemplate('lt'): function.CompareLT, + __ttemplate('ge'): function.CompareGE, + __ttemplate('le'): function.CompareLE, # # __ttemplate('sin'): function.Sin, # @@ -136,8 +142,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('mean') : function.Mean, # # __ttemplate('view'): function.View, - # - # __ttemplate('reshape'): function.Reshape, + __tttemplate('view'): function.View, + + __ttemplate('reshape'): function.Reshape, # # __ttemplate('conv2d'): function.Conv2D, # @@ -158,12 +165,12 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('linear'): function.Linear, # # __ttemplate('cat'): function.Cat, - # - # __ttemplate('stack'): function.Stack, + + __ttemplate('stack'): function.Stack, # # __ttemplate('chunk'): function.Chunk, - # - # __ttemplate('flatten'): function.Flatten, + + __ttemplate('flatten'): function.Flatten, # # __ttemplate('roll'): function.Roll, # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index db77bccd..01ae17c1 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -351,7 +351,11 @@ def extract_val(fx_node): input_vals = list() for item in node.args: - input_vals.append(extract_val(item)) + if isinstance(item, tuple): + for _ in item: + input_vals.append(extract_val(_)) + else: + input_vals.append(extract_val(item)) if node.kwargs: input_kwvals = {} for k, v in node.kwargs.items(): diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index 4bda60c0..f654cd62 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -47,9 +47,9 @@ def infer(node, dtypes: List[IRDType]) -> IRDType: if IRDType.float32 in dtypes and IRDType.float16 in dtypes: raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") # TODO(yizhu1): hack - if node.signature in ['torch.ne', 'torch.eq']: + if node.signature in ('torch.ne', 'torch.eq', 'torch.gt'): return IRDType.boolean - elif node.signature == 'torch.Tensor.long': + elif node.signature in ('torch.Tensor.long', 'torch._shape_as_tensor'): return IRDType.int64 # in priority: fp32 > fp16 > bool > int64 > int16 > priority = [ diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 0054005a..2798631b 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -633,7 +633,7 @@ def grad(self) -> bool: @grad.setter def grad(self, val: Optional[IRTensor]): if isinstance(val, IRSubTensor): - assert self.requires_grad and val.shape == self.shape + assert self.requires_grad and val.shape == self.shape, f'info: {self.requires_grad} {val.shape == self.shape}' self._grad = val elif val is None: assert not self.requires_grad diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 34422de5..f62e0274 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -60,6 +60,8 @@ def __init__(self, dim, mult=1, nlayers=4): def forward(self, data, mask): x = data.masked_fill(mask, 0.0) + y = torch._shape_as_tensor(x) + z = torch.gt(x, x) x = x.fill_(0.0) x = torch.nn.functional.softmax(x, dim=-1) x = torch.bmm(x, x) @@ -69,6 +71,7 @@ def forward(self, data, mask): for layer in self.layers: x = layer(x) x = torch.nn.functional.relu(x) + x = torch.nn.functional.gelu(x) # x = self.layer_norm(x) x = x.type_as(data) x = x.unsqueeze(1) @@ -76,17 +79,30 @@ def forward(self, data, mask): x = x.squeeze() x = torch.triu(x, 1) x = torch.nan_to_num(x) - # ne and eq cannot backward - # x = torch.ne(x, 1.0) - # y = torch.eq(x, 1.0) + ne_var = x.detach() + ne_var = torch.ne(ne_var, 1.0) + eq_var = x.detach() + eq_var = torch.eq(eq_var, 1.0) + long_var = x.detach() + long_var = long_var.long() + floor_div_var = x.detach() + floor_div_var = torch.floor_divide(floor_div_var, 2.0) + x = torch.true_divide(x, 1.0) x = torch.cumsum(x, -1) x = x.permute(0, 2, 1) x = x.transpose(1, 2) + x = torch.div(x, 1.0) + # concrete_trace not support + # x = torch.Tensor.view(x, [32 * 1024, 1024]) + x = x.view(32 * 1024, 1024) + x = x.reshape(32, 1024, 1024) + # indices = torch.arange(4, dtype=torch.int64) + # x = torch.index_select(x, 1, indices) + p = torch.div(x, 2.0) + x = torch.stack((x, p), dim=1) + x = torch.flatten(x, 2, 3) loss = torch.sum(x) - # loss = loss.expand(2) - # loss = torch.sum(loss) - # long cannot backward - # loss = loss.long() + loss = torch.neg(loss) return loss From 43483650f6430f73dea6f8725440aa5e4fce0f15 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 5 Mar 2023 02:07:46 -0800 Subject: [PATCH 1306/1892] fix bugs --- cube/graph/function/function.py | 22 +++++++++++++++++----- examples/mlp/linearsfx.py | 4 ++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 623e4fd0..74088053 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -46,11 +46,18 @@ def BatchLinear(signature, inputs): def BMMAdd(signature, inputs): - assert len(inputs) == 3 + assert len(inputs) >= 3, f'{inputs}' + alpha, beta = 1, 1 + if len(inputs) == 4: + assert isinstance(inputs[3], dict) + alpha = inputs[3]['alpha'] + beta = inputs[3]['beta'] + elif len(inputs) == 5: + alpha, beta = inputs[3:] annos = [ 'b m n, b m k^, b k^ n -> b m n' ] - return IRDimops(BMMAdd, 'baddbmm', signature, annos, inputs) + return IRDimops(BMMAdd, 'baddbmm', signature, annos, inputs[:3], alpha=alpha, beta=beta) def Matmul(signature, inputs: Tuple[IRTensor, IRTensor]): @@ -591,6 +598,9 @@ def MaskedFill(signature, inputs): edim_in0 = ShapeAnno.create_shape_str(input0.shape) edim_in1 = ShapeAnno.create_shape_str(input1.shape) edim_ou = copy.copy(edim_in0) + for idx, (lhs, rhs) in enumerate(zip(input0.shape, input1.shape)): + if lhs != rhs and rhs == 1: + edim_ou[idx] = '1' anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input0, input1], value=value) @@ -948,9 +958,8 @@ def TypeAs(signature, inputs): input0, input1 = inputs edim_in0 = ShapeAnno.create_shape_str(input0.shape) - edim_in1 = ShapeAnno.create_shape_str(input1.shape) edim_ou = copy.copy(edim_in0) - anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) + anno = OpAnno.create_op_str([edim_in0, '*'], [edim_ou]) return IRDimops(TypeAs, 'type_as', signature, [anno], [input0, input1]) @@ -979,7 +988,10 @@ def CumSum(signature, inputs): """ assert len(inputs) == 2 input, dim = inputs - assert isinstance(dim, int) + if isinstance(dim, dict): + dim = dim['dim'] + else: + assert isinstance(dim, int) edim_in = ShapeAnno.create_shape_str(input.shape) edim_in[dim] += '^' diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index de31227e..233400f5 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -65,7 +65,7 @@ def forward(self, data, mask): x = x.fill_(0.0) x = torch.nn.functional.softmax(x, dim=-1) x = torch.bmm(x, x) - x = torch.baddbmm(x, x, x) + x = torch.baddbmm(x, x, x, alpha=0.125, beta=1.0) x = torch.tanh(x) x = torch.pow(x, x) for layer in self.layers: @@ -88,7 +88,7 @@ def forward(self, data, mask): floor_div_var = x.detach() floor_div_var = torch.floor_divide(floor_div_var, 2.0) x = torch.true_divide(x, 1.0) - x = torch.cumsum(x, -1) + x = torch.cumsum(x, dim=-1) x = x.permute(0, 2, 1) x = x.transpose(1, 2) x = torch.div(x, 1.0) From b38735729f8584f2024af6f68df76d4a338ec543 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 5 Mar 2023 10:08:51 +0000 Subject: [PATCH 1307/1892] Merged PR 1468: Support torch related operators TMP PR for quick dev --- cube/codegen/frontend_mapping.py | 6 + cube/graph/function/dimops.py | 5 +- cube/graph/function/function.py | 192 ++++++++++++++++++++++++++++--- cube/graph/parser/mappingfx.py | 71 ++++++------ cube/ir/dtype.py | 4 +- cube/ir/tensor.py | 2 +- examples/mlp/linearsfx.py | 38 +++++- 7 files changed, 263 insertions(+), 55 deletions(-) diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index c3ad7aa3..37d3cba0 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -169,6 +169,11 @@ def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: member = f'"{arg_vars[1][5:]}"' return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" +def emit_index_select(node, arg_vars:list, kw_pairs:dict) -> str: + assert 'dim' in kw_pairs + dim = kw_pairs['dim'] + return f'{node.signature}({arg_vars[0]}, {dim}, {arg_vars[1]})' + class Sign2EmitRule: @@ -200,6 +205,7 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: 'torch.Tensor.to': emit_to, 'torch.rand': emit_rand, 'torch.tensor': emit_new_tensor, + 'torch.index_select': emit_index_select, 'setattr': emit_setattr, } diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 7204c444..3e47a510 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -409,7 +409,7 @@ def parse(anno: str) -> Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]]: """ # to inputs and outputs if '->' not in anno: - raise ValueError("Syntax Error: Expected -> in operator anno") + raise ValueError(f"Syntax Error: Expected -> in operator anno: {anno}") inputs, outputs = anno.split('->') inputs = inputs.split(',') outputs = outputs.split(',') @@ -438,6 +438,9 @@ def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], ou_annos = list() for shape in ins: flatten = list() + if isinstance(shape, str): + in_annos.append(shape) + continue for edim in shape: if isinstance(edim, str): flatten.append(edim) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 454bb6aa..74088053 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -45,6 +45,21 @@ def BatchLinear(signature, inputs): return IRDimops(BatchLinear, 'bmm', signature, annos, inputs) +def BMMAdd(signature, inputs): + assert len(inputs) >= 3, f'{inputs}' + alpha, beta = 1, 1 + if len(inputs) == 4: + assert isinstance(inputs[3], dict) + alpha = inputs[3]['alpha'] + beta = inputs[3]['beta'] + elif len(inputs) == 5: + alpha, beta = inputs[3:] + annos = [ + 'b m n, b m k^, b k^ n -> b m n' + ] + return IRDimops(BMMAdd, 'baddbmm', signature, annos, inputs[:3], alpha=alpha, beta=beta) + + def Matmul(signature, inputs: Tuple[IRTensor, IRTensor]): assert len(inputs) == 2 annos = [ @@ -252,6 +267,21 @@ def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: return lhs_shape, rhs_shape, out_shape +def Expand(signature, inputs): + input = inputs[0] + sizes = inputs[1:] + + edim_in = ShapeAnno.create_shape_str(input.shape) + assert len(input.shape) == len(sizes) + for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes)): + if dim == 1 and dim != expand_dim: + edim_in[idx] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(Expand, 'expand', signature, [anno], [input], sizes=sizes) + + def Clone(signature, inputs): """ torch.clone(input, *, memory_format=torch.preserve_format) @@ -461,12 +491,32 @@ def ReLU(signature, inputs): def Softmax(signature, inputs): - assert len(inputs) == 4 - annos = ['* -> *'] - tensor = inputs[0:1] - dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] - return IRDimops(Softmax, 'softmax', signature, annos, tensor, - dim=dim, _stacklevel=_stacklevel, dtype=dtype) + assert len(inputs) >= 1 + tensor = inputs[0] + edim_in = ShapeAnno.create_shape_str(tensor.shape) + if len(inputs) == 2: + if isinstance(inputs[1], dict): + edim_in[inputs[1]['dim']] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], **inputs[1]) + elif isinstance(inputs[1], int): + dim = inputs[1] + edim_in[dim] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], dim=inputs[1]) + else: + raise RuntimeError(f'Unexpect intput type {inputs[1]}, {type(inputs[1])}') + elif len(inputs) == 4: + dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] + dim = inputs[1] + edim_in[dim] += '^' + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], + dim=dim, _stacklevel=_stacklevel, dtype=dtype) + else: + raise RuntimeError('Unexpected input num {inputs}') def Dropout(signature, inputs): @@ -480,13 +530,35 @@ def Dropout(signature, inputs): p=p, training=training, inplace=inplace) +def Detach(signature, inputs): + assert len(inputs) == 1 + annos = ['* -> *'] + tensor = inputs[0:1] + return IRDimops(Detach, 'detach', signature, annos, tensor) + + +def EQ(signature, inputs): + assert len(inputs) == 2 + input0, input1 = inputs + + edim_in0 = ShapeAnno.create_shape_str(input0.shape) + edim_ou = copy.copy(edim_in0) + if isinstance(input1, (int, float)): + anno = OpAnno.create_op_str([edim_in0], [edim_ou]) + return IRDimops(EQ, 'eq', signature, [anno], [input0], other=input1) + else: + edim_in1 = copy.copy(edim_in0) + anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) + return IRDimops(EQ, 'eq', signature, [anno], [input0], other=input1) + + def NE(signature, inputs): assert len(inputs) == 2 input0, input1 = inputs edim_in0 = ShapeAnno.create_shape_str(input0.shape) edim_ou = copy.copy(edim_in0) - if isinstance(input1, float): + if isinstance(input1, (int, float)): anno = OpAnno.create_op_str([edim_in0], [edim_ou]) return IRDimops(NE, 'ne', signature, [anno], [input0], other=input1) else: @@ -509,6 +581,16 @@ def Long(signature, inputs): return IRDimops(Long, 'long', signature, annos, tensor) +def Fill(signature, inputs): + assert len(inputs) == 2 + input, value = inputs + + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Fill, 'fill', signature, [anno], [input], value=value) + + def MaskedFill(signature, inputs): assert len(inputs) == 3 input0, input1, value = inputs @@ -516,6 +598,9 @@ def MaskedFill(signature, inputs): edim_in0 = ShapeAnno.create_shape_str(input0.shape) edim_in1 = ShapeAnno.create_shape_str(input1.shape) edim_ou = copy.copy(edim_in0) + for idx, (lhs, rhs) in enumerate(zip(input0.shape, input1.shape)): + if lhs != rhs and rhs == 1: + edim_ou[idx] = '1' anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input0, input1], value=value) @@ -650,8 +735,11 @@ def View(signature, inputs): """ out = torch.Tensor.view(tensor: torch.Tensor, size: List[int]) """ - assert len(inputs) == 2 - input, shape = inputs + if len(inputs) == 2: + input, shape = inputs + else: + input = inputs[0] + shape = inputs[1:] if not all([isinstance(dim, int) for dim in shape]): raise TypeError("Expected tensor.view has static int shape") in_shape, ou_shape = list(input.shape), shape @@ -809,6 +897,25 @@ def Reshape(signature, inputs): return View(signature, inputs) + +def Permute(signature, inputs): + if isinstance(inputs[1], list): + in_tensor, dims = inputs[0], inputs[1] + else: + in_tensor, dims = inputs[0], inputs[1:] + edim_in = ShapeAnno.create_shape_str(in_tensor.shape) + for idx, dim in enumerate(dims): + if idx != dim: + edim_in[idx] += '^' + assert len(edim_in) == len(dims), f'{len(edim_in)} vs {len(dims)}' + edim_ou = [] + for dim in dims: + assert isinstance(dim, int) + edim_ou.append(copy.copy(edim_in[dim])) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Permute, 'permute', signature, [anno], [in_tensor], dims=dims) + + def Squeeze(signature, inputs): """ out = torch.squeeze(tensor) @@ -826,6 +933,7 @@ def Squeeze(signature, inputs): return IRDimops(Squeeze, 'squeeze', signature, [anno], [input]) + def Unsqueeze(signature, inputs): """ out = torch.unsqueeze(tensor, dim) @@ -841,6 +949,7 @@ def Unsqueeze(signature, inputs): return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input], dim=dim) + def TypeAs(signature, inputs): """ out = torch.Tensor.type_as(tensor0, tensor1) @@ -849,12 +958,12 @@ def TypeAs(signature, inputs): input0, input1 = inputs edim_in0 = ShapeAnno.create_shape_str(input0.shape) - edim_in1 = ShapeAnno.create_shape_str(input1.shape) edim_ou = copy.copy(edim_in0) - anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) + anno = OpAnno.create_op_str([edim_in0, '*'], [edim_ou]) return IRDimops(TypeAs, 'type_as', signature, [anno], [input0, input1]) + def Triu(signature, inputs): """ out = torch.triu(tensor, diagonal) @@ -872,6 +981,27 @@ def Triu(signature, inputs): return IRDimops(Triu, 'triu', signature, [anno], [input], diagonal=diagonal) + +def CumSum(signature, inputs): + """ + out = torch.cumsum(tensor, dim) + """ + assert len(inputs) == 2 + input, dim = inputs + if isinstance(dim, dict): + dim = dim['dim'] + else: + assert isinstance(dim, int) + + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_in[dim] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(CumSum, 'cumsum', signature, [anno], [input], + dim=dim) + + # def Pad(signature, inputs): # """ # torch.nn.functional.pad(input: torch.Tensor, pad: List[int], mode='constant', value=0.0) @@ -1026,7 +1156,10 @@ def Stack(signature, inputs: Tuple[List[IRTensor], int]): tensors, dim = inputs else: tensors, dim = inputs[:-1], inputs[-1] - assert all(isinstance(tensor, IRTensor) for tensor in tensors) + if isinstance(dim, dict): + assert 'dim' in dim + dim = dim['dim'] + assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'{tensors}' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] oannos = [copy.copy(iannos[-1])] oannos[0].insert(dim, str(len(tensors))) @@ -1063,6 +1196,25 @@ def Select(signature, inputs: Tuple[IRTensor, int, int]): return IRDimops(Select, 'select', signature, [anno], [tensor], dim=dim, index=index) +def IndexSelect(signature, inputs): + assert len(inputs) == 3 + # hack + if isinstance(inputs[1], int): + tensor, dim, idx = inputs + else: + assert isinstance(inputs[2], int) + tensor, idx, dim = inputs + + edim_in = ShapeAnno.create_shape_str(tensor.shape) + edim_in[dim] += '^' + idx_anno = chr(ord(edim_in[-1]) + 1) + '^' + edim_ou = copy.copy(edim_in) + edim_ou[dim] = copy.copy(idx_anno) + anno = OpAnno.create_op_str([edim_in, idx_anno], [edim_ou]) + + return IRDimops(IndexSelect, 'index_select', signature, [anno], [tensor, idx], dim=dim) + + def Slice(signature, inputs): """ aten::slice(input:Tensor, dim:int, start:Optional[int], end:Optional[int], step:int) -> Tensor @@ -1152,7 +1304,10 @@ def Embedding(signature, inputs: List): def Flatten(signature, inputs: List): tensor: IRTensor = inputs[0] - start_dim, end_dim = inputs[1:] + if len(inputs) == 1: + start_dim, end_dim = 0, len(tensor.shape) - 1 + else: + start_dim, end_dim = inputs[1:] end_dim = len(tensor.shape) + end_dim if end_dim < 0 else end_dim ishape = ShapeAnno.create_shape_str(tensor.shape) for dim in range(start_dim, end_dim+1): @@ -1275,3 +1430,14 @@ def CompareLE(signature, inputs): torch.gt(input, other, *, out=None) -> Tensor """ return _comparison(CompareLE, operator.le, 'le', signature, inputs) + + +def ShapeAsTensor(signature, inputs): + assert len(inputs) == 1 + input = inputs[0] + + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = [str(len(input.shape))] + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + return IRDimops(ShapeAsTensor, '_shape_as_tensor', signature, [anno], [input]) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 055fd17f..bfa5a460 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -69,11 +69,29 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('unsqueeze'): function.Unsqueeze, __tttemplate('type_as'): function.TypeAs, __ttemplate('triu'): function.Triu, - __ftemplate('relu') : function.ReLU, + __ftemplate('relu'): function.ReLU, + __fcntemplate('gelu'): function.GeLU, + __ttemplate('eq') : function.EQ, __ttemplate('ne') : function.NE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, + __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, + __ttemplate('cumsum'): function.CumSum, + __ttemplate('tanh'): function.Tanh, + __ftemplate('softmax') : function.Softmax, + __ttemplate('bmm') : function.BatchLinear, + __ttemplate('pow'): function.Pow, + __ttemplate('baddbmm'): function.BMMAdd, + __ttemplate('permute'): function.Permute, + __ttemplate('transpose'): function.Transpose, + __tttemplate('expand'): function.Expand, + __ttemplate('detach'): function.Detach, + __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, + __ttemplate('index_select'): function.IndexSelect, + + # TODO + __ftemplate('layer_norm'): function.LayerNorm, __ftemplate('embedding'): function.Embedding, # # torch nn functional @@ -82,10 +100,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # # __ttemplate('matmul'): function.Matmul, # - # __ftemplate('softmax') : function.Softmax, - # - # __ftemplate('dropout') : function.Dropout, - # # __ftemplate('gelu') : function.GeLU, # __ttemplate('gelu') : function.GeLU, # @@ -94,8 +108,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # # __ftemplate('_pad'): function.Pad, # - # __ftemplate('layer_norm'): function.LayerNorm, - # # __ftemplate('embedding'): function.Embedding, # # __ftemplate('cross_entropy'): function.CrossEntropy, @@ -115,36 +127,31 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('sub') : function.Sub, # # __ttemplate('mul') : function.Mul, - # - # __ttemplate('div') : function.Div, - # - # __ttemplate('floordiv') : function.FloorDiv, - # - # __ttemplate('neg'): function.Neg, - # - # __ttemplate('gt'): function.CompareGT, - # __ttemplate('lt'): function.CompareLT, - # __ttemplate('ge'): function.CompareGE, - # __ttemplate('le'): function.CompareLE, - # - # __ttemplate('pow'): function.Pow, + + __ttemplate('div') : function.Div, + __ttemplate('truediv'): function.Div, + __ttemplate('true_divide'): function.Div, + __ttemplate('floordiv') : function.FloorDiv, + __ttemplate('floor_divide') : function.FloorDiv, + + __ttemplate('neg'): function.Neg, + # + __ttemplate('gt'): function.CompareGT, + __ttemplate('lt'): function.CompareLT, + __ttemplate('ge'): function.CompareGE, + __ttemplate('le'): function.CompareLE, # # __ttemplate('sin'): function.Sin, # # __ttemplate('cos'): function.Cos, # - # __ttemplate('tanh'): function.Tanh, - # - # __ttemplate('bmm') : function.BatchLinear, - # # __ttemplate('sum') : function.Sum, # __ttemplate('mean') : function.Mean, # - # __ttemplate('transpose') : function.Transpose, - # # __ttemplate('view'): function.View, - # - # __ttemplate('reshape'): function.Reshape, + __tttemplate('view'): function.View, + + __ttemplate('reshape'): function.Reshape, # # __ttemplate('conv2d'): function.Conv2D, # @@ -165,12 +172,12 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('linear'): function.Linear, # # __ttemplate('cat'): function.Cat, - # - # __ttemplate('stack'): function.Stack, + + __ttemplate('stack'): function.Stack, # # __ttemplate('chunk'): function.Chunk, - # - # __ttemplate('flatten'): function.Flatten, + + __ttemplate('flatten'): function.Flatten, # # __ttemplate('roll'): function.Roll, # diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index ab3a463b..f654cd62 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -47,9 +47,9 @@ def infer(node, dtypes: List[IRDType]) -> IRDType: if IRDType.float32 in dtypes and IRDType.float16 in dtypes: raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") # TODO(yizhu1): hack - if node.signature == 'torch.ne': + if node.signature in ('torch.ne', 'torch.eq', 'torch.gt'): return IRDType.boolean - elif node.signature == 'torch.Tensor.long': + elif node.signature in ('torch.Tensor.long', 'torch._shape_as_tensor'): return IRDType.int64 # in priority: fp32 > fp16 > bool > int64 > int16 > priority = [ diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 0054005a..2798631b 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -633,7 +633,7 @@ def grad(self) -> bool: @grad.setter def grad(self, val: Optional[IRTensor]): if isinstance(val, IRSubTensor): - assert self.requires_grad and val.shape == self.shape + assert self.requires_grad and val.shape == self.shape, f'info: {self.requires_grad} {val.shape == self.shape}' self._grad = val elif val is None: assert not self.requires_grad diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 15e288de..233400f5 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -60,9 +60,18 @@ def __init__(self, dim, mult=1, nlayers=4): def forward(self, data, mask): x = data.masked_fill(mask, 0.0) + y = torch._shape_as_tensor(x) + z = torch.gt(x, x) + x = x.fill_(0.0) + x = torch.nn.functional.softmax(x, dim=-1) + x = torch.bmm(x, x) + x = torch.baddbmm(x, x, x, alpha=0.125, beta=1.0) + x = torch.tanh(x) + x = torch.pow(x, x) for layer in self.layers: x = layer(x) x = torch.nn.functional.relu(x) + x = torch.nn.functional.gelu(x) # x = self.layer_norm(x) x = x.type_as(data) x = x.unsqueeze(1) @@ -70,13 +79,30 @@ def forward(self, data, mask): x = x.squeeze() x = torch.triu(x, 1) x = torch.nan_to_num(x) - # ne cannot backward - # x = torch.ne(x, 1.0) - # x = torch.nn.functional.dropout(x, self.p) - # x = x * self.y + ne_var = x.detach() + ne_var = torch.ne(ne_var, 1.0) + eq_var = x.detach() + eq_var = torch.eq(eq_var, 1.0) + long_var = x.detach() + long_var = long_var.long() + floor_div_var = x.detach() + floor_div_var = torch.floor_divide(floor_div_var, 2.0) + x = torch.true_divide(x, 1.0) + x = torch.cumsum(x, dim=-1) + x = x.permute(0, 2, 1) + x = x.transpose(1, 2) + x = torch.div(x, 1.0) + # concrete_trace not support + # x = torch.Tensor.view(x, [32 * 1024, 1024]) + x = x.view(32 * 1024, 1024) + x = x.reshape(32, 1024, 1024) + # indices = torch.arange(4, dtype=torch.int64) + # x = torch.index_select(x, 1, indices) + p = torch.div(x, 2.0) + x = torch.stack((x, p), dim=1) + x = torch.flatten(x, 2, 3) + x = torch.neg(x) loss = torch.sum(x) - # long cannot backward - # loss = loss.long() return loss From 8a8eb9a992383005381d803140fcff04ef36bfdb Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sun, 5 Mar 2023 02:18:01 -0800 Subject: [PATCH 1308/1892] minor --- cube/graph/parser/mappingfx.py | 6 +++--- cube/graph/parser/parserfx.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 128a2791..0b2c2c7f 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -75,14 +75,14 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, __ttemplate('fill_'): function.Fill, - __ttemplate('masked_fill'): function.MaskedFill, + # __ttemplate('masked_fill'): function.MaskedFill, __ftemplate('embedding'): function.Embedding, - __ttemplate('cumsum'): function.CumSum, + # __ttemplate('cumsum'): function.CumSum, __ttemplate('tanh'): function.Tanh, __ftemplate('softmax') : function.Softmax, __ttemplate('bmm') : function.BatchLinear, __ttemplate('pow'): function.Pow, - __ttemplate('baddbmm'): function.BMMAdd, + # __ttemplate('baddbmm'): function.BMMAdd, __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 3733dfa1..905ce85f 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -276,6 +276,7 @@ def extract_val(fx_node): # map to IR operator if SignFx2Op.exist(fsig): + print(f'zql input_vals: {input_vals}') ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: #input_vals = [extract_val(v) for v in node.args] From a66b23c5b808a4f5bbb22234e068b08f741f1a63 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 5 Mar 2023 04:16:06 -0800 Subject: [PATCH 1309/1892] fix bug, support repeat and einsum --- cube/codegen/frontend_mapping.py | 7 +++++++ cube/graph/function/function.py | 21 ++++++++++++++++++++- cube/graph/parser/mappingfx.py | 5 +++-- examples/mlp/linearsfx.py | 11 ++++++++--- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index 37d3cba0..c145dd0f 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -174,6 +174,12 @@ def emit_index_select(node, arg_vars:list, kw_pairs:dict) -> str: dim = kw_pairs['dim'] return f'{node.signature}({arg_vars[0]}, {dim}, {arg_vars[1]})' +def emit_einsum(node, arg_vars:list, kw_pairs:dict) -> str: + assert 'equation' in kw_pairs + equation = kw_pairs['equation'] + args_str = ', '.join(arg_vars) + return f'{node.signature}({equation}, {args_str})' + class Sign2EmitRule: @@ -206,6 +212,7 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: 'torch.rand': emit_rand, 'torch.tensor': emit_new_tensor, 'torch.index_select': emit_index_select, + 'torch.functional.einsum': emit_einsum, 'setattr': emit_setattr, } diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 74088053..9d0f80c5 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -60,6 +60,21 @@ def BMMAdd(signature, inputs): return IRDimops(BMMAdd, 'baddbmm', signature, annos, inputs[:3], alpha=alpha, beta=beta) +def EinSum(signature, inputs): + if isinstance(inputs[0], str): + equation, tensors = inputs[0], inputs[1:] + else: + tensors, equation = inputs[:-1], inputs[-1] + lhs, rhs = equation.split('->') + assert ',' not in rhs + lhs_dims = set(lhs.replace(',', ' ').split(' ')) + for dim in lhs_dims: + if dim not in rhs: + lhs = lhs.replace(dim, f'{dim}+') + anno = f'{lhs} -> {rhs}' + return IRDimops(EinSum, 'einsum', signature, [anno], tensors, equation=equation) + + def Matmul(signature, inputs: Tuple[IRTensor, IRTensor]): assert len(inputs) == 2 annos = [ @@ -1265,7 +1280,11 @@ def Repeat(signature, inputs: Tuple[IRTensor, List[int]]): torch.repeat(tensor:Tensor, repeats: List[int]) -> Tensor """ signature = 'torch.ops.aten.repeat' - tensor, repeats = inputs + tensor = inputs[0] + if isinstance(inputs[1], list): + repeats = inputs[1] + else: + repeats = inputs[1:] in_shape = tensor.shape assert len(in_shape) <= len(repeats), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor" expand = len(repeats) - len(tensor.shape) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index bfa5a460..536059e5 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -89,10 +89,11 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('detach'): function.Detach, __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, __ttemplate('index_select'): function.IndexSelect, + __ftemplate('embedding'): function.Embedding, + 'torch.functional.einsum': function.EinSum, # TODO __ftemplate('layer_norm'): function.LayerNorm, - __ftemplate('embedding'): function.Embedding, # # torch nn functional # @@ -166,7 +167,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # #pytorch1.11 # __ttemplate('select_scatter'): function.SelectScatter, # - # __ttemplate('repeat'): function.Repeat, + __tttemplate('repeat'): function.Repeat, # # #pytorch1.11 # __ttemplate('linear'): function.Linear, diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 233400f5..3cac42ed 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -73,7 +73,8 @@ def forward(self, data, mask): x = torch.nn.functional.relu(x) x = torch.nn.functional.gelu(x) # x = self.layer_norm(x) - x = x.type_as(data) + type_x = torch.pow(x, 1.0) + x = x.type_as(type_x) x = x.unsqueeze(1) x = self.drop_out(x) x = x.squeeze() @@ -96,12 +97,16 @@ def forward(self, data, mask): # x = torch.Tensor.view(x, [32 * 1024, 1024]) x = x.view(32 * 1024, 1024) x = x.reshape(32, 1024, 1024) - # indices = torch.arange(4, dtype=torch.int64) + neg_x = torch.neg(x) + x = torch.einsum('a b c, a c d -> a b d', x, neg_x) + # TODO(yizhu1): uncomment and check + # bs = x.size(1) + # indices = torch.arange(bs, dtype=torch.int64) # x = torch.index_select(x, 1, indices) p = torch.div(x, 2.0) x = torch.stack((x, p), dim=1) x = torch.flatten(x, 2, 3) - x = torch.neg(x) + x = x.repeat(1, 2, 1) loss = torch.sum(x) return loss From 46adde7ee0ea8dffe8f8fd943fdc8d10afb4bac5 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sun, 5 Mar 2023 05:50:27 -0800 Subject: [PATCH 1310/1892] support layer_norm, constant folding of get_attr --- cube/graph/function/function.py | 18 +++++++++++------- cube/graph/parser/mappingfx.py | 10 ++++------ cube/graph/parser/parserfx.py | 12 +++++++----- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 74088053..143cca71 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -612,28 +612,32 @@ def LayerNorm(signature, inputs): """ if 'torch.' in signature: tensor, normalized_shape, weight, bias, eps = inputs - assert isinstance(normalized_shape, list), f"normalized_shape for layer_norm can only be List[int]" + # FIXME: uncomment the assert + assert isinstance(normalized_shape, (list, tuple, torch.Size)), \ + f"normalized_shape for layer_norm can only be tuple or list or torch.Size, NOT {type(normalized_shape)}" else: + assert 'cube.runtime.function.layer_norm' == signature, f'{signature} of LayerNorm is not supported.' tensor, weight, bias, normalized_shape, eps = inputs letters = iter(string.ascii_lowercase) einput = ShapeAnno.create_shape_str(tensor.shape, iterator=letters) eoutput = copy.copy(einput) ndims = len(tensor.shape) - for dim in range(len(normalized_shape)): + ndims_normshape = len(normalized_shape) + for dim in range(ndims_normshape): + # though these dimensions can be partitioned, + # such partition induces additional communication and complexity einput[ndims-1-dim] += '^' eoutput[ndims-1-dim] += '^' - assert not (bias is None is weight is not None), f"Not support for None of weight and parameter of bias" + assert not (bias is not None and weight is None), f"Not support for None of weight and parameter of bias" einputs, inputs = [einput], [tensor] kwargs = {} if weight is not None: - eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) - einputs.append(eweight) + einputs.append(einput[ndims-ndims_normshape:]) inputs.append(weight) else: kwargs['weight'] = weight if bias is not None: - ebias = ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) - einputs.append(ebias) + einputs.append(einput[ndims-ndims_normshape:]) inputs.append(bias) else: kwargs['bias'] = bias diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 74d82fe9..82e1888e 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -76,21 +76,19 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, __ttemplate('fill_'): function.Fill, - __ttemplate('masked_fill'): function.MaskedFill, + # __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, __ttemplate('tanh'): function.Tanh, __ftemplate('softmax') : function.Softmax, __ttemplate('bmm') : function.BatchLinear, __ttemplate('pow'): function.Pow, - __ttemplate('baddbmm'): function.BMMAdd, + # __ttemplate('baddbmm'): function.BMMAdd, __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, __tttemplate('expand'): function.Expand, __ttemplate('detach'): function.Detach, __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, __ttemplate('index_select'): function.IndexSelect, - - # TODO __ftemplate('layer_norm'): function.LayerNorm, # # torch nn functional @@ -148,9 +146,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('mean') : function.Mean, # # __ttemplate('view'): function.View, - __tttemplate('view'): function.View, + # __tttemplate('view'): function.View, - __ttemplate('reshape'): function.Reshape, + # __ttemplate('reshape'): function.Reshape, # # __ttemplate('conv2d'): function.Conv2D, # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 905ce85f..953508bb 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -154,7 +154,8 @@ def parse(module: torch.fx.GraphModule, all_ir_nodes: List[IRFwOperation] = list() for node in module.graph.nodes: ir_nodes = FxModuleParser.parse_node(node, module, frame) - all_ir_nodes += ir_nodes + if ir_nodes is not None: + all_ir_nodes += ir_nodes #output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] # handle outputs @@ -318,11 +319,12 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram else: # FIXME: why no need to record the constant value of this var? # the value can be obtained below: - # var = FxModuleParser.fetch_attr(module, node.target) - print(f'WARNING: {node.name} {node.meta} in attr node uses empty IRObject!') - frame.add_var(tensor_name, IRObject()) + var = FxModuleParser.fetch_attr(module, node.target) + frame.add_var(tensor_name, var) + # print(f'WARNING: {node.name} {node.meta} in attr node uses empty IRObject!') + # frame.add_var(tensor_name, IRObject()) - return list() + return None @staticmethod def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: From 3164d932bb040ee615b304a31061cc936567faa1 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sun, 5 Mar 2023 05:53:52 -0800 Subject: [PATCH 1311/1892] minor --- cube/graph/function/function.py | 1 - cube/graph/parser/parserfx.py | 5 ----- 2 files changed, 6 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 143cca71..2a315d0c 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -612,7 +612,6 @@ def LayerNorm(signature, inputs): """ if 'torch.' in signature: tensor, normalized_shape, weight, bias, eps = inputs - # FIXME: uncomment the assert assert isinstance(normalized_shape, (list, tuple, torch.Size)), \ f"normalized_shape for layer_norm can only be tuple or list or torch.Size, NOT {type(normalized_shape)}" else: diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 953508bb..a37583b3 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -277,7 +277,6 @@ def extract_val(fx_node): # map to IR operator if SignFx2Op.exist(fsig): - print(f'zql input_vals: {input_vals}') ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: #input_vals = [extract_val(v) for v in node.args] @@ -317,12 +316,8 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram ir_tensor.as_param() frame.add_var(tensor_name, ir_tensor) else: - # FIXME: why no need to record the constant value of this var? - # the value can be obtained below: var = FxModuleParser.fetch_attr(module, node.target) frame.add_var(tensor_name, var) - # print(f'WARNING: {node.name} {node.meta} in attr node uses empty IRObject!') - # frame.add_var(tensor_name, IRObject()) return None From 083d4f7fc51ff1c882bf2a3e5178c9253e52019e Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 5 Mar 2023 06:09:38 -0800 Subject: [PATCH 1312/1892] refine code --- cube/graph/parser/mappingfx.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 66794a61..536059e5 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -89,9 +89,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('detach'): function.Detach, __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, __ttemplate('index_select'): function.IndexSelect, - - # TODO - __ftemplate('layer_norm'): function.LayerNorm, __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, From c70dad6c1b5e06f51aa3ec432a3ca4510b6fe42f Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Mon, 6 Mar 2023 01:18:28 +0000 Subject: [PATCH 1313/1892] Merged PR 1476: add torchscale tester besides linearfx add examples/nlp/torchscale/basic_test.py which contains more torchscale usage. Users can perform more complex tests than linearfx without installing extra dependencies, e.g., TorchScale, Fairseq --- .../{lm_fx_test.py => run_torchscale_lm.py} | 2 +- .../{fx_test.py => run_torchscale_tl.py} | 2 +- tests/parser/test_torchscale_basic.py | 84 +++++++++++++++++++ 3 files changed, 86 insertions(+), 2 deletions(-) rename examples/nlp/torchscale/{lm_fx_test.py => run_torchscale_lm.py} (89%) rename examples/nlp/torchscale/{fx_test.py => run_torchscale_tl.py} (94%) create mode 100644 tests/parser/test_torchscale_basic.py diff --git a/examples/nlp/torchscale/lm_fx_test.py b/examples/nlp/torchscale/run_torchscale_lm.py similarity index 89% rename from examples/nlp/torchscale/lm_fx_test.py rename to examples/nlp/torchscale/run_torchscale_lm.py index 382174e6..3813681b 100644 --- a/examples/nlp/torchscale/lm_fx_test.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -1,4 +1,4 @@ -# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/lm_fx_test.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData +# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData import torch import pickle diff --git a/examples/nlp/torchscale/fx_test.py b/examples/nlp/torchscale/run_torchscale_tl.py similarity index 94% rename from examples/nlp/torchscale/fx_test.py rename to examples/nlp/torchscale/run_torchscale_tl.py index 20588758..f7f58ff3 100644 --- a/examples/nlp/torchscale/fx_test.py +++ b/examples/nlp/torchscale/run_torchscale_tl.py @@ -1,4 +1,4 @@ -# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/fx_test.py examples/nlp/torchscale/input --arch mt_base --share-decoder-input-output-embed --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --dropout 0.3 --weight-decay 0.0001 --max-tokens 4096 --fp16 --policy PASData +# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/run_torchscale_tl.py examples/nlp/torchscale/input --arch mt_base --share-decoder-input-output-embed --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --dropout 0.3 --weight-decay 0.0001 --max-tokens 4096 --fp16 --policy PASData import torch import pickle diff --git a/tests/parser/test_torchscale_basic.py b/tests/parser/test_torchscale_basic.py new file mode 100644 index 00000000..6b91cd4d --- /dev/null +++ b/tests/parser/test_torchscale_basic.py @@ -0,0 +1,84 @@ +# OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 --master_port=25648 tests/parser/test_torchscale_basic.py --policy PASData + +import torch +from torch import nn + +import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + +import examples.mlp.policy.spmd as spmd +import examples.mlp.policy.mpmd as mpmd + +import argparse + +parser = argparse.ArgumentParser(description='comm primitive') +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +parser.add_argument('--local_rank', type=int, default=0) +args = parser.parse_args() + +cube.init() + +# set up policy +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + +class SimpleNLP(nn.Module): + def __init__(self): + super().__init__() + self._tensor_constant0 = 1 + self.linear = torch.nn.Linear(2, 3) + + def forward(self, src_tokens, num): + _shape_as_tensor = torch._shape_as_tensor(src_tokens) + getitem_1 = _shape_as_tensor[1] + add = 2 + getitem_1 + arange = torch.arange(add, dtype=torch.float32) + unsqueeze = arange.unsqueeze(1) + _tensor_constant0 = self._tensor_constant0 + mul = unsqueeze * _tensor_constant0 + sin = torch.sin(mul) + cos = torch.cos(mul) + cat = torch.cat([sin, cos], dim=1) + view = cat.view(add, -1) + linear = self.linear(view) + return linear + +def run(): + dataloader = cube.runtime.syndata.SynDataLoader( + shapes=([4, 16], [2],), + dtypes=(torch.int64, torch.int64,), + batch_dims=(0, 0,) + ) + + sample_input = next(dataloader) + print(f'next(dataloader) = {sample_input}') + + model = SimpleNLP() + output = model(*sample_input) + print(f'output = {output}') + + device = next(model.parameters()).device + sample_input = next(dataloader) + sample_input_cpu = tuple([input.to(device) for input in sample_input]) + model = cube.SemanticModel( + model, dummy_input=sample_input_cpu, + ) + + # @cube.compile(model, dataloader, PAS=PAS, load_content=False) + def train_iter(model, dataloader): + data = next(dataloader) + out = model(*data) + return out + + train_iter(model, dataloader) + +run() \ No newline at end of file From 16df9ba57c098b5acaa86e1e0a69c480621151eb Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 6 Mar 2023 05:18:41 +0000 Subject: [PATCH 1314/1892] Merged PR 1475: constant folding frame function.py can potentially return a concrete value instead of IRCell. In this case, vars in frame will be reset to the concrete value so that the following consumers will take the concrete value instead an IRObject --- cube/graph/function/function.py | 32 +++++++++++++++++- cube/graph/gener/gen.py | 11 +++--- cube/graph/graph.py | 8 ++--- cube/graph/parser/frame.py | 10 ++++++ cube/graph/parser/mappingfx.py | 12 +++++-- cube/graph/parser/parserfx.py | 53 ++++++++++++----------------- cube/graph/segment.py | 5 ++- cube/ir/cten.py | 10 +++--- cube/ir/operator.py | 9 +++-- cube/program.py | 59 ++++++++++++++++++--------------- tests/parser/test_fx_ops.py | 1 + 11 files changed, 125 insertions(+), 85 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 9d0f80c5..ec673c05 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -5,9 +5,10 @@ import warnings import operator -from cube.ir.cten import IRTensor +from cube.ir.cten import IRTensor, IRObject from cube.ir.tensor import IRSubTensor from cube.ir.dtype import IRDType +from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D from cube.graph.function.creators import IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor @@ -1460,3 +1461,32 @@ def ShapeAsTensor(signature, inputs): anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(ShapeAsTensor, '_shape_as_tensor', signature, [anno], [input]) + + + +# ================== Non-autograd Function Space ================= + +def Size(signature, inputs) -> Union[List[int], IRPyFunc]: + """ + torch.Tensor.size(tensor, dim=None) + """ + assert len(inputs) == 2 or len(inputs) == 1, f"but got {len(inputs)}, {inputs}" + tensor, dim = inputs if len(inputs) == 2 else (inputs[0], None) + assert isinstance(tensor, IRTensor) + # constant + if all(isinstance(dimlen, int) for dimlen in tensor.shape) and not isinstance(dim, IRObject): + return tensor.shape[dim] if isinstance(dim, int) else list(tensor.shape) + return IRPyFunc(signature, inputs, [IRObject()]) + + +def GetItem(signature, inputs) -> Union[Any, IRPyFunc]: + """ + _operator.getitem(obj, index: int) + """ + assert len(inputs) == 2, f"but got {inputs}" + obj, index = inputs + if (not isinstance(obj, IRObject)) and isinstance(index, int): + return obj[index] + else: + return IRPyFunc(signature, inputs, [IRObject()]) + \ No newline at end of file diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index ce7b8ccd..ac0fbc3d 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -9,7 +9,7 @@ from cube.graph.segment import IRSegment from cube.graph.function.pyfunc import IRPyFunc -from cube.ir.cten import IRCell +from cube.ir.cten import IRCell, IRObject from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.operator import IRFwOperation @@ -147,15 +147,14 @@ def remove_anchor(graph: IRSegment): def auto_pyfunc(graph: IRSegment): """ Make pyfunc to be local + IRPyFunc will be replicated to devices with its producers output """ for func in graph.select(ntype=IRPyFunc, flatten=False): assert func.mirror is None, "PyFunc is only supported by inference" - assert all(not isinstance(t, IRSubTensor) for t in func.outputs()), \ - "PyFunc doesn't support tensor outputs" # get devices it will lowered to devices = set() for t in func.inputs(): - if not isinstance(t, IRSubTensor): continue + if not isinstance(t, IRObject): continue producers = graph.producers(t.parent) for p in producers: devices.update(p.device) @@ -164,13 +163,13 @@ def auto_pyfunc(graph: IRSegment): for devid in devices: inputs = [] for t in func.inputs(): - if isinstance(t, IRSubTensor): + if isinstance(t, IRObject): if t.is_attr(): tensors = set(tensor for tensor in graph.ctensors(t.parent) if devid in tensor.device and tensor.cell != func) else: tensors = set(tensor for tensor in graph.ptensors(t.parent) if devid in tensor.device) assert len(tensors) == 1, \ - f"Find {len(tensors)} != 1 versions of tensor {t} on a same device." + f"Find {len(tensors)} != 1: {tensors} versions of tensor {t} on a same device." t = list(tensors)[0] inputs.append(t) lower_func = IRPyFunc(func.signature, inputs, func.outputs(), **func.kwargs) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index c8214e19..5528dd4b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -251,12 +251,12 @@ def from_logic_graph(nodes: List[IRCell], # instantiate to subtensor for node in nodes: for idx, ftensor in enumerate(node.inputs()): - if isinstance(ftensor, IRFullTensor): - subtensor = ftensor.tosub() + if isinstance(ftensor, IRObject): + subtensor = ftensor.tosub() if isinstance(ftensor, IRFullTensor) else ftensor node.set_input(idx, subtensor) for idx, ftensor in enumerate(node.outputs()): - if isinstance(ftensor, IRFullTensor): - subtensor = ftensor.tosub() + if isinstance(ftensor, IRObject): + subtensor = ftensor.tosub() if isinstance(ftensor, IRFullTensor) else ftensor node.set_output(idx, subtensor) graph = IRGraph(nodes, inputs, outputs, module_name) return graph diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index d3b82936..29f71419 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -75,6 +75,16 @@ def add_var(self, var_name: str, val: Any, graph_arg: int = -1): self._vars[-1][var_name] = val else: raise ValueError("graph_arg (int) must be >= 0") + + def set_var(self, var_name: str, val: Any): + """ + Reset a variable with arbitrary value. + If `var_name` doesn't exist, will create a new one + + @param var_name str: variable name + @param val Any + """ + self._vars[-1][var_name] = val def get_var(self, var_name: str) -> Any: """ diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 536059e5..e78ce876 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -14,7 +14,9 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: """ Map the signature to GenericLogicalOp """ - if 'torch.' not in signature and 'cube.runtime.' not in signature: + bultin_regions = ['torch.', 'cube.runtime.', '_operator.'] + # customized function + if all(not signature.startswith(region) for region in bultin_regions): signature = signature.split('.')[-1] if signature in SignFx2Op.kOpMap: function = SignFx2Op.kOpMap[signature] @@ -26,7 +28,9 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: @staticmethod def exist(signature: str) -> bool: - if 'torch.' not in signature and 'cube.runtime.' not in signature: + bultin_regions = ['torch.', 'cube.runtime.', '_operator.'] + # customized function + if all(not signature.startswith(region) for region in bultin_regions): signature = signature.split('.')[-1] return signature in SignFx2Op.kOpMap @@ -95,6 +99,10 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # TODO __ftemplate('layer_norm'): function.LayerNorm, + # ============== runtime function ================= + __tttemplate('size'): function.Size, + '_operator.getitem': function.GetItem, + # # torch nn functional # # __ftemplate('linear') : function.Linear, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 3733dfa1..862f1393 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -11,6 +11,7 @@ from cube.graph.parser.mapping import DType2IRDType from cube.graph.parser.mappingfx import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc +from cube.graph.function import Identity import torch.fx @@ -156,22 +157,8 @@ def parse(module: torch.fx.GraphModule, ir_nodes = FxModuleParser.parse_node(node, module, frame) all_ir_nodes += ir_nodes - #output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] - # handle outputs - output_nodes = [node.all_input_nodes for node in module.graph.nodes if node.op == 'output'] - print(f'outputs = {output_nodes}') - output_var_name = [output.name for output in [item for sublist in output_nodes for item in sublist]] - output_val = [frame.get_var(var_name) for var_name in output_var_name] - - # flatten output_val - outputs = list() - for val in output_val: - if isinstance(val, list): - outputs += val - else: - outputs.append(val) - output_val = outputs - + output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + frame.pop_var() frame.pop_attr() if FxModuleParser.save_content: @@ -278,28 +265,31 @@ def extract_val(fx_node): if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: - #input_vals = [extract_val(v) for v in node.args] - #kwargs = {key: extract_val(v) for key, v in node.kwargs.items()} + input_vals = [extract_val(v) for v in node.args] + # FIXME: handle cases for IRObject in kwargs + kwargs = {key: extract_val(v) for key, v in node.kwargs.items()} # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): print(f'>>> Find unkown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fname ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) - #ir_node.kwargs = kwargs + ir_node.kwargs = kwargs for idx, t in enumerate(input_vals): ir_node.set_input(idx, t) # case2: python runtime function else: print(f'>>> Set python runtime function: {fsig}') - #ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) - ir_node = IRPyFunc(fsig, input_vals, [None]) - - # TODO gracefully set output - output_name = node.name - output_val = frame.get_var(output_name) - ir_node.set_output(0, output_val) - - return [ir_node] + ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) + + if isinstance(ir_node, IRCell): + # TODO gracefully set output + output_name = node.name + output_val = frame.get_var(output_name) + ir_node.set_output(0, output_val) + return [ir_node] + else: + frame.set_var(node.name, ir_node) + return [] @staticmethod def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: @@ -351,10 +341,9 @@ def generate_outputs(val: Any, _ops: List) -> IRObject: if isinstance(val, torch.fx.Node): return frame.get_var(val.name) return val - - generate_outputs(node.args[0], ir_nodes) - if len(ir_nodes) > 0: - ir_nodes[-1].set_output(0, frame.get_var(node.name)) + + output = generate_outputs(node.args[0], ir_nodes) + frame.set_var(node.name, output) return ir_nodes # # NOTE: this is a function in torch.fx diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 5bdfc27c..7c66bf8c 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -957,8 +957,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: continue # loss doesn't have consumers if len(segment.consumers(ftensor)) == 0: - # TODO: loss judgement should be more robust - if ftensor.nelement() == 1: + if isinstance(ftensor, IRFullTensor) and ftensor.is_loss(): outputs.add(otensor) continue # for outside consumers @@ -1030,7 +1029,7 @@ def order(tensors: Set[IRObject]) -> Tuple[IRObject]: def __repr__(self): fw = 'f' if self.isfw() else 'b' - inputs = tuple(t for t in self.inputs() if isinstance(t, IRObject) and not t.is_param()) + inputs = tuple(t for t in self.inputs() if isinstance(t, IRObject) and not t.is_attr()) if self.isfw(): dscp = f"{fw}Graph{self.cid}-{self.device}(inputs={inputs}, outputs={self.outputs()})" else: diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 50fbfced..c636be1e 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -265,7 +265,7 @@ def set_input(self, input_index: int, val): raise RuntimeError( f"Set the input out of range ({input_index} >= {c} or {input_index} < {-c})" ) - if isinstance(val, IRTensor): + if isinstance(val, IRObject): # copy the val val = copy.copy(val) # set tensor dst @@ -298,7 +298,7 @@ def set_output(self, output_index: int, val): raise RuntimeError( f"Set the input out of range ({output_index} >= {c} or {output_index} < {-c})" ) - if isinstance(val, IRTensor): + if isinstance(val, IRObject): val = copy.copy(val) val.cell = self @@ -430,7 +430,7 @@ def __repr__(self) -> str: class IRObject: """ - IRObject serves as non-tensor inputs/outputs for IRCell. + IRObject serves as general data of IRGraph edge """ def __init__(self, name: Optional[str] = None, tid: Optional[int] = None): @@ -439,7 +439,7 @@ def __init__(self, name: Optional[str] = None, tid: Optional[int] = None): @param tid int: object unique id """ self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() - self.name: str = name if name else 'tensor' + self.name: str = name if name else 'obj' self._cell: Optional[IRCell] = None self._is_attr: bool = False @@ -522,7 +522,7 @@ def __repr__(self): class IRTensor(IRObject): """ - IRTensor serves as IRGraph edge + IRTensor serves as tensor data of IRGraph edge Note by setting IRTensor name to "None" indicates this tensor holds nothing and will be translated to None in code generation. diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 4fabc785..fa234bef 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -1,7 +1,7 @@ from typing import Optional, Tuple, Any, Union import copy -from cube.ir.cten import IRCell, IRTensor +from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.tensor import IRFullTensor from cube.algorithm.factory import DistAlgorithmFactory from cube.algorithm.generics import GenericDistAlgo @@ -128,17 +128,16 @@ def replicate(self): def __repr__(self) -> str: sign = self.signature.split('.')[-1] - ins = [t for t in self.inputs() if isinstance(t, IRTensor) and not t.is_attr()] dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " - f"inputs={ins}, " + f"inputs={self.inputs()}, " f"outputs={self.outputs()})") return dscp def extra_repr(self) -> str: sign = self.signature.split('.')[-1] - ins = [t for t in self.inputs()] + # ins = [t for t in self.inputs()] dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " - f"inputs={ins}, " + f"inputs={self.inputs()}, " f"outputs={self.outputs()})") return dscp diff --git a/cube/program.py b/cube/program.py index 948f42e3..d165b4ca 100644 --- a/cube/program.py +++ b/cube/program.py @@ -43,10 +43,8 @@ def add_nodes(self, nodes: List[IRCell]): def get_graph(self) -> IRGraph: return self.instance._graph - def set_output(self, outputs: List[IRTensor]): - for otensor in outputs: - if not isinstance(otensor, IRTensor): - raise NotImplementedError("Not support for non-tensor graph output") + def set_output(self, outputs: List[IRObject]): + assert all(isinstance(t, IRObject) for t in outputs) self.instance._graph.reset_outputs(len(outputs)) for idx, otensor in enumerate(outputs): self.instance._graph.set_output(idx, otensor) @@ -60,7 +58,6 @@ def finalize(self): if not any(isinstance(node, IRBpOperation) for node in graph.nodes()): for ftensor in graph.full_tensors(): ftensor.requires_grad = False - def mirror_as_self(self): """ @@ -81,12 +78,6 @@ def __init__(self, dataloader: CubeDataLoader): if not isinstance(dataloader, CubeDataLoader): raise TypeError("Expected data loader derived from CubeDataLoader") self.dataloader: CubeDataLoader = iter(dataloader) - dtype_map = DType2IRDType - sample = next(dataloader) - if not isinstance(sample, tuple): - sample = (sample,) - self.dtypes = [dtype_map.map(t.dtype) if torch.is_tensor(t) else None for t in sample] - self.shapes = [list(t.shape) if torch.is_tensor(t) else None for t in sample] def get_batch_dims(self) -> Tuple[Optional[int]]: return tuple(self.dataloader.get_batch_dims()) @@ -97,31 +88,45 @@ def get_batch_size(self) -> int: def set_batch_size(self, bs: int): self.dataloader.set_batch_size(bs) return + + def get_runtime_sample(self): + return next(self.dataloader) def __iter__(self): return self def __next__(self): - outputs = list() - for dtype, shape in zip(self.dtypes, self.shapes): - if shape is not None: - data = IRFullTensor( - shape, 'data', requires_grad=False, dtype=dtype - ).tosub() + dtype_map = DType2IRDType + def generate_output(sample): + """Support complex of types: List, Tuple, torch.Tensor, object""" + if isinstance(sample, tuple): + return tuple(generate_output(t) for t in sample) + if isinstance(sample, list): + return list(generate_output(t) for t in sample) + # if isinstance(sample, dict): + # assert all(isinstance(key, (str, int)) for key in sample.keys()) + # return {key:generate_output(val) for key, val in sample.items()} + # if isinstance(sample, set): + # return {generate_output(t) for t in sample} + if isinstance(sample, torch.Tensor): + shape, dtype = list(sample.shape), dtype_map.map(sample.dtype) + return IRFullTensor(shape, 'data', dtype=dtype).tosub() else: - data = IRObject('data') - outputs.append(data) + return IRObject('data') - data_op = IRDataOperation( - data_num=len(outputs), batch_dims=self.get_batch_dims(), - ) - for idx, output in enumerate(outputs): - data_op.set_output(idx, output) + sample = next(self.dataloader) + outputs = generate_output(sample) + # create dataloader + data_num = len(outputs) if isinstance(outputs, tuple) else 1 + data_op = IRDataOperation(data_num=data_num, batch_dims=self.get_batch_dims()) + if not isinstance(outputs, tuple): + data_op.set_output(0, outputs) + else: + for idx, t in enumerate(outputs): + data_op.set_output(idx, t) Program().add_node(data_op) - if len(outputs) == 0: return - elif len(outputs) == 1: return outputs[0] - else: return tuple(outputs) + return outputs class SemanticModel: diff --git a/tests/parser/test_fx_ops.py b/tests/parser/test_fx_ops.py index 74de8f91..b224ed40 100644 --- a/tests/parser/test_fx_ops.py +++ b/tests/parser/test_fx_ops.py @@ -21,6 +21,7 @@ def forward(self, x: torch.Tensor): # matmul: [256, 512], [512, 256] -> [256, 256] x1 = torch.matmul(x, self.param1) x1 = torch.matmul(x, self.param1) + x1 = x1 + x1.size(0) + x1.size()[0] x2 = torch.chunk(x, 2, dim=1) x3 = x2[0] x = x + x.size(0) From 3c532ca1f0ec16915bc0b56af0c9896ec62b1bed Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Mon, 6 Mar 2023 06:03:27 +0000 Subject: [PATCH 1315/1892] Merged PR 1478: fix for test_torchscale_basic error fix for test_torchscale_basic errors --- cube/graph/parser/parserfx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 862f1393..dbd8fb12 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -246,7 +246,7 @@ def extract_val(fx_node): return frame.get_var(var_name) elif isinstance(fx_node, (int, float, str, torch.dtype)) or fx_node is None: return fx_node - elif isinstance(fx_node, tuple): + elif isinstance(fx_node, (tuple, list)): return handle_tuple(fx_node) else: raise RuntimeError(f'Unsupported input node {fx_node}, {type(fx_node)} in parse function!') From 4a405afae9ac6cb642a1fb7981a623fa84561d98 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Mon, 6 Mar 2023 01:08:03 -0800 Subject: [PATCH 1316/1892] update --- cube/graph/function/function.py | 21 ++++++++++++++++++--- cube/graph/parser/mappingfx.py | 15 +++++++++------ cube/graph/parser/parserfx.py | 11 +++++++++++ examples/__init__.py | 0 tests/parser/test_bloom.py | 12 +++++++++++- 5 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 examples/__init__.py diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index a49adf6c..6251902e 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -6,7 +6,7 @@ import operator from cube.ir.cten import IRTensor, IRObject -from cube.ir.tensor import IRSubTensor +from cube.ir.tensor import IRSubTensor, IRFullTensor from cube.ir.dtype import IRDType from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule @@ -616,7 +616,7 @@ def MaskedFill(signature, inputs): edim_ou = copy.copy(edim_in0) for idx, (lhs, rhs) in enumerate(zip(input0.shape, input1.shape)): if lhs != rhs and rhs == 1: - edim_ou[idx] = '1' + edim_in1[idx] = '1' anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input0, input1], value=value) @@ -1492,4 +1492,19 @@ def GetItem(signature, inputs) -> Union[Any, IRPyFunc]: return obj[index] else: return IRPyFunc(signature, inputs, [IRObject()]) - \ No newline at end of file + +def GetAttr(signature, inputs) -> Union[List[int], IRPyFunc]: + """ + builtins.getattr(object, name[, default]) + NOTE: only deal with the attr "shape" of IRFullTensor, because other type of object may not + have instantiated object or the attr is not simple value. + """ + assert len(inputs) == 2, f"but got {inputs}" + obj, name = inputs + if name == 'shape': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + assert hasattr(obj, name), f"attr {name} is not existed in {obj}" + return getattr(obj, name) + else: + # FIXME: is it right? + return IRPyFunc(signature, inputs, [IRObject()]) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 834c9557..19300771 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -14,7 +14,7 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: """ Map the signature to GenericLogicalOp """ - bultin_regions = ['torch.', 'cube.runtime.', '_operator.'] + bultin_regions = ['torch.', 'cube.runtime.', '_operator.', 'builtins.'] # customized function if all(not signature.startswith(region) for region in bultin_regions): signature = signature.split('.')[-1] @@ -28,7 +28,7 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: @staticmethod def exist(signature: str) -> bool: - bultin_regions = ['torch.', 'cube.runtime.', '_operator.'] + bultin_regions = ['torch.', 'cube.runtime.', '_operator.', 'builtins.'] # customized function if all(not signature.startswith(region) for region in bultin_regions): signature = signature.split('.')[-1] @@ -80,7 +80,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, __ttemplate('fill_'): function.Fill, - # __ttemplate('masked_fill'): function.MaskedFill, + __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, __ttemplate('tanh'): function.Tanh, __ftemplate('softmax') : function.Softmax, @@ -101,6 +101,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # ============== runtime function ================= __tttemplate('size'): function.Size, '_operator.getitem': function.GetItem, + 'builtins.getattr': function.GetAttr, # # torch nn functional # @@ -135,13 +136,15 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('sub') : function.Sub, # # __ttemplate('mul') : function.Mul, + '_operator.mul': function.Mul, __ttemplate('div') : function.Div, __ttemplate('truediv'): function.Div, __ttemplate('true_divide'): function.Div, __ttemplate('floordiv') : function.FloorDiv, __ttemplate('floor_divide') : function.FloorDiv, - + '_operator.floordiv': function.FloorDiv, + __ttemplate('neg'): function.Neg, # __ttemplate('gt'): function.CompareGT, @@ -157,9 +160,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('mean') : function.Mean, # # __ttemplate('view'): function.View, - # __tttemplate('view'): function.View, + __tttemplate('view'): function.View, - # __ttemplate('reshape'): function.Reshape, + __ttemplate('reshape'): function.Reshape, # # __ttemplate('conv2d'): function.Conv2D, # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 0718abec..8dfe420b 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -264,6 +264,7 @@ def extract_val(fx_node): # map to IR operator if SignFx2Op.exist(fsig): + print('zql: ', fsig, input_vals, node.meta, node.args, node.kwargs) ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: input_vals = [extract_val(v) for v in node.args] @@ -280,7 +281,12 @@ def extract_val(fx_node): # case2: python runtime function else: print(f'>>> Set python runtime function: {fsig}') + # if fsig == 'builtins.getattr': + # print('zql func getattr: ', FxModuleParser.ntype(node), node.name, node.target, node.meta, node.args, node.kwargs) ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) + # if fsig == 'builtins.getattr': + # print('zql ir_node: ', ir_node) + # exit(1) if isinstance(ir_node, IRCell): # TODO gracefully set output @@ -294,6 +300,11 @@ def extract_val(fx_node): @staticmethod def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + """ + There are two types of get_attr, one is `FxNodeKind.PrimGetAttr` which is dealt with in this function. + The other is `FxNodeKind.PrimCallFunction ` (i.e., ) + which is dealt with by parse_prim_function_method. + """ assert node is not None tensor_name = node.name if 'tensor_meta' in node.meta: diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/parser/test_bloom.py b/tests/parser/test_bloom.py index 41aff42b..e286606d 100644 --- a/tests/parser/test_bloom.py +++ b/tests/parser/test_bloom.py @@ -28,5 +28,15 @@ print("parsing fx graph to cube graph...") from cube.graph.parser import FxModuleParser -FxModuleParser.parse(traced_graph, dummy_inputs=inputs) +cube_graph = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) print("parsing done.") + +# # AutoDist +# # profile communication cost +# import os +# comm_gpu_num = (2, 4) +# for gpu_num in comm_gpu_num: +# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --comm_profile_dir=./ --connect_type=NV') +# # find the best partition plan +# from autodist.apis import compile +# compile(cube_graph, ...) From 6fb157b21a38dbd3ed4d944630dd28415dd5a4ee Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Mon, 6 Mar 2023 01:15:05 -0800 Subject: [PATCH 1317/1892] update --- cube/graph/parser/mappingfx.py | 2 +- cube/graph/parser/parserfx.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 19300771..f5f79b56 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -86,7 +86,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('softmax') : function.Softmax, __ttemplate('bmm') : function.BatchLinear, __ttemplate('pow'): function.Pow, - # __ttemplate('baddbmm'): function.BMMAdd, + __ttemplate('baddbmm'): function.BMMAdd, __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, __tttemplate('expand'): function.Expand, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 8dfe420b..9657be84 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -264,7 +264,6 @@ def extract_val(fx_node): # map to IR operator if SignFx2Op.exist(fsig): - print('zql: ', fsig, input_vals, node.meta, node.args, node.kwargs) ir_node = SignFx2Op.map(fsig)(inputs=input_vals) else: input_vals = [extract_val(v) for v in node.args] @@ -281,12 +280,7 @@ def extract_val(fx_node): # case2: python runtime function else: print(f'>>> Set python runtime function: {fsig}') - # if fsig == 'builtins.getattr': - # print('zql func getattr: ', FxModuleParser.ntype(node), node.name, node.target, node.meta, node.args, node.kwargs) ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) - # if fsig == 'builtins.getattr': - # print('zql ir_node: ', ir_node) - # exit(1) if isinstance(ir_node, IRCell): # TODO gracefully set output From 7c063fb840d70d8ba85565752ad185528584ee65 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Mon, 6 Mar 2023 01:25:39 -0800 Subject: [PATCH 1318/1892] add test and remove unused mapping --- cube/graph/parser/mappingfx.py | 2 -- examples/mlp/linearsfx.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index f5f79b56..622cc497 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -139,9 +139,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] '_operator.mul': function.Mul, __ttemplate('div') : function.Div, - __ttemplate('truediv'): function.Div, __ttemplate('true_divide'): function.Div, - __ttemplate('floordiv') : function.FloorDiv, __ttemplate('floor_divide') : function.FloorDiv, '_operator.floordiv': function.FloorDiv, diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 3cac42ed..50c7cfd7 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -53,7 +53,7 @@ def __init__(self, dim, mult=1, nlayers=4): self.layers.append(nn.Linear(dim * mult, dim, bias=False)) last_dim = dim - # self.layer_norm = nn.LayerNorm(last_dim) #TODO CHECK torch.fx ignores LayerNorm + self.layer_norm = nn.LayerNorm(last_dim) #TODO CHECK torch.fx ignores LayerNorm # self.p = 0.5 self.drop_out = nn.Dropout() # self.y = torch.nn.Parameter(torch.empty(128, last_dim)) @@ -72,7 +72,7 @@ def forward(self, data, mask): x = layer(x) x = torch.nn.functional.relu(x) x = torch.nn.functional.gelu(x) - # x = self.layer_norm(x) + x = self.layer_norm(x) type_x = torch.pow(x, 1.0) x = x.type_as(type_x) x = x.unsqueeze(1) From 15f6da4a28a7e7cdb58131853e683467fd29e6a9 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Mon, 6 Mar 2023 09:28:08 +0000 Subject: [PATCH 1319/1892] Merged PR 1477: support layer_norm and constant folding of get_attr --- cube/graph/function/function.py | 38 ++++++++++++++++++++++++--------- cube/graph/parser/mappingfx.py | 12 +++++------ cube/graph/parser/parserfx.py | 17 +++++++++------ examples/__init__.py | 0 examples/mlp/linearsfx.py | 4 ++-- tests/parser/test_bloom.py | 12 ++++++++++- tests/test_examples.sh | 2 ++ 7 files changed, 59 insertions(+), 26 deletions(-) create mode 100644 examples/__init__.py diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index ec673c05..6251902e 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -6,7 +6,7 @@ import operator from cube.ir.cten import IRTensor, IRObject -from cube.ir.tensor import IRSubTensor +from cube.ir.tensor import IRSubTensor, IRFullTensor from cube.ir.dtype import IRDType from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule @@ -616,7 +616,7 @@ def MaskedFill(signature, inputs): edim_ou = copy.copy(edim_in0) for idx, (lhs, rhs) in enumerate(zip(input0.shape, input1.shape)): if lhs != rhs and rhs == 1: - edim_ou[idx] = '1' + edim_in1[idx] = '1' anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input0, input1], value=value) @@ -628,28 +628,31 @@ def LayerNorm(signature, inputs): """ if 'torch.' in signature: tensor, normalized_shape, weight, bias, eps = inputs - assert isinstance(normalized_shape, list), f"normalized_shape for layer_norm can only be List[int]" + assert isinstance(normalized_shape, (list, tuple, torch.Size)), \ + f"normalized_shape for layer_norm can only be tuple or list or torch.Size, NOT {type(normalized_shape)}" else: + assert 'cube.runtime.function.layer_norm' == signature, f'{signature} of LayerNorm is not supported.' tensor, weight, bias, normalized_shape, eps = inputs letters = iter(string.ascii_lowercase) einput = ShapeAnno.create_shape_str(tensor.shape, iterator=letters) eoutput = copy.copy(einput) ndims = len(tensor.shape) - for dim in range(len(normalized_shape)): + ndims_normshape = len(normalized_shape) + for dim in range(ndims_normshape): + # though these dimensions can be partitioned, + # such partition induces additional communication and complexity einput[ndims-1-dim] += '^' eoutput[ndims-1-dim] += '^' - assert not (bias is None is weight is not None), f"Not support for None of weight and parameter of bias" + assert not (bias is not None and weight is None), f"Not support for None of weight and parameter of bias" einputs, inputs = [einput], [tensor] kwargs = {} if weight is not None: - eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) - einputs.append(eweight) + einputs.append(einput[ndims-ndims_normshape:]) inputs.append(weight) else: kwargs['weight'] = weight if bias is not None: - ebias = ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) - einputs.append(ebias) + einputs.append(einput[ndims-ndims_normshape:]) inputs.append(bias) else: kwargs['bias'] = bias @@ -1489,4 +1492,19 @@ def GetItem(signature, inputs) -> Union[Any, IRPyFunc]: return obj[index] else: return IRPyFunc(signature, inputs, [IRObject()]) - \ No newline at end of file + +def GetAttr(signature, inputs) -> Union[List[int], IRPyFunc]: + """ + builtins.getattr(object, name[, default]) + NOTE: only deal with the attr "shape" of IRFullTensor, because other type of object may not + have instantiated object or the attr is not simple value. + """ + assert len(inputs) == 2, f"but got {inputs}" + obj, name = inputs + if name == 'shape': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + assert hasattr(obj, name), f"attr {name} is not existed in {obj}" + return getattr(obj, name) + else: + # FIXME: is it right? + return IRPyFunc(signature, inputs, [IRObject()]) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index e78ce876..622cc497 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -14,7 +14,7 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: """ Map the signature to GenericLogicalOp """ - bultin_regions = ['torch.', 'cube.runtime.', '_operator.'] + bultin_regions = ['torch.', 'cube.runtime.', '_operator.', 'builtins.'] # customized function if all(not signature.startswith(region) for region in bultin_regions): signature = signature.split('.')[-1] @@ -28,7 +28,7 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: @staticmethod def exist(signature: str) -> bool: - bultin_regions = ['torch.', 'cube.runtime.', '_operator.'] + bultin_regions = ['torch.', 'cube.runtime.', '_operator.', 'builtins.'] # customized function if all(not signature.startswith(region) for region in bultin_regions): signature = signature.split('.')[-1] @@ -96,12 +96,12 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, - # TODO __ftemplate('layer_norm'): function.LayerNorm, # ============== runtime function ================= __tttemplate('size'): function.Size, '_operator.getitem': function.GetItem, + 'builtins.getattr': function.GetAttr, # # torch nn functional # @@ -136,13 +136,13 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('sub') : function.Sub, # # __ttemplate('mul') : function.Mul, + '_operator.mul': function.Mul, __ttemplate('div') : function.Div, - __ttemplate('truediv'): function.Div, __ttemplate('true_divide'): function.Div, - __ttemplate('floordiv') : function.FloorDiv, __ttemplate('floor_divide') : function.FloorDiv, - + '_operator.floordiv': function.FloorDiv, + __ttemplate('neg'): function.Neg, # __ttemplate('gt'): function.CompareGT, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index dbd8fb12..9657be84 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -155,7 +155,8 @@ def parse(module: torch.fx.GraphModule, all_ir_nodes: List[IRFwOperation] = list() for node in module.graph.nodes: ir_nodes = FxModuleParser.parse_node(node, module, frame) - all_ir_nodes += ir_nodes + if ir_nodes is not None: + all_ir_nodes += ir_nodes output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] @@ -293,6 +294,11 @@ def extract_val(fx_node): @staticmethod def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + """ + There are two types of get_attr, one is `FxNodeKind.PrimGetAttr` which is dealt with in this function. + The other is `FxNodeKind.PrimCallFunction ` (i.e., ) + which is dealt with by parse_prim_function_method. + """ assert node is not None tensor_name = node.name if 'tensor_meta' in node.meta: @@ -305,13 +311,10 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram ir_tensor.as_param() frame.add_var(tensor_name, ir_tensor) else: - # FIXME: why no need to record the constant value of this var? - # the value can be obtained below: - # var = FxModuleParser.fetch_attr(module, node.target) - print(f'WARNING: {node.name} {node.meta} in attr node uses empty IRObject!') - frame.add_var(tensor_name, IRObject()) + var = FxModuleParser.fetch_attr(module, node.target) + frame.add_var(tensor_name, var) - return list() + return None @staticmethod def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 3cac42ed..50c7cfd7 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -53,7 +53,7 @@ def __init__(self, dim, mult=1, nlayers=4): self.layers.append(nn.Linear(dim * mult, dim, bias=False)) last_dim = dim - # self.layer_norm = nn.LayerNorm(last_dim) #TODO CHECK torch.fx ignores LayerNorm + self.layer_norm = nn.LayerNorm(last_dim) #TODO CHECK torch.fx ignores LayerNorm # self.p = 0.5 self.drop_out = nn.Dropout() # self.y = torch.nn.Parameter(torch.empty(128, last_dim)) @@ -72,7 +72,7 @@ def forward(self, data, mask): x = layer(x) x = torch.nn.functional.relu(x) x = torch.nn.functional.gelu(x) - # x = self.layer_norm(x) + x = self.layer_norm(x) type_x = torch.pow(x, 1.0) x = x.type_as(type_x) x = x.unsqueeze(1) diff --git a/tests/parser/test_bloom.py b/tests/parser/test_bloom.py index 41aff42b..e286606d 100644 --- a/tests/parser/test_bloom.py +++ b/tests/parser/test_bloom.py @@ -28,5 +28,15 @@ print("parsing fx graph to cube graph...") from cube.graph.parser import FxModuleParser -FxModuleParser.parse(traced_graph, dummy_inputs=inputs) +cube_graph = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) print("parsing done.") + +# # AutoDist +# # profile communication cost +# import os +# comm_gpu_num = (2, 4) +# for gpu_num in comm_gpu_num: +# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --comm_profile_dir=./ --connect_type=NV') +# # find the best partition plan +# from autodist.apis import compile +# compile(cube_graph, ...) diff --git a/tests/test_examples.sh b/tests/test_examples.sh index 9e5c526d..87e366d1 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -1,6 +1,8 @@ # NOTE: This test should run in the root directory. # Before running this test, you should run `export PYTHONPATH=.:$PYTHONPATH` first. +set -e + # test torch.fx # working path OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:$PYTHONPATH \ From afc28532ff5092a0b668f6281649beee0885a99f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 7 Mar 2023 06:26:57 +0000 Subject: [PATCH 1320/1892] Merged PR 1479: function refactor with *args and **kwargs support function refactor with *args and **kwargs support --- cube/graph/function/dimops.py | 3 +- cube/graph/function/function.py | 1035 +++++++++++++------------------ cube/graph/gener/gen.py | 10 +- cube/graph/graph.py | 6 +- cube/graph/parser/mappingfx.py | 4 +- cube/graph/parser/parser.py | 4 +- cube/graph/parser/parserfx.py | 49 +- cube/graph/parser/register.py | 7 +- cube/graph/segment.py | 6 +- examples/mlp/linearsfx.py | 3 +- 10 files changed, 452 insertions(+), 675 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 3e47a510..d095f679 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -676,8 +676,7 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): @return op IRDimop: the new constructed operator """ - inputs = inputs + [kwargs[key] for key in kwargs.keys()] - op = self._create_fn[0](self.signature, inputs) + op = self._create_fn[0](*inputs, **kwargs, signature=self.signature) # annos = self._annos_candidates # rules = self._trans_rules # op = IRDimops(self.signature, annos, inputs, self.name, rules, **kwargs) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 6251902e..d9c9b987 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -18,54 +18,73 @@ ErasedDevice = 'str' -def Identity(signature, inputs: List[IRTensor]): +def Identity(tensor: IRObject, signature = None): signature = 'cube.runtime.function.identity' - eshape = ShapeAnno.create_shape_str(inputs[0].shape) + eshape = ShapeAnno.create_shape_str(tensor.shape) anno = OpAnno.create_op_str([eshape], [eshape]) - return IRDimops(Identity, 'identity', signature, [anno], inputs) + return IRDimops(Identity, 'identity', signature, [anno], [tensor]) -def Linear(signature, inputs): - assert len(inputs) == 3 +def MultiRef(tensor: IRTensor, times: int, signature = None): + """ + cube.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] + """ + signature = 'cube.runtime.function.multiref' + assert isinstance(tensor, IRTensor), "require all inputs to be IRSubTensor" + assert isinstance(times, int), "require int for second input" + anno = '* -> ' + ', '.join('*' for _ in range(times)) + node = IRDimops(MultiRef, 'multiref', signature, [anno], [tensor], times=times) + return node + + +def Accum(*inputs, signature = None): + """ + tensor = cube.runtime.function.accum(tensors) + """ + assert all(isinstance(t, IRTensor) for t in inputs) + signature = 'cube.runtime.function.accum' + iannos = [ShapeAnno.create_shape_str(t.shape) for t in inputs] + oannos = [copy.copy(iannos[0])] + anno = OpAnno.create_op_str(iannos, oannos) + return IRDimops(Cat, 'accum', signature, [anno], inputs) + + +def Linear(input, weight, bias=None, signature = None): signature = 'torch.nn.functional.linear' - if inputs[2] is None: + if bias is None: annos = ['b * k+, n k+ -> b * n'] - return IRDimops(Linear, 'linear', signature, annos, inputs[:2], bias=None) + return IRDimops(Linear, 'linear', signature, annos, [input, weight], bias=None) else: annos = ['b * k+, n k+, n -> b * n'] rules = [TransformRule( [DimopSplit.D(-1), DimopSplit.D(1), DimopSplit.V()], [DimopSplit.V()] )] - return IRDimops(Linear, 'linear', signature, annos, inputs, rules) + return IRDimops(Linear, 'linear', signature, annos, [input, weight, bias], rules) -def BatchLinear(signature, inputs): - annos = [ - 'b m k+, b k+ n -> b m n' - ] - return IRDimops(BatchLinear, 'bmm', signature, annos, inputs) - - -def BMMAdd(signature, inputs): - assert len(inputs) >= 3, f'{inputs}' - alpha, beta = 1, 1 - if len(inputs) == 4: - assert isinstance(inputs[3], dict) - alpha = inputs[3]['alpha'] - beta = inputs[3]['beta'] - elif len(inputs) == 5: - alpha, beta = inputs[3:] - annos = [ - 'b m n, b m k^, b k^ n -> b m n' - ] - return IRDimops(BMMAdd, 'baddbmm', signature, annos, inputs[:3], alpha=alpha, beta=beta) +def BatchLinear(input, mat2, *, out=None, signature = None): + assert out is None + annos = ['b m k+, b k+ n -> b m n'] + return IRDimops(BatchLinear, 'bmm', signature, annos, [input, mat2]) -def EinSum(signature, inputs): - if isinstance(inputs[0], str): - equation, tensors = inputs[0], inputs[1:] - else: - tensors, equation = inputs[:-1], inputs[-1] +def BMMAdd(input, batch1, batch2, *, beta=1, alpha=1, out=None, signature = None): + """ + torch.baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) + """ + assert out is None + in_dims = ['b', 'm', 'n'] + assert len(input.shape) == 3 + for i, size in enumerate(input.shape): + if size == 1: + in_dims[i] = '1' + in_anno = ' '.join(in_dims) + anno = f'{in_anno}, b m k^, b k^ n -> b m n' + return IRDimops(BMMAdd, 'baddbmm', signature, [anno], [input, batch1, batch2], alpha=alpha, beta=beta) + + +def CubeEinSum(*operands, equation=None, signature = None): + assert isinstance(equation, str) lhs, rhs = equation.split('->') assert ',' not in rhs lhs_dims = set(lhs.replace(',', ' ').split(' ')) @@ -73,11 +92,14 @@ def EinSum(signature, inputs): if dim not in rhs: lhs = lhs.replace(dim, f'{dim}+') anno = f'{lhs} -> {rhs}' - return IRDimops(EinSum, 'einsum', signature, [anno], tensors, equation=equation) + return IRDimops(CubeEinSum, 'einsum', signature, [anno], operands, equation=equation) +def EinSum(equation: str, *operands, signature = None): + return CubeEinSum(*operands, equation=equation, signature=signature) -def Matmul(signature, inputs: Tuple[IRTensor, IRTensor]): - assert len(inputs) == 2 + +def Matmul(signature, input, other, *, out=None): + assert out is None annos = [ 'm k+, k+ n -> m n', 'k+, k+ n -> n', @@ -85,10 +107,9 @@ def Matmul(signature, inputs: Tuple[IRTensor, IRTensor]): '* m k+, k+ n -> * m n', '* m k+, * k+ n -> * m n' # TODO: broadcast ] - lhs, rhs = inputs - if len(lhs.shape) > 2 and len(rhs.shape) > 2: - assert tuple(lhs.shape[:-2]) == tuple(rhs.shape[:-2]), "broadcast of matmul (bmm) is not supported" - return IRDimops(Matmul, 'matmul', signature, annos, inputs) + if len(input.shape) > 2 and len(other.shape) > 2: + assert tuple(input.shape[:-2]) == tuple(other.shape[:-2]), "broadcast of matmul (bmm) is not supported" + return IRDimops(Matmul, 'matmul', signature, annos, [input, other]) def Zeros(signature, @@ -283,10 +304,10 @@ def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: return lhs_shape, rhs_shape, out_shape -def Expand(signature, inputs): - input = inputs[0] - sizes = inputs[1:] - +def Expand(input, *sizes, signature = None): + """ + torch.Tensor.expand(*sizes) + """ edim_in = ShapeAnno.create_shape_str(input.shape) assert len(input.shape) == len(sizes) for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes)): @@ -294,406 +315,256 @@ def Expand(signature, inputs): edim_in[idx] += '^' edim_ou = copy.copy(edim_in) anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Expand, 'expand', signature, [anno], [input], sizes=sizes) -def Clone(signature, inputs): +def Clone(input, *, memory_format=None, signature = None): """ torch.clone(input, *, memory_format=torch.preserve_format) """ - assert len(inputs) == 2, f"inputs: {inputs}" - tensor, memory_format = inputs - annos = ['* -> *'] - tensor = inputs[0] assert memory_format is None, f"Not supported for a specific memory format" - return IRDimops(Clone, 'clone', signature, annos, [tensor]) - - -def Add(signature, inputs): - if len(inputs) == 2: - kwargs = {} - elif len(inputs) == 3: - alpha = inputs[2] - kwargs = {'alpha': alpha} - inputs = inputs[0:2] - else: - raise RuntimeError("The number of inputs must be 2 or 3") - - lhs, rhs = inputs + annos = ['* -> *'] + return IRDimops(Clone, 'clone', signature, annos, [input]) - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): - # In this case there won't be an 'alpha' parameter. - assert not('alpha' in kwargs) - return lhs + rhs - annos = [ - '*, ? -> *', - '?, * -> *', - ] - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): - lshape, rshape, oshape = _handle_broadcast(lhs, rhs) +def Add(input, other, alpha=1, *, out=None, signature = None): + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input + alpha * other + annos = ['*, ? -> *', '?, * -> *',] + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Add, 'add', signature, annos, inputs, **kwargs) - - -def Sub(signature, inputs): - if len(inputs) == 2: - alpha = 1 - kwargs = {} - elif len(inputs) == 3: - alpha = inputs[2] - kwargs = {'alpha': alpha} - inputs = inputs[0:2] - else: - raise RuntimeError("The number of inputs must be 2 or 3") + return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) - lhs, rhs = inputs - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): - # In this case there won't be an 'alpha' parameter. - assert not('alpha' in kwargs) - return lhs - rhs - - annos = [ - '*, ? -> *', - '?, * -> *', - ] - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): - lshape, rshape, oshape = _handle_broadcast(lhs, rhs) +def Sub(input, other, alpha=1, *, out=None, signature = None): + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input - alpha * other + annos = ['*, ? -> *', '?, * -> *',] + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Sub, 'sub', signature, annos, inputs, **kwargs) + return IRDimops(Sub, 'sub', signature, annos, [input, other], alpha=1) -def Mul(signature, inputs): - lhs, rhs = inputs - - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): - return lhs * rhs - - annos = [ - '*, ? -> *', - '?, * -> *', - ] - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): - lshape, rshape, oshape = _handle_broadcast(lhs, rhs) +def Mul(input, other, *, out=None, signature = None): + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input * other + annos = ['*, ? -> *', '?, * -> *',] + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Mul, 'mul', signature, annos, inputs) - + return IRDimops(Mul, 'mul', signature, annos, [input, other]) -def Div(signature, inputs): - lhs, rhs = inputs - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): - # For `aten::div` we always do floating division, even operands are both ints. - # TorchScript would dispatch frontend `a // b` to another op `aten::floordiv`. - return lhs / rhs - annos = [ - '*, ? -> *', - '?, * -> *', - ] - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): - lshape, rshape, oshape = _handle_broadcast(lhs, rhs) +def Div(input, other, *, rounding_mode=None, out=None, signature = None): + assert rounding_mode is None and out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input / other + annos = ['*, ? -> *', '?, * -> *',] + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Div, 'div', signature, annos, inputs) - - -def FloorDiv(signature, inputs): - lhs, rhs = inputs - - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): - return lhs // rhs - - annos = [ - '*, ? -> *', - '?, * -> *', - ] - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): - lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + return IRDimops(Div, 'div', signature, annos, [input, other]) + + +def FloorDiv(input, other, *, out=None, signature = None): + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input // other + if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): + return IRPyFunc(signature, [input, other], [IRObject()]) + annos = ['*, ? -> *', '?, * -> *',] + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(FloorDiv, 'floordiv', signature, annos, inputs) + return IRDimops(FloorDiv, 'floordiv', signature, annos, [input, other]) -def Pow(signature, inputs): - lhs, rhs = inputs - - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): - return lhs ** rhs - - annos = [ - '*, ? -> *', - '?, * -> *', - ] - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): - lshape, rshape, oshape = _handle_broadcast(lhs, rhs) +def Pow(input, exponent, *, out=None, signature = None): + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(exponent, IRObject)): + return input ** exponent + annos = ['*, ? -> *', '?, * -> *',] + if isinstance(input, IRTensor) and isinstance(exponent, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, exponent) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Pow, 'pow', signature, annos, inputs) + return IRDimops(Pow, 'pow', signature, annos, [input, exponent]) -def Neg(signature, inputs): - assert len(inputs) == 1 or len(inputs) == 2 - kwargs = {} if len(inputs) == 1 else {'approximate': inputs[1]} - tensors = inputs[0:1] - - if isinstance(tensors[0], (int, float)): - assert not('approximate' in kwargs) - return -tensors[0] - +def Neg(input, *, out=None, signature = None): + assert out is None + if not isinstance(input, IRObject): return -1 * input annos = ['* -> *'] - return IRDimops(Neg, 'neg', signature, annos, inputs, **kwargs) + return IRDimops(Neg, 'neg', signature, annos, [input]) -def Sin(signature, inputs): +def Sin(input, *, out=None, signature = None): + assert out is None annos = ['* -> *'] - tensor = inputs[0:1] - if len(inputs) == 2: - # adapt for newest pytorch version - approximate = inputs[1] - return IRDimops(Sin, 'sin', signature, annos, tensor, - approximate=approximate) - else: - return IRDimops(Sin, 'sin', signature, annos, tensor) + return IRDimops(Sin, 'sin', signature, annos, [input]) -def Cos(signature, inputs): +def Cos(input, *, out=None, signature = None): + assert out is None annos = ['* -> *'] - tensor = inputs[0:1] - if len(inputs) == 2: - # adapt for newest pytorch version - approximate = inputs[1] - return IRDimops(Cos, 'cos', signature, annos, tensor, - approximate=approximate) - else: - return IRDimops(Cos, 'cos', signature, annos, tensor) + return IRDimops(Cos, 'cos', signature, annos, [input]) -def Tanh(signature, inputs): - """ - torch.tanh(input, *, out=None) - """ - assert len(inputs) == 1, f"inputs: {inputs}" +def Tanh(input, *, out=None, signature = None): + assert out is None annos = ['* -> *'] - tensor = inputs[0:1] - return IRDimops(Tanh, 'tanh', signature, annos, tensor) + return IRDimops(Tanh, 'tanh', signature, annos, [input]) -def GeLU(signature, inputs): +def GeLU(input, approximate='none', signature = None): annos = ['* -> *'] signature = 'torch.nn.functional.gelu' - tensor = inputs[0:1] - if len(inputs) == 2: - # adapt for newest pytorch version - approximate = inputs[1] - return IRDimops(GeLU, 'gelu', signature, annos, tensor, - approximate=approximate) - else: - return IRDimops(GeLU, 'gelu', signature, annos, tensor) + return IRDimops(GeLU, 'gelu', signature, annos, [input], approximate=approximate) -def SiLU(signature, inputs): - assert len(inputs) == 1 +def SiLU(input, inplace=False, signature = None): annos = ['* -> *'] signature = 'torch.nn.functional.silu' - tensor = inputs[0:1] - return IRDimops(SiLU, 'silu', signature, annos, tensor) + return IRDimops(SiLU, 'silu', signature, annos, [input], inplace=inplace) -def ReLU(signature, inputs): - assert len(inputs) == 1 +def ReLU(input, inplace=False, signature = None): annos = ['* -> *'] signature = 'torch.nn.functional.relu' - tensor = inputs[0:1] - return IRDimops(ReLU, 'relu', signature, annos, tensor) + return IRDimops(ReLU, 'relu', signature, annos, [input], inplace=inplace) -def Softmax(signature, inputs): - assert len(inputs) >= 1 - tensor = inputs[0] - edim_in = ShapeAnno.create_shape_str(tensor.shape) - if len(inputs) == 2: - if isinstance(inputs[1], dict): - edim_in[inputs[1]['dim']] += '^' - edim_ou = copy.copy(edim_in) - anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], **inputs[1]) - elif isinstance(inputs[1], int): - dim = inputs[1] - edim_in[dim] += '^' - edim_ou = copy.copy(edim_in) - anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], dim=inputs[1]) - else: - raise RuntimeError(f'Unexpect intput type {inputs[1]}, {type(inputs[1])}') - elif len(inputs) == 4: - dim, _stacklevel, dtype = inputs[1], inputs[2], inputs[3] - dim = inputs[1] +def Softmax(input, dim=None, _stacklevel=3, dtype=None, signature = None): + """ + torch.nn.functional.softmax(input, dim=None, _stacklevel=3, dtype=None) + """ + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = copy.copy(edim_in) + if dim is not None: edim_in[dim] += '^' - anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Softmax, 'softmax', signature, [anno], [tensor], - dim=dim, _stacklevel=_stacklevel, dtype=dtype) - else: - raise RuntimeError('Unexpected input num {inputs}') - + edim_ou[dim] += '^' + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Softmax, 'softmax', signature, [anno], [input], + dim=dim, _stacklevel=_stacklevel, dtype=dtype) -def Dropout(signature, inputs): - assert len(inputs) <= 4, f'but the length is {len(inputs)}' - default_inputs = [None, 0.5, True, False] - inputs = inputs + default_inputs[len(inputs):] +def Dropout(input, p=0.5, training=True, inplace=False, signature = None): + """ + torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False) + """ annos = ['* -> *'] - tensor = inputs[0:1] - p, training, inplace = inputs[1], inputs[2], inputs[3] - return IRDimops(Dropout, 'dropout', signature, annos, tensor, + return IRDimops(Dropout, 'dropout', signature, annos, [input], p=p, training=training, inplace=inplace) -def Detach(signature, inputs): - assert len(inputs) == 1 +def Detach(input, signature = None): + """ + torch.Tensor.detach(input) + """ annos = ['* -> *'] - tensor = inputs[0:1] - return IRDimops(Detach, 'detach', signature, annos, tensor) - - -def EQ(signature, inputs): - assert len(inputs) == 2 - input0, input1 = inputs - - edim_in0 = ShapeAnno.create_shape_str(input0.shape) - edim_ou = copy.copy(edim_in0) - if isinstance(input1, (int, float)): - anno = OpAnno.create_op_str([edim_in0], [edim_ou]) - return IRDimops(EQ, 'eq', signature, [anno], [input0], other=input1) - else: - edim_in1 = copy.copy(edim_in0) - anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) - return IRDimops(EQ, 'eq', signature, [anno], [input0], other=input1) - - -def NE(signature, inputs): - assert len(inputs) == 2 - input0, input1 = inputs - - edim_in0 = ShapeAnno.create_shape_str(input0.shape) - edim_ou = copy.copy(edim_in0) - if isinstance(input1, (int, float)): - anno = OpAnno.create_op_str([edim_in0], [edim_ou]) - return IRDimops(NE, 'ne', signature, [anno], [input0], other=input1) - else: - edim_in1 = copy.copy(edim_in0) - anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) - return IRDimops(NE, 'ne', signature, [anno], [input0], other=input1) + return IRDimops(Detach, 'detach', signature, annos, [input]) -def NanToNum(signature, inputs): - assert len(inputs) == 1 +def NanToNum(input, nan=0.0, posinf=None, neginf=None, *, out=None, signature = None): + assert out is None annos = ['* -> *'] - tensor = inputs[0:1] - return IRDimops(NanToNum, 'nan_to_num', signature, annos, tensor) + return IRDimops(NanToNum, 'nan_to_num', signature, annos, [input], nan=nan, posinf=posinf, neginf=neginf) -def Long(signature, inputs): - assert len(inputs) == 1 +def Long(input, memory_format=None, signature = None): + """ + torch.Tensor.long(memory_format=torch.preserve_format) + """ + assert memory_format is None annos = ['* -> *'] - tensor = inputs[0:1] - return IRDimops(Long, 'long', signature, annos, tensor) + return IRDimops(Long, 'long', signature, annos, [input]) -def Fill(signature, inputs): - assert len(inputs) == 2 - input, value = inputs - +def Fill(input, value, signature = None): + """ + torch.Tensor.fill_(value) + """ edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = copy.copy(edim_in) anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(Fill, 'fill', signature, [anno], [input], value=value) -def MaskedFill(signature, inputs): - assert len(inputs) == 3 - input0, input1, value = inputs - - edim_in0 = ShapeAnno.create_shape_str(input0.shape) - edim_in1 = ShapeAnno.create_shape_str(input1.shape) +def MaskedFill(input, mask, value, signature = None): + """ + torch.Tensor.masked_fill_(mask, value) + """ + edim_in0 = ShapeAnno.create_shape_str(input.shape) + edim_in1 = ShapeAnno.create_shape_str(mask.shape) edim_ou = copy.copy(edim_in0) - for idx, (lhs, rhs) in enumerate(zip(input0.shape, input1.shape)): + #TODO: add broadcast rule + for idx, (lhs, rhs) in enumerate(zip(input.shape, mask.shape)): if lhs != rhs and rhs == 1: edim_in1[idx] = '1' anno = OpAnno.create_op_str([edim_in0, edim_in1], [edim_ou]) - return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input0, input1], value=value) + return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input, mask], value=value) -def LayerNorm(signature, inputs): +def CubeLayerNorm(input, weight=None, bias=None, normalized_shape=None, eps=1e-05, signature = None): """ - torch.nn.functional.layer_norm(input, normliazed_shape, weight=None, bias=None, eps) cube.runtime.function.layer_norm(input, weight, bias, normliazed_shape, eps) """ - if 'torch.' in signature: - tensor, normalized_shape, weight, bias, eps = inputs - assert isinstance(normalized_shape, (list, tuple, torch.Size)), \ - f"normalized_shape for layer_norm can only be tuple or list or torch.Size, NOT {type(normalized_shape)}" - else: - assert 'cube.runtime.function.layer_norm' == signature, f'{signature} of LayerNorm is not supported.' - tensor, weight, bias, normalized_shape, eps = inputs + signature = 'cube.runtime.function.layer_norm' + assert not (weight is None and bias is not None), f"Not support for None of weight and parameter of bias" letters = iter(string.ascii_lowercase) - einput = ShapeAnno.create_shape_str(tensor.shape, iterator=letters) + einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) eoutput = copy.copy(einput) - ndims = len(tensor.shape) - ndims_normshape = len(normalized_shape) - for dim in range(ndims_normshape): - # though these dimensions can be partitioned, - # such partition induces additional communication and complexity + ndims = len(input.shape) + for dim in range(len(normalized_shape)): einput[ndims-1-dim] += '^' eoutput[ndims-1-dim] += '^' - assert not (bias is not None and weight is None), f"Not support for None of weight and parameter of bias" - einputs, inputs = [einput], [tensor] + einputs, inputs = [einput], [input] kwargs = {} if weight is not None: - einputs.append(einput[ndims-ndims_normshape:]) + eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) + einputs.append(eweight) inputs.append(weight) else: kwargs['weight'] = weight if bias is not None: - einputs.append(einput[ndims-ndims_normshape:]) + ebias = ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) + einputs.append(ebias) inputs.append(bias) else: kwargs['bias'] = bias anno = OpAnno.create_op_str(einputs, [eoutput]) kwargs['normalized_shape'] = normalized_shape kwargs['eps'] = eps - signature = 'cube.runtime.function.layer_norm' - return IRDimops(LayerNorm, 'layernorm', signature, [anno], inputs, **kwargs) + return IRDimops(CubeLayerNorm, 'layernorm', signature, [anno], inputs, **kwargs) -def Sum(signature, inputs): +def LayerNorm(input, normalized_shape, weight=None, bias=None, eps=1e-05, signature = None): + """ + torch.nn.functional.layer_norm(input, normliazed_shape, weight=None, bias=None, eps) + """ + return CubeLayerNorm(input, weight, bias, normalized_shape, eps, signature=signature) + + +def Sum(input, dim=None, keepdim=False, *, dtype=None, signature = None): """ torch.sum(input, *, dtype=None) -> Tensor torch.sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + + @note troch.sum is overrided by two signatures, which may lead mismatch in torch.jit.script: + may get (input, dtype) as input """ - assert len(inputs) == 1 or len(inputs) == 2 or len(inputs) == 4, f"{inputs}" - tensor = inputs[0] - einput = ShapeAnno.create_shape_str(tensor.shape) + assert dtype is None, "Currently Sum only support dtype=None" + einput = ShapeAnno.create_shape_str(input.shape) eoutput = copy.copy(einput) - if len(inputs) == 1: - inputs = [tensor] - eoutput = ['1'] - # every dimension can be reduced - einput = [edim + '+' for edim in einput] - anno = OpAnno.create_op_str([einput], [eoutput]) - return IRDimops(Sum, 'sum', signature, [anno], [tensor]) - elif len(inputs) == 2: - dtype = inputs[1] - assert dtype is None, "Currently Sum only support dtype=None" - # torch.sum(input) - inputs = [tensor] - eoutput = ['1'] - # every dimension can be reduced + if dim is None: einput = [edim + '+' for edim in einput] - anno = OpAnno.create_op_str([einput], [eoutput]) - return IRDimops(Sum, 'sum', signature, [anno], [tensor], dtype=dtype) + anno = OpAnno.create_op_str([einput], ['1']) + return IRDimops(Sum, 'sum', signature, [anno], [input]) else: - # torch.sum(input, dim, keepdim, *, dtype) - dim, keepdim, dtype = inputs[1:4] - assert dtype is None, "Currently Sum only support dtype=None" - assert isinstance(dim, list), f"Expect dim to be list but got: {dim}" + dim = (dim,) if isinstance(dim, int) else dim for dimidx in dim: einput[dimidx] += '+' if keepdim: @@ -705,63 +576,53 @@ def Sum(signature, inputs): for dimidx in sort_dim[::-1]: eoutput.pop(dimidx) anno = OpAnno.create_op_str([einput], [eoutput]) - return IRDimops(Sum, 'sum', signature, [anno], [tensor], dim=dim, keepdim=keepdim, dtype=dtype) - + return IRDimops(Sum, 'sum', signature, [anno], [input], dim=dim, keepdim=keepdim) -def Mean(signature, inputs): - if len(inputs) >= 2: - tensor, dim = inputs[:2] - elif len(inputs) == 1: - tensor = inputs[0] - dim = None - einput = ShapeAnno.create_shape_str(tensor.shape) +def Mean(input, dim=None, keepdim=False, *, dtype=None, signature = None): + """ + torch.mean(input, *, dtype=None) -> Tensor + torch.mean(input, dim, keepdim=False, *, dtype=None) -> Tensor + """ + assert dtype is None + einput = ShapeAnno.create_shape_str(input.shape) eoutput = copy.copy(einput) + dim = (dim,) if isinstance(dim, int) else dim if dim is not None: - keepdim = inputs[2] - sort_dim = list(dim) - sort_dim.sort() + sort_dim = sorted(dim) for dimidx in sort_dim[::-1]: eoutput.pop(dimidx) - einput[dimidx] = einput[dimidx] + '+' + einput[dimidx] = einput[dimidx] + '^' else: eoutput = ['1'] - # every dimension is reduced - einput = [edim + '+' for edim in einput] + einput = [edim + '^' for edim in einput] anno = OpAnno.create_op_str([einput], [eoutput]) if dim is not None: - return IRDimops(Mean, 'mean', signature, [anno], [tensor], dim=dim, keepdim=keepdim) + return IRDimops(Mean, 'mean', signature, [anno], [input], dim=dim, keepdim=keepdim) else: - return IRDimops(Mean, 'mean', signature, [anno], [tensor]) + return IRDimops(Mean, 'mean', signature, [anno], [input]) -def Transpose(signature, inputs): +def Transpose(input, dim0, dim1, signature = None): """ out = torch.transpose(tensor, dim0, dim1) """ - assert len(inputs) == 3 - input, dim0, dim1 = inputs - edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = copy.copy(edim_in) edim_ou[dim0], edim_ou[dim1] = edim_ou[dim1], edim_ou[dim0] anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Transpose, 'transpose', signature, [anno], [input], dim0=dim0, dim1=dim1) -def View(signature, inputs): +def View(input, size: Tuple[int], *arg_size, signature = None): """ - out = torch.Tensor.view(tensor: torch.Tensor, size: List[int]) + out = torch.Tensor.view(tensor: torch.Tensor, *size) """ - if len(inputs) == 2: - input, shape = inputs - else: - input = inputs[0] - shape = inputs[1:] - if not all([isinstance(dim, int) for dim in shape]): - raise TypeError("Expected tensor.view has static int shape") - in_shape, ou_shape = list(input.shape), shape + size = (size,) if isinstance(size, int) else tuple(size) + size = size + arg_size + assert all([isinstance(dim, int) for dim in size]), \ + f"Expected tensor.view has static int shape but got: {size}" + in_shape, ou_shape = list(input.shape), list(size) # infer -1 def nele(shape, nele=1): @@ -901,10 +762,10 @@ def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: anno = OpAnno.create_op_str([in_anno], [ou_anno]) signature = 'torch.Tensor.view' - return IRDimops(View, 'view', signature, [anno], [input], rules, size=tuple(shape)) + return IRDimops(View, 'view', signature, [anno], [input], rules, size=tuple(size)) -def Reshape(signature, inputs): +def Reshape(input, *shape, signature = None): """ torch.reshape(Tensor self, int[] shape) -> Tensor """ @@ -914,34 +775,28 @@ def Reshape(signature, inputs): but 'reshape' has keyword parameter 'shape' while 'view' has 'size'. ArgumentMissing error may be raised during codegen.""") - return View(signature, inputs) + return View(input, *shape, signature='torch.Tensor.view') -def Permute(signature, inputs): - if isinstance(inputs[1], list): - in_tensor, dims = inputs[0], inputs[1] - else: - in_tensor, dims = inputs[0], inputs[1:] - edim_in = ShapeAnno.create_shape_str(in_tensor.shape) - for idx, dim in enumerate(dims): - if idx != dim: - edim_in[idx] += '^' - assert len(edim_in) == len(dims), f'{len(edim_in)} vs {len(dims)}' - edim_ou = [] - for dim in dims: - assert isinstance(dim, int) - edim_ou.append(copy.copy(edim_in[dim])) +def Permute(input, dims: Tuple[int], *arg_dims, signature = None): + """ + torch.Tensor.permute(input, *dims) + torch.permute(input, dims: Tuple[int]) + """ + dims = (dims,) if isinstance(dims, int) else tuple(dims) + dims = dims + arg_dims + assert all(isinstance(dim, int) for dim in dims), f"but got {dims}" + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = [copy.copy(edim_in[dim]) for dim in dims] anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Permute, 'permute', signature, [anno], [in_tensor], dims=dims) + return IRDimops(Permute, 'permute', signature, [anno], [input], dims=dims) -def Squeeze(signature, inputs): +def Squeeze(input, dim=None, signature = None): """ out = torch.squeeze(tensor) """ - assert len(inputs) == 1 - input = inputs[0] - + assert dim is None, "got dim: {dim} != None, which is not supported" edim_in = ShapeAnno.create_shape_str(input.shape) assert len(edim_in) == len(input.shape) edim_ou = [] @@ -949,76 +804,51 @@ def Squeeze(signature, inputs): if dim_size > 1: edim_ou.append(copy.copy(dim_anno)) anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Squeeze, 'squeeze', signature, [anno], [input]) -def Unsqueeze(signature, inputs): +def Unsqueeze(input, dim, signature = None): """ out = torch.unsqueeze(tensor, dim) """ - assert len(inputs) == 2 - input, dim = inputs - edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = copy.copy(edim_in) edim_ou.insert(dim, '1') anno = OpAnno.create_op_str([edim_in], [edim_ou]) - - return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input], - dim=dim) + return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input],dim=dim) -def TypeAs(signature, inputs): +def TypeAs(input, tensor, signature = None): """ out = torch.Tensor.type_as(tensor0, tensor1) """ - assert len(inputs) == 2 - input0, input1 = inputs - - edim_in0 = ShapeAnno.create_shape_str(input0.shape) - edim_ou = copy.copy(edim_in0) - anno = OpAnno.create_op_str([edim_in0, '*'], [edim_ou]) + edim_in0 = ShapeAnno.create_shape_str(tensor.shape) + anno = OpAnno.create_op_str(['*', edim_in0], ['*']) + return IRDimops(TypeAs, 'type_as', signature, [anno], [input, tensor]) - return IRDimops(TypeAs, 'type_as', signature, [anno], [input0, input1]) - -def Triu(signature, inputs): +def Triu(input, diagonal=0, *, out=None, signature = None): """ out = torch.triu(tensor, diagonal) """ - assert len(inputs) == 2 - input, diagonal = inputs - edim_in = ShapeAnno.create_shape_str(input.shape) assert len(edim_in) >= 2 edim_in[-1] += '^' edim_in[-2] += '^' edim_ou = copy.copy(edim_in) anno = OpAnno.create_op_str([edim_in], [edim_ou]) - - return IRDimops(Triu, 'triu', signature, [anno], [input], - diagonal=diagonal) + return IRDimops(Triu, 'triu', signature, [anno], [input], diagonal=diagonal) -def CumSum(signature, inputs): +def CumSum(tensor, dim, signature = None): """ out = torch.cumsum(tensor, dim) """ - assert len(inputs) == 2 - input, dim = inputs - if isinstance(dim, dict): - dim = dim['dim'] - else: - assert isinstance(dim, int) - - edim_in = ShapeAnno.create_shape_str(input.shape) + edim_in = ShapeAnno.create_shape_str(tensor.shape) edim_in[dim] += '^' edim_ou = copy.copy(edim_in) anno = OpAnno.create_op_str([edim_in], [edim_ou]) - - return IRDimops(CumSum, 'cumsum', signature, [anno], [input], - dim=dim) + return IRDimops(CumSum, 'cumsum', signature, [anno], [tensor], dim=dim) # def Pad(signature, inputs): @@ -1042,23 +872,11 @@ def CumSum(signature, inputs): # return IRDimops(Pad, 'pad', signature, [anno], [tensor], pad=pad, mode=mode, value=value) -def Pad(signature, inputs): +def Pad(input, pad, mode='constant', value=0.0, signature = None): """ torch.nn.functional.pad(input, pad, mode='constant', value=0.0) - https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad - :param signature: - :param inputs: - :return: - """ - # print("#Pad::inputs.len: {}".format(len(inputs))) - # idx = 0 - # for input in inputs: - # if idx >= 0: - # print("#Pad::input[{}]: {}".format(idx, input)) - # idx += 1 - tensors = inputs[0:1] - pad, mode, value = inputs[1:] - return IRPad(signature, tensors, 'pad', pad=pad, mode=mode, value=value) + """ + return IRPad(signature, [input], 'pad', pad=pad, mode=mode, value=value) # def Conv2D(signature, inputs): @@ -1090,65 +908,41 @@ def Pad(signature, inputs): # stride=stride, padding=padding, dilation=dilation, groups=groups) -def Conv2D(signature, inputs): +def Conv2D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, signature = None): """ - torch.conv2d(input, weight, bias, stride, padding, dialation, groups) - https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html?highlight=torch%20conv2d#torch.nn.functional.conv2d + torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) """ - assert len(inputs) == 7, f"Expected 7 inputs but only got {len(inputs)}" - tensors = inputs[0:3] - stride, padding, dilation, groups = inputs[3:] if isinstance(padding, int): padding = [padding] * 4 elif len(padding) == 2: padH, padW = padding padding = [padH, padH, padW, padW] - return IRConv2D(signature, tensors, 'conv2d', + return IRConv2D(signature, [input, weight, bias], 'conv2d', stride=stride, padding=padding, dilation=dilation, groups=groups) -def Conv3D(signature, inputs): +def Conv3D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, signature = None): """ - conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor - https://pytorch.org/docs/stable/generated/torch.nn.functional.conv3d.html?highlight=conv3d#torch.nn.functional.conv3d + torch.nn.functional.conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) """ - assert len(inputs) == 7, f"Expected 7 inputs but only got {len(inputs)}" - tensors = inputs[0:3] - stride, padding, dilation, groups = inputs[3:] if isinstance(padding, int): padding = [padding] * 4 elif len(padding) == 2: padH, padW = padding padding = [padH, padH, padW, padW] - return IRConv3D(signature, tensors, 'conv3d', + return IRConv3D(signature, [input, weight, bias], 'conv3d', stride=stride, padding=padding, dilation=dilation, groups=groups) -def Accum(signature, inputs: Tuple[IRTensor]): +def CubeCat(*tensors, dim: int, signature = None): """ - tensor = cube.runtime.function.accum(tensors) + torch.cat(tensors, dim=0, *, out=None) """ - assert all(isinstance(t, IRTensor) for t in inputs) - signature = 'cube.runtime.function.accum' - iannos = [ShapeAnno.create_shape_str(t.shape) for t in inputs] - oannos = [copy.copy(iannos[0])] - anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(Cat, 'accum', signature, [anno], inputs) - - -def Cat(signature, inputs: Tuple[List[IRTensor], int]): - """ - torch.cat(inputs: List[Tensor], dim: int) -> Tensor - torch.cat(tensor1: Tensor, tensor2: Tensor, ..., dim: int) - - e.g. cat(tensor([2,3]), tensor([2,3])).shape == [4,3] - """ - assert len(inputs) >= 2 - if len(inputs) == 2: - tensors, dim = inputs - else: - tensors, dim = inputs[:-1], inputs[-1] + # REMARK: IRFwOperation doesn't support taking a list of IRTensors. + # Therefore, the argument interface is adapted to take unpacked tensors + # with dimension. dim=None is for the support of kwarg inputs from torchfx assert all(isinstance(tensor, IRTensor) for tensor in tensors) + assert isinstance(dim, int) iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] dimlens = [t.shape[dim] for t in tensors] for ashape, dimlen in zip(iannos, dimlens): @@ -1156,90 +950,95 @@ def Cat(signature, inputs: Tuple[List[IRTensor], int]): oannos = [copy.copy(iannos[-1])] oannos[0][dim] = str(sum(dimlens)) anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(Cat, 'cat', signature, [anno], tensors, dim=dim) + return IRDimops(CubeCat, 'cat', signature, [anno], tensors, dim=dim) -def Stack(signature, inputs: Tuple[List[IRTensor], int]): +def Cat(*tensors_and_dim, dim=0, out=None, signature=None): + """ + torch.cat(tensors, dim=0, *, out=None) """ - torch.stack(inputs: List[Tensor], dim: int) -> Tensor - torch.stack(tensor1: Tensor, tensor2: Tensor, ..., dim: int) -> Tensor + assert out is None + if len(tensors_and_dim) == 2: + tensors, dim = tensors_and_dim[0], tensors_and_dim[1] + else: + tensors = tensors_and_dim[0] + return CubeCat(*tensors, dim=dim, signature=signature) - inputs: - tensors: List[Tensor]: all tensors need to have same size - dim: the new inserted dim - e.g. stack(tensor([2,3]), tensor([2,3])).shape == [2,2,3] +def CubeStack(*tensors, dim: int, signature=None): """ - assert len(inputs) >= 2 - if len(inputs) == 2: - tensors, dim = inputs - else: - tensors, dim = inputs[:-1], inputs[-1] - if isinstance(dim, dict): - assert 'dim' in dim - dim = dim['dim'] - assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'{tensors}' + torch.stack(tensors, dim=0, *, out=None) + """ + # REMARK: IRFwOperation doesn't support taking a list of IRTensors. + # Therefore, the argument interface is adapted to take unpacked tensors + # with dimension. + assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'but got {tensors}' + assert isinstance(dim, int), f"but not {dim}" iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] oannos = [copy.copy(iannos[-1])] oannos[0].insert(dim, str(len(tensors))) anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(Stack, 'stack', signature, [anno], tensors, dim=dim) + return IRDimops(CubeStack, 'stack', signature, [anno], tensors, dim=dim) + + +def Stack(*tensors_and_dim, dim=0, out=None, signature = None): + """ + torch.stack(tensors, dim=0, *, out=None) + """ + if len(tensors_and_dim) == 2: + tensors, dim = tensors_and_dim[0], tensors_and_dim[1] + else: + tensors, dim = tensors_and_dim[0], dim + return CubeStack(*tensors, dim=dim, signature=signature) -def Chunk(signature, inputs: Tuple[IRTensor, int, int]): +def Chunk(input, chunks, dim=0, signature = None): """ torch.chunk(input, chunks, dim=0) """ - assert len(inputs) == 3 - tensor, chunks, dim = inputs - assert tensor.shape[dim] % chunks == 0 - iannos = [ShapeAnno.create_shape_str(tensor.shape)] + assert input.shape[dim] % chunks == 0 + iannos = [ShapeAnno.create_shape_str(input.shape)] oannos = [copy.copy(iannos[0]) for _ in range(chunks)] - iannos[0][dim] = str(tensor.shape[dim]) + iannos[0][dim] = str(input.shape[dim]) for oanno in oannos: - oanno[dim] = str(tensor.shape[dim] // chunks) + oanno[dim] = str(input.shape[dim] // chunks) anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(Chunk, 'chunk', signature, [anno], [tensor], chunks=chunks, dim=dim) + return IRDimops(Chunk, 'chunk', signature, [anno], [input], chunks=chunks, dim=dim) -def Select(signature, inputs: Tuple[IRTensor, int, int]): +def Select(input, dim, index, signature = None): """ torch.select(self:Tensor, dim:int, index:int) -> Tensor """ - tensor, dim, index = inputs - ianno = ShapeAnno.create_shape_str(tensor.shape) + ianno = ShapeAnno.create_shape_str(input.shape) oanno = copy.copy(ianno) ianno[dim] += '^' oanno.pop(dim) anno = OpAnno.create_op_str([ianno], [oanno]) - return IRDimops(Select, 'select', signature, [anno], [tensor], dim=dim, index=index) + return IRDimops(Select, 'select', signature, [anno], [input], dim=dim, index=index) -def IndexSelect(signature, inputs): - assert len(inputs) == 3 - # hack - if isinstance(inputs[1], int): - tensor, dim, idx = inputs - else: - assert isinstance(inputs[2], int) - tensor, idx, dim = inputs - - edim_in = ShapeAnno.create_shape_str(tensor.shape) +def CubeIndexSelect(input: torch.Tensor, index: torch.Tensor, dim: int, signature = None): + edim_in = ShapeAnno.create_shape_str(input.shape) edim_in[dim] += '^' idx_anno = chr(ord(edim_in[-1]) + 1) + '^' edim_ou = copy.copy(edim_in) edim_ou[dim] = copy.copy(idx_anno) anno = OpAnno.create_op_str([edim_in, idx_anno], [edim_ou]) + # FIXME: runtime function support + return IRDimops(CubeIndexSelect, 'index_select', signature, [anno], [input, index], dim=dim) + - return IRDimops(IndexSelect, 'index_select', signature, [anno], [tensor, idx], dim=dim) +def IndexSelect(input: torch.Tensor, dim: int, index: torch.Tensor, *, out=None, signature = None): + assert out is None + return CubeIndexSelect(input, index, dim, signature=signature) -def Slice(signature, inputs): +def Slice(tensor: torch.Tensor, dim, start, end, step, signature = None): """ aten::slice(input:Tensor, dim:int, start:Optional[int], end:Optional[int], step:int) -> Tensor """ signature = 'torch.ops.aten.slice' - tensor, dim, start, end, step = inputs ianno = ShapeAnno.create_shape_str(tensor.shape) oanno = copy.copy(ianno) ianno[dim] = str(tensor.shape[dim]) @@ -1258,13 +1057,12 @@ def clip(ofst): return IRDimops(Slice, 'slice', signature, [anno], [tensor], dim=dim, start=start, end=end, step=step) -def SelectScatter(signature, inputs: Tuple[IRTensor, IRTensor, int, int]): +def SelectScatter(self: torch.Tensor, input: torch.Tensor, dim: int, index: int, signature = None): """ torch.select_scatter(self:Tensor, input:Tensor, dim:int, index:int) -> Tensor """ # 'torch.select_scatter' isn't supported by Torch2ONNX yet. signature = 'cube.runtime.function.select_scatter' - self, input, dim, index = inputs # shape check self_shape, input_shape = self.shape, input.shape self_shape.pop(dim) @@ -1279,16 +1077,13 @@ def SelectScatter(signature, inputs: Tuple[IRTensor, IRTensor, int, int]): [anno], [self, input], dim=dim, index=index) -def Repeat(signature, inputs: Tuple[IRTensor, List[int]]): +def Repeat(tensor, repeats: Tuple[int], *arg_repeats, signature = None): """ - torch.repeat(tensor:Tensor, repeats: List[int]) -> Tensor + torch.Tensor.repeat(*sizes) """ signature = 'torch.ops.aten.repeat' - tensor = inputs[0] - if isinstance(inputs[1], list): - repeats = inputs[1] - else: - repeats = inputs[1:] + repeats = (repeats,) if isinstance(repeats, int) else tuple(repeats) + repeats = repeats + arg_repeats in_shape = tensor.shape assert len(in_shape) <= len(repeats), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor" expand = len(repeats) - len(tensor.shape) @@ -1309,65 +1104,73 @@ def Repeat(signature, inputs: Tuple[IRTensor, List[int]]): return IRDimops(Repeat, 'repeat', signature, [anno], [tensor], repeats=repeats) -def Embedding(signature, inputs: List): +def CubeEmbedding(input, weight, padding_idx, signature = None, **kwargs): """ - torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False) + cube.runtime.function.embedding(input, weight, padding_idx, start, stop) """ signature = 'cube.runtime.function.embedding' - itensor, weight = inputs[:2] - padding_idx = inputs[2] if isinstance(weight, IRSubTensor): start, stop = weight.indmap[0] else: start, stop = 0, weight.shape[0] annos = ['*, n+ e -> * e'] - return IRDimops(Embedding, 'embedding', signature, annos, [itensor, weight], + return IRDimops(CubeEmbedding, 'embedding', signature, annos, [input, weight], padding_idx=padding_idx, start=start, stop=stop) -def Flatten(signature, inputs: List): - tensor: IRTensor = inputs[0] - if len(inputs) == 1: - start_dim, end_dim = 0, len(tensor.shape) - 1 - else: - start_dim, end_dim = inputs[1:] - end_dim = len(tensor.shape) + end_dim if end_dim < 0 else end_dim - ishape = ShapeAnno.create_shape_str(tensor.shape) +def Embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, + scale_grad_by_freq=False, sparse=False, signature = None): + """ + torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False) + """ + assert max_norm is None and norm_type == 2.0 and (not scale_grad_by_freq) and (not sparse) + return CubeEmbedding(input, weight, padding_idx, signature=signature) + + +def Flatten(input, start_dim=0, end_dim=-1, signature = None): + end_dim = len(input.shape) + end_dim if end_dim < 0 else end_dim + ishape = ShapeAnno.create_shape_str(input.shape) for dim in range(start_dim, end_dim+1): ishape[dim] += '^' oshape = ishape[:start_dim] oshape.append(ishape[start_dim:end_dim+1]) anno = OpAnno.create_op_str([ishape], [oshape]) - return IRDimops(Flatten, 'flatten', signature, [anno], [tensor], start_dim=start_dim, end_dim=end_dim) + return IRDimops(Flatten, 'flatten', signature, [anno], [input], + start_dim=start_dim, end_dim=end_dim) -def Roll(signature, inputs: Tuple[IRTensor, Union[int, Tuple[int]], Union[int, Tuple[int]]]): - tensor = inputs[0] - shifts, dims = inputs[1:] - ishape = ShapeAnno.create_shape_str(tensor.shape) +def Roll(input, shifts: Union[int, Tuple[int]], dims=None, signature = None): + shifts = (shifts,) if isinstance(shifts, int) else shifts + ishape = ShapeAnno.create_shape_str(input.shape) for dim in range(len(ishape)): if dims is None or dim in dims: ishape[dim] += '^' anno = OpAnno.create_op_str([ishape], [ishape]) - return IRDimops(Roll, 'roll', signature, [anno], [tensor], shifts=shifts, dims=dims) + return IRDimops(Roll, 'roll', signature, [anno], [input], shifts=shifts, dims=dims) -def AdaptiveAvgPool1d(signature, inputs: Tuple[IRTensor, Tuple[int]]): - tensor = inputs[0] - out_size = inputs[1] - ishape = ShapeAnno.create_shape_str(tensor.shape) +def AdaptiveAvgPool1d(input, output_size, signature = None): + """ + torch.nn.functional.adaptive_avg_pool2d(input, output_size) + """ + ishape = ShapeAnno.create_shape_str(input.shape) ishape[-1] += '^' - oshape = ishape[:-1] + [str(size) for size in out_size] + oshape = ishape[:-1] + [str(size) for size in output_size] anno = OpAnno.create_op_str([ishape], [oshape]) - return IRDimops(AdaptiveAvgPool1d, 'adaptive_avg_pool1d', signature, [anno], [tensor], output_size=out_size) + return IRDimops(AdaptiveAvgPool1d, 'adaptive_avg_pool1d', signature, [anno], [input], output_size=output_size) -def CrossEntropy(signature, inputs): +def CrossEntropy(input, target, weight=None, + size_average=None, ignore_index=- 100, reduce=None, + reduction='mean', label_smoothing=0.0, signature = None): + """ + torch.nn.functional.cross_entropy( + input, target, weight=None, + size_average=None, ignore_index=- 100, reduce=None, + reduction='mean', label_smoothing=0.0) + """ # FIXME: reduction is by default 'mean', in this way it cannot be partitioned # no N dimension. - tensor, target, weight = inputs[0:3] - assert weight is None, "weight not supported for cross entropy" - size_average, ignore_index, reduce, reduction, label_smoothing = inputs[3:] annos = [ 'C^, N -> 1', 'N+ C, N+ -> 1', @@ -1375,36 +1178,22 @@ def CrossEntropy(signature, inputs): ] return IRDimops( CrossEntropy, 'cross_entropy', - signature, annos, [tensor, target], + signature, annos, [input, target], weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce, reduction=reduction, label_smoothing=label_smoothing ) - -def MultiRef(signature, inputs: List[IRTensor]): - """ - cube.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] - """ - signature = 'cube.runtime.function.multiref' - itensor, times = inputs - assert isinstance(itensor, IRTensor), "require all inputs to be IRSubTensor" - assert isinstance(times, int), "require int for second input" - anno = '* -> ' + ', '.join('*' for _ in range(times)) - node = IRDimops(MultiRef, 'multiref', signature, [anno], [itensor], times=times) - return node - - -def GraphAnchor(signature, inputs: List[IRSubTensor]): +def GraphAnchor(name: str, signature = None): """ cube.runtime.function.anchor() -> None """ - name: str = inputs[0] node = IRGraphAnchor(signature, name) return node -def _comparison(creator: Callable, f: Callable, name: str, signature: str, inputs): +def _comparison(creator: Callable, f: Callable, name: str, signature: str, + input, other): """ if both operands are scalars, returns bool. if one operand is a tensor, returns a broadcasted tensor with dtype being bool. @@ -1412,99 +1201,111 @@ def _comparison(creator: Callable, f: Callable, name: str, signature: str, input @param creator Callable: the outside creation function @param f Callable: (Scalar, Scalar) -> bools """ - assert len(inputs) == 2 - lhs, rhs = inputs - - if isinstance(lhs, (int, float)) and isinstance(rhs, (int, float)): - return f(lhs, rhs) - - annos = [ - '*, ? -> *', - '?, * -> *', - ] - if isinstance(lhs, IRTensor) and isinstance(rhs, IRTensor): - lshape, rshape, oshape = _handle_broadcast(lhs, rhs) + # case 0: return constant + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return f(input, other) + # case1: torch.equal(tensor1, tensor2) + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(creator, name, signature, annos, inputs) + return IRDimops(creator, name, signature, annos, [input, other]) + # case2: torch.equal(tensor1, obj2) / torch.equal(obj1, tensor2) + if isinstance(input, IRTensor) or isinstance(other, IRTensor): + annos = ['*, ? -> *', '?, * -> *',] + return IRDimops(creator, name, signature, annos, [input, other]) + # case3: torch.equal(obj1, obj2) + else: + return IRPyFunc(signature, [input, other], [IRObject()]) -def CompareGT(signature, inputs): +def CompareGT(input, other, *, out=None, signature = None): """ torch.gt(input, other, *, out=None) -> Tensor """ - return _comparison(CompareGT, operator.gt, 'gt', signature, inputs) + return _comparison(CompareGT, operator.gt, 'gt', signature, input, other) -def CompareLT(signature, inputs): +def CompareLT(input, other, *, out=None, signature = None): """ torch.lt(input, other, *, out=None) -> Tensor """ - return _comparison(CompareLT, operator.lt, 'lt', signature, inputs) + return _comparison(CompareLT, operator.lt, 'lt', signature, input, other) -def CompareGE(signature, inputs): +def CompareGE(input, other, *, out=None, signature = None): """ torch.ge(input, other, *, out=None) -> Tensor """ - return _comparison(CompareGE, operator.ge, 'ge', signature, inputs) + return _comparison(CompareGE, operator.ge, 'ge', signature, input, other) + -def CompareLE(signature, inputs): +def CompareLE(input, other, *, out=None, signature = None): """ torch.gt(input, other, *, out=None) -> Tensor """ - return _comparison(CompareLE, operator.le, 'le', signature, inputs) + return _comparison(CompareLE, operator.le, 'le', signature, input, other) -def ShapeAsTensor(signature, inputs): - assert len(inputs) == 1 - input = inputs[0] +def CompareEQ(input, other, *, out=None, signature = None): + """ + torch.eq(input, other, *, out=None) + """ + return _comparison(CompareEQ, operator.eq, 'eq', signature, input, other) + +def CompareNE(input, other, *, out=None, signature = None): + """ + torch.ne(input, other, *, out=None) + """ + return _comparison(CompareNE, operator.eq, 'ne', signature, input, other) + + +def ShapeAsTensor(input: IRTensor, signature = None): + """ + torch._shape_as_tensor + """ edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = [str(len(input.shape))] anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(ShapeAsTensor, '_shape_as_tensor', signature, [anno], [input]) # ================== Non-autograd Function Space ================= -def Size(signature, inputs) -> Union[List[int], IRPyFunc]: +def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: """ torch.Tensor.size(tensor, dim=None) """ - assert len(inputs) == 2 or len(inputs) == 1, f"but got {len(inputs)}, {inputs}" - tensor, dim = inputs if len(inputs) == 2 else (inputs[0], None) assert isinstance(tensor, IRTensor) # constant if all(isinstance(dimlen, int) for dimlen in tensor.shape) and not isinstance(dim, IRObject): return tensor.shape[dim] if isinstance(dim, int) else list(tensor.shape) - return IRPyFunc(signature, inputs, [IRObject()]) + return IRPyFunc(signature, [tensor, dim], [IRObject()]) -def GetItem(signature, inputs) -> Union[Any, IRPyFunc]: +def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: """ _operator.getitem(obj, index: int) """ - assert len(inputs) == 2, f"but got {inputs}" - obj, index = inputs + obj, index = a, b if (not isinstance(obj, IRObject)) and isinstance(index, int): return obj[index] else: - return IRPyFunc(signature, inputs, [IRObject()]) + return IRPyFunc(signature, [obj, index], [IRObject()]) + -def GetAttr(signature, inputs) -> Union[List[int], IRPyFunc]: +def GetAttr(instance: object, field: str, signature=None) -> Union[List[int], IRPyFunc]: """ builtins.getattr(object, name[, default]) NOTE: only deal with the attr "shape" of IRFullTensor, because other type of object may not have instantiated object or the attr is not simple value. """ - assert len(inputs) == 2, f"but got {inputs}" - obj, name = inputs + obj, name = instance, field if name == 'shape': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" assert hasattr(obj, name), f"attr {name} is not existed in {obj}" return getattr(obj, name) else: # FIXME: is it right? - return IRPyFunc(signature, inputs, [IRObject()]) + return IRPyFunc(signature, [instance, field], [IRObject()]) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index ac0fbc3d..19529200 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -485,7 +485,7 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens f"{graph.debug_tensor_map_str(ftensor)}" ) # set concat input / output - node = Cat('torch.cat', (ptensors, catdim)) + node = Cat(ptensors, dim=catdim) node.set_output(0, new_ftensor.select(otensor.indmap, otensor.valmap)) # set gradient for idx, ptensor in enumerate(ptensors): @@ -518,7 +518,7 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens for ptensor in ptensors[1:]: rhs = ptensor output = ftensor.like().select(ptensors[0].indmap, (0,1)) - node = Accum('cube.runtime.accum', [lhs, rhs]) + node = Accum(lhs, rhs) node.set_output(0, output) node.device = devid node.recompute = rcid @@ -528,7 +528,7 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens graph.remove(node) # === Orignal way to at alst release tensor - # node = Accum('cube.runtime.accum', ptensors) + # node = Accum(*ptensors) # # set gradient # for idx, ptensor in enumerate(ptensors): # node.input(idx).grad = ftensor.grad.select(ptensor.indmap, (0,1)) @@ -619,7 +619,7 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): f"{graph.mirror.debug_tensor_map_str(ftensor.grad)}" ) - multiref = MultiRef(None, [devtensors[devid][0], len(grads)]) + multiref = MultiRef(devtensors[devid][0], len(grads)) # set input gradient multiref.input(0).grad = accum_grad # set output and its gradient @@ -660,7 +660,7 @@ def autoref(graph: IRSegment) -> IRGraph: ftensor: IRFullTensor = multiref.input(0).parent multirefs = [] for otensor in graph.ptensors(ftensor): - mr = MultiRef(None, [otensor, len(multiref.outputs())]) + mr = MultiRef(otensor, len(multiref.outputs())) for idx in range(len(multiref.outputs())): output = multiref.output(idx).parent.select(otensor.indmap, otensor.valmap) if otensor.requires_grad: diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5528dd4b..9cd54f85 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -742,7 +742,7 @@ def get_sid(fnode: IRCell) -> Optional[int]: return None def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: - fwop = Identity('', [tensor]) + fwop = Identity(tensor) fwop.infer_shape() fwop.set_output(0, fwop.output(0).tosub()) if tensor.requires_grad: @@ -898,7 +898,7 @@ def auto_multiref(self): otensors = [] for otensor in node.outputs(): otensors.append(otensor.parent.select(ctensor.indmap, ctensor.valmap)) - multiref = MultiRef('', [itensor, len(otensors)]) + multiref = MultiRef(itensor, len(otensors)) for idx, otensor in enumerate(otensors): multiref.set_output(idx, otensor) multiref.device = devid @@ -913,7 +913,7 @@ def auto_multiref(self): outputs = [] for output in node.outputs(): outputs.append(output.parent.select(ptensor.indmap, ptensor.valmap)) - multiref = MultiRef('', [ptensor, len(outputs)]) + multiref = MultiRef(ptensor, len(outputs)) for idx, otensor in enumerate(outputs): multiref.set_output(idx, otensor) multiref.device = devid diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 622cc497..53acdd89 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -75,8 +75,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('triu'): function.Triu, __ftemplate('relu'): function.ReLU, __fcntemplate('gelu'): function.GeLU, - __ttemplate('eq') : function.EQ, - __ttemplate('ne') : function.NE, + __ttemplate('eq') : function.CompareEQ, + __ttemplate('ne') : function.CompareNE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, __ttemplate('fill_'): function.Fill, diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index ffdf02fe..d10aba18 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -238,7 +238,7 @@ def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: input_vals.append(val) # map to IR operator - ir_node = Sign2Op.map(fsig)(inputs=input_vals) + ir_node = Sign2Op.map(fsig)(*input_vals) # push output in the frame # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) @@ -309,7 +309,7 @@ def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: # May be a symbolic object i.e. IRFwOperation, # or, occasionally this node can be statically evaluated, therefore a concrete value - result = Sign2Op.map(fsig)(inputs=input_val) + result = Sign2Op.map(fsig)(*input_val) if isinstance(result, IRFwOperation): # to create IR node diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 9657be84..71112cae 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -6,12 +6,10 @@ from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor from cube.ir.cten import IRObject, IRCell -import cube.ir as ir from cube.graph.parser.frame import Frame from cube.graph.parser.mapping import DType2IRDType from cube.graph.parser.mappingfx import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.function import Identity import torch.fx @@ -229,46 +227,25 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule else: print(f'parse_prim_function_node: {fsig}') - def handle_tuple(fx_node: tuple) -> tuple: - vals = [] - for ele in fx_node: - if isinstance(ele, torch.fx.Node): - vals.append(frame.get_var(ele.name)) - elif isinstance(ele, tuple): - vals.append(handle_tuple(ele)) - else: - assert not isinstance(ele, (list, dict)) - vals.append(ele) - return tuple(vals) - - def extract_val(fx_node): - if isinstance(fx_node, torch.fx.Node): - var_name = fx_node.name - return frame.get_var(var_name) - elif isinstance(fx_node, (int, float, str, torch.dtype)) or fx_node is None: - return fx_node - elif isinstance(fx_node, (tuple, list)): - return handle_tuple(fx_node) - else: - raise RuntimeError(f'Unsupported input node {fx_node}, {type(fx_node)} in parse function!') - + def get_complex_data(val: Any) -> Any: + """Change inner fx.Node into IRObject""" + if isinstance(val, tuple): + return tuple(get_complex_data(t) for t in val) + if isinstance(val, list): + return list(get_complex_data(t) for t in val) + if isinstance(val, torch.fx.Node): + return frame.get_var(val.name) + return val + # get inputs - input_vals = list() - for item in node.args: - input_vals.append(extract_val(item)) - if node.kwargs: - input_kwvals = {} - for k, v in node.kwargs.items(): - input_kwvals[k] = extract_val(v) - input_vals.append(input_kwvals) + input_vals = [get_complex_data(val) for val in node.args] + kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} # map to IR operator if SignFx2Op.exist(fsig): - ir_node = SignFx2Op.map(fsig)(inputs=input_vals) + ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) else: - input_vals = [extract_val(v) for v in node.args] # FIXME: handle cases for IRObject in kwargs - kwargs = {key: extract_val(v) for key, v in node.kwargs.items()} # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): print(f'>>> Find unkown pytorch operation: {fsig}') diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 6be4e316..80047d63 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -58,14 +58,13 @@ def decorator(fn: Callable): code = inspect.getsource(fn) code = code[code.index('def'):] - def udfop(signature: str, inputs: List[Any]): + def udfop(*args, signature=None, **kwargs): manno = OpAnno(anno) - tensors = inputs[:ninputs] + tensors = args[:ninputs] for idx in range(ninputs): if arg_kinds[idx] == Optional[torch.Tensor] and tensors[idx] is None: manno.set_input(idx, '?') - kwarg_vals = inputs[ninputs:] - kwargs = dict() + kwarg_vals = args[ninputs:] for name, val in zip(kwarg_names, kwarg_vals): kwargs[name] = val return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, transform_rules=rules, **kwargs) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 7c66bf8c..1e97c257 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -663,7 +663,7 @@ def multiref(self, ftensor: IRFullTensor, node_groups: List[List[IRFwOperation]] ftensors: List[IRSubTensor] = [ftensor.like() for _ in node_groups] otensors: List[IRSubTensor] = [ft.select(tensor.indmap, tensor.valmap) for ft in ftensors] # create multiref - multiref = MultiRef('cube.runtime.function.multiref', [tensor, len(node_groups)]) + multiref = MultiRef(tensor, len(node_groups)) for idx, otensor in enumerate(otensors): multiref.set_output(idx, otensor) # setup gradient @@ -767,7 +767,7 @@ def single_consume(self, one_for_all: bool = True): consumer = cnodes.pop(0) if len(cnodes) > 0: itensors = [ftensor.like() for _ in range(2)] - multiref = MultiRef(None, [reftensor, 2]) + multiref = MultiRef(reftensor, 2) for idx, itensor in enumerate(itensors): multiref.set_output(idx, itensor) multiref.infer_shape() @@ -807,7 +807,7 @@ def single_consume(self, one_for_all: bool = True): idx = consumer.inputs().index(ftensor) consumer.set_input(idx, itensor) # create and insert multiref operation - multiref = MultiRef(None, [ftensor, len(cnodes)]) + multiref = MultiRef(ftensor, len(cnodes)) for idx, itensor in enumerate(itensors): multiref.set_output(idx, itensor) multiref.infer_shape() diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 50c7cfd7..de66022b 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -65,7 +65,8 @@ def forward(self, data, mask): x = x.fill_(0.0) x = torch.nn.functional.softmax(x, dim=-1) x = torch.bmm(x, x) - x = torch.baddbmm(x, x, x, alpha=0.125, beta=1.0) + adder = torch.sum(x, dim=2, keepdim=True) + x = torch.baddbmm(adder, batch1=x, batch2=x, alpha=0.125, beta=1.0) x = torch.tanh(x) x = torch.pow(x, x) for layer in self.layers: From 7bd5484ab71ee965d3acf3ed57fddcea0cde63d6 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 6 Mar 2023 23:19:42 -0800 Subject: [PATCH 1321/1892] refine code --- cube/profiler/database.py | 12 ++++++------ cube/runtime/schedule/sched1f1b.py | 2 -- examples/alphafold2/alphafold2.py | 7 +++---- examples/alphafold2/model.py | 4 ++++ examples/alphafold2/policy/spmd.py | 2 +- examples/nlp/gpt/model.py | 3 --- scripts/megatron.sh | 2 -- scripts/pre_install.sh | 2 -- 8 files changed, 14 insertions(+), 20 deletions(-) delete mode 100644 scripts/megatron.sh delete mode 100644 scripts/pre_install.sh diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 61dae614..911ea9ec 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -216,14 +216,14 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): torch.cuda.set_device(device) in_mem_info, param_mem_info = [], [] - Residual_mem, input_count = 0, 0 + residual_mem, input_count = 0, 0 for t in node.inputs(): if t.is_param(): param_mem_info.append(t.byte_size()) else: input_count += 1 if input_count == 1: - Residual_mem += t.byte_size() + residual_mem += t.byte_size() in_mem_info.append(t.byte_size()) @@ -232,7 +232,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): CompProfiler.profile(fn, shapes, dtypes, **kwargs) # log to database key = self._serialize(node) - self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, Residual_mem) + self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | fw: {round(fw_span, 2)} ms | " @@ -240,11 +240,11 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): if isinstance(device, int): torch.cuda.set_device(orig_device) - return tuple(in_mem_info), tuple(param_mem_info), fw_span, bw_span, infer_memory, train_mem_info, Residual_mem + return tuple(in_mem_info), tuple(param_mem_info), fw_span, bw_span, infer_memory, train_mem_info, residual_mem def insert(self, name: str, key: str, in_mem_info: Tuple[int], param_mem_info: Tuple[int], fw_span: float, bw_span: float, infer_memory: int, train_mem_info: Tuple[int], - Residual_mem: int): + residual_mem: int): """ log the span of a function name with key @@ -260,7 +260,7 @@ def insert(self, name: str, key: str, in_mem_info: Tuple[int], param_mem_info: T assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = (in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, Residual_mem) + self._data[name][key] = (in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) def exist(self, node: IRFwOperation) -> bool: """ diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py index e2c9d298..c8d4bdde 100644 --- a/cube/runtime/schedule/sched1f1b.py +++ b/cube/runtime/schedule/sched1f1b.py @@ -22,8 +22,6 @@ def run(segment: Callable, # forward body # special case: num_stages == 1: use gradient accum if num_stages == 1: for _ in range(num_microbatch): - # if torch.distributed.get_rank() == 0: - # print(_) inputs = Schedule1F1B.dataloader_step(dataloader) outputs = Schedule1F1B.forward_step(segment, *inputs) input_grads = Schedule1F1B.backward_step(inputs, outputs, (None,)) diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 3dd72d5f..5936be61 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -94,14 +94,14 @@ def train_iter(model, dataloader): def test_main(): # Training && Evoformer Stack # initial training - # bs, s, r, cm, cz = 1, 128, 256, 256, 128 + bs, s, r, cm, cz = 1, 128, 256, 256, 128 # first fine-tuning # bs, s, r, cm, cz = 1, 512, 256, 256, 128 # second fine-tuning # bs, s, r, cm, cz = 1, 512, 384, 256, 128 - bs, s, r, cm, cz = 1, 512, 512, 256, 128 + # bs, s, r, cm, cz = 1, 512, 512, 256, 128 - dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 4, False, True, False + dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False policy = spmd.PASDAP # Training && Extra Sequence @@ -127,4 +127,3 @@ def test_main(): if __name__ == '__main__': cube.init() test_main() - # build_alphafold_config(1) diff --git a/examples/alphafold2/model.py b/examples/alphafold2/model.py index a2be9bfe..b78f4ded 100644 --- a/examples/alphafold2/model.py +++ b/examples/alphafold2/model.py @@ -5,11 +5,15 @@ from examples.alphafold2.module import * from dataclasses import dataclass + + """ a simplified version for evoformer in alphafold2 - dropout layers are omitted - masks are omitted """ + + @dataclass class Config: bs: int = 1 diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 9e734f7e..24d07f1c 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -171,7 +171,7 @@ def PASDAP(graph: IRGraph, resource): if isinstance(fnodes[j], IRGraphAnchor): sub_indices.append(j) sub_indices.append(rhs) - # graph.recompute(fnodes[lhs:rhs]) + graph.recompute(fnodes[lhs:rhs]) for j in range(len(sub_indices) - 1): sub_l, sub_r = sub_indices[j], sub_indices[j + 1] names = [] diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index c13915e4..910908af 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -234,6 +234,3 @@ def random_sample(self): device=torch.cuda.current_device() ).repeat(self.batch_size).view(self.batch_size, -1) return input_ids, position_ids - - def __next__(self): - return self.samples[0] diff --git a/scripts/megatron.sh b/scripts/megatron.sh deleted file mode 100644 index 82466184..00000000 --- a/scripts/megatron.sh +++ /dev/null @@ -1,2 +0,0 @@ -torchrun --nproc_per_node=2 --nnodes=1 \ - examples/nlp/gpt/train.py --policy=PASMegatronWSRTP --fp16 | tee -a LogForMegatronRecompute.txt \ No newline at end of file diff --git a/scripts/pre_install.sh b/scripts/pre_install.sh deleted file mode 100644 index 08e65899..00000000 --- a/scripts/pre_install.sh +++ /dev/null @@ -1,2 +0,0 @@ -pip install -r requirements.txt -python setup.py develop --user \ No newline at end of file From 69464127315931811c1490e22d0cb46514efe555 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Tue, 7 Mar 2023 03:17:56 -0800 Subject: [PATCH 1322/1892] debug --- cube/algorithm/ops/dimops.py | 6 +++++- cube/graph/function/function.py | 12 +++++++++++- cube/graph/parser/mappingfx.py | 1 + cube/graph/parser/parserfx.py | 1 + cube/profiler/database.py | 32 ++++++++++++++++++++++---------- 5 files changed, 40 insertions(+), 12 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index dc412f84..28764af5 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -331,7 +331,11 @@ def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: def gen_hash(node: IRFwOperation) -> str: ret = node.signature for it in node.inputs(): - ret = ret + '-' + str(it.shape) + # FIXME: this is hack + if isinstance(it, float): + ret = ret + '-' + str(it) + else: + ret = ret + '-' + str(it.shape) return ret dq = deque() diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index d9c9b987..f88570c6 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1295,7 +1295,7 @@ def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: return IRPyFunc(signature, [obj, index], [IRObject()]) -def GetAttr(instance: object, field: str, signature=None) -> Union[List[int], IRPyFunc]: +def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], IRPyFunc]: """ builtins.getattr(object, name[, default]) NOTE: only deal with the attr "shape" of IRFullTensor, because other type of object may not @@ -1306,6 +1306,16 @@ def GetAttr(instance: object, field: str, signature=None) -> Union[List[int], IR assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" assert hasattr(obj, name), f"attr {name} is not existed in {obj}" return getattr(obj, name) + elif name == 'dtype': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + assert hasattr(obj, name), f"attr {name} is not existed in {obj}" + return getattr(obj, name) + elif isinstance(obj, torch.finfo): + return getattr(obj, name) else: # FIXME: is it right? return IRPyFunc(signature, [instance, field], [IRObject()]) + +def FInfo(dtype: IRDType, signature = None) -> torch.finfo: + assert isinstance(dtype, IRDType) + return torch.finfo(eval('torch.' + dtype.value)) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 53acdd89..4172f372 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -93,6 +93,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('detach'): function.Detach, __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, __ttemplate('index_select'): function.IndexSelect, + __ttemplate('finfo'): function.FInfo, __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 71112cae..93689324 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -243,6 +243,7 @@ def get_complex_data(val: Any) -> Any: # map to IR operator if SignFx2Op.exist(fsig): + print('zql: ', input_vals, kwargs) ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) else: # FIXME: handle cases for IRObject in kwargs diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 079baaf7..a4fa8226 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -11,7 +11,8 @@ import cube from cube.ir.cten import IRTensor from cube.ir.operator import IRFwOperation -from cube.graph.parser.mapping import Sign2Op, IRDType2TorchDType +from cube.graph.parser.mapping import IRDType2TorchDType +from cube.graph.parser.mappingfx import SignFx2Op as Sign2Op Shapes = NewType('Shapes', Tuple[Tuple[int]]) @@ -50,12 +51,13 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, # create data dtypes = [torch.float32] * len(shapes) if dtypes is None else dtypes def gen_torch_tensors(shape, dtype): - constructor = torch.zeros if dtype == torch.int64 else torch.rand - requires_grad = False if dtype == torch.int64 else True + constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand + requires_grad = False if dtype in (torch.int64, torch.int32, torch.bool) else True return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) tensors = tuple( gen_torch_tensors(shape, dtype) for shape, dtype in zip(shapes, dtypes) ) + require_backward = any([t.requires_grad for t in tensors]) # repalce kwargs starting with 'self.xxx' train_kwargs, eval_kwargs = {}, {} for name, value in kwargs.items(): @@ -108,12 +110,12 @@ def unpack_hook(x): torch.cuda.synchronize() torch.cuda.empty_cache() with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): - outs = run_step(func, tensors, train_kwargs, backward=True) + outs = run_step(func, tensors, train_kwargs, backward=require_backward) # warmup tic = time.time() while time.time() - tic < warmup_sec: - run_step(func, tensors, train_kwargs, backward=True) + run_step(func, tensors, train_kwargs, backward=require_backward) # profile forward only torch.cuda.synchronize() @@ -129,7 +131,7 @@ def unpack_hook(x): torch.cuda.synchronize() tic = time.perf_counter() for _ in range(prof_times): - run_step(func, tensors, train_kwargs, backward=True) + run_step(func, tensors, train_kwargs, backward=require_backward) torch.cuda.synchronize() toc = time.perf_counter() fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds @@ -169,6 +171,8 @@ def get_dep_names(sign: str): return ret if node.signature in Sign2Op.kOpCodeDef: + # FIXME: ... + assert False, 'Sing2Op.kOpCodeDef is not empty' dep_code_impl = '' for dep_name in get_dep_names(node.signature): dep_code_impl = dep_code_impl + Sign2Op.kOpCodeDef[dep_name] @@ -184,7 +188,10 @@ def get_dep_names(sign: str): exec(code_impl, globals(), local) fn = list(local.values())[0] else: - fn = eval(node.signature) + if '_operator.' in node.signature: + fn = eval(node.signature.replace('_operator.', 'torch.')) + else: + fn = eval(node.signature) shapes, dtypes = [], [] for t in node.inputs(): assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" @@ -206,10 +213,15 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @return infer_memory int: the peak memory in bytes after inference of the function @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ + if node.name in ('mul',): + key = self._serialize(node) + self.insert(node.signature, key, (0,), (0,), 0, 0, 0, (0,), 0) + return (0,), (0,), 0, 0, 0, (0,), 0 fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) if self.exist(node): - return self.query(node) + ret = list(self.query(node)) + return ret + [0] if isinstance(device, int): orig_device = torch.cuda.current_device() @@ -227,7 +239,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): CompProfiler.profile(fn, shapes, dtypes, **kwargs) # log to database key = self._serialize(node) - self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info) + self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, 0) print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | fw: {round(fw_span, 2)} ms | " @@ -235,7 +247,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): if isinstance(device, int): torch.cuda.set_device(orig_device) - return tuple(in_mem_info), tuple(param_mem_info), fw_span, bw_span, infer_memory, train_mem_info + return tuple(in_mem_info), tuple(param_mem_info), fw_span, bw_span, infer_memory, train_mem_info, 0 def insert(self, name: str, key: str, in_mem_info: Tuple[int], param_mem_info: Tuple[int], fw_span: float, bw_span: float, infer_memory: int, train_mem_info: Tuple[int]): From 6297ffc1ed643548b8c35d577ffb27352ce12b84 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Tue, 7 Mar 2023 04:47:43 -0800 Subject: [PATCH 1323/1892] bloom partition --- cube/graph/function/function.py | 20 +++- cube/profiler/database.py | 11 +-- cube/runtime/function/function.py | 4 +- tests/parser/test_bloom.py | 150 ++++++++++++++++++++++++++++-- 4 files changed, 167 insertions(+), 18 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index f88570c6..4295560d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -349,15 +349,25 @@ def Sub(input, other, alpha=1, *, out=None, signature = None): return IRDimops(Sub, 'sub', signature, annos, [input, other], alpha=1) +def CubeMul(input, other, *, out=None, signature = None): + signature = 'cube.runtime.function.mul' + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(CubeMul, 'mul', signature, annos, [input, other]) + else: + annos = ['* -> *'] + if isinstance(input, IRTensor): + return IRDimops(CubeMul, 'mul', signature, annos, [input], other=other) + else: + return IRDimops(CubeMul, 'mul', signature, annos, [other], other=input) + + def Mul(input, other, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input * other - annos = ['*, ? -> *', '?, * -> *',] - if isinstance(input, IRTensor) and isinstance(other, IRTensor): - lshape, rshape, oshape = _handle_broadcast(input, other) - annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Mul, 'mul', signature, annos, [input, other]) + return CubeMul(input, other, out=out, signature=signature) def Div(input, other, *, rounding_mode=None, out=None, signature = None): diff --git a/cube/profiler/database.py b/cube/profiler/database.py index aa941d35..67d5b08f 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -213,15 +213,14 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @return infer_memory int: the peak memory in bytes after inference of the function @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ - if node.name in ('mul',): + if node.name in ('mul', 'expand'): key = self._serialize(node) self.insert(node.signature, key, (0,), (0,), 0, 0, 0, (0,), 0) return (0,), (0,), 0, 0, 0, (0,), 0 fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) if self.exist(node): - ret = list(self.query(node)) - return ret + [0] + return self.query(node) if isinstance(device, int): orig_device = torch.cuda.current_device() @@ -360,9 +359,9 @@ def _serialize(self, node: IRFwOperation) -> str: """ shapes, dtypes = [], [] for t in node.inputs(): - assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" - shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) + if isinstance(t, IRTensor):#, f"Only support node inputs with tensor shape" + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) shapes = '-'.join(str(tuple(shape)) for shape in shapes) dtypes = '-'.join(str(dtype) for dtype in dtypes) return shapes + ' : ' + dtypes diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index bd7c5ec3..02546ace 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Union import torch import torch.nn.functional as TorchF @@ -32,6 +32,8 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) +def mul(input: torch.Tensor, other: Union[float, torch.Tensor]) -> torch.Tensor: + return torch.mul(input, other) def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], stride: int, padding: List[int], dilation, groups: int = 1): diff --git a/tests/parser/test_bloom.py b/tests/parser/test_bloom.py index e286606d..934fc1cd 100644 --- a/tests/parser/test_bloom.py +++ b/tests/parser/test_bloom.py @@ -1,3 +1,4 @@ +from pathlib import Path import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig @@ -28,15 +29,152 @@ print("parsing fx graph to cube graph...") from cube.graph.parser import FxModuleParser -cube_graph = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) +inputs, nodes, outputs = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) print("parsing done.") +from cube.graph import IRGraph +module_name = model.__class__.__name__ +cube_graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) -# # AutoDist +# AutoDist # # profile communication cost # import os # comm_gpu_num = (2, 4) # for gpu_num in comm_gpu_num: -# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --comm_profile_dir=./ --connect_type=NV') -# # find the best partition plan -# from autodist.apis import compile -# compile(cube_graph, ...) +# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --connect_type=NV') +# profile computation cost +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ +config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': 'bloom'}) +config.autodist_config = dotdict({'ngpus': 2}) +# NOTE add SINGLE_DEV_MODE=1 before the running command +from autodist.cost_model.cost_database import CostDatabase +cost_database = CostDatabase(cube_graph, config) +# find the best partition plan +from autodist.task_config import TaskConfig +class BloomTaskConfig(TaskConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model = 'Bloom' + # self.Bloom_setting = kwargs['Bloom_setting'] + # self.fine_grained_Bloom = kwargs['fine_grained_Bloom'] + # self.bloom_config = build_bloom_config(self.Bloom_setting) + self.task_name = f'bloom-{self.autodist_config.ngpus}gpu-'\ + f'{self.autodist_config.micro_batch_size}batch_size' + self.estimated_fname, self.backup_fname, self.runtime_fname = self._build_file_name( + self.task_name) + self.allow_recom_ops = [] + self.del_dim = [] +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Bloom benchmark') + parser.add_argument('--fp16', + action='store_true', + help='use fp16 for the training') + parser.add_argument('--fine_grained_GPT', + action='store_true', + help='model = GPTFineGrained') + parser.add_argument('--GPT_setting', + type=str, + default='6.7B', + help='set GPT model type') + parser.add_argument('--save_folder', + type=str, + default='exp_data', + help='set the save folder for experiment data') + parser.add_argument('--micro_batch_size', + type=int, + default=8, + help='set micro batch size') + parser.add_argument('--global_batch_size', + type=int, + default=8, + help='set the global batch size') + parser.add_argument('--iter_num', + type=int, + default=2, + help='set the number of all iterations') + parser.add_argument('--warm_num', + type=int, + default=1, + help='set the number of warmup iterations') + parser.add_argument('--recompute', + action='store_true', + help='set recompute flag') + parser.add_argument('--memory_constraint', + type=float, + default=32, + help='memory constraint for program') + parser.add_argument('--memory_granularity', + type=int, + default=1, + help='memory granularity in byte') + parser.add_argument('--profile_dir', + type=str, + default=str(Path.home()) + '/.autodist', + help='profile dir') + parser.add_argument('--connect_type', + type=str, + default='NV2', + help='connect type from nvidia-smi topo -m') + parser.add_argument('--use_prev_plan', + action='store_true', + help='run from previous plan') + parser.add_argument('--is_train', + action='store_true', + help='True: train, False: inference') + parser.add_argument('--topk', + type=int, + default=20, + help='generate multiple plans for robustness') + parser.add_argument('--mesh_row', type=int, default=1, help='node num') + parser.add_argument('--mesh_col', + type=int, + default=2, + help='dev num in a node') + parser.add_argument('--compile', + action='store_true', + help='compile stage: true, runtime stage: false') + parser.add_argument('--pipeline', + action='store_true', + help='pipeline: true, tensor parallel: false') + parser.add_argument('--nproc', + type=int, + default=12, + help='multiprocess deg in pipeline') + parser.add_argument('--adaptive_recom', + action='store_true', + help='allow adaptive recompute') + parser.add_argument('--plan_idx', + type=int, + default=0, + help='runtime plan idx') + parser.add_argument('--verbose', action='store_true', help='verbose mode') + parser.add_argument('--ignore_small_tensor_threshold', + type=int, + default=0, + help='set the tensor size threshold to ignore') + parser.add_argument('--parse_plan', + action='store_true', + help='parse plan to user-friendly format') + parser.add_argument('--alphafold', + action='store_true', + help='use alphafold2') + parser.add_argument('--alphafold_setting', + type=int, + default=1, + help='1: bs, s, r = 1, 128, 256'\ + '2: bs, s, r = 1, 512, 256'\ + '3: bs, s, r = 1, 512, 384') + args = parser.parse_args() + + # if args.compile: + # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' + + task_config = BloomTaskConfig(**vars(args)) + from autodist.apis import calc_parallel_plan + topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) + # from autodist.apis import compile + # compile(cube_graph, None, task_config) From 5a9e59a41bdc6183238156d26fd30014e89ee8de Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 8 Mar 2023 03:58:33 -0800 Subject: [PATCH 1324/1892] support more ops with IRDimops --- cube/graph/function/function.py | 82 ++++++++++++++++++++++++++----- cube/graph/parser/mappingfx.py | 7 ++- cube/graph/parser/parserfx.py | 4 ++ cube/profiler/database.py | 4 -- cube/runtime/function/function.py | 11 +++++ 5 files changed, 90 insertions(+), 18 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 4295560d..c4ef5c18 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -308,6 +308,7 @@ def Expand(input, *sizes, signature = None): """ torch.Tensor.expand(*sizes) """ + signature = 'cube.runtime.function.expand' edim_in = ShapeAnno.create_shape_str(input.shape) assert len(input.shape) == len(sizes) for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes)): @@ -327,26 +328,55 @@ def Clone(input, *, memory_format=None, signature = None): return IRDimops(Clone, 'clone', signature, annos, [input]) +def BitwiseOr(input, other, *, out=None, signature=None): + """ + torch.bitwise_or(input, other, *, out=None) → Tensor + """ + assert isinstance(input, IRTensor) and isinstance(other, IRTensor) + annos = ['* -> *'] + return IRDimops(BitwiseOr, 'bitwise_or', signature, annos, [input, other]) + + +def CubeAdd(input, other, alpha=1, *, out=None, signature = None): + signature = 'cube.runtime.function.add' + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) + else: + annos = ['* -> *'] + if isinstance(input, IRTensor): + return IRDimops(CubeAdd, 'add', signature, annos, [input], other=other, alpha=alpha) + else: + return IRDimops(CubeAdd, 'add', signature, annos, [other], other=input, alpha=alpha) + + def Add(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input + alpha * other - annos = ['*, ? -> *', '?, * -> *',] + return CubeAdd(input, other, alpha, out=out, signature=signature) + + +def CubeSub(input, other, alpha=1, *, out=None, signature = None): + signature = 'cube.runtime.function.sub' if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) + return IRDimops(Sub, 'sub', signature, annos, [input, other], alpha=alpha) + else: + annos = ['* -> *'] + if isinstance(input, IRTensor): + return IRDimops(CubeAdd, 'sub', signature, annos, [input], other=other, alpha=alpha) + else: + return IRDimops(CubeAdd, 'sub', signature, annos, [other], other=input, alpha=alpha) def Sub(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input - alpha * other - annos = ['*, ? -> *', '?, * -> *',] - if isinstance(input, IRTensor) and isinstance(other, IRTensor): - lshape, rshape, oshape = _handle_broadcast(input, other) - annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Sub, 'sub', signature, annos, [input, other], alpha=1) + return CubeSub(input, other, alpha, out=out, signature=signature) def CubeMul(input, other, *, out=None, signature = None): @@ -1159,6 +1189,17 @@ def Roll(input, shifts: Union[int, Tuple[int]], dims=None, signature = None): return IRDimops(Roll, 'roll', signature, [anno], [input], shifts=shifts, dims=dims) +def Inverse(input, *, out=None, signature=None): + """ + torch.inverse(input, *, out=None) → Tensor + """ + ishape = ShapeAnno.create_shape_str(input.shape) + ishape = [i + '^' for i in ishape] + oshape = copy.copy(ishape) + anno = OpAnno.create_op_str([ishape], [oshape]) + return IRDimops(Inverse, 'inverse', signature, [anno], [input]) + + def AdaptiveAvgPool1d(input, output_size, signature = None): """ torch.nn.functional.adaptive_avg_pool2d(input, output_size) @@ -1294,6 +1335,23 @@ def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: return IRPyFunc(signature, [tensor, dim], [IRObject()]) +def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): + """ + torch.Tensor.to(*args, **kwargs) → Tensor + """ + assert out is None + # FIXME: support full version of torch.Tensor.to + # create "to" in cube runtime functions because dtype if not kwarg in torch.Tensor.to + signature = 'cube.runtime.function.to' + annos = ['* -> *'] + if isinstance(dtype_or_device, torch.device): + return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device) + else: + assert isinstance(dtype_or_device, (IRDType, torch.dtype)) + dtype = dtype_or_device if isinstance(dtype_or_device, torch.dtype) else eval('torch.'+dtype_or_device.value) + return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) + + def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: """ _operator.getitem(obj, index: int) @@ -1304,7 +1362,7 @@ def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: else: return IRPyFunc(signature, [obj, index], [IRObject()]) - + def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], IRPyFunc]: """ builtins.getattr(object, name[, default]) @@ -1312,14 +1370,14 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], have instantiated object or the attr is not simple value. """ obj, name = instance, field - if name == 'shape': + if name in ('shape', 'dtype'): assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" assert hasattr(obj, name), f"attr {name} is not existed in {obj}" return getattr(obj, name) - elif name == 'dtype': + elif name == 'device': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - assert hasattr(obj, name), f"attr {name} is not existed in {obj}" - return getattr(obj, name) + # FIXME: this is hack, IRFullTensor does not have attribute "device" + return torch.device('cpu') elif isinstance(obj, torch.finfo): return getattr(obj, name) else: diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 4172f372..b53102ed 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -94,6 +94,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, __ttemplate('index_select'): function.IndexSelect, __ttemplate('finfo'): function.FInfo, + __ttemplate('inverse'): function.Inverse, + __ttemplate('bitwise_or'): function.BitwiseOr, __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, @@ -101,6 +103,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # ============== runtime function ================= __tttemplate('size'): function.Size, + __tttemplate('to'): function.To, '_operator.getitem': function.GetItem, 'builtins.getattr': function.GetAttr, @@ -133,9 +136,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('clone'): function.Clone, # __ttemplate('add') : function.Add, - # + '_operator.add': function.Add, # __ttemplate('sub') : function.Sub, - # + '_operator.sub': function.Sub, # __ttemplate('mul') : function.Mul, '_operator.mul': function.Mul, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 93689324..d1a30a7a 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -241,10 +241,14 @@ def get_complex_data(val: Any) -> Any: input_vals = [get_complex_data(val) for val in node.args] kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} + if '.or' in fsig: + print('zql find to: ', fsig, node.name, node.target, node.meta, node.args, node.kwargs, input_vals, kwargs) + # exit(1) # map to IR operator if SignFx2Op.exist(fsig): print('zql: ', input_vals, kwargs) ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) + print('zql ir_node: ', ir_node) else: # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 67d5b08f..b51a72de 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -213,10 +213,6 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @return infer_memory int: the peak memory in bytes after inference of the function @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ - if node.name in ('mul', 'expand'): - key = self._serialize(node) - self.insert(node.signature, key, (0,), (0,), 0, 0, 0, (0,), 0) - return (0,), (0,), 0, 0, 0, (0,), 0 fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) if self.exist(node): diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 02546ace..0f83d4da 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -22,6 +22,8 @@ def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: """ return tensor if times == 1 else tuple([tensor] * times) +def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) -> torch.Tensor: + return tensor.to(dtype_or_device) def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: """ @@ -32,9 +34,18 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) +def add(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float]) -> torch.Tensor: + return torch.add(input, other, alpha=alpha) + +def sub(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float]) -> torch.Tensor: + return torch.sub(input, other, alpha=alpha) + def mul(input: torch.Tensor, other: Union[float, torch.Tensor]) -> torch.Tensor: return torch.mul(input, other) +def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: + return input.expand(*sizes) + def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], stride: int, padding: List[int], dilation, groups: int = 1): """ From ceb08f21bb3a8e65785519f63ba84e58f1c9911f Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 8 Mar 2023 04:12:09 -0800 Subject: [PATCH 1325/1892] resolve comments --- cube/algorithm/ops/dimops.py | 6 +----- cube/graph/function/function.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 28764af5..dc412f84 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -331,11 +331,7 @@ def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: def gen_hash(node: IRFwOperation) -> str: ret = node.signature for it in node.inputs(): - # FIXME: this is hack - if isinstance(it, float): - ret = ret + '-' + str(it) - else: - ret = ret + '-' + str(it.shape) + ret = ret + '-' + str(it.shape) return ret dq = deque() diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index c4ef5c18..ed89e85c 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -333,7 +333,7 @@ def BitwiseOr(input, other, *, out=None, signature=None): torch.bitwise_or(input, other, *, out=None) → Tensor """ assert isinstance(input, IRTensor) and isinstance(other, IRTensor) - annos = ['* -> *'] + annos = ['*, * -> *'] return IRDimops(BitwiseOr, 'bitwise_or', signature, annos, [input, other]) From ba9e92d8f723cfdaaf5589171d1eb00241ce3caa Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 8 Mar 2023 06:27:09 -0800 Subject: [PATCH 1326/1892] save work --- cube/graph/function/function.py | 44 ++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 4295560d..0c5e1c7f 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1299,7 +1299,49 @@ def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: _operator.getitem(obj, index: int) """ obj, index = a, b - if (not isinstance(obj, IRObject)) and isinstance(index, int): + + def try_reshape(tensor, expr): + if expr[0] == Ellipsis and expr[1] == None: + return True, copy.copy(tensor.shape) + [1] + dim_cnt = 0 + idx = 0 + dst_shape = [] + for item in expr: + if item == slice(None, None, None): + dst_shape.append(tensor.shape[idx]) + idx += 1 + elif item == None: + dst_shape.append(1) + else: + return False, [] + if idx != len(tensor.shape): + return False, [] + return True, dst_shape + + def try_select(tensor, expr): + int_cnt = 0 + dim = -1 + val = -1 + for i, item in enumerate(expr): + if isinstance(item, int): + int_cnt += 1 + dim = i + val = item + if int_cnt != 1: + return False, -1, -1 + if expr[0] == Ellipsis: + dim = dim - len(expr) + return True, dim, val + + if isinstance(obj, IRTensor): + is_reshape, dst_shape = try_reshape(obj, index) + if is_reshape: + return Reshape(obj, dst_shape, signature='torch.reshape') + is_select, dim, val = try_select(obj, index) + if is_select: + return Select(obj, dim, val, 'torch.select') + assert False, f'{obj}, {index}' + elif (not isinstance(obj, IRObject)) and isinstance(index, int): return obj[index] else: return IRPyFunc(signature, [obj, index], [IRObject()]) From ec942abf8aeadd7d180e167c9a2a6bfc9f086de7 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 8 Mar 2023 07:19:55 -0800 Subject: [PATCH 1327/1892] save work --- tests/parser/yizhu1_bloom.py | 219 +++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 tests/parser/yizhu1_bloom.py diff --git a/tests/parser/yizhu1_bloom.py b/tests/parser/yizhu1_bloom.py new file mode 100644 index 00000000..5bb44b14 --- /dev/null +++ b/tests/parser/yizhu1_bloom.py @@ -0,0 +1,219 @@ +from pathlib import Path +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +import time +import json + +def convert_mem_into_GB(mem): + if type(mem) in [int, float]: + return mem / 1024 / 1024 / 1024 + else: + return [x / 1024 / 1024 / 1024 for x in mem] + +model_name = "bigscience/bloom-560m" +model_path = "/home/quzha/bloom560m" + +print("Loading model...") +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir=model_path) +print(type(model), '; is nn.Module? ', isinstance(model, nn.Module)) +print("Model's generation config which does not list default values: ", model.generation_config) +print("Loading tokenizer...") +tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) +print("Loading Done!") +prompt = "If I want to travel to a new city, I should plan my trip as follows:" +#input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() +inputs = tokenizer(prompt, return_tensors="pt") + +# Cube +# from cube.graph import parser +# ir_graph = parser.convert_model(model, input_shapes=[1, 17], save_content=False) + +print("concrete tracing model...") +import sys +nni_path = "/home/quzha/yizhu1/yizhu1_autodist/nni/" +sys.path.append(nni_path) + +# from concrete_trace_utils import concrete_trace +from nni.common.concrete_trace_utils import concrete_trace +# from cube.graph.parser.concrete_trace_utils import concrete_trace + +traced_graph = concrete_trace(model, inputs, use_operator_patch=True, + autowrap_leaf_class={torch.finfo: ((), False)}) +print("tracing model done.") +# print(traced_graph) + +print("parsing fx graph to cube graph...") +from cube.graph.parser import FxModuleParser +inputs, nodes, outputs = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) +print("parsing done.") +from cube.graph import IRGraph +module_name = model.__class__.__name__ +cube_graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) + +# AutoDist +# # profile communication cost +# import os +# comm_gpu_num = (2, 4) +# for gpu_num in comm_gpu_num: +# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --connect_type=NV') +# profile computation cost +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ +config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': 'bloom'}) +config.autodist_config = dotdict({'ngpus': 2}) +# NOTE add SINGLE_DEV_MODE=1 before the running command +from autodist.cost_model.cost_database import CostDatabase +cost_database = CostDatabase(cube_graph, config) +# find the best partition plan +from autodist.task_config import TaskConfig +class BloomTaskConfig(TaskConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model = 'Bloom' + # self.Bloom_setting = kwargs['Bloom_setting'] + # self.fine_grained_Bloom = kwargs['fine_grained_Bloom'] + # self.bloom_config = build_bloom_config(self.Bloom_setting) + self.task_name = f'bloom-{self.autodist_config.ngpus}gpu-'\ + f'{self.autodist_config.micro_batch_size}batch_size' + self.estimated_fname, self.backup_fname, self.runtime_fname = self._build_file_name( + self.task_name) + self.allow_recom_ops = [] + self.del_dim = [] +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Bloom benchmark') + parser.add_argument('--fp16', + action='store_true', + help='use fp16 for the training') + parser.add_argument('--fine_grained_GPT', + action='store_true', + help='model = GPTFineGrained') + parser.add_argument('--GPT_setting', + type=str, + default='6.7B', + help='set GPT model type') + parser.add_argument('--save_folder', + type=str, + default='exp_data', + help='set the save folder for experiment data') + parser.add_argument('--micro_batch_size', + type=int, + default=8, + help='set micro batch size') + parser.add_argument('--global_batch_size', + type=int, + default=8, + help='set the global batch size') + parser.add_argument('--iter_num', + type=int, + default=2, + help='set the number of all iterations') + parser.add_argument('--warm_num', + type=int, + default=1, + help='set the number of warmup iterations') + parser.add_argument('--recompute', + action='store_true', + help='set recompute flag') + parser.add_argument('--memory_constraint', + type=float, + default=32, + help='memory constraint for program') + parser.add_argument('--memory_granularity', + type=int, + default=1, + help='memory granularity in byte') + parser.add_argument('--profile_dir', + type=str, + default=str(Path.home()) + '/.autodist', + help='profile dir') + parser.add_argument('--connect_type', + type=str, + default='NV2', + help='connect type from nvidia-smi topo -m') + parser.add_argument('--use_prev_plan', + action='store_true', + help='run from previous plan') + parser.add_argument('--is_train', + action='store_true', + help='True: train, False: inference') + parser.add_argument('--topk', + type=int, + default=20, + help='generate multiple plans for robustness') + parser.add_argument('--mesh_row', type=int, default=1, help='node num') + parser.add_argument('--mesh_col', + type=int, + default=2, + help='dev num in a node') + parser.add_argument('--compile', + action='store_true', + help='compile stage: true, runtime stage: false') + parser.add_argument('--pipeline', + action='store_true', + help='pipeline: true, tensor parallel: false') + parser.add_argument('--nproc', + type=int, + default=12, + help='multiprocess deg in pipeline') + parser.add_argument('--adaptive_recom', + action='store_true', + help='allow adaptive recompute') + parser.add_argument('--plan_idx', + type=int, + default=0, + help='runtime plan idx') + parser.add_argument('--verbose', action='store_true', help='verbose mode') + parser.add_argument('--ignore_small_tensor_threshold', + type=int, + default=0, + help='set the tensor size threshold to ignore') + parser.add_argument('--parse_plan', + action='store_true', + help='parse plan to user-friendly format') + parser.add_argument('--alphafold', + action='store_true', + help='use alphafold2') + parser.add_argument('--alphafold_setting', + type=int, + default=1, + help='1: bs, s, r = 1, 128, 256'\ + '2: bs, s, r = 1, 512, 256'\ + '3: bs, s, r = 1, 512, 384') + args = parser.parse_args() + + # if args.compile: + # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' + + task_config = BloomTaskConfig(**vars(args)) + from autodist.apis import calc_parallel_plan + compile_start_time = time.time() + topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) + compile_cost_time = time.time() - compile_start_time + plan_info = [] + for plan in topk_plans: + cur_info = {} + if task_config.pipeline: + cur_spmd_descs, cur_time, cur_mems, cur_devs, cur_times = plan + cur_info['plan'] = [] + for item in cur_spmd_descs: + cur_info['plan'].append(item.to_json_object()) + cur_info['estimated time'] = cur_time + cur_info['estimated memory'] = convert_mem_into_GB(cur_mems) + cur_info['estimated time list'] = cur_times + cur_info['compile time'] = compile_cost_time + plan_info.append(cur_info) + else: + cur_spmd_desc, cur_mem, cur_time = plan + cur_info['plan'] = cur_spmd_desc.to_json_object() + cur_info['estimated time'] = cur_time + cur_info['estimated memory'] = convert_mem_into_GB(cur_mem) + cur_info['compile time'] = compile_cost_time + plan_info.append(cur_info) + + with open(task_config.backup_fname, 'w') as f: + json.dump(plan_info, f) From 57c51c97e8052e5c0d6cd95b254909156e72b3bb Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 8 Mar 2023 17:26:33 -0800 Subject: [PATCH 1328/1892] resolve comments --- cube/graph/parser/mappingfx.py | 1 + cube/graph/parser/parserfx.py | 5 ----- cube/profiler/database.py | 13 +++++++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index b53102ed..870c9f13 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -96,6 +96,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('finfo'): function.FInfo, __ttemplate('inverse'): function.Inverse, __ttemplate('bitwise_or'): function.BitwiseOr, + '_operator.or_': function.BitwiseOr, __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index d1a30a7a..71112cae 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -241,14 +241,9 @@ def get_complex_data(val: Any) -> Any: input_vals = [get_complex_data(val) for val in node.args] kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} - if '.or' in fsig: - print('zql find to: ', fsig, node.name, node.target, node.meta, node.args, node.kwargs, input_vals, kwargs) - # exit(1) # map to IR operator if SignFx2Op.exist(fsig): - print('zql: ', input_vals, kwargs) ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) - print('zql ir_node: ', ir_node) else: # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator diff --git a/cube/profiler/database.py b/cube/profiler/database.py index b51a72de..65b18e94 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -171,8 +171,6 @@ def get_dep_names(sign: str): return ret if node.signature in Sign2Op.kOpCodeDef: - # FIXME: ... - assert False, 'Sing2Op.kOpCodeDef is not empty' dep_code_impl = '' for dep_name in get_dep_names(node.signature): dep_code_impl = dep_code_impl + Sign2Op.kOpCodeDef[dep_name] @@ -189,7 +187,10 @@ def get_dep_names(sign: str): fn = list(local.values())[0] else: if '_operator.' in node.signature: - fn = eval(node.signature.replace('_operator.', 'torch.')) + if '_operator.or_' == node.signature: + fn = torch.bitwise_or + else: + fn = eval(node.signature.replace('_operator.', 'torch.')) else: fn = eval(node.signature) shapes, dtypes = [], [] @@ -355,9 +356,9 @@ def _serialize(self, node: IRFwOperation) -> str: """ shapes, dtypes = [], [] for t in node.inputs(): - if isinstance(t, IRTensor):#, f"Only support node inputs with tensor shape" - shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) + assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) shapes = '-'.join(str(tuple(shape)) for shape in shapes) dtypes = '-'.join(str(dtype) for dtype in dtypes) return shapes + ' : ' + dtypes From 19ebbd98b485722c96397d6a730ab732db33aba5 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 8 Mar 2023 22:29:57 -0800 Subject: [PATCH 1329/1892] support more ops --- cube/graph/function/creators.py | 55 ++++++++++++++++++++++++++++++++- cube/graph/function/function.py | 54 +++++++++++++++++++++++++++++++- cube/graph/parser/mappingfx.py | 4 +++ cube/graph/parser/parserfx.py | 5 +++ cube/profiler/database.py | 2 ++ 5 files changed, 118 insertions(+), 2 deletions(-) diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index c9f43e0e..1af0b7e2 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -7,6 +7,59 @@ import numpy as np +class IRArange(IRFwOperation): + def __init__(self, signature: str, shape: List[int], name: str, **kwargs): + + # The shape information must be statically known integer values + assert all(isinstance(dim, int) for dim in shape) + assert 'dtype' in kwargs + assert isinstance(kwargs['dtype'], IRDType) + + super().__init__(name, signature, input_length=0, output_length=1) + + # Customize output's dtype only after 'super().__init__' and 'self.set_input', + # otherwise it gets overwritten. + self.output(0).dtype = kwargs['dtype'] + self.shape = shape + + def infer_shape(self) -> bool: + self.output(0).shape = copy(self.shape) + return True + + def new(self, outputs: List[IRTensor]): + op = IRArange(self.signature, outputs[0].shape, self.name, **self.kwargs) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRArange::new infer_shape failed" + return op + +class IREmpty(IRFwOperation): + def __init__(self, signature: str, shape: List[int], name: str, **kwargs): + + # The shape information must be statically known integer values + assert all(isinstance(dim, int) for dim in shape) + assert 'dtype' in kwargs + assert isinstance(kwargs['dtype'], IRDType) + + super().__init__(name, signature, input_length=0, output_length=1) + + # Customize output's dtype only after 'super().__init__' and 'self.set_input', + # otherwise it gets overwritten. + self.output(0).dtype = kwargs['dtype'] + + # The positional argument to specify the shape is actually called 'size'. + self.kwargs.update({"size": copy(shape)}) + + def infer_shape(self) -> bool: + shape : list = copy(self.kwargs["size"]) + self.output(0).shape = shape + return True + + def new(self, outputs: List[IRTensor]): + op = IREmpty(self.signature, outputs[0].shape, self.name, **self.kwargs) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IREmpty::new infer_shape failed" + return op + class IRZeros(IRFwOperation): def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): @@ -29,7 +82,7 @@ def infer_shape(self) -> bool: return True def new(self, outputs: List[IRTensor]): - op = IROnes(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) + op = IRZeros(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) op.set_output(0, outputs[0]) assert op.infer_shape(), "IRZeros::new infer_shape failed" return op diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index ed89e85c..b8390163 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -11,7 +11,7 @@ from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D -from cube.graph.function.creators import IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor +from cube.graph.function.creators import IRArange, IREmpty, IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor from cube.graph.function.anchor import IRGraphAnchor @@ -112,6 +112,46 @@ def Matmul(signature, input, other, *, out=None): return IRDimops(Matmul, 'matmul', signature, annos, [input, other]) +def Arange(*args, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, signature=None): + """ + torch.arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor + """ + if len(args) == 1: + start, end, step = 0, args[0], 1 + elif len(args) == 2: + start, end, step = args[0], args[1], 1 + elif len(args) == 3: + start, end, step = args + else: + raise RuntimeError(f'Invalid number {len(args)} of args in Arange.') + assert isinstance(start, int) and isinstance(end, int) and isinstance(step, int) + from cube.graph.parser.mapping import DType2IRDType + if dtype is None: + dtype = torch.get_default_dtype() + ir_dtype : IRDType = DType2IRDType.map(dtype) + import math + size = (math.ceil((end-start)/step),) + kwargs = {'start': start, 'end': end, 'step': step, 'out': out, 'dtype': ir_dtype, + 'layout': layout, 'device': device, 'requires_grad': requires_grad} + return IRArange(signature, size, 'arange', **kwargs) + + +def Empty(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, + pin_memory=False, memory_format=torch.contiguous_format, signature=None): + from cube.graph.parser.mapping import DType2IRDType + if dtype is None: + dtype = torch.get_default_dtype() + ir_dtype : IRDType = DType2IRDType.map(dtype) + # example size: ((17, 17),) + assert isinstance(size, tuple) and isinstance(size[0], tuple) + for dim, i in enumerate(size[0]): + if not isinstance(dim, int) and not dim >= 0: + raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") + kwargs = {'dtype': ir_dtype, 'layout': layout, 'device': device, 'requires_grad': requires_grad, + 'pin_memory': pin_memory, 'memory_format': memory_format} + return IREmpty(signature, size[0], 'empty', **kwargs) + + def Zeros(signature, inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): # zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -332,11 +372,23 @@ def BitwiseOr(input, other, *, out=None, signature=None): """ torch.bitwise_or(input, other, *, out=None) → Tensor """ + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input | other assert isinstance(input, IRTensor) and isinstance(other, IRTensor) annos = ['*, * -> *'] return IRDimops(BitwiseOr, 'bitwise_or', signature, annos, [input, other]) +def BitwiseNot(input, *, out=None, signature=None): + assert out is None + if not isinstance(input, IRObject): + return ~input + assert isinstance(input, IRTensor) + annos = ['* -> *'] + return IRDimops(BitwiseNot, 'bitwise_not', signature, annos, [input]) + + def CubeAdd(input, other, alpha=1, *, out=None, signature = None): signature = 'cube.runtime.function.add' if isinstance(input, IRTensor) and isinstance(other, IRTensor): diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 870c9f13..a2798e8f 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -90,6 +90,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, __tttemplate('expand'): function.Expand, + __ttemplate('arange'): function.Arange, __ttemplate('detach'): function.Detach, __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, __ttemplate('index_select'): function.IndexSelect, @@ -97,6 +98,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('inverse'): function.Inverse, __ttemplate('bitwise_or'): function.BitwiseOr, '_operator.or_': function.BitwiseOr, + __ttemplate('bitwise_not'): function.BitwiseOr, + '_operator.invert': function.BitwiseNot, __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, @@ -129,6 +132,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # # torch aten # # # creators + __ttemplate('empty'): function.Empty, # __ttemplate('zeros'): function.Zeros, # __ttemplate('ones'): function.Ones, # __ttemplate('tensor'): function.NewTensor, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 71112cae..d0cf2185 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -241,9 +241,14 @@ def get_complex_data(val: Any) -> Any: input_vals = [get_complex_data(val) for val in node.args] kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} + # if 'invert' in fsig: + # print('zql find invert: ', fsig, node.name, node.target, node.meta, node.args, node.kwargs, input_vals, kwargs) + # exit(1) # map to IR operator if SignFx2Op.exist(fsig): + print('zql: ', input_vals, kwargs) ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) + print('zql ir_node: ', ir_node) else: # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 65b18e94..89a332e3 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -189,6 +189,8 @@ def get_dep_names(sign: str): if '_operator.' in node.signature: if '_operator.or_' == node.signature: fn = torch.bitwise_or + elif '_operator.invert' == node.signature: + fn = torch.bitwise_not else: fn = eval(node.signature.replace('_operator.', 'torch.')) else: From b9286244c88280408c8202957ddd32978e199ffc Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Thu, 9 Mar 2023 02:28:48 -0800 Subject: [PATCH 1330/1892] update --- cube/graph/function/function.py | 44 +++++++++++++------------------ cube/runtime/function/function.py | 6 ----- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index b8390163..8ba37fd1 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -138,6 +138,10 @@ def Arange(*args, out=None, dtype=None, layout=torch.strided, device=None, requi def Empty(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format, signature=None): + """ + torch.empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, + requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) → Tensor + """ from cube.graph.parser.mapping import DType2IRDType if dtype is None: dtype = torch.get_default_dtype() @@ -389,8 +393,10 @@ def BitwiseNot(input, *, out=None, signature=None): return IRDimops(BitwiseNot, 'bitwise_not', signature, annos, [input]) -def CubeAdd(input, other, alpha=1, *, out=None, signature = None): - signature = 'cube.runtime.function.add' +def Add(input, other, alpha=1, *, out=None, signature = None): + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input + alpha * other if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] @@ -398,16 +404,9 @@ def CubeAdd(input, other, alpha=1, *, out=None, signature = None): else: annos = ['* -> *'] if isinstance(input, IRTensor): - return IRDimops(CubeAdd, 'add', signature, annos, [input], other=other, alpha=alpha) + return IRDimops(Add, 'add', signature, annos, [input], other=other, alpha=alpha) else: - return IRDimops(CubeAdd, 'add', signature, annos, [other], other=input, alpha=alpha) - - -def Add(input, other, alpha=1, *, out=None, signature = None): - assert out is None - if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): - return input + alpha * other - return CubeAdd(input, other, alpha, out=out, signature=signature) + return IRDimops(Add, 'add', signature, annos, [other], other=input, alpha=alpha) def CubeSub(input, other, alpha=1, *, out=None, signature = None): @@ -419,9 +418,9 @@ def CubeSub(input, other, alpha=1, *, out=None, signature = None): else: annos = ['* -> *'] if isinstance(input, IRTensor): - return IRDimops(CubeAdd, 'sub', signature, annos, [input], other=other, alpha=alpha) + return IRDimops(CubeSub, 'sub', signature, annos, [input], other=other, alpha=alpha) else: - return IRDimops(CubeAdd, 'sub', signature, annos, [other], other=input, alpha=alpha) + return IRDimops(CubeSub, 'sub', signature, annos, [other], other=input, alpha=alpha) def Sub(input, other, alpha=1, *, out=None, signature = None): @@ -431,25 +430,20 @@ def Sub(input, other, alpha=1, *, out=None, signature = None): return CubeSub(input, other, alpha, out=out, signature=signature) -def CubeMul(input, other, *, out=None, signature = None): - signature = 'cube.runtime.function.mul' +def Mul(input, other, *, out=None, signature = None): + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input * other if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(CubeMul, 'mul', signature, annos, [input, other]) + return IRDimops(Mul, 'mul', signature, annos, [input, other]) else: annos = ['* -> *'] if isinstance(input, IRTensor): - return IRDimops(CubeMul, 'mul', signature, annos, [input], other=other) + return IRDimops(Mul, 'mul', signature, annos, [input], other=other) else: - return IRDimops(CubeMul, 'mul', signature, annos, [other], other=input) - - -def Mul(input, other, *, out=None, signature = None): - assert out is None - if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): - return input * other - return CubeMul(input, other, out=out, signature=signature) + return IRDimops(Mul, 'mul', signature, annos, [other], other=input) def Div(input, other, *, rounding_mode=None, out=None, signature = None): diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 0f83d4da..6cb2414b 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -34,15 +34,9 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) -def add(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float]) -> torch.Tensor: - return torch.add(input, other, alpha=alpha) - def sub(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float]) -> torch.Tensor: return torch.sub(input, other, alpha=alpha) -def mul(input: torch.Tensor, other: Union[float, torch.Tensor]) -> torch.Tensor: - return torch.mul(input, other) - def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: return input.expand(*sizes) From 16657d7a8166cc7f29864757623662d7375baf47 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Thu, 9 Mar 2023 03:08:41 -0800 Subject: [PATCH 1331/1892] update --- cube/graph/function/creators.py | 40 +++++++++++++---- cube/graph/function/function.py | 71 ++++++++++++++++++------------- cube/graph/parser/mappingfx.py | 2 +- cube/graph/parser/parserfx.py | 5 --- cube/runtime/function/function.py | 7 ++- 5 files changed, 78 insertions(+), 47 deletions(-) diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index 1af0b7e2..356b0698 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -21,6 +21,7 @@ def __init__(self, signature: str, shape: List[int], name: str, **kwargs): # otherwise it gets overwritten. self.output(0).dtype = kwargs['dtype'] self.shape = shape + self.kwargs = kwargs def infer_shape(self) -> bool: self.output(0).shape = copy(self.shape) @@ -47,6 +48,7 @@ def __init__(self, signature: str, shape: List[int], name: str, **kwargs): self.output(0).dtype = kwargs['dtype'] # The positional argument to specify the shape is actually called 'size'. + self.kwargs = kwargs self.kwargs.update({"size": copy(shape)}) def infer_shape(self) -> bool: @@ -60,6 +62,26 @@ def new(self, outputs: List[IRTensor]): assert op.infer_shape(), "IREmpty::new infer_shape failed" return op +class IRNewTensor(IRFwOperation): + def __init__(self, signature: str, data: list, name: str, **kwargs): + super().__init__(name, signature, input_length=0, output_length=1) + assert 'dtype' in kwargs + assert isinstance(kwargs['dtype'], IRDType) + self.output(0).dtype = kwargs['dtype'] + self.data = data + self.shape = np.array(data).shape + self.kwargs = kwargs + + def infer_shape(self) -> bool: + self.output(0).shape = copy(self.shape) + return True + + def new(self, outputs: List[IRTensor]): + op = IRNewTensor(self.signature, self.data, self.name, **self.kwargs) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRNewTensor::new infer_shape failed" + return op + class IRZeros(IRFwOperation): def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): @@ -141,16 +163,16 @@ def new(self, outputs: List[IRTensor]): assert op.infer_shape(), "IRRand::new infer_shape failed" return op -class IRNewTensor(IRFwOperation): - def __init__(self, signature: str, data: list, name: str, ir_dtype: IRDType): - super().__init__(name, signature, input_length=0, output_length=1) - self.output(0).dtype = ir_dtype - self.kwargs.update({'data': data, 'shape': np.array(data).shape, 'dtype': ir_dtype}) +# class IRNewTensor(IRFwOperation): +# def __init__(self, signature: str, data: list, name: str, ir_dtype: IRDType): +# super().__init__(name, signature, input_length=0, output_length=1) +# self.output(0).dtype = ir_dtype +# self.kwargs.update({'data': data, 'shape': np.array(data).shape, 'dtype': ir_dtype}) - def infer_shape(self) -> bool: - shape : list = copy(self.kwargs['shape']) - self.output(0).shape = shape - return True +# def infer_shape(self) -> bool: +# shape : list = copy(self.kwargs['shape']) +# self.output(0).shape = shape +# return True diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 8ba37fd1..64c5b6ea 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -235,39 +235,50 @@ def Rand(signature, raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") return IRRand(signature, size, 'rand', ir_dtype) -def NewTensor(signature, - inputs: Tuple[ list, Optional[int], ErasedDevice, bool ]): - # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor - # - # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of - # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. +def NewTensor(data: Union[int, float, list], dtype=None, device=None, requires_grad=False, pin_memory=False, signature=None): + # NOTE: not sure all the keys of torch.tensor + assert requires_grad == False + from cube.graph.parser.mapping import DType2IRDType + if dtype is None: + dtype = torch.get_default_dtype() + ir_dtype : IRDType = DType2IRDType.map(dtype) + kwargs = {'dtype': ir_dtype, 'device': device, 'requires_grad': requires_grad, 'pin_memory': pin_memory} + return IRNewTensor(signature, data, 'tensor', **kwargs) - data, dtype_underlying, _erased_device, requires_grad = inputs - # TODO parameters to support, currently they are all None - assert requires_grad == False - from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap +# def NewTensor(signature, +# inputs: Tuple[ list, Optional[int], ErasedDevice, bool ]): +# # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor +# # +# # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of +# # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. - if dtype_underlying is not None: - # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, - # which is the underlying type of PyTorch C++ enum 'ScalarType'. - dtype = TorchScalarTypeEnumMap.map(dtype_underlying) - else: - dtype = torch.get_default_dtype() +# data, dtype_underlying, _erased_device, requires_grad = inputs - ir_dtype : IRDType = DType2IRDType.map(dtype) +# # TODO parameters to support, currently they are all None +# assert requires_grad == False +# from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap + +# if dtype_underlying is not None: +# # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, +# # which is the underlying type of PyTorch C++ enum 'ScalarType'. +# dtype = TorchScalarTypeEnumMap.map(dtype_underlying) +# else: +# dtype = torch.get_default_dtype() + +# ir_dtype : IRDType = DType2IRDType.map(dtype) - # if 'data' is not: - # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 - # 2) non-ragged - # ... then this call will throw. - arr = torch.tensor(data, dtype=dtype) +# # if 'data' is not: +# # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 +# # 2) non-ragged +# # ... then this call will throw. +# arr = torch.tensor(data, dtype=dtype) - # TODO temporarily fake creation with Zeros - # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', - # but since we have omitted the 'data', we must do type inferrence ourselves, - # only in this way we get correct dtype e.g. ints or bools. - return IRNewTensor(signature, data, 'tensor', ir_dtype=ir_dtype) +# # TODO temporarily fake creation with Zeros +# # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', +# # but since we have omitted the 'data', we must do type inferrence ourselves, +# # only in this way we get correct dtype e.g. ints or bools. +# return IRNewTensor(signature, data, 'tensor', ir_dtype=ir_dtype) def ToTensor(signature, inputs: Tuple[ IRTensor, ... ]): @@ -414,13 +425,13 @@ def CubeSub(input, other, alpha=1, *, out=None, signature = None): if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Sub, 'sub', signature, annos, [input, other], alpha=alpha) + return IRDimops(CubeSub, 'sub', signature, annos, [input, other], alpha=alpha, swap_operands=False) else: annos = ['* -> *'] if isinstance(input, IRTensor): - return IRDimops(CubeSub, 'sub', signature, annos, [input], other=other, alpha=alpha) + return IRDimops(CubeSub, 'sub', signature, annos, [input], other=other, alpha=alpha, swap_operands=False) else: - return IRDimops(CubeSub, 'sub', signature, annos, [other], other=input, alpha=alpha) + return IRDimops(CubeSub, 'sub', signature, annos, [other], other=input, alpha=alpha, swap_operands=True) def Sub(input, other, alpha=1, *, out=None, signature = None): diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index a2798e8f..c4b0c978 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -135,7 +135,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('empty'): function.Empty, # __ttemplate('zeros'): function.Zeros, # __ttemplate('ones'): function.Ones, - # __ttemplate('tensor'): function.NewTensor, + __ttemplate('tensor'): function.NewTensor, # __ttemplate('to'): function.ToTensor, # __ttemplate('rand'): function.Rand, # __ttemplate('clone'): function.Clone, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index d0cf2185..71112cae 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -241,14 +241,9 @@ def get_complex_data(val: Any) -> Any: input_vals = [get_complex_data(val) for val in node.args] kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} - # if 'invert' in fsig: - # print('zql find invert: ', fsig, node.name, node.target, node.meta, node.args, node.kwargs, input_vals, kwargs) - # exit(1) # map to IR operator if SignFx2Op.exist(fsig): - print('zql: ', input_vals, kwargs) ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) - print('zql ir_node: ', ir_node) else: # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 6cb2414b..6d78c301 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -34,8 +34,11 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) -def sub(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float]) -> torch.Tensor: - return torch.sub(input, other, alpha=alpha) +def sub(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float], swap_operands: bool) -> torch.Tensor: + if swap_operands: + return torch.sub(other, input, alpha=alpha) + else: + return torch.sub(input, other, alpha=alpha) def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: return input.expand(*sizes) From 6441f5af64622ce292500a5f3d525121b5094d51 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Thu, 9 Mar 2023 11:21:45 +0000 Subject: [PATCH 1332/1892] Merged PR 1481: support more ops for running autodist --- cube/graph/function/creators.py | 95 +++++++++++-- cube/graph/function/function.py | 215 ++++++++++++++++++++++++------ cube/graph/parser/mappingfx.py | 15 ++- cube/profiler/database.py | 24 +++- cube/runtime/function/function.py | 12 +- tests/parser/test_bloom.py | 150 ++++++++++++++++++++- 6 files changed, 444 insertions(+), 67 deletions(-) diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index c9f43e0e..356b0698 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -7,6 +7,81 @@ import numpy as np +class IRArange(IRFwOperation): + def __init__(self, signature: str, shape: List[int], name: str, **kwargs): + + # The shape information must be statically known integer values + assert all(isinstance(dim, int) for dim in shape) + assert 'dtype' in kwargs + assert isinstance(kwargs['dtype'], IRDType) + + super().__init__(name, signature, input_length=0, output_length=1) + + # Customize output's dtype only after 'super().__init__' and 'self.set_input', + # otherwise it gets overwritten. + self.output(0).dtype = kwargs['dtype'] + self.shape = shape + self.kwargs = kwargs + + def infer_shape(self) -> bool: + self.output(0).shape = copy(self.shape) + return True + + def new(self, outputs: List[IRTensor]): + op = IRArange(self.signature, outputs[0].shape, self.name, **self.kwargs) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRArange::new infer_shape failed" + return op + +class IREmpty(IRFwOperation): + def __init__(self, signature: str, shape: List[int], name: str, **kwargs): + + # The shape information must be statically known integer values + assert all(isinstance(dim, int) for dim in shape) + assert 'dtype' in kwargs + assert isinstance(kwargs['dtype'], IRDType) + + super().__init__(name, signature, input_length=0, output_length=1) + + # Customize output's dtype only after 'super().__init__' and 'self.set_input', + # otherwise it gets overwritten. + self.output(0).dtype = kwargs['dtype'] + + # The positional argument to specify the shape is actually called 'size'. + self.kwargs = kwargs + self.kwargs.update({"size": copy(shape)}) + + def infer_shape(self) -> bool: + shape : list = copy(self.kwargs["size"]) + self.output(0).shape = shape + return True + + def new(self, outputs: List[IRTensor]): + op = IREmpty(self.signature, outputs[0].shape, self.name, **self.kwargs) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IREmpty::new infer_shape failed" + return op + +class IRNewTensor(IRFwOperation): + def __init__(self, signature: str, data: list, name: str, **kwargs): + super().__init__(name, signature, input_length=0, output_length=1) + assert 'dtype' in kwargs + assert isinstance(kwargs['dtype'], IRDType) + self.output(0).dtype = kwargs['dtype'] + self.data = data + self.shape = np.array(data).shape + self.kwargs = kwargs + + def infer_shape(self) -> bool: + self.output(0).shape = copy(self.shape) + return True + + def new(self, outputs: List[IRTensor]): + op = IRNewTensor(self.signature, self.data, self.name, **self.kwargs) + op.set_output(0, outputs[0]) + assert op.infer_shape(), "IRNewTensor::new infer_shape failed" + return op + class IRZeros(IRFwOperation): def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): @@ -29,7 +104,7 @@ def infer_shape(self) -> bool: return True def new(self, outputs: List[IRTensor]): - op = IROnes(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) + op = IRZeros(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) op.set_output(0, outputs[0]) assert op.infer_shape(), "IRZeros::new infer_shape failed" return op @@ -88,16 +163,16 @@ def new(self, outputs: List[IRTensor]): assert op.infer_shape(), "IRRand::new infer_shape failed" return op -class IRNewTensor(IRFwOperation): - def __init__(self, signature: str, data: list, name: str, ir_dtype: IRDType): - super().__init__(name, signature, input_length=0, output_length=1) - self.output(0).dtype = ir_dtype - self.kwargs.update({'data': data, 'shape': np.array(data).shape, 'dtype': ir_dtype}) +# class IRNewTensor(IRFwOperation): +# def __init__(self, signature: str, data: list, name: str, ir_dtype: IRDType): +# super().__init__(name, signature, input_length=0, output_length=1) +# self.output(0).dtype = ir_dtype +# self.kwargs.update({'data': data, 'shape': np.array(data).shape, 'dtype': ir_dtype}) - def infer_shape(self) -> bool: - shape : list = copy(self.kwargs['shape']) - self.output(0).shape = shape - return True +# def infer_shape(self) -> bool: +# shape : list = copy(self.kwargs['shape']) +# self.output(0).shape = shape +# return True diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index d9c9b987..64c5b6ea 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -11,7 +11,7 @@ from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D -from cube.graph.function.creators import IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor +from cube.graph.function.creators import IRArange, IREmpty, IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor from cube.graph.function.anchor import IRGraphAnchor @@ -112,6 +112,50 @@ def Matmul(signature, input, other, *, out=None): return IRDimops(Matmul, 'matmul', signature, annos, [input, other]) +def Arange(*args, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, signature=None): + """ + torch.arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor + """ + if len(args) == 1: + start, end, step = 0, args[0], 1 + elif len(args) == 2: + start, end, step = args[0], args[1], 1 + elif len(args) == 3: + start, end, step = args + else: + raise RuntimeError(f'Invalid number {len(args)} of args in Arange.') + assert isinstance(start, int) and isinstance(end, int) and isinstance(step, int) + from cube.graph.parser.mapping import DType2IRDType + if dtype is None: + dtype = torch.get_default_dtype() + ir_dtype : IRDType = DType2IRDType.map(dtype) + import math + size = (math.ceil((end-start)/step),) + kwargs = {'start': start, 'end': end, 'step': step, 'out': out, 'dtype': ir_dtype, + 'layout': layout, 'device': device, 'requires_grad': requires_grad} + return IRArange(signature, size, 'arange', **kwargs) + + +def Empty(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, + pin_memory=False, memory_format=torch.contiguous_format, signature=None): + """ + torch.empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, + requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) → Tensor + """ + from cube.graph.parser.mapping import DType2IRDType + if dtype is None: + dtype = torch.get_default_dtype() + ir_dtype : IRDType = DType2IRDType.map(dtype) + # example size: ((17, 17),) + assert isinstance(size, tuple) and isinstance(size[0], tuple) + for dim, i in enumerate(size[0]): + if not isinstance(dim, int) and not dim >= 0: + raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") + kwargs = {'dtype': ir_dtype, 'layout': layout, 'device': device, 'requires_grad': requires_grad, + 'pin_memory': pin_memory, 'memory_format': memory_format} + return IREmpty(signature, size[0], 'empty', **kwargs) + + def Zeros(signature, inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): # zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -191,39 +235,50 @@ def Rand(signature, raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") return IRRand(signature, size, 'rand', ir_dtype) -def NewTensor(signature, - inputs: Tuple[ list, Optional[int], ErasedDevice, bool ]): - # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor - # - # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of - # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. +def NewTensor(data: Union[int, float, list], dtype=None, device=None, requires_grad=False, pin_memory=False, signature=None): + # NOTE: not sure all the keys of torch.tensor + assert requires_grad == False + from cube.graph.parser.mapping import DType2IRDType + if dtype is None: + dtype = torch.get_default_dtype() + ir_dtype : IRDType = DType2IRDType.map(dtype) + kwargs = {'dtype': ir_dtype, 'device': device, 'requires_grad': requires_grad, 'pin_memory': pin_memory} + return IRNewTensor(signature, data, 'tensor', **kwargs) - data, dtype_underlying, _erased_device, requires_grad = inputs - # TODO parameters to support, currently they are all None - assert requires_grad == False - from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap +# def NewTensor(signature, +# inputs: Tuple[ list, Optional[int], ErasedDevice, bool ]): +# # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor +# # +# # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of +# # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. - if dtype_underlying is not None: - # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, - # which is the underlying type of PyTorch C++ enum 'ScalarType'. - dtype = TorchScalarTypeEnumMap.map(dtype_underlying) - else: - dtype = torch.get_default_dtype() +# data, dtype_underlying, _erased_device, requires_grad = inputs - ir_dtype : IRDType = DType2IRDType.map(dtype) +# # TODO parameters to support, currently they are all None +# assert requires_grad == False +# from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap + +# if dtype_underlying is not None: +# # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, +# # which is the underlying type of PyTorch C++ enum 'ScalarType'. +# dtype = TorchScalarTypeEnumMap.map(dtype_underlying) +# else: +# dtype = torch.get_default_dtype() + +# ir_dtype : IRDType = DType2IRDType.map(dtype) - # if 'data' is not: - # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 - # 2) non-ragged - # ... then this call will throw. - arr = torch.tensor(data, dtype=dtype) +# # if 'data' is not: +# # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 +# # 2) non-ragged +# # ... then this call will throw. +# arr = torch.tensor(data, dtype=dtype) - # TODO temporarily fake creation with Zeros - # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', - # but since we have omitted the 'data', we must do type inferrence ourselves, - # only in this way we get correct dtype e.g. ints or bools. - return IRNewTensor(signature, data, 'tensor', ir_dtype=ir_dtype) +# # TODO temporarily fake creation with Zeros +# # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', +# # but since we have omitted the 'data', we must do type inferrence ourselves, +# # only in this way we get correct dtype e.g. ints or bools. +# return IRNewTensor(signature, data, 'tensor', ir_dtype=ir_dtype) def ToTensor(signature, inputs: Tuple[ IRTensor, ... ]): @@ -308,6 +363,7 @@ def Expand(input, *sizes, signature = None): """ torch.Tensor.expand(*sizes) """ + signature = 'cube.runtime.function.expand' edim_in = ShapeAnno.create_shape_str(input.shape) assert len(input.shape) == len(sizes) for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes)): @@ -327,37 +383,78 @@ def Clone(input, *, memory_format=None, signature = None): return IRDimops(Clone, 'clone', signature, annos, [input]) +def BitwiseOr(input, other, *, out=None, signature=None): + """ + torch.bitwise_or(input, other, *, out=None) → Tensor + """ + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input | other + assert isinstance(input, IRTensor) and isinstance(other, IRTensor) + annos = ['*, * -> *'] + return IRDimops(BitwiseOr, 'bitwise_or', signature, annos, [input, other]) + + +def BitwiseNot(input, *, out=None, signature=None): + assert out is None + if not isinstance(input, IRObject): + return ~input + assert isinstance(input, IRTensor) + annos = ['* -> *'] + return IRDimops(BitwiseNot, 'bitwise_not', signature, annos, [input]) + + def Add(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input + alpha * other - annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) + return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) + else: + annos = ['* -> *'] + if isinstance(input, IRTensor): + return IRDimops(Add, 'add', signature, annos, [input], other=other, alpha=alpha) + else: + return IRDimops(Add, 'add', signature, annos, [other], other=input, alpha=alpha) + + +def CubeSub(input, other, alpha=1, *, out=None, signature = None): + signature = 'cube.runtime.function.sub' + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(CubeSub, 'sub', signature, annos, [input, other], alpha=alpha, swap_operands=False) + else: + annos = ['* -> *'] + if isinstance(input, IRTensor): + return IRDimops(CubeSub, 'sub', signature, annos, [input], other=other, alpha=alpha, swap_operands=False) + else: + return IRDimops(CubeSub, 'sub', signature, annos, [other], other=input, alpha=alpha, swap_operands=True) def Sub(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input - alpha * other - annos = ['*, ? -> *', '?, * -> *',] - if isinstance(input, IRTensor) and isinstance(other, IRTensor): - lshape, rshape, oshape = _handle_broadcast(input, other) - annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Sub, 'sub', signature, annos, [input, other], alpha=1) + return CubeSub(input, other, alpha, out=out, signature=signature) def Mul(input, other, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input * other - annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Mul, 'mul', signature, annos, [input, other]) + return IRDimops(Mul, 'mul', signature, annos, [input, other]) + else: + annos = ['* -> *'] + if isinstance(input, IRTensor): + return IRDimops(Mul, 'mul', signature, annos, [input], other=other) + else: + return IRDimops(Mul, 'mul', signature, annos, [other], other=input) def Div(input, other, *, rounding_mode=None, out=None, signature = None): @@ -1149,6 +1246,17 @@ def Roll(input, shifts: Union[int, Tuple[int]], dims=None, signature = None): return IRDimops(Roll, 'roll', signature, [anno], [input], shifts=shifts, dims=dims) +def Inverse(input, *, out=None, signature=None): + """ + torch.inverse(input, *, out=None) → Tensor + """ + ishape = ShapeAnno.create_shape_str(input.shape) + ishape = [i + '^' for i in ishape] + oshape = copy.copy(ishape) + anno = OpAnno.create_op_str([ishape], [oshape]) + return IRDimops(Inverse, 'inverse', signature, [anno], [input]) + + def AdaptiveAvgPool1d(input, output_size, signature = None): """ torch.nn.functional.adaptive_avg_pool2d(input, output_size) @@ -1284,6 +1392,23 @@ def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: return IRPyFunc(signature, [tensor, dim], [IRObject()]) +def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): + """ + torch.Tensor.to(*args, **kwargs) → Tensor + """ + assert out is None + # FIXME: support full version of torch.Tensor.to + # create "to" in cube runtime functions because dtype if not kwarg in torch.Tensor.to + signature = 'cube.runtime.function.to' + annos = ['* -> *'] + if isinstance(dtype_or_device, torch.device): + return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device) + else: + assert isinstance(dtype_or_device, (IRDType, torch.dtype)) + dtype = dtype_or_device if isinstance(dtype_or_device, torch.dtype) else eval('torch.'+dtype_or_device.value) + return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) + + def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: """ _operator.getitem(obj, index: int) @@ -1294,18 +1419,28 @@ def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: else: return IRPyFunc(signature, [obj, index], [IRObject()]) - -def GetAttr(instance: object, field: str, signature=None) -> Union[List[int], IRPyFunc]: + +def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], IRPyFunc]: """ builtins.getattr(object, name[, default]) NOTE: only deal with the attr "shape" of IRFullTensor, because other type of object may not have instantiated object or the attr is not simple value. """ obj, name = instance, field - if name == 'shape': + if name in ('shape', 'dtype'): assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" assert hasattr(obj, name), f"attr {name} is not existed in {obj}" return getattr(obj, name) + elif name == 'device': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + # FIXME: this is hack, IRFullTensor does not have attribute "device" + return torch.device('cpu') + elif isinstance(obj, torch.finfo): + return getattr(obj, name) else: # FIXME: is it right? return IRPyFunc(signature, [instance, field], [IRObject()]) + +def FInfo(dtype: IRDType, signature = None) -> torch.finfo: + assert isinstance(dtype, IRDType) + return torch.finfo(eval('torch.' + dtype.value)) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 53acdd89..c4b0c978 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -90,9 +90,16 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, __tttemplate('expand'): function.Expand, + __ttemplate('arange'): function.Arange, __ttemplate('detach'): function.Detach, __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, __ttemplate('index_select'): function.IndexSelect, + __ttemplate('finfo'): function.FInfo, + __ttemplate('inverse'): function.Inverse, + __ttemplate('bitwise_or'): function.BitwiseOr, + '_operator.or_': function.BitwiseOr, + __ttemplate('bitwise_not'): function.BitwiseOr, + '_operator.invert': function.BitwiseNot, __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, @@ -100,6 +107,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # ============== runtime function ================= __tttemplate('size'): function.Size, + __tttemplate('to'): function.To, '_operator.getitem': function.GetItem, 'builtins.getattr': function.GetAttr, @@ -124,17 +132,18 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # # torch aten # # # creators + __ttemplate('empty'): function.Empty, # __ttemplate('zeros'): function.Zeros, # __ttemplate('ones'): function.Ones, - # __ttemplate('tensor'): function.NewTensor, + __ttemplate('tensor'): function.NewTensor, # __ttemplate('to'): function.ToTensor, # __ttemplate('rand'): function.Rand, # __ttemplate('clone'): function.Clone, # __ttemplate('add') : function.Add, - # + '_operator.add': function.Add, # __ttemplate('sub') : function.Sub, - # + '_operator.sub': function.Sub, # __ttemplate('mul') : function.Mul, '_operator.mul': function.Mul, diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 911ea9ec..89a332e3 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -11,7 +11,8 @@ import cube from cube.ir.cten import IRTensor from cube.ir.operator import IRFwOperation -from cube.graph.parser.mapping import Sign2Op, IRDType2TorchDType +from cube.graph.parser.mapping import IRDType2TorchDType +from cube.graph.parser.mappingfx import SignFx2Op as Sign2Op Shapes = NewType('Shapes', Tuple[Tuple[int]]) @@ -50,12 +51,13 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, # create data dtypes = [torch.float32] * len(shapes) if dtypes is None else dtypes def gen_torch_tensors(shape, dtype): - constructor = torch.zeros if dtype == torch.int64 else torch.rand - requires_grad = False if dtype == torch.int64 else True + constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand + requires_grad = False if dtype in (torch.int64, torch.int32, torch.bool) else True return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) tensors = tuple( gen_torch_tensors(shape, dtype) for shape, dtype in zip(shapes, dtypes) ) + require_backward = any([t.requires_grad for t in tensors]) # repalce kwargs starting with 'self.xxx' train_kwargs, eval_kwargs = {}, {} for name, value in kwargs.items(): @@ -108,12 +110,12 @@ def unpack_hook(x): torch.cuda.synchronize() torch.cuda.empty_cache() with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): - outs = run_step(func, tensors, train_kwargs, backward=True) + outs = run_step(func, tensors, train_kwargs, backward=require_backward) # warmup tic = time.time() while time.time() - tic < warmup_sec: - run_step(func, tensors, train_kwargs, backward=True) + run_step(func, tensors, train_kwargs, backward=require_backward) # profile forward only torch.cuda.synchronize() @@ -129,7 +131,7 @@ def unpack_hook(x): torch.cuda.synchronize() tic = time.perf_counter() for _ in range(prof_times): - run_step(func, tensors, train_kwargs, backward=True) + run_step(func, tensors, train_kwargs, backward=require_backward) torch.cuda.synchronize() toc = time.perf_counter() fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds @@ -184,7 +186,15 @@ def get_dep_names(sign: str): exec(code_impl, globals(), local) fn = list(local.values())[0] else: - fn = eval(node.signature) + if '_operator.' in node.signature: + if '_operator.or_' == node.signature: + fn = torch.bitwise_or + elif '_operator.invert' == node.signature: + fn = torch.bitwise_not + else: + fn = eval(node.signature.replace('_operator.', 'torch.')) + else: + fn = eval(node.signature) shapes, dtypes = [], [] for t in node.inputs(): assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index bd7c5ec3..6d78c301 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Union import torch import torch.nn.functional as TorchF @@ -22,6 +22,8 @@ def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: """ return tensor if times == 1 else tuple([tensor] * times) +def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) -> torch.Tensor: + return tensor.to(dtype_or_device) def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: """ @@ -32,6 +34,14 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) +def sub(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float], swap_operands: bool) -> torch.Tensor: + if swap_operands: + return torch.sub(other, input, alpha=alpha) + else: + return torch.sub(input, other, alpha=alpha) + +def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: + return input.expand(*sizes) def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], stride: int, padding: List[int], dilation, groups: int = 1): diff --git a/tests/parser/test_bloom.py b/tests/parser/test_bloom.py index e286606d..934fc1cd 100644 --- a/tests/parser/test_bloom.py +++ b/tests/parser/test_bloom.py @@ -1,3 +1,4 @@ +from pathlib import Path import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig @@ -28,15 +29,152 @@ print("parsing fx graph to cube graph...") from cube.graph.parser import FxModuleParser -cube_graph = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) +inputs, nodes, outputs = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) print("parsing done.") +from cube.graph import IRGraph +module_name = model.__class__.__name__ +cube_graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) -# # AutoDist +# AutoDist # # profile communication cost # import os # comm_gpu_num = (2, 4) # for gpu_num in comm_gpu_num: -# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --comm_profile_dir=./ --connect_type=NV') -# # find the best partition plan -# from autodist.apis import compile -# compile(cube_graph, ...) +# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --connect_type=NV') +# profile computation cost +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ +config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': 'bloom'}) +config.autodist_config = dotdict({'ngpus': 2}) +# NOTE add SINGLE_DEV_MODE=1 before the running command +from autodist.cost_model.cost_database import CostDatabase +cost_database = CostDatabase(cube_graph, config) +# find the best partition plan +from autodist.task_config import TaskConfig +class BloomTaskConfig(TaskConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model = 'Bloom' + # self.Bloom_setting = kwargs['Bloom_setting'] + # self.fine_grained_Bloom = kwargs['fine_grained_Bloom'] + # self.bloom_config = build_bloom_config(self.Bloom_setting) + self.task_name = f'bloom-{self.autodist_config.ngpus}gpu-'\ + f'{self.autodist_config.micro_batch_size}batch_size' + self.estimated_fname, self.backup_fname, self.runtime_fname = self._build_file_name( + self.task_name) + self.allow_recom_ops = [] + self.del_dim = [] +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Bloom benchmark') + parser.add_argument('--fp16', + action='store_true', + help='use fp16 for the training') + parser.add_argument('--fine_grained_GPT', + action='store_true', + help='model = GPTFineGrained') + parser.add_argument('--GPT_setting', + type=str, + default='6.7B', + help='set GPT model type') + parser.add_argument('--save_folder', + type=str, + default='exp_data', + help='set the save folder for experiment data') + parser.add_argument('--micro_batch_size', + type=int, + default=8, + help='set micro batch size') + parser.add_argument('--global_batch_size', + type=int, + default=8, + help='set the global batch size') + parser.add_argument('--iter_num', + type=int, + default=2, + help='set the number of all iterations') + parser.add_argument('--warm_num', + type=int, + default=1, + help='set the number of warmup iterations') + parser.add_argument('--recompute', + action='store_true', + help='set recompute flag') + parser.add_argument('--memory_constraint', + type=float, + default=32, + help='memory constraint for program') + parser.add_argument('--memory_granularity', + type=int, + default=1, + help='memory granularity in byte') + parser.add_argument('--profile_dir', + type=str, + default=str(Path.home()) + '/.autodist', + help='profile dir') + parser.add_argument('--connect_type', + type=str, + default='NV2', + help='connect type from nvidia-smi topo -m') + parser.add_argument('--use_prev_plan', + action='store_true', + help='run from previous plan') + parser.add_argument('--is_train', + action='store_true', + help='True: train, False: inference') + parser.add_argument('--topk', + type=int, + default=20, + help='generate multiple plans for robustness') + parser.add_argument('--mesh_row', type=int, default=1, help='node num') + parser.add_argument('--mesh_col', + type=int, + default=2, + help='dev num in a node') + parser.add_argument('--compile', + action='store_true', + help='compile stage: true, runtime stage: false') + parser.add_argument('--pipeline', + action='store_true', + help='pipeline: true, tensor parallel: false') + parser.add_argument('--nproc', + type=int, + default=12, + help='multiprocess deg in pipeline') + parser.add_argument('--adaptive_recom', + action='store_true', + help='allow adaptive recompute') + parser.add_argument('--plan_idx', + type=int, + default=0, + help='runtime plan idx') + parser.add_argument('--verbose', action='store_true', help='verbose mode') + parser.add_argument('--ignore_small_tensor_threshold', + type=int, + default=0, + help='set the tensor size threshold to ignore') + parser.add_argument('--parse_plan', + action='store_true', + help='parse plan to user-friendly format') + parser.add_argument('--alphafold', + action='store_true', + help='use alphafold2') + parser.add_argument('--alphafold_setting', + type=int, + default=1, + help='1: bs, s, r = 1, 128, 256'\ + '2: bs, s, r = 1, 512, 256'\ + '3: bs, s, r = 1, 512, 384') + args = parser.parse_args() + + # if args.compile: + # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' + + task_config = BloomTaskConfig(**vars(args)) + from autodist.apis import calc_parallel_plan + topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) + # from autodist.apis import compile + # compile(cube_graph, None, task_config) From 95e74f18166b390061795d4520e1dff183d38cae Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 9 Mar 2023 11:23:26 +0000 Subject: [PATCH 1333/1892] Merged PR 1484: support dimop partition on the first non-1 hidden dimension support dimop partition on the first non-1 hidden dimension --- cube/algorithm/ops/dimops.py | 43 +++++++++++++++++++++++---------- cube/graph/function/function.py | 7 +++--- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index dc412f84..3a60cf25 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Any, Dict, Union +from typing import List, Optional, Any, Dict, Union, Tuple from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule @@ -44,7 +44,30 @@ def __init__(self, node: IRDimops): if not isinstance(node, IRDimops): raise TypeError(f"Expect IRDimops") super().__init__(node) - + + def get_identifier_reduce(self, idx: int, dim: int, num: int) -> Tuple[str, DimAnno.ReduceType]: + """ + Get the partitioned identifier and reduction type. + If the partitioned number is 1, return the first hidden identitifer + Otherwise, return the first hidden identifier whose length > 1 + + @param idx int: input index + @param dim int: input dimension + + @return identifier Optional[str]: annotated dimension identifier + @return reduction Optional[DimAnno.ReduceType] + """ + node: IRDimops = self.node + hidx = None + for hidx, adim in enumerate(node.anno.input(idx).dims[dim].identifiers): + if num == 1: break + dimlen = node.anno.getlen(adim) + if adim == '1^' or dimlen == 1: continue + break + if hidx is None: return (None, None) + reduce = node.anno.input(idx).dims[dim].reduces[hidx] + return adim, reduce + def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: """ Check whether the condition satisfies. @@ -70,16 +93,14 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: # try split at tensor spatial dimension if isinstance(dim, int): - for adim in node.anno.input(idx).dims[dim].identifiers: - if adim == '1^': continue - break + adim, reduce = self.get_identifier_reduce(idx, dim, num) + if adim is None: return False dimlen = node.anno.getlen(adim) # first check node special rules first for rule in node.transform_rules: if rule.input(idx) == DimopSplit.D(dim): return dimlen >= num # then check default rules - reduce = node.anno.input(idx).dims[dim].reduces[0] if reduce == DimAnno.ReduceType.Freeze: return False return dimlen >= num @@ -96,10 +117,7 @@ def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List satisfy = self.satisfy(idx, dim, num) if isinstance(dim, int): - for adim in node.anno.input(idx).dims[dim].identifiers: - if adim == '1^': continue - break - reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] + adim, reduce = self.get_identifier_reduce(idx, dim, num) else: adim, reduce = 'Value', None @@ -162,15 +180,14 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR return r # otherwise use default rule assert isinstance(dim, int), f"Error: expect dim to be int for default rules" - adim: str = node.anno.input(idx).dims[dim].identifiers[0] - reduce: DimAnno.ReduceType = node.anno.input(idx).dims[dim].reduces[0] + adim, reduce = self.get_identifier_reduce(idx, dim, num) if reduce == DimAnno.ReduceType.Freeze: return None itransform, otransform = [], [] # input for idx, idim in enumerate(node.anno.inputs()): dims = idim.getdims(adim) - assert len(dims) <= 1, "Cannot split on multple same tensors" + assert len(dims) <= 1, "Cannot split on multiple same tensors" if len(dims) == 1: itransform.append(DimopSplit.D(dims[0])) else: diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 64c5b6ea..ec918100 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -808,7 +808,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s for bracket in in_anno: sdim = None for hdim in range(len(bracket)): - if bracket[hdim] == '1': continue + if bracket[hdim] == '1' or shape_map[bracket[hdim]] == 1: continue sdim = bracket[hdim] break if sdim is not None: @@ -819,7 +819,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s for bracket in ou_anno: sdim = None for hdim in range(len(bracket)): - if bracket[hdim] == '1': continue + if bracket[hdim] == '1' or shape_map[bracket[hdim]] == 1: continue sdim = bracket[hdim] break if sdim is not None: @@ -840,8 +840,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s # the last one. def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: kwargs = dict(**kwargs) - ofirst = [bracket[0] for bracket in ou_anno] - identifier = in_anno[dim][0] + identifier = ifirst[dim] oidx = ofirst.index(identifier) size = list(kwargs['size']) size[oidx] = size[oidx] // num From 8c449bdf0937e0e6c4005091c74f832b159c07df Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Fri, 10 Mar 2023 02:53:47 +0000 Subject: [PATCH 1334/1892] Merged PR 1483: fix issues to enable TorchScale language model graph capturing fix issues to enable TorchScale language model graph capturing > confirmed op.view > fixed constant folding of ShapeAsTensor > skipped illegal subtensor dtype reset > mappingfx updated execute by run_torchscale_fx.py --- cube/flags.py | 1 + cube/graph/function/function.py | 12 ++++++++++-- cube/graph/parser/converter.py | 13 +++++++------ cube/graph/parser/mappingfx.py | 1 + cube/graph/parser/parserfx.py | 18 ++++++++++++------ cube/ir/tensor.py | 9 +++++++-- examples/nlp/torchscale/policy/spmd.py | 18 +++++++++++++++--- examples/nlp/torchscale/run_torchscale_lm.py | 8 ++++++-- 8 files changed, 59 insertions(+), 21 deletions(-) diff --git a/cube/flags.py b/cube/flags.py index d5f94e83..ed2a4354 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -23,6 +23,7 @@ class CompileFlag: # ================ compiling ======================== use_torchfx = _to_bool('USE_TORCHFX') # using torch.fx or torchscript as frontend to capture dataflow graph + use_default_fx_tracer = _to_bool('USE_DEFAULT_FX_TRACER') # using default fx tracer or more powerful concrete_tracer # worker sleep in seconds worker_sleep = _to_int('WORKER_SLEEP') disable_intra_rvd = _to_bool('DISABLE_INTRA_RVD') diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index ec918100..23b337c3 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1224,6 +1224,7 @@ def Embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, def Flatten(input, start_dim=0, end_dim=-1, signature = None): + start_dim = len(input.shape) + start_dim if start_dim < 0 else start_dim end_dim = len(input.shape) + end_dim if end_dim < 0 else end_dim ishape = ShapeAnno.create_shape_str(input.shape) for dim in range(start_dim, end_dim+1): @@ -1371,6 +1372,9 @@ def ShapeAsTensor(input: IRTensor, signature = None): """ torch._shape_as_tensor """ + if isinstance(input.shape, list) and all(isinstance(dim, int) for dim in input.shape): + return input.shape + edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = [str(len(input.shape))] anno = OpAnno.create_op_str([edim_in], [edim_ou]) @@ -1402,10 +1406,14 @@ def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): annos = ['* -> *'] if isinstance(dtype_or_device, torch.device): return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device) - else: - assert isinstance(dtype_or_device, (IRDType, torch.dtype)) + elif isinstance(dtype_or_device, (IRDType, torch.dtype)): dtype = dtype_or_device if isinstance(dtype_or_device, torch.dtype) else eval('torch.'+dtype_or_device.value) return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) + elif isinstance(dtype_or_device, IRFullTensor): + return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device.dtype) + else: + raise RuntimeError(f'function.To with unknown arg: {dtype_or_device}') + def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 64d9fc35..feb70b35 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -19,8 +19,8 @@ def convert_model(model: torch.nn.Module, """ try: if CompileFlag.use_torchfx: - if not dummy_input: - print('using torch.fx tracer') + if CompileFlag.use_default_fx_tracer: + print('INFO: using torch.fx tracer') from torch.fx import symbolic_trace # Symbolic tracing frontend - captures the semantics of the module tracer = FxFuncOpTracer() @@ -28,12 +28,13 @@ def convert_model(model: torch.nn.Module, smodule: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) smodule.graph.print_tabular() else: - print('using concrete tracer') + print('INFO: using concrete tracer') with torch.no_grad(): - if isinstance(dummy_input, tuple) or isinstance(dummy_input, list): - output_origin = model(*dummy_input) - elif isinstance(dummy_input, torch.Tensor): + if isinstance(dummy_input, torch.Tensor): output_origin = model(dummy_input) + dummy_input = (dummy_input, ) + elif isinstance(dummy_input, tuple) or isinstance(dummy_input, list): + output_origin = model(*dummy_input) elif isinstance(dummy_input, dict): print(f'WARNING dict dummy_input') output_origin = model(**dummy_input) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index c4b0c978..31b271f4 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -153,6 +153,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] '_operator.floordiv': function.FloorDiv, __ttemplate('neg'): function.Neg, + '_operator.neg': function.Neg, # __ttemplate('gt'): function.CompareGT, __ttemplate('lt'): function.CompareLT, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 71112cae..d2af6113 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -130,8 +130,10 @@ def parse(module: torch.fx.GraphModule, print(f'{node.name} is immutable_dict type') assert isinstance(node.meta['tensor_meta'], dict) else: - assert node.meta['type'] is type(torch.Tensor()) or node.meta['type'] is type(torch.nn.parameter.Parameter()) - print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) + if node.meta['type'] is type(torch.Tensor()) or node.meta['type'] is type(torch.nn.parameter.Parameter()): + print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) + else: + print(f'WARNING node {node.name} is neither Tensor nor Parameter') else: print(f'{node.name} does not has tensor_meta') @@ -151,7 +153,12 @@ def parse(module: torch.fx.GraphModule, # handle nodes all_ir_nodes: List[IRFwOperation] = list() + total_node_num = len(module.graph.nodes) + node_idx = 1 for node in module.graph.nodes: + print(f'[{node_idx}/{total_node_num}]') + node_idx += 1 + ir_nodes = FxModuleParser.parse_node(node, module, frame) if ir_nodes is not None: all_ir_nodes += ir_nodes @@ -180,7 +187,7 @@ def ntype(node: torch.fx.Node): return FxNodeKind.Output if node.op == 'call_method': return FxNodeKind.PrimCallMethod - raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") + raise RuntimeError(f"Unknown node kind {node.kind()} from torchscript module") @staticmethod def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation]: @@ -193,7 +200,6 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] return [] if node_type == FxNodeKind.Output: return FxModuleParser.parse_prim_output_node(node, module, frame) - if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): return FxModuleParser.parse_prim_function_method(node, module, frame) if node_type == FxNodeKind.PrimGetAttr: @@ -248,7 +254,7 @@ def get_complex_data(val: Any) -> Any: # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): - print(f'>>> Find unkown pytorch operation: {fsig}') + print(f'>>> Find unknown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fname ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) ir_node.kwargs = kwargs @@ -380,7 +386,7 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> assert len(node.kwargs) == 0, f'invalid kwargs {node.kwargs} in {node.name}, {node.target}, {node.meta}' # example node.args[0].meta is {'type': } in_type = node.args[0].meta['type'] - assert node_target in in_type().__dir__() + assert node_target in in_type().__dir__(), f'node_target = {node_target}, in_type().__dir__() = {in_type().__dir__()}' sig = f'{in_type.__name__}.{node_target}' print(f'The method is not torch or Tensor, but {sig}') return sig diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 2798631b..96bdef4e 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -503,8 +503,13 @@ def dtype(self, val: IRDType): if self.parent.dtype == IRDType.unknown: self.parent.dtype = val else: - assert self.parent.dtype == val, \ - f"dtype mis-matched with previous setting: {val} != {self.parent.dtype}" + if self.parent.dtype != val: + print(f'ERROR (skipped) reset IRSubTensor({self.name}) dtype {self.parent.dtype}->{val}') + self.parent.dtype = val + + # TODO recover me + # assert self.parent.dtype == val, \ + # f"dtype mis-matched with previous setting: {val} != {self.parent.dtype}" def splitdims(self) -> Tuple[int]: """! diff --git a/examples/nlp/torchscale/policy/spmd.py b/examples/nlp/torchscale/policy/spmd.py index e3bae1a4..52cf2b99 100644 --- a/examples/nlp/torchscale/policy/spmd.py +++ b/examples/nlp/torchscale/policy/spmd.py @@ -74,9 +74,21 @@ def PASData(graph: IRGraph, resource): batch_dim = node.get_batch_dims()[0] for node in graph.nodes(): if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=0, dim=batch_dim, num=resource.ngpus) + # if not isinstance(node, IRPyFunc): # and node.signature in ('torch.arange', 'torch.sin'): + # algo = node.algorithms('dim') + # sub_nodes = graph.partition( + # node, algo, idx=0, dim=batch_dim, num=resource.ngpus) + # else: + # print(f'WARNING: {node} cannot find dim algo, using replicate instead') + # sub_nodes = graph.replicate(node, resource.ngpus) + try: + algo = node.algorithms('dim') + sub_nodes = graph.partition( + node, algo, idx=0, dim=batch_dim, num=resource.ngpus) + except AssertionError: + print(f'WARNING: {node} cannot find dim algo, using replicate instead') + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): graph.assign(node, idx) return graph diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index 3813681b..589d9e55 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -115,9 +115,13 @@ def train_iter(model, dataloader): data = next(dataloader) loss = model(*data) - loss.backward() + # loss.backward() + return loss -train_iter(model, dataloader) +model = model.get_gen_module() + +iter_ret = train_iter(model, dataloader) +print(f'iter_ret = {iter_ret}') # Conduct concrete trace below # sys.path.append('/home/v-junliang/torchscaletest/nni') From b889d75df66586a835a7500cfcff08a2d15f6a3e Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Fri, 10 Mar 2023 06:01:53 +0000 Subject: [PATCH 1335/1892] Merged PR 1485: update concrete trace update concrete trace fix kwargs parsing error in torchscale --- .../concrete_trace_utils/concrete_tracer.py | 33 ++-- .../kwargs_shape_prop/__init__.py | 0 .../kwargs_shape_prop/kwargs_interpreter.py | 147 ++++++++++++++++++ .../kwargs_shape_prop/kwargs_shape_prop.py | 126 +++++++++++++++ cube/graph/parser/converter.py | 2 +- cube/graph/parser/parserfx.py | 8 +- 6 files changed, 289 insertions(+), 27 deletions(-) create mode 100644 cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/__init__.py create mode 100644 cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py create mode 100644 cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index db1e3b84..1c9d6d9c 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -19,7 +19,6 @@ from torch._C import ScriptObject from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict -import torch.fx from torch.fx import GraphModule from torch.fx._compatibility import compatibility from torch.fx._symbolic_trace import _Patcher, _proxyable_classes @@ -164,10 +163,6 @@ class ConcreteTracer(TracerBase): _orig_frozenset: ([], True), _orig_dict: ([], True), } - - current_module_qualified_name : str = '' - node_to_originating_module : Dict[torch.fx.Node, str] = {} - @compatibility(is_backward_compatible=True) def __init__(self): """ @@ -289,10 +284,7 @@ def upwrapper(obj: Any): assert isinstance(kwargs_noded, dict) node = self.create_node(kind, target, args_noded, kwargs_noded, name, type_expr) - # return self.proxy(value_unwrapped, node) - proxy = self.proxy(value_unwrapped, node) - self.node_to_originating_module[proxy.node] = self.current_module_qualified_name - return proxy + return self.proxy(value_unwrapped, node) @compatibility(is_backward_compatible=True) def create_arg(self, a: Any) -> Union[Node, Any]: @@ -384,7 +376,6 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool similar to _symbolic_trace.Tracer.is_leaf_module """ # return (m.__module__.startswith('torch.nn') and not _orig_isinstance(m, (Sequential, ModuleList, ModuleDict)))\ - # or _orig_isinstance(m, self.leaf_module) return (m.__module__.startswith('torch.nn.functional') and not _orig_isinstance(m, (Sequential, ModuleList, ModuleDict)))\ or _orig_isinstance(m, self.leaf_module) @@ -640,19 +631,13 @@ def module_call_wrapper(mod, *args, **kwargs): if self.temp_disable_call: return _orig_module_call(mod, *args, **kwargs) else: - # corresponding to call_module - old_qualname = self.current_module_qualified_name - try: - self.current_module_qualified_name = self.path_of_module(mod) - module_qualified_name = self.path_of_module(mod) - if not self.is_leaf_module(mod, module_qualified_name): - _autowrap_check(self, mod.forward.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - _autowrap_check(self, mod.__dict__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - return _orig_module_call(mod, *args, **kwargs) - else: - return self.create_proxy('call_module', module_qualified_name, args, kwargs) - finally: - self.current_module_qualified_name = old_qualname + module_qualified_name = self.path_of_module(mod) + if not self.is_leaf_module(mod, module_qualified_name): + _autowrap_check(self, mod.forward.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + _autowrap_check(self, mod.__dict__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + return _orig_module_call(mod, *args, **kwargs) + else: + return self.create_proxy('call_module', module_qualified_name, args, kwargs) class map_wrapper_clz: @functools.wraps(_orig_map) @@ -1413,4 +1398,4 @@ def f(x, y): # # assert root(**concrete_args) == traced(**concrete_args) if check_args is not None: assert root(**check_args) == traced(**check_args) - return traced, tracer \ No newline at end of file + return traced diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/__init__.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py new file mode 100644 index 00000000..223c3e2a --- /dev/null +++ b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py @@ -0,0 +1,147 @@ +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +from torch.fx import Interpreter, Node +from typing import Optional, Union, Tuple, Dict, List, Any, Iterator, Callable, MutableMapping, Mapping + +Target = Union[Callable[..., Any], str] + +BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, + torch.Tensor, torch.device, torch.memory_format, torch.layout] + +Argument = Optional[Union[ + Tuple[Any, ...], + List[Any], + Dict[str, Any], + slice, + Node, + BaseArgumentTypes +]] + + +class KwargsInterpreter(Interpreter): + def run(self, + concrete_args: Union[Dict[str, Any], Tuple, MutableMapping[str, Any], Mapping[str, Any]] = None, + initial_env: Optional[Dict[Node, Any]] = None, + enable_io_preocessing: bool = True) -> Any: + + self.env = initial_env if initial_env else {} + + if isinstance(concrete_args, tuple): + # if concrete_args is a tuple, then they are positional args + # then they are consumed left-to-right by `placeholder` nodes. + # Use an iterator to keep track of position and extract those values + if enable_io_preocessing: + args = self.module.graph.process_inputs(*concrete_args) + self.args_iter: Iterator[Any] = iter(args) + self.concrete_kwargs = None + else: + try: + # concrete_args is a kwargs dict/mapping + self.args_iter = None + self.concrete_kwargs = concrete_args + self.used_concrete_kwargs = [] + # get default values of parameters in `forward()` method + import inspect + fw = inspect.unwrap(self.module.forward) + args_default_values = fw.__defaults__ + if args_default_values is not None: + fw_code = fw.__code__ + n_args = fw_code.co_argcount + fw_code.co_kwonlyargcount + names_iter = iter(fw_code.co_varnames) + start_idx = 0 + if fw_code.co_varnames[0] == 'self': + _ = next(names_iter) # skip self + start_idx = 1 + args_names = [next(names_iter) for idx in range(start_idx, n_args)] + diff_len = len(args_names) - len(args_default_values) + self.default_args = {args_names[idx + diff_len]: args_default_values[idx] for idx in + range(len(args_default_values))} + else: + self.default_args = {} + except: + raise RuntimeError(f'invalid concrete_args type: {type(concrete_args)}') + + assert ( + self.args_iter is None or self.concrete_kwargs is None), 'can not use positional args and keyword args at the same time' + + for node in self.module.graph.nodes: + if node in self.env: + continue + try: + self.env[node] = self.run_node(node) + except Exception as e: + print(node.name, node.op, node.target) + msg = f'While executing {node.format_node()}' + msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg) + msg += f"\nOriginal traceback:\n{node.stack_trace}" + e.args = (msg,) + e.args[1:] + if isinstance(e, KeyError): + raise RuntimeError(*e.args) + raise + + if self.garbage_collect_values: + for to_delete in self.user_to_last_uses.get(node, []): + del self.env[to_delete] + + if node.op == 'output': + output_val = self.env[node] + return self.module.graph.process_outputs(output_val) if enable_io_preocessing else output_val + + def run_node(self, n: Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + with fx_traceback.append_stack_trace(n.stack_trace): + args, kwargs = self.fetch_args_kwargs_from_env(n) + return getattr(self, n.op)(n.target, args, kwargs) + + def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a `placeholder` node. + + Args: + target(Target): The call target for this node, + exactly the argument name of the forward function + args(Tuple): Tuple of positional args for this invocation + kwargs(Dict): Dict of keyword arguments for this invocation + + Returns: + Any: The argument value that was retrieved. + """ + assert isinstance(target, str) + if target.startswith('**'): + # For a douvle-starred parameter, e.g., `**kwargs`, + # retrieve all the remaining values from the concrete kwargs dict + remaining_keys = [key for key in self.concrete_kwargs if key not in self.used_concrete_kwargs] + return {key: self.concrete_kwargs[key] for key in remaining_keys} + elif target.startswith('*'): + assert self.concrete_kwargs is None, 'unexpected positional args in kwargs mode' + return list(self.args_iter) + else: + if self.concrete_kwargs is not None: + try: + ret_arg = self.concrete_kwargs[target] + except KeyError: + return self.default_args[target] + else: + self.used_concrete_kwargs.append(target) + return ret_arg + else: + try: + return next(self.args_iter) + except StopAsyncIteration: + if len(args) > 0: + return args[0] + else: + raise RuntimeError( + f'Expected positional argument for parameter {target}, but one was not passed in!') diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py new file mode 100644 index 00000000..9ad13c01 --- /dev/null +++ b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py @@ -0,0 +1,126 @@ +import torch +import traceback +from torch.fx.node import Node, map_aggregate +from typing import Optional, Union, NamedTuple, Tuple, Any, Dict +from .kwargs_interpreter import KwargsInterpreter + + +__all__ = ['TensorMetadata', 'KwargsShapeProp', 'DCEHandler'] + + +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + # General Tensor metadata + shape : torch.Size + dtype : torch.dtype + requires_grad : bool + stride : Tuple[int] + memory_format : Optional[torch.memory_format] + + # Quantization metadata + is_quantized : bool + qparams: Dict[str, Any] + + +def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() + + memory_formats = { + torch.contiguous_format, + torch.channels_last, + torch.channels_last_3d, + } + + memory_format = None + + for query_format in memory_formats: + if result.is_contiguous(memory_format=query_format): + memory_format = query_format + break + + is_quantized = result.is_quantized + qparams: Dict[str, Any] = {} + if is_quantized: + qscheme = result.qscheme() + qparams["qscheme"] = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + qparams["scale"] = result.q_scale() # type: ignore[assignment] + qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] + elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: + # In this branch, scale and zero_point are expected to be tensors, + # we store the values as immutable_list in TensorMetadata for + # easier serialization downstream + qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] + qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] + qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] + + return TensorMetadata( + shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + + +class KwargsShapeProp(KwargsInterpreter): + def run_node(self, n: Node): + try: + result = super().run_node(n) + except Exception: + traceback.print_exc() + raise RuntimeError( + f"ShapeProp error for: node={n.format_node()} with " + f"meta={n.meta}" + ) + + found_tensor = False + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + nonlocal found_tensor + found_tensor = True + return _extract_tensor_metadata(obj) + else: + return obj + + # if the obj is a tensor, then wrap it into a TensorMetaData + # else recursively descend and wrap + meta = map_aggregate(result, extract_tensor_meta) + if found_tensor: + n.meta['tensor_meta'] = meta + n.meta['type'] = type(result) + return result + + def propagate(self, concrete_args: Union[Dict[str, Any], Tuple]): + return super().run(concrete_args) + + +class DCEHandler: + def __init__(self, gm: torch.fx.GraphModule): + self.gm = gm + + def eliminate_dead_code(self): + # set a loop to make sure clean all dead nodes + # because some nodes may be used by some dead nodes are also dead nodes + # !pay attention that the `output` node should be ignored for users checking + while True: + removed = False + for node in self.gm.graph.nodes: + if node.op == 'output': + continue + users = list(node.users) + if not users: + # make input nodes pop this node out of their users list + # before the node is removed + input_nodes = node.all_input_nodes + for input_node in input_nodes: + input_node.users.pop(node) + node._remove_from_list() + removed = True + if not removed: + break + self.gm.recompile() diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index feb70b35..5e53a4f7 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -40,7 +40,7 @@ def convert_model(model: torch.nn.Module, output_origin = model(**dummy_input) else: raise RuntimeError(f'dummy_input should be a tuple (not a {type(dummy_input)}) = {dummy_input}') - traced_model, _ = concrete_trace( + traced_model = concrete_trace( model, dummy_input, use_operator_patch=True, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index d2af6113..681e71f8 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -88,8 +88,12 @@ def parse(module: torch.fx.GraphModule, # shape propagation sample_inputs = dummy_inputs if dummy_inputs else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] sample_input_tensors = [sample_inputs[input] for input in sample_inputs] if type(sample_inputs) is dict else sample_inputs - from torch.fx.passes.shape_prop import ShapeProp - ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) + + # from torch.fx.passes.shape_prop import ShapeProp + # ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp + ShapeProp(module).propagate(sample_inputs) + # handle graph inputs for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) From d9db80d810ab72f6359f030bc8381d95a83dffc3 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Thu, 9 Mar 2023 22:04:05 -0800 Subject: [PATCH 1336/1892] minor --- examples/nlp/torchscale/run_torchscale_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index 3813681b..96a2f9bc 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -67,7 +67,7 @@ print("building model succeed: ", type(model)) # create dummy input -with open('examples/nlp/torchscale/input_lm.bak', 'rb') as f: +with open('examples/nlp/torchscale/input_lm', 'rb') as f: # with open('examples/nlp/torchscale/lm_input_v2.pkl', 'rb') as f: dummy_input = pickle.load(f) device = next(model.parameters()).device From a0c1e5477dd6aed66caf383b78b38a5f064d7fb1 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sat, 11 Mar 2023 05:22:13 -0800 Subject: [PATCH 1337/1892] support profiling torchscale --- cube/algorithm/ops/dimops.py | 2 ++ cube/graph/function/function.py | 57 +++++++++++++++++++++++-------- cube/graph/parser/parserfx.py | 41 +++++++++++++++------- cube/profiler/database.py | 22 ++++++++---- cube/runtime/function/function.py | 8 +++++ 5 files changed, 97 insertions(+), 33 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 3a60cf25..a3b53885 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -3,6 +3,7 @@ from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule from cube.ir.tensor import IRSubTensor +from cube.ir.cten import IRTensor from cube.ir.operator import IRFwOperation from collections import deque @@ -348,6 +349,7 @@ def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: def gen_hash(node: IRFwOperation) -> str: ret = node.signature for it in node.inputs(): + if not isinstance(it, IRTensor): continue ret = ret + '-' + str(it.shape) return ret diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 23b337c3..83e1d6da 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -85,6 +85,7 @@ def BMMAdd(input, batch1, batch2, *, beta=1, alpha=1, out=None, signature = None def CubeEinSum(*operands, equation=None, signature = None): assert isinstance(equation, str) + signature = 'cube.runtime.function.einsum' lhs, rhs = equation.split('->') assert ',' not in rhs lhs_dims = set(lhs.replace(',', ' ').split(' ')) @@ -461,11 +462,15 @@ def Div(input, other, *, rounding_mode=None, out=None, signature = None): assert rounding_mode is None and out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input / other - annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Div, 'div', signature, annos, [input, other]) + return IRDimops(Div, 'div', signature, annos, [input, other]) + else: + # if not all tensors, the second must not be IRObject + assert isinstance(input, IRTensor) and not isinstance(other, IRObject) + annos = ['* -> *'] + return IRDimops(Div, 'div', signature, annos, [input], other=other) def FloorDiv(input, other, *, out=None, signature = None): @@ -1061,15 +1066,13 @@ def Cat(*tensors_and_dim, dim=0, out=None, signature=None): return CubeCat(*tensors, dim=dim, signature=signature) -def CubeStack(*tensors, dim: int, signature=None): - """ - torch.stack(tensors, dim=0, *, out=None) - """ +def CubeStack(*tensors, dim=0, signature=None): # REMARK: IRFwOperation doesn't support taking a list of IRTensors. # Therefore, the argument interface is adapted to take unpacked tensors # with dimension. assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'but got {tensors}' assert isinstance(dim, int), f"but not {dim}" + signature = 'cube.runtime.function.stack' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] oannos = [copy.copy(iannos[-1])] oannos[0].insert(dim, str(len(tensors))) @@ -1077,14 +1080,10 @@ def CubeStack(*tensors, dim: int, signature=None): return IRDimops(CubeStack, 'stack', signature, [anno], tensors, dim=dim) -def Stack(*tensors_and_dim, dim=0, out=None, signature = None): +def Stack(tensors, dim=0, out=None, signature = None): """ torch.stack(tensors, dim=0, *, out=None) """ - if len(tensors_and_dim) == 2: - tensors, dim = tensors_and_dim[0], tensors_and_dim[1] - else: - tensors, dim = tensors_and_dim[0], dim return CubeStack(*tensors, dim=dim, signature=signature) @@ -1115,6 +1114,7 @@ def Select(input, dim, index, signature = None): def CubeIndexSelect(input: torch.Tensor, index: torch.Tensor, dim: int, signature = None): + signature = 'cube.runtime.function.index_select' edim_in = ShapeAnno.create_shape_str(input.shape) edim_in[dim] += '^' idx_anno = chr(ord(edim_in[-1]) + 1) + '^' @@ -1325,6 +1325,34 @@ def _comparison(creator: Callable, f: Callable, name: str, signature: str, else: return IRPyFunc(signature, [input, other], [IRObject()]) +def _comparison_hack(creator: Callable, f: Callable, name: str, signature: str, + input, other): + """ + if both operands are scalars, returns bool. + if one operand is a tensor, returns a broadcasted tensor with dtype being bool. + + @param creator Callable: the outside creation function + @param f Callable: (Scalar, Scalar) -> bools + """ + # case 0: return constant + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return f(input, other) + # case1: torch.equal(tensor1, tensor2) + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(creator, name, signature, annos, [input, other]) + # case2: torch.equal(tensor1, obj2) / torch.equal(obj1, tensor2) + if isinstance(input, IRTensor) or isinstance(other, IRTensor): + annos = ['* -> *'] + if isinstance(input, IRTensor): + return IRDimops(creator, name, signature, annos, [input], other=other) + else: + return IRDimops(creator, name, signature, annos, [other], other=input) + # case3: torch.equal(obj1, obj2) + else: + return IRPyFunc(signature, [input, other], [IRObject()]) + def CompareGT(input, other, *, out=None, signature = None): """ @@ -1358,14 +1386,14 @@ def CompareEQ(input, other, *, out=None, signature = None): """ torch.eq(input, other, *, out=None) """ - return _comparison(CompareEQ, operator.eq, 'eq', signature, input, other) + return _comparison_hack(CompareEQ, operator.eq, 'eq', signature, input, other) def CompareNE(input, other, *, out=None, signature = None): """ torch.ne(input, other, *, out=None) """ - return _comparison(CompareNE, operator.eq, 'ne', signature, input, other) + return _comparison_hack(CompareNE, operator.eq, 'ne', signature, input, other) def ShapeAsTensor(input: IRTensor, signature = None): @@ -1410,7 +1438,8 @@ def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): dtype = dtype_or_device if isinstance(dtype_or_device, torch.dtype) else eval('torch.'+dtype_or_device.value) return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) elif isinstance(dtype_or_device, IRFullTensor): - return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device.dtype) + dtype = eval('torch.'+dtype_or_device.dtype.value) + return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) else: raise RuntimeError(f'function.To with unknown arg: {dtype_or_device}') diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 681e71f8..64806b92 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -104,23 +104,39 @@ def parse(module: torch.fx.GraphModule, else: assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' # remove dead nodes - from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler + # from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler DCEHandler(module).eliminate_dead_code() # shape propagation - from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp - KwargsShapeProp(module).propagate(dummy_inputs) + # from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp + # KwargsShapeProp(module).propagate(dummy_inputs) + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp + ShapeProp(module).propagate(dummy_inputs) + # print('zql dummy inputs: ', dummy_inputs) + # print('zql graph inputs: ', inputs) + # print(inputs[0].__dir__()) + # print(inputs[0].meta, inputs[0].name) + # print(inputs[1].meta, inputs[1].name, inputs[1].target, inputs[1].args, inputs[1].kwargs) + # # print(inputs[1]['src_lengths'].meta, inputs[1].name) + # exit(1) # handle graph inputs for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) - # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, - # extend to other input types - if hasattr(dummy_inputs, input.name): - print(f'dummy_inputs has {input.name}') - shape = getattr(dummy_inputs, input.name).size() + if isinstance(dummy_inputs, dict): + if input.name in dummy_inputs: + shape = input.meta['tensor_meta'].shape + else: + shape = None else: - # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name - print(f'dummy_inputs does not have {input.name}') - shape = None + # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, + # extend to other input types + if hasattr(dummy_inputs, input.name): + print(f'dummy_inputs has {input.name}') + shape = getattr(dummy_inputs, input.name).size() + else: + # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name + print(f'dummy_inputs does not have {input.name}') + shape = None dtype = kDefaultType val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) @@ -150,7 +166,8 @@ def parse(module: torch.fx.GraphModule, shape = node.meta['tensor_meta'].shape shape = FxModuleParser.shape_refine(shape) dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) + requires_grad = node.meta['tensor_meta'].requires_grad + val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=node.name) frame.add_var(node.name, val) else: frame.add_var(node.name, IRObject()) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 89a332e3..5c7e3a5b 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -29,6 +29,7 @@ class CompProfiler: @staticmethod def profile(func: Callable, shapes: Shapes, dtypes: DTypes, + requires_grads: Tuple[bool], warmup_sec: float = 2, prof_times: int = 50, **kwargs) -> Tuple[float, float, int, Tuple[int]]: """ @@ -50,14 +51,19 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" # create data dtypes = [torch.float32] * len(shapes) if dtypes is None else dtypes - def gen_torch_tensors(shape, dtype): + def gen_torch_tensors(shape, dtype, requires_grad): constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand - requires_grad = False if dtype in (torch.int64, torch.int32, torch.bool) else True + # requires_grad = False if dtype in (torch.int64, torch.int32, torch.bool) else True return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) tensors = tuple( - gen_torch_tensors(shape, dtype) for shape, dtype in zip(shapes, dtypes) + gen_torch_tensors(shape, dtype, requires_grad) for shape, dtype, requires_grad in zip(shapes, dtypes, requires_grads) ) require_backward = any([t.requires_grad for t in tensors]) + # require_backward = True + print('zql: ', func.__name__, [t.requires_grad for t in tensors]) + # FIXME: reconsidering requires_grad + if func.__name__ in ('type_as'): + require_backward = False # repalce kwargs starting with 'self.xxx' train_kwargs, eval_kwargs = {}, {} for name, value in kwargs.items(): @@ -69,6 +75,7 @@ def gen_torch_tensors(shape, dtype): train_kwargs[name] = train_val eval_kwargs[name] = eval_val # run one sample + # print(func, func.__name__, tensors, train_kwargs) outputs = func(*tensors, **train_kwargs) outputs = (outputs,) if torch.is_tensor(outputs) else outputs assert all(torch.is_tensor(otensor) for otensor in outputs), \ @@ -195,12 +202,13 @@ def get_dep_names(sign: str): fn = eval(node.signature.replace('_operator.', 'torch.')) else: fn = eval(node.signature) - shapes, dtypes = [], [] + shapes, dtypes, requires_grads = [], [], [] for t in node.inputs(): assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" shapes.append(t.shape) dtypes.append(IRDType2TorchDType.map(t.dtype)) - return fn, shapes, dtypes, node.kwargs + requires_grads.append(t.requires_grad) + return fn, shapes, dtypes, requires_grads, node.kwargs def profile(self, node: IRFwOperation, device: Optional[int] = None): """ @@ -216,7 +224,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @return infer_memory int: the peak memory in bytes after inference of the function @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ - fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) + fn, shapes, dtypes, requires_grads, kwargs = ProfileDataBase.get_func(node) if self.exist(node): return self.query(node) @@ -239,7 +247,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): # run profiling fw_span, bw_span, infer_memory, train_mem_info = \ - CompProfiler.profile(fn, shapes, dtypes, **kwargs) + CompProfiler.profile(fn, shapes, dtypes, requires_grads, **kwargs) # log to database key = self._serialize(node) self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 6d78c301..2da5a40b 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -122,3 +122,11 @@ def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): return torch.masked_scatter(input, mask, src) +def index_select(input: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor: + return torch.index_select(input, dim, index) + +def einsum(*operands, equation=None) -> torch.Tensor: + return torch.einsum(equation, *operands) + +def stack(*tensors, dim=0) -> torch.Tensor: + return torch.stack(tensors, dim) \ No newline at end of file From 72b532aaf535c12156d303fde8a61c89c45daf45 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sat, 11 Mar 2023 05:25:33 -0800 Subject: [PATCH 1338/1892] update --- examples/nlp/torchscale/run_torchscale_lm.py | 233 ++++++++++++++----- 1 file changed, 178 insertions(+), 55 deletions(-) diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index 589d9e55..160e2337 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -1,5 +1,9 @@ +# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:/home/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale:$PYTHONPATH python3 run_torchscale_lm.py /home/quzha/MagicCube/examples/nlp/torchscale/input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData + + # USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData +from pathlib import Path import torch import pickle from fairseq import ( @@ -20,8 +24,8 @@ # https://github.com/microsoft/torchscale/tree/main/examples/fairseq # sys.path.append('/home/v-junliang/torchscaletest/torchscale/examples/fairseq') # sys.path.append('./torchscaletest/torchscale/examples/fairseq') -sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') -sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale') +sys.path.append('/home/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') +sys.path.append('/home/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale') print(f'sys.path = {sys.path}') import models @@ -67,7 +71,7 @@ print("building model succeed: ", type(model)) # create dummy input -with open('examples/nlp/torchscale/input_lm.bak', 'rb') as f: +with open('/home/quzha/MagicCube/examples/nlp/torchscale/input_lm', 'rb') as f: # with open('examples/nlp/torchscale/lm_input_v2.pkl', 'rb') as f: dummy_input = pickle.load(f) device = next(model.parameters()).device @@ -105,32 +109,32 @@ print(f'next(dataloader) = {sample_input}') sample_input_cpu = tuple([val.to(device) if isinstance(val, torch.Tensor) else val for val in sample_input]) -model = cube.SemanticModel( - #TODO fix me model, dummy_input=sample_input_cpu, - # model, dummy_input=dummy_input_list, - model, dummy_input=dummy_input, -) +# model = cube.SemanticModel( +# #TODO fix me model, dummy_input=sample_input_cpu, +# # model, dummy_input=dummy_input_list, +# model, dummy_input=dummy_input, +# ) -@cube.compile(model, dataloader, PAS=PAS, load_content=False) -def train_iter(model, dataloader): - data = next(dataloader) - loss = model(*data) - # loss.backward() - return loss +# @cube.compile(model, dataloader, PAS=PAS, load_content=False) +# def train_iter(model, dataloader): +# data = next(dataloader) +# loss = model(*data) +# # loss.backward() +# return loss -model = model.get_gen_module() +# model = model.get_gen_module() -iter_ret = train_iter(model, dataloader) -print(f'iter_ret = {iter_ret}') +# iter_ret = train_iter(model, dataloader) +# print(f'iter_ret = {iter_ret}') # Conduct concrete trace below # sys.path.append('/home/v-junliang/torchscaletest/nni') # sys.path.append('./torchscaletest/nni') # from nni.common.concrete_trace_utils import concrete_trace # from concrete_trace_utils import concrete_trace -from examples.nlp.torchscale.concrete_trace_utils import concrete_trace -import examples.nlp.torchscale.torchscaletest.torchscale - +# from examples.nlp.torchscale.concrete_trace_utils import concrete_trace +# import examples.nlp.torchscale.torchscaletest.torchscale +from cube.graph.parser.concrete_trace_utils.concrete_tracer import concrete_trace def check_equal(a, b): if type(a) != type(b): @@ -155,9 +159,8 @@ def check_equal(a, b): else: return a == b - print("start tracing...") -traced_model, _ = concrete_trace( +traced_graph = concrete_trace( model, dummy_input, use_operator_patch=True, @@ -169,40 +172,160 @@ def check_equal(a, b): print("trace succeed") print("checking equal...") with torch.no_grad(): - output_traced = traced_model(**dummy_input) + output_traced = traced_graph(**dummy_input) assert check_equal(output_origin, output_traced), "check equal failed" print("checked") # check graph -traced_model.graph.print_tabular() - -# with open('input_tl', 'wb') as f: -# pickle.dump(dummy_input, f) - -# try to save traced model with pickle -# from concrete_trace_utils.concrete_tracer import MagicMethodPatcher -# from pickle import _Pickler, _Unpickler - -# with open("save/through_nn_Module/tl_traced_v2.model", "wb") as f: -# # pickle.dump(traced_model, f) -# with MagicMethodPatcher(): -# _Pickler(f).dump(traced_model) - -# with open("save/through_nn_Module/tl_traced.model", "rb") as f: -# with MagicMethodPatcher(): -# reload_model = _Unpickler(f).load() - - -# with torch.no_grad(): -# output_reload = reload_model(**dummy_input) -# assert check_equal(output_origin, output_reload), "reload check equal failed" -# print("reload is good!") - -# with open("save/through_nn_Module/tl_origin_v2.model", "wb") as f: -# with MagicMethodPatcher(): -# _Pickler(f).dump(model) - -# with open("save/through_nn_Module/tl_input_v2.pkl", "wb") as f: -# with MagicMethodPatcher(): -# _Pickler(f).dump(dummy_input) - +traced_graph.graph.print_tabular() + +print("parsing fx graph to cube graph...") +from cube.graph.parser import FxModuleParser +inputs, nodes, outputs = FxModuleParser.parse(traced_graph, dummy_inputs=dummy_input) +print("parsing done.") +from cube.graph import IRGraph +module_name = model.__class__.__name__ +cube_graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) +print("generating cube ir graph done.") + +# AutoDist +# # profile communication cost +# import os +# comm_gpu_num = (2, 4) +# for gpu_num in comm_gpu_num: +# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --connect_type=NV') +# profile computation cost +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ +config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': 'torchscale'}) +config.autodist_config = dotdict({'ngpus': 2}) +# NOTE add SINGLE_DEV_MODE=1 before the running command +from autodist.cost_model.cost_database import CostDatabase +cost_database = CostDatabase(cube_graph, config) +# find the best partition plan +from autodist.task_config import TaskConfig +class TorchscaleTaskConfig(TaskConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model = 'Bloom' + # self.Bloom_setting = kwargs['Bloom_setting'] + # self.fine_grained_Bloom = kwargs['fine_grained_Bloom'] + # self.bloom_config = build_bloom_config(self.Bloom_setting) + self.task_name = f'torchscale-{self.autodist_config.ngpus}gpu-'\ + f'{self.autodist_config.micro_batch_size}batch_size' + self.estimated_fname, self.backup_fname, self.runtime_fname = self._build_file_name( + self.task_name) + self.allow_recom_ops = [] + self.del_dim = [] +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Torchscale benchmark') + parser.add_argument('--fp16', + action='store_true', + help='use fp16 for the training') + parser.add_argument('--fine_grained_GPT', + action='store_true', + help='model = GPTFineGrained') + parser.add_argument('--GPT_setting', + type=str, + default='6.7B', + help='set GPT model type') + parser.add_argument('--save_folder', + type=str, + default='exp_data', + help='set the save folder for experiment data') + parser.add_argument('--micro_batch_size', + type=int, + default=8, + help='set micro batch size') + parser.add_argument('--global_batch_size', + type=int, + default=8, + help='set the global batch size') + parser.add_argument('--iter_num', + type=int, + default=2, + help='set the number of all iterations') + parser.add_argument('--warm_num', + type=int, + default=1, + help='set the number of warmup iterations') + parser.add_argument('--recompute', + action='store_true', + help='set recompute flag') + parser.add_argument('--memory_constraint', + type=float, + default=32, + help='memory constraint for program') + parser.add_argument('--memory_granularity', + type=int, + default=1, + help='memory granularity in byte') + parser.add_argument('--profile_dir', + type=str, + default=str(Path.home()) + '/.autodist', + help='profile dir') + parser.add_argument('--connect_type', + type=str, + default='NV2', + help='connect type from nvidia-smi topo -m') + parser.add_argument('--use_prev_plan', + action='store_true', + help='run from previous plan') + parser.add_argument('--is_train', + action='store_true', + help='True: train, False: inference') + parser.add_argument('--topk', + type=int, + default=20, + help='generate multiple plans for robustness') + parser.add_argument('--mesh_row', type=int, default=1, help='node num') + parser.add_argument('--mesh_col', + type=int, + default=2, + help='dev num in a node') + parser.add_argument('--compile', + action='store_true', + help='compile stage: true, runtime stage: false') + parser.add_argument('--pipeline', + action='store_true', + help='pipeline: true, tensor parallel: false') + parser.add_argument('--nproc', + type=int, + default=12, + help='multiprocess deg in pipeline') + parser.add_argument('--adaptive_recom', + action='store_true', + help='allow adaptive recompute') + parser.add_argument('--plan_idx', + type=int, + default=0, + help='runtime plan idx') + parser.add_argument('--verbose', action='store_true', help='verbose mode') + parser.add_argument('--ignore_small_tensor_threshold', + type=int, + default=0, + help='set the tensor size threshold to ignore') + parser.add_argument('--parse_plan', + action='store_true', + help='parse plan to user-friendly format') + parser.add_argument('--alphafold', + action='store_true', + help='use alphafold2') + parser.add_argument('--alphafold_setting', + type=int, + default=1, + help='1: bs, s, r = 1, 128, 256'\ + '2: bs, s, r = 1, 512, 256'\ + '3: bs, s, r = 1, 512, 384') + args = parser.parse_args() + + # if args.compile: + # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' + + task_config = TorchscaleTaskConfig(**vars(args)) + from autodist.apis import calc_parallel_plan + topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) From 772bb26e88611a70fe02c7d05eaab12faac906d6 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Sat, 11 Mar 2023 05:32:00 -0800 Subject: [PATCH 1339/1892] minor --- cube/graph/parser/parserfx.py | 7 ------- cube/profiler/database.py | 3 --- 2 files changed, 10 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 64806b92..fd73182f 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -112,13 +112,6 @@ def parse(module: torch.fx.GraphModule, # KwargsShapeProp(module).propagate(dummy_inputs) from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp ShapeProp(module).propagate(dummy_inputs) - # print('zql dummy inputs: ', dummy_inputs) - # print('zql graph inputs: ', inputs) - # print(inputs[0].__dir__()) - # print(inputs[0].meta, inputs[0].name) - # print(inputs[1].meta, inputs[1].name, inputs[1].target, inputs[1].args, inputs[1].kwargs) - # # print(inputs[1]['src_lengths'].meta, inputs[1].name) - # exit(1) # handle graph inputs for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 5c7e3a5b..3fe6bbd5 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -59,8 +59,6 @@ def gen_torch_tensors(shape, dtype, requires_grad): gen_torch_tensors(shape, dtype, requires_grad) for shape, dtype, requires_grad in zip(shapes, dtypes, requires_grads) ) require_backward = any([t.requires_grad for t in tensors]) - # require_backward = True - print('zql: ', func.__name__, [t.requires_grad for t in tensors]) # FIXME: reconsidering requires_grad if func.__name__ in ('type_as'): require_backward = False @@ -75,7 +73,6 @@ def gen_torch_tensors(shape, dtype, requires_grad): train_kwargs[name] = train_val eval_kwargs[name] = eval_val # run one sample - # print(func, func.__name__, tensors, train_kwargs) outputs = func(*tensors, **train_kwargs) outputs = (outputs,) if torch.is_tensor(outputs) else outputs assert all(torch.is_tensor(otensor) for otensor in outputs), \ From 464a75f12cc802eec1930f8988dfec8aec2aa52d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 12 Mar 2023 04:10:56 +0000 Subject: [PATCH 1340/1892] Merged PR 1486: pyfunc folding for tuple, list, getitem pyfunc folding for tuple, list, getitem --- cube/graph/function/function.py | 54 +++++++++++++++++-- cube/graph/function/pyfunc.py | 13 ++++- cube/graph/parser/mappingfx.py | 8 +-- cube/graph/parser/parserfx.py | 86 +++++++++++++++++++++---------- cube/ir/tensor.py | 2 +- cube/runtime/function/function.py | 4 +- 6 files changed, 129 insertions(+), 38 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 23b337c3..8c4c0dbf 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Tuple, Dict, Union +from typing import Any, Callable, List, Optional, Tuple, Dict, Union, Iterable import string import copy import torch @@ -98,7 +98,7 @@ def EinSum(equation: str, *operands, signature = None): return CubeEinSum(*operands, equation=equation, signature=signature) -def Matmul(signature, input, other, *, out=None): +def Matmul(input, other, *, out=None, signature=None): assert out is None annos = [ 'm k+, k+ n -> m n', @@ -408,6 +408,7 @@ def Add(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input + alpha * other + signature = 'torch.add' if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] @@ -445,6 +446,7 @@ def Mul(input, other, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input * other + signature = 'torch.mul' if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] @@ -1130,6 +1132,40 @@ def IndexSelect(input: torch.Tensor, dim: int, index: torch.Tensor, *, out=None, return CubeIndexSelect(input, index, dim, signature=signature) +def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice]], signature=None): + """ + subtensor = tensor[:,128:] + subtensor = tensor[0,128:] + subtensor = tensor[0] + """ + signature = 'cube.runtime.function.fullslice' + slicers = tuple(slicers) + (None,) * (len(tensor.shape) - len(slicers)) + edim_in = ShapeAnno.create_shape_str(tensor.shape) + edim_ou = [] + for dim, slicer in enumerate(slicers): + if slicer is None: + if dim < len(edim_in): + edim_ou.append(edim_in[dim]) + else: + # expand the dimension + edim_ou.append('1') + else: + edim_in[dim] += '^' + if isinstance(slicer, slice): + stop = tensor.shape[dim] if slicer.stop is None else slicer.stop + start = 0 if slicer.start is None else slicer.start + step = 1 if slicer.step is None else slicer.step + dimlen = len(range(start, stop, step)) + edim_ou.append(str(dimlen)) + else: + pass # no shape for int + # special case for loss = torch.Tensor([1,2,3])[0] + if len(edim_ou) == 0: + edim_ou = ['1^'] + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(FullSlice, 'fullslice', signature, [anno], [tensor], slicers=slicers) + + def Slice(tensor: torch.Tensor, dim, start, end, step, signature = None): """ aten::slice(input:Tensor, dim:int, start:Optional[int], end:Optional[int], step:int) -> Tensor @@ -1423,8 +1459,10 @@ def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: obj, index = a, b if (not isinstance(obj, IRObject)) and isinstance(index, int): return obj[index] - else: - return IRPyFunc(signature, [obj, index], [IRObject()]) + # case: subtensor = tensor[1,:2] + if isinstance(obj, IRTensor): + return FullSlice(obj, b) + return IRPyFunc(signature, [obj, index], [IRObject()]) def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], IRPyFunc]: @@ -1451,3 +1489,11 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], def FInfo(dtype: IRDType, signature = None) -> torch.finfo: assert isinstance(dtype, IRDType) return torch.finfo(eval('torch.' + dtype.value)) + + +def MakeTuple(inputs: Iterable, signature=None): + return tuple(inputs) + + +def MakeList(inputs: Iterable, signature=None): + return list(inputs) diff --git a/cube/graph/function/pyfunc.py b/cube/graph/function/pyfunc.py index 074cb859..e0b2aa85 100644 --- a/cube/graph/function/pyfunc.py +++ b/cube/graph/function/pyfunc.py @@ -25,5 +25,16 @@ def infer_shape(self) -> bool: """ return True - + def __repr__(self) -> str: + sign = self.signature.split('.')[-1] + dscp = (f"PyOp{self._id}-{self.device}(sign={sign}, " + f"inputs={self.inputs()}, " + f"outputs={self.outputs()})") + return dscp + def extra_repr(self) -> str: + sign = self.signature.split('.')[-1] + dscp = (f"PyOp{self._id}-{self.device}(sign={sign}, " + f"inputs={self.inputs()}, " + f"outputs={self.outputs()})") + return dscp diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 31b271f4..827ab515 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -110,12 +110,14 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __tttemplate('to'): function.To, '_operator.getitem': function.GetItem, 'builtins.getattr': function.GetAttr, + 'builtins.tuple': function.MakeTuple, + 'builtins.list': function.MakeList, # # torch nn functional # # __ftemplate('linear') : function.Linear, # - # __ttemplate('matmul'): function.Matmul, + __ttemplate('matmul'): function.Matmul, # # __ftemplate('gelu') : function.GeLU, # __ttemplate('gelu') : function.GeLU, @@ -190,11 +192,11 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # #pytorch1.11 # __ttemplate('linear'): function.Linear, # - # __ttemplate('cat'): function.Cat, + __ttemplate('cat'): function.Cat, __ttemplate('stack'): function.Stack, # - # __ttemplate('chunk'): function.Chunk, + __ttemplate('chunk'): function.Chunk, __ttemplate('flatten'): function.Flatten, # diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 681e71f8..1c111c81 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -166,8 +166,10 @@ def parse(module: torch.fx.GraphModule, ir_nodes = FxModuleParser.parse_node(node, module, frame) if ir_nodes is not None: all_ir_nodes += ir_nodes - - output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + + output_nodes = [node for node in module.graph.nodes if node.op == 'output'] + assert len(output_nodes) == 1, f"get mutiple {len(all_ir_nodes)} output nodes" + output_val = frame.get_var(output_nodes[0].name) frame.pop_var() frame.pop_attr() @@ -253,6 +255,7 @@ def get_complex_data(val: Any) -> Any: # map to IR operator if SignFx2Op.exist(fsig): + print(input_vals) ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) else: # FIXME: handle cases for IRObject in kwargs @@ -271,9 +274,17 @@ def get_complex_data(val: Any) -> Any: if isinstance(ir_node, IRCell): # TODO gracefully set output - output_name = node.name - output_val = frame.get_var(output_name) - ir_node.set_output(0, output_val) + if len(ir_node.outputs()) > 1: + # REMARK: some nodes will return multiple outputs, e.g., torch.chunk, + # while torch.fx always return one output. This will cause + # getitem or unpacking operation on the output, which can be folded by + # setting the list of the output tensor + print('>> parsing {ir_node}') + ir_node.infer_shape() + frame.set_var(node.name, ir_node.outputs()) + else: + output_val = frame.get_var(node.name) + ir_node.set_output(0, output_val) return [ir_node] else: frame.set_var(node.name, ir_node) @@ -309,30 +320,49 @@ def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr ir_nodes = [] # handle complex outputs - def generate_outputs(val: Any, _ops: List) -> IRObject: - """Support complex data type of List, Tuple, Dict, Tensor/Object""" - if isinstance(val, list): - inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) - output = IRObject() - _ops.append(IRPyFunc('(lambda *args: list(args))', inputs, [output])) - return output - if isinstance(val, tuple): - inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) - output = IRObject() - _ops.append(IRPyFunc('(lambda *args: args)', inputs, [output])) - return output - if isinstance(val, dict): - output = IRObject() - assert all(not isinstance(key, torch.fx.Node) for key in val.keys()), f"output dict cannot have torch.fx.Node is key" - keys = tuple(str(key) for key in val.keys()) - values = generate_outputs(tuple(generate_outputs(value, _ops) for value in val.values()), _ops) - _ops.append(IRPyFunc('(lambda vals, keys: {key:val for key,val in zip(keys,vals)})', [values], [output], keys=keys)) - return output - if isinstance(val, torch.fx.Node): - return frame.get_var(val.name) - return val + # def generate_outputs(val: Any, _ops: List) -> IRObject: + # """Support complex data type of List, Tuple, Dict, Tensor/Object""" + # if isinstance(val, list): + # inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) + # output = IRObject() + # _ops.append(IRPyFunc('(lambda *args: list(args))', inputs, [output])) + # return output + # if isinstance(val, tuple): + # inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) + # output = IRObject() + # _ops.append(IRPyFunc('(lambda *args: args)', inputs, [output])) + # return output + # if isinstance(val, dict): + # output = IRObject() + # assert all(not isinstance(key, torch.fx.Node) for key in val.keys()), f"output dict cannot have torch.fx.Node is key" + # keys = tuple(str(key) for key in val.keys()) + # values = generate_outputs(tuple(generate_outputs(value, _ops) for value in val.values()), _ops) + # _ops.append(IRPyFunc('(lambda vals, keys: {key:val for key,val in zip(keys,vals)})', [values], [output], keys=keys)) + # return output + # if isinstance(val, torch.fx.Node): + # return frame.get_var(val.name) + # return val + # output = generate_outputs(node.args[0], ir_nodes) - output = generate_outputs(node.args[0], ir_nodes) + # def generate_outputs(val: Any) -> Any: + # """Support complex data type of List, Tuple, Dict, Tensor/Object""" + # if isinstance(val, list): + # return list(generate_outputs(item) for item in val) + # if isinstance(val, tuple): + # return tuple(generate_outputs(item) for item in val) + # if isinstance(val, dict): + # return {generate_outputs(key) : generate_outputs(value) for key, value in val.items()} + # if isinstance(val, torch.fx.Node): + # return frame.get_var(val.name) + # # for other types like int, float, ... + # return val + # output = generate_outputs(node.args[0]) + + # TODO: support more complex data type + outs = (node.args[0],) if isinstance(node.args[0], torch.fx.Node) else node.args[0] + assert all(isinstance(t, torch.fx.Node) for t in outs), "Only support model return with tuple of " + output = [frame.get_var(t.name) for t in outs] + frame.set_var(node.name, output) return ir_nodes diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 96bdef4e..0740560e 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -256,7 +256,7 @@ class IRFullTensor(IRTensor): the sequentail execution order by its graph. """ - def __init__(self, shape=None, name=None, requires_grad=False, dtype=IRDType.unknown): + def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=IRDType.unknown): super().__init__(shape, name, dtype) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 6d78c301..d16155ae 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -43,6 +43,9 @@ def sub(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Unio def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: return input.expand(*sizes) +def fullslice(input: torch.Tensor, slicers: Tuple[Union[None, slice]]): + return input[slicers] + def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], stride: int, padding: List[int], dilation, groups: int = 1): """ @@ -73,7 +76,6 @@ def conv3d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso input = TorchF.pad(input, pad_padding, 'constant', 0) return TorchF.conv3d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) - def embedding(input: torch.Tensor, weight: torch.Tensor, padding_idx: Optional[int], start: int, stop: int): """ Embedding From 3ab9fc5369cf96a97246c7c066d0d274363af261 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sat, 11 Mar 2023 21:43:42 -0800 Subject: [PATCH 1341/1892] save work --- cube/graph/function/function.py | 35 +++++++++++++++++++++++++++++++++ cube/graph/parser/mappingfx.py | 2 ++ 2 files changed, 37 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index da744def..4802cc49 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -554,6 +554,21 @@ def Softmax(input, dim=None, _stacklevel=3, dtype=None, signature = None): return IRDimops(Softmax, 'softmax', signature, [anno], [input], dim=dim, _stacklevel=_stacklevel, dtype=dtype) + +def LogSoftmax(input, dim=None, _stacklevel=3, dtype=None, signature=None): + """ + torch.nn.functional.log_softmax(input, dim=None, _stacklevel=3, dtype=None) + """ + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = copy.copy(edim_in) + if dim is not None: + edim_in[dim] += '^' + edim_ou[dim] += '^' + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(LogSoftmax, 'log_softmax', signature, [anno], [input], + dim=dim, _stacklevel=_stacklevel, dtype=dtype) + + def Dropout(input, p=0.5, training=True, inplace=False, signature = None): """ torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False) @@ -1523,3 +1538,23 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], def FInfo(dtype: IRDType, signature = None) -> torch.finfo: assert isinstance(dtype, IRDType) return torch.finfo(eval('torch.' + dtype.value)) + + +def NLLLoss(input, target, weight=None, size_average=None, + ignore_index=-100, reduce=None, reduction='mean', + signature=None): + """ + torch.nn.functional.nll_loss(input, target, weight=None, size_average=None, + ignore_index=-100, reduce=None, reduction='mean') + """ + assert weight is None + annos = [ + 'C^, N -> 1', + 'N+ C, N+ -> 1', + 'N+ C *, N+ * -> 1' + ] + return IRDimops( + NLLLoss, 'nll_loss', + signature, annos, [input, target], + weight=weight, size_average=size_average, ignore_index=ignore_index, + reduce=reduce, reduction=reduction) \ No newline at end of file diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 31b271f4..f9bf4a16 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -84,6 +84,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('cumsum'): function.CumSum, __ttemplate('tanh'): function.Tanh, __ftemplate('softmax') : function.Softmax, + __ftemplate('log_softmax') : function.LogSoftmax, __ttemplate('bmm') : function.BatchLinear, __ttemplate('pow'): function.Pow, __ttemplate('baddbmm'): function.BMMAdd, @@ -102,6 +103,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] '_operator.invert': function.BitwiseNot, __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, + __ftemplate('nll_loss') : function.NLLLoss, __ftemplate('layer_norm'): function.LayerNorm, From ea68abb2574a9a1cf70d18f03631d616ea8defb8 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Mon, 13 Mar 2023 03:06:32 +0000 Subject: [PATCH 1342/1892] Merged PR 1488: Fix torchscale LM inference @ may help look at the changes on output parsing --- cube/codegen/frontend_mapping.py | 7 +- cube/codegen/module/module.py | 2 +- cube/graph/function/creators.py | 9 +- cube/graph/function/function.py | 157 +++++++++++++++++-- cube/graph/gener/gen.py | 2 +- cube/graph/graph.py | 9 +- cube/graph/parser/frame.py | 1 + cube/graph/parser/mappingfx.py | 4 + cube/graph/parser/parserfx.py | 76 +++++---- cube/program.py | 14 +- cube/runtime/syndata.py | 10 +- examples/nlp/torchscale/run_torchscale_lm.py | 106 ++----------- 12 files changed, 250 insertions(+), 147 deletions(-) diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index c145dd0f..0f08dfe4 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -82,10 +82,13 @@ def emit_zeros(node, arg_vars:list, kw_pairs:dict) -> str: kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. - assert 'device' not in kw_pairs + if 'device' in kw_pairs: + print(f'WARNING: overload device info. of {node}') kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. - assert len(arg_vars) == 0 + if len(arg_vars) != 0: + print(f'WARNING: emit_zero with len(arg_vars) {len(arg_vars)} != 0') + return _common_rule_join_all(node, arg_vars, kw_pairs) def emit_ones(node, arg_vars:list, kw_pairs:dict) -> str: diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index e084873d..7b70cfff 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -82,7 +82,7 @@ def __init__(self, execplan: ExecutionPlan) -> None: '\n\n########## Generated Model Code ###########', 'from typing import *', 'import torch', 'import torch.utils.checkpoint as ckpt', - 'import cube', 'import _operator', '', ''] + 'import cube', 'import _operator', 'from numpy import inf', '', ''] if CompileFlag.use_nnfusion: self.init_code.extend(['import nnfusion', '']) diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py index 356b0698..f71ef98e 100644 --- a/cube/graph/function/creators.py +++ b/cube/graph/function/creators.py @@ -13,14 +13,19 @@ def __init__(self, signature: str, shape: List[int], name: str, **kwargs): # The shape information must be statically known integer values assert all(isinstance(dim, int) for dim in shape) assert 'dtype' in kwargs - assert isinstance(kwargs['dtype'], IRDType) + dtype = kwargs['dtype'] + assert not isinstance(dtype, IRDType) + + from cube.graph.parser.mapping import DType2IRDType + ir_dtype: IRDType = DType2IRDType.map(dtype) super().__init__(name, signature, input_length=0, output_length=1) # Customize output's dtype only after 'super().__init__' and 'self.set_input', # otherwise it gets overwritten. - self.output(0).dtype = kwargs['dtype'] + self.output(0).dtype = ir_dtype self.shape = shape + kwargs.update({'dtype': dtype, 'device': 'cuda'}) #TODO check me and fix more, e.g., ones, zeros, empty self.kwargs = kwargs def infer_shape(self) -> bool: diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 8c4c0dbf..4b1b92be 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -128,11 +128,11 @@ def Arange(*args, out=None, dtype=None, layout=torch.strided, device=None, requi from cube.graph.parser.mapping import DType2IRDType if dtype is None: dtype = torch.get_default_dtype() - ir_dtype : IRDType = DType2IRDType.map(dtype) + import math size = (math.ceil((end-start)/step),) - kwargs = {'start': start, 'end': end, 'step': step, 'out': out, 'dtype': ir_dtype, - 'layout': layout, 'device': device, 'requires_grad': requires_grad} + kwargs = {'start': start, 'end': end, 'step': step, 'out': out, 'dtype': dtype, + 'layout': layout, 'requires_grad': requires_grad} return IRArange(signature, size, 'arange', **kwargs) @@ -863,17 +863,156 @@ def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: return IRDimops(View, 'view', signature, [anno], [input], rules, size=tuple(size)) -def Reshape(input, *shape, signature = None): +def Reshape(input, shape: Tuple[int], *arg_shape, signature = None): """ torch.reshape(Tensor self, int[] shape) -> Tensor """ - warnings.warn(""" - 'torch.reshape' is currently dispatched to 'torch.Tensor.view', - but 'reshape' has keyword parameter 'shape' while 'view' has 'size'. - ArgumentMissing error may be raised during codegen.""") + size = (shape,) if isinstance(shape, int) else tuple(shape) + size = size + arg_shape + assert all([isinstance(dim, int) for dim in size]), \ + f"Expected tensor.view has static int shape but got: {size}" + in_shape, ou_shape = list(input.shape), list(size) + + # infer -1 + def nele(shape, nele=1): + for dimlen in shape: nele *= dimlen + return nele + + cnt = nele(in_shape) + if -1 in ou_shape: + idx = ou_shape.index(-1) + ou_shape[idx] = cnt // (-nele(ou_shape)) + assert nele(in_shape) == nele(ou_shape), f"shape mismatch: {in_shape}, {ou_shape}" + + # generate annotation + rest_inshape = [dimlen for dimlen in in_shape] + rest_oushape = [dimlen for dimlen in ou_shape] + chain = [] + can_bucket = True + while len(rest_inshape) != 0 or len(rest_oushape) != 0: + if len(rest_inshape) == 0: + chain = chain + rest_oushape + rest_oushape = [] + elif len(rest_oushape) == 0: + chain = chain + rest_inshape + rest_inshape = [] + else: + dimlen = min(rest_inshape[0], rest_oushape[0]) + if max(rest_inshape[0], rest_oushape[0]) % dimlen == 0: + chain.append(dimlen) + if dimlen == rest_inshape[0]: + rest_inshape.pop(0) + else: + rest_inshape[0] = rest_inshape[0] // dimlen + if dimlen == rest_oushape[0]: + rest_oushape.pop(0) + else: + rest_oushape[0] = rest_oushape[0] // dimlen + else: + can_bucket = False + break + + letters = iter(string.ascii_lowercase) + if can_bucket: + inchain = ouchain = chain + inedims = ouedims = edims = [next(letters) for _ in chain] + else: + inchain, ouchain = in_shape, ou_shape + inedims = [str(dimlen) for dimlen in in_shape] + ouedims = [str(dimlen) for dimlen in ou_shape] + chain = inchain + ouchain + edims = inedims + ouedims + shape_map: Dict[str, int] = {edim: eshape for (edim, eshape) in zip(edims, chain)} + + # generate input and output shape annotations + def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[str]]: + anno = [] + dimidx = 0 + for idx, dimlen in enumerate(shape): + elements, bracket = 1, [] + maxele = len(chain) - dimidx - (len(shape) - 1 - idx) + while True: + if len(bracket) == maxele: + assert elements == dimlen, f"internal match error1: {bracket}" + break + if dimidx >= len(chain) or elements * chain[dimidx] > dimlen: + assert elements == dimlen, f"internal match error2: {bracket}" + break + else: + elements *= chain[dimidx] + bracket.append(edims[dimidx]) + dimidx += 1 + anno.append(bracket) + return anno + + in_anno = buckets(in_shape, inchain, inedims) + ou_anno = buckets(ou_shape, ouchain, ouedims) + + # postprocess on dimlen == 1 + shape_map['1'] = 1 + for bracket in in_anno + ou_anno: + for subdim, edim in enumerate(bracket): + if shape_map[edim] == 1: + bracket[subdim] = str(shape_map[edim]) + + # find out the axis that can be partitioned + ispatial, ifirst = set(), [] + for bracket in in_anno: + sdim = None + for hdim in range(len(bracket)): + if bracket[hdim] == '1' or shape_map[bracket[hdim]] == 1: continue + sdim = bracket[hdim] + break + if sdim is not None: + ispatial.add(sdim) + ifirst.append(sdim) + + ospatial, ofirst = set(), [] + for bracket in ou_anno: + sdim = None + for hdim in range(len(bracket)): + if bracket[hdim] == '1' or shape_map[bracket[hdim]] == 1: continue + sdim = bracket[hdim] + break + if sdim is not None: + ospatial.add(sdim) + ofirst.append(sdim) + + # intersection for spatial partitioned dimensions + spatial = ispatial.intersection(ospatial) + + # set dimension cannot be partitioned + for bracket in in_anno + ou_anno: + for hdim in range(len(bracket)): + if bracket[hdim] not in spatial: + bracket[hdim] = str(shape_map[bracket[hdim]]) + + # TODO: strange behaviour if every identitifer creates own + # modifier, seems all previous modifiers will be overrided by + # the last one. + def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: + kwargs = dict(**kwargs) + identifier = ifirst[dim] + oidx = ofirst.index(identifier) + size = list(kwargs['shape']) + size[oidx] = size[oidx] // num + kwargs['shape'] = tuple(size) + return kwargs + + # special rules: to change output size argument + rules = [] + for identifier in spatial: + iidx = ifirst.index(identifier) + oidx = ofirst.index(identifier) + rules.append( + TransformRule([DimopSplit.D(iidx)], [DimopSplit.D(oidx)], view_modifier) + ) + + anno = OpAnno.create_op_str([in_anno], [ou_anno]) - return View(input, *shape, signature='torch.Tensor.view') + new_signature = 'torch.Tensor.reshape' + return IRDimops(Reshape, 'shape', new_signature, [anno], [input], rules, shape=tuple(size)) def Permute(input, dims: Tuple[int], *arg_dims, signature = None): diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 19529200..dc0c9935 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -165,7 +165,7 @@ def auto_pyfunc(graph: IRSegment): for t in func.inputs(): if isinstance(t, IRObject): if t.is_attr(): - tensors = set(tensor for tensor in graph.ctensors(t.parent) if devid in tensor.device and tensor.cell != func) + tensors = set(tensor for tensor in graph.ctensors(t.parent) if devid in tensor.device) else: tensors = set(tensor for tensor in graph.ptensors(t.parent) if devid in tensor.device) assert len(tensors) == 1, \ diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 9cd54f85..ae8a727c 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -73,7 +73,14 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: """ # align graph with input tensors itensors: Tuple[IRObject, ...] = self.inputs() - assert len(args) == len(itensors) + if len(args) != len(itensors): + print(f'ERROR(skipping) len(args) != len(itensors): {len(args)} != {len(itensors)}') + if len(args) > len(itensors): + args = args[:len(itensors)] + print(f'WARNING: args shrinked into {args}') + else: + raise RuntimeError('len(args) < len(itensors)') + for idx, (itensor, arg) in enumerate(zip(itensors, args)): self.set_input(idx, arg) for producer in self.producers(itensor.parent): diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index 29f71419..0480a407 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -55,6 +55,7 @@ def add_var(self, var_name: str, val: Any, graph_arg: int = -1): and link the name of the argument name from the callee function to the names of the argument passed-in. """ + if not isinstance(var_name, str): raise RuntimeError("Expected var_name is str") if var_name in self._vars[-1]: diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 827ab515..9eef0678 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -76,6 +76,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('relu'): function.ReLU, __fcntemplate('gelu'): function.GeLU, __ttemplate('eq') : function.CompareEQ, + '_operator.eq': function.CompareEQ, __ttemplate('ne') : function.CompareNE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, @@ -86,6 +87,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ftemplate('softmax') : function.Softmax, __ttemplate('bmm') : function.BatchLinear, __ttemplate('pow'): function.Pow, + '_operator.pow': function.Pow, __ttemplate('baddbmm'): function.BMMAdd, __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, @@ -151,6 +153,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('div') : function.Div, __ttemplate('true_divide'): function.Div, + '_operator.truediv': function.Div, __ttemplate('floor_divide') : function.FloorDiv, '_operator.floordiv': function.FloorDiv, @@ -158,6 +161,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] '_operator.neg': function.Neg, # __ttemplate('gt'): function.CompareGT, + '_operator.gt': function.CompareGT, __ttemplate('lt'): function.CompareLT, __ttemplate('ge'): function.CompareGE, __ttemplate('le'): function.CompareLE, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 1c111c81..756cfaf3 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -67,6 +67,10 @@ def parse(module: torch.fx.GraphModule, dummy_inputs = None, frame: Frame = None) \ -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler + dce_module = DCEHandler(module).eliminate_dead_code() + model = dce_module + """ The overall entry to parse a torch.fx graph module """ @@ -78,9 +82,12 @@ def parse(module: torch.fx.GraphModule, print(f'inputs = {inputs}') if input_shapes is not None and len(input_shapes) != len(inputs): - print(f'module(type = {type(module)}.__dict__.keys() = {module.__dict__.keys()}') - print(f'input shape mismatch (got {len(input_shapes)} != {len(inputs)})') - # TODO fixme raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + print(f'WARNING input shape mismatch (got {len(input_shapes)} != {len(inputs)})') + if len(input_shapes) < len(inputs): + raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") + else: + input_shapes = input_shapes[:len(inputs)] + print(f'WARNING input_shapes shrinked to {input_shapes})') default_dtype = torch.get_default_dtype() kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype @@ -167,9 +174,12 @@ def parse(module: torch.fx.GraphModule, if ir_nodes is not None: all_ir_nodes += ir_nodes - output_nodes = [node for node in module.graph.nodes if node.op == 'output'] - assert len(output_nodes) == 1, f"get mutiple {len(all_ir_nodes)} output nodes" - output_val = frame.get_var(output_nodes[0].name) + # output_nodes = [node for node in module.graph.nodes if node.op == 'output'] + # assert len(output_nodes) == 1, f"get mutiple {len(all_ir_nodes)} output nodes" + # output_val = frame.get_var(output_nodes[0].name) + output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + + frame.pop_var() frame.pop_attr() @@ -320,29 +330,29 @@ def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr ir_nodes = [] # handle complex outputs - # def generate_outputs(val: Any, _ops: List) -> IRObject: - # """Support complex data type of List, Tuple, Dict, Tensor/Object""" - # if isinstance(val, list): - # inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) - # output = IRObject() - # _ops.append(IRPyFunc('(lambda *args: list(args))', inputs, [output])) - # return output - # if isinstance(val, tuple): - # inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) - # output = IRObject() - # _ops.append(IRPyFunc('(lambda *args: args)', inputs, [output])) - # return output - # if isinstance(val, dict): - # output = IRObject() - # assert all(not isinstance(key, torch.fx.Node) for key in val.keys()), f"output dict cannot have torch.fx.Node is key" - # keys = tuple(str(key) for key in val.keys()) - # values = generate_outputs(tuple(generate_outputs(value, _ops) for value in val.values()), _ops) - # _ops.append(IRPyFunc('(lambda vals, keys: {key:val for key,val in zip(keys,vals)})', [values], [output], keys=keys)) - # return output - # if isinstance(val, torch.fx.Node): - # return frame.get_var(val.name) - # return val - # output = generate_outputs(node.args[0], ir_nodes) + def generate_outputs(val: Any, _ops: List) -> IRObject: + """Support complex data type of List, Tuple, Dict, Tensor/Object""" + if isinstance(val, list): + inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) + output = IRObject() + _ops.append(IRPyFunc('(lambda *args: list(args))', inputs, [output])) + return output + if isinstance(val, tuple): + inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) + output = IRObject() + _ops.append(IRPyFunc('(lambda *args: args)', inputs, [output])) + return output + if isinstance(val, dict): + output = IRObject() + assert all(not isinstance(key, torch.fx.Node) for key in val.keys()), f"output dict cannot have torch.fx.Node is key" + keys = tuple(str(key) for key in val.keys()) + values = generate_outputs(tuple(generate_outputs(value, _ops) for value in val.values()), _ops) + _ops.append(IRPyFunc('(lambda vals, keys: {key:val for key,val in zip(keys,vals)})', [values], [output], keys=keys)) + return output + if isinstance(val, torch.fx.Node): + return frame.get_var(val.name) + return val + output = generate_outputs(node.args[0], ir_nodes) # def generate_outputs(val: Any) -> Any: # """Support complex data type of List, Tuple, Dict, Tensor/Object""" @@ -358,10 +368,10 @@ def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr # return val # output = generate_outputs(node.args[0]) - # TODO: support more complex data type - outs = (node.args[0],) if isinstance(node.args[0], torch.fx.Node) else node.args[0] - assert all(isinstance(t, torch.fx.Node) for t in outs), "Only support model return with tuple of " - output = [frame.get_var(t.name) for t in outs] + # # TODO: support more complex data type + # outs = (node.args[0],) if isinstance(node.args[0], torch.fx.Node) else node.args[0] + # assert all(isinstance(t, torch.fx.Node) for t in outs), "Only support model return with tuple of " + # output = [frame.get_var(t.name) for t in outs] frame.set_var(node.name, output) return ir_nodes diff --git a/cube/program.py b/cube/program.py index d165b4ca..ef32d95b 100644 --- a/cube/program.py +++ b/cube/program.py @@ -103,9 +103,9 @@ def generate_output(sample): return tuple(generate_output(t) for t in sample) if isinstance(sample, list): return list(generate_output(t) for t in sample) - # if isinstance(sample, dict): - # assert all(isinstance(key, (str, int)) for key in sample.keys()) - # return {key:generate_output(val) for key, val in sample.items()} + if isinstance(sample, dict): + assert all(isinstance(key, (str, int)) for key in sample.keys()) + return {key:generate_output(val) for key, val in sample.items()} # if isinstance(sample, set): # return {generate_output(t) for t in sample} if isinstance(sample, torch.Tensor): @@ -118,7 +118,13 @@ def generate_output(sample): outputs = generate_output(sample) # create dataloader - data_num = len(outputs) if isinstance(outputs, tuple) else 1 + if isinstance(outputs, (tuple, list)): + data_num = len(outputs) + elif isinstance(outputs, dict): + data_num = len(outputs.keys()) + else: + data_num = 1 + data_op = IRDataOperation(data_num=data_num, batch_dims=self.get_batch_dims()) if not isinstance(outputs, tuple): data_op.set_output(0, outputs) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index daea5538..5b00633a 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -138,7 +138,6 @@ def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, self.shapes = tuple([list(shape) for shape in shapes]) self.dtypes = dtypes batch_size = shapes[0][batch_dims[0]] - assert not names super().__init__(batch_size, batch_dims) self.names = names self.append_args=append_args @@ -151,9 +150,12 @@ def __iter__(self): return self def __next__(self): - if self.names: + if self.names is not None: assert len(self.names) == len(self.buffer) - return dict(zip(self.names, self.buffer)).update(self.append_args) + ret_dict = dict(zip(self.names, self.buffer)) + if self.append_args is not None: + ret_dict = ret_dict.update(self.append_args) + return ret_dict else: return self.buffer @@ -183,7 +185,7 @@ def set_output(self, datas: Union[torch.Tensor, Tuple[torch.Tensor]]): if len(datas) == 0: self.buffer = None else: - self.buffer = datas[0] if len(datas) == 1 else datas + self.buffer = datas #will not convert like: datas[0] if len(datas) == 1 else datas def set_batch_size(self, batch_size: int): self.batch_size = batch_size diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index 589d9e55..cbe018b8 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -67,8 +67,7 @@ print("building model succeed: ", type(model)) # create dummy input -with open('examples/nlp/torchscale/input_lm.bak', 'rb') as f: -# with open('examples/nlp/torchscale/lm_input_v2.pkl', 'rb') as f: +with open('examples/nlp/torchscale/input_lm', 'rb') as f: dummy_input = pickle.load(f) device = next(model.parameters()).device print(f'device = {device}') @@ -97,21 +96,29 @@ print(f'input_dtypes = {input_dtypes}') dataloader = cube.runtime.syndata.SynDataLoader( + # names=('src_tokens',), shapes=(input_shapes), dtypes=input_dtypes, - batch_dims=(0,0), + batch_dims=(0, 0), ) + sample_input = next(dataloader) print(f'next(dataloader) = {sample_input}') -sample_input_cpu = tuple([val.to(device) if isinstance(val, torch.Tensor) else val for val in sample_input]) +if isinstance(sample_input, tuple): + sample_input_cpu = tuple([val.to(device) if isinstance(val, torch.Tensor) else val for val in sample_input]) +elif isinstance(sample_input, dict): + sample_input_cpu = sample_input + for key in sample_input_cpu.keys(): + sample_input_cpu[key] = sample_input_cpu[key].to(device) +else: + raise RuntimeError(f'To fix sample_input with type{type(sample_input)}') + model = cube.SemanticModel( - #TODO fix me model, dummy_input=sample_input_cpu, - # model, dummy_input=dummy_input_list, model, dummy_input=dummy_input, ) -@cube.compile(model, dataloader, PAS=PAS, load_content=False) +@cube.compile(model, dataloader, PAS=PAS, load_content=False, override=True) def train_iter(model, dataloader): data = next(dataloader) loss = model(*data) @@ -123,86 +130,5 @@ def train_iter(model, dataloader): iter_ret = train_iter(model, dataloader) print(f'iter_ret = {iter_ret}') -# Conduct concrete trace below -# sys.path.append('/home/v-junliang/torchscaletest/nni') -# sys.path.append('./torchscaletest/nni') -# from nni.common.concrete_trace_utils import concrete_trace -# from concrete_trace_utils import concrete_trace -from examples.nlp.torchscale.concrete_trace_utils import concrete_trace -import examples.nlp.torchscale.torchscaletest.torchscale - - -def check_equal(a, b): - if type(a) != type(b): - return False - if isinstance(a, (list, tuple, set)): - if len(a) != len(b): - return False - for sub_a, sub_b in zip(a, b): - if not check_equal(sub_a, sub_b): - return False - return True - elif isinstance(a, dict): - keys_a, kes_b = set(a.keys()), set(b.keys()) - if keys_a != kes_b: - return False - for key in keys_a: - if not check_equal(a[key], b[key]): - return False - return True - elif isinstance(a, torch.Tensor): - return torch.equal(a, b) - else: - return a == b - - -print("start tracing...") -traced_model, _ = concrete_trace( - model, - dummy_input, - use_operator_patch=True, - autowrap_leaf_class={ - torch.finfo: ((), False), - type(output_origin): ((), False), - }, -) -print("trace succeed") -print("checking equal...") -with torch.no_grad(): - output_traced = traced_model(**dummy_input) -assert check_equal(output_origin, output_traced), "check equal failed" -print("checked") - -# check graph -traced_model.graph.print_tabular() - -# with open('input_tl', 'wb') as f: -# pickle.dump(dummy_input, f) - -# try to save traced model with pickle -# from concrete_trace_utils.concrete_tracer import MagicMethodPatcher -# from pickle import _Pickler, _Unpickler - -# with open("save/through_nn_Module/tl_traced_v2.model", "wb") as f: -# # pickle.dump(traced_model, f) -# with MagicMethodPatcher(): -# _Pickler(f).dump(traced_model) - -# with open("save/through_nn_Module/tl_traced.model", "rb") as f: -# with MagicMethodPatcher(): -# reload_model = _Unpickler(f).load() - - -# with torch.no_grad(): -# output_reload = reload_model(**dummy_input) -# assert check_equal(output_origin, output_reload), "reload check equal failed" -# print("reload is good!") - -# with open("save/through_nn_Module/tl_origin_v2.model", "wb") as f: -# with MagicMethodPatcher(): -# _Pickler(f).dump(model) - -# with open("save/through_nn_Module/tl_input_v2.pkl", "wb") as f: -# with MagicMethodPatcher(): -# _Pickler(f).dump(dummy_input) - +import sys +sys.exit(0) \ No newline at end of file From 6d11870a0fb2e25f94a74e924c6d853af0812137 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Mon, 13 Mar 2023 00:07:38 -0700 Subject: [PATCH 1343/1892] minor --- examples/nlp/torchscale/run_torchscale_lm.py | 264 ++++++++----------- 1 file changed, 114 insertions(+), 150 deletions(-) diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index 0cf4d2ea..ba2e6bfc 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -17,53 +17,18 @@ import sys -import os -print(f'os.getcwd() = {os.getcwd()}') - - # https://github.com/microsoft/torchscale/tree/main/examples/fairseq -# sys.path.append('/home/v-junliang/torchscaletest/torchscale/examples/fairseq') -# sys.path.append('./torchscaletest/torchscale/examples/fairseq') -sys.path.append('/home/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') -sys.path.append('/home/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale') +# sys.path.append('/home/quzha/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') +# sys.path.append('/home/quzha/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale') +sys.path.append('/home/quzha/quzha/torchscale/examples/fairseq') print(f'sys.path = {sys.path}') import models -#:torchscaletest/torchscale import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -sys.path.append('.') -from policy import mpmd, spmd -# import examples.nlp.torchscale.policy.spmd as spmd - -# import argparse -# parser = argparse.ArgumentParser(description='comm primitive') -# parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -# parser.add_argument('--local_rank', type=int, default=0) -# args = parser.parse_args() - -# build model +# # build model parser = options.get_training_parser() -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -# parser.add_argument('--local_rank', type=int, default=0) - args = options.parse_args_and_arch(parser) - -cube.init() -# set up policy -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - cfg = convert_namespace_to_omegaconf(args) task = tasks.setup_task(cfg.task) model = task.build_model(cfg.model) @@ -71,8 +36,7 @@ print("building model succeed: ", type(model)) # create dummy input -with open('/home/quzha/MagicCube/examples/nlp/torchscale/input_lm', 'rb') as f: -# with open('examples/nlp/torchscale/lm_input_v2.pkl', 'rb') as f: +with open('/home/quzha/quzha/MagicCube/examples/nlp/torchscale/input_lm', 'rb') as f: dummy_input = pickle.load(f) device = next(model.parameters()).device print(f'device = {device}') @@ -230,112 +194,112 @@ def __init__(self, **kwargs): self.task_name) self.allow_recom_ops = [] self.del_dim = [] -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Torchscale benchmark') - parser.add_argument('--fp16', - action='store_true', - help='use fp16 for the training') - parser.add_argument('--fine_grained_GPT', - action='store_true', - help='model = GPTFineGrained') - parser.add_argument('--GPT_setting', - type=str, - default='6.7B', - help='set GPT model type') - parser.add_argument('--save_folder', - type=str, - default='exp_data', - help='set the save folder for experiment data') - parser.add_argument('--micro_batch_size', - type=int, - default=8, - help='set micro batch size') - parser.add_argument('--global_batch_size', - type=int, - default=8, - help='set the global batch size') - parser.add_argument('--iter_num', - type=int, - default=2, - help='set the number of all iterations') - parser.add_argument('--warm_num', - type=int, - default=1, - help='set the number of warmup iterations') - parser.add_argument('--recompute', - action='store_true', - help='set recompute flag') - parser.add_argument('--memory_constraint', - type=float, - default=32, - help='memory constraint for program') - parser.add_argument('--memory_granularity', - type=int, - default=1, - help='memory granularity in byte') - parser.add_argument('--profile_dir', - type=str, - default=str(Path.home()) + '/.autodist', - help='profile dir') - parser.add_argument('--connect_type', - type=str, - default='NV2', - help='connect type from nvidia-smi topo -m') - parser.add_argument('--use_prev_plan', - action='store_true', - help='run from previous plan') - parser.add_argument('--is_train', - action='store_true', - help='True: train, False: inference') - parser.add_argument('--topk', - type=int, - default=20, - help='generate multiple plans for robustness') - parser.add_argument('--mesh_row', type=int, default=1, help='node num') - parser.add_argument('--mesh_col', - type=int, - default=2, - help='dev num in a node') - parser.add_argument('--compile', - action='store_true', - help='compile stage: true, runtime stage: false') - parser.add_argument('--pipeline', - action='store_true', - help='pipeline: true, tensor parallel: false') - parser.add_argument('--nproc', - type=int, - default=12, - help='multiprocess deg in pipeline') - parser.add_argument('--adaptive_recom', - action='store_true', - help='allow adaptive recompute') - parser.add_argument('--plan_idx', - type=int, - default=0, - help='runtime plan idx') - parser.add_argument('--verbose', action='store_true', help='verbose mode') - parser.add_argument('--ignore_small_tensor_threshold', - type=int, - default=0, - help='set the tensor size threshold to ignore') - parser.add_argument('--parse_plan', - action='store_true', - help='parse plan to user-friendly format') - parser.add_argument('--alphafold', - action='store_true', - help='use alphafold2') - parser.add_argument('--alphafold_setting', - type=int, - default=1, - help='1: bs, s, r = 1, 128, 256'\ - '2: bs, s, r = 1, 512, 256'\ - '3: bs, s, r = 1, 512, 384') - args = parser.parse_args() - - # if args.compile: - # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' - - task_config = TorchscaleTaskConfig(**vars(args)) - from autodist.apis import calc_parallel_plan - topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) +# if __name__ == '__main__': +# import argparse +# parser = argparse.ArgumentParser(description='Torchscale benchmark') +# parser.add_argument('--fp16', +# action='store_true', +# help='use fp16 for the training') +# parser.add_argument('--fine_grained_GPT', +# action='store_true', +# help='model = GPTFineGrained') +# parser.add_argument('--GPT_setting', +# type=str, +# default='6.7B', +# help='set GPT model type') +# parser.add_argument('--save_folder', +# type=str, +# default='exp_data', +# help='set the save folder for experiment data') +# parser.add_argument('--micro_batch_size', +# type=int, +# default=8, +# help='set micro batch size') +# parser.add_argument('--global_batch_size', +# type=int, +# default=8, +# help='set the global batch size') +# parser.add_argument('--iter_num', +# type=int, +# default=2, +# help='set the number of all iterations') +# parser.add_argument('--warm_num', +# type=int, +# default=1, +# help='set the number of warmup iterations') +# parser.add_argument('--recompute', +# action='store_true', +# help='set recompute flag') +# parser.add_argument('--memory_constraint', +# type=float, +# default=32, +# help='memory constraint for program') +# parser.add_argument('--memory_granularity', +# type=int, +# default=1, +# help='memory granularity in byte') +# parser.add_argument('--profile_dir', +# type=str, +# default=str(Path.home()) + '/.autodist', +# help='profile dir') +# parser.add_argument('--connect_type', +# type=str, +# default='NV2', +# help='connect type from nvidia-smi topo -m') +# parser.add_argument('--use_prev_plan', +# action='store_true', +# help='run from previous plan') +# parser.add_argument('--is_train', +# action='store_true', +# help='True: train, False: inference') +# parser.add_argument('--topk', +# type=int, +# default=20, +# help='generate multiple plans for robustness') +# parser.add_argument('--mesh_row', type=int, default=1, help='node num') +# parser.add_argument('--mesh_col', +# type=int, +# default=2, +# help='dev num in a node') +# parser.add_argument('--compile', +# action='store_true', +# help='compile stage: true, runtime stage: false') +# parser.add_argument('--pipeline', +# action='store_true', +# help='pipeline: true, tensor parallel: false') +# parser.add_argument('--nproc', +# type=int, +# default=12, +# help='multiprocess deg in pipeline') +# parser.add_argument('--adaptive_recom', +# action='store_true', +# help='allow adaptive recompute') +# parser.add_argument('--plan_idx', +# type=int, +# default=0, +# help='runtime plan idx') +# parser.add_argument('--verbose', action='store_true', help='verbose mode') +# parser.add_argument('--ignore_small_tensor_threshold', +# type=int, +# default=0, +# help='set the tensor size threshold to ignore') +# parser.add_argument('--parse_plan', +# action='store_true', +# help='parse plan to user-friendly format') +# parser.add_argument('--alphafold', +# action='store_true', +# help='use alphafold2') +# parser.add_argument('--alphafold_setting', +# type=int, +# default=1, +# help='1: bs, s, r = 1, 128, 256'\ +# '2: bs, s, r = 1, 512, 256'\ +# '3: bs, s, r = 1, 512, 384') +# args = parser.parse_args() + +# # if args.compile: +# # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' + +# task_config = TorchscaleTaskConfig(**vars(args)) +# from autodist.apis import calc_parallel_plan +# topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) From 51df03da7599d5d8babe5b1cb157b91da5aae86f Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Mon, 13 Mar 2023 01:41:24 -0700 Subject: [PATCH 1344/1892] profiling torchscale --- cube/graph/function/dimops.py | 6 +-- cube/graph/function/function.py | 61 ++++++++++--------------------- cube/graph/parser/parserfx.py | 11 ++---- cube/runtime/function/function.py | 11 ++---- 4 files changed, 31 insertions(+), 58 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index d095f679..132f0469 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -658,9 +658,9 @@ def infer_shape(self) -> bool: shape.append(accum) otensor.shape = shape # set output shape - if isinstance(otensor, IRSubTensor): - otensor.parent.dtype = odtype - otensor.dtype = odtype + # if isinstance(otensor, IRSubTensor): + # otensor.parent.dtype = odtype + # otensor.dtype = odtype # print(f'=> sign: {self.signature} anno: {self.anno}\n' # f'=> inputs: {self.inputs()}\n' # f'=> outputs: {self.outputs()}') diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 4c4f1a03..ea367d2e 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -410,37 +410,22 @@ def Add(input, other, alpha=1, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input + alpha * other signature = 'torch.add' + annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) - else: - annos = ['* -> *'] - if isinstance(input, IRTensor): - return IRDimops(Add, 'add', signature, annos, [input], other=other, alpha=alpha) - else: - return IRDimops(Add, 'add', signature, annos, [other], other=input, alpha=alpha) - - -def CubeSub(input, other, alpha=1, *, out=None, signature = None): - signature = 'cube.runtime.function.sub' - if isinstance(input, IRTensor) and isinstance(other, IRTensor): - lshape, rshape, oshape = _handle_broadcast(input, other) - annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(CubeSub, 'sub', signature, annos, [input, other], alpha=alpha, swap_operands=False) - else: - annos = ['* -> *'] - if isinstance(input, IRTensor): - return IRDimops(CubeSub, 'sub', signature, annos, [input], other=other, alpha=alpha, swap_operands=False) - else: - return IRDimops(CubeSub, 'sub', signature, annos, [other], other=input, alpha=alpha, swap_operands=True) + return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) def Sub(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input - alpha * other - return CubeSub(input, other, alpha, out=out, signature=signature) + annos = ['*, ? -> *', '?, * -> *',] + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(CubeSub, 'sub', signature, annos, [input, other], alpha=alpha) def Mul(input, other, *, out=None, signature = None): @@ -448,31 +433,22 @@ def Mul(input, other, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input * other signature = 'torch.mul' + annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Mul, 'mul', signature, annos, [input, other]) - else: - annos = ['* -> *'] - if isinstance(input, IRTensor): - return IRDimops(Mul, 'mul', signature, annos, [input], other=other) - else: - return IRDimops(Mul, 'mul', signature, annos, [other], other=input) + return IRDimops(Mul, 'mul', signature, annos, [input, other]) def Div(input, other, *, rounding_mode=None, out=None, signature = None): assert rounding_mode is None and out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input / other + annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Div, 'div', signature, annos, [input, other]) - else: - # if not all tensors, the second must not be IRObject - assert isinstance(input, IRTensor) and not isinstance(other, IRObject) - annos = ['* -> *'] - return IRDimops(Div, 'div', signature, annos, [input], other=other) + return IRDimops(Div, 'div', signature, annos, [input, other], rounding_mode=rounding_mode) def FloorDiv(input, other, *, out=None, signature = None): @@ -1176,7 +1152,7 @@ def Conv3D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, stride=stride, padding=padding, dilation=dilation, groups=groups) -def CubeCat(*tensors, dim: int, signature = None): +def CubeCat(*tensors, dim=0, signature = None): """ torch.cat(tensors, dim=0, *, out=None) """ @@ -1185,6 +1161,7 @@ def CubeCat(*tensors, dim: int, signature = None): # with dimension. dim=None is for the support of kwarg inputs from torchfx assert all(isinstance(tensor, IRTensor) for tensor in tensors) assert isinstance(dim, int) + signature = 'cube.runtime.function.cat' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] dimlens = [t.shape[dim] for t in tensors] for ashape, dimlen in zip(iannos, dimlens): @@ -1195,15 +1172,11 @@ def CubeCat(*tensors, dim: int, signature = None): return IRDimops(CubeCat, 'cat', signature, [anno], tensors, dim=dim) -def Cat(*tensors_and_dim, dim=0, out=None, signature=None): +def Cat(tensors, dim=0, out=None, signature=None): """ torch.cat(tensors, dim=0, *, out=None) """ assert out is None - if len(tensors_and_dim) == 2: - tensors, dim = tensors_and_dim[0], tensors_and_dim[1] - else: - tensors = tensors_and_dim[0] return CubeCat(*tensors, dim=dim, signature=signature) @@ -1224,7 +1197,13 @@ def CubeStack(*tensors, dim=0, signature=None): def Stack(tensors, dim=0, out=None, signature = None): """ torch.stack(tensors, dim=0, *, out=None) + It needs CubeStack and runtime.function.stack, because + (i) if the tensors are packed in a list or tuple, it is treated as a whole tensor which is not aligned + with tensor partitioning; + (ii) if the tensors are not packed in a list or tuple, torch.stack cannot receive unpacked tensors. + """ + assert out is None return CubeStack(*tensors, dim=dim, signature=signature) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 788468c1..be193687 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -90,7 +90,6 @@ def parse(module: torch.fx.GraphModule, print(f'WARNING input_shapes shrinked to {input_shapes})') default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype if input_shapes is not None: # shape propagation sample_inputs = dummy_inputs if dummy_inputs else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] @@ -105,7 +104,7 @@ def parse(module: torch.fx.GraphModule, for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] - dtype = kDefaultType + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) else: @@ -137,7 +136,7 @@ def parse(module: torch.fx.GraphModule, # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name print(f'dummy_inputs does not have {input.name}') shape = None - dtype = kDefaultType + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) input_val = [frame.get_var(input.name) for input in inputs] @@ -321,11 +320,9 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram tensor_name = node.name if 'tensor_meta' in node.meta: tensor_shape = node.meta['tensor_meta'].shape - # tensor_dtype = node.meta['tensor_meta'].dtype #TODO assume it is weight - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=kDefaultType) + dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) + ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=dtype) ir_tensor.as_param() frame.add_var(tensor_name, ir_tensor) else: diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 084f042f..8378519a 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -34,12 +34,6 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) -def sub(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float], swap_operands: bool) -> torch.Tensor: - if swap_operands: - return torch.sub(other, input, alpha=alpha) - else: - return torch.sub(input, other, alpha=alpha) - def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: return input.expand(*sizes) @@ -131,4 +125,7 @@ def einsum(*operands, equation=None) -> torch.Tensor: return torch.einsum(equation, *operands) def stack(*tensors, dim=0) -> torch.Tensor: - return torch.stack(tensors, dim) \ No newline at end of file + return torch.stack(tensors, dim) + +def cat(*tensors, dim=0) -> torch.Tensor: + return torch.cat(tensors, dim) From 80d04c8822b22838ce7026a6617a87bff802c6a6 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Mon, 13 Mar 2023 04:10:16 -0700 Subject: [PATCH 1345/1892] update --- cube/graph/function/function.py | 1 + cube/profiler/database.py | 73 +++++++----- examples/nlp/torchscale/run_torchscale_lm.py | 117 ++----------------- 3 files changed, 55 insertions(+), 136 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index ea367d2e..40eb7b06 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -444,6 +444,7 @@ def Div(input, other, *, rounding_mode=None, out=None, signature = None): assert rounding_mode is None and out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input / other + signature = 'torch.div' annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 3fe6bbd5..cb421350 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -7,9 +7,10 @@ import time import os import json +import _operator import cube -from cube.ir.cten import IRTensor +from cube.ir.cten import IRTensor, IRObject from cube.ir.operator import IRFwOperation from cube.graph.parser.mapping import IRDType2TorchDType from cube.graph.parser.mappingfx import SignFx2Op as Sign2Op @@ -29,7 +30,7 @@ class CompProfiler: @staticmethod def profile(func: Callable, shapes: Shapes, dtypes: DTypes, - requires_grads: Tuple[bool], + requires_grads: Tuple[bool], values: Tuple[Any], warmup_sec: float = 2, prof_times: int = 50, **kwargs) -> Tuple[float, float, int, Tuple[int]]: """ @@ -50,15 +51,15 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, assert len(shapes) == len(dtypes), \ f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" # create data - dtypes = [torch.float32] * len(shapes) if dtypes is None else dtypes + assert dtypes is not None def gen_torch_tensors(shape, dtype, requires_grad): constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand # requires_grad = False if dtype in (torch.int64, torch.int32, torch.bool) else True return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) tensors = tuple( - gen_torch_tensors(shape, dtype, requires_grad) for shape, dtype, requires_grad in zip(shapes, dtypes, requires_grads) + gen_torch_tensors(shape, dtype, requires_grad) if value is None else value for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) ) - require_backward = any([t.requires_grad for t in tensors]) + require_backward = any([t.requires_grad for t in tensors if hasattr(t, 'requires_grad')]) # FIXME: reconsidering requires_grad if func.__name__ in ('type_as'): require_backward = False @@ -190,22 +191,30 @@ def get_dep_names(sign: str): exec(code_impl, globals(), local) fn = list(local.values())[0] else: - if '_operator.' in node.signature: - if '_operator.or_' == node.signature: - fn = torch.bitwise_or - elif '_operator.invert' == node.signature: - fn = torch.bitwise_not - else: - fn = eval(node.signature.replace('_operator.', 'torch.')) - else: - fn = eval(node.signature) - shapes, dtypes, requires_grads = [], [], [] + # if '_operator.' in node.signature: + # if '_operator.or_' == node.signature: + # fn = torch.bitwise_or + # elif '_operator.invert' == node.signature: + # fn = torch.bitwise_not + # else: + # fn = eval(node.signature.replace('_operator.', 'torch.')) + # else: + fn = eval(node.signature) + shapes, dtypes, requires_grads, values = [], [], [], [] for t in node.inputs(): - assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" - shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) - requires_grads.append(t.requires_grad) - return fn, shapes, dtypes, requires_grads, node.kwargs + if isinstance(t, IRTensor): + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) + requires_grads.append(t.requires_grad) + values.append(None) + elif isinstance(t, IRObject): + raise RuntimeError('IRObject has not been supported in profiling.') + else: + shapes.append(None) + dtypes.append(type(t).__name__) + requires_grads.append(None) + values.append(t) + return fn, shapes, dtypes, requires_grads, values, node.kwargs def profile(self, node: IRFwOperation, device: Optional[int] = None): """ @@ -221,7 +230,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @return infer_memory int: the peak memory in bytes after inference of the function @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ - fn, shapes, dtypes, requires_grads, kwargs = ProfileDataBase.get_func(node) + fn, shapes, dtypes, requires_grads, values, kwargs = ProfileDataBase.get_func(node) if self.exist(node): return self.query(node) @@ -233,18 +242,19 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): in_mem_info, param_mem_info = [], [] residual_mem, input_count = 0, 0 for t in node.inputs(): - if t.is_param(): + if hasattr(t, 'is_param') and t.is_param(): param_mem_info.append(t.byte_size()) - else: + elif hasattr(t, 'byte_size'): input_count += 1 if input_count == 1: residual_mem += t.byte_size() in_mem_info.append(t.byte_size()) + else: + print(f'WARNING: input {t} is skipped.') - # run profiling fw_span, bw_span, infer_memory, train_mem_info = \ - CompProfiler.profile(fn, shapes, dtypes, requires_grads, **kwargs) + CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) # log to database key = self._serialize(node) self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) @@ -363,10 +373,15 @@ def _serialize(self, node: IRFwOperation) -> str: """ shapes, dtypes = [], [] for t in node.inputs(): - assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" - shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) - shapes = '-'.join(str(tuple(shape)) for shape in shapes) + if isinstance(t, IRTensor): + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) + elif isinstance(t, IRObject): + raise RuntimeError('IRObject has not been supported in _serialize') + else: + shapes.append(None) + dtypes.append(type(t)) + shapes = '-'.join(str(tuple(shape)) if shape is not None else str(None) for shape in shapes) dtypes = '-'.join(str(dtype) for dtype in dtypes) return shapes + ' : ' + dtypes diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index ba2e6bfc..85965f2a 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -179,6 +179,7 @@ class dotdict(dict): # NOTE add SINGLE_DEV_MODE=1 before the running command from autodist.cost_model.cost_database import CostDatabase cost_database = CostDatabase(cube_graph, config) + # find the best partition plan from autodist.task_config import TaskConfig class TorchscaleTaskConfig(TaskConfig): @@ -194,112 +195,14 @@ def __init__(self, **kwargs): self.task_name) self.allow_recom_ops = [] self.del_dim = [] -# if __name__ == '__main__': -# import argparse -# parser = argparse.ArgumentParser(description='Torchscale benchmark') -# parser.add_argument('--fp16', -# action='store_true', -# help='use fp16 for the training') -# parser.add_argument('--fine_grained_GPT', -# action='store_true', -# help='model = GPTFineGrained') -# parser.add_argument('--GPT_setting', -# type=str, -# default='6.7B', -# help='set GPT model type') -# parser.add_argument('--save_folder', -# type=str, -# default='exp_data', -# help='set the save folder for experiment data') -# parser.add_argument('--micro_batch_size', -# type=int, -# default=8, -# help='set micro batch size') -# parser.add_argument('--global_batch_size', -# type=int, -# default=8, -# help='set the global batch size') -# parser.add_argument('--iter_num', -# type=int, -# default=2, -# help='set the number of all iterations') -# parser.add_argument('--warm_num', -# type=int, -# default=1, -# help='set the number of warmup iterations') -# parser.add_argument('--recompute', -# action='store_true', -# help='set recompute flag') -# parser.add_argument('--memory_constraint', -# type=float, -# default=32, -# help='memory constraint for program') -# parser.add_argument('--memory_granularity', -# type=int, -# default=1, -# help='memory granularity in byte') -# parser.add_argument('--profile_dir', -# type=str, -# default=str(Path.home()) + '/.autodist', -# help='profile dir') -# parser.add_argument('--connect_type', -# type=str, -# default='NV2', -# help='connect type from nvidia-smi topo -m') -# parser.add_argument('--use_prev_plan', -# action='store_true', -# help='run from previous plan') -# parser.add_argument('--is_train', -# action='store_true', -# help='True: train, False: inference') -# parser.add_argument('--topk', -# type=int, -# default=20, -# help='generate multiple plans for robustness') -# parser.add_argument('--mesh_row', type=int, default=1, help='node num') -# parser.add_argument('--mesh_col', -# type=int, -# default=2, -# help='dev num in a node') -# parser.add_argument('--compile', -# action='store_true', -# help='compile stage: true, runtime stage: false') -# parser.add_argument('--pipeline', -# action='store_true', -# help='pipeline: true, tensor parallel: false') -# parser.add_argument('--nproc', -# type=int, -# default=12, -# help='multiprocess deg in pipeline') -# parser.add_argument('--adaptive_recom', -# action='store_true', -# help='allow adaptive recompute') -# parser.add_argument('--plan_idx', -# type=int, -# default=0, -# help='runtime plan idx') -# parser.add_argument('--verbose', action='store_true', help='verbose mode') -# parser.add_argument('--ignore_small_tensor_threshold', -# type=int, -# default=0, -# help='set the tensor size threshold to ignore') -# parser.add_argument('--parse_plan', -# action='store_true', -# help='parse plan to user-friendly format') -# parser.add_argument('--alphafold', -# action='store_true', -# help='use alphafold2') -# parser.add_argument('--alphafold_setting', -# type=int, -# default=1, -# help='1: bs, s, r = 1, 128, 256'\ -# '2: bs, s, r = 1, 512, 256'\ -# '3: bs, s, r = 1, 512, 384') -# args = parser.parse_args() -# # if args.compile: -# # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' +kwargs = {'save_folder': 'exp_data', 'micro_batch_size': 8, 'global_batch_size': 8, 'iter_num': 2, + 'warm_num': 1, 'recompute': False, 'memory_constraint': 32, 'memory_granularity': 1, + 'profile_dir': str(Path.home())+'/.autodist/', 'connect_type': 'NV2', 'use_prev_plan': False, + 'is_train': True, 'topk': 20, 'mesh_row': 1, 'mesh_col': 2, 'compile': True, 'pipeline': False, + 'nproc': 12, 'adaptive_recom': False, 'plan_idx': 0, 'verbose': True, 'ignore_small_tensor_threshold': 0, + 'parse_plan': True} -# task_config = TorchscaleTaskConfig(**vars(args)) -# from autodist.apis import calc_parallel_plan -# topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) +task_config = TorchscaleTaskConfig(**kwargs) +from autodist.apis import calc_parallel_plan +topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) From 11b4d63fd2594f1fea0a022025c796fa67cea683 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 13 Mar 2023 12:20:22 +0000 Subject: [PATCH 1346/1892] Merged PR 1489: training support with IRPyFunc training support with IRPyFunc @ Please check the modification to cube/graph/gener/gen.py still applies to torchscale --- cube/codegen/emit.py | 34 ++++++++++++++- cube/codegen/lifecycle.py | 19 +++++---- cube/codegen/schedule/schedule.py | 2 +- cube/compiler.py | 5 ++- cube/graph/gener/gen.py | 37 ++++++++++------ cube/graph/graph.py | 29 ++++++------- cube/graph/parser/parserfx.py | 60 +++++++++++++------------- cube/graph/segment.py | 71 ++++++++++++++++++++++++++++--- cube/program.py | 38 ++++++++++++++--- cube/runtime/syndata.py | 4 +- examples/mlp/infer.py | 33 ++++++++++---- examples/mlp/linearsfx.py | 14 +----- examples/mlp/train.py | 35 ++++++++++----- examples/nlp/gpt/infer.py | 2 +- examples/nlp/gpt/model.py | 58 +++++++++++++------------ examples/nlp/mbart/model.py | 32 ++++++-------- examples/vision/swin/model.py | 27 ++++++++---- examples/vision/swin/train.py | 4 +- tests/parser/test_fx_ops.py | 37 ++++++++-------- 19 files changed, 348 insertions(+), 193 deletions(-) diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 3243436c..59436b9f 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -7,11 +7,22 @@ from cube.ir.adapter import IRWeightReducer, IRAdapter from cube.ir.adapter.prim import IRAdapterPrim +from cube.graph.segment import IRSegment + from cube.codegen.frontend_mapping import Sign2EmitRule from cube.flags import CompileFlag +class IRValue: + + def __init__(self, name: str): + self.name = name + + def __repr__(self): + return self.name + + class CodeEmission: """ Basic emission @@ -49,7 +60,17 @@ def tensor_name(tensor: Any, prefix_attr: Optional[str] = None) -> str: else: name = str(tensor) return name - + + @staticmethod + def complex_name(val: Any, prefix_attr: Optional[str]=None) -> str: + """ + Return the val name with complex data type over IRObject + Currently support complex data type of Dict, List, Tuple, IRObject + """ + modifier = lambda t: IRValue(CodeEmission.tensor_name(t, prefix_attr)) + val = IRSegment.modify_objects_of_complex(val, modifier) + return str(val) + @staticmethod def tuple_name(tensors: List[Any], skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: @@ -81,6 +102,17 @@ def return_name(tensors: List[Any], names = '_' if len(names) == 0 else ', '.join(names) return names + @staticmethod + def return_name_complex(vals: List[Any], + skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: + names = [] + for t in vals: + if isinstance(t, IRObject) and skip_attr and t.is_attr(): + continue + names.append(CodeEmission.complex_name(t, prefix_attr)) + names = '_' if len(names) == 0 else ', '.join(names) + return names + @staticmethod def kwargs_name(**kwargs) -> str: names = [] diff --git a/cube/codegen/lifecycle.py b/cube/codegen/lifecycle.py index a3c7833b..a76135d3 100644 --- a/cube/codegen/lifecycle.py +++ b/cube/codegen/lifecycle.py @@ -1,7 +1,7 @@ -from typing import Iterable, Dict, List +from typing import Iterable, Dict, List, Any import itertools -from cube.ir.cten import IRCell, IRTensor +from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.tensor import IRSubTensor from cube.graph.segment import IRSegment from cube.execplan.execplan import ExeReuseCell @@ -11,22 +11,25 @@ class LifeCycle: - def __init__(self, nodes: List[IRCell], graph_inputs: List[IRSubTensor], graph_outputs: List[IRSubTensor]): + def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: List[Any]): + + graph_inputs = IRSegment.get_objects_from_complex(graph_inputs) + graph_outputs = IRSegment.get_objects_from_complex(graph_outputs) self.nodes: Dict[int] = {node: lid for lid, node in enumerate(nodes)} # the last line id of consuming or producing a tensor - self.lifetime: Dict[IRSubTensor, int] = {} + self.lifetime: Dict[IRObject, int] = {} # the tensors can be released given the finish of line id - self.release: Dict[int, List[IRSubTensor]] = {} + self.release: Dict[int, List[IRObject]] = {} - is_activation = lambda t: isinstance(t, IRSubTensor) and not t.is_attr() + is_activation = lambda t: isinstance(t, IRObject) and not t.is_attr() self.lifetime.update((tsin, 0) for tsin in graph_inputs if is_activation(tsin)) for i, node in enumerate(nodes): - outputs : Iterable[IRTensor] - inputs : Iterable[IRTensor] + outputs : Iterable[IRObject] + inputs : Iterable[IRObject] if isinstance(node, (IRSegment, ExeReuseCell)): # forward segment diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index e60f4f01..90a9a56e 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -65,7 +65,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: fb.insert_body(ScheduleCodeGen.emit_release(tensors)) # return code - outputs = ScheduleCodeGen.return_name(self.execplan.graph.outputs()) + outputs = ScheduleCodeGen.return_name_complex(self.execplan.graph.outputs()) code = f'return {outputs}' fb.insert_body(code) gencode += fb.code diff --git a/cube/compiler.py b/cube/compiler.py index fa093d6d..238f6950 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -74,6 +74,8 @@ def train_step(model, dataloader): raise TypeError("Expect dataloader derived from CubeDataLoader") model.save_content = load_content + if model.dummy_input is None: + model.dummy_input = next(dataloader) ir_dataloader = SemanticDataLoader(dataloader) myrank = DeviceGroup().rank @@ -104,7 +106,6 @@ def decorator(fn: Callable) -> Callable: if DeviceGroup().local_rank == 0: compile_start = time.time() - resource = cube.runtime.resource.EnvResource() # run once to get model structure and tensor shape @@ -222,8 +223,10 @@ def decorator(fn: Callable) -> Callable: if torch.distributed.is_initialized(): torch.distributed.barrier() + model.dummy_input = None # set dataloder batch size (serialize output) bs = model.get_gen_module().get_batch_size() + print_each_rank(f'> setting batch size to: {bs}') if torch.distributed.is_initialized(): for rank in range(torch.distributed.get_world_size()): if rank == torch.distributed.get_rank(): diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index dc0c9935..9badd327 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -36,7 +36,8 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) # create inputs if inputs: - for tensor in segment.inputs(): + input_objects = IRGraph.get_objects_from_complex(segment.inputs()) + for tensor in input_objects: devices = [consumer.device for consumer in segment.consumers(tensor.parent)][::-1] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" @@ -53,7 +54,8 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) # create outputs if outputs: - for tensor in segment.outputs(): + output_objects = IRGraph.get_objects_from_complex(segment.outputs()) + for tensor in output_objects: devices = [producer.device for producer in segment.producers(tensor.parent)] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" @@ -155,23 +157,30 @@ def auto_pyfunc(graph: IRSegment): devices = set() for t in func.inputs(): if not isinstance(t, IRObject): continue - producers = graph.producers(t.parent) - for p in producers: - devices.update(p.device) + if t.is_attr(): + cells = graph.consumers(t.parent) + else: + cells = graph.producers(t.parent) + for cell in cells: + devices.update(cell.device) pyfuncs = [] # lower to each device for devid in devices: inputs = [] + # automatic partition to align with consumer (attr) or producer (activation) for t in func.inputs(): - if isinstance(t, IRObject): - if t.is_attr(): - tensors = set(tensor for tensor in graph.ctensors(t.parent) if devid in tensor.device) - else: - tensors = set(tensor for tensor in graph.ptensors(t.parent) if devid in tensor.device) - assert len(tensors) == 1, \ - f"Find {len(tensors)} != 1: {tensors} versions of tensor {t} on a same device." - t = list(tensors)[0] - inputs.append(t) + sub_ts = set() + if not isinstance(t, IRSubTensor): + sub_ts.add(t) # replica for non-tensor + elif t.is_attr(): + # get local consumers except func itself + sub_ts = set(tensor for tensor in graph.ctensors(t.parent) \ + if devid in tensor.device and tensor.cell != func) + else: + # get local producers + sub_ts = set(tensor for tensor in graph.ptensors(t.parent) \ + if devid in tensor.device) + inputs.append(t if len(sub_ts) == 0 else list(sub_ts)[0]) lower_func = IRPyFunc(func.signature, inputs, func.outputs(), **func.kwargs) lower_func.device = devid pyfuncs.append(lower_func) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ae8a727c..4bac021b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,7 +7,7 @@ will be inserted at scheduling time. """ -from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any +from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any, Callable import warnings import copy @@ -153,7 +153,7 @@ def backward(self, loss: IRSubTensor): loss.parent.to_loss() # infer gradient - for ftensor in self._ftensors: + for ftensor in self.full_tensors(): self.infer_grad(ftensor) # create backward node for fnode in self.nodes()[::-1]: @@ -216,7 +216,8 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: # reset fsegment gradient for itensor in fsegment.inputs(): - fgraph.infer_grad(itensor.parent) + if isinstance(itensor, IRTensor): + fgraph.infer_grad(itensor.parent) # update backward if len(bnodes) > 0: @@ -233,29 +234,23 @@ def group(self, fnodes: List[IRCell]) -> IRSegment: @staticmethod def from_logic_graph(nodes: List[IRCell], - inputs: List[IRObject], outputs: List[IRObject], + inputs: List[Any], outputs: List[Any], module_name: str): """ Generate IRGraph from logical graph (IRFullTensor) @param nodes: nodes of the graph - @param inputs List[IRFullTensor]: graph inputs - @param outputs List[IRFullTensor]: graph outputs + @param inputs List[Any]: graph inputs + @param outputs List[Any]: graph outputs @param module_name str: graph name @return graph IRGraph """ - # instantiate graph inputs / outputs - for idx, tensor in enumerate(inputs): - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - inputs[idx] = tensor - for idx, tensor in enumerate(outputs): - if isinstance(tensor, IRFullTensor): - tensor = tensor.tosub() - outputs[idx] = tensor - - # instantiate to subtensor + modifier = lambda t: t.tosub() if isinstance(t, IRFullTensor) else t + # input / output + inputs = [IRGraph.modify_objects_of_complex(t, modifier) for t in inputs] + outputs = [IRGraph.modify_objects_of_complex(t, modifier) for t in outputs] + # nodes for node in nodes: for idx, ftensor in enumerate(node.inputs()): if isinstance(ftensor, IRObject): diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 756cfaf3..4a7d5d04 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -330,43 +330,43 @@ def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr ir_nodes = [] # handle complex outputs - def generate_outputs(val: Any, _ops: List) -> IRObject: - """Support complex data type of List, Tuple, Dict, Tensor/Object""" - if isinstance(val, list): - inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) - output = IRObject() - _ops.append(IRPyFunc('(lambda *args: list(args))', inputs, [output])) - return output - if isinstance(val, tuple): - inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) - output = IRObject() - _ops.append(IRPyFunc('(lambda *args: args)', inputs, [output])) - return output - if isinstance(val, dict): - output = IRObject() - assert all(not isinstance(key, torch.fx.Node) for key in val.keys()), f"output dict cannot have torch.fx.Node is key" - keys = tuple(str(key) for key in val.keys()) - values = generate_outputs(tuple(generate_outputs(value, _ops) for value in val.values()), _ops) - _ops.append(IRPyFunc('(lambda vals, keys: {key:val for key,val in zip(keys,vals)})', [values], [output], keys=keys)) - return output - if isinstance(val, torch.fx.Node): - return frame.get_var(val.name) - return val - output = generate_outputs(node.args[0], ir_nodes) - - # def generate_outputs(val: Any) -> Any: + # def generate_outputs(val: Any, _ops: List) -> IRObject: # """Support complex data type of List, Tuple, Dict, Tensor/Object""" # if isinstance(val, list): - # return list(generate_outputs(item) for item in val) + # inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) + # output = IRObject() + # _ops.append(IRPyFunc('(lambda *args: list(args))', inputs, [output])) + # return output # if isinstance(val, tuple): - # return tuple(generate_outputs(item) for item in val) + # inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) + # output = IRObject() + # _ops.append(IRPyFunc('(lambda *args: args)', inputs, [output])) + # return output # if isinstance(val, dict): - # return {generate_outputs(key) : generate_outputs(value) for key, value in val.items()} + # output = IRObject() + # assert all(not isinstance(key, torch.fx.Node) for key in val.keys()), f"output dict cannot have torch.fx.Node is key" + # keys = tuple(str(key) for key in val.keys()) + # values = generate_outputs(tuple(generate_outputs(value, _ops) for value in val.values()), _ops) + # _ops.append(IRPyFunc('(lambda vals, keys: {key:val for key,val in zip(keys,vals)})', [values], [output], keys=keys)) + # return output # if isinstance(val, torch.fx.Node): # return frame.get_var(val.name) - # # for other types like int, float, ... # return val - # output = generate_outputs(node.args[0]) + # output = generate_outputs(node.args[0], ir_nodes) + + def generate_outputs(val: Any) -> Any: + """Support complex data type of List, Tuple, Dict, Tensor/Object""" + if isinstance(val, list): + return list(generate_outputs(item) for item in val) + if isinstance(val, tuple): + return tuple(generate_outputs(item) for item in val) + if isinstance(val, dict): + return {generate_outputs(key) : generate_outputs(value) for key, value in val.items()} + if isinstance(val, torch.fx.Node): + return frame.get_var(val.name) + # for other types like int, float, ... + return val + output = generate_outputs(node.args[0]) # # TODO: support more complex data type # outs = (node.args[0],) if isinstance(node.args[0], torch.fx.Node) else node.args[0] diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 1e97c257..00a3de1b 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1,10 +1,10 @@ from contextlib import contextmanager -from typing import Dict, Union, List, Optional, Set, Tuple +from typing import Dict, Union, List, Optional, Set, Tuple, Any, Callable import numpy as np from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap from cube.ir.cten import IRTensor, IRCell, IRObject -from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation +from cube.ir.operator import IRFwOperation, IRBpOperation from cube.ir.adapter import IRAdapter from cube.graph.function.function import MultiRef @@ -70,6 +70,8 @@ class IRSegment(IRCell): """ A distributed sub-graph representing a piece of workload in parent IRGraph + Input/output can be complex data type of Dict, List, Tuple on IRObjects + Once the segment is generated, its input and output will be fixed. Inserting and removing nodes that could change input/output are not allowed. """ @@ -301,6 +303,8 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: f"{self.debug_tensor_map_str(ftensor)}" ) for ptensor, producer in zip(self.ptensors(ftensor), self.producers(ftensor)): + # filter out non-autograd operators of IRPyFunc + if isinstance(producer, IRPyFunc): continue idx = producer.outputs().index(ptensor) if fgrad is None: grad = None @@ -315,8 +319,16 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: f"{self.debug_tensor_map_str(ftensor)}" ) curr_valmap = ValueMap((0, 1)) - nconsumers = len(self.consumers(ftensor)) - for cidx, (ctensor, consumer) in enumerate(zip(self.ctensors(ftensor), self.consumers(ftensor))): + + # filter out non-autograd operators of IRPyFunc + consumers, ctensors = [], [] + for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): + if isinstance(consumer, IRPyFunc): continue + consumers.append(consumer) + ctensors.append(ctensor) + + nconsumers = len(consumers) + for cidx, (ctensor, consumer) in enumerate(zip(ctensors, consumers)): idx = consumer.inputs().index(ctensor) if fgrad is None: grad = None @@ -878,6 +890,9 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: @return segment IRSegment: the grouped segment. """ segment = self + segment_inputs = IRSegment.get_objects_from_complex(segment.inputs()) + segment_outputs = IRSegment.get_objects_from_complex(segment.outputs()) + # segments: List[IRSegment] = [self.segment(node) for node in nodes] # assert len(set(segments)) == 1, "Cross segment hierarchy grouping is not allowed" # segment = segments[0] @@ -930,7 +945,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: if len(node.device) > 0 and set(itensor.device).issubset(adapter_ous[itensor]): continue # from segment inputs - if any(t.overlap(itensor) for t in segment.inputs() if isinstance(t, IRObject)): + if any(t.overlap(itensor) for t in segment_inputs if isinstance(t, IRObject)): inputs.add(itensor) continue # from outside producers @@ -952,7 +967,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: if len(node.device) > 0 and set(otensor.device).issubset(adapter_ins[otensor]): continue # from segment outputs - if any(t.overlap(otensor) for t in segment.outputs() if isinstance(t, IRObject)): + if any(t.overlap(otensor) for t in segment_outputs if isinstance(t, IRObject)): outputs.add(otensor) continue # loss doesn't have consumers @@ -1048,3 +1063,47 @@ def extra_repr(self) -> str: # outputs dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" return dscp + + @staticmethod + def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[IRObject]: + """ + Get objects from val of complex data type + Support complex of types: List, Tuple, Dict, torch.Tensor, object + + @param val Any + + @return _objects List[IRObject]: all IRObject + """ + _objects = [] if _objects is None else _objects + if isinstance(val, (tuple, list)): + for item in val: + IRSegment.get_objects_from_complex(item, _objects) + if isinstance(val, dict): + for key, value in val.items(): + IRSegment.get_objects_from_complex(key, _objects) + IRSegment.get_objects_from_complex(value, _objects) + if isinstance(val, IRObject): + _objects.append(val) + return _objects + + @staticmethod + def modify_objects_of_complex(val: Any, modifier: Callable) -> Any: + """ + Get objects from val of complex data type + Support complex of types: List, Tuple, Dict, torch.Tensor, object + + @param val Any + @param modifier Callable: modify IRObject to another one + + @return new_val List[IRObject]: all IRObject + """ + rcall = IRSegment.modify_objects_of_complex + if isinstance(val, tuple): + return tuple(rcall(item, modifier) for item in val) + if isinstance(val, list): + return list(rcall(item, modifier) for item in val) + if isinstance(val, dict): + return {rcall(key, modifier):rcall(value, modifier) for key, value in val.items()} + if isinstance(val, IRObject): + return modifier(val) + return val diff --git a/cube/program.py b/cube/program.py index ef32d95b..581e8b68 100644 --- a/cube/program.py +++ b/cube/program.py @@ -1,4 +1,5 @@ -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Any +import warnings from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -43,8 +44,7 @@ def add_nodes(self, nodes: List[IRCell]): def get_graph(self) -> IRGraph: return self.instance._graph - def set_output(self, outputs: List[IRObject]): - assert all(isinstance(t, IRObject) for t in outputs) + def set_output(self, outputs: Tuple[Any]): self.instance._graph.reset_outputs(len(outputs)) for idx, otensor in enumerate(outputs): self.instance._graph.set_output(idx, otensor) @@ -106,8 +106,8 @@ def generate_output(sample): if isinstance(sample, dict): assert all(isinstance(key, (str, int)) for key in sample.keys()) return {key:generate_output(val) for key, val in sample.items()} - # if isinstance(sample, set): - # return {generate_output(t) for t in sample} + if isinstance(sample, set): + return {generate_output(t) for t in sample} if isinstance(sample, torch.Tensor): shape, dtype = list(sample.shape), dtype_map.map(sample.dtype) return IRFullTensor(shape, 'data', dtype=dtype).tosub() @@ -148,7 +148,7 @@ def __init__(self, model: Optional[torch.nn.Module], input_shapes=None, dummy_in assert isinstance(model, torch.nn.Module), f"device of local_rank == 0 must provide model" self.model = model self.input_shapes = None - self.dummy_input = dummy_input + self._dummy_input = dummy_input self.ir_graph = None self._loaded_module: CubeModule = None self._save_content = True @@ -161,6 +161,32 @@ def save_content(self) -> bool: def save_content(self, val: bool): self._save_content = val + @property + def dummy_input(self) -> Any: + """ + Get dummy real-tensor input from on CPU + """ + return self._dummy_input + + @dummy_input.setter + def dummy_input(self, val): + + def complex(val: Any): + """Complex to CPU""" + if isinstance(val, tuple): + return tuple(complex(t) for t in val) + if isinstance(val, list): + return list(complex(t) for t in val) + if isinstance(val, dict): + return {complex(key):complex(val) for key, val in val.items()} + if isinstance(val, set): + return {complex(t) for t in val} + if isinstance(val, torch.Tensor): + return val.cpu() + return val + + self._dummy_input = complex(val) + def get_graph(self): return self.ir_graph diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 5b00633a..50103dee 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -33,7 +33,7 @@ def __iter__(self): raise NotImplementedError("Required implementation for derived class") def __next__(self): - return NotImplementedError("Required implementation for derived class") + raise NotImplementedError("Required implementation for derived class") def get_batch_size(self) -> int: """ @@ -45,7 +45,7 @@ def set_batch_size(self, batch_size: int): """ set batch size """ - return NotImplementedError("Required implementation for derived class") + raise NotImplementedError("Required implementation for derived class") def get_batch_dims(self) -> Tuple[Optional[int]]: return tuple(self.batch_dims) diff --git a/examples/mlp/infer.py b/examples/mlp/infer.py index 6871a553..fe63a445 100644 --- a/examples/mlp/infer.py +++ b/examples/mlp/infer.py @@ -59,20 +59,35 @@ def forward(self, data): return loss +class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, bs: int, dim: int): + super().__init__(bs, [0]) + self.sample = None + self.dim = dim + self.set_batch_size(bs) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + self.batch_size = batch_size + self.sample = torch.rand( + [batch_size, self.dim], dtype=torch.float32, + device=torch.cuda.current_device() + ) + + def infer(): batch_size = 128 dim = 4096 model = MLP(dim=dim) - model = cube.SemanticModel( - model, input_shapes=([batch_size, dim],), - ) - - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, dim],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) + model = cube.SemanticModel(model) + dataloader = MLPDataLoader(batch_size, dim) @cube.compile(model, dataloader, PAS=PAS) def infer_iter(model, dataloader): diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index de66022b..03eb1937 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -123,19 +123,7 @@ def train(): ) model = MLP(dim=dim) - - # shape based input (will trigger standard torch.fx.Tracer) - # model = cube.SemanticModel( - # model, input_shapes=([batch_size, dim, dim], [batch_size, dim, dim]), - # ) - - # dummy based input (will trigger concrete Tracer) - device = next(model.parameters()).device - dummy_input = next(dataloader) - dummy_input = tuple([input.to(device) for input in dummy_input]) - model = cube.SemanticModel( - model, dummy_input=dummy_input, - ) + model = cube.SemanticModel(model) @cube.compile(model, dataloader, PAS=PAS, load_content=False) def train_iter(model, dataloader): diff --git a/examples/mlp/train.py b/examples/mlp/train.py index 2d58eee9..36cf7d01 100644 --- a/examples/mlp/train.py +++ b/examples/mlp/train.py @@ -4,7 +4,7 @@ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASMegatron + examples/mlp/train.py --policy PASMegatron """ import torch @@ -58,20 +58,35 @@ def forward(self, data): return loss +class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, bs: int, dim: int): + super().__init__(bs, [0]) + self.sample = None + self.dim = dim + self.set_batch_size(bs) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + self.batch_size = batch_size + self.sample = torch.rand( + [batch_size, self.dim], dtype=torch.float32, + device=torch.cuda.current_device() + ) + + def train(): batch_size = 128 dim = 4096 model = MLP(dim=dim) - model = cube.SemanticModel( - model, input_shapes=([batch_size, dim],), - ) - - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, dim],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) + model = cube.SemanticModel(model) + dataloader = MLPDataLoader(batch_size, dim) @cube.compile(model, dataloader, PAS=PAS) def train_iter(model, dataloader): diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py index 5e5d49f1..c6e8b646 100644 --- a/examples/nlp/gpt/infer.py +++ b/examples/nlp/gpt/infer.py @@ -67,7 +67,7 @@ def inter(): dataloader = GPTInferDataLoader(batch_size) ################## SuperScaler run - model = cube.SemanticModel(model, dataloader.shapes) + model = cube.SemanticModel(model) @cube.compile(model, dataloader, PAS=PAS, override=True) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 910908af..13cbcd56 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -184,20 +184,22 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): return logits -class GPTDataLoader(cube.runtime.syndata.SynDataLoader): - - def __init__(self, batch_size: int, cfg: Config = Config()): - - self.cfg = cfg - super().__init__( - shapes=([batch_size, self.cfg.seqlen], - [batch_size, self.cfg.seqlen], - ), - dtypes=(torch.int64, torch.int64), - batch_dims=(0, 0) - ) - - def random_sample(self): +class GPTDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, bs: int, cfg: Config = None): + self.cfg = Config() if cfg is None else cfg + super().__init__(bs, [0, 0]) + self.sample = None + self.set_batch_size(bs) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + self.batch_size = batch_size input_ids = torch.randint( 0, self.cfg.num_embeddings, size=(self.batch_size, self.cfg.seqlen), @@ -206,23 +208,25 @@ def random_sample(self): position_ids = torch.arange( 0, self.cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() ).repeat(self.batch_size).view(self.batch_size, -1) - return input_ids, position_ids + self.sample = (input_ids, position_ids) -class GPTInferDataLoader(cube.runtime.syndata.SynDataLoader): +class GPTInferDataLoader(cube.runtime.syndata.CubeDataLoader): - def __init__(self, batch_size: int, cfg: Config = Config()): + def __init__(self, bs: int, cfg: Config = None): + self.cfg = Config() if cfg is None else cfg + super().__init__(bs, [0, 0]) + self.sample = None + self.set_batch_size(bs) - self.cfg = cfg - super().__init__( - shapes=([batch_size, 1], - [batch_size, 1], - ), - dtypes=(torch.int64, torch.int64), - batch_dims=(0, 0) - ) + def __iter__(self): + return self - def random_sample(self): + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + self.batch_size = batch_size input_ids = torch.randint( 0, self.cfg.num_embeddings, size=(self.batch_size, 1), @@ -233,4 +237,4 @@ def random_sample(self): 0, 1, dtype=torch.int64, device=torch.cuda.current_device() ).repeat(self.batch_size).view(self.batch_size, -1) - return input_ids, position_ids + self.sample = (input_ids, position_ids) \ No newline at end of file diff --git a/examples/nlp/mbart/model.py b/examples/nlp/mbart/model.py index 3c2ee51a..302724a5 100644 --- a/examples/nlp/mbart/model.py +++ b/examples/nlp/mbart/model.py @@ -231,30 +231,26 @@ def forward(self, input_ids: torch.Tensor): class MBartSyntheticDataLoader(cube.runtime.syndata.CubeDataLoader): - def __init__(self, batch_size: int): + def __init__(self, bs: int, cfg: Config = None): + self.cfg = Config() if cfg is None else cfg + super().__init__(bs, [0, 0]) + self.sample = None + self.set_batch_size(bs) - self.bs = batch_size - self.cfg = Config() - super().__init__( - shapes=([batch_size, self.cfg.max_source_positions,],), - dtypes=(torch.int64,), - batch_dims=(0,) - ) - self.samples = [self.random_sample()] - - def random_sample(self): + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + self.batch_size = batch_size input_ids = torch.randint( 0, self.cfg.num_embeddings, size=(self.bs, self.cfg.max_source_positions), dtype=torch.int64, device=torch.cuda.current_device() ) - return input_ids - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] + self.sample = input_ids class MBartDataLoader(cube.runtime.syndata.CubeDataLoader): diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index 7381fc56..10e29dd9 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -221,16 +221,27 @@ def flops(self): # =========================== Data Loader ======================= - -class ImageDataLoader(cube.runtime.syndata.SynDataLoader): +class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): def __init__(self, batch_size: int, img_size: int, num_classes: int, dtype=torch.float32): - - self.bs = batch_size + super().__init__(batch_size, [0]) self.img_size = img_size self.num_classes = num_classes - super().__init__( - shapes=([batch_size, 3, img_size, img_size,],), - dtypes=(dtype,), - batch_dims=(0,) + self.dtype = dtype + + self.sample = None + self.set_batch_size(batch_size) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + self.batch_size = batch_size + input_ids = torch.rand( + [self.batch_size, 3, self.img_size, self.img_size], + dtype=self.dtype, device=torch.cuda.current_device() ) + self.sample = input_ids diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 212dd76a..3b720d52 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -44,7 +44,7 @@ def train(): - batch_size = 1 + batch_size = 4 load_content: bool = False cfg = Config() @@ -80,7 +80,7 @@ def train_iter(model, dataloader): print_each_rank(f'model parameter: {nparams}') CudaTimer(enable=False).warmup() - iter_num, warmup = 10, 2 + iter_num, warmup = 5, 2 for step in range(iter_num): if step >= warmup: diff --git a/tests/parser/test_fx_ops.py b/tests/parser/test_fx_ops.py index b224ed40..e78c9ff1 100644 --- a/tests/parser/test_fx_ops.py +++ b/tests/parser/test_fx_ops.py @@ -9,38 +9,36 @@ from cube.graph.function.dimops import IRDimops +def _param(size, dtype=torch.float32): + return torch.nn.Parameter(torch.empty(size, dtype=dtype)) + + class TestOpModule(torch.nn.Module): def __init__(self): super().__init__() - self.param1 = torch.nn.Parameter(torch.empty([512, 256], dtype=torch.float32)) - self.param2 = torch.nn.Parameter(torch.empty([512, 256], dtype=torch.float32)) + self.param1 = _param([512, 256]) + self.param2 = _param([512, 256]) self.ints = [1, 2, 3] def forward(self, x: torch.Tensor): - # matmul: [256, 512], [512, 256] -> [256, 256] - x1 = torch.matmul(x, self.param1) + # matmul: [bs, 512], [512, 256] -> [bs, 256] x1 = torch.matmul(x, self.param1) + # [bs, 256] -> [bs, 256] x1 = x1 + x1.size(0) + x1.size()[0] - x2 = torch.chunk(x, 2, dim=1) - x3 = x2[0] - x = x + x.size(0) - x = x + self.ints[0] - return {'x': x}, [x3,] + # [bs, 256] -> [bs, 128], [bs, 128] + x2 = torch.chunk(x1, 2, dim=1)[0] + # [bs, 128] -> [bs, 128] + x3 = x2 + x2.size(0) + x4 = x3 + self.ints[0] + # [bs, 128] -> [1] + loss = torch.sum(x4) + return {'x': x4, 'loss': loss} # , [x3,] class TestDataLoader(cube.runtime.syndata.CubeDataLoader): def __init__(self, batch_size: int = 256) -> None: - # self.sample = ( - # torch.rand( - # [batch_size, 512], - # dtype=torch.float32, - # device=torch.cuda.current_device() - # ), - # [torch.tensor([1], dtype=torch.float32),] - # ) - # super().__init__(batch_size, (0, None)) self.sample = torch.rand( [batch_size, 512], dtype=torch.float32, @@ -84,7 +82,8 @@ def policy(graph, resource): def eval_iter(model, dataloader): data = next(dataloader) out = model(data) - return out + out['loss'].backward() + # return out model = model.get_gen_module() From ac99fda22c5afa0dc0a74decad189aa2ced7b3b5 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Mon, 13 Mar 2023 05:21:13 -0700 Subject: [PATCH 1347/1892] support more ops --- cube/graph/function/function.py | 11 ++++++++++- cube/graph/parser/mappingfx.py | 11 +++++++---- cube/graph/parser/parserfx.py | 4 ++++ examples/nlp/torchscale/run_torchscale_lm.py | 3 +++ 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 40eb7b06..bc0be49e 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -425,7 +425,7 @@ def Sub(input, other, alpha=1, *, out=None, signature = None): if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(CubeSub, 'sub', signature, annos, [input, other], alpha=alpha) + return IRDimops(Sub, 'sub', signature, annos, [input, other], alpha=alpha) def Mul(input, other, *, out=None, signature = None): @@ -564,6 +564,15 @@ def Long(input, memory_format=None, signature = None): return IRDimops(Long, 'long', signature, annos, [input]) +def Int(input, memory_format=None, signature = None): + """ + Tensor.int(memory_format=torch.preserve_format) → Tensor + """ + assert memory_format is None + annos = ['* -> *'] + return IRDimops(Int, 'int', signature, annos, [input]) + + def Fill(input, value, signature = None): """ torch.Tensor.fill_(value) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 9eef0678..bd3d27c9 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -80,6 +80,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ne') : function.CompareNE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, + __tttemplate('int'): function.Int, __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, @@ -146,10 +147,12 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('add') : function.Add, '_operator.add': function.Add, - # __ttemplate('sub') : function.Sub, + '_operator.iadd': function.Add, # FIXME: may waste memory + __ttemplate('sub') : function.Sub, '_operator.sub': function.Sub, - # __ttemplate('mul') : function.Mul, + __ttemplate('mul') : function.Mul, '_operator.mul': function.Mul, + '_operator.imul': function.Mul, # FIXME: may waste memory __ttemplate('div') : function.Div, __ttemplate('true_divide'): function.Div, @@ -166,9 +169,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ge'): function.CompareGE, __ttemplate('le'): function.CompareLE, # - # __ttemplate('sin'): function.Sin, + __ttemplate('sin'): function.Sin, # - # __ttemplate('cos'): function.Cos, + __ttemplate('cos'): function.Cos, # # __ttemplate('sum') : function.Sum, # __ttemplate('mean') : function.Mean, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index be193687..2b867698 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -272,6 +272,10 @@ def get_complex_data(val: Any) -> Any: input_vals = [get_complex_data(val) for val in node.args] kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} + if 'int' in fsig: + print(fsig) + exit(1) + # map to IR operator if SignFx2Op.exist(fsig): print(input_vals) diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index 85965f2a..d52fd96f 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -162,6 +162,9 @@ def check_equal(a, b): cube_graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) print("generating cube ir graph done.") +# move simple type inputs to kwargs +# for node in cube_graph.nodes + # AutoDist # # profile communication cost # import os From b342a3a00db6cd8def4677c6619b8afd4125cf4e Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Mon, 13 Mar 2023 05:38:46 -0700 Subject: [PATCH 1348/1892] update --- cube/graph/function/function.py | 9 +++++++++ cube/graph/parser/mappingfx.py | 1 + cube/graph/parser/parserfx.py | 6 +++--- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index bc0be49e..ed68761e 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -573,6 +573,15 @@ def Int(input, memory_format=None, signature = None): return IRDimops(Int, 'int', signature, annos, [input]) +def Float(input, memory_format=None, signature = None): + """ + Tensor.float(memory_format=torch.preserve_format) → Tensor + """ + assert memory_format is None + annos = ['* -> *'] + return IRDimops(Float, 'float', signature, annos, [input]) + + def Fill(input, value, signature = None): """ torch.Tensor.fill_(value) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index bd3d27c9..e6387617 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -81,6 +81,7 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, __tttemplate('int'): function.Int, + __tttemplate('float'): function.Float, __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index deaf523f..ffcdc39f 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -272,9 +272,9 @@ def get_complex_data(val: Any) -> Any: input_vals = [get_complex_data(val) for val in node.args] kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} - if 'int' in fsig: - print(fsig) - exit(1) + # if 'int' in fsig: + # print(fsig) + # exit(1) # map to IR operator if SignFx2Op.exist(fsig): From bbf153bdb17a86c76bc52dedea67f1972150eb52 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 13 Mar 2023 05:40:20 -0700 Subject: [PATCH 1349/1892] merge changes --- cube/graph/function/function.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 2f113b3b..5fbbbe78 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1662,12 +1662,11 @@ def try_select(tensor, expr): is_select, dim, val = try_select(obj, index) if is_select: return Select(obj, dim, val, 'torch.select') - assert False, f'{obj}, {index}' + # case: subtensor = tensor[1,:2] + return FullSlice(obj, b) + # assert False, f'{obj}, {index}' elif (not isinstance(obj, IRObject)) and isinstance(index, int): return obj[index] - # case: subtensor = tensor[1,:2] - if isinstance(obj, IRTensor): - return FullSlice(obj, b) return IRPyFunc(signature, [obj, index], [IRObject()]) From db29882ae6160b5909838e49bb63803370af1f2c Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Tue, 14 Mar 2023 05:34:02 +0000 Subject: [PATCH 1350/1892] Merged PR 1487: support op execution of torchscale --- cube/algorithm/ops/dimops.py | 2 + cube/codegen/frontend_mapping.py | 5 -- cube/graph/function/dimops.py | 7 ++- cube/graph/function/function.py | 91 +++++++++++++++---------------- cube/graph/parser/mappingfx.py | 12 ++-- cube/graph/parser/parserfx.py | 43 ++++++++------- cube/profiler/database.py | 70 ++++++++++++++---------- cube/runtime/function/function.py | 17 ++++-- 8 files changed, 134 insertions(+), 113 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 3a60cf25..a3b53885 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -3,6 +3,7 @@ from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule from cube.ir.tensor import IRSubTensor +from cube.ir.cten import IRTensor from cube.ir.operator import IRFwOperation from collections import deque @@ -348,6 +349,7 @@ def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: def gen_hash(node: IRFwOperation) -> str: ret = node.signature for it in node.inputs(): + if not isinstance(it, IRTensor): continue ret = ret + '-' + str(it.shape) return ret diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index 0f08dfe4..f60ff37c 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -205,17 +205,12 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: _signMap = { - 'torch.cat': _common_rule_input_as_list, - 'torch.stack': _common_rule_input_as_list, - 'torch.slice': emit_slice, 'torch.zeros': emit_zeros, 'torch.ones': emit_ones, 'torch.Tensor.to': emit_to, 'torch.rand': emit_rand, 'torch.tensor': emit_new_tensor, - 'torch.index_select': emit_index_select, - 'torch.functional.einsum': emit_einsum, 'setattr': emit_setattr, } diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index d095f679..6079f4e1 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -657,10 +657,11 @@ def infer_shape(self) -> bool: accum *= self.anno.getlen(identifier) shape.append(accum) otensor.shape = shape + # commented because fx has assigned dtype to nodes # set output shape - if isinstance(otensor, IRSubTensor): - otensor.parent.dtype = odtype - otensor.dtype = odtype + # if isinstance(otensor, IRSubTensor): + # otensor.parent.dtype = odtype + # otensor.dtype = odtype # print(f'=> sign: {self.signature} anno: {self.anno}\n' # f'=> inputs: {self.inputs()}\n' # f'=> outputs: {self.outputs()}') diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 4b1b92be..a6d744b7 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -85,6 +85,7 @@ def BMMAdd(input, batch1, batch2, *, beta=1, alpha=1, out=None, signature = None def CubeEinSum(*operands, equation=None, signature = None): assert isinstance(equation, str) + signature = 'cube.runtime.function.einsum' lhs, rhs = equation.split('->') assert ',' not in rhs lhs_dims = set(lhs.replace(',', ' ').split(' ')) @@ -409,37 +410,22 @@ def Add(input, other, alpha=1, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input + alpha * other signature = 'torch.add' + annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) - else: - annos = ['* -> *'] - if isinstance(input, IRTensor): - return IRDimops(Add, 'add', signature, annos, [input], other=other, alpha=alpha) - else: - return IRDimops(Add, 'add', signature, annos, [other], other=input, alpha=alpha) - - -def CubeSub(input, other, alpha=1, *, out=None, signature = None): - signature = 'cube.runtime.function.sub' - if isinstance(input, IRTensor) and isinstance(other, IRTensor): - lshape, rshape, oshape = _handle_broadcast(input, other) - annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(CubeSub, 'sub', signature, annos, [input, other], alpha=alpha, swap_operands=False) - else: - annos = ['* -> *'] - if isinstance(input, IRTensor): - return IRDimops(CubeSub, 'sub', signature, annos, [input], other=other, alpha=alpha, swap_operands=False) - else: - return IRDimops(CubeSub, 'sub', signature, annos, [other], other=input, alpha=alpha, swap_operands=True) + return IRDimops(Add, 'add', signature, annos, [input, other], alpha=alpha) def Sub(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input - alpha * other - return CubeSub(input, other, alpha, out=out, signature=signature) + annos = ['*, ? -> *', '?, * -> *',] + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(Sub, 'sub', signature, annos, [input, other], alpha=alpha) def Mul(input, other, *, out=None, signature = None): @@ -447,27 +433,23 @@ def Mul(input, other, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input * other signature = 'torch.mul' + annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Mul, 'mul', signature, annos, [input, other]) - else: - annos = ['* -> *'] - if isinstance(input, IRTensor): - return IRDimops(Mul, 'mul', signature, annos, [input], other=other) - else: - return IRDimops(Mul, 'mul', signature, annos, [other], other=input) + return IRDimops(Mul, 'mul', signature, annos, [input, other]) def Div(input, other, *, rounding_mode=None, out=None, signature = None): assert rounding_mode is None and out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input / other + signature = 'torch.div' annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Div, 'div', signature, annos, [input, other]) + return IRDimops(Div, 'div', signature, annos, [input, other], rounding_mode=rounding_mode) def FloorDiv(input, other, *, out=None, signature = None): @@ -582,6 +564,24 @@ def Long(input, memory_format=None, signature = None): return IRDimops(Long, 'long', signature, annos, [input]) +def Int(input, memory_format=None, signature = None): + """ + Tensor.int(memory_format=torch.preserve_format) → Tensor + """ + assert memory_format is None + annos = ['* -> *'] + return IRDimops(Int, 'int', signature, annos, [input]) + + +def Float(input, memory_format=None, signature = None): + """ + Tensor.float(memory_format=torch.preserve_format) → Tensor + """ + assert memory_format is None + annos = ['* -> *'] + return IRDimops(Float, 'float', signature, annos, [input]) + + def Fill(input, value, signature = None): """ torch.Tensor.fill_(value) @@ -1171,7 +1171,7 @@ def Conv3D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, stride=stride, padding=padding, dilation=dilation, groups=groups) -def CubeCat(*tensors, dim: int, signature = None): +def CubeCat(*tensors, dim=0, signature = None): """ torch.cat(tensors, dim=0, *, out=None) """ @@ -1180,6 +1180,7 @@ def CubeCat(*tensors, dim: int, signature = None): # with dimension. dim=None is for the support of kwarg inputs from torchfx assert all(isinstance(tensor, IRTensor) for tensor in tensors) assert isinstance(dim, int) + signature = 'cube.runtime.function.cat' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] dimlens = [t.shape[dim] for t in tensors] for ashape, dimlen in zip(iannos, dimlens): @@ -1190,27 +1191,21 @@ def CubeCat(*tensors, dim: int, signature = None): return IRDimops(CubeCat, 'cat', signature, [anno], tensors, dim=dim) -def Cat(*tensors_and_dim, dim=0, out=None, signature=None): +def Cat(tensors, dim=0, out=None, signature=None): """ torch.cat(tensors, dim=0, *, out=None) """ assert out is None - if len(tensors_and_dim) == 2: - tensors, dim = tensors_and_dim[0], tensors_and_dim[1] - else: - tensors = tensors_and_dim[0] return CubeCat(*tensors, dim=dim, signature=signature) -def CubeStack(*tensors, dim: int, signature=None): - """ - torch.stack(tensors, dim=0, *, out=None) - """ +def CubeStack(*tensors, dim=0, signature=None): # REMARK: IRFwOperation doesn't support taking a list of IRTensors. # Therefore, the argument interface is adapted to take unpacked tensors # with dimension. assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'but got {tensors}' assert isinstance(dim, int), f"but not {dim}" + signature = 'cube.runtime.function.stack' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] oannos = [copy.copy(iannos[-1])] oannos[0].insert(dim, str(len(tensors))) @@ -1218,14 +1213,16 @@ def CubeStack(*tensors, dim: int, signature=None): return IRDimops(CubeStack, 'stack', signature, [anno], tensors, dim=dim) -def Stack(*tensors_and_dim, dim=0, out=None, signature = None): +def Stack(tensors, dim=0, out=None, signature = None): """ torch.stack(tensors, dim=0, *, out=None) + It needs CubeStack and runtime.function.stack, because + (i) if the tensors are packed in a list or tuple, it is treated as a whole tensor which is not aligned + with tensor partitioning; + (ii) if the tensors are not packed in a list or tuple, torch.stack cannot receive unpacked tensors. + """ - if len(tensors_and_dim) == 2: - tensors, dim = tensors_and_dim[0], tensors_and_dim[1] - else: - tensors, dim = tensors_and_dim[0], dim + assert out is None return CubeStack(*tensors, dim=dim, signature=signature) @@ -1256,6 +1253,7 @@ def Select(input, dim, index, signature = None): def CubeIndexSelect(input: torch.Tensor, index: torch.Tensor, dim: int, signature = None): + signature = 'cube.runtime.function.index_select' edim_in = ShapeAnno.create_shape_str(input.shape) edim_in[dim] += '^' idx_anno = chr(ord(edim_in[-1]) + 1) + '^' @@ -1585,7 +1583,8 @@ def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): dtype = dtype_or_device if isinstance(dtype_or_device, torch.dtype) else eval('torch.'+dtype_or_device.value) return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) elif isinstance(dtype_or_device, IRFullTensor): - return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device.dtype) + dtype = eval('torch.'+dtype_or_device.dtype.value) + return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) else: raise RuntimeError(f'function.To with unknown arg: {dtype_or_device}') diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 9eef0678..e6387617 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -80,6 +80,8 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ne') : function.CompareNE, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, + __tttemplate('int'): function.Int, + __tttemplate('float'): function.Float, __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, @@ -146,10 +148,12 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __ttemplate('add') : function.Add, '_operator.add': function.Add, - # __ttemplate('sub') : function.Sub, + '_operator.iadd': function.Add, # FIXME: may waste memory + __ttemplate('sub') : function.Sub, '_operator.sub': function.Sub, - # __ttemplate('mul') : function.Mul, + __ttemplate('mul') : function.Mul, '_operator.mul': function.Mul, + '_operator.imul': function.Mul, # FIXME: may waste memory __ttemplate('div') : function.Div, __ttemplate('true_divide'): function.Div, @@ -166,9 +170,9 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('ge'): function.CompareGE, __ttemplate('le'): function.CompareLE, # - # __ttemplate('sin'): function.Sin, + __ttemplate('sin'): function.Sin, # - # __ttemplate('cos'): function.Cos, + __ttemplate('cos'): function.Cos, # # __ttemplate('sum') : function.Sum, # __ttemplate('mean') : function.Mean, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 4a7d5d04..794de871 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -90,7 +90,6 @@ def parse(module: torch.fx.GraphModule, print(f'WARNING input_shapes shrinked to {input_shapes})') default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype if input_shapes is not None: # shape propagation sample_inputs = dummy_inputs if dummy_inputs else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] @@ -105,30 +104,36 @@ def parse(module: torch.fx.GraphModule, for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] - dtype = kDefaultType + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) else: assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' # remove dead nodes - from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler DCEHandler(module).eliminate_dead_code() # shape propagation - from nni.common.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp - KwargsShapeProp(module).propagate(dummy_inputs) + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp + ShapeProp(module).propagate(dummy_inputs) # handle graph inputs for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) - # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, - # extend to other input types - if hasattr(dummy_inputs, input.name): - print(f'dummy_inputs has {input.name}') - shape = getattr(dummy_inputs, input.name).size() + if isinstance(dummy_inputs, dict): + if input.name in dummy_inputs: + shape = input.meta['tensor_meta'].shape + else: + shape = None else: - # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name - print(f'dummy_inputs does not have {input.name}') - shape = None - dtype = kDefaultType + # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, + # extend to other input types + if hasattr(dummy_inputs, input.name): + print(f'dummy_inputs has {input.name}') + shape = getattr(dummy_inputs, input.name).size() + else: + # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name + print(f'dummy_inputs does not have {input.name}') + shape = None + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) input_val = [frame.get_var(input.name) for input in inputs] @@ -157,7 +162,8 @@ def parse(module: torch.fx.GraphModule, shape = node.meta['tensor_meta'].shape shape = FxModuleParser.shape_refine(shape) dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=True, dtype=dtype, name=node.name) + requires_grad = node.meta['tensor_meta'].requires_grad + val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=node.name) frame.add_var(node.name, val) else: frame.add_var(node.name, IRObject()) @@ -265,7 +271,6 @@ def get_complex_data(val: Any) -> Any: # map to IR operator if SignFx2Op.exist(fsig): - print(input_vals) ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) else: # FIXME: handle cases for IRObject in kwargs @@ -311,11 +316,9 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram tensor_name = node.name if 'tensor_meta' in node.meta: tensor_shape = node.meta['tensor_meta'].shape - # tensor_dtype = node.meta['tensor_meta'].dtype #TODO assume it is weight - default_dtype = torch.get_default_dtype() - kDefaultType = DType2IRDType.map(default_dtype) # TODO specify dtype - ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=kDefaultType) + dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) + ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=dtype) ir_tensor.as_param() frame.add_var(tensor_name, ir_tensor) else: diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 89a332e3..81ff7008 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -7,9 +7,10 @@ import time import os import json +import _operator import cube -from cube.ir.cten import IRTensor +from cube.ir.cten import IRTensor, IRObject from cube.ir.operator import IRFwOperation from cube.graph.parser.mapping import IRDType2TorchDType from cube.graph.parser.mappingfx import SignFx2Op as Sign2Op @@ -29,6 +30,7 @@ class CompProfiler: @staticmethod def profile(func: Callable, shapes: Shapes, dtypes: DTypes, + requires_grads: Tuple[bool], values: Tuple[Any], warmup_sec: float = 2, prof_times: int = 50, **kwargs) -> Tuple[float, float, int, Tuple[int]]: """ @@ -49,15 +51,18 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, assert len(shapes) == len(dtypes), \ f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" # create data - dtypes = [torch.float32] * len(shapes) if dtypes is None else dtypes - def gen_torch_tensors(shape, dtype): + assert dtypes is not None + def gen_torch_tensors(shape, dtype, requires_grad): constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand - requires_grad = False if dtype in (torch.int64, torch.int32, torch.bool) else True + # requires_grad = False if dtype in (torch.int64, torch.int32, torch.bool) else True return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) tensors = tuple( - gen_torch_tensors(shape, dtype) for shape, dtype in zip(shapes, dtypes) + gen_torch_tensors(shape, dtype, requires_grad) if value is None else value for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) ) - require_backward = any([t.requires_grad for t in tensors]) + require_backward = any([t.requires_grad for t in tensors if hasattr(t, 'requires_grad')]) + # FIXME: reconsidering requires_grad + if func.__name__ in ('type_as'): + require_backward = False # repalce kwargs starting with 'self.xxx' train_kwargs, eval_kwargs = {}, {} for name, value in kwargs.items(): @@ -186,21 +191,22 @@ def get_dep_names(sign: str): exec(code_impl, globals(), local) fn = list(local.values())[0] else: - if '_operator.' in node.signature: - if '_operator.or_' == node.signature: - fn = torch.bitwise_or - elif '_operator.invert' == node.signature: - fn = torch.bitwise_not - else: - fn = eval(node.signature.replace('_operator.', 'torch.')) - else: - fn = eval(node.signature) - shapes, dtypes = [], [] + fn = eval(node.signature) + shapes, dtypes, requires_grads, values = [], [], [], [] for t in node.inputs(): - assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" - shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) - return fn, shapes, dtypes, node.kwargs + if isinstance(t, IRTensor): + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) + requires_grads.append(t.requires_grad) + values.append(None) + elif isinstance(t, IRObject): + raise RuntimeError('IRObject has not been supported in profiling.') + else: + shapes.append(None) + dtypes.append(type(t).__name__) + requires_grads.append(None) + values.append(t) + return fn, shapes, dtypes, requires_grads, values, node.kwargs def profile(self, node: IRFwOperation, device: Optional[int] = None): """ @@ -216,7 +222,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @return infer_memory int: the peak memory in bytes after inference of the function @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward """ - fn, shapes, dtypes, kwargs = ProfileDataBase.get_func(node) + fn, shapes, dtypes, requires_grads, values, kwargs = ProfileDataBase.get_func(node) if self.exist(node): return self.query(node) @@ -228,18 +234,19 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): in_mem_info, param_mem_info = [], [] residual_mem, input_count = 0, 0 for t in node.inputs(): - if t.is_param(): + if hasattr(t, 'is_param') and t.is_param(): param_mem_info.append(t.byte_size()) - else: + elif hasattr(t, 'byte_size'): input_count += 1 if input_count == 1: residual_mem += t.byte_size() in_mem_info.append(t.byte_size()) + else: + print(f'WARNING: input {t} is skipped.') - # run profiling fw_span, bw_span, infer_memory, train_mem_info = \ - CompProfiler.profile(fn, shapes, dtypes, **kwargs) + CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) # log to database key = self._serialize(node) self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) @@ -358,10 +365,15 @@ def _serialize(self, node: IRFwOperation) -> str: """ shapes, dtypes = [], [] for t in node.inputs(): - assert isinstance(t, IRTensor), f"Only support node inputs with tensor shape" - shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) - shapes = '-'.join(str(tuple(shape)) for shape in shapes) + if isinstance(t, IRTensor): + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) + elif isinstance(t, IRObject): + raise RuntimeError('IRObject has not been supported in _serialize') + else: + shapes.append(None) + dtypes.append(type(t)) + shapes = '-'.join(str(tuple(shape)) if shape is not None else str(None) for shape in shapes) dtypes = '-'.join(str(dtype) for dtype in dtypes) return shapes + ' : ' + dtypes diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index d16155ae..8378519a 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -34,12 +34,6 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) -def sub(input: torch.Tensor, other: Union[int, float, torch.Tensor], alpha: Union[int, float], swap_operands: bool) -> torch.Tensor: - if swap_operands: - return torch.sub(other, input, alpha=alpha) - else: - return torch.sub(input, other, alpha=alpha) - def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: return input.expand(*sizes) @@ -124,3 +118,14 @@ def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): return torch.masked_scatter(input, mask, src) +def index_select(input: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor: + return torch.index_select(input, dim, index) + +def einsum(*operands, equation=None) -> torch.Tensor: + return torch.einsum(equation, *operands) + +def stack(*tensors, dim=0) -> torch.Tensor: + return torch.stack(tensors, dim) + +def cat(*tensors, dim=0) -> torch.Tensor: + return torch.cat(tensors, dim) From ab73d41eaae3617c887c0708d37bb1a585df264d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 14 Mar 2023 07:26:52 +0000 Subject: [PATCH 1351/1892] Merged PR 1490: Creators to Dimops and On-demand Dtype Inference 1) creators use dimops for implementation 2) dimops allow to partition with configuration on output 3) on-demand dtype inference --- cube/algorithm/factory.py | 3 - cube/algorithm/ops/dimops.py | 34 ++-- cube/codegen/frontend_mapping.py | 106 ----------- cube/graph/function/dimops.py | 13 +- cube/graph/function/function.py | 293 +++++++++++++++--------------- cube/graph/graph.py | 20 +- cube/graph/parser/mappingfx.py | 8 +- cube/ir/cten.py | 2 +- cube/ir/dtype.py | 5 - cube/ir/tensor.py | 14 +- cube/runtime/function/function.py | 49 +++++ 11 files changed, 234 insertions(+), 313 deletions(-) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 02b14e50..bdcb5d02 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -74,6 +74,3 @@ def _load_predefined_algos(self): import cube.algorithm.ops.creators as creators self.register(creators.IRToTensor, creators.DimSplitTo, tag='dim') - self.register(creators.IRZeros, creators.DimSplitZeros, tag='dim') - self.register(creators.IROnes, creators.DimSplitOnes, tag='dim') - self.register(creators.IRRand, creators.DimSplitRand, tag='dim') diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index a3b53885..9e06660a 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -52,29 +52,30 @@ def get_identifier_reduce(self, idx: int, dim: int, num: int) -> Tuple[str, DimA If the partitioned number is 1, return the first hidden identitifer Otherwise, return the first hidden identifier whose length > 1 - @param idx int: input index + @param idx int: input/output index. Take the idx-th input tensor or (idx-ninputs)-th output @param dim int: input dimension @return identifier Optional[str]: annotated dimension identifier @return reduction Optional[DimAnno.ReduceType] """ node: IRDimops = self.node + eshapes = node.anno.inputs() + node.anno.outputs() hidx = None - for hidx, adim in enumerate(node.anno.input(idx).dims[dim].identifiers): + for hidx, adim in enumerate(eshapes[idx].dims[dim].identifiers): if num == 1: break dimlen = node.anno.getlen(adim) if adim == '1^' or dimlen == 1: continue break if hidx is None: return (None, None) - reduce = node.anno.input(idx).dims[dim].reduces[hidx] + reduce = eshapes[idx].dims[dim].reduces[hidx] return adim, reduce def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: """ Check whether the condition satisfies. - @param idx int: input index - @param dim Union[int, str]: input dimension or 'v', ie., partition at value dimension + @param idx int: input/output index. Take the idx-th input tensor or (idx-ninputs)-th output tensor + @param dim Union[int, str]: tensor dimension or 'v', i.e., partition at value dimension. @param num int: chunks to partition the dimension @return satisfy bool: true if can be partitioned, elsewise false. @@ -83,14 +84,14 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: assert isinstance(dim, int) or dim == 'v', f"expect dim to be int or 'v'" node: IRDimops = self.node - assert isinstance(node.input(idx), IRSubTensor), f"partitioning on a non-tensor input" - ninputs = len(node.inputs()) - idx = idx if idx >= 0 else idx + ninputs - assert idx < ninputs, f"index out of boundary: {idx} >= {ninputs}" + tensors = node.inputs() + node.outputs() + assert isinstance(tensors[idx], IRSubTensor), f"partition on a non-tensor input/output" + assert 0 <= idx and idx < len(tensors), f"index out of boundary: {idx} >= {len(tensors)}" + tensors = node.inputs() + node.outputs() if isinstance(dim, int): - dim = dim if dim >= 0 else dim + node.input(idx).ndims - assert dim < node.input(idx).ndims, f"dimension output of boundary: {dim} >= {node.input(idx).ndims}" + dim = dim if dim >= 0 else dim + tensors[idx].ndims + assert dim < tensors[idx].ndims, f"dimension output of boundary: {dim} >= {node.input(idx).ndims}" # try split at tensor spatial dimension if isinstance(dim, int): @@ -99,7 +100,8 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: dimlen = node.anno.getlen(adim) # first check node special rules first for rule in node.transform_rules: - if rule.input(idx) == DimopSplit.D(dim): + splits = rule.inputs() + rule.outputs() + if splits[idx] == DimopSplit.D(dim): return dimlen >= num # then check default rules if reduce == DimAnno.ReduceType.Freeze: @@ -107,7 +109,8 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: return dimlen >= num else: for rule in node.transform_rules: - if rule.input(idx).isV(): + splits = rule.inputs() + rule.outputs() + if splits[idx].isV(): return True return False @@ -173,11 +176,12 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR assert isinstance(dim, int) or dim == 'v', f"expect dim to be int or 'v'" # check node special rules first for r in node.transform_rules: + splits = r.inputs() + r.outputs() if isinstance(dim, int): - if r.input(idx) == DimopSplit.D(dim): + if splits[idx] == DimopSplit.D(dim): return r else: - if r.input(idx).isV(): + if splits[idx].isV(): return r # otherwise use default rule assert isinstance(dim, int), f"Error: expect dim to be int for default rules" diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index f60ff37c..ccc4008b 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -65,96 +65,6 @@ def emit_slice(node, arg_vars:list, kw_pairs:dict) -> str: return f"{in_tensor_var}[{', '.join(subscript_components)}]" -# TODO consider making the IR-Torch conversion like IRDType2TorchDType intrinsic to codegen, -# so that we don't need to ad hoc do the conversion as in these emission functions. -# Also, we'd better limit the complexity of the values in 'kw_pairs' so we know for sure we have -# done all necessary conversion. -# -# Basically to convert internal 'IRDType' to frontend 'torch.dtype' -def emit_zeros(node, arg_vars:list, kw_pairs:dict) -> str: - """ - zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - """ - kw_pairs = kw_pairs.copy() - if 'dtype' in kw_pairs: - ir_dtype : IRDType = kw_pairs['dtype'] - if ir_dtype is not None: - kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) - - # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. - if 'device' in kw_pairs: - print(f'WARNING: overload device info. of {node}') - kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. - - if len(arg_vars) != 0: - print(f'WARNING: emit_zero with len(arg_vars) {len(arg_vars)} != 0') - - return _common_rule_join_all(node, arg_vars, kw_pairs) - -def emit_ones(node, arg_vars:list, kw_pairs:dict) -> str: - """ - ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - """ - kw_pairs = kw_pairs.copy() - if 'dtype' in kw_pairs: - ir_dtype : IRDType = kw_pairs['dtype'] - if ir_dtype is not None: - kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) - - # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. - assert 'device' not in kw_pairs - kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. - - assert len(arg_vars) == 0 - return _common_rule_join_all(node, arg_vars, kw_pairs) - -def emit_rand(node, arg_vars:list, kw_pairs:dict) -> str: - """ - rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - """ - kw_pairs = kw_pairs.copy() - if 'dtype' in kw_pairs: - ir_dtype : IRDType = kw_pairs['dtype'] - if ir_dtype is not None: - kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) - - # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. - assert 'device' not in kw_pairs - kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. - - assert len(arg_vars) == 0 - return _common_rule_join_all(node, arg_vars, kw_pairs) - - -def emit_new_tensor(node, arg_vars:list, kw_pairs:dict) -> str: - """ - rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - """ - kw_pairs = kw_pairs.copy() - if 'dtype' in kw_pairs: - ir_dtype : IRDType = kw_pairs['dtype'] - if ir_dtype is not None: - kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) - - # TODO make all intermediately created tensors CUDA, to fit with other parts of the system, like SynDataLoader. - assert 'device' not in kw_pairs - kw_pairs['device'] = 'torch.cuda.current_device()' # str will get directly dumped as it's. - - assert len(arg_vars) == 0 - assert 'data' in kw_pairs - assert 'shape' in kw_pairs - data_str = str(kw_pairs['data']) - _ = kw_pairs.pop('data') - _ = kw_pairs.pop('shape') - - kw_assigns = list() - for key, val in kw_pairs.items(): - assert key != 'data' - code = f'{key}={val}' - kw_assigns.append(code) - args = data_str + ', ' + ', '.join(kw_assigns) - return f'{node.signature}({args})' - # Basically to convert internal 'IRDType' to frontend 'torch.dtype' def emit_to(node, arg_vars:list, kw_pairs:dict) -> str: kw_pairs = kw_pairs.copy() @@ -172,17 +82,6 @@ def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: member = f'"{arg_vars[1][5:]}"' return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" -def emit_index_select(node, arg_vars:list, kw_pairs:dict) -> str: - assert 'dim' in kw_pairs - dim = kw_pairs['dim'] - return f'{node.signature}({arg_vars[0]}, {dim}, {arg_vars[1]})' - -def emit_einsum(node, arg_vars:list, kw_pairs:dict) -> str: - assert 'equation' in kw_pairs - equation = kw_pairs['equation'] - args_str = ', '.join(arg_vars) - return f'{node.signature}({equation}, {args_str})' - class Sign2EmitRule: @@ -206,12 +105,7 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: _signMap = { 'torch.slice': emit_slice, - 'torch.zeros': emit_zeros, - 'torch.ones': emit_ones, 'torch.Tensor.to': emit_to, - 'torch.rand': emit_rand, - 'torch.tensor': emit_new_tensor, - 'setattr': emit_setattr, } diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 6079f4e1..d991b4d7 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -411,8 +411,11 @@ def parse(anno: str) -> Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]]: if '->' not in anno: raise ValueError(f"Syntax Error: Expected -> in operator anno: {anno}") inputs, outputs = anno.split('->') - inputs = inputs.split(',') - outputs = outputs.split(',') + + inputs = inputs.strip() + inputs = [] if len(inputs) == 0 else inputs.split(',') + outputs = outputs.strip() + outputs = [] if len(outputs) == 0 else outputs.split(',') # to ShapeAnnos inputs: Tuple[ShapeAnno] = tuple(ShapeAnno(shape) for shape in inputs) outputs: Tuple[ShapeAnno] = tuple(ShapeAnno(shape) for shape in outputs) @@ -647,7 +650,6 @@ def infer_shape(self) -> bool: @return sucess: True if successfully inferred shape """ idtypes = [t.dtype for t in self._inputs if isinstance(t, IRTensor)] - odtype = DTypeInferRule.infer(self, idtypes) for oidx, otensor in enumerate(self.outputs()): shape_anno = self.oanno(oidx) shape = [] @@ -657,11 +659,6 @@ def infer_shape(self) -> bool: accum *= self.anno.getlen(identifier) shape.append(accum) otensor.shape = shape - # commented because fx has assigned dtype to nodes - # set output shape - # if isinstance(otensor, IRSubTensor): - # otensor.parent.dtype = odtype - # otensor.dtype = odtype # print(f'=> sign: {self.signature} anno: {self.anno}\n' # f'=> inputs: {self.inputs()}\n' # f'=> outputs: {self.outputs()}') diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index a6d744b7..9a953409 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -2,8 +2,9 @@ import string import copy import torch -import warnings import operator +import numpy as np +import math from cube.ir.cten import IRTensor, IRObject from cube.ir.tensor import IRSubTensor, IRFullTensor @@ -11,7 +12,7 @@ from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D -from cube.graph.function.creators import IRArange, IREmpty, IROnes, IRToTensor, IRZeros, IRRand, IRNewTensor +from cube.graph.function.creators import IRToTensor from cube.graph.function.anchor import IRGraphAnchor @@ -113,10 +114,52 @@ def Matmul(input, other, *, out=None, signature=None): return IRDimops(Matmul, 'matmul', signature, annos, [input, other]) -def Arange(*args, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, signature=None): +# =============================================== creators ========================================== + +def _get_creator_anno_rules(size: Tuple[int], partitionable: bool) -> str: + """ + Create annotation and transformation rules for creator + """ + eshape = [str(dimlen) + ('' if partitionable else '^') for dimlen in size] + anno = OpAnno.create_op_str([], [eshape]) + rules = [] + if partitionable: + for dim in range(len(size)): + def creator_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: + kwargs = dict(**kwargs) + size = list(kwargs['size']) + size[dim] = size[dim] // num + kwargs['size'] = tuple(size) + return kwargs + + rules.append(TransformRule([], [DimopSplit.D(dim)], creator_modifier)) + + return anno, rules + + +def CubeArange(start: int, end: int, step: int, dtype=None, + requires_grad=False, signature=None): + signature = 'cube.runtime.function.arange' + size = (math.ceil((end-start)/step),) + # FIXME: torch.jit.script has dtype with int + # from cube.graph.parser.mapping import TorchScalarTypeEnumMap + # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) + dtype = dtype if dtype is not None else torch.get_default_dtype() + kwargs = {'start': start, 'end': end, 'step': step, + 'dtype': dtype, 'requires_grad': requires_grad} + anno, rules = _get_creator_anno_rules(size, False) + dimop = IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) + from cube.graph.parser.mapping import DType2IRDType + dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + return dimop + + +def Arange(*args, out=None, dtype=None, layout=None, + device=None, requires_grad=False, signature=None): """ torch.arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor """ + assert layout is None if len(args) == 1: start, end, step = 0, args[0], 1 elif len(args) == 2: @@ -125,161 +168,109 @@ def Arange(*args, out=None, dtype=None, layout=torch.strided, device=None, requi start, end, step = args else: raise RuntimeError(f'Invalid number {len(args)} of args in Arange.') - assert isinstance(start, int) and isinstance(end, int) and isinstance(step, int) - from cube.graph.parser.mapping import DType2IRDType - if dtype is None: - dtype = torch.get_default_dtype() - - import math - size = (math.ceil((end-start)/step),) - kwargs = {'start': start, 'end': end, 'step': step, 'out': out, 'dtype': dtype, - 'layout': layout, 'requires_grad': requires_grad} - return IRArange(signature, size, 'arange', **kwargs) + return CubeArange(start, end, step, dtype, requires_grad=requires_grad) -def Empty(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, - pin_memory=False, memory_format=torch.contiguous_format, signature=None): - """ - torch.empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, - requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) → Tensor - """ +def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, + pin_memory=False, memory_format=None, signature=None): + # note: device is ignored + signature = 'cube.runtime.function.empty' + size = (size,) if isinstance(size, int) else tuple(size) + size: Tuple[int] = size + arg_size + assert all(isinstance(dimlen, int) for dimlen in size), f"Empty only supports static size but got {size}" + assert layout is None and memory_format is None, f"Not support for non-default memory_format and layout" + # FIXME: torch.jit.script has dtype with int + # from cube.graph.parser.mapping import TorchScalarTypeEnumMap + # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) + dtype = dtype if dtype is not None else torch.get_default_dtype() + kwargs = {'size': size, 'requires_grad': requires_grad, + 'dtype': dtype, 'pin_memory': pin_memory} + anno, rules = _get_creator_anno_rules(size, True) + dimop = IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) from cube.graph.parser.mapping import DType2IRDType - if dtype is None: - dtype = torch.get_default_dtype() - ir_dtype : IRDType = DType2IRDType.map(dtype) - # example size: ((17, 17),) - assert isinstance(size, tuple) and isinstance(size[0], tuple) - for dim, i in enumerate(size[0]): - if not isinstance(dim, int) and not dim >= 0: - raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") - kwargs = {'dtype': ir_dtype, 'layout': layout, 'device': device, 'requires_grad': requires_grad, - 'pin_memory': pin_memory, 'memory_format': memory_format} - return IREmpty(signature, size[0], 'empty', **kwargs) - - -def Zeros(signature, - inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): - # zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - # - # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of - # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. - - from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap - - size, dtype_underlying, layout, _erased_device, pin_memory = inputs - - # TODO parameters to support, currently they are all None - assert layout is None - assert pin_memory is None - - if dtype_underlying is not None: - # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, - # which is the underlying type of PyTorch C++ enum 'ScalarType'. - dtype = TorchScalarTypeEnumMap.map(dtype_underlying) - else: - dtype = torch.get_default_dtype() - - ir_dtype : IRDType = DType2IRDType.map(dtype) - - for dim, i in enumerate(size): - if not isinstance(dim, int) and not dim >= 0: - raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") - return IRZeros(signature, size, 'zeros', ir_dtype) + dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + return dimop -def Ones(signature, - inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): - # ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - size, dtype_underlying, layout, _erased_device, pin_memory = inputs - - # TODO parameters to support, currently they are all None - assert layout is None - assert pin_memory is None - from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap - - if dtype_underlying is not None: - # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, - # which is the underlying type of PyTorch C++ enum 'ScalarType'. - dtype = TorchScalarTypeEnumMap.map(dtype_underlying) - else: - dtype = torch.get_default_dtype() - - ir_dtype : IRDType = DType2IRDType.map(dtype) - - for dim, i in enumerate(size): - if not isinstance(dim, int) and not dim >= 0: - raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") - return IROnes(signature, size, 'ones', ir_dtype) - -def Rand(signature, - inputs: Tuple[ List[int], Optional[int], Optional[Any], ErasedDevice, Optional[bool] ]): - # ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - - size, dtype_underlying, layout, _erased_device, pin_memory = inputs - - # TODO parameters to support, currently they are all None - assert layout is None - assert pin_memory is None - from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap - - if dtype_underlying is not None: - # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, - # which is the underlying type of PyTorch C++ enum 'ScalarType'. - dtype = TorchScalarTypeEnumMap.map(dtype_underlying) - else: - dtype = torch.get_default_dtype() - - ir_dtype : IRDType = DType2IRDType.map(dtype) - - for dim, i in enumerate(size): - if not isinstance(dim, int) and not dim >= 0: - raise RuntimeWarning(f"The {i}-th component of the size must be non-negative integer") - return IRRand(signature, size, 'rand', ir_dtype) - -def NewTensor(data: Union[int, float, list], dtype=None, device=None, requires_grad=False, pin_memory=False, signature=None): - # NOTE: not sure all the keys of torch.tensor - assert requires_grad == False +def Zeros(size, *arg_size, out=None, dtype=None, layout=None, + device=None, requires_grad=False, signature=None): + # note: device is ignored + signature = 'cube.runtime.function.zeros' + size = (size,) if isinstance(size, int) else tuple(size) + size: Tuple[int] = size + arg_size + assert all(isinstance(dimlen, int) for dimlen in size), f"Zeros only supports static size but got {size}" + assert layout is None, f"Not support for non-default layout" + # FIXME: torch.jit.script has dtype with int + # from cube.graph.parser.mapping import TorchScalarTypeEnumMap + # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) + dtype = dtype if dtype is not None else torch.get_default_dtype() + kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} + anno, rules = _get_creator_anno_rules(size, True) + dimop = IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) from cube.graph.parser.mapping import DType2IRDType - if dtype is None: - dtype = torch.get_default_dtype() - ir_dtype : IRDType = DType2IRDType.map(dtype) - kwargs = {'dtype': ir_dtype, 'device': device, 'requires_grad': requires_grad, 'pin_memory': pin_memory} - return IRNewTensor(signature, data, 'tensor', **kwargs) - + dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + return dimop -# def NewTensor(signature, -# inputs: Tuple[ list, Optional[int], ErasedDevice, bool ]): -# # aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor -# # -# # REMARK: in the PyTorch-internal operator definition expression, an asterisk ("*") is merely a marker of -# # the beginning of the sublist of _keyword arguments_, and does not result in an actual argument. -# data, dtype_underlying, _erased_device, requires_grad = inputs - -# # TODO parameters to support, currently they are all None -# assert requires_grad == False -# from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap - -# if dtype_underlying is not None: -# # If some torch.dtype is specified at the frontend, in TorchScript it becomes an int, -# # which is the underlying type of PyTorch C++ enum 'ScalarType'. -# dtype = TorchScalarTypeEnumMap.map(dtype_underlying) -# else: -# dtype = torch.get_default_dtype() +def Ones(size, *arg_size, out=None, dtype=None, layout=None, + device=None, requires_grad=False, signature=None): + # note: device is ignored + signature = 'cube.runtime.function.ones' + size = (size,) if isinstance(size, int) else tuple(size) + size: Tuple[int] = size + arg_size + assert all(isinstance(dimlen, int) for dimlen in size), f"Ones only supports static size but got {size}" + assert layout is None, f"Not support for non-default layout" + # FIXME: torch.jit.script has dtype with int + # from cube.graph.parser.mapping import TorchScalarTypeEnumMap + # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) + dtype = dtype if dtype is not None else torch.get_default_dtype() + kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} + anno, rules = _get_creator_anno_rules(size, True) + dimop = IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) + from cube.graph.parser.mapping import DType2IRDType + dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + return dimop -# ir_dtype : IRDType = DType2IRDType.map(dtype) -# # if 'data' is not: -# # 1) ints or floats of any precision, e.g. i8, i64, f16, f32 -# # 2) non-ragged -# # ... then this call will throw. -# arr = torch.tensor(data, dtype=dtype) +def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, + pin_memory=False, memory_format=None, signature=None): + # note: device is ignored + signature = 'cube.runtime.function.rand' + size = (size,) if isinstance(size, int) else tuple(size) + size: Tuple[int] = size + arg_size + assert all(isinstance(dimlen, int) for dimlen in size), f"Rand only supports static size but got {size}" + assert layout is None and memory_format is None, f"Not support for non-default memory_format and layout" + # FIXME: torch.jit.script has dtype with int + # from cube.graph.parser.mapping import TorchScalarTypeEnumMap + # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) + dtype = dtype if dtype is not None else torch.get_default_dtype() + kwargs = {'size': size, 'requires_grad': requires_grad, + 'dtype': dtype, 'pin_memory': pin_memory} + anno, rules = _get_creator_anno_rules(size, True) + dimop = IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) + from cube.graph.parser.mapping import DType2IRDType + dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + return dimop + + +def NewTensor(data, *, dtype=None, device=None, + requires_grad=False, pin_memory=False, signature=None): + # note: device is ignored + signature = 'cube.runtime.function.tensor' + size = tuple(np.array(data).shape) + assert all(isinstance(dimlen, int) for dimlen in size), f"Ones only supports static size but got {size}" + # FIXME: torch.jit.script has dtype with int + # from cube.graph.parser.mapping import TorchScalarTypeEnumMap + # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) + dtype = dtype if dtype is not None else torch.get_default_dtype() + kwargs = {'size': size, 'requires_grad': requires_grad, + 'dtype': dtype, 'pin_memory': pin_memory} + anno, rules = _get_creator_anno_rules(size, True) + dimop = IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) + from cube.graph.parser.mapping import DType2IRDType + dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + return dimop -# # TODO temporarily fake creation with Zeros -# # and remark that originally aten::tensor should be able to infer the dtype from the specified 'data', -# # but since we have omitted the 'data', we must do type inferrence ourselves, -# # only in this way we get correct dtype e.g. ints or bools. -# return IRNewTensor(signature, data, 'tensor', ir_dtype=ir_dtype) def ToTensor(signature, inputs: Tuple[ IRTensor, ... ]): @@ -1489,11 +1480,15 @@ def _comparison(creator: Callable, f: Callable, name: str, signature: str, if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(creator, name, signature, annos, [input, other]) + dimop = IRDimops(creator, name, signature, annos, [input, other]) + dimop.output(0).parent.dtype = IRDType.boolean + return dimop # case2: torch.equal(tensor1, obj2) / torch.equal(obj1, tensor2) if isinstance(input, IRTensor) or isinstance(other, IRTensor): annos = ['*, ? -> *', '?, * -> *',] - return IRDimops(creator, name, signature, annos, [input, other]) + dimop = IRDimops(creator, name, signature, annos, [input, other]) + dimop.output(0).parent.dtype = IRDType.boolean + return dimop # case3: torch.equal(obj1, obj2) else: return IRPyFunc(signature, [input, other], [IRObject()]) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 4bac021b..3b18c68b 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -102,20 +102,18 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: # dtype inference for node in self._nodes: - itensors = [t for t in node.inputs() if isinstance(t, IRSubTensor)] - # setup gradient + # reset input + itensors: List[IRTensor] = [t for t in node.inputs() if isinstance(t, IRSubTensor)] for itensor in itensors: - if itensor.parent.grad is not None: - itensor.parent.dtype = itensor.dtype + itensor.parent.dtype = itensor.dtype + # infer output dtype with default dtype promotion rules if len(itensors) == 0: continue - odtype = DTypeInferRule.infer(node, [t.dtype for t in itensors]) - assert odtype != IRDType.unknown, f"{node} : {[t.dtype for t in itensors]}" + default_dtype = DTypeInferRule.infer(node, [t.dtype for t in itensors]) + # set output tensors if it has unkown tensor dtype otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] - for tensor in otensors: - tensor.dtype = odtype - # setup graidient - if tensor.parent.grad is not None: - tensor.parent.grad.dtype = odtype + for otensor in otensors: + if otensor.dtype == IRDType.unknown: + otensor.parent.dtype = default_dtype from cube.program import Program Program().add_nodes(self.nodes()) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index e6387617..32e28d8a 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -135,15 +135,13 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # # __ftemplate('cross_entropy'): function.CrossEntropy, # - # # torch aten - # # # creators __ttemplate('empty'): function.Empty, - # __ttemplate('zeros'): function.Zeros, - # __ttemplate('ones'): function.Ones, + __ttemplate('zeros'): function.Zeros, + __ttemplate('ones'): function.Ones, __ttemplate('tensor'): function.NewTensor, # __ttemplate('to'): function.ToTensor, - # __ttemplate('rand'): function.Rand, + __ttemplate('rand'): function.Rand, # __ttemplate('clone'): function.Clone, # __ttemplate('add') : function.Add, diff --git a/cube/ir/cten.py b/cube/ir/cten.py index c636be1e..fec5509a 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -557,7 +557,7 @@ def dtype(self, val: IRDType): raise TypeError(f"Expected IRDType but got {val}") self._dtype = val if isinstance(self._grad, IRTensor): - self._dtype = val + self._grad._dtype = val def is_param(self) -> bool: """! diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index f654cd62..89b8f6dd 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -46,11 +46,6 @@ def infer(node, dtypes: List[IRDType]) -> IRDType: raise RuntimeError(f"Find an unkown dtype") if IRDType.float32 in dtypes and IRDType.float16 in dtypes: raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") - # TODO(yizhu1): hack - if node.signature in ('torch.ne', 'torch.eq', 'torch.gt'): - return IRDType.boolean - elif node.signature in ('torch.Tensor.long', 'torch._shape_as_tensor'): - return IRDType.int64 # in priority: fp32 > fp16 > bool > int64 > int16 > priority = [ IRDType.float64, IRDType.float32, IRDType.float16, diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 0740560e..5f3937a0 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -500,16 +500,10 @@ def dtype(self) -> IRDType: @dtype.setter def dtype(self, val: IRDType): - if self.parent.dtype == IRDType.unknown: - self.parent.dtype = val - else: - if self.parent.dtype != val: - print(f'ERROR (skipped) reset IRSubTensor({self.name}) dtype {self.parent.dtype}->{val}') - self.parent.dtype = val - - # TODO recover me - # assert self.parent.dtype == val, \ - # f"dtype mis-matched with previous setting: {val} != {self.parent.dtype}" + raise RuntimeError( + f"IRSubTensor dtype must follow IRFullTensor dtype. " + f"Please set it by subtensor.parent.dtype = {val}" + ) def splitdims(self) -> Tuple[int]: """! diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 8378519a..813985a7 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -9,6 +9,7 @@ def identity(tensor: torch.Tensor) -> torch.Tensor: """ return tensor + def anchor(name: str): """ anchor operation for graph navigation @@ -22,9 +23,11 @@ def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: """ return tensor if times == 1 else tuple([tensor] * times) + def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) -> torch.Tensor: return tensor.to(dtype_or_device) + def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: """ accumulate tensors in to one tensor @@ -34,12 +37,15 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) + def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: return input.expand(*sizes) + def fullslice(input: torch.Tensor, slicers: Tuple[Union[None, slice]]): return input[slicers] + def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], stride: int, padding: List[int], dilation, groups: int = 1): """ @@ -54,6 +60,7 @@ def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso input = TorchF.pad(input, padding, 'constant', 0) return TorchF.conv2d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) + def conv3d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], stride: int, padding: List[int], dilation, groups: int = 1): """ @@ -70,6 +77,7 @@ def conv3d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso input = TorchF.pad(input, pad_padding, 'constant', 0) return TorchF.conv3d(input, weight, bias, stride=stride, dilation=dilation, groups=groups) + def embedding(input: torch.Tensor, weight: torch.Tensor, padding_idx: Optional[int], start: int, stop: int): """ Embedding @@ -118,14 +126,55 @@ def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): return torch.masked_scatter(input, mask, src) + +def empty(size: Tuple[int], dtype=None, requires_grad=False, pin_memory=False): + return torch.empty( + size, dtype=torch.get_default_dtype() if dtype is None else dtype, + device=torch.cuda.current_device(), + requires_grad=requires_grad, pin_memory=pin_memory + ) + + +def zeros(size: Tuple[int], dtype=None, requires_grad=False): + return torch.zeros( + size, dtype=torch.get_default_dtype() if dtype is None else dtype, + device=torch.cuda.current_device(), + requires_grad=requires_grad + ) + + +def ones(size: Tuple[int], dtype=None, requires_grad=False): + return torch.ones( + size, dtype=torch.get_default_dtype() if dtype is None else dtype, + device=torch.cuda.current_device(), + requires_grad=requires_grad + ) + + +def rand(size: Tuple[int], dtype=None, requires_grad=False): + return torch.rand( + size, dtype=torch.get_default_dtype() if dtype is None else dtype, + device=torch.cuda.current_device(), + requires_grad=requires_grad + ) + +def arange(start: int, end: int, step: int, dtype: torch.dtype, requires_grad=False): + return torch.arange(start=start, end=end, step=step, + dtype=dtype, requires_grad=requires_grad, + device=torch.cuda.current_device()) + + def index_select(input: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor: return torch.index_select(input, dim, index) + def einsum(*operands, equation=None) -> torch.Tensor: return torch.einsum(equation, *operands) + def stack(*tensors, dim=0) -> torch.Tensor: return torch.stack(tensors, dim) + def cat(*tensors, dim=0) -> torch.Tensor: return torch.cat(tensors, dim) From e4cb02b43b15fca143564d5451679ae5b97198f3 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 14 Mar 2023 02:05:52 -0700 Subject: [PATCH 1352/1892] save work --- cube/profiler/database.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index cb421350..44a290eb 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -378,9 +378,9 @@ def _serialize(self, node: IRFwOperation) -> str: dtypes.append(IRDType2TorchDType.map(t.dtype)) elif isinstance(t, IRObject): raise RuntimeError('IRObject has not been supported in _serialize') - else: - shapes.append(None) - dtypes.append(type(t)) + # else: + # shapes.append(None) + # dtypes.append(type(t)) shapes = '-'.join(str(tuple(shape)) if shape is not None else str(None) for shape in shapes) dtypes = '-'.join(str(dtype) for dtype in dtypes) return shapes + ' : ' + dtypes From 7aed597598bd6be3e5b9da2cd73befdf827423bc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 15 Mar 2023 08:25:30 +0000 Subject: [PATCH 1353/1892] Merged PR 1493: fix bug for script parser on jit.ignore fix bug for script parser on jit.ignore --- cube/graph/parser/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index d10aba18..d3c6e856 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -746,7 +746,7 @@ def parse_prim_python_op_node(node, module, frame): fsig: str = str(node.pyname()) # map to IR operator - ir_node = Sign2Op.map(fsig)(inputs=input_vals) + ir_node = Sign2Op.map(fsig)(*input_vals) # push output in the frame # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) From e6d91e6c988a0aed8665b8122bc045e7ebd29716 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 15 Mar 2023 17:59:42 -0700 Subject: [PATCH 1354/1892] fix stack anno bug --- cube/graph/function/function.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 9a953409..1daf99ba 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1198,9 +1198,14 @@ def CubeStack(*tensors, dim=0, signature=None): assert isinstance(dim, int), f"but not {dim}" signature = 'cube.runtime.function.stack' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] - oannos = [copy.copy(iannos[-1])] - oannos[0].insert(dim, str(len(tensors))) - anno = OpAnno.create_op_str(iannos, oannos) + oanno = [None for i in range(len(tensors[0].shape) + 1)] + oanno[dim] = f'{len(tensors)}^' + offset = 0 + for i in range(len(oanno)): + if oanno[i] is None: + oanno[i] = copy.copy(iannos[-1][offset]) + offset += 1 + anno = OpAnno.create_op_str(iannos, [oanno]) return IRDimops(CubeStack, 'stack', signature, [anno], tensors, dim=dim) From cc65e5a9ee1ac862a4a01dda9f58e9b5585fc8b9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 Mar 2023 10:06:14 +0000 Subject: [PATCH 1355/1892] Merged PR 1494: fix database bugs and add estimator fix database bugs and add estimator: None in node inputs will crash the profiler old estimator is deprecated. The new estimator will get computation / memory cost of executing a sub-graph. --- cube/profiler/database.py | 12 ++-- cube/profiler/estimator.py | 142 ++++++++++++------------------------- 2 files changed, 52 insertions(+), 102 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 81ff7008..18f53903 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -13,6 +13,7 @@ from cube.ir.cten import IRTensor, IRObject from cube.ir.operator import IRFwOperation from cube.graph.parser.mapping import IRDType2TorchDType +# from cube.graph.parser.mapping import Sign2Op from cube.graph.parser.mappingfx import SignFx2Op as Sign2Op @@ -53,11 +54,13 @@ def profile(func: Callable, shapes: Shapes, dtypes: DTypes, # create data assert dtypes is not None def gen_torch_tensors(shape, dtype, requires_grad): + """Generate dummy input tenosrs""" constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand - # requires_grad = False if dtype in (torch.int64, torch.int32, torch.bool) else True return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) + tensors = tuple( - gen_torch_tensors(shape, dtype, requires_grad) if value is None else value for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) + gen_torch_tensors(shape, dtype, requires_grad) if isinstance(value, IRTensor) else value \ + for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) ) require_backward = any([t.requires_grad for t in tensors if hasattr(t, 'requires_grad')]) # FIXME: reconsidering requires_grad @@ -198,7 +201,7 @@ def get_dep_names(sign: str): shapes.append(t.shape) dtypes.append(IRDType2TorchDType.map(t.dtype)) requires_grads.append(t.requires_grad) - values.append(None) + values.append(t) elif isinstance(t, IRObject): raise RuntimeError('IRObject has not been supported in profiling.') else: @@ -221,6 +224,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): @return bw_span float: the backward span time in milliseconds @return infer_memory int: the peak memory in bytes after inference of the function @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward + @return residual_mem: ?? """ fn, shapes, dtypes, requires_grads, values, kwargs = ProfileDataBase.get_func(node) @@ -234,7 +238,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): in_mem_info, param_mem_info = [], [] residual_mem, input_count = 0, 0 for t in node.inputs(): - if hasattr(t, 'is_param') and t.is_param(): + if isinstance(t, IRTensor) and t.is_param(): param_mem_info.append(t.byte_size()) elif hasattr(t, 'byte_size'): input_count += 1 diff --git a/cube/profiler/estimator.py b/cube/profiler/estimator.py index 5aec2d05..3b8fcb89 100644 --- a/cube/profiler/estimator.py +++ b/cube/profiler/estimator.py @@ -1,109 +1,55 @@ +from typing import Union, Tuple +import sys +import os -from cube.ir.operator import IRBpOperation, IRFwOperation -from cube.ir.tensor import IRSubTensor, ValueMap -from cube.ir.adapter import IRAdapter -from cube.graph import IRGraph -from cube.ir.cten import IRCell, IRTensor +from cube.ir.operator import IRFwOperation +from cube.graph.segment import IRSegment +from cube.graph.function import IRGraphAnchor +from cube.profiler.database import ProfileDataBase class Estimator: + """ + Estimator to measture the computation / memory cost of a subgraph + """ - def __init__(self, graph: IRGraph): - """ - Estimator for policy use - """ + def __init__(self, cache='./profile_database.json'): - self.graph = graph + self.cache_file = cache + reload = cache if os.path.exists(cache) else None + self.database = ProfileDataBase(reload) - def comm_volume(self, device: int) -> int: + def __call__(self, nodes_or_segment: Union[Tuple[IRFwOperation], IRSegment], + train: bool=False): """ - Estimate message recv volume of device id. - This has no requirement for generating adapters in graph. + Profile the computation cost of a subgraph - Node that is not assigned to a particular device will not - be considered. - """ - volume = 0 - for node in self.graph.nodes(): - if isinstance(node, IRAdapter): - continue - if device in node.device: - volume += self.comm_volume_node(node) - return volume + @param nodes_or_segment Tuple[IRFwOperation] | IRSegment - def comm_volume_node(self, node: IRCell) -> int: + @return latency float: latency in ms + @return memory int: memory in bytes """ - Estimate node message recv volume. - This has no requirement for generating adapters in graph. - - Note for intermediate tensor communication, the estimated - communication volume is: - Volume = 0 if local produced tensor can covor all the needed region. - else N#(remote produced overlapping region) - """ - if node not in self.graph.nodes(): - raise KeyError(f"node {node} not in graph") - if len(node.device) == 0: - raise RuntimeError(f"node {node} device is not assigned") - volume = 0 - for input in node.inputs(): - if isinstance(input, IRSubTensor): - # reducer - if input.is_param(): - if input.grad.valmap != ValueMap(0, 1): - volume += input.nele() * (input.grad.valmap.chunk_num - 1) - # adapter - else: - local, remote = list(), list() - for ptensor in input.parent.ptensors: - if ptensor.device != input.device: - remote.append(ptensor) - else: - local.append(ptensor) - # check local producer - local_cover = False - for ptensor in local: - if input.overlap(ptensor): - intersection = input.common(ptensor) - if intersection == input: - local_cover = True - break - if local_cover: - continue - # check remote producer - remote_producer_volume = 0 - for ptensor in remote: - if input.overlap(ptensor): - intersection = input.common(ptensor) - remote_producer_volume += intersection.nele() - # check remote consumer - # TODO: need to check if all consumers can be - # merged to input - remote_consumer_volume = None - index = input.parent.consumers.index(node) - for ctensor in input.parent.ctensors[:index]: - if input.overlap(ctensor): - if remote_consumer_volume is None: - remote_consumer_volume = 0 - intersection = input.common(ctensor) - remote_consumer_volume += intersection.nele() - if intersection == input: - break - if remote_consumer_volume is None: - volume += remote_producer_volume - else: - volume += min(remote_consumer_volume, remote_producer_volume) - # debug info - # if isinstance(node, IRFwOperation): - # print(f'fw{node._id}-{node.device}-{node.name}: {volume}') - # elif isinstance(node, IRBpOperation): - # print(f'bw{node._id}(fw{node.mirror._id}): {volume}') - # else: - # print(f'cell{node._id}-{node.device}-{node.name}: {volume}') - return volume - - def flops(self) -> int: - raise NotImplementedError - - def flops_node(self, node: IRCell) -> int: - raise NotImplementedError + nodes = nodes_or_segment.nodes() if isinstance(nodes_or_segment, IRSegment) else nodes_or_segment + memory, latency = 0.0, 0.0 + for node in nodes: + if node.name == 'multiref' or isinstance(node, IRGraphAnchor): + continue + # _, _, fw_span, bw_span, infer_mem, train_mem_info, _ = self.database.profile(node) + try: + _, _, fw_span, bw_span, infer_mem, train_mem_info, _ = self.database.profile(node) + except Exception as e: + color, default = '\033[31m', '\033[0m' + error_msg = f'fail to run node: {node}\nerror: {e}' + print(f'{color}{error_msg}{default}', file=sys.stderr) + fw_span, bw_span, infer_mem, train_mem_info = 0, 0, 0, [0] + + if train: + memory += sum(train_mem_info) + latency += fw_span + bw_span + else: + memory = max(memory, infer_mem) + latency += fw_span + return latency, memory + + def save(self): + self.database.dump(self.cache_file, override=True) From 5744fd78c7ab2b14a6bd8464557b78a60441e5fa Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Thu, 16 Mar 2023 11:13:07 +0000 Subject: [PATCH 1356/1892] Merged PR 1492: update run_torchscale_lm.py to allow multi-GPU training and inference update run_torchscale_lm.py to allow single-GPU training and multi-GPU inference # changes required on torchscale: _torchscale/architecture/decoder.py" 462L_ ```python x = torch.sum(x) return x ``` --- cube/flags.py | 2 +- cube/graph/graph.py | 23 ++++---- cube/runtime/adapter/reducer.py | 2 +- examples/nlp/torchscale/policy/spmd.py | 49 ++++++++++++----- examples/nlp/torchscale/run_torchscale_lm.py | 55 +++++++++++++------- 5 files changed, 87 insertions(+), 44 deletions(-) diff --git a/cube/flags.py b/cube/flags.py index ed2a4354..9b7dceda 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -41,7 +41,7 @@ class CompileFlag: async_comm = _to_bool('ASYNC_COMM') # maximal reducer weight bytes for one allreduce - max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=5e8) + max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=6e8) # use automate mixture precision training, where weights, gradients # and optimizer status are kept in its original data type (can be float32), diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 3b18c68b..8c4a738c 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -11,6 +11,7 @@ import warnings import copy +import cube.flags from cube.ir.cten import IRTensor, IRCell, IRObject from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation @@ -136,16 +137,20 @@ def backward(self, loss: IRSubTensor): # set mirror as self self._mirror = self - # infer gradient requirement - for node in self.nodes(): - itensors = [t for t in node.inputs() if isinstance(t, IRTensor)] - require_grad = any(t.requires_grad for t in itensors) - for otensor in node.outputs(): - if not isinstance(otensor, IRTensor): continue - if isinstance(otensor, IRSubTensor): - otensor.parent.requires_grad = require_grad + if not cube.flags.CompileFlag.use_torchfx: + # infer gradient requirement + for node in self.nodes(): + itensors = [t for t in node.inputs() if isinstance(t, IRTensor)] + if node.name == 'type_as': + require_grad = itensors[0].requires_grad else: - otensor.requires_grad = require_grad + require_grad = any(t.requires_grad for t in itensors) + for otensor in node.outputs(): + if not isinstance(otensor, IRTensor): continue + if isinstance(otensor, IRSubTensor): + otensor.parent.requires_grad = require_grad + else: + otensor.requires_grad = require_grad # set loss gradient loss.parent.to_loss() diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 8aedc87c..019daefb 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -43,7 +43,7 @@ def allreduce(self): for param in self._params: if param.requires_grad and param.grad is not None: cur_byte_size = param.nelement() * param.element_size() - assert cur_byte_size <= self.bucket_size + assert cur_byte_size <= self.bucket_size, f'cur_byte_size = {cur_byte_size}' tp = param.data.type() if tp not in buckets: diff --git a/examples/nlp/torchscale/policy/spmd.py b/examples/nlp/torchscale/policy/spmd.py index 52cf2b99..f9cc611e 100644 --- a/examples/nlp/torchscale/policy/spmd.py +++ b/examples/nlp/torchscale/policy/spmd.py @@ -3,6 +3,7 @@ from cube.graph.segment import IRSegment from cube.ir.operator import IRDataOperation, IRFwOperation from cube.graph.gener.rvd.intra import IntraAutoPlacer +from cube.graph.function import IRTensor # tensor parallelism with auto-placer @@ -74,17 +75,32 @@ def PASData(graph: IRGraph, resource): batch_dim = node.get_batch_dims()[0] for node in graph.nodes(): if isinstance(node, IRFwOperation): - # if not isinstance(node, IRPyFunc): # and node.signature in ('torch.arange', 'torch.sin'): - # algo = node.algorithms('dim') - # sub_nodes = graph.partition( - # node, algo, idx=0, dim=batch_dim, num=resource.ngpus) - # else: - # print(f'WARNING: {node} cannot find dim algo, using replicate instead') - # sub_nodes = graph.replicate(node, resource.ngpus) try: algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=0, dim=batch_dim, num=resource.ngpus) + + must_replicate = False + for itensor in node.inputs(): + if not isinstance(itensor, IRTensor): + continue + + print(f'itersor = {itensor}') + for consumer in graph.consumers(itensor.parent): + if consumer.name == 'fullslice': + must_replicate = True + break + if must_replicate == True: + break + + if must_replicate: + print(f'##### must_replicate {node.name}') + sub_nodes = graph.replicate(node, resource.ngpus) + else: + idx = 0 + if node.name in {'type_as'}: + print(f"###### {node.name}") + idx = 1 + sub_nodes = graph.partition( + node, algo, idx=idx, dim=batch_dim, num=resource.ngpus) except AssertionError: print(f'WARNING: {node} cannot find dim algo, using replicate instead') sub_nodes = graph.replicate(node, resource.ngpus) @@ -141,10 +157,17 @@ def PASHybrid(graph: IRGraph, resource): """ linears = [node for node in graph.nodes() if node.name == 'linear'] for idx, node in enumerate(linears): - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=resource.ngpus) - for idx, node in enumerate(tp_nodes): - graph.assign(node, idx) + try: + algo = node.algorithms('dim') + tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=resource.ngpus) + for idx, node in enumerate(tp_nodes): + graph.assign(node, idx) + except AssertionError: + print(f'WARNING: {node} cannot find dim algo, using replicate instead') + sub_nodes = graph.replicate(node, resource.ngpus) + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + for node in graph.nodes(): if isinstance(node, (IRFwOperation, IRDataOperation)): if len(node.device) == 0: diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index cbe018b8..90c679d4 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -1,4 +1,11 @@ +# single GPU inference debug # USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData +# multi-GPU inference test +# OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 --master_port=25642 examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --subln --xpos-rel-pos --fp16 --policy PASData +# single-GPU training test +# OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 --master_port=25642 examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --subln --xpos-rel-pos --fp16 --policy PASData --do_train +# multi-GPU training test +# OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 --master_port=25642 examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --subln --xpos-rel-pos --fp16 --policy PASData --do_train import torch import pickle @@ -12,14 +19,9 @@ from fairseq.data import iterators import sys - import os -print(f'os.getcwd() = {os.getcwd()}') - # https://github.com/microsoft/torchscale/tree/main/examples/fairseq -# sys.path.append('/home/v-junliang/torchscaletest/torchscale/examples/fairseq') -# sys.path.append('./torchscaletest/torchscale/examples/fairseq') sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale') print(f'sys.path = {sys.path}') @@ -43,9 +45,11 @@ # build model parser = options.get_training_parser() parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +parser.add_argument('--do_train', action='store_true', default=False) # parser.add_argument('--local_rank', type=int, default=0) args = options.parse_args_and_arch(parser) +print(f"Running mode: {'TRAIN' if args.do_train else 'EVAL'}") cube.init() # set up policy @@ -63,7 +67,10 @@ cfg = convert_namespace_to_omegaconf(args) task = tasks.setup_task(cfg.task) model = task.build_model(cfg.model) -model.eval() +if args.do_train: + model.train() +else: + model.eval() print("building model succeed: ", type(model)) # create dummy input @@ -118,17 +125,25 @@ model, dummy_input=dummy_input, ) -@cube.compile(model, dataloader, PAS=PAS, load_content=False, override=True) -def train_iter(model, dataloader): - data = next(dataloader) - loss = model(*data) - # loss.backward() - return loss - -model = model.get_gen_module() - -iter_ret = train_iter(model, dataloader) -print(f'iter_ret = {iter_ret}') - -import sys -sys.exit(0) \ No newline at end of file +if args.do_train: + @cube.compile(model, dataloader, PAS=PAS, load_content=False, override=True) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(*data) + loss.backward() + # TODO fix loss.mirror DummyInputOutput issue + + model = model.get_gen_module() + train_iter(model, dataloader) +else: # do_eval + @cube.compile(model, dataloader, PAS=PAS, load_content=False, override=True) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(*data) + return loss + + model = model.get_gen_module() + iter_ret = train_iter(model, dataloader) + print(f'iter_ret = {iter_ret}') + +print('DONE') \ No newline at end of file From 6b3dc3efbbc95ce06e75c727c2a360bba5e35da8 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Mar 2023 02:43:20 +0000 Subject: [PATCH 1357/1892] Merged PR 1496: Make both torch.script and torch.fx available for training 1) Inferring of tensor `require_grad` attribute moved into IRFwOperation initialization. (can be reset in function.py or parser) 2) separate dtype mapping as a single module 3) separate register operator interface as a single module @ please help check if this PR works with torchscale example --- cube/algorithm/factory.py | 3 - cube/algorithm/ops/creators.py | 167 ------------------------ cube/codegen/frontend_mapping.py | 12 -- cube/codegen/module/module.py | 4 +- cube/graph/function/anchor.py | 2 +- cube/graph/function/conv.py | 15 +-- cube/graph/function/creators.py | 212 ------------------------------- cube/graph/function/dimops.py | 12 +- cube/graph/function/function.py | 57 ++------- cube/graph/function/pyfunc.py | 4 +- cube/graph/gener/utils.py | 9 +- cube/graph/graph.py | 19 +-- cube/graph/parser/dtype.py | 42 ++++++ cube/graph/parser/mapping.py | 78 ++---------- cube/graph/parser/mappingfx.py | 40 ++---- cube/graph/parser/parser.py | 3 +- cube/graph/parser/parserfx.py | 8 +- cube/graph/parser/register.py | 39 +++++- cube/ir/operator.py | 42 +++--- cube/profiler/database.py | 15 +-- cube/program.py | 2 +- 21 files changed, 162 insertions(+), 623 deletions(-) delete mode 100644 cube/algorithm/ops/creators.py delete mode 100644 cube/graph/function/creators.py create mode 100644 cube/graph/parser/dtype.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index bdcb5d02..279757e2 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -71,6 +71,3 @@ def _load_predefined_algos(self): self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') self.register(conv.IRConv2D, conv.HaloSplitConv2D, tag='halo') self.register(conv.IRConv3D, conv.HaloSplitConv3D, tag='halo') - - import cube.algorithm.ops.creators as creators - self.register(creators.IRToTensor, creators.DimSplitTo, tag='dim') diff --git a/cube/algorithm/ops/creators.py b/cube/algorithm/ops/creators.py deleted file mode 100644 index 7119c3fc..00000000 --- a/cube/algorithm/ops/creators.py +++ /dev/null @@ -1,167 +0,0 @@ -from typing import List, Tuple, Optional - -from cube.algorithm.generics import GenericDistAlgo - -from cube.graph.function.creators import IRToTensor, IROnes, IRRand, IRZeros -from cube.ir.tensor import IRSubTensor - - -class DimSplitTo(GenericDistAlgo): - """ - split Pad at dimension level - - """ - def __init__(self, node: IRToTensor): - if not isinstance(node, IRToTensor): - raise TypeError(f"Expect IRToTensor") - super().__init__(node) - - def satisfy(self, dim: int, num: int): - """ - config = dict(idx=int, dim=int, num=num) - - """ - assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" - node: IRToTensor = self.node - - assert dim < len(node.input(0).shape), "Split dimension should be smaller than tensor dimension" - - # split non-pad dim - return node.input(0).shape[dim] >= num - - def instantiate(self, dim: int, num: int) -> Optional[List[IRToTensor]]: - - node: IRToTensor = self.node - satisfy = self.satisfy(dim, num) - if not satisfy: - return None - - ins, ous = list(), list() - for iidx, itensor in enumerate(node.inputs()): - assert isinstance(itensor, IRSubTensor), "Input of select shoud be IRSubTensor" - ins.append(itensor.split_dim(dim, num)) - - odim = dim - - for oidx, otensor in enumerate(node.outputs()): - assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" - ous.append(otensor.split_dim(odim, num)) - - sub_nodes = list() - for nid in range(num): - inputs = [t[nid] for t in ins] - outputs = [t[nid] for t in ous] - sub_nodes.append(node.new(inputs, outputs)) - return sub_nodes - -class DimSplitZeros(GenericDistAlgo): - def __init__(self, node: IRZeros): - if not isinstance(node, IRZeros): - raise TypeError(f"Expect IRZeros") - super().__init__(node) - - def satisfy(self, dim: int, num: int): - """ - config = dict(idx=int, dim=int, num=num) - - """ - assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" - node: IRZeros = self.node - - assert dim < len(node.output(0).shape), "Split dimension should be smaller than tensor dimension" - - # split non-pad dim - return node.output(0).shape[dim] >= num - - def instantiate(self, dim: int, num: int) -> Optional[List[IRZeros]]: - - node: IRZeros = self.node - satisfy = self.satisfy(dim, num) - if not satisfy: - return None - - ous = list() - for oidx, otensor in enumerate(node.outputs()): - assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" - ous.append(otensor.split_dim(dim, num)) - - sub_nodes = list() - for nid in range(num): - outputs = [t[nid] for t in ous] - sub_nodes.append(node.new(outputs)) - return sub_nodes - - -class DimSplitOnes(GenericDistAlgo): - def __init__(self, node: IROnes): - if not isinstance(node, IROnes): - raise TypeError(f"Expect IROnes") - super().__init__(node) - - def satisfy(self, dim: int, num: int): - """ - config = dict(idx=int, dim=int, num=num) - - """ - assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" - node: IROnes = self.node - - assert dim < len(node.outputs(0).shape), "Split dimension should be smaller than tensor dimension" - - # split non-pad dim - return node.outputs(0).shape[dim] >= num - - def instantiate(self, dim: int, num: int) -> Optional[List[IROnes]]: - - node: IROnes = self.node - satisfy = self.satisfy(dim, num) - if not satisfy: - return None - - ous = list() - for oidx, otensor in enumerate(node.outputs()): - assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" - ous.append(otensor.split_dim(dim, num)) - - sub_nodes = list() - for nid in range(num): - outputs = [t[nid] for t in ous] - sub_nodes.append(node.new(outputs)) - return sub_nodes - -class DimSplitRand(GenericDistAlgo): - def __init__(self, node: IRRand): - if not isinstance(node, IRRand): - raise TypeError(f"Expect IRRand") - super().__init__(node) - - def satisfy(self, dim: int, num: int): - """ - config = dict(idx=int, dim=int, num=num) - - """ - assert all(isinstance(t, int) for t in [dim, num]), "dim and num should be integer" - node: IRRand = self.node - - assert dim < len(node.outputs(0).shape), "Split dimension should be smaller than tensor dimension" - - # split non-pad dim - return node.outputs(0).shape[dim] >= num - - def instantiate(self, dim: int, num: int) -> Optional[List[IRRand]]: - - node: IRRand = self.node - satisfy = self.satisfy(dim, num) - if not satisfy: - return None - - ous = list() - for oidx, otensor in enumerate(node.outputs()): - assert isinstance(otensor, IRSubTensor), "Output of select should be IRSubTensor" - ous.append(otensor.split_dim(dim, num)) - - sub_nodes = list() - for nid in range(num): - outputs = [t[nid] for t in ous] - sub_nodes.append(node.new(outputs)) - return sub_nodes \ No newline at end of file diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index ccc4008b..e5e262f3 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -65,17 +65,6 @@ def emit_slice(node, arg_vars:list, kw_pairs:dict) -> str: return f"{in_tensor_var}[{', '.join(subscript_components)}]" -# Basically to convert internal 'IRDType' to frontend 'torch.dtype' -def emit_to(node, arg_vars:list, kw_pairs:dict) -> str: - kw_pairs = kw_pairs.copy() - - # Unlike 'zeros' who has 'ScalarType? dtype', 'to' has a non-nullable 'dtype'. - ir_dtype : IRDType = kw_pairs['dtype'] - assert ir_dtype is not None - kw_pairs['dtype'] = IRDType2DType.map(ir_dtype) - - return _common_rule_join_all(node, arg_vars, kw_pairs) - def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: assert arg_vars[1].startswith('self.') @@ -105,7 +94,6 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: _signMap = { 'torch.slice': emit_slice, - 'torch.Tensor.to': emit_to, 'setattr': emit_setattr, } diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 7b70cfff..f7138264 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -11,7 +11,7 @@ from cube.ir.adapter.prim import CollectivePrim from cube.graph.graph import IRSegment -from cube.graph.parser.mapping import Sign2Op +from cube.graph.parser.register import CustomizedOps from cube.execplan import ExecutionPlan from cube.execplan.execplan import ExeRepetend, ExeReuseCell @@ -88,7 +88,7 @@ def __init__(self, execplan: ExecutionPlan) -> None: self.init_code.extend(['import nnfusion', '']) # customized op code - for _, op_impl in Sign2Op.kOpCodeDef.items(): + for _, op_impl in CustomizedOps.kOpCodeDef.items(): # self.init_code.append('@torch.jit.script') self.init_code.append(op_impl) self.init_code += [''] diff --git a/cube/graph/function/anchor.py b/cube/graph/function/anchor.py index 2fec2200..0b4a3e50 100644 --- a/cube/graph/function/anchor.py +++ b/cube/graph/function/anchor.py @@ -12,7 +12,7 @@ class IRGraphAnchor(IRFwOperation): user doesn't need to manipulate it. """ def __init__(self, signature: str, name: str): - super().__init__(name, signature, 0, 1) + super().__init__(name, signature, [], 1) self.kwargs['name'] = name self.set_output(0, None) diff --git a/cube/graph/function/conv.py b/cube/graph/function/conv.py index 27ed3dff..771343fa 100644 --- a/cube/graph/function/conv.py +++ b/cube/graph/function/conv.py @@ -12,10 +12,7 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, signature = 'torch.nn.functional.pad' assert len(inputs) == 1, "Expected only input, weight, bias as inputs" assert len(kwargs) == 3, "Expected 2 kwargs: mode, value" - super().__init__(name, signature, 1, 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update(kwargs) + super().__init__(name, signature, inputs, 1, **kwargs) def infer_shape(self) -> bool: """ @@ -58,10 +55,7 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, signature = 'cube.runtime.function.conv2d' assert len(inputs) == 3, "Expected only input, weight, bias as inputs" assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" - super().__init__(name, signature, 3, 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update(kwargs) + super().__init__(name, signature, inputs, 1, **kwargs) def infer_shape(self) -> bool: """ @@ -107,10 +101,7 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, signature = 'cube.runtime.function.conv3d' assert len(inputs) == 3, "Expected only input, weight, bias as inputs" assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" - super().__init__(name, signature, 3, 1) - for idx, input in enumerate(inputs): - self.set_input(idx, input) - self.kwargs.update(kwargs) + super().__init__(name, signature, inputs, 1, **kwargs) def infer_shape(self) -> bool: """ diff --git a/cube/graph/function/creators.py b/cube/graph/function/creators.py deleted file mode 100644 index f71ef98e..00000000 --- a/cube/graph/function/creators.py +++ /dev/null @@ -1,212 +0,0 @@ -from copy import copy -from typing import List, Optional -from cube.ir.dtype import IRDType - -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor - -import numpy as np - -class IRArange(IRFwOperation): - def __init__(self, signature: str, shape: List[int], name: str, **kwargs): - - # The shape information must be statically known integer values - assert all(isinstance(dim, int) for dim in shape) - assert 'dtype' in kwargs - dtype = kwargs['dtype'] - assert not isinstance(dtype, IRDType) - - from cube.graph.parser.mapping import DType2IRDType - ir_dtype: IRDType = DType2IRDType.map(dtype) - - super().__init__(name, signature, input_length=0, output_length=1) - - # Customize output's dtype only after 'super().__init__' and 'self.set_input', - # otherwise it gets overwritten. - self.output(0).dtype = ir_dtype - self.shape = shape - kwargs.update({'dtype': dtype, 'device': 'cuda'}) #TODO check me and fix more, e.g., ones, zeros, empty - self.kwargs = kwargs - - def infer_shape(self) -> bool: - self.output(0).shape = copy(self.shape) - return True - - def new(self, outputs: List[IRTensor]): - op = IRArange(self.signature, outputs[0].shape, self.name, **self.kwargs) - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IRArange::new infer_shape failed" - return op - -class IREmpty(IRFwOperation): - def __init__(self, signature: str, shape: List[int], name: str, **kwargs): - - # The shape information must be statically known integer values - assert all(isinstance(dim, int) for dim in shape) - assert 'dtype' in kwargs - assert isinstance(kwargs['dtype'], IRDType) - - super().__init__(name, signature, input_length=0, output_length=1) - - # Customize output's dtype only after 'super().__init__' and 'self.set_input', - # otherwise it gets overwritten. - self.output(0).dtype = kwargs['dtype'] - - # The positional argument to specify the shape is actually called 'size'. - self.kwargs = kwargs - self.kwargs.update({"size": copy(shape)}) - - def infer_shape(self) -> bool: - shape : list = copy(self.kwargs["size"]) - self.output(0).shape = shape - return True - - def new(self, outputs: List[IRTensor]): - op = IREmpty(self.signature, outputs[0].shape, self.name, **self.kwargs) - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IREmpty::new infer_shape failed" - return op - -class IRNewTensor(IRFwOperation): - def __init__(self, signature: str, data: list, name: str, **kwargs): - super().__init__(name, signature, input_length=0, output_length=1) - assert 'dtype' in kwargs - assert isinstance(kwargs['dtype'], IRDType) - self.output(0).dtype = kwargs['dtype'] - self.data = data - self.shape = np.array(data).shape - self.kwargs = kwargs - - def infer_shape(self) -> bool: - self.output(0).shape = copy(self.shape) - return True - - def new(self, outputs: List[IRTensor]): - op = IRNewTensor(self.signature, self.data, self.name, **self.kwargs) - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IRNewTensor::new infer_shape failed" - return op - -class IRZeros(IRFwOperation): - def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): - - # The shape information must be statically known integer values - assert all(isinstance(dim, int) for dim in shape) - assert isinstance(ir_dtype, IRDType) - - super().__init__(name, signature, input_length=0, output_length=1) - - # Customize output's dtype only after 'super().__init__' and 'self.set_input', - # otherwise it gets overwritten. - self.output(0).dtype = ir_dtype - - # The positional argument to specify the shape is actually called 'size'. - self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) - - def infer_shape(self) -> bool: - shape : list = copy(self.kwargs["size"]) - self.output(0).shape = shape - return True - - def new(self, outputs: List[IRTensor]): - op = IRZeros(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IRZeros::new infer_shape failed" - return op - -class IROnes(IRFwOperation): - def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): - - # The shape information must be statically known integer values - assert all(isinstance(dim, int) for dim in shape) - assert isinstance(ir_dtype, IRDType) - - super().__init__(name, signature, input_length=0, output_length=1) - - # Customize output's dtype only after 'super().__init__' and 'self.set_input', - # otherwise it gets overwritten. - self.output(0).dtype = ir_dtype - - # The positional argument to specify the shape is actually called 'size'. - self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) - - def infer_shape(self) -> bool: - shape : list = copy(self.kwargs["size"]) - self.output(0).shape = shape - return True - - def new(self, outputs: List[IRTensor]): - op = IROnes(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IROnes::new infer_shape failed" - return op - -class IRRand(IRFwOperation): - def __init__(self, signature: str, shape: List[int], name: str, ir_dtype:IRDType): - - # The shape information must be statically known integer values - assert all(isinstance(dim, int) for dim in shape) - assert isinstance(ir_dtype, IRDType) - - super().__init__(name, signature, input_length=0, output_length=1) - - # Customize output's dtype only after 'super().__init__' and 'self.set_input', - # otherwise it gets overwritten. - self.output(0).dtype = ir_dtype - - # The positional argument to specify the shape is actually called 'size'. - self.kwargs.update({"size": copy(shape), "dtype": ir_dtype}) - - def infer_shape(self) -> bool: - shape : list = copy(self.kwargs["size"]) - self.output(0).shape = shape - return True - - def new(self, outputs: List[IRTensor]): - op = IRRand(self.signature, outputs[0].shape, self.name, self.kwargs['dtype']) - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IRRand::new infer_shape failed" - return op - -# class IRNewTensor(IRFwOperation): -# def __init__(self, signature: str, data: list, name: str, ir_dtype: IRDType): -# super().__init__(name, signature, input_length=0, output_length=1) -# self.output(0).dtype = ir_dtype -# self.kwargs.update({'data': data, 'shape': np.array(data).shape, 'dtype': ir_dtype}) - -# def infer_shape(self) -> bool: -# shape : list = copy(self.kwargs['shape']) -# self.output(0).shape = shape -# return True - - - -# `aten::to` has several overloading, which one should be dispatched is determined by the argument types -# See -# https://github.com/pytorch/pytorch/blob/483bb4f0cb273f42f655aa30eee6a1fbbaba69b0/torch/csrc/jit/runtime/register_prim_ops.cpp#L1057 -# https://github.com/pytorch/pytorch/blob/483bb4f0cb273f42f655aa30eee6a1fbbaba69b0/torch/csrc/jit/runtime/register_prim_ops.cpp#L2215 -class IRToTensor(IRFwOperation): - def __init__(self, signature: str, inputs, name:str, ir_dtype:IRDType): - - assert isinstance(ir_dtype, IRDType) - - super().__init__(name, signature, input_length=1, output_length=1) - self.set_input(0, inputs[0]) - - # Customize output's dtype only after 'super().__init__' and 'self.set_input', - # otherwise it gets overwritten. - self.output(0).dtype = ir_dtype - - self.kwargs.update({"dtype": ir_dtype}) - - def infer_shape(self) -> bool: - self.output(0).shape = self.input(0).shape - return True - - def new(self, inputs: List[IRTensor], outputs: List[IRTensor]): - op = IRToTensor(self.signature, inputs, self.name, self.kwargs['dtype']) - op.set_output(0, outputs[0]) - assert op.infer_shape(), "IRToTensor::new infer_shape failed" - return op - - diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index d991b4d7..0abb3a94 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -67,7 +67,7 @@ import re import string -from cube.ir.cten import IRTensor +from cube.ir.cten import IRTensor, IRObject from cube.ir.dtype import DTypeInferRule from cube.ir.operator import IRFwOperation from cube.algorithm.factory import DistAlgorithmFactory @@ -567,7 +567,7 @@ class IRDimops(IRFwOperation): """ def __init__(self, create_fn: Callable, name: str, signature: str, annos: Tuple[str], - inputs: List[IRTensor], + inputs: List[Union[IRTensor, IRObject]], transform_rules: Optional[Tuple[TransformRule]] = None, **kwargs): """! @@ -606,12 +606,7 @@ def __init__(self, create_fn: Callable, name: str, ) n_outputs = len(self._oannos) - super().__init__(name, signature, len(inputs), n_outputs) - # set input - for idx, input in enumerate(inputs): - self.set_input(idx, input) - for name in kwargs: - self.kwargs[name] = kwargs[name] + super().__init__(name, signature, inputs, n_outputs, **kwargs) @property def anno(self) -> OpAnno: @@ -649,7 +644,6 @@ def infer_shape(self) -> bool: @return sucess: True if successfully inferred shape """ - idtypes = [t.dtype for t in self._inputs if isinstance(t, IRTensor)] for oidx, otensor in enumerate(self.outputs()): shape_anno = self.oanno(oidx) shape = [] diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 1daf99ba..940451d5 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -12,13 +12,9 @@ from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D -from cube.graph.function.creators import IRToTensor from cube.graph.function.anchor import IRGraphAnchor -ErasedDevice = 'str' - - def Identity(tensor: IRObject, signature = None): signature = 'cube.runtime.function.identity' eshape = ShapeAnno.create_shape_str(tensor.shape) @@ -149,7 +145,7 @@ def CubeArange(start: int, end: int, step: int, dtype=None, 'dtype': dtype, 'requires_grad': requires_grad} anno, rules = _get_creator_anno_rules(size, False) dimop = IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.mapping import DType2IRDType + from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) return dimop @@ -187,7 +183,7 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(size, True) dimop = IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.mapping import DType2IRDType + from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) return dimop @@ -207,7 +203,7 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules(size, True) dimop = IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.mapping import DType2IRDType + from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) return dimop @@ -227,7 +223,7 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules(size, True) dimop = IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.mapping import DType2IRDType + from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) return dimop @@ -248,7 +244,7 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(size, True) dimop = IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.mapping import DType2IRDType + from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) return dimop @@ -267,46 +263,11 @@ def NewTensor(data, *, dtype=None, device=None, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(size, True) dimop = IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.mapping import DType2IRDType + from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) return dimop -def ToTensor(signature, - inputs: Tuple[ IRTensor, ... ]): - """ - 'aten::to' has many overloadings that need resolution, - they differ by both the arity and the type of the argument (possibly at the same position): - - ``` - aten::to.device(Tensor self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor): - aten::to.dtype(Tensor self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor): - aten::to.dtype_layout(Tensor self, *, int dtype, int layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor): - aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor): - aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)): - aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)): - aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)): - ``` - ... where the 'int? dtype' is the underlying type for the enum 'ScalarType'. - """ - - # in our case we only care the overloading 'to.dtype' (arity=5) - assert len(inputs) == 5 - tensor : IRTensor - dtype_underlying : int - non_blocking : bool - copy : bool - opt_memory_format : Optional[int] - tensor, dtype_underlying, non_blocking, copy, opt_memory_format = inputs - - from cube.graph.parser.mapping import DType2IRDType, TorchScalarTypeEnumMap - dtype : torch.dtype = TorchScalarTypeEnumMap.map(dtype_underlying) - ir_dtype : IRDType = DType2IRDType.map(dtype) - - signature = 'torch.Tensor.to' - return IRToTensor(signature, [tensor], 'to', ir_dtype=ir_dtype) - - def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: """! Create shape annotations for element wise operator following broadcastable rules: @@ -1046,13 +1007,15 @@ def Unsqueeze(input, dim, signature = None): return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input],dim=dim) -def TypeAs(input, tensor, signature = None): +def TypeAs(input: IRTensor, tensor: IRTensor, signature = None): """ out = torch.Tensor.type_as(tensor0, tensor1) """ edim_in0 = ShapeAnno.create_shape_str(tensor.shape) anno = OpAnno.create_op_str(['*', edim_in0], ['*']) - return IRDimops(TypeAs, 'type_as', signature, [anno], [input, tensor]) + dimop = IRDimops(TypeAs, 'type_as', signature, [anno], [input, tensor]) + dimop.output(0).requires_grad = input.requires_grad + return dimop def Triu(input, diagonal=0, *, out=None, signature = None): diff --git a/cube/graph/function/pyfunc.py b/cube/graph/function/pyfunc.py index e0b2aa85..68430e31 100644 --- a/cube/graph/function/pyfunc.py +++ b/cube/graph/function/pyfunc.py @@ -12,9 +12,7 @@ class IRPyFunc(IRFwOperation): def __init__(self, signature: str, inputs: Tuple[IRObject], outputs: Tuple[IRObject], **kwargs): name = signature.split('.')[-1] - super().__init__(name, signature, len(inputs), len(outputs)) - for idx, t in enumerate(inputs): - self.set_input(idx, t) + super().__init__(name, signature, inputs, len(outputs)) for idx, t in enumerate(outputs): self.set_output(idx, t) self.kwargs.update(**kwargs) diff --git a/cube/graph/gener/utils.py b/cube/graph/gener/utils.py index cd11495b..35325c45 100644 --- a/cube/graph/gener/utils.py +++ b/cube/graph/gener/utils.py @@ -13,13 +13,10 @@ class DummyInputOuput(IRFwOperation): def __init__(self, tensor: IRSubTensor, device: int, is_input=False, is_output=False, name='dummy'): - super().__init__(name, name, - 1 if is_input else 0, - 1 if is_output else 0 - ) assert (is_input and not is_output) or (is_output and not is_input) - if is_input: - self.set_input(0, tensor) + inputs = [tensor] if is_input else [] + outputs = [tensor] if is_output else [] + super().__init__(name, name, inputs, len(outputs)) if is_output: self.set_output(0, tensor) self.device = device diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 8c4a738c..de6b9239 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,11 +7,10 @@ will be inserted at scheduling time. """ -from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any, Callable +from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any import warnings import copy -import cube.flags from cube.ir.cten import IRTensor, IRCell, IRObject from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation @@ -136,22 +135,6 @@ def backward(self, loss: IRSubTensor): """ # set mirror as self self._mirror = self - - if not cube.flags.CompileFlag.use_torchfx: - # infer gradient requirement - for node in self.nodes(): - itensors = [t for t in node.inputs() if isinstance(t, IRTensor)] - if node.name == 'type_as': - require_grad = itensors[0].requires_grad - else: - require_grad = any(t.requires_grad for t in itensors) - for otensor in node.outputs(): - if not isinstance(otensor, IRTensor): continue - if isinstance(otensor, IRSubTensor): - otensor.parent.requires_grad = require_grad - else: - otensor.requires_grad = require_grad - # set loss gradient loss.parent.to_loss() diff --git a/cube/graph/parser/dtype.py b/cube/graph/parser/dtype.py new file mode 100644 index 00000000..97f264fd --- /dev/null +++ b/cube/graph/parser/dtype.py @@ -0,0 +1,42 @@ +import torch +import cube.ir as ir + + +class DType2IRDType: + + @staticmethod + def map(dtype: torch.dtype): + """ + Map the torch dtype to IRDType + """ + return DType2IRDType.kDtypeMap[dtype] + + kDtypeMap = { + torch.double: ir.float64, + torch.float64: ir.float64, + torch.float32: ir.float32, + torch.float : ir.float32, + torch.float16: ir.float16, + torch.half : ir.float16, + torch.uint8 : ir.uint8, + torch.int8 : ir.int8, + torch.int16 : ir.int16, + torch.short : ir.int16, + torch.int32 : ir.int32, + torch.int : ir.int32, + torch.int64 : ir.int64, + torch.long : ir.int64, + torch.bool : ir.boolean + } + + +class IRDType2TorchDType: + + @staticmethod + def map(ir_dtype: ir.IRDType): + """ + Map the IRDtype to torch dtype + """ + return IRDType2TorchDType.kDtypeMap[ir_dtype] + + kDtypeMap = {val: key for key, val in DType2IRDType.kDtypeMap.items()} \ No newline at end of file diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/mapping.py index 257b4dc3..126a18f4 100644 --- a/cube/graph/parser/mapping.py +++ b/cube/graph/parser/mapping.py @@ -1,17 +1,12 @@ -""" -Mapping of - signature -> IROperator - torch.dtype -> cube.ir.IRDType - cube.ir.IRDType -> torch.dtype -""" + import torch -from typing import Callable, Dict, Union +from typing import Callable, Union from functools import partial import cube.graph.function as function -import cube.ir as ir from cube.ir.operator import IRFwOperation +from cube.graph.parser.register import CustomizedOps class Sign2Op: @@ -21,26 +16,19 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: """ Map the signature to GenericLogicalOp """ - if 'torch.' not in signature and 'cube.runtime.' not in signature: - signature = signature.split('.')[-1] if signature in Sign2Op.kOpMap: return partial(Sign2Op.kOpMap[signature], signature=signature) - else: - raise KeyError(f"{signature} is not supported yet") - # print(f'warning: {signature} is not recognized') - # return partial(function.UnkownOperator, signature=signature) + if CustomizedOps.exist(signature): + return CustomizedOps.map(signature) + raise KeyError(f"{signature} is not supported yet") @staticmethod - def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]], code): - """ - Register an operator - """ - if not isinstance(signature, str): - raise TypeError(f"Expected signature to be str but got {type(signature)}") + def exist(signature: str) -> bool: if signature in Sign2Op.kOpMap: - raise KeyError(f"function {signature} is already registered") - Sign2Op.kOpMap[signature] = op - Sign2Op.kOpCodeDef[signature] = code + return True + if CustomizedOps.exist(signature): + return True + return False # functional templates __ftemplate = lambda name: f'torch.nn.functional.{name}' @@ -84,7 +72,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] __ttemplate('zeros'): function.Zeros, __ttemplate('ones'): function.Ones, __ttemplate('tensor'): function.NewTensor, - __ttemplate('to'): function.ToTensor, __ttemplate('rand'): function.Rand, __ttemplate('clone'): function.Clone, @@ -165,49 +152,6 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] } - # customized operator code: signature -> code - kOpCodeDef: Dict[str, str] = {} - - -class DType2IRDType: - - @staticmethod - def map(dtype: torch.dtype): - """ - Map the torch dtype to IRDType - """ - return DType2IRDType.kDtypeMap[dtype] - - kDtypeMap = { - torch.double: ir.float64, - torch.float64: ir.float64, - torch.float32: ir.float32, - torch.float : ir.float32, - torch.float16: ir.float16, - torch.half : ir.float16, - torch.uint8 : ir.uint8, - torch.int8 : ir.int8, - torch.int16 : ir.int16, - torch.short : ir.int16, - torch.int32 : ir.int32, - torch.int : ir.int32, - torch.int64 : ir.int64, - torch.long : ir.int64, - torch.bool : ir.boolean - } - - -class IRDType2TorchDType: - - @staticmethod - def map(ir_dtype: ir.IRDType): - """ - Map the IRDtype to torch dtype - """ - return IRDType2TorchDType.kDtypeMap[ir_dtype] - - kDtypeMap = {val: key for key, val in DType2IRDType.kDtypeMap.items()} - # see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h # diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 32e28d8a..79558439 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -1,11 +1,11 @@ -import torch -from typing import Callable, Dict, Union +from typing import Callable, Union from functools import partial import cube.graph.function as function -import cube.ir as ir from cube.ir.operator import IRFwOperation +from cube.graph.parser.register import CustomizedOps + class SignFx2Op: @@ -14,37 +14,20 @@ def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: """ Map the signature to GenericLogicalOp """ - bultin_regions = ['torch.', 'cube.runtime.', '_operator.', 'builtins.'] - # customized function - if all(not signature.startswith(region) for region in bultin_regions): - signature = signature.split('.')[-1] if signature in SignFx2Op.kOpMap: function = SignFx2Op.kOpMap[signature] - # signature = 'torch.sum' if signature == 'sum' else signature #TODO fixme return partial(function, signature=signature) - else: - raise KeyError(f"{signature} is not supported yet") - # return partial(function.UnkownOperator, signature=signature) + if CustomizedOps.exist(signature): + return CustomizedOps.map(signature) + raise KeyError(f"{signature} is not supported yet") @staticmethod def exist(signature: str) -> bool: - bultin_regions = ['torch.', 'cube.runtime.', '_operator.', 'builtins.'] - # customized function - if all(not signature.startswith(region) for region in bultin_regions): - signature = signature.split('.')[-1] - return signature in SignFx2Op.kOpMap - - @staticmethod - def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]], code): - """ - Register an operator - """ - if not isinstance(signature, str): - raise TypeError(f"Expected signature to be str but got {type(signature)}") if signature in SignFx2Op.kOpMap: - raise KeyError(f"function {signature} is already registered") - SignFx2Op.kOpMap[signature] = op - SignFx2Op.kOpCodeDef[signature] = code + return True + if CustomizedOps.exist(signature): + return True + return False # functional templates __ftemplate = lambda name: f'torch.nn.functional.{name}' @@ -223,6 +206,3 @@ def register(signature: str, op: Callable[..., Union[IRFwOperation, int, float]] # __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, } - - # customized operator code: signature -> code - kOpCodeDef: Dict[str, str] = {} diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/parser.py index d3c6e856..0bc8aad4 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/parser.py @@ -9,7 +9,8 @@ from cube.ir.tensor import IRFullTensor import cube.ir as ir from cube.graph.parser.frame import Frame -from cube.graph.parser.mapping import Sign2Op, DType2IRDType +from cube.graph.parser.mapping import Sign2Op +from cube.graph.parser.dtype import DType2IRDType _refmodule = torch.nn.Module() diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 794de871..19b5dbbf 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -7,7 +7,7 @@ from cube.ir.tensor import IRFullTensor from cube.ir.cten import IRObject, IRCell from cube.graph.parser.frame import Frame -from cube.graph.parser.mapping import DType2IRDType +from cube.graph.parser.dtype import DType2IRDType from cube.graph.parser.mappingfx import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc @@ -278,10 +278,7 @@ def get_complex_data(val: Any) -> Any: if FxModuleParser._is_torch_autograd_op(node, frame, fsig): print(f'>>> Find unknown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fname - ir_node = IRFwOperation(fname, fsig, len(input_vals), 1) - ir_node.kwargs = kwargs - for idx, t in enumerate(input_vals): - ir_node.set_input(idx, t) + ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) # case2: python runtime function else: print(f'>>> Set python runtime function: {fsig}') @@ -296,6 +293,7 @@ def get_complex_data(val: Any) -> Any: # setting the list of the output tensor print('>> parsing {ir_node}') ir_node.infer_shape() + ir_node.infer_dtype() frame.set_var(node.name, ir_node.outputs()) else: output_val = frame.get_var(node.name) diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 80047d63..310d4eae 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -2,12 +2,45 @@ Register cutomized function """ -from typing import Any, Callable, List, Optional +from typing import Dict, Callable, List, Optional +from functools import partial import inspect import torch from cube.graph.function.dimops import IRDimops, OpAnno -from cube.graph.parser.mapping import Sign2Op + + +class CustomizedOps: + + kOpMap: Dict[str, Callable] = {} + # customized operator code: signature -> code + kOpCodeDef: Dict[str, str] = {} + + @staticmethod + def map(signature: str) -> Callable: + signature = signature.split('.')[-1] + if signature in CustomizedOps.kOpMap: + return partial(CustomizedOps.kOpMap[signature], signature=signature) + else: + raise KeyError(f"{signature} is not found in registered ops") + + @staticmethod + def exist(signature: str) -> bool: + signature = signature.split('.')[-1] + return signature in CustomizedOps.kOpMap + + @staticmethod + def register(signature: str, op: Callable, code: str): + """ + Register an operator + """ + builtins = ['_operator', 'torch', 'cube.runtime.function'] + if any(signature.startswith(builtin) for builtin in builtins): + raise RuntimeError(f"Cannot register operators with signature starting from any of {builtins}") + signature = signature.split('.')[-1] + assert signature not in CustomizedOps.kOpMap, f"function {signature} is already registered" + CustomizedOps.kOpMap[signature] = op + CustomizedOps.kOpCodeDef[signature] = code def register(anno: str, name: Optional[str] = None, rules: Optional[List] = None): @@ -70,7 +103,7 @@ def udfop(*args, signature=None, **kwargs): return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, transform_rules=rules, **kwargs) print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') - Sign2Op.register(fsig, udfop, code) + CustomizedOps.register(fsig, udfop, code) return fn return decorator diff --git a/cube/ir/operator.py b/cube/ir/operator.py index fa234bef..995450b2 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -1,11 +1,10 @@ -from typing import Optional, Tuple, Any, Union +from typing import Optional, Tuple, Any, Union, List import copy from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.tensor import IRFullTensor from cube.algorithm.factory import DistAlgorithmFactory from cube.algorithm.generics import GenericDistAlgo -from cube.ir.unique import IDGenerator from cube.ir.dtype import IRDType, DTypeInferRule @@ -14,11 +13,8 @@ class IRFwOperation(IRCell): Forward operation """ - def __init__(self, - name: str, - signature: str, - input_length: int, - output_length: int): + def __init__(self, name: str, signature: str, + inputs: List[IRObject], num_outputs: int, **kwargs): """! Create a forward operation. @@ -27,28 +23,42 @@ def __init__(self, @param input_length int: number of inputs @param output_length int: number of outputs """ - # additional argument - self.kwargs = dict() # recompute schedule self._recompute = None - super().__init__(name, signature, input_length, output_length, init_outputs=False) - outputs = [IRFullTensor() for _ in range(output_length)] + super().__init__(name, signature, len(inputs), + num_outputs, init_outputs=False) + + # setup input + for idx, input in enumerate(inputs): + self.set_input(idx, input) + + # additional argument + self.kwargs = kwargs + + # default infer rule + requires_grad = any( + t.requires_grad for t in inputs if isinstance(t, IRTensor)) + + # setup output + outputs = [IRFullTensor(requires_grad=requires_grad) for _ in range(num_outputs)] for idx, output in enumerate(outputs): self.set_output(idx, output) def infer_dtype(self): """ Infer output value dtype. - By default will follow the same dtype promotion rule with PyTorch. """ itensors = [t for t in self.inputs() if isinstance(t, IRTensor)] - assert len(itensors) > 0, "Missing input tensors, need to customize the infer rule" - odtype = DTypeInferRule.infer(self, [t.dtype for t in itensors]) - assert odtype != IRDType.unknown, f"{self} : {[t.dtype for t in itensors]}" otensors = [t for t in self.outputs() if isinstance(t, IRTensor)] + odtype = DTypeInferRule.infer(self, [t.dtype for t in itensors]) for tensor in otensors: - tensor.dtype = odtype + # in case of setting manually due to special rules + if tensor.dtype == IRDType.unknown: + if isinstance(tensor, IRFullTensor): + tensor.dtype = odtype + else: + tensor.parent.dtype = odtype def infer_shape(self): """ diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 18f53903..b0488d11 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -12,9 +12,8 @@ import cube from cube.ir.cten import IRTensor, IRObject from cube.ir.operator import IRFwOperation -from cube.graph.parser.mapping import IRDType2TorchDType -# from cube.graph.parser.mapping import Sign2Op -from cube.graph.parser.mappingfx import SignFx2Op as Sign2Op +from cube.graph.parser.dtype import IRDType2TorchDType +from cube.graph.parser.register import CustomizedOps Shapes = NewType('Shapes', Tuple[Tuple[int]]) @@ -168,21 +167,21 @@ def get_func(node: IRFwOperation) -> Tuple[Callable, Shapes, DTypes, Dict]: def get_dep_names(sign: str): ret = [] - code_impl = Sign2Op.kOpCodeDef[sign] + code_impl = CustomizedOps.kOpCodeDef[sign] for code_line in code_impl.split('\n'): idx = code_line.find('# call: ') if idx != -1: dep_name = code_line[idx + 8:] - assert dep_name in Sign2Op.kOpCodeDef, dep_name + assert dep_name in CustomizedOps.kOpCodeDef, dep_name ret = ret + get_dep_names(dep_name) ret.append(dep_name) return ret - if node.signature in Sign2Op.kOpCodeDef: + if node.signature in CustomizedOps.kOpCodeDef: dep_code_impl = '' for dep_name in get_dep_names(node.signature): - dep_code_impl = dep_code_impl + Sign2Op.kOpCodeDef[dep_name] - code_impl: str = Sign2Op.kOpCodeDef[node.signature] + dep_code_impl = dep_code_impl + CustomizedOps.kOpCodeDef[dep_name] + code_impl: str = CustomizedOps.kOpCodeDef[node.signature] def_end = code_impl.find(':\n') assert def_end >= 0 prev_code_lines = code_impl[:def_end+2] diff --git a/cube/program.py b/cube/program.py index 581e8b68..b14075a0 100644 --- a/cube/program.py +++ b/cube/program.py @@ -7,7 +7,7 @@ from cube.graph import IRGraph from cube.graph import parser -from cube.graph.parser.mapping import DType2IRDType +from cube.graph.parser.dtype import DType2IRDType from cube.runtime.syndata import CubeDataLoader from cube.runtime.module import CubeModule From 35bcac4f46273fdb326d687bbc77e46338f26f03 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 17 Mar 2023 08:11:13 +0000 Subject: [PATCH 1358/1892] Merged PR 1498: fix view and reshape bug on bracket fix view and reshape bug on bracket --- cube/graph/function/function.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 940451d5..d0cef82d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -727,6 +727,7 @@ def nele(shape, nele=1): shape_map: Dict[str, int] = {edim: eshape for (edim, eshape) in zip(edims, chain)} # generate input and output shape annotations + # greedy fuse suffix number def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[str]]: anno = [] dimidx = 0 @@ -744,6 +745,18 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s elements *= chain[dimidx] bracket.append(edims[dimidx]) dimidx += 1 + # fetch as many 1^ as possible from tail of the previous bracket + if len(bracket) == 0: + assert dimlen == 1, f"internal match error3: dimlen={dimlen}" + back = 0 + for edim in anno[-1][1:][::-1]: + if chain[edims.index(edim)] != 1: + break + back += 1 + assert back > 0, f"internal match error4: dimlen={dimlen}" + bracket = anno[-1][-back:] + anno[-1] = anno[-1][:-back] + assert len(bracket) > 0, f"got a dimension with no edim" anno.append(bracket) return anno @@ -895,6 +908,18 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s elements *= chain[dimidx] bracket.append(edims[dimidx]) dimidx += 1 + # fetch as many 1^ as possible from tail of the previous bracket + if len(bracket) == 0: + assert dimlen == 1, f"internal match error3: dimlen={dimlen}" + back = 0 + for edim in anno[-1][1:][::-1]: + if chain[edims.index(edim)] != 1: + break + back += 1 + assert back > 0, f"internal match error4: dimlen={dimlen}" + bracket = anno[-1][-back:] + anno[-1] = anno[-1][:-back] + assert len(bracket) > 0, f"got a dimension with no edim" anno.append(bracket) return anno From f70ab8ea86fd7285cbc2a1b2e5285ae838ff3126 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sat, 18 Mar 2023 08:40:16 -0700 Subject: [PATCH 1359/1892] save work --- cube/algorithm/ops/dimops.py | 7 ++-- cube/graph/function/function.py | 38 +++++++++++++++++--- cube/profiler/database.py | 9 +++-- examples/nlp/torchscale/run_torchscale_lm.py | 27 ++++++++++---- 4 files changed, 65 insertions(+), 16 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index a3b53885..fa532c1d 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -344,7 +344,6 @@ def collect_split_info(node: IRFwOperation): return split_info def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: - split_info = collect_split_info(node) def gen_hash(node: IRFwOperation) -> str: ret = node.signature @@ -363,10 +362,11 @@ def gen_hash(node: IRFwOperation) -> str: while dq: cur_node, cur_ngpus = dq.popleft() gen_nodes.append(cur_node) + split_info = collect_split_info(cur_node) for key, val in split_info.items(): idx_1st, dim_1st, _ = val - dim_size = cur_node.inputs()[idx_1st].shape[dim_1st] + dim_size = cur_node.anno.getlen(key) # TODO(yizhu1): only consider powers of 2 currently split_deg = 2 @@ -374,7 +374,8 @@ def gen_hash(node: IRFwOperation) -> str: if dim_size % split_deg != 0: break - new_node = cur_node.algorithms('dim').instantiate(idx=idx_1st, dim=dim_1st, num=split_deg)[0] + new_nodes = cur_node.algorithms('dim').instantiate(idx=idx_1st, dim=dim_1st, num=split_deg) + new_node = new_nodes[0] new_ngpus = cur_ngpus // split_deg cur_key = gen_hash(new_node) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index c1fa5dc8..325fb468 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -808,6 +808,18 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s elements *= chain[dimidx] bracket.append(edims[dimidx]) dimidx += 1 + # fetch as many 1^ as possible from tail of the previous bracket + if len(bracket) == 0: + assert dimlen == 1, f"internal match error3: dimlen={dimlen}" + back = 0 + for edim in anno[-1][1:][::-1]: + if chain[edims.index(edim)] != 1: + break + back += 1 + assert back > 0, f"internal match error4: dimlen={dimlen}" + bracket = anno[-1][-back:] + anno[-1] = anno[-1][:-back] + assert len(bracket) > 0, f"got a dimension with no edim" anno.append(bracket) return anno @@ -959,6 +971,18 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s elements *= chain[dimidx] bracket.append(edims[dimidx]) dimidx += 1 + # fetch as many 1^ as possible from tail of the previous bracket + if len(bracket) == 0: + assert dimlen == 1, f"internal match error3: dimlen={dimlen}" + back = 0 + for edim in anno[-1][1:][::-1]: + if chain[edims.index(edim)] != 1: + break + back += 1 + assert back > 0, f"internal match error4: dimlen={dimlen}" + bracket = anno[-1][-back:] + anno[-1] = anno[-1][:-back] + assert len(bracket) > 0, f"got a dimension with no edim" anno.append(bracket) return anno @@ -1223,9 +1247,14 @@ def CubeStack(*tensors, dim=0, signature=None): assert isinstance(dim, int), f"but not {dim}" signature = 'cube.runtime.function.stack' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] - oannos = [copy.copy(iannos[-1])] - oannos[0].insert(dim, str(len(tensors))) - anno = OpAnno.create_op_str(iannos, oannos) + oanno = [None for i in range(len(tensors[0].shape) + 1)] + oanno[dim] = f'{len(tensors)}^' + offset = 0 + for i in range(len(oanno)): + if oanno[i] is None: + oanno[i] = copy.copy(iannos[-1][offset]) + offset += 1 + anno = OpAnno.create_op_str(iannos, [oanno]) return IRDimops(CubeStack, 'stack', signature, [anno], tensors, dim=dim) @@ -1303,7 +1332,8 @@ def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice]], signature=No # expand the dimension edim_ou.append('1') else: - edim_in[dim] += '^' + if slicer != slice(None, None, None): + edim_in[dim] += '^' if isinstance(slicer, slice): stop = tensor.shape[dim] if slicer.stop is None else slicer.stop start = 0 if slicer.start is None else slicer.start diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 44a290eb..ce815e9c 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -140,7 +140,7 @@ def unpack_hook(x): torch.cuda.synchronize() toc = time.perf_counter() fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds - bw_span = fwbw_span - fw_span + bw_span = max(fwbw_span - fw_span, 0.0) return fw_span, bw_span, infer_memory, tuple(train_mem_info) @@ -253,8 +253,11 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): print(f'WARNING: input {t} is skipped.') # run profiling - fw_span, bw_span, infer_memory, train_mem_info = \ - CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) + try: + fw_span, bw_span, infer_memory, train_mem_info = \ + CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) + except: + fw_span, bw_span, infer_memory, train_mem_info = float('inf'), float('inf'), 0, [0] # log to database key = self._serialize(node) self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index d52fd96f..fb9f2e40 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -26,6 +26,8 @@ import cube +bs, seql, ngpu = 2, 2048, 4 + # # build model parser = options.get_training_parser() args = options.parse_args_and_arch(parser) @@ -36,8 +38,10 @@ print("building model succeed: ", type(model)) # create dummy input -with open('/home/quzha/quzha/MagicCube/examples/nlp/torchscale/input_lm', 'rb') as f: +# with open('/home/quzha/quzha/MagicCube/examples/nlp/torchscale/input_lm', 'rb') as f: +with open(f'/home/quzha/torchscale_{bs}_{seql}.pkl', 'rb') as f: dummy_input = pickle.load(f) +dummy_input = dummy_input['net_input'] device = next(model.parameters()).device print(f'device = {device}') for key in dummy_input.keys(): @@ -177,8 +181,9 @@ class dotdict(dict): __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ -config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': 'torchscale'}) -config.autodist_config = dotdict({'ngpus': 2}) + +config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': f'torchscale_{bs}_{seql}_{ngpu}'}) +config.autodist_config = dotdict({'ngpus': ngpu}) # NOTE add SINGLE_DEV_MODE=1 before the running command from autodist.cost_model.cost_database import CostDatabase cost_database = CostDatabase(cube_graph, config) @@ -199,13 +204,23 @@ def __init__(self, **kwargs): self.allow_recom_ops = [] self.del_dim = [] -kwargs = {'save_folder': 'exp_data', 'micro_batch_size': 8, 'global_batch_size': 8, 'iter_num': 2, - 'warm_num': 1, 'recompute': False, 'memory_constraint': 32, 'memory_granularity': 1, +kwargs = {'consider_mem': False, 'save_folder': 'exp_data', 'micro_batch_size': bs, 'global_batch_size': bs, 'iter_num': 2, + 'warm_num': 1, 'recompute': False, 'memory_constraint': 40, 'memory_granularity': 1, 'profile_dir': str(Path.home())+'/.autodist/', 'connect_type': 'NV2', 'use_prev_plan': False, - 'is_train': True, 'topk': 20, 'mesh_row': 1, 'mesh_col': 2, 'compile': True, 'pipeline': False, + 'is_train': True, 'topk': 20, 'mesh_row': 1, 'mesh_col': ngpu, 'compile': True, 'pipeline': False, 'nproc': 12, 'adaptive_recom': False, 'plan_idx': 0, 'verbose': True, 'ignore_small_tensor_threshold': 0, 'parse_plan': True} task_config = TorchscaleTaskConfig(**kwargs) from autodist.apis import calc_parallel_plan topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) + +best_plan = topk_plans[0][0].partition_descs + +from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation + +for node in cube_graph.select(ntype=IRFwOperation): + if node.cid in best_plan: + print(f'{node}, {node.anno}, autodist ret: {best_plan[node.cid]}') + else: + print(f'{node}, switch to default replica') \ No newline at end of file From 226509eba858a0e211b9b70ca90c6351e890e18d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 20 Mar 2023 02:41:19 +0000 Subject: [PATCH 1360/1892] Merged PR 1497: support load checkpointed weights support load checkpointed weights --- cube/graph/parser/parserfx.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 19b5dbbf..90ef2363 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -314,11 +314,16 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram tensor_name = node.name if 'tensor_meta' in node.meta: tensor_shape = node.meta['tensor_meta'].shape - #TODO assume it is weight dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) - ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=True, dtype=dtype) - ir_tensor.as_param() + requires_grad = node.meta['tensor_meta'].requires_grad + ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=requires_grad, dtype=dtype) + if requires_grad: # case for registered parameters + ir_tensor.as_param() + else: # case for registered buffers + ir_tensor.as_buffer() frame.add_var(tensor_name, ir_tensor) + value = FxModuleParser.fetch_attr(module, node.target) + frame.add_attr_content(ir_tensor.tid, value) else: var = FxModuleParser.fetch_attr(module, node.target) frame.add_var(tensor_name, var) From 8aec516901bad5c7a9d2c7cde06d7b0c5d24eeba Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Tue, 21 Mar 2023 01:07:44 +0000 Subject: [PATCH 1361/1892] Merged PR 1500: Refine torchscale spmd.PASData Removing special treatment on certain operators Adding auto multiref --- cube/graph/gener/gen.py | 4 ++- examples/nlp/torchscale/policy/spmd.py | 34 ++++++++------------------ 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 9badd327..064a9260 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -676,7 +676,9 @@ def autoref(graph: IRSegment) -> IRGraph: output.grad = multiref.output(idx).grad.parent.select(otensor.indmap, (0,1)) mr.set_output(idx, output) mr.device = otensor.device - mr.recompute = otensor.cell.recompute + # Dataloader has no recompute + if hasattr(otensor.cell, 'recompute'): + mr.recompute = otensor.cell.recompute multirefs.append(mr) # remove original multiref fidx = graph.remove(multiref) diff --git a/examples/nlp/torchscale/policy/spmd.py b/examples/nlp/torchscale/policy/spmd.py index f9cc611e..15e39c53 100644 --- a/examples/nlp/torchscale/policy/spmd.py +++ b/examples/nlp/torchscale/policy/spmd.py @@ -66,6 +66,13 @@ def PASData(graph: IRGraph, resource): """ Data Parallel """ + # auto multi-ref + for ftensor in graph.full_tensors(): + if len(graph.consumers(ftensor)) > 1: + if ftensor.is_attr(): + continue + graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) + for node in graph.nodes(): if isinstance(node, IRDataOperation): algo = node.algorithms('data') @@ -77,30 +84,9 @@ def PASData(graph: IRGraph, resource): if isinstance(node, IRFwOperation): try: algo = node.algorithms('dim') - - must_replicate = False - for itensor in node.inputs(): - if not isinstance(itensor, IRTensor): - continue - - print(f'itersor = {itensor}') - for consumer in graph.consumers(itensor.parent): - if consumer.name == 'fullslice': - must_replicate = True - break - if must_replicate == True: - break - - if must_replicate: - print(f'##### must_replicate {node.name}') - sub_nodes = graph.replicate(node, resource.ngpus) - else: - idx = 0 - if node.name in {'type_as'}: - print(f"###### {node.name}") - idx = 1 - sub_nodes = graph.partition( - node, algo, idx=idx, dim=batch_dim, num=resource.ngpus) + idx = 0 + sub_nodes = graph.partition( + node, algo, idx=idx, dim=batch_dim, num=resource.ngpus) except AssertionError: print(f'WARNING: {node} cannot find dim algo, using replicate instead') sub_nodes = graph.replicate(node, resource.ngpus) From 3e81df9a413f42458028eff0170554b03331905e Mon Sep 17 00:00:00 2001 From: rwlu Date: Wed, 22 Mar 2023 01:02:35 -0700 Subject: [PATCH 1362/1892] implement fined_grained profile functions --- cube/profiler/database.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index b0488d11..f19f4a9b 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -408,6 +408,21 @@ def dump(self, file: str, override=False): with open(file, 'w') as f: json.dump(self._data, f) + + def dump_fine_grained(self, file: str, override=False): + if os.path.exists(file): + assert override, f"File {file} exists. Set override = True to force dump." + for signature in self._data.keys(): + file_n = os.path.join(file, signature +'.json') + with open(file_n, 'w') as f: + json.dump(self._data[signature],f) + + def dump_single(self, file: str, signature ,override=False): + assert signature in self._data.keys(), f'this node not be profiled' + file_n = os.path.join(file, signature +'.json') + with open(file_n, 'w') as f: + json.dump(self._data[signature],f) + def load(self, file: str): """! load the profiled data into data base. The original existed one will be @@ -418,6 +433,13 @@ def load(self, file: str): with open(file, 'r') as f: self._data = json.load(f) + def load_fine_grained(self, file: str): + for filename in os.listdir(file): + if filename.endswith('.json'): + with open(os.path.join(file, filename)) as f: + signature = filename[:-len('.json')] + self._data[signature] = json.load(f) + def __repr__(self) -> str: data = [] for signature in self._data: From 0acabc14e1349200bec06c29a9c2fb43a853689a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 22 Mar 2023 08:46:32 +0000 Subject: [PATCH 1363/1892] Merged PR 1501: Fix dummy input output 1) FX parser allow to give IRObject input 2) add input/output IRObject into producer and consumer 2) dummy input/ouput operators are not inserted into segment --- cube/graph/gener/gen.py | 68 +++++++++++++++++------------------ cube/graph/graph.py | 19 ++++++---- cube/graph/parser/parserfx.py | 11 ++++-- cube/graph/segment.py | 40 ++++++++++++++++----- cube/program.py | 5 +++ 5 files changed, 89 insertions(+), 54 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 064a9260..7dc39513 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -6,7 +6,7 @@ from cube.graph.gener.concurrent import ConcurrentGener import cube.graph.gener.utils as utils from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment +from cube.graph.segment import IRSegment, CellPosition from cube.graph.function.pyfunc import IRPyFunc from cube.ir.cten import IRCell, IRObject @@ -23,7 +23,6 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) -> List[IRFwOperation]: """ Create dummy operators segment inputs and outputs. - The backward operator is also inserted. @param segment IRSegment: the target segment @param inputs bool: True for creating dummy operators to produce segement's inputs @@ -32,8 +31,8 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) @return nodes List[IRCell]: the generated operation """ # devices = segment.device - fwops = [] - + input_producers: Dict[IRFullTensor, List[IRCell]] = {} + output_consumers: Dict[IRFullTensor, List[IRCell]] = {} # create inputs if inputs: input_objects = IRGraph.get_objects_from_complex(segment.inputs()) @@ -46,12 +45,9 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) fop = fwop.replicate() fop.device = devid if tensor.requires_grad: - fop.output(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) - segment.finsert(fop, 0) - else: - segment.insert(fop, 0) - fwops.append(fop) - + fop.output(0).grad = tensor.parent.grad.select(tensor.indmap, (0, 1)) + fop.output(0).grad.cell = fop + input_producers.setdefault(tensor.parent, []).append(fop) # create outputs if outputs: output_objects = IRGraph.get_objects_from_complex(segment.outputs()) @@ -63,17 +59,14 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) for devid in devices: fop = fwop.replicate() fop.device = devid - if tensor.requires_grad and segment.mirror != segment: - fop.input(0).grad = tensor.grad.select(tensor.indmap, (0, 1)) - segment.finsert(fop, segment.nnodes) - else: - segment.insert(fop, segment.nnodes) - fwops.append(fop) - - return fwops + if tensor.requires_grad: + fop.input(0).grad = tensor.parent.grad.select(tensor.indmap, (0, 1)) + fop.input(0).grad.cell = fop + output_consumers.setdefault(tensor.parent, []).append(fop) + return input_producers, output_consumers -def expand_devices(tensors: List[IRSubTensor], +def expand_devices(tensors: List[Optional[IRSubTensor]], producer: bool = False, consumer: bool = False) -> List[IRSubTensor]: """ Scatter a tensor if it is on multiple devices. It produces a tensor list where @@ -87,6 +80,7 @@ def expand_devices(tensors: List[IRSubTensor], """ dtensors = [] for tensor in tensors: + if tensor is None: continue if len(tensor.device) == 1: dtensors.append(tensor) continue @@ -301,8 +295,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: return False return True - fdummies = create_dummy(graph, inputs=True, outputs=True) - bdummies = [fwop.mirror for fwop in fdummies if fwop.mirror is not None] + input_producer, output_consumer = create_dummy(graph, inputs=True, outputs=True) bgraph: Optional[IRSegment] = graph.mirror # local producer fusion and local consumer multiref @@ -321,7 +314,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # reorder again since inserted multiref could be mis-ordered graph._reorder_producer_consumer() - # generate adapter for inter-segments + # generate adapter for intra-segments # FIXME: assume producers and consumers can run in parallel for ftensor in ftensors: @@ -330,9 +323,15 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # producers can be operators and graph inputs fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) + if ftensor in input_producer: + fptensors = fptensors + tuple(fop.output(0) for fop in input_producer[ftensor]) fptensors = expand_devices(fptensors, producer=True) assert all(len(ptensor.device) == 1 for ptensor in fptensors), "Not support for multi-device" + + # consumers can be operators and graph outputs fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) + if ftensor in output_consumer: + fctensors = fctensors + tuple(fwop.input(0) for fwop in output_consumer[ftensor]) fctensors = expand_devices(fctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" @@ -340,13 +339,17 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bconsumers, bctensors = [], [] if isinstance(ftensor.grad, IRFullTensor): bproducers, bptensors = bgraph.producers(ftensor.grad), bgraph.ptensors(ftensor.grad) + if ftensor in output_consumer: + bptensors = bptensors + tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) bptensors = expand_devices(bptensors, producer=True) assert all(len(ptensor.device) == 1 for ptensor in bptensors), ( f"Not support for multi-device:\n" f"{[ptensor.device for ptensor in bptensors]}" f"{[ptensor.cell for ptensor in bptensors]}" ) - bconsumers, bctensors = bgraph.consumers(ftensor.grad), bgraph.ctensors(ftensor.grad) + bconsumers, bctensors = bgraph.consumers(ftensor.grad), bgraph.ctensors(ftensor.grad) + if ftensor in input_producer: + bctensors = bctensors + tuple(fwop.output(0).grad for fwop in input_producer[ftensor]) bctensors = expand_devices(bctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" @@ -376,29 +379,22 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # insert forward adapter # graph.insert(fadapter, max(producers) + 1) - fidx = min(graph.index(c) for c in fconsumers) + tail = CellPosition((graph.nnodes,)) + fconsumer_idx = [tail] + [graph.index(c) for c in fconsumers] + fidx = min(fconsumer_idx) # setup recompute if fadapter.differentiable and allow_recompute: - fadapter.recompute = graph.node(fidx).recompute + fadapter.recompute = graph.node(fidx if fidx != tail else fidx-1).recompute graph.insert(fadapter, fidx) # insert backward adapter if badapter is not None: assert isinstance(badapter, IRAdapter) assert isinstance(bgraph, IRSegment) - bproducers = [ - bgraph.index(consumer.mirror) + 1 for \ - consumer in graph.consumers(ftensor) - ] - bidx = max(bproducers) if len(bproducers) > 0 else 0 + bproducer_idx = [CellPosition((0,))] + [bgraph.index(consumer.mirror) + 1 for consumer in fconsumers] + bidx = max(bproducer_idx) bgraph.insert(badapter, bidx) - # remove dummy op - for dummy_op in fdummies: - graph.remove(dummy_op) - for dummy_op in bdummies: - bgraph.remove(dummy_op) - # generate adapter for each segment segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] for segment in segments: diff --git a/cube/graph/graph.py b/cube/graph/graph.py index de6b9239..35680f6d 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -81,8 +81,14 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: else: raise RuntimeError('len(args) < len(itensors)') + arg_objs = IRGraph.get_objects_from_complex(args) + graph_objs = IRGraph.get_objects_from_complex(self.inputs()) + assert len(arg_objs) == len(graph_objs), f"input object number not match: {len(arg_objs)} != {len(graph_objs)}" + for idx, (itensor, arg) in enumerate(zip(itensors, args)): self.set_input(idx, arg) + + for arg, itensor in zip(arg_objs, graph_objs): for producer in self.producers(itensor.parent): with self.update(producer): while itensor in producer.outputs(): @@ -93,13 +99,12 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: while itensor in consumer.inputs(): iidx = consumer.inputs().index(itensor) consumer.set_input(iidx, arg) - while itensor in self.outputs(): - oidx = self.outputs().index(itensor) - self.set_output(oidx, arg) - while itensor in self.inputs(): - iidx = self.inputs().index(itensor) - self.set_input(iidx, arg) - + # reset output + for oidx, output in enumerate(self.outputs()): + output = IRGraph.modify_objects_of_complex( + self.output(oidx), lambda t: t if t != itensor else arg) + self.set_output(oidx, output) + # dtype inference for node in self._nodes: # reset input diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 90ef2363..e7855453 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -103,9 +103,14 @@ def parse(module: torch.fx.GraphModule, # handle graph inputs for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) - shape = None if (input_shapes is None or len(input_shapes) <= idx) else input_shapes[idx] - dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + if 'tensor_meta' in input.meta: # tensor type + shape = None if len(input_shapes) <= idx else input_shapes[idx] + if shape is not None and len(shape) == 0: + shape = [1] + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + else: + val = IRObject(input.name) frame.add_var(input.name, val, graph_arg=idx) else: assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 00a3de1b..88461b46 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -76,16 +76,11 @@ class IRSegment(IRCell): Inserting and removing nodes that could change input/output are not allowed. """ - def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRSubTensor], name='segment'): + def __init__(self, nodes: List[IRCell], inputs: List[Any], outputs: List[Any], name='segment'): super().__init__(name, '', len(inputs), len(outputs), init_outputs=False) self._nodes: List[IRCell] = [] - for idx, val in enumerate(inputs): - self.set_input(idx, val) - for idx, val in enumerate(outputs): - self.set_output(idx, val) - # full-tensor / sub-tensor mapping self._ftensors: Set[IRFullTensor] = set() self._producers: Dict[IRFullTensor, List[IRCell]] = dict() @@ -96,6 +91,14 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR # attributes self._attributes: Set[IRFullTensor] = set() + for idx, val in enumerate(inputs): + self.set_input(idx, val) + for idx, val in enumerate(outputs): + self.set_output(idx, val) + + for t in IRSegment.get_objects_from_complex(list(inputs) + list(outputs)): + self._add_ftensor(t.parent) + for node in nodes: self.insert(node, self.nnodes) @@ -103,6 +106,16 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IR # self.reset_dependency() + def set_input(self, idx: int, val: Any): + for t in IRSegment.get_objects_from_complex(val): + self._add_ftensor(t.parent) + return super().set_input(idx, val) + + def set_output(self, idx: int, val: Any): + for t in IRSegment.get_objects_from_complex(val): + self._add_ftensor(t.parent) + return super().set_output(idx, val) + def isfw(self) -> bool: return all(n.isfw() for n in self._nodes) # return self._have_forward @@ -414,6 +427,13 @@ def _reorder_producer_consumer(self): self._ftensors, self._attributes = set(), set() self._producers, self._ptensors = dict(), dict() self._consumers, self._ctensors = dict(), dict() + + # set input and output + for obj in IRSegment.get_objects_from_complex(self.inputs()): + self._add_ftensor(obj.parent) + for obj in IRSegment.get_objects_from_complex(self.outputs()): + self._add_ftensor(obj.parent) + # set producer and consumer for node in self._nodes: if isinstance(node, IRAdapter): continue @@ -833,7 +853,7 @@ def single_consume(self, one_for_all: bool = True): # ====================== Graph Generations ============================ @staticmethod - def get_inputs(nodes: List[IRCell]): + def get_inputs(nodes: List[IRCell], exclude_attr: bool = True): """ Get all the input tensors that are required by nodes. @@ -848,13 +868,15 @@ def get_inputs(nodes: List[IRCell]): for node in nodes: for input in node.inputs(): if isinstance(input, IRTensor): + if exclude_attr and input.is_attr(): + continue if input not in all_outputs: if input not in inputs: inputs.append(input) return inputs @staticmethod - def get_outputs(nodes: List[IRCell]): + def get_outputs(nodes: List[IRCell], exclude_attr: bool = True): """ Get tensors that are produced but not consumed by nodes @@ -874,6 +896,8 @@ def get_outputs(nodes: List[IRCell]): for output in node.outputs(): # not consumed tensor if isinstance(output, IRTensor): + if exclude_attr and output.is_attr(): + continue if output not in all_inputs: if output not in outputs: outputs.append(output) diff --git a/cube/program.py b/cube/program.py index b14075a0..f669b887 100644 --- a/cube/program.py +++ b/cube/program.py @@ -44,6 +44,11 @@ def add_nodes(self, nodes: List[IRCell]): def get_graph(self) -> IRGraph: return self.instance._graph + def set_input(self, inputs: Tuple[Any]): + self.instance._graph.reset_inputs(len(inputs)) + for idx, obj in enumerate(inputs): + self.instance._graph.set_input(idx, obj) + def set_output(self, outputs: Tuple[Any]): self.instance._graph.reset_outputs(len(outputs)) for idx, otensor in enumerate(outputs): From a7f88b27acf002f7b430ede7a25c894a446246e3 Mon Sep 17 00:00:00 2001 From: rwlu Date: Wed, 22 Mar 2023 05:18:11 -0700 Subject: [PATCH 1364/1892] change the name of apis --- cube/profiler/database.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index f19f4a9b..c2a65243 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -409,7 +409,7 @@ def dump(self, file: str, override=False): json.dump(self._data, f) - def dump_fine_grained(self, file: str, override=False): + def dump_nodes(self, file: str, override=False): if os.path.exists(file): assert override, f"File {file} exists. Set override = True to force dump." for signature in self._data.keys(): @@ -417,7 +417,7 @@ def dump_fine_grained(self, file: str, override=False): with open(file_n, 'w') as f: json.dump(self._data[signature],f) - def dump_single(self, file: str, signature ,override=False): + def dump_node(self, file: str, signature ,override=False): assert signature in self._data.keys(), f'this node not be profiled' file_n = os.path.join(file, signature +'.json') with open(file_n, 'w') as f: @@ -433,13 +433,13 @@ def load(self, file: str): with open(file, 'r') as f: self._data = json.load(f) - def load_fine_grained(self, file: str): + def load_nodes(self, file: str): for filename in os.listdir(file): if filename.endswith('.json'): with open(os.path.join(file, filename)) as f: signature = filename[:-len('.json')] self._data[signature] = json.load(f) - + def __repr__(self) -> str: data = [] for signature in self._data: From eca7d97ffbc9d975ee62dbc3a172623cabc6d304 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 22 Mar 2023 22:52:37 -0700 Subject: [PATCH 1365/1892] save work --- examples/nlp/torchscale/run_torchscale_lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index fb9f2e40..643d5289 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -26,7 +26,8 @@ import cube -bs, seql, ngpu = 2, 2048, 4 +# bs, seql, ngpu = 2, 2048, 4 +bs, seql, hidden_dim, ngpu = 2, 2048, 2048, 8 # # build model parser = options.get_training_parser() @@ -182,7 +183,7 @@ class dotdict(dict): __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ -config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': f'torchscale_{bs}_{seql}_{ngpu}'}) +config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': f'torchscale_{bs}_{seql}_{hidden_dim}_{ngpu}'}) config.autodist_config = dotdict({'ngpus': ngpu}) # NOTE add SINGLE_DEV_MODE=1 before the running command from autodist.cost_model.cost_database import CostDatabase From 7595d2e180dc1d5d4e275cb1705e5afc634c7fcc Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 22 Mar 2023 23:27:41 -0700 Subject: [PATCH 1366/1892] refine code --- cube/graph/function/function.py | 1 - cube/graph/parser/parserfx.py | 4 ---- cube/runtime/function/function.py | 1 + 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 3da528bc..0316929b 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -382,7 +382,6 @@ def Sub(input, other, alpha=1, *, out=None, signature = None): def Mul(input, other, *, out=None, signature = None): assert out is None - signature = 'torch.mul' if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input * other signature = 'torch.mul' diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 23326c09..e7855453 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -274,10 +274,6 @@ def get_complex_data(val: Any) -> Any: input_vals = [get_complex_data(val) for val in node.args] kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} - # if 'int' in fsig: - # print(fsig) - # exit(1) - # map to IR operator if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 95289f52..813985a7 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -37,6 +37,7 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: else: return torch.sum(torch.stack(tensors, dim=0), dim=0) + def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: return input.expand(*sizes) From 730e9b10f96f74c4896a9597c5a610bbabeffacc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 23 Mar 2023 12:26:02 +0000 Subject: [PATCH 1367/1892] Merged PR 1505: Support generate code for evaluation mode running examples: ```python @cube.compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) loss.backward() for _ in range(num_train_iters): loss = train_iter(model, dataloader) eval_fn = cube.load_eval_schedule() for step in range(4): loss = eval_fn(model, dataloader) print(f'iter [{step + 1}/{iter_num}]) ``` --- cube/__init__.py | 1 + cube/codegen/schedule/schedule.py | 43 ++++++++++++++++++++++++++----- cube/compiler.py | 13 ++-------- cube/utils.py | 34 ++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 17 deletions(-) create mode 100644 cube/utils.py diff --git a/cube/__init__.py b/cube/__init__.py index d2c5bbd7..4607f7d4 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -3,6 +3,7 @@ from cube import profiler from cube.compiler import SemanticModel, compile +from cube.utils import load_model, load_default_schedule, load_eval_schedule def _check_torch_version(): diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index 90a9a56e..f73d5467 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -1,6 +1,7 @@ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any import copy +import warnings from cube.ir.cten import IRCell, IRTensor from cube.ir.operator import IRDataOperation, IRFwOperation @@ -54,10 +55,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: else: for line, node in enumerate(device_nodes): # execute - if isinstance(node, ExeRepetend): - codes = ScheduleCodeGen.emit_repetend(node) - else: - codes = ScheduleCodeGen.emit_node(node) + codes = ScheduleCodeGen.emit_node(node) fb.insert_body(codes) # release tensors = lifetime.release_tensors_after_line(line) @@ -69,6 +67,38 @@ def gen(self, device: int, outfile=None, attach=None) -> str: code = f'return {outputs}' fb.insert_body(code) gencode += fb.code + + gencode += ['', ''] + + # infer code + if not any(not node.isfw() for node in device_nodes): + gencode += ['_infer_step = _train_step'] + else: + # legacy hardcode strategy + if isinstance(self.execplan.graph.sched, IRScheduleStrategy): + warnings.warn('using legacy IRScheduleStrategy cannot generate inference code. ' + 'Switch to use scheduling without strategy') + with FunctionBlock(func_name='_infer_step', + args=['model', 'dataloader']) as fb: + fb.insert_body('_ = None') + # body code + if len(device_nodes) == 0: + fb.insert_body('pass') + for line, node in enumerate(device_nodes): + if not node.isfw(): continue # skip backward segments and adapters + # execute + codes = ScheduleCodeGen.emit_node(node, force_no_grad=True) + fb.insert_body(codes) + # release + tensors = lifetime.release_tensors_after_line(line) + tensors = [t for t in tensors if not t.is_grad()] + if len(tensors) > 0 : # not necessarily to have one after each line + fb.insert_body(ScheduleCodeGen.emit_release(tensors)) + # return code + outputs = ScheduleCodeGen.return_name_complex(self.execplan.graph.outputs()) + code = f'return {outputs}' + fb.insert_body(code) + gencode += fb.code gencode += [''] code = '\n'.join(gencode) @@ -79,7 +109,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: return code @staticmethod - def emit_node(node: IRCell) -> List[str]: + def emit_node(node: IRCell, force_no_grad: bool = False) -> List[str]: """ Emit node / subgraph code """ @@ -89,6 +119,7 @@ def emit_node(node: IRCell) -> List[str]: node_inputs, node_outputs = node.inputs(), node.outputs() req_grad = any(t.requires_grad for t in node.outputs() if isinstance(t, IRTensor)) + req_grad = False if force_no_grad else req_grad # handle for forward inputs = ScheduleCodeGen.tuple_name(node_inputs, skip_attr=True, prefix_attr='model.') diff --git a/cube/compiler.py b/cube/compiler.py index 238f6950..9eee8d5f 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -80,15 +80,6 @@ def train_step(model, dataloader): myrank = DeviceGroup().rank - def _load_tschedule_fn(filename) -> Callable: - import importlib.util - spec = importlib.util.spec_from_file_location( - "_train_step", filename - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module._train_step - def decorator(fn: Callable) -> Callable: filename = 'gencode{}.py' @@ -101,7 +92,7 @@ def decorator(fn: Callable) -> Callable: model.load_module(filename) # load schedule code print_each_rank(f'loading existed schedule from {filename} ...') - return _load_tschedule_fn(filename) + return cube.load_default_schedule(filename) if DeviceGroup().local_rank == 0: @@ -242,6 +233,6 @@ def decorator(fn: Callable) -> Callable: # load temporal schedule print_each_rank(f'loading generated schedule from {filename} ...') - return _load_tschedule_fn(filename) + return cube.load_default_schedule(filename) return decorator diff --git a/cube/utils.py b/cube/utils.py new file mode 100644 index 00000000..623c6431 --- /dev/null +++ b/cube/utils.py @@ -0,0 +1,34 @@ +from typing import Optional +from cube.profiler.timer import print_each_rank +from cube.runtime.device import DeviceGroup + + +def _load_module_attr(filename: str, name: str): + import importlib.util + spec = importlib.util.spec_from_file_location(name, filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def load_model(filename: Optional[str] = None, load_content: bool = True): + filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename + module = _load_module_attr(filename, 'GenModel') + loaded_module = module.GenModel().cuda() + if load_content: + print_each_rank("> loading parameter content...") + loaded_module.load_attr_content('./fullmodel.pt') + return loaded_module + + +def load_default_schedule(filename: Optional[str] = None): + filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename + module = _load_module_attr(filename, '_train_step') + return module._train_step + + +def load_eval_schedule(filename: Optional[str] = None): + filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename + module = _load_module_attr(filename, '_infer_step') + return module._infer_step + From c30d8a7ccee69502c57d2d83bf181e5730fde46f Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 23 Mar 2023 05:37:24 -0700 Subject: [PATCH 1368/1892] typeas to To --- cube/graph/function/function.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 0316929b..3a6ba762 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1049,13 +1049,9 @@ def Unsqueeze(input, dim, signature = None): def TypeAs(input: IRTensor, tensor: IRTensor, signature = None): """ - out = torch.Tensor.type_as(tensor0, tensor1) + translate to To """ - edim_in0 = ShapeAnno.create_shape_str(tensor.shape) - anno = OpAnno.create_op_str(['*', edim_in0], ['*']) - dimop = IRDimops(TypeAs, 'type_as', signature, [anno], [input, tensor]) - dimop.output(0).requires_grad = input.requires_grad - return dimop + return To(input, tensor) def Triu(input, diagonal=0, *, out=None, signature = None): From d244ad60d8e14e1d9433d667c97800fc58fda1b0 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 23 Mar 2023 05:39:59 -0700 Subject: [PATCH 1369/1892] typeas to to --- cube/graph/function/function.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index d0cef82d..4f82f361 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1034,13 +1034,9 @@ def Unsqueeze(input, dim, signature = None): def TypeAs(input: IRTensor, tensor: IRTensor, signature = None): """ - out = torch.Tensor.type_as(tensor0, tensor1) + translate to To """ - edim_in0 = ShapeAnno.create_shape_str(tensor.shape) - anno = OpAnno.create_op_str(['*', edim_in0], ['*']) - dimop = IRDimops(TypeAs, 'type_as', signature, [anno], [input, tensor]) - dimop.output(0).requires_grad = input.requires_grad - return dimop + return To(input, tensor) def Triu(input, diagonal=0, *, out=None, signature = None): From 3c90bebae7c61f17db5d2660b16e4fc744a7f3c0 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 24 Mar 2023 02:43:06 -0700 Subject: [PATCH 1370/1892] for pr --- examples/nlp/torchscale/run_torchscale_lm.py | 217 ++++++------------ tests/parser/yizhu1_bloom.py | 219 ------------------- 2 files changed, 66 insertions(+), 370 deletions(-) delete mode 100644 tests/parser/yizhu1_bloom.py diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py index 74b58350..90c679d4 100644 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ b/examples/nlp/torchscale/run_torchscale_lm.py @@ -1,6 +1,4 @@ -# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:/home/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale:$PYTHONPATH python3 run_torchscale_lm.py /home/quzha/MagicCube/examples/nlp/torchscale/input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData - - +# single GPU inference debug # USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData # multi-GPU inference test # OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 --master_port=25642 examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --subln --xpos-rel-pos --fp16 --policy PASData @@ -9,7 +7,6 @@ # multi-GPU training test # OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 --master_port=25642 examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --subln --xpos-rel-pos --fp16 --policy PASData --do_train -from pathlib import Path import torch import pickle from fairseq import ( @@ -22,33 +19,63 @@ from fairseq.data import iterators import sys +import os # https://github.com/microsoft/torchscale/tree/main/examples/fairseq -# sys.path.append('/home/quzha/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') -# sys.path.append('/home/quzha/quzha/MagicCube/examples/nlp/torchscale/torchscaletest/torchscale') -sys.path.append('/home/quzha/quzha/torchscale/examples/fairseq') +sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') +sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale') print(f'sys.path = {sys.path}') import models +#:torchscaletest/torchscale import cube +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +sys.path.append('.') +from policy import mpmd, spmd +# import examples.nlp.torchscale.policy.spmd as spmd + +# import argparse -# bs, seql, ngpu = 2, 2048, 4 -bs, seql, hidden_dim, ngpu = 2, 2048, 2048, 8 +# parser = argparse.ArgumentParser(description='comm primitive') +# parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +# parser.add_argument('--local_rank', type=int, default=0) +# args = parser.parse_args() -# # build model +# build model parser = options.get_training_parser() +parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +parser.add_argument('--do_train', action='store_true', default=False) +# parser.add_argument('--local_rank', type=int, default=0) + args = options.parse_args_and_arch(parser) +print(f"Running mode: {'TRAIN' if args.do_train else 'EVAL'}") + +cube.init() +# set up policy +PAS = None +policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) +if args.policy in spmd.__dict__: + PAS = spmd.__dict__[args.policy] + print_each_rank(f'using policy from spmd.{args.policy}') +elif args.policy in mpmd.__dict__: + PAS = mpmd.__dict__[args.policy] + print_each_rank(f'using policy from mpmd.{args.policy}') +else: + raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + cfg = convert_namespace_to_omegaconf(args) task = tasks.setup_task(cfg.task) model = task.build_model(cfg.model) -model.train() +if args.do_train: + model.train() +else: + model.eval() print("building model succeed: ", type(model)) # create dummy input -# with open('/home/quzha/quzha/MagicCube/examples/nlp/torchscale/input_lm', 'rb') as f: -with open(f'/home/quzha/torchscale_{bs}_{seql}.pkl', 'rb') as f: +with open('examples/nlp/torchscale/input_lm', 'rb') as f: dummy_input = pickle.load(f) -dummy_input = dummy_input['net_input'] device = next(model.parameters()).device print(f'device = {device}') for key in dummy_input.keys(): @@ -94,141 +121,29 @@ raise RuntimeError(f'To fix sample_input with type{type(sample_input)}') -# model = cube.SemanticModel( -# #TODO fix me model, dummy_input=sample_input_cpu, -# # model, dummy_input=dummy_input_list, -# model, dummy_input=dummy_input, -# ) - -# @cube.compile(model, dataloader, PAS=PAS, load_content=False) -# def train_iter(model, dataloader): -# data = next(dataloader) -# loss = model(*data) -# # loss.backward() -# return loss - -# model = model.get_gen_module() - -# iter_ret = train_iter(model, dataloader) -# print(f'iter_ret = {iter_ret}') - -# Conduct concrete trace below -# sys.path.append('/home/v-junliang/torchscaletest/nni') -# sys.path.append('./torchscaletest/nni') -# from nni.common.concrete_trace_utils import concrete_trace -# from concrete_trace_utils import concrete_trace -# from examples.nlp.torchscale.concrete_trace_utils import concrete_trace -# import examples.nlp.torchscale.torchscaletest.torchscale -from cube.graph.parser.concrete_trace_utils.concrete_tracer import concrete_trace - -def check_equal(a, b): - if type(a) != type(b): - return False - if isinstance(a, (list, tuple, set)): - if len(a) != len(b): - return False - for sub_a, sub_b in zip(a, b): - if not check_equal(sub_a, sub_b): - return False - return True - elif isinstance(a, dict): - keys_a, kes_b = set(a.keys()), set(b.keys()) - if keys_a != kes_b: - return False - for key in keys_a: - if not check_equal(a[key], b[key]): - return False - return True - elif isinstance(a, torch.Tensor): - return torch.equal(a, b) - else: - return a == b - -print("start tracing...") -traced_graph = concrete_trace( - model, - dummy_input, - use_operator_patch=True, - autowrap_leaf_class={ - torch.finfo: ((), False), - type(output_origin): ((), False), - }, +model = cube.SemanticModel( + model, dummy_input=dummy_input, ) -print("trace succeed") -print("checking equal...") -with torch.no_grad(): - output_traced = traced_graph(**dummy_input) -# assert check_equal(output_origin, output_traced), "check equal failed" -print("checked") - -# check graph -traced_graph.graph.print_tabular() - -print("parsing fx graph to cube graph...") -from cube.graph.parser import FxModuleParser -inputs, nodes, outputs = FxModuleParser.parse(traced_graph, dummy_inputs=dummy_input) -print("parsing done.") -from cube.graph import IRGraph -module_name = model.__class__.__name__ -cube_graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) -print("generating cube ir graph done.") - -# move simple type inputs to kwargs -# for node in cube_graph.nodes - -# AutoDist -# # profile communication cost -# import os -# comm_gpu_num = (2, 4) -# for gpu_num in comm_gpu_num: -# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --connect_type=NV') -# profile computation cost -class dotdict(dict): - """dot.notation access to dictionary attributes""" - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - -# config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': f'torchscale_{bs}_{seql}_{ngpu}'}) -config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': f'torchscale_{bs}_{seql}_{hidden_dim}_{ngpu}'}) -config.autodist_config = dotdict({'ngpus': ngpu}) -# NOTE add SINGLE_DEV_MODE=1 before the running command -from autodist.cost_model.cost_database import CostDatabase -cost_database = CostDatabase(cube_graph, config) - -# find the best partition plan -from autodist.task_config import TaskConfig -class TorchscaleTaskConfig(TaskConfig): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.model = 'Bloom' - # self.Bloom_setting = kwargs['Bloom_setting'] - # self.fine_grained_Bloom = kwargs['fine_grained_Bloom'] - # self.bloom_config = build_bloom_config(self.Bloom_setting) - self.task_name = f'torchscale-{self.autodist_config.ngpus}gpu-'\ - f'{self.autodist_config.micro_batch_size}batch_size' - self.estimated_fname, self.backup_fname, self.runtime_fname = self._build_file_name( - self.task_name) - self.allow_recom_ops = [] - self.del_dim = [] - -kwargs = {'consider_mem': False, 'save_folder': 'exp_data', 'micro_batch_size': bs, 'global_batch_size': bs, 'iter_num': 2, - 'warm_num': 1, 'recompute': False, 'memory_constraint': 40, 'memory_granularity': 1, - 'profile_dir': str(Path.home())+'/.autodist/', 'connect_type': 'NV2', 'use_prev_plan': False, - 'is_train': True, 'topk': 20, 'mesh_row': 1, 'mesh_col': ngpu, 'compile': True, 'pipeline': False, - 'nproc': 12, 'adaptive_recom': False, 'plan_idx': 0, 'verbose': True, 'ignore_small_tensor_threshold': 0, - 'parse_plan': True} - -task_config = TorchscaleTaskConfig(**kwargs) -from autodist.apis import calc_parallel_plan -topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) - -best_plan = topk_plans[0][0].partition_descs - -from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation - -for node in cube_graph.select(ntype=IRFwOperation): - if node.cid in best_plan: - print(f'{node}, {node.anno}, autodist ret: {best_plan[node.cid]}') - else: - print(f'{node}, switch to default replica') + +if args.do_train: + @cube.compile(model, dataloader, PAS=PAS, load_content=False, override=True) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(*data) + loss.backward() + # TODO fix loss.mirror DummyInputOutput issue + + model = model.get_gen_module() + train_iter(model, dataloader) +else: # do_eval + @cube.compile(model, dataloader, PAS=PAS, load_content=False, override=True) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(*data) + return loss + + model = model.get_gen_module() + iter_ret = train_iter(model, dataloader) + print(f'iter_ret = {iter_ret}') + +print('DONE') \ No newline at end of file diff --git a/tests/parser/yizhu1_bloom.py b/tests/parser/yizhu1_bloom.py deleted file mode 100644 index 5bb44b14..00000000 --- a/tests/parser/yizhu1_bloom.py +++ /dev/null @@ -1,219 +0,0 @@ -from pathlib import Path -import torch -import torch.nn as nn -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -import time -import json - -def convert_mem_into_GB(mem): - if type(mem) in [int, float]: - return mem / 1024 / 1024 / 1024 - else: - return [x / 1024 / 1024 / 1024 for x in mem] - -model_name = "bigscience/bloom-560m" -model_path = "/home/quzha/bloom560m" - -print("Loading model...") -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir=model_path) -print(type(model), '; is nn.Module? ', isinstance(model, nn.Module)) -print("Model's generation config which does not list default values: ", model.generation_config) -print("Loading tokenizer...") -tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) -print("Loading Done!") -prompt = "If I want to travel to a new city, I should plan my trip as follows:" -#input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() -inputs = tokenizer(prompt, return_tensors="pt") - -# Cube -# from cube.graph import parser -# ir_graph = parser.convert_model(model, input_shapes=[1, 17], save_content=False) - -print("concrete tracing model...") -import sys -nni_path = "/home/quzha/yizhu1/yizhu1_autodist/nni/" -sys.path.append(nni_path) - -# from concrete_trace_utils import concrete_trace -from nni.common.concrete_trace_utils import concrete_trace -# from cube.graph.parser.concrete_trace_utils import concrete_trace - -traced_graph = concrete_trace(model, inputs, use_operator_patch=True, - autowrap_leaf_class={torch.finfo: ((), False)}) -print("tracing model done.") -# print(traced_graph) - -print("parsing fx graph to cube graph...") -from cube.graph.parser import FxModuleParser -inputs, nodes, outputs = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) -print("parsing done.") -from cube.graph import IRGraph -module_name = model.__class__.__name__ -cube_graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) - -# AutoDist -# # profile communication cost -# import os -# comm_gpu_num = (2, 4) -# for gpu_num in comm_gpu_num: -# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --connect_type=NV') -# profile computation cost -class dotdict(dict): - """dot.notation access to dictionary attributes""" - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ -config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': 'bloom'}) -config.autodist_config = dotdict({'ngpus': 2}) -# NOTE add SINGLE_DEV_MODE=1 before the running command -from autodist.cost_model.cost_database import CostDatabase -cost_database = CostDatabase(cube_graph, config) -# find the best partition plan -from autodist.task_config import TaskConfig -class BloomTaskConfig(TaskConfig): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.model = 'Bloom' - # self.Bloom_setting = kwargs['Bloom_setting'] - # self.fine_grained_Bloom = kwargs['fine_grained_Bloom'] - # self.bloom_config = build_bloom_config(self.Bloom_setting) - self.task_name = f'bloom-{self.autodist_config.ngpus}gpu-'\ - f'{self.autodist_config.micro_batch_size}batch_size' - self.estimated_fname, self.backup_fname, self.runtime_fname = self._build_file_name( - self.task_name) - self.allow_recom_ops = [] - self.del_dim = [] -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Bloom benchmark') - parser.add_argument('--fp16', - action='store_true', - help='use fp16 for the training') - parser.add_argument('--fine_grained_GPT', - action='store_true', - help='model = GPTFineGrained') - parser.add_argument('--GPT_setting', - type=str, - default='6.7B', - help='set GPT model type') - parser.add_argument('--save_folder', - type=str, - default='exp_data', - help='set the save folder for experiment data') - parser.add_argument('--micro_batch_size', - type=int, - default=8, - help='set micro batch size') - parser.add_argument('--global_batch_size', - type=int, - default=8, - help='set the global batch size') - parser.add_argument('--iter_num', - type=int, - default=2, - help='set the number of all iterations') - parser.add_argument('--warm_num', - type=int, - default=1, - help='set the number of warmup iterations') - parser.add_argument('--recompute', - action='store_true', - help='set recompute flag') - parser.add_argument('--memory_constraint', - type=float, - default=32, - help='memory constraint for program') - parser.add_argument('--memory_granularity', - type=int, - default=1, - help='memory granularity in byte') - parser.add_argument('--profile_dir', - type=str, - default=str(Path.home()) + '/.autodist', - help='profile dir') - parser.add_argument('--connect_type', - type=str, - default='NV2', - help='connect type from nvidia-smi topo -m') - parser.add_argument('--use_prev_plan', - action='store_true', - help='run from previous plan') - parser.add_argument('--is_train', - action='store_true', - help='True: train, False: inference') - parser.add_argument('--topk', - type=int, - default=20, - help='generate multiple plans for robustness') - parser.add_argument('--mesh_row', type=int, default=1, help='node num') - parser.add_argument('--mesh_col', - type=int, - default=2, - help='dev num in a node') - parser.add_argument('--compile', - action='store_true', - help='compile stage: true, runtime stage: false') - parser.add_argument('--pipeline', - action='store_true', - help='pipeline: true, tensor parallel: false') - parser.add_argument('--nproc', - type=int, - default=12, - help='multiprocess deg in pipeline') - parser.add_argument('--adaptive_recom', - action='store_true', - help='allow adaptive recompute') - parser.add_argument('--plan_idx', - type=int, - default=0, - help='runtime plan idx') - parser.add_argument('--verbose', action='store_true', help='verbose mode') - parser.add_argument('--ignore_small_tensor_threshold', - type=int, - default=0, - help='set the tensor size threshold to ignore') - parser.add_argument('--parse_plan', - action='store_true', - help='parse plan to user-friendly format') - parser.add_argument('--alphafold', - action='store_true', - help='use alphafold2') - parser.add_argument('--alphafold_setting', - type=int, - default=1, - help='1: bs, s, r = 1, 128, 256'\ - '2: bs, s, r = 1, 512, 256'\ - '3: bs, s, r = 1, 512, 384') - args = parser.parse_args() - - # if args.compile: - # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' - - task_config = BloomTaskConfig(**vars(args)) - from autodist.apis import calc_parallel_plan - compile_start_time = time.time() - topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) - compile_cost_time = time.time() - compile_start_time - plan_info = [] - for plan in topk_plans: - cur_info = {} - if task_config.pipeline: - cur_spmd_descs, cur_time, cur_mems, cur_devs, cur_times = plan - cur_info['plan'] = [] - for item in cur_spmd_descs: - cur_info['plan'].append(item.to_json_object()) - cur_info['estimated time'] = cur_time - cur_info['estimated memory'] = convert_mem_into_GB(cur_mems) - cur_info['estimated time list'] = cur_times - cur_info['compile time'] = compile_cost_time - plan_info.append(cur_info) - else: - cur_spmd_desc, cur_mem, cur_time = plan - cur_info['plan'] = cur_spmd_desc.to_json_object() - cur_info['estimated time'] = cur_time - cur_info['estimated memory'] = convert_mem_into_GB(cur_mem) - cur_info['compile time'] = compile_cost_time - plan_info.append(cur_info) - - with open(task_config.backup_fname, 'w') as f: - json.dump(plan_info, f) From 059daa2f111141cf89f59bf55c56c5166bfcd831 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 24 Mar 2023 02:45:33 -0700 Subject: [PATCH 1371/1892] refine code --- cube/graph/function/function.py | 32 ++------------------------------ 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 3a6ba762..a461fde2 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1498,34 +1498,6 @@ def _comparison(creator: Callable, f: Callable, name: str, signature: str, else: return IRPyFunc(signature, [input, other], [IRObject()]) -def _comparison_hack(creator: Callable, f: Callable, name: str, signature: str, - input, other): - """ - if both operands are scalars, returns bool. - if one operand is a tensor, returns a broadcasted tensor with dtype being bool. - - @param creator Callable: the outside creation function - @param f Callable: (Scalar, Scalar) -> bools - """ - # case 0: return constant - if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): - return f(input, other) - # case1: torch.equal(tensor1, tensor2) - if isinstance(input, IRTensor) and isinstance(other, IRTensor): - lshape, rshape, oshape = _handle_broadcast(input, other) - annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(creator, name, signature, annos, [input, other]) - # case2: torch.equal(tensor1, obj2) / torch.equal(obj1, tensor2) - if isinstance(input, IRTensor) or isinstance(other, IRTensor): - annos = ['* -> *'] - if isinstance(input, IRTensor): - return IRDimops(creator, name, signature, annos, [input], other=other) - else: - return IRDimops(creator, name, signature, annos, [other], other=input) - # case3: torch.equal(obj1, obj2) - else: - return IRPyFunc(signature, [input, other], [IRObject()]) - def CompareGT(input, other, *, out=None, signature = None): """ @@ -1559,14 +1531,14 @@ def CompareEQ(input, other, *, out=None, signature = None): """ torch.eq(input, other, *, out=None) """ - return _comparison_hack(CompareEQ, operator.eq, 'eq', signature, input, other) + return _comparison(CompareEQ, operator.eq, 'eq', signature, input, other) def CompareNE(input, other, *, out=None, signature = None): """ torch.ne(input, other, *, out=None) """ - return _comparison_hack(CompareNE, operator.eq, 'ne', signature, input, other) + return _comparison(CompareNE, operator.eq, 'ne', signature, input, other) def ShapeAsTensor(input: IRTensor, signature = None): From ed975db2071d88ab1431d265f5ef79ae60a6103c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 28 Mar 2023 11:05:56 +0000 Subject: [PATCH 1372/1892] Merged PR 1512: fix fullslice fix fullslice --- cube/graph/function/function.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index a461fde2..a4bcfe00 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1289,7 +1289,10 @@ def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice]], signature=No start = 0 if slicer.start is None else slicer.start step = 1 if slicer.step is None else slicer.step dimlen = len(range(start, stop, step)) - edim_ou.append(str(dimlen)) + if dimlen == tensor.shape[dim]: + edim_ou.append(edim_in[dim]) + else: + edim_ou.append(str(dimlen)) else: pass # no shape for int # special case for loss = torch.Tensor([1,2,3])[0] From 679d453c6a0894111f87c1083480e75224349e16 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Thu, 30 Mar 2023 12:37:26 +0000 Subject: [PATCH 1373/1892] Merged PR 1516: related changes when integrating with fairseq --- cube/codegen/frontend_mapping.py | 4 + cube/codegen/module/module.py | 2 +- cube/codegen/schedule/schedule.py | 7 +- cube/graph/function/dimops.py | 14 ++++ cube/graph/function/function.py | 10 +++ cube/graph/gener/gen.py | 84 +++++++++---------- .../concrete_trace_utils/concrete_tracer.py | 59 ++++++++++++- .../kwargs_shape_prop/kwargs_interpreter.py | 26 ++++++ .../kwargs_shape_prop/kwargs_shape_prop.py | 36 ++++---- cube/graph/parser/converter.py | 40 +++++---- cube/graph/parser/mappingfx.py | 4 + cube/graph/parser/parserfx.py | 42 +++++++--- cube/runtime/function/function.py | 3 + 13 files changed, 241 insertions(+), 90 deletions(-) diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index e5e262f3..e131d9e6 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -72,6 +72,9 @@ def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" +def emit_getattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" + class Sign2EmitRule: @staticmethod @@ -95,6 +98,7 @@ def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: _signMap = { 'torch.slice': emit_slice, 'setattr': emit_setattr, + 'builtins.getattr': emit_getattr, } diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index f7138264..aa611b9c 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -82,7 +82,7 @@ def __init__(self, execplan: ExecutionPlan) -> None: '\n\n########## Generated Model Code ###########', 'from typing import *', 'import torch', 'import torch.utils.checkpoint as ckpt', - 'import cube', 'import _operator', 'from numpy import inf', '', ''] + 'import cube', 'import _operator', 'from numpy import inf', 'import builtins', '', ''] if CompileFlag.use_nnfusion: self.init_code.extend(['import nnfusion', '']) diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index f73d5467..c0d37780 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -42,8 +42,11 @@ def gen(self, device: int, outfile=None, attach=None) -> str: lifetime = LifeCycle(device_nodes, [], self.execplan.graph.outputs()) + model_inputs = ['{}_{}'.format(_input.name, _input.tid) for _input in self.execplan.graph.inputs()] + args = ['model'] + model_inputs + with FunctionBlock(func_name='_train_step', - args=['model', 'dataloader']) as fb: + args=args) as fb: fb.insert_body('_ = None') # body code if len(device_nodes) == 0: @@ -79,7 +82,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: warnings.warn('using legacy IRScheduleStrategy cannot generate inference code. ' 'Switch to use scheduling without strategy') with FunctionBlock(func_name='_infer_step', - args=['model', 'dataloader']) as fb: + args=args) as fb: fb.insert_body('_ = None') # body code if len(device_nodes) == 0: diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 0abb3a94..181804ae 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -64,6 +64,7 @@ from typing import Callable, Dict, Iterable, List, Union, Set, Tuple, Optional import enum +import importlib import re import string @@ -616,6 +617,19 @@ def anno(self) -> OpAnno: def transform_rules(self) -> Tuple[TransformRule]: return self._trans_rules + def getstate_for_dump(self): + state = self.__dict__.copy() + state['_create_fn'] = { + 'name': self._create_fn[0].__name__, + 'module': self._create_fn[0].__module__, + } + return state + + def setstate_for_load(self, state): + module = importlib.import_module(state['_create_fn']['module']) + state['_create_fn'] = (getattr(module, state['_create_fn']['name']),) + self.__dict__.update(state) + def ianno(self, index: int) -> Tuple[DimAnno]: """! Get index-th input tensor shape annotation diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index a4bcfe00..32b666cb 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -508,6 +508,16 @@ def Dropout(input, p=0.5, training=True, inplace=False, signature = None): p=p, training=training, inplace=inplace) +def nnDropout(input, p=0.5, inplace=False, signature=None): + """ + torch.nn.Dropout(p=0.5, inplace=False) + """ + signature = 'cube.runtime.function.nndropout' + annos = ['* -> *'] + return IRDimops(nnDropout, 'Dropout', signature, annos, [input], + p=p, inplace=inplace) + + def Detach(input, signature = None): """ torch.Tensor.detach(input) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 7dc39513..dd57cb5a 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -140,49 +140,33 @@ def remove_anchor(graph: IRSegment): return graph @staticmethod - def auto_pyfunc(graph: IRSegment): + def auto_pyfunc(graph: IRGraph): """ Make pyfunc to be local IRPyFunc will be replicated to devices with its producers output """ - for func in graph.select(ntype=IRPyFunc, flatten=False): - assert func.mirror is None, "PyFunc is only supported by inference" + for func in graph.select(ntype=IRPyFunc, flatten=True): # get devices it will lowered to + segment: IRSegment = graph.segment(func) + segment_outputs = IRSegment.get_objects_from_complex(segment.outputs()) devices = set() for t in func.inputs(): if not isinstance(t, IRObject): continue - if t.is_attr(): - cells = graph.consumers(t.parent) - else: - cells = graph.producers(t.parent) + cells = segment.consumers(t.parent) if t.is_attr() else segment.producers(t.parent) for cell in cells: devices.update(cell.device) - pyfuncs = [] - # lower to each device - for devid in devices: - inputs = [] - # automatic partition to align with consumer (attr) or producer (activation) - for t in func.inputs(): - sub_ts = set() - if not isinstance(t, IRSubTensor): - sub_ts.add(t) # replica for non-tensor - elif t.is_attr(): - # get local consumers except func itself - sub_ts = set(tensor for tensor in graph.ctensors(t.parent) \ - if devid in tensor.device and tensor.cell != func) - else: - # get local producers - sub_ts = set(tensor for tensor in graph.ptensors(t.parent) \ - if devid in tensor.device) - inputs.append(t if len(sub_ts) == 0 else list(sub_ts)[0]) - lower_func = IRPyFunc(func.signature, inputs, func.outputs(), **func.kwargs) - lower_func.device = devid - pyfuncs.append(lower_func) - position = graph.remove(func) - for pyfunc in pyfuncs: - graph.insert(pyfunc, position) - for segment in graph.select(ntype=IRSegment, flatten=False): - IRAdapterGener.auto_pyfunc(segment) + for t in func.outputs(): + if not isinstance(t, IRObject): continue + if t in segment_outputs: + devices.update(segment.device) + # replicate + pyfuncs = [func.replicate() for _ in devices] + for devid, pyfunc in zip(sorted(devices), pyfuncs): + pyfunc.device = devid + # insert + position = segment.remove(func) + for pyfunc in pyfuncs[::-1]: + segment.insert(pyfunc, position) return graph @staticmethod @@ -331,7 +315,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # consumers can be operators and graph outputs fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) if ftensor in output_consumer: - fctensors = fctensors + tuple(fwop.input(0) for fwop in output_consumer[ftensor]) + fctensors = fctensors + tuple(fwop.input(0) for fwop in output_consumer[ftensor] if fwop.input(0) not in fctensors) fctensors = expand_devices(fctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" @@ -379,20 +363,32 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # insert forward adapter # graph.insert(fadapter, max(producers) + 1) - tail = CellPosition((graph.nnodes,)) - fconsumer_idx = [tail] + [graph.index(c) for c in fconsumers] - fidx = min(fconsumer_idx) + if len(fconsumers) > 0: + fidx = min(graph.nodes().index(c) for c in fconsumers) + else: + # no consumer: find the last forward node + for fidx, node in enumerate(graph.nodes()[::-1]): + if node.isfw(): + fidx = graph.nnodes - fidx + break + graph.insert(fadapter, fidx) # setup recompute if fadapter.differentiable and allow_recompute: - fadapter.recompute = graph.node(fidx if fidx != tail else fidx-1).recompute - graph.insert(fadapter, fidx) + if fidx > 0: + prev_node = graph.node(fidx-1) + if isinstance(prev_node, IRFwOperation): + fadapter.recompute = prev_node.recompute # insert backward adapter if badapter is not None: assert isinstance(badapter, IRAdapter) assert isinstance(bgraph, IRSegment) - bproducer_idx = [CellPosition((0,))] + [bgraph.index(consumer.mirror) + 1 for consumer in fconsumers] - bidx = max(bproducer_idx) + if len(bproducers) > 0: + bidx = max(bgraph.nodes().index(p) for p in bproducers) + 1 + else: + # no producer: find the first backward node + for bidx, node in enumerate(bgraph.nodes()): + if not node.isfw(): break bgraph.insert(badapter, bidx) # generate adapter for each segment @@ -664,7 +660,8 @@ def autoref(graph: IRSegment) -> IRGraph: for multiref in graph.select(name='multiref', flatten=False): ftensor: IRFullTensor = multiref.input(0).parent multirefs = [] - for otensor in graph.ptensors(ftensor): + tensors = graph.ptensors(ftensor) if len(graph.ptensors(ftensor)) > 0 else graph.ctensors(ftensor) + for otensor in tensors: mr = MultiRef(otensor, len(multiref.outputs())) for idx in range(len(multiref.outputs())): output = multiref.output(idx).parent.select(otensor.indmap, otensor.valmap) @@ -672,8 +669,7 @@ def autoref(graph: IRSegment) -> IRGraph: output.grad = multiref.output(idx).grad.parent.select(otensor.indmap, (0,1)) mr.set_output(idx, output) mr.device = otensor.device - # Dataloader has no recompute - if hasattr(otensor.cell, 'recompute'): + if isinstance(otensor.cell, IRFwOperation): mr.recompute = otensor.cell.recompute multirefs.append(mr) # remove original multiref diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index 1c9d6d9c..73d1d406 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -18,6 +18,7 @@ import torch from torch._C import ScriptObject from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict +from torch.utils._pytree import tree_map from torch.fx import GraphModule from torch.fx._compatibility import compatibility @@ -228,7 +229,33 @@ def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: D if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) with self.do_temp_disable(call=True): - return OperatorPatcherContext.patch_run(fn, *args, **kwargs) + to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + result = OperatorPatcherContext.patch_run(fn, *args, **kwargs) + for arg in args: + if _orig_isinstance(arg, torch.Tensor): + del arg + del args + for key, value in kwargs.items(): + if _orig_isinstance(value, torch.Tensor): + del value + del kwargs + if _orig_isinstance(result, torch.Tensor): + result_cpu = result.cpu() + del result + torch.cuda.empty_cache() + return result_cpu + if not isinstance(result, (tuple, list, dict)): + torch.cuda.empty_cache() + return result + to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t + result_cpu = tree_map(to_cpu, result) + for ret in result: + if _orig_isinstance(ret, torch.Tensor): + del ret + torch.cuda.empty_cache() + return result_cpu elif kind == 'call_method': self_obj, *args_tail = args fn = _orig_getattr(self_obj, target) @@ -239,10 +266,38 @@ def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: D elif kind == 'call_module': assert isinstance(target, str) mod = self.fetch_attr(target) + mod.cuda() if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(mod, '__globals__'): _autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) with self.do_temp_disable(call=True): - return OperatorPatcherContext.patch_run(mod, *args, **kwargs) + to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + result = OperatorPatcherContext.patch_run(mod, *args, **kwargs) + for arg in args: + if _orig_isinstance(arg, torch.Tensor): + del arg + del args + for key, value in kwargs.items(): + if _orig_isinstance(value, torch.Tensor): + del value + del kwargs + mod.cpu() + if _orig_isinstance(result, torch.Tensor): + result_cpu = result.cpu() + del result + torch.cuda.empty_cache() + return result_cpu + if not isinstance(result, (tuple, list, dict)): + torch.cuda.empty_cache() + return result + to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t + result_cpu = tree_map(to_cpu, result) + for ret in result: + if _orig_isinstance(ret, torch.Tensor): + del ret + torch.cuda.empty_cache() + return result_cpu elif kind == 'get_attr': assert isinstance(target, str) return self.fetch_attr(target) diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py index 223c3e2a..1d9aac75 100644 --- a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py +++ b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py @@ -3,6 +3,7 @@ import torch.fx.traceback as fx_traceback from torch.fx import Interpreter, Node from typing import Optional, Union, Tuple, Dict, List, Any, Iterator, Callable, MutableMapping, Mapping +from torch.utils._pytree import tree_map Target = Union[Callable[..., Any], str] @@ -145,3 +146,28 @@ def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[s else: raise RuntimeError( f'Expected positional argument for parameter {target}, but one was not passed in!') + + def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + result = super().call_function(target, args, kwargs) + if isinstance(result, torch.Tensor): + return result.cpu() + else: + to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t + return tree_map(to_cpu, result) + + def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + assert isinstance(target, str) + mod = self.fetch_attr(target) + mod = mod.cuda() + to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + result = mod(*args, **kwargs) + if isinstance(result, torch.Tensor): + return result.cpu() + else: + to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t + return tree_map(to_cpu, result) diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py index 9ad13c01..2cbad447 100644 --- a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py +++ b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py @@ -1,5 +1,8 @@ +import builtins +import operator import torch import traceback +from torch.fx import GraphModule from torch.fx.node import Node, map_aggregate from typing import Optional, Union, NamedTuple, Tuple, Any, Dict from .kwargs_interpreter import KwargsInterpreter @@ -100,27 +103,32 @@ def propagate(self, concrete_args: Union[Dict[str, Any], Tuple]): class DCEHandler: + dont_delete = [operator.setitem, builtins.next] + def __init__(self, gm: torch.fx.GraphModule): self.gm = gm def eliminate_dead_code(self): - # set a loop to make sure clean all dead nodes - # because some nodes may be used by some dead nodes are also dead nodes - # !pay attention that the `output` node should be ignored for users checking + to_check = set() + for node in self.gm.graph.nodes: + to_check.add(node) while True: - removed = False - for node in self.gm.graph.nodes: + deleted = False + modified = set() + for node in to_check: if node.op == 'output': continue - users = list(node.users) - if not users: - # make input nodes pop this node out of their users list - # before the node is removed - input_nodes = node.all_input_nodes - for input_node in input_nodes: + if not node.users and node.op != 'placeholder' and node.target not in DCEHandler.dont_delete: + for input_node in node.all_input_nodes: input_node.users.pop(node) + modified.add(input_node) node._remove_from_list() - removed = True - if not removed: + if node in modified: + modified.remove(node) + deleted = True + if deleted is False: break - self.gm.recompile() + else: + to_check = modified + name = self.gm.__class__.__name__ + return GraphModule(self.gm, self.gm.graph, name) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 5e53a4f7..c5c8d2b4 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -29,27 +29,39 @@ def convert_model(model: torch.nn.Module, smodule.graph.print_tabular() else: print('INFO: using concrete tracer') - with torch.no_grad(): - if isinstance(dummy_input, torch.Tensor): - output_origin = model(dummy_input) - dummy_input = (dummy_input, ) - elif isinstance(dummy_input, tuple) or isinstance(dummy_input, list): - output_origin = model(*dummy_input) - elif isinstance(dummy_input, dict): - print(f'WARNING dict dummy_input') - output_origin = model(**dummy_input) - else: - raise RuntimeError(f'dummy_input should be a tuple (not a {type(dummy_input)}) = {dummy_input}') + # NOTE: remove this part because when model is too large to fit in one GPU, + # this model forward cannot be successfully done, thus remove it. + # with torch.no_grad(): + # if isinstance(dummy_input, torch.Tensor): + # output_origin = model(dummy_input) + # dummy_input = (dummy_input, ) + # elif isinstance(dummy_input, tuple) or isinstance(dummy_input, list): + # output_origin = model(*dummy_input) + # elif isinstance(dummy_input, dict): + # print(f'WARNING dict dummy_input') + # output_origin = model(**dummy_input) + # else: + # raise RuntimeError(f'dummy_input should be a tuple (not a {type(dummy_input)}) = {dummy_input}') traced_model = concrete_trace( model, dummy_input, use_operator_patch=True, autowrap_leaf_class={ torch.finfo: ((), False), - type(output_origin): ((), False), + # type(output_origin): ((), False), }, + # FIXME: check if dropout is not included in it, can self.training be handled properly in the new version of + # concrete_trace + leaf_module=( + torch.nn.Dropout, torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, + # NOTE: the following modules also have different behavior depending on self.training. but currently in used. + # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, + # torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, + # torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.SyncBatchNorm, + # torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, + # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, + ), ) - traced_model.graph.print_tabular() else: print('using torchscript tracer') smodule = torch.jit.script(model) @@ -65,7 +77,7 @@ def convert_model(model: torch.nn.Module, module_name = model.__class__.__name__ else: FxModuleParser.save_content = save_content - inputs, nodes, outputs = FxModuleParser.parse(traced_model, input_shapes=input_shapes, dummy_inputs=dummy_input) + inputs, nodes, outputs = FxModuleParser.parse(traced_model, input_shapes=None, dummy_inputs=dummy_input) module_name = model.__class__.__name__ else: ScriptModuleParser.save_content = save_content diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 102767ca..56676167 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -36,6 +36,9 @@ def exist(signature: str) -> bool: # tensor template __ttemplate = lambda name: f'torch.{name}' + # torch nn module + __tnmtemplate = lambda name: f'torch.nn.{name}' + # torch.Tensor template __tttemplate = lambda name: f'torch.Tensor.{name}' @@ -49,6 +52,7 @@ def exist(signature: str) -> bool: __customops = lambda name: f'examples.custom_ops.{name}' kOpMap = { + __tnmtemplate('Dropout'): function.nnDropout, __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index e7855453..4060279a 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -68,8 +68,7 @@ def parse(module: torch.fx.GraphModule, frame: Frame = None) \ -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler - dce_module = DCEHandler(module).eliminate_dead_code() - model = dce_module + DCEHandler(module).eliminate_dead_code() """ The overall entry to parse a torch.fx graph module @@ -114,20 +113,25 @@ def parse(module: torch.fx.GraphModule, frame.add_var(input.name, val, graph_arg=idx) else: assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' - # remove dead nodes - from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler - DCEHandler(module).eliminate_dead_code() # shape propagation from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp ShapeProp(module).propagate(dummy_inputs) # handle graph inputs for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) + # dealing with different types of dummy_inputs if isinstance(dummy_inputs, dict): - if input.name in dummy_inputs: - shape = input.meta['tensor_meta'].shape + if input.name not in dummy_inputs: + val = IRObject(input.name) else: - shape = None + if 'tensor_meta' in input.meta: + shape = input.meta['tensor_meta'].shape + if len(shape) == 0: + shape = [1] + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + else: + val = IRObject(input.name) else: # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, # extend to other input types @@ -138,8 +142,8 @@ def parse(module: torch.fx.GraphModule, # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name print(f'dummy_inputs does not have {input.name}') shape = None - dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) input_val = [frame.get_var(input.name) for input in inputs] @@ -159,7 +163,8 @@ def parse(module: torch.fx.GraphModule, print(f'{node.name} does not has tensor_meta') # add activations to frame, including call_func/call_method output and final output - activation_op_strs = {'call_function', 'output', 'call_method'} + # call_module corresponds to leaf torch.nn.module + activation_op_strs = {'call_function', 'output', 'call_method', 'call_module'} activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] for node in activation_nodes: if hasattr(node, 'meta') and node.meta.get('tensor_meta') and hasattr(node.meta['tensor_meta'], 'dtype'): @@ -232,7 +237,7 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == FxNodeKind.PrimGetAttr: return FxModuleParser.parse_prim_attr_node(node, module, frame) if node_type == FxNodeKind.PrimCallModule: - raise RuntimeError(f"parse_prim_module is not supported.") + return FxModuleParser.parse_prim_module(node, module, frame) # TODO bother assigning all ignored prim functions new NodeKinds? if node_type == FxNodeKind.PrimDevice: @@ -251,6 +256,13 @@ def fetch_attr(mod: torch.fx.GraphModule, target: str): attr_itr = getattr(attr_itr, atom) return attr_itr + @staticmethod + def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + prim_module = FxModuleParser.fetch_attr(module, node.target) + assert prim_module.__class__.__module__.startswith('torch.nn.modules'), f'{module.__class__.__module__}' + fsig = 'torch.nn.{}'.format(prim_module.__class__.__name__) + return FxModuleParser._parse_node(fsig, node, module, frame) + @staticmethod def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: # get signature @@ -259,7 +271,11 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule print(f'parse_prim_method_node: {fsig}') else: print(f'parse_prim_function_node: {fsig}') + return FxModuleParser._parse_node(fsig, node, module, frame) + @staticmethod + def _parse_node(fsig: str, node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + def get_complex_data(val: Any) -> Any: """Change inner fx.Node into IRObject""" if isinstance(val, tuple): @@ -269,7 +285,7 @@ def get_complex_data(val: Any) -> Any: if isinstance(val, torch.fx.Node): return frame.get_var(val.name) return val - + # get inputs input_vals = [get_complex_data(val) for val in node.args] kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 813985a7..c3c55cb3 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -178,3 +178,6 @@ def stack(*tensors, dim=0) -> torch.Tensor: def cat(*tensors, dim=0) -> torch.Tensor: return torch.cat(tensors, dim) + +def nndropout(input: torch.Tensor, p=0.5, inplace=False): + return torch.nn.Dropout(0.0, inplace)(input) \ No newline at end of file From 20611a1800b667403f32949c7f0402399cf56282 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 31 Mar 2023 07:10:50 +0000 Subject: [PATCH 1374/1892] Merged PR 1517: Support torchscale in AutoDist --- cube/algorithm/ops/dimops.py | 11 +---------- cube/graph/parser/dtype.py | 1 + cube/ir/cten.py | 1 + cube/profiler/database.py | 6 +++--- cube/runtime/function/function.py | 2 +- 5 files changed, 7 insertions(+), 14 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index e7e5c5f9..e826f030 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -322,9 +322,6 @@ def instantiate(self, idx: int, dimi: int, dimo: int, num: int) -> Optional[List return sub_nodes def collect_split_info(node: IRFwOperation): - # TODO(yizhu1): workaround - split_batch_ops = {} - anno = node.anno split_info = {} @@ -339,13 +336,7 @@ def collect_split_info(node: IRFwOperation): if identifier not in split_info: split_info[identifier] = (idx_shape, idx_dim, idx_id) - if node.signature in split_batch_ops: - for key, val in split_info.items(): - if val == (0, 0, 0): - return {key: val} - assert False - else: - return split_info + return split_info def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: diff --git a/cube/graph/parser/dtype.py b/cube/graph/parser/dtype.py index 97f264fd..9d74ee53 100644 --- a/cube/graph/parser/dtype.py +++ b/cube/graph/parser/dtype.py @@ -37,6 +37,7 @@ def map(ir_dtype: ir.IRDType): """ Map the IRDtype to torch dtype """ + assert ir_dtype in IRDType2TorchDType.kDtypeMap, f'unexpected ir_dtype {ir_dtype}' return IRDType2TorchDType.kDtypeMap[ir_dtype] kDtypeMap = {val: key for key, val in DType2IRDType.kDtypeMap.items()} \ No newline at end of file diff --git a/cube/ir/cten.py b/cube/ir/cten.py index fec5509a..41bdda31 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -535,6 +535,7 @@ def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): super().__init__(name, tid) self._shape: Tuple[int] = () if shape is None else tuple(shape) self._cell: Optional[IRCell] = None + assert isinstance(dtype, IRDType), f'expect IRDType, get {dtype} with type {type(dtype)}' self._dtype: IRDType = dtype # tensor gradient self._is_grad: bool = False diff --git a/cube/profiler/database.py b/cube/profiler/database.py index d1306a93..098148b4 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -412,7 +412,7 @@ def dump(self, file: str, override=False): json.dump(self._data, f) - def dump_nodes(self, file: str, override=False): + def dump_ops(self, file: str, override=False): if os.path.exists(file): assert override, f"File {file} exists. Set override = True to force dump." for signature in self._data.keys(): @@ -420,7 +420,7 @@ def dump_nodes(self, file: str, override=False): with open(file_n, 'w') as f: json.dump(self._data[signature],f) - def dump_node(self, file: str, signature ,override=False): + def dump_op(self, file: str, signature, override=False): assert signature in self._data.keys(), f'this node not be profiled' file_n = os.path.join(file, signature +'.json') with open(file_n, 'w') as f: @@ -436,7 +436,7 @@ def load(self, file: str): with open(file, 'r') as f: self._data = json.load(f) - def load_nodes(self, file: str): + def load_ops(self, file: str): for filename in os.listdir(file): if filename.endswith('.json'): with open(os.path.join(file, filename)) as f: diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index c3c55cb3..e97ffd86 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -180,4 +180,4 @@ def cat(*tensors, dim=0) -> torch.Tensor: return torch.cat(tensors, dim) def nndropout(input: torch.Tensor, p=0.5, inplace=False): - return torch.nn.Dropout(0.0, inplace)(input) \ No newline at end of file + return torch.nn.Dropout(0.0, inplace)(input) From 5a24346bc9581771c34c99ba68aab51e28387af2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 31 Mar 2023 12:34:23 +0000 Subject: [PATCH 1375/1892] Merged PR 1511: Add utility and refine code 1) extend EnvResource to include total gpu memory 2) extend OpAnno to get reduce type given an identifier 3) extend IRDimops to get transformation space 4) Remove useless code --- cube/algorithm/factory.py | 1 - cube/algorithm/ops/dimops.py | 103 -------------------------------- cube/graph/function/dimops.py | 45 ++++++++++++-- cube/runtime/adapter/reducer.py | 7 ++- cube/runtime/resource.py | 31 ++++++++-- 5 files changed, 70 insertions(+), 117 deletions(-) diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 279757e2..602df88e 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -64,7 +64,6 @@ def _load_predefined_algos(self): import cube.algorithm.ops.dimops as dimops self.register(dimops.IRDimops, dimops.DimSplitEinops, tag='dim') - self.register(dimops.IRDimops, dimops.SimpleViewSplitEinops, tag='view_simp') import cube.algorithm.ops.conv as conv self.register(conv.IRPad, conv.DimSplitPad, tag='dim') diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index e826f030..f7972ec9 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -218,109 +218,6 @@ def modify(kwargs: Dict, idx: int, dim: int, num: int): return TransformRule(itransform, otransform, modify) -class SimpleViewSplitEinops(GenericDistAlgo): - """ - split Einops at dimension level. - - The sum-reduce dimension and non-reduce dimension can be splitted. - - For sum-reduce dimension, the output keeps same shape but has partial-sum valmap result. - For non-reduce dimension, the output keeps same valmap but has partial output shape. - For stay-reduce dimension, this dimension is not allowed to be splitted. - """ - - def __init__(self, node: IRDimops): - if not isinstance(node, IRDimops): - raise TypeError(f"Expect IRDimops") - super().__init__(node) - self._adim: str = None - self._reduce: DimAnno.ReduceType = None - - def satisfy(self, idx: int, dimi: int, dimo: int, num: int) -> bool: - """ - Check whether the condition satisfies. - - @param idx int: input index - @param dimi int: input dimension - @param dimo int: corresponding output dimension - @param num int: chunks to partition the dimension - - @return satisfy bool: true if can be partitioned, elsewise false. - """ - # assert all(isinstance(cond, int) for cond in [idx, dim, num]), "expect int condition" - node: IRDimops = self.node - assert idx == 0, f"Index should be 0" - assert len(node.inputs()) == 1, f"Inputs size should be 1" - assert len(node.outputs()) == 1, f"Outputs size should be 1" - dimi = dimi if dimi >= 0 else dimi + node.input(0).ndims - dimo = dimo if dimo >= 0 else dimo + node.output(0).ndims - assert dimi < node.input(0).ndims, f"dimension out of boundary: {dimi} >= {node.input(0).ndims}" - assert dimo < node.output(0).ndims, f"dimension out of boundary" - # # due to implementation limits, we only partition the first annotated dimension - # # for inner-dimension cases. - idi = 1 if dimi == 0 else 0 - ido = 1 if dimo == 0 else 0 - self._adimi: str = node.anno.input(0).dims[dimi].identifiers[idi] - self._adimo: str = node.anno.output(0).dims[dimo].identifiers[ido] - dimlen = node.anno.getlen(self._adimi) - if dimlen < num: - return False - return True - - def instantiate(self, idx: int, dimi: int, dimo: int, num: int) -> Optional[List[IRDimops]]: - - node: IRDimops = self.node - satisfy = self.satisfy(idx, dimi, dimo, num) - if not satisfy: - return None - - ins, ous = list(), list() - for iidx, itensor in enumerate(node.inputs()): - if not isinstance(itensor, IRSubTensor): - assert 0, "should not happen" - shape_anno = node.anno.input(iidx) - split_dims = shape_anno.getdims(self._adimi) - assert len(split_dims) <= 1, f"find split dims ({self._adimi}) more than 1: {shape_anno}" - if len(split_dims) == 1: - dim = split_dims[0] - # split axis - # print('dimi =', dim) - ins.append(itensor.split_dim(dim, num)) - else: - assert 0, "should not happen" - - for oidx, otensor in enumerate(node.outputs()): - if not isinstance(otensor, IRSubTensor): - assert 0, f"should not happen" - shape_anno = node.anno.output(oidx) - split_dims = shape_anno.getdims(self._adimo) - assert len(split_dims) <= 1, f"find split dims ({self._adimo}) more than 1: {shape_anno}" - # split axis - if self._reduce != DimAnno.ReduceType.Dim: - assert len(split_dims) == 1, f"expect only one spatial dimension in output tensor but got {len(split_dims)}" - dim = split_dims[0] - # print('dimo =', dim) - ous.append(otensor.split_dim(dim, num)) - # split numerical dimension - else: - assert 0, f"not implemented" - - sub_nodes = list() - for nid in range(num): - inputs = [t[nid] for t in ins] - outputs = [t[nid] for t in ous] - updated_kwargs = dict() - if self._adimi in node.kwargs and isinstance(node.kwargs[self._adimi], int): - assert 0, "should not happen" - if self._adimo in node.kwargs and isinstance(node.kwargs[self._adimo], int): - assert 0, "should not happen" - assert len(outputs) == 1, f"outputs len should be one" - updated_kwargs['size'] = outputs[0].shape - sub_node: IRDimops = node.new(inputs, outputs, **updated_kwargs) - sub_node.infer_shape() - sub_nodes.append(sub_node) - return sub_nodes - def collect_split_info(node: IRFwOperation): anno = node.anno diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 181804ae..0a696dc3 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -304,11 +304,12 @@ def __init__(self, anno: Union[str, Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]]]): inputs, outputs = anno self._inputs: Tuple[ShapeAnno] = tuple(inputs) self._outputs: Tuple[ShapeAnno] = tuple(outputs) - self._identifiers: Dict[str, int] = dict() + self._identifiers: Dict[str, int] = dict() # identifier -> dimension length + self._reduces: Dict[str, DimAnno.ReduceType] = dict() # identifier -> reducer self.reset_identifiers() @property - def identifiers(self) -> Set[str]: + def identifiers(self) -> Tuple[str]: """! Get all identifier set @@ -360,7 +361,10 @@ def reset_identifiers(self): shape_annos = list(self._inputs) + list(self._outputs) for ashape in shape_annos: for adim in ashape.dims: - self._identifiers.update({identifier: None for identifier in adim.identifiers}) + for identifier, reduce in zip(adim.identifiers, adim.reduces): + self._identifiers[identifier] = None + # TODO: check consistency + self._reduces[identifier] = reduce for identifier in self._identifiers.keys(): if str.isdecimal(identifier): self._identifiers[identifier] = int(identifier) @@ -391,9 +395,20 @@ def getlen(self, identifier: str) -> Optional[int]: @return length Optional[int]: the length of identifier """ - assert identifier in self._identifiers, f"{identifier} not int identifier set {self._identifiers}" + assert identifier in self._identifiers, f"{identifier} not exists {set(self._identifiers.keys())}" return self._identifiers[identifier] + def get_reduce(self, identifier: str) -> DimAnno.ReduceType: + """ + Get identifier reduce type + + @param identifier str: identifier name + + @return reduce DimAnno.ReduceType + """ + assert identifier in self._reduces, f"{identifier} not exists {set(self._reduces.keys())}" + return self._reduces[identifier] + def __repr__(self) -> str: inputs = ', '.join(repr(input) for input in self.inputs()) outputs = ', '.join(repr(output) for output in self.outputs()) @@ -790,3 +805,25 @@ def algorithms(self, tag: Optional[str] = None): template = factory.algorithms(IRDimops, tag) return template(self) return None + + def transform_space(self) -> List[Tuple[int, int]]: + """ + Get transformation space of the operator + + @return List[Tuple[int, int]]: list of (idx, dim) + """ + visited : Set[str] = set() + configs = [] + ashapes = self.anno.inputs() + self.anno.outputs() + for idx, eshape in enumerate(ashapes): + if idx < len(self.inputs()): + if not isinstance(self.input(idx), IRTensor): continue + for dim, edim in enumerate(eshape.dims): + for identifier, reduce in zip(edim.identifiers, edim.reduces): + if identifier in visited: continue + visited.add(identifier) + if identifier == '1' or self.anno.getlen(identifier) == 1: continue + if reduce == DimAnno.ReduceType.Freeze: break + configs.append((idx, dim)) + break + return configs diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 019daefb..b89c3099 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -43,14 +43,15 @@ def allreduce(self): for param in self._params: if param.requires_grad and param.grad is not None: cur_byte_size = param.nelement() * param.element_size() - assert cur_byte_size <= self.bucket_size, f'cur_byte_size = {cur_byte_size}' - tp = param.data.type() if tp not in buckets: buckets[tp] = [[param]] tp2size[tp] = cur_byte_size else: - if tp2size[tp] + cur_byte_size <= self.bucket_size: + if cur_byte_size > self.bucket_size: + warnings.warn(f'find one parameter {param.shape} ({cur_byte_size} bytes) larger than bucket size {self.bucket_size}') + buckets[tp].insert(0, [param]) + elif tp2size[tp] + cur_byte_size <= self.bucket_size: tp2size[tp] = tp2size[tp] + cur_byte_size buckets[tp][-1].append(param) else: diff --git a/cube/runtime/resource.py b/cube/runtime/resource.py index 06dc1c2f..d8f3dafa 100644 --- a/cube/runtime/resource.py +++ b/cube/runtime/resource.py @@ -1,9 +1,17 @@ r""" Runtime information """ +from typing import Tuple import torch -import os +from cube.flags import CompileFlag +from dataclasses import dataclass + + +@dataclass +class DeviceInfo: + # memory in btypes + memory: int = None class EnvResource: @@ -12,13 +20,24 @@ class __EnvResource: def __init__(self): # number of gpus - single_device_mode = os.environ.get('SINGLE_DEV_MODE') - if single_device_mode: - self.ngpus = 1 - else: - self.ngpus = torch.distributed.get_world_size() + self.ngpus = 1 if CompileFlag.dev_mode else torch.distributed.get_world_size() # device topology self.topo = None + self.gpus: Tuple[DeviceInfo] = self.get_device_capability() + + def get_device_capability(self) -> Tuple[DeviceInfo]: + if CompileFlag.dev_mode: + memory = [torch.cuda.get_device_properties(0).total_memory] + else: + rank = torch.distributed.get_rank() + memory = torch.tensor(torch.cuda.get_device_properties(0).total_memory, + dtype=torch.int64, device=torch.cuda.current_device()) + all_device_mem = [torch.empty_like(memory) for _ in range(self.ngpus)] + all_device_mem[rank] = memory.data + torch.distributed.all_gather(all_device_mem, memory) + torch.cuda.synchronize() + memory = [t.item() for t in all_device_mem] + return tuple(DeviceInfo(memory=mem) for mem in memory) instance = None From 3552ba22a8d8d6e6b416f65ba7df2514f64f58e3 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Sat, 1 Apr 2023 11:47:09 +0000 Subject: [PATCH 1376/1892] Merged PR 1520: fix embed sharing fix embed sharing by detect duplicated node.target and re-direct to the same IRTensor --- cube/graph/parser/frame.py | 25 +++++++++++++++++++++++++ cube/graph/parser/parserfx.py | 26 +++++++++++++++++--------- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index 0480a407..bfd43527 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -15,6 +15,7 @@ def __init__(self): # module attributes self._attributes: List[dict[str, Any]] = list() self._attr_vals: Dict[int, Any] = dict() # tensor tid to real value mapping + self._name_map: Dict[Any, Any] = dict() # tensor name to real tensor name def push_var(self, inherit_from_top=False): """ @@ -153,6 +154,30 @@ def save_attr_content(self, save_file: str = 'fullmodel.pt'): """ torch.save(self._attr_vals, save_file) + def add_attr_map(self, key, value): + """ + Add names map to connect internal parameter name and original parameter + """ + self._name_map[str(key)] = value + + def has_attr_value(self, value): + return value in self._name_map.values() + + def get_attr_key(self, value): + ret = None + for key, val in self._name_map.items(): + if val == value: + ret = key + break + return ret + + def save_attr_map(self, save_file: str = 'dist_param_map.pt'): + """ + Save local_param -> origin_param name map. + """ + torch.save(self._name_map, save_file) + + def push_param(self, var_name): """ push var name to the method stack diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 4060279a..c6089f2b 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -298,7 +298,7 @@ def get_complex_data(val: Any) -> Any: # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): print(f'>>> Find unknown pytorch operation: {fsig}') - fname = fsig.split('.')[-1] if '.' in fsig else fname + fname = fsig.split('.')[-1] if '.' in fsig else fsig ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) # case2: python runtime function else: @@ -337,14 +337,22 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram tensor_shape = node.meta['tensor_meta'].shape dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) requires_grad = node.meta['tensor_meta'].requires_grad - ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=requires_grad, dtype=dtype) - if requires_grad: # case for registered parameters - ir_tensor.as_param() - else: # case for registered buffers - ir_tensor.as_buffer() - frame.add_var(tensor_name, ir_tensor) - value = FxModuleParser.fetch_attr(module, node.target) - frame.add_attr_content(ir_tensor.tid, value) + + # check if existing param + if requires_grad and frame.has_attr_value(node.target): # existing param + prev_tensor_name = frame.get_attr_key(node.target) + print(f'INFO: link {tensor_name} to existing param {prev_tensor_name}') + frame.add_var(tensor_name, frame.get_var(prev_tensor_name)) + else: # new param / activation + ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=requires_grad, dtype=dtype) + if requires_grad: # case for registered parameters + ir_tensor.as_param() + else: # case for registered buffers + ir_tensor.as_buffer() + frame.add_var(tensor_name, ir_tensor) + value = FxModuleParser.fetch_attr(module, node.target) + frame.add_attr_content(ir_tensor.tid, value) + frame.add_attr_map(ir_tensor.name, node.target) else: var = FxModuleParser.fetch_attr(module, node.target) frame.add_var(tensor_name, var) From 5d22d5e7ea63766600ad4e2c8a24bea4e8fb8b7e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 3 Apr 2023 05:25:46 +0000 Subject: [PATCH 1377/1892] Merged PR 1519: support general schedule signature. support general schedule signature --- cube/codegen/schedule/schedule.py | 5 +- cube/compiler.py | 78 +++++++++----- examples/mlp/linearsfx.py | 4 +- tests/parser/test_compile.py | 165 ++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+), 31 deletions(-) create mode 100644 tests/parser/test_compile.py diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index c0d37780..081e42af 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -42,8 +42,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: lifetime = LifeCycle(device_nodes, [], self.execplan.graph.outputs()) - model_inputs = ['{}_{}'.format(_input.name, _input.tid) for _input in self.execplan.graph.inputs()] - args = ['model'] + model_inputs + args = ['model'] + [ScheduleCodeGen.tensor_name(t) for t in self.execplan.graph.inputs()] with FunctionBlock(func_name='_train_step', args=args) as fb: @@ -94,7 +93,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: fb.insert_body(codes) # release tensors = lifetime.release_tensors_after_line(line) - tensors = [t for t in tensors if not t.is_grad()] + tensors = [t for t in tensors if isinstance(t, IRTensor) and not t.is_grad()] if len(tensors) > 0 : # not necessarily to have one after each line fb.insert_body(ScheduleCodeGen.emit_release(tensors)) # return code diff --git a/cube/compiler.py b/cube/compiler.py index 9eee8d5f..ad4e9449 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple, Union, Optional +from typing import Callable, Tuple, Union, Optional, Any import torch import time import os @@ -7,7 +7,9 @@ from cube.graph.gener.gen import IRAdapterGener from cube.graph.graph import IRGraph -from cube.ir.operator import IRDataOperation +from cube.ir.cten import IRObject +from cube.graph.parser.dtype import DType2IRDType +from cube.ir.tensor import IRFullTensor from cube.graph.function.anchor import IRGraphAnchor from cube.graph.schedule.schedplan import SchedulePlan from cube.graph.function.pyfunc import IRPyFunc @@ -27,8 +29,9 @@ from cube.flags import CompileFlag -def compile(model: SemanticModel, dataloader: Optional[CubeDataLoader] = None, +def compile(model: SemanticModel, *args, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, + model_dummy_inputs: Tuple[Any] = None, comm_cost_fn: Optional[Callable] = None, override = True, load_content = True) -> Callable: """ AI Scientist calls like: @@ -48,8 +51,9 @@ def train_step(model, dataloader): ... @param model SemanticModel: AI Scientist specified SemanticModel - @param dataloader CubDataLoader: dataloader used for training + @param args: compile function example inputs @param PAS Callable: policy to transform and schedule graph + @param model_dummy_inputs Tuple[Any]: model example inputs when using torch.fx parser @param comm_cost_fn: Optional[Callable]: communication cost function, which takes in an IRAdapterPrim, and outputs a cost in float. By default (None) use communication volume. @@ -64,19 +68,28 @@ def train_step(model, dataloader): # clean global status Program().clear() IDGenerator().clear() + assert PAS is not None, f'PAS should be callable function' - if not isinstance(model, SemanticModel): - raise TypeError("Expect Semantic Model") - if dataloader is None: - # create empty dataloader - dataloader = cube.runtime.syndata.SynDataLoader(shapes=(),dtypes=()) - if not isinstance(dataloader, CubeDataLoader): - raise TypeError("Expect dataloader derived from CubeDataLoader") - + model = SemanticModel(model) if isinstance(model, torch.nn.Module) else model + assert isinstance(model, SemanticModel), f'Require cube.SemanticModel or torch.nn.Module, but got model: {type(model)}' model.save_content = load_content - if model.dummy_input is None: - model.dummy_input = next(dataloader) - ir_dataloader = SemanticDataLoader(dataloader) + model.dummy_input = model_dummy_inputs + + dataloader = None + inputs = [model] + for arg in args: + assert not isinstance(arg, (torch.nn.Module, SemanticModel)), f"Only one model can be input for compile" + if isinstance(arg, (torch.utils.data.Dataset, CubeDataLoader)): + assert dataloader is None + dataloader = arg + arg = SemanticDataLoader(dataloader) + elif isinstance(arg, torch.Tensor): + arg = IRFullTensor(arg.shape, name='tensor', + requires_grad=arg.requires_grad, + dtype=DType2IRDType.map(arg.dtype)).tosub() + else: + arg= IRObject('obj') + inputs.append(arg) myrank = DeviceGroup().rank @@ -101,12 +114,22 @@ def decorator(fn: Callable) -> Callable: # run once to get model structure and tensor shape start = time.time() - outputs = fn(model, ir_dataloader) + outputs = fn(*inputs) Program().finalize() if outputs is None: outputs = [] elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = [outputs] + # setup program input + pinputs = [] + for input in inputs[1:]: # we don't consider `model` as inputs + if isinstance(input, SemanticModel): + pinputs.append('model') + elif isinstance(input, SemanticDataLoader): + pinputs.append('dataloader') + else: + pinputs.append(input) + Program().set_input(pinputs) # setup program output Program().set_output(outputs) span = time.time() - start @@ -216,17 +239,18 @@ def decorator(fn: Callable) -> Callable: model.dummy_input = None # set dataloder batch size (serialize output) - bs = model.get_gen_module().get_batch_size() - print_each_rank(f'> setting batch size to: {bs}') - if torch.distributed.is_initialized(): - for rank in range(torch.distributed.get_world_size()): - if rank == torch.distributed.get_rank(): - if bs is not None and dataloader is not None: - dataloader.set_batch_size(bs) - torch.distributed.barrier() - else: - if bs is not None and dataloader is not None: - dataloader.set_batch_size(bs) + if dataloader is not None: + bs = model.get_gen_module().get_batch_size() + print_each_rank(f'> setting batch size to: {bs}') + if torch.distributed.is_initialized(): + for rank in range(torch.distributed.get_world_size()): + if rank == torch.distributed.get_rank(): + if bs is not None and dataloader is not None: + dataloader.set_batch_size(bs) + torch.distributed.barrier() + else: + if bs is not None and dataloader is not None: + dataloader.set_batch_size(bs) if torch.distributed.is_initialized(): torch.distributed.barrier() diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py index 03eb1937..f51b21e5 100644 --- a/examples/mlp/linearsfx.py +++ b/examples/mlp/linearsfx.py @@ -125,7 +125,9 @@ def train(): model = MLP(dim=dim) model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=PAS, load_content=False) + data, mask = next(dataloader) + @cube.compile(model, dataloader, PAS=PAS, load_content=False, + model_dummy_inputs={'data': data, 'mask': mask}) def train_iter(model, dataloader): data, mask = next(dataloader) loss = model(data, mask) diff --git a/tests/parser/test_compile.py b/tests/parser/test_compile.py new file mode 100644 index 00000000..fbcdc121 --- /dev/null +++ b/tests/parser/test_compile.py @@ -0,0 +1,165 @@ +""" +USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_compile.py +""" +import torch + +import cube +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.ir.tensor import IRFullTensor +from cube.graph.function.dimops import IRDimops + + +cube.init() + + +class TestOpModule(torch.nn.Module): + + def __init__(self, shape=[256, 512]): + super().__init__() + self.param = torch.nn.Parameter(torch.empty(shape, dtype=torch.float32)) + + def forward(self, x: torch.Tensor, cache: torch.Tensor): + x = x + cache + # [256, 512], [256, 512] -> [256, 512] + x = x * self.param + # [256, 512] -> [512] + x1 = x.select(0, 6) + # [256, 512], [512] -> [256, 512] + x2 = x.select_scatter(x1, 0, 7) + # [256, 512] -> [512, 512] + x3 = x2.repeat(2, 1) + # [512, 512] -> [256, 512]: this will be parsed to 2 slice operations + x4 = x3[:256,:] + loss = x4.sum() + return loss + + +class TestDataLoader1(cube.runtime.syndata.CubeDataLoader): + + def __init__(self) -> None: + self.sample = ( + torch.rand([256, 512], dtype=torch.float32, device=torch.cuda.current_device()), + torch.rand([256, 512], dtype=torch.float32, device=torch.cuda.current_device()), + ) + batch_size = self.sample[0][0] + super().__init__(batch_size, (0, 0)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +class TestDataLoader2(cube.runtime.syndata.CubeDataLoader): + + def __init__(self) -> None: + self.sample = torch.rand( + [256, 512], dtype=torch.float32, device=torch.cuda.current_device()) + batch_size = self.sample[0] + super().__init__(batch_size, (0,)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +model = TestOpModule() +dataloader1 = TestDataLoader1() +dataloader2 = TestDataLoader2() + + +def graph_check(graph): + for t in graph.inputs(): + assert not isinstance(t, IRFullTensor) + for node in graph.nodes(): + for t in node.inputs() + node.outputs(): + assert not isinstance(t, IRFullTensor) + for t in graph.outputs(): + assert not isinstance(t, IRFullTensor) + + +def policy(graph, resource): + graph_check(graph) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + +def test_compile_with_dataloader(): + global model + + sample, cache = next(dataloader1) + + @cube.compile(model, dataloader1, PAS=policy, + model_dummy_inputs={'x': sample, 'cache': cache}) + def train_step(model, dataloader): + data = next(dataloader) + print(data) + loss = model(*data) + loss.backward() + + gmodel = cube.load_model() + + for step in range(4): + train_step(gmodel, dataloader1) + print(f'step [{step}/4]') + + +def test_compile_without_dataloader(): + global model + + dummy_args = next(dataloader1) + sample, cache = dummy_args + + @cube.compile(model, sample, cache, PAS=policy, + model_dummy_inputs={'x': sample, 'cache': cache}) + def train_step(model, x, cache): + loss = model(x, cache) + loss.backward() + + gmodel = cube.load_model() + + for step in range(4): + x, cache = next(dataloader1) + train_step(gmodel, x, cache) + print(f'step [{step}/4]') + + + +def test_compile_with_complex(): + global model + + sample = next(dataloader2) + cache = torch.rand([256, 512], dtype=torch.float32, device=torch.cuda.current_device()) + + # @cube.compile(model, dataloader2, cache, PAS=policy) + # print(sample.size(), cache.size()) + + @cube.compile(model, dataloader2, cache, PAS=policy, + model_dummy_inputs={'x': sample, 'cache': cache}) + def train_step(model, dataloader, cache): + sample = next(dataloader) + loss = model(sample, cache) + loss.backward() + + gmodel = cube.load_model() + + for step in range(4): + train_step(gmodel, dataloader2, step) + print(f'step [{step}/4]') + + + +if __name__ == '__main__': + test_compile_with_dataloader() + test_compile_without_dataloader() + test_compile_with_complex() \ No newline at end of file From dc20936a37764ef3fd4e0cee931399880dfee2a6 Mon Sep 17 00:00:00 2001 From: Juntao Liang Date: Mon, 3 Apr 2023 06:42:42 +0000 Subject: [PATCH 1378/1892] Merged PR 1513: update tracer to support cpu/gpu mixed trace The new tracer modifications include: 1) `run_target`: to support tracing model on cpu, and move `call_function` & `call_module` to cuda in case they do not support cpu operation. 2) other features that keep up with the latest master branch of nni's tracer. --- .../parser/concrete_trace_utils/__init__.py | 2 +- .../concrete_trace_utils/concrete_proxy.py | 13 +- .../concrete_trace_utils/concrete_tracer.py | 216 +++++++++++++----- .../kwargs_shape_prop/kwargs_interpreter.py | 5 +- .../concrete_trace_utils/operator_patcher.py | 51 +++-- .../parser/concrete_trace_utils/utils.py | 10 +- 6 files changed, 215 insertions(+), 82 deletions(-) diff --git a/cube/graph/parser/concrete_trace_utils/__init__.py b/cube/graph/parser/concrete_trace_utils/__init__.py index b825294a..0e04acba 100644 --- a/cube/graph/parser/concrete_trace_utils/__init__.py +++ b/cube/graph/parser/concrete_trace_utils/__init__.py @@ -11,4 +11,4 @@ More information about concrete tracing can be found in the :func:`concrete_trace` documentation. """ -from .concrete_tracer import ConcreteTracer, concrete_trace \ No newline at end of file +from .concrete_tracer import ConcreteTracer, concrete_trace diff --git a/cube/graph/parser/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/concrete_trace_utils/concrete_proxy.py index 8b2fcc57..ae05767a 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_proxy.py @@ -40,7 +40,7 @@ class ConcreteProxy(Proxy): """ `ConcreteProxy` is a wrapped proxy carried the real intermediate value. - We can use it to trace a more compatibal model, and pass the branches. + We can use it to trace a more compatible model, and pass the branches. """ # TODO: python bytecode changes a lot in version 3.11. these ops should be updated. @@ -58,6 +58,7 @@ class ConcreteProxy(Proxy): op_call_ex = dis.opmap['CALL_FUNCTION_EX'] op_not = dis.opmap['UNARY_NOT'] op_unpack_sequence = dis.opmap['UNPACK_SEQUENCE'] + op_dict_merge = dis.opmap.get('DICT_MERGE', None) # DICT_MERGE is new in python 3.9 jump_before_opcodes = (op_compare, op_not) # occurred in different python versions @@ -73,11 +74,7 @@ def __init__(self, node: Node, value: Any, tracer: Optional[et.ConcreteTracer] = self.node = node def __repr__(self) -> str: - # to detect if in debugging or in code - calling_frame_name = inspect.stack()[1][1] - if calling_frame_name.endswith('pydevd_exe2.py') or calling_frame_name.endswith('pydevd_safe_repr.py'): - return f'ConcreteProxy({self.node.name})' - return repr(self.value) + return f'ConcreteProxy({self.node.name}, {self.value})' def __getattr__(self, k) -> ConcreteProxy: return ConcreteAttrProxy(self, k) @@ -201,7 +198,7 @@ def keys(self): while insts[cur].opcode == self.op_extended_arg: cur += 1 - if insts[cur].opcode == self.op_call_ex: + if insts[cur].opcode == self.op_call_ex or insts[cur].opcode == self.op_dict_merge: # in executing `**proxy` return self.value.keys() else: @@ -419,4 +416,4 @@ def impl(self, rhs): for orig_method_name in reflectable_magic_methods: - _define_reflectable(orig_method_name) \ No newline at end of file + _define_reflectable(orig_method_name) diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index 73d1d406..72f48be5 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -3,12 +3,13 @@ from __future__ import annotations +import collections import sys import inspect +import logging import operator import functools import builtins -import copy from itertools import chain from types import BuiltinMethodType, FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType @@ -20,6 +21,7 @@ from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict from torch.utils._pytree import tree_map +import torch.fx from torch.fx import GraphModule from torch.fx._compatibility import compatibility from torch.fx._symbolic_trace import _Patcher, _proxyable_classes @@ -27,6 +29,19 @@ from torch.fx.node import Target, Node from torch.fx.proxy import TracerBase +try: + # Scope is a new class to record module path in pytorch 2.0 + from torch.fx.proxy import Scope +except ImportError: + # copy from pytorch 2.0 + @compatibility(is_backward_compatible=False) + class Scope: + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + + from . import concrete_proxy as ep from .operator_patcher import OperatorPatcherContext from .utils import ( @@ -39,6 +54,7 @@ _orig_type, _orig_isinstance, + _orig_issubclass, _orig_getattr, _orig_range, @@ -53,6 +69,8 @@ _orig_zip, _orig_enumerate, _orig_slice, + _orig_reversed, + _orig_torch_size, _orig_len, _orig_not, @@ -60,17 +78,26 @@ _orig_is_not, _orig_contains, _orig_index, + + _orig_all, + _orig_min, + _orig_max, ) + +_logger = logging.getLogger(__name__) HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS @compatibility(is_backward_compatible=True) class ConcreteTracer(TracerBase): """ - A model tracer similar to _symbolic_trace.Tracer, but with concrete execution and real value so we can pass complecate conditions + A model tracer similar to _symbolic_trace.Tracer, but with concrete execution and real value so we can pass complex conditions and go into correct brunches. """ + default_module_getattr = ( + 'training', + ) default_autowrap_modules = ( 'math', ) @@ -82,15 +109,27 @@ class ConcreteTracer(TracerBase): _orig_is_not: ([], False, None), _orig_contains: ([], False, None), _orig_index: ([], False, None), - - # force-traced function + _orig_all: ((), False, None), + _orig_min: ((), False, None), + _orig_max: ((), False, None), + + # force-traced function (the factory functions of tensor creation) + torch.arange: ([], True, None), + torch.empty: ([], True, None), + torch.eye: ([], True, None), + torch.full: ([], True, None), + torch.linspace: ([], True, None), + torch.logspace: ([], True, None), + torch.ones: ([], True, None), torch.rand: ([], True, None), - torch.randn: ([], True, None), torch.randint: ([], True, None), - torch.rand_like: ([], True, None), - torch.randn_like: ([], True, None), - torch.randint_like: ([], True, None), + torch.randn: ([], True, None), + # torch.rand_like: ([], True, None), # seems that xxx_like will not directly call torch._TensorBase.xxx + # torch.randn_like: ([], True, None), + # torch.randint_like: ([], True, None), torch.randperm: ([], True, None), + torch.tensor: ([], True, None), + torch.zeros: ([], True, None), # method Sequential.__getitem__: ([], False, operator.getitem), @@ -163,7 +202,15 @@ class ConcreteTracer(TracerBase): _orig_set: ([], True), _orig_frozenset: ([], True), _orig_dict: ([], True), + _orig_reversed: ((), False), + + _orig_torch_size: ((), False), } + + # add these to record module path information during tracing + current_module_qualified_name : str = '' + node_to_originating_module : Dict[torch.fx.Node, str] = {} + @compatibility(is_backward_compatible=True) def __init__(self): """ @@ -171,6 +218,9 @@ def __init__(self): remove the 'param_shapes_constant' because we can get real shape when executing. """ super().__init__() + self.scope = Scope("", None) + self.module_stack = collections.OrderedDict() + self.node_name_to_scope = {} @contextmanager def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): @@ -317,7 +367,8 @@ def proxy(self, value: Any, node: Node) -> ep.ConcreteProxy: @compatibility(is_backward_compatible=True) def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], - name: Optional[str] = None, type_expr: Optional[Any] = None): + name: Optional[str] = None, type_expr: Optional[Any] = None, + proxy_factory_fn: Optional[Callable[[Node], Any]] = None): """ similar to _symbolic_trace.Tracer.create_proxy. use the 'run_target' to actually execute the code, and store the value in 'value' field. @@ -332,14 +383,16 @@ def upwrapper(obj: Any): # real value by execution value_unwrapped = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) - args_noded = self.create_arg(args) - kwargs_noded = self.create_arg(kwargs) + args_ = self.create_arg(args) + kwargs_ = self.create_arg(kwargs) + assert isinstance(args_, tuple) + assert isinstance(kwargs_, dict) - assert isinstance(args_noded, tuple) - assert isinstance(kwargs_noded, dict) + node = self.create_node(kind, target, args_, kwargs_, name, type_expr) - node = self.create_node(kind, target, args_noded, kwargs_noded, name, type_expr) - return self.proxy(value_unwrapped, node) + proxy = self.proxy(value_unwrapped, node) + self.node_to_originating_module[proxy.node] = self.current_module_qualified_name + return proxy @compatibility(is_backward_compatible=True) def create_arg(self, a: Any) -> Union[Node, Any]: @@ -542,15 +595,15 @@ def proxy_placeholder(name: str): @compatibility(is_backward_compatible=True) def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, - autowrap_modules: Tuple[str] = (), - autowrap_leaf_function = {}, - autowrap_leaf_class = {}, - leaf_module = (), - fake_middle_class = (), - concrete_args: Union[Dict[str, Any], Tuple], - use_operator_patch: bool = True, - operator_patch_backlist: List[str] = [], - forwrad_function_name: str = 'forward') -> Graph: + autowrap_modules: Tuple[str] | None = None, + autowrap_leaf_function = None, + autowrap_leaf_class = None, + leaf_module = None, + fake_middle_class = None, + concrete_args: Union[Dict[str, Any], Tuple], + use_operator_patch: bool = True, + operator_patch_backlist: List[str] | None = None, + forward_function_name: str = 'forward') -> Graph: """ similar to _symbolic_trace.Tracer.trace different args: @@ -571,18 +624,38 @@ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, such as '__main__.FooModel' or '__main__.bar_func'. the namespace is always needed. """ + # fill default values + args = inspect.getfullargspec(root.forward).args[1:] + defaults = inspect.getfullargspec(root.forward).defaults + defaults = tuple() if defaults is None else defaults + if isinstance(concrete_args, (tuple, list)): + concrete_args = (*concrete_args, *defaults[len(concrete_args) + len(defaults) - len(args):]) + else: + kv_default = {k: v for k, v in zip(args[-len(defaults):], defaults)} + concrete_args = { + **concrete_args, + **{n: kv_default[n] for n in args if n not in concrete_args} + } + + # preprocess arguments + autowrap_modules = autowrap_modules if autowrap_modules is not None else tuple() + autowrap_leaf_function = autowrap_leaf_function if autowrap_leaf_function is not None else {} + autowrap_leaf_class = autowrap_leaf_class if autowrap_leaf_class is not None else {} + leaf_module = leaf_module if leaf_module is not None else () + fake_middle_class = fake_middle_class if fake_middle_class is not None else () + operator_patch_backlist = operator_patch_backlist if operator_patch_backlist is not None else [] # Python modules to apply autowrap to at the start, in addition to # modules we see while tracing - self._autowrap_search: List[ModuleType] = list(sys.modules[m] for m in (*autowrap_modules, *ConcreteTracer.default_autowrap_modules)) + self._autowrap_search: List[ModuleType] = list( + sys.modules[m] for m in (*autowrap_modules, *ConcreteTracer.default_autowrap_modules) + ) # Functions we will eagerly wrap when we see them while tracing # this captures both `math.sqrt()` and `from math import sqrt` automatically self._autowrap_function_ids: Set[int] = { id(value) for name, value in chain(*[m.__dict__.items() for m in self._autowrap_search]) if not name.startswith("_") and callable(value)} - self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None - self.autowrap_leaf_function = {**autowrap_leaf_function, **ConcreteTracer.default_autowrap_leaf_function} self.autowrap_leaf_class = {**autowrap_leaf_class, **ConcreteTracer.default_autowrap_leaf_class} self.leaf_module = leaf_module @@ -592,27 +665,17 @@ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, # TODO: better infomation assert hasattr( - root, forwrad_function_name - ), f"traced_func_name={forwrad_function_name} doesn't exist in {_orig_type(root).__name__}" + root, forward_function_name + ), f"traced_func_name={forward_function_name} doesn't exist in {_orig_type(root).__name__}" - fn = getattr(root, forwrad_function_name) + fn = getattr(root, forward_function_name) self.submodule_paths = {mod: name for name, mod in root.named_modules()} else: self.root = torch.nn.Module() fn = root tracer_cls = getattr(self, '__class__', None) - self.graph = Graph(tracer_cls=tracer_cls, tracer_extras={ - 'autowrap_modules': autowrap_modules, - 'autowrap_leaf_function': autowrap_leaf_function, - 'autowrap_leaf_class': autowrap_leaf_class, - 'leaf_module': leaf_module, - 'fake_middle_class': fake_middle_class, - 'concrete_args': concrete_args, - 'use_operator_patch': use_operator_patch, - 'operator_patch_backlist': operator_patch_backlist, - 'forwrad_function_name': 'forward', - }) + self.graph = Graph(tracer_cls=tracer_cls) # When we encounter a Tensor value that's not a parameter, we look if it # is some other attribute on the model. Construct a dict mapping Tensor @@ -670,6 +733,8 @@ def module_getattribute_wrapper(mod, attr): if attr_val in self.wrapped_leaf: return self.wrapped_leaf[attr_val][1] return attr_val + elif attr in self.default_module_getattr: + return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) elif _orig_isinstance(attr_val, (_orig_tuple, _orig_list)): if self.the_path_of_middle_class[id(mod)] == '': return self.create_proxy('get_attr', f'{attr}', (), {}) @@ -688,8 +753,16 @@ def module_call_wrapper(mod, *args, **kwargs): else: module_qualified_name = self.path_of_module(mod) if not self.is_leaf_module(mod, module_qualified_name): - _autowrap_check(self, mod.forward.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - _autowrap_check(self, mod.__dict__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + _autowrap_check(self, + mod.forward.__globals__, + self._autowrap_function_ids, + self.autowrap_leaf_pairs, + self.agfunc_dict) + _autowrap_check(self, + mod.__dict__, + self._autowrap_function_ids, + self.autowrap_leaf_pairs, + self.agfunc_dict) return _orig_module_call(mod, *args, **kwargs) else: return self.create_proxy('call_module', module_qualified_name, args, kwargs) @@ -898,6 +971,20 @@ def isinstance_wrapper(instance, clz): instance = instance.value return _orig_isinstance(instance, clz) + @functools.wraps(_orig_issubclass) + def issubclass_wrapper(subclass, clz): + if _orig_type(clz) in (slice, tuple, list, _orig_slice, _orig_tuple, _orig_list): + clz_wrapped = [] + for wrapped_type, orig_type in self.clz_wrapper_map.items(): + if wrapped_type in clz: + clz_wrapped.append(orig_type) + clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in self.clz_wrapper_map)) + return _orig_issubclass(subclass, clz) + else: + if clz in self.clz_wrapper_map: + clz = self.clz_wrapper_map[clz] + return _orig_issubclass(subclass, clz) + @functools.wraps(_orig_getattr) def getattr_wrapper(obj, *args): # TODO: better infomation @@ -929,6 +1016,7 @@ def getattr_wrapper(obj, *args): self.patcher.patch_method(builtins, "range", range_wrapper, deduplicate=False) self.patcher.patch_method(builtins, "type", type_wrapper, deduplicate=False) self.patcher.patch_method(builtins, "isinstance", isinstance_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "issubclass", issubclass_wrapper, deduplicate=False) self.patcher.patch_method(builtins, "getattr", getattr_wrapper, deduplicate=False) for obj, (positions, wrapped) in self.wrapped_leaf.items(): @@ -947,6 +1035,7 @@ def getattr_wrapper(obj, *args): finally: # for cuda versions of pytorch, autograd.Function.apply should be reverted manually delattr(torch.autograd.Function, 'apply') + _retain_weight_consistency(self.root) pass self.submodule_paths = None @@ -1124,7 +1213,7 @@ def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: @staticmethod def format_import_statement_new(name: str, obj: Any, importer) -> str: if isinstance(obj, BuiltinMethodType) and getattr(obj, '__name__', None) == 'apply'\ - and isinstance(getattr(obj, '__self__', None), Type) and issubclass(obj.__self__, torch.autograd.Function): + and isinstance(getattr(obj, '__self__', None), Type) and issubclass(obj.__self__, torch.autograd.Function): # type: ignore # torch.autograd.function return MagicMethodPatcher.format_import_statement_ori(name, obj.__self__, importer) + f'\n{name} = {name}.apply' return MagicMethodPatcher.format_import_statement_ori(name, obj, importer) @@ -1271,17 +1360,37 @@ def clz_getattr_wrapper(obj, attr): return tracer.create_proxy('get_attr', f'{the_path_of_middle_class[id(obj)]}.{attr}', (), {}) return clz_getattr_wrapper +def _retain_weight_consistency(root: torch.nn.Module): + _flag = 0 + for module in root.modules(): + for name, param in module.named_parameters(): + if _orig_isinstance(param, ep.ConcreteProxy): + param: ep.ConcreteProxy # pyright: reportGeneralTypeIssues=false + _logger.warning(f'Parameter {name} of {module} is a ConcreteProxy. Some weight may be modified inplace within forward().') + setattr(module, name, param.value) + _flag |= 1 + for name, buffer in module.named_buffers(): + if _orig_isinstance(buffer, ep.ConcreteProxy): + buffer: ep.ConcreteProxy # pyright: reportGeneralTypeIssues=false + _logger.warning(f'Buffer {name} of {module} is a ConcreteProxy. Some buffer may be modified inplace within forward().') + setattr(module, name, buffer.value) + _flag |= 1 + if _flag: + _logger.warning('Some weight or buffer is modified inplace within forward(). This may cause unexpected behavior.' + ' ``concrete_trace`` may not guarantee the consistency of the traced graph.') + return root + def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Union[Dict[str, Any], Tuple], *, - use_operator_patch: bool = False, - operator_patch_backlist: List[str] = [], - forwrad_function_name: str = 'forward', + use_operator_patch: bool = True, + operator_patch_backlist: List[str] | None = None, + forward_function_name: str = 'forward', check_args: Optional[Dict[str, Any]] = None, - autowrap_leaf_function = {}, - autowrap_leaf_class = {}, - leaf_module = (), - fake_middle_class = ()) -> GraphModule: + autowrap_leaf_function = None, + autowrap_leaf_class = None, + leaf_module: Tuple | None = None, + fake_middle_class = None,) -> GraphModule: """ Concrete tracing API @@ -1408,6 +1517,7 @@ def f(x, y): fx.GraphModule: a Module created from the recorded operations from ``root``. """ tracer = ConcreteTracer() + graph = tracer.trace(root, autowrap_leaf_function = autowrap_leaf_function, autowrap_leaf_class = autowrap_leaf_class, @@ -1416,7 +1526,7 @@ def f(x, y): concrete_args=concrete_args, use_operator_patch=use_operator_patch, operator_patch_backlist=operator_patch_backlist, - forwrad_function_name=forwrad_function_name, + forward_function_name=forward_function_name, ) graph_check = tracer.trace(root, autowrap_leaf_function = autowrap_leaf_function, @@ -1426,7 +1536,7 @@ def f(x, y): concrete_args=concrete_args, use_operator_patch=use_operator_patch, operator_patch_backlist=operator_patch_backlist, - forwrad_function_name=forwrad_function_name, + forward_function_name=forward_function_name, ) # compare to check equal assert len(graph.nodes) == len(graph_check.nodes) diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py index 1d9aac75..8d9a5e11 100644 --- a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py +++ b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py @@ -102,9 +102,8 @@ def run_node(self, n: Node) -> Any: Returns: Any: The result of executing ``n`` """ - with fx_traceback.append_stack_trace(n.stack_trace): - args, kwargs = self.fetch_args_kwargs_from_env(n) - return getattr(self, n.op)(n.target, args, kwargs) + args, kwargs = self.fetch_args_kwargs_from_env(n) + return getattr(self, n.op)(n.target, args, kwargs) def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ diff --git a/cube/graph/parser/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/concrete_trace_utils/operator_patcher.py index 1f6eb3ac..ef9756cd 100644 --- a/cube/graph/parser/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/concrete_trace_utils/operator_patcher.py @@ -6,8 +6,10 @@ from .concrete_tracer import ConcreteTracer import ast +import builtins import inspect import logging +import platform from textwrap import dedent from types import MethodType, FunctionType @@ -21,6 +23,7 @@ _orig_len, _orig_dict, _orig_zip, + _orig_tuple, ) _logger = logging.getLogger(__name__) @@ -49,7 +52,16 @@ def visit(self, node): return super().visit(node) def visit_Call(self, node: ast.Call): - if not isinstance(node.func, ast.Name) or node.func.id != 'patch_run': + if isinstance(node.func, ast.Name) and node.func.id == 'super' and len(node.args) == 0: + return self.generic_visit(ast.Call( + func=ast.Name(id='super', ctx=ast.Load()), + args=[ + ast.Attribute(value=ast.Name(id='self', ctx=ast.Load()), attr='__class__', ctx=ast.Load()), + ast.Name(id='self', ctx=ast.Load()), + ], + keywords=node.keywords, + )) + elif not isinstance(node.func, ast.Name) or node.func.id != 'patch_run': self.is_transformed = True return self.generic_visit(ast.Call( func=ast.Name(id='patch_run', ctx=ast.Load()), @@ -226,20 +238,27 @@ def patch_inner_helper(self, func): assert _orig_len(closures) == _orig_len(co_freevars) closure_dict = _orig_dict(_orig_zip(co_freevars, [c.cell_contents for c in closures])) - var_dict = {} - exec( - # use func.__code__.co_filename to make the new function easily debuggable. - compile(new_tree, func_inner.__code__.co_filename, 'exec'), - { - 'patch_run': OperatorPatcherContext.patch_run, - **func_inner.__globals__, - **closure_dict, - }, - var_dict) - if the_self is not None: - return var_dict['new_func'].__get__(the_self) - else: - return var_dict['new_func'] + tuple_wrapped = tuple + try: + if platform.python_version_tuple() < ('3', '9'): + setattr(builtins, 'tuple', _orig_tuple) + var_dict = {} + exec( + # use func.__code__.co_filename to make the new function easily debuggable. + compile(new_tree, func_inner.__code__.co_filename, 'exec'), + { + 'patch_run': OperatorPatcherContext.patch_run, + **func_inner.__globals__, + **closure_dict, + }, + var_dict) + if the_self is not None: + return var_dict['new_func'].__get__(the_self) + else: + return var_dict['new_func'] + finally: + if platform.python_version_tuple() < ('3', '9'): + setattr(builtins, 'tuple', tuple_wrapped) class OperatorPatcherContext: ctx_tracer: Optional['ConcreteTracer'] = None @@ -267,4 +286,4 @@ def patch_run(func, *args, **kwargs): assert OperatorPatcherContext.ctx_patcher is not None with OperatorPatcherContext.ctx_tracer.do_temp_disable(True, True, True): new_func = OperatorPatcherContext.ctx_patcher.patch_inner(func) - return new_func(*args, **kwargs) \ No newline at end of file + return new_func(*args, **kwargs) diff --git a/cube/graph/parser/concrete_trace_utils/utils.py b/cube/graph/parser/concrete_trace_utils/utils.py index 3cfc95dc..1d8a48f1 100644 --- a/cube/graph/parser/concrete_trace_utils/utils.py +++ b/cube/graph/parser/concrete_trace_utils/utils.py @@ -18,6 +18,7 @@ _orig_type: Callable = builtins.type _orig_isinstance: Callable = builtins.isinstance +_orig_issubclass: Callable = builtins.issubclass _orig_getattr: Callable = builtins.getattr _orig_range: Type[Any] = builtins.range @@ -32,6 +33,8 @@ _orig_zip: Type[Any] = builtins.zip _orig_enumerate: Type[Any] = builtins.enumerate _orig_slice: Type[Any] = builtins.slice +_orig_reversed: Type[Any] = builtins.reversed +_orig_torch_size: Type[Any] = torch.Size _orig_len: Callable = builtins.len _orig_not: Callable = operator.not_ @@ -40,6 +43,11 @@ _orig_contains: Callable = operator.contains _orig_index: Callable = operator.index +_orig_all: Callable = builtins.all +_orig_min: Callable = builtins.min +_orig_max: Callable = builtins.max + + def run_onlyif_instance(cond_type: Type[Any], return_orig: bool = True, return_const: Any = None): def helper(fn): if return_orig: @@ -94,4 +102,4 @@ def map_recursive_zip(fn: Callable, arg0, *args) -> Any: return {k: map_recursive_zip(fn, arg0[k], *(arg[k] for arg in args)) for k in keys} else: # assert not _orig_isinstance(arg0, slice) - return fn(arg0, *args) \ No newline at end of file + return fn(arg0, *args) From cb1b5ab7e0bdf953c7ae6be06ccab18507748557 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 3 Apr 2023 10:25:23 +0000 Subject: [PATCH 1379/1892] Merged PR 1522: [need test] graph dump and load support --- cube/graph/function/dimops.py | 13 --- cube/graph/graph.py | 57 ++++++++++++ cube/ir/cten.py | 11 +++ cube/ir/tensor.py | 8 +- cube/ir/unique.py | 7 ++ requirements.txt | 1 + tests/graph/test_dump_load.py | 157 ++++++++++++++++++++++++++++++++++ 7 files changed, 237 insertions(+), 17 deletions(-) create mode 100644 tests/graph/test_dump_load.py diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 0a696dc3..da496adc 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -632,19 +632,6 @@ def anno(self) -> OpAnno: def transform_rules(self) -> Tuple[TransformRule]: return self._trans_rules - def getstate_for_dump(self): - state = self.__dict__.copy() - state['_create_fn'] = { - 'name': self._create_fn[0].__name__, - 'module': self._create_fn[0].__module__, - } - return state - - def setstate_for_load(self, state): - module = importlib.import_module(state['_create_fn']['module']) - state['_create_fn'] = (getattr(module, state['_create_fn']['name']),) - self.__dict__.update(state) - def ianno(self, index: int) -> Tuple[DimAnno]: """! Get index-th input tensor shape annotation diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 35680f6d..a8707742 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -10,6 +10,8 @@ from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any import warnings import copy +import pickle +import dill from cube.ir.cten import IRTensor, IRCell, IRObject from cube.ir.unique import IDGenerator @@ -20,6 +22,7 @@ from cube.graph.function.function import Identity, MultiRef from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.pyfunc import IRPyFunc +from cube.graph.function.dimops import IRDimops from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo @@ -921,3 +924,57 @@ def auto_multiref(self): self.finsert(multiref, fidx) else: self.insert(multiref, fidx) + + def dump(self, filename: str) -> None: + """ + Dump the graph into pickled format + + @param filename str + """ + # FIXME: dump doesn't support customized op + class PicklingContextSave: + def __enter__(self): + IRObject.__getstate__ = IRObject.getstate_for_dump + def __exit__(self, exc_type, exc_value, traceback): + IRObject.__getstate__ = lambda self: self.__dict__.copy() + + with PicklingContextSave(): + with open(filename, 'wb') as f: + save = (IDGenerator().get_states(), self) + dill.dump(save, f) + + @staticmethod + def load(filename: str): + """ + Load the graph from pickled file. + Note IDGenerator will also be reset to match with graph status + + @param filename str + + @return graph IRGraph + """ + with open(filename, 'rb') as f: + id_state, graph = dill.load(f) + + # recover IRGenerator + IDGenerator().load_states(id_state) + # recover cell + def reset_node(segment: IRSegment): + # input + for t in segment.inputs(): + if isinstance(t, IRObject): + t.cell = segment + # nodes + for node in segment.nodes(): + for t in node.inputs() + node.outputs(): + if isinstance(t, IRObject): + t.cell = node + # recursively recover segments + if isinstance(node, IRSegment): + reset_node(node) + # output + for t in IRSegment.get_objects_from_complex(segment.outputs()): + t.cell = segment + + reset_node(graph) + return graph diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 41bdda31..e7c0c1ff 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -451,6 +451,17 @@ def __eq__(self, obj): def __hash__(self) -> int: return self._id + def getstate_for_dump(self): + """ + __getstate__ method for pickle dump + + @warning: dump an IRObject will disconnect the tensor to its cell + """ + state = self.__dict__.copy() + # this will decouple the interconnected object and cell during dump. + state['_cell'] = None + return state + @property def tid(self) -> int: """Get object id""" diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 5f3937a0..1fa768f6 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -261,7 +261,7 @@ def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=IRDType super().__init__(shape, name, dtype) # record all created sub_tensors - self._segments : Dict[(ValueMap, IndexMap), int] = dict() + self._subtensors : Dict[(ValueMap, IndexMap), int] = dict() self.requires_grad = requires_grad self._is_loss = False @@ -397,12 +397,12 @@ def select(self, indmap: IndexMap, valmap: ValueMap): keys = (indmap, valmap) # print(f'key: {keys}, hash {hash(keys)}') # return tensor to keep id same for same sub tensor - if keys in self._segments: - tid = self._segments[keys] + if keys in self._subtensors: + tid = self._subtensors[keys] sub_tensor = IRSubTensor(self, indmap, valmap, tid=tid) else: sub_tensor = IRSubTensor(self, indmap, valmap) - self._segments[keys] = sub_tensor.tid + self._subtensors[keys] = sub_tensor.tid return sub_tensor def tosub(self): diff --git a/cube/ir/unique.py b/cube/ir/unique.py index 635a456c..b40851e4 100644 --- a/cube/ir/unique.py +++ b/cube/ir/unique.py @@ -29,6 +29,13 @@ def gen_cell_id(self): self.instance._cell_id += 1 return self.instance._cell_id + def get_states(self): + return (self._tensor_id, self._cell_id) + + def load_states(self, states: tuple): + IDGenerator.instance._tensor_id = states[0] + IDGenerator.instance._cell_id = states[1] + def clear(self): self.instance._tensor_id = 0 self.instance._cell_id = 0 diff --git a/requirements.txt b/requirements.txt index 94cf5b68..b1866504 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ matplotlib pytest setuptools==60.7.0 more-itertools +dill diff --git a/tests/graph/test_dump_load.py b/tests/graph/test_dump_load.py new file mode 100644 index 00000000..df036df3 --- /dev/null +++ b/tests/graph/test_dump_load.py @@ -0,0 +1,157 @@ +""" +USE_TORCHFX=1 torchrun --nproc_per_node=2 tests/graph/test_dump_load.py +""" +from typing import List +import torch +from cube.ir.cten import IRObject + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops + + +cube.init() + + +def _param(size, dtype=torch.float32): + return torch.nn.Parameter(torch.empty(size, dtype=dtype)) + + +class TestOpModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.param1 = _param([512, 256]) + self.param2 = _param([512, 256]) + self.ints = [1, 2, 3] + + def forward(self, x: torch.Tensor): + # matmul: [bs, 512], [512, 256] -> [bs, 256] + x1 = torch.matmul(x, self.param1) + # [bs, 256] -> [bs, 256] + x1 = x1 + x1.size(0) + x1.size()[0] + # [bs, 256] -> [bs, 128], [bs, 128] + x2 = torch.chunk(x1, 2, dim=1)[0] + # [bs, 128] -> [bs, 128] + x3 = x2 + x2.size(0) + x4 = x3 + self.ints[0] + # [bs, 128] -> [1] + loss = torch.sum(x4) + return {'x': x4, 'loss': loss} # , [x3,] + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int = 256) -> None: + self.sample = torch.rand( + [batch_size, 512], + dtype=torch.float32, + device=torch.cuda.current_device() + ) + super().__init__(batch_size, (0,)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +def test_graph_dump_load_single(): + + model = TestOpModule() + dataloader = TestDataLoader() + + def policy(graph: IRGraph, resource): + print('================ original one:') + print(graph.extra_repr()) + + graph.dump('graph.pickle') + new_graph = IRGraph.load('graph.pickle') + + print('================ loaded from pickled one:') + print(graph.extra_repr()) + + for node in graph.nodes(): + for t in node.inputs(): + if isinstance(t, IRObject): + assert t.cell is not None + + assert graph.extra_repr() == new_graph.extra_repr() + + assert resource.ngpus == 1 + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'add': + sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=0, num=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph + + + @cube.compile(model, dataloader, PAS=policy, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) + def train_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out['loss'].backward() + + model = cube.load_model(load_content=False) + + for idx in range(3): + train_iter(model, dataloader) + print(f"iter {idx}/3") + + +def test_graph_dump_load_with_transform(): + + model = TestOpModule() + dataloader = TestDataLoader() + + def policy(graph: IRGraph, resource): + print('================ original one:') + print(graph.extra_repr()) + old_repr = graph.extra_repr() + + graph.dump('graph.pickle') + graph = IRGraph.load('graph.pickle') + + print('================ loaded from pickled one:') + print(graph.extra_repr()) + new_repr = graph.extra_repr() + + for node in graph.nodes(): + for t in node.inputs(): + if isinstance(t, IRObject): + assert t.cell is not None + + assert new_repr == old_repr + + assert resource.ngpus == 2 + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + + @cube.compile(model, dataloader, PAS=policy, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) + def train_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out['loss'].backward() + + model = cube.load_model(load_content=False) + + for idx in range(3): + train_iter(model, dataloader) + print(f"iter {idx}/3") + + +if __name__ == '__main__': + # test_graph_dump_load_single() + test_graph_dump_load_with_transform() From c1e75192e9777c42ed23f43acc6041fe812879ac Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 7 Apr 2023 05:22:07 +0000 Subject: [PATCH 1380/1892] Merged PR 1529: dedup same tensors in backward dedup same tensors --- cube/runtime/executor.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 0247854c..65ba43c8 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -148,7 +148,14 @@ def backward(name: str, dtensors: List[torch.Tensor] = [pair[1] for pair in saved_pairs] for t in input_tensors: if id(t) not in tensor_ids: - warnings.warn("input doesn't match. Make sure in scheduling that earlier forward perform earlier backward") + import traceback + warnings.warn( + f"rank {torch.distributed.get_rank()}: input {name} doesn't match. " + f"Make sure in scheduling, earlier forward perform earlier backward. " + f"Remain {len(Executor._detach[name])} segments.\n" + f"{''.join(traceback.format_stack())}" + ) + input_tensors = [] for t in dtensors: @@ -156,9 +163,18 @@ def backward(name: str, t.retain_grad() input_tensors.append(t) + visited = set() + dedup_output_tensors = [] + dedup_output_tensor_grads = [] + for t, g in zip(output_tensors, output_tensor_grads): + if id(t) not in visited: + visited.add(id(t)) + dedup_output_tensors.append(t) + dedup_output_tensor_grads.append(g) + torch.autograd.backward( - output_tensors, - grad_tensors=output_tensor_grads, + dedup_output_tensors, + grad_tensors=dedup_output_tensor_grads, ) grads = tuple(t.grad for t in input_tensors) assert all(grad is not None for grad in grads), "RuntimeError: got gradient None" From b7b34b02faaccdeb660e06aeaebeb3fda7ec02dd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 7 Apr 2023 08:26:47 +0000 Subject: [PATCH 1381/1892] Merged PR 1527: Allow flexible graph output returning --- cube/codegen/emit.py | 24 ++- cube/graph/gener/gen.py | 126 ++++++++++------ cube/graph/gener/utils.py | 7 +- cube/graph/graph.py | 266 ++++++++++++++------------------- cube/graph/segment.py | 166 ++++++++++---------- cube/ir/tensor.py | 26 ++-- tests/graph/test_infer_grad.py | 134 +++++++++++++++++ tests/graph/test_multiref.py | 115 ++++++++++++++ 8 files changed, 558 insertions(+), 306 deletions(-) create mode 100644 tests/graph/test_infer_grad.py create mode 100644 tests/graph/test_multiref.py diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 59436b9f..965ed86b 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -229,7 +229,7 @@ def emit_release(tensors: Iterable[IRTensor]) -> str: return 'del ' + ', '.join(tnames) @staticmethod - def get_backward_callsite_io_tensors(bw_cell: IRCell) -> Tuple: + def get_backward_callsite_io_tensors(bwop: IRCell) -> Tuple: """ Get backward inputs and outputs ``` @@ -245,19 +245,17 @@ def get_backward_callsite_io_tensors(bw_cell: IRCell) -> Tuple: @return input_grads List[IRSubTensor]: gradient of forward input tensors (backward output) """ - assert not bw_cell.isfw() + assert not bwop.isfw() + fwop: IRCell = bwop.mirror - input_tensors = [t for t in bw_cell.mirror.inputs() if \ - isinstance(t, IRSubTensor) and \ - t.requires_grad and \ - not t.is_attr() - ] - output_tensors = [t for t in bw_cell.mirror.outputs() if isinstance(t, IRSubTensor) and t.grad is not None] - input_grads = [t.grad for t in input_tensors] + grad2tensor = {} + for t in fwop.inputs() + fwop.outputs(): + if isinstance(t, IRSubTensor) and t.grad is not None: + grad2tensor[t.grad] = t - # WARNING !!! - # non-tensor gradients like scalar '1.0f' are removed in 'bpSeg.inputs()' - # so the items of 'bpSeg.inputs()' are generally disaligned with 'output_grads' here. - output_grads = [t.grad for t in output_tensors] + input_grads = [t for t in bwop.outputs() if isinstance(t, IRSubTensor)] + output_grads = [t for t in bwop.inputs() if isinstance(t, IRSubTensor)] + input_tensors = [grad2tensor[g] for g in input_grads] + output_tensors = [grad2tensor[g] for g in output_grads] return input_tensors, output_tensors, output_grads, input_grads diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index dd57cb5a..152bd73a 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,12 +1,13 @@ -from typing import Dict, List, Optional, Tuple, Callable +from typing import Dict, List, Optional, Tuple, Callable, Set import numpy as np import itertools +import warnings from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener import cube.graph.gener.utils as utils from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment, CellPosition +from cube.graph.segment import IRSegment from cube.graph.function.pyfunc import IRPyFunc from cube.ir.cten import IRCell, IRObject @@ -33,36 +34,29 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) # devices = segment.device input_producers: Dict[IRFullTensor, List[IRCell]] = {} output_consumers: Dict[IRFullTensor, List[IRCell]] = {} + devices = segment.device # create inputs if inputs: input_objects = IRGraph.get_objects_from_complex(segment.inputs()) for tensor in input_objects: - devices = [consumer.device for consumer in segment.consumers(tensor.parent)][::-1] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = utils.DummyInputOuput(tensor, 0, is_output=True, name=f'segment{segment.cid}_input') - for devid in devices: - fop = fwop.replicate() - fop.device = devid - if tensor.requires_grad: - fop.output(0).grad = tensor.parent.grad.select(tensor.indmap, (0, 1)) - fop.output(0).grad.cell = fop - input_producers.setdefault(tensor.parent, []).append(fop) + fwop = utils.DummyInputOuput(tensor, devices, is_output=True, name=f'segment{segment.cid}_input') + if tensor.grad is not None: + fwop.output(0).grad = tensor.parent.grad.select(tensor.indmap, (0, 1)) + fwop.output(0).grad.cell = fwop + input_producers.setdefault(tensor.parent, []).append(fwop) # create outputs if outputs: output_objects = IRGraph.get_objects_from_complex(segment.outputs()) for tensor in output_objects: - devices = [producer.device for producer in segment.producers(tensor.parent)] if not isinstance(tensor, IRSubTensor): continue assert tensor.valmap == (0, 1), f"valmap != (0, 1):\n{segment.extra_repr()}" - fwop = utils.DummyInputOuput(tensor, 0, is_input=True, name=f'segment{segment.cid}_output') - for devid in devices: - fop = fwop.replicate() - fop.device = devid - if tensor.requires_grad: - fop.input(0).grad = tensor.parent.grad.select(tensor.indmap, (0, 1)) - fop.input(0).grad.cell = fop - output_consumers.setdefault(tensor.parent, []).append(fop) + fwop = utils.DummyInputOuput(tensor, devices, is_input=True, name=f'segment{segment.cid}_output') + if tensor.grad is not None: + fwop.input(0).grad = tensor.parent.grad.select(tensor.indmap, (0, 1)) + fwop.input(0).grad.cell = fwop + output_consumers.setdefault(tensor.parent, []).append(fwop) return input_producers, output_consumers @@ -78,22 +72,24 @@ def expand_devices(tensors: List[Optional[IRSubTensor]], @return dtensors List[IRSubTensor]: each tensor is on one device """ - dtensors = [] + dtensors: Dict[int, List[IRSubTensor]] = {} for tensor in tensors: if tensor is None: continue - if len(tensor.device) == 1: - dtensors.append(tensor) - continue for devid in tensor.device: + if tensor in dtensors.setdefault(devid, []): + continue if producer: fwop = utils.DummyInputOuput(tensor, devid, is_output=True, name=tensor.cell.name) - dtensors.append(fwop.output(0)) + dtensors[devid].append(fwop.output(0)) elif consumer: fwop = utils.DummyInputOuput(tensor, devid, is_input=True, name=tensor.cell.name) - dtensors.append(fwop.input(0)) + dtensors[devid].append(fwop.input(0)) else: raise ValueError("At least one of producer or consumer") - return dtensors + all_tensors = [] + for device_tensors in dtensors.values(): + all_tensors += device_tensors + return all_tensors class IRAdapterGener: @@ -182,6 +178,27 @@ def gen_weight(graph: IRGraph) -> IRGraph: 1. same number of nodes per cid group 2. same device set or no-overlapping device set per cid group """ + def check_consistent_local_partition(graph: IRSegment): + """each weight full tensor inside one device should in same format.""" + for ftensor in graph.full_tensors(): + if not ftensor.is_attr(): continue + device_tensors: Dict[int, Set[IRSubTensor]] = {} + for ctensor in graph.ctensors(ftensor): + for devid in ctensor.device: + local_tensors = device_tensors.setdefault(devid, set()) + for t in local_tensors: + assert t == ctensor or not t.overlap(ctensor), ( + f"Detected graph attribute is partitioned with shared part on device {devid}.\n" + f"To achieve this, need call graph.multiref at the front of sProgram.\n" + f"{graph.debug_tensor_map_str(ftensor)}" + ) + local_tensors.add(ctensor) + for segment in graph.select(ntype=IRSegment, flatten=False): + if segment.isfw(): + check_consistent_local_partition(segment) + + check_consistent_local_partition(graph) + # collect subtensor and consumer fweights: Dict[IRFullTensor, List[IRSubTensor]] = dict() fgrads: Dict[IRFullTensor, List[IRSubTensor]] = dict() @@ -212,8 +229,6 @@ def gen_weight(graph: IRGraph) -> IRGraph: if grad not in weight_grads[weight]: weight_grads[weight][grad] = [] weight_grads[weight][grad].append(consumer) - - # TODO: check sub_weight is no-overlapping # assert all(sw.valmap[1] == len(weight_grads) for sw in weight_grads.keys()) for sub_weight in weight_grads: @@ -315,7 +330,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # consumers can be operators and graph outputs fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) if ftensor in output_consumer: - fctensors = fctensors + tuple(fwop.input(0) for fwop in output_consumer[ftensor] if fwop.input(0) not in fctensors) + fctensors = fctensors + tuple(fwop.input(0) for fwop in output_consumer[ftensor]) fctensors = expand_devices(fctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" @@ -585,6 +600,7 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): # collect consumer of each device for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): + if consumer.mirror is None: continue for devid in ctensor.device: devtensors.setdefault(devid, []).append(ctensor) devops.setdefault(devid, []).append(consumer) @@ -658,27 +674,51 @@ def autoref(graph: IRSegment) -> IRGraph: @return None """ for multiref in graph.select(name='multiref', flatten=False): + # setup recompute + idx = graph.index(multiref).indices[0] + recompute = None + neighbor = graph.node(idx-1) if idx > 0 else graph.node(idx+1) + if isinstance(neighbor, IRFwOperation): + recompute = neighbor.recompute + ftensor: IRFullTensor = multiref.input(0).parent multirefs = [] - tensors = graph.ptensors(ftensor) if len(graph.ptensors(ftensor)) > 0 else graph.ctensors(ftensor) - for otensor in tensors: - mr = MultiRef(otensor, len(multiref.outputs())) - for idx in range(len(multiref.outputs())): - output = multiref.output(idx).parent.select(otensor.indmap, otensor.valmap) - if otensor.requires_grad: - output.grad = multiref.output(idx).grad.parent.select(otensor.indmap, (0,1)) - mr.set_output(idx, output) - mr.device = otensor.device - if isinstance(otensor.cell, IRFwOperation): - mr.recompute = otensor.cell.recompute - multirefs.append(mr) + # by default follow producer transformation strategy + ptensors = graph.ptensors(ftensor) + if len(ptensors) > 0: + for tensor in ptensors: + mr = MultiRef(tensor, len(multiref.outputs())) + mr.input(0).grad = tensor.grad + for idx, out in enumerate(multiref.outputs()): + output = out.parent.select(tensor.indmap, tensor.valmap) + if out.grad is not None: + output.grad = out.grad.parent.select(tensor.indmap, (0,1)) + mr.set_output(idx, output) + mr.device = tensor.device + mr.recompute = recompute + multirefs.append(mr) + # otherwise replicate: usually for weight / graph inputs + else: + devices = set() + for otensor in multiref.outputs(): + ftensor = otensor.parent + for consumer in graph.consumers(ftensor): + devices.update(consumer.device) + devices = sorted(devices) + for devid in devices: + mr = multiref.replicate() + mr.device = devid + mr.recompute = recompute + multirefs.append(mr) + assert len(multirefs) > 0 # remove original multiref fidx = graph.remove(multiref) if multiref.mirror is not None: graph.mirror.remove(multiref.mirror) # insert multirefs + req_bw = multiref.mirror is not None for ofst, multiref in enumerate(multirefs): - if ftensor.requires_grad: + if req_bw: graph.finsert(multiref, fidx + ofst) else: graph.insert(multiref, fidx + ofst) diff --git a/cube/graph/gener/utils.py b/cube/graph/gener/utils.py index 35325c45..0e8f8e5e 100644 --- a/cube/graph/gener/utils.py +++ b/cube/graph/gener/utils.py @@ -1,7 +1,7 @@ """ Utilities for gradient modification """ -from typing import Dict, List +from typing import Dict, List, Union, Tuple from cube.graph import IRGraph from cube.graph.segment import IRSegment from cube.ir.operator import IRFwOperation @@ -10,7 +10,7 @@ class DummyInputOuput(IRFwOperation): - def __init__(self, tensor: IRSubTensor, device: int, + def __init__(self, tensor: IRSubTensor, device: Union[int, Tuple[int]], is_input=False, is_output=False, name='dummy'): assert (is_input and not is_output) or (is_output and not is_input) @@ -71,7 +71,7 @@ def convert_add_to_valmap(graph: IRGraph, add_node: IRFwOperation): def flatten_grad(graph: IRSegment, ftensor: IRFullTensor): """ Reset gradient for consumers that are different (no replica) - Gradient valuemap will be flatten iter-devices, e.g.,(0,3), (1,3), (2,3) + Gradient valuemap will be flatten inter-devices, e.g.,(0,3), (1,3), (2,3) Gradient valuemap will be exponent intra-devices, e.g., (0,2), (2,4), (3,4) @param graph IRGraph: the graph @@ -93,6 +93,7 @@ def flatten_grad(graph: IRSegment, ftensor: IRFullTensor): if len(ctensor.device) > 1: return devtensors[ctensor][ctensor.device[0]] = [] for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): + if consumer.mirror is None: continue devid = ctensor.device[0] devtensors[ctensor][devid].append(consumer) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index a8707742..f840a1b7 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -10,19 +10,18 @@ from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any import warnings import copy -import pickle import dill +import sys from cube.ir.cten import IRTensor, IRCell, IRObject from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap from cube.ir.dtype import IRDType, DTypeInferRule -from cube.graph.function.function import Identity, MultiRef +from cube.graph.function.function import Identity from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.function.dimops import IRDimops from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo @@ -135,6 +134,15 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: def backward(self, loss: IRSubTensor): """ Backward the graph from the entry tensor of loss. + + This will infer tensors' gradients by following rules: + + Conditions must satisfy for an forward op having its backward op: + * one of its output tensors requires gradient + * one of its output tensors is consumed by other forward ops + + For operators that doesn't need backward, all gradients of their + input/output tensors will make to None (despite require_grad is True) @param loss IRSubTensor: the loss tensor, must be in the output of current graph. The loss shape should be (1,) @@ -146,17 +154,27 @@ def backward(self, loss: IRSubTensor): # set loss gradient loss.parent.to_loss() + # update require gradient: for tensors that have no consumers, + # make their gradient to be False + for ftensor in self.full_tensors(): + if ftensor.is_loss(): continue + consumers = [n for n in self.consumers(ftensor) if isinstance(n, IRFwOperation)] + if len(consumers) == 0 and ftensor.requires_grad: + print(f"warning: detected a dead ftensor which is not consumed by any nodes:\n\t{ftensor.name}: {ftensor}", file=sys.stderr) + ftensor.requires_grad = False + # infer gradient for ftensor in self.full_tensors(): self.infer_grad(ftensor) + # create backward node for fnode in self.nodes()[::-1]: assert not isinstance(fnode, IRSegment), "Internal Error: Segment should not appear for now" if not isinstance(fnode, IRFwOperation): continue - tensors = [t for t in fnode.inputs() + fnode.outputs() if isinstance(t, IRSubTensor)] - grads = [t.grad for t in tensors] + outputs = [t for t in fnode.outputs() if isinstance(t, IRSubTensor)] # no backward op generated for fnode - if all(grad is None for grad in grads): continue + if all(t.grad is None for t in outputs): + continue # create backward op and insert to graph bwop = self.create_bwop(fnode) self.insert(bwop, self.nnodes) @@ -166,63 +184,61 @@ def backward(self, loss: IRSubTensor): # ========================= Graph Manipulation ======================== - def group(self, fnodes: List[IRCell]) -> IRSegment: + def group(self, nodes: List[IRCell]) -> IRSegment: """! - Group consecutive forward nodes into IRSegment. - Note the fnodes should not apply any transformation. - TODO: update operator dependency - - The corresponding backward nodes will also be grouped. + Group consecutive nodes into IRSegment. + Note nodes should not have applied by any transformation. - @param nodes List[IRCell]: the consecutive node subset of this graph + @param nodes List[IRCell]: consecutive nodes in forward procedure @return segment IRSegment: the grouped segment """ - assert any(not isinstance(node, (IRBpOperation, IRDataOperation)) for node in fnodes), \ - "grouped nodes cannot be backward operation, segment or data operation" - - fgraphs = [self.segment(fnode) for fnode in fnodes] - assert len(set(fgraphs)) == 1, "Cross-segment grouping is not allowed yet." - - # get backward nodes - bnodes = [fnode.mirror for fnode in fnodes[::-1] if fnode.mirror is not None] - - fgraph: IRSegment = fgraphs[0] - bgraph: IRSegment = fgraph.mirror + assert all(node.isfw() for node in nodes), f"Expected all nodes in forward procedure" + fgraphs = [self.segment(fnode) for fnode in nodes] + assert len(set(fgraphs)) == 1, "cross-segment grouping is not allowed yet." - findices: Tuple[int] = tuple(fgraph.index(fnode)[0] for fnode in fnodes) - bindices: Tuple[int] = tuple(bgraph.index(bnode)[0] for bnode in bnodes) + fgraph: IRSegment = fgraphs[0] + findices: Tuple[int] = tuple(fgraph.index(node)[0] for node in nodes) + min_fidx, max_fidx = min(findices), max(findices) + assert max_fidx - min_fidx + 1 == len(nodes), "nodes should be in consecutive order" - minfidx, maxfidx = min(findices), max(findices) - assert maxfidx - minfidx + 1 == len(fnodes), \ - "Forward nodes are not consecutive" + fsegment: IRSegment = fgraph.create_segment(nodes) + for node in nodes: + idx = fgraph.remove(node) + fgraph.insert(fsegment, idx) - if len(bnodes) > 0: - minbidx, maxbidx = min(bindices), max(bindices) - assert maxbidx - minbidx + 1 == len(bnodes), \ - f"Internal Error: backward nodes are not consecutive. maxbidx: {maxbidx}, minbidx: {minbidx}" + # group for mirror nodes + bnodes = [node.mirror for node in nodes if node.mirror is not None] + if len(bnodes) == 0: return fsegment - # remove fnodes and insert fsegment - fsegment: IRSegment = fgraph.create_segment(fnodes) - for fnode in fnodes: - fidx = fgraph.remove(fnode) - fgraph.insert(fsegment, fidx) + # check consecutive + bgraph: IRSegment = fgraph.mirror + bindices = [bgraph.index(bnode)[0] for bnode in bnodes] + min_bidx, max_bidx = min(bindices), max(bindices) + assert max_bidx - min_bidx + 1 == len(bnodes), \ + f"backward nodes are not consecutive. minbidx: {min_bidx}, maxbidx: {max_bidx}" - # reset fsegment gradient + # update gradient for fgraph for itensor in fsegment.inputs(): - if isinstance(itensor, IRTensor): - fgraph.infer_grad(itensor.parent) - - # update backward - if len(bnodes) > 0: - # remove backward nodes - for bnode in bnodes: - bidx = bgraph.remove(bnode) - # create new backward node - bnodes = [fsegment.create_bwop(fnode) for fnode in fnodes[::-1]] - # create and insert backward segment - bsegment = fgraph.create_bwop(fsegment) - bgraph.insert(bsegment, bidx) + fgraph.infer_grad(itensor.parent) + # update gradient inside segment + for ftensor in fsegment.full_tensors(): + fsegment.infer_grad(ftensor) + + # create backward segment + for bnode in bnodes: + bidx = bgraph.remove(bnode) + bnodes = [fsegment.create_bwop(fnode) for fnode in nodes[::-1] if fnode.mirror is not None] + # get backward graph inputs + output_grads = [t.grad for t in fsegment.outputs() if isinstance(t, IRSubTensor) and t.grad is not None] + # get backward graph outputs + input_grads = [t.grad for t in fsegment.inputs() if \ + isinstance(t, IRSubTensor) and t.grad is not None] + bsegment = IRSegment(bnodes, output_grads, input_grads) + + bgraph.insert(bsegment, bidx) + IRCell.make_pair(fsegment, bsegment) + return fsegment # ========================== Graph Creation ======================== @@ -357,61 +373,61 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], fnodes = algo.instantiate(**config) assert fnodes is not None, f"Fail to partition node: {node} use algorithm and config: {config}" - # set gradient - valmaps: Dict[IRFullTensor, ValueMap] = dict() - for t in node.inputs(): - if isinstance(t, IRSubTensor) and t.requires_grad: - valmaps[t.parent] = ValueMap(t.grad.valmap) - # set up consumers + # insert forward node + fsegment: IRSegment = self.segment(node) + for fnode in fnodes: + if isinstance(node, IRFwOperation): + fnode.recompute = node.recompute + if isinstance(node.comment, str): + fnode.comment = node.comment + fnode.device = node.device + fsegment.replace(node, fnodes) + + if node.mirror is None: return fnodes + + valmaps: Dict[IRFullTensor, Optional[ValueMap]] = dict() + for t in node.inputs() + node.outputs(): + if isinstance(t, IRSubTensor): + valmaps[t.parent] = None if t.grad is None else ValueMap(t.grad.valmap) + + # gather consumers ctensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() consumers: Dict[IRFullTensor, List[IRCell]] = dict() for fnode in fnodes: - for itensor in fnode.inputs(): + for itensor in set(fnode.inputs()): if not isinstance(itensor, IRSubTensor): continue - if not itensor.requires_grad: continue - if itensor.parent not in ctensors: - ctensors[itensor.parent] = [] - consumers[itensor.parent] = [] - ctensors[itensor.parent].append(itensor) - consumers[itensor.parent].append(fnode) + ctensors.setdefault(itensor.parent, []).append(itensor) + consumers.setdefault(itensor.parent, []).append(fnode) # set up gradient for fnode in fnodes: for itensor in fnode.inputs(): if not isinstance(itensor, IRSubTensor): continue - if not itensor.requires_grad: continue ftensor = itensor.parent - # the [::-1] only makes the valuemap to grow with execution order - cs = [c for c, t in zip(consumers[ftensor], ctensors[ftensor]) if t == itensor][::-1] - valmap = valmaps[itensor.parent].map((cs.index(fnode), len(cs))) + itensor.grad = None + if valmaps[ftensor] is None: continue + # collect consumers that consume the same sub_tensor + consumers_of_same_tensor = [] + for idx, t in enumerate(ctensors[ftensor]): + if t == itensor: + consumers_of_same_tensor.append(consumers[ftensor][idx]) + consumers_of_same_tensor = consumers_of_same_tensor[::-1] # make valmap grow with exec order + # calculate value map + valmap = valmaps[ftensor].map( + (consumers_of_same_tensor.index(fnode), len(consumers_of_same_tensor)) + ) grad = ftensor.grad.select(itensor.indmap, valmap) itensor.grad = grad for otensor in fnode.outputs(): if not isinstance(otensor, IRSubTensor): continue - if not otensor.requires_grad: - grad = None - else: - grad = otensor.parent.grad.select(otensor.indmap, (0,1)) - otensor.grad = grad - - # insert forward node - fsegment: IRSegment = self.segment(node) - for fnode in fnodes: - if isinstance(node, IRFwOperation): - fnode.recompute = node.recompute - if isinstance(node.comment, str): - fnode.comment = node.comment - fnode.device = node.device - fsegment.replace(node, fnodes) + otensor.grad = None if valmaps[otensor.parent] is None else \ + otensor.parent.grad.select(otensor.indmap, (0,1)) # insert backward node - if isinstance(node.mirror, IRCell): - bnodes = [fsegment.create_bwop(fnode) for fnode in fnodes[::-1]] - assert isinstance(node.mirror, IRBpOperation) - assert len(bnodes) == len(fnodes) - for bnode in bnodes: - bnode.device = node.device - bsegment: IRSegment = fsegment.mirror - bsegment.replace(node.mirror, bnodes) + bnodes = [fsegment.create_bwop(fnode) for fnode in fnodes[::-1]] + for bnode in bnodes: + bnode.device = node.device + bsegment: IRSegment = fsegment.mirror + bsegment.replace(node.mirror, bnodes) return fnodes @@ -861,70 +877,6 @@ def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: # =================== Helpers ==================== - def auto_multiref(self): - """ - Automatically partition and schedule multiref node. - This requires to call after all transformation and - scheduling. - - The policy is to partition and assign multiref - in the same way of its input producer - """ - for node in self.nodes(flatten=True): - if node.name == 'multiref': - if len(node.device) != 0: continue - segment: IRSegment = self.segment(node) - ftensor = node.input(0).parent - ptensors = segment.ptensors(ftensor) - - multirefs = [] - - # use downstream consumers - devtensors: Dict[int, List[IRSubTensor]] = dict() - for tensor in node.outputs(): - for ctensor in segment.ctensors(tensor.parent): - for devid in ctensor.device: - if devid not in devtensors: - devtensors[devid] = [] - devtensors[devid].append(ctensor) - devids = list(devtensors.keys()) - ctensors = [ts[0] for ts in devtensors.values()] - for devid, ctensor in zip(devids, ctensors): - itensor = node.input(0).parent.select(ctensor.indmap, ctensor.valmap) - otensors = [] - for otensor in node.outputs(): - otensors.append(otensor.parent.select(ctensor.indmap, ctensor.valmap)) - multiref = MultiRef(itensor, len(otensors)) - for idx, otensor in enumerate(otensors): - multiref.set_output(idx, otensor) - multiref.device = devid - multirefs.append(multiref) - - # if no downstream consumers, use upstream producers - if len(multirefs) == 0: - for ptensor in ptensors: - assert len(ptensor.device) > 0, \ - "Auto Multiref requires its producer nodes assigned to devices" - for devid in ptensor.device: - outputs = [] - for output in node.outputs(): - outputs.append(output.parent.select(ptensor.indmap, ptensor.valmap)) - multiref = MultiRef(ptensor, len(outputs)) - for idx, otensor in enumerate(outputs): - multiref.set_output(idx, otensor) - multiref.device = devid - multirefs.append(multiref) - - # replace into graph - fidx = self.remove(node) - if node.mirror is not None: - self.remove(node.mirror) - for multiref in multirefs[::-1]: - if node.mirror is not None: - self.finsert(multiref, fidx) - else: - self.insert(multiref, fidx) - def dump(self, filename: str) -> None: """ Dump the graph into pickled format diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 88461b46..53c9bc3e 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -76,7 +76,7 @@ class IRSegment(IRCell): Inserting and removing nodes that could change input/output are not allowed. """ - def __init__(self, nodes: List[IRCell], inputs: List[Any], outputs: List[Any], name='segment'): + def __init__(self, nodes: List[IRCell], inputs: List[IRObject], outputs: List[Any], name='segment'): super().__init__(name, '', len(inputs), len(outputs), init_outputs=False) self._nodes: List[IRCell] = [] @@ -262,8 +262,7 @@ def producers(self, ftensor: IRFullTensor) -> Tuple[IRCell]: @return subtensors Tuple[IRSubTensor]: the producers. """ - assert ftensor in self._producers, f"{ftensor} is not in the graph" - return tuple(self._producers[ftensor]) + return tuple(self._producers.get(ftensor, ())) def consumers(self, ftensor: IRFullTensor) -> Tuple[IRCell]: """ @@ -273,8 +272,7 @@ def consumers(self, ftensor: IRFullTensor) -> Tuple[IRCell]: @return subtensors Tuple[IRCell]: theconsumers """ - assert ftensor in self._consumers, f"{ftensor} is not in the graph" - return tuple(self._consumers[ftensor]) + return tuple(self._consumers.get(ftensor, ())) def ptensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: """ @@ -284,8 +282,7 @@ def ptensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: @return subtensors Tuple[IRSubTensor]: the consumed subtensors. """ - assert ftensor in self._ptensors, f"{ftensor} is not in the graph" - return tuple(self._ptensors[ftensor]) + return tuple(self._ptensors.get(ftensor, ())) def ctensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: """ @@ -295,8 +292,7 @@ def ctensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: @return subtensors Tuple[IRSubTensor]: the consumed subtensors. """ - assert ftensor in self._ctensors, f"{ftensor} is not in the graph" - return tuple(self._ctensors[ftensor]) + return tuple(self._ctensors.get(ftensor, ())) def infer_grad(self, ftensor: IRFullTensor) -> None: """ @@ -305,51 +301,64 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: Note this can only be called when no operator transformation is applied for this graph. + If a tensor is consumed by multiple consumers, the value map of its gradient + will be in exponential format. + + E.g., t has consumed by node1, node2, node3 and node4. + Then the gradient value_map of t (t.grad) of each consumer is (idx, nchunks): + (0, 2), (2, 4), (6, 8), (7, 8), + where: + (0, 2) + (2, 4) + (6, 8) + (7, 8) + = (0, 2) + (2, 4) + (3, 4) + = (0, 2) + (1, 2) + = FULL VALUE + @param ftensor IRFullTensor: the full tensor. @return None: gradient are set to producer/consumer tensor's .grad """ - fgrad = ftensor.grad - # set for producer + # check condition: no transformation assert len(self.producers(ftensor)) <= 1, ( f"grad can only be set when no transformation is applied but got:\n" f"{self.debug_tensor_map_str(ftensor)}" ) + assert len(set(self.ctensors(ftensor))) <= 1, ( + f"grad can only be set when no transformation is applied but got:\n" + f"{self.debug_tensor_map_str(ftensor)}" + ) + + fgrad = ftensor.grad + # set for producer for ptensor, producer in zip(self.ptensors(ftensor), self.producers(ftensor)): # filter out non-autograd operators of IRPyFunc if isinstance(producer, IRPyFunc): continue idx = producer.outputs().index(ptensor) - if fgrad is None: - grad = None - else: - grad = fgrad.select(ptensor.indmap, (0, 1)) + grad = None if fgrad is None else fgrad.select(ptensor.indmap, (0, 1)) producer.output(idx).grad = grad + # set for consumers - ctensors = self.ctensors(ftensor) - if len(ctensors) > 0: - assert all(ctensor == ctensors[0] for ctensor in ctensors), ( - f"grad can only be set when no transformation is applied but got:\n" - f"{self.debug_tensor_map_str(ftensor)}" - ) - curr_valmap = ValueMap((0, 1)) - - # filter out non-autograd operators of IRPyFunc - consumers, ctensors = [], [] + consumers, ctensors = [], [] # consumers that require gradient for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): + # set by default None + for t in consumer.inputs(): # consider an op can have multiple same-tensor inputs + if isinstance(t, IRSubTensor) and t == ctensor: + t.grad = None + # filter out non-autograd operators + if fgrad is None: continue if isinstance(consumer, IRPyFunc): continue - consumers.append(consumer) - ctensors.append(ctensor) - + if any(isinstance(t, IRSubTensor) and t.requires_grad for t in consumer.outputs()): + consumers.append(consumer) + ctensors.append(ctensor) + # set with value map + curr_valmap = ValueMap((0, 1)) nconsumers = len(consumers) for cidx, (ctensor, consumer) in enumerate(zip(ctensors, consumers)): - idx = consumer.inputs().index(ctensor) - if fgrad is None: - grad = None - else: - valmap = curr_valmap.map((0, 2)) if cidx != nconsumers - 1 else curr_valmap - grad = fgrad.select(ctensor.indmap, valmap) - curr_valmap = curr_valmap.map((1, 2)) if cidx != nconsumers - 1 else curr_valmap - consumer.input(idx).grad = grad + valmap = curr_valmap.map((0, 2)) if cidx != nconsumers - 1 else curr_valmap + grad = fgrad.select(ctensor.indmap, valmap) + curr_valmap = curr_valmap.map((1, 2)) if cidx != nconsumers - 1 else curr_valmap + for t in consumer.inputs(): + if isinstance(t, IRSubTensor) and t == ctensor: + t.grad = grad def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: dscp : str = '' @@ -364,25 +373,26 @@ def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: dscp += f'\t{consumer}\n' return dscp - def create_bwop(self, fwop: IRFwOperation) -> Union[IRBpOperation, IRBpOperation]: + def create_bwop(self, fwop: IRFwOperation) -> IRBpOperation: """ Create dummy backward operator for given forward operator. This assumes input/output tensors of fwop have been set by correct gradient tensors. + This can only be called before any transformation / grouping + @param fwop IRFwOperation: forward operation @return bwop IRBpOperation: the created backward operation """ - assert isinstance(fwop, (IRFwOperation, IRSegment)), "Expected IRFwOperation" + assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] - igrads = [t.grad if t.requires_grad else None for t in fins] - ograds = [t.grad if t.requires_grad else None for t in fous] - if isinstance(fwop, IRFwOperation): - bwop = IRBpOperation(ograds, igrads) - else: - bnodes = [fnode.mirror for fnode in fwop.nodes() if fnode.mirror is not None][::-1] - bwop = IRSegment(bnodes, ograds, igrads) + igrads = [t.grad for t in fins if t.grad is not None] + # note not all output tensors will be consumed by nodes, e.g., chunk. + # for these cases, the backward op doesn't have exactly the same number of + # backward inputs with the number of its forward outputs + ograds = [t.grad for t in fous if t.grad is not None] + bwop = IRBpOperation(ograds, igrads) IRCell.make_pair(fwop, bwop) return bwop @@ -664,7 +674,7 @@ def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwO # ===================== Advance Graph manipulations ================== - def multiref(self, ftensor: IRFullTensor, node_groups: List[List[IRFwOperation]]) -> IRFwOperation: + def multiref(self, ftensor: IRFullTensor, *deprecated_args) -> IRFwOperation: """ Add multiref to separate forward nodes that consume a same tensor into different tensor alias. This should be called before any graph transformation. @@ -674,56 +684,57 @@ def multiref(self, ftensor: IRFullTensor, node_groups: List[List[IRFwOperation]] tensor adapters. @param tensor IRSubTensor: tensor. - @param node_groups List[List[IRFwOperation]]: - operators that take the tensor as input. - @return multiref IRFwOperation: the inserted multiref operator. """ assert ftensor in self._ftensors, f"tensor: {ftensor} not in this graph." - # check no transformation if len(self.consumers(ftensor)) <= 1: return assert not ftensor.is_grad(), f"graph.multiref can only be applied on a non-gradient full tensor." - assert len(set(self.ctensors(ftensor))) == 1, \ - f"Detected happened graph transformation. This interfacee should be called before graph transformation." - # check completeness - consumers = set() - for nodes in node_groups: - consumers.update(nodes) - assert consumers == set(self.consumers(ftensor)), f"some consumer(s) are not in node_groups" + # check no transformation + assert len(self.ptensors(ftensor)) <= 1, f"no transformation should be called before multiref" + assert len(set(self.ctensors(ftensor))) == 1, f"no transformation should be called before multiref" + # create new full tensors + consumers = self.consumers(ftensor) tensor = self.ctensors(ftensor)[0] - ftensors: List[IRSubTensor] = [ftensor.like() for _ in node_groups] + ftensors: List[IRSubTensor] = [ftensor.like() for _ in consumers] otensors: List[IRSubTensor] = [ft.select(tensor.indmap, tensor.valmap) for ft in ftensors] # create multiref - multiref = MultiRef(tensor, len(node_groups)) + multiref = MultiRef(tensor, len(consumers)) for idx, otensor in enumerate(otensors): multiref.set_output(idx, otensor) # setup gradient - if tensor.requires_grad: - multiref.input(0).grad = tensor.parent.grad.select(tensor.indmap, (0, 1)) - for idx, output in enumerate(multiref.outputs()): + req_grad = ftensor.requires_grad + multiref.input(0).grad = ftensor.grad.select(tensor.indmap, (0, 1)) if req_grad else None + for idx, output in enumerate(multiref.outputs()): + if ftensor.grad is None or consumers[idx].mirror is None: + output.grad = None + else: output.grad = ftensors[idx].grad.select(tensor.indmap, (0,1)) # insert multiref if len(self.producers(ftensor)) == 0: fidx = min(self.index(consumer) for consumer in self.consumers(ftensor)) else: fidx = max(self.index(prod) for prod in self.producers(ftensor)) + 1 - if ftensor.requires_grad: + if req_grad: self.finsert(multiref, fidx) else: self.insert(multiref, fidx) # update forward / backward consumer - for otensor, nodes in zip(otensors, node_groups): - for idx, node in enumerate(nodes): - fidx = node.inputs().index(tensor) - grad = node.input(fidx).grad - with self.update(node): - node.set_input(fidx, otensor) - if tensor.requires_grad: - node.input(fidx).grad = otensor.parent.grad.select(otensor.indmap, (idx, len(nodes))) - with self.mirror.update(node.mirror) as bnode: - bidx = bnode.outputs().index(grad) - bnode.set_output(bidx, node.input(bidx).grad) + for idx, consumer in enumerate(consumers): + fidx = consumer.inputs().index(tensor) + grad = consumer.input(fidx).grad + # update forward + with self.update(consumer): + for fidx, t in enumerate(consumer.inputs()): + if tensor == t: + consumer.set_input(fidx, multiref.output(idx)) + consumer.input(fidx).grad = multiref.output(idx).grad + if consumer.mirror is None: continue + # update backward + with self.mirror.update(consumer.mirror) as bnode: + for bidx, t in enumerate(bnode.outputs()): + if grad is not None and grad == t: + bnode.set_output(bidx, multiref.output(idx).grad) return multiref def single_consume(self, one_for_all: bool = True): @@ -994,10 +1005,11 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: if any(t.overlap(otensor) for t in segment_outputs if isinstance(t, IRObject)): outputs.add(otensor) continue - # loss doesn't have consumers + # loss must be returned + if isinstance(ftensor, IRFullTensor) and ftensor.is_loss(): + outputs.add(otensor) + continue if len(segment.consumers(ftensor)) == 0: - if isinstance(ftensor, IRFullTensor) and ftensor.is_loss(): - outputs.add(otensor) continue # for outside consumers consumers, ctensors = segment.consumers(ftensor), segment.ctensors(ftensor) diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 1fa768f6..99c90eb8 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -296,14 +296,13 @@ def grad(self) -> Optional[IRTensor]: @grad.setter def grad(self, val: Optional[IRTensor]): """ - int indicates the tensor is the loss tensor. + Setup gradient for the tensor. """ - if self._requires_grad: - assert isinstance(val, IRFullTensor) + assert val is None or isinstance(val, IRFullTensor) + if val is not None: + assert self._requires_grad, f"Cannot assign {val} to no grad-required tensor" assert val.shape == self.shape assert val.is_attr() == self.is_attr() - else: - assert val is None, "The FullTensor doesn't require grad but is assigned with a grad." self._grad = val def is_loss(self) -> bool: @@ -631,14 +630,15 @@ def grad(self) -> bool: @grad.setter def grad(self, val: Optional[IRTensor]): - if isinstance(val, IRSubTensor): - assert self.requires_grad and val.shape == self.shape, f'info: {self.requires_grad} {val.shape == self.shape}' - self._grad = val - elif val is None: - assert not self.requires_grad - self._grad = None - else: - raise ValueError(f"Expected grad to be None or IRSubTensor but got: {val}") + """ + Setup gradient for the tensor. + """ + assert val is None or isinstance(val, IRSubTensor) + if val is not None: + assert self._requires_grad, f"Cannot assign {val} to no grad-required tensor" + assert val.shape == self.shape + assert val.is_attr() == self.is_attr() + self._grad = val def is_loss(self) -> bool: """ diff --git a/tests/graph/test_infer_grad.py b/tests/graph/test_infer_grad.py new file mode 100644 index 00000000..f185dd36 --- /dev/null +++ b/tests/graph/test_infer_grad.py @@ -0,0 +1,134 @@ +""" +USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/graph/test_infer_grad.py +USE_TORCHFX=1 torchrun --nproc_per_node=2 tests/graph/test_infer_grad.py +""" +from typing import List +import torch + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation, IRDataOperation + +cube.init() + + +def _param(size, dtype=torch.float32): + return torch.nn.Parameter(torch.empty(size, dtype=dtype)) + +def _rand(size, dtype=torch.float32): + return torch.rand(size, dtype=dtype, device=torch.cuda.current_device()) + + +class TestOpModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.param1 = _param([256, 512]) + self.param2 = _param([256, 512]) + self.param3 = _param([256, 512]) + + def forward(self, x: torch.Tensor): + x1 = x * self.param1 + x2 = x1 * self.param2 # no grad + + cube.runtime.function.anchor('residual') + x3 = x1 + 2 + x4 = x3 * self.param3 + + loss = torch.sum(x4) + return {'intermediate': [x3, x2], 'loss': loss}, loss.data + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self) -> None: + self.sample = _rand([256, 512]) + super().__init__(256, (0,)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +def policy_test_single_device(graph: IRGraph, resource): + print(graph.extra_repr()) + for idx, node in enumerate(graph.select(name='mul')): + if idx == 1: + assert node.mirror is None + for t in node.inputs() + node.outputs(): + assert t.grad is None + elif idx == 2: + assert node.mirror is not None + for t in node.inputs() + node.outputs(): + assert t.grad is not None + for node in graph.select(ntype=(IRDataOperation, IRFwOperation)): + graph.assign(node, 0) + return graph + + +def policy_test_multi_device(graph: IRGraph, resource): + # multiref + for ftensor in graph.full_tensors(): + if ftensor.is_attr(): continue + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) + + print(graph.extra_repr()) + assert resource.ngpus == 2 + for idx, node in enumerate(graph.select(ntype=(IRFwOperation, IRDataOperation))): + devid = 0 if idx < 4 else 1 + graph.assign(node, devid) + print(graph.extra_repr()) + return graph + + +def test_single_no_backward_ops(): + + model = TestOpModule() + dataloader = TestDataLoader() + + @cube.compile(model, dataloader, PAS=policy_test_single_device, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) + def train_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out[0]['loss'].backward() + return out + + model = cube.load_model(load_content=False) + + for idx in range(3): + train_iter(model, dataloader) + print(f"single device: iter {idx}/3") + + +def test_multidev_residual(): + + model = TestOpModule() + dataloader = TestDataLoader() + + @cube.compile(model, dataloader, PAS=policy_test_multi_device, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) + def train_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out[0]['loss'].backward() + return out + + model = cube.load_model(load_content=False) + + for idx in range(3): + train_iter(model, dataloader) + print(f"multi device: iter {idx}/3") + + +if __name__ == '__main__': + if torch.distributed.get_world_size() == 1: + test_single_no_backward_ops() + if torch.distributed.get_world_size() == 2: + test_multidev_residual() diff --git a/tests/graph/test_multiref.py b/tests/graph/test_multiref.py new file mode 100644 index 00000000..76e384a1 --- /dev/null +++ b/tests/graph/test_multiref.py @@ -0,0 +1,115 @@ +""" +torchrun --nproc_per_node=2 tests/graph/test_multiref.py +""" +import torch + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops + +cube.init() + + +def _param(shape, dtype=torch.float32): + return torch.nn.Parameter(torch.empty(shape, dtype=dtype)) + + +class TestOpModule(torch.nn.Module): + + def __init__(self, shape=[256, 512]): + super().__init__() + self.param = _param(shape) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + x = x * self.param + x = torch.sum(x) + + y = y * self.param + y = torch.sum(y) + + loss = x + y + return loss + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int = 256) -> None: + self.sample = ( + torch.rand([batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()), + torch.rand([batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()), + ) + super().__init__(batch_size, (0, 0)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +def _tp(graph, node, devs, idx, dim): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + for node, devid in zip(sub_nodes, devs): + graph.assign(node, devid) + return sub_nodes + + +def _replica(graph, node, devs): + rnodes = graph.replicate(node, times=len(devs)) + for rnode, devid in zip(rnodes, devs): + graph.assign(rnode, devid) + return rnodes + + +def test_multiref_param(): + + cube.init() + + model = TestOpModule() + dataloader = TestDataLoader() + + def policy(graph: IRGraph, resource): + + # multiref + for t in graph.full_tensors(): + if len(graph.consumers(t)) > 1: + graph.multiref(t) + + devs = list(range(resource.ngpus)) + + muls = graph.select(name='mul') + _tp(graph, muls[0], devs, idx=1, dim=0) + _tp(graph, muls[1], devs, idx=1, dim=1) + + for node in graph.select(ntype=(IRDataOperation, IRFwOperation)): + if node.name == 'multiref': continue + if node.name == 'mul': continue + _replica(graph, node, devs) + + return graph + + sample_x, sample_y = next(dataloader) + + @cube.compile(model, dataloader, PAS=policy, load_content=True, + model_dummy_inputs={'x': sample_x, 'y': sample_y}) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(*data) + loss.backward() + + model = cube.load_model() + + for idx in range(3): + train_iter(model, dataloader) + print(f"iter {idx}/3") + print('Done') + + +if __name__ == '__main__': + test_multiref_param() + exit(0) From 0697546a5845aa27a384d84f4714ce6c23fe50e3 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Fri, 7 Apr 2023 08:41:46 +0000 Subject: [PATCH 1382/1892] Merged PR 1526: create feature: save/load distributed checkpoint create feature: save/load distributed checkpoint --- cube/graph/parser/parserfx.py | 7 +- cube/runtime/module.py | 145 ++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 3 deletions(-) diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index c6089f2b..6b9f5b0d 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -196,11 +196,12 @@ def parse(module: torch.fx.GraphModule, output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] - - frame.pop_var() - frame.pop_attr() if FxModuleParser.save_content: frame.save_attr_content() + frame.save_attr_map() + + frame.pop_var() + frame.pop_attr() return input_val, all_ir_nodes, output_val diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 852c79f4..3e4d9823 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -2,6 +2,7 @@ import torch from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer +import os class CubeModule(torch.nn.Module): @@ -39,6 +40,9 @@ def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: i assert hasattr(self, attr), f"{attr} is not in the module" self._fullmap[attr] = (tid, slicers, val_chunks) + def get_full_map(self): + return self._fullmap + def set_batch_size(self, bs: Optional[int]): assert (bs is None) or (isinstance(bs, int) and bs > 0) self._batch_size = bs @@ -60,3 +64,144 @@ def init_group(self, ranks: List[int]): if not all([isinstance(rank, int) for rank in ranks]): raise TypeError("Expected ranks to be List[int]") DeviceGroup().get_group(ranks) + + def get_checkpoint(self, optimizer: torch.optim.Optimizer = None): + state_dict = super().state_dict() + assert os.path.isfile('dist_param_map.pt'), 'Cannot open distributed parameter mapping file: dist_param_map.pt' + dist_param_map = torch.load('dist_param_map.pt') + param_area_map = self._fullmap + optimizer_state_dict = optimizer.state_dict() if optimizer is not None else None + return state_dict, dist_param_map, param_area_map, optimizer_state_dict + + def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_prefix: str = None): + filename_prefix = 'dist_checkpoint' if filename_prefix is None else filename_prefix + filename = f"{filename_prefix}-{DeviceGroup().rank}.ckpt" + state_dict, dist_param_map, param_area_map, optimizer_state_dict = self.get_checkpoint(optimizer) + print(f'> Saving distributed checkpoint to {filename}') + torch.save({ + 'state_dict': state_dict, + 'dist_param_map': dist_param_map, + 'param_area_map': param_area_map, + 'optim_state_dict': optimizer_state_dict, + }, filename) + + @staticmethod + def merge_partial_states(state_dicts): + """ + :param state_dicts: list of state_dict from different ranks + state_dict(model_state_dict, optimizer_state_dict, dist_param_map, param_area_map) + :return: merged state_dict(model_state_dict, optimizer_state_dict,) + """ + assert len(state_dicts) > 0 + if len(state_dicts) == 1: + return state_dicts[0][0], state_dicts[0][1] + + # find tensor full shape + param_max_dimsize = {} + for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: + for param_area in param_area_map.items(): + local_name = param_area[0][0:param_area[0].rfind('_')] + assert len(local_name) > 0 + raw_name = dist_param_map[local_name] + slices = param_area[1][1] + if param_area[1][2] != 1: + print(f'TODO: value-split on {raw_name}') + if raw_name in param_max_dimsize: + param_max_dimsize[raw_name] = max(param_max_dimsize[raw_name], slices) + else: + param_max_dimsize[raw_name] = slices + + # create full tensors + param_full_tensors = {} + sample_step = -1 + optim_full_tensors: Dict[int, Dict[any, any]] = {} # param_id, (state_name, state_val) + for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: + if len(optimizer_state_dict['state'].items()) > 0: + optimizer_state_names = list(optimizer_state_dict['state'][0].keys()) + print(f'optimizer_state_names = {optimizer_state_names}') + if 'step' in optimizer_state_names: + sample_step = optimizer_state_dict['state'][0]['step'] + optimizer_state_names.remove('step') + print(f'optimizer_state_names (without step) = {optimizer_state_names}') + else: + optimizer_state_names = [] + + other_optim_keys = [key for key in optimizer_state_dict.keys() if key != 'state'] + optimizer_other_state_dict = {} + for key in other_optim_keys: + optimizer_other_state_dict[key] = optimizer_state_dict[key] + + # for raw_name in param_max_dimsize.keys(): + model_state_dict_keys = list(model_state_dict.keys()) + for param_area in param_area_map.items(): + local_name_with_id = param_area[0] + local_name = local_name_with_id[0:local_name_with_id.rfind('_')] + raw_name = dist_param_map[local_name] + + tensor_size_slice = param_max_dimsize[raw_name] + tensor_size = [] + for dim_slice in tensor_size_slice: + tensor_size.append(dim_slice.stop) + param_full_tensors[raw_name] = torch.zeros(tuple(tensor_size)) + + index = model_state_dict_keys.index(local_name_with_id) + if index in optimizer_state_dict['state']: + for state_name in optimizer_state_names: # 'step' + if index not in optim_full_tensors: + optim_full_tensors[index] = {} + optim_full_tensors[index][state_name] = torch.zeros(tuple(tensor_size)) + else: + print(f'INFO: merge_checkpoint skips {local_name_with_id}\'s optimizer state') + # print(f'param_full_tensors = {param_full_tensors}') + # print(f'optim_full_tensors = {optim_full_tensors}') + break # only create once + + # assign value + for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: + model_state_dict_keys = list(model_state_dict.keys()) + for param_area in param_area_map.items(): + local_name_with_id = param_area[0] + local_name = local_name_with_id[0:local_name_with_id.rfind('_')] + raw_name = dist_param_map[local_name] + slices = param_area[1][1] + partial_tensor = model_state_dict[local_name_with_id] + param_full_tensors[raw_name][slices] = partial_tensor + + index = model_state_dict_keys.index(local_name_with_id) + if index in optimizer_state_dict['state']: + states = optimizer_state_dict['state'][index] + for name in optimizer_state_names: + val = states[name] + optim_full_tensors[index][name][slices] = val + if sample_step > 0: + optim_full_tensors[index]['step'] = sample_step + + # print(f'param_full_tensors (assigned) = {param_full_tensors}') + # print(f'optim_full_tensors (assigned) = {optim_full_tensors}') + + optimizer_other_state_dict.update({'state': optim_full_tensors}) + # dump to ckpt + return param_full_tensors, optimizer_other_state_dict + + @staticmethod + def merge_checkpoints(filename_prefix='dist_checkpoint'): + ckpts = {} + for rank in range(DeviceGroup().world_size): + filename = f"{filename_prefix}-{rank}.ckpt" + ckpts[rank] = torch.load(filename) + print(f'checkpoints = {ckpts}') + + state_dicts = [] + for ckpt in ckpts.values(): + model_state_dict = ckpt['state_dict'] + dist_param_map = ckpt['dist_param_map'] + param_area_map = ckpt['param_area_map'] + optimizer_state_dict = ckpt['optim_state_dict'] + state_dicts.push(model_state_dict, optimizer_state_dict, dist_param_map, param_area_map, ) + + merged_model_state_dict, merged_optimizer_state_dict = CubeModule.merge_partial_states(state_dicts) + + # dump to ckpt + torch.save({'state_dict': merged_model_state_dict, + 'optim_state_dict': merged_optimizer_state_dict + }, filename_prefix + '.full.ckpt') \ No newline at end of file From 22132edae6e907e662a75c182110535b40ffd905 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 9 Apr 2023 13:29:29 +0000 Subject: [PATCH 1383/1892] Merged PR 1537: fix backward bugs fix backward bugs --- cube/runtime/executor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 65ba43c8..93ef0af6 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -167,8 +167,10 @@ def backward(name: str, dedup_output_tensors = [] dedup_output_tensor_grads = [] for t, g in zip(output_tensors, output_tensor_grads): - if id(t) not in visited: - visited.add(id(t)) + # filter out duplicated output tensor and its grad. + pair = (id(t), id(g)) + if pair not in visited: + visited.add(pair) dedup_output_tensors.append(t) dedup_output_tensor_grads.append(g) From 5abc1467215919a0b2271ff19da3c33d93d9015c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 10 Apr 2023 02:24:10 +0000 Subject: [PATCH 1384/1892] Merged PR 1530: general policy example general policy example --- examples/policies/__init__.py | 2 + examples/policies/gshard.py | 96 ++++++++++++++++++++++++++++++++ examples/policies/random_spmd.py | 70 +++++++++++++++++++++++ 3 files changed, 168 insertions(+) create mode 100644 examples/policies/__init__.py create mode 100644 examples/policies/gshard.py create mode 100644 examples/policies/random_spmd.py diff --git a/examples/policies/__init__.py b/examples/policies/__init__.py new file mode 100644 index 00000000..c7b028f8 --- /dev/null +++ b/examples/policies/__init__.py @@ -0,0 +1,2 @@ +from examples.policies.gshard import PASGShard +from examples.policies.random_spmd import PASRandomSPMD \ No newline at end of file diff --git a/examples/policies/gshard.py b/examples/policies/gshard.py new file mode 100644 index 00000000..7e836dd0 --- /dev/null +++ b/examples/policies/gshard.py @@ -0,0 +1,96 @@ +""" +Policy example following GShard +""" + +from typing import List + +from cube.ir.tensor import IRSubTensor +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.graph import IRGraph +from cube.graph.function.dimops import IRDimops +from cube.graph.function.anchor import IRGraphAnchor + + +def follow(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int, + nodes: List[IRDimops]) -> List[IRDimops]: + """ + Partition nodes along one tensor dimension + + @param node IRDimops: the entry node + @param devs List[int]: the devices + @param idx int: entry node partition config idx + @param dim int: entry node partition config dim + @param nodes List[IRDimops]: partition node scopes + + @return remain_nodes List[IRDimops]: remaining nodes that are not partitioned + """ + assert node in nodes + algo = node.algorithms('dim') + if not algo.satisfy(idx=idx, dim=dim, num=len(devs)): return nodes + # tensor parallelism + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + # partition successors + nodes.remove(node) + for oidx, tensor in enumerate(node.outputs()): + if not isinstance(tensor, IRSubTensor): continue + ftensor = tensor.parent + for pdim in range(len(ftensor.shape)): + if sub_nodes[0].output(oidx).shape[pdim] != ftensor.shape[pdim]: + break + else: + continue + for consumer, ctensor in zip(graph.consumers(ftensor), graph.ctensors(ftensor)): + if not isinstance(consumer, IRDimops): continue + if isinstance(consumer, IRGraphAnchor) or consumer.name == 'multiref': continue + if consumer in nodes: + cidx = consumer.inputs().index(ctensor) + follow(graph, consumer, devs, cidx, pdim, nodes) + return nodes + + +def PASGShard(graph: IRGraph, resource): + + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): continue + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor) + + devs = list(range(resource.ngpus)) + + def replicate(node): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + + # print(graph.extra_repr()) + + fwops = graph.select(ntype=(IRDataOperation, IRFwOperation)) + print(f'> total fwops: {len(fwops)}') + while len(fwops) > 0: + fwop = fwops[0] + if isinstance(fwop, IRGraphAnchor) or fwop.name == 'multiref': + fwops.pop(0) + continue + # replicate if the node is not IRDimops + if not isinstance(fwop, IRDimops): + replicate(fwop) + fwops.pop(0) + continue + # partition along the longest dimension + configs = fwop.transform_space() + configs = sorted(configs, reverse=True, + key=lambda config: fwop.input(config[0]).shape[config[1]]) + for (idx, dim) in configs: + if fwop.input(idx).shape[dim] % len(devs) != 0: continue + if fwop.algorithms('dim').satisfy(idx=idx, dim=dim, num=len(devs)): + print(f'> policy partition: entry Fwop{fwop.cid}: {fwop.name} idx={idx}, dim={dim}') + follow(graph, fwop, devs, idx, dim, fwops) + print(f'> remaining fwops: {len(fwops)}') + break + else: + replicate(fwop) + fwops.pop(0) + return graph diff --git a/examples/policies/random_spmd.py b/examples/policies/random_spmd.py new file mode 100644 index 00000000..686dd8d9 --- /dev/null +++ b/examples/policies/random_spmd.py @@ -0,0 +1,70 @@ +""" +Random SPMD policy +""" +from typing import List, Optional +from cube.graph.graph import IRGraph +from cube.graph.function.dimops import IRDimops +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.function.anchor import IRGraphAnchor +from datetime import datetime + +import random + + +def _tp(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int): + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def _replica(graph: IRGraph, node, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def PASRandomSPMD(graph: IRGraph, resource, seed: Optional[int] = None): + """ + Random SPMD policy + """ + # get the current random state + state = random.getstate() + + seed = int(datetime.now().timestamp()) if seed is None else seed + print(f'> set random SPDM policy seed to {seed}') + random.seed(seed) + devs = list(range(resource.ngpus)) + + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): continue + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor) + + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'multiref' or isinstance(node, IRGraphAnchor): + continue + if isinstance(node, IRDimops): + configs = node.transform_space() + if len(configs) == 0: + _replica(graph, node, devs) + else: + configs = sorted(configs, reverse=True, + key=lambda config: node.input(config[0]).shape[config[1]]) + random.shuffle(configs) + for (idx, dim) in configs: + if node.input(idx).shape[dim] % len(devs) != 0: continue + if node.algorithms('dim').satisfy(idx=idx, dim=dim, num=len(devs)): + print(f'> partition node {node.name} ({node.cid}) with config idx={idx}, dim={dim}') + _tp(graph, node, devs, idx, dim) + break + else: + _replica(graph, node, devs) + else: + _replica(graph, node, devs) + + # restore the random state + random.setstate(state) + return graph From 2ff3ca693722a46b96c9928077befbd6c52a3e0f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 10 Apr 2023 11:44:58 +0000 Subject: [PATCH 1385/1892] Merged PR 1533: fix runtime require grad bugs fix runtime require grad bugs --- cube/graph/gener/concurrent.py | 7 +++---- cube/runtime/adapter/collectives.py | 12 ++++-------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index efd504b2..d946842e 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -54,7 +54,6 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # case 1: sharing device (intra-rvd) inshard = (set(pdevs) == set(cdevs)) and (len(fptensors) == len(fctensors)) and (len(pdevs) == len(fptensors)) if (not CompileFlag.disable_intra_rvd) and inshard and len(pdevs) > 1: - # fadapter = ConcurrentGener.gen_in_shard(fptensors, fctensors, bptensors, bctensors, allow_reassign=True) try: fadapter = ConcurrentGener.gen_intra_rvd(fptensors, fctensors, bptensors, bctensors, cost_fn) except Exception as e: @@ -70,7 +69,6 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # Case 2: sperating device (inter-rvd) if (not CompileFlag.disable_inter_rvd) and len(set(pdevs).intersection(cdevs)) == 0: - # fadapter = ConcurrentGener.gen_cross_shard(fptensors, fctensors, bptensors, bctensors) try: fadapter = ConcurrentGener.gen_inter_rvd(fptensors, fctensors, bptensors, bctensors, cost_fn) except Exception as e: @@ -130,13 +128,14 @@ def gen_intra_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # generate backward grad: IRFullTensor = ftensor.grad bprims = [] - if grad is not None and (len(bptensors) != 0 or len(bctensors) != 0): + if len(bptensors) > 0 and len(bctensors) > 0: # reorder ptensors to match with forward ptensors = [None] * len(devs) for bptensor in bptensors: idx = devs.index(bptensor.device) assert ptensors[idx] is None, "same device of different tensors" ptensors[idx] = bptensor + assert all(t is not None for t in ptensors), f"empty device slot from {bptensors}" ilayout = RVDLayout.togrid(grad, ptensors) olayout = RVDLayout.togrid(grad, bctensors) # paths, bprims = ilayout.path(olayout) @@ -172,7 +171,7 @@ def gen_inter_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], fadapter.prims = fprims grad: IRFullTensor = ftensor.grad - if grad is not None and (len(bptensors) != 0 or len(bctensors) != 0): + if len(bptensors) > 0 or len(bctensors) > 0: ilayout = RVDLayout.togrid(grad, bptensors) olayout = RVDLayout.togrid(grad, bctensors) bprims = InterPathFinder.path(ilayout, olayout, cost_fn) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 54e0a1d2..b946e16a 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -134,12 +134,9 @@ def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int], async_op=False) -> """ group = DeviceGroup().get_group(ranks) idx = torch.distributed.get_rank(group) - require_grad = itensor.requires_grad with torch.no_grad(): otensor = itensor.chunk(len(ranks), dim)[idx] otensor = otensor.detach() - if require_grad: - otensor = otensor.requires_grad_() return otensor @@ -167,7 +164,7 @@ def rdscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, shape = list(shape) shape[dim] = shape[dim] // len(dsts) otensor = torch.empty( - shape, requires_grad=True, dtype=dtype, + shape, requires_grad=False, dtype=dtype, device=torch.cuda.current_device() ) if async_op: @@ -190,7 +187,7 @@ def rvscatter(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, group = DeviceGroup().get_group((src,) + dsts) rank = torch.distributed.get_rank() tensor: torch.Tensor = itensor / len(dsts) if src == rank else \ - torch.empty(shape, dtype=dtype, requires_grad = True) + torch.empty(shape, dtype=dtype, requires_grad=False) tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor work = torch.distributed.broadcast(tensor, src, group=group, async_op=async_op) if work: @@ -226,7 +223,6 @@ def rdgather(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, otensor = itensor else: otensor = torch.cat(tuple(recv_tensors), dim=dim) - otensor = otensor.requires_grad_() else: assert rank in srcs otensor = itensor.contiguous() if not itensor.is_contiguous() else itensor @@ -249,7 +245,7 @@ def rvgather(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, CudaTimer().start(field_name='comm', predefined=True) rank = torch.distributed.get_rank() group = DeviceGroup().get_group(srcs + (dst,)) - tensor = torch.zeros(shape, dtype=dtype, requires_grad=True) if rank == dst else itensor + tensor = torch.zeros(shape, dtype=dtype, requires_grad=False) if rank == dst else itensor work = torch.distributed.reduce(tensor, dst, group=group, async_op=async_op) if work and rank == dst: AsyncCommHandler().submit(tensor, [work]) @@ -272,7 +268,7 @@ def broadcast(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, src: else: assert rank in ranks tensor = torch.empty(shape, - device=torch.cuda.current_device(), requires_grad=True, dtype=dtype) + device=torch.cuda.current_device(), requires_grad=False, dtype=dtype) work = torch.distributed.broadcast(tensor, src, group=group, async_op=async_op) if work and rank != src: AsyncCommHandler().submit(tensor, [work]) From 2aabf33613f55f79edf910939b64bee9d7818e6a Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Tue, 11 Apr 2023 05:57:17 +0000 Subject: [PATCH 1386/1892] Merged PR 1535: six layers parity checked 1. support apex layernorm 2. support loss scaling 3. updated concrete tracer 4. fix bug of dropout module (i.e., dropout rate) 5. revert to naive implementation to workaround the multiref bug --- cube/graph/function/function.py | 34 ++++ .../concrete_trace_utils/concrete_tracer.py | 169 +++++++++++------- .../kwargs_shape_prop/kwargs_interpreter.py | 67 +++++-- cube/graph/parser/converter.py | 44 ++--- cube/graph/parser/mappingfx.py | 1 + cube/graph/parser/parserfx.py | 78 +++++--- cube/runtime/executor.py | 5 + cube/runtime/function/function.py | 14 +- 8 files changed, 289 insertions(+), 123 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 32b666cb..96f393df 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -624,6 +624,39 @@ def LayerNorm(input, normalized_shape, weight=None, bias=None, eps=1e-05, signat return CubeLayerNorm(input, weight, bias, normalized_shape, eps, signature=signature) +def FusedLayerNorm(input, weight, bias, normalized_shape, eps=1e-5, signature = None): + """ + apex.normalization.fused_layer_norm.FusedLayerNorm + """ + signature = 'cube.runtime.function.fused_layer_norm' + assert not (weight is None and bias is not None), f"Not support for None of weight and parameter of bias" + letters = iter(string.ascii_lowercase) + einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) + eoutput = copy.copy(einput) + ndims = len(input.shape) + for dim in range(len(normalized_shape)): + einput[ndims-1-dim] += '^' + eoutput[ndims-1-dim] += '^' + einputs, inputs = [einput], [input] + kwargs = {} + if weight is not None: + eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) + einputs.append(eweight) + inputs.append(weight) + else: + kwargs['weight'] = weight + if bias is not None: + ebias = ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) + einputs.append(ebias) + inputs.append(bias) + else: + kwargs['bias'] = bias + anno = OpAnno.create_op_str(einputs, [eoutput]) + kwargs['normalized_shape'] = normalized_shape + kwargs['eps'] = eps + return IRDimops(FusedLayerNorm, 'fusedlayernorm', signature, [anno], inputs, **kwargs) + + def Sum(input, dim=None, keepdim=False, *, dtype=None, signature = None): """ torch.sum(input, *, dtype=None) -> Tensor @@ -654,6 +687,7 @@ def Sum(input, dim=None, keepdim=False, *, dtype=None, signature = None): anno = OpAnno.create_op_str([einput], [eoutput]) return IRDimops(Sum, 'sum', signature, [anno], [input], dim=dim, keepdim=keepdim) + def Mean(input, dim=None, keepdim=False, *, dtype=None, signature = None): """ torch.mean(input, *, dtype=None) -> Tensor diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index 72f48be5..6ec2f930 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -212,7 +212,7 @@ class ConcreteTracer(TracerBase): node_to_originating_module : Dict[torch.fx.Node, str] = {} @compatibility(is_backward_compatible=True) - def __init__(self): + def __init__(self, fake_device_type='cpu'): """ similar to _symbolic_trace.Tracer.__init__. remove the 'param_shapes_constant' because we can get real shape when executing. @@ -221,6 +221,8 @@ def __init__(self): self.scope = Scope("", None) self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} + assert fake_device_type in ('cuda', 'cpu') + self.fake_device_type = fake_device_type @contextmanager def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): @@ -279,75 +281,107 @@ def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: D if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) with self.do_temp_disable(call=True): - to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - result = OperatorPatcherContext.patch_run(fn, *args, **kwargs) - for arg in args: - if _orig_isinstance(arg, torch.Tensor): - del arg - del args - for key, value in kwargs.items(): - if _orig_isinstance(value, torch.Tensor): - del value - del kwargs - if _orig_isinstance(result, torch.Tensor): - result_cpu = result.cpu() - del result + if self.fake_device_type == 'cpu': + to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + result = OperatorPatcherContext.patch_run(fn, *args, **kwargs) + for arg in args: + if _orig_isinstance(arg, torch.Tensor): + del arg + del args + for key, value in kwargs.items(): + if _orig_isinstance(value, torch.Tensor): + del value + del kwargs + if _orig_isinstance(result, torch.Tensor): + result_cpu = result.cpu() + del result + torch.cuda.empty_cache() + return result_cpu + if not isinstance(result, (tuple, list, dict)): + torch.cuda.empty_cache() + return result + to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t + result_cpu = tree_map(to_cpu, result) + for ret in result: + if _orig_isinstance(ret, torch.Tensor): + del ret torch.cuda.empty_cache() return result_cpu - if not isinstance(result, (tuple, list, dict)): - torch.cuda.empty_cache() - return result - to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t - result_cpu = tree_map(to_cpu, result) - for ret in result: - if _orig_isinstance(ret, torch.Tensor): - del ret - torch.cuda.empty_cache() - return result_cpu + else: + return OperatorPatcherContext.patch_run(fn, *args, **kwargs) elif kind == 'call_method': - self_obj, *args_tail = args - fn = _orig_getattr(self_obj, target) - if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): - _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) with self.do_temp_disable(call=True): - return OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) + if self.fake_device_type == 'cpu': + to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + self_obj, *args_tail = args + fn = _orig_getattr(self_obj, target) + if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): + _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + result = OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) + if _orig_isinstance(result, torch.Tensor): + result_cpu = result.cpu() + del result + torch.cuda.empty_cache() + return result_cpu + if not isinstance(result, (tuple, list, dict)): + torch.cuda.empty_cache() + return result + to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t + result_cpu = tree_map(to_cpu, result) + for ret in result: + if _orig_isinstance(ret, torch.Tensor): + del ret + torch.cuda.empty_cache() + return result_cpu + else: + self_obj, *args_tail = args + fn = _orig_getattr(self_obj, target) + if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): + _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + return OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) elif kind == 'call_module': assert isinstance(target, str) mod = self.fetch_attr(target) - mod.cuda() + if self.fake_device_type == 'cpu': + mod.cuda() if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(mod, '__globals__'): _autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) with self.do_temp_disable(call=True): - to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - result = OperatorPatcherContext.patch_run(mod, *args, **kwargs) - for arg in args: - if _orig_isinstance(arg, torch.Tensor): - del arg - del args - for key, value in kwargs.items(): - if _orig_isinstance(value, torch.Tensor): - del value - del kwargs - mod.cpu() - if _orig_isinstance(result, torch.Tensor): - result_cpu = result.cpu() - del result + if self.fake_device_type == 'cpu': + to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + result = OperatorPatcherContext.patch_run(mod, *args, **kwargs) + for arg in args: + if _orig_isinstance(arg, torch.Tensor): + del arg + del args + for key, value in kwargs.items(): + if _orig_isinstance(value, torch.Tensor): + del value + del kwargs + mod.cpu() + if _orig_isinstance(result, torch.Tensor): + result_cpu = result.cpu() + del result + torch.cuda.empty_cache() + return result_cpu + if not isinstance(result, (tuple, list, dict)): + torch.cuda.empty_cache() + return result + to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t + result_cpu = tree_map(to_cpu, result) + for ret in result: + if _orig_isinstance(ret, torch.Tensor): + del ret torch.cuda.empty_cache() return result_cpu - if not isinstance(result, (tuple, list, dict)): - torch.cuda.empty_cache() - return result - to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t - result_cpu = tree_map(to_cpu, result) - for ret in result: - if _orig_isinstance(ret, torch.Tensor): - del ret - torch.cuda.empty_cache() - return result_cpu + else: + return OperatorPatcherContext.patch_run(mod, *args, **kwargs) elif kind == 'get_attr': assert isinstance(target, str) return self.fetch_attr(target) @@ -1390,7 +1424,8 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], autowrap_leaf_function = None, autowrap_leaf_class = None, leaf_module: Tuple | None = None, - fake_middle_class = None,) -> GraphModule: + fake_middle_class = None, + fake_device_type='cpu') -> GraphModule: """ Concrete tracing API @@ -1516,8 +1551,18 @@ def f(x, y): Returns: fx.GraphModule: a Module created from the recorded operations from ``root``. """ - tracer = ConcreteTracer() + tracer = ConcreteTracer(fake_device_type=fake_device_type) + graph = tracer.trace(root, + autowrap_leaf_function = autowrap_leaf_function, + autowrap_leaf_class = autowrap_leaf_class, + leaf_module = leaf_module, + fake_middle_class = fake_middle_class, + concrete_args=concrete_args, + use_operator_patch=use_operator_patch, + operator_patch_backlist=operator_patch_backlist, + forward_function_name=forward_function_name, + ) graph = tracer.trace(root, autowrap_leaf_function = autowrap_leaf_function, autowrap_leaf_class = autowrap_leaf_class, @@ -1539,7 +1584,7 @@ def f(x, y): forward_function_name=forward_function_name, ) # compare to check equal - assert len(graph.nodes) == len(graph_check.nodes) + assert len(graph.nodes) == len(graph_check.nodes), f'number nodes: {len(graph.nodes)} vs {len(graph_check.nodes)}' for node_a, node_b in zip(graph.nodes, graph_check.nodes): node_a: Node node_b: Node @@ -1553,7 +1598,7 @@ def f(x, y): assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\ hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function) else: - assert node_a.op == node_b.op and target_a == target_b + assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}' with MagicMethodPatcher(): name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py index 8d9a5e11..4799a2b3 100644 --- a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py +++ b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py @@ -1,7 +1,7 @@ import torch import torch.fx import torch.fx.traceback as fx_traceback -from torch.fx import Interpreter, Node +from torch.fx import Interpreter, Node, GraphModule from typing import Optional, Union, Tuple, Dict, List, Any, Iterator, Callable, MutableMapping, Mapping from torch.utils._pytree import tree_map @@ -21,6 +21,11 @@ class KwargsInterpreter(Interpreter): + def __init__(self, module : GraphModule, garbage_collect_values : bool = True, fake_device_type='cpu'): + super().__init__(module, garbage_collect_values) + assert fake_device_type in ('cpu', 'cuda') + self.fake_device_type = fake_device_type + def run(self, concrete_args: Union[Dict[str, Any], Tuple, MutableMapping[str, Any], Mapping[str, Any]] = None, initial_env: Optional[Dict[Node, Any]] = None, @@ -147,26 +152,52 @@ def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[s f'Expected positional argument for parameter {target}, but one was not passed in!') def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: - to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - result = super().call_function(target, args, kwargs) - if isinstance(result, torch.Tensor): - return result.cpu() + assert not isinstance(target, str) + if self.fake_device_type == 'cpu': + to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + result = target(*args, **kwargs) + if isinstance(result, torch.Tensor): + return result.cpu() + else: + to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t + return tree_map(to_cpu, result) else: - to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t - return tree_map(to_cpu, result) + return target(*args, **kwargs) + + def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + assert isinstance(target, str) + if self.fake_device_type == 'cpu': + self_obj = self_obj.cuda() + to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t + args_tail = tree_map(to_cuda, args_tail) + kwargs = tree_map(to_cuda, kwargs) + result = getattr(self_obj, target)(*args_tail, **kwargs) + self_obj = self_obj.cpu() + if isinstance(result, torch.Tensor): + return result.cpu() + else: + to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t + return tree_map(to_cpu, result) + else: + return getattr(self_obj, target)(*args_tail, **kwargs) def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: assert isinstance(target, str) mod = self.fetch_attr(target) - mod = mod.cuda() - to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - result = mod(*args, **kwargs) - if isinstance(result, torch.Tensor): - return result.cpu() + if self.fake_device_type == 'cpu': + mod = mod.cuda() + to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + result = mod(*args, **kwargs) + if isinstance(result, torch.Tensor): + return result.cpu() + else: + to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t + return tree_map(to_cpu, result) else: - to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t - return tree_map(to_cpu, result) + return mod(*args, **kwargs) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index c5c8d2b4..09dd2b94 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -10,6 +10,12 @@ import torch import torch.fx +try: + import apex + HAS_APEX = True +except: + HAS_APEX = False + def convert_model(model: torch.nn.Module, input_shapes: Optional[ List[List[int],] ] = None, dummy_input = None, @@ -29,19 +35,20 @@ def convert_model(model: torch.nn.Module, smodule.graph.print_tabular() else: print('INFO: using concrete tracer') - # NOTE: remove this part because when model is too large to fit in one GPU, - # this model forward cannot be successfully done, thus remove it. - # with torch.no_grad(): - # if isinstance(dummy_input, torch.Tensor): - # output_origin = model(dummy_input) - # dummy_input = (dummy_input, ) - # elif isinstance(dummy_input, tuple) or isinstance(dummy_input, list): - # output_origin = model(*dummy_input) - # elif isinstance(dummy_input, dict): - # print(f'WARNING dict dummy_input') - # output_origin = model(**dummy_input) - # else: - # raise RuntimeError(f'dummy_input should be a tuple (not a {type(dummy_input)}) = {dummy_input}') + if HAS_APEX: + leaf_module = ( + torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, + apex.normalization.FusedLayerNorm, + # NOTE: the following modules also have different behavior depending on self.training. but currently in used. + # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, + # torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, + # torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.SyncBatchNorm, + # torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, + # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, + ) + else: + print('WARNING: apex package is not installed') + leaf_module = (torch.nn.Dropout, ) traced_model = concrete_trace( model, dummy_input, @@ -52,15 +59,8 @@ def convert_model(model: torch.nn.Module, }, # FIXME: check if dropout is not included in it, can self.training be handled properly in the new version of # concrete_trace - leaf_module=( - torch.nn.Dropout, torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, - # NOTE: the following modules also have different behavior depending on self.training. but currently in used. - # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, - # torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, - # torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.SyncBatchNorm, - # torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, - # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, - ), + leaf_module=leaf_module, + fake_device_type='cpu', ) else: print('using torchscript tracer') diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 56676167..6f5cc357 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -97,6 +97,7 @@ def exist(signature: str) -> bool: __ftemplate('nll_loss') : function.NLLLoss, __ftemplate('layer_norm'): function.LayerNorm, + 'apex.normalization.fused_layer_norm.FusedLayerNorm': function.FusedLayerNorm, # ============== runtime function ================= __tttemplate('size'): function.Size, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 6b9f5b0d..1ea91ecb 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -47,6 +47,17 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo return m.__module__.startswith('torch.nn.functional') and not isinstance(m, torch.nn.Sequential) +def get_complex_data(val: Any, frame: Frame) -> Any: + """Change inner fx.Node into IRObject""" + if isinstance(val, tuple): + return tuple(get_complex_data(t, frame) for t in val) + if isinstance(val, list): + return list(get_complex_data(t, frame) for t in val) + if isinstance(val, torch.fx.Node): + return frame.get_var(val.name) + return val + + class FxModuleParser: save_content: bool = True @@ -195,7 +206,6 @@ def parse(module: torch.fx.GraphModule, # output_val = frame.get_var(output_nodes[0].name) output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] - if FxModuleParser.save_content: frame.save_attr_content() frame.save_attr_map() @@ -260,9 +270,47 @@ def fetch_attr(mod: torch.fx.GraphModule, target: str): @staticmethod def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: prim_module = FxModuleParser.fetch_attr(module, node.target) - assert prim_module.__class__.__module__.startswith('torch.nn.modules'), f'{module.__class__.__module__}' - fsig = 'torch.nn.{}'.format(prim_module.__class__.__name__) - return FxModuleParser._parse_node(fsig, node, module, frame) + input_vals = [get_complex_data(val, frame) for val in node.args] + kwargs = {key: get_complex_data(val, frame) for key, val in node.kwargs.items()} + if prim_module.__class__.__module__.startswith('torch.nn.modules'): + fsig = 'torch.nn.{}'.format(prim_module.__class__.__name__) + # specifically deal with torch.nn.Dropout, because some inputs of nn.module are passed + # in module instantiating phase, besides during forward + assert prim_module.__class__.__name__ == 'Dropout', f'{prim_module.__class__.__name__}, {fsig}' + kwargs.update({'p': prim_module.p, 'inplace': prim_module.inplace}) + return FxModuleParser._parse_node(fsig, node, input_vals, kwargs, frame) + elif prim_module.__class__.__module__ == 'apex.normalization.fused_layer_norm': + fsig = '{}.{}'.format(prim_module.__class__.__module__, prim_module.__class__.__name__) + assert prim_module.elementwise_affine is True + assert SignFx2Op.exist(fsig) + assert len(kwargs) == 0 + # add var of weight and bias into frame + shape = FxModuleParser.shape_refine(prim_module.weight.size()) + dtype = DType2IRDType.map(prim_module.weight.dtype) + requires_grad = prim_module.weight.requires_grad + ir_weight_val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=f'{node.name}_weight') + ir_weight_val.as_param() + frame.add_var(ir_weight_val.name, ir_weight_val) + frame.add_attr_content(ir_weight_val.tid, prim_module.weight) + frame.add_attr_map(ir_weight_val.name, node.target+'.weight') + shape = FxModuleParser.shape_refine(prim_module.bias.size()) + dtype = DType2IRDType.map(prim_module.bias.dtype) + requires_grad = prim_module.bias.requires_grad + ir_bias_val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=f'{node.name}_bias') + ir_bias_val.as_param() + frame.add_var(ir_bias_val.name, ir_bias_val) + frame.add_attr_content(ir_bias_val.tid, prim_module.bias) + frame.add_attr_map(ir_bias_val.name, node.target+'.bias') + input_vals.extend([ir_weight_val, ir_bias_val]) + kwargs.update({'normalized_shape': prim_module.normalized_shape, 'eps': prim_module.eps}) + ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) + assert isinstance(ir_node, IRCell) + assert len(ir_node.outputs()) == 1 + output_val = frame.get_var(node.name) + ir_node.set_output(0, output_val) + return [ir_node] + else: + raise RuntimeError(f'unknown module: {prim_module.__class__.__module__}') @staticmethod def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: @@ -272,25 +320,15 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule print(f'parse_prim_method_node: {fsig}') else: print(f'parse_prim_function_node: {fsig}') - return FxModuleParser._parse_node(fsig, node, module, frame) - - @staticmethod - def _parse_node(fsig: str, node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: - - def get_complex_data(val: Any) -> Any: - """Change inner fx.Node into IRObject""" - if isinstance(val, tuple): - return tuple(get_complex_data(t) for t in val) - if isinstance(val, list): - return list(get_complex_data(t) for t in val) - if isinstance(val, torch.fx.Node): - return frame.get_var(val.name) - return val # get inputs - input_vals = [get_complex_data(val) for val in node.args] - kwargs = {key: get_complex_data(val) for key, val in node.kwargs.items()} + input_vals = [get_complex_data(val, frame) for val in node.args] + kwargs = {key: get_complex_data(val, frame) for key, val in node.kwargs.items()} + return FxModuleParser._parse_node(fsig, node, input_vals, kwargs, frame) + + @staticmethod + def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, frame: Frame) -> List[IRFwOperation]: # map to IR operator if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 93ef0af6..c7dd8f0b 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -74,6 +74,7 @@ class Executor: # Each graph has its name, and multiple call for the graph will append # (instant id -> detached) input tensor pairs for backward reference. _detach: Dict[str, List[TensorPairs]] = dict() + _fn: Callable = None @staticmethod def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): @@ -140,6 +141,10 @@ def backward(name: str, gradient tensors corresponding to input_tensors. """ output_tensor_grads = Executor.sync_tensors(output_tensor_grads) + if Executor._fn is not None and output_tensor_grads[0] is None: + assert len(output_tensor_grads) == 1 + assert len(output_tensors) == 1 + output_tensors = (Executor._fn(output_tensors[0]), ) if len(output_tensors) == 0: return None diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index e97ffd86..059bb6da 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -2,6 +2,11 @@ import torch import torch.nn.functional as TorchF +try: + from apex.normalization.fused_layer_norm import fused_layer_norm_affine +except: + print('WARNING: apex is not installed, skip it.') + def identity(tensor: torch.Tensor) -> torch.Tensor: """ @@ -112,6 +117,12 @@ def layer_norm(input: torch.Tensor, return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias, eps) +def fused_layer_norm(input: torch.Tensor, + weight: torch.Tensor, bias: torch.Tensor, + normalized_shape: List[int], eps: float = 1e-05) -> torch.Tensor: + return fused_layer_norm_affine(input, weight, bias, normalized_shape, eps) + + # 'torch.select_scatter' isn't supported by Torch2ONNX yet. # Implement it with 'torch.masked_scatter' which is supported with ONNX opset=11. def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): @@ -179,5 +190,6 @@ def stack(*tensors, dim=0) -> torch.Tensor: def cat(*tensors, dim=0) -> torch.Tensor: return torch.cat(tensors, dim) + def nndropout(input: torch.Tensor, p=0.5, inplace=False): - return torch.nn.Dropout(0.0, inplace)(input) + return torch.nn.Dropout(p, inplace)(input) From fae15a1b0055ef2090b8dfc34032c80d76be3215 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 12 Apr 2023 07:55:26 +0000 Subject: [PATCH 1387/1892] Merged PR 1545: Fix indexselect annotation bug --- cube/graph/function/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 96f393df..bb22d207 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1295,7 +1295,7 @@ def CubeIndexSelect(input: torch.Tensor, index: torch.Tensor, dim: int, signatur signature = 'cube.runtime.function.index_select' edim_in = ShapeAnno.create_shape_str(input.shape) edim_in[dim] += '^' - idx_anno = chr(ord(edim_in[-1]) + 1) + '^' + idx_anno = chr(ord(edim_in[-1]) + 1) edim_ou = copy.copy(edim_in) edim_ou[dim] = copy.copy(idx_anno) anno = OpAnno.create_op_str([edim_in, idx_anno], [edim_ou]) From a849a198c0ca12bbb9a3fe4e3cd7505317f1a6e7 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Mon, 24 Apr 2023 07:24:30 +0000 Subject: [PATCH 1388/1892] Merged PR 1555: change dropout from module to functional change dropout from module to functional --- cube/graph/function/function.py | 2 +- cube/graph/parser/converter.py | 6 ++---- cube/graph/parser/mappingfx.py | 2 +- cube/graph/parser/parserfx.py | 1 + 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index bb22d207..cf8a8aad 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -505,7 +505,7 @@ def Dropout(input, p=0.5, training=True, inplace=False, signature = None): """ annos = ['* -> *'] return IRDimops(Dropout, 'dropout', signature, annos, [input], - p=p, training=training, inplace=inplace) + p=p, training='self.training', inplace=inplace) def nnDropout(input, p=0.5, inplace=False, signature=None): diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 09dd2b94..9e2c204d 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -37,7 +37,7 @@ def convert_model(model: torch.nn.Module, print('INFO: using concrete tracer') if HAS_APEX: leaf_module = ( - torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, + # torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, apex.normalization.FusedLayerNorm, # NOTE: the following modules also have different behavior depending on self.training. but currently in used. # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, @@ -48,7 +48,7 @@ def convert_model(model: torch.nn.Module, ) else: print('WARNING: apex package is not installed') - leaf_module = (torch.nn.Dropout, ) + leaf_module = None traced_model = concrete_trace( model, dummy_input, @@ -57,8 +57,6 @@ def convert_model(model: torch.nn.Module, torch.finfo: ((), False), # type(output_origin): ((), False), }, - # FIXME: check if dropout is not included in it, can self.training be handled properly in the new version of - # concrete_trace leaf_module=leaf_module, fake_device_type='cpu', ) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 6f5cc357..6033aecc 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -52,7 +52,7 @@ def exist(signature: str) -> bool: __customops = lambda name: f'examples.custom_ops.{name}' kOpMap = { - __tnmtemplate('Dropout'): function.nnDropout, + # __tnmtemplate('Dropout'): function.nnDropout, __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 1ea91ecb..43e7b64b 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -273,6 +273,7 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: input_vals = [get_complex_data(val, frame) for val in node.args] kwargs = {key: get_complex_data(val, frame) for key, val in node.kwargs.items()} if prim_module.__class__.__module__.startswith('torch.nn.modules'): + assert False, 'Dropout is not supposed to be treated as module.' fsig = 'torch.nn.{}'.format(prim_module.__class__.__name__) # specifically deal with torch.nn.Dropout, because some inputs of nn.module are passed # in module instantiating phase, besides during forward From 63c8382070376baa08c1fad6d7823492c7cc1879 Mon Sep 17 00:00:00 2001 From: Rongwei Lu Date: Tue, 25 Apr 2023 10:32:04 +0000 Subject: [PATCH 1389/1892] Merged PR 1559: add a override switch for profile add a override switch for profile --- cube/profiler/database.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 098148b4..1d54a60d 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -210,7 +210,7 @@ def get_dep_names(sign: str): values.append(t) return fn, shapes, dtypes, requires_grads, values, node.kwargs - def profile(self, node: IRFwOperation, device: Optional[int] = None): + def profile(self, node: IRFwOperation, device: Optional[int] = None, override: bool = False): """ Profile a forward node in IRGraph on a specific device (default current device) @@ -227,7 +227,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): """ fn, shapes, dtypes, requires_grads, values, kwargs = ProfileDataBase.get_func(node) - if self.exist(node): + if not override and self.exist(node): return self.query(node) if isinstance(device, int): @@ -252,7 +252,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None): fw_span, bw_span, infer_memory, train_mem_info = \ CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) except: - fw_span, bw_span, infer_memory, train_mem_info = float('inf'), float('inf'), 0, [0] + fw_span, bw_span, infer_memory, train_mem_info = float('inf'), float('inf'), 0, [] # log to database key = self._serialize(node) self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) From d095c034fd3ba2421b5a3b508e3b8ab6965c97cf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sat, 6 May 2023 11:58:15 +0000 Subject: [PATCH 1390/1892] Merged PR 1562: Alpa Policy Example --- examples/policies/__init__.py | 3 +- examples/policies/alpa/README.md | 26 ++ examples/policies/alpa/__init__.py | 214 ++++++++++++++ examples/policies/alpa/cost_model.py | 227 +++++++++++++++ examples/policies/alpa/estimator.py | 413 +++++++++++++++++++++++++++ examples/policies/alpa/inter_op.py | 174 +++++++++++ examples/policies/alpa/intra_op.py | 230 +++++++++++++++ examples/policies/alpa/layer_op.py | 42 +++ examples/policies/alpa/plan.py | 105 +++++++ 9 files changed, 1433 insertions(+), 1 deletion(-) create mode 100644 examples/policies/alpa/README.md create mode 100644 examples/policies/alpa/__init__.py create mode 100644 examples/policies/alpa/cost_model.py create mode 100644 examples/policies/alpa/estimator.py create mode 100644 examples/policies/alpa/inter_op.py create mode 100644 examples/policies/alpa/intra_op.py create mode 100644 examples/policies/alpa/layer_op.py create mode 100644 examples/policies/alpa/plan.py diff --git a/examples/policies/__init__.py b/examples/policies/__init__.py index c7b028f8..749781c0 100644 --- a/examples/policies/__init__.py +++ b/examples/policies/__init__.py @@ -1,2 +1,3 @@ from examples.policies.gshard import PASGShard -from examples.policies.random_spmd import PASRandomSPMD \ No newline at end of file +from examples.policies.random_spmd import PASRandomSPMD +from examples.policies.alpa import PASAlpa \ No newline at end of file diff --git a/examples/policies/alpa/README.md b/examples/policies/alpa/README.md new file mode 100644 index 00000000..359e380b --- /dev/null +++ b/examples/policies/alpa/README.md @@ -0,0 +1,26 @@ + +# Alpa Implementation + +## Prerequisite + +```sh +pip install pulp +``` + +## Implementation Notes + +* The implementation doesn't support auto_layer construction, and relies on the `cube.runtime.function.anchor` as stage division candidates. + +* The implementation doesn't support `follow`, which relies on the user customized operator to achieve manual fusion. + +* For computation cost: + + * we assume the full efficiency, which is calculated by `cost/tp/dp` + + * Similar with Alpa, we force computation-intensive operators to be partitioned, and allow computation-light operators to be replicated. The computation-intensive operators are defined as operators that require weight for input (usually are customized operators). + +* For communication cost: + + * Similar with Alpa, we calculate the cost of communication by `bytes / bandwidth`. + + diff --git a/examples/policies/alpa/__init__.py b/examples/policies/alpa/__init__.py new file mode 100644 index 00000000..b028e23d --- /dev/null +++ b/examples/policies/alpa/__init__.py @@ -0,0 +1,214 @@ +from typing import List, Optional +from functools import partial +import warnings +import torch + +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.dimops import IRDimops +from cube.graph.graph import IRGraph +from cube.graph.segment import IRSegment +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.schedule.predefined import PredefinedSched +from cube.runtime.device import DeviceGroup + +from examples.policies.alpa.plan import ParallelSpec +from examples.policies.alpa.inter_op import inter_op +from examples.policies.alpa.intra_op import intra_op +from examples.policies.alpa.layer_op import annotate_structure +from examples.policies.alpa.cost_model import CostModel +from examples.policies.alpa.estimator import Estimator + + +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]) -> List[IRDimops]: + """Replicate a node""" + sub_nodes = [node] if len(devs) == 1 else graph.replicate(node, len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def _tp(graph: IRGraph, node: IRDimops, devs: List[int], **configs) -> List[IRDimops]: + """Tensor parallelism on a node""" + sub_nodes = [node] if len(devs) == 1 \ + else graph.partition(node, node.algorithms('dim'), **configs) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def PASAlpa(graph: IRGraph, resource, + recompute: bool = False, + nmicros: int = 1, + db_cache: str = 'db_train.json', + load_spec_file: Optional[str] = None, + save_spec_file: Optional[str] = None, + use_multiref: bool = False, + max_pp_size: Optional[int] = None, + max_tp_size: Optional[int] = None, + max_layer_number: int = 12) -> IRGraph: + """ + Alpa policy examples. + + Require user to manually add cune.runtime.anchor inside model + for AutoLayer partition position + + @param graph IRGraph: model graph + @param rresource Resource: resource + @param recompute bool: whether to enable recompute on each layer + @param nmicros int: number of micro-batches + @param db_cache str: database cache file + @param load_spec_file str: reuse spec file + @param save_spec_file str: save spec file + @param max_pp_size Optional[int]: limit the maximum number of pipeline parallelism size + @param max_tp_size Optional[int]: limit the maximum number of tensor parallelism size + @param max_layer_number Optional[int]: maximum number of layers to search + """ + # enable this for multiref + if use_multiref: + for ftensor in graph.full_tensors(): + if ftensor.is_grad() or ftensor.is_attr(): continue + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor) + + # recompute granularity will follow original anchor scope + layers = annotate_structure(graph) + if recompute: + for layer in layers: + graph.recompute(layer) + + anchors = graph.select(ntype=IRGraphAnchor) + nlayers = len(anchors) + 1 + removed = 0 + while removed < nlayers - max_layer_number: + for anchor in list(anchors[::2]): + graph.remove(anchor) + anchors.remove(anchor) + removed += 1 + if removed >= nlayers - max_layer_number: break + anchors = graph.select(ntype=IRGraphAnchor) + if removed > 0: + print(f'> shrink search space to {len(anchors)+1} layers') + + # enable this will follow alpa's policy: recompute on auto-layer granularity + # layers = annotate_structure(graph) + # if recompute: + # for layer in layers: + # graph.recompute(layer) + nodes = tuple(graph.select(ntype=IRFwOperation)) + + dl: IRDataOperation = graph.select(ntype=IRDataOperation)[0] + mbs: int = dl.output(0).shape[dl.get_batch_dims()[0]] + + # reserve 2GB memory for nccl + mem_limit = resource.gpus[0].memory - 2 * 1024 * 1024 * 1024 + print(f'> search [constraints]: device limitied memory: {mem_limit}') + # profile + print(f'> profiling model...') + estimator = Estimator(db_cache) + latency, memory = estimator(nodes, train=graph.train) + print(f'> search [estimation]: single device latency: {latency} ms, memory: {memory/1024/1024/1024} GB') + if DeviceGroup().rank == 0: + print(f'> search [dump]: saving profiled database...') + estimator.save() + # build cost model + print(f'> building cost model...') + cost_model = CostModel(graph, estimator) + + # alpa search -- only apply on rank 0 to ensure deterministic + if DeviceGroup().rank == 0: + if isinstance(load_spec_file, str): + print(f'loading spec from {load_spec_file}...') + config = ParallelSpec.load(load_spec_file, graph) + else: + print(f'> start searching...') + intra_solver = partial(intra_op, recompute=recompute, memory_limit=mem_limit, cost_model=cost_model) + config = inter_op(nodes, resource.ngpus, intra_solver, mbs, + max_p=max_pp_size, max_t=max_tp_size) + print(f'> parallel spec results:\n{config}') + + if isinstance(save_spec_file, str): + print(f'> saving spec to {save_spec_file}...') + config.save(save_spec_file) + + state: str = config.getstate() + state = torch.tensor([ord(c) for c in state], dtype=torch.int, device=torch.cuda.current_device()) + # notify -suppose each node has 8 gpus + for rank in range(8, DeviceGroup().world_size, 8): + print(f'> notify rank {rank} has finished searching...') + torch.distributed.send(torch.tensor([state.size(0)], device=torch.cuda.current_device()), dst=rank) + torch.distributed.send(state, dst=rank) + + else: + print('> waiting for rank 0 to finish searching...') + length = torch.tensor([0], device=torch.cuda.current_device()) + torch.distributed.recv(length, src=0) + state = torch.empty(length.item(), dtype=torch.int, device=torch.cuda.current_device()) + torch.distributed.recv(state, src=0) + state = ''.join([chr(c) for c in state.tolist()]) + config = ParallelSpec.loadstate(state) + print(f'> parallel spec results:\n{config}') + + print(f'> instantiate plan...') + # print(graph.extra_repr()) + + # staging + cid2node = {n.cid : n for n in nodes} + leading_cids = [list(stage.tp_spec.keys())[0] for stage in config.stages] + leading_nodes = [cid2node[cid] for cid in leading_cids] + graph.staging(leading_nodes) + segments = graph.select(ntype=IRSegment, flatten=False) + fsegments = [seg for seg in segments if seg.isfw()] + assert len(fsegments) == len(config.stages) + + # replicate data loader + devices = list(range(resource.ngpus)) + _replica(graph, dl, devices) + + # partition + # TODO: make data parallel to be outside of pipeline parallelism + for sidx, stage in enumerate(config.stages): + tp, dp = stage.tp_size, stage.dp_size + spec = stage.tp_spec + stage_devices, devices = devices[:tp*dp], devices[tp*dp:] + print(f'> applying spec: tp={tp}, dp={dp} for stage {sidx}...') + for node in fsegments[sidx].nodes(): + if isinstance(node, IRGraphAnchor) or node.name == 'multiref': + continue + if node.cid not in spec: + print(f'warning: node {node.name}({node.cid}) not in spec, replicate') + _replica(graph, node, stage_devices) + continue + if mbs not in node.input(0).shape: + if dp > 1: + print(f'warning: cannot find batch dimension of {node.name}({node.cid}), assuming idx=0, dim=0') + batch_dim = 0 + else: + batch_dim = node.input(0).shape.index(mbs) + strategy = spec[node.cid] if node.cid in spec else None + # data parallel + if not isinstance(node, IRDimops): + warnings.warn(f'detected a node {node.name} is not IRDimops, replicate for data parallel') + dp_nodes = [node] if dp == 1 else graph.replicate(node, times=dp) + else: + dp_nodes = [node] if dp == 1 else \ + graph.partition(node, node.algorithms('dim'), idx=0, dim=batch_dim, num=dp) + # tensor parallelism + tp_nodes = [] + for dp_node in dp_nodes: + if strategy is None: + ts = [dp_node] if tp == 1 else graph.replicate(dp_node, times=tp) + else: + idx, dim = strategy + ts = [dp_node] if tp == 1 else \ + graph.partition(dp_node, dp_node.algorithms('dim'), idx=idx, dim=dim, num=tp) + assert len(ts) == tp, f"got tp nodes: {ts} | partition {dp_node} with {strategy}" + tp_nodes += ts + for devid, tp_node in zip(stage_devices, tp_nodes): + graph.assign(tp_node, devid) + # print(graph.extra_repr()) + # setup schedule + if graph.train: + sched = PredefinedSched.sched_1f1b(graph, nmicros, len(config.stages)) + else: + sched = PredefinedSched.sched_infer_pipe(graph, nmicros, len(config.stages)) + return graph diff --git a/examples/policies/alpa/cost_model.py b/examples/policies/alpa/cost_model.py new file mode 100644 index 00000000..a14b2125 --- /dev/null +++ b/examples/policies/alpa/cost_model.py @@ -0,0 +1,227 @@ +""" +Cost model for intra-op plan search +""" +from typing import List, Callable, Tuple, Dict +import numpy as np + +from cube.graph import IRGraph +from cube.ir.cten import IRTensor +from cube.ir.operator import IRFwOperation +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.dimops import IRDimops, TransformRule, DimopSplit + + +DistSpec = Dict[int, Tuple[Tuple[int, int]]] + + +class CommCost: + """ + Get communication cost in milliseconds + """ + @staticmethod + def get_bandwidth(ranks: List[int]): + """ + TODO: support with real runtime information + """ + if len(ranks) < 8: + return 150 * 1e9 # 150 GB/s for intra-node (NVLink) + else: + return 12.5 * 1e9 # 12.5 GB/s for inter-node (IB) + + @staticmethod + def allreduce_cost(tensor: IRTensor, num_devices: int) -> float: + bandwidth = CommCost.get_bandwidth(list(range(num_devices))) + return 2 * (num_devices - 1) * tensor.byte_size() / num_devices / bandwidth * 1000 + + @staticmethod + def alltoall_cost(tensor: IRTensor, num_devices: int) -> float: + # bandwidth in all-to-all is really worse (1GB/s) and should not use + return 1e6 + bandwidth = CommCost.get_bandwidth(list(range(num_devices))) + return tensor.byte_size() / num_devices / num_devices * (num_devices - 1) / bandwidth * 1000 + + @staticmethod + def allgather_cost(tensor: IRTensor, num_devices: int) -> float: + # bandwidth in allgather can only be half due to torch implementation issues + # return 1e6 + bandwidth = CommCost.get_bandwidth(list(range(num_devices))) / 2.98 + return tensor.byte_size() / num_devices * (num_devices - 1) / bandwidth * 1000 + + @staticmethod + def reducescatter_cost(tensor: IRTensor, num_devices: int) -> float: + # bandwidth in reduce-scatter can only be half due to torch implementation issues + # return 1e6 + bandwidth = CommCost.get_bandwidth(list(range(num_devices))) / 2.38 + return tensor.byte_size() / num_devices * (num_devices - 1) / bandwidth * 1000 + + +class CostModel: + + def __init__(self, graph: IRGraph, estimator: Callable): + + self.graph = graph + self.estimator = estimator + + # node property + self.comp_cost = {} + self.mem_cost = {} + + self.edges: Dict[int, List[int]] = {} + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): continue + for producer in graph.producers(ftensor): + if not isinstance(producer, IRFwOperation): continue + for consumer in graph.consumers(ftensor): + if not isinstance(consumer, IRFwOperation): continue + self.edges.setdefault(producer.cid, []).append(consumer.cid) + + # node.cid -> ((idx, dim),) + self.partition_algos: Dict[int, Tuple[int, int]] = {} + + fnodes = graph.select(ntype=IRFwOperation) + fnodes = [n for n in fnodes if not (isinstance(n, IRGraphAnchor) or n.name == 'multiref')] + + for fnode in fnodes: + latency, memory = self.estimator((fnode,)) + self.comp_cost[fnode.cid] = latency + self.mem_cost[fnode.cid] = memory + self.partition_algos[fnode.cid] = self.get_transform_space(fnode) + + def get_transform_space(self, node: IRFwOperation) -> List[Tuple[int, int]]: + """ + Get the transform space of a node + + None indicates replicate + """ + light_op_names = ('add', 'sub', 'mul', 'layernorm') + # light_op_names = () + if isinstance(node, IRDimops): + params = [t for t in node.inputs() if isinstance(t, IRTensor) and t.is_attr()] + # must be partitioned for computation-intensive ops + if len(params) > 0 and node.name not in light_op_names: # not node.signature.startswith('torch.'): + return list(node.transform_space()) + # can be partitioned or replicated for computation-light ops + else: + return [None] + node.transform_space() + return [None] + + def get_memory_cost(self, fnode: IRFwOperation) -> int: + if fnode.cid not in self.mem_cost: + if not (isinstance(fnode, IRGraphAnchor) or fnode.name == 'multiref'): + print(f'warning: cannot find memory cost for node {fnode.name}({fnode.cid})') + return 0 + return self.mem_cost[fnode.cid] + + def get_comp_cost(self, fnode: IRFwOperation, num_devices: int) -> np.ndarray: + """ + Get computation cost related to different partition strategies + """ + return np.zeros(len(self.partition_algos[fnode.cid]), dtype=float) + # cost = [] + # original_cost = self.comp_cost[fnode.cid] + # for strategy in self.partition_algos[fnode.cid]: + # if strategy is None: + # cost.append(original_cost) + # else: + # # computation efficiency simulation + # efficiency = 1 - (num_devices-1)*0.1/2 + # cost.append(original_cost / num_devices / efficiency) + # return np.array(cost, dtype=float) + + def get_comm_cost(self, fnode: IRFwOperation, num_devices) -> np.ndarray: + """ + Get communication cost for a node given a strategy + + This only calucates the cases for partitioning on value dimension + + @return cost: np.ndarray: 1-D array of the cost on allreduce + """ + cost = [] + for strategy in self.partition_algos[fnode.cid]: + if strategy is None: + cost.append(0.) + continue + s_cost = 0 + idx, dim = strategy + rule: TransformRule = fnode.algorithms('dim').infer(idx, dim, num_devices) + for idx, output in enumerate(rule.outputs()): + if output.isV(): + s_cost += CommCost.allreduce_cost(fnode.output(idx), num_devices) + cost.append(s_cost) + return np.array(cost, dtype=float) + + def get_pair_reshard_cost(self, fnode_src: IRFwOperation, fnode_dst: IRFwOperation, + num_devices: int) -> np.ndarray: + """ + Get cost of resharding between two nodes + @return cost: np.ndarray: 1-D tensor of (nsrc * ndst,) shape, + nsrc is the number of partitioned ways of the source node + ndst is the number of partitioned ways of the destination node + """ + nsrc = len(self.partition_algos[fnode_src.cid]) + ndst = len(self.partition_algos[fnode_dst.cid]) + cost = np.zeros((nsrc, ndst), dtype=float) + + def comm_cost(tensor: IRTensor, num_devices: int, + src_split: DimopSplit, dst_split: DimopSplit, dst_replica: bool): + # note for data parallel, we don't consider allreduce cost as it + # will only be performed at the last of iteration. + if tensor.is_attr(): return 0.0 + if src_split.isV() or src_split.isR(): + # identity-allreduce or identity-identity + if dst_split.isR(): + return 0.0 if dst_replica else CommCost.allreduce_cost(tensor, num_devices) + # split-allgather + if dst_split.isD(): + return CommCost.allgather_cost(tensor, num_devices) + if src_split.isD(): + # allgahter-reducescatter or allgather-split + if dst_split.isR(): + return CommCost.allgather_cost(tensor, num_devices) if dst_replica else \ + CommCost.allgather_cost(tensor, num_devices) + CommCost.reducescatter_cost(tensor, num_devices) + # all2all-all2all or identity-identity + if dst_split.isD(): + return 0.0 if src_split.dim == dst_split.dim else 2 * CommCost.alltoall_cost(tensor, num_devices) + raise NotImplementedError(f"Unknown split type: {src_split} -> {dst_split}") + + # FIXME: need consider cases that an operator has multiple **same** inputs + tensors: Dict[IRTensor, Tuple[int, int]] = {} + for idx, output in enumerate(fnode_src.outputs()): + tensors[output.parent] = [idx] + for idx, input in enumerate(fnode_dst.inputs()): + if not isinstance(input, IRTensor): continue + tensors.setdefault(input.parent, []).append(idx) + tensors = {t: tuple(v) for t, v in tensors.items() if len(v) == 2} + + for i, strategy_src in enumerate(self.partition_algos[fnode_src.cid]): + + rule_src = None + if strategy_src is not None: + idx, dim = strategy_src + rule_src = fnode_src.algorithms('dim').infer(idx, dim, num_devices) + + for j, strategy_dst in enumerate(self.partition_algos[fnode_dst.cid]): + rule_dst = None + if strategy_dst is not None: + idx, dim = strategy_dst + rule_dst = fnode_dst.algorithms('dim').infer(idx, dim, num_devices) + + for tensor, (idx_src, idx_dst) in tensors.items(): + cost[i, j] += comm_cost( + tensor, num_devices, + rule_src.outputs()[idx_src] if rule_src is not None else DimopSplit(r=True), + rule_dst.inputs()[idx_dst] if rule_dst is not None else DimopSplit(r=True), + strategy_dst is None + ) + return cost + + def get_edges(self, nodes: List[IRFwOperation]) -> Dict[IRFwOperation, Tuple[IRFwOperation]]: + """ + Get edges of a subgraph + """ + edges: Dict[IRFwOperation, List[IRFwOperation]] = {} + cid2nodes: Dict[int, IRFwOperation] = {n.cid : n for n in nodes} + for node in nodes: + if node.cid in self.edges: + edges[node] = [cid2nodes[cid] for cid in self.edges[node.cid] if cid in cid2nodes] + return edges diff --git a/examples/policies/alpa/estimator.py b/examples/policies/alpa/estimator.py new file mode 100644 index 00000000..2f6b67c7 --- /dev/null +++ b/examples/policies/alpa/estimator.py @@ -0,0 +1,413 @@ +from typing import Callable, Tuple, Union, Optional, Dict, NewType, List +import time +import os +import json + +# ===== neccesaary for profiling ===== +import cube +import torch +# ==================================== + +from cube.ir.cten import IRTensor, IRObject, IRCell +from cube.ir.operator import IRFwOperation +from cube.graph.parser.dtype import IRDType2TorchDType +from cube.graph.parser.register import CustomizedOps +from cube.graph.segment import IRSegment +from cube.graph.function.dimops import IRDimops +from cube.graph.function import IRGraphAnchor + + +Shapes = NewType('Shapes', Tuple[Tuple[int]]) +DTypes = NewType('DTypes', Tuple[torch.dtype]) +ShapesDTypes = NewType('ShapesDTypes', Tuple[Shapes, DTypes]) +NameOrFunc = Union[str, Callable] + + +_train_module_ref: torch.nn.Module = torch.nn.Module().train() +_eval_module_ref: torch.nn.Module = torch.nn.Module().eval() + + +class CompProfiler: + + @staticmethod + def profile(node: IRCell, train: bool = True, + warmup_sec: float = 2, prof_times: int = 50) -> Tuple[float, float, int, Tuple[int]]: + """ + Profile a function + + @param func Callable: the callable function, e.g., torch.nn.functional.linear + @param warmup_sec float: warmup seconds + @param prof_times int: profile times + + @return latency float: average latency in ms + @return memory int: average memory in bytes + """ + torch.cuda.empty_cache() + # print(f'current GPU memory: {torch.cuda.memory_allocated() / 1024 / 1024 / 1024} GB') + + func: Callable = CompProfiler.get_func(node) + args, kwargs = CompProfiler.get_inputs(node, train=train) + + # prepare gradients + with torch.no_grad(): + outputs = func(*args, **kwargs) + outputs = (outputs,) if torch.is_tensor(outputs) else outputs + assert all(torch.is_tensor(otensor) for otensor in outputs), \ + f"{func.__name__}: require all the outputs to be tensors" + grads = tuple(torch.zeros_like(otensor) for otensor in outputs) + del outputs + + def run_step(func, tensors, kwargs, backward: bool): + if not backward: + with torch.no_grad(): + outputs = func(*tensors, **kwargs) + else: + outputs = func(*tensors, **kwargs) + torch.autograd.backward(outputs, grads) + + # memory + torch.cuda.synchronize() + torch.cuda.empty_cache() + mtic = torch.cuda.max_memory_allocated() # in bytes + memory = 0 + if train: + used_tensor = set() + def pack_hook(x): + nonlocal memory, used_tensor + if x.storage().data_ptr() not in used_tensor: + used_tensor.add(x.storage().data_ptr()) + byte_size = x.element_size() + for dim in list(x.size()): + byte_size = byte_size * dim + memory += byte_size + return x + def unpack_hook(x): return x + + with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + run_step(func, args, kwargs, backward=True) + torch.cuda.synchronize() + del used_tensor + else: + run_step(func, args, kwargs, backward=False) + torch.cuda.synchronize() + mtoc = torch.cuda.max_memory_allocated() + memory = mtoc - mtic + + # warmup + torch.cuda.synchronize() + tic = time.time() + while time.time() - tic < warmup_sec: + run_step(func, args, kwargs, backward=train) + torch.cuda.synchronize() + + torch.cuda.synchronize() + tic = time.perf_counter() + for _ in range(prof_times): + run_step(func, args, kwargs, backward=train) + torch.cuda.synchronize() + toc = time.perf_counter() + latency = (toc - tic) / prof_times * 1000 # in milliseconds + + return latency, memory + + @staticmethod + def get_inputs(node: IRFwOperation, train: bool) -> Tuple[List, Dict]: + # create data + def dummy_torch_tensor(tensor: IRTensor): + """Generate dummy input tenosrs""" + dtype = IRDType2TorchDType.map(tensor.dtype) + constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand + return constructor(tuple(tensor.shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=tensor.requires_grad) + + args = [dummy_torch_tensor(t) if isinstance(t, IRTensor) else t for t in node.inputs()] + # replace kwargs starting with 'self.xxx' + kwargs = {} + for name, value in node.kwargs.items(): + if isinstance(value, str) and value.startswith('self.'): + value = getattr(_train_module_ref, value[5:]) if train else getattr(_eval_module_ref, value[5:]) + kwargs[name] = value + + return args, kwargs + + @staticmethod + def get_func(node: IRFwOperation) -> Callable: + """ + Get function call + """ + assert isinstance(node, IRFwOperation), f"Only support profiling forward operation but got {type(node)}" + + def get_dep_names(sign: str): + ret = [] + code_impl = CustomizedOps.kOpCodeDef[sign] + for code_line in code_impl.split('\n'): + idx = code_line.find('# call: ') + if idx != -1: + dep_name = code_line[idx + 8:] + assert dep_name in CustomizedOps.kOpCodeDef, dep_name + ret = ret + get_dep_names(dep_name) + ret.append(dep_name) + return ret + + if node.signature in CustomizedOps.kOpCodeDef: + dep_code_impl = '' + for dep_name in get_dep_names(node.signature): + dep_code_impl = dep_code_impl + CustomizedOps.kOpCodeDef[dep_name] + code_impl: str = CustomizedOps.kOpCodeDef[node.signature] + def_end = code_impl.find(':\n') + assert def_end >= 0 + prev_code_lines = code_impl[:def_end+2] + succ_code_lines = code_impl[def_end+2:] + for line in dep_code_impl.split('\n'): + prev_code_lines = prev_code_lines + ' ' + line + '\n' + code_impl = prev_code_lines + succ_code_lines + local = {} + exec(code_impl, globals(), local) + fn = list(local.values())[0] + else: + fn = eval(node.signature) + return fn + + +class ProfileDataBase: + + def __init__(self, filename: Optional[str] = None) -> None: + """! + Create a database for profiling result + """ + self._data: Dict[str, Dict[str, Tuple[float, float, int]]] = dict() + if filename is not None: + self.load(filename) + + def profile(self, node: IRFwOperation, train: bool = True, device: Optional[int] = None): + """ + Profile a forward node in IRGraph on a specific device (default current device) + + @param node IRFwOperation: node of IRGraph + @param device int: the device that the node will execute on + + @return latency float: average latency in ms + @return memory int: average memory in bytes + """ + if self.exist(node): + return self.query(node) + + if isinstance(device, int): + orig_device = torch.cuda.current_device() + torch.cuda.set_device(device) + + color, default = '\033[31m', '\033[0m' + + #FIXME: OOM will increase cuda allocated memory + try: + latency, memory = CompProfiler.profile(node, train) + # log to database + self.insert(node, latency, memory) + except Exception as e: + err = f'{color}profil error:\n {str(e)}{default}' + print(err) + latency, memory = e, e + + shapes = tuple(t.shape if isinstance(t, IRTensor) else None for t in node.inputs()) + dtypes = tuple(IRDType2TorchDType.map(t.dtype) if isinstance(t, IRTensor) else None for t in node.inputs()) + error = f'{color}None{default}' + print( + f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} | train {train} => " + f"latency: {round(latency, 2) if isinstance(latency, float) else error} ms | " + f"memory {memory if isinstance(memory, int) else None} bytes") + + if isinstance(device, int): + torch.cuda.set_device(orig_device) + return latency, memory + + def insert(self, node: IRCell, latency: float, memory: int): + """ + log (reset) the span of a node with key + + @param node IRCell + @param latency float: inference time in milliseconds + @param memory int: inference peak memory in bytes + """ + name = node.signature + key = self._serialize(node) + assert isinstance(name, str) and isinstance(key, str) + if name not in self._data: + self._data[name] = dict() + latency = latency if isinstance(latency, float) else None + memory = memory if isinstance(memory, int) else None + self._data[name][key] = (latency, memory) + + def exist(self, node: IRFwOperation) -> bool: + """ + Check if the node has the performance recorded in the database + + @param node IRFwOperation: forward operation + + @return exist bool: True if the performance is recorded, else False + """ + key = self._serialize(node) + if node.signature not in self._data: + return False + if key not in self._data[node.signature]: + return False + return True + + def query(self, node: IRFwOperation) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int]]: + """! + Get the performance number of a node in IRGraph + + @param node IRFwOperation: node in IRGraph + + @return latency float: average latency in ms + @return memory int: average memory in bytes + """ + key = self._serialize(node) + if node.signature not in self._data: + return None + if key not in self._data[node.signature]: + return None + return self._data[node.signature][key] + + def _serialize(self, node: IRFwOperation) -> str: + """ + Serialize the shapes, dtypes and kwargs into a string + + e.g., + shapes: ((1024,), (1024,1024)) + dtypes: (torch.float32, torch.float32) + => ((1024,), (1024,1024)) : (torch.float32, torch.float32) + + @param shapes Tuple[Tuple[int]]: the shape of each tensor + @param dtypes Tuple[torch.dtype]: the dtype of each tensor + + @return key str: the serialized string + """ + shapes, dtypes = [], [] + for t in node.inputs(): + if isinstance(t, IRTensor): + shapes.append(t.shape) + dtypes.append(IRDType2TorchDType.map(t.dtype)) + elif isinstance(t, IRObject): + raise RuntimeError('IRObject has not been supported in _serialize') + else: + shapes.append(None) + dtypes.append(type(t)) + shapes = str(tuple(shapes)) + dtypes= str(tuple(dtypes)) + return shapes + ' : ' + dtypes + + def _deserialize(self, key: str) -> ShapesDTypes: + """ + De-serialize the key string to shapes and dtypes + + e.g., (1024,)-(1024,1024)=torch.float32-torch.float32 + => shapes: ((1024,), (1024,1024)) + dtypes: (torch.float32, torch.float32) + + @param key str: the serialized string + @return shapes_and_dtypes ShapesDTypes: shapes and dtypes + """ + shapes, dtypes = key.split(' : ') + shapes = eval(shapes) + dtypes = eval(dtypes) + # shapes = tuple(eval(shape) for shape in shapes.split('-')) + # dtypes = tuple(eval(dtype) for dtype in dtypes.split('-')) + return shapes, dtypes + + def dump(self, file: str, override=False): + """! + dump the profiled data into json format + + @param file str: the file name + @param override bool: True if the existed can be overrided else False + """ + if os.path.exists(file): + assert override, f"File {file} exists. Set override = True to force dump." + with open(file, 'w') as f: + json.dump(self._data, f) + + def load(self, file: str): + """! + load the profiled data into data base. The original existed one will be + overrided by the loaded data. + + @param file str: the file name + """ + with open(file, 'r') as f: + self._data = json.load(f) + + def __repr__(self) -> str: + data = [] + for signature in self._data: + for key in self._data[signature]: + shapes, dtypes = self._deserialize(key) + latency, memory = self._data[signature][key] + data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, latency {latency:.2f} msm, memory {memory} bytes') + data = '\n'.join(data) + return data + + +class Estimator: + """ + Estimator to measture the computation / memory cost of a subgraph + """ + def __init__(self, cache='./profile_database.json'): + + self.cache_file = cache + reload = cache if os.path.exists(cache) else None + self.database = ProfileDataBase(reload) + + def profile(self, node: IRFwOperation, train: bool) -> Tuple[float, int]: + if node.name == 'multiref' or isinstance(node, IRGraphAnchor): return 0.0, 0 + trials = Estimator.special_rules(node, [None]) + for config in trials: + if config is None: + num = 1 + latency, memory = self.database.profile(node, train) + else: + idx, dim, num = config + print(f'> ... try node {node.name} with idx={idx}, dim={dim}, num={num}') + sub_node = node.algorithms('dim').instantiate(idx=idx, dim=dim, num=num)[0] + latency, memory = self.database.profile(sub_node, train) + if isinstance(latency, float): break + if isinstance(latency, float): break + assert isinstance(latency, float), f"Failed to profile: {node}" + latency, memory = latency * num, memory * num + self.database.insert(node, latency, memory) + return latency, memory + + def __call__(self, nodes_or_segment: Union[Tuple[IRFwOperation], IRSegment], + train: bool = True): + """ + Profile the computation cost of a subgraph + + @param nodes_or_segment Tuple[IRFwOperation] | IRSegment + + @return latency float: latency in ms + @return memory int: memory in bytes + """ + nodes = nodes_or_segment.nodes() if isinstance(nodes_or_segment, IRSegment) else nodes_or_segment + memory, latency = 0.0, 0.0 + for node in nodes: + if self.database.exist(node): + node_latency, node_memory = self.database.query(node) + else: + node_latency, node_memory = self.profile(node, train) + if train: + memory += node_memory + latency += node_latency + else: + memory = max(memory, node_memory) + latency += node_latency + return latency, memory + + def save(self): + self.database.dump(self.cache_file, override=True) + + def special_rules(node, trials): + # if node.name == 'embedding': # for GPT + # trials = [(1, 0, 4),] + # if node.name == 'self_attention': # for GPT + # trials = [(1, 0, 4),] + # if node.name == 'window_attn': # for Swin + # trials = [(1, 0, 4),] + return trials diff --git a/examples/policies/alpa/inter_op.py b/examples/policies/alpa/inter_op.py new file mode 100644 index 00000000..a11585f1 --- /dev/null +++ b/examples/policies/alpa/inter_op.py @@ -0,0 +1,174 @@ +""" +Piper policy + +https://openreview.net/attachment?id=-U9I0f2S7W&name=supplementary_material + +The implementation is a little bit adapted to fit with cube's view +""" +from typing import List, Callable, Tuple, Dict, Optional +import time + +from cube.ir.operator import IRFwOperation +from examples.policies.alpa.layer_op import IRLayerOp, cluster_to_layer_ops +from examples.policies.alpa.plan import StageSpec, ParallelSpec + + +def iter_subgraph(nodes: Tuple[IRLayerOp], s: int): + """ + Iterate sub-graphs of the nodes + + @param nodes Tuple[IRFwOperation] + @param s int: number of stages + + @return (sub_graph1, sub_graph2) Tuple[Tuple[IRFwOp], Tuple[IRFwOp]] + """ + assert s > 0 + if s > 1: + # don't consider the head and tail to be anchor + assert len(nodes) >= s - 1, f"layer op: {len(nodes)}, stage: {s}" + for idx in range(len(nodes)): + remain_nodes = len(nodes) - (idx + 1) + # sub-problem of iter(sub_graph2, s-1) must iterable + if remain_nodes < s - 2: continue + sub_graph1, sub_graph2 = nodes[:idx+1], nodes[idx+1:] + yield sub_graph1, sub_graph2 + else: + # s == 1, take all + yield nodes, () + + +def DP(nodes: Tuple[IRLayerOp], k: int, s: int, intra_solver: Callable, + mbs: int, max_d: Optional[int] = None, max_t: Optional[int] = None, + _cost : Dict[Tuple, float] = None, + _config : Dict[Tuple, List[StageSpec]] = None, + _intra_cache = None) -> Tuple[Dict, Dict]: + """ + DP algorithm to search for balanced pipeline stage divisions by considering + tensor parallelism and pipeline parallelism. + + cost[D][k][s] = min_{D' \in D} min_{t, d where t*d<=k} max( + TPS(D\D',t,d,s), cost[D'][k-d*t][s-1] ) + + D: subgraph + K: number of devices + t: tensor parallelism size + d: data parallelism size + s: number of pipeline stages + + @param nodes Tuple[IRFwOperation]: sub-graph + @param k int: number of devices + @param s: number of pipeline stages + @param intra_solver: + which takes nodes, tensor parallelism size, data parallelism size + and in-flight number of microbatches, and outputs the + @param mbs: micro-batch size + @param max_d int: maximal data parallelism size constraint + @param max_t int: maximal tensor parallelism size constraint + + @return costs Dict[( (IRCell,), k, s ), latency] + @return config Dict[( (IRCell,), k, s ), [(IRCell,),] ] + """ + nodes = nodes if isinstance(nodes, tuple) else tuple(nodes) + key = (nodes, k, s) + + # initialize: dp[((), k, s)] = 0 for every k and s + _cost = dict() if _cost is None else _cost + _config = dict() if _config is None else _config + _intra_cache = dict() if _intra_cache is None else _intra_cache + max_d = k if max_d is None else max_d + max_t = k if max_t is None else max_t + if key in _cost: return _cost, _config + + # dp tatble boundary + if len(nodes) == 0: + _cost[key], _config[key] = 0, [] + return _cost, _config + + assert not (k == 0 or s == 0), \ + f"Illegal configuration: nodes: {len(nodes)} k={k}, s={s}: device number (k) cannot be smaller than pipeline stages (s)" + assert k >= s, f"Expected k >= s but got k={k}, s={s}" + + # True for 1,2,4,8,16,... + is_of_power2 = lambda n: (n & (n-1) == 0) and n != 0 + + # construct dynamic programming table + min_val = None # None means no solution + for sub1, sub2 in iter_subgraph(nodes, s): + for d in range(1, min(k + 1, max_d + 1)): + if mbs % d != 0: continue + for t in range(1, min(k // d + 1, max_t + 1)): + # constraints: all devices must be used + if s == 1 and d * t != k: continue + # only search for gpu# of power of 2 + if not is_of_power2(t * d): continue + # guarantee sub-problem searchable + if k - d * t < s - 1: continue + # sub2 cost + DP(sub2, k-d*t, s-1, intra_solver, mbs, max_d, max_t, + _cost, _config, _intra_cache) + sub2_cost = _cost[(sub2, k-d*t, s-1)] + if sub2_cost is None: continue + # sub1 cost: s is also the in-flight microbatch number + sub1_config = intra_solver(sub1, d, t, s, _cache=_intra_cache) + if sub1_config is None: continue + sub1_cost = sub1_config.est_latency + # pipeline cost + cost = max(sub1_cost, sub2_cost) + config = [sub1_config] + _config[(sub2, k-d*t, s-1)] + # update + if min_val is None or cost < min_val: + min_val = cost + _config[(nodes, k, s)] = config + + _cost[key] = min_val + return _cost, _config + + +def inter_op(nodes: Tuple[IRFwOperation], ndevs: int, intra_solver: Callable, mbs: int, + max_d: Optional[int]=None, max_t: Optional[int]=None, max_p: Optional[int]=None) -> ParallelSpec: + """ + DP algorithm to search for balanced pipeline stage divisions by considering + tensor parallelism and pipeline parallelism. + + @param nodes List[IRFwOperation]: graph + @param ndevs int: number of devices + @param intra_solver Callable: estimator + which takes nodes, tensor parallelism size, data parallelism size + and in-flight number of microbatches, and outputs of + cost (latency in ms) and config (intra-tp config) + @param mbs: micro-batch size + @param max_d int: maximal data parallelism size constraint + @param max_t int: maximal tensor parallelism size constraint + + @return best_config + """ + nodes: List[IRLayerOp] = cluster_to_layer_ops(nodes) + nodes = tuple(nodes) + print(f'> search [search]: constructing dp tables ({len(nodes)} layer ops)...') + tic = time.time() + max_d = mbs if max_d is None else max_d + max_d = min(max_d, mbs, ndevs) + max_t = ndevs if max_t is None else max_t + max_t = min(max_t, ndevs) + max_p = ndevs if max_p is None else min(max_p, ndevs) + max_p = min(len(nodes), max_p) + cost, config = None, None + for nstages in range(1, max_p+1): + cost, config = DP(nodes, ndevs, nstages, intra_solver, mbs, + max_d, max_t, cost, config) + print(f'> search [search]: getting optimal results...') + min_cost, best_config = None, None + for nstages in range(1, max_p+1): + tcost = cost[(nodes, ndevs, nstages)] + if tcost is None: continue + if min_cost is None or tcost < min_cost: + min_cost = tcost + best_config = config[(nodes, ndevs, nstages)] + assert best_config is not None, f"no solution" + toc = time.time() + span = toc - tic + print(f'> search [finish]: searching time: {span} s') + print(f'> search [result]: minimal latency per microbatch {min_cost} ms') + assert all(isinstance(config, StageSpec) for config in best_config) + spec = ParallelSpec(stages=best_config) + return spec diff --git a/examples/policies/alpa/intra_op.py b/examples/policies/alpa/intra_op.py new file mode 100644 index 00000000..e595f5b2 --- /dev/null +++ b/examples/policies/alpa/intra_op.py @@ -0,0 +1,230 @@ + +from typing import List, Tuple, Dict, Optional +import multiprocessing +import numpy as np +import warnings +import time + +from cube.ir.cten import IRTensor +from cube.ir.operator import IRFwOperation +from cube.graph.function.anchor import IRGraphAnchor + +from examples.policies.alpa.layer_op import IRLayerOp +from examples.policies.alpa.cost_model import CostModel +from examples.policies.alpa.plan import StageSpec + +# ILP solver +import pulp +from pulp import LpVariable, LpProblem, LpMinimize, LpStatus, lpSum, lpDot, LpStatus + + +def intra_op(layer_nodes: List[IRLayerOp], dp_size: int, tp_size: int, + inflights: int, recompute: bool, memory_limit: int, + cost_model: CostModel, _cache: Dict = None) -> Optional[StageSpec]: + """ + Search for the best intra-op parallelism configuration given device mesh. + The search is only suitable for training. + """ + key = (layer_nodes, dp_size, tp_size) + if isinstance(_cache, dict) and key in _cache: return _cache[key] + + tic = time.time() + + fnodes: List[IRFwOperation] = [] + for layer_op in layer_nodes: + for node in layer_op.nodes: + if isinstance(node, IRGraphAnchor) or node.name == 'multiref': continue + fnodes.append(node) + + # search for tp configuration + + # create variables (nodes) + s, d, c = {}, {}, {} # partition index, computation cost, communication cost + e, r = [], [] # inter-node resharding cost + + num_nodes = 0 + for fnode in fnodes: + cid = fnode.cid + npartitions = len(cost_model.partition_algos[fnode.cid]) + s[cid] = LpVariable.matrix(f's[{num_nodes}]', (range(npartitions),), cat='Binary') + d[cid] = cost_model.get_comp_cost(fnode, tp_size).flatten() / dp_size + c[cid] = cost_model.get_comm_cost(fnode, tp_size).flatten() / dp_size + # setup initial value + for pidx, strategy in enumerate(cost_model.partition_algos[fnode.cid]): + if strategy is None: continue + idx, dim = strategy + identifier = fnode.anno.input(idx)[dim].identifiers[0] + if fnode.anno.getlen(identifier) % (tp_size * dp_size) != 0: + # print(f'remove transform choice on {fnode.name}({fnode.cid}) ' + # f'of strategy: {strategy} for tp={tp_size}, dp={dp_size}') + s[cid][pidx].setInitialValue(False) + s[cid][pidx].fixValue() + num_nodes += 1 + + edges = cost_model.get_edges(fnodes) + num_edges = 0 + for src, dsts in edges.items(): + for dst in dsts: + nsrc = len(cost_model.partition_algos[src.cid]) + ndst = len(cost_model.partition_algos[dst.cid]) + e.append(LpVariable.matrix(f"e[{src.cid}, {dst.cid}]", + (range(nsrc * ndst),), + cat='Binary')) + r.append(cost_model.get_pair_reshard_cost(src, dst, tp_size).flatten()) + num_edges += 1 + + # initial value: --skip + + # objective + prob = LpProblem('intra_op', LpMinimize) + # computation cost + obj = 0 + for fnode in fnodes: + cid = fnode.cid + obj += lpDot(s[cid], c[cid]) + lpDot(s[cid], d[cid]) + # communication cost + for i in range(num_edges): + obj += lpDot(e[i], r[i]) + + prob += obj + + # constraints + + # a) only one partition can be selected + for fnode in fnodes: + prob += lpSum(s[fnode.cid]) == 1 + for i in range(num_edges): + prob += lpSum(e[i]) == 1 + + # e_src_dst[i][j] = 1 => s_src[i] == 1 and s_dst[j] == 1 + eidx = 0 + for src, dsts in edges.items(): + for dst in dsts: + for row in range(len(s[src.cid])): + C = len(s[dst.cid]) + prob += lpSum( + e[eidx][row * C + col] for col in range(0, C)) <= s[src.cid][row] + for col in range(len(s[dst.cid])): + R = len(s[src.cid]) + C = len(s[dst.cid]) + prob += lpSum( + e[eidx][row * C + col] for row in range(0, R)) <= s[dst.cid][col] + eidx += 1 + + # b) memory constraint --skip + + assert "PULP_CBC_CMD" in pulp.listSolvers(onlyAvailable=True), ( + "Please install ILP solvers by 'sudo apt install coinor-cbc' or 'pip install pulp'") + + time_limit = 600 + solver = pulp.PULP_CBC_CMD( + mip=True, msg=0, + timeLimit=time_limit, + threads=multiprocessing.cpu_count()) + prob.solve(solver) + + status = prob.status + objective = pulp.value(prob.objective) + objective = float(objective) if objective is not None else -1.0 + # print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}") + # print(f"#nodes: {num_nodes}, #edges: {num_edges}") + # print(f'ILP search time: {time.time() - tic:.2f} seconds') + + # reshard_cost = 0 + # for i in range(num_edges): + # reshard_cost += lpDot(e[i], r[i]) + # reshard_cost = pulp.value(reshard_cost) + # print(f'debug info: reshard cost: {reshard_cost}') + + if prob.status in [pulp.LpStatusInfeasible]: + raise RuntimeError("Cannot run the function under the given memory budget.") + + def get_non_zero_index(binary_vector): + """Get the index of non-zero item in a vector.""" + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + + assert ct == 1 + return ret + + tp_spec: Dict[int, int] = {} + for fnode in fnodes: + index = get_non_zero_index(s[fnode.cid]) + tp_spec[fnode.cid] = index + + # check results + e_val = np.full((num_edges,), -1, dtype=np.int32) + eidx = 0 + for (src, dsts) in edges.items(): + for dst in dsts: + e_val[eidx] = get_non_zero_index(e[eidx]) + src_spec_index = e_val[eidx] // len(s[dst.cid]) + dst_spec_index = e_val[eidx] % len(s[dst.cid]) + assert src_spec_index == tp_spec[src.cid] + assert dst_spec_index == tp_spec[dst.cid] + eidx += 1 + + if objective > 1e13: + warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") + + # estimate activation memory + non_recompute_mem = 0 + recompute_mem, curr_recomp_id = [0], None + for node in fnodes: + strat = cost_model.partition_algos[node.cid][tp_spec[node.cid]] + op_tp_size = 1 if strat is None else tp_size + node_mem = cost_model.get_memory_cost(node) // (dp_size * op_tp_size) + if node.recompute != curr_recomp_id: + recompute_mem.append(0) + curr_recomp_id = node.recompute + if node.recompute is None: + non_recompute_mem += node_mem + else: + recompute_mem[-1] += node_mem + act_memory = non_recompute_mem * inflights + max(recompute_mem) + + # estimate parameter memory + param_mem = 0 + pids = set() + for node in fnodes: + attrs = [t for t in node.inputs() if \ + isinstance(t, IRTensor) and t.is_attr()] + for attr in attrs: + if attr.tid in pids: continue + opt = 4 if attr.is_param() else 1 + # we estimate parameter size by assuming it will partition on weight + param_mem += opt * attr.byte_size() // tp_size + pids.add(attr.tid) + + # print(f'debug: inflights: {inflights}, act memory: {act_memory/1024/1024/1024}, param mem: {param_mem/1024/1024/1024}') + mem_cost = act_memory + param_mem + if mem_cost > memory_limit: + print(f'searching results of {len(tp_spec)} nodes: tp={tp_size}, dp={dp_size}: no solution (memory: {mem_cost/1024/1024/1024} GB)') + return None + + # get tensor parallelism spec + stage_tp_spec = {} + names = {} + for fnode in fnodes: + strategy = None if tp_size == 1 else \ + cost_model.partition_algos[fnode.cid][tp_spec[fnode.cid]] + stage_tp_spec[fnode.cid] = strategy + names[fnode.cid] = fnode.name + + config = StageSpec( + est_latency=objective / 3 * 4 if recompute else objective, + est_memory=mem_cost, + tp_size=tp_size, + dp_size=dp_size, + tp_spec=stage_tp_spec, + names=names, + ) + print(f'searching results of {len(stage_tp_spec)} nodes: tp={tp_size}, dp={dp_size} ' + f'latency={objective}, memory={mem_cost/1024/1024/1024} GB') + if isinstance(_cache, dict): _cache[key] = config + # print(config) + return config diff --git a/examples/policies/alpa/layer_op.py b/examples/policies/alpa/layer_op.py new file mode 100644 index 00000000..d87504af --- /dev/null +++ b/examples/policies/alpa/layer_op.py @@ -0,0 +1,42 @@ +from typing import List, Dict, Tuple +import more_itertools + +from cube.ir.cten import IRCell +from cube.ir.operator import IRFwOperation +from cube.graph.graph import IRGraph +from cube.graph.function.anchor import IRGraphAnchor + + +class IRLayerOp(IRCell): + + def __init__(self, nodes: List[IRCell], layer_id: int = None): + super().__init__('layer_op', 'layer_op', 0, 0, init_outputs=False) + self.nodes = nodes + self.layer_id : int = layer_id + + +def cluster_to_layer_ops(nodes: List[IRFwOperation]) -> List[IRLayerOp]: + layer_ops: List[IRLayerOp] = [] + ops = [] + for node in nodes: + if isinstance(node, IRGraphAnchor): + if len(ops) != 0: + layer_ops.append(IRLayerOp(ops, layer_id=len(layer_ops))) + ops = [node] + elif isinstance(node, IRFwOperation): + ops.append(node) + if len(ops) != 0: + layer_ops.append(IRLayerOp(ops, layer_id=len(layer_ops))) + return layer_ops + + +def annotate_structure(graph: IRGraph) -> List[Tuple[IRFwOperation]]: + """Annotate graph stucture in generated code""" + anchors = graph.select(ntype=IRGraphAnchor) + for idx, anchor in enumerate(anchors): + nidx = graph.index(anchor) + graph.node(nidx + 1).comment = f'===> split position {idx}: {anchor.name}' + fnodes = graph.select(ntype=IRFwOperation) + subgraphs = more_itertools.split_before(fnodes, lambda n: isinstance(n, IRGraphAnchor)) + return list(subgraphs) + \ No newline at end of file diff --git a/examples/policies/alpa/plan.py b/examples/policies/alpa/plan.py new file mode 100644 index 00000000..80e9212d --- /dev/null +++ b/examples/policies/alpa/plan.py @@ -0,0 +1,105 @@ +from typing import Dict, Tuple, Optional +from dataclasses import dataclass +import json + +from cube.ir.operator import IRFwOperation +from cube.graph.graph import IRGraph + +@dataclass +class StageSpec: + # estimation + est_latency: float # in milliseconds + est_memory: float # in types + # config + tp_size: int + dp_size: int + # node.cid -> (idx, num) | None + tp_spec: Dict[int, Optional[Tuple[int, int]]] + # node.cid -> node.name + names: Dict[int, str] + + def __repr__(self) -> str: + dscp = '' + for cid, strategy in self.tp_spec.items(): + strategy = 'Replicate' if strategy is None else f"idx={strategy[0]}, dim={strategy[1]}, num={self.tp_size}" + dscp += f' {self.names[cid]}({cid}): {strategy}\n' + return dscp + + def to_dict(self) -> Dict: + return { + 'est_latency': self.est_latency, + 'est_memory': self.est_memory, + 'tp_size': self.tp_size, + 'dp_size': self.dp_size, + 'tp_spec': self.tp_spec, + 'names': self.names + } + + @staticmethod + def from_dict(d: Dict): + tp_spec = {int(cid): spec for cid, spec in d['tp_spec'].items()} + names = {int(cid): name for cid, name in d['names'].items()} + return StageSpec( + est_latency=d['est_latency'], + est_memory=d['est_memory'], + tp_size=d['tp_size'], + dp_size=d['dp_size'], + tp_spec=tp_spec, + names=names + ) + + +@dataclass +class ParallelSpec: + stages: Tuple[StageSpec] + + @property + def est_latency(self) -> float: + return max(s.est_latency for s in self.stages) + + def save(self, filename: str): + """ + Save plan into json file + """ + with open(filename, 'w') as f: + json.dump([s.to_dict() for s in self.stages], f) + + def getstate(self) -> str: + """ + Get plan state as json string + """ + return json.dumps([s.to_dict() for s in self.stages]) + + @staticmethod + def loadstate(state: str): + """ + Load plan from json string + """ + stages = json.loads(state) + return ParallelSpec(tuple(StageSpec.from_dict(s) for s in stages)) + + @staticmethod + def load(filename: str, check_graph_consistent: IRGraph = None): + """ + Load plan from json file + """ + with open(filename, 'r') as f: + stages = json.load(f) + spec = ParallelSpec(tuple(StageSpec.from_dict(s) for s in stages)) + if check_graph_consistent is not None: + graph = check_graph_consistent + cid2name = {n.cid: n.name for n in graph.select(ntype=IRFwOperation)} + for stage in spec.stages: + for cid, name in stage.names.items(): + assert cid in cid2name, f'graph is not consistent with plan: node cid {cid}:{name} not found in graph' + assert cid2name[cid] == name, f'graph is not consistent with plan: cid {cid}:{name} name mismatch' + return spec + + def __repr__(self) -> str: + dscp = f'nstages: {len(self.stages)} | latency: {self.est_latency} ms' + for sidx, stage in enumerate(self.stages): + tp, dp = stage.tp_size, stage.dp_size + latency, memory = stage.est_latency, stage.est_memory / 1024 / 1024 / 1024 + dscp += f'\nStage {sidx} (tp={tp}, dp={dp}, latency={latency} ms, memory={memory}):\n' + dscp += f'{stage}' + return dscp From c9e95e4b3eddc066d24376b671c40b73bd74b507 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 May 2023 09:06:43 +0000 Subject: [PATCH 1391/1892] Merged PR 1560: fix communication cost fn fix communication cost fn --- cube/graph/gener/gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 152bd73a..f425435a 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -409,7 +409,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # generate adapter for each segment segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] for segment in segments: - IRAdapterGener.gen_activation(segment) + IRAdapterGener.gen_activation(segment, allow_recompute=allow_recompute, cost_fn=cost_fn) return graph From 27a1742d3d622324e9d589a07aa3c7f315e15f6d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 May 2023 01:43:40 +0000 Subject: [PATCH 1392/1892] Merged PR 1569: allow recompute on intermediate non-autograd op allow recompute on intermediate non-autograd op --- cube/graph/gener/gen.py | 4 ++-- cube/graph/graph.py | 24 +++++++++++++++++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index f425435a..77cd21af 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -388,10 +388,10 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: break graph.insert(fadapter, fidx) # setup recompute - if fadapter.differentiable and allow_recompute: + if allow_recompute: if fidx > 0: prev_node = graph.node(fidx-1) - if isinstance(prev_node, IRFwOperation): + if isinstance(prev_node, (IRFwOperation, IRAdapter)): fadapter.recompute = prev_node.recompute # insert backward adapter diff --git a/cube/graph/graph.py b/cube/graph/graph.py index f840a1b7..edf17f82 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -864,15 +864,29 @@ def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: assert all(segment == segments[0] for segment in segments), \ "Cross-segment recompute is not allowed yet" recompute_group_id: int = IDGenerator().gen_cell_id() + start = 0 for fnode in nodes: - if isinstance(fnode, IRGraphAnchor): + tensors = [t for t in fnode.inputs() if isinstance(t, IRSubTensor) and (not t.is_attr())] + if all(t.grad is None for t in tensors): + start += 1 continue - # pytorch limitation - if all(not t.requires_grad for t in fnode.inputs() if isinstance(t, IRSubTensor) and (not t.is_attr())): - print(f"skipping recompute node: {fnode}\n\tbecause all its input tensors doesn't require grad.") + break + skip = nodes[:start] + nodes = nodes[start:] + end = len(nodes) + for fnode in nodes[::-1]: + tensors = [t for t in fnode.inputs() if isinstance(t, IRSubTensor) and (not t.is_attr())] + if all(t.grad is None for t in tensors): + end -= 1 continue + break + skip += nodes[end:] + for node in skip: + if isinstance(node, IRGraphAnchor): continue + print(f"skip recompute node: {node.name} ({node.cid}) as it doesn't require gradient and appears at head or tail.") + nodes = nodes[:end] + for fnode in nodes: fnode.recompute = recompute_group_id - return True # =================== Helpers ==================== From 94873a6668c9e7daaab6bf7949eb4390c09c2a17 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 May 2023 01:45:55 +0000 Subject: [PATCH 1393/1892] Merged PR 1563: support grad accum and fix schedule bugs support grad accum and fix schedule bugs usage ```python accum_times = 4 for idx in range(accum_times): with cube.accum_mode(enable=(idx!=accum_times-1)): data = next(dataloader) train_iter(model, data) ``` --- cube/__init__.py | 4 ++++ cube/codegen/module/module.py | 33 ++++++++++++------------------- cube/flags.py | 7 +++++++ cube/graph/schedule/predefined.py | 16 +++++++++++++++ cube/graph/schedule/schedplan.py | 17 +++++++++++----- cube/runtime/adapter/reducer.py | 2 ++ cube/runtime/module.py | 9 ++++++--- cube/utils.py | 22 +++++++++++++++++++++ 8 files changed, 82 insertions(+), 28 deletions(-) diff --git a/cube/__init__.py b/cube/__init__.py index 4607f7d4..ecce5e1c 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,9 +1,13 @@ import warnings from cube import runtime + from cube import profiler +from cube.profiler.timer import CudaTimer from cube.compiler import SemanticModel, compile + from cube.utils import load_model, load_default_schedule, load_eval_schedule +from cube.utils import accum_mode def _check_torch_version(): diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index aa611b9c..5acb16bd 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -111,34 +111,27 @@ def gen(self, device: int, outfile=None, attach=False) -> str: node_args: List[List[str]] = list() gen_nodes: List[IRCell] = list() + unrolled_seqs = [] + for node in self.execplan.seq(device): + # unwrap from ExeReuseCell + node = node.cell if isinstance(node, ExeReuseCell) else node + unrolled_seqs.append(node) + # we use ordered dict as ordered set + sequence = tuple(dict.fromkeys(unrolled_seqs)) + # init customized adapter - for seg in [seg for seg in self.execplan.seq(device) if isinstance(seg, IRSegment)]: - for adapter in [n for n in seg.nodes() if isinstance(n, IRAdapter)]: - if adapter.isfw() and adapter.differentiable and adapter.custom: + fsegments = [node for node in sequence if isinstance(node, IRSegment) and node.isfw()] + for seg in fsegments: + for adapter in seg.select(ntype=IRAdapter): + if adapter.differentiable and adapter.custom: gencode += AutogradAdapterCodeGen().gen(adapter) + ['', ''] adapter.signature = AutogradAdapterCodeGen.name(adapter) + '.apply' # initialize communication groups self.init_comm_groups() - # parse graph body - unrolled_seqs = [] - - for node in self.execplan.seq(device): - # unwrap from ExeReuseCell and ExeRepetend - if isinstance(node, ExeReuseCell): - node = node.cell - if isinstance(node, ExeRepetend): - for node in node.nodes(): - if isinstance(node, ExeReuseCell): - node = node.cell - unrolled_seqs.append(node) - else: - unrolled_seqs.append(node) - # we use ordered dict as ordered set - unrolled_seqs = tuple(dict.fromkeys(unrolled_seqs)) # emit code - for node in unrolled_seqs: + for node in sequence: if isinstance(node, IRSegment): if not node.isfw(): continue # skip backward segment codes = self.emit_segment(node) diff --git a/cube/flags.py b/cube/flags.py index 9b7dceda..a5fbb8de 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -47,3 +47,10 @@ class CompileFlag: # and optimizer status are kept in its original data type (can be float32), # but some of the forward operators will be converted to float16. use_amp = _to_bool('USE_AMP') + + +class RuntimeFlag: + + # turn execution in accumulation mode + # where reducers will not allpy allreduce on gradients + accum_mode: bool = False diff --git a/cube/graph/schedule/predefined.py b/cube/graph/schedule/predefined.py index d35e7f4b..a2e81503 100644 --- a/cube/graph/schedule/predefined.py +++ b/cube/graph/schedule/predefined.py @@ -11,6 +11,22 @@ class PredefinedSched: + @staticmethod + def grad_accum(graph: IRGraph, num_microbatches: int) -> SchedulePlan: + """ + Gradient accumulation for SPMD scenario. + """ + segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + # describe schedule + sched = SchedulePlan(graph, num_microbatches) + step = 0 + for midx in range(num_microbatches): + for seg in segments: + sched.add_segment(seg, midx, step) + step += 1 + sched.finish() + return sched + @staticmethod def sched_1f1b(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: """ diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py index 115f8d73..0f45731a 100644 --- a/cube/graph/schedule/schedplan.py +++ b/cube/graph/schedule/schedplan.py @@ -84,11 +84,11 @@ def build(self): for segment in segments: if self.graph.depends(adapter, segment): assert adapter not in self.recvers, \ - f"Detected more than one segments to recv data from a same adapter" + f"Detected one adapter receives data from more than one segments" self.recvers[adapter] = segment elif self.graph.depends(segment, adapter): assert adapter not in self.senders, \ - f"Detected more than one segments to send data from a same adapter" + f"Detected one adapter {adapter} sends data to more than one segments" self.senders[adapter] = segment # get all weight reducers self.reducers = self.graph.select(ntype=IRWeightReducer, flatten=False) @@ -183,17 +183,21 @@ def _place_dataloader(self): """ Place dataloaders together with segments """ - # FIXME: this may not work for multiple segments in a same - # micro-batch require for the data + # insert dataloaders to its devices before the first required segment for dl in self._dependency.dataloaders: + inserted_mids = set() for step in range(self.nsteps): blocks = self.segments(step) for block in blocks: segment, mid = block.content, block.mid + if mid in inserted_mids: continue + if dl.device[0] not in segment.device: continue if self.graph.depends(dl, segment): dl_block = Block(dl, mid, 1) + # print(f'inserting microbatch {mid} at step {step} before {segment.name}{segment.cid}') self._step_segments[step+block.span-1].insert(0, dl_block) self._block_step[dl_block] = step+block.span-1 + inserted_mids.add(mid) break def topo_sort(self): @@ -401,8 +405,11 @@ def _place_adapters(self): Place adapters to make sure the communication happens correctly and efficiently. """ - assert len(self._dependency.adapters) > 0 for adapter in self._dependency.adapters: + assert adapter in self._dependency.senders, ( + f"Detected an adapter\n\t{adapter}\ndoesn't have a sender segment. " + f"This usually happens when its sender is dataloader or graph inputs." + f"Please replicate dataloader to remove this adapter.") sender: IRSegment = self._dependency.senders[adapter] # find sender step and insert adapter for step in range(self.nsteps): diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index b89c3099..559243cf 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -8,6 +8,7 @@ from cube.runtime.device import DeviceGroup from cube.profiler.timer import CudaTimer, print_each_rank +from cube.flags import RuntimeFlag def get_nbytes(dtype: torch.dtype) -> int: @@ -38,6 +39,7 @@ def allreduce(self): """ Reduce gradients across given group """ + if RuntimeFlag.accum_mode: return buckets = {} tp2size = {} for param in self._params: diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 3e4d9823..a43ee395 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -13,7 +13,7 @@ class CubeModule(torch.nn.Module): def __init__(self): super().__init__() - self._reducers = list() + self._reducers: List[Reducer] = list() self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() self._batch_size: Optional[int] = None @@ -22,9 +22,12 @@ def add_reducer(self, reducer: Reducer): raise RuntimeError(f"Expected a Reducer but got {type(reducer)}") self._reducers.append(reducer) - def sync_params(self): + def reduce_grads(self): + """ + Mannually allreduce gradients on the weight + """ for reducer in self._reducers: - reducer.sync() + reducer.allreduce() def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: int): """ diff --git a/cube/utils.py b/cube/utils.py index 623c6431..7128c6f1 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -2,6 +2,8 @@ from cube.profiler.timer import print_each_rank from cube.runtime.device import DeviceGroup +from cube.flags import RuntimeFlag + def _load_module_attr(filename: str, name: str): import importlib.util @@ -32,3 +34,23 @@ def load_eval_schedule(filename: Optional[str] = None): module = _load_module_attr(filename, '_infer_step') return module._infer_step + +class accum_mode: + """ + Make cube execution in accumulation mode, where weight + gradient allreduce will be skipped. + + need manually call `model.reduce_grads()` to reduce gradients + after finish accumulation, or make `enable=False` for the last + accumulation step. + """ + def __init__(self, enable: bool = True): + self.enable = enable + self.old = None + + def __enter__(self): + self.old = RuntimeFlag.accum_mode + RuntimeFlag.accum_mode = self.enable + + def __exit__(self, *args): + RuntimeFlag.accum_mode = self.old From 714e4198db6f42b356c5e234189cb642a465b99a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 9 May 2023 02:12:01 +0000 Subject: [PATCH 1394/1892] Merged PR 1534: allow merge inference adapter into segment allow merge inference adapter into segment --- cube/codegen/module/module.py | 3 +- cube/execplan/planpass/grouping.py | 5 +- cube/graph/segment.py | 114 +++++++++-------------------- 3 files changed, 38 insertions(+), 84 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 5acb16bd..602a71a3 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -417,8 +417,9 @@ def emit_segment(segment: IRSegment) -> List[str]: else: # get recompute excution code rc_segment = segment.create_segment(rc_group) + rc_lifetime = LifeCycle(rc_group, rc_segment.inputs(), rc_segment.outputs()) rc_codes = ModuleCodeGen._emit_recompute(rc_group, - rc_segment.inputs(), rc_segment.outputs(), lifetime) + rc_segment.inputs(), rc_segment.outputs(), rc_lifetime) codes += rc_codes # release input tensors after exiting a RC group: last_node = rc_group[-1] diff --git a/cube/execplan/planpass/grouping.py b/cube/execplan/planpass/grouping.py index bb0de6d5..3f6ea81d 100644 --- a/cube/execplan/planpass/grouping.py +++ b/cube/execplan/planpass/grouping.py @@ -65,8 +65,9 @@ def differentiable(fnode): if isinstance(fnode, IRAdapter): return False if isinstance(fnode, IRFwOperation): return True - if isinstance(fnode, IRAdapter) and fnode.differentiable and fnode.isfw(): - return True + if isinstance(fnode, IRAdapter) and fnode.isfw(): + if fnode.differentiable: return True + if fnode.mirror is None: return True # not require backward return False fgroups, bgroups = dict(), dict() diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 53c9bc3e..944ed326 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -920,107 +920,59 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: Create a segment with part of the nodes. This only return the created segment wihout modifying the graph. + Calling this requires that the dependencies are already materialized, + i.e., every input IRSubTensor should have a corresponding producer. + @param nodes List[IRCell]: the subset nodes of this graph @return segment IRSegment: the grouped segment. """ segment = self - segment_inputs = IRSegment.get_objects_from_complex(segment.inputs()) segment_outputs = IRSegment.get_objects_from_complex(segment.outputs()) - # segments: List[IRSegment] = [self.segment(node) for node in nodes] - # assert len(set(segments)) == 1, "Cross segment hierarchy grouping is not allowed" - # segment = segments[0] - - inputs, outputs = set(), set() - - # go through adapters - adapter_ins: Dict[IRSubTensor, Set[int]] = dict() - adapter_ous: Dict[IRSubTensor, Set[int]] = dict() - for adapter in nodes: - if not isinstance(adapter, IRAdapter): - continue + # setup adapter dependency + ad_consumers: Dict[Tuple[IRSubTensor,int], Set[int]] = dict() + ad_producers: Dict[Tuple[IRSubTensor,int], Set[int]] = dict() + for adapter in self.select(ntype=IRAdapter): for itensor in adapter.inputs(): if not isinstance(itensor, IRSubTensor): continue - if itensor not in adapter_ins: - adapter_ins[itensor] = set() - adapter_ins[itensor].update(itensor.device) - # producers can from out side node - producers = [] - for ptensor, prod in zip(segment.ptensors(itensor.parent), segment.producers(itensor.parent)): - if ptensor == itensor and set(itensor.device).issubset(set(prod.device)): - producers.append(prod) - if not any(p in nodes for p in producers): - inputs.add(itensor) + ad_consumers.setdefault((itensor, itensor.device[0]), set()).add(adapter.cid) for otensor in adapter.outputs(): if not isinstance(otensor, IRSubTensor): continue - if otensor not in adapter_ous: - adapter_ous[otensor] = set() - adapter_ous[otensor].update(otensor.device) - consumers = [] - for ctensor, cons in zip(segment.ctensors(otensor.parent), segment.consumers(otensor.parent)): - if ctensor == otensor and set(otensor.device).issubset(set(cons.device)): - consumers.append(cons) - if not any(c in nodes for c in consumers): - outputs.add(otensor) + ad_producers.setdefault((otensor, otensor.device[0]), set()).add(adapter.cid) - # go through non-adapter nodes + # tensor and its device match + dmatch = lambda t1, t2: t1 == t2 and t1.device == t2.device + + inputs, outputs = set(), set() + sub_cids = set(node.cid for node in nodes) for node in nodes: - if isinstance(node, IRAdapter): - assert node.differentiable, \ - "Non-differentiable IRAdapter is not allowed to be grouped" - continue - # update inputs - itensors = [t for t in node.inputs() if isinstance(t, IRObject)] - for itensor in itensors: - ftensor = itensor.parent + for itensor in node.inputs(): + if not isinstance(itensor, IRTensor): continue if itensor.is_attr(): continue - # from inside adapters - if itensor in adapter_ous: - if len(node.device) > 0 and set(itensor.device).issubset(adapter_ous[itensor]): - continue - # from segment inputs - if any(t.overlap(itensor) for t in segment_inputs if isinstance(t, IRObject)): - inputs.add(itensor) - continue - # from outside producers - producers, ptensors = segment.producers(ftensor), segment.ptensors(ftensor) - producers = [p for p, t in zip(producers, ptensors) if t == itensor] + producers, ptensors = self.producers(itensor.parent), self.ptensors(itensor.parent) + pids = set(p.cid for p, t in zip(producers, ptensors) if dmatch(t, itensor)) if len(itensor.device) > 0: - producers = [p for p in producers if set(itensor.device).issubset(set(p.device))] - # from graph inputs or outside adapter (no producer) - if len(producers) == 0 or any(p not in nodes for p in producers): + assert len(itensor.device) == 1 + pids.update(cid for cid in ad_producers.get((itensor, itensor.device[0]), [])) + # if no producers inside the nodes can produce data, set as input + if all(pid not in sub_cids for pid in pids): inputs.add(itensor) - continue - # update outputs - otensors = [t for t in node.outputs() if isinstance(t, IRObject)] - for otensor in otensors: - ftensor = otensor.parent - if otensor.is_attr(): continue - # from inside adapters - if otensor in adapter_ins: - if len(node.device) > 0 and set(otensor.device).issubset(adapter_ins[otensor]): - continue - # from segment outputs - if any(t.overlap(otensor) for t in segment_outputs if isinstance(t, IRObject)): + for otensor in node.outputs(): + if not isinstance(otensor, IRTensor): continue + # if the tensor is required by segment outputs or is loss during train, set as output + if otensor.is_loss() or otensor in segment_outputs: outputs.add(otensor) continue - # loss must be returned - if isinstance(ftensor, IRFullTensor) and ftensor.is_loss(): - outputs.add(otensor) - continue - if len(segment.consumers(ftensor)) == 0: - continue - # for outside consumers - consumers, ctensors = segment.consumers(ftensor), segment.ctensors(ftensor) - consumers = [c for c, t in zip(consumers, ctensors) if t == otensor] + consumers, ctensors = self.consumers(otensor.parent), self.ctensors(otensor.parent) + cids = set(c.cid for c, t in zip(consumers, ctensors) if dmatch(t, otensor)) if len(otensor.device) > 0: - consumers = [c for c in consumers if set(otensor.device).issubset(set(c.device))] - # for adapter (no consumer) - if len(consumers) == 0 or any(c not in nodes for c in consumers): + assert len(otensor.device) == 1 + cids.update(cid for cid in ad_consumers.get((otensor, otensor.device[0]), [])) + # if the tensor is required by other nodes outside the nodes, set as output + if any(cid not in sub_cids for cid in cids): outputs.add(otensor) - continue - + def order(tensors: Set[IRObject]) -> Tuple[IRObject]: """Reorder by logical tensor id. Temporally necessary for pipeline scheduling""" tensors = list(tensors) From 94e6d844f5d62433c7b08e9a0b2a00341fe74367 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 28 May 2023 01:17:19 +0000 Subject: [PATCH 1395/1892] Merged PR 1579: fix negative offset of transformation rules fix negative offset of transformation rules --- cube/algorithm/ops/dimops.py | 9 ++++++--- cube/graph/function/function.py | 12 +++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index f7972ec9..32e1b8fe 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -113,7 +113,6 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: if splits[idx].isV(): return True return False - def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List[IRDimops]]: @@ -178,8 +177,12 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR for r in node.transform_rules: splits = r.inputs() + r.outputs() if isinstance(dim, int): - if splits[idx] == DimopSplit.D(dim): - return r + if splits[idx].isD(): + # make negative offset to be possitive + ndims = len(node.input(idx).shape) + rdim = (splits[idx].dim + ndims) % ndims + if rdim == dim: + return r else: if splits[idx].isV(): return r diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index cf8a8aad..8f42334e 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -5,6 +5,7 @@ import operator import numpy as np import math +import warnings from cube.ir.cten import IRTensor, IRObject from cube.ir.tensor import IRSubTensor, IRFullTensor @@ -52,11 +53,12 @@ def Linear(input, weight, bias=None, signature = None): annos = ['b * k+, n k+ -> b * n'] return IRDimops(Linear, 'linear', signature, annos, [input, weight], bias=None) else: - annos = ['b * k+, n k+, n -> b * n'] - rules = [TransformRule( - [DimopSplit.D(-1), DimopSplit.D(1), DimopSplit.V()], [DimopSplit.V()] - )] - return IRDimops(Linear, 'linear', signature, annos, [input, weight, bias], rules) + annos = ['b * k^, n k^, n -> b * n'] + # rules = [TransformRule( + # [DimopSplit.D(-1), DimopSplit.D(1), DimopSplit.V()], [DimopSplit.V()] + # )] + warnings.warn('detected a linear operator has bias, the partition on reduction dimension is disabled.') + return IRDimops(Linear, 'linear', signature, annos, [input, weight, bias]) def BatchLinear(input, mat2, *, out=None, signature = None): From c61d1586d7ab683ab73ec9930a9324635fb5fe73 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Sun, 28 May 2023 08:29:07 +0000 Subject: [PATCH 1396/1892] Merged PR 1592: support exp, tril, @, bool support exp, tril, @, bool --- cube/graph/function/function.py | 42 +++++++++++++++++++++++++++++++++ cube/graph/parser/mappingfx.py | 7 +++--- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 8f42334e..e32a5ead 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -99,6 +99,11 @@ def EinSum(equation: str, *operands, signature = None): def Matmul(input, other, *, out=None, signature=None): + """ + torch.matmul + _operator.matmul + """ + signature = 'torch.matmul' assert out is None annos = [ 'm k+, k+ n -> m n', @@ -406,6 +411,18 @@ def Div(input, other, *, rounding_mode=None, out=None, signature = None): return IRDimops(Div, 'div', signature, annos, [input, other], rounding_mode=rounding_mode) +def Exp(input, *, out=None, signature=None): + """ + torch.exp(input, *, out=None) + """ + assert out is None + if not isinstance(input, IRTensor): + return torch.exp(input) + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(Exp, 'exp', signature, annos, [input]) + + def FloorDiv(input, other, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): @@ -561,6 +578,15 @@ def Float(input, memory_format=None, signature = None): return IRDimops(Float, 'float', signature, annos, [input]) +def Bool(input, memory_format=None, signature = None): + """ + torch.Tensor.bool(memory_format=torch.preserve_format) + """ + assert memory_format is None + annos = ['* -> *'] + return IRDimops(Bool, 'bool', signature, annos, [input]) + + def Fill(input, value, signature = None): """ torch.Tensor.fill_(value) @@ -1113,6 +1139,22 @@ def Triu(input, diagonal=0, *, out=None, signature = None): return IRDimops(Triu, 'triu', signature, [anno], [input], diagonal=diagonal) +def Tril(input, diagonal=0, *, out=None, signature=None): + """ + torch.tril(input, diagonal=0, *, out=None) + """ + assert out is None + assert isinstance(input, IRTensor) + edim_in = ShapeAnno.create_shape_str(input.shape) + assert len(edim_in) >= 2 + edim_in[-1] += '^' + edim_in[-2] += '^' + edim_ou = copy.copy(edim_in) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(Tril, 'tril', signature, [anno], [input], + diagonal=diagonal) + + def CumSum(tensor, dim, signature = None): """ out = torch.cumsum(tensor, dim) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 6033aecc..37e0af53 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -56,10 +56,12 @@ def exist(signature: str) -> bool: __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, + __ttemplate('exp'): function.Exp, __ttemplate('squeeze'): function.Squeeze, __ttemplate('unsqueeze'): function.Unsqueeze, __tttemplate('type_as'): function.TypeAs, __ttemplate('triu'): function.Triu, + __ttemplate('tril'): function.Tril, __ftemplate('relu'): function.ReLU, __fcntemplate('gelu'): function.GeLU, __ttemplate('eq') : function.CompareEQ, @@ -69,6 +71,7 @@ def exist(signature: str) -> bool: __tttemplate('long'): function.Long, __tttemplate('int'): function.Int, __tttemplate('float'): function.Float, + __tttemplate('bool'): function.Bool, __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, @@ -108,9 +111,7 @@ def exist(signature: str) -> bool: 'builtins.list': function.MakeList, # # torch nn functional - # - # __ftemplate('linear') : function.Linear, - # + '_operator.matmul': function.Matmul, __ttemplate('matmul'): function.Matmul, # # __ftemplate('gelu') : function.GeLU, From 08b4e57ec746044807a7a620d93abfb4271ec7d4 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 5 Jun 2023 05:45:05 +0000 Subject: [PATCH 1397/1892] Merged PR 1600: Merge work in deploy branch When review this pull request, please compare deploy_resolve_conflict with deploy. --- cube/codegen/lifecycle.py | 2 +- cube/codegen/module/module.py | 179 +++++++++++------- cube/codegen/schedule/schedule.py | 22 ++- cube/compiler.py | 9 +- cube/graph/function/function.py | 27 +++ .../concrete_trace_utils/concrete_tracer.py | 75 ++++---- cube/graph/parser/mappingfx.py | 3 + cube/ir/adapter/prim.py | 4 +- cube/profiler/database.py | 3 +- cube/runtime/adapter/reducer.py | 12 +- cube/runtime/module.py | 58 +++++- scripts/keep.py | 10 +- 12 files changed, 285 insertions(+), 119 deletions(-) diff --git a/cube/codegen/lifecycle.py b/cube/codegen/lifecycle.py index a76135d3..6fe19680 100644 --- a/cube/codegen/lifecycle.py +++ b/cube/codegen/lifecycle.py @@ -61,7 +61,7 @@ def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: # Generally we don't manually release those tensors since the enclosing function is about to # return, all local variables are automatically released. # But we do need to update the lifetime of all outputs, to avoid early releasing. - self.lifetime.update((tsout, i+1) for tsout in graph_outputs if is_activation(tsout)) + self.lifetime.update((tsout, len(nodes)) for tsout in graph_outputs if is_activation(tsout)) for tensor, line_id in self.lifetime.items(): self.release.setdefault(line_id, []).append(tensor) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 602a71a3..967ff75e 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -3,6 +3,7 @@ import warnings import copy import torch +import numpy as np from cube.ir.cten import IRCell from cube.ir.tensor import IRSubTensor @@ -74,9 +75,17 @@ class ModuleCodeGen(FuncEmission): ``` """ - def __init__(self, execplan: ExecutionPlan) -> None: + def __init__(self, execplan: ExecutionPlan, scale_ndevs: Optional[int] = None) -> None: + """ + Create Module code generator + + @param execplan ExecutionPlan + @param scale_ndevs Optional[int]: scale to number of devices + """ + super().__init__() self.execplan: ExecutionPlan = execplan + self.devices: Tuple[int] = tuple(sorted(execplan.graph.device)) self.init_code: List[str] = [ '\n\n########## Generated Model Code ###########', @@ -102,6 +111,89 @@ def __init__(self, execplan: ExecutionPlan) -> None: self._ref_module = torch.nn.Module() # batch size self.batch_size = None + # communication groups + self.comm_groups: List[Tuple[int]] = self.get_comm_groups(scale_ndevs) + # whether to scale (with data parallelism) + self._scale_to_ndevs = scale_ndevs + self._scale_reducers = [] + + def get_comm_groups(self, scale_ndevs: Optional[int] = None): + """ + Scale the communication groups to multiple devices + using data parallelism. + + @warn this requires user side to setup dataloader + for different GPUs + + @param scale_ndevs Optional[int]: scale to number of devices + """ + scale_ndevs = scale_ndevs if scale_ndevs is not None else len(self.devices) + assert len(self.devices) == max(self.devices) + 1, f'device must be consecutive' + assert scale_ndevs % len(self.devices) == 0, f'ngpus must be a multiple of {len(self.devices)}' + nreplica = scale_ndevs // len(self.devices) + # scale communication groups + graph = self.execplan.graph + comm_groups = [] + reducers: List[IRWeightReducer] = graph.select(ntype=IRWeightReducer) + for reducer in reducers: + ranks = np.array(tuple(sorted(reducer.device)), dtype=int) + for i in range(nreplica): + shifted_ranks = tuple(ranks + i * len(self.devices)) + shifted_ranks = tuple(int(rank) for rank in shifted_ranks) + if shifted_ranks not in comm_groups: + comm_groups.append(shifted_ranks) + adapters = graph.select(ntype=IRAdapter) + for adapter in adapters: + for prim in adapter.prims: + if isinstance(prim, CollectivePrim): + ranks = np.array(tuple(sorted(prim.kwargs['ranks'])), dtype=int) + for i in range(nreplica): + shifted_ranks = tuple(ranks + i * len(self.devices)) + shifted_ranks = tuple(int(rank) for rank in shifted_ranks) + if shifted_ranks not in comm_groups: + comm_groups.append(shifted_ranks) + # allreduce all gradients + if nreplica > 1: + for rank in self.devices: + ranks = tuple(range(rank, scale_ndevs, len(self.devices))) + if ranks not in comm_groups: + comm_groups.append(ranks) + return comm_groups + + def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: + assert len(self.devices) == max(self.devices) + 1, f'device must be consecutive' + assert ndevs % len(self.devices) == 0, f'ngpus must be a multiple of {len(self.devices)}' + shift = (device // len(self.devices)) * len(self.devices) + if isinstance(node, IRAdapter): + adapter = copy.copy(node) + adapter._id = node.cid + adapter.kwargs = dict(**node.kwargs) + prims = [] + for prim in adapter.prims: + p = copy.copy(prim) + p.kwargs = dict(**prim.kwargs) + if 'ranks' in prim.kwargs: + p.kwargs['ranks'] = [rank + shift for rank in prim.kwargs['ranks']] + if 'src' in prim.kwargs: + p.kwargs['src'] = prim.kwargs['src'] + shift + if 'srcs' in prim.kwargs: + p.kwargs['srcs'] = [src + shift for src in prim.kwargs['srcs']] + if 'dst' in prim.kwargs: + p.kwargs['dst'] = prim.kwargs['dst'] + shift + if 'dsts' in prim.kwargs: + p.kwargs['dsts'] = [dst + shift for dst in prim.kwargs['dsts']] + prims.append(p) + adapter.prims = prims + if node.isfw() and node.differentiable and node.custom: + badapter = self.scale(node.mirror, ndevs, device) + IRCell.make_pair(adapter, badapter) + return adapter + if isinstance(node, IRSegment) and node.isfw(): + nodes = [self.scale(n, ndevs, device) for n in node.nodes()] + segment = IRSegment(nodes, node.inputs(), node.outputs(), node.name) + segment._id = node.cid + return segment + return node def gen(self, device: int, outfile=None, attach=False) -> str: """ @@ -111,13 +203,21 @@ def gen(self, device: int, outfile=None, attach=False) -> str: node_args: List[List[str]] = list() gen_nodes: List[IRCell] = list() + device_map = device if self._scale_to_ndevs is None else \ + device % len(self.devices) + sequence = self.execplan.seq(device_map) unrolled_seqs = [] - for node in self.execplan.seq(device): + for node in sequence: # unwrap from ExeReuseCell node = node.cell if isinstance(node, ExeReuseCell) else node unrolled_seqs.append(node) # we use ordered dict as ordered set sequence = tuple(dict.fromkeys(unrolled_seqs)) + + # scale to multiple devices + if self._scale_to_ndevs is not None: + sequence = [self.scale(node, self._scale_to_ndevs, device) \ + for node in sequence] # init customized adapter fsegments = [node for node in sequence if isinstance(node, IRSegment) and node.isfw()] @@ -128,7 +228,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: adapter.signature = AutogradAdapterCodeGen.name(adapter) + '.apply' # initialize communication groups - self.init_comm_groups() + self.emit_comm_groups() # emit code for node in sequence: @@ -140,7 +240,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: elif isinstance(node, IRAdapter): codes = self.emit_adapter(node, prefix_attr='self.') elif isinstance(node, IRWeightReducer): - self.init_reducer(node) + self.init_reducer(node, device) codes = self.emit_reducer(node) elif isinstance(node, IRBpOperation): continue @@ -210,78 +310,18 @@ def gen(self, device: int, outfile=None, attach=False) -> str: self.clear() return code - def init_comm_groups(self): - """ - Get all communication groups. - - Creating communication group requires all the devices - enter the same call. - - The fields storing intermediate codes that are populated by this method: - - `model_init_statements` + def emit_comm_groups(self): """ - graph = self.execplan.graph - sign = 'self.init_group(ranks={ranks})' - # collect groups from weight reducer - comm_groups: Dict[Tuple[int]] = list() - for node in graph.nodes(): - if isinstance(node, IRWeightReducer): - ranks = list(node.device) - ranks.sort() - ranks = tuple(ranks) - if ranks not in comm_groups: - comm_groups.append(ranks) - # collect groups from p2p fusion - adapters = [n for n in graph.nodes(flatten=True) if isinstance(n, IRAdapter)] - for adapter in adapters: - for prim in adapter.prims: - if isinstance(prim, CollectivePrim): - ranks = list(prim.kwargs['ranks']) - ranks.sort() - ranks = tuple(ranks) - if ranks not in comm_groups: - comm_groups.append(ranks) - # create communication group - self.model_init_statements.append('# communication groups') - for ranks in comm_groups: - code = sign.format(ranks=list(ranks)) - self.model_init_statements.append(code) - self.model_init_statements.append(' ') - - def init_comm_groups(self): - """ - Get all communication groups. - Creating communication group requires all the devices enter the same call. The fields storing intermediate codes that are populated by this method: - `model_init_statements` """ - graph = self.execplan.graph sign = 'self.init_group(ranks={ranks})' - # collect groups from weight reducer - comm_groups: Dict[Tuple[int]] = list() - for node in graph.nodes(): - if isinstance(node, IRWeightReducer): - ranks = list(node.device) - ranks.sort() - ranks = tuple(ranks) - if ranks not in comm_groups: - comm_groups.append(ranks) - # collect groups from p2p fusion - adapters = [n for n in graph.nodes(flatten=True) if isinstance(n, IRAdapter)] - for adapter in adapters: - for prim in adapter.prims: - if isinstance(prim, CollectivePrim): - ranks = list(prim.kwargs['ranks']) - ranks.sort() - ranks = tuple(ranks) - if ranks not in comm_groups: - comm_groups.append(ranks) # create communication group self.model_init_statements.append('# communication groups') - for ranks in comm_groups: + for ranks in self.comm_groups: code = sign.format(ranks=list(ranks)) self.model_init_statements.append(code) self.model_init_statements.append(' ') @@ -332,7 +372,7 @@ def init_attributes(self, node: IRCell): self.init_attributes(sub_node) return - def init_reducer(self, node: IRWeightReducer) -> None: + def init_reducer(self, node: IRWeightReducer, device: int) -> None: """ Emit code to initialize involved reducer objects in `__init__`. @@ -348,7 +388,10 @@ def init_reducer(self, node: IRWeightReducer) -> None: weights = node.inputs() reducer_name = f'self.wreducer{node._id}' self.model_init_statements.append('') - init_code = reducer_init.format(reducer=reducer_name, ranks=node.device, max_nbytes=max_nbytes) + ranks = list(sorted(node.device)) + shift = (device // len(self.devices)) * len(self.devices) + ranks = [r + shift for r in ranks] + init_code = reducer_init.format(reducer=reducer_name, ranks=ranks, max_nbytes=max_nbytes) self.model_init_statements.append(init_code) weights = [ModuleCodeGen.tensor_name(t, prefix_attr='self.') for t in weights] for weight in weights: diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index 081e42af..e81d6b99 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -1,5 +1,5 @@ -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional, Tuple import copy import warnings @@ -21,22 +21,32 @@ class ScheduleCodeGen(FuncEmission): - def __init__(self, execplan: ExecutionPlan): + def __init__(self, execplan: ExecutionPlan, scale_ndevs: Optional[int] = None): + """ + Create Module code generator + @param execplan ExecutionPlan + @param scale_ndevs Optional[int]: scale to number of devices + """ self.execplan = execplan + self.devices: Tuple[int] = tuple(sorted(execplan.graph.device)) # model full code self.init_code: List[str] = [ '\n\n########## Generated Schedule Code ###########', 'import torch', 'import cube', ''] # module member name self.symbols = SymbolTable() + self._scale_to_ndevs = scale_ndevs def gen(self, device: int, outfile=None, attach=None) -> str: """ Generate scheduling code on device """ gencode = copy.copy(self.init_code) - device_nodes: List[IRCell] = self.execplan.seq(device) + device_map = device if self._scale_to_ndevs is None else \ + device % len(self.devices) + device_nodes = self.execplan.seq(device_map) + assert all(not isinstance(n, IRFwOperation) for n in device_nodes), \ "Expected all forward operators have been grouped into IRSegment" @@ -63,7 +73,11 @@ def gen(self, device: int, outfile=None, attach=None) -> str: tensors = lifetime.release_tensors_after_line(line) if len(tensors) > 0 : # not necessarily to have one after each line fb.insert_body(ScheduleCodeGen.emit_release(tensors)) - + # scale sync gradients + if self.execplan.graph.train and self._scale_to_ndevs is not None: + ranks = tuple(range(device_map, self._scale_to_ndevs, len(self.devices))) + if len(ranks) > 1: + fb.insert_body(f'model.reduce_all_gradients({ranks})') # return code outputs = ScheduleCodeGen.return_name_complex(self.execplan.graph.outputs()) code = f'return {outputs}' diff --git a/cube/compiler.py b/cube/compiler.py index ad4e9449..2926f028 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -32,7 +32,10 @@ def compile(model: SemanticModel, *args, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, model_dummy_inputs: Tuple[Any] = None, - comm_cost_fn: Optional[Callable] = None, override = True, load_content = True) -> Callable: + comm_cost_fn: Optional[Callable] = None, + override = True, + load_content = True, + scale_ndevs: Optional[int] = None) -> Callable: """ AI Scientist calls like: @@ -203,8 +206,8 @@ def decorator(fn: Callable) -> Callable: start = time.time() local_world_size = DeviceGroup().local_world_size # code generation - mgener = ModuleCodeGen(execplan) - sgener = ScheduleCodeGen(execplan) + mgener = ModuleCodeGen(execplan, scale_ndevs=scale_ndevs) + sgener = ScheduleCodeGen(execplan, scale_ndevs=scale_ndevs) for local_rank in range(local_world_size): rank = DeviceGroup().node_rank * local_world_size + local_rank fname = filename.format(rank) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index e32a5ead..4f45b5c2 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -378,6 +378,7 @@ def Add(input, other, alpha=1, *, out=None, signature = None): def Sub(input, other, alpha=1, *, out=None, signature = None): assert out is None + signature = 'torch.sub' if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input - alpha * other annos = ['*, ? -> *', '?, * -> *',] @@ -436,6 +437,18 @@ def FloorDiv(input, other, *, out=None, signature = None): return IRDimops(FloorDiv, 'floordiv', signature, annos, [input, other]) +def Exp(input, *, out=None, signature=None): + """ + torch.exp(input, *, out=None) + """ + assert out is None + if not isinstance(input, IRTensor): + return torch.exp(input) + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(Exp, 'exp', signature, annos, [input]) + + def Pow(input, exponent, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(exponent, IRObject)): @@ -490,6 +503,18 @@ def ReLU(input, inplace=False, signature = None): return IRDimops(ReLU, 'relu', signature, annos, [input], inplace=inplace) +def Abs(input, *, out=None, signature = None): + assert out is None + annos = ['* -> *'] + return IRDimops(Abs, 'abs', signature, annos, [input]) + + +def Clamp(input, min=None, max=None, *, out=None, signature = None): + assert out is None + annos = ['* -> *'] + return IRDimops(Clamp, 'clamp', signature, annos, [input], min=min, max=max) + + def Softmax(input, dim=None, _stacklevel=3, dtype=None, signature = None): """ torch.nn.functional.softmax(input, dim=None, _stacklevel=3, dtype=None) @@ -1114,6 +1139,8 @@ def Unsqueeze(input, dim, signature = None): """ edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = copy.copy(edim_in) + if dim == -1: + dim = len(edim_ou) edim_ou.insert(dim, '1') anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input],dim=dim) diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index 6ec2f930..b79de5cd 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -321,7 +321,8 @@ def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: D fn = _orig_getattr(self_obj, target) if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - result = OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) + # result = OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) + result = fn(*args_tail, **kwargs) # quick fix from yanjun result = OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) if _orig_isinstance(result, torch.Tensor): result_cpu = result.cpu() del result @@ -1563,42 +1564,42 @@ def f(x, y): operator_patch_backlist=operator_patch_backlist, forward_function_name=forward_function_name, ) - graph = tracer.trace(root, - autowrap_leaf_function = autowrap_leaf_function, - autowrap_leaf_class = autowrap_leaf_class, - leaf_module = leaf_module, - fake_middle_class = fake_middle_class, - concrete_args=concrete_args, - use_operator_patch=use_operator_patch, - operator_patch_backlist=operator_patch_backlist, - forward_function_name=forward_function_name, - ) - graph_check = tracer.trace(root, - autowrap_leaf_function = autowrap_leaf_function, - autowrap_leaf_class = autowrap_leaf_class, - leaf_module = leaf_module, - fake_middle_class = fake_middle_class, - concrete_args=concrete_args, - use_operator_patch=use_operator_patch, - operator_patch_backlist=operator_patch_backlist, - forward_function_name=forward_function_name, - ) - # compare to check equal - assert len(graph.nodes) == len(graph_check.nodes), f'number nodes: {len(graph.nodes)} vs {len(graph_check.nodes)}' - for node_a, node_b in zip(graph.nodes, graph_check.nodes): - node_a: Node - node_b: Node - target_a = node_a.target - target_b = node_b.target - if node_a.op == 'get_attr' and node_a.name.startswith('_tensor_constant'): - assert node_b.op == 'get_attr' and node_b.name.startswith('_tensor_constant') - assert torch.equal(getattr(root, node_a.name), getattr(root, node_b.name)) - elif node_a.op == 'call_function' and isinstance(target_a, Callable) and target_a.__name__ == 'apply' and\ - hasattr(target_a, '__self__') and issubclass(target_a.__self__, torch.autograd.Function): - assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\ - hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function) - else: - assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}' + # graph = tracer.trace(root, + # autowrap_leaf_function = autowrap_leaf_function, + # autowrap_leaf_class = autowrap_leaf_class, + # leaf_module = leaf_module, + # fake_middle_class = fake_middle_class, + # concrete_args=concrete_args, + # use_operator_patch=use_operator_patch, + # operator_patch_backlist=operator_patch_backlist, + # forward_function_name=forward_function_name, + # ) + # graph_check = tracer.trace(root, + # autowrap_leaf_function = autowrap_leaf_function, + # autowrap_leaf_class = autowrap_leaf_class, + # leaf_module = leaf_module, + # fake_middle_class = fake_middle_class, + # concrete_args=concrete_args, + # use_operator_patch=use_operator_patch, + # operator_patch_backlist=operator_patch_backlist, + # forward_function_name=forward_function_name, + # ) + # # compare to check equal + # assert len(graph.nodes) == len(graph_check.nodes), f'number nodes: {len(graph.nodes)} vs {len(graph_check.nodes)}' + # for node_a, node_b in zip(graph.nodes, graph_check.nodes): + # node_a: Node + # node_b: Node + # target_a = node_a.target + # target_b = node_b.target + # if node_a.op == 'get_attr' and node_a.name.startswith('_tensor_constant'): + # assert node_b.op == 'get_attr' and node_b.name.startswith('_tensor_constant') + # assert torch.equal(getattr(root, node_a.name), getattr(root, node_b.name)) + # elif node_a.op == 'call_function' and isinstance(target_a, Callable) and target_a.__name__ == 'apply' and\ + # hasattr(target_a, '__self__') and issubclass(target_a.__self__, torch.autograd.Function): + # assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\ + # hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function) + # else: + # assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}' with MagicMethodPatcher(): name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 37e0af53..888a4877 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -56,13 +56,16 @@ def exist(signature: str) -> bool: __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, + __ttemplate('abs'): function.Abs, __ttemplate('exp'): function.Exp, + __ttemplate('clamp'): function.Clamp, __ttemplate('squeeze'): function.Squeeze, __ttemplate('unsqueeze'): function.Unsqueeze, __tttemplate('type_as'): function.TypeAs, __ttemplate('triu'): function.Triu, __ttemplate('tril'): function.Tril, __ftemplate('relu'): function.ReLU, + __ftemplate('silu'): function.SiLU, __fcntemplate('gelu'): function.GeLU, __ttemplate('eq') : function.CompareEQ, '_operator.eq': function.CompareEQ, diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index b29e3232..eab15dbd 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -393,7 +393,9 @@ def volume(self) -> int: Use ring-based communication cost """ ndevs = len(self.inputs()) - return (ndevs - 1) * self.input(0).nelement() // ndevs + # FIXME: temporally disable reduce scatter in code generation + # which has parity issues for now. + return 100 * (ndevs - 1) * self.input(0).nelement() // ndevs def __repr__(self) -> str: return f'{self.outputs()} = reduce_scatter[{self.device}]({self.inputs()})' diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 1d54a60d..40c80e07 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -252,7 +252,8 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b fw_span, bw_span, infer_memory, train_mem_info = \ CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) except: - fw_span, bw_span, infer_memory, train_mem_info = float('inf'), float('inf'), 0, [] + print(f'WARNING: fail to profile {node}') + fw_span, bw_span, infer_memory, train_mem_info = 0, 0, 0, [] # log to database key = self._serialize(node) self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 559243cf..a4866154 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -24,13 +24,18 @@ def get_nbytes(dtype: torch.dtype) -> int: class Reducer: - def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912): + def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, use_mean=False): self._params: List[torch.nn.Parameter] = list() # note this need to be called for every device self.ranks = ranks self._group = DeviceGroup().get_group(ranks) self.bucket_size = max_bucket_size_bytes + self.use_mean = use_mean + + @property + def params(self): + return self._params def add_param(self, param: torch.nn.Parameter): self._params.append(param) @@ -66,7 +71,10 @@ def allreduce(self): for bucket in buckets[tp]: grads = [param.grad.data for param in bucket] coalesced = self._flatten_dense_tensors(grads) - torch.distributed.all_reduce(coalesced, group=self._group) + torch.distributed.all_reduce( + coalesced, + op=torch.distributed.ReduceOp.AVG if self.use_mean else torch.distributed.ReduceOp.SUM, + group=self._group) all_synced = self._unflatten_dense_tensors(coalesced, grads) for grad, synced in zip(grads, all_synced): grad.copy_(synced, non_blocking=True) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index a43ee395..b28c2bf6 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -17,6 +17,10 @@ def __init__(self): self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() self._batch_size: Optional[int] = None + @property + def reducers(self): + return self._reducers + def add_reducer(self, reducer: Reducer): if not isinstance(reducer, Reducer): raise RuntimeError(f"Expected a Reducer but got {type(reducer)}") @@ -43,6 +47,16 @@ def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: i assert hasattr(self, attr), f"{attr} is not in the module" self._fullmap[attr] = (tid, slicers, val_chunks) + def reduce_all_gradients(self, ranks: Tuple[int]): + """ + reduce gradients for the whole model. + This can only be used for data parallel + """ + reducer = Reducer(ranks, use_mean=True) + for parameter in self.parameters(): + reducer.add_param(parameter) + reducer.allreduce() + def get_full_map(self): return self._fullmap @@ -89,7 +103,7 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref }, filename) @staticmethod - def merge_partial_states(state_dicts): + def merge_partial_states(state_dicts, zero_idx_maps=None): """ :param state_dicts: list of state_dict from different ranks state_dict(model_state_dict, optimizer_state_dict, dist_param_map, param_area_map) @@ -99,6 +113,48 @@ def merge_partial_states(state_dicts): if len(state_dicts) == 1: return state_dicts[0][0], state_dicts[0][1] + # at first, merge the partitioned optimizer states due to zero to the zero-disabled format + if zero_idx_maps is not None: + def _check_opt_state(opt_state): + cnt = 0 + sorted_opt_state = {} + for idx in sorted(opt_state.keys()): + assert cnt == idx, f'opt state error: {idx} vs {cnt}, in {opt_state.keys()}' + sorted_opt_state[idx] = opt_state[idx] + cnt += 1 + return sorted_opt_state + optimizer_state_dict = {} + worker_cnt = len(state_dicts) + opt_state_list = [] + for work_idx in range(worker_cnt): + zero_idx2model_idx, model_idx2zero_idx, zero_rank_groups = zero_idx_maps[work_idx] + opt_state = {} + # first place local opt state to right index + if len(zero_idx2model_idx) == 0: + assert len(state_dicts[work_idx][1]['state']) == 0 + for local_idx, val in state_dicts[work_idx][1]['state'].items(): # worker / last_optimizer_state / state + print(f'{work_idx}, {local_idx}') + global_idx = zero_idx2model_idx[local_idx] + assert global_idx not in opt_state + opt_state[global_idx] = val + # for each rank group, copy opt state from other buckets + for rank_group, param_idx_buckets in zero_rank_groups.items(): + for bucket_idx, rank in enumerate(rank_group): + if rank == work_idx: continue + for global_idx in param_idx_buckets[bucket_idx]: + other_local_idx = zero_idx_maps[rank][1][global_idx] # rank / model_idx2zero_idx / global_idx + assert global_idx not in opt_state + opt_state[global_idx] = state_dicts[rank][1]['state'][other_local_idx] # worker / last_optimizer_state / state / local idx + opt_state = _check_opt_state(opt_state) + opt_state_list.append(opt_state) + assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' + # assign opt_state to state_dicts, cannot be assigned in the above loop + opt_state_len = len(opt_state_list[0]) + for work_idx in range(worker_cnt): + state_dicts[work_idx][1]['state'] = opt_state_list[work_idx] + state_dicts[work_idx][1]['param_groups'][0]['params'] = sorted(opt_state_list[work_idx].keys()) + assert len(opt_state_list[work_idx]) == opt_state_len + # find tensor full shape param_max_dimsize = {} for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: diff --git a/scripts/keep.py b/scripts/keep.py index b580a83e..d45a02fc 100644 --- a/scripts/keep.py +++ b/scripts/keep.py @@ -6,9 +6,17 @@ import re def get_gpu_util(rank): + from shutil import which + smi = None + if which('nvidia-smi') is not None: + smi = 'nvidia-smi' + elif which('rocm-smi') is not None: + smi = 'rocm-smi' + else: + raise Exception('Cannot find either nvidia-smi or rocm-smi!') cmds = [ - 'nvidia-smi', + smi, '-i', str(rank), ] From e40a46b7e4a9325446cb843f774e948a428dcb94 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 6 Jun 2023 06:29:29 +0000 Subject: [PATCH 1398/1892] Merged PR 1605: add info in profiler --- cube/profiler/database.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 40c80e07..9dc22f01 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -99,6 +99,7 @@ def run_step(func, tensors, kwargs, backward: bool): infer_memory = mtoc - mtic train_mem_info = [] + train_mem2in_idx = [] used_tensor = set() # ref torch/utils/checkpoint.py/_checkpoint_without_reentrant def pack_hook(x): @@ -109,6 +110,12 @@ def pack_hook(x): for dim in list(x.size()): byte_size = byte_size * dim train_mem_info.append(byte_size) + idx = -1 + for i, t in enumerate(tensors): + if t.storage().data_ptr() == x.storage().data_ptr(): + idx = i + break + train_mem2in_idx.append(idx) return x def unpack_hook(x): @@ -144,7 +151,7 @@ def unpack_hook(x): fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds bw_span = max(fwbw_span - fw_span, 0.0) - return fw_span, bw_span, infer_memory, tuple(train_mem_info) + return fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx class ProfileDataBase: @@ -249,26 +256,28 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b # run profiling try: - fw_span, bw_span, infer_memory, train_mem_info = \ + fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = \ CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) except: print(f'WARNING: fail to profile {node}') - fw_span, bw_span, infer_memory, train_mem_info = 0, 0, 0, [] + fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = 0, 0, 0, [], [] # log to database key = self._serialize(node) - self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) + self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span,\ + infer_memory, train_mem_info, residual_mem, train_mem2in_idx) print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | fw: {round(fw_span, 2)} ms | " - f"bw: {round(bw_span, 2)} ms | infer mem: {infer_memory} | train mem info: {train_mem_info}") + f"bw: {round(bw_span, 2)} ms | infer mem: {infer_memory} | train mem info: {train_mem_info} | idx: {train_mem2in_idx}") if isinstance(device, int): torch.cuda.set_device(orig_device) - return tuple(in_mem_info), tuple(param_mem_info), fw_span, bw_span, infer_memory, train_mem_info, residual_mem + return tuple(in_mem_info), tuple(param_mem_info), fw_span, bw_span, infer_memory, \ + tuple(train_mem_info), residual_mem, tuple(train_mem2in_idx) def insert(self, name: str, key: str, in_mem_info: Tuple[int], param_mem_info: Tuple[int], fw_span: float, bw_span: float, infer_memory: int, train_mem_info: Tuple[int], - residual_mem: int): + residual_mem: int, train_mem2in_idx: Tuple[int]): """ log the span of a function name with key @@ -284,7 +293,7 @@ def insert(self, name: str, key: str, in_mem_info: Tuple[int], param_mem_info: T assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = (in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem) + self._data[name][key] = (in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem, train_mem2in_idx) def exist(self, node: IRFwOperation) -> bool: """ @@ -301,7 +310,7 @@ def exist(self, node: IRFwOperation) -> bool: return False return True - def query(self, node: IRFwOperation) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int]]: + def query(self, node: IRFwOperation) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int], int, Tuple[int]]: """! Get the performance number of a node in IRGraph @@ -321,7 +330,7 @@ def query(self, node: IRFwOperation) -> Tuple[Tuple[int], Tuple[int], float, flo return None return self._data[node.signature][key] - def query_func(self, signature, shapes, dtypes) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int]]: + def query_func(self, signature, shapes, dtypes) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int], int, Tuple[int]]: """ Get performance number of given name (signature), shapes and dtypes From 9803c4c297cbd8163dda353c8980a7ba3a37cc0e Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 7 Jun 2023 07:23:10 +0000 Subject: [PATCH 1399/1892] Merged PR 1609: support operator: sqrt --- cube/graph/function/function.py | 12 ++++++++++++ cube/graph/parser/mappingfx.py | 1 + 2 files changed, 13 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 4f45b5c2..140c0e55 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -424,6 +424,18 @@ def Exp(input, *, out=None, signature=None): return IRDimops(Exp, 'exp', signature, annos, [input]) +def Sqrt(input, *, out=None, signature=None): + """ + torch.sqrt(input, *, out=None) + """ + assert out is None + if not isinstance(input, IRTensor): + return torch.sqrt(input) + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(Sqrt, 'sqrt', signature, annos, [input]) + + def FloorDiv(input, other, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 888a4877..ba1480f8 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -58,6 +58,7 @@ def exist(signature: str) -> bool: __ttemplate('sum'): function.Sum, __ttemplate('abs'): function.Abs, __ttemplate('exp'): function.Exp, + __ttemplate('sqrt'): function.Sqrt, __ttemplate('clamp'): function.Clamp, __ttemplate('squeeze'): function.Squeeze, __ttemplate('unsqueeze'): function.Unsqueeze, From c7c92af4878498bac57b91a4c427ecc0e255c74e Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 14 Jun 2023 11:37:40 +0000 Subject: [PATCH 1400/1892] Merged PR 1622: Support new model on cube TODO: check logic of pre div ahead of reduce-sum --- cube/codegen/emit.py | 2 +- cube/codegen/module/module.py | 89 ++++-- cube/codegen/schedule/schedule.py | 5 - cube/compiler.py | 8 +- cube/flags.py | 13 +- cube/program.py | 7 +- cube/runtime/adapter/reducer.py | 502 +++++++++++++++++++++++++----- cube/runtime/device.py | 19 +- cube/runtime/module.py | 176 ++++++++--- cube/utils.py | 9 +- tests/codegen/test_scale.py | 202 ++++++++++++ tests/runtime/test_reducer.py | 189 +++++++++++ 12 files changed, 1064 insertions(+), 157 deletions(-) create mode 100644 tests/codegen/test_scale.py create mode 100644 tests/runtime/test_reducer.py diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 965ed86b..4a3429dd 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -220,7 +220,7 @@ def emit_reducer(node: IRWeightReducer) -> List[str]: - NONE """ reducer_name = f'self.wreducer{node._id}' - code = f'{reducer_name}.allreduce()' + code = f'{reducer_name}.sync_grads()' return [code] @staticmethod diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 967ff75e..45e69e47 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -1,5 +1,5 @@ from typing import Dict, List, Optional, Tuple -from more_itertools import split_when +import more_itertools import warnings import copy import torch @@ -115,7 +115,40 @@ def __init__(self, execplan: ExecutionPlan, scale_ndevs: Optional[int] = None) - self.comm_groups: List[Tuple[int]] = self.get_comm_groups(scale_ndevs) # whether to scale (with data parallelism) self._scale_to_ndevs = scale_ndevs - self._scale_reducers = [] + if scale_ndevs is not None: + self.add_scale_reducers() + + def add_scale_reducers(self): + """ + Insert reducers to for scale scenario + """ + if self._scale_to_ndevs is None: + return + graph = self.execplan.graph + # for each device, collect parameters in the all reducers and create a reducer for the rest + for device in self.devices: + # collect parameters in the all reducers belonging to this device + all_params = set() + for reducer in graph.select(ntype=IRWeightReducer): + if device not in reducer.device: continue + for param in reducer.inputs(): + assert param not in all_params, \ + f'detected a parameter {param} in multiple reducers on device {device}' + all_params.update(reducer.inputs()) + # create a reducer for the rest parameters used for this device + rest_params = [] + for param in self.execplan.graph.attributes(): + if not param.is_param(): continue + for ctensor in graph.ctensors(param): + if device not in ctensor.device: continue + if ctensor not in all_params and ctensor not in rest_params: + rest_params.append(ctensor) + if len(rest_params) == 0: + continue + # create reducer and append to the execution + reducer = IRWeightReducer(rest_params) + reducer.device = device # will be scaled in `self.scale` + self.execplan.at(device).append(reducer) def get_comm_groups(self, scale_ndevs: Optional[int] = None): """ @@ -134,14 +167,19 @@ def get_comm_groups(self, scale_ndevs: Optional[int] = None): # scale communication groups graph = self.execplan.graph comm_groups = [] + # communication groups for parameters that are in reducers reducers: List[IRWeightReducer] = graph.select(ntype=IRWeightReducer) for reducer in reducers: - ranks = np.array(tuple(sorted(reducer.device)), dtype=int) - for i in range(nreplica): - shifted_ranks = tuple(ranks + i * len(self.devices)) - shifted_ranks = tuple(int(rank) for rank in shifted_ranks) - if shifted_ranks not in comm_groups: - comm_groups.append(shifted_ranks) + ranks = more_itertools.flatten(list(range(device, scale_ndevs, len(self.devices))) \ + for device in reducer.device) + ranks = tuple(sorted(ranks)) + comm_groups.append(ranks) + # communication groups for parameters that are outside reducers + for device in self.devices: + ranks = list(range(device, scale_ndevs, len(self.devices))) + if len(ranks) > 1: + comm_groups.append(ranks) + # communication groups for activations adapters = graph.select(ntype=IRAdapter) for adapter in adapters: for prim in adapter.prims: @@ -152,12 +190,6 @@ def get_comm_groups(self, scale_ndevs: Optional[int] = None): shifted_ranks = tuple(int(rank) for rank in shifted_ranks) if shifted_ranks not in comm_groups: comm_groups.append(shifted_ranks) - # allreduce all gradients - if nreplica > 1: - for rank in self.devices: - ranks = tuple(range(rank, scale_ndevs, len(self.devices))) - if ranks not in comm_groups: - comm_groups.append(ranks) return comm_groups def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: @@ -188,6 +220,15 @@ def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: badapter = self.scale(node.mirror, ndevs, device) IRCell.make_pair(adapter, badapter) return adapter + if isinstance(node, IRWeightReducer): + reducer = IRWeightReducer(node.inputs(), name=node.name) + reducer._id = node.cid + ranks = list(node.device) + scale_ranks = [] + for rank in ranks: + scale_ranks += list(range(rank, ndevs, len(self.devices))) + reducer.device = sorted(scale_ranks) + return reducer if isinstance(node, IRSegment) and node.isfw(): nodes = [self.scale(n, ndevs, device) for n in node.nodes()] segment = IRSegment(nodes, node.inputs(), node.outputs(), node.name) @@ -380,8 +421,15 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: - `model_init_statements` """ max_nbytes = CompileFlag.max_reducer_bucket + async_op = CompileFlag.async_reducer + zero = CompileFlag.use_zero + reduce_op = f"'{CompileFlag.reducer_op}'" # reducer init interface - reducer_init = '{reducer} = cube.runtime.adapter.Reducer(ranks={ranks}, max_bucket_size_bytes={max_nbytes})' + reducer_init = ( + "{reducer} = cube.runtime.adapter.Reducer(" + "ranks={ranks}, reduce_op={reduce_op}, " + "async_op={async_op}, zero={zero}, max_bucket_size_bytes={max_nbytes})" + ) reducer_add = 'self.add_reducer({reducer})' add_param = '{reducer}.add_param({weight})' # create reducer in declare region @@ -389,9 +437,9 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: reducer_name = f'self.wreducer{node._id}' self.model_init_statements.append('') ranks = list(sorted(node.device)) - shift = (device // len(self.devices)) * len(self.devices) - ranks = [r + shift for r in ranks] - init_code = reducer_init.format(reducer=reducer_name, ranks=ranks, max_nbytes=max_nbytes) + init_code = reducer_init.format( + reducer=reducer_name, ranks=ranks, reduce_op=reduce_op, + async_op=async_op, zero=zero, max_nbytes=max_nbytes) self.model_init_statements.append(init_code) weights = [ModuleCodeGen.tensor_name(t, prefix_attr='self.') for t in weights] for weight in weights: @@ -449,7 +497,7 @@ def emit_segment(segment: IRSegment) -> List[str]: nodes : List[IRCell] = segment.nodes() lifetime = LifeCycle(nodes, segment.inputs(), segment.outputs()) rc_groups: List[List[IRCell]] = list( - split_when(nodes, lambda prev, curr: prev.recompute != curr.recompute)) + more_itertools.split_when(nodes, lambda prev, curr: prev.recompute != curr.recompute)) codes: List[str] = [] for rc_group in rc_groups: @@ -460,9 +508,8 @@ def emit_segment(segment: IRSegment) -> List[str]: else: # get recompute excution code rc_segment = segment.create_segment(rc_group) - rc_lifetime = LifeCycle(rc_group, rc_segment.inputs(), rc_segment.outputs()) rc_codes = ModuleCodeGen._emit_recompute(rc_group, - rc_segment.inputs(), rc_segment.outputs(), rc_lifetime) + rc_segment.inputs(), rc_segment.outputs(), lifetime) codes += rc_codes # release input tensors after exiting a RC group: last_node = rc_group[-1] diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index e81d6b99..d0be4374 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -73,11 +73,6 @@ def gen(self, device: int, outfile=None, attach=None) -> str: tensors = lifetime.release_tensors_after_line(line) if len(tensors) > 0 : # not necessarily to have one after each line fb.insert_body(ScheduleCodeGen.emit_release(tensors)) - # scale sync gradients - if self.execplan.graph.train and self._scale_to_ndevs is not None: - ranks = tuple(range(device_map, self._scale_to_ndevs, len(self.devices))) - if len(ranks) > 1: - fb.insert_body(f'model.reduce_all_gradients({ranks})') # return code outputs = ScheduleCodeGen.return_name_complex(self.execplan.graph.outputs()) code = f'return {outputs}' diff --git a/cube/compiler.py b/cube/compiler.py index 2926f028..4bc38621 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -35,7 +35,7 @@ def compile(model: SemanticModel, *args, comm_cost_fn: Optional[Callable] = None, override = True, load_content = True, - scale_ndevs: Optional[int] = None) -> Callable: + scale: Union[bool, int] = False) -> Callable: """ AI Scientist calls like: @@ -65,6 +65,9 @@ def train_step(model, dataloader): generated code, i.e., the policy won't take effect. Default true. @param load_content bool: If true, will load parameter from exsiting saved models. Otherwise, will initial model parameters with empty tensor. + @param scale Union[bool, int]: If true, will scale the generated code to the + total launched number of GPUs. If int, will scale to the specified number. + Default False, no scaling. @return sched_fn Callable: the scheduling function loaded from generated code. """ @@ -206,6 +209,9 @@ def decorator(fn: Callable) -> Callable: start = time.time() local_world_size = DeviceGroup().local_world_size # code generation + scale_ndevs = None + if scale: + scale_ndevs = resource.ngpus if isinstance(scale, bool) else scale mgener = ModuleCodeGen(execplan, scale_ndevs=scale_ndevs) sgener = ScheduleCodeGen(execplan, scale_ndevs=scale_ndevs) for local_rank in range(local_world_size): diff --git a/cube/flags.py b/cube/flags.py index a5fbb8de..5c88b273 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -40,8 +40,17 @@ class CompileFlag: dev_mode = _to_bool('SINGLE_DEV_MODE') # allow to use python xx.py async_comm = _to_bool('ASYNC_COMM') - # maximal reducer weight bytes for one allreduce - max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=6e8) + # ============== reducer ================== + # use zero optimization on optimizer status. + # to cooperate with zero, user needs to call `model.parameters_for_optimizer()` + # to get parameters for optimizer, and `model.gather_params()` after `optimizer.step()` + use_zero = _to_bool('USE_ZERO') + # use async communication to overlap gradient synchronization and backward computation + async_reducer = _to_bool('ASYNC_REDUCER') # use async reducer + # maximal reducer weight bytes for one allreduce (only effective for async): default 128MB + max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=137217728) + # perform reducer op on gradients, can be sum, avg, mean, max, min. Default is sum + reducer_op = os.environ.get('REDUCER_OP', default='sum') # use automate mixture precision training, where weights, gradients # and optimizer status are kept in its original data type (can be float32), diff --git a/cube/program.py b/cube/program.py index f669b887..d9dc4595 100644 --- a/cube/program.py +++ b/cube/program.py @@ -200,11 +200,14 @@ def load_module(self, filename: str): spec = importlib.util.spec_from_file_location("GenModel", filename) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - self._loaded_module = module.GenModel().cuda() + self._loaded_module: CubeModule = module.GenModel().cuda() + # load parameter content if self.save_content: print_each_rank("> loading parameter content...") - # TODO: make hardcode ./fullmodel.pt programmable self._loaded_module.load_attr_content('./fullmodel.pt') + # initialize reducer + for reducer in self._loaded_module.reducers: + reducer.build_buckets() def get_gen_module(self) -> Optional[torch.nn.Module]: return self._loaded_module diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index a4866154..94b68499 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -1,122 +1,480 @@ -""" -Borrowed from Megatron Implementation -""" - -from typing import List -import torch +from typing import List, Dict, Tuple, Any, Callable, Optional +from functools import partial import warnings +import torch +from torch.utils.hooks import RemovableHandle from cube.runtime.device import DeviceGroup -from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.timer import CudaTimer from cube.flags import RuntimeFlag -def get_nbytes(dtype: torch.dtype) -> int: - try: - if dtype.is_floating_point(): - return torch.finfo(dtype).bits // 8 +def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: + """ + Get reduce op from string + """ + reduce_op = reduce_op.lower() # to lower case + supported = ['sum', 'avg', 'mean', 'min', 'max'] + if reduce_op == 'sum': + return torch.distributed.ReduceOp.SUM + elif reduce_op == 'avg' or reduce_op == 'mean': + return torch.distributed.ReduceOp.AVG + elif reduce_op == 'min': + return torch.distributed.ReduceOp.MIN + elif reduce_op == 'max': + return torch.distributed.ReduceOp.MAX + raise KeyError(f"Unsupported reduce op {reduce_op}. Supported reduce op: {supported}") + + +class Bucket: + + # config: whether to use allreduce for zero (default: reduce-scatter) + use_allreduce_for_zero: bool = False + + def __init__(self, params: List[torch.nn.Parameter], + param_buffer: torch.Tensor, grad_buffer: torch.Tensor, + reduce_op: torch.distributed.ReduceOp, + group, async_op: bool, zero: bool): + """ + Create a communication unit for parameter allreduce. + + One allreduce will be called for all gradients associated to the parameters. + The parameters are assumed to participate in backward and generate gradient. + + @param params List[torch.nn.Parameter]: the parameters + @param param_buffer torch.Tensor: Paramter contiguous buffer + @param grad_buffer torch.Tensor: gradient contiguous buffer + @param reduce_op torch.distributed.ReduceOp: the reduce op used by collectives + @param group: communication group + @param async_op bool: whether to use asynchronous operation + @param zero bool: whether to use zero optimization on gradients + """ + + self._params: List[torch.nn.Parameter] = params + self._pofset: Dict[torch.nn.Parameter, int] = {} + self._reduce_op = reduce_op + self._group = group + self._wsz: int = torch.distributed.get_world_size(group=self._group) + self._cnt = 0 + self._work = None # communication handle + self._hooks: List[Tuple[Any, RemovableHandle]] = [] + + self._async: bool = async_op + self._zero: bool = zero + self._contiguous_params = param_buffer + self._contiguous_grads = grad_buffer + assert grad_buffer.size() == param_buffer.size() + assert grad_buffer.size(0) % self._wsz == 0, "internal error: buffer size not chunkable" + # the parameter exposed for optimizer + self._param_for_optimizer: torch.nn.Parameter = None + # total number of parameters + self._numel: int = sum(p.numel() for p in self._params) + self._padding: int = self._contiguous_grads.size(0) - self._numel + + # only async will enable contiguous gradient + self.build() + self.register_hooks() + + @property + def numel(self) -> int: + """total number of parameters in the bucket""" + return self._numel + + @property + def params(self) -> Tuple: + """Parameter list""" + return self._params + + @property + def zero(self) -> bool: + """Whether enable zero for this bucket""" + return self._zero + + def build(self): + """ + Build offset for each parameter + This should only be called once during the construction of bucket. + """ + self._numel = sum(p.numel() for p in self._params) + ofst = 0 + for param in self._params: + self._pofset[param] = ofst + ofst += param.numel() + # build parameter for optimizer (shared storage). + # Its gradient will be updated everytime calling `self.sync_grads()` + if not self._zero: + opt = self._contiguous_params[:self._numel] + else: + rank = torch.distributed.get_rank(group=self._group) + opt = self._contiguous_params.chunk(self._wsz)[rank] + if rank == self._wsz - 1 and self._padding != 0: + opt = opt[:-self._padding] + self._param_for_optimizer = torch.nn.Parameter(opt) + + def register_hooks(self): + """ + Register post-backward hook to each paramter + + The post-backward will change the generated gradient from `.grad` to `self._contiguous_grads`. + The `.grad` will always keep as None until the finish of allreduce sync. + After allreduce sync, each parameter will be reset by its `.grad` attribute, which + shares the same storage from `self._contiguous_grads`. + + This should only be called once during the construction of bucket. + """ + + @torch.no_grad() + def post_grad_hook(param: torch.nn.Parameter, *unused): + # stream = DeviceGroup().get_stream('reducer') + ofst = self._pofset[param] + # due to **unknown** reasons, multi-stream computation has incorrect results + # with torch.cuda.stream(stream): # async update to overlap backward computation + # TODO: need to handle sparse gradients in torch.nn.Embedding + # print('yizhu1', param.requires_grad, param.size(), param.data_ptr) + self._contiguous_grads[ofst:ofst+param.numel()].add_(param.grad.data.view(-1)) + param.grad = None + + if RuntimeFlag.accum_mode: return + + self._cnt += 1 + assert self._cnt <= len(self._params), \ + "detected double backward for a weight (not supported), or not use `model.zero_grad()` after optimizer" + + # perform all-reduce + if self._async and self._cnt == len(self._params): + # wait until all gradients are accumulated in the gradient buffer + # stream.synchronize() + if self._zero: # zero will use reduce-scatter (default) or allreduce on gradient shards + if Bucket.use_allreduce_for_zero: + self._work = torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, + group=self._group, async_op=True) + else: + rank = torch.distributed.get_rank(group=self._group) + shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) + self._work = torch.distributed.reduce_scatter( + shards[rank], shards, op=self._reduce_op, + group=self._group, async_op=True) + else: + self._work = torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, + group=self._group, async_op=True) + + for param in self._params: + # same trick with FSDP and Megatron + # reference: https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/fully_sharded_data_parallel.py#L3177-L3188 + param_tmp = param.expand_as(param) + # gets its AccumulateGrad object. + grad_acc = param_tmp.grad_fn.next_functions[0][0] + hook = grad_acc.register_hook(partial(post_grad_hook, param)) + # grad_acc must keep, otherwise the hook won't take effect + self._hooks.append((grad_acc, hook)) + + def sync_grads(self): + """ + Wait until allreduce finished (async), or perform allreduce (sync). + + The `.grad` attribute for each parameter will also be set after + the completion of allreduce. + """ + rank = torch.distributed.get_rank(group=self._group) + # async + if self._async: + if CudaTimer().enabled and CudaTimer().predefined: + warnings.warn(f'CudaTimer: the communication time of async ' + f'reducer will not be recorded in `comm`') + assert self._work is not None + self._work.wait() + else: + if self._reduce_op == torch.distributed.ReduceOp.SUM: + self._contiguous_grads.div_(torch.distributed.get_world_size(group=self._group)) + CudaTimer().start('comm', predefined=True) + if self._zero: + if Bucket.use_allreduce_for_zero: + torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, group=self._group) + else: + shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) + torch.distributed.reduce_scatter( + shards[rank], shards, op=self._reduce_op, group=self._group) + else: + torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, group=self._group) + CudaTimer().stop('comm', predefined=True) + # grads = self._contiguous_grads.clone() + for param in self._params: + assert param.grad is None + pofst = self._pofset[param] + param.grad = self._contiguous_grads[pofst:pofst+param.numel()].view(param.size()) + # the following two methods can make `.grad` isolate with `self._contiguous_grads`, + # thereby enabling torch.zero_ inside the reducer without requiring user to modify. + # param.grad = grads[pofst:pofst+param.numel()].view(param.size()) + # param.grad = self._contiguous_grads[pofst:pofst+param.numel()].clone().view(param.size()) + + # setup gradient for optimizer parameters + if self._zero: + grad = self._contiguous_grads.chunk(self._wsz, dim=0)[rank] + if rank == self._wsz - 1 and self._padding != 0: + grad = grad[:-self._padding] + self._param_for_optimizer.grad = grad else: - return torch.iinfo(dtype).bits // 8 - except Exception as e: - warnings.warn(f'Cannot figure out bytes of dtype: {dtype}, set default as 4.') - return 4 + self._param_for_optimizer.grad = self._contiguous_grads[:self._numel] + + def gather_params(self): + """ + All-gather parameters + """ + assert self._zero, "gathering paramters is only for zero optimization." + rank = torch.distributed.get_rank(group=self._group) + shards = list(self._contiguous_params.chunk(self._wsz, dim=0)) + torch.distributed.all_gather(shards, shards[rank], group=self._group) + + def reset(self): + """Reset status.""" + self._cnt = 0 + self._work = None class Reducer: - def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, use_mean=False): + def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, + reduce_op: str = 'sum', async_op: bool = False, zero: bool = False): + """ + Create a reducer applied on a set of weights for weight reduction + + This assumes the communication group is already created by every rank. + @param ranks List[int]: reducer communication group + @param max_bucket_size_bytes int: largest bucket size for one-time communication, + only work for asynchronous reducer. + @param reduce_op str: reduce operation, can be 'sum', 'avg', 'max' or 'min' (default 'sum') + @param async_op bool: whether to overlap with backward computation (default False) + @param zero bool: whether to apply zero optimization on gradients + """ self._params: List[torch.nn.Parameter] = list() - # note this need to be called for every device - self.ranks = ranks + self._numel: int = 0 + self._ranks = ranks self._group = DeviceGroup().get_group(ranks) - self.bucket_size = max_bucket_size_bytes - self.use_mean = use_mean + self._bucket_size: Optional[int] = max_bucket_size_bytes if async_op else None + self._reduce_op = _get_reduce_op(reduce_op) + # buckets stands for a transission unit + self._buckets: List[Bucket] = list() + self._async: bool = async_op + self._zero: bool = zero + # contiguous parameter buffer and gradient buffer + self._contiguous_params: torch.Tensor = None + self._contiguous_grads: torch.Tensor = None + # hooks + self._hooks: List[Callable] = [] @property - def params(self): - return self._params + def params(self) -> Tuple[torch.nn.Parameter]: + return tuple(self._params) + + @property + def ranks(self) -> Tuple[int]: + return tuple(self._ranks) + + @property + def numel(self) -> int: + """Total number of parameters""" + return self._numel + + @property + def zero(self) -> bool: + """Whether to apply zero optimization on gradients""" + return self._zero + + @property + def buckets(self) -> Tuple[Bucket]: + return tuple(self._buckets) def add_param(self, param: torch.nn.Parameter): + """ + Add a parameter to the reducer + + The reducer assumes the ordering of added parameter + is consistent with forward order. Otherwise, the overlapping + will show less benefits. + + @param param torch.nn.Parameter: the added parameter + """ self._params.append(param) + self._numel += param.numel() - def allreduce(self): + def build_buckets(self): """ - Reduce gradients across given group + Build buckets the reducer. + + The parameters in each bucket have consistent data types, + and each bucket contains at least one parameter. + If the bucket contains more than 2 parameters, than the total size is samller + than the max_bucket_size_bytes. """ - if RuntimeFlag.accum_mode: return + # step 1: build bucket for overlapping gradient synchronization + bucket_size = self._numel * 8 + 1 if self._bucket_size is None else self._bucket_size buckets = {} - tp2size = {} + dtype2size = {} for param in self._params: - if param.requires_grad and param.grad is not None: + if param.requires_grad: cur_byte_size = param.nelement() * param.element_size() tp = param.data.type() if tp not in buckets: buckets[tp] = [[param]] - tp2size[tp] = cur_byte_size + dtype2size[tp] = cur_byte_size else: - if cur_byte_size > self.bucket_size: - warnings.warn(f'find one parameter {param.shape} ({cur_byte_size} bytes) larger than bucket size {self.bucket_size}') + if cur_byte_size > bucket_size: + warnings.warn(f'find one parameter {param.shape} ({cur_byte_size} bytes) larger than bucket size {self._bucket_size}') buckets[tp].insert(0, [param]) - elif tp2size[tp] + cur_byte_size <= self.bucket_size: - tp2size[tp] = tp2size[tp] + cur_byte_size + elif dtype2size[tp] + cur_byte_size <= bucket_size: + dtype2size[tp] = dtype2size[tp] + cur_byte_size buckets[tp][-1].append(param) else: - tp2size[tp] = cur_byte_size + dtype2size[tp] = cur_byte_size buckets[tp].append([param]) + seq_buckets: List[List[torch.nn.Parameter]] = [] + for dtype in buckets: + if not self._async: + assert len(buckets[dtype]) == 1, \ + f"internal error: synchronized reducer only needs one bucket, but got {len(buckets[dtype])}" + for bucket in buckets[dtype]: + seq_buckets.append(bucket) - # for each bucket, do all-reduce - CudaTimer().start(field_name='comm', predefined=True) - for tp in buckets: - for bucket in buckets[tp]: - grads = [param.grad.data for param in bucket] - coalesced = self._flatten_dense_tensors(grads) - torch.distributed.all_reduce( - coalesced, - op=torch.distributed.ReduceOp.AVG if self.use_mean else torch.distributed.ReduceOp.SUM, - group=self._group) - all_synced = self._unflatten_dense_tensors(coalesced, grads) - for grad, synced in zip(grads, all_synced): - grad.copy_(synced, non_blocking=True) + # step 2: build meta data for the offset of each bucket + # the start of each bucket will be padded to the next multiple of `len(self.ranks)` + buffer_length: int = 0 + starts, stops = [], [] + for params in seq_buckets: + starts.append(buffer_length) + numel = sum(p.numel() for p in params) + padding = len(self._ranks) - numel % len(self._ranks) + buffer_length += numel + padding + stops.append(buffer_length) + + # step3: allocate memory + # gradient buffer + self._contiguous_grads: torch.Tensor = torch.zeros( + (buffer_length,), dtype=self._params[0].dtype, + device=torch.cuda.current_device(), requires_grad=False) + # parameter buffer + self._contiguous_params: torch.Tensor = torch.empty( + (buffer_length,), dtype=self._params[0].dtype, + device=torch.cuda.current_device(), requires_grad=False) + + # step 4: build buckets + buckets: List[Bucket] = [] + for params, start, stop in zip(seq_buckets, starts, stops): + # replace underlying parameter content using shared storage from parameter + ofst = start + for param in params: + with torch.no_grad(): + self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) + param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) + ofst += param.numel() + # initialize buckets + bucket = Bucket( + params, + self._contiguous_params[start:stop], + self._contiguous_grads[start:stop], + self._reduce_op, + self._group, + self._async, + self._zero, + ) + buckets.append(bucket) + torch.cuda.empty_cache() + # make it in reverse order as the backward happens from tail to head + self._buckets: List[Bucket] = list(reversed(buckets)) + assert len(self._buckets) > 0, ( + f"Find {len(self._params)} parameters in the reducer. " + f"Make sure adding all parameters before building buckets") + + def sync_grads(self): + """ + synchronize gradients using allreuce (non-zero) or reduce-scatter (zero) + """ + if RuntimeFlag.accum_mode: return + for bucket in self._buckets: + bucket.sync_grads() + self._apply_post_hooks() + + def gather_params(self): + """ + Gather parameters + """ + if RuntimeFlag.accum_mode: return + assert self._zero, "gathering paramters is only for zero optimization." + for bucket in self._buckets: + bucket.gather_params() + + def zero_grad(self): + """ + Make gradient to be zero. This needs to be called + after `optimizer.step()` if `optmizer.zero_grad(set_to_none=True)`. + """ + if RuntimeFlag.accum_mode: return torch.cuda.synchronize() - CudaTimer().stop(field_name='comm', predefined=True) + self._contiguous_grads.zero_() + for bucket in self._buckets: + bucket.reset() + bucket._param_for_optimizer.grad = None + for param in self.params: + param.grad = None + + def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: + """ + Get parameters for optimizers + """ + params = [] + for bucket in self._buckets: + params.append(bucket._param_for_optimizer) + return params - def sync(self): + def broadcast_params(self): """ - Sync parameters before training + broadcast parameters before training """ for param in self._params: torch.distributed.broadcast(param, self.ranks[0], group=self._group) torch.cuda.synchronize() - def _flatten_dense_tensors(self, tensors): + def register_post_hook(self, fn: Callable): """ - Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of - same dense type. + Register a post hook function after gradient update. - Since inputs are dense, the resulting tensor will be a concatenated 1D - buffer. Element-wise operation on this buffer will be equivalent to - operating individually. + A reducer can be registered by multiple hooks and the hooks will be + applied in the order of registration. + + The hook function takes a contiguous buffer of updated gradients + and can only apply in-place operations on it. - Args: - tensors (Iterable[Tensor]): dense tensors to flatten. - Returns: - A contiguous 1D buffer containing input tensors. - """ - return torch._utils._flatten_dense_tensors(tensors) + Example: + + ``` + hook = lambda grad: grad.clamp_(min=-1, max=1) + reducer.register_post_hook(hook) + ``` - def _unflatten_dense_tensors(self, flat, tensors): + @param fn Callable: hook function that takes a gradient buffer as input """ - View a flat buffer using the sizes of tensors. Assume that tensors are of - same dense type, and that flat is given by _flatten_dense_tensors. + assert callable(fn), f"post hook function must be callable, but got {type(fn)}" + self._hooks.append(fn) - Args: - flat (Tensor): flattened dense tensors to unflatten. - tensors (Iterable[Tensor]): dense tensors whose sizes will be used to - unflatten flat. + def _apply_post_hooks(self): + """ + Apply registered post hooks one by one after gradient update + """ + if len(self._hooks) == 0: return + # get updated gradients + grads = tuple(bucket._param_for_optimizer.grad for bucket in self._buckets) + assert all(grad is not None for grad in grads) + # apply hooks + for grad in grads: + for hook in self._hooks: + hook(grad) - Returns: - Unflattened dense tensors with sizes same as tensors and values from - flat. + def clear_post_hooks(self): + """ + Clear all post hooks """ - return torch._utils._unflatten_dense_tensors(flat, tensors) + self._hooks = [] diff --git a/cube/runtime/device.py b/cube/runtime/device.py index de8d7802..a8065400 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -1,7 +1,7 @@ """ Communication group settings among devices """ -from typing import List +from typing import List, Dict import numpy as np import torch import os @@ -21,8 +21,6 @@ def __init__(self): self.local_world_size = 1 self.local_rank = 0 self.node_rank = 0 - self.groups = dict() - torch.cuda.set_device(0) else: if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend='nccl') @@ -32,8 +30,11 @@ def __init__(self): self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) self.local_rank = int(os.environ.get('LOCAL_RANK')) self.node_rank = int(os.environ.get('GROUP_RANK')) - self.groups = dict() - torch.cuda.set_device(self.local_rank) + + torch.cuda.set_device(self.local_rank) + self.groups: Dict = dict() + self.streams: Dict[str, torch.cuda.Stream] = { + 'default': torch.cuda.default_stream()} instance = None @@ -63,6 +64,14 @@ def get_group(self, ranks): self.groups[rank_bits] = torch.distributed.new_group(list(ranks)) return self.groups[rank_bits] + def get_stream(self, name: str) -> torch.cuda.Stream: + """ + Get stream by name. If name doesn't exist, + will create a new one. + """ + return DeviceGroup.instance.streams.setdefault( + name, torch.cuda.Stream()) + def create_hybrid(self, group_num: List[int]) -> List[List[int]]: """ Create hybrid (nested) groups given the each group number. diff --git a/cube/runtime/module.py b/cube/runtime/module.py index b28c2bf6..af7e505f 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -26,12 +26,35 @@ def add_reducer(self, reducer: Reducer): raise RuntimeError(f"Expected a Reducer but got {type(reducer)}") self._reducers.append(reducer) - def reduce_grads(self): + def zero_grad(self): """ - Mannually allreduce gradients on the weight + Make zero for gradients caused by async weight reducer """ for reducer in self._reducers: - reducer.allreduce() + reducer.zero_grad() + + def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: + """Get parameter list for optimizer""" + params = [] + reducer_pids = set() + for reducer in self._reducers: + params += reducer.parameters_for_optimizer() + reducer_pids.update(id(p) for p in reducer.params) + for param in self.parameters(): + if id(param) not in reducer_pids: + params.append(param) + # print(f'> get out parameters: {sum(p.numel() for p in params)}') + return params + + def gather_params(self): + """ + Gather parameters + + This won't take effect when zero is not enabled. + """ + for reducer in self._reducers: + if reducer.zero: + reducer.gather_params() def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: int): """ @@ -47,16 +70,6 @@ def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: i assert hasattr(self, attr), f"{attr} is not in the module" self._fullmap[attr] = (tid, slicers, val_chunks) - def reduce_all_gradients(self, ranks: Tuple[int]): - """ - reduce gradients for the whole model. - This can only be used for data parallel - """ - reducer = Reducer(ranks, use_mean=True) - for parameter in self.parameters(): - reducer.add_param(parameter) - reducer.allreduce() - def get_full_map(self): return self._fullmap @@ -115,39 +128,108 @@ def merge_partial_states(state_dicts, zero_idx_maps=None): # at first, merge the partitioned optimizer states due to zero to the zero-disabled format if zero_idx_maps is not None: - def _check_opt_state(opt_state): - cnt = 0 - sorted_opt_state = {} - for idx in sorted(opt_state.keys()): - assert cnt == idx, f'opt state error: {idx} vs {cnt}, in {opt_state.keys()}' - sorted_opt_state[idx] = opt_state[idx] - cnt += 1 - return sorted_opt_state - optimizer_state_dict = {} - worker_cnt = len(state_dicts) - opt_state_list = [] - for work_idx in range(worker_cnt): - zero_idx2model_idx, model_idx2zero_idx, zero_rank_groups = zero_idx_maps[work_idx] - opt_state = {} - # first place local opt state to right index - if len(zero_idx2model_idx) == 0: - assert len(state_dicts[work_idx][1]['state']) == 0 - for local_idx, val in state_dicts[work_idx][1]['state'].items(): # worker / last_optimizer_state / state - print(f'{work_idx}, {local_idx}') - global_idx = zero_idx2model_idx[local_idx] - assert global_idx not in opt_state - opt_state[global_idx] = val - # for each rank group, copy opt state from other buckets - for rank_group, param_idx_buckets in zero_rank_groups.items(): - for bucket_idx, rank in enumerate(rank_group): - if rank == work_idx: continue - for global_idx in param_idx_buckets[bucket_idx]: - other_local_idx = zero_idx_maps[rank][1][global_idx] # rank / model_idx2zero_idx / global_idx - assert global_idx not in opt_state - opt_state[global_idx] = state_dicts[rank][1]['state'][other_local_idx] # worker / last_optimizer_state / state / local idx - opt_state = _check_opt_state(opt_state) - opt_state_list.append(opt_state) - assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' + if bool(int(os.environ.get('USE_ZERO', default=0))): + def _check_state_size(opt_state_keys, bucket_state): + if len(opt_state_keys) <= 1: + return True + return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape + for key in opt_state_keys) + + def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): + assert bucket_size % len(bucket_states) == 0 + opt_state_keys = list(bucket_states[0].keys()) + print(bucket_states[0], opt_state_keys) + if 'step' in bucket_states[0]: + opt_state_keys.remove('step') + assert _check_state_size(opt_state_keys, bucket_states[0]), f'the keys {opt_state_keys} have different shape' + # NOTE: only support adam for now + assert 'exp_avg' in opt_state_keys + assert 'exp_avg_sq' in opt_state_keys + chunk_size = bucket_size // len(bucket_states) + start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size + end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + opt_states, opt_states_1d = {}, {} + for key in opt_state_keys: + opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, + device=bucket_states[0][key].device, requires_grad=False) + opt_states_1d[key] = opt_states[key].view(-1) + + if start_rank_id == end_rank_id: + for key in opt_state_keys: + opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] + else: + offset = chunk_size-start_offset + for key in opt_state_keys: + opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] + for i in range(start_rank_id+1, end_rank_id): + for key in opt_state_keys: + opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] + offset += chunk_size + for key in opt_state_keys: + opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + + if 'step' in bucket_states[0]: + opt_states['step'] = bucket_states[0]['step'] + return opt_states + + opt_state_list = [] + worker_cnt = len(state_dicts) + for work_idx in range(worker_cnt): + model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[work_idx] + opt_state = {} + for model_idx, opt_idx in model_idx2opt_idx.items(): + if isinstance(opt_idx, int): + # the param without reducer + assert opt_idx2ranks[opt_idx] is None + # state_dicts [worker idx][opt state]['state'][param idx] + opt_state[model_idx] = state_dicts[work_idx][1]['state'][opt_idx] + else: + # the param in reducer bucket + opt_idx, pstart, pend, pshape = opt_idx + ranks, bucket_size = opt_idx2ranks[opt_idx] + bucket_states = [state_dicts[rank][1]['state'][opt_idx] for rank in ranks] + opt_state[model_idx] = _retrieve_param_opt_state( + bucket_states, + pstart, + pend, + pshape, + bucket_size) + opt_state_list.append(opt_state) + assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' + else: + def _check_opt_state(opt_state): + cnt = 0 + sorted_opt_state = {} + for idx in sorted(opt_state.keys()): + assert cnt == idx, f'opt state error: {idx} vs {cnt}, in {opt_state.keys()}' + sorted_opt_state[idx] = opt_state[idx] + cnt += 1 + return sorted_opt_state + optimizer_state_dict = {} + worker_cnt = len(state_dicts) + opt_state_list = [] + for work_idx in range(worker_cnt): + zero_idx2model_idx, model_idx2zero_idx, zero_rank_groups = zero_idx_maps[work_idx] + opt_state = {} + # first place local opt state to right index + if len(zero_idx2model_idx) == 0: + assert len(state_dicts[work_idx][1]['state']) == 0 + for local_idx, val in state_dicts[work_idx][1]['state'].items(): # worker / last_optimizer_state / state + print(f'{work_idx}, {local_idx}') + global_idx = zero_idx2model_idx[local_idx] + assert global_idx not in opt_state + opt_state[global_idx] = val + # for each rank group, copy opt state from other buckets + for rank_group, param_idx_buckets in zero_rank_groups.items(): + for bucket_idx, rank in enumerate(rank_group): + if rank == work_idx: continue + for global_idx in param_idx_buckets[bucket_idx]: + other_local_idx = zero_idx_maps[rank][1][global_idx] # rank / model_idx2zero_idx / global_idx + assert global_idx not in opt_state + opt_state[global_idx] = state_dicts[rank][1]['state'][other_local_idx] # worker / last_optimizer_state / state / local idx + opt_state = _check_opt_state(opt_state) + opt_state_list.append(opt_state) + assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' # assign opt_state to state_dicts, cannot be assigned in the above loop opt_state_len = len(opt_state_list[0]) for work_idx in range(worker_cnt): @@ -263,4 +345,4 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): # dump to ckpt torch.save({'state_dict': merged_model_state_dict, 'optim_state_dict': merged_optimizer_state_dict - }, filename_prefix + '.full.ckpt') \ No newline at end of file + }, filename_prefix + '.full.ckpt') diff --git a/cube/utils.py b/cube/utils.py index 7128c6f1..940b2c0c 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -1,6 +1,9 @@ from typing import Optional + +import cube from cube.profiler.timer import print_each_rank from cube.runtime.device import DeviceGroup +from cube.flags import RuntimeFlag from cube.flags import RuntimeFlag @@ -16,10 +19,14 @@ def _load_module_attr(filename: str, name: str): def load_model(filename: Optional[str] = None, load_content: bool = True): filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename module = _load_module_attr(filename, 'GenModel') - loaded_module = module.GenModel().cuda() + loaded_module: cube.runtime.module.CubeModule = module.GenModel().cuda() + # load parameter content if load_content: print_each_rank("> loading parameter content...") loaded_module.load_attr_content('./fullmodel.pt') + # initialize reducer + for reducer in loaded_module.reducers: + reducer.build_buckets() return loaded_module diff --git a/tests/codegen/test_scale.py b/tests/codegen/test_scale.py new file mode 100644 index 00000000..93dbf19b --- /dev/null +++ b/tests/codegen/test_scale.py @@ -0,0 +1,202 @@ + + +""" +example: + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + tests/codegen/test_scale.py + +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + --nnodes=1 \ + tests/codegen/test_scale.py +""" + +from typing import List +import torch +from torch import nn + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank + + +cube.init() + + +class MLP(nn.Module): + def __init__(self, dim, mult=1, nlayers=16): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for lid in range(nlayers): + if lid % 2 == 0: + self.layers.append(nn.Linear(dim, dim * mult, bias=False)) + else: + self.layers.append(nn.Linear(dim * mult, dim, bias=False)) + + def forward(self, data): + x = data + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) + return loss + + +class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, bs: int, dim: int): + super().__init__(bs, [0]) + self.sample = None + self.dim = dim + self.set_batch_size(bs) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + self.batch_size = batch_size + self.sample = torch.rand( + [batch_size, self.dim], dtype=torch.float32, + device=torch.cuda.current_device() + ) + + +# tensor parallelism +def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], + idx: int, dim: int, tag='dim'): + algo = node.algorithms(tag) + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + +# replicate +def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def _run(train_iter, model, dataloader, optimizer): + iter_num, warmup = 5, 2 + for step in range(iter_num): + if step >= warmup: + CudaTimer(enable=True).start('e2e') + loss = train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + # model.zero_grad() + # model.gather_params() + if step >= warmup: + CudaTimer().stop('e2e') + print_each_rank(f'loss: {loss.item()}', rank_only=0) + print_each_rank('e2e time (ms) per iteration: {} ms'.format( + CudaTimer().duration(iter_num-warmup, field_name='e2e'))) + CudaTimer().print_all(times=iter_num-warmup) + + +def test_scale_full_dp(): + + model = MLP(dim=4096) + dataloader = MLPDataLoader(bs=8, dim=4096) + + def policy(graph: IRGraph, resource): + assert resource.ngpus > 2 + ngpus = 2 + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, list(range(ngpus))) + for node in graph.select(ntype=IRFwOperation): + if node.name == 'linear': + _tp(graph, node, list(range(ngpus)), idx=0, dim=0, tag='dim') + else: + _replica(graph, node, list(range(ngpus))) + return graph + + @cube.compile(model, dataloader, PAS=policy, scale=True) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + model = cube.load_model() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + _run(train_iter, model, dataloader, optimizer) + + +def test_scale_partial_dp(): + + model = MLP(dim=4096) + dataloader = MLPDataLoader(bs=8, dim=4096) + + def policy(graph: IRGraph, resource): + assert resource.ngpus > 2 + ngpus = 2 + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, list(range(ngpus))) + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if node.name == 'linear': + if idx % 4 == 0: + _tp(graph, node, list(range(ngpus)), idx=0, dim=0, tag='dim') + if idx % 4 == 1: # partition weight, partition input (reduction) + _tp(graph, node, list(range(ngpus)), idx=0, dim=1, tag='dim') + if idx % 4 == 2: # partition weight, replicate input + _tp(graph, node, list(range(ngpus)), idx=1, dim=0, tag='dim') + if idx % 4 == 3: # replicate + _replica(graph, node, list(range(ngpus))) + else: + _replica(graph, node, list(range(ngpus))) + return graph + + @cube.compile(model, dataloader, PAS=policy, scale=True) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + model = cube.load_model() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + _run(train_iter, model, dataloader, optimizer) + + +def test_scale_no_dp(): + + model = MLP(dim=4096) + dataloader = MLPDataLoader(bs=8, dim=4096) + + def policy(graph: IRGraph, resource): + assert resource.ngpus > 2 + ngpus = 2 + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, list(range(ngpus))) + for node in graph.select(ntype=IRFwOperation): + _replica(graph, node, list(range(ngpus))) + return graph + + @cube.compile(model, dataloader, PAS=policy, scale=True) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + model = cube.load_model() + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + _run(train_iter, model, dataloader, optimizer) + + +if __name__ == '__main__': + + # test_scale_full_dp() + # test_scale_partial_dp() + test_scale_no_dp() \ No newline at end of file diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py new file mode 100644 index 00000000..ba9350a0 --- /dev/null +++ b/tests/runtime/test_reducer.py @@ -0,0 +1,189 @@ +""" +example: + +ASYNC_REDUCER=0 USE_ZERO=1 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ + tests/runtime/test_reducer.py +""" +from typing import List +from functools import partial + +import torch +import random +from torch import nn + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.profiler import CudaTimer +from cube.profiler.timer import print_each_rank +# from cube.tools.debug import DebugTool + +cube.init() + + +class MLP(nn.Module): + def __init__(self, dim, nlayers=16): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + self.param = torch.nn.Parameter(torch.ones([1])) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = x * self.param # for padding test + loss = torch.sum(x) + return loss + + +class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, bs: int, dim: int): + super().__init__(bs, [0]) + self.sample = None + self.dim = dim + self.set_batch_size(bs) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + torch.random.manual_seed(0) + self.batch_size = batch_size + self.sample = torch.randn( + [batch_size, self.dim], dtype=torch.float32, + device=torch.cuda.current_device() + ) + self.sample = (self.sample - 1) * 1e3 + + +def init_model_dataloader(): + batch_size = 4 + dim = 4096 + torch.random.manual_seed(0) + random.seed(0) + model = MLP(dim=dim) + # torch.random.manual_seed(0) + dataloader = MLPDataLoader(batch_size, dim) + return model, dataloader + + +def policy(graph: IRGraph, resource): + + # tensor parallelism + def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + # replicate + def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + devs = list(range(resource.ngpus)) + for node in graph.select(ntype=IRDataOperation): + _replica(graph, node, devs) + for node in graph.select(ntype=IRFwOperation): + _tp(graph, node, devs, idx=0, dim=0, num=resource.ngpus) + # if node.name == 'linear': + # _tp(graph, node, devs, idx=0, dim=0, num=resource.ngpus) + # else: + # _replica(graph, node, devs) + return graph + + +def get_baseline(): + + model, dataloader = init_model_dataloader() + model = model.cuda() + # optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + + wsz = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + niters = 4 + losses = [] + for idx in range(niters): + loss = train_iter(model, dataloader) + # loss = DebugTool.record( + # model, + # partial(train_iter, model, dataloader), + # filename=f'base-{wsz}gpus-{rank}.iter{idx}.log' + # ) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + losses.append(loss.item()) + + for idx, loss in enumerate(losses): + print_each_rank(f'baseline loss[{idx}]: {loss}', rank_only=0) + + return losses + + +baseline_losses = get_baseline() + + +def test_reducer(): + + # nonlocal baseline_losses + + model, dataloader = init_model_dataloader() + + @cube.compile(model, dataloader, PAS=policy) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + + model = cube.load_model() + # optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-4) # not match for adam + optimizer = torch.optim.SGD(model.parameters_for_optimizer(), lr=1e-2) + + def post_hook(grad): + grad.mul_(0.1) + for reducer in model.reducers: + reducer.register_post_hook(post_hook) + + wsz = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + niters = 4 + losses = [] + for idx in range(niters): + loss = train_iter(model, dataloader) + # loss = DebugTool.record( + # model, + # partial(train_iter, model, dataloader), + # filename=f'reducer-{wsz}gpus-{rank}.iter{idx}.log' + # ) + optimizer.step() + optimizer.zero_grad() + model.zero_grad() + model.gather_params() + losses.append(loss.item()) + + for idx, loss in enumerate(losses): + print_each_rank(f'reducer loss[{idx}]: {loss}', rank_only=0) + + +if __name__ == '__main__': + + test_reducer() \ No newline at end of file From 6ae95a87545fef3976be29483d1c25eef99f0451 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 15 Jun 2023 02:04:42 +0000 Subject: [PATCH 1401/1892] Merged PR 1618: update docstring example update docstring example --- README.md | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/README.md b/README.md index 5f30c5bd..b7a01fae 100644 --- a/README.md +++ b/README.md @@ -54,3 +54,44 @@ OMP_NUM_THREADS=4 torchrun \ --nnodes=1 \ examples/mlp/linears.py --policy PASCol ``` + + +## Development Docstring + +We follow [Google Style Python Docstring](https://google.github.io/styleguide/pyguide.html) for development. + +Following is an typical example: + +```python +class SampleClass: + """Summary of class here. + + Longer class information... + Longer class information... + + """ + + def __init__(self, likes_spam: bool = False): + """Initializes the instance based on spam preference. + + Args: + likes_spam: Defines if instance exhibits this preference. + """ + self.likes_spam = likes_spam + self.eggs = 0 + + def public_method(self, a, b): + """Performs operation blah. + + Long description here. + + Args: + a (int): xxx + b (int/str): xxx + + Returns: + t (bool): xxx + k (int): xxx + """ + # function implementation goes here +``` \ No newline at end of file From 9f2425d179e02d66337593ef7c0e1fb5afc4caeb Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Thu, 15 Jun 2023 02:49:11 +0000 Subject: [PATCH 1402/1892] Merged PR 1625: optimize merge_partial_states by eliminating redundant data loading. optimize merge_partial_states by eliminating redundant data loading. --- cube/runtime/module.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index af7e505f..2de98ecb 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,3 +1,4 @@ +import logging from typing import List, Dict, Tuple, Optional import torch from cube.runtime.device import DeviceGroup @@ -126,6 +127,14 @@ def merge_partial_states(state_dicts, zero_idx_maps=None): if len(state_dicts) == 1: return state_dicts[0][0], state_dicts[0][1] + plan_ngpus = -1 + if 'PLAN_NGPUS' in os.environ: + plan_ngpus = int(os.environ['PLAN_NGPUS']) + assert plan_ngpus >= 1, plan_ngpus + assert plan_ngpus <= len(state_dicts), f'plan_ngpus = {plan_ngpus}, len(state_dicts) = {len(state_dicts)}' + assert len(state_dicts) % plan_ngpus == 0, f'plan_ngpus = {plan_ngpus}, len(state_dicts) = {len(state_dicts)}' + logging.info(f'plan_ngpus = {plan_ngpus}') + # at first, merge the partitioned optimizer states due to zero to the zero-disabled format if zero_idx_maps is not None: if bool(int(os.environ.get('USE_ZERO', default=0))): @@ -174,7 +183,7 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): opt_state_list = [] worker_cnt = len(state_dicts) - for work_idx in range(worker_cnt): + for work_idx in (range(worker_cnt) if plan_ngpus < 0 else range(plan_ngpus)): model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[work_idx] opt_state = {} for model_idx, opt_idx in model_idx2opt_idx.items(): @@ -197,6 +206,8 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): opt_state_list.append(opt_state) assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' else: + if plan_ngpus > 0: + logging.warning(f'plan_ngpus {plan_ngpus} not handled USE_ZERO == False') def _check_opt_state(opt_state): cnt = 0 sorted_opt_state = {} @@ -232,13 +243,15 @@ def _check_opt_state(opt_state): assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' # assign opt_state to state_dicts, cannot be assigned in the above loop opt_state_len = len(opt_state_list[0]) - for work_idx in range(worker_cnt): + for work_idx in (range(worker_cnt) if plan_ngpus < 0 else range(plan_ngpus)): state_dicts[work_idx][1]['state'] = opt_state_list[work_idx] state_dicts[work_idx][1]['param_groups'][0]['params'] = sorted(opt_state_list[work_idx].keys()) assert len(opt_state_list[work_idx]) == opt_state_len # find tensor full shape param_max_dimsize = {} + if plan_ngpus > 0: + state_dicts = state_dicts[0:plan_ngpus] for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: for param_area in param_area_map.items(): local_name = param_area[0][0:param_area[0].rfind('_')] From 1145bf3a0f32beb3cc2d16e4f38c9c8d169adae2 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 27 Jun 2023 12:34:16 +0000 Subject: [PATCH 1403/1892] Merged PR 1626: Hide model.zero_grad() with updated accum mode - Hide model.zero_grad into generated code - Update accum mode with awareness of the first iteration and the last iteration - Bug fix on reducer allreduce SUM - Reducer more robust in adding param --- cube/codegen/module/module.py | 8 +- cube/codegen/schedule/schedule.py | 1 + cube/flags.py | 17 ++- cube/runtime/adapter/reducer.py | 206 +++++++++++++++++++--------- cube/runtime/module.py | 10 +- cube/utils.py | 102 ++++++++++++-- tests/runtime/test_reducer.py | 107 +++++++++------ tests/runtime/test_runtime_flag.py | 210 +++++++++++++++++++++++++++++ 8 files changed, 535 insertions(+), 126 deletions(-) create mode 100644 tests/runtime/test_runtime_flag.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 45e69e47..35c5b7e3 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -15,7 +15,7 @@ from cube.graph.parser.register import CustomizedOps from cube.execplan import ExecutionPlan -from cube.execplan.execplan import ExeRepetend, ExeReuseCell +from cube.execplan.execplan import ExeReuseCell from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock @@ -141,8 +141,10 @@ def add_scale_reducers(self): if not param.is_param(): continue for ctensor in graph.ctensors(param): if device not in ctensor.device: continue - if ctensor not in all_params and ctensor not in rest_params: - rest_params.append(ctensor) + if ctensor not in all_params: + # a same parameter can be consumed multiple times by different operators + if ctensor not in rest_params: + rest_params.append(ctensor) if len(rest_params) == 0: continue # create reducer and append to the execution diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index d0be4374..c50d34e5 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -57,6 +57,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: with FunctionBlock(func_name='_train_step', args=args) as fb: fb.insert_body('_ = None') + fb.insert_body('model.zero_grad()') # body code if len(device_nodes) == 0: fb.insert_body('pass') diff --git a/cube/flags.py b/cube/flags.py index 5c88b273..35168e26 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -60,6 +60,17 @@ class CompileFlag: class RuntimeFlag: - # turn execution in accumulation mode - # where reducers will not allpy allreduce on gradients - accum_mode: bool = False + # if True, skip model.zero_grad(). + # when applying gradient accumulation, + # this flag should be set to True at the first accumulation step, + # and set to False at other accumulation steps. + # By default False, which means the gradients of parameters in the reducers + # will be zeroed at the beginning of every iteration. + skip_zero_grad: bool = False + + # if True, skip reducer.sync_grads(). + # when applying gradient accumulation, + # this flag should be set to True at the last accumulation step, + # .and set to False at other accumulation steps. + # By default False, which means the gradients will be reduced at the end of every iteration. + skip_reducer: bool = False diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 94b68499..9fa5e46a 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Tuple, Any, Callable, Optional +from typing import List, Dict, Tuple, Any, Callable, Optional, Set from functools import partial import warnings import torch @@ -28,8 +28,12 @@ def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: class Bucket: - # config: whether to use allreduce for zero (default: reduce-scatter) - use_allreduce_for_zero: bool = False + # config: whether to use reduce scatter for zero (default False). + # By default we use `allreduce` for zero, which is due to + # 1) `reduce_scatter` will make some parameters have stale gradient after synchronization, + # hence break the consistency of `.data` and `.grad` of parameters. Need to be careful when using optimizer. + # 2) `reduce_scatter`` doesn't significantly improve performance comparing with `allreduce`. + use_reduce_scatter_for_zero: bool = False def __init__(self, params: List[torch.nn.Parameter], param_buffer: torch.Tensor, grad_buffer: torch.Tensor, @@ -71,6 +75,10 @@ def __init__(self, params: List[torch.nn.Parameter], self._numel: int = sum(p.numel() for p in self._params) self._padding: int = self._contiguous_grads.size(0) - self._numel + # pre and post hooks for gradient synchronization + self._pre_hooks: List[Callable] = [] + self._post_hooks: List[Callable] = [] + # only async will enable contiguous gradient self.build() self.register_hooks() @@ -127,14 +135,11 @@ def register_hooks(self): def post_grad_hook(param: torch.nn.Parameter, *unused): # stream = DeviceGroup().get_stream('reducer') ofst = self._pofset[param] - # due to **unknown** reasons, multi-stream computation has incorrect results - # with torch.cuda.stream(stream): # async update to overlap backward computation # TODO: need to handle sparse gradients in torch.nn.Embedding - # print('yizhu1', param.requires_grad, param.size(), param.data_ptr) self._contiguous_grads[ofst:ofst+param.numel()].add_(param.grad.data.view(-1)) param.grad = None - if RuntimeFlag.accum_mode: return + if RuntimeFlag.skip_reducer: return self._cnt += 1 assert self._cnt <= len(self._params), \ @@ -142,19 +147,15 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): # perform all-reduce if self._async and self._cnt == len(self._params): - # wait until all gradients are accumulated in the gradient buffer - # stream.synchronize() - if self._zero: # zero will use reduce-scatter (default) or allreduce on gradient shards - if Bucket.use_allreduce_for_zero: - self._work = torch.distributed.all_reduce( - self._contiguous_grads, op=self._reduce_op, - group=self._group, async_op=True) - else: - rank = torch.distributed.get_rank(group=self._group) - shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) - self._work = torch.distributed.reduce_scatter( - shards[rank], shards, op=self._reduce_op, - group=self._group, async_op=True) + # apply pre hooks + self._apply_pre_hooks() + # communication + if self._zero and Bucket.use_reduce_scatter_for_zero: + rank = torch.distributed.get_rank(group=self._group) + shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) + self._work = torch.distributed.reduce_scatter( + shards[rank], shards, op=self._reduce_op, + group=self._group, async_op=True) else: self._work = torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, @@ -186,17 +187,14 @@ def sync_grads(self): assert self._work is not None self._work.wait() else: - if self._reduce_op == torch.distributed.ReduceOp.SUM: - self._contiguous_grads.div_(torch.distributed.get_world_size(group=self._group)) CudaTimer().start('comm', predefined=True) - if self._zero: - if Bucket.use_allreduce_for_zero: - torch.distributed.all_reduce( - self._contiguous_grads, op=self._reduce_op, group=self._group) - else: - shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) - torch.distributed.reduce_scatter( - shards[rank], shards, op=self._reduce_op, group=self._group) + # apply pre-hooks + self._apply_pre_hooks() + # synchrnoize gradients + if self._zero and Bucket.use_reduce_scatter_for_zero: + shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) + torch.distributed.reduce_scatter( + shards[rank], shards, op=self._reduce_op, group=self._group) else: torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, group=self._group) @@ -206,11 +204,7 @@ def sync_grads(self): assert param.grad is None pofst = self._pofset[param] param.grad = self._contiguous_grads[pofst:pofst+param.numel()].view(param.size()) - # the following two methods can make `.grad` isolate with `self._contiguous_grads`, - # thereby enabling torch.zero_ inside the reducer without requiring user to modify. - # param.grad = grads[pofst:pofst+param.numel()].view(param.size()) - # param.grad = self._contiguous_grads[pofst:pofst+param.numel()].clone().view(param.size()) - + # setup gradient for optimizer parameters if self._zero: grad = self._contiguous_grads.chunk(self._wsz, dim=0)[rank] @@ -220,6 +214,9 @@ def sync_grads(self): else: self._param_for_optimizer.grad = self._contiguous_grads[:self._numel] + # apply post-hooks + self._apply_post_hooks() + def gather_params(self): """ All-gather parameters @@ -229,6 +226,56 @@ def gather_params(self): shards = list(self._contiguous_params.chunk(self._wsz, dim=0)) torch.distributed.all_gather(shards, shards[rank], group=self._group) + def register_pre_hook(self, fn: Callable): + """Register pre hooks to be applied before gradient synchronization. + + The pre-hooks will be applied one by one following the order of registration. + + Args: + fn (Callable): a callable function that takes a gradient as input and optionally updates the gradient. + """ + assert callable(fn), f"fn must be callable for pre hooks, but got {type(fn)}" + self._pre_hooks.append(fn) + + def register_post_hook(self, fn: Callable): + """Register post hooks to be applied after gradient synchronization. + + The post-hooks will be applied one by one following the order of registration. + + Args: + fn (Callable): a callable function that takes a gradient as input and optionally updates the gradient. + """ + assert callable(fn), f"fn must be callable for post hooks, but got {type(fn)}" + self._post_hooks.append(fn) + + def _apply_pre_hooks(self): + """Apply pre hooks before gradient synchronization. + + The pre-hooks will be applied one by one following the order of registration. + """ + if len(self._pre_hooks) == 0: return + grads = self._contiguous_grads[:self._numel] + for hook in self._pre_hooks: + hook(grads) + + def _apply_post_hooks(self): + """Apply post hooks after gradient synchronization. + + The post-hooks will be applied one by one following the order of registration. + """ + if len(self._post_hooks) == 0: return + grads = self._contiguous_grads[:self._numel] + for hook in self._post_hooks: + hook(grads) + + def clear_pre_hooks(self): + """Clear all pre hooks.""" + self._pre_hooks = [] + + def clear_post_hooks(self): + """Clear all post hooks.""" + self._post_hooks = [] + def reset(self): """Reset status.""" self._cnt = 0 @@ -252,6 +299,7 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, @param zero bool: whether to apply zero optimization on gradients """ self._params: List[torch.nn.Parameter] = list() + self._param_ids: Set[int] = set() self._numel: int = 0 self._ranks = ranks self._group = DeviceGroup().get_group(ranks) @@ -264,8 +312,6 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, # contiguous parameter buffer and gradient buffer self._contiguous_params: torch.Tensor = None self._contiguous_grads: torch.Tensor = None - # hooks - self._hooks: List[Callable] = [] @property def params(self) -> Tuple[torch.nn.Parameter]: @@ -289,6 +335,11 @@ def zero(self) -> bool: def buckets(self) -> Tuple[Bucket]: return tuple(self._buckets) + @property + def reduce_op(self) -> torch.distributed.ReduceOp: + """Get reduce operation""" + return self._reduce_op + def add_param(self, param: torch.nn.Parameter): """ Add a parameter to the reducer @@ -299,7 +350,13 @@ def add_param(self, param: torch.nn.Parameter): @param param torch.nn.Parameter: the added parameter """ + if param.data.data_ptr() in self._param_ids: + warnings.warn( + f'rank [{torch.distributed.get_rank()}]: detected duplicated or shared parameters, ignored.', + category=RuntimeWarning) + return self._params.append(param) + self._param_ids.add(param.data.data_ptr()) self._numel += param.numel() def build_buckets(self): @@ -393,26 +450,25 @@ def sync_grads(self): """ synchronize gradients using allreuce (non-zero) or reduce-scatter (zero) """ - if RuntimeFlag.accum_mode: return + if RuntimeFlag.skip_reducer: return for bucket in self._buckets: bucket.sync_grads() - self._apply_post_hooks() def gather_params(self): + """Gather parameters with Zero optimizations after `optimizer.step()`. + + This is required when zero optimization is turned on. """ - Gather parameters - """ - if RuntimeFlag.accum_mode: return - assert self._zero, "gathering paramters is only for zero optimization." + if not self._zero: return for bucket in self._buckets: bucket.gather_params() def zero_grad(self): + """Make gradient to be zero. + + This needs to be called at the beginning of every training iteration. """ - Make gradient to be zero. This needs to be called - after `optimizer.step()` if `optmizer.zero_grad(set_to_none=True)`. - """ - if RuntimeFlag.accum_mode: return + if RuntimeFlag.skip_zero_grad: return torch.cuda.synchronize() self._contiguous_grads.zero_() for bucket in self._buckets: @@ -438,6 +494,30 @@ def broadcast_params(self): torch.distributed.broadcast(param, self.ranks[0], group=self._group) torch.cuda.synchronize() + def register_pre_hook(self, fn: Callable): + """Register a pre hook function before gradient update. + + A reducer can be registered by multiple hooks and the hooks will be + applied in the order of registration. + + The hook function takes a contiguous buffer of local computed gradient + and can optionally apply in-place operations on it. + + Example: + + ``` + hook = lambda grad: grad.div_(4) + reducer.register_pre_hook(hook) + ``` + + Args: + fn Callable: + hook function that takes a gradient as input and optionally inplacemently updates it + """ + assert callable(fn), f"pre hook function must be callable, but got {type(fn)}" + for bucket in self._buckets: + bucket.register_pre_hook(fn) + def register_post_hook(self, fn: Callable): """ Register a post hook function after gradient update. @@ -445,7 +525,7 @@ def register_post_hook(self, fn: Callable): A reducer can be registered by multiple hooks and the hooks will be applied in the order of registration. - The hook function takes a contiguous buffer of updated gradients + The hook function takes a contiguous buffer of updated gradient and can only apply in-place operations on it. Example: @@ -455,26 +535,20 @@ def register_post_hook(self, fn: Callable): reducer.register_post_hook(hook) ``` - @param fn Callable: hook function that takes a gradient buffer as input + Args: + fn Callable: + hook function that takes a gradient as input and optionally inplacemently updates it """ assert callable(fn), f"post hook function must be callable, but got {type(fn)}" - self._hooks.append(fn) + for bucket in self._buckets: + bucket.register_post_hook(fn) - def _apply_post_hooks(self): - """ - Apply registered post hooks one by one after gradient update - """ - if len(self._hooks) == 0: return - # get updated gradients - grads = tuple(bucket._param_for_optimizer.grad for bucket in self._buckets) - assert all(grad is not None for grad in grads) - # apply hooks - for grad in grads: - for hook in self._hooks: - hook(grad) + def clear_pre_hooks(self): + """Clear all pre hooks.""" + for bucket in self._buckets: + bucket.clear_pre_hooks() def clear_post_hooks(self): - """ - Clear all post hooks - """ - self._hooks = [] + """Clear all post hooks.""" + for bucket in self._buckets: + bucket.clear_post_hooks() diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 2de98ecb..11f25a21 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -28,8 +28,14 @@ def add_reducer(self, reducer: Reducer): self._reducers.append(reducer) def zero_grad(self): - """ - Make zero for gradients caused by async weight reducer + """Make zero for gradients in weight reducer + + This only applies on the gradients of the parameters in each reducer. + This function will be automatically inserted inside the generated code + at the beginning of each iteration. + + If the function is under the context of `with cube.accum_mode()`, the zero of gradients + will be skipped. """ for reducer in self._reducers: reducer.zero_grad() diff --git a/cube/utils.py b/cube/utils.py index 940b2c0c..d2c31f76 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import cube from cube.profiler.timer import print_each_rank @@ -43,21 +43,97 @@ def load_eval_schedule(filename: Optional[str] = None): class accum_mode: - """ - Make cube execution in accumulation mode, where weight - gradient allreduce will be skipped. + """Make cube execution in gradient accumulation mode. + + A typical usage is: + + ``` + for _ in range(num_iters): + for step in range(accum_steps): + datas = next(dataloader) + with cube.accum_mode(begin=(step == 0), end=(step == accum_steps - 1)): + train_iter(model, *datas) + optimizer.step() + optimizer.zero_grad() + ``` - need manually call `model.reduce_grads()` to reduce gradients - after finish accumulation, or make `enable=False` for the last - accumulation step. + Or, + + ``` + for _ in range(num_iters): + for step in cube.accum_mode.steps(accum_steps): + datas = next(dataloader) + train_iter(model, *datas) + optimizer.step() + optimizer.zero_grad() + ``` """ - def __init__(self, enable: bool = True): - self.enable = enable - self.old = None + def __init__(self, begin: bool = True, end: bool = True): + """Turn on/off accumulation mode. + + Args: + begin (bool): Whether the iteration is the first accumulation step. + If True, the `model.zero_grad()` will be enabled to zero out gradients + of the parameters in the reducer. + end (bool): Whether the iteration is the last accumulation step. + If True, the `model.reduce_grad()` will be enabled to reduce gradients at + the end of the iteration. + """ + self.begin: bool = begin + self.end: bool = end + self.old: Tuple[bool, bool] = None def __enter__(self): - self.old = RuntimeFlag.accum_mode - RuntimeFlag.accum_mode = self.enable + """Enter the accumulation mode. + + Example usage: + + ``` + for _ in range(num_iters): + for step in range(accum_steps): + datas = next(dataloader) + with cube.accum_mode(begin=(step == 0), end=(step == accum_steps - 1)): + train_iter(model, *datas) + optimizer.step() + optimizer.zero_grad() + ``` + + """ + self.old = (RuntimeFlag.skip_zero_grad, RuntimeFlag.skip_reducer) + RuntimeFlag.skip_zero_grad = (not self.begin) + RuntimeFlag.skip_reducer = (not self.end) def __exit__(self, *args): - RuntimeFlag.accum_mode = self.old + RuntimeFlag.skip_zero_grad, RuntimeFlag.skip_reducer = self.old + self.old = None + + @staticmethod + def steps(nsteps: int): + """Perform the accumulation in `nsteps` steps. + + This interface doesn't require to set the `begin` and `end` flags + during the initilization of `accum_mode`. + + Example usage: + + ``` + for _ in range(num_iters): + for step in cube.accum_mode.steps(accum_steps): + datas = next(dataloader) + train_iter(model, *datas) + optimizer.step() + optimizer.zero_grad() + ``` + + Args: + nsteps (int): The number of accumulation steps. + + Yield: + int: The current step index. + """ + old = (RuntimeFlag.skip_zero_grad, RuntimeFlag.skip_reducer) + for step in range(nsteps): + RuntimeFlag.skip_zero_grad = (not (step == 0)) + RuntimeFlag.skip_reducer = (not (step == nsteps - 1)) + yield step + RuntimeFlag.skip_zero_grad, RuntimeFlag.skip_reducer = old diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py index ba9350a0..0ce9486a 100644 --- a/tests/runtime/test_reducer.py +++ b/tests/runtime/test_reducer.py @@ -3,9 +3,17 @@ ASYNC_REDUCER=0 USE_ZERO=1 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ tests/runtime/test_reducer.py + +ASYNC_REDUCER=0 USE_ZERO=0 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ + tests/runtime/test_reducer.py + +ASYNC_REDUCER=1 USE_ZERO=1 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ + tests/runtime/test_reducer.py + +ASYNC_REDUCER=1 USE_ZERO=0 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ + tests/runtime/test_reducer.py """ from typing import List -from functools import partial import torch import random @@ -14,9 +22,7 @@ import cube from cube.graph import IRGraph from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -# from cube.tools.debug import DebugTool cube.init() @@ -102,47 +108,77 @@ def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): return graph +def cal_gnorms(model): + """Calculate gradient normalization for gradients""" + gnorms = [] + for p in model.parameters(): + if p.grad is None: + continue + gnorms.append(p.grad.norm().item()) + return sum(gnorms) + + def get_baseline(): model, dataloader = init_model_dataloader() model = model.cuda() - # optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + # optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) loss.backward() return loss - - wsz = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() niters = 4 - losses = [] - for idx in range(niters): + losses, gnorms = [], [] + for _ in range(niters): loss = train_iter(model, dataloader) - # loss = DebugTool.record( - # model, - # partial(train_iter, model, dataloader), - # filename=f'base-{wsz}gpus-{rank}.iter{idx}.log' - # ) + gnorms.append(cal_gnorms(model)) optimizer.step() optimizer.zero_grad(set_to_none=True) losses.append(loss.item()) - - for idx, loss in enumerate(losses): - print_each_rank(f'baseline loss[{idx}]: {loss}', rank_only=0) + return losses, gnorms - return losses +def test_reducer(): -baseline_losses = get_baseline() + losses, gnorms = get_baseline() + for idx, (loss, gnorm) in enumerate(zip(losses, gnorms)): + print_each_rank(f'baseline step [{idx}]: loss: {loss} | gnorm: {gnorm}', rank_only=0) + model, dataloader = init_model_dataloader() -def test_reducer(): + @cube.compile(model, dataloader, PAS=policy) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + + model = cube.load_model() + optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-5) - # nonlocal baseline_losses + niters = 4 + losses, gnorms = [], [] + for idx in range(niters): + loss = train_iter(model, dataloader) + gnorms.append(cal_gnorms(model)) + optimizer.step() + optimizer.zero_grad() + model.gather_params() + losses.append(loss.item()) + + for idx, (loss, gnorm) in enumerate(zip(losses, gnorms)): + print_each_rank(f'reducer step [{idx}]: loss: {loss} | gnorm: {gnorm}', rank_only=0) + + +def test_reducer_hooks(): + + losses, gnorms = get_baseline() + for idx, (loss, gnorm) in enumerate(zip(losses, gnorms)): + print_each_rank(f'baseline step [{idx}]: loss: {loss} | gnorm: {gnorm}', rank_only=0) model, dataloader = init_model_dataloader() @@ -154,36 +190,29 @@ def train_iter(model, dataloader): return loss model = cube.load_model() - # optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-4) # not match for adam - optimizer = torch.optim.SGD(model.parameters_for_optimizer(), lr=1e-2) + optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-5) - def post_hook(grad): - grad.mul_(0.1) for reducer in model.reducers: + pre_hook = lambda grad: grad.div_(len(reducer.ranks)) + post_hook = lambda grad: grad.mul_(len(reducer.ranks)) + reducer.register_pre_hook(pre_hook) reducer.register_post_hook(post_hook) - wsz = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - niters = 4 - losses = [] + losses, gnorms = [], [] for idx in range(niters): loss = train_iter(model, dataloader) - # loss = DebugTool.record( - # model, - # partial(train_iter, model, dataloader), - # filename=f'reducer-{wsz}gpus-{rank}.iter{idx}.log' - # ) + gnorms.append(cal_gnorms(model)) optimizer.step() optimizer.zero_grad() - model.zero_grad() model.gather_params() losses.append(loss.item()) - for idx, loss in enumerate(losses): - print_each_rank(f'reducer loss[{idx}]: {loss}', rank_only=0) + for idx, (loss, gnorm) in enumerate(zip(losses, gnorms)): + print_each_rank(f'reducer step [{idx}]: loss: {loss} | gnorm: {gnorm}', rank_only=0) if __name__ == '__main__': - test_reducer() \ No newline at end of file + # test_reducer() + test_reducer_hooks() \ No newline at end of file diff --git a/tests/runtime/test_runtime_flag.py b/tests/runtime/test_runtime_flag.py new file mode 100644 index 00000000..cba0a7ca --- /dev/null +++ b/tests/runtime/test_runtime_flag.py @@ -0,0 +1,210 @@ +""" +example: + +ASYNC_REDUCER=0 USE_ZERO=1 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ + tests/runtime/test_runtime_flag.py +""" + +from typing import List +from functools import partial + +import torch +import random +from torch import nn + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.profiler.timer import print_each_rank +# from cube.tools.debug import DebugTool + +cube.init() + + +class MLP(nn.Module): + def __init__(self, dim, nlayers=16): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + self.param = torch.nn.Parameter(torch.ones([1])) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = x * self.param # for padding test + loss = torch.sum(x) + return loss + + +class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, bs: int, dim: int): + super().__init__(bs, [0]) + self.sample = None + self.dim = dim + self.set_batch_size(bs) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + torch.random.manual_seed(0) + self.batch_size = batch_size + self.sample = torch.randn( + [batch_size, self.dim], dtype=torch.float32, + device=torch.cuda.current_device() + ) + self.sample = (self.sample - 1) * 1e3 + + +def init_model_dataloader(): + batch_size = 4 + dim = 4096 + torch.random.manual_seed(0) + random.seed(0) + model = MLP(dim=dim) + # torch.random.manual_seed(0) + dataloader = MLPDataLoader(batch_size, dim) + return model, dataloader + + +def policy(graph: IRGraph, resource): + + # tensor parallelism + def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, **configs) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + # replicate + def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + devs = list(range(resource.ngpus)) + for node in graph.select(ntype=IRDataOperation): + _replica(graph, node, devs) + for node in graph.select(ntype=IRFwOperation): + _tp(graph, node, devs, idx=0, dim=0, num=resource.ngpus) + # if node.name == 'linear': + # _tp(graph, node, devs, idx=0, dim=0, num=resource.ngpus) + # else: + # _replica(graph, node, devs) + return graph + + +def get_baseline(): + + model, dataloader = init_model_dataloader() + model = model.cuda() + # optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + # optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + + wsz = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + accum_steps = 4 + + niters = 4 + losses = [] + for idx in range(niters): + for _ in range(accum_steps): + loss = train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + losses.append(loss.item()) + + for idx, loss in enumerate(losses): + print_each_rank(f'baseline loss[{idx}]: {loss}', rank_only=0) + + +def test_runtime_accum_mode_v1(): + + model, dataloader = init_model_dataloader() + model = model.cuda() + + @cube.compile(model, dataloader, PAS=policy) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + + model = cube.load_model() + # optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + # optimizer = torch.optim.SGD(model.parameters_for_optimizer(), lr=1e-4) + optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-5) + + accum_steps = 4 + + niters = 4 + losses = [] + for idx in range(niters): + for step in cube.accum_mode.steps(accum_steps): + # print(f'enter step {step}') + loss = train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + model.gather_params() + losses.append(loss.item()) + + for idx, loss in enumerate(losses): + print_each_rank(f'reducer loss[{idx}]: {loss}', rank_only=0) + + +def test_runtime_accum_mode_v2(): + + model, dataloader = init_model_dataloader() + model = model.cuda() + + @cube.compile(model, dataloader, PAS=policy) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + return loss + + model = cube.load_model() + # optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + # optimizer = torch.optim.SGD(model.parameters_for_optimizer(), lr=1e-4) + optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-5) + + accum_steps = 4 + + niters = 4 + losses = [] + for idx in range(niters): + for step in range(accum_steps): + # print(f'enter step {step}') + with cube.accum_mode(start=(step==0), end=(step==accum_steps-1)): + loss = train_iter(model, dataloader) + optimizer.step() + optimizer.zero_grad() + model.gather_params() + losses.append(loss.item()) + + for idx, loss in enumerate(losses): + print_each_rank(f'reducer loss[{idx}]: {loss}', rank_only=0) + + +if __name__ == '__main__': + + get_baseline() + test_runtime_accum_mode_v1() + # test_runtime_accum_mode_v2() \ No newline at end of file From b14e4b5a69e7da52764eb136b664832d2ec7cd87 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 29 Jun 2023 14:20:33 +0800 Subject: [PATCH 1404/1892] migrate concrete tracer from nni --- .../concrete_trace_utils/concrete_proxy.py | 4 +- .../concrete_trace_utils/concrete_tracer.py | 426 ++++++++++-------- .../concrete_trace_utils/operator_patcher.py | 4 +- .../parser/concrete_trace_utils/utils.py | 5 + cube/graph/parser/converter.py | 6 +- 5 files changed, 250 insertions(+), 195 deletions(-) diff --git a/cube/graph/parser/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/concrete_trace_utils/concrete_proxy.py index ae05767a..8706a07c 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_proxy.py @@ -54,7 +54,6 @@ class ConcreteProxy(Proxy): jump_opcodes = _orig_tuple(dis.opmap[name] for name in jump_opnames if name in dis.opmap) op_compare = dis.opmap['COMPARE_OP'] op_extended_arg = dis.opmap['EXTENDED_ARG'] - op_call = dis.opmap['CALL_FUNCTION'] op_call_ex = dis.opmap['CALL_FUNCTION_EX'] op_not = dis.opmap['UNARY_NOT'] op_unpack_sequence = dis.opmap['UNPACK_SEQUENCE'] @@ -170,6 +169,9 @@ def __bool__(self) -> Union[bool, ConcreteProxy]: elif insts[cur].opname == 'CONTAINS_OP': # in executing 'in' return _orig_bool(self.value) + elif insts[cur].opcode == self.op_call_ex: + # in executing func(..., *proxy) + return _orig_bool(self.value) elif insts[cur].opcode == self.op_not: # We cannot return a proxy because 'UNARY_NOT' op will check the type. _logger.warning('please use the function patcher, or use "x = operator.not_(y)" instead of "x = not y",' diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index b79de5cd..197d6c72 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -4,6 +4,7 @@ from __future__ import annotations import collections +import copy import sys import inspect import logging @@ -26,8 +27,9 @@ from torch.fx._compatibility import compatibility from torch.fx._symbolic_trace import _Patcher, _proxyable_classes from torch.fx.graph import Graph -from torch.fx.node import Target, Node +from torch.fx.node import Target, Node, Argument, _side_effectful_functions from torch.fx.proxy import TracerBase +from torch.fx.operator_schemas import check_for_mutable_operation try: # Scope is a new class to record module path in pytorch 2.0 @@ -41,6 +43,39 @@ def __init__(self, module_path: str, module_type: Any): self.module_path = module_path self.module_type = module_type +try: + # comes with Scope + from torch.fx.proxy import ScopeContextManager +except ImportError: + # copy from pytorch 2.0 + @compatibility(is_backward_compatible=False) + class ScopeContextManager: + """ A context manager to track the Scope of Node during symbolic tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + + def __init__( + self, + scope: Scope, + current_scope: Scope, + ): + super().__init__() + # Keep a copy of prev scope to restore on exit + self._prev_scope = copy.copy(scope) + # Update scope to current scope + scope.module_path = current_scope.module_path + scope.module_type = current_scope.module_type + # Save a reference so we can restore it + self._scope = scope + + def __enter__(self): + return self._scope + + def __exit__(self, *args): + self._scope.module_path = self._prev_scope.module_path + self._scope.module_type = self._prev_scope.module_type + return from . import concrete_proxy as ep from .operator_patcher import OperatorPatcherContext @@ -70,7 +105,9 @@ def __init__(self, module_path: str, module_type: Any): _orig_enumerate, _orig_slice, _orig_reversed, + _orig_torch_size, + _orig_torch_finfo, _orig_len, _orig_not, @@ -82,9 +119,19 @@ def __init__(self, module_path: str, module_type: Any): _orig_all, _orig_min, _orig_max, + + _orig_node_is_impure, ) +# some side effectful functions that should not be deleted during dead code elimination +# there may be more than listed here +extra_side_effectful_functions = { + operator.setitem, + builtins.next, +} +_side_effectful_functions = _side_effectful_functions.union(extra_side_effectful_functions) +# pyright: reportGeneralTypeIssues=false _logger = logging.getLogger(__name__) HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS @@ -205,14 +252,11 @@ class ConcreteTracer(TracerBase): _orig_reversed: ((), False), _orig_torch_size: ((), False), + _orig_torch_finfo: ((), False), } - # add these to record module path information during tracing - current_module_qualified_name : str = '' - node_to_originating_module : Dict[torch.fx.Node, str] = {} - @compatibility(is_backward_compatible=True) - def __init__(self, fake_device_type='cpu'): + def __init__(self, cpu_offload = False): """ similar to _symbolic_trace.Tracer.__init__. remove the 'param_shapes_constant' because we can get real shape when executing. @@ -221,8 +265,7 @@ def __init__(self, fake_device_type='cpu'): self.scope = Scope("", None) self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} - assert fake_device_type in ('cuda', 'cpu') - self.fake_device_type = fake_device_type + self.cpu_offload = cpu_offload @contextmanager def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): @@ -275,123 +318,91 @@ def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: D actually execute the code. apply the patcher, and the _autowrap_check to the target function. """ - if kind == 'call_function': - assert isinstance(target, Callable) - fn = target - if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): - _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - with self.do_temp_disable(call=True): - if self.fake_device_type == 'cpu': - to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - result = OperatorPatcherContext.patch_run(fn, *args, **kwargs) - for arg in args: - if _orig_isinstance(arg, torch.Tensor): - del arg - del args - for key, value in kwargs.items(): - if _orig_isinstance(value, torch.Tensor): - del value - del kwargs - if _orig_isinstance(result, torch.Tensor): - result_cpu = result.cpu() - del result - torch.cuda.empty_cache() - return result_cpu - if not isinstance(result, (tuple, list, dict)): - torch.cuda.empty_cache() - return result - to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t - result_cpu = tree_map(to_cpu, result) - for ret in result: - if _orig_isinstance(ret, torch.Tensor): - del ret - torch.cuda.empty_cache() - return result_cpu - else: - return OperatorPatcherContext.patch_run(fn, *args, **kwargs) - elif kind == 'call_method': - with self.do_temp_disable(call=True): - if self.fake_device_type == 'cpu': - to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - self_obj, *args_tail = args - fn = _orig_getattr(self_obj, target) - if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): - _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - # result = OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) - result = fn(*args_tail, **kwargs) # quick fix from yanjun result = OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) - if _orig_isinstance(result, torch.Tensor): - result_cpu = result.cpu() - del result - torch.cuda.empty_cache() - return result_cpu - if not isinstance(result, (tuple, list, dict)): - torch.cuda.empty_cache() - return result - to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t - result_cpu = tree_map(to_cpu, result) - for ret in result: - if _orig_isinstance(ret, torch.Tensor): - del ret - torch.cuda.empty_cache() - return result_cpu - else: - self_obj, *args_tail = args - fn = _orig_getattr(self_obj, target) - if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'): - _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - return OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs) - elif kind == 'call_module': - assert isinstance(target, str) - mod = self.fetch_attr(target) - if self.fake_device_type == 'cpu': - mod.cuda() - if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(mod, '__globals__'): - _autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - with self.do_temp_disable(call=True): - if self.fake_device_type == 'cpu': - to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - result = OperatorPatcherContext.patch_run(mod, *args, **kwargs) - for arg in args: - if _orig_isinstance(arg, torch.Tensor): - del arg - del args - for key, value in kwargs.items(): - if _orig_isinstance(value, torch.Tensor): - del value - del kwargs - mod.cpu() - if _orig_isinstance(result, torch.Tensor): - result_cpu = result.cpu() - del result - torch.cuda.empty_cache() - return result_cpu - if not isinstance(result, (tuple, list, dict)): - torch.cuda.empty_cache() - return result - to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t - result_cpu = tree_map(to_cpu, result) - for ret in result: - if _orig_isinstance(ret, torch.Tensor): - del ret - torch.cuda.empty_cache() - return result_cpu - else: - return OperatorPatcherContext.patch_run(mod, *args, **kwargs) - elif kind == 'get_attr': - assert isinstance(target, str) - return self.fetch_attr(target) - elif kind == 'output': + if kind == 'output': return args[0] elif kind == 'placeholder': return self.placeholder_dict[target] + + to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t + to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t + + def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]): + if self.cpu_offload: + args = tree_map(to_cuda, args) + kwargs = tree_map(to_cuda, kwargs) + + if kind == 'call_function': + assert isinstance(target, Callable) + fn = target + if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \ + and hasattr(fn, '__globals__'): + _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + return OperatorPatcherContext.patch_run(fn, *args, **kwargs) + elif kind == 'call_method': + self_obj, *args_tail = args + fn = _orig_getattr(self_obj, target) + if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \ + and hasattr(fn, '__globals__'): + _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + result = fn(*args_tail, **kwargs) + elif kind == 'call_module': + assert isinstance(target, str) + mod = self.fetch_attr(target) + if self.cpu_offload: + mod.cuda() # how it works in ddp? + if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \ + and hasattr(mod, '__globals__'): + _autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) + result = OperatorPatcherContext.patch_run(mod, *args, **kwargs) + if self.cpu_offload: + mod.cpu() + elif kind == 'get_attr': + assert isinstance(target, str) + return self.fetch_attr(target) + else: + raise RuntimeError() + return result + + with self.do_temp_disable(call=True): + result = run(kind, target, args, kwargs) + if self.cpu_offload: + if isinstance(result, torch.Tensor): + result = result.cpu() + elif isinstance(result, (list, dict, tuple)): + result = tree_map(to_cpu, result) + else: + _logger.warning(f"result of target {target} is {type(result)}, which is not a common behavior.") + + torch.cuda.empty_cache() + + self.temp_disable_call = False + return result + + @compatibility(is_backward_compatible=True) + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: + """ + This method is almost the same as the one in `TracerBase` class of Pytorch2.0. + Add it here because this method of Pytorch1.13 and older version + doesn't have the part related to `module_stack` and `node_name_to_scope`. + If we don't add it here, we can not use these two attributes in Pytorch1.13 and older version. + """ + if kind == 'call_function' and self.check_mutable_operations: + check_for_mutable_operation(target, args, kwargs) + + node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) + # TODO node_name_to_scope will be depricated in favor of + # node.meta['nn_module_stack'] + self.node_name_to_scope[node.name] = ( + self.scope.module_path, + self.scope.module_type, + ) + if self.module_stack: + node.meta['nn_module_stack'] = copy.copy(self.module_stack) else: - raise RuntimeError() + node.meta['nn_module_stack'] = collections.OrderedDict() + return node @compatibility(is_backward_compatible=True) def proxy(self, value: Any, node: Node) -> ep.ConcreteProxy: @@ -402,8 +413,8 @@ def proxy(self, value: Any, node: Node) -> ep.ConcreteProxy: @compatibility(is_backward_compatible=True) def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], - name: Optional[str] = None, type_expr: Optional[Any] = None, - proxy_factory_fn: Optional[Callable[[Node], Any]] = None): + name: Optional[str] = None, type_expr: Optional[Any] = None, + proxy_factory_fn: Optional[Callable[[Node], Any]] = None): """ similar to _symbolic_trace.Tracer.create_proxy. use the 'run_target' to actually execute the code, and store the value in 'value' field. @@ -426,7 +437,6 @@ def upwrapper(obj: Any): node = self.create_node(kind, target, args_, kwargs_, name, type_expr) proxy = self.proxy(value_unwrapped, node) - self.node_to_originating_module[proxy.node] = self.current_module_qualified_name return proxy @compatibility(is_backward_compatible=True) @@ -518,7 +528,6 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool """ similar to _symbolic_trace.Tracer.is_leaf_module """ - # return (m.__module__.startswith('torch.nn') and not _orig_isinstance(m, (Sequential, ModuleList, ModuleDict)))\ return (m.__module__.startswith('torch.nn.functional') and not _orig_isinstance(m, (Sequential, ModuleList, ModuleDict)))\ or _orig_isinstance(m, self.leaf_module) @@ -590,10 +599,10 @@ def create_args_for_root(self, root_fn, is_module, concrete_args: Union[Dict[str cnt = 0 self.placeholder_dict = {} arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] - diff_len = len(arg_names) - len(default_value_list) + diff_len = _orig_len(arg_names) - _orig_len(default_value_list) default_args = {arg_names[idx + diff_len]: default_value_list[idx] for idx in range(len(default_value_list))} if isinstance(concrete_args, tuple): - if len(arg_names) != len(concrete_args): + if _orig_len(arg_names) != _orig_len(concrete_args): raise RuntimeError(f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments") concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} def proxy_placeholder(name: str): @@ -759,7 +768,7 @@ def module_getattribute_wrapper(mod, attr): return _orig_module_getattribute(mod, attr) except AttributeError: return _orig_module_getattr(mod, attr) - with self.do_temp_disable(call=True, attr=True): + with self.do_temp_disable(attr=True): try: attr_val = _orig_module_getattribute(mod, attr) except AttributeError: @@ -769,7 +778,9 @@ def module_getattribute_wrapper(mod, attr): return self.wrapped_leaf[attr_val][1] return attr_val elif attr in self.default_module_getattr: - return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) + path = self.the_path_of_middle_class[id(mod)] + path = path + '.' if path else '' + return self.create_proxy('get_attr', f'{path + attr}', (), {}) elif _orig_isinstance(attr_val, (_orig_tuple, _orig_list)): if self.the_path_of_middle_class[id(mod)] == '': return self.create_proxy('get_attr', f'{attr}', (), {}) @@ -786,21 +797,27 @@ def module_call_wrapper(mod, *args, **kwargs): if self.temp_disable_call: return _orig_module_call(mod, *args, **kwargs) else: + # codes below corresponds to symbolic tracer's call_module module_qualified_name = self.path_of_module(mod) - if not self.is_leaf_module(mod, module_qualified_name): - _autowrap_check(self, - mod.forward.__globals__, - self._autowrap_function_ids, - self.autowrap_leaf_pairs, - self.agfunc_dict) - _autowrap_check(self, - mod.__dict__, - self._autowrap_function_ids, - self.autowrap_leaf_pairs, - self.agfunc_dict) - return _orig_module_call(mod, *args, **kwargs) - else: - return self.create_proxy('call_module', module_qualified_name, args, kwargs) + with ScopeContextManager(self.scope, Scope(module_qualified_name, type(mod))) as _scope: + self.module_stack[_scope.module_path] = _scope.module_type + if not self.is_leaf_module(mod, module_qualified_name): + _autowrap_check(self, + mod.forward.__globals__, + self._autowrap_function_ids, + self.autowrap_leaf_pairs, + self.agfunc_dict) + _autowrap_check(self, + mod.__dict__, + self._autowrap_function_ids, + self.autowrap_leaf_pairs, + self.agfunc_dict) + ret_val = _orig_module_call(mod, *args, **kwargs) + else: + ret_val = self.create_proxy('call_module', module_qualified_name, args, kwargs) + key, _ = self.module_stack.popitem(last=True) + assert key == _scope.module_path, f" Unexpected key {key}" + return ret_val class map_wrapper_clz: @functools.wraps(_orig_map) @@ -1400,13 +1417,13 @@ def _retain_weight_consistency(root: torch.nn.Module): for module in root.modules(): for name, param in module.named_parameters(): if _orig_isinstance(param, ep.ConcreteProxy): - param: ep.ConcreteProxy # pyright: reportGeneralTypeIssues=false + param: ep.ConcreteProxy _logger.warning(f'Parameter {name} of {module} is a ConcreteProxy. Some weight may be modified inplace within forward().') setattr(module, name, param.value) _flag |= 1 for name, buffer in module.named_buffers(): if _orig_isinstance(buffer, ep.ConcreteProxy): - buffer: ep.ConcreteProxy # pyright: reportGeneralTypeIssues=false + buffer: ep.ConcreteProxy _logger.warning(f'Buffer {name} of {module} is a ConcreteProxy. Some buffer may be modified inplace within forward().') setattr(module, name, buffer.value) _flag |= 1 @@ -1415,6 +1432,29 @@ def _retain_weight_consistency(root: torch.nn.Module): ' ``concrete_trace`` may not guarantee the consistency of the traced graph.') return root +@functools.wraps(_orig_node_is_impure) +def node_is_impure_wrapper(node): + if node.op in {"placeholder", "output"}: + return True + + if node.op == "call_function": + return node.target in _side_effectful_functions + + if node.op == "call_method": + return node.target.endswith("_") + + if node.op == "call_module": + assert ( + node.graph.owning_module is not None + ), "self.graph.owning_module not set for purity check" + target_mod = node.graph.owning_module.get_submodule(node.target) + assert ( + target_mod is not None + ), f"Did not find expected submodule target {node.target}" + return getattr(target_mod, "_is_impure", False) + + return False + def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Union[Dict[str, Any], Tuple], *, @@ -1426,7 +1466,10 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], autowrap_leaf_class = None, leaf_module: Tuple | None = None, fake_middle_class = None, - fake_device_type='cpu') -> GraphModule: + dce = True, + cpu_offload = False, + trace_twice = False, + ) -> GraphModule: """ Concrete tracing API @@ -1549,64 +1592,71 @@ def f(x, y): The struct of dict is: leaf_class: ([(module_path, module_name)], is_iterator_class). is_iterator_class: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True. + cpu_offload (bool): Whether to offload the module to CPU during tracing. If set to True, the traced code will be executed on GPU, + but is offloaded to CPU afterward. This is useful for reducing memory usage during tracing, but may cause performance issues. + If set to False, there will be no offloading during tracing, but the traced code will be executed on default device. + Returns: fx.GraphModule: a Module created from the recorded operations from ``root``. """ - tracer = ConcreteTracer(fake_device_type=fake_device_type) + tracer = ConcreteTracer(cpu_offload = cpu_offload) + is_training = root.training + root.eval() graph = tracer.trace(root, autowrap_leaf_function = autowrap_leaf_function, autowrap_leaf_class = autowrap_leaf_class, leaf_module = leaf_module, fake_middle_class = fake_middle_class, - concrete_args=concrete_args, - use_operator_patch=use_operator_patch, - operator_patch_backlist=operator_patch_backlist, - forward_function_name=forward_function_name, + concrete_args = concrete_args, + use_operator_patch = use_operator_patch, + operator_patch_backlist = operator_patch_backlist, + forward_function_name = forward_function_name, ) - # graph = tracer.trace(root, - # autowrap_leaf_function = autowrap_leaf_function, - # autowrap_leaf_class = autowrap_leaf_class, - # leaf_module = leaf_module, - # fake_middle_class = fake_middle_class, - # concrete_args=concrete_args, - # use_operator_patch=use_operator_patch, - # operator_patch_backlist=operator_patch_backlist, - # forward_function_name=forward_function_name, - # ) - # graph_check = tracer.trace(root, - # autowrap_leaf_function = autowrap_leaf_function, - # autowrap_leaf_class = autowrap_leaf_class, - # leaf_module = leaf_module, - # fake_middle_class = fake_middle_class, - # concrete_args=concrete_args, - # use_operator_patch=use_operator_patch, - # operator_patch_backlist=operator_patch_backlist, - # forward_function_name=forward_function_name, - # ) - # # compare to check equal - # assert len(graph.nodes) == len(graph_check.nodes), f'number nodes: {len(graph.nodes)} vs {len(graph_check.nodes)}' - # for node_a, node_b in zip(graph.nodes, graph_check.nodes): - # node_a: Node - # node_b: Node - # target_a = node_a.target - # target_b = node_b.target - # if node_a.op == 'get_attr' and node_a.name.startswith('_tensor_constant'): - # assert node_b.op == 'get_attr' and node_b.name.startswith('_tensor_constant') - # assert torch.equal(getattr(root, node_a.name), getattr(root, node_b.name)) - # elif node_a.op == 'call_function' and isinstance(target_a, Callable) and target_a.__name__ == 'apply' and\ - # hasattr(target_a, '__self__') and issubclass(target_a.__self__, torch.autograd.Function): - # assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\ - # hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function) - # else: - # assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}' + + if trace_twice: + graph_check = tracer.trace(root, + autowrap_leaf_function = autowrap_leaf_function, + autowrap_leaf_class = autowrap_leaf_class, + leaf_module = leaf_module, + fake_middle_class = fake_middle_class, + concrete_args = concrete_args, + use_operator_patch = use_operator_patch, + operator_patch_backlist = operator_patch_backlist, + forward_function_name = forward_function_name, + ) + # compare to check equal + assert len(graph.nodes) == len(graph_check.nodes), f'number nodes: {len(graph.nodes)} vs {len(graph_check.nodes)}' + for node_a, node_b in zip(graph.nodes, graph_check.nodes): + node_a: Node + node_b: Node + target_a = node_a.target + target_b = node_b.target + if node_a.op == 'get_attr' and node_a.name.startswith('_tensor_constant'): + assert node_b.op == 'get_attr' and node_b.name.startswith('_tensor_constant') + assert torch.equal(getattr(root, node_a.name), getattr(root, node_b.name)) + elif node_a.op == 'call_function' and isinstance(target_a, Callable) and target_a.__name__ == 'apply' and\ + hasattr(target_a, '__self__') and issubclass(target_a.__self__, torch.autograd.Function): + assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\ + hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function) + else: + assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}' with MagicMethodPatcher(): name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ traced = GraphModule(tracer.root, graph, name) + if dce: + with _Patcher() as patcher: + patcher.patch_method(Node, 'is_impure', node_is_impure_wrapper, deduplicate=False) + traced.graph.eliminate_dead_code() + traced.recompile() # this need to be done in MagicMethodPatcher context + # TODO: better infomation - # # assert root(**concrete_args) == traced(**concrete_args) if check_args is not None: assert root(**check_args) == traced(**check_args) + + if is_training: + root.train() + return traced diff --git a/cube/graph/parser/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/concrete_trace_utils/operator_patcher.py index ef9756cd..7f2b109c 100644 --- a/cube/graph/parser/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/concrete_trace_utils/operator_patcher.py @@ -52,7 +52,7 @@ def visit(self, node): return super().visit(node) def visit_Call(self, node: ast.Call): - if isinstance(node.func, ast.Name) and node.func.id == 'super' and len(node.args) == 0: + if isinstance(node.func, ast.Name) and node.func.id == 'super' and _orig_len(node.args) == 0: return self.generic_visit(ast.Call( func=ast.Name(id='super', ctx=ast.Load()), args=[ @@ -173,6 +173,8 @@ def __init__(self, use_operator_patch: bool, operator_patch_backlist: List[str]) self.function_cache_orig: Dict[int, Callable] = {} def patch_inner(self, func): + if _orig_isinstance(func, torch.nn.Module): + return self.patch_inner_helper(func) # better not cache this if id(func) not in self.function_cache: self.function_cache[id(func)] = self.patch_inner_helper(func) self.function_cache_orig[id(func)] = func diff --git a/cube/graph/parser/concrete_trace_utils/utils.py b/cube/graph/parser/concrete_trace_utils/utils.py index 1d8a48f1..2604340a 100644 --- a/cube/graph/parser/concrete_trace_utils/utils.py +++ b/cube/graph/parser/concrete_trace_utils/utils.py @@ -7,6 +7,7 @@ import functools import torch +from torch.fx import Node # These need to run in global scope to handle nested calls correctly _orig_module_call: Callable = torch.nn.Module.__call__ @@ -34,7 +35,9 @@ _orig_enumerate: Type[Any] = builtins.enumerate _orig_slice: Type[Any] = builtins.slice _orig_reversed: Type[Any] = builtins.reversed + _orig_torch_size: Type[Any] = torch.Size +_orig_torch_finfo: Type[Any] = torch.finfo _orig_len: Callable = builtins.len _orig_not: Callable = operator.not_ @@ -47,6 +50,8 @@ _orig_min: Callable = builtins.min _orig_max: Callable = builtins.max +_orig_node_is_impure: Callable = Node.is_impure + def run_onlyif_instance(cond_type: Type[Any], return_orig: bool = True, return_const: Any = None): def helper(fn): diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 9e2c204d..678106d0 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -53,12 +53,8 @@ def convert_model(model: torch.nn.Module, model, dummy_input, use_operator_patch=True, - autowrap_leaf_class={ - torch.finfo: ((), False), - # type(output_origin): ((), False), - }, leaf_module=leaf_module, - fake_device_type='cpu', + cpu_offload=True, ) else: print('using torchscript tracer') From 77672e484b8f4d3283997edd006b15ded76a26bd Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 3 Jul 2023 19:43:48 +0800 Subject: [PATCH 1405/1892] fix autograd apply in torch 2.0 --- .../concrete_trace_utils/concrete_proxy.py | 4 ++ .../concrete_trace_utils/concrete_tracer.py | 54 +++++++++++-------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/cube/graph/parser/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/concrete_trace_utils/concrete_proxy.py index 8706a07c..e06bf93c 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_proxy.py @@ -188,6 +188,10 @@ def __hash__(self) -> Union[int, ConcreteProxy]: # should only be in dict getitem return hash(self.value) + def __contains__(self, item) -> bool: + # should only be in iterable + return self.value.__contains__(item) + @compatibility(is_backward_compatible=True) def keys(self): # to detect if in executing `**proxy` diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index 197d6c72..71ff1148 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -11,6 +11,7 @@ import operator import functools import builtins +from packaging import version from itertools import chain from types import BuiltinMethodType, FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType @@ -135,6 +136,17 @@ def __exit__(self, *args): _logger = logging.getLogger(__name__) HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS + +def is_autograd_apply(func) -> bool: + # FIXME: version need check + if version.parse(torch.__version__) >= version.parse('2.0'): + return getattr(func, '__name__', None) == 'apply' \ + and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) + else: + return _orig_isinstance(func, BuiltinMethodType) and getattr(func, '__name__', None) == 'apply'\ + and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) + + @compatibility(is_backward_compatible=True) class ConcreteTracer(TracerBase): """ @@ -773,19 +785,24 @@ def module_getattribute_wrapper(mod, attr): attr_val = _orig_module_getattribute(mod, attr) except AttributeError: attr_val = _orig_module_getattr(mod, attr) - if callable(attr_val): + if _orig_isinstance(attr_val, ep.ConcreteProxy): + warn_msg = f'Detected {self.the_path_of_middle_class[id(mod)]}.{attr} is a ConcreteProxy, ' + \ + 'this is usually caused by directly assigning the return value of some leaf function to the attribute of the module. ' + \ + 'Please note that this writing method may cause some trace errors.' + _logger.warning(warn_msg) + if callable(attr_val) and not _orig_isinstance(attr_val, ep.ConcreteProxy): if attr_val in self.wrapped_leaf: return self.wrapped_leaf[attr_val][1] return attr_val - elif attr in self.default_module_getattr: - path = self.the_path_of_middle_class[id(mod)] - path = path + '.' if path else '' - return self.create_proxy('get_attr', f'{path + attr}', (), {}) - elif _orig_isinstance(attr_val, (_orig_tuple, _orig_list)): + # using isinstance instead of _orig_isinstance to judge whether + # the ConcreteProxy.value is the following three types if the attr_val is a ConcreteProxy + elif isinstance(attr_val, (_orig_tuple, _orig_list)): if self.the_path_of_middle_class[id(mod)] == '': return self.create_proxy('get_attr', f'{attr}', (), {}) else: return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) + elif attr in self.default_module_getattr: + return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) elif id(attr_val) in self.the_path_of_parameter: return self.create_proxy('get_attr', self.the_path_of_parameter[id(attr_val)], (), {}) elif id(attr_val) in self.the_path_of_buffer: @@ -932,8 +949,7 @@ def torch_assert_wrapper(condition, message): self.wrapped_leaf = dict() for func, (positions, is_force_trace, to_func) in self.autowrap_leaf_function.items(): - if _orig_isinstance(func, BuiltinMethodType) and getattr(func, '__name__', None) == 'apply'\ - and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function): + if is_autograd_apply(func): # torch.autograd.function assert to_func == None, '.apply should set to_func to None!' if func.__self__ not in self.agfunc_dict: @@ -953,7 +969,8 @@ def torch_assert_wrapper(condition, message): wrapped = _create_wrapped_leaf_func(self, func, to_func) elif _orig_isinstance(func, (MethodDescriptorType, MethodWrapperType)): wrapped = _create_wrapped_leaf_method(self, func, func.__name__, to_func) - elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn': + elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn' \ + and not func.__qualname__.startswith('PyCapsule'): # method if func.__module__.startswith('_') and func.__module__ != '__main__': path = sys.modules[func.__module__[1:]] @@ -1165,8 +1182,7 @@ def _autowrap_check(tracer: ConcreteTracer, frame_dict : Dict[str, Any], functio patcher.patch(frame_dict, name, _create_wrapped_func(value)) elif id(value) in function_pairs: patcher.patch(frame_dict, name, function_pairs[id(value)]) - elif _orig_isinstance(value, BuiltinMethodType) and getattr(value, '__name__', None) == 'apply'\ - and _orig_isinstance(getattr(value, '__self__', None), Type) and issubclass(value.__self__, torch.autograd.Function): + elif is_autograd_apply(value): # torch.autograd.function if value.__self__ not in agfunc_dict: agfunc_dict[value.__self__] = _create_wrapped_leaf_func(tracer, value, value) @@ -1248,15 +1264,14 @@ def copy_attr_new(from_module: torch.nn.Module, to_module: torch.nn.Module, targ def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: name = orig_method.__name__ module = orig_method.__module__ + if is_autograd_apply(orig_method): + # for torch.autograd.Function + return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' if module is not None: return module elif hasattr(orig_method, '__qualname__')\ and isinstance(orig_method.__qualname__, str) and orig_method.__qualname__.startswith('_VariableFunctionsClass.'): return 'torch._C._VariableFunctions' - elif hasattr(orig_method, '__self__')\ - and isinstance(orig_method.__self__, Type) and issubclass(orig_method.__self__, torch.autograd.Function): - # for torch.autograd.Function - return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' for guess in [torch, getattr(torch.nn, 'functional')]: if getattr(guess, name, None) is orig_method: return guess.__name__ @@ -1264,8 +1279,7 @@ def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: @staticmethod def format_import_statement_new(name: str, obj: Any, importer) -> str: - if isinstance(obj, BuiltinMethodType) and getattr(obj, '__name__', None) == 'apply'\ - and isinstance(getattr(obj, '__self__', None), Type) and issubclass(obj.__self__, torch.autograd.Function): # type: ignore + if is_autograd_apply(obj): # type: ignore # torch.autograd.function return MagicMethodPatcher.format_import_statement_ori(name, obj.__self__, importer) + f'\n{name} = {name}.apply' return MagicMethodPatcher.format_import_statement_ori(name, obj, importer) @@ -1635,10 +1649,8 @@ def f(x, y): if node_a.op == 'get_attr' and node_a.name.startswith('_tensor_constant'): assert node_b.op == 'get_attr' and node_b.name.startswith('_tensor_constant') assert torch.equal(getattr(root, node_a.name), getattr(root, node_b.name)) - elif node_a.op == 'call_function' and isinstance(target_a, Callable) and target_a.__name__ == 'apply' and\ - hasattr(target_a, '__self__') and issubclass(target_a.__self__, torch.autograd.Function): - assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\ - hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function) + elif node_a.op == 'call_function' and is_autograd_apply(target_a): + assert node_b.op == 'call_function' and is_autograd_apply(target_b) else: assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}' From 3606c2885f634427319433b50ac168217965ef36 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 5 Jul 2023 05:10:20 +0000 Subject: [PATCH 1406/1892] Merged PR 1646: Support ops introduced by MoE Related work items: #1455 --- cube/graph/function/function.py | 44 ++++++++++++++++++++++++++++++--- cube/graph/parser/mappingfx.py | 5 +++- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 140c0e55..e96e413c 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -11,7 +11,7 @@ from cube.ir.tensor import IRSubTensor, IRFullTensor from cube.ir.dtype import IRDType from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule +from cube.graph.function.dimops import DimAnno, DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D from cube.graph.function.anchor import IRGraphAnchor @@ -211,7 +211,7 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, anno, rules = _get_creator_anno_rules(size, True) dimop = IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) from cube.graph.parser.dtype import DType2IRDType - dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + dimop.output(0).parent.dtype = dtype if isinstance(dtype, IRDType) else DType2IRDType.map(dtype) return dimop @@ -334,6 +334,10 @@ def Expand(input, *sizes, signature = None): return IRDimops(Expand, 'expand', signature, [anno], [input], sizes=sizes) +def ExpandAs(input, other, signature = None): + return Expand(input, *other.shape, signature = signature) + + def Clone(input, *, memory_format=None, signature = None): """ torch.clone(input, *, memory_format=torch.preserve_format) @@ -527,6 +531,10 @@ def Clamp(input, min=None, max=None, *, out=None, signature = None): return IRDimops(Clamp, 'clamp', signature, annos, [input], min=min, max=max) +def ClampMin(input, min, *, out=None, signature = None): + return Clamp(input, min=min, out=out, signature='torch.clamp') + + def Softmax(input, dim=None, _stacklevel=3, dtype=None, signature = None): """ torch.nn.functional.softmax(input, dim=None, _stacklevel=3, dtype=None) @@ -722,6 +730,37 @@ def FusedLayerNorm(input, weight, bias, normalized_shape, eps=1e-5, signature = return IRDimops(FusedLayerNorm, 'fusedlayernorm', signature, [anno], inputs, **kwargs) +def Norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None, signature=None): + assert dtype is None, "Currently Norm only support dtype=None" + einput = ShapeAnno.create_shape_str(input.shape) + eoutput = copy.copy(einput) + kwargs = { + 'p': p, + 'dim': dim, + 'keepdim': keepdim, + 'out': out, + 'dtype': dtype, + } + if dim is None: + einput = [edim + '^' for edim in einput] + anno = OpAnno.create_op_str([einput], ['1']) + return IRDimops(Norm, 'norm', signature, [anno], [input], **kwargs) + else: + dim = (dim,) if isinstance(dim, int) else dim + for dimidx in dim: + einput[dimidx] += '^' + if keepdim: + for dimidx in dim: + eoutput[dimidx] = '1' + else: + sort_dim = list(dim) + sort_dim.sort() + for dimidx in sort_dim[::-1]: + eoutput.pop(dimidx) + anno = OpAnno.create_op_str([einput], [eoutput]) + return IRDimops(Norm, 'norm', signature, [anno], [input], **kwargs) + + def Sum(input, dim=None, keepdim=False, *, dtype=None, signature = None): """ torch.sum(input, *, dtype=None) -> Tensor @@ -1719,7 +1758,6 @@ def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): raise RuntimeError(f'function.To with unknown arg: {dtype_or_device}') - def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: """ _operator.getitem(obj, index: int) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index ba1480f8..62a7b984 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -60,6 +60,7 @@ def exist(signature: str) -> bool: __ttemplate('exp'): function.Exp, __ttemplate('sqrt'): function.Sqrt, __ttemplate('clamp'): function.Clamp, + __ttemplate('clamp_min'): function.ClampMin, __ttemplate('squeeze'): function.Squeeze, __ttemplate('unsqueeze'): function.Unsqueeze, __tttemplate('type_as'): function.TypeAs, @@ -89,6 +90,7 @@ def exist(signature: str) -> bool: __ttemplate('permute'): function.Permute, __ttemplate('transpose'): function.Transpose, __tttemplate('expand'): function.Expand, + __tttemplate('expand_as'): function.ExpandAs, __ttemplate('arange'): function.Arange, __ttemplate('detach'): function.Detach, __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, @@ -102,7 +104,7 @@ def exist(signature: str) -> bool: __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, __ftemplate('nll_loss') : function.NLLLoss, - + 'torch.functional.norm': function.Norm, __ftemplate('layer_norm'): function.LayerNorm, 'apex.normalization.fused_layer_norm.FusedLayerNorm': function.FusedLayerNorm, @@ -116,6 +118,7 @@ def exist(signature: str) -> bool: # # torch nn functional '_operator.matmul': function.Matmul, + 'torch.mm': function.Matmul, __ttemplate('matmul'): function.Matmul, # # __ftemplate('gelu') : function.GeLU, From 43380be05a544a21c5dc4b1b9c70173c1c6b78ab Mon Sep 17 00:00:00 2001 From: nishang Date: Wed, 5 Jul 2023 06:01:43 +0000 Subject: [PATCH 1407/1892] fix comments --- .../concrete_trace_utils/concrete_tracer.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index 71ff1148..7f9e195b 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -138,13 +138,8 @@ def __exit__(self, *args): def is_autograd_apply(func) -> bool: - # FIXME: version need check - if version.parse(torch.__version__) >= version.parse('2.0'): - return getattr(func, '__name__', None) == 'apply' \ - and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) - else: - return _orig_isinstance(func, BuiltinMethodType) and getattr(func, '__name__', None) == 'apply'\ - and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) + return getattr(func, '__name__', None) == 'apply' \ + and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) @compatibility(is_backward_compatible=True) @@ -346,14 +341,14 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] if kind == 'call_function': assert isinstance(target, Callable) fn = target - if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \ + if _orig_getattr(fn, '__module__', None) != 'cube.graph.parser.concrete_trace_utils.concrete_tracer' \ and hasattr(fn, '__globals__'): _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) return OperatorPatcherContext.patch_run(fn, *args, **kwargs) elif kind == 'call_method': self_obj, *args_tail = args fn = _orig_getattr(self_obj, target) - if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \ + if _orig_getattr(fn, '__module__', None) != 'cube.graph.parser.concrete_trace_utils.concrete_tracer' \ and hasattr(fn, '__globals__'): _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) result = fn(*args_tail, **kwargs) @@ -362,7 +357,7 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] mod = self.fetch_attr(target) if self.cpu_offload: mod.cuda() # how it works in ddp? - if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \ + if _orig_getattr(mod, '__module__', None) != 'cube.graph.parser.concrete_trace_utils.concrete_tracer' \ and hasattr(mod, '__globals__'): _autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) result = OperatorPatcherContext.patch_run(mod, *args, **kwargs) @@ -1606,10 +1601,14 @@ def f(x, y): The struct of dict is: leaf_class: ([(module_path, module_name)], is_iterator_class). is_iterator_class: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True. + dce (bool): If set to True, dead code eliminatation will be applied on the graph. + cpu_offload (bool): Whether to offload the module to CPU during tracing. If set to True, the traced code will be executed on GPU, but is offloaded to CPU afterward. This is useful for reducing memory usage during tracing, but may cause performance issues. If set to False, there will be no offloading during tracing, but the traced code will be executed on default device. + trace_twice (bool): If set to True, a second trace will be performed, and the two obtained graphs will be checked for consistency. + Returns: fx.GraphModule: a Module created from the recorded operations from ``root``. """ From 5d88498968e49f980429030cde43e5ed416aa567 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jul 2023 02:02:35 +0000 Subject: [PATCH 1408/1892] Merged PR 1635: general dynamic shape support --- cube/codegen/emit.py | 32 +- cube/codegen/frontend_mapping.py | 138 ++++---- cube/codegen/lifecycle.py | 4 +- cube/compiler.py | 72 ++-- cube/flags.py | 1 + cube/graph/function/function.py | 524 +++++++++++------------------- cube/graph/gener/gen.py | 4 +- cube/graph/parser/converter.py | 42 ++- cube/graph/parser/parserfx.py | 241 ++++++-------- cube/graph/segment.py | 6 +- cube/ir/cten.py | 12 +- cube/program.py | 66 ++-- cube/runtime/function/function.py | 17 +- tests/parser/test_fx_ops.py | 3 +- 14 files changed, 505 insertions(+), 657 deletions(-) diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 4a3429dd..9eb5b707 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -115,11 +115,16 @@ def return_name_complex(vals: List[Any], @staticmethod def kwargs_name(**kwargs) -> str: + """Get kwarg name""" names = [] + # FIXME make the str include `""` + # for name, val in kwargs.items(): + # if isinstance(val, str) and not val.startswith('self.'): + # kwargs[name] = '"' + val + '"' + # turn object into name + modifier = lambda t: IRValue(CodeEmission.tensor_name(t)) + kwargs = IRSegment.modify_objects_of_complex(kwargs, modifier) for name, val in kwargs.items(): - # TODO: Ad-hoc patch for amp - if CompileFlag.use_amp and val == 'torch.float32': - val = 'torch.float16' names.append(f'{name}={val}') name = ', '.join(names) return name @@ -133,12 +138,11 @@ def emit_dataloader(node: IRDataOperation) -> List[str]: @staticmethod def emit_fnode(node: IRFwOperation, prefix_attr: str = None) -> List[str]: - """ - Emit the statement to call the op in the forward code - (e.g. in Segments, Adapter or CodeGen.Main) + """Emit forward node code The result will look like (the lines are split into `List[str]`) ``` + # comment if have tensor_3333 = torch.view(tensor_2222, [1,2,3,4,5]) ``` @@ -150,14 +154,18 @@ def emit_fnode(node: IRFwOperation, prefix_attr: str = None) -> List[str]: # insert comment if node.comment is not None: codes.append(f'# {node.comment}') + signature = node.signature + # setup arg string inputs = [FuncEmission.tensor_name(t, prefix_attr=prefix_attr) for t in node.inputs()] - kwargs = {} - for key in node.kwargs: - val = node.kwargs[key] - if isinstance(val, str) and 'self.' not in val: - val = '"' + val + '"' - kwargs[key] = val + # setup kwarg string + kwargs = dict(**node.kwargs) + for name, val in kwargs.items(): + if isinstance(val, str) and not val.startswith('self.'): + kwargs[name] = '"' + val + '"' + # turn IRObject into name + modifier = lambda t: IRValue(CodeEmission.tensor_name(t)) + kwargs = IRSegment.modify_objects_of_complex(kwargs, modifier) emit_rule = Sign2EmitRule.map(signature) body = emit_rule(node, inputs, kwargs) diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index e131d9e6..60b65856 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -1,7 +1,7 @@ # Some operators should be specially handled during codegen to the frontend code, # here we define the customized rule for code emisson. -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional from cube import ir from cube.ir.cten import IRTensor @@ -10,100 +10,96 @@ import torch -# By default, we flatten all args and join them by "," -# this includes ops with a fixed number of parameters like 'add(x,y)', -# or ops allowing multiple parameters at the frontend like 'block_diag(t1,t2' -def _common_rule_join_all(node:IRFwOperation, arg_vars:List[str], kw_pairs:dict) -> str: - signature = node.signature - kw_assigns = list() - for key, val in kw_pairs.items(): - code = f'{key}={val}' - kw_assigns.append(code) +class Sign2EmitRule: + """Emit rule for frontend PyTorch codegen""" - args = ", ".join(arg_vars + kw_assigns) - return f"{signature}({args})" + _sign2rule = {} -def _common_rule_input_as_list(node:IRFwOperation, arg_vars:List[str], kw_pairs:dict) -> str: - signature = node.signature + @staticmethod + def map(signature: str) -> Callable: + """Get the emit rule for the given signature + + Args: + signature (str): signature of the operator - kw_assigns = list() - for key, val in kw_pairs.items(): - code = f'{key}={val}' - kw_assigns.append(code) - - args = ", ".join(arg_vars) - kwargs = ", ".join(kw_assigns) - return f"{signature}([{args}], {kwargs})" + Returns: + Callable: emit rule that takes the node, args (List[str]) and kwargs (Dict[str, str]) as input + """ + return Sign2EmitRule._sign2rule.get(signature, Sign2EmitRule.emit_common) -def emit_slice(node, arg_vars:list, kw_pairs:dict) -> str: - """ - The op is: - aten::slice(input:Tensor, dim:int=0, start:Optional[int]=None, end:Optional[int]=None, step:int=1) -> Tensor - - but at the frontend such an invocation must be rewritten as 'x[:, l:h:s, :, :]' - depending on the 'input's rank and the 'dim' value. - """ - out_tensors : tuple = node.outputs() - assert len(out_tensors) == 1 - out_tensor : IRTensor = out_tensors[0] + @staticmethod + def emit_common(node: IRFwOperation, args: List[str], kwargs: Dict[str, str]) -> str: + """Default rule to join all args and kwargs""" - assert len(arg_vars) == 1 - in_tensor_var : str = arg_vars[0] + signature = node.signature - dim : int = kw_pairs["dim"] - start : Optional[int] = kw_pairs["start"] - end : Optional[int] = kw_pairs["end"] - step : int = kw_pairs["step"] - - rank = len(out_tensor.shape) - subscript_components = [":"] * rank + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) - slice_str = f"{start or ''}:{end or ''}:{step}" - subscript_components[dim] = slice_str + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" - return f"{in_tensor_var}[{', '.join(subscript_components)}]" + @staticmethod + def emit_slice(node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + """Special rule for generating slice node + The op is: + aten::slice(input:Tensor, dim:int=0, start:Optional[int]=None, end:Optional[int]=None, step:int=1) -> Tensor -def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + but at the frontend such an invocation must be rewritten as 'x[:, l:h:s, :, :]' + depending on the 'input's rank and the 'dim' value. + """ + out_tensors : tuple = node.outputs() + assert len(out_tensors) == 1 + out_tensor : IRTensor = out_tensors[0] - assert arg_vars[1].startswith('self.') - member = f'"{arg_vars[1][5:]}"' - return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" + assert len(arg_vars) == 1 + in_tensor_var : str = arg_vars[0] + dim : int = kw_pairs["dim"] + start : Optional[int] = kw_pairs["start"] + end : Optional[int] = kw_pairs["end"] + step : int = kw_pairs["step"] -def emit_getattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: - return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" + rank = len(out_tensor.shape) + subscript_components = [":"] * rank -class Sign2EmitRule: + slice_str = f"{start or ''}:{end or ''}:{step}" + subscript_components[dim] = slice_str + + return f"{in_tensor_var}[{', '.join(subscript_components)}]" @staticmethod - def map(signature:str) -> Callable[[IRFwOperation, List[str], Dict[str, Any]], str]: + def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + """Special rule for generating setattr node """ - The definition of the emit rule is like: - - ``` - def emit_for_lstm_cell(node, arg_vars, kw_pairs) -> str: - x_var, h_var, c_var = arg_vars - return f"lstm({x_var}, [{h_var}, {c_var}], OTHER_ARG_VARS)" - ``` - - 'arg_vars' are inputs (all are Tensor-typed) variable names as string, e.g., ["x", "y"] - 'kw_pairs' are dict whose values has been preprocessed and can be directly stringified, - e.g., {"dim":1, "layout"="nchw"} + + assert arg_vars[1].startswith('self.') + member = f'"{arg_vars[1][5:]}"' + return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" + + @staticmethod + def emit_getattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + """Special rule for generating getattr node """ - return Sign2EmitRule._signMap.get(signature) or _common_rule_join_all + return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" - _signMap = { - 'torch.slice': emit_slice, - 'setattr': emit_setattr, - 'builtins.getattr': emit_getattr, - } +# the registered emit rules +Sign2EmitRule._sign2rule = { + 'torch.slice': Sign2EmitRule.emit_slice, + 'setattr': Sign2EmitRule.emit_setattr, + 'builtins.getattr': Sign2EmitRule.emit_getattr, +} -# The reverse mapping of DType2IRDType in /graph/parser/mapping.py class IRDType2DType: + """ + The reverse mapping of DType2IRDType in /graph/parser/mapping.py + """ @staticmethod def map(ir_dtype:IRDType) -> torch.dtype: diff --git a/cube/codegen/lifecycle.py b/cube/codegen/lifecycle.py index 6fe19680..1767bb38 100644 --- a/cube/codegen/lifecycle.py +++ b/cube/codegen/lifecycle.py @@ -22,7 +22,9 @@ def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: # the tensors can be released given the finish of line id self.release: Dict[int, List[IRObject]] = {} - is_activation = lambda t: isinstance(t, IRObject) and not t.is_attr() + # FIXME: consider the case of IRObject in the kwargs of IRFwOperation + # is_activation = lambda t: isinstance(t, IRObject) and not t.is_attr() + is_activation = lambda t: isinstance(t, IRSubTensor) and not t.is_attr() self.lifetime.update((tsin, 0) for tsin in graph_inputs if is_activation(tsin)) diff --git a/cube/compiler.py b/cube/compiler.py index 4bc38621..19a01783 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -32,53 +32,55 @@ def compile(model: SemanticModel, *args, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, model_dummy_inputs: Tuple[Any] = None, + model_dynamic_shape: bool = False, comm_cost_fn: Optional[Callable] = None, override = True, load_content = True, scale: Union[bool, int] = False) -> Callable: + """Cube compile entry + + Examples: + + ``` + @cube.compile(model, data, PAS=policy) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(data) + loss.backward() + ``` + + Args: + model (SemanticModel | torch.nn.Module): single-device model + args (Tuple[Any]): compile function example inputs + PAS (Callable | Tuple[Callable, Callable, Callable]): policy to transform and schedule graph + model_dummy_inputs (Tuple[Any]): model example inputs when using torch.fx parser + model_dynamic_shape (bool): whether to compile model with dynamic shape + comm_cost_fn (Optional[Callable]): communication cost function, which + takes in an IRAdapterPrim, and outputs a cost in float. By default (None) use + communication volume. + override (bool): If true, the generated code will override exsisting + files (if they are already existed.), otherwise, use the already existed + generated code, i.e., the policy won't take effect. Default true. + load_content (bool): If true, will load parameter from exsiting saved models. + Otherwise, will initial model parameters with empty tensor. + scale (Union[bool, int]): If true, will scale the generated code to the + total launched number of GPUs. If int, will scale to the specified number. + Default False, no scaling. + + Returns: + Callable: compiled training iteration """ - AI Scientist calls like: - - @cube.compile(model, dataloader, PAS=policy) - def train_step(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - ... - - for epoch in range(100): - train_step(model, data_loader) - optimizer.step() - optimizer.zero_grad() - - ... - - @param model SemanticModel: AI Scientist specified SemanticModel - @param args: compile function example inputs - @param PAS Callable: policy to transform and schedule graph - @param model_dummy_inputs Tuple[Any]: model example inputs when using torch.fx parser - @param comm_cost_fn: Optional[Callable]: communication cost function, which - takes in an IRAdapterPrim, and outputs a cost in float. By default (None) use - communication volume. - @param override bool: If true, the generated code will override exsisting - files (if they are already existed.), otherwise, use the already existed - generated code, i.e., the policy won't take effect. Default true. - @param load_content bool: If true, will load parameter from exsiting saved models. - Otherwise, will initial model parameters with empty tensor. - @param scale Union[bool, int]: If true, will scale the generated code to the - total launched number of GPUs. If int, will scale to the specified number. - Default False, no scaling. - - @return sched_fn Callable: the scheduling function loaded from generated code. - """ + # clean global status Program().clear() IDGenerator().clear() assert PAS is not None, f'PAS should be callable function' - model = SemanticModel(model) if isinstance(model, torch.nn.Module) else model + if isinstance(model, torch.nn.Module): + model = SemanticModel(model) assert isinstance(model, SemanticModel), f'Require cube.SemanticModel or torch.nn.Module, but got model: {type(model)}' model.save_content = load_content + model.dynamic_shape = model_dynamic_shape model.dummy_input = model_dummy_inputs dataloader = None diff --git a/cube/flags.py b/cube/flags.py index 35168e26..18cdb3b7 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -20,6 +20,7 @@ class CompileFlag: # ============= loggings =================== log_transform = _to_bool('LOG_TRANSFORM') log_schedule = _to_bool('LOG_SCHEDULE') + log_parser = _to_bool('LOG_PARSER') # ================ compiling ======================== use_torchfx = _to_bool('USE_TORCHFX') # using torch.fx or torchscript as frontend to capture dataflow graph diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index e96e413c..4c53b44f 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -6,6 +6,7 @@ import numpy as np import math import warnings +import functools from cube.ir.cten import IRTensor, IRObject from cube.ir.tensor import IRSubTensor, IRFullTensor @@ -134,23 +135,23 @@ def creator_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: size[dim] = size[dim] // num kwargs['size'] = tuple(size) return kwargs - rules.append(TransformRule([], [DimopSplit.D(dim)], creator_modifier)) - return anno, rules -def CubeArange(start: int, end: int, step: int, dtype=None, - requires_grad=False, signature=None): - signature = 'cube.runtime.function.arange' - size = (math.ceil((end-start)/step),) - # FIXME: torch.jit.script has dtype with int - # from cube.graph.parser.mapping import TorchScalarTypeEnumMap - # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) +def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Union[int, IRObject], + dtype=None, requires_grad=False, signature=None): dtype = dtype if dtype is not None else torch.get_default_dtype() + assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + signature = 'cube.runtime.function.arange' kwargs = {'start': start, 'end': end, 'step': step, 'dtype': dtype, 'requires_grad': requires_grad} - anno, rules = _get_creator_anno_rules(size, False) + start_val = start.value if isinstance(start, IRObject) else start + end_val = end.value if isinstance(end, IRObject) else end + step_val = step.value if isinstance(step, IRObject) else step + size = (math.ceil((end_val-start_val)/step_val),) + anno, rules = _get_creator_anno_rules( + tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), False) dimop = IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) @@ -177,18 +178,16 @@ def Arange(*args, out=None, dtype=None, layout=None, def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): # note: device is ignored - signature = 'cube.runtime.function.empty' - size = (size,) if isinstance(size, int) else tuple(size) - size: Tuple[int] = size + arg_size - assert all(isinstance(dimlen, int) for dimlen in size), f"Empty only supports static size but got {size}" assert layout is None and memory_format is None, f"Not support for non-default memory_format and layout" - # FIXME: torch.jit.script has dtype with int - # from cube.graph.parser.mapping import TorchScalarTypeEnumMap - # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) dtype = dtype if dtype is not None else torch.get_default_dtype() + assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + signature = 'cube.runtime.function.empty' + size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) + size: Tuple[Union[int, IRObject]] = size + arg_size kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} - anno, rules = _get_creator_anno_rules(size, True) + anno, rules = _get_creator_anno_rules( + tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) @@ -198,17 +197,15 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi def Zeros(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): # note: device is ignored - signature = 'cube.runtime.function.zeros' - size = (size,) if isinstance(size, int) else tuple(size) - size: Tuple[int] = size + arg_size - assert all(isinstance(dimlen, int) for dimlen in size), f"Zeros only supports static size but got {size}" assert layout is None, f"Not support for non-default layout" - # FIXME: torch.jit.script has dtype with int - # from cube.graph.parser.mapping import TorchScalarTypeEnumMap - # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) dtype = dtype if dtype is not None else torch.get_default_dtype() + assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + signature = 'cube.runtime.function.zeros' + size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) + size: Tuple[Union[int, IRObject]] = size + arg_size kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} - anno, rules = _get_creator_anno_rules(size, True) + anno, rules = _get_creator_anno_rules( + tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = dtype if isinstance(dtype, IRDType) else DType2IRDType.map(dtype) @@ -218,17 +215,15 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, def Ones(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): # note: device is ignored - signature = 'cube.runtime.function.ones' - size = (size,) if isinstance(size, int) else tuple(size) - size: Tuple[int] = size + arg_size - assert all(isinstance(dimlen, int) for dimlen in size), f"Ones only supports static size but got {size}" assert layout is None, f"Not support for non-default layout" - # FIXME: torch.jit.script has dtype with int - # from cube.graph.parser.mapping import TorchScalarTypeEnumMap - # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) dtype = dtype if dtype is not None else torch.get_default_dtype() + assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + signature = 'cube.runtime.function.ones' + size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) + size: Tuple[Union[int, IRObject]] = size + arg_size kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} - anno, rules = _get_creator_anno_rules(size, True) + anno, rules = _get_creator_anno_rules( + tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) @@ -238,18 +233,16 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): # note: device is ignored - signature = 'cube.runtime.function.rand' - size = (size,) if isinstance(size, int) else tuple(size) - size: Tuple[int] = size + arg_size - assert all(isinstance(dimlen, int) for dimlen in size), f"Rand only supports static size but got {size}" assert layout is None and memory_format is None, f"Not support for non-default memory_format and layout" - # FIXME: torch.jit.script has dtype with int - # from cube.graph.parser.mapping import TorchScalarTypeEnumMap - # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) dtype = dtype if dtype is not None else torch.get_default_dtype() + assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + signature = 'cube.runtime.function.rand' + size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) + size: Tuple[Union[int, IRObject]] = size + arg_size kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} - anno, rules = _get_creator_anno_rules(size, True) + anno, rules = _get_creator_anno_rules( + tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) from cube.graph.parser.dtype import DType2IRDType dimop.output(0).parent.dtype = DType2IRDType.map(dtype) @@ -259,13 +252,10 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir def NewTensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False, signature=None): # note: device is ignored + dtype = dtype if dtype is not None else torch.get_default_dtype() + assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" signature = 'cube.runtime.function.tensor' size = tuple(np.array(data).shape) - assert all(isinstance(dimlen, int) for dimlen in size), f"Ones only supports static size but got {size}" - # FIXME: torch.jit.script has dtype with int - # from cube.graph.parser.mapping import TorchScalarTypeEnumMap - # dtype = TorchScalarTypeEnumMap.map(dtype_underlying) - dtype = dtype if dtype is not None else torch.get_default_dtype() kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(size, True) @@ -372,6 +362,10 @@ def Add(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input + alpha * other + if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): + iv = input.value if isinstance(input, IRObject) else input + ov = other.value if isinstance(other, IRObject) else other + return IRPyFunc(signature, [input, other], [IRObject(name='add', value=iv+ov)]) signature = 'torch.add' annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): @@ -385,6 +379,10 @@ def Sub(input, other, alpha=1, *, out=None, signature = None): signature = 'torch.sub' if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input - alpha * other + if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): + iv = input.value if isinstance(input, IRObject) else input + ov = other.value if isinstance(other, IRObject) else other + return IRPyFunc(signature, [input, other], [IRObject(name='sub', value=iv-ov)]) annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) @@ -396,6 +394,10 @@ def Mul(input, other, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input * other + if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): + iv = input.value if isinstance(input, IRObject) else input + ov = other.value if isinstance(other, IRObject) else other + return IRPyFunc(signature, [input, other], [IRObject(name='mul', value=iv*ov)]) signature = 'torch.mul' annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): @@ -408,6 +410,10 @@ def Div(input, other, *, rounding_mode=None, out=None, signature = None): assert rounding_mode is None and out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input / other + if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): + iv = input.value if isinstance(input, IRObject) else input + ov = other.value if isinstance(other, IRObject) else other + return IRPyFunc(signature, [input, other], [IRObject(name='div', value=iv/ov)]) signature = 'torch.div' annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): @@ -421,8 +427,11 @@ def Exp(input, *, out=None, signature=None): torch.exp(input, *, out=None) """ assert out is None - if not isinstance(input, IRTensor): + if not isinstance(input, IRObject): return torch.exp(input) + if not isinstance(input, IRTensor): + assert input.value is not None + return IRPyFunc(signature, [input], [IRObject(name='exp', value=torch.exp(input.value))]) shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] return IRDimops(Exp, 'exp', signature, annos, [input]) @@ -435,6 +444,9 @@ def Sqrt(input, *, out=None, signature=None): assert out is None if not isinstance(input, IRTensor): return torch.sqrt(input) + if not isinstance(input, IRTensor): + iv = input.value if isinstance(input, IRObject) else input + return IRPyFunc(signature, [input], [IRObject(name='sqrt', value=torch.sqrt(iv))]) shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] return IRDimops(Sqrt, 'sqrt', signature, annos, [input]) @@ -445,7 +457,9 @@ def FloorDiv(input, other, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input // other if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): - return IRPyFunc(signature, [input, other], [IRObject()]) + iv = input.value if isinstance(input, IRObject) else input + ov = other.value if isinstance(other, IRObject) else other + return IRPyFunc(signature, [input, other], [IRObject(name='fdiv', value=iv//ov)]) annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) @@ -453,22 +467,14 @@ def FloorDiv(input, other, *, out=None, signature = None): return IRDimops(FloorDiv, 'floordiv', signature, annos, [input, other]) -def Exp(input, *, out=None, signature=None): - """ - torch.exp(input, *, out=None) - """ - assert out is None - if not isinstance(input, IRTensor): - return torch.exp(input) - shape = ShapeAnno.create_shape_str(input.shape) - annos = [OpAnno.create_op_str([shape], [shape])] - return IRDimops(Exp, 'exp', signature, annos, [input]) - - def Pow(input, exponent, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(exponent, IRObject)): return input ** exponent + if (not isinstance(input, IRTensor)) and (not isinstance(exponent, IRTensor)): + iv = input.value if isinstance(input, IRObject) else input + ev = exponent.value if isinstance(exponent, IRObject) else exponent + return IRPyFunc(signature, [input, exponent], [IRObject(name='pow', value=iv**ev)]) annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(exponent, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, exponent) @@ -479,6 +485,9 @@ def Pow(input, exponent, *, out=None, signature = None): def Neg(input, *, out=None, signature = None): assert out is None if not isinstance(input, IRObject): return -1 * input + if not isinstance(input, IRTensor): + iv = input.value if isinstance(input, IRObject) else input + return IRPyFunc(signature, [input], [IRObject(name='neg', value=-iv)]) annos = ['* -> *'] return IRDimops(Neg, 'neg', signature, annos, [input]) @@ -828,21 +837,24 @@ def Transpose(input, dim0, dim1, signature = None): dim0=dim0, dim1=dim1) -def View(input, size: Tuple[int], *arg_size, signature = None): +def _reshape_anno(in_shape: List[int], ou_shape: List[int], kwarg_name: str) -> Tuple[str, List[TransformRule]]: """ - out = torch.Tensor.view(tensor: torch.Tensor, *size) - """ - size = (size,) if isinstance(size, int) else tuple(size) - size = size + arg_size - assert all([isinstance(dim, int) for dim in size]), \ - f"Expected tensor.view has static int shape but got: {size}" - in_shape, ou_shape = list(input.shape), list(size) + reshape / view annotation and transformation rule generator - # infer -1 + Args: + in_shape List[int]: input shape + ou_shape List[int]: output shape + kwarg_name str: kwarg name of reshape / view op + + Returns: + str: annotation string + List[TransformRule]: transformation rules + """ def nele(shape, nele=1): for dimlen in shape: nele *= dimlen return nele + # infer -1 cnt = nele(in_shape) if -1 in ou_shape: idx = ou_shape.index(-1) @@ -965,194 +977,74 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s if bracket[hdim] not in spatial: bracket[hdim] = str(shape_map[bracket[hdim]]) - # TODO: strange behaviour if every identitifer creates own - # modifier, seems all previous modifiers will be overrided by - # the last one. - def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: + def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: kwargs = dict(**kwargs) identifier = ifirst[dim] oidx = ofirst.index(identifier) - size = list(kwargs['size']) + size = list(kwargs[kwarg_name]) + assert isinstance(size[oidx], int), \ + f'dynamic size cannot be partitioned but got: {size}' size[oidx] = size[oidx] // num - kwargs['size'] = tuple(size) + kwargs[kwarg_name] = tuple(size) return kwargs # special rules: to change output size argument - rules = [] + rules: TransformRule = [] for identifier in spatial: iidx = ifirst.index(identifier) oidx = ofirst.index(identifier) rules.append( - TransformRule([DimopSplit.D(iidx)], [DimopSplit.D(oidx)], view_modifier) + TransformRule([DimopSplit.D(iidx)], [DimopSplit.D(oidx)], modifier) ) anno = OpAnno.create_op_str([in_anno], [ou_anno]) + return anno, rules + + +def View(input, size: Tuple[int], *arg_size, signature = None): + """ + out = torch.Tensor.view(tensor: torch.Tensor, *size) + """ + in_shape = list(input.shape) + if isinstance(size, IRObject): + assert size.value is not None, f"shape should have a reference value but got: {size}" + if isinstance(size.value, int): + size = (size,) + arg_size + ou_shape = [d.value if isinstance(d, IRObject) else d for d in size] + else: # tuple[int] / list[int] + assert len(arg_size) == 0, f"already got a tuple of int shape" + ou_shape = list(size.value) + else: # int / tuple[int] + size = ((size,) if isinstance(size, int) else tuple(size)) + arg_size + ou_shape = [d.value if isinstance(d, IRObject) else d for d in size] + assert all(isinstance(d, int) for d in ou_shape), f"but got {ou_shape}" + + anno, rules = _reshape_anno(in_shape, ou_shape, kwarg_name='size') signature = 'torch.Tensor.view' - return IRDimops(View, 'view', signature, [anno], [input], rules, size=tuple(size)) + return IRDimops(View, 'view', signature, [anno], [input], rules, size=size) def Reshape(input, shape: Tuple[int], *arg_shape, signature = None): """ torch.reshape(Tensor self, int[] shape) -> Tensor """ - - size = (shape,) if isinstance(shape, int) else tuple(shape) - size = size + arg_shape - assert all([isinstance(dim, int) for dim in size]), \ - f"Expected tensor.view has static int shape but got: {size}" - in_shape, ou_shape = list(input.shape), list(size) - - # infer -1 - def nele(shape, nele=1): - for dimlen in shape: nele *= dimlen - return nele - - cnt = nele(in_shape) - if -1 in ou_shape: - idx = ou_shape.index(-1) - ou_shape[idx] = cnt // (-nele(ou_shape)) - assert nele(in_shape) == nele(ou_shape), f"shape mismatch: {in_shape}, {ou_shape}" - - # generate annotation - rest_inshape = [dimlen for dimlen in in_shape] - rest_oushape = [dimlen for dimlen in ou_shape] - chain = [] - can_bucket = True - while len(rest_inshape) != 0 or len(rest_oushape) != 0: - if len(rest_inshape) == 0: - chain = chain + rest_oushape - rest_oushape = [] - elif len(rest_oushape) == 0: - chain = chain + rest_inshape - rest_inshape = [] - else: - dimlen = min(rest_inshape[0], rest_oushape[0]) - if max(rest_inshape[0], rest_oushape[0]) % dimlen == 0: - chain.append(dimlen) - if dimlen == rest_inshape[0]: - rest_inshape.pop(0) - else: - rest_inshape[0] = rest_inshape[0] // dimlen - if dimlen == rest_oushape[0]: - rest_oushape.pop(0) - else: - rest_oushape[0] = rest_oushape[0] // dimlen - else: - can_bucket = False - break - - letters = iter(string.ascii_lowercase) - if can_bucket: - inchain = ouchain = chain - inedims = ouedims = edims = [next(letters) for _ in chain] - else: - inchain, ouchain = in_shape, ou_shape - inedims = [str(dimlen) for dimlen in in_shape] - ouedims = [str(dimlen) for dimlen in ou_shape] - chain = inchain + ouchain - edims = inedims + ouedims - shape_map: Dict[str, int] = {edim: eshape for (edim, eshape) in zip(edims, chain)} - - # generate input and output shape annotations - def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[str]]: - anno = [] - dimidx = 0 - for idx, dimlen in enumerate(shape): - elements, bracket = 1, [] - maxele = len(chain) - dimidx - (len(shape) - 1 - idx) - while True: - if len(bracket) == maxele: - assert elements == dimlen, f"internal match error1: {bracket}" - break - if dimidx >= len(chain) or elements * chain[dimidx] > dimlen: - assert elements == dimlen, f"internal match error2: {bracket}" - break - else: - elements *= chain[dimidx] - bracket.append(edims[dimidx]) - dimidx += 1 - # fetch as many 1^ as possible from tail of the previous bracket - if len(bracket) == 0: - assert dimlen == 1, f"internal match error3: dimlen={dimlen}" - back = 0 - for edim in anno[-1][1:][::-1]: - if chain[edims.index(edim)] != 1: - break - back += 1 - assert back > 0, f"internal match error4: dimlen={dimlen}" - bracket = anno[-1][-back:] - anno[-1] = anno[-1][:-back] - assert len(bracket) > 0, f"got a dimension with no edim" - anno.append(bracket) - return anno - - in_anno = buckets(in_shape, inchain, inedims) - ou_anno = buckets(ou_shape, ouchain, ouedims) - - # postprocess on dimlen == 1 - shape_map['1'] = 1 - for bracket in in_anno + ou_anno: - for subdim, edim in enumerate(bracket): - if shape_map[edim] == 1: - bracket[subdim] = str(shape_map[edim]) - - # find out the axis that can be partitioned - ispatial, ifirst = set(), [] - for bracket in in_anno: - sdim = None - for hdim in range(len(bracket)): - if bracket[hdim] == '1' or shape_map[bracket[hdim]] == 1: continue - sdim = bracket[hdim] - break - if sdim is not None: - ispatial.add(sdim) - ifirst.append(sdim) - - ospatial, ofirst = set(), [] - for bracket in ou_anno: - sdim = None - for hdim in range(len(bracket)): - if bracket[hdim] == '1' or shape_map[bracket[hdim]] == 1: continue - sdim = bracket[hdim] - break - if sdim is not None: - ospatial.add(sdim) - ofirst.append(sdim) - - # intersection for spatial partitioned dimensions - spatial = ispatial.intersection(ospatial) - - # set dimension cannot be partitioned - for bracket in in_anno + ou_anno: - for hdim in range(len(bracket)): - if bracket[hdim] not in spatial: - bracket[hdim] = str(shape_map[bracket[hdim]]) - - # TODO: strange behaviour if every identitifer creates own - # modifier, seems all previous modifiers will be overrided by - # the last one. - def view_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: - kwargs = dict(**kwargs) - identifier = ifirst[dim] - oidx = ofirst.index(identifier) - size = list(kwargs['shape']) - size[oidx] = size[oidx] // num - kwargs['shape'] = tuple(size) - return kwargs - - # special rules: to change output size argument - rules = [] - for identifier in spatial: - iidx = ifirst.index(identifier) - oidx = ofirst.index(identifier) - rules.append( - TransformRule([DimopSplit.D(iidx)], [DimopSplit.D(oidx)], view_modifier) - ) - - anno = OpAnno.create_op_str([in_anno], [ou_anno]) - - new_signature = 'torch.Tensor.reshape' - return IRDimops(Reshape, 'shape', new_signature, [anno], [input], rules, shape=tuple(size)) + in_shape = list(input.shape) + if isinstance(shape, IRObject): + assert shape.value is not None, f"shape should have a reference value but got: {shape}" + if isinstance(shape.value, int): + shape = (shape,) + arg_shape + ou_shape = [d.value if isinstance(d, IRObject) else d for d in shape] + else: # tuple[int] / list[int] + assert len(arg_shape) == 0, f"already got a tuple of int shape" + ou_shape = list(shape.value) + else: # int / tuple[int] + shape = ((shape,) if isinstance(shape, int) else tuple(shape)) + arg_shape + ou_shape = [d.value if isinstance(d, IRObject) else d for d in shape] + assert all(isinstance(d, int) for d in ou_shape), f"but got {ou_shape}" + + anno, rules = _reshape_anno(in_shape, ou_shape, kwarg_name='shape') + signature = 'torch.Tensor.reshape' + return IRDimops(Reshape, 'reshape', signature, [anno], [input], rules, shape=shape) def Permute(input, dims: Tuple[int], *arg_dims, signature = None): @@ -1430,40 +1322,46 @@ def IndexSelect(input: torch.Tensor, dim: int, index: torch.Tensor, *, out=None, return CubeIndexSelect(input, index, dim, signature=signature) -def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice]], signature=None): +def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice, int]], signature=None): """ - subtensor = tensor[:,128:] - subtensor = tensor[0,128:] - subtensor = tensor[0] + Examples: + >>> a = torch.randn((4,2)) + >>> a[(2,)], a[2] # shape [2] + >>> a[2:3], a[2:3,:] # shape [1,2] + >>> a[(2, slice(None, None, None))] # shape [2] + >>> a[(2, None)] # shape [1,2] + >>> a[(2, slice(None, None, None)), None] # shape [2,1] + >>> a[(2, None, slice(None, None, None))] # shape [1,2] """ signature = 'cube.runtime.function.fullslice' - slicers = tuple(slicers) + (None,) * (len(tensor.shape) - len(slicers)) + slicers = tuple(slicers) edim_in = ShapeAnno.create_shape_str(tensor.shape) edim_ou = [] - for dim, slicer in enumerate(slicers): + in_idx = 0 + for slicer in slicers: if slicer is None: - if dim < len(edim_in): - edim_ou.append(edim_in[dim]) - else: - # expand the dimension - edim_ou.append('1') - else: + edim_ou.append('1') + elif isinstance(slicer, int): + edim_in[in_idx] += '^' + in_idx += 1 + elif isinstance(slicer, slice): if slicer != slice(None, None, None): - edim_in[dim] += '^' - if isinstance(slicer, slice): - stop = tensor.shape[dim] if slicer.stop is None else slicer.stop - start = 0 if slicer.start is None else slicer.start - step = 1 if slicer.step is None else slicer.step - dimlen = len(range(start, stop, step)) - if dimlen == tensor.shape[dim]: - edim_ou.append(edim_in[dim]) - else: - edim_ou.append(str(dimlen)) + edim_in[in_idx] += '^' + start = 0 if slicer.start is None else slicer.start + stop = tensor.shape[in_idx] if slicer.stop is None else slicer.stop + step = 1 if slicer.step is None else slicer.step + dimlen = len(range(start, stop, step)) + if dimlen == tensor.shape[in_idx]: + edim_ou.append(edim_in[in_idx]) else: - pass # no shape for int - # special case for loss = torch.Tensor([1,2,3])[0] + edim_ou.append(str(dimlen)) + in_idx += 1 + else: + raise RuntimeError(f"Unsupported slicer {slicer}") + edim_ou += edim_in[in_idx:] + # special case for scalar = torch.Tensor([1,2,3])[0] if len(edim_ou) == 0: - edim_ou = ['1^'] + edim_ou.append('1') anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(FullSlice, 'fullslice', signature, [anno], [tensor], slicers=slicers) @@ -1714,16 +1612,16 @@ def ShapeAsTensor(input: IRTensor, signature = None): """ torch._shape_as_tensor """ - if isinstance(input.shape, list) and all(isinstance(dim, int) for dim in input.shape): - return input.shape - + warnings.warn('shape_as_tensor is interpreted as an IRPyFunc' + ' and generate an IRObject instead of IRTensor') + signature = 'torch._shape_as_tensor' + return IRPyFunc(signature, [input], [IRObject(name='shape', value=input.shape)]) edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = [str(len(input.shape))] anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(ShapeAsTensor, '_shape_as_tensor', signature, [anno], [input]) - # ================== Non-autograd Function Space ================= def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: @@ -1731,10 +1629,12 @@ def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: torch.Tensor.size(tensor, dim=None) """ assert isinstance(tensor, IRTensor) - # constant - if all(isinstance(dimlen, int) for dimlen in tensor.shape) and not isinstance(dim, IRObject): - return tensor.shape[dim] if isinstance(dim, int) else list(tensor.shape) - return IRPyFunc(signature, [tensor, dim], [IRObject()]) + val = tensor.shape[dim] if isinstance(dim, int) else list(tensor.shape) + assert val is not None + if dim is None: + return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)]) + else: + return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)], dim=dim) def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): @@ -1763,53 +1663,17 @@ def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: _operator.getitem(obj, index: int) """ obj, index = a, b - - def try_reshape(tensor, expr): - if expr[0] == Ellipsis and expr[1] == None: - return True, copy.copy(tensor.shape) + [1] - dim_cnt = 0 - idx = 0 - dst_shape = [] - for item in expr: - if item == slice(None, None, None): - dst_shape.append(tensor.shape[idx]) - idx += 1 - elif item == None: - dst_shape.append(1) - else: - return False, [] - if idx != len(tensor.shape): - return False, [] - return True, dst_shape - - def try_select(tensor, expr): - int_cnt = 0 - dim = -1 - val = -1 - for i, item in enumerate(expr): - if isinstance(item, int): - int_cnt += 1 - dim = i - val = item - if int_cnt != 1: - return False, -1, -1 - if expr[0] == Ellipsis: - dim = dim - len(expr) - return True, dim, val - + # tensor slice if isinstance(obj, IRTensor): - is_reshape, dst_shape = try_reshape(obj, index) - if is_reshape: - return Reshape(obj, dst_shape, signature='torch.reshape') - is_select, dim, val = try_select(obj, index) - if is_select: - return Select(obj, dim, val, 'torch.select') - # case: subtensor = tensor[1,:2] - return FullSlice(obj, b) - # assert False, f'{obj}, {index}' - elif (not isinstance(obj, IRObject)) and isinstance(index, int): - return obj[index] - return IRPyFunc(signature, [obj, index], [IRObject()]) + # note `None` will always + index = (index,) if isinstance(index, int) else tuple(index) + return FullSlice(obj, index) + # object slice + if isinstance(obj, IRObject): + assert obj.value is not None + out = IRObject(name='getitem', value=obj.value[index]) + return IRPyFunc(signature, [obj, index], [out]) + return obj[index] def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], IRPyFunc]: @@ -1819,19 +1683,23 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], have instantiated object or the attr is not simple value. """ obj, name = instance, field - if name in ('shape', 'dtype'): - assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - assert hasattr(obj, name), f"attr {name} is not existed in {obj}" - return getattr(obj, name) - elif name == 'device': - assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - # FIXME: this is hack, IRFullTensor does not have attribute "device" - return torch.device('cpu') - elif isinstance(obj, torch.finfo): + if isinstance(obj, IRTensor): + if name == 'shape': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + shape = IRObject('shape', value=obj.shape) + return IRPyFunc(signature, [instance, field], [shape]) + if name == 'dtype': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + assert hasattr(obj, name), f"attr {name} is not existed in {obj}" + return getattr(obj, name) + if name == 'device': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + # FIXME: this is hack, IRFullTensor does not have attribute "device" + return torch.device('cpu') + if isinstance(obj, torch.finfo): return getattr(obj, name) - else: - # FIXME: is it right? - return IRPyFunc(signature, [instance, field], [IRObject()]) + return IRPyFunc(signature, [instance, field], [IRObject()]) + def FInfo(dtype: IRDType, signature = None) -> torch.finfo: assert isinstance(dtype, IRDType) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 77cd21af..23cfc806 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -110,10 +110,10 @@ def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: graph._reorder_producer_consumer() # remove anchor node graph = IRAdapterGener.remove_anchor(graph) - # automatic replace pyfunc - graph = IRAdapterGener.auto_pyfunc(graph) # automatic transform multiref graph = IRAdapterGener.autoref(graph) + # automatic replace pyfunc + graph = IRAdapterGener.auto_pyfunc(graph) # generate adapters for activation graph = IRAdapterGener.gen_activation(graph, cost_fn=cost_fn) # generate weight reducer diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 9e2c204d..8fc47fa2 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,4 +1,5 @@ from typing import Optional, List +import warnings from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser @@ -19,22 +20,25 @@ def convert_model(model: torch.nn.Module, input_shapes: Optional[ List[List[int],] ] = None, dummy_input = None, - save_content: bool = True) -> IRGraph: + save_content: bool = True, + dynamic_shape: bool = False) -> IRGraph: """ Convert torch.nn.Module based model into IRGraph """ try: if CompileFlag.use_torchfx: if CompileFlag.use_default_fx_tracer: - print('INFO: using torch.fx tracer') - from torch.fx import symbolic_trace + if CompileFlag.log_parser: + print('> use default torch.fx tracer') # Symbolic tracing frontend - captures the semantics of the module tracer = FxFuncOpTracer() traced_graph: torch.fx.Graph = tracer.trace(model) - smodule: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) - smodule.graph.print_tabular() + traced_model: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) + if CompileFlag.log_parser: + traced_model.graph.print_tabular() else: - print('INFO: using concrete tracer') + if CompileFlag.log_parser: + print('> use concrete torch.fx tracer') if HAS_APEX: leaf_module = ( # torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, @@ -61,26 +65,28 @@ def convert_model(model: torch.nn.Module, fake_device_type='cpu', ) else: - print('using torchscript tracer') - smodule = torch.jit.script(model) + if CompileFlag.log_parser: + print('> use default torch.jit.script tracer') + traced_model = torch.jit.script(model) except Exception as ex: print(ex) raise RuntimeError("Cannot convert module into torchscript/torch.fx module.") if CompileFlag.use_torchfx: - if not dummy_input: - FxModuleParser.save_content = save_content - inputs, nodes, outputs = FxModuleParser.parse(smodule, input_shapes) - module_name = model.__class__.__name__ - else: - FxModuleParser.save_content = save_content - inputs, nodes, outputs = FxModuleParser.parse(traced_model, input_shapes=None, dummy_inputs=dummy_input) - module_name = model.__class__.__name__ + FxModuleParser.save_content = save_content + FxModuleParser.dynamic_shape = dynamic_shape + if CompileFlag.log_parser: + print(f"> use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") + inputs, nodes, outputs = FxModuleParser.parse(traced_model, dummy_input) + module_name = model.__class__.__name__ else: + if dynamic_shape: + warnings.warn('dynamic shape is not supported in torch.jit.script', + category=RuntimeWarning) ScriptModuleParser.save_content = save_content - inputs, nodes, outputs = ScriptModuleParser.parse_module(smodule, input_shapes) - module_name = smodule.original_name + inputs, nodes, outputs = ScriptModuleParser.parse_module(traced_model, input_shapes) + module_name = traced_model.original_name for input in inputs: if isinstance(input, IRFullTensor): diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 43e7b64b..3de7267e 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -1,7 +1,7 @@ import torch import enum -import re -from typing import Any, List, Tuple, Optional, Callable, Union +import warnings +from typing import Any, List, Tuple, Optional, Callable, Union, Dict from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor @@ -10,6 +10,9 @@ from cube.graph.parser.dtype import DType2IRDType from cube.graph.parser.mappingfx import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc +from cube.graph.function.dimops import IRDimops + +from cube.flags import CompileFlag import torch.fx @@ -59,119 +62,85 @@ def get_complex_data(val: Any, frame: Frame) -> Any: class FxModuleParser: + """torch.fx module parser + + Attributes: + save_content (bool): whether to save the content of the module + dynamic_shape (bool): whether to parse the module with dynamic shape + """ save_content: bool = True + dynamic_shape: bool = False + @staticmethod def shape_refine(shape: torch.Size) -> torch.Size: + """Replacing scale shape [] to [1] + + Args: + shape (torch.Size): tensor shape + + Returns: + torch.Size: refined shape """ - replacing scale shape [] to [1] - :param shape: - :return: - """ - # TODO update return torch.Size([1]) if shape == torch.Size([]) else shape @staticmethod def parse(module: torch.fx.GraphModule, - input_shapes: Optional[Tuple[List[int],]] = None, - dummy_inputs = None, + dummy_inputs: Dict[str, Any], frame: Frame = None) \ -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: - from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler - DCEHandler(module).eliminate_dead_code() + """Parse torch.fx module into cube IR - """ The overall entry to parse a torch.fx graph module """ + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp + DCEHandler(module).eliminate_dead_code() + frame = frame if frame is not None else Frame() frame.push_var() frame.push_attr() - inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] - print(f'inputs = {inputs}') + assert isinstance(dummy_inputs, dict), "Expected dummy inputs to parse module" - if input_shapes is not None and len(input_shapes) != len(inputs): - print(f'WARNING input shape mismatch (got {len(input_shapes)} != {len(inputs)})') - if len(input_shapes) < len(inputs): - raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") - else: - input_shapes = input_shapes[:len(inputs)] - print(f'WARNING input_shapes shrinked to {input_shapes})') + inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] + if CompileFlag.log_parser: + print(f'> torch.fx parser: graph inputs: {inputs}') - default_dtype = torch.get_default_dtype() - if input_shapes is not None: - # shape propagation - sample_inputs = dummy_inputs if dummy_inputs else [torch.ones(shape, dtype=default_dtype) for shape in input_shapes] - sample_input_tensors = [sample_inputs[input] for input in sample_inputs] if type(sample_inputs) is dict else sample_inputs - - # from torch.fx.passes.shape_prop import ShapeProp - # ShapeProp(module).propagate(*sample_input_tensors) # TODO fixme ShapeProp(module).propagate(*sample_inputs) - from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp - ShapeProp(module).propagate(sample_inputs) - - # handle graph inputs - for idx, input in enumerate(inputs): - assert isinstance(input, torch.fx.Node) - if 'tensor_meta' in input.meta: # tensor type - shape = None if len(input_shapes) <= idx else input_shapes[idx] - if shape is not None and len(shape) == 0: - shape = [1] - dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) - else: + # shape propagation + ShapeProp(module).propagate(dummy_inputs) + # handle graph inputs + for idx, input in enumerate(inputs): + assert isinstance(input, torch.fx.Node) + # dealing with different types of dummy_inputs + if isinstance(dummy_inputs, dict): + if input.name not in dummy_inputs: val = IRObject(input.name) - frame.add_var(input.name, val, graph_arg=idx) - else: - assert dummy_inputs is not None, 'input_shapes and dummy_inputs cannot be None at the same time.' - # shape propagation - from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp - ShapeProp(module).propagate(dummy_inputs) - # handle graph inputs - for idx, input in enumerate(inputs): - assert isinstance(input, torch.fx.Node) - # dealing with different types of dummy_inputs - if isinstance(dummy_inputs, dict): - if input.name not in dummy_inputs: - val = IRObject(input.name) - else: - if 'tensor_meta' in input.meta: - shape = input.meta['tensor_meta'].shape - if len(shape) == 0: - shape = [1] - dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) - else: - val = IRObject(input.name) else: - # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, - # extend to other input types - if hasattr(dummy_inputs, input.name): - print(f'dummy_inputs has {input.name}') - shape = getattr(dummy_inputs, input.name).size() + if 'tensor_meta' in input.meta: + shape = input.meta['tensor_meta'].shape + if len(shape) == 0: + shape = [1] + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) else: - # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name - print(f'dummy_inputs does not have {input.name}') - shape = None - dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) - frame.add_var(input.name, val, graph_arg=idx) - input_val = [frame.get_var(input.name) for input in inputs] - - for node in module.graph.nodes: - if 'tensor_meta' in node.meta: - if node.meta['type'] is type(tuple()): - print(f'{node.name} is tuple type') - elif node.meta['type'] is type(torch.fx.immutable_collections.immutable_dict()): - print(f'{node.name} is immutable_dict type') - assert isinstance(node.meta['tensor_meta'], dict) - else: - if node.meta['type'] is type(torch.Tensor()) or node.meta['type'] is type(torch.nn.parameter.Parameter()): - print(node.name, node.meta['tensor_meta'].dtype, node.meta['tensor_meta'].shape) - else: - print(f'WARNING node {node.name} is neither Tensor nor Parameter') + val = IRObject(input.name) else: - print(f'{node.name} does not has tensor_meta') + # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, + # extend to other input types + if hasattr(dummy_inputs, input.name): + # print(f'dummy_inputs has {input.name}') + shape = getattr(dummy_inputs, input.name).size() + else: + # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name + # print(f'dummy_inputs does not have {input.name}') + shape = None + dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + frame.add_var(input.name, val, graph_arg=idx) + + input_val = [frame.get_var(input.name) for input in inputs] # add activations to frame, including call_func/call_method output and final output # call_module corresponds to leaf torch.nn.module @@ -192,11 +161,9 @@ def parse(module: torch.fx.GraphModule, # handle nodes all_ir_nodes: List[IRFwOperation] = list() total_node_num = len(module.graph.nodes) - node_idx = 1 - for node in module.graph.nodes: - print(f'[{node_idx}/{total_node_num}]') - node_idx += 1 - + for nidx, node in enumerate(module.graph.nodes): + if CompileFlag.log_parser: + print(f'> torch.fx parser: [{nidx}/{total_node_num}] parsing node {node}...', flush=True) ir_nodes = FxModuleParser.parse_node(node, module, frame) if ir_nodes is not None: all_ir_nodes += ir_nodes @@ -273,13 +240,7 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: input_vals = [get_complex_data(val, frame) for val in node.args] kwargs = {key: get_complex_data(val, frame) for key, val in node.kwargs.items()} if prim_module.__class__.__module__.startswith('torch.nn.modules'): - assert False, 'Dropout is not supposed to be treated as module.' - fsig = 'torch.nn.{}'.format(prim_module.__class__.__name__) - # specifically deal with torch.nn.Dropout, because some inputs of nn.module are passed - # in module instantiating phase, besides during forward - assert prim_module.__class__.__name__ == 'Dropout', f'{prim_module.__class__.__name__}, {fsig}' - kwargs.update({'p': prim_module.p, 'inplace': prim_module.inplace}) - return FxModuleParser._parse_node(fsig, node, input_vals, kwargs, frame) + assert False, 'torch.nn.modules can not be parsed as leaf nodes' elif prim_module.__class__.__module__ == 'apex.normalization.fused_layer_norm': fsig = '{}.{}'.format(prim_module.__class__.__module__, prim_module.__class__.__name__) assert prim_module.elementwise_affine is True @@ -317,10 +278,6 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: # get signature fsig = FxModuleParser._get_qualified_name(node.target, node) - if isinstance(node.target, str): - print(f'parse_prim_method_node: {fsig}') - else: - print(f'parse_prim_function_node: {fsig}') # get inputs input_vals = [get_complex_data(val, frame) for val in node.args] @@ -337,32 +294,51 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): - print(f'>>> Find unknown pytorch operation: {fsig}') + warnings.warn(f'Find unknown pytorch operation: {fsig}', + category=RuntimeWarning) fname = fsig.split('.')[-1] if '.' in fsig else fsig ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) # case2: python runtime function else: - print(f'>>> Set python runtime function: {fsig}') - ir_node = IRPyFunc(fsig, input_vals, [None], **kwargs) + warnings.warn(f'Set python runtime function: {fsig}', + category=RuntimeWarning) + ir_node = IRPyFunc(fsig, input_vals, [IRObject()], **kwargs) + ir_nodes = [] if isinstance(ir_node, IRCell): - # TODO gracefully set output + ir_nodes.append(ir_node) if len(ir_node.outputs()) > 1: - # REMARK: some nodes will return multiple outputs, e.g., torch.chunk, - # while torch.fx always return one output. This will cause - # getitem or unpacking operation on the output, which can be folded by - # setting the list of the output tensor - print('>> parsing {ir_node}') + # REMARK: some nodes will return multiple outputs, e.g., torch.chunk, while torch.fx always + # return one output. This will cause `getitem`` or `unpacking`` operation on the output, + # which can be folded by setting the list of the output tensor ir_node.infer_shape() ir_node.infer_dtype() frame.set_var(node.name, ir_node.outputs()) + elif ir_node.output(0).value is not None: + if FxModuleParser.dynamic_shape: + frame.set_var(node.name, ir_node.output(0)) + ir_node.output(0).name = node.name + else: + # if use static shape graph, all IRObject will be converted to real traced value. + # the ir_node will be folded and not appeared in the final graph + frame.set_var(node.name, ir_node.output(0).value) + ir_nodes.pop(-1) else: output_val = frame.get_var(node.name) + if isinstance(ir_node, IRDimops): + ir_node.infer_shape() + assert output_val.shape == ir_node.output(0).shape, ( + f'find shape inference not match: {output_val.shape} vs {ir_node.output(0).shape}' + f'\nnode: {node}' + ) ir_node.set_output(0, output_val) - return [ir_node] else: frame.set_var(node.name, ir_node) - return [] + + if CompileFlag.log_parser: + print(f'parsing result: {ir_node}', flush=True) + + return ir_nodes @staticmethod def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: @@ -381,7 +357,6 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram # check if existing param if requires_grad and frame.has_attr_value(node.target): # existing param prev_tensor_name = frame.get_attr_key(node.target) - print(f'INFO: link {tensor_name} to existing param {prev_tensor_name}') frame.add_var(tensor_name, frame.get_var(prev_tensor_name)) else: # new param / activation ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=requires_grad, dtype=dtype) @@ -403,31 +378,6 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 ir_nodes = [] - - # handle complex outputs - # def generate_outputs(val: Any, _ops: List) -> IRObject: - # """Support complex data type of List, Tuple, Dict, Tensor/Object""" - # if isinstance(val, list): - # inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) - # output = IRObject() - # _ops.append(IRPyFunc('(lambda *args: list(args))', inputs, [output])) - # return output - # if isinstance(val, tuple): - # inputs = tuple(generate_outputs(sub_node, _ops) for sub_node in val) - # output = IRObject() - # _ops.append(IRPyFunc('(lambda *args: args)', inputs, [output])) - # return output - # if isinstance(val, dict): - # output = IRObject() - # assert all(not isinstance(key, torch.fx.Node) for key in val.keys()), f"output dict cannot have torch.fx.Node is key" - # keys = tuple(str(key) for key in val.keys()) - # values = generate_outputs(tuple(generate_outputs(value, _ops) for value in val.values()), _ops) - # _ops.append(IRPyFunc('(lambda vals, keys: {key:val for key,val in zip(keys,vals)})', [values], [output], keys=keys)) - # return output - # if isinstance(val, torch.fx.Node): - # return frame.get_var(val.name) - # return val - # output = generate_outputs(node.args[0], ir_nodes) def generate_outputs(val: Any) -> Any: """Support complex data type of List, Tuple, Dict, Tensor/Object""" @@ -443,11 +393,6 @@ def generate_outputs(val: Any) -> Any: return val output = generate_outputs(node.args[0]) - # # TODO: support more complex data type - # outs = (node.args[0],) if isinstance(node.args[0], torch.fx.Node) else node.args[0] - # assert all(isinstance(t, torch.fx.Node) for t in outs), "Only support model return with tuple of " - # output = [frame.get_var(t.name) for t in outs] - frame.set_var(node.name, output) return ir_nodes diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 944ed326..f5e5704c 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -948,7 +948,7 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: sub_cids = set(node.cid for node in nodes) for node in nodes: for itensor in node.inputs(): - if not isinstance(itensor, IRTensor): continue + if not isinstance(itensor, IRObject): continue if itensor.is_attr(): continue producers, ptensors = self.producers(itensor.parent), self.ptensors(itensor.parent) pids = set(p.cid for p, t in zip(producers, ptensors) if dmatch(t, itensor)) @@ -959,9 +959,9 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: if all(pid not in sub_cids for pid in pids): inputs.add(itensor) for otensor in node.outputs(): - if not isinstance(otensor, IRTensor): continue + if not isinstance(otensor, IRObject): continue # if the tensor is required by segment outputs or is loss during train, set as output - if otensor.is_loss() or otensor in segment_outputs: + if (isinstance(otensor, IRSubTensor) and otensor.is_loss()) or otensor in segment_outputs: outputs.add(otensor) continue consumers, ctensors = self.consumers(otensor.parent), self.ctensors(otensor.parent) diff --git a/cube/ir/cten.py b/cube/ir/cten.py index e7c0c1ff..a73f3c5d 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -433,7 +433,7 @@ class IRObject: IRObject serves as general data of IRGraph edge """ - def __init__(self, name: Optional[str] = None, tid: Optional[int] = None): + def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: Optional[None] = None): """ @param name str: object name @param tid int: object unique id @@ -442,6 +442,7 @@ def __init__(self, name: Optional[str] = None, tid: Optional[int] = None): self.name: str = name if name else 'obj' self._cell: Optional[IRCell] = None self._is_attr: bool = False + self._value: Optional[Any] = value def __eq__(self, obj): if not isinstance(obj, IRObject): @@ -494,6 +495,11 @@ def parent(self): """Get parent""" return self + @property + def value(self) -> Any: + """Get example value""" + return self._value + def __eq__(self, obj) -> bool: if not isinstance(obj, IRObject): return False @@ -501,7 +507,7 @@ def __eq__(self, obj) -> bool: def __copy__(self): """Copy this object but remove the cell information""" - return IRObject(self.name, self._id) + return IRObject(self.name, self._id, self._value) def as_attr(self): """ @@ -528,7 +534,7 @@ def overlap(self, other: Any) -> bool: return False def __repr__(self): - return f'Object({self.name}{self.tid})' + return f'Object({self.name}{self.tid}, val={self.value})' class IRTensor(IRObject): diff --git a/cube/program.py b/cube/program.py index d9dc4595..504f741b 100644 --- a/cube/program.py +++ b/cube/program.py @@ -142,35 +142,33 @@ def generate_output(sample): class SemanticModel: - def __init__(self, model: Optional[torch.nn.Module], input_shapes=None, dummy_input=None): + def __init__(self, model: Optional[torch.nn.Module], + save_content: bool = True, + dynamic_shape: bool = False): """ Create semantic model based on AI Scientist description. - @param model Optional[torch.nn.Module]: Model description. Each device of local_rank == 0 needs to provide. - @param input_shapes Any: to compatable with previous interface. No more need. + Args: + model (Optional[torch.nn.Module]): + single-device model description, only required for rank 0 + save_content (bool): + whether to save the content of model and load it into generated model. Default True. + dynamic_shape (bool): + whether to use dynamic shape. Default False. """ if DeviceGroup().local_rank == 0: assert isinstance(model, torch.nn.Module), f"device of local_rank == 0 must provide model" self.model = model - self.input_shapes = None - self._dummy_input = dummy_input - self.ir_graph = None + self._dummy_input = None + self._ir_graph = None self._loaded_module: CubeModule = None - self._save_content = True - - @property - def save_content(self) -> bool: - return self._save_content - - @save_content.setter - def save_content(self, val: bool): - self._save_content = val + # parser configuration + self.save_content: bool = save_content + self.dynamic_shape: bool = dynamic_shape @property def dummy_input(self) -> Any: - """ - Get dummy real-tensor input from on CPU - """ + """Get dummy real-tensor input from on CPU""" return self._dummy_input @dummy_input.setter @@ -216,22 +214,22 @@ def clear_module(self): self._loaded_module = None def __call__(self, *args): - """ - Forward the model. + """Forward the semantic model. + This will trigger torch.jit.script to parse the model. + + Args: + *args: input IRObjects """ - if self._loaded_module: - return self._loaded_module(*args) - else: - # assert all(isinstance(t, IRSubTensor) for t in args), f"Only support tensors as model inputs" + assert self._ir_graph is None, \ + f"multiple forward on a semantic model is not allowed" + if DeviceGroup().local_rank == 0: input_shapes = [tuple(t.shape) if isinstance(t, IRTensor) else None for t in args] - if DeviceGroup().local_rank == 0: - if self.ir_graph is None: - self.ir_graph = parser.convert_model( - self.model, input_shapes=input_shapes, dummy_input=self.dummy_input, save_content=self.save_content - ) - self.input_shapes = input_shapes - else: - assert tuple(self.input_shapes) == tuple(input_shapes), \ - f"Multiple forwarding of a same model, which require input shapes to be same." - return self.ir_graph(*args) + self._ir_graph = parser.convert_model( + self.model, + input_shapes=input_shapes, + dummy_input=self.dummy_input, + save_content=self.save_content, + dynamic_shape=self.dynamic_shape + ) + return self._ir_graph(*args) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 059bb6da..eb41e226 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -47,7 +47,22 @@ def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Te return input.expand(*sizes) -def fullslice(input: torch.Tensor, slicers: Tuple[Union[None, slice]]): +def fullslice(input: torch.Tensor, slicers: Tuple[Union[None, slice, int]]): + """Slice tensors + + Note: + 1) `None` will always extend a dimension at current position. + 2) `slice(None, None, None)` equals to `:`, + meaning select every element at its dimension. + + Args: + input (torch.Tensor): input tensor + slicers (Tuple[None | slicer | int]): slicer tuple + + + Returns: + torch.Tensor: sliced tensor + """ return input[slicers] diff --git a/tests/parser/test_fx_ops.py b/tests/parser/test_fx_ops.py index e78c9ff1..d6cf09f0 100644 --- a/tests/parser/test_fx_ops.py +++ b/tests/parser/test_fx_ops.py @@ -78,7 +78,8 @@ def policy(graph, resource): model = cube.SemanticModel(model) - @cube.compile(model, dataloader, policy, load_content=False) + @cube.compile(model, dataloader, PAS=policy, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) def eval_iter(model, dataloader): data = next(dataloader) out = model(data) From 453a2adf9a81225b87da3961fcefaa0729772373 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 5 Jul 2023 17:20:50 +0800 Subject: [PATCH 1409/1892] update with fusion interface --- cube/graph/graph.py | 78 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index edf17f82..db7f520f 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -22,6 +22,7 @@ from cube.graph.function.function import Identity from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.pyfunc import IRPyFunc +from cube.graph.function.dimops import IRDimops, OpAnno from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo @@ -431,6 +432,83 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], return fnodes + def fuse(self, nodes: List[IRFwOperation], + signature: Optional[str] = None, + args: Optional[List[IRObject]] = None, kwargs: Optional[Dict[str, Any]] = None) -> IRDimops: + """Fuse primitive. + + Fuse a list of forward operators into a single operator. + The backward operators will be fused automatically. + + Note: + 1) fusion can by applied for consecutive operators on the same device (developer-level call). + 2) fusion can be applied before any node paritioning or after generation of adapters (system-level call). + + Args: + nodes (List[IRFwOperation]): the operators to fuse. + signature (Optional[str], optional): + the signature of the fused operator. If not provided, the fusion will perform a simple grouping of operators, + where the underlying runtime still call the unfused kernel one by one. If the signature is provided, + the fusion will generate an IRDimops calling `signature`, which is expected to be a function signature + of the fused operator. Defaults to None. + args (Optional[List[IRObject]], optional): the arguments of the fused operator. Defaults to None. + kwargs (Optional[Dict[str, Any]], optional): the keyword arguments of the fused operator. Defaults to None. + + Returns: + IRDimops: the fused operator. + """ + assert len(nodes) > 0, "Cannot fuse empty list of nodes" + assert all([isinstance(node, IRFwOperation) for node in nodes]), \ + "Only forward operators are allowed to fuse" + indices = [self.index(node) for node in nodes] + assert max(indices) - min(indices) + 1 == len(nodes), \ + "Only consecutive operators can be fused" + + segment: IRSegment = self.create_segment(nodes) + # get inputs where tensors should appear in the front. + inputs = list(segment.inputs()) + attributes = [segment.ctensors(attr)[0] for attr in segment.attributes()] + inputs += attributes + inputs = [t for t in inputs if isinstance(t, IRTensor)] + [t for t in inputs if not isinstance(t, IRTensor)] + # get outputs + outputs = list(segment.outputs()) + + if args is not None: + assert len(inputs) == len(args) and set(inputs) == set(args), \ + "inputs don't match" + inputs = args + kwargs = {} if kwargs is None else kwargs + + # create annotation. TODO: support partition + in_shapes = [[str(dimlen) for dimlen in t.shape] for t in inputs if isinstance(t, IRTensor)] + ou_shapes = [[str(dimlen) for dimlen in t.shape] for t in outputs if isinstance(t, IRTensor)] + anno: str = OpAnno.create_op_str(in_shapes, ou_shapes) + + if signature is None: + assert False, "TODO: register function" + + if len(nodes) < 4: + name = '_'.join(['fused'] + [node.name for node in nodes]) + else: + name = '_'.join(['fused'] + [node.name for node in nodes[:3]] + ['etc']) + + def fuse_ops(*args, **kwargs) -> IRDimops: + return IRDimops(fuse_ops, name, signature, [anno], args, **kwargs) + + fuse_op = fuse_ops(*inputs, **kwargs) + for idx, output in enumerate(outputs): + fuse_op.set_output(idx, output) + + # setup device + if len(nodes[0].device) != 0: + fuse_op.device = nodes[0].device + + # replace + segment = self.segment(nodes[0]) + segment.replace(nodes, [fuse_op]) + + return fuse_op + ## Spatial Primitives ## def assign(self, node: Union[IRFwOperation, IRDataOperation], device: int) -> bool: From b64f8fcfb11e2fb6ef5dd3fa48d47c3d028715e1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jul 2023 12:27:23 +0800 Subject: [PATCH 1410/1892] enable customized function registration --- cube/graph/parser/converter.py | 6 ++++ cube/graph/parser/register.py | 61 ++++++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index b66c4b47..7ae15cae 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -5,6 +5,7 @@ from cube.graph.parser import ScriptModuleParser from cube.graph.parser import FxModuleParser, FxFuncOpTracer from cube.graph.parser.concrete_trace_utils import concrete_trace +from cube.graph.parser.register import CustomizedOps from cube.graph import IRGraph from cube.flags import CompileFlag @@ -25,6 +26,10 @@ def convert_model(model: torch.nn.Module, """ Convert torch.nn.Module based model into IRGraph """ + # get registered leaf function + customized_funcs = CustomizedOps.kOpRuntime.values() + leaf_functions = {func: ([], False, None) for func in customized_funcs} + try: if CompileFlag.use_torchfx: if CompileFlag.use_default_fx_tracer: @@ -58,6 +63,7 @@ def convert_model(model: torch.nn.Module, dummy_input, use_operator_patch=True, leaf_module=leaf_module, + autowrap_leaf_function=leaf_functions, cpu_offload=True, ) else: diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 310d4eae..44f64a48 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -2,7 +2,7 @@ Register cutomized function """ -from typing import Dict, Callable, List, Optional +from typing import Dict, Callable, List, Optional, Any from functools import partial import inspect import torch @@ -11,13 +11,25 @@ class CustomizedOps: + """Customized op registry.""" + # signature -> IRDimop creation function kOpMap: Dict[str, Callable] = {} - # customized operator code: signature -> code + # singature -> runtime function + kOpRuntime: Dict[str, Callable] = {} + # signature -> runtime function implementation code kOpCodeDef: Dict[str, str] = {} @staticmethod def map(signature: str) -> Callable: + """Get IRDimop creation function by signature + + Args: + signature (str): operator signature + + Returns: + Callable: IRDimop creation function + """ signature = signature.split('.')[-1] if signature in CustomizedOps.kOpMap: return partial(CustomizedOps.kOpMap[signature], signature=signature) @@ -26,13 +38,22 @@ def map(signature: str) -> Callable: @staticmethod def exist(signature: str) -> bool: + """Check if the signature is registered""" signature = signature.split('.')[-1] return signature in CustomizedOps.kOpMap @staticmethod - def register(signature: str, op: Callable, code: str): - """ - Register an operator + def register(signature: str, op: Callable, code: str, runtime_fn: Callable): + """Register an operator + + Args: + signature (str): operator signature + op (Callable): IRDimop creation function + code (str): runtime function implementation code + runtime_fn (Callable): runtime function + + Returns: + None """ builtins = ['_operator', 'torch', 'cube.runtime.function'] if any(signature.startswith(builtin) for builtin in builtins): @@ -40,14 +61,17 @@ def register(signature: str, op: Callable, code: str): signature = signature.split('.')[-1] assert signature not in CustomizedOps.kOpMap, f"function {signature} is already registered" CustomizedOps.kOpMap[signature] = op + CustomizedOps.kOpRuntime[signature] = runtime_fn CustomizedOps.kOpCodeDef[signature] = code -def register(anno: str, name: Optional[str] = None, rules: Optional[List] = None): +def register(anno: str, name: Optional[str] = None, + rules: Optional[List] = None, + input_type_annos: Optional[List[Any]] = None) -> Callable: """ Register a function with einop annotations. - This function is cooperated with IREinOp. + This function is cooperated with IRDimops. User needs to define a python function that satisfies 1). Has type annotations for each input 2). Tensor inputs goes first then other inputs @@ -63,11 +87,17 @@ def funcname(x: torch.Tensor, b: int = 4): Note: for Optional[torch.Tensor] type, user should annotate the dimension when the input is not None. - @param anno str: operator annotation - @param name str: operator name - @param rules Optional[List[TransformRule]]: additional transformation rules. + Args: + anno (str): operator annotation + name (str): operator name + rules (Optional[List[TransformRule]]): + additional transformation rules. + input_type_annos (Optional[List[Any]]): + type annotations for inputs. If not provided, the function + should be annotated with types. - @return fn Callable: the runtime function + Returns: + fn (Callable): the runtime function """ def decorator(fn: Callable): if not callable(fn): @@ -76,7 +106,12 @@ def decorator(fn: Callable): op_name = name if name is not None else fsig args = inspect.signature(fn) arg_names = list(args.parameters.keys()) - arg_kinds = [args.parameters[name].annotation for name in arg_names] + # get argument types + arg_kinds = input_type_annos if input_type_annos is not None else \ + [args.parameters[name].annotation for name in arg_names] + assert len(arg_kinds) == len(arg_names), \ + "Number of annotations should match with number of arguments" + # parse for number of inputs and kwargs allow_types = (torch.Tensor, Optional[torch.Tensor]) for ninputs, kind in enumerate(arg_kinds): if kind in allow_types: @@ -103,7 +138,7 @@ def udfop(*args, signature=None, **kwargs): return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, transform_rules=rules, **kwargs) print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') - CustomizedOps.register(fsig, udfop, code) + CustomizedOps.register(fsig, udfop, code, fn) return fn return decorator From da3a83d37f4206be7eb1bf55d3bf00d7b3dbc43c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jul 2023 12:27:45 +0800 Subject: [PATCH 1411/1892] add test for customized function for torch.fx --- tests/parser/test_fx_ops.py | 70 +++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/tests/parser/test_fx_ops.py b/tests/parser/test_fx_ops.py index d6cf09f0..41f5827d 100644 --- a/tests/parser/test_fx_ops.py +++ b/tests/parser/test_fx_ops.py @@ -8,6 +8,19 @@ from cube.ir.operator import IRFwOperation, IRDataOperation from cube.graph.function.dimops import IRDimops +cube.init() + + +@cube.graph.parser.register('a b -> a b', name='test_op1') +def test_op1(a: torch.Tensor): + return a.clone() + + +@cube.graph.parser.register('a b -> a b', name='test_op2', + input_type_annos=[torch.Tensor, int]) +def test_op2(a, b): + return a + b + def _param(size, dtype=torch.float32): return torch.nn.Parameter(torch.empty(size, dtype=dtype)) @@ -34,6 +47,24 @@ def forward(self, x: torch.Tensor): # [bs, 128] -> [1] loss = torch.sum(x4) return {'x': x4, 'loss': loss} # , [x3,] + + +class TestOpModuleForCustomizeOp(torch.nn.Module): + + def __init__(self): + super().__init__() + self.param1 = _param([512, 256]) + self.param2 = _param([512, 256]) + self.ints = [1, 2, 3] + + def forward(self, x: torch.Tensor): + # matmul: [bs, 512], [512, 256] -> [bs, 256] + x1 = torch.matmul(x, self.param1) + # [bs, 256] -> [bs, 256] + x2 = test_op1(x1) + x3 = test_op2(x2, 1) + loss = torch.sum(x3) + return {'x': x3, 'loss': loss} # , [x3,] class TestDataLoader(cube.runtime.syndata.CubeDataLoader): @@ -58,8 +89,6 @@ def set_batch_size(self, batch_size: int): def test_parse_ops(): - cube.init() - model = TestOpModule() dataloader = TestDataLoader() @@ -93,6 +122,41 @@ def eval_iter(model, dataloader): print(f"iter {idx}/3") +def test_registered_ops(): + + model = TestOpModuleForCustomizeOp() + dataloader = TestDataLoader() + + def policy(graph, resource): + print(graph.extra_repr()) + assert resource.ngpus == 1 + for node in graph.nodes(): + if isinstance(node, IRDimops): + print(f'# {node.anno}') + print(node) + elif isinstance(node, (IRFwOperation, IRDataOperation)): + print(node) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + model = cube.SemanticModel(model) + + @cube.compile(model, dataloader, PAS=policy, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) + def eval_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out['loss'].backward() + # return out + + model = model.get_gen_module() + + for idx in range(3): + eval_iter(model, dataloader) + print(f"iter {idx}/3") + + if __name__ == '__main__': test_parse_ops() - + test_registered_ops() From dfa591338f53465b02eba01af61a8354a6e1ccfd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jul 2023 13:34:29 +0800 Subject: [PATCH 1412/1892] update fusion interface --- cube/graph/graph.py | 106 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 85 insertions(+), 21 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index db7f520f..c462a8c3 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -434,7 +434,11 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], def fuse(self, nodes: List[IRFwOperation], signature: Optional[str] = None, - args: Optional[List[IRObject]] = None, kwargs: Optional[Dict[str, Any]] = None) -> IRDimops: + fuse_op_args: Optional[List[IRObject]] = None, + fuse_op_kwargs: Optional[Dict[str, Any]] = None, + fuse_op_outputs: Optional[List[IRObject]] = None, + fuse_op_anno: str = None, + fuse_op_name: str = None) -> IRDimops: """Fuse primitive. Fuse a list of forward operators into a single operator. @@ -451,8 +455,16 @@ def fuse(self, nodes: List[IRFwOperation], where the underlying runtime still call the unfused kernel one by one. If the signature is provided, the fusion will generate an IRDimops calling `signature`, which is expected to be a function signature of the fused operator. Defaults to None. - args (Optional[List[IRObject]], optional): the arguments of the fused operator. Defaults to None. - kwargs (Optional[Dict[str, Any]], optional): the keyword arguments of the fused operator. Defaults to None. + fuse_op_args (Optional[List[IRObject]], optional): + the arguments of the fused operator. Defaults to None. + fuse_op_kwargs (Optional[Dict[str, Any]], optional): + the keyword arguments of the fused operator. Defaults to None. + fuse_op_outputs (Optional[List[IRObject]], optional): + the outputs of the fused operator. Defaults to None. + fuse_op_anno (str, optional): + the annotation of the fused operator. Defaults to None. + fuse_op_name (str, optional): + the name of the fused operator. Defaults to None. Returns: IRDimops: the fused operator. @@ -460,7 +472,7 @@ def fuse(self, nodes: List[IRFwOperation], assert len(nodes) > 0, "Cannot fuse empty list of nodes" assert all([isinstance(node, IRFwOperation) for node in nodes]), \ "Only forward operators are allowed to fuse" - indices = [self.index(node) for node in nodes] + indices: List[int] = [self.index(node).indices[-1] for node in nodes] assert max(indices) - min(indices) + 1 == len(nodes), \ "Only consecutive operators can be fused" @@ -473,29 +485,69 @@ def fuse(self, nodes: List[IRFwOperation], # get outputs outputs = list(segment.outputs()) - if args is not None: - assert len(inputs) == len(args) and set(inputs) == set(args), \ + # reorder and check op inputs and outputs + if fuse_op_args is not None: + assert len(inputs) == len(fuse_op_args) and set(inputs) == set(fuse_op_args), \ "inputs don't match" - inputs = args - kwargs = {} if kwargs is None else kwargs + inputs = fuse_op_args + kwargs = {} if fuse_op_kwargs is None else fuse_op_kwargs + if fuse_op_kwargs is not None: + assert len(outputs) == len(fuse_op_outputs) and set(outputs) == set(fuse_op_outputs), \ + "outputs don't match" + outputs = fuse_op_outputs # create annotation. TODO: support partition - in_shapes = [[str(dimlen) for dimlen in t.shape] for t in inputs if isinstance(t, IRTensor)] - ou_shapes = [[str(dimlen) for dimlen in t.shape] for t in outputs if isinstance(t, IRTensor)] - anno: str = OpAnno.create_op_str(in_shapes, ou_shapes) + if fuse_op_anno is None: + in_shapes = [[str(dimlen) for dimlen in t.shape] for t in inputs if isinstance(t, IRTensor)] + ou_shapes = [[str(dimlen) for dimlen in t.shape] for t in outputs if isinstance(t, IRTensor)] + fuse_op_anno: str = OpAnno.create_op_str(in_shapes, ou_shapes) + + if fuse_op_name is None: + if len(nodes) < 4: + fuse_op_name = '_'.join(['fused'] + [node.name for node in nodes]) + else: + fuse_op_name = '_'.join(['fused'] + [node.name for node in nodes[:3]] + ['etc']) + # if signature is not provided, register the fused function by + # grouping the node implementations together inside a function. + # This doesn't make real fusion but can help reduce partition + # search space for the policy. + make_customized_op: bool = signature is None if signature is None: - assert False, "TODO: register function" + signature = f'{fuse_op_name}_{nodes[0].cid}_to_{nodes[-1].cid}' + + def fuse_op_fn(*args, **kwargs) -> IRDimops: + return IRDimops(fuse_op_fn, fuse_op_name, signature, [fuse_op_anno], args, **kwargs) - if len(nodes) < 4: - name = '_'.join(['fused'] + [node.name for node in nodes]) - else: - name = '_'.join(['fused'] + [node.name for node in nodes[:3]] + ['etc']) + if make_customized_op: + from cube.graph.parser.register import CustomizedOps - def fuse_ops(*args, **kwargs) -> IRDimops: - return IRDimops(fuse_ops, name, signature, [anno], args, **kwargs) + def to_name(t: Any) -> str: + """Convert an object to its name.""" + if isinstance(t, IRObject): + return '_'.join([t.name, str(t.tid)]) + elif isinstance(t, str) and not t.startswith('self.'): + return f"'{t}'" + return str(t) + # function inputs / outputs + func_inputs = ','.join(to_name(t) for t in inputs) + func_kwargs = ','.join(f'{k}={to_name(v)}' for k, v in kwargs.items()) + func_outputs = ','.join([to_name(t) for t in outputs]) + # generate code + code = [f'def {signature}({func_inputs}, {func_kwargs}):'] + for node in nodes: + node_inputs = ','.join(to_name(t) for t in node.inputs()) + node_kwargs = ','.join(f'{k}={to_name(v)}' for k, v in node.kwargs.items()) + node_outputs = ','.join(to_name(t) for t in node.outputs()) if len(outputs) > 0 else '_' + code += [f'\t{node_outputs} = {node.signature}({node_inputs}, {node_kwargs})'] + code.append(f'\treturn {func_outputs}') + code = '\n'.join(code) + CustomizedOps.register( + signature, fuse_op_fn, code, + lambda *args : NotImplementedError("a fused operator doesn't have runtime call") + ) - fuse_op = fuse_ops(*inputs, **kwargs) + fuse_op = fuse_op_fn(*inputs, **kwargs) for idx, output in enumerate(outputs): fuse_op.set_output(idx, output) @@ -503,9 +555,21 @@ def fuse_ops(*args, **kwargs) -> IRDimops: if len(nodes[0].device) != 0: fuse_op.device = nodes[0].device - # replace + # replace nodes with the fused operator + # remove forward operators segment = self.segment(nodes[0]) - segment.replace(nodes, [fuse_op]) + indices = [segment.remove(node).indices[-1] for node in nodes] + idx = min(indices) + # remove backward operators + have_backward = any(node.mirror is not None for node in nodes) + for node in nodes: + if node.mirror is not None: + segment.mirror.remove(node.mirror) + # insert forward/backward operators + if have_backward: + segment.finsert(fuse_op, idx) + else: + segment.insert(fuse_op, idx) return fuse_op From 3d5f0e99a6252d955125351d7cb97062130256e5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jul 2023 13:37:50 +0800 Subject: [PATCH 1413/1892] add test fusion example --- tests/graph/test_fusion.py | 99 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 tests/graph/test_fusion.py diff --git a/tests/graph/test_fusion.py b/tests/graph/test_fusion.py new file mode 100644 index 00000000..6b201f41 --- /dev/null +++ b/tests/graph/test_fusion.py @@ -0,0 +1,99 @@ +""" +USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/graph/test_fusion.py +""" +from typing import List +import torch + +import cube +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops + +cube.init() + + +def _param(size, dtype=torch.float32): + return torch.nn.Parameter(torch.empty(size, dtype=dtype)) + + +class TestModuleForFusedOp(torch.nn.Module): + + def __init__(self): + super().__init__() + self.param1 = _param([512, 256]) + self.param2 = _param([512, 256]) + self.ints = [1, 2, 3] + + def forward(self, x: torch.Tensor): + # matmul: [bs, 512], [512, 256] -> [bs, 256] + x1 = torch.matmul(x, self.param1) + # [bs, 256] -> [bs, 256] + x2 = x1.clone() + x3 = x2 + 1 + loss = torch.sum(x3) + return {'x': x3, 'loss': loss} # , [x3,] + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int = 256) -> None: + self.sample = torch.rand( + [batch_size, 512], + dtype=torch.float32, + device=torch.cuda.current_device() + ) + super().__init__(batch_size, (0,)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + + +def test_fused_op(): + + model = TestModuleForFusedOp() + dataloader = TestDataLoader() + + def policy(graph, resource): + assert resource.ngpus == 1 + print(graph.extra_repr()) + + clone = graph.select(name='clone')[0] + idx = graph.index(clone) + clonse_add = [clone, graph.node(idx+1)] + graph.fuse(clonse_add) + + for node in graph.nodes(): + if isinstance(node, IRDimops): + print(f'# {node.anno}') + print(node) + elif isinstance(node, (IRFwOperation, IRDataOperation)): + print(node) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + model = cube.SemanticModel(model) + + @cube.compile(model, dataloader, PAS=policy, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) + def eval_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out['loss'].backward() + # return out + + model = model.get_gen_module() + + for idx in range(3): + eval_iter(model, dataloader) + print(f"iter {idx}/3") + + +if __name__ == '__main__': + test_fused_op() From 3b5f17d160e9c133bfd2b0e49009021a97453c4d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jul 2023 07:22:20 +0000 Subject: [PATCH 1414/1892] Merged PR 1647: allow segment take IRObject as inputs and outputs allow segment take IRObject as inputs and outputs https://msrasrg.visualstudio.com/SuperScaler/_workitems/edit/1463 --- cube/graph/graph.py | 1 + cube/graph/segment.py | 33 +++++++++----- tests/graph/test_segment.py | 87 +++++++++++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 11 deletions(-) create mode 100644 tests/graph/test_segment.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index edf17f82..43a56149 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -220,6 +220,7 @@ def group(self, nodes: List[IRCell]) -> IRSegment: # update gradient for fgraph for itensor in fsegment.inputs(): + if not isinstance(itensor, IRTensor): continue fgraph.infer_grad(itensor.parent) # update gradient inside segment for ftensor in fsegment.full_tensors(): diff --git a/cube/graph/segment.py b/cube/graph/segment.py index f5e5704c..f3fa0b4b 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -915,30 +915,38 @@ def get_outputs(nodes: List[IRCell], exclude_attr: bool = True): continue return outputs - def create_segment(self, nodes: List[IRCell]) -> IRCell: - """! - Create a segment with part of the nodes. + def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> IRCell: + """Create a segment (sub-graph) with part of the nodes. + This only return the created segment wihout modifying the graph. Calling this requires that the dependencies are already materialized, - i.e., every input IRSubTensor should have a corresponding producer. + i.e., every input IRSubTensor should have a corresponding producer. Two scenarios + satisfy this condition: + + 1) the node in the graph is not partitioned; + + 2) the adapters (communication) are generated. - @param nodes List[IRCell]: the subset nodes of this graph + Args: + nodes (List[IRCell]): the subset nodes of this graph + attr_as_inputs (bool): whether to treat attributes as segment inputs - @return segment IRSegment: the grouped segment. + Returns: + segment (IRSegment): the grouped segment. """ segment = self segment_outputs = IRSegment.get_objects_from_complex(segment.outputs()) # setup adapter dependency - ad_consumers: Dict[Tuple[IRSubTensor,int], Set[int]] = dict() - ad_producers: Dict[Tuple[IRSubTensor,int], Set[int]] = dict() + ad_consumers: Dict[Tuple[IRObject,int], Set[int]] = dict() + ad_producers: Dict[Tuple[IRObject,int], Set[int]] = dict() for adapter in self.select(ntype=IRAdapter): for itensor in adapter.inputs(): - if not isinstance(itensor, IRSubTensor): continue + if not isinstance(itensor, IRObject): continue ad_consumers.setdefault((itensor, itensor.device[0]), set()).add(adapter.cid) for otensor in adapter.outputs(): - if not isinstance(otensor, IRSubTensor): continue + if not isinstance(otensor, IRObject): continue ad_producers.setdefault((otensor, otensor.device[0]), set()).add(adapter.cid) # tensor and its device match @@ -949,7 +957,10 @@ def create_segment(self, nodes: List[IRCell]) -> IRCell: for node in nodes: for itensor in node.inputs(): if not isinstance(itensor, IRObject): continue - if itensor.is_attr(): continue + if itensor.is_attr(): + if attr_as_inputs: + inputs.add(itensor) + continue producers, ptensors = self.producers(itensor.parent), self.ptensors(itensor.parent) pids = set(p.cid for p, t in zip(producers, ptensors) if dmatch(t, itensor)) if len(itensor.device) > 0: diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py new file mode 100644 index 00000000..be8b245d --- /dev/null +++ b/tests/graph/test_segment.py @@ -0,0 +1,87 @@ +""" +PYTHONPATH=.:$PYTHONPATH torchrun --nproc_per_node=1 \ + tests/graph/test_segment.py +""" +import torch + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops + +cube.init() + + +def _param(shape, dtype=torch.float32): + return torch.nn.Parameter(torch.empty(shape, dtype=dtype)) + + +class TestOpModule(torch.nn.Module): + + def __init__(self, shape=[256, 512]): + super().__init__() + self.param = _param(shape) + + def forward(self, x: torch.Tensor, y: int): + x = x * self.param + x = x + y + loss = torch.sum(x) + return loss + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int = 256) -> None: + self.sample = ( + torch.rand([batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()), + 4, + ) + super().__init__(batch_size, (0, None)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +def test_segment_creation(): + + cube.init() + + model = TestOpModule() + dataloader = TestDataLoader() + + def policy(graph: IRGraph, resource): + assert resource.ngpus == 1 + fwops = graph.select(ntype=IRFwOperation) + graph.staging([fwops[0]]) + print(graph.extra_repr()) + for node in fwops: + graph.assign(node, 0) + for dl in graph.select(ntype=IRDataOperation): + graph.assign(dl, 0) + return graph + + sample_x, sample_y = next(dataloader) + + @cube.compile(model, dataloader, PAS=policy, load_content=True, + model_dummy_inputs={'x': sample_x, 'y': sample_y}) + def train_iter(model, dataloader): + data = next(dataloader) + loss = model(*data) + loss.backward() + + model = cube.load_model() + + for idx in range(3): + train_iter(model, dataloader) + print(f"iter {idx}/3") + print('Done') + + +if __name__ == '__main__': + test_segment_creation() From 9701187a6af2b7288f885916381adbe436080204 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jul 2023 08:51:30 +0000 Subject: [PATCH 1415/1892] Merged PR 1598: add support with auto-multiref; add constraints that all devices should be used add support with auto-multiref; add constraints all devices should be used --- examples/policies/alpa/__init__.py | 46 +++++++++++++++++++++++------- examples/policies/alpa/inter_op.py | 2 ++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/examples/policies/alpa/__init__.py b/examples/policies/alpa/__init__.py index b028e23d..af910735 100644 --- a/examples/policies/alpa/__init__.py +++ b/examples/policies/alpa/__init__.py @@ -4,10 +4,11 @@ import torch from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.dimops import IRDimops +from cube.graph.function.dimops import IRDimops, TransformRule, DimopSplit from cube.graph.graph import IRGraph from cube.graph.segment import IRSegment from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.ir.tensor import IRFullTensor from cube.graph.schedule.predefined import PredefinedSched from cube.runtime.device import DeviceGroup @@ -36,13 +37,42 @@ def _tp(graph: IRGraph, node: IRDimops, devs: List[int], **configs) -> List[IRDi return sub_nodes +def _auto_multiref(graph: IRGraph, plan: ParallelSpec): + """ + Apply automated multiref on tensors that are partitioned differently by different nodes + """ + # get parallel strategy + specs = dict() + for stage in plan.stages: + for cid, spec in stage.tp_spec.items(): + specs[cid] = spec + + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): continue + if len(graph.consumers(ftensor)) <= 1: continue + consumers, ctensors = graph.consumers(ftensor), graph.ctensors(ftensor) + splits = set() + for consumer, ctensor in zip(consumers, ctensors): + spec = specs[consumer.cid] + if spec is None: + splits.add(DimopSplit.R()) + else: + idx, dim = spec + rule: TransformRule = consumer.algorithms('dim').infer(idx, dim, 1) + split = rule.inputs()[consumer.inputs().index(ctensor)] + splits.add(split) + if len(splits) > 1: + print(f"> detected a(n) {'activation' if not ftensor.is_attr() else 'parameter'}: " + f"{ftensor.name}({ftensor.tid}) is partitioned differently. Apply multierf...") + graph.multiref(ftensor) + + def PASAlpa(graph: IRGraph, resource, recompute: bool = False, nmicros: int = 1, db_cache: str = 'db_train.json', load_spec_file: Optional[str] = None, save_spec_file: Optional[str] = None, - use_multiref: bool = False, max_pp_size: Optional[int] = None, max_tp_size: Optional[int] = None, max_layer_number: int = 12) -> IRGraph: @@ -53,7 +83,7 @@ def PASAlpa(graph: IRGraph, resource, for AutoLayer partition position @param graph IRGraph: model graph - @param rresource Resource: resource + @param resource Resource: resource @param recompute bool: whether to enable recompute on each layer @param nmicros int: number of micro-batches @param db_cache str: database cache file @@ -63,13 +93,6 @@ def PASAlpa(graph: IRGraph, resource, @param max_tp_size Optional[int]: limit the maximum number of tensor parallelism size @param max_layer_number Optional[int]: maximum number of layers to search """ - # enable this for multiref - if use_multiref: - for ftensor in graph.full_tensors(): - if ftensor.is_grad() or ftensor.is_attr(): continue - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor) - # recompute granularity will follow original anchor scope layers = annotate_structure(graph) if recompute: @@ -151,6 +174,9 @@ def PASAlpa(graph: IRGraph, resource, print(f'> instantiate plan...') # print(graph.extra_repr()) + # auto-multiref + _auto_multiref(graph, config) + # staging cid2node = {n.cid : n for n in nodes} leading_cids = [list(stage.tp_spec.keys())[0] for stage in config.stages] diff --git a/examples/policies/alpa/inter_op.py b/examples/policies/alpa/inter_op.py index a11585f1..7db2a5f5 100644 --- a/examples/policies/alpa/inter_op.py +++ b/examples/policies/alpa/inter_op.py @@ -103,6 +103,8 @@ def DP(nodes: Tuple[IRLayerOp], k: int, s: int, intra_solver: Callable, if not is_of_power2(t * d): continue # guarantee sub-problem searchable if k - d * t < s - 1: continue + # constraints: every device must be used + if s - 1 > 0 and len(sub2) == 0: continue # sub2 cost DP(sub2, k-d*t, s-1, intra_solver, mbs, max_d, max_t, _cost, _config, _intra_cache) From d05abb422568cb9caa6c15501116c78401022bed Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 6 Jul 2023 23:39:56 +0800 Subject: [PATCH 1416/1892] add choice for code implementation pattern --- cube/graph/parser/register.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 44f64a48..8afbce32 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -5,6 +5,7 @@ from typing import Dict, Callable, List, Optional, Any from functools import partial import inspect +import warnings import torch from cube.graph.function.dimops import IRDimops, OpAnno @@ -67,7 +68,8 @@ def register(signature: str, op: Callable, code: str, runtime_fn: Callable): def register(anno: str, name: Optional[str] = None, rules: Optional[List] = None, - input_type_annos: Optional[List[Any]] = None) -> Callable: + input_type_annos: Optional[List[Any]] = None, + code_impl_pattern: str = 'import') -> Callable: """ Register a function with einop annotations. @@ -95,7 +97,11 @@ def funcname(x: torch.Tensor, b: int = 4): input_type_annos (Optional[List[Any]]): type annotations for inputs. If not provided, the function should be annotated with types. - + code_impl_pattern (str): + can only be 'import' or 'source'. If 'import', will generate code with + import statement. If 'source', will take the source code directly. + Default: 'import'. + Returns: fn (Callable): the runtime function """ @@ -122,9 +128,23 @@ def decorator(fn: Callable): break nkwargs = len(arg_names) - ninputs kwarg_names = [name for name in arg_names[ninputs:]] + # get customized op code - code = inspect.getsource(fn) - code = code[code.index('def'):] + if code_impl_pattern == 'import': + import_path = inspect.getmodule(fn).__name__ + if import_path == '__main__': + warnings.warn(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' + f'This may cause error when the function has inner functions from other modules. ' + f'To solve this, define the function in another module and import into main', stacklevel=0) + code = inspect.getsource(fn) + code = code[code.index('def'):] + else: + code = f'from {import_path} import {fsig}' + elif code_impl_pattern == 'source': + code = inspect.getsource(fn) + code = code[code.index('def'):] + else: + raise ValueError(f'code_impl_pattern should be either "import" or "source", got {code_impl_pattern}') def udfop(*args, signature=None, **kwargs): manno = OpAnno(anno) From 0e69ab9e020f3ff10656d2d844ff9dd84aba625a Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 7 Jul 2023 04:08:29 +0000 Subject: [PATCH 1417/1892] Merged PR 1652: Refine parser and codegen to support MoE in torchscale Related work items: #1463, #1464, #1465 --- cube/codegen/emit.py | 2 +- cube/codegen/frontend_mapping.py | 10 +++++++ cube/graph/function/dimops.py | 6 ++++ cube/graph/function/function.py | 49 +++++++++++++++++++++---------- cube/graph/parser/mappingfx.py | 1 + cube/graph/parser/parserfx.py | 50 +++++++++++++++++++++++--------- cube/graph/segment.py | 2 ++ 7 files changed, 91 insertions(+), 29 deletions(-) diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 9eb5b707..71e27f66 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -58,7 +58,7 @@ def tensor_name(tensor: Any, prefix_attr: Optional[str] = None) -> str: if prefix_attr is not None and tensor.is_attr(): name = prefix_attr + name else: - name = str(tensor) + name = str(IRSegment.modify_objects_of_complex(tensor, CodeEmission.tensor_name)).replace('\'', '') return name @staticmethod diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index 60b65856..32b40029 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -87,12 +87,22 @@ def emit_getattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: """ return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" + @staticmethod + def emit_getitem(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + """Special rule for generating getitem node + """ + if len(arg_vars) == 2 and len(kw_pairs) == 0 and not arg_vars[1].replace('_', '').isdigit(): + return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" + else: + return Sign2EmitRule.emit_common(node, arg_vars, kw_pairs) + # the registered emit rules Sign2EmitRule._sign2rule = { 'torch.slice': Sign2EmitRule.emit_slice, 'setattr': Sign2EmitRule.emit_setattr, 'builtins.getattr': Sign2EmitRule.emit_getattr, + '_operator.getitem': Sign2EmitRule.emit_getitem, } diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index da496adc..c46e1a4a 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -67,6 +67,7 @@ import importlib import re import string +import warnings from cube.ir.cten import IRTensor, IRObject from cube.ir.dtype import DTypeInferRule @@ -662,6 +663,11 @@ def infer_shape(self) -> bool: """ for oidx, otensor in enumerate(self.outputs()): shape_anno = self.oanno(oidx) + if str(shape_anno) == '?': + assert isinstance(otensor, IRObject), f"expect IRObject for unknown shape, get {otensor}" + warnings.warn('detect IRObject output in a IRDimops, please ensure the annotation is' + 'correct w.r.t the partition policy.') + continue shape = [] for odim in range(shape_anno.ndims): accum = 1 diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 4c53b44f..e626abfa 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -7,6 +7,7 @@ import math import warnings import functools +from collections.abc import Iterable from cube.ir.cten import IRTensor, IRObject from cube.ir.tensor import IRSubTensor, IRFullTensor @@ -178,7 +179,7 @@ def Arange(*args, out=None, dtype=None, layout=None, def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): # note: device is ignored - assert layout is None and memory_format is None, f"Not support for non-default memory_format and layout" + assert layout in (None, torch.strided) and memory_format is None, f"Not support for non-default memory_format and layout" dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" signature = 'cube.runtime.function.empty' @@ -197,7 +198,7 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi def Zeros(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): # note: device is ignored - assert layout is None, f"Not support for non-default layout" + assert layout in (None, torch.strided), f"Not support for non-strided layout, get {layout}" dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" signature = 'cube.runtime.function.zeros' @@ -208,14 +209,14 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) from cube.graph.parser.dtype import DType2IRDType - dimop.output(0).parent.dtype = dtype if isinstance(dtype, IRDType) else DType2IRDType.map(dtype) + dimop.output(0).parent.dtype = DType2IRDType.map(dtype) return dimop def Ones(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): # note: device is ignored - assert layout is None, f"Not support for non-default layout" + assert layout in (None, torch.strided), f"Not support for non-strided layout, get {layout}" dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" signature = 'cube.runtime.function.ones' @@ -233,7 +234,7 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): # note: device is ignored - assert layout is None and memory_format is None, f"Not support for non-default memory_format and layout" + assert layout in (None, torch.strided) and memory_format is None, f"Not support for non-default memory_format and layout" dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" signature = 'cube.runtime.function.rand' @@ -316,10 +317,11 @@ def Expand(input, *sizes, signature = None): signature = 'cube.runtime.function.expand' edim_in = ShapeAnno.create_shape_str(input.shape) assert len(input.shape) == len(sizes) + edim_ou = copy.copy(edim_in) for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes)): if dim == 1 and dim != expand_dim: edim_in[idx] += '^' - edim_ou = copy.copy(edim_in) + edim_ou[idx] = str(expand_dim) anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(Expand, 'expand', signature, [anno], [input], sizes=sizes) @@ -1338,6 +1340,11 @@ def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice, int]], signatu edim_in = ShapeAnno.create_shape_str(tensor.shape) edim_ou = [] in_idx = 0 + def obj_helper(obj): + if isinstance(obj, IRObject): + return obj.value + else: + return obj for slicer in slicers: if slicer is None: edim_ou.append('1') @@ -1347,9 +1354,10 @@ def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice, int]], signatu elif isinstance(slicer, slice): if slicer != slice(None, None, None): edim_in[in_idx] += '^' - start = 0 if slicer.start is None else slicer.start - stop = tensor.shape[in_idx] if slicer.stop is None else slicer.stop - step = 1 if slicer.step is None else slicer.step + _start, _stop, _step = obj_helper(slicer.start), obj_helper(slicer.stop), obj_helper(slicer.step) + start = 0 if _start is None else _start + stop = tensor.shape[in_idx] if _stop is None else _stop + step = 1 if _step is None else _step dimlen = len(range(start, stop, step)) if dimlen == tensor.shape[in_idx]: edim_ou.append(edim_in[in_idx]) @@ -1658,10 +1666,9 @@ def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): raise RuntimeError(f'function.To with unknown arg: {dtype_or_device}') -def GetItem(a, b, signature = None) -> Union[Any, IRPyFunc]: - """ - _operator.getitem(obj, index: int) - """ +def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: + """_operator.getitem(a, b): return a[b]""" + assert not isinstance(b, IRObject) obj, index = a, b # tensor slice if isinstance(obj, IRTensor): @@ -1689,13 +1696,18 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], shape = IRObject('shape', value=obj.shape) return IRPyFunc(signature, [instance, field], [shape]) if name == 'dtype': + from cube.graph.parser.dtype import IRDType2TorchDType assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" assert hasattr(obj, name), f"attr {name} is not existed in {obj}" - return getattr(obj, name) + return IRDType2TorchDType.map(getattr(obj, name)) if name == 'device': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" # FIXME: this is hack, IRFullTensor does not have attribute "device" return torch.device('cpu') + if name == 'layout': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + warnings.warn('hack currently, please ensure the input tensor is in torch.strided layout') + return torch.strided if isinstance(obj, torch.finfo): return getattr(obj, name) return IRPyFunc(signature, [instance, field], [IRObject()]) @@ -1731,4 +1743,11 @@ def MakeTuple(inputs: Iterable, signature=None): def MakeList(inputs: Iterable, signature=None): - return list(inputs) + if isinstance(inputs, Iterable): + return list(inputs) + else: + return IRPyFunc(signature, [inputs], [IRObject(value=list(inputs.value))]) + + +def MakeSlice(*inputs: Iterable, signature=None): + return slice(*inputs) diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 62a7b984..4fde2f20 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -115,6 +115,7 @@ def exist(signature: str) -> bool: 'builtins.getattr': function.GetAttr, 'builtins.tuple': function.MakeTuple, 'builtins.list': function.MakeList, + 'builtins.slice': function.MakeSlice, # # torch nn functional '_operator.matmul': function.Matmul, diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 3de7267e..86ad4c62 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -144,17 +144,43 @@ def parse(module: torch.fx.GraphModule, # add activations to frame, including call_func/call_method output and final output # call_module corresponds to leaf torch.nn.module + from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata activation_op_strs = {'call_function', 'output', 'call_method', 'call_module'} activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] + def parse_complex_out(meta_out): + if isinstance(meta_out, TensorMetadata): + shape = meta_out.shape + assert shape == torch.Size([]), f'{meta_out}' + return torch.zeros(shape, dtype=meta_out.dtype, requires_grad=meta_out.requires_grad) + elif isinstance(meta_out, dict): + ret = {} + for k, v in meta_out.items(): + ret[k] = parse_complex_out(v) + return ret + else: + return meta_out for node in activation_nodes: - if hasattr(node, 'meta') and node.meta.get('tensor_meta') and hasattr(node.meta['tensor_meta'], 'dtype'): + if hasattr(node, 'meta') and node.meta.get('tensor_meta'): assert isinstance(node, torch.fx.Node) - shape = node.meta['tensor_meta'].shape - shape = FxModuleParser.shape_refine(shape) - dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) - requires_grad = node.meta['tensor_meta'].requires_grad - val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=node.name) - frame.add_var(node.name, val) + if isinstance(node.meta['tensor_meta'], TensorMetadata): + meta_outs = (node.meta['tensor_meta'],) + else: + meta_outs = node.meta['tensor_meta'] + vals = list() + for meta_out in meta_outs: + if isinstance(meta_out, TensorMetadata): + shape = meta_out.shape + shape = FxModuleParser.shape_refine(shape) + dtype = DType2IRDType.map(meta_out.dtype) + requires_grad = meta_out.requires_grad + val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=node.name) + else: + val = IRObject(value=parse_complex_out(meta_out)) + vals.append(val) + if len(vals) == 1: + frame.add_var(node.name, vals[0]) + else: + frame.add_var(node.name, vals) else: frame.add_var(node.name, IRObject()) @@ -308,12 +334,10 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, if isinstance(ir_node, IRCell): ir_nodes.append(ir_node) if len(ir_node.outputs()) > 1: - # REMARK: some nodes will return multiple outputs, e.g., torch.chunk, while torch.fx always - # return one output. This will cause `getitem`` or `unpacking`` operation on the output, - # which can be folded by setting the list of the output tensor - ir_node.infer_shape() - ir_node.infer_dtype() - frame.set_var(node.name, ir_node.outputs()) + vals = frame.get_var(node.name) + assert len(vals) == len(ir_node.outputs()), f'{vals}, {ir_node.outputs()}' + for i in range(len(vals)): + ir_node.set_output(i, vals[i]) elif ir_node.output(0).value is not None: if FxModuleParser.dynamic_shape: frame.set_var(node.name, ir_node.output(0)) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index f3fa0b4b..48da8a69 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1103,6 +1103,8 @@ def modify_objects_of_complex(val: Any, modifier: Callable) -> Any: return list(rcall(item, modifier) for item in val) if isinstance(val, dict): return {rcall(key, modifier):rcall(value, modifier) for key, value in val.items()} + if isinstance(val, slice): + return slice(rcall(val.start, modifier), rcall(val.stop, modifier), rcall(val.step, modifier)) if isinstance(val, IRObject): return modifier(val) return val From d1f2cce75dc67ed42d5c41c5d9190a009e50391c Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 10 Jul 2023 07:22:40 +0000 Subject: [PATCH 1418/1892] Merged PR 1651: Add some ops introduced by megatron gpt2 --- cube/graph/function/dimops.py | 2 +- cube/graph/function/function.py | 84 +++++++++++++++++++++++++++++---- cube/graph/parser/mappingfx.py | 10 +++- cube/graph/parser/parserfx.py | 16 +++++-- cube/ir/cten.py | 5 +- 5 files changed, 99 insertions(+), 18 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index c46e1a4a..aa142a3d 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -22,7 +22,7 @@ A `reduction` can be a set of {'', '+', '^'}: '' indicates this dimension can be partitioned, and each output should have this dimension. - '+' indicates this dimension can be partitioned, and each ouutput doesn't have this and need to do sum-reduction. + '+' indicates this dimension can be partitioned, and each output doesn't have this and need to do sum-reduction. '^' means this dimension cannot be partitioned. A dimension can also be annotated with inner-dimensions using brackets, i.e., '(' and ')'. diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index e626abfa..64e11298 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -254,9 +254,8 @@ def NewTensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False, signature=None): # note: device is ignored dtype = dtype if dtype is not None else torch.get_default_dtype() - assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" signature = 'cube.runtime.function.tensor' - size = tuple(np.array(data).shape) + size = tuple(np.array(data).shape) if np.array(data).shape else (1,) # (1,) means it is a scalar kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(size, True) @@ -315,13 +314,17 @@ def Expand(input, *sizes, signature = None): torch.Tensor.expand(*sizes) """ signature = 'cube.runtime.function.expand' - edim_in = ShapeAnno.create_shape_str(input.shape) - assert len(input.shape) == len(sizes) - edim_ou = copy.copy(edim_in) - for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes)): - if dim == 1 and dim != expand_dim: + ori_len, exp_len = len(input.shape), len(sizes) + assert ori_len <= exp_len + assert all(dim == expand_dim or dim == 1 or expand_dim == -1 for dim, expand_dim in zip(input.shape, sizes[-ori_len:])) + edim_ou = ShapeAnno.create_shape_str(sizes) + edim_in = copy.copy(edim_ou[-ori_len:]) + for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes[-len(input.shape):])): + if dim == 1 and dim != expand_dim and expand_dim != -1: edim_in[idx] += '^' - edim_ou[idx] = str(expand_dim) + edim_ou[exp_len - ori_len + idx] = str(expand_dim) + for idx in range(exp_len - ori_len): + edim_ou[idx] = str(sizes[idx]) anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(Expand, 'expand', signature, [anno], [input], sizes=sizes) @@ -408,6 +411,18 @@ def Mul(input, other, *, out=None, signature = None): return IRDimops(Mul, 'mul', signature, annos, [input, other]) +def Mod(input, other, *, out = None, signature = None): + assert out is None + if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): + return input % other + signature = 'torch.fmod' + annos = ['*, ? -> *'] + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(Mod, 'mod', signature, annos, [input, other]) + + def Div(input, other, *, rounding_mode=None, out=None, signature = None): assert rounding_mode is None and out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): @@ -839,6 +854,36 @@ def Transpose(input, dim0, dim1, signature = None): dim0=dim0, dim1=dim1) +def Split(tensor, split_size_or_sections, dim = 0, signature = None): + """ + torch.functional.split(tensor, split_size_or_sections, dim=0) -> List[Tensor] + """ + if isinstance(split_size_or_sections, int): + sections = [split_size_or_sections for _ in range(tensor.shape[dim] // split_size_or_sections)] + if tensor.shape[dim] % split_size_or_sections != 0: + sections.append(tensor.shape[dim] % split_size_or_sections) + else: + sections = split_size_or_sections + assert sum(sections) == tensor.shape[dim] + edim_in = ShapeAnno.create_shape_str(tensor.shape) + edim_ous = [copy.copy(edim_in) for _ in sections] + edim_in[dim] = str(tensor.shape[dim]) + for edim_ou, dimlen in zip(edim_ous, sections): + edim_ou[dim] = str(dimlen) + anno = OpAnno.create_op_str([edim_in], edim_ous) + return IRDimops(Split, 'split', signature, [anno], [tensor], split_size_or_sections=split_size_or_sections, dim=dim) + + +def Contiguous(input, memory_format = None, signature = None): + """ + torch.Tensor.contiguous(Tensor self) -> Tensor + """ + assert memory_format is None + anno = ['* -> *'] + signature = 'torch.Tensor.contiguous' + return IRDimops(Contiguous, 'contiguous', signature, anno, [input]) + + def _reshape_anno(in_shape: List[int], ou_shape: List[int], kwarg_name: str) -> Tuple[str, List[TransformRule]]: """ reshape / view annotation and transformation rule generator @@ -1424,7 +1469,7 @@ def Repeat(tensor, repeats: Tuple[int], *arg_repeats, signature = None): signature = 'torch.ops.aten.repeat' repeats = (repeats,) if isinstance(repeats, int) else tuple(repeats) repeats = repeats + arg_repeats - in_shape = tensor.shape + in_shape = list(tensor.shape) assert len(in_shape) <= len(repeats), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor" expand = len(repeats) - len(tensor.shape) in_shape += [1] * expand @@ -1637,7 +1682,7 @@ def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: torch.Tensor.size(tensor, dim=None) """ assert isinstance(tensor, IRTensor) - val = tensor.shape[dim] if isinstance(dim, int) else list(tensor.shape) + val = tensor.shape[dim] if isinstance(dim, int) else tensor.shape assert val is not None if dim is None: return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)]) @@ -1645,6 +1690,15 @@ def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)], dim=dim) +def Dim(tensor, signature=None) -> Union[List[int], IRPyFunc]: + """ + torch.Tensor.dim(tensor) + """ + assert isinstance(tensor, IRTensor) + # constant + return len(tensor.shape) + + def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): """ torch.Tensor.to(*args, **kwargs) → Tensor @@ -1751,3 +1805,13 @@ def MakeList(inputs: Iterable, signature=None): def MakeSlice(*inputs: Iterable, signature=None): return slice(*inputs) + + +def Is(input, other, signature=None): + assert not isinstance(input, IRObject) and not isinstance(other, IRObject) + return input is other + + +def IsNot(input, other, signature=None): + assert not isinstance(input, IRObject) and not isinstance(other, IRObject) + return input is not other diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/mappingfx.py index 4fde2f20..8ba3b9e0 100644 --- a/cube/graph/parser/mappingfx.py +++ b/cube/graph/parser/mappingfx.py @@ -79,6 +79,7 @@ def exist(signature: str) -> bool: __tttemplate('bool'): function.Bool, __ttemplate('fill_'): function.Fill, __ttemplate('masked_fill'): function.MaskedFill, + __tttemplate('masked_fill_'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, __ttemplate('tanh'): function.Tanh, __ftemplate('softmax') : function.Softmax, @@ -111,6 +112,7 @@ def exist(signature: str) -> bool: # ============== runtime function ================= __tttemplate('size'): function.Size, __tttemplate('to'): function.To, + __tttemplate('dim'): function.Dim, '_operator.getitem': function.GetItem, 'builtins.getattr': function.GetAttr, 'builtins.tuple': function.MakeTuple, @@ -142,7 +144,9 @@ def exist(signature: str) -> bool: # __ttemplate('to'): function.ToTensor, __ttemplate('rand'): function.Rand, # __ttemplate('clone'): function.Clone, - # + + '_operator.is_': function.Is, + '_operator.is_not': function.IsNot, __ttemplate('add') : function.Add, '_operator.add': function.Add, '_operator.iadd': function.Add, # FIXME: may waste memory @@ -151,6 +155,7 @@ def exist(signature: str) -> bool: __ttemplate('mul') : function.Mul, '_operator.mul': function.Mul, '_operator.imul': function.Mul, # FIXME: may waste memory + '_operator.mod': function.Mod, __ttemplate('div') : function.Div, __ttemplate('true_divide'): function.Div, @@ -166,6 +171,7 @@ def exist(signature: str) -> bool: __ttemplate('lt'): function.CompareLT, __ttemplate('ge'): function.CompareGE, __ttemplate('le'): function.CompareLE, + '_operator.le': function.CompareLE, # __ttemplate('sin'): function.Sin, # @@ -176,6 +182,7 @@ def exist(signature: str) -> bool: # # __ttemplate('view'): function.View, __tttemplate('view'): function.View, + __tttemplate('contiguous'): function.Contiguous, __ttemplate('reshape'): function.Reshape, # @@ -221,4 +228,5 @@ def exist(signature: str) -> bool: # #einops # __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, + 'torch.functional.split': function.Split } diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 86ad4c62..713e441d 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -1,7 +1,7 @@ import torch import enum import warnings -from typing import Any, List, Tuple, Optional, Callable, Union, Dict +from typing import Any, List, Tuple, Callable, Union, Dict, Type from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor @@ -479,18 +479,24 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> print(f'The method is not torch or Tensor, but {sig}') return sig - # this is fixed on master, WAR for 1.5 @staticmethod def _find_module_of_method(orig_method: Callable[..., Any]) -> str: name = orig_method.__name__ module = orig_method.__module__ - if module is not None: + if getattr(orig_method, '__name__', None) == 'apply' and isinstance(getattr(orig_method, '__self__', None), Type) \ + and issubclass(orig_method.__self__, torch.autograd.Function): + # for torch.autograd.Function + return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' + elif module is not None: return module - for guess in [torch, torch.nn.functional]: + elif hasattr(orig_method, '__qualname__')\ + and isinstance(orig_method.__qualname__, str) and orig_method.__qualname__.startswith('_VariableFunctionsClass.'): + return 'torch._C._VariableFunctions' + for guess in [torch, getattr(torch.nn, 'functional')]: if getattr(guess, name, None) is orig_method: return guess.__name__ raise RuntimeError(f'cannot find module for {orig_method}') - + @staticmethod def _is_torch_autograd_op(node: torch.fx.Node, frame: Frame, signature: str) -> bool: """Check whether the node is of a pytorch autograd operation.""" diff --git a/cube/ir/cten.py b/cube/ir/cten.py index a73f3c5d..3c326461 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -650,7 +650,10 @@ def __copy__(self): @property def shape(self) -> Tuple[int]: - return list(self._shape) + # NOTE: here return a tuple but not a real torch.Size obj may have risk, here is an example: + # (torch.Size + tuple -> torch.Size) will change to (tuple + tuple -> tuple), is ok. + # (torch.Size + list -> torch.Size) will change to (tuple + list -> error), is wrong. + return self._shape @shape.setter def shape(self, val: Tuple[int]): From 887f2604712261ff0eb1a2947b2014dd8a9082be Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jul 2023 13:53:17 +0800 Subject: [PATCH 1419/1892] allow partition on mutliple dimensions of a tensor --- cube/algorithm/ops/dimops.py | 43 +++++++++++++++++++++------- cube/graph/function/dimops.py | 29 ++++++++++++------- examples/policies/alpa/cost_model.py | 2 +- 3 files changed, 53 insertions(+), 21 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 32e1b8fe..a87a69b0 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -1,4 +1,6 @@ from typing import List, Optional, Any, Dict, Union, Tuple +import warnings +import numpy as np from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule @@ -136,7 +138,14 @@ def transform(tensor: Any, split: DimopSplit) -> List[Any]: if not isinstance(tensor, IRSubTensor): return [tensor] * num if split.isD(): - return tensor.split_dim(split.dim, num) + sub_tensors = [tensor] + for dim in split.dims: + for _ in range(len(sub_tensors)): + sub_tensor = sub_tensors.pop(0) + sub_tensors += sub_tensor.split_dim(dim, num) + sub_tensors = np.array(sub_tensors, dtype=IRSubTensor).reshape((num,) * len(split.dims)) + sub_tensors = [sub_tensors[(i,) * len(split.dims)] for i in range(num)] + return sub_tensors if split.isR(): return tensor.replicate(num) if split.isV(): @@ -180,8 +189,11 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR if splits[idx].isD(): # make negative offset to be possitive ndims = len(node.input(idx).shape) - rdim = (splits[idx].dim + ndims) % ndims - if rdim == dim: + # rdim = (splits[idx].dims + ndims) % ndims + # if rdim == dim: + # return r + rdims = tuple((d + ndims) % ndims for d in splits[idx].dims) + if dim in rdims: return r else: if splits[idx].isV(): @@ -195,20 +207,31 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR # input for idx, idim in enumerate(node.anno.inputs()): dims = idim.getdims(adim) - assert len(dims) <= 1, "Cannot split on multiple same tensors" - if len(dims) == 1: - itransform.append(DimopSplit.D(dims[0])) - else: + if len(dims) == 0: itransform.append(DimopSplit.R()) + else: + if len(dims) > 1: + warnings.warn( + f'node ({self.node.name}-{self.node.cid}): detected an input tensor ' + f'is split on {len(dims)} dimensions, this will cause data loss.', + category=RuntimeWarning, stacklevel=0, + ) + itransform.append(DimopSplit.D(dims)) # output for idx, odim in enumerate(node.anno.outputs()): dims = odim.getdims(adim) - if len(dims) == 1: - otransform.append(DimopSplit.D(dims[0])) - else: + if len(dims) == 0: otransform.append( DimopSplit.R() if reduce == DimAnno.ReduceType.Dim else DimopSplit.V() ) + else: + if len(dims) > 1: + warnings.warn( + f'node ({self.node.name}-{self.node.cid}): detected an output tensor ' + f'is split on {len(dims)} dimensions, this will cause data loss.', + category=RuntimeWarning, stacklevel=0, + ) + otransform.append(DimopSplit.D(dims)) # modifier def modify(kwargs: Dict, idx: int, dim: int, num: int): updated_kwargs = dict(**kwargs) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index aa142a3d..5c30f0b0 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -488,16 +488,25 @@ class DimopSplit: """ Partition status of a tensor """ - def __init__(self, dim: Optional[int] = None, r = False, v = False) -> None: - self.dim = dim - self.rep = r - self.val = v + def __init__(self, dims: Optional[Union[int, List[int]]] = None, r = False, v = False) -> None: + """Dimension split config + + Args: + dims (Optional[Union[int, List[int]]], optional): [description]. Defaults to None. + """ + if isinstance(dims, int): + dims = (dims,) + elif isinstance(dims, Iterable): + dims = tuple(sorted(dims)) + self.dims: Optional[Tuple[int]] = dims + self.rep: bool = r + self.val: bool = v def isR(self) -> bool: return self.rep def isD(self) -> bool: - return self.dim is not None + return self.dims is not None def isV(self) -> bool: return self.val @@ -507,7 +516,7 @@ def __eq__(self, other): return False if other.isR() and self.isR(): return True - if other.isD() and self.isD() and other.dim == self.dim: + if other.isD() and self.isD() and other.dims == self.dims: return True if other.isV() and self.isV(): return True @@ -519,11 +528,11 @@ def __hash__(self) -> int: elif self.isR(): return -2 else: - return self.dim + return self.dims def __repr__(self) -> str: if self.isD(): - return f'D({self.dim})' + return f'D({self.dims})' if self.isR(): return f'R' if self.isV(): @@ -539,8 +548,8 @@ def V(): return DimopSplit(v=True) @staticmethod - def D(dim: int): - return DimopSplit(dim=dim) + def D(dims: Union[int, List[int]]): + return DimopSplit(dims=dims) class TransformRule: diff --git a/examples/policies/alpa/cost_model.py b/examples/policies/alpa/cost_model.py index a14b2125..57772c5b 100644 --- a/examples/policies/alpa/cost_model.py +++ b/examples/policies/alpa/cost_model.py @@ -181,7 +181,7 @@ def comm_cost(tensor: IRTensor, num_devices: int, CommCost.allgather_cost(tensor, num_devices) + CommCost.reducescatter_cost(tensor, num_devices) # all2all-all2all or identity-identity if dst_split.isD(): - return 0.0 if src_split.dim == dst_split.dim else 2 * CommCost.alltoall_cost(tensor, num_devices) + return 0.0 if src_split == dst_split else 2 * CommCost.alltoall_cost(tensor, num_devices) raise NotImplementedError(f"Unknown split type: {src_split} -> {dst_split}") # FIXME: need consider cases that an operator has multiple **same** inputs From ceae089ce9d03d159d6e898bf005f9c26220df3a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 11 Jul 2023 13:56:12 +0800 Subject: [PATCH 1420/1892] add tests for operator partition algorithms --- tests/algorithm/test_op_algorithm.py | 56 ++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 tests/algorithm/test_op_algorithm.py diff --git a/tests/algorithm/test_op_algorithm.py b/tests/algorithm/test_op_algorithm.py new file mode 100644 index 00000000..95f51545 --- /dev/null +++ b/tests/algorithm/test_op_algorithm.py @@ -0,0 +1,56 @@ +""" +python tests/algorithm/test_op_algorithm.py +pytest tests/algorithm/test_op_algorithm.py +""" + +from typing import Callable, Tuple, List +from functools import partial + +import cube +import cube.graph.function as F +from cube.graph.function.dimops import IRDimops +from cube.ir.tensor import IRFullTensor + + +Shape=Tuple[int] + + +def create_op(creator: Callable, + input_shapes: List[Tuple[int]], *args, **kwargs): + inputs = tuple(IRFullTensor(shape=shape).tosub() for shape in input_shapes) + return creator(*(inputs+args), **kwargs) + + +def partitionable(node: IRDimops, **config): + print(f'\n\n# {node.anno}') + print(f'testing node: {node}') + sub_nodes = node.algorithms('dim').instantiate(**config) + print(f'partitioned sub nodes:') + for sub_node in sub_nodes: + print(f'# {sub_node.anno}') + print(sub_node) + + +test_view1 = partial(partitionable, + create_op(F.Reshape, [(2048, 16, 64),], shape=[2048, 2, 512]), + idx=0, dim=1, num=2, +) + +test_view2 = partial(partitionable, + create_op(F.Reshape, [(2048, 8, 64),], shape=[2048, 1, 512]), + idx=0, dim=1, num=2, +) + +def create_udf_op1(input, weight, signature='test_udf_op1'): + anno = 'L 8^ (L 2), L E -> 8^ (L 2) E ' + return IRDimops(create_udf_op1, 'udf_op1', signature, [anno], [input, weight]) + +test_multi_dim_partition = partial(partitionable, + create_op(create_udf_op1, [(2048, 8, 4096), (2048, 4096)]), + idx=0, dim=0, num=2, +) + +if __name__ == '__main__': + test_view1() + test_view2() + test_multi_dim_partition() \ No newline at end of file From d384d8955ee7846ed4ab90b9b2efcb5da37a4d6d Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 11 Jul 2023 07:21:12 +0000 Subject: [PATCH 1421/1892] Merged PR 1650: Remove DCEHandler --- .../parser/concrete_trace_utils/__init__.py | 1 + .../kwargs_shape_prop/kwargs_interpreter.py | 1 - .../kwargs_shape_prop/kwargs_shape_prop.py | 37 +------------------ .../parser/concrete_trace_utils/utils.py | 13 +++++++ cube/graph/parser/converter.py | 4 +- cube/graph/parser/parserfx.py | 2 - 6 files changed, 18 insertions(+), 40 deletions(-) diff --git a/cube/graph/parser/concrete_trace_utils/__init__.py b/cube/graph/parser/concrete_trace_utils/__init__.py index 0e04acba..e4a574ff 100644 --- a/cube/graph/parser/concrete_trace_utils/__init__.py +++ b/cube/graph/parser/concrete_trace_utils/__init__.py @@ -12,3 +12,4 @@ More information about concrete tracing can be found in the :func:`concrete_trace` documentation. """ from .concrete_tracer import ConcreteTracer, concrete_trace +from .utils import ExtraSEFPatcher diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py index 4799a2b3..e8add705 100644 --- a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py +++ b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py @@ -1,6 +1,5 @@ import torch import torch.fx -import torch.fx.traceback as fx_traceback from torch.fx import Interpreter, Node, GraphModule from typing import Optional, Union, Tuple, Dict, List, Any, Iterator, Callable, MutableMapping, Mapping from torch.utils._pytree import tree_map diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py index 2cbad447..0a63e95c 100644 --- a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py +++ b/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py @@ -1,14 +1,11 @@ -import builtins -import operator import torch import traceback -from torch.fx import GraphModule from torch.fx.node import Node, map_aggregate from typing import Optional, Union, NamedTuple, Tuple, Any, Dict from .kwargs_interpreter import KwargsInterpreter -__all__ = ['TensorMetadata', 'KwargsShapeProp', 'DCEHandler'] +__all__ = ['TensorMetadata', 'KwargsShapeProp'] class TensorMetadata(NamedTuple): @@ -100,35 +97,3 @@ def extract_tensor_meta(obj): def propagate(self, concrete_args: Union[Dict[str, Any], Tuple]): return super().run(concrete_args) - - -class DCEHandler: - dont_delete = [operator.setitem, builtins.next] - - def __init__(self, gm: torch.fx.GraphModule): - self.gm = gm - - def eliminate_dead_code(self): - to_check = set() - for node in self.gm.graph.nodes: - to_check.add(node) - while True: - deleted = False - modified = set() - for node in to_check: - if node.op == 'output': - continue - if not node.users and node.op != 'placeholder' and node.target not in DCEHandler.dont_delete: - for input_node in node.all_input_nodes: - input_node.users.pop(node) - modified.add(input_node) - node._remove_from_list() - if node in modified: - modified.remove(node) - deleted = True - if deleted is False: - break - else: - to_check = modified - name = self.gm.__class__.__name__ - return GraphModule(self.gm, self.gm.graph, name) diff --git a/cube/graph/parser/concrete_trace_utils/utils.py b/cube/graph/parser/concrete_trace_utils/utils.py index 2604340a..202af5e2 100644 --- a/cube/graph/parser/concrete_trace_utils/utils.py +++ b/cube/graph/parser/concrete_trace_utils/utils.py @@ -108,3 +108,16 @@ def map_recursive_zip(fn: Callable, arg0, *args) -> Any: else: # assert not _orig_isinstance(arg0, slice) return fn(arg0, *args) + + +class ExtraSEFPatcher: + from torch.fx.node import _side_effectful_functions + # some side effectful functions that should not be deleted during dead code elimination + # there may be more than listed here + extra_funcs = {operator.setitem, builtins.next} - _side_effectful_functions + + def __enter__(self): + self._side_effectful_functions.update(self.extra_funcs) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._side_effectful_functions.difference_update(self.extra_funcs) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 7ae15cae..191a5c5a 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -4,7 +4,7 @@ from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser from cube.graph.parser import FxModuleParser, FxFuncOpTracer -from cube.graph.parser.concrete_trace_utils import concrete_trace +from cube.graph.parser.concrete_trace_utils import concrete_trace, ExtraSEFPatcher from cube.graph.parser.register import CustomizedOps from cube.graph import IRGraph from cube.flags import CompileFlag @@ -38,6 +38,8 @@ def convert_model(model: torch.nn.Module, # Symbolic tracing frontend - captures the semantics of the module tracer = FxFuncOpTracer() traced_graph: torch.fx.Graph = tracer.trace(model) + with ExtraSEFPatcher(): + traced_graph.eliminate_dead_code() traced_model: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) if CompileFlag.log_parser: traced_model.graph.print_tabular() diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 713e441d..73e53924 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -94,9 +94,7 @@ def parse(module: torch.fx.GraphModule, The overall entry to parse a torch.fx graph module """ - from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import DCEHandler from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp - DCEHandler(module).eliminate_dead_code() frame = frame if frame is not None else Frame() frame.push_var() From 2c32c10ead527a4698641eb2cf1f8e6a7e79b830 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 12 Jul 2023 06:45:05 +0000 Subject: [PATCH 1422/1892] Merged PR 1659: Add origin code as comments to the gen code --- .../concrete_trace_utils/concrete_tracer.py | 21 ++++++++++++++++++- .../parser/concrete_trace_utils/utils.py | 14 +++++++++++++ cube/graph/parser/parserfx.py | 8 +++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py index 7f9e195b..b809fefc 100644 --- a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/concrete_trace_utils/concrete_tracer.py @@ -11,12 +11,14 @@ import operator import functools import builtins -from packaging import version +import traceback +import importlib.util from itertools import chain from types import BuiltinMethodType, FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType from typing import Any, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, List, Callable, Union from contextlib import contextmanager +from pathlib import Path import torch from torch._C import ScriptObject @@ -122,6 +124,8 @@ def __exit__(self, *args): _orig_max, _orig_node_is_impure, + + FrameRecord, ) # some side effectful functions that should not be deleted during dead code elimination @@ -443,6 +447,21 @@ def upwrapper(obj: Any): node = self.create_node(kind, target, args_, kwargs_, name, type_expr) + # record code frame, include filename, line number, and function name + frame_record = FrameRecord(None, None, None, None) + cube_cct_path = str(Path(__file__).parent) + '/' # the cube concrete tracer path + torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path + ignore_dirs = [cube_cct_path, torch_path] + for frame in traceback.extract_stack()[-2::-1]: + if any(p in frame.filename for p in ignore_dirs): + continue + frame_record.filename = frame.filename + frame_record.lineno = frame.lineno + frame_record.line = frame.line + frame_record.name = frame.name + break + node.meta['frame_record'] = frame_record + proxy = self.proxy(value_unwrapped, node) return proxy diff --git a/cube/graph/parser/concrete_trace_utils/utils.py b/cube/graph/parser/concrete_trace_utils/utils.py index 202af5e2..8265e6df 100644 --- a/cube/graph/parser/concrete_trace_utils/utils.py +++ b/cube/graph/parser/concrete_trace_utils/utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import builtins +from dataclasses import dataclass import operator from typing import Any, Callable, Type import functools @@ -110,6 +111,19 @@ def map_recursive_zip(fn: Callable, arg0, *args) -> Any: return fn(arg0, *args) +@dataclass +class FrameRecord: + filename: str + lineno: str + line: str + name: str + + def __repr__(self) -> str: + if self.filename: + return f'File "{self.filename}", line {self.lineno}, in {self.name}, {self.line}' + else: + return '' + class ExtraSEFPatcher: from torch.fx.node import _side_effectful_functions # some side effectful functions that should not be deleted during dead code elimination diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/parserfx.py index 73e53924..7e30378d 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/parserfx.py @@ -294,6 +294,9 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: assert len(ir_node.outputs()) == 1 output_val = frame.get_var(node.name) ir_node.set_output(0, output_val) + comment = str(node.meta.get('frame_record', '')) + if comment: + ir_node.comment = comment return [ir_node] else: raise RuntimeError(f'unknown module: {prim_module.__class__.__module__}') @@ -328,6 +331,11 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, category=RuntimeWarning) ir_node = IRPyFunc(fsig, input_vals, [IRObject()], **kwargs) + if isinstance(ir_node, IRCell): + comment = str(node.meta.get('frame_record', '')) + if comment: + ir_node.comment = comment + ir_nodes = [] if isinstance(ir_node, IRCell): ir_nodes.append(ir_node) From b02cd46a89193176921b8e5d4f863de9419d0028 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 09:14:53 +0800 Subject: [PATCH 1423/1892] fix IRPyfunc auto-schedule --- cube/graph/gener/gen.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 23cfc806..95e452cd 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -137,9 +137,21 @@ def remove_anchor(graph: IRSegment): @staticmethod def auto_pyfunc(graph: IRGraph): - """ - Make pyfunc to be local + """Transform and assign IRPyFunc. + IRPyFunc will be replicated to devices with its producers output + + Note if an IRPyFunc has no input, indicating its device can not + be indicated from any other operators. In this case, the pyfunc + will be replicated to all devices in its segment. To restrict + the replicaed devices in pipeline-like scenarios, use `graph.staging` + to group the operators into segments. + + Args: + graph (IRGraph): the graph to be transformed + + Returns: + graph (IRGraph): the transformed graph """ for func in graph.select(ntype=IRPyFunc, flatten=True): # get devices it will lowered to @@ -155,6 +167,10 @@ def auto_pyfunc(graph: IRGraph): if not isinstance(t, IRObject): continue if t in segment_outputs: devices.update(segment.device) + # if a pyfunc doesn't have input, it will be replicated + # to all devices in its segment. + if len(devices) == 0: + devices = set(segment.device) # replicate pyfuncs = [func.replicate() for _ in devices] for devid, pyfunc in zip(sorted(devices), pyfuncs): From 7ab1852882543100256404a3d3ff1f4885169b78 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 14:03:02 +0800 Subject: [PATCH 1424/1892] remove useless handcraft code --- handcraft/bigbird/sparse.py | 550 --------- handcraft/bigbird/sparse_attn.py | 968 --------------- handcraft/efficientnet/efficientnet.py | 411 ------- handcraft/efficientnet/schedule.py | 250 ---- handcraft/efficientnet/train.py | 338 ------ handcraft/efficientnet/utils.py | 586 --------- handcraft/gpt3/test-1gpu.sh | 71 -- handcraft/gpt3/test-1node.sh | 284 ----- handcraft/gpt3/test-2node.sh | 115 -- handcraft/gpt3/train.py | 745 ------------ handcraft/mbart/swap.py | 146 --- handcraft/mbart/test-2node-fp32.sh | 193 --- handcraft/mbart/test-4node-fp32.sh | 216 ---- handcraft/mbart/test-fp32.sh | 153 --- handcraft/mbart/train.py | 861 -------------- handcraft/module/distnn.py | 278 ----- handcraft/module/schedule.py | 698 ----------- handcraft/module/stage.py | 166 --- handcraft/playground/dag/data_parallel_raw.py | 130 -- .../playground/dag/graph_manipulation.py | 550 --------- handcraft/playground/dag/graph_trans.py | 46 - handcraft/playground/test.sh | 144 --- handcraft/playground/transformers.py | 326 ----- handcraft/swin/test-2node.sh | 235 ---- handcraft/swin/test-4node.sh | 300 ----- handcraft/swin/test.sh | 268 ----- handcraft/swin/train.py | 1045 ----------------- handcraft/swin/utils.py | 115 -- handcraft/textnas/dataloader.py | 144 --- handcraft/textnas/dataset.sh | 6 - handcraft/textnas/ops.py | 240 ---- handcraft/textnas/train.py | 280 ----- 32 files changed, 10858 deletions(-) delete mode 100644 handcraft/bigbird/sparse.py delete mode 100644 handcraft/bigbird/sparse_attn.py delete mode 100644 handcraft/efficientnet/efficientnet.py delete mode 100644 handcraft/efficientnet/schedule.py delete mode 100644 handcraft/efficientnet/train.py delete mode 100644 handcraft/efficientnet/utils.py delete mode 100755 handcraft/gpt3/test-1gpu.sh delete mode 100755 handcraft/gpt3/test-1node.sh delete mode 100755 handcraft/gpt3/test-2node.sh delete mode 100644 handcraft/gpt3/train.py delete mode 100644 handcraft/mbart/swap.py delete mode 100755 handcraft/mbart/test-2node-fp32.sh delete mode 100755 handcraft/mbart/test-4node-fp32.sh delete mode 100755 handcraft/mbart/test-fp32.sh delete mode 100644 handcraft/mbart/train.py delete mode 100644 handcraft/module/distnn.py delete mode 100644 handcraft/module/schedule.py delete mode 100644 handcraft/module/stage.py delete mode 100644 handcraft/playground/dag/data_parallel_raw.py delete mode 100644 handcraft/playground/dag/graph_manipulation.py delete mode 100644 handcraft/playground/dag/graph_trans.py delete mode 100755 handcraft/playground/test.sh delete mode 100644 handcraft/playground/transformers.py delete mode 100755 handcraft/swin/test-2node.sh delete mode 100755 handcraft/swin/test-4node.sh delete mode 100755 handcraft/swin/test.sh delete mode 100644 handcraft/swin/train.py delete mode 100644 handcraft/swin/utils.py delete mode 100644 handcraft/textnas/dataloader.py delete mode 100755 handcraft/textnas/dataset.sh delete mode 100644 handcraft/textnas/ops.py delete mode 100644 handcraft/textnas/train.py diff --git a/handcraft/bigbird/sparse.py b/handcraft/bigbird/sparse.py deleted file mode 100644 index 05ae5997..00000000 --- a/handcraft/bigbird/sparse.py +++ /dev/null @@ -1,550 +0,0 @@ -""" -BigBird paper -https://papers.nips.cc/paper/2020/file/c8512d142a2d849725f31a9a7a361ab9-Paper.pdf - -Understanding blog: -https://github.com/huggingface/blog/blob/main/big-bird.md - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/bigbird/sparse.py \ - --hidden-size 5120 --heads 32 --seqlen 4096 \ - --bs 1 --fp16 -""" - -import torch -import torch.nn as nn -import cube -import math -import numpy as np - -import argparse -from cube.runtime.device import DeviceGroup -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - -from cube.runtime.adapter.distnn import AllGatherSplit, IdentityAllreduce - - - -parser = argparse.ArgumentParser(description='sparse_attention') - -parser.add_argument('--hidden-size', type=int, default=4096, - help='hidden size') -parser.add_argument('--heads', type=int, default=32, - help='number of heads') -parser.add_argument('--seqlen', type=int, default=4096, - help='sequence length') -parser.add_argument('--blk-size', type=int, default=64, - help='sequence length') -# training config -parser.add_argument('--bs', type=int, default=256, - help='num of micro batch') -parser.add_argument('--fp16', action='store_true', default=False) -parser.add_argument('--sparse', action='store_true', default=False) -args = parser.parse_args() - -print(args) -cube.init() - -tp_ranks = list(range(DeviceGroup().world_size)) -_tp_group = -1 -_tp_size = len(tp_ranks) -_tp_rank = DeviceGroup().rank -if len(tp_ranks) > 1: - print_each_rank(f'initializing tp ranks: {tp_ranks}') - _tp_group = DeviceGroup().get_group(tp_ranks) - - -class Config: - - num_attention_heads = args.heads - hidden_size = args.hidden_size - all_head_sie = hidden_size - seqlen = args.seqlen # seqlen - num_random_blocks = 2 - block_size=args.blk_size - use_bias = True - -config = Config() - - -def bmm(tensor1: torch.Tensor, tensor2: torch.Tensor, ndim: int, out=None): - # print(f'bmm: {tensor1.size()} {tensor2.size()}') - return torch.bmm( - tensor1.reshape((-1,) + tensor1.shape[-2:]), - tensor2.reshape((-1,) + tensor2.shape[-2:]), - out=out - ).view(tensor1.shape[: ndim - 2] + (tensor1.shape[ndim - 2], tensor2.shape[ndim-1])) - - -def stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - h: int, block_size: int, out=None): - CudaTimer().start('stride_q@k') - # print('start stride qk') - # q, k, v: (N h) L d - num_head = h - L, N = q.size(1), q.size(0) // h - dim_head = q.size(2) - assert L % block_size == 0 - - q = q.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d - k = k.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d - - # stride diagnal [2:] - middle_q = q[:, 2:] # (N h) (nblock-2) blksize d - # (N h) nblock-3 (3 blksize) d - sliding_keys = torch.cat((k[:, 1:-2], k[:, 2:-1], k[:, 3:]), dim=2) - # (N h) 1 blksize d - pad_zero = torch.zeros_like(k[:,-3:-2]) - # (N h) 1 (3 blksize) d - sliding_bottom_keys = torch.cat((pad_zero, k[:,-2:-1], k[:,-1:]), dim=2) - # (N h) (nblock-2) (3 blksize) d - sliding_keys = torch.cat((sliding_keys, sliding_bottom_keys), dim=1) - # (N h) (nblock-2) d (3 blksize) - sliding_keys = sliding_keys.transpose(2, 3) - - # (N h) (nblock-2) blksize (3 blksize) - out = bmm(middle_q, sliding_keys, ndim=4, out=out) - # (N h) ((nblock-2) blksize) (3 blksize) - qk = out.view(N * h, -1, block_size * 3) - CudaTimer().stop('stride_q@k') - return qk - - -def parallel_stride_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - h: int, block_size: int, start: int, end: int): - CudaTimer().start('parallel_stride_q@k') - # print('start stride qk') - # q, k, v: (N h) L d - num_head = h - L, N = q.size(1), q.size(0) // h - dim_head = q.size(2) - assert L % block_size == 0 - - q = q.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d - k = k.view(N * num_head, L // block_size, block_size, dim_head) # (N h) nblock blksize d - - # (N h) 1 blksize d - pad_zero = torch.zeros_like(k[:,-1:]) - # (N h) (end-start) blksize d - middle_q = q[:,2+start:2+end] - if end + 2 == L // block_size: - # (N h) (end-start) blksize d - k_right = torch.cat((k[:,start+3:], pad_zero), dim=1) - else: - k_right = k[:,start+3:end+3] - # (N h) (end-start) (3 blksize) d - sliding_keys = torch.cat((k[:, start+1:end+1], k[:, start+2:end+2], k_right), dim=2) - # (N h) (end-start) blksize (3 blksize) - qk = bmm(middle_q, sliding_keys.transpose(2, 3), ndim=4) - # (N h) (nblock-2) blksize (3 blksize) - qk = torch.nn.functional.pad( - qk, (0,0,0,0,start,L//block_size-2-end), 'constant', 0) - # (N h) ((nblock-2) blksize) (3 blksize) - qk = qk.view(N * h, -1, block_size * 3) - CudaTimer().stop('parallel_stride_q@k') - return qk - - -def stride_v(v: torch.Tensor, h: int, block_size: int): - L, N, dim_head = v.size(1), v.size(0) // h, v.size(2) - assert L % block_size == 0 - v = v.view(N * h, L // block_size, block_size, dim_head) - # (N h) 1 blksize d - pad_zero = torch.zeros_like(v[:,-3:-2]) - # (N h) (nblock-3) (3 blksize) d - stride_vals = torch.cat((v[:, 1:-2], v[:, 2:-1], v[:, 3:]), dim=2) - # (N h) 1 (3 blksize) d - stride_bottom_vals = torch.cat((v[:,-2:-1], v[:,-1:], pad_zero), dim=2) - # (N h) (nblock-2) (3 blksize) d - stride_vals = torch.cat((stride_vals, stride_bottom_vals), dim=1) - return stride_vals - - -def global_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - h: int, block_size: int): - CudaTimer().start('global_q@k') - # print('start global qk') - # q, k, v: (N h) L d - num_head = h - L, N = q.size(1), q.size(0) // h - dim_head = q.size(2) - assert L % block_size == 0 - - # first two row - head_q = q[:, :2 * block_size] # (N h) (2 blocksize) d - head_k = k.transpose(1, 2) # (N h) d L - head = bmm(head_q, head_k, ndim=3) # (N h) (2 blocksize) L - # (N h) L d - head_v = v - - # remain first two column - col_q = q[:, 2 * block_size:] # (N h) ((nblock-2) blocksize) d - col_k = k[:, :2 * block_size].transpose(1, 2) # (N h) d (2 blocksize) - # (N h) (2 blksize) d - col_v = v[:, :2 * block_size] - col = bmm(col_q, col_k, ndim=3) # (N h) ((nblock-2) blocksize) (2 blocksize) - CudaTimer().stop('global_q@k') - return head, head_v, col, col_v - - -def randn_qk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - h: int, block_size: int, rand_num: int = 2, out=None): - CudaTimer().start('rand_q@k') - torch.manual_seed(0) - # q, k, v: (N h) L d - # rand_num = 2 - num_head = h - L, N = q.size(1), q.size(0) // h - dim_head = q.size(2) - # nblock-2 2 - indices = torch.randint( - 2, L // block_size, (L // block_size-2, rand_num), - dtype=torch.int64, device=torch.cuda.current_device() - ) - # (N h) nblock blksize d - k = k.view(N * num_head, L // block_size, block_size, dim_head) - - # Optimize: remove for loop, use direct index can greatly speedup - # (N h) nblock-2 (randnum blksize) d - keys = tuple(k[:,indices[:,idx]] for idx in range(rand_num)) - gathered_k = torch.cat(keys, dim=2) - - # (N h) nblock blksize d - q = q.view(N * num_head, L // block_size, block_size, dim_head) - # (N h) nblock-2 blksize d - q = q[:,2:] - # (N h) nblock-2 blksize (randnum blksize) - out = bmm(q, gathered_k.transpose(2, 3), ndim=4, out=out) - # (N h) ((nblock-2) blksize) (randnum blksize) - qk = out.view(N * h, -1, rand_num * block_size) - CudaTimer().stop('rand_q@k') - return qk - - -def randn_v(v: torch.Tensor, h: int, block_size: int, rand_num: int = 2): - # v: (N h) L d - # CudaTimer().start('rand_v') - torch.manual_seed(0) - L, N, dim_head = v.size(1), v.size(0) // h, v.size(2) - # nblock-2 2 - indices = torch.randint( - 2, L // block_size, (L // block_size-2, rand_num), - dtype=torch.int64, device=torch.cuda.current_device() - ) - v = v.view(N * h, L // block_size, block_size, dim_head) - vals = tuple(v[:,indices[:,idx]] for idx in range(rand_num)) - gathered_v = torch.cat(vals, dim=2) - # CudaTimer().stop('rand_v') - return gathered_v - - -def sparse_attn(query: torch.Tensor, - q_proj: torch.Tensor, q_bias: torch.Tensor, - k_proj: torch.Tensor, k_bias: torch.Tensor, - v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, block_size: int): - rand_num = 2 - num_head = h - L, N = query.size(0), query.size(1) - dim_head = q_proj.size(0) // num_head - nblocks = L // block_size - - CudaTimer().start('to_qkv') - q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - CudaTimer().stop('to_qkv') - - sqk = torch.empty( - N * h, (nblocks - 2) * block_size, 3 * block_size, - dtype=torch.float16 if args.fp16 else torch.float32, - device=torch.cuda.current_device() - ) - - rqk = torch.empty( - N * h, (nblocks - 2) * block_size, 2 * block_size, - dtype=torch.float16 if args.fp16 else torch.float32, - device=torch.cuda.current_device() - ) - # we don't need pre-allocation as memory are sufficient - # sqk = rqk = None - - CudaTimer().start('q@k') - # sqk: (N h) ((nblock-2) blksize) (3 blksize) - sqk = stride_qk(q, k, v, h, block_size=block_size, out=sqk) - # head: (N h) (2 blocksize) L - # head_v: (N h) L d - # col: (N h) ((nblock-2) blocksize) (2 blocksize) - # col_v: (N h) (2 blksize) d - head, head_v, col, col_v = global_qk(q, k, v, h, block_size=block_size) - # rqk: (N h) ((nblock-2) blksize) (2 blksize) - rqk = randn_qk(q, k, v, h, block_size=block_size, out=rqk) - # (N h) ((nblock-2) blksize) (7 blksize) - middle_attn = torch.cat((col, sqk, rqk), dim=-1) - CudaTimer().stop('q@k') - - CudaTimer().start('all_softmax') - # (N h) (2 blksize) L - head_attn = torch.nn.functional.softmax(head, dim=-1) - head_attn = torch.nn.functional.dropout(head_attn, dropout_p, True, False) - # (N h) ((nblock-2) blksize) (7 blksize) - middle_attn = torch.nn.functional.softmax(middle_attn, dim=-1) - middle_attn = torch.nn.functional.dropout(middle_attn, dropout_p, True, False) - CudaTimer().stop('all_softmax') - - CudaTimer().start('qk@v') - # sqk_v: (N h) (nblock-2) (3 blksize) d - sqk_v = stride_v(v, h, block_size) - # rqk_v: (N h) (nblock-2) (2 blksize) d - rqk_v = randn_v(v, h, block_size) - - CudaTimer().start('global_qk@v') - # (N h) (2 blocksize) L, (N h) L d -> (N h) (2 blksize) d - head_output = bmm(head_attn, v, ndim=3) - - # global col v: (N h) ((nblock-2) blksize) (2 blksize), (N h) (2 blksize) d - # :-> (N h) (L-(2 blksize)) d - middle_output = bmm(middle_attn[:,:,:2 * block_size], col_v, ndim=3) - CudaTimer().stop('global_qk@v') - - CudaTimer().start('stride_qk@v') - middle_stride = middle_attn[:,:,2*block_size:5*block_size].view( - N * h, L // block_size - 2, block_size, 3 * block_size - ) - # stide v: (N h) (nblock-2) blksize (3 blksize), (N h) (nblock-2) (3 blksize) d - # : -> (N h) (nblock-2) blksize d - middle_stride_output = bmm(middle_stride, sqk_v, ndim=4) - middle_output += middle_stride_output.view(N * h, -1, sqk_v.size(-1)) - CudaTimer().stop('stride_qk@v') - - # (N h) (nblock-2) blksize (randnum blksize) - CudaTimer().start('rand_qk@v') - middle_rand = middle_attn[:,:,5*block_size:].view( - N * h, L // block_size - 2, block_size, rand_num * block_size - ) - # rand v: (N h) (nblock-2) blksize (2 blksize), (N h) (nblock-2) (2 blksize) d - # -> (N h) (nblock-2) blksize d - middle_rand_output = bmm(middle_rand, rqk_v, ndim=4) - middle_output += middle_rand_output.view(N * h, -1, rqk_v.size(-1)) - CudaTimer().stop('rand_qk@v') - - # (N h) (2 blksize) d, (N h) ((nblock-2) blksize) d -> (N h) L d - output = torch.cat((head_output, middle_output), dim=1) - CudaTimer().stop('qk@v') - - CudaTimer().start('out_proj') - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E - CudaTimer().stop('out_proj') - return output - - -def dense_attn(query: torch.Tensor, - q_proj: torch.Tensor, q_bias: torch.Tensor, - k_proj: torch.Tensor, k_bias: torch.Tensor, - v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, block_size: int): - num_head = h - L, N = query.size(0), query.size(1) - dim_head = q_proj.size(0) // num_head - - CudaTimer().start('to_qkv') - q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - # v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q.transpose(0, 1).contiguous() # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - CudaTimer().stop('to_qkv') - # k = k.transpose(1, 2).contiguous() # (N h) L d -> (N h) d L - - CudaTimer().start('allocation') - attn = torch.empty( - (N * h), L, L, dtype=torch.float16 if args.fp16 else args.fp32, - device=torch.cuda.current_device() - ) - CudaTimer().stop('allocation') - - CudaTimer().start('q@k') - attn = torch.bmm(q, k.transpose(1, 2), out=attn) # (N h) L d, (N h) d L -> (N h) L L - CudaTimer().stop('q@k') - - # attention mask - attention_mask = torch.ones(((N, L)), device=torch.cuda.current_device()) - attention_mask = attention_mask.view(N, L // block_size, block_size) - exp_blocked_to_pad = torch.cat( - [attention_mask[:, 1:-3], attention_mask[:, 2:-2], attention_mask[:, 3:-1]], dim=2 - ) - band_mask = torch.einsum("blq,blk->blqk", attention_mask[:, 2:-2], exp_blocked_to_pad) - band_mask.unsqueeze_(1) - band_mask = band_mask < 0.5 - # attn.masked_fill_(band_mask, -1000.0) - - CudaTimer().start('all_softmax') - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L - CudaTimer().stop('all_softmax') - - CudaTimer().start('qk@v') - output = torch.bmm(attn, v.transpose(0, 1)) # (N h) L L, (N h) L d -> (N h) L d - CudaTimer().stop('qk@v') - - CudaTimer().start('out_proj') - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E - CudaTimer().stop('out_proj') - return output - - -class MultiHeadSelfAttention(torch.nn.Module): - - def __init__(self): - super().__init__() - self.kdim = config.hidden_size - self.vdim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = 0.0 - self.block_size = config.block_size - # Q - self.q_proj = torch.nn.Parameter(torch.empty(config.hidden_size, config.hidden_size)) - self.q_bias = torch.nn.Parameter(torch.empty(config.hidden_size)) - # K - self.k_proj = torch.nn.Parameter(torch.empty(config.hidden_size, self.kdim)) - self.k_bias = torch.nn.Parameter(torch.empty(config.hidden_size)) - # V - self.v_proj = torch.nn.Parameter(torch.empty(config.hidden_size, self.vdim)) - self.v_bias = torch.nn.Parameter(torch.empty(config.hidden_size)) - # Out - self.out_proj = torch.nn.Parameter(torch.empty(config.hidden_size, config.hidden_size)) - self.out_bias = torch.nn.Parameter(torch.empty(config.hidden_size)) - - def forward(self, query): - if args.sparse: - if _tp_size > 1: - return parallel_sparse_attn( - query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p, self.block_size - ) - else: - return sparse_attn( - query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p, self.block_size - ) - else: - return dense_attn( - query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p, self.block_size - ) - - -class AttnDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int): - self.bs = batch_size - super().__init__( - shapes=( - [config.seqlen, args.bs, config.hidden_size], - ), - dtypes=(torch.float16 if args.fp16 else torch.float,), - batch_dims=(0,) - ) - self.samples = [self.random_sample()] - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] - - def random_sample(self): - hidden_state = torch.randn( - config.seqlen, self.bs, config.hidden_size, - dtype=torch.float16 if args.fp16 else torch.float, - device=torch.cuda.current_device() - ) - return hidden_state - - - -if __name__ == '__main__': - - model = MultiHeadSelfAttention() - nparams = sum([param.numel() for param in model.parameters()]) - print_each_rank(f'model params (M): {nparams / 1e6}. Launching model...') - model = model.half().cuda() if args.fp16 else model.cuda() - model.eval() - - dataloader = AttnDataLoader(args.bs) - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - print_each_rank('model weight consumpition:') - memory_summary() - - CudaTimer(enable=False).warmup(2) - torch.distributed.barrier() - iter_num =32 - for step in range(iter_num): - if step >= 8: - CudaTimer(enable=True).start('e2e') - - # train 1 step - # num_microbatch = 1 - with torch.no_grad(): - data = next(dataloader) - out = model(data) - # loss.backward() - - optimizer.step() - optimizer.zero_grad() - - if step >= 8: - CudaTimer().stop('e2e') - - # torch.cuda.empty_cache() - torch.distributed.barrier() - - if step == 0: - print_each_rank('memory after optimizer:', rank_only=0) - memory_summary() - - if (step + 1) % 8 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-8, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-8) - memory_summary() \ No newline at end of file diff --git a/handcraft/bigbird/sparse_attn.py b/handcraft/bigbird/sparse_attn.py deleted file mode 100644 index 7282ac34..00000000 --- a/handcraft/bigbird/sparse_attn.py +++ /dev/null @@ -1,968 +0,0 @@ -""" -BigBird paper -https://papers.nips.cc/paper/2020/file/c8512d142a2d849725f31a9a7a361ab9-Paper.pdf - -Understanding blog: -https://github.com/huggingface/blog/blob/main/big-bird.md - - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/bigbird/sparse_attn.py \ - --hidden-size 4096 --heads 32 --seqlen 4096 \ - --bs 8 --fp16 -""" - -import torch -import torch.nn as nn -import cube -import math -import numpy as np - -import argparse -from cube.runtime.device import DeviceGroup -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - - - -parser = argparse.ArgumentParser(description='sparse_attention') - -parser.add_argument('--hidden-size', type=int, default=4096, - help='hidden size') -parser.add_argument('--heads', type=int, default=32, - help='number of heads') -parser.add_argument('--seqlen', type=int, default=3096, - help='sequence length') -parser.add_argument('--blk-size', type=int, default=64, - help='sequence length') -# training config -parser.add_argument('--bs', type=int, default=256, - help='num of micro batch') -parser.add_argument('--fp16', action='store_true', default=False) -parser.add_argument('--sparse', action='store_true', default=False) -args = parser.parse_args() -print(args) -cube.init() - - -class Config: - - num_attention_heads = args.heads - hidden_size = args.hidden_size - all_head_sie = hidden_size - seqlen = args.seqlen # seqlen - num_random_blocks = 3 - block_size=args.blk_size - use_bias = True - -config = Config() - - -def create_mask(): - batch_size = args.bs - seq_length = config.seqlen - block_size = config.block_size - attention_mask = torch.ones(((batch_size, seq_length)), device=torch.cuda.current_device()) - assert ( - seq_length % block_size == 0 - ), f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block size is {block_size}." - - def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): - """ - Create 3D attention mask from a 2D tensor mask. - Args: - from_blocked_mask: 2D Tensor of shape [batch_size, - from_seq_length//from_block_size, from_block_size]. - to_blocked_mask: int32 Tensor of shape [batch_size, - to_seq_length//to_block_size, to_block_size]. - Returns: - float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, - 3*to_block_size]. - """ - exp_blocked_to_pad = torch.cat( - [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], dim=2 - ) - band_mask = torch.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) - band_mask.unsqueeze_(1) - return band_mask - - blocked_encoder_mask = attention_mask.view(batch_size, seq_length // block_size, block_size) - band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) - from_mask = attention_mask.view(batch_size, 1, seq_length, 1) - to_mask = attention_mask.view(batch_size, 1, 1, seq_length) - return blocked_encoder_mask, band_mask, from_mask, to_mask - - - -class BigBirdBlockSparseAttention(nn.Module): - def __init__(self, seed=None): - super().__init__() - - self.max_seqlen = config.seqlen - self.seed = seed - - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size {config.hidden_size} is not a multiple of the number of attention " - f"heads {config.num_attention_heads}." - ) - - self.num_attention_heads = config.num_attention_heads - self.num_random_blocks = config.num_random_blocks - self.block_size = config.block_size - - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) - self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) - self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states): - # Currently this `class` can't be used in decoder. - blocked_mask, band_mask, from_mask, to_mask = create_mask() - from_blocked_mask = to_blocked_mask = blocked_mask - - batch_size, seqlen, _ = hidden_states.size() - to_seq_length = from_seq_length = seqlen - from_block_size = to_block_size = self.block_size - - assert from_seq_length % from_block_size == 0, "Query sided sequence length must be multiple of block size" - assert to_seq_length % to_block_size == 0, "Key/Value sided sequence length must be multiple of block size" - - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - context_layer = self.bigbird_block_sparse_attention( - query_layer, - key_layer, - value_layer, - band_mask, - from_mask, - to_mask, - from_blocked_mask, - to_blocked_mask, - self.num_attention_heads, - self.num_random_blocks, - self.attention_head_size, - from_block_size, - to_block_size, - batch_size, - from_seq_length, - to_seq_length, - seed=self.seed, - plan_from_length=None, - plan_num_rand_blocks=None, - ) - - context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) - return context_layer - - @staticmethod - def torch_bmm_nd(inp_1, inp_2, ndim=None): - """Fast nd matrix multiplication""" - # faster replacement of torch.einsum ("bhqk,bhkd->bhqd") - return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view( - inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1]) - ) - - @staticmethod - def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None): - """Fast nd matrix multiplication with transpose""" - # faster replacement of torch.einsum (bhqd,bhkd->bhqk) - return torch.bmm( - inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2) - ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2])) - - def bigbird_block_sparse_attention( - self, - query_layer, - key_layer, - value_layer, - band_mask, - from_mask, - to_mask, - from_blocked_mask, # same with blocked encoder mask - to_blocked_mask, # same with blocked encoder mask - n_heads, - n_rand_blocks, - attention_head_size, - from_block_size, - to_block_size, - batch_size, - from_seq_len, - to_seq_len, - seed, - plan_from_length, - plan_num_rand_blocks, - ): - - # BigBird block-sparse attention as suggested in paper - - # ITC: - # global tokens: 2 x block_size - # window tokens: 3 x block_size - # random tokens: num_rand_tokens x block_size - - # ETC: - # global tokens: extra_globals_tokens + 2 x block_size - # window tokens: 3 x block_size - # random tokens: num_rand_tokens x block_size - - # Note: - # 1) Currently, ETC is not supported. - # 2) Window size is fixed to 3 blocks & it can be changed only by - # changing `block_size`. - # 3) Number of global blocks are fixed (2 blocks here) & global tokens can be - # controlled only by `block_size`. - - # attention is calculated separately for q[0], q[1], q[2:-2], q[-2], q[-1] in order to use special trick of shifting tokens (for calculating sliding attention) - # hence following code can be divided into 5 parts. - - if from_seq_len // from_block_size != to_seq_len // to_block_size: - raise ValueError("Error the number of blocks needs to be same!") - - rsqrt_d = 1 / math.sqrt(attention_head_size) - bsz = batch_size - attn_mask_penalty = -10000.0 - - # generate random attention and corresponding masks - np.random.seed(seed) - if from_seq_len in [1024, 3072, 4096]: # old plans used in paper - rand_attn = [ - self._bigbird_block_rand_mask( - self.max_seqlen, self.max_seqlen, from_block_size, to_block_size, n_rand_blocks, last_idx=1024 - )[: (from_seq_len // from_block_size - 2)] - for _ in range(n_heads) - ] - else: - if plan_from_length is None: - plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( - from_seq_len, from_block_size, n_rand_blocks - ) - - rand_attn = self._bigbird_block_rand_mask_with_head( - from_seq_length=from_seq_len, - to_seq_length=to_seq_len, - from_block_size=from_block_size, - to_block_size=to_block_size, - num_heads=n_heads, - plan_from_length=plan_from_length, - plan_num_rand_blocks=plan_num_rand_blocks, - ) - - rand_attn = np.stack(rand_attn, axis=0) - rand_attn = torch.tensor(rand_attn, device=query_layer.device, dtype=torch.long) - rand_attn.unsqueeze_(0) - rand_attn = torch.cat([rand_attn for _ in range(batch_size)], dim=0) - - rand_mask = self._create_rand_mask_from_inputs( - from_blocked_mask, to_blocked_mask, rand_attn, n_heads, n_rand_blocks, bsz, from_seq_len, from_block_size - ) - - blocked_query_matrix = query_layer.view(bsz, n_heads, from_seq_len // from_block_size, from_block_size, -1) - blocked_key_matrix = key_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) - blocked_value_matrix = value_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) - - # preparing block for randn attn - gathered_key = self.torch_gather_b2(blocked_key_matrix, rand_attn) - gathered_key = gathered_key.view( - bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 - ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] - gathered_value = self.torch_gather_b2(blocked_value_matrix, rand_attn) - gathered_value = gathered_value.view( - bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 - ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] - - # 1st PART - # 1st block (global block) attention scores - # q[0] x (k[0], k[1], k[2], k[3], k[4] .... ) - - # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] - first_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 0], key_layer, ndim=4) - - first_product = first_product * rsqrt_d - first_product += (1.0 - to_mask) * attn_mask_penalty - first_attn_weights = nn.functional.softmax( - first_product, dim=-1 - ) # [bsz, n_heads, from_block_size, to_seq_len] - - # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] - first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4) - first_context_layer.unsqueeze_(2) - - # 2nd PART - # 2nd block attention scores - # q[1] x (sliding_keys, random_keys, global_keys) - # sliding key blocks -> 2nd, 3rd blocks - # global key blocks -> 1st block - - second_key_mat = torch.cat( - [ - blocked_key_matrix[:, :, 0], - blocked_key_matrix[:, :, 1], - blocked_key_matrix[:, :, 2], - blocked_key_matrix[:, :, -1], - gathered_key[:, :, 0], - ], - dim=2, - ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] - second_value_mat = torch.cat( - [ - blocked_value_matrix[:, :, 0], - blocked_value_matrix[:, :, 1], - blocked_value_matrix[:, :, 2], - blocked_value_matrix[:, :, -1], - gathered_value[:, :, 0], - ], - dim=2, - ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] - - # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] - second_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 1], second_key_mat, ndim=4) - second_seq_pad = torch.cat( - [ - to_mask[:, :, :, : 3 * to_block_size], - to_mask[:, :, :, -to_block_size:], - to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), - ], - dim=3, - ) - second_rand_pad = torch.cat( - [ - rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), - rand_mask[:, :, 0], - ], - dim=3, - ) - second_product = second_product * rsqrt_d - second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty - second_attn_weights = nn.functional.softmax( - second_product, dim=-1 - ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] - - # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] - second_context_layer = self.torch_bmm_nd(second_attn_weights, second_value_mat, ndim=4) - - second_context_layer.unsqueeze_(2) - - # 3rd PART - # Middle blocks attention scores - # q[-2:2] x (sliding_keys, random_keys, global_keys) - # sliding attn is calculated using special trick of shifting tokens as discussed in paper - # random keys are generated by taking random indices as per `rand_attn` - # global keys -> 1st & last block - - exp_blocked_key_matrix = torch.cat( - [blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1]], dim=3 - ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] - exp_blocked_value_matrix = torch.cat( - [blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1]], - dim=3, - ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] - middle_query_matrix = blocked_query_matrix[:, :, 2:-2] - - # sliding attention scores for q[-2:2] - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [b, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] - inner_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, exp_blocked_key_matrix, ndim=5) - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, 3*to_block_size] - inner_band_product = inner_band_product * rsqrt_d - - # randn attention scores for q[-2:2] - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] - rand_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, gathered_key[:, :, 1:-1], ndim=5) - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] - rand_band_product = rand_band_product * rsqrt_d - - # Including 1st block (since it's global) - first_band_product = torch.einsum( - "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0] - ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] - first_band_product = first_band_product * rsqrt_d - - # Including last block (since it's global) - last_band_product = torch.einsum( - "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1] - ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] - last_band_product = last_band_product * rsqrt_d - - # masking padded tokens - inner_band_product += (1.0 - band_mask) * attn_mask_penalty - first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * attn_mask_penalty - last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * attn_mask_penalty - rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty - - # completing attention scores matrix for all q[-2:2] - band_product = torch.cat( - [first_band_product, inner_band_product, rand_band_product, last_band_product], dim=-1 - ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] - - # safely doing softmax since attention matrix is completed - attn_weights = nn.functional.softmax( - band_product, dim=-1 - ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] - - # contribution of sliding keys - # [bsz, n_heads, m//from_block_size-4, from_block_size, 3*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] - context_layer = self.torch_bmm_nd( - attn_weights[:, :, :, :, to_block_size : 4 * to_block_size], exp_blocked_value_matrix, ndim=5 - ) - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - - # adding contribution of random keys - # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] - context_layer += self.torch_bmm_nd( - attn_weights[:, :, :, :, 4 * to_block_size : -to_block_size], gathered_value[:, :, 1:-1], ndim=5 - ) - # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - - # adding contribution of global keys - context_layer += torch.einsum( - "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, :to_block_size], blocked_value_matrix[:, :, 0] - ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - context_layer += torch.einsum( - "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, -to_block_size:], blocked_value_matrix[:, :, -1] - ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] - - # 4th PART - # last 2nd token attention scores - # q[-2] x (sliding_keys, random_keys, global_keys) - # sliding key blocks -> last 3 blocks - # global key block -> 1st block - # random key block -> based on indices stored in `randn_attn` - - second_last_key_mat = torch.cat( - [ - blocked_key_matrix[:, :, 0], - blocked_key_matrix[:, :, -3], - blocked_key_matrix[:, :, -2], - blocked_key_matrix[:, :, -1], - gathered_key[:, :, -1], - ], - dim=2, - ) # [bsz, n_heads, (4+n_random_blocks)*to_block_size, -1] - second_last_value_mat = torch.cat( - [ - blocked_value_matrix[:, :, 0], - blocked_value_matrix[:, :, -3], - blocked_value_matrix[:, :, -2], - blocked_value_matrix[:, :, -1], - gathered_value[:, :, -1], - ], - dim=2, - ) # [bsz, n_heads, (4+r)*to_block_size, -1] - - # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] - second_last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -2], second_last_key_mat, ndim=4) - second_last_seq_pad = torch.cat( - [ - to_mask[:, :, :, :to_block_size], - to_mask[:, :, :, -3 * to_block_size :], - to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), - ], - dim=3, - ) - second_last_rand_pad = torch.cat( - [ - rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), - rand_mask[:, :, -1], - ], - dim=3, - ) - second_last_product = second_last_product * rsqrt_d - second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty - second_last_attn_weights = nn.functional.softmax( - second_last_product, dim=-1 - ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] - - # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] - second_last_context_layer = self.torch_bmm_nd(second_last_attn_weights, second_last_value_mat, ndim=4) - second_last_context_layer.unsqueeze_(2) - - # 5th PART - # last block (global) attention scores - # q[-1] x (k[0], k[1], k[2], k[3], .... ) - - # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] - last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) - last_product = last_product * rsqrt_d - last_product += (1.0 - to_mask) * attn_mask_penalty - last_attn_weights = nn.functional.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] - - # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] - last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4) - last_context_layer.unsqueeze_(2) - - # combining representations of all tokens - context_layer = torch.cat( - [first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer], - dim=2, - ) - context_layer = context_layer.view((bsz, n_heads, from_seq_len, -1)) * from_mask - context_layer = torch.transpose(context_layer, 1, 2) - - return context_layer - - @staticmethod - def torch_gather_b2(params, indices): - # this operation is equivalent to tf.gather when batch_dims=2 - - if params.shape[:2] != indices.shape[:2]: - raise ValueError( - f"Make sure that the first two dimensions of params and indices are identical, \ - but they are params: {params.shape[:2]} vs. indices: {params.shape[:2]}" - ) - num_indices_to_gather = indices.shape[-2] * indices.shape[-1] - num_indices_to_pick_from = params.shape[2] - - indices_shift = ( - torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device) - // num_indices_to_gather - * num_indices_to_pick_from - ) - - flattened_indices = indices.view(-1) + indices_shift - flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1]) - - out_flattened = flattened_params.index_select(0, flattened_indices) - - out = out_flattened.reshape(params.shape[:2] + (num_indices_to_gather,) + params.shape[3:]) - return out - - @staticmethod - def _create_rand_mask_from_inputs( - from_blocked_mask, - to_blocked_mask, - rand_attn, - num_attention_heads, - num_rand_blocks, - batch_size, - from_seq_length, - from_block_size, - ): - """ - Create 3D attention mask from a 2D tensor mask. - Args: - from_blocked_mask: 2D Tensor of shape [batch_size, - from_seq_length//from_block_size, from_block_size]. - to_blocked_mask: int32 Tensor of shape [batch_size, - to_seq_length//to_block_size, to_block_size]. - rand_attn: [batch_size, num_attention_heads, - from_seq_length//from_block_size-2, num_rand_blocks] - num_attention_heads: int. Number of attention heads. - num_rand_blocks: int. Number of random chunks per row. - batch_size: int. Batch size for computation. - from_seq_length: int. length of from sequence. - from_block_size: int. size of block in from sequence. - Returns: - float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2, - from_block_size, num_rand_blocks*to_block_size]. - """ - num_windows = from_seq_length // from_block_size - 2 - rand_mask = torch.stack([p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]) - rand_mask = rand_mask.view(batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size) - rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) - return rand_mask - - @staticmethod - def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks): - """ - Gives the plan of where to put random attention. - Args: - from_seq_length: int. length of from sequence. - from_block_size: int. size of block in from sequence. - num_rand_blocks: int. Number of random chunks per row. - Returns: - plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for - each block - """ - - plan_from_length = [] - plan_num_rand_blocks = [] - if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size): - plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size)) - plan_num_rand_blocks.append(num_rand_blocks) - plan_from_length.append(from_seq_length) - plan_num_rand_blocks.append(0) - elif (num_rand_blocks + 5) < (from_seq_length // from_block_size): - plan_from_length.append(int((num_rand_blocks + 5) * from_block_size)) - plan_num_rand_blocks.append(num_rand_blocks // 2) - plan_from_length.append(from_seq_length) - plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2)) - else: - plan_from_length.append(from_seq_length) - plan_num_rand_blocks.append(num_rand_blocks) - - return plan_from_length, plan_num_rand_blocks - - @staticmethod - def _bigbird_block_rand_mask( - from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1 - ): - """ - Create adjacency list of random attention. - Args: - from_seq_length: int. length of from sequence. - to_seq_length: int. length of to sequence. - from_block_size: int. size of block in from sequence. - to_block_size: int. size of block in to sequence. - num_rand_blocks: int. Number of random chunks per row. - last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, - if positive then num_rand_blocks blocks chosen only up to last_idx. - Returns: - adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks - """ - # using this method when from_seq_length in [1024, 3072, 4096] - - assert ( - from_seq_length // from_block_size == to_seq_length // to_block_size - ), "Error the number of blocks needs to be same!" - - rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) - middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32) - last = to_seq_length // to_block_size - 1 - if last_idx > (2 * to_block_size): - last = (last_idx // to_block_size) - 1 - - r = num_rand_blocks # shorthand - for i in range(1, from_seq_length // from_block_size - 1): - start = i - 2 - end = i - if i == 1: - rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] - elif i == 2: - rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] - elif i == from_seq_length // from_block_size - 3: - rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] - # Missing -3: should have been sliced till last-3 - elif i == from_seq_length // from_block_size - 2: - rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] - # Missing -4: should have been sliced till last-4 - else: - if start > last: - start = last - rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] - elif (end + 1) == last: - rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] - else: - rand_attn[i - 1, :] = np.random.permutation( - np.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) - )[:r] - return rand_attn - - def _bigbird_block_rand_mask_with_head( - self, - from_seq_length, - to_seq_length, - from_block_size, - to_block_size, - num_heads, - plan_from_length, - plan_num_rand_blocks, - window_block_left=1, - window_block_right=1, - global_block_top=1, - global_block_bottom=1, - global_block_left=1, - global_block_right=1, - ): - """ - Create adjacency list of random attention. - Args: - from_seq_length: int. length of from sequence. - to_seq_length: int. length of to sequence. - from_block_size: int. size of block in from sequence. - to_block_size: int. size of block in to sequence. - num_heads: int. total number of heads. - plan_from_length: list. plan from length where num_random_blocks are chosen from. - plan_num_rand_blocks: list. number of rand blocks within the plan. - window_block_left: int. number of blocks of window to left of a block. - window_block_right: int. number of blocks of window to right of a block. - global_block_top: int. number of blocks at the top. - global_block_bottom: int. number of blocks at the bottom. - global_block_left: int. Number of blocks globally used to the left. - global_block_right: int. Number of blocks globally used to the right. - Returns: - adjacency list of size num_head where each element is of size from_seq_length//from_block_size-2 by - num_rand_blocks - """ - # using this method when from_seq_length not in [1024, 3072, 4096] - - assert ( - from_seq_length // from_block_size == to_seq_length // to_block_size - ), "Error the number of blocks needs to be same!" - - assert from_seq_length in plan_from_length, "Error from sequence length not in plan!" - - # Total number of blocks in the mmask - num_blocks = from_seq_length // from_block_size - # Number of blocks per plan - plan_block_length = np.array(plan_from_length) // from_block_size - # till when to follow plan - max_plan_idx = plan_from_length.index(from_seq_length) - # Random Attention adjacency list - rand_attn = [ - np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32) - for i in range(num_heads) - ] - - # We will go iteratively over the plan blocks and pick random number of - # Attention blocks from the legally allowed blocks - for plan_idx in range(max_plan_idx + 1): - rnd_r_cnt = 0 - if plan_idx > 0: - # set the row for all from_blocks starting from 0 to - # plan_block_length[plan_idx-1] - # column indx start fromm plan_block_length[plan_idx-1] and ends at - # plan_block_length[plan_idx] - if plan_num_rand_blocks[plan_idx] > 0: - rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) - curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) - for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): - for h in range(num_heads): - rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( - block_id=blk_rw_idx, - to_start_block_id=plan_block_length[plan_idx - 1], - to_end_block_id=plan_block_length[plan_idx], - num_rand_blocks=plan_num_rand_blocks[plan_idx], - window_block_left=window_block_left, - window_block_right=window_block_right, - global_block_left=global_block_left, - global_block_right=global_block_right, - ) - - for pl_id in range(plan_idx): - if plan_num_rand_blocks[pl_id] == 0: - continue - for blk_rw_idx in range(plan_block_length[plan_idx - 1], plan_block_length[plan_idx]): - rnd_r_cnt = 0 - to_start_block_id = 0 - if pl_id > 0: - rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id])) - to_start_block_id = plan_block_length[pl_id - 1] - curr_r_cnt = int(np.sum(plan_num_rand_blocks[: pl_id + 1])) - for h in range(num_heads): - rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( - block_id=blk_rw_idx, - to_start_block_id=to_start_block_id, - to_end_block_id=plan_block_length[pl_id], - num_rand_blocks=plan_num_rand_blocks[pl_id], - window_block_left=window_block_left, - window_block_right=window_block_right, - global_block_left=global_block_left, - global_block_right=global_block_right, - ) - - if plan_num_rand_blocks[plan_idx] == 0: - continue - curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) - from_start_block_id = global_block_top - to_start_block_id = 0 - if plan_idx > 0: - rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) - from_start_block_id = plan_block_length[plan_idx - 1] - to_start_block_id = plan_block_length[plan_idx - 1] - - for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): - for h in range(num_heads): - rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( - block_id=blk_rw_idx, - to_start_block_id=to_start_block_id, - to_end_block_id=plan_block_length[plan_idx], - num_rand_blocks=plan_num_rand_blocks[plan_idx], - window_block_left=window_block_left, - window_block_right=window_block_right, - global_block_left=global_block_left, - global_block_right=global_block_right, - ) - - for nh in range(num_heads): - rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] - - return rand_attn - - @staticmethod - def _get_single_block_row_attention( - block_id, - to_start_block_id, - to_end_block_id, - num_rand_blocks, - window_block_left=1, - window_block_right=1, - global_block_left=1, - global_block_right=1, - ): - """ - For a single row block get random row attention. - Args: - block_id: int. block id of row. - to_start_block_id: int. random attention column start id. - to_end_block_id: int. random attention column end id. - num_rand_blocks: int. number of random blocks to be selected. - window_block_left: int. number of blocks of window to left of a block. - window_block_right: int. number of blocks of window to right of a block. - global_block_left: int. Number of blocks globally used to the left. - global_block_right: int. Number of blocks globally used to the right. - Returns: - row containing the random attention vector of size num_rand_blocks. - """ - # list of to_blocks from which to choose random attention - to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32) - # permute the blocks - perm_block = np.random.permutation(to_block_list) - - # illegal blocks for the current block id, using window - illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) - - # Add blocks at the start and at the end - illegal_blocks.extend(list(range(global_block_left))) - illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id))) - - # The second from_block cannot choose random attention on second last to_block - if block_id == 1: - illegal_blocks.append(to_end_block_id - 2) - - # The second last from_block cannot choose random attention on second to_block - if block_id == to_end_block_id - 2: - illegal_blocks.append(1) - - selected_random_blokcs = [] - - for i in range(to_end_block_id - to_start_block_id): - if perm_block[i] not in illegal_blocks: - selected_random_blokcs.append(perm_block[i]) - if len(selected_random_blokcs) == num_rand_blocks: - break - return np.array(selected_random_blokcs, dtype=np.int32) - - @torch.no_grad() - def forward_dense(self, hidden_state): - N, L = hidden_state.size(0), hidden_state.size(1) - num_head = self.num_attention_heads - dim_head = self.attention_head_size - scale = 1 / math.sqrt(self.attention_head_size) - - # bs, seq, emb -> seq, bs, emb - hidden_state = hidden_state.transpose(0, 1) - q = self.query(hidden_state) - k = self.key(hidden_state) - v = self.value(hidden_state) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # attention mask - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - mask = ones # torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, 0.0, True, False) # (N h) L L -> (N h) L L - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - # output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E - output = output.transpose(0, 1).contiguous() - return output - - - -class BigBirdDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int): - self.bs = batch_size - super().__init__( - shapes=( - [args.bs, config.seqlen, config.hidden_size], - ), - dtypes=(torch.float16 if args.fp16 else torch.float,), - batch_dims=(0,) - ) - self.samples = [self.random_sample()] - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] - - def random_sample(self): - hidden_state = torch.randn( - self.bs, config.seqlen, config.hidden_size, - dtype=torch.float16 if args.fp16 else torch.float, - device=torch.cuda.current_device() - ) - return hidden_state - - - -if __name__ == '__main__': - - model = BigBirdBlockSparseAttention() - nparams = sum([param.numel() for param in model.parameters()]) - print_each_rank(f'model params (M): {nparams / 1e6}. Launching model...') - model = model.half().cuda() if args.fp16 else model.cuda() - model.eval() - - dataloader = BigBirdDataLoader(args.bs) - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - print_each_rank('model weight consumpition:') - memory_summary() - - CudaTimer(enable=False) - torch.distributed.barrier() - iter_num =32 - for step in range(iter_num): - if step >= 8: - CudaTimer(enable=True).start('e2e') - - # train 1 step - # num_microbatch = 1 - with torch.no_grad(): - data = next(dataloader) - if args.sparse: - out = model(data) - else: - out = model.forward_dense(data) - # loss.backward() - - optimizer.step() - optimizer.zero_grad() - - if step >= 8: - CudaTimer().stop('e2e') - - torch.cuda.empty_cache() - torch.distributed.barrier() - - if step == 0: - print_each_rank('memory after optimizer:', rank_only=0) - memory_summary() - - if (step + 1) % 8 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-8, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-8) - memory_summary() \ No newline at end of file diff --git a/handcraft/efficientnet/efficientnet.py b/handcraft/efficientnet/efficientnet.py deleted file mode 100644 index 4c8c1973..00000000 --- a/handcraft/efficientnet/efficientnet.py +++ /dev/null @@ -1,411 +0,0 @@ -"""model.py - Model and module class for EfficientNet. - They are built to mirror those in the official TensorFlow implementation. -""" - -# Author: lukemelas (github username) -# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch -# With adjustments and added comments by workingcoder (github username). - -# https://arxiv.org/pdf/1911.04252.pdf - -import torch -from torch import nn -from torch.nn import functional as F -from .utils import ( - round_filters, - round_repeats, - drop_connect, - get_same_padding_conv2d, - get_model_params, - efficientnet_params, - load_pretrained_weights, - Swish, - MemoryEfficientSwish, - calculate_output_image_size -) - -import torch.utils.checkpoint as checkpoint - - -VALID_MODELS = ( - 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', - 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', - 'efficientnet-b8', - - # Support the construction of 'efficientnet-l2' without pretrained weights - 'efficientnet-l2' -) - - -class MBConvBlock(nn.Module): - """Mobile Inverted Residual Bottleneck Block. - Args: - block_args (namedtuple): BlockArgs, defined in utils.py. - global_params (namedtuple): GlobalParam, defined in utils.py. - image_size (tuple or list): [image_height, image_width]. - References: - [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) - [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) - [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) - """ - - def __init__(self, block_args, global_params, image_size=None): - super().__init__() - self._block_args = block_args - self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow - self._bn_eps = global_params.batch_norm_epsilon - self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) - self.id_skip = block_args.id_skip # whether to use skip connection and drop connect - - # Expansion phase (Inverted Bottleneck) - inp = self._block_args.input_filters # number of input channels - oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels - if self._block_args.expand_ratio != 1: - Conv2d = get_same_padding_conv2d(image_size=image_size) - self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) - self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) - # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size - - in_image_size = image_size - # Depthwise convolution phase - k = self._block_args.kernel_size - s = self._block_args.stride - Conv2d = get_same_padding_conv2d(image_size=image_size) - self._depthwise_conv = Conv2d( - in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise - kernel_size=k, stride=s, bias=False) - self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) - image_size = calculate_output_image_size(image_size, s) - - # Squeeze and Excitation layer, if desired - if self.has_se: - Conv2d = get_same_padding_conv2d(image_size=(1, 1)) - num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) - self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) - self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) - - # Pointwise convolution phase - final_oup = self._block_args.output_filters - Conv2d = get_same_padding_conv2d(image_size=image_size) - self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) - self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) - self._swish = MemoryEfficientSwish() - - self.in_size = [inp, *in_image_size] - self.out_size = [final_oup, *image_size] - - def forward(self, inputs, drop_connect_rate=None): - """MBConvBlock's forward function. - Args: - inputs (tensor): Input tensor. - drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). - Returns: - Output of this block after processing. - """ - # before_allocated = torch.cuda.max_memory_allocated() - # Expansion and Depthwise Convolution - x = inputs - assert list(x.shape)[1:] == self.in_size - if self._block_args.expand_ratio != 1: - x = self._expand_conv(inputs) - x = self._bn0(x) - x = self._swish(x) - - x = self._depthwise_conv(x) - x = self._bn1(x) - x = self._swish(x) - - # Squeeze and Excitation - if self.has_se: - x_squeezed = F.adaptive_avg_pool2d(x, 1) - x_squeezed = self._se_reduce(x_squeezed) - x_squeezed = self._swish(x_squeezed) - x_squeezed = self._se_expand(x_squeezed) - x = torch.sigmoid(x_squeezed) * x - - # Pointwise Convolution - x = self._project_conv(x) - x = self._bn2(x) - - # Skip connection and drop connect - input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters - if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: - # The combination of skip connection and drop connect brings about stochastic depth. - if drop_connect_rate: - x = drop_connect(x, p=drop_connect_rate, training=self.training) - x = x + inputs # skip connection - - assert list(x.shape)[1:] == self.out_size - # after_allocated = torch.cuda.max_memory_allocated() - # consumption = (after_allocated - before_allocated) / 1024 / 1024 - # print('{} {}'.format(self.layer_id, consumption)) - return x - - def set_swish(self, memory_efficient=True): - """Sets swish function as memory efficient (for training) or standard (for export). - Args: - memory_efficient (bool): Whether to use memory-efficient version of swish. - """ - self._swish = MemoryEfficientSwish() if memory_efficient else Swish() - - -class EfficientNet(nn.Module): - """EfficientNet model. - Most easily loaded with the .from_name or .from_pretrained methods. - Args: - blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. - global_params (namedtuple): A set of GlobalParams shared between blocks. - References: - [1] https://arxiv.org/abs/1905.11946 (EfficientNet) - Example: - >>> import torch - >>> from efficientnet.model import EfficientNet - >>> inputs = torch.rand(1, 3, 224, 224) - >>> model = EfficientNet.from_pretrained('efficientnet-b0') - >>> model.eval() - >>> outputs = model(inputs) - """ - - def __init__(self, blocks_args=None, global_params=None): - super().__init__() - assert isinstance(blocks_args, list), 'blocks_args should be a list' - assert len(blocks_args) > 0, 'block args must be greater than 0' - self._global_params = global_params - self._blocks_args = blocks_args - - # Batch norm parameters - bn_mom = 1 - self._global_params.batch_norm_momentum - bn_eps = self._global_params.batch_norm_epsilon - - # Get stem static or dynamic convolution depending on image size - image_size = global_params.image_size - Conv2d = get_same_padding_conv2d(image_size=image_size) - - # Stem - in_channels = 3 # rgb - out_channels = round_filters(32, self._global_params) # number of output channels - self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) - self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) - image_size = calculate_output_image_size(image_size, 2) - - # Build blocks - self._blocks = nn.ModuleList([]) - for block_args in self._blocks_args: - - # Update block input and output filters based on depth multiplier. - block_args = block_args._replace( - input_filters=round_filters(block_args.input_filters, self._global_params), - output_filters=round_filters(block_args.output_filters, self._global_params), - num_repeat=round_repeats(block_args.num_repeat, self._global_params) - ) - - # The first block needs to take care of stride and filter size increase. - self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) - image_size = calculate_output_image_size(image_size, block_args.stride) - if block_args.num_repeat > 1: # modify block_args to keep same output size - block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) - for _ in range(block_args.num_repeat - 1): - self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) - # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 - - # Head - in_channels = block_args.output_filters # output of final block - out_channels = round_filters(1280, self._global_params) - Conv2d = get_same_padding_conv2d(image_size=image_size) - self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) - self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) - - # Final linear layer - self._avg_pooling = nn.AdaptiveAvgPool2d(1) - if self._global_params.include_top: - self._dropout = nn.Dropout(self._global_params.dropout_rate) - self._fc = nn.Linear(out_channels, self._global_params.num_classes) - - # set activation to memory efficient swish by default - self._swish = MemoryEfficientSwish() - - self.use_checkpoint = [False] * len(self._blocks) - self.preprocess = True - self.postprocess = True - - def set_swish(self, memory_efficient=True): - """Sets swish function as memory efficient (for training) or standard (for export). - Args: - memory_efficient (bool): Whether to use memory-efficient version of swish. - """ - self._swish = MemoryEfficientSwish() if memory_efficient else Swish() - for block in self._blocks: - block.set_swish(memory_efficient) - - def extract_endpoints(self, inputs): - """Use convolution layer to extract features - from reduction levels i in [1, 2, 3, 4, 5]. - Args: - inputs (tensor): Input tensor. - Returns: - Dictionary of last intermediate features - with reduction levels i in [1, 2, 3, 4, 5]. - Example: - >>> import torch - >>> from efficientnet.model import EfficientNet - >>> inputs = torch.rand(1, 3, 224, 224) - >>> model = EfficientNet.from_pretrained('efficientnet-b0') - >>> endpoints = model.extract_endpoints(inputs) - >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) - >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) - >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) - >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) - >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) - >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) - """ - endpoints = dict() - - # Stem - x = self._swish(self._bn0(self._conv_stem(inputs))) - prev_x = x - - # Blocks - for idx, block in enumerate(self._blocks): - drop_connect_rate = self._global_params.drop_connect_rate - if drop_connect_rate: - drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate - x = block(x, drop_connect_rate=drop_connect_rate) - if prev_x.size(2) > x.size(2): - endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x - elif idx == len(self._blocks) - 1: - endpoints['reduction_{}'.format(len(endpoints) + 1)] = x - prev_x = x - - # Head - x = self._swish(self._bn1(self._conv_head(x))) - endpoints['reduction_{}'.format(len(endpoints) + 1)] = x - - return endpoints - - def forward(self, x, feature_map=None): - """EfficientNet's forward function. - Calls extract_features to extract features, applies final linear layer, and returns logits. - Args: - inputs (tensor): Input tensor. - Returns: - Output of this model after processing. - """ - if self.preprocess: - # Stem - x = self._swish(self._bn0(self._conv_stem(x))) - feature_map = x - - # Blocks - for idx, block in enumerate(self._blocks): - drop_connect_rate = self._global_params.drop_connect_rate - if drop_connect_rate: - drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate - else: - drop_connect_rate = None - if self.use_checkpoint[idx]: - feature_map = checkpoint.checkpoint(block, feature_map, drop_connect_rate) - else: - feature_map = block(feature_map, drop_connect_rate=drop_connect_rate) - # feature_map = block(feature_map, drop_connect_rate=drop_connect_rate) - x = feature_map - - if self.postprocess: - # Head - x = self._swish(self._bn1(self._conv_head(x))) - # Pooling and final linear layer - x = self._avg_pooling(x) - if self._global_params.include_top: - x = x.flatten(start_dim=1) - x = self._dropout(x) - x = self._fc(x) - x = torch.sum(x) - return x - - @classmethod - def from_name(cls, model_name, in_channels=3, **override_params): - """Create an efficientnet model according to name. - Args: - model_name (str): Name for efficientnet. - in_channels (int): Input data's channel number. - override_params (other key word params): - Params to override model's global_params. - Optional key: - 'width_coefficient', 'depth_coefficient', - 'image_size', 'dropout_rate', - 'num_classes', 'batch_norm_momentum', - 'batch_norm_epsilon', 'drop_connect_rate', - 'depth_divisor', 'min_depth' - Returns: - An efficientnet model. - """ - cls._check_model_name_is_valid(model_name) - blocks_args, global_params = get_model_params(model_name, override_params) - model = cls(blocks_args, global_params) - model._change_in_channels(in_channels) - return model - - @classmethod - def from_pretrained(cls, model_name, weights_path=None, advprop=False, - in_channels=3, num_classes=1000, **override_params): - """Create an efficientnet model according to name. - Args: - model_name (str): Name for efficientnet. - weights_path (None or str): - str: path to pretrained weights file on the local disk. - None: use pretrained weights downloaded from the Internet. - advprop (bool): - Whether to load pretrained weights - trained with advprop (valid when weights_path is None). - in_channels (int): Input data's channel number. - num_classes (int): - Number of categories for classification. - It controls the output size for final linear layer. - override_params (other key word params): - Params to override model's global_params. - Optional key: - 'width_coefficient', 'depth_coefficient', - 'image_size', 'dropout_rate', - 'batch_norm_momentum', - 'batch_norm_epsilon', 'drop_connect_rate', - 'depth_divisor', 'min_depth' - Returns: - A pretrained efficientnet model. - """ - model = cls.from_name(model_name, num_classes=num_classes, **override_params) - load_pretrained_weights(model, model_name, weights_path=weights_path, - load_fc=(num_classes == 1000), advprop=advprop) - model._change_in_channels(in_channels) - return model - - @classmethod - def get_image_size(cls, model_name): - """Get the input image size for a given efficientnet model. - Args: - model_name (str): Name for efficientnet. - Returns: - Input image size (resolution). - """ - cls._check_model_name_is_valid(model_name) - _, _, res, _ = efficientnet_params(model_name) - return res - - @classmethod - def _check_model_name_is_valid(cls, model_name): - """Validates model name. - Args: - model_name (str): Name for efficientnet. - Returns: - bool: Is a valid name or not. - """ - if model_name not in VALID_MODELS: - raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS)) - - def _change_in_channels(self, in_channels): - """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. - Args: - in_channels (int): Input data's channel number. - """ - if in_channels != 3: - Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) - out_channels = round_filters(32, self._global_params) - self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) \ No newline at end of file diff --git a/handcraft/efficientnet/schedule.py b/handcraft/efficientnet/schedule.py deleted file mode 100644 index 281d2814..00000000 --- a/handcraft/efficientnet/schedule.py +++ /dev/null @@ -1,250 +0,0 @@ -import torch - -from torch.distributed.distributed_c10d import _get_global_rank -from cube.profiler.timer import CudaTimer - - -def get_global_rank(group, group_rank): - if group is None: - return group_rank - else: - return _get_global_rank(group, group_rank) - - -def is_last_stage(group): - return torch.distributed.get_rank(group=group) == torch.distributed.get_world_size(group=group) - 1 - - -#================= WhatToDO functions ==================# - -def forward_step(model, image, trans_input=None): - CudaTimer().start("forward") - output = model(image, trans_input) - CudaTimer().stop("forward") - return output - - -def backward_step(feature_map, output_tensor, output_tensor_grad): - """ - Calculate input tensor gradient - """ - if feature_map is not None and feature_map.requires_grad: - feature_map.retain_grad() - CudaTimer().start("backward") - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) - CudaTimer().stop("backward") - input_tensor_grad = None - if feature_map is not None and feature_map.requires_grad: - input_tensor_grad = feature_map.grad - return input_tensor_grad - -#================= WhatToDO functions ==================# - -#================= Between Stage functions ==================# - -def send(tensors, to_rank, group): - """ - send tensor to the target rank - """ - if to_rank < 0 or to_rank >= torch.distributed.get_world_size(group): - return None - if group is not None: - to_rank = get_global_rank(group, to_rank) - assert isinstance(tensors, list) or isinstance(tensors, tuple) - CudaTimer().start("send") - reqs = list() - for tensor in tensors: - if tensor is None: - continue - elif torch.is_tensor(tensor): - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, to_rank - ) - reqs.append(send_op) - else: - raise RuntimeError("Expected tensor or None") - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("send") - - -def recv(shapes, from_rank, dtype, group): - if from_rank < 0 or from_rank >= torch.distributed.get_world_size(group): - return [None] * len(shapes) - assert isinstance(shapes, list) or isinstance(shapes, tuple) - if group is not None: - from_rank = get_global_rank(group, from_rank) - # print(f'recv: {torch.distributed.get_rank()} <- {from_rank}: {shapes}') - CudaTimer().start("recv") - reqs = list() - recved_tensors = list() - for shape in shapes: - if shape is None: - recved_tensors.append(None) - continue - tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device(), - dtype=dtype - ) - recved_tensors.append(tensor) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor, from_rank - ) - reqs.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("recv") - return recved_tensors - - -def send_and_recv(send_tensors, recv_shapes, rank, dtype, group): - if rank < 0 or rank >= torch.distributed.get_world_size(group): - return [None] * len(recv_shapes) - if group is not None: - rank = get_global_rank(group, rank) - # print(f'exchange: {torch.distributed.get_rank()} <-> {rank}: {recv_shapes}') - assert isinstance(send_tensors, list) or isinstance(send_tensors, tuple) - assert isinstance(recv_shapes, list) or isinstance(recv_shapes, tuple) - CudaTimer().start("send_recv") - reqs = list() - recved_tensors = list() - for tensor in send_tensors: - if tensor is None: - continue - send_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, rank - ) - reqs.append(send_op) - for shape in recv_shapes: - if shape is None: - recved_tensors.append(None) - continue - recv_tensor = torch.empty( - shape, requires_grad=True, device=torch.cuda.current_device(), - dtype=dtype - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_tensor, rank - ) - recved_tensors.append(recv_tensor) - reqs.append(recv_op) - reqs = torch.distributed.batch_isend_irecv(reqs) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop("send_recv") - return recved_tensors - -#================= Between Stage functions ==================# - -def split_batch(inputs, num_microbatches): - """ - Split a mini-batch to micro-batches - """ - assert isinstance(inputs, list) or isinstance(inputs, tuple) - input_chunks = list() - for feature_map in inputs: - if torch.is_tensor(feature_map): - feature_map = torch.chunk(feature_map, chunks=num_microbatches, dim=0) - else: - feature_map = [feature_map] * num_microbatches - input_chunks.append(feature_map) - micro_batches = list() - for micro_data in zip(*tuple(input_chunks)): - micro_batches.append(micro_data) - return micro_batches - - -#================= Scheduling ==================# - -def scheduling_1f1b(model, inputs, bs, micro_bs, dtype, group): - myrank = torch.distributed.get_rank(group) - - num_microbatches = int(bs / micro_bs) - num_warmup_microbatches = \ - (torch.distributed.get_world_size(group) - - torch.distributed.get_rank(group) - 1) - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_warmup_remaining = num_microbatches - num_warmup_microbatches - - input_tensors = list() - output_tensors = list() - - inputs = split_batch(inputs, num_microbatches) - - # warmup forward pass - for i in range(num_warmup_microbatches): - # recv forward - # print('[warmup] rank {}: step-{}: recving forward...'.format(myrank, i)) - feature_map = recv( - (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype, group - )[0] - image = inputs[i][0] - # forward - output_tensor = forward_step(model, image, feature_map) - # send forward - # print('[warmup] rank {}: step-{}: sending forward...'.format(myrank, i)) - send((output_tensor,), myrank+1, group) - - input_tensors.append(feature_map) - output_tensors.append(output_tensor) - - # before running 1F1B, need to recieve first forward tensor - if num_warmup_remaining > 0: - # recv forward - # print('[1f1b] rank {}: step-{}: recving forward...'.format(myrank, 0)) - feature_map = recv( - (torch.Size([micro_bs] + model.in_size),), myrank-1, dtype, group - )[0] - image = inputs[num_warmup_microbatches][0] - - # run 1F1B - for i in range(num_warmup_remaining): - # forward - output_tensor = forward_step(model, image, feature_map) - # send forward + recv backward grads - # print('[1f1b] rank {}: step-{}: sending forward + recving backward...'.format(myrank, i)) - output_tensor_grad = send_and_recv( - (output_tensor,), - (torch.Size([micro_bs] + model.out_size),), - myrank+1, dtype, group - )[0] - input_tensors.append(feature_map) - output_tensors.append(output_tensor) - # backward - feature_map, output_tensor = input_tensors.pop(0), output_tensors.pop(0) - input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) - if i != (num_warmup_remaining-1): - # send backward grads + recv forward results - # print('[1f1b] rank {}: step-{}: sending backward + recving forward...'.format(myrank, i)) - feature_map = send_and_recv( - (input_tensor_grad,), - (torch.Size([micro_bs] + model.in_size),), - myrank-1, dtype, group - )[0] - image = inputs[num_warmup_microbatches+i+1][0] - else: # last iteration - no more inputs - feature_map = None - # send backward grads - # print('[1f1b] rank {}: step-{}: sending backward...'.format(myrank, i)) - send((input_tensor_grad,), myrank-1, group) - - # cooldown gradient trans back - for i in range(num_warmup_microbatches): - feature_map = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - # recv backward gradients - output_tensor_grad = recv( - (torch.Size([micro_bs] + model.out_size),), myrank+1, dtype, group - )[0] - # backward - input_tensor_grad = backward_step(feature_map, output_tensor, output_tensor_grad) - # send backward gradients - # print('[cooldown] rank {}: step-{}: sending backward...'.format(myrank, i)) - send((input_tensor_grad,), myrank-1, group) - -#================= Scheduling ==================# \ No newline at end of file diff --git a/handcraft/efficientnet/train.py b/handcraft/efficientnet/train.py deleted file mode 100644 index b272033c..00000000 --- a/handcraft/efficientnet/train.py +++ /dev/null @@ -1,338 +0,0 @@ -""" -python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/efficientnet/train.py \ - --pp 8 --gbs 32 --mbs 1 -""" -import torch -from handcraft.efficientnet.efficientnet import EfficientNet -import time -import argparse - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary -from handcraft.efficientnet.schedule import is_last_stage, scheduling_1f1b - - -def model_partition(model, in_size): - resource = cube.runtime.resource.EnvResource() - # pipeline stage - pp_rank = torch.distributed.get_rank(resource.pp_group) - pp_size = torch.distributed.get_world_size(resource.pp_group) - - layers = model._blocks - # for lid, layer in enumerate(layers): - # layer.layer_id = lid - - chunk = len(layers) // pp_size - if len(layers) % pp_size != 0: - remain = len(layers) % pp_size - if pp_rank < remain: - start = pp_rank * (chunk+1) - chunk = chunk + 1 - else: - start = remain * (chunk + 1) + (pp_rank - remain) * chunk - else: - start = pp_rank * chunk - stop = start + chunk - - use_checkpoint = [False] * (stop - start) - use_checkpoint = [True] * (stop - start) - - # 2 stage - # layer_split = [30, 58] - # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - # use_checkpoint = [False] * (stop - start) - # if pp_rank == 0: - # for idx in range(stop - start): - # if idx < 23: - # use_checkpoint[idx] = True - # if pp_rank == 1: - # for idx in range(stop - start): - # if idx < 20: - # use_checkpoint[idx] = True - - # layer_split = [8, 5, 8, 13, 16, 12, 16, 10] - # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - # - # use_checkpoint = [False] * (stop - start) - # if pp_rank == 0: - # for idx in range(stop - start): - # if idx < 3: - # use_checkpoint[idx] = True - # if pp_rank == 1: - # for idx in range(stop - start): - # if idx < 4: - # use_checkpoint[idx] = True - - # if pp_rank == 2: - # for idx in range(stop - start): - # if idx < 4: - # use_checkpoint[idx] = True - # if pp_rank == 3: - # for idx in range(stop - start): - # if idx < 3: - # use_checkpoint[idx] = True - - # 8 gpu naive partition plan - # if pp_rank == 0: - # for idx in range(stop - start): - # if idx < 10: - # use_checkpoint[idx] = True - # if pp_rank == 1: - # for idx in range(stop - start): - # if idx < 8: - # use_checkpoint[idx] = True - # if pp_rank == 2: - # for idx in range(stop - start): - # if idx < 2: - # use_checkpoint[idx] = True - - # 16GPU - # layer_split = [4, 3, 3, 3, 3, 5, 6, 7, 8, 8, 6, 6, 8, 8, 6, 4] - # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - # use_checkpoint = [False] * (stop - start) - # if pp_rank == 1: - # for idx in range(stop - start): - # if idx < 2: - # use_checkpoint[idx] = True - # if pp_rank == 2: - # for idx in range(stop - start): - # if idx < 1: - # use_checkpoint[idx] = True - use_checkpoint = [False] * (stop - start) - if pp_rank == 0: - for idx in range(stop - start): - if idx < 2: - use_checkpoint[idx] = True - if pp_rank == 1: - for idx in range(stop - start): - if idx < 4: - use_checkpoint[idx] = True - if pp_rank == 2: - for idx in range(stop - start): - if idx < 3: - use_checkpoint[idx] = True - - # 8GB memory experiments - # layer_split = [8, 5, 7, 14, 14, 13, 16, 11] - # assert sum(layer_split) == 88, f"split {sum(layer_split)} != 88" - # start = sum(layer_split[0:pp_rank]) - # stop = sum(layer_split[0:pp_rank+1]) - # - # use_checkpoint = [False] * (stop - start) - # if pp_rank == 0: - # for idx in range(stop - start): - # if idx < 8: - # use_checkpoint[idx] = True - # if pp_rank == 1: - # for idx in range(stop - start): - # if idx < 4: - # use_checkpoint[idx] = True - # if pp_rank == 2: - # for idx in range(stop - start): - # if idx < 5: - # use_checkpoint[idx] = True - # if pp_rank == 3: - # for idx in range(stop - start): - # if idx < 8: - # use_checkpoint[idx] = True - # if pp_rank == 4: - # for idx in range(stop - start): - # if idx < 5: - # use_checkpoint[idx] = True - # if pp_rank == 5: - # for idx in range(stop - start): - # if idx < 4: - # use_checkpoint[idx] = True - - print_each_rank(f'layer start -> end: {start} -> {stop}') - layers = layers[start:stop] - model._blocks = layers - model.use_checkpoint = use_checkpoint - - if pp_rank == 0: - model.preprocess = True - model.in_size = in_size - else: - model.preprocess = False - model.in_size = layers[0].in_size - - if is_last_stage(resource.pp_group): - model.postprocess = True - model.out_size = [1,] - else: - model.postprocess = False - model.out_size = layers[-1].out_size - - return model - - -def train(args): - resource = cube.runtime.resource.EnvResource() - - # L2 config - C, H, W = [3, 800, 800] - model = EfficientNet.from_name('efficientnet-l2') - - # B8 config - # C, H, W = [3, 672, 672] - # model = EfficientNet.from_name('efficientnet-b8') - - model = model_partition(model, [C, H, W]) - if args.fp16: - model == model.half() - model = model.cuda() - - nparams_million = sum(p.numel() for p in model.parameters()) / 1000 / 1000 - print_each_rank('model has {:.2f} million parameters'.format(nparams_million)) - memory_summary() - - if args.gbs % args.dp != 0: - raise RuntimeError("global bs is not divisible by DP") - dataloader = cube.runtime.syndata.SynDataLoader( - 1280, [0], [args.gbs // args.dp, C, H, W]) - - if args.fp16: - data_buff = [[e.half() for e in data] for data in dataloader.datas] - dataloader.datas = data_buff - - def train_iter(model, dataloader): - img = next(dataloader) - scheduling_1f1b(model, [img], args.gbs // args.dp, args.mbs, torch.float, resource.pp_group) - CudaTimer().start('dp_allreduce') - resource.reducer.allreduce() - CudaTimer().stop('dp_allreduce') - - optimizer = torch.optim.RMSprop(model.parameters()) - - if args.dp > 1: - print_each_rank('adding param for allreduce sync') - for param in model.parameters(): - resource.reducer.add_param(param) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - span = 0 - iter_num = 20 - for step in range(iter_num): - if step >= 10: - torch.cuda.synchronize() - start = time.time() - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step == 1: - print('> passed on 1st iteration') - memory_summary() - if step >= 10: - torch.cuda.synchronize() - stop = time.time() - span += (stop - start) * 1000 - CudaTimer().stop('e2e') - if (step + 1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - iter_time = CudaTimer().duration(iter_num-10, field_name='e2e') - throughput = args.gbs / iter_time * 1000 - print_each_rank('e2e time {:.2f} ms/iter. Throughput: {:.2f} samples/sec'.format( - iter_time, throughput) - ) - compute_time = CudaTimer().duration(iter_num-10, field_name='forward') + \ - CudaTimer().duration(iter_num-10, field_name='backward') - print_each_rank(f'compute time: {compute_time} ms') - - CudaTimer().print_all(times=iter_num-10) - memory_summary() - - -if __name__ == '__main__': - - cube.init() - - # resource allocation - parser = argparse.ArgumentParser(description='swin') - parser.add_argument('--tp', type=int, default=1, - help='tensor parallel size') - parser.add_argument('--dp', type=int, default=1, - help='data parallel size') - parser.add_argument('--pp', type=int, default=1, - help='pipeline parallel size') - parser.add_argument('--gbs', type=int, default=-1) - parser.add_argument('--mbs', type=int, default=-1) - parser.add_argument('--fp16', action='store_true', dest='fp16') - parser.add_argument('--memory-limit', type=float, default=None, - help='memory fraction limit') - args = parser.parse_args() - - - resource = cube.runtime.resource.EnvResource() - ndevs = resource.ngpus - - tp_size, tp_group_nums = args.tp, ndevs // args.tp - dp_size, dp_group_nums = args.dp, ndevs // args.dp - pp_size, pp_group_nums = args.pp, ndevs // args.pp - - if not pp_size * dp_size * tp_size == ndevs: - raise RuntimeError("Expected all devices are used") - - devs = cube.runtime.device.DeviceGroup() - - myrank = torch.distributed.get_rank() - - # initialize data parallel group - all_data_parallel_group_ranks = list() - for i in range(pp_size): - start_rank = i * pp_group_nums - end_rank = (i + 1) * pp_group_nums - for j in range(tp_size): - ranks = list(range(start_rank + j, end_rank, tp_size)) - all_data_parallel_group_ranks.append(ranks) - # initialize groups - group = devs.get_group(ranks) - if myrank in ranks: - dp_ranks = ranks - resource.dp_group = group - resource.reducer = cube.runtime.reducer.Reducer(ranks) - print_each_rank(f'initialzed data parallel group: {dp_ranks}', rank_only=myrank) - - # initialize pipelne parallel groups - resource.pp_group = -1 - for i in range(dp_size): - ranks = [data_parallel_group_ranks[i] - for data_parallel_group_ranks in all_data_parallel_group_ranks] - group = devs.get_group(ranks) - if myrank in ranks: - pp_ranks = ranks - resource.pp_group = group - print_each_rank(f'initialzed pipeline parallel group: {pp_ranks}', rank_only=myrank) - - # initialize tensor parallel groups - for i in range(tp_group_nums): - ranks = list(range(i * tp_size, (i + 1) * tp_size)) - group = devs.get_group(ranks) - if myrank in ranks: - tp_ranks = ranks - resource.tp_group = group - print_each_rank(f'initialzed tensor parallel group: {tp_ranks}', rank_only=myrank) - - if args.memory_limit is not None: - assert isinstance(args.memory_limit, float) - print_each_rank(f'set memory constraints on {args.memory_limit} fraction.') - torch.cuda.set_per_process_memory_fraction(args.memory_limit) - - train(args) diff --git a/handcraft/efficientnet/utils.py b/handcraft/efficientnet/utils.py deleted file mode 100644 index 4850a9e9..00000000 --- a/handcraft/efficientnet/utils.py +++ /dev/null @@ -1,586 +0,0 @@ -"""utils.py - Helper functions for building the model and for loading model parameters. - These helper functions are built to mirror those in the official TensorFlow implementation. -""" - -# Author: lukemelas (github username) -# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch -# With adjustments and added comments by workingcoder (github username). - -import re -import math -import collections -from functools import partial -import torch -from torch import nn -from torch.nn import functional as F -from torch.utils import model_zoo - - -################################################################################ -# Help functions for model architecture -################################################################################ - -# GlobalParams and BlockArgs: Two namedtuples -# Swish and MemoryEfficientSwish: Two implementations of the method -# round_filters and round_repeats: -# Functions to calculate params for scaling model width and depth ! ! ! -# get_width_and_height_from_size and calculate_output_image_size -# drop_connect: A structural design -# get_same_padding_conv2d: -# Conv2dDynamicSamePadding -# Conv2dStaticSamePadding -# get_same_padding_maxPool2d: -# MaxPool2dDynamicSamePadding -# MaxPool2dStaticSamePadding -# It's an additional function, not used in EfficientNet, -# but can be used in other model (such as EfficientDet). - -# Parameters for the entire model (stem, all blocks, and head) -GlobalParams = collections.namedtuple('GlobalParams', [ - 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', - 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', - 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top']) - -# Parameters for an individual model block -BlockArgs = collections.namedtuple('BlockArgs', [ - 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', - 'input_filters', 'output_filters', 'se_ratio', 'id_skip']) - -# Set GlobalParams and BlockArgs's defaults -GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) -BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) - -# Swish activation function -if hasattr(nn, 'SiLU'): - Swish = nn.SiLU -else: - # For compatibility with old PyTorch versions - class Swish(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -# A memory-efficient implementation of Swish function -class SwishImplementation(torch.autograd.Function): - @staticmethod - def forward(ctx, i): - result = i * torch.sigmoid(i) - ctx.save_for_backward(i) - return result - - @staticmethod - def backward(ctx, grad_output): - i = ctx.saved_tensors[0] - sigmoid_i = torch.sigmoid(i) - return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) - - -class MemoryEfficientSwish(nn.Module): - def forward(self, x): - return SwishImplementation.apply(x) - - -def round_filters(filters, global_params): - """Calculate and round number of filters based on width multiplier. - Use width_coefficient, depth_divisor and min_depth of global_params. - Args: - filters (int): Filters number to be calculated. - global_params (namedtuple): Global params of the model. - Returns: - new_filters: New filters number after calculating. - """ - multiplier = global_params.width_coefficient - if not multiplier: - return filters - # TODO: modify the params names. - # maybe the names (width_divisor,min_width) - # are more suitable than (depth_divisor,min_depth). - divisor = global_params.depth_divisor - min_depth = global_params.min_depth - filters *= multiplier - min_depth = min_depth or divisor # pay attention to this line when using min_depth - # follow the formula transferred from official TensorFlow implementation - new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) - if new_filters < 0.9 * filters: # prevent rounding by more than 10% - new_filters += divisor - return int(new_filters) - - -def round_repeats(repeats, global_params): - """Calculate module's repeat number of a block based on depth multiplier. - Use depth_coefficient of global_params. - Args: - repeats (int): num_repeat to be calculated. - global_params (namedtuple): Global params of the model. - Returns: - new repeat: New repeat number after calculating. - """ - multiplier = global_params.depth_coefficient - if not multiplier: - return repeats - # follow the formula transferred from official TensorFlow implementation - return int(math.ceil(multiplier * repeats)) - - -def drop_connect(inputs, p, training): - """Drop connect. - Args: - input (tensor: BCWH): Input of this structure. - p (float: 0.0~1.0): Probability of drop connection. - training (bool): The running mode. - Returns: - output: Output after drop connection. - """ - assert 0 <= p <= 1, 'p must be in range of [0,1]' - - if not training: - return inputs - - batch_size = inputs.shape[0] - keep_prob = 1 - p - - # generate binary_tensor mask according to probability (p for 0, 1-p for 1) - random_tensor = keep_prob - random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) - binary_tensor = torch.floor(random_tensor) - - output = inputs / keep_prob * binary_tensor - return output - - -def get_width_and_height_from_size(x): - """Obtain height and width from x. - Args: - x (int, tuple or list): Data size. - Returns: - size: A tuple or list (H,W). - """ - if isinstance(x, int): - return x, x - if isinstance(x, list) or isinstance(x, tuple): - return x - else: - raise TypeError() - - -def calculate_output_image_size(input_image_size, stride): - """Calculates the output image size when using Conv2dSamePadding with a stride. - Necessary for static padding. Thanks to mannatsingh for pointing this out. - Args: - input_image_size (int, tuple or list): Size of input image. - stride (int, tuple or list): Conv2d operation's stride. - Returns: - output_image_size: A list [H,W]. - """ - if input_image_size is None: - return None - image_height, image_width = get_width_and_height_from_size(input_image_size) - stride = stride if isinstance(stride, int) else stride[0] - image_height = int(math.ceil(image_height / stride)) - image_width = int(math.ceil(image_width / stride)) - return [image_height, image_width] - - -# Note: -# The following 'SamePadding' functions make output size equal ceil(input size/stride). -# Only when stride equals 1, can the output size be the same as input size. -# Don't be confused by their function names ! ! ! - -def get_same_padding_conv2d(image_size=None): - """Chooses static padding if you have specified an image size, and dynamic padding otherwise. - Static padding is necessary for ONNX exporting of models. - Args: - image_size (int or tuple): Size of the image. - Returns: - Conv2dDynamicSamePadding or Conv2dStaticSamePadding. - """ - if image_size is None: - return Conv2dDynamicSamePadding - else: - return partial(Conv2dStaticSamePadding, image_size=image_size) - - -class Conv2dDynamicSamePadding(nn.Conv2d): - """2D Convolutions like TensorFlow, for a dynamic image size. - The padding is operated in forward function by calculating dynamically. - """ - - # Tips for 'SAME' mode padding. - # Given the following: - # i: width or height - # s: stride - # k: kernel size - # d: dilation - # p: padding - # Output after Conv2d: - # o = floor((i+p-((k-1)*d+1))/s+1) - # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), - # => p = (i-1)*s+((k-1)*d+1)-i - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): - super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) - self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 - - def forward(self, x): - ih, iw = x.size()[-2:] - kh, kw = self.weight.size()[-2:] - sh, sw = self.stride - oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! - pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) - pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - - -class Conv2dStaticSamePadding(nn.Conv2d): - """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. - The padding mudule is calculated in construction function, then used in forward. - """ - - # With the same calculation as Conv2dDynamicSamePadding - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): - super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) - self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 - - # Calculate padding based on image size and save it - assert image_size is not None - ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size - kh, kw = self.weight.size()[-2:] - sh, sw = self.stride - oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) - pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) - pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) - if pad_h > 0 or pad_w > 0: - self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, - pad_h // 2, pad_h - pad_h // 2)) - else: - self.static_padding = nn.Identity() - - def forward(self, x): - x = self.static_padding(x) - x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - return x - - -def get_same_padding_maxPool2d(image_size=None): - """Chooses static padding if you have specified an image size, and dynamic padding otherwise. - Static padding is necessary for ONNX exporting of models. - Args: - image_size (int or tuple): Size of the image. - Returns: - MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. - """ - if image_size is None: - return MaxPool2dDynamicSamePadding - else: - return partial(MaxPool2dStaticSamePadding, image_size=image_size) - - -class MaxPool2dDynamicSamePadding(nn.MaxPool2d): - """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. - The padding is operated in forward function by calculating dynamically. - """ - - def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False): - super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) - self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride - self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size - self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation - - def forward(self, x): - ih, iw = x.size()[-2:] - kh, kw = self.kernel_size - sh, sw = self.stride - oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) - pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) - pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, - self.dilation, self.ceil_mode, self.return_indices) - - -class MaxPool2dStaticSamePadding(nn.MaxPool2d): - """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. - The padding mudule is calculated in construction function, then used in forward. - """ - - def __init__(self, kernel_size, stride, image_size=None, **kwargs): - super().__init__(kernel_size, stride, **kwargs) - self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride - self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size - self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation - - # Calculate padding based on image size and save it - assert image_size is not None - ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size - kh, kw = self.kernel_size - sh, sw = self.stride - oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) - pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) - pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) - if pad_h > 0 or pad_w > 0: - self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) - else: - self.static_padding = nn.Identity() - - def forward(self, x): - x = self.static_padding(x) - x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, - self.dilation, self.ceil_mode, self.return_indices) - return x - - -################################################################################ -# Helper functions for loading model params -################################################################################ - -# BlockDecoder: A Class for encoding and decoding BlockArgs -# efficientnet_params: A function to query compound coefficient -# get_model_params and efficientnet: -# Functions to get BlockArgs and GlobalParams for efficientnet -# url_map and url_map_advprop: Dicts of url_map for pretrained weights -# load_pretrained_weights: A function to load pretrained weights - -class BlockDecoder(object): - """Block Decoder for readability, - straight from the official TensorFlow repository. - """ - - @staticmethod - def _decode_block_string(block_string): - """Get a block through a string notation of arguments. - Args: - block_string (str): A string notation of arguments. - Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. - Returns: - BlockArgs: The namedtuple defined at the top of this file. - """ - assert isinstance(block_string, str) - - ops = block_string.split('_') - options = {} - for op in ops: - splits = re.split(r'(\d.*)', op) - if len(splits) >= 2: - key, value = splits[:2] - options[key] = value - - # Check stride - assert (('s' in options and len(options['s']) == 1) or - (len(options['s']) == 2 and options['s'][0] == options['s'][1])) - - return BlockArgs( - num_repeat=int(options['r']), - kernel_size=int(options['k']), - stride=[int(options['s'][0])], - expand_ratio=int(options['e']), - input_filters=int(options['i']), - output_filters=int(options['o']), - se_ratio=float(options['se']) if 'se' in options else None, - id_skip=('noskip' not in block_string)) - - @staticmethod - def _encode_block_string(block): - """Encode a block to a string. - Args: - block (namedtuple): A BlockArgs type argument. - Returns: - block_string: A String form of BlockArgs. - """ - args = [ - 'r%d' % block.num_repeat, - 'k%d' % block.kernel_size, - 's%d%d' % (block.strides[0], block.strides[1]), - 'e%s' % block.expand_ratio, - 'i%d' % block.input_filters, - 'o%d' % block.output_filters - ] - if 0 < block.se_ratio <= 1: - args.append('se%s' % block.se_ratio) - if block.id_skip is False: - args.append('noskip') - return '_'.join(args) - - @staticmethod - def decode(string_list): - """Decode a list of string notations to specify blocks inside the network. - Args: - string_list (list[str]): A list of strings, each string is a notation of block. - Returns: - blocks_args: A list of BlockArgs namedtuples of block args. - """ - assert isinstance(string_list, list) - blocks_args = [] - for block_string in string_list: - blocks_args.append(BlockDecoder._decode_block_string(block_string)) - return blocks_args - - @staticmethod - def encode(blocks_args): - """Encode a list of BlockArgs to a list of strings. - Args: - blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. - Returns: - block_strings: A list of strings, each string is a notation of block. - """ - block_strings = [] - for block in blocks_args: - block_strings.append(BlockDecoder._encode_block_string(block)) - return block_strings - - -def efficientnet_params(model_name): - """Map EfficientNet model name to parameter coefficients. - Args: - model_name (str): Model name to be queried. - Returns: - params_dict[model_name]: A (width,depth,res,dropout) tuple. - """ - params_dict = { - # Coefficients: width,depth,res,dropout - 'efficientnet-b0': (1.0, 1.0, 224, 0.2), - 'efficientnet-b1': (1.0, 1.1, 240, 0.2), - 'efficientnet-b2': (1.1, 1.2, 260, 0.3), - 'efficientnet-b3': (1.2, 1.4, 300, 0.3), - 'efficientnet-b4': (1.4, 1.8, 380, 0.4), - 'efficientnet-b5': (1.6, 2.2, 456, 0.4), - 'efficientnet-b6': (1.8, 2.6, 528, 0.5), - 'efficientnet-b7': (2.0, 3.1, 600, 0.5), - 'efficientnet-b8': (2.2, 3.6, 672, 0.5), - 'efficientnet-l2': (4.3, 5.3, 800, 0.5), - } - return params_dict[model_name] - - -def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, - dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): - """Create BlockArgs and GlobalParams for efficientnet model. - Args: - width_coefficient (float) - depth_coefficient (float) - image_size (int) - dropout_rate (float) - drop_connect_rate (float) - num_classes (int) - Meaning as the name suggests. - Returns: - blocks_args, global_params. - """ - - # Blocks args for the whole model(efficientnet-b0 by default) - # It will be modified in the construction of EfficientNet Class according to model - blocks_args = [ - 'r1_k3_s11_e1_i32_o16_se0.25', - 'r2_k3_s22_e6_i16_o24_se0.25', - 'r2_k5_s22_e6_i24_o40_se0.25', - 'r3_k3_s22_e6_i40_o80_se0.25', - 'r3_k5_s11_e6_i80_o112_se0.25', - 'r4_k5_s22_e6_i112_o192_se0.25', - 'r1_k3_s11_e6_i192_o320_se0.25', - ] - blocks_args = BlockDecoder.decode(blocks_args) - - global_params = GlobalParams( - width_coefficient=width_coefficient, - depth_coefficient=depth_coefficient, - image_size=image_size, - dropout_rate=dropout_rate, - - num_classes=num_classes, - batch_norm_momentum=0.99, - batch_norm_epsilon=1e-3, - drop_connect_rate=drop_connect_rate, - depth_divisor=8, - min_depth=None, - include_top=include_top, - ) - - return blocks_args, global_params - - -def get_model_params(model_name, override_params): - """Get the block args and global params for a given model name. - Args: - model_name (str): Model's name. - override_params (dict): A dict to modify global_params. - Returns: - blocks_args, global_params - """ - if model_name.startswith('efficientnet'): - w, d, s, p = efficientnet_params(model_name) - # note: all models have drop connect rate = 0.2 - blocks_args, global_params = efficientnet( - width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) - else: - raise NotImplementedError('model name is not pre-defined: {}'.format(model_name)) - if override_params: - # ValueError will be raised here if override_params has fields not included in global_params. - global_params = global_params._replace(**override_params) - return blocks_args, global_params - - -# train with Standard methods -# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) -url_map = { - 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', - 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', - 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', - 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', - 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', - 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', - 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', - 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', -} - -# train with Adversarial Examples(AdvProp) -# check more details in paper(Adversarial Examples Improve Image Recognition) -url_map_advprop = { - 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', - 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', - 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', - 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', - 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', - 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', - 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', - 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', - 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', -} - -# TODO: add the petrained weights url map of 'efficientnet-l2' - - -def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True): - """Loads pretrained weights from weights path or download using url. - Args: - model (Module): The whole model of efficientnet. - model_name (str): Model name of efficientnet. - weights_path (None or str): - str: path to pretrained weights file on the local disk. - None: use pretrained weights downloaded from the Internet. - load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. - advprop (bool): Whether to load pretrained weights - trained with advprop (valid when weights_path is None). - """ - if isinstance(weights_path, str): - state_dict = torch.load(weights_path) - else: - # AutoAugment or Advprop (different preprocessing) - url_map_ = url_map_advprop if advprop else url_map - state_dict = model_zoo.load_url(url_map_[model_name]) - - if load_fc: - ret = model.load_state_dict(state_dict, strict=False) - assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) - else: - state_dict.pop('_fc.weight') - state_dict.pop('_fc.bias') - ret = model.load_state_dict(state_dict, strict=False) - assert set(ret.missing_keys) == set( - ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) - assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) - - if verbose: - print('Loaded pretrained weights for {}'.format(model_name)) \ No newline at end of file diff --git a/handcraft/gpt3/test-1gpu.sh b/handcraft/gpt3/test-1gpu.sh deleted file mode 100755 index 3e89c256..00000000 --- a/handcraft/gpt3/test-1gpu.sh +++ /dev/null @@ -1,71 +0,0 @@ -#### -# Single Node Model Scaling Test -#### -evaldir=eval/gpt3-coshard-v100-32gb -mkdir -p ${evaldir} - -bs=4 - -test_naive() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing naive (recompute): ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 # > ${evaldir}/1dev-${arch}-naive.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_coshard() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing coshard: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 # > ${evaldir}/1dev-${arch}-coshard.txt - sleep 5 - killall python - sleep 5 - killall python -} - - -test_naive 48 5120 32 2048 -test_naive 48 5120 32 4096 -test_naive 48 5120 32 8192 -test_naive 48 5120 32 12288 - -# test_naive 24 2048 32 2048 -# test_naive 24 2048 32 4096 -# test_naive 24 2048 32 8192 -# # test_naive 24 2048 32 12288 # --# > OOM -# # test_naive 24 2048 32 16384 # --# > OOM -# -# test_coshard 24 2048 32 2048 -# test_coshard 24 2048 32 4096 -# test_coshard 24 2048 32 8192 -# test_coshard 24 2048 32 12288 -# test_coshard 24 2048 32 16384 -# test_coshard 24 2048 32 20480 -# test_coshard 24 2048 32 24576 diff --git a/handcraft/gpt3/test-1node.sh b/handcraft/gpt3/test-1node.sh deleted file mode 100755 index b68fc266..00000000 --- a/handcraft/gpt3/test-1node.sh +++ /dev/null @@ -1,284 +0,0 @@ -#### -# Single Node Model Scaling Test -#### -evaldir=eval/gpt3-coshard-v100-32gb -mkdir -p ${evaldir} - -bs=256 - -test_pp() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing pipeline 1f1b: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --pp-size ${gpus} --tp-size 1 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-pp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_pp_coshard() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing coshard: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --pp-size ${gpus} --tp-size 1 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-pp-coshard.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_tp() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing tp: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --pp-size ${gpus} --tp-size 1 \ - --seqlen ${seqlen} --bs 16 --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-tp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_hybrid() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - if [ ${gpus} == 4 ] - then - echo "testing hybrid: tp:pp=2:2 : ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --pp-size 2 --dp-size 2 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-dp2pp2.txt - sleep 5 - killall python - sleep 5 - killall python - fi - - if [ ${gpus} == 8 ] - then - # echo "testing hybrid: dp:pp=4:2 : ${arch}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=${gpus} \ - # --nnodes=1 \ - # handcraft/gpt3/train.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --dp-size 4 --pp-size 2 \ - # --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - # --fp16 > ${evaldir}/${gpus}dev-${arch}-dp4pp2.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - - echo "testing hybrid: dp:pp=2:4 : ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size 2 --pp-size 4 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-dp2pp4.txt - sleep 5 - killall python - sleep 5 - killall python - fi -} - - -test_dp() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing dp: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size ${gpus} --pp-size 1 --tp-size 1 \ - --seqlen ${seqlen} --bs 16 --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-dp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_dp_coshard() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing DP coshard: ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size ${gpus} --pp-size 1 --tp-size 1 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 --fp16 \ - --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp-coshard.txt - sleep 5 - killall python - sleep 5 - killall python -} - - -test_hybrid_coshard() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - if [ ${gpus} == 4 ] - then - echo "testing coshard hybrid: dp:pp=2:2 : ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size 2 --pp-size 2 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp2pp2-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - fi - - if [ ${gpus} == 8 ] - then - echo "testing hybrid: dp:pp=4:2 : ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size 4 --pp-size 2 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp4pp2-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - - # echo "testing hybrid: dp:pp=2:4 : ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size 2 --pp-size 4 \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp2pp4-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - fi -} - -# 2.6B -test_dp_coshard 32 2560 32 12288 4 -test_pp 32 2560 32 12288 4 -test_hybrid 32 2560 32 12288 4 - -# 6.7B -test_hybrid 32 4096 32 8192 8 # pp2dp4 OOM, pp4dp2: 26.06GB -test_hybrid_coshard 32 4096 32 8192 8 # pp2dp4: 20.4GB -test_hybrid_coshard 32 4096 32 12288 8 # pp2dp4: OOM, dp2pp4: 25.17GB - - -# =========================== - -# test_pp 24 8192 64 2048 8 # 15.45 GB -# test_pp 24 8192 64 4096 8 # 22.84 GB -# test_pp 24 8192 64 8192 8 # OOM -# test_tp 24 8192 64 8192 8 - -# 2.6B -# test_pp_coshard 32 2560 32 2048 1 # 12.24 GB -# test_pp 32 2560 32 2048 1 # can run -# test_pp 32 2560 32 4096 1 # 15.5GB -# test_pp 32 2560 32 8192 1 # 28.38 GB -# test_dp 32 2560 32 8192 4 # 28.38 GB - - -# 6.7B -# test_dp 32 4096 32 4096 8 # OOM -# test_hybrid 32 4096 32 4096 8 # 18.99GB -# test_hybrid 32 4096 32 8192 8 # pp2dp4 oom, pp4dp2: 26.06GB -# test_dp_coshard 32 4096 32 8192 8 # OOM -# test_hybrid_coshard 32 4096 32 8192 8 # pp2dp4: 20.4GB -# test_hybrid 32 4096 32 12288 8 # all OOM -# test_pp 32 4096 32 12288 8 # OOM -# test_pp_coshard 32 4096 32 12288 8 # 16.73GB -# test_hybrid_coshard 32 4096 32 12288 8 # dp4pp2 OOM, dp2pp4: 25.17GB - -# 15B diff --git a/handcraft/gpt3/test-2node.sh b/handcraft/gpt3/test-2node.sh deleted file mode 100755 index fff7cc8a..00000000 --- a/handcraft/gpt3/test-2node.sh +++ /dev/null @@ -1,115 +0,0 @@ -#### -# 2-Node Model Scaling Test -#### -evaldir=eval/gpt3-coshard-v100-16gb -mkdir -p ${evaldir} - -bs=256 - - -test_hybrid() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - dp=$6 - pp=$7 - tp=$8 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing hybrid: dp:pp:tp=${dp}:${pp}:${tp} : ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size ${dp} --pp-size ${pp} --tp-size ${tp} \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-dp${dp}pp${pp}tp${tp}.txt - sleep 5 - killall python - sleep 5 - killall python -} - - -test_hybrid_coshard() -{ - layers=$1 - hidden=$2 - heads=$3 - seqlen=$4 - gpus=$5 - dp=$6 - pp=$7 - tp=$8 - arch=L${layers}E${hidden}H${heads}-seq${seqlen} - - echo "testing coshard hybrid: dp:pp:tp=${dp}:${pp}:${tp} : ${arch}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/gpt3/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --dp-size ${dp} --pp-size ${pp} --tp-size ${tp} \ - --seqlen ${seqlen} --bs ${bs} --micro-bs 1 \ - --fp16 --use-coshard --coshard-num 8 > ${evaldir}/${gpus}dev-${arch}-dp${dp}pp${pp}tp${tp}-coshard.txt - sleep 5 - killall python - sleep 5 - killall python -} - -# 15B - -# test_hybrid 48 5120 32 2048 16 4 4 1 # dp8pp2 OOM dp4pp4 18.91GB -# test_hybrid_coshard 48 5120 32 2048 16 4 4 1 # dp4pp4 20.93 -# -# test_hybrid 48 5120 32 4096 16 4 4 1 # dp4pp4 15.62 -# test_hybrid_coshard 48 5120 32 4096 16 4 4 1 # dp4pp4 20.93 -# -# test_hybrid 48 5120 32 8192 16 1 8 2 # pp8tp2 # pp2tp2 17.17GB -# test_hybrid_coshard 48 5120 32 8192 16 4 4 1 # dp4pp4 # dp4pp4 26.73GB -# -# test_hybrid 48 5120 32 12288 16 1 4 4 # pp8tp2 OOM pp4tp4 20.29GB -# test_hybrid_coshard 48 5120 32 12288 16 2 8 1 # dp4pp4 OOM dp2pp8 26.88GB - -# 6.7B 251.35 TFLOPS -test_hybrid 32 4096 32 2048 16 4 4 1 # dp8pp2 OOM dp4pp4 9.29GB -test_hybrid_coshard 32 4096 32 2048 16 4 4 1 # dp4pp4 -test_hybrid 32 4096 32 4096 16 4 4 1 # dp8pp2 OOM, dp4pp4 13.05GB -test_hybrid_coshard 32 4096 32 4096 16 4 4 1 # dp4pp4 10.45, dp8pp2 OOM -test_hybrid 32 4096 32 8192 16 1 8 2 # dp4pp4 OOM dp2pp8 OOM pp16 OOM pp8tp2 13.46GB -test_hybrid_coshard 32 4096 32 8192 16 4 4 1 # dp4pp4 14.38 -# test_hybrid 32 4096 32 12288 16 1 1 16 # pp8tp2 OOM pp4tp4 OOM pp2tp8 OOM pp1tp16 OOM -test_hybrid_coshard 32 4096 32 12288 16 1 4 4 # dp2pp8 OOM dp2pp4tp2 13.31GB - - -# =========================== - -# 15B -# test_hybrid 48 5120 32 2048 16 4 4 1 # dp8pp2 OOM dp4pp4 18.91GB - -# test_pp 48 5120 32 4096 16 # 12.42GB -# test_hybrid 48 5120 32 4096 16 2 8 1 # dp2pp8 15.62GB -# test_hybrid 48 5120 32 4096 16 4 4 1 # dp4pp4 15.62 -# test_hybrid 48 5120 32 4096 16 8 2 1 # dp8pp2 OOM -# test_hybrid_coshard 48 5120 32 4096 16 4 4 1 # dp16 OOM dp8pp2 OOM dp4pp4 can run - -# test_hybrid 48 5120 32 8192 16 # pp-dp OOM, pp8tp2: can run -# test_pp 48 5120 32 8192 16 # OOM -# test_hybrid_coshard 48 5120 32 8192 16 4 4 1 # dp4pp4 - -# test_hybrid 48 5120 32 12288 16 1 4 4 # pp8tp2 OOM pp4tp4 20.29GB -# test_hybrid_coshard 48 5120 32 12288 16 2 8 1 # dp4pp4 OOM dp2pp8 26.88GB - -python scripts/keep.py --gpus 8 -killall python \ No newline at end of file diff --git a/handcraft/gpt3/train.py b/handcraft/gpt3/train.py deleted file mode 100644 index 12aedf1c..00000000 --- a/handcraft/gpt3/train.py +++ /dev/null @@ -1,745 +0,0 @@ -""" -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers 24 --hidden-size 2048 --heads 32 \ - --dp-size 1 --tp-size 1 --pp-size 1 \ - --seqlen 8192 --bs 8 --micro-bs 1 --fp16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/gpt3/train.py \ - --layers 32 --hidden-size 4096 --heads 32 \ - --dp-size 1 --tp-size 1 --pp-size 4 \ - --seqlen 1024 --bs 8 --micro-bs 1 --fp16 - -350M: --layers 24 --hidden-size 1024 --heads 16 \ -1.3B: --layers 24 --hidden-size 2048 --heads 32 \ -2.6B: --layers 32 --hidden-size 2560 --heads 32 \ -6.7B: --layers 32 --hidden-size 4096 --heads 32 \ -15 B: --layers 48 --hidden-size 5120 --heads 32 \ -39 B: --layers 48 --hidden-size 8192 --heads 64 \ -""" - -import torch -import torch.utils.checkpoint as checkpoint -import cube -import math -import numpy as np - -from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.reducer import Reducer -from handcraft.module.distnn import AllReduceIdentity, IdentityAllreduce, AllGatherSplit - -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - -from handcraft.module.schedule import schedule_1f1b -from handcraft.module.stage import PipeStage, layer_division - -import argparse - -torch.manual_seed(0) -np.random.seed(0) - -parser = argparse.ArgumentParser(description='gpt3') -# model arch -parser.add_argument('--layers', type=int, default=12, - help='number encoder/decoder of layers') -parser.add_argument('--hidden-size', type=int, default=1024, - help='hidden size') -parser.add_argument('--heads', type=int, default=16, - help='number of heads') -parser.add_argument('--seqlen', type=int, default=1024, - help='sequence length') -# training config -parser.add_argument('--bs', type=int, default=256, - help='num of micro batch') -parser.add_argument('--micro-bs', type=int, default=1, - help='micro batch size') -parser.add_argument('--fp16', action='store_true', default=False) -# parallelism -parser.add_argument('--pp-size', type=int, default=1, - help='pipeline parallelism size') -parser.add_argument('--tp-size', type=int, default=1, - help='tensor parallelism size') -parser.add_argument('--dp-size', type=int, default=1, - help='data parallelism size') -parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b'], - help='scheduling algorithm') -parser.add_argument('--use-coshard', action='store_true', default=False) -parser.add_argument('--coshard-num', type=int, default=4, - help='if use coshard, the coshard number') - -args = parser.parse_args() -print(args) - -_tp_group = -1 - -_dp_group = -1 -_dp_reducer = None - -_pp_group = -1 -_pp_global_ranks = () -_layer_divisions = [] - -_schedule = schedule_1f1b - -_pp_embed_group = -1 -_pp_embed_reducer = None -cube.init() -# print_each_rank('setting memory constraints to 16GB') -# torch.cuda.set_per_process_memory_fraction(0.5) - -dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( - [args.dp_size, args.pp_size, args.tp_size] -) - -if len(dp_ranks) != 1: - print_each_rank(f'initializing dp ranks: {dp_ranks}') - _dp_group = DeviceGroup().get_group(dp_ranks) - _dp_reducer = Reducer(dp_ranks) - -if len(tp_ranks) != 1: - print_each_rank(f'initializing tp ranks: {tp_ranks}') - _tp_group = DeviceGroup().get_group(tp_ranks) - assert args.heads % args.tp_size == 0, "cannot be divided by tp-size" - -if len(pp_ranks) != 1: - print_each_rank(f'initializing pp ranks: {pp_ranks}') - _pp_group = DeviceGroup().get_group(pp_ranks) - _pp_global_ranks = tuple(pp_ranks) - _layer_divisions = layer_division([1] * args.layers, args.pp_size) -else: - _layer_divisions = [(0, args.layers)] -print_each_rank(f'layer divisions: {_layer_divisions}') - -if args.schedule == '1f1b' and args.pp_size > 1: - grid = np.arange(args.dp_size * args.pp_size * args.tp_size).reshape( - (args.dp_size, args.pp_size, args.tp_size)) - for dp_rank in range(args.dp_size): - embed_ranks = np.vstack((grid[dp_rank, 0, :], grid[dp_rank, -1, :])) - grank = torch.distributed.get_rank() - for gid in range(args.tp_size): - embed_rank = embed_ranks[:,gid] - embed_rank = np.squeeze(embed_rank).tolist() - print_each_rank(f'creating embed group: {embed_rank}') - group = DeviceGroup().get_group(embed_rank) - if grank in embed_rank: - print(f'rank [{grank}]: embedding group: {embed_rank}') - _pp_embed_group = group - _pp_embed_reducer = Reducer(embed_rank) - - -class Config: - vocab_size = 50432 - seqlen = args.seqlen - layers = args.layers - heads = args.heads - hidden_size = args.hidden_size - -config = Config() - - -class MLP(torch.nn.Module): - - def __init__(self, hidden_dim: int = None): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) - - hidden_dim = config.hidden_size * 4 if hidden_dim is None else hidden_dim - self.dense_h_to_4h = torch.nn.Linear( - config.hidden_size, hidden_dim // self.tp_size - ) - - self.dense_4h_to_h = torch.nn.Linear( - hidden_dim // self.tp_size, config.hidden_size - ) - - def forward_(self, hidden_states): - if self.tp_size > 1: - hidden_states = IdentityAllreduce.apply(hidden_states, self.tp_group) - x = self.dense_h_to_4h(hidden_states) - x = torch.nn.functional.gelu(x) - x = self.dense_4h_to_h(x) - if self.tp_size > 1: - x = AllReduceIdentity.apply(x, self.tp_group) - return x - - def forward(self, hidden_states, recompute=False): - if recompute: - x = checkpoint.checkpoint(self.forward_, hidden_states) - else: - x = self.forward_(hidden_states) - return x - - def flops(self): - mlp_flops = dict( - fc1=config.seqlen * config.hidden_size * config.hidden_size * 4 // self.tp_size, - gelu=8 * config.seqlen * config.hidden_size * 4 // self.tp_size, - fc2=config.seqlen * (config.hidden_size * 4 // self.tp_size) * config.hidden_size, - ) - return sum(mlp_flops.values()) - - -class SeqMLP(torch.nn.Module): - - def __init__(self): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) - - coshard = args.coshard_num - assert (config.hidden_size * 4) % (self.tp_size * coshard) == 0 - hidden_dim = config.hidden_size * 4 // coshard - self.mlps = torch.nn.ModuleList([MLP(hidden_dim) for _ in range(coshard)]) - for mlp in self.mlps: - mlp.tp_size = 1 - - def forward(self, x, recompute=True): - if self.tp_size > 1: - x = IdentityAllreduce.apply(x, self.tp_group) - - outs = None - for mlp in self.mlps: - x_out = mlp(x, recompute=recompute) - outs = x_out if outs is None else outs + x_out - - if self.tp_size > 1: - outs = AllReduceIdentity.apply(outs, self.tp_group) - return outs - - def flops(self): - return sum([mlp.flops() for mlp in self.mlps]) - - -class Attention(torch.nn.Module): - - def __init__(self, num_heads: int = None): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) - - self.num_heads = (config.heads if num_heads is None else num_heads) // self.tp_size - self.head_dim = config.hidden_size // config.heads - projection_size = self.num_heads * self.head_dim - - self.query_key_value = torch.nn.Linear( - config.hidden_size, - 3 * projection_size, - ) - self.softmax = torch.nn.Softmax(dim=-1) - self.norm_factor = math.sqrt(self.head_dim) - self.dense = torch.nn.Linear( - projection_size, config.hidden_size - ) - - def forward_(self, x, mask): - # to test attention memory consumpiton: enable this - # start_mem = torch.cuda.memory_allocated() - - # x: [seqlen, bs, hidden], np: head num | hn: head dim - if self.tp_size > 1: - x = IdentityAllreduce.apply(x, self.tp_group) - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(x) - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_heads, 3 * self.head_dim) - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - query_layer, key_layer, value_layer = \ - torch.chunk(mixed_x_layer, 3, dim=-1) - - # [b, np, seqlen, seqlen] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) - - # [seqlen, b, np, hn] -> [seqlen, b * np, hn] - query_layer = query_layer.view(output_size[2], - output_size[0] * output_size[1], -1) - - # [seqlen, b, np, hn] -> [seqlen, b * np, hn] - key_layer = key_layer.view(output_size[3], - output_size[0] * output_size[1], -1) - - matmul_result = torch.empty( - output_size[0]*output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device()) - - # Raw attention scores. [b * np, seqlen, seqlen] - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, seqlen, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, seqlen] - beta=0.0, alpha=(1.0/self.norm_factor)) - - # change view to [b, np, seqlen, seqlen] - attention_scores = matmul_result.view(*output_size) - - # attention scores and attention mask [b, np, seqlen, seqlen] - if mask is not None: - attention_scores.masked_fill_(mask, -10000.0) - attention_probs = self.softmax(attention_scores) - attention_probs = torch.nn.functional.dropout(attention_probs, 0.0) - - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) - - # change view [seqlen, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), - output_size[0] * output_size[1], -1) - - # change view [b * np, seqlen, seqlen] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) - - # matmul: [b * np, seqlen, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, seqlen, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, seqlen, hn] --> [seqlen, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [seqlen, b, np, hn] --> [seqlen, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.head_dim * self.num_heads,) - context_layer = context_layer.view(*new_context_layer_shape) - - # ================= - # Output. [seqlen, b, h] - # ================= - output = self.dense(context_layer) - if self.tp_size > 1: - output = AllReduceIdentity.apply(output, self.tp_group) - - # to test attention memory consumpiton: enable this - # end_mem = torch.cuda.memory_allocated() - # print(f'mem: attention memory: {(end_mem - start_mem) / 1024 / 1024} MB') - - return output - - def forward(self, x, mask, recompute=False): - if recompute: - x = checkpoint.checkpoint(self.forward_, x, mask) - else: - x = self.forward_(x, mask) - return x - - def flops(self): - seqlen = config.seqlen - attn_flops = dict( - kqv=3 * seqlen * config.hidden_size * self.head_dim * self.num_heads, - kqv_bias=3 * seqlen * self.head_dim * self.num_heads, - q_scale=seqlen * self.num_heads * self.head_dim, # (N h) L d, 1 -> (N h) L d - attn_score=self.num_heads * seqlen * self.head_dim * seqlen, # (N h) L d, (N h) d L -> (N h) L L - attn_softmax=5 * self.num_heads * seqlen * seqlen, # (N h) L L - attn_output=self.num_heads * seqlen * seqlen * self.head_dim, # (N h) L L, (N h) L d -> (N h) L d - out_proj=seqlen * self.num_heads * self.head_dim * config.hidden_size, # L N (h d), E (h d) -> L N E - ) - return sum(attn_flops.values()) - - -class SeqAttention(torch.nn.Module): - - def __init__(self): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - coshard = args.coshard_num - assert config.heads % (coshard * self.tp_size) == 0 - self.shard_num_heads = config.heads // coshard - self.attns = torch.nn.ModuleList( - [Attention(self.shard_num_heads) for _ in range(coshard)] - ) - for attn in self.attns: - attn.tp_size = 1 - - def forward(self, x, mask, recompute=True): - if self.tp_size > 1: - x = IdentityAllreduce.apply(x, self.tp_group) - - outs = None - for attn in self.attns: - x_out = attn(x, mask, recompute) - outs = x_out if outs is None else outs + x_out - - if self.tp_size > 1: - outs = AllReduceIdentity.apply(outs, self.tp_group) - return outs - - def flops(self): - return sum([attn.flops() for attn in self.attns]) - - -class Embedding(torch.nn.Module): - - def __init__(self, num_embeddings: int, embedding_dim: int): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) - self.tp_id = 0 if self.tp_group == -1 else torch.distributed.get_rank(self.tp_group) - - self.vocab_start_index = num_embeddings // self.tp_size * self.tp_id - self.vocab_end_index = num_embeddings // self.tp_size * (self.tp_id + 1) - self.weight = torch.nn.Parameter( - torch.ones((num_embeddings // self.tp_size, embedding_dim), requires_grad=True) - ) - - def forward(self, tokens): - """ - Embedding lookup - if dst is None, use all - """ - if self.tp_size > 1: - mask = (tokens < self.vocab_start_index) | \ - (tokens >= self.vocab_end_index) - tokens = tokens.clone() - self.vocab_start_index - tokens[mask] = 0 - embed = torch.nn.functional.embedding(tokens, self.weight) - embed[mask, :] = 0.0 - embed = AllReduceIdentity.apply(embed, self.tp_group) - else: - embed = torch.nn.functional.embedding(tokens, self.weight) - return embed - - -class TransformerLayer(PipeStage): - - def __init__(self): - super().__init__() - self.input_layernorm = torch.nn.LayerNorm(config.hidden_size) - if args.use_coshard: - # print('use cosharding attention...') - self.self_attention = SeqAttention() - else: - self.self_attention = Attention() - - self.hidden_dropout = 0.0 - self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size) - if args.use_coshard: - # print('use cosharding mlp...') - self.mlp = SeqMLP() - else: - self.mlp = MLP() - - # seqlen, b, h - self.inputs_info = ( - ((config.seqlen, args.micro_bs, config.hidden_size),), - (torch.float16 if args.fp16 else torch.float32,) - ) - self.outputs_info = ( - ((config.seqlen, args.micro_bs, config.hidden_size),), - (torch.float16 if args.fp16 else torch.float32,) - ) - - - def forward(self, hidden_states, attention_mask): - - layernorm_output = self.input_layernorm(hidden_states) - - attention_output = self.self_attention(layernorm_output, attention_mask) - - residual = hidden_states - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout) - layernorm_input = layernorm_input + residual - layernorm_output = self.post_attention_layernorm(layernorm_input) - - mlp_output = self.mlp(layernorm_output) - - residual = layernorm_input - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout) - output = output + residual - return output - - def flops(self): - seqlen = config.seqlen - transformer_flops = dict( - attn_layer_norm=5 * seqlen * config.hidden_size, # (L, N, E) - attn=self.self_attention.flops(), - dropout=seqlen * config.hidden_size, # (L, N, E) - attn_residual=seqlen * config.hidden_size, - fc_layer_norm=5 * seqlen * config.hidden_size, # (L, N, E) - mlp=self.mlp.flops(), - fc_dropout=seqlen * config.hidden_size, - fc_residual=seqlen * config.hidden_size, - ) - return sum(transformer_flops.values()) - - -class Pooler(torch.nn.Module): - - def __init__(self): - super().__init__() - self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) - - def forward(self, hidden_states, sequence_index=0): - pooled = hidden_states[:, sequence_index, :] - pooled = self.dense(pooled) - pooled = torch.tanh(pooled) - return pooled - - -class GPT3(PipeStage): - - def __init__(self): - super().__init__() - self.set_pipeline(pp_ranks) - self.tp_group = _tp_group - self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) - - inputs_info = None - outputs_info = None - - self.word_embeddings = None - if self.is_first_stage: - print(f'rank [{torch.distributed.get_rank()}]: initializing preprocess...') - self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = torch.nn.Embedding( - config.seqlen, config.hidden_size - ) - self.embedding_dropout = torch.nn.Dropout(0.0) - - inputs_info = ((), ()) if inputs_info is None else inputs_info - - start, end = _layer_divisions[self.stage_local_rank] - print_each_rank(f'initializing layers [{start}, {end})...') - layers = [TransformerLayer() for _ in range(end - start)] - self.layers = torch.nn.ModuleList(layers) - - inputs_info = self.layers[0].inputs_info if inputs_info is None else inputs_info - outputs_info = self.layers[-1].outputs_info - - if self.is_last_stage: - print(f'rank [{torch.distributed.get_rank()}]: initializing postprocess...') - self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) if self.word_embeddings is None else self.word_embeddings - self.final_layernorm = torch.nn.LayerNorm(config.hidden_size) - outputs_info = ((1,), (torch.float32,)) - - assert inputs_info is not None - assert outputs_info is not None - self.inputs_info = inputs_info - self.outputs_info = outputs_info - print_each_rank(f'stage: inputs: {inputs_info} | outputs: {outputs_info}') - - def forward(self, hidden_states = None): - # data - # input_ids, position_ids, atten_mask, loss_mask - - # preprocess - if self.is_first_stage: - input_ids, position_ids, _, _ = self.data - word_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - embeddings = word_embeddings + position_embeddings - embeddings = self.embedding_dropout(embeddings) - hidden_states = embeddings - # [seqlen, bs, hidden] - hidden_states = hidden_states.transpose(0, 1).contiguous() - - - assert hidden_states is not None - _, _, attention_mask, _ = self.data - for layer in self.layers: - if args.use_coshard: - # inner recompute - hidden_states = layer(hidden_states, attention_mask) - else: - # block recompute - hidden_states = checkpoint.checkpoint(layer, hidden_states, attention_mask) - outputs = hidden_states - - # postprocess - if self.is_last_stage: - labels, _, _, loss_mask = self.data - - hidden_states = hidden_states.transpose(0, 1).contiguous() - hidden_states = self.final_layernorm(hidden_states) - - if self.tp_size > 1: - hidden_states = IdentityAllreduce.apply(hidden_states, self.tp_group) - logits = torch.nn.functional.linear(hidden_states, self.word_embeddings.weight) - if self.tp_size > 1: - logits = AllGatherSplit.apply(logits, -1, self.tp_group) - - # minor changes from - # https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/pretrain_gpt.py#L75 - logits = logits.float() - logits = logits.view(args.micro_bs * config.seqlen, -1) - labels = labels.view(-1) - loss = torch.nn.functional.cross_entropy(logits, labels) - outputs = loss - - return outputs - - def flops(self): - flops = 0 - if self.is_first_stage: - # ignore - flops += 0 - # transformer layers - flops += sum([t.flops() for t in self.layers]) - if self.is_last_stage: - # logits - flops += config.seqlen * config.hidden_size * config.vocab_size - return flops - - -class GPT3DataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int): - self.bs = batch_size - super().__init__( - shapes=( - [batch_size, config.seqlen,], - [batch_size, config.seqlen,], - [batch_size, config.seqlen,], - [batch_size, config.seqlen,], - ), - dtypes=( - torch.int64, - torch.int64, - torch.float16 if args.fp16 else torch.float, - torch.float16 if args.fp16 else torch.float, - ), - batch_dims=(0, 0, 0, 0) - ) - self.samples = [self.random_sample()] - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] - - def random_sample(self): - input_ids = torch.randint( - 0, 25000, - size=(self.bs, config.seqlen,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - attention_mask, loss_mask, position_ids = self.get_ltor_masks_and_position_ids(input_ids) - return (input_ids, position_ids, attention_mask, loss_mask) - - def get_ltor_masks_and_position_ids(self, input_ids): - """ - Build masks and position id for left to right model. - https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/utils.py#L81 - """ - # Extract batch size and sequence length. - seq_length = config.seqlen - # Attention mask (lower triangular). - mask_dtype = torch.float16 if args.fp16 else torch.float32 - attention_mask = torch.tril( - torch.ones((args.micro_bs, seq_length, seq_length), dtype=mask_dtype, device=torch.cuda.current_device()) - ).view(args.micro_bs, 1, seq_length, seq_length) - - # Loss mask. - loss_mask = torch.ones(input_ids.size(), device=input_ids.device) - eod_token = 2 - loss_mask[input_ids == eod_token] = 0.0 - - # Position ids. - position_ids = torch.arange(seq_length, dtype=torch.long, - device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - # Convert attention mask to binary: - attention_mask = (attention_mask < 0.5) - return attention_mask, loss_mask, position_ids - - -def get_alpa_tflops(): - batch_size = 1 - seq_len = config.seqlen - hidden_size = config.hidden_size - num_layers = config.layers - vocab_size = config.vocab_size - factor = 96 # if checkpoint_activations else 72 - total_flop = factor * batch_size * seq_len * (hidden_size ** 2) * num_layers * \ - (1 + seq_len / (6 * hidden_size)) \ - + 6 * batch_size * seq_len * hidden_size * vocab_size - # Note: if we use dot to compute forward embedding - # then the last term in total_flops should be - # "+ 10 * batch_size * seq_len * hidden_size * vocab_size". - tflops = total_flop / 1e12 # total_flop / latency / num_gpus / 1e12 - return tflops - - -if __name__ == '__main__': - - print_each_rank(f'alpa calculated TFLOPs: {get_alpa_tflops()}', rank_only=0) - model = GPT3() - nparams = sum([param.numel() for param in model.parameters()]) - print_each_rank(f'model params (M): {nparams / 1e6}. Launching model...') - model = model.half().cuda() if args.fp16 else model.cuda() - - dataloader = GPT3DataLoader(args.micro_bs) - optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.02, lr=3e-05, betas=(0.9, 0.98)) - if _pp_embed_reducer is not None: - _pp_embed_reducer.add_param(model.word_embeddings.weight) - if _dp_reducer is not None: - for param in model.parameters(): - _dp_reducer.add_param(param) - - print_each_rank('model weight consumpition:') - memory_summary() - - CudaTimer(enable=False) - torch.distributed.barrier() - iter_num = 6 - for step in range(iter_num): - if step >= 2: - CudaTimer(enable=True).start('e2e') - - # train 1 step - num_microbatch = args.bs // (args.micro_bs * args.dp_size) - if args.pp_size > 1: - _schedule(model, dataloader, num_microbatch) - else: - for _ in range(num_microbatch): - model.data = next(dataloader) - loss = model() - loss.backward() - - if _pp_embed_reducer is not None: - _pp_embed_reducer.allreduce() - - if _dp_reducer is not None: - _dp_reducer.allreduce() - - optimizer.step() - optimizer.zero_grad() - - if step >= 2: - CudaTimer().stop('e2e') - - torch.cuda.empty_cache() - torch.distributed.barrier() - - if step == 0: - print_each_rank('memory after optimizer:', rank_only=0) - memory_summary() - - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-2, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-2) - memory_summary() diff --git a/handcraft/mbart/swap.py b/handcraft/mbart/swap.py deleted file mode 100644 index e41ef9e9..00000000 --- a/handcraft/mbart/swap.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import List -import torch - -_param_map = dict() - - -def get_swap_parameters() -> List[torch.nn.Parameter]: - global _param_map - return list(_param_map.values()) - - -class _SwapEmbed(torch.autograd.Function): - - @staticmethod - def forward(ctx, input: torch.Tensor, weight_id: int, fake: torch.nn.Parameter): - # the fake parameter is preventing no grad fn - ctx.save_for_backward(input, fake) - ctx.weight_id = weight_id - - global _param_map - weight = _param_map[weight_id] - ctx.num_embeddings, ctx.embedding_dim = weight.size() - ctx.weight_dtype = weight.dtype - - with torch.no_grad(): - # swap in - weight_gpu = torch.empty( - weight.size(), dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=True - ) - weight_gpu.copy_(weight) - # compute - output = torch.nn.functional.embedding(input, weight_gpu) - # swap out - del weight_gpu - - return output - - @staticmethod - def backward(ctx, grad_output): - # print(f'debug: >> {torch.distributed.get_rank()} embed backward here') - (input, fake) = ctx.saved_tensors - - global _param_map - weight = _param_map[ctx.weight_id] - - # swap in - with torch.no_grad(): - weight_gpu = torch.empty( - weight.size(), dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=True - ) - weight_gpu.copy_(weight) - # compute - with torch.enable_grad(): - output = torch.nn.functional.embedding(input, weight_gpu) - torch.autograd.backward((output,), (grad_output,)) - # swap out - assert weight_gpu.grad is not None - with torch.no_grad(): - weight.grad.copy_(weight_gpu.grad) - del weight_gpu - - fake_grad = torch.zeros_like(fake) - return None, None, fake_grad - - -class SwapEmbed(torch.nn.Module): - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx=None): - super().__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - assert padding_idx >= 0 - self.padding_idx = self.num_embeddings + padding_idx - else: - self.padding_idx = padding_idx - - _weight = torch.nn.Parameter( - torch.empty(num_embeddings, embedding_dim, requires_grad=True, pin_memory=True) - ) - _weight.grad = torch.zeros_like(_weight, requires_grad=False, pin_memory=True) - self.weight_id = id(_weight) - # the fake parameter is preventing no grad fn - self.fake = torch.nn.Parameter(torch.empty((1,), requires_grad=True)) - global _param_map - _param_map[self.weight_id] = _weight - - def forward(self, input): - return _SwapEmbed.apply(input, self.weight_id, self.fake) - - @property - def weight(self): - global _param_map - return _param_map[self.weight_id] - - -if __name__ == '__main__': - - import cube - from cube.profiler.memory import model_summary - cube.init() - - class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - # self.model1 = torch.nn.Embedding(250000, 1024) - self.model1 = SwapEmbed(250000, 1024) - self.model2 = SwapEmbed(250000, 1024) - # self.model2 = torch.nn.Embedding(250000, 1024) - self.model3 = torch.nn.Embedding(250000, 1024) - - def forward(self, input_ids): - out1 = self.model1(input_ids) - # assert out1.grad_fn is not None - out1 = out1 * 10 - # out2 = checkpoint.checkpoint(self.model2, input_ids) - out2 = self.model2(input_ids) - out2 = out2 / 10 - out3 = self.model3(input_ids) - out3 = -out3 - return torch.sum(out1 + out2 + out3) - - model = Model().cuda() - model.train() - - input_ids = torch.randint( - 0, 25000, (128, 1024), - dtype=torch.int, - device=torch.cuda.current_device(), - ) - - model_summary(model, (input_ids,)) - - loss = model(input_ids) - print(loss) - loss.backward() - - print(model.model1.weight.grad) diff --git a/handcraft/mbart/test-2node-fp32.sh b/handcraft/mbart/test-2node-fp32.sh deleted file mode 100755 index 70d386be..00000000 --- a/handcraft/mbart/test-2node-fp32.sh +++ /dev/null @@ -1,193 +0,0 @@ -evaldir=eval/mbart-fp32-v100-32gb -mkdir -p ${evaldir} - -rm -f notify.py -wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py - -bs=256 - -test_mix_tp_1f1b() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - arch=L${layers}E${hidden}H${heads} - - echo "testing ${gpus}-dev mixture-1f1b: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule tp1f1b > ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results MBart Mixture-1f1b | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt" \ - --file ${evaldir}/${gpus}dev-${arch}-tp1f1b.txt -} - -test_tp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - arch=L${layers}E${hidden}H${heads} - - echo "testing ${gpus}-dev pure tp: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 16 --micro-bs 1 \ - --pp-size 1 --tp-size ${gpus} \ - --schedule 1f1b > ${evaldir}/${gpus}dev-${arch}-tp.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results MBart TP | Node Rank ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp.txt" \ - --file ${evaldir}/${gpus}dev-${arch}-tp.txt -} - -test_pp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev pure pp: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_pp_swap() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev pure pp swap: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule 1f1b --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp-swap.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_hybrid_tp_pp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - - if [ ${gpus} == 16 ] - then - echo "testing ${gpus}-dev tp:pp=8:2 | L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size 2 --tp-size 8 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt - sleep 5 - killall python - sleep 5 - killall python - - # echo "testing ${gpus}-dev tp:pp=4:4 | L${layers}E${hidden}H${heads}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=2 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/mbart/train.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --bs ${bs} --micro-bs 1 \ - # --pp-size 4 --tp-size 4 \ - # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp4.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - # - # echo "testing ${gpus}-dev tp:pp=2:8 | L${layers}E${hidden}H${heads}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=2 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/mbart/train.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --bs ${bs} --micro-bs 1 \ - # --pp-size 8 --tp-size 2 \ - # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp2.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - fi -} - - -# ================================================= -# selected experiments -# ================================================= - -# strong scalability test -# test_mix_tp_1f1b 16 3072 24 16 -# test_tp 16 3072 24 16 - -# model scaling test -test_mix_tp_1f1b 36 5120 32 16 -test_tp 36 5120 32 16 -# test_hybrid_tp_pp 36 5120 32 16 # --> OOM - -# test_mix_tp_1f1b 40 5120 40 16 - -python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/test-4node-fp32.sh b/handcraft/mbart/test-4node-fp32.sh deleted file mode 100755 index 83ece579..00000000 --- a/handcraft/mbart/test-4node-fp32.sh +++ /dev/null @@ -1,216 +0,0 @@ -evaldir=eval/mbart-fp32-v100-32gb -mkdir -p ${evaldir} - -bs=256 - -rm -f notify.py -wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py - -test_mix_tp_1f1b() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev mixture-1f1b: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=4 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule tp1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results MBart Mixture-1f1b | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt" \ - --file ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt -} - -test_tp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev pure tp: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=4 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 8 --micro-bs 1 \ - --pp-size 1 --tp-size ${gpus} \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results MBart Pure TP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt" \ - --file ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt -} - -test_pp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev pure pp: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=4 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_pp_swap() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev pure pp swap: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=4 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule 1f1b --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp-swap.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_hybrid_tp_pp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - - if [ ${gpus} == 32 ] - then - echo "testing ${gpus}-dev tp:pp=16:2 | L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=4 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size 2 --tp-size 16 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp16pp2.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results MBart TP16-PP2 | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp16pp2.txt" \ - --file ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp16pp2.txt - - # echo "testing ${gpus}-dev tp:pp=8:4 | L${layers}E${hidden}H${heads}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=4 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/mbart/mbart_hybrid.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --bs ${bs} --micro-bs 1 \ - # --pp-size 4 --tp-size 8 \ - # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp8pp4.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - # - # echo "testing ${gpus}-dev tp:pp=4:8 | L${layers}E${hidden}H${heads}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=4 \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/mbart/mbart_hybrid.py \ - # --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - # --bs ${bs} --micro-bs 1 \ - # --pp-size 8 --tp-size 4 \ - # --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp8.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - fi -} - - -# ================================================= -# selected experiments -# ================================================= - - -test_mix_tp_1f1b 48 6144 32 32 -test_tp 48 6144 32 32 -test_hybrid_tp_pp 48 6144 32 32 - -python scripts/keep.py --gpus 8 - -# OOM: --layers 64 --hidden-size 6144 --heads 32 -# OOM: --layers 52 --hidden-size 6144 --heads 32 -- 29.64GB -# SUC: --layers 48 --hidden-size 6144 --heads 32 -- 29.64GB -# SUC: --layers 48 --hidden-size 5120 --heads 32 - -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=8 \ -# --nnodes=4 \ -# --node_rank=${NODE_RANK} \ -# --master_addr="${MASTER_IP}" \ -# --master_port=${MASTER_PORT} \ -# handcraft/mbart/train.py \ -# --layers 48 --hidden-size 6144 --heads 32 \ -# --bs 32 --micro-bs 1 \ -# --pp-size 32 --tp-size 1 \ -# --schedule tp1f1b -# -# -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=8 \ -# --nnodes=4 \ -# --node_rank=${NODE_RANK} \ -# --master_addr="${MASTER_IP}" \ -# --master_port=${MASTER_PORT} \ -# handcraft/mbart/train.py \ -# --layers 52 --hidden-size 6144 --heads 32 \ -# --bs 4 --micro-bs 1 \ -# --pp-size 1 --tp-size 32 \ -# --schedule 1f1b \ No newline at end of file diff --git a/handcraft/mbart/test-fp32.sh b/handcraft/mbart/test-fp32.sh deleted file mode 100755 index 27f20ed3..00000000 --- a/handcraft/mbart/test-fp32.sh +++ /dev/null @@ -1,153 +0,0 @@ -evaldir=eval/mbart-fp32-v100-32gb -mkdir -p ${evaldir} - -bs=256 - -test_mix_tp_1f1b() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev mixture-1f1b: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule tp1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp1f1b.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_tp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev pure tp: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs 16 --micro-bs 1 \ - --pp-size 1 --tp-size ${gpus} \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_pp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev pure pp: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_pp_swap() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - echo "testing ${gpus}-dev pure pp swap: L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/train.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size ${gpus} --tp-size 1 \ - --schedule 1f1b --use-swap > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-pp-swap.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_hybrid_tp_pp() -{ - layers=$1 - hidden=$2 - heads=$3 - gpus=$4 - - if [ ${gpus} == 4 ] - then - echo "testing ${gpus}-dev tp:pp=2:2 | L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size 2 --tp-size 2 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp2.txt - sleep 5 - killall python - sleep 5 - killall python - fi - - if [ ${gpus} == 8 ] - then - echo "testing ${gpus}-dev tp:pp=4:2 | L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size 2 --tp-size 4 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp4pp2.txt - sleep 5 - killall python - sleep 5 - killall python - - echo "testing ${gpus}-dev tp:pp=2:4 | L${layers}E${hidden}H${heads}" - OMP_NUM_THREADS=4 torchrun --nproc_per_node=${gpus} --nnodes=1 \ - handcraft/mbart/mbart_hybrid.py \ - --layers ${layers} --hidden-size ${hidden} --heads ${heads} \ - --bs ${bs} --micro-bs 1 \ - --pp-size 2 --tp-size 4 \ - --schedule 1f1b > ${evaldir}/${gpus}dev-L${layers}E${hidden}H${heads}-tp2pp4.txt - sleep 5 - killall python - sleep 5 - killall python - fi -} - - -# ================================================= -# selected experiments -# ================================================= -test_tp 8 2048 16 2 -test_mix_tp_1f1b 8 2048 16 2 -test_hybrid_tp_pp 8 2048 16 2 - -test_mix_tp_1f1b 16 3072 24 4 -test_tp 16 3072 24 4 -test_mix_tp_1f1b 16 3072 24 8 -test_tp 16 3072 24 8 -# test_mix_tp_1f1b 16 3072 24 16 -# test_tp 16 3072 24 16 - -test_mix_tp_1f1b 16 3072 24 4 -test_tp 16 3072 24 4 - -test_mix_tp_1f1b 24 4096 32 8 -test_tp 24 4096 32 8 - -python scripts/keep.py --gpus 8 diff --git a/handcraft/mbart/train.py b/handcraft/mbart/train.py deleted file mode 100644 index ea9fddfc..00000000 --- a/handcraft/mbart/train.py +++ /dev/null @@ -1,861 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/mbart/train.py \ - --layers 8 --hidden-size 2048 --heads 16 \ - --bs 1 --micro-bs 1 --schedule 1f1b -""" - -from typing import Optional -import argparse -import math -import numpy as np -import torch -import torch.utils.checkpoint as checkpoint - -import cube -from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.reducer import Reducer -from handcraft.module.distnn import ReduceBroadcast, AllReduceIdentity, IdentityAllreduce - -from cube.profiler import CudaTimer -from cube.profiler.memory import memory_summary -from cube.profiler.timer import print_each_rank - - -from handcraft.module.schedule import schedule_1f1b, schedule_tp1f1b -from handcraft.module.stage import PipeStage -from handcraft.mbart.swap import SwapEmbed, get_swap_parameters - -torch.manual_seed(0) -np.random.seed(0) - - -parser = argparse.ArgumentParser(description='mbart') -# model arch -parser.add_argument('--layers', type=int, default=12, - help='number encoder/decoder of layers') -parser.add_argument('--hidden-size', type=int, default=1024, - help='hidden size') -parser.add_argument('--heads', type=int, default=16, - help='number of heads') -# training config -parser.add_argument('--bs', type=int, default=256, - help='num of micro batch') -parser.add_argument('--micro-bs', type=int, default=1, - help='micro batch size') -parser.add_argument('--fp16', action='store_true', default=False) -# parallelism -parser.add_argument('--pp-size', type=int, default=1, - help='pipeline parallelism size') -parser.add_argument('--tp-size', type=int, default=1, - help='tensor parallelism size') -parser.add_argument('--dp-size', type=int, default=1, - help='data parallelism size') -parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b', 'tp1f1b'], - help='scheduling algorithm') -parser.add_argument('--use-swap', action='store_true', default=False, - help='swap on embedding weight') - -args = parser.parse_args() -print(args) - -_tp_group = -1 - -_dp_group = -1 -_dp_reducer = None - -_pp_group = -1 -_pp_global_ranks = () -_first_encoder_stage = 0 -_first_decoder_stage = 0 -_layer_divisions = [] - -_schedule = schedule_1f1b if args.schedule == '1f1b' else schedule_tp1f1b -if args.schedule == 'tp1f1b': - assert args.tp_size == 1 and args.dp_size == 1, "tp1f1b only supports pure pipeline" - -_pp_embed_group = -1 - -cube.init() -dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( - [args.dp_size, args.pp_size, args.tp_size] -) - -if len(dp_ranks) != 1: - assert False, "DP is not supported yet" - print_each_rank(f'initializing dp ranks: {dp_ranks}') - _dp_group = DeviceGroup().get_group(dp_ranks) - _dp_reducer = Reducer(dp_ranks) - -if len(tp_ranks) != 1: - print_each_rank(f'initializing tp ranks: {tp_ranks}') - _tp_group = DeviceGroup().get_group(tp_ranks) - assert args.heads % args.tp_size == 0, "cannot be divided by tp-size" - -if len(pp_ranks) != 1: - print_each_rank(f'initializing pp ranks: {pp_ranks}') - _pp_group = DeviceGroup().get_group(pp_ranks) - _pp_global_ranks = tuple(pp_ranks) - - # layer division - chunk_num = args.layers // (args.pp_size // 2) - layers = [chunk_num] * (args.pp_size // 2) - for idx in range(args.layers % (args.pp_size // 2)): - layers[-2-idx] += 1 - layer_num_per_dev = layers + layers - start = 0 - layer_scopes, start = [], 0 - for sid in range(args.pp_size): - end = start + layer_num_per_dev[sid] - layer_scopes.append((start, end)) - if start <= args.layers and end > args.layers: - _first_decoder_stage = sid - start = end - _layer_divisions = layer_scopes - assert _first_decoder_stage != _first_encoder_stage, "Not supported yet" -else: - _layer_divisions = [(0, args.layers * 2)] -print_each_rank( - f"layer divisions: {_layer_divisions} | " - f"first encoder stage: {_first_encoder_stage} | " - f"first decoder stage: {_first_decoder_stage}", rank_only=0 -) - - -# create embed group: first encoder, first decoder -if args.schedule == '1f1b' and args.pp_size > 1: - grid = np.arange( - args.pp_size * args.tp_size).reshape((args.pp_size, args.tp_size)) - encoder_preprocess = grid[_first_encoder_stage,:] - decoder_preprocess = grid[_first_decoder_stage,:] - embed_ranks = np.vstack((encoder_preprocess, decoder_preprocess)) - grank = torch.distributed.get_rank() - for gid in range(args.tp_size): - embed_rank = embed_ranks[:,gid] - embed_rank = np.squeeze(embed_rank).tolist() - print_each_rank(f'creating embed group: {embed_rank}') - group = DeviceGroup().get_group(embed_rank) - if grank in embed_rank: - print(f'rank [{grank}]: embedding group: {embed_rank}') - _pp_embed_group = group - - -class Config: - - num_embeddings = 500000 - decoder_layers = args.layers - encoder_layers = args.layers - embed_dim = args.hidden_size - attention_heads = args.heads - - attention_inner_dim = attention_heads * 64 - ffn_dim = 4 * embed_dim - - attention_dropout = 0.0 # for correctness veirfication - activation_dropout = 0.0 # for correctness veirfication - dropout = 0.0 # for correctness veirfication - - max_target_positions = 1024 - max_source_positions = 1024 - - # classification task - pooler_dropout = 0.0 - num_classes = 3 - - -def attn_fn(query: torch.Tensor, key: torch.Tensor, - wq: torch.Tensor, wq_bias: Optional[torch.Tensor], - wk: torch.Tensor, wk_bias: Optional[torch.Tensor], - wv: torch.Tensor, wv_bias: Optional[torch.Tensor], - wout: torch.Tensor, wout_bias: Optional[torch.Tensor], - h: int, scale: float, dropout: float, mask=True): - """ - query, key: (L, N, E) = (seqlen, batch size, embed_dim) - wq, wk, wv weight: [(num_head * dim_head), E] - dropout: float - h: int: number of heads - """ - num_head = h - L, N = query.size(0), query.size(1) - dim_head = wq.size(0) // num_head - - q = torch.nn.functional.linear(query, wq, wq_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(key, wk, wk_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(key, wv, wv_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # attention mask - if mask: # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout, True, False) # (N h) L L -> (N h) L L - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, wout, wout_bias) # L N (h d), E E -> L N E - return output - - -class PositionalEmbedding(torch.nn.Embedding): - - def __init__(self, num_embeddings: int, embedding_dim: int): - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward(self, seq_len: int): - positions = torch.arange( - 0, seq_len, dtype=torch.long, device=torch.cuda.current_device() - ) - return super().forward(positions + self.offset) - - -class MultiheadAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout=0.0, bias=True): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.embed_dim = embed_dim - self.inner_dim = inner_dim - self.head_dim = inner_dim // num_heads - self.num_heads = num_heads // self.tp_size - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # K - self.k_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None - # V - self.v_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None - # Q - self.q_proj = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(self.inner_dim // self.tp_size)) if bias else None - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, self.inner_dim // self.tp_size)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None - - def forward(self, query: torch.Tensor, key: torch.Tensor): - if self.tp_size > 1: - if key is not query: - key = IdentityAllreduce.apply(key, self.tp_group) - query = IdentityAllreduce.apply(query, self.tp_group) - attn = attn_fn(query, key, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p) - if self.tp_size > 1: - attn = AllReduceIdentity.apply(attn, self.tp_group) - return attn - - def flops(self, seqlen: int): - """ - Get forward-pass FLOPs for 1 micro-batch - """ - attn_flops = dict( - kqv=3 * seqlen * self.embed_dim * self.head_dim * self.num_heads, - kqv_bias=3 * seqlen * self.head_dim * self.num_heads, - q_scale=seqlen * self.num_heads * self.head_dim, # (N h) L d, 1 -> (N h) L d - attn_score=self.num_heads * seqlen * self.head_dim * seqlen, # (N h) L d, (N h) d L -> (N h) L L - attn_softmax=5 * self.num_heads * seqlen * seqlen, # (N h) L L - attn_dropout=self.num_heads * seqlen * seqlen, # (N h) L L -> (N h) L L - attn_output=self.num_heads * seqlen * seqlen * self.head_dim, # (N h) L L, (N h) L d -> (N h) L d - out_proj=seqlen * self.num_heads * self.head_dim * self.embed_dim, # L N (h d), E (h d) -> L N E - ) - return sum(attn_flops.values()) - - -class EncoderLayer(PipeStage): - - def __init__(self, cfg: Config): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.cfg = cfg - self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - self.hidden_dim = cfg.ffn_dim // self.tp_size - self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) - self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - self.inputs_info = ( - ((self.cfg.max_source_positions, 1, self.cfg.embed_dim),), - (torch.float32 if not args.fp16 else torch.float16,) - ) - self.outputs_info = ( - ((self.cfg.max_source_positions, 1, self.cfg.embed_dim),), - (torch.float32 if not args.fp16 else torch.float16,) - ) - - def forward(self, x): - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x, x) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - - if self.tp_size > 1: - x = IdentityAllreduce.apply(x, self.tp_group) - x = self.fc1(x) - x = torch.nn.functional.gelu(x) - x = self.activation_dropout(x) - x = self.fc2(x) - if self.tp_size > 1: - x = AllReduceIdentity.apply(x, self.tp_group) - - x = self.dropout(x) - - x = x + residual - return x - - def flops(self): - seqlen = self.cfg.max_source_positions - enc_flops = dict( - attn_layer_norm=5 * seqlen * self.cfg.embed_dim, # (L, N, E) - attn=self.self_attn.flops(seqlen), - dropout=seqlen * self.cfg.embed_dim, # (L, N, E) - attn_residual=seqlen * self.cfg.embed_dim, - fc_layer_norm=5 * seqlen * self.cfg.embed_dim, # (L, N, E) - fc1=seqlen * self.cfg.embed_dim * self.hidden_dim, # L N E, E hidden -> L N hidden - gelu=8 * seqlen * self.hidden_dim, - fc_inner_dropout=seqlen * self.hidden_dim, - fc2=seqlen * self.cfg.embed_dim * self.hidden_dim, # L N hidden, hidden E -> L N E - fc_dropout=seqlen * self.cfg.embed_dim, - fc_residual=seqlen * self.cfg.embed_dim, - ) - return sum(enc_flops.values()) - - -class DecoderLayer(PipeStage): - - def __init__(self, cfg: Config): - super().__init__() - self.tp_group = _tp_group - self.tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.cfg = cfg - self.dropout = torch.nn.Dropout(p=cfg.dropout) - self.self_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.activation_dropout = torch.nn.Dropout(p=cfg.activation_dropout) - - self.self_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - # encoder atten - self.encoder_attn = MultiheadAttention(cfg.embed_dim, cfg.attention_heads, cfg.attention_inner_dim, cfg.attention_dropout) - self.encoder_attn_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - self.hidden_dim = cfg.ffn_dim // self.tp_size - self.fc1 = torch.nn.Linear(cfg.embed_dim, cfg.ffn_dim // self.tp_size) - self.fc2 = torch.nn.Linear(cfg.ffn_dim // self.tp_size, cfg.embed_dim) - self.final_layer_norm = torch.nn.LayerNorm(cfg.embed_dim) - - self.inputs_info = ( - ((self.cfg.max_target_positions, 1, self.cfg.embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.embed_dim)), - (torch.float32 if not args.fp16 else torch.float16, - torch.float32 if not args.fp16 else torch.float16,) - ) - self.outputs_info = ( - ((self.cfg.max_target_positions, 1, self.cfg.embed_dim), - (self.cfg.max_source_positions, 1, self.cfg.embed_dim)), - (torch.float32 if not args.fp16 else torch.float16, - torch.float32 if not args.fp16 else torch.float16,) - ) - - def forward(self, x, encoder_out): - # print(f'decoder layer: x: {x.size()}, encoder_out: {encoder_out.size()}') - residual = x - x = self.self_attn_layer_norm(x) - - # self attention - x = self.self_attn(x, x) - x = self.dropout(x) - x = residual + x - - # encoder attn - residual = x - x = self.encoder_attn_layer_norm(x) - x = self.encoder_attn(x, encoder_out) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - - if self.tp_size > 1: - x = IdentityAllreduce.apply(x, self.tp_group) - x = self.fc1(x) - x = torch.nn.functional.gelu(x) - x = self.activation_dropout(x) - x = self.fc2(x) - if self.tp_size > 1: - x = AllReduceIdentity.apply(x, self.tp_group) - - x = self.dropout(x) - x = x + residual - return x, encoder_out - - def flops(self): - seqlen = self.cfg.max_target_positions - dec_flops = dict( - attn_layer_norm=0, # ignore - attn=self.self_attn.flops(seqlen) * 2, # self attention + cross attention - dropout=seqlen * self.cfg.embed_dim, # (L, N, E) - attn_residual=seqlen * self.cfg.embed_dim, - fc_layer_norm=0, # ignore - fc1=seqlen * self.cfg.embed_dim * self.hidden_dim, # L N E, E hidden -> L N hidden - gelu=seqlen * self.hidden_dim, - fc_inner_dropout=seqlen * self.hidden_dim, - fc2=seqlen * self.cfg.embed_dim * self.hidden_dim, # L N hidden, hidden E -> L N E - fc_dropout=seqlen * self.cfg.embed_dim, - fc_residual=seqlen * self.cfg.embed_dim, - ) - return sum(dec_flops.values()) - - -class MBartClassificationHead(torch.nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.num_classes = num_classes - self.dense = torch.nn.Linear(input_dim, inner_dim) - self.dropout = torch.nn.Dropout(p=pooler_dropout) - self.out_proj = torch.nn.Linear(inner_dim, num_classes) - self.loss_fct = torch.nn.CrossEntropyLoss() - - def forward(self, dec: torch.Tensor, labels): - # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] - dec = dec.transpose(0, 1)[:,-1,:] - sentence_represent = dec - hidden_states = self.dropout(sentence_represent) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - logits = self.out_proj(hidden_states) - loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) - return loss - - def flops(self): - return 0 # ignore - - -class ShardEmbed(torch.nn.Module): - - def __init__(self, cfg: Config): - super().__init__() - self.tp_group = None if args.schedule == 'tp1f1b' else _tp_group - self.tp_size = 1 if self.tp_group == -1 else torch.distributed.get_world_size(self.tp_group) - self.tp_id = 0 if self.tp_group == -1 else torch.distributed.get_rank(self.tp_group) - - self.cfg = cfg - print(f'initialize sharding embed (x{self.tp_size})') - - self.swap = args.use_swap - if self.swap: - assert args.schedule == '1f1b', "only 1f1b can use swap" - self.embed = SwapEmbed(self.cfg.num_embeddings, self.cfg.embed_dim) - else: - self.vocab_start_index = self.cfg.num_embeddings // self.tp_size * self.tp_id - self.vocab_end_index = self.cfg.num_embeddings // self.tp_size * (self.tp_id + 1) - self.weight = torch.nn.Parameter( - torch.ones((self.cfg.num_embeddings // self.tp_size, self.cfg.embed_dim)) - ) - - # encoder-preprocess - self.embed_positions_encoder = PositionalEmbedding(cfg.max_source_positions, cfg.embed_dim) - self.embed_scale_encoder = math.sqrt(cfg.embed_dim) - self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.embed_dim) - - # decoder-preprocess - self.embed_scale_decoder = math.sqrt(cfg.embed_dim) - self.embed_positions_decoder = PositionalEmbedding(cfg.max_target_positions, cfg.embed_dim) - self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.embed_dim) - - self.init_weight() - - def embed_lookup(self, tokens, dst: Optional[int] = None): - """ - Embedding lookup - if dst is None, use all - """ - if self.swap: - embed = self.embed(tokens) - elif self.tp_size > 1: - mask = (tokens < self.vocab_start_index) | \ - (tokens >= self.vocab_end_index) - tokens = tokens.clone() - self.vocab_start_index - tokens[mask] = 0 - embed = torch.nn.functional.embedding(tokens, self.weight) - embed[mask, :] = 0.0 - if dst is None: - assert _tp_group != -1 - embed = AllReduceIdentity.apply(embed, self.tp_group) - else: - assert self.tp_group is None # args.sharding = True - embed = ReduceBroadcast.apply(embed, dst, None) - else: - embed = torch.nn.functional.embedding(tokens, self.weight) - return embed - - def forward(self, tokens, encoder=False, decoder=False, dst: Optional[int] = None): - """ - If dst is not None: the embedding is sharded across all devices - using tp1f1b, and hence requires a Reduce on the target rank. - """ - assert encoder ^ decoder, "can only be either encoder or decoder" - embed = self.embed_lookup(tokens, dst) - x = embed + self.embed_positions_encoder(embed.size(1)) - if encoder: - x = self.layernorm_embedding_encoder(x) - if decoder: - x = self.layernorm_embedding_decoder(x) - x = torch.nn.functional.dropout(x, p=0.0) - x = x.transpose(0, 1) - return x - - def flops(self): - # ignore - return 0 - - def init_weight(self): - for param in self.parameters(): - torch.nn.init.constant_(param, 0.1) - - -class MBart(PipeStage): - - def __init__(self, cfg: Config): - super().__init__() - self.set_pipeline(_pp_global_ranks) - self.first_encoder_stage = _first_encoder_stage - self.first_decoder_stage = _first_decoder_stage - - self.cfg = cfg - - start, end = _layer_divisions[self.stage_local_rank] - print_each_rank(f'initializing layer ranging from [{start}, {end})') - - self.encoder_preprocess = self.is_first_stage - self.encoder_forward = start < cfg.encoder_layers - self.decoder_preprocess = start <= cfg.encoder_layers and end > cfg.encoder_layers - self.decoder_forward = end > cfg.encoder_layers - self.sharding = args.schedule == 'tp1f1b' - print_each_rank( - f"encoder: (pre: {self.encoder_preprocess}) {self.encoder_forward} | " - f"decoder (pre: {self.decoder_preprocess}) {self.decoder_forward} | " - f"post-process: {self.is_last_stage} | sharding {self.sharding}" - ) - - inputs_info = None - outputs_info = None - - self.embed: ShardEmbed = None - if self.encoder_preprocess: - self.embed = ShardEmbed(cfg) if self.embed is None else self.embed - inputs_info = ((), ()) if inputs_info is None else inputs_info - - self.encoders = [] - self.layer_norm_encoder = None - if self.encoder_forward: - encoders = [EncoderLayer(cfg) for _ in range(min(end, cfg.encoder_layers) - start)] - self.encoders = torch.nn.ModuleList(encoders) - if self.decoder_preprocess or self.decoder_forward: - self.layer_norm_encoder = torch.nn.LayerNorm(cfg.embed_dim) - - inputs_info = self.encoders[0].inputs_info if inputs_info is None else inputs_info - outputs_info = self.encoders[-1].outputs_info - - if self.decoder_preprocess: - _encoder = EncoderLayer(cfg) - self.embed = ShardEmbed(cfg) if self.embed is None else self.embed - inputs_info = _encoder.outputs_info if inputs_info is None else inputs_info - - self.decoders = [] - self.layer_norm_decoder = None - if self.decoder_forward: - decoders = [DecoderLayer(cfg) for _ in range(end - max(cfg.encoder_layers, start))] - self.decoders = torch.nn.ModuleList(decoders) - if self.is_last_stage: - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.embed_dim) - - inputs_info = self.decoders[0].inputs_info if inputs_info is None else inputs_info - outputs_info = self.decoders[-1].outputs_info - - if self.is_last_stage: - self.head = MBartClassificationHead(cfg.embed_dim, 1024, cfg.num_classes, 0.0) - outputs_info = ((1,), torch.float32 if not args.fp16 else torch.float16) - - if self.sharding: - self.embed = ShardEmbed(cfg) if self.embed is None else self.embed - - assert inputs_info is not None - assert outputs_info is not None - self.inputs_info = inputs_info - self.outputs_info = outputs_info - print_each_rank(f'stage: inputs: {inputs_info} | outputs: {outputs_info}') - self.init_weight() - - def init_weight(self): - for param in self.parameters(): - torch.nn.init.constant_(param, 0.01) - - def forward_encoder_shard(self): - """ - Return detached outputs with enabled gradient - """ - source_tokens, _, _ = self.data - enc = self.embed(source_tokens, encoder=True, dst=self.first_encoder_stage) - model.push(enc, 'encoder_sharding_output') - if self.stage_global_rank == self.first_encoder_stage: - enc = enc.detach().requires_grad_() - self.push(enc, 'encoder_preprocess') - return enc - - def forward_decoder_shard(self): - """ - Return detached outputs with enabled gradient - """ - _, prev_tokens, _ = self.data - dec = self.embed(prev_tokens, decoder=True, dst=self.first_decoder_stage) - model.push(dec, 'decoder_sharding_output') - if self.stage_global_rank == self.first_decoder_stage: - dec = dec.detach().requires_grad_() - self.push(dec, 'decoder_preprocess') - return dec - - def forward(self, enc=None, dec=None, recompute=False): - """ - enc: encoder input - dec: decoder input - recompute: outside control for tp1f1b - """ - if self.encoder_preprocess: - if self.sharding: - if recompute: - enc = self.get_last('encoder_preprocess') - else: - enc = self.pop('encoder_preprocess') - else: - source_tokens, _, _ = self.data - enc = self.embed(source_tokens, encoder=True) - - if self.encoder_forward: - if args.pp_size == 1: - def encoder_forward(enc): - for layer in self.encoders: - enc = layer(enc) - return enc - enc = checkpoint.checkpoint(encoder_forward, enc) - else: - for layer in self.encoders: - enc = layer(enc) - if self.layer_norm_encoder is not None: - enc = self.layer_norm_encoder(enc) - output = enc - - if self.decoder_preprocess: - if self.sharding: - if recompute: - dec = self.get_last('decoder_preprocess') - else: - dec = self.pop('decoder_preprocess') - else: - _, prev_tokens, _ = self.data - dec = self.embed(prev_tokens, decoder=True) - - if self.decoder_forward: - if args.pp_size == 1: - def decoder_forward(enc, dec): - assert enc is not None - for layer in self.decoders: - dec, enc = layer(dec, enc) - return enc, dec - enc, dec = checkpoint.checkpoint(decoder_forward, enc, dec) - else: - assert enc is not None - for layer in self.decoders: - dec, enc = layer(dec, enc) - if self.layer_norm_decoder is not None: - dec = self.layer_norm_decoder(dec) - output = (enc, dec) - - if self.is_last_stage: - _, _, label = self.data - loss = self.head(dec, label) - output = loss - - return output - - def flops(self): - enc_flops = sum([enc.flops() for enc in self.encoders]) - enc_layernorm = 5 * self.cfg.max_source_positions * self.cfg.embed_dim if self.layer_norm_decoder is None else 0 - dec_flops = sum([dec.flops() for dec in self.decoders]) - dec_layernorm = 5 * self.cfg.max_target_positions * self.cfg.embed_dim if self.layer_norm_decoder is None else 0 - return enc_flops + enc_layernorm + dec_flops + dec_layernorm - - -def reduce_embed(model: MBart, pp_embed_group): - """ - Embedding gradients needs to be reduced across pipeline stages - """ - if pp_embed_group == -1: - return - if isinstance(model.embed, torch.nn.Module): - if model.embed.swap: - with torch.no_grad(): - grad = model.embed.weight.grad - grad = grad.cuda() - else: - grad = model.embed.weight.grad - else: - grad = None - if grad is not None: - CudaTimer().start('comm') - torch.distributed.all_reduce(grad, group=pp_embed_group) - torch.cuda.synchronize() - CudaTimer().stop('comm') - if isinstance(model.embed, torch.nn.Module): - if model.embed.swap: - with torch.no_grad(): - model.embed.embed.weight.grad.copy_(grad) - torch.cuda.synchronize() - - -class MBartDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int, cfg: Config): - self.bs = batch_size - self.cfg = cfg - super().__init__( - shapes=( - [batch_size, cfg.max_source_positions,], - [batch_size, cfg.max_target_positions,], - [batch_size,] - ), - dtypes=( - torch.int64, - torch.int64, - torch.int, - ), - batch_dims=(0, 0, 0) - ) - self.samples = [self.random_sample()] - - def random_sample(self): - source_token = torch.randint( - 0, 25000, - size=(self.bs, cfg.max_source_positions,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - target_token = torch.randint( - 0, 25000, - size=(self.bs, cfg.max_target_positions,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - labels = torch.randint( - 0, self.cfg.num_classes, - size=(self.bs,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - return (source_token, target_token, labels) - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] - - -if __name__ == '__main__': - - - cfg = Config() - print_each_rank(f'enc/dec layer#: {cfg.encoder_layers}, embed_dim#: {cfg.embed_dim}, heads#: {cfg.attention_heads}, ffn_dim#: {cfg.ffn_dim}', rank_only=0) - - model = MBart(cfg) - nparams = sum([param.numel() for param in model.parameters()]) - forward_flops = model.flops() - tflops = forward_flops * 4 / 1e12 # forward + re-compute forward + backward (=2 forward flops) - print_each_rank(f'model params: {nparams} | TFLOPs: {tflops}. Launching model...') - model = model.half().cuda() if args.fp16 else model.cuda() - - dataloader = MBartDataLoader(args.micro_bs, cfg) - - parameters = get_swap_parameters() + list(model.parameters()) if args.use_swap else model.parameters() - optimizer = torch.optim.Adam(parameters, lr=3e-05, betas=(0.9, 0.98)) - - print_each_rank('model weight consumpition:') - memory_summary() - - CudaTimer(enable=False) - torch.distributed.barrier() - iter_num = 6 - for step in range(iter_num): - if step >= 2: - CudaTimer(enable=True).start('e2e') - - # train 1 step - num_microbatch = args.bs // args.micro_bs - if args.pp_size > 1: - _schedule(model, dataloader, num_microbatch, recompute=True) - reduce_embed(model, _pp_embed_group) - else: - for _ in range(num_microbatch): - model.data = next(dataloader) - loss = model() - loss.backward() - - optimizer.step() - optimizer.zero_grad() - - if step >= 2: - CudaTimer().stop('e2e') - - if step == 0: - print_each_rank('memory after optimizer:', rank_only=0) - memory_summary() - - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-2, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-2) - memory_summary() diff --git a/handcraft/module/distnn.py b/handcraft/module/distnn.py deleted file mode 100644 index bd0878c4..00000000 --- a/handcraft/module/distnn.py +++ /dev/null @@ -1,278 +0,0 @@ -from typing import List -import torch - -from cube.profiler.timer import CudaTimer -from cube.runtime.device import DeviceGroup - - -class SendRecv(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dst: int, group): - CudaTimer().start(field_name='comm') - ctx._tsize = input_.size() - ctx._tdtype = input_.dtype - ctx._src = dst - if not input_.is_contiguous(): - input_ = input_.contiguous() - sendop = torch.distributed.P2POp( - torch.distributed.isend, input_, dst - ) - reqs = torch.distributed.batch_isend_irecv([sendop]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, _grad: torch.Tensor): - CudaTimer().start(field_name='comm') - size = ctx._tsize - dtype = ctx._tdtype - src = ctx._src - grad = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) - recvop = torch.distributed.P2POp( - torch.distributed.irecv, grad, src - ) - reqs = torch.distributed.batch_isend_irecv([recvop]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class RecvSend(torch.autograd.Function): - - @staticmethod - def forward(ctx, size, dtype, src: int, ranks: List[int]): - CudaTimer().start(field_name='comm') - ctx._tsize = size - ctx._tdtype = dtype - ctx._dst = src - input_ = torch.empty( - size, dtype=dtype, device=torch.cuda.current_device(), - requires_grad=True) - recvop = torch.distributed.P2POp( - torch.distributed.irecv, input_, src - ) - reqs = torch.distributed.batch_isend_irecv([recvop]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad: torch.Tensor): - CudaTimer().start(field_name='comm') - dst = ctx._dst - if not grad.is_contiguous(): - grad = grad.contiguous() - sendop = torch.distributed.P2POp( - torch.distributed.isend, grad, dst - ) - reqs = torch.distributed.batch_isend_irecv([sendop]) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return None, None, None, None - - -class AllReduceIdentity(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, group): - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - torch.distributed.all_reduce(input_, group=group) - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -class IdentityAllreduce(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, group): - ctx._group = group - return input_ - - @staticmethod - def backward(ctx, grad_output): - world_size = torch.distributed.get_world_size(ctx._group) - if world_size == 1: - return grad_output, None - CudaTimer().start(field_name='comm') - torch.distributed.all_reduce(grad_output, group=ctx._group) - CudaTimer().stop(field_name='comm') - return grad_output, None - - -class ReduceScatterAllGather(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dim: int, group): - ctx._group = group - ctx._dim = dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_, - CudaTimer().start(field_name='comm') - input_tensors = input_.chunk(world_size, dim) - rank = torch.distributed.get_rank(group) - input_ = torch.empty_like(input_tensors[rank], requires_grad=True) - torch.distributed.reduce_scatter( - input_, input_tensors, group=group - ) - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output): - group = ctx._group - dim = ctx._dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output - CudaTimer().start(field_name='comm') - rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] - tensor_list[rank] = grad_output - torch.distributed.all_gather(tensor_list, grad_output, group=group) - grad = torch.cat(tensor_list, dim=dim).contiguous() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class AllGatherSplit(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dim: int, group): - ctx._group = group - ctx._dim = dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=group) - output = torch.cat(tensor_list, dim=dim).contiguous() - CudaTimer().stop(field_name='comm') - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - group = ctx._group - dim = ctx._dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output - CudaTimer().start(field_name='comm') - input_list = grad_output.chunk(world_size, dim=dim) - rank = torch.distributed.get_rank(group) - grad = input_list[rank].contiguous() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class SplitAllGather(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dim: int, group): - ctx._group = group - ctx._dim = dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - input_list = input_.chunk(world_size, dim=dim) - rank = torch.distributed.get_rank(group) - input_ = input_list[rank].contiguous() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - group = ctx._group - dim = ctx._dim - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output - CudaTimer().start(field_name='comm') - rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] - tensor_list[rank] = grad_output - torch.distributed.all_gather(tensor_list, grad_output, group=group) - grad = torch.cat(tensor_list, dim=dim).contiguous() - CudaTimer().stop(field_name='comm') - return grad, None, None - - -class ReduceBroadcast(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, dst: int, group): - ctx._dst = dst - ctx._group = group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - torch.distributed.reduce(input_, dst, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output): - src = ctx._dst - group = ctx._group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output, None, None - CudaTimer().start(field_name='comm') - torch.distributed.broadcast(grad_output, src, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return grad_output, None, None - - -class BroadcastReduce(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_: torch.Tensor, src: int, group): - ctx._src = src - ctx._group = group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return input_ - CudaTimer().start(field_name='comm') - torch.distributed.broadcast(input_, src, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return input_ - - @staticmethod - def backward(ctx, grad_output): - dst = ctx._src - group = ctx._group - world_size = torch.distributed.get_world_size(group) - if world_size == 1: - return grad_output, None, None - CudaTimer().start(field_name='comm') - if not grad_output.is_contiguous(): - grad_output = grad_output.contiguous() - torch.distributed.reduce(grad_output, dst, group=group) - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return grad_output, None, None diff --git a/handcraft/module/schedule.py b/handcraft/module/schedule.py deleted file mode 100644 index 5f2fc700..00000000 --- a/handcraft/module/schedule.py +++ /dev/null @@ -1,698 +0,0 @@ -from typing import List -import torch - -from cube.profiler.timer import CudaTimer, print_each_rank - -from handcraft.module.stage import PipeStage - -io_input = input - -def forward_step(model, *args, **kwargs): - """ - Forward pass - """ - CudaTimer().start("forward") - outputs = model(*args, **kwargs) - if not isinstance(outputs, tuple): - outputs = (outputs, ) - CudaTimer().stop("forward") - return outputs - - -def backward_step(input_tensors: List[torch.Tensor], - output_tensors: List[torch.Tensor], - output_tensor_grads: List[torch.Tensor]) -> List[torch.Tensor]: - """ - Backward pass - """ - for tensor in input_tensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - tensor.retain_grad() - CudaTimer().start("backward") - torch.autograd.backward(output_tensors, grad_tensors=output_tensor_grads) - CudaTimer().stop("backward") - input_tensor_grads = [] - for tensor in input_tensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - input_tensor_grads.append(tensor.grad) - else: - input_tensor_grads.append(None) - return input_tensor_grads - - -def recv_forward(model: PipeStage, prev_rank: int) -> List[torch.Tensor]: - shapes, dtypes = model.inputs_info - assert len(shapes) == len(dtypes) - assert isinstance(prev_rank, int), "Expected prev_rank to be int" - # print(f'rank {DeviceGroup().rank} recving forward: {shapes}, {dtypes}') - if len(shapes) == 0: return () - - CudaTimer().start(field_name='comm') - tensors = [ - torch.empty( - shape, requires_grad=True, dtype=dtype, - device=torch.cuda.current_device() - ) for shape, dtype in zip(shapes, dtypes) - ] - recv_ops = [ - torch.distributed.P2POp( - torch.distributed.irecv, tensor, prev_rank - ) for tensor in tensors - ] - reqs = torch.distributed.batch_isend_irecv(recv_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return tensors - - -def recv_backward(model: PipeStage, next_rank: int) -> List[torch.Tensor]: - shapes, dtypes = model.outputs_info - assert len(shapes) == len(dtypes) - assert isinstance(next_rank, int), "Expected next_rank to be int" - # print(f'rank {DeviceGroup().rank} recving backward: {shapes}') - if len(shapes) == 0: return () - - CudaTimer().start(field_name='comm') - tensors = [ - torch.empty( - shape, requires_grad=False, dtype=dtype, - device=torch.cuda.current_device() - ) for shape, dtype in zip(shapes, dtypes) - ] - recv_ops = [ - torch.distributed.P2POp( - torch.distributed.irecv, tensor, next_rank - ) for tensor in tensors - ] - reqs = torch.distributed.batch_isend_irecv(recv_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return tensors - - -def send_forward(outputs: List[torch.Tensor], next_rank: int): - assert all([torch.is_tensor(out) for out in outputs]), "Expected List[Tensor]" - assert isinstance(next_rank, int), "Expected next_rank to be int" - if len(outputs) == 0: return - # print(f'rank {DeviceGroup().rank} sending forward: {[tuple(t.size()) for t in outputs]}') - - CudaTimer().start(field_name='comm') - send_ops = [ - torch.distributed.P2POp( - torch.distributed.isend, tensor, next_rank - ) for tensor in outputs - ] - reqs = torch.distributed.batch_isend_irecv(send_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - - -def send_backward(grads: List[torch.Tensor], prev_rank: int): - assert all([torch.is_tensor(grad) for grad in grads]), "Expected List[Tensor]" - assert isinstance(prev_rank, int), "Expected prev_rank to be int" - if len(grads) == 0: return - CudaTimer().start(field_name='comm') - # print(f'rank {DeviceGroup().rank} sending backward: {[tuple(t.size()) for t in grads]}') - - send_ops = [ - torch.distributed.P2POp( - torch.distributed.isend, tensor, prev_rank - ) for tensor in grads - ] - reqs = torch.distributed.batch_isend_irecv(send_ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - - -def send_forward_recv_backward(outputs, model: PipeStage, next_rank: int) -> List[torch.Tensor]: - assert all([torch.is_tensor(out) for out in outputs]), "Expected List[Tensor]" - assert isinstance(next_rank, int), "Expected next_rank to be int" - shapes, dtypes = model.outputs_info - assert len(shapes) == len(dtypes) - # print(f'rank {DeviceGroup().rank} sending forward: {[tuple(t.size()) for t in outputs]} recving backward {shapes}') - - CudaTimer().start(field_name='comm') - ops = list() - # send forward outputs - send_ops = [ - torch.distributed.P2POp( - torch.distributed.isend, tensor, next_rank - ) for tensor in outputs - ] - ops += send_ops - # recv backward inputs - tensors = [ - torch.empty( - shape, requires_grad=True, dtype=dtype, - device=torch.cuda.current_device() - ) for shape, dtype in zip(shapes, dtypes) - ] - recv_ops = [ - torch.distributed.P2POp( - torch.distributed.irecv, tensor, next_rank - ) for tensor in tensors - ] - ops += recv_ops - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return tensors - - -def send_backward_recv_forward(grads, model: PipeStage, prev_rank: int) -> List[torch.Tensor]: - assert all([torch.is_tensor(grad) for grad in grads]), "Expected List[Tensor]" - assert isinstance(prev_rank, int), "Expected prev_rank to be int" - shapes, dtypes = model.inputs_info - assert len(shapes) == len(dtypes) - # print(f'rank {DeviceGroup().rank} sending backward: {[tuple(t.size()) for t in grads]} recving forward {shapes}') - - CudaTimer().start(field_name='comm') - ops = list() - # send backward gradients - send_ops = [ - torch.distributed.P2POp( - torch.distributed.isend, tensor, prev_rank - ) for tensor in grads - ] - ops += send_ops - # recv forward inputs - tensors = [ - torch.empty( - shape, requires_grad=True, dtype=dtype, - device=torch.cuda.current_device() - ) for shape, dtype in zip(shapes, dtypes) - ] - recv_ops = [ - torch.distributed.P2POp( - torch.distributed.irecv, tensor, prev_rank - ) for tensor in tensors - ] - ops += recv_ops - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - CudaTimer().stop(field_name='comm') - return tensors - - - -def schedule_naive(model: PipeStage, dataloader, num_microbatch: int): - """ - neighbors: (prev_rank: int, next_rank: int) - """ - prev_rank = model.prev_stage_global_grank - next_rank = model.next_stage_global_rank - - for _ in range(num_microbatch): - model.data = next(dataloader) - # print(f'rank {rank} recving forward input...') - inputs = () if model.is_first_stage else recv_forward(model, prev_rank) - # forward - outputs = forward_step(model, *inputs) - # send forward - if not model.is_last_stage: - # print(f'rank {rank} sending forward output...') - send_forward(outputs, next_rank) - # recv backward - # print(f'rank {rank} recving backward input...') - output_grads = (None,) if model.is_last_stage else recv_backward(model, next_rank) - # backward - input_grads = backward_step(inputs, outputs, output_grads) - # send backward - if not model.is_first_stage: - # print(f'rank {rank} sending backward output...') - send_backward(input_grads, prev_rank) - - -def schedule_1f1b(model: PipeStage, - dataloader, - num_microbatch: int, - recompute=False): - - num_stage = model.num_stages - prev_rank = model.prev_stage_global_grank - next_rank = model.next_stage_global_rank - - num_warmup_microbatches = num_stage - 1 - model.stage_local_rank - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatch) - num_warmup_remaining = num_microbatch - num_warmup_microbatches - - # warmup - for i in range(num_warmup_microbatches): - model.data = next(dataloader) - # recv forward - inputs = () if model.is_first_stage else recv_forward(model, prev_rank) - # forward - model.push(inputs, 'inputs') - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs) - model.push(None, 'outputs') - else: - outputs = forward_step(model, *inputs) - model.push(outputs, 'outputs') - # send forward - send_forward(outputs, next_rank) - - # before running 1f1b: need to recv first forward tensor - if num_warmup_remaining > 0: - model.data = next(dataloader) - inputs = () if model.is_first_stage else recv_forward(model, prev_rank) - - # run 1f1b - for i in range(num_warmup_remaining): - model.data = next(dataloader) - # forward - model.push(inputs, 'inputs') - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs) - model.push(None, 'outputs') - # correctness checkprint - # if model.is_last_stage: - # print(outputs) - else: - outputs = forward_step(model, *inputs) - model.push(outputs, 'outputs') - - # send forward recv backward - grads = (None,) - if not model.is_last_stage: - grads = send_forward_recv_backward(outputs, model, next_rank) - - # backward - inputs, outputs = model.pop('inputs'), model.pop('outputs') - if recompute: - assert outputs is None - outputs = forward_step(model, *inputs) - input_grads = backward_step(inputs, outputs, grads) - - # send backward - inputs = () - if not model.is_first_stage: - if i != (num_warmup_remaining-1): - # send backward recv forward - inputs = send_backward_recv_forward(input_grads, model, prev_rank) - else: - # send backward - send_backward(input_grads, prev_rank) - - # cooldown - for i in range(num_warmup_microbatches): - inputs, outputs = model.pop('inputs'), model.pop('outputs') - # recv backward - grads = (None,) if model.is_last_stage else recv_backward(model, next_rank) - # backward - if recompute: - assert outputs is None - outputs = forward_step(model, *inputs) - input_grads = backward_step(inputs, outputs, grads) - # send backward - if not model.is_first_stage: - send_backward(input_grads, prev_rank) - - model.assert_empty_cached() - - -def schedule_tp1f1b_pp2(model: PipeStage, - dataloader, - num_microbatch: int, - recompute=False): - def tp_encoder_preprocess(model: PipeStage) -> torch.Tensor: - model.data = next(dataloader) - enc = model.forward_encoder_shard() - return (enc,) - - def tp_decoder_preprocess(model: PipeStage) -> torch.Tensor: - model.data = next(dataloader) - dec = model.forward_decoder_shard() - return (dec,) - - def tp_encoder_backward(model: PipeStage): - enc = model.pop('encoder_sharding_output') - if model.stage_local_rank == model.first_encoder_stage: - grads = model.pop('encoder_sharding_grad') - else: - grads = (torch.empty_like(enc),) - backward_step((), (enc,), grads) - - def tp_decoder_backward(model: PipeStage): - dec = model.pop('decoder_sharding_output') - if model.stage_local_rank == model.first_decoder_stage: - grads = model.pop('decoder_sharding_grad') - else: - grads = (torch.empty_like(dec),) - backward_step((), (dec,), grads) - - num_stage = model.num_stages - rank = model.stage_local_rank - prev_rank = model.prev_stage_global_grank - next_rank = model.next_stage_global_rank - - output_grads = (None,) - inputs = () - for step in range(num_microbatch * 2 + 2): - - encoder_fmid = step // 2 - encoder_bmid = step - 2 - decoder_fmid = step - 1 - decoder_bmid = step - 3 - - # step1: forward sharding 0 - if step % 2 == 0: - encoder_fmid = step // 2 - encoder_inputs = None - if 0 <= encoder_fmid and encoder_fmid <= num_microbatch - 1: - encoder_inputs = tp_encoder_preprocess(model) - # step1: forward sharding 1 - if step % 2 == 1: - decoder_fmid = (step - 1) // 2 - decoder_inputs = None - if 0 <= decoder_fmid and decoder_fmid <= num_microbatch - 1: - decoder_inputs = tp_decoder_preprocess(model) - - if rank % 2 == 0: - # do forward - if step % 2 == 0: - fmid = step // 2 - do_forward = 0 <= fmid and fmid <= num_microbatch - 1 - if do_forward: - model.push(encoder_inputs, 'inputs') - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *(), recompute=True) - model.push(None, 'outputs') - else: - outputs = forward_step(model, *()) - model.push(outputs, 'outputs') - - # recompute - next_bmid = (step + 1 - 3) // 2 if step+1 >= 3 else -1 - do_next_backward = 0 <= next_bmid and next_bmid <= num_microbatch - 1 - if recompute and do_next_backward : - outputs_bp = model.pop('outputs') - assert outputs_bp is None - outputs_bp = forward_step(model, *()) - model.push_ahead(outputs_bp, 'outputs') - - # send forward recv backward - if do_forward and do_next_backward: - # print(f'rank {rank}: step {step}: send forward recv backward') - output_grads = send_forward_recv_backward(outputs, model, next_rank) - elif do_next_backward: - # print(f'rank {rank}: step {step}: recv backward') - output_grads = recv_backward(model, next_rank) - elif do_forward: - # print(f'rank {rank}: step {step}: send forward') - send_forward(outputs, next_rank) - - # do backward - else: - bmid = (step - 3) // 2 if step >= 3 else -1 - if 0 <= bmid and bmid <= num_microbatch - 1: - inputs, outputs = model.pop('inputs'), model.pop('outputs') - input_grads = backward_step(inputs, outputs, output_grads) - output_grads = (None,) - assert len(input_grads) == 1 - model.push(input_grads, 'encoder_sharding_grad') - - if rank % 2 == 1: - # do backward - if step % 2 == 0: - bmid = (step - 2) // 2 if step >= 2 else -1 - do_backward = 0 <= bmid and bmid <= num_microbatch - 1 - - # backward - if do_backward: - inputs, outputs = model.pop('inputs'), model.pop('outputs') - assert output_grads == (None,) - input_grads = backward_step(inputs, outputs, output_grads) - assert len(inputs) == 2 - model.push((input_grads[1],), 'decoder_sharding_grad') - input_grads = (input_grads[0],) - - # send backward recv forward - next_fmid = (step + 1 - 1) // 2 - do_next_forward = 0 <= next_fmid and next_fmid <= num_microbatch - 1 - if do_backward and do_next_forward: - # print(f'rank {rank}: step {step}: send backward recv forward') - inputs = send_backward_recv_forward(input_grads, model, prev_rank) - elif do_next_forward: - # print(f'rank {rank}: step {step}: recv forward') - inputs = recv_forward(model, prev_rank) - elif do_backward: - # print(f'rank {rank}: step {step}: send backward') - send_backward(input_grads, prev_rank) - # do forward - else: - # forward - fmid = (step - 1) // 2 - if 0 <= fmid and fmid <= num_microbatch - 1: - assert inputs != () - model.push((inputs[0], decoder_inputs[0]), 'inputs') - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs, recompute=True) - model.push(None, 'outputs') - else: - outputs = forward_step(model, *inputs) - model.push(outputs, 'outputs') - - # recompute - if recompute: - inputs, outputs = model.pop('inputs'), model.pop('outputs') - assert outputs is None - outputs = forward_step(model, *inputs) - model.push_ahead(inputs, 'inputs') - model.push_ahead(outputs, 'outputs') - - - # step3: backward sharding 1 - if step % 2 == 0: - decoder_bmid = (step - 2) // 2 - if 0 <= decoder_bmid and decoder_bmid <= num_microbatch - 1: - tp_decoder_backward(model) - - # step3: backward sharding 0 - if step % 2 == 1: - encoder_bmid = (step - 3) // 2 - if 0 <= encoder_bmid and encoder_bmid <= num_microbatch - 1: - tp_encoder_backward(model) - - model.assert_empty_cached() - - -def schedule_tp1f1b(model: PipeStage, - dataloader, - num_microbatch: int, - recompute=False): - # special cases for pipeline stage == 2 - if model.num_stages == 2: - return schedule_tp1f1b_pp2(model, dataloader, num_microbatch, recompute) - - def tp_encoder_preprocess(model: PipeStage) -> torch.Tensor: - model.data = next(dataloader) - enc = model.forward_encoder_shard() - return (enc,) - - def tp_decoder_preprocess(model: PipeStage) -> torch.Tensor: - model.data = next(dataloader) - dec = model.forward_decoder_shard() - return (dec,) - - def tp_encoder_backward(model: PipeStage): - enc = model.pop('encoder_sharding_output') - if model.stage_local_rank == model.first_encoder_stage: - grads = model.pop('encoder_sharding_grad') - else: - grads = (torch.empty_like(enc),) - backward_step((), (enc,), grads) - - def tp_decoder_backward(model: PipeStage): - dec = model.pop('decoder_sharding_output') - if model.stage_local_rank == model.first_decoder_stage: - grads = model.pop('decoder_sharding_grad') - else: - grads = (torch.empty_like(dec),) - backward_step((), (dec,), grads) - - num_stage = model.num_stages - rank = model.stage_local_rank - prev_rank = model.prev_stage_global_grank - next_rank = model.next_stage_global_rank - fofst = [-(step // 2) for step in range(num_stage)] - bofst = [-(num_stage - 1 - (step // 2)) for step in range(num_stage)] - - fofst = fofst[model.stage_local_rank] - bofst = bofst[model.stage_local_rank] - last_backward = (None,) - last_forward = (None,) - - for step in range(num_microbatch + num_stage - 1): - fmid, bmid = step + fofst, step + bofst - encoder_fmid = step - decoder_fmid = step - num_stage // 2 // 2 - encoder_bmid = step + 1 - num_stage // 2 * 2 - decoder_bmid = step + 1 - int(num_stage // 2 * 1.5) - do_backward = 0 <= bmid and bmid <= num_microbatch - 1 - do_forward = 0 <= fmid and fmid <= num_microbatch - 1 - - # step1: tp encoder forward - encoder_inputs = None - if 0 <= encoder_fmid and encoder_fmid <= num_microbatch - 1: - encoder_inputs = tp_encoder_preprocess(model) - # step2: tp decoder forward - decoder_inputs = None - if 0 <= decoder_fmid and decoder_fmid <= num_microbatch - 1: - decoder_inputs = tp_decoder_preprocess(model) - - # step 3: forward + backward - if rank % 2 == 0: - # inter-barrier - inputs = () - if not model.is_first_stage: - if do_forward and last_backward != (None,): - # print(f'rank {rank} send backward grad + recv forward output ') - inputs = send_backward_recv_forward(last_backward, model, prev_rank) - elif do_forward: - # print(f'rank {rank} recv forward output ') - inputs = recv_forward(model, prev_rank) - elif last_backward != (None,): - # print(f'rank {rank} send backward grad ') - send_backward(last_backward, prev_rank) - - # forward - if do_forward: - - if model.stage_local_rank == model.first_encoder_stage and encoder_inputs is not None: - model.push(encoder_inputs, 'inputs') - elif model.stage_local_rank == model.first_decoder_stage and decoder_inputs is not None: - assert len(inputs) == 1 and len(decoder_inputs) == 1 - model.push((inputs[0], decoder_inputs[0]), 'inputs') - else: - model.push(inputs, 'inputs') - - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs, recompute=True) - model.push(None, 'outputs') - else: - outputs = forward_step(model, *inputs) - model.push(outputs, 'outputs') - - # recompute if backward is needed - if do_backward: - inputs, outputs_bp = model.pop('inputs'), model.pop('outputs') - if recompute: - assert outputs_bp is None - outputs_bp = forward_step(model, *inputs) - - # intra-barrier send recv - output_grads = (None,) - if (do_forward and not model.is_last_stage) and (do_backward and not model.is_last_stage): - # send forward recv backward - # print(f'rank {rank} recv backward grad + send forward output ') - output_grads = send_forward_recv_backward(outputs, model, next_rank) - elif do_forward and not model.is_last_stage: - # print(f'rank {rank} send forward output ') - send_forward(outputs, next_rank) - elif do_backward and not model.is_last_stage: - # print(f'rank {rank} recv backward grad ') - output_grads = recv_backward(model, next_rank) - - # backward - last_backward = (None,) - if do_backward: - # inputs, outputs = input_tensors.pop(0), output_tensors.pop(0) - input_grads = backward_step(inputs, outputs_bp, output_grads) - - if model.stage_local_rank == model.first_encoder_stage: - assert len(input_grads) == 1 - model.push(input_grads, 'encoder_sharding_grad') - elif model.stage_local_rank == model.first_decoder_stage: - assert len(input_grads) == 2 - model.push((input_grads[1],), 'decoder_sharding_grad') - input_grads = (input_grads[0],) - last_backward = input_grads - - # step 3: backward + forward - if rank % 2 == 1: - # inter-barrier - if model.is_last_stage: - output_grads = (None,) - else: - if do_backward and last_forward != (None,): - # print(f'rank {rank} recv backward grad + send forward output ') - output_grads = send_forward_recv_backward(last_forward, model, next_rank) - elif do_backward: - # print(f'rank {rank} recv backward grad ') - output_grads = recv_backward(model, next_rank) - elif last_forward != (None,): - # print(f'rank {rank} send forward output ') - send_forward(last_forward, next_rank) - - # backward - last_backward = (None,) - if do_backward: - inputs, outputs_bp = model.pop('inputs'), model.pop('outputs') - # backward - input_grads = backward_step(inputs, outputs_bp, output_grads) - last_backward = input_grads - - # intra-barrier - if do_backward and do_forward: - # print(f'rank {rank} send backward grad + recv forward output ') - inputs = send_backward_recv_forward(input_grads, model, prev_rank) - elif do_backward: - # print(f'rank {rank} send backward grad ') - send_backward(input_grads, prev_rank) - elif do_forward: - # print(f'rank {rank} recv forward output ') - inputs = recv_forward(model, prev_rank) - - # forward - last_forward = (None,) - if do_forward: - # forward step - model.push(inputs, 'inputs') - if recompute: - with torch.no_grad(): - outputs = forward_step(model, *inputs, recompute=True) - model.push(None, 'outputs') - # correctness check print - # if model.is_last_stage: - # print(outputs) - else: - outputs = forward_step(model, *inputs) - model.push(outputs, 'outputs') - last_forward = outputs - - next_backward = 0 <= (bmid+1) and (bmid+1) <= num_microbatch - 1 - if next_backward: - if recompute: - inputs, outputs_bp = model.pop('inputs'), model.pop('outputs') - assert outputs_bp is None - outputs = forward_step(model, *inputs) - model.push_ahead(inputs, 'inputs') - model.push_ahead(outputs, 'outputs') - - # step 4: sharding decoder backward - if 0 <= decoder_bmid and decoder_bmid <= num_microbatch - 1: - tp_decoder_backward(model) - - # step 5: sharding encoder backward - if 0 <= encoder_bmid and encoder_bmid <= num_microbatch - 1: - tp_encoder_backward(model) - - model.assert_empty_cached() diff --git a/handcraft/module/stage.py b/handcraft/module/stage.py deleted file mode 100644 index 9341a086..00000000 --- a/handcraft/module/stage.py +++ /dev/null @@ -1,166 +0,0 @@ -from typing import Any, List, Tuple -import torch - - -class PipeStage(torch.nn.Module): - - def __init__(self): - super().__init__() - self._cached = dict() - self._data = () - self._input_shapes = () - self._input_dtypes = () - self._output_shapes = () - self._output_dtypes = () - - # pipeline information - self._num_stages = None - self._is_first_stage = None - self._is_last_stage = None - self._stage_grank = None # global rank - self._next_grank = None # global rank - self._prev_grank = None # global rank - self._stage_lrank = None # local rank - self._next_lrank = None # local rank - self._prev_lrank = None # local rank - - @property - def is_first_stage(self) -> bool: - return self._is_first_stage - - @property - def is_last_stage(self) -> bool: - return self._is_last_stage - - @property - def next_stage_global_rank(self) -> int: - return self._next_grank - - @property - def prev_stage_global_grank(self) -> int: - return self._prev_grank - - @property - def stage_global_rank(self) -> int: - return self._stage_grank - - @property - def next_stage_local_rank(self) -> int: - return self._next_lrank - - @property - def prev_stage_local_rank(self) -> int: - return self._prev_lrank - - @property - def stage_local_rank(self) -> int: - return self._stage_lrank - - @property - def num_stages(self): - return self._num_stages - - def set_pipeline(self, group_global_ranks: Tuple[int]): - """ - Setup pipeline information given global ranks. - Note NCCL group should be initialized outside - """ - if len(group_global_ranks) == 0: - group_global_ranks = (torch.distributed.get_rank(),) - self._num_stages = len(group_global_ranks) - self._stage_grank = torch.distributed.get_rank() - self._stage_lrank = group_global_ranks.index(self._stage_grank) - - self._next_grank = group_global_ranks[(self._stage_lrank+1) % self.num_stages] - self._prev_grank = group_global_ranks[(self._stage_lrank-1) % self.num_stages] - - self._next_lrank = (self._stage_lrank+1) % self.num_stages - self._prev_lrank = (self._stage_lrank-1) % self.num_stages - - self._is_first_stage = self._stage_lrank == 0 - self._is_last_stage = self._stage_lrank == self.num_stages - 1 - - def get_last(self, region: str = 'default') -> Any: - return self._cached[region][-1] - - def pop_last(self, region: str = 'default') -> Any: - return self._cached[region].pop(-1) - - def pop(self, region: str = 'default') -> Any: - return self._cached[region].pop(0) - - def push_ahead(self, val: Any, region: str = 'default') -> Any: - self._cached[region] = [val] + self._cached[region] - - def push(self, val: Any, region: str = 'default'): - if region not in self._cached: - self._cached[region] = [] - return self._cached[region].append(val) - - def assert_empty_cached(self): - for key, vals in self._cached.items(): - assert len(vals) == 0, f"key {key} still has {len(vals)} values" - - @property - def inputs_info(self) -> Tuple[Tuple, Tuple]: - """ - return input shapes and dtypes - """ - return self._input_shapes, self._input_dtypes - - @inputs_info.setter - def inputs_info(self, shapes_dtypes: Tuple[Tuple, Tuple]): - self._input_shapes, self._input_dtypes = shapes_dtypes - - @property - def outputs_info(self) -> Tuple[Tuple, Tuple]: - """ - return output shapes and dtypes - """ - return self._output_shapes, self._output_dtypes - - @outputs_info.setter - def outputs_info(self, shapes_dtypes: Tuple[Tuple, Tuple]): - self._output_shapes, self._output_dtypes = shapes_dtypes - - @property - def data(self) -> Tuple: - return self._data - - @data.setter - def data(self, datas: Tuple): - self._data = datas - - -def layer_division(times: List[int], num_stages: int, start_id: int = 0, limits: List[int] = None): - """ - Computation balance division - """ - divisions = [] - budget = sum(times) / num_stages - nlayers = len(times) - start, end = 0, 1 - if limits is None: - limits = [None] * num_stages - else: - assert len(limits) == num_stages - for idx in range(num_stages): - accum = times[start] - assert end <= nlayers - while end != nlayers: - if limits[idx] is not None and (end - start) == limits[idx]: - break - if times[end] > 0 and budget - accum < 0.5 * times[end]: - break - accum += times[end] - end += 1 - if idx == num_stages - 1: - end = nlayers - divisions.append((start, end)) - if idx != num_stages - 1: - budget = sum(times[end:]) / (num_stages - 1 - idx) - start, end = end, end+1 - for sid in range(num_stages): - start, end = divisions[sid] - divisions[sid] = (start+start_id, end+start_id) - return divisions diff --git a/handcraft/playground/dag/data_parallel_raw.py b/handcraft/playground/dag/data_parallel_raw.py deleted file mode 100644 index 0f8ab7ff..00000000 --- a/handcraft/playground/dag/data_parallel_raw.py +++ /dev/null @@ -1,130 +0,0 @@ -from graph_manipulation import * - -''' -[dataflow graph]: the logic of one training iteration -DFG input: samples, weight tensors, optimizer state tensors -DFG output: (updated) weight tensors, (updated) optimizer state tensors - -data tensors as edges, produced by one operator and consumed by one or more operator(s) -operators as nodes, consumes one or more tensor(s), (mostly) producing one tensor - -(assumption resizable batch) same DL model description for different batch-sizes (batch-size as variable) - ref 1 ONNX: https://github.com/onnx/onnx/issues/2182 - ref 2 ... - -/////////////// - -graph manipulation as data-parallel, 4 method options -option 1: manually decide manipulation for each node/tensor following an oracle that knows everything -option 2: deep copy graph and manually decide adjustment for each node/tensor -option 3: using node-role, e.g., DataNode/Fwd/Bwd split; Optimizer replicate; weight's gradient all-reduce before used by Optimizers -option 4: using tensor info. e.g., tensors with batch-dim will split, operators and other tensors adapt accordingly -''' -def data_parallel_raw(g: Graph, device_num: int, method: int): - def oracle_func(*args) -> bool: - pass - - if method == 'raw graph manipulation': #per node manipulation following oracle's instruction - # 1. multiply operators for ``parallel'' in data-parllelism - for node in g.nodes: - new_nodes = [] - for device_id in range(device_num): - new_node_inputs = [] - for ts in node.inputs: - # find corresponding input tensor, which is another new operator's (sliced/replicated...) output - new_input = oracle_func(node, ts, device_id, device_num).query("find_new_input") - new_node_inputs.append(new_input) - - new_node_outputs = [] - for ts in node.outputs: - # new out tensor of the same shape (if replicate) or 1/N (if slice on certain dim) - new_output_shape = oracle_func(node, ts, device_id, device_num).query("new_output_shape") - new_output = Tensor(new_output_shape) # create new tensor as output (will be another operator(s)'s input) - new_node_outputs.append(new_output) - - new_node_type = oracle_func(node).query("new_node_type") - # create new node, with device info - new_node = Node(type=new_node_type, inputs=new_node_inputs, outputs=new_node_outputs, - device=device_id) - new_nodes.append(new_node) - - g.replace(node, new_nodes) #replacing with new nodes - - # 2. inserting gradient averaging - for node in g.nodes: - new_allreduce_node = None - input_to_replace = None - for ts in node.inputs: - if oracle_func(ts).query('insert allreduce here'): - new_allreduce_node = Node(type='allreduce', inputs=ts) - input_to_replace = ts - break - - new_node = Node(type=node.type, inputs=node.inputs - input_to_replace + new_allreduce_node.output, - outputs=node.outputs) - g.replace(node, [new_allreduce_node, new_node]) - - elif method == 'replicate graph and adjust': #replicate entire graph and adjust, similar to approaches of Horovod and PyTorch DDP - # 1. deep copy graph - graphs = [g.deepcopy() for i in range(device_num)] - - # 2. reset batch size for each new graph, leveraging model description resizable batch (<-assumption) - # input or output shape inferred from shape_inference, representing split (1/N shape) or replicated (unchanged shape) - for index, graph in enumerate(graphs): - graph.arguments.batch_size = g.arguments.batch_size // device_num - graph.to_device(device=index) - - # 3. inserting gradient averaging - for graph in graphs: - for node in graph.nodes: - new_allreduce_node = None - input_to_replace = None - for ts in node.inputs: - if oracle_func(ts).query('insert allreduce here'): - new_allreduce_node = Node(type=allreduce, inputs=ts, outputs=Tensor(node.outputs.shape)) - input_to_replace = ts - break - - new_node = Node(type=node.type, inputs=node.inputs - input_to_replace + new_allreduce_node.output, - outputs=node.outputs) - graph.replace(node, [new_allreduce_node, new_node]) - - elif method == 3: #node role based manipulation - for node in g.nodes: - if isinstance(node, (NodeData)): - new_nodes = [ - Node(type=node.type, inputs=None, - config=node.config.reset_batch_size(node.config.batch_size // device_num), - outputs=node.outputs.shape[0] // device_num + node.outputs.shape[1:]) for - device_id in range(device_num)] - elif isinstance(node, (NodeFwd, NodeBwdA)): - new_nodes = [ - # assume inputs[0] as activation (for NodeFwd) or activation's gradient (for NodeBwdA) and inputs[1] as weight - Node(type=node.type, - inputs=[oracle_func(node, node.inputs[0], device_id, device_num).query("find_new_input"), #batch-split - oracle_func(node, node.inputs[1], device_id, device_num).query("find_new_input")], #replicated - outputs=Tensor(node.outputs.shape[0] // device_num + node.outputs.shape[1:])) #output batch-split - for device_id in range(device_num)] - elif isinstance(node, (NodeBwdW)): #backward that computing weight's gradient - # assume inputs[0] as activation's gradient and inputs[1] as activation, both with batch-dim - new_nodes = [ - Node(type=node.type, - inputs=[oracle_func(node, node.inputs[0], device_id, device_num).query("find_new_input"), #batch-split - oracle_func(node, node.inputs[1], device_id, device_num).query("find_new_input")], #batch-split - outputs=[Tensor(node.outputs.shape[0])]) #shape unchanged, but only 1/N value - for device_id in range(device_num)] - elif isinstance(node, (NodeOpt)): - new_nodes = trans(node, algo.replica, device_num) # replicated optimizers - [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] - - g.replace(node, new_nodes) - - #omit device assign and allreduce insertion - elif method == 4: #tensor dimention info based manipulation - pass - - - - - -data_parallel_raw(graph) \ No newline at end of file diff --git a/handcraft/playground/dag/graph_manipulation.py b/handcraft/playground/dag/graph_manipulation.py deleted file mode 100644 index 397b3657..00000000 --- a/handcraft/playground/dag/graph_manipulation.py +++ /dev/null @@ -1,550 +0,0 @@ -from enum import Enum -import sys - -# class NodeType(Enum): -# UNKNOWN = 0 -# DATALOADER = 1 -# FORWARD = 2 -# BACKWARD_A = 3 -# BACKWARD_W = 4 -# OPTIMIZER = 5 - -class bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKCYAN = '\033[96m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - - -nodeList = [] -global_node_id = -1 - - -def new_node_id(): - global global_node_id - global_node_id += 1 - return global_node_id - - -def last_node(last_step=1): - assert len(nodeList) >= last_step - return nodeList[-last_step] - - - - -class AlgorithmMgr: - batch_split: str - replica: str - - def __init__(self): - self.batch_split = 'batch_split' - self.replica = 'replica' - self.split = 'split' - self.tensor_split = 'tensor_split' - -class Operator: - algo: AlgorithmMgr - def __init__(self): - self.algo = AlgorithmMgr() - -class Node: - id: int - inputs: [] - outputs: [] - removed: bool - op: Operator - - def __init__(self): - super().__init__() - self.removed = False - self.id = new_node_id() - self.inputs = [] - self.outputs = [] - self.op = Operator() - nodeList.append(self) - - def spawn(self, portion: str=None): - node = self.__class__() #create same type - node.inputs = [t.spawn() for t in self.inputs] - node.outputs = [t.spawn() for t in self.outputs] - return node - - def __str__(self): - return "Node({}), {}\tinput:{}\toutput:{} ".format( - self.id, str(type(self)).lstrip(''), - '\t'.join([str(x) for x in self.inputs] if len(self.inputs) > 0 else ""), - '\t'.join([str(x) for x in self.outputs] if len(self.outputs) > 0 else "")) - - - def slim(self): - return "Node({}), {}".format( - self.id, str(type(self)).lstrip('')) - -class NodeData(Node): - def __init__(self): - super(NodeData, self).__init__() - # self.type = NodeType.DATALOADER - - -class NodeFwd(Node): - def __init__(self): - super(NodeFwd, self).__init__() - # self.type = NodeType.FORWARD - - -class NodeBwd(Node): - def __init__(self): - super(NodeBwd, self).__init__() - - -class NodeBwdA(NodeBwd): - def __init__(self): - super(NodeBwd, self).__init__() - # self.type = NodeType.BACKWARD_A - - -class NodeBwdW(NodeBwd): - def __init__(self): - super(NodeBwd, self).__init__() - # self.type = NodeType.BACKWARD_W - - -class NodeOpt(Node): - def __init__(self): - super(NodeOpt, self).__init__() - # self.type = NodeType.OPTIMIZER - - -# for logic tensor -class TensorType(Enum): - UNKNOWN = 0 - WEIGHT = 1 - WEIGHT_UPDATED = 2 - ACTIVATION = 3 - GRADIENT_A = 4 - GRADIENT_W = 5 - OPTIMIZER_STATE = 6 - LOSS = 7 - - -logicTensorList = [] -global_logic_tensor_id = -1 - - -def new_logic_tensor_id(): - global global_logic_tensor_id - global_logic_tensor_id += 1 - return global_logic_tensor_id - - -def last_logic_tensor(last_step=1): - assert len(logicTensorList) >= last_step - return logicTensorList[-last_step] - - -class LogicTensor: - id: int - type: TensorType - - def __init__(self, tensor_type=TensorType.UNKNOWN): - super().__init__() - self.id = new_logic_tensor_id() - self.type = tensor_type - - -tensorList = [] -global_tensor_id = -1 - - -def new_tensor_id(): - global global_tensor_id - global_tensor_id += 1 - return global_tensor_id - - -def last_tensor(last_step=1): - assert len(tensorList) >= last_step - return tensorList[-last_step] - - -class Tensor: - id: int - logic: LogicTensor - portion: str - - def new(self): - pass - - def __init__(self, tensor_type=TensorType.UNKNOWN, exist_tensor=None, portion=None): - super().__init__() - self.id = new_tensor_id() - if exist_tensor is None: - self.logic = LogicTensor(tensor_type) - self.portion = 'full' - else: - self.logic = exist_tensor.logic - self.portion = exist_tensor.portion - if portion is not None: - self.portion += '>' + portion - tensorList.append(self) - - def __getattr__(self, attr): - if(attr == 'type'): - return self.logic.type - else: - return self.attr - - def __str__(self): - return ("Tensor({}), {} of ({} {})".format( - self.id, - self.portion, - self.logic.id, - str(self.type).lstrip('TensorType.'))) - - def spawn(self, portion:str=None): - return Tensor(exist_tensor=self, portion=portion) - -class Graph: - nodes: [] - - def find_input(self, node: Node, tensor_type: TensorType): - ret = list(filter(lambda x: x.type == tensor_type, node.inputs)) - assert len(ret) > 0 - return ret[0] - - def create_sample_graph(self): - op_num = 2 - - for idx in range(1): # sample data loader - node = NodeData() # Node(NodeType.DATALOADER) - node.outputs.append(Tensor(TensorType.ACTIVATION)) - self.nodes.append(node) - - for idx in range(op_num): # forward ops - node = NodeFwd() # Node(NodeType.FORWARD) - node.inputs.append(last_tensor()) - node.inputs.append(Tensor(TensorType.WEIGHT)) - node.outputs.append(Tensor(TensorType.ACTIVATION)) - self.nodes.append(node) - - for idx in range(1): # label data loader - node = NodeData() # Node(NodeType.DATALOADER) - node.outputs.append(Tensor(TensorType.ACTIVATION)) - self.nodes.append(node) - - for idx in range(1): # loss - node = NodeFwd() # Node(NodeType.FORWARD) - node.inputs.append(last_tensor()) - node.outputs.append(Tensor(TensorType.LOSS)) - self.nodes.append(node) - - for fwd_node in list(filter(lambda x: type(x) is NodeFwd, self.nodes))[::-1]: # backward ops - out_gradient = last_tensor() - if len(fwd_node.inputs) == 2: - # computing weight's gradient - node = NodeBwdW() # Node(NodeType.BACKWARD_W) - node.inputs.append(out_gradient) # out_g_act - node.inputs.append(self.find_input(fwd_node, TensorType.WEIGHT)) - node.outputs.append(Tensor(TensorType.GRADIENT_W)) - self.nodes.append(node) - if len(fwd_node.inputs) >= 1: - # computing activation's gradient - node = NodeBwdA() # Node(NodeType.BACKWARD_A) - node.inputs.append(out_gradient) - node.inputs.append(self.find_input(fwd_node, TensorType.ACTIVATION)) - node.outputs.append(Tensor(TensorType.GRADIENT_A)) - self.nodes.append(node) - else: - assert False - - for bwd_w_node in list(filter(lambda x: type(x) is NodeBwdW, self.nodes)): # optimizer - node = NodeOpt() # Node(NodeType.OPTIMIZER) - node.inputs.append(self.find_input(bwd_w_node, TensorType.WEIGHT)) # WEIGHT - node.inputs.append(bwd_w_node.outputs[0]) - node.inputs.append(Tensor(TensorType.OPTIMIZER_STATE)) - node.outputs.append(Tensor(TensorType.WEIGHT_UPDATED)) - self.nodes.append(node) - - def __init__(self, create_sample=False): - super().__init__() - self.nodes = [] - - if create_sample: - self.create_sample_graph() - - def __str__(self): - # for node in self.nodes: - return '\n'.join([str(x) if not x.removed else "DEL "+str(x) for x in self.nodes]) - - -graph = Graph(create_sample=True) -# print('graph = \n{}'.format(graph)) -global_new_graph = Graph() - -# print('nodeList[{}] = \n{}'.format(len(nodeList), nodeList)) -# print('tensorList[{}] = \n{}'.format(len(tensorList), tensorList)) - - -class Config: - num: int - - -class Device(int): - def __init__(self, x, base=10): - super().__init__(x, base) - - -class Parallelizer: - def run(self, g: Graph, config: Config) -> Graph: - return None - - -def trans(node: Node, algo, num: int) -> [Node]: - node.removed = True - nodes = [node.spawn() for i in range(num)] - if algo == 'replica': - global_new_graph.nodes.extend(nodes) - return nodes - elif algo == 'batch_split': - for idx, nd in enumerate(nodes): - for ts in nd.inputs + nd.outputs: - ts.portion += '>batch-{}/{}'.format(idx, num) - global_new_graph.nodes.extend(nodes) - return nodes - elif algo == 'split': #elementwise split - for idx, nd in enumerate(nodes): - for ts in nd.inputs + nd.outputs: - ts.portion += '>flat-{}/{}'.format(idx, num) - global_new_graph.nodes.extend(nodes) - return nodes - elif algo == 'tensor_split': - for idx, nd in enumerate(nodes): - for ts in nd.inputs + nd.outputs: - ts.portion += '>tensor-{}/{}'.format(idx, num) - global_new_graph.nodes.extend(nodes) - return nodes - else: - assert False - - -def sched_s(node: Node, dev: Device) -> None: - print("{}sched_s...{} @ {}{}".format(bcolors.OKGREEN, node.slim(), dev, bcolors.ENDC)) - pass - - -def sched_t_pair(node_before: Node, node_after: Node) -> bool: - print("{}sched_t...{}-> {}{}".format(bcolors.OKBLUE, node_before.slim(), node_after.slim(), bcolors.ENDC)) - #TODO legal check - return True - - -def sched_t(nodes: [Node]) -> bool: - for i in range (len(nodes) - 1): - if not sched_t_pair(nodes[i], nodes[i+1]): - return False - return True - - -def set_affinity(producer_node, consumer_node): - print("{}affinity...{}-> {}{}".format(bcolors.OKCYAN, producer_node.slim(), consumer_node.slim(), bcolors.ENDC)) - pass - - -from collections import namedtuple -def index_enumerate(list: []): - Entry = namedtuple('Entry', ['idx', 'item']) - # return [{'idx': i, 'item': x} for i, x in enumerate(list)] - return [Entry(i, x) for i, x in enumerate(list)] - - -### TODO how about Tx in flexflow GSPMD etc.? - -# traditional data-parallel process: -# tx start -# 1. replicated graph g -> g' * N -# 2. change batch-size of g' -# 3. insert gradient allreduce (manually, can auto-gen in our sys) -# tx end - - -class DataParallelParallelizer(Parallelizer): - def run(self, g: Graph, config: Config) -> Graph: - global global_new_graph - global_new_graph.nodes.clear() - - # ---------------- - for node in g.nodes: - if isinstance(node, (NodeData, NodeFwd, NodeBwd)): - nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-split - [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] - elif isinstance(node, (NodeOpt)): - nodes = trans(node, node.op.algo.replica, config.num) # replicated optimizers - [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] - else: - assert False - # ---------------- - - global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] - return global_new_graph - - -class DataParallelZeROParallelizer(Parallelizer): - def run(self, g: Graph, config: Config) -> Graph: - global global_new_graph - global_new_graph.nodes.clear() - - # ---------------- - for node in g.nodes: - if isinstance(node, (NodeData, NodeFwd, NodeBwd)): - nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-split - [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] - elif isinstance(node, (NodeOpt)): - nodes = trans(node, node.op.algo.split, config.num) # split optimizers - [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] - else: - assert False - # ---------------- - - global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] - return global_new_graph - - -class GradientAccumulationParallelizer(Parallelizer): - def run(self, g: Graph, config: Config) -> Graph: - global global_new_graph - global_new_graph.nodes.clear() - - # ---------------- - for node in g.nodes: - if isinstance(node, (NodeData, NodeFwd, NodeBwd)): - nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit - sched_t(nodes) # sequential order - elif isinstance(node, (NodeOpt)): - pass - else: - assert False - # ---------------- - - global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] - return global_new_graph - - -def node_to_stage(g: Graph, config: Config) -> {}: # return node->stage mapping - ret = {} - nodes = g.nodes # TODO topo forward traversal - fwd_node = list(filter(lambda x: type(x) is NodeFwd, nodes)) - - per_stage_size = len(nodes) // config.stages - for node in nodes: - # TODO replace dummy assignment - ret[node] = 0 - - return ret - - -class GPipeParallelizer(Parallelizer): - def run(self, g: Graph, config: Config) -> Graph: - global global_new_graph - global_new_graph.nodes.clear() - - # ---------------- - n2stage = node_to_stage(g, config) - for node in g.nodes: - device = n2stage[node] - if isinstance(node, (NodeData, NodeFwd, NodeBwd)): - nodes = trans(node, node.op.algo.batch_split, config.num) # by batch-dim-slit - sched_t(nodes) # sequential order - [sched_s(node=x, dev=device) for x in nodes] # assign same stage device - elif isinstance(node, (NodeOpt)): - sched_s(node, device) - else: - assert False - # ---------------- - - global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] - return global_new_graph - - -class TensorParallelParallelizer(Parallelizer): - def run(self, g: Graph, config: Config) -> Graph: - global global_new_graph - global_new_graph.nodes.clear() - - # ---------------- - for node in g.nodes: - if isinstance(node, (NodeFwd, NodeBwd, NodeOpt)): - nodes = trans(node, node.op.algo.tensor_split, config.num) # by tensor-dim-slit - [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] - elif isinstance(node, (NodeData)): - nodes = trans(node, node.op.algo.replica, config.num) - [sched_s(node=x.item, dev=x.idx) for x in index_enumerate(nodes)] - else: - assert False - # ---------------- - - global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] - return global_new_graph - - -def find_consumers(graph: Graph, tensor: Tensor): - ret = [] - for node in graph.nodes: - if any([input_tensor.logic == tensor.logic for input_tensor in node.inputs]): - ret.append(node) - return ret - -def find_producers(graph: Graph, tensor: Tensor): - ret = [] - for node in graph.nodes: - if any([output_tensor.logic == tensor.logic for output_tensor in node.outputs]): - ret.append(node) - return ret - - -class Recompute(Parallelizer): - def run(self, g: Graph, config: Config) -> Graph: - global global_new_graph - global_new_graph.nodes.clear() - - # ---------------- - for node in g.nodes: - if isinstance(node, (NodeFwd)): - origin_fwd, recompute_fwd = trans(node, node.op.algo.replica, 2) - consumers = find_consumers(g, origin_fwd.outputs[0]) - for consumer in consumers: - if isinstance(consumer, NodeFwd): - set_affinity(origin_fwd, consumer) # break dependencies op0.fwd -> op1.fwd; - else: - set_affinity(recompute_fwd, consumer) # break dependencies op0.fwd' -> op0.bwd - producers = list(filter(lambda x: isinstance(x, NodeBwd), find_producers(g, consumer.inputs[0]))) - for producer in producers: - sched_t_pair(producer, recompute_fwd) - # ---------------- - - global_new_graph.nodes[:0] = [nd for nd in graph.nodes if not nd.removed] - return global_new_graph - -class ActivationSwap(Parallelizer): - def run(self, g: Graph, config: Config) -> Graph: - #TODO activate consuming NodeBwd -> Identity(CPU) + NodeBwd - pass - -para = DataParallelParallelizer() -# para = DataParallelZeROParallelizer() -# para = GradientAccumulationParallelizer() -# para = GPipeParallelizer() -# para = TensorParallelParallelizer() -# para = Recompute() - - -# config = Config() -# config.num = 2 -# config.stages = 2 -# global_new_graph = para.run(graph, config) -# print('new_graph = \n{}'.format(global_new_graph)) diff --git a/handcraft/playground/dag/graph_trans.py b/handcraft/playground/dag/graph_trans.py deleted file mode 100644 index 4134acec..00000000 --- a/handcraft/playground/dag/graph_trans.py +++ /dev/null @@ -1,46 +0,0 @@ -# general transformations -''' -Op := I -> Op (pre-identity) -Op := Op -> I (post-identity) -Op := Op, Op (replicate) -''' - -# batch transformation (due to DL operators are sample-wise) -''' -DataLoader - split (output)activation - -OperatorForward - split (input)activation - replica (input)weight - split (output)activation* - -OperatorBackward-(activation's gradient) - split (input)d-activation* - replica (input)weight - split (output)d-activation - -OperatorBackward-(weight's gradient) - split (input)d-activation* - split (input)activation - value-split (to-reduce) (output)d-weight -''' - -# non-batch transformation (operator semantic aware) -''' -elementwise operators (including optimizers) - arbitrary same split on inputs and outputs - -MatMul [M, K]*[K, N] => [M, N] - 1. split M or N (e.g., cases with M or N as batch-dim) - 2. split reducing dim K: [M, K/2]*[K/2, N] => value-split [M, N] - -Conv2D - 1. split (input) image H, W => halo exchange then local Conv2D, split (output) image - 2. split (input) filter out-channel-dim => Conv2D on replicated image with partial filter, value-split (output) image - -(more cases) ... -''' - - - diff --git a/handcraft/playground/test.sh b/handcraft/playground/test.sh deleted file mode 100755 index 43a95b88..00000000 --- a/handcraft/playground/test.sh +++ /dev/null @@ -1,144 +0,0 @@ -datadir=eval/sharding -mkdir -p ${datadir} - -# hidden=768 -# heads=12 -# -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=1 \ -# --nnodes=1 \ -# handcraft/playground/transformers.py \ -# --hidden-size ${hidden} --heads ${heads} \ -# --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt -# -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=1 \ -# --nnodes=1 \ -# handcraft/playground/transformers.py \ -# --hidden-size ${hidden} --heads ${heads} \ -# --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt - - -hidden=1024 -heads=16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt - - -hidden=1536 -heads=16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt - -hidden=2304 -heads=24 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt - - -hidden=2560 -heads=32 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt - - - -hidden=4096 -heads=32 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt - -hidden=5120 -heads=40 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt - - -hidden=12288 -heads=96 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 1 > ${datadir}/sharding-E${hidden}H${heads}-naive.txt - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/playground/transformers.py \ - --hidden-size ${hidden} --heads ${heads} \ - --seq 8 > ${datadir}/sharding-E${hidden}H${heads}-shard8.txt \ No newline at end of file diff --git a/handcraft/playground/transformers.py b/handcraft/playground/transformers.py deleted file mode 100644 index 9b46ad50..00000000 --- a/handcraft/playground/transformers.py +++ /dev/null @@ -1,326 +0,0 @@ -""" -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/dummy/transformers.py -""" -import torch -from torch.utils import checkpoint - -import cube -from cube.profiler.memory import memory_summary, model_summary -from cube.profiler.timer import CudaTimer, print_each_rank -import argparse - - -parser = argparse.ArgumentParser(description='transformer') -# model arch -parser.add_argument('--layers', type=int, default=4, - help='number encoder/decoder of layers') -parser.add_argument('--hidden-size', type=int, default=1024, - help='hidden size') -parser.add_argument('--heads', type=int, default=16, - help='number of heads') -parser.add_argument('--bs', type=int, default=8, - help='number of heads') -# parallelism -parser.add_argument('--seq', type=int, default=1, - help='sharding sequential execution') -args = parser.parse_args() -print(args) - - -cube.init() - - -class Config: - - layers = args.layers - embed_dim = args.hidden_size - num_heads = args.heads - ffn_dim = embed_dim * 4 - - seqlen = 1024 - - -def self_attention(query: torch.Tensor, - q_proj: torch.Tensor, q_bias: torch.Tensor, - k_proj: torch.Tensor, k_bias: torch.Tensor, - v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, mask=True): - num_head = h - L, N = query.size(0), query.size(1) - dim_head = q_proj.size(0) // num_head - - q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(query, k_proj, k_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(query, v_proj, v_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # attention mask - if mask: # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - mask = torch.tril(ones) - mask = mask.view(N, 1, L, L) - mask = (mask < 0.5) - attn = attn.masked_fill_(mask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj, out_bias) # L N (h d), E E -> L N E - return output - - -class MultiHeadSelfAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0, bias=True): - super().__init__() - self.inner_dim = inner_dim - self.num_heads = num_heads - self.head_dim = inner_dim // num_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # Q - self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None - # K - self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None - # V - self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) if bias else None - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) if bias else None - - def forward(self, query): - # x = self_attention( - # query, - # self.q_proj, self.q_bias, - # self.k_proj, self.k_bias, - # self.v_proj, self.v_bias, - # self.out_proj, self.out_bias, - # self.num_heads, self.scaling, self.dropout_p, mask=True - # ) - x = checkpoint.checkpoint( - self_attention, - query, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, self.out_bias, - self.num_heads, self.scaling, self.dropout_p, True - ) - return x - - -class SeqMHSA(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, - inner_dim: int, dropout: float = 0.0, - bias=True, shard_num=4): - super().__init__() - assert num_heads % shard_num == 0 - print_each_rank(f'using sequence MHSA: sharding size: {shard_num}') - self.layers = torch.nn.ModuleList( - MultiHeadSelfAttention( - embed_dim, - num_heads // shard_num, - inner_dim // shard_num, - dropout, - bias - ) for _ in range(shard_num) - ) - - def forward(self, x): - out_sum = None - for layer in self.layers: - out = layer(x) - out_sum = out if out_sum is None else out_sum + out - return out_sum - - -def feedforward(x: torch.Tensor, - proj1: torch.Tensor, proj1_bias: torch.Tensor, - proj2: torch.Tensor, proj2_bias: torch.Tensor, - dropout: float) -> torch.Tensor: - x = torch.nn.functional.linear(x, proj1, proj1_bias) - x = torch.nn.functional.gelu(x) - x = torch.nn.functional.dropout(x, dropout, True, False) - x = torch.nn.functional.linear(x, proj2, proj2_bias) - return x - - -class MLP(torch.nn.Module): - - def __init__(self, embed_dim, hidden_dim, dropout: float, bias=True): - super().__init__() - self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) - self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) - self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) - self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) - self.dropout = dropout - - def forward(self, x: torch.Tensor): - x = checkpoint.checkpoint( - feedforward, - x, self.proj1, self.proj1_bias, self.proj2, self.proj2_bias, self.dropout - ) - # x = feedforward(x, - # self.proj1, self.proj1_bias, - # self.proj2, self.proj2_bias, - # self.dropout) - return x - - -class SeqMLP(torch.nn.Module): - - def __init__(self, embed_dim, hidden_dim, dropout: float, - bias=True, shard_num = 4): - super().__init__() - print_each_rank(f'using sequence MLP: sharding size: {shard_num}') - assert hidden_dim % shard_num == 0 - self.layers = torch.nn.ModuleList( - [MLP(embed_dim, hidden_dim // shard_num, dropout, bias) for _ in range(shard_num)] - ) - - def forward(self, x: torch.Tensor): - out_sum = None - for layer in self.layers: - out = layer(x) - out_sum = out if out_sum is None else out_sum + out - return out_sum - - -class TransformerLayer(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, attn_inner_dim: int, ffn_embed_dim: int, - dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): - super().__init__() - - if args.seq > 1: - self.self_attn = SeqMHSA(embed_dim, num_heads, attn_inner_dim, atten_dropout, shard_num=args.seq) - else: - self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_inner_dim, atten_dropout) - - self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.dropout = torch.nn.Dropout(p=dropout) - - if args.seq > 1: - self.mlp = SeqMLP(embed_dim, ffn_embed_dim, activation_dropout, shard_num=args.seq) - else: - self.mlp = MLP(embed_dim, ffn_embed_dim, activation_dropout) - - self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - x = self.mlp(x) - x = self.dropout(x) - x = x + residual - return x - - -class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.cfg = Config() - self.layers = torch.nn.ModuleList( - [TransformerLayer( - self.cfg.embed_dim, - self.cfg.num_heads, - self.cfg.embed_dim, - self.cfg.ffn_dim - ) for _ in range(self.cfg.layers)] - ) - - def forward(self, x: torch.Tensor): # L N E - - for layer in self.layers: - x = layer(x) - loss = torch.sum(x) - return loss - - -class DataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int): - - self.bs = batch_size - self.cfg = Config() - super().__init__( - shapes=([self.cfg.seqlen, batch_size, self.cfg.embed_dim],), - dtypes=(torch.float,), - batch_dims=(1,) - ) - self.samples = [self.random_sample()] - - def random_sample(self): - inputs = torch.randn( - *(self.cfg.seqlen, self.bs, self.cfg.embed_dim), - dtype=torch.float, - device=torch.cuda.current_device(), - requires_grad=True - ) - return (inputs,) - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] - - - -if __name__ == '__main__': - - dataloader = DataLoader(batch_size=args.bs) - model = Model().cuda() - model.train() - # optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - optimizer = torch.optim.SGD(model.parameters(), lr=3e-05) - - CudaTimer(enable=False).warmup() - torch.distributed.barrier() - iter_num = 10 - for step in range(iter_num): - dataloader = iter(dataloader) - if step >= 4: - CudaTimer(enable=True).start('e2e') - if step == 0: - model_summary(model, next(dataloader)) - loss = model(*next(dataloader)) - loss.backward() - optimizer.step() - optimizer.zero_grad() - - if step >= 4: - CudaTimer().stop('e2e') - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-4, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-4) - memory_summary() \ No newline at end of file diff --git a/handcraft/swin/test-2node.sh b/handcraft/swin/test-2node.sh deleted file mode 100755 index 47f3f433..00000000 --- a/handcraft/swin/test-2node.sh +++ /dev/null @@ -1,235 +0,0 @@ -# swin transformer constant head dim == 32 - -evaldir=eval/swin-coshard -mkdir -p ${evaldir} - -rm -f notify.py -wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py - -img_size=1536 -window_size=48 -bs=256 - - -test_naive_pp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - echo "testing ${gpus}-dev: Pure PP${coshard}: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-pp${gpus}.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results Swin PP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-pp${gpus}.txt" \ - --file ${evaldir}/${gpus}dev-${arch}-pp${gpus}.txt -} - -test_naive_tp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 16 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results Swin TP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt" \ - --file ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt -} - -test_naive_hybrid_tp_pp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - - # Hybrid TP-1F1B -- 16 GPU - if [ ${gpus} == 16 ] - then - echo "testing ${gpus}-dev: TP8-PP2: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 2 --tp-size 8 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp8pp2.txt - sleep 5 - killall python - sleep 5 - killall python - - echo "testing ${gpus}-dev: TP4-PP4: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 4 --tp-size 4 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp4.txt - sleep 5 - killall python - sleep 5 - killall python - fi -} - -test_coshard_pp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --use-coshard - --fp16 > ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_coshard_hybrid_tp_pp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - # Hybrid TP-1F1B -- 8 GPU - if [ ${gpus} == 16 ] - then - # echo "testing ${gpus}-dev: TP8-PP2: L${layers}E${dim}H${heads}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=${nodes} \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/swin/train.py \ - # --layers ${layers} --dim ${dim} --heads ${heads} \ - # --img-size ${img_size} --window-size ${window_size} \ - # --pp-size 2 --tp-size 8 --dp-size 1 \ - # --bs 64 --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp8pp2-coshard.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - - echo "testing coshard ${gpus}-dev: TP4-PP4: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 4 --tp-size 4 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --use-coshard --use-inner-coshard \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-tp4pp4-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results Swin TP4-PP4+Coshard | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp4pp4-coshard.txt" \ - --file ${evaldir}/${gpus}dev-${arch}-tp4pp4-coshard.txt - fi -} - -test_all() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - test_naive_pp $layers $dim $heads $gpus - test_naive_tp $layers $dim $heads $gpus - test_naive_hybrid_tp_pp $layers $dim $heads $gpus - test_coshard_pp $layers $dim $heads $gpus -} - - -# ================================================= -# selected experiments -# ================================================= - -test_naive_tp 50 1024 32 2 16 -test_coshard_hybrid_tp_pp 50 1024 32 2 16 -# test_naive_hybrid_tp_pp 50 1024 32 2 16 # -> OOM - -python scripts/keep.py --gpus 8 - -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=8 \ -# --nnodes=2 \ -# --node_rank=${NODE_RANK} \ -# --master_addr="${MASTER_IP}" \ -# --master_port=${MASTER_PORT} \ -# handcraft/swin/train.py \ -# --layers 50 --dim 1024 --heads 32 \ -# --img-size 1536 --window-size 48 \ -# --pp-size 4 --tp-size 4 --dp-size 1 \ -# --bs 256 --micro-bs 1 --use-coshard --use-inner-coshard \ -# --fp16 diff --git a/handcraft/swin/test-4node.sh b/handcraft/swin/test-4node.sh deleted file mode 100755 index ab132b77..00000000 --- a/handcraft/swin/test-4node.sh +++ /dev/null @@ -1,300 +0,0 @@ -# swin transformer constant head dim == 32 - -evaldir=eval/swin-coshard -mkdir -p ${evaldir} - - -img_size=1536 -window_size=48 -bs=256 - -rm -f notify.py -wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py - - -test_naive_pp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - echo "testing ${gpus}-dev: Pure PP${coshard}: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-pp${gpus}.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_naive_tp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 16 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results Swin TP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt" \ - --file ${evaldir}/${gpus}dev-${arch}-tp${gpus}.txt -} - -test_naive_hybrid_tp_pp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - # Hybrid TP-1F1B -- 16 GPU - if [ ${gpus} == 32 ] - then - echo "testing ${gpus}-dev: TP16-PP2: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 2 --tp-size 16 --dp-size 1 \ - --bs ${bs} --micro-bs 1 \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-tp8pp2.txt - sleep 5 - killall python - sleep 5 - killall python - - # echo "testing ${gpus}-dev: TP8-PP4: L${layers}E${dim}H${heads}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=${nodes} \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/swin/train.py \ - # --layers ${layers} --dim ${dim} --heads ${heads} \ - # --img-size ${img_size} --window-size ${window_size} \ - # --pp-size 4 --tp-size 8 --dp-size 1 \ - # --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp4.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - fi -} - -test_coshard_pp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - echo "testing ${gpus}-dev: Coshard PP: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --use-coshard - --fp16 > ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results Swin Coshard PP | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt" \ - --file ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt -} - -test_coshard_hybrid_tp_pp() -{ - layers=$1 - dim=$2 - heads=$3 - nodes=$4 - gpus=$5 - arch=L${layers}E${dim}H${heads}-${img_size} - - # Hybrid TP-1F1B -- 8 GPU - if [ ${gpus} == 32 ] - then - # echo "testing ${gpus}-dev: TP16-PP2: L${layers}E${dim}H${heads}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=${nodes} \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/swin/train.py \ - # --layers ${layers} --dim ${dim} --heads ${heads} \ - # --img-size ${img_size} --window-size ${window_size} \ - # --pp-size 2 --tp-size 16 --dp-size 1 \ - # --bs 64 --micro-bs 1 --use-coshard --use-inner-coshard \ - # --fp16 > ${evaldir}/${gpus}dev-${arch}-tp16pp2-coshard.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - - echo "testing coshard ${gpus}-dev: TP8-PP4: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=${nodes} \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 4 --tp-size 8 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --use-coshard --use-inner-coshard \ - --fp16 > ${evaldir}/${gpus}dev-${arch}-tp8pp4-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ - --msg "Test Results Swin Coshard TP8-PP4 | Node ${NODE_RANK} | ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt" \ - --file ${evaldir}/${gpus}dev-${arch}-pp${gpus}-coshard.txt - - # echo "testing coshard ${gpus}-dev: TP4-PP8: L${layers}E${dim}H${heads}" - # OMP_NUM_THREADS=4 torchrun \ - # --nproc_per_node=8 \ - # --nnodes=${nodes} \ - # --node_rank=${NODE_RANK} \ - # --master_addr="${MASTER_IP}" \ - # --master_port=${MASTER_PORT} \ - # handcraft/swin/train.py \ - # --layers ${layers} --dim ${dim} --heads ${heads} \ - # --img-size ${img_size} --window-size ${window_size} \ - # --pp-size 8 --tp-size 4 --dp-size 1 \ - # --bs ${bs} --micro-bs 1 --use-coshard --use-inner-coshard \ - # --fp16 > ${evaldir}/${gpus}dev-${arch}-tp8pp4-coshard.txt - # sleep 5 - # killall python - # sleep 5 - # killall python - fi -} - -test_all() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - test_naive_pp $layers $dim $heads $gpus - test_naive_tp $layers $dim $heads $gpus - test_naive_hybrid_tp_pp $layers $dim $heads $gpus - test_coshard_pp $layers $dim $heads $gpus -} - - -# ================================================= -# selected experiments -# ================================================= - -test_naive_tp 58 1536 32 4 32 -test_coshard_hybrid_tp_pp 58 1536 32 4 32 -# test_naive_hybrid_tp_pp 58 1536 32 4 32 # -> OOM - -python scripts/keep.py --gpus 8 - - - -# ============ exp -# Fail: 50 1280 32 | COSHARD-TP: TP4PP8 Fail TP: ? Hybrid-TP: ? -# TEST: 50 1536 32 | COSHARD-TP: ? TP4PP8 ALL Fail TP8PP4 SUC TP: ? Hybrid-TP: ? -# TEST: 58 1536 32 | COSHARD-TP: ? TP4PP8 ? Fail TP8PP4 SUC TP: ? Hybrid-TP: ? -# FAIL: 50 1536 64 | COSHARD-TP: ? TP8PP4 Fail TP: ? Hybrid-TP: ? -# FAIL: 50 2048 64 | COSHARD-TP: ? TP8PP4 Fail TP: ? Hybrid-TP: ? - - -# coshard -# layers=58 -# dim=1536 -# heads=32 -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=8 \ -# --nnodes=4 \ -# --node_rank=${NODE_RANK} \ -# --master_addr="${MASTER_IP}" \ -# --master_port=${MASTER_PORT} \ -# handcraft/swin/train.py \ -# --layers ${layers} --dim ${dim} --heads ${heads} \ -# --img-size 1536 --window-size 48 \ -# --pp-size 8 --tp-size 4 --dp-size 1 \ -# --bs 8 --micro-bs 1 --use-coshard --use-inner-coshard \ -# --fp16 -# -# # hybrid tp -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=8 \ -# --nnodes=4 \ -# --node_rank=${NODE_RANK} \ -# --master_addr="${MASTER_IP}" \ -# --master_port=${MASTER_PORT} \ -# handcraft/swin/train.py \ -# --layers ${layers} --dim ${dim} --heads ${heads} \ -# --img-size 1536 --window-size 48 \ -# --pp-size 2 --tp-size 16 --dp-size 1 \ -# --bs 4 --micro-bs 1 \ -# --fp16 -# -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=8 \ -# --nnodes=4 \ -# --node_rank=${NODE_RANK} \ -# --master_addr="${MASTER_IP}" \ -# --master_port=${MASTER_PORT} \ -# handcraft/swin/train.py \ -# --layers ${layers} --dim ${dim} --heads ${heads} \ -# --img-size 1536 --window-size 48 \ -# --pp-size 1 --tp-size 32 --dp-size 1 \ -# --bs 2 --micro-bs 1 --fp16 -# -# clear -# killall python diff --git a/handcraft/swin/test.sh b/handcraft/swin/test.sh deleted file mode 100755 index 57deff28..00000000 --- a/handcraft/swin/test.sh +++ /dev/null @@ -1,268 +0,0 @@ -# swin transformer constant head dim == 32 - -evaldir=eval/swin-coshard -mkdir -p ${evaldir} - -bs=256 -img_size=1536 -window_size=48 - - -test_naive_pp() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - - echo "testing ${gpus}-dev: Pure PP${coshard}: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_naive_tp() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - - echo "testing ${gpus}-dev: Pure TP: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 1 --tp-size ${gpus} --dp-size 1 \ - --bs 16 --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp${gpus}.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_naive_hybrid_tp_pp() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - - if [ ${gpus} == 4 ] - then - echo "testing ${gpus}-dev: TP2-PP2: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 2 --tp-size 2 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2.txt - sleep 5 - killall python - sleep 5 - killall python - fi - - # Hybrid TP-1F1B -- 8 GPU - if [ ${gpus} == 8 ] - then - echo "testing ${gpus}-dev: TP4-PP2: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 2 --tp-size 4 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2.txt - sleep 5 - killall python - sleep 5 - killall python - - echo "testing ${gpus}-dev: TP2-PP4: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 4 --tp-size 2 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4.txt - sleep 5 - killall python - sleep 5 - killall python - fi -} - -test_coshard_pp() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - - echo "testing ${gpus}-dev: Coshard PP: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size ${gpus} --tp-size 1 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --use-coshard --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-pp${gpus}-coshard.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_coshard_dp() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - - echo "testing ${gpus}-dev: Coshard DP: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 1 --tp-size 1 --dp-size ${gpus} \ - --bs ${bs} --micro-bs 1 --use-coshard \ - --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-dp${gpus}-coshard.txt - sleep 5 - killall python - sleep 5 - killall python -} - -test_coshard_hybrid_tp_pp() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - - if [ ${gpus} == 4 ] - then - echo "testing ${gpus}-dev: TP2-PP2: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 2 --tp-size 2 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --use-coshard \ - --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp2-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - fi - - # Hybrid TP-1F1B -- 8 GPU - if [ ${gpus} == 8 ] - then - echo "testing ${gpus}-dev: TP4-PP2: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 2 --tp-size 4 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --use-coshard \ - --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp4pp2-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - - echo "testing ${gpus}-dev: TP2-PP4: L${layers}E${dim}H${heads}" - OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=${gpus} \ - --nnodes=1 \ - handcraft/swin/train.py \ - --layers ${layers} --dim ${dim} --heads ${heads} \ - --img-size ${img_size} --window-size ${window_size} \ - --pp-size 4 --tp-size 2 --dp-size 1 \ - --bs ${bs} --micro-bs 1 --use-coshard \ - --fp16 > ${evaldir}/${gpus}dev-L${layers}E${dim}H${heads}-${img_size}-tp2pp4-coshard.txt - sleep 5 - killall python - sleep 5 - killall python - fi -} - -test_all() -{ - layers=$1 - dim=$2 - heads=$3 - gpus=$4 - test_naive_pp $layers $dim $heads $gpus - test_naive_tp $layers $dim $heads $gpus - test_naive_hybrid_tp_pp $layers $dim $heads $gpus - test_coshard_pp $layers $dim $heads $gpus -} - -# ================================================= -# selected experiments -# ================================================= -test_naive_tp 6 96 3 1 -test_naive_tp 10 128 4 1 -test_naive_tp 14 192 6 1 -# test_naive_tp 18 256 8 1 # --> OOM -# test_naive_tp 26 512 16 1 # --> OOM -test_coshard_pp 6 96 3 1 -test_coshard_pp 10 128 4 1 -test_coshard_pp 14 192 6 1 -test_coshard_pp 18 256 8 1 -test_coshard_pp 26 512 16 1 - -test_coshard_dp 18 256 8 2 -test_naive_tp 18 256 8 2 -test_naive_hybrid_tp_pp 18 256 8 2 - -test_coshard_pp 26 512 16 4 -test_naive_tp 26 512 16 4 -# test_naive_hybrid_tp_pp 26 512 16 4 # --> OOM - -test_coshard_pp 42 768 24 8 -test_naive_tp 42 768 24 8 -# test_naive_hybrid_tp_pp 42 768 24 8 # --> OOM - -python scripts/keep.py --gpus 8 - - -# for test -# coshard-pp -# gpus=4 -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=${gpus} \ -# --nnodes=1 \ -# handcraft/swin/train.py \ -# --layers 26 --dim 512 --heads 16 \ -# --img-size 1536 --window-size 48 \ -# --pp-size ${gpus} --tp-size 1 --dp-size 1 \ -# --bs 16 --micro-bs 1 --use-coshard --fp16 \ No newline at end of file diff --git a/handcraft/swin/train.py b/handcraft/swin/train.py deleted file mode 100644 index d7f70506..00000000 --- a/handcraft/swin/train.py +++ /dev/null @@ -1,1045 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/swin/train.py \ - --bs 1 --micro-bs 1 --fp16 \ - --dp-size 1 --pp-size 1 --tp-size 1 \ - --layers 10 --dim 128 --heads 4 -""" - -import torch -import torch.nn as nn -import torch.utils.checkpoint as checkpoint - -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary, model_summary -from cube.runtime.adapter.reducer import Reducer -from cube.runtime.device import DeviceGroup -from handcraft.module.distnn import IdentityAllreduce, AllReduceIdentity, AllGatherSplit -from handcraft.module.schedule import schedule_1f1b -from handcraft.module.stage import PipeStage, layer_division -from handcraft.swin.utils import create_position_bias, create_position_index, trunc_normal_, window_partition, window_reverse, DropPath - -import argparse - - -parser = argparse.ArgumentParser(description='swin') - -# model arch -parser.add_argument('--layers', type=int, default=18, - help='third stage layer depths. default large') -parser.add_argument('--dim', type=int, default=192, - help='input channel of first stage') -parser.add_argument('--heads', type=int, default=6, - help='head num of first stage') -# data -parser.add_argument('--img-size', type=int, default=1536, - help='image size, can be 224, 640, 1536') -parser.add_argument('--window-size', type=int, default=48, - help='image size, can be 7, 40, 48') -# training -parser.add_argument('--bs', type=int, default=256, - help='batch size') -parser.add_argument('--micro-bs', type=int, default=1, - help='micro batch size') -parser.add_argument('--pp-size', type=int, default=1, - help='pipeline parallelism size') -parser.add_argument('--tp-size', type=int, default=1, - help='tensor parallelism size') -parser.add_argument('--dp-size', type=int, default=1, - help='data parallelism size') -parser.add_argument('--schedule', type=str, default='1f1b', choices=['1f1b'], - help='scheduling algorithm') -parser.add_argument('--use-coshard', action='store_true', default=False, - help='enable this will split head but co-locate them with re-compute') -parser.add_argument('--use-inner-coshard', action='store_true', default=False, - help='enable this will shard bmm in attention of q @ k') -parser.add_argument('--fp16', action='store_true', default=False) - -args = parser.parse_args() -print(args) - -_tp_group = -1 - -_dp_group = -1 -_dp_reducer = None - -_pp_group = -1 -_pp_global_ranks = () -_schedule = schedule_1f1b -_layer_divisions = [] - -cube.init() -dp_ranks, pp_ranks, tp_ranks= DeviceGroup().create_hybrid( - [args.dp_size, args.pp_size, args.tp_size] -) - -if len(dp_ranks) != 1: - print_each_rank(f'initializing dp ranks: {dp_ranks}') - _dp_group = DeviceGroup().get_group(dp_ranks) - _dp_reducer = Reducer(dp_ranks) - -if len(tp_ranks) != 1: - print_each_rank(f'initializing tp ranks: {tp_ranks}') - _tp_group = DeviceGroup().get_group(tp_ranks) - -if len(pp_ranks) != 1: - print_each_rank(f'initializing pp ranks: {pp_ranks}') - _pp_group = DeviceGroup().get_group(pp_ranks) - _pp_global_ranks = tuple(pp_ranks) - - # layer division - nlayers = 2 + 2 + args.layers + 2 + 3 # 3 is patch merging layers - # metrics for V100-32GB-PCIe - if args.dim == 256: # OK! - times = ([109.93] * 2 + [0]) + \ - ([60.34] * 2 + [0]) + \ - ([43.18] * args.layers + [0]) + \ - ([27.51] * 2) - elif args.dim == 512: # OK! - times = ([255.10] * 2 + [0]) + \ - ([139.92] * 2 + [0]) + \ - ([90.98] * args.layers + [0]) + \ - ([63.78] * 2) - elif args.dim == 768: # OK! - times = ([440.5] * 2 + [0]) + \ - ([241.4] * 2 + [0]) + \ - ([145.7] * args.layers + [0]) + \ - ([108.9] * 2) - elif args.dim >= 1024: # TP needed - times = ([255.10] * 2 + [0]) + \ - ([139.92] * 2 + [0]) + \ - ([90.98] * args.layers + [0]) + \ - ([63.78] * 2) - else: - print_each_rank('WARNING: NO Metric Logged!!') - times = ([1] * 2 + [0]) + \ - ([1] * 2 + [0]) + \ - ([1] * args.layers + [0]) + \ - ([1] * 2) - num_stages = len(pp_ranks) - _layer_divisions = layer_division(times, num_stages) - # specific rules for stage division in order to fit in memory - if args.dim == 1024 and args.tp_size == 4: - # first stage - if _layer_divisions[0][1] > 8: - remain_times = times[8:] - _layer_divisions = [(0, 8)] + layer_division(remain_times, num_stages-1, start_id=8) - if args.dim == 1536 and args.tp_size == 8: - limits = [None] * args.pp_size - limits[0] = 4 - _layer_divisions = layer_division(times, num_stages, limits=limits) - # if args.dim == 1536 and args.tp_size == 4: - # limits = [None] * args.pp_size - # limits[0] = 1 - # limits[1] = 3 - # limits[2] = 9 - # _layer_divisions = layer_division(times, num_stages, limits=limits) - -else: - _layer_divisions = [(0, 2 + 2 + args.layers + 2 + 3)] -print_each_rank(f'layer divisions: {_layer_divisions}') - - -class Config: - - embed_dim = args.dim - depths = [2, 2, args.layers, 2] - num_heads = [args.heads, args.heads * 2, args.heads * 4, args.heads * 8] - - mlp_ratio = 4 - qkv_bias = True - qk_scale = None - drop_path_rate = 0.2 - drop_rate = 0.2 - - img_size = args.img_size - window_size = args.window_size - num_classes = 1000 - - -class Mlp(torch.nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - self._tp_group = _tp_group - self._tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.in_features = in_features - self.hidden_features = hidden_features // self._tp_size - self.fc1 = nn.Linear(in_features, hidden_features // self._tp_size) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features // self._tp_size, out_features) - self.drop = nn.Dropout(drop) - - def forward_(self, x): - if self._tp_size > 1: - x = IdentityAllreduce.apply(x, self._tp_group) - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - if self._tp_size > 1: - x = AllReduceIdentity.apply(x, self._tp_group) - return x - - def forward(self, x, recompute=True): - if recompute: - x = checkpoint.checkpoint(self.forward_, x) - else: - x = self.forward_(x) - return x - - def flops(self, seqlen: int): - mlp_flops = dict( - fc1=seqlen * self.in_features * self.hidden_features, - act=8 * seqlen * self.hidden_features, - drop=seqlen * self.hidden_features, - fc2=seqlen * self.hidden_features * self.in_features, - final_drop=seqlen * self.in_features, - ) - return sum(mlp_flops.values()) - - -class SeqMlp(torch.nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, - act_layer=nn.GELU, drop=0., - coshard=1): - super().__init__() - self._tp_group = _tp_group - self._tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.coshard = coshard - assert hidden_features is not None - assert hidden_features % coshard == 0 - self.mlps = torch.nn.ModuleList( - [Mlp(in_features, hidden_features // coshard, out_features, act_layer, drop) for _ in range(coshard)] - ) - # remove tp communication inside each mlp as it will be - # done outside here - for mlp in self.mlps: - mlp._tp_size = 1 - - def forward(self, x, recompute=True): - if self._tp_size > 1: - x = IdentityAllreduce.apply(x, self._tp_group) - - outs = None - for mlp in self.mlps: - x_out = mlp(x, recompute=recompute) - outs = x_out if outs is None else outs + x_out - - if self._tp_size > 1: - outs = AllReduceIdentity.apply(outs, self._tp_group) - return outs - - def flops(self, seqlen: int): - return sum([mlp.flops(seqlen) for mlp in self.mlps]) - - -class WindowAttention(torch.nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, inner_dim, window_size, num_heads, - qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., - position_index=True): - - super().__init__() - self._tp_group = _tp_group - self._tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - - self.dim = dim - self.window_size = window_size # Wh, Ww - self.head_dim = inner_dim // num_heads - assert num_heads % self._tp_size == 0 - self.num_heads = num_heads // self._tp_size - self.scale = qk_scale or self.head_dim ** -0.5 - - # define define a parameter table of relative position bias - table = create_position_bias(self.window_size, self.num_heads) - self.relative_position_bias_table = table - if position_index: - index = create_position_index(window_size, cuda=False) - self.register_buffer("relative_position_index", index) - else: - self.relative_position_index = None - - self.qkv = nn.Linear(dim, inner_dim // self._tp_size * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(inner_dim // self._tp_size, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) - - def forward_(self, x, mask=None, position_index=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - assert (self.relative_position_index is None) ^ (position_index is None) - if position_index is not None: - relative_position_index = position_index - else: - relative_position_index = self.relative_position_index - - if self._tp_size > 1: - x = IdentityAllreduce.apply(x, self._tp_group) - - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - - k = k.transpose(-2, -1) - # inner coshard by splitting windows - if args.use_inner_coshard and (B_ == 64 or B_ == 16): - chunk_num = B_ // 4 - attn = [] - for shard_q, shard_k in zip(torch.chunk(q, chunks=chunk_num, dim=0), torch.chunk(k, chunks=chunk_num, dim=0)): - attn.append(shard_q @ shard_k) - attn = torch.concat(tuple(attn), dim=0) - else: - attn = (q @ k) - - relative_position_bias = self.relative_position_bias_table[relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) - x = self.proj(x) - x = self.proj_drop(x) - - if self._tp_size > 1: - x = AllReduceIdentity.apply(x, self._tp_group) - - return x - - def forward(self, x, mask=None, position_index=None, recompute=True): - if recompute: - x = checkpoint.checkpoint(self.forward_, x, mask, position_index) - else: - x = self.forward_(x, mask, position_index) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, seqlen: int): - # calculate flops for one window - # seqlen is window size * window size - attn_flops = dict( - kqv=3 * seqlen * self.dim * self.head_dim * self.num_heads, - kqv_bias= 3 * seqlen * self.head_dim * self.num_heads, - q_scale=seqlen * self.num_heads * self.head_dim, - attn_score=self.num_heads * seqlen * self.head_dim * seqlen, # q @ k - position_index=self.num_heads * seqlen * seqlen, - attn_softmax=5 * self.num_heads * seqlen * seqlen, - attn_dropout=self.num_heads * seqlen * seqlen, - attn_output=self.num_heads * seqlen * seqlen * self.head_dim, # attn @ v - out_proj=seqlen * self.num_heads * self.head_dim * self.dim # self.proj(x) - ) - return sum(attn_flops.values()) - - -class SeqWindowAttention(torch.nn.Module): - - def __init__(self, dim, inner_dim, window_size, num_heads, - qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., - coshard=1): - super().__init__() - self._tp_group = _tp_group - self._tp_size = 1 if _tp_group == -1 else torch.distributed.get_world_size(_tp_group) - assert (num_heads // args.tp_size) % coshard == 0 - # only coshard num heads of first two stages - self.coshard = coshard - self.attns = torch.nn.ModuleList( - [WindowAttention( - dim, inner_dim // self.coshard, window_size, num_heads // self.coshard, - qkv_bias, qk_scale, attn_drop, proj_drop, False) for _ in range(self.coshard)] - ) - # 1) remove communication inside each attention as it will be - # done outside here - # 2) share same relative position index - index = create_position_index(window_size, cuda=False) - self.register_buffer("relative_position_index", index) - for attn in self.attns: - attn._tp_size = 1 - - def forward(self, x, mask=None, recompute=True): - - # ===> sharding from both window and heads - # B = x.size(0) - # if B % 2 == 0: - # xs = torch.chunk(x, 2, dim=0) - # masks = torch.chunk(mask, 2, dim=0) if mask is not None else (None,) * 2 - # else: - # xs = (x,) - # masks = (mask,) - # outs = [] - # for bid, (cx, cmask) in enumerate(zip(xs, masks)): - # for attn in self.attns: - # cx_out = attn(cx, cmask, recompute) - # if len(outs) < bid + 1: - # outs.append(cx_out) - # else: - # outs[bid] = outs[bid] + cx_out - # outs = torch.concat(tuple(outs), dim=0) - # return outs - - # ===> sharding only from heads - if self._tp_size > 1: - x = IdentityAllreduce.apply(x, self._tp_group) - - outs = None - for attn in self.attns: - x_out = attn(x, mask, self.relative_position_index, recompute) - outs = x_out if outs is None else outs + x_out - - if self._tp_size > 1: - outs = AllReduceIdentity.apply(outs, self._tp_group) - return outs - - def flops(self, seqlen: int): - flops = 0 - for attn in self.attns: - flops += attn.flops(seqlen) - return flops - - -class SwinTransformerBlock(PipeStage): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_coshard=False, layer_id=None): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - if not use_coshard or layer_id in [2,3]: - self.attn = WindowAttention( - dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - else: - coshard = num_heads // args.tp_size - coshard = coshard // 2 if layer_id > 0 else coshard - print(f'rank [{torch.distributed.get_rank()}]: Swin-stage-{layer_id} using coshard {coshard}') - self.attn = SeqWindowAttention( - dim, dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, coshard=coshard) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - if not use_coshard or layer_id in [2,3]: - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - else: - coshard = num_heads // args.tp_size - coshard = coshard // 2 if layer_id > 0 else coshard - self.mlp = SeqMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, coshard=coshard) - - H, W = self.input_resolution - if self.shift_size > 0: - # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - assert args.bs // (args.micro_bs * args.dp_size) != 0 - self.inputs_info = ( - ((args.micro_bs, H * W, self.dim),), - (torch.float32 if not args.fp16 else torch.float16,) - ) - self.outputs_info = ( - ((args.micro_bs, H * W, self.dim),), - (torch.float32 if not args.fp16 else torch.float16,) - ) - self.layer_id = layer_id - self.inner_recompute = False if not use_coshard else layer_id in [0,1] - - def forward_(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask, recompute=self.inner_recompute) # nW*B, window_size*window_size, C - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x), recompute=self.inner_recompute)) - return x - - def forward(self, x): - CudaTimer().start(f'layer{self.layer_id}') - # layer-wise recompute - if not self.inner_recompute: - x = checkpoint.checkpoint(self.forward_, x) - # attention/mlp-wise recompute - else: - x = self.forward_(x) - CudaTimer().stop(f'layer{self.layer_id}') - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - H, W = self.input_resolution - num_windows = H * W / self.window_size / self.window_size - block_flops = dict( - norm1=5 * H * W * self.dim, - roll1=0, # ignore - window_partition=0, # ignore - attn=num_windows * self.attn.flops(self.window_size * self.window_size), - roll2=0, # ignore - attn_dropout=H * W * self.dim, - atnn_residual=H * W * self.dim, - norm2=5 * H * W * self.dim, - mlp=self.mlp.flops(H * W), - mlp_drop=H * W * self.dim, - mlp_residual=H * W * self.dim, - ) - return sum(block_flops.values()) - - -class PatchMerging(PipeStage): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - H, W = self.input_resolution - assert args.bs // (args.micro_bs * args.dp_size) != 0 - self.inputs_info = ( - ((args.micro_bs, H * W, self.dim),), - (torch.float32 if not args.fp16 else torch.float16,) - ) - self.outputs_info = ( - ((args.micro_bs, (H // 2) * (W // 2), self.dim * 2),), - (torch.float32 if not args.fp16 else torch.float16,) - ) - - def forward_(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def forward(self, x): - x = checkpoint.checkpoint(self.forward_, x) - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -def create_basic_layter(dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, - layer_id=None, start_id=0): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - # swin transformer layers - blocks = [SwinTransformerBlock( - dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if ((i + start_id) % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, use_coshard=args.use_coshard, layer_id=layer_id) - for i in range(depth)] - # patch merging layer - if downsample is not None: - downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - blocks.append(downsample) - return blocks - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(PipeStage): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, **kwargs): - super().__init__() - self.set_pipeline(_pp_global_ranks) - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - self.patches_resolution = (img_size // patch_size, img_size // patch_size) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - # build layers - total_layers = [3, 3, depths[2] + 1, 2] - # pipeline split layers - start, end = _layer_divisions[self.stage_local_rank] - layers = [] - for i_layer in range(self.num_layers): - layer_start = sum(total_layers[:i_layer]) - layer_end = sum(total_layers[:i_layer+1]) - if max(layer_start, start) >= min(layer_end, end): - continue - have_downsample = start < layer_end and layer_end <= end and i_layer < self.num_layers - 1 - layer_start_id = max(layer_start, start) - layer_start - layer_num = min(layer_end, end) - max(layer_start, start) - layer_num = layer_num if not have_downsample else layer_num - 1 - assert layer_num >= 0 - blocks = create_basic_layter(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(self.patches_resolution[0] // (2 ** i_layer), - self.patches_resolution[1] // (2 ** i_layer)), - depth=layer_num, - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if have_downsample else None, - layer_id=i_layer, start_id=layer_start_id) - layers += blocks - assert (end - start) == len(layers), f"layer num not equal, [{start}, {end}) != {len(layers)} " - torch.distributed.barrier() - self.layers = torch.nn.ModuleList(layers) - print_each_rank(f'initialized {len(self.layers)} layers ranging from [{start}, {end})') - - self.inputs_info = self.layers[0].inputs_info - self.outputs_info = self.layers[-1].outputs_info - - # preprocess - if self.is_first_stage: - print(f'rank [{torch.distributed.get_rank()}]: initializing pre-process...') - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - # dropout - self.pos_drop = nn.Dropout(p=drop_rate) - - self.inputs_info = ((), ()) - - # post-process - if self.is_last_stage: - print(f'rank [{torch.distributed.get_rank()}]: initializing post-process...') - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.criterion = nn.CrossEntropyLoss() - - self.outputs_info = ( - (1,), - torch.float32 if args.fp16 else torch.float16 - ) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def forward(self, x = None): - if self.is_first_stage: - CudaTimer().start('pre-process') - x, _ = self.data - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - CudaTimer().stop('pre-process') - - for layer in self.layers: - x = layer(x) - - if self.is_last_stage: - CudaTimer().start('post-process') - _, labels = self.data - - def _post_process(x): - x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - x = self.head(x) - return x - - x = checkpoint.checkpoint(_post_process, x) - x = self.criterion(x, labels) - CudaTimer().stop('post-process') - - return x - - def flops(self): - flops = 0 - if self.is_first_stage: - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - if self.is_last_stage: - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes - return flops - - -class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int, img_size: int, num_classes: int): - - self.bs = batch_size - self.img_size = img_size - self.num_classes = num_classes - super().__init__( - shapes=([batch_size, 3, img_size, img_size,], - [batch_size], - ), - dtypes=(torch.float if not args.fp16 else torch.float16, torch.int), - batch_dims=(0, 0) - ) - self.samples = [self.random_sample()] - - def random_sample(self): - img = torch.rand( - *(self.bs, 3, self.img_size, self.img_size), - dtype=torch.float if not args.fp16 else torch.float16, - device=torch.cuda.current_device() - ) - labels = torch.randint( - 0, self.num_classes, - size=(self.bs,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - return (img, labels) - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] - - -def train(): - - cfg = Config() - model = SwinTransformer(img_size=cfg.img_size, - patch_size=4, - in_chans=3, - num_classes=cfg.num_classes, - embed_dim=cfg.embed_dim, - depths=cfg.depths, - num_heads=cfg.num_heads, - window_size=cfg.window_size, - mlp_ratio=cfg.mlp_ratio, - qkv_bias=cfg.qkv_bias, - qk_scale=cfg.qk_scale, - drop_rate=cfg.drop_rate, - drop_path_rate=cfg.drop_path_rate, - ape=False, - patch_norm=True, - use_checkpoint=False) - nparams = sum([param.numel() for param in model.parameters()]) - forward_flops = model.flops() - tflops = forward_flops * 4 / (1e12) # forward + recompute-forward + backward (2x) - print_each_rank(f'Model Params#: {nparams} | TFlops: {tflops}') - if args.fp16: - model = model.half() - model = model.cuda() - dataloader = ImageDataLoader(args.micro_bs, cfg.img_size, cfg.num_classes) - if _dp_reducer is not None: - for param in model.parameters(): - _dp_reducer.add_param(param) - optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) - - print_each_rank('model weight consumpition:', rank_only=0) - memory_summary() - - def train_iter(model, dataloader): - num_microbatch = args.bs // (args.dp_size * args.micro_bs) - if _pp_group != -1: - _schedule(model, dataloader, num_microbatch) - else: - for _ in range(num_microbatch): - model.data = next(dataloader) - loss = model() - loss.backward() - if _dp_reducer is not None: - _dp_reducer.allreduce() - - CudaTimer(enable=False) - iter_num = 6 - for step in range(iter_num): - - if step >= 2: - CudaTimer(enable=True).start('e2e') - - # training - train_iter(model, dataloader) - - # if step == 0: - # print_each_rank('passed first iteration', rank_only=0) - # print_each_rank('memory consumption before optimizer:', rank_only=0) - # memory_summary() - - optimizer.step() - optimizer.zero_grad() - - if step >= 2: - CudaTimer().stop('e2e') - - torch.cuda.empty_cache() - torch.distributed.barrier() - - if step == 0: - print_each_rank('memory consumption after optimizer:', rank_only=0) - memory_summary() - - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-2, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-2) - memory_summary() - -train() diff --git a/handcraft/swin/utils.py b/handcraft/swin/utils.py deleted file mode 100644 index 2472b44d..00000000 --- a/handcraft/swin/utils.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Tuple -import warnings - -import torch -import math - - -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class DropPath(torch.nn.Module): - - def __init__(self, drop_prob: float): - super().__init__() - self.drop_prob = drop_prob - - def forward(self, x): - if self.drop_prob == 0. or not self.training: - return x - keep_prob = 1 - self.drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -def create_position_bias(window_size: Tuple[int, int], num_heads: int): - relative_position_bias_table = torch.nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - trunc_normal_(relative_position_bias_table, std=.02) - return relative_position_bias_table - - -def create_position_index(window_size: Tuple[int, int], cuda=False): - # get pair-wise relative position index for each token inside the window - with torch.no_grad(): - if cuda: - coords_h = torch.arange(window_size[0], device=torch.cuda.current_device()) - coords_w = torch.arange(window_size[1], device=torch.cuda.current_device()) - else: - coords_h = torch.arange(window_size[0]) - coords_w = torch.arange(window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - return relative_position_index diff --git a/handcraft/textnas/dataloader.py b/handcraft/textnas/dataloader.py deleted file mode 100644 index c724c192..00000000 --- a/handcraft/textnas/dataloader.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. -""" -For test: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - handcraft/textnas/dataloader.py -""" - - -import os -import numpy as np -import torch -from torch.utils import data -import threading -from transformers import BertModel, BertTokenizer -import collections -import time - -import cube -from cube.runtime.device import DeviceGroup -from cube.profiler import CudaTimer - - -def read_sst_2(data_path='./SST-2', max_input_length=64, min_count=1): - sentences, labels = [], [] - assert os.path.exists(data_path) - dataset_train = os.path.join(data_path, 'train.tsv') - with open(dataset_train, 'r') as f: - lines = f.readlines()[1:] # skip first - for line in lines: - sentence, label = line.split('\t') - sentence = sentence.strip() - label = int(label.strip()) - sentences.append(sentence) - labels.append(label) - return sentences, labels - - -class SSTDataset(data.Dataset): - def __init__(self): - self.sents, self.labels = read_sst_2() - print(f'> loaded SST dataset: train length: {len(self.sents)}') - - def __getitem__(self, index): - return self.sents[index], self.labels[index] - - def __len__(self): - return len(self.sents) - - -class SharedDataLoader(object): - def __init__(self, batch_size, replicate=True, **kwargs): - self.replicate = replicate - self.has_model = self.replicate or (DeviceGroup().rank == 0) - - dataset = SSTDataset() - dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, **kwargs) - self.dataloader = dataloader - - if self.has_model: - self.model = BertModel.from_pretrained('bert-base-uncased').cuda() - self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - else: - self.model = None - self.tokenizer = None - - self.max_queue = 32 - self.input_size = (batch_size, 64, 768) - self.batch_size = batch_size - self.length = len(dataset) // batch_size - - def __iter__(self): - self.counter = 0 - self.shared_queue = collections.deque() - self._dataloader_iter = iter(self.dataloader) - if self.has_model and (not self.replicate): - # sharing mode: all models share the same dataloader - print('starting pipeline to produce datas') - self.workers = threading.Thread(target=self._pipe).start() - return self - - def __len__(self): - return len(self.dataloader) - - def get_data(self): - if self.replicate: - CudaTimer().start('bert') - text, label = next(self._dataloader_iter) - text = torch.tensor([self.tokenizer.encode(t, max_length=64, padding='max_length') for t in text]).cuda() - mask = text > 0 - with torch.no_grad(): - output = self.model(text)['last_hidden_state'] - label = label.cuda() - if self.replicate: - CudaTimer().stop('bert') - return output, mask, label - - def _pipe(self): - while True: - while len(self.shared_queue) >= self.max_queue: - time.sleep(0.2) - # print('sample data...') - datas = self.get_data() - # print(datas) - self.shared_queue.append(datas) - - def __next__(self): - self.counter += 1 - if self.counter >= len(self): - raise StopIteration - if self.replicate: - # replicate mode: each gpu has a dataloader - text, masks, labels = self.get_data() - else: - # sharing mode: all models share the same dataloader - if self.has_model: - while not self.shared_queue: - time.sleep(0.1) - text, masks, labels = self.shared_queue.popleft() - assert torch.is_tensor(text) - masks = masks.float() - else: - text = torch.zeros(self.input_size, dtype=torch.float, device="cuda") - labels = torch.zeros(self.batch_size, dtype=torch.long, device="cuda") - masks = torch.zeros(self.input_size[:2], dtype=torch.float, device="cuda") - CudaTimer().start('get_data') - torch.distributed.broadcast(text, 0) - torch.distributed.broadcast(labels, 0) - torch.distributed.broadcast(masks, 0) - CudaTimer().stop('get_data') - masks = masks.bool() - return text, masks, labels - - -if __name__ == '__main__': - - cube.init() - dataloader = SharedDataLoader(32, replicate=True) - for datas in dataloader: - print(f'get data: {[data.size() for data in datas]}') - input('>>>') diff --git a/handcraft/textnas/dataset.sh b/handcraft/textnas/dataset.sh deleted file mode 100755 index 2c76689e..00000000 --- a/handcraft/textnas/dataset.sh +++ /dev/null @@ -1,6 +0,0 @@ - -echo 'downloading SST-2 dataset...' -wget https://dl.fbaipublicfiles.com/glue/data/SST-2.zip -unzip SST-2.zip -rm SST-2.zip - diff --git a/handcraft/textnas/ops.py b/handcraft/textnas/ops.py deleted file mode 100644 index 4e88943b..00000000 --- a/handcraft/textnas/ops.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import torch -import torch.nn.functional as F -from torch import nn - - -INF = 1E10 -EPS = 1E-12 - -def get_length(mask): - length = torch.sum(mask, 1) - length = length.long() - return length - - -class Mask(nn.Module): - - def forward(self, seq, mask): - # seq: (N, C, L) - # mask: (N, L) - seq_mask = torch.unsqueeze(mask, 2) - seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2) - return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq)) - - -class BatchNorm(nn.Module): - def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True): - super(BatchNorm, self).__init__() - self.mask_opt = Mask() - self.mask_opt1 = Mask() - self.pre_mask = pre_mask - self.post_mask = post_mask - self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine) - - def forward(self, seq, mask): - if self.pre_mask: - seq = self.mask_opt(seq, mask) - seq = self.bn(seq) - if self.post_mask: - seq = self.mask_opt1(seq, mask) - return seq - - -class ConvBN(nn.Module): - - def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob, - pre_mask, post_mask, with_bn=True, with_relu=True): - super(ConvBN, self).__init__() - self.mask_opt = Mask() - self.pre_mask = pre_mask - self.post_mask = post_mask - self.with_bn = with_bn - self.with_relu = with_relu - self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, - padding=(kernal_size - 1) // 2) - self.dropout = nn.Dropout(p=(1 - cnn_keep_prob)) - - if with_bn: - self.bn = BatchNorm(out_channels, not post_mask, True) - - if with_relu: - self.relu = nn.ReLU() - - def forward(self, seq, mask): - if self.pre_mask: - seq = self.mask_opt(seq, mask) - seq = self.conv(seq) - if self.post_mask: - seq = self.mask_opt(seq, mask) - if self.with_bn: - seq = self.bn(seq, mask) - if self.with_relu: - seq = self.relu(seq) - seq = self.dropout(seq) - return seq - - -class AvgPool(nn.Module): - def __init__(self, kernal_size, pre_mask, post_mask): - super(AvgPool, self).__init__() - self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2) - self.pre_mask = pre_mask - self.post_mask = post_mask - self.mask_opt = Mask() - - def forward(self, seq, mask): - if self.pre_mask: - seq = self.mask_opt(seq, mask) - seq = self.avg_pool(seq) - if self.post_mask: - seq = self.mask_opt(seq, mask) - return seq - - -class MaxPool(nn.Module): - def __init__(self, kernal_size, pre_mask, post_mask): - super(MaxPool, self).__init__() - self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2) - self.pre_mask = pre_mask - self.post_mask = post_mask - self.mask_opt = Mask() - - def forward(self, seq, mask): - if self.pre_mask: - seq = self.mask_opt(seq, mask) - seq = self.max_pool(seq) - if self.post_mask: - seq = self.mask_opt(seq, mask) - return seq - - -class Attention(nn.Module): - def __init__(self, num_units, num_heads, keep_prob, is_mask): - super(Attention, self).__init__() - self.num_units = num_units - self.num_heads = num_heads - self.keep_prob = keep_prob - self.is_mask = is_mask - - self.linear_q = nn.Linear(num_units, num_units) - self.linear_k = nn.Linear(num_units, num_units) - self.linear_v = nn.Linear(num_units, num_units) - - self.bn = BatchNorm(num_units, True, is_mask) - self.dropout = nn.Dropout(p=1 - self.keep_prob) - - def forward(self, seq, mask): - in_c = seq.size()[1] - seq = torch.transpose(seq, 1, 2) # (N, L, C) - queries = seq - keys = seq - num_heads = self.num_heads - - # T_q = T_k = L - Q = F.relu(self.linear_q(seq)) # (N, T_q, C) - K = F.relu(self.linear_k(seq)) # (N, T_k, C) - V = F.relu(self.linear_v(seq)) # (N, T_k, C) - - # Split and concat - Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h) - K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h) - V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h) - - # Multiplication - outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k) - # Scale - outputs = outputs / (K_.size()[-1] ** 0.5) - # Key Masking - key_masks = mask.repeat(num_heads, 1) # (h*N, T_k) - key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k) - key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k) - - paddings = torch.ones_like(outputs) * (-INF) # extremely small value - outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs) - - query_masks = mask.repeat(num_heads, 1) # (h*N, T_q) - query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1) - query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k) - - att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k) - att_scores = self.dropout(att_scores) - - # Weighted sum - x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h) - # Restore shape - x_outputs = torch.cat( - torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0), - dim=2) # (N, T_q, C) - - x = torch.transpose(x_outputs, 1, 2) # (N, C, L) - x = self.bn(x, mask) - - return x - - -class RNN(nn.Module): - def __init__(self, hidden_size, output_keep_prob): - super(RNN, self).__init__() - self.hidden_size = hidden_size - self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) - self.output_keep_prob = output_keep_prob - - self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob)) - - def forward(self, seq, mask): - # seq: (N, C, L) - # mask: (N, L) - max_len = seq.size()[2] - length = get_length(mask) - seq = torch.transpose(seq, 1, 2) # to (N, L, C) - packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True, - enforce_sorted=False) - outputs, _ = self.bid_rnn(packed_seq) - outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, - total_length=max_len)[0] - outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C) - outputs = self.out_dropout(outputs) # output dropout - return torch.transpose(outputs, 1, 2) # back to: (N, C, L) - - -class LinearCombine(nn.Module): - def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False): - super(LinearCombine, self).__init__() - self.layers_num = layers_num - self.trainable = trainable - self.input_aware = input_aware - self.word_level = word_level - - if input_aware: - raise NotImplementedError("Input aware is not supported.") - self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num), - requires_grad=trainable) - - def forward(self, seq): - nw = F.softmax(self.w, dim=0) - seq = torch.mul(seq, nw) - seq = torch.sum(seq, dim=0) - return seq - - -class GlobalAvgPool(nn.Module): - def forward(self, x, mask): - x = torch.sum(x, 2) - length = torch.sum(mask, 1, keepdim=True).float() - length += torch.eq(length, 0.0).float() * EPS - length = length.repeat(1, x.size()[1]) - x /= length - return x - - -class GlobalMaxPool(nn.Module): - def forward(self, x, mask): - mask = torch.eq(mask.float(), 0.0).long() - mask = torch.unsqueeze(mask, dim=1).repeat(1, x.size()[1], 1) - mask *= -INF - x += mask - x, _ = torch.max(x + mask, 2) - return x \ No newline at end of file diff --git a/handcraft/textnas/train.py b/handcraft/textnas/train.py deleted file mode 100644 index 44b6e89d..00000000 --- a/handcraft/textnas/train.py +++ /dev/null @@ -1,280 +0,0 @@ -""" -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - handcraft/textnas/train.py \ - --bs 128 --models 12 --schedule pipe -""" - -import numpy as np -import torch -import torch.nn as nn -import argparse - -import cube -from cube.runtime.device import DeviceGroup -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.profiler.memory import memory_summary - -from handcraft.textnas.ops import ConvBN, LinearCombine, AvgPool, MaxPool, RNN, Attention, BatchNorm, GlobalMaxPool, GlobalAvgPool -from handcraft.textnas.dataloader import SharedDataLoader -from handcraft.module.stage import layer_division - - -cube.init() - -parser = argparse.ArgumentParser(description='textnas') -parser.add_argument('--schedule', type=str, default='replicate', choices=['replicate', 'pipe'], - help='scheduling algorithm. model: train model with replicated dataloader. pipe: train model with shared dataloader') -parser.add_argument('--models', type=int, default=1, - help='number of models to be trained in total') -parser.add_argument('--bs', type=int, default=128, - help='number of micro batch (default: paper setting)') -parser.add_argument('--non-uniform', action='store_true', default=False, - help='use non-uniform partition that Bert-allocated GPU can also have models') -args = parser.parse_args() -print(args) - - -_model_divisions = [] -if args.schedule == 'replicate': - num_trainers = DeviceGroup().world_size - num_model_per_device = args.models // num_trainers - _model_divisions = [num_model_per_device] * num_trainers - for idx in range(args.models % num_trainers): - _model_divisions[-1-idx] += 1 -if args.schedule == 'pipe': - num_trainers = DeviceGroup().world_size - 1 - if args.non_uniform: - times = [160] + [80] * args.models - _model_divisions = layer_division(times, DeviceGroup().world_size) - _model_divisions = [end-start for start, end in _model_divisions] - _model_divisions[0] -= 1 - else: - num_model_per_device = args.models // num_trainers - _model_divisions = [0] + [num_model_per_device] * num_trainers - for idx in range(args.models % num_trainers): - _model_divisions[-1-idx] += 1 -print_each_rank(f'model number placements: {_model_divisions}') - - -class WrapperOp(nn.Module): - def __init__(self, op_choice, input_args): - super(WrapperOp, self).__init__() - self.op_choice = op_choice - self.input_args = input_args - self.op = None - - def conv_shortcut(kernel_size, hidden_units, cnn_keep_prob): - return ConvBN(kernel_size, hidden_units, hidden_units, - cnn_keep_prob, False, True) - - if op_choice == 'conv_shortcut1': - self.op = conv_shortcut(*input_args) - elif op_choice == 'conv_shortcut3': - self.op = conv_shortcut(*input_args) - elif op_choice == 'conv_shortcut5': - self.op = conv_shortcut(*input_args) - elif op_choice == 'conv_shortcut7': - self.op = conv_shortcut(*input_args) - elif op_choice == 'AvgPool': - self.op = AvgPool(3, False, True) - elif op_choice == 'MaxPool': - self.op = MaxPool(3, False, True) - elif op_choice == 'RNN': - self.op = RNN(*input_args) - elif op_choice == 'Attention': - self.op = Attention(*input_args) - else: - raise - - def forward(self, prec, mask): - return self.op(prec, mask) - - -class Layer(nn.Module): - def __init__(self, key, prev_keys, hidden_units, choose_from_k, cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask): - super(Layer, self).__init__() - - self.n_candidates = len(prev_keys) - if self.n_candidates: - #===self.prec = mutables.InputChoice(choose_from=prev_keys[-choose_from_k:], n_chosen=1) - self.prec = 1 - else: - # first layer, skip input choice - self.prec = None - '''self.op = mutables.LayerChoice([ - conv_shortcut(1), - conv_shortcut(3), - conv_shortcut(5), - conv_shortcut(7), - AvgPool(3, False, True), - MaxPool(3, False, True), - RNN(hidden_units, lstm_keep_prob), - Attention(hidden_units, 4, att_keep_prob, att_mask) - ])''' - #self.op = conv_shortcut(1) - #self.op = Attention(hidden_units, 4, att_keep_prob, att_mask) - #self.op = RNN(hidden_units, lstm_keep_prob) - #self.op = WrapperOp('RNN', [hidden_units, lstm_keep_prob]) - #self.op = WrapperOp('Attention', [hidden_units, 4, att_keep_prob, att_mask]) - #self.op = WrapperOp('MaxPool', [3, False, True]) - #self.op = WrapperOp('AvgPool', [3, False, True]) - #self.op = WrapperOp('conv_shortcut7', [7, hidden_units, cnn_keep_prob]) - #self.op = WrapperOp('conv_shortcut5', [5, hidden_units, cnn_keep_prob]) - #self.op = WrapperOp('conv_shortcut3', [3, hidden_units, cnn_keep_prob]) - self.op = WrapperOp('conv_shortcut1', [1, hidden_units, cnn_keep_prob]) - if self.n_candidates: - #===self.skipconnect = mutables.InputChoice(choose_from=prev_keys) - self.skipconnect = 1 - else: - self.skipconnect = None - self.bn = BatchNorm(hidden_units, False, True) - - self.prec_n_candidates = choose_from_k - self.skip_n_candidates = len(prev_keys) - - def forward(self, last_layer, prev_layers, mask): - # pass an extra last_layer to deal with layer 0 (prev_layers is empty) - if self.prec is None: - prec = last_layer - else: - #===prec = self.prec(prev_layers[-self.prec.n_candidates:]) # skip first - x = min(len(prev_layers), self.prec_n_candidates) - prec = prev_layers[-x] # skip first - out = self.op(prec, mask) - if self.skipconnect is not None: - #===connection = self.skipconnect(prev_layers[-self.skipconnect.n_candidates:]) - connection = prev_layers[-self.skip_n_candidates] - if connection is not None: - out = out + connection - out = self.bn(out, mask) - return out - - -class Model(nn.Module): - def __init__(self, embedding_dim=768, hidden_units=256, num_layers=24, num_classes=5, choose_from_k=5, - lstm_keep_prob=0.5, cnn_keep_prob=0.5, att_keep_prob=0.5, att_mask=True, - embed_keep_prob=0.5, final_output_keep_prob=1.0, global_pool="avg"): - super(Model, self).__init__() - - # self.embedding = nn.Embedding.from_pretrained(embedding, freeze=False) - self.hidden_units = hidden_units - self.num_layers = num_layers - self.num_classes = num_classes - - self.init_conv = ConvBN(1, embedding_dim, hidden_units, cnn_keep_prob, False, True) - - self.layers = nn.ModuleList() - candidate_keys_pool = [] - for layer_id in range(self.num_layers): - k = "layer_{}".format(layer_id) - self.layers.append(Layer(k, candidate_keys_pool, hidden_units, choose_from_k, - cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask)) - candidate_keys_pool.append(k) - - self.linear_combine = LinearCombine(self.num_layers) - self.linear_out = nn.Linear(self.hidden_units, self.num_classes) - - self.embed_dropout = nn.Dropout(p=1 - embed_keep_prob) - self.output_dropout = nn.Dropout(p=1 - final_output_keep_prob) - - assert global_pool in ["max", "avg"] - if global_pool == "max": - self.global_pool = GlobalMaxPool() - elif global_pool == "avg": - self.global_pool = GlobalAvgPool() - - self.criterion = torch.nn.CrossEntropyLoss() - - def forward(self, inputs, mask, labels): - # sent_ids, mask = inputs - # seq = self.embedding(sent_ids.long()) - seq = self.embed_dropout(inputs) - - seq = torch.transpose(seq, 1, 2) # from (N, L, C) -> (N, C, L) - - x = self.init_conv(seq, mask) - prev_layers = [] - - for layer in self.layers: - x = layer(x, prev_layers, mask) - prev_layers.append(x) - - x = self.linear_combine(torch.stack(prev_layers)) - x = self.global_pool(x, mask) - x = self.output_dropout(x) - x = self.linear_out(x) - loss = self.criterion(x, labels) - return loss - - -if __name__ == '__main__': - - # initialize models - num_model = _model_divisions[DeviceGroup().rank] - print_each_rank(f'initializing {num_model} models...') - models = [Model().cuda() for _ in range(num_model)] - - # initialize dataloaders - if args.schedule == 'replicate': - dataloader = SharedDataLoader(args.bs, replicate=True) - elif args.schedule == 'pipe': - dataloader = SharedDataLoader(args.bs, replicate=False) - else: - assert False - - # initialize optimizer - optimizers = [ - torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) for model in models - ] - - CudaTimer(enable=False) - torch.distributed.barrier() - dataloader = iter(dataloader) - iter_num = 64 - for step in range(iter_num): - if step >= 16: - CudaTimer(enable=True).start('e2e') - # if args.schedule == 'replicate': - # # retiarii baseline - # for _ in range(len(models)): - # text, masks, labels = next(dataloader) - # else: - # text, masks, labels = next(dataloader) - text, masks, labels = next(dataloader) - for model, optimizer in zip(models, optimizers): - CudaTimer().start('nas-model') - loss = model(text, masks, labels) - loss.backward() - optimizer.step() - optimizer.zero_grad() - CudaTimer().stop('nas-model') - - # CudaTimer().start('nas-model') - # losses = [] - # for model in models: - # losses.append(model(text, masks, labels)) - # for loss in losses: - # loss.backward() - # for optimizer in optimizers: - # optimizer.step() - # optimizer.zero_grad() - # CudaTimer().stop('nas-model') - - if step >= 16: - CudaTimer().stop('e2e') - - if step == 0: - torch.distributed.barrier() - print_each_rank('memory after optimizer:', rank_only=0) - memory_summary() - - if (step + 1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-16, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-16) - memory_summary() From ac99d7caff986599d113b0f344a5f68f1f2fa479 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 15:28:45 +0800 Subject: [PATCH 1425/1892] setup logging category --- cube/__init__.py | 62 ++++++++++++++++++++++++++++++++++++++++-- cube/compiler.py | 49 ++++++++++++++++----------------- cube/flags.py | 15 ++++++++-- cube/profiler/timer.py | 30 ++------------------ cube/program.py | 18 ++++-------- cube/utils.py | 39 ++++++++++++++++++++++---- 6 files changed, 138 insertions(+), 75 deletions(-) diff --git a/cube/__init__.py b/cube/__init__.py index ecce5e1c..4cc7c5ab 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,4 +1,5 @@ -import warnings +from typing import Optional +import logging from cube import runtime from cube import profiler @@ -9,13 +10,15 @@ from cube.utils import load_model, load_default_schedule, load_eval_schedule from cube.utils import accum_mode +from cube.flags import CompileFlag + def _check_torch_version(): import torch torch_version = str(torch.__version__).split('+')[0] torch_version = float('.'.join(torch_version.split('.')[:2])) - if torch_version < 1.11: - warnings.warn(f"Expected PyTorch version >= 1.11 but got {torch_version}") + if torch_version < 1.12: + logging.warn(f"expected PyTorch version >= 1.12 but got {torch_version}") def init(): @@ -23,4 +26,57 @@ def init(): _ = runtime.resource.EnvResource() +def _init_logger(): + logging.basicConfig(level=logging.WARN) + + level = lambda flag: logging.INFO if flag else logging.WARN + + logging.getLogger('cube.parser').setLevel( + level(CompileFlag.log_parser) + ) + logging.getLogger('cube.prim').setLevel( + level(CompileFlag.log_transform) + ) + logging.getLogger('cube.adapter').setLevel( + level(CompileFlag.log_adapter) + ) + logging.getLogger('cube.execplan').setLevel( + level(CompileFlag.log_execplan) + ) + logging.getLogger('cube.codegen').setLevel( + level(CompileFlag.log_codegen) + ) + logging.getLogger('cube.runtime').setLevel( + level(CompileFlag.log_runtime) + ) + logging.getLogger('cube.profiler').setLevel( + level(CompileFlag.log_profiler) + ) + logging.getLogger('cube.compiler').setLevel( + logging.INFO + ) + + +def set_logger_level(name: Optional[str], level): + """Set the logger level of cube. + + Args: + name (Optional[str]): the name of the logger, can be one of + 'cube.parser', 'cube.policy', 'cube.adapter', + 'cube.execplan', 'cube.compiler'. Or None to set all. + level (int): the level of the logger, can be one of + logging.DEBUG, logging.INFO, logging.WARN, logging.ERROR. + """ + + if name is None: + logger_names = list(logging.root.manager.loggerDict.keys()) + logger_names = [name for name in logger_names if name.startswith('cube')] + loggers = [logging.getLogger(name) for name in logger_names] + for logger in loggers: + logger.setLevel(level) + elif name in logging.root.manager.loggerDict: + logging.getLogger(name).setLevel(level) + + _check_torch_version() +_init_logger() \ No newline at end of file diff --git a/cube/compiler.py b/cube/compiler.py index 19a01783..f1e6b95f 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -2,17 +2,19 @@ import torch import time import os +import logging import cube +from cube.ir.cten import IRObject +from cube.ir.tensor import IRFullTensor +from cube.ir.unique import IDGenerator from cube.graph.gener.gen import IRAdapterGener from cube.graph.graph import IRGraph -from cube.ir.cten import IRObject from cube.graph.parser.dtype import DType2IRDType -from cube.ir.tensor import IRFullTensor from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.schedule.schedplan import SchedulePlan from cube.graph.function.pyfunc import IRPyFunc +from cube.graph.schedule.schedplan import SchedulePlan from cube.execplan import ExecutionPlan from cube.execplan.planpass.fusion import DiffFusion @@ -20,13 +22,12 @@ from cube.codegen import ModuleCodeGen, ScheduleCodeGen -from cube.profiler.timer import print_each_rank from cube.runtime.device import DeviceGroup from cube.runtime.syndata import CubeDataLoader from cube.program import Program, SemanticDataLoader, SemanticModel -from cube.ir.unique import IDGenerator from cube.flags import CompileFlag +from cube.utils import print_each_rank def compile(model: SemanticModel, *args, @@ -70,7 +71,7 @@ def train_iter(model, dataloader): Returns: Callable: compiled training iteration """ - + logger = logging.getLogger('cube.compiler') # clean global status Program().clear() IDGenerator().clear() @@ -107,12 +108,12 @@ def decorator(fn: Callable) -> Callable: if not override and os.path.exists(filename.format(myrank)): filename = filename.format(myrank) # TODO: set batch size - print('warning: dataloader batch size stay as default.') + logger.warning('dataloader batch size stay as default.') # load module code - print_each_rank(f'loading existed module from {filename} ...') + logger.info(f'loading existed module from {filename} ...') model.load_module(filename) # load schedule code - print_each_rank(f'loading existed schedule from {filename} ...') + logger.info(f'loading existed schedule from {filename} ...') return cube.load_default_schedule(filename) if DeviceGroup().local_rank == 0: @@ -141,7 +142,7 @@ def decorator(fn: Callable) -> Callable: # setup program output Program().set_output(outputs) span = time.time() - start - print('> finish parsing iteration: {:.2f} s'.format(span)) + logger.info('> finish parsing iteration: {:.2f} s'.format(span)) # run policy start = time.time() @@ -149,7 +150,7 @@ def decorator(fn: Callable) -> Callable: assert callable(PAS), f"Policy PAS is not callable" graph = PAS(graph, resource) span = time.time() - start - print('> finish policy expression: {:.2f} s'.format(span)) + logger.info('> finish policy expression: {:.2f} s'.format(span)) if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") @@ -168,16 +169,14 @@ def decorator(fn: Callable) -> Callable: start = time.time() graph = IRAdapterGener.gen(graph, cost_fn=comm_cost_fn) span = time.time() - start - print('> finish generating adapters: {:.2f} s'.format(span)) + logger.info('> finish generating adapters: {:.2f} s'.format(span)) if graph.sched is not None: start = time.time() graph.sched.apply() - # print(graph.sched)qq - if CompileFlag.log_schedule: - print(graph.sched) + logging.getLogger('cube.schedule').info(f'schedule:\n{graph.sched}') span = time.time() - start - print('> finish planpass on applying schedule strategy: {:.2f} s'.format(span)) + logger.info('> finish planpass on applying schedule strategy: {:.2f} s'.format(span)) # to execution plan start = time.time() @@ -188,13 +187,13 @@ def decorator(fn: Callable) -> Callable: if CompileFlag.visualize_plan: execplan.visualize('plan.png') span = time.time() - start - print('> finish lowering to execution plan: {:.2f} s'.format(span)) + logger.info('> finish lowering to execution plan: {:.2f} s'.format(span)) # plan pass for communication optimization start = time.time() execplan = DiffFusion.apply(execplan) span = time.time() - start - print('> finish planpass on diff-fusion operations: {:.2f} s'.format(span)) + logger.info('> finish planpass on diff-fusion operations: {:.2f} s'.format(span)) # execplan.visualize(outfile='plan.png') @@ -203,7 +202,7 @@ def decorator(fn: Callable) -> Callable: start = time.time() execplan = Grouping.apply(execplan) span = time.time() - start - print('> finish planpass on grouping operations: {:.2f} s'.format(span)) + logger.info('> finish planpass on grouping operations: {:.2f} s'.format(span)) # execplan.graph.reset_dependency() # execplan.analyze(outfile='execplan.png') @@ -228,21 +227,21 @@ def decorator(fn: Callable) -> Callable: attach=True ) span = time.time() - start - print('> finish generating code: {:.2f} seconds'.format(span)) + logger.info('> finish generating code: {:.2f} seconds'.format(span)) compile_end = time.time() compile_time = compile_end - compile_start - print('> compile time: {:.2f} seconds'.format(compile_time)) + logger.info('> compile time: {:.2f} seconds'.format(compile_time)) if torch.distributed.is_initialized(): if DeviceGroup().local_rank != 0 and CompileFlag.worker_sleep > 0: - print(f'rank [{DeviceGroup().rank}] starts sleeping {CompileFlag.worker_sleep} seconds...') + logger.info(f'rank [{DeviceGroup().rank}] starts sleeping {CompileFlag.worker_sleep} seconds...') time.sleep(CompileFlag.worker_sleep) torch.distributed.barrier() # load module filename = filename.format(myrank) - print_each_rank(f'loading generated module from {filename} ...') + print_each_rank(f'loading generated module from {filename} ...', logger_fn=logger.info) model.load_module(filename) if torch.distributed.is_initialized(): @@ -252,7 +251,7 @@ def decorator(fn: Callable) -> Callable: # set dataloder batch size (serialize output) if dataloader is not None: bs = model.get_gen_module().get_batch_size() - print_each_rank(f'> setting batch size to: {bs}') + print_each_rank(f'> setting batch size to: {bs}', logger_fn=logger.info) if torch.distributed.is_initialized(): for rank in range(torch.distributed.get_world_size()): if rank == torch.distributed.get_rank(): @@ -267,7 +266,7 @@ def decorator(fn: Callable) -> Callable: torch.distributed.barrier() # load temporal schedule - print_each_rank(f'loading generated schedule from {filename} ...') + print_each_rank(f'loading generated schedule from {filename} ...', logger_fn=logger.info) return cube.load_default_schedule(filename) return decorator diff --git a/cube/flags.py b/cube/flags.py index 18cdb3b7..b548bc14 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -18,9 +18,20 @@ def _to_int(s: str, default=0) -> int: class CompileFlag: # ============= loggings =================== - log_transform = _to_bool('LOG_TRANSFORM') - log_schedule = _to_bool('LOG_SCHEDULE') + # log the parser information log_parser = _to_bool('LOG_PARSER') + # log the primitives applied on the cube graph + log_prim = _to_bool('LOG_PRIM') + # log the adapter information during communication generation + log_adapter = _to_bool('LOG_ADAPTER') + # log the execution plan + log_execplan = _to_bool('LOG_EXECPLAN') + # log the code generation information + log_codegen = _to_bool('LOG_CODEGEN') + # log the runtime information + log_runtime = _to_bool('LOG_RUNTIME') + # log the profiling information + log_profiler = _to_bool('LOG_PROFILER') # ================ compiling ======================== use_torchfx = _to_bool('USE_TORCHFX') # using torch.fx or torchscript as frontend to capture dataflow graph diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index d731bee4..62b7a873 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -1,33 +1,9 @@ from typing import Optional import time -import sys -import warnings +import logging import torch - - -def print_each_rank(msg, rank_only=None, outfile=''): - import os - single_device_mode = os.environ.get('SINGLE_DEV_MODE') - if single_device_mode: - return - - myrank = torch.distributed.get_rank() - outfile = sys.stdout if outfile == '' else outfile - for rank in range(torch.distributed.get_world_size()): - if rank_only is None: - if myrank == rank: - f = open(outfile, 'a') if outfile != sys.stdout else sys.stdout - f.write('rank [{}]: {}\n'.format(rank, msg)) - if outfile != sys.stdout: - f.close() - else: - if myrank == rank_only and rank_only == rank: - f = open(outfile, 'a') if outfile != sys.stdout else sys.stdout - f.write('rank [{}]: {}\n'.format(rank, msg)) - if outfile != sys.stdout: - f.close() - torch.distributed.barrier() +from cube.utils import print_each_rank class CudaTimer: @@ -137,7 +113,7 @@ def duration(self, times: int, field_name: str = 'default') -> float: @return span float: wall clock in milliseconds. """ if field_name not in self.instance.field: - warnings.warn(f"CudaTimer: {field_name} doesn't record.") + logging.getLogger('profiler').warning(f"CudaTimer: {field_name} doesn't record.") return 0.0 if len(self.instance.field[field_name]) != 0: raise RuntimeError(f"timer for field {field_name} not stopped") diff --git a/cube/program.py b/cube/program.py index 504f741b..29e19777 100644 --- a/cube/program.py +++ b/cube/program.py @@ -1,5 +1,5 @@ from typing import List, Tuple, Optional, Any -import warnings +import logging from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -14,6 +14,8 @@ from cube.runtime.device import DeviceGroup from cube.profiler.timer import print_each_rank +from cube.utils import load_model, load_default_schedule + import torch @@ -194,18 +196,8 @@ def get_graph(self): return self.ir_graph def load_module(self, filename: str): - import importlib.util - spec = importlib.util.spec_from_file_location("GenModel", filename) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self._loaded_module: CubeModule = module.GenModel().cuda() - # load parameter content - if self.save_content: - print_each_rank("> loading parameter content...") - self._loaded_module.load_attr_content('./fullmodel.pt') - # initialize reducer - for reducer in self._loaded_module.reducers: - reducer.build_buckets() + """Load module from file.""" + self._loaded_module = load_model(filename, self.save_content) def get_gen_module(self) -> Optional[torch.nn.Module]: return self._loaded_module diff --git a/cube/utils.py b/cube/utils.py index d2c31f76..75a7fa06 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -1,11 +1,39 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Callable +import logging import cube -from cube.profiler.timer import print_each_rank from cube.runtime.device import DeviceGroup -from cube.flags import RuntimeFlag +from cube.flags import RuntimeFlag, CompileFlag -from cube.flags import RuntimeFlag +import torch + + +def print_each_rank(msg: str, rank_only: Optional[int] = None, logger_fn: Callable = print): + """Logging the message. + + Args: + msg (str): message to be logged. + rank_only (int, optional): + the rank to be logged. Defaults to None, which means all ranks. + logger_fn (Callable, optional): + the logger function. Defaults to print. + + Returns: + None + """ + if CompileFlag.dev_mode: + logger_fn(msg) + return + + myrank = torch.distributed.get_rank() + for rank in range(torch.distributed.get_world_size()): + if rank_only is None: + if myrank == rank: + logger_fn('rank [{}]: {}\n'.format(rank, msg)) + else: + if myrank == rank_only and rank_only == rank: + logger_fn('rank [{}]: {}\n'.format(rank, msg)) + torch.distributed.barrier() def _load_module_attr(filename: str, name: str): @@ -22,7 +50,8 @@ def load_model(filename: Optional[str] = None, load_content: bool = True): loaded_module: cube.runtime.module.CubeModule = module.GenModel().cuda() # load parameter content if load_content: - print_each_rank("> loading parameter content...") + print_each_rank("> loading parameter content...", + logger_fn=logging.getLogger('cube.codegen').info) loaded_module.load_attr_content('./fullmodel.pt') # initialize reducer for reducer in loaded_module.reducers: From 8ceda7857b05cbc9cc1ec703481da1648d26e21b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 17:10:07 +0800 Subject: [PATCH 1426/1892] refine parser with logging and better code structure --- cube/graph/parser/__init__.py | 4 +- cube/graph/parser/converter.py | 102 +++++++++--------- .../{ => fx}/concrete_trace_utils/__init__.py | 0 .../concrete_trace_utils/concrete_proxy.py | 0 .../concrete_trace_utils/concrete_tracer.py | 0 .../kwargs_shape_prop/__init__.py | 0 .../kwargs_shape_prop/kwargs_interpreter.py | 0 .../kwargs_shape_prop/kwargs_shape_prop.py | 0 .../concrete_trace_utils/operator_patcher.py | 0 .../{ => fx}/concrete_trace_utils/utils.py | 0 .../parser/{mappingfx.py => fx/mapping.py} | 0 .../parser/{parserfx.py => fx/parser.py} | 31 ++---- cube/graph/parser/register.py | 13 ++- cube/graph/parser/{ => script}/mapping.py | 0 cube/graph/parser/{ => script}/parser.py | 11 +- 15 files changed, 76 insertions(+), 85 deletions(-) rename cube/graph/parser/{ => fx}/concrete_trace_utils/__init__.py (100%) rename cube/graph/parser/{ => fx}/concrete_trace_utils/concrete_proxy.py (100%) rename cube/graph/parser/{ => fx}/concrete_trace_utils/concrete_tracer.py (100%) rename cube/graph/parser/{ => fx}/concrete_trace_utils/kwargs_shape_prop/__init__.py (100%) rename cube/graph/parser/{ => fx}/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py (100%) rename cube/graph/parser/{ => fx}/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py (100%) rename cube/graph/parser/{ => fx}/concrete_trace_utils/operator_patcher.py (100%) rename cube/graph/parser/{ => fx}/concrete_trace_utils/utils.py (100%) rename cube/graph/parser/{mappingfx.py => fx/mapping.py} (100%) rename cube/graph/parser/{parserfx.py => fx/parser.py} (95%) rename cube/graph/parser/{ => script}/mapping.py (100%) rename cube/graph/parser/{ => script}/parser.py (98%) diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index b3811616..d3642c72 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,4 +1,4 @@ -from cube.graph.parser.parser import ScriptModuleParser -from cube.graph.parser.parserfx import FxModuleParser, FxFuncOpTracer +from cube.graph.parser.script.parser import ScriptModuleParser +from cube.graph.parser.fx.parser import FxModuleParser, FxFuncOpTracer from cube.graph.parser.converter import convert_model from cube.graph.parser.register import register \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 191a5c5a..d6576fda 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,10 +1,8 @@ from typing import Optional, List -import warnings +import logging from cube.ir.tensor import IRFullTensor from cube.graph.parser import ScriptModuleParser -from cube.graph.parser import FxModuleParser, FxFuncOpTracer -from cube.graph.parser.concrete_trace_utils import concrete_trace, ExtraSEFPatcher from cube.graph.parser.register import CustomizedOps from cube.graph import IRGraph from cube.flags import CompileFlag @@ -23,71 +21,69 @@ def convert_model(model: torch.nn.Module, dummy_input = None, save_content: bool = True, dynamic_shape: bool = False) -> IRGraph: - """ - Convert torch.nn.Module based model into IRGraph + """Convert torch.nn.Module based model into IRGraph + + Args: + model (torch.nn.Module): single-device model description + input_shapes (Optional[ List[List[int],] ]): + input shapes of model, only required for torch.jit.script parser + dummy_input (Optional[Any]): + dummy input of model, only required for torch.fx parser + save_content (bool): + whether to save the content of model and load it into generated model. Default True. + dynamic_shape (bool): + whether to use dynamic shape. Default False. + + Returns: + IRGraph: IRGraph of model """ # get registered leaf function customized_funcs = CustomizedOps.kOpRuntime.values() leaf_functions = {func: ([], False, None) for func in customized_funcs} - try: - if CompileFlag.use_torchfx: - if CompileFlag.use_default_fx_tracer: - if CompileFlag.log_parser: - print('> use default torch.fx tracer') - # Symbolic tracing frontend - captures the semantics of the module - tracer = FxFuncOpTracer() - traced_graph: torch.fx.Graph = tracer.trace(model) - with ExtraSEFPatcher(): - traced_graph.eliminate_dead_code() - traced_model: torch.fx.GraphModule = torch.fx.GraphModule(model, traced_graph) - if CompileFlag.log_parser: - traced_model.graph.print_tabular() - else: - if CompileFlag.log_parser: - print('> use concrete torch.fx tracer') - if HAS_APEX: - leaf_module = ( - # torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, - apex.normalization.FusedLayerNorm, - # NOTE: the following modules also have different behavior depending on self.training. but currently in used. - # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, - # torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, - # torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.SyncBatchNorm, - # torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, - # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, - ) - else: - print('WARNING: apex package is not installed') - leaf_module = None - traced_model = concrete_trace( - model, - dummy_input, - use_operator_patch=True, - leaf_module=leaf_module, - autowrap_leaf_function=leaf_functions, - cpu_offload=True, + logger = logging.getLogger('cube.parser') + + # step 1: trace model + if CompileFlag.use_torchfx: + logger.info('use concrete torch.fx tracer') + from cube.graph.parser.fx.concrete_trace_utils import concrete_trace + from cube.graph.parser.fx.parser import FxModuleParser + if HAS_APEX: + leaf_module = ( + # torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, + apex.normalization.FusedLayerNorm, + # NOTE: the following modules also have different behavior depending on self.training. but currently in used. + # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, + # torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, + # torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.SyncBatchNorm, + # torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, + # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, ) else: - if CompileFlag.log_parser: - print('> use default torch.jit.script tracer') - traced_model = torch.jit.script(model) - - except Exception as ex: - print(ex) - raise RuntimeError("Cannot convert module into torchscript/torch.fx module.") + logger.warn('apex package is not installed') + leaf_module = None + traced_model = concrete_trace( + model, + dummy_input, + use_operator_patch=True, + leaf_module=leaf_module, + autowrap_leaf_function=leaf_functions, + cpu_offload=True, + ) + else: + logger.info('use torch.jit.script tracer') + traced_model = torch.jit.script(model) + # step 2: convert traced model into IRGraph if CompileFlag.use_torchfx: FxModuleParser.save_content = save_content FxModuleParser.dynamic_shape = dynamic_shape - if CompileFlag.log_parser: - print(f"> use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") + logger.info(f"use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") inputs, nodes, outputs = FxModuleParser.parse(traced_model, dummy_input) module_name = model.__class__.__name__ else: if dynamic_shape: - warnings.warn('dynamic shape is not supported in torch.jit.script', - category=RuntimeWarning) + logger.warn('dynamic shape is not supported in torch.jit.script') ScriptModuleParser.save_content = save_content inputs, nodes, outputs = ScriptModuleParser.parse_module(traced_model, input_shapes) module_name = traced_model.original_name diff --git a/cube/graph/parser/concrete_trace_utils/__init__.py b/cube/graph/parser/fx/concrete_trace_utils/__init__.py similarity index 100% rename from cube/graph/parser/concrete_trace_utils/__init__.py rename to cube/graph/parser/fx/concrete_trace_utils/__init__.py diff --git a/cube/graph/parser/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py similarity index 100% rename from cube/graph/parser/concrete_trace_utils/concrete_proxy.py rename to cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py diff --git a/cube/graph/parser/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py similarity index 100% rename from cube/graph/parser/concrete_trace_utils/concrete_tracer.py rename to cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/__init__.py b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/__init__.py similarity index 100% rename from cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/__init__.py rename to cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/__init__.py diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py similarity index 100% rename from cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py rename to cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py diff --git a/cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py similarity index 100% rename from cube/graph/parser/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py rename to cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py diff --git a/cube/graph/parser/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py similarity index 100% rename from cube/graph/parser/concrete_trace_utils/operator_patcher.py rename to cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py diff --git a/cube/graph/parser/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py similarity index 100% rename from cube/graph/parser/concrete_trace_utils/utils.py rename to cube/graph/parser/fx/concrete_trace_utils/utils.py diff --git a/cube/graph/parser/mappingfx.py b/cube/graph/parser/fx/mapping.py similarity index 100% rename from cube/graph/parser/mappingfx.py rename to cube/graph/parser/fx/mapping.py diff --git a/cube/graph/parser/parserfx.py b/cube/graph/parser/fx/parser.py similarity index 95% rename from cube/graph/parser/parserfx.py rename to cube/graph/parser/fx/parser.py index 7e30378d..dbb8c43f 100644 --- a/cube/graph/parser/parserfx.py +++ b/cube/graph/parser/fx/parser.py @@ -1,6 +1,6 @@ import torch import enum -import warnings +import logging from typing import Any, List, Tuple, Callable, Union, Dict, Type from cube.ir.operator import IRFwOperation @@ -8,12 +8,10 @@ from cube.ir.cten import IRObject, IRCell from cube.graph.parser.frame import Frame from cube.graph.parser.dtype import DType2IRDType -from cube.graph.parser.mappingfx import SignFx2Op +from cube.graph.parser.fx.mapping import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import IRDimops -from cube.flags import CompileFlag - import torch.fx class ErasedDevice: @@ -94,7 +92,7 @@ def parse(module: torch.fx.GraphModule, The overall entry to parse a torch.fx graph module """ - from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp + from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp frame = frame if frame is not None else Frame() frame.push_var() @@ -103,8 +101,8 @@ def parse(module: torch.fx.GraphModule, assert isinstance(dummy_inputs, dict), "Expected dummy inputs to parse module" inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] - if CompileFlag.log_parser: - print(f'> torch.fx parser: graph inputs: {inputs}') + logging.getLogger('cube.parser').info( + f'> torch.fx parser: graph inputs: {inputs}') # shape propagation ShapeProp(module).propagate(dummy_inputs) @@ -128,11 +126,9 @@ def parse(module: torch.fx.GraphModule, # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, # extend to other input types if hasattr(dummy_inputs, input.name): - # print(f'dummy_inputs has {input.name}') shape = getattr(dummy_inputs, input.name).size() else: # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name - # print(f'dummy_inputs does not have {input.name}') shape = None dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) @@ -142,7 +138,7 @@ def parse(module: torch.fx.GraphModule, # add activations to frame, including call_func/call_method output and final output # call_module corresponds to leaf torch.nn.module - from cube.graph.parser.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata + from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata activation_op_strs = {'call_function', 'output', 'call_method', 'call_module'} activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] def parse_complex_out(meta_out): @@ -186,8 +182,8 @@ def parse_complex_out(meta_out): all_ir_nodes: List[IRFwOperation] = list() total_node_num = len(module.graph.nodes) for nidx, node in enumerate(module.graph.nodes): - if CompileFlag.log_parser: - print(f'> torch.fx parser: [{nidx}/{total_node_num}] parsing node {node}...', flush=True) + logging.getLogger('cube.parser').info( + f'[{nidx}/{total_node_num}] parsing node {node}...') ir_nodes = FxModuleParser.parse_node(node, module, frame) if ir_nodes is not None: all_ir_nodes += ir_nodes @@ -321,14 +317,12 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): - warnings.warn(f'Find unknown pytorch operation: {fsig}', - category=RuntimeWarning) + logging.getLogger('cube.parser').warn(f'Find unknown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fsig ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) # case2: python runtime function else: - warnings.warn(f'Set python runtime function: {fsig}', - category=RuntimeWarning) + logging.getLogger('cube.parser').warn(f'Set python runtime function: {fsig}') ir_node = IRPyFunc(fsig, input_vals, [IRObject()], **kwargs) if isinstance(ir_node, IRCell): @@ -365,9 +359,7 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, else: frame.set_var(node.name, ir_node) - if CompileFlag.log_parser: - print(f'parsing result: {ir_node}', flush=True) - + logging.getLogger('cube.parser').info(f'parsing result: {ir_node}') return ir_nodes @staticmethod @@ -482,7 +474,6 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> in_type = node.args[0].meta['type'] assert node_target in in_type().__dir__(), f'node_target = {node_target}, in_type().__dir__() = {in_type().__dir__()}' sig = f'{in_type.__name__}.{node_target}' - print(f'The method is not torch or Tensor, but {sig}') return sig @staticmethod diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 8afbce32..861a36c4 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -5,7 +5,8 @@ from typing import Dict, Callable, List, Optional, Any from functools import partial import inspect -import warnings +import logging + import torch from cube.graph.function.dimops import IRDimops, OpAnno @@ -133,9 +134,10 @@ def decorator(fn: Callable): if code_impl_pattern == 'import': import_path = inspect.getmodule(fn).__name__ if import_path == '__main__': - warnings.warn(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' - f'This may cause error when the function has inner functions from other modules. ' - f'To solve this, define the function in another module and import into main', stacklevel=0) + logger = logging.getLogger('cube.parser') + logger.warn(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' + f'This may cause error when the function has inner functions from other modules. ' + f'To solve this, define the function in another module and import into main', stacklevel=0) code = inspect.getsource(fn) code = code[code.index('def'):] else: @@ -157,7 +159,8 @@ def udfop(*args, signature=None, **kwargs): kwargs[name] = val return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, transform_rules=rules, **kwargs) - print(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') + logging.getLogger('cube.parser').info( + f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') CustomizedOps.register(fsig, udfop, code, fn) return fn diff --git a/cube/graph/parser/mapping.py b/cube/graph/parser/script/mapping.py similarity index 100% rename from cube/graph/parser/mapping.py rename to cube/graph/parser/script/mapping.py diff --git a/cube/graph/parser/parser.py b/cube/graph/parser/script/parser.py similarity index 98% rename from cube/graph/parser/parser.py rename to cube/graph/parser/script/parser.py index 0bc8aad4..855118b1 100644 --- a/cube/graph/parser/parser.py +++ b/cube/graph/parser/script/parser.py @@ -1,6 +1,7 @@ import torch import enum import re +import logging from typing import Any, List, Tuple, Optional from cube.ir.cten import IRObject @@ -9,7 +10,7 @@ from cube.ir.tensor import IRFullTensor import cube.ir as ir from cube.graph.parser.frame import Frame -from cube.graph.parser.mapping import Sign2Op +from cube.graph.parser.script.mapping import Sign2Op from cube.graph.parser.dtype import DType2IRDType @@ -75,7 +76,7 @@ def parse_module(module, try: ret = ir_node.infer_shape() if not ret: - print(f'warning: {ir_node} cannot infer shape') + logging.getLogger('cube.parser').error(f'{ir_node} cannot infer shape') except Exception: raise RuntimeError( f"====== Shape Infer Error ====\n\n\n" @@ -126,7 +127,7 @@ def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): try: ret = ir_node.infer_shape() if not ret: - print(f'warning: {ir_node} cannot infer shape') + logging.getLogger('cube.parser').error(f'{ir_node} cannot infer shape') except Exception: raise RuntimeError( f"====== Shape Infer Error ====\n\n\n" @@ -552,7 +553,7 @@ def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: try: ret = ir_node.infer_shape() if not ret: - print(f'warning: {ir_node} cannot infer shape') + logging.getLogger('cube.parser').error(f'{ir_node} cannot infer shape') except Exception: raise RuntimeError( f"====== Shape Infer Error ====\n\n\n" @@ -666,7 +667,7 @@ def parse_prim_loop_node(node, module, frame: Frame) -> List[IRFwOperation]: try: ret = ir_node.infer_shape() if not ret: - print(f'warning: {ir_node} cannot infer shape') + logging.getLogger('cube.parser').error(f'{ir_node} cannot infer shape') except Exception: raise RuntimeError( f"====== Shape Infer Error ====\n\n\n" From 4437d68a6d928e5a10a6470e799a0517e1f69379 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 17:33:20 +0800 Subject: [PATCH 1427/1892] fix compile flag --- cube/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/__init__.py b/cube/__init__.py index 4cc7c5ab..1421b8a2 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -35,7 +35,7 @@ def _init_logger(): level(CompileFlag.log_parser) ) logging.getLogger('cube.prim').setLevel( - level(CompileFlag.log_transform) + level(CompileFlag.log_prim) ) logging.getLogger('cube.adapter').setLevel( level(CompileFlag.log_adapter) From 5524b1c7d01e28ec8aaa4465d53e317877ffd606 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 17:33:45 +0800 Subject: [PATCH 1428/1892] refine compiler logging --- cube/compiler.py | 20 ++++++++++---------- cube/utils.py | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index f1e6b95f..075f0204 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -142,7 +142,7 @@ def decorator(fn: Callable) -> Callable: # setup program output Program().set_output(outputs) span = time.time() - start - logger.info('> finish parsing iteration: {:.2f} s'.format(span)) + logger.info('finish parsing iteration: {:.2f} s'.format(span)) # run policy start = time.time() @@ -150,7 +150,7 @@ def decorator(fn: Callable) -> Callable: assert callable(PAS), f"Policy PAS is not callable" graph = PAS(graph, resource) span = time.time() - start - logger.info('> finish policy expression: {:.2f} s'.format(span)) + logger.info('finish policy expression: {:.2f} s'.format(span)) if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") @@ -169,14 +169,14 @@ def decorator(fn: Callable) -> Callable: start = time.time() graph = IRAdapterGener.gen(graph, cost_fn=comm_cost_fn) span = time.time() - start - logger.info('> finish generating adapters: {:.2f} s'.format(span)) + logger.info('finish generating adapters: {:.2f} s'.format(span)) if graph.sched is not None: start = time.time() graph.sched.apply() logging.getLogger('cube.schedule').info(f'schedule:\n{graph.sched}') span = time.time() - start - logger.info('> finish planpass on applying schedule strategy: {:.2f} s'.format(span)) + logger.info('finish planpass on applying schedule strategy: {:.2f} s'.format(span)) # to execution plan start = time.time() @@ -187,13 +187,13 @@ def decorator(fn: Callable) -> Callable: if CompileFlag.visualize_plan: execplan.visualize('plan.png') span = time.time() - start - logger.info('> finish lowering to execution plan: {:.2f} s'.format(span)) + logger.info('finish lowering to execution plan: {:.2f} s'.format(span)) # plan pass for communication optimization start = time.time() execplan = DiffFusion.apply(execplan) span = time.time() - start - logger.info('> finish planpass on diff-fusion operations: {:.2f} s'.format(span)) + logger.info('finish planpass on diff-fusion operations: {:.2f} s'.format(span)) # execplan.visualize(outfile='plan.png') @@ -202,7 +202,7 @@ def decorator(fn: Callable) -> Callable: start = time.time() execplan = Grouping.apply(execplan) span = time.time() - start - logger.info('> finish planpass on grouping operations: {:.2f} s'.format(span)) + logger.info('finish planpass on grouping operations: {:.2f} s'.format(span)) # execplan.graph.reset_dependency() # execplan.analyze(outfile='execplan.png') @@ -227,11 +227,11 @@ def decorator(fn: Callable) -> Callable: attach=True ) span = time.time() - start - logger.info('> finish generating code: {:.2f} seconds'.format(span)) + logger.info('finish generating code: {:.2f} seconds'.format(span)) compile_end = time.time() compile_time = compile_end - compile_start - logger.info('> compile time: {:.2f} seconds'.format(compile_time)) + logger.info('compile time: {:.2f} seconds'.format(compile_time)) if torch.distributed.is_initialized(): if DeviceGroup().local_rank != 0 and CompileFlag.worker_sleep > 0: @@ -251,7 +251,7 @@ def decorator(fn: Callable) -> Callable: # set dataloder batch size (serialize output) if dataloader is not None: bs = model.get_gen_module().get_batch_size() - print_each_rank(f'> setting batch size to: {bs}', logger_fn=logger.info) + print_each_rank(f'setting batch size to: {bs}', logger_fn=logger.info) if torch.distributed.is_initialized(): for rank in range(torch.distributed.get_world_size()): if rank == torch.distributed.get_rank(): diff --git a/cube/utils.py b/cube/utils.py index 75a7fa06..79b76656 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -29,10 +29,10 @@ def print_each_rank(msg: str, rank_only: Optional[int] = None, logger_fn: Callab for rank in range(torch.distributed.get_world_size()): if rank_only is None: if myrank == rank: - logger_fn('rank [{}]: {}\n'.format(rank, msg)) + logger_fn('rank [{}]: {}'.format(rank, msg)) else: if myrank == rank_only and rank_only == rank: - logger_fn('rank [{}]: {}\n'.format(rank, msg)) + logger_fn('rank [{}]: {}'.format(rank, msg)) torch.distributed.barrier() From 03b7b98dd1922cbd98206acc7b75cc697171c54a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 17:35:57 +0800 Subject: [PATCH 1429/1892] refine logging of applying graph primitives --- cube/algorithm/ops/dimops.py | 9 ++++---- cube/graph/graph.py | 42 ++++++++++++++++++++---------------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 32e1b8fe..640e0d35 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -1,4 +1,5 @@ from typing import List, Optional, Any, Dict, Union, Tuple +import logging from cube.algorithm.generics import GenericDistAlgo from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule @@ -7,8 +8,6 @@ from cube.ir.operator import IRFwOperation from collections import deque -from cube.flags import CompileFlag - class DimSplitEinops(GenericDistAlgo): """! @@ -124,9 +123,9 @@ def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List else: adim, reduce = 'Value', None - if CompileFlag.log_transform: - color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' - print(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") + logger = logging.getLogger('cube.prim') + color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' + logger.info(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") if not satisfy: return None rule: TransformRule = self.infer(idx, dim, num) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 0f08800f..7fb45ea3 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -8,7 +8,7 @@ """ from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any -import warnings +import logging import copy import dill import sys @@ -77,10 +77,14 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: # align graph with input tensors itensors: Tuple[IRObject, ...] = self.inputs() if len(args) != len(itensors): - print(f'ERROR(skipping) len(args) != len(itensors): {len(args)} != {len(itensors)}') + logger = logging.getLogger('cube.parser') + logger.error( + f'cube graph forward: skipping arguments due to len(args) != len(itensors): ' + f'{len(args)} != {len(itensors)}' + ) if len(args) > len(itensors): args = args[:len(itensors)] - print(f'WARNING: args shrinked into {args}') + logger.warn(f'cube graph forward: args shrinked into {args}') else: raise RuntimeError('len(args) < len(itensors)') @@ -161,7 +165,8 @@ def backward(self, loss: IRSubTensor): if ftensor.is_loss(): continue consumers = [n for n in self.consumers(ftensor) if isinstance(n, IRFwOperation)] if len(consumers) == 0 and ftensor.requires_grad: - print(f"warning: detected a dead ftensor which is not consumed by any nodes:\n\t{ftensor.name}: {ftensor}", file=sys.stderr) + logging.getLogger('cube.parser').warn( + f"detected a dead ftensor which is not consumed by any nodes:\n\t{ftensor.name}: {ftensor}") ftensor.requires_grad = False # infer gradient @@ -294,13 +299,12 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis raise TypeError("Expected op to be forward op or data op") if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") - if node.name == 'multiref': - warnings.warn( - 'Detected partition a multiref node. This will be skipped as system will automatically handle it.') + logger = logging.getLogger('cube.prim') + if node.name == 'multiref': + logger.warn(f'skip replicating multiref ({node.cid}), which will be handled by system.') return [node] if isinstance(node, IRPyFunc): - warnings.warn( - 'Detected partition a python runtime function. This will be skipped as system will automatically handle it') + logger.warn(f'skip replicating pyfunc ({node.cid}), which will be handled by system.') return [node] fsegment: IRSegment = self.segment(node) @@ -362,13 +366,12 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], f"The partition algorithm ({algo}) is not initialized for this node" assert isinstance(node, (IRFwOperation, IRDataOperation)), \ f"Only allow op to be forward op or data op, but got: {node}" + logger = logging.getLogger('cube.prim') if node.name == 'multiref': - warnings.warn( - 'Detected partition a multiref node. This will be skipped as system will automatically handle it.') + logger.warn(f'skip partitioning multiref ({node.cid}), which will be handled by system.') return [node] if isinstance(node, IRPyFunc): - warnings.warn( - 'Detected partition a python runtime function. This will be skipped as system will automatically handle it') + logger.warn(f'skip partitioning pyfunc ({node.cid}), which will be handled by system.') return [node] # get partitioned sub-nodes @@ -859,14 +862,13 @@ def staging(self, nodes: Tuple[IRFwOperation]): # adjust the start of the first stage to involve beginning operators for idx in range(starts[0]): node = self.node(idx) - if isinstance(node, IRDataOperation): - continue + if isinstance(node, IRDataOperation): continue assert isinstance(node, IRFwOperation), \ f"Expected nodes previous from the first stage are all IRFwOperation, but got {type(node)}" if node.name == 'multiref' or isinstance(node, IRPyFunc): pass else: - warnings.warn(f'Detect a node: {node} that is previous from the first stage. Will be included inside the first stage') + logging.getLogger('cube.prim').info(f'involve node {node.name}({node.cid}) into the first stage') starts[0] = idx break @@ -881,7 +883,8 @@ def staging(self, nodes: Tuple[IRFwOperation]): begin = starts[sid] end = starts[sid+1] if sid != len(starts) - 1 else last_fidx + 1 if begin >= end: - warnings.warn(f"Detected stage {sid} doesn't have operators: [begin({begin}): end({end})). Skipped") + logging.getLogger('cube.prim').warn( + f"skip stage {sid} which doesn't have operators: [begin({begin}): end({end})).") continue fnodes = self._nodes[begin:end] assert all(isinstance(node, IRFwOperation) for node in fnodes), \ @@ -1026,7 +1029,10 @@ def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: skip += nodes[end:] for node in skip: if isinstance(node, IRGraphAnchor): continue - print(f"skip recompute node: {node.name} ({node.cid}) as it doesn't require gradient and appears at head or tail.") + logging.getLogger('cube.prim').info( + f"skip recompute node: {node.name} ({node.cid}) as " + f"it doesn't require gradient and appears at head or tail." + ) nodes = nodes[:end] for fnode in nodes: fnode.recompute = recompute_group_id From 18fb3bab222338481859e436b6fb640b340f0f3f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 17:46:56 +0800 Subject: [PATCH 1430/1892] refine logging of adapter --- cube/graph/gener/concurrent.py | 12 +++++++----- cube/graph/gener/gen.py | 11 ++++++++--- cube/graph/gener/rvd/intra.py | 12 ++++-------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index d946842e..b4050ad6 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -4,7 +4,7 @@ from typing import List, Optional, Dict, Tuple, Callable import copy import numpy as np -import sys +import logging from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap from cube.ir.adapter.prim import IRAdapterPrim @@ -59,13 +59,14 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], except Exception as e: fadapter = None color, default = '\033[33m' , '\033[0m' - print( + msg = ( f"{color}========== Fail to use intra-RVD ==========\n" f"full tensor: {fptensors[0].parent} | is grad: {fptensors[0].parent.is_grad()}\n" f"Reason: {str(e)}\n" f"Switch to general P2P communication.\n" - f"===========================================\n{default}", file=sys.stderr + f"===========================================\n{default}" ) + logging.getLogger('cube.adapter').warn(f'intra-RVD:\n{msg}') # Case 2: sperating device (inter-rvd) if (not CompileFlag.disable_inter_rvd) and len(set(pdevs).intersection(cdevs)) == 0: @@ -74,13 +75,14 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], except Exception as e: fadapter = None color, default = '\033[33m' , '\033[0m' - print( + msg = ( f"{color}========== Fail to use inter-RVD ==========\n" f"full tensor: {fptensors[0].parent}\n" f"Reason: {str(e)}\n" f"Switch to general P2P communication.\n" - f"===========================================\n{default}", file=sys.stderr + f"===========================================\n{default}" ) + logging.getLogger('cube.adapter').warn(f'inter-RVD:\n{msg}') # Case 3: General cases # warnings.warn('The adapter is generated using P2P communication') diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 23cfc806..cb6296bb 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, Tuple, Callable, Set import numpy as np import itertools -import warnings +import logging from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener @@ -317,8 +317,13 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # FIXME: assume producers and consumers can run in parallel for ftensor in ftensors: - # print(graph.debug_tensor_map_str(ftensor)) - # print(graph.mirror.debug_tensor_map_str(ftensor.grad)) + logging.getLogger('cube.adapter').debug( + f'generate adapter for forward tenosrs:\n' + f'{graph.debug_tensor_map_str(ftensor)}') + if ftensor.grad is not None: + logging.getLogger('cube.adapter').debug( + f'generate adapter for backward tenosrs:\n' + f'{graph.mirror.debug_tensor_map_str(ftensor.grad)}') # producers can be operators and graph inputs fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index bfb85488..b2b23d86 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -3,7 +3,7 @@ import numpy as np import sys import copy -import warnings +import logging from cube.ir.dtype import IRDType from cube.ir.cten import IRCell @@ -211,7 +211,6 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD) -> List[Tuple[RVDLayout, Li for olayout in olayouts: if len(ilayouts) > 1: olayout = copy.copy(olayout) if len(olayouts) > 1: ilayout = copy.copy(ilayout) - # print(f'transition: {ilayout}{tuple(t.device[0] for t in ilayout.mat.flatten())} -> {olayout}') imat = RVDLayout.dim2last(ilayout.mat, decd, chunks) omat = RVDLayout.dim2last(olayout.mat, incd, chunks) for itensor, otensor in zip(imat.flatten(), omat.flatten()): @@ -264,7 +263,7 @@ def path(ilayout: RVDLayout, olayout: RVDLayout, f"Switch to a fixed plan: ilayout -> FullReplica -> olayout" ) color, default = '\033[33m' , '\033[0m' - print(color+warn_msg+default, file=sys.stderr) + logging.getLogger('cube.adapter').warn(f'intra-RVD:\n{color+warn_msg+default}') all_prims = IntraPathFinder.backup_path(ilayout, olayout, cost_fn) return all_prims @@ -391,7 +390,6 @@ def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, visited.add(visit) IntraPathFinder._cached_intra_paths[key][src_rvd] = paths - # print for debug # for idx, path in enumerate(paths): # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") @@ -558,9 +556,6 @@ def auto_place(graph: IRSegment, ftensor: IRFullTensor, assert len(producers) == len(consumers), \ f"Expect same number of producer and consumer, but got {len(producers)} producers and {len(consumers)} consumers" - if any(len(consumer.device) > 0 for consumer in consumers): - warnings.warn('Detected at least one consumer has been assigned to a device, which will be overrided by a new device placement.') - if len(producers) == 1: return [producers[0].device[0]] @@ -663,11 +658,12 @@ def advice(shape: TShape, # - if not find, keep forward one as optimal and adopt backup plan for backward one else: placement = list(fw_consumer_devices)[0] - print(f"================ forward-backward mis-aligned! ============== \n" + msg = (f"================ forward-backward mis-aligned! ============== \n" f"fw device choices: {fw_consumer_devices} | hops: {'->'.join(str(rvd) for rvd in fw_rvd_hops)}\n" f"bw hops: {'->'.join(str(rvd) for rvd in bw_rvd_hops)}\n" f"using placement: {placement}\n" f"=============================================================") + logging.getLogger('cube.adapter').warn(f'intra-RVD:\n{msg}') bw_rvd_hops = IntraPathFinder.get_backup_path(ftensor, bw_src_rvd, bw_dst_rvd, cost_fn) # estimate cost From 6c1ca4dcf0f97aaeb0a01f21aedc0b919410a00d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 18:04:02 +0800 Subject: [PATCH 1431/1892] refine logging of parser --- cube/graph/function/dimops.py | 10 ++++------ cube/graph/function/function.py | 20 +++++++++----------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index aa142a3d..cdac9fdd 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -64,16 +64,13 @@ from typing import Callable, Dict, Iterable, List, Union, Set, Tuple, Optional import enum -import importlib import re import string -import warnings +import logging from cube.ir.cten import IRTensor, IRObject -from cube.ir.dtype import DTypeInferRule from cube.ir.operator import IRFwOperation from cube.algorithm.factory import DistAlgorithmFactory -from cube.ir.tensor import IRSubTensor _kSpecialIdentifiers = ('*', '?') @@ -665,8 +662,9 @@ def infer_shape(self) -> bool: shape_anno = self.oanno(oidx) if str(shape_anno) == '?': assert isinstance(otensor, IRObject), f"expect IRObject for unknown shape, get {otensor}" - warnings.warn('detect IRObject output in a IRDimops, please ensure the annotation is' - 'correct w.r.t the partition policy.') + logging.getLogger('cube.parser').warn( + 'detect IRObject output in a IRDimops, please ensure the annotation is ' + 'correct w.r.t the partition policy.') continue shape = [] for odim in range(shape_anno.ndims): diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 64e11298..fdd36040 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -5,15 +5,14 @@ import operator import numpy as np import math -import warnings -import functools +import logging from collections.abc import Iterable from cube.ir.cten import IRTensor, IRObject from cube.ir.tensor import IRSubTensor, IRFullTensor from cube.ir.dtype import IRDType from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.function.dimops import DimAnno, DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule +from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D from cube.graph.function.anchor import IRGraphAnchor @@ -56,10 +55,8 @@ def Linear(input, weight, bias=None, signature = None): return IRDimops(Linear, 'linear', signature, annos, [input, weight], bias=None) else: annos = ['b * k^, n k^, n -> b * n'] - # rules = [TransformRule( - # [DimopSplit.D(-1), DimopSplit.D(1), DimopSplit.V()], [DimopSplit.V()] - # )] - warnings.warn('detected a linear operator has bias, the partition on reduction dimension is disabled.') + logging.getLogger('cube.parser').warn( + 'detected a linear operator has bias, the partition on reduction dimension is disabled.') return IRDimops(Linear, 'linear', signature, annos, [input, weight, bias]) @@ -305,7 +302,6 @@ def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: rhs_shape[dim-rofst] = '1' else: raise ValueError(f"cannot broadcast lhs: {lhs.shape} and rhs: {rhs.shape}") - # print(lhs.shape, rhs.shape, lhs_shape, rhs_shape, out_shape) return lhs_shape, rhs_shape, out_shape @@ -1665,8 +1661,9 @@ def ShapeAsTensor(input: IRTensor, signature = None): """ torch._shape_as_tensor """ - warnings.warn('shape_as_tensor is interpreted as an IRPyFunc' - ' and generate an IRObject instead of IRTensor') + logging.getLogger('cube.parser').warn( + 'shape_as_tensor is interpreted as an IRPyFunc ' + 'and generate an IRObject instead of IRTensor') signature = 'torch._shape_as_tensor' return IRPyFunc(signature, [input], [IRObject(name='shape', value=input.shape)]) edim_in = ShapeAnno.create_shape_str(input.shape) @@ -1760,7 +1757,8 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], return torch.device('cpu') if name == 'layout': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - warnings.warn('hack currently, please ensure the input tensor is in torch.strided layout') + logging.getLogger('cube.parser').warn( + "getattr of 'layout' will always return torch.strided") return torch.strided if isinstance(obj, torch.finfo): return getattr(obj, name) From bbb74f4fb697f98167c7fa01c11c8b3b61a8af34 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 18:20:58 +0800 Subject: [PATCH 1432/1892] refine logging of profiler --- cube/profiler/database.py | 7 ++++--- cube/profiler/estimator.py | 3 ++- cube/profiler/memory.py | 11 ++++++++--- cube/profiler/timer.py | 2 +- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 9dc22f01..b716fe64 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -7,6 +7,7 @@ import time import os import json +import logging import _operator import cube @@ -252,20 +253,20 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b residual_mem += t.byte_size() in_mem_info.append(t.byte_size()) else: - print(f'WARNING: input {t} is skipped.') + logging.getLogger('cube.profiler').warn('skip input {t}') # run profiling try: fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = \ CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) except: - print(f'WARNING: fail to profile {node}') + logging.getLogger('cube.profiler').error('fail to profile {node}') fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = 0, 0, 0, [], [] # log to database key = self._serialize(node) self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span,\ infer_memory, train_mem_info, residual_mem, train_mem2in_idx) - print( + logging.getLogger('cube.profiler').info( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | fw: {round(fw_span, 2)} ms | " f"bw: {round(bw_span, 2)} ms | infer mem: {infer_memory} | train mem info: {train_mem_info} | idx: {train_mem2in_idx}") diff --git a/cube/profiler/estimator.py b/cube/profiler/estimator.py index 3b8fcb89..29869be5 100644 --- a/cube/profiler/estimator.py +++ b/cube/profiler/estimator.py @@ -1,6 +1,7 @@ from typing import Union, Tuple import sys import os +import logging from cube.ir.operator import IRFwOperation from cube.graph.segment import IRSegment @@ -40,7 +41,7 @@ def __call__(self, nodes_or_segment: Union[Tuple[IRFwOperation], IRSegment], except Exception as e: color, default = '\033[31m', '\033[0m' error_msg = f'fail to run node: {node}\nerror: {e}' - print(f'{color}{error_msg}{default}', file=sys.stderr) + logging.getLogger('cube.profiler').error(f'{color}{error_msg}{default}') fw_span, bw_span, infer_mem, train_mem_info = 0, 0, 0, [0] if train: diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index 89719625..204b472e 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -1,6 +1,7 @@ from typing import Any, List +import logging +from cube.utils import print_each_rank import torch -from cube.profiler.timer import print_each_rank def memory_summary(): @@ -10,6 +11,7 @@ def memory_summary(): # mem = torch.cuda.max_memory_reserved() print_each_rank( '{:.2f} GB memory consumption'.format(mem / 1024 / 1024 / 1024), + logger_fn=logging.getLogger('cube.profiler').info ) return mem @@ -28,6 +30,7 @@ def model_summary(model: torch.nn.Module, inputs: List[Any], do_eval=False, max_ Make sure all of these attributes are not used in modules. """ + logger = logging.getLogger('cube.profiler') torch.cuda.empty_cache() static_memory = torch.cuda.memory_allocated() print_each_rank( @@ -45,7 +48,8 @@ def before_forward(module, input): name = module.__class__.__name__ module._summary_begin_end = True prefix = ' ' * module._summary_depth + '[Begin] > ' - print_each_rank(prefix + '{}:'.format(name), rank_only=0) + print_each_rank(prefix + '{}:'.format(name), rank_only=0, + logger_fn=logger.info) if module._summary_depth < max_depth: module._summary_memory_state = torch.cuda.memory_allocated() stat['depth'] += 1 @@ -67,7 +71,8 @@ def after_forward(module, input, output): prefix += '[End] > ' if module._summary_begin_end else '> ' print_each_rank( prefix + '{}: Mem {:,.2f} MB, Params: {:,} ({:,.2f} MB if fp32)'.format( - name, mem_consumption, n_params, n_params / 1024 / 1024 * 4), rank_only=0) + name, mem_consumption, n_params, n_params / 1024 / 1024 * 4), + rank_only=0, logger_fn=logger.info) handle_pre = torch.nn.modules.module.register_module_forward_pre_hook(before_forward) handle_after = torch.nn.modules.module.register_module_forward_hook(after_forward) diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index 62b7a873..a1dd1ada 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -113,7 +113,7 @@ def duration(self, times: int, field_name: str = 'default') -> float: @return span float: wall clock in milliseconds. """ if field_name not in self.instance.field: - logging.getLogger('profiler').warning(f"CudaTimer: {field_name} doesn't record.") + logging.getLogger('cube.profiler').warn(f"CudaTimer: {field_name} doesn't record.") return 0.0 if len(self.instance.field[field_name]) != 0: raise RuntimeError(f"timer for field {field_name} not stopped") From 41ac790e2e8a3aa7acb9120f0b2f8bf52ee0c1ea Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 18:21:56 +0800 Subject: [PATCH 1433/1892] refine logging of execplan --- cube/compiler.py | 2 +- cube/execplan/planpass/fusion.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 075f0204..a725be02 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -174,7 +174,7 @@ def decorator(fn: Callable) -> Callable: if graph.sched is not None: start = time.time() graph.sched.apply() - logging.getLogger('cube.schedule').info(f'schedule:\n{graph.sched}') + logging.getLogger('cube.execplan').info(f'schedule:\n{graph.sched}') span = time.time() - start logger.info('finish planpass on applying schedule strategy: {:.2f} s'.format(span)) diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index fefcb17b..2383e584 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -1,4 +1,5 @@ from typing import List, Union, Set +import logging from cube.graph.graph import IRSegment from cube.ir.adapter import IRAdapter @@ -44,7 +45,8 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: ret = DiffFusion.nnfuse(fadapter) cnt = cnt+1 if ret else cnt visited.add(node) - print(f'successfully generate {cnt} differentiable adapters') + logging.getLogger('cube.execplan').info( + f'adapter fusion: successfully fuse {cnt} differentiable adapters') return execplan @staticmethod From af6c971d749ecdd7dfdfda31c7ff49acb1cc68cd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 18:22:33 +0800 Subject: [PATCH 1434/1892] refine logging of profiler --- cube/__init__.py | 2 +- cube/flags.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/cube/__init__.py b/cube/__init__.py index 1421b8a2..9fe97b7e 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -50,7 +50,7 @@ def _init_logger(): level(CompileFlag.log_runtime) ) logging.getLogger('cube.profiler').setLevel( - level(CompileFlag.log_profiler) + logging.INFO ) logging.getLogger('cube.compiler').setLevel( logging.INFO diff --git a/cube/flags.py b/cube/flags.py index b548bc14..47d79762 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -30,8 +30,6 @@ class CompileFlag: log_codegen = _to_bool('LOG_CODEGEN') # log the runtime information log_runtime = _to_bool('LOG_RUNTIME') - # log the profiling information - log_profiler = _to_bool('LOG_PROFILER') # ================ compiling ======================== use_torchfx = _to_bool('USE_TORCHFX') # using torch.fx or torchscript as frontend to capture dataflow graph From b8719e164a9c16b243b778fddf156f067793f067 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 18:22:54 +0800 Subject: [PATCH 1435/1892] refine logging of runtime --- cube/runtime/adapter/reducer.py | 13 ++++++------- cube/runtime/device.py | 4 +++- cube/runtime/executor.py | 4 ++-- cube/runtime/function/function.py | 5 ++++- cube/runtime/module.py | 22 ++++++++++------------ cube/runtime/syndata.py | 4 ++-- 6 files changed, 27 insertions(+), 25 deletions(-) diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 9fa5e46a..4e498501 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -1,6 +1,6 @@ from typing import List, Dict, Tuple, Any, Callable, Optional, Set from functools import partial -import warnings +import logging import torch from torch.utils.hooks import RemovableHandle @@ -182,8 +182,8 @@ def sync_grads(self): # async if self._async: if CudaTimer().enabled and CudaTimer().predefined: - warnings.warn(f'CudaTimer: the communication time of async ' - f'reducer will not be recorded in `comm`') + logging.getLogger('cube.runtime').warn( + f'CudaTimer: the communication time of async reducer will not be recorded in `comm`') assert self._work is not None self._work.wait() else: @@ -351,9 +351,8 @@ def add_param(self, param: torch.nn.Parameter): @param param torch.nn.Parameter: the added parameter """ if param.data.data_ptr() in self._param_ids: - warnings.warn( - f'rank [{torch.distributed.get_rank()}]: detected duplicated or shared parameters, ignored.', - category=RuntimeWarning) + logging.getLogger('cube.runtime').warn( + f'rank [{torch.distributed.get_rank()}]: detected duplicated or shared parameters, ignored.') return self._params.append(param) self._param_ids.add(param.data.data_ptr()) @@ -381,7 +380,7 @@ def build_buckets(self): dtype2size[tp] = cur_byte_size else: if cur_byte_size > bucket_size: - warnings.warn(f'find one parameter {param.shape} ({cur_byte_size} bytes) larger than bucket size {self._bucket_size}') + logging.getLogger('cube.runtime').warn(f'find one parameter {param.shape} ({cur_byte_size} bytes) larger than bucket size {self._bucket_size}') buckets[tp].insert(0, [param]) elif dtype2size[tp] + cur_byte_size <= bucket_size: dtype2size[tp] = dtype2size[tp] + cur_byte_size diff --git a/cube/runtime/device.py b/cube/runtime/device.py index a8065400..39131006 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -5,6 +5,7 @@ import numpy as np import torch import os +import logging from cube.flags import CompileFlag @@ -15,7 +16,8 @@ class __DeviceGroup: def __init__(self): if CompileFlag.dev_mode: - print(f"DeviceGroup init using single device mode...") + logging.getLogger('cube.runtime').info( + f"DeviceGroup init using single device mode...") self.rank = 0 self.world_size = 1 self.local_world_size = 1 diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index c7dd8f0b..b3e8d0f5 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -5,7 +5,7 @@ from typing import Tuple, Any, Callable, List, Dict, Optional import torch -import warnings +import logging def debug_id(tensors, msg: str, rank: int): @@ -154,7 +154,7 @@ def backward(name: str, for t in input_tensors: if id(t) not in tensor_ids: import traceback - warnings.warn( + logging.getLogger('cube.runtime').warn( f"rank {torch.distributed.get_rank()}: input {name} doesn't match. " f"Make sure in scheduling, earlier forward perform earlier backward. " f"Remain {len(Executor._detach[name])} segments.\n" diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index eb41e226..da4df752 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -1,11 +1,14 @@ from typing import Optional, List, Tuple, Union import torch import torch.nn.functional as TorchF +import logging + +# TODO: move to registered function try: from apex.normalization.fused_layer_norm import fused_layer_norm_affine except: - print('WARNING: apex is not installed, skip it.') + logging.getLogger('cube.runtime').warn('skip apex ops as it is not installed.') def identity(tensor: torch.Tensor) -> torch.Tensor: diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 11f25a21..5ea79a9f 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -114,7 +114,8 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref filename_prefix = 'dist_checkpoint' if filename_prefix is None else filename_prefix filename = f"{filename_prefix}-{DeviceGroup().rank}.ckpt" state_dict, dist_param_map, param_area_map, optimizer_state_dict = self.get_checkpoint(optimizer) - print(f'> Saving distributed checkpoint to {filename}') + + logging.getLogger('cube.runtime').info(f'saving distributed checkpoint to {filename}') torch.save({ 'state_dict': state_dict, 'dist_param_map': dist_param_map, @@ -134,12 +135,13 @@ def merge_partial_states(state_dicts, zero_idx_maps=None): return state_dicts[0][0], state_dicts[0][1] plan_ngpus = -1 + # TODO: remove this flag if 'PLAN_NGPUS' in os.environ: plan_ngpus = int(os.environ['PLAN_NGPUS']) assert plan_ngpus >= 1, plan_ngpus assert plan_ngpus <= len(state_dicts), f'plan_ngpus = {plan_ngpus}, len(state_dicts) = {len(state_dicts)}' assert len(state_dicts) % plan_ngpus == 0, f'plan_ngpus = {plan_ngpus}, len(state_dicts) = {len(state_dicts)}' - logging.info(f'plan_ngpus = {plan_ngpus}') + logging.getLogger('cube.runtime').info(f'plan_ngpus = {plan_ngpus}') # at first, merge the partitioned optimizer states due to zero to the zero-disabled format if zero_idx_maps is not None: @@ -153,7 +155,6 @@ def _check_state_size(opt_state_keys, bucket_state): def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): assert bucket_size % len(bucket_states) == 0 opt_state_keys = list(bucket_states[0].keys()) - print(bucket_states[0], opt_state_keys) if 'step' in bucket_states[0]: opt_state_keys.remove('step') assert _check_state_size(opt_state_keys, bucket_states[0]), f'the keys {opt_state_keys} have different shape' @@ -213,7 +214,7 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' else: if plan_ngpus > 0: - logging.warning(f'plan_ngpus {plan_ngpus} not handled USE_ZERO == False') + logging.getLogger('cube.runtime').warn(f'plan_ngpus {plan_ngpus} not handled USE_ZERO == False') def _check_opt_state(opt_state): cnt = 0 sorted_opt_state = {} @@ -232,7 +233,6 @@ def _check_opt_state(opt_state): if len(zero_idx2model_idx) == 0: assert len(state_dicts[work_idx][1]['state']) == 0 for local_idx, val in state_dicts[work_idx][1]['state'].items(): # worker / last_optimizer_state / state - print(f'{work_idx}, {local_idx}') global_idx = zero_idx2model_idx[local_idx] assert global_idx not in opt_state opt_state[global_idx] = val @@ -265,7 +265,7 @@ def _check_opt_state(opt_state): raw_name = dist_param_map[local_name] slices = param_area[1][1] if param_area[1][2] != 1: - print(f'TODO: value-split on {raw_name}') + logging.getLogger('cube.runtime').error(f'value-split on {raw_name} is not supported') if raw_name in param_max_dimsize: param_max_dimsize[raw_name] = max(param_max_dimsize[raw_name], slices) else: @@ -278,11 +278,11 @@ def _check_opt_state(opt_state): for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: if len(optimizer_state_dict['state'].items()) > 0: optimizer_state_names = list(optimizer_state_dict['state'][0].keys()) - print(f'optimizer_state_names = {optimizer_state_names}') + logging.getLogger('cube.runtime').info(f'optimizer_state_names = {optimizer_state_names}') if 'step' in optimizer_state_names: sample_step = optimizer_state_dict['state'][0]['step'] optimizer_state_names.remove('step') - print(f'optimizer_state_names (without step) = {optimizer_state_names}') + logging.getLogger('cube.runtime').info(f'optimizer_state_names (without step) = {optimizer_state_names}') else: optimizer_state_names = [] @@ -311,9 +311,7 @@ def _check_opt_state(opt_state): optim_full_tensors[index] = {} optim_full_tensors[index][state_name] = torch.zeros(tuple(tensor_size)) else: - print(f'INFO: merge_checkpoint skips {local_name_with_id}\'s optimizer state') - # print(f'param_full_tensors = {param_full_tensors}') - # print(f'optim_full_tensors = {optim_full_tensors}') + logging.getLogger('cube.runtime').info(f'merge_checkpoint skips {local_name_with_id}\'s optimizer state') break # only create once # assign value @@ -349,7 +347,7 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): for rank in range(DeviceGroup().world_size): filename = f"{filename_prefix}-{rank}.ckpt" ckpts[rank] = torch.load(filename) - print(f'checkpoints = {ckpts}') + logging.getLogger('cube.runtime').info(f'checkpoints = {ckpts}') state_dicts = [] for ckpt in ckpts.values(): diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 50103dee..37978c69 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional, Tuple, Union import torch -import warnings +import logging class CubeDataLoader: @@ -192,6 +192,6 @@ def set_batch_size(self, batch_size: int): for shape, dim in zip(self.shapes, self.batch_dims): shape[dim] = batch_size rank = 0 if not torch.distributed.is_initialized() else torch.distributed.get_rank() - print(f'rank [{rank}]: > set batch size to {batch_size}. dataloader outputs change to: {self.shapes}') + logging.getLogger('cube.runtime').info(f'rank [{rank}]: set batch size to {batch_size}. dataloader outputs change to: {self.shapes}') datas = self.random_sample() self.set_output(datas) From c7d2d10feea6fa8c4bfe1b87f20b68780cbb605d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 13 Jul 2023 19:25:55 +0800 Subject: [PATCH 1436/1892] better logger format --- cube/__init__.py | 7 ++++++- cube/graph/gener/gen.py | 10 +++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/cube/__init__.py b/cube/__init__.py index 9fe97b7e..265beb93 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -27,7 +27,12 @@ def init(): def _init_logger(): - logging.basicConfig(level=logging.WARN) + + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" + ) level = lambda flag: logging.INFO if flag else logging.WARN diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index cb6296bb..2abe703d 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -317,13 +317,9 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # FIXME: assume producers and consumers can run in parallel for ftensor in ftensors: - logging.getLogger('cube.adapter').debug( - f'generate adapter for forward tenosrs:\n' - f'{graph.debug_tensor_map_str(ftensor)}') - if ftensor.grad is not None: - logging.getLogger('cube.adapter').debug( - f'generate adapter for backward tenosrs:\n' - f'{graph.mirror.debug_tensor_map_str(ftensor.grad)}') + # debug + # print(f'forward:\n{graph.debug_tensor_map_str(ftensor)}') + # print(f'backward:\n{graph.mirror.debug_tensor_map_str(ftensor.grad)}') # producers can be operators and graph inputs fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) From 8b474585bc2ef95644fba2dfc87713562afde7f2 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Fri, 14 Jul 2023 00:50:05 +0000 Subject: [PATCH 1437/1892] Merged PR 1656: support group ZeRO The allgather of the updated weights in ZeRO is executed in subgroups of a reducer rank group. env variable ZERO_FACTOR is used to decide the number of such subgroups. parameters (i.e., optimizer states) will be divided across the subgroups instead of the reducer rank group. --- cube/codegen/module/module.py | 21 +++++++++++-- cube/flags.py | 7 ++++- cube/runtime/adapter/reducer.py | 55 ++++++++++++++++++++++++++------- cube/runtime/device.py | 7 +++++ 4 files changed, 75 insertions(+), 15 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 35c5b7e3..2ec5d7f1 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -162,6 +162,15 @@ def get_comm_groups(self, scale_ndevs: Optional[int] = None): @param scale_ndevs Optional[int]: scale to number of devices """ + def _add_comm_for_group_zero(ranks): + zero_comm_groups = [] + for i in range(CompileFlag.zero_ngroups): + assert len(ranks) % CompileFlag.zero_ngroups == 0 + ranks_per_group = len(ranks) // CompileFlag.zero_ngroups + zero_subgroup = tuple(ranks[i * ranks_per_group : (i + 1) * ranks_per_group]) + if len(zero_subgroup) > 1 and len(zero_subgroup) < len(ranks): + zero_comm_groups.append(zero_subgroup) + return zero_comm_groups scale_ndevs = scale_ndevs if scale_ndevs is not None else len(self.devices) assert len(self.devices) == max(self.devices) + 1, f'device must be consecutive' assert scale_ndevs % len(self.devices) == 0, f'ngpus must be a multiple of {len(self.devices)}' @@ -176,11 +185,15 @@ def get_comm_groups(self, scale_ndevs: Optional[int] = None): for device in reducer.device) ranks = tuple(sorted(ranks)) comm_groups.append(ranks) + # add comm groups for group ZeRO + comm_groups.extend(_add_comm_for_group_zero(ranks)) # communication groups for parameters that are outside reducers for device in self.devices: ranks = list(range(device, scale_ndevs, len(self.devices))) if len(ranks) > 1: comm_groups.append(ranks) + # add comm groups for group ZeRO + comm_groups.extend(_add_comm_for_group_zero(ranks)) # communication groups for activations adapters = graph.select(ntype=IRAdapter) for adapter in adapters: @@ -425,12 +438,14 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: max_nbytes = CompileFlag.max_reducer_bucket async_op = CompileFlag.async_reducer zero = CompileFlag.use_zero + zero_ngroups = CompileFlag.zero_ngroups reduce_op = f"'{CompileFlag.reducer_op}'" # reducer init interface reducer_init = ( "{reducer} = cube.runtime.adapter.Reducer(" "ranks={ranks}, reduce_op={reduce_op}, " - "async_op={async_op}, zero={zero}, max_bucket_size_bytes={max_nbytes})" + "async_op={async_op}, zero={zero}, max_bucket_size_bytes={max_nbytes}, " + "zero_ngroups={zero_ngroups})" ) reducer_add = 'self.add_reducer({reducer})' add_param = '{reducer}.add_param({weight})' @@ -440,8 +455,8 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: self.model_init_statements.append('') ranks = list(sorted(node.device)) init_code = reducer_init.format( - reducer=reducer_name, ranks=ranks, reduce_op=reduce_op, - async_op=async_op, zero=zero, max_nbytes=max_nbytes) + reducer=reducer_name, ranks=ranks, reduce_op=reduce_op, + async_op=async_op, zero=zero, max_nbytes=max_nbytes, zero_ngroups=zero_ngroups) self.model_init_statements.append(init_code) weights = [ModuleCodeGen.tensor_name(t, prefix_attr='self.') for t in weights] for weight in weights: diff --git a/cube/flags.py b/cube/flags.py index 18cdb3b7..5fdba724 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -52,7 +52,12 @@ class CompileFlag: max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=137217728) # perform reducer op on gradients, can be sum, avg, mean, max, min. Default is sum reducer_op = os.environ.get('REDUCER_OP', default='sum') - + # zero_ngroups is the number of subgroups in each original ZeRO gruop (e.g., weights reducer) + # ZeRO subgroup is obtained by dividing the original ZeRO group by zero_ngroups + # it helps reduce communication cost of allgather weights in ZeRO, but increase the weights' + # optimization states on each GPU. + zero_ngroups = _to_int('ZERO_NUM_GROUPS', default=1) + # use automate mixture precision training, where weights, gradients # and optimizer status are kept in its original data type (can be float32), # but some of the forward operators will be converted to float16. diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 9fa5e46a..737782b8 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -38,7 +38,8 @@ class Bucket: def __init__(self, params: List[torch.nn.Parameter], param_buffer: torch.Tensor, grad_buffer: torch.Tensor, reduce_op: torch.distributed.ReduceOp, - group, async_op: bool, zero: bool): + group, async_op: bool, zero: bool, + zero_subgroup: torch.distributed.ProcessGroup = None): """ Create a communication unit for parameter allreduce. @@ -52,6 +53,7 @@ def __init__(self, params: List[torch.nn.Parameter], @param group: communication group @param async_op bool: whether to use asynchronous operation @param zero bool: whether to use zero optimization on gradients + @param zero_subgroup: the subgroup for zero optimization the current rank belongs to """ self._params: List[torch.nn.Parameter] = params @@ -75,6 +77,9 @@ def __init__(self, params: List[torch.nn.Parameter], self._numel: int = sum(p.numel() for p in self._params) self._padding: int = self._contiguous_grads.size(0) - self._numel + self._zero_subgroup = self._group if zero_subgroup is None else zero_subgroup + self._zgroup_sz: int = torch.distributed.get_world_size(group=self._zero_subgroup) + # pre and post hooks for gradient synchronization self._pre_hooks: List[Callable] = [] self._post_hooks: List[Callable] = [] @@ -113,9 +118,9 @@ def build(self): if not self._zero: opt = self._contiguous_params[:self._numel] else: - rank = torch.distributed.get_rank(group=self._group) - opt = self._contiguous_params.chunk(self._wsz)[rank] - if rank == self._wsz - 1 and self._padding != 0: + rank = torch.distributed.get_rank(group=self._zero_subgroup) + opt = self._contiguous_params.chunk(self._zgroup_sz)[rank] + if rank == self._zgroup_sz - 1 and self._padding != 0: opt = opt[:-self._padding] self._param_for_optimizer = torch.nn.Parameter(opt) @@ -207,8 +212,9 @@ def sync_grads(self): # setup gradient for optimizer parameters if self._zero: - grad = self._contiguous_grads.chunk(self._wsz, dim=0)[rank] - if rank == self._wsz - 1 and self._padding != 0: + rank = torch.distributed.get_rank(group=self._zero_subgroup) + grad = self._contiguous_grads.chunk(self._zgroup_sz, dim=0)[rank] + if rank == self._zgroup_sz - 1 and self._padding != 0: grad = grad[:-self._padding] self._param_for_optimizer.grad = grad else: @@ -222,9 +228,9 @@ def gather_params(self): All-gather parameters """ assert self._zero, "gathering paramters is only for zero optimization." - rank = torch.distributed.get_rank(group=self._group) - shards = list(self._contiguous_params.chunk(self._wsz, dim=0)) - torch.distributed.all_gather(shards, shards[rank], group=self._group) + rank = torch.distributed.get_rank(group=self._zero_subgroup) + shards = list(self._contiguous_params.chunk(self._zgroup_sz, dim=0)) + torch.distributed.all_gather(shards, shards[rank], group=self._zero_subgroup) def register_pre_hook(self, fn: Callable): """Register pre hooks to be applied before gradient synchronization. @@ -285,7 +291,8 @@ def reset(self): class Reducer: def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, - reduce_op: str = 'sum', async_op: bool = False, zero: bool = False): + reduce_op: str = 'sum', async_op: bool = False, + zero: bool = False, zero_ngroups: int = 1): """ Create a reducer applied on a set of weights for weight reduction @@ -296,7 +303,8 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, only work for asynchronous reducer. @param reduce_op str: reduce operation, can be 'sum', 'avg', 'max' or 'min' (default 'sum') @param async_op bool: whether to overlap with backward computation (default False) - @param zero bool: whether to apply zero optimization on gradients + @param zero bool: whether to apply ZeRO optimization on gradients + @param zero_ngroups int: number of ZeRO subgroups in the original ZeRO group """ self._params: List[torch.nn.Parameter] = list() self._param_ids: Set[int] = set() @@ -313,6 +321,30 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, self._contiguous_params: torch.Tensor = None self._contiguous_grads: torch.Tensor = None + # build the subgroup of zero the current rank belongs to. + # When zero_ngroups is larger than 1, the number of ranks + # will be divided by zero_ngroups into sub rank groups, + # allgather of weights will be done within each subgroup. + # For example, if the ranks are [0, 1, 2, 3, 4, 5, 6, 7] and zero_ngroups=2, + # the ranks will be divided into [0, 1, 2, 3] and [4, 5, 6, 7]. + # If the ranks are [0, 2, 4, 6], zero_ngroups=2, then the ranks + # will be divided into [0, 2] and [4, 6]. + if self._zero and Bucket.use_reduce_scatter_for_zero: + assert zero_ngroups == 1, f"zero_ngroups {zero_ngroups}, which is >1, does not support reduce scatter" + if zero_ngroups > 1: + assert self._zero, f"USE_ZERO must be set when ZERO_NUM_GROUPS is larger than 1" + assert len(ranks) % zero_ngroups == 0, f"length of ranks {ranks} must be divisible by zero factor {zero_ngroups}" + curr_rank = torch.distributed.get_rank(group=self._group) + zgroup_sz = len(ranks) // zero_ngroups + group_idx = curr_rank // zgroup_sz + sub_ranks = ranks[group_idx * zgroup_sz : (group_idx + 1) * zgroup_sz] + if len(sub_ranks) > 1: + assert DeviceGroup().group_exists(sub_ranks), f"zero subgroup {sub_ranks} does not exist in comm groups" + self._zero_subgroup = DeviceGroup().get_group(sub_ranks) + else: + assert zero_ngroups == 1, f"zero factor must be 1, but got {zero_ngroups}" + self._zero_subgroup = self._group + @property def params(self) -> Tuple[torch.nn.Parameter]: return tuple(self._params) @@ -437,6 +469,7 @@ def build_buckets(self): self._group, self._async, self._zero, + self._zero_subgroup, ) buckets.append(bucket) torch.cuda.empty_cache() diff --git a/cube/runtime/device.py b/cube/runtime/device.py index a8065400..a3df99c1 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -51,6 +51,13 @@ def __getattr__(self, name): def __len__(self, name): return DeviceGroup.instance.world_size + def group_exists(self, ranks): + """ + Check if group exists + """ + rank_bits = DeviceGroup.bitmap(ranks) + return rank_bits in self.instance.groups + def get_group(self, ranks): """ Create and return rank groups on-demand From 35d14322e3738f1cedc325f5ce7b3f945b2198be Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 14 Jul 2023 10:29:57 +0800 Subject: [PATCH 1438/1892] refine logging code --- cube/__init__.py | 27 --- cube/codegen/module/module.py | 8 +- cube/codegen/schedule/schedule.py | 12 +- cube/compiler.py | 39 +++-- cube/graph/function/function.py | 12 +- cube/graph/gener/concurrent.py | 14 +- cube/graph/gener/gen.py | 1 - cube/graph/graph.py | 29 ++-- cube/profiler/database.py | 10 +- cube/profiler/estimator.py | 4 +- cube/profiler/memory.py | 16 +- cube/profiler/timer.py | 5 +- cube/runtime/adapter/reducer.py | 12 +- cube/runtime/device.py | 6 +- cube/runtime/executor.py | 7 +- cube/runtime/function/__init__.py | 1 - cube/runtime/function/dist.py | 280 ------------------------------ cube/runtime/function/function.py | 7 +- cube/runtime/module.py | 26 +-- cube/runtime/syndata.py | 10 +- 20 files changed, 138 insertions(+), 388 deletions(-) delete mode 100644 cube/runtime/function/dist.py diff --git a/cube/__init__.py b/cube/__init__.py index 265beb93..a524ccaa 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -34,33 +34,6 @@ def _init_logger(): datefmt="%Y-%m-%d %H:%M:%S" ) - level = lambda flag: logging.INFO if flag else logging.WARN - - logging.getLogger('cube.parser').setLevel( - level(CompileFlag.log_parser) - ) - logging.getLogger('cube.prim').setLevel( - level(CompileFlag.log_prim) - ) - logging.getLogger('cube.adapter').setLevel( - level(CompileFlag.log_adapter) - ) - logging.getLogger('cube.execplan').setLevel( - level(CompileFlag.log_execplan) - ) - logging.getLogger('cube.codegen').setLevel( - level(CompileFlag.log_codegen) - ) - logging.getLogger('cube.runtime').setLevel( - level(CompileFlag.log_runtime) - ) - logging.getLogger('cube.profiler').setLevel( - logging.INFO - ) - logging.getLogger('cube.compiler').setLevel( - logging.INFO - ) - def set_logger_level(name: Optional[str], level): """Set the logger level of cube. diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 35c5b7e3..7e54acca 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional, Tuple import more_itertools -import warnings +import logging import copy import torch import numpy as np @@ -27,6 +27,10 @@ from cube.flags import CompileFlag +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_codegen else logging.WARN) + + class ModuleCodeGen(FuncEmission): """ Generate module code @@ -458,7 +462,7 @@ def init_batchsize(self, node: IRDataOperation): bs = [t.shape[dim] for t, dim in zip(node.outputs(), node.get_batch_dims()) if dim is not None] bs = set(bs) if len(bs) > 1: - warnings.warn(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') + _logger.warn(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') bs = list(bs)[0] if len(bs) == 1 else None assert self.batch_size is None or self.batch_size == bs, f"Not match for batch size: {self.batch_size} != {bs}" self.model_init_statements.append(signature.format(bs=bs)) diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index c50d34e5..c7c58275 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -1,7 +1,7 @@ from typing import List, Dict, Any, Optional, Tuple import copy -import warnings +import logging from cube.ir.cten import IRCell, IRTensor from cube.ir.operator import IRDataOperation, IRFwOperation @@ -18,6 +18,12 @@ from cube.codegen.lifecycle import LifeCycle from cube.codegen.syntax.blocks import FunctionBlock, ForBlock +from cube.flags import CompileFlag + + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_codegen else logging.WARN) + class ScheduleCodeGen(FuncEmission): @@ -88,8 +94,8 @@ def gen(self, device: int, outfile=None, attach=None) -> str: else: # legacy hardcode strategy if isinstance(self.execplan.graph.sched, IRScheduleStrategy): - warnings.warn('using legacy IRScheduleStrategy cannot generate inference code. ' - 'Switch to use scheduling without strategy') + _logger.warn('using legacy IRScheduleStrategy cannot generate inference code. ' + 'Switch to use scheduling without strategy') with FunctionBlock(func_name='_infer_step', args=args) as fb: fb.insert_body('_ = None') diff --git a/cube/compiler.py b/cube/compiler.py index a725be02..9c3f5435 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -30,6 +30,10 @@ from cube.utils import print_each_rank +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + def compile(model: SemanticModel, *args, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, model_dummy_inputs: Tuple[Any] = None, @@ -71,7 +75,6 @@ def train_iter(model, dataloader): Returns: Callable: compiled training iteration """ - logger = logging.getLogger('cube.compiler') # clean global status Program().clear() IDGenerator().clear() @@ -108,12 +111,12 @@ def decorator(fn: Callable) -> Callable: if not override and os.path.exists(filename.format(myrank)): filename = filename.format(myrank) # TODO: set batch size - logger.warning('dataloader batch size stay as default.') + _logger.warning('dataloader batch size stay as default.') # load module code - logger.info(f'loading existed module from {filename} ...') + _logger.info(f'loading existed module from {filename} ...') model.load_module(filename) # load schedule code - logger.info(f'loading existed schedule from {filename} ...') + _logger.info(f'loading existed schedule from {filename} ...') return cube.load_default_schedule(filename) if DeviceGroup().local_rank == 0: @@ -142,7 +145,7 @@ def decorator(fn: Callable) -> Callable: # setup program output Program().set_output(outputs) span = time.time() - start - logger.info('finish parsing iteration: {:.2f} s'.format(span)) + _logger.info('finish parsing iteration: {:.2f} s'.format(span)) # run policy start = time.time() @@ -150,7 +153,7 @@ def decorator(fn: Callable) -> Callable: assert callable(PAS), f"Policy PAS is not callable" graph = PAS(graph, resource) span = time.time() - start - logger.info('finish policy expression: {:.2f} s'.format(span)) + _logger.info('finish policy expression: {:.2f} s'.format(span)) if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") @@ -169,14 +172,14 @@ def decorator(fn: Callable) -> Callable: start = time.time() graph = IRAdapterGener.gen(graph, cost_fn=comm_cost_fn) span = time.time() - start - logger.info('finish generating adapters: {:.2f} s'.format(span)) + _logger.info('finish generating adapters: {:.2f} s'.format(span)) if graph.sched is not None: start = time.time() graph.sched.apply() - logging.getLogger('cube.execplan').info(f'schedule:\n{graph.sched}') + _logger.debug(f'schedule:\n{graph.sched}') span = time.time() - start - logger.info('finish planpass on applying schedule strategy: {:.2f} s'.format(span)) + _logger.info('finish planpass on applying schedule strategy: {:.2f} s'.format(span)) # to execution plan start = time.time() @@ -187,13 +190,13 @@ def decorator(fn: Callable) -> Callable: if CompileFlag.visualize_plan: execplan.visualize('plan.png') span = time.time() - start - logger.info('finish lowering to execution plan: {:.2f} s'.format(span)) + _logger.info('finish lowering to execution plan: {:.2f} s'.format(span)) # plan pass for communication optimization start = time.time() execplan = DiffFusion.apply(execplan) span = time.time() - start - logger.info('finish planpass on diff-fusion operations: {:.2f} s'.format(span)) + _logger.info('finish planpass on diff-fusion operations: {:.2f} s'.format(span)) # execplan.visualize(outfile='plan.png') @@ -202,7 +205,7 @@ def decorator(fn: Callable) -> Callable: start = time.time() execplan = Grouping.apply(execplan) span = time.time() - start - logger.info('finish planpass on grouping operations: {:.2f} s'.format(span)) + _logger.info('finish planpass on grouping operations: {:.2f} s'.format(span)) # execplan.graph.reset_dependency() # execplan.analyze(outfile='execplan.png') @@ -227,21 +230,21 @@ def decorator(fn: Callable) -> Callable: attach=True ) span = time.time() - start - logger.info('finish generating code: {:.2f} seconds'.format(span)) + _logger.info('finish generating code: {:.2f} seconds'.format(span)) compile_end = time.time() compile_time = compile_end - compile_start - logger.info('compile time: {:.2f} seconds'.format(compile_time)) + _logger.info('compile time: {:.2f} seconds'.format(compile_time)) if torch.distributed.is_initialized(): if DeviceGroup().local_rank != 0 and CompileFlag.worker_sleep > 0: - logger.info(f'rank [{DeviceGroup().rank}] starts sleeping {CompileFlag.worker_sleep} seconds...') + _logger.info(f'rank [{DeviceGroup().rank}] starts sleeping {CompileFlag.worker_sleep} seconds...') time.sleep(CompileFlag.worker_sleep) torch.distributed.barrier() # load module filename = filename.format(myrank) - print_each_rank(f'loading generated module from {filename} ...', logger_fn=logger.info) + print_each_rank(f'loading generated module from {filename} ...', logger_fn=_logger.info) model.load_module(filename) if torch.distributed.is_initialized(): @@ -251,7 +254,7 @@ def decorator(fn: Callable) -> Callable: # set dataloder batch size (serialize output) if dataloader is not None: bs = model.get_gen_module().get_batch_size() - print_each_rank(f'setting batch size to: {bs}', logger_fn=logger.info) + print_each_rank(f'setting batch size to: {bs}', logger_fn=_logger.info) if torch.distributed.is_initialized(): for rank in range(torch.distributed.get_world_size()): if rank == torch.distributed.get_rank(): @@ -266,7 +269,7 @@ def decorator(fn: Callable) -> Callable: torch.distributed.barrier() # load temporal schedule - print_each_rank(f'loading generated schedule from {filename} ...', logger_fn=logger.info) + print_each_rank(f'loading generated schedule from {filename} ...', logger_fn=_logger.info) return cube.load_default_schedule(filename) return decorator diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index fdd36040..e6d12150 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -16,6 +16,11 @@ from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D from cube.graph.function.anchor import IRGraphAnchor +from cube.flags import CompileFlag + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_parser else logging.WARN) + def Identity(tensor: IRObject, signature = None): signature = 'cube.runtime.function.identity' @@ -55,7 +60,7 @@ def Linear(input, weight, bias=None, signature = None): return IRDimops(Linear, 'linear', signature, annos, [input, weight], bias=None) else: annos = ['b * k^, n k^, n -> b * n'] - logging.getLogger('cube.parser').warn( + _logger.warning( 'detected a linear operator has bias, the partition on reduction dimension is disabled.') return IRDimops(Linear, 'linear', signature, annos, [input, weight, bias]) @@ -1661,7 +1666,7 @@ def ShapeAsTensor(input: IRTensor, signature = None): """ torch._shape_as_tensor """ - logging.getLogger('cube.parser').warn( + _logger.warning( 'shape_as_tensor is interpreted as an IRPyFunc ' 'and generate an IRObject instead of IRTensor') signature = 'torch._shape_as_tensor' @@ -1757,8 +1762,7 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], return torch.device('cpu') if name == 'layout': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - logging.getLogger('cube.parser').warn( - "getattr of 'layout' will always return torch.strided") + _logger.warn("getattr of 'layout' will always return torch.strided") return torch.strided if isinstance(obj, torch.finfo): return getattr(obj, name) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index b4050ad6..b97d2bcf 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -15,18 +15,18 @@ from cube.graph.gener.rvd.layout import RVDLayout from cube.graph.gener.rvd.intra import IntraPathFinder from cube.graph.gener.rvd.inter import InterPathFinder -from cube.graph.gener.utils import DummyInputOuput from cube.flags import CompileFlag -import warnings +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_adapter else logging.WARNING) if CompileFlag.disable_intra_rvd: - warnings.warn('Detected disabling intra-RVD collective generation, which may have big impact on performance.') + _logger.warn('Detected disabling intra-RVD collective generation, which may have big impact on performance.') if CompileFlag.disable_inter_rvd: - warnings.warn('Detected disabling inter-RVD collective generation, which may have big impact on performance.') + _logger.warn('Detected disabling inter-RVD collective generation, which may have big impact on performance.') if CompileFlag.disable_comm_fusion: - warnings.warn('Detected disabling general communication fusion, which may have big impact on performance in certain cases.') + _logger.warn('Detected disabling general communication fusion, which may have big impact on performance in certain cases.') class ConcurrentGener: @@ -66,7 +66,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], f"Switch to general P2P communication.\n" f"===========================================\n{default}" ) - logging.getLogger('cube.adapter').warn(f'intra-RVD:\n{msg}') + _logger.warning(f'intra-RVD:\n{msg}') # Case 2: sperating device (inter-rvd) if (not CompileFlag.disable_inter_rvd) and len(set(pdevs).intersection(cdevs)) == 0: @@ -82,7 +82,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], f"Switch to general P2P communication.\n" f"===========================================\n{default}" ) - logging.getLogger('cube.adapter').warn(f'inter-RVD:\n{msg}') + _logger.warning(f'inter-RVD:\n{msg}') # Case 3: General cases # warnings.warn('The adapter is generated using P2P communication') diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 2abe703d..23e17985 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional, Tuple, Callable, Set import numpy as np import itertools -import logging from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 7fb45ea3..2cd22d9d 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -11,7 +11,6 @@ import logging import copy import dill -import sys from cube.ir.cten import IRTensor, IRCell, IRObject from cube.ir.unique import IDGenerator @@ -26,6 +25,11 @@ from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo +from cube.flags import CompileFlag + + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_prim else logging.WARNING) FOp = Union[IRFwOperation, IRDataOperation] @@ -77,14 +81,13 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: # align graph with input tensors itensors: Tuple[IRObject, ...] = self.inputs() if len(args) != len(itensors): - logger = logging.getLogger('cube.parser') - logger.error( + _logger.error( f'cube graph forward: skipping arguments due to len(args) != len(itensors): ' f'{len(args)} != {len(itensors)}' ) if len(args) > len(itensors): args = args[:len(itensors)] - logger.warn(f'cube graph forward: args shrinked into {args}') + _logger.warning(f'cube graph forward: args shrinked into {args}') else: raise RuntimeError('len(args) < len(itensors)') @@ -165,7 +168,7 @@ def backward(self, loss: IRSubTensor): if ftensor.is_loss(): continue consumers = [n for n in self.consumers(ftensor) if isinstance(n, IRFwOperation)] if len(consumers) == 0 and ftensor.requires_grad: - logging.getLogger('cube.parser').warn( + _logger.warning( f"detected a dead ftensor which is not consumed by any nodes:\n\t{ftensor.name}: {ftensor}") ftensor.requires_grad = False @@ -299,12 +302,11 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis raise TypeError("Expected op to be forward op or data op") if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") - logger = logging.getLogger('cube.prim') if node.name == 'multiref': - logger.warn(f'skip replicating multiref ({node.cid}), which will be handled by system.') + _logger.warning(f'skip replicating multiref ({node.cid}), which will be handled by system.') return [node] if isinstance(node, IRPyFunc): - logger.warn(f'skip replicating pyfunc ({node.cid}), which will be handled by system.') + _logger.warning(f'skip replicating pyfunc ({node.cid}), which will be handled by system.') return [node] fsegment: IRSegment = self.segment(node) @@ -366,12 +368,11 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], f"The partition algorithm ({algo}) is not initialized for this node" assert isinstance(node, (IRFwOperation, IRDataOperation)), \ f"Only allow op to be forward op or data op, but got: {node}" - logger = logging.getLogger('cube.prim') if node.name == 'multiref': - logger.warn(f'skip partitioning multiref ({node.cid}), which will be handled by system.') + _logger.warning(f'skip partitioning multiref ({node.cid}), which will be handled by system.') return [node] if isinstance(node, IRPyFunc): - logger.warn(f'skip partitioning pyfunc ({node.cid}), which will be handled by system.') + _logger.warning(f'skip partitioning pyfunc ({node.cid}), which will be handled by system.') return [node] # get partitioned sub-nodes @@ -868,7 +869,7 @@ def staging(self, nodes: Tuple[IRFwOperation]): if node.name == 'multiref' or isinstance(node, IRPyFunc): pass else: - logging.getLogger('cube.prim').info(f'involve node {node.name}({node.cid}) into the first stage') + _logger.info(f'involve node {node.name}({node.cid}) into the first stage') starts[0] = idx break @@ -883,7 +884,7 @@ def staging(self, nodes: Tuple[IRFwOperation]): begin = starts[sid] end = starts[sid+1] if sid != len(starts) - 1 else last_fidx + 1 if begin >= end: - logging.getLogger('cube.prim').warn( + _logger.warning( f"skip stage {sid} which doesn't have operators: [begin({begin}): end({end})).") continue fnodes = self._nodes[begin:end] @@ -1029,7 +1030,7 @@ def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: skip += nodes[end:] for node in skip: if isinstance(node, IRGraphAnchor): continue - logging.getLogger('cube.prim').info( + _logger.info( f"skip recompute node: {node.name} ({node.cid}) as " f"it doesn't require gradient and appears at head or tail." ) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index b716fe64..7fb32151 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -15,6 +15,7 @@ from cube.ir.operator import IRFwOperation from cube.graph.parser.dtype import IRDType2TorchDType from cube.graph.parser.register import CustomizedOps +from cube.flags import CompileFlag Shapes = NewType('Shapes', Tuple[Tuple[int]]) @@ -26,6 +27,9 @@ _train_module_ref: torch.nn.Module = torch.nn.Module().train() _eval_module_ref: torch.nn.Module = torch.nn.Module().eval() +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + class CompProfiler: @@ -253,20 +257,20 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b residual_mem += t.byte_size() in_mem_info.append(t.byte_size()) else: - logging.getLogger('cube.profiler').warn('skip input {t}') + _logger.warning('node {node}: skip input {t}') # run profiling try: fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = \ CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) except: - logging.getLogger('cube.profiler').error('fail to profile {node}') + _logger.error('fail to profile {node}') fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = 0, 0, 0, [], [] # log to database key = self._serialize(node) self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span,\ infer_memory, train_mem_info, residual_mem, train_mem2in_idx) - logging.getLogger('cube.profiler').info( + _logger.info( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | fw: {round(fw_span, 2)} ms | " f"bw: {round(bw_span, 2)} ms | infer mem: {infer_memory} | train mem info: {train_mem_info} | idx: {train_mem2in_idx}") diff --git a/cube/profiler/estimator.py b/cube/profiler/estimator.py index 29869be5..51f63202 100644 --- a/cube/profiler/estimator.py +++ b/cube/profiler/estimator.py @@ -8,6 +8,8 @@ from cube.graph.function import IRGraphAnchor from cube.profiler.database import ProfileDataBase +_logger = logging.getLogger(__name__) + class Estimator: """ @@ -41,7 +43,7 @@ def __call__(self, nodes_or_segment: Union[Tuple[IRFwOperation], IRSegment], except Exception as e: color, default = '\033[31m', '\033[0m' error_msg = f'fail to run node: {node}\nerror: {e}' - logging.getLogger('cube.profiler').error(f'{color}{error_msg}{default}') + _logger.error(f'{color}{error_msg}{default}') fw_span, bw_span, infer_mem, train_mem_info = 0, 0, 0, [0] if train: diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index 204b472e..2e4faf10 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -3,6 +3,9 @@ from cube.utils import print_each_rank import torch +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + def memory_summary(): torch.cuda.synchronize() @@ -11,7 +14,7 @@ def memory_summary(): # mem = torch.cuda.max_memory_reserved() print_each_rank( '{:.2f} GB memory consumption'.format(mem / 1024 / 1024 / 1024), - logger_fn=logging.getLogger('cube.profiler').info + logger_fn=_logger.info ) return mem @@ -30,14 +33,15 @@ def model_summary(model: torch.nn.Module, inputs: List[Any], do_eval=False, max_ Make sure all of these attributes are not used in modules. """ - logger = logging.getLogger('cube.profiler') torch.cuda.empty_cache() static_memory = torch.cuda.memory_allocated() print_each_rank( - 'static model: {:,.2f} MB'.format(static_memory / 1024 / 1024), rank_only=0) + 'static model: {:,.2f} MB'.format(static_memory / 1024 / 1024), + rank_only=0, logger_fn=_logger.info) nparams = sum([param.numel() for param in model.parameters()]) print_each_rank( - 'model paramters: {:,.2f} M'.format(nparams / 1000000), rank_only=0) + 'model paramters: {:,.2f} M'.format(nparams / 1000000), + rank_only=0, logger_fn=_logger.info) stat = dict(depth=0) def before_forward(module, input): @@ -49,7 +53,7 @@ def before_forward(module, input): module._summary_begin_end = True prefix = ' ' * module._summary_depth + '[Begin] > ' print_each_rank(prefix + '{}:'.format(name), rank_only=0, - logger_fn=logger.info) + logger_fn=_logger.info) if module._summary_depth < max_depth: module._summary_memory_state = torch.cuda.memory_allocated() stat['depth'] += 1 @@ -72,7 +76,7 @@ def after_forward(module, input, output): print_each_rank( prefix + '{}: Mem {:,.2f} MB, Params: {:,} ({:,.2f} MB if fp32)'.format( name, mem_consumption, n_params, n_params / 1024 / 1024 * 4), - rank_only=0, logger_fn=logger.info) + rank_only=0, logger_fn=_logger.info) handle_pre = torch.nn.modules.module.register_module_forward_pre_hook(before_forward) handle_after = torch.nn.modules.module.register_module_forward_hook(after_forward) diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index a1dd1ada..469024e3 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -5,6 +5,9 @@ import torch from cube.utils import print_each_rank +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + class CudaTimer: r""" @@ -113,7 +116,7 @@ def duration(self, times: int, field_name: str = 'default') -> float: @return span float: wall clock in milliseconds. """ if field_name not in self.instance.field: - logging.getLogger('cube.profiler').warn(f"CudaTimer: {field_name} doesn't record.") + _logger.warning(f"CudaTimer: {field_name} doesn't record.") return 0.0 if len(self.instance.field[field_name]) != 0: raise RuntimeError(f"timer for field {field_name} not stopped") diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 4e498501..33d8e887 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -6,7 +6,11 @@ from cube.runtime.device import DeviceGroup from cube.profiler.timer import CudaTimer -from cube.flags import RuntimeFlag +from cube.flags import RuntimeFlag, CompileFlag + + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: @@ -182,7 +186,7 @@ def sync_grads(self): # async if self._async: if CudaTimer().enabled and CudaTimer().predefined: - logging.getLogger('cube.runtime').warn( + _logger.warning( f'CudaTimer: the communication time of async reducer will not be recorded in `comm`') assert self._work is not None self._work.wait() @@ -351,7 +355,7 @@ def add_param(self, param: torch.nn.Parameter): @param param torch.nn.Parameter: the added parameter """ if param.data.data_ptr() in self._param_ids: - logging.getLogger('cube.runtime').warn( + _logger.warning( f'rank [{torch.distributed.get_rank()}]: detected duplicated or shared parameters, ignored.') return self._params.append(param) @@ -380,7 +384,7 @@ def build_buckets(self): dtype2size[tp] = cur_byte_size else: if cur_byte_size > bucket_size: - logging.getLogger('cube.runtime').warn(f'find one parameter {param.shape} ({cur_byte_size} bytes) larger than bucket size {self._bucket_size}') + _logger.warning(f'find one parameter {param.shape} ({cur_byte_size} bytes) larger than bucket size {self._bucket_size}') buckets[tp].insert(0, [param]) elif dtype2size[tp] + cur_byte_size <= bucket_size: dtype2size[tp] = dtype2size[tp] + cur_byte_size diff --git a/cube/runtime/device.py b/cube/runtime/device.py index 39131006..d746b944 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -9,6 +9,9 @@ from cube.flags import CompileFlag +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) + class DeviceGroup: @@ -16,8 +19,7 @@ class __DeviceGroup: def __init__(self): if CompileFlag.dev_mode: - logging.getLogger('cube.runtime').info( - f"DeviceGroup init using single device mode...") + _logger.info(f"DeviceGroup init using single device mode") self.rank = 0 self.world_size = 1 self.local_world_size = 1 diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index b3e8d0f5..cba09593 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -7,6 +7,11 @@ import torch import logging +from cube.flags import CompileFlag + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) + def debug_id(tensors, msg: str, rank: int): if torch.distributed.get_rank() == rank: @@ -154,7 +159,7 @@ def backward(name: str, for t in input_tensors: if id(t) not in tensor_ids: import traceback - logging.getLogger('cube.runtime').warn( + _logger.warning( f"rank {torch.distributed.get_rank()}: input {name} doesn't match. " f"Make sure in scheduling, earlier forward perform earlier backward. " f"Remain {len(Executor._detach[name])} segments.\n" diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py index bcd6790e..c5b9ae13 100644 --- a/cube/runtime/function/__init__.py +++ b/cube/runtime/function/__init__.py @@ -1,2 +1 @@ -from cube.runtime.function.dist import * from cube.runtime.function.function import * \ No newline at end of file diff --git a/cube/runtime/function/dist.py b/cube/runtime/function/dist.py deleted file mode 100644 index 400bc3ad..00000000 --- a/cube/runtime/function/dist.py +++ /dev/null @@ -1,280 +0,0 @@ -from typing import Tuple, List -import torch -from torch.distributed.distributed_c10d import _get_global_rank - -from cube.profiler.timer import print_each_rank - -from cube.profiler.timer import CudaTimer - - -def get_global_rank(group, group_rank): - if group is None: - return group_rank - else: - return _get_global_rank(group, group_rank) - - -def _roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, group): - """ - partition torch.roll at shifted dimension - - Inputs: - input: [B, H, W, C] - shift: int - dim: int - """ - world_size = len(dim_ranks) - if world_size == 1: - return torch.roll(input, (shift), (dim,)) - dim_rank = dim_ranks.index(torch.distributed.get_rank(group)) - # halo exchange at H dimension - if shift < 0: - shift = 0 - shift - if dim == 1: - local = input[:, shift:, :, :] - remote = input[:, slice(0, shift), :, :].contiguous() - elif dim == 2: - local = input[:, :, shift:, :] - remote = input[:, :, slice(0, shift), :].contiguous() - else: - raise NotImplementedError("Only support on dim 1 and dim 2") - recv_tensor = torch.empty_like(remote) - - # send to next rank and recv from prevous rank - send_local_rank = dim_ranks[(dim_rank - 1 + world_size) % world_size] - send_global_rank = get_global_rank(group, send_local_rank) - recv_local_rank = dim_ranks[(dim_rank + 1) % world_size] - recv_global_rank = get_global_rank(group, recv_local_rank) - # print_each_rank(f'send to {send_global_rank}, recv from {recv_global_rank}') - - send_op = torch.distributed.P2POp( - torch.distributed.isend, remote, - send_global_rank - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_tensor, - recv_global_rank - ) - ops = [send_op, recv_op] if dim_rank % 2 == 0 else [recv_op, send_op] - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - tensor = torch.cat((local, recv_tensor), dim=dim).contiguous() - return tensor - - elif shift > 0: - boundary = input.shape[dim] - shift - if dim == 1: - local = input[:, slice(0, boundary), :, :] - remote = input[:, slice(boundary, input.shape[dim]), :, :].contiguous() - elif dim == 2: - local = input[:, :, slice(0, boundary), :] - remote = input[:, :, slice(boundary, input.shape[dim]), :].contiguous() - else: - raise NotImplementedError("Only support on dim 1 and dim 2") - recv_tensor = torch.empty_like(remote) - - # to global rank - send_local_rank = dim_ranks[(dim_rank + 1) % world_size] - send_global_rank = get_global_rank(group, send_local_rank) - recv_local_rank = dim_ranks[(dim_rank - 1 + world_size) % world_size] - recv_global_rank = get_global_rank(group, recv_local_rank) - # print_each_rank(f'send to {send_global_rank}, recv from {recv_global_rank}') - - send_op = torch.distributed.P2POp( - torch.distributed.isend, remote, - send_global_rank - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_tensor, - recv_global_rank - ) - ops = [send_op, recv_op] if dim_rank % 2 == 0 else [recv_op, send_op] - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - tensor = torch.cat((recv_tensor, local), dim=dim).contiguous() - return tensor - else: - return input - - -def roll_dim_allgather(input: torch.Tensor, shift: int, dim: int, group, - full_input=False, full_output=False): - """ - partition torch.roll at shifted dimension - - Inputs: - input: [B, H, W, C] - shift: int - dim: int - """ - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) - # allgather to have all and select what each rank needed - tensor_list = [torch.empty_like(input) for _ in range(world_size)] - tensor_list[rank] = input - torch.distributed.all_gather(tensor_list, input, group=group) - full_tensor = torch.cat(tuple(tensor_list), dim=dim).contiguous() - full_tensor = torch.roll(full_tensor, shifts=(shift,), dims=(dim,)) - chunk_len = input.shape[dim] - if dim == 1: - mytensor = full_tensor[:, rank * chunk_len : (rank + 1) * chunk_len, :, :] - elif dim == 2: - mytensor = full_tensor[:, :, rank * chunk_len : (rank + 1) * chunk_len, :] - else: - raise NotImplementedError("Only supported on dim 1 and dim 2") - mytensor = mytensor.contiguous() - return mytensor - - -class RollDimParallel(torch.autograd.Function): - """ - Halo exchange implementation on partitioning torch.roll - at shift dimension - - """ - @staticmethod - def forward(ctx, input_, shift: int, dim: int, dim_ranks: List[int], group=None): - CudaTimer().start(field_name='roll parallel') - ctx.shift = shift - ctx.dim = dim - ctx.group = group - ctx.dim_ranks = dim_ranks - output = _roll_dim_parallel(input_, shift, dim, dim_ranks, group) - CudaTimer().stop(field_name='roll parallel') - return output - - @staticmethod - def backward(ctx, grad_output): - CudaTimer().start(field_name='roll parallel') - shift = ctx.shift - dim = ctx.dim - group = ctx.group - dim_ranks = ctx.dim_ranks - grad = _roll_dim_parallel(grad_output, 0-shift, dim, dim_ranks, group) - CudaTimer().stop(field_name='roll parallel') - return grad, None, None, None, None - - -def roll_dim_parallel(input: torch.Tensor, shift: int, dim: int, dim_ranks, group): - """ - partition torch.roll at shifted dimension - - Inputs: - input: [B, H, W, C] - shift: int - dim: int - """ - return RollDimParallel.apply(input, shift, dim, dim_ranks, group) - - -def roll_grid_parallel(input: torch.Tensor, - shifts: Tuple[int, int], dims: Tuple[int, int], - nh_group_ranks: List[int], nw_group_ranks: List[int], group): - input = roll_dim_parallel(input, shifts[0], 1, nh_group_ranks, group) - input = roll_dim_parallel(input, shifts[1], 2, nw_group_ranks, group) - return input - - -class GridPartition(torch.autograd.Function): - """ - Full input - """ - @staticmethod - def forward(ctx, input_, nrow: int, ncol: int, group=None): - """ - input: [B, H, W, C] - """ - CudaTimer().start(field_name='grid_partition') - ctx.group = group - world_size = torch.distributed.get_world_size(group) - ctx.nrow = nrow - ctx.ncol = ncol - assert nrow * ncol == world_size - rank = torch.distributed.get_rank(group) - myrow = rank // ncol - mycol = rank % ncol - - chunk = torch.chunk(input_, nrow, dim=1)[myrow] - chunk = torch.chunk(chunk, ncol, dim=2)[mycol].contiguous() - CudaTimer().stop(field_name='grid_partition') - return chunk - - @staticmethod - def backward(ctx, grad_output): - CudaTimer().start(field_name='grid_partition') - group = ctx.group - nrow = ctx.nrow - ncol = ctx.ncol - - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) - grad_output = grad_output.contiguous() - tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)] - tensor_list[rank] = grad_output - torch.distributed.all_gather(tensor_list, grad_output, group=group) - - rows = list() - for row in range(nrow): - row_slice = torch.cat(tuple(tensor_list[row*ncol:(row+1)*ncol]), dim=2) - rows.append(row_slice) - grad_output = torch.cat(tuple(rows), dim=1).contiguous() - CudaTimer().stop(field_name='grid_partition') - return grad_output, None, None, None - - -class GridCollection(torch.autograd.Function): - """ - Full input - """ - @staticmethod - def forward(ctx, input_, nrow: int, ncol: int, group=None): - """ - input: [B, H, W, C] - output: [B, nrow * H, ncol * W, C] - """ - CudaTimer().start(field_name='grid_collection') - ctx.group = group - world_size = torch.distributed.get_world_size(group) - ctx.nrow = nrow - ctx.ncol = ncol - assert nrow * ncol == world_size - - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=group) - - rows = list() - for row in range(nrow): - row_slice = torch.cat(tuple(tensor_list[row*ncol:(row+1)*ncol]), dim=2) - rows.append(row_slice) - output = torch.cat(tuple(rows), dim=1).contiguous() - CudaTimer().stop(field_name='grid_collection') - return output - - @staticmethod - def backward(ctx, grad_output): - CudaTimer().start(field_name='grid_collection') - group = ctx.group - nrow = ctx.nrow - ncol = ctx.ncol - - rank = torch.distributed.get_rank(group) - myrow = rank // ncol - mycol = rank % ncol - - chunk = torch.chunk(grad_output, nrow, dim=1)[myrow] - chunk = torch.chunk(chunk, ncol, dim=2)[mycol].contiguous() - CudaTimer().stop(field_name='grid_collection') - return chunk, None, None, None - - -def grid_partition(input_, nrow, ncol, group=None): - return GridPartition.apply(input_, nrow, ncol, group) - - -def grid_collection(input_, nrow, ncol, group=None): - return GridCollection.apply(input_, nrow, ncol, group) diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index da4df752..51f785dd 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -3,12 +3,17 @@ import torch.nn.functional as TorchF import logging +from cube.flags import CompileFlag + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) + # TODO: move to registered function try: from apex.normalization.fused_layer_norm import fused_layer_norm_affine except: - logging.getLogger('cube.runtime').warn('skip apex ops as it is not installed.') + _logger.warning('skip apex ops as it is not installed.') def identity(tensor: torch.Tensor) -> torch.Tensor: diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 5ea79a9f..035e1cda 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,9 +1,15 @@ -import logging from typing import List, Dict, Tuple, Optional +import logging +import os + import torch + from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer -import os +from cube.flags import CompileFlag + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) class CubeModule(torch.nn.Module): @@ -115,7 +121,7 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref filename = f"{filename_prefix}-{DeviceGroup().rank}.ckpt" state_dict, dist_param_map, param_area_map, optimizer_state_dict = self.get_checkpoint(optimizer) - logging.getLogger('cube.runtime').info(f'saving distributed checkpoint to {filename}') + _logger.info(f'saving distributed checkpoint to {filename}') torch.save({ 'state_dict': state_dict, 'dist_param_map': dist_param_map, @@ -141,7 +147,7 @@ def merge_partial_states(state_dicts, zero_idx_maps=None): assert plan_ngpus >= 1, plan_ngpus assert plan_ngpus <= len(state_dicts), f'plan_ngpus = {plan_ngpus}, len(state_dicts) = {len(state_dicts)}' assert len(state_dicts) % plan_ngpus == 0, f'plan_ngpus = {plan_ngpus}, len(state_dicts) = {len(state_dicts)}' - logging.getLogger('cube.runtime').info(f'plan_ngpus = {plan_ngpus}') + _logger.info(f'plan_ngpus = {plan_ngpus}') # at first, merge the partitioned optimizer states due to zero to the zero-disabled format if zero_idx_maps is not None: @@ -214,7 +220,7 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' else: if plan_ngpus > 0: - logging.getLogger('cube.runtime').warn(f'plan_ngpus {plan_ngpus} not handled USE_ZERO == False') + _logger.warning(f'plan_ngpus {plan_ngpus} not handled USE_ZERO == False') def _check_opt_state(opt_state): cnt = 0 sorted_opt_state = {} @@ -265,7 +271,7 @@ def _check_opt_state(opt_state): raw_name = dist_param_map[local_name] slices = param_area[1][1] if param_area[1][2] != 1: - logging.getLogger('cube.runtime').error(f'value-split on {raw_name} is not supported') + _logger.error(f'value-split on {raw_name} is not supported') if raw_name in param_max_dimsize: param_max_dimsize[raw_name] = max(param_max_dimsize[raw_name], slices) else: @@ -278,11 +284,11 @@ def _check_opt_state(opt_state): for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: if len(optimizer_state_dict['state'].items()) > 0: optimizer_state_names = list(optimizer_state_dict['state'][0].keys()) - logging.getLogger('cube.runtime').info(f'optimizer_state_names = {optimizer_state_names}') + _logger.info(f'optimizer_state_names = {optimizer_state_names}') if 'step' in optimizer_state_names: sample_step = optimizer_state_dict['state'][0]['step'] optimizer_state_names.remove('step') - logging.getLogger('cube.runtime').info(f'optimizer_state_names (without step) = {optimizer_state_names}') + _logger.info(f'optimizer_state_names (without step) = {optimizer_state_names}') else: optimizer_state_names = [] @@ -311,7 +317,7 @@ def _check_opt_state(opt_state): optim_full_tensors[index] = {} optim_full_tensors[index][state_name] = torch.zeros(tuple(tensor_size)) else: - logging.getLogger('cube.runtime').info(f'merge_checkpoint skips {local_name_with_id}\'s optimizer state') + _logger.info(f'merge_checkpoint skips {local_name_with_id}\'s optimizer state') break # only create once # assign value @@ -347,7 +353,7 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): for rank in range(DeviceGroup().world_size): filename = f"{filename_prefix}-{rank}.ckpt" ckpts[rank] = torch.load(filename) - logging.getLogger('cube.runtime').info(f'checkpoints = {ckpts}') + _logger.info(f'checkpoints = {ckpts}') state_dicts = [] for ckpt in ckpts.values(): diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 37978c69..3566c80e 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -3,9 +3,15 @@ """ from typing import Any, List, Optional, Tuple, Union -import torch import logging +import torch + +from cube.flags import CompileFlag + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) + class CubeDataLoader: r""" @@ -192,6 +198,6 @@ def set_batch_size(self, batch_size: int): for shape, dim in zip(self.shapes, self.batch_dims): shape[dim] = batch_size rank = 0 if not torch.distributed.is_initialized() else torch.distributed.get_rank() - logging.getLogger('cube.runtime').info(f'rank [{rank}]: set batch size to {batch_size}. dataloader outputs change to: {self.shapes}') + _logger.info(f'rank [{rank}]: set batch size to {batch_size}. dataloader outputs change to: {self.shapes}') datas = self.random_sample() self.set_output(datas) From 69cd3c2894c0c7420041fc75ccbeabd15009555f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 14 Jul 2023 11:25:54 +0800 Subject: [PATCH 1439/1892] refine again the logging format --- cube/__init__.py | 32 +++++++----------------------- cube/algorithm/ops/dimops.py | 5 +++-- cube/codegen/module/module.py | 3 +-- cube/codegen/schedule/schedule.py | 7 ++----- cube/execplan/planpass/fusion.py | 6 ++++-- cube/flags.py | 14 ------------- cube/graph/function/dimops.py | 3 ++- cube/graph/function/function.py | 5 +---- cube/graph/gener/concurrent.py | 7 +++---- cube/graph/gener/rvd/intra.py | 7 +++---- cube/graph/graph.py | 4 ---- cube/graph/parser/converter.py | 13 ++++++------ cube/graph/parser/fx/parser.py | 15 +++++++------- cube/graph/parser/register.py | 8 ++++---- cube/graph/parser/script/parser.py | 9 +++++---- cube/profiler/database.py | 5 +---- cube/profiler/timer.py | 1 - cube/program.py | 4 +--- cube/runtime/adapter/reducer.py | 4 +--- cube/runtime/device.py | 2 -- cube/runtime/executor.py | 3 --- cube/runtime/function/function.py | 2 -- cube/runtime/module.py | 2 -- cube/runtime/syndata.py | 3 --- cube/utils.py | 4 +++- 25 files changed, 56 insertions(+), 112 deletions(-) diff --git a/cube/__init__.py b/cube/__init__.py index a524ccaa..c5c32ea9 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -26,35 +26,17 @@ def init(): _ = runtime.resource.EnvResource() -def _init_logger(): - +def set_logger_level(level): + """Set the logger level with predefined logging format. + + Args: + level (int): the level of the logger. + """ logging.basicConfig( - level=logging.WARN, + level=level, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) -def set_logger_level(name: Optional[str], level): - """Set the logger level of cube. - - Args: - name (Optional[str]): the name of the logger, can be one of - 'cube.parser', 'cube.policy', 'cube.adapter', - 'cube.execplan', 'cube.compiler'. Or None to set all. - level (int): the level of the logger, can be one of - logging.DEBUG, logging.INFO, logging.WARN, logging.ERROR. - """ - - if name is None: - logger_names = list(logging.root.manager.loggerDict.keys()) - logger_names = [name for name in logger_names if name.startswith('cube')] - loggers = [logging.getLogger(name) for name in logger_names] - for logger in loggers: - logger.setLevel(level) - elif name in logging.root.manager.loggerDict: - logging.getLogger(name).setLevel(level) - - _check_torch_version() -_init_logger() \ No newline at end of file diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 640e0d35..c21e8892 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -8,6 +8,8 @@ from cube.ir.operator import IRFwOperation from collections import deque +_logger = logging.getLogger(__name__) + class DimSplitEinops(GenericDistAlgo): """! @@ -123,9 +125,8 @@ def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List else: adim, reduce = 'Value', None - logger = logging.getLogger('cube.prim') color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' - logger.info(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") + _logger.info(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") if not satisfy: return None rule: TransformRule = self.infer(idx, dim, num) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 7e54acca..9f971707 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -28,7 +28,6 @@ _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_codegen else logging.WARN) class ModuleCodeGen(FuncEmission): @@ -462,7 +461,7 @@ def init_batchsize(self, node: IRDataOperation): bs = [t.shape[dim] for t, dim in zip(node.outputs(), node.get_batch_dims()) if dim is not None] bs = set(bs) if len(bs) > 1: - _logger.warn(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') + _logger.warning(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') bs = list(bs)[0] if len(bs) == 1 else None assert self.batch_size is None or self.batch_size == bs, f"Not match for batch size: {self.batch_size} != {bs}" self.model_init_statements.append(signature.format(bs=bs)) diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index c7c58275..9ac9cc36 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -18,11 +18,8 @@ from cube.codegen.lifecycle import LifeCycle from cube.codegen.syntax.blocks import FunctionBlock, ForBlock -from cube.flags import CompileFlag - _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_codegen else logging.WARN) class ScheduleCodeGen(FuncEmission): @@ -94,8 +91,8 @@ def gen(self, device: int, outfile=None, attach=None) -> str: else: # legacy hardcode strategy if isinstance(self.execplan.graph.sched, IRScheduleStrategy): - _logger.warn('using legacy IRScheduleStrategy cannot generate inference code. ' - 'Switch to use scheduling without strategy') + _logger.warning('using legacy IRScheduleStrategy cannot generate inference code. ' + 'Switch to use scheduling without strategy') with FunctionBlock(func_name='_infer_step', args=args) as fb: fb.insert_body('_ = None') diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 2383e584..969abf7f 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -17,6 +17,9 @@ from cube.ir.adapter.prim import AllToAllAllToAllPrim +_logger = logging.getLogger(__name__) + + class DiffFusion(PlanPass): @staticmethod @@ -45,8 +48,7 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: ret = DiffFusion.nnfuse(fadapter) cnt = cnt+1 if ret else cnt visited.add(node) - logging.getLogger('cube.execplan').info( - f'adapter fusion: successfully fuse {cnt} differentiable adapters') + _logger.info(f'adapter fusion: successfully fuse {cnt} differentiable adapters') return execplan @staticmethod diff --git a/cube/flags.py b/cube/flags.py index 47d79762..361eeffd 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -17,20 +17,6 @@ def _to_int(s: str, default=0) -> int: class CompileFlag: - # ============= loggings =================== - # log the parser information - log_parser = _to_bool('LOG_PARSER') - # log the primitives applied on the cube graph - log_prim = _to_bool('LOG_PRIM') - # log the adapter information during communication generation - log_adapter = _to_bool('LOG_ADAPTER') - # log the execution plan - log_execplan = _to_bool('LOG_EXECPLAN') - # log the code generation information - log_codegen = _to_bool('LOG_CODEGEN') - # log the runtime information - log_runtime = _to_bool('LOG_RUNTIME') - # ================ compiling ======================== use_torchfx = _to_bool('USE_TORCHFX') # using torch.fx or torchscript as frontend to capture dataflow graph use_default_fx_tracer = _to_bool('USE_DEFAULT_FX_TRACER') # using default fx tracer or more powerful concrete_tracer diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index cdac9fdd..9228629c 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -74,6 +74,7 @@ _kSpecialIdentifiers = ('*', '?') +_logger = logging.getLogger(__name__) class DimAnno: @@ -662,7 +663,7 @@ def infer_shape(self) -> bool: shape_anno = self.oanno(oidx) if str(shape_anno) == '?': assert isinstance(otensor, IRObject), f"expect IRObject for unknown shape, get {otensor}" - logging.getLogger('cube.parser').warn( + _logger.warning( 'detect IRObject output in a IRDimops, please ensure the annotation is ' 'correct w.r.t the partition policy.') continue diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index e6d12150..2df900e4 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -16,10 +16,7 @@ from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D from cube.graph.function.anchor import IRGraphAnchor -from cube.flags import CompileFlag - _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_parser else logging.WARN) def Identity(tensor: IRObject, signature = None): @@ -1762,7 +1759,7 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], return torch.device('cpu') if name == 'layout': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - _logger.warn("getattr of 'layout' will always return torch.strided") + _logger.warning("getattr of 'layout' will always return torch.strided") return torch.strided if isinstance(obj, torch.finfo): return getattr(obj, name) diff --git a/cube/graph/gener/concurrent.py b/cube/graph/gener/concurrent.py index b97d2bcf..b16deb7f 100644 --- a/cube/graph/gener/concurrent.py +++ b/cube/graph/gener/concurrent.py @@ -19,14 +19,13 @@ _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_adapter else logging.WARNING) if CompileFlag.disable_intra_rvd: - _logger.warn('Detected disabling intra-RVD collective generation, which may have big impact on performance.') + _logger.warning('Detected disabling intra-RVD collective generation, which may have big impact on performance.') if CompileFlag.disable_inter_rvd: - _logger.warn('Detected disabling inter-RVD collective generation, which may have big impact on performance.') + _logger.warning('Detected disabling inter-RVD collective generation, which may have big impact on performance.') if CompileFlag.disable_comm_fusion: - _logger.warn('Detected disabling general communication fusion, which may have big impact on performance in certain cases.') + _logger.warning('Detected disabling general communication fusion, which may have big impact on performance in certain cases.') class ConcurrentGener: diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index b2b23d86..85222143 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -1,7 +1,6 @@ from typing import Callable, Dict, List, Tuple, Optional, Set from functools import partial import numpy as np -import sys import copy import logging @@ -23,7 +22,7 @@ from cube.graph.gener.utils import tensor_vd_repr - +_logger = logging.getLogger(__name__) TShape = Tuple[int, ...] TRVD = Tuple[int, ...] @@ -263,7 +262,7 @@ def path(ilayout: RVDLayout, olayout: RVDLayout, f"Switch to a fixed plan: ilayout -> FullReplica -> olayout" ) color, default = '\033[33m' , '\033[0m' - logging.getLogger('cube.adapter').warn(f'intra-RVD:\n{color+warn_msg+default}') + _logger.warning(f'intra-RVD:\n{color+warn_msg+default}') all_prims = IntraPathFinder.backup_path(ilayout, olayout, cost_fn) return all_prims @@ -663,7 +662,7 @@ def advice(shape: TShape, f"bw hops: {'->'.join(str(rvd) for rvd in bw_rvd_hops)}\n" f"using placement: {placement}\n" f"=============================================================") - logging.getLogger('cube.adapter').warn(f'intra-RVD:\n{msg}') + _logger.warning(f'intra-RVD:\n{msg}') bw_rvd_hops = IntraPathFinder.get_backup_path(ftensor, bw_src_rvd, bw_dst_rvd, cost_fn) # estimate cost diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 2cd22d9d..87c180ee 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -25,13 +25,9 @@ from cube.graph.segment import IRSegment from cube.algorithm.generics import GenericDistAlgo -from cube.flags import CompileFlag _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_prim else logging.WARNING) - - FOp = Union[IRFwOperation, IRDataOperation] diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index d6576fda..91d8ae73 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -10,6 +10,8 @@ import torch import torch.fx +_logger = logging.getLogger(__name__) + try: import apex HAS_APEX = True @@ -41,11 +43,10 @@ def convert_model(model: torch.nn.Module, customized_funcs = CustomizedOps.kOpRuntime.values() leaf_functions = {func: ([], False, None) for func in customized_funcs} - logger = logging.getLogger('cube.parser') # step 1: trace model if CompileFlag.use_torchfx: - logger.info('use concrete torch.fx tracer') + _logger.info('use concrete torch.fx tracer') from cube.graph.parser.fx.concrete_trace_utils import concrete_trace from cube.graph.parser.fx.parser import FxModuleParser if HAS_APEX: @@ -60,7 +61,7 @@ def convert_model(model: torch.nn.Module, # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, ) else: - logger.warn('apex package is not installed') + _logger.warning('apex package is not installed') leaf_module = None traced_model = concrete_trace( model, @@ -71,19 +72,19 @@ def convert_model(model: torch.nn.Module, cpu_offload=True, ) else: - logger.info('use torch.jit.script tracer') + _logger.info('use torch.jit.script tracer') traced_model = torch.jit.script(model) # step 2: convert traced model into IRGraph if CompileFlag.use_torchfx: FxModuleParser.save_content = save_content FxModuleParser.dynamic_shape = dynamic_shape - logger.info(f"use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") + _logger.info(f"use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") inputs, nodes, outputs = FxModuleParser.parse(traced_model, dummy_input) module_name = model.__class__.__name__ else: if dynamic_shape: - logger.warn('dynamic shape is not supported in torch.jit.script') + _logger.warning('dynamic shape is not supported in torch.jit.script') ScriptModuleParser.save_content = save_content inputs, nodes, outputs = ScriptModuleParser.parse_module(traced_model, input_shapes) module_name = traced_model.original_name diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index dbb8c43f..4b84966e 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -14,6 +14,9 @@ import torch.fx +_logger = logging.getLogger(__name__) + + class ErasedDevice: pass @@ -101,8 +104,7 @@ def parse(module: torch.fx.GraphModule, assert isinstance(dummy_inputs, dict), "Expected dummy inputs to parse module" inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] - logging.getLogger('cube.parser').info( - f'> torch.fx parser: graph inputs: {inputs}') + _logger.info(f'> torch.fx parser: graph inputs: {inputs}') # shape propagation ShapeProp(module).propagate(dummy_inputs) @@ -182,8 +184,7 @@ def parse_complex_out(meta_out): all_ir_nodes: List[IRFwOperation] = list() total_node_num = len(module.graph.nodes) for nidx, node in enumerate(module.graph.nodes): - logging.getLogger('cube.parser').info( - f'[{nidx}/{total_node_num}] parsing node {node}...') + _logger.info(f'[{nidx}/{total_node_num}] parsing node {node}...') ir_nodes = FxModuleParser.parse_node(node, module, frame) if ir_nodes is not None: all_ir_nodes += ir_nodes @@ -317,12 +318,12 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator if FxModuleParser._is_torch_autograd_op(node, frame, fsig): - logging.getLogger('cube.parser').warn(f'Find unknown pytorch operation: {fsig}') + _logger.warning(f'Find unknown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fsig ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) # case2: python runtime function else: - logging.getLogger('cube.parser').warn(f'Set python runtime function: {fsig}') + _logger.warning(f'Set python runtime function: {fsig}') ir_node = IRPyFunc(fsig, input_vals, [IRObject()], **kwargs) if isinstance(ir_node, IRCell): @@ -359,7 +360,7 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, else: frame.set_var(node.name, ir_node) - logging.getLogger('cube.parser').info(f'parsing result: {ir_node}') + _logger.info(f'parsing result: {ir_node}') return ir_nodes @staticmethod diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 861a36c4..9c9f9c29 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -11,6 +11,8 @@ from cube.graph.function.dimops import IRDimops, OpAnno +_logger = logging.getLogger(__name__) + class CustomizedOps: """Customized op registry.""" @@ -134,8 +136,7 @@ def decorator(fn: Callable): if code_impl_pattern == 'import': import_path = inspect.getmodule(fn).__name__ if import_path == '__main__': - logger = logging.getLogger('cube.parser') - logger.warn(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' + _logger.warning(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' f'This may cause error when the function has inner functions from other modules. ' f'To solve this, define the function in another module and import into main', stacklevel=0) code = inspect.getsource(fn) @@ -159,8 +160,7 @@ def udfop(*args, signature=None, **kwargs): kwargs[name] = val return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, transform_rules=rules, **kwargs) - logging.getLogger('cube.parser').info( - f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') + _logger.info(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') CustomizedOps.register(fsig, udfop, code, fn) return fn diff --git a/cube/graph/parser/script/parser.py b/cube/graph/parser/script/parser.py index 855118b1..6f2793f1 100644 --- a/cube/graph/parser/script/parser.py +++ b/cube/graph/parser/script/parser.py @@ -14,6 +14,7 @@ from cube.graph.parser.dtype import DType2IRDType +_logger = logging.getLogger(__name__) _refmodule = torch.nn.Module() class ErasedDevice: @@ -76,7 +77,7 @@ def parse_module(module, try: ret = ir_node.infer_shape() if not ret: - logging.getLogger('cube.parser').error(f'{ir_node} cannot infer shape') + _logger.error(f'{ir_node} cannot infer shape') except Exception: raise RuntimeError( f"====== Shape Infer Error ====\n\n\n" @@ -127,7 +128,7 @@ def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): try: ret = ir_node.infer_shape() if not ret: - logging.getLogger('cube.parser').error(f'{ir_node} cannot infer shape') + _logger.error(f'{ir_node} cannot infer shape') except Exception: raise RuntimeError( f"====== Shape Infer Error ====\n\n\n" @@ -553,7 +554,7 @@ def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: try: ret = ir_node.infer_shape() if not ret: - logging.getLogger('cube.parser').error(f'{ir_node} cannot infer shape') + _logger.error(f'{ir_node} cannot infer shape') except Exception: raise RuntimeError( f"====== Shape Infer Error ====\n\n\n" @@ -667,7 +668,7 @@ def parse_prim_loop_node(node, module, frame: Frame) -> List[IRFwOperation]: try: ret = ir_node.infer_shape() if not ret: - logging.getLogger('cube.parser').error(f'{ir_node} cannot infer shape') + _logger.error(f'{ir_node} cannot infer shape') except Exception: raise RuntimeError( f"====== Shape Infer Error ====\n\n\n" diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 7fb32151..d544a9bf 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -15,8 +15,8 @@ from cube.ir.operator import IRFwOperation from cube.graph.parser.dtype import IRDType2TorchDType from cube.graph.parser.register import CustomizedOps -from cube.flags import CompileFlag +_logger = logging.getLogger(__name__) Shapes = NewType('Shapes', Tuple[Tuple[int]]) DTypes = NewType('DTypes', Tuple[torch.dtype]) @@ -27,9 +27,6 @@ _train_module_ref: torch.nn.Module = torch.nn.Module().train() _eval_module_ref: torch.nn.Module = torch.nn.Module().eval() -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - class CompProfiler: diff --git a/cube/profiler/timer.py b/cube/profiler/timer.py index 469024e3..4fffa0e1 100644 --- a/cube/profiler/timer.py +++ b/cube/profiler/timer.py @@ -6,7 +6,6 @@ from cube.utils import print_each_rank _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) class CudaTimer: diff --git a/cube/program.py b/cube/program.py index 29e19777..9e18d6d1 100644 --- a/cube/program.py +++ b/cube/program.py @@ -1,5 +1,4 @@ from typing import List, Tuple, Optional, Any -import logging from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -12,9 +11,8 @@ from cube.runtime.syndata import CubeDataLoader from cube.runtime.module import CubeModule from cube.runtime.device import DeviceGroup -from cube.profiler.timer import print_each_rank -from cube.utils import load_model, load_default_schedule +from cube.utils import load_model import torch diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 33d8e887..124828ec 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -6,11 +6,9 @@ from cube.runtime.device import DeviceGroup from cube.profiler.timer import CudaTimer -from cube.flags import RuntimeFlag, CompileFlag - +from cube.flags import RuntimeFlag _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: diff --git a/cube/runtime/device.py b/cube/runtime/device.py index d746b944..5fa3476f 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -10,7 +10,6 @@ from cube.flags import CompileFlag _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) class DeviceGroup: @@ -19,7 +18,6 @@ class __DeviceGroup: def __init__(self): if CompileFlag.dev_mode: - _logger.info(f"DeviceGroup init using single device mode") self.rank = 0 self.world_size = 1 self.local_world_size = 1 diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index cba09593..1823f03e 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -7,10 +7,7 @@ import torch import logging -from cube.flags import CompileFlag - _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) def debug_id(tensors, msg: str, rank: int): diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 51f785dd..47c5146f 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -3,10 +3,8 @@ import torch.nn.functional as TorchF import logging -from cube.flags import CompileFlag _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) # TODO: move to registered function diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 035e1cda..04b0dccf 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -6,10 +6,8 @@ from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer -from cube.flags import CompileFlag _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) class CubeModule(torch.nn.Module): diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py index 3566c80e..0bd53343 100644 --- a/cube/runtime/syndata.py +++ b/cube/runtime/syndata.py @@ -7,10 +7,7 @@ import torch -from cube.flags import CompileFlag - _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO if CompileFlag.log_runtime else logging.WARNING) class CubeDataLoader: diff --git a/cube/utils.py b/cube/utils.py index 79b76656..53c1aa16 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -7,6 +7,8 @@ import torch +_logger = logging.getLogger(__name__) + def print_each_rank(msg: str, rank_only: Optional[int] = None, logger_fn: Callable = print): """Logging the message. @@ -51,7 +53,7 @@ def load_model(filename: Optional[str] = None, load_content: bool = True): # load parameter content if load_content: print_each_rank("> loading parameter content...", - logger_fn=logging.getLogger('cube.codegen').info) + logger_fn=_logger.info) loaded_module.load_attr_content('./fullmodel.pt') # initialize reducer for reducer in loaded_module.reducers: From 3c3451834821aabfa5fea986789b1ed7adb148a9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 14 Jul 2023 16:42:00 +0800 Subject: [PATCH 1440/1892] refine print_each_rank interface --- cube/compiler.py | 6 +++--- cube/profiler/memory.py | 10 +++++----- cube/utils.py | 9 +++++---- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 9c3f5435..a54900dd 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -244,7 +244,7 @@ def decorator(fn: Callable) -> Callable: # load module filename = filename.format(myrank) - print_each_rank(f'loading generated module from {filename} ...', logger_fn=_logger.info) + print_each_rank(f'loading generated module from {filename} ...', logger=_logger) model.load_module(filename) if torch.distributed.is_initialized(): @@ -254,7 +254,7 @@ def decorator(fn: Callable) -> Callable: # set dataloder batch size (serialize output) if dataloader is not None: bs = model.get_gen_module().get_batch_size() - print_each_rank(f'setting batch size to: {bs}', logger_fn=_logger.info) + print_each_rank(f'setting batch size to: {bs}', logger=_logger) if torch.distributed.is_initialized(): for rank in range(torch.distributed.get_world_size()): if rank == torch.distributed.get_rank(): @@ -269,7 +269,7 @@ def decorator(fn: Callable) -> Callable: torch.distributed.barrier() # load temporal schedule - print_each_rank(f'loading generated schedule from {filename} ...', logger_fn=_logger.info) + print_each_rank(f'loading generated schedule from {filename} ...', logger=_logger) return cube.load_default_schedule(filename) return decorator diff --git a/cube/profiler/memory.py b/cube/profiler/memory.py index 2e4faf10..5d536664 100644 --- a/cube/profiler/memory.py +++ b/cube/profiler/memory.py @@ -14,7 +14,7 @@ def memory_summary(): # mem = torch.cuda.max_memory_reserved() print_each_rank( '{:.2f} GB memory consumption'.format(mem / 1024 / 1024 / 1024), - logger_fn=_logger.info + logger=_logger ) return mem @@ -37,11 +37,11 @@ def model_summary(model: torch.nn.Module, inputs: List[Any], do_eval=False, max_ static_memory = torch.cuda.memory_allocated() print_each_rank( 'static model: {:,.2f} MB'.format(static_memory / 1024 / 1024), - rank_only=0, logger_fn=_logger.info) + rank_only=0, logger=_logger) nparams = sum([param.numel() for param in model.parameters()]) print_each_rank( 'model paramters: {:,.2f} M'.format(nparams / 1000000), - rank_only=0, logger_fn=_logger.info) + rank_only=0, logger=_logger) stat = dict(depth=0) def before_forward(module, input): @@ -53,7 +53,7 @@ def before_forward(module, input): module._summary_begin_end = True prefix = ' ' * module._summary_depth + '[Begin] > ' print_each_rank(prefix + '{}:'.format(name), rank_only=0, - logger_fn=_logger.info) + logger=_logger) if module._summary_depth < max_depth: module._summary_memory_state = torch.cuda.memory_allocated() stat['depth'] += 1 @@ -76,7 +76,7 @@ def after_forward(module, input, output): print_each_rank( prefix + '{}: Mem {:,.2f} MB, Params: {:,} ({:,.2f} MB if fp32)'.format( name, mem_consumption, n_params, n_params / 1024 / 1024 * 4), - rank_only=0, logger_fn=_logger.info) + rank_only=0, logger=_logger) handle_pre = torch.nn.modules.module.register_module_forward_pre_hook(before_forward) handle_after = torch.nn.modules.module.register_module_forward_hook(after_forward) diff --git a/cube/utils.py b/cube/utils.py index 53c1aa16..e7a9fd84 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -10,19 +10,20 @@ _logger = logging.getLogger(__name__) -def print_each_rank(msg: str, rank_only: Optional[int] = None, logger_fn: Callable = print): +def print_each_rank(msg: str, rank_only: Optional[int] = None, logger: Optional[logging.Logger] = None): """Logging the message. Args: msg (str): message to be logged. rank_only (int, optional): the rank to be logged. Defaults to None, which means all ranks. - logger_fn (Callable, optional): - the logger function. Defaults to print. + logger (logging.Logger, optional): + the logger to use. Defaults to print. Returns: None """ + logger_fn = print if logger is None else logger.info if CompileFlag.dev_mode: logger_fn(msg) return @@ -53,7 +54,7 @@ def load_model(filename: Optional[str] = None, load_content: bool = True): # load parameter content if load_content: print_each_rank("> loading parameter content...", - logger_fn=_logger.info) + logger=_logger) loaded_module.load_attr_content('./fullmodel.pt') # initialize reducer for reducer in loaded_module.reducers: From 40e36d6bb584f7431c5a6d9d55d62f61a0acea4a Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 14 Jul 2023 09:24:20 +0000 Subject: [PATCH 1441/1892] Merged PR 1664: Add a compile flag for gen code comments --- cube/flags.py | 1 + cube/graph/parser/converter.py | 1 + .../concrete_trace_utils/concrete_tracer.py | 42 +++++++++++-------- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/cube/flags.py b/cube/flags.py index 8f52978d..e1317f1c 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -31,6 +31,7 @@ class CompileFlag: # ============ code generation =============== use_nnfusion = _to_bool('USE_NNFUSION') use_jit = _to_bool('USE_JIT') + disable_code_line_info = _to_bool('DISABLE_CODE_LINE_INFO') # will add original code information in generated code, note that this will make trace slow # ============== runtime ==================== dev_mode = _to_bool('SINGLE_DEV_MODE') # allow to use python xx.py diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 91d8ae73..742b3651 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -70,6 +70,7 @@ def convert_model(model: torch.nn.Module, leaf_module=leaf_module, autowrap_leaf_function=leaf_functions, cpu_offload=True, + record_frames=not CompileFlag.disable_code_line_info, ) else: _logger.info('use torch.jit.script tracer') diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index b809fefc..b5af4de5 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -267,7 +267,7 @@ class ConcreteTracer(TracerBase): } @compatibility(is_backward_compatible=True) - def __init__(self, cpu_offload = False): + def __init__(self, cpu_offload = False, record_frames = False): """ similar to _symbolic_trace.Tracer.__init__. remove the 'param_shapes_constant' because we can get real shape when executing. @@ -277,6 +277,7 @@ def __init__(self, cpu_offload = False): self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} self.cpu_offload = cpu_offload + self.record_frames = record_frames @contextmanager def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): @@ -381,6 +382,9 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] result = result.cpu() elif isinstance(result, (list, dict, tuple)): result = tree_map(to_cpu, result) + elif isinstance(result, (int, bool, torch.device, torch.dtype)): + # avoid too noisy warning + pass else: _logger.warning(f"result of target {target} is {type(result)}, which is not a common behavior.") @@ -447,20 +451,21 @@ def upwrapper(obj: Any): node = self.create_node(kind, target, args_, kwargs_, name, type_expr) - # record code frame, include filename, line number, and function name - frame_record = FrameRecord(None, None, None, None) - cube_cct_path = str(Path(__file__).parent) + '/' # the cube concrete tracer path - torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path - ignore_dirs = [cube_cct_path, torch_path] - for frame in traceback.extract_stack()[-2::-1]: - if any(p in frame.filename for p in ignore_dirs): - continue - frame_record.filename = frame.filename - frame_record.lineno = frame.lineno - frame_record.line = frame.line - frame_record.name = frame.name - break - node.meta['frame_record'] = frame_record + if self.record_frames: + # record code frame, include filename, line number, and function name + frame_record = FrameRecord(None, None, None, None) + cube_cct_path = str(Path(__file__).parent) + '/' # the cube concrete tracer path + torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path + ignore_dirs = [cube_cct_path, torch_path] + for frame in traceback.extract_stack()[-2::-1]: + if any(p in frame.filename for p in ignore_dirs): + continue + frame_record.filename = frame.filename + frame_record.lineno = frame.lineno + frame_record.line = frame.line + frame_record.name = frame.name + break + node.meta['frame_record'] = frame_record proxy = self.proxy(value_unwrapped, node) return proxy @@ -1497,6 +1502,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], dce = True, cpu_offload = False, trace_twice = False, + record_frames = False, ) -> GraphModule: """ Concrete tracing API @@ -1627,11 +1633,13 @@ def f(x, y): If set to False, there will be no offloading during tracing, but the traced code will be executed on default device. trace_twice (bool): If set to True, a second trace will be performed, and the two obtained graphs will be checked for consistency. - + + record_frames(bool): If set to True, will add frame information to node.meta['frame_record']. Note this will cost additional trace time. + Returns: fx.GraphModule: a Module created from the recorded operations from ``root``. """ - tracer = ConcreteTracer(cpu_offload = cpu_offload) + tracer = ConcreteTracer(cpu_offload = cpu_offload, record_frames = record_frames) is_training = root.training root.eval() From 9d048d0296af73786935ca214edfc0c96f9a4da5 Mon Sep 17 00:00:00 2001 From: yileiyang Date: Fri, 14 Jul 2023 09:36:39 +0000 Subject: [PATCH 1442/1892] add bf16 definition --- cube/graph/parser/dtype.py | 1 + cube/ir/dtype.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/cube/graph/parser/dtype.py b/cube/graph/parser/dtype.py index 9d74ee53..f2136e4a 100644 --- a/cube/graph/parser/dtype.py +++ b/cube/graph/parser/dtype.py @@ -18,6 +18,7 @@ def map(dtype: torch.dtype): torch.float : ir.float32, torch.float16: ir.float16, torch.half : ir.float16, + torch.bfloat16: ir.bfloat16, torch.uint8 : ir.uint8, torch.int8 : ir.int8, torch.int16 : ir.int16, diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index 89b8f6dd..6a5e64f0 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -6,6 +6,7 @@ class IRDType(Enum): float64 = 'float64' float16 = 'float16' float32 = 'float32' + bfloat16 = 'bfloat16' int64 = 'int64' int32 = 'int32' int16 = 'int16' @@ -20,6 +21,7 @@ def dtype2byte_size(dtype: IRDType) -> int: IRDType.float64: 8, IRDType.float32: 4, IRDType.float16: 2, + IRDType.bfloat16: 2, IRDType.int64: 8, IRDType.int32: 4, IRDType.int16: 2, @@ -61,6 +63,7 @@ def infer(node, dtypes: List[IRDType]) -> IRDType: float64 = IRDType.float64 float16 = IRDType.float16 float32 = IRDType.float32 +bfloat16 = IRDType.bfloat16 int64 = IRDType.int64 int32 = IRDType.int32 int16 = IRDType.int16 From 997fa5bac472b0f5135cb42d892a56bdfc1466ab Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 17 Jul 2023 21:02:29 +0800 Subject: [PATCH 1443/1892] refine code --- cube/algorithm/ops/dimops.py | 11 +++-------- cube/ir/tensor.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index a87a69b0..b034871e 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -138,11 +138,9 @@ def transform(tensor: Any, split: DimopSplit) -> List[Any]: if not isinstance(tensor, IRSubTensor): return [tensor] * num if split.isD(): - sub_tensors = [tensor] - for dim in split.dims: - for _ in range(len(sub_tensors)): - sub_tensor = sub_tensors.pop(0) - sub_tensors += sub_tensor.split_dim(dim, num) + # get sub-tensors with nested partition on dims + sub_tensors = tensor.split_dims(split.dims, (num,) * len(split.dims)) + # reshape to (num, num, ...) and select [i, i, ..., i] sub-tensor, i = 0 to num-1 sub_tensors = np.array(sub_tensors, dtype=IRSubTensor).reshape((num,) * len(split.dims)) sub_tensors = [sub_tensors[(i,) * len(split.dims)] for i in range(num)] return sub_tensors @@ -189,9 +187,6 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR if splits[idx].isD(): # make negative offset to be possitive ndims = len(node.input(idx).shape) - # rdim = (splits[idx].dims + ndims) % ndims - # if rdim == dim: - # return r rdims = tuple((d + ndims) % ndims for d in splits[idx].dims) if dim in rdims: return r diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index 99c90eb8..e2a93a83 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -720,6 +720,23 @@ def split_dim(self, dim: int, num: int) -> List[IRTensor]: sub_tensors.append(sub_tensor) return sub_tensors + def split_dims(self, dims: Tuple[int], nums: Tuple[int] ) -> List[IRTensor]: + """Uniformly and nestedly partition tensors alongside multiple dimensions. + + Args: + dims (Tuple[int]): the dimensions to get partitioned + nums (Tuple[int]): the number of sub-tensor generated + + Returns: + List[IRTensor]: the generated `\prod nums` sub-tensors + """ + sub_tensors = [self] + for dim, num in zip(dims, nums): + for _ in range(len(sub_tensors)): + sub_tensor = sub_tensors.pop(0) + sub_tensors += sub_tensor.split_dim(dim, num) + return sub_tensors + def split_val(self, num: int) -> List[IRTensor]: """! Partition primitive: From ac521990f02596e0f7d2774d2f1127feac4cf698 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 18 Jul 2023 14:58:40 +0800 Subject: [PATCH 1444/1892] fix group exist bug when query for all devices --- cube/runtime/device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/runtime/device.py b/cube/runtime/device.py index 27e2251b..0fbaeacd 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -34,7 +34,7 @@ def __init__(self): self.node_rank = int(os.environ.get('GROUP_RANK')) torch.cuda.set_device(self.local_rank) - self.groups: Dict = dict() + self.groups: Dict = { '1'*self.world_size: None } self.streams: Dict[str, torch.cuda.Stream] = { 'default': torch.cuda.default_stream()} From 0612ce56db849cbb8eee983312e2420d77cc8b30 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 19 Jul 2023 05:50:45 +0000 Subject: [PATCH 1445/1892] Merged PR 1675: Handle complex inputs and outputs in profiler handle complex inputs and outputs in profiler --- cube/profiler/database.py | 50 +++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index d544a9bf..737f1b34 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -79,13 +79,22 @@ def gen_torch_tensors(shape, dtype, requires_grad): eval_kwargs[name] = eval_val # run one sample outputs = func(*tensors, **train_kwargs) + # omit non-tensor outputs + ''' + only profile IRDimops currently, which has at least one tensor output and + may have non-tensor outputs (like list, tuple, dict, etc.). In additional, + we assume that non-tensor outputs will not be used in backward. + ''' outputs = (outputs,) if torch.is_tensor(outputs) else outputs + outputs = tuple(filter(lambda x: torch.is_tensor(x) and x.requires_grad, outputs)) assert all(torch.is_tensor(otensor) for otensor in outputs), \ f"{func.__name__}: require all the outputs to be tensors" grads = tuple(torch.zeros_like(otensor) for otensor in outputs) def run_step(func, tensors, kwargs, backward: bool): outputs = func(*tensors, **kwargs) + outputs = (outputs,) if torch.is_tensor(outputs) else outputs + outputs = tuple(filter(lambda x: torch.is_tensor(x) and x.requires_grad, outputs)) if backward: torch.autograd.backward(outputs, grads) return outputs @@ -114,6 +123,8 @@ def pack_hook(x): train_mem_info.append(byte_size) idx = -1 for i, t in enumerate(tensors): + if not isinstance(t, torch.Tensor): + continue if t.storage().data_ptr() == x.storage().data_ptr(): idx = i break @@ -187,37 +198,38 @@ def get_dep_names(sign: str): return ret if node.signature in CustomizedOps.kOpCodeDef: - dep_code_impl = '' - for dep_name in get_dep_names(node.signature): - dep_code_impl = dep_code_impl + CustomizedOps.kOpCodeDef[dep_name] code_impl: str = CustomizedOps.kOpCodeDef[node.signature] - def_end = code_impl.find(':\n') - assert def_end >= 0 - prev_code_lines = code_impl[:def_end+2] - succ_code_lines = code_impl[def_end+2:] - for line in dep_code_impl.split('\n'): - prev_code_lines = prev_code_lines + ' ' + line + '\n' - code_impl = prev_code_lines + succ_code_lines local = {} exec(code_impl, globals(), local) fn = list(local.values())[0] else: fn = eval(node.signature) shapes, dtypes, requires_grads, values = [], [], [], [] + + def extract_val(val: Union[IRObject, Any]) -> Any: + if isinstance(val, IRObject): + return extract_val(val.value) + elif isinstance(val, tuple): + return tuple([extract_val(v) for v in val]) + elif isinstance(val, dict): + return {k: extract_val(v) for k, v in val.items()} + elif isinstance(val, slice): + return slice(extract_val(val.start), extract_val(val.stop), extract_val(val.step)) + else: + return val + for t in node.inputs(): if isinstance(t, IRTensor): shapes.append(t.shape) dtypes.append(IRDType2TorchDType.map(t.dtype)) requires_grads.append(t.requires_grad) values.append(t) - elif isinstance(t, IRObject): - raise RuntimeError('IRObject has not been supported in profiling.') else: shapes.append(None) - dtypes.append(type(t).__name__) + dtypes.append(None) requires_grads.append(None) - values.append(t) - return fn, shapes, dtypes, requires_grads, values, node.kwargs + values.append(extract_val(t)) + return fn, shapes, dtypes, requires_grads, values, extract_val(node.kwargs) def profile(self, node: IRFwOperation, device: Optional[int] = None, override: bool = False): """ @@ -254,14 +266,14 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b residual_mem += t.byte_size() in_mem_info.append(t.byte_size()) else: - _logger.warning('node {node}: skip input {t}') + _logger.warning(f'node {node}: skip input {t}') # run profiling try: fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = \ CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) - except: - _logger.error('fail to profile {node}') + except Exception: + _logger.exception(f'fail to profile {node}, use default values') fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = 0, 0, 0, [], [] # log to database key = self._serialize(node) @@ -386,8 +398,6 @@ def _serialize(self, node: IRFwOperation) -> str: if isinstance(t, IRTensor): shapes.append(t.shape) dtypes.append(IRDType2TorchDType.map(t.dtype)) - elif isinstance(t, IRObject): - raise RuntimeError('IRObject has not been supported in _serialize') # else: # shapes.append(None) # dtypes.append(type(t)) From 5b28f9779ebfc7d7b332b16058d49c94d669a950 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 19 Jul 2023 08:25:35 +0000 Subject: [PATCH 1446/1892] Merged PR 1657: fix concrete trace proxy call fix item 1457 --- .../fx/concrete_trace_utils/concrete_proxy.py | 5 ++ .../concrete_trace_utils/concrete_tracer.py | 3 +- tests/parser/test_fx_zip.py | 79 +++++++++++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 tests/parser/test_fx_zip.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index e06bf93c..962f805f 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -79,6 +79,11 @@ def __getattr__(self, k) -> ConcreteProxy: return ConcreteAttrProxy(self, k) def __call__(self, *args, **kwargs) -> ConcreteProxy: + # If it is a module proxy, we should not create a `call_method` node for this case. + # What we need is to trace this module or the internals of this module, + # so here we directly call the `__call__` to trigger `create_proxy` inner the `__call__`. + if isinstance(self.value, torch.nn.Module): + return self.value.__call__(*args, **kwargs) return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) def __iter__(self) -> Union[Iterable, ConcreteProxy]: diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index b5af4de5..3012500e 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -251,7 +251,8 @@ class ConcreteTracer(TracerBase): default_autowrap_leaf_class: Dict[Type, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool]] = { # class _orig_bool: ([], False), - _orig_zip: ([], False), + # we don't want zip appear as a node in the graph + # _orig_zip: ([], False), _orig_int: ([], False), # iterable class diff --git a/tests/parser/test_fx_zip.py b/tests/parser/test_fx_zip.py new file mode 100644 index 00000000..c81d0d48 --- /dev/null +++ b/tests/parser/test_fx_zip.py @@ -0,0 +1,79 @@ +""" +USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_fx_zip.py +""" +import torch + +import cube +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops + +cube.init() + +class TestModel(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.fcs = torch.nn.Sequential(torch.nn.Linear(512, 10, bias=False), torch.nn.Linear(512, 10, bias=False)) + + def forward(self, x: torch.Tensor): + result = [] + xs = x.chunk(2, dim=1) + for x, fc in zip(xs, self.fcs): + result.append(fc(x)) + res = torch.cat(result) + return {'result': res, 'loss': torch.sum(res)} + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + def __init__(self, batch_size: int = 256) -> None: + self.sample = torch.rand( + [batch_size, 1024], + dtype=torch.float32, + device=torch.cuda.current_device() + ) + super().__init__(batch_size, (0,)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +def test_zip(): + model = TestModel() + dataloader = TestDataLoader() + + def policy(graph, resource): + print(graph.extra_repr()) + assert resource.ngpus == 1 + for node in graph.nodes(): + if isinstance(node, IRDimops): + print(f'# {node.anno}') + print(node) + elif isinstance(node, (IRFwOperation, IRDataOperation)): + print(node) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + model = cube.SemanticModel(model) + + @cube.compile(model, dataloader, PAS=policy, load_content=False, model_dummy_inputs={'x': next(dataloader)}) + def eval_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out['loss'].backward() + + model = model.get_gen_module() + + for idx in range(3): + eval_iter(model, dataloader) + print(f"iter {idx}/3") + + +if __name__ == '__main__': + # zip should not appear in graph + test_zip() From ccfaa5d5f87efbd374f3feff2ad325211c315fdf Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 21 Jul 2023 03:02:04 +0000 Subject: [PATCH 1447/1892] Merged PR 1671: keep no grad nodes in graph --- .../fx/concrete_trace_utils/concrete_proxy.py | 12 +++ .../concrete_trace_utils/concrete_tracer.py | 61 ++++++++++++- .../parser/fx/concrete_trace_utils/utils.py | 3 + cube/graph/parser/fx/parser.py | 24 ++++- tests/parser/test_no_grad.py | 88 +++++++++++++++++++ 5 files changed, 181 insertions(+), 7 deletions(-) create mode 100644 tests/parser/test_no_grad.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 962f805f..387f8fae 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -197,6 +197,18 @@ def __contains__(self, item) -> bool: # should only be in iterable return self.value.__contains__(item) + def __enter__(self): + if getattr(self.value.__class__.__enter__, "__fx_already_patched", False): + return self.value.__enter__() + else: + return self.value.__class__.__enter__(self) + + def __exit__(self, exc_type, exc_value, traceback): + if getattr(self.value.__class__.__exit__, "__fx_already_patched", False): + return self.value.__exit__(exc_type, exc_value, traceback) + else: + return self.value.__class__.__exit__(self, exc_type, exc_value, traceback) + @compatibility(is_backward_compatible=True) def keys(self): # to detect if in executing `**proxy` diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 3012500e..259c1ad5 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -89,6 +89,9 @@ def __exit__(self, *args): _orig_agfunc_apply, _orig_torch_assert, + _orig_torch_no_grad, + _orig_torch_no_grad_enter, + _orig_torch_no_grad_exit, _orig_type, _orig_isinstance, @@ -133,6 +136,9 @@ def __exit__(self, *args): extra_side_effectful_functions = { operator.setitem, builtins.next, + _orig_torch_no_grad, + _orig_torch_no_grad_enter, + _orig_torch_no_grad_exit, } _side_effectful_functions = _side_effectful_functions.union(extra_side_effectful_functions) @@ -962,6 +968,18 @@ def torch_assert_wrapper(condition, message): condition = condition.value return _orig_torch_assert(condition, message) + @functools.wraps(_orig_torch_no_grad) + def torch_no_grad_wrapper(): + return self.create_proxy('call_function', _orig_torch_no_grad, (), {}) + + @functools.wraps(_orig_torch_no_grad_enter) + def torch_no_grad_enter_wrapper(no_grad): + return self.create_proxy('call_function', _orig_torch_no_grad_enter, (no_grad,), {}) + + @functools.wraps(_orig_torch_no_grad_exit) + def torch_no_grad_exit_wrapper(no_grad, exc_type, exc_value, traceback): + return self.create_proxy('call_function', _orig_torch_no_grad_exit, (no_grad, exc_type, exc_value, traceback,), {}) + self.agfunc_dict: dict[Type, Any] = {} self.autowrap_leaf_pairs = { id(_orig_torch_assert): torch_assert_wrapper, @@ -1099,6 +1117,11 @@ def getattr_wrapper(obj, *args): self.patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) self.patcher.patch_method(torch.autograd.Function, "apply", agfunc_apply_wrapper, deduplicate=False) self.patcher.patch_method(torch, "_assert", torch_assert_wrapper, deduplicate=False) + # if class member functions and the class need to be wrapped together, + # wrap the member functions before wrap the class. + self.patcher.patch_method(_orig_torch_no_grad, "__enter__", torch_no_grad_enter_wrapper, deduplicate=False) + self.patcher.patch_method(_orig_torch_no_grad, "__exit__", torch_no_grad_exit_wrapper, deduplicate=False) + self.patcher.patch_method(torch, "no_grad", torch_no_grad_wrapper, deduplicate=False) self.patcher.patch_method(builtins, "map", map_wrapper, deduplicate=False) self.patcher.patch_method(builtins, "enumerate", enumerate_wrapper, deduplicate=False) @@ -1282,14 +1305,31 @@ def copy_attr_new(from_module: torch.nn.Module, to_module: torch.nn.Module, targ @staticmethod def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: - name = orig_method.__name__ - module = orig_method.__module__ if is_autograd_apply(orig_method): # for torch.autograd.Function return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' + + name = orig_method.__name__ + module = orig_method.__module__ + # if hasattr(orig_method, '__qualname__') and isinstance(orig_method.__qualname__, str): + # # if there has '.' in '__qualname__', it means this function is in a nested structure, + # # + # # for example, it is a method / function in a class: + # # torch.nn.Linear.forward.__module__ = torch.nn + # # torch.nn.Linear.forward.__name__ = forward + # # torch.nn.Linear.forward.__qualname__ = Linear.forward + # # + # # And in fx.node qualified name creating rule, the module also should include the class name, + # # in this example, the returned module should be `torch.nn.Linear`. + # # It is not the original meaning of a obj's module, but we need this workaround to reuse fx node. + # splited_names = orig_method.__qualname__.split('.') + # class_name, name = splited_names[:-1], splited_names[-1] + # module = '.'.join([module] + class_name) + if module == 'torch.autograd.grad_mode' and name in ['__enter__', '__exit__']: + return 'torch.autograd.grad_mode.no_grad' if module is not None: return module - elif hasattr(orig_method, '__qualname__')\ + if hasattr(orig_method, '__qualname__')\ and isinstance(orig_method.__qualname__, str) and orig_method.__qualname__.startswith('_VariableFunctionsClass.'): return 'torch._C._VariableFunctions' for guess in [torch, getattr(torch.nn, 'functional')]: @@ -1468,6 +1508,9 @@ def _retain_weight_consistency(root: torch.nn.Module): @functools.wraps(_orig_node_is_impure) def node_is_impure_wrapper(node): + if is_useless_no_grad_node(node): + return False + if node.op in {"placeholder", "output"}: return True @@ -1489,6 +1532,18 @@ def node_is_impure_wrapper(node): return False +def is_useless_no_grad_node(node: Node): + # keep the no_gard related nodes, but except useless situation: no node between __enter__ and __exit__ + if node.op == 'call_function': + if node.target is _orig_torch_no_grad_exit: + if node.prev.target is _orig_torch_no_grad_enter and node.prev.prev.target is _orig_torch_no_grad: + setattr(node.prev, '_is_impure', False) + setattr(node.prev.prev, '_is_impure', False) + return True + if node.target is _orig_torch_no_grad_enter or node.target is _orig_torch_no_grad: + return not getattr(node, '_is_impure', True) + return False + def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Union[Dict[str, Any], Tuple], *, diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index 8265e6df..0fcaa9a2 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -17,6 +17,9 @@ _orig_agfunc_apply: Callable = torch.autograd.function.Function.apply _orig_torch_assert: Callable = torch._assert +_orig_torch_no_grad: Callable = torch.no_grad +_orig_torch_no_grad_enter: Callable = torch.no_grad.__enter__ +_orig_torch_no_grad_exit: Callable = torch.no_grad.__exit__ _orig_type: Callable = builtins.type _orig_isinstance: Callable = builtins.isinstance diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 4b84966e..15e9d5f8 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -479,15 +479,31 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> @staticmethod def _find_module_of_method(orig_method: Callable[..., Any]) -> str: - name = orig_method.__name__ - module = orig_method.__module__ if getattr(orig_method, '__name__', None) == 'apply' and isinstance(getattr(orig_method, '__self__', None), Type) \ and issubclass(orig_method.__self__, torch.autograd.Function): # for torch.autograd.Function return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' - elif module is not None: + + name = orig_method.__name__ + module = orig_method.__module__ + # if hasattr(orig_method, '__qualname__') and isinstance(orig_method.__qualname__, str): + # # if there has '.' in '__qualname__', it means this function is in a nested structure, + # # + # # for example, it is a method / function in a class: + # # torch.nn.Linear.forward.__module__ = torch.nn + # # torch.nn.Linear.forward.__name__ = forward + # # torch.nn.Linear.forward.__qualname__ = Linear.forward + # # + # # And in fx.node qualified name creating rule, the module also should include the class name, + # # in this example, the returned module should be `torch.nn.Linear`. + # splited_names = orig_method.__qualname__.split('.') + # class_name, name = splited_names[:-1], splited_names[-1] + # module = '.'.join([module] + class_name) + if module == 'torch.autograd.grad_mode' and name in ['__enter__', '__exit__']: + return 'torch.autograd.grad_mode.no_grad' + if module is not None: return module - elif hasattr(orig_method, '__qualname__')\ + if hasattr(orig_method, '__qualname__')\ and isinstance(orig_method.__qualname__, str) and orig_method.__qualname__.startswith('_VariableFunctionsClass.'): return 'torch._C._VariableFunctions' for guess in [torch, getattr(torch.nn, 'functional')]: diff --git a/tests/parser/test_no_grad.py b/tests/parser/test_no_grad.py new file mode 100644 index 00000000..64d771ed --- /dev/null +++ b/tests/parser/test_no_grad.py @@ -0,0 +1,88 @@ +""" +USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_no_grad.py +""" +from typing import List +import torch + +import cube +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops + +cube.init() + + +class TestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(512, 10) + + def forward(self, x: torch.Tensor): + # this no grad will be dce + with torch.no_grad(): + pass + + # this no grad will not be dce + with torch.no_grad(): + res = self.fc(x) + + return {'res': res, 'loss': res.sum()} + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int = 256) -> None: + self.sample = torch.rand( + [batch_size, 512], + dtype=torch.float32, + device=torch.cuda.current_device() + ) + super().__init__(batch_size, (0,)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +def test_no_grad(): + + model = TestModel() + dataloader = TestDataLoader() + + def policy(graph, resource): + print(graph.extra_repr()) + assert resource.ngpus == 1 + for node in graph.nodes(): + if isinstance(node, IRDimops): + print(f'# {node.anno}') + print(node) + elif isinstance(node, (IRFwOperation, IRDataOperation)): + print(node) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + model = cube.SemanticModel(model) + + @cube.compile(model, dataloader, PAS=policy, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) + def eval_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out['loss'].backward() + # return out + + model = model.get_gen_module() + + for idx in range(3): + eval_iter(model, dataloader) + print(f"iter {idx}/3") + + +if __name__ == '__main__': + # consecutive no_grad __enter__ __exit__ sequences will be dce + test_no_grad() From 01e0220f1188a0d469b4d14a04ee755d8cc9841d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jul 2023 19:38:41 -0700 Subject: [PATCH 1448/1892] fix gradient setup for graph inputs --- cube/compiler.py | 1 + cube/graph/segment.py | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index a54900dd..08b4683a 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -99,6 +99,7 @@ def train_iter(model, dataloader): arg = IRFullTensor(arg.shape, name='tensor', requires_grad=arg.requires_grad, dtype=DType2IRDType.map(arg.dtype)).tosub() + arg.grad = arg.parent.grad.tosub() if arg.requires_grad else None else: arg= IRObject('obj') inputs.append(arg) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 48da8a69..bf5109d2 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1065,11 +1065,14 @@ def extra_repr(self) -> str: @staticmethod def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[IRObject]: - """ - Get objects from val of complex data type - Support complex of types: List, Tuple, Dict, torch.Tensor, object + """Get all IRObjects from a complex data structure + + Supported complex of types: List, Tuple, Dict, IRTensor, IRObject - @param val Any + Args: + val (Any): the complex data structure to be modified + _objects (List[IRObject] | None): + if provided, the objects will be appened into this @return _objects List[IRObject]: all IRObject """ @@ -1087,14 +1090,16 @@ def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[ @staticmethod def modify_objects_of_complex(val: Any, modifier: Callable) -> Any: - """ - Get objects from val of complex data type - Support complex of types: List, Tuple, Dict, torch.Tensor, object - - @param val Any - @param modifier Callable: modify IRObject to another one + """Return a complex data structure where its IRObjects are in-placemently modified + + Supported complex of types: List, Tuple, Dict, IRTensor, IRObject + + Args: + val (Any): the complex data structure to be modified + modifier (Callable): an inplacement modifier that takes an IRObject and return None - @return new_val List[IRObject]: all IRObject + Return: + new_val (Any): complex data structure with modified IRObjects """ rcall = IRSegment.modify_objects_of_complex if isinstance(val, tuple): From ebab2a63d95da345a8dbf5c12bb5a382b6361aaf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jul 2023 20:08:02 -0700 Subject: [PATCH 1449/1892] set graph input always to be true --- cube/compiler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cube/compiler.py b/cube/compiler.py index 08b4683a..b5dc5cc8 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -96,8 +96,12 @@ def train_iter(model, dataloader): dataloader = arg arg = SemanticDataLoader(dataloader) elif isinstance(arg, torch.Tensor): + # note: we will always set tensor to require gradient, which may + # generate backward communications in adapter. However, as long as + # the data doesn't require gradient in real runtime, the backward + # communication will not be triggered. arg = IRFullTensor(arg.shape, name='tensor', - requires_grad=arg.requires_grad, + requires_grad=True, dtype=DType2IRDType.map(arg.dtype)).tosub() arg.grad = arg.parent.grad.tosub() if arg.requires_grad else None else: From 5bd7941dd94ab3c825e0066dfc7c38523b94fafa Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jul 2023 22:25:51 -0700 Subject: [PATCH 1450/1892] output adapter generation --- cube/graph/gener/gen.py | 120 +++++++++++++++++++++------------------- 1 file changed, 64 insertions(+), 56 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index b9ba294d..58661963 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -345,8 +345,6 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # consumers can be operators and graph outputs fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) - if ftensor in output_consumer: - fctensors = fctensors + tuple(fwop.input(0) for fwop in output_consumer[ftensor]) fctensors = expand_devices(fctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" @@ -354,73 +352,83 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bconsumers, bctensors = [], [] if isinstance(ftensor.grad, IRFullTensor): bproducers, bptensors = bgraph.producers(ftensor.grad), bgraph.ptensors(ftensor.grad) - if ftensor in output_consumer: - bptensors = bptensors + tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) bptensors = expand_devices(bptensors, producer=True) - assert all(len(ptensor.device) == 1 for ptensor in bptensors), ( - f"Not support for multi-device:\n" - f"{[ptensor.device for ptensor in bptensors]}" - f"{[ptensor.cell for ptensor in bptensors]}" - ) bconsumers, bctensors = bgraph.consumers(ftensor.grad), bgraph.ctensors(ftensor.grad) if ftensor in input_producer: bctensors = bctensors + tuple(fwop.output(0).grad for fwop in input_producer[ftensor]) bctensors = expand_devices(bctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" - if skip(fptensors, fctensors) and skip(bptensors, bctensors): - continue + fadapters = [] - fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) - if fadapter is None: - continue + # activation -> activation generation + if (not skip(fptensors, fctensors)) or (not skip(bptensors, bctensors)): + fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) + if fadapter is not None: + fadapters.append(fadapter) - if not isinstance(graph, IRGraph): - if not (fadapter.differentiable or fadapter.mirror is None): - raise NotImplementedError( - "Require adapter to be differentiable for nested IRAdapter.\n" - "Condition to be differentiable: prodcuers have same device set with consumers\n" - f"Failed FullTensor: {ftensor}" - f"{graph.debug_tensor_map_str(ftensor)}" - f"Failed FullTensor.grad: {ftensor.grad}" - f"{bgraph.debug_tensor_map_str(ftensor.grad) if ftensor.grad is not None else None}" - ) + # activation -> output generation + if ftensor in output_consumer: + # TODO: dedup adapter if the output is same with activation + fctensors = tuple(fwop.input(0) for fwop in output_consumer[ftensor]) + fctensors = expand_devices(fctensors, consumer=True) + bptensors = [] + if isinstance(ftensor.grad, IRFullTensor): + bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) + bptensors = expand_devices(bptensors, producer=True) + + fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) + if fadapter is not None: + fadapters.append(fadapter) + + # insert adapters + for fadapter in fadapters: + if not isinstance(graph, IRGraph): + if not (fadapter.differentiable or fadapter.mirror is None): + raise NotImplementedError( + "Require adapter to be differentiable for nested IRAdapter.\n" + "Condition to be differentiable: prodcuers have same device set with consumers\n" + f"Failed FullTensor: {ftensor}" + f"{graph.debug_tensor_map_str(ftensor)}" + f"Failed FullTensor.grad: {ftensor.grad}" + f"{bgraph.debug_tensor_map_str(ftensor.grad) if ftensor.grad is not None else None}" + ) - badapter: Optional[IRAdapter] = fadapter.mirror + badapter: Optional[IRAdapter] = fadapter.mirror - if (badapter is not None and len(fadapter.prims) == 0 and len(badapter.prims) == 0) or \ - (badapter is None and len(fadapter.prims) == 0): - continue + if (badapter is not None and len(fadapter.prims) == 0 and len(badapter.prims) == 0) or \ + (badapter is None and len(fadapter.prims) == 0): + continue - # insert forward adapter - # graph.insert(fadapter, max(producers) + 1) - if len(fconsumers) > 0: - fidx = min(graph.nodes().index(c) for c in fconsumers) - else: - # no consumer: find the last forward node - for fidx, node in enumerate(graph.nodes()[::-1]): - if node.isfw(): - fidx = graph.nnodes - fidx - break - graph.insert(fadapter, fidx) - # setup recompute - if allow_recompute: - if fidx > 0: - prev_node = graph.node(fidx-1) - if isinstance(prev_node, (IRFwOperation, IRAdapter)): - fadapter.recompute = prev_node.recompute - - # insert backward adapter - if badapter is not None: - assert isinstance(badapter, IRAdapter) - assert isinstance(bgraph, IRSegment) - if len(bproducers) > 0: - bidx = max(bgraph.nodes().index(p) for p in bproducers) + 1 + # insert forward adapter + # graph.insert(fadapter, max(producers) + 1) + if len(fconsumers) > 0: + fidx = min(graph.nodes().index(c) for c in fconsumers) else: - # no producer: find the first backward node - for bidx, node in enumerate(bgraph.nodes()): - if not node.isfw(): break - bgraph.insert(badapter, bidx) + # no consumer: find the last forward node + for fidx, node in enumerate(graph.nodes()[::-1]): + if node.isfw(): + fidx = graph.nnodes - fidx + break + graph.insert(fadapter, fidx) + # setup recompute + if allow_recompute: + if fidx > 0: + prev_node = graph.node(fidx-1) + if isinstance(prev_node, (IRFwOperation, IRAdapter)): + fadapter.recompute = prev_node.recompute + + # insert backward adapter + if badapter is not None: + assert isinstance(badapter, IRAdapter) + assert isinstance(bgraph, IRSegment) + if len(bproducers) > 0: + bidx = max(bgraph.nodes().index(p) for p in bproducers) + 1 + else: + # no producer: find the first backward node + for bidx, node in enumerate(bgraph.nodes()): + if not node.isfw(): break + bgraph.insert(badapter, bidx) # generate adapter for each segment segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] From 0192da58ea832cbec366033fe727b5edbe9641b5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jul 2023 23:07:46 -0700 Subject: [PATCH 1451/1892] remove possible failed cases --- cube/codegen/emit.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 71e27f66..eaa67cd0 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -1,4 +1,5 @@ from typing import Generator, Iterable, List, Any, Optional, Tuple +import logging from cube.ir.cten import IRCell, IRTensor, IRObject from cube.ir.dtype import IRDType @@ -13,6 +14,8 @@ from cube.flags import CompileFlag +_logger = logging.getLogger(__name__) + class IRValue: @@ -263,7 +266,7 @@ def get_backward_callsite_io_tensors(bwop: IRCell) -> Tuple: input_grads = [t for t in bwop.outputs() if isinstance(t, IRSubTensor)] output_grads = [t for t in bwop.inputs() if isinstance(t, IRSubTensor)] - input_tensors = [grad2tensor[g] for g in input_grads] - output_tensors = [grad2tensor[g] for g in output_grads] + input_tensors = [grad2tensor[g] for g in input_grads if g in grad2tensor] + output_tensors = [grad2tensor[g] for g in output_grads if g in grad2tensor] return input_tensors, output_tensors, output_grads, input_grads From 1939a614b7933a6e3f4853b6ca43f4caeb370318 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 20 Jul 2023 23:53:23 -0700 Subject: [PATCH 1452/1892] skip adapter for output tensor if possible --- cube/graph/gener/gen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 58661963..38d438f9 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -376,10 +376,10 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: if isinstance(ftensor.grad, IRFullTensor): bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) bptensors = expand_devices(bptensors, producer=True) - - fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) - if fadapter is not None: - fadapters.append(fadapter) + if (not skip(fptensors, fctensors)) or (not skip(bptensors, bctensors)): + fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) + if fadapter is not None: + fadapters.append(fadapter) # insert adapters for fadapter in fadapters: From 5013e9f1f28cfa6ceb758e95cc95043f9b6ee9e7 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 24 Jul 2023 06:29:44 +0000 Subject: [PATCH 1453/1892] Merged PR 1672: support customize autograd function --- cube/graph/parser/converter.py | 4 +- .../concrete_trace_utils/concrete_tracer.py | 2 +- cube/graph/parser/fx/mapping.py | 2 +- cube/graph/parser/register.py | 83 ++++++++++--- tests/parser/test_cus_autograd.py | 111 ++++++++++++++++++ 5 files changed, 181 insertions(+), 21 deletions(-) create mode 100644 tests/parser/test_cus_autograd.py diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 742b3651..ddd3f646 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -40,8 +40,8 @@ def convert_model(model: torch.nn.Module, IRGraph: IRGraph of model """ # get registered leaf function - customized_funcs = CustomizedOps.kOpRuntime.values() - leaf_functions = {func: ([], False, None) for func in customized_funcs} + autowrap_funcs = [CustomizedOps.kOpRuntime.get(sign, None) for sign in CustomizedOps.kOpAutowrap] + leaf_functions = {func: ([], True, None) for func in autowrap_funcs if func is not None} # step 1: trace model diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 259c1ad5..d3606b66 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -389,7 +389,7 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] result = result.cpu() elif isinstance(result, (list, dict, tuple)): result = tree_map(to_cpu, result) - elif isinstance(result, (int, bool, torch.device, torch.dtype)): + elif isinstance(result, (int, bool, torch.device, torch.dtype, _orig_torch_no_grad)) or result is None: # avoid too noisy warning pass else: diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 8ba3b9e0..8069dc93 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -134,7 +134,7 @@ def exist(signature: str) -> bool: # # __ftemplate('embedding'): function.Embedding, # - # __ftemplate('cross_entropy'): function.CrossEntropy, + __ftemplate('cross_entropy'): function.CrossEntropy, # # # creators __ttemplate('empty'): function.Empty, diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index 9c9f9c29..b54d705d 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -23,6 +23,10 @@ class CustomizedOps: kOpRuntime: Dict[str, Callable] = {} # signature -> runtime function implementation code kOpCodeDef: Dict[str, str] = {} + # original signature (xxx.xxx.xxx) -> pure signature (xxx_xxx_xxx) + kOpSignMap: Dict[str, str] = {} + # the function in it will be autowrapped by tracer. + kOpAutowrap: List[str] = [] @staticmethod def map(signature: str) -> Callable: @@ -34,7 +38,7 @@ def map(signature: str) -> Callable: Returns: Callable: IRDimop creation function """ - signature = signature.split('.')[-1] + signature = CustomizedOps.pure_signature(signature) if signature in CustomizedOps.kOpMap: return partial(CustomizedOps.kOpMap[signature], signature=signature) else: @@ -43,11 +47,12 @@ def map(signature: str) -> Callable: @staticmethod def exist(signature: str) -> bool: """Check if the signature is registered""" - signature = signature.split('.')[-1] + signature = CustomizedOps.pure_signature(signature) return signature in CustomizedOps.kOpMap @staticmethod - def register(signature: str, op: Callable, code: str, runtime_fn: Callable): + def register(signature: str, op: Callable, code: str, runtime_fn: Callable, + keep_full_name: bool = False, trace_autowrap: bool = True): """Register an operator Args: @@ -55,6 +60,8 @@ def register(signature: str, op: Callable, code: str, runtime_fn: Callable): op (Callable): IRDimop creation function code (str): runtime function implementation code runtime_fn (Callable): runtime function + keep_full_name (bool): if set True, the full name will be kept, `.` in name will be replaced to `_` + trace_autowrap (bool): if set True, the function will be autowrapped by tracer. Returns: None @@ -62,11 +69,30 @@ def register(signature: str, op: Callable, code: str, runtime_fn: Callable): builtins = ['_operator', 'torch', 'cube.runtime.function'] if any(signature.startswith(builtin) for builtin in builtins): raise RuntimeError(f"Cannot register operators with signature starting from any of {builtins}") - signature = signature.split('.')[-1] + signature = CustomizedOps.create_pure_signature(signature, keep_full_name) assert signature not in CustomizedOps.kOpMap, f"function {signature} is already registered" CustomizedOps.kOpMap[signature] = op CustomizedOps.kOpRuntime[signature] = runtime_fn CustomizedOps.kOpCodeDef[signature] = code + if trace_autowrap and signature not in CustomizedOps.kOpAutowrap: + CustomizedOps.kOpAutowrap.append(signature) + elif not trace_autowrap and signature in CustomizedOps.kOpAutowrap: + CustomizedOps.kOpAutowrap.pop(signature) + + @staticmethod + def create_pure_signature(signature: str, keep_full_name: bool) -> str: + if keep_full_name: + pure_signature = signature.replace('__main__.', '', 1) if signature.startswith('__main__.') else signature + pure_signature = pure_signature.replace('.', '_') + CustomizedOps.kOpSignMap[signature] = pure_signature + return pure_signature + return signature.split('.')[-1] + + @staticmethod + def pure_signature(signature: str) -> str: + if signature in CustomizedOps.kOpSignMap: + return CustomizedOps.kOpSignMap[signature] + return signature.split('.')[-1] def register(anno: str, name: Optional[str] = None, @@ -108,13 +134,21 @@ def funcname(x: torch.Tensor, b: int = 4): Returns: fn (Callable): the runtime function """ + from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply + def decorator(fn: Callable): if not callable(fn): raise TypeError("Expected a function") - fsig = fn.__name__ - op_name = name if name is not None else fsig - args = inspect.signature(fn) - arg_names = list(args.parameters.keys()) + if is_autograd_apply(fn): + fsig = CustomizedOps.create_pure_signature(f'{fn.__self__.__module__}.{fn.__self__.__name__}.apply', True) + op_name = name if name is not None else fn.__name__ + args = inspect.signature(fn.__self__.forward) + arg_names = list(args.parameters.keys())[1:] + else: + fsig = fn.__name__ + op_name = name if name is not None else fsig + args = inspect.signature(fn) + arg_names = list(args.parameters.keys()) # get argument types arg_kinds = input_type_annos if input_type_annos is not None else \ [args.parameters[name].annotation for name in arg_names] @@ -134,16 +168,28 @@ def decorator(fn: Callable): # get customized op code if code_impl_pattern == 'import': - import_path = inspect.getmodule(fn).__name__ - if import_path == '__main__': - _logger.warning(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' - f'This may cause error when the function has inner functions from other modules. ' - f'To solve this, define the function in another module and import into main', stacklevel=0) - code = inspect.getsource(fn) - code = code[code.index('def'):] + if is_autograd_apply(fn): + import_path = inspect.getmodule(fn.__self__).__name__ + if import_path == '__main__': + _logger.warning(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' + f'This may cause error when the function has inner functions from other modules. ' + f'To solve this, define the function in another module and import into main', stacklevel=0) + code = inspect.getsource(fn.__self__) + code = code[code.index(f'class {fn.__self__.__name__}'):] + f'\n{fsig}={fn.__self__.__name__}.apply' + else: + code = f'from {import_path} import {fn.__self__.__name__}\n{fsig}={fn.__self__.__name__}.apply' else: - code = f'from {import_path} import {fsig}' + import_path = inspect.getmodule(fn).__name__ + if import_path == '__main__': + _logger.warning(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' + f'This may cause error when the function has inner functions from other modules. ' + f'To solve this, define the function in another module and import into main', stacklevel=0) + code = inspect.getsource(fn) + code = code[code.index('def'):] + else: + code = f'from {import_path} import {fsig}' elif code_impl_pattern == 'source': + assert not is_autograd_apply(fn), 'Only support code_impl_pattern="import" for autograd.Function.apply.' code = inspect.getsource(fn) code = code[code.index('def'):] else: @@ -161,7 +207,10 @@ def udfop(*args, signature=None, **kwargs): return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, transform_rules=rules, **kwargs) _logger.info(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') - CustomizedOps.register(fsig, udfop, code, fn) + if is_autograd_apply(fn): + CustomizedOps.register(f'{fn.__self__.__module__}.{fn.__self__.__name__}.apply', udfop, code, fn, True, False) + else: + CustomizedOps.register(fsig, udfop, code, fn) return fn return decorator diff --git a/tests/parser/test_cus_autograd.py b/tests/parser/test_cus_autograd.py new file mode 100644 index 00000000..d2570a65 --- /dev/null +++ b/tests/parser/test_cus_autograd.py @@ -0,0 +1,111 @@ +""" +USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_cus_autograd.py +""" +import torch + +import cube +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.dimops import IRDimops +from cube.graph.parser import register + +cube.init() + + +class GeLU(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input: torch.Tensor, bias: torch.Tensor): + ctx.save_for_backward(input, bias) + return GeLU.bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = GeLU.bias_gelu_back(grad_output, bias, input) + return tmp, tmp + + @staticmethod + def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + @staticmethod + def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + + +class TestModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(512, 10) + self.bias = torch.nn.Parameter(torch.rand(10)) + + def forward(self, x: torch.Tensor): + res = GeLU.apply(self.fc(x), self.bias) + loss = res.sum() + return {'res': res, 'loss': loss} + + +class TestDataLoader(cube.runtime.syndata.CubeDataLoader): + + def __init__(self, batch_size: int = 256) -> None: + self.sample = torch.rand( + [batch_size, 512], + dtype=torch.float32, + device=torch.cuda.current_device() + ) + super().__init__(batch_size, (0,)) + + def __iter__(self): + return self + + def __next__(self): + return self.sample + + def set_batch_size(self, batch_size: int): + return True + + +def test_cus_autograd(): + register('* h, h -> * h')(GeLU.apply) + + model = TestModel() + dataloader = TestDataLoader() + + def policy(graph, resource): + print(graph.extra_repr()) + assert resource.ngpus == 1 + for node in graph.nodes(): + if isinstance(node, IRDimops): + print(f'# {node.anno}') + print(node) + elif isinstance(node, (IRFwOperation, IRDataOperation)): + print(node) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + model = cube.SemanticModel(model) + + @cube.compile(model, dataloader, PAS=policy, load_content=False, + model_dummy_inputs={'x': next(dataloader)}) + def eval_iter(model, dataloader): + data = next(dataloader) + out = model(data) + out['loss'].backward() + # return out + + model = model.get_gen_module() + + for idx in range(3): + eval_iter(model, dataloader) + print(f"iter {idx}/3") + + +if __name__ == '__main__': + test_cus_autograd() From cefcdc8a7c051d52a4458059384662225259cf0d Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 25 Jul 2023 03:57:27 +0000 Subject: [PATCH 1454/1892] Merged PR 1679: codegen: refine converter and init unit tests codegen: refine converter and init unit tests --- .gitignore | 5 + README.md | 46 +- cube/flags.py | 3 - cube/graph/parser/__init__.py | 1 - cube/graph/parser/converter.py | 153 ++-- .../kwargs_shape_prop/kwargs_interpreter.py | 3 +- cube/graph/parser/fx/parser.py | 26 +- cube/graph/parser/script/mapping.py | 201 ----- cube/graph/parser/script/parser.py | 845 ------------------ cube/program.py | 16 +- cube/runtime/adapter/collectives.py | 12 +- requirements-dev.txt | 8 + requirements.txt | 2 +- tests/test_parser.py | 61 -- tests/test_prim_loop.py | 156 ---- tox.ini | 18 + unit_tests/__init__.py | 0 unit_tests/graph/__init__.py | 0 unit_tests/graph/parser/__init__.py | 0 unit_tests/graph/parser/test_converter.py | 75 ++ unit_tests/launch_torchrun.py | 29 + unit_tests/runtime/__init__.py | 0 .../runtime/test_runtime_collectives.py | 233 +++++ unit_tests/test_utils.py | 11 + 24 files changed, 541 insertions(+), 1363 deletions(-) delete mode 100644 cube/graph/parser/script/mapping.py delete mode 100644 cube/graph/parser/script/parser.py create mode 100644 requirements-dev.txt delete mode 100644 tests/test_parser.py delete mode 100644 tests/test_prim_loop.py create mode 100644 tox.ini create mode 100644 unit_tests/__init__.py create mode 100644 unit_tests/graph/__init__.py create mode 100644 unit_tests/graph/parser/__init__.py create mode 100644 unit_tests/graph/parser/test_converter.py create mode 100644 unit_tests/launch_torchrun.py create mode 100644 unit_tests/runtime/__init__.py create mode 100644 unit_tests/runtime/test_runtime_collectives.py create mode 100644 unit_tests/test_utils.py diff --git a/.gitignore b/.gitignore index 715393f3..7bcd3abd 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,11 @@ __pycache__ .vs/ .vscode/ +.tox/ +.coverage +.coverage.* +htmlcov/ + benchmark/megatron/Megatron-LM benchmark/deepspeed/Megatron-DeepSpeed diff --git a/README.md b/README.md index b7a01fae..097cfe5c 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ * ### Debug for model parsing check on single Device ```shell -PYTHONPATH=.:$PYTHONPATH SINGLE_DEV_MODE=1 python examples/mlp/linears.py +PYTHONPATH=.:$PYTHONPATH SINGLE_DEV_MODE=1 python examples/mlp/linears.py ``` @@ -45,7 +45,7 @@ PYTHONPATH=.:$PYTHONPATH SINGLE_DEV_MODE=1 python examples/mlp/linears.py python setup.py develop ``` -* ### Run Example +* ### Run Example [Micro Benchmark] Run a mutiple MLP Model ```sh @@ -82,7 +82,7 @@ class SampleClass: def public_method(self, a, b): """Performs operation blah. - + Long description here. Args: @@ -94,4 +94,42 @@ class SampleClass: k (int): xxx """ # function implementation goes here -``` \ No newline at end of file +``` + +## Run unit tests + +We use `tox` to run unit tests. You should install `tox` in your development environemnt +``` +pip install tox +``` +Currently we only use python3.10 to run tests. If you don't have python3.10 in your system, you can use conda. After conda is installed, you should install tox conda plugin by running +``` +pip install tox-conda +``` +After tox is ready, you can run all the unit test by running +``` +tox +``` +Please note tox will reuse the same virtual environment which is initialized by installing all packages listed in `requirements.txt` and `requirements-dev.txt`. If any of above files are modified, you should re-create virtual environment by running +``` +tox -r +``` + +### Run unit tests in vscode + +VS Code has a great support to unit tests. You can run/debug every tests easily in VS Code. Please refer to this document to set up your environment https://code.visualstudio.com/docs/python/testing + +Another trick is, if you want to step into pakcage source code, you can add the following config to your .vscode/launch.json: +``` +{ + "name": "Debug Unit Test", + "type": "python", + "request": "test", + "justMyCode": false, +}, +``` + +### Write Unit Tests +1. If you need to use torchrun, please refer to `unit_test/launch_torchrun.py`, and you can find examples in `unit_tests/runtime/test_runtime_collectives.py`. Please note that `torchrun` is very slow, you should reduce its usage as possible. +2. If you want to mock up any functions/methods, please use pytest-mock. +3. **NOTE**: The name of test files and test functions must start with `test_` diff --git a/cube/flags.py b/cube/flags.py index e1317f1c..d93bcbdf 100644 --- a/cube/flags.py +++ b/cube/flags.py @@ -16,10 +16,7 @@ def _to_int(s: str, default=0) -> int: class CompileFlag: - # ================ compiling ======================== - use_torchfx = _to_bool('USE_TORCHFX') # using torch.fx or torchscript as frontend to capture dataflow graph - use_default_fx_tracer = _to_bool('USE_DEFAULT_FX_TRACER') # using default fx tracer or more powerful concrete_tracer # worker sleep in seconds worker_sleep = _to_int('WORKER_SLEEP') disable_intra_rvd = _to_bool('DISABLE_INTRA_RVD') diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index d3642c72..c9b1ba92 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,4 +1,3 @@ -from cube.graph.parser.script.parser import ScriptModuleParser from cube.graph.parser.fx.parser import FxModuleParser, FxFuncOpTracer from cube.graph.parser.converter import convert_model from cube.graph.parser.register import register \ No newline at end of file diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index ddd3f646..45ff6ef4 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,94 +1,96 @@ -from typing import Optional, List +from typing import Any, Dict, Optional, List, Union import logging +from pathlib import Path +import os from cube.ir.tensor import IRFullTensor -from cube.graph.parser import ScriptModuleParser from cube.graph.parser.register import CustomizedOps from cube.graph import IRGraph from cube.flags import CompileFlag +from cube.graph.parser.fx.parser import FxModuleParser +from cube.graph.parser.fx.concrete_trace_utils import concrete_trace + import torch import torch.fx _logger = logging.getLogger(__name__) -try: +try: import apex HAS_APEX = True except: HAS_APEX = False -def convert_model(model: torch.nn.Module, - input_shapes: Optional[ List[List[int],] ] = None, - dummy_input = None, - save_content: bool = True, - dynamic_shape: bool = False) -> IRGraph: - """Convert torch.nn.Module based model into IRGraph +def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: + """ + Convert torch.nn.Module based model into torch.fx.GraphModule Args: model (torch.nn.Module): single-device model description - input_shapes (Optional[ List[List[int],] ]): - input shapes of model, only required for torch.jit.script parser - dummy_input (Optional[Any]): - dummy input of model, only required for torch.fx parser - save_content (bool): - whether to save the content of model and load it into generated model. Default True. - dynamic_shape (bool): - whether to use dynamic shape. Default False. - + dummy_input (Dict[str, Any]): + dummy input of model, the keys are the names of forward arguments. Returns: - IRGraph: IRGraph of model + torch.fx.GraphModule representation of model """ # get registered leaf function autowrap_funcs = [CustomizedOps.kOpRuntime.get(sign, None) for sign in CustomizedOps.kOpAutowrap] leaf_functions = {func: ([], True, None) for func in autowrap_funcs if func is not None} - - # step 1: trace model - if CompileFlag.use_torchfx: - _logger.info('use concrete torch.fx tracer') - from cube.graph.parser.fx.concrete_trace_utils import concrete_trace - from cube.graph.parser.fx.parser import FxModuleParser - if HAS_APEX: - leaf_module = ( - # torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, - apex.normalization.FusedLayerNorm, - # NOTE: the following modules also have different behavior depending on self.training. but currently in used. - # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, - # torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, - # torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.SyncBatchNorm, - # torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, - # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, - ) - else: - _logger.warning('apex package is not installed') - leaf_module = None - traced_model = concrete_trace( - model, - dummy_input, - use_operator_patch=True, - leaf_module=leaf_module, - autowrap_leaf_function=leaf_functions, - cpu_offload=True, - record_frames=not CompileFlag.disable_code_line_info, - ) - else: - _logger.info('use torch.jit.script tracer') - traced_model = torch.jit.script(model) - - # step 2: convert traced model into IRGraph - if CompileFlag.use_torchfx: - FxModuleParser.save_content = save_content - FxModuleParser.dynamic_shape = dynamic_shape - _logger.info(f"use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") - inputs, nodes, outputs = FxModuleParser.parse(traced_model, dummy_input) - module_name = model.__class__.__name__ + if HAS_APEX: + leaf_module = ( + # torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, + apex.normalization.FusedLayerNorm, + # NOTE: the following modules also have different behavior depending on self.training. but currently in used. + # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, + # torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, + # torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.SyncBatchNorm, + # torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, + # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, + ) else: - if dynamic_shape: - _logger.warning('dynamic shape is not supported in torch.jit.script') - ScriptModuleParser.save_content = save_content - inputs, nodes, outputs = ScriptModuleParser.parse_module(traced_model, input_shapes) - module_name = traced_model.original_name + _logger.warning('apex package is not installed') + leaf_module = None + traced_model = concrete_trace( + model, + dummy_input, + use_operator_patch=True, + leaf_module=leaf_module, + autowrap_leaf_function=leaf_functions, + cpu_offload=True, + record_frames=not CompileFlag.disable_code_line_info, + ) + return traced_model + + +def to_ir_graph( + traced_model: torch.fx.GraphModule, + dummy_input: Dict[str, Any], + attr_save_dir: Union[str, Path], + dynamic_shape: bool = False, +) -> IRGraph: + """Convert torch.fx.GraphModule based model into IRGraph + + Args: + traced_model (torch.fx.GraphModule): single-device model description in fx format + dummy_input (Dict[str, Any]): + dummy input of model, the keys are the names of forward arguments. + dynamic_shape (bool): + whether to use dynamic shape. Default False. + attr_save_dir (Union[str, Path]): directory to save content (attribtes) + + Returns: + IRGraph: IRGraph of model + """ + FxModuleParser.save_content = True + FxModuleParser.dynamic_shape = dynamic_shape + _logger.info(f"use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") + + inputs, nodes, outputs = FxModuleParser.parse( + traced_model, dummy_input, + attr_save_dir=attr_save_dir + ) + module_name = traced_model.__class__.__name__ for input in inputs: if isinstance(input, IRFullTensor): @@ -97,3 +99,26 @@ def convert_model(model: torch.nn.Module, graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) return graph + +def convert_model( + model: torch.nn.Module, + dummy_input: Dict[str, Any], + attr_save_dir: Union[str, Path], + dynamic_shape: bool = False +) -> IRGraph: + """Convert torch.nn.Module based model into IRGraph + + Args: + model (torch.nn.Module): single-device model description + dummy_input (Dict[str, Any]): + dummy input of model, the keys are the names of forward arguments. + dynamic_shape (bool): + whether to use dynamic shape. Default False. + attr_save_dir (Union[str, Path]): directory to save content (attribtes) + + Returns: + IRGraph: IRGraph of model + """ + traced_model = to_fx_graph(model, dummy_input) + graph = to_ir_graph(traced_model, dummy_input, attr_save_dir, dynamic_shape) + return graph diff --git a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py index e8add705..a7ced04f 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py +++ b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py @@ -129,7 +129,8 @@ def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[s remaining_keys = [key for key in self.concrete_kwargs if key not in self.used_concrete_kwargs] return {key: self.concrete_kwargs[key] for key in remaining_keys} elif target.startswith('*'): - assert self.concrete_kwargs is None, 'unexpected positional args in kwargs mode' + if self.concrete_kwargs is not None: + raise RuntimeError('unexpected positional args in kwargs mode') return list(self.args_iter) else: if self.concrete_kwargs is not None: diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 15e9d5f8..f60c46b3 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -1,6 +1,7 @@ import torch import enum import logging +from pathlib import Path from typing import Any, List, Tuple, Callable, Union, Dict, Type from cube.ir.operator import IRFwOperation @@ -64,22 +65,23 @@ def get_complex_data(val: Any, frame: Frame) -> Any: class FxModuleParser: """torch.fx module parser - + Attributes: save_content (bool): whether to save the content of the module dynamic_shape (bool): whether to parse the module with dynamic shape """ save_content: bool = True dynamic_shape: bool = False - + ATTR_CONTENT_FILE = 'fullmodel.pt' + ATTR_MAP_FILE = 'dist_param_map.pt' @staticmethod def shape_refine(shape: torch.Size) -> torch.Size: """Replacing scale shape [] to [1] - + Args: shape (torch.Size): tensor shape - + Returns: torch.Size: refined shape """ @@ -89,7 +91,8 @@ def shape_refine(shape: torch.Size) -> torch.Size: @staticmethod def parse(module: torch.fx.GraphModule, dummy_inputs: Dict[str, Any], - frame: Frame = None) \ + frame: Frame = None, + attr_save_dir='./') \ -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: """Parse torch.fx module into cube IR @@ -105,7 +108,7 @@ def parse(module: torch.fx.GraphModule, inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] _logger.info(f'> torch.fx parser: graph inputs: {inputs}') - + # shape propagation ShapeProp(module).propagate(dummy_inputs) # handle graph inputs @@ -188,15 +191,16 @@ def parse_complex_out(meta_out): ir_nodes = FxModuleParser.parse_node(node, module, frame) if ir_nodes is not None: all_ir_nodes += ir_nodes - + # output_nodes = [node for node in module.graph.nodes if node.op == 'output'] # assert len(output_nodes) == 1, f"get mutiple {len(all_ir_nodes)} output nodes" # output_val = frame.get_var(output_nodes[0].name) output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] if FxModuleParser.save_content: - frame.save_attr_content() - frame.save_attr_map() + attr_save_dir = Path(attr_save_dir) + frame.save_attr_content(attr_save_dir / FxModuleParser.ATTR_CONTENT_FILE) + frame.save_attr_map(attr_save_dir / FxModuleParser.ATTR_MAP_FILE) frame.pop_var() frame.pop_attr() @@ -359,7 +363,7 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, ir_node.set_output(0, output_val) else: frame.set_var(node.name, ir_node) - + _logger.info(f'parsing result: {ir_node}') return ir_nodes @@ -401,7 +405,7 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 ir_nodes = [] - + def generate_outputs(val: Any) -> Any: """Support complex data type of List, Tuple, Dict, Tensor/Object""" if isinstance(val, list): diff --git a/cube/graph/parser/script/mapping.py b/cube/graph/parser/script/mapping.py deleted file mode 100644 index 126a18f4..00000000 --- a/cube/graph/parser/script/mapping.py +++ /dev/null @@ -1,201 +0,0 @@ - -import torch - -from typing import Callable, Union -from functools import partial - -import cube.graph.function as function -from cube.ir.operator import IRFwOperation -from cube.graph.parser.register import CustomizedOps - - -class Sign2Op: - - @staticmethod - def map(signature: str) -> Callable[..., Union[IRFwOperation, int, float]]: - """ - Map the signature to GenericLogicalOp - """ - if signature in Sign2Op.kOpMap: - return partial(Sign2Op.kOpMap[signature], signature=signature) - if CustomizedOps.exist(signature): - return CustomizedOps.map(signature) - raise KeyError(f"{signature} is not supported yet") - - @staticmethod - def exist(signature: str) -> bool: - if signature in Sign2Op.kOpMap: - return True - if CustomizedOps.exist(signature): - return True - return False - - # functional templates - __ftemplate = lambda name: f'torch.nn.functional.{name}' - - # tensor template - __ttemplate = lambda name: f'torch.{name}' - - # runtime template - __rtemplate = lambda name: f'cube.runtime.function.function.{name}' - - - kOpMap = { - - # torch nn functional - - __ftemplate('linear') : function.Linear, - - __ttemplate('matmul'): function.Matmul, - - __ftemplate('softmax') : function.Softmax, - - __ftemplate('dropout') : function.Dropout, - - __ftemplate('gelu') : function.GeLU, - __ttemplate('gelu') : function.GeLU, - - __ftemplate('silu') : function.SiLU, - __ttemplate('silu') : function.SiLU, - - __ftemplate('_pad'): function.Pad, - - __ftemplate('layer_norm'): function.LayerNorm, - - __ftemplate('embedding'): function.Embedding, - - __ftemplate('cross_entropy'): function.CrossEntropy, - - # torch aten - - # creators - __ttemplate('zeros'): function.Zeros, - __ttemplate('ones'): function.Ones, - __ttemplate('tensor'): function.NewTensor, - __ttemplate('rand'): function.Rand, - __ttemplate('clone'): function.Clone, - - __ttemplate('add') : function.Add, - - __ttemplate('sub') : function.Sub, - - __ttemplate('mul') : function.Mul, - - __ttemplate('div') : function.Div, - - __ttemplate('floordiv') : function.FloorDiv, - - __ttemplate('neg'): function.Neg, - - __ttemplate('gt'): function.CompareGT, - __ttemplate('lt'): function.CompareLT, - __ttemplate('ge'): function.CompareGE, - __ttemplate('le'): function.CompareLE, - - __ttemplate('pow'): function.Pow, - - __ttemplate('sin'): function.Sin, - - __ttemplate('cos'): function.Cos, - - __ttemplate('tanh'): function.Tanh, - - __ttemplate('bmm') : function.BatchLinear, - - __ttemplate('sum') : function.Sum, - __ttemplate('mean') : function.Mean, - - __ttemplate('transpose') : function.Transpose, - - __ttemplate('view'): function.View, - - __ttemplate('reshape'): function.Reshape, - - __ttemplate('conv2d'): function.Conv2D, - - __ttemplate('conv3d'): function.Conv3D, - - __ttemplate('pad'): function.Pad, - - __ttemplate('select'): function.Select, - - __ttemplate('slice'): function.Slice, - - #pytorch1.11 - __ttemplate('select_scatter'): function.SelectScatter, - - __ttemplate('repeat'): function.Repeat, - - #pytorch1.11 - __ttemplate('linear'): function.Linear, - - __ttemplate('cat'): function.Cat, - - __ttemplate('stack'): function.Stack, - - __ttemplate('chunk'): function.Chunk, - - __ttemplate('flatten'): function.Flatten, - - __ttemplate('roll'): function.Roll, - - __ttemplate('adaptive_avg_pool1d'): function.AdaptiveAvgPool1d, - - # runtime functions - __rtemplate('anchor'): function.GraphAnchor, - - __rtemplate('identity'): function.Identity, - - __rtemplate('multiref'): function.MultiRef, - - __rtemplate('accum'): function.Accum, - - } - - -# see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h -# -# ScalarType enum is totally a PyTorch-internal object. Neither itself nor its underlying ints -# are accessible from its Python frontend. -class TorchScalarTypeEnumMap: - - @staticmethod - def map(underlying: int) -> torch.dtype: - - assert isinstance(underlying, int), """ - This function is to convert an underlying 'int' for a Torch-internal 'at::ScalarType' enum - to its corresponding Python-frontend 'torch.dtype' enum. - """ - - dtype = TorchScalarTypeEnumMap._fields[underlying] - - assert dtype is not None, f""" - Referenced to an unsupported ScalarType with underlying int being {underlying} - """ - - return dtype - - # Less used dtypes are masked out because PyTorch keeps **exposing and hiding** them recently - # from a view of Python frontend. - _fields = [ - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.half, - torch.float32, - torch.float64, - None, #torch.complex32, # complexHalf - None, #torch.complex64, # complexFloat - None, #torch.complex128, # complexDouble - torch.bool, - None, #torch.qint8, - None, #torch.quint8, - None, #torch.qint32, - None, #torch.bfloat16, - None, #torch.quint4x2, - None, #torch.quint2x4, - ] - - assert len(_fields) == 18, "Do not remove any item, mask it out with None" diff --git a/cube/graph/parser/script/parser.py b/cube/graph/parser/script/parser.py deleted file mode 100644 index 6f2793f1..00000000 --- a/cube/graph/parser/script/parser.py +++ /dev/null @@ -1,845 +0,0 @@ -import torch -import enum -import re -import logging -from typing import Any, List, Tuple, Optional - -from cube.ir.cten import IRObject -from cube.ir.operator import IRFwOperation -from cube.graph.function.pyfunc import IRPyFunc -from cube.ir.tensor import IRFullTensor -import cube.ir as ir -from cube.graph.parser.frame import Frame -from cube.graph.parser.script.mapping import Sign2Op -from cube.graph.parser.dtype import DType2IRDType - - -_logger = logging.getLogger(__name__) -_refmodule = torch.nn.Module() - -class ErasedDevice: - pass - - -class ScriptNodeKind(enum.Enum): - PrimGetAttr = 1 - PrimCallMethod = 2 - PrimCallFunction = 3 # -> the parser may end here - PrimConstant = 4 - AtenOp = 5 # -> the parser may end here - PrimIf = 6 # dynamic - PrimListConstruct = 7 - PrimListUnpack = 8 - PrimTupleUnpack = 9 - PrimPythonOp = 10 - PrimDevice = 11 # erased - PrimLoop = 12 - PrimSetAttr = 13 - - -class ScriptModuleParser: - - save_content: bool = True - - @staticmethod - def parse_module(module, - input_shapes: Optional[ Tuple[List[int],] ] = None, - frame: Frame = None) \ - -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: - """ - The overall entry to parse a torchscript graph module - """ - frame = frame if frame is not None else Frame() - frame.push_var() - frame.push_attr() - - inputs = list(module.graph.inputs())[1:] - if input_shapes is not None and len(input_shapes) != len(inputs): - raise RuntimeError(f"Module {module.original_name} input shape mismatch (got {len(input_shapes)} != {len(inputs)})") - - # handle graph input -- Assuming all the inputs are tensors - # kDefaultType = DType2IRDType.map(torch.get_default_dtype()) - for idx, input in enumerate(inputs): - if isinstance(input.type(), torch._C.TensorType): - shape = None if input_shapes is None else input_shapes[idx] - dtype = ir.IRDType.unknown # kDefaultType - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.debugName()) - else: - val = IRObject(name=input.debugName()) - frame.add_var(input.debugName(), val, graph_arg=idx) - input_val = [frame.get_var(input.debugName()) for input in inputs] - - # handle nodes - all_ir_nodes: List[IRFwOperation] = list() - for node in module.graph.nodes(): - ir_nodes = ScriptModuleParser.parse_node(node, module, frame) - for ir_node in ir_nodes: - try: - ret = ir_node.infer_shape() - if not ret: - _logger.error(f'{ir_node} cannot infer shape') - except Exception: - raise RuntimeError( - f"====== Shape Infer Error ====\n\n\n" - f"IR Node: {ir_node}\n\n" - f"Node:\n{node}\n" - f"====== Shape Infer Error ====\n\n\n" - ) - all_ir_nodes += ir_nodes - - # handle outputs - output_var_name = [output.debugName() for output in module.graph.outputs()] - output_val = [frame.get_var(var_name) for var_name in output_var_name] - - # flatten output_val - outputs = list() - for val in output_val: - if isinstance(val, list): - outputs += val - else: - outputs.append(val) - output_val = outputs - - frame.pop_var() - frame.pop_attr() - if ScriptModuleParser.save_content: - frame.save_attr_content() - return input_val, all_ir_nodes, output_val - - @staticmethod - def parse_module_method(module, method: torch._C.ScriptMethod, frame: Frame): - """ - Parse module method - """ - frame.push_var() - - input_var_name = [input.debugName() for input in method.graph.inputs()] - kDefaultType = DType2IRDType.map(torch.get_default_dtype()) - - for index, var_name in enumerate(input_var_name[1:]): # omit self - frame.add_var(var_name, IRFullTensor(name=var_name, requires_grad=False, dtype=kDefaultType), graph_arg=index) - - input_val = [frame.get_var(var_name) for var_name in input_var_name[1:]] - - all_ir_nodes: List[IRFwOperation] = list() - for node in method.graph.nodes(): - ir_nodes = ScriptModuleParser.parse_node(node, module, frame) - for ir_node in ir_nodes: - try: - ret = ir_node.infer_shape() - if not ret: - _logger.error(f'{ir_node} cannot infer shape') - except Exception: - raise RuntimeError( - f"====== Shape Infer Error ====\n\n\n" - f"IR Node: {ir_node}\n\n" - f"Module:\n{module.code}\n\n" - f"Node:\n{node}\n" - f"====== Shape Infer Error ====\n\n\n" - ) - all_ir_nodes += ir_nodes - - # handle graph output - output_var_name = [output.debugName() for output in method.graph.outputs()] - output_val = [frame.get_var(var_name) for var_name in output_var_name] - - frame.pop_var() - return input_val, all_ir_nodes, output_val - - @staticmethod - def ntype(node: torch._C.Node): - if node.kind() == 'prim::GetAttr': - return ScriptNodeKind.PrimGetAttr - if node.kind() == 'prim::SetAttr': - return ScriptNodeKind.PrimSetAttr - if node.kind() == 'prim::CallMethod': - return ScriptNodeKind.PrimCallMethod - if node.kind() == 'prim::CallFunction': # the op call - return ScriptNodeKind.PrimCallFunction - if node.kind() == 'prim::Constant': - return ScriptNodeKind.PrimConstant - if node.kind().startswith('aten::'): - return ScriptNodeKind.AtenOp - if node.kind() == 'prim::If': - return ScriptNodeKind.PrimIf - if node.kind() == 'prim::Loop': - return ScriptNodeKind.PrimLoop - if node.kind() == 'prim::ListConstruct': - return ScriptNodeKind.PrimListConstruct - if node.kind() == 'prim::TupleConstruct': - return ScriptNodeKind.PrimListConstruct - if node.kind() == 'prim::ListUnpack': - return ScriptNodeKind.PrimListUnpack - if node.kind() == 'prim::TupleUnpack': - return ScriptNodeKind.PrimListUnpack - if node.kind() == 'prim::PythonOp': - return ScriptNodeKind.PrimPythonOp - if node.kind() == 'prim::device': - return ScriptNodeKind.PrimDevice - raise RuntimeError(f"Unkown node kind {node.kind()} from torchscript module") - - @staticmethod - def parse_node(node: torch._C.Node, module, frame: Frame) -> List[IRFwOperation]: - # print("### parse_node {}".format(node)) - """ - Parse the node and return the IRFwOperation nodes - """ - node_type = ScriptModuleParser.ntype(node) - try: - if node_type == ScriptNodeKind.PrimCallFunction: - return ScriptModuleParser.parse_prim_function_node(node, module, frame) - if node_type == ScriptNodeKind.AtenOp: - return ScriptModuleParser.parse_aten_node(node, module, frame) - if node_type == ScriptNodeKind.PrimCallMethod: - return ScriptModuleParser.parse_prim_method_node(node, module, frame) - if node_type == ScriptNodeKind.PrimGetAttr: - return ScriptModuleParser.parse_prim_attr_node(node, module, frame) - if node_type == ScriptNodeKind.PrimSetAttr: - return ScriptModuleParser.parse_prim_setattr_node(node, module, frame) - if node_type == ScriptNodeKind.PrimConstant: - return ScriptModuleParser.parse_prim_constant_node(node, module, frame) - if node_type == ScriptNodeKind.PrimListConstruct: - return ScriptModuleParser.parse_prim_list_construct_node(node, module, frame) - if node_type == ScriptNodeKind.PrimListUnpack: - return ScriptModuleParser.parse_prim_list_unpack_node(node, module, frame) - if node_type == ScriptNodeKind.PrimPythonOp: - return ScriptModuleParser.parse_prim_python_op_node(node, module, frame) - if node_type == ScriptNodeKind.PrimIf: - return ScriptModuleParser.parse_prim_if_node(node, module, frame) - if node_type == ScriptNodeKind.PrimLoop: - return ScriptModuleParser.parse_prim_loop_node(node, module, frame) - - # TODO bother assigning all ignored prim functions new NodeKinds? - if node_type == ScriptNodeKind.PrimDevice: - return ScriptModuleParser.parse_value_erased_node(node, module, frame, [ErasedDevice()]) - - raise NotImplementedError(f"Un-supported node type {node_type}") - except Exception: - raise RuntimeError(f"\n\nParsing error at node:\n\t{node}\n") - - @staticmethod - def parse_prim_function_node(node, module, frame: Frame) -> List[IRFwOperation]: - """ - parse node like: - Tensor = prim::CallFunction(%5, %input.1, %3, %4) - %5 : Function = prim::Constant[name="linear"]() - %12 : (Tensor, Tensor) = prim::CallFunction(%5, %x1.1, %x2.1) - """ - inputs = [input for input in node.inputs()] - - # get signature - fnode = node.inputsAt(0).node() - if not ScriptModuleParser.ntype(fnode) == ScriptNodeKind.PrimConstant: - raise RuntimeError(f"Found unexpected function call node: {fnode}") - fsig = frame.get_var(inputs[0].debugName()) - - # get inputs - input_vals = list() - for index, input in enumerate(inputs[1:]): - var_name = input.debugName() - val = frame.get_var(var_name) - input_vals.append(val) - - # map to IR operator - ir_node = Sign2Op.map(fsig)(*input_vals) - - # push output in the frame - # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) - # : >>> dir(a) - # : >>> a.elements() # [TensorType, TensorType] - cnt = 0 - for output in node.outputs(): - if isinstance(output.type(), torch._C.TupleType): - tuplen = len(output.type().elements()) - ir_output = [ir_node.output(idx) for idx in range(cnt, cnt+tuplen)] - cnt += tuplen - else: - ir_output = ir_node.output(cnt) - cnt += 1 - frame.add_var(output.debugName(), ir_output) - - if cnt != len(ir_node.outputs()): - raise RuntimeError( - f"Parse fail: {fsig} has {cnt} outputs != pre-defined {len(ir_node.outputs())}" - ) - - return [ir_node] - - @staticmethod - def parse_aten_node(node, module, frame: Frame) -> List[IRFwOperation]: - """ - Parse script module node like: - %13 : Tensor = aten::gt(%output1.1, %output2.1) - """ - fsig = node.kind() - fsig = re.sub('aten::', 'torch.', fsig) - inputs = [input for input in node.inputs()] - outputs = [output for output in node.outputs()] - - # handle inputs: - input_val = [frame.get_var(input.debugName()) for input in inputs] - - # special handling on aten::size(tensor: tensor, dim: int) - if fsig == 'torch.size': - if len(inputs) == 2: - tensor, dim = input_val - output: int = tensor.shape[dim] - else: - tensor = input_val[0] - output: List[int] = list(tensor.shape) - frame.add_var(outputs[0].debugName(), output) - return [] - - # aten::__getitem__.t(t[](a) list, int idx) -> t(*)" - # REMARK List-type only. '__getitem__' cannot serve as accessor to tensor element. - elif fsig == 'torch.__getitem__': - # NOTE there are other overloadings of '__getitem__' for 'str'(i.e. char list), 'Dict(t)' in TorchScript - container, index = input_val - frame.add_var(outputs[0].debugName(), container[index]) - return [] - - elif fsig == 'torch.__range_length': - lo, hi, step = input_val - rng_len = ScriptModuleParser.aten___range_length(lo, hi, step) - frame.add_var(outputs[0].debugName(), rng_len) - return [] - - elif fsig == 'torch.__derive_index': - index, start, step = input_val - derived = ScriptModuleParser.aten___derive_index(index, start, step) - frame.add_var(outputs[0].debugName(), derived) - return [] - - # May be a symbolic object i.e. IRFwOperation, - # or, occasionally this node can be statically evaluated, therefore a concrete value - result = Sign2Op.map(fsig)(*input_val) - - if isinstance(result, IRFwOperation): - # to create IR node - - ir_node = result - if len(ir_node.outputs()) != len(outputs): - assert len(outputs) == 1, ( - f"Farse Fail: torchscript has different output number of IR node: {len(outputs)} != {len(ir_node.outputs())}\n" - f"This can only be happend to have pre-defined output number of 1" - ) - node_outputs = (ir_node.outputs(),) - # raise RuntimeError( - # f"Parse fail: {fsig} has {len(outputs)} outputs != pre-defined {len(ir_node.outputs())}" - # ) - else: - node_outputs = ir_node.outputs() - - # handle outputs - for output, node_output in zip(outputs, node_outputs): - frame.add_var(output.debugName(), node_output) - - return [ir_node] - - else: - # concrete value. - assert len(outputs) == 1, "Cases with multiple outputs are only List/Tuple-Unpack and handled specially" - frame.add_var(outputs[0].debugName(), result) - return [] - - @staticmethod - def parse_prim_method_node(node, module, frame: Frame) -> List[IRFwOperation]: - """ - Parse script module node like: - %output.1 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) - - prim::CallMethod has a underlying submodule - """ - inputs = [input for input in node.inputs()] - outputs = [output for output in node.outputs()] - - # forward - label = node.s('name') - # handle inputs -- in stack with reverse order - for input in inputs[1:][::-1]: - var_name = input.debugName() - val = frame.get_var(var_name) - frame.push_param(var_name) - - # recursively parse the module - self_module = node.inputsAt(0).debugName() == 'self' - if self_module: - call_module = module - else: - call_module = frame.get_var(node.inputsAt(0).debugName()) - assert isinstance(call_module, torch.nn.Module), "the call module is not torch.nn.Module" - # call_module = getattr(module, node.inputsAt(0).debugName()) - frame.push_attr() - - call_method = getattr(call_module, label) - _, ir_nodes, outputs_val = ScriptModuleParser.parse_module_method(call_module, call_method, frame=frame) - - if not self_module: - frame.pop_attr() - - # pop out the frame - frame.pop_param(times=len(inputs)-1) - - # handle outputs - outputs = [output for output in node.outputs()] - for output, val in zip(outputs, outputs_val): - frame.add_var(output.debugName(), val) - - return ir_nodes - - @staticmethod - def parse_prim_attr_node(node, module, frame: Frame) -> List[None]: - """ - Parse script module node like: - %2 :__torch__.torch.nn.modules.linear.___torch_mangle_0.Linear = prim::GetAttr[name="linear1"](%self) - %3 : Tensor = prim::GetAttr[name="weight"](%self) - Or: - %embed.1 : __torch__.torch.nn.modules.sparse.Embedding = prim::GetAttr[name="embed"](%self) - %embed.3 : Tensor = prim::CallMethod[name="forward"](%embed.1, %input_ids.1) - - This will add frame with the variable name and it's value - - The value can be: - 1). (IRFullTensor) the tensor edge in graph - 2). (str code) symbolic value based on runtime info (e.g., self.training) - 3). (str) Function or torch.nn.moudles - - Returns: - Empty list - """ - global _refmodule - - module_name = node.inputsAt(0).debugName() - module = module if module_name == 'self' else frame.get_var(module_name) - assert isinstance(module, torch.nn.Module) - - label = node.s('name') - var_name = node.outputsAt(0).debugName() - dtype = node.outputsAt(0).type().str() - - if dtype == 'Tensor?': - tensor = getattr(module, label) - if torch.is_tensor(tensor): - dtype = 'Tensor' - - # this usually means weight (nn.Parameter in torch) - if dtype == 'Tensor': - tensor = getattr(module, label) - shape = list(tensor.shape) - if frame.has_attr(label): - ir_tensor = frame.get_attr(label) - else: - ir_tensor = IRFullTensor( - name=label, shape=shape, - requires_grad=tensor.requires_grad, - dtype=DType2IRDType.map(tensor.dtype) - ) - if isinstance(tensor, torch.nn.Parameter): - ir_tensor.as_param() - else: - ir_tensor.as_buffer() - frame.add_attr(label, ir_tensor) - frame.add_attr_content(ir_tensor.tid, tensor) - frame.add_var(var_name, ir_tensor) - # symbolic attributes - elif dtype in ['bool', 'int', 'float']: - if hasattr(_refmodule, label): - val = 'self.' + label - else: - val = getattr(module, label) - frame.add_var(var_name, val) - # NoneType - elif dtype == 'NoneType': - frame.add_var(var_name, None) - else: - if isinstance(module, torch.nn.ModuleList): - if str.isdecimal(label): - val = module[int(label)] - else: - val = getattr(module, label) - else: - val = getattr(module, label) - frame.add_var(var_name, val) - return list() - - @staticmethod - def parse_prim_setattr_node(node, module, frame) -> List[IRFwOperation]: - """ - = prim::SetAttr[name="past_k"](%self, %k.1) - """ - signature = 'setattr' - target = node.s('name') # past_k - module_name = node.inputsAt(0).debugName() - module = module if module_name == 'self' else frame.get_var(module_name) - - var = node.inputsAt(1).debugName() # %k.1 - dtype = node.inputsAt(1).type().str() # torch.Tensor - assert dtype == 'Tensor', "Only tensor can be set inside module" - var_tensor = frame.get_var(var) - # make sure of having same attribute name in graph - assert frame.has_attr(target), f"SetAttr currently only supports replace an existing tensor attribute" - target_tensor = frame.get_attr(target) # IRFullTensor - # target_name = f"{target_tensor.name}_{target_tensor.tid}" - func = IRPyFunc(signature, ('self', target_tensor, var_tensor), ()) - # setattr(module, target, var) -> This will have error - return [func] - - - @staticmethod - def parse_prim_constant_node(node, module, frame) -> List[None]: - """ - Parse script module node like: - %6 : Function = prim::Constant[name="dropout"]() - %5 : bool = prim::Constant[value=0]() - - This will add frame with the variable name and it's value - - Returns: - Empty list - """ - if len(list(node.inputs())) != 0: - raise RuntimeError(f"prim::Constant node: {node} has inputs") - var_name = node.outputsAt(0).debugName() - dtype = node.outputsAt(0).type().str() - - if dtype == 'Function': - signature = repr(node.outputsAt(0).type()) - if '__torch__.' in signature: - signature = re.sub('__torch__.', '', signature) - frame.add_var(var_name, signature) - else: - val = node.outputsAt(0).toIValue() - frame.add_var(var_name, val) - return list() - - @staticmethod - def parse_prim_if_node(node, module, frame: Frame) -> List[IRFwOperation]: - """ - Parse script module node like - %output1 : Tensor, %output2 : Tensor = prim::If(%15) # /tmp/ipykernel_27188/2459450745.py:13:8 - block0(): - -> (%1, %2) - block1(): - -> (%3, %4) - - and the only input (e.g. %15) must be of type bool. - """ - - inputs : List[torch._C.Value] = list(node.inputs()) - outputs : List[torch._C.Value] = list(node.outputs()) - - assert len(inputs) == 1 - in_val = frame.get_var(inputs[0].debugName()) - if not isinstance(in_val, bool): - raise RuntimeError("Dynamic Graph is not supported yet") - - # type: torch._C.Block - true_block, false_block = node.blocks() - chosen_block : torch._C.Block = true_block if in_val else false_block - body_out_vars = list(chosen_block.outputs()) - - all_ir_nodes : List[IRFwOperation] = [] - - # Evaluate the 'eval_block' in a new frame, to isolate within-block variables from - # polluting the current frame. And we'll manually bind all resultant variables later on. - frame.push_var(inherit_from_top=True) - - # prim::If's blocks do not have any subgraph parameters, directly evaluate the body - for subnode in chosen_block.nodes(): - subnode : torch._C.Node - - sub_ir_nodes : List[IRFwOperation] = ScriptModuleParser.parse_node(subnode, module, frame) - - for ir_node in sub_ir_nodes: - try: - ret = ir_node.infer_shape() - if not ret: - _logger.error(f'{ir_node} cannot infer shape') - except Exception: - raise RuntimeError( - f"====== Shape Infer Error ====\n\n\n" - f"IR Node: {ir_node}\n\n" - f"Module:\n{module.code}\n\n" - f"Node:\n{node}\n" - f"====== Shape Infer Error ====\n\n\n" - ) - - all_ir_nodes += sub_ir_nodes - - # retrieve the block's resultant values - result_vals = [frame.get_var(body_out_var.debugName()) for body_out_var in body_out_vars] - - # clean up - frame.pop_var() - - # bind the prim:If's resultant variables - assert len(result_vals) == len(outputs) - for output, out_val in zip(outputs, result_vals): - frame.add_var(output.debugName(), out_val) - - return all_ir_nodes - - @staticmethod - def parse_prim_loop_node(node, module, frame: Frame) -> List[IRFwOperation]: - """ - Inputs: - %max_iter_count : int - %init_condition : bool - %x_1 : T_1 - ... - %x_N : T_N - %dependencies : R - - Syntax: - %y_1 : T_1, ..., %y_N : T_N = prim::Loop(%max_iter_count, %init_condition, %x_1, ..., %x_N) - block0(%iter_step : int, %p_1 : T_1, ..., %p_N : T_N): - ... - %r_1 : T_1 = some_func(%x_1, %dependencies) - ... - %r_N : T_N = ... - %next_condition : bool = ... - -> (%next_condition, %r_1, ..., %r_N) - - REMARK: - - Outer variables (%dependencies) may be referenced in the Loop-body/subgraph, this is AKA _free variables_. - In contrast, a standalone TorchScript function/graph will have all variables, - including its parameters, defined within its scope. - - In other words, functions/graphs have no free variables. - - Semantics: - - The next step is evaluated if both (%iter_step < %max_iter_count) and (%next_condition == True). - - (%y_1, ..., %y_N) are bound to the last (%r_1, ..., %r_N) returned. - If no step is ever evaluated, they are (%x_1, ..., %x_N). - """ - inputs : List[torch._C.Value] = list(node.inputs()) - outputs : List[torch._C.Value] = list(node.outputs()) - - in_vals = [frame.get_var(input.debugName()) for input in inputs] - - max_iter_count, init_condition = in_vals[0:2] - if not isinstance(max_iter_count, int): - raise RuntimeError("The upper bound of the loop must be able to be statically evaluated") - if not isinstance(init_condition, bool): - raise RuntimeError("The init condition of the loop must be able to be statically evaluated") - - # type: Subgraph - loop_block : torch._C.Block = list(node.blocks())[0] - - body_in_vars : torch._C.Value = list(loop_block.inputs()) - iter_step_var = body_in_vars[0] - p_vars = body_in_vars[1:] - - body_out_vars = list(loop_block.outputs()) - - step = 0 - condition = init_condition - loop_carried_vals = in_vals[2:] - - all_ir_nodes : List[IRFwOperation] = [] - - while step < max_iter_count and condition: - - # create the context for evaluating the body, and bind loop variables %iter_step, %p_1, ... - - # Defensively we don't let variables defined in the Loop body subgraph pollute the outer graph. - # So we'd better duplicate all existing variables into a new frame (namely 'inherit_from_top'), - # and clean up this new frame after the interpretation of the whole loop execution. - frame.push_var(inherit_from_top=True) - - frame.add_var(iter_step_var.debugName(), step) - - # At the evaluation of each step, we cannot call Frame's 'push_param(var_name)' and 'add_var(var_name, val, graph_arg=N)' APIs, - # because all intermediate loop-carried values do not have syntactically static names. - # - # For the sake of isolation, we don't bind carried values onto {y_i}s variables and overwrite the binding - # during evaluation, either. - assert len(p_vars) == len(loop_carried_vals) - for p_var, carried_val in zip(p_vars, loop_carried_vals): - frame.add_var(p_var.debugName(), carried_val) - - # evaluate the body block - for subnode in loop_block.nodes(): - subnode : torch._C.Node - - sub_ir_nodes : List[IRFwOperation] = ScriptModuleParser.parse_node(subnode, module, frame) - - for ir_node in sub_ir_nodes: - try: - ret = ir_node.infer_shape() - if not ret: - _logger.error(f'{ir_node} cannot infer shape') - except Exception: - raise RuntimeError( - f"====== Shape Infer Error ====\n\n\n" - f"IR Node: {ir_node}\n\n" - f"Module:\n{module.code}\n\n" - f"Node:\n{node}\n" - f"====== Shape Infer Error ====\n\n\n" - ) - - all_ir_nodes += sub_ir_nodes - - # rebind for next step and clean-ups - step_result_vals = [frame.get_var(body_out_var.debugName()) for body_out_var in body_out_vars] - condition = step_result_vals[0] - loop_carried_vals = step_result_vals[1:] - step += 1 - - frame.pop_var() - - if not isinstance(condition, bool): - raise RuntimeError(f"At the {step}-th step the condition is not evaluated to a constant bool") - - assert len(outputs) == len(loop_carried_vals) - for output, y_val in zip(outputs, loop_carried_vals): - frame.add_var(output.debugName(), y_val) - - return all_ir_nodes - - - @staticmethod - def parse_prim_list_construct_node(node, module, frame: Frame) -> List[None]: - """ - Parse script module node like - %8 : int[] = prim::ListConstruct(%3) - """ - inputs = [input for input in node.inputs()] - outputs = [output for output in node.outputs()] - assert len(outputs) == 1 - output = outputs[0] - out_val = list() - for input in inputs: - out_val.append(frame.get_var(input.debugName())) - frame.add_var(output.debugName(), out_val) - return list() - - @staticmethod - def parse_prim_list_unpack_node(node, module, frame: Frame) -> List[None]: - """ - Parse script module node like: - %q.1 : Tensor, %k.1 : Tensor, %v.1 : Tensor = prim::TupleUnpack(%11) - """ - inputs = [input for input in node.inputs()] - outputs = [output for output in node.outputs()] - if len(inputs) != 1: - raise RuntimeError("Find UnpackTuple has more than one input") - if len(outputs) == 1: - raise RuntimeError("Find UnpackTuple has only one output") - tuple_inputs = frame.get_var(inputs[0].debugName()) - if len(tuple_inputs) != len(outputs): - raise RuntimeError("Expected unpacked tuple number have same length of tupled input") - for output, val in zip(outputs, tuple_inputs): - frame.add_var(output.debugName(), val) - return list() - - @staticmethod - def parse_prim_python_op_node(node, module, frame): - """ - parse node like: - %64 : Tensor = ^OuterProductMean()(%opm_left.1, %opm_right.1, %outer_out_proj) - """ - # get inputs - input_vals = list() - for input in node.inputs(): - var_name = input.debugName() - val = frame.get_var(var_name) - input_vals.append(val) - - fsig: str = str(node.pyname()) - - # map to IR operator - ir_node = Sign2Op.map(fsig)(*input_vals) - - # push output in the frame - # help: >>> a = torch._C.TupleType([torch._C.TensorType.getInferred()]) - # : >>> dir(a) - # : >>> a.elements() # [TensorType, TensorType] - cnt = 0 - for output in node.outputs(): - if isinstance(output.type(), torch._C.TupleType): - tuplen = len(output.type().elements()) - ir_output = [ir_node.output(idx) for idx in range(cnt, cnt+tuplen)] - cnt += tuplen - else: - ir_output = ir_node.output(cnt) - cnt += 1 - frame.add_var(output.debugName(), ir_output) - - if cnt != len(ir_node.outputs()): - raise RuntimeError( - f"Parse fail: {fsig} has {cnt} outputs != pre-defined {len(ir_node.outputs())}" - ) - - # print(input_vals) - # print(node.pyname()) - # print(dir(node)) - # print(tuple(node.inputs())) - # print(tuple(node.outputs())) - # raise NotImplementedError("Cannot support torch.jit.ignore") - return [ir_node] - - @staticmethod - def parse_value_erased_node(node, module, frame, erased_vals: List[Any]): - outputs = list(node.outputs()) - - assert len(outputs) == len(erased_vals) - for output, ev in zip(outputs, erased_vals): - frame.add_var(output.debugName(), ev) - return [] - - - - @staticmethod - def flatten(smodule, depth=0): - """ - Flatten the recursive script module to function and aten primitives - """ - # stashed_module = list() - inputs = [input for input in smodule.graph.inputs()] - print(' '*depth, f'graph inputs: {inputs}') - if len(list(smodule.children())) == 0: - for node in smodule.graph.nodes(): - print(' '*depth, node) - else: - for node in smodule.graph.nodes(): - print(' '*depth, node) - if node.kind() == 'prim::CallMethod': - label = node.inputsAt(0).node().s('name') - submodule = getattr(smodule, label) - ScriptModuleParser.flatten(submodule, depth+1) - - @staticmethod - def aten___range_length(lo, hi, step): - """ - aten::__range_length(int lo, int hi, int step) -> int - - Python loops - ``` - for i in range(L, H, S): - use(i) - ``` - will be translated to TorchScript - ``` - _c = aten::__range_length(L, H, S) - for _k < _c: - i = aten::__derive_index(k, L, S) - use(i) - ``` - """ - if not (isinstance(lo, int) and isinstance(hi, int) and isinstance(step, int)): - raise RuntimeError("All inputs to __range_length must be statically evaluated") - if step == 0: - raise RuntimeError("Step cannot be zero") - - return len(range(lo, hi, step)) - - @staticmethod - def aten___derive_index(index, start, step): - if not (isinstance(index, int) and isinstance(start, int) and isinstance(step, int)): - raise RuntimeError("All inputs to __derive_index must be statically evaluated") - - return start + index * step - - - diff --git a/cube/program.py b/cube/program.py index 9e18d6d1..63af7d8b 100644 --- a/cube/program.py +++ b/cube/program.py @@ -53,7 +53,7 @@ def set_output(self, outputs: Tuple[Any]): self.instance._graph.reset_outputs(len(outputs)) for idx, otensor in enumerate(outputs): self.instance._graph.set_output(idx, otensor) - + def finalize(self): """ Close the recording of program. @@ -93,7 +93,7 @@ def get_batch_size(self) -> int: def set_batch_size(self, bs: int): self.dataloader.set_batch_size(bs) return - + def get_runtime_sample(self): return next(self.dataloader) @@ -151,7 +151,7 @@ def __init__(self, model: Optional[torch.nn.Module], Args: model (Optional[torch.nn.Module]): single-device model description, only required for rank 0 - save_content (bool): + save_content (bool): whether to save the content of model and load it into generated model. Default True. dynamic_shape (bool): whether to use dynamic shape. Default False. @@ -170,10 +170,10 @@ def __init__(self, model: Optional[torch.nn.Module], def dummy_input(self) -> Any: """Get dummy real-tensor input from on CPU""" return self._dummy_input - + @dummy_input.setter def dummy_input(self, val): - + def complex(val: Any): """Complex to CPU""" if isinstance(val, tuple): @@ -191,7 +191,7 @@ def complex(val: Any): self._dummy_input = complex(val) def get_graph(self): - return self.ir_graph + return self._ir_graph def load_module(self, filename: str): """Load module from file.""" @@ -214,12 +214,10 @@ def __call__(self, *args): assert self._ir_graph is None, \ f"multiple forward on a semantic model is not allowed" if DeviceGroup().local_rank == 0: - input_shapes = [tuple(t.shape) if isinstance(t, IRTensor) else None for t in args] self._ir_graph = parser.convert_model( self.model, - input_shapes=input_shapes, dummy_input=self.dummy_input, - save_content=self.save_content, + attr_save_dir='./', dynamic_shape=self.dynamic_shape ) return self._ir_graph(*args) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index b946e16a..2520476c 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -25,7 +25,7 @@ def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, torch.distributed.send(tensor, dst) else: assert rank == dst - tensor = torch.empty(shape, dtype=dtype, + tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device() ) if async_op: @@ -55,7 +55,7 @@ def all_reduce(tensor: torch.Tensor, else: torch.distributed.all_reduce(tensor, group=group) if not async_op: - CudaTimer().stop(field_name='comm', predefined=True) + CudaTimer().stop(field_name='comm', predefined=True) return tensor @@ -82,7 +82,7 @@ def all_gather(tensor: torch.Tensor, dim: int, return otensor -def reduce_scatter(tensor: torch.Tensor, dim: int, +def reduce_scatter(tensor: torch.Tensor, dim: int, ranks: Tuple[int], async_op=False) -> torch.Tensor: """ ReduceScatter @@ -267,7 +267,7 @@ def broadcast(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, src: tensor = itensor.contiguous() if not itensor.is_contiguous() else itensor else: assert rank in ranks - tensor = torch.empty(shape, + tensor = torch.empty(shape, device=torch.cuda.current_device(), requires_grad=False, dtype=dtype) work = torch.distributed.broadcast(tensor, src, group=group, async_op=async_op) if work and rank != src: @@ -281,10 +281,10 @@ def exchange(tensor: torch.Tensor, ranks: List[int], async_op=False) -> torch.Te """ Exchange a same-shaped tensor between two ranks """ - + if not async_op: CudaTimer().start(field_name='comm', predefined=True) - + assert len(ranks) == 2 group = DeviceGroup().get_group(ranks) myrank = torch.distributed.get_rank(group) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..275ac755 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,8 @@ +pytest +mock +pytest-mock +tox +coverage +pytest-cov +tabulate +tox-conda diff --git a/requirements.txt b/requirements.txt index b1866504..d5b9475b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ matplotlib -pytest setuptools==60.7.0 more-itertools dill +torch diff --git a/tests/test_parser.py b/tests/test_parser.py deleted file mode 100644 index 8923261d..00000000 --- a/tests/test_parser.py +++ /dev/null @@ -1,61 +0,0 @@ -# run tests: -# pytest ./tests/test_parser.py - -import pytest -import torch -from cube.graph.function.creators import IROnes, IRZeros - -from cube.graph.parser.frame import Frame -from cube.graph.parser.parser import ScriptModuleParser -from cube.graph.torch_dtype_mapping import DType2IRDType -from cube.ir.dtype import IRDType -from cube.ir.tensor import IRFullTensor -from cube import ir - -@pytest.mark.parametrize( - "aten_op, ir_op_cls", - [("zeros", IRZeros), ("ones", IROnes)] -) -def test_optional_dtype_none(aten_op, ir_op_cls): - g = torch._C.parse_ir(f''' - graph(): - %d : int = prim::Constant[value=2]() - %shape : int[] = prim::ListConstruct(%d, %d, %d) - %none : NoneType = prim::Constant() - %z : Tensor = aten::{aten_op}(%shape, %none, %none, %none, %none) - return (%z) - ''') - frame = Frame() - frame.push_var() - frame.push_attr() - - for node in g.nodes(): - ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) - for node in ir_nodes: - if isinstance(node, ir_op_cls): - assert node.output(0).dtype == DType2IRDType.map(torch.get_default_dtype()) - -@pytest.mark.parametrize( - "aten_op, ir_op_cls", - [("zeros", IRZeros), ("ones", IROnes)] -) -def test_optional_dtype_underlying_int(aten_op, ir_op_cls): - # ScalarType(3) == torch.int32 - g = torch._C.parse_ir(f''' - graph(): - %d : int = prim::Constant[value=2]() - %shape : int[] = prim::ListConstruct(%d, %d, %d) - %none : NoneType = prim::Constant() - %scalarType : int = prim::Constant[value=3]() - %z : Tensor = aten::{aten_op}(%shape, %scalarType, %none, %none, %none) - return (%z) - ''') - frame = Frame() - frame.push_var() - frame.push_attr() - - for node in g.nodes(): - ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) - for node in ir_nodes: - if isinstance(node, ir_op_cls): - assert node.output(0).dtype == IRDType.int32 diff --git a/tests/test_prim_loop.py b/tests/test_prim_loop.py deleted file mode 100644 index 575d7342..00000000 --- a/tests/test_prim_loop.py +++ /dev/null @@ -1,156 +0,0 @@ -# run tests: -# pytest ./tests/test_prim_loop.py - -import pytest -import torch -import cube - -from cube.graph.parser.frame import Frame -from cube.graph.parser.parser import ScriptModuleParser -from cube.ir.tensor import IRFullTensor -from cube import ir - -# Stub objects: -# - A stub object for 'ScriptModule' should have members: -# -- entry_method_normally_forward: Stub[ScriptMethod] -# -- code: str # only to avoid AttributeError, could be empty -# and optionally: -# -- other_script_method: Stub[ScriptMethod] -# -# - A stub object for 'ScriptMethod' should have fields: -# -- graph: torch._C.Graph - -class StubScriptMethod(object): - def __init__(self, graph: torch._C.Graph) -> None: - self.graph = graph - -# REMARK: -# 'torch._C.parse_ir' will change local variable names into unique-number ID, e.g. -# graph(%p: int): -# %local = ... -# becomes: -# graph(%p: int): -# %1 = ... - -def out_var_name0(g): - return next(g.outputs()).debugName() - -def test_simple_unroll_evaluation(): - g = torch._C.parse_ir(''' - graph(%a : int): - %ub : int = prim::Constant[value=100]() - %truth : bool = prim::Constant[value=1]() - %z : int = prim::Loop(%ub, %truth, %a) - block0(%step : int, %p : int): - %r : int = aten::add(%step, %p) - -> (%truth, %r) - return (%z) - ''') - frame = Frame() - frame.push_var() - frame.add_var("a", 0) - - for node in g.nodes(): - ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) - assert len(ir_nodes) == 0 - - # %z becomes %3 - assert frame.get_var(out_var_name0(g)) == (0+99)*100//2 - -def test_unroll_with_structural_info(): - g = torch._C.parse_ir(''' - graph(%a : Tensor): - %ub : int = prim::Constant[value=3]() - %truth : bool = prim::Constant[value=1]() - %i0 : int = prim::Constant[value=0]() - %z : Tensor = prim::Loop(%ub, %truth, %a) - block0(%step : int, %p : Tensor): - %ts : Tensor[] = prim::ListConstruct(%p, %p) # at each step, double the 0-th dim - %r : Tensor = aten::cat(%ts, %i0) - -> (%truth, %r) - return (%z) - ''') - frame = Frame() - frame.push_var() - - t_a = IRFullTensor(shape=[2,3]) - frame.add_var("a", t_a) - - all_ir_nodes = [] - for node in g.nodes(): - ir_nodes = ScriptModuleParser.parse_node(node, {}, frame) - all_ir_nodes += ir_nodes - - assert len(all_ir_nodes) == 3 - - p = t_a - for i in range(3): - ir_node = all_ir_nodes[i] - in_val_1, in_val_2 = ir_node.inputs() - assert in_val_1 == p - assert in_val_2 == p - - out_val = ir_node.output(0) - assert out_val.shape == [ 2**(i+2) , 3] - - p = out_val - - -def test_nested_unroll(): - ''' - The outer loop has 3 steps, and the inner loop has 3 steps too. - ''' - - subp = torch._C.parse_ir(''' - graph(%self: int, %a : Tensor): - %ub : int = prim::Constant[value=3]() - %truth : bool = prim::Constant[value=1]() - %i0 : int = prim::Constant[value=0]() - %z : Tensor = prim::Loop(%ub, %truth, %a) - block0(%step : int, %p : Tensor): - %ts : Tensor[] = prim::ListConstruct(%p, %p) # at each step, double the 0-th dim - %r : Tensor = aten::cat(%ts, %i0) - -> (%truth, %r) - return (%z) - ''') - main = torch._C.parse_ir(''' - graph(%self: int, %a : Tensor): - %ub : int = prim::Constant[value=3]() - %truth : bool = prim::Constant[value=1]() - %z : Tensor = prim::Loop(%ub, %truth, %a) - block0(%step : int, %p : Tensor): - %r : Tensor = prim::CallMethod[name="subp"](%self, %p) - -> (%truth, %r) - return (%z) - ''') - - class StubScriptModule(object): - def __init__(self) -> None: - self.main = StubScriptMethod(main) - self.subp = StubScriptMethod(subp) - module = StubScriptModule() - - frame = Frame() - frame.push_var() - - t_a = IRFullTensor(shape=[2,3]) - frame.add_var("a", t_a) - - all_ir_nodes = [] - for node in main.nodes(): - ir_nodes = ScriptModuleParser.parse_node(node, module, frame) - all_ir_nodes += ir_nodes - - assert len(all_ir_nodes) == 9 - - p = t_a - for i in range(9): - ir_node = all_ir_nodes[i] - in_val_1, in_val_2 = ir_node.inputs() - assert in_val_1 == p - assert in_val_2 == p - - out_val = ir_node.output(0) - assert out_val.shape == [ 2**(i+2) , 3] - - p = out_val \ No newline at end of file diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..81fb0b46 --- /dev/null +++ b/tox.ini @@ -0,0 +1,18 @@ +# Tox (http://tox.testrun.org/) is a tool for running tests +# in multiple virtualenvs. This configuration file will run the +# test suite on all supported python versions. To use it, "pip install tox" +# and then run "tox" from this directory. + +[tox] +envlist = py310 +skipsdist = True + +[testenv] +passenv = * +install_command = pip install {opts} {packages} +deps = + -rrequirements.txt + -rrequirements-dev.txt +commands = coverage erase + py.test --cov={toxinidir}/cube -sx unit_tests + coverage html diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unit_tests/graph/__init__.py b/unit_tests/graph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unit_tests/graph/parser/__init__.py b/unit_tests/graph/parser/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unit_tests/graph/parser/test_converter.py b/unit_tests/graph/parser/test_converter.py new file mode 100644 index 00000000..c817daa7 --- /dev/null +++ b/unit_tests/graph/parser/test_converter.py @@ -0,0 +1,75 @@ +import tempfile +from pathlib import Path + +import torch +import pytest + +from cube.graph.parser.converter import to_fx_graph, to_ir_graph +from cube.graph.parser import FxModuleParser +from cube.ir.cten import IRObject, IRTensor + + +def test_to_graph(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, **kwargs): + return self.linear(x) + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + assert fx_graph is not None + nodes = list(fx_graph.graph.nodes) + + # starts with placeholder, and ends with output + assert nodes[0].op == 'placeholder' + assert nodes[0].target == 'x' + assert nodes[1].op == 'placeholder' + assert nodes[1].target == '**kwargs' # should keep the double stars + assert nodes[-1].op == 'output' + assert nodes[-1].target == 'output' + + # should have linear.weight, linear.bias, and linear(x) + assert any(node.op == 'get_attr' and node.target == 'linear.weight' for node in nodes) + assert any(node.op == 'get_attr' and node.target == 'linear.bias' for node in nodes) + assert any(node.op == 'call_function' and node.target == torch.nn.functional.linear for node in nodes) + + with tempfile.TemporaryDirectory() as tempdir: + to_ir_graph(fx_graph, dummy_input, attr_save_dir=tempdir, dynamic_shape=True) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_save_dir=tempdir, dynamic_shape=True) + assert ir_graph is not None + assert (Path(tempdir) / FxModuleParser.ATTR_MAP_FILE).exists() + assert (Path(tempdir) / FxModuleParser.ATTR_CONTENT_FILE).exists() + assert ir_graph.name == 'MyModule' + inputs = ir_graph.inputs() + assert len(inputs) == 2 + assert inputs[0].name == nodes[0].name + assert isinstance(inputs[0], IRTensor) + assert inputs[1].name == nodes[1].name + assert isinstance(inputs[1], IRObject) + + outputs = ir_graph.outputs() + assert len(outputs) == 1 + + nodes = list(ir_graph.nodes()) + assert any(node.signature == 'torch.nn.functional.linear' for node in nodes) + + +def test_to_ir_graph_args(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, *args): + return self.linear(x) + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + # currently we don't support *args + with pytest.raises(RuntimeError): + to_ir_graph(fx_graph, dummy_input, attr_save_dir=tempdir, dynamic_shape=True) diff --git a/unit_tests/launch_torchrun.py b/unit_tests/launch_torchrun.py new file mode 100644 index 00000000..e3b7a2f8 --- /dev/null +++ b/unit_tests/launch_torchrun.py @@ -0,0 +1,29 @@ +import uuid + +from torch.distributed.run import elastic_launch, LaunchConfig + + +def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): + launch_config = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=nproc_per_node, + rdzv_backend = "c10d", + rdzv_endpoint = "localhost:29400", + run_id = str(uuid.uuid4()), + monitor_interval=0.1, + max_restarts=0, + ) + outputs = elastic_launch(launch_config, worker_fn)(*args, **kwargs) + return outputs + + +def clone_to_cpu(tensor): + # when you use launch_torchrun + # you can't directly return a cuda tensor + # Error message: Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors] + # So you can use this function to clone a tensor to cpu + cloned_tensor = tensor.cpu().clone().detach().requires_grad_(tensor.requires_grad) + if tensor.grad is not None: + cloned_tensor.grad = tensor.grad.cpu().clone() + return cloned_tensor diff --git a/unit_tests/runtime/__init__.py b/unit_tests/runtime/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unit_tests/runtime/test_runtime_collectives.py b/unit_tests/runtime/test_runtime_collectives.py new file mode 100644 index 00000000..59580721 --- /dev/null +++ b/unit_tests/runtime/test_runtime_collectives.py @@ -0,0 +1,233 @@ +from typing import List + +import cube +import torch + +from ..launch_torchrun import launch_torchrun, clone_to_cpu + + +def _init_distributed(expected_devices=2, backend=None): + if torch.cuda.is_available() and torch.cuda.device_count() >= expected_devices: + torch.distributed.init_process_group(backend='nccl') + rank = torch.distributed.get_rank() + torch.cuda.set_device(rank) + torch.set_default_device(f'cuda:{rank}') + else: + torch.distributed.init_process_group(backend='gloo') + torch.set_default_device('cpu') + + +def _get_tensor(shape: List[int], device=None, dtype: torch.dtype = torch.float32) -> torch.Tensor: + tensor = torch.randn(shape, dtype=dtype, device=device or torch.cuda.current_device()) + return tensor + + +def _move_worker(async_op: bool): + shape = [128, 256] + + tensor = _get_tensor(shape) + tensor = cube.runtime.adapter.move(tensor, shape, torch.float32, 0, 1, async_op=async_op) + + if async_op: + tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + return clone_to_cpu(tensor) + + +def _allreduce_worker(async_op: bool): + shape = [128, 256] + + tensor = _get_tensor(shape) + sum_tensor = tensor.clone().detach() + sum_tensor = cube.runtime.adapter.all_reduce(sum_tensor, [0, 1], async_op=async_op) + + if async_op: + sum_tensor = cube.runtime.executor.AsyncCommHandler().wait(sum_tensor) + return (clone_to_cpu(tensor), clone_to_cpu(sum_tensor)) + + +def _allgather_worker(async_op: bool): + shape = [128, 256] + + tensor = _get_tensor(shape) + otensor = cube.runtime.adapter.all_gather(tensor, 0, [0, 1], async_op=async_op) + + if async_op: + otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + return (clone_to_cpu(tensor), clone_to_cpu(otensor)) + + +def _reduce_scatter_worker(async_op: bool): + shape = [128, 256] + + tensor = _get_tensor(shape) + otensor = cube.runtime.adapter.reduce_scatter(tensor, 0, [0, 1], async_op=async_op) + + if async_op: + otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + + return (clone_to_cpu(tensor), clone_to_cpu(otensor)) + + +def _all2all_worker(async_op): + shape = [128, 256] + + tensor = _get_tensor(shape) + + # # synchronize + otensor = cube.runtime.adapter.all_to_all(tensor, 0, 1, [0, 1], async_op=async_op) + + if async_op: + otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + + return (clone_to_cpu(tensor), clone_to_cpu(otensor)) + + +def _exchange_worker(async_op): + shape = [128, 256] + + tensor = _get_tensor(shape) + otensor = cube.runtime.adapter.exchange(tensor, [0, 1], async_op=async_op) + + if async_op: + otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + + return (clone_to_cpu(tensor), clone_to_cpu(otensor)) + + +def _2gpu_worker(): + _init_distributed(2) + result = {} + result['move'] = _move_worker(False) + result['move_async'] = _move_worker(True) + result['allreduce'] = _allreduce_worker(False) + result['allreduce_async'] = _allreduce_worker(True) + result['allgather'] = _allgather_worker(False) + result['allgather_async'] = _allgather_worker(True) + result['reduce_scatter'] = _reduce_scatter_worker(False) + result['reduce_scatter_async'] = _reduce_scatter_worker(True) + result['all2all'] = _all2all_worker(False) + result['all2all_async'] = _all2all_worker(True) + result['exchange'] = _exchange_worker(False) + result['exchange_async'] = _exchange_worker(True) + + return result + + +def test_2gpu(): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + print('skip test_2gpu due to lack of cuda devices') + return + results = launch_torchrun(2, _2gpu_worker) + + for op in ['', '_async']: + # check move + outputs = results[0][f'move{op}'], results[1][f'move{op}'] + assert torch.equal(outputs[0], outputs[1]) + + # check allreduce + outputs = results[0][f'allreduce{op}'], results[1][f'allreduce{op}'] + assert torch.equal(outputs[0][1], outputs[1][1]) + assert torch.equal(outputs[0][0] + outputs[1][0], outputs[0][1]) + + # check allgather + outputs = results[0][f'allgather{op}'], results[1][f'allgather{op}'] + result = torch.concat([outputs[0][0], outputs[1][0]], dim=0) + assert torch.equal(outputs[0][1], outputs[1][1]) + assert torch.equal(result, outputs[0][1]) + + # check reduce_scatter + outputs = results[0][f'reduce_scatter{op}'], results[1][f'reduce_scatter{op}'] + result = (outputs[0][0] + outputs[1][0]).chunk(2, dim=0) + assert torch.equal(outputs[0][1], result[0]) + assert torch.equal(outputs[1][1], result[1]) + + # check all2all + outputs = results[0][f'all2all{op}'], results[1][f'all2all{op}'] + in0 = outputs[0][0].chunk(2, dim=1) + in1 = outputs[1][0].chunk(2, dim=1) + out0 = torch.concat([in0[0], in1[0]], dim=0) + out1 = torch.concat([in0[1], in1[1]], dim=0) + assert torch.equal(outputs[0][1], out0) + assert torch.equal(outputs[1][1], out1) + + # check exchange + outputs = results[0][f'exchange{op}'], results[1][f'exchange{op}'] + assert torch.equal(outputs[0][1], outputs[1][0]) + assert torch.equal(outputs[1][1], outputs[0][0]) + + +def _rdscatter_worker(async_op): + shape = [128, 256] + + tensor = _get_tensor(shape) + otensor = cube.runtime.adapter.rdscatter( + tensor, shape, torch.float32, dim=0, src=0, dsts=[1,2], async_op=async_op) + + if async_op: + otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + + return (clone_to_cpu(tensor), clone_to_cpu(otensor)) + + +def _rdgather_worker(async_op): + shape = [128, 256] + + tensor = _get_tensor(shape) + otensor = cube.runtime.adapter.rdgather( + tensor, shape, torch.float32, dim=0, srcs=[1,2], dst=0) + + if async_op: + otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + + return (clone_to_cpu(tensor), clone_to_cpu(otensor)) + + +def _broadcast_worker(async_op): + shape = [128, 256] + + tensor = _get_tensor(shape) + + # synchronize + otensor = cube.runtime.adapter.broadcast( + tensor, shape, torch.float32, src=0, ranks=[0,1,2]) + if async_op: + otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + + return (clone_to_cpu(tensor), clone_to_cpu(otensor)) + + +def _3gpu_worker(): + _init_distributed(3) + result = {} + result['rdscatter'] = _rdscatter_worker(False) + result['rdscatter_async'] = _rdscatter_worker(True) + result['rdgather'] = _rdgather_worker(False) + result['rdgather_async'] = _rdgather_worker(True) + result['broadcast'] = _broadcast_worker(False) + result['broadcast_async'] = _broadcast_worker(True) + return result + + +def test_3gpu(): + if not torch.cuda.is_available() or torch.cuda.device_count() < 3: + print('skip test_3gpu due to lack of cuda devices') + return + results = launch_torchrun(3, _3gpu_worker) + + for op in ['', '_async']: + # check rdscatter + outputs = results[0][f'rdscatter{op}'], results[1][f'rdscatter{op}'], results[2][f'rdscatter{op}'] + result = outputs[0][0].chunk(2, dim=0) + assert torch.equal(outputs[1][1], result[0]) + assert torch.equal(outputs[2][1], result[1]) + + # check rdgather + outputs = results[0][f'rdgather{op}'], results[1][f'rdgather{op}'], results[2][f'rdgather{op}'] + result = torch.cat((outputs[1][0], outputs[2][0]), dim=0) + assert torch.equal(outputs[0][1], result) + + # check broadcast + outputs = results[0][f'broadcast{op}'], results[1][f'broadcast{op}'], results[2][f'broadcast{op}'] + assert torch.equal(outputs[0][0], outputs[0][1]) + assert torch.equal(outputs[0][0], outputs[1][1]) + assert torch.equal(outputs[0][0], outputs[2][1]) diff --git a/unit_tests/test_utils.py b/unit_tests/test_utils.py new file mode 100644 index 00000000..3d09c952 --- /dev/null +++ b/unit_tests/test_utils.py @@ -0,0 +1,11 @@ +import os +from .launch_torchrun import launch_torchrun + +def worker_fn(): + rank = int(os.environ["RANK"]) + return rank + + +def test_torchrun(): + outputs = launch_torchrun(2, worker_fn) + assert outputs == {0: 0, 1: 1} From ac5dfa183b1c35b57b3d1f1f6ef692d491ff2c06 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 26 Jul 2023 06:49:56 +0000 Subject: [PATCH 1455/1892] Merged PR 1687: Support distributed MoE --- cube/algorithm/ops/dimops.py | 4 ++-- cube/graph/function/function.py | 15 +++++++++++---- cube/graph/parser/fx/parser.py | 2 +- cube/profiler/database.py | 1 - 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 0d3ebd0f..9e4581ea 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -208,7 +208,7 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR _logger.warning( f'node ({self.node.name}-{self.node.cid}): detected an input tensor ' f'is split on {len(dims)} dimensions, this will cause data loss.', - category=RuntimeWarning, stacklevel=0, + stacklevel=0, ) itransform.append(DimopSplit.D(dims)) # output @@ -223,7 +223,7 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR _logger.warning( f'node ({self.node.name}-{self.node.cid}): detected an output tensor ' f'is split on {len(dims)} dimensions, this will cause data loss.', - category=RuntimeWarning, stacklevel=0, + stacklevel=0, ) otransform.append(DimopSplit.D(dims)) # modifier diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 2df900e4..6faa293b 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1026,9 +1026,13 @@ def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: kwargs = dict(**kwargs) identifier = ifirst[dim] oidx = ofirst.index(identifier) - size = list(kwargs[kwarg_name]) - assert isinstance(size[oidx], int), \ - f'dynamic size cannot be partitioned but got: {size}' + if isinstance(kwargs[kwarg_name], IRObject): + size = list(kwargs[kwarg_name].value) + else: + size = list(kwargs[kwarg_name]) + if isinstance(size[oidx], IRObject): + _logger.warning(f'partition dim size in IRObject: {size[oidx]}') + size[oidx] = size[oidx].value size[oidx] = size[oidx] // num kwargs[kwarg_name] = tuple(size) return kwargs @@ -1731,7 +1735,10 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: # object slice if isinstance(obj, IRObject): assert obj.value is not None - out = IRObject(name='getitem', value=obj.value[index]) + if isinstance(obj.value[index], IRTensor): + out = obj.value[index] + else: + out = IRObject(name='getitem', value=obj.value[index]) return IRPyFunc(signature, [obj, index], [out]) return obj[index] diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index f60c46b3..33768318 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -150,7 +150,7 @@ def parse_complex_out(meta_out): if isinstance(meta_out, TensorMetadata): shape = meta_out.shape assert shape == torch.Size([]), f'{meta_out}' - return torch.zeros(shape, dtype=meta_out.dtype, requires_grad=meta_out.requires_grad) + return IRFullTensor(shape=shape, requires_grad=meta_out.requires_grad, dtype=DType2IRDType.map(meta_out.dtype)) elif isinstance(meta_out, dict): ret = {} for k, v in meta_out.items(): diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 737f1b34..2f006e0b 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -79,7 +79,6 @@ def gen_torch_tensors(shape, dtype, requires_grad): eval_kwargs[name] = eval_val # run one sample outputs = func(*tensors, **train_kwargs) - # omit non-tensor outputs ''' only profile IRDimops currently, which has at least one tensor output and may have non-tensor outputs (like list, tuple, dict, etc.). In additional, From d38559bb5e0718ee69e33d580dfba1da45bde467 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 31 Jul 2023 07:54:03 +0000 Subject: [PATCH 1456/1892] Merged PR 1694: dedup adapter for output tensor dedup adapter for output tensor --- cube/graph/gener/gen.py | 48 ++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 38d438f9..493f6d79 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -74,6 +74,7 @@ def expand_devices(tensors: List[Optional[IRSubTensor]], dtensors: Dict[int, List[IRSubTensor]] = {} for tensor in tensors: if tensor is None: continue + assert len(tensor.device) > 0, f"find the tensor {tensor} is not assigned by devices" for devid in tensor.device: if tensor in dtensors.setdefault(devid, []): continue @@ -293,11 +294,15 @@ def gen_activation(graph: IRSegment, allow_recompute: bool = True, cost_fn: Opti Generate adapter for activation tensors. The forward/backward adapter is inserted before the first consumers of its full tensor. - @param graph IRGraph: the graph the requires for adapter. - @param allow_recompute bool: Allow adapter recomputes. If this enables, all adapters will be - set to the same recompute group with its consumed node. + Args: + graph (IRGraph): the graph the requires for adapter. + allow_recompute (bool): Allow adapter recomputes. If this enables, all adapters will be + set to the same recompute group with its consumed node. + cost_fn (Callable | None): takes an IRAdapterPrim and outputs a cost in float. + default to be None, which will use communication volume. - @return graph IRGraph: the (inplace) modified graph with activation adapters. + Returns: + graph (IRGraph): the (inplace) modified graph with activation adapters. """ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # e.g., loss or parameter/buffer @@ -361,25 +366,34 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: fadapters = [] - # activation -> activation generation + # (activation -> activation) generation: generate communication adapters between producer operators + # and consumer adapters. if (not skip(fptensors, fctensors)) or (not skip(bptensors, bctensors)): fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) if fadapter is not None: fadapters.append(fadapter) - # activation -> output generation + # (activation -> graph/segment output) generation: generate communication adapters between + # producer operatiors and graph/segment output tensors. Note graph/segment output tensors + # always require for full-shape/value for output, while consumers may partition them. Therefore, + # we need to additionally generate adapters for this case. if ftensor in output_consumer: - # TODO: dedup adapter if the output is same with activation - fctensors = tuple(fwop.input(0) for fwop in output_consumer[ftensor]) - fctensors = expand_devices(fctensors, consumer=True) - bptensors = [] - if isinstance(ftensor.grad, IRFullTensor): - bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) - bptensors = expand_devices(bptensors, producer=True) - if (not skip(fptensors, fctensors)) or (not skip(bptensors, bctensors)): - fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) - if fadapter is not None: - fadapters.append(fadapter) + out_fctensors = tuple(fwop.input(0) for fwop in output_consumer[ftensor]) + out_fctensors = expand_devices(out_fctensors, consumer=True) + # dedup adapter if the output is same with activation tensor + if set(out_fctensors) == set(fctensors) and \ + set(t.device[0] for t in out_fctensors) == set(t.device[0] for t in fctensors): + pass + else: + fctensors = out_fctensors + bptensors = [] + if isinstance(ftensor.grad, IRFullTensor): + bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) + bptensors = expand_devices(bptensors, producer=True) + if (not skip(fptensors, fctensors)) or (not skip(bptensors, bctensors)): + fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) + if fadapter is not None: + fadapters.append(fadapter) # insert adapters for fadapter in fadapters: From de52bfb587b307cc10e845604d878f75cdba90c4 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 7 Aug 2023 02:21:55 +0000 Subject: [PATCH 1457/1892] Merged PR 1691: add parallel module support An example code: [gencode0.py.txt](https://dev.azure.com/msrasrg/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/1691/attachments/gencode0.py.txt) Please note: 1. Base class is changed. 2. Add Forward function 3. More logic in constructor Original Module: ``` class FcRelu(nn.Module): def __init__(self, in_features, out_features, bias=True): super().__init__() self.fc1 = CubeLinear(in_features, in_features, bias=bias) self.relu1 = nn.ReLU() self.fc2 = CubeLinear(in_features, out_features, bias=bias) self.relu2 = nn.ReLU() self.fc3 = CubeLinear(out_features, out_features, bias=bias) self.relu3 = nn.ReLU() def forward(self, x): return self.relu3(self.fc3(self.relu2(self.fc2(self.relu1(self.fc1(x)))))) class FcRelu_4_4(FcRelu): def __init__(self): super().__init__(4, 4) ``` --- cube/codegen/module/module.py | 108 ++++- cube/cube.py | 402 +++++++++++++++++++ cube/graph/function/dimops.py | 18 +- cube/graph/graph.py | 88 ++-- cube/graph/parser/__init__.py | 4 +- cube/ir/tensor.py | 28 +- cube/runtime/module.py | 80 +++- unit_tests/launch_torchrun.py | 18 +- unit_tests/parallel_module/__init__.py | 0 unit_tests/parallel_module/common.py | 172 ++++++++ unit_tests/parallel_module/test_override.py | 103 +++++ unit_tests/parallel_module/test_submodule.py | 271 +++++++++++++ 12 files changed, 1198 insertions(+), 94 deletions(-) create mode 100644 cube/cube.py create mode 100644 unit_tests/parallel_module/__init__.py create mode 100644 unit_tests/parallel_module/common.py create mode 100644 unit_tests/parallel_module/test_override.py create mode 100644 unit_tests/parallel_module/test_submodule.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index b735259b..df320b37 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -13,12 +13,14 @@ from cube.graph.graph import IRSegment from cube.graph.parser.register import CustomizedOps +from cube.graph.parser.fx.parser import FxModuleParser from cube.execplan import ExecutionPlan from cube.execplan.execplan import ExeReuseCell from cube.codegen.syntax.symtable import SymbolTable -from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock +from cube.codegen.syntax.blocks import ClassBlock, ForBlock, FunctionBlock +from cube.codegen.schedule.schedule import ScheduleCodeGen from cube.codegen.emit import FuncEmission from cube.codegen.module.autograd import AutogradAdapterCodeGen @@ -35,11 +37,11 @@ class ModuleCodeGen(FuncEmission): Generate module code `ModuleCodeGen` traverses all IR nodes and categorizes their intermediately generated - codes into different parts, + codes into different parts, then reorders and concatenates these parts into the final code for PyTorch to run. These parts are progressively stored into fields of `ModelCodeGen` - + - `init_code : List[str]` Statements like `import torch` @@ -69,7 +71,7 @@ class ModuleCodeGen(FuncEmission): [ 'tensor_3333 = torch.view(tensor_2222, [1,2,3,4])' ] - + # intermediate codes for 'adapter456(self, tensor_4444)' [ 'tensor_5555 = cube.runtime.adapter.all_reduce(tensor_4444, ranks=[0,1,2,3])' @@ -93,9 +95,10 @@ def __init__(self, execplan: ExecutionPlan, scale_ndevs: Optional[int] = None) - self.init_code: List[str] = [ '\n\n########## Generated Model Code ###########', 'from typing import *', + 'from pathlib import Path', 'import torch', 'import torch.utils.checkpoint as ckpt', 'import cube', 'import _operator', 'from numpy import inf', 'import builtins', '', ''] - + if CompileFlag.use_nnfusion: self.init_code.extend(['import nnfusion', '']) @@ -146,7 +149,7 @@ def add_scale_reducers(self): if device not in ctensor.device: continue if ctensor not in all_params: # a same parameter can be consumed multiple times by different operators - if ctensor not in rest_params: + if ctensor not in rest_params: rest_params.append(ctensor) if len(rest_params) == 0: continue @@ -157,12 +160,12 @@ def add_scale_reducers(self): def get_comm_groups(self, scale_ndevs: Optional[int] = None): """ - Scale the communication groups to multiple devices + Scale the communication groups to multiple devices using data parallelism. @warn this requires user side to setup dataloader for different GPUs - + @param scale_ndevs Optional[int]: scale to number of devices """ def _add_comm_for_group_zero(ranks): @@ -254,9 +257,31 @@ def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: return segment return node - def gen(self, device: int, outfile=None, attach=False) -> str: + def gen( + self, + device: int, + outfile=None, + attach=False, + *, + as_parallel_module=False, + forward_arg_names=None + ) -> str: """ Generate model implementation code based on the given graph. + + Args: + device (int): device id + outfile (str): output file path + attach (bool): whether to append to the file + as_parallel_module (bool): whether to generate parallel module, which will + 1. Inherit from ParallelModule + 2. Has forward method + 3. Add more content to constructor + forward_arg_names (List[str]): argument names of forward function, if None, use node inputs. + This is used only in parallel module + + Returns: + generated code """ gencode = copy.copy(self.init_code) node_args: List[List[str]] = list() @@ -272,7 +297,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: unrolled_seqs.append(node) # we use ordered dict as ordered set sequence = tuple(dict.fromkeys(unrolled_seqs)) - + # scale to multiple devices if self._scale_to_ndevs is not None: sequence = [self.scale(node, self._scale_to_ndevs, device) \ @@ -328,20 +353,36 @@ def gen(self, device: int, outfile=None, attach=False) -> str: node_args.append(args) # generate full code - with ClassBlock(class_name='GenModel', derived=['cube.runtime.module.CubeModule']) as cb: + with ClassBlock( + class_name='GenModel', + derived=[f'cube.runtime.module.{"ParallelModule" if as_parallel_module else "CubeModule"}'] + ) as cb: with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.model_init_statements) - # switch to training or inference mode - if self.execplan.inference: - ib.insert_body('self.eval()') + if as_parallel_module: + ib.insert_body('') + ib.insert_body(f'self.load_attr_content(Path(__file__).with_name("{FxModuleParser.ATTR_CONTENT_FILE}"))') + ib.insert_body(f'self.load_dist_param_map(Path(__file__).with_name("{FxModuleParser.ATTR_MAP_FILE}"))') + ib.insert_body('') + with ForBlock('reducer', f'self.reducers') as for_block: + for_block.insert_body(f'reducer.build_buckets()') + ib.insert_body('') + ib.insert_body(for_block.code) else: - ib.insert_body('self.train()') + # switch to training or inference mode + if self.execplan.inference: + ib.insert_body('self.eval()') + else: + ib.insert_body('self.train()') cb.insert_body('') cb.insert_body(ib.code) + segment_idxs =[] for idx, node in enumerate(gen_nodes): name = ModuleCodeGen.node_name(node) input_args = ['self'] + node_args[idx] forward_code = self.model_methods_bodies[idx] + if isinstance(node, IRSegment): + segment_idxs.append(idx) with FunctionBlock(func_name=name, args=input_args) as fb: fb.insert_body(forward_code) @@ -356,6 +397,33 @@ def gen(self, device: int, outfile=None, attach=False) -> str: cb.insert_body('@torch.jit.script_method') cb.insert_body(fb.code) + if as_parallel_module: + if len(segment_idxs) > 1: + raise RuntimeError("The graph has more than one segment, forward code cannot be generated.") + elif not segment_idxs: + raise RuntimeError("The graph has no segment, forward code cannot be generated.") + segment_idx = segment_idxs[0] + node = gen_nodes[segment_idx] + # will use the orignal names of inputs + inputs = [t.name for t in node.inputs() if not isinstance(t, IRSubTensor) or not t.is_attr()] + # ensure forward args are valid + if forward_arg_names: + for i in range(len(inputs)): + if inputs[i] != forward_arg_names[i]: + raise ValueError(f"Forward args mismatch: {inputs[i]} != {forward_arg_names[i]}") + for i in range(len(inputs), len(forward_arg_names)): + if not forward_arg_names[i].startswith('*'): + raise ValueError(f"Invalid extra forward args: only *args & **kwargs are allowed") + + with FunctionBlock(func_name='_forward_impl', args=['self'] + (forward_arg_names or inputs)) as fb: + outputs = ScheduleCodeGen.return_name(node.outputs(), skip_attr=True) + call_code = f'{outputs} = self.{ScheduleCodeGen.node_name(node)}({", ".join(inputs)})' + fb.insert_body(call_code) + return_code = f'return {ScheduleCodeGen.return_name_complex(self.execplan.graph.outputs())}' + fb.insert_body(return_code) + cb.insert_body('') + cb.insert_body(fb.code) + gencode += cb.code gencode += [''] @@ -364,7 +432,7 @@ def gen(self, device: int, outfile=None, attach=False) -> str: if outfile: with open(outfile, 'a' if attach else 'w') as f: f.write(code) - + # clear used buffer self.clear() return code @@ -545,7 +613,7 @@ def emit_segment(segment: IRSegment) -> List[str]: @staticmethod def _emit_nodes(nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: """ - Emit code to invoke operations and adapter, + Emit code to invoke operations and adapter, e.g. (the lines are split into `List[str]`) ``` @@ -589,7 +657,7 @@ def recompute(tensor_2222): tensor_2222 = None # no more reference return tensor_3333 # in the beginning we have `import torch.utils.checkpoint as ckpt` - tensor_4444 = ckpt.checkpoint(recompute, tensor_1111) + tensor_4444 = ckpt.checkpoint(recompute, tensor_1111) ``` REMARK: @@ -613,7 +681,7 @@ def recompute(tensor_2222): output_names = [FuncEmission.tensor_name(t) for t in outputs] output_names_tuple = ', '.join(output_names) - # 'graph.segment(nodes)' ensures that if a tensor is no longer used (in RC group or in later code), + # 'graph.segment(nodes)' ensures that if a tensor is no longer used (in RC group or in later code), # it's not included in 'outputs'. # And we will not generate 'return' statement for it, since it will cause the error # that the variable is not defined (because it has been 'del'-ed). @@ -623,7 +691,7 @@ def recompute(tensor_2222): # e.g. those ids in subgraphs are not 0-based, and incremented after the preceding non-rc nodes and so on. # # So within the recomputing subgraph, tensors can be released if they are no longer used - # i.e. not returned by the 'def recompute(...)' + # i.e. not returned by the 'def recompute(...)' # since 'execplan.graph.segment(nodes)' will make all "free variables" as explicit inputs/outputs # to that subgraph. diff --git a/cube/cube.py b/cube/cube.py new file mode 100644 index 00000000..679971ce --- /dev/null +++ b/cube/cube.py @@ -0,0 +1,402 @@ +from typing import Callable, Any, Optional, Type, Union +from pathlib import Path +import inspect +import sys +import importlib +from dataclasses import dataclass + +import torch +from cube.graph.parser.fx.parser import FxModuleParser + +from cube.ir.cten import IRObject +from cube.ir.tensor import IRFullTensor + +from cube.graph import IRGraph +from cube.graph import parser +from cube.graph.parser.dtype import DType2IRDType +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.pyfunc import IRPyFunc +from cube.graph.schedule.schedplan import SchedulePlan +from cube.graph.gener.gen import IRAdapterGener + +from cube.codegen import ModuleCodeGen +from cube.execplan import ExecutionPlan +from cube.execplan.planpass.grouping import Grouping +from cube.execplan.planpass.fusion import DiffFusion +from cube.ir.unique import IDGenerator +from cube.program import Program +from cube.runtime.module import CubeModule + + +@dataclass +class ComputeConfig: + plan_ngpus: int + runtime_ngpus: int + + +def _complex(val: Any): + """Complex to CPU""" + if isinstance(val, tuple): + return tuple(_complex(t) for t in val) + if isinstance(val, list): + return list(_complex(t) for t in val) + if isinstance(val, dict): + return {_complex(key):_complex(val) for key, val in val.items()} + if isinstance(val, set): + return {_complex(t) for t in val} + if isinstance(val, torch.Tensor): + return val.cpu() + return val + + +def _get_full_qualified_name(obj: Any) -> str: + """Get full qualified name of an object""" + if inspect.isclass(obj): + return obj.__module__ + '.' + obj.__qualname__ + return obj.__module__ + '.' + obj.__class__.__qualname__ + + +def _add_cube_savedir_to_syspath(cube_savedir: str) -> Path: + cube_savedir = Path(cube_savedir).resolve() + cube_savedir.mkdir(parents=True, exist_ok=True) + if str(cube_savedir) not in sys.path: + sys.path.append(str(cube_savedir)) + return cube_savedir + + +def _is_any_gencode_loaded(namespace: str) -> bool: + """Check if a module is loaded""" + for m in sys.modules.values(): + if m.__name__.startswith(namespace + '.' + _GENCODE_FILE_PREFIX): + return True + return False + + +_GENCODE_FILE_PREFIX = 'gencode' +_GENCODE_FILE_TEMPLATE = _GENCODE_FILE_PREFIX + '{}.py' # 'gencode{}.py' +_CUBE_MODULE_NAMESPACE = '_cube_modules' + + +def _gencode( + module: torch.nn.Module, + dummy_input: dict, + pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], + compute_config: ComputeConfig, + *, + dynamic_shape: bool = True, + cube_savedir: Union[str, Path] = './.cube', + override: bool = False, + instance_name: Optional[str] = None + ) -> None: + """ + Generate cube module source code from a torch module, and save it to file. + Generated module will be save according to its full qualified name. + + If you want to save multiple instances of the same module, + you can specify the instance_name to distingish them. + + For example, if the module is `torchscale.x.y`, then the generated module will be save to + `cube_savedir/_cube_modules/torchscale/x/y/instance_name`. + + Args: + module (torch.nn.Module): the module to be compiled + dummy_input (dict): the dummy input for the module + pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy + compute_config (ComputeConfig): the environment resource + dynamic_shape (bool): whether to use dynamic shape + override (bool): If true, source code will be regenerated even if generated code exists. + cube_savedir (Union[str, Path]): the directory to save generated code + instance_name (Optional[str]): the instance name of the generated module. + + Returns: + None + """ + # put cube_savedir into sys.path + # so we can import the generated module with its namespace later + cube_savedir = _add_cube_savedir_to_syspath(cube_savedir) + + instance_name = instance_name.strip('.') if instance_name else '' + instance_namespace = f'.{instance_name}' if instance_name else '' + namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module)}{instance_namespace}' + outdir = cube_savedir / Path(namespace.replace('.', '/').strip('/')) + outdir.mkdir(parents=True, exist_ok=True) + + # decision matrix for code generation + # override flag | dir condition(imported, empty, match, unmatched) | action + # --------------------------------------------------------- + # True | empty | generate + # True | imported | raise error + # True | match | generate + # True | unmatch | generate + # False | empty | generate + # False | match | do nothing + # False | unmatch | raise error + # False | imported | doesn't matter + if not override: + # check if the module is already generated + expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.plan_ngpus)] + expected_output_files.append(outdir / FxModuleParser.ATTR_CONTENT_FILE) + expected_output_files.append(outdir / FxModuleParser.ATTR_MAP_FILE) + existing_output_files = [f for f in outdir.glob('*') if f.is_file()] + if existing_output_files: + if all([output_file.exists() for output_file in expected_output_files]) \ + and len(existing_output_files) == len(expected_output_files): + return + else: + raise RuntimeError(f'Output directory {outdir} is not empty. ' + f'And the existing files do not match with current config.') + else: + # check if the module is already loaded + if _is_any_gencode_loaded(namespace): + raise RuntimeError(f'Output directory {outdir} is already loaded. ' + f'You can not override a loaded module.') + # clear existing generated files + for f in outdir.glob('*'): + if f.is_file(): + f.unlink() + + # reset environment + program = Program() + program.clear() + IDGenerator().clear() + + module = module.to(device=torch.device("cpu")) + module.train() + + # generate fx graph + dummy_input = _complex(dummy_input) + fx_graph = parser.to_fx_graph(module, dummy_input) + + # generate ir logic graph + ir_graph = parser.to_ir_graph( + fx_graph, dummy_input, outdir, dynamic_shape + ) + + # generate dummy inputs for logic graph + # that is, generate IRObject/IRFullTensor for fx graph dummpy input + fx_input_nodes = [node for node in fx_graph.graph.nodes if node.op == 'placeholder'] + # the inputs of graph is different with original forward args + # so we get the real forward args from fx inputs + forward_args = [node.target for node in fx_input_nodes] + ir_dummy_inputs = [] + for node in fx_input_nodes: + if node.target.startswith('*'): # *args or **kwargs + if node.target.strip('*') in dummy_input: + raise ValueError(f"Input {node.target}: *args or **kwargs is not suppported") + ir_dummy_inputs.append(None) # always set None to *args/**kwargs + elif node.target in dummy_input: + ir_dummy_inputs.append(dummy_input[node.target]) + else: + raise ValueError(f"Input {node.target} not in dummy input. Default value is not supported.") + for i in range(len(ir_dummy_inputs)): + if isinstance(ir_dummy_inputs[i], torch.Tensor): + # note: we will always set tensor to require gradient, which may + # generate backward communications in adapter. However, as long as + # the data doesn't require gradient in real runtime, the backward + # communication will not be triggered. + ir_dummy_inputs[i] = IRFullTensor( + shape=ir_dummy_inputs[i].size(), + name=fx_input_nodes[i].target, + requires_grad=True, + dtype=DType2IRDType.map(ir_dummy_inputs[i].dtype)).tosub() + ir_dummy_inputs[i].grad = ir_dummy_inputs[i].parent.grad.tosub() + else: + ir_dummy_inputs[i] = IRObject( + name=fx_input_nodes[i].target, + value=ir_dummy_inputs[i] + ) + # generate complete ir graph + ir_dummy_outputs = ir_graph(*ir_dummy_inputs) + + graph = program.get_graph() + graph.backward() + program.finalize() + program.set_input(ir_dummy_inputs) + if ir_dummy_outputs is None: ir_dummy_outputs = [] + elif not (isinstance(ir_dummy_outputs, tuple) or isinstance(ir_dummy_outputs, list)): + ir_dummy_outputs = [ir_dummy_outputs] + program.set_output(ir_dummy_outputs) + + graph = pas_policy(graph, compute_config) + if not isinstance(graph, IRGraph): + raise RuntimeError("Expected policy return IRGraph") + + # check assignment and remove anchor node + for node in graph.nodes(flatten=True): + # skip graph anchor and multiref: they will be removed or replaced by system + if isinstance(node, IRGraphAnchor) or node.name == 'multiref': + graph.assign(node, 0) + if isinstance(node, IRPyFunc): + graph.assign(node, 0) + if len(node.device) == 0: + raise RuntimeError(f"Node {node} device is not set") + graph = IRAdapterGener.gen(graph, cost_fn=None) + if graph.sched is not None: + graph.sched.apply() + + if isinstance(graph.sched, SchedulePlan): + execplan = ExecutionPlan.from_schedplan(graph.sched) + else: + execplan = ExecutionPlan.from_graph(graph) + + execplan = DiffFusion.apply(execplan) + # plan pass for computation grouping + if not graph.sched: + execplan = Grouping.apply(execplan) + + # code generation + runtime_ngpus = None if compute_config.plan_ngpus == compute_config.runtime_ngpus else compute_config.runtime_ngpus + assert len(execplan.graph.device) == compute_config.plan_ngpus, f"{execplan.graph.device}" + mgener = ModuleCodeGen(execplan, scale_ndevs=runtime_ngpus) + for rank in range(compute_config.plan_ngpus): + filename = _GENCODE_FILE_TEMPLATE.format(rank) + mgener.gen(rank, forward_arg_names=forward_args, outfile=outdir / filename, attach=False, as_parallel_module=True) + + +def _load_cube_module_class( + module_class: Type[torch.nn.Module], + *, + cube_savedir: Union[str, Path] = './.cube', + instance_name: Optional[str] = None, +): + """ + Load the generated cube module class. + + Please note that the cube module class should be generated beforehand by _gencode(). + + Args: + module_class (Type[torch.nn.Module]): the original module class + cube_savedir (Union[str, Path]): the directory to load generated code + instance_name (Optional[str]): the instance name of the generated module. + """ + _add_cube_savedir_to_syspath(cube_savedir) + rank = torch.distributed.get_rank() + instance_name = instance_name.strip('.') if instance_name else '' + instance_namespace = f'.{instance_name}' if instance_name else '' + gen_imported = importlib.import_module( + f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}{instance_namespace}.{Path(_GENCODE_FILE_TEMPLATE.format(rank)).stem}' + ) + cube_module_class = gen_imported.GenModel + # rewrite class name and module name + cube_module_class.__name__ = module_class.__name__ + cube_module_class.__qualname__ = module_class.__qualname__ + # cube_module_class.__module__ = module_class.__module__ + cube_module_class.__orig_module_class__ = module_class # save the original module class + return cube_module_class + + +def as_cube( + module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]], + dummy_input: dict, + pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], + compute_config: ComputeConfig, + *, + dynamic_shape: bool = True, + cube_savedir: Union[str, Path] = './.cube', + override: bool = False, + instance_name: Optional[str] = None, +) -> Union[CubeModule, Type[CubeModule]]: + """ + Convert a torch.nn.Module object or class to CubeModule object or class. + + If you want to save multiple instances of the same module, + you can specify the instance_name to distingish them. + + Args: + module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled + dummy_input (dict): the dummy input for the module + pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy + compute_config (ComputeConfig): the environment resource + dynamic_shape (bool): whether to use dynamic shape + override (bool): If true, source code will be regenerated even if generated code exists. + cube_savedir (Union[str, Path]): the directory to save generated code + instance_name (Optional[str]): the instance name of the generated module. + + Returns: + Union[CubeModule, Type[CubeModule]]: the converted CubeModule object or class + """ + if ( + isinstance(module_or_module_class, CubeModule) or + (inspect.isclass(module_or_module_class) and issubclass(module_or_module_class, CubeModule)) + ): + return module_or_module_class + + if not torch.distributed.is_initialized(): # we only support distributed training + raise RuntimeError("Distributed training is not initialized.") + + rank = torch.distributed.get_rank() + is_module_class = inspect.isclass(module_or_module_class) + module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ + + if rank == 0: + if is_module_class: + # it should only have 1 `self` parameter + if len(inspect.signature(module_or_module_class.__init__).parameters) > 1: + raise ValueError("Module class __init__ should be parameter-free.") + try: + module = module_or_module_class() + except Exception as e: + raise RuntimeError(f"Error when create module instance.") from e + else: + module = module_or_module_class + + # TODO: copy generated files to other nodes + # Currently you must use a shared file system to share the generated files (like mounted Azure Blob) + # Or you can manually copy the generated files to other nodes + _gencode( + module, + dummy_input, + pas_policy, + compute_config, + dynamic_shape=dynamic_shape, + override=override, + cube_savedir=cube_savedir, + instance_name=instance_name, + ) + if is_module_class: + del module + torch.distributed.barrier() + cube_module_class = _load_cube_module_class( + module_class, + cube_savedir=cube_savedir, + instance_name=instance_name, + ) + return cube_module_class if is_module_class else cube_module_class() + + +def cube( + dummy_input: dict, + pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], + compute_config: ComputeConfig, + *, + dynamic_shape: bool = True, + cube_savedir: Union[str, Path] = './.cube' +) -> Callable[[Union[torch.nn.Module, Type[torch.nn.Module]]], Union[CubeModule, Type[CubeModule]]]: + """ + Work as a class decorator to convert a torch.nn.Module to CubeModule. + + Please make sure the Module's __init__ is paremeter-free. + Please note that + 1. Returned CubeModule will replace the torch.nn.Module in-place. + And all member functions/variables of original torch.nn.Module will be gone. + 2. The parameters of CubeModule will be fixed, + which means all instances of CubeModule will use the same parameters (which are from the tracing). + + Args: + dummy_input (dict): the dummy input for the module + pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy + compute_config (ComputeConfig): the environment resource + dynamic_shape (bool): whether to use dynamic shape + cube_savedir (Union[str, Path]): the directory to save generated code + """ + def wrap(module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]]) -> Union[CubeModule, Type[CubeModule]]: + return as_cube( + module_or_module_class, + dummy_input, + pas_policy, + compute_config, + dynamic_shape=dynamic_shape, + override=False, + cube_savedir=cube_savedir + ) + return wrap diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index f26237c1..13242a15 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -226,16 +226,16 @@ def parse(shape_anno: str) -> Tuple[DimAnno]: a (b+ dim) d^ @param shape str: shape annotation - + @return dim_annos Tuple[DimAnno]: tuple of dimension annotations """ # => ['a', '(', 'b+', 'dim', ')', 'd^'] shapes = list() - for group in re.split('\ +', shape_anno): + for group in re.split(r'\ +', shape_anno): if len(group) == 0: continue if '(' in group or ')' in group: - for group in re.split('([\(\)])', group): + for group in re.split(r'([\(\)])', group): if len(group) != 0: shapes.append(group) else: @@ -391,7 +391,7 @@ def getlen(self, identifier: str) -> Optional[int]: Get identifier length @param identifier str: identifier name - + @return length Optional[int]: the length of identifier """ assert identifier in self._identifiers, f"{identifier} not exists {set(self._identifiers.keys())}" @@ -441,7 +441,7 @@ def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], ous: Tuple[Tuple[Union[str, Tuple[str]]]]) -> str: """! Create operator annotation string - e.g., + e.g., ins = [ ['a', 'b', 'c+'], ['c+', ['d', 'e']] ] ous = [ ['a', 'b', 'd', 'e'] ] => @@ -449,7 +449,7 @@ def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], @param ins Tuple[Tuple[Union[str, Tuple[str]]]: input identifier list @param ous Tuple[Tuple[Union[str, Tuple[str]]]: output identifier list - + @return anno str: operator annotation """ in_annos = list() @@ -488,7 +488,7 @@ class DimopSplit: """ def __init__(self, dims: Optional[Union[int, List[int]]] = None, r = False, v = False) -> None: """Dimension split config - + Args: dims (Optional[Union[int, List[int]]], optional): [description]. Defaults to None. """ @@ -571,7 +571,7 @@ def outputs(self) -> Tuple[DimopSplit]: def output(self, idx: int) -> DimopSplit: return self._outputs[idx] - + def modifier(self) -> Optional[Callable]: return self._modifier[0] @@ -604,7 +604,7 @@ def __init__(self, create_fn: Callable, name: str, @param transform_rules: the special rules to partition the operator. Default None. @param kwargs: the kwarg non-tensor parameters """ - assert all(isinstance(anno, str) for anno in annos), "Expect annos to be List[str]" + assert all(isinstance(anno, str) for anno in annos), "Expect annos to be List[str]" self._annos_candidates: List[str] = tuple(annos) self._anno: OpAnno = None self._iannos: List[ShapeAnno] = None diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 87c180ee..5a15efb7 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -38,7 +38,7 @@ class IRGraph(IRSegment): IRGraph is used for reprensting a distributed training iteration. """ - def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRTensor], + def __init__(self, nodes: List[IRCell], inputs: List[IRTensor], outputs: List[IRTensor], module_name: str): super().__init__(nodes, inputs, outputs, module_name) @@ -61,7 +61,7 @@ def __call__(self, *args): Register forward action """ return self.forward(*args) - + def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: """ forward will divide the graph into Actions according to @@ -135,10 +135,10 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: else: return self.outputs() - def backward(self, loss: IRSubTensor): + def backward(self, loss: Optional[IRSubTensor] = None): """ - Backward the graph from the entry tensor of loss. - + Backward the graph from the entry tensor of loss to complete the graph with backward operators. + This will infer tensors' gradients by following rules: Conditions must satisfy for an forward op having its backward op: @@ -146,7 +146,7 @@ def backward(self, loss: IRSubTensor): * one of its output tensors is consumed by other forward ops For operators that doesn't need backward, all gradients of their - input/output tensors will make to None (despite require_grad is True) + input/output tensors will make to None (despite require_grad is True) @param loss IRSubTensor: the loss tensor, must be in the output of current graph. The loss shape should be (1,) @@ -155,18 +155,20 @@ def backward(self, loss: IRSubTensor): """ # set mirror as self self._mirror = self - # set loss gradient - loss.parent.to_loss() - # update require gradient: for tensors that have no consumers, - # make their gradient to be False - for ftensor in self.full_tensors(): - if ftensor.is_loss(): continue - consumers = [n for n in self.consumers(ftensor) if isinstance(n, IRFwOperation)] - if len(consumers) == 0 and ftensor.requires_grad: - _logger.warning( - f"detected a dead ftensor which is not consumed by any nodes:\n\t{ftensor.name}: {ftensor}") - ftensor.requires_grad = False + if loss is not None: # optimize graph with loss + # set loss gradient + loss.parent.to_loss() + + # update require gradient: for tensors that have no consumers, + # make their gradient to be False + for ftensor in self.full_tensors(): + if ftensor.is_loss(): continue + consumers = [n for n in self.consumers(ftensor) if isinstance(n, IRFwOperation)] + if len(consumers) == 0 and ftensor.requires_grad: + _logger.warning( + f"detected a dead ftensor which is not consumed by any nodes:\n\t{ftensor.name}: {ftensor}") + ftensor.requires_grad = False # infer gradient for ftensor in self.full_tensors(): @@ -195,7 +197,7 @@ def group(self, nodes: List[IRCell]) -> IRSegment: Note nodes should not have applied by any transformation. @param nodes List[IRCell]: consecutive nodes in forward procedure - + @return segment IRSegment: the grouped segment """ assert all(node.isfw() for node in nodes), f"Expected all nodes in forward procedure" @@ -285,7 +287,7 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis """ Partition Primitive: - replicate: replicate a forward or data operation multiple times. - + Each input and output will be replicated with no gradient accumulation. The backward of the forward operation will automatically be replicated. @@ -298,7 +300,7 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis raise TypeError("Expected op to be forward op or data op") if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") - if node.name == 'multiref': + if node.name == 'multiref': _logger.warning(f'skip replicating multiref ({node.cid}), which will be handled by system.') return [node] if isinstance(node, IRPyFunc): @@ -339,7 +341,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], """ Partition Primitive: - partition: partition a forward or data operation using algorithms. - + The comment in the node will be inherited to partitioned nodes. The backward of the forward operation will be automatically partitioned. @@ -353,7 +355,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], Both primitive may replicate the tensors, but `replicate` will not do gradient accumulation while `partition` will always require gradient accumulation on replicated tensors. - + @param node Union[IRFwOperation, IRDataOperation]: the node to partition @param algo GenericDistAlgo: the partition algorithm related to the node @param config Dict[str, Any]: the algorithm configuration, e.g., partition number @@ -391,7 +393,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], for t in node.inputs() + node.outputs(): if isinstance(t, IRSubTensor): valmaps[t.parent] = None if t.grad is None else ValueMap(t.grad.valmap) - + # gather consumers ctensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() consumers: Dict[IRFullTensor, List[IRCell]] = dict() @@ -519,7 +521,7 @@ def fuse(self, nodes: List[IRFwOperation], def fuse_op_fn(*args, **kwargs) -> IRDimops: return IRDimops(fuse_op_fn, fuse_op_name, signature, [fuse_op_anno], args, **kwargs) - + if make_customized_op: from cube.graph.parser.register import CustomizedOps @@ -544,18 +546,18 @@ def to_name(t: Any) -> str: code.append(f'\treturn {func_outputs}') code = '\n'.join(code) CustomizedOps.register( - signature, fuse_op_fn, code, + signature, fuse_op_fn, code, lambda *args : NotImplementedError("a fused operator doesn't have runtime call") ) - + fuse_op = fuse_op_fn(*inputs, **kwargs) for idx, output in enumerate(outputs): fuse_op.set_output(idx, output) - + # setup device if len(nodes[0].device) != 0: fuse_op.device = nodes[0].device - + # replace nodes with the fused operator # remove forward operators segment = self.segment(nodes[0]) @@ -580,7 +582,7 @@ def assign(self, node: Union[IRFwOperation, IRDataOperation], device: int) -> bo """ Assign an operator (subgraph) to (multiple) rank(s). - Corresponding backward operators (if have) will also be + Corresponding backward operators (if have) will also be assigned to the same device. @param node Union[IRFwOperation, IRBpOperation, IRSegment]: operator @@ -622,7 +624,7 @@ def sequential(self, nodes: Sequence[Union[FOp, Set[FOp]]]): Currently only support node (set) from a same device. @param nodes Sequence[Set[FOp]]: a sequence of operators or - a sequence of concurrent operators. Note there should be no + a sequence of concurrent operators. Note there should be no """ assert len(nodes) > 0 concurrent_groups = [[node] if isinstance(node, IRCell) else node for node in nodes] @@ -707,7 +709,7 @@ def schedule(self, node1: IRCell, action: str, node2: IRCell) -> bool: @param action str: 'after': fixed node2 and schedule node1 after node2 in the sequence. 'before': fixed node2 and schedule node1 before node2 in the sequence. - + @return success bool: True if the scheduling success otherwise False. """ idx1 = self._nodes.index(node1) @@ -765,13 +767,13 @@ def legal_schedule(seq: List[IRCell], integrity_check=False): @note: this functionality is not enabled due to predecessor and succesor functionality. - + @param seq List[IRCell]: the nodes in scheudled order @param integrity_check bool: If true, performs additional integrity check that requires all the nodes in predecessor and successor of a node should appear in the sequence. - + @return valid bool: True for satisfying topo order, otherwise False. """ for index, node in enumerate(seq): @@ -812,15 +814,15 @@ def staging(self, nodes: Tuple[IRFwOperation]): This should be called before any operator partition. The transformation and temporal scheduling can only be applied within each stage. - For example, after staging, user cannot schedule a (transformed) node + For example, after staging, user cannot schedule a (transformed) node from one stage to another stage. Changes will be made: 1). Identity creation: If a non-attribute tensor is produced / consumed not in - neighbor stages, - e.g., + neighbor stages, + e.g., stage 1: t1 = producer() stage 2: ... stage 3: xx = consume(t1) @@ -865,7 +867,7 @@ def staging(self, nodes: Tuple[IRFwOperation]): if node.name == 'multiref' or isinstance(node, IRPyFunc): pass else: - _logger.info(f'involve node {node.name}({node.cid}) into the first stage') + _logger.info(f'involve node {node.name}({node.cid}) into the first stage') starts[0] = idx break @@ -873,7 +875,7 @@ def staging(self, nodes: Tuple[IRFwOperation]): for idx, node in enumerate(self._nodes): if not isinstance(node, IRBpOperation): last_fidx = idx - + fstages: List[List[IRCell]] = [] bstages: List[List[IRCell]] = [] for sid in range(len(starts)): @@ -937,7 +939,7 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: producer, ptensor = self.producers(ftensor)[0], self.ptensors(ftensor)[0] psid = get_sid(producer) # outside of stages, not consider - if psid is None: continue + if psid is None: continue # group consumers into stages consumers = self.consumers(ftensor) @@ -1001,7 +1003,7 @@ def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: if isinstance(nodes, IRSegment): assert nodes.isfw() and (not nodes.isbw()), "Only forward IRSegment can recompute" return self.recompute(nodes.nodes()) - + else: segments = [self.segment(node) for node in nodes] assert all(segment == segments[0] for segment in segments), \ @@ -1067,7 +1069,7 @@ def load(filename: str): """ with open(filename, 'rb') as f: id_state, graph = dill.load(f) - + # recover IRGenerator IDGenerator().load_states(id_state) # recover cell @@ -1079,7 +1081,7 @@ def reset_node(segment: IRSegment): # nodes for node in segment.nodes(): for t in node.inputs() + node.outputs(): - if isinstance(t, IRObject): + if isinstance(t, IRObject): t.cell = node # recursively recover segments if isinstance(node, IRSegment): @@ -1087,6 +1089,6 @@ def reset_node(segment: IRSegment): # output for t in IRSegment.get_objects_from_complex(segment.outputs()): t.cell = segment - + reset_node(graph) return graph diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index c9b1ba92..a64c308e 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,3 +1,3 @@ from cube.graph.parser.fx.parser import FxModuleParser, FxFuncOpTracer -from cube.graph.parser.converter import convert_model -from cube.graph.parser.register import register \ No newline at end of file +from cube.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph +from cube.graph.parser.register import register diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index e2a93a83..bbffede4 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -2,7 +2,7 @@ SubTensor Gradient rule: SubTensor's logical grad = SubTensor.parent.grad.select( - indmap = SubTensor.indmap, + indmap = SubTensor.indmap, valmap = SubTensor.valmap, shape = SubTensor.shape ) @@ -18,7 +18,7 @@ val is always (0/1) Tensor can be graph attributes. In deep learning, these graph attribute tensors -can be +can be 1) parameters (require gradient), 2) buffers (not require gradient) 3) gradient of parameters @@ -40,7 +40,7 @@ def __init__(self, indmap: Tuple[StartEnd]): Create an index map. @param indmap Union[Tuple[StartEnd], IndexMap]: index range [start, end) for each dimension - + @return indmap IndexMap: the created new instance of index map. """ if isinstance(indmap, IndexMap): @@ -110,7 +110,7 @@ def overlap(self, other) -> bool: """ if not isinstance(other, IndexMap): raise TypeError("Expected IndexMap") - + if other.ndims != self.ndims: raise TypeError("Expected same dimension") @@ -126,7 +126,7 @@ def __and__(self, other): Get the common part @param other IndexMap: the other one - + @return indexmap IndexMap: index map for the common part """ if not self.overlap(other): @@ -175,11 +175,11 @@ def weight(self) -> IdxChunk: Get value partitioned chunks in tha accumulcated group """ return self._weight - + def overlap(self, other) -> bool: """! Check on value overlapping. - Note the overlap can only be within a same accumulation group and + Note the overlap can only be within a same accumulation group and a same replication group. """ if not isinstance(other, ValueMap): @@ -332,7 +332,7 @@ def requires_grad(self, req_grad: bool): self._requires_grad = True if self._grad is None: grad = IRFullTensor( - self.shape, 'g' + self.name, + self.shape, 'g' + self.name, requires_grad=False, dtype=self.dtype ).as_grad(self.is_attr()) self._grad = grad @@ -445,7 +445,7 @@ def __init__(self, ftensor: IRFullTensor, """ indmap, valmap = IndexMap(indmap), ValueMap(valmap) assert isinstance(ftensor, IRFullTensor), "Expcted ftensor to be IRFullTensor" - assert 'dtype' not in kwargs, "IRSubTensor is not allowed to initialize with a dtype" + assert 'dtype' not in kwargs, "IRSubTensor is not allowed to initialize with a dtype" super().__init__(shape=indmap.shape, name=ftensor.name, **kwargs) for attr in IRFullTensor._meta: setattr(self, attr, getattr(ftensor, attr)) @@ -537,7 +537,7 @@ def catdim(self, other: IRTensor) -> Optional[int]: else: return None return cat_dim - + def concat(self, other: IRTensor, dim: int) -> IRTensor: """! concat dimension with other IRSubTensor. The concatenate @@ -590,7 +590,7 @@ def accumable(self, tensors: Union[IRTensor, List[IRTensor]]) -> bool: def accum(self, tensors: Union[IRTensor, List[IRTensor]]) -> IRTensor: """! Accumulate tensor on value dimension. - The replica id will be + The replica id will be @param: tensors Union[IRTensor, List[IRTensor]] @return tensor IRSubTensor: accumulated tensor @@ -721,12 +721,12 @@ def split_dim(self, dim: int, num: int) -> List[IRTensor]: return sub_tensors def split_dims(self, dims: Tuple[int], nums: Tuple[int] ) -> List[IRTensor]: - """Uniformly and nestedly partition tensors alongside multiple dimensions. - + r"""Uniformly and nestedly partition tensors alongside multiple dimensions. + Args: dims (Tuple[int]): the dimensions to get partitioned nums (Tuple[int]): the number of sub-tensor generated - + Returns: List[IRTensor]: the generated `\prod nums` sub-tensors """ diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 04b0dccf..f9495e36 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,6 +1,7 @@ from typing import List, Dict, Tuple, Optional import logging import os +import sys import torch @@ -72,7 +73,7 @@ def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: i Add an attribute map. The mapping includes current attribute name (str) to logical tensor id, and the mapping of logical tensor id including spatial (slice) and val chunks - + @param attr str: attribute name of this moudle @param tid int: full tensor id @param slicers Tuple[slice]: indexing from full tensor @@ -108,8 +109,13 @@ def init_group(self, ranks: List[int]): def get_checkpoint(self, optimizer: torch.optim.Optimizer = None): state_dict = super().state_dict() - assert os.path.isfile('dist_param_map.pt'), 'Cannot open distributed parameter mapping file: dist_param_map.pt' - dist_param_map = torch.load('dist_param_map.pt') + # backward compatibility + # in old version, dist_param_map is not loaded in constructor + # so we will try to load it from file on the fly. + dist_param_map = getattr(self, '_dist_param_map', None) + if not dist_param_map: + assert os.path.isfile('dist_param_map.pt'), 'Cannot open distributed parameter mapping file: dist_param_map.pt' + dist_param_map = torch.load('dist_param_map.pt') param_area_map = self._fullmap optimizer_state_dict = optimizer.state_dict() if optimizer is not None else None return state_dict, dist_param_map, param_area_map, optimizer_state_dict @@ -118,7 +124,7 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref filename_prefix = 'dist_checkpoint' if filename_prefix is None else filename_prefix filename = f"{filename_prefix}-{DeviceGroup().rank}.ckpt" state_dict, dist_param_map, param_area_map, optimizer_state_dict = self.get_checkpoint(optimizer) - + _logger.info(f'saving distributed checkpoint to {filename}') torch.save({ 'state_dict': state_dict, @@ -367,3 +373,69 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): torch.save({'state_dict': merged_model_state_dict, 'optim_state_dict': merged_optimizer_state_dict }, filename_prefix + '.full.ckpt') + + +class _AddGradModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args): + new_args = [] + found_tensor = False + for arg in args: + if isinstance(arg, torch.Tensor): + found_tensor = True + new_arg = arg + if not arg.requires_grad: + new_arg = arg.clone().requires_grad_(True) + new_args.append(new_arg) + else: + new_args.append(arg) + if not found_tensor: + raise RuntimeError("Failed to setup module backward hook: no input Tensors.") + return tuple(new_args) + + +class ParallelModule(CubeModule): + def __init__(self): + super().__init__() + self._dist_param_map = None # should fill in sub classes. + + # register_full_backward_pre_hook requires the input tensor to be requires_grad + # so we add a module to make sure the input tensor requires grad + self._add_grad_module = _AddGradModule() + self._add_grad_module.register_full_backward_pre_hook(self.backward_hook) + + # if _grad_sentry.grad becomes None or zero + # we should zero grad in the next forward + # NOTE: this is a hacky way to detect whether the backward is called + # And it will add an extra parameter to the module + self._grad_sentry = torch.nn.Parameter(torch.tensor([0.0], requires_grad=True)) + self._grad_sentry.grad = torch.tensor([1e-7]) + + def forward(self, *args, **kwargs): + if self._grad_sentry.grad is None or self._grad_sentry.grad.item() == 0: + self.zero_grad() + self._grad_sentry.grad = torch.tensor([1.0]) + + new_args = self._add_grad_module(*args) + return self._forward_impl(*new_args, **kwargs) + + def _forward_impl(self, *args, **kwargs): + """ + forward implementation. Should be implemented by subclass + """ + raise NotImplementedError + + def backward_hook(self, module, grad_output): + """ + backward hook for gradient synchronization + """ + for reducer in self.reducers: + reducer.sync_grads() + + def get_dist_param_map(self): + return self._dist_param_map + + def load_dist_param_map(self, filename: str): + self._dist_param_map = torch.load(filename) diff --git a/unit_tests/launch_torchrun.py b/unit_tests/launch_torchrun.py index e3b7a2f8..fe7fde30 100644 --- a/unit_tests/launch_torchrun.py +++ b/unit_tests/launch_torchrun.py @@ -1,4 +1,5 @@ import uuid +import torch from torch.distributed.run import elastic_launch, LaunchConfig @@ -18,12 +19,25 @@ def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): return outputs -def clone_to_cpu(tensor): +def clone_to_cpu(tensor: torch.Tensor): # when you use launch_torchrun # you can't directly return a cuda tensor # Error message: Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors] # So you can use this function to clone a tensor to cpu cloned_tensor = tensor.cpu().clone().detach().requires_grad_(tensor.requires_grad) - if tensor.grad is not None: + if tensor.is_leaf and tensor.grad is not None: cloned_tensor.grad = tensor.grad.cpu().clone() return cloned_tensor + + +def clone_to_cpu_recursively(data): + if isinstance(data, torch.Tensor): + return clone_to_cpu(data) + elif isinstance(data, dict): + return {k: clone_to_cpu_recursively(v) for k, v in data.items()} + elif isinstance(data, list): + return [clone_to_cpu_recursively(v) for v in data] + elif isinstance(data, tuple): + return tuple(clone_to_cpu_recursively(v) for v in data) + else: + return data diff --git a/unit_tests/parallel_module/__init__.py b/unit_tests/parallel_module/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unit_tests/parallel_module/common.py b/unit_tests/parallel_module/common.py new file mode 100644 index 00000000..25b78f39 --- /dev/null +++ b/unit_tests/parallel_module/common.py @@ -0,0 +1,172 @@ +from datetime import datetime +import math +import random +from typing import List, Optional + +import torch +from torch import nn +import numpy as np + +from cube.cube import ComputeConfig +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.function.dimops import IRDimops +from cube.graph.graph import IRGraph +from cube.graph.segment import IRSegment +from cube.ir.operator import IRDataOperation, IRFwOperation + + +def _tp(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int): + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def _replica(graph: IRGraph, node, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def PASRandomSPMD(graph: IRGraph, env_resource: ComputeConfig): + """ + Random SPMD policy + """ + ngpus = env_resource.plan_ngpus + # get the current random state + state = random.getstate() + + seed = 1 + # print(f'> set random SPDM policy seed to {seed}') + random.seed(seed) + devs = list(range(ngpus)) + + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): continue + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor) + + graph_inputs = IRSegment.get_objects_from_complex(graph.inputs()) + graph_outputs = IRSegment.get_objects_from_complex(graph.outputs()) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'multiref' or isinstance(node, IRGraphAnchor): + continue + # Currently cube only support replicate if node's input or input is part of the graph output + # workaround for now + # will fix later. + if any(output in graph_outputs for output in node.outputs()) \ + or any(input in graph_outputs for input in node.inputs()) \ + or any(input in graph_inputs for input in node.inputs()): + _replica(graph, node, devs) + continue + if isinstance(node, IRDimops): + configs = node.transform_space() + if len(configs) == 0: + _replica(graph, node, devs) + else: + configs = sorted(configs, reverse=True, + key=lambda config: node.input(config[0]).shape[config[1]]) + random.shuffle(configs) + for (idx, dim) in configs: + if node.input(idx).shape[dim] % len(devs) != 0: continue + if node.algorithms('dim').satisfy(idx=idx, dim=dim, num=len(devs)): + # print(f'> partition node {node.name} ({node.cid}) with config idx={idx}, dim={dim}') + _tp(graph, node, devs, idx, dim) + break + else: + _replica(graph, node, devs) + else: + _replica(graph, node, devs) + + # restore the random state + random.setstate(state) + # print(graph.extra_repr()) + return graph + + +def PASData(graph: IRGraph, env_resource: ComputeConfig): + """ + Data Parallel + """ + ngpus = env_resource.plan_ngpus + # auto multi-ref + for ftensor in graph.full_tensors(): + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) + + batch_dim = None + for node in graph.nodes(): + if isinstance(node, IRDataOperation): + algo = node.algorithms('data') + sub_nodes = graph.partition(node, algo, num=ngpus) + for idx, subnode in enumerate(sub_nodes): + graph.assign(subnode, idx) + batch_dim = node.get_batch_dims()[0] + if batch_dim is None: batch_dim = 0 + graph_inputs = IRSegment.get_objects_from_complex(graph.inputs()) + graph_outputs = IRSegment.get_objects_from_complex(graph.outputs()) + for node in graph.nodes(): + # print(node) + if isinstance(node, IRFwOperation): + # Currently cube only support replicate if node's input or input is part of the graph output + # workaround for now + # will fix later. + if any(output in graph_outputs for output in node.outputs()) \ + or any(input in graph_outputs for input in node.inputs()) \ + or any(input in graph_inputs for input in node.inputs()): + sub_nodes = graph.replicate(node, ngpus) + else: + try: + algo = node.algorithms('dim') + idx = 0 + sub_nodes = graph.partition( + node, algo, idx=idx, dim=batch_dim, num=ngpus) + # except AssertionError: + except: + # print(f'WARNING: {node} cannot find dim algo, using replicate instead') + sub_nodes = graph.replicate(node, ngpus) + + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + # print(graph.extra_repr()) + return graph + + +class CubeLinear(nn.Module): + def __init__(self, in_features, out_features, bias=False): + super().__init__() + self.fc = nn.Linear(in_features, out_features, bias=False) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.fc.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.fc.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + torch.nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, x): + x = self.fc(x) + if self.bias is not None: + x = x + self.bias + return x + + +def init_distributed(): + torch.distributed.init_process_group(backend='nccl') + rank = torch.distributed.get_rank() + torch.cuda.set_device(rank) + torch.set_default_device(f'cuda:{rank}') + + +def init_random(): + np.random.seed(1) + torch.manual_seed(1) + if torch.cuda.is_available(): + torch.cuda.manual_seed(1) diff --git a/unit_tests/parallel_module/test_override.py b/unit_tests/parallel_module/test_override.py new file mode 100644 index 00000000..c9f944ed --- /dev/null +++ b/unit_tests/parallel_module/test_override.py @@ -0,0 +1,103 @@ +from pathlib import Path +import sys +import tempfile +import pytest +import torch +import shutil + +from cube.cube import as_cube, ComputeConfig + +from .common import PASData, init_distributed +from ..launch_torchrun import launch_torchrun + + +def _to_cube_model(module, compute_config, cube_savedir, override, instance_name): + return as_cube( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + compute_config, + dynamic_shape=True, + override=override, + cube_savedir=cube_savedir, + instance_name=instance_name, + ) + + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + return self.linear(x) + + +def _worker(): + init_distributed() + + with tempfile.TemporaryDirectory() as tempdir: + # False | empty | generate + cmodule1 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, False, None) + # False | match | do nothing + cmodule2 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, False, None) + # true + for (n1, v1), (n2, v2) in zip(cmodule1.named_parameters(), cmodule2.named_parameters()): + assert n1 == n2 + assert torch.equal(v1, v2) + + cmodule3 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, False, 'test') + cmodule4 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, False, 'test') + + for (n1, v1), (n2, v2) in zip(cmodule3.named_parameters(), cmodule4.named_parameters()): + assert n1 == n2 + assert torch.equal(v1, v2) + + cmodule2_p = dict(cmodule2.named_parameters()) + cmodule3_p = dict(cmodule3.named_parameters()) + keys = cmodule3_p.keys() + assert any(not torch.equal(cmodule2_p[key], cmodule3_p[key]) for key in keys) + + # True | imported | raise error + with pytest.raises(RuntimeError): + _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, True, None) + + with pytest.raises(RuntimeError): + _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, True, 'test') + + # False | unmatch | raise error + with pytest.raises(RuntimeError): + _to_cube_model(MyModule(), ComputeConfig(2, 2),tempdir, False, 'test') + + # True | empty | generate + cmodule1 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, True, 'test2') + module_path = Path(sys.modules[cmodule1.__module__].__file__).parent + test3_module_path = module_path.with_name('test3') + test3_module_path.mkdir(exist_ok=True, parents=True) + test4_module_path = module_path.with_name('test4') + test4_module_path.mkdir(exist_ok=True, parents=True) + for f in module_path.glob('*'): + if f.is_file(): + shutil.copy(f, test3_module_path / f.name) + shutil.copy(f, test4_module_path / f.name) + # fake two gpus + shutil.copy(test4_module_path / 'gencode0.py', test4_module_path / 'gencode1.py') + + # True | match | generate + cmodule2 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, True, 'test3') + cmodule2_p = dict(cmodule2.named_parameters()) + cmodule1_p = dict(cmodule1.named_parameters()) + keys = cmodule2_p.keys() + assert any(not torch.equal(cmodule2_p[key], cmodule1_p[key]) for key in keys) + + # True | unmatch | generate + assert (test4_module_path / 'gencode1.py').exists() + cmodule3 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, True, 'test4') + assert not (test4_module_path / 'gencode1.py').exists() + +def test_override(): + if not torch.cuda.is_available(): + print('skip test_submodules_tp_gpu1 due to lack of cuda devices') + return + launch_torchrun(1, _worker) + diff --git a/unit_tests/parallel_module/test_submodule.py b/unit_tests/parallel_module/test_submodule.py new file mode 100644 index 00000000..569b3208 --- /dev/null +++ b/unit_tests/parallel_module/test_submodule.py @@ -0,0 +1,271 @@ +import tempfile +import itertools +import re +from pathlib import Path +import shutil + +import torch +from torch import nn +import numpy as np + +from cube.cube import ComputeConfig, as_cube +from cube.runtime.module import ParallelModule + +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed +from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively + + +class FcRelu(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, out_features, bias=bias) + self.relu2 = nn.ReLU() + self.fc3 = CubeLinear(out_features, out_features, bias=bias) + self.relu3 = nn.ReLU() + + + def forward(self, x): + return self.relu3(self.fc3(self.relu2(self.fc2(self.relu1(self.fc1(x)))))) + + +class FcRelu_4_4(FcRelu): + def __init__(self): + super().__init__(4, 4) + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return as_cube( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + dynamic_shape=True, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + +def _create_modules(pas, compute_config, cube_savedir): + class OrigModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc_relu1 = FcRelu_4_4() + self.fc_relu2 = FcRelu_4_4() + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.fc_relu1(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc_relu1 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu1') + self.fc_relu2 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu2') + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.fc_relu1(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + orig_module = OrigModule().cuda() + init_random() + compiled_module = CompiledModule().cuda() + return orig_module, compiled_module + + +def _train(model): + init_random() + + loss_fn = nn.BCELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + data = [] + DATA_SIZE = 20 + UPDATE_FREQ = 1 # TODO: update_freq support + for _ in range(DATA_SIZE): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + results = [] + for i, (x, y) in enumerate(data): + model.train() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + if i % UPDATE_FREQ == UPDATE_FREQ - 1: + grads = {n: p.grad for n, p in model.named_parameters()} + results.append(clone_to_cpu_recursively([y_pred, loss, grads])) + optimizer.step() + optimizer.zero_grad() + weights = {n: p.data for n, p in model.named_parameters()} + results[-1].append(clone_to_cpu_recursively(weights)) + return results + + +def _gpu_worker(pas, ngpus): + init_distributed() + tempdir = Path(tempfile.gettempdir()) / 'cube_test' + if torch.distributed.get_rank() == 0 and tempdir.exists(): + shutil.rmtree(tempdir) + torch.distributed.barrier() + orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) + orig_results = _train(orig_module) + compiled_results = _train(compiled_module) + return ( + orig_results, + compiled_results, + compiled_module.fc_relu1.get_full_map(), + compiled_module.fc_relu1.get_dist_param_map(), + compiled_module.fc_relu2.get_full_map(), + compiled_module.fc_relu2.get_dist_param_map(), + ) + + +def test_submodules_tp_gpu1(): + if not torch.cuda.is_available(): + print('skip test_submodules_tp_gpu1 due to lack of cuda devices') + return + results = launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) + orig_results, compiled_results, _, _, _, _ = results[0] + for orig, compiled in zip(orig_results, compiled_results): + assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred + assert torch.allclose(orig[1], compiled[1], rtol=1e-6, atol=1e-6) # loss + + # grad + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items() if not k.endswith('_grad_sentry')} + assert len(orig[2]) == len(compiled_cleaned) + for k in orig[2].keys(): + assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + + # weights + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items() if not k.endswith('_grad_sentry')} + assert len(orig[3]) == len(compiled_cleaned) + for k in orig[3].keys(): + assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + + +def _get_fc_weights(state_dict: dict, prefix): + result = {} + new_state_dict = {} + for k, v in state_dict.items(): + if k.endswith('_grad_sentry'): + continue + if k.startswith(prefix): + result[k[len(prefix):]] = v + else: + new_state_dict[k] = v + state_dict.clear() + state_dict.update(new_state_dict) + return result + + +def _compare_weights(orig0, orig1, compiled0, compiled1, fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map): + fc1_weights0 = _get_fc_weights(compiled0, 'fc_relu1.') + fc2_weights0 = _get_fc_weights(compiled0, 'fc_relu2.') + fc1_weights1 = _get_fc_weights(compiled1, 'fc_relu1.') + fc2_weights1 = _get_fc_weights(compiled1, 'fc_relu2.') + + cube_state_fc1 = [(fc1_weights0, {'state':{}}, fc1_dist_param_map[0], fc1_fullmap[0]), (fc1_weights1, {'state':{}}, fc1_dist_param_map[1], fc1_fullmap[1])] + cube_state_fc2 = [(fc2_weights0, {'state':{}}, fc2_dist_param_map[0], fc2_fullmap[0]), (fc2_weights1, {'state':{}}, fc2_dist_param_map[1], fc2_fullmap[1])] + merged_fc1, _ = ParallelModule.merge_partial_states(cube_state_fc1) + merged_fc1_fixed = {} + for k, v in merged_fc1.items(): + merged_fc1_fixed['fc_relu1.' + k] = v + merged_fc2, _ = ParallelModule.merge_partial_states(cube_state_fc2) + merged_fc2_fixed = {} + for k, v in merged_fc2.items(): + merged_fc2_fixed['fc_relu2.' + k] = v + assert len(merged_fc1_fixed) + len(merged_fc2_fixed) + len(compiled0) == len(orig0) + assert len(compiled1) == len(compiled0) + for k, v in compiled0.items(): + assert torch.allclose(compiled0[k], compiled1[k], rtol=1e-4, atol=1e-4) + for k, v in itertools.chain(merged_fc1_fixed.items(), merged_fc2_fixed.items(), compiled0.items()): + assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) + + +def test_submodules_tp_gpu2(): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + print('skip test_submodules_tp_gpu2 due to lack of cuda devices') + return + results = launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) + results0, results1 = results[0], results[1] + eps = 1e-4 + + fc1_fullmap = results0[2], results1[2] + fc1_dist_param_map = results0[3], results1[3] + + fc2_fullmap = results0[4], results1[4] + fc2_dist_param_map = results0[5],results1[5] + + for orig0, compiled0, orig1, compiled1 in zip(results0[0], results0[1], results1[0], results1[1]): + assert torch.allclose(orig0[0], orig1[0], rtol=eps, atol=eps) # pred + assert torch.allclose(orig0[0], compiled0[0], rtol=eps, atol=eps) # pred + assert torch.allclose(orig1[0], compiled1[0], rtol=eps, atol=eps) # pred + + assert torch.allclose(orig0[1], orig1[1], rtol=eps, atol=eps) # loss + assert torch.allclose(orig0[1], compiled0[1], rtol=eps, atol=eps) # loss + assert torch.allclose(orig1[1], compiled1[1], rtol=eps, atol=eps) # loss + + # grad + for k in orig0[2].keys(): + assert torch.allclose(orig0[2][k], orig1[2][k], rtol=eps, atol=eps) + _compare_weights(orig0[2], orig1[2], compiled0[2], compiled1[2], fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map) + + # weights + for k in orig0[3].keys(): + assert torch.allclose(orig0[3][k], orig1[3][k], rtol=eps, atol=eps) + _compare_weights(orig0[3], orig1[3], compiled0[3], compiled1[3], fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map) + + +def test_submodules_dp_gpu1(): + if not torch.cuda.is_available(): + print('skip test_submodules_dp_gpu1 due to lack of cuda devices') + return + results = launch_torchrun(1, _gpu_worker, PASData, 1) + orig_results, compiled_results, _, _, _, _ = results[0] + for orig, compiled in zip(orig_results, compiled_results): + assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred + assert torch.allclose(orig[1], compiled[1], rtol=1e-6, atol=1e-6) # loss + + # grad + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items() if not k.endswith('_grad_sentry')} + assert len(orig[2]) == len(compiled_cleaned) + for k in orig[2].keys(): + assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + + # weights + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items() if not k.endswith('_grad_sentry')} + assert len(orig[3]) == len(compiled_cleaned) + for k in orig[3].keys(): + assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + + +def test_submodules_dp_gpu2(): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + print('skip test_submodules_dp_gpu2 due to lack of cuda devices') + return + results = launch_torchrun(2, _gpu_worker, PASData, 2) + for r in results.values(): + orig_results, compiled_results, _, _, _, _ = r + for orig, compiled in zip(orig_results, compiled_results): + assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred + assert torch.allclose(orig[1], compiled[1], rtol=1e-6, atol=1e-6) # loss + + # grad + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items() if not k.endswith('_grad_sentry')} + assert len(orig[2]) == len(compiled_cleaned) + for k in orig[2].keys(): + assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + + # weights + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items() if not k.endswith('_grad_sentry')} + assert len(orig[3]) == len(compiled_cleaned) + for k in orig[3].keys(): + assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) From 28935485b7dc08c72a73d064a132d6c7877f9c42 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 8 Aug 2023 02:22:48 +0000 Subject: [PATCH 1458/1892] Merged PR 1703: rename cube to parallel --- cube/graph/parser/converter.py | 14 +++--- .../kwargs_shape_prop/kwargs_shape_prop.py | 11 +++-- cube/graph/parser/fx/parser.py | 43 ++++++++++--------- cube/{cube.py => parallel.py} | 8 ++-- cube/program.py | 2 +- cube/runtime/module.py | 5 ++- unit_tests/graph/parser/test_converter.py | 6 +-- unit_tests/parallel_module/common.py | 2 +- unit_tests/parallel_module/test_decorator.py | 41 ++++++++++++++++++ unit_tests/parallel_module/test_override.py | 4 +- unit_tests/parallel_module/test_submodule.py | 4 +- 11 files changed, 92 insertions(+), 48 deletions(-) rename cube/{cube.py => parallel.py} (99%) create mode 100644 unit_tests/parallel_module/test_decorator.py diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 45ff6ef4..1e079806 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -66,7 +66,7 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: def to_ir_graph( traced_model: torch.fx.GraphModule, dummy_input: Dict[str, Any], - attr_save_dir: Union[str, Path], + attr_savedir: Union[str, Path], dynamic_shape: bool = False, ) -> IRGraph: """Convert torch.fx.GraphModule based model into IRGraph @@ -77,18 +77,18 @@ def to_ir_graph( dummy input of model, the keys are the names of forward arguments. dynamic_shape (bool): whether to use dynamic shape. Default False. - attr_save_dir (Union[str, Path]): directory to save content (attribtes) + attr_savedir (Union[str, Path]): directory to save content (attribtes) Returns: IRGraph: IRGraph of model """ - FxModuleParser.save_content = True - FxModuleParser.dynamic_shape = dynamic_shape _logger.info(f"use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") inputs, nodes, outputs = FxModuleParser.parse( traced_model, dummy_input, - attr_save_dir=attr_save_dir + attr_savedir=attr_savedir, + dynamic_shape=dynamic_shape, + save_content=True, ) module_name = traced_model.__class__.__name__ @@ -103,7 +103,7 @@ def to_ir_graph( def convert_model( model: torch.nn.Module, dummy_input: Dict[str, Any], - attr_save_dir: Union[str, Path], + attr_savedir: Union[str, Path], dynamic_shape: bool = False ) -> IRGraph: """Convert torch.nn.Module based model into IRGraph @@ -120,5 +120,5 @@ def convert_model( IRGraph: IRGraph of model """ traced_model = to_fx_graph(model, dummy_input) - graph = to_ir_graph(traced_model, dummy_input, attr_save_dir, dynamic_shape) + graph = to_ir_graph(traced_model, dummy_input, attr_savedir, dynamic_shape) return graph diff --git a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py index 0a63e95c..d6a30779 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py +++ b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py @@ -70,13 +70,12 @@ class KwargsShapeProp(KwargsInterpreter): def run_node(self, n: Node): try: result = super().run_node(n) - except Exception: - traceback.print_exc() + except Exception as e: raise RuntimeError( f"ShapeProp error for: node={n.format_node()} with " f"meta={n.meta}" - ) - + ) from e + found_tensor = False def extract_tensor_meta(obj): @@ -86,7 +85,7 @@ def extract_tensor_meta(obj): return _extract_tensor_metadata(obj) else: return obj - + # if the obj is a tensor, then wrap it into a TensorMetaData # else recursively descend and wrap meta = map_aggregate(result, extract_tensor_meta) @@ -94,6 +93,6 @@ def extract_tensor_meta(obj): n.meta['tensor_meta'] = meta n.meta['type'] = type(result) return result - + def propagate(self, concrete_args: Union[Dict[str, Any], Tuple]): return super().run(concrete_args) diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 33768318..1c80db86 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -64,14 +64,10 @@ def get_complex_data(val: Any, frame: Frame) -> Any: class FxModuleParser: - """torch.fx module parser - - Attributes: - save_content (bool): whether to save the content of the module - dynamic_shape (bool): whether to parse the module with dynamic shape """ - save_content: bool = True - dynamic_shape: bool = False + torch.fx module parser + """ + ATTR_CONTENT_FILE = 'fullmodel.pt' ATTR_MAP_FILE = 'dist_param_map.pt' @@ -92,11 +88,18 @@ def shape_refine(shape: torch.Size) -> torch.Size: def parse(module: torch.fx.GraphModule, dummy_inputs: Dict[str, Any], frame: Frame = None, - attr_save_dir='./') \ - -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: + attr_savedir='./', + *, + save_content: bool = True, + dynamic_shape: bool = True + ) -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: """Parse torch.fx module into cube IR The overall entry to parse a torch.fx graph module + + Args: + save_content (bool): whether to save the content of the module + dynamic_shape (bool): whether to parse the module with dynamic shape """ from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp @@ -188,7 +191,7 @@ def parse_complex_out(meta_out): total_node_num = len(module.graph.nodes) for nidx, node in enumerate(module.graph.nodes): _logger.info(f'[{nidx}/{total_node_num}] parsing node {node}...') - ir_nodes = FxModuleParser.parse_node(node, module, frame) + ir_nodes = FxModuleParser.parse_node(node, module, dynamic_shape, frame) if ir_nodes is not None: all_ir_nodes += ir_nodes @@ -197,10 +200,10 @@ def parse_complex_out(meta_out): # output_val = frame.get_var(output_nodes[0].name) output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] - if FxModuleParser.save_content: - attr_save_dir = Path(attr_save_dir) - frame.save_attr_content(attr_save_dir / FxModuleParser.ATTR_CONTENT_FILE) - frame.save_attr_map(attr_save_dir / FxModuleParser.ATTR_MAP_FILE) + if save_content: + attr_savedir = Path(attr_savedir) + frame.save_attr_content(attr_savedir / FxModuleParser.ATTR_CONTENT_FILE) + frame.save_attr_map(attr_savedir / FxModuleParser.ATTR_MAP_FILE) frame.pop_var() frame.pop_attr() @@ -225,7 +228,7 @@ def ntype(node: torch.fx.Node): raise RuntimeError(f"Unknown node kind {node.kind()} from torchscript module") @staticmethod - def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation]: + def parse_node(node: torch.fx.Node, module, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: """ Parse the node and return the IRFwOperation nodes """ @@ -236,7 +239,7 @@ def parse_node(node: torch.fx.Node, module, frame: Frame) -> List[IRFwOperation] if node_type == FxNodeKind.Output: return FxModuleParser.parse_prim_output_node(node, module, frame) if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): - return FxModuleParser.parse_prim_function_method(node, module, frame) + return FxModuleParser.parse_prim_function_method(node, module, dynamic_shape, frame) if node_type == FxNodeKind.PrimGetAttr: return FxModuleParser.parse_prim_attr_node(node, module, frame) if node_type == FxNodeKind.PrimCallModule: @@ -303,7 +306,7 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: raise RuntimeError(f'unknown module: {prim_module.__class__.__module__}') @staticmethod - def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: # get signature fsig = FxModuleParser._get_qualified_name(node.target, node) @@ -311,10 +314,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule input_vals = [get_complex_data(val, frame) for val in node.args] kwargs = {key: get_complex_data(val, frame) for key, val in node.kwargs.items()} - return FxModuleParser._parse_node(fsig, node, input_vals, kwargs, frame) + return FxModuleParser._parse_node(fsig, node, input_vals, kwargs, dynamic_shape, frame) @staticmethod - def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, frame: Frame) -> List[IRFwOperation]: + def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: # map to IR operator if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) @@ -344,7 +347,7 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, for i in range(len(vals)): ir_node.set_output(i, vals[i]) elif ir_node.output(0).value is not None: - if FxModuleParser.dynamic_shape: + if dynamic_shape: frame.set_var(node.name, ir_node.output(0)) ir_node.output(0).name = node.name else: diff --git a/cube/cube.py b/cube/parallel.py similarity index 99% rename from cube/cube.py rename to cube/parallel.py index 679971ce..cac2207a 100644 --- a/cube/cube.py +++ b/cube/parallel.py @@ -25,7 +25,7 @@ from cube.execplan.planpass.fusion import DiffFusion from cube.ir.unique import IDGenerator from cube.program import Program -from cube.runtime.module import CubeModule +from cube.runtime.module import CubeModule, ParallelModule @dataclass @@ -285,7 +285,7 @@ def _load_cube_module_class( return cube_module_class -def as_cube( +def parallelize( module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]], dummy_input: dict, pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], @@ -364,7 +364,7 @@ def as_cube( return cube_module_class if is_module_class else cube_module_class() -def cube( +def parallel_module( dummy_input: dict, pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], compute_config: ComputeConfig, @@ -390,7 +390,7 @@ def cube( cube_savedir (Union[str, Path]): the directory to save generated code """ def wrap(module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]]) -> Union[CubeModule, Type[CubeModule]]: - return as_cube( + return parallelize( module_or_module_class, dummy_input, pas_policy, diff --git a/cube/program.py b/cube/program.py index 63af7d8b..56c8789e 100644 --- a/cube/program.py +++ b/cube/program.py @@ -217,7 +217,7 @@ def __call__(self, *args): self._ir_graph = parser.convert_model( self.model, dummy_input=self.dummy_input, - attr_save_dir='./', + attr_savedir='./', dynamic_shape=self.dynamic_shape ) return self._ir_graph(*args) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index f9495e36..c2662c5b 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -397,6 +397,7 @@ def forward(self, *args): class ParallelModule(CubeModule): + _EPSILON = 1e-7 # A small constant that can be represented by fp16 def __init__(self): super().__init__() self._dist_param_map = None # should fill in sub classes. @@ -411,12 +412,12 @@ def __init__(self): # NOTE: this is a hacky way to detect whether the backward is called # And it will add an extra parameter to the module self._grad_sentry = torch.nn.Parameter(torch.tensor([0.0], requires_grad=True)) - self._grad_sentry.grad = torch.tensor([1e-7]) + self._grad_sentry.grad = torch.tensor([self._EPSILON]) def forward(self, *args, **kwargs): if self._grad_sentry.grad is None or self._grad_sentry.grad.item() == 0: self.zero_grad() - self._grad_sentry.grad = torch.tensor([1.0]) + self._grad_sentry.grad = torch.tensor([self._EPSILON]) new_args = self._add_grad_module(*args) return self._forward_impl(*new_args, **kwargs) diff --git a/unit_tests/graph/parser/test_converter.py b/unit_tests/graph/parser/test_converter.py index c817daa7..7e010e79 100644 --- a/unit_tests/graph/parser/test_converter.py +++ b/unit_tests/graph/parser/test_converter.py @@ -37,8 +37,8 @@ def forward(self, x, **kwargs): assert any(node.op == 'call_function' and node.target == torch.nn.functional.linear for node in nodes) with tempfile.TemporaryDirectory() as tempdir: - to_ir_graph(fx_graph, dummy_input, attr_save_dir=tempdir, dynamic_shape=True) - ir_graph = to_ir_graph(fx_graph, dummy_input, attr_save_dir=tempdir, dynamic_shape=True) + to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) assert ir_graph is not None assert (Path(tempdir) / FxModuleParser.ATTR_MAP_FILE).exists() assert (Path(tempdir) / FxModuleParser.ATTR_CONTENT_FILE).exists() @@ -72,4 +72,4 @@ def forward(self, x, *args): with tempfile.TemporaryDirectory() as tempdir: # currently we don't support *args with pytest.raises(RuntimeError): - to_ir_graph(fx_graph, dummy_input, attr_save_dir=tempdir, dynamic_shape=True) + to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) diff --git a/unit_tests/parallel_module/common.py b/unit_tests/parallel_module/common.py index 25b78f39..b75a765d 100644 --- a/unit_tests/parallel_module/common.py +++ b/unit_tests/parallel_module/common.py @@ -7,7 +7,7 @@ from torch import nn import numpy as np -from cube.cube import ComputeConfig +from cube.parallel import ComputeConfig from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.dimops import IRDimops from cube.graph.graph import IRGraph diff --git a/unit_tests/parallel_module/test_decorator.py b/unit_tests/parallel_module/test_decorator.py new file mode 100644 index 00000000..a3b67be8 --- /dev/null +++ b/unit_tests/parallel_module/test_decorator.py @@ -0,0 +1,41 @@ +import tempfile + +import torch +from cube.parallel import ComputeConfig, parallel_module, CubeModule + +from .common import PASData, init_distributed +from ..launch_torchrun import launch_torchrun + + +def _decorator_worker(): + init_distributed() + with tempfile.TemporaryDirectory() as tempdir: + @parallel_module( + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + ) + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + return self.linear(x) + assert issubclass(MyModule, CubeModule) + x = MyModule() + y = MyModule() + + MyModule()(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) + # parameters from different instances will have the same value. + for p, q in zip(x.parameters(),y.parameters()): + assert torch.equal(p, q) + + +def test_decorator(): + if not torch.cuda.is_available(): + print('skip test_submodules_tp_gpu1 due to lack of cuda devices') + return + launch_torchrun(1, _decorator_worker) diff --git a/unit_tests/parallel_module/test_override.py b/unit_tests/parallel_module/test_override.py index c9f944ed..9a6ffcd0 100644 --- a/unit_tests/parallel_module/test_override.py +++ b/unit_tests/parallel_module/test_override.py @@ -5,14 +5,14 @@ import torch import shutil -from cube.cube import as_cube, ComputeConfig +from cube.parallel import parallelize, ComputeConfig from .common import PASData, init_distributed from ..launch_torchrun import launch_torchrun def _to_cube_model(module, compute_config, cube_savedir, override, instance_name): - return as_cube( + return parallelize( module, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, diff --git a/unit_tests/parallel_module/test_submodule.py b/unit_tests/parallel_module/test_submodule.py index 569b3208..2ff97e1d 100644 --- a/unit_tests/parallel_module/test_submodule.py +++ b/unit_tests/parallel_module/test_submodule.py @@ -8,7 +8,7 @@ from torch import nn import numpy as np -from cube.cube import ComputeConfig, as_cube +from cube.parallel import ComputeConfig, parallelize from cube.runtime.module import ParallelModule from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed @@ -36,7 +36,7 @@ def __init__(self): def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): - return as_cube( + return parallelize( module, {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, From 7c6d29fb4f430d5dc5858bddea6812570d44a17f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Aug 2023 02:45:40 +0000 Subject: [PATCH 1459/1892] Merged PR 1700: Remove model dummy input argument requests from compile interface --- cube/compiler.py | 8 ++++---- cube/program.py | 44 +++++++++++++++++++++++++++++--------------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index b5dc5cc8..dedacc69 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple, Union, Optional, Any +from typing import Callable, Tuple, Union, Optional import torch import time import os @@ -36,7 +36,6 @@ def compile(model: SemanticModel, *args, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, - model_dummy_inputs: Tuple[Any] = None, model_dynamic_shape: bool = False, comm_cost_fn: Optional[Callable] = None, override = True, @@ -85,7 +84,6 @@ def train_iter(model, dataloader): assert isinstance(model, SemanticModel), f'Require cube.SemanticModel or torch.nn.Module, but got model: {type(model)}' model.save_content = load_content model.dynamic_shape = model_dynamic_shape - model.dummy_input = model_dummy_inputs dataloader = None inputs = [model] @@ -100,12 +98,14 @@ def train_iter(model, dataloader): # generate backward communications in adapter. However, as long as # the data doesn't require gradient in real runtime, the backward # communication will not be triggered. + tensor = arg arg = IRFullTensor(arg.shape, name='tensor', requires_grad=True, dtype=DType2IRDType.map(arg.dtype)).tosub() + arg._value = tensor arg.grad = arg.parent.grad.tosub() if arg.requires_grad else None else: - arg= IRObject('obj') + arg = IRObject('obj', value=arg) inputs.append(arg) myrank = DeviceGroup().rank diff --git a/cube/program.py b/cube/program.py index 56c8789e..df588204 100644 --- a/cube/program.py +++ b/cube/program.py @@ -1,7 +1,8 @@ -from typing import List, Tuple, Optional, Any +from typing import List, Tuple, Optional, Any, Dict +import inspect -from cube.ir.cten import IRCell, IRTensor, IRObject -from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.cten import IRCell, IRObject +from cube.ir.tensor import IRFullTensor from cube.ir.operator import IRBpOperation, IRDataOperation from cube.graph import IRGraph @@ -22,7 +23,6 @@ class Program: class __Program: def __init__(self): - self._graph = IRGraph([], [], [], 'program') instance = None @@ -64,14 +64,8 @@ def finalize(self): for ftensor in graph.full_tensors(): ftensor.requires_grad = False - def mirror_as_self(self): - """ - Set mirror as self. This is called when a backward is triggered. - """ - IRCell.make_pair(self.instance._graph, self.instance._graph) - def clear(self): - self.instance._graph = IRGraph([], [], [], 'program') + Program.instance._graph = IRGraph([], [], [], 'program') def __repr__(self): return repr(self.instance._graph) @@ -115,9 +109,11 @@ def generate_output(sample): return {generate_output(t) for t in sample} if isinstance(sample, torch.Tensor): shape, dtype = list(sample.shape), dtype_map.map(sample.dtype) - return IRFullTensor(shape, 'data', dtype=dtype).tosub() + tensor = IRFullTensor(shape, 'data', dtype=dtype).tosub() + tensor._value = sample + return tensor else: - return IRObject('data') + return IRObject('data', value=sample) sample = next(self.dataloader) outputs = generate_output(sample) @@ -159,7 +155,7 @@ def __init__(self, model: Optional[torch.nn.Module], if DeviceGroup().local_rank == 0: assert isinstance(model, torch.nn.Module), f"device of local_rank == 0 must provide model" self.model = model - self._dummy_input = None + self._dummy_input: Dict[str, Any] = None self._ir_graph = None self._loaded_module: CubeModule = None # parser configuration @@ -206,14 +202,32 @@ def clear_module(self): def __call__(self, *args): """Forward the semantic model. - This will trigger torch.jit.script to parse the model. + This will parse the model into cube graph. Args: *args: input IRObjects + + Returns: + graph outputs with IRObjects """ assert self._ir_graph is None, \ f"multiple forward on a semantic model is not allowed" if DeviceGroup().local_rank == 0: + # collect dummy input + if self.dummy_input is None: + dummy_input = {} + sig = inspect.signature(self.model.forward) + # note: we don't support model forward arguments having complex data stucture + # that contains tensor + for name, arg in zip(sig.parameters.keys(), args): + if isinstance(arg, IRObject): + value = arg.value + arg._value = None # remove tensor reference to release memory + else: + value = arg + dummy_input[str(name)] = value + self.dummy_input = dummy_input + # parse graph self._ir_graph = parser.convert_model( self.model, dummy_input=self.dummy_input, From ec35f5c26dcf138cfa963b7b41b1d1375f0a55ad Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 8 Aug 2023 06:41:04 +0000 Subject: [PATCH 1460/1892] Merged PR 1705: remove ir dtype --- cube/codegen/emit.py | 7 -- cube/codegen/frontend_mapping.py | 22 ------ cube/codegen/module/module.py | 2 +- cube/compiler.py | 5 +- cube/graph/function/function.py | 41 ++++------ cube/graph/gener/rvd/inter.py | 1 - cube/graph/gener/rvd/intra.py | 4 +- cube/graph/graph.py | 16 ---- cube/graph/parser/dtype.py | 44 ----------- cube/graph/parser/fx/parser.py | 15 ++-- cube/ir/cten.py | 27 +++---- cube/ir/dtype.py | 117 +++++++++++++--------------- cube/ir/operator.py | 6 +- cube/ir/tensor.py | 28 +------ cube/parallel.py | 3 +- cube/profiler/database.py | 5 +- cube/program.py | 5 +- examples/policies/alpa/estimator.py | 7 +- 18 files changed, 110 insertions(+), 245 deletions(-) delete mode 100644 cube/graph/parser/dtype.py diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index eaa67cd0..09284ec2 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -2,7 +2,6 @@ import logging from cube.ir.cten import IRCell, IRTensor, IRObject -from cube.ir.dtype import IRDType from cube.ir.tensor import IRSubTensor from cube.ir.operator import IRDataOperation, IRFwOperation from cube.ir.adapter import IRWeightReducer, IRAdapter @@ -31,12 +30,6 @@ class CodeEmission: Basic emission """ - @staticmethod - def dtype_map(dtype: IRDType) -> str: - if not isinstance(dtype, IRDType): - raise TypeError("Expected IRDType") - return 'torch.' + dtype.value - @staticmethod def node_name(node: IRCell) -> str: return f"{node.name}{node.cid}" diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index 32b40029..df80c58a 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -5,7 +5,6 @@ from cube import ir from cube.ir.cten import IRTensor -from cube.ir.dtype import IRDType from cube.ir.operator import IRFwOperation import torch @@ -105,24 +104,3 @@ def emit_getitem(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: '_operator.getitem': Sign2EmitRule.emit_getitem, } - -class IRDType2DType: - """ - The reverse mapping of DType2IRDType in /graph/parser/mapping.py - """ - - @staticmethod - def map(ir_dtype:IRDType) -> torch.dtype: - return IRDType2DType._map[ir_dtype] # subscript/[]-access will throw if not found - - _map = { - ir.float64: torch.float64, - ir.float32: torch.float32, - ir.float16: torch.float16, - ir.uint8: torch.uint8, - ir.int8: torch.int8, - ir.int16: torch.int16, - ir.int32: torch.int32, - ir.int64: torch.int64, - ir.boolean: torch.bool - } diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index df320b37..e52e0393 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -476,7 +476,7 @@ def init_attributes(self, node: IRCell): code = sign.format( name=ModuleCodeGen.tensor_name(itensor), shape=tuple(itensor.shape), - dtype=self.dtype_map(itensor.dtype) + dtype=itensor.dtype ) self.model_init_statements.append(code) tid = itensor.parent.tid diff --git a/cube/compiler.py b/cube/compiler.py index dedacc69..91cbdf08 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -11,7 +11,8 @@ from cube.ir.unique import IDGenerator from cube.graph.gener.gen import IRAdapterGener from cube.graph.graph import IRGraph -from cube.graph.parser.dtype import DType2IRDType +from cube.ir.cten import IRObject +from cube.ir.tensor import IRFullTensor from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.pyfunc import IRPyFunc from cube.graph.schedule.schedplan import SchedulePlan @@ -101,7 +102,7 @@ def train_iter(model, dataloader): tensor = arg arg = IRFullTensor(arg.shape, name='tensor', requires_grad=True, - dtype=DType2IRDType.map(arg.dtype)).tosub() + dtype=arg.dtype).tosub() arg._value = tensor arg.grad = arg.parent.grad.tosub() if arg.requires_grad else None else: diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 6faa293b..aeee818a 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -10,7 +10,6 @@ from cube.ir.cten import IRTensor, IRObject from cube.ir.tensor import IRSubTensor, IRFullTensor -from cube.ir.dtype import IRDType from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D @@ -153,8 +152,7 @@ def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Uni anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), False) dimop = IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.dtype import DType2IRDType - dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + dimop.output(0).parent.dtype = dtype return dimop @@ -189,8 +187,7 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.dtype import DType2IRDType - dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + dimop.output(0).parent.dtype = dtype return dimop @@ -207,8 +204,7 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.dtype import DType2IRDType - dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + dimop.output(0).parent.dtype = dtype return dimop @@ -225,8 +221,7 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.dtype import DType2IRDType - dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + dimop.output(0).parent.dtype = dtype return dimop @@ -244,8 +239,7 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.dtype import DType2IRDType - dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + dimop.output(0).parent.dtype = dtype return dimop @@ -259,8 +253,7 @@ def NewTensor(data, *, dtype=None, device=None, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(size, True) dimop = IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) - from cube.graph.parser.dtype import DType2IRDType - dimop.output(0).parent.dtype = DType2IRDType.map(dtype) + dimop.output(0).parent.dtype = dtype return dimop @@ -1608,13 +1601,13 @@ def _comparison(creator: Callable, f: Callable, name: str, signature: str, lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] dimop = IRDimops(creator, name, signature, annos, [input, other]) - dimop.output(0).parent.dtype = IRDType.boolean + dimop.output(0).parent.dtype = torch.bool return dimop # case2: torch.equal(tensor1, obj2) / torch.equal(obj1, tensor2) if isinstance(input, IRTensor) or isinstance(other, IRTensor): annos = ['*, ? -> *', '?, * -> *',] dimop = IRDimops(creator, name, signature, annos, [input, other]) - dimop.output(0).parent.dtype = IRDType.boolean + dimop.output(0).parent.dtype = torch.bool return dimop # case3: torch.equal(obj1, obj2) else: @@ -1713,11 +1706,10 @@ def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): annos = ['* -> *'] if isinstance(dtype_or_device, torch.device): return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device) - elif isinstance(dtype_or_device, (IRDType, torch.dtype)): - dtype = dtype_or_device if isinstance(dtype_or_device, torch.dtype) else eval('torch.'+dtype_or_device.value) - return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) - elif isinstance(dtype_or_device, IRFullTensor): - dtype = eval('torch.'+dtype_or_device.dtype.value) + elif isinstance(dtype_or_device, torch.dtype): + return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device) + elif isinstance(dtype_or_device, IRTensor): + dtype = dtype_or_device.dtype return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) else: raise RuntimeError(f'function.To with unknown arg: {dtype_or_device}') @@ -1756,10 +1748,9 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], shape = IRObject('shape', value=obj.shape) return IRPyFunc(signature, [instance, field], [shape]) if name == 'dtype': - from cube.graph.parser.dtype import IRDType2TorchDType assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" assert hasattr(obj, name), f"attr {name} is not existed in {obj}" - return IRDType2TorchDType.map(getattr(obj, name)) + return getattr(obj, name) if name == 'device': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" # FIXME: this is hack, IRFullTensor does not have attribute "device" @@ -1773,9 +1764,9 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], return IRPyFunc(signature, [instance, field], [IRObject()]) -def FInfo(dtype: IRDType, signature = None) -> torch.finfo: - assert isinstance(dtype, IRDType) - return torch.finfo(eval('torch.' + dtype.value)) +def FInfo(dtype: torch.dtype, signature = None) -> torch.finfo: + assert isinstance(dtype, torch.dtype) + return torch.finfo(dtype) def NLLLoss(input, target, weight=None, size_average=None, diff --git a/cube/graph/gener/rvd/inter.py b/cube/graph/gener/rvd/inter.py index bc3eefcb..41e35589 100644 --- a/cube/graph/gener/rvd/inter.py +++ b/cube/graph/gener/rvd/inter.py @@ -4,7 +4,6 @@ import sys import copy -from cube.ir.dtype import IRDType from cube.ir.tensor import IRFullTensor from cube.ir.adapter.prim import IRAdapterPrim diff --git a/cube/graph/gener/rvd/intra.py b/cube/graph/gener/rvd/intra.py index 85222143..2a705a54 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/cube/graph/gener/rvd/intra.py @@ -3,8 +3,8 @@ import numpy as np import copy import logging +import torch -from cube.ir.dtype import IRDType from cube.ir.cten import IRCell from cube.ir.tensor import IRFullTensor, IRSubTensor @@ -624,7 +624,7 @@ def advice(shape: TShape, @return cost float: Cost of communication plan """ src_placement = tuple(src_placement) - ftensor = IRFullTensor(shape, dtype=IRDType.float16) + ftensor = IRFullTensor(shape, dtype=torch.float16) cost_fn = IntraPathFinder.default_cost_fn if cost_fn is None else cost_fn # forward pass diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5a15efb7..23b6e62c 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -16,7 +16,6 @@ from cube.ir.unique import IDGenerator from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap -from cube.ir.dtype import IRDType, DTypeInferRule from cube.graph.function.function import Identity from cube.graph.function.anchor import IRGraphAnchor @@ -111,21 +110,6 @@ def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: self.output(oidx), lambda t: t if t != itensor else arg) self.set_output(oidx, output) - # dtype inference - for node in self._nodes: - # reset input - itensors: List[IRTensor] = [t for t in node.inputs() if isinstance(t, IRSubTensor)] - for itensor in itensors: - itensor.parent.dtype = itensor.dtype - # infer output dtype with default dtype promotion rules - if len(itensors) == 0: continue - default_dtype = DTypeInferRule.infer(node, [t.dtype for t in itensors]) - # set output tensors if it has unkown tensor dtype - otensors = [t for t in node.outputs() if isinstance(t, IRSubTensor)] - for otensor in otensors: - if otensor.dtype == IRDType.unknown: - otensor.parent.dtype = default_dtype - from cube.program import Program Program().add_nodes(self.nodes()) diff --git a/cube/graph/parser/dtype.py b/cube/graph/parser/dtype.py deleted file mode 100644 index f2136e4a..00000000 --- a/cube/graph/parser/dtype.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import cube.ir as ir - - -class DType2IRDType: - - @staticmethod - def map(dtype: torch.dtype): - """ - Map the torch dtype to IRDType - """ - return DType2IRDType.kDtypeMap[dtype] - - kDtypeMap = { - torch.double: ir.float64, - torch.float64: ir.float64, - torch.float32: ir.float32, - torch.float : ir.float32, - torch.float16: ir.float16, - torch.half : ir.float16, - torch.bfloat16: ir.bfloat16, - torch.uint8 : ir.uint8, - torch.int8 : ir.int8, - torch.int16 : ir.int16, - torch.short : ir.int16, - torch.int32 : ir.int32, - torch.int : ir.int32, - torch.int64 : ir.int64, - torch.long : ir.int64, - torch.bool : ir.boolean - } - - -class IRDType2TorchDType: - - @staticmethod - def map(ir_dtype: ir.IRDType): - """ - Map the IRDtype to torch dtype - """ - assert ir_dtype in IRDType2TorchDType.kDtypeMap, f'unexpected ir_dtype {ir_dtype}' - return IRDType2TorchDType.kDtypeMap[ir_dtype] - - kDtypeMap = {val: key for key, val in DType2IRDType.kDtypeMap.items()} \ No newline at end of file diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 1c80db86..7256765b 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -8,7 +8,6 @@ from cube.ir.tensor import IRFullTensor from cube.ir.cten import IRObject, IRCell from cube.graph.parser.frame import Frame -from cube.graph.parser.dtype import DType2IRDType from cube.graph.parser.fx.mapping import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import IRDimops @@ -126,7 +125,7 @@ def parse(module: torch.fx.GraphModule, shape = input.meta['tensor_meta'].shape if len(shape) == 0: shape = [1] - dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) + dtype = input.meta['tensor_meta'].dtype val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) else: val = IRObject(input.name) @@ -138,7 +137,7 @@ def parse(module: torch.fx.GraphModule, else: # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name shape = None - dtype = DType2IRDType.map(input.meta['tensor_meta'].dtype) + dtype = input.meta['tensor_meta'].dtype val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) frame.add_var(input.name, val, graph_arg=idx) @@ -153,7 +152,7 @@ def parse_complex_out(meta_out): if isinstance(meta_out, TensorMetadata): shape = meta_out.shape assert shape == torch.Size([]), f'{meta_out}' - return IRFullTensor(shape=shape, requires_grad=meta_out.requires_grad, dtype=DType2IRDType.map(meta_out.dtype)) + return IRFullTensor(shape=shape, requires_grad=meta_out.requires_grad, dtype=meta_out.dtype) elif isinstance(meta_out, dict): ret = {} for k, v in meta_out.items(): @@ -173,7 +172,7 @@ def parse_complex_out(meta_out): if isinstance(meta_out, TensorMetadata): shape = meta_out.shape shape = FxModuleParser.shape_refine(shape) - dtype = DType2IRDType.map(meta_out.dtype) + dtype = meta_out.dtype requires_grad = meta_out.requires_grad val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=node.name) else: @@ -276,7 +275,7 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: assert len(kwargs) == 0 # add var of weight and bias into frame shape = FxModuleParser.shape_refine(prim_module.weight.size()) - dtype = DType2IRDType.map(prim_module.weight.dtype) + dtype = prim_module.weight.dtype requires_grad = prim_module.weight.requires_grad ir_weight_val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=f'{node.name}_weight') ir_weight_val.as_param() @@ -284,7 +283,7 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: frame.add_attr_content(ir_weight_val.tid, prim_module.weight) frame.add_attr_map(ir_weight_val.name, node.target+'.weight') shape = FxModuleParser.shape_refine(prim_module.bias.size()) - dtype = DType2IRDType.map(prim_module.bias.dtype) + dtype = prim_module.bias.dtype requires_grad = prim_module.bias.requires_grad ir_bias_val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=f'{node.name}_bias') ir_bias_val.as_param() @@ -381,7 +380,7 @@ def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, fram tensor_name = node.name if 'tensor_meta' in node.meta: tensor_shape = node.meta['tensor_meta'].shape - dtype = DType2IRDType.map(node.meta['tensor_meta'].dtype) + dtype = node.meta['tensor_meta'].dtype requires_grad = node.meta['tensor_meta'].requires_grad # check if existing param diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 3c326461..dfb99bfd 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -17,9 +17,10 @@ from functools import lru_cache from typing import Iterable, List, Tuple, Union, Optional, Any import copy +import torch from cube.ir.unique import IDGenerator -from cube.ir.dtype import IRDType, dtype2byte_size +from cube.ir.dtype import DTypeInfo class IRCell: @@ -547,32 +548,28 @@ class IRTensor(IRObject): _meta = ['name', '_is_attr', '_is_grad', '_requires_grad', '_dtype'] - def __init__(self, shape=None, name='tensor', dtype=IRDType.unknown, tid=None): + def __init__(self, shape=None, name='tensor', dtype=None, tid=None): super().__init__(name, tid) self._shape: Tuple[int] = () if shape is None else tuple(shape) self._cell: Optional[IRCell] = None - assert isinstance(dtype, IRDType), f'expect IRDType, get {dtype} with type {type(dtype)}' - self._dtype: IRDType = dtype + self._dtype: Optional[torch.dtype] = dtype # tensor gradient self._is_grad: bool = False self._requires_grad: bool = False self._grad: Optional[Union[IRTensor, float]] = None @property - def dtype(self) -> IRDType: - """ - Tensor data type - """ + def dtype(self) -> Optional[torch.dtype]: + """Tensor data type""" return self._dtype @dtype.setter - def dtype(self, val: IRDType): - """ - Set data type - """ - if not isinstance(val, IRDType): - raise TypeError(f"Expected IRDType but got {val}") + def dtype(self, val: Optional[torch.dtype]): + """Set data type""" + if not isinstance(val, torch.dtype): + raise NotImplementedError( + "Only support setting IRTensor with dtype of torch.dtype") self._dtype = val if isinstance(self._grad, IRTensor): self._grad._dtype = val @@ -673,7 +670,7 @@ def nelement(self) -> int: return cnt def byte_size(self) -> int: - return self.nelement() * dtype2byte_size(self.dtype) + return self.nelement() * DTypeInfo.get_byte_size(self.dtype) def backward(self) -> None: """ diff --git a/cube/ir/dtype.py b/cube/ir/dtype.py index 6a5e64f0..8040b6d3 100644 --- a/cube/ir/dtype.py +++ b/cube/ir/dtype.py @@ -1,72 +1,65 @@ -from typing import List -from enum import Enum +from typing import List, Any +import torch -class IRDType(Enum): - float64 = 'float64' - float16 = 'float16' - float32 = 'float32' - bfloat16 = 'bfloat16' - int64 = 'int64' - int32 = 'int32' - int16 = 'int16' - int8 = 'int8' - uint8 = 'uint8' - boolean = 'bool' - unknown = 'unknown' +class DTypeInfo: + """Tensor dtype information + Attributes: + bytes (Dict[Any, int]): data type -> btye size. + priority (List[torch.dtype]): the priority of dtypes for promotion + """ + bytes = { + torch.complex128: 128, + torch.complex64: 64, + torch.complex32: 32, + torch.float64: 8, + torch.float32: 4, + torch.bfloat16: 2, + torch.float16: 2, + torch.int64: 8, + torch.int32: 4, + torch.int16: 2, + torch.int8: 1, + torch.uint8: 1, + torch.bool: 1, + } -def dtype2byte_size(dtype: IRDType) -> int: - return { - IRDType.float64: 8, - IRDType.float32: 4, - IRDType.float16: 2, - IRDType.bfloat16: 2, - IRDType.int64: 8, - IRDType.int32: 4, - IRDType.int16: 2, - IRDType.int8: 1, - IRDType.uint8: 1, - IRDType.boolean: 1, - }.get(dtype, 0) + priority = [ + torch.float64, torch.float32, torch.bfloat16, torch.float16, + torch.int64, torch.int32, torch.int16, torch.int8, torch.bool + ] + @staticmethod + def get_byte_size(dtype: Any) -> int: + """Get dtype btye size""" + if dtype not in DTypeInfo.bytes: + raise NotImplementedError(f'Unknown dtype {dtype}') + return DTypeInfo.bytes[dtype] + + @staticmethod + def promote(dtypes: List[torch.dtype]) -> torch.dtype: + """Infer the promoted dtype according to dtypes. -class DTypeInferRule: - """ - Infer the output shape according to given input shapes. - This will follow the dtype promotion rule, which is same with PyTorch. + This will follow the dtype promotion rule, which is same with PyTorch. - Reference: - https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc + Reference: + https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc - complex > floating > integral > boolean - """ - @staticmethod - def infer(node, dtypes: List[IRDType]) -> IRDType: - dtypes = [dtype for dtype in dtypes if dtype != IRDType.unknown] - if IRDType.unknown in dtypes: - raise RuntimeError(f"Find an unkown dtype") - if IRDType.float32 in dtypes and IRDType.float16 in dtypes: - raise RuntimeError(f"Find node has both fp32 and fp16 inputs {node}") - # in priority: fp32 > fp16 > bool > int64 > int16 > - priority = [ - IRDType.float64, IRDType.float32, IRDType.float16, - IRDType.int64, IRDType.int32, IRDType.int16, IRDType.int8, - IRDType.boolean - ] - for dtype in priority: - if dtype in dtypes: - return dtype - return IRDType.unknown + priority: torch.float64 > torch.float32 > torch.bfloat16 > torch.float16 > + torch.int64 > torch.int32 > torch.int16 > torch.int8 > torch.bool + Args: + dtypes List[torch.dtype]: a list of dtypes -float64 = IRDType.float64 -float16 = IRDType.float16 -float32 = IRDType.float32 -bfloat16 = IRDType.bfloat16 -int64 = IRDType.int64 -int32 = IRDType.int32 -int16 = IRDType.int16 -int8 = IRDType.int8 -uint8 = IRDType.uint8 -boolean = IRDType.boolean + Returns: + the promoted dtype + """ + if not all(dtype in DTypeInfo.priority for dtype in dtypes): + raise NotImplementedError( + f"Fail to promote dtypes because one dtype " + f"in {dtypes} doesn't appear in priority list.") + dtype = None + for dtype in DTypeInfo.priority: + if dtype in dtypes: break + return dtype diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 995450b2..fc18bafc 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -5,7 +5,7 @@ from cube.ir.tensor import IRFullTensor from cube.algorithm.factory import DistAlgorithmFactory from cube.algorithm.generics import GenericDistAlgo -from cube.ir.dtype import IRDType, DTypeInferRule +from cube.ir.dtype import DTypeInfo class IRFwOperation(IRCell): @@ -51,10 +51,10 @@ def infer_dtype(self): """ itensors = [t for t in self.inputs() if isinstance(t, IRTensor)] otensors = [t for t in self.outputs() if isinstance(t, IRTensor)] - odtype = DTypeInferRule.infer(self, [t.dtype for t in itensors]) + odtype = DTypeInfo.promote([t.dtype for t in itensors]) for tensor in otensors: # in case of setting manually due to special rules - if tensor.dtype == IRDType.unknown: + if tensor.dtype is None: if isinstance(tensor, IRFullTensor): tensor.dtype = odtype else: diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index bbffede4..b6de94be 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -24,10 +24,9 @@ 3) gradient of parameters """ -from typing import List, Optional, Union, Tuple, NewType, Dict +from typing import List, Optional, Union, Tuple, NewType, Dict, Any from cube.ir.cten import IRTensor -from cube.ir.dtype import IRDType StartEnd = NewType('[start:end)', Tuple[int, int]) IdxChunk = NewType('(index, chunks)', Tuple[int, int]) @@ -256,7 +255,7 @@ class IRFullTensor(IRTensor): the sequentail execution order by its graph. """ - def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=IRDType.unknown): + def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=None): super().__init__(shape, name, dtype) @@ -340,25 +339,6 @@ def requires_grad(self, req_grad: bool): self._requires_grad = False self._grad = None - @property - def dtype(self) -> IRDType: - """ - Tensor data type - """ - return self._dtype - - @dtype.setter - def dtype(self, val: IRDType): - """ - Set data type. - It's gradient data type will also be set. - """ - if not isinstance(val, IRDType): - raise TypeError(f"Expected IRDType but got {val}") - self._dtype = val - if isinstance(self.grad, IRTensor): - self.grad.dtype = val - def as_param(self): """ Set the tensor as trainable parameter @@ -494,11 +474,11 @@ def ndims(self) -> int: return len(self.shape) @property - def dtype(self) -> IRDType: + def dtype(self) -> Any: return self.parent.dtype @dtype.setter - def dtype(self, val: IRDType): + def dtype(self, val): raise RuntimeError( f"IRSubTensor dtype must follow IRFullTensor dtype. " f"Please set it by subtensor.parent.dtype = {val}" diff --git a/cube/parallel.py b/cube/parallel.py index cac2207a..a7114231 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -13,7 +13,6 @@ from cube.graph import IRGraph from cube.graph import parser -from cube.graph.parser.dtype import DType2IRDType from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.pyfunc import IRPyFunc from cube.graph.schedule.schedplan import SchedulePlan @@ -198,7 +197,7 @@ def _gencode( shape=ir_dummy_inputs[i].size(), name=fx_input_nodes[i].target, requires_grad=True, - dtype=DType2IRDType.map(ir_dummy_inputs[i].dtype)).tosub() + dtype=ir_dummy_inputs[i].dtype).tosub() ir_dummy_inputs[i].grad = ir_dummy_inputs[i].parent.grad.tosub() else: ir_dummy_inputs[i] = IRObject( diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 2f006e0b..e1de7124 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -13,7 +13,6 @@ import cube from cube.ir.cten import IRTensor, IRObject from cube.ir.operator import IRFwOperation -from cube.graph.parser.dtype import IRDType2TorchDType from cube.graph.parser.register import CustomizedOps _logger = logging.getLogger(__name__) @@ -220,7 +219,7 @@ def extract_val(val: Union[IRObject, Any]) -> Any: for t in node.inputs(): if isinstance(t, IRTensor): shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) + dtypes.append(t.dtype) requires_grads.append(t.requires_grad) values.append(t) else: @@ -396,7 +395,7 @@ def _serialize(self, node: IRFwOperation) -> str: for t in node.inputs(): if isinstance(t, IRTensor): shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) + dtypes.append(t.dtype) # else: # shapes.append(None) # dtypes.append(type(t)) diff --git a/cube/program.py b/cube/program.py index df588204..c44be527 100644 --- a/cube/program.py +++ b/cube/program.py @@ -7,7 +7,6 @@ from cube.graph import IRGraph from cube.graph import parser -from cube.graph.parser.dtype import DType2IRDType from cube.runtime.syndata import CubeDataLoader from cube.runtime.module import CubeModule @@ -95,7 +94,6 @@ def __iter__(self): return self def __next__(self): - dtype_map = DType2IRDType def generate_output(sample): """Support complex of types: List, Tuple, torch.Tensor, object""" if isinstance(sample, tuple): @@ -108,8 +106,7 @@ def generate_output(sample): if isinstance(sample, set): return {generate_output(t) for t in sample} if isinstance(sample, torch.Tensor): - shape, dtype = list(sample.shape), dtype_map.map(sample.dtype) - tensor = IRFullTensor(shape, 'data', dtype=dtype).tosub() + tensor = IRFullTensor(list(sample.shape), 'data', dtype=sample.dtype).tosub() tensor._value = sample return tensor else: diff --git a/examples/policies/alpa/estimator.py b/examples/policies/alpa/estimator.py index 2f6b67c7..14cee328 100644 --- a/examples/policies/alpa/estimator.py +++ b/examples/policies/alpa/estimator.py @@ -10,7 +10,6 @@ from cube.ir.cten import IRTensor, IRObject, IRCell from cube.ir.operator import IRFwOperation -from cube.graph.parser.dtype import IRDType2TorchDType from cube.graph.parser.register import CustomizedOps from cube.graph.segment import IRSegment from cube.graph.function.dimops import IRDimops @@ -115,7 +114,7 @@ def get_inputs(node: IRFwOperation, train: bool) -> Tuple[List, Dict]: # create data def dummy_torch_tensor(tensor: IRTensor): """Generate dummy input tenosrs""" - dtype = IRDType2TorchDType.map(tensor.dtype) + dtype = tensor.dtype constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand return constructor(tuple(tensor.shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=tensor.requires_grad) @@ -208,7 +207,7 @@ def profile(self, node: IRFwOperation, train: bool = True, device: Optional[int] latency, memory = e, e shapes = tuple(t.shape if isinstance(t, IRTensor) else None for t in node.inputs()) - dtypes = tuple(IRDType2TorchDType.map(t.dtype) if isinstance(t, IRTensor) else None for t in node.inputs()) + dtypes = tuple(t.dtype if isinstance(t, IRTensor) else None for t in node.inputs()) error = f'{color}None{default}' print( f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} | train {train} => " @@ -285,7 +284,7 @@ def _serialize(self, node: IRFwOperation) -> str: for t in node.inputs(): if isinstance(t, IRTensor): shapes.append(t.shape) - dtypes.append(IRDType2TorchDType.map(t.dtype)) + dtypes.append(t.dtype) elif isinstance(t, IRObject): raise RuntimeError('IRObject has not been supported in _serialize') else: From df7f5b9694a860166d9b8d350fa27d8b08ba2567 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 8 Aug 2023 08:22:43 +0000 Subject: [PATCH 1461/1892] Merged PR 1710: parallel module: remove class decorator as it is confusing and not very useful parallel module: remove class decorator as it is confusing and not very useful --- cube/parallel.py | 40 +------------------ unit_tests/parallel_module/test_decorator.py | 41 -------------------- 2 files changed, 1 insertion(+), 80 deletions(-) delete mode 100644 unit_tests/parallel_module/test_decorator.py diff --git a/cube/parallel.py b/cube/parallel.py index a7114231..afacd542 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -1,4 +1,4 @@ -from typing import Callable, Any, Optional, Type, Union +from typing import Callable, Any, Dict, Optional, Type, Union from pathlib import Path import inspect import sys @@ -361,41 +361,3 @@ def parallelize( instance_name=instance_name, ) return cube_module_class if is_module_class else cube_module_class() - - -def parallel_module( - dummy_input: dict, - pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], - compute_config: ComputeConfig, - *, - dynamic_shape: bool = True, - cube_savedir: Union[str, Path] = './.cube' -) -> Callable[[Union[torch.nn.Module, Type[torch.nn.Module]]], Union[CubeModule, Type[CubeModule]]]: - """ - Work as a class decorator to convert a torch.nn.Module to CubeModule. - - Please make sure the Module's __init__ is paremeter-free. - Please note that - 1. Returned CubeModule will replace the torch.nn.Module in-place. - And all member functions/variables of original torch.nn.Module will be gone. - 2. The parameters of CubeModule will be fixed, - which means all instances of CubeModule will use the same parameters (which are from the tracing). - - Args: - dummy_input (dict): the dummy input for the module - pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy - compute_config (ComputeConfig): the environment resource - dynamic_shape (bool): whether to use dynamic shape - cube_savedir (Union[str, Path]): the directory to save generated code - """ - def wrap(module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]]) -> Union[CubeModule, Type[CubeModule]]: - return parallelize( - module_or_module_class, - dummy_input, - pas_policy, - compute_config, - dynamic_shape=dynamic_shape, - override=False, - cube_savedir=cube_savedir - ) - return wrap diff --git a/unit_tests/parallel_module/test_decorator.py b/unit_tests/parallel_module/test_decorator.py deleted file mode 100644 index a3b67be8..00000000 --- a/unit_tests/parallel_module/test_decorator.py +++ /dev/null @@ -1,41 +0,0 @@ -import tempfile - -import torch -from cube.parallel import ComputeConfig, parallel_module, CubeModule - -from .common import PASData, init_distributed -from ..launch_torchrun import launch_torchrun - - -def _decorator_worker(): - init_distributed() - with tempfile.TemporaryDirectory() as tempdir: - @parallel_module( - {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, - ComputeConfig(1, 1), - dynamic_shape=True, - cube_savedir=tempdir, - ) - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 5) - - def forward(self, x): - return self.linear(x) - assert issubclass(MyModule, CubeModule) - x = MyModule() - y = MyModule() - - MyModule()(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - # parameters from different instances will have the same value. - for p, q in zip(x.parameters(),y.parameters()): - assert torch.equal(p, q) - - -def test_decorator(): - if not torch.cuda.is_available(): - print('skip test_submodules_tp_gpu1 due to lack of cuda devices') - return - launch_torchrun(1, _decorator_worker) From 924434212dedcda79549e55c63b1b048c91f367c Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 9 Aug 2023 02:34:06 +0000 Subject: [PATCH 1462/1892] Merged PR 1711: parallel module: raise error on nested CubeModule parallel module: raise error on nested CubeModule --- cube/codegen/module/module.py | 1 - cube/parallel.py | 3 ++ unit_tests/parallel_module/test_nested.py | 56 +++++++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 unit_tests/parallel_module/test_nested.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index e52e0393..06bd8e25 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -360,7 +360,6 @@ def gen( with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.model_init_statements) if as_parallel_module: - ib.insert_body('') ib.insert_body(f'self.load_attr_content(Path(__file__).with_name("{FxModuleParser.ATTR_CONTENT_FILE}"))') ib.insert_body(f'self.load_dist_param_map(Path(__file__).with_name("{FxModuleParser.ATTR_MAP_FILE}"))') ib.insert_body('') diff --git a/cube/parallel.py b/cube/parallel.py index afacd542..9ad10de3 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -339,6 +339,9 @@ def parallelize( else: module = module_or_module_class + if any(isinstance(m, CubeModule) for m in module.modules()): + raise RuntimeError('CubeModule can not be nested.') + # TODO: copy generated files to other nodes # Currently you must use a shared file system to share the generated files (like mounted Azure Blob) # Or you can manually copy the generated files to other nodes diff --git a/unit_tests/parallel_module/test_nested.py b/unit_tests/parallel_module/test_nested.py new file mode 100644 index 00000000..30fb3b26 --- /dev/null +++ b/unit_tests/parallel_module/test_nested.py @@ -0,0 +1,56 @@ +import tempfile + +import torch +import pytest + +from cube.parallel import parallelize, ComputeConfig + +from .common import PASData, init_distributed +from ..launch_torchrun import launch_torchrun + +def _to_cube_model(module, pas, compute_config, cube_savedir): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + pas, + compute_config, + dynamic_shape=True, + cube_savedir=cube_savedir + ) + +class Module0(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + return self.linear(x) + + +def _nested_module_worker(): + init_distributed() + with tempfile.TemporaryDirectory() as tempdir: + class Module1(torch.nn.Module): + def __init__(self): + super().__init__() + self.module0 = _to_cube_model(Module0(), PASData, ComputeConfig(1, 1), cube_savedir=tempdir) + + def forward(self, x): + return self.module0(x) + + class Module2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.module1 = Module1() + def forward(self, x): + return self.module1(x) + + with pytest.raises(RuntimeError, match='CubeModule can not be nested.'): + _to_cube_model(Module2(), PASData, ComputeConfig(1, 1), cube_savedir=tempdir) + + +def test_nested_module(): + if not torch.cuda.is_available(): + print('skip test_nested_module due to lack of cuda devices') + return + launch_torchrun(1, _nested_module_worker) From d668ff9186b25be127cdaf291b4bda0240920466 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Aug 2023 07:10:18 +0000 Subject: [PATCH 1463/1892] Merged PR 1701: fix segment io bug fix segment io bug --- cube/graph/segment.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index bf5109d2..26caab7b 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -943,11 +943,13 @@ def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> I ad_producers: Dict[Tuple[IRObject,int], Set[int]] = dict() for adapter in self.select(ntype=IRAdapter): for itensor in adapter.inputs(): - if not isinstance(itensor, IRObject): continue + assert len(itensor.device) == 1 ad_consumers.setdefault((itensor, itensor.device[0]), set()).add(adapter.cid) for otensor in adapter.outputs(): - if not isinstance(otensor, IRObject): continue - ad_producers.setdefault((otensor, otensor.device[0]), set()).add(adapter.cid) + assert len(otensor.device) == 1 + # for identity adapters, we remove it from producer side + if (otensor, otensor.device[0]) not in ad_consumers: + ad_producers.setdefault((otensor, otensor.device[0]), set()).add(adapter.cid) # tensor and its device match dmatch = lambda t1, t2: t1 == t2 and t1.device == t2.device From 5e02cf07d641348fce40dfcab6ee5194f3e15a69 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 9 Aug 2023 11:49:53 +0000 Subject: [PATCH 1464/1892] Merged PR 1713: dispatch to (device) using identity dispatch to (device) using identity --- cube/graph/function/function.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index aeee818a..fb198400 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1705,7 +1705,8 @@ def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): signature = 'cube.runtime.function.to' annos = ['* -> *'] if isinstance(dtype_or_device, torch.device): - return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device) + # skip device movement as policy can determine device for the tensor. + return Identity(tensor) elif isinstance(dtype_or_device, torch.dtype): return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device) elif isinstance(dtype_or_device, IRTensor): From 5b3a075f0d7d3f79e9296cf4ba35131a08c9537d Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 10 Aug 2023 03:22:06 +0000 Subject: [PATCH 1465/1892] Merged PR 1709: quick fix small bugs 1. add slice support in GetItem 2. fix record frame catch unexpected cube path 3. add apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply mapping 4. add temp disable trace for record frame --- cube/graph/function/function.py | 2 +- .../concrete_trace_utils/concrete_tracer.py | 37 ++++++++++--------- cube/graph/parser/fx/mapping.py | 1 + unit_tests/graph/parser/test_converter.py | 21 +++++++++++ 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index fb198400..a2e0a3d0 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1723,7 +1723,7 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: # tensor slice if isinstance(obj, IRTensor): # note `None` will always - index = (index,) if isinstance(index, int) else tuple(index) + index = (index,) if isinstance(index, (int, slice)) else tuple(index) return FullSlice(obj, index) # object slice if isinstance(obj, IRObject): diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index d3606b66..d55b3d8f 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -353,14 +353,14 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] if kind == 'call_function': assert isinstance(target, Callable) fn = target - if _orig_getattr(fn, '__module__', None) != 'cube.graph.parser.concrete_trace_utils.concrete_tracer' \ + if _orig_getattr(fn, '__module__', None) != self.__module__ \ and hasattr(fn, '__globals__'): _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) return OperatorPatcherContext.patch_run(fn, *args, **kwargs) elif kind == 'call_method': self_obj, *args_tail = args fn = _orig_getattr(self_obj, target) - if _orig_getattr(fn, '__module__', None) != 'cube.graph.parser.concrete_trace_utils.concrete_tracer' \ + if _orig_getattr(fn, '__module__', None) != self.__module__ \ and hasattr(fn, '__globals__'): _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) result = fn(*args_tail, **kwargs) @@ -369,7 +369,7 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] mod = self.fetch_attr(target) if self.cpu_offload: mod.cuda() # how it works in ddp? - if _orig_getattr(mod, '__module__', None) != 'cube.graph.parser.concrete_trace_utils.concrete_tracer' \ + if _orig_getattr(mod, '__module__', None) != self.__module__ \ and hasattr(mod, '__globals__'): _autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) result = OperatorPatcherContext.patch_run(mod, *args, **kwargs) @@ -458,21 +458,22 @@ def upwrapper(obj: Any): node = self.create_node(kind, target, args_, kwargs_, name, type_expr) - if self.record_frames: - # record code frame, include filename, line number, and function name - frame_record = FrameRecord(None, None, None, None) - cube_cct_path = str(Path(__file__).parent) + '/' # the cube concrete tracer path - torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path - ignore_dirs = [cube_cct_path, torch_path] - for frame in traceback.extract_stack()[-2::-1]: - if any(p in frame.filename for p in ignore_dirs): - continue - frame_record.filename = frame.filename - frame_record.lineno = frame.lineno - frame_record.line = frame.line - frame_record.name = frame.name - break - node.meta['frame_record'] = frame_record + if self.record_frames and kind != 'placeholder': + with self.do_temp_disable(True, True, True): + # record code frame, include filename, line number, and function name + frame_record = FrameRecord(None, None, None, None) + cube_path = str(Path(importlib.util.find_spec('cube').origin).parent) + '/' # the cube path + torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path + ignore_dirs = [cube_path, torch_path] + for frame in traceback.extract_stack()[-2::-1]: + if any(p in frame.filename for p in ignore_dirs): + continue + frame_record.filename = frame.filename + frame_record.lineno = frame.lineno + frame_record.line = frame.line + frame_record.name = frame.name + break + node.meta['frame_record'] = frame_record proxy = self.proxy(value_unwrapped, node) return proxy diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 8069dc93..a7df13d1 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -108,6 +108,7 @@ def exist(signature: str) -> bool: 'torch.functional.norm': function.Norm, __ftemplate('layer_norm'): function.LayerNorm, 'apex.normalization.fused_layer_norm.FusedLayerNorm': function.FusedLayerNorm, + 'apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply': function.FusedLayerNorm, # ============== runtime function ================= __tttemplate('size'): function.Size, diff --git a/unit_tests/graph/parser/test_converter.py b/unit_tests/graph/parser/test_converter.py index 7e010e79..15206922 100644 --- a/unit_tests/graph/parser/test_converter.py +++ b/unit_tests/graph/parser/test_converter.py @@ -1,4 +1,5 @@ import tempfile +import importlib from pathlib import Path import torch @@ -73,3 +74,23 @@ def forward(self, x, *args): # currently we don't support *args with pytest.raises(RuntimeError): to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + + +def test_record_codeline(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, *args): + return self.linear(x) + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + cube_path = str(Path(importlib.util.find_spec('cube').origin).parent) + '/' + + for node in fx_graph.graph.nodes: + if 'frame_record' in node.meta and cube_path in str(node.meta['frame_record']): + err_msg = f"Cube root path should not in node comment {node.meta['frame_record']}" + raise RuntimeError(err_msg) From ef6fd079199e739b683c9a999d998716580085e3 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 11 Aug 2023 07:50:56 +0000 Subject: [PATCH 1466/1892] Merged PR 1721: support more functions support more functions --- cube/graph/function/function.py | 50 ++++++++++++++++++++++++++----- cube/graph/parser/fx/mapping.py | 8 ++--- cube/runtime/function/function.py | 7 +++++ 3 files changed, 52 insertions(+), 13 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index a2e0a3d0..92383e9f 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -243,6 +243,23 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir return dimop +def Full(size, fill_value, *, out=None, dtype=None, layout=None, + device=None, requires_grad=False, signature=None): + """ + torch.full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) + """ + assert layout in (None, torch.strided), f"Not support for non-default layout" + dtype = dtype if dtype is not None else torch.get_default_dtype() + signature = 'cube.runtime.function.full' + size = tuple(size) + anno, rules = _get_creator_anno_rules( + tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) + dimop = IRDimops(Full, 'full', signature, [anno], [], rules, + size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad) + dimop.output(0).parent.dtype = dtype + return dimop + + def NewTensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False, signature=None): # note: device is ignored @@ -450,7 +467,7 @@ def Sqrt(input, *, out=None, signature=None): torch.sqrt(input, *, out=None) """ assert out is None - if not isinstance(input, IRTensor): + if not isinstance(input, IRObject): return torch.sqrt(input) if not isinstance(input, IRTensor): iv = input.value if isinstance(input, IRObject) else input @@ -460,6 +477,18 @@ def Sqrt(input, *, out=None, signature=None): return IRDimops(Sqrt, 'sqrt', signature, annos, [input]) +def RSqrt(input, *, out=None, signature=None): + assert out is None + if not isinstance(input, IRObject): + return torch.rsqrt(input) + if not isinstance(input, IRTensor): + iv = input.value if isinstance(input, IRObject) else input + return IRPyFunc(signature, [input], [IRObject(name='rsqrt', value=torch.rsqrt(iv))]) + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(RSqrt, 'rsqrt', signature, annos, [input]) + + def FloorDiv(input, other, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): @@ -811,18 +840,23 @@ def Sum(input, dim=None, keepdim=False, *, dtype=None, signature = None): def Mean(input, dim=None, keepdim=False, *, dtype=None, signature = None): """ - torch.mean(input, *, dtype=None) -> Tensor - torch.mean(input, dim, keepdim=False, *, dtype=None) -> Tensor + torch.mean(input, *, dtype=None) + torch.mean(input, dim=None, keepdim=False, *, dtype=None) + torch.Tensor.mean(input, dim=None, keepdim=False, *, dtype=None) """ assert dtype is None einput = ShapeAnno.create_shape_str(input.shape) eoutput = copy.copy(einput) - dim = (dim,) if isinstance(dim, int) else dim if dim is not None: - sort_dim = sorted(dim) - for dimidx in sort_dim[::-1]: - eoutput.pop(dimidx) - einput[dimidx] = einput[dimidx] + '^' + dims = (dim,) if isinstance(dim, int) else tuple(dim) + dims = tuple(dim % len(input.shape) for dim in dims) + for dim in sorted(dims, reverse=True): + einput[dim] += '^' + if keepdim: + eoutput[dim] = '1' + else: + eoutput.pop(dim) + dim = dims else: eoutput = ['1'] einput = [edim + '^' for edim in einput] diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index a7df13d1..f02436e3 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -56,9 +56,11 @@ def exist(signature: str) -> bool: __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, + __ttemplate('mean') : function.Mean, __ttemplate('abs'): function.Abs, __ttemplate('exp'): function.Exp, __ttemplate('sqrt'): function.Sqrt, + __ttemplate('rsqrt'): function.RSqrt, __ttemplate('clamp'): function.Clamp, __ttemplate('clamp_min'): function.ClampMin, __ttemplate('squeeze'): function.Squeeze, @@ -142,7 +144,7 @@ def exist(signature: str) -> bool: __ttemplate('zeros'): function.Zeros, __ttemplate('ones'): function.Ones, __ttemplate('tensor'): function.NewTensor, - # __ttemplate('to'): function.ToTensor, + __ttemplate('full'): function.Full, __ttemplate('rand'): function.Rand, # __ttemplate('clone'): function.Clone, @@ -178,10 +180,6 @@ def exist(signature: str) -> bool: # __ttemplate('cos'): function.Cos, # - # __ttemplate('sum') : function.Sum, - # __ttemplate('mean') : function.Mean, - # - # __ttemplate('view'): function.View, __tttemplate('view'): function.View, __tttemplate('contiguous'): function.Contiguous, diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 47c5146f..99c26beb 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -190,6 +190,13 @@ def rand(size: Tuple[int], dtype=None, requires_grad=False): requires_grad=requires_grad ) +def full(size: Tuple[int], fill_value, dtype=None, requires_grad=False): + return torch.full( + size, fill_value, dtype=dtype, requires_grad=requires_grad, + device=torch.cuda.current_device() + ) + + def arange(start: int, end: int, step: int, dtype: torch.dtype, requires_grad=False): return torch.arange(start=start, end=end, step=step, dtype=dtype, requires_grad=requires_grad, From c18ecca6c1eeb81f1a546f0e220e5ffba0729dcc Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 11 Aug 2023 07:57:38 +0000 Subject: [PATCH 1467/1892] Merged PR 1712: set all tensor gradient to none for inference set all tensor gradient to none for inference --- cube/codegen/module/module.py | 6 ------ cube/compiler.py | 8 ++------ cube/program.py | 2 ++ 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 06bd8e25..a91a2eed 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -367,12 +367,6 @@ def gen( for_block.insert_body(f'reducer.build_buckets()') ib.insert_body('') ib.insert_body(for_block.code) - else: - # switch to training or inference mode - if self.execplan.inference: - ib.insert_body('self.eval()') - else: - ib.insert_body('self.train()') cb.insert_body('') cb.insert_body(ib.code) segment_idxs =[] diff --git a/cube/compiler.py b/cube/compiler.py index 91cbdf08..48c9dbee 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -95,13 +95,9 @@ def train_iter(model, dataloader): dataloader = arg arg = SemanticDataLoader(dataloader) elif isinstance(arg, torch.Tensor): - # note: we will always set tensor to require gradient, which may - # generate backward communications in adapter. However, as long as - # the data doesn't require gradient in real runtime, the backward - # communication will not be triggered. tensor = arg arg = IRFullTensor(arg.shape, name='tensor', - requires_grad=True, + requires_grad=arg.requires_grad, dtype=arg.dtype).tosub() arg._value = tensor arg.grad = arg.parent.grad.tosub() if arg.requires_grad else None @@ -133,7 +129,6 @@ def decorator(fn: Callable) -> Callable: # run once to get model structure and tensor shape start = time.time() outputs = fn(*inputs) - Program().finalize() if outputs is None: outputs = [] elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): @@ -150,6 +145,7 @@ def decorator(fn: Callable) -> Callable: Program().set_input(pinputs) # setup program output Program().set_output(outputs) + Program().finalize() span = time.time() - start _logger.info('finish parsing iteration: {:.2f} s'.format(span)) diff --git a/cube/program.py b/cube/program.py index c44be527..49fa534f 100644 --- a/cube/program.py +++ b/cube/program.py @@ -59,7 +59,9 @@ def finalize(self): If the program doesn't do backward, set all tensors with requires_grad=False. """ graph = self.get_graph() + # inference scenario, set all gradients to none. if not any(isinstance(node, IRBpOperation) for node in graph.nodes()): + # set gradients of activation tensors to none for ftensor in graph.full_tensors(): ftensor.requires_grad = False From 527d60f52b5f8a926b1bd360ec74358455bdc52b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 11 Aug 2023 12:59:07 +0000 Subject: [PATCH 1468/1892] Merged PR 1717: refactor dataloader to disable partition refactor dataloader --- cube/algorithm/factory.py | 3 - cube/algorithm/ops/dataloader.py | 53 -------- cube/codegen/emit.py | 3 +- cube/codegen/module/module.py | 16 --- cube/codegen/schedule/schedule.py | 2 +- cube/compiler.py | 29 +---- cube/ir/operator.py | 47 +++---- cube/program.py | 58 +++------ cube/runtime/__init__.py | 1 - cube/runtime/module.py | 8 -- cube/runtime/syndata.py | 200 ------------------------------ cube/runtime/utils.py | 46 +++++++ 12 files changed, 84 insertions(+), 382 deletions(-) delete mode 100644 cube/algorithm/ops/dataloader.py delete mode 100644 cube/runtime/syndata.py create mode 100644 cube/runtime/utils.py diff --git a/cube/algorithm/factory.py b/cube/algorithm/factory.py index 602df88e..37c8a12d 100644 --- a/cube/algorithm/factory.py +++ b/cube/algorithm/factory.py @@ -59,9 +59,6 @@ def algorithms(self, op, tag = None): def _load_predefined_algos(self): - import cube.algorithm.ops.dataloader as dataloader - self.register(dataloader.IRDataOperation, dataloader.DPDataLoader, tag='data') - import cube.algorithm.ops.dimops as dimops self.register(dimops.IRDimops, dimops.DimSplitEinops, tag='dim') diff --git a/cube/algorithm/ops/dataloader.py b/cube/algorithm/ops/dataloader.py deleted file mode 100644 index 4914b381..00000000 --- a/cube/algorithm/ops/dataloader.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import List -import copy - -from cube.algorithm.generics import GenericDistAlgo -from cube.ir.operator import IRDataOperation - - -class DPDataLoader(GenericDistAlgo): - - def __init__(self, node: IRDataOperation): - - if not isinstance(node, IRDataOperation): - raise TypeError(f"f{type(node)} can not be transformed to {type(self)}") - super().__init__(node) - - def satisfy(self, num: int): - """ - Check whether the condition satisfies. - - @param num int: number of chunks to partition - """ - - node: IRDataOperation = self.node - dims: List[int] = node.get_batch_dims() - # check batch size - all_batch_size = set([output.shape[dim] for dim, output in zip(dims, node.outputs())]) - # batch size not same -- indicate a scientific model - if len(all_batch_size) != 1: - return False - for dim, output in zip(dims, node.outputs()): - if output.shape[dim] % num != 0: - return False - return True - - def instantiate(self, num: int): - if not self.satisfy(num): - return None - node: IRDataOperation = self.node - dims: List[int] = node.get_batch_dims() - - outputs = list() - for dim, output in zip(dims, node.outputs()): - output = output.split_dim(dim, num) - outputs.append(output) - - nodes = list() - for outs in zip(*outputs): - node = IRDataOperation( - data_num=len(outs), batch_dims=copy.copy(dims)) - for idx, out in enumerate(outs): - node.set_output(idx, out) - nodes.append(node) - return nodes diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 09284ec2..6b6d2886 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -130,7 +130,8 @@ class FuncEmission(CodeEmission): @staticmethod def emit_dataloader(node: IRDataOperation) -> List[str]: - return ['next(dataloader)'] + outputs = FuncEmission.return_name(node.outputs()) + return [f'{outputs} = next({FuncEmission.tensor_name(node.input(0))})'] @staticmethod def emit_fnode(node: IRFwOperation, prefix_attr: str = None) -> List[str]: diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index a91a2eed..c03343bf 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -329,7 +329,6 @@ def gen( elif isinstance(node, IRBpOperation): continue elif isinstance(node, IRDataOperation): - self.init_batchsize(node) continue else: raise RuntimeError(f"Un-recognized IRCell type: {type(node)}") @@ -529,21 +528,6 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: add_code = reducer_add.format(reducer=reducer_name) self.model_init_statements.append(add_code) - def init_batchsize(self, node: IRDataOperation): - """ - Emit batch size declare - """ - signature = 'self.set_batch_size({bs})' - bs = [t.shape[dim] for t, dim in zip(node.outputs(), node.get_batch_dims()) if dim is not None] - bs = set(bs) - if len(bs) > 1: - _logger.warning(f'Find Heterogenous batch size {bs}. Keep output to be same with semantic dataloder.') - bs = list(bs)[0] if len(bs) == 1 else None - assert self.batch_size is None or self.batch_size == bs, f"Not match for batch size: {self.batch_size} != {bs}" - self.model_init_statements.append(signature.format(bs=bs)) - self.model_init_statements.append('') - self.batch_size = bs - @staticmethod def emit_segment(segment: IRSegment) -> List[str]: """ diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index 9ac9cc36..5194a2c0 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -170,7 +170,7 @@ def emit_node(node: IRCell, force_no_grad: bool = False) -> List[str]: ) elif isinstance(unwrap_node, IRDataOperation): - code = f'{outputs} = next(dataloader)' + code = ScheduleCodeGen.emit_dataloader(unwrap_node)[0] elif isinstance(unwrap_node, IRAdapter): code = asign.format( diff --git a/cube/compiler.py b/cube/compiler.py index 48c9dbee..1fcd3a92 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -24,7 +24,6 @@ from cube.codegen import ModuleCodeGen, ScheduleCodeGen from cube.runtime.device import DeviceGroup -from cube.runtime.syndata import CubeDataLoader from cube.program import Program, SemanticDataLoader, SemanticModel from cube.flags import CompileFlag @@ -86,14 +85,11 @@ def train_iter(model, dataloader): model.save_content = load_content model.dynamic_shape = model_dynamic_shape - dataloader = None inputs = [model] for arg in args: assert not isinstance(arg, (torch.nn.Module, SemanticModel)), f"Only one model can be input for compile" - if isinstance(arg, (torch.utils.data.Dataset, CubeDataLoader)): - assert dataloader is None - dataloader = arg - arg = SemanticDataLoader(dataloader) + if isinstance(arg, torch.utils.data.DataLoader): + arg = SemanticDataLoader(arg) elif isinstance(arg, torch.Tensor): tensor = arg arg = IRFullTensor(arg.shape, name='tensor', @@ -112,8 +108,6 @@ def decorator(fn: Callable) -> Callable: if not override and os.path.exists(filename.format(myrank)): filename = filename.format(myrank) - # TODO: set batch size - _logger.warning('dataloader batch size stay as default.') # load module code _logger.info(f'loading existed module from {filename} ...') model.load_module(filename) @@ -139,7 +133,7 @@ def decorator(fn: Callable) -> Callable: if isinstance(input, SemanticModel): pinputs.append('model') elif isinstance(input, SemanticDataLoader): - pinputs.append('dataloader') + pinputs.append(input.object) else: pinputs.append(input) Program().set_input(pinputs) @@ -253,23 +247,6 @@ def decorator(fn: Callable) -> Callable: torch.distributed.barrier() model.dummy_input = None - # set dataloder batch size (serialize output) - if dataloader is not None: - bs = model.get_gen_module().get_batch_size() - print_each_rank(f'setting batch size to: {bs}', logger=_logger) - if torch.distributed.is_initialized(): - for rank in range(torch.distributed.get_world_size()): - if rank == torch.distributed.get_rank(): - if bs is not None and dataloader is not None: - dataloader.set_batch_size(bs) - torch.distributed.barrier() - else: - if bs is not None and dataloader is not None: - dataloader.set_batch_size(bs) - - if torch.distributed.is_initialized(): - torch.distributed.barrier() - # load temporal schedule print_each_rank(f'loading generated schedule from {filename} ...', logger=_logger) return cube.load_default_schedule(filename) diff --git a/cube/ir/operator.py b/cube/ir/operator.py index fc18bafc..ff8faa43 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -202,13 +202,21 @@ def __repr__(self) -> str: class IRDataOperation(IRCell): + """Dataloader operator + + The output of a dataloader operator is a tuple of (IRObject,). + """ - def __init__(self, data_num: int, batch_dims: Tuple[int], name='dataloader'): - if len(batch_dims) != data_num: - raise RuntimeError("Expected each output data has a specified batch dim") - signature = 'dataloader.__next__' - super().__init__(name, signature, 0, data_num) - self.batch_dims = tuple(batch_dims) + def __init__(self, input: IRObject, outputs: Tuple[IRObject], name='dataloader'): + signature = 'next' + super().__init__(name, signature, 1, len(outputs)) + if not isinstance(input, IRObject): + raise TypeError(f"input should be an IRObject, but got {type(output)}") + self.set_input(0, input) + for idx, output in enumerate(outputs): + if not isinstance(output, IRObject): + raise TypeError(f"output should be an IRObject, but got {type(output)}") + self.set_output(idx, output) def replicate(self): """ @@ -229,38 +237,11 @@ def replicate(self): cpy.clear_successor() return cpy - def get_batch_dims(self): - return copy.copy(self.batch_dims) - def infer_shape(self): """ Infer output value shape """ return True - - def algorithms(self, tag: Optional[str] = None) -> Union[Tuple[GenericDistAlgo], GenericDistAlgo]: - """ - Get algorithm from algorithm factory - - @param tag Optional[str]: the queried tag (default None for all) - - @return algorithm(s) Union[Tuple[GenericDistAlgo], GenericDistAlgo]: - If None (default), return all possible algorithms. - Otherwise, return the specified one. - """ - factory = DistAlgorithmFactory() - if tag is None: - templates = list() - if factory.exist(type(self)): - templates = factory.algorithms(type(self)) - algos = list() - for template in templates: - algos.append(template(self)) - return algos - else: - assert factory.exist(type(self), tag), f"Node {self} doesn't have transformation algorithm tag: {tag}" - template = factory.algorithms(type(self), tag) - return template(self) def __repr__(self): dscp = (f"DataLoader{self._id}-{self.device}(outputs={self.outputs()})") diff --git a/cube/program.py b/cube/program.py index 49fa534f..bdd16d4b 100644 --- a/cube/program.py +++ b/cube/program.py @@ -8,13 +8,13 @@ from cube.graph import IRGraph from cube.graph import parser -from cube.runtime.syndata import CubeDataLoader from cube.runtime.module import CubeModule from cube.runtime.device import DeviceGroup from cube.utils import load_model import torch +import torch.utils.data as data class Program: @@ -74,23 +74,18 @@ def __repr__(self): class SemanticDataLoader: - def __init__(self, dataloader: CubeDataLoader): - if not isinstance(dataloader, CubeDataLoader): - raise TypeError("Expected data loader derived from CubeDataLoader") - self.dataloader: CubeDataLoader = iter(dataloader) - - def get_batch_dims(self) -> Tuple[Optional[int]]: - return tuple(self.dataloader.get_batch_dims()) - - def get_batch_size(self) -> int: - return self.dataloader.get_batch_size() - - def set_batch_size(self, bs: int): - self.dataloader.set_batch_size(bs) - return + def __init__(self, dataloader: data.DataLoader): + """ + Create semantic dataloader which will produces IRDataOperation + when calling `next`. - def get_runtime_sample(self): - return next(self.dataloader) + Args: + dataloader (torch.utils.data.DataLoader): torch dataloader + """ + if not isinstance(dataloader, data.DataLoader): + raise TypeError("Expected data loader derived from torch.utils.data.DataLoader") + self.dataloader: data.DataLoader = iter(dataloader) + self.object = IRObject(name='dataloader', value=self.dataloader) def __iter__(self): return self @@ -102,35 +97,18 @@ def generate_output(sample): return tuple(generate_output(t) for t in sample) if isinstance(sample, list): return list(generate_output(t) for t in sample) - if isinstance(sample, dict): - assert all(isinstance(key, (str, int)) for key in sample.keys()) - return {key:generate_output(val) for key, val in sample.items()} - if isinstance(sample, set): - return {generate_output(t) for t in sample} if isinstance(sample, torch.Tensor): tensor = IRFullTensor(list(sample.shape), 'data', dtype=sample.dtype).tosub() tensor._value = sample return tensor - else: - return IRObject('data', value=sample) - + return IRObject('data', value=sample) + # get dataloader sample sample = next(self.dataloader) + # turn sample into IRObjects outputs = generate_output(sample) - - # create dataloader - if isinstance(outputs, (tuple, list)): - data_num = len(outputs) - elif isinstance(outputs, dict): - data_num = len(outputs.keys()) - else: - data_num = 1 - - data_op = IRDataOperation(data_num=data_num, batch_dims=self.get_batch_dims()) - if not isinstance(outputs, tuple): - data_op.set_output(0, outputs) - else: - for idx, t in enumerate(outputs): - data_op.set_output(idx, t) + # create dataloader operation + outputs = outputs if isinstance(outputs, (tuple, list)) else (outputs,) + data_op = IRDataOperation(self.object, outputs) Program().add_node(data_op) return outputs diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index e883bb2c..898c5fec 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -1,7 +1,6 @@ from cube.runtime import executor from cube.runtime import device from cube.runtime import adapter -from cube.runtime import syndata from cube.runtime import resource from cube.runtime import module from cube.runtime import function diff --git a/cube/runtime/module.py b/cube/runtime/module.py index c2662c5b..4f8f9ba1 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -21,7 +21,6 @@ def __init__(self): super().__init__() self._reducers: List[Reducer] = list() self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() - self._batch_size: Optional[int] = None @property def reducers(self): @@ -85,13 +84,6 @@ def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: i def get_full_map(self): return self._fullmap - def set_batch_size(self, bs: Optional[int]): - assert (bs is None) or (isinstance(bs, int) and bs > 0) - self._batch_size = bs - - def get_batch_size(self) -> Optional[int]: - return self._batch_size - def load_attr_content(self, filename: str): with torch.no_grad(): full = torch.load(filename) diff --git a/cube/runtime/syndata.py b/cube/runtime/syndata.py deleted file mode 100644 index 0bd53343..00000000 --- a/cube/runtime/syndata.py +++ /dev/null @@ -1,200 +0,0 @@ -r""" -Synthetic Data Loader -""" - -from typing import Any, List, Optional, Tuple, Union -import logging - -import torch - -_logger = logging.getLogger(__name__) - - -class CubeDataLoader: - r""" - Cube Dataloader. - User should provide a dataloader to runtime with at least these functionalities: - - 1) `__iter__()`: get the dataloder iterator - 2) `__next__()` get the next batch of data - 3) `get_batch_size()` return the batch size (int) - 4) `set_batch_size(bs)` reset the batch size (int) - 5) `get_batch_dims(self)` get the batch dimension of each output data - """ - def __init__(self, batch_size: int, batch_dims: Tuple[Optional[int]]): - """ - Create a dataloader for cube runtime - - @param batch_size int: dataloader batch size - @param batch_dims Tuple[Optional[int]]: the batch dimension of each output data, - None indicates the output (tensor or non-tensor) doesn't have the batch dimension. - """ - self.batch_size: int = batch_size - self.batch_dims: Tuple[Optional[int]] = batch_dims - - def __iter__(self): - raise NotImplementedError("Required implementation for derived class") - - def __next__(self): - raise NotImplementedError("Required implementation for derived class") - - def get_batch_size(self) -> int: - """ - get batch size - """ - return self.batch_size - - def set_batch_size(self, batch_size: int): - """ - set batch size - """ - raise NotImplementedError("Required implementation for derived class") - - def get_batch_dims(self) -> Tuple[Optional[int]]: - return tuple(self.batch_dims) - - -class SciLoopVariables(CubeDataLoader): - r"""Scientific loop variable loader - """ - def __init__(self, variables: List[Any], constants: List[Any]): - shapes = [] - dtypes = [] - for var in variables + constants: - if torch.is_tensor(var): - shapes.append(list(var.size()) if len(var.size()) != 0 else [1,]) - dtypes.append(var.dtype) - else: - shapes.append([1,]) - dtypes.append(type(var)) - super().__init__(0, [None] * len(shapes)) - self.variables = list() - self.constants = list() - for var in variables: - if torch.is_tensor(var) and var.device != torch.cuda.current_device(): - var = var.cuda() - self.variables.append(var) - for const in constants: - if torch.is_tensor(const) and const.device != torch.cuda.current_device(): - const = const.cuda() - self.constants.append(const) - - def get_batch_size(self) -> int: - return 0 - - def set_batch_size(self, batch_size: int): - return - - def __iter__(self): - return self - - def __next__(self): - if len(self.variables) + len(self.constants) == 1: - return (self.variables + self.constants)[0] - return tuple(self.variables + self.constants) - - def update(self, variables: Optional[List[Any]] = None, constants: Optional[List[Any]] = None): - """ - Update variables and constants - """ - if variables is not None: - if len(variables) != len(self.variables): - raise ValueError(f"Expected {len(self.shapes)} but only got {len(variables)} varaibales to update") - for var, expected_shape in zip(variables, self.shapes): - expected_shape = tuple(expected_shape) - if not torch.is_tensor(var) and expected_shape != (1,): - raise ValueError(f"Non-tensor variable: Expected shape is (1,)") - if torch.is_tensor(var) and tuple(var.size()) != expected_shape: - raise ValueError(f"Shape update mismatch: var: {var.size()} != expected: {expected_shape}") - self.variables = variables - if constants is not None: - if len(constants) != len(self.constants): - raise ValueError(f"Expected {len(self.shapes)} but only got {len(constants)} varaibales to update") - for const, expected_shape in zip(constants, self.shapes): - expected_shape = tuple(expected_shape) - if not torch.is_tensor(const) and expected_shape != (1,): - raise ValueError(f"Non-tensor constant: Expected shape is (1,)") - if torch.is_tensor(const) and tuple(const.size()) != expected_shape: - raise ValueError(f"Shape update mismatch: const: {const.size()} != expected: {expected_shape}") - self.constants = constants - - -class SynDataLoader(CubeDataLoader): - r""" - Synthetic dataloader to produce tensors - for given shapes, dtypes. - """ - def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype] = None, - batch_dims: Tuple[int] = None, names: Tuple[str] = None, append_args=None, device=None): - """ - shapes Tuple[Tuple[int]]: - The shape for each data - dtypes Tuple[torch.dtype]: - The dtype for each data (Default None: use torch.float32) - batch_dims Tuple[int]: - The batch dimension of each data (Default None: dimension 0 is the batch dim) - """ - if batch_dims is None: - batch_dims = tuple([0] * len(shapes)) - if dtypes is None: - dtypes = tuple([torch.float] * len(shapes)) - self.shapes = tuple([list(shape) for shape in shapes]) - self.dtypes = dtypes - batch_size = shapes[0][batch_dims[0]] - super().__init__(batch_size, batch_dims) - self.names = names - self.append_args=append_args - self.device = device if device else torch.cuda.current_device() - self.buffer: Union[torch.Tensor, Tuple[torch.Tensor]] = None - datas = self.random_sample() - self.set_output(datas) - - def __iter__(self): - return self - - def __next__(self): - if self.names is not None: - assert len(self.names) == len(self.buffer) - ret_dict = dict(zip(self.names, self.buffer)) - if self.append_args is not None: - ret_dict = ret_dict.update(self.append_args) - return ret_dict - else: - return self.buffer - - def random_sample(self) -> Tuple[torch.Tensor]: - torch.manual_seed(0) - datas = [] - for shape, dtype in zip(self.shapes, self.dtypes): - if shape and all(isinstance(dim, int) for dim in list(shape)): - datas.append( - torch.rand( - shape, - device=self.device, - requires_grad=False).to(dtype) - if torch.is_floating_point(torch.zeros([1], dtype=dtype)) else - torch.ones( - shape, - device=self.device, - requires_grad=False - ).to(dtype) - ) - else: - datas.append(dtype()) - return tuple(datas) - - def set_output(self, datas: Union[torch.Tensor, Tuple[torch.Tensor]]): - datas = (datas,) if torch.is_tensor(datas) else tuple(datas) - if len(datas) == 0: - self.buffer = None - else: - self.buffer = datas #will not convert like: datas[0] if len(datas) == 1 else datas - - def set_batch_size(self, batch_size: int): - self.batch_size = batch_size - for shape, dim in zip(self.shapes, self.batch_dims): - shape[dim] = batch_size - rank = 0 if not torch.distributed.is_initialized() else torch.distributed.get_rank() - _logger.info(f'rank [{rank}]: set batch size to {batch_size}. dataloader outputs change to: {self.shapes}') - datas = self.random_sample() - self.set_output(datas) diff --git a/cube/runtime/utils.py b/cube/runtime/utils.py new file mode 100644 index 00000000..828ffa2a --- /dev/null +++ b/cube/runtime/utils.py @@ -0,0 +1,46 @@ +r"""Runtime Utilities""" + +from typing import Any +import logging + +import torch.utils.data as data + +_logger = logging.getLogger(__name__) + + +def create_dummy_dataloader(sample: Any, + batch_size: int, drop_last=True, + **dataloader_config) -> data.DataLoader: + """Create a dummy dataloader + + The function is mainly used for performance test. + + Args: + sample (Any): a data sample without batch size dimension. + The sample can be a single tensor/object or tuple/list of tensors/objects + batch_size (int): batch size + drop_last (bool): whether to drop last batch to make batch size consistent. + dataloader_config (dict): kwargs for dataloader initialization. + + Returns: + dataloader (torch.utils.data.DataLoader): + returns + """ + + class DummyDataset(data.Dataset): + + def __init__(self, sample: Any): + + self.sample = sample + + def __len__(self): + return 1024000 + + def __getitem__(self, key: int): + return self.sample + + dataset = DummyDataset(sample) + dataloader = data.DataLoader( + dataset, batch_size=batch_size, drop_last=drop_last, + **dataloader_config) + return dataloader From 25bf019f76b946a371d930d58228220908d6bb3b Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 14 Aug 2023 06:09:20 +0000 Subject: [PATCH 1469/1892] Merged PR 1727: add megatron gpt example --- examples/megatron_gpt/.gitignore | 3 + examples/megatron_gpt/README.md | 27 ++++++++ examples/megatron_gpt/convert.py | 76 ++++++++++++++++++++ examples/megatron_gpt/gpt_model.py | 34 +++++++++ examples/megatron_gpt/parallel.py | 64 +++++++++++++++++ examples/megatron_gpt/run.sh | 108 +++++++++++++++++++++++++++++ 6 files changed, 312 insertions(+) create mode 100644 examples/megatron_gpt/.gitignore create mode 100644 examples/megatron_gpt/README.md create mode 100644 examples/megatron_gpt/convert.py create mode 100644 examples/megatron_gpt/gpt_model.py create mode 100644 examples/megatron_gpt/parallel.py create mode 100644 examples/megatron_gpt/run.sh diff --git a/examples/megatron_gpt/.gitignore b/examples/megatron_gpt/.gitignore new file mode 100644 index 00000000..cf862087 --- /dev/null +++ b/examples/megatron_gpt/.gitignore @@ -0,0 +1,3 @@ +*log.txt +*.pt +*.cube \ No newline at end of file diff --git a/examples/megatron_gpt/README.md b/examples/megatron_gpt/README.md new file mode 100644 index 00000000..4a20fdfc --- /dev/null +++ b/examples/megatron_gpt/README.md @@ -0,0 +1,27 @@ +# Train Megatron-GPT with Cube + +This example demonstrates how to train a GPT model from Megatron-ML using Cube. The process consists of three main steps: +1. Instantiate the model and trace it to an fx.Graph. Then, convert the fx.Graph to a Cube graph. +2. Compile the Cube graph into Python code by **data parallel** on 2 devices. +3. Train the GPT model using the compiled code in Fairseq. + +At first, clone the Megatron-LM and checkpoint to the devcube branch, gpt model in this branch is a single device version. + +```console +git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Megatron-LM +cd Megatron-LM +git checkout devcube +# cd MagicCube dir +cd ../MagicCube/examples/megatron_gpt +# download gpt2-vocab.json and gpt2-merges.txt +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt +``` + +The following three commands correspond to the above three steps: + +```console +bash run.sh trace +bash run.sh compile +bash run.sh run +``` diff --git a/examples/megatron_gpt/convert.py b/examples/megatron_gpt/convert.py new file mode 100644 index 00000000..dc0f3dfb --- /dev/null +++ b/examples/megatron_gpt/convert.py @@ -0,0 +1,76 @@ +# 1. build model +from gpt_model import build_model, GeLUFunction +model = build_model() + +# 2. register customized op +from cube.graph.parser.register import register +register('* h, h -> * h')(GeLUFunction.apply) + +# 3. build semantic model +from cube import SemanticModel +smodel = SemanticModel(model) + +# 4. set dummy input +import torch +batch_size = 16 +seq_len = 128 +dict_len = 50000 +smodel.dummy_input={ + 'src_tokens': torch.randint(0, dict_len, (batch_size, seq_len)), + 'target': torch.randint(0, dict_len, (batch_size, seq_len)), + 'ntokens': 128, +} + +from cube.graph.function import IRObject +from cube.ir import IRFullTensor + +src_tokens = IRFullTensor(shape=[batch_size, seq_len], + name='src_tokens', + dtype=torch.int).tosub() + +target = IRFullTensor(shape=[batch_size, seq_len], + name='target', + dtype=torch.int).tosub() + +ntokens = IRObject(name='ntokens') + +# 5. convert to graph +from cube.graph.segment import IRSegment +from cube.program import Program + +from torch.autograd.graph import saved_tensors_hooks + +class no_save_tensor_hook(saved_tensors_hooks): + def __init__(self): + + def pack(x): + return None + + def unpack(x): + raise RuntimeError("not expecting backward to be called on this tensor") + + super().__init__(pack, unpack) + +Program().clear() + +with no_save_tensor_hook(): + outputs = smodel(src_tokens, target, ntokens) +outputs[0].backward() + +Program().finalize() +Program().set_input([src_tokens, target, ntokens]) + +if outputs is None: + outputs = [] +elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = [outputs] +Program().set_output(outputs) + +graph = Program().get_graph() + +# 6. save graph +graph.dump('megatron_gpt2.cube') + +for node in graph._nodes: + if isinstance(node, IRSegment): + print(node.debug_tensor_map_str()) diff --git a/examples/megatron_gpt/gpt_model.py b/examples/megatron_gpt/gpt_model.py new file mode 100644 index 00000000..aea1f412 --- /dev/null +++ b/examples/megatron_gpt/gpt_model.py @@ -0,0 +1,34 @@ +import torch + +from megatron import initialize_megatron +from megatron.training import get_args, ModelType +from megatron.arguments import core_transformer_config_from_args +from megatron.model import GPTModel +from megatron.model.fused_bias_gelu import GeLUFunction +from megatron.core.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear + +class GPT2Model(GPTModel): + def __init__(self, config, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True): + super().__init__(config, num_tokentypes, parallel_output, pre_process, post_process) + + def forward(self, src_tokens, target, ntokens): + position_ids = torch.arange(0, src_tokens.shape[1], 1).unsqueeze(0).expand_as(src_tokens) + attention_mask = (torch.tril(torch.ones(1, 1, src_tokens.shape[1], src_tokens.shape[1])) < 0.5).bool() + res = super().forward(src_tokens, position_ids, attention_mask, labels=target) + return res, ntokens, {'loss': res, 'ntokens': ntokens, 'nsentences': src_tokens.shape[0], 'sample_size': ntokens} + + +def build_model() -> GPT2Model: + initialize_megatron(extra_args_provider=None, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) + get_args().model_type = ModelType.encoder_or_decoder + config = core_transformer_config_from_args(get_args()) + model = GPT2Model( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True + ) + + return model diff --git a/examples/megatron_gpt/parallel.py b/examples/megatron_gpt/parallel.py new file mode 100644 index 00000000..adae5b09 --- /dev/null +++ b/examples/megatron_gpt/parallel.py @@ -0,0 +1,64 @@ +import os +plan_ngpus = int(os.environ['PLAN_NGPUS']) +runtime_ngpus = int(os.environ['CUBE_SCALING_FACTOR']) + +# 1. load graph +from cube.graph import IRGraph +graph = IRGraph.load('megatron_gpt2.cube') + +# 2. register customized op +from gpt_model import GeLUFunction +from cube.graph.parser.register import register +register('* h, h -> * h')(GeLUFunction.apply) + +# 3. parallel model +from fairseq.cube.pas_policies import PASData, PASRandomSPMD +graph = PASData(graph, plan_ngpus) + +for node in graph.nodes(flatten=True): + from cube.graph.function.anchor import IRGraphAnchor + from cube.graph.function.pyfunc import IRPyFunc + # skip graph anchor and multiref: they will be removed or replaced by system + if isinstance(node, IRGraphAnchor) or node.name == 'multiref': + graph.assign(node, 0) + if isinstance(node, IRPyFunc): + graph.assign(node, 0) + if len(node.device) == 0: + raise RuntimeError(f"Node {node} device is not set") +from cube.graph.gener.gen import IRAdapterGener +graph = IRAdapterGener.gen(graph, cost_fn=None) +if graph.sched is not None: + graph.sched.apply() + print(graph.sched) + +from cube.graph.schedule.schedplan import SchedulePlan +from cube.execplan import ExecutionPlan +if isinstance(graph.sched, SchedulePlan): + execplan = ExecutionPlan.from_schedplan(graph.sched) +else: + execplan = ExecutionPlan.from_graph(graph) +# execplan.visualize('plan.png') +from cube.execplan.planpass.fusion import DiffFusion +execplan = DiffFusion.apply(execplan) +# plan pass for computation grouping +from cube.execplan.planpass.grouping import Grouping +if not graph.sched: + execplan = Grouping.apply(execplan) + +# 4. generate code +from cube.codegen import ModuleCodeGen, ScheduleCodeGen +filename = 'gencode{}.py' +_runtime_ngpus = None if plan_ngpus == runtime_ngpus else runtime_ngpus +assert len(execplan.graph.device) == plan_ngpus, f"{execplan.graph.device}" +mgener = ModuleCodeGen(execplan, scale_ndevs=_runtime_ngpus) +sgener = ScheduleCodeGen(execplan, scale_ndevs=_runtime_ngpus) +for rank in range(runtime_ngpus): + fname = filename.format(rank) + # generate spatial module code + mgener.gen(rank, outfile=fname, attach=False) + # generate temporal schedule code + sgener.gen( + device = rank, + outfile = fname, + attach=True + ) diff --git a/examples/megatron_gpt/run.sh b/examples/megatron_gpt/run.sh new file mode 100644 index 00000000..a8b0f6af --- /dev/null +++ b/examples/megatron_gpt/run.sh @@ -0,0 +1,108 @@ +# Usage: bash run.sh mode = {trace, compile, run, all} + +MEGATRON_PATH=/home/ningshang/Megatron-LM +TENSORBOARD_DIR=/data/ningshang/megatron_gpt +DATA_PATH=/data/ningshang/torchscale_data +TORCHSCALE_PATH=/home/ningshang/anaconda3/envs/cube/lib/python3.10/site-packages/examples/fairseq +FAIRSEQ_PATH=/home/ningshang/Fairseq + +export USE_TORCHFX=1 +export LOG_PARSER=1 +export DISABLE_CODE_LINE_INFO=0 + +PLAN_NGPUS=1 +CUBE_SCALING_FACTOR=2 + + +# check arg num +if [ $# -ne 1 ] +then + echo "Usage: bash run.sh mode = {trace, compile, run}" + exit 1 +fi + +MODE=$1 + +if [ $MODE = "trace" ] +then + VOCAB_FILE=./gpt2-vocab.json + MERGE_FILE=./gpt2-merges.txt + GPT_ARGS=" + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --seq-length 128 \ + --max-position-embeddings 128 + " + USELESS_ARGS=" + --micro-batch-size 4 \ + --global-batch-size 8 \ + --lr 0.00015 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --lr-warmup-fraction .01 \ + --clip-grad 1.0 + " + DATA_ARGS=" + --data-path $DATA_PATH/train \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --data-impl mmap \ + --split 949,50,1 + " + OUTPUT_ARGS=" + --log-interval 100 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 + " + PYTHONPATH=.:PYTHONPATH:$TORCHSCALE_PATH:$MEGATRON_PATH CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --nnodes 1 convert.py $GPT_ARGS $DATA_ARGS $OUTPUT_ARGS $USELESS_ARGS >trace_log.txt 2>&1 +elif [ $MODE = "compile" ] +then + PLAN_NGPUS=$PLAN_NGPUS CUBE_SCALING_FACTOR=$CUBE_SCALING_FACTOR PYTHONPATH=.:PYTHONPATH:$TORCHSCALE_PATH:$MEGATRON_PATH:$FAIRSEQ_PATH python parallel.py >compile_log.txt 2>&1 +elif [ $MODE = "run" ] +then + PLAN_NGPUS=$PLAN_NGPUS PYTHONPATH=.:PYTHONPATH:$TORCHSCALE_PATH:$MEGATRON_PATH torchrun \ + --nproc_per_node=2 \ + --nnodes=1 \ + $TORCHSCALE_PATH/train.py $DATA_PATH \ + --num-workers 2 \ + --activation-fn gelu \ + --share-decoder-input-output-embed \ + --arch lm_base_125M \ + --validate-interval-updates 1000 \ + --save-interval-updates 1000 \ + --log-interval 1 \ + --task language_modeling \ + --sample-break-mode none \ + --tokens-per-sample 128 \ + --optimizer adam \ + --adam-betas "(0.9,0.999)" \ + --adam-eps 1e-08 \ + --clip-norm 1.0 \ + --lr 6.0e-4 \ + --lr-scheduler polynomial_decay \ + --warmup-updates 230 \ + --dropout 0.0 \ + --attention-dropout 0.0 \ + --weight-decay 0.01 \ + --batch-size 16 \ + --update-freq 1 \ + --required-batch-size-multiple 1 \ + --total-num-update 5000 \ + --max-update 5000 \ + --seed 1234 \ + --ddp-backend=legacy_ddp \ + --cube-scaling-factor $CUBE_SCALING_FACTOR \ + --subln --xpos-rel-pos \ + --parallel-backend=cube \ + --compile=run_only \ + --tensorboard-logdir $TENSORBOARD_DIR \ + --save-dir=/data/ningshang/checkpoint >run_log.txt 2>&1 +else + echo "Usage: bash run.sh mode = {trace, compile, run}" + exit 1 +fi From 6b5a187ef36b7ad4e7a2c0fa64d2eae23fff9ee0 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 14 Aug 2023 06:09:30 +0000 Subject: [PATCH 1470/1892] Merged PR 1725: move inputs back to cpu after call if cpu offload move inputs back to cpu after call if cpu offload --- cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index d55b3d8f..19e1ca4a 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -385,6 +385,9 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] with self.do_temp_disable(call=True): result = run(kind, target, args, kwargs) if self.cpu_offload: + # move back arguments to cpu if cpu_offload + args = tree_map(to_cpu, args) + kwargs = tree_map(to_cpu, kwargs) if isinstance(result, torch.Tensor): result = result.cpu() elif isinstance(result, (list, dict, tuple)): From c2628d993ed0c1832f91e92e9a1474eee31cb0d5 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 15 Aug 2023 02:55:45 +0000 Subject: [PATCH 1471/1892] Merged PR 1724: add dce ignore function api add dce ignore api --- cube/graph/parser/converter.py | 13 ++++- .../concrete_trace_utils/concrete_tracer.py | 48 ++++++++++--------- .../parser/fx/concrete_trace_utils/utils.py | 15 +++--- 3 files changed, 45 insertions(+), 31 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 1e079806..a5589f18 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, Optional, List, Union +import inspect import logging from pathlib import Path -import os +from typing import Any, Dict, Union + from cube.ir.tensor import IRFullTensor from cube.graph.parser.register import CustomizedOps @@ -11,6 +12,8 @@ from cube.graph.parser.fx.parser import FxModuleParser from cube.graph.parser.fx.concrete_trace_utils import concrete_trace +import cube.runtime.function as cube_rt_function + import torch import torch.fx @@ -36,6 +39,11 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: # get registered leaf function autowrap_funcs = [CustomizedOps.kOpRuntime.get(sign, None) for sign in CustomizedOps.kOpAutowrap] leaf_functions = {func: ([], True, None) for func in autowrap_funcs if func is not None} + + # get cube runtime functions + cube_rt_funcs = [func for _, func in inspect.getmembers(cube_rt_function, inspect.isfunction)] + leaf_functions.update({func: ([], True, None) for func in cube_rt_funcs}) + dce_ignored_funcs = set(cube_rt_funcs) if HAS_APEX: leaf_module = ( @@ -57,6 +65,7 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: use_operator_patch=True, leaf_module=leaf_module, autowrap_leaf_function=leaf_functions, + dce_ignored_function=dce_ignored_funcs, cpu_offload=True, record_frames=not CompileFlag.disable_code_line_info, ) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 19e1ca4a..4d8f7159 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -15,7 +15,7 @@ import importlib.util from itertools import chain -from types import BuiltinMethodType, FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType +from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType from typing import Any, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, List, Callable, Union from contextlib import contextmanager from pathlib import Path @@ -127,20 +127,8 @@ def __exit__(self, *args): _orig_max, _orig_node_is_impure, - - FrameRecord, ) - -# some side effectful functions that should not be deleted during dead code elimination -# there may be more than listed here -extra_side_effectful_functions = { - operator.setitem, - builtins.next, - _orig_torch_no_grad, - _orig_torch_no_grad_enter, - _orig_torch_no_grad_exit, -} -_side_effectful_functions = _side_effectful_functions.union(extra_side_effectful_functions) +from .utils import FrameRecord, ExtraSEFPatcher # pyright: reportGeneralTypeIssues=false _logger = logging.getLogger(__name__) @@ -1558,11 +1546,12 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], autowrap_leaf_function = None, autowrap_leaf_class = None, leaf_module: Tuple | None = None, - fake_middle_class = None, - dce = True, - cpu_offload = False, - trace_twice = False, - record_frames = False, + fake_middle_class: Tuple | None = None, + dce: bool = True, + dce_ignored_function: Set[Callable] | None = None, + cpu_offload: bool = False, + trace_twice: bool = False, + record_frames: bool = False, ) -> GraphModule: """ Concrete tracing API @@ -1687,18 +1676,23 @@ def f(x, y): is_iterator_class: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True. dce (bool): If set to True, dead code eliminatation will be applied on the graph. - + + dce_ignored_function (Set[Callable]): The node that its target in this set will not be removed from the graph during dce. + cpu_offload (bool): Whether to offload the module to CPU during tracing. If set to True, the traced code will be executed on GPU, but is offloaded to CPU afterward. This is useful for reducing memory usage during tracing, but may cause performance issues. If set to False, there will be no offloading during tracing, but the traced code will be executed on default device. trace_twice (bool): If set to True, a second trace will be performed, and the two obtained graphs will be checked for consistency. - record_frames(bool): If set to True, will add frame information to node.meta['frame_record']. Note this will cost additional trace time. + record_frames (bool): If set to True, will add frame information to node.meta['frame_record']. Note this will cost additional trace time. Returns: fx.GraphModule: a Module created from the recorded operations from ``root``. """ + dce_ignored_function = dce_ignored_function if isinstance(dce_ignored_function, set) else set() + assert all(callable(ignore_func) for ignore_func in dce_ignored_function) + tracer = ConcreteTracer(cpu_offload = cpu_offload, record_frames = record_frames) is_training = root.training root.eval() @@ -1745,7 +1739,17 @@ def f(x, y): traced = GraphModule(tracer.root, graph, name) if dce: - with _Patcher() as patcher: + # some side effectful functions that should not be deleted during dead code elimination + # there may be more than listed here + default_extra_side_effectful_functions = { + operator.setitem, + builtins.next, + _orig_torch_no_grad, + _orig_torch_no_grad_enter, + _orig_torch_no_grad_exit, + } + extra_side_effectful_functions = default_extra_side_effectful_functions | dce_ignored_function + with _Patcher() as patcher, ExtraSEFPatcher(extra_side_effectful_functions): patcher.patch_method(Node, 'is_impure', node_is_impure_wrapper, deduplicate=False) traced.graph.eliminate_dead_code() traced.recompile() # this need to be done in MagicMethodPatcher context diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index 0fcaa9a2..db244c6b 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -4,11 +4,12 @@ import builtins from dataclasses import dataclass import operator -from typing import Any, Callable, Type +from typing import Any, Callable, Set, Type import functools import torch from torch.fx import Node +from torch.fx.node import _side_effectful_functions # These need to run in global scope to handle nested calls correctly _orig_module_call: Callable = torch.nn.Module.__call__ @@ -128,13 +129,13 @@ def __repr__(self) -> str: return '' class ExtraSEFPatcher: - from torch.fx.node import _side_effectful_functions - # some side effectful functions that should not be deleted during dead code elimination - # there may be more than listed here - extra_funcs = {operator.setitem, builtins.next} - _side_effectful_functions + def __init__(self, extra_side_effectful_functions: Set[Callable]): + self.extra_side_effectful_functions = extra_side_effectful_functions + self.incontext_funcs = set() def __enter__(self): - self._side_effectful_functions.update(self.extra_funcs) + self.incontext_funcs = self.extra_side_effectful_functions - _side_effectful_functions + _side_effectful_functions.update(self.incontext_funcs) def __exit__(self, exc_type, exc_val, exc_tb): - self._side_effectful_functions.difference_update(self.extra_funcs) + _side_effectful_functions.difference_update(self.incontext_funcs) From f1a937a865c665d6d911a7cb23cbe54a69b752a0 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 15 Aug 2023 07:46:43 +0000 Subject: [PATCH 1472/1892] Merged PR 1742: quick fix trace cube runtime func bug quick fix trace cube runtime func bug --- cube/graph/parser/converter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index a5589f18..7b5a7124 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,4 +1,3 @@ -import inspect import logging from pathlib import Path from typing import Any, Dict, Union @@ -39,9 +38,9 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: # get registered leaf function autowrap_funcs = [CustomizedOps.kOpRuntime.get(sign, None) for sign in CustomizedOps.kOpAutowrap] leaf_functions = {func: ([], True, None) for func in autowrap_funcs if func is not None} - + # get cube runtime functions - cube_rt_funcs = [func for _, func in inspect.getmembers(cube_rt_function, inspect.isfunction)] + cube_rt_funcs = [cube_rt_function.anchor] leaf_functions.update({func: ([], True, None) for func in cube_rt_funcs}) dce_ignored_funcs = set(cube_rt_funcs) From 7a413fcd5eb5974891ce1713ec89912505f66f57 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 15 Aug 2023 09:07:12 +0000 Subject: [PATCH 1473/1892] Merged PR 1729: fix ir dtype fix ir dtype --- cube/ir/adapter/prim.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cube/ir/adapter/prim.py b/cube/ir/adapter/prim.py index eab15dbd..9eae0c9e 100644 --- a/cube/ir/adapter/prim.py +++ b/cube/ir/adapter/prim.py @@ -181,7 +181,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k if len(kwargs) == 0: assert len(itensors) == 1 and len(otensors) == 1 kwargs['shape'] = itensors[0].shape - kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None shape, dtype, src, dst = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dst'] @@ -223,7 +223,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim if len(kwargs) == 0: assert len(itensors) == 1 kwargs['shape'] = tuple(itensors[0].shape) - kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None kwargs['dsts'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) shape, dtype, src, dsts = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dsts'] @@ -254,7 +254,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k if len(kwargs) == 0: assert len(itensors) == 1 kwargs['shape'] = tuple(itensors[0].shape) - kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None kwargs['dsts'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) shape, dtype, src, dsts = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dsts'] @@ -279,7 +279,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim if len(kwargs) == 0: assert len(otensors) == 1 kwargs['shape'] = tuple(itensors[0].shape) # the input tensor shape - kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['dtype'] = str(itensors[0].dtype) kwargs['srcs'] = tuple(itensor.device[0] if len(itensor.device) > 0 else None for itensor in itensors) kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None shape, dtype, srcs, dst = kwargs['shape'], kwargs['dtype'], kwargs['srcs'], kwargs['dst'] @@ -304,7 +304,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k if len(kwargs) == 0: assert len(otensors) == 1 kwargs['shape'] = tuple(itensors[0].shape) - kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['dtype'] = str(itensors[0].dtype) kwargs['srcs'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None shape, dtype, srcs, dst = kwargs['shape'], kwargs['dtype'], kwargs['srcs'], kwargs['dst'] @@ -328,7 +328,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k if len(kwargs) == 0: assert len(itensors) == 1 kwargs['shape'] = tuple(itensors[0].shape) - kwargs['dtype'] = 'torch.' + itensors[0].dtype.value + kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None super().__init__(itensors, otensors, **kwargs) self.signature = 'cube.runtime.adapter.broadcast' From f486ea955a7b9686d26d05ea19837e97c5b3abbb Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 16 Aug 2023 02:42:38 +0000 Subject: [PATCH 1474/1892] Merged PR 1733: Refine runtime collectives implementations - all_to_all -> all_to_all_single - all_gather -> all_gather_into_tensor --- cube/runtime/adapter/nn.py | 35 ++++++++++++++++++++++++++++++++- cube/runtime/adapter/reducer.py | 6 ++++-- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index 919d75f4..7619ba68 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -56,6 +56,22 @@ def _alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> return otensor +def _alltoallsingle(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: + CudaTimer().start(field_name='comm', predefined=True) + if odim != 0: + itensor = itensor.transpose(0, odim) + if not itensor.is_contiguous(): + itensor = itensor.contiguous() + group = DeviceGroup().get_group(ranks) + otensor = torch.empty_like(itensor) + torch.distributed.all_to_all_single(otensor, itensor, group=group) + if odim != 0: + otensor = otensor.transpose(0, odim) + otensor = torch.concat(tuple(otensor.chunk(len(ranks), dim=odim)), dim=idim) + CudaTimer().stop(field_name='comm', predefined=True) + return otensor + + def _chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: """ split dimension in n chunks and take idx-th chunk @@ -224,9 +240,26 @@ def backward(ctx, grad: torch.Tensor): return grad, None, None, None +class AllToAllAllToAllSingle(torch.autograd.Function): + + @staticmethod + def forward(ctx, itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]): + ctx._ranks = ranks + ctx._idim = idim + ctx._odim = odim + return _alltoallsingle(itensor, idim, odim, ranks) + + @staticmethod + def backward(ctx, grad: torch.Tensor): + ranks = ctx._ranks + idim, odim = ctx._idim, ctx._odim + grad = _alltoallsingle(grad, odim, idim, ranks) + return grad, None, None, None + + @torch.jit.ignore def alltoall_alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: - return AllToAllAllToAll.apply(itensor, idim, odim, ranks) + return AllToAllAllToAllSingle.apply(itensor, idim, odim, ranks) class ReduceBroadcast(torch.autograd.Function): diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index fb755110..b0b9ab33 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -231,8 +231,10 @@ def gather_params(self): """ assert self._zero, "gathering paramters is only for zero optimization." rank = torch.distributed.get_rank(group=self._zero_subgroup) - shards = list(self._contiguous_params.chunk(self._zgroup_sz, dim=0)) - torch.distributed.all_gather(shards, shards[rank], group=self._zero_subgroup) + CudaTimer().start(field_name='comm', predefined=True) + src_tensor = self._contiguous_params.chunk(self._zgroup_sz, dim=0)[rank] + torch.distributed.all_gather_into_tensor(self._contiguous_params, src_tensor, group=self._zero_subgroup) + CudaTimer().stop(field_name='comm', predefined=True) def register_pre_hook(self, fn: Callable): """Register pre hooks to be applied before gradient synchronization. From 94fb55aee3baf02c949e21f4d1c9f2191a0e4180 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 16 Aug 2023 02:57:41 +0000 Subject: [PATCH 1475/1892] Merged PR 1714: parallel module: support two-step parallelization parallel module: support two-step parallelization --- cube/parallel.py | 43 +++++++++++--------- unit_tests/parallel_module/test_gencode.py | 46 ++++++++++++++++++++++ 2 files changed, 71 insertions(+), 18 deletions(-) create mode 100644 unit_tests/parallel_module/test_gencode.py diff --git a/cube/parallel.py b/cube/parallel.py index 9ad10de3..6200cc5d 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -294,13 +294,18 @@ def parallelize( cube_savedir: Union[str, Path] = './.cube', override: bool = False, instance_name: Optional[str] = None, -) -> Union[CubeModule, Type[CubeModule]]: + load_module: bool = True, +) -> Union[None, CubeModule, Type[CubeModule]]: """ Convert a torch.nn.Module object or class to CubeModule object or class. If you want to save multiple instances of the same module, you can specify the instance_name to distingish them. + Currently you must use a shared file system to share the generated files (like mounted Azure Blob) + Or you can unset load_module flag, and manually copy the generated files to other nodes. + After all nodes have the generated files, you can call parallelize() again with load_module flag set. + Args: module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled dummy_input (dict): the dummy input for the module @@ -310,24 +315,25 @@ def parallelize( override (bool): If true, source code will be regenerated even if generated code exists. cube_savedir (Union[str, Path]): the directory to save generated code instance_name (Optional[str]): the instance name of the generated module. + load_module (bool): whether to load the generated module after done. Returns: - Union[CubeModule, Type[CubeModule]]: the converted CubeModule object or class + Union[CubeModule, Type[CubeModule], None]: + if load_module flag is set, return the converted CubeModule object or class + if load_module flag is not set, return None """ if ( isinstance(module_or_module_class, CubeModule) or (inspect.isclass(module_or_module_class) and issubclass(module_or_module_class, CubeModule)) ): - return module_or_module_class - - if not torch.distributed.is_initialized(): # we only support distributed training - raise RuntimeError("Distributed training is not initialized.") + return module_or_module_class if load_module else None - rank = torch.distributed.get_rank() is_module_class = inspect.isclass(module_or_module_class) module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ - if rank == 0: + # genereate code only in node0 + # if it is not in a torchrun environment, just generate. + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: if is_module_class: # it should only have 1 `self` parameter if len(inspect.signature(module_or_module_class.__init__).parameters) > 1: @@ -342,9 +348,6 @@ def parallelize( if any(isinstance(m, CubeModule) for m in module.modules()): raise RuntimeError('CubeModule can not be nested.') - # TODO: copy generated files to other nodes - # Currently you must use a shared file system to share the generated files (like mounted Azure Blob) - # Or you can manually copy the generated files to other nodes _gencode( module, dummy_input, @@ -357,10 +360,14 @@ def parallelize( ) if is_module_class: del module - torch.distributed.barrier() - cube_module_class = _load_cube_module_class( - module_class, - cube_savedir=cube_savedir, - instance_name=instance_name, - ) - return cube_module_class if is_module_class else cube_module_class() + + if load_module: + if not torch.distributed.is_initialized(): # we only support distributed training + raise RuntimeError("Load CubeModule failed: torch.distributed is not initialized.") + torch.distributed.barrier() + cube_module_class = _load_cube_module_class( + module_class, + cube_savedir=cube_savedir, + instance_name=instance_name, + ) + return cube_module_class if is_module_class else cube_module_class() diff --git a/unit_tests/parallel_module/test_gencode.py b/unit_tests/parallel_module/test_gencode.py new file mode 100644 index 00000000..3f6b53ac --- /dev/null +++ b/unit_tests/parallel_module/test_gencode.py @@ -0,0 +1,46 @@ +import tempfile + +import torch +import pytest + +from cube.parallel import parallelize, ComputeConfig, CubeModule + +from .common import PASData, init_distributed +from ..launch_torchrun import launch_torchrun + +def _to_cube_model(module, compute_config, cube_savedir, load_module): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + compute_config, + dynamic_shape=True, + cube_savedir=cube_savedir, + load_module=load_module + ) + +class Module0(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + return self.linear(x) + + +def _gencode_worker(tempdir): + init_distributed() + m = Module0() + with pytest.raises(RuntimeError): # config mismatch + _to_cube_model(m, ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True) + + +def test_codegen(): + if not torch.cuda.is_available(): + print('skip test_codegen due to lack of cuda devices') + return + with tempfile.TemporaryDirectory() as tempdir: + m = Module0() + m_new = _to_cube_model(m, ComputeConfig(2, 4), cube_savedir=tempdir, load_module=False) + assert m_new is None + launch_torchrun(1, _gencode_worker, tempdir) From 406e9a15b531d9e13af135ba787353cfc1bda954 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 16 Aug 2023 06:53:28 +0000 Subject: [PATCH 1476/1892] Merged PR 1734: Add llama example --- cube/graph/parser/fx/parser.py | 2 +- examples/llama/__init__.py | 0 examples/llama/chat.py | 66 +++++ examples/llama/generation.py | 342 +++++++++++++++++++++++++ examples/llama/model.py | 307 ++++++++++++++++++++++ examples/llama/test_chat_completion.py | 107 ++++++++ examples/llama/tokenizer.py | 41 +++ 7 files changed, 864 insertions(+), 1 deletion(-) create mode 100644 examples/llama/__init__.py create mode 100644 examples/llama/chat.py create mode 100644 examples/llama/generation.py create mode 100644 examples/llama/model.py create mode 100644 examples/llama/test_chat_completion.py create mode 100644 examples/llama/tokenizer.py diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 7256765b..69d30a22 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -128,7 +128,7 @@ def parse(module: torch.fx.GraphModule, dtype = input.meta['tensor_meta'].dtype val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) else: - val = IRObject(input.name) + val = IRObject(input.name, value=dummy_inputs[input.name]) else: # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, # extend to other input types diff --git a/examples/llama/__init__.py b/examples/llama/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/llama/chat.py b/examples/llama/chat.py new file mode 100644 index 00000000..e89b6493 --- /dev/null +++ b/examples/llama/chat.py @@ -0,0 +1,66 @@ +""" +pip install fire sentencepiece + +PYTHONPATH=.:$PYTHONPATH torchrun \ + --nproc_per_node=1 \ + examples/llama/chat.py \ + --ckpt_dir=/home/t-zhiqilin/llama/llama-2-7b-chat \ + --tokenizer_path=/home/t-zhiqilin/llama/tokenizer.model \ + --max_seq_len 512 --max_batch_size 8 --temperature 0 \ + --use-cube +""" + +from typing import Optional + +import fire +import logging + +from examples.llama.generation import Llama + +import cube + +cube.init() +cube.set_logger_level(level=logging.WARNING) +logging.getLogger('cube.compiler').setLevel(logging.INFO) + + +def main( + ckpt_dir: str, + tokenizer_path: str, + temperature: float = 0.6, + top_p: float = 0.9, + max_seq_len: int = 512, + max_batch_size: int = 8, + max_gen_len: Optional[int] = None, + use_cube: bool = False, +): + generator = Llama.build( + ckpt_dir=ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + use_cube=use_cube, + ) + + dialog = [ + {"role": "system", "content": + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature."}, + ] + + print('Assistant: Hello, this is Llama 2') + while True: + user_content = input("Prompt >> ") + dialog.append({"role": "user", "content": user_content}) + result = generator.chat_completion( + [dialog], + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + )[0] + assit_content = result['generation']['content'] + print(f"{result['generation']['role'].capitalize()}: {assit_content}") + dialog.append({"role": "assistant", "content": assit_content}) + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/examples/llama/generation.py b/examples/llama/generation.py new file mode 100644 index 00000000..3594ebd4 --- /dev/null +++ b/examples/llama/generation.py @@ -0,0 +1,342 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +import os +import sys +import time +from pathlib import Path +from typing import List, Literal, Optional, Tuple, TypedDict + +import torch +import torch.nn.functional as F + +from examples.llama.model import ModelArgs, Transformer +from examples.llama.tokenizer import Tokenizer + +Role = Literal["system", "user", "assistant"] + +import cube +from cube.flags import CompileFlag + + +class Message(TypedDict): + role: Role + content: str + + +class CompletionPrediction(TypedDict, total=False): + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required + + +class ChatPrediction(TypedDict, total=False): + generation: Message + tokens: List[str] # not required + logprobs: List[float] # not required + + +Dialog = List[Message] + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] +UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." + + +class Llama: + @staticmethod + def build( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + use_cube: bool, + ) -> "Llama": + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(1) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + # assert model_parallel_size == len( + # checkpoints + # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" + ckpt_path = checkpoints[0] + # ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + torch.set_default_tensor_type(torch.cuda.HalfTensor) + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer, use_cube) + + def __init__(self, model: Transformer, tokenizer: Tokenizer, use_cube: bool): + self.model = model + self.tokenizer = tokenizer + + # ======================= cube initilizer ================= + self.use_cube = use_cube + if use_cube: + print(f"Build using cube engine") + CompileFlag.disable_code_line_info = False + self.build_inference() + + def build_inference(self): + + sample_tokens = torch.randint( + 1, 1000, size=(4, 38), dtype=torch.int64) + + def policy(graph, resource): + from cube.ir.operator import IRFwOperation + for fwop in graph.select(ntype=IRFwOperation): + graph.assign(fwop, 0) + return graph + + @cube.compile(self.model, sample_tokens, 0, + PAS=policy, model_dynamic_shape=True) + def infer(model: torch.nn.Module, tokens: torch.Tensor, prev_pos: int): + logits = model(tokens, prev_pos) + return logits + + params = self.model.params + vocab_size, n_layers = params.vocab_size, params.n_layers + + del self.model + self.model = cube.load_model() + + # TODO: support auto reset non-parameter attributes for llama model + self.model.params = params + self.model.vocab_size = vocab_size + self.model.n_layers = n_layers + self.infer_fn = (infer,) + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + for cur_pos in range(min_prompt_len, total_len): + if self.use_cube: + logits = self.infer_fn[0](self.model, tokens[:, prev_pos:cur_pos], prev_pos) + else: + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + next_token == self.tokenizer.eos_id + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to eos tok if any + if self.tokenizer.eos_id in toks: + eos_idx = toks.index(self.tokenizer.eos_id) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + + def text_completion( + self, + prompts: List[str], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> List[CompletionPrediction]: + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + if logprobs: + return [ + { + "generation": self.tokenizer.decode(t), + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] + + def chat_completion( + self, + dialogs: List[Dialog], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> List[ChatPrediction]: + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [] + unsafe_requests = [] + for dialog in dialogs: + unsafe_requests.append( + any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) + ) + if dialog[0]["role"] == "system": + dialog = [ + { + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system', 'user' and 'assistant' roles, " + "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" + ) + dialog_tokens: List[int] = sum( + [ + self.tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + bos=True, + eos=True, + ) + for prompt, answer in zip( + dialog[::2], + dialog[1::2], + ) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += self.tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + bos=True, + eos=False, + ) + prompt_tokens.append(dialog_tokens) + + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + ) + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t) + if not unsafe + else UNSAFE_ERROR, + }, + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i, unsafe in zip( + generation_tokens, generation_logprobs, unsafe_requests + ) + ] + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, + } + } + for t, unsafe in zip(generation_tokens, unsafe_requests) + ] + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token \ No newline at end of file diff --git a/examples/llama/model.py b/examples/llama/model.py new file mode 100644 index 00000000..646d13be --- /dev/null +++ b/examples/llama/model.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Any, Optional, Tuple + + +import torch +import torch.nn.functional as F +from torch import nn + +import cube + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +# TODO: fix annotation +@cube.graph.parser.register('*, *, 38^ 64^ -> *, *') +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +@cube.graph.parser.register('N seqlen^, N seqlen^ H^ -> 1 1 seqlen^ seqlen^') +def create_mask(tokens: torch.Tensor, h: torch.Tensor, start_pos: int): + seqlen = tokens.shape[1] + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + return mask + + +@cube.graph.parser.register('N seqlen *, 1 1 * -> N seqlen *') +def apply_mask(x: torch.Tensor, mask: torch.Tensor): + return x if mask is None else x + mask + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + self.n_local_heads = args.n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = torch.nn.Linear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + ) + self.wk = torch.nn.Linear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wv = torch.nn.Linear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wo = torch.nn.Linear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + ) + + self.cache_k = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + self.cache_v = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # TODO: support register function with kwargs on tensor + # xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis) + + # modification: move `.to(xq)` to the belowing + # self.cache_k = self.cache_k.to(xq) + # self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k.to(xq)[:bsz, : start_pos + seqlen] + values = self.cache_v.to(xq)[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + + # NOTE: cube doesn't support dynamic graph + # if mask is not None: + # scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = apply_mask(scores, mask) + + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = torch.nn.Linear( + dim, hidden_dim, bias=False + ) + self.w2 = torch.nn.Linear( + hidden_dim, dim, bias=False + ) + self.w3 = torch.nn.Linear( + dim, hidden_dim, bias=False + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + h = x + self.attention.forward( + self.attention_norm(x), start_pos, freqs_cis, mask + ) + out = h + self.feed_forward.forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = torch.nn.Embedding( + params.vocab_size, params.dim + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = torch.nn.Linear( + params.dim, params.vocab_size, bias=False + ) + + self.freqs_cis = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + ) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + + # TODO: support tracking dependency on kwarg IRObject + start_pos = start_pos + 0 + + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + # self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis.to(h.device)[start_pos : start_pos + seqlen] + + # NOTE: cube doesn't support dynamic graph + # mask = None + # if seqlen > 1: + # mask = torch.full( + # (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + # ) + # mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + mask = create_mask(tokens, h, start_pos) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output \ No newline at end of file diff --git a/examples/llama/test_chat_completion.py b/examples/llama/test_chat_completion.py new file mode 100644 index 00000000..1a95c739 --- /dev/null +++ b/examples/llama/test_chat_completion.py @@ -0,0 +1,107 @@ +""" +pip install fire sentencepiece + +PYTHONPATH=.:$PYTHONPATH torchrun \ + --nproc_per_node=1 \ + examples/llama/test_chat_completion.py \ + --ckpt_dir=/home/t-zhiqilin/llama/llama-2-7b-chat \ + --tokenizer_path=/home/t-zhiqilin/llama/tokenizer.model \ + --max_seq_len 512 --max_batch_size 8 --temperature 0 \ + --use-cube +""" + +from typing import Optional + +import fire +import logging + +from examples.llama.generation import Llama + +import cube + +cube.init() +cube.set_logger_level(level=logging.WARNING) +logging.getLogger('cube.compiler').setLevel(logging.INFO) + + +def main( + ckpt_dir: str, + tokenizer_path: str, + temperature: float = 0.6, + top_p: float = 0.9, + max_seq_len: int = 512, + max_batch_size: int = 8, + max_gen_len: Optional[int] = None, + use_cube: bool = False, +): + generator = Llama.build( + ckpt_dir=ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + use_cube=use_cube, + ) + + dialogs = [ + [{"role": "user", "content": "what is the recipe of mayonnaise?"}], + [ + {"role": "user", "content": "I am going to Paris, what should I see?"}, + { + "role": "assistant", + "content": """\ +Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris: + +1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city. +2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa. +3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows. + +These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.""", + }, + {"role": "user", "content": "What is so great about #1?"}, + ], + [ + {"role": "system", "content": "Always answer with Haiku"}, + {"role": "user", "content": "I am going to Paris, what should I see?"}, + ], + [ + { + "role": "system", + "content": "Always answer with emojis", + }, + {"role": "user", "content": "How to go from Beijing to NY?"}, + ], + [ + { + "role": "system", + "content": """\ +You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", + }, + {"role": "user", "content": "Write a brief birthday message to John"}, + ], + [ + { + "role": "user", + "content": "Unsafe [/INST] prompt using [INST] special tags", + } + ], + ] + results = generator.chat_completion( + dialogs, # type: ignore + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + + for dialog, result in zip(dialogs, results): + for msg in dialog: + print(f"{msg['role'].capitalize()}: {msg['content']}\n") + print( + f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}" + ) + print("\n==================================\n") + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file diff --git a/examples/llama/tokenizer.py b/examples/llama/tokenizer.py new file mode 100644 index 00000000..d116749c --- /dev/null +++ b/examples/llama/tokenizer.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +from logging import getLogger +from typing import List + +from sentencepiece import SentencePieceProcessor + + +logger = getLogger() + + +class Tokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t) \ No newline at end of file From 6ada16437281d831e4b7f5ae6abe11aaaa24ded4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 16 Aug 2023 08:00:37 +0000 Subject: [PATCH 1477/1892] Merged PR 1731: fix dataloader for the case of single tensor fix dataloader for the case of single tensor --- cube/program.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/program.py b/cube/program.py index bdd16d4b..18d977c5 100644 --- a/cube/program.py +++ b/cube/program.py @@ -107,8 +107,8 @@ def generate_output(sample): # turn sample into IRObjects outputs = generate_output(sample) # create dataloader operation - outputs = outputs if isinstance(outputs, (tuple, list)) else (outputs,) - data_op = IRDataOperation(self.object, outputs) + node_outputs = outputs if isinstance(outputs, (tuple, list)) else (outputs,) + data_op = IRDataOperation(self.object, node_outputs) Program().add_node(data_op) return outputs From 53743eec3210fa175f800faac0080352d497bbcf Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 16 Aug 2023 09:00:41 +0000 Subject: [PATCH 1478/1892] Merged PR 1728: Restructure Codebase - 1: Merge tests into unit tests --- README.md | 6 + benchmark/deepspeed/benchmark_gpt.sh | 136 ----------- benchmark/deepspeed/gpt_bench.py | 186 -------------- benchmark/deepspeed/pretrain_gpt_synthetic.py | 222 ----------------- benchmark/megatron/benchmark_gpt.sh | 69 ------ benchmark/megatron/pretrain_gpt_synthetic.py | 112 --------- scripts/aggregate.sh | 23 -- scripts/env-setup.sh | 49 ---- scripts/sync.sh | 40 --- tests/codegen/test_scale.py | 202 --------------- tests/gpt_memory_profile.md | 9 - tests/graph/test_dump_load.py | 157 ------------ tests/graph/test_fusion.py | 99 -------- tests/graph/test_infer_grad.py | 134 ---------- tests/graph/test_multiref.py | 115 --------- tests/graph/test_segment.py | 87 ------- tests/parser/test_bloom.py | 180 -------------- tests/parser/test_compile.py | 165 ------------- tests/parser/test_cus_autograd.py | 111 --------- tests/parser/test_fx_ops.py | 162 ------------ tests/parser/test_fx_zip.py | 79 ------ tests/parser/test_jit_ops.py | 87 ------- tests/parser/test_no_grad.py | 88 ------- tests/parser/test_torchscale_basic.py | 84 ------- tests/runtime/test_reducer.py | 218 ----------------- tests/runtime/test_runtime_collectives.py | 231 ------------------ tests/runtime/test_runtime_flag.py | 210 ---------------- tests/test_codegen.py | 185 -------------- tests/test_examples.sh | 122 --------- tests/test_execplan_grouping.py | 66 ----- tests/test_nccl.py | 103 -------- tests/test_profile_gpt.py | 107 -------- tox.ini | 2 +- unit_tests/compiler/__init__.py | 0 unit_tests/compiler/test_compile.py | 137 +++++++++++ unit_tests/graph/function/test_dataloader.py | 45 ++++ .../graph/function/test_dimops.py | 14 +- .../graph/gener/check_inter_rvd.py | 2 + .../graph/gener/check_intra_rvd.py | 4 +- unit_tests/graph/test_multiref.py | 124 ++++++++++ unit_tests/launch_torchrun.py | 32 +++ unit_tests/parallel_module/common.py | 13 +- unit_tests/runtime/test_reducer.py | 132 ++++++++++ .../{test_utils.py => test_torchrun.py} | 0 unit_tests/utils.py | 73 ++++++ utility/aggregate.sh | 15 ++ utility/broadcast.sh | 15 ++ {scripts => utility}/dgx1_reorder_gpu.py | 0 {scripts => utility}/keep.py | 0 {tests => utility}/test_rvd_prim.py | 9 +- 50 files changed, 597 insertions(+), 3864 deletions(-) delete mode 100755 benchmark/deepspeed/benchmark_gpt.sh delete mode 100644 benchmark/deepspeed/gpt_bench.py delete mode 100644 benchmark/deepspeed/pretrain_gpt_synthetic.py delete mode 100755 benchmark/megatron/benchmark_gpt.sh delete mode 100644 benchmark/megatron/pretrain_gpt_synthetic.py delete mode 100755 scripts/aggregate.sh delete mode 100755 scripts/env-setup.sh delete mode 100755 scripts/sync.sh delete mode 100644 tests/codegen/test_scale.py delete mode 100644 tests/gpt_memory_profile.md delete mode 100644 tests/graph/test_dump_load.py delete mode 100644 tests/graph/test_fusion.py delete mode 100644 tests/graph/test_infer_grad.py delete mode 100644 tests/graph/test_multiref.py delete mode 100644 tests/graph/test_segment.py delete mode 100644 tests/parser/test_bloom.py delete mode 100644 tests/parser/test_compile.py delete mode 100644 tests/parser/test_cus_autograd.py delete mode 100644 tests/parser/test_fx_ops.py delete mode 100644 tests/parser/test_fx_zip.py delete mode 100644 tests/parser/test_jit_ops.py delete mode 100644 tests/parser/test_no_grad.py delete mode 100644 tests/parser/test_torchscale_basic.py delete mode 100644 tests/runtime/test_reducer.py delete mode 100644 tests/runtime/test_runtime_collectives.py delete mode 100644 tests/runtime/test_runtime_flag.py delete mode 100644 tests/test_codegen.py delete mode 100755 tests/test_examples.sh delete mode 100644 tests/test_execplan_grouping.py delete mode 100644 tests/test_nccl.py delete mode 100644 tests/test_profile_gpt.py create mode 100644 unit_tests/compiler/__init__.py create mode 100644 unit_tests/compiler/test_compile.py create mode 100644 unit_tests/graph/function/test_dataloader.py rename tests/algorithm/test_op_algorithm.py => unit_tests/graph/function/test_dimops.py (86%) rename tests/adapter/test_inter_rvd.py => unit_tests/graph/gener/check_inter_rvd.py (98%) rename tests/adapter/test_intra_rvd.py => unit_tests/graph/gener/check_intra_rvd.py (99%) create mode 100644 unit_tests/graph/test_multiref.py create mode 100644 unit_tests/runtime/test_reducer.py rename unit_tests/{test_utils.py => test_torchrun.py} (100%) create mode 100644 unit_tests/utils.py create mode 100644 utility/aggregate.sh create mode 100644 utility/broadcast.sh rename {scripts => utility}/dgx1_reorder_gpu.py (100%) rename {scripts => utility}/keep.py (100%) rename {tests => utility}/test_rvd_prim.py (94%) diff --git a/README.md b/README.md index 097cfe5c..4ff20825 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,12 @@ Please note tox will reuse the same virtual environment which is initialized by tox -r ``` +To run a single unit test task during development, you can run + +``` +pytest unit_tests/your_test_file.py +``` + ### Run unit tests in vscode VS Code has a great support to unit tests. You can run/debug every tests easily in VS Code. Please refer to this document to set up your environment https://code.visualstudio.com/docs/python/testing diff --git a/benchmark/deepspeed/benchmark_gpt.sh b/benchmark/deepspeed/benchmark_gpt.sh deleted file mode 100755 index a73e45ae..00000000 --- a/benchmark/deepspeed/benchmark_gpt.sh +++ /dev/null @@ -1,136 +0,0 @@ -#!/bin/bash -# run at MagicCube/ -# ./benchmark/deepspeed/benchmark_gpt.sh - -# get commit ID: -# git rev-parse --short HEAD - -# installation -# pip install deepspeed==0.7.4 -# git clone https://github.com/microsoft/Megatron-DeepSpeed -# git checkout 54f1cb7 - -# note DeepSpeed can do: -# 1) PP > 1 with constraints of Zero-Stage=1 -# 2) TP > 1 with constraints of Zero-Stage < 3 - -cp benchmark/deepspeed/pretrain_gpt_synthetic.py \ - benchmark/deepspeed/Megatron-DeepSpeed/ - -Nnodes=1 -TP=2 -PP=2 - -# Model arch -Layers=12 -Hidden=2048 -Heads=32 -Seqlen=2048 - -# batch size -Gbs=8 -Mbs=1 -Accum=$(( ${Gbs} / ( ${Nnodes} * 8 / ${TP} / ${PP} * ${Mbs} ) )) -echo "Accumulated steps: ${Accum}" - -# zero stage config -Zero=1 -OFFLOAD_DEVICE="none" -CPU_OPTIM=" " -#OFFLOAD_DEVICE="cpu" -#CPU_OPTIM=" --cpu-optimizer" - -cd benchmark/deepspeed/Megatron-DeepSpeed - -DS_CONFIG=ds_config.json - -cat < $DS_CONFIG -{ - "train_batch_size" : $Gbs, - "train_micro_batch_size_per_gpu": $Mbs, - "steps_per_print": 1, - "gradient_accumulation_steps": ${Accum}, - "zero_optimization": { - "stage": $Zero, - "stage3_max_live_parameters": 3e9, - "stage3_max_reuse_distance": 3e9, - "stage3_param_persistence_threshold": 1e5, - "stage3_prefetch_bucket_size": 5e7, - "contiguous_gradients": true, - "overlap_comm": true, - "reduce_bucket_size": 90000000, - "sub_group_size": 1e9, - "offload_optimizer": { - "device": "$OFFLOAD_DEVICE", - "buffer_count": 4, - "pipeline_read": false, - "pipeline_write": false, - "pin_memory": true - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": true, - "initial_scale_power" : 15, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "wall_clock_breakdown": true, - "zero_allow_untested_optimizer": false, - "aio": { - "block_size": 1048576, - "queue_depth": 16, - "single_submit": false, - "overlap_events": true, - "thread_count": 2 - } -} -EOT - -# export NCCL_DEBUG=warn - -ds_args=" " -ds_args=" --deepspeed ${ds_args}" -# ds_args=" --no-pipeline-parallel ${ds_args}" -ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" -ds_args=" --zero-stage=$Zero ${ds_args}" -ds_args=" --deepspeed-activation-checkpointing ${ds_args}" - - -GPT_ARGS="--num-layers $Layers \ - --hidden-size $Hidden \ - --num-attention-heads $Heads \ - --seq-length $Seqlen \ - --loss-scale 15 \ - --max-position-embeddings $Seqlen \ - --train-iters 3 \ - --lr 6.0e-5 \ - --min-lr 6.0e-6 \ - --lr-decay-style cosine \ - --fp16 \ - --fp16-lm-cross-entropy \ - --no-query-key-layer-scaling \ - --no-masked-softmax-fusion \ - --no-bias-gelu-fusion \ - --no-bias-dropout-fusion \ - --checkpoint-activations \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --weight-decay 0.1 \ - --clip-grad 1.0 \ - --init-method-std 0.006 \ - --log-interval 1 \ - --num-workers 0" - -# deepspeed --force_multi --num_nodes -deepspeed --num_nodes=$Nnodes --num_gpus 8 \ - --master_addr localhost --master_port 6144 \ - pretrain_gpt_synthetic.py \ - $GPT_ARGS $CPU_OPTIM $ds_args \ - --global-batch-size $Gbs \ - --micro-batch-size $Mbs \ - --tensor-model-parallel-size $TP \ - --pipeline-model-parallel-size $PP - -cd ../../.. diff --git a/benchmark/deepspeed/gpt_bench.py b/benchmark/deepspeed/gpt_bench.py deleted file mode 100644 index 9980f7a4..00000000 --- a/benchmark/deepspeed/gpt_bench.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -Following - -https://github.com/microsoft/DeepSpeedExamples/blob/master/HelloDeepSpeed/train_bert_ds.py - -Config file: -https://www.deepspeed.ai/docs/config-json/ - -deepspeed --num_nodes 1 --num_gpus 8 \ - benchmark/deepspeed/gpt_bench.py \ - --fp16 --mbs 1 --gbs 4 \ - --zero 2 \ - --layers 24 --heads 32 --hidden 2048 --seqlen 2048 - -""" - -from typing import List, Tuple -import torch -import time -import numpy as np -import os -import logging - -from examples.nlp.gpt.model import GPT, Config -from examples.nlp.gpt.model import GPTDataLoader - -import argparse -import deepspeed - -logging.getLogger().setLevel(logging.WARN) - - -parser = argparse.ArgumentParser(description='GPT Train') - -parser.add_argument('--fp16', action='store_true', default=False, - help='use fp16 for the training') -parser.add_argument('--mbs', type=int, default=1, - help='micro-batch size') -parser.add_argument('--gbs', type=int, default=256, - help='global batch size') -parser.add_argument('--zero', type=int, required=True, - help='zero stage, 2 or 3') -parser.add_argument('--layers', type=int, required=True) -parser.add_argument('--heads', type=int, required=True) -parser.add_argument('--seqlen', type=int, required=True) -parser.add_argument('--hidden', type=int, required=True) - -parser.add_argument('--local_rank', type=int) -args = parser.parse_args() - -print(args) -torch.cuda.set_device(args.local_rank) - -ds_zero3_config = { - "train_micro_batch_size_per_gpu": args.mbs, - "gradient_accumulation_steps": args.gbs // args.mbs, - "zero_optimization": { - "stage": 3, - "offload_param": { # Zero-3 - "device": "cpu" - }, - "offload_optimizer": { # Zero-2 - "device": "cpu" - }, - "contiguous_gradients": True, - "overlap_comm": True, - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015, - "betas": [0.9, 0.95] - } - }, - "fp16": { - "enabled": True, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "wall_clock_breakdown": True, - "steps_per_print": 1, -} - - -ds_zero2_config = { - "train_micro_batch_size_per_gpu": args.mbs, - "gradient_accumulation_steps": args.gbs // args.mbs, - "zero_optimization": { - "stage": 2, - "offload_optimizer": { # Zero-2 - "device": "cpu" - }, - "contiguous_gradients": True, - "overlap_comm": True, - }, - "mp_size": 2, - "activation_checkpointing": { - "partition_activations": True, - "cpu_checkpointing": True, - "contiguous_memory_optimization": True, - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015, - "betas": [0.9, 0.95] - } - }, - "fp16": { - "enabled": True, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "wall_clock_breakdown": True, - "steps_per_print": 1, -} - -assert args.zero in [2, 3], f"Zero stage can only be 2 or 3" -zero_config = ds_zero2_config if args.zero == 2 else ds_zero3_config - -def log_dist(message: str, ranks: List[int] = None) -> None: - my_rank = int(os.environ.get("RANK", "0")) - if my_rank in ranks: - print(f"rank [{my_rank}] {message}") - - -def train(): - - batch_size = args.mbs - Config.seqlen = args.seqlen - Config.layers = args.layers - Config.embed_dim = args.hidden - Config.attention_heads = args.heads - - model = GPT() - model = model if not args.fp16 else model.half() - - nparams = 0 - param: torch.Tensor - for param in model.parameters(): - nparams += param.nelement() - log_dist(f'parameter before zero: {nparams}', [0]) - - model, _, _, _ = deepspeed.initialize( - model=model, - model_parameters=model.parameters(), - config=zero_config) - model.train() - log_dist("DeepSpeed engine created", ranks=[0]) - - nparams = 0 - param: torch.Tensor - for param in model.parameters(): - nparams += param.nelement() - log_dist(f'parameter after zero: {nparams}', [0]) - - dataloader = GPTDataLoader(batch_size) - - - iter_num = 3 - warmup = 1 - for step in range(iter_num): - if step == warmup: - torch.cuda.synchronize() - tic = time.time() - - data = next(dataloader) - loss = model(*data) - model.backward(loss) - model.step() - - if step == 0: - log_dist('passed first iteration', ranks=[0]) - if (step + 1) % 2 == 0: - log_dist(f'iter [{step + 1}/{iter_num}]', ranks=[0]) - torch.cuda.synchronize() - toc = time.time() - log_dist(f"iteration time: {(toc-tic) / (iter_num - warmup) * 1000} ms", ranks=[0]) - log_dist(f"Max allocated memory: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024} GB", [0]) - -if __name__ == '__main__': - train() \ No newline at end of file diff --git a/benchmark/deepspeed/pretrain_gpt_synthetic.py b/benchmark/deepspeed/pretrain_gpt_synthetic.py deleted file mode 100644 index 32de2eb2..00000000 --- a/benchmark/deepspeed/pretrain_gpt_synthetic.py +++ /dev/null @@ -1,222 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Pretrain GPT""" - -import torch -from functools import partial -from megatron import get_args -from megatron import print_rank_0 -from megatron import get_timers -from megatron import get_tokenizer -from megatron import mpu -from megatron.data.gpt_dataset import build_train_valid_test_datasets -from megatron.model import GPTModel, GPTModelPipe -from megatron.training import pretrain -from megatron.utils import get_ltor_masks_and_position_ids -from megatron.utils import average_losses_across_data_parallel_group - -import deepspeed -from deepspeed.runtime.utils import see_memory_usage -import os -import subprocess - -from torch import nn -import torch.nn.functional as F - -def model_provider(pre_process=True, post_process=True): - """Build the model.""" - - print_rank_0('building GPT model ...') - see_memory_usage(f"Before Building Model", force=True) - - args = get_args() - vocab_size = 50257 - after = vocab_size - multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size - while after % multiple != 0: - after += 1 - args.padded_vocab_size = after - - with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), - remote_device=None if args.remote_device == 'none' else args.remote_device, - config_dict_or_path=args.deepspeed_config, - enabled=args.zero_stage == 3, - mpu=mpu): - if args.deepspeed and not args.no_pipeline_parallel: - print_rank_0('building GPT model using DeepSpeed ...') - model = GPTModelPipe( - num_tokentypes=0, - parallel_output=True - ) - # This is a hack to give us a reference to get_batch_pipe from within training.py - # We need to call model.set_batch_fn after deepspeed.initialize - model._megatron_batch_fn = get_batch_pipe - - # Predompute the attention mask and store it in args. This avoids having to - # pipeline it as an activation during training. The mask is constant, and thus - # we can reuse it. - attention_mask = torch.tril(torch.ones( - (1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view( - 1, 1, args.seq_length, args.seq_length) - - # Convert attention mask to binary: - attention_mask = (attention_mask < 0.5) - if args.fp16: - attention_mask = attention_mask.half() - elif args.bf16: - attention_mask = attention_mask.bfloat16() - - # Attention mask must be bool. - args.attn_mask = attention_mask.to(torch.bool) - else: - model = GPTModel( - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - return_moe_loss=False - ) - - see_memory_usage(f"After Building Model", force=True) - return model - - -def get_batch(data_iterator): - """Generate a batch""" - args = get_args() - vocab_size = 50257 - tokens = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size - labels = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size - loss_mask = torch.ones(tokens.size(), dtype=torch.float, device=torch.cuda.current_device()) - attention_mask = torch.tril(torch.ones( - (args.micro_batch_size, args.seq_length, args.seq_length), device=torch.cuda.current_device() - )).view(args.micro_batch_size, 1, args.seq_length, args.seq_length) - attention_mask = (attention_mask < 0.5) - position_ids = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * args.seq_length - - return tokens, labels, loss_mask, attention_mask, position_ids - - -def get_batch_pipe(data): - """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" - # args = get_args() - # tokenizer = get_tokenizer() - - # # Items and their type. - # keys = ['text'] - # datatype = torch.int64 - # - # # Broadcast data. - # data_b = mpu.broadcast_data(keys, data, datatype) - # - # # Unpack. - # tokens_ = data_b['text'].long() - # labels = tokens_[:, 1:].contiguous() - # tokens = tokens_[:, :-1].contiguous() - # - # # Get the masks and postition ids. - # attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - # tokens, - # tokenizer.eod, - # args.reset_position_ids, - # args.reset_attention_mask, - # args.eod_mask_loss) - # if args.curriculum_learning and args.curriculum_seqlen < tokens.size()[1]: - # # seqlen-based curriculum learning - # # tokens, position_ids, labels, loss_mask have size [batch size, seqlen] - # tokens = tokens[:, :args.curriculum_seqlen].contiguous() - # position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() - # if labels is not None: - # labels = labels[:, :args.curriculum_seqlen].contiguous() - # loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()\ - tokens, labels, loss_mask, attention_mask, position_ids = get_batch(None) - - return (tokens, position_ids, attention_mask), (labels, loss_mask) - - - -def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): - args = get_args() - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {'lm loss': averaged_loss[0]} - - -def forward_step(data_iterator, model): - """Forward step.""" - args = get_args() - timers = get_timers() - - # Get the batch. - timers('batch-generator').start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator) - timers('batch-generator').stop() - - output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) - - moe_loss = 0 - mos_loss = 0 - return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) - - -def train_valid_test_datasets_provider(train_val_test_num_samples): - """Build train, valid, and test datasets.""" - return [1]*10000, None, None - - -def command_exists(cmd): - result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) - return result.wait() == 0 - - -def git_ds_info(): - from deepspeed.env_report import main as ds_report - ds_report() - - # Write out version/git info - git_hash_cmd = "git rev-parse --short HEAD" - git_branch_cmd = "git rev-parse --abbrev-ref HEAD" - if command_exists('git'): - try: - result = subprocess.check_output(git_hash_cmd, shell=True) - git_hash = result.decode('utf-8').strip() - result = subprocess.check_output(git_branch_cmd, shell=True) - git_branch = result.decode('utf-8').strip() - except subprocess.CalledProcessError: - git_hash = "unknown" - git_branch = "unknown" - else: - git_hash = "unknown" - git_branch = "unknown" - print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****') - - -if __name__ == "__main__": - git_ds_info() - pretrain(train_valid_test_datasets_provider, model_provider, forward_step, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) - mem = torch.cuda.max_memory_allocated() - for rank in range(torch.distributed.get_world_size()): - if rank == torch.distributed.get_rank(): - print(f'rank[{rank}]: memory consumption: {round(mem / 1024 / 1024 / 1024 * 100) / 100} GBs') - torch.distributed.barrier() \ No newline at end of file diff --git a/benchmark/megatron/benchmark_gpt.sh b/benchmark/megatron/benchmark_gpt.sh deleted file mode 100755 index 7e973e1c..00000000 --- a/benchmark/megatron/benchmark_gpt.sh +++ /dev/null @@ -1,69 +0,0 @@ -# setup megatron -# git clone https://github.com/NVIDIA/Megatron-LM.git -# pip install regex - -# setup apex -# git clone https://github.com/NVIDIA/apex -# cd apex -# pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . -# cd .. - -cp pretrain_gpt_synthetic.py ./Megatron-LM/ - -NODE_GPUS=8 -PP=4 -TP=4 - -GPT_ARGS="--num-layers 32 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --lr 0.00015 \ - --train-iters 10 \ - --lr-decay-iters 320000 \ - --lr-decay-style cosine \ - --lr-warmup-fraction .01 \ - --fp16 \ - --fp16-lm-cross-entropy \ - --no-query-key-layer-scaling \ - --no-masked-softmax-fusion \ - --no-bias-gelu-fusion \ - --no-bias-dropout-fusion \ - --no-async-tensor-model-parallel-allreduce \ - --no-gradient-accumulation-fusion \ - --checkpoint-activations \ - --log-interval 1 \ - --num-workers 0" - -# --checkpoint-activations - -SINGLE_NODE="--nproc_per_node $NODE_GPUS \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr localhost \ - --master_port 6000" - -MULTI_NODE="--nproc_per_node $NODE_GPUS \ - --nnodes 2 \ - --node_rank ${NODE_RANK} \ - --master_addr worker-0 \ - --master_port 6012" - - -cd Megatron-LM - -OMP_NUM_THREADS=4 python -m torch.distributed.launch $MULTI_NODE \ - pretrain_gpt_synthetic.py $GPT_ARGS \ - --global-batch-size 128 \ - --micro-batch-size 4 \ - --tensor-model-parallel-size $TP \ - --pipeline-model-parallel-size $PP \ - --DDP-impl local - - -# OMP_NUM_THREADS=4 python -m torch.distributed.launch \ -# --nproc_per_node 1 --master_addr localhost --master_port 6112 \ -# pretrain_gpt_synthetic.py -h - -cd .. \ No newline at end of file diff --git a/benchmark/megatron/pretrain_gpt_synthetic.py b/benchmark/megatron/pretrain_gpt_synthetic.py deleted file mode 100644 index 48345eff..00000000 --- a/benchmark/megatron/pretrain_gpt_synthetic.py +++ /dev/null @@ -1,112 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Pretrain GPT""" - -import torch -from functools import partial -from megatron import get_args -from megatron import print_rank_0 -from megatron import get_timers -from megatron import get_tokenizer -from megatron import mpu -from megatron.data.gpt_dataset import build_train_valid_test_datasets -from megatron.model import GPTModel, ModelType -from megatron.training import pretrain -from megatron.utils import get_ltor_masks_and_position_ids -from megatron.utils import average_losses_across_data_parallel_group - -def model_provider(pre_process=True, post_process=True): - """Build the model.""" - args = get_args() - - vocab_size = 50257 - after = vocab_size - multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size - while after % multiple != 0: - after += 1 - args.padded_vocab_size = after - - print_rank_0('building GPT model ...') - model = GPTModel( - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process - ) - return model - - -def get_batch(data_iterator): - """Generate a batch""" - args = get_args() - vocab_size = 50257 - tokens = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size - labels = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * vocab_size - loss_mask = torch.ones(tokens.size(), dtype=torch.float, device=torch.cuda.current_device()) - attention_mask = torch.tril(torch.ones( - (args.micro_batch_size, args.seq_length, args.seq_length), device=torch.cuda.current_device() - )).view(args.micro_batch_size, 1, args.seq_length, args.seq_length) - attention_mask = (attention_mask < 0.5) - position_ids = torch.rand((args.micro_batch_size, args.seq_length), requires_grad=False, device=torch.cuda.current_device()).long() * args.seq_length - - return tokens, labels, loss_mask, attention_mask, position_ids - -def loss_func(loss_mask, output_tensor): - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {'lm loss': averaged_loss[0]} - - -def forward_step(data_iterator, model): - """Forward step.""" - args = get_args() - timers = get_timers() - - # Get the batch. - timers('batch-generator').start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator) - timers('batch-generator').stop() - - output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) - - return output_tensor, partial(loss_func, loss_mask) - - -def train_valid_test_datasets_provider(train_val_test_num_samples): - """Build train, valid, and test datasets.""" - return [1]*10000, None, None - - -if __name__ == "__main__": - - - - pretrain(train_valid_test_datasets_provider, model_provider, - ModelType.encoder_or_decoder, - forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) - - mem = torch.cuda.max_memory_allocated() - for rank in range(torch.distributed.get_world_size()): - if rank == torch.distributed.get_rank(): - print(f'rank[{rank}]: memory consumption: {round(mem / 1024 / 1024 / 1024 * 100) / 100} GBs') - torch.distributed.barrier() diff --git a/scripts/aggregate.sh b/scripts/aggregate.sh deleted file mode 100755 index a5d03ab5..00000000 --- a/scripts/aggregate.sh +++ /dev/null @@ -1,23 +0,0 @@ -# ============= ITP Variables ============ -# NODE_RANK -# MASTER_IP -# MASTER_PORT -# ============= ITP Variables ============ - -node_num=$1 - -if [ ${node_num} == 4 ] -then - mkdir -p /workspace/MagicCube/eval/worker-1 - scp -r worker-1:/workspace/MagicCube/eval/ /workspace/MagicCube/eval/worker-1 - mkdir -p /workspace/MagicCube/eval/worker-2 - scp -r worker-2:/workspace/MagicCube/eval/ /workspace/MagicCube/eval/worker-2 - mkdir -p /workspace/MagicCube/eval/worker-3 - scp -r worker-3:/workspace/MagicCube/eval/ /workspace/MagicCube/eval/worker-3 -fi - -if [ ${node_num} == 2 ] -then - mkdir -p /workspace/MagicCube/eval/worker-1 - scp -r worker-1:/workspace/MagicCube/eval/ workspace/MagicCube/eval/worker-1 -fi diff --git a/scripts/env-setup.sh b/scripts/env-setup.sh deleted file mode 100755 index c9ecf29a..00000000 --- a/scripts/env-setup.sh +++ /dev/null @@ -1,49 +0,0 @@ - -echo using docker image nvcr.io/pytorch:pytorch-21.12-py3 - -git config --global core.editor "vim" -git config --global user.name "Zhiqi Lin" -git config --global user.email "v-zhiql@microsoft.com" - -sudo git config --global core.editor "vim" -sudo git config --global user.name "Zhiqi Lin" -sudo git config --global user.email "v-zhiql@microsoft.com" -sudo chmod -R a+w /opt/conda - -sudo apt-get update -sudo apt-get install htop -y -sudo apt-get install tmux -y -sudo apt-get install psmisc -y -sudo apt-get install lsof -y -sudo apt-get install infiniband-diags -y # ibstatus => check ib link -sudo apt-get install net-tools -y # ifconfig - -# install blob -# sudo apt-get install lsb-release -y -# wget https://packages.microsoft.com/config/ubuntu/20.04/packages-microsoft-prod.deb -# sudo dpkg -i packages-microsoft-prod.deb -# sudo apt-get update -# sudo apt-get install blobfuse -y -# sudo rm packages-microsoft-prod.deb - -# install azcopy -# wget https://azcopyvnext.azureedge.net/release20210616/azcopy_linux_amd64_10.11.0.tar.gz -O azcopy.tar.gz -# tar -zxvf azcopy.tar.gz -# sudo mv azcopy_linux_amd64_10.11.0/azcopy /usr/bin/ -# rm -rf azcopy_linux_amd64_10.11.0 azcopy.tar.gz - -wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.tmux.conf -O ~/.tmux.conf -wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/.vimrc -O ~/.vimrc - -echo 'export PATH=/opt/conda/bin:$PATH' >> ~/.bashrc -echo 'export PATH=/usr/local/cuda/bin:$PATH' >> ~/.bashrc -echo 'export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH' >> ~/.bashrc -echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc - -# cmd for count code lines -# find cube/ -name "*.py" -print0 | xargs -0 wc -l - -# training_daemon will disable torch.jit.script -pip uninstall training_daemon -y -python setup.py develop -pip install -r requirements.txt diff --git a/scripts/sync.sh b/scripts/sync.sh deleted file mode 100755 index 928958d5..00000000 --- a/scripts/sync.sh +++ /dev/null @@ -1,40 +0,0 @@ -# ============= ITP Variables ============ -# NODE_RANK -# MASTER_IP -# MASTER_PORT -# ============= ITP Variables ============ - -node_num=$1 -folder=$2 - -host=worker - -if [ ${node_num} == 8 ] -then - scp -r /workspace/MagicCube/$folder $host-1:/workspace/MagicCube/ - scp -r /workspace/MagicCube/$folder $host-2:/workspace/MagicCube/ - scp -r /workspace/MagicCube/$folder $host-3:/workspace/MagicCube/ - scp -r /workspace/MagicCube/$folder $host-4:/workspace/MagicCube/ - scp -r /workspace/MagicCube/$folder $host-5:/workspace/MagicCube/ - scp -r /workspace/MagicCube/$folder $host-6:/workspace/MagicCube/ - scp -r /workspace/MagicCube/$folder $host-7:/workspace/MagicCube/ -fi - -if [ ${node_num} == 4 ] -then - scp -r /workspace/MagicCube/$folder $host-1:/workspace/MagicCube/ - scp -r /workspace/MagicCube/$folder $host-2:/workspace/MagicCube/ - scp -r /workspace/MagicCube/$folder $host-3:/workspace/MagicCube/ -fi - -if [ ${node_num} == 2 ] -then - scp -r /workspace/MagicCube/$folder $host-1:/workspace/MagicCube/ -fi - - -# rm -f notify.py -# wget https://raw.githubusercontent.com/zhiqi-0/EnvDeployment/master/email/notify.py -# python notify.py --sender zhiqi.0@qq.com --code uyakwgslumknbfgg --recver zhiqi.0@outlook.com \ -# --msg "Test Results Swin Coshard | 32 GPU" \ -# --file logs/e2e-swin-32gpu-coshard-${NODE_RANK}.txt \ No newline at end of file diff --git a/tests/codegen/test_scale.py b/tests/codegen/test_scale.py deleted file mode 100644 index 93dbf19b..00000000 --- a/tests/codegen/test_scale.py +++ /dev/null @@ -1,202 +0,0 @@ - - -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - tests/codegen/test_scale.py - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=1 \ - tests/codegen/test_scale.py -""" - -from typing import List -import torch -from torch import nn - -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - - -cube.init() - - -class MLP(nn.Module): - def __init__(self, dim, mult=1, nlayers=16): - super().__init__() - self.layers = torch.nn.ModuleList([]) - for lid in range(nlayers): - if lid % 2 == 0: - self.layers.append(nn.Linear(dim, dim * mult, bias=False)) - else: - self.layers.append(nn.Linear(dim * mult, dim, bias=False)) - - def forward(self, data): - x = data - for layer in self.layers: - x = layer(x) - loss = torch.sum(x) - return loss - - -class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, bs: int, dim: int): - super().__init__(bs, [0]) - self.sample = None - self.dim = dim - self.set_batch_size(bs) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - self.batch_size = batch_size - self.sample = torch.rand( - [batch_size, self.dim], dtype=torch.float32, - device=torch.cuda.current_device() - ) - - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], - idx: int, dim: int, tag='dim'): - algo = node.algorithms(tag) - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def _run(train_iter, model, dataloader, optimizer): - iter_num, warmup = 5, 2 - for step in range(iter_num): - if step >= warmup: - CudaTimer(enable=True).start('e2e') - loss = train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - # model.zero_grad() - # model.gather_params() - if step >= warmup: - CudaTimer().stop('e2e') - print_each_rank(f'loss: {loss.item()}', rank_only=0) - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) - - -def test_scale_full_dp(): - - model = MLP(dim=4096) - dataloader = MLPDataLoader(bs=8, dim=4096) - - def policy(graph: IRGraph, resource): - assert resource.ngpus > 2 - ngpus = 2 - for dl in graph.select(ntype=IRDataOperation): - _replica(graph, dl, list(range(ngpus))) - for node in graph.select(ntype=IRFwOperation): - if node.name == 'linear': - _tp(graph, node, list(range(ngpus)), idx=0, dim=0, tag='dim') - else: - _replica(graph, node, list(range(ngpus))) - return graph - - @cube.compile(model, dataloader, PAS=policy, scale=True) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - model = cube.load_model() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - _run(train_iter, model, dataloader, optimizer) - - -def test_scale_partial_dp(): - - model = MLP(dim=4096) - dataloader = MLPDataLoader(bs=8, dim=4096) - - def policy(graph: IRGraph, resource): - assert resource.ngpus > 2 - ngpus = 2 - for dl in graph.select(ntype=IRDataOperation): - _replica(graph, dl, list(range(ngpus))) - for idx, node in enumerate(graph.select(ntype=IRFwOperation)): - if node.name == 'linear': - if idx % 4 == 0: - _tp(graph, node, list(range(ngpus)), idx=0, dim=0, tag='dim') - if idx % 4 == 1: # partition weight, partition input (reduction) - _tp(graph, node, list(range(ngpus)), idx=0, dim=1, tag='dim') - if idx % 4 == 2: # partition weight, replicate input - _tp(graph, node, list(range(ngpus)), idx=1, dim=0, tag='dim') - if idx % 4 == 3: # replicate - _replica(graph, node, list(range(ngpus))) - else: - _replica(graph, node, list(range(ngpus))) - return graph - - @cube.compile(model, dataloader, PAS=policy, scale=True) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - model = cube.load_model() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - _run(train_iter, model, dataloader, optimizer) - - -def test_scale_no_dp(): - - model = MLP(dim=4096) - dataloader = MLPDataLoader(bs=8, dim=4096) - - def policy(graph: IRGraph, resource): - assert resource.ngpus > 2 - ngpus = 2 - for dl in graph.select(ntype=IRDataOperation): - _replica(graph, dl, list(range(ngpus))) - for node in graph.select(ntype=IRFwOperation): - _replica(graph, node, list(range(ngpus))) - return graph - - @cube.compile(model, dataloader, PAS=policy, scale=True) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - model = cube.load_model() - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - _run(train_iter, model, dataloader, optimizer) - - -if __name__ == '__main__': - - # test_scale_full_dp() - # test_scale_partial_dp() - test_scale_no_dp() \ No newline at end of file diff --git a/tests/gpt_memory_profile.md b/tests/gpt_memory_profile.md deleted file mode 100644 index c828c185..00000000 --- a/tests/gpt_memory_profile.md +++ /dev/null @@ -1,9 +0,0 @@ -# GPT-3 toy model memory profiling result - -| layer | end2end | activation | param | e2e - 3 * p - activation | -|:------|:--------|:-----------|:------|:-------------------------| -| 1 | 1.59 | 0.47 | 0.24 | 0.40 | -| 2 | 1.98 | 0.73 | 0.29 | 0.38 | -| 4 | 2.78 | 1.24 | 0.38 | 0.40 | -| 8 | 4.37 | 2.26 | 0.57 | 0.40 | -| 16 | 7.55 | 4.30 | 0.95 | 0.40 | \ No newline at end of file diff --git a/tests/graph/test_dump_load.py b/tests/graph/test_dump_load.py deleted file mode 100644 index df036df3..00000000 --- a/tests/graph/test_dump_load.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -USE_TORCHFX=1 torchrun --nproc_per_node=2 tests/graph/test_dump_load.py -""" -from typing import List -import torch -from cube.ir.cten import IRObject - -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops - - -cube.init() - - -def _param(size, dtype=torch.float32): - return torch.nn.Parameter(torch.empty(size, dtype=dtype)) - - -class TestOpModule(torch.nn.Module): - - def __init__(self): - super().__init__() - self.param1 = _param([512, 256]) - self.param2 = _param([512, 256]) - self.ints = [1, 2, 3] - - def forward(self, x: torch.Tensor): - # matmul: [bs, 512], [512, 256] -> [bs, 256] - x1 = torch.matmul(x, self.param1) - # [bs, 256] -> [bs, 256] - x1 = x1 + x1.size(0) + x1.size()[0] - # [bs, 256] -> [bs, 128], [bs, 128] - x2 = torch.chunk(x1, 2, dim=1)[0] - # [bs, 128] -> [bs, 128] - x3 = x2 + x2.size(0) - x4 = x3 + self.ints[0] - # [bs, 128] -> [1] - loss = torch.sum(x4) - return {'x': x4, 'loss': loss} # , [x3,] - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int = 256) -> None: - self.sample = torch.rand( - [batch_size, 512], - dtype=torch.float32, - device=torch.cuda.current_device() - ) - super().__init__(batch_size, (0,)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def test_graph_dump_load_single(): - - model = TestOpModule() - dataloader = TestDataLoader() - - def policy(graph: IRGraph, resource): - print('================ original one:') - print(graph.extra_repr()) - - graph.dump('graph.pickle') - new_graph = IRGraph.load('graph.pickle') - - print('================ loaded from pickled one:') - print(graph.extra_repr()) - - for node in graph.nodes(): - for t in node.inputs(): - if isinstance(t, IRObject): - assert t.cell is not None - - assert graph.extra_repr() == new_graph.extra_repr() - - assert resource.ngpus == 1 - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if node.name == 'add': - sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=0, num=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return graph - - - @cube.compile(model, dataloader, PAS=policy, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def train_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out['loss'].backward() - - model = cube.load_model(load_content=False) - - for idx in range(3): - train_iter(model, dataloader) - print(f"iter {idx}/3") - - -def test_graph_dump_load_with_transform(): - - model = TestOpModule() - dataloader = TestDataLoader() - - def policy(graph: IRGraph, resource): - print('================ original one:') - print(graph.extra_repr()) - old_repr = graph.extra_repr() - - graph.dump('graph.pickle') - graph = IRGraph.load('graph.pickle') - - print('================ loaded from pickled one:') - print(graph.extra_repr()) - new_repr = graph.extra_repr() - - for node in graph.nodes(): - for t in node.inputs(): - if isinstance(t, IRObject): - assert t.cell is not None - - assert new_repr == old_repr - - assert resource.ngpus == 2 - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - - @cube.compile(model, dataloader, PAS=policy, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def train_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out['loss'].backward() - - model = cube.load_model(load_content=False) - - for idx in range(3): - train_iter(model, dataloader) - print(f"iter {idx}/3") - - -if __name__ == '__main__': - # test_graph_dump_load_single() - test_graph_dump_load_with_transform() diff --git a/tests/graph/test_fusion.py b/tests/graph/test_fusion.py deleted file mode 100644 index 6b201f41..00000000 --- a/tests/graph/test_fusion.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/graph/test_fusion.py -""" -from typing import List -import torch - -import cube -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops - -cube.init() - - -def _param(size, dtype=torch.float32): - return torch.nn.Parameter(torch.empty(size, dtype=dtype)) - - -class TestModuleForFusedOp(torch.nn.Module): - - def __init__(self): - super().__init__() - self.param1 = _param([512, 256]) - self.param2 = _param([512, 256]) - self.ints = [1, 2, 3] - - def forward(self, x: torch.Tensor): - # matmul: [bs, 512], [512, 256] -> [bs, 256] - x1 = torch.matmul(x, self.param1) - # [bs, 256] -> [bs, 256] - x2 = x1.clone() - x3 = x2 + 1 - loss = torch.sum(x3) - return {'x': x3, 'loss': loss} # , [x3,] - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int = 256) -> None: - self.sample = torch.rand( - [batch_size, 512], - dtype=torch.float32, - device=torch.cuda.current_device() - ) - super().__init__(batch_size, (0,)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - - -def test_fused_op(): - - model = TestModuleForFusedOp() - dataloader = TestDataLoader() - - def policy(graph, resource): - assert resource.ngpus == 1 - print(graph.extra_repr()) - - clone = graph.select(name='clone')[0] - idx = graph.index(clone) - clonse_add = [clone, graph.node(idx+1)] - graph.fuse(clonse_add) - - for node in graph.nodes(): - if isinstance(node, IRDimops): - print(f'# {node.anno}') - print(node) - elif isinstance(node, (IRFwOperation, IRDataOperation)): - print(node) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - model = cube.SemanticModel(model) - - @cube.compile(model, dataloader, PAS=policy, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def eval_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out['loss'].backward() - # return out - - model = model.get_gen_module() - - for idx in range(3): - eval_iter(model, dataloader) - print(f"iter {idx}/3") - - -if __name__ == '__main__': - test_fused_op() diff --git a/tests/graph/test_infer_grad.py b/tests/graph/test_infer_grad.py deleted file mode 100644 index f185dd36..00000000 --- a/tests/graph/test_infer_grad.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/graph/test_infer_grad.py -USE_TORCHFX=1 torchrun --nproc_per_node=2 tests/graph/test_infer_grad.py -""" -from typing import List -import torch - -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation - -cube.init() - - -def _param(size, dtype=torch.float32): - return torch.nn.Parameter(torch.empty(size, dtype=dtype)) - -def _rand(size, dtype=torch.float32): - return torch.rand(size, dtype=dtype, device=torch.cuda.current_device()) - - -class TestOpModule(torch.nn.Module): - - def __init__(self): - super().__init__() - self.param1 = _param([256, 512]) - self.param2 = _param([256, 512]) - self.param3 = _param([256, 512]) - - def forward(self, x: torch.Tensor): - x1 = x * self.param1 - x2 = x1 * self.param2 # no grad - - cube.runtime.function.anchor('residual') - x3 = x1 + 2 - x4 = x3 * self.param3 - - loss = torch.sum(x4) - return {'intermediate': [x3, x2], 'loss': loss}, loss.data - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self) -> None: - self.sample = _rand([256, 512]) - super().__init__(256, (0,)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def policy_test_single_device(graph: IRGraph, resource): - print(graph.extra_repr()) - for idx, node in enumerate(graph.select(name='mul')): - if idx == 1: - assert node.mirror is None - for t in node.inputs() + node.outputs(): - assert t.grad is None - elif idx == 2: - assert node.mirror is not None - for t in node.inputs() + node.outputs(): - assert t.grad is not None - for node in graph.select(ntype=(IRDataOperation, IRFwOperation)): - graph.assign(node, 0) - return graph - - -def policy_test_multi_device(graph: IRGraph, resource): - # multiref - for ftensor in graph.full_tensors(): - if ftensor.is_attr(): continue - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) - - print(graph.extra_repr()) - assert resource.ngpus == 2 - for idx, node in enumerate(graph.select(ntype=(IRFwOperation, IRDataOperation))): - devid = 0 if idx < 4 else 1 - graph.assign(node, devid) - print(graph.extra_repr()) - return graph - - -def test_single_no_backward_ops(): - - model = TestOpModule() - dataloader = TestDataLoader() - - @cube.compile(model, dataloader, PAS=policy_test_single_device, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def train_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out[0]['loss'].backward() - return out - - model = cube.load_model(load_content=False) - - for idx in range(3): - train_iter(model, dataloader) - print(f"single device: iter {idx}/3") - - -def test_multidev_residual(): - - model = TestOpModule() - dataloader = TestDataLoader() - - @cube.compile(model, dataloader, PAS=policy_test_multi_device, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def train_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out[0]['loss'].backward() - return out - - model = cube.load_model(load_content=False) - - for idx in range(3): - train_iter(model, dataloader) - print(f"multi device: iter {idx}/3") - - -if __name__ == '__main__': - if torch.distributed.get_world_size() == 1: - test_single_no_backward_ops() - if torch.distributed.get_world_size() == 2: - test_multidev_residual() diff --git a/tests/graph/test_multiref.py b/tests/graph/test_multiref.py deleted file mode 100644 index 76e384a1..00000000 --- a/tests/graph/test_multiref.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -torchrun --nproc_per_node=2 tests/graph/test_multiref.py -""" -import torch - -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops - -cube.init() - - -def _param(shape, dtype=torch.float32): - return torch.nn.Parameter(torch.empty(shape, dtype=dtype)) - - -class TestOpModule(torch.nn.Module): - - def __init__(self, shape=[256, 512]): - super().__init__() - self.param = _param(shape) - - def forward(self, x: torch.Tensor, y: torch.Tensor): - x = x * self.param - x = torch.sum(x) - - y = y * self.param - y = torch.sum(y) - - loss = x + y - return loss - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int = 256) -> None: - self.sample = ( - torch.rand([batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()), - torch.rand([batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()), - ) - super().__init__(batch_size, (0, 0)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def _tp(graph, node, devs, idx, dim): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - for node, devid in zip(sub_nodes, devs): - graph.assign(node, devid) - return sub_nodes - - -def _replica(graph, node, devs): - rnodes = graph.replicate(node, times=len(devs)) - for rnode, devid in zip(rnodes, devs): - graph.assign(rnode, devid) - return rnodes - - -def test_multiref_param(): - - cube.init() - - model = TestOpModule() - dataloader = TestDataLoader() - - def policy(graph: IRGraph, resource): - - # multiref - for t in graph.full_tensors(): - if len(graph.consumers(t)) > 1: - graph.multiref(t) - - devs = list(range(resource.ngpus)) - - muls = graph.select(name='mul') - _tp(graph, muls[0], devs, idx=1, dim=0) - _tp(graph, muls[1], devs, idx=1, dim=1) - - for node in graph.select(ntype=(IRDataOperation, IRFwOperation)): - if node.name == 'multiref': continue - if node.name == 'mul': continue - _replica(graph, node, devs) - - return graph - - sample_x, sample_y = next(dataloader) - - @cube.compile(model, dataloader, PAS=policy, load_content=True, - model_dummy_inputs={'x': sample_x, 'y': sample_y}) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(*data) - loss.backward() - - model = cube.load_model() - - for idx in range(3): - train_iter(model, dataloader) - print(f"iter {idx}/3") - print('Done') - - -if __name__ == '__main__': - test_multiref_param() - exit(0) diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py deleted file mode 100644 index be8b245d..00000000 --- a/tests/graph/test_segment.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -PYTHONPATH=.:$PYTHONPATH torchrun --nproc_per_node=1 \ - tests/graph/test_segment.py -""" -import torch - -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops - -cube.init() - - -def _param(shape, dtype=torch.float32): - return torch.nn.Parameter(torch.empty(shape, dtype=dtype)) - - -class TestOpModule(torch.nn.Module): - - def __init__(self, shape=[256, 512]): - super().__init__() - self.param = _param(shape) - - def forward(self, x: torch.Tensor, y: int): - x = x * self.param - x = x + y - loss = torch.sum(x) - return loss - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int = 256) -> None: - self.sample = ( - torch.rand([batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()), - 4, - ) - super().__init__(batch_size, (0, None)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def test_segment_creation(): - - cube.init() - - model = TestOpModule() - dataloader = TestDataLoader() - - def policy(graph: IRGraph, resource): - assert resource.ngpus == 1 - fwops = graph.select(ntype=IRFwOperation) - graph.staging([fwops[0]]) - print(graph.extra_repr()) - for node in fwops: - graph.assign(node, 0) - for dl in graph.select(ntype=IRDataOperation): - graph.assign(dl, 0) - return graph - - sample_x, sample_y = next(dataloader) - - @cube.compile(model, dataloader, PAS=policy, load_content=True, - model_dummy_inputs={'x': sample_x, 'y': sample_y}) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(*data) - loss.backward() - - model = cube.load_model() - - for idx in range(3): - train_iter(model, dataloader) - print(f"iter {idx}/3") - print('Done') - - -if __name__ == '__main__': - test_segment_creation() diff --git a/tests/parser/test_bloom.py b/tests/parser/test_bloom.py deleted file mode 100644 index 934fc1cd..00000000 --- a/tests/parser/test_bloom.py +++ /dev/null @@ -1,180 +0,0 @@ -from pathlib import Path -import torch -import torch.nn as nn -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -model_name = "bigscience/bloom-560m" -model_path = "/home/quzha/bloom560m" - -print("Loading model...") -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir=model_path) -print(type(model), '; is nn.Module? ', isinstance(model, nn.Module)) -print("Model's generation config which does not list default values: ", model.generation_config) -print("Loading tokenizer...") -tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) -print("Loading Done!") -prompt = "If I want to travel to a new city, I should plan my trip as follows:" -#input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() -inputs = tokenizer(prompt, return_tensors="pt") - -# Cube -# from cube.graph import parser -# ir_graph = parser.convert_model(model, input_shapes=[1, 17], save_content=False) - -print("concrete tracing model...") -from nni.common.concrete_trace_utils import concrete_trace -traced_graph = concrete_trace(model, inputs, use_operator_patch=True, - autowrap_leaf_class={torch.finfo: ((), False)}) -print("tracing model done.") - -print("parsing fx graph to cube graph...") -from cube.graph.parser import FxModuleParser -inputs, nodes, outputs = FxModuleParser.parse(traced_graph, dummy_inputs=inputs) -print("parsing done.") -from cube.graph import IRGraph -module_name = model.__class__.__name__ -cube_graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) - -# AutoDist -# # profile communication cost -# import os -# comm_gpu_num = (2, 4) -# for gpu_num in comm_gpu_num: -# os.system(f'torchrun --nproc_per_node={gpu_num} /home/quzha/AutoDist/comm_profile.py --connect_type=NV') -# profile computation cost -class dotdict(dict): - """dot.notation access to dictionary attributes""" - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ -config = dotdict({'profile_dir': str(Path.home())+'/.autodist/', 'task_name': 'bloom'}) -config.autodist_config = dotdict({'ngpus': 2}) -# NOTE add SINGLE_DEV_MODE=1 before the running command -from autodist.cost_model.cost_database import CostDatabase -cost_database = CostDatabase(cube_graph, config) -# find the best partition plan -from autodist.task_config import TaskConfig -class BloomTaskConfig(TaskConfig): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.model = 'Bloom' - # self.Bloom_setting = kwargs['Bloom_setting'] - # self.fine_grained_Bloom = kwargs['fine_grained_Bloom'] - # self.bloom_config = build_bloom_config(self.Bloom_setting) - self.task_name = f'bloom-{self.autodist_config.ngpus}gpu-'\ - f'{self.autodist_config.micro_batch_size}batch_size' - self.estimated_fname, self.backup_fname, self.runtime_fname = self._build_file_name( - self.task_name) - self.allow_recom_ops = [] - self.del_dim = [] -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Bloom benchmark') - parser.add_argument('--fp16', - action='store_true', - help='use fp16 for the training') - parser.add_argument('--fine_grained_GPT', - action='store_true', - help='model = GPTFineGrained') - parser.add_argument('--GPT_setting', - type=str, - default='6.7B', - help='set GPT model type') - parser.add_argument('--save_folder', - type=str, - default='exp_data', - help='set the save folder for experiment data') - parser.add_argument('--micro_batch_size', - type=int, - default=8, - help='set micro batch size') - parser.add_argument('--global_batch_size', - type=int, - default=8, - help='set the global batch size') - parser.add_argument('--iter_num', - type=int, - default=2, - help='set the number of all iterations') - parser.add_argument('--warm_num', - type=int, - default=1, - help='set the number of warmup iterations') - parser.add_argument('--recompute', - action='store_true', - help='set recompute flag') - parser.add_argument('--memory_constraint', - type=float, - default=32, - help='memory constraint for program') - parser.add_argument('--memory_granularity', - type=int, - default=1, - help='memory granularity in byte') - parser.add_argument('--profile_dir', - type=str, - default=str(Path.home()) + '/.autodist', - help='profile dir') - parser.add_argument('--connect_type', - type=str, - default='NV2', - help='connect type from nvidia-smi topo -m') - parser.add_argument('--use_prev_plan', - action='store_true', - help='run from previous plan') - parser.add_argument('--is_train', - action='store_true', - help='True: train, False: inference') - parser.add_argument('--topk', - type=int, - default=20, - help='generate multiple plans for robustness') - parser.add_argument('--mesh_row', type=int, default=1, help='node num') - parser.add_argument('--mesh_col', - type=int, - default=2, - help='dev num in a node') - parser.add_argument('--compile', - action='store_true', - help='compile stage: true, runtime stage: false') - parser.add_argument('--pipeline', - action='store_true', - help='pipeline: true, tensor parallel: false') - parser.add_argument('--nproc', - type=int, - default=12, - help='multiprocess deg in pipeline') - parser.add_argument('--adaptive_recom', - action='store_true', - help='allow adaptive recompute') - parser.add_argument('--plan_idx', - type=int, - default=0, - help='runtime plan idx') - parser.add_argument('--verbose', action='store_true', help='verbose mode') - parser.add_argument('--ignore_small_tensor_threshold', - type=int, - default=0, - help='set the tensor size threshold to ignore') - parser.add_argument('--parse_plan', - action='store_true', - help='parse plan to user-friendly format') - parser.add_argument('--alphafold', - action='store_true', - help='use alphafold2') - parser.add_argument('--alphafold_setting', - type=int, - default=1, - help='1: bs, s, r = 1, 128, 256'\ - '2: bs, s, r = 1, 512, 256'\ - '3: bs, s, r = 1, 512, 384') - args = parser.parse_args() - - # if args.compile: - # assert args.ignore_small_tensor_threshold >= 64, 'suggest ignore_small_tensor_threshold >= 64' - - task_config = BloomTaskConfig(**vars(args)) - from autodist.apis import calc_parallel_plan - topk_plans = calc_parallel_plan(cube_graph, cost_database, task_config) - # from autodist.apis import compile - # compile(cube_graph, None, task_config) diff --git a/tests/parser/test_compile.py b/tests/parser/test_compile.py deleted file mode 100644 index fbcdc121..00000000 --- a/tests/parser/test_compile.py +++ /dev/null @@ -1,165 +0,0 @@ -""" -USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_compile.py -""" -import torch - -import cube -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.ir.tensor import IRFullTensor -from cube.graph.function.dimops import IRDimops - - -cube.init() - - -class TestOpModule(torch.nn.Module): - - def __init__(self, shape=[256, 512]): - super().__init__() - self.param = torch.nn.Parameter(torch.empty(shape, dtype=torch.float32)) - - def forward(self, x: torch.Tensor, cache: torch.Tensor): - x = x + cache - # [256, 512], [256, 512] -> [256, 512] - x = x * self.param - # [256, 512] -> [512] - x1 = x.select(0, 6) - # [256, 512], [512] -> [256, 512] - x2 = x.select_scatter(x1, 0, 7) - # [256, 512] -> [512, 512] - x3 = x2.repeat(2, 1) - # [512, 512] -> [256, 512]: this will be parsed to 2 slice operations - x4 = x3[:256,:] - loss = x4.sum() - return loss - - -class TestDataLoader1(cube.runtime.syndata.CubeDataLoader): - - def __init__(self) -> None: - self.sample = ( - torch.rand([256, 512], dtype=torch.float32, device=torch.cuda.current_device()), - torch.rand([256, 512], dtype=torch.float32, device=torch.cuda.current_device()), - ) - batch_size = self.sample[0][0] - super().__init__(batch_size, (0, 0)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -class TestDataLoader2(cube.runtime.syndata.CubeDataLoader): - - def __init__(self) -> None: - self.sample = torch.rand( - [256, 512], dtype=torch.float32, device=torch.cuda.current_device()) - batch_size = self.sample[0] - super().__init__(batch_size, (0,)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -model = TestOpModule() -dataloader1 = TestDataLoader1() -dataloader2 = TestDataLoader2() - - -def graph_check(graph): - for t in graph.inputs(): - assert not isinstance(t, IRFullTensor) - for node in graph.nodes(): - for t in node.inputs() + node.outputs(): - assert not isinstance(t, IRFullTensor) - for t in graph.outputs(): - assert not isinstance(t, IRFullTensor) - - -def policy(graph, resource): - graph_check(graph) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - -def test_compile_with_dataloader(): - global model - - sample, cache = next(dataloader1) - - @cube.compile(model, dataloader1, PAS=policy, - model_dummy_inputs={'x': sample, 'cache': cache}) - def train_step(model, dataloader): - data = next(dataloader) - print(data) - loss = model(*data) - loss.backward() - - gmodel = cube.load_model() - - for step in range(4): - train_step(gmodel, dataloader1) - print(f'step [{step}/4]') - - -def test_compile_without_dataloader(): - global model - - dummy_args = next(dataloader1) - sample, cache = dummy_args - - @cube.compile(model, sample, cache, PAS=policy, - model_dummy_inputs={'x': sample, 'cache': cache}) - def train_step(model, x, cache): - loss = model(x, cache) - loss.backward() - - gmodel = cube.load_model() - - for step in range(4): - x, cache = next(dataloader1) - train_step(gmodel, x, cache) - print(f'step [{step}/4]') - - - -def test_compile_with_complex(): - global model - - sample = next(dataloader2) - cache = torch.rand([256, 512], dtype=torch.float32, device=torch.cuda.current_device()) - - # @cube.compile(model, dataloader2, cache, PAS=policy) - # print(sample.size(), cache.size()) - - @cube.compile(model, dataloader2, cache, PAS=policy, - model_dummy_inputs={'x': sample, 'cache': cache}) - def train_step(model, dataloader, cache): - sample = next(dataloader) - loss = model(sample, cache) - loss.backward() - - gmodel = cube.load_model() - - for step in range(4): - train_step(gmodel, dataloader2, step) - print(f'step [{step}/4]') - - - -if __name__ == '__main__': - test_compile_with_dataloader() - test_compile_without_dataloader() - test_compile_with_complex() \ No newline at end of file diff --git a/tests/parser/test_cus_autograd.py b/tests/parser/test_cus_autograd.py deleted file mode 100644 index d2570a65..00000000 --- a/tests/parser/test_cus_autograd.py +++ /dev/null @@ -1,111 +0,0 @@ -""" -USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_cus_autograd.py -""" -import torch - -import cube -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops -from cube.graph.parser import register - -cube.init() - - -class GeLU(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input: torch.Tensor, bias: torch.Tensor): - ctx.save_for_backward(input, bias) - return GeLU.bias_gelu(bias, input) - - @staticmethod - def backward(ctx, grad_output): - input, bias = ctx.saved_tensors - tmp = GeLU.bias_gelu_back(grad_output, bias, input) - return tmp, tmp - - @staticmethod - def bias_gelu(bias, y): - x = bias + y - return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - - @staticmethod - def bias_gelu_back(g, bias, y): - x = bias + y - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff*g - - -class TestModel(torch.nn.Module): - - def __init__(self): - super().__init__() - self.fc = torch.nn.Linear(512, 10) - self.bias = torch.nn.Parameter(torch.rand(10)) - - def forward(self, x: torch.Tensor): - res = GeLU.apply(self.fc(x), self.bias) - loss = res.sum() - return {'res': res, 'loss': loss} - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int = 256) -> None: - self.sample = torch.rand( - [batch_size, 512], - dtype=torch.float32, - device=torch.cuda.current_device() - ) - super().__init__(batch_size, (0,)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def test_cus_autograd(): - register('* h, h -> * h')(GeLU.apply) - - model = TestModel() - dataloader = TestDataLoader() - - def policy(graph, resource): - print(graph.extra_repr()) - assert resource.ngpus == 1 - for node in graph.nodes(): - if isinstance(node, IRDimops): - print(f'# {node.anno}') - print(node) - elif isinstance(node, (IRFwOperation, IRDataOperation)): - print(node) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - model = cube.SemanticModel(model) - - @cube.compile(model, dataloader, PAS=policy, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def eval_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out['loss'].backward() - # return out - - model = model.get_gen_module() - - for idx in range(3): - eval_iter(model, dataloader) - print(f"iter {idx}/3") - - -if __name__ == '__main__': - test_cus_autograd() diff --git a/tests/parser/test_fx_ops.py b/tests/parser/test_fx_ops.py deleted file mode 100644 index 41f5827d..00000000 --- a/tests/parser/test_fx_ops.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_fx_ops.py -""" -from typing import List -import torch - -import cube -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops - -cube.init() - - -@cube.graph.parser.register('a b -> a b', name='test_op1') -def test_op1(a: torch.Tensor): - return a.clone() - - -@cube.graph.parser.register('a b -> a b', name='test_op2', - input_type_annos=[torch.Tensor, int]) -def test_op2(a, b): - return a + b - - -def _param(size, dtype=torch.float32): - return torch.nn.Parameter(torch.empty(size, dtype=dtype)) - - -class TestOpModule(torch.nn.Module): - - def __init__(self): - super().__init__() - self.param1 = _param([512, 256]) - self.param2 = _param([512, 256]) - self.ints = [1, 2, 3] - - def forward(self, x: torch.Tensor): - # matmul: [bs, 512], [512, 256] -> [bs, 256] - x1 = torch.matmul(x, self.param1) - # [bs, 256] -> [bs, 256] - x1 = x1 + x1.size(0) + x1.size()[0] - # [bs, 256] -> [bs, 128], [bs, 128] - x2 = torch.chunk(x1, 2, dim=1)[0] - # [bs, 128] -> [bs, 128] - x3 = x2 + x2.size(0) - x4 = x3 + self.ints[0] - # [bs, 128] -> [1] - loss = torch.sum(x4) - return {'x': x4, 'loss': loss} # , [x3,] - - -class TestOpModuleForCustomizeOp(torch.nn.Module): - - def __init__(self): - super().__init__() - self.param1 = _param([512, 256]) - self.param2 = _param([512, 256]) - self.ints = [1, 2, 3] - - def forward(self, x: torch.Tensor): - # matmul: [bs, 512], [512, 256] -> [bs, 256] - x1 = torch.matmul(x, self.param1) - # [bs, 256] -> [bs, 256] - x2 = test_op1(x1) - x3 = test_op2(x2, 1) - loss = torch.sum(x3) - return {'x': x3, 'loss': loss} # , [x3,] - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int = 256) -> None: - self.sample = torch.rand( - [batch_size, 512], - dtype=torch.float32, - device=torch.cuda.current_device() - ) - super().__init__(batch_size, (0,)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def test_parse_ops(): - - model = TestOpModule() - dataloader = TestDataLoader() - - def policy(graph, resource): - print(graph.extra_repr()) - assert resource.ngpus == 1 - for node in graph.nodes(): - if isinstance(node, IRDimops): - print(f'# {node.anno}') - print(node) - elif isinstance(node, (IRFwOperation, IRDataOperation)): - print(node) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - model = cube.SemanticModel(model) - - @cube.compile(model, dataloader, PAS=policy, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def eval_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out['loss'].backward() - # return out - - model = model.get_gen_module() - - for idx in range(3): - eval_iter(model, dataloader) - print(f"iter {idx}/3") - - -def test_registered_ops(): - - model = TestOpModuleForCustomizeOp() - dataloader = TestDataLoader() - - def policy(graph, resource): - print(graph.extra_repr()) - assert resource.ngpus == 1 - for node in graph.nodes(): - if isinstance(node, IRDimops): - print(f'# {node.anno}') - print(node) - elif isinstance(node, (IRFwOperation, IRDataOperation)): - print(node) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - model = cube.SemanticModel(model) - - @cube.compile(model, dataloader, PAS=policy, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def eval_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out['loss'].backward() - # return out - - model = model.get_gen_module() - - for idx in range(3): - eval_iter(model, dataloader) - print(f"iter {idx}/3") - - -if __name__ == '__main__': - test_parse_ops() - test_registered_ops() diff --git a/tests/parser/test_fx_zip.py b/tests/parser/test_fx_zip.py deleted file mode 100644 index c81d0d48..00000000 --- a/tests/parser/test_fx_zip.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_fx_zip.py -""" -import torch - -import cube -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops - -cube.init() - -class TestModel(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.fcs = torch.nn.Sequential(torch.nn.Linear(512, 10, bias=False), torch.nn.Linear(512, 10, bias=False)) - - def forward(self, x: torch.Tensor): - result = [] - xs = x.chunk(2, dim=1) - for x, fc in zip(xs, self.fcs): - result.append(fc(x)) - res = torch.cat(result) - return {'result': res, 'loss': torch.sum(res)} - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - def __init__(self, batch_size: int = 256) -> None: - self.sample = torch.rand( - [batch_size, 1024], - dtype=torch.float32, - device=torch.cuda.current_device() - ) - super().__init__(batch_size, (0,)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def test_zip(): - model = TestModel() - dataloader = TestDataLoader() - - def policy(graph, resource): - print(graph.extra_repr()) - assert resource.ngpus == 1 - for node in graph.nodes(): - if isinstance(node, IRDimops): - print(f'# {node.anno}') - print(node) - elif isinstance(node, (IRFwOperation, IRDataOperation)): - print(node) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - model = cube.SemanticModel(model) - - @cube.compile(model, dataloader, PAS=policy, load_content=False, model_dummy_inputs={'x': next(dataloader)}) - def eval_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out['loss'].backward() - - model = model.get_gen_module() - - for idx in range(3): - eval_iter(model, dataloader) - print(f"iter {idx}/3") - - -if __name__ == '__main__': - # zip should not appear in graph - test_zip() diff --git a/tests/parser/test_jit_ops.py b/tests/parser/test_jit_ops.py deleted file mode 100644 index 7f489540..00000000 --- a/tests/parser/test_jit_ops.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -torchrun --nproc_per_node=1 tests/parser/test_jit_ops.py -""" -import torch - -import cube -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops - - -class TestOpModule(torch.nn.Module): - - def __init__(self, shape=[256, 512]): - super().__init__() - self.param = torch.nn.Parameter(torch.empty(shape, dtype=torch.float32)) - - def forward(self, x: torch.Tensor, cache: int): - x = x + cache - # [256, 512], [256, 512] -> [256, 512] - x = x * self.param - # [256, 512] -> [512] - x1 = x.select(0, 6) - # [256, 512], [512] -> [256, 512] - x2 = x.select_scatter(x1, 0, 7) - # [256, 512] -> [512, 512] - x3 = x2.repeat(2, 1) - # [512, 512] -> [256, 512]: this will be parsed to 2 slice operations - x4 = x3[:256,:] - return x4 - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self) -> None: - self.sample = ( - torch.rand([256, 512], dtype=torch.float32, device=torch.cuda.current_device()), - 4 - ) - batch_size = self.sample[0][0] - super().__init__(batch_size, (0, None)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def test_parse_ops(): - - cube.init() - - model = TestOpModule() - dataloader = TestDataLoader() - - def policy(graph, resource): - assert resource.ngpus == 1 - for node in graph.nodes(): - if isinstance(node, IRDimops): - print(f'# {node.anno}') - print(node) - elif isinstance(node, (IRFwOperation, IRDataOperation)): - print(node) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - model = cube.SemanticModel(model) - - @cube.compile(model, dataloader, policy, load_content=False) - def eval_iter(model, dataloader): - data1, data2 = next(dataloader) - out = model(data1, data2) - - model = model.get_gen_module() - - for idx in range(3): - eval_iter(model, dataloader) - print(f"iter {idx}/3") - - -if __name__ == '__main__': - test_parse_ops() - diff --git a/tests/parser/test_no_grad.py b/tests/parser/test_no_grad.py deleted file mode 100644 index 64d771ed..00000000 --- a/tests/parser/test_no_grad.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -USE_TORCHFX=1 torchrun --nproc_per_node=1 tests/parser/test_no_grad.py -""" -from typing import List -import torch - -import cube -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.dimops import IRDimops - -cube.init() - - -class TestModel(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc = torch.nn.Linear(512, 10) - - def forward(self, x: torch.Tensor): - # this no grad will be dce - with torch.no_grad(): - pass - - # this no grad will not be dce - with torch.no_grad(): - res = self.fc(x) - - return {'res': res, 'loss': res.sum()} - - -class TestDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int = 256) -> None: - self.sample = torch.rand( - [batch_size, 512], - dtype=torch.float32, - device=torch.cuda.current_device() - ) - super().__init__(batch_size, (0,)) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - return True - - -def test_no_grad(): - - model = TestModel() - dataloader = TestDataLoader() - - def policy(graph, resource): - print(graph.extra_repr()) - assert resource.ngpus == 1 - for node in graph.nodes(): - if isinstance(node, IRDimops): - print(f'# {node.anno}') - print(node) - elif isinstance(node, (IRFwOperation, IRDataOperation)): - print(node) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - model = cube.SemanticModel(model) - - @cube.compile(model, dataloader, PAS=policy, load_content=False, - model_dummy_inputs={'x': next(dataloader)}) - def eval_iter(model, dataloader): - data = next(dataloader) - out = model(data) - out['loss'].backward() - # return out - - model = model.get_gen_module() - - for idx in range(3): - eval_iter(model, dataloader) - print(f"iter {idx}/3") - - -if __name__ == '__main__': - # consecutive no_grad __enter__ __exit__ sequences will be dce - test_no_grad() diff --git a/tests/parser/test_torchscale_basic.py b/tests/parser/test_torchscale_basic.py deleted file mode 100644 index 6b91cd4d..00000000 --- a/tests/parser/test_torchscale_basic.py +++ /dev/null @@ -1,84 +0,0 @@ -# OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 --master_port=25648 tests/parser/test_torchscale_basic.py --policy PASData - -import torch -from torch import nn - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - -import examples.mlp.policy.spmd as spmd -import examples.mlp.policy.mpmd as mpmd - -import argparse - -parser = argparse.ArgumentParser(description='comm primitive') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -parser.add_argument('--local_rank', type=int, default=0) -args = parser.parse_args() - -cube.init() - -# set up policy -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - -class SimpleNLP(nn.Module): - def __init__(self): - super().__init__() - self._tensor_constant0 = 1 - self.linear = torch.nn.Linear(2, 3) - - def forward(self, src_tokens, num): - _shape_as_tensor = torch._shape_as_tensor(src_tokens) - getitem_1 = _shape_as_tensor[1] - add = 2 + getitem_1 - arange = torch.arange(add, dtype=torch.float32) - unsqueeze = arange.unsqueeze(1) - _tensor_constant0 = self._tensor_constant0 - mul = unsqueeze * _tensor_constant0 - sin = torch.sin(mul) - cos = torch.cos(mul) - cat = torch.cat([sin, cos], dim=1) - view = cat.view(add, -1) - linear = self.linear(view) - return linear - -def run(): - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([4, 16], [2],), - dtypes=(torch.int64, torch.int64,), - batch_dims=(0, 0,) - ) - - sample_input = next(dataloader) - print(f'next(dataloader) = {sample_input}') - - model = SimpleNLP() - output = model(*sample_input) - print(f'output = {output}') - - device = next(model.parameters()).device - sample_input = next(dataloader) - sample_input_cpu = tuple([input.to(device) for input in sample_input]) - model = cube.SemanticModel( - model, dummy_input=sample_input_cpu, - ) - - # @cube.compile(model, dataloader, PAS=PAS, load_content=False) - def train_iter(model, dataloader): - data = next(dataloader) - out = model(*data) - return out - - train_iter(model, dataloader) - -run() \ No newline at end of file diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py deleted file mode 100644 index 0ce9486a..00000000 --- a/tests/runtime/test_reducer.py +++ /dev/null @@ -1,218 +0,0 @@ -""" -example: - -ASYNC_REDUCER=0 USE_ZERO=1 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ - tests/runtime/test_reducer.py - -ASYNC_REDUCER=0 USE_ZERO=0 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ - tests/runtime/test_reducer.py - -ASYNC_REDUCER=1 USE_ZERO=1 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ - tests/runtime/test_reducer.py - -ASYNC_REDUCER=1 USE_ZERO=0 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ - tests/runtime/test_reducer.py -""" -from typing import List - -import torch -import random -from torch import nn - -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.profiler.timer import print_each_rank - -cube.init() - - -class MLP(nn.Module): - def __init__(self, dim, nlayers=16): - super().__init__() - self.layers = torch.nn.ModuleList([]) - for _ in range(nlayers): - self.layers.append(nn.Linear(dim, dim, bias=False)) - self.param = torch.nn.Parameter(torch.ones([1])) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - x = x * self.param # for padding test - loss = torch.sum(x) - return loss - - -class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, bs: int, dim: int): - super().__init__(bs, [0]) - self.sample = None - self.dim = dim - self.set_batch_size(bs) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - torch.random.manual_seed(0) - self.batch_size = batch_size - self.sample = torch.randn( - [batch_size, self.dim], dtype=torch.float32, - device=torch.cuda.current_device() - ) - self.sample = (self.sample - 1) * 1e3 - - -def init_model_dataloader(): - batch_size = 4 - dim = 4096 - torch.random.manual_seed(0) - random.seed(0) - model = MLP(dim=dim) - # torch.random.manual_seed(0) - dataloader = MLPDataLoader(batch_size, dim) - return model, dataloader - - -def policy(graph: IRGraph, resource): - - # tensor parallelism - def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - # replicate - def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - devs = list(range(resource.ngpus)) - for node in graph.select(ntype=IRDataOperation): - _replica(graph, node, devs) - for node in graph.select(ntype=IRFwOperation): - _tp(graph, node, devs, idx=0, dim=0, num=resource.ngpus) - # if node.name == 'linear': - # _tp(graph, node, devs, idx=0, dim=0, num=resource.ngpus) - # else: - # _replica(graph, node, devs) - return graph - - -def cal_gnorms(model): - """Calculate gradient normalization for gradients""" - gnorms = [] - for p in model.parameters(): - if p.grad is None: - continue - gnorms.append(p.grad.norm().item()) - return sum(gnorms) - - -def get_baseline(): - - model, dataloader = init_model_dataloader() - model = model.cuda() - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - # optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - - niters = 4 - losses, gnorms = [], [] - for _ in range(niters): - loss = train_iter(model, dataloader) - gnorms.append(cal_gnorms(model)) - optimizer.step() - optimizer.zero_grad(set_to_none=True) - losses.append(loss.item()) - return losses, gnorms - - -def test_reducer(): - - losses, gnorms = get_baseline() - for idx, (loss, gnorm) in enumerate(zip(losses, gnorms)): - print_each_rank(f'baseline step [{idx}]: loss: {loss} | gnorm: {gnorm}', rank_only=0) - - model, dataloader = init_model_dataloader() - - @cube.compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - - model = cube.load_model() - optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-5) - - niters = 4 - losses, gnorms = [], [] - for idx in range(niters): - loss = train_iter(model, dataloader) - gnorms.append(cal_gnorms(model)) - optimizer.step() - optimizer.zero_grad() - model.gather_params() - losses.append(loss.item()) - - for idx, (loss, gnorm) in enumerate(zip(losses, gnorms)): - print_each_rank(f'reducer step [{idx}]: loss: {loss} | gnorm: {gnorm}', rank_only=0) - - -def test_reducer_hooks(): - - losses, gnorms = get_baseline() - for idx, (loss, gnorm) in enumerate(zip(losses, gnorms)): - print_each_rank(f'baseline step [{idx}]: loss: {loss} | gnorm: {gnorm}', rank_only=0) - - model, dataloader = init_model_dataloader() - - @cube.compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - - model = cube.load_model() - optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-5) - - for reducer in model.reducers: - pre_hook = lambda grad: grad.div_(len(reducer.ranks)) - post_hook = lambda grad: grad.mul_(len(reducer.ranks)) - reducer.register_pre_hook(pre_hook) - reducer.register_post_hook(post_hook) - - niters = 4 - losses, gnorms = [], [] - for idx in range(niters): - loss = train_iter(model, dataloader) - gnorms.append(cal_gnorms(model)) - optimizer.step() - optimizer.zero_grad() - model.gather_params() - losses.append(loss.item()) - - for idx, (loss, gnorm) in enumerate(zip(losses, gnorms)): - print_each_rank(f'reducer step [{idx}]: loss: {loss} | gnorm: {gnorm}', rank_only=0) - - -if __name__ == '__main__': - - # test_reducer() - test_reducer_hooks() \ No newline at end of file diff --git a/tests/runtime/test_runtime_collectives.py b/tests/runtime/test_runtime_collectives.py deleted file mode 100644 index c90cd773..00000000 --- a/tests/runtime/test_runtime_collectives.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -intra-primitives: - torchrun --nproc_per_node=2 tests/runtime/test_runtime_collectives.py - -inter-primitives: - torchrun --nproc_per_node=3 tests/runtime/test_runtime_collectives.py -""" - -from typing import List - -import cube -import torch - - - -cube.init() - -mydevice = torch.cuda.current_device() -myrank = torch.distributed.get_rank() -ndevices = torch.distributed.get_world_size() - - -def _get_tensor(shape: List[int], dtype: torch.dtype = torch.float32, rank=myrank) -> torch.Tensor: - global mydevice, myrank - tensor = torch.ones(shape, dtype=dtype, device=mydevice) - tensor = tensor * rank - return tensor - - -def test_runtime_move(): - assert ndevices == 2 - shape = [128, 256] - - # synchronize - tensor = _get_tensor(shape) - res = _get_tensor(shape, rank=0) - tensor = cube.runtime.adapter.move(tensor, shape, torch.float32, 0, 1) - - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - # asynchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.move(tensor, shape, torch.float32, 0, 1, async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass move') - - -def test_runtime_allreduce(): - assert ndevices == 2 - shape = [128, 256] - - # synchronize - tensor = _get_tensor(shape) - cube.runtime.adapter.all_reduce(tensor, [0, 1]) - res = _get_tensor(shape, rank=0) + _get_tensor(shape, rank=1) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - # asynchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.all_reduce(tensor, [0, 1], async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass allreduce') - - -def test_runtime_allgather(): - assert ndevices == 2 - shape = [128, 256] - - # synchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.all_gather(tensor, 0, [0, 1]) - res = torch.concat([_get_tensor(shape, rank=0), _get_tensor(shape, rank=1)], dim=0) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - # asynchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.all_gather(tensor, 0, [0, 1], async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass allgather') - - -def test_runtime_reduce_scatter(): - assert ndevices == 2 - shape = [128, 256] - - tensor = _get_tensor(shape) - res = _get_tensor(shape, rank=0) + _get_tensor(shape, rank=1) - res = res.chunk(2, dim=0)[myrank] - - # synchronize - tensor = cube.runtime.adapter.reduce_scatter(tensor, 0, [0, 1]) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - # asynchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.reduce_scatter(tensor, 0, [0, 1], async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass reduce scatter') - - -def test_runtime_all2all(): - assert ndevices == 2 - shape = [128, 256] - - tensor = _get_tensor(shape) - res = torch.concat([_get_tensor(shape, rank=0), _get_tensor(shape, rank=1)], dim=0) - res = res.chunk(2, dim=1)[myrank] - - # synchronize - tensor = cube.runtime.adapter.all_to_all(tensor, 0, 1, [0, 1]) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - # asynchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.all_to_all(tensor, 0, 1, [0, 1], async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass all2all') - - -def test_runtime_exchange(): - assert ndevices == 2 - shape = [128, 256] - - tensor = _get_tensor(shape) - res = _get_tensor(shape, rank=(myrank + 1) % 2) - - tensor = cube.runtime.adapter.exchange(tensor, [0, 1]) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.exchange(tensor, [0, 1], async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass exchange') - - -def test_runtime_rdscatter(): - assert ndevices == 3 - shape = [128, 256] - - tensor = _get_tensor(shape) - res = _get_tensor(shape, rank=0).chunk(ndevices-1, dim=0)[myrank-1] - - # synchronize - tensor = cube.runtime.adapter.rdscatter( - tensor, shape, torch.float32, dim=0, src=0, dsts=[1,2]) - if myrank > 0: - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - # synchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.rdscatter( - tensor, shape, torch.float32, dim=0, src=0, dsts=[1,2], async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - if myrank > 0: - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass rdscatter') - - -def test_runtime_rdgather(): - assert ndevices == 3 - shape = [128, 256] - - tensor = _get_tensor(shape) - res = torch.cat((_get_tensor(shape, rank=1), _get_tensor(shape, rank=2)), dim=0) - - # synchronize - tensor = cube.runtime.adapter.rdgather( - tensor, shape, torch.float32, dim=0, srcs=[1,2], dst=0) - if myrank == 0: - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - # asynchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.rdgather( - tensor, shape, torch.float32, dim=0, srcs=[1,2], dst=0, async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - if myrank == 0: - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass rdgather') - - -def test_runtime_broadcast(): - assert ndevices == 3 - shape = [128, 256] - - tensor = _get_tensor(shape) - res = _get_tensor(shape, rank=0) - - # synchronize - tensor = cube.runtime.adapter.broadcast( - tensor, shape, torch.float32, src=0, ranks=[0,1,2]) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - # asynchronize - tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.broadcast( - tensor, shape, torch.float32, src=0, ranks=[0,1,2], async_op=True) - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) - assert torch.allclose(tensor, res), f"mismatch rank{myrank}: {tensor[0,0]} vs. {res[0, 0]}" - - print(f'rank[{myrank}]: pass broadcast') - - -if __name__ == '__main__': - - if ndevices == 2: - test_runtime_move() - test_runtime_allreduce() - test_runtime_allgather() - test_runtime_reduce_scatter() - test_runtime_all2all() - test_runtime_exchange() - - if ndevices == 3: - test_runtime_rdscatter() - test_runtime_rdgather() - test_runtime_broadcast() diff --git a/tests/runtime/test_runtime_flag.py b/tests/runtime/test_runtime_flag.py deleted file mode 100644 index cba0a7ca..00000000 --- a/tests/runtime/test_runtime_flag.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -example: - -ASYNC_REDUCER=0 USE_ZERO=1 OMP_NUM_THREADS=4 torchrun --nproc_per_node=4 \ - tests/runtime/test_runtime_flag.py -""" - -from typing import List -from functools import partial - -import torch -import random -from torch import nn - -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.profiler.timer import print_each_rank -# from cube.tools.debug import DebugTool - -cube.init() - - -class MLP(nn.Module): - def __init__(self, dim, nlayers=16): - super().__init__() - self.layers = torch.nn.ModuleList([]) - for _ in range(nlayers): - self.layers.append(nn.Linear(dim, dim, bias=False)) - self.param = torch.nn.Parameter(torch.ones([1])) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - x = x * self.param # for padding test - loss = torch.sum(x) - return loss - - -class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, bs: int, dim: int): - super().__init__(bs, [0]) - self.sample = None - self.dim = dim - self.set_batch_size(bs) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - torch.random.manual_seed(0) - self.batch_size = batch_size - self.sample = torch.randn( - [batch_size, self.dim], dtype=torch.float32, - device=torch.cuda.current_device() - ) - self.sample = (self.sample - 1) * 1e3 - - -def init_model_dataloader(): - batch_size = 4 - dim = 4096 - torch.random.manual_seed(0) - random.seed(0) - model = MLP(dim=dim) - # torch.random.manual_seed(0) - dataloader = MLPDataLoader(batch_size, dim) - return model, dataloader - - -def policy(graph: IRGraph, resource): - - # tensor parallelism - def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - # replicate - def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - devs = list(range(resource.ngpus)) - for node in graph.select(ntype=IRDataOperation): - _replica(graph, node, devs) - for node in graph.select(ntype=IRFwOperation): - _tp(graph, node, devs, idx=0, dim=0, num=resource.ngpus) - # if node.name == 'linear': - # _tp(graph, node, devs, idx=0, dim=0, num=resource.ngpus) - # else: - # _replica(graph, node, devs) - return graph - - -def get_baseline(): - - model, dataloader = init_model_dataloader() - model = model.cuda() - # optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - # optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - - wsz = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - accum_steps = 4 - - niters = 4 - losses = [] - for idx in range(niters): - for _ in range(accum_steps): - loss = train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad(set_to_none=True) - losses.append(loss.item()) - - for idx, loss in enumerate(losses): - print_each_rank(f'baseline loss[{idx}]: {loss}', rank_only=0) - - -def test_runtime_accum_mode_v1(): - - model, dataloader = init_model_dataloader() - model = model.cuda() - - @cube.compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - - model = cube.load_model() - # optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) - # optimizer = torch.optim.SGD(model.parameters_for_optimizer(), lr=1e-4) - optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-5) - - accum_steps = 4 - - niters = 4 - losses = [] - for idx in range(niters): - for step in cube.accum_mode.steps(accum_steps): - # print(f'enter step {step}') - loss = train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - model.gather_params() - losses.append(loss.item()) - - for idx, loss in enumerate(losses): - print_each_rank(f'reducer loss[{idx}]: {loss}', rank_only=0) - - -def test_runtime_accum_mode_v2(): - - model, dataloader = init_model_dataloader() - model = model.cuda() - - @cube.compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - return loss - - model = cube.load_model() - # optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) - # optimizer = torch.optim.SGD(model.parameters_for_optimizer(), lr=1e-4) - optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=1e-5) - - accum_steps = 4 - - niters = 4 - losses = [] - for idx in range(niters): - for step in range(accum_steps): - # print(f'enter step {step}') - with cube.accum_mode(start=(step==0), end=(step==accum_steps-1)): - loss = train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - model.gather_params() - losses.append(loss.item()) - - for idx, loss in enumerate(losses): - print_each_rank(f'reducer loss[{idx}]: {loss}', rank_only=0) - - -if __name__ == '__main__': - - get_baseline() - test_runtime_accum_mode_v1() - # test_runtime_accum_mode_v2() \ No newline at end of file diff --git a/tests/test_codegen.py b/tests/test_codegen.py deleted file mode 100644 index a525a2b8..00000000 --- a/tests/test_codegen.py +++ /dev/null @@ -1,185 +0,0 @@ -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union -import pytest -from cube.codegen.codegen import ModelCodeGen -from cube.execplan.execplan import ExecutionPlan - -from cube.graph.graph import IRGraph, IRSegment -from cube.ir.cten import IRCell, IRTensor -from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRFullTensor, IRSubTensor - -# Override tensor naming to omit TensorID since the ID assignment is too hard to predict. -class FakeModelCodeGen(ModelCodeGen): - def tensor_naming(self, tensor: Any, prefix_attr: Optional[str] = None) -> str: - if isinstance(tensor, IRTensor): - name = tensor.name.replace(".", "_") - if prefix_attr is not None and tensor.is_param(): - name = prefix_attr + name - return name - else: - return super().tensor_naming(tensor, prefix_attr) - -def make_nodes(args_list: list, input_vars:List[str], output_vars:List[str]) \ - -> Tuple[List[IRFwOperation], List[IRTensor], List[IRTensor]]: - """ - Each element of `args_list` is in a form of: - - (RCGID:Optional[int], OutputNames:str|List[str], Signature:str, OpArg...:OpArgType...) - - If any `OpArg` is string, it's automatically mapped to a `IRTensor` with the same name. - - E.g. - ``` - [ - ("sum_res", "sum_fn", "a", IRFullTensor(name="b")) - (1, ["prod_res"], "prod_fn", "sum", "a") - ] - ``` - ... results in - ``` - sum_res = sum_fn(a, b) - def recompute(sum_res, a): - prod_res = prod_fn(sum_res, a) - return prod_res - prod_res = checkpoint(recompute, sum_res, a) - ``` - - REMARK: - `signature:str` will affect how the call is dumped, see also `cube/codegen/frontend_mapping.py`. - Generally if it's not a 'torch.some_fn' operator, it's dumped as-it-is. - """ - var_tensor_map = dict() - - def _convert(output_names:Union[str, List[str]], signature:str, op_args, rc_gid:Optional[int]): - if type(output_names) is str: - output_names = [output_names] - - op_kwargs = {} - if type(op_args[-1]) is dict: - op_args, op_kwargs = op_args[:-1], op_args[-1] - - mapped_inputs = [var_tensor_map.setdefault(arg, IRFullTensor(name=arg).tosub()) if type(arg) is str else arg for arg in op_args] - mapped_outputs = [var_tensor_map.setdefault(oname, IRFullTensor(name=oname).tosub()) for oname in output_names] - - op = IRFwOperation("not_matter_name", signature, len(mapped_inputs), len(output_names)) - for i, input in enumerate(mapped_inputs): - op.set_input(i, input) - for i, output in enumerate(mapped_outputs): - op.set_output(i, output) - op.kwargs.update(op_kwargs) - - # All devices are the same - op.device = 0 - - op.recompute = rc_gid - - return op - - def convert(args): - rc_gid = None - if type(args[0]) is int: - rc_gid, args = args[0], args[1:] - return _convert(args[0], args[1], args[2:], rc_gid) - - nodes = [convert(args) for args in args_list] - inputs = [var_tensor_map[n] for n in input_vars] - outputs = [var_tensor_map[n] for n in output_vars] - return nodes, inputs, outputs - -def gen(node_defs, invars, outvars): - nodes, inputs, outputs = make_nodes(node_defs, invars, outvars) - # REMARK - # Do not directly create a 'IRSegment' from 'nodes', instead, re-retrieve the IRSegment - # using 'graph.segment(nodes)' - # Because we rely on proper dataflow analysis when segmentation, which requires all nodes - # are properly registered/'attach'-ed into the graph. - graph = IRGraph(nodes, inputs, outputs, "module_name_not_matter") - segment = graph.segment(nodes) - assert list(segment.inputs()) == inputs - assert list(segment.outputs()) == outputs - - codegen = FakeModelCodeGen(ExecutionPlan(graph)) - code : list = codegen.emit_segment_code(segment) - return str.join("\n", code) - - -def test_codegen_segment_recompute__simple(): - code = gen([ - ("c", "add", "a", "b"), - ("d", "add", "a", "c"), - ], invars=["a","b"], outvars=["d"]) - assert code == """\ -c = add(a, b) -del b -d = add(a, c)""" - - -def test_codegen_segment_recompute_rc__simple(): - code = gen([ - ("c", "add", "a", "b"), - (1, "d", "add", "a", "c"), - ], invars=["a","b"], outvars=["d"]) - - assert code == """\ -c = add(a, b) -del b - -def recompute(a, c): - d = add(a, c) - return d - -d = ckpt.checkpoint(recompute, a, c)""" - - -def test_codegen_segment_recompute_rc__del_args(): - code = gen([ - ("c", "add", "a", "b"), - (1, "d", "add", "a", "c"), - ("e", "sub", "d", "a"), - ], invars=["a","b"], outvars=["e"]) - - assert code == """\ -c = add(a, b) -del b - -def recompute(a, c): - d = add(a, c) - return d - -d = ckpt.checkpoint(recompute, a, c) -del c -e = sub(d, a)""" - - -def test_codegen_segment_recompute_rc__multi_rc(): - code = gen([ - ("c", "add", "a", "b"), - (1, "d", "add", "a", "c"), - (1, "e", "sub", "d", "a"), - ("f", "mul", "a", "d"), - (2, "g", "div", "f", "e"), - (2, "h", "pow", "f", "g"), - ], invars=["a","b"], outvars=["g","h"]) - - assert code == """\ -c = add(a, b) -del b - -def recompute(a, c): - d = add(a, c) - del c - e = sub(d, a) - return d, e - -d, e = ckpt.checkpoint(recompute, a, c) -del c -f = mul(a, d) -del a, d - -def recompute(f, e): - g = div(f, e) - del e - h = pow(f, g) - return g, h - -g, h = ckpt.checkpoint(recompute, f, e)""" # e,f will be auto rel-ed when it returns diff --git a/tests/test_examples.sh b/tests/test_examples.sh deleted file mode 100755 index 87e366d1..00000000 --- a/tests/test_examples.sh +++ /dev/null @@ -1,122 +0,0 @@ -# NOTE: This test should run in the root directory. -# Before running this test, you should run `export PYTHONPATH=.:$PYTHONPATH` first. - -set -e - -# test torch.fx -# working path -OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:$PYTHONPATH \ - python -m torch.distributed.launch \ - --nproc_per_node=1 \ - examples/mlp/linearsfx.py --policy PASData - -# test MLP - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASSingle - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASData - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASCol - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASRow - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASHybrid - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASMegatronTP - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASOptimal - -ASYNC_COMM=1 OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/infer.py --policy PASMegatron - - -# test GPT model - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMegatronTP --fp16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASRoundRobin --fp16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/nlp/gpt/train.py --policy PAS1F1B --fp16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMegatron --fp16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMeshShard --fp16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=2 \ - --nnodes=1 \ - examples/nlp/gpt/infer.py --policy PASDP --fp16 - - -# test Swin model - -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# examples/vision/swin/train.py --policy PASData --fp16 - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/vision/swin/train.py --policy PASMegatronTP --fp16 - -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# examples/vision/swin/train.py --policy PASMegatron --fp16 - - -# test scientific model - -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# examples/poisson/sci.py -# -# OMP_NUM_THREADS=4 torchrun \ -# --nproc_per_node=1 \ -# --nnodes=1 \ -# examples/wrf/wrf2.py --policy PAS -# -# OMP_NUM_THREADS=1 torchrun \ -# --nproc_per_node=4 \ -# --nnodes=1 \ -# examples/wrf/wrf2.py --policy PAS_ALL_Y diff --git a/tests/test_execplan_grouping.py b/tests/test_execplan_grouping.py deleted file mode 100644 index 9323a6aa..00000000 --- a/tests/test_execplan_grouping.py +++ /dev/null @@ -1,66 +0,0 @@ -# run tests: -# pytest ./tests/test_execplan_grouping.py - -from typing import Dict, List - -import pytest -from cube.execplan.planpass import grouping -from cube.execplan.planpass.grouping import Grouping -from cube.ir.cten import IRCell -from cube.ir.operator import IRDataOperation, IRFwOperation - -# Stub object for 'cube.execplan.ExecPlan' -# Commonly the devices are like [0,1,2,...] -class StubExecPlan(): - def __init__(self, devices:List[int], seq:Dict[int, List[IRCell]]) -> None: - assert all(devid in seq for devid in devices) - self._devices = devices - self._seq = seq - - def devices(self): - return self._devices - def seq(self, devid:int): - return self._seq[devid] - -# With these settings, all tests here are run twice, with 'grouping._get_new...algo' returning True or False, respectively. -# And all the setting ups and the recovery of this flag happen in the background. -# -# By runninng tests in both environments, we can check the consistency of the old and new algorithms. -@pytest.fixture(params=[True, False], autouse=True) -def setup_and_cleanup(request:pytest.FixtureRequest) -> None: - flag = grouping._get_use_new_grouping_algo() - grouping._set_use_new_grouping_algo(request.param) - yield - grouping._set_use_new_grouping_algo(flag) - - -def test_grouping_forward_single_group(): - execplan = StubExecPlan([0], {0: [IRFwOperation(f"op{i}", f"sign{i}", i, i) for i in range(1, 10)] }) - # each type: Dict[DeviceIdInt, List[List[IRCell]] ] - fwgroups, bpgroups = Grouping.group(execplan) - - assert len(fwgroups) == 1 # one device - assert len(fwgroups[0]) == 1 # one group - assert all(fnode.name == f"op{i+1}" for i, fnode in enumerate(fwgroups[0][0])) - - assert len(bpgroups) == 1 - assert len(bpgroups[0]) == 1 - assert bpgroups[0][0] is None - - -def test_grouping_forward_interleaving_excluded_nodes(): - execplan = StubExecPlan([0], {0: [ - IRFwOperation(f"op{i}", f"sign{i}", i, i) if i % 2 == 0 - else IRDataOperation(i, (2,)*i) # IRDataOperation is the IRCell to exclude from the group - for i in range(1, 9) # [1,2,...,8] - ] }) - # each type: Dict[DeviceIdInt, List[List[IRCell]] ] - fwgroups, bpgroups = Grouping.group(execplan) - - assert len(fwgroups) == 1 - assert len(fwgroups[0]) == 4 - assert all(len(fwgroup) == 1 and fwgroup[0].name == f"op{i}" for fwgroup, i in zip(fwgroups[0], [2,4,6,8])) - - assert len(bpgroups) == 1 - assert len(bpgroups[0]) == 4 - assert all(bpgroup is None for bpgroup in bpgroups[0]) \ No newline at end of file diff --git a/tests/test_nccl.py b/tests/test_nccl.py deleted file mode 100644 index d3000b2f..00000000 --- a/tests/test_nccl.py +++ /dev/null @@ -1,103 +0,0 @@ - -""" - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=1 \ - tests/test_nccl.py - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - tests/test_nccl.py - -OMP_NUM_THREADS=4 python -m torch.distributed.launch \ - --nproc_per_node=8 \ - --nnodes=2 \ - --node_rank=${NODE_RANK} \ - --master_addr="${MASTER_IP}" \ - --master_port=${MASTER_PORT} \ - tests/test_nccl.py - -""" - -import torch -import time -import sys -import os -import argparse - - -def print_each_rank(msg, select=True, outfile=''): - myrank = torch.distributed.get_rank() - outfile = sys.stdout if outfile == '' else outfile - for rank in range(torch.distributed.get_world_size()): - if select: - if myrank == rank: - f = open(outfile, 'a') if outfile != sys.stdout else sys.stdout - f.write('rank [{}]: {}\n'.format(rank, msg)) - if outfile != sys.stdout: - f.close() - torch.distributed.barrier() - - -def test_nccl(size, local_rank): - msg = torch.ones((size,)).cuda() - # warm up - for _ in range(20): - out = torch.distributed.all_reduce(msg) - torch.cuda.synchronize() - # profile - tic = time.perf_counter() - for _ in range(100): - out = torch.distributed.all_reduce(msg) - torch.cuda.synchronize() - toc = time.perf_counter() - - span = (toc - tic) * 1000 / 100 # in ms - bandwidth = size / span / 1e6 # in GB/s - print_each_rank( - 'NCCL Allreduce | Msg Size: {:.0f} MB | Algo Bandwidth: {:.2f} GB/s'.format( - size / 1024 / 1024, bandwidth), - select=(local_rank==0), - ) - -def test_allgather(size, local_rank): - msg = torch.ones((size,)).cuda() - tensor_list = [torch.empty_like(msg) for _ in range(torch.distributed.get_world_size())] - - tic = time.perf_counter() - for _ in range(100): - out = torch.distributed.all_gather(tensor_list, msg) - torch.cuda.synchronize() - print_each_rank('Passed all-gather') - toc = time.perf_counter() - - -def benchmark(args): - size = args.begin - while size <= args.end: - # test_allgather(size * 1024 * 1024, args.local_rank) - test_nccl(size * 1024 * 1024, args.local_rank) # MB to B - size *= 2 - print_each_rank('test on nccl is done') - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--begin', type=int, default=4, - help='start message size in MB') - parser.add_argument('--end', type=int, default=64, - help='end message size in MB') - args = parser.parse_args() - - torch.distributed.init_process_group(backend='nccl') - print(f'{torch.distributed.get_rank()} launches') - - args.local_rank = int(os.environ.get('LOCAL_RANK')) - torch.cuda.set_device(args.local_rank) - benchmark(args) diff --git a/tests/test_profile_gpt.py b/tests/test_profile_gpt.py deleted file mode 100644 index eb9a391f..00000000 --- a/tests/test_profile_gpt.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -torchrun --nproc_per_node=1 test/test_profile_gpt.py -""" - - -import torch -import time - -from examples.nlp.gpt.model import GPT -from examples.nlp.gpt.model import GPTDataLoader - -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary, model_summary - -from examples.nlp.gpt.policy.mpmd import PASMegatron as PAS -import examples.nlp.gpt.policy.spmd as spmd -import examples.nlp.gpt.policy.mpmd as mpmd - -import argparse - -from cube.ir.operator import IRFwOperation, IRBpOperation -from cube.profiler.database import ProfileDataBase -from cube.algorithm.ops.dimops import gen_partitions -from cube.graph.function.anchor import IRGraphAnchor - -parser = argparse.ArgumentParser(description='GPT Train') -parser.add_argument('--fp16', action='store_true', default=False, - help='use fp16 for the training') -args = parser.parse_args() - -cube.init() - - -def train(): - batch_size = 1 - - model = GPT() - model = model if not args.fp16 else model.half() - dataloader = GPTDataLoader(batch_size) - - model = cube.SemanticModel(model, dataloader.shapes) - - def profile(graph, resource): - db = ProfileDataBase() - mem_sum = 0 - for node in graph.select(ntype=IRFwOperation): - if isinstance(node, IRGraphAnchor): - continue - partition_nodes = gen_partitions(node, 1) - for partition_node in partition_nodes: - in_mem, param_mem, fw_span, bw_span, infer_mem, train_mem = db.profile(partition_node) - print(node.signature, in_mem, param_mem) - mem_sum = mem_sum + train_mem - db.dump('db.json', override=True) - print('estimated train mem: ', mem_sum / 1024 / 1024 / 1024) - - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - return graph - - @cube.compile(model, dataloader, PAS=profile, override=True) - def train_iter(model, dataloader): - input_ids, position_ids = next(dataloader) - loss = model(input_ids, position_ids) - loss.backward() - model = model.get_gen_module() - - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - if torch.distributed.is_initialized(): - torch.distributed.barrier() - print_each_rank('model weight consumpition:', rank_only=0) - memory_summary() - - # CudaTimer(enable=False).warmup() - iter_num = 4 - warmup = 2 - for step in range(iter_num): - if step == warmup: - CudaTimer(enable=True).start('e2e') - - train_iter(model, dataloader) - # memory_summary() - optimizer.step() - # memory_summary() - optimizer.zero_grad() - # memory_summary() - - if step == 0: - print_each_rank('passed first iteration') - if (step + 1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - CudaTimer().stop('e2e') - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) - - memory_summary() - - -if __name__ == '__main__': - - cube.init() - train() \ No newline at end of file diff --git a/tox.ini b/tox.ini index 81fb0b46..ed160849 100644 --- a/tox.ini +++ b/tox.ini @@ -14,5 +14,5 @@ deps = -rrequirements.txt -rrequirements-dev.txt commands = coverage erase - py.test --cov={toxinidir}/cube -sx unit_tests + pytest --cov={toxinidir}/cube -x unit_tests coverage html diff --git a/unit_tests/compiler/__init__.py b/unit_tests/compiler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unit_tests/compiler/test_compile.py b/unit_tests/compiler/test_compile.py new file mode 100644 index 00000000..49234484 --- /dev/null +++ b/unit_tests/compiler/test_compile.py @@ -0,0 +1,137 @@ +""" +pytest unit_tests/compiler/test_compile.py +""" +import torch +import logging +from functools import partial + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation +from cube.flags import CompileFlag +from ..launch_torchrun import torchrun +from ..utils import init_parameter, assert_parity + + +class MLP(torch.nn.Module): + def __init__(self, dim=512, nlayers=4): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + def forward(self, data): + x = data + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) + return loss + + +def get_dummy_data(batch_size: int = 512): + torch.random.manual_seed(0) + return torch.randn( + [128, 512], dtype=torch.float32, + device=torch.cuda.current_device()).repeat([batch_size // 128, 1]) + + +def baseline(): + + model = MLP() + init_parameter(model) + model.cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + x = get_dummy_data() + loss = model(x) + loss.backward() + optimizer.step() + optimizer.zero_grad() + loss = loss.item() + while abs(loss) > 10.0: + loss /= 10.0 # scale for comparison + losses.append(loss) + + return losses + + +def scale(ngpus_per_unit: int): + + model = MLP() + init_parameter(model) + + def policy(graph: IRGraph, resource): + + ngpus = min(ngpus_per_unit, resource.ngpus) + + def tensor_parallelism(node, idx, dim, num): + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=idx, dim=dim, num=num) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return sub_nodes + + l1, l2, l3, l4 = graph.select(name='linear') + + # l1 tensor parallelism + tensor_parallelism(l1, idx=1, dim=0, num=ngpus) + # l2 data parallelism + tensor_parallelism(l2, idx=0, dim=0, num=ngpus) + # l3 tensor parallelism + tensor_parallelism(l3, idx=1, dim=1, num=ngpus) + # l4 replicate + + for node in graph.select(ntype=IRFwOperation): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph + + ngpus_per_unit = min(ngpus_per_unit, torch.distributed.get_world_size()) + nreplicas = torch.distributed.get_world_size() // ngpus_per_unit + batch_size = 512 // nreplicas + print('>> set batch size to', batch_size) + x = get_dummy_data(batch_size=batch_size) + + @cube.compile(model, x, PAS=policy, scale=True) + def train_iter(model, x): + loss = model(x) + loss.backward() + return loss + + model = cube.load_model() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + x = get_dummy_data(batch_size=batch_size) + loss = train_iter(model, x) + loss = loss * nreplicas + optimizer.step() + optimizer.zero_grad() + loss = loss.item() + while abs(loss) > 10.0: + loss /= 10.0 # scale for comparison + losses.append(loss) + + return losses + + +def scale_test(): + cube.init() + CompileFlag.disable_code_line_info = True # speedup parse + assert_parity(baseline, partial(scale, 2)) + + +def scale_test_dp(): + cube.init() + CompileFlag.disable_code_line_info = True # speedup parse + assert_parity(baseline, partial(scale, 1)) + + +test_scale_2gpu = partial(torchrun, 2, scale_test) +test_scale_2gpu_dp = partial(torchrun, 2, scale_test_dp) +test_scale_4gpu = partial(torchrun, 4, scale_test) diff --git a/unit_tests/graph/function/test_dataloader.py b/unit_tests/graph/function/test_dataloader.py new file mode 100644 index 00000000..b2d4850e --- /dev/null +++ b/unit_tests/graph/function/test_dataloader.py @@ -0,0 +1,45 @@ +""" +pytest unit_tests/graph/function/test_dataloader.py +""" + +import torch + +from cube.ir.cten import IRObject +from cube.ir.tensor import IRFullTensor +from cube.ir.operator import IRDataOperation +from cube.runtime.utils import create_dummy_dataloader + + +def test_dummy_dataloader(): + samples = ( + torch.rand([256, 512], dtype=torch.float32), + torch.rand([128, 224], dtype=torch.float16), + 4, + ) + dataloader = create_dummy_dataloader(samples, batch_size=32) + for idx, samples in enumerate(dataloader): + assert samples[0].shape == torch.Size([32, 256, 512]) + assert samples[1].shape == torch.Size([32, 128, 224]) + assert torch.allclose(samples[2], torch.tensor([4] * 32, dtype=torch.int64)) + if idx == 4: + break + + +def test_data_operation(): + + data_op = IRDataOperation( + IRObject('dataloader'), + [IRFullTensor(shape=[32, 256, 512]).tosub(), + IRFullTensor(shape=[32, 128, 224]).tosub(),]) + + # cannot be partitioned + assert not hasattr(data_op, 'algorithms') + # test input / output + assert all(isinstance(out, IRObject) for out in data_op.outputs()) + assert all(isinstance(inp, IRObject) for inp in data_op.inputs()) + # can be replicated + data_op_replica = data_op.replicate() + assert data_op_replica.input(0) == data_op.input(0) + assert data_op_replica.output(0) == data_op.output(0) + assert data_op_replica.output(1) == data_op.output(1) + assert data_op_replica.cid == data_op.cid diff --git a/tests/algorithm/test_op_algorithm.py b/unit_tests/graph/function/test_dimops.py similarity index 86% rename from tests/algorithm/test_op_algorithm.py rename to unit_tests/graph/function/test_dimops.py index 95f51545..7f68f37a 100644 --- a/tests/algorithm/test_op_algorithm.py +++ b/unit_tests/graph/function/test_dimops.py @@ -1,20 +1,15 @@ """ -python tests/algorithm/test_op_algorithm.py -pytest tests/algorithm/test_op_algorithm.py +pytest unit_tests/graph/function/test_dimops.py """ from typing import Callable, Tuple, List from functools import partial -import cube import cube.graph.function as F from cube.graph.function.dimops import IRDimops from cube.ir.tensor import IRFullTensor -Shape=Tuple[int] - - def create_op(creator: Callable, input_shapes: List[Tuple[int]], *args, **kwargs): inputs = tuple(IRFullTensor(shape=shape).tosub() for shape in input_shapes) @@ -48,9 +43,4 @@ def create_udf_op1(input, weight, signature='test_udf_op1'): test_multi_dim_partition = partial(partitionable, create_op(create_udf_op1, [(2048, 8, 4096), (2048, 4096)]), idx=0, dim=0, num=2, -) - -if __name__ == '__main__': - test_view1() - test_view2() - test_multi_dim_partition() \ No newline at end of file +) \ No newline at end of file diff --git a/tests/adapter/test_inter_rvd.py b/unit_tests/graph/gener/check_inter_rvd.py similarity index 98% rename from tests/adapter/test_inter_rvd.py rename to unit_tests/graph/gener/check_inter_rvd.py index 497a2930..6578b850 100644 --- a/tests/adapter/test_inter_rvd.py +++ b/unit_tests/graph/gener/check_inter_rvd.py @@ -1,4 +1,6 @@ """ +Note this is not for test. + OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=1 \ tests/adapter/test_inter_rvd.py diff --git a/tests/adapter/test_intra_rvd.py b/unit_tests/graph/gener/check_intra_rvd.py similarity index 99% rename from tests/adapter/test_intra_rvd.py rename to unit_tests/graph/gener/check_intra_rvd.py index 6286d2b4..9bb970e1 100644 --- a/tests/adapter/test_intra_rvd.py +++ b/unit_tests/graph/gener/check_intra_rvd.py @@ -1,7 +1,9 @@ """ +Note this is not for test. + OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=1 \ - tests/adapter/test_intra_rvd.py + unit_test/graph/gener/test_intra_rvd.py """ from typing import List, Tuple diff --git a/unit_tests/graph/test_multiref.py b/unit_tests/graph/test_multiref.py new file mode 100644 index 00000000..935cf9fd --- /dev/null +++ b/unit_tests/graph/test_multiref.py @@ -0,0 +1,124 @@ +""" +pytest unit_tests/graph/test_multiref.py +""" +import torch +import logging +from functools import partial + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation +from ..launch_torchrun import torchrun +from ..utils import init_parameter, assert_parity + + +def _param(shape, dtype=torch.float32): + return torch.nn.Parameter(torch.empty(shape, dtype=dtype)) + + +class OpModule(torch.nn.Module): + def __init__(self, shape=[256, 512]): + super().__init__() + self.param = _param(shape) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + """ + residual on x and self.param + """ + residual = x + x = residual * self.param + y = residual + y + y = y * self.param + loss = torch.sum(y) + return loss + + +def get_dummy_data(batch_size: int = 256): + torch.random.manual_seed(0) + return ( + torch.rand([batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()), + torch.rand([batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()), + ) + + +def baseline(): + + model = OpModule() + init_parameter(model) + model.cuda() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + x, y = get_dummy_data() + loss = model(x, y) + loss.backward() + optimizer.step() + optimizer.zero_grad() + losses.append(loss.item()) + + return losses + + +def multiref(): + + model = OpModule() + init_parameter(model) + x, y = get_dummy_data() + + def policy(graph: IRGraph, resource): + + first_mul = graph.select('mul')[0] + first_add = graph.select('add')[0] + + sub_muls = graph.partition( + first_mul, first_mul.algorithms('dim'), + idx=0, dim=0, num=resource.ngpus + ) + for idx, sub_node in enumerate(sub_muls): + graph.assign(sub_node, idx) + + sub_adds = graph.partition( + first_add, first_add.algorithms('dim'), + idx=0, dim=0, num=resource.ngpus + ) + for idx, sub_node in enumerate(sub_adds): + graph.assign(sub_node, idx) + + for node in graph.select(ntype=IRFwOperation): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph + + x, y = get_dummy_data() + + @cube.compile(model, x, y, PAS=policy) + def train_iter(model, x, y): + loss = model(x, y) + loss.backward() + return loss + + model = cube.load_model() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + x, y = get_dummy_data() + loss = train_iter(model, x, y) + optimizer.step() + optimizer.zero_grad() + losses.append(loss.item()) + + return losses + + +def multiref_test(): + cube.init() + cube.set_logger_level(logging.INFO) + assert_parity(baseline, multiref) + + +test_multiref_1gpu = partial(torchrun, 1, multiref_test) +test_multiref_2gpu = partial(torchrun, 2, multiref_test) diff --git a/unit_tests/launch_torchrun.py b/unit_tests/launch_torchrun.py index fe7fde30..65d93e7a 100644 --- a/unit_tests/launch_torchrun.py +++ b/unit_tests/launch_torchrun.py @@ -1,8 +1,14 @@ +from typing import Callable import uuid import torch +import logging +import time +import random from torch.distributed.run import elastic_launch, LaunchConfig +_logger = logging.getLogger(__name__) + def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): launch_config = LaunchConfig( @@ -19,6 +25,32 @@ def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): return outputs +def torchrun(nproc_per_node: int, test_fn: Callable, *args, **kwargs): + """Test utility for torchrun + + Example usage: + + ```python + from functools import partial + test_function_name = partial(torchrun, 2, function_to_test) + ``` + + Args: + nproc_per_node (int): number of gpus + test_fn (function): test function, which should return None + *args: args for worker_fn + **kwargs: kwargs for worker_fn + + Returns: + None + """ + + if not torch.cuda.is_available() or torch.cuda.device_count() < nproc_per_node: + _logger.warning(f"skip test on {nproc_per_node} gpus due to lack of cuda devices") + return + launch_torchrun(nproc_per_node, test_fn, *args, **kwargs) + + def clone_to_cpu(tensor: torch.Tensor): # when you use launch_torchrun # you can't directly return a cuda tensor diff --git a/unit_tests/parallel_module/common.py b/unit_tests/parallel_module/common.py index b75a765d..677230f6 100644 --- a/unit_tests/parallel_module/common.py +++ b/unit_tests/parallel_module/common.py @@ -96,15 +96,10 @@ def PASData(graph: IRGraph, env_resource: ComputeConfig): if len(graph.consumers(ftensor)) > 1: graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) - batch_dim = None - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=ngpus) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - batch_dim = node.get_batch_dims()[0] - if batch_dim is None: batch_dim = 0 + batch_dim = 0 + for dl in graph.select(ntype=IRDataOperation): + _replica(dl, list(range(ngpus))) + graph_inputs = IRSegment.get_objects_from_complex(graph.inputs()) graph_outputs = IRSegment.get_objects_from_complex(graph.outputs()) for node in graph.nodes(): diff --git a/unit_tests/runtime/test_reducer.py b/unit_tests/runtime/test_reducer.py new file mode 100644 index 00000000..5a3281f1 --- /dev/null +++ b/unit_tests/runtime/test_reducer.py @@ -0,0 +1,132 @@ +""" +pytest unit_tests/runtime/test_reducer.py +""" +import torch +import logging +from functools import partial + +import cube +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation +from cube.flags import CompileFlag +from ..launch_torchrun import torchrun +from ..utils import init_parameter, assert_parity + + +class MLP(torch.nn.Module): + def __init__(self, dim=512, nlayers=4): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + def forward(self, data): + x = data + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) + return loss + + +def get_dummy_data(batch_size: int = 256): + torch.random.manual_seed(0) + return torch.randn( + [batch_size, 512], dtype=torch.float32, + device=torch.cuda.current_device()) + + +def baseline(): + + model = MLP() + init_parameter(model) + model.cuda() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + x = get_dummy_data() + loss = model(x) + loss.backward() + optimizer.step() + optimizer.zero_grad() + loss = loss.item() + while abs(loss) > 10.0: + loss /= 10.0 + losses.append(loss) + + return losses + + +def reducer(use_zero: bool, async_reducer: bool): + + CompileFlag.use_zero = use_zero + CompileFlag.async_reducer = async_reducer + + model = MLP() + init_parameter(model) + + def policy(graph: IRGraph, resource): + + def tensor_parallelism(node, idx, dim, num): + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=idx, dim=dim, num=num) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return sub_nodes + + l1, l2, l3, l4 = graph.select(name='linear') + + # l1 data parallelism + tensor_parallelism(l1, idx=0, dim=0, num=resource.ngpus) + # l2 data parallelism + tensor_parallelism(l2, idx=0, dim=0, num=resource.ngpus) + # l3, l4 replicate + + for node in graph.select(ntype=IRFwOperation): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, times=resource.ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph + + x = get_dummy_data() + + @cube.compile(model, x, PAS=policy) + def train_iter(model, x): + loss = model(x) + loss.backward() + return loss + + model = cube.load_model() + optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=0.01) + + losses = [] + for _ in range(3): + x = get_dummy_data() + loss = train_iter(model, x) + optimizer.step() + optimizer.zero_grad() + ## === neccessary for zero === + model.gather_params() + ## =========================== + loss = loss.item() + while abs(loss) > 10.0: + loss /= 10.0 + losses.append(loss) + + return losses + + +def reducer_test(): + cube.init() + CompileFlag.disable_code_line_info = True # speedup parse + print('starting zero=True, async=True') + assert_parity(baseline, partial(reducer, True, True)) + print('starting zero=True, async=False') + assert_parity(baseline, partial(reducer, True, False)) + print('starting zero=False, async=True') + assert_parity(baseline, partial(reducer, False, True)) + print('starting zero=False, async=False') + assert_parity(baseline, partial(reducer, False, False)) + +test_reducer_2gpu = partial(torchrun, 2, reducer_test) diff --git a/unit_tests/test_utils.py b/unit_tests/test_torchrun.py similarity index 100% rename from unit_tests/test_utils.py rename to unit_tests/test_torchrun.py diff --git a/unit_tests/utils.py b/unit_tests/utils.py new file mode 100644 index 00000000..0f9ae486 --- /dev/null +++ b/unit_tests/utils.py @@ -0,0 +1,73 @@ +from typing import Callable +import torch +import math +import random + + +def init_parameter(model: torch.nn.Module, seed: int = 0): + """ + Initialize a model's parameters with truncated normal distribution. + """ + def trunc_normal_(tensor: torch.Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.): + with torch.no_grad(): + l = (1. + math.erf((a - mean) / std / math.sqrt(2.))) / 2. + u = (1. + math.erf((b - mean) / std / math.sqrt(2.))) / 2. + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + torch.random.manual_seed(seed) + random.seed(seed) + + for param in list(model.parameters()) + list(model.buffers()): + if len(param.size()) > 1: + trunc_normal_(param, std=.02) + else: + torch.nn.init.constant_(param, 0) + + +def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-4) -> bool: + """Compare the output of baseline_fn and compile_fn + + Error will raise if the output of two functions are not the same. + + Args: + baseline_fn (Callable): a function that returns the output of baseline + compile_fn (Callable): a function that returns the output of compile (cube) + atol (Callable): absolute tolerance when comparing two torch tensors + + Returns: + result (bool): True if the output of two functions are the same else raise Error + """ + baseline_outputs = baseline_fn() + compile_outputs = compile_fn() + + print(f'comparing\nGT:\t{baseline_outputs}\nOUT:\t{compile_outputs}') + + def assert_same_complex(gt, out): + if isinstance(gt, tuple): + assert isinstance(out, tuple) + for ele_gt, ele_out in zip(gt, out): + assert_same_complex(ele_gt, ele_out) + elif isinstance(gt, list): + assert isinstance(out, list) + for ele_gt, ele_out in zip(gt, out): + assert_same_complex(ele_gt, ele_out) + elif isinstance(gt, dict): + assert isinstance(out, dict) + assert set(gt.keys()) == set(out.keys()) + for key in gt: + assert_same_complex(gt[key], out[key]) + elif isinstance(gt, torch.Tensor): + assert isinstance(out, torch.Tensor) + assert torch.allclose(gt, out, atol=atol), f'mismatched: {gt} != {out}' + elif isinstance(gt, float): + assert isinstance(out, float) + assert math.isclose(gt, out, abs_tol=atol), f'mismatched: {gt} != {out}' + else: + assert gt == out, f'mismatched: {gt} != {out}' + assert_same_complex(baseline_outputs, compile_outputs) + return None diff --git a/utility/aggregate.sh b/utility/aggregate.sh new file mode 100644 index 00000000..19185564 --- /dev/null +++ b/utility/aggregate.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# gather the folder to all workers to node-0 under the same workspace + +set -ex + +WORKSPACE=/workspace +FOLDER=MagicCube + +WORKER_PREFIX=node- +WORKER_NUM=2 + +for ((i=1; i<${WORKER_NUM}; i++)); do + WORKER=${WORKER_PREFIX}${i} + scp -r ${WORKER}:${WORKSPACE}/${FOLDER} ${WORKSPACE}/${FOLDER}-${WORKER} +done diff --git a/utility/broadcast.sh b/utility/broadcast.sh new file mode 100644 index 00000000..dbb77c7a --- /dev/null +++ b/utility/broadcast.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# broadcast the folder to all workers under the same workspace + +set -ex + +WORKSPACE=/workspace +FOLDER=MagicCube + +WORKER_PREFIX=node- +WORKER_NUM=2 + +for ((i=1; i<=${WORKER_NUM}; i++)); do + WORKER=${WORKER_PREFIX}${i} + scp -r ${WORKSPACE}/${SYNC_FOLDER} ${WORKER}:${WORKSPACE} +done diff --git a/scripts/dgx1_reorder_gpu.py b/utility/dgx1_reorder_gpu.py similarity index 100% rename from scripts/dgx1_reorder_gpu.py rename to utility/dgx1_reorder_gpu.py diff --git a/scripts/keep.py b/utility/keep.py similarity index 100% rename from scripts/keep.py rename to utility/keep.py diff --git a/tests/test_rvd_prim.py b/utility/test_rvd_prim.py similarity index 94% rename from tests/test_rvd_prim.py rename to utility/test_rvd_prim.py index 9442c8cb..739c8ca3 100644 --- a/tests/test_rvd_prim.py +++ b/utility/test_rvd_prim.py @@ -1,9 +1,12 @@ """ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=8 \ - --nnodes=1 \ - tests/test_rvd_prim.py --prims all + utility/test_rvd_prim.py --prims allreduce +OMP_NUM_THREADS=4 torchrun \ + --nnode=2 --node_rank=$NODE_RANK --master_addr=node-0 \ + --nproc_per_node=8 \ + utility/test_rvd_prim.py --prims all """ from typing import Callable @@ -110,7 +113,7 @@ def prim_bw(prim: Callable, bandwidth: Callable, ranks, size, warmup=100, profil args.prims = args.prims[0] prims, bws = [], [] - if 'allrecuce' in args.prims or 'all' in args.prims: + if 'allreduce' in args.prims or 'all' in args.prims: prims.append(prim_allreduce) bws.append(bw_allreduce) if 'allgather' in args.prims or 'all' in args.prims: From dbb2d03405004bf18ad1f93de6923e8a4eb380b5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 17 Aug 2023 01:54:34 +0000 Subject: [PATCH 1479/1892] Merged PR 1741: Fix IRGraphAnchor --- cube/graph/function/anchor.py | 36 ++++++++++++++++++++++++++++++--- cube/graph/parser/converter.py | 2 +- cube/graph/parser/fx/mapping.py | 32 ++++++----------------------- 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/cube/graph/function/anchor.py b/cube/graph/function/anchor.py index 0b4a3e50..acaeca14 100644 --- a/cube/graph/function/anchor.py +++ b/cube/graph/function/anchor.py @@ -1,20 +1,50 @@ from cube.ir.operator import IRFwOperation +from cube.ir.cten import IRObject class IRGraphAnchor(IRFwOperation): """ The anchor function serves for - 1) navigation inside the graph - 2) staging boundary inside the graph + 1) navigation inside the graph + 2) user hints of staging boundary inside the graph This operator will eventually be removed from graph, user doesn't need to manipulate it. + + To add anchor node in the graph, a user can simply insert anchor + function in model forward like following: + + ```python + class Model(torch.nn.Module): + + def __init__(self): + xxx + + def forward(self, x): + for layer in self.layers: + cube.runtime.function.anchor('layer start') + x = layer(x) + return x + ``` + + Then there will be anchor nodes named `layer start` inserted inside graph. + Policy maker can quickly access them by + + ```python + graph.select(name='layer start')` + ``` + + Or quickly find all anchor nodes through + + ```python + anchors = graph.select(ntype=IRGraphAnchor) + ``` """ def __init__(self, signature: str, name: str): super().__init__(name, signature, [], 1) self.kwargs['name'] = name - self.set_output(0, None) + self.set_output(0, IRObject('anchor', value=None)) def infer_dtype(self): return diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 7b5a7124..79246169 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -41,7 +41,7 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: # get cube runtime functions cube_rt_funcs = [cube_rt_function.anchor] - leaf_functions.update({func: ([], True, None) for func in cube_rt_funcs}) + leaf_functions.update({func: ([(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs}) dce_ignored_funcs = set(cube_rt_funcs) if HAS_APEX: diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index f02436e3..24eef67d 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -127,16 +127,7 @@ def exist(signature: str) -> bool: 'torch.mm': function.Matmul, __ttemplate('matmul'): function.Matmul, # - # __ftemplate('gelu') : function.GeLU, - # __ttemplate('gelu') : function.GeLU, - # - # __ftemplate('silu') : function.SiLU, - # __ttemplate('silu') : function.SiLU, - # # __ftemplate('_pad'): function.Pad, - # - # __ftemplate('embedding'): function.Embedding, - # __ftemplate('cross_entropy'): function.CrossEntropy, # # # creators @@ -199,31 +190,20 @@ def exist(signature: str) -> bool: # __ttemplate('select_scatter'): function.SelectScatter, # __tttemplate('repeat'): function.Repeat, - # - # #pytorch1.11 - # __ttemplate('linear'): function.Linear, - # __ttemplate('cat'): function.Cat, - __ttemplate('stack'): function.Stack, - # __ttemplate('chunk'): function.Chunk, - __ttemplate('flatten'): function.Flatten, - # # __ttemplate('roll'): function.Roll, # # __ttemplate('adaptive_avg_pool1d'): function.AdaptiveAvgPool1d, # - # # runtime functions - # __rtemplate('anchor'): function.GraphAnchor, - # - # __rtemplate('identity'): function.Identity, - # - # __rtemplate('multiref'): function.MultiRef, - # - # __rtemplate('accum'): function.Accum, - # + # runtime functions + __rtemplate('anchor'): function.GraphAnchor, + __rtemplate('identity'): function.Identity, + __rtemplate('multiref'): function.MultiRef, + __rtemplate('accum'): function.Accum, + # #einops # __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, From 4007da21f2af41af51ab6cafd8a3b6a04c905861 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 17 Aug 2023 01:58:44 +0000 Subject: [PATCH 1480/1892] Merged PR 1719: trace inner apex FusedLayerNorm module trace inner apex FusedLayerNorm module --- cube/graph/function/function.py | 35 +------- cube/graph/parser/__init__.py | 1 + cube/graph/parser/converter.py | 24 +----- cube/graph/parser/external/__init__.py | 1 + cube/graph/parser/external/apex.py | 81 +++++++++++++++++++ .../fx/concrete_trace_utils/concrete_proxy.py | 5 +- cube/graph/parser/fx/mapping.py | 2 - cube/graph/parser/fx/parser.py | 37 +-------- cube/runtime/function/function.py | 17 ---- 9 files changed, 90 insertions(+), 113 deletions(-) create mode 100644 cube/graph/parser/external/__init__.py create mode 100644 cube/graph/parser/external/apex.py diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 92383e9f..604e6b8f 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Tuple, Dict, Union, Iterable +from typing import Any, Callable, List, Tuple, Dict, Union, Iterable import string import copy import torch @@ -743,39 +743,6 @@ def LayerNorm(input, normalized_shape, weight=None, bias=None, eps=1e-05, signat return CubeLayerNorm(input, weight, bias, normalized_shape, eps, signature=signature) -def FusedLayerNorm(input, weight, bias, normalized_shape, eps=1e-5, signature = None): - """ - apex.normalization.fused_layer_norm.FusedLayerNorm - """ - signature = 'cube.runtime.function.fused_layer_norm' - assert not (weight is None and bias is not None), f"Not support for None of weight and parameter of bias" - letters = iter(string.ascii_lowercase) - einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) - eoutput = copy.copy(einput) - ndims = len(input.shape) - for dim in range(len(normalized_shape)): - einput[ndims-1-dim] += '^' - eoutput[ndims-1-dim] += '^' - einputs, inputs = [einput], [input] - kwargs = {} - if weight is not None: - eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) - einputs.append(eweight) - inputs.append(weight) - else: - kwargs['weight'] = weight - if bias is not None: - ebias = ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) - einputs.append(ebias) - inputs.append(bias) - else: - kwargs['bias'] = bias - anno = OpAnno.create_op_str(einputs, [eoutput]) - kwargs['normalized_shape'] = normalized_shape - kwargs['eps'] = eps - return IRDimops(FusedLayerNorm, 'fusedlayernorm', signature, [anno], inputs, **kwargs) - - def Norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None, signature=None): assert dtype is None, "Currently Norm only support dtype=None" einput = ShapeAnno.create_shape_str(input.shape) diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index a64c308e..21838a42 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,3 +1,4 @@ from cube.graph.parser.fx.parser import FxModuleParser, FxFuncOpTracer from cube.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph from cube.graph.parser.register import register +from cube.graph.parser.external import * diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 79246169..bc1d11af 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,7 +1,6 @@ +from typing import Any, Dict, Union import logging from pathlib import Path -from typing import Any, Dict, Union - from cube.ir.tensor import IRFullTensor from cube.graph.parser.register import CustomizedOps @@ -18,12 +17,6 @@ _logger = logging.getLogger(__name__) -try: - import apex - HAS_APEX = True -except: - HAS_APEX = False - def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: """ @@ -44,25 +37,10 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: leaf_functions.update({func: ([(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs}) dce_ignored_funcs = set(cube_rt_funcs) - if HAS_APEX: - leaf_module = ( - # torch.nn.Dropout, #torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d, - apex.normalization.FusedLayerNorm, - # NOTE: the following modules also have different behavior depending on self.training. but currently in used. - # torch.nn.AlphaDropout, torch.nn.FeatureAlphaDropout, - # torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, - # torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.SyncBatchNorm, - # torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, - # torch.nn.LazyInstanceNorm1d, torch.nn.LazyInstanceNorm2d, torch.nn.LazyInstanceNorm3d, - ) - else: - _logger.warning('apex package is not installed') - leaf_module = None traced_model = concrete_trace( model, dummy_input, use_operator_patch=True, - leaf_module=leaf_module, autowrap_leaf_function=leaf_functions, dce_ignored_function=dce_ignored_funcs, cpu_offload=True, diff --git a/cube/graph/parser/external/__init__.py b/cube/graph/parser/external/__init__.py new file mode 100644 index 00000000..8574c302 --- /dev/null +++ b/cube/graph/parser/external/__init__.py @@ -0,0 +1 @@ +from .apex import * \ No newline at end of file diff --git a/cube/graph/parser/external/apex.py b/cube/graph/parser/external/apex.py new file mode 100644 index 00000000..8b8b099e --- /dev/null +++ b/cube/graph/parser/external/apex.py @@ -0,0 +1,81 @@ +import copy +import logging +import string + +from cube.graph.function.dimops import ShapeAnno, OpAnno, IRDimops +from cube.graph.parser.register import CustomizedOps + +_logger = logging.getLogger(__name__) + + +try: + from apex.normalization.fused_layer_norm import FusedLayerNormFunction, FusedLayerNormAffineFunction + + def ApexFusedLayerNormFunction(input, normalized_shape, eps=1e-6, signature = None): + """ + apex.normalization.fused_layer_norm.FusedLayerNormFunction + """ + signature = 'apex_fused_layer_norm' + letters = iter(string.ascii_lowercase) + einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) + eoutput = copy.copy(einput) + ndims = len(input.shape) + for dim in range(len(normalized_shape)): + einput[ndims-1-dim] += '^' + eoutput[ndims-1-dim] += '^' + einputs, inputs = [einput], [input] + kwargs = {} + anno = OpAnno.create_op_str(einputs, [eoutput]) + kwargs['normalized_shape'] = normalized_shape + kwargs['eps'] = eps + return IRDimops(FusedLayerNormFunction, 'fusedlayernorm', signature, [anno], inputs, **kwargs) + + def ApexFusedLayerNormAffineFunction(input, weight, bias, normalized_shape, eps=1e-6, signature = None): + """ + apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction + """ + signature = 'apex_fused_layer_norm_affine' + assert not (weight is None and bias is not None), f"Not support for None of weight and parameter of bias" + letters = iter(string.ascii_lowercase) + einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) + eoutput = copy.copy(einput) + ndims = len(input.shape) + for dim in range(len(normalized_shape)): + einput[ndims-1-dim] += '^' + eoutput[ndims-1-dim] += '^' + einputs, inputs = [einput], [input] + kwargs = {} + if weight is not None: + eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) + einputs.append(eweight) + inputs.append(weight) + else: + kwargs['weight'] = weight + if bias is not None: + ebias = ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) + einputs.append(ebias) + inputs.append(bias) + else: + kwargs['bias'] = bias + anno = OpAnno.create_op_str(einputs, [eoutput]) + kwargs['normalized_shape'] = normalized_shape + kwargs['eps'] = eps + return IRDimops(FusedLayerNormAffineFunction, 'fusedlayernormaffine', signature, [anno], inputs, **kwargs) + + CustomizedOps.register('apex.normalization.fused_layer_norm.FusedLayerNormFunction.apply', + ApexFusedLayerNormFunction, + 'from apex.normalization.fused_layer_norm import fused_layer_norm as apex_fused_layer_norm', + FusedLayerNormFunction.apply, + keep_full_name=True, + trace_autowrap=False) + + CustomizedOps.register('apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply', + ApexFusedLayerNormAffineFunction, + 'from apex.normalization.fused_layer_norm import fused_layer_norm_affine as apex_fused_layer_norm_affine', + FusedLayerNormAffineFunction.apply, + keep_full_name=True, + trace_autowrap=False) +except: + _logger.warning('skip apex ops as it is not installed.') + + diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 387f8fae..ecc10327 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -279,7 +279,10 @@ def __init__(self, root: ConcreteProxy, attr: str): self.attr = attr self.tracer = root.tracer self._node: Optional[Node] = None - self.value = _orig_getattr(root.value, attr) + if _orig_isinstance(root.value, torch.Tensor) and attr == 'is_cuda' and self.tracer.cpu_offload: + self.value = True + else: + self.value = _orig_getattr(root.value, attr) def __repr__(self) -> str: calling_frame_name = inspect.stack()[1][1] diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 24eef67d..d3b6e455 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -109,8 +109,6 @@ def exist(signature: str) -> bool: __ftemplate('nll_loss') : function.NLLLoss, 'torch.functional.norm': function.Norm, __ftemplate('layer_norm'): function.LayerNorm, - 'apex.normalization.fused_layer_norm.FusedLayerNorm': function.FusedLayerNorm, - 'apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply': function.FusedLayerNorm, # ============== runtime function ================= __tttemplate('size'): function.Size, diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 69d30a22..1fe2a293 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -264,43 +264,8 @@ def fetch_attr(mod: torch.fx.GraphModule, target: str): @staticmethod def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: prim_module = FxModuleParser.fetch_attr(module, node.target) - input_vals = [get_complex_data(val, frame) for val in node.args] - kwargs = {key: get_complex_data(val, frame) for key, val in node.kwargs.items()} if prim_module.__class__.__module__.startswith('torch.nn.modules'): - assert False, 'torch.nn.modules can not be parsed as leaf nodes' - elif prim_module.__class__.__module__ == 'apex.normalization.fused_layer_norm': - fsig = '{}.{}'.format(prim_module.__class__.__module__, prim_module.__class__.__name__) - assert prim_module.elementwise_affine is True - assert SignFx2Op.exist(fsig) - assert len(kwargs) == 0 - # add var of weight and bias into frame - shape = FxModuleParser.shape_refine(prim_module.weight.size()) - dtype = prim_module.weight.dtype - requires_grad = prim_module.weight.requires_grad - ir_weight_val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=f'{node.name}_weight') - ir_weight_val.as_param() - frame.add_var(ir_weight_val.name, ir_weight_val) - frame.add_attr_content(ir_weight_val.tid, prim_module.weight) - frame.add_attr_map(ir_weight_val.name, node.target+'.weight') - shape = FxModuleParser.shape_refine(prim_module.bias.size()) - dtype = prim_module.bias.dtype - requires_grad = prim_module.bias.requires_grad - ir_bias_val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=f'{node.name}_bias') - ir_bias_val.as_param() - frame.add_var(ir_bias_val.name, ir_bias_val) - frame.add_attr_content(ir_bias_val.tid, prim_module.bias) - frame.add_attr_map(ir_bias_val.name, node.target+'.bias') - input_vals.extend([ir_weight_val, ir_bias_val]) - kwargs.update({'normalized_shape': prim_module.normalized_shape, 'eps': prim_module.eps}) - ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) - assert isinstance(ir_node, IRCell) - assert len(ir_node.outputs()) == 1 - output_val = frame.get_var(node.name) - ir_node.set_output(0, output_val) - comment = str(node.meta.get('frame_record', '')) - if comment: - ir_node.comment = comment - return [ir_node] + raise RuntimeError(f'{prim_module.__class__.__module__} can not be parsed as leaf nodes') else: raise RuntimeError(f'unknown module: {prim_module.__class__.__module__}') diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 99c26beb..9bd01170 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -1,17 +1,6 @@ from typing import Optional, List, Tuple, Union import torch import torch.nn.functional as TorchF -import logging - - -_logger = logging.getLogger(__name__) - - -# TODO: move to registered function -try: - from apex.normalization.fused_layer_norm import fused_layer_norm_affine -except: - _logger.warning('skip apex ops as it is not installed.') def identity(tensor: torch.Tensor) -> torch.Tensor: @@ -138,12 +127,6 @@ def layer_norm(input: torch.Tensor, return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias, eps) -def fused_layer_norm(input: torch.Tensor, - weight: torch.Tensor, bias: torch.Tensor, - normalized_shape: List[int], eps: float = 1e-05) -> torch.Tensor: - return fused_layer_norm_affine(input, weight, bias, normalized_shape, eps) - - # 'torch.select_scatter' isn't supported by Torch2ONNX yet. # Implement it with 'torch.masked_scatter' which is supported with ONNX opset=11. def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): From 58d471115be11f7012358575b881b875fcf766ed Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 17 Aug 2023 02:34:00 +0000 Subject: [PATCH 1481/1892] Merged PR 1747: rename unit_tests folder to tests --- {unit_tests => tests}/__init__.py | 0 {unit_tests => tests}/compiler/__init__.py | 0 {unit_tests => tests}/compiler/test_compile.py | 0 {unit_tests => tests}/graph/__init__.py | 0 {unit_tests => tests}/graph/function/test_dataloader.py | 0 {unit_tests => tests}/graph/function/test_dimops.py | 0 {unit_tests => tests}/graph/gener/check_inter_rvd.py | 0 {unit_tests => tests}/graph/gener/check_intra_rvd.py | 0 {unit_tests => tests}/graph/parser/__init__.py | 0 {unit_tests => tests}/graph/parser/test_converter.py | 0 {unit_tests => tests}/graph/test_multiref.py | 0 {unit_tests => tests}/launch_torchrun.py | 0 {unit_tests => tests}/parallel_module/__init__.py | 0 {unit_tests => tests}/parallel_module/common.py | 8 ++++---- {unit_tests => tests}/parallel_module/test_gencode.py | 0 {unit_tests => tests}/parallel_module/test_nested.py | 0 {unit_tests => tests}/parallel_module/test_override.py | 0 {unit_tests => tests}/parallel_module/test_submodule.py | 9 +++++---- {unit_tests => tests}/runtime/__init__.py | 0 {unit_tests => tests}/runtime/test_reducer.py | 0 .../runtime/test_runtime_collectives.py | 0 {unit_tests => tests}/test_torchrun.py | 0 {unit_tests => tests}/utils.py | 0 tox.ini | 2 +- 24 files changed, 10 insertions(+), 9 deletions(-) rename {unit_tests => tests}/__init__.py (100%) rename {unit_tests => tests}/compiler/__init__.py (100%) rename {unit_tests => tests}/compiler/test_compile.py (100%) rename {unit_tests => tests}/graph/__init__.py (100%) rename {unit_tests => tests}/graph/function/test_dataloader.py (100%) rename {unit_tests => tests}/graph/function/test_dimops.py (100%) rename {unit_tests => tests}/graph/gener/check_inter_rvd.py (100%) rename {unit_tests => tests}/graph/gener/check_intra_rvd.py (100%) rename {unit_tests => tests}/graph/parser/__init__.py (100%) rename {unit_tests => tests}/graph/parser/test_converter.py (100%) rename {unit_tests => tests}/graph/test_multiref.py (100%) rename {unit_tests => tests}/launch_torchrun.py (100%) rename {unit_tests => tests}/parallel_module/__init__.py (100%) rename {unit_tests => tests}/parallel_module/common.py (97%) rename {unit_tests => tests}/parallel_module/test_gencode.py (100%) rename {unit_tests => tests}/parallel_module/test_nested.py (100%) rename {unit_tests => tests}/parallel_module/test_override.py (100%) rename {unit_tests => tests}/parallel_module/test_submodule.py (97%) rename {unit_tests => tests}/runtime/__init__.py (100%) rename {unit_tests => tests}/runtime/test_reducer.py (100%) rename {unit_tests => tests}/runtime/test_runtime_collectives.py (100%) rename {unit_tests => tests}/test_torchrun.py (100%) rename {unit_tests => tests}/utils.py (100%) diff --git a/unit_tests/__init__.py b/tests/__init__.py similarity index 100% rename from unit_tests/__init__.py rename to tests/__init__.py diff --git a/unit_tests/compiler/__init__.py b/tests/compiler/__init__.py similarity index 100% rename from unit_tests/compiler/__init__.py rename to tests/compiler/__init__.py diff --git a/unit_tests/compiler/test_compile.py b/tests/compiler/test_compile.py similarity index 100% rename from unit_tests/compiler/test_compile.py rename to tests/compiler/test_compile.py diff --git a/unit_tests/graph/__init__.py b/tests/graph/__init__.py similarity index 100% rename from unit_tests/graph/__init__.py rename to tests/graph/__init__.py diff --git a/unit_tests/graph/function/test_dataloader.py b/tests/graph/function/test_dataloader.py similarity index 100% rename from unit_tests/graph/function/test_dataloader.py rename to tests/graph/function/test_dataloader.py diff --git a/unit_tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py similarity index 100% rename from unit_tests/graph/function/test_dimops.py rename to tests/graph/function/test_dimops.py diff --git a/unit_tests/graph/gener/check_inter_rvd.py b/tests/graph/gener/check_inter_rvd.py similarity index 100% rename from unit_tests/graph/gener/check_inter_rvd.py rename to tests/graph/gener/check_inter_rvd.py diff --git a/unit_tests/graph/gener/check_intra_rvd.py b/tests/graph/gener/check_intra_rvd.py similarity index 100% rename from unit_tests/graph/gener/check_intra_rvd.py rename to tests/graph/gener/check_intra_rvd.py diff --git a/unit_tests/graph/parser/__init__.py b/tests/graph/parser/__init__.py similarity index 100% rename from unit_tests/graph/parser/__init__.py rename to tests/graph/parser/__init__.py diff --git a/unit_tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py similarity index 100% rename from unit_tests/graph/parser/test_converter.py rename to tests/graph/parser/test_converter.py diff --git a/unit_tests/graph/test_multiref.py b/tests/graph/test_multiref.py similarity index 100% rename from unit_tests/graph/test_multiref.py rename to tests/graph/test_multiref.py diff --git a/unit_tests/launch_torchrun.py b/tests/launch_torchrun.py similarity index 100% rename from unit_tests/launch_torchrun.py rename to tests/launch_torchrun.py diff --git a/unit_tests/parallel_module/__init__.py b/tests/parallel_module/__init__.py similarity index 100% rename from unit_tests/parallel_module/__init__.py rename to tests/parallel_module/__init__.py diff --git a/unit_tests/parallel_module/common.py b/tests/parallel_module/common.py similarity index 97% rename from unit_tests/parallel_module/common.py rename to tests/parallel_module/common.py index 677230f6..47d13730 100644 --- a/unit_tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -57,8 +57,8 @@ def PASRandomSPMD(graph: IRGraph, env_resource: ComputeConfig): # workaround for now # will fix later. if any(output in graph_outputs for output in node.outputs()) \ - or any(input in graph_outputs for input in node.inputs()) \ - or any(input in graph_inputs for input in node.inputs()): + or any(input in graph_outputs for input in node.inputs()): + # or any(input in graph_inputs for input in node.inputs()): _replica(graph, node, devs) continue if isinstance(node, IRDimops): @@ -109,8 +109,8 @@ def PASData(graph: IRGraph, env_resource: ComputeConfig): # workaround for now # will fix later. if any(output in graph_outputs for output in node.outputs()) \ - or any(input in graph_outputs for input in node.inputs()) \ - or any(input in graph_inputs for input in node.inputs()): + or any(input in graph_outputs for input in node.inputs()): + # or any(input in graph_inputs for input in node.inputs()): sub_nodes = graph.replicate(node, ngpus) else: try: diff --git a/unit_tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py similarity index 100% rename from unit_tests/parallel_module/test_gencode.py rename to tests/parallel_module/test_gencode.py diff --git a/unit_tests/parallel_module/test_nested.py b/tests/parallel_module/test_nested.py similarity index 100% rename from unit_tests/parallel_module/test_nested.py rename to tests/parallel_module/test_nested.py diff --git a/unit_tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py similarity index 100% rename from unit_tests/parallel_module/test_override.py rename to tests/parallel_module/test_override.py diff --git a/unit_tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py similarity index 97% rename from unit_tests/parallel_module/test_submodule.py rename to tests/parallel_module/test_submodule.py index 2ff97e1d..8459024c 100644 --- a/unit_tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -251,21 +251,22 @@ def test_submodules_dp_gpu2(): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: print('skip test_submodules_dp_gpu2 due to lack of cuda devices') return + eps = 1e-4 results = launch_torchrun(2, _gpu_worker, PASData, 2) for r in results.values(): orig_results, compiled_results, _, _, _, _ = r for orig, compiled in zip(orig_results, compiled_results): - assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred - assert torch.allclose(orig[1], compiled[1], rtol=1e-6, atol=1e-6) # loss + assert torch.allclose(orig[0], compiled[0], rtol=eps, atol=eps) # pred + assert torch.allclose(orig[1], compiled[1], rtol=eps, atol=eps) # loss # grad compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items() if not k.endswith('_grad_sentry')} assert len(orig[2]) == len(compiled_cleaned) for k in orig[2].keys(): - assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=eps, atol=eps) # weights compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items() if not k.endswith('_grad_sentry')} assert len(orig[3]) == len(compiled_cleaned) for k in orig[3].keys(): - assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=eps, atol=eps) diff --git a/unit_tests/runtime/__init__.py b/tests/runtime/__init__.py similarity index 100% rename from unit_tests/runtime/__init__.py rename to tests/runtime/__init__.py diff --git a/unit_tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py similarity index 100% rename from unit_tests/runtime/test_reducer.py rename to tests/runtime/test_reducer.py diff --git a/unit_tests/runtime/test_runtime_collectives.py b/tests/runtime/test_runtime_collectives.py similarity index 100% rename from unit_tests/runtime/test_runtime_collectives.py rename to tests/runtime/test_runtime_collectives.py diff --git a/unit_tests/test_torchrun.py b/tests/test_torchrun.py similarity index 100% rename from unit_tests/test_torchrun.py rename to tests/test_torchrun.py diff --git a/unit_tests/utils.py b/tests/utils.py similarity index 100% rename from unit_tests/utils.py rename to tests/utils.py diff --git a/tox.ini b/tox.ini index ed160849..7598dac7 100644 --- a/tox.ini +++ b/tox.ini @@ -14,5 +14,5 @@ deps = -rrequirements.txt -rrequirements-dev.txt commands = coverage erase - pytest --cov={toxinidir}/cube -x unit_tests + pytest --cov={toxinidir}/cube -x tests coverage html From b66068fb0874e741e5a9da609dd40fa40139c739 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 18 Aug 2023 08:08:08 +0000 Subject: [PATCH 1482/1892] Merged PR 1749: parallel module: test inference support --- cube/codegen/module/module.py | 2 +- cube/parallel.py | 18 ++++-- cube/runtime/module.py | 15 +++-- tests/launch_torchrun.py | 9 +-- tests/parallel_module/test_inference.py | 84 +++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 18 deletions(-) create mode 100644 tests/parallel_module/test_inference.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index c03343bf..e4f9ff3b 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -359,12 +359,12 @@ def gen( with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.model_init_statements) if as_parallel_module: + ib.insert_body('') ib.insert_body(f'self.load_attr_content(Path(__file__).with_name("{FxModuleParser.ATTR_CONTENT_FILE}"))') ib.insert_body(f'self.load_dist_param_map(Path(__file__).with_name("{FxModuleParser.ATTR_MAP_FILE}"))') ib.insert_body('') with ForBlock('reducer', f'self.reducers') as for_block: for_block.insert_body(f'reducer.build_buckets()') - ib.insert_body('') ib.insert_body(for_block.code) cb.insert_body('') cb.insert_body(ib.code) diff --git a/cube/parallel.py b/cube/parallel.py index 6200cc5d..982ff2d1 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -159,8 +159,7 @@ def _gencode( program.clear() IDGenerator().clear() - module = module.to(device=torch.device("cpu")) - module.train() + module.cpu() # generate fx graph dummy_input = _complex(dummy_input) @@ -257,7 +256,7 @@ def _load_cube_module_class( *, cube_savedir: Union[str, Path] = './.cube', instance_name: Optional[str] = None, -): +) -> Type[ParallelModule]: """ Load the generated cube module class. @@ -306,6 +305,10 @@ def parallelize( Or you can unset load_module flag, and manually copy the generated files to other nodes. After all nodes have the generated files, you can call parallelize() again with load_module flag set. + if the input is a module object. + The module object will be copied to cpu to handle possible insufficient gpu memory. + The training flag will be the same as the original module + Args: module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled dummy_input (dict): the dummy input for the module @@ -362,7 +365,7 @@ def parallelize( del module if load_module: - if not torch.distributed.is_initialized(): # we only support distributed training + if not torch.distributed.is_initialized(): # we only support loading in torchrun environment raise RuntimeError("Load CubeModule failed: torch.distributed is not initialized.") torch.distributed.barrier() cube_module_class = _load_cube_module_class( @@ -370,4 +373,9 @@ def parallelize( cube_savedir=cube_savedir, instance_name=instance_name, ) - return cube_module_class if is_module_class else cube_module_class() + if is_module_class: + return cube_module_class + else: + cube_module = cube_module_class() + cube_module.train(module_or_module_class.training) # set training state to the same as original module + return cube_module diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 4f8f9ba1..4327788e 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -372,6 +372,9 @@ def __init__(self): super().__init__() def forward(self, *args): + if not self.training: + return args + new_args = [] found_tensor = False for arg in args: @@ -407,11 +410,13 @@ def __init__(self): self._grad_sentry.grad = torch.tensor([self._EPSILON]) def forward(self, *args, **kwargs): - if self._grad_sentry.grad is None or self._grad_sentry.grad.item() == 0: - self.zero_grad() - self._grad_sentry.grad = torch.tensor([self._EPSILON]) - - new_args = self._add_grad_module(*args) + if self.training: + if self._grad_sentry.grad is None or self._grad_sentry.grad.item() == 0: + self.zero_grad() + self._grad_sentry.grad = torch.tensor([self._EPSILON]) + new_args = self._add_grad_module(*args) + else: + new_args = args return self._forward_impl(*new_args, **kwargs) def _forward_impl(self, *args, **kwargs): diff --git a/tests/launch_torchrun.py b/tests/launch_torchrun.py index 65d93e7a..6267aa19 100644 --- a/tests/launch_torchrun.py +++ b/tests/launch_torchrun.py @@ -1,14 +1,9 @@ from typing import Callable import uuid import torch -import logging -import time -import random from torch.distributed.run import elastic_launch, LaunchConfig -_logger = logging.getLogger(__name__) - def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): launch_config = LaunchConfig( @@ -27,7 +22,7 @@ def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): def torchrun(nproc_per_node: int, test_fn: Callable, *args, **kwargs): """Test utility for torchrun - + Example usage: ```python @@ -46,7 +41,7 @@ def torchrun(nproc_per_node: int, test_fn: Callable, *args, **kwargs): """ if not torch.cuda.is_available() or torch.cuda.device_count() < nproc_per_node: - _logger.warning(f"skip test on {nproc_per_node} gpus due to lack of cuda devices") + print(f"skip test on {nproc_per_node} gpus due to lack of cuda devices") return launch_torchrun(nproc_per_node, test_fn, *args, **kwargs) diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py new file mode 100644 index 00000000..6ecc7138 --- /dev/null +++ b/tests/parallel_module/test_inference.py @@ -0,0 +1,84 @@ +from pathlib import Path +import shutil +import tempfile +import torch +from torch import nn + +from cube.parallel import ComputeConfig, parallelize + +from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD +from ..launch_torchrun import torchrun + + +class FcRelu(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, out_features, bias=bias) + self.relu2 = nn.ReLU() + self.fc3 = CubeLinear(out_features, out_features, bias=bias) + self.relu3 = nn.ReLU() + + + def forward(self, x): + return self.relu3(self.fc3(self.relu2(self.fc2(self.relu1(self.fc1(x)))))) + + +class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc_relu1 = FcRelu(4, 4) + self.fc_relu2 = FcRelu(4, 4) + self.dropout = nn.Dropout(0.5) + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.fc_relu1(x) + x = self.fc_relu2(x) + x = self.dropout(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + dynamic_shape=True, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + +def _inference_worker(ngpus): + init_distributed() + init_random() + + tempdir = Path(tempfile.gettempdir()) / 'cube_inference_test' + if torch.distributed.get_rank() == 0 and tempdir.exists(): + shutil.rmtree(tempdir) + torch.distributed.barrier() + + model = Module() + model.eval() + cube_model = _to_cube_model(model, PASRandomSPMD, ComputeConfig(ngpus, ngpus), tempdir, 'test_inference') + + data = torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]) + assert not model.training + assert not cube_model.training + model.cuda() + + with torch.inference_mode(): + result = model(data) + cube_result = cube_model(data) + assert torch.allclose(result, cube_result, atol=1e-4) + +def test_inference1(): + torchrun(1, _inference_worker, 1) + + +def test_inference2(): + torchrun(2, _inference_worker, 2) From ebb325e2872aa37cd7e6f85174f3f6da36f17ce4 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 21 Aug 2023 03:07:33 +0000 Subject: [PATCH 1483/1892] Merged PR 1757: tests: remove tempdir on exit tests: remove tempdir on exit --- .gitignore | 1 + tests/parallel_module/common.py | 12 ++++++++++ tests/parallel_module/test_inference.py | 30 +++++++++++-------------- tests/parallel_module/test_submodule.py | 29 +++++++++++------------- 4 files changed, 39 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 7bcd3abd..b6267036 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ benchmark/deepspeed/Megatron-DeepSpeed gencode*.py fullmodel.pt +dist_param_map.pt diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 47d13730..17385b36 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -1,7 +1,9 @@ from datetime import datetime import math import random +import shutil from typing import List, Optional +import contextlib import torch from torch import nn @@ -165,3 +167,13 @@ def init_random(): torch.manual_seed(1) if torch.cuda.is_available(): torch.cuda.manual_seed(1) + + + +@contextlib.contextmanager +def clear_dir_on_rank0(tempdir): + if torch.distributed.get_rank() == 0 and tempdir.exists(): + shutil.rmtree(tempdir) + yield tempdir + if torch.distributed.get_rank() == 0 and tempdir.exists(): + shutil.rmtree(tempdir) diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 6ecc7138..255c6817 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -6,7 +6,7 @@ from cube.parallel import ComputeConfig, parallelize -from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD +from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD, clear_dir_on_rank0 from ..launch_torchrun import torchrun @@ -57,24 +57,20 @@ def _inference_worker(ngpus): init_distributed() init_random() - tempdir = Path(tempfile.gettempdir()) / 'cube_inference_test' - if torch.distributed.get_rank() == 0 and tempdir.exists(): - shutil.rmtree(tempdir) - torch.distributed.barrier() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_inference_test') as tempdir: + model = Module() + model.eval() + cube_model = _to_cube_model(model, PASRandomSPMD, ComputeConfig(ngpus, ngpus), tempdir, 'test_inference') - model = Module() - model.eval() - cube_model = _to_cube_model(model, PASRandomSPMD, ComputeConfig(ngpus, ngpus), tempdir, 'test_inference') + data = torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]) + assert not model.training + assert not cube_model.training + model.cuda() - data = torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]) - assert not model.training - assert not cube_model.training - model.cuda() - - with torch.inference_mode(): - result = model(data) - cube_result = cube_model(data) - assert torch.allclose(result, cube_result, atol=1e-4) + with torch.inference_mode(): + result = model(data) + cube_result = cube_model(data) + assert torch.allclose(result, cube_result, atol=1e-4) def test_inference1(): torchrun(1, _inference_worker, 1) diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 8459024c..8a08f725 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -11,7 +11,7 @@ from cube.parallel import ComputeConfig, parallelize from cube.runtime.module import ParallelModule -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively @@ -111,21 +111,18 @@ def _train(model): def _gpu_worker(pas, ngpus): init_distributed() - tempdir = Path(tempfile.gettempdir()) / 'cube_test' - if torch.distributed.get_rank() == 0 and tempdir.exists(): - shutil.rmtree(tempdir) - torch.distributed.barrier() - orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) - orig_results = _train(orig_module) - compiled_results = _train(compiled_module) - return ( - orig_results, - compiled_results, - compiled_module.fc_relu1.get_full_map(), - compiled_module.fc_relu1.get_dist_param_map(), - compiled_module.fc_relu2.get_full_map(), - compiled_module.fc_relu2.get_dist_param_map(), - ) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) + orig_results = _train(orig_module) + compiled_results = _train(compiled_module) + return ( + orig_results, + compiled_results, + compiled_module.fc_relu1.get_full_map(), + compiled_module.fc_relu1.get_dist_param_map(), + compiled_module.fc_relu2.get_full_map(), + compiled_module.fc_relu2.get_dist_param_map(), + ) def test_submodules_tp_gpu1(): From 6281ac8a1652c4f10c8eaba356340caf280580ed Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 21 Aug 2023 06:47:54 +0000 Subject: [PATCH 1484/1892] Merged PR 1763: parallel module: move some init code to ParallelModule parallel module: move some init code to ParallelModule --- cube/codegen/module/module.py | 9 ++------- cube/parallel.py | 5 ++++- cube/runtime/module.py | 21 ++++++++++++++++++--- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index e4f9ff3b..c81e112f 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -359,13 +359,8 @@ def gen( with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.model_init_statements) if as_parallel_module: - ib.insert_body('') - ib.insert_body(f'self.load_attr_content(Path(__file__).with_name("{FxModuleParser.ATTR_CONTENT_FILE}"))') - ib.insert_body(f'self.load_dist_param_map(Path(__file__).with_name("{FxModuleParser.ATTR_MAP_FILE}"))') - ib.insert_body('') - with ForBlock('reducer', f'self.reducers') as for_block: - for_block.insert_body(f'reducer.build_buckets()') - ib.insert_body(for_block.code) + cb.insert_body('') + ib.insert_body('self._post_init()') cb.insert_body('') cb.insert_body(ib.code) segment_idxs =[] diff --git a/cube/parallel.py b/cube/parallel.py index 982ff2d1..297e930f 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -136,10 +136,12 @@ def _gencode( expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.plan_ngpus)] expected_output_files.append(outdir / FxModuleParser.ATTR_CONTENT_FILE) expected_output_files.append(outdir / FxModuleParser.ATTR_MAP_FILE) + expected_output_files.append(outdir / ParallelModule.COMPUTE_CONFIG_FILE) existing_output_files = [f for f in outdir.glob('*') if f.is_file()] if existing_output_files: if all([output_file.exists() for output_file in expected_output_files]) \ - and len(existing_output_files) == len(expected_output_files): + and len(existing_output_files) == len(expected_output_files) \ + and torch.load(outdir / ParallelModule.COMPUTE_CONFIG_FILE) == compute_config: return else: raise RuntimeError(f'Output directory {outdir} is not empty. ' @@ -243,6 +245,7 @@ def _gencode( execplan = Grouping.apply(execplan) # code generation + torch.save(compute_config, outdir / ParallelModule.COMPUTE_CONFIG_FILE) runtime_ngpus = None if compute_config.plan_ngpus == compute_config.runtime_ngpus else compute_config.runtime_ngpus assert len(execplan.graph.device) == compute_config.plan_ngpus, f"{execplan.graph.device}" mgener = ModuleCodeGen(execplan, scale_ndevs=runtime_ngpus) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 4327788e..6f4f5efb 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -2,8 +2,10 @@ import logging import os import sys +from pathlib import Path import torch +from cube.graph.parser.fx.parser import FxModuleParser from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer @@ -393,9 +395,13 @@ def forward(self, *args): class ParallelModule(CubeModule): _EPSILON = 1e-7 # A small constant that can be represented by fp16 + COMPUTE_CONFIG_FILE = 'compute_config.pt' + def __init__(self): + if self.__class__ == ParallelModule: # not init via super().__init__() + raise RuntimeError(f"ParallelModule should not be initialized directly. Please derive it first") + super().__init__() - self._dist_param_map = None # should fill in sub classes. # register_full_backward_pre_hook requires the input tensor to be requires_grad # so we add a module to make sure the input tensor requires grad @@ -409,6 +415,15 @@ def __init__(self): self._grad_sentry = torch.nn.Parameter(torch.tensor([0.0], requires_grad=True)) self._grad_sentry.grad = torch.tensor([self._EPSILON]) + def _post_init(self): + module_file = Path(sys.modules[self.__module__].__file__) + self.load_attr_content(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE}")) + self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}")) + self._compute_config = torch.load(module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}")) + + for reducer in self.reducers: + reducer.build_buckets() + def forward(self, *args, **kwargs): if self.training: if self._grad_sentry.grad is None or self._grad_sentry.grad.item() == 0: @@ -435,5 +450,5 @@ def backward_hook(self, module, grad_output): def get_dist_param_map(self): return self._dist_param_map - def load_dist_param_map(self, filename: str): - self._dist_param_map = torch.load(filename) + def get_compute_config(self): + return self._compute_config From 1603249f7ab93948b618fc1fba9a340f7492db58 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 23 Aug 2023 02:43:43 +0000 Subject: [PATCH 1485/1892] Merged PR 1764: parallel module: refine training process parallel module: refine training process --- cube/parallel.py | 66 ++++++++- cube/runtime/module.py | 13 +- tests/parallel_module/test_submodule.py | 18 +-- tests/parallel_module/test_wholemodule.py | 173 ++++++++++++++++++++++ 4 files changed, 247 insertions(+), 23 deletions(-) create mode 100644 tests/parallel_module/test_wholemodule.py diff --git a/cube/parallel.py b/cube/parallel.py index 297e930f..9612acb0 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -1,3 +1,4 @@ +import types from typing import Callable, Any, Dict, Optional, Type, Union from pathlib import Path import inspect @@ -297,7 +298,7 @@ def parallelize( override: bool = False, instance_name: Optional[str] = None, load_module: bool = True, -) -> Union[None, CubeModule, Type[CubeModule]]: +) -> Union[None, ParallelModule, Type[ParallelModule]]: """ Convert a torch.nn.Module object or class to CubeModule object or class. @@ -382,3 +383,66 @@ def parallelize( cube_module = cube_module_class() cube_module.train(module_or_module_class.training) # set training state to the same as original module return cube_module + + +def build_optimizer( + module: torch.nn.Module, + optimizer_fn: Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]], + *args, + **kwargs, +) -> torch.optim.Optimizer: + """ + Build an optimizer for a module. + + To support parallelized module (CubeModule), we need to hook 4 places: + 1. optimizer constructor: + the parameters of optimizer will not be the same with the parameters of the module if we use zero + so we need to replace the parameters of optimizer with CubeModule.parameters_for_optimizer + It is impossible to make this change transparent to end users. + 2. optimizer.step(): + In zero mode, we have to call CubeModule.gather_params() after optimizer.step() + 3. optimizer.zero_grad(): + We need to call CubeModule.zero_grad() after optimizer.zero_grad() + 4. backward(): + we need to call CubeModule.sync_grads() after each CubeModule backward. + This is done with _AddGradModule and its hook in ParallelModule. + + Please note this DOES NOT work in end2end mode. + + Args: + module (torch.nn.Module): the module to be optimized + optimizer_fn (Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]): + It can be the optimizer class or optimizer factory function. + If it is a factory function, the signature should be the same with optimizer class constructor. + *args: the args for optimizer constructor + **kwargs: the kwargs for optimizer constructor + + Returns: + torch.optim.Optimizer: the optimizer you should use to train the module + """ + + if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): + raise RuntimeError("End2End mode is not supported") + + def _local_parameters(module: torch.nn.Module): + gen = module._named_members(lambda m: m._parameters.items()) + for _, param in gen: + yield param + + optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), *args, **kwargs) + + def _step_hook(opt, *args, **kwargs): + for m in module.modules(): + if isinstance(m, CubeModule): + m.gather_params() + optimizer.register_step_post_hook(_step_hook) + + orig_zero_grad = optimizer.zero_grad + def _patched_zero_grad_hook(self, set_to_none: bool = True): + orig_zero_grad(set_to_none) + for m in module.modules(): + if isinstance(m, CubeModule): + m.zero_grad() + optimizer.zero_grad = types.MethodType(_patched_zero_grad_hook, optimizer) + + return optimizer diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 6f4f5efb..16d898fb 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -5,8 +5,8 @@ from pathlib import Path import torch -from cube.graph.parser.fx.parser import FxModuleParser +from cube.graph.parser.fx.parser import FxModuleParser from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer @@ -394,7 +394,6 @@ def forward(self, *args): class ParallelModule(CubeModule): - _EPSILON = 1e-7 # A small constant that can be represented by fp16 COMPUTE_CONFIG_FILE = 'compute_config.pt' def __init__(self): @@ -408,13 +407,6 @@ def __init__(self): self._add_grad_module = _AddGradModule() self._add_grad_module.register_full_backward_pre_hook(self.backward_hook) - # if _grad_sentry.grad becomes None or zero - # we should zero grad in the next forward - # NOTE: this is a hacky way to detect whether the backward is called - # And it will add an extra parameter to the module - self._grad_sentry = torch.nn.Parameter(torch.tensor([0.0], requires_grad=True)) - self._grad_sentry.grad = torch.tensor([self._EPSILON]) - def _post_init(self): module_file = Path(sys.modules[self.__module__].__file__) self.load_attr_content(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE}")) @@ -426,9 +418,6 @@ def _post_init(self): def forward(self, *args, **kwargs): if self.training: - if self._grad_sentry.grad is None or self._grad_sentry.grad.item() == 0: - self.zero_grad() - self._grad_sentry.grad = torch.tensor([self._EPSILON]) new_args = self._add_grad_module(*args) else: new_args = args diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 8a08f725..d57f0c56 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -8,7 +8,7 @@ from torch import nn import numpy as np -from cube.parallel import ComputeConfig, parallelize +from cube.parallel import ComputeConfig, parallelize, build_optimizer from cube.runtime.module import ParallelModule from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 @@ -84,7 +84,7 @@ def _train(model): init_random() loss_fn = nn.BCELoss() - optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) data = [] DATA_SIZE = 20 UPDATE_FREQ = 1 # TODO: update_freq support @@ -136,13 +136,13 @@ def test_submodules_tp_gpu1(): assert torch.allclose(orig[1], compiled[1], rtol=1e-6, atol=1e-6) # loss # grad - compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items() if not k.endswith('_grad_sentry')} + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items()} assert len(orig[2]) == len(compiled_cleaned) for k in orig[2].keys(): assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) # weights - compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items() if not k.endswith('_grad_sentry')} + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items()} assert len(orig[3]) == len(compiled_cleaned) for k in orig[3].keys(): assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) @@ -152,8 +152,6 @@ def _get_fc_weights(state_dict: dict, prefix): result = {} new_state_dict = {} for k, v in state_dict.items(): - if k.endswith('_grad_sentry'): - continue if k.startswith(prefix): result[k[len(prefix):]] = v else: @@ -232,13 +230,13 @@ def test_submodules_dp_gpu1(): assert torch.allclose(orig[1], compiled[1], rtol=1e-6, atol=1e-6) # loss # grad - compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items() if not k.endswith('_grad_sentry')} + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items()} assert len(orig[2]) == len(compiled_cleaned) for k in orig[2].keys(): assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) # weights - compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items() if not k.endswith('_grad_sentry')} + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items()} assert len(orig[3]) == len(compiled_cleaned) for k in orig[3].keys(): assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) @@ -257,13 +255,13 @@ def test_submodules_dp_gpu2(): assert torch.allclose(orig[1], compiled[1], rtol=eps, atol=eps) # loss # grad - compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items() if not k.endswith('_grad_sentry')} + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items()} assert len(orig[2]) == len(compiled_cleaned) for k in orig[2].keys(): assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=eps, atol=eps) # weights - compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items() if not k.endswith('_grad_sentry')} + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items()} assert len(orig[3]) == len(compiled_cleaned) for k in orig[3].keys(): assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=eps, atol=eps) diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py new file mode 100644 index 00000000..4fa7cfd1 --- /dev/null +++ b/tests/parallel_module/test_wholemodule.py @@ -0,0 +1,173 @@ +import tempfile +import itertools +import re +from pathlib import Path +import shutil + +import torch +from torch import nn +import numpy as np + +from cube.parallel import ComputeConfig, parallelize, build_optimizer +from cube.runtime.module import ParallelModule + +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively + + +class FcRelu(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, out_features, bias=bias) + self.relu2 = nn.ReLU() + self.fc3 = CubeLinear(out_features, out_features, bias=bias) + self.relu3 = nn.ReLU() + + + def forward(self, x): + return self.relu3(self.fc3(self.relu2(self.fc2(self.relu1(self.fc1(x)))))) + + +class FcRelu_4_4(FcRelu): + def __init__(self): + super().__init__(4, 4) + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + dynamic_shape=True, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + + +def _create_modules(pas, compute_config, cube_savedir): + class OrigModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc_relu1 = FcRelu_4_4() + self.fc_relu2 = FcRelu_4_4() + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.fc_relu1(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + orig_module = OrigModule().cuda() + init_random() + compiled_module = _to_cube_model(OrigModule(), pas, compute_config, cube_savedir, 'orig_module_whole').cuda() + return orig_module, compiled_module + + +def _train(model): + init_random() + + loss_fn = nn.BCELoss() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + data = [] + DATA_SIZE = 20 + UPDATE_FREQ = 1 # TODO: update_freq support + for _ in range(DATA_SIZE): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + results = [] + for i, (x, y) in enumerate(data): + model.train() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + if i % UPDATE_FREQ == UPDATE_FREQ - 1: + grads = {n: p.grad for n, p in model.named_parameters()} + results.append(clone_to_cpu_recursively([y_pred, loss, grads])) + optimizer.step() + optimizer.zero_grad() + weights = {n: p.data for n, p in model.named_parameters()} + results[-1].append(clone_to_cpu_recursively(weights)) + return results + + +def _gpu_worker(pas, ngpus): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) + orig_results = _train(orig_module) + compiled_results = _train(compiled_module) + return ( + orig_results, + compiled_results, + compiled_module.get_full_map(), + compiled_module.get_dist_param_map(), + ) + + +def test_module_tp_gpu1(): + if not torch.cuda.is_available(): + print('skip test_submodules_tp_gpu1 due to lack of cuda devices') + return + results = launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) + orig_results, compiled_results, _, _ = results[0] + for orig, compiled in zip(orig_results, compiled_results): + assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred + assert torch.allclose(orig[1], compiled[1], rtol=1e-6, atol=1e-6) # loss + + # grad + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[2].items()} + assert len(orig[2]) == len(compiled_cleaned) + for k in orig[2].keys(): + assert torch.allclose(orig[2][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + + # weights + compiled_cleaned = {re.sub(r"_[0-9]+", '', k).replace('.', '_'): v for k, v in compiled[3].items()} + assert len(orig[3]) == len(compiled_cleaned) + for k in orig[3].keys(): + assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) + + +def _compare_weights(orig0, orig1, compiled0, compiled1, module_fullmap, module_dist_param_map): + cube_state = [(compiled0, {'state':{}}, module_dist_param_map[0], module_fullmap[0]), (compiled1, {'state':{}}, module_dist_param_map[1], module_fullmap[1])] + merged_state, _ = ParallelModule.merge_partial_states(cube_state) + assert len(compiled1) == len(compiled0) == len(orig0) + for k, v in merged_state.items(): + assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) + + +def test_module_tp_gpu2(): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + print('skip test_submodules_tp_gpu2 due to lack of cuda devices') + return + results = launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) + results0, results1 = results[0], results[1] + eps = 1e-4 + + module_fullmap = results0[2], results1[2] + module_dist_param_map = results0[3], results1[3] + + for orig0, compiled0, orig1, compiled1 in zip(results0[0], results0[1], results1[0], results1[1]): + assert torch.allclose(orig0[0], orig1[0], rtol=eps, atol=eps) # pred + assert torch.allclose(orig0[0], compiled0[0], rtol=eps, atol=eps) # pred + assert torch.allclose(orig1[0], compiled1[0], rtol=eps, atol=eps) # pred + + assert torch.allclose(orig0[1], orig1[1], rtol=eps, atol=eps) # loss + assert torch.allclose(orig0[1], compiled0[1], rtol=eps, atol=eps) # loss + assert torch.allclose(orig1[1], compiled1[1], rtol=eps, atol=eps) # loss + + # grad + for k in orig0[2].keys(): + assert torch.allclose(orig0[2][k], orig1[2][k], rtol=eps, atol=eps) + _compare_weights(orig0[2], orig1[2], compiled0[2], compiled1[2], module_fullmap, module_dist_param_map) + + # weights + for k in orig0[3].keys(): + assert torch.allclose(orig0[3][k], orig1[3][k], rtol=eps, atol=eps) + _compare_weights(orig0[3], orig1[3], compiled0[3], compiled1[3], module_fullmap, module_dist_param_map) From 90ba4b105e102678a2852f0ca30bf454922e396e Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 23 Aug 2023 08:35:20 +0000 Subject: [PATCH 1486/1892] Merged PR 1771: codegen: replace staticmethod with method codegen: replace staticmethod with method The goal is to make global CompileFlag to be an argument to codegen. --- cube/codegen/__init__.py | 3 +- cube/codegen/emit.py | 107 ++++++++++++++---------------- cube/codegen/frontend_mapping.py | 45 +++++-------- cube/codegen/lifecycle.py | 17 ++--- cube/codegen/module/autograd.py | 22 +++--- cube/codegen/module/module.py | 56 ++++++++-------- cube/codegen/schedule/schedule.py | 63 +++++++++--------- tox.ini | 2 +- 8 files changed, 143 insertions(+), 172 deletions(-) diff --git a/cube/codegen/__init__.py b/cube/codegen/__init__.py index b9af357b..84f4bcd6 100644 --- a/cube/codegen/__init__.py +++ b/cube/codegen/__init__.py @@ -1,3 +1,2 @@ - from cube.codegen.module.module import ModuleCodeGen -from cube.codegen.schedule.schedule import ScheduleCodeGen \ No newline at end of file +from cube.codegen.schedule.schedule import ScheduleCodeGen diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 6b6d2886..c91ad547 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -20,7 +20,7 @@ class IRValue: def __init__(self, name: str): self.name = name - + def __repr__(self): return self.name @@ -29,13 +29,10 @@ class CodeEmission: """ Basic emission """ - - @staticmethod - def node_name(node: IRCell) -> str: + def node_name(self, node: IRCell) -> str: return f"{node.name}{node.cid}" - - @staticmethod - def tensor_name(tensor: Any, prefix_attr: Optional[str] = None) -> str: + + def tensor_name(self, tensor: Any, prefix_attr: Optional[str] = None) -> str: """ Return the var name. For tensor, return the {prefix}{tensor.name}_{tensor.tid} @@ -54,21 +51,19 @@ def tensor_name(tensor: Any, prefix_attr: Optional[str] = None) -> str: if prefix_attr is not None and tensor.is_attr(): name = prefix_attr + name else: - name = str(IRSegment.modify_objects_of_complex(tensor, CodeEmission.tensor_name)).replace('\'', '') + name = str(IRSegment.modify_objects_of_complex(tensor, self.tensor_name)).replace('\'', '') return name - @staticmethod - def complex_name(val: Any, prefix_attr: Optional[str]=None) -> str: + def complex_name(self, val: Any, prefix_attr: Optional[str]=None) -> str: """ Return the val name with complex data type over IRObject Currently support complex data type of Dict, List, Tuple, IRObject """ - modifier = lambda t: IRValue(CodeEmission.tensor_name(t, prefix_attr)) + modifier = lambda t: IRValue(self.tensor_name(t, prefix_attr)) val = IRSegment.modify_objects_of_complex(val, modifier) return str(val) - @staticmethod - def tuple_name(tensors: List[Any], + def tuple_name(self, tensors: List[Any], skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: """ Return the tupled tensor name. @@ -83,34 +78,31 @@ def tuple_name(tensors: List[Any], for t in tensors: if isinstance(t, IRTensor) and skip_attr and t.is_attr(): continue - names.append(CodeEmission.tensor_name(t, prefix_attr)) + names.append(self.tensor_name(t, prefix_attr)) name = '(' + ', '.join(names + ['']) + ')' return name - - @staticmethod - def return_name(tensors: List[Any], + + def return_name(self, tensors: List[Any], skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: names = [] for t in tensors: if isinstance(t, IRTensor) and skip_attr and t.is_attr(): continue - names.append(CodeEmission.tensor_name(t, prefix_attr)) + names.append(self.tensor_name(t, prefix_attr)) names = '_' if len(names) == 0 else ', '.join(names) return names - - @staticmethod - def return_name_complex(vals: List[Any], + + def return_name_complex(self, vals: List[Any], skip_attr: bool = False, prefix_attr: Optional[str] = None) -> str: names = [] for t in vals: if isinstance(t, IRObject) and skip_attr and t.is_attr(): continue - names.append(CodeEmission.complex_name(t, prefix_attr)) + names.append(self.complex_name(t, prefix_attr)) names = '_' if len(names) == 0 else ', '.join(names) return names - - @staticmethod - def kwargs_name(**kwargs) -> str: + + def kwargs_name(self, **kwargs) -> str: """Get kwarg name""" names = [] # FIXME make the str include `""` @@ -118,23 +110,24 @@ def kwargs_name(**kwargs) -> str: # if isinstance(val, str) and not val.startswith('self.'): # kwargs[name] = '"' + val + '"' # turn object into name - modifier = lambda t: IRValue(CodeEmission.tensor_name(t)) + modifier = lambda t: IRValue(self.tensor_name(t)) kwargs = IRSegment.modify_objects_of_complex(kwargs, modifier) for name, val in kwargs.items(): names.append(f'{name}={val}') name = ', '.join(names) return name - + class FuncEmission(CodeEmission): + def __init__(self): + super().__init__() + self._emit_rules = Sign2EmitRule() + + def emit_dataloader(self, node: IRDataOperation) -> List[str]: + outputs = self.return_name(node.outputs()) + return [f'{outputs} = next({self.tensor_name(node.input(0))})'] - @staticmethod - def emit_dataloader(node: IRDataOperation) -> List[str]: - outputs = FuncEmission.return_name(node.outputs()) - return [f'{outputs} = next({FuncEmission.tensor_name(node.input(0))})'] - - @staticmethod - def emit_fnode(node: IRFwOperation, prefix_attr: str = None) -> List[str]: + def emit_fnode(self, node: IRFwOperation, prefix_attr: str = None) -> List[str]: """Emit forward node code The result will look like (the lines are split into `List[str]`) @@ -151,38 +144,37 @@ def emit_fnode(node: IRFwOperation, prefix_attr: str = None) -> List[str]: # insert comment if node.comment is not None: codes.append(f'# {node.comment}') - + signature = node.signature # setup arg string - inputs = [FuncEmission.tensor_name(t, prefix_attr=prefix_attr) for t in node.inputs()] + inputs = [self.tensor_name(t, prefix_attr=prefix_attr) for t in node.inputs()] # setup kwarg string kwargs = dict(**node.kwargs) for name, val in kwargs.items(): if isinstance(val, str) and not val.startswith('self.'): kwargs[name] = '"' + val + '"' # turn IRObject into name - modifier = lambda t: IRValue(CodeEmission.tensor_name(t)) + modifier = lambda t: IRValue(self.tensor_name(t)) kwargs = IRSegment.modify_objects_of_complex(kwargs, modifier) - emit_rule = Sign2EmitRule.map(signature) + emit_rule = self._emit_rules.map(signature) body = emit_rule(node, inputs, kwargs) if len(node.outputs()) == 0: code = body else: - outputs = [FuncEmission.tensor_name(t) for t in node.outputs()] + outputs = [self.tensor_name(t) for t in node.outputs()] outputs = ', '.join(outputs) code = f'{outputs} = {body}' codes.append(code) return codes - - @staticmethod - def emit_adapter(node: IRAdapter, prefix_attr: Optional[str] = None) -> List[str]: + + def emit_adapter(self, node: IRAdapter, prefix_attr: Optional[str] = None) -> List[str]: """ Emit the statment of the adapter call - + The resultant `List[str]` will be lines of the statements of the final - Python method for the targeted Segment, + Python method for the targeted Segment, without the method signature and the return statement. The fields storing intermediate codes that are populated by this method: @@ -204,23 +196,22 @@ def emit_adapter(node: IRAdapter, prefix_attr: Optional[str] = None) -> List[str for prim in prims: if len(prim.inputs()) == 1: - itensors = FuncEmission.tensor_name(prim.inputs()[0], prefix_attr=prefix_attr) + itensors = self.tensor_name(prim.inputs()[0], prefix_attr=prefix_attr) else: - itensors = FuncEmission.tuple_name(prim.inputs(), prefix_attr=prefix_attr) + itensors = self.tuple_name(prim.inputs(), prefix_attr=prefix_attr) prim_kwargs = dict(prim.kwargs) if async_op: prim_kwargs['async_op'] = True - kwargs = FuncEmission.kwargs_name(**prim_kwargs) - outputs = FuncEmission.return_name(prim.outputs()) + kwargs = self.kwargs_name(**prim_kwargs) + outputs = self.return_name(prim.outputs()) code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' codes.append(code) return codes - - @staticmethod - def emit_reducer(node: IRWeightReducer) -> List[str]: + + def emit_reducer(self, node: IRWeightReducer) -> List[str]: """ Emit the statment to invoke a reducer object. - + The fields storing intermediate codes that are populated by this method: - NONE """ @@ -228,13 +219,11 @@ def emit_reducer(node: IRWeightReducer) -> List[str]: code = f'{reducer_name}.sync_grads()' return [code] - @staticmethod - def emit_release(tensors: Iterable[IRTensor]) -> str: - tnames : Generator = (FuncEmission.tensor_name(t) for t in tensors) + def emit_release(self, tensors: Iterable[IRTensor]) -> str: + tnames : Generator = (self.tensor_name(t) for t in tensors) return 'del ' + ', '.join(tnames) - - @staticmethod - def get_backward_callsite_io_tensors(bwop: IRCell) -> Tuple: + + def get_backward_callsite_io_tensors(self, bwop: IRCell) -> Tuple: """ Get backward inputs and outputs ``` @@ -242,7 +231,7 @@ def get_backward_callsite_io_tensors(bwop: IRCell) -> Tuple: #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~ #inputs to 'backward' outputs of 'backward' ``` - + @return input_tensors List[IRSubTensor]: forward input tensors (backward input) @return output_tensors List[IRSubTensor]: forward output tensors (backward output) @return output_grads List[IRSubTensor]: gradient of forward output tensors diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index df80c58a..d467dbf4 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -7,28 +7,31 @@ from cube.ir.cten import IRTensor from cube.ir.operator import IRFwOperation -import torch - class Sign2EmitRule: """Emit rule for frontend PyTorch codegen""" - _sign2rule = {} + def __init__(self) -> None: + # the registered emit rules + self._sign2rule = { + 'torch.slice': self.emit_slice, + 'setattr': self.emit_setattr, + 'builtins.getattr': self.emit_getattr, + '_operator.getitem': self.emit_getitem, + } - @staticmethod - def map(signature: str) -> Callable: + def map(self, signature: str) -> Callable: """Get the emit rule for the given signature - + Args: signature (str): signature of the operator Returns: Callable: emit rule that takes the node, args (List[str]) and kwargs (Dict[str, str]) as input """ - return Sign2EmitRule._sign2rule.get(signature, Sign2EmitRule.emit_common) + return self._sign2rule.get(signature, self.emit_common) - @staticmethod - def emit_common(node: IRFwOperation, args: List[str], kwargs: Dict[str, str]) -> str: + def emit_common(self, node: IRFwOperation, args: List[str], kwargs: Dict[str, str]) -> str: """Default rule to join all args and kwargs""" signature = node.signature @@ -41,8 +44,7 @@ def emit_common(node: IRFwOperation, args: List[str], kwargs: Dict[str, str]) -> args = ", ".join(list(args) + kw_pairs) return f"{signature}({args})" - @staticmethod - def emit_slice(node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + def emit_slice(self, node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: """Special rule for generating slice node The op is: @@ -71,8 +73,7 @@ def emit_slice(node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[str, str return f"{in_tensor_var}[{', '.join(subscript_components)}]" - @staticmethod - def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + def emit_setattr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: """Special rule for generating setattr node """ @@ -80,27 +81,15 @@ def emit_setattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: member = f'"{arg_vars[1][5:]}"' return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" - @staticmethod - def emit_getattr(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + def emit_getattr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: """Special rule for generating getattr node """ return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" - @staticmethod - def emit_getitem(node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + def emit_getitem(self, node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: """Special rule for generating getitem node """ if len(arg_vars) == 2 and len(kw_pairs) == 0 and not arg_vars[1].replace('_', '').isdigit(): return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" else: - return Sign2EmitRule.emit_common(node, arg_vars, kw_pairs) - - -# the registered emit rules -Sign2EmitRule._sign2rule = { - 'torch.slice': Sign2EmitRule.emit_slice, - 'setattr': Sign2EmitRule.emit_setattr, - 'builtins.getattr': Sign2EmitRule.emit_getattr, - '_operator.getitem': Sign2EmitRule.emit_getitem, -} - + return self.emit_common(node, arg_vars, kw_pairs) diff --git a/cube/codegen/lifecycle.py b/cube/codegen/lifecycle.py index 1767bb38..0be80895 100644 --- a/cube/codegen/lifecycle.py +++ b/cube/codegen/lifecycle.py @@ -15,6 +15,7 @@ def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: graph_inputs = IRSegment.get_objects_from_complex(graph_inputs) graph_outputs = IRSegment.get_objects_from_complex(graph_outputs) + func_emission = FuncEmission() self.nodes: Dict[int] = {node: lid for lid, node in enumerate(nodes)} # the last line id of consuming or producing a tensor @@ -25,9 +26,9 @@ def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: # FIXME: consider the case of IRObject in the kwargs of IRFwOperation # is_activation = lambda t: isinstance(t, IRObject) and not t.is_attr() is_activation = lambda t: isinstance(t, IRSubTensor) and not t.is_attr() - + self.lifetime.update((tsin, 0) for tsin in graph_inputs if is_activation(tsin)) - + for i, node in enumerate(nodes): outputs : Iterable[IRObject] @@ -41,7 +42,7 @@ def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: # backward segment else: fw_inputs, fw_outputs, output_grads, input_grads = \ - FuncEmission.get_backward_callsite_io_tensors(node) + func_emission.get_backward_callsite_io_tensors(node) # remove loss gradient output_grads = [t for t in output_grads if not t.is_loss()] @@ -73,23 +74,23 @@ def release_tensors_after_line(self, line_id: int) -> List[IRSubTensor]: Get the releasable IRSubTensors after finish of executing of `line_id`. @param line_id int - + @return tensors List[IRSubTensors]: tensors that can be released. """ return self.release.get(line_id, []) - + def release_tensors_after_node(self, node: IRCell) -> List[IRSubTensor]: """ Get the releasable IRSubTensors after finish of executing of the node. @param line_id int - + @return tensors List[IRSubTensors]: tensors that can be released. """ assert node in self.nodes line_id = self.nodes[node] return self.release.get(line_id, []) - + def releasable_after_node(self, tensor: IRSubTensor, node: IRCell) -> bool: """ Check if the tensor is releasable after executing the node @@ -103,7 +104,7 @@ def releasable_after_node(self, tensor: IRSubTensor, node: IRCell) -> bool: assert tensor in self.lifetime[tensor] line_id = self.nodes[node] return self.lifetime[tensor] < line_id - + def releasable_after_line(self, tensor: IRSubTensor, line: int) -> bool: """ Check if the tensor is releasable after executing the node diff --git a/cube/codegen/module/autograd.py b/cube/codegen/module/autograd.py index dd17790c..25fa2782 100644 --- a/cube/codegen/module/autograd.py +++ b/cube/codegen/module/autograd.py @@ -1,5 +1,4 @@ from typing import List -from cube.codegen.emit import FuncEmission from cube.ir.tensor import IRSubTensor from cube.ir.adapter import IRAdapter @@ -26,43 +25,42 @@ def __init__(self): def emit_prim(self, prim: IRAdapterPrim) -> str: if len(prim.inputs()) == 1: - itensors = FuncEmission.tensor_name(prim.inputs()[0]) + itensors = self.tensor_name(prim.inputs()[0]) else: - itensors = FuncEmission.tuple_name(prim.inputs()) + itensors = self.tuple_name(prim.inputs()) kwargs = list() for name, val in prim.kwargs.items(): kwargs.append(f'{name}={val}') kwargs = ', '.join(kwargs) - outputs = FuncEmission.return_name(prim.outputs()) + outputs = self.return_name(prim.outputs()) code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' return code def gen(self, fadapter: IRAdapter) -> List[str]: assert fadapter.isfw() and fadapter.differentiable and fadapter.custom, "generate autograd for a non-differentiable adapter" assert fadapter.mirror is not None - name = AutogradAdapterCodeGen.name(fadapter) + name = self.name(fadapter) with ClassBlock(class_name=name, derived=['torch.autograd.Function']) as cb: # forward cb.insert_body('@staticmethod') - finputs = [FuncEmission.tensor_name(t) for t in fadapter.inputs()] + finputs = [self.tensor_name(t) for t in fadapter.inputs()] with FunctionBlock(func_name='forward', args=['ctx']+finputs) as fw: for prim in fadapter.prims: fw.insert_body(self.emit_prim(prim)) - outputs = FuncEmission.return_name(fadapter.outputs()) + outputs = self.return_name(fadapter.outputs()) fw.insert_body(f'return {outputs}') cb.insert_body(fw.code) # backward cb.insert_body('@staticmethod') badapter: IRAdapter = fadapter.mirror - binputs = [FuncEmission.tensor_name(t) for t in badapter.inputs()] + binputs = [self.tensor_name(t) for t in badapter.inputs()] with FunctionBlock(func_name='backward', args=['ctx']+binputs) as bw: for prim in badapter.prims: bw.insert_body(self.emit_prim(prim)) - outputs = FuncEmission.return_name(badapter.outputs()) + outputs = self.return_name(badapter.outputs()) bw.insert_body(f'return {outputs}') cb.insert_body(bw.code) return cb.code - - @staticmethod - def name(adapter: IRAdapter) -> str: + + def name(self, adapter: IRAdapter) -> str: return f'Adapter{adapter.cid}' diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index c81e112f..a5ed0192 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -305,11 +305,12 @@ def gen( # init customized adapter fsegments = [node for node in sequence if isinstance(node, IRSegment) and node.isfw()] + autograd_adapter_gen = AutogradAdapterCodeGen() for seg in fsegments: for adapter in seg.select(ntype=IRAdapter): if adapter.differentiable and adapter.custom: - gencode += AutogradAdapterCodeGen().gen(adapter) + ['', ''] - adapter.signature = AutogradAdapterCodeGen.name(adapter) + '.apply' + gencode += autograd_adapter_gen.gen(adapter) + ['', ''] + adapter.signature = autograd_adapter_gen.name(adapter) + '.apply' # initialize communication groups self.emit_comm_groups() @@ -346,9 +347,9 @@ def gen( for t in node.inputs(): if isinstance(t, IRSubTensor): if not t.is_attr(): - args.append(ModuleCodeGen.tensor_name(t)) + args.append(self.tensor_name(t)) else: - args.append(ModuleCodeGen.tensor_name(t)) + args.append(self.tensor_name(t)) node_args.append(args) # generate full code @@ -365,7 +366,7 @@ def gen( cb.insert_body(ib.code) segment_idxs =[] for idx, node in enumerate(gen_nodes): - name = ModuleCodeGen.node_name(node) + name = self.node_name(node) input_args = ['self'] + node_args[idx] forward_code = self.model_methods_bodies[idx] if isinstance(node, IRSegment): @@ -374,7 +375,7 @@ def gen( with FunctionBlock(func_name=name, args=input_args) as fb: fb.insert_body(forward_code) # generate output - outputs = [ModuleCodeGen.tensor_name(t) for t in node.outputs()] + outputs = [self.tensor_name(t) for t in node.outputs()] return_code = f"return {', '.join(outputs)}" fb.insert_body(return_code) cb.insert_body('') @@ -403,10 +404,10 @@ def gen( raise ValueError(f"Invalid extra forward args: only *args & **kwargs are allowed") with FunctionBlock(func_name='_forward_impl', args=['self'] + (forward_arg_names or inputs)) as fb: - outputs = ScheduleCodeGen.return_name(node.outputs(), skip_attr=True) - call_code = f'{outputs} = self.{ScheduleCodeGen.node_name(node)}({", ".join(inputs)})' + outputs = self.return_name(node.outputs(), skip_attr=True) + call_code = f'{outputs} = self.{self.node_name(node)}({", ".join(inputs)})' fb.insert_body(call_code) - return_code = f'return {ScheduleCodeGen.return_name_complex(self.execplan.graph.outputs())}' + return_code = f'return {self.return_name_complex(self.execplan.graph.outputs())}' fb.insert_body(return_code) cb.insert_body('') cb.insert_body(fb.code) @@ -455,13 +456,13 @@ def init_attributes(self, node: IRCell): map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" if not isinstance(node, IRSegment): for itensor in node.inputs(): - name = ModuleCodeGen.tensor_name(itensor, prefix_attr='self.') + name = self.tensor_name(itensor, prefix_attr='self.') if isinstance(itensor, IRSubTensor): if itensor.is_attr() and not self.symbols.exist(name): self.symbols.create(name) sign = psign if itensor.is_param() else bsign code = sign.format( - name=ModuleCodeGen.tensor_name(itensor), + name=self.tensor_name(itensor), shape=tuple(itensor.shape), dtype=itensor.dtype ) @@ -470,7 +471,7 @@ def init_attributes(self, node: IRCell): slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) val_chunks = itensor.valmap[1] code = map_sign.format( - attr=ModuleCodeGen.tensor_name(itensor), tid=tid, + attr=self.tensor_name(itensor), tid=tid, slicers=str(slicers), val_chunks=val_chunks ) self.model_init_statements.append(code) @@ -480,7 +481,7 @@ def init_attributes(self, node: IRCell): if not hasattr(self._ref_module, name[5:]): raise NotImplementedError("member attribute is not added") for output in node.outputs(): - self.symbols.create(ModuleCodeGen.tensor_name(output, prefix_attr='self.')) + self.symbols.create(self.tensor_name(output, prefix_attr='self.')) else: for sub_node in node.nodes(): self.init_attributes(sub_node) @@ -516,15 +517,14 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: reducer=reducer_name, ranks=ranks, reduce_op=reduce_op, async_op=async_op, zero=zero, max_nbytes=max_nbytes, zero_ngroups=zero_ngroups) self.model_init_statements.append(init_code) - weights = [ModuleCodeGen.tensor_name(t, prefix_attr='self.') for t in weights] + weights = [self.tensor_name(t, prefix_attr='self.') for t in weights] for weight in weights: add_param_code = add_param.format(reducer=reducer_name, weight=weight) self.model_init_statements.append(add_param_code) add_code = reducer_add.format(reducer=reducer_name) self.model_init_statements.append(add_code) - @staticmethod - def emit_segment(segment: IRSegment) -> List[str]: + def emit_segment(self, segment: IRSegment) -> List[str]: """ Emit IRSegment code. @@ -564,11 +564,11 @@ def emit_segment(segment: IRSegment) -> List[str]: assert len(rc_group) > 0 gid: Optional[int] = rc_group[0].recompute if gid is None: - codes += ModuleCodeGen._emit_nodes(rc_group, lifetime) + codes += self._emit_nodes(rc_group, lifetime) else: # get recompute excution code rc_segment = segment.create_segment(rc_group) - rc_codes = ModuleCodeGen._emit_recompute(rc_group, + rc_codes = self._emit_recompute(rc_group, rc_segment.inputs(), rc_segment.outputs(), lifetime) codes += rc_codes # release input tensors after exiting a RC group: @@ -577,13 +577,12 @@ def emit_segment(segment: IRSegment) -> List[str]: if last_node != nodes[-1]: # skip if it is the last node inputs_to_rel = [t for t in rc_segment.inputs() if lifetime.releasable_after_line(t, line)] if len(inputs_to_rel) > 0: - del_stmt = ModuleCodeGen.emit_release(inputs_to_rel) + del_stmt = self.emit_release(inputs_to_rel) codes.append(del_stmt) return codes - @staticmethod - def _emit_nodes(nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: + def _emit_nodes(self, nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: """ Emit code to invoke operations and adapter, e.g. (the lines are split into `List[str]`) @@ -602,22 +601,21 @@ def _emit_nodes(nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: for node in nodes: # execute if isinstance(node, IRFwOperation): - code = ModuleCodeGen.emit_fnode(node, prefix_attr='self.') + code = self.emit_fnode(node, prefix_attr='self.') node_codes += code elif isinstance(node, IRAdapter): - code = ModuleCodeGen.emit_adapter(node) + code = self.emit_adapter(node) node_codes += code else: raise RuntimeError(f"unexpected type {type(node)} in IRSegment") # release tensors_to_del = lifecycle.release_tensors_after_node(node) if len(tensors_to_del) > 0: - node_codes.append(FuncEmission.emit_release(tensors_to_del)) + node_codes.append(self.emit_release(tensors_to_del)) return node_codes - @staticmethod - def _emit_recompute(nodes: Tuple[IRCell], inputs: List[IRSubTensor], outputs: List[IRSubTensor], + def _emit_recompute(self, nodes: Tuple[IRCell], inputs: List[IRSubTensor], outputs: List[IRSubTensor], lifecycle: LifeCycle) -> List[str]: """ Emit code to define a Python function for Recomputing and invoke it @@ -648,9 +646,9 @@ def recompute(tensor_2222): assert len(nodes) > 0 inputs = [t for t in inputs if not t.is_attr()] - input_names = [FuncEmission.tensor_name(t) for t in inputs] + input_names = [self.tensor_name(t) for t in inputs] input_names_tuple = ', '.join(input_names) - output_names = [FuncEmission.tensor_name(t) for t in outputs] + output_names = [self.tensor_name(t) for t in outputs] output_names_tuple = ', '.join(output_names) # 'graph.segment(nodes)' ensures that if a tensor is no longer used (in RC group or in later code), @@ -669,7 +667,7 @@ def recompute(tensor_2222): # for ncode in ModuleCodeGen._emit_nodes(nodes, lifecycle): # fb.insert_body(ncode) - fb.insert_body(ModuleCodeGen._emit_nodes(nodes, lifecycle)) + fb.insert_body(self._emit_nodes(nodes, lifecycle)) fb.insert_body(f'return {output_names_tuple}') codes = [''] + fb.code + [''] codes.append( diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index 5194a2c0..a62765df 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -55,9 +55,9 @@ def gen(self, device: int, outfile=None, attach=None) -> str: lifetime = LifeCycle(device_nodes, [], self.execplan.graph.outputs()) - args = ['model'] + [ScheduleCodeGen.tensor_name(t) for t in self.execplan.graph.inputs()] + args = ['model'] + [self.tensor_name(t) for t in self.execplan.graph.inputs()] - with FunctionBlock(func_name='_train_step', + with FunctionBlock(func_name='_train_step', args=args) as fb: fb.insert_body('_ = None') fb.insert_body('model.zero_grad()') @@ -71,14 +71,14 @@ def gen(self, device: int, outfile=None, attach=None) -> str: else: for line, node in enumerate(device_nodes): # execute - codes = ScheduleCodeGen.emit_node(node) + codes = self.emit_node(node) fb.insert_body(codes) # release tensors = lifetime.release_tensors_after_line(line) if len(tensors) > 0 : # not necessarily to have one after each line - fb.insert_body(ScheduleCodeGen.emit_release(tensors)) + fb.insert_body(self.emit_release(tensors)) # return code - outputs = ScheduleCodeGen.return_name_complex(self.execplan.graph.outputs()) + outputs = self.return_name_complex(self.execplan.graph.outputs()) code = f'return {outputs}' fb.insert_body(code) gencode += fb.code @@ -93,7 +93,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: if isinstance(self.execplan.graph.sched, IRScheduleStrategy): _logger.warning('using legacy IRScheduleStrategy cannot generate inference code. ' 'Switch to use scheduling without strategy') - with FunctionBlock(func_name='_infer_step', + with FunctionBlock(func_name='_infer_step', args=args) as fb: fb.insert_body('_ = None') # body code @@ -102,15 +102,15 @@ def gen(self, device: int, outfile=None, attach=None) -> str: for line, node in enumerate(device_nodes): if not node.isfw(): continue # skip backward segments and adapters # execute - codes = ScheduleCodeGen.emit_node(node, force_no_grad=True) + codes = self.emit_node(node, force_no_grad=True) fb.insert_body(codes) # release tensors = lifetime.release_tensors_after_line(line) tensors = [t for t in tensors if isinstance(t, IRTensor) and not t.is_grad()] if len(tensors) > 0 : # not necessarily to have one after each line - fb.insert_body(ScheduleCodeGen.emit_release(tensors)) + fb.insert_body(self.emit_release(tensors)) # return code - outputs = ScheduleCodeGen.return_name_complex(self.execplan.graph.outputs()) + outputs = self.return_name_complex(self.execplan.graph.outputs()) code = f'return {outputs}' fb.insert_body(code) gencode += fb.code @@ -123,8 +123,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: f.write(code) return code - @staticmethod - def emit_node(node: IRCell, force_no_grad: bool = False) -> List[str]: + def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: """ Emit node / subgraph code """ @@ -137,12 +136,12 @@ def emit_node(node: IRCell, force_no_grad: bool = False) -> List[str]: req_grad = False if force_no_grad else req_grad # handle for forward - inputs = ScheduleCodeGen.tuple_name(node_inputs, skip_attr=True, prefix_attr='model.') - outputs = ScheduleCodeGen.return_name(node_outputs, skip_attr=True, prefix_attr='model.') - + inputs = self.tuple_name(node_inputs, skip_attr=True, prefix_attr='model.') + outputs = self.return_name(node_outputs, skip_attr=True, prefix_attr='model.') + unwrap_node = node.cell if isinstance(node, ExeReuseCell) else node - name = ScheduleCodeGen.node_name(unwrap_node) - + name = self.node_name(unwrap_node) + if isinstance(unwrap_node, IRSegment): # emit forward segment if node.isfw(): @@ -156,21 +155,21 @@ def emit_node(node: IRCell, force_no_grad: bool = False) -> List[str]: else: # get gradient computation arguments input_tensors, output_tensors, output_grads, input_grads = \ - ScheduleCodeGen.get_backward_callsite_io_tensors(node) + self.get_backward_callsite_io_tensors(node) # special handle for loss for idx, tensor in enumerate(output_grads): if isinstance(tensor, IRSubTensor) and tensor.is_loss(): output_grads[idx] = None code = bsign.format( - name = f"'{ScheduleCodeGen.node_name(unwrap_node.mirror)}'", - input_grads = ScheduleCodeGen.return_name(input_grads), - input_tensors = ScheduleCodeGen.tuple_name(input_tensors, skip_attr=True, prefix_attr='model.'), - output_tensors = ScheduleCodeGen.tuple_name(output_tensors, skip_attr=True, prefix_attr='model.'), - output_grads = ScheduleCodeGen.tuple_name(output_grads, skip_attr=True, prefix_attr='model.') + name = f"'{self.node_name(unwrap_node.mirror)}'", + input_grads = self.return_name(input_grads), + input_tensors = self.tuple_name(input_tensors, skip_attr=True, prefix_attr='model.'), + output_tensors = self.tuple_name(output_tensors, skip_attr=True, prefix_attr='model.'), + output_grads = self.tuple_name(output_grads, skip_attr=True, prefix_attr='model.') ) elif isinstance(unwrap_node, IRDataOperation): - code = ScheduleCodeGen.emit_dataloader(unwrap_node)[0] + code = self.emit_dataloader(unwrap_node)[0] elif isinstance(unwrap_node, IRAdapter): code = asign.format( @@ -190,22 +189,20 @@ def emit_node(node: IRCell, force_no_grad: bool = False) -> List[str]: else: raise RuntimeError(f"Unspported node type: {type(unwrap_node)}") - + return [code] - - @staticmethod - def emit_repetend(repetend: ExeRepetend) -> List[str]: + + def emit_repetend(self, repetend: ExeRepetend) -> List[str]: """ Emit code for executing a repetend """ with ForBlock(var=None, iters=f'range({repetend.repeat})') as fb: for node in repetend.nodes(): - ncode = ScheduleCodeGen.emit_node(node) + ncode = self.emit_node(node) fb.insert_body(ncode) return fb.code - @staticmethod - def emit_legacy_schedplan(schedplan: IRScheduleStrategy, devid: int) -> List[str]: + def emit_legacy_schedplan(self, schedplan: IRScheduleStrategy, devid: int) -> List[str]: """ Lagecy code """ @@ -214,13 +211,13 @@ def emit_legacy_schedplan(schedplan: IRScheduleStrategy, devid: int) -> List[str strkwargs = dict() for kwarg, val in kwargs.items(): if isinstance(val, IRCell): - name = 'model.' + ScheduleCodeGen.node_name(val) + name = 'model.' + self.node_name(val) elif isinstance(val, (tuple, list)): brackets = ')' if len(val) != 1 else ',)' - name = '(' + ', '.join('model.' + ScheduleCodeGen.node_name(n) \ + name = '(' + ', '.join('model.' + self.node_name(n) \ if isinstance(n, IRCell) else str(n) for n in val) + brackets else: - name = str(val) + name = str(val) strkwargs[kwarg] = name code = ', '.join(f'{kwarg}={name}' for kwarg, name in strkwargs.items()) code = f'{signature}({code})' diff --git a/tox.ini b/tox.ini index 7598dac7..4411b3ac 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = py310 +envlist = py38,py310 skipsdist = True [testenv] From a56107d1f336681bf5cf746d91339ee0ed4505dd Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 23 Aug 2023 08:41:15 +0000 Subject: [PATCH 1487/1892] Merged PR 1768: Add load and save interface for compile --- cube/compiler.py | 66 ++++++++++++++++++++++++++++++------------------ cube/program.py | 6 ++--- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 1fcd3a92..ab2541ea 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -34,9 +34,11 @@ _logger.setLevel(logging.INFO) -def compile(model: SemanticModel, *args, +def compile(model: Union[torch.nn.Module, SemanticModel], *args, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, model_dynamic_shape: bool = False, + load_graph_file: Optional[str] = None, + save_graph_file: Optional[str] = None, comm_cost_fn: Optional[Callable] = None, override = True, load_content = True, @@ -57,8 +59,11 @@ def train_iter(model, dataloader): model (SemanticModel | torch.nn.Module): single-device model args (Tuple[Any]): compile function example inputs PAS (Callable | Tuple[Callable, Callable, Callable]): policy to transform and schedule graph - model_dummy_inputs (Tuple[Any]): model example inputs when using torch.fx parser model_dynamic_shape (bool): whether to compile model with dynamic shape + load_graph_file (str | None): + load cached graph. This will skip parsing the function and model. + Note the user should keep correct `fullmodel.pt` if load_content is True. + save_graph_file (str | None): save parsed graph before applying policy. comm_cost_fn (Optional[Callable]): communication cost function, which takes in an IRAdapterPrim, and outputs a cost in float. By default (None) use communication volume. @@ -104,6 +109,7 @@ def train_iter(model, dataloader): myrank = DeviceGroup().rank def decorator(fn: Callable) -> Callable: + filename = 'gencode{}.py' if not override and os.path.exists(filename.format(myrank)): @@ -121,31 +127,43 @@ def decorator(fn: Callable) -> Callable: resource = cube.runtime.resource.EnvResource() # run once to get model structure and tensor shape - start = time.time() - outputs = fn(*inputs) - if outputs is None: - outputs = [] - elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = [outputs] - # setup program input - pinputs = [] - for input in inputs[1:]: # we don't consider `model` as inputs - if isinstance(input, SemanticModel): - pinputs.append('model') - elif isinstance(input, SemanticDataLoader): - pinputs.append(input.object) - else: - pinputs.append(input) - Program().set_input(pinputs) - # setup program output - Program().set_output(outputs) - Program().finalize() - span = time.time() - start - _logger.info('finish parsing iteration: {:.2f} s'.format(span)) + graph = None + if load_graph_file is None: + start = time.time() + outputs = fn(*inputs) + if outputs is None: + outputs = [] + elif not isinstance(outputs, (tuple, list)): + outputs = [outputs] + # setup program input + pinputs = [] + for input in inputs[1:]: # we don't consider `model` as inputs + if isinstance(input, SemanticModel): + pinputs.append('model') + elif isinstance(input, SemanticDataLoader): + pinputs.append(input.object) + else: + pinputs.append(input) + Program().set_input(pinputs) + # setup program output + Program().set_output(outputs) + Program().finalize() + span = time.time() - start + graph = Program().get_graph() + _logger.info('finish parsing iteration: {:.2f} s'.format(span)) + else: + # get cube graph of a iteration from tracer or cached file + start = time.time() + graph = IRGraph.load(load_graph_file) + span = time.time() - start + _logger.info('finish loading graph from {}: {:.2f} s'.format(load_graph_file, span)) + + if save_graph_file is not None and save_graph_file != load_graph_file: + _logger.info(f'saving graph to {save_graph_file}') + graph.dump(save_graph_file) # run policy start = time.time() - graph = Program().get_graph() assert callable(PAS), f"Policy PAS is not callable" graph = PAS(graph, resource) span = time.time() - start diff --git a/cube/program.py b/cube/program.py index 18d977c5..7b81aded 100644 --- a/cube/program.py +++ b/cube/program.py @@ -84,8 +84,8 @@ def __init__(self, dataloader: data.DataLoader): """ if not isinstance(dataloader, data.DataLoader): raise TypeError("Expected data loader derived from torch.utils.data.DataLoader") - self.dataloader: data.DataLoader = iter(dataloader) - self.object = IRObject(name='dataloader', value=self.dataloader) + self.dataloader: data.DataLoader = dataloader + self.object = IRObject(name='dataloader', value=None) def __iter__(self): return self @@ -103,7 +103,7 @@ def generate_output(sample): return tensor return IRObject('data', value=sample) # get dataloader sample - sample = next(self.dataloader) + sample = next(iter(self.dataloader)) # turn sample into IRObjects outputs = generate_output(sample) # create dataloader operation From 9505f26ac93b74f27ae0e743c647f92ef1836f82 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 23 Aug 2023 08:55:16 +0000 Subject: [PATCH 1488/1892] Merged PR 1732: Restructure Codebase - 0: Clean up and make examples re-work 1) examples are updated to meet with concrete trace parser. 2) remove examples that are no more used in cube due to previous --- cube/__init__.py | 1 + examples/atmosphere/algo_test.py | 139 ----- examples/atmosphere/policy/naive.py | 11 - examples/atmosphere/policy/replicate.py | 13 - examples/atmosphere/policy/split.py | 43 -- examples/atmosphere/weather.py | 457 --------------- examples/mlp/__init__.py | 0 examples/mlp/infer.py | 119 ---- examples/mlp/linearsfx.py | 161 ------ examples/mlp/policy/__init__.py | 0 examples/mlp/policy/gallery.py | 89 +++ examples/mlp/policy/mpmd.py | 103 ---- examples/mlp/policy/spmd.py | 231 -------- examples/mlp/train.py | 97 ++-- examples/nlp/__init__.py | 0 examples/nlp/blocks/__init__.py | 0 examples/nlp/blocks/attention.py | 204 ------- examples/nlp/blocks/encoder.py | 101 ---- examples/nlp/blocks/mlp_moe.py | 212 ------- .../nlp/blocks/{decoder.py => transformer.py} | 39 +- examples/nlp/gpt/__init__.py | 0 examples/nlp/gpt/infer.py | 122 ---- examples/nlp/gpt/model.py | 215 ++----- examples/nlp/gpt/policy/__init__.py | 0 examples/nlp/gpt/policy/mpmd.py | 248 +-------- examples/nlp/gpt/policy/spmd.py | 303 ++-------- examples/nlp/gpt/train.py | 85 +-- examples/nlp/mbart/__init__.py | 0 examples/nlp/mbart/model.py | 246 +++----- examples/nlp/mbart/policy/__init__.py | 0 examples/nlp/mbart/policy/gallery.py | 186 +++++++ examples/nlp/mbart/policy/mpmd.py | 312 ----------- examples/nlp/mbart/train.py | 110 ++-- examples/nlp/palm/module_profiler.py | 56 -- examples/nlp/palm/palm.py | 334 ----------- examples/nlp/palm/policy/mpmd.py | 150 ----- examples/nlp/palm/policy/spmd.py | 91 --- examples/nlp/torchscale/policy/mpmd.py | 103 ---- examples/nlp/torchscale/policy/spmd.py | 252 --------- examples/nlp/torchscale/run_torchscale_lm.py | 149 ----- examples/nlp/torchscale/run_torchscale_tl.py | 204 ------- examples/poisson/policy/spmd.py | 27 - examples/poisson/sci.py | 82 --- examples/policies/alpa/estimator.py | 10 - examples/utils.py | 145 +++++ examples/vision/resnet/model.py | 281 ---------- examples/vision/resnet/model_alpa.py | 152 ----- examples/vision/resnet/train.py | 76 --- examples/vision/swin/model.py | 32 +- examples/vision/swin/policy/__init__.py | 0 examples/vision/swin/policy/gallery.py | 106 ++++ examples/vision/swin/policy/mpmd.py | 219 -------- examples/vision/swin/policy/spmd.py | 164 ------ examples/vision/swin/train.py | 66 ++- examples/wrf/policy/h_halo.py | 18 - examples/wrf/policy/hw_halo.py | 22 - examples/wrf/policy/naive.py | 8 - examples/wrf/policy/onedim.py | 165 ------ examples/wrf/wrf.py | 385 ------------- examples/wrf/wrf2.py | 527 ------------------ 60 files changed, 908 insertions(+), 6763 deletions(-) delete mode 100644 examples/atmosphere/algo_test.py delete mode 100644 examples/atmosphere/policy/naive.py delete mode 100644 examples/atmosphere/policy/replicate.py delete mode 100644 examples/atmosphere/policy/split.py delete mode 100644 examples/atmosphere/weather.py create mode 100644 examples/mlp/__init__.py delete mode 100644 examples/mlp/infer.py delete mode 100644 examples/mlp/linearsfx.py create mode 100644 examples/mlp/policy/__init__.py create mode 100644 examples/mlp/policy/gallery.py delete mode 100644 examples/mlp/policy/mpmd.py delete mode 100644 examples/mlp/policy/spmd.py create mode 100644 examples/nlp/__init__.py create mode 100644 examples/nlp/blocks/__init__.py delete mode 100644 examples/nlp/blocks/encoder.py delete mode 100644 examples/nlp/blocks/mlp_moe.py rename examples/nlp/blocks/{decoder.py => transformer.py} (50%) create mode 100644 examples/nlp/gpt/__init__.py delete mode 100644 examples/nlp/gpt/infer.py create mode 100644 examples/nlp/gpt/policy/__init__.py create mode 100644 examples/nlp/mbart/__init__.py create mode 100644 examples/nlp/mbart/policy/__init__.py create mode 100644 examples/nlp/mbart/policy/gallery.py delete mode 100644 examples/nlp/mbart/policy/mpmd.py delete mode 100644 examples/nlp/palm/module_profiler.py delete mode 100644 examples/nlp/palm/palm.py delete mode 100644 examples/nlp/palm/policy/mpmd.py delete mode 100644 examples/nlp/palm/policy/spmd.py delete mode 100644 examples/nlp/torchscale/policy/mpmd.py delete mode 100644 examples/nlp/torchscale/policy/spmd.py delete mode 100644 examples/nlp/torchscale/run_torchscale_lm.py delete mode 100644 examples/nlp/torchscale/run_torchscale_tl.py delete mode 100644 examples/poisson/policy/spmd.py delete mode 100644 examples/poisson/sci.py create mode 100644 examples/utils.py delete mode 100644 examples/vision/resnet/model.py delete mode 100644 examples/vision/resnet/model_alpa.py delete mode 100644 examples/vision/resnet/train.py create mode 100644 examples/vision/swin/policy/__init__.py create mode 100644 examples/vision/swin/policy/gallery.py delete mode 100644 examples/vision/swin/policy/mpmd.py delete mode 100644 examples/vision/swin/policy/spmd.py delete mode 100644 examples/wrf/policy/h_halo.py delete mode 100644 examples/wrf/policy/hw_halo.py delete mode 100644 examples/wrf/policy/naive.py delete mode 100644 examples/wrf/policy/onedim.py delete mode 100644 examples/wrf/wrf.py delete mode 100644 examples/wrf/wrf2.py diff --git a/cube/__init__.py b/cube/__init__.py index c5c32ea9..647c6a24 100644 --- a/cube/__init__.py +++ b/cube/__init__.py @@ -1,6 +1,7 @@ from typing import Optional import logging from cube import runtime +from cube import utils from cube import profiler from cube.profiler.timer import CudaTimer diff --git a/examples/atmosphere/algo_test.py b/examples/atmosphere/algo_test.py deleted file mode 100644 index 603321f9..00000000 --- a/examples/atmosphere/algo_test.py +++ /dev/null @@ -1,139 +0,0 @@ -""" -example: - -python -m torch.distributed.launch \ - --nproc_per_node=4 \ - --nnodes=1 \ - --node_rank=0 \ - --master_addr=127.0.0.1 \ - --master_port=8004 \ - --use_env \ - examples/mlp/linears.py - -OMP_NUM_THREADS=4 torchrun --standalone \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/linears.py - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=2 \ - --rdzv_id=888 \ - --rdzv_backend=c10d \ - --rdzv_endpoint=worker0:8004 \ - examples/mlp/linears.py -""" - -import torch -from torch import nn - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from examples.atmosphere.policy.split import PAS -import torch.nn.functional as F - - -# from examples.mlp.policy.col_parallel import P, A, S -# PAS = (P, A, S) - -# =================== Semantic Model Description ==================== - -class MLP(nn.Module): - def __init__(self, dim, mult=1, filter=None): - super().__init__() - self.linear1 = nn.Linear(dim, dim * mult) - - def forward(self, data): - a = self.linear1(data) - paded = F.pad(a, (1, 1), "constant", 8.8) - output = paded + 0 - # loss = torch.sum(output) - # return loss - return output - -class ConvModel(nn.Module): - def __init__(self, dim, mult=1, filter=None): - super().__init__() - # self.linear1 = nn.Linear(dim, dim * mult) - self.filter = filter - - def forward(self, data): - # a = self.linear1(data) - # paded = F.pad(a, (1, 1), "constant", 8.8) - # output = paded + 0 - added = data + 1.0 - convd = torch.nn.functional.conv3d(added, self.filter, padding=[1,1,1]) - output = convd + 0 - - # loss = torch.sum(output) - # return loss - return output - -def train(): - batch_size = 2 - dim = 4 - in_channel, out_channel = 2, 2 - dimT, dimH, dimW = 2, 4, 4 - kT, kH, kW = 1, 3, 3 - - to_test = "MLP" - to_test = "Conv3d" - if to_test == "MLP": - model = MLP(dim=dim) - model = cube.SemanticModel( - model, input_shapes=([batch_size, dim],), - ) - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, dim],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) - elif to_test == "Conv3d": - filter = torch.randn(out_channel, in_channel, kT, kH, kW) - model = ConvModel(dim=dim, filter=filter) - model = cube.SemanticModel( - model, input_shapes=([batch_size, in_channel, dimT, dimH, dimW],), - ) - - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, in_channel, dimT, dimH, dimW],), - dtypes=(torch.float32,), - batch_dims=(0,) - ) - - @cube.compile(model, dataloader, PAS=PAS, override=True) - def train_iter(model, dataloader): - data = next(dataloader) - # loss = model(data) - # loss.backward() - output = model(data) - return output - - model = model.get_gen_module() - - # optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - # CudaTimer(enable=False).warmup() - torch.distributed.barrier() - iter_num = 1 - for step in range(iter_num): - # if step >= 40: - # CudaTimer(enable=True).start('e2e') - output = train_iter(model, dataloader) - # optimizer.step() - # optimizer.zero_grad() - # if step >= 40: - # CudaTimer().stop('e2e') - # if (step + 1) % 20 == 0: - # print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - print(f'output = {output}') - - # print_each_rank('e2e time (ms) per iteration: {} ms'.format( - # CudaTimer().duration(iter_num - 40, field_name='e2e'))) - # CudaTimer().print_all(times=iter_num - 40) - - -if __name__ == '__main__': - cube.init() - train() \ No newline at end of file diff --git a/examples/atmosphere/policy/naive.py b/examples/atmosphere/policy/naive.py deleted file mode 100644 index 535c0de1..00000000 --- a/examples/atmosphere/policy/naive.py +++ /dev/null @@ -1,11 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.adapter import IRAdapter -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation - - -def PAS(graph: IRGraph, resource): - print(graph.extra_repr()) - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - return graph diff --git a/examples/atmosphere/policy/replicate.py b/examples/atmosphere/policy/replicate.py deleted file mode 100644 index e2099fa6..00000000 --- a/examples/atmosphere/policy/replicate.py +++ /dev/null @@ -1,13 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.adapter.adapter import IRAdapter -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation - - -def PAS(graph: IRGraph, resource): - print(graph.extra_repr()) - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus, reset_dependency=False) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - return graph diff --git a/examples/atmosphere/policy/split.py b/examples/atmosphere/policy/split.py deleted file mode 100644 index cf4d8eb8..00000000 --- a/examples/atmosphere/policy/split.py +++ /dev/null @@ -1,43 +0,0 @@ -from cube.graph import IRGraph -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.graph.function.conv import IRConv3D -from cube.graph.function.pad import IRPad - -def PAS(graph: IRGraph, resource): - print(graph.extra_repr()) - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - if isinstance(node, IRDataOperation): - print(f'### IRDataOperation = {node}') - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, config=dict(num=resource.ngpus)) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - elif isinstance(node, IRPad): - print(f'### IRPad = {node}') - sub_nodes = list() - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, config=dict(dim=1, num=min(2, resource.ngpus))) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif isinstance(node, IRConv3D): - print(f'### IRConv3D = {node}') - sub_nodes = list() - algo = node.algorithms('halo') - # Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus // 2)) - Wnodes = graph.partition(node, algo, config=dict(idx=0, dim=3, num=resource.ngpus)) - # for Wnode in Wnodes: - # algo = Wnode.algorithms('halo') - # Hnodes = graph.partition(Wnode, algo, config=dict(idx=0, dim=2, num=2)) - # sub_nodes += Hnodes - sub_nodes += Wnodes #TODO remove temp - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - else: - print(f'### to-replicate = {node}') - sub_nodes = graph.replicate(node, times=resource.ngpus, reset_dependency=False) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - else: - print(f'### non-IRBpOperation = {node}') - return graph diff --git a/examples/atmosphere/weather.py b/examples/atmosphere/weather.py deleted file mode 100644 index 12dc8261..00000000 --- a/examples/atmosphere/weather.py +++ /dev/null @@ -1,457 +0,0 @@ -import torch -import torch.nn.functional as F - -torch.set_default_tensor_type(torch.DoubleTensor) - -import cube -from cube.runtime.syndata import SciLoopVariables -from examples.atmosphere.policy.naive import PAS - -from einops.layers.torch import Rearrange - -#custom ops -import examples.custom_ops as custom_ops - -class Atmoshpere(torch.nn.Module): - def __init__(self, - nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, - bar_x_filter, bar_y_filter, bar_z_filter, - bar_x2_filter, bar_y2_filter, bar_xy_filter, - delta_x_filter, delta_y_filter, delta_z_filter, delta_D_filter, delta_E_filter, laplas_filter, - device='cuda'): - super().__init__() - #self.device = torch.device(device) - - # physics constant - self.g = 9.8 # acceleration of gravity, unit in m/s^2 - self.PSEA = 101325. # sea level pressure, unit in Pa - self.KAPPA = 0.286 # dimensionless - self.RE = 6.4e6 # radius of earth, unit in m - self.CPD = 1004.67 # specific heat of dry air at constant pressure J*kg^-1*K^-1 - # self.OMEGA = 7.292e-5 # angular speed of the Earth s^-1 - self.OMEGA = 1e-1 # angular speed of the Earth s^-1 - - # simulation domain - self.nx = nx - self.ny = ny - self.nz = nz - self.dx = dx - self.dy = dy - self.dz = 1. / nz - self.x0 = x0 - self.y0 = y0 - - self.deltaA = deltaA - self.Y = Y - self.f = f - self.sigma = sigma - self.P_ = P_ - self.P = P - self.phi = phi - self.zs = zs - self.w = w - - self.bar_x_filter = bar_x_filter - self.bar_y_filter = bar_y_filter - self.delta_x_filter = delta_x_filter - self.delta_y_filter = delta_y_filter - self.bar_y2_filter = bar_y2_filter - self.bar_x2_filter = bar_x2_filter - self.bar_z_filter = bar_z_filter - self.delta_z_filter = delta_z_filter - self.delta_E_filter = delta_E_filter - self.laplas_filter = laplas_filter - self.bar_xy_filter = bar_xy_filter - self.delta_D_filter = delta_D_filter - - self.pre_conv3d_reshape = Rearrange('(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # x = X.view(1, 1, Nz, Ny, Nx) - self.post_conv3d_reshape = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') - - self.pre_pady_reshape = Rearrange('(b0 Nz) Ny Nx -> b0 Nz Ny Nx', b0=1) - self.post_pady_reshape = Rearrange('b0 Nz Ny Nx -> (b0 Nz) Ny Nx') - - - def step(self, dt, pi, theta, u, v, pi0, theta0, u0, v0): - # flux - F = self.bar_x(self.pad_x(pi)) * 0.5 * u * self.RE * self.dy # (nz, ny, nx + 1) - G = self.bar_y(self.pad_y(pi)) * 0.5 * v * self.RE * self.dx * torch.cos(self.Y) # (nz, ny + 1, nx) - B = self.bar_y2(self.bar_x(self.pad_y(F))) / 12. # (nz, ny, nx) - C = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(G)))) / 12. # (nz, ny + 1, nx + 1) - D = self.bar_y2(self.pad_y(G)) + self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) - E = self.bar_y2(self.pad_y(G)) - self.bar_y(self.bar_x(self.pad_y(F))) / 24. # (nx, ny + 1, nx) - Q = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(F)))) / 12. # (nz, ny + 1, nx + 1) - R = self.bar_x2(self.bar_y(self.pad_x(G))) / 12. # (nz, ny, nx) - S = self.bar_y(self.bar_x(self.pad_x(G))) + self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) - T = self.bar_y(self.bar_x(self.pad_x(G))) - self.bar_x2(self.pad_x(F)) / 24. # (nz, ny, nx + 1) - - pi1 = pi0 - dt / self.deltaA * ((self.delta_x(F) + self.delta_y(G)) * self.dz).sum(dim=0) #sum(axis=0) # (nz, ny, nx) - # print('pi:', pi1.mean()) - - - # # update diagnostic variable w (nz + 1, ny, nx) - # for i in range(1, self.nz + 1): - # self.w[i] = - ((self.delta_x(F[:i]) + self.delta_y(G[:i])) * self.dz).sum(dim=0) / self.deltaA / pi1 \ - # - self.sigma[i] * (pi1 - pi0) / dt / pi1 - # TODO fix SetAttr for "self.w =" - custom_ops.update_diag_(self.w, F, G, self.delta_x_filter, self.delta_y_filter, self.deltaA, - pi0, pi1, self.sigma, dt, self.dz) - # print('w:', self.w.mean()) - - # update potential temperature theta (nz, ny, nx) - theta_ = self.pad_z( - (self.bar_z(self.P * theta) - self.delta_z(theta) * custom_ops.strip_2_borders(self.P_)) / self.delta_z(self.P) - ) # (nz + 1, ny, nx) - - theta1 = pi0 / pi1 * theta0 + dt / self.deltaA / pi1 * ( - (self.delta_x(F * self.bar_x(self.pad_x(theta))) + self.delta_y(G * self.bar_y(self.pad_y(theta)))) / 2. + - pi * self.deltaA * self.delta_z(self.w * theta_) / self.dz + - 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(theta)))) - ) - - # print('theta:', theta1.mean()) - - # TODO a custom Op needed - # # update geopotential - # self.phi[-1] = self.g * self.zs - self.CPD * (self.P[-1] - self.P_[-1]) * theta[-1] - # for i in range(1, self.nz): - # tmp = self.phi[-i] - self.CPD * (self.P_[-i - 1] - self.P[-i]) * theta[-i] - # self.phi[-1 - i] = tmp - self.CPD * (self.P[-1 - i] - self.P_[-1 - i]) * theta[-1 - i] - custom_ops.update_geopotential_(self.phi, self.zs, self.P, self.P_, theta, self.g, self.CPD, self.nz) - # print('phi:', self.phi.mean()) - - # update u (nz, ny, nx + 1) - pi0_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi0 * self.deltaA)))) / 8. # (nz, ny, nx + 1) - pi1_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi1 * self.deltaA)))) / 8. # (nz, ny, nx + 1) - pi_w_deltaA = self.bar_y2(self.bar_x(self.pad_y(self.pad_x(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny, nx + 1) - advec = ( - - self.delta_x(self.pad_x(B * self.bar_x(u))) - - self.delta_y(C * self.bar_y(self.pad_y(u))) - + self.delta_D(self.pad_x(D * self.bar_xy(self.pad_y(u)))) - + self.delta_E(self.pad_x(E * self.bar_xy(self.pad_y(u)))) - ) / 2. - trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(u)) * 0.5) / self.dz - #TODO fixme press = - self.RE * self.dy * ( - press = self.dy * ( - self.delta_x(self.pad_x(self.phi)) * self.delta_x(self.pad_x(pi)) / 2. + - self.delta_x(self.pad_x(pi)) * 0.5 * self.CPD * self.bar_x(self.pad_x( - theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) - )) - ) - diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(u)))) - #TODO fixme cori = self.RE * self.dx * self.dy * 0.25 * ( - cori = self.dy * ( - self.bar_x(self.pad_x(pi * self.bar_y(v) * (self.f + 0. * self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) - ) * 0.0 - u1 = (pi0_deltaA * u0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA - # print('u1:', u1.mean()) - - # # update v (nz, ny + 1, nx) - pi0_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi0 * self.deltaA)))) / 8. # (nz, ny + 1, nx) - pi1_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi1 * self.deltaA)))) / 8. # (nz, ny + 1, nx) - pi_w_deltaA = self.bar_x2(self.bar_y(self.pad_x(self.pad_y(pi * self.deltaA * self.w)))) / 8. # (nz + 1, ny + 1, nx) - advec = ( - - self.delta_x(Q * self.bar_x(self.pad_x(v))) - - self.delta_y(self.pad_y(R * self.bar_y(v))) - + self.delta_D(self.pad_y(S * self.bar_xy(self.pad_x(v)))) - + self.delta_E(self.pad_y(T * self.bar_xy(self.pad_x(v)))) - ) / 2. - trans = - self.delta_z(pi_w_deltaA * self.bar_z(self.pad_z(v)) * 0.5) / self.dz - #TODO fixme press = - self.RE * self.dx * ( - press = self.dx * ( - self.delta_y(self.pad_y(self.phi)) * self.delta_y(self.pad_y(pi)) / 2. + - self.delta_y(self.pad_y(pi)) * 0.5 * self.CPD * self.bar_y(self.pad_y( - theta / self.dz * (self.delta_z(self.sigma * self.P_) - self.P * self.delta_z(self.sigma)) - )) - ) - diff = 10e8 * self.laplas(self.pad_z(self.pad_y(self.pad_x(v)))) - #TODO fixme cori = - self.RE * self.dx * self.dy * 0.25 * ( - cori = self.dy * ( - self.bar_y(self.pad_y(pi * self.bar_x(u) * (self.f + self.bar_x(u) * torch.sin(self.bar_y(self.Y))))) - ) * 0.0 - v1 = (pi0_deltaA * v0 + dt * (advec + trans + press + diff + cori)) / pi1_deltaA - # # print('v1:', v1.mean()) - - return pi1, theta1, u1, v1 - - - def forward(self, pi, theta, u, v, dt): - pi_, theta_, u_, v_ = self.step(dt / 2., pi, theta, u, v, pi, theta, u, v) - pi1, theta1, u1, v1 = self.step(dt, pi_, theta_, u_, v_, pi, theta, u, v) - return pi1, theta1, u1, v1 - - - def pad_x(self, X): - return F.pad(X, (1, 1), "circular") - - - def bar_x(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_x_filter)) - - - def bar_x2(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_x2_filter)) - - - def delta_x(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_x_filter)) - - - def pad_y(self, X): - #TODO check return F.pad(X.view(1, nz, ny, nx), (0, 0, 1, 1), "circular").view(nz, ny + 2, nx) - return self.post_pady_reshape(F.pad(self.pre_pady_reshape(X), (0, 0, 1, 1), "circular")) - - - def bar_y(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_y_filter)) - - - def bar_y2(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_y2_filter)) - - - def delta_y(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_y_filter)) - - - def bar_z(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_z_filter)) - - - def pad_z(self, X): - # return F.pad(X.view(1, 1, nz, ny, nx), (0, 0, 0, 0, 1, 1)).view(nz + 2, ny, nx) - return F.pad(X, (0, 0, 0, 0, 1, 1)) - - - def delta_z(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_z_filter)) - - - def delta_D(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_D_filter)) - - - def delta_E(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.delta_E_filter)) - - - def bar_xy(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.bar_xy_filter)) - - - def laplas(self, X): - return self.post_conv3d_reshape(F.conv3d(self.pre_conv3d_reshape(X), self.laplas_filter)) - - -if __name__ == "__main__": - cube.init() - - nz = 15 - ny = 100 - nx = 100 - dy = 1e-4 - dx = 1e-4 - x0 = 0.0 - y0 = 0.2 - - PSEA = 101325. # sea level pressure, unit in Pa - RE = 6.4e6 # radius of earth, unit in m - - xc = nx * dx / 2 + x0 - yc = ny * dy / 2 + y0 - X = torch.linspace(0, nx - 1, nx).view(1, 1, nx) * dx + x0 - Y = torch.linspace(0, ny - 1, ny).view(1, ny, 1) * dy + y0 - ps = torch.ones((1, ny, nx)) * PSEA - 300 * torch.exp( - - 1e-6 * ((RE * torch.cos((Y + yc) / 2)) * (X - xc))**2 - - 1e-6 * (RE * (Y - yc))**2) - pt = 250e2 - zs = torch.zeros((ny, nx)) + 10000 * torch.exp( - - 1e-6 * (RE * (X - nx * dx / 3 - x0))**2 - - 1e-6 * (RE * (Y - yc))**2) - - u = torch.zeros((nz, ny, nx + 1)) - v = torch.zeros((nz, ny + 1, nx)) - - dt = torch.tensor(1.) - - # physics constant - g = 9.8 # acceleration of gravity, unit in m/s^2 - PSEA = 101325. # sea level pressure, unit in Pa - KAPPA = 0.286 # dimensionless - RE = 6.4e6 # radius of earth, unit in m - CPD = 1004.67 # specific heat of dry air at constant pressure J*kg^-1*K^-1 - # OMEGA = 7.292e-5 # angular speed of the Earth s^-1 - OMEGA = 1e-1 # angular speed of the Earth s^-1 - - # atmoshpere verticle profile - hight_profile = torch.tensor([ - 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, - 8.5, 9, 9.5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, - 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100 - ]) * 1e3 - - pressure_profile = torch.tensor([ - 1013.25, 1001.20, 989.45, 977.72, 966.11, 954.61, 943.22, 931.94, 920.77, 909.71, 898.80, 845.59, 795.0, - 746.9, 701.2, 657.8, 616.6, 577.5, 540.5, 505.4, 472.2, 440.7, 411.1, 383.0, 356.5, 331.5, 308.0, 285.8, - 265.0, 227.0, 194.0, 165.8, 141.7, 121.1, 103.5, 88.5, 75.7, 64.7, 55.3, 47.3, 40.5, 34.7, 29.7, 25.5, 21.9, - 18.8, 16.2, 13.9, 12.0, 10.3, 8.89, 7.67, 6.63, 5.75, 4.99, 4.33, 3.77, 3.29, 2.87, 2.51, 2.20, 1.93, 1.69, - 1.49, 1.31, 1.16, 1.02, 0.903, 0.903, 0.425, 0.220, 0.109, 0.0522, 0.0239, 0.0105, 0.0045, 0.0018, 0.00076, - 0.00032 - ]) * 1e2 - - temperature_profile = torch.tensor([ - 288.15, 287.50, 286.85, 286.20, 285.55, 284.90, 284.25, 283.60, 282.95, 282.30, 281.65, 278.40, 275.15, - 271.91, 268.66, 265.41, 262.17, 258.92, 255.68, 252.43, 249.19, 245.94, 242.70, 239.46, 236.22, 232.97, - 229.73, 226.49, 223.25, 216.78, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, 216.65, - 217.58, 218.57, 219.57, 220.56, 221.55, 222.54, 223.54, 224.53, 225.52, 226.51, 227.50, 228.49, 230.97, - 233.74, 236.51, 239.28, 242.05, 244.82, 247.58, 250.35, 253.11, 255.88, 258.64, 261.40, 264.16, 266.93, - 269.68, 270.65, 270.65, 270.65, 260.77, 247.02, 233.29, 219.59, 208.40, 198.64, 188.89, 186.87, 188.42, - 195.08 - ]) - - density_profile = torch.tensor([ - 1.225, 1.213, 1.202, 1.190, 1.179, 1.167, 1.156, 1.145, 1.134, 1.123, 1.112, 1.058, 1.007, 0.957, 0.909, - 0.863, 0.819, 0.777, 0.736, 0.697, 0.660, 0.624, 0.590, 0.557, 0.526, 0.496, 0.467, 0.440, 0.414, 0.365, - 0.312, 0.267, 0.228, 0.195, 0.166, 0.142, 0.122, 0.104, 0.0889, 0.0757, 0.0645, 0.0550, 0.0469, 0.0401, - 0.0343, 0.0293, 0.0251, 0.0215, 0.0184, 0.0158, 0.0136, 0.0116, 0.00989, 0.00846, 0.00726, 0.00624, 0.00537, - 0.00463, 0.00400, 0.00346, 0.00299, 0.00260, 0.00226, 0.00197, 0.00171, 0.0015, 0.00132, 0.00116, 0.00103, - 5.7e-4, 3.1e-4, 1.6e-4, 8.3e-4, 4.0e-5, 1.8e-5, 8.2e-6, 3.4e-6, 7.5e-7, 5.6e-7 - ]) - - def hight_from_pressure(p): - ind0 = torch.abs((p[None] - pressure_profile[(..., ) + (None, ) * len(p.shape)])).argmin(axis=0) - ind1 = (p > pressure_profile[ind0]) * 2 - 1 + ind0 - hight = (hight_profile[ind1] - hight_profile[ind0]) * (p - pressure_profile[ind0]) / ( - pressure_profile[ind1] - pressure_profile[ind0]) + hight_profile[ind0] - return hight - - def pressure_from_hight(z): - ind0 = torch.abs((z[None] - hight_profile[(..., ) + (None, ) * len(z.shape)])).argmin(axis=0) - ind1 = (hight_profile[ind0] > z) * 2 - 1 + ind0 - p = (pressure_profile[ind1] - pressure_profile[ind0]) * (z - hight_profile[ind0]) / \ - (hight_profile[ind1] - hight_profile[ind0]) + pressure_profile[ind0] - return p - - def temperature_from_pressure(p): - ind0 = torch.abs((p[None] - pressure_profile[(..., ) + (None, ) * len(p.shape)])).argmin(axis=0) - ind1 = (p > pressure_profile[ind0]) * 2 - 1 + ind0 - T = (temperature_profile[ind1] - temperature_profile[ind0]) * (p - pressure_profile[ind0]) / ( - pressure_profile[ind1] - pressure_profile[ind0]) + temperature_profile[ind0] - return T - - def bar_y(X): - nz, ny, nx = X.shape - filter = torch.tensor([1., 1.]).view(1, 1, 1, 2, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) - - def delta_z(X): - nz, ny, nx = X.shape - filter = torch.tensor([-1., 1.]).view(1, 1, 2, 1, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) - - def init(ps, pt, zs): - Y = ( - torch.linspace(0, ny, ny + 1) * dy + y0 - ).view(1, ny + 1, 1) - deltaA = RE**2 * torch.cos(bar_y(Y) * 0.5) * dx * dy # (1, ny, 1) - f = 2 * OMEGA * torch.sin(bar_y(Y)) * torch.cos(bar_y(Y)) * RE # (nz, ny, nx) - - # vertical grids - pt = torch.tensor([pt]).view(1, 1, 1) - zt = hight_from_pressure(pt) - z = torch.linspace(1, 0, nz + 1).view(-1, 1, 1) * zt - p_ = pressure_from_hight(z) - sigma = (p_ - pt) / (p_[-1] - pt) # (nz + 1, 1, 1) - - # column pressure, with shape (1, ny, nx) - pi = (ps - pt).view(1, ny, nx) - - # potential temperature factor - p_ = pt + sigma * pi # (nz + 1, ny, nx) - P_ = (p_ / PSEA)**KAPPA # (nz + 1, ny, nx) - P = delta_z(p_ * P_) / delta_z(p_) / (1 + KAPPA) # (nz, ny, nx) - - # potential temperature - p = PSEA * P**(1 / KAPPA) - T = temperature_from_pressure(p) - theta = T / P - - # geopotential (nz, ny, nx) - phi = torch.zeros((nz, ny, nx)) - zs = zs - - # vertical velocity - w = torch.zeros((nz + 1, ny, nx)) - - return pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w - - pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w = init(ps, pt, zs) - print("[pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w]") - for var in [pi, theta, deltaA, Y, f, sigma, P_, P, phi, zs, w]: - print(f'shape {var.shape}') - - bar_x_filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) - bar_y_filter = torch.tensor([1., 1.]).view(1, 1, 1, 2, 1) - delta_x_filter = torch.tensor([-1., 1.]).view(1, 1, 1, 1, 2) - delta_y_filter = torch.tensor([-1., 1.]).view(1, 1, 1, 2, 1) - bar_y2_filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 3, 1) - bar_x2_filter = torch.tensor([1., 2., 1.]).view(1, 1, 1, 1, 3) - bar_z_filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) - delta_z_filter = torch.tensor([-1., 1.]).view(1, 1, 2, 1, 1) - delta_E_filter = torch.tensor( - [[0., 1.], - [-1., 0.]] - ).view(1, 1, 1, 2, 2) - laplas_filter = torch.tensor( - [[[0., 0., 0.], - [0., 1., 0.], - [0., 0., 0.]], - [[0., 1., 0.], - [1., -6, 1.], - [0., 1., 0.]], - [[0., 0., 0.], - [0., 1., 0.], - [0., 0., 0.]]], - ).view(1, 1, 3, 3, 3) - bar_xy_filter = torch.tensor( - [[1., 0.], - [0., 1.]] - ).view(1, 1, 1, 2, 2) - delta_D_filter = torch.tensor( - [[1., 0.], - [0., -1.]] - ).view(1, 1, 1, 2, 2) - - model = Atmoshpere(nz, ny, nx, dy, dx, x0, y0, deltaA, Y, f, sigma, P_, P, phi, zs, w, - bar_x_filter, bar_y_filter, bar_z_filter, - bar_x2_filter, bar_y2_filter, bar_xy_filter, - delta_x_filter, delta_y_filter, delta_z_filter, delta_D_filter, delta_E_filter, laplas_filter) - - print("[pi, theta, u, v, dt]") - for var in [pi, theta, u, v, dt]: - print(f'shape {var.shape}') - - varloader = SciLoopVariables(variables=[pi, theta, u, v], constants=[dt]) - model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes)) - - @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) - def train_iter(model, dataloader): - pi, theta, u, v, dt = next(dataloader) - pi, theta, u, v = model(pi, theta, u, v, dt) - return pi, theta, u, v - model = model.get_gen_module() - - for i in range(3): - print("iter-{}...".format(i)) - pi, theta, u, v = train_iter(model, varloader) - - # # ctf = plt.contourf(pi.view(ny, nx).numpy(), levels=50, cmap='jet') - # plt.cla() - # ct = plt.contour(zs.view(ny, nx).cpu().numpy(), levels=[7000]) - # ctf = plt.contourf(u[3].cpu().numpy(), levels=50, cmap='jet') - # plt.colorbar(ctf) - # # plt.grid(True) - # plt.tight_layout() - # plt.savefig(f'res2/res{i}.jpeg', dpi=300) - # plt.clf() - - print(f'pi = {pi}; theta = {theta}; u = {u}; v = {v}') diff --git a/examples/mlp/__init__.py b/examples/mlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/mlp/infer.py b/examples/mlp/infer.py deleted file mode 100644 index fe63a445..00000000 --- a/examples/mlp/infer.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -example: - -ASYNC_COMM=1 OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/infer.py --policy PASMegatron -""" - -import torch -from torch import nn -import time - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - -import examples.mlp.policy.spmd as spmd -import examples.mlp.policy.mpmd as mpmd - -import argparse - -parser = argparse.ArgumentParser(description='MLP example') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -args = parser.parse_args() - -cube.init() - -# set up policy -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - - -# =================== Semantic Model Description ==================== - -class MLP(nn.Module): - def __init__(self, dim, mult=1, nlayers=4): - super().__init__() - self.layers = torch.nn.ModuleList([]) - for lid in range(nlayers): - if lid % 2 == 0: - self.layers.append(nn.Linear(dim, dim * mult, bias=False)) - else: - self.layers.append(nn.Linear(dim * mult, dim, bias=False)) - - def forward(self, data): - x = data - for layer in self.layers: - x = layer(x) - loss = torch.sum(x) - return loss - - -class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, bs: int, dim: int): - super().__init__(bs, [0]) - self.sample = None - self.dim = dim - self.set_batch_size(bs) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - self.batch_size = batch_size - self.sample = torch.rand( - [batch_size, self.dim], dtype=torch.float32, - device=torch.cuda.current_device() - ) - - -def infer(): - batch_size = 128 - dim = 4096 - - model = MLP(dim=dim) - model = cube.SemanticModel(model) - dataloader = MLPDataLoader(batch_size, dim) - - @cube.compile(model, dataloader, PAS=PAS) - def infer_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - model = model.get_gen_module() - - CudaTimer(enable=False).warmup() - if torch.distributed.is_initialized(): - torch.distributed.barrier() - iter_num = 16 - warmup = 4 - for step in range(iter_num): - if step >= warmup: - CudaTimer(enable=True, predefined=True).start('e2e') - infer_iter(model, dataloader) - if step >= warmup: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - torch.distributed.barrier() - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) - - -infer() \ No newline at end of file diff --git a/examples/mlp/linearsfx.py b/examples/mlp/linearsfx.py deleted file mode 100644 index f51b21e5..00000000 --- a/examples/mlp/linearsfx.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -example: - -//torchscript based DAG capture -PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/mlp/linearsfx.py --policy PASData -//torch.fx based DAG capture -USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/mlp/linearsfx.py --policy PASData -""" - -import torch -from torch import nn - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank - -import examples.mlp.policy.spmd as spmd -import examples.mlp.policy.mpmd as mpmd - -import argparse - -parser = argparse.ArgumentParser(description='comm primitive') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -parser.add_argument('--local_rank', type=int, default=0) -args = parser.parse_args() - -cube.init() - -# set up policy -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - - -# =================== Semantic Model Description ==================== - -class MLP(nn.Module): - def __init__(self, dim, mult=1, nlayers=4): - super().__init__() - self.layers = torch.nn.ModuleList([]) - for lid in range(nlayers): - if lid % 2 == 0: - self.layers.append(nn.Linear(dim, dim * mult, bias=False)) - last_dim = dim * mult - else: - self.layers.append(nn.Linear(dim * mult, dim, bias=False)) - last_dim = dim - - self.layer_norm = nn.LayerNorm(last_dim) #TODO CHECK torch.fx ignores LayerNorm - # self.p = 0.5 - self.drop_out = nn.Dropout() - # self.y = torch.nn.Parameter(torch.empty(128, last_dim)) - - def forward(self, data, mask): - x = data.masked_fill(mask, 0.0) - y = torch._shape_as_tensor(x) - z = torch.gt(x, x) - x = x.fill_(0.0) - x = torch.nn.functional.softmax(x, dim=-1) - x = torch.bmm(x, x) - adder = torch.sum(x, dim=2, keepdim=True) - x = torch.baddbmm(adder, batch1=x, batch2=x, alpha=0.125, beta=1.0) - x = torch.tanh(x) - x = torch.pow(x, x) - for layer in self.layers: - x = layer(x) - x = torch.nn.functional.relu(x) - x = torch.nn.functional.gelu(x) - x = self.layer_norm(x) - type_x = torch.pow(x, 1.0) - x = x.type_as(type_x) - x = x.unsqueeze(1) - x = self.drop_out(x) - x = x.squeeze() - x = torch.triu(x, 1) - x = torch.nan_to_num(x) - ne_var = x.detach() - ne_var = torch.ne(ne_var, 1.0) - eq_var = x.detach() - eq_var = torch.eq(eq_var, 1.0) - long_var = x.detach() - long_var = long_var.long() - floor_div_var = x.detach() - floor_div_var = torch.floor_divide(floor_div_var, 2.0) - x = torch.true_divide(x, 1.0) - x = torch.cumsum(x, dim=-1) - x = x.permute(0, 2, 1) - x = x.transpose(1, 2) - x = torch.div(x, 1.0) - # concrete_trace not support - # x = torch.Tensor.view(x, [32 * 1024, 1024]) - x = x.view(32 * 1024, 1024) - x = x.reshape(32, 1024, 1024) - neg_x = torch.neg(x) - x = torch.einsum('a b c, a c d -> a b d', x, neg_x) - # TODO(yizhu1): uncomment and check - # bs = x.size(1) - # indices = torch.arange(bs, dtype=torch.int64) - # x = torch.index_select(x, 1, indices) - p = torch.div(x, 2.0) - x = torch.stack((x, p), dim=1) - x = torch.flatten(x, 2, 3) - x = x.repeat(1, 2, 1) - loss = torch.sum(x) - return loss - - -def train(): - batch_size = 32 - dim = 1024 - - dataloader = cube.runtime.syndata.SynDataLoader( - shapes=([batch_size, dim, dim], [batch_size, dim, dim],), - dtypes=(torch.float32, torch.bool,), - batch_dims=(0, 0,) - ) - - model = MLP(dim=dim) - model = cube.SemanticModel(model) - - data, mask = next(dataloader) - @cube.compile(model, dataloader, PAS=PAS, load_content=False, - model_dummy_inputs={'data': data, 'mask': mask}) - def train_iter(model, dataloader): - data, mask = next(dataloader) - loss = model(data, mask) - loss.backward() - - model = model.get_gen_module() - - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - CudaTimer(enable=False).warmup() - if torch.distributed.is_initialized(): - torch.distributed.barrier() - iter_num = 2 #32 - warmup = 0 #8 - for step in range(iter_num): - if step >= warmup: - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= warmup: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num - warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num - warmup) - - -train() \ No newline at end of file diff --git a/examples/mlp/policy/__init__.py b/examples/mlp/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/mlp/policy/gallery.py b/examples/mlp/policy/gallery.py new file mode 100644 index 00000000..06f91aa0 --- /dev/null +++ b/examples/mlp/policy/gallery.py @@ -0,0 +1,89 @@ +from typing import List +from cube.graph import IRGraph +from cube.graph.segment import IRSegment +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.graph.schedule.predefined import PredefinedSched + +from examples.utils import tensor_parallelism, replica, create_mesh + + +def PASSingle(graph: IRGraph, resource, **kwargs): + """Single device""" + assert resource.ngpus == 1, "only apply for single gpu case" + for node in graph.nodes(): + if isinstance(node, (IRDataOperation, IRFwOperation)): + graph.assign(node, 0) + return graph + + +def PASData(graph: IRGraph, resource, **kwargs): + """Data Parallellism""" + devs = list(range(resource.ngpus)) + for node in graph.select(ntype=IRFwOperation): + tensor_parallelism(graph, node, idx=0, dim=0, devs=devs) + for node in graph.select(ntype=IRDataOperation): + replica(graph, node, devs=devs) + return graph + + +def PASCol(graph: IRGraph, resource, **kwargs): + """Linear Column Parallel""" + devs = list(range(resource.ngpus)) + for node in graph.select(name='linear'): + tensor_parallelism(graph, node, idx=1, dim=0, devs=devs) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + replica(graph, node, devs=devs) + return graph + + +def PASRow(graph: IRGraph, resource, **kwargs): + """Linear Row Parallel""" + devs = list(range(resource.ngpus)) + for node in graph.select(name='linear'): + tensor_parallelism(graph, node, idx=0, dim=1, devs=devs) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + replica(graph, node, devs=devs) + return graph + + +def PASMegatronTP(graph: IRGraph, resource, **kwargs): + """Linear Hybrid Parallelism (Megatron)""" + devs = list(range(resource.ngpus)) + for idx, node in enumerate(graph.select(name='linear')): + tensor_parallelism(graph, node, idx=1, dim=idx%2, devs=devs) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + replica(graph, node, devs=devs) + return graph + + +def PASMegatron(graph: IRGraph, resource, nmicros: int, tp_size: int, **kwargs): + + num_stages = resource.ngpus // tp_size + _, tp_mesh = create_mesh(resource.ngpus, (num_stages, tp_size)) + + # group to sub-graphs + linears = graph.select(name='linear') + stage_start_nodes = linears[::len(linears) // num_stages][:num_stages] + graph.staging(stage_start_nodes) + + segments = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + + for sid, segment in enumerate(fsegs): + # get tensor parallel group + tp_group = tp_mesh[sid] + for idx, node in enumerate(segment.nodes()): + if node.name == 'linear': + tensor_parallelism(graph, node, idx=1, dim=idx%2, devs=tp_group) + else: + replica(node, devs=tp_group) + + for dl in graph.select(ntype=IRDataOperation): + replica(dl, devs=list(range(resource.ngpus))) + + PredefinedSched.sched_1f1b(graph, nmicros, num_stages) + return graph + diff --git a/examples/mlp/policy/mpmd.py b/examples/mlp/policy/mpmd.py deleted file mode 100644 index bf00f3ee..00000000 --- a/examples/mlp/policy/mpmd.py +++ /dev/null @@ -1,103 +0,0 @@ -import random -from typing import Tuple -import numpy as np - -from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.schedule.predefined import PredefinedSched - - -def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: - """ - Create hybrid (nested) groups given the each group number. - - The product of group_num should be same with total devices. - """ - group_num = np.array(group_num) - cnt = np.prod(group_num) - assert cnt == ngpus, 'total device not match' - grid = np.arange(cnt).reshape(tuple(group_num)) - dims = list(range(len(group_num))) - outputs = [] - for dim, num in enumerate(group_num): - remain = ngpus // num - order = tuple(dims[:dim] + dims[dim+1:] + [dim]) - grid_dim = np.transpose(grid, order).reshape((remain,num)) - grid_dim = grid_dim.tolist() - outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) - assert len(outputs) == len(group_num) - return tuple(outputs) - - -def PASRandom(graph, resource): - """ - Random pipeline - """ - assert len(graph.nodes()) // 2 >= resource.ngpus, "not enough operator number." - remain_device = set(range(resource.ngpus)) - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - if len(remain_device) != 0: - idx = random.randint(0, len(remain_device) - 1) - device = list(remain_device)[idx] - remain_device.remove(device) - else: - device = random.randint(0, resource.ngpus - 1) - graph.assign(node, device) - elif isinstance(node, IRDataOperation): - device = random.randint(0, resource.ngpus - 1) - graph.assign(node, device) - print(graph.extra_repr()) - return graph - - -def PASMegatron(graph: IRGraph, resource): - - # assert resource.ngpus == 8, "should apply on 8 gpus" - num_stage = 4 - num_tp = resource.ngpus // num_stage - num_microbatch = resource.ngpus * 8 - - _, tp_mesh = _create_mesh(resource.ngpus, (num_stage, num_tp)) - print(f'> pipeline-tensor parallel group: {tp_mesh}') - assert len(tp_mesh) == num_stage - - linears = graph.select('linear') - stage_start_nodes = linears[::len(linears) // num_stage] - stage_start_nodes = stage_start_nodes[:num_stage] - assert len(stage_start_nodes) == num_stage, f"{len(stage_start_nodes)} != {num_stage}" - graph.staging(stage_start_nodes) - - segments = graph.select(ntype=IRSegment, flatten=False) - fsegs = [seg for seg in segments if seg.isfw()] - assert len(fsegs) == num_stage - - for sid, segment in enumerate(fsegs): - # get tensor parallel group - tp_group = tp_mesh[sid] - for idx, node in enumerate(segment.nodes()): - # partition - if node.name == 'linear': - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx % 2, num=num_tp) - else: - tp_nodes = graph.replicate(node, times=num_tp) - # assign - for devid, node in zip(tp_group, tp_nodes): - graph.assign(node, devid) - - for dl in graph.select(ntype=IRDataOperation): - mesh = tp_mesh[0] - dls = graph.replicate(dl, times=num_tp) - for devid, dl in zip(mesh, dls): - graph.assign(dl, devid) - - # setup schedule to 1F1B - # schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) - # graph.schedule_plan = schedule - if graph.train: - schedule = PredefinedSched.sched_1f1b(graph, num_microbatch, num_stage) - else: - schedule = PredefinedSched.sched_infer_pipe(graph, num_microbatch, num_stage) - return graph diff --git a/examples/mlp/policy/spmd.py b/examples/mlp/policy/spmd.py deleted file mode 100644 index e3bae1a4..00000000 --- a/examples/mlp/policy/spmd.py +++ /dev/null @@ -1,231 +0,0 @@ -from typing import List -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.gener.rvd.intra import IntraAutoPlacer - - -# tensor parallelism with auto-placer -# This is an implementation example of SPMD auto placer usage -def _tp_autoplace(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): - - if len(devs) == 1: - graph.assign(node, devs[0]) - return [node] - - segment: IRSegment = graph.segment(node) - ftensor = node.input(configs['idx']).parent - - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - producers = segment.producers(ftensor) - if ftensor.is_param() or len(producers) != len(sub_nodes): - print(f"> skip auto placer due to condition not matched: " - f"nproducers: {len(producers)}, nconsumers: {len(sub_nodes)}, " f"producer name: {producers[0].name if len(producers) > 0 else None}") - devs = sorted(list(devs)) - for devid, node in zip(devs, sub_nodes): - graph.assign(node, devid) - else: - devices = IntraAutoPlacer.auto_place( - segment, ftensor, producers, sub_nodes) - for devid, subnode in zip(devices, sub_nodes): - graph.assign(subnode, devid) - return sub_nodes - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def PASSingle(graph: IRGraph, resource): - """ - Single device - """ - assert resource.ngpus == 1, "only apply for single gpu case" - for node in graph.nodes(): - if isinstance(node, (IRDataOperation, IRFwOperation)): - graph.assign(node, 0) - return graph - - -def PASData(graph: IRGraph, resource): - """ - Data Parallel - """ - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - batch_dim = node.get_batch_dims()[0] - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=0, dim=batch_dim, num=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - return graph - - -def PASCol(graph: IRGraph, resource): - """ - Linear Column Parallel - """ - linears = [node for node in graph.nodes() if node.name == 'linear'] - for idx, node in enumerate(linears): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=1, dim=0, num=resource.ngpus - ) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - sub_nodes = graph.replicate(node, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - return graph - - -def PASRow(graph: IRGraph, resource): - """ - Linear Column Parallel - """ - devs = list(range(resource.ngpus)) - - for dl in graph.select(ntype=IRDataOperation): - sub_nodes = graph.replicate(dl, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - - for node in graph.select(ntype=IRFwOperation): - if node.name == 'linear': - _tp(graph, node, devs, idx=0, dim=1, num=len(devs)) - else: - _replica(graph, node, devs) - - return graph - - -def PASHybrid(graph: IRGraph, resource): - """ - Linear Hybrid Parallelism (Megatron) - """ - linears = [node for node in graph.nodes() if node.name == 'linear'] - for idx, node in enumerate(linears): - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=resource.ngpus) - for idx, node in enumerate(tp_nodes): - graph.assign(node, idx) - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - sub_nodes = graph.replicate(node, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - print(graph.extra_repr()) - return graph - - -def PASMegatronTP(graph: IRGraph, resource): - """ - Tensor + Data Parallelism - """ - tp = min(2, resource.ngpus) - dp = resource.ngpus // tp - linears = [node for node in graph.nodes() if node.name == 'linear'] - for idx, node in enumerate(linears): - sub_nodes = [] - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=tp) - for tp_node in tp_nodes: - algo = tp_node.algorithms('dim') - dp_nodes = graph.partition(tp_node, algo, idx=0, dim=0, num=dp) - sub_nodes += dp_nodes - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - sub_nodes = graph.replicate(node, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - # print(graph.extra_repr()) - return graph - - -def PASOptimal(graph: IRGraph, resource): - """ - Square Linear optimal parallelism (4GPU) - """ - assert resource.ngpus == 4, "only apply to 4 GPU case" - - # replicate data operation - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - - # replicate loss operation - fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] - loss = fnodes[-1] - sub_nodes = graph.replicate(loss, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - - fnodes = fnodes[:-1] - # linear0 config - config0 = [ - None, - dict(idx=1, dim=0, num=4) # col - ] - # linear1 config - config1 = [ - dict(idx=0, dim=1, num=2), # row - dict(idx=1, dim=0, num=2), # col - ] - # linear2 config - config2 = [ - dict(idx=0, dim=0, num=2), # dat - dict(idx=0, dim=1, num=2), # row - ] - # linear3 config - config3 = [ - dict(idx=0, dim=0, num=2), # dat - dict(idx=0, dim=1, num=2), # row - ] - configs = [config0, config1, config2, config3] - assert len(fnodes) == len(configs) - for fnode, config in zip(fnodes, configs): - all_nodes = [fnode] - for conf in config: - if conf is None: - continue - sub_nodes = list() - for node in all_nodes: - algo = node.algorithms('dim') - nodes = graph.partition(node, algo, **conf) - sub_nodes += nodes - all_nodes = sub_nodes - assert len(all_nodes) == 4 - for idx, node in enumerate(all_nodes): - graph.assign(node, idx) - return graph - diff --git a/examples/mlp/train.py b/examples/mlp/train.py index 36cf7d01..b4073ed3 100644 --- a/examples/mlp/train.py +++ b/examples/mlp/train.py @@ -1,54 +1,47 @@ """ -example: - -OMP_NUM_THREADS=4 torchrun \ +PYTHONPATH=.:$PYTHONPATH torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/train.py --policy PASMegatron + examples/mlp/train.py --policy PASMegatronTP """ import torch from torch import nn +from functools import partial import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank +from cube.runtime.utils import create_dummy_dataloader -import examples.mlp.policy.spmd as spmd -import examples.mlp.policy.mpmd as mpmd +import examples.mlp.policy.gallery as gallery +from examples.utils import get_policy import argparse parser = argparse.ArgumentParser(description='MLP example') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') +parser.add_argument('--policy', type=str, help='policy choice, starting with "PAS"') +parser.add_argument('--dim', type=int, default=1024, help='model hidden size') +parser.add_argument('--layers', type=int, default=16, help='number of linear layers') +parser.add_argument('--gbs', type=int, default=64, help='global batch size') +parser.add_argument('--mbs', type=int, default=64, help='micro batch size') +parser.add_argument('--tp-size', type=int, default=2, help='tensor parallelism size only for Megatron policy') args = parser.parse_args() cube.init() -# set up policy -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - +# get policy +policy = get_policy([gallery], args.policy) +policy = partial(policy, nmicros=args.gbs//args.mbs, tp_size=args.tp_size) # =================== Semantic Model Description ==================== class MLP(nn.Module): - def __init__(self, dim, mult=1, nlayers=4): + def __init__(self, dim: int, nlayers: int): super().__init__() self.layers = torch.nn.ModuleList([]) - for lid in range(nlayers): - if lid % 2 == 0: - self.layers.append(nn.Linear(dim, dim * mult, bias=False)) - else: - self.layers.append(nn.Linear(dim * mult, dim, bias=False)) + for _ in range(nlayers): + self.layers.append(nn.Linear(dim, dim, bias=False)) def forward(self, data): x = data @@ -58,64 +51,42 @@ def forward(self, data): return loss -class MLPDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, bs: int, dim: int): - super().__init__(bs, [0]) - self.sample = None - self.dim = dim - self.set_batch_size(bs) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - self.batch_size = batch_size - self.sample = torch.rand( - [batch_size, self.dim], dtype=torch.float32, - device=torch.cuda.current_device() - ) - - def train(): - batch_size = 128 - dim = 4096 - model = MLP(dim=dim) - model = cube.SemanticModel(model) - dataloader = MLPDataLoader(batch_size, dim) + model = MLP(dim=args.dim, nlayers=args.layers) + dataloader = create_dummy_dataloader( + torch.randn(args.dim, device=torch.cuda.current_device()), + args.mbs, + ) - @cube.compile(model, dataloader, PAS=PAS) + # compile a training iteration + @cube.compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) loss.backward() - model = model.get_gen_module() + # load generated model + model = cube.utils.load_model() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) CudaTimer(enable=False).warmup() - if torch.distributed.is_initialized(): - torch.distributed.barrier() - iter_num = 16 - warmup = 4 + dataloader = iter(dataloader) + iter_num, warmup = 5, 2 for step in range(iter_num): - if step >= warmup: + if step == warmup: CudaTimer(enable=True).start('e2e') train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() - if step >= warmup: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: + if (step + 1) % 2 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + CudaTimer().stop('e2e') print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-warmup, field_name='e2e'))) CudaTimer().print_all(times=iter_num-warmup) -train() \ No newline at end of file +if __name__ == '__main__': + train() \ No newline at end of file diff --git a/examples/nlp/__init__.py b/examples/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/nlp/blocks/__init__.py b/examples/nlp/blocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index aa90baf9..2a15b28a 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -1,11 +1,6 @@ import torch import cube -@cube.graph.parser.register('* -> *') -@torch.jit.ignore -def func_print_shape(x: torch.Tensor, msg: str): - print(msg, x.size()) - return x @cube.graph.parser.register('L^ N E^, (h+ d^ 3) E^, (h+ d^ 3), E^ (h+ d^) -> L^ N E^', name='self_attention') def self_attention(query: torch.Tensor, @@ -60,70 +55,6 @@ def self_attention(query: torch.Tensor, output = torch.nn.functional.linear(output, out_proj) # L N (h d), E E -> L N E return output -@cube.graph.parser.register('L^ N E^, (h d^ 3) E^, (h d^ 3) -> L^ N (h d^) 3', name='qkv_combined') -def qvk_combined(query: torch.Tensor, - qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, - #out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = False): - num_head = h - L, N = query.size(0), query.size(1) - dim_head = qkv_proj.size(0) // num_head // 3 - - qkv = torch.nn.functional.linear(query, qkv_proj, qkv_bias) # L N E, (h d 3) E -> L N (h d 3) - output = qkv.view(L, N, num_head * dim_head, 3) # L N (h d 3) -> L N (h d) 3 - - return output - -@cube.graph.parser.register('L^ N (h d^) 3 -> L^ N (h d^)', name='attention_mask') -def attention_mask(qkv: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = False): - - L, N = qkv.size(0), qkv.size(1) - num_head = h - dim_head = qkv.size(2) // num_head - - q, k, v = qkv.chunk(3, dim=-1) # L N (h d) 3 -> L N (h d), L N (h d), L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - - # preallocating input tensor: (N h) L L - matmul_input_buffer = torch.empty([N * h, L, L], dtype=q.dtype, device=q.device) - # L (N h) d, L (N h) d -> (N h) L L - attn = torch.baddbmm( - matmul_input_buffer, - q.transpose(0, 1), # (N h) L d - k.transpose(0, 1).transpose(1, 2), # (N h) d L - beta=0.0, alpha=scale - ) - # ======== replace the semantic into more efficient implementation ============ - - # attention mask - if mask: # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - amask = torch.tril(ones) - amask = amask.view(N, 1, L, L) - amask = (amask < 0.5) - attn = attn.masked_fill_(amask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - return output - -@cube.graph.parser.register('L^ N (h+ d^), E^ (h+ d^) -> L^ N E^', name='attention_mask') -def attention_out_linear(lin_input: torch.Tensor, - out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = False): - - output = torch.nn.functional.linear(lin_input, out_proj) # L N (h d), E E -> L N E - return output - @cube.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') def cross_attention(query: torch.Tensor, key: torch.Tensor, @@ -168,101 +99,6 @@ def cross_attention(query: torch.Tensor, key: torch.Tensor, return output -@cube.graph.parser.register('l N E^, L^ N (h+ d), L^ N (h+ d), (h+ d 3) E^, (h+ d 3), E^ (h+ d) -> l N E^, L^ N (h+ d), L^ N (h+ d)', name='one_attention') -def one_attention(hidden_states: torch.Tensor, - past_embed_key: torch.Tensor, - past_embed_value: torch.Tensor, - # q_proj: torch.Tensor, q_bias: torch.Tensor, - # k_proj: torch.Tensor, k_bias: torch.Tensor, - # v_proj: torch.Tensor, v_bias: torch.Tensor, - qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, - out_proj: torch.Tensor, #out_bias: torch.Tensor, - h: int, scale: float, dropout_p: float, is_training: bool = True, mask: bool = False): - num_head = h - l, N = hidden_states.size(0), hidden_states.size(1) - # dim_head = q_proj.size(0) // num_head - dim_head = qkv_proj.size(0) // num_head // 3 - - # q = torch.nn.functional.linear(hidden_states, q_proj, q_bias) # l N E, (h d) E -> l N (h d) - # k = torch.nn.functional.linear(hidden_states, k_proj, k_bias) # l N E, (h d) E -> l N (h d) - # v = torch.nn.functional.linear(hidden_states, v_proj, v_bias) # l N E, (h d) E -> l N (h d) - qkv = torch.nn.functional.linear(hidden_states, qkv_proj, qkv_bias) # l N E, (h d 3) E -> l N (h d) 3 - q, k, v = qkv.chunk(3, dim=-1) - - if past_embed_key is not None and past_embed_value is not None: - k = torch.cat((past_embed_key, k), dim=-3) - v = torch.cat((past_embed_value, v), dim=-3) - - q_N = hidden_states.size(1) - - k_L = k.size(0) - v_L = v.size(0) - - q = q.contiguous().view(l, (N * num_head), dim_head) # l N (h d) -> l (N h) d - k = k.contiguous().view(k_L, (N * num_head), dim_head) # (L+l) N (h d) -> (L+l) (N h) d - v = v.contiguous().view(v_L, (N * num_head), dim_head) # (L+l) N (h d) -> (L+l) (N h) d - - # q = q.transpose(0, 1) # l (N h) d -> (N h) l d - # k = k.transpose(0, 1) # (L+l) (N h) d -> (N h) (L+l) d - # v = v.transpose(0, 1) # (L+l) (N h) d -> (N h) (L+l) d - # q = q * scale # (N h) L d, 1 -> (N h) L d - # k = k.transpose(1, 2) # (N h) (L+l) d -> (N h) d (L+l) - # attn = torch.bmm(q, k) # (N h) l d, (N h) d (L+l) -> (N h) l (L+l) - - # preallocating input tensor: (N h) L L - matmul_input_buffer = torch.empty([N * h, l, k_L], dtype=hidden_states.dtype, device=hidden_states.device) - # L (N h) d, L (N h) d -> (N h) L L - attn = torch.baddbmm( - matmul_input_buffer, - q.transpose(0, 1), # (N h) l d - k.transpose(0, 1).transpose(1, 2), # (N h) d (L+l) - beta=0.0, alpha=scale - ) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) l (L+l) -> (N h) l (L+l) - #no dropout in inference attn - attn = torch.nn.functional.dropout(attn, dropout_p, is_training, False) # (N h) l (L+l) -> (N h) l (L+l) - v_t = v.transpose(0, 1) - output = torch.bmm(attn, v_t) # (N h) l (L+l), (N h) (L+l) d -> (N h) l d - output = output.transpose(0, 1).contiguous() # (N h) l d -> l (N h) d - output = output.view(l, N, num_head * dim_head) # l (N h) d -> l N (h d) - output = torch.nn.functional.linear(output, out_proj, None) # l N (h d), E E -> l N E - return output, k.view(k_L, N, -1), v.view(v_L, N, -1) - -class MultiHeadSelfAttentionFineGrained(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): - super().__init__() - self.inner_dim = inner_dim - self.num_heads = num_heads - self.head_dim = inner_dim // num_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # QKV [(h d 3), E] - self.qkv_proj = torch.nn.Parameter(torch.empty(3 * inner_dim, embed_dim)) - self.qkv_bias = torch.nn.Parameter(torch.empty(3 * inner_dim)) - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) - - def forward(self, query): - - qkv = qvk_combined( - query, self.qkv_proj, self.qkv_bias, - self.num_heads, self.scaling, self.dropout_p, mask=False - ) - lin_input = attention_mask( - qkv, - self.num_heads, self.scaling, self.dropout_p, mask=False - ) - attn = attention_out_linear( - lin_input, - self.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=False - ) - attn = attn + self.out_bias - return attn - class MultiHeadSelfAttention(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): @@ -322,43 +158,3 @@ def forward(self, query: torch.Tensor, key: torch.Tensor): ) attn = attn + self.out_bias return attn - - -class MultiHeadOneAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): - super().__init__() - self.inner_dim = inner_dim - self.num_heads = num_heads - self.head_dim = inner_dim // num_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # # Q - # self.q_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) - # self.q_bias = torch.nn.Parameter(torch.rand(inner_dim)) - # # K - # self.k_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) - # self.k_bias = torch.nn.Parameter(torch.rand(inner_dim)) - # # V - # self.v_proj = torch.nn.Parameter(torch.rand(inner_dim, embed_dim)) - # self.v_bias = torch.nn.Parameter(torch.rand(inner_dim)) - # QKV - self.qkv_proj = torch.nn.Parameter(torch.rand(3 * inner_dim, embed_dim)) - self.qkv_bias = torch.nn.Parameter(torch.rand(3 * inner_dim)) - # Out - self.out_proj = torch.nn.Parameter(torch.rand(embed_dim, inner_dim)) - self.out_bias = torch.nn.Parameter(torch.rand(embed_dim)) - - - def forward(self, query: torch.Tensor, past_embed_key: torch.Tensor, past_embed_value: torch.Tensor): - attn, past_k, past_v = one_attention( - query, past_embed_key, past_embed_value, - # self.q_proj, self.q_bias, - # self.k_proj, self.k_bias, - # self.v_proj, self.v_bias, - self.qkv_proj, self.qkv_bias, - self.out_proj, #self.out_bias, - self.num_heads, self.scaling, self.dropout_p, self.training, mask=True - ) - attn = attn + self.out_bias - return attn, past_k, past_v \ No newline at end of file diff --git a/examples/nlp/blocks/encoder.py b/examples/nlp/blocks/encoder.py deleted file mode 100644 index 4609cf96..00000000 --- a/examples/nlp/blocks/encoder.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch -from examples.nlp.blocks.attention import MultiHeadSelfAttention, MultiHeadOneAttention, func_print_shape, MultiHeadSelfAttentionFineGrained -from examples.nlp.blocks.mlp import MLP - -class EncoderLayerFineGrained(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, - attn_hidden_dim: int, ffn_hidden_dim: int, - dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): - super().__init__() - self.self_attn = MultiHeadSelfAttentionFineGrained( - embed_dim, num_heads, attn_hidden_dim, atten_dropout - ) - self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) - self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - x = self.mlp(x) - x = self.dropout(x) - x = x + residual - return x - -class EncoderLayer(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, - attn_hidden_dim: int, ffn_hidden_dim: int, - dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): - super().__init__() - self.self_attn = MultiHeadSelfAttention( - embed_dim, num_heads, attn_hidden_dim, atten_dropout - ) - self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) - self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x) - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - x = self.mlp(x) - x = self.dropout(x) - x = x + residual - return x - - -class EncoderInferLayer(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, - attn_hidden_dim: int, ffn_hidden_dim: int, seqlen: int = -1, - batch_size: int = 1, - dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0, - moe_size: int = 1): - super().__init__() - self.self_attn_partial = MultiHeadOneAttention( - embed_dim, num_heads, attn_hidden_dim, atten_dropout - ) - self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) if moe_size == 1 else MoEMLP(embed_dim, ffn_hidden_dim, activation_dropout, moe_size) - self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - - # id-embed + pos-embed - tmp_batch_size = batch_size - self.past_embed_key = torch.nn.Parameter(torch.rand(seqlen, tmp_batch_size, embed_dim)) - self.past_embed_value = torch.nn.Parameter(torch.rand(seqlen, tmp_batch_size, embed_dim)) - - # def forward(self, x: torch.Tensor, encoder_output: torch.Tensor) -> torch.Tensor: - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = self.self_attn_layer_norm(x) - x, past_k, past_v = self.self_attn_partial(x, self.past_embed_key, self.past_embed_value) - self.past_embed_key = past_k - self.past_embed_value = past_v - x = self.dropout(x) - x = x + residual - - residual = x - x = self.final_layer_norm(x) - x = self.mlp(x) - x = self.dropout(x) - x = x + residual - # func_print_shape(self.past_embed_key, 'past_k: ') - # func_print_shape(self.past_embed_value, 'past_v: ') - return x \ No newline at end of file diff --git a/examples/nlp/blocks/mlp_moe.py b/examples/nlp/blocks/mlp_moe.py deleted file mode 100644 index bc936cbf..00000000 --- a/examples/nlp/blocks/mlp_moe.py +++ /dev/null @@ -1,212 +0,0 @@ -import torch -import cube -import torch.distributed -from typing import Tuple - -# (N L) emb; emb * exp_num -> (N L), 1(part_idx) -@cube.graph.parser.register('*, * -> *') -@torch.jit.ignore -def gating_func(x, gate_w) -> torch.Tensor: - # assert top_k == 1 - affinities = torch.matmul(x, gate_w) - # print(f'affinities = {affinities}') - dst_pid_list = torch.argmax(affinities, -1) - # print(f'dst_pid_list = {dst_pid_list}') - return dst_pid_list - -# split tokens into groups by target expert -@cube.graph.parser.register('* -> *') -@torch.jit.ignore -def split_tokens_by_eid(tokens, eids, expert_num): - print(f"tokens = {tokens}, shape {tokens.size()}") - print(f"eids = {eids}, shape {eids.size()}") - reshape_needed = list(tokens.size()) != list(eids.size()) - reshape_feat_dim = list(tokens.size())[-1] - print("##### reshape_feat_dim = " + str(reshape_feat_dim)) - if reshape_needed: - vid_part_extend = torch.unsqueeze(eids, 2).repeat(1, 1, reshape_feat_dim) - print("vid_part_extend = " + str(vid_part_extend)) - else: - vid_part_extend = eids - - token_lists = [] - for exp_id in range(0, expert_num): - print("exp_id = " + str(exp_id)) - mask = (vid_part_extend == exp_id) - print("mask = " + str(mask)) - parted_tokens = torch.masked_select(tokens, mask) - if reshape_needed: - parted_tokens = parted_tokens.reshape(-1, reshape_feat_dim) - print("parted_tokens = " + str(parted_tokens)) - token_lists.append(parted_tokens) - return token_lists - - -@cube.graph.parser.register('* -> *') -@torch.jit.ignore -def samesize_all_gather(tensor: torch.Tensor): - tensor_list = [torch.zeros_like(tensor) for _ in - range(torch.distributed.get_world_size())] - torch.distributed.all_gather(tensor_list, tensor) - return torch.stack(tensor_list) - - -@cube.graph.parser.register('* -> *') -@torch.jit.ignore -def nonvarsize_gather(tensor: torch.Tensor, dst): - tensor_list = [torch.zeros_like(tensor) for _ in - range(torch.distributed.get_world_size())] if torch.distributed.get_rank() == dst else None - torch.distributed.gather(tensor, tensor_list, dst) - - return torch.cat(tensor_list) if torch.distributed.get_rank() == dst else None - - -@cube.graph.parser.register('* -> *') -@torch.jit.ignore -def varsize_tensor_gather(tensor: torch.Tensor, dst): - tensor = tensor.contiguous() - # cuda_device = f'cuda:{torch.distributed.get_rank()}' - print(f'tensor.get_device() = {tensor.get_device()}') - size_tens = torch.tensor([tensor.shape[0]], dtype=tensor.dtype, device=f'cuda:{tensor.get_device()}') - print(f'size_tens.get_device() = {size_tens.get_device()}') - size_tens = samesize_all_gather(size_tens) - print(f"size_tens = {size_tens}, tensor.shape[1:] = {tensor.shape[1:]}") - - max_size = size_tens.max().int().item() - padded = torch.empty(max_size, *tensor.shape[1:], dtype=tensor.dtype, device=f'cuda:{tensor.get_device()}') - padded[:tensor.shape[0]] = tensor - - ga = nonvarsize_gather(padded, dst) - print(f" tensor = {tensor}; padded = {padded}; ga = {ga}") - - if torch.distributed.get_rank() != dst: # not this rank as dst - return [] - - slices = [] - for i, sz in enumerate(size_tens): - start_idx = i * max_size - end_idx = start_idx + sz.int().item() - print("start_idx = " + str(start_idx)) - print("end_idx = " + str(end_idx)) - - if end_idx > start_idx: - print("ga[start_idx:end_idx] = " + str(ga[start_idx:end_idx])) - slices.append(ga[start_idx:end_idx]) - # print("slices = " + str(slices)) - else: - slices.append(torch.empty((0, *tensor.shape[1:]), dtype=tensor.dtype, device=f'cuda:{tensor.get_device()}')) - # slices.append(torch.tensor([], dtype=tensor.dtype).resize(0, 3)) - return slices - - -@cube.graph.parser.register('* -> *') -@torch.jit.ignore -def all_to_all_token(input_list): - print(f'***** all_to_all_token.input_list = {input_list}') - data_type = input_list[0].dtype - print(data_type) - ret = [] - for i in range(len(input_list)): - gather_list = varsize_tensor_gather(input_list[i], i) # new replacement - if i == torch.distributed.get_rank(): #TODO check local_rank - ret = gather_list - print(f'***** all_to_all_token.output_list = {ret}') - return ret - - -# N * 1, N * emb -> M * 1, M * emd -@cube.graph.parser.register('*, * -> *') -@torch.jit.ignore -def send_to_experts(dst_pid_list, x, expert_num: int) -> Tuple[torch.Tensor]: - # send to remote and recv from remote - token_lists = split_tokens_by_eid(x, dst_pid_list, expert_num) - print(f'### token_lists = {token_lists}') - local_token_lists = all_to_all_token(token_lists) # exchange idx - print(f'### local_token_lists = {local_token_lists}') - return local_token_lists - - -# M * 1, M * emd -> N * 1, N * emb -@cube.graph.parser.register('*, * -> *') -@torch.jit.ignore -def recv_from_experts(dst_pid_list: torch.Tensor, new_local_token_lists: torch.Tensor, expert_num: int) -> Tuple[torch.Tensor]: - local_token_lists = all_to_all_token(new_local_token_lists) - print(f'### [return] local_token_lists = {local_token_lists}') - - vid_part_np = dst_pid_list.detach().flatten().cpu().tolist() #TODO vid_part_np = dst_pid_list.detach().cpu().numpy() - print("vid_part_np = " + str(vid_part_np)) - # part_count = {} - # for i in range(expert_num): - # part_count[i] = 0 - part_count = [0 for i in range(expert_num)] - print(f'part_count = {part_count}') - - embed_list = [] - for i in range(len(vid_part_np)): - pid = vid_part_np[i] - print(f'pid = {pid}') - offset = part_count[pid] - part_count[pid] += 1 - embed_list.append(local_token_lists[pid][offset]) - - # print("### embed_list = " + str(embed_list)) - embed = torch.stack(embed_list) - print("### final embed = " + str(embed)) - - return embed - - -# @cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward_moe') -@cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+, E^ K-> L^ N E^', name='feedforward_moe') -@torch.jit.ignore -def feedforward_moe(x: torch.Tensor, - proj1: torch.Tensor, proj1_bias: torch.Tensor, - proj2: torch.Tensor, - gate_w: torch.Tensor, - dropout: float, - is_training: bool = True, - expert_num: int = 1) -> torch.Tensor: - #gating - dst_pid_list = gating_func(x, gate_w) - #shuffle tokens - # src_pid_list, x_local - local_token_lists = send_to_experts(dst_pid_list, x, expert_num) - - new_local_token_lists = [] - for x_local in local_token_lists: - #local expert - with torch.no_grad(): - print(f'#### checking ####', x_local, proj1, proj1_bias) - x_local = torch.nn.functional.linear(x_local, proj1, proj1_bias) - x_local = torch.nn.functional.gelu(x_local) - #TODO FIXME x_local = torch.nn.functional.dropout(x_local, dropout, is_training, False) - x_local = torch.nn.functional.linear(x_local, proj2, None) - new_local_token_lists.append(x_local) - - #shuffle back tokens - print(f'### new_local_token_lists = {new_local_token_lists}') - x = recv_from_experts(dst_pid_list, new_local_token_lists, expert_num) - return x - - -class MoEMLP(torch.nn.Module): - def __init__(self, embed_dim: int, hidden_dim: int, dropout: float, expert_num: int = 1): - super().__init__() - # self.proj1 = torch.nn.Parameter(torch.ones((hidden_dim // expert_num, embed_dim))) # TODO fix me empty - # self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim // expert_num,))) - # self.proj2 = torch.nn.Parameter(torch.ones((embed_dim, hidden_dim // expert_num))) # TODO fix me empty - # self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) - # self.gate_w = torch.nn.Parameter(torch.rand((embed_dim, expert_num))) - self.proj1 = torch.nn.Parameter(torch.rand((hidden_dim, embed_dim))) # TODO fix me empty - self.proj1_bias = torch.nn.Parameter(torch.rand((hidden_dim,))) - self.proj2 = torch.nn.Parameter(torch.rand((embed_dim, hidden_dim))) # TODO fix me empty - self.proj2_bias = torch.nn.Parameter(torch.rand((embed_dim,))) - self.gate_w = torch.nn.Parameter(torch.rand((embed_dim, expert_num))) - self.dropout = dropout - self.expert_num = expert_num - - def forward(self, x: torch.Tensor): - x = feedforward_moe(x, self.proj1, self.proj1_bias, - self.proj2, self.gate_w, self.dropout, self.training, self.expert_num) - x = x + self.proj2_bias - return x \ No newline at end of file diff --git a/examples/nlp/blocks/decoder.py b/examples/nlp/blocks/transformer.py similarity index 50% rename from examples/nlp/blocks/decoder.py rename to examples/nlp/blocks/transformer.py index ea4b57de..069a370e 100644 --- a/examples/nlp/blocks/decoder.py +++ b/examples/nlp/blocks/transformer.py @@ -1,48 +1,53 @@ import torch -from examples.nlp.blocks.attention import MultiHeadCrossAttention, MultiHeadSelfAttention +from examples.nlp.blocks.attention import MultiHeadSelfAttention +from examples.nlp.blocks.attention import MultiHeadCrossAttention from examples.nlp.blocks.mlp import MLP -class DecoderLayer(torch.nn.Module): +class TransformerLayer(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, attn_hidden_dim: int, ffn_hidden_dim: int, - dropout: float = 0.0, atten_dropout: float = 0.0, activation_dropout: float = 0.0): + dropout: float = 0.2, atten_dropout: float = 0.2, activation_dropout: float = 0.2, + use_cross_attention: bool = False): super().__init__() self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) self.self_attn = MultiHeadSelfAttention( embed_dim, num_heads, attn_hidden_dim, atten_dropout ) - self.cross_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.cross_attn = MultiHeadCrossAttention( - embed_dim, num_heads, attn_hidden_dim, atten_dropout - ) + self.use_cross_attention = use_cross_attention + if use_cross_attention: + self.cross_attn_layer_norm = torch.nn.LayerNorm(embed_dim) + self.cross_attn = MultiHeadCrossAttention( + embed_dim, num_heads, attn_hidden_dim, atten_dropout + ) self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - def forward(self, x: torch.Tensor, encoder_output: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, encoder_output = None) -> torch.Tensor: + # self attention residual = x x = self.self_attn_layer_norm(x) x = self.self_attn(x) - x = self.dropout(x) x = x + residual - residual = x - x = self.cross_attn_layer_norm(x) - x = self.cross_attn(x, encoder_output) - - x = self.dropout(x) - x = x + residual + # cross attention + if self.use_cross_attention: + residual = x + x = self.cross_attn_layer_norm(x) + x = self.cross_attn(x, encoder_output) + x = self.dropout(x) + x = x + residual + # mlp residual = x x = self.final_layer_norm(x) x = self.mlp(x) - x = self.dropout(x) x = x + residual + return x diff --git a/examples/nlp/gpt/__init__.py b/examples/nlp/gpt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/nlp/gpt/infer.py b/examples/nlp/gpt/infer.py deleted file mode 100644 index c6e8b646..00000000 --- a/examples/nlp/gpt/infer.py +++ /dev/null @@ -1,122 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/nlp/gpt/infer.py --policy PASMeshShard --fp16 - -PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/gpt/infer.py --policy PASSingle --fp16 - -PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 examples/nlp/gpt/infer.py --policy PASMegatronInferTP --fp16 - -PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 examples/nlp/gpt/infer.py --policy PASDP --fp16 --moe_size 2 -""" - - -import torch - -from examples.nlp.gpt.model import GPTInfer, GPTInferDataLoader -from examples.nlp.gpt.model import GPTDataLoader -from examples.nlp.gpt.model import build_gpt_config - -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary, model_summary - -from examples.nlp.gpt.policy.mpmd import PASMegatron as PAS -import examples.nlp.gpt.policy.spmd as spmd -import examples.nlp.gpt.policy.mpmd as mpmd - -import argparse - -parser = argparse.ArgumentParser(description='GPT Inference') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') -parser.add_argument('--fp16', action='store_true', default=False, - help='use fp16 for the inference') -parser.add_argument('--local_rank', type=int, default=0) -parser.add_argument('--moe_size', type=int, default=1, - help='number of experts, use MoE for the inference if moe_size > 1') -args = parser.parse_args() - -cube.init() - -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -policies = [policy for policy in policies if policy.startswith('PAS')] -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - -def inter(): - print(f'torch.cuda.is_available() = {torch.cuda.is_available()}') - - batch_size = 8 - - cfg = build_gpt_config('toy') - cfg.moe_size = args.moe_size - model = GPTInfer(batch_size=batch_size, cfg=cfg) - model = model if not args.fp16 else model.half() - # model = model.cuda() #only for PyTorch run - model.eval() - dataloader = GPTInferDataLoader(batch_size) - - ################## SuperScaler run - model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=PAS, override=True) - def train_iter(model, dataloader): - input_ids, position_ids = next(dataloader) - loss = model(input_ids, position_ids) - return loss - model = model.get_gen_module() - - torch.distributed.barrier() - print_each_rank('model weight consumption:', rank_only=0) - memory_summary() - - CudaTimer(enable=False).warmup() - iter_num = 4 - warmup = 2 - for step in range(iter_num): - # if step == 0: - # model_summary(model, next(dataloader)) - - if step >= warmup: - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - if step >= warmup: - CudaTimer().stop('e2e') - - if step == 0: - print_each_rank('passed first iteration') - if (step + 1) % 10 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num - warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num - warmup) - memory_summary() - - # iter_num = 2 - # for step in range(iter_num): - # output = train_iter(model, dataloader) - # print(f'output = {output}') - - ################## PyTorch run - # output = None - # for i in range(10): - # input_ids, position_ids = next(dataloader) - # print(f'input_ids = {input_ids} [{input_ids.size()}], position_ids = {position_ids} [{position_ids.size()}]') - # output = model(input_ids, position_ids) - # print(f'output = {output}') - - -if __name__ == '__main__': - - cube.init() - inter() \ No newline at end of file diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 13cbcd56..8e5e9dc9 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -1,103 +1,68 @@ import torch +from dataclasses import dataclass -from examples.nlp.blocks.encoder import EncoderLayer, EncoderLayerFineGrained, EncoderInferLayer import cube -from dataclasses import dataclass +from cube.runtime.utils import create_dummy_dataloader + +from examples.nlp.blocks.transformer import TransformerLayer + @dataclass class Config: - embed_dim: int = 1024 + hidden: int = 1024 layers: int = 8 - attention_heads: int = 16 - attn_hidden_dim: int = 1024 + heads: int = 16 ffn_hidden_dim: int = 4096 num_embeddings: int = 51200 seqlen: int = 1024 dropout: float = 0.2 attn_dropout: float = 0.2 activation_dropout: float = 0.2 - moe_size: int = 1 def build_gpt_config(name: str) -> Config: if name == 'toy': - embed_dim, layers, attention_heads = 32, 4, 16 + hidden, layers, heads = 1024, 4, 16 elif name == '350M': - embed_dim, layers, attention_heads = 1024, 24, 16 + hidden, layers, heads = 1024, 24, 16 elif name == '760M': - embed_dim, layers, attention_heads = 1536, 24, 16 + hidden, layers, heads = 1536, 24, 16 elif name == '1.3B': - embed_dim, layers, attention_heads = 2048, 24, 32 + hidden, layers, heads = 2048, 24, 32 elif name == '2.6B': - embed_dim, layers, attention_heads = 2560, 32, 32 + hidden, layers, heads = 2560, 32, 32 elif name == '6.7B': - embed_dim, layers, attention_heads = 4096, 32, 32 + hidden, layers, heads = 4096, 32, 32 elif name == '15B': - embed_dim, layers, attention_heads = 5120, 48, 40 + hidden, layers, heads = 5120, 48, 40 elif name == '39B': - embed_dim, layers, attention_heads = 8192, 48, 64 + hidden, layers, heads = 8192, 48, 64 elif name == '175B': - embed_dim, layers, attention_heads = 12288, 96, 96 + hidden, layers, heads = 12288, 96, 96 else: assert False, f'unrecognized name: {name}' - return Config(embed_dim, layers, attention_heads, embed_dim, 4 * embed_dim) + return Config(hidden, layers, heads, hidden, 4 * hidden) -class GPTFineGrained(torch.nn.Module): - - def __init__(self, cfg=Config()): - super().__init__() - - self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.embed_dim)) - self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) - self.embed_dropout = torch.nn.Dropout() - - self.layers = torch.nn.ModuleList( - [EncoderLayerFineGrained( - cfg.embed_dim, cfg.attention_heads, - cfg.attn_hidden_dim, cfg.ffn_hidden_dim, - cfg.dropout, cfg.attn_dropout, cfg.activation_dropout - ) for _ in range(cfg.layers)] - ) - self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) - - def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): - embed = torch.nn.functional.embedding( - input_ids, self.embedw, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False - ) - pos_embed = self.position(position_ids) - embed = embed + pos_embed - embed = self.embed_dropout(embed) - enc = embed.transpose(0, 1) - - for layer in self.layers: - cube.runtime.function.anchor('transformer start') - enc = layer(enc) - enc = self.final_layernorm(enc) - - logits = torch.nn.functional.linear(enc, self.embedw) - # simplified - loss = torch.sum(logits) - return loss class GPT(torch.nn.Module): - def __init__(self, cfg=Config()): + def __init__(self, cfg: Config): super().__init__() - # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) - self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.embed_dim)) - self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) + # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.hidden) + self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.hidden)) + self.position = torch.nn.Embedding(cfg.seqlen, cfg.hidden) self.embed_dropout = torch.nn.Dropout() self.layers = torch.nn.ModuleList( - [EncoderLayer( - cfg.embed_dim, cfg.attention_heads, - cfg.attn_hidden_dim, cfg.ffn_hidden_dim, - cfg.dropout, cfg.attn_dropout, cfg.activation_dropout + [TransformerLayer( + cfg.hidden, cfg.heads, + cfg.hidden, cfg.ffn_hidden_dim, + cfg.dropout, cfg.attn_dropout, cfg.activation_dropout, + use_cross_attention=False, ) for _ in range(cfg.layers)] ) - self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) + self.final_layernorm = torch.nn.LayerNorm(cfg.hidden) def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): @@ -123,118 +88,18 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): return loss -class GPTInfer(torch.nn.Module): - - def __init__(self, batch_size: int = 1, cfg: Config = Config()): - super().__init__() - # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.embed_dim) - self.embedw = torch.nn.Parameter(torch.rand(cfg.num_embeddings, cfg.embed_dim)) - self.position = torch.nn.Embedding(cfg.seqlen, cfg.embed_dim) - self.embed_dropout = torch.nn.Dropout() +def get_gpt_dummy_dataloader(batch_size: int, cfg: Config): - if cfg.moe_size == 1: - self.layers = torch.nn.ModuleList( - [EncoderInferLayer( - cfg.embed_dim, cfg.attention_heads, - cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.seqlen, - batch_size, - cfg.dropout, cfg.attn_dropout, cfg.activation_dropout - ) for _ in range(cfg.layers)] - ) - else: - assert cfg.moe_size > 1 - self.layers = torch.nn.ModuleList() - for layer_id in range(cfg.layers): - self.layers.append( - EncoderInferLayer( - cfg.embed_dim, cfg.attention_heads, - cfg.attn_hidden_dim, cfg.ffn_hidden_dim, cfg.seqlen, - batch_size, - cfg.dropout, cfg.attn_dropout, cfg.activation_dropout, - 1 if (layer_id % 2) == 0 else cfg.moe_size - ) - ) - - self.final_layernorm = torch.nn.LayerNorm(cfg.embed_dim) + input_ids = torch.randint( + 0, cfg.num_embeddings, + size=(cfg.seqlen,), + dtype=torch.int64, + device=torch.cuda.current_device() + ) + position_ids = torch.arange( + 0, cfg.seqlen, dtype=torch.int64, + device=torch.cuda.current_device() + ).view(cfg.seqlen,) - - def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): - - # embed = self.embed(input_ids) - cube.runtime.function.anchor('first_embed') - embed = torch.nn.functional.embedding( - input_ids, self.embedw, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False - ) - pos_embed = self.position(position_ids) - embed = embed + pos_embed - embed = self.embed_dropout(embed) - enc = embed.transpose(0, 1) - - for layer in self.layers: - cube.runtime.function.anchor('transformer start') - enc = layer(enc) - enc = self.final_layernorm(enc) - - # logits = torch.nn.functional.linear(enc, self.embed.weight) - cube.runtime.function.anchor('last_embed') - logits = torch.nn.functional.linear(enc, self.embedw) - # simplified - # loss = torch.sum(logits) - return logits - - -class GPTDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, bs: int, cfg: Config = None): - self.cfg = Config() if cfg is None else cfg - super().__init__(bs, [0, 0]) - self.sample = None - self.set_batch_size(bs) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - self.batch_size = batch_size - input_ids = torch.randint( - 0, self.cfg.num_embeddings, - size=(self.batch_size, self.cfg.seqlen), - dtype=torch.int64, device=torch.cuda.current_device() - ) - position_ids = torch.arange( - 0, self.cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() - ).repeat(self.batch_size).view(self.batch_size, -1) - self.sample = (input_ids, position_ids) - - -class GPTInferDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, bs: int, cfg: Config = None): - self.cfg = Config() if cfg is None else cfg - super().__init__(bs, [0, 0]) - self.sample = None - self.set_batch_size(bs) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - self.batch_size = batch_size - input_ids = torch.randint( - 0, self.cfg.num_embeddings, - size=(self.batch_size, 1), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - position_ids = torch.arange( - 0, 1, dtype=torch.int64, - device=torch.cuda.current_device() - ).repeat(self.batch_size).view(self.batch_size, -1) - self.sample = (input_ids, position_ids) \ No newline at end of file + return create_dummy_dataloader( + (input_ids, position_ids), batch_size) diff --git a/examples/nlp/gpt/policy/__init__.py b/examples/nlp/gpt/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index 6c67b2bf..279645a8 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -1,87 +1,20 @@ -from typing import List, Tuple -import numpy as np - +"""GPT policy gallery for MPMD Parallelism""" from cube.graph import IRGraph from cube.graph.segment import IRSegment -from cube.graph.function.anchor import IRGraphAnchor -from cube.ir.cten import IRCell from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.schedule.sched1f1b import IRSchedule1F1B -from cube.graph.schedule.schedinfer import IRScheduleInfer - - -def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: - """ - Create hybrid (nested) groups given the each group number. - - The product of group_num should be same with total devices. - - e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: - ( - ( (0,1,2), (3,4,5) ), - ( (0,3), (2,5), (3,6) ), - ) - """ - group_num = np.array(group_num) - cnt = np.prod(group_num) - assert cnt == ngpus, 'total device not match' - grid = np.arange(cnt).reshape(tuple(group_num)) - dims = list(range(len(group_num))) - outputs = [] - for dim, num in enumerate(group_num): - remain = ngpus // num - order = tuple(dims[:dim] + dims[dim+1:] + [dim]) - grid_dim = np.transpose(grid, order).reshape((remain,num)) - grid_dim = grid_dim.tolist() - outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) - assert len(outputs) == len(group_num) - return tuple(outputs) - - -def _group_to_transformers(fnodes) -> List[List[IRCell]]: - # group to transformer layers - transformers: List[List[IRFwOperation]] = [] - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - fnodes[idx+1].comment = f'===> start of transformer layer {lid}' - start = idx if lid != 0 else 0 - end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) - transformers.append(fnodes[start:end]) - for lid in range(len(transformers) - 1): - if transformers[lid][-1].name == 'multiref': - node = transformers[lid].pop() - transformers[lid+1].insert(0, node) - return transformers - -# ========================= parallelisms ================================= - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes +from cube.graph.schedule.predefined import PredefinedSched -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes +from examples.utils import create_mesh, tensor_parallelism, replica, group_to_layers -# ========================= parallelisms ================================= -def PASRoundRobin(graph: IRGraph, resource): +def PASRoundRobin(graph: IRGraph, resource, **kwargs): """ roundrobin scheduling """ fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] # group to transformer layers - transformers = _group_to_transformers(fnodes) + transformers = group_to_layers(fnodes) for lid, transformer in enumerate(transformers): stage_id = lid % resource.ngpus @@ -96,16 +29,15 @@ def PASRoundRobin(graph: IRGraph, resource): return graph -def PAS1F1B(graph: IRGraph, resource): - """ - 1F1B scheduling - """ +def PAS1F1B(graph: IRGraph, resource, nmicros: int = 16, **kwargs): + """1F1B schedule""" num_stages = resource.ngpus - num_microbatch = 16 + num_microbatch = nmicros # group to transformer layers fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - transformers = _group_to_transformers(fnodes) + transformers = group_to_layers(fnodes) + assert len(transformers) >= num_stages # staging fstages = [[] for _ in range(num_stages)] @@ -125,57 +57,23 @@ def PAS1F1B(graph: IRGraph, resource): if isinstance(node, IRDataOperation): graph.assign(node, 0) - strategy = IRSchedule1F1B(graph, num_microbatch) - graph.predef_sched(strategy) + if graph.train(): + PredefinedSched.sched_1f1b(graph, num_microbatch, num_stages) + else: + PredefinedSched.sched_infer_pipe(graph, num_microbatch, num_stages) return graph -def PAS1F(graph: IRGraph, resource): - """ - 1F1B scheduling - """ - num_stages = resource.ngpus - num_microbatch = 16 - - # group to transformer layers - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - transformers = _group_to_transformers(fnodes) - - # staging - fstages = [[] for _ in range(num_stages)] - nlayer_per_stage = (len(transformers) // resource.ngpus) - for lid, fnodes in enumerate(transformers): - stage_id = min(lid // nlayer_per_stage, num_stages - 1) - fstages[stage_id] += fnodes - graph.staging(tuple(stages[0] for stages in fstages)) - - # stage to device - fsegments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] - assert len(fsegments) == num_stages - for devid, segment in enumerate(fsegments): - graph.assign(segment, devid) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - graph.assign(node, 0) - - strategy = IRScheduleInfer(graph, num_microbatch) - graph.predef_sched(strategy) - return graph - - -def PASMegatron(graph: IRGraph, resource): - """ - 1F1B scheduling - """ - dp_size = 1 - tp_size = 2 +def PASMegatron(graph: IRGraph, resource, + tp_size: int = 2, dp_size: int = 1, + nmicros: int = 16, **kwargs ): + """Megatron policy for hybrid data-tensor-pipeline parallelism""" pp_size = resource.ngpus // (dp_size * tp_size) - num_microbatch = 16 + num_microbatch = nmicros # device mesh dp_groups, pp_groups, tp_groups = \ - _create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) + create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) print(f'dp groups: {dp_groups}') print(f'pp groups: {pp_groups}') print(f'tp groups: {tp_groups}') @@ -184,7 +82,7 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] # group to transformer layers - transformers = _group_to_transformers(graph.select(ntype=IRFwOperation)) + transformers = group_to_layers(graph.select(ntype=IRFwOperation)) # group to stage: set each stage operators fstages = [[] for _ in range(pp_size)] @@ -198,11 +96,11 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: bs = dataloader.output(0).shape[0] # partition dataloader - dls = _replica(graph, dataloader, [0]*dp_size) # graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) + dls = replica(graph, dataloader, [0]*dp_size) # graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) for dp_idx, dl in enumerate(dls): # only stage 0 needs dataloader devices = [get_device(dp_idx, 0, tp_idx) for tp_idx in range(tp_size)] - _replica(graph, dl, devices) + replica(graph, dl, devices) fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] assert len(fstages) > 0 @@ -210,105 +108,19 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: for fnode in fstage.nodes(): if len(fnode.inputs()) == 0: continue # anchor if fnode.name == 'self_attention' or fnode.name == 'feedforward': - fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + fnodes = tensor_parallelism(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) elif fnode.name == 'embedding': - fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + fnodes = tensor_parallelism(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) elif fnode.name == 'linear': # the last embeding linear - fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + fnodes = tensor_parallelism(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) elif fnode.name == 'sum': - fnodes = _tp(graph, fnode, [0]*tp_size, idx=0, dim=2, num=tp_size) + fnodes = tensor_parallelism(graph, fnode, [0]*tp_size, idx=0, dim=2, num=tp_size) else: - fnodes = _replica(graph, fnode, [0]*tp_size) + fnodes = replica(graph, fnode, [0]*tp_size) # data parallel for tp_idx, fnode in enumerate(fnodes): dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] batch_dim = fnode.input(0).shape.index(bs) - _tp(graph, fnode, devs=dp_devices, idx=0, dim=batch_dim, num=dp_size) - - strategy = IRSchedule1F1B(graph, num_microbatch) - graph.predef_sched(strategy) - # print(graph.extra_repr()) - return graph - - -def PASPiperSpace(graph: IRGraph, resource): - - # ================= Policy hyper parameter =================== - num_microbatch = 16 - # num_stages = 3 - # sub_meshes = [(1,4), (1,4), (1,8)] # (dp_size, tp_size) - # stage_layers = [(0,6), (6,12), (12,24)] # (start, end) - assert resource.ngpus == 8 - num_stages = 3 - sub_meshes = [(1,2), (1,2), (1,4)] # (dp_size, tp_size) - stage_layers = [(0,6), (6,12), (12,24)] # (start, end) - # ============================================================ - - # checking - transformers = _group_to_transformers(graph.select(ntype=IRFwOperation)) - assert len(stage_layers) == num_stages, f"Expect {num_stages} pipeline stages but got {len(stage_layers)} stage layer assignment." - nlayers = 0 - for sid, (start, end) in enumerate(stage_layers): - prev_end = stage_layers[sid-1][1] if sid > 0 else 0 - assert start == prev_end, f"Layers are not contiguous" - nlayers += end-start - assert nlayers == len(transformers), f"Total layer number {nlayers} != model layers {len(transformers)}" - # check gpus allocation - device_allocation = [] - devices = 0 - for mesh in sub_meshes: - dp_size, tp_size = mesh - assert dp_size >= 1 and tp_size >= 1 - stage_ngpus = dp_size * tp_size - device_allocation.append(stage_ngpus) - devices += stage_ngpus - assert devices <= resource.ngpus, f"Total GPUs in policy ({devices}) > resource capacity ({resource.ngpus})" - - # pipeline staging - fstages = [[] for _ in range(num_stages)] - for sid, (start, end) in enumerate(stage_layers): - for lid in range(start, end): - fstages[sid] += transformers[lid] - graph.staging(tuple(stages[0] for stages in fstages)) - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - - # setup data loader - dataloader = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[0] - - - # sub mesh of (dp_size, tp_size) - for sid, fstage in enumerate(fstages): - dp_size, tp_size = sub_meshes[sid] - devices = np.arange(dp_size * tp_size, dtype=int) + sum(device_allocation[:sid]) - devices = devices.reshape((dp_size, tp_size)) - # setup dataloader - if sid == 0: - dls = _replica(graph, dataloader, [0] * tp_size) - for tp_idx, dl in enumerate(dls): - dp_devices = list(int(devid) for devid in devices[:,tp_idx].flatten()) - dp_dls = graph.partition(dl, dl.algorithms('data'), num=dp_size) - for devid, dp_dl in zip(dp_devices, dp_dls): - graph.assign(dp_dl, devid) - for fnode in fstage.nodes(): - if len(fnode.inputs()) == 0: continue # anchor - # tensor parallel -- FIXME: current restriction needs replica happen before partition - if fnode.name == 'self_attention' or fnode.name == 'feedforward': - fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) - elif fnode.name == 'embedding': - fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) - elif fnode.name == 'linear': # the last embeding linear - fnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) - elif fnode.name == 'sum': - fnodes = _tp(graph, fnode, [0]*tp_size, idx=0, dim=2, num=tp_size) - else: - fnodes = _replica(graph, fnode, [0]*tp_size) - # data parallel - for tp_idx, fnode in enumerate(fnodes): - dp_devices = list(int(devid) for devid in devices[:,tp_idx].flatten()) - batch_dim = fnode.input(0).shape.index(bs) - _tp(graph, fnode, devs=dp_devices, idx=0, dim=batch_dim, num=dp_size) - - strategy = IRSchedule1F1B(graph, num_microbatch) - graph.predef_sched(strategy) + tensor_parallelism(graph, fnode, idx=0, dim=batch_dim, num=dp_size, devs=dp_devices) + PredefinedSched.sched_1f1b(graph, num_microbatch, pp_size) return graph diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index f513f55c..f84f8cc0 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -1,32 +1,16 @@ +"""GPT policy gallery for MPMD Parallelism""" + from typing import List from cube.graph import IRGraph -from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.pyfunc import IRPyFunc from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from examples.utils import tensor_parallelism, replica -# ========================= parallelisms ================================= - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], - idx: int, dim: int, tag='dim'): - algo = node.algorithms(tag) - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes # coshard -def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, +def coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, idx: int, dim: int): algo = node.algorithms('dim') sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) @@ -39,274 +23,71 @@ def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int return sub_nodes -# ========================= parallelisms ================================= - - -def PASSingle(graph: IRGraph, resource): +def PASSingle(graph: IRGraph, resource, **kwargs): + """Single-device execution""" assert resource.ngpus == 1 - # print(graph.extra_repr()) for node in graph.nodes(): if not isinstance(node, IRBpOperation): graph.assign(node, 0) return graph -def PASDP(graph: IRGraph, resource): - dp_size = resource.ngpus - dp_devs = list(range(dp_size)) - +def PASDP(graph: IRGraph, resource, **kwargs): + """Data parallelism""" + devs = list(range(resource.ngpus)) dataloader = graph.select(ntype=IRDataOperation)[0] bs = dataloader.output(0).shape[0] - - # partition dataloader - dls = graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) - for devid, dl in enumerate(dls): - graph.assign(dl, devid) - + # replicate dataloader + replica(graph, dataloader, devs) # partition forward operators for node in graph.select(ntype=IRFwOperation): if isinstance(node, IRPyFunc): graph.assign(node, 0) continue if len(node.inputs()) == 0: continue - #FIXME: a workaround to find batch dimension batch_dim = node.input(0).shape.index(bs) - _tp(graph, node, dp_devs, idx=0, dim=batch_dim) - + tensor_parallelism(graph, node, idx=0, dim=batch_dim, devs=devs) return graph -def PASMegatronTP(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # annotating code structure -- not consider multiref on embedding weight - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - fnodes[idx+1].comment = f'===> start of transformer layer {lid}' - +def PASMegatronTP(graph: IRGraph, resource, **kwargs): + """Megatron-way tensor parallelism""" + devs = list(range(resource.ngpus)) # attention - attns = [node for node in fnodes if node.name == 'self_attention'] - for attn in attns: - _tp(graph, attn, tp_devs, idx=1, dim=0) - + for attn in graph.select(name='self_attention'): + tensor_parallelism(graph, attn, idx=1, dim=0, devs=devs) # feedforward - ffns = [node for node in fnodes if node.name == 'feedforward'] - for ffn in ffns: - _tp(graph, ffn, tp_devs, idx=1, dim=0) - + for ffn in graph.select(name='feedforward'): + tensor_parallelism(graph, ffn, idx=1, dim=0, devs=devs) # partition embed - embeds = [node for node in fnodes if node.name == 'embedding'] - for embed in embeds: - _tp(graph, embed, tp_devs, idx=1, dim=0) - + for embed in graph.select(name='embedding'): + tensor_parallelism(graph, embed, idx=1, dim=0, devs=devs) # partition last linear - linears = [node for node in fnodes if node.name == 'linear'] - _tp(graph, linears[-1], tp_devs, idx=1, dim=0) - + linears = graph.select(name='linear') + tensor_parallelism(graph, linears[-1], idx=1, dim=0, devs=devs) # partition loss - sums = [node for node in fnodes if node.name == 'sum'] - assert len(sums) == 1 - _tp(graph, sums[0], tp_devs, idx=0, dim=2) - - # partition add - # adds = [node for node in fnodes if node.name == 'add'] - # for add in adds: - # # subnodes = _replica(graph, add, [0] * 2) - # # for idx, sub_node in enumerate(subnodes): - # # _tp(graph, sub_node, [0,1] if idx == 0 else [2,3], idx=0, dim=1) - # # _tp(graph, add, tp_devs, idx=0, dim=1) - # subnodes = _tp(graph, add, [0] * 2, idx=0, dim=1) - # for idx, sub_node in enumerate(subnodes): - # _replica(graph, sub_node, [0,1] if idx == 0 else [2,3]) - # - # # partition layernorm - # lns = [node for node in fnodes if node.name == 'layernorm'] - # assert len(lns) > 0 - # for ln in lns: - # # _tp(graph, ln, tp_devs, idx=0, dim=1) - # # subnodes = _replica(graph, ln, [0] * 2) - # # for idx, sub_node in enumerate(subnodes): - # # _tp(graph, sub_node, [0,1] if idx == 0 else [2,3], idx=0, dim=1) - # subnodes = _tp(graph, ln, [0] * 2, idx=0, dim=1) - # for idx, sub_node in enumerate(subnodes): - # _replica(graph, sub_node, [0,1] if idx == 0 else [2,3]) - - - # replicate other nodes - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: - _replica(graph, node, tp_devs) - + sums = graph.select(name='sum') + tensor_parallelism(graph, sums[0], idx=0, dim=2, devs=devs) + # replica other nodes + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + replica(graph, node, devs) return graph -def PASMegatronInferTP(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # annotating code structure -- not consider multiref on embedding weight - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - # why -1: multiref - fnodes[idx - 1].comment = f'===> start of transformer layer {lid}' - +def PASMeshShard(graph: IRGraph, resource, **kwargs): + """Coshard policy for long sequence""" + devs = list(range(resource.ngpus)) # attention - attns = [node for node in fnodes if node.name == 'one_attention'] - for attn in attns: - _tp(graph, attn, tp_devs, idx=3, dim=0) - + for attn in graph.select(name='self_attention'): + # tensor_parallelism(graph, attn, idx=1, dim=0, devs) + coshard(graph, attn, devs, colocate=2, idx=1, dim=0) # feedforward - ffns = [node for node in fnodes if node.name == 'feedforward'] - for ffn in ffns: - _tp(graph, ffn, tp_devs, idx=1, dim=0) - - # func_print_shape - prts = [node for node in fnodes if node.name == 'func_print_shape'] - for prt in prts: - _tp(graph, prt, tp_devs, idx=0, dim=2) - - # first embedding linear - first_emb_anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor) and node.name == 'first_embed'] - print(f'last_emd_anchors = {first_emb_anchors}') - indices = [fnodes.index(anchor) for anchor in first_emb_anchors] - for lid, idx in enumerate(indices): - print(f'fnodes[idx+1].name = {fnodes[idx+1].name}') - print(f'fnodes[idx+1] = {fnodes[idx + 1]}') - first_emb_node = fnodes[idx+1] - _tp(graph, first_emb_node, tp_devs, idx=1, dim=0) - - # last embedding linear - last_emb_anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor) and node.name == 'last_embed'] - print(f'last_emd_anchors = {last_emb_anchors}') - indices = [fnodes.index(anchor) for anchor in last_emb_anchors] - for lid, idx in enumerate(indices): - print(f'fnodes[idx+1].name = {fnodes[idx+1].name}') - print(f'fnodes[idx+1] = {fnodes[idx + 1]}') - last_emb_node = fnodes[idx+1] - _tp(graph, last_emb_node, tp_devs, idx=1, dim=0) - - # replicate other nodes - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: - _replica(graph, node, tp_devs) - - return graph - - -def PASMeshShard(graph: IRGraph, resource): - - # print(graph.extra_repr()) - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # annotating code structure -- not consider multiref on embedding weight - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - # why -1: multiref - fnodes[idx-1].comment = f'===> start of transformer layer {lid}' - - # attention - attns = [node for node in fnodes if node.name == 'self_attention'] - for attn in attns: - # _tp(graph, attn, tp_devs, idx=1, dim=0) - _coshard(graph, attn, tp_devs, colocate=2, idx=1, dim=0) - - # feedforward - ffns = [node for node in fnodes if node.name == 'feedforward'] - for ffn in ffns: - # _tp(graph, ffn, tp_devs, idx=1, dim=0) - _coshard(graph, ffn, tp_devs, colocate=4, idx=1, dim=0) - - # replicate other nodes - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: - _replica(graph, node, tp_devs) - - # print(graph.extra_repr()) + for ffn in graph.select(name='feedforward'): + # tensor_parallelism(graph, ffn, idx=1, dim=0, devs) + coshard(graph, ffn, devs, colocate=4, idx=1, dim=0) + # replica other nodes + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + replica(graph, node, devs) return graph - -def PASMegatronWSRTP(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # annotating code structure -- not consider multiref on embedding weight - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - # why -1: multiref - fnodes[idx-1].comment = f'===> start of transformer layer {lid}' - - qkvs = [node for node in fnodes if node.name == 'qkv_combined'] - for qkv in qkvs: - _tp(graph, qkv, tp_devs, idx=1, dim=0) - - # implement selective recompute - attns = [node for node in fnodes if node.name == 'attention_mask'] - graph.recompute(attns) - for attn in attns: - _tp(graph, attn, tp_devs, idx=0, dim=2) - - lins = [node for node in fnodes if node.name == 'lin'] - for lin in lins: - _tp(graph, lin, tp_devs, idx=1, dim=0) - - # feedforward - ffns = [node for node in fnodes if node.name == 'feedforward'] - for ffn in ffns: - _tp(graph, ffn, tp_devs, idx=1, dim=0) - - # partition embed - embeds = [node for node in fnodes if node.name == 'embedding'] - for embed in embeds: - _tp(graph, embed, tp_devs, idx=1, dim=0) - - # partition last linear - linears = [node for node in fnodes if node.name == 'linear'] - _tp(graph, linears[-1], tp_devs, idx=1, dim=0) - - # partition loss - sums = [node for node in fnodes if node.name == 'sum'] - assert len(sums) == 1 - _tp(graph, sums[0], tp_devs, idx=0, dim=2) - - # tp - def GenerateNodesForSP(nodes): - output=[] - count = 0 - for node in nodes: - if isinstance(node, (IRFwOperation)) and not isinstance(node, (IRGraphAnchor)): - sign = node.signature.split('.')[-1] - cid = node.cid - if len(output) == 0: - if sign == 'layer_norm': - output.append(node) - elif sign == 'dropout': - count = 0 - output.append(node) - count += 1 - elif sign == 'add' and count == 1: - output.append(node) - count += 1 - elif sign == 'layer_norm' and count == 2: - output.append(node) - elif sign == 'add': - output.append(node) - return output - - for node in GenerateNodesForSP(graph.nodes()): - _tp(graph, node, tp_devs, idx=0, dim=0) - - # replicate other nodes - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: - _replica(graph, node, tp_devs) - - return graph \ No newline at end of file diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index a6ee360f..ca45d83f 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -1,69 +1,89 @@ """ example: -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=1 \ - examples/nlp/gpt/train.py --policy PASMegatron --fp16 +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/nlp/gpt/train.py --policy PASMegatronTP --fp16 """ import torch -import time +import logging +from functools import partial -from examples.nlp.gpt.model import GPT, GPTFineGrained, build_gpt_config -from examples.nlp.gpt.model import GPTDataLoader +from model import GPT, Config +from model import get_gpt_dummy_dataloader import cube from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary, model_summary +from cube.profiler.memory import memory_summary -from examples.nlp.gpt.policy.mpmd import PASMegatron as PAS import examples.nlp.gpt.policy.spmd as spmd import examples.nlp.gpt.policy.mpmd as mpmd +from examples.utils import get_policy + import argparse parser = argparse.ArgumentParser(description='GPT Train') + parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') +parser.add_argument('--mbs', type=int, default=8, + help='micro-batch size') +parser.add_argument('--gbs', type=int, default=8, + help='global batch size') +parser.add_argument('--dp', type=int, default=1, + help='data parallel size, only for megatron') +parser.add_argument('--tp', type=int, default=1, + help='tensor parallel size, only for megatron') + +# arch +parser.add_argument('--layers', type=int, default=4, + help='number of transformer layers') +parser.add_argument('--hidden', type=int, default=1024, + help='hidden size') +parser.add_argument('--heads', type=int, default=16, + help='number of attention heads') +parser.add_argument('--seqlen', type=int, default=1024, + help='sequence length') args = parser.parse_args() -cube.init() -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -policies = [policy for policy in policies if policy.startswith('PAS')] -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") +cube.init() +cube.set_logger_level(logging.WARN) +logging.getLogger('cube.compiler').setLevel(logging.INFO) +# get policy +policy = get_policy([spmd, mpmd], args.policy) +policy = partial(policy, + nmicros=args.gbs//args.mbs, + dp_size=args.dp, + tp_size=args.tp +) def train(): - batch_size = 8 - Config=build_gpt_config('760M') - if args.policy == 'PASMegatronWSRTP': - model = GPTFineGrained(Config) - else: - model = GPT(Config) + config = Config( + hidden=args.hidden, + layers=args.layers, + heads=args.heads, + ffn_hidden_dim=4*args.hidden, + num_embeddings=51200, + seqlen=args.seqlen, + ) + model = GPT(config) model = model if not args.fp16 else model.half() - dataloader = GPTDataLoader(batch_size) + dataloader = get_gpt_dummy_dataloader(args.mbs, Config) - model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=PAS, override=True, load_content=True) + @cube.compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) loss.backward() - model = model.get_gen_module() + model = cube.utils.load_model() optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) @@ -71,7 +91,8 @@ def train_iter(model, dataloader): print_each_rank('model weight consumpition:', rank_only=0) memory_summary() - CudaTimer(enable=False).warmup() + CudaTimer().warmup() + dataloader = iter(dataloader) iter_num, warmup = 5, 2 for step in range(iter_num): if step == warmup: diff --git a/examples/nlp/mbart/__init__.py b/examples/nlp/mbart/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/nlp/mbart/model.py b/examples/nlp/mbart/model.py index 302724a5..141084fb 100644 --- a/examples/nlp/mbart/model.py +++ b/examples/nlp/mbart/model.py @@ -1,83 +1,37 @@ import torch import math +from dataclasses import dataclass -from examples.nlp.blocks.encoder import EncoderLayer -from examples.nlp.blocks.decoder import DecoderLayer +from examples.nlp.blocks.transformer import TransformerLayer import cube +from cube.runtime.utils import create_dummy_dataloader -@cube.graph.parser.register('* -> *, *', name='multi2ref') -def multi2ref(tensor: torch.Tensor): - return tensor, tensor - - +@dataclass class Config: - TBD = None # to be decided - # source and target - num_embeddings = 2500 - hidden = 1024 - heads = 16 - layers = 4 - seqlen = 2048 - - max_source_positions = None - max_target_positions = None - - encoder_embed_dim = TBD - encoder_ffn_embed_dim = TBD - encoder_layers = TBD - encoder_attention_heads = TBD - - decoder_embed_dim = TBD - decoder_ffn_embed_dim = TBD - decoder_layers = TBD - decoder_attention_heads = TBD - - attention_dropout = TBD - dropout = TBD - activation_dropout = TBD - - pad_token_id = TBD - eos_token_id = TBD - - # classification task - num_classes = TBD - - def __init__(self) -> None: + hidden: int = 1024 + heads: int = 16 + layers: int = 4 # for encoder and decoder layers separately + seqlen: int = 2048 + ffn_hidden_dim: int = 4096 + vocab: int = 2500 - Config.max_source_positions = Config.seqlen - Config.max_target_positions = Config.seqlen + attention_dropout: float = 0.2 + dropout: float = 0.2 + activation_dropout: float = 0.2 - Config.encoder_embed_dim = Config.hidden - Config.encoder_ffn_embed_dim = 4 * Config.hidden - Config.encoder_layers = Config.layers - Config.encoder_attention_heads = Config.heads - - Config.decoder_embed_dim = Config.hidden - Config.decoder_ffn_embed_dim = 4 * Config.hidden - Config.decoder_layers = Config.layers - Config.decoder_attention_heads = Config.heads - - Config.attention_dropout = 0.1 - Config.dropout = 0.1 - Config.activation_dropout = 0.1 - - Config.pad_token_id = 1 - Config.eos_token_id = 2 - - Config.num_classes = 3 - - def __repr__(self) -> str: - return f'Config(num_embeddings={Config.num_embeddings}, hidden={Config.hidden}, heads={Config.heads}, layers={Config.layers}, seqlen={Config.seqlen})' + pad_token_id: int = 1 + eos_token_id: int = 1 + num_classes: int = 3 class PositionalEmbedding(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int): + def __init__(self, vocab: int, embedding_dim: int): self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) + super().__init__(vocab + self.offset, embedding_dim) def forward(self, seq_len: int): positions = torch.arange( @@ -122,57 +76,49 @@ def forward(self, dec: torch.Tensor): class MBartForSentenceClassification(torch.nn.Module): - def __init__(self, batch_size: int): + def __init__(self, batch_size: int, cfg: Config): super().__init__() - cfg = Config() - self.vocab_size = cfg.num_embeddings - print("Model Arch:", cfg) + self.vocab_size = cfg.vocab # embedding self.vocab = torch.nn.Parameter(torch.empty( - cfg.num_embeddings, cfg.encoder_embed_dim)) + cfg.vocab, cfg.hidden)) # encoder embedding self.embed_offset = 2 self.encoder_position = torch.nn.Parameter(torch.empty( - cfg.max_source_positions, cfg.encoder_embed_dim)) - self.embed_scale_encoder = math.sqrt(cfg.encoder_embed_dim) - self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + cfg.seqlen, cfg.hidden)) + self.embed_scale_encoder = math.sqrt(cfg.hidden) + self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.hidden) # encoder layers self.encoders = torch.nn.ModuleList( - [EncoderLayer( - cfg.encoder_embed_dim, cfg.encoder_attention_heads, - cfg.encoder_embed_dim, cfg.encoder_ffn_embed_dim, - cfg.dropout, cfg.attention_dropout, cfg.activation_dropout - ) for _ in range(cfg.decoder_layers)] + [TransformerLayer( + cfg.hidden, cfg.heads, + cfg.hidden, cfg.ffn_hidden_dim, + cfg.dropout, cfg.attention_dropout, cfg.activation_dropout, + use_cross_attention=False, + ) for _ in range(cfg.layers)] ) - self.layer_norm_encoder = torch.nn.LayerNorm(cfg.encoder_embed_dim) + self.layer_norm_encoder = torch.nn.LayerNorm(cfg.hidden) # decoder embedding self.decoder_position = torch.nn.Parameter(torch.empty( - cfg.max_target_positions, cfg.decoder_embed_dim)) - self.embed_scale_decoder = math.sqrt(cfg.decoder_embed_dim) - self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) + cfg.seqlen, cfg.hidden)) + self.embed_scale_decoder = math.sqrt(cfg.hidden) + self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.hidden) # decoder layers self.decoders = torch.nn.ModuleList( - [DecoderLayer( - cfg.decoder_embed_dim, cfg.decoder_attention_heads, - cfg.decoder_embed_dim, cfg.decoder_ffn_embed_dim, - cfg.dropout, cfg.attention_dropout, cfg.activation_dropout - ) for _ in range(cfg.decoder_layers)] + [TransformerLayer( + cfg.hidden, cfg.heads, + cfg.hidden, cfg.ffn_hidden_dim, + cfg.dropout, cfg.attention_dropout, cfg.activation_dropout, + use_cross_attention=True, + ) for _ in range(cfg.layers)] ) - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.decoder_embed_dim) - self.head = MBartClassificationHead(cfg.decoder_embed_dim, 1024, cfg.num_classes, 0.0) + self.layer_norm_decoder = torch.nn.LayerNorm(cfg.hidden) + self.head = MBartClassificationHead(cfg.hidden, 1024, cfg.num_classes, 0.0) - # FIXME: cube now is not safe for multiple - # tensor transmissions between stages. - decoder_input_ids = torch.randint( - 0, self.vocab_size, (batch_size, cfg.seqlen), dtype=torch.int64, device=torch.device('cpu'), - ) - self.register_buffer('decoder_input_ids', decoder_input_ids) - - - def forward(self, input_ids: torch.Tensor): + def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor): """ The forward is only for benchmark performance, the original input of input_ids, decoder_input_ids and labels are @@ -180,7 +126,6 @@ def forward(self, input_ids: torch.Tensor): The loss computation is also simplified by using sum. """ - # decoder_input_ids = torch.clone(input_ids) # encoder embedding cube.runtime.function.anchor('encoder embedding') enc_emb = torch.nn.functional.embedding(input_ids, self.vocab) @@ -198,105 +143,42 @@ def forward(self, input_ids: torch.Tensor): # decoder embedding cube.runtime.function.anchor('decoder embedding') - dec_emb = torch.nn.functional.embedding(self.decoder_input_ids, self.vocab) + dec_emb = torch.nn.functional.embedding(decoder_input_ids, self.vocab) dec_emb = dec_emb * self.embed_scale_decoder dec_emb = dec_emb + self.decoder_position dec_emb = self.layernorm_embedding_decoder(dec_emb) dec_emb = torch.nn.functional.dropout(dec_emb, p=0.1) dec = dec_emb.transpose(0, 1) - # FIXME: need to cat and chunk because cube now is not safe - # for multiple tensor transformation between stages. - encdec = torch.cat((enc, dec), dim=-1) - # decoder layers for layer in self.decoders: - cube.runtime.function.anchor('decoder layer') - enc, dec = torch.chunk(encdec, 2, dim=-1) - - enc, next_enc = multi2ref(enc) - + cube.runtime.function.anchor('decoder layer') dec = layer(dec, enc) - encdec = torch.cat((next_enc, dec), dim=-1) - enc, dec = torch.chunk(encdec, 2, dim=-1) dec = self.layer_norm_decoder(dec) dec = dec.transpose(0, 1) # head - # loss = self.head(dec, labels) loss = self.head(dec) return loss -class MBartSyntheticDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, bs: int, cfg: Config = None): - self.cfg = Config() if cfg is None else cfg - super().__init__(bs, [0, 0]) - self.sample = None - self.set_batch_size(bs) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - self.batch_size = batch_size - input_ids = torch.randint( - 0, self.cfg.num_embeddings, - size=(self.bs, self.cfg.max_source_positions), - dtype=torch.int64, device=torch.cuda.current_device() - ) - self.sample = input_ids - - -class MBartDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int): - - self.bs = batch_size - self.cfg = Config() - super().__init__( - shapes=([batch_size, self.cfg.max_source_positions,], - [batch_size, self.cfg.max_target_positions], - [batch_size] - ), - dtypes=(torch.int64, torch.int64, torch.int64), - batch_dims=(0, 0, 0) - ) - self.samples = [self.random_sample()] - - def random_sample(self): - input_ids = torch.randint( - 0, self.cfg.num_embeddings, - size=(self.bs, self.cfg.max_source_positions), - dtype=torch.int64, device=torch.cuda.current_device() - ) - decoder_input_ids = MBartDataLoader.shift_tokens_right(input_ids, self.cfg.pad_token_id) - labels = torch.randint( - 0, self.cfg.num_classes, - size=(self.bs,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - return (input_ids, decoder_input_ids, labels) - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] - - @staticmethod - def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): - prev_output_tokens = input_ids.clone() - prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) - index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) - decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() - prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() - prev_output_tokens[:, 0] = decoder_start_tokens - return prev_output_tokens - +def get_mbart_dummy_dataloader(batch_size: int, config: Config): + + input_ids = torch.randint( + 0, config.vocab, + size=(config.seqlen,), + dtype=torch.int64, device=torch.cuda.current_device() + ) + decoder_input_ids = torch.randint( + 0, config.vocab, + size=(config.seqlen,), + dtype=torch.int64, device=torch.cuda.current_device() + ) + labels = torch.randint( + 0, config.num_classes, + size=(), # scalar + dtype=torch.int64, + device=torch.cuda.current_device() + ) + return create_dummy_dataloader((input_ids, decoder_input_ids,), batch_size) diff --git a/examples/nlp/mbart/policy/__init__.py b/examples/nlp/mbart/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/nlp/mbart/policy/gallery.py b/examples/nlp/mbart/policy/gallery.py new file mode 100644 index 00000000..ab45bd8a --- /dev/null +++ b/examples/nlp/mbart/policy/gallery.py @@ -0,0 +1,186 @@ +from typing import List + +from cube.graph import IRGraph +from cube.ir.operator import IRFwOperation, IRDataOperation +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.schedule.predefined import PredefinedSched +from cube.graph.segment import IRSegment +from cube.ir.cten import IRCell + +from examples.utils import create_mesh, tensor_parallelism, replica + + +def _group_to_blocks(fnodes) -> List[List[IRCell]]: + """ + Grouping to [ + [Encoder Embed], + [Encoder Layer], [Encoder Layer], ..., + [Decoder Embed], + [Decoder Layer], [Decoder Layer], ... + ] + """ + blocks = [] + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + # encoder embedding + fnodes[indices[0] + 1].comment = f'==> start of encoder embedding' + assert anchors[0].name == 'encoder embedding' + blocks.append(fnodes[0:indices[1]]) + indices.pop(0) + anchors.pop(0) + # encoder layers + lid = 0 + while anchors[0].name == 'encoder layer': + start, end = indices[0], indices[1] + fnodes[start + 1].comment = f'==> start of encoder layer {lid}' + blocks.append(fnodes[start:end]) + indices.pop(0) + anchors.pop(0) + lid += 1 + # decoder embedding + assert anchors[0].name == 'decoder embedding' + blocks.append(fnodes[indices[0]:indices[1]]) + indices.pop(0) + anchors.pop(0) + # decoder layers + lid = 0 + while len(indices) != 0: + assert anchors[0].name == 'decoder layer' + start, end = indices[0], indices[1] if len(indices) > 1 else len(fnodes) + fnodes[start + 1].comment = f'==> start of decoder layer {lid}' + blocks.append(fnodes[indices[0]:end]) + indices.pop(0) + anchors.pop(0) + lid += 1 + return blocks + + + +def PASSingle(graph: IRGraph, resource, **kwargs): + assert resource.ngpus == 1 + _ = _group_to_blocks(graph.select(ntype=IRFwOperation)) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + graph.assign(node, 0) + return graph + + +def PAS1F1B(graph: IRGraph, resource, nmicros: int = 16, **kwargs): + + num_stages = resource.ngpus + recompute: bool = True + + blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) + enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] + dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] + if recompute: + for block in blocks: + graph.recompute(block) + + # staging + fstages = [[] for _ in range(num_stages)] + nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // num_stages + for lid, fnodes in enumerate(enc_layers + dec_layers): + if lid == 0: + fstages[0] += enc_emb + elif lid == len(enc_layers): + fstages[num_stages // 2] += dec_emb + stage_id = min(lid // nlayer_per_stage, num_stages - 1) + fstages[stage_id] += fnodes + graph.staging(tuple(stage[0] for stage in fstages)) + + dataloader = graph.select(ntype=IRDataOperation)[0] + replica(graph, dataloader, [0, num_stages // 2]) + + fsegments = [seg for seg in graph.select(ntype=IRSegment, flatten=False) if seg.isfw()] + assert len(fsegments) == num_stages, f"Not match: {len(fsegments)} != {num_stages}" + for devid, segment in enumerate(fsegments): + graph.assign(segment, devid) + + strategy = PredefinedSched(graph, nmicros, num_stages) + graph.predef_sched(strategy) + + return graph + + +def PASMegatronTP(graph: IRGraph, resource, **kwargs): + """Megatron-way tensor parallelism""" + devs = list(range(resource.ngpus)) + for node in graph.select(ntype=(IRDataOperation, IRFwOperation)): + if node.name == 'embedding': + tensor_parallelism(graph, node, idx=1, dim=0, devs=devs) + elif node.name == 'self_attention' or node.name == 'feedforward': + tensor_parallelism(graph, node, idx=1, dim=0, devs=devs) + elif node.name == 'cross_attention': + tensor_parallelism(graph, node, idx=2, dim=0, devs=devs) + else: + replica(graph, node, devs) + return graph + + +def PASMegatron(graph: IRGraph, resource, + tp_size: int = 2, dp_size: int = 1, + nmicros: int = 16, **kwargs): + """Megatron policy for hybrid data-tensor-pipeline parallelism""" + dp_size = 2 + tp_size = 2 + pp_size = resource.ngpus // (dp_size * tp_size) + recompute: bool = True + num_microbatch = nmicros + + # device mesh + dp_groups, pp_groups, tp_groups = \ + create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) + print(f'dp groups: {dp_groups}') + print(f'pp groups: {pp_groups}') + print(f'tp groups: {tp_groups}') + + def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: + return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] + + blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) + enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] + dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] + if recompute: + for block in blocks: + graph.recompute(block) + + # pipelien stage + fstages = [[] for _ in range(pp_size)] + nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // pp_size + for lid, fnodes in enumerate(enc_layers + dec_layers): + if lid == 0: + fstages[0] += enc_emb + elif lid == len(enc_layers): + fstages[pp_size // 2] += dec_emb + stage_id = min(lid // nlayer_per_stage, pp_size - 1) + fstages[stage_id] += fnodes + graph.staging(tuple(stage[0] for stage in fstages)) + + # partition dataloader + dataloader = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[0] + replica(graph, dataloader, list(range(resource.ngpus))) + + # tp-dp partition + fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] + assert len(fstages) == pp_size + for pp_idx, fstage in enumerate(fstages): + for node in fstage.nodes(): + if len(node.inputs()) == 0: continue # anchor + if node.name == 'embedding': + nodes = tensor_parallelism(graph, node, idx=1, dim=0, devs=[0]*tp_size) + elif node.name == 'self_attention' or node.name == 'feedforward': + nodes = tensor_parallelism(graph, node, idx=1, dim=0, devs=[0]*tp_size) + elif node.name == 'cross_attention': + nodes = tensor_parallelism(graph, node, idx=2, dim=0, devs=[0]*tp_size) + else: + nodes = replica(graph, node, [0]*tp_size) + # data parallel + for tp_idx, node in enumerate(nodes): + dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] + batch_dim = node.input(0).shape.index(bs) + tensor_parallelism(graph, node, dp_devices, idx=0, dim=batch_dim) + + strategy = PredefinedSched.sched_1f1b(graph, num_microbatch) + graph.predef_sched(strategy) + return graph diff --git a/examples/nlp/mbart/policy/mpmd.py b/examples/nlp/mbart/policy/mpmd.py deleted file mode 100644 index 48299a8b..00000000 --- a/examples/nlp/mbart/policy/mpmd.py +++ /dev/null @@ -1,312 +0,0 @@ -from typing import List, Tuple -import numpy as np - -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.segment import IRSegment -from cube.ir.cten import IRCell -from cube.graph.schedule.sched1f1b import IRSchedule1F1B -from cube.graph.schedule.schedmix import IRScheduleMix - - -def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: - """ - Create hybrid (nested) groups given the each group number. - - The product of group_num should be same with total devices. - - e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: - ( - ( (0,1,2), (3,4,5) ), - ( (0,3), (2,5), (3,6) ), - ) - """ - group_num = np.array(group_num) - cnt = np.prod(group_num) - assert cnt == ngpus, 'total device not match' - grid = np.arange(cnt).reshape(tuple(group_num)) - dims = list(range(len(group_num))) - outputs = [] - for dim, num in enumerate(group_num): - remain = ngpus // num - order = tuple(dims[:dim] + dims[dim+1:] + [dim]) - grid_dim = np.transpose(grid, order).reshape((remain,num)) - grid_dim = grid_dim.tolist() - outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) - assert len(outputs) == len(group_num) - return tuple(outputs) - - -def _group_to_blocks(fnodes) -> List[List[IRCell]]: - """ - Grouping to [ - [Encoder Embed], - [Encoder Layer], [Encoder Layer], ..., - [Decoder Embed], - [Decoder Layer], [Decoder Layer], ... - ] - """ - blocks = [] - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - # encoder embedding - fnodes[indices[0] + 1].comment = f'==> start of encoder embedding' - assert anchors[0].name == 'encoder embedding' - blocks.append(fnodes[0:indices[1]]) - indices.pop(0) - anchors.pop(0) - # encoder layers - lid = 0 - while anchors[0].name == 'encoder layer': - start, end = indices[0], indices[1] - fnodes[start + 1].comment = f'==> start of encoder layer {lid}' - blocks.append(fnodes[start:end]) - indices.pop(0) - anchors.pop(0) - lid += 1 - # decoder embedding - assert anchors[0].name == 'decoder embedding' - blocks.append(fnodes[indices[0]:indices[1]]) - indices.pop(0) - anchors.pop(0) - # decoder layers - lid = 0 - while len(indices) != 0: - assert anchors[0].name == 'decoder layer' - start, end = indices[0], indices[1] if len(indices) > 1 else len(fnodes) - fnodes[start + 1].comment = f'==> start of decoder layer {lid}' - blocks.append(fnodes[indices[0]:end]) - indices.pop(0) - anchors.pop(0) - lid += 1 - return blocks - - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): - if len(devs) == 1: - graph.assign(node, devs[0]) - sub_nodes = [node] - else: - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - if len(devs) == 1: - graph.assign(node, devs[0]) - sub_nodes = [node] - else: - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def PASSingle(graph: IRGraph, resource): - assert resource.ngpus == 1 - _ = _group_to_blocks(graph.select(ntype=IRFwOperation)) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - -def PAS1F1B(graph: IRGraph, resource): - - num_stages = resource.ngpus - num_microbatch = 4 - recompute: bool = True - - blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) - enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] - dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] - if recompute: - for block in blocks: - graph.recompute(block) - - # staging - fstages = [[] for _ in range(num_stages)] - nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // num_stages - for lid, fnodes in enumerate(enc_layers + dec_layers): - if lid == 0: - fstages[0] += enc_emb - elif lid == len(enc_layers): - fstages[num_stages // 2] += dec_emb - stage_id = min(lid // nlayer_per_stage, num_stages - 1) - fstages[stage_id] += fnodes - graph.staging(tuple(stage[0] for stage in fstages)) - - dataloader = graph.select(ntype=IRDataOperation)[0] - _replica(graph, dataloader, [0, num_stages // 2]) - - fsegments = [seg for seg in graph.select(ntype=IRSegment, flatten=False) if seg.isfw()] - assert len(fsegments) == num_stages, f"Not match: {len(fsegments)} != {num_stages}" - for devid, segment in enumerate(fsegments): - graph.assign(segment, devid) - - strategy = IRSchedule1F1B(graph, num_microbatch) - graph.predef_sched(strategy) - - return graph - - -def PASMegatronTP(graph: IRGraph, resource): - - tp_size = resource.ngpus - recompute: bool = True - devs = list(range(tp_size)) - - blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) - if recompute: - for block in blocks: - graph.recompute(block) - - for node in graph.select(ntype=IRFwOperation): - if node.name == 'embedding': - _tp(graph, node, devs, idx=1, dim=0) - elif node.name == 'self_attention' or node.name == 'feedforward': - _tp(graph, node, devs, idx=1, dim=0) - elif node.name == 'cross_attention': - _tp(graph, node, devs, idx=2, dim=0) - else: - _replica(graph, node, devs) - - dataloader = graph.select(ntype=IRDataOperation)[0] - _replica(graph, dataloader, devs) - - return graph - - -def PASMegatron(graph: IRGraph, resource): - - dp_size = 2 - tp_size = 2 - pp_size = resource.ngpus // (dp_size * tp_size) - recompute: bool = True - num_microbatch = 16 - - # device mesh - dp_groups, pp_groups, tp_groups = \ - _create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) - print(f'dp groups: {dp_groups}') - print(f'pp groups: {pp_groups}') - print(f'tp groups: {tp_groups}') - - def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: - return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] - - blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) - enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] - dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] - if recompute: - for block in blocks: - graph.recompute(block) - - # pipelien stage - fstages = [[] for _ in range(pp_size)] - nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // pp_size - for lid, fnodes in enumerate(enc_layers + dec_layers): - if lid == 0: - fstages[0] += enc_emb - elif lid == len(enc_layers): - fstages[pp_size // 2] += dec_emb - stage_id = min(lid // nlayer_per_stage, pp_size - 1) - fstages[stage_id] += fnodes - graph.staging(tuple(stage[0] for stage in fstages)) - - # partition dataloader - dataloader = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[0] - dls = _replica(graph, dataloader, [0]*dp_size) # graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) - for dp_idx, dl in enumerate(dls): - devices = [get_device(dp_idx, 0, tp_idx) for tp_idx in range(tp_size)] - _replica(graph, dl, devices) - - # tp-dp partition - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - assert len(fstages) == pp_size - for pp_idx, fstage in enumerate(fstages): - for node in fstage.nodes(): - if len(node.inputs()) == 0: continue # anchor - if node.name == 'embedding': - nodes = _tp(graph, node, [0]*tp_size, idx=1, dim=0) - elif node.name == 'self_attention' or node.name == 'feedforward': - nodes = _tp(graph, node, [0]*tp_size, idx=1, dim=0) - elif node.name == 'cross_attention': - nodes = _tp(graph, node, [0]*tp_size, idx=2, dim=0) - else: - nodes = _replica(graph, node, [0]*tp_size) - # data parallel - for tp_idx, node in enumerate(nodes): - dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] - batch_dim = node.input(0).shape.index(bs) - _tp(graph, node, dp_devices, idx=0, dim=batch_dim) - - strategy = IRSchedule1F1B(graph, num_microbatch) - graph.predef_sched(strategy) - return graph - - -def PASMixPipe(graph: IRGraph, resource): - - tp_size = 2 - pp_size = resource.ngpus // tp_size - - blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) - enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] - dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] - - num_microbatch = 4 - - # pipelien stage - embed_sid = [0, pp_size // 2 + 1] - fstages = [[] for _ in range(pp_size)] - nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // pp_size - for lid, fnodes in enumerate(enc_layers + dec_layers): - stage_id = min(lid // nlayer_per_stage, pp_size - 1) - fstages[stage_id] += fnodes - fstages.insert(embed_sid[0], enc_emb) - fstages.insert(embed_sid[1], dec_emb) - graph.staging(tuple(stage[0] for stage in fstages)) - - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - assert len(fstages) == pp_size + 2 - - # fully shard enmbedding - enc_emb, dec_emb = fstages[embed_sid[0]], fstages[embed_sid[1]] - tp_device = list(range(resource.ngpus)) - for node in enc_emb.nodes() + dec_emb.nodes(): - # skip anchor nodes - if isinstance(node, IRGraphAnchor): continue - # shard embedding layer to all devices - if node.name == 'embedding': - _tp(graph, node, tp_device, idx=1, dim=0) - else: - _replica(graph, node, tp_device) - - dataloader = graph.select(ntype=IRDataOperation)[0] - _replica(graph, dataloader, tp_device) - - # pipeline stage to devices - pipe_stages = [stage for sid, stage in enumerate(fstages) if sid not in embed_sid] - assert len(pipe_stages) == pp_size - for sid, stage in enumerate(pipe_stages): - tp_devs = [idx for idx in range(tp_size * sid, tp_size * sid + tp_size)] - for node in stage.nodes(): - if len(node.inputs()) == 0: continue # anchor - if node.name == 'self_attention' or node.name == 'feedforward': - _tp(graph, node, tp_devs, idx=1, dim=0) - elif node.name == 'cross_attention': - _tp(graph, node, tp_devs, idx=2, dim=0) - else: - _replica(graph, node, tp_devs) - - strategy = IRScheduleMix(graph, num_microbatch) - graph.predef_sched(strategy) - - return graph diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index 04ee0e3c..f6b49deb 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -1,33 +1,39 @@ """ example: -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=8 \ - --nnodes=1 \ - examples/nlp/mbart/train.py --policy PASMixPipe +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/nlp/mbart/train.py --policy PASMegatronTP --fp16 """ import torch +import logging +import argparse +import math +from functools import partial from examples.nlp.mbart.model import MBartForSentenceClassification, Config -from examples.nlp.mbart.model import MBartSyntheticDataLoader +from examples.nlp.mbart.model import get_mbart_dummy_dataloader import cube from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary, model_summary -import examples.nlp.mbart.policy.mpmd as mpmd +from cube.profiler.memory import memory_summary +import examples.nlp.mbart.policy.gallery as gallery -import argparse -import math +from examples.utils import get_policy parser = argparse.ArgumentParser(description='GPT Train') parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') +parser.add_argument('--dp', type=int, default=1, + help='data parallel size, only for megatron') +parser.add_argument('--tp', type=int, default=1, + help='tensor parallel size, only for megatron') # training -parser.add_argument('--gbs', type=int, default=1, help='global batch size') -parser.add_argument('--mbs', type=int, default=2, help='micro batch size') +parser.add_argument('--gbs', type=int, default=4, help='global batch size') +parser.add_argument('--mbs', type=int, default=4, help='micro batch size') # arch parser.add_argument('--vocab', type=int, default=2500, help='used vocabulary size') @@ -45,14 +51,18 @@ cube.init() print(args) -PAS = None -policies = list(mpmd.__dict__.keys()) -policies = [policy for policy in policies if policy.startswith('PAS')] -if args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") + +cube.init() +cube.set_logger_level(logging.WARN) +logging.getLogger('cube.compiler').setLevel(logging.INFO) + +# get policy +policy = get_policy([gallery], args.policy) +policy = partial(policy, + nmicros=args.gbs//args.mbs, + dp_size=args.dp, + tp_size=args.tp +) def trunc_normal_(tensor: torch.Tensor, mean=0., std=1., a=-2., b=2.): @@ -73,46 +83,39 @@ def norm_cdf(x): def train(): batch_size = args.mbs - Config.num_embeddings = args.vocab - Config.layers = args.layers - Config.hidden = args.hidden - Config.heads = args.heads - Config.seqlen = args.seqlen - - if cube.runtime.device.DeviceGroup().local_rank == 0: - model = MBartForSentenceClassification(batch_size) - model = model.half() if args.fp16 else model - else: - model = None - dataloader = MBartSyntheticDataLoader(batch_size) - - model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=PAS, override=True, load_content=False) - def train_iter(model, dataloader): - input_ids = next(dataloader) - loss = model(input_ids) - loss.backward() - model = model.get_gen_module() - - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - for name, buffer in model.named_buffers(): - torch.manual_seed(0) - if name.startswith('decoder_input_ids'): - inputs = torch.randint( - 0, args.vocab, buffer.size(), - dtype=torch.int64, device=torch.cuda.current_device(), - ) - buffer.copy_(inputs) - + + config = Config( + hidden=args.hidden, + heads=args.heads, + layers=args.layers, + seqlen=args.seqlen, + ffn_hidden_dim=args.hidden * 4, + vocab=args.vocab, + ) + print_each_rank(config) + + model = MBartForSentenceClassification(batch_size, config) torch.manual_seed(0) for param in model.parameters(): trunc_normal_(param) + model = model.half() if args.fp16 else model + + dataloader = get_mbart_dummy_dataloader(batch_size, config) + + @cube.compile(model, dataloader, PAS=policy) + def train_iter(model, dataloader): + input_ids, decoder_input_ids = next(dataloader) + loss = model(input_ids, decoder_input_ids) + loss.backward() + model = cube.load_model() - CudaTimer(enable=False).warmup() + optimizer = torch.optim.Adam( + model.parameters(), lr=3e-05, betas=(0.9, 0.98)) + + CudaTimer().warmup() + dataloader = iter(dataloader) iter_num, warmup = 5, 2 for step in range(iter_num): - if step == warmup: CudaTimer(enable=True).start('e2e') @@ -123,7 +126,6 @@ def train_iter(model, dataloader): if step == 0: print_each_rank('passed first iteration') - if (step + 1) % 2 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) diff --git a/examples/nlp/palm/module_profiler.py b/examples/nlp/palm/module_profiler.py deleted file mode 100644 index ab942023..00000000 --- a/examples/nlp/palm/module_profiler.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -from cube.profiler import CudaTimer - -bs, n, dim, heads, dim_head = 10, 2048, 4096, 16, 256 -scale = 0.125 - -dev = torch.device('cuda:0') - -def multi_head_attention(x: torch.Tensor, qkv_proj: torch.Tensor, - out_proj: torch.Tensor): - - q, kv = torch.matmul(x, qkv_proj).split((dim, dim_head), dim=-1) - q = q.view(bs, n, heads, dim_head).transpose(1, 2) - q = q.reshape(bs, heads * n, dim_head) - trans_kv = kv.transpose(1, 2) - sim = torch.bmm(q, trans_kv).view(bs, heads, n, n) - attn = torch.nn.functional.softmax(sim, dim=-1) - attn = attn.view(bs, heads * n, n) - out = torch.bmm(attn, kv).view(bs, heads, n, dim_head) - out = torch.transpose(out, 1, 2).reshape(bs, n, dim) - out = torch.matmul(out, out_proj) - return out - -def ffn(x: torch.Tensor, xx: torch.Tensor, y: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor): - return torch.matmul(x, w1), torch.matmul(xx * y, w2) - -x = torch.randn(bs, n, dim).to(dev) -xx = torch.randn(bs, n, dim).to(dev) -y = torch.randn(bs, n, dim).to(dev) -qkv_proj = torch.randn(dim, dim+dim_head).to(dev) -q_proj = torch.randn(dim, dim).to(dev) -kv_proj = torch.randn(dim, dim_head).to(dev) -out_proj = torch.randn(dim, dim).to(dev) -w1 = torch.randn(dim, 2 * dim).to(dev) -w2 = torch.randn(dim, dim).to(dev) -score = torch.randn([bs * heads * n, n], requires_grad=True).to(dev) - -CudaTimer(enable=False).warmup() - -iter_num = 64 -warmup = 20 - -for step in range(iter_num): - softmax_score = torch.nn.functional.softmax(score, dim=-1) - if step >= warmup: - CudaTimer(enable=True).start('e2e') - # out = multi_head_attention(x, qkv_proj, out_proj) - # out = ffn(x, xx, y, w1, w2) - out = torch.autograd.grad(outputs=softmax_score, inputs=score, grad_outputs=softmax_score) - if step >= warmup: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print(f'iter [{step + 1}/{iter_num}]') - -print('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num - warmup, field_name='e2e'))) \ No newline at end of file diff --git a/examples/nlp/palm/palm.py b/examples/nlp/palm/palm.py deleted file mode 100644 index a75c370d..00000000 --- a/examples/nlp/palm/palm.py +++ /dev/null @@ -1,334 +0,0 @@ -from typing import List - -import torch -import torch.nn.functional as F -from torch import nn, einsum - -from math import log2, floor - -# from einops import rearrange, repeat - -from cube.graph import IRGraph - -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.ir.operator import IRDataOperation, IRFwOperation - -import examples.nlp.palm.policy.spmd as spmd -import examples.nlp.palm.policy.mpmd as mpmd - -import argparse - -cube.init() - -# =================== Semantic Model Description ==================== - -# normalization - - -class RMSNorm(nn.Module): - - def __init__(self, dim, eps=1e-8): - super().__init__() - self.scale = dim**-0.5 - self.eps = eps - self.g = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - norm = torch.norm(x, dim=-1, keepdim=True) * self.scale - return x / norm.clamp(min=self.eps) * self.g - - -def exists(val): - return val is not None - - -# AliBi - - -class AlibiPositionalBias(nn.Module): - - def __init__(self, heads, **kwargs): - super().__init__() - self.heads = heads - slopes = torch.Tensor(self._get_slopes(heads)) - slopes = rearrange(slopes, 'h -> h 1 1') - self.register_buffer('slopes', slopes, persistent=False) - self.register_buffer('bias', None, persistent=False) - - def get_bias(self, i, j, device): - i_arange = torch.arange(i, device=device) - j_arange = torch.arange(j, device=device) - bias = -torch.abs( - rearrange(j_arange, 'j -> 1 1 j') - - rearrange(i_arange, 'i -> 1 i 1')) - return bias - - @staticmethod - def _get_slopes(heads): - - def get_slopes_power_of_2(n): - start = (2**(-2**-(log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if log2(heads).is_integer(): - return get_slopes_power_of_2(heads) - - closest_power_of_2 = 2**floor(log2(heads)) - return get_slopes_power_of_2( - closest_power_of_2) + get_slopes_power_of_2( - 2 * closest_power_of_2)[0::2][:heads - closest_power_of_2] - - def forward(self, qk_sim): - h, i, j, device = *qk_sim.shape[-3:], qk_sim.device - - if exists(self.bias) and self.bias.shape[-1] >= j: - return self.bias[..., :i, :j] - - bias = self.get_bias(i, j, device) - bias = bias * self.slopes - - num_heads_unalibied = h - bias.shape[0] - bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) - self.register_buffer('bias', bias, persistent=False) - - return bias - - -@cube.graph.parser.register('N L^ E^, E^ F^, E^ E^ -> N L^ E^', - name='multi_head_attention') -def multi_head_attention(x: torch.Tensor, qkv_proj: torch.Tensor, - out_proj: torch.Tensor, heads: int, scale: float): - ''' - x: [bs, len, dim] - qkv_proj: [dim, dim + dim_head] - out_proj: [dim, dim] - ''' - bs, n, dim = x.size() - dim_head = dim // heads - - q, kv = torch.matmul(x, qkv_proj).split((dim, dim_head), dim=-1) - q = q.view(bs, n, heads, dim_head).transpose(1, 2) * scale - q = q.reshape(bs, heads * n, dim_head) - trans_kv = kv.transpose(1, 2) - sim = torch.bmm(q, trans_kv).view(bs, heads, n, n) - attn = torch.nn.functional.softmax(sim, dim=-1) - attn = attn.view(bs, heads * n, n) - out = torch.bmm(attn, kv).view(bs, heads, n, dim_head) - out = torch.transpose(out, 1, 2).reshape(bs, n, dim) - out = torch.matmul(out, out_proj) - return out - - -@cube.graph.parser.register('N L^ E^, E^ F^, G^ H^ -> N L^ H^', - name='feedforward') -def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): - ''' - x: [bs, len, dim] - proj1: [dim, 2 * ff_mult * dim] - proj2: [ff_mult * dim, dim] - ''' - x = torch.matmul(x, proj1) - x, gate = x.chunk(2, dim=-1) - x = torch.nn.functional.silu(gate) * x - x = torch.matmul(x, proj2) - return x - - -@cube.graph.parser.register('N L^ E+, E+ F -> N L^ F', name='feedforward1') -def feedforward1(x: torch.Tensor, proj: torch.Tensor): - return torch.nn.functional.silu(torch.matmul(x, proj)) - - -@cube.graph.parser.register('N L^ E+, E+ F -> N L^ F', name='feedforward2') -def feedforward2(x: torch.Tensor, proj: torch.Tensor): - return torch.matmul(x, proj) - - -@cube.graph.parser.register('N L^ E+, N L^ E+, E+ F -> N L^ F', - name='feedforward3') -def feedforward3(x: torch.Tensor, y: torch.Tensor, proj: torch.Tensor): - return torch.matmul(x * y, proj) - -@cube.graph.parser.register('* -> *, *', name='multi2ref') -def multi2ref(x: torch.Tensor): - return (x, x) - - -class PaLMLayer(nn.Module): - - def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): - super().__init__() - - self.dim, self.dim_head, self.heads, self.scale = dim, dim_head, heads, dim_head**-0.5 - - # TODO - # self.alibi_pos_biases = AlibiPositionalBias(heads=self.heads) - # self.norm = RMSNorm(dim) - self.norm = torch.nn.LayerNorm(self.dim) - - self.qkv_proj = torch.nn.Parameter(torch.randn(dim, dim + dim_head)) - self.attn_out_proj = torch.nn.Parameter(torch.randn(dim, dim)) - - self.ff_proj1 = torch.nn.Parameter(torch.randn(dim, 2 * ff_mult * dim)) - self.ff_proj2 = torch.nn.Parameter(torch.randn(ff_mult * dim, dim)) - - # self.register_buffer("mask", None, persistent=False) - - def get_mask(self, n, device): - if self.mask is not None and self.mask.shape[-1] >= n: - return self.mask[:n, :n] - - mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), - 1) - self.register_buffer("mask", mask, persistent=False) - return mask - - def forward(self, in_x): - bs, n, device = in_x.shape[0], in_x.shape[1], in_x.device - - # pre layernorm - x = self.norm(in_x) - - attn_out = multi_head_attention(x, self.qkv_proj, self.attn_out_proj, - self.heads, self.scale) - - ff_out = feedforward(x, self.ff_proj1, self.ff_proj2) - - return in_x + attn_out + ff_out - - -class PaLMLayerV2(nn.Module): - - def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): - super().__init__() - - self.dim, self.dim_head, self.heads, self.scale = dim, dim_head, heads, dim_head**-0.5 - - # TODO - # self.alibi_pos_biases = AlibiPositionalBias(heads=self.heads) - # self.norm = RMSNorm(dim) - self.norm = torch.nn.LayerNorm(self.dim) - - self.qkv_proj = torch.nn.Parameter(torch.randn(dim, dim + dim_head)) - self.attn_out_proj = torch.nn.Parameter(torch.randn(dim, dim)) - - self.ff_proj1 = torch.nn.Parameter(torch.randn(dim, ff_mult * dim)) - self.ff_proj2 = torch.nn.Parameter(torch.randn(dim, ff_mult * dim)) - self.ff_proj3 = torch.nn.Parameter(torch.randn(ff_mult * dim, dim)) - - # self.register_buffer("mask", None, persistent=False) - - def get_mask(self, n, device): - if self.mask is not None and self.mask.shape[-1] >= n: - return self.mask[:n, :n] - - mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), - 1) - self.register_buffer("mask", mask, persistent=False) - return mask - - def forward(self, in_x): - - in_x = cube.runtime.function.identity(in_x) - residual = in_x - # pre layernorm - x = self.norm(in_x) - - attn_out = multi_head_attention(x, self.qkv_proj, self.attn_out_proj, - self.heads, self.scale) - - ff1 = feedforward1(x, self.ff_proj1) - ff2 = feedforward2(x, self.ff_proj2) - ff_out = feedforward3(ff1, ff2, self.ff_proj3) - - return attn_out + ff_out + residual - -class PaLM(nn.Module): - - def __init__(self, - dim, - num_tokens, - depth, - dim_head=64, - heads=8, - ff_mult=4): - super().__init__() - - self.net = nn.Sequential( - nn.Embedding(num_tokens, dim), - # *[PaLMLayer(dim, dim_head, heads, ff_mult) for _ in range(depth)], - *[ - PaLMLayerV2(dim, dim_head, heads, ff_mult) - for _ in range(depth) - ], - torch.nn.LayerNorm(dim), - nn.Linear(dim, num_tokens, bias=False), - ) - - self.net[-1].weight = self.net[0].weight - nn.init.normal_(self.net[0].weight, std=0.02) - - def forward(self, x): - return self.net(x).mean() - - -def train(): - bs, n, dim = 5, 2048, 4096 - num_tokens, depth, heads, dim_head = 20000, 1, 16, 256 - - model = PaLM(dim, num_tokens, depth, heads=heads, dim_head=dim_head) - - # for debug - # tokens = torch.randint(0, num_tokens, (bs, n)) - # print(model(tokens)) - # return - - model = cube.SemanticModel( - model, - input_shapes=([bs, n], ), - ) - - dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, n], ), - dtypes=(torch.float32, ), - batch_dims=(0, )) - - # @cube.compile(model, dataloader, PAS=PASSingle) - # @cube.compile(model, dataloader, PAS=PASBranch) - # @cube.compile(model, dataloader, PAS=PASData) - # @cube.compile(model, dataloader, PAS=PASBranch3) - # @cube.compile(model, dataloader, PAS=spmd.PASMegatron) - @cube.compile(model, dataloader, PAS=mpmd.PASBranch5) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - - model = model.get_gen_module() - - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - CudaTimer(enable=False).warmup() - if torch.distributed.is_initialized(): - torch.distributed.barrier() - iter_num = 64 - warmup = 20 - for step in range(iter_num): - if step >= warmup: - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - if step >= warmup: - CudaTimer().stop('e2e') - if (step + 1) % 20 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num - warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num - warmup) - - -train() \ No newline at end of file diff --git a/examples/nlp/palm/policy/mpmd.py b/examples/nlp/palm/policy/mpmd.py deleted file mode 100644 index d90c8522..00000000 --- a/examples/nlp/palm/policy/mpmd.py +++ /dev/null @@ -1,150 +0,0 @@ -from typing import List -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.ir.tensor import IRSubTensor, IRFullTensor - -def PASBranch3(graph: IRGraph, resource): - ''' - 3 way branch - ''' - assert resource.ngpus == 3 - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] - - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - if node.name == 'embedding' or node.name == 'linear': - # data parallel - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=0, - dim=batch_dim, - num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif node.name == 'layernorm' or node.name == 'multiref' or node.name == 'add' or node.name == 'mean': - # replicate - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif node.name == 'feedforward1': - graph.assign(node, 0) - elif node.name == 'feedforward2': - graph.assign(node, 1) - elif node.name == 'feedforward3': - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=2, dim=0, num=2) - graph.assign(sub_nodes[0], 0) - graph.assign(sub_nodes[1], 1) - elif node.name == 'multi_head_attention': - graph.assign(node, 2) - else: - assert False, node.name - - return graph - - -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def convert_add_to_valmap(graph: IRGraph, add_node: IRFwOperation): - """ - Remove add node by replacing with tensor valmap - """ - assert add_node.name == 'add' - ptensors, producers = [], [] - for itensor in add_node.inputs(): - iptensors = graph.ptensors(itensor.parent) - assert len(set(t.valmap for t in iptensors)) == len(iptensors) - ptensors += iptensors - producers += graph.producers(itensor.parent) - ftensor = add_node.output(0).parent - for idx, (ptensor, producer) in enumerate(zip(ptensors, producers)): - fidx = producer.outputs().index(ptensor) - bidx = producer.mirror.inputs().index(ptensor.grad) - ptensor = ftensor.select(ptensor.indmap, (idx, len(producers))) - ptensor.grad = ftensor.grad.select(ptensor.indmap, (0,1)) - with graph.update(producer): - producer.set_output(fidx, ptensor) - with graph.mirror.update(producer.mirror) as bnode: - bnode.set_input(bidx, ptensor.grad) - graph.remove(add_node) - graph.mirror.remove(add_node.mirror) - - -def flatten_branch_grad(graph: IRGraph, ftensor: IRFullTensor): - """ - Flatten valmap for different branches. - """ - assert ftensor.requires_grad - ctensors = graph.ctensors(ftensor) - consumers = graph.consumers(ftensor) - # same tinput ensor - assert all(ctensor == ctensors[0] for ctensor in ctensors) - # different gradient (no replicate) - assert len(set(ctensor.grad.valmap for ctensor in ctensors)) == len(ctensors) - for idx, (consumer, ctensor) in enumerate(zip(consumers, ctensors)): - with graph.mirror.update(consumer.mirror) as bnode: - tidx = bnode.outputs().index(ctensor.grad) - ctensor.grad = ftensor.grad.select(ctensor.indmap, (idx, len(ctensors))) - bnode.set_output(tidx, ctensor.grad) - - -def PASBranch5(graph: IRGraph, resource): - ''' - 5 way branch - ''' - assert resource.ngpus == 5 - devs = list(range(resource.ngpus)) - for node in graph.select(ntype=IRDataOperation): - _replica(graph, node, devs) - for node in graph.select(name='embedding'): - _tp(graph, node, devs, idx=1, dim=0) - for node in graph.select(name='linear'): - _tp(graph, node, devs, idx=1, dim=0) - for node in graph.select(name='mean'): - _tp(graph, node, devs, idx=0, dim=2) - for node in graph.select(name='layernorm'): - _replica(graph, node, devs) - for node in graph.select(name='feedforward1'): - _tp(graph, node, [0, 1], idx=1, dim=1) - for node in graph.select(name='feedforward2'): - _tp(graph, node, [2, 3], idx=1, dim=1) - for node in graph.select(name='feedforward3'): - _tp(graph, node, [0, 1, 2, 3], idx=2, dim=0) - for node in graph.select(name='multi_head_attention'): - graph.assign(node, 4) - for node in graph.select(name='identity'): - _replica(graph, node, devs) - adds = tuple(graph.select(name='add')) - assert len(adds) == 2 - # graph.assign(adds[0], 4) - convert_add_to_valmap(graph, adds[0]) - _replica(graph, adds[1], devs) - # convert_add_to_valmap(graph, adds[1]) - for node in graph.select('feedforward1'): - ftensor = node.input(0).parent - break - flatten_branch_grad(graph, ftensor) - print(graph.extra_repr()) - return graph \ No newline at end of file diff --git a/examples/nlp/palm/policy/spmd.py b/examples/nlp/palm/policy/spmd.py deleted file mode 100644 index ba20e0a6..00000000 --- a/examples/nlp/palm/policy/spmd.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import List -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation - -def PASSingle(graph: IRGraph, resource): - assert resource.ngpus == 1 - - for node in graph.nodes(): - if isinstance(node, (IRDataOperation, IRFwOperation)): - graph.assign(node, 0) - - return graph - - -def PASData(graph: IRGraph, resource): - ''' - 2 way Data Parallel - ''' - # assert resource.ngpus == 2 - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] - - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=0, - dim=batch_dim, - num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return graph - - -def PASMegatron(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - - def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for dev_id, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, dev_id) - return sub_nodes - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - _replica(graph, node, tp_devs) - batch_dim = node.get_batch_dims()[0] - - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - if node.name == 'embedding': - _tp(graph, node, tp_devs, idx=1, dim=0) - elif node.name == "linear": - _tp(graph, node, tp_devs, idx=1, dim=0) - elif node.name == 'multi_head_attention': - # TODO: data parallel current - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=0, - dim=batch_dim, - num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - elif node.name == 'feedforward1': - _tp(graph, node, tp_devs, idx=1, dim=1) - elif node.name == 'feedforward2': - _tp(graph, node, tp_devs, idx=1, dim=1) - elif node.name == 'feedforward3': - _tp(graph, node, tp_devs, idx=2, dim=0) - elif node.name == 'mean': - _tp(graph, node, tp_devs, idx=0, dim=2) - else: - _replica(graph, node, tp_devs) - return graph diff --git a/examples/nlp/torchscale/policy/mpmd.py b/examples/nlp/torchscale/policy/mpmd.py deleted file mode 100644 index bf00f3ee..00000000 --- a/examples/nlp/torchscale/policy/mpmd.py +++ /dev/null @@ -1,103 +0,0 @@ -import random -from typing import Tuple -import numpy as np - -from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.schedule.predefined import PredefinedSched - - -def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: - """ - Create hybrid (nested) groups given the each group number. - - The product of group_num should be same with total devices. - """ - group_num = np.array(group_num) - cnt = np.prod(group_num) - assert cnt == ngpus, 'total device not match' - grid = np.arange(cnt).reshape(tuple(group_num)) - dims = list(range(len(group_num))) - outputs = [] - for dim, num in enumerate(group_num): - remain = ngpus // num - order = tuple(dims[:dim] + dims[dim+1:] + [dim]) - grid_dim = np.transpose(grid, order).reshape((remain,num)) - grid_dim = grid_dim.tolist() - outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) - assert len(outputs) == len(group_num) - return tuple(outputs) - - -def PASRandom(graph, resource): - """ - Random pipeline - """ - assert len(graph.nodes()) // 2 >= resource.ngpus, "not enough operator number." - remain_device = set(range(resource.ngpus)) - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - if len(remain_device) != 0: - idx = random.randint(0, len(remain_device) - 1) - device = list(remain_device)[idx] - remain_device.remove(device) - else: - device = random.randint(0, resource.ngpus - 1) - graph.assign(node, device) - elif isinstance(node, IRDataOperation): - device = random.randint(0, resource.ngpus - 1) - graph.assign(node, device) - print(graph.extra_repr()) - return graph - - -def PASMegatron(graph: IRGraph, resource): - - # assert resource.ngpus == 8, "should apply on 8 gpus" - num_stage = 4 - num_tp = resource.ngpus // num_stage - num_microbatch = resource.ngpus * 8 - - _, tp_mesh = _create_mesh(resource.ngpus, (num_stage, num_tp)) - print(f'> pipeline-tensor parallel group: {tp_mesh}') - assert len(tp_mesh) == num_stage - - linears = graph.select('linear') - stage_start_nodes = linears[::len(linears) // num_stage] - stage_start_nodes = stage_start_nodes[:num_stage] - assert len(stage_start_nodes) == num_stage, f"{len(stage_start_nodes)} != {num_stage}" - graph.staging(stage_start_nodes) - - segments = graph.select(ntype=IRSegment, flatten=False) - fsegs = [seg for seg in segments if seg.isfw()] - assert len(fsegs) == num_stage - - for sid, segment in enumerate(fsegs): - # get tensor parallel group - tp_group = tp_mesh[sid] - for idx, node in enumerate(segment.nodes()): - # partition - if node.name == 'linear': - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx % 2, num=num_tp) - else: - tp_nodes = graph.replicate(node, times=num_tp) - # assign - for devid, node in zip(tp_group, tp_nodes): - graph.assign(node, devid) - - for dl in graph.select(ntype=IRDataOperation): - mesh = tp_mesh[0] - dls = graph.replicate(dl, times=num_tp) - for devid, dl in zip(mesh, dls): - graph.assign(dl, devid) - - # setup schedule to 1F1B - # schedule = IRSchedule1F1B(num_microbatch, tp_mesh, recompute=False) - # graph.schedule_plan = schedule - if graph.train: - schedule = PredefinedSched.sched_1f1b(graph, num_microbatch, num_stage) - else: - schedule = PredefinedSched.sched_infer_pipe(graph, num_microbatch, num_stage) - return graph diff --git a/examples/nlp/torchscale/policy/spmd.py b/examples/nlp/torchscale/policy/spmd.py deleted file mode 100644 index 15e39c53..00000000 --- a/examples/nlp/torchscale/policy/spmd.py +++ /dev/null @@ -1,252 +0,0 @@ -from typing import List -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.gener.rvd.intra import IntraAutoPlacer -from cube.graph.function import IRTensor - - -# tensor parallelism with auto-placer -# This is an implementation example of SPMD auto placer usage -def _tp_autoplace(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): - - if len(devs) == 1: - graph.assign(node, devs[0]) - return [node] - - segment: IRSegment = graph.segment(node) - ftensor = node.input(configs['idx']).parent - - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - producers = segment.producers(ftensor) - if ftensor.is_param() or len(producers) != len(sub_nodes): - print(f"> skip auto placer due to condition not matched: " - f"nproducers: {len(producers)}, nconsumers: {len(sub_nodes)}, " f"producer name: {producers[0].name if len(producers) > 0 else None}") - devs = sorted(list(devs)) - for devid, node in zip(devs, sub_nodes): - graph.assign(node, devid) - else: - devices = IntraAutoPlacer.auto_place( - segment, ftensor, producers, sub_nodes) - for devid, subnode in zip(devices, sub_nodes): - graph.assign(subnode, devid) - return sub_nodes - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def PASSingle(graph: IRGraph, resource): - """ - Single device - """ - assert resource.ngpus == 1, "only apply for single gpu case" - for node in graph.nodes(): - if isinstance(node, (IRDataOperation, IRFwOperation)): - graph.assign(node, 0) - return graph - - -def PASData(graph: IRGraph, resource): - """ - Data Parallel - """ - # auto multi-ref - for ftensor in graph.full_tensors(): - if len(graph.consumers(ftensor)) > 1: - if ftensor.is_attr(): - continue - graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, subnode in enumerate(sub_nodes): - graph.assign(subnode, idx) - batch_dim = node.get_batch_dims()[0] - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - try: - algo = node.algorithms('dim') - idx = 0 - sub_nodes = graph.partition( - node, algo, idx=idx, dim=batch_dim, num=resource.ngpus) - except AssertionError: - print(f'WARNING: {node} cannot find dim algo, using replicate instead') - sub_nodes = graph.replicate(node, resource.ngpus) - - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - return graph - - -def PASCol(graph: IRGraph, resource): - """ - Linear Column Parallel - """ - linears = [node for node in graph.nodes() if node.name == 'linear'] - for idx, node in enumerate(linears): - algo = node.algorithms('dim') - sub_nodes = graph.partition( - node, algo, idx=1, dim=0, num=resource.ngpus - ) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - sub_nodes = graph.replicate(node, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - return graph - - -def PASRow(graph: IRGraph, resource): - """ - Linear Column Parallel - """ - devs = list(range(resource.ngpus)) - - for dl in graph.select(ntype=IRDataOperation): - sub_nodes = graph.replicate(dl, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - - for node in graph.select(ntype=IRFwOperation): - if node.name == 'linear': - _tp(graph, node, devs, idx=0, dim=1, num=len(devs)) - else: - _replica(graph, node, devs) - - return graph - - -def PASHybrid(graph: IRGraph, resource): - """ - Linear Hybrid Parallelism (Megatron) - """ - linears = [node for node in graph.nodes() if node.name == 'linear'] - for idx, node in enumerate(linears): - try: - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=resource.ngpus) - for idx, node in enumerate(tp_nodes): - graph.assign(node, idx) - except AssertionError: - print(f'WARNING: {node} cannot find dim algo, using replicate instead') - sub_nodes = graph.replicate(node, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - sub_nodes = graph.replicate(node, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - print(graph.extra_repr()) - return graph - - -def PASMegatronTP(graph: IRGraph, resource): - """ - Tensor + Data Parallelism - """ - tp = min(2, resource.ngpus) - dp = resource.ngpus // tp - linears = [node for node in graph.nodes() if node.name == 'linear'] - for idx, node in enumerate(linears): - sub_nodes = [] - algo = node.algorithms('dim') - tp_nodes = graph.partition(node, algo, idx=1, dim=idx%2, num=tp) - for tp_node in tp_nodes: - algo = tp_node.algorithms('dim') - dp_nodes = graph.partition(tp_node, algo, idx=0, dim=0, num=dp) - sub_nodes += dp_nodes - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - sub_nodes = graph.replicate(node, resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - # print(graph.extra_repr()) - return graph - - -def PASOptimal(graph: IRGraph, resource): - """ - Square Linear optimal parallelism (4GPU) - """ - assert resource.ngpus == 4, "only apply to 4 GPU case" - - # replicate data operation - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - - # replicate loss operation - fnodes = [fnode for fnode in graph.nodes() if isinstance(fnode, IRFwOperation)] - loss = fnodes[-1] - sub_nodes = graph.replicate(loss, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - - fnodes = fnodes[:-1] - # linear0 config - config0 = [ - None, - dict(idx=1, dim=0, num=4) # col - ] - # linear1 config - config1 = [ - dict(idx=0, dim=1, num=2), # row - dict(idx=1, dim=0, num=2), # col - ] - # linear2 config - config2 = [ - dict(idx=0, dim=0, num=2), # dat - dict(idx=0, dim=1, num=2), # row - ] - # linear3 config - config3 = [ - dict(idx=0, dim=0, num=2), # dat - dict(idx=0, dim=1, num=2), # row - ] - configs = [config0, config1, config2, config3] - assert len(fnodes) == len(configs) - for fnode, config in zip(fnodes, configs): - all_nodes = [fnode] - for conf in config: - if conf is None: - continue - sub_nodes = list() - for node in all_nodes: - algo = node.algorithms('dim') - nodes = graph.partition(node, algo, **conf) - sub_nodes += nodes - all_nodes = sub_nodes - assert len(all_nodes) == 4 - for idx, node in enumerate(all_nodes): - graph.assign(node, idx) - return graph - diff --git a/examples/nlp/torchscale/run_torchscale_lm.py b/examples/nlp/torchscale/run_torchscale_lm.py deleted file mode 100644 index 90c679d4..00000000 --- a/examples/nlp/torchscale/run_torchscale_lm.py +++ /dev/null @@ -1,149 +0,0 @@ -# single GPU inference debug -# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --ddp-backend=c10d --subln --xpos-rel-pos --fp16 --policy PASData -# multi-GPU inference test -# OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 --master_port=25642 examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --subln --xpos-rel-pos --fp16 --policy PASData -# single-GPU training test -# OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=1 --master_port=25642 examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --subln --xpos-rel-pos --fp16 --policy PASData --do_train -# multi-GPU training test -# OMP_NUM_THREADS=12 USE_TORCHFX=1 PYTHONPATH=.:..:$PYTHONPATH python -m torch.distributed.launch --nproc_per_node=2 --master_port=25642 examples/nlp/torchscale/run_torchscale_lm.py examples/nlp/torchscale/lm_input --activation-fn gelu --share-decoder-input-output-embed --validate-interval-updates 1000 --save-interval-updates 1000 --no-epoch-checkpoints --memory-efficient-fp16 --fp16-init-scale 4 --arch lm_base --task language_modeling --sample-break-mode none --tokens-per-sample 128 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 --lr 5e-4 --lr-scheduler polynomial_decay --warmup-updates 750 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --batch-size 4 --update-freq 1 --required-batch-size-multiple 1 --total-num-update 50000 --max-update 50000 --seed 1 --subln --xpos-rel-pos --fp16 --policy PASData --do_train - -import torch -import pickle -from fairseq import ( - tasks, - options, - checkpoint_utils -) -from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from fairseq.trainer import Trainer -from fairseq.data import iterators - -import sys -import os - -# https://github.com/microsoft/torchscale/tree/main/examples/fairseq -sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') -sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale') -print(f'sys.path = {sys.path}') -import models - -#:torchscaletest/torchscale -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -sys.path.append('.') -from policy import mpmd, spmd -# import examples.nlp.torchscale.policy.spmd as spmd - -# import argparse - -# parser = argparse.ArgumentParser(description='comm primitive') -# parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -# parser.add_argument('--local_rank', type=int, default=0) -# args = parser.parse_args() - -# build model -parser = options.get_training_parser() -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -parser.add_argument('--do_train', action='store_true', default=False) -# parser.add_argument('--local_rank', type=int, default=0) - -args = options.parse_args_and_arch(parser) -print(f"Running mode: {'TRAIN' if args.do_train else 'EVAL'}") - -cube.init() -# set up policy -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - -cfg = convert_namespace_to_omegaconf(args) -task = tasks.setup_task(cfg.task) -model = task.build_model(cfg.model) -if args.do_train: - model.train() -else: - model.eval() -print("building model succeed: ", type(model)) - -# create dummy input -with open('examples/nlp/torchscale/input_lm', 'rb') as f: - dummy_input = pickle.load(f) -device = next(model.parameters()).device -print(f'device = {device}') -for key in dummy_input.keys(): - dummy_input[key] = dummy_input[key].to(device) -print(f'dummy_input <{type(dummy_input)}> = {dummy_input}') - -# create input as list of tensors/objects -dummy_input_list = [val for key, val in dict(dummy_input).items()] -# print(f'dummy_input_list = {dummy_input_list}, len = {len(dummy_input_list)}') - -with torch.no_grad(): - output_origin = model(**dummy_input) - # output_origin = model(*dummy_input_list) - # print(f'output_origin = {output_origin}') - - -input_shapes = [list(dummy_input[input].size()) for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] -input_dtypes = [dummy_input[input].dtype for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] -input_names = tuple([input for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)]) - -# input_shapes += [[None], [None]] -# input_dtypes += [bool, bool] - -print(f'input_shapes = {input_shapes}') -print(f'input_dtypes = {input_dtypes}') - -dataloader = cube.runtime.syndata.SynDataLoader( - # names=('src_tokens',), - shapes=(input_shapes), - dtypes=input_dtypes, - batch_dims=(0, 0), -) - -sample_input = next(dataloader) -print(f'next(dataloader) = {sample_input}') -if isinstance(sample_input, tuple): - sample_input_cpu = tuple([val.to(device) if isinstance(val, torch.Tensor) else val for val in sample_input]) -elif isinstance(sample_input, dict): - sample_input_cpu = sample_input - for key in sample_input_cpu.keys(): - sample_input_cpu[key] = sample_input_cpu[key].to(device) -else: - raise RuntimeError(f'To fix sample_input with type{type(sample_input)}') - - -model = cube.SemanticModel( - model, dummy_input=dummy_input, -) - -if args.do_train: - @cube.compile(model, dataloader, PAS=PAS, load_content=False, override=True) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(*data) - loss.backward() - # TODO fix loss.mirror DummyInputOutput issue - - model = model.get_gen_module() - train_iter(model, dataloader) -else: # do_eval - @cube.compile(model, dataloader, PAS=PAS, load_content=False, override=True) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(*data) - return loss - - model = model.get_gen_module() - iter_ret = train_iter(model, dataloader) - print(f'iter_ret = {iter_ret}') - -print('DONE') \ No newline at end of file diff --git a/examples/nlp/torchscale/run_torchscale_tl.py b/examples/nlp/torchscale/run_torchscale_tl.py deleted file mode 100644 index f7f58ff3..00000000 --- a/examples/nlp/torchscale/run_torchscale_tl.py +++ /dev/null @@ -1,204 +0,0 @@ -# USE_TORCHFX=1 SINGLE_DEV_MODE=1 PYTHONPATH=.:$PYTHONPATH:torchscaletest/torchscale python examples/nlp/torchscale/run_torchscale_tl.py examples/nlp/torchscale/input --arch mt_base --share-decoder-input-output-embed --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --dropout 0.3 --weight-decay 0.0001 --max-tokens 4096 --fp16 --policy PASData - -import torch -import pickle -from fairseq import ( - tasks, - options, - checkpoint_utils -) -from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from fairseq.trainer import Trainer -from fairseq.data import iterators - -import sys - -import os -print(f'os.getcwd() = {os.getcwd()}') - - -# https://github.com/microsoft/torchscale/tree/main/examples/fairseq -# sys.path.append('/home/v-junliang/torchscaletest/torchscale/examples/fairseq') -# sys.path.append('./torchscaletest/torchscale/examples/fairseq') -sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale/examples/fairseq') -sys.path.append('examples/nlp/torchscale/torchscaletest/torchscale') -print(f'sys.path = {sys.path}') -import models - -#:torchscaletest/torchscale -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -sys.path.append('.') -from policy import mpmd, spmd -# import examples.nlp.torchscale.policy.spmd as spmd - -# import argparse - -# parser = argparse.ArgumentParser(description='comm primitive') -# parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -# parser.add_argument('--local_rank', type=int, default=0) -# args = parser.parse_args() - -# build model -parser = options.get_training_parser() -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with "PAS"') -# parser.add_argument('--local_rank', type=int, default=0) - -args = options.parse_args_and_arch(parser) - -cube.init() -# set up policy -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - -cfg = convert_namespace_to_omegaconf(args) -task = tasks.setup_task(cfg.task) -model = task.build_model(cfg.model) -model.eval() -print("building model succeed: ", type(model)) - -# create dummy input -with open('examples/nlp/torchscale/input_tl', 'rb') as f: - dummy_input = pickle.load(f) -device = next(model.parameters()).device -print(f'device = {device}') -for key in dummy_input.keys(): - dummy_input[key] = dummy_input[key].to(device) -print("creating dummy input succeed") -dummy_input['features_only'] = False -dummy_input['return_all_hiddens'] = False -print(f'dummy_input = {dummy_input}, {type(dummy_input)}') - -# create input as list of tensors/objects -dummy_input_list = [val for key, val in dict(dummy_input).items()] -print(f'dummy_input_list = {dummy_input_list}') - -with torch.no_grad(): - # output_origin = model(**dummy_input) - output_origin = model(*dummy_input_list) - # print(f'output_origin = {output_origin}') - - -input_shapes = [list(dummy_input[input].size()) for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] -input_dtypes = [dummy_input[input].dtype for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)] -input_names = tuple([input for input in dummy_input if isinstance(dummy_input[input], torch.Tensor)]) - -input_shapes += [[None], [None]] -input_dtypes += [bool, bool] - -print(f'input_shapes = {input_shapes}') -print(f'input_dtypes = {input_dtypes}') - -dataloader = cube.runtime.syndata.SynDataLoader( - shapes=(input_shapes), - dtypes=input_dtypes, - batch_dims=(0,0,0, None, None), -) -sample_input = next(dataloader) -print(f'next(dataloader) = {sample_input}') -sample_input_cpu = tuple([val.to(device) if isinstance(val, torch.Tensor) else val for val in sample_input]) - -model = cube.SemanticModel( - model, dummy_input=sample_input_cpu, -) - -@cube.compile(model, dataloader, PAS=PAS, load_content=False) -def train_iter(model, dataloader): - data = next(dataloader) - loss = model(*data) - loss.backward() - -train_iter(model, dataloader) - -# Conduct concrete trace below -# sys.path.append('/home/v-junliang/torchscaletest/nni') -# sys.path.append('./torchscaletest/nni') -# from nni.common.concrete_trace_utils import concrete_trace -# from concrete_trace_utils import concrete_trace -from examples.nlp.torchscale.concrete_trace_utils import concrete_trace -import examples.nlp.torchscale.torchscaletest.torchscale - - -def check_equal(a, b): - if type(a) != type(b): - return False - if isinstance(a, (list, tuple, set)): - if len(a) != len(b): - return False - for sub_a, sub_b in zip(a, b): - if not check_equal(sub_a, sub_b): - return False - return True - elif isinstance(a, dict): - keys_a, kes_b = set(a.keys()), set(b.keys()) - if keys_a != kes_b: - return False - for key in keys_a: - if not check_equal(a[key], b[key]): - return False - return True - elif isinstance(a, torch.Tensor): - return torch.equal(a, b) - else: - return a == b - - -print("start tracing...") -traced_model, _ = concrete_trace( - model, - dummy_input, - use_operator_patch=True, - autowrap_leaf_class={ - torch.finfo: ((), False), - type(output_origin): ((), False), - }, -) -print("trace succeed") -print("checking equal...") -with torch.no_grad(): - output_traced = traced_model(**dummy_input) -assert check_equal(output_origin, output_traced), "check equal failed" -print("checked") - -# check graph -traced_model.graph.print_tabular() - -# with open('input_tl', 'wb') as f: -# pickle.dump(dummy_input, f) - -# try to save traced model with pickle -# from concrete_trace_utils.concrete_tracer import MagicMethodPatcher -# from pickle import _Pickler, _Unpickler - -# with open("save/through_nn_Module/tl_traced_v2.model", "wb") as f: -# # pickle.dump(traced_model, f) -# with MagicMethodPatcher(): -# _Pickler(f).dump(traced_model) - -# with open("save/through_nn_Module/tl_traced.model", "rb") as f: -# with MagicMethodPatcher(): -# reload_model = _Unpickler(f).load() - - -# with torch.no_grad(): -# output_reload = reload_model(**dummy_input) -# assert check_equal(output_origin, output_reload), "reload check equal failed" -# print("reload is good!") - -# with open("save/through_nn_Module/tl_origin_v2.model", "wb") as f: -# with MagicMethodPatcher(): -# _Pickler(f).dump(model) - -# with open("save/through_nn_Module/tl_input_v2.pkl", "wb") as f: -# with MagicMethodPatcher(): -# _Pickler(f).dump(dummy_input) - diff --git a/examples/poisson/policy/spmd.py b/examples/poisson/policy/spmd.py deleted file mode 100644 index 03720306..00000000 --- a/examples/poisson/policy/spmd.py +++ /dev/null @@ -1,27 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.function import IRConv2D - - -def PASReplica(graph: IRGraph, resource) -> IRGraph: - for node in graph.nodes(): - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return graph - - -def PASHaloConv(graph: IRGraph, resource) -> IRGraph: - for node in graph.nodes(): - if isinstance(node, IRConv2D): - sub_nodes = list() - algo = node.algorithms('halo') - Wnodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus // 2) - for Wnode in Wnodes: - algo = Wnode.algorithms('halo') - Hnodes = graph.partition(Wnode, algo, idx=0, dim=2, num=2) - sub_nodes += Hnodes - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return graph diff --git a/examples/poisson/sci.py b/examples/poisson/sci.py deleted file mode 100644 index e6415f0e..00000000 --- a/examples/poisson/sci.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/poisson/sci.py -""" - -import torch -import torch.nn.functional as F -import time - -torch.set_default_tensor_type(torch.DoubleTensor) - -from cube.runtime.syndata import SciLoopVariables -import cube -from examples.poisson.policy.spmd import PASHaloConv as PAS - - -class ScientificModel(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, r0: torch.Tensor, p: torch.Tensor, phi: torch.Tensor, - filter: torch.Tensor): - conv_out = F.conv2d(p, filter, padding=1) - alpha = torch.mul(r0, r0).sum() / torch.mul(p, conv_out).sum() - r1 = r0 - alpha * conv_out - # update - phi = phi + alpha * p - r1_sum = torch.mul(r1, r1).sum() - beta = r1_sum / torch.mul(r0, r0).sum() - p = r1 + beta * p - return r1, p, phi, r1_sum - - -def train_loop(): - # initialize - N = 1024 * 2 - filter = torch.tensor( - [[0., 1., 0.], - [1., -4., 1.], - [0., 1., 0.]] - ).view(1, 1, 3, 3) - rho = F.conv2d(torch.ones((1, 1, N, N)), filter, padding=1) - phi = torch.zeros((1, 1, N, N)) - r0 = rho - F.conv2d(phi, filter, padding=1) - p = r0 - - varloader = SciLoopVariables(variables=[r0, p, phi], constants=[filter]) - model = ScientificModel() - model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes),) - - @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) - def train_iter(model, dataloader): - r0, p, phi, filter = next(dataloader) - r0, p, phi, r1_sum = model(r0, p, phi, filter) - return r0, p, phi, r1_sum - model = model.get_gen_module() - - start = time.time() - - counter = 0 - while True: - counter += 1 - r0, p, phi, r1_sum = train_iter(model, varloader) - varloader.update(variables=[r0, p, phi]) - if counter % 100 == 0: - print('iters:\t', counter) - print('rnorm:\t', torch.sqrt(r1_sum)) - if torch.sqrt(r1_sum) < 1e-10: - print('**************** Converged ****************') - print('iters:\t', counter) - torch.cuda.synchronize() - print('time:\t', time.time() - start) - print('error:\t', torch.norm(phi - torch.ones((1, 1, N, N)).cuda())) - break - - -if __name__ == '__main__': - cube.init() - train_loop() \ No newline at end of file diff --git a/examples/policies/alpa/estimator.py b/examples/policies/alpa/estimator.py index 14cee328..a658c475 100644 --- a/examples/policies/alpa/estimator.py +++ b/examples/policies/alpa/estimator.py @@ -148,17 +148,7 @@ def get_dep_names(sign: str): return ret if node.signature in CustomizedOps.kOpCodeDef: - dep_code_impl = '' - for dep_name in get_dep_names(node.signature): - dep_code_impl = dep_code_impl + CustomizedOps.kOpCodeDef[dep_name] code_impl: str = CustomizedOps.kOpCodeDef[node.signature] - def_end = code_impl.find(':\n') - assert def_end >= 0 - prev_code_lines = code_impl[:def_end+2] - succ_code_lines = code_impl[def_end+2:] - for line in dep_code_impl.split('\n'): - prev_code_lines = prev_code_lines + ' ' + line + '\n' - code_impl = prev_code_lines + succ_code_lines local = {} exec(code_impl, globals(), local) fn = list(local.values())[0] diff --git a/examples/utils.py b/examples/utils.py new file mode 100644 index 00000000..c1488ef1 --- /dev/null +++ b/examples/utils.py @@ -0,0 +1,145 @@ +from typing import List, Union, Callable, Optional, Tuple +import logging + +from cube.graph import IRGraph +from cube.graph.segment import IRSegment +from cube.graph.function.dimops import IRDimops +from cube.graph.gener.rvd.intra import IntraAutoPlacer +from cube.ir.operator import IRDataOperation, IRFwOperation +from cube.ir.cten import IRCell +from cube.ir.tensor import IRFullTensor +from cube.graph.function.anchor import IRGraphAnchor +from cube.utils import print_each_rank + +import numpy as np + + +_logger = logging.getLogger(__name__) + + +def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + + e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: + ( + ( (0,1,2), (3,4,5) ), + ( (0,3), (2,5), (3,6) ), + ) + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + + +def group_to_layers(fnodes) -> List[List[IRCell]]: + # group to layers + transformers: List[List[IRFwOperation]] = [] + anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] + indices = [fnodes.index(anchor) for anchor in anchors] + for lid, idx in enumerate(indices): + fnodes[idx+1].comment = f'===> start of layer {lid}' + start = idx if lid != 0 else 0 + end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) + transformers.append(fnodes[start:end]) + for lid in range(len(transformers) - 1): + if transformers[lid][-1].name == 'multiref': + node = transformers[lid].pop() + transformers[lid+1].insert(0, node) + return transformers + + +def _tp_autoplace(segment: IRSegment, ftensor: IRFullTensor, + producers: List[IRFwOperation], devs: List[int], + sub_nodes: List[IRFwOperation]) -> List[int]: + """decide the devices of the partitioned `sub-nodes` to achieve optimal communication + + Args: + segment (IRSegment): segment of the ftensor + ftensor (IRFullTensor): the tensor to be partitioned + producers (List[IRFwOperation]): producers of the ftensor + devs (List[int]): devices to be placed + sub_nodes (List[IRFwOperation]): partitioned nodes + + Returns: + List[int]: devices of the partitioned `sub-nodes` + """ + if ftensor.is_param() or len(producers) != len(sub_nodes): + _logger.warning(f"skip auto placer due to condition not matched: " + f"nproducers: {len(producers)}, nconsumers: {len(sub_nodes)}, " + f"producer name: {producers[0].name if len(producers) > 0 else None}") + devs = sorted(list(devs)) + else: + devs = IntraAutoPlacer.auto_place(segment, ftensor, producers, sub_nodes) + return devs + +# tensor parallelism +def tensor_parallelism(graph: IRGraph, node: IRDimops, + idx: int, dim: int, devs: List[int], + autoplace: bool = False) -> List[IRDimops]: + """Apply tensor parallelism of a node to devs""" + if len(devs) == 1: + graph.assign(node, devs[0]) + return [node] + # transformation + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + + if autoplace: + segment = graph.segment(node) + devs = _tp_autoplace(segment, node.input(idx).parent, + segment.producers(node.input(idx).parent), + devs, sub_nodes) + # assign + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +# replica +def replica(graph: IRGraph, node: Union[IRFwOperation, IRDataOperation], + devs: List[int]) -> List[Union[IRFwOperation, IRDataOperation]]: + """Replicate a forward node or dataloader to devs""" + if len(devs) == 1: + graph.assign(node, devs[0]) + return [node] + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def get_policy(modules: List, name: str) -> Callable: + """Get policy from modules + + Note every rank should enter this function simutaneously. + + Args: + modules (List): list of modules + name (str): name of policy + + Returns: + Callable: policy + """ + for module in modules: + if name in module.__dict__: + print_each_rank(f'using policy from {module.__name__}.{name}') + return module.__dict__[name] + policies = [] + for module in modules: + policies += list(policy for policy in module.__dict__.keys() if policy.startswith('PAS')) + raise ValueError(f"policy {name} not found. Candidates: {policies}") \ No newline at end of file diff --git a/examples/vision/resnet/model.py b/examples/vision/resnet/model.py deleted file mode 100644 index 09753c73..00000000 --- a/examples/vision/resnet/model.py +++ /dev/null @@ -1,281 +0,0 @@ -from typing import List, Optional, Callable -import torch -import torch.nn as nn -import cube - - -class Config: - - width_factor = 1 # for scaling default 1 - inplanes = 160 # 64 - # setting for wide-resnet 50 - layers : List[int] = [3, 4, 6, 3] - - # setting for wide-resnet 101 - layers : List[int] = [3, 4, 23, 3] - - width_per_group = 128 * width_factor - # conv2d: - # in_channel: 128 | out_channel: 128 | stride: 1 | groups: 1 | dilation: 1 - # in_channel: 256 | out_channel: 256 | stride: 2 | groups: 1 | dilation: 1 - # in_channel: 512 | out_channel: 512 | stride: 1 | groups: 1 | dilation: 1 - # in_channel: 1024 | out_channel: 1024 | stride: 2 | groups: 1 | dilation: 1 - # conv2d inputs: - # torch.Size([1, 128, 128, 128]) - # torch.Size([1, 256, 64, 64]) - # torch.Size([1, 512, 32, 32]) - # torch.Size([1, 1024, 16, 16]) - - # input - img_size = 224 - num_classes = 1024 # 1000 - - -def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: - """3x3 convolution with padding""" - # print(f'conv2d: in_channel: {in_planes} | out_channel: {out_planes} | stride: {stride} | groups: {groups} | dilation: {dilation}') - return nn.Conv2d( - in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation, - ) - - -def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -class BasicBlock(nn.Module): - expansion: int = 1 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ) -> None: - super().__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError("BasicBlock only supports groups=1 and base_width=64") - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x: torch.Tensor) -> torch.Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - - expansion: int = 4 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ) -> None: - super().__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.0)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - print(f'adding conv2d channel: {width}, stride: {stride}, padding: {dilation}') - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x: torch.Tensor) -> torch.Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - # print(f'conv2d input shape: {out.size()}') - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - print(identity.size(), out.size()) - out += identity - out = self.relu(out) - - return out - - -class WideResNet(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.layers = Config.layers - self.num_classes = 1000 - self._norm_layer = nn.BatchNorm2d - self.block = Bottleneck - self.inplanes = 64 - self.dilation = 1 - self.replace_stride_with_dilation = [False, False, False] - self.groups = 1 - self.base_width = Config.width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = self._norm_layer(self.inplanes) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(64, self.layers[0]) - self.layer2 = self._make_layer(128, self.layers[1], stride=2, dilate=self.replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(256, self.layers[2], stride=2, dilate=self.replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(512, self.layers[3], stride=2, dilate=self.replace_stride_with_dilation[2]) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * self.block.expansion, self.num_classes) - self.loss_func = nn.CrossEntropyLoss() - - def _make_layer(self, planes: int, blocks: int, stride: int = 1, dilate = False): - block = Bottleneck - - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - norm_layer(planes * block.expansion), - ) - - layers = [] - layers.append( - block( - self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer - ) - ) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append( - block( - self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation, - norm_layer=norm_layer, - ) - ) - return torch.nn.ModuleList(layers) - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - for layer in self.layer1: - x = layer(x) - for layer in self.layer2: - x = layer(x) - for layer in self.layer3: - x = layer(x) - for layer in self.layer4: - x = layer(x) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - x = self.fc(x) - - loss = self.loss_func(x, target) - return loss - - -class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int): - - self.bs = batch_size - self.img_size = Config.img_size - self.num_classes = Config.num_classes - super().__init__( - shapes=([batch_size, 3, self.img_size, self.img_size,], - [batch_size], - ), - dtypes=(torch.float, torch.int), - batch_dims=(0, 0) - ) - self.samples = [self.random_sample()] - - def random_sample(self): - img = torch.rand( - *(self.bs, 3, self.img_size, self.img_size), - dtype=torch.float, - device=torch.cuda.current_device() - ) - labels = torch.randint( - 0, self.num_classes, - size=(self.bs,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - return (img, labels) - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] \ No newline at end of file diff --git a/examples/vision/resnet/model_alpa.py b/examples/vision/resnet/model_alpa.py deleted file mode 100644 index 362d6432..00000000 --- a/examples/vision/resnet/model_alpa.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -import torch.nn as nn -import cube - - -class Config: - - stages = { - 50: [3, 4, 6, 3], - 101: [3, 4, 23, 3] - } - - num_layers = 50 - width_factor = 2 - num_filters = 160 - - img_size = 224 - num_classes = 1024 - - -class Bottleneck(nn.Module): - - def __init__(self, in_channels: int, out_channels: int, width_factor: int, stride: int) -> None: - super().__init__() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) - self.norm1 = nn.BatchNorm2d(out_channels) - self.act1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_channels, out_channels * width_factor, kernel_size=3, stride=stride, padding=1, bias=False) - self.norm2 = nn.BatchNorm2d(out_channels * width_factor) - self.act2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(out_channels * width_factor, out_channels * 4, kernel_size=1, bias=False) - self.norm3 = nn.BatchNorm2d(out_channels * 4) - - # down sample - self.downsample = None if in_channels == out_channels * 4 else torch.nn.ModuleList([ - nn.Conv2d(in_channels, out_channels * 4, 1, stride, bias=False), - nn.BatchNorm2d(out_channels * 4) - ]) - - self.act3 = nn.ReLU(inplace=True) - - def forward(self, x: torch.Tensor): - - residual = x - - y = self.conv1(x) - y = self.norm1(y) - y = self.act1(y) - - y = self.conv2(y) - y = self.norm2(y) - y = self.act2(y) - - y = self.conv3(y) - y = self.norm3(y) - - if self.downsample is not None: - for layer in self.downsample: - residual = layer(residual) - - # print(residual.size(), y.size()) - y = self.act3(residual + y) - return y - - -class WideResNet(nn.Module): - - def __init__(self): - super().__init__() - config = Config() - - # preprocess - self.conv1 = nn.Conv2d(3, config.num_filters, kernel_size=7, stride=2, padding=3, bias=False) - self.norm1 = nn.BatchNorm2d(config.num_filters) - self.act1 = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 'padding=SAME' - - self.layers = torch.nn.ModuleList([]) - - stages = config.stages[config.num_layers] - for i, block_size in enumerate(stages): - channel = config.num_filters * (2 ** i) - for j in range(block_size): - if i == 0 and j == 0: - in_channels = channel - elif i > 0 and j == 0: - in_channels = channel // 2 * 4 - else: - in_channels = channel * 4 - stride = 2 if i > 0 and j == 0 else 1 - print(f'add in_channel: {in_channels} | out_channel: {channel * 4}') - block = Bottleneck( - in_channels, channel, config.width_factor, stride - ) - self.layers.append(block) - - # postprocess - self.fc = nn.Linear(channel * 4, config.num_classes, bias=False) - self.criterion = nn.CrossEntropyLoss() - - def forward(self, img: torch.Tensor, label: torch.Tensor): - x = self.conv1(img) - x = self.norm1(x) - x = self.act1(x) - x = self.maxpool(x) - # print(x.size()) - - for block in self.layers: - x = block(x) - - # N C H W -> N C - x = torch.mean(x, dim=(2,3)) - x = self.fc(x) - loss = self.criterion(x, label) - return loss - - -class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int): - - self.bs = batch_size - self.img_size = Config.img_size - self.num_classes = Config.num_classes - super().__init__( - shapes=([batch_size, 3, self.img_size, self.img_size,], - [batch_size], - ), - dtypes=(torch.float, torch.int), - batch_dims=(0, 0) - ) - self.samples = [self.random_sample()] - - def random_sample(self): - img = torch.rand( - *(self.bs, 3, self.img_size, self.img_size), - dtype=torch.float, - device=torch.cuda.current_device() - ) - labels = torch.randint( - 0, self.num_classes, - size=(self.bs,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - return (img, labels) - - def __iter__(self): - return self - - def __next__(self): - return self.samples[0] diff --git a/examples/vision/resnet/train.py b/examples/vision/resnet/train.py deleted file mode 100644 index 20ef08c7..00000000 --- a/examples/vision/resnet/train.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -example: - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - --nnodes=1 \ - examples/vision/resnet/train.py -""" - -import torch -from examples.vision.resnet.model_alpa import WideResNet, ImageDataLoader - -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary, model_summary - - - -def train(): - - batch_size = 32 - nmicros = 1536 // batch_size - - - model = WideResNet() - model = model.cuda() - - cnt = 0 - for param in model.parameters(): - cnt += param.nelement() - print(f'param#: {cnt / 1e6} M') - - dataloader = ImageDataLoader(batch_size) - optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) - - print_each_rank('model weight consumpition:') - memory_summary() - - def train_iter(model, dataloader): - imgs, labels = next(dataloader) - loss = model(imgs, labels) - loss.backward() - - CudaTimer(enable=False).warmup() - iter_num = 10 - for step in range(iter_num): - - # if step == 0: - # model_summary(model, next(dataloader)) - - if step >= 4: - CudaTimer(enable=True).start('e2e') - - # training - for _ in range(nmicros): - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - - if step >= 4: - CudaTimer().stop('e2e') - - if step == 0: - print_each_rank('passed first iteration') - - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-4, field_name='e2e'))) - memory_summary() - -if __name__ == '__main__': - - cube.init() - train() diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index 10e29dd9..edae0a44 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -6,6 +6,7 @@ from examples.vision.swin.blocks.patch import PatchEmbed, PatchMerging import cube +from cube.runtime.utils import create_dummy_dataloader class Config: @@ -221,27 +222,10 @@ def flops(self): # =========================== Data Loader ======================= -class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): - - def __init__(self, batch_size: int, img_size: int, num_classes: int, dtype=torch.float32): - super().__init__(batch_size, [0]) - self.img_size = img_size - self.num_classes = num_classes - self.dtype = dtype - - self.sample = None - self.set_batch_size(batch_size) - - def __iter__(self): - return self - - def __next__(self): - return self.sample - - def set_batch_size(self, batch_size: int): - self.batch_size = batch_size - input_ids = torch.rand( - [self.batch_size, 3, self.img_size, self.img_size], - dtype=self.dtype, device=torch.cuda.current_device() - ) - self.sample = input_ids +def get_swin_dummy_dataloader(batch_size: int, + dtype: torch.dtype, cfg: Config): + input_ids = torch.randn( + [3, cfg.img_size, cfg.img_size], + dtype=dtype, device=torch.cuda.current_device() + ) + return create_dummy_dataloader(input_ids, batch_size) diff --git a/examples/vision/swin/policy/__init__.py b/examples/vision/swin/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/vision/swin/policy/gallery.py b/examples/vision/swin/policy/gallery.py new file mode 100644 index 00000000..79a6956b --- /dev/null +++ b/examples/vision/swin/policy/gallery.py @@ -0,0 +1,106 @@ +from typing import List + +from cube.graph import IRGraph +from cube.graph.function.anchor import IRGraphAnchor +from cube.graph.schedule.predefined import PredefinedSched +from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation + +from examples.utils import tensor_parallelism, replica, group_to_layers + +import logging +_logger = logging.getLogger(__name__) + + +def coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, + idx: int, dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) + assert sub_nodes is not None + graph.recompute(sub_nodes) + for devid in devs: + for coid in range(colocate): + sub_node = sub_nodes[devid * colocate + coid] + graph.assign(sub_node, devid) + return sub_nodes + + +def PASSingle(graph: IRGraph, resource, **kwargs): + assert resource.ngpus == 1 + # print(graph.extra_repr()) + for node in graph.nodes(): + if not isinstance(node, IRBpOperation): + graph.assign(node, 0) + return graph + + +def PASData(graph: IRGraph, resource, **kwargs): + """Data parallelism""" + devs = list(range(resource.ngpus)) + dataloader = graph.select(ntype=IRDataOperation)[0] + bs = dataloader.output(0).shape[0] + # replicate dataloader + replica(graph, dataloader, devs) + # partition forward operators + for node in graph.select(ntype=IRFwOperation): + if isinstance(node, IRGraphAnchor): continue + try: + tensor_parallelism(graph, node, idx=0, dim=0, devs=devs) + except Exception as e: + _logger.warning(f'fail to partition node {node.name} at idx=0, using replica') + replica(graph, node, devs) + return graph + + +def PASMegatronTP(graph: IRGraph, resource, **kwargs): + """Megatron-way tensor parallelism""" + devs = list(range(resource.ngpus)) + # attention + for attn in graph.select(name='window_attn'): + tensor_parallelism(graph, attn, idx=1, dim=0, devs=devs) + # feedforward + for ffn in graph.select(name='feedforward'): + tensor_parallelism(graph, ffn, idx=1, dim=0, devs=devs) + # replicate other nodes + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + replica(graph, node, devs) + return graph + + +def PASMeshShard(graph: IRGraph, resource, **kwargs): + """Coshard policy example""" + devs = list(range(resource.ngpus)) + # attention + for attn in graph.select(name='window_attn'): + # _tp(graph, attn, tp_devs, idx=1, dim=0) + coshard(graph, attn, devs, colocate=2, idx=1, dim=0) + # feedforward + for ffn in graph.select(name='feedforward'): + # _tp(graph, ffn, tp_devs, idx=1, dim=0) + coshard(graph, ffn, devs, colocate=4, idx=1, dim=0) + # replicate other nodes + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + replica(graph, node, devs) + return graph + + +def PAS1F1B(graph: IRGraph, resource, nmicros: int, **kwargs): + """1F1B schedule""" + num_stages = resource.ngpus + num_microbatch = nmicros + # group to transformer layers + transformers = group_to_layers(graph.select(ntype=IRFwOperation)) + # staging + nlayer_per_stage = (len(transformers) // resource.ngpus) + for lid, fnodes in enumerate(transformers): + stage_id = min(lid // nlayer_per_stage, num_stages-1) + _logger.info(f'assigning {lid}-th transformer layter to stage {stage_id}') + for fnode in fnodes: + graph.assign(fnode, stage_id) + # replicate dataloader + for node in graph.select(ntype=IRDataOperation): + replica(graph, node, list(range(resource.ngpus))) + # apply 1f1b schedule + PredefinedSched.sched_1f1b(graph, num_microbatch, num_stages) + return graph diff --git a/examples/vision/swin/policy/mpmd.py b/examples/vision/swin/policy/mpmd.py deleted file mode 100644 index 74850138..00000000 --- a/examples/vision/swin/policy/mpmd.py +++ /dev/null @@ -1,219 +0,0 @@ -from typing import List, Tuple -import numpy as np - -from cube.graph import IRGraph -from cube.graph.function.anchor import IRGraphAnchor -from cube.ir.cten import IRCell -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.segment import IRSegment -from cube.graph.schedule.sched1f1b import IRSchedule1F1B - - -def _create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: - """ - Create hybrid (nested) groups given the each group number. - - The product of group_num should be same with total devices. - - e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: - ( - ( (0,1,2), (3,4,5) ), - ( (0,3), (2,5), (3,6) ), - ) - """ - group_num = np.array(group_num) - cnt = np.prod(group_num) - assert cnt == ngpus, 'total device not match' - grid = np.arange(cnt).reshape(tuple(group_num)) - dims = list(range(len(group_num))) - outputs = [] - for dim, num in enumerate(group_num): - remain = ngpus // num - order = tuple(dims[:dim] + dims[dim+1:] + [dim]) - grid_dim = np.transpose(grid, order).reshape((remain,num)) - grid_dim = grid_dim.tolist() - outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) - assert len(outputs) == len(group_num) - return tuple(outputs) - - -def _group_to_transformers(fnodes) -> List[List[IRCell]]: - # group to transformer layers - transformers: List[List[IRFwOperation]] = [] - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - fnodes[idx+1].comment = f'===> start of transformer layer {lid}' - start = idx if lid != 0 else 0 - end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) - transformers.append(fnodes[start:end]) - for lid in range(len(transformers) - 1): - if transformers[lid][-1].name == 'multiref': - node = transformers[lid].pop() - transformers[lid+1].insert(0, node) - return transformers - -# ========================= parallelisms ================================= - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], **configs): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -def _coshard(graph: IRGraph, node: IRFwOperation, devid: int, **configs): - algo = node.algorithms('dim') - if node.recompute is None: - graph.recompute([node]) - sub_nodes = graph.partition(node, algo, **configs) - assert sub_nodes is not None - for sub_node in sub_nodes: - graph.assign(sub_node, devid) - return sub_nodes - -# ========================= parallelisms ================================= - -def PASRoundRobin(graph: IRGraph, resource): - """ - roundrobin scheduling - """ - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # group to transformer layers - transformers = _group_to_transformers(fnodes) - - for lid, transformer in enumerate(transformers): - stage_id = lid % resource.ngpus - print(f'assigning {lid} transformer to stage {stage_id}') - for node in transformer: - graph.assign(node, stage_id) - - for node in graph.nodes(): - if len(node.device) == 0: - _replica(graph, node, list(range(resource.ngpus))) - - return graph - - -def PAS1F1B(graph: IRGraph, resource): - """ - 1F1B scheduling - """ - num_stage = resource.ngpus - num_microbatch = resource.ngpus * 8 - _, stage_mesh = _create_mesh(resource.ngpus, (num_stage, 1)) - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # group to transformer layers - transformers = _group_to_transformers(fnodes) - - # staging - nlayer_per_stage = (len(transformers) // resource.ngpus) - for lid, fnodes in enumerate(transformers): - stage_id = min(lid // nlayer_per_stage, num_stage-1) - print(f'assigning {lid}-th transformer layter to stage {stage_id}') - for fnode in fnodes: - graph.assign(fnode, stage_id) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - _replica(graph, node, list(range(resource.ngpus))) - - strategy = IRSchedule1F1B(graph, num_microbatch, stage_mesh) - graph.sched = strategy - return graph - - -def PASMegatron(graph: IRGraph, resource): - """ - Megatron policy with Data, Tensor, Pipeline Parallelism. - """ - dp_size = 1 - tp_size = 2 - pp_size = resource.ngpus // (dp_size * tp_size) - # note coshard will only apply to first 4 tranformer blocks - coshard = 2 - recompute: bool = False - num_microbatch = 8 - - # device mesh - dp_groups, pp_groups, tp_groups = \ - _create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) - print(f'dp groups: {dp_groups}') - print(f'pp groups: {pp_groups}') - print(f'tp groups: {tp_groups}') - - def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: - return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] - - # group to transformer layers - transformers = _group_to_transformers(graph.select(ntype=IRFwOperation)) - if recompute: - for transformer in transformers: - graph.recompute(transformer) - - # group to stage: set each stage operators - fstages = [[] for _ in range(pp_size)] - nlayer_per_stage = (len(transformers) // pp_size) - for lid, fnodes in enumerate(transformers): - stage_id = min(lid // nlayer_per_stage, pp_size - 1) - fstages[stage_id] += fnodes - graph.staging(tuple(stages[0] for stages in fstages)) - - dataloader = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[0] - - # partition dataloader - dls = _replica(graph, dataloader, [0]*dp_size) # graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) - for dp_idx, dl in enumerate(dls): - # only stage 0 needs dataloader - devices = [get_device(dp_idx, 0, tp_idx) for tp_idx in range(tp_size)] - _replica(graph, dl, devices) - - tid = 0 - - # staging - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - assert len(fstages) == pp_size - nlayer_per_stage = (len(transformers) // pp_size) - for pp_idx, fstage in enumerate(fstages): - for fnode in fstage.nodes(): - subnodes = [fnode] - if len(fnode.inputs()) == 0: continue # anchor - # tensor parallel -- FIXME: current restriction needs replica happen before partition - if fnode.name == 'window_attn' or fnode.name == 'feedforward': - subnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) - elif fnode.name == 'linear': # the last embeding linear - subnodes = _tp(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) - else: - subnodes = _replica(graph, fnode, [0]*tp_size) - # data parallel - pnodes = [] - for tp_idx, subnode in enumerate(subnodes): - dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] - batch_dim = 0 if bs not in subnode.input(0).shape else subnode.input(0).shape.index(bs) - nodes = _tp(graph, subnode, devs=dp_devices, idx=0, dim=batch_dim, num=dp_size) - pnodes += nodes - subnodes = pnodes - # coshard - if fnode.name in ['window_attn', 'feedforward']: - if coshard > 1 and tid < 4: - for subnode in subnodes: - devid = subnode.device[0] - _coshard(graph, subnode, devid, idx=1, dim=0, num=coshard) - tid = tid + 1 if fnode.name == 'window_attn' else tid - - strategy = IRSchedule1F1B(graph, num_microbatch) - graph.predef_sched(strategy) - return graph diff --git a/examples/vision/swin/policy/spmd.py b/examples/vision/swin/policy/spmd.py deleted file mode 100644 index 815d43ca..00000000 --- a/examples/vision/swin/policy/spmd.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import Dict, List - -from cube.graph import IRGraph -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.dimops import DimopSplit, TransformRule -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.ir.tensor import IRFullTensor, IRSubTensor - - -# ========================= parallelisms ================================= - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], - idx: int, dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - -# coshard -def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, - idx: int, dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) - assert sub_nodes is not None - graph.recompute(sub_nodes) - for devid in devs: - for coid in range(colocate): - sub_node = sub_nodes[devid * colocate + coid] - graph.assign(sub_node, devid) - return sub_nodes - -# ========================= parallelisms ================================= - - - -def PASSingle(graph: IRGraph, resource): - assert resource.ngpus == 1 - # print(graph.extra_repr()) - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - return graph - - -def PASData(graph: IRGraph, resource): - dp_size = resource.ngpus - dp_devs = list(range(dp_size)) - - ftensors: Dict[IRFullTensor, DimopSplit] = dict() # ftensor: producer partition index - - dataloaders = [node for node in graph.nodes() if isinstance(node, IRDataOperation)] - for dataloader in dataloaders: - algo = dataloader.algorithms('data') - subnodes = graph.partition(dataloader, algo, num=dp_size) - for idx, sub_node in enumerate(subnodes): - graph.assign(sub_node, idx) - for oidx, output in enumerate(dataloader.outputs()): - if not isinstance(output, IRSubTensor): - continue - if output.parent not in ftensors: - bdim = dataloader.get_batch_dims()[oidx] - ftensors[output.parent] = DimopSplit.D(bdim) - - - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - for node in fnodes: - if isinstance(node, IRGraphAnchor): - continue - partitioned = False - for iidx, itensor in enumerate(node.inputs()): - if not isinstance(itensor, IRSubTensor): - continue - if itensor.parent in ftensors: - dim = ftensors[itensor.parent] - assert dim.isD(), f"on partitioning node: {node}:\nexpected input to be partitioned on dimensions but found {dim}" - rule: TransformRule = node.algorithms('dim').infer(idx=iidx, dim=dim.dim, num=len(dp_devs)) - # print(rule) - assert rule is not None, f"fail to infer node: {node}, idx={iidx}" - for odim, output in zip(rule.outputs(), node.outputs()): - ftensors[output.parent] = odim - # print(f'==> setting next dim: {odim}') - _tp(graph, node, dp_devs, idx=iidx, dim=dim.dim) - partitioned = True - break - if not partitioned: - print(f'warning: cannot partition of node using dim propagation, use replica instead: {node}') - _replica(graph, node, dp_devs) - - return graph - - -def PASMegatronTP(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # annotating code structure -- not consider multiref on embedding weight - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - # why -1: multiref - fnodes[idx-1].comment = f'===> start of transformer layer {lid}' - - # attention - attns = [node for node in fnodes if node.name == 'window_attn'] - for attn in attns: - _tp(graph, attn, tp_devs, idx=1, dim=0) - - # feedforward - ffns = [node for node in fnodes if node.name == 'feedforward'] - for ffn in ffns: - _tp(graph, ffn, tp_devs, idx=1, dim=0) - - # replicate other nodes - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: - _replica(graph, node, tp_devs) - - return graph - - -def PASMeshShard(graph: IRGraph, resource): - - # print(graph.extra_repr()) - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # annotating code structure -- not consider multiref on embedding weight - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - # why -1: multiref - fnodes[idx-1].comment = f'===> start of transformer layer {lid}' - - # attention - attns = [node for node in fnodes if node.name == 'window_attn'] - for attn in attns: - # _tp(graph, attn, tp_devs, idx=1, dim=0) - _coshard(graph, attn, tp_devs, colocate=2, idx=1, dim=0) - - # feedforward - ffns = [node for node in fnodes if node.name == 'feedforward'] - for ffn in ffns: - # _tp(graph, ffn, tp_devs, idx=1, dim=0) - _coshard(graph, ffn, tp_devs, colocate=4, idx=1, dim=0) - - # replicate other nodes - for node in graph.nodes(): - if isinstance(node, (IRFwOperation, IRDataOperation)) and len(node.device) == 0: - _replica(graph, node, tp_devs) - - # print(graph.extra_repr()) - return graph diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 3b720d52..b293b680 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -1,23 +1,23 @@ """ example: -OMP_NUM_THREADS=4 torchrun \ +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=4 \ - --nnodes=1 \ - examples/vision/swin/train.py --policy PASMegatron --fp16 + examples/vision/swin/train.py --policy PASMegatronTP --fp16 """ import math import torch +from functools import partial from examples.vision.swin.blocks.attention import init_relative_position_index -from examples.vision.swin.model import Config, SwinTransformer, ImageDataLoader +from examples.vision.swin.model import Config, SwinTransformer, get_swin_dummy_dataloader import cube from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary, model_summary +from cube.profiler.memory import memory_summary -import examples.vision.swin.policy.spmd as spmd -import examples.vision.swin.policy.mpmd as mpmd +import examples.vision.swin.policy.gallery as gallery +from examples.utils import get_policy import argparse @@ -25,26 +25,30 @@ parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') +parser.add_argument('--dp', type=int, default=1, + help='data parallel size, only for megatron') +parser.add_argument('--tp', type=int, default=1, + help='tensor parallel size, only for megatron') +# training +parser.add_argument('--gbs', type=int, default=4, help='global batch size') +parser.add_argument('--mbs', type=int, default=4, help='micro batch size') + args = parser.parse_args() cube.init() -PAS = None -policies = list(spmd.__dict__.keys()) + list(mpmd.__dict__.keys()) -policies = [policy for policy in policies if policy.startswith('PAS')] -if args.policy in spmd.__dict__: - PAS = spmd.__dict__[args.policy] - print_each_rank(f'using policy from spmd.{args.policy}') -elif args.policy in mpmd.__dict__: - PAS = mpmd.__dict__[args.policy] - print_each_rank(f'using policy from mpmd.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") +# get policy +policy = get_policy([gallery], args.policy) +policy = partial(policy, + nmicros=args.gbs//args.mbs, + dp_size=args.dp, + tp_size=args.tp +) def train(): - batch_size = 4 + batch_size = args.mbs load_content: bool = False cfg = Config() @@ -52,16 +56,14 @@ def train(): model = model.half() if args.fp16 else model dtype = torch.float16 if args.fp16 else torch.float32 - dataloader = ImageDataLoader(batch_size, cfg.img_size, cfg.num_classes, dtype=dtype) + dataloader = get_swin_dummy_dataloader(batch_size, dtype, cfg) - model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=PAS, override=True, load_content=load_content) + @cube.compile(model, dataloader, PAS=policy, load_content=load_content) def train_iter(model, dataloader): imgs = next(dataloader) loss = model(imgs) loss.backward() - # return loss - model: torch.nn.Module = model.get_gen_module() + model = cube.utils.load_model() if not load_content: for name, buffer in model.named_buffers(): @@ -79,27 +81,23 @@ def train_iter(model, dataloader): nparams += param.nelement() print_each_rank(f'model parameter: {nparams}') - CudaTimer(enable=False).warmup() + CudaTimer().warmup() + dataloader = iter(dataloader) iter_num, warmup = 5, 2 for step in range(iter_num): - - if step >= warmup: + if step == warmup: CudaTimer(enable=True).start('e2e') - # training - loss = train_iter(model, dataloader) - # print(loss) + train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() - if step >= warmup: - CudaTimer().stop('e2e') - if step == 0: print_each_rank('passed first iteration') - if (step + 1) % 4 == 0: + if (step + 1) % 2 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) + CudaTimer().stop('e2e') print_each_rank('e2e time (ms) per iteration: {} ms'.format( CudaTimer().duration(iter_num-warmup, field_name='e2e'))) CudaTimer().print_all(times=iter_num-warmup) diff --git a/examples/wrf/policy/h_halo.py b/examples/wrf/policy/h_halo.py deleted file mode 100644 index a615ce11..00000000 --- a/examples/wrf/policy/h_halo.py +++ /dev/null @@ -1,18 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.function import IRConv2D, IRConv3D - -def PAS(graph: IRGraph, resource): - for node in graph.nodes(): -# graph.assign(node, 0) - if isinstance(node, IRConv3D): - sub_nodes = list() - algo = node.algorithms('halo') - sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - # sub_nodes = graph.replicate(node, times=resource.ngpus) - - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - print(graph.extra_repr()) - return graph diff --git a/examples/wrf/policy/hw_halo.py b/examples/wrf/policy/hw_halo.py deleted file mode 100644 index 48aedc92..00000000 --- a/examples/wrf/policy/hw_halo.py +++ /dev/null @@ -1,22 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.function import IRConv2D, IRConv3D - -def PAS(graph: IRGraph, resource): - for node in graph.nodes(): -# graph.assign(node, 0) - if isinstance(node, IRConv3D): - sub_nodes = list() - algo = node.algorithms('halo') - Wnodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus // 2) - for Wnode in Wnodes: - algo = Wnode.algorithms('halo') - Hnodes = graph.partition(Wnode, algo, idx=0, dim=2, num=2) - sub_nodes += Hnodes - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - # sub_nodes = graph.replicate(node, times=resource.ngpus) - - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - print(graph.extra_repr()) - return graph diff --git a/examples/wrf/policy/naive.py b/examples/wrf/policy/naive.py deleted file mode 100644 index 9d1b7d97..00000000 --- a/examples/wrf/policy/naive.py +++ /dev/null @@ -1,8 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.function import IRConv2D - -def PAS(graph: IRGraph, resource): - for node in graph.nodes(): - graph.assign(node, 0) - print(graph.extra_repr()) - return graph diff --git a/examples/wrf/policy/onedim.py b/examples/wrf/policy/onedim.py deleted file mode 100644 index 0d0ea297..00000000 --- a/examples/wrf/policy/onedim.py +++ /dev/null @@ -1,165 +0,0 @@ -from cube.graph import IRGraph -from cube.graph.function import IRConv2D, IRConv3D -from cube.graph.function import IRDimops, IRPad -from cube.ir.cten import IRTensor, IRCell -from cube.graph.function import IRSelect, IRSelectScatter, IRSlice, IRToTensor, IROnes, IRRand, IRZeros - -def PAS(graph: IRGraph, resource): - for node in graph.nodes(): - graph.assign(node, 0) - print(graph.extra_repr()) - return graph - -global opSigns - -opSigns = [] - -def append_sign(sign: str): - global opSigns - if not sign in opSigns: - opSigns.append(sign) - -def PAS_ALL_TEST(graph: IRGraph, resource): - for node in graph.nodes(): - sign = node.signature.split('.')[-1] - append_sign(sign) - if isinstance(node, IRConv3D): - sub_nodes = list() - algo = node.algorithms('halo') - sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) - elif isinstance(node, IRDimops): - sign = node.signature.split('.')[-1] - if (sign == 'mul' or sign == 'add' or sign == 'sub' or sign == 'div') and (len(node.input(0).shape) == 5 or len(node.input(0).shape) == 3): - algo = node.algorithms('dim') - if len(node.input(0).shape) == 3: - sub_nodes = graph.partition(node, algo, idx=0, dim=1, num=resource.ngpus) - if sub_nodes == None: - sub_nodes = graph.replicate(node, times=resource.ngpus) - elif len(node.input(0).shape) == 5: - sub_nodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus) - if sub_nodes == None: - sub_nodes = graph.replicate(node, times=resource.ngpus) - elif sign == 'view': - print('partition view') - print(node) - algo = node.algorithms('view_simp') - sub_nodes = graph.partition(node, algo, idx=0, dimi=node.input(0).ndims-2, dimo=node.output(0).ndims-2, num=resource.ngpus) - print(sub_nodes) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - elif isinstance(node, IRPad): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.input(0).ndims-2, num=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - print(graph.extra_repr()) - return graph - - -def PAS_ALL_X(graph: IRGraph, resource): - elewise_sign = ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat', 'stack', 'sum', 'sin', 'gt'] - # elewise_sign = ['mul', 'div', 'add', 'sub'] - for node in graph.nodes(): - sign = node.signature.split('.')[-1] - if isinstance(node, IRConv3D): - sub_nodes = list() - algo = node.algorithms('halo') - sub_nodes = graph.partition(node, algo, idx=0, dim=3, num=resource.ngpus) - elif isinstance(node, IRDimops): - if sign in elewise_sign: - ndims = node.input(0).ndims - algo = node.algorithms('dim') - append_sign(ndims) - if ndims == 3 or ndims == 5 or ndims == 2 or ndims == 4: - sub_nodes = graph.partition(node, algo, idx=0, dim=ndims-1, num=resource.ngpus) - if sub_nodes == None: - sub_nodes = graph.replicate(node, times=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - elif sign == 'view': - algo = node.algorithms('view_simp') - if node.input(0).ndims >= 2 and node.output(0).ndims >= 3: - sub_nodes = graph.partition(node, algo, idx=0, dimi=node.input(0).ndims-1, dimo=node.output(0).ndims-1, num=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - # FIXME: Check 'circular' padding, should not be splitted easily - elif isinstance(node, IRSelect) or isinstance(node, IRPad) or isinstance(node, IRSlice) or isinstance(node, IRToTensor): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.input(0).ndims-1, num=resource.ngpus) - elif isinstance(node, IRSelectScatter): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, diml=node.input(0).ndims-1, dimr=node.input(1).ndims-1, num=resource.ngpus) - elif isinstance(node, IROnes) and node.output(0).ndims >= 3: - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-1, num=resource.ngpus) - # elif isinstance(node, IRRand) and node.output(0).ndims >= 3: - # algo = node.algorithms('dim') - # sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-1, num=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - print(graph.extra_repr()) - print(opSigns) - return graph - -def PAS_ALL_Y(graph: IRGraph, resource): - elewise_sign = ['mul', 'div', 'add', 'sub', 'multiref', 'neg', 'pow', 'cat', 'stack', 'sum', 'sin', 'gt'] - # elewise_sign = ['mul', 'div', 'add', 'sub'] - for node in graph.nodes(): - sign = node.signature.split('.')[-1] - if isinstance(node, IRConv3D): - sub_nodes = list() - algo = node.algorithms('halo') - sub_nodes = graph.partition(node, algo, idx=0, dim=2, num=resource.ngpus) - assert sub_nodes != None - elif isinstance(node, IRDimops): - if sign in elewise_sign: - ndims = node.input(0).ndims - algo = node.algorithms('dim') - append_sign(ndims) - if ndims == 3 or ndims == 5 or ndims == 2 or ndims == 4: - sub_nodes = graph.partition(node, algo, idx=0, dim=ndims-2, num=resource.ngpus) - if sub_nodes == None: - sub_nodes = graph.replicate(node, times=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - elif sign == 'view': - algo = node.algorithms('view_simp') - if node.input(0).ndims >= 2 and node.output(0).ndims >= 3: - print(node.input(0).shape, node.output(0).shape) - sub_nodes = graph.partition(node, algo, idx=0, dimi=node.input(0).ndims-2, dimo=node.output(0).ndims-2, num=resource.ngpus) - assert sub_nodes != None - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - elif isinstance(node, IRSelect) or isinstance(node, IRPad) or isinstance(node, IRSlice) or isinstance(node, IRToTensor): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.input(0).ndims-2, num=resource.ngpus) - assert sub_nodes != None - elif isinstance(node, IRSelectScatter): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, diml=node.input(0).ndims-2, dimr=node.input(1).ndims-2, num=resource.ngpus) - assert sub_nodes != None - elif isinstance(node, IROnes) and node.output(0).ndims >= 3: - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-2, num=resource.ngpus) - assert sub_nodes != None - elif isinstance(node, IRZeros) and node.output(0).ndims >= 3: - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-2, num=resource.ngpus) - # elif isinstance(node, IRRand) and node.output(0).ndims >= 3: - # algo = node.algorithms('dim') - # sub_nodes = graph.partition(node, algo, dim=node.output(0).ndims-1, num=resource.ngpus) - else: - sub_nodes = graph.replicate(node, times=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - print(graph.extra_repr()) - print(opSigns) - return graph diff --git a/examples/wrf/wrf.py b/examples/wrf/wrf.py deleted file mode 100644 index 47acdad9..00000000 --- a/examples/wrf/wrf.py +++ /dev/null @@ -1,385 +0,0 @@ -from typing import List - -import torch -torch.set_default_tensor_type(torch.DoubleTensor) -from torch import nn -import torch.nn.functional as F -# from linalg import tridiagonal - -from cube.runtime.syndata import SciLoopVariables -from einops import rearrange -from einops.layers.torch import Rearrange - -torch.jit.script(Rearrange('b c h w -> b h w c')) -print("torch einops 0") -torch.jit.script(Rearrange('(b0 b1 b2) c h w -> b0 b1 b2 h w c', b0=1, b1=1)) -print("torch einops 1") - -import cube -from examples.wrf.policy.hw_halo import PAS - - -device = 'cuda' # - - -def namestr(obj, namespace): - return [name for name in namespace if namespace[name] is obj] - -def init(nx, ny, nz, dx, dy, dz, theta, PREF, Rd, g, p_t=None, p_s=None, u0=None, v0=None, w0=None, device='cuda'): - # spatial discretization - # dx, dy, dz, nx, ny, nz = dx, dy, 1. / (nz + 1), nx, ny, nz - # agnostic variables - P_t = p_t if p_t else torch.ones((1, ny, nx), device=device) * PREF * 0.0 - P_s = p_s if p_s else torch.ones((1, ny, nx), device=device) * PREF - # pressure (nz, ny, nx) - P = torch.linspace(dz, 1 - dz, nz, device=device).view(nz, 1, 1) * \ - (P_s - P_t).view(1, ny, nx) + P_t - # Alpha (nz, ny, nx) - Alpha = Rd / PREF * theta[1:-1] * (P / PREF) ** (-1 / 1.4) - # prognostic variables - # Mu (nz, ny, nx) - Mu = torch.ones((nz, 1, 1), device=device) * (P_s - P_t).view(1, ny, nx) - # Mu_t = (P_s - P_t).view(1, ny, nx) - # Mu_s = (P_s - P_t).view(1, ny, nx) - # Phi (nz - 1, ny, nx) - Phi = torch.zeros((nz + 1, ny, nx), device=device) - Phi[:-1] = Mu * Alpha * dz - for i in range(nz - 1, -1, -1): - Phi[i] += Phi[i + 1] - Phi_t = Phi[0].view(1, ny, nx) - Phi_s = Phi[-1].view(1, ny, nx) - Phi = Phi[1:-1] - # Theta (nz, ny, nx) - theta_t = theta[0].view(1, ny, nx) - theta_s = theta[-1].view(1, ny, nx) - Theta = theta[1:-1] * Mu - # U (nz, ny, nx + 1) - U = u0 if u0 is not None else torch.zeros((nz, ny, nx + 1), device=device) - # V (nz, ny + 1, nx) - V = v0 if v0 is not None else torch.zeros((nz, ny + 1, nx), device=device) - # W (nz - 1, ny, nx) - W = w0 if w0 is not None else torch.zeros((nz - 1, ny, nx), device=device) - - # for var in [U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s]: - # print("### {} shape {}".format(namestr(var, globals()), var.shape)) - return U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s - -class WRF(torch.nn.Module): - - def __init__(self): - super().__init__() - self.bar_x_pre = Rearrange('(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # x = X.view(1, 1, Nz, Ny, Nx) - self.bar_x_post = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') - - self.pad_y_pre = Rearrange('(b0 Nz) Ny Nx -> b0 Nz Ny Nx', b0=1) #X.view(1, Nz, Ny, Nx) - self.pad_y_post = Rearrange('b0 Nz Ny Nx -> (b0 Nz) Ny Nx') #.view(Nz, Ny + 2, Nx) - - self.delta_z_pre = Rearrange('(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / dz - self.delta_z_post = Rearrange('b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') - - self.delta_x_pre = self.delta_z_pre - self.delta_x_post = self.delta_z_post - - - # def forward(self, dt): - # self.U, self.V, self.W, self.Theta, self.Mu, self.Phi = \ - # self.RK3_step(self.U, self.V, self.W, self.Theta, self.Mu, self.Phi, dt) - def forward(self, - U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, - dx, dy, dz, - dt, PREF, Rd, g, - bar_x_filter, delta_z_filter): - - # U, V, W, Theta, Mu, Phi = self.RK3_step(U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter) - U = self.RK3_step(U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, - Rd, g, bar_x_filter, delta_z_filter) - return U, V, W, Theta, Mu, Phi - - def RHS(self, - U, V, W, Theta, Mu, Phi, - Phi_t, Phi_s, - P_t, P_s, dx, dy, dz, - PREF, Rd, g, - bar_x_filter, - delta_z_filter): - - delta_x_filter = bar_x_filter - - # volecity - u = U / self.bar_x(self.pad_x(Mu), bar_x_filter) - # v = V / self.bar_y(self.pad_y(Mu), bar_y_filter) - # w = W / self.bar_z(Mu) - alpha = self.delta_z(self.pad_z(Phi, Phi_t, Phi_s), dz, delta_z_filter) / Mu - #TODO recover me alpha = -self.delta_z(self.pad_z(Phi, Phi_t, Phi_s), dz, delta_z_filter) / Mu - Alpha = alpha - theta = Theta / Mu - p = theta #TODO p = PREF * (Rd * theta / PREF / alpha)**1.4 - # omega = -w * g / self.bar_z(alpha) / self.bar_z(Mu) - # Omega = omega * self.bar_z(Mu) - #Omega = Omega - - # advection term - R_U = self.delta_x(self.bar_x(self.pad_x(U), bar_x_filter) * self.bar_x(self.pad_x(u), bar_x_filter), dx, delta_x_filter) - #- self.delta_x(self.bar_x(self.pad_x(U), bar_x_filter) * self.bar_x(self.pad_x(u), bar_x_filter), dx, delta_x_filter) - # \ - # - self.delta_y(self.bar_x(self.pad_x(V), bar_x_filter) * self.bar_y(self.pad_y(u)), dy) \ - # - self.delta_z(self.bar_x(self.pad_x(self.pad_z(Omega)), bar_x_filter) * self.bar_z(self.pad_z(u)), dz) - - # R_V = - self.delta_x(self.bar_y(self.pad_y(U)) * self.bar_x(self.pad_x(v)), dx) \ - # - self.delta_y(self.bar_y(self.pad_y(V)) * self.bar_y(self.pad_y(v)), dy) \ - # - self.delta_z(self.bar_y(self.pad_y(self.pad_z(Omega))) * self.bar_z(self.pad_z(v)), dz) - # - # R_W = - self.delta_x(self.bar_z(U) * self.bar_x(self.pad_x(w)), dx) \ - # - self.delta_y(self.bar_z(V) * self.bar_y(self.pad_y(w)), dy) \ - # - self.delta_z(self.bar_z(self.pad_z(Omega)) * self.bar_z(self.pad_z(w)), dz) - - # R_Theta = - self.delta_x(U * self.bar_x(self.pad_x(theta)), dx) \ - # - self.delta_y(V * self.bar_y(self.pad_y(theta)), dy) \ - # - self.delta_z(self.pad_z(Omega) * self.bar_z(self.pad_z(theta)), dz) - # - # R_Phi = - self.bar_z(self.bar_x(u)) * self.delta_x(self.bar_x(self.pad_x(Phi)), dx) \ - # - self.bar_z(self.bar_y(v)) * self.delta_y(self.bar_y(self.pad_y(Phi)), dy) \ - # - Omega * self.delta_z(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s)), dz) - # - # R_Mu = - self.delta_x(U, dx) - self.delta_y(V, dy) - self.delta_z(self.pad_z(Omega), dz) - - # pressure term - R_U = R_U - self.bar_x(self.pad_x(Mu), bar_x_filter) * self.bar_x(self.pad_x(alpha), bar_x_filter) * self.delta_x(self.pad_x(p), dx, delta_x_filter) - # \ - # - (self.delta_z(self.bar_x(self.bar_z(self.pad_x(self.pad_z(p, P_t, P_s))), bar_x_filter), dz) * - # self.delta_x(self.pad_x(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s))), dx)) - - # R_V = R_V - self.bar_y(self.pad_y(Mu)) * self.bar_y(self.pad_y(alpha)) * self.delta_y(self.pad_y(p), dy) \ - # - (self.delta_z(self.bar_y(self.bar_z(self.pad_y(self.pad_z(p, P_t, P_s)))), dz) * - # self.delta_y(self.pad_y(self.bar_z(self.pad_z(Phi, Phi_t, Phi_s))), dy)) - # - # R_W = R_W + g * (self.delta_z(p, dz) - self.bar_z(Mu)) - - # # gravity term - # R_Phi = R_Phi + g * w - - # Coriolis term - # R_U += + 100 * self.bar_x(self.bar_y(self.pad_x(V))) \ - # - 100 * self.bar_x(self.bar_z(self.pad_x(self.pad_z(W)))) \ - # - u * self.bar_x(self.bar_z(self.pad_x(self.pad_z(W)))) / 6400. / 1000. - # R_V += - 100 * self.bar_x(self.bar_y(self.pad_y(U))) \ - # + 100 * self.bar_y(self.bar_z(self.pad_y(self.pad_z(W)))) \ - # - v * self.bar_y(self.bar_z(self.pad_y(self.pad_z(W)))) / 6400. / 1000. - - return R_U #, R_V, R_W, R_Theta, R_Mu, R_Phi #, Alpha, Omega - - # def RK3_step(self, U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, dt): - def RK3_step(self, U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter): - r"""One RK3 Step""" - # R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter) - R_U = self.RHS(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter) - U_ = U + dt * R_U / 3 - # V_ = V + dt * R_V / 3 - # W_ = W + dt * R_W / 3 - # Theta_ = Theta + dt * R_Theta / 3 - # Mu_ = Mu + dt * R_Mu / 3 - # Phi_ = Phi + dt * R_Phi / 3 - - # R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) - # U_ = U + dt * R_U / 2 - # V_ = V + dt * R_V / 2 - # W_ = W + dt * R_W / 2 - # Theta_ = Theta + dt * R_Theta / 2 - # Mu_ = Mu + dt * R_Mu / 2 - # Phi_ = Phi + dt * R_Phi / 2 - # - # R_U, R_V, R_W, R_Theta, R_Mu, R_Phi = self.RHS(U_, V_, W_, Theta_, Mu_, Phi_, Phi_t, Phi_s, P_t, P_s, dx, dy, dz, PREF, Rd, g) - U = U + R_U #TODO U = U + dt * R_U - # V = V + dt * R_V - # W = W + dt * R_W - # Theta = Theta + dt * R_Theta - # Mu = Mu + dt * R_Mu - # Phi = Phi + dt * R_Phi - - return U #TODO , V, W, Theta, Mu, Phi - - def pad_x(self, X): - r"""Periodic boundary condition in x axis""" - return F.pad(X, (1, 1), "circular") - - def pad_y(self, X): - r"""Periodic boundary condition in y axis""" - # Nz, Ny, Nx = X.shape - # return F.pad(X.view(1, Nz, Ny, Nx), (0, 0, 1, 1), "circular").view(Nz, Ny + 2, Nx) - x_ext = self.pad_y_pre(X) - x_pad = F.pad(x_ext, (0, 0, 1, 1), "circular") - x_unext = self.pad_y_post(x_pad) - return x_unext - - # TODO def pad_z(self, X, top=None, surface=None): - def pad_z(self, X, top=torch.Tensor(), surface=torch.Tensor()): - r"""Dirichlet boundary condition in z axis""" - # _, ny, nx = X.shape - # top = torch.zeros((1, ny, nx), device=X.device) #TODO top = top if top is not None else torch.zeros((1, ny, nx), device=X.device) - # surface = torch.zeros((1, ny, nx), device=X.device) #TODO surface = surface if surface is not None else torch.zeros((1, ny, nx), device=X.device) - # return torch.cat((top, X, surface), dim=0) - return F.pad(X, (0, 0, 0, 0, 1, 1), "constant", 0.) - - def bar_x(self, X, filter): - r"""Numerical scheme for X\bar^x - - Args: - X (Tensor): shape (Nz, Ny, Nx) - - Returns: - Tensor: X\bar^x with shape (Nz, Ny, Nx-1) - """ - # Nz, Ny, Nx = X.shape - # filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) # filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) - #TODO return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / 2. - x = self.bar_x_pre(X) #x = rearrange(X, '(b0 b1 Nz) Ny Nx -> b0 b1 Nz Ny Nx', b0=1, b1=1) # x = X.view(1, 1, Nz, Ny, Nx) - convx = F.conv3d(x, filter) - convx2 = self.bar_x_post(convx) #rearrange(convx, 'b0 b1 Nz Ny Nx -> (b0 b1 Nz) Ny Nx') # convx2 = convx.view(Nz, Ny, Nx - 1) - convx3 = convx2 #TODO recover me / 2. - # convx3 = X # - return convx3 - - def delta_x(self, X, dx, filter): - r"""Numerical scheme for \delta_x X - - Args: - X (Tensor): shape (Nz, Ny, Nx) - - Returns: - Tensor: \delta_x X with shape (Nz, Ny, Nx-1) - """ - # Nz, Ny, Nx = X.shape - # filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) - # return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny, Nx - 1) / dx - x_ext = self.delta_x_pre(X) - x_conv = F.conv3d(x_ext, filter) - x_unext = self.delta_x_post(x_conv) - return x_unext #TODO / dx - - - - def bar_y(self, X, filter): - r"""Numerical scheme for X\bar^y - - Args: - X (Tensor): shape (Nz, Ny, Nx) - - Returns: - Tensor: X\bar^y with shape (Nz, Ny-1, Nx) - """ - # Nz, Ny, Nx = X.shape - # filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) - # return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / 2. - - def delta_y(self, X, dy): - r"""Numerical scheme for \delta_y X - - Args: - X (Tensor): shape (Nz, Ny, Nx) - - Returns: - Tensor: \delta_y X with shape (Nz, Ny-1, Nx) - """ - Nz, Ny, Nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 2, 1) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz, Ny - 1, Nx) / dy - - def bar_z(self, X): - r"""Numerical scheme for X\bar^z - - Args: - X (Tensor): shape (Nz, Ny, Nx) - - Returns: - Tensor: X\bar^z with shape (Nz-1, Ny, Nx) - """ - Nz, Ny, Nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) - return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / 2. - - def delta_z(self, X, dz, filter): - r"""Numerical scheme for \delta_z X - - Args: - X (Tensor): shape (Nz, Ny, Nx) - - Returns: - Tensor: \delta_z X with shape (Nz-1, Ny, Nx) - """ - # Nz, Ny, Nx = X.shape - # filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) - # return F.conv3d(X.view(1, 1, Nz, Ny, Nx), filter).view(Nz - 1, Ny, Nx) / dz - x_ext = self.delta_z_pre(X) - x_conv = F.conv3d(x_ext, filter) - x_unext = self.delta_z_post(x_conv) - return x_unext #TODO / dz - - def _acoustic_step(self, ): - r"""One acustic step""" - pass - - -if __name__ == "__main__": - cube.init() - - # simulation settings - nx = 201 - ny = 201 - nz = 201 - dx = 1e3 # m - dy = 1e3 # m - dz = 1. / (nz + 1) - # constants - PREF = torch.tensor(1e5) # reference pressure, usually sea level pressure, Pa - Rd = torch.tensor(287) # gas constant for dry air, J/(kg*K) - g = torch.tensor(9.81) # the acceleration of gravity, m/s**2 - - x0 = 100e3 - y0 = 100e3 - grid_x, grid_y = torch.meshgrid(torch.linspace(0, 200e3, 201), torch.linspace(0, 200e3, 201)) - # 100K - theta = torch.linspace(0, 1, nz + 2).view(nz + 2, 1, 1) * torch.ones((1, ny, nx)) * 600. + 300 - theta += torch.linspace(1, 0, nz + 2).view(nz + 2, 1, 1) * \ - -100. * torch.exp(-0.5 * ((grid_x - x0)**2 + (grid_y - y0)**2) / 400e6).view(1, ny, nx) - # u0 = torch.ones((nz, ny, nx + 1)).cuda() - # wrf = WRF(dx, dy, nx, ny, nz, theta.cuda()) - theta = theta.cuda() - - dt = torch.tensor(0.1) - # nx = torch.tensor(nx) - # ny = torch.tensor(ny) - # nz = torch.tensor(nz) - dx = torch.tensor(dx) - dy = torch.tensor(dy) - dz = torch.tensor(dz) - - U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s = init(nx, ny, nz, dx, dy, dz, theta, PREF, Rd, g) - bar_x_filter = torch.tensor([1., 1.]).view(1, 1, 1, 1, 2) - delta_z_filter = torch.tensor([1., 1.]).view(1, 1, 2, 1, 1) - - varloader = SciLoopVariables(variables=[U, V, W, Theta, Mu, Phi, dt], constants=[Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter, delta_z_filter]) - model = WRF() - model = cube.SemanticModel(model, input_shapes=tuple(varloader.shapes), ) - - @cube.compile(model=model, dataloader=varloader, PAS=PAS, override=True) - def train_iter(model, dataloader): - U, V, W, Theta, Mu, Phi, dt, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, PREF, Rd, g, bar_x_filter , delta_z_filter= next(dataloader) - U, V, W, Theta, Mu, Phi = model(U, V, W, Theta, Mu, Phi, Phi_t, Phi_s, theta_t, theta_s, P_t, P_s, dx, dy, dz, dt, PREF, Rd, g, bar_x_filter, delta_z_filter) - return U, V, W, Theta, Mu, Phi - model = model.get_gen_module() - - import matplotlib.pyplot as plt - import numpy as np - - for iter in range (3): # while True: - # plt.cla() - # cf = plt.contourf(wrf.Theta[:, 100, :].cpu().numpy(), levels=50, cmap='jet') - # cb = plt.colorbar(cf) - # plt.savefig('res.jpeg', dpi=300) - # plt.clf() - # input('stop') - - # for i in range(1): - print("iter-{}...".format(iter)) - U, V, W, Theta, Mu, Phi = train_iter(model, varloader) # Phi_t, Phi_s, theta_t, theta_s - diff --git a/examples/wrf/wrf2.py b/examples/wrf/wrf2.py deleted file mode 100644 index 044d81c2..00000000 --- a/examples/wrf/wrf2.py +++ /dev/null @@ -1,527 +0,0 @@ -import torch -import torch.nn.functional as F - -from cube.runtime.syndata import SciLoopVariables -from cube.profiler.timer import CudaTimer, print_each_rank -from examples.wrf.policy.naive import PAS -import examples.wrf.policy.onedim as onedim - -torch.set_default_tensor_type(torch.DoubleTensor) - -import cube - -import argparse - -parser = argparse.ArgumentParser(description='') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') -args = parser.parse_args() - - -cube.init() -chosenPAS = PAS -policies = list(onedim.__dict__.keys()) -policies = [policy for policy in policies if policy.startswith('PAS')] -if args.policy in onedim.__dict__: - chosenPAS = onedim.__dict__[args.policy] - print_each_rank(f'using policy from onedim.{args.policy}') -else: - raise ValueError(f"policy {args.policy} not found. Candidates: {policies}") - -class WRF(torch.nn.Module): - def __init__(self, dt, ntau, nz, ny, nx, dz, dy, dx, device): - super().__init__() - # simulation domain settings - self.dt = dt - self.ntau = ntau - self.nx = nx - self.ny = ny - self.nz = nz - self.delta_x = dx - self.delta_y = dy - self.delta_z = dz - - # physics constant - self.g = 9.8 # acceleration of gravity, unit in m/s^2 - self.GAMMA = 1.4 # the ratio of heat capacities for dry air - self.PREF = 101325. # sea level pressure, unit in Pa - self.RD = 287. # gas constant for dry air J*kg^-1*K^-1 - self.RE = 6.4e6 # radius of earth, unit in m - self.OMEGA = 7.292e-5 # angular speed of the Earth s^-1 - - self.device = torch.device(device) - - # These three fields are to control the size of the unrolled graph - # by faking the loop upper bounds (UB), - # and they are related to the three layers of the nested loops, respectively. - # - # By setting to -1 we can recover the original loop upper bound. - # - # NOTE The magnitude is almost decided by `ntau` only. The final graph size may vary - # from ~4k (all fake UBs are 1) to ~23k (all fake UBs are -1, i.e. the original) - self._step_fake_ntau = 1 - self._ac_step_fake_ub = 1 - self._solver_fake_ub = 1 - - - def init(self, theta, Ptop=250e2): - eta = torch.linspace(0, 1, self.nz + 1, device=self.device) - pi = self.PREF - Ptop - p0 = Ptop + pi * eta - self.p0 = ((p0[:-1] + p0[1:]) / 2).view(self.nz, 1, 1) * torch.ones((1, self.ny, self.nx), device=self.device) - - self.mu0 = torch.ones((self.nz, self.ny, self.nx), device=self.device) * pi - mu1 = torch.zeros((self.nz, self.ny, self.nx), device=self.device) - - self.alpha0 = (self.RD * theta) / self.PREF * (self.p0 / self.PREF)**(-1. / self.GAMMA) - - phi0 = torch.zeros((self.nz + 1, self.ny, self.nx), device=self.device) - phi1 = torch.zeros((self.nz - 1, self.ny, self.nx), device=self.device) - phi0[-1] = self.alpha0[-1] * self.mu0[-1] - for i in range(self.nz - 1, -1, -1): - phi0[i] = self.alpha0[i] * self.mu0[i] * self.delta_z + phi0[i + 1] - self.phi0 = phi0[1:-1] # phi0 with shape (nz - 1, ny, nx) - self.phit = phi0[0].view(1, self.ny, self.nx) # model top hight - self.phis = phi0[-1].view(1, self.ny, self.nx) # earth surface hight - - self.ztop = (self.phit / self.g).view(1, self.ny, self.nx) - - Theta = theta * self.mu0 - - return Theta, phi1, mu1 - - def forward(self, U, V, W, O, Theta, phi1, mu1): - r""" - Args: - U (Tensor): (nz, ny, nx - 1) - V (Tensor): (nz, ny - 1, nx) - W (Tensor): (nz - 1, ny, nx) - O (Tensor): (nz - 1, ny, nx) - Theta (Tensor): (nz, ny, nx) - phi1 (Tensor): (nz - 1, ny, nx) - mu1 (Tensor): (nz, ny, nx) - """ - R_U, R_V, R_W, R_Theta, R_phi, R_mu = self.RHS(U, V, W, O, Theta, phi1, mu1) - U_, V_, W_, O_, Theta_, phi1_, mu1_ = \ - self.step(self.dt / 3, 1, - U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) - - R_U, R_V, R_W, R_Theta, R_phi, R_mu = self.RHS(U_, V_, W_, O_, Theta_, phi1_, mu1_) - # U_, V_, W_, O_, Theta_, phi1_, mu1_ = \ - # self._step(self.dt / 2, self.ntau // 2, - # U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) - U_, V_, W_, O_, Theta_, phi1_, mu1_ = \ - self.step(self.dt / self.ntau, self.ntau // 2, - U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) - - R_U, R_V, R_W, R_Theta, R_phi, R_mu = self.RHS(U_, V_, W_, O_, Theta_, phi1_, mu1_) - U, V, W, O, Theta, phi1, mu1 = \ - self.step(self.dt / self.ntau, self.ntau, - U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) - # U, V, W, O, Theta, phi1, mu1 = \ - # self._step(self.dt, self.ntau, - # U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu) - - # print() - # print('R_U\t', R_U.abs().min(), R_U.abs().max()) - # print('R_V\t', R_V.abs().min(), R_V.abs().max()) - # print('R_W\t', R_W.abs().min(), R_W.abs().max()) - # print('R_phi\t', R_phi.abs().min(), R_phi.abs().max()) - # print('R_mu\t', R_mu.abs().min(), R_mu.abs().max()) - # print('R_Theta\t', R_Theta.abs().min(), R_Theta.abs().max()) - - return U, V, W, O, Theta, phi1, mu1 - - def _step(self, dtau, ntau, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu): - U += R_U * dtau - V += R_V * dtau - W += R_W * dtau - Theta += R_Theta * dtau - phi1 += R_phi * dtau - mu1 += R_mu * dtau - - phi = phi1 + self.phi0 - O = self.g * W / self.dz(self.bz(self.pzphi(phi))) - - return U, V, W, O, Theta, phi1, mu1 - - def step(self, dtau:float, ntau:int, U, V, W, O, Theta, phi1, mu1, R_U, R_V, R_W, R_Theta, R_phi, R_mu): - # initialize perturbed varibles - U2 = torch.zeros(U.shape, device=self.device) - V2 = torch.zeros(V.shape, device=self.device) - W2 = torch.zeros(W.shape, device=self.device) - O2 = torch.zeros(O.shape, device=self.device) - Theta2 = torch.zeros(Theta.shape, device=self.device) - phi2 = torch.zeros(phi1.shape, device=self.device) - mu2 = torch.zeros(mu1.shape, device=self.device) - pi2 = torch.zeros((self.ny, self.nx), device=self.device) - - phi = self.phi0 + phi1 - mu = self.mu0 + mu1 - alpha = - self.dz(self.pzphi(phi)) / mu - p = self.PREF * (self.RD * Theta / mu / self.PREF / alpha)**self.GAMMA - - ntau = self._step_fake_ntau if self._step_fake_ntau >= 0 else ntau - - for i in range(ntau): - U2, V2, W2, O2, Theta2, phi2, mu2, pi2 = \ - self.ac_step(dtau, - U2, V2, W2, O2, Theta2, phi2, mu2, pi2, - R_U, R_V, R_W, R_Theta, R_phi, R_mu, - U, V, Theta, phi, mu, alpha, p) - - return U + U2, V + V2, W + W2, O + O2, Theta + Theta2, phi1 + phi2, mu1 + mu2 - - def ac_step(self, dtau:float, - U2, V2, W2, O2, Theta2, phi2, mu2, pi2, - R_U, R_V, R_W, R_Theta, R_phi, R_mu, - U, V, Theta, phi, mu, alpha, p): - r"""one acoustic step""" - # diagnostic variables - alpha2 = - (self.dz(self.pz(phi2)) + alpha * mu2) / mu - cs2 = self.GAMMA * p * alpha # square of sound speed - C = cs2 / mu / alpha**2 - p2 = self.GAMMA * p * (Theta2 / Theta - alpha2 / alpha - mu2 / mu) - theta = Theta / mu - - # prognostic variables - U2_ = U2 + dtau * ( - R_U - self.bx(mu) * ( - self.bx(alpha) * self.dx(p2) + - self.bx(alpha2) * self.dx(self.p0) + - self.dx(self.bz(self.pz(phi2)))) - - self.dx(self.bz(self.pzphi(phi))) * (self.dz(self.bx(self.pzp1(self.bz(p2)))) - self.bx(mu2)) - ) - V2_ = V2 + dtau * ( - R_V - self.by(mu) * ( - self.by(alpha) * self.dy(p2) + - self.by(alpha2) * self.dy(self.p0) + - self.dy(self.bz(self.pz(phi2)))) - - self.dy(self.bz(self.pzphi(phi))) * (self.dz(self.by(self.pzp1(self.bz(p2)))) - self.by(mu2)) - ) - - # W2_ = W2 + dtau * R_W - # O2_ = self.g * W2_ / self.dz(self.bz(self.pzphi(phi))) - # mu2_ = mu2 + dtau * R_mu - - dpi2 = - (self.dx(self.px(U2_ + U)) + self.dy(self.py(V2_ + V))).sum(0) * self.delta_z - pi2 = pi2 + dpi2 * dtau - - O2_ = torch.zeros(O2.shape, device=O2.device) - mu2_ = torch.zeros(mu2.shape, device=mu2.device) - - _ctrl_O2_ub = self._ac_step_fake_ub + 1 if self._ac_step_fake_ub >= 0 else O2.shape[0] + 1 - for i in range(1, _ctrl_O2_ub): - sub = i * self.delta_z * dpi2 + \ - (self.dx(self.px(U2_)) + self.dy(self.py(V2_)) - R_mu)[-i:].view( - -1, self.ny, self.nx).sum(0) * self.delta_z - O2_ = O2_.select_scatter(sub, dim=0, index=-i) - - _ctrl_mu2_ub = self._ac_step_fake_ub if self._ac_step_fake_ub >= 0 else mu2.shape[0] - for i in range(_ctrl_mu2_ub): - mu2_ = mu2_.select_scatter(pi2, dim=0, index=i) - - # self.O2_ = O2_ - - Theta2_ = Theta2 + dtau * ( - R_Theta - - self.dx(self.px(U2_ * self.bx(theta))) - - self.dy(self.py(V2_ * self.by(theta))) - - self.dz(self.pz(O2_ * self.bz(theta))) - ) - # print('Theta2_:\t', Theta2_.min(), Theta2_.max()) - # Theta2_ = torch.zeros(Theta2_.shape, device=Theta2_.device) - - W2_ = self.solve_tridiagonal_( - phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, - O2_, C, Theta2_, mu2_) - - #if torch.abs(f(W2_)).max() > 1e-6: - # print("Triangular solver warning:\t", torch.abs(f(W2_)).max()) - W2_ = W2_ / (1 + self.damping(phi, 0.2, self.ztop * 0.75) * dtau) - # print((1 + self.damping(phi, 0.8, self.ztop * 0.75) * dtau)[:, 64, 64]) - - # W2_ = W2 + dtau * R_W - - phi2_ = phi2 + dtau * ( - R_phi - (O2_ * self.dz(self.bz(self.pzphi(phi))) - self.g * (W2_ + W2) * 0.5) / self.bz(mu)) - # phi2_ = phi2 + dtau * R_phi - - return U2_, V2_, W2_, O2_, Theta2_, phi2_, mu2_, pi2 - - def damping(self, phi, gamma:float, zd): - z = phi / self.g - res = gamma * torch.sin(torch.pi / 2 * (1 - (self.ztop - z) / (self.ztop - zd)))**2 - return res * z.gt(zd).double() - - def RHS(self, U, V, W, O, Theta, phi1, mu1): - mu = self.mu0 + mu1 - phi = self.phi0 + phi1 - alpha = - self.dz(self.pzphi(phi)) / mu - alpha1 = alpha - self.alpha0 - theta = Theta / mu - p = self.PREF * (self.RD * theta / self.PREF / alpha)**self.GAMMA - p1 = p - self.p0 - - u = U / self.bx(mu) - v = V / self.by(mu) - w = W / self.bz(mu) - - R_U = ( - # pressure term - - self.bx(mu) * ( - self.dx(self.bz(self.pz(phi1))) - + self.bx(alpha) * self.dx(p1) - + self.bx(alpha1) * self.dx(self.p0)) - - self.dx(self.bz(self.pzphi(phi))) * (self.dz(self.bx(self.pzp1(self.bz(p1)))) - self.bx(mu1)) - # advection term - - self.dx(self.bx(self.px(U * u))) - - self.dy(self.bx(self.py(V * self.by(self.bx(self.px(u)))))) - - self.dz(self.bx(self.pz(O * self.bz(self.bx(self.px(u)))))) - ) - R_V = ( - # pressure term - - self.by(mu) * ( - self.dy(self.bz(self.pz(phi1))) - + self.by(alpha) * self.dy(p1) - + self.by(alpha1) * self.dy(self.p0)) - - self.dy(self.bz(self.pzphi(phi))) * (self.dz(self.by(self.pzp1(self.bz(p1)))) - self.by(mu1)) - # advection term - - self.dx(self.by(self.px(U * self.bx(self.by(self.py(v)))))) - - self.dy(self.by(self.py(V * v))) - - self.dz(self.by(self.pz(O * self.bz(self.by(self.py(v)))))) - ) - R_W = ( - # pressure term - #+ self.g * (self.dz(p1) - self.bz(self.mu0) * 0.0) - self.bz(mu1) * self.g - self.g * (self.dz(p1) - self.bz(self.mu0) * 0.0) - self.bz(mu1) * self.g - # advection term - - self.dx(self.px(self.bz(U) * self.bx(w))) - - self.dy(self.py(self.bz(V) * self.by(w))) - - self.dz(self.bz(self.pz(O * w))) - ) - R_Theta = ( - - self.dx(self.px(U * self.bx(theta))) - - self.dy(self.py(V * self.by(theta))) - - self.dz(self.pz(O * self.bz(theta))) - ) - R_phi = ( - # advection term - - self.bx(self.px(self.bz(U) * self.dx(phi))) - - self.by(self.py(self.bz(V) * self.dy(phi))) - - O * self.dz(self.bz(self.pzphi(phi))) - # gravity term - + self.g * W - ) / self.bz(mu) - R_mu = ( - # advection term - - self.dx(self.px(U)) - - self.dy(self.py(V)) - - self.dz(self.pz(O)) - ) - - return R_U, R_V, R_W, R_Theta, R_phi, R_mu - - def dx(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 1, 2) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) / self.delta_x - - def dy(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 1, 2, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) / self.delta_y - - def dz(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([-1., 1.], device=X.device).view(1, 1, 2, 1, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) / self.delta_z - - def px(self, X): - return F.pad(X, (1, 1), "circular") - - def py(self, X): - nz, ny, nx = X.shape - return F.pad(X.view(1, nz, ny, nx), (0, 0, 1, 1), "constant").view(nz, ny + 2, nx) - - def pz(self, X): - nz, ny, nx = X.shape - return F.pad(X.view(1, 1, nz, ny, nx), (0, 0, 0, 0, 1, 1), "constant").view(nz + 2, ny, nx) - - def pzphi(self, X): - """pad phi in z axis""" - return torch.cat((self.phit, X, self.phis), 0) - - def pzp1(self, X): - """pad p1 in z axis""" - nz, ny, nx = X.shape - p1t = torch.zeros((1, ny, nx), device=X.device) - p1s = X[-1].view(1, ny, nx) - return torch.cat((p1t, X, p1s), 0) - - def bx(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 1, 2) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny, nx - 1) / 2. - - def by(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 1, 2, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz, ny - 1, nx) / 2. - - def bz(self, X): - nz, ny, nx = X.shape - filter = torch.tensor([1., 1.], device=X.device).view(1, 1, 2, 1, 1) - return F.conv3d(X.view(1, 1, nz, ny, nx), filter).view(nz - 1, ny, nx) / 2. - - def tridiagonal_system(self, - phi2, dtau:float, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, - O2_, C, Theta2_, mu2_, - x): - phi2_ = phi2 + dtau * ( - R_phi - (O2_ * self.dz(self.bz(self.pzphi(phi))) - self.g * (x + W2) * 0.5) / self.bz(mu)) - return ( - R_W + ( - self.dz(C * self.dz(self.pz(phi2_))) + self.dz(self.GAMMA * p * Theta2_ / Theta) - self.bz(mu2_) + - self.dz(C * self.dz(self.pz(phi2))) + self.dz(self.GAMMA * p * Theta2 / Theta) - self.bz(mu2) - ) * 0.5 * self.g - ) * dtau + W2 - x - - def solve_tridiagonal_(self, - phi2, dtau:float, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, - O2_, C, Theta2_, mu2_): - b = - self.tridiagonal_system( - phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, - O2_, C, Theta2_, mu2_, - torch.zeros((self.nz - 1, self.ny, self.nx), device=self.device)) - - idx0 = torch.tensor([1., 0., 0.], device=self.device).view(3, 1, 1) - idx0 = idx0.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] - r0 = self.tridiagonal_system( - phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, - O2_, C, Theta2_, mu2_, - idx0) + b - - idx1 = torch.tensor([0., 1., 0.], device=self.device).view(3, 1, 1) - idx1 = idx1.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] - r1 = self.tridiagonal_system( - phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, - O2_, C, Theta2_, mu2_, - idx1) + b - - idx2 = torch.tensor([0., 0., 1.], device=self.device).view(3, 1, 1) - idx2 = idx2.repeat((self.nz - 1) // 3 + 1, self.ny, self.nx)[:self.nz - 1] - r2 = self.tridiagonal_system( - phi2, dtau, R_phi, phi, W2, mu, R_W, p, Theta, Theta2, mu2, - O2_, C, Theta2_, mu2_, - idx2) + b - - d = (torch.stack([r0, r1, r2], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1) - l = (torch.stack([r2, r0, r1], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1)[1:] - u = (torch.stack([r1, r2, r0], 1) * torch.stack([idx0, idx1, idx2], 1)).sum(1)[:-1] - - # forward sweep - - _ctrl_d_ub = self._solver_fake_ub + 1 if self._solver_fake_ub >= 0 else d.shape[0] - for i in range(1, _ctrl_d_ub): - w = l[i - 1] / d[i - 1] - - d_i = d[i] - w * u[i - 1] - b_i = b[i] - w * b[i - 1] - - d = d.select_scatter(d_i, dim=0, index=i) - b = b.select_scatter(b_i, dim=0, index=i) - - # backward substitution - x = torch.zeros(b.shape, device=b.device) - x = x.select_scatter(b[-1] / d[-1], dim=0, index=-1) - - _ctrl_x_range_start = self._solver_fake_ub - 1 if self._solver_fake_ub >= 0 else x.shape[0] - 2 - for i in range(_ctrl_x_range_start, -1, -1): - x = x.select_scatter( (b[i] - u[i] * x[i + 1]) / d[i], dim=0, index=i) - - return x - - -if __name__ == "__main__": - import matplotlib.pyplot as plt - from matplotlib.ticker import ScalarFormatter - - cube.init() - - nz = 16 - dz = 1. / 16 - ny = 128 - dy = 1e3 - nx = 128 - dx = 1e3 - - dt = 1. - - x = torch.linspace(-1., 1., 128).cuda() - y = torch.linspace(-1., 1., 128).cuda() - theta0 = (torch.linspace(1, 0, nz).cuda() * 500 + 300).view(nz, 1, 1) * torch.ones((1, ny, nx)).cuda() - theta1 = torch.exp(-0.5 * (x / 0.1)**2).view(1, 1, nx) * torch.exp(-0.5 * (y / 0.1)**2).view(1, ny, 1) * 0.01 - wrf = WRF(dt, 10, nz, ny, nx, dz, dy, dx, 'cuda') - Theta, phi1, mu1 = wrf.init(theta0) - Theta += theta1 * wrf.mu0 - - U = torch.zeros((nz, ny, nx - 1)).cuda() - V = torch.zeros((nz, ny - 1, nx)).cuda() - W = torch.zeros((nz - 1, ny, nx)).cuda() - O = torch.zeros((nz - 1, ny, nx)).cuda() - - varloader = SciLoopVariables(variables=[U, V, W, O, Theta, phi1, mu1], constants=[]) - model = cube.SemanticModel(wrf, input_shapes=tuple(varloader.shapes)) - - @cube.compile(model=model, dataloader=varloader, PAS=chosenPAS) - def train_iter(model, dataloader): - U, V, W, O, Theta, phi1, mu1 = next(dataloader) - U, V, W, O, Theta, phi1, mu1 = model(U, V, W, O, Theta, phi1, mu1) - return U, V, W, O, Theta, phi1, mu1 - model = model.get_gen_module() - - for i in range(10): - U, V, W, O, Theta, phi1, mu1 = train_iter(model, varloader) - mu = wrf.mu0 + mu1 - u = U / wrf.bx(mu) - v = V / wrf.by(mu) - w = W / wrf.bz(mu) - o = O / wrf.bz(mu) - theta = Theta / mu - - interval = 1 - if i % interval == 0: - # plt.cla() - # fig, ax = plt.subplots(2, 3, figsize=(12, 6)) - # - # ctf = ax[0, 0].contourf(u[nz // 2, :, :].cpu().numpy(), levels=50, cmap='jet') - # ax[0, 0].set_title('u') - # plt.colorbar(ctf, ax=ax[0, 0], format='%.1e') - # - # ctf = ax[0, 1].contourf(v[nz // 2, :, :].cpu().numpy(), levels=50, cmap='jet') - # ax[0, 1].set_title('v') - # plt.colorbar(ctf, ax=ax[0, 1], format='%.1e') - # - # ctf = ax[0, 2].contourf(w[:, 32, :].cpu().numpy(), levels=50, cmap='jet') - # ax[0, 2].set_title('w') - # plt.colorbar(ctf, ax=ax[0, 2], format='%.1e') - # - # ctf = ax[1, 0].contourf(mu1[nz // 2, :, :].cpu().numpy(), levels=50, cmap='jet') - # ax[1, 0].set_title(r'$\mu^\prime$') - # plt.colorbar(ctf, ax=ax[1, 0], format='%.1e') - # - # ctf = ax[1, 1].contourf(phi1[:, 32, :].cpu().numpy(), levels=50, cmap='jet') - # ax[1, 1].set_title(r'$\phi^\prime$') - # plt.colorbar(ctf, ax=ax[1, 1], format='%.1e') - # - # ctf = ax[1, 2].contourf(o[:, 32, :].cpu().numpy(), levels=50, cmap='jet') - # ax[1, 2].set_title(r'$\omega$') - # plt.colorbar(ctf, ax=ax[1, 2], format='%.1e') - # - # fig.text(0.01, 0.95, f't={i * dt}s', size=18) - # plt.tight_layout() - # plt.savefig(f'res/res{i // interval}.jpeg', dpi=300) - # plt.close() - # plt.clf() - - print(i) From e1d3c971887d34c9851a210dc4ffd1116abe4d20 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 23 Aug 2023 11:24:55 +0000 Subject: [PATCH 1489/1892] Merged PR 1752: update setup tools update setup tools --- README.md | 50 +++++++++++------------------------------------- requirements.txt | 4 ++-- setup.py | 21 ++++++++++++-------- 3 files changed, 26 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 4ff20825..086a53ae 100644 --- a/README.md +++ b/README.md @@ -4,55 +4,27 @@ AI System Compiler to map a semantic (single-device) model into distributed exec ## Prerequisite -* Python >= 3.7 +Install the following packages before the installation of cube: -> Install Python 3.7 in the development environment for widest compatibility. +* Python >= 3.8 -Install dependent packages -```sh -pip install -r requirements.txt - -# require pytorch version >= 1.11 -pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch/ -# pip install torch==1.11.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html -``` - -## Option 1: Quick Start without Installation - -* ### Run on repo root path: -```sh -PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/linears.py --policy PASCol -``` +* PyTorch >= 1.13 -[comment]: <> (UDA_VISIBLE_DEVICES=7 PYTHONPATH=.:$PYTHONPATH python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 ./examples/wrf/wrf2.py) +## Install -* ### Debug for model parsing check on single Device -```shell -PYTHONPATH=.:$PYTHONPATH SINGLE_DEV_MODE=1 python examples/mlp/linears.py +```bash +pip install -e . ``` +## Run Example ---- - -## Option 2: Install for Run - -* ### Install - -```python -python setup.py develop -``` - -* ### Run Example -[Micro Benchmark] Run a mutiple MLP Model +Run an MLP Model on 4 GPUs: ```sh -OMP_NUM_THREADS=4 torchrun \ +PYTHONPATH=:.$PYTHONPATH torchrun \ --nproc_per_node=4 \ --nnodes=1 \ - examples/mlp/linears.py --policy PASCol + examples/mlp/train.py --policy PASCol ``` @@ -118,7 +90,7 @@ tox -r To run a single unit test task during development, you can run ``` -pytest unit_tests/your_test_file.py +pytest tests/your_test_file.py ``` ### Run unit tests in vscode diff --git a/requirements.txt b/requirements.txt index d5b9475b..96ee7e4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ +numpy>=1.23.0 matplotlib -setuptools==60.7.0 more-itertools dill -torch +torch>=1.13 diff --git a/setup.py b/setup.py index 3824b357..20c504db 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,17 @@ import setuptools +with open("requirements.txt") as f: + install_requires = [ + line.split("#")[0].strip() for line in f if not line.startswith("#") and line.split("#")[0].strip() != "" + ] + setuptools.setup( - name= 'cube', - version= '0.2', - author= 'Zhiqi Lin', - author_email= 'v-zhiql@microsoft.com', - description= 'Parallelize DNN Traning from A Systematic Way', - long_description= 'Parallelize DNN Traning from A Systematic Way', - packages= ['cube'], - python_requires= '>=3.6', + name= 'cube', + version= '0.2', + author= 'Cube Team', + description= 'Parallelize DNN Traning from A Systematic Way', + long_description= 'Parallelize DNN Traning from A Systematic Way', + packages= ['cube'], + python_requires= '>=3.8', + install_requires= install_requires, ) From f5cd17b116d10764df29c13cf65a9f161191f3cc Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 25 Aug 2023 07:08:22 +0000 Subject: [PATCH 1490/1892] Merged PR 1751: quick fix fx cube signature not align --- cube/graph/parser/external/apex.py | 78 +++++++++++++++++++++---- cube/profiler/database.py | 2 +- tests/graph/parser/test_register.py | 90 +++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 11 deletions(-) create mode 100644 tests/graph/parser/test_register.py diff --git a/cube/graph/parser/external/apex.py b/cube/graph/parser/external/apex.py index 8b8b099e..a3e67418 100644 --- a/cube/graph/parser/external/apex.py +++ b/cube/graph/parser/external/apex.py @@ -9,13 +9,17 @@ try: - from apex.normalization.fused_layer_norm import FusedLayerNormFunction, FusedLayerNormAffineFunction + from apex.normalization.fused_layer_norm import FusedLayerNormFunction, FusedLayerNormAffineFunction, FusedRMSNormFunction, FusedRMSNormAffineFunction + + pure_sign_fused_layer_norm = CustomizedOps.create_pure_signature('apex.normalization.fused_layer_norm.FusedLayerNormFunction.apply', True) + pure_sign_fused_layer_norm_affine = CustomizedOps.create_pure_signature('apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply', True) + pure_sign_fused_rms_norm = CustomizedOps.create_pure_signature('apex.normalization.fused_layer_norm.FusedRMSNormFunction.apply', True) + pure_sign_fused_rms_norm_affine = CustomizedOps.create_pure_signature('apex.normalization.fused_layer_norm.FusedRMSNormAffineFunction.apply', True) def ApexFusedLayerNormFunction(input, normalized_shape, eps=1e-6, signature = None): """ apex.normalization.fused_layer_norm.FusedLayerNormFunction """ - signature = 'apex_fused_layer_norm' letters = iter(string.ascii_lowercase) einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) eoutput = copy.copy(einput) @@ -28,13 +32,12 @@ def ApexFusedLayerNormFunction(input, normalized_shape, eps=1e-6, signature = No anno = OpAnno.create_op_str(einputs, [eoutput]) kwargs['normalized_shape'] = normalized_shape kwargs['eps'] = eps - return IRDimops(FusedLayerNormFunction, 'fusedlayernorm', signature, [anno], inputs, **kwargs) + return IRDimops(ApexFusedLayerNormFunction, 'fusedlayernorm', signature, [anno], inputs, **kwargs) def ApexFusedLayerNormAffineFunction(input, weight, bias, normalized_shape, eps=1e-6, signature = None): """ apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction """ - signature = 'apex_fused_layer_norm_affine' assert not (weight is None and bias is not None), f"Not support for None of weight and parameter of bias" letters = iter(string.ascii_lowercase) einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) @@ -60,22 +63,77 @@ def ApexFusedLayerNormAffineFunction(input, weight, bias, normalized_shape, eps= anno = OpAnno.create_op_str(einputs, [eoutput]) kwargs['normalized_shape'] = normalized_shape kwargs['eps'] = eps - return IRDimops(FusedLayerNormAffineFunction, 'fusedlayernormaffine', signature, [anno], inputs, **kwargs) - + return IRDimops(ApexFusedLayerNormAffineFunction, 'fusedlayernormaffine', signature, [anno], inputs, **kwargs) + + def ApexFusedRMSNormFunction(input, normalized_shape, eps=1e-6, signature = None): + """ + apex.normalization.fused_layer_norm.FusedRMSNormFunction + """ + letters = iter(string.ascii_lowercase) + einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) + eoutput = copy.copy(einput) + ndims = len(input.shape) + for dim in range(len(normalized_shape)): + einput[ndims-1-dim] += '^' + eoutput[ndims-1-dim] += '^' + einputs, inputs = [einput], [input] + kwargs = {} + anno = OpAnno.create_op_str(einputs, [eoutput]) + kwargs['normalized_shape'] = normalized_shape + kwargs['eps'] = eps + return IRDimops(ApexFusedRMSNormFunction, 'fusedrmsnorm', signature, [anno], inputs, **kwargs) + + def ApexFusedRMSNormAffineFunction(input, weight, normalized_shape, eps=1e-6, signature = None): + """ + apex.normalization.fused_layer_norm.FusedRMSNormAffineFunction + """ + letters = iter(string.ascii_lowercase) + einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) + eoutput = copy.copy(einput) + ndims = len(input.shape) + for dim in range(len(normalized_shape)): + einput[ndims-1-dim] += '^' + eoutput[ndims-1-dim] += '^' + einputs, inputs = [einput], [input] + kwargs = {} + if weight is not None: + eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) + einputs.append(eweight) + inputs.append(weight) + else: + kwargs['weight'] = weight + anno = OpAnno.create_op_str(einputs, [eoutput]) + kwargs['normalized_shape'] = normalized_shape + kwargs['eps'] = eps + return IRDimops(ApexFusedRMSNormAffineFunction, 'fusedrmsnormaffine', signature, [anno], inputs, **kwargs) + CustomizedOps.register('apex.normalization.fused_layer_norm.FusedLayerNormFunction.apply', ApexFusedLayerNormFunction, - 'from apex.normalization.fused_layer_norm import fused_layer_norm as apex_fused_layer_norm', + f'from apex.normalization.fused_layer_norm import fused_layer_norm as {pure_sign_fused_layer_norm}', FusedLayerNormFunction.apply, keep_full_name=True, trace_autowrap=False) CustomizedOps.register('apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply', ApexFusedLayerNormAffineFunction, - 'from apex.normalization.fused_layer_norm import fused_layer_norm_affine as apex_fused_layer_norm_affine', + f'from apex.normalization.fused_layer_norm import fused_layer_norm_affine as {pure_sign_fused_layer_norm_affine}', FusedLayerNormAffineFunction.apply, keep_full_name=True, trace_autowrap=False) -except: - _logger.warning('skip apex ops as it is not installed.') + CustomizedOps.register('apex.normalization.fused_layer_norm.FusedRMSNormFunction.apply', + ApexFusedRMSNormFunction, + f'from apex.normalization.fused_layer_norm import fused_rms_norm as {pure_sign_fused_rms_norm}', + FusedRMSNormFunction.apply, + keep_full_name=True, + trace_autowrap=False) + + CustomizedOps.register('apex.normalization.fused_layer_norm.FusedRMSNormAffineFunction.apply', + ApexFusedRMSNormAffineFunction, + f'from apex.normalization.fused_layer_norm import fused_rms_norm_affine as {pure_sign_fused_rms_norm_affine}', + FusedRMSNormAffineFunction.apply, + keep_full_name=True, + trace_autowrap=False) +except: + _logger.warning('skip apex ops as it is not installed.') diff --git a/cube/profiler/database.py b/cube/profiler/database.py index e1de7124..3525e9be 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -199,7 +199,7 @@ def get_dep_names(sign: str): code_impl: str = CustomizedOps.kOpCodeDef[node.signature] local = {} exec(code_impl, globals(), local) - fn = list(local.values())[0] + fn = list(local.values())[-1] else: fn = eval(node.signature) shapes, dtypes, requires_grads, values = [], [], [], [] diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py new file mode 100644 index 00000000..6d8167a4 --- /dev/null +++ b/tests/graph/parser/test_register.py @@ -0,0 +1,90 @@ +import cube +from cube.graph.parser.converter import convert_model +from cube.profiler.database import ProfileDataBase +import tempfile +import torch + + +def mock_add(x: torch.Tensor, y: torch.Tensor): + return x + y + +cube.graph.parser.register('*, * -> *')(mock_add) + + +@cube.graph.parser.register('*, * -> *') +def mock_add2(x: torch.Tensor, y: torch.Tensor): + return x + y + + +class MockAGF(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, y: torch.Tensor): + return x + y + + @staticmethod + def backward(ctx, grad): + return grad, grad + +cube.graph.parser.register('*, * -> *')(MockAGF.apply) + + +class TestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x, y): + x, y = self.fc(x), self.fc(y) + return mock_add(x, y) + +class TestModel2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x, y): + x, y = self.fc(x), self.fc(y) + return mock_add2(x, y) + +class TestModel3(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x, y): + x, y = self.fc(x), self.fc(y) + return MockAGF.apply(x, y) + + +# passed test +def test_common_register(): + model = TestModel() + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) + + # test profiler.database + for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'mock_add']): + profile_name = ProfileDataBase.get_func(node)[0].__qualname__ + assert profile_name == p_name, f'{profile_name} should be {p_name}' + + +def test_common_register2(): + model = TestModel2() + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) + + # test profiler.database + for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'mock_add2']): + profile_name = ProfileDataBase.get_func(node)[0].__qualname__ + assert profile_name == p_name, f'{profile_name} should be {p_name}' + + +def test_autograd_register(): + model = TestModel3() + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) + + # test profiler.database + for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'MockAGF.apply']): + profile_name = ProfileDataBase.get_func(node)[0].__qualname__ + assert profile_name == p_name, f'{profile_name} should be {p_name}' From 71fbce5ccdd8f1505a1ea78991348e4dfb92864e Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 1 Sep 2023 15:01:30 +0000 Subject: [PATCH 1491/1892] Merged PR 1784: refine register interface refine register interface registering a customized operator are merged into one single API: `cube.graph.parser.register` Changes have made: - align signatures of - interface of `cube.graph.parser.register` changes with `node_repr` to accept annotation of type str or callable function that generates str (details in function docstring) - remove constraint of type hints: No type hints required for the registered functions, and the system will check the input type according to the specified IRDimOps annotation during runtime parse. - remove constraint of prioritized tensor-type inputs ordering: developers can simply specify '?' for non-tensor type inputs, which will make non-tensor inputs be included in inputs and replicated during operator partitioning. Generated code pattern: * Registering a python function gencode*.py will have: ``` import examples.nlp.blocks.attention class Model: def segmentxx(self, x): ... x = examples.nlp.blocks.attention.self_attention(x, xxx) ``` * Registering a pytorch autograd.Function gencode*.py will have: ```python import apex.normalization.fused_layer_norm class Model: def segmentxx(self, x): ... x = apex.normalization.fused_layer_norm.FusedLayerNormFunction.apply(x) ... ``` --- cube/graph/parser/converter.py | 5 +- cube/graph/parser/external/apex.py | 137 +++-------- cube/graph/parser/register.py | 230 +++++++++---------- cube/ir/operator.py | 8 +- cube/profiler/database.py | 19 +- tests/graph/parser/test_register.py | 14 +- tests/graph/parser/test_register_external.py | 52 +++++ 7 files changed, 218 insertions(+), 247 deletions(-) create mode 100644 tests/graph/parser/test_register_external.py diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index bc1d11af..fcb8fe70 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -9,6 +9,7 @@ from cube.graph.parser.fx.parser import FxModuleParser from cube.graph.parser.fx.concrete_trace_utils import concrete_trace +from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply import cube.runtime.function as cube_rt_function @@ -29,7 +30,9 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: torch.fx.GraphModule representation of model """ # get registered leaf function - autowrap_funcs = [CustomizedOps.kOpRuntime.get(sign, None) for sign in CustomizedOps.kOpAutowrap] + autowrap_funcs = [CustomizedOps.kOpRuntime[sign] for sign in CustomizedOps.kOpMap] + # filter out torch.autograd.Function.apply as concrete trace already treats them as leaf function + autowrap_funcs = [fn for fn in autowrap_funcs if not is_autograd_apply(fn)] leaf_functions = {func: ([], True, None) for func in autowrap_funcs if func is not None} # get cube runtime functions diff --git a/cube/graph/parser/external/apex.py b/cube/graph/parser/external/apex.py index a3e67418..08cd9461 100644 --- a/cube/graph/parser/external/apex.py +++ b/cube/graph/parser/external/apex.py @@ -2,138 +2,73 @@ import logging import string -from cube.graph.function.dimops import ShapeAnno, OpAnno, IRDimops -from cube.graph.parser.register import CustomizedOps +from cube.graph.function.dimops import ShapeAnno, OpAnno +from cube.graph import parser _logger = logging.getLogger(__name__) try: + from apex.normalization.fused_layer_norm import FusedLayerNormFunction, FusedLayerNormAffineFunction, FusedRMSNormFunction, FusedRMSNormAffineFunction - - pure_sign_fused_layer_norm = CustomizedOps.create_pure_signature('apex.normalization.fused_layer_norm.FusedLayerNormFunction.apply', True) - pure_sign_fused_layer_norm_affine = CustomizedOps.create_pure_signature('apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply', True) - pure_sign_fused_rms_norm = CustomizedOps.create_pure_signature('apex.normalization.fused_layer_norm.FusedRMSNormFunction.apply', True) - pure_sign_fused_rms_norm_affine = CustomizedOps.create_pure_signature('apex.normalization.fused_layer_norm.FusedRMSNormAffineFunction.apply', True) - def ApexFusedLayerNormFunction(input, normalized_shape, eps=1e-6, signature = None): + def apex_fused_layer_norm_anno(input, normalized_shape, *args, **kwargs): """ apex.normalization.fused_layer_norm.FusedLayerNormFunction """ letters = iter(string.ascii_lowercase) - einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) - eoutput = copy.copy(einput) + input_anno = ShapeAnno.create_shape_str(input.shape, iterator=letters) ndims = len(input.shape) for dim in range(len(normalized_shape)): - einput[ndims-1-dim] += '^' - eoutput[ndims-1-dim] += '^' - einputs, inputs = [einput], [input] - kwargs = {} - anno = OpAnno.create_op_str(einputs, [eoutput]) - kwargs['normalized_shape'] = normalized_shape - kwargs['eps'] = eps - return IRDimops(ApexFusedLayerNormFunction, 'fusedlayernorm', signature, [anno], inputs, **kwargs) + input_anno[ndims-1-dim] += '^' + inputs = [input_anno, '?'] + ['?' for _ in args] + outputs = [copy.copy(input_anno),] + assert len(kwargs) == 0, f'torch.autgrad.Function receives unexpected kwargs ({kwargs}) for apply.' + return OpAnno.create_op_str(inputs, outputs) + + + # apex.normalization.fused_layer_norm.FusedRMSNormFunction + apex_fused_rms_norm_anno = apex_fused_layer_norm_anno - def ApexFusedLayerNormAffineFunction(input, weight, bias, normalized_shape, eps=1e-6, signature = None): + + def apex_fused_layer_norm_affine_anno(input, weight, bias, normalized_shape, eps) -> str: """ apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction """ assert not (weight is None and bias is not None), f"Not support for None of weight and parameter of bias" letters = iter(string.ascii_lowercase) - einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) - eoutput = copy.copy(einput) + anno_input = ShapeAnno.create_shape_str(input.shape, iterator=letters) ndims = len(input.shape) for dim in range(len(normalized_shape)): - einput[ndims-1-dim] += '^' - eoutput[ndims-1-dim] += '^' - einputs, inputs = [einput], [input] - kwargs = {} - if weight is not None: - eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) - einputs.append(eweight) - inputs.append(weight) - else: - kwargs['weight'] = weight - if bias is not None: - ebias = ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) - einputs.append(ebias) - inputs.append(bias) - else: - kwargs['bias'] = bias - anno = OpAnno.create_op_str(einputs, [eoutput]) - kwargs['normalized_shape'] = normalized_shape - kwargs['eps'] = eps - return IRDimops(ApexFusedLayerNormAffineFunction, 'fusedlayernormaffine', signature, [anno], inputs, **kwargs) + anno_input[ndims-1-dim] += '^' + outputs = [copy.copy(anno_input),] + inputs = [anno_input] + inputs.append(ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) if weight is not None else '?') + inputs.append(ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) if bias is not None else '?') + inputs += ['?', '?'] + return OpAnno.create_op_str(inputs, outputs) - def ApexFusedRMSNormFunction(input, normalized_shape, eps=1e-6, signature = None): - """ - apex.normalization.fused_layer_norm.FusedRMSNormFunction - """ - letters = iter(string.ascii_lowercase) - einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) - eoutput = copy.copy(einput) - ndims = len(input.shape) - for dim in range(len(normalized_shape)): - einput[ndims-1-dim] += '^' - eoutput[ndims-1-dim] += '^' - einputs, inputs = [einput], [input] - kwargs = {} - anno = OpAnno.create_op_str(einputs, [eoutput]) - kwargs['normalized_shape'] = normalized_shape - kwargs['eps'] = eps - return IRDimops(ApexFusedRMSNormFunction, 'fusedrmsnorm', signature, [anno], inputs, **kwargs) - def ApexFusedRMSNormAffineFunction(input, weight, normalized_shape, eps=1e-6, signature = None): + def apex_fused_rms_norm_affine_anno(input, weight, normalized_shape, eps) -> str: """ apex.normalization.fused_layer_norm.FusedRMSNormAffineFunction """ letters = iter(string.ascii_lowercase) - einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) - eoutput = copy.copy(einput) + input_anno = ShapeAnno.create_shape_str(input.shape, iterator=letters) ndims = len(input.shape) for dim in range(len(normalized_shape)): - einput[ndims-1-dim] += '^' - eoutput[ndims-1-dim] += '^' - einputs, inputs = [einput], [input] - kwargs = {} - if weight is not None: - eweight = ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) - einputs.append(eweight) - inputs.append(weight) - else: - kwargs['weight'] = weight - anno = OpAnno.create_op_str(einputs, [eoutput]) - kwargs['normalized_shape'] = normalized_shape - kwargs['eps'] = eps - return IRDimops(ApexFusedRMSNormAffineFunction, 'fusedrmsnormaffine', signature, [anno], inputs, **kwargs) - - CustomizedOps.register('apex.normalization.fused_layer_norm.FusedLayerNormFunction.apply', - ApexFusedLayerNormFunction, - f'from apex.normalization.fused_layer_norm import fused_layer_norm as {pure_sign_fused_layer_norm}', - FusedLayerNormFunction.apply, - keep_full_name=True, - trace_autowrap=False) - - CustomizedOps.register('apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply', - ApexFusedLayerNormAffineFunction, - f'from apex.normalization.fused_layer_norm import fused_layer_norm_affine as {pure_sign_fused_layer_norm_affine}', - FusedLayerNormAffineFunction.apply, - keep_full_name=True, - trace_autowrap=False) + input_anno[ndims-1-dim] += '^' + outputs = [copy.copy(input_anno),] + inputs = [input_anno] + inputs.append(ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) if weight is not None else '?') + inputs += ['?', '?'] + return OpAnno.create_op_str(inputs, outputs) - CustomizedOps.register('apex.normalization.fused_layer_norm.FusedRMSNormFunction.apply', - ApexFusedRMSNormFunction, - f'from apex.normalization.fused_layer_norm import fused_rms_norm as {pure_sign_fused_rms_norm}', - FusedRMSNormFunction.apply, - keep_full_name=True, - trace_autowrap=False) - CustomizedOps.register('apex.normalization.fused_layer_norm.FusedRMSNormAffineFunction.apply', - ApexFusedRMSNormAffineFunction, - f'from apex.normalization.fused_layer_norm import fused_rms_norm_affine as {pure_sign_fused_rms_norm_affine}', - FusedRMSNormAffineFunction.apply, - keep_full_name=True, - trace_autowrap=False) + parser.register(apex_fused_layer_norm_anno)(FusedLayerNormFunction.apply) + parser.register(apex_fused_layer_norm_affine_anno)(FusedLayerNormAffineFunction.apply) + parser.register(apex_fused_rms_norm_anno)(FusedRMSNormFunction.apply) + parser.register(apex_fused_rms_norm_affine_anno)(FusedRMSNormAffineFunction.apply) except: _logger.warning('skip apex ops as it is not installed.') diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index b54d705d..f2779703 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -2,14 +2,14 @@ Register cutomized function """ -from typing import Dict, Callable, List, Optional, Any +from typing import Dict, Callable, Optional, Union from functools import partial import inspect import logging -import torch - from cube.graph.function.dimops import IRDimops, OpAnno +from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply +from cube.ir.operator import IRTensor _logger = logging.getLogger(__name__) @@ -23,10 +23,6 @@ class CustomizedOps: kOpRuntime: Dict[str, Callable] = {} # signature -> runtime function implementation code kOpCodeDef: Dict[str, str] = {} - # original signature (xxx.xxx.xxx) -> pure signature (xxx_xxx_xxx) - kOpSignMap: Dict[str, str] = {} - # the function in it will be autowrapped by tracer. - kOpAutowrap: List[str] = [] @staticmethod def map(signature: str) -> Callable: @@ -38,30 +34,24 @@ def map(signature: str) -> Callable: Returns: Callable: IRDimop creation function """ - signature = CustomizedOps.pure_signature(signature) - if signature in CustomizedOps.kOpMap: - return partial(CustomizedOps.kOpMap[signature], signature=signature) - else: + if signature not in CustomizedOps.kOpMap: raise KeyError(f"{signature} is not found in registered ops") + return partial(CustomizedOps.kOpMap[signature], signature=signature) @staticmethod def exist(signature: str) -> bool: """Check if the signature is registered""" - signature = CustomizedOps.pure_signature(signature) return signature in CustomizedOps.kOpMap @staticmethod - def register(signature: str, op: Callable, code: str, runtime_fn: Callable, - keep_full_name: bool = False, trace_autowrap: bool = True): + def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Callable): """Register an operator Args: signature (str): operator signature - op (Callable): IRDimop creation function + op_create_fn (Callable): IRDimops creation function code (str): runtime function implementation code runtime_fn (Callable): runtime function - keep_full_name (bool): if set True, the full name will be kept, `.` in name will be replaced to `_` - trace_autowrap (bool): if set True, the function will be autowrapped by tracer. Returns: None @@ -69,63 +59,64 @@ def register(signature: str, op: Callable, code: str, runtime_fn: Callable, builtins = ['_operator', 'torch', 'cube.runtime.function'] if any(signature.startswith(builtin) for builtin in builtins): raise RuntimeError(f"Cannot register operators with signature starting from any of {builtins}") - signature = CustomizedOps.create_pure_signature(signature, keep_full_name) assert signature not in CustomizedOps.kOpMap, f"function {signature} is already registered" - CustomizedOps.kOpMap[signature] = op + CustomizedOps.kOpMap[signature] = op_create_fn CustomizedOps.kOpRuntime[signature] = runtime_fn CustomizedOps.kOpCodeDef[signature] = code - if trace_autowrap and signature not in CustomizedOps.kOpAutowrap: - CustomizedOps.kOpAutowrap.append(signature) - elif not trace_autowrap and signature in CustomizedOps.kOpAutowrap: - CustomizedOps.kOpAutowrap.pop(signature) - @staticmethod - def create_pure_signature(signature: str, keep_full_name: bool) -> str: - if keep_full_name: - pure_signature = signature.replace('__main__.', '', 1) if signature.startswith('__main__.') else signature - pure_signature = pure_signature.replace('.', '_') - CustomizedOps.kOpSignMap[signature] = pure_signature - return pure_signature - return signature.split('.')[-1] - - @staticmethod - def pure_signature(signature: str) -> str: - if signature in CustomizedOps.kOpSignMap: - return CustomizedOps.kOpSignMap[signature] - return signature.split('.')[-1] - -def register(anno: str, name: Optional[str] = None, - rules: Optional[List] = None, - input_type_annos: Optional[List[Any]] = None, +def register(node_repr: Union[str, Callable], name: Optional[str] = None, code_impl_pattern: str = 'import') -> Callable: """ - Register a function with einop annotations. + Register a function with IRDimops annotations. - This function is cooperated with IRDimops. - User needs to define a python function that satisfies - 1). Has type annotations for each input - 2). Tensor inputs goes first then other inputs + This function is cooperated with IRDimops. Users can only register functions defined under a module, instead of + ones defined inside a function / class or __main__ scope. - For DimAnnos containing brackets (e.g., (3 h d)) that can not be - inferred by system, user should have same argument name in the - function definition to help system infer each dim length, e.g., + The annotation (`node_repr`) specifies the number of inputs as *args, + and treat all the rest inputs as **kwargs. - @cube.register('a (b c) -> (a b) c') - def funcname(x: torch.Tensor, b: int = 4): + For tensor-type inputs, the annotation should be a string of identifiers separated by space, e.g., `'a b'`; + For non-tensor-type inputs, the annotation should be specified '?'. + + Examples: + + ```python + import cube + from third_party import func + + cube.graph.parser.register('a (b c) -> (a b) c')(func) + ``` + + or, + + ```python + import cube + from third_party import func + + @cube.graph.parser.register('a (b c) -> (a b) c') + def func(x, b = 4): xxx + ``` + + or, + + ```python + import cube + from third_party import func - Note: for Optional[torch.Tensor] type, user should annotate the - dimension when the input is not None. + def anno_fn(*inputs, **kwargs): + return 'a (b c) -> (a b) c' + + cube.graph.parser.register(anno_fn)(func) + ``` Args: - anno (str): operator annotation - name (str): operator name - rules (Optional[List[TransformRule]]): - additional transformation rules. - input_type_annos (Optional[List[Any]]): - type annotations for inputs. If not provided, the function - should be annotated with types. + node_repr (str | Callable): operator annotation of IRDimops or callable function that generates IRFwOperation. + - op annotation: e.g., 'a (b c) -> (a b) c' + - a callable function that generates op annotation (str). The function + taks inputs and kwargs as arguments and returns the operator annotation. + name (str | None): operator name. Only usable when node_repr is a string. code_impl_pattern (str): can only be 'import' or 'source'. If 'import', will generate code with import statement. If 'source', will take the source code directly. @@ -134,83 +125,90 @@ def funcname(x: torch.Tensor, b: int = 4): Returns: fn (Callable): the runtime function """ - from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply def decorator(fn: Callable): + nonlocal code_impl_pattern + if not callable(fn): - raise TypeError("Expected a function") + raise TypeError("Expected a runtime function") + + # step 1. get function signature and inputs + def get_import_path(fn: Callable) -> str: + if is_autograd_apply(fn): + import_path = inspect.getmodule(fn.__self__).__name__ + else: + import_path = inspect.getmodule(fn).__name__ + return import_path + + import_path = get_import_path(fn) + if import_path == '__main__': + raise NotImplementedError( + f"Cannot register function {fsig} in __main__ module. " + f"Try to define it in another module and import into main") + if is_autograd_apply(fn): - fsig = CustomizedOps.create_pure_signature(f'{fn.__self__.__module__}.{fn.__self__.__name__}.apply', True) - op_name = name if name is not None else fn.__name__ + fsig = f'{import_path}.{fn.__self__.__name__}.apply' + op_name = name if name is not None else fn.__self__.__name__ args = inspect.signature(fn.__self__.forward) arg_names = list(args.parameters.keys())[1:] else: - fsig = fn.__name__ - op_name = name if name is not None else fsig + fsig = f'{import_path}.{fn.__name__}' + op_name = name if name is not None else fn.__name__ args = inspect.signature(fn) arg_names = list(args.parameters.keys()) - # get argument types - arg_kinds = input_type_annos if input_type_annos is not None else \ - [args.parameters[name].annotation for name in arg_names] - assert len(arg_kinds) == len(arg_names), \ - "Number of annotations should match with number of arguments" - # parse for number of inputs and kwargs - allow_types = (torch.Tensor, Optional[torch.Tensor]) - for ninputs, kind in enumerate(arg_kinds): - if kind in allow_types: - ninputs += 1 - continue - assert not any(k in allow_types for k in arg_kinds[ninputs:]), \ - f"Type of {allow_types} should be consecutive in parameter order." - break - nkwargs = len(arg_names) - ninputs - kwarg_names = [name for name in arg_names[ninputs:]] - - # get customized op code - if code_impl_pattern == 'import': + + # step 2. get customized op code + def get_source_code(fn: Callable) -> str: if is_autograd_apply(fn): - import_path = inspect.getmodule(fn.__self__).__name__ - if import_path == '__main__': - _logger.warning(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' - f'This may cause error when the function has inner functions from other modules. ' - f'To solve this, define the function in another module and import into main', stacklevel=0) - code = inspect.getsource(fn.__self__) - code = code[code.index(f'class {fn.__self__.__name__}'):] + f'\n{fsig}={fn.__self__.__name__}.apply' - else: - code = f'from {import_path} import {fn.__self__.__name__}\n{fsig}={fn.__self__.__name__}.apply' + code = inspect.getsource(fn.__self__) + code = code[code.index(f'class {fn.__self__.__name__}'):] else: - import_path = inspect.getmodule(fn).__name__ - if import_path == '__main__': - _logger.warning(f'Find the function {fsig} is defined in __main__ module, will take the source code directly. ' - f'This may cause error when the function has inner functions from other modules. ' - f'To solve this, define the function in another module and import into main', stacklevel=0) - code = inspect.getsource(fn) - code = code[code.index('def'):] - else: - code = f'from {import_path} import {fsig}' + code = inspect.getsource(fn) + code = code[code.index('def'):] + return code + + def get_import_code(fn: Callable) -> str: + import_path = get_import_path(fn) + code = f'import {import_path}' + return code + + if code_impl_pattern == 'import': + code = get_import_code(fn) elif code_impl_pattern == 'source': - assert not is_autograd_apply(fn), 'Only support code_impl_pattern="import" for autograd.Function.apply.' - code = inspect.getsource(fn) - code = code[code.index('def'):] + code = get_source_code(fn) else: raise ValueError(f'code_impl_pattern should be either "import" or "source", got {code_impl_pattern}') + # step 3. define customized IRDimops creation function + if not (isinstance(node_repr, str) or callable(node_repr)): + raise TypeError(f"node_repr should be either str or callable, got {type(node_repr)}") + def udfop(*args, signature=None, **kwargs): - manno = OpAnno(anno) + anno = node_repr if isinstance(node_repr, str) else node_repr(*args, **kwargs) + if not isinstance(anno, str): + raise TypeError(f"node_repr should return a string, but got {type(anno)}: {anno}") + anno = OpAnno(anno) + ninputs = len(anno.inputs()) + if len(args) < ninputs: + raise ValueError(f"calling function {signature} should include at least {ninputs} *args") tensors = args[:ninputs] - for idx in range(ninputs): - if arg_kinds[idx] == Optional[torch.Tensor] and tensors[idx] is None: - manno.set_input(idx, '?') + for idx, t in enumerate(tensors): + # argument check + if str(anno.input(idx)) != '?': + if not isinstance(t, IRTensor): + raise ValueError( + f"{idx}-th input needs IRTensor, but got {type(t)}: {t}\n" + f"signature: {signature}\n" + f"annotation: {anno}") + kwarg_names = [name for name in arg_names[ninputs:]] kwarg_vals = args[ninputs:] for name, val in zip(kwarg_names, kwarg_vals): kwargs[name] = val - return IRDimops(udfop, op_name, signature, [repr(manno)], tensors, transform_rules=rules, **kwargs) + return IRDimops(udfop, op_name, signature, [repr(anno)], tensors, **kwargs) - _logger.info(f'registering op {fsig} with {ninputs} inputs and {nkwargs} kwargs...') - if is_autograd_apply(fn): - CustomizedOps.register(f'{fn.__self__.__module__}.{fn.__self__.__name__}.apply', udfop, code, fn, True, False) - else: - CustomizedOps.register(fsig, udfop, code, fn) + # step 4. register in CustomizedOps + _logger.info(f'registering op {fsig}...') + CustomizedOps.register(fsig, udfop, code, fn) return fn return decorator diff --git a/cube/ir/operator.py b/cube/ir/operator.py index ff8faa43..7fa43f7a 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -137,16 +137,14 @@ def replicate(self): return cpy def __repr__(self) -> str: - sign = self.signature.split('.')[-1] - dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " + dscp = (f"FwOp{self._id}-{self.device}(name={self.name}, " f"inputs={self.inputs()}, " f"outputs={self.outputs()})") return dscp def extra_repr(self) -> str: - sign = self.signature.split('.')[-1] - # ins = [t for t in self.inputs()] - dscp = (f"FwOp{self._id}-{self.device}(sign={sign}, " + dscp = (f"FwOp{self._id}-{self.device}(name={self.name}, " + f"sign={self.signature}, " f"inputs={self.inputs()}, " f"outputs={self.outputs()})") return dscp diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 3525e9be..c3b7fabb 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -183,23 +183,8 @@ def get_func(node: IRFwOperation) -> Tuple[Callable, Shapes, DTypes, Dict]: """ assert isinstance(node, IRFwOperation), f"Only support profiling forward operation but got {type(node)}" - def get_dep_names(sign: str): - ret = [] - code_impl = CustomizedOps.kOpCodeDef[sign] - for code_line in code_impl.split('\n'): - idx = code_line.find('# call: ') - if idx != -1: - dep_name = code_line[idx + 8:] - assert dep_name in CustomizedOps.kOpCodeDef, dep_name - ret = ret + get_dep_names(dep_name) - ret.append(dep_name) - return ret - - if node.signature in CustomizedOps.kOpCodeDef: - code_impl: str = CustomizedOps.kOpCodeDef[node.signature] - local = {} - exec(code_impl, globals(), local) - fn = list(local.values())[-1] + if node.signature in CustomizedOps.kOpRuntime: + fn = CustomizedOps.kOpRuntime[node.signature] else: fn = eval(node.signature) shapes, dtypes, requires_grads, values = [], [], [], [] diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index 6d8167a4..a08e6699 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -28,7 +28,7 @@ def backward(ctx, grad): cube.graph.parser.register('*, * -> *')(MockAGF.apply) -class TestModel(torch.nn.Module): +class MockModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc = torch.nn.Linear(10, 10) @@ -37,7 +37,7 @@ def forward(self, x, y): x, y = self.fc(x), self.fc(y) return mock_add(x, y) -class TestModel2(torch.nn.Module): +class MockModel2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc = torch.nn.Linear(10, 10) @@ -46,7 +46,7 @@ def forward(self, x, y): x, y = self.fc(x), self.fc(y) return mock_add2(x, y) -class TestModel3(torch.nn.Module): +class MockModel3(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc = torch.nn.Linear(10, 10) @@ -58,7 +58,7 @@ def forward(self, x, y): # passed test def test_common_register(): - model = TestModel() + model = MockModel() with tempfile.TemporaryDirectory() as tempdir: ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) @@ -69,7 +69,7 @@ def test_common_register(): def test_common_register2(): - model = TestModel2() + model = MockModel2() with tempfile.TemporaryDirectory() as tempdir: ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) @@ -80,11 +80,11 @@ def test_common_register2(): def test_autograd_register(): - model = TestModel3() + model = MockModel3() with tempfile.TemporaryDirectory() as tempdir: ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) # test profiler.database - for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'MockAGF.apply']): + for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'Function.apply']): profile_name = ProfileDataBase.get_func(node)[0].__qualname__ assert profile_name == p_name, f'{profile_name} should be {p_name}' diff --git a/tests/graph/parser/test_register_external.py b/tests/graph/parser/test_register_external.py new file mode 100644 index 00000000..842586c6 --- /dev/null +++ b/tests/graph/parser/test_register_external.py @@ -0,0 +1,52 @@ + +import torch +import logging +import tempfile +from cube.graph.parser.converter import convert_model +from cube.ir.operator import IRFwOperation +from cube.graph.function.dimops import IRDimops + +_logger = logging.getLogger(__name__) + +def test_register_apex_fused_op(): + + have_apex = True + + try: + from apex.normalization.fused_layer_norm import FusedLayerNorm + from apex.normalization.fused_layer_norm import FusedRMSNorm + + except Exception as e: + _logger.warning(f'skip op registering test on external apex due to lack of apex installation.') + have_apex = False + + if not have_apex: + return + + class ApexModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.empty(128, dtype=torch.float16)) + # fused layer norm + self.fused_layer_norm = FusedLayerNorm((128,), eps=1e-5, elementwise_affine=False) + self.fused_layer_norm_affine = FusedLayerNorm((128,), eps=1e-5, elementwise_affine=True) + # fused rms norm + self.fused_rms_norm = FusedRMSNorm((128,), eps=1e-5, elementwise_affine=False) + self.fused_rms_norm_affine = FusedRMSNorm((128,), eps=1e-5, elementwise_affine=True) + + def forward(self, x): + x = self.param + x + x = self.fused_layer_norm(x) + x = self.fused_layer_norm_affine(x) + x = self.fused_rms_norm(x) + x = self.fused_rms_norm_affine(x) + return x + + sample = torch.randn((4, 128), dtype=torch.float16, device=torch.cuda.current_device()) + model = ApexModel().half() + with tempfile.TemporaryDirectory() as tempdir: + graph = convert_model(model, dummy_input={'x': sample}, attr_savedir=tempdir) + print(graph.extra_repr()) + apex_nodes = [n for n in graph.select(ntype=IRFwOperation) if 'apex' in n.signature] + assert len(apex_nodes) == 4, graph.extra_repr() + assert all(isinstance(n, IRDimops) for n in apex_nodes), graph.extra_repr() From 40ea61a97c0e68ff6881de66f387534dda185983 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 5 Sep 2023 08:59:38 +0000 Subject: [PATCH 1492/1892] Merged PR 1794: parallel module: add gradient accumulation support 1. remove _AddGradModule as it doesn't work for non-float input. Instead, we will hook optimizer.step to call reducer.sync_grads to do the same thing 2. Add register_reducer_pre[post]_hook to optimizer. 3. Add sync_shard_grad to optimizer, so user can manually sync grads before optimizer.step() --- cube/parallel.py | 154 ++++++++++++++++++--- cube/runtime/adapter/reducer.py | 28 ++-- cube/runtime/module.py | 47 ++----- tests/parallel_module/test_reducer_hook.py | 145 +++++++++++++++++++ tests/parallel_module/test_submodule.py | 33 +++-- tests/parallel_module/test_wholemodule.py | 2 +- 6 files changed, 322 insertions(+), 87 deletions(-) create mode 100644 tests/parallel_module/test_reducer_hook.py diff --git a/cube/parallel.py b/cube/parallel.py index 9612acb0..9a6a2d56 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -1,10 +1,13 @@ +from functools import partial import types -from typing import Callable, Any, Dict, Optional, Type, Union +from typing import Callable, Any, Dict, Optional, Type, Union, TypeVar from pathlib import Path import inspect import sys import importlib from dataclasses import dataclass +from contextlib import contextmanager +import logging import torch from cube.graph.parser.fx.parser import FxModuleParser @@ -12,6 +15,8 @@ from cube.ir.cten import IRObject from cube.ir.tensor import IRFullTensor +from cube.flags import CompileFlag, RuntimeFlag + from cube.graph import IRGraph from cube.graph import parser from cube.graph.function.anchor import IRGraphAnchor @@ -25,15 +30,43 @@ from cube.execplan.planpass.fusion import DiffFusion from cube.ir.unique import IDGenerator from cube.program import Program +from cube.runtime.adapter.reducer import Reducer from cube.runtime.module import CubeModule, ParallelModule +logger = logging.getLogger(__name__) + + @dataclass class ComputeConfig: plan_ngpus: int runtime_ngpus: int +@contextmanager +def _flags(flags, warning_on_override=True, **kwargs): + old_flags = {} + for k, v in kwargs.items(): + old_flags[k] = getattr(flags, k) + if old_flags[k] != v: + if warning_on_override: + logger.warning(f"{flags}.{k}={old_flags[k]} is not supported. Changed to {v}.") + setattr(flags, k, v) + try: + yield + finally: + for k, v in old_flags.items(): + setattr(flags, k, v) + + +def _compile_flags(): + return _flags(CompileFlag, use_zero=False, async_reducer=False, reducer_op='sum', async_comm=False) + + +def _runtime_flags(**kwargs): + return _flags(RuntimeFlag, warning_on_override=False, **kwargs) + + def _complex(val: Any): """Complex to CPU""" if isinstance(val, tuple): @@ -250,7 +283,7 @@ def _gencode( runtime_ngpus = None if compute_config.plan_ngpus == compute_config.runtime_ngpus else compute_config.runtime_ngpus assert len(execplan.graph.device) == compute_config.plan_ngpus, f"{execplan.graph.device}" mgener = ModuleCodeGen(execplan, scale_ndevs=runtime_ngpus) - for rank in range(compute_config.plan_ngpus): + for rank in range(compute_config.runtime_ngpus): filename = _GENCODE_FILE_TEMPLATE.format(rank) mgener.gen(rank, forward_arg_names=forward_args, outfile=outdir / filename, attach=False, as_parallel_module=True) @@ -354,17 +387,17 @@ def parallelize( if any(isinstance(m, CubeModule) for m in module.modules()): raise RuntimeError('CubeModule can not be nested.') - - _gencode( - module, - dummy_input, - pas_policy, - compute_config, - dynamic_shape=dynamic_shape, - override=override, - cube_savedir=cube_savedir, - instance_name=instance_name, - ) + with _compile_flags(): + _gencode( + module, + dummy_input, + pas_policy, + compute_config, + dynamic_shape=dynamic_shape, + override=override, + cube_savedir=cube_savedir, + instance_name=instance_name, + ) if is_module_class: del module @@ -385,12 +418,51 @@ def parallelize( return cube_module +class ParallelOptimizer(torch.optim.Optimizer): + """ + A optimizer stub to support parallelized module. + The returned optimizer of build_optimizer() will have the same methods in this class. + """ + def sync_shard_grad(self): + """ + Sync the shard gradients of the module from nodes with same shard to the optimizer. + Please note this is called automatically in optimizer.step(). + But If you want to access the gradients before optimizer.step(), + you need to call this function manually. + """ + ... + + def register_reducer_pre_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): + """ + Register pre hooks to reducers which will be applied before gradient synchronization. + + The pre-hooks will be applied one by one following the order of registration. + + Args: + fn (Callable[[Reducer, torch.Tensor], None]): a callable function that takes a reducer and a gradient as input and optionally updates the gradient. + """ + ... + + def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): + """ + Register post hooks to reducers which will be applied after gradient synchronization. + + The post-hooks will be applied one by one following the order of registration. + + Args: + fn (Callable[[Reducer, torch.Tensor], None]): a callable function that takes a reducer and a gradient as input and optionally updates the gradient. + """ + ... + +OptimizerT = TypeVar('OptimizerT', bound=torch.optim.Optimizer) + + def build_optimizer( module: torch.nn.Module, - optimizer_fn: Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]], + optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], *args, **kwargs, -) -> torch.optim.Optimizer: +) -> OptimizerT: """ Build an optimizer for a module. @@ -400,12 +472,12 @@ def build_optimizer( so we need to replace the parameters of optimizer with CubeModule.parameters_for_optimizer It is impossible to make this change transparent to end users. 2. optimizer.step(): + we need to call optimier.sync_shard_grad() to sync the gradients of the module before optimizer.step(). In zero mode, we have to call CubeModule.gather_params() after optimizer.step() 3. optimizer.zero_grad(): We need to call CubeModule.zero_grad() after optimizer.zero_grad() 4. backward(): - we need to call CubeModule.sync_grads() after each CubeModule backward. - This is done with _AddGradModule and its hook in ParallelModule. + you need to call optimizer.sync_shard_grad() manually if you want to read the gradients of the module before optimizer.step(). Please note this DOES NOT work in end2end mode. @@ -419,11 +491,16 @@ def build_optimizer( Returns: torch.optim.Optimizer: the optimizer you should use to train the module + The optimizer is created by optimizer_fn, + and will be patched with the methods in ParallelModule class to support parallelized module. """ if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): raise RuntimeError("End2End mode is not supported") + RuntimeFlag.skip_reducer = True + RuntimeFlag.skip_zero_grad = False + def _local_parameters(module: torch.nn.Module): gen = module._named_members(lambda m: m._parameters.items()) for _, param in gen: @@ -431,18 +508,53 @@ def _local_parameters(module: torch.nn.Module): optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), *args, **kwargs) - def _step_hook(opt, *args, **kwargs): + def _step_pre_hook(opt, *args, **kwargs): + opt.sync_shard_grad() + def _step_post_hook(opt, *args, **kwargs): for m in module.modules(): - if isinstance(m, CubeModule): + if isinstance(m, ParallelModule): m.gather_params() - optimizer.register_step_post_hook(_step_hook) + else: + assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + optimizer.register_step_pre_hook(_step_pre_hook) + optimizer.register_step_post_hook(_step_post_hook) orig_zero_grad = optimizer.zero_grad def _patched_zero_grad_hook(self, set_to_none: bool = True): orig_zero_grad(set_to_none) for m in module.modules(): - if isinstance(m, CubeModule): + if isinstance(m, ParallelModule): m.zero_grad() + else: + assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" optimizer.zero_grad = types.MethodType(_patched_zero_grad_hook, optimizer) + def _sync_shard_grad(self): + with _runtime_flags(skip_reducer=False): + for m in module.modules(): + if isinstance(m, ParallelModule): + m.sync_grad() + else: + assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + optimizer.sync_shard_grad = types.MethodType(_sync_shard_grad, optimizer) + + def _register_reducer_pre_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): + for m in module.modules(): + if isinstance(m, ParallelModule): + for reducer in m.reducers: + reducer.register_pre_hook(partial(fn, reducer)) + else: + assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + + def _register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): + for m in module.modules(): + if isinstance(m, ParallelModule): + for reducer in m.reducers: + reducer.register_post_hook(partial(fn, reducer)) + else: + assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + + optimizer.register_reducer_pre_hook = types.MethodType(_register_reducer_pre_hook, optimizer) + optimizer.register_reducer_post_hook = types.MethodType(_register_reducer_post_hook, optimizer) + return optimizer diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index b0b9ab33..da749ee0 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -39,12 +39,12 @@ class Bucket: def __init__(self, params: List[torch.nn.Parameter], param_buffer: torch.Tensor, grad_buffer: torch.Tensor, - reduce_op: torch.distributed.ReduceOp, + reduce_op: torch.distributed.ReduceOp, group, async_op: bool, zero: bool, zero_subgroup: torch.distributed.ProcessGroup = None): """ Create a communication unit for parameter allreduce. - + One allreduce will be called for all gradients associated to the parameters. The parameters are assumed to participate in backward and generate gradient. @@ -104,7 +104,7 @@ def params(self) -> Tuple: def zero(self) -> bool: """Whether enable zero for this bucket""" return self._zero - + def build(self): """ Build offset for each parameter @@ -181,8 +181,8 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): def sync_grads(self): """ Wait until allreduce finished (async), or perform allreduce (sync). - - The `.grad` attribute for each parameter will also be set after + + The `.grad` attribute for each parameter will also be set after the completion of allreduce. """ rank = torch.distributed.get_rank(group=self._group) @@ -270,7 +270,7 @@ def _apply_pre_hooks(self): def _apply_post_hooks(self): """Apply post hooks after gradient synchronization. - + The post-hooks will be applied one by one following the order of registration. """ if len(self._post_hooks) == 0: return @@ -281,7 +281,7 @@ def _apply_post_hooks(self): def clear_pre_hooks(self): """Clear all pre hooks.""" self._pre_hooks = [] - + def clear_post_hooks(self): """Clear all post hooks.""" self._post_hooks = [] @@ -442,7 +442,7 @@ def build_buckets(self): padding = len(self._ranks) - numel % len(self._ranks) buffer_length += numel + padding stops.append(buffer_length) - + # step3: allocate memory # gradient buffer self._contiguous_grads: torch.Tensor = torch.zeros( @@ -465,8 +465,8 @@ def build_buckets(self): ofst += param.numel() # initialize buckets bucket = Bucket( - params, - self._contiguous_params[start:stop], + params, + self._contiguous_params[start:stop], self._contiguous_grads[start:stop], self._reduce_op, self._group, @@ -484,7 +484,7 @@ def build_buckets(self): def sync_grads(self): """ - synchronize gradients using allreuce (non-zero) or reduce-scatter (zero) + synchronize gradients using allreduce (non-zero) or reduce-scatter (zero) """ if RuntimeFlag.skip_reducer: return for bucket in self._buckets: @@ -501,7 +501,7 @@ def gather_params(self): def zero_grad(self): """Make gradient to be zero. - + This needs to be called at the beginning of every training iteration. """ if RuntimeFlag.skip_zero_grad: return @@ -535,7 +535,7 @@ def register_pre_hook(self, fn: Callable): A reducer can be registered by multiple hooks and the hooks will be applied in the order of registration. - + The hook function takes a contiguous buffer of local computed gradient and can optionally apply in-place operations on it. @@ -560,7 +560,7 @@ def register_post_hook(self, fn: Callable): A reducer can be registered by multiple hooks and the hooks will be applied in the order of registration. - + The hook function takes a contiguous buffer of updated gradient and can only apply in-place operations on it. diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 16d898fb..21fa9476 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -369,30 +369,6 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): }, filename_prefix + '.full.ckpt') -class _AddGradModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, *args): - if not self.training: - return args - - new_args = [] - found_tensor = False - for arg in args: - if isinstance(arg, torch.Tensor): - found_tensor = True - new_arg = arg - if not arg.requires_grad: - new_arg = arg.clone().requires_grad_(True) - new_args.append(new_arg) - else: - new_args.append(arg) - if not found_tensor: - raise RuntimeError("Failed to setup module backward hook: no input Tensors.") - return tuple(new_args) - - class ParallelModule(CubeModule): COMPUTE_CONFIG_FILE = 'compute_config.pt' @@ -401,11 +377,8 @@ def __init__(self): raise RuntimeError(f"ParallelModule should not be initialized directly. Please derive it first") super().__init__() - - # register_full_backward_pre_hook requires the input tensor to be requires_grad - # so we add a module to make sure the input tensor requires grad - self._add_grad_module = _AddGradModule() - self._add_grad_module.register_full_backward_pre_hook(self.backward_hook) + # this is used to allow multiple sync_grad() calls + self._sync_grad_required = False def _post_init(self): module_file = Path(sys.modules[self.__module__].__file__) @@ -418,10 +391,8 @@ def _post_init(self): def forward(self, *args, **kwargs): if self.training: - new_args = self._add_grad_module(*args) - else: - new_args = args - return self._forward_impl(*new_args, **kwargs) + self._sync_grad_required = True # mark sync_grad() can be called again + return self._forward_impl(*args, **kwargs) def _forward_impl(self, *args, **kwargs): """ @@ -429,12 +400,14 @@ def _forward_impl(self, *args, **kwargs): """ raise NotImplementedError - def backward_hook(self, module, grad_output): + def sync_grad(self): """ - backward hook for gradient synchronization + synchronize gradients using allreduce (non-zero) or reduce-scatter (zero) """ - for reducer in self.reducers: - reducer.sync_grads() + if self._sync_grad_required: + self._sync_grad_required = False # mark sync_grad() has been called + for reducer in self._reducers: + reducer.sync_grads() def get_dist_param_map(self): return self._dist_param_map diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py new file mode 100644 index 00000000..14528467 --- /dev/null +++ b/tests/parallel_module/test_reducer_hook.py @@ -0,0 +1,145 @@ +import tempfile +from pathlib import Path +from collections import defaultdict + +import torch +from torch import nn + +from cube.parallel import ComputeConfig, parallelize, build_optimizer +from cube.runtime.module import ParallelModule + +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively + + +class FcRelu(nn.Module): + def __init__(self, in_features=4, out_features=4, bias=True): + super().__init__() + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, out_features, bias=bias) + self.relu2 = nn.ReLU() + self.fc3 = CubeLinear(out_features, out_features, bias=bias) + self.relu3 = nn.ReLU() + + + def forward(self, x): + return self.relu3(self.fc3(self.relu2(self.fc2(self.relu1(self.fc1(x)))))) + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + dynamic_shape=True, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + +def _create_module(pas, compute_config, cube_savedir): + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc_relu1 = _to_cube_model(FcRelu(), pas, compute_config, cube_savedir, 'fc_relu1') + self.fc_relu2 = _to_cube_model(FcRelu(), pas, compute_config, cube_savedir, 'fc_relu2') + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.fc_relu1(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + compiled_module = CompiledModule().cuda() + return compiled_module + + +def _train(model): + init_random() + + pre_called = defaultdict(int) + post_called = defaultdict(int) + def pre_hook(reducer, grad): + pre_called[reducer] += 1 + + def post_hook(reducer, grad): + post_called[reducer] += 1 + + loss_fn = nn.BCELoss() + + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + optimizer.register_reducer_pre_hook(pre_hook) + optimizer.register_reducer_post_hook(post_hook) + + reducers = [] + for m in model.modules(): + if isinstance(m, ParallelModule): + reducers.extend(m.reducers) + + if not reducers: + print('No reducer found, skip test_hook') + return + + data = [] + DATA_SIZE = 20 + UPDATE_FREQ = 2 + for _ in range(DATA_SIZE): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + results = [] + for i, (x, y) in enumerate(data): + model.train() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + if i % UPDATE_FREQ == UPDATE_FREQ - 1: + optimizer.step() + grads = {n: p.grad for n, p in model.named_parameters()} + results.append(clone_to_cpu_recursively([y_pred, loss, grads])) + optimizer.zero_grad() + weights = {n: p.data for n, p in model.named_parameters()} + results[-1].append(clone_to_cpu_recursively(weights)) + assert pre_called == post_called + assert set(pre_called.keys()) == set(reducers) + assert all(v == (i + 1) // UPDATE_FREQ for v in pre_called.values()) + return results + + +def _gpu_worker(pas, ngpus): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_hook') as tempdir: + compiled_module = _create_module(pas, ComputeConfig(ngpus, ngpus), tempdir) + _train(compiled_module) + + +def test_hook_tp_gpu1(): + if not torch.cuda.is_available(): + print('skip test_submodules_tp_gpu1 due to lack of cuda devices') + return + launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) + + +def test_hook_tp_gpu2(): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + print('skip test_submodules_tp_gpu2 due to lack of cuda devices') + return + launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) + + +def test_submodules_dp_gpu1(): + if not torch.cuda.is_available(): + print('skip test_submodules_dp_gpu1 due to lack of cuda devices') + return + launch_torchrun(1, _gpu_worker, PASData, 1) + + +def test_submodules_dp_gpu2(): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + print('skip test_submodules_dp_gpu2 due to lack of cuda devices') + return + launch_torchrun(2, _gpu_worker, PASData, 2) diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index d57f0c56..32c372e5 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -3,6 +3,7 @@ import re from pathlib import Path import shutil +import pytest import torch from torch import nn @@ -80,14 +81,14 @@ def forward(self, x): return orig_module, compiled_module -def _train(model): +def _train(model, update_freq): init_random() loss_fn = nn.BCELoss() optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) data = [] DATA_SIZE = 20 - UPDATE_FREQ = 1 # TODO: update_freq support + UPDATE_FREQ = update_freq for _ in range(DATA_SIZE): data.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), @@ -100,21 +101,21 @@ def _train(model): loss = loss_fn(y_pred, y) loss.backward() if i % UPDATE_FREQ == UPDATE_FREQ - 1: + optimizer.step() grads = {n: p.grad for n, p in model.named_parameters()} results.append(clone_to_cpu_recursively([y_pred, loss, grads])) - optimizer.step() optimizer.zero_grad() weights = {n: p.data for n, p in model.named_parameters()} results[-1].append(clone_to_cpu_recursively(weights)) return results -def _gpu_worker(pas, ngpus): +def _gpu_worker(pas, ngpus, update_freq): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) - orig_results = _train(orig_module) - compiled_results = _train(compiled_module) + orig_results = _train(orig_module, update_freq) + compiled_results = _train(compiled_module, update_freq) return ( orig_results, compiled_results, @@ -125,11 +126,12 @@ def _gpu_worker(pas, ngpus): ) -def test_submodules_tp_gpu1(): +@pytest.mark.parametrize('update_freq', [1, 2, 4]) +def test_submodules_tp_gpu1(update_freq): if not torch.cuda.is_available(): print('skip test_submodules_tp_gpu1 due to lack of cuda devices') return - results = launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) + results = launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1, update_freq) orig_results, compiled_results, _, _, _, _ = results[0] for orig, compiled in zip(orig_results, compiled_results): assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred @@ -185,11 +187,12 @@ def _compare_weights(orig0, orig1, compiled0, compiled1, fc1_fullmap, fc2_fullma assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) -def test_submodules_tp_gpu2(): +@pytest.mark.parametrize('update_freq', [1, 2, 4]) +def test_submodules_tp_gpu2(update_freq): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: print('skip test_submodules_tp_gpu2 due to lack of cuda devices') return - results = launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) + results = launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2, update_freq) results0, results1 = results[0], results[1] eps = 1e-4 @@ -219,11 +222,12 @@ def test_submodules_tp_gpu2(): _compare_weights(orig0[3], orig1[3], compiled0[3], compiled1[3], fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map) -def test_submodules_dp_gpu1(): +@pytest.mark.parametrize('update_freq', [1, 2, 4]) +def test_submodules_dp_gpu1(update_freq): if not torch.cuda.is_available(): print('skip test_submodules_dp_gpu1 due to lack of cuda devices') return - results = launch_torchrun(1, _gpu_worker, PASData, 1) + results = launch_torchrun(1, _gpu_worker, PASData, 1, update_freq) orig_results, compiled_results, _, _, _, _ = results[0] for orig, compiled in zip(orig_results, compiled_results): assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred @@ -242,12 +246,13 @@ def test_submodules_dp_gpu1(): assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) -def test_submodules_dp_gpu2(): +@pytest.mark.parametrize('update_freq', [1, 2, 4]) +def test_submodules_dp_gpu2(update_freq): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: print('skip test_submodules_dp_gpu2 due to lack of cuda devices') return eps = 1e-4 - results = launch_torchrun(2, _gpu_worker, PASData, 2) + results = launch_torchrun(2, _gpu_worker, PASData, 2, update_freq) for r in results.values(): orig_results, compiled_results, _, _, _, _ = r for orig, compiled in zip(orig_results, compiled_results): diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 4fa7cfd1..aa88edfb 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -88,9 +88,9 @@ def _train(model): loss = loss_fn(y_pred, y) loss.backward() if i % UPDATE_FREQ == UPDATE_FREQ - 1: + optimizer.step() grads = {n: p.grad for n, p in model.named_parameters()} results.append(clone_to_cpu_recursively([y_pred, loss, grads])) - optimizer.step() optimizer.zero_grad() weights = {n: p.data for n, p in model.named_parameters()} results[-1].append(clone_to_cpu_recursively(weights)) From 994c7e23a6e9fd1d7a41caaa885ed15f70d53ee4 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 8 Sep 2023 09:49:34 +0000 Subject: [PATCH 1493/1892] Merged PR 1762: fix runtime_ngpus in gpt example fix runtime_ngpus --- examples/megatron_gpt/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/megatron_gpt/parallel.py b/examples/megatron_gpt/parallel.py index adae5b09..e41df6f8 100644 --- a/examples/megatron_gpt/parallel.py +++ b/examples/megatron_gpt/parallel.py @@ -1,6 +1,6 @@ import os plan_ngpus = int(os.environ['PLAN_NGPUS']) -runtime_ngpus = int(os.environ['CUBE_SCALING_FACTOR']) +runtime_ngpus = int(os.environ['CUBE_SCALING_FACTOR']) * plan_ngpus # 1. load graph from cube.graph import IRGraph From 93942e90c20b15ab7a57a4d3a4329c6ffcd89a03 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 11 Sep 2023 04:55:55 +0000 Subject: [PATCH 1494/1892] Merged PR 1787: fix parsing nested struct with tensor in model inputs --- cube/graph/parser/fx/parser.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 1fe2a293..bdb874a5 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -101,6 +101,7 @@ def parse(module: torch.fx.GraphModule, dynamic_shape (bool): whether to parse the module with dynamic shape """ from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp + from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata frame = frame if frame is not None else Frame() frame.push_var() @@ -117,35 +118,25 @@ def parse(module: torch.fx.GraphModule, for idx, input in enumerate(inputs): assert isinstance(input, torch.fx.Node) # dealing with different types of dummy_inputs - if isinstance(dummy_inputs, dict): - if input.name not in dummy_inputs: - val = IRObject(input.name) - else: - if 'tensor_meta' in input.meta: - shape = input.meta['tensor_meta'].shape - if len(shape) == 0: - shape = [1] - dtype = input.meta['tensor_meta'].dtype - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) - else: - val = IRObject(input.name, value=dummy_inputs[input.name]) + if not isinstance(dummy_inputs, dict): + raise RuntimeError('dummy_inputs should be a dict.') + if input.name not in dummy_inputs: + val = IRObject(input.name) else: - # FIXME: this part is only for transformers.tokenization_utils_base.BatchEncoding, - # extend to other input types - if hasattr(dummy_inputs, input.name): - shape = getattr(dummy_inputs, input.name).size() + if 'tensor_meta' in input.meta and isinstance(input.meta['tensor_meta'], TensorMetadata): + shape = input.meta['tensor_meta'].shape + if len(shape) == 0: + shape = [1] + dtype = input.meta['tensor_meta'].dtype + val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) else: - # FIXME: seems the kwargs name (e.g., _deprecated_arguments) is not aligned with input.name - shape = None - dtype = input.meta['tensor_meta'].dtype - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) + val = IRObject(input.name, value=dummy_inputs[input.name]) frame.add_var(input.name, val, graph_arg=idx) input_val = [frame.get_var(input.name) for input in inputs] # add activations to frame, including call_func/call_method output and final output # call_module corresponds to leaf torch.nn.module - from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata activation_op_strs = {'call_function', 'output', 'call_method', 'call_module'} activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] def parse_complex_out(meta_out): From 9bbac579028fa1a4884a1ae4f7e4c9ecf00e1380 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 11 Sep 2023 05:02:07 +0000 Subject: [PATCH 1495/1892] Merged PR 1809: skip saving tensor activation during graph parse skip saving tensor activation for during graph parse. --- cube/graph/parser/converter.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index fcb8fe70..7c4aa6ac 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -15,6 +15,7 @@ import torch import torch.fx +from torch.autograd.graph import saved_tensors_hooks _logger = logging.getLogger(__name__) @@ -40,15 +41,25 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: leaf_functions.update({func: ([(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs}) dce_ignored_funcs = set(cube_rt_funcs) - traced_model = concrete_trace( - model, - dummy_input, - use_operator_patch=True, - autowrap_leaf_function=leaf_functions, - dce_ignored_function=dce_ignored_funcs, - cpu_offload=True, - record_frames=not CompileFlag.disable_code_line_info, - ) + class no_save_tensor_hook(saved_tensors_hooks): + """skip saving tensors for backward since tracer only traces forward""" + def __init__(self): + def pack(x): + return None + def unpack(x): + raise RuntimeError("not expecting backward to be called on this tensor") + super().__init__(pack, unpack) + + with no_save_tensor_hook(): + traced_model = concrete_trace( + model, + dummy_input, + use_operator_patch=True, + autowrap_leaf_function=leaf_functions, + dce_ignored_function=dce_ignored_funcs, + cpu_offload=True, + record_frames=not CompileFlag.disable_code_line_info, + ) return traced_model From 7546b7617826fa183f3913a591348286d9dbfad9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 12 Sep 2023 09:59:03 +0000 Subject: [PATCH 1496/1892] Merged PR 1817: convenient gradient accumulation for synchronous reducer convenient gradient accumulation for synchronous reducer * For `ASYNC_REDUCER=0` (default): no more required for `cube.accum_mode` for gradient accumulation. Users can repeatly call `loss = model(data); loss.backward()` without any context manager in this case. * For `ASYNC_REDUCER=1`, `cube.accum_mode` is still required. RuntimeError will raise once user forgot calling `cube.accum_mode` at second-time backward in this case. --- cube/runtime/adapter/reducer.py | 50 +++++---- cube/utils.py | 2 + tests/runtime/test_grad_accum.py | 184 +++++++++++++++++++++++++++++++ 3 files changed, 212 insertions(+), 24 deletions(-) create mode 100644 tests/runtime/test_grad_accum.py diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index da749ee0..7f50c74c 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -63,8 +63,8 @@ def __init__(self, params: List[torch.nn.Parameter], self._reduce_op = reduce_op self._group = group self._wsz: int = torch.distributed.get_world_size(group=self._group) - self._cnt = 0 - self._work = None # communication handle + self._async_param_cnt: int = 0 # flag for triggering async communication + self._async_handle = None # asynchrounous communication handler self._hooks: List[Tuple[Any, RemovableHandle]] = [] self._async: bool = async_op @@ -147,26 +147,28 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): param.grad = None if RuntimeFlag.skip_reducer: return - - self._cnt += 1 - assert self._cnt <= len(self._params), \ - "detected double backward for a weight (not supported), or not use `model.zero_grad()` after optimizer" + self._async_param_cnt += 1 # perform all-reduce - if self._async and self._cnt == len(self._params): - # apply pre hooks - self._apply_pre_hooks() - # communication - if self._zero and Bucket.use_reduce_scatter_for_zero: - rank = torch.distributed.get_rank(group=self._group) - shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) - self._work = torch.distributed.reduce_scatter( - shards[rank], shards, op=self._reduce_op, - group=self._group, async_op=True) - else: - self._work = torch.distributed.all_reduce( - self._contiguous_grads, op=self._reduce_op, - group=self._group, async_op=True) + if self._async: + if self._async_param_cnt > len(self._params): + raise RuntimeError( + "Detected gradient accumulation with asynchronous Reducer. " + "Users should run with `cube.accum_mode` to manage gradient synchronization.") + if self._async_param_cnt == len(self._params): + # apply pre hooks + self._apply_pre_hooks() + # communication + if self._zero and Bucket.use_reduce_scatter_for_zero: + rank = torch.distributed.get_rank(group=self._group) + shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) + self._async_handle = torch.distributed.reduce_scatter( + shards[rank], shards, op=self._reduce_op, + group=self._group, async_op=True) + else: + self._async_handle = torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, + group=self._group, async_op=True) for param in self._params: # same trick with FSDP and Megatron @@ -191,8 +193,8 @@ def sync_grads(self): if CudaTimer().enabled and CudaTimer().predefined: _logger.warning( f'CudaTimer: the communication time of async reducer will not be recorded in `comm`') - assert self._work is not None - self._work.wait() + assert self._async_handle is not None + self._async_handle.wait() else: CudaTimer().start('comm', predefined=True) # apply pre-hooks @@ -288,8 +290,8 @@ def clear_post_hooks(self): def reset(self): """Reset status.""" - self._cnt = 0 - self._work = None + self._async_param_cnt = 0 + self._async_handle = None class Reducer: diff --git a/cube/utils.py b/cube/utils.py index e7a9fd84..79679732 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -77,6 +77,8 @@ def load_eval_schedule(filename: Optional[str] = None): class accum_mode: """Make cube execution in gradient accumulation mode. + This is only required when `ASYNC_REDUCER=1`. + A typical usage is: ``` diff --git a/tests/runtime/test_grad_accum.py b/tests/runtime/test_grad_accum.py new file mode 100644 index 00000000..6b8875b9 --- /dev/null +++ b/tests/runtime/test_grad_accum.py @@ -0,0 +1,184 @@ +import torch +import pytest +from functools import partial + +import cube +from cube.runtime.module import CubeModule +from ..launch_torchrun import torchrun +from ..utils import init_parameter, assert_parity + + +class MLP(CubeModule): + def __init__(self, ngpus, async_op, dim=512, nlayers=4,): + super().__init__() + ranks = list(range(ngpus)) + self.init_group(ranks=ranks) + + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + self.wreducer1 = cube.runtime.adapter.Reducer(ranks=ranks, reduce_op='sum', async_op=async_op, zero=False, + max_bucket_size_bytes=137217728, zero_ngroups=1) + for param in self.parameters(): + self.wreducer1.add_param(param) + self.add_reducer(self.wreducer1) + + def forward(self, data): + x = data + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) + return loss + + +class BaseMLP(torch.nn.Module): + def __init__(self, dim=512, nlayers=4,): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + def forward(self, data): + x = data + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) + return loss + + +def get_dummy_data(batch_size: int = 256): + torch.random.manual_seed(0) + return torch.randn( + [batch_size, 512], dtype=torch.float32, + device=torch.cuda.current_device()) + + +def baseline(accum_times: int = 4): + model = BaseMLP() + init_parameter(model) + model = model.cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + for _ in range(accum_times): + x = get_dummy_data() + loss = model(x) + loss.backward() + optimizer.step() + optimizer.zero_grad() + loss = loss.item() + while abs(loss) > 10.0: + loss /= 10.0 + losses.append(loss) + return losses + + +def reducer_sync_test(accum_times: int = 4): + ngpus = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + model = MLP(ngpus, async_op=False) + init_parameter(model) + model = model.cuda() + for reducer in model.reducers: + reducer.build_buckets() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + model.zero_grad() + for _ in range(accum_times): + x = get_dummy_data() + x = x.chunk(ngpus, dim=0)[rank] + loss = model(x) + loss.backward() + + torch.distributed.all_reduce(loss) + for reducer in model.reducers: + reducer.sync_grads() + optimizer.step() + optimizer.zero_grad() + loss = loss.item() + while abs(loss) > 10.0: + loss /= 10.0 + losses.append(loss) + return losses + + +def reducer_async_test_wrong(accum_times: int = 4): + ngpus = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + model = MLP(ngpus, async_op=True) + init_parameter(model) + model = model.cuda() + for reducer in model.reducers: + reducer.build_buckets() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + model.zero_grad() + for _ in range(accum_times): + x = get_dummy_data() + x = x.chunk(ngpus, dim=0)[rank] + loss = model(x) + loss.backward() + + torch.distributed.all_reduce(loss) + for reducer in model.reducers: + reducer.sync_grads() + optimizer.step() + optimizer.zero_grad() + loss = loss.item() + while abs(loss) > 10.0: + loss /= 10.0 + losses.append(loss) + return losses + + +def reducer_async_test_correct(accum_times: int = 4): + ngpus = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + model = MLP(ngpus, async_op=True) + init_parameter(model) + model = model.cuda() + for reducer in model.reducers: + reducer.build_buckets() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + losses = [] + for _ in range(3): + model.zero_grad() + for step in range(accum_times): + with cube.accum_mode(begin=(step == 0), end=(step == accum_times - 1)): + x = get_dummy_data() + x = x.chunk(ngpus, dim=0)[rank] + loss = model(x) + loss.backward() + + torch.distributed.all_reduce(loss) + for reducer in model.reducers: + reducer.sync_grads() + optimizer.step() + optimizer.zero_grad() + loss = loss.item() + while abs(loss) > 10.0: + loss /= 10.0 + losses.append(loss) + return losses + + +def accum_test(): + cube.init() + print('starting reducer sync') + assert_parity(baseline, partial(reducer_sync_test, 4)) + print('starting reducer async') + assert_parity(baseline, partial(reducer_async_test_correct, 4)) + # FIXME: this will hang: + # print('starting reducer async wrong') + # with pytest.raises(RuntimeError): + # assert_parity(baseline, partial(reducer_async_test_wrong, 4)) + +test_accum_2gpu = partial(torchrun, 2, accum_test) + From b01dc0a81e32eb3ceb2aec01b3341e2380357964 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 13 Sep 2023 09:02:46 +0000 Subject: [PATCH 1497/1892] Merged PR 1811: Fix merging ckpt bugs 1. fix bug when merging only one shard 2. fix bug that ckpt size is larger than expected when merging from zero ckpts --- cube/graph/function/function.py | 1 + cube/runtime/module.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 604e6b8f..7a17ae9c 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1021,6 +1021,7 @@ def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: identifier = ifirst[dim] oidx = ofirst.index(identifier) if isinstance(kwargs[kwarg_name], IRObject): + _logger.warning(f'partition size in IRObject: {kwargs[kwarg_name]}') size = list(kwargs[kwarg_name].value) else: size = list(kwargs[kwarg_name]) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 21fa9476..6853865b 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -135,8 +135,6 @@ def merge_partial_states(state_dicts, zero_idx_maps=None): :return: merged state_dict(model_state_dict, optimizer_state_dict,) """ assert len(state_dicts) > 0 - if len(state_dicts) == 1: - return state_dicts[0][0], state_dicts[0][1] plan_ngpus = -1 # TODO: remove this flag @@ -306,7 +304,8 @@ def _check_opt_state(opt_state): tensor_size = [] for dim_slice in tensor_size_slice: tensor_size.append(dim_slice.stop) - param_full_tensors[raw_name] = torch.zeros(tuple(tensor_size)) + partial_tensor = model_state_dict[local_name_with_id] + param_full_tensors[raw_name] = torch.zeros(tuple(tensor_size), dtype=partial_tensor.dtype) index = model_state_dict_keys.index(local_name_with_id) if index in optimizer_state_dict['state']: From f057fa513338e9d625d615b84c4bef41f77b011b Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 13 Sep 2023 09:26:10 +0000 Subject: [PATCH 1498/1892] Merged PR 1827: Force clone in nn.allreduce temporary fix for multiref + inplace operation error --- cube/runtime/adapter/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index 7619ba68..7089ae8a 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -9,6 +9,8 @@ def _allreduce(itensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: CudaTimer().start(field_name='comm', predefined=True) if not itensor.is_contiguous(): itensor = itensor.contiguous() + # force allreduce not to be in-place + itensor = itensor.detach().clone() group = DeviceGroup().get_group(ranks) torch.distributed.all_reduce(itensor, group=group) CudaTimer().stop(field_name='comm', predefined=True) From c098c08cca0e4de3686c72ac8b7a9403bbf575a5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 15 Sep 2023 06:58:34 +0000 Subject: [PATCH 1499/1892] Merged PR 1720: flexible schedule support for all kinds of operator placement Flexible pipeline schedule description and runtime execution support this should be co-merged with https://msrasrg.visualstudio.com/SuperScaler/_git/AutoDist/pullrequest/1778 --- cube/codegen/emit.py | 34 +- cube/codegen/module/module.py | 12 +- cube/codegen/schedule/schedule.py | 52 +-- cube/compiler.py | 1 - cube/execplan/execplan.py | 77 +--- cube/execplan/planpass/fusion.py | 8 +- cube/graph/gener/gen.py | 2 +- cube/graph/graph.py | 447 +++++++++++----------- cube/graph/schedule/__init__.py | 2 +- cube/graph/schedule/predefined.py | 134 ++++++- cube/graph/schedule/sched1f1b.py | 125 ------ cube/graph/schedule/schedinfer.py | 93 ----- cube/graph/schedule/schedmix.py | 189 --------- cube/graph/schedule/schednf1b.py | 77 ---- cube/graph/schedule/schedplan.py | 362 +++++++++--------- cube/graph/schedule/strategy.py | 50 --- cube/runtime/__init__.py | 1 - cube/runtime/adapter/collectives.py | 31 +- cube/runtime/schedule/__init__.py | 4 - cube/runtime/schedule/sched1f1b.py | 118 ------ cube/runtime/schedule/schedinfer.py | 25 -- cube/runtime/schedule/schedmix.py | 213 ----------- cube/runtime/schedule/schednf1b.py | 292 -------------- cube/runtime/schedule/strategy.py | 133 ------- tests/runtime/test_runtime_collectives.py | 19 - 25 files changed, 554 insertions(+), 1947 deletions(-) delete mode 100644 cube/graph/schedule/sched1f1b.py delete mode 100644 cube/graph/schedule/schedinfer.py delete mode 100644 cube/graph/schedule/schedmix.py delete mode 100644 cube/graph/schedule/schednf1b.py delete mode 100644 cube/graph/schedule/strategy.py delete mode 100644 cube/runtime/schedule/__init__.py delete mode 100644 cube/runtime/schedule/sched1f1b.py delete mode 100644 cube/runtime/schedule/schedinfer.py delete mode 100644 cube/runtime/schedule/schedmix.py delete mode 100644 cube/runtime/schedule/schednf1b.py delete mode 100644 cube/runtime/schedule/strategy.py diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index c91ad547..39c89628 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -5,7 +5,7 @@ from cube.ir.tensor import IRSubTensor from cube.ir.operator import IRDataOperation, IRFwOperation from cube.ir.adapter import IRWeightReducer, IRAdapter -from cube.ir.adapter.prim import IRAdapterPrim +from cube.ir.adapter.prim import CommPrim from cube.graph.segment import IRSegment @@ -169,7 +169,8 @@ def emit_fnode(self, node: IRFwOperation, prefix_attr: str = None) -> List[str]: codes.append(code) return codes - def emit_adapter(self, node: IRAdapter, prefix_attr: Optional[str] = None) -> List[str]: + def emit_adapter(self, node: IRAdapter, prefix_attr: Optional[str] = None, + async_op: bool = False) -> List[str]: """ Emit the statment of the adapter call @@ -177,22 +178,29 @@ def emit_adapter(self, node: IRAdapter, prefix_attr: Optional[str] = None) -> Li Python method for the targeted Segment, without the method signature and the return statement. - The fields storing intermediate codes that are populated by this method: - - NONE + Args: + node (IRAdapter) + prefix_attr (str | None): prefix to the tensor name + async_op (bool): whether to enable async communication """ codes = [] assert len(node.device) == 1, f"Expected adapter to be dispatched:\n{node.extra_repr()}" prims = [node] if node.differentiable and node.custom else [prim for prim in node.prims] - # only adapter that is non-differentiable can be executed as async - async_op = CompileFlag.async_comm and (not node.differentiable) if async_op: - for idx, prim in enumerate(prims): - if isinstance(prim, IRAdapterPrim) and prim.volume() == 0: - continue - break - #TODO: support more general cases: independent same-group primitives - async_op = False if len(prims[idx:]) != 1 else async_op + # note async_op can only be applied when primitives satisfy: + # 1) non-collective primitives perform before collective primitives. + # 2) collectives running on same nccl stream (i.e., same device group) + non_colls = [p for p in prims if not isinstance(p, CommPrim)] + colls = [p for p in prims if isinstance(p, CommPrim)] + # check condition 1) + if len(non_colls) > 1: + if max(prims.index(p) for p in non_colls) + 1 != len(non_colls): + async_op = False + # check condition 2) + devices = [set(p.device) for p in colls] + if len(colls) > 1 and not all(devs == devices[0] for devs in devices[1:]): + async_op = False for prim in prims: if len(prim.inputs()) == 1: @@ -200,7 +208,7 @@ def emit_adapter(self, node: IRAdapter, prefix_attr: Optional[str] = None) -> Li else: itensors = self.tuple_name(prim.inputs(), prefix_attr=prefix_attr) prim_kwargs = dict(prim.kwargs) - if async_op: + if async_op and isinstance(prim, CommPrim): prim_kwargs['async_op'] = True kwargs = self.kwargs_name(**prim_kwargs) outputs = self.return_name(prim.outputs()) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index a5ed0192..2f638a30 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import more_itertools import logging import copy @@ -13,14 +13,12 @@ from cube.graph.graph import IRSegment from cube.graph.parser.register import CustomizedOps -from cube.graph.parser.fx.parser import FxModuleParser from cube.execplan import ExecutionPlan from cube.execplan.execplan import ExeReuseCell from cube.codegen.syntax.symtable import SymbolTable -from cube.codegen.syntax.blocks import ClassBlock, ForBlock, FunctionBlock -from cube.codegen.schedule.schedule import ScheduleCodeGen +from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock from cube.codegen.emit import FuncEmission from cube.codegen.module.autograd import AutogradAdapterCodeGen @@ -323,7 +321,7 @@ def gen( elif isinstance(node, IRFwOperation): raise RuntimeError(f"Unexcepted global-level op call: {node}") elif isinstance(node, IRAdapter): - codes = self.emit_adapter(node, prefix_attr='self.') + codes = self.emit_adapter(node, prefix_attr='self.', async_op=CompileFlag.async_comm) elif isinstance(node, IRWeightReducer): self.init_reducer(node, device) codes = self.emit_reducer(node) @@ -604,7 +602,9 @@ def _emit_nodes(self, nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: code = self.emit_fnode(node, prefix_attr='self.') node_codes += code elif isinstance(node, IRAdapter): - code = self.emit_adapter(node) + # for adapters inside an IRSegment, we don't apply async communication to it + # as it is mostly in critical path. + code = self.emit_adapter(node, async_op=False) node_codes += code else: raise RuntimeError(f"unexpected type {type(node)} in IRSegment") diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index a62765df..8c3a2e42 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -1,5 +1,5 @@ -from typing import List, Dict, Any, Optional, Tuple +from typing import List, Optional, Tuple import copy import logging @@ -9,14 +9,12 @@ from cube.ir.adapter import IRWeightReducer, IRAdapter from cube.graph.graph import IRSegment -from cube.graph.schedule import IRScheduleStrategy - -from cube.execplan.execplan import ExecutionPlan, ExeRepetend, ExeReuseCell +from cube.execplan.execplan import ExecutionPlan, ExeReuseCell from cube.codegen.emit import FuncEmission from cube.codegen.syntax.symtable import SymbolTable from cube.codegen.lifecycle import LifeCycle -from cube.codegen.syntax.blocks import FunctionBlock, ForBlock +from cube.codegen.syntax.blocks import FunctionBlock _logger = logging.getLogger(__name__) @@ -64,10 +62,6 @@ def gen(self, device: int, outfile=None, attach=None) -> str: # body code if len(device_nodes) == 0: fb.insert_body('pass') - # legacy hardcode strategy - elif isinstance(self.execplan.graph.sched, IRScheduleStrategy): - code = self.emit_legacy_schedplan(self.execplan.graph.sched, device) - fb.insert_body(code) else: for line, node in enumerate(device_nodes): # execute @@ -89,12 +83,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: if not any(not node.isfw() for node in device_nodes): gencode += ['_infer_step = _train_step'] else: - # legacy hardcode strategy - if isinstance(self.execplan.graph.sched, IRScheduleStrategy): - _logger.warning('using legacy IRScheduleStrategy cannot generate inference code. ' - 'Switch to use scheduling without strategy') - with FunctionBlock(func_name='_infer_step', - args=args) as fb: + with FunctionBlock(func_name='_infer_step', args=args) as fb: fb.insert_body('_ = None') # body code if len(device_nodes) == 0: @@ -169,7 +158,7 @@ def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: ) elif isinstance(unwrap_node, IRDataOperation): - code = self.emit_dataloader(unwrap_node)[0] + code = f'{outputs} = {unwrap_node.signature}(*{inputs})' elif isinstance(unwrap_node, IRAdapter): code = asign.format( @@ -191,34 +180,3 @@ def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: raise RuntimeError(f"Unspported node type: {type(unwrap_node)}") return [code] - - def emit_repetend(self, repetend: ExeRepetend) -> List[str]: - """ - Emit code for executing a repetend - """ - with ForBlock(var=None, iters=f'range({repetend.repeat})') as fb: - for node in repetend.nodes(): - ncode = self.emit_node(node) - fb.insert_body(ncode) - return fb.code - - def emit_legacy_schedplan(self, schedplan: IRScheduleStrategy, devid: int) -> List[str]: - """ - Lagecy code - """ - signature = schedplan.signature - kwargs: Dict[str, Any] = schedplan.kwargs(devid) - strkwargs = dict() - for kwarg, val in kwargs.items(): - if isinstance(val, IRCell): - name = 'model.' + self.node_name(val) - elif isinstance(val, (tuple, list)): - brackets = ')' if len(val) != 1 else ',)' - name = '(' + ', '.join('model.' + self.node_name(n) \ - if isinstance(n, IRCell) else str(n) for n in val) + brackets - else: - name = str(val) - strkwargs[kwarg] = name - code = ', '.join(f'{kwarg}={name}' for kwarg, name in strkwargs.items()) - code = f'{signature}({code})' - return [code] diff --git a/cube/compiler.py b/cube/compiler.py index ab2541ea..dafe120d 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -191,7 +191,6 @@ def decorator(fn: Callable) -> Callable: if graph.sched is not None: start = time.time() graph.sched.apply() - _logger.debug(f'schedule:\n{graph.sched}') span = time.time() - start _logger.info('finish planpass on applying schedule strategy: {:.2f} s'.format(span)) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 5935b409..34d34bde 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -8,7 +8,7 @@ from cube.ir.adapter import IRAdapter, IRWeightReducer from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation from cube.graph.graph import IRGraph, IRSegment -from cube.graph.schedule.schedplan import SchedulePlan, Block, Repetend +from cube.graph.schedule.schedplan import SchedulePlan, Block class ExeReuseCell(IRCell): @@ -72,61 +72,6 @@ def __repr__(self) -> str: return f'ReuseCell-{self.device}(name={self._cell.name}{self._cell.cid}, inputs={self.inputs()}, outputs={self.outputs()})' -class ExeRepetend(IRCell): - """ - A cell that will be repeatedly executed for multiple times - on a sequence of nodes - """ - - def __init__(self, nodes: List[IRCell], repeat: int = 1): - super().__init__('repetend', 'None', 0, 0, init_outputs=False) - self._nodes: List[IRCell] = nodes - self._repeat = repeat - - @property - def repeat(self) -> int: - return self._repeat - - @property - def device(self) -> Tuple[int]: - device = set() - for node in self._nodes: - device.update(node.device) - return tuple(device) - - def nodes(self) -> Tuple[IRCell]: - return tuple(self._nodes) - - def isfw(self) -> bool: - return all(n.isfw() for n in self._nodes) - - def dispatch(self, devid: int) -> IRCell: - nodes = [] - for n in self._nodes: - if devid in n.device: - nodes.append(n.dispatch(devid)) - repetend = ExeRepetend(nodes, self.repeat) - repetend._id = self._id - - def add(self, node: IRCell): - """ - Append a node - """ - self._nodes.append(node) - - def pop(self, index: int) -> IRCell: - return self._nodes.pop(index) - - def remove(self, node: IRCell): - return self._nodes.remove(node) - - def __repr__(self) -> str: - dscp = f'Repetend{self.cid}-{self.device}(repeat={self.repeat}\n' - for n in self._nodes: - dscp += ' ' + str(n) + '\n' - dscp += ')' - return dscp - class ExecutionPlan: """ @@ -181,15 +126,7 @@ def block2reuse(node: Block) -> ExeReuseCell: topo_seqs: List[IRCell] = [] for block in schedplan.nodes(): - # convert repetends and blocks - if isinstance(block, Repetend): - nodes: List[ExeReuseCell] = [] - for node in block.nodes(): - if isinstance(node, Block): - node = block2reuse(node) - nodes.append(node) - block = ExeRepetend(nodes, repeat=block.span) - elif isinstance(block, Block): + if isinstance(block, Block): block = block2reuse(block) assert isinstance(block, IRCell) topo_seqs.append(block) @@ -207,9 +144,6 @@ def __init__(self, graph: IRGraph, topo_seqs: List[IRCell]): for device in node.device: self._seq.setdefault(device, []).append(node) - # due to repetends, a same node could appear multiple times - # in the execution sequence. For this case, all of them - # will be replaced by a same dispatched one. def cached_dispatch(node: IRCell, devid: int, dispatched: Dict[IRCell, IRCell]) -> IRCell: """Cached dispatch""" @@ -228,12 +162,7 @@ def cached_dispatch(node: IRCell, devid: int, node = nodes[idx] # print(f'handling {node}') if len(node.device) == 1: continue # no need for dispatch - if isinstance(node, ExeRepetend): - rnodes = [cached_dispatch(n, devid, dispatched) \ - for n in node.nodes() if devid in n.device] - dnode = ExeRepetend(rnodes, node.repeat) - else: - dnode = cached_dispatch(node, devid, dispatched) + dnode = cached_dispatch(node, devid, dispatched) nodes[idx] = dnode @property diff --git a/cube/execplan/planpass/fusion.py b/cube/execplan/planpass/fusion.py index 969abf7f..ebe5024e 100644 --- a/cube/execplan/planpass/fusion.py +++ b/cube/execplan/planpass/fusion.py @@ -5,7 +5,7 @@ from cube.ir.adapter import IRAdapter from cube.execplan import ExecutionPlan -from cube.execplan.execplan import ExeRepetend, ExeReuseCell +from cube.execplan.execplan import ExeReuseCell from cube.execplan.planpass.planpass import PlanPass from cube.ir.adapter.prim import IRAdapterPrim @@ -43,16 +43,12 @@ def apply(execplan: ExecutionPlan) -> ExecutionPlan: for fadapter in node.select(ntype=IRAdapter): ret = DiffFusion.nnfuse(fadapter) cnt = cnt+1 if ret else cnt - elif isinstance(node, ExeRepetend) and node.isfw(): - for fadapter in [n for n in node.nodes() if isinstance(n, IRAdapter)]: - ret = DiffFusion.nnfuse(fadapter) - cnt = cnt+1 if ret else cnt visited.add(node) _logger.info(f'adapter fusion: successfully fuse {cnt} differentiable adapters') return execplan @staticmethod - def _apply(cell: Union[IRSegment, ExeRepetend]) -> int: + def _apply(cell: IRSegment) -> int: cnt = 0 for node in cell.nodes(): if isinstance(node, IRAdapter) and node.isfw(): diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 493f6d79..99631a87 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -119,7 +119,7 @@ def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: # generate weight reducer graph = IRAdapterGener.gen_weight(graph) # fuse consecutive non-differentiable adapters into one - graph = IRAdapterGener.fusion(graph) + # graph = IRAdapterGener.fusion(graph) # print(graph.extra_repr()) return graph diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 23b6e62c..690a0d6c 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -7,7 +7,7 @@ will be inserted at scheduling time. """ -from typing import Sequence, Set, Union, Tuple, List, Optional, Dict, Any +from typing import Union, Tuple, List, Optional, Dict, Any import logging import copy import dill @@ -172,66 +172,6 @@ def backward(self, loss: Optional[IRSubTensor] = None): return self - - # ========================= Graph Manipulation ======================== - - def group(self, nodes: List[IRCell]) -> IRSegment: - """! - Group consecutive nodes into IRSegment. - Note nodes should not have applied by any transformation. - - @param nodes List[IRCell]: consecutive nodes in forward procedure - - @return segment IRSegment: the grouped segment - """ - assert all(node.isfw() for node in nodes), f"Expected all nodes in forward procedure" - fgraphs = [self.segment(fnode) for fnode in nodes] - assert len(set(fgraphs)) == 1, "cross-segment grouping is not allowed yet." - - fgraph: IRSegment = fgraphs[0] - findices: Tuple[int] = tuple(fgraph.index(node)[0] for node in nodes) - min_fidx, max_fidx = min(findices), max(findices) - assert max_fidx - min_fidx + 1 == len(nodes), "nodes should be in consecutive order" - - fsegment: IRSegment = fgraph.create_segment(nodes) - for node in nodes: - idx = fgraph.remove(node) - fgraph.insert(fsegment, idx) - - # group for mirror nodes - bnodes = [node.mirror for node in nodes if node.mirror is not None] - if len(bnodes) == 0: return fsegment - - # check consecutive - bgraph: IRSegment = fgraph.mirror - bindices = [bgraph.index(bnode)[0] for bnode in bnodes] - min_bidx, max_bidx = min(bindices), max(bindices) - assert max_bidx - min_bidx + 1 == len(bnodes), \ - f"backward nodes are not consecutive. minbidx: {min_bidx}, maxbidx: {max_bidx}" - - # update gradient for fgraph - for itensor in fsegment.inputs(): - if not isinstance(itensor, IRTensor): continue - fgraph.infer_grad(itensor.parent) - # update gradient inside segment - for ftensor in fsegment.full_tensors(): - fsegment.infer_grad(ftensor) - - # create backward segment - for bnode in bnodes: - bidx = bgraph.remove(bnode) - bnodes = [fsegment.create_bwop(fnode) for fnode in nodes[::-1] if fnode.mirror is not None] - # get backward graph inputs - output_grads = [t.grad for t in fsegment.outputs() if isinstance(t, IRSubTensor) and t.grad is not None] - # get backward graph outputs - input_grads = [t.grad for t in fsegment.inputs() if \ - isinstance(t, IRSubTensor) and t.grad is not None] - bsegment = IRSegment(bnodes, output_grads, input_grads) - - bgraph.insert(bsegment, bidx) - IRCell.make_pair(fsegment, bsegment) - return fsegment - # ========================== Graph Creation ======================== @staticmethod @@ -596,84 +536,95 @@ def reside(self, tensor: IRSubTensor, devices: Union[int, List[int]]): ## Schedule Policy Primitives ## - def sequential(self, nodes: Sequence[Union[FOp, Set[FOp]]]): - """ - Scheduling Primitive: sequentially execute a list of nodes, - or a list of concurrent nodes. - - Note there should be no dependency from a later node (set) to a previous node (set). - - Note in current implementation we don't check correctness + def sequential(self, prev_nodes: Tuple[IRFwOperation], succ_nodes: Tuple[IRFwOperation]): + """Schedule primitive: schedule prev_nodes right before the succ_nodes + + The position of `succ_nodes` will keep unchanged in the sequence + while the `prev_nodes` will be scheduled right before the `succ_nodes`. + Corresponding backward operators will also be re-ordered. - Currently only support node (set) from a same device. + The `prev_nodes` should be consecutive in the sequence. + The `succ_nodes` should be consecutive in the sequence. - @param nodes Sequence[Set[FOp]]: a sequence of operators or - a sequence of concurrent operators. Note there should be no - """ - assert len(nodes) > 0 - concurrent_groups = [[node] if isinstance(node, IRCell) else node for node in nodes] - segment: IRSegment = self.segment(concurrent_groups[0][0]) - idx = segment.index(nodes[0]) - for group in concurrent_groups[1:]: - for node in group: - assert segment.exist(node, flatten=False), "All nodes should in a same segment" - # TODO: should check every node to see if they can be gathered based on that node - segment.reorder(node, idx) - - def concurrent(self, nodes: Set[Union[FOp, Sequence[FOp]]]): - """ - Scheduling Primitive: concurrently execut a list of nodes, - or a list of sequential nodes. - - Note there should be no dependency from a node (set) to another node (set). - - Currently only suuport node (set) from different devices. - - @param nodes Set[Sequence[Fop]]: a set of operators or - a set of sequential operators. - """ - assert len(nodes) > 0 - seq_groups = [[node] if isinstance(node, IRCell) else node for node in nodes] - segment: IRSegment = self.segment(seq_groups[0][0]) - idx = segment.index(nodes[0]) - for group in seq_groups[1:]: - for node in group: - assert segment.exist(node, flatten=False), "All nodes should in a same segment" - # TODO: should check every node to see if they can be gathered based on that node - segment.reorder(node, idx) - - def happen_before(self, node1: IRCell, node2: IRCell, skip=None) -> bool: - """ - Check node1 -> (happen before) node2 + Args: + prev_nodes (Tuple[IRFwOperation]): the nodes to be scheduled right before `succ_nodes` + succ_nodes (Tuple[IRFwOperation]): the nodes to be executed right after `prev_nodes` Returns: - Boolean + None """ - raise NotImplementedError("dependency is not supported yet") - skip = list() if skip is None else skip - if node1 in skip: - return False - if not isinstance(node1, IRCell) or not isinstance(node2, IRCell): - raise TypeError("Expected node to be IRCell") - if node2 in node1.successors(): - return True - else: - for succ_node in node1.successors(): - if self.happen_before(succ_node, node2, skip): - return True - return False - - def depends(self, pre_node: IRCell, post_node: IRCell) -> bool: - """! - Check whether pre_node has dataflow dependency on post_node: - pre_node -> post_node + prev_indices = [self._nodes.index(n) for n in prev_nodes] + succ_indices = [self._nodes.index(n) for n in succ_nodes] + if len(prev_nodes) != max(prev_indices) - min(prev_indices) + 1: + raise ValueError( + f'prev_nodes are expected to be consecutive in node sequence: ' + f'{len(prev_nodes)} != {max(prev_indices) - min(prev_indices) + 1}' + ) + if len(succ_nodes) != max(succ_indices) - min(succ_indices) + 1: + raise ValueError( + f'succ_nodes are expected to be consecutive in node sequence: ' + f'{len(succ_nodes)} != {max(succ_indices) - min(succ_indices) + 1}' + ) + # check duplication + if len(set(prev_indices)) != len(prev_indices): + raise ValueError(f'find duplicated node in prev nodes') + if len(set(succ_indices)) != len(succ_indices): + raise ValueError(f'find duplicated node in succ nodes') + if len(set(prev_indices).intersection(set(succ_indices))) != 0: + raise ValueError(f'find duplicated node in both succ_nodes and prev_nodes') + # TODO: check dependency + + seq = list(self._nodes) + # cut out prev_nodes + fstart, fend = min(prev_indices), max(prev_indices) + 1 + fnodes = seq[fstart:fend] + seq = seq[:fstart] + seq[fend:] + # insert prev_nodes + ofst = min(succ_indices) + if max(prev_indices) < min(succ_indices): + ofst = ofst - len(fnodes) + seq = seq[:ofst] + fnodes + seq[ofst:] + + # update order of backward node + prev_bnodes = [n.mirror for n in prev_nodes[::-1] if n.mirror is not None] + succ_bnodes = [n.mirror for n in succ_nodes[::-1] if n.mirror is not None] + prev_bindx = [seq.index(n) for n in prev_bnodes] + succ_bindx = [seq.index(n) for n in succ_bnodes] + if len(prev_bnodes) > 0: + # TODO: extend succ_nodes to find at least one forward op that has backward + if len(succ_bnodes) == 0: + raise NotImplementedError(f'backward of succ_nodes are expected') + # cut out prev_backward_nodes + bstart, bend = min(prev_bindx), max(prev_bindx) + 1 + bnodes = seq[bstart:bend] + seq = seq[:bstart] + seq[bend:] + # insert prev_backward_nodes + ofst = max(succ_bindx) + 1 + if max(prev_bindx) < min(succ_bindx): + ofst = ofst - len(bnodes) + seq = seq[:ofst] + bnodes + seq[ofst:] + # update sequence + self._nodes = seq + + def depends(self, pre_node: IRCell, succ_node: IRCell) -> bool: + """Check direct data dependency between two nodes. + + Check dependency of pre_node -> post_node. + + Note this function only checks direct data dependency that whether + the outputs in `prev_node` and inputs in `post_node` have data dependency. + + The function cannot detect data dependency in graph like: + pre_node -> (some nodes) ... -> post_node - @param pre_node: the happen before node - @param post_node: the happen after node + Args: + pre_node (IRCell): the happen before node + post_node (IRCell): the happen after node - @return ret bool: True if post_node depends on pre_node on dataflow, otherwise False. + Returns: + ret (bool): True if post_node depends on pre_node on dataflow, otherwise False. """ - itensors = [t for t in post_node.inputs() if isinstance(t, IRSubTensor)] + itensors = [t for t in succ_node.inputs() if isinstance(t, IRSubTensor)] for otensor in pre_node.outputs(): if not isinstance(otensor, IRSubTensor): continue for itensor in itensors: @@ -681,118 +632,160 @@ def depends(self, pre_node: IRCell, post_node: IRCell) -> bool: return True return False - def schedule(self, node1: IRCell, action: str, node2: IRCell) -> bool: - """! - Schedule node1 and node2 based on the action - - The node2 will keep unchanged in the sequence and schedule will perform - on node1. - - @param node1 IRCell - @param node2 IRCell - @param action str: - 'after': fixed node2 and schedule node1 after node2 in the sequence. - 'before': fixed node2 and schedule node1 before node2 in the sequence. - - @return success bool: True if the scheduling success otherwise False. - """ - idx1 = self._nodes.index(node1) - idx2 = self._nodes.index(node2) - # node2 -> node1 - if action == 'after': - if idx2 < idx1: - return True - for idx in range(idx1+1, idx2+1): - if self.depends(node1, self._nodes[idx]): - return False - self.remove(node1) - self.insert(node1, idx2) - return True - # node1 -> node2 - if action == 'before': - if idx1 < idx2: - return True - for idx in range(idx2, idx1): - if self.depends(self._nodes[idx], node1): - return False - self.remove(node1) - self.insert(node1, idx2) - return True - raise KeyError(f"Unknown scheduling action {action}") - @property def sched(self): - """! - Return schedule plan for the execution. + """ Get bound schedule plan + + Returns: + sched (SchedulePlan | None): bound schedule plan """ return self._sched - def predef_sched(self, strategy): - """! - Set schedule plan for the execution. + def _bind_schedule(self, schedplan): + """Set schedule plan for the execution - @param strategy IRScheduleStrategy: the schedule strategy instance - """ - self._sched = strategy + This will be called when initiating a schedule plan for the graph. - def _bind_schedule(self, schedplan): - """ - Set schedule plan for the execution + Args: + schedplan (SchedulePlan) - @param schedplan SchedulePlan + Returns: + None """ - assert self._sched is None, "The graph is already binded with one schedule plan." + from cube.graph.schedule import SchedulePlan + if not isinstance(schedplan, SchedulePlan): + raise TypeError(f"Expect a SchedulePlan but got: {type(schedplan)}") + assert self._sched is None, "The graph is already bound with one schedule plan." self._sched = schedplan - @staticmethod - def legal_schedule(seq: List[IRCell], integrity_check=False): - """ - Check whether seq satisfies topological order. + # ================= staging primitives ================== - @note: this functionality is not enabled due to predecessor and succesor - functionality. + def group(self, nodes: List[IRCell]) -> IRSegment: + """Group consecutive nodes into IRSegment. - @param seq List[IRCell]: the nodes in scheudled order - @param integrity_check bool: - If true, performs additional integrity check that requires - all the nodes in predecessor and successor of a node should - appear in the sequence. + Note nodes should not have applied by any transformation. - @return valid bool: True for satisfying topo order, otherwise False. - """ - for index, node in enumerate(seq): - for pre in node.predecessors(): - if pre in seq: - pre_idx = seq.index(pre) - if pre_idx >= index: - return False - elif integrity_check: - return False - return True + Args: + nodes List[IRCell]: consecutive nodes in forward procedure - def add_schedule(self, nodes: List[IRCell]) -> bool: + Returns: + segment IRSegment: the grouped segment """ - Add node happen before dependencies according to nodes list order + assert all(node.isfw() for node in nodes), f"Expected all nodes in forward procedure" + fgraphs = [self.segment(fnode) for fnode in nodes] + assert len(set(fgraphs)) == 1, "cross-segment grouping is not allowed yet." + + fgraph: IRSegment = fgraphs[0] + findices: Tuple[int] = tuple(fgraph.index(node)[0] for node in nodes) + min_fidx, max_fidx = min(findices), max(findices) + assert max_fidx - min_fidx + 1 == len(nodes), "nodes should be in consecutive order" + + fsegment: IRSegment = fgraph.create_segment(nodes) + for node in nodes: + idx = fgraph.remove(node) + fgraph.insert(fsegment, idx) + + # group for mirror nodes + bnodes = [node.mirror for node in nodes if node.mirror is not None] + if len(bnodes) == 0: return fsegment + + # check consecutive + bgraph: IRSegment = fgraph.mirror + bindices = [bgraph.index(bnode)[0] for bnode in bnodes] + min_bidx, max_bidx = min(bindices), max(bindices) + assert max_bidx - min_bidx + 1 == len(bnodes), \ + f"backward nodes are not consecutive. minbidx: {min_bidx}, maxbidx: {max_bidx}" + + # update gradient for fgraph + for itensor in fsegment.inputs(): + if not isinstance(itensor, IRTensor): continue + fgraph.infer_grad(itensor.parent) + # update gradient inside segment + for ftensor in fsegment.full_tensors(): + fsegment.infer_grad(ftensor) + + # create backward segment + for bnode in bnodes: + bidx = bgraph.remove(bnode) + bnodes = [fsegment.create_bwop(fnode) for fnode in nodes[::-1] if fnode.mirror is not None] + # get backward graph inputs + output_grads = [t.grad for t in fsegment.outputs() if isinstance(t, IRSubTensor) and t.grad is not None] + # get backward graph outputs + input_grads = [t.grad for t in fsegment.inputs() if \ + isinstance(t, IRSubTensor) and t.grad is not None] + bsegment = IRSegment(bnodes, output_grads, input_grads) + + bgraph.insert(bsegment, bidx) + IRCell.make_pair(fsegment, bsegment) + return fsegment + + def blocking(self, nodes: Tuple[IRFwOperation]): + """Group forward operators into blocks. + + The corresponding backward operators (if have) will also be grouped into stages + Cross-stage dataflow will be limited to neighbor stages. + This should be called before any operator partition. + + Args: + nodes Tuple[IRFwOperations]: the start forward node of each stage. + + Returns: + None """ - if not all([isinstance(node, IRCell) for node in nodes]): - raise TypeError("Expected List[IRCell") - for idx in range(len(nodes) - 1): - prev = nodes[idx] - post = nodes[idx + 1] - if self.happen_before(post, prev): - return False - for idx in range(len(nodes) - 1): - prev = nodes[idx] - post = nodes[idx + 1] - prev.add_successor(output_index=-1, cell=post) - post.add_predecessor(input_index=-1, cell=prev) - return True + assert all(isinstance(node, IRFwOperation) for node in nodes), \ + f"Find node is not IRFwOperation or IRDataOperation: {node}" + assert all(node in self._nodes for node in nodes), \ + f"Exist node is not in graph nodes" + starts = list(self._nodes.index(node) for node in nodes) + assert len(starts) > 0 - # ================= staging primitives ================== + # multiref (created by graph.multiref) will be moved to the next stage (if possible) for optimization + for sid in range(len(starts)): + while starts[sid] > 0: + node = self.node(starts[sid]-1) + if node.name == 'multiref' or isinstance(node, IRGraphAnchor): + starts[sid] -= 1 + continue + break + + # adjust the start of the first stage to involve beginning operators + for idx in range(starts[0]): + node = self.node(idx) + if isinstance(node, IRDataOperation): + continue + assert isinstance(node, IRFwOperation), \ + f"Expected nodes previous from the first stage are all IRFwOperation, but got {type(node)}" + if node.name == 'multiref' or isinstance(node, IRPyFunc): + pass + else: + _logger.warning(f'Detect a node: {node} that is previous from the first stage. Will be included inside the first stage') + starts[0] = idx + break + + last_fidx = 0 + for idx, node in enumerate(self._nodes): + if not isinstance(node, IRBpOperation): + last_fidx = idx + + fstages: List[List[IRCell]] = [] + for sid in range(len(starts)): + begin = starts[sid] + end = starts[sid+1] if sid != len(starts) - 1 else last_fidx + 1 + if begin >= end: + _logger.warning(f"Detected stage {sid} doesn't have operators: [begin({begin}): end({end})). Skipped") + continue + fnodes = self._nodes[begin:end] + assert all(isinstance(node, IRFwOperation) for node in fnodes), \ + f"find at least one nodes are not of IRFwOperation in the stage {sid}. They should be moved to the front" + fstages.append(fnodes) + + # grouping into segment + for sid in range(len(fstages)): + self.group(fstages[sid]) def staging(self, nodes: Tuple[IRFwOperation]): - """! - Group forward operators into sequential stages. + """Group forward operators into sequential stages. + The corresponding backward operators (if have) will also be grouped into stages Cross-stage dataflow will be limited to neighbor stages. This should be called before any operator partition. @@ -823,8 +816,11 @@ def staging(self, nodes: Tuple[IRFwOperation]): stage 5: t5 = identity(t4) xx = consume(t5) - @param nodes Tuple[IRFwOperations]: the start forward node of each stage. - @return None + Args: + nodes Tuple[IRFwOperations]: the start forward node of each stage. + + Returns: + None """ assert all(isinstance(node, IRFwOperation) for node in nodes), \ f"Find node is not IRFwOperation or IRDataOperation: {node}" @@ -885,6 +881,7 @@ def get_sid(fnode: IRCell) -> Optional[int]: def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: fwop = Identity(tensor) fwop.infer_shape() + fwop.output(0).parent.dtype = tensor.dtype fwop.set_output(0, fwop.output(0).tosub()) if tensor.requires_grad: fwop.output(0).parent.requires_grad = True diff --git a/cube/graph/schedule/__init__.py b/cube/graph/schedule/__init__.py index 3712eea3..4c5f2f80 100644 --- a/cube/graph/schedule/__init__.py +++ b/cube/graph/schedule/__init__.py @@ -1 +1 @@ -from cube.graph.schedule.strategy import IRScheduleStrategy \ No newline at end of file +from cube.graph.schedule.schedplan import SchedulePlan diff --git a/cube/graph/schedule/predefined.py b/cube/graph/schedule/predefined.py index a2e81503..7a197b12 100644 --- a/cube/graph/schedule/predefined.py +++ b/cube/graph/schedule/predefined.py @@ -11,22 +11,6 @@ class PredefinedSched: - @staticmethod - def grad_accum(graph: IRGraph, num_microbatches: int) -> SchedulePlan: - """ - Gradient accumulation for SPMD scenario. - """ - segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) - # describe schedule - sched = SchedulePlan(graph, num_microbatches) - step = 0 - for midx in range(num_microbatches): - for seg in segments: - sched.add_segment(seg, midx, step) - step += 1 - sched.finish() - return sched - @staticmethod def sched_1f1b(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: """ @@ -65,6 +49,74 @@ def sched_1f1b(graph: IRGraph, num_microbatches: int, num_stages: int) -> Schedu sched.finish() return sched + @staticmethod + def sched_1f1b_plus(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: + """1F1B Plus Scheduling. + + f0 f0 f1 f1 f2 f2 | f3 f3 b0 | b1 b2 b3 + f0 f0 f1 f1 f2 f2 | f3 b0 f3 | b1 b2 b3 + f0 f1 f0 f2 f1 b0 | f3 f2 b1 | f3 b2 b3 + f0 f1 f0 f2 b0 f1 | f3 b1 f2 | b2 f3 b + """ + segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + tp_fsegs = [seg for seg in fsegs if len(seg.device) == len(graph.device)] + fb_fsegs = [seg for seg in fsegs if seg not in tp_fsegs] + assert len(fb_fsegs) == num_stages, f"got only {len(fb_fsegs)} stages but need {num_stages} stages" + assert all(tuple(seg.device) == tuple(graph.device) for seg in tp_fsegs) + + # describe schedule + sched = SchedulePlan(graph, num_microbatches) + + wait_steps = [sid for sid in range(num_stages)] + bw_ofst = [num_stages - 1 - sid for sid in range(num_stages)] + total_steps = num_microbatches * 2 + (num_stages - 1) * 2 + + # 1f1b schedule + for step in range(total_steps): + for sid in range(num_stages): + ofst = wait_steps[sid] + if step < ofst: continue + fw_idx = (step - ofst) // 2 + # forward or backward segment + segment = fb_fsegs[sid] if (step - ofst) % 2 == 0 else fb_fsegs[sid].mirror + mb_idx = fw_idx if (step - ofst) % 2 == 0 else fw_idx - bw_ofst[sid] + # append for execution + if mb_idx < 0 or mb_idx >= num_microbatches: continue + sched.add_segment(segment, mb_idx, step) + + # insert + for mid in range(num_microbatches): + for tp_seg in tp_fsegs: + # TODO: not work case: tp_seg at tail fsegs + next_seg = fsegs[fsegs.index(tp_seg)+1] + assert next_seg in fsegs + insert_fw, insert_bw = False, tp_seg.mirror is None + if tp_seg.mirror is not None: + assert next_seg.mirror is not None + + for step in range(sched.nsteps-1, -1, -1): + segments = [blk.content for blk in sched.segments(step) if blk.mid == mid] + # insert forward + if next_seg in segments: + sched.insert_step(step, tp_seg, mid, 1) + assert not insert_fw + insert_fw = True + # insert backward + if next_seg.mirror in segments: + sched.insert_step(step+1, tp_seg.mirror, mid, 1) + assert not insert_bw + insert_bw = True + if insert_fw and insert_bw: break + + assert insert_fw and insert_bw, ( + f'find one segment cannot be inserted in schedplan: ', + f'mid: {mid}, fw: {insert_fw}, bs: {insert_bw}') + + sched.finish() + # print(sched) + return sched + @staticmethod def sched_gpipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: """ @@ -98,6 +150,56 @@ def sched_gpipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> Sched sched.finish() return sched + @staticmethod + def sched_chimera_direct(graph: IRGraph, num_microbatches: int, num_stages: int): + """Chimera-direct scheduling. + + The graph should be staged into segments. + + An illustration of scheduling schema (the number is micro-batch index): + ``` + f0 f1 f2 b2-b2 f3 b3-b3 b0-b0 b1-b1 + f0 f2 f1 f3 b2-b2 b0-b0 b3-b3 b1-b1 + f2 f0 f3 f1 b0-b0 b2-b2 b1-b1 b3-b3 + f2 f3 f0 b0-b0 f1 b1-b1 b2-b2 b3-b3 + + 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 (-> steps) + ``` + + Note the f0 and f2 (step 0) should be considered to be one segment in graph. + """ + segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + assert len(fsegs) == 4, f"Chimera-direct scheduling only applies for 4 segments, but {len(segments)} detected" + sched = SchedulePlan(graph, num_microbatches) + assert num_microbatches % 2 == 0 + mid = 0 + while mid < num_microbatches: + ofst = 16 * (mid // 2) + # first micro-batch + sched.add_segment(fsegs[0], micro_batch_id=mid, step=max(0, 0+ofst-3)) # tight compact + sched.add_segment(fsegs[1], micro_batch_id=mid, step=max(1, 1+ofst-3)) # tight compact + sched.add_segment(fsegs[2], micro_batch_id=mid, step=2+ofst) + sched.add_segment(fsegs[3], micro_batch_id=mid, step=3+ofst) + sched.add_segment(fsegs[3].mirror, micro_batch_id=mid, step=4+ofst, span=2) + sched.add_segment(fsegs[2].mirror, micro_batch_id=mid, step=6+ofst, span=2) + sched.add_segment(fsegs[1].mirror, micro_batch_id=mid, step=8+ofst, span=2) + sched.add_segment(fsegs[0].mirror, micro_batch_id=mid, step=10+ofst, span=2) + # second micro-batch + sched.add_segment(fsegs[0], micro_batch_id=mid+1, step=2+ofst) + sched.add_segment(fsegs[1], micro_batch_id=mid+1, step=3+ofst) + sched.add_segment(fsegs[2], micro_batch_id=mid+1, step=4+ofst) + sched.add_segment(fsegs[3], micro_batch_id=mid+1, step=6+ofst) + sched.add_segment(fsegs[3].mirror, micro_batch_id=mid+1, step=8+ofst, span=2) + sched.add_segment(fsegs[2].mirror, micro_batch_id=mid+1, step=10+ofst, span=2) + sched.add_segment(fsegs[1].mirror, micro_batch_id=mid+1, step=12+ofst, span=2) + sched.add_segment(fsegs[0].mirror, micro_batch_id=mid+1, step=14+ofst, span=2) + # update + mid += 2 + sched.finish() + return sched + + @staticmethod def sched_infer_pipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: """ diff --git a/cube/graph/schedule/sched1f1b.py b/cube/graph/schedule/sched1f1b.py deleted file mode 100644 index 095df847..00000000 --- a/cube/graph/schedule/sched1f1b.py +++ /dev/null @@ -1,125 +0,0 @@ - -from typing import Dict, Optional, List -import warnings - -from cube.ir.cten import IRCell -from cube.ir.adapter.adapter import IRAdapter -from cube.ir.adapter.adapter import IRWeightReducer - -from cube.graph.graph import IRGraph, IRSegment -from cube.graph.schedule import IRScheduleStrategy - - -class IRSchedule1F1B(IRScheduleStrategy): - """ - 1F1B Scheduling - - This treats model as a linear graph which can be - grouped into continous stages. - - [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] - [Recv-Backward] Backward-Segment [Send-Backward] - """ - - def __init__(self, graph, nmicros: int): - super().__init__(graph, nmicros) - self.signature = 'cube.runtime.schedule.Schedule1F1B.run' - # forward body - self.fsegments: Dict[int, IRSegment] = dict() - # forward send - self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() - # forward recv - self.rfadapter: Dict[int, Optional[IRAdapter]] = dict() - # backard send - self.sbadapter: Dict[int, Optional[IRAdapter]] = dict() - # backward recv - self.rbadapter: Dict[int, Optional[IRAdapter]] = dict() - # num_stage - self.num_stages: int = -1 - # stage id - self.stage_id: Dict[int, int] = dict() - # reducers - self.dev_reducers: Dict[int, List[IRWeightReducer]] = dict() - # recompute - self.recompute = False - - - def apply(self) -> IRGraph: - self.mesh() - for node in self.graph.nodes(): - if isinstance(node, IRAdapter) and node.forward: - if len(set(node.outputs())) > 1 or len(set(node.inputs())) > 1: - warnings.warn( - "Detected one adapter has more than one input/output in stage transmission, " - "which is not safe for current scheduling implementation due to potential " - "mis-ordering of arguments. Better to use torch.cat and torch.chunk to " - "merge multiple tensors into one and unpack it at next stage." - ) - # each forward has corresponding backward - assert all(fseg.mirror in self.segments for fseg in self.segments if fseg.isfw()), \ - "Require backward of each forward stage" - # stage doesn't share devices - fsegments: List[IRSegment] = [fseg for fseg in self.segments if fseg.isfw()] - self.num_stages = len(fsegments) - for sid, fseg in enumerate(fsegments): - for devid in fseg.device: - # forward body - assert devid not in self.fsegments, "One device cannot have multiple forward stages" - self.fsegments[devid] = fseg - # forward recv / backward send - assert len(self.recvers[fseg]) <= 1, "Corss-stage adapter can only be one" - if sid == 0: - assert len(self.recvers[fseg]) == 0, "Expect no forward send at first stage" - assert len(self.senders[fseg.mirror]) == 0, "Expect no backward send at first stage" - else: - assert len(self.recvers[fseg]) == 1, "Expect one forward recv at non-first stage" - assert len(self.senders[fseg.mirror]) == 1, "Expect one backward send at non-first stage" - self.rfadapter[devid] = None if sid == 0 else self.recvers[fseg][0] - self.sbadapter[devid] = None if sid == 0 else self.senders[fseg.mirror][0] - # forward send / backward recv - if sid == self.num_stages - 1: - assert len(self.senders[fseg]) == 0, "Expect no forward send at last stage" - assert len(self.recvers[fseg.mirror]) == 0, "Expect no backward recv at last stage" - else: - assert len(self.senders[fseg]) == 1, "Expect no forward send at last stage" - assert len(self.recvers[fseg.mirror]) == 1, "Expect no forward send at last stage" - self.sfadapter[devid] = None if sid == self.num_stages - 1 else self.senders[fseg][0] - self.rbadapter[devid] = None if sid == self.num_stages - 1 else self.recvers[fseg.mirror][0] - # weight reducer - self.dev_reducers[devid] = [reducer for reducer in self.reducers if devid in reducer.device] - # stage id - self.stage_id[devid] = sid - - return self.graph - - def kwargs(self, devid: int) -> Dict[str, IRCell]: - """ - return kwargs for runtime caller - """ - return dict( - segment = self.fsegments[devid], - sfadapter = self.sfadapter[devid], - rfadapter = self.rfadapter[devid], - sbadapter = self.sbadapter[devid], - rbadapter = self.rbadapter[devid], - dataloader = 'dataloader', - stage_id = self.stage_id[devid], - num_stages = self.num_stages, - num_microbatch = self.nmicros, - reducers = self.dev_reducers[devid], - recompute = self.recompute - ) - - def __repr__(self) -> str: - dscp = '' - for mesh in self.devmesh: - devid = mesh[0] - # segment = self.segments[devid].to_str(skip_attr=True) if self.segment[mesh[0]] else None - dscp += (f"1F1B Schedule: Stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" - f" segment = {self.segments[devid]}\n" - f" send-fw = {self.sfadapter[mesh[0]]}\n" - f" recv-fw = {self.rfadapter[mesh[0]]}\n" - f" send-bw = {self.sbadapter[mesh[0]]}\n" - f" recv-bw = {self.rbadapter[mesh[0]]}\n" - f")\n") - return dscp diff --git a/cube/graph/schedule/schedinfer.py b/cube/graph/schedule/schedinfer.py deleted file mode 100644 index b0c2bf6a..00000000 --- a/cube/graph/schedule/schedinfer.py +++ /dev/null @@ -1,93 +0,0 @@ - -from typing import Dict, Optional, List -import warnings - -from cube.ir.cten import IRCell -from cube.ir.adapter.adapter import IRAdapter - -from cube.graph.graph import IRGraph, IRSegment -from cube.graph.schedule import IRScheduleStrategy - - -class IRScheduleInfer(IRScheduleStrategy): - """ - 1F1B Scheduling - - This treats model as a linear graph which can be - grouped into continous stages. - - [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] - [Recv-Backward] Backward-Segment [Send-Backward] - """ - - def __init__(self, graph, nmicros: int): - super().__init__(graph, nmicros) - self.signature = 'cube.runtime.schedule.ScheduleInfer.run' - # forward body - self.fsegments: Dict[int, IRSegment] = dict() - # forward send - self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() - # forward recv - self.rfadapter: Dict[int, Optional[IRAdapter]] = dict() - # num_stage - self.num_stages: int = -1 - - def apply(self) -> IRGraph: - self.mesh() - for node in self.graph.nodes(): - if isinstance(node, IRAdapter) and node.forward: - if len(set(node.outputs())) > 1 or len(set(node.inputs())) > 1: - warnings.warn( - "Detected one adapter has more than one input/output in stage transmission, " - "which is not safe for current scheduling implementation due to potential " - "mis-ordering of arguments. Better to use torch.cat and torch.chunk to " - "merge multiple tensors into one and unpack it at next stage." - ) - # no backward - for seg in self.graph.select(ntype=IRSegment): - assert seg.isfw(), "Detected backward, which should not exist in inference" - # stage doesn't share devices - fsegments: List[IRSegment] = [fseg for fseg in self.segments if fseg.isfw()] - self.num_stages = len(fsegments) - for sid, fseg in enumerate(fsegments): - for devid in fseg.device: - # forward body - assert devid not in self.fsegments, "One device cannot have multiple forward stages" - self.fsegments[devid] = fseg - if sid == 0: - assert len(self.recvers[fseg]) == 0, "Expect no forward send at first stage" - else: - assert len(self.recvers[fseg]) == 1, "Expect one forward recv at non-first stage" - self.rfadapter[devid] = None if sid == 0 else self.recvers[fseg][0] - # forward send - if sid == self.num_stages - 1: - assert len(self.senders[fseg]) == 0, "Expect no forward send at last stage" - else: - assert len(self.senders[fseg]) == 1, "Expect no forward send at last stage" - self.sfadapter[devid] = None if sid == self.num_stages - 1 else self.senders[fseg][0] - - return self.graph - - def kwargs(self, devid: int) -> Dict[str, IRCell]: - """ - return kwargs for runtime caller - """ - return dict( - segment = self.fsegments[devid], - sfadapter = self.sfadapter[devid], - rfadapter = self.rfadapter[devid], - dataloader = 'dataloader', - num_microbatch = self.nmicros, - ) - - def __repr__(self) -> str: - dscp = '' - for mesh in self.devmesh: - devid = mesh[0] - # segment = self.segments[devid].to_str(skip_attr=True) if self.segment[mesh[0]] else None - dscp += (f"GPipe Infer Schedule: Stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" - f" segment = {self.segments[devid]}\n" - f" send-fw = {self.sfadapter[mesh[0]]}\n" - f" recv-fw = {self.rfadapter[mesh[0]]}\n" - f")\n") - return dscp diff --git a/cube/graph/schedule/schedmix.py b/cube/graph/schedule/schedmix.py deleted file mode 100644 index 6267b5a8..00000000 --- a/cube/graph/schedule/schedmix.py +++ /dev/null @@ -1,189 +0,0 @@ - -from typing import Dict, Optional, List -import warnings - -from cube.ir.cten import IRCell -from cube.ir.adapter.adapter import IRAdapter -from cube.ir.adapter.adapter import IRWeightReducer - -from cube.graph.graph import IRGraph, IRSegment -from cube.graph.schedule import IRScheduleStrategy -from cube.ir.adapter.prim import IdentityPrim - - -class IRScheduleMix(IRScheduleStrategy): - """ - 1F1B Scheduling - - This treats model as a linear graph which can be - grouped into continous stages. - - [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] - [Recv-Backward] Backward-Segment [Send-Backward] - """ - - def __init__(self, graph, nmicros: int): - super().__init__(graph, nmicros) - self.signature = 'cube.runtime.schedule.ScheduleMix.run' - # forward body - self.encoder_barriers: Dict[int, IRSegment] = dict() - self.decoder_barriers: Dict[int, IRSegment] = dict() - self.fsegments: Dict[int, IRSegment] = dict() - # body forward recv adapter - self.rfadapter: Dict[int, Optional[IRAdapter]] = dict() - # body forward send adapter - self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() - # body backward recv adapter - self.rbadapter: Dict[int, Optional[IRAdapter]] = dict() - # body backward send adapter - self.sbadapter: Dict[int, Optional[IRAdapter]] = dict() - # encoder barrier backward prepare adapter - self.enc_badapter: Dict[int, IRAdapter] = dict() - # decoder barrier forward input prepare adapter - self.dec_fadapter: Dict[int, IRAdapter] = dict() - # decoder barrier backward input prepare adapter - self.dec_badapter: Dict[int, IRAdapter] = dict() - # num_stage - self.num_stages: int = -1 - # stage id - self.stage_id: Dict[int, int] = dict() - # reducers - self.dev_reducers: Dict[int, List[IRWeightReducer]] = dict() - # recompute - self.recompute = False - - - def apply(self) -> IRGraph: - self.mesh() - # each forward adapter has only one input and one output for each device - for node in self.graph.nodes(): - if isinstance(node, IRAdapter) and node.forward: - if len(set(node.outputs())) > 1 or len(set(node.inputs())) > 1: - warnings.warn( - "Detected one adapter has more than one input/output in stage transmission, " - "which is not safe for current scheduling implementation due to potential " - "mis-ordering of arguments. Better to use torch.cat and torch.chunk to " - "merge multiple tensors into one and unpack it at next stage." - ) - # each forward has corresponding backward - assert all(fseg.mirror in self.segments for fseg in self.segments if fseg.isfw()), \ - "Require backward of each forward stage" - - fsegments: List[IRSegment] = [fseg for fseg in self.segments if fseg.isfw()] - self.num_stages = len(fsegments) - 2 - - shard_enc_sid, shard_dec_sid = (0, self.num_stages // 2) - print(f'> shard encoder stage id: {shard_enc_sid} | shard decoder stage id: {shard_dec_sid} | num stages: {self.num_stages}') - - shard_enc, shard_dec = fsegments[0], fsegments[shard_dec_sid + 1] - assert len(shard_enc.device) == len(shard_dec.device) and len(shard_enc.device) >= 4, ( - f"This scheduling can only be applied to number of devices >= 4" - ) - pipe_stages = [seg for lid, seg in enumerate(fsegments) if lid not in (shard_enc_sid, shard_dec_sid + 1)] - - # setup shard encoder embedding - assert len(self.recvers[shard_enc.mirror]) == 1 - for devid in shard_enc.device: - self.encoder_barriers[devid] = shard_enc - self.enc_badapter[devid] = self.recvers[shard_enc.mirror][0] - # setup shard decoder embedding - assert len(self.recvers[shard_dec]) == 1 - assert len(self.recvers[shard_dec.mirror]) == 1 - for devid in shard_dec.device: - self.decoder_barriers[devid] = shard_dec - self.dec_fadapter[devid] = self.recvers[shard_dec][0] - self.dec_badapter[devid] = self.recvers[shard_dec.mirror][0] - # pipeline stages - for sid, stage in enumerate(pipe_stages): - for devid in stage.device: - assert devid not in self.fsegments, f"Pipeline stage cannot be overlapped" - # forward body - self.fsegments[devid] = stage - # forward recv - if sid in (shard_enc_sid, shard_dec_sid): - for adapter in self.recvers[stage]: - assert all(isinstance(prim, IdentityPrim) for prim in adapter.prims), ( - f"stage {sid} got unexpected forward recv adapters: {self.recvers[stage]}" - ) - self.rfadapter[devid] = None - else: - assert len(self.recvers[stage]) == 1 - self.rfadapter[devid] = self.recvers[stage][0] - # forward send - if sid == shard_dec_sid - 1: # decoder recv broadcast - assert len(self.senders[stage]) == 1 - self.sfadapter[devid] = None - elif sid == self.num_stages - 1: - assert len(self.senders[stage]) == 0 - self.sfadapter[devid] = None - else: - assert len(self.senders[stage]) == 1 - self.sfadapter[devid] = self.senders[stage][0] - # backward recv - if sid in (shard_dec_sid - 1, self.num_stages - 1): - for adapter in self.recvers[stage.mirror]: - assert all(isinstance(prim, IdentityPrim) for prim in adapter.prims), ( - f"stage {sid} got unexpected backward recv adapters: {self.recvers[stage]}" - ) - self.rbadapter[devid] = None - else: - assert len(self.recvers[stage.mirror]) == 1, \ - f"stage {sid} got unexpected backward recv adapters: {self.recvers[stage.mirror]}" - self.rbadapter[devid] = self.recvers[stage.mirror][0] - # backward send: - if sid == shard_dec_sid: # decoder broadcast - assert len(self.senders[stage.mirror]) == 1 - self.sbadapter[devid] = None - elif sid == shard_enc_sid: # encoder broadcast - assert len(self.senders[stage.mirror]) == 1 - self.sbadapter[devid] = None - else: - self.sbadapter[devid] = self.senders[stage.mirror][0] - - # weight reducer - self.dev_reducers[devid] = [reducer for reducer in self.reducers if devid in reducer.device] - # stage id - self.stage_id[devid] = sid - - return self.graph - - def kwargs(self, devid: int) -> Dict[str, IRCell]: - """ - return kwargs for runtime caller - """ - return dict( - encoder_barrier = self.encoder_barriers[devid], - decoder_barrier = self.decoder_barriers[devid], - segment = self.fsegments[devid], - sfadapter = self.sfadapter[devid], - rfadapter = self.rfadapter[devid], - sbadapter = self.sbadapter[devid], - rbadapter = self.rbadapter[devid], - enc_badapter = self.enc_badapter[devid], - dec_fadapter = self.dec_fadapter[devid], - dec_badapter = self.dec_badapter[devid], - dataloader = 'dataloader', - stage_id = self.stage_id[devid], - num_stages = self.num_stages, - num_microbatch = self.nmicros, - reducers = self.dev_reducers[devid], - recompute = self.recompute - ) - - def __repr__(self) -> str: - dscp = '' - devices = self.devmesh[0] - for devid in devices: - dscp += (f"Interplaced Schedule: Stage[{self.stage_id[devid]}](dev {devid})(\n" - f" encoder_barrier = {self.encoder_barriers[devid]}\n" - f" decoder_barrier = {self.decoder_barriers[devid]}\n" - f" segment = {self.fsegments[devid]}\n" - f" send-fw = {self.sfadapter[devid]}\n" - f" recv-fw = {self.rfadapter[devid]}\n" - f" send-bw = {self.sbadapter[devid]}\n" - f" recv-bw = {self.rbadapter[devid]}\n" - f" enc_badapter = {self.enc_badapter[devid]}\n" - f" dec_fadapter = {self.dec_fadapter[devid]}\n" - f" dec_badapter = {self.dec_badapter[devid]}\n" - f")\n") - return dscp diff --git a/cube/graph/schedule/schednf1b.py b/cube/graph/schedule/schednf1b.py deleted file mode 100644 index a900c406..00000000 --- a/cube/graph/schedule/schednf1b.py +++ /dev/null @@ -1,77 +0,0 @@ - -from typing import Dict, Optional, List -import warnings - -from cube.ir.cten import IRCell -from cube.ir.adapter.adapter import IRAdapter -from cube.ir.adapter.adapter import IRWeightReducer - -from cube.graph.graph import IRGraph, IRSegment -from cube.graph.schedule.sched1f1b import IRSchedule1F1B - - -class IRScheduleNF1B(IRSchedule1F1B): - """ - NF1B Scheduling - - This treats model as a linear graph which can be - grouped into continous stages. - - [Recv-Forward/Dataloader] Forward-Segment [Send-Forward] - [Recv-Backward] Backward-Segment [Send-Backward] - """ - - def __init__(self, graph, nmicros: int, recycle: int): - super().__init__(graph, nmicros) - self.signature = 'cube.runtime.schedule.ScheduleNF1B.run' - # forward body - self.fsegments: Dict[int, IRSegment] = dict() - # forward send - self.sfadapter: Dict[int, Optional[IRAdapter]] = dict() - # forward recv - self.rfadapter: Dict[int, Optional[IRAdapter]] = dict() - # backard send - self.sbadapter: Dict[int, Optional[IRAdapter]] = dict() - # backward recv - self.rbadapter: Dict[int, Optional[IRAdapter]] = dict() - # num_stage - self.num_stages: int = -1 - # stage id - self.stage_id: Dict[int, int] = dict() - # reducers - self.dev_reducers: Dict[int, List[IRWeightReducer]] = dict() - # recycle - self.recycle = recycle - - def kwargs(self, devid: int) -> Dict[str, IRCell]: - """ - return kwargs for runtime caller - """ - return dict( - segment = self.fsegments[devid], - sfadapter = self.sfadapter[devid], - rfadapter = self.rfadapter[devid], - sbadapter = self.sbadapter[devid], - rbadapter = self.rbadapter[devid], - dataloader = 'dataloader', - stage_id = self.stage_id[devid], - num_stages = self.num_stages, - num_microbatch = self.nmicros, - recycle = self.recycle, - reducers = self.dev_reducers[devid], - ) - - def __repr__(self) -> str: - dscp = '' - for mesh in self.devmesh: - devid = mesh[0] - # segment = self.segments[devid].to_str(skip_attr=True) if self.segment[mesh[0]] else None - dscp += (f"NF1B Schedule: Stage[{self.stage_id[mesh[0]]}](dev {mesh})(\n" - f" segment = {self.segments[devid]}\n" - f" send-fw = {self.sfadapter[mesh[0]]}\n" - f" recv-fw = {self.rfadapter[mesh[0]]}\n" - f" send-bw = {self.sbadapter[mesh[0]]}\n" - f" recv-bw = {self.rbadapter[mesh[0]]}\n" - f" recycle = {self.recycle}\n" - f")\n") - return dscp diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py index 0f45731a..1c5df7fb 100644 --- a/cube/graph/schedule/schedplan.py +++ b/cube/graph/schedule/schedplan.py @@ -7,7 +7,7 @@ from cube.graph.graph import IRGraph from cube.graph.segment import IRSegment - +from cube.flags import CompileFlag class Block: @@ -93,7 +93,7 @@ def build(self): # get all weight reducers self.reducers = self.graph.select(ntype=IRWeightReducer, flatten=False) - def depend(self, prev: Block, next: Block) -> bool: + def depends(self, prev: Block, next: Block) -> bool: return prev.mid == next.mid and self.graph.depends(prev.content, next.content) @@ -101,22 +101,28 @@ class PlanBase: def __init__(self, graph: IRGraph, _dependency: Optional[ScheduleDependency] = None): self._graph: IRGraph = graph - self._segments: List[Block] = [] - self._step_segments: List[List[Block]] = [] + self._blocks: List[Block] = [] + + # execution time table + # 1) the blocks in execution at each step + self._step_blocks: List[List[Block]] = [] + # 2) the devices in execution at each step self._step_devices: List[Set[int]] = [] - # adapters executed *after* the segments on that step + # 3) the adapters executed *after* the segments on their steps self._step_adapters: List[List[Block]] = [] - self._block_step: Dict[Block, int] = {} + # the start time step of block + self._block_start_step: Dict[Block, int] = {} + # dependency table self._dependency = _dependency if _dependency is not None \ else ScheduleDependency(graph) - + # topological sequence self._seqs: List[IRCell] = [] @property def nsteps(self) -> int: - return len(self._step_segments) + return len(self._step_blocks) @property def graph(self) -> IRGraph: @@ -131,51 +137,144 @@ def device(self) -> Tuple[int]: def nodes(self) -> Tuple[Block]: return tuple(self._seqs) + + def add_block(self, block: Block, step: int): + """Add a block to start executing from step""" + self._extend_step(step + block.span - 1) + # check + for t in range(block.span): + for devid in block.device: + if devid in self._step_devices[step+t]: + raise RuntimeError( + f"inserting confict at device {devid} of time step {step+t}: " + f"cannot execute multiple blocks at a same time step") + for t in range(block.span): + self._step_blocks[step+t].append(block) + self._step_devices[step+t].update(block.device) + self._block_start_step[block] = step + self._blocks.append(block) + return block - def add_segment(self, seg: IRSegment, micro_batch_id: int, step: int, span: Optional[int] = 1) -> Block: - """ - Add a segment `seg` to be executed with `micro-batch-id` data at step `step`. + def add_segment(self, seg: IRSegment, micro_batch_id: int, + step: int, span: Optional[int] = 1) -> Block: + """Add a segment to be executed with micro_batch_id data at step. + + The segments after `step` will keep unchanged. + + Args: + seg (IRSegment): the segment to add to the plan + micro_batch_id (int): the micro-batch id to execute the segment + step (int): the step to execute the segment + span (int): the time step costs to execute the segment + + Returns: + block (Block): the block representing the segment """ - self._extend_step(step + span - 1) - if len(self._step_segments[step]) == 1 and isinstance(self._step_segments[0], PlanBase): - assert False, "Cannot add an IRSegment into a step that already has Repetend." - assert all(devid not in self._step_devices for devid in seg.device), \ - f"A device cannot execute multiple segments on a same step" block = Block(seg, micro_batch_id, span) - for t in range(span): - self._step_segments[step+t].append(block) - self._step_devices[step+t].update(seg.device) - self._block_step[block] = step - self._segments.append(block) + self.add_block(block, step) + return block + + def insert_step(self, step: int, seg: IRSegment, micro_batch_id: int, span: Optional[int] = 1) -> Block: + """Insert `span` steps at current `step`. + + The segments after `step` will be pushed `span` time step for execution。 + + Args: + step (int): the step to insert + seg (IRSegment): the segment to insert + micro_batch_id (int): the micro-batch id to execute the segment + span (int): the time step costs to execute the segment + + Returns: + block (Block): the block representing the segment + """ + # shift + assert all(len(adapters) == 0 for adapters in self._step_adapters) + for block in self._blocks: + start = self.start(block) + if start >= step: + self._block_start_step[block] += span + elif start + block.span > step: + raise NotImplementedError( + f"Cannot shift the block {block} that is in execution on step {step}") + # insert + block = Block(seg, micro_batch_id, span) + for _ in range(span): + self._step_blocks.insert(step, [block]) + self._step_devices.insert(step, set(seg.device)) + self._step_adapters.insert(step, []) + self._block_start_step[block] = step + self._blocks.append(block) return block - def segments(self, step: int) -> Tuple[Block]: - """ - Get segment blocks at step - """ - assert step < self.nsteps - blocks = self._step_segments[step] - blocks = tuple(blk for blk in blocks if self.step(blk) == step) + def remove_step(self, step: int): + """Remove the step if there are no blocks in execution. + + All the blocks after the `step` will be shifted earlier. + This can only apply when no adapters are placed. + + Args: + step (int): the step to remove + + Returns: + None + """ + if len(self._step_blocks[step]) > 0: + raise RuntimeError(f"Cannot remove step {step} with blocks in execution") + if len(self._step_adapters[step]) > 0: + raise RuntimeError(f"Cannot remove step {step} with adapters in execution") + # shift + for block in self._blocks: + if self.start(block) > step: + self._block_start_step[block] -= 1 + self._step_blocks.pop(step) + self._step_devices.pop(step) + self._step_adapters.pop(step) + + def shrink(self): + """Remove steps that have no blocks in execution + + Note the implementation is costly. Users should avoid + calling it many times. + """ + for step in range(self.nsteps-1, -1, -1): + if len(self._step_blocks[step]) == 0: + self.remove_step(step) + + def blocks(self, step: int) -> Tuple[Block]: + """Get blocks in execution at the step""" + if step >= self.nsteps: + return () + blocks = self._step_blocks[step] + return tuple(blocks) + + def start_blocks(self, step: int) -> Tuple[Block]: + """Get blocks starting at the step""" + if step >= self.nsteps: + return () + blocks = self._step_blocks[step] + blocks = tuple(blk for blk in blocks if self.start(blk) == step) return blocks - def step(self, block: Block) -> int: - """Get the step of the block - """ - return self._block_step[block] + def start(self, block: Block) -> int: + """Get the start step of the block""" + return self._block_start_step[block] - def all_segments(self) -> Tuple[Block]: + def all_blocks(self) -> Tuple[Block]: """ Get all segment blocks """ - return tuple(self._segments) + return tuple(self._blocks) + + def depends(self, prev: Block, succ: Block) -> bool: + """Check whether prev block directly depends on succ block""" + return self._dependency.depends(prev, succ) def _extend_step(self, step: int): - """ - Extend the maximize plan with `step`. - """ - if len(self._step_segments) <= step: - nextend = step - len(self._step_segments) + 1 - self._step_segments += [[] for _ in range(nextend)] + """Extend the maximal accessible steps of plan to `step` index""" + if len(self._step_blocks) <= step: + nextend = step - len(self._step_blocks) + 1 + self._step_blocks += [[] for _ in range(nextend)] self._step_devices += [set() for _ in range(nextend)] self._step_adapters += [[] for _ in range(nextend)] @@ -187,7 +286,7 @@ def _place_dataloader(self): for dl in self._dependency.dataloaders: inserted_mids = set() for step in range(self.nsteps): - blocks = self.segments(step) + blocks = self.start_blocks(step) for block in blocks: segment, mid = block.content, block.mid if mid in inserted_mids: continue @@ -195,8 +294,9 @@ def _place_dataloader(self): if self.graph.depends(dl, segment): dl_block = Block(dl, mid, 1) # print(f'inserting microbatch {mid} at step {step} before {segment.name}{segment.cid}') - self._step_segments[step+block.span-1].insert(0, dl_block) - self._block_step[dl_block] = step+block.span-1 + self._blocks.append(dl_block) + self._step_blocks[step+block.span-1].insert(0, dl_block) + self._block_start_step[dl_block] = step+block.span-1 inserted_mids.add(mid) break @@ -207,102 +307,10 @@ def topo_sort(self): """ self._seqs = [] for step in range(self.nsteps): - self._seqs += self.segments(step) + self._seqs += self.start_blocks(step) self._seqs += self._step_adapters[step] -class Repetend(PlanBase): - """ - A repetend is a node in SchedulePlan, representing its nodes - will be repeatedly executed by `span` times witn growing - micro-batch index. - """ - - def __init__(self, graph: IRGraph, dependency: ScheduleDependency, - span: int, step_segments: List[List[Block]], ): - """ - @param graph IRGraph - @param dependency: ScheduleDependency - @param span int: the repeated execution time - @param step_segments List[List[Block]] - """ - super().__init__(graph, dependency) - self._span = span - self._extend_step(len(step_segments)) - self._step_segments = step_segments - for step, blocks in enumerate(step_segments): - devices = set() - for block in blocks: - devices.update(block.device) - self._step_devices[step] = devices - # the adapters that will be performed outside the repetend - self._post_adapters: List[Block] = [] - - @property - def span(self) -> int: - return self._span - - def nodes(self) -> Tuple[Block]: - return tuple(self._seqs) - - def apply(self): - self._place_adapters() - self._place_dataloader() - self.topo_sort() - - def _place_adapters(self): - """ - Place adapters - """ - # step1: unrolling repetend for one step - cnts: Dict[IRSegment, int] = {} - for step in range(self.nsteps): - for blk in self.segments(step): - cnts.setdefault(blk.content, 0) - cnts[blk.content] += 1 - extended_blocks = [] - for step in range(self.nsteps): - for blk in self.segments(step): - extend_blk = Block(blk.content, blk.mid + cnts[blk.content], blk.span) - extended_blocks.append(extend_blk) - # step2: generate adapters for each step - all_blocks = self.all_segments() - for adapter, sender in self._dependency.senders.items(): - for step in range(self.nsteps): - for block in self.segments(step): - if block.content != sender: continue - # sender adapter can be classified into three categories - # 1) its recver are in the same repetend - # 2) its recver are in neighbored repetend - # - we don't allow send and recver in un-neighbored repetend - # 3) its recver are outside the repetend - recver = self._dependency.recvers[adapter] - rblock = Block(recver, block.mid, block.span) - ablock = Block(adapter, block.mid, 1) - # case 1) - if rblock in all_blocks: - self._step_adapters[step+block.span-1].append(ablock) - # case 2) - elif rblock in extended_blocks: - self._step_adapters[self.nsteps-1].append(Block(adapter, block.mid - cnts[blk.content], 1)) - self._post_adapters.append(ablock) - # case 3) - else: - self._post_adapters.append(ablock) - - def get_post_adapters(self) -> List[Block]: - return tuple(self._post_adapters) - - def __repr__(self): - dscp = f'Repetend-{self.device}(span={self._span}\n' - for step, blks in enumerate(self._step_segments): - dscp += f'\n Substep {step}:\n' - for blk in blks: - dscp += ' ' + repr(blk) + '\n' - dscp += ')' - return dscp - - class SchedulePlan(PlanBase): """ A schedule plan leverages the fact no data dependency across different @@ -321,6 +329,8 @@ class SchedulePlan(PlanBase): def __init__(self, graph: IRGraph, num_microbatches: int): super().__init__(graph) + if CompileFlag.async_reducer: + raise NotImplementedError("Async reducer is not supported for schedule plan yet.") # execution sequence self._num_microbatches = num_microbatches # bind to the graph @@ -337,32 +347,6 @@ def nmicros(self) -> int: def graph(self) -> IRGraph: return self._graph - def repeat(self, from_step: int, to_step: int, span: int) -> Repetend: - """ - Create a repetend where the nodes inside the step ranges will - be repeatedly executed by `span` time, with the increasing micro-batch - index. The microbatch index among same segment must be - consecutive. - - Note: calling this will shrink self.nsteps and the blocks begin from - to_step will be shifted to the front of total steps by `to_step - from_step - - @param from_step int: starting (included) step - @param to_step int: stopping (excluded) step - @param span int: repeat time, i.e., number of increasing micro-batch index - - @return repetend Repetend - """ - raise NotImplementedError("repeat is not supported.") - assert 0 < from_step and from_step < self.nsteps - assert 0 < to_step and to_step <= self.nsteps - segment_blocks: List[List[Block]] = self._step_segments[from_step:to_step] - repetend = Repetend(self._graph, self._dependency, span, segment_blocks) - self._step_segments = self._step_segments[:from_step] + [[repetend]] + self._step_segments[to_step:] - self._step_adapters = self._step_adapters[:from_step] + [[]] + self._step_adapters[to_step:] - self._step_devs = self._step_devs[:from_step] + [set(repetend.device)] + self._step_devs[to_step:] - return repetend - def finish(self): """ Check whether the description contains full micro-batches @@ -377,14 +361,10 @@ def apply(self): """ # step 1: build dependency for scheduling self._dependency.build() - # step 2: apply repetends - for blocks in self._step_segments: - if len(blocks) == 1 and isinstance(blocks[0], Repetend): - blocks[0].apply() - # step 3: apply this scheduling + # step 2: apply this scheduling self._place_adapters() self._place_dataloader() - # step 4: generate topological sequence + # step 3: generate topological sequence self.topo_sort() def validate(self) -> bool: @@ -393,10 +373,10 @@ def validate(self) -> bool: @return valid bool """ - for block1 in self._segments: - for block2 in self._segments: - if self._dependency.depend(block1, block2): - if self.step(block1) >= self.step(block2): + for block1 in self._blocks: + for block2 in self._blocks: + if self._dependency.depends(block1, block2): + if self.start(block1) >= self.start(block2): return False return True @@ -413,28 +393,30 @@ def _place_adapters(self): sender: IRSegment = self._dependency.senders[adapter] # find sender step and insert adapter for step in range(self.nsteps): - blocks = self.segments(step) - if len(blocks) == 1 and isinstance(blocks[0], Repetend): - self._step_adapters[step] += list(blocks[0].get_post_adapters()) - else: - assert all(isinstance(blk, Block) for blk in blocks) - segments = [block.content for block in blocks] - mids = [block.mid for block in blocks] - if sender in segments: - span = blocks[segments.index(sender)].span - mid = mids[segments.index(sender)] - self._step_adapters[step+span-1].append(Block(adapter, mid, 1)) + blocks = self.start_blocks(step) + assert all(isinstance(blk, Block) for blk in blocks) + segments = [block.content for block in blocks] + mids = [block.mid for block in blocks] + if sender in segments: + span = blocks[segments.index(sender)].span + mid = mids[segments.index(sender)] + self._step_adapters[step+span-1].append(Block(adapter, mid, 1)) def topo_sort(self): super().topo_sort() for reducer in self._dependency.reducers: self._seqs.append(reducer) - def __repr__(self) -> str: + def str(self, show_max_steps: Optional[int] = None) -> str: + if show_max_steps is None: + show_max_steps = self.nsteps + dscp = f"SchedulePlan:\n" + if show_max_steps < self.nsteps: + dscp += f"only show the first {show_max_steps} steps\n" sids: Dict[IRCell, int] = {} - for block in self._segments: + for block in self._blocks: if block.content not in sids: sids[block.content] = len(sids) @@ -444,14 +426,13 @@ def __repr__(self) -> str: dscp += '\n' dscp += '\nAnnotation: i(f/b)j = segment i on executing (forward/backward) microbatch j' - for devid in sorted(self.device): timeline = '\n' step = 0 - while step < self.nsteps: + while step < min(self.nsteps, show_max_steps): # segment have_block = False - for block in self._step_segments[step]: + for block in self._step_blocks[step]: if devid in block.device: have_block = True break @@ -472,5 +453,10 @@ def __repr__(self) -> str: # timeline += ' {0: <5}'.format('adapt') # else: # timeline += ' {0: <5}'.format('') + if show_max_steps < self.nsteps: + timeline += f" ... (remaining {self.nsteps-show_max_steps} steps)" dscp += timeline return dscp + + def __repr__(self): + return self.str(show_max_steps=20) diff --git a/cube/graph/schedule/strategy.py b/cube/graph/schedule/strategy.py deleted file mode 100644 index 323b4e08..00000000 --- a/cube/graph/schedule/strategy.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Tuple, Dict, Any, List -from cube.graph.graph import IRGraph, IRSegment -from cube.ir.adapter.adapter import IRAdapter, IRWeightReducer -from cube.ir.cten import IRCell - - -class IRScheduleStrategy: - - def __init__(self, graph: IRGraph, nmicros: int) -> None: - self.graph : IRGraph = graph - self.nmicros : int = nmicros - self.devmesh: List[Tuple[int]] = [] - # preprocess before segments - self.pre_process: List[IRCell] = [] - self.segments: List[IRSegment] = [] - # the recver adapters for this segment - self.recvers: Dict[IRSegment, List[IRAdapter]] = dict() - # the sender adapters for this segment - self.senders: Dict[IRSegment, List[IRAdapter]] = dict() - # postprocess of weight reducers - self.reducers: List[IRWeightReducer] = [] - self.signature: str = '' - - def apply(self, graph: IRGraph) -> IRGraph: - raise NotImplementedError - - def kwargs(self, device: int) -> Dict[str, Any]: - raise NotImplementedError - - def mesh(self) -> List[List[int]]: - """! - Group operators into segments corresponding to graph stage. - Reorder adapter output to match with segment input order - """ - for segment in self.graph.nodes(): - if isinstance(segment, IRSegment): - self.segments.append(segment) - self.devmesh.append(segment.device) - self.recvers[segment] = [] - self.senders[segment] = [] - - for adapter in self.graph.nodes(): - if isinstance(adapter, IRAdapter): - for segment in self.segments: - if self.graph.depends(adapter, segment): - self.recvers[segment].append(adapter) - elif self.graph.depends(segment, adapter): - self.senders[segment].append(adapter) - if isinstance(adapter, IRWeightReducer): - self.reducers.append(adapter) diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py index 898c5fec..4fea0597 100644 --- a/cube/runtime/__init__.py +++ b/cube/runtime/__init__.py @@ -4,4 +4,3 @@ from cube.runtime import resource from cube.runtime import module from cube.runtime import function -from cube.runtime import schedule diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 2520476c..750b81c8 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -2,7 +2,7 @@ import torch from cube.runtime.device import DeviceGroup -from cube.profiler.timer import CudaTimer, print_each_rank +from cube.profiler.timer import CudaTimer from cube.runtime.executor import AsyncCommHandler @@ -275,32 +275,3 @@ def broadcast(itensor: torch.Tensor, shape: Tuple[int], dtype: torch.dtype, src: if not async_op: CudaTimer().stop(field_name='comm', predefined=True) return tensor - - -def exchange(tensor: torch.Tensor, ranks: List[int], async_op=False) -> torch.Tensor: - """ - Exchange a same-shaped tensor between two ranks - """ - - if not async_op: - CudaTimer().start(field_name='comm', predefined=True) - - assert len(ranks) == 2 - group = DeviceGroup().get_group(ranks) - myrank = torch.distributed.get_rank(group) - - tensor_list = [tensor, torch.empty_like(tensor)] if myrank == 0 \ - else [torch.empty_like(tensor), tensor.data] - - work = torch.distributed.all_gather(tensor_list, tensor, group=group, async_op=async_op) - if work: - exchange_callback = lambda t: tensor_list[(myrank + 1) % 2] - AsyncCommHandler().submit(tensor, [work], exchange_callback) - otensor = tensor - else: - otensor = tensor_list[(myrank + 1) % 2] - - if not async_op: - CudaTimer().stop(field_name='comm', predefined=True) - - return otensor diff --git a/cube/runtime/schedule/__init__.py b/cube/runtime/schedule/__init__.py deleted file mode 100644 index 34b7d0c7..00000000 --- a/cube/runtime/schedule/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from cube.runtime.schedule.sched1f1b import Schedule1F1B -from cube.runtime.schedule.schedmix import ScheduleMix -from cube.runtime.schedule.schednf1b import ScheduleNF1B -from cube.runtime.schedule.schedinfer import ScheduleInfer \ No newline at end of file diff --git a/cube/runtime/schedule/sched1f1b.py b/cube/runtime/schedule/sched1f1b.py deleted file mode 100644 index c8d4bdde..00000000 --- a/cube/runtime/schedule/sched1f1b.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Callable, Iterable, List -import torch - -from cube.runtime.schedule.strategy import ScheduleABC - - -class Schedule1F1B(ScheduleABC): - - @staticmethod - def run(segment: Callable, # forward body - rfadapter: Callable, # recv_forward adapter - sfadapter: Callable, # send_forward adapter - rbadapter: Callable, # recv_backward adapter - sbadapter: Callable, # send_backward adapter - dataloader: Iterable, - stage_id: int, - num_stages: int, - num_microbatch: int, - reducers: List[Callable], # weight reducers - recompute=False): - - # special case: num_stages == 1: use gradient accum - if num_stages == 1: - for _ in range(num_microbatch): - inputs = Schedule1F1B.dataloader_step(dataloader) - outputs = Schedule1F1B.forward_step(segment, *inputs) - input_grads = Schedule1F1B.backward_step(inputs, outputs, (None,)) - for reducer in reducers: - reducer() - return - - num_warmup_microbatches = num_stages - 1 - stage_id - num_warmup_remaining = num_microbatch - num_warmup_microbatches - - # warmup - for _ in range(num_warmup_microbatches): - # recv forward - # print(f'rank[{torch.distributed.get_rank()}]: line26: recving forward') - inputs = Schedule1F1B.adapter_step(rfadapter, True) - inputs = Schedule1F1B.dataloader_step(dataloader) if inputs == (None,) else inputs - # forward - Schedule1F1B.push_tail('inputs', inputs) - if recompute: - with torch.no_grad(): - outputs = Schedule1F1B.forward_step(segment, *inputs) - Schedule1F1B.push_tail('outputs', None) - else: - # print(f'rank[{torch.distributed.get_rank()}]: line36: forward') - outputs = Schedule1F1B.forward_step(segment, *inputs) - Schedule1F1B.push_tail('outputs', outputs) - # send forward - # print(f'rank[{torch.distributed.get_rank()}]: line40 send forward') - Schedule1F1B.adapter_step(sfadapter, True, *outputs) - - if num_warmup_remaining > 0: - # print(f'rank[{torch.distributed.get_rank()}]: line44 recv forward') - inputs = Schedule1F1B.adapter_step(rfadapter, True) - inputs = Schedule1F1B.dataloader_step(dataloader) if inputs == (None,) else inputs - - # steady - for i in range(num_warmup_remaining): - # forward - Schedule1F1B.push_tail('inputs', inputs) - if recompute: - with torch.no_grad(): - outputs = Schedule1F1B.forward_step(segment, *inputs) - Schedule1F1B.push_tail('outputs', None) - else: - # print(f'rank[{torch.distributed.get_rank()}]: line 57 forward') - outputs = Schedule1F1B.forward_step(segment, *inputs) - Schedule1F1B.push_tail('outputs', outputs) - - # send forward recv backward - # print(f'rank[{torch.distributed.get_rank()}]: line62 send forward recv backward') - grads = Schedule1F1B.exchange(sfadapter, rbadapter, stage_id, (True, False), *outputs) - grads = (None,) if len(grads) == 0 else grads - - # backward - inputs, outputs = Schedule1F1B.pop_head('inputs'), Schedule1F1B.pop_head('outputs') - if recompute: - assert outputs is None - outputs = Schedule1F1B.forward_step(segment, *inputs) - # print(f'rank[{torch.distributed.get_rank()}]: line71 backward') - input_grads = Schedule1F1B.backward_step(inputs, outputs, grads) - - # send backward recv forward - if i != num_warmup_remaining - 1: - # print(f'rank[{torch.distributed.get_rank()}]: line77 send backward recv forward') - inputs = Schedule1F1B.exchange(sbadapter, rfadapter, stage_id, (False, True), *input_grads) - inputs = Schedule1F1B.dataloader_step(dataloader) if inputs == (None,) else inputs - else: - # send backward - # print(f'rank[{torch.distributed.get_rank()}]: line82 send backward') - Schedule1F1B.adapter_step(sbadapter, False, *input_grads) - - # cooldown - for i in range(num_warmup_microbatches): - inputs, outputs = Schedule1F1B.pop_head('inputs'), Schedule1F1B.pop_head('outputs') - # recv backward - # print(f'rank[{torch.distributed.get_rank()}]: line89 recv backward') - grads = Schedule1F1B.adapter_step(rbadapter, False) - grads = (None,) if len(grads) == 0 else grads - # backward - if recompute: - assert outputs is None - outputs = Schedule1F1B.forward_step(segment, *inputs) - # print(f'rank[{torch.distributed.get_rank()}]: line96 backward') - input_grads = Schedule1F1B.backward_step(inputs, outputs, grads) - # send backward - # print(f'rank[{torch.distributed.get_rank()}]: line99 send backward') - Schedule1F1B.adapter_step(sbadapter, False, *input_grads) - - # allreduce gradient - for reducer in reducers: - reducer() - - Schedule1F1B.assert_empty() - # print(f'rank[{torch.distributed.get_rank()}]: ok here') diff --git a/cube/runtime/schedule/schedinfer.py b/cube/runtime/schedule/schedinfer.py deleted file mode 100644 index 0ba9a498..00000000 --- a/cube/runtime/schedule/schedinfer.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Callable, Iterable, List, Optional -import torch - -from cube.runtime.schedule.strategy import ScheduleABC - - -class ScheduleInfer(ScheduleABC): - - @staticmethod - def run(segment: Callable, # forward body - rfadapter: Optional[Callable], # recv forward adapter - sfadapter: Optional[Callable], # send forward adapter - dataloader: Iterable, - num_microbatch: int): - - for _ in range(num_microbatch): - # recv forward - inputs = ScheduleInfer.adapter_step(rfadapter, False) - inputs = ScheduleInfer.dataloader_step(dataloader) if inputs == (None,) else inputs - # forward - outputs = ScheduleInfer.forward_step(segment, *inputs) - # send forward - ScheduleInfer.adapter_step(sfadapter, True, *outputs) - - ScheduleInfer.assert_empty() diff --git a/cube/runtime/schedule/schedmix.py b/cube/runtime/schedule/schedmix.py deleted file mode 100644 index 425c5041..00000000 --- a/cube/runtime/schedule/schedmix.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Schedule Plan designed for Interplaced Pipeline -""" - -from typing import Callable, Iterable, List, Optional -import torch - -from cube.runtime.schedule.strategy import ScheduleABC - -def debug_msg(msg: str, ranks): - myrank = torch.distributed.get_rank() - if myrank in ranks: - print(f'rank [{myrank}]: {msg}') - - -class ScheduleMix(ScheduleABC): - """ - Emb -> Encoder -> Demb -> Decoder - - All communication will start at begining of each step and - finish at the end of step. No communication will happen cross - step, i.e., send from the previous step and recv at the next step. - """ - @staticmethod - def run(encoder_barrier: Callable, - decoder_barrier: Callable, - segment: Callable, - rfadapter: Optional[Callable], # segment adapter - sfadapter: Optional[Callable], # segment adapter - rbadapter: Optional[Callable], # segment adapter - sbadapter: Optional[Callable], # segment adapter - enc_badapter: Optional[Callable], # sharding encoder gradient input prepare adapter - dec_fadapter: Optional[Callable], # sharding decoder input prepare adapter - dec_badapter: Optional[Callable], # sharding decoder gradient input prepare adapter - dataloader: Iterable, - stage_id: int, - num_stages: int, - num_microbatch: int, - reducers: List[Callable], - recompute: bool = False): - - assert num_stages >= 4, f"Only support for stage number >= 4." - - enc_emb_stage = 0 - dec_emb_stage = num_stages // 2 - - fw_ofst = -(stage_id // 2) - bw_ofst = -(num_stages - 1 - (stage_id // 2)) - - # sharding encoder embed inputs / outputs - shard_enc_inputs, shard_enc_outputs = (None,), (None,) - shard_enc_input_grads, shard_enc_output_grads = (None,), (None,) - # sharding decoder embed inputs / outputs - shard_dec_inputs, shard_dec_outputs = (None,), (None,) - shard_dec_input_grads, shard_dec_output_grads = (None,), (None,) - # segement inputs / outputs - segment_inputs, segment_outputs = (None,), (None,) - segment_input_grads, segment_output_grads = (None,), (None,) - - for step in range(num_microbatch + num_stages - 1): - fmid, bmid = step + fw_ofst, step + bw_ofst - encoder_fw_mid = step - decoder_fw_mid = step - num_stages // 2 // 2 - encoder_bw_mid = step + 1 - num_stages // 2 * 2 - decoder_bw_mid = step + 1 - int(num_stages // 2 * 1.5) - do_forward = 0 <= fmid and fmid < num_microbatch - do_backward = 0 <= bmid and bmid < num_microbatch - - # step1: sharding encoder forward - if 0 <= encoder_fw_mid and encoder_fw_mid < num_microbatch: - data = ScheduleMix.dataloader_step(dataloader) - shard_enc_outputs = ScheduleMix.forward_step(encoder_barrier, *data) - ScheduleMix.push_tail('shard_enc_inputs', data) - ScheduleMix.push_tail('shard_enc_outputs', shard_enc_outputs) - shard_enc_outputs = tuple(t.detach().requires_grad_() for t in shard_enc_outputs) - - # step2: sharding decoder forward - if 0 <= decoder_fw_mid and decoder_fw_mid < num_microbatch: - if stage_id == dec_emb_stage - 1: - shard_dec_inputs = tuple(t.detach().requires_grad_() for t in segment_outputs) - ScheduleMix.adapter_step(dec_fadapter, True, *shard_dec_inputs) - else: - shard_dec_inputs = ScheduleMix.adapter_step(dec_fadapter, True) - shard_dec_outputs = ScheduleMix.forward_step(decoder_barrier, *shard_dec_inputs) - ScheduleMix.push_tail('shard_dec_inputs', shard_dec_inputs) - ScheduleMix.push_tail('shard_dec_outputs', shard_dec_outputs) - shard_dec_outputs = tuple(t.detach().requires_grad_() for t in shard_dec_outputs) - - # step3: forward then backward - if stage_id % 2 == 0: - - # After barrier communication: send backward recv forward =========> - if segment_input_grads != (None,): - ScheduleMix.adapter_step(sbadapter, False, *segment_input_grads) - segment_input_grads = (None,) - if do_forward: - if stage_id == enc_emb_stage: - segment_inputs = shard_enc_outputs - elif stage_id == dec_emb_stage: - segment_inputs = shard_dec_outputs - else: - segment_inputs = ScheduleMix.adapter_step(rfadapter, True) - # <=============================================================== - - segment_outputs = (None,) - if do_forward: - ScheduleMix.push_tail('segment_inputs', segment_inputs) - if recompute: - with torch.no_grad(): - segment_outputs = ScheduleMix.forward_step(segment, *segment_inputs) - ScheduleMix.push_tail('segment_outputs', None) - else: - segment_outputs = ScheduleMix.forward_step(segment, *segment_inputs) - ScheduleMix.push_tail('segment_outputs', segment_outputs) - - # recompute - if recompute and do_backward: - inputs = ScheduleMix.pop_head('segment_inputs') - ScheduleMix.pop_head('segment_outputs') - outputs = ScheduleMix.forward_step(segment, *inputs) - ScheduleMix.push_head('segment_inputs', inputs) - ScheduleMix.push_head('segment_outputs', outputs) - - # Inter barrier communication: recv backward send forward ======> - if do_backward: - segment_output_grads = ScheduleMix.adapter_step(rbadapter, False) - if segment_outputs != (None,): - ScheduleMix.adapter_step(sfadapter, True, *segment_outputs) - # <=============================================================== - - segment_input_grads = (None,) - if do_backward: - inputs = ScheduleMix.pop_head('segment_inputs') - outputs = ScheduleMix.pop_head('segment_outputs') - segment_input_grads = ScheduleMix.backward_step(inputs, outputs, segment_output_grads) - - # step3: backward then forward - if stage_id % 2 == 1: - - # After barrier communication: recv backward send forward =========> - if do_backward: - if stage_id == dec_emb_stage - 1: - segment_output_grads = shard_dec_input_grads - else: - segment_output_grads = ScheduleMix.adapter_step(rbadapter, False) - if segment_outputs != (None,): - segment_input_grads = ScheduleMix.adapter_step(sfadapter, True, *segment_outputs) - # <=============================================================== - - segment_input_grads = (None,) - if do_backward: - inputs = ScheduleMix.pop_head('segment_inputs') - outputs = ScheduleMix.pop_head('segment_outputs') - segment_input_grads = ScheduleMix.backward_step(inputs, outputs, segment_output_grads) - - # Inter barrier communication: send backward recv forward ========> - if segment_input_grads != (None,): - ScheduleMix.adapter_step(sbadapter, False, *segment_input_grads) - if do_forward: - segment_inputs = ScheduleMix.adapter_step(rfadapter, True) - # <=============================================================== - - segment_outputs = (None,) - if do_forward: - ScheduleMix.push_tail('segment_inputs', segment_inputs) - if recompute: - with torch.no_grad(): - segment_outputs = ScheduleMix.forward_step(segment, *segment_inputs) - ScheduleMix.push_tail('segment_outputs', None) - else: - segment_outputs = ScheduleMix.forward_step(segment, *segment_inputs) - ScheduleMix.push_tail('segment_outputs', segment_outputs) - - # recompute - if recompute and (0 <= bmid + 1 and bmid + 1 < num_microbatch): - inputs = ScheduleMix.pop_head('segment_inputs') - ScheduleMix.pop_head('segment_outputs') - outputs = ScheduleMix.forward_step(segment, *inputs) - ScheduleMix.push_head('segment_inputs', inputs) - ScheduleMix.push_head('segment_outputs', outputs) - - # step 4: sharding decoder backward - if 0 <= decoder_bw_mid and decoder_bw_mid < num_microbatch: - if stage_id == dec_emb_stage: - assert segment_input_grads != (None,) - shard_dec_output_grads = segment_input_grads - ScheduleMix.adapter_step(dec_badapter, False, *shard_dec_output_grads) - else: - shard_dec_output_grads = ScheduleMix.adapter_step(dec_badapter, False) - - inputs = ScheduleMix.pop_head('shard_dec_inputs') - outputs = ScheduleMix.pop_head('shard_dec_outputs') - shard_dec_input_grads = ScheduleMix.backward_step( - inputs, outputs, shard_dec_output_grads) - - # step 5: sharding encoder backward - if 0 <= encoder_bw_mid and encoder_bw_mid < num_microbatch: - if stage_id == enc_emb_stage: - assert segment_input_grads != (None,) - shard_enc_output_grads = segment_input_grads - ScheduleMix.adapter_step(enc_badapter, False, *shard_enc_output_grads) - else: - shard_enc_output_grads = ScheduleMix.adapter_step(enc_badapter, False) - - inputs = ScheduleMix.pop_head('shard_enc_inputs') - outputs = ScheduleMix.pop_head('shard_enc_outputs') - shard_enc_input_grads = ScheduleMix.backward_step( - inputs, outputs, shard_enc_output_grads) - - for reducer in reducers: - reducer() - - ScheduleMix.assert_empty() diff --git a/cube/runtime/schedule/schednf1b.py b/cube/runtime/schedule/schednf1b.py deleted file mode 100644 index 23cb8d75..00000000 --- a/cube/runtime/schedule/schednf1b.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -Schedule Plan tailored for AlphaFold - -The scheduling follows forward-backward pattern. -In steady phase, each forward will perform a single forward at `recycle+1` -micro-batches, with one keeping activation while others no activation. - -""" -from typing import Callable, Iterable, List, Tuple -from functools import partial -import torch - -from cube.runtime.schedule.strategy import ScheduleABC - - -def first_stage_rfadapter(shapes: Tuple[List[int]], dtypes: Tuple[List[torch.dtype]], dataloader): - outputs = next(dataloader) - outputs = tuple(t.clone() for t in outputs) - return outputs - -def last_stage_sfadapter(*msa_repr_and_pair_repr: torch.Tensor): - pass - - -class ScheduleNF1B(ScheduleABC): - - @staticmethod - def _deprecate_run(segment: Callable, # forward body - rfadapter: Callable, # recv_forward adapter - sfadapter: Callable, # send_forward adapter - rbadapter: Callable, # recv_backward adapter - sbadapter: Callable, # send_backward adapter - dataloader: Iterable, - stage_id: int, - num_stages: int, - num_microbatch: int, - recycle: int, - reducers: List[Callable]): - - assert num_microbatch >= num_stages - - # special case: num_stages == 1: use gradient accum - if num_stages == 1: - for _ in range(num_microbatch): - inputs = ScheduleNF1B.dataloader_step(dataloader) - for _ in range(recycle): - # FIXME: a simulation as output will be loss - with torch.no_grad(): - _ = ScheduleNF1B.forward_step(segment, *inputs) - outputs = ScheduleNF1B.forward_step(segment, *inputs) - input_grads = ScheduleNF1B.backward_step(inputs, outputs, (None,)) - for reducer in reducers: - reducer() - print(f'> rank [{torch.distributed.get_rank()}]: {ScheduleNF1B._fw_cnt}') - return - - # =============================== recycle ==================================== - if stage_id == 0: - assert rfadapter is None - shapes, dtypes = [], [] - for data in ScheduleNF1B.dataloader_step(dataloader): - shapes.append(list(data.size())) - dtypes.append(data.dtype) - rfadapter = partial(first_stage_rfadapter, shapes=shapes, dtypes=dtypes, dataloader=dataloader) - # if stage_id == num_stages - 1: - # assert sfadapter is None - # sfadapter = last_stage_sfadapter - - for rid in range(recycle): - for mid in range(num_microbatch): - # recv forward - if stage_id == 0 and rid == 0: - inputs = ScheduleNF1B.dataloader_step(dataloader) - else: - inputs = ScheduleNF1B.adapter_step(rfadapter, require_grad=(rid == recycle-1)) - # forward - with torch.no_grad(): - outputs = ScheduleNF1B.forward_step(segment, *inputs) - # FIXME: a simulation - if stage_id == num_stages - 1: - outputs = ScheduleNF1B.dataloader_step(dataloader) - # send forward - ScheduleNF1B.adapter_step(sfadapter, False, *outputs) - # recv forward batches TODO: optmize with async - datas = [] - if stage_id == 0: - for mid in range(num_microbatch): - inputs = ScheduleNF1B.adapter_step(rfadapter, require_grad=False) - inputs = (t.cpu() for t in inputs) - datas.append(inputs) - # ========================================================================== - - # 1F1B schedule - if stage_id == 0: rfadapter = None - if stage_id == num_stages - 1: sfadapter = None - num_warmup_microbatches = num_stages - 1 - stage_id - num_warmup_remaining = num_microbatch - num_warmup_microbatches - - # warmup - for _ in range(num_warmup_microbatches): - # recv forward - # print(f'rank[{torch.distributed.get_rank()}]: line26: recving forward') - inputs = ScheduleNF1B.adapter_step(rfadapter, True) - inputs = datas.pop(0) if inputs == (None,) else inputs - inputs = tuple(t.cuda() for t in inputs) - # forward - ScheduleNF1B.push_tail('inputs', inputs) - outputs = ScheduleNF1B.forward_step(segment, *inputs) - ScheduleNF1B.push_tail('outputs', outputs) - # send forward - # print(f'rank[{torch.distributed.get_rank()}]: line40 send forward') - ScheduleNF1B.adapter_step(sfadapter, True, *outputs) - - if num_warmup_remaining > 0: - # print(f'rank[{torch.distributed.get_rank()}]: line44 recv forward') - inputs = ScheduleNF1B.adapter_step(rfadapter, True) - inputs = datas.pop(0) if inputs == (None,) else inputs - - # steady - for i in range(num_warmup_remaining): - # forward - ScheduleNF1B.push_tail('inputs', inputs) - # print(f'rank[{torch.distributed.get_rank()}]: line 57 forward') - outputs = ScheduleNF1B.forward_step(segment, *inputs) - ScheduleNF1B.push_tail('outputs', outputs) - - # send forward recv backward - # print(f'rank[{torch.distributed.get_rank()}]: line62 send forward recv backward') - grads = ScheduleNF1B.exchange(sfadapter, rbadapter, stage_id, (True, False), *outputs) - grads = (None,) if len(grads) == 0 else grads - - # backward - inputs, outputs = ScheduleNF1B.pop_head('inputs'), ScheduleNF1B.pop_head('outputs') - # print(f'rank[{torch.distributed.get_rank()}]: line71 backward') - input_grads = ScheduleNF1B.backward_step(inputs, outputs, grads) - - # send backward recv forward - if i != num_warmup_remaining - 1: - # print(f'rank[{torch.distributed.get_rank()}]: line77 send backward recv forward') - inputs = ScheduleNF1B.exchange(sbadapter, rfadapter, stage_id, (False, True), *input_grads) - inputs = datas.pop(0) if inputs == (None,) else inputs - inputs = tuple(t.cuda() for t in inputs) - else: - # send backward - # print(f'rank[{torch.distributed.get_rank()}]: line82 send backward') - ScheduleNF1B.adapter_step(sbadapter, False, *input_grads) - - # cooldown - for i in range(num_warmup_microbatches): - inputs, outputs = ScheduleNF1B.pop_head('inputs'), ScheduleNF1B.pop_head('outputs') - # recv backward - # print(f'rank[{torch.distributed.get_rank()}]: line89 recv backward') - grads = ScheduleNF1B.adapter_step(rbadapter, False) - grads = (None,) if len(grads) == 0 else grads - # backward - # print(f'rank[{torch.distributed.get_rank()}]: line96 backward') - input_grads = ScheduleNF1B.backward_step(inputs, outputs, grads) - # send backward - # print(f'rank[{torch.distributed.get_rank()}]: line99 send backward') - ScheduleNF1B.adapter_step(sbadapter, False, *input_grads) - - # allreduce gradient - for reducer in reducers: - reducer() - - assert len(datas) == 0 - ScheduleNF1B.assert_empty() - - - @staticmethod - def run(segment: Callable, # forward body - rfadapter: Callable, # recv_forward adapter - sfadapter: Callable, # send_forward adapter - rbadapter: Callable, # recv_backward adapter - sbadapter: Callable, # send_backward adapter - dataloader: Iterable, - stage_id: int, - num_stages: int, - num_microbatch: int, - recycle: int, - reducers: List[Callable]): - - assert num_microbatch >= num_stages - - # special case: num_stages == 1: use gradient accum - if num_stages == 1: - for _ in range(num_microbatch): - inputs = ScheduleNF1B.dataloader_step(dataloader) - for _ in range(recycle): - # FIXME: a simulation as output will be loss - with torch.no_grad(): - _ = ScheduleNF1B.forward_step(segment, *inputs) - outputs = ScheduleNF1B.forward_step(segment, *inputs) - input_grads = ScheduleNF1B.backward_step(inputs, outputs, (None,)) - for reducer in reducers: - reducer() - # print(f'> rank [{torch.distributed.get_rank()}]: {ScheduleNF1B._fw_cnt}') - return - - # setup dummpy adapter - if stage_id == 0: - assert rfadapter is None - shapes, dtypes = [], [] - for data in ScheduleNF1B.dataloader_step(dataloader): - shapes.append(list(data.size())) - dtypes.append(data.dtype) - rfadapter = partial(first_stage_rfadapter, shapes=shapes, dtypes=dtypes, dataloader=dataloader) - if stage_id == num_stages - 1: - assert sfadapter is None - sfadapter = last_stage_sfadapter - - # =============================== warmup ======================== - for rid in range(recycle): - # forward rid micro-batches - for t in range(rid+1): - inputs = ScheduleNF1B.adapter_step(rfadapter, False) - inputs = ScheduleNF1B.dataloader_step(dataloader) if inputs == (None,) else inputs - with torch.no_grad(): - outputs = ScheduleNF1B.forward_step(segment, *inputs) - ScheduleNF1B.adapter_step(sfadapter, False, *outputs) - - # print(f'> rank [{torch.distributed.get_rank()}]: OK here') - - # recv inputs - inputs = ScheduleNF1B.adapter_step(rfadapter, stage_id != 0) - - # steady pattern - for fmid in range(num_microbatch + num_stages - 1 - stage_id): - - # ======================= forward region ==================== - if fmid + 1 < num_microbatch: - with torch.no_grad(): - outputs = ScheduleNF1B.forward_step(segment, *inputs) - - # ================== send forward recv backward ================== - send_fw = fmid + 1 < num_microbatch - bmid = fmid - (num_stages - 1 - stage_id) - recv_bw = 0 <= bmid and bmid < num_microbatch - if send_fw and recv_bw: - grads = ScheduleNF1B.exchange(sfadapter, rbadapter, stage_id, (False, False), *outputs) - elif send_fw: - ScheduleNF1B.adapter_step(sfadapter, False, *outputs) - elif recv_bw: - grads = ScheduleNF1B.adapter_step(rbadapter, False) - else: - assert False, f"> rank [{torch.distributed.get_rank()}]: Fail at fmid: {fmid}" - - # ===================== backward region ================== - - # recycle inference - for idx in range(recycle - 1): - if fmid + 2 + idx < num_microbatch: - # recv forward - inputs = ScheduleNF1B.adapter_step(rfadapter, False) - # forward - with torch.no_grad(): - outputs = ScheduleNF1B.forward_step(segment, *inputs) - # send forward - ScheduleNF1B.adapter_step(sfadapter, False, *outputs) - - # train forward - if fmid < num_microbatch: - # recv forward - inputs = ScheduleNF1B.adapter_step(rfadapter, stage_id != 0) - # forward - ScheduleNF1B.push_tail('inputs', inputs) - outputs = ScheduleNF1B.forward_step(segment, *inputs) - ScheduleNF1B.push_tail('outputs', outputs) - # send forward - ScheduleNF1B.adapter_step(sfadapter, True, *outputs) - - # train backward - bmid = fmid - (num_stages - 1 - stage_id) - if 0 <= bmid and bmid < num_microbatch: - inputs, outputs = ScheduleNF1B.pop_head('inputs'), ScheduleNF1B.pop_head('outputs') - input_grads = ScheduleNF1B.backward_step(inputs, outputs, grads) - - # =============== send backward recv forward ===================== - send_bw = 0 <= bmid and bmid < num_microbatch - recv_fw = fmid + 2 < num_microbatch - if send_bw and recv_fw: - ScheduleNF1B.exchange(sbadapter, rfadapter, stage_id, (False, False), *input_grads) - elif send_bw: - ScheduleNF1B.adapter_step(sbadapter, False, *input_grads) - elif recv_fw: - inputs = ScheduleNF1B.adapter_step(rfadapter, False) - - for reducer in reducers: - reducer() - - ScheduleNF1B.assert_empty() - # print(f'> rank [{torch.distributed.get_rank()}]: {ScheduleNF1B._fw_cnt}') diff --git a/cube/runtime/schedule/strategy.py b/cube/runtime/schedule/strategy.py deleted file mode 100644 index 581c702c..00000000 --- a/cube/runtime/schedule/strategy.py +++ /dev/null @@ -1,133 +0,0 @@ -from typing import Any, Callable, Dict, Iterable, List -import torch - -from cube.runtime.executor import AsyncCommHandler -from cube.flags import CompileFlag -from cube.profiler.timer import CudaTimer - - -class ScheduleABC: - - status: Dict[str, List[torch.Tensor]] = dict() - - @staticmethod - def forward_step(segment: Callable, *args, **kwargs): - """ - forward pass - """ - args = ScheduleABC.sync_tensors(args) - if not CompileFlag.async_comm: - CudaTimer().start('forward') - outputs = segment(*args, **kwargs) - if not CompileFlag.async_comm: - CudaTimer().stop('forward') - if not isinstance(outputs, tuple): - outputs = (outputs,) - return outputs - - @staticmethod - def backward_step(itensors: List[torch.Tensor], - otensors: List[torch.Tensor], - otensor_grads: List[torch.Tensor]) -> List[torch.Tensor]: - """ - backward pass - """ - otensor_grads = ScheduleABC.sync_tensors(otensor_grads) - for tensor in itensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - tensor.retain_grad() - if not CompileFlag.async_comm: - CudaTimer().start("backward") - otensors = [t for t in otensors if t.requires_grad] - assert len(otensors) == len(otensor_grads), f"output tensor mismatches with gradient number" - torch.autograd.backward(otensors, grad_tensors=otensor_grads) - if not CompileFlag.async_comm: - CudaTimer().stop("backward") - itensor_grads = [] - for tensor in itensors: - if torch.is_tensor(tensor) and tensor.requires_grad: - itensor_grads.append(tensor.grad) - else: - itensor_grads.append(None) - return tuple(itensor_grads) - - @staticmethod - def dataloader_step(dataloader: Iterable): - data = next(dataloader) - if not isinstance(data, tuple): - data = (data,) - return data - - @staticmethod - def adapter_step(adapter: Callable, require_grad : bool = True, *args): - """ - Adapter pass. - If the adapter is None, will return (None,) - """ - if adapter is None: return (None,) - # if adapter is None: return () - args = tuple(t for t in args if torch.is_tensor(t)) - if not CompileFlag.async_comm: - CudaTimer().start('adapter') - outputs = adapter(*args) - if not CompileFlag.async_comm: - CudaTimer().stop('adapter') - if not isinstance(outputs, tuple): - outputs = (outputs,) - if require_grad: - grad_dtypes = (torch.float16, torch.float32) - outputs = tuple(t.requires_grad_() if torch.is_tensor(t) and t.dtype in grad_dtypes else t for t in outputs) - return outputs - - @staticmethod - def exchange(sadapter: Callable, radapter: Callable, stage_id: int, require_grads: bool, *args): - """ - send adapter and recv adapter - """ - # TODO: optimize with batch operators - if stage_id % 2 == 0: - ScheduleABC.adapter_step(sadapter, require_grads[0], *args) - outs = ScheduleABC.adapter_step(radapter, require_grads[1]) - else: - outs = ScheduleABC.adapter_step(radapter, require_grads[1]) - ScheduleABC.adapter_step(sadapter, require_grads[0], *args) - return outs - - @staticmethod - def push_tail(name: str, val: Any): - if name not in ScheduleABC.status: - ScheduleABC.status[name] = [] - ScheduleABC.status[name].append(val) - - @staticmethod - def push_head(name: str, val: Any): - if name not in ScheduleABC.status: - ScheduleABC.status[name] = [] - ScheduleABC.status[name].insert(0, val) - - @staticmethod - def pop_head(name: str): - assert name in ScheduleABC.status, f"{name} is empty" - out = ScheduleABC.status[name].pop(0) - if len(ScheduleABC.status[name]) == 0: - del ScheduleABC.status[name] - return out - - @staticmethod - def pop_tail(name: str): - assert name in ScheduleABC.status, f"{name} is empty" - out = ScheduleABC.status[name].pop(-1) - if len(ScheduleABC.status[name]) == 0: - del ScheduleABC.status - return out - - @staticmethod - def sync_tensors(tensors: List[Any]) -> List[Any]: - """ - Wait until the finish of synchornized tensors - """ - return [AsyncCommHandler().wait(t) if torch.is_tensor(t) else t for t in tensors] - - @staticmethod - def assert_empty(): - assert len(ScheduleABC.status) == 0, f"status is not empty. Got field {list(ScheduleABC.status.keys())}" diff --git a/tests/runtime/test_runtime_collectives.py b/tests/runtime/test_runtime_collectives.py index 59580721..ff72261e 100644 --- a/tests/runtime/test_runtime_collectives.py +++ b/tests/runtime/test_runtime_collectives.py @@ -82,18 +82,6 @@ def _all2all_worker(async_op): return (clone_to_cpu(tensor), clone_to_cpu(otensor)) -def _exchange_worker(async_op): - shape = [128, 256] - - tensor = _get_tensor(shape) - otensor = cube.runtime.adapter.exchange(tensor, [0, 1], async_op=async_op) - - if async_op: - otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) - - return (clone_to_cpu(tensor), clone_to_cpu(otensor)) - - def _2gpu_worker(): _init_distributed(2) result = {} @@ -107,8 +95,6 @@ def _2gpu_worker(): result['reduce_scatter_async'] = _reduce_scatter_worker(True) result['all2all'] = _all2all_worker(False) result['all2all_async'] = _all2all_worker(True) - result['exchange'] = _exchange_worker(False) - result['exchange_async'] = _exchange_worker(True) return result @@ -150,11 +136,6 @@ def test_2gpu(): assert torch.equal(outputs[0][1], out0) assert torch.equal(outputs[1][1], out1) - # check exchange - outputs = results[0][f'exchange{op}'], results[1][f'exchange{op}'] - assert torch.equal(outputs[0][1], outputs[1][0]) - assert torch.equal(outputs[1][1], outputs[0][0]) - def _rdscatter_worker(async_op): shape = [128, 256] From 7c34c36f78162259b06435a9dbbf2708e9022a6d Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 25 Sep 2023 02:32:52 +0000 Subject: [PATCH 1500/1892] Merged PR 1841: refine code/fix version check refine code/fix version check --- .../fx/concrete_trace_utils/operator_patcher.py | 8 +++++--- cube/parallel.py | 4 ++-- tests/parallel_module/test_reducer_hook.py | 12 ++++++------ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index 7f2b109c..22405b7e 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -5,11 +5,11 @@ if TYPE_CHECKING: from .concrete_tracer import ConcreteTracer +import sys import ast import builtins import inspect import logging -import platform from textwrap import dedent from types import MethodType, FunctionType @@ -28,6 +28,7 @@ _logger = logging.getLogger(__name__) + class TransformerOp(ast.NodeTransformer): """ An ast transformer, to check and replace the python ops 'not/is/is not/in/not in' to functions in 'operator' module. @@ -242,7 +243,7 @@ def patch_inner_helper(self, func): tuple_wrapped = tuple try: - if platform.python_version_tuple() < ('3', '9'): + if sys.version_info < (3, 9): setattr(builtins, 'tuple', _orig_tuple) var_dict = {} exec( @@ -259,9 +260,10 @@ def patch_inner_helper(self, func): else: return var_dict['new_func'] finally: - if platform.python_version_tuple() < ('3', '9'): + if sys.version_info < (3, 9): setattr(builtins, 'tuple', tuple_wrapped) + class OperatorPatcherContext: ctx_tracer: Optional['ConcreteTracer'] = None ctx_patcher: Optional[OperatorPatcher] = None diff --git a/cube/parallel.py b/cube/parallel.py index 9a6a2d56..e1af8854 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -44,7 +44,7 @@ class ComputeConfig: @contextmanager -def _flags(flags, warning_on_override=True, **kwargs): +def _flags(flags, warning_on_override=True, /, **kwargs): old_flags = {} for k, v in kwargs.items(): old_flags[k] = getattr(flags, k) @@ -64,7 +64,7 @@ def _compile_flags(): def _runtime_flags(**kwargs): - return _flags(RuntimeFlag, warning_on_override=False, **kwargs) + return _flags(RuntimeFlag, False, **kwargs) def _complex(val: Any): diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index 14528467..63f1af2d 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -119,27 +119,27 @@ def _gpu_worker(pas, ngpus): def test_hook_tp_gpu1(): if not torch.cuda.is_available(): - print('skip test_submodules_tp_gpu1 due to lack of cuda devices') + print('skip test_hook_tp_gpu1 due to lack of cuda devices') return launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) def test_hook_tp_gpu2(): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - print('skip test_submodules_tp_gpu2 due to lack of cuda devices') + print('skip test_hook_tp_gpu2 due to lack of cuda devices') return launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) -def test_submodules_dp_gpu1(): +def test_hook_dp_gpu1(): if not torch.cuda.is_available(): - print('skip test_submodules_dp_gpu1 due to lack of cuda devices') + print('skip test_hook_dp_gpu1 due to lack of cuda devices') return launch_torchrun(1, _gpu_worker, PASData, 1) -def test_submodules_dp_gpu2(): +def test_hook_dp_gpu2(): if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - print('skip test_submodules_dp_gpu2 due to lack of cuda devices') + print('skip test_hook_dp_gpu2 due to lack of cuda devices') return launch_torchrun(2, _gpu_worker, PASData, 2) From 6ab20172588ee074264e9d2b5ce0ad6c86c5c24c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 11 Oct 2023 09:06:31 +0000 Subject: [PATCH 1501/1892] Merged PR 1846: fix example 1f1b and megatron policy this pr fixes the policy of PAS1F1B and PASMegatron, which had bugs of mis-alignment in API calls. --- examples/nlp/gpt/policy/mpmd.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index 279645a8..8d9978b7 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -57,7 +57,7 @@ def PAS1F1B(graph: IRGraph, resource, nmicros: int = 16, **kwargs): if isinstance(node, IRDataOperation): graph.assign(node, 0) - if graph.train(): + if graph.train: PredefinedSched.sched_1f1b(graph, num_microbatch, num_stages) else: PredefinedSched.sched_infer_pipe(graph, num_microbatch, num_stages) @@ -108,19 +108,19 @@ def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: for fnode in fstage.nodes(): if len(fnode.inputs()) == 0: continue # anchor if fnode.name == 'self_attention' or fnode.name == 'feedforward': - fnodes = tensor_parallelism(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + fnodes = tensor_parallelism(graph, fnode, idx=1, dim=0, devs=[0]*tp_size) elif fnode.name == 'embedding': - fnodes = tensor_parallelism(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + fnodes = tensor_parallelism(graph, fnode, idx=1, dim=0, devs=[0]*tp_size) elif fnode.name == 'linear': # the last embeding linear - fnodes = tensor_parallelism(graph, fnode, [0]*tp_size, idx=1, dim=0, num=tp_size) + fnodes = tensor_parallelism(graph, fnode, idx=1, dim=0, devs=[0]*tp_size) elif fnode.name == 'sum': - fnodes = tensor_parallelism(graph, fnode, [0]*tp_size, idx=0, dim=2, num=tp_size) + fnodes = tensor_parallelism(graph, fnode, idx=0, dim=2, devs=[0]*tp_size) else: fnodes = replica(graph, fnode, [0]*tp_size) # data parallel for tp_idx, fnode in enumerate(fnodes): dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] batch_dim = fnode.input(0).shape.index(bs) - tensor_parallelism(graph, fnode, idx=0, dim=batch_dim, num=dp_size, devs=dp_devices) + tensor_parallelism(graph, fnode, idx=0, dim=batch_dim, devs=dp_devices) PredefinedSched.sched_1f1b(graph, num_microbatch, pp_size) return graph From da68536a09dd861d872a5060bd166b3b29581af1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 19 Oct 2023 03:22:16 +0000 Subject: [PATCH 1502/1892] Merged PR 1844: Fix no-partition restriction of graph outputs in parallel module Fix no-partition limitation of graph outputs in parallel module. Related work items: #1511 --- cube/compiler.py | 1 - cube/parallel.py | 2 +- cube/program.py | 10 +++++++++- tests/parallel_module/common.py | 10 ---------- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index dafe120d..317edf72 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -101,7 +101,6 @@ def train_iter(model, dataloader): requires_grad=arg.requires_grad, dtype=arg.dtype).tosub() arg._value = tensor - arg.grad = arg.parent.grad.tosub() if arg.requires_grad else None else: arg = IRObject('obj', value=arg) inputs.append(arg) diff --git a/cube/parallel.py b/cube/parallel.py index e1af8854..0181f7ab 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -244,12 +244,12 @@ def _gencode( graph = program.get_graph() graph.backward() - program.finalize() program.set_input(ir_dummy_inputs) if ir_dummy_outputs is None: ir_dummy_outputs = [] elif not (isinstance(ir_dummy_outputs, tuple) or isinstance(ir_dummy_outputs, list)): ir_dummy_outputs = [ir_dummy_outputs] program.set_output(ir_dummy_outputs) + program.finalize() graph = pas_policy(graph, compute_config) if not isinstance(graph, IRGraph): diff --git a/cube/program.py b/cube/program.py index 7b81aded..d9b11757 100644 --- a/cube/program.py +++ b/cube/program.py @@ -2,7 +2,7 @@ import inspect from cube.ir.cten import IRCell, IRObject -from cube.ir.tensor import IRFullTensor +from cube.ir.tensor import IRFullTensor, IRSubTensor from cube.ir.operator import IRBpOperation, IRDataOperation from cube.graph import IRGraph @@ -47,11 +47,19 @@ def set_input(self, inputs: Tuple[Any]): self.instance._graph.reset_inputs(len(inputs)) for idx, obj in enumerate(inputs): self.instance._graph.set_input(idx, obj) + # update gradient + for t in IRGraph.get_objects_from_complex(self.instance._graph.inputs()): + if isinstance(t, IRSubTensor) and t.requires_grad: + t.grad = t.parent.grad.tosub() def set_output(self, outputs: Tuple[Any]): self.instance._graph.reset_outputs(len(outputs)) for idx, otensor in enumerate(outputs): self.instance._graph.set_output(idx, otensor) + # update gradient + for t in IRGraph.get_objects_from_complex(self.instance._graph.outputs()): + if isinstance(t, IRSubTensor) and t.requires_grad: + t.grad = t.parent.grad.tosub() def finalize(self): """ diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 17385b36..407955ca 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -50,19 +50,9 @@ def PASRandomSPMD(graph: IRGraph, env_resource: ComputeConfig): if len(graph.consumers(ftensor)) > 1: graph.multiref(ftensor) - graph_inputs = IRSegment.get_objects_from_complex(graph.inputs()) - graph_outputs = IRSegment.get_objects_from_complex(graph.outputs()) for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): if node.name == 'multiref' or isinstance(node, IRGraphAnchor): continue - # Currently cube only support replicate if node's input or input is part of the graph output - # workaround for now - # will fix later. - if any(output in graph_outputs for output in node.outputs()) \ - or any(input in graph_outputs for input in node.inputs()): - # or any(input in graph_inputs for input in node.inputs()): - _replica(graph, node, devs) - continue if isinstance(node, IRDimops): configs = node.transform_space() if len(configs) == 0: From c53d484313bf472fd399cff5710c4586e2d83e7b Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 20 Oct 2023 08:44:54 +0000 Subject: [PATCH 1503/1892] Merged PR 1861: Fix gradient alignment bug between IRSubTensor and IRFullTensor Previously after we change the `IRFullTensor.requires_grad = False`, if the `IRSubTensor` is already assigned with gradient tensor, it will not be removed and may raise error in `IRAdapterGen`. This PR fixes this issue by checking and assigning `None` if its parent requires_grad changes to False. --- cube/ir/tensor.py | 6 ++++++ tests/ir/tensor.py | 17 +++++++++++++++++ tests/parallel_module/test_gencode.py | 24 ++++++++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 tests/ir/tensor.py diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index b6de94be..b4bb6819 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -606,6 +606,12 @@ def requires_grad(self) -> bool: @property def grad(self) -> bool: + """Get the gradient of this tensor. + + The gradient is kept aligned with its parent IRFullTensor. + """ + if not self.requires_grad: + self._grad = None return self._grad @grad.setter diff --git a/tests/ir/tensor.py b/tests/ir/tensor.py new file mode 100644 index 00000000..4b9ba867 --- /dev/null +++ b/tests/ir/tensor.py @@ -0,0 +1,17 @@ +from cube.ir.tensor import IRSubTensor, IRFullTensor + + +def test_tensor_grad(): + + ftensor = IRFullTensor((128, 512), requires_grad=True) + subtensor = ftensor.tosub() + + assert isinstance(ftensor.grad, IRFullTensor) + subtensor.grad = ftensor.grad.tosub() + + assert isinstance(subtensor.grad, IRSubTensor) + + ftensor.requires_grad = False + assert ftensor.grad is None + assert subtensor.grad is None + assert subtensor.requires_grad is False diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 3f6b53ac..c5483969 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -44,3 +44,27 @@ def test_codegen(): m_new = _to_cube_model(m, ComputeConfig(2, 4), cube_savedir=tempdir, load_module=False) assert m_new is None launch_torchrun(1, _gencode_worker, tempdir) + + +class SliceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:2] + +def test_codegen_slice(): + if not torch.cuda.is_available(): + print('skip test_codegen_slice due to lack of cuda devices') + return + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + SliceModule(), + {'x': torch.tensor([1.0, 2.0, 3.0, 6.0])}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False + ) + assert m_new is None From 5dadacc4ed3216bf8bc4446e0b48e403b9f20858 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 23 Oct 2023 07:33:01 +0000 Subject: [PATCH 1504/1892] Merged PR 1862: add no save tensor hook for to_ir_graph --- cube/graph/parser/converter.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 7c4aa6ac..6203c77d 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -20,6 +20,16 @@ _logger = logging.getLogger(__name__) +class no_save_tensor_hook(saved_tensors_hooks): + """skip saving tensors for backward since tracer only traces forward""" + def __init__(self): + def pack(x): + return None + def unpack(x): + raise RuntimeError("not expecting backward to be called on this tensor") + super().__init__(pack, unpack) + + def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: """ Convert torch.nn.Module based model into torch.fx.GraphModule @@ -41,15 +51,6 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: leaf_functions.update({func: ([(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs}) dce_ignored_funcs = set(cube_rt_funcs) - class no_save_tensor_hook(saved_tensors_hooks): - """skip saving tensors for backward since tracer only traces forward""" - def __init__(self): - def pack(x): - return None - def unpack(x): - raise RuntimeError("not expecting backward to be called on this tensor") - super().__init__(pack, unpack) - with no_save_tensor_hook(): traced_model = concrete_trace( model, @@ -84,12 +85,13 @@ def to_ir_graph( """ _logger.info(f"use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") - inputs, nodes, outputs = FxModuleParser.parse( - traced_model, dummy_input, - attr_savedir=attr_savedir, - dynamic_shape=dynamic_shape, - save_content=True, - ) + with no_save_tensor_hook(): + inputs, nodes, outputs = FxModuleParser.parse( + traced_model, dummy_input, + attr_savedir=attr_savedir, + dynamic_shape=dynamic_shape, + save_content=True, + ) module_name = traced_model.__class__.__name__ for input in inputs: From e1d969fd5590d0afda46de2eb2a165e81acd8a1a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 24 Oct 2023 05:54:22 +0000 Subject: [PATCH 1505/1892] Merged PR 1866: Cell-Input-Output-0: Refine parser to accept complex type --- cube/graph/parser/__init__.py | 2 +- cube/graph/parser/frame.py | 87 ++----- cube/graph/parser/fx/parser.py | 373 ++++++++++-------------------- tests/graph/parser/test_parser.py | 33 +++ 4 files changed, 184 insertions(+), 311 deletions(-) create mode 100644 tests/graph/parser/test_parser.py diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py index 21838a42..ba172d07 100644 --- a/cube/graph/parser/__init__.py +++ b/cube/graph/parser/__init__.py @@ -1,4 +1,4 @@ -from cube.graph.parser.fx.parser import FxModuleParser, FxFuncOpTracer +from cube.graph.parser.fx.parser import FxModuleParser from cube.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph from cube.graph.parser.register import register from cube.graph.parser.external import * diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index bfd43527..9e5217f6 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -1,5 +1,6 @@ from collections import OrderedDict -from typing import List, Any, Dict +from typing import List, Any, Dict, Tuple, Optional +from cube.ir.cten import IRTensor import torch @@ -12,10 +13,8 @@ def __init__(self): # var name -> value (IRTesnor, deterministic) self._vars: List[dict[str, Any]] = list() self._var_stack: List[str] = list() - # module attributes - self._attributes: List[dict[str, Any]] = list() - self._attr_vals: Dict[int, Any] = dict() # tensor tid to real value mapping - self._name_map: Dict[Any, Any] = dict() # tensor name to real tensor name + # IRTensor -> (module param name, concrete value) + self._attr_map: Dict[IRTensor, Tuple[str, torch.Tensor]] = dict() def push_var(self, inherit_from_top=False): """ @@ -104,79 +103,41 @@ def get_var(self, var_name: str) -> Any: return self._vars[-1][var_name] raise KeyError(f"Cannot find var name {var_name} in {self._vars}") - def push_attr(self): - """ - Push a new module attribut frame as current frame. - This should only be called when stepping in the graph. - """ - self._attributes.append(OrderedDict()) - - def pop_attr(self): - """ - Pop the current module attribute frame. - This should only be called when stepping out the graph. - """ - self._attributes.pop() + def add_attr(self, tensor: IRTensor, concrete_value: torch.Tensor, name: str): + """Add module attribute content - def add_attr(self, name: str, val: Any): - """ - Add module attribute + Args: + tensor (IRTensor): the tensor represents the value + value (torch.Tensor or Any): concrete value + name (str): attributed name of its original module """ - if name in self._attributes[-1]: - raise KeyError("Try to add an already existed attributed") - self._attributes[-1][name] = val + assert isinstance(concrete_value, torch.Tensor) + self._attr_map[tensor] = (name, concrete_value) - def get_attr(self, name: str) -> Any: - """ - Get module attribute by name - """ - if name not in self._attributes[-1]: - raise KeyError(f"Cannot find var name {name}") - return self._attributes[-1][name] + def get_attr_var(self, concrete_value: torch.Tensor) -> Optional[IRTensor]: + """Get IRTensor from attribute concrete value - def has_attr(self, name: str) -> bool: - """ - Return if `name` exists in current attributes + If the concrete value is not found, return None """ - return name in self._attributes[-1] - - def add_attr_content(self, tid: int, val: torch.Tensor): - """ - Add module attribute content - """ - if torch.is_tensor(val): - val = val.cpu() - self._attr_vals[tid] = val + assert isinstance(concrete_value, torch.Tensor) + for tensor, (_, value) in self._attr_map.items(): + if value is concrete_value: + return tensor + return None def save_attr_content(self, save_file: str = 'fullmodel.pt'): """ Save attribute content into file. """ - torch.save(self._attr_vals, save_file) - - def add_attr_map(self, key, value): - """ - Add names map to connect internal parameter name and original parameter - """ - self._name_map[str(key)] = value - - def has_attr_value(self, value): - return value in self._name_map.values() - - def get_attr_key(self, value): - ret = None - for key, val in self._name_map.items(): - if val == value: - ret = key - break - return ret + tid2value = {t.tid: val.cpu() for t, (_, val) in self._attr_map.items()} + torch.save(tid2value, save_file) def save_attr_map(self, save_file: str = 'dist_param_map.pt'): """ Save local_param -> origin_param name map. """ - torch.save(self._name_map, save_file) - + ir_name_to_orig_name = {str(t.name): name for t, (name, _) in self._attr_map.items()} + torch.save(ir_name_to_orig_name, save_file) def push_param(self, var_name): """ diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index bdb874a5..620b4fa6 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -1,8 +1,7 @@ import torch -import enum import logging from pathlib import Path -from typing import Any, List, Tuple, Callable, Union, Dict, Type +from typing import Any, List, Tuple, Callable, Union, Dict, Type, Optional from cube.ir.operator import IRFwOperation from cube.ir.tensor import IRFullTensor @@ -17,51 +16,6 @@ _logger = logging.getLogger(__name__) -class ErasedDevice: - pass - -class FxNodeKind(enum.Enum): - PrimGetAttr = 1 - PrimCallMethod = 2 - PrimCallFunction = 3 # -> the parser may end here - PrimConstant = 4 - AtenOp = 5 # -> the parser may end here - PrimIf = 6 # dynamic - PrimListConstruct = 7 - PrimListUnpack = 8 - PrimTupleUnpack = 9 - PrimPythonOp = 10 - PrimDevice = 11 # erased - PrimLoop = 12 - PrimCallModule = 13 - # for torch.fx - Placeholder = 14 - Output = 15 - - -class FxFuncOpTracer(torch.fx.Tracer): - def __init__(self, *args, customed_leaf_module=None, **kwargs): - super().__init__(*args, **kwargs) - self.customed_leaf_module = customed_leaf_module - - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: - if self.customed_leaf_module and isinstance(m, self.customed_leaf_module): - return True - # capture torch.nn.functional return - return m.__module__.startswith('torch.nn.functional') and not isinstance(m, torch.nn.Sequential) - - -def get_complex_data(val: Any, frame: Frame) -> Any: - """Change inner fx.Node into IRObject""" - if isinstance(val, tuple): - return tuple(get_complex_data(t, frame) for t in val) - if isinstance(val, list): - return list(get_complex_data(t, frame) for t in val) - if isinstance(val, torch.fx.Node): - return frame.get_var(val.name) - return val - - class FxModuleParser: """ torch.fx module parser @@ -70,125 +24,56 @@ class FxModuleParser: ATTR_CONTENT_FILE = 'fullmodel.pt' ATTR_MAP_FILE = 'dist_param_map.pt' - @staticmethod - def shape_refine(shape: torch.Size) -> torch.Size: - """Replacing scale shape [] to [1] - - Args: - shape (torch.Size): tensor shape - - Returns: - torch.Size: refined shape - """ - return torch.Size([1]) if shape == torch.Size([]) else shape - - @staticmethod def parse(module: torch.fx.GraphModule, dummy_inputs: Dict[str, Any], - frame: Frame = None, attr_savedir='./', *, save_content: bool = True, dynamic_shape: bool = True - ) -> Tuple[List[IRFullTensor], List[IRFwOperation], List[IRFullTensor]]: + ) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: """Parse torch.fx module into cube IR The overall entry to parse a torch.fx graph module Args: + module (torch.fx.GraphModule): the torch.fx module + dummy_inputs (Dict[str, Any]): the dummy inputs to run the module + attr_savedir (str): the directory to save the attribute content save_content (bool): whether to save the content of the module dynamic_shape (bool): whether to parse the module with dynamic shape + + Returns: + inputs (List[IRObject]): the input IRObjects + all_ir_nodes (List[IRFwOperation]): the IRFwOperation nodes + outputs (List[IRObject]): the output IRObjects """ from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata - frame = frame if frame is not None else Frame() + frame = Frame() frame.push_var() - frame.push_attr() + # shape propagation assert isinstance(dummy_inputs, dict), "Expected dummy inputs to parse module" + ShapeProp(module).propagate(dummy_inputs) - inputs = [node for node in module.graph.nodes if node.op == 'placeholder'] - _logger.info(f'> torch.fx parser: graph inputs: {inputs}') + # create IRObjects and IRTensors + for node in module.graph.nodes: + concrete_value = dummy_inputs.get(node.name) if node.op == 'placeholder' else None + FxModuleParser.init_objects(node, module, frame, concrete_value) - # shape propagation - ShapeProp(module).propagate(dummy_inputs) - # handle graph inputs - for idx, input in enumerate(inputs): - assert isinstance(input, torch.fx.Node) - # dealing with different types of dummy_inputs - if not isinstance(dummy_inputs, dict): - raise RuntimeError('dummy_inputs should be a dict.') - if input.name not in dummy_inputs: - val = IRObject(input.name) - else: - if 'tensor_meta' in input.meta and isinstance(input.meta['tensor_meta'], TensorMetadata): - shape = input.meta['tensor_meta'].shape - if len(shape) == 0: - shape = [1] - dtype = input.meta['tensor_meta'].dtype - val = IRFullTensor(shape=shape, requires_grad=False, dtype=dtype, name=input.name) - else: - val = IRObject(input.name, value=dummy_inputs[input.name]) - frame.add_var(input.name, val, graph_arg=idx) - - input_val = [frame.get_var(input.name) for input in inputs] - - # add activations to frame, including call_func/call_method output and final output - # call_module corresponds to leaf torch.nn.module - activation_op_strs = {'call_function', 'output', 'call_method', 'call_module'} - activation_nodes = [node for node in module.graph.nodes if node.op in activation_op_strs] - def parse_complex_out(meta_out): - if isinstance(meta_out, TensorMetadata): - shape = meta_out.shape - assert shape == torch.Size([]), f'{meta_out}' - return IRFullTensor(shape=shape, requires_grad=meta_out.requires_grad, dtype=meta_out.dtype) - elif isinstance(meta_out, dict): - ret = {} - for k, v in meta_out.items(): - ret[k] = parse_complex_out(v) - return ret - else: - return meta_out - for node in activation_nodes: - if hasattr(node, 'meta') and node.meta.get('tensor_meta'): - assert isinstance(node, torch.fx.Node) - if isinstance(node.meta['tensor_meta'], TensorMetadata): - meta_outs = (node.meta['tensor_meta'],) - else: - meta_outs = node.meta['tensor_meta'] - vals = list() - for meta_out in meta_outs: - if isinstance(meta_out, TensorMetadata): - shape = meta_out.shape - shape = FxModuleParser.shape_refine(shape) - dtype = meta_out.dtype - requires_grad = meta_out.requires_grad - val = IRFullTensor(shape=shape, requires_grad=requires_grad, dtype=dtype, name=node.name) - else: - val = IRObject(value=parse_complex_out(meta_out)) - vals.append(val) - if len(vals) == 1: - frame.add_var(node.name, vals[0]) - else: - frame.add_var(node.name, vals) - else: - frame.add_var(node.name, IRObject()) + # get graph inputs + inputs = [frame.get_var(n.name) for n in module.graph.nodes if n.op == 'placeholder'] - # handle nodes - all_ir_nodes: List[IRFwOperation] = list() - total_node_num = len(module.graph.nodes) - for nidx, node in enumerate(module.graph.nodes): - _logger.info(f'[{nidx}/{total_node_num}] parsing node {node}...') + # parse graph nodes + all_ir_nodes = [] + for node in module.graph.nodes: ir_nodes = FxModuleParser.parse_node(node, module, dynamic_shape, frame) - if ir_nodes is not None: - all_ir_nodes += ir_nodes - - # output_nodes = [node for node in module.graph.nodes if node.op == 'output'] - # assert len(output_nodes) == 1, f"get mutiple {len(all_ir_nodes)} output nodes" - # output_val = frame.get_var(output_nodes[0].name) - output_val = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + all_ir_nodes += ir_nodes + + # get graph outputs + outputs = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] if save_content: attr_savedir = Path(attr_savedir) @@ -196,51 +81,89 @@ def parse_complex_out(meta_out): frame.save_attr_map(attr_savedir / FxModuleParser.ATTR_MAP_FILE) frame.pop_var() - frame.pop_attr() - - return input_val, all_ir_nodes, output_val - + return inputs, all_ir_nodes, outputs @staticmethod - def ntype(node: torch.fx.Node): - if node.op == 'call_module': - return FxNodeKind.PrimCallModule - if node.op == 'call_function': - return FxNodeKind.PrimCallFunction - if node.op == 'get_attr': - return FxNodeKind.PrimGetAttr + def parse_node(node: torch.fx.Node, module, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: + """ + Parse the node and return the IRFwOperation nodes + """ if node.op == 'placeholder': - return FxNodeKind.Placeholder + return [] if node.op == 'output': - return FxNodeKind.Output - if node.op == 'call_method': - return FxNodeKind.PrimCallMethod - raise RuntimeError(f"Unknown node kind {node.kind()} from torchscript module") + return FxModuleParser.parse_prim_output_node(node, module, frame) + if node.op in ('call_function', 'call_method'): + return FxModuleParser.parse_prim_function_method(node, module, dynamic_shape, frame) + if node.op == 'get_attr': + return FxModuleParser.parse_prim_get_attr_node(node, module, frame) + if node.op == 'call_module': + return FxModuleParser.parse_prim_module(node, module, frame) + else: + raise TypeError(f"Unknown node kind {node.op}") @staticmethod - def parse_node(node: torch.fx.Node, module, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: - """ - Parse the node and return the IRFwOperation nodes + def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, + frame: Frame, concrete_value: Optional[Any] = None): + + from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata + assert isinstance(node, torch.fx.Node) + + def meta2var(meta: Any) -> Any: + """Support complex data type of List, Tuple, Dict, Tensor/Object""" + if isinstance(meta, TensorMetadata): + shape = meta.shape + # TODO: support scalar type + shape = torch.Size([1]) if shape == torch.Size([]) else shape + dtype = meta.dtype + requires_grad = meta.requires_grad + return IRFullTensor(shape=shape, name=node.name, + requires_grad=requires_grad, dtype=dtype) + if isinstance(meta, list): + return list(meta2var(item) for item in meta) + if isinstance(meta, tuple): + return tuple(meta2var(item) for item in meta) + if isinstance(meta, dict): + if not all(isinstance(key, str) for key in meta.keys()): + raise TypeError(f"only support dict type with str key, but got {meta.keys()}.\n{node}") + return {key : meta2var(value) for key, value in meta.items()} + # TODO: data type check, with cases like {'a': 1.2, 'b': torch.Tensor} + return meta + + if hasattr(node, 'meta') and node.meta.get('tensor_meta'): + meta = node.meta['tensor_meta'] + val = meta2var(meta) + else: + # FIXME: double check: there should be a concrete value as example, + # otherwise, it may fail in parsing node like getattr + val = IRObject(name=node.name, value=concrete_value) + + frame.add_var(node.name, val) + + @staticmethod + def parse_complex(val: Any, frame: Frame) -> Any: + """parse complex fx.Node into IRObject + + The val is usually from a node's input or output, can be fx.Node nested + by tuple/list/dict type, or a fx.Node itself. + + Args: + val (Any): fx.Node nested by tuple/list/dict + frame (Frame): the frame to get the fx.Node + + Returns: + the copied strcuture where the fx.Node is replaced by IRObjects/IRTensors """ - node_type = FxModuleParser.ntype(node) - try: - if node_type == FxNodeKind.Placeholder: - return [] - if node_type == FxNodeKind.Output: - return FxModuleParser.parse_prim_output_node(node, module, frame) - if node_type in (FxNodeKind.PrimCallFunction, FxNodeKind.PrimCallMethod): - return FxModuleParser.parse_prim_function_method(node, module, dynamic_shape, frame) - if node_type == FxNodeKind.PrimGetAttr: - return FxModuleParser.parse_prim_attr_node(node, module, frame) - if node_type == FxNodeKind.PrimCallModule: - return FxModuleParser.parse_prim_module(node, module, frame) - - # TODO bother assigning all ignored prim functions new NodeKinds? - if node_type == FxNodeKind.PrimDevice: - return FxModuleParser.parse_value_erased_node(node, module, frame, [ErasedDevice()]) - raise NotImplementedError(f"Un-supported node type {node_type}") - except Exception: - raise RuntimeError(f"\n\nParsing error at node:\n\t{node}\n") + # to support more nested types, we can refer to the implementation of + # https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py + if isinstance(val, tuple): + return tuple(FxModuleParser.parse_complex(t, frame) for t in val) + if isinstance(val, list): + return list(FxModuleParser.parse_complex(t, frame) for t in val) + if isinstance(val, dict): + return {key: FxModuleParser.parse_complex(val, frame) for key, val in val.items()} + if isinstance(val, torch.fx.Node): + return frame.get_var(val.name) + return val @staticmethod def fetch_attr(mod: torch.fx.GraphModule, target: str): @@ -264,16 +187,10 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: # get signature fsig = FxModuleParser._get_qualified_name(node.target, node) - # get inputs - input_vals = [get_complex_data(val, frame) for val in node.args] - kwargs = {key: get_complex_data(val, frame) for key, val in node.kwargs.items()} + input_vals = FxModuleParser.parse_complex(list(node.args), frame) + kwargs = FxModuleParser.parse_complex(node.kwargs, frame) - return FxModuleParser._parse_node(fsig, node, input_vals, kwargs, dynamic_shape, frame) - - @staticmethod - def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: - # map to IR operator if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) else: @@ -326,77 +243,39 @@ def _parse_node(fsig: str, node: torch.fx.Node, input_vals: list, kwargs: dict, return ir_nodes @staticmethod - def parse_prim_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: """ There are two types of get_attr, one is `FxNodeKind.PrimGetAttr` which is dealt with in this function. The other is `FxNodeKind.PrimCallFunction ` (i.e., ) which is dealt with by parse_prim_function_method. """ - assert node is not None - tensor_name = node.name - if 'tensor_meta' in node.meta: - tensor_shape = node.meta['tensor_meta'].shape - dtype = node.meta['tensor_meta'].dtype - requires_grad = node.meta['tensor_meta'].requires_grad - - # check if existing param - if requires_grad and frame.has_attr_value(node.target): # existing param - prev_tensor_name = frame.get_attr_key(node.target) - frame.add_var(tensor_name, frame.get_var(prev_tensor_name)) - else: # new param / activation - ir_tensor = IRFullTensor(tensor_shape, tensor_name, requires_grad=requires_grad, dtype=dtype) - if requires_grad: # case for registered parameters - ir_tensor.as_param() - else: # case for registered buffers - ir_tensor.as_buffer() - frame.add_var(tensor_name, ir_tensor) - value = FxModuleParser.fetch_attr(module, node.target) - frame.add_attr_content(ir_tensor.tid, value) - frame.add_attr_map(ir_tensor.name, node.target) + concrete_value = FxModuleParser.fetch_attr(module, node.target) + if isinstance(concrete_value, torch.Tensor): + assert isinstance(concrete_value, torch.Tensor), \ + f"GetAttrPrim: expect tensor but got {type(concrete_value)}" + exist_tensor = frame.get_attr_var(concrete_value) + # the cath that the parameter is the first time used by getattr + if not exist_tensor: + tensor = frame.get_var(node.name) + if tensor.requires_grad: + tensor.as_param() + else: + tensor.as_buffer() + frame.add_attr(tensor, concrete_value, node.target) + # the case that the parameter is consumed multiple times and regisetered previously + else: + frame.set_var(node.name, exist_tensor) else: - var = FxModuleParser.fetch_attr(module, node.target) - frame.add_var(tensor_name, var) - - return None + assert not isinstance(concrete_value, torch.Tensor), f"GetAttrPrim: unexpected parameter" + frame.set_var(node.name, concrete_value) + return [] @staticmethod def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 - ir_nodes = [] - - def generate_outputs(val: Any) -> Any: - """Support complex data type of List, Tuple, Dict, Tensor/Object""" - if isinstance(val, list): - return list(generate_outputs(item) for item in val) - if isinstance(val, tuple): - return tuple(generate_outputs(item) for item in val) - if isinstance(val, dict): - return {generate_outputs(key) : generate_outputs(value) for key, value in val.items()} - if isinstance(val, torch.fx.Node): - return frame.get_var(val.name) - # for other types like int, float, ... - return val - output = generate_outputs(node.args[0]) - + output = FxModuleParser.parse_complex(node.args[0], frame) frame.set_var(node.name, output) - return ir_nodes - - # # NOTE: this is a function in torch.fx - # @staticmethod - # def _get_qualified_name(func: Callable[..., Any]) -> str: - # # things like getattr just appear in builtins - # if getattr(builtins, func.__name__, None) is func: - # return func.__name__ - # # torch.Tensor.{fn} - # if isinstance(func, types.MethodDescriptorType) and func is getattr(torch.Tensor, func.__name__, None): - # return f"torch.Tensor.{func.__name__}" - # name = func.__name__ - # module = FxModuleParser._find_module_of_method(func) - # module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module - # # Fixup segment_reduce mismatch - # if module == "torch" and name == "segment_reduce": - # name = "_" + name - # return f'{module}.{name}' + return [] @staticmethod def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py new file mode 100644 index 00000000..33f5df19 --- /dev/null +++ b/tests/graph/parser/test_parser.py @@ -0,0 +1,33 @@ +import tempfile +import torch +from cube.graph.parser.converter import to_fx_graph, to_ir_graph + + +def test_multi_consume(): + + class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.param1 = torch.nn.Parameter(torch.empty(4, 4)) + self.param2 = torch.nn.Parameter(torch.empty(4, 4)) + + def forward(self, x): + shortcut = x + x = torch.matmul(x, self.param1) + x = x + self.param2 + x = x + shortcut + x = x + self.param1 + return torch.sum(x) + + dummy_input = {'x': torch.randn(4, 4)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + assert ir_graph is not None + assert len(ir_graph.attributes()) == 2 # param1 and param2 + assert len(ir_graph.full_tensors()) == 8 From 65984a4bdb8cfbe7bdf00743a72976416e32a159 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 24 Oct 2023 06:55:30 +0000 Subject: [PATCH 1506/1892] Merged PR 1867: parallel module: fix forward function signature handle the case when some of the forward arguments are not used in tracing. parity-check has passed. --- cube/codegen/module/module.py | 32 ++++++++-- .../concrete_trace_utils/concrete_tracer.py | 1 + cube/parallel.py | 7 ++- tests/conftest.py | 18 ++++++ tests/parallel_module/common.py | 1 + tests/parallel_module/test_gencode.py | 62 +++++++++++++++++++ 6 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 tests/conftest.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 2f638a30..71ed7063 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -393,17 +393,39 @@ def gen( # will use the orignal names of inputs inputs = [t.name for t in node.inputs() if not isinstance(t, IRSubTensor) or not t.is_attr()] # ensure forward args are valid + unused_args = [] if forward_arg_names: for i in range(len(inputs)): - if inputs[i] != forward_arg_names[i]: - raise ValueError(f"Forward args mismatch: {inputs[i]} != {forward_arg_names[i]}") - for i in range(len(inputs), len(forward_arg_names)): + if inputs[i] not in forward_arg_names: + raise ValueError(f"Forward args mismatch: {inputs[i]} arg needed") + + forward_args = [] + # find the first mismatch + for i in range(len(inputs)): + if inputs[i] == forward_arg_names[i]: + forward_args.append(inputs[i]) + else: + break + + for i in range(len(forward_args), len(forward_arg_names)): if not forward_arg_names[i].startswith('*'): - raise ValueError(f"Invalid extra forward args: only *args & **kwargs are allowed") + forward_args.append(f'{forward_arg_names[i]}=None') + if forward_arg_names[i] not in inputs: + unused_args.append(forward_arg_names[i]) + _logger.warning(f'Unused forward argument `{forward_arg_names[i]}`.' + f'The argument value will be ignored when you call module forward') + else: + forward_args.append(forward_arg_names[i]) + forward_arg_names = forward_args + else: + forward_arg_names = inputs - with FunctionBlock(func_name='_forward_impl', args=['self'] + (forward_arg_names or inputs)) as fb: + with FunctionBlock(func_name='_forward_impl', args=['self'] + forward_arg_names) as fb: outputs = self.return_name(node.outputs(), skip_attr=True) call_code = f'{outputs} = self.{self.node_name(node)}({", ".join(inputs)})' + # be sure the user doesn't specify unused args. + for unused_arg in unused_args: + fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') fb.insert_body(call_code) return_code = f'return {self.return_name_complex(self.execplan.graph.outputs())}' fb.insert_body(return_code) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 4d8f7159..6528061e 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -607,6 +607,7 @@ def create_args_for_root(self, root_fn, is_module, concrete_args: Union[Dict[str # defined via ``functools.wraps``. In this case, the outer code object # will likely not contain the actual parameters we care about, so unwrap # the function to get to the innermost callable. + # TODO: keyward-only arguments are not supported now fn_for_analysis = inspect.unwrap(root_fn) default_value_list = fn_for_analysis.__defaults__ if default_value_list is None: diff --git a/cube/parallel.py b/cube/parallel.py index 0181f7ab..66cfd971 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -167,7 +167,7 @@ def _gencode( # False | imported | doesn't matter if not override: # check if the module is already generated - expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.plan_ngpus)] + expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] expected_output_files.append(outdir / FxModuleParser.ATTR_CONTENT_FILE) expected_output_files.append(outdir / FxModuleParser.ATTR_MAP_FILE) expected_output_files.append(outdir / ParallelModule.COMPUTE_CONFIG_FILE) @@ -177,6 +177,11 @@ def _gencode( and len(existing_output_files) == len(expected_output_files) \ and torch.load(outdir / ParallelModule.COMPUTE_CONFIG_FILE) == compute_config: return + elif all(f.suffix != '.py' for f in existing_output_files): + # No python source code is generated. + # which means its last generation failed. + # in this case, we can reuse the same directory safely. + pass else: raise RuntimeError(f'Output directory {outdir} is not empty. ' f'And the existing files do not match with current config.') diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..1340efc1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,18 @@ +import pytest +from pathlib import Path + +from cube.graph.parser.fx.parser import FxModuleParser + +@pytest.fixture(autouse=True) +def clean_generated_files(): + print('hello') + yield + # try to clean generated files after each test run. + basedir = Path('./').resolve() + generated_files = [FxModuleParser.ATTR_CONTENT_FILE, FxModuleParser.ATTR_MAP_FILE] + for f in generated_files: + f = basedir / f + if f.exists(): + f.unlink() + for f in basedir.glob('gencode*.py'): + f.unlink() diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 407955ca..c5492eb2 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -165,5 +165,6 @@ def clear_dir_on_rank0(tempdir): if torch.distributed.get_rank() == 0 and tempdir.exists(): shutil.rmtree(tempdir) yield tempdir + torch.distributed.barrier() if torch.distributed.get_rank() == 0 and tempdir.exists(): shutil.rmtree(tempdir) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index c5483969..0b22ed64 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -1,3 +1,4 @@ +import inspect import tempfile import torch @@ -54,6 +55,9 @@ def forward(self, x): return x[:2] def test_codegen_slice(): + """ + Test it can support modules without parameters + """ if not torch.cuda.is_available(): print('skip test_codegen_slice due to lack of cuda devices') return @@ -68,3 +72,61 @@ def test_codegen_slice(): load_module=False ) assert m_new is None + + +class UnusedArgsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, y, z=None, m=None, n=None, **kwargs): + return self.linear(x) + m + + +def _gencode_unused_args_worker(tempdir): + init_distributed() + m_new = parallelize( + UnusedArgsModule(), + { + 'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + 'y': torch.tensor([1, 2, 3]), + 'z': None, + 'm': 0, + 'n': None + }, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=True + ) + assert m_new is not None + args = inspect.signature(m_new._forward_impl) + assert len(args.parameters) == 6 + assert args.parameters['x'].default is inspect.Parameter.empty + assert args.parameters['y'].default is None + assert args.parameters['z'].default is None + assert args.parameters['m'].default is None + assert args.parameters['n'].default is None + + with pytest.raises(TypeError): + # m can't be None + # TypeError is raised by torch.add + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) + + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), m=1) + + with pytest.raises(ValueError): + # y must be None + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) + +def test_codegen_unused_args(): + """ + Verify that unused args are supported by parallalize + """ + if not torch.cuda.is_available(): + print('skip test_unused_input due to lack of cuda devices') + return + + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, _gencode_unused_args_worker, tempdir) From 639f1aba628748facc6907c782e0479bc9a67df6 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 24 Oct 2023 09:35:54 +0000 Subject: [PATCH 1507/1892] Merged PR 1848: support some functions for new models --- cube/graph/function/function.py | 210 +++++++++++++----- .../concrete_trace_utils/operator_patcher.py | 4 +- cube/graph/parser/fx/mapping.py | 9 +- cube/runtime/function/function.py | 14 +- tests/graph/function/test_functions.py | 78 +++++++ 5 files changed, 258 insertions(+), 57 deletions(-) create mode 100644 tests/graph/function/test_functions.py diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 7a17ae9c..c09d1d7d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -251,7 +251,8 @@ def Full(size, fill_value, *, out=None, dtype=None, layout=None, assert layout in (None, torch.strided), f"Not support for non-default layout" dtype = dtype if dtype is not None else torch.get_default_dtype() signature = 'cube.runtime.function.full' - size = tuple(size) + # cube treat scalar as size (1,) tensor now, scalar support will in another pr if necessary + size = tuple(size) if size else (1,) anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) dimop = IRDimops(Full, 'full', signature, [anno], [], rules, @@ -275,66 +276,100 @@ def NewTensor(data, *, dtype=None, device=None, def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: - """! - Create shape annotations for element wise operator following broadcastable rules: + """Create shape annotations for element wise operator following broadcastable rules: https://pytorch.org/docs/stable/notes/broadcasting.html - @param lhs IRTensor: the lhs input tensor - @param rhs IRTensor: the rhs input tensor + Args: + lhs IRTensor: the lhs input tensor + rhs IRTensor: the rhs input tensor - @return lhs_shape, rhs_shape, out_shape: the lhs, rhs and output shape annotation + Returns: + lhs_anno List[str]: lhs shape annotation + rhs_anno List[str]: rhs shape annotation + out_anno List[str]: output shape annotation """ - lndims, rndims = len(lhs.shape), len(rhs.shape) - # init lhs_shape and rhs_shape annotation string - shape_anno = ShapeAnno.create_shape_str(lhs.shape if lndims > rndims else rhs.shape) - lhs_shape = shape_anno[0-lndims:] - rhs_shape = shape_anno[0-rndims:] + ins_anno, out_anno = _handle_broadcast_multi([lhs, rhs]) + assert len(ins_anno) == 2 + return ins_anno[0], ins_anno[1], out_anno + + +def _handle_broadcast_multi(ins_list: List[IRTensor]) -> Tuple[Tuple[List[str]], List[str]]: + """Similar to ``_handle_broadcast``, handle broadcast for more than two input tensors. + + Create shape annotations for element wise operator following broadcastable rules: + https://pytorch.org/docs/stable/notes/broadcasting.html + + Args: + ins_list List[IRTensor]: the list of input tensors + + Returns: + ins_anno (Tuple[List[str]]): a list of input tensors annotation + out_anno (List[str]): output shape annotation + """ + assert len(ins_list) >= 2, 'at least two tensor require for broadcast' + ins_ndims = [len(inp.shape) for inp in ins_list] + # init annotation string + maxlen_shape = ins_list[ins_ndims.index(max(ins_ndims))].shape + shape_anno = ShapeAnno.create_shape_str(maxlen_shape) + ins_anno = [shape_anno[-ndims:] for ndims in ins_ndims] # expand dimensions for empty dimensions - lofst = max(lndims, rndims) - lndims - lshape = [1] * lofst + list(lhs.shape) - rofst = max(lndims, rndims) - rndims - rshape = [1] * rofst + list(rhs.shape) + ins_ofst = [max(ins_ndims) - ndims for ndims in ins_ndims] + ins_shape = [[1] * ins_ofst[idx] + list(inp.shape) for idx, inp in enumerate(ins_list)] # init out_shape - out_shape = [] - for dim in range(len(lshape)): - ldim_anno = None if dim - lofst < 0 else lhs_shape[dim-lofst] - rdim_anno = None if dim - rofst < 0 else rhs_shape[dim-rofst] - if lshape[dim] == rshape[dim]: - assert rdim_anno is not None or ldim_anno is not None - out_shape.append(rdim_anno if rdim_anno is not None else ldim_anno) - elif lshape[dim] == 1: - assert rdim_anno is not None - out_shape.append(rdim_anno) - if ldim_anno is not None: - lhs_shape[dim-lofst] = '1' - elif rshape[dim] == 1: - assert ldim_anno is not None - out_shape.append(ldim_anno) - if rdim_anno is not None: - rhs_shape[dim-rofst] = '1' + out_anno = [] + for dim in range(len(maxlen_shape)): + dim_annos = [None if dim - ins_ofst[idx] < 0 else anno[dim-ins_ofst[idx]] for idx, anno in enumerate(ins_anno)] + not_none_annos = [anno for anno in dim_annos if anno is not None] + assert len(not_none_annos) > 0 + not_none_anno = not_none_annos[0] + if all(x[dim] == ins_shape[0][dim] for x in ins_shape): + out_anno.append(not_none_anno) else: - raise ValueError(f"cannot broadcast lhs: {lhs.shape} and rhs: {rhs.shape}") - return lhs_shape, rhs_shape, out_shape + not_one_shape = [shape[dim] for shape in ins_shape if shape[dim] != 1] + if len(not_one_shape) > 0 and not all(s == not_one_shape[0] for s in not_one_shape): + raise ValueError(f"cannot broadcast tensor list: {ins_list}") + else: + out_anno.append(not_none_anno) + for idx, anno in enumerate(dim_annos): + if anno is not None and ins_shape[idx][dim] == 1: + ins_anno[idx][dim-ins_ofst[idx]] = '1' + return ins_anno, out_anno -def Expand(input, *sizes, signature = None): +def Expand(input, *sizes, size = None, signature = None): """ torch.Tensor.expand(*sizes) - """ - signature = 'cube.runtime.function.expand' + + The reason of add ``size`` to this function argument is: + 1. ``sizes`` need to reuse in IRDimops.new(), but it is a ``non-keyword arguments``, + and can not put it into keyword arguments (something like Expand(input, sizes=[1, 2, 3])) is not work, + to support IRDimops.new API, here add a ``size`` to workaround. + + 2. in torch._C.expand API, it has: + def expand(self, size: Sequence[Union[_int, SymInt]], *, implicit: _bool=False) -> Tensor: ... + so add ``size`` can also solve user using something like: + torch.rand(3, 1).expand(size=(3, 3)) + """ + signature = 'torch.Tensor.expand' + if size is not None: + assert len(sizes) == 0 + sizes = size ori_len, exp_len = len(input.shape), len(sizes) assert ori_len <= exp_len assert all(dim == expand_dim or dim == 1 or expand_dim == -1 for dim, expand_dim in zip(input.shape, sizes[-ori_len:])) edim_ou = ShapeAnno.create_shape_str(sizes) edim_in = copy.copy(edim_ou[-ori_len:]) + new_size = [-1] * len(sizes) for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes[-len(input.shape):])): if dim == 1 and dim != expand_dim and expand_dim != -1: edim_in[idx] += '^' edim_ou[exp_len - ori_len + idx] = str(expand_dim) + new_size[exp_len - ori_len + idx] = expand_dim for idx in range(exp_len - ori_len): edim_ou[idx] = str(sizes[idx]) + new_size[idx] = sizes[idx] anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Expand, 'expand', signature, [anno], [input], sizes=sizes) + return IRDimops(Expand, 'expand', signature, [anno], [input], size=new_size) def ExpandAs(input, other, signature = None): @@ -703,6 +738,28 @@ def MaskedFill(input, mask, value, signature = None): return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input, mask], value=value) +def Where(condition, input, other, *, out=None, signature = None): + """ + torch.where + """ + assert isinstance(condition, IRTensor) + if isinstance(input, IRTensor) and isinstance(other, IRTensor): + (edim_in0, edim_in1, edim_in2), edim_out = _handle_broadcast_multi([condition, input, other]) + elif isinstance(input, IRTensor) and len(input.shape) > 0 and not (len(input.shape) == 1 and input.shape[0] == 1): + edim_in0, edim_in1, edim_out = _handle_broadcast(condition, input) + edim_in2 = ['?'] + elif isinstance(other, IRTensor) and len(other.shape) > 0 and not (len(other.shape) == 1 and other.shape[0] == 1): + edim_in0, edim_in2, edim_out = _handle_broadcast(condition, other) + edim_in1 = ['?'] + else: + edim_in0 = ShapeAnno.create_shape_str(condition.shape) + edim_in1, edim_in2 = ['?'], ['?'] + edim_out = copy.copy(edim_in0) + + annos = [OpAnno.create_op_str([edim_in0, edim_in1, edim_in2], [edim_out])] + dimop = IRDimops(Where, 'where', signature, annos, [condition, input, other]) + return dimop + def CubeLayerNorm(input, weight=None, bias=None, normalized_shape=None, eps=1e-05, signature = None): """ cube.runtime.function.layer_norm(input, weight, bias, normliazed_shape, eps) @@ -1109,12 +1166,12 @@ def Squeeze(input, dim=None, signature = None): """ out = torch.squeeze(tensor) """ - assert dim is None, "got dim: {dim} != None, which is not supported" + dim = (dim,) if isinstance(dim, int) else dim edim_in = ShapeAnno.create_shape_str(input.shape) assert len(edim_in) == len(input.shape) edim_ou = [] - for dim_anno, dim_size in zip(edim_in, input.shape): - if dim_size > 1: + for idx, (dim_anno, dim_size) in enumerate(zip(edim_in, input.shape)): + if dim_size > 1 or (dim is not None and idx not in dim): edim_ou.append(copy.copy(dim_anno)) anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(Squeeze, 'squeeze', signature, [anno], [input]) @@ -1379,6 +1436,21 @@ def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice, int]], signatu """ signature = 'cube.runtime.function.fullslice' slicers = tuple(slicers) + + # deal with ... in slice + if any(slicer is Ellipsis for slicer in slicers): + front_slicers, back_slicers, ellipsis_flag = [], [], False + for slicer in slicers: + if not slicer is Ellipsis: + front_slicers.append(slicer) if not ellipsis_flag else back_slicers.append(slicer) + else: + ellipsis_flag = True + front_count = len([slicer for slicer in front_slicers if slicer is not None]) + back_count = len([slicer for slicer in back_slicers if slicer is not None]) + assert front_count + back_count <= len(tensor.shape) + mid_slicers = [slice(None, None, None) for _ in range(len(tensor.shape) - front_count - back_count)] + slicers = tuple(front_slicers + mid_slicers + back_slicers) + edim_in = ShapeAnno.create_shape_str(tensor.shape) edim_ou = [] in_idx = 0 @@ -1397,8 +1469,9 @@ def obj_helper(obj): if slicer != slice(None, None, None): edim_in[in_idx] += '^' _start, _stop, _step = obj_helper(slicer.start), obj_helper(slicer.stop), obj_helper(slicer.step) - start = 0 if _start is None else _start - stop = tensor.shape[in_idx] if _stop is None else _stop + start = 0 if _start is None else _start + tensor.shape[in_idx] if _start < 0 else _start + stop = tensor.shape[in_idx] if _stop is None else _stop + tensor.shape[in_idx] if _stop < 0 else _stop + start, stop = min(start, tensor.shape[in_idx]), min(stop, tensor.shape[in_idx]) step = 1 if _step is None else _step dimlen = len(range(start, stop, step)) if dimlen == tensor.shape[in_idx]: @@ -1563,13 +1636,18 @@ def CrossEntropy(input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0) """ - # FIXME: reduction is by default 'mean', in this way it cannot be partitioned - # no N dimension. - annos = [ - 'C^, N -> 1', - 'N+ C, N+ -> 1', - 'N+ C *, N+ * -> 1' - ] + if reduction == 'sum': + annos = [ + 'C^, N -> 1', + 'N+ C^, N+ -> 1', + 'N+ C^ *, N+ * -> 1' + ] + else: + annos = [ + 'C^, N -> 1', + 'N^ C^, N^ -> 1', + 'N^ C^ *, N^ * -> 1' + ] return IRDimops( CrossEntropy, 'cross_entropy', signature, annos, [input, target], @@ -1658,6 +1736,34 @@ def CompareNE(input, other, *, out=None, signature = None): return _comparison(CompareNE, operator.eq, 'ne', signature, input, other) +def Max(input, other_or_dim=None, out_or_keepdim=None, *, out=None, signature = None): + """ + torch.max(input) + torch.max(input, dim, keepdim=False, *, out=None) + torch.max(input, other, *, out=None) + """ + signature = 'cube.runtime.function.max_' + if other_or_dim is None: + edim_in = [s + '^' for s in ShapeAnno.create_shape_str(input.shape)] + annos = [OpAnno.create_op_str([edim_in], ['1'])] + return IRDimops(Max, 'max', signature, annos, [input]) + elif isinstance(other_or_dim, IRTensor): + lshape, rshape, oshape = _handle_broadcast(input, other_or_dim) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(Max, 'max', signature, annos, [input, other_or_dim]) + else: + assert isinstance(other_or_dim, int) and isinstance(out_or_keepdim, bool) + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_in[other_or_dim] += '^' + edim_out = copy.copy(edim_in) + if out_or_keepdim: + edim_out[other_or_dim] = '1' + else: + edim_out.pop(other_or_dim) + annos = [OpAnno.create_op_str([edim_in], [edim_out, edim_out])] + return IRDimops(Max, 'max', signature, annos, [input], other_or_dim=other_or_dim, out_or_keepdim=out_or_keepdim) + + def ShapeAsTensor(input: IRTensor, signature = None): """ torch._shape_as_tensor @@ -1697,12 +1803,14 @@ def Dim(tensor, signature=None) -> Union[List[int], IRPyFunc]: return len(tensor.shape) -def To(tensor: IRTensor, dtype_or_device, *, out=None, signature = None): +def To(tensor: IRTensor, dtype_or_device=None, *, device=None, dtype=None, out=None, signature = None): """ torch.Tensor.to(*args, **kwargs) → Tensor """ assert out is None # FIXME: support full version of torch.Tensor.to + dtype_or_device = dtype if dtype is not None else dtype_or_device + dtype_or_device = device if dtype_or_device is None else dtype_or_device # create "to" in cube runtime functions because dtype if not kwarg in torch.Tensor.to signature = 'cube.runtime.function.to' annos = ['* -> *'] diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index 22405b7e..f7a95a5b 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -227,10 +227,12 @@ def patch_inner_helper(self, func): ] body0.name = 'new_func' # for deleting some annotations like 'add_start_docstrings_to_model_forward' or 'add_code_sample_docstrings' + # these decorators are used for tranformers model docstrings generation, can be removed in trace + transform_useless_decorators = ('add_start_docstrings_to_model_forward', 'add_code_sample_docstrings', 'replace_return_docstrings') body0.decorator_list = [i for i in body0.decorator_list if isinstance(i, ast.Call) and isinstance(i.func, ast.Name) and i.func.id == 'patch_run' and isinstance(i.args[0], ast.Name) and - i.args[0].id not in ('add_start_docstrings_to_model_forward', 'add_code_sample_docstrings')] + i.args[0].id not in transform_useless_decorators] ast.fix_missing_locations(new_tree) # closure info diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index d3b6e455..5fd31dc2 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -74,6 +74,9 @@ def exist(signature: str) -> bool: __ttemplate('eq') : function.CompareEQ, '_operator.eq': function.CompareEQ, __ttemplate('ne') : function.CompareNE, + '_operator.ne': function.CompareNE, + __ttemplate('max'): function.Max, + __ttemplate('where'): function.Where, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, __tttemplate('int'): function.Int, @@ -161,7 +164,9 @@ def exist(signature: str) -> bool: __ttemplate('gt'): function.CompareGT, '_operator.gt': function.CompareGT, __ttemplate('lt'): function.CompareLT, + '_operator.lt': function.CompareLT, __ttemplate('ge'): function.CompareGE, + '_operator.ge': function.CompareGE, __ttemplate('le'): function.CompareLE, '_operator.le': function.CompareLE, # @@ -205,5 +210,7 @@ def exist(signature: str) -> bool: # #einops # __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, - 'torch.functional.split': function.Split + 'torch.functional.split': function.Split, + __ttemplate('split'): function.Split, + __tttemplate('split'): function.Split, } diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 9bd01170..ac275501 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -28,6 +28,16 @@ def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) return tensor.to(dtype_or_device) +def max_(input: torch.Tensor, other_or_dim: Union[torch.Tensor, int, None]=None, out_or_keepdim: Optional[bool]=None) -> torch.Tensor: + if other_or_dim is None: + return torch.max(input) + elif isinstance(other_or_dim, int): + return torch.max(input, other_or_dim, out_or_keepdim) + else: + assert isinstance(other_or_dim, torch.Tensor) + return torch.max(input, other_or_dim) + + def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: """ accumulate tensors in to one tensor @@ -38,10 +48,6 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: return torch.sum(torch.stack(tensors, dim=0), dim=0) -def expand(input: torch.Tensor, sizes: Union[torch.Size, List[int]]) -> torch.Tensor: - return input.expand(*sizes) - - def fullslice(input: torch.Tensor, slicers: Tuple[Union[None, slice, int]]): """Slice tensors diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py new file mode 100644 index 00000000..01a8b8e5 --- /dev/null +++ b/tests/graph/function/test_functions.py @@ -0,0 +1,78 @@ +### Only test the anno creation in these tests + +import cube.graph.function.function as F +from cube.ir.cten import IRTensor + + +def test_handle_broadcast_multi(): + ins_anno, out_anno = F._handle_broadcast_multi([IRTensor([4]), IRTensor([3, 4]), IRTensor([2, 3, 4])]) + assert ins_anno[0] == ['c'] + assert ins_anno[1] == ['b', 'c'] + assert ins_anno[2] == ['a', 'b', 'c'] + assert out_anno == ['a', 'b', 'c'] + + ins_anno, out_anno = F._handle_broadcast_multi([IRTensor([1]), IRTensor([2, 1, 4]), IRTensor([2, 3, 4])]) + assert ins_anno[0] == ['1'] + assert ins_anno[1] == ['a', '1', 'c'] + assert ins_anno[2] == ['a', 'b', 'c'] + assert out_anno == ['a', 'b', 'c'] + +def test_Full(): + op = F.Full([1, 2, 3], 1.) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 1 2 3' + + op = F.Full([], 1.) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 1' + +def test_Expand(): + inp = IRTensor([10, 1]) + out = IRTensor([10, 2]) + op = F.Expand(inp, 10, 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ -> a 2' + + op.new([inp], [out], size=[10, 2]) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ -> a 2' + +def test_Where(): + op = F.Where(IRTensor([3, 4]), IRTensor([3, 4]), IRTensor([3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, a b, a b -> a b' + op = F.Where(IRTensor([3, 4]), IRTensor([4]), IRTensor([3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, b, a b -> a b' + op = F.Where(IRTensor([3, 4]), IRTensor([1]), IRTensor([3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, 1, a b -> a b' + op = F.Where(IRTensor([3, 4]), 1, IRTensor([3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, ?, a b -> a b' + op = F.Where(IRTensor([3, 4]), IRTensor([3, 4]), IRTensor([4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, a b, b -> a b' + op = F.Where(IRTensor([3, 4]), IRTensor([3, 4]), IRTensor([1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, a b, 1 -> a b' + op = F.Where(IRTensor([3, 4]), IRTensor([3, 4]), 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, a b, ? -> a b' + op = F.Where(IRTensor([3, 4]), 1, 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, ?, ? -> a b' + +def test_FullSlice(): + op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, 3)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' + op = F.FullSlice(IRTensor([2, 3, 4]), (..., 2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b' + op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, slice(0, 3, 2))) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 2' + op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, slice(1, 10, 1))) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 3' + +def test_Max(): + op = F.Max(IRTensor([2, 3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' + op = F.Max(IRTensor([2, 3, 4]), IRTensor([4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c, c -> a b c' + op = F.Max(IRTensor([2, 3, 4]), 1, True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a 1 c, a 1 c' + op = F.Max(IRTensor([2, 3, 4]), 1, False) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a c, a c' + +def test_Squeeze(): + op = F.Squeeze(IRTensor([2, 1, 4, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a c' + op = F.Squeeze(IRTensor([2, 1, 4, 1]), 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a c d' From a1a2b7eee3dab70b361d9fdb915de4faad8a1ff9 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 25 Oct 2023 06:00:09 +0000 Subject: [PATCH 1508/1892] Merged PR 1869: parallel module: refine forward generation make the whole design more clear. --- cube/codegen/module/module.py | 194 +++++++++++++++++++------- cube/parallel.py | 16 ++- tests/parallel_module/test_gencode.py | 102 ++++++++++++-- 3 files changed, 253 insertions(+), 59 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 71ed7063..98726a0d 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -1,9 +1,10 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Any import more_itertools import logging import copy import torch import numpy as np +import inspect from cube.ir.cten import IRCell from cube.ir.tensor import IRSubTensor @@ -258,15 +259,85 @@ def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: def gen( self, device: int, - outfile=None, - attach=False, + outfile: str = None, + attach: bool = False, *, - as_parallel_module=False, - forward_arg_names=None + as_parallel_module: bool = False, + forward_args: Optional[Dict[str, Any]] = None ) -> str: """ Generate model implementation code based on the given graph. - + if as_parallel_module is True, we will create a forward method for the module. + The arguments of the forward method will be same with original forward method with some exceptions: + 1. No positional only argument/keyword only argument support. + For example + ```python + def forward(self, x, y, /, z=None, *, m=1, n=2): + ... + ``` + the bevaior of the forward method will be undefined, and should be avoided. + Also the generated forward method will not have positional only argument/keyword only argument. + 2. *args is not supported, and will trigger runtime error. + For example: + ```python + def forward(self, x, y, *args): + ... + ``` + will fail to generate forward method. + 3. **kwargs will be kept as it is. + For example: + ```python + def forward(self, x, y, **kwargs): + ... + ``` + the generated forward method will be: + ```python + def forward(self, x, y, **kwargs): + ... + ``` + But you should not specify any argument in **kwargs when tracing the forward method. + The behavior will be undefined if you rely on the argument in **kwargs. + 4. If an argument is found not in the traced graph, + it will have default value None(no matter what the default value is in the original forward method), + And when calling the generated forward method, you should not specify the argument. + Otherwise, ValueError will be raised. + For example: + ```python + def forward(self, x, y, z=1): + ... + ``` + If y and z are not in the traced graph, the generated forward method will be: + ```python + def forward(self, x, y=None, z=None): + if y is not None: raise ValueError + if z is not None: raise ValueError + ... + ``` + 5. If an argument is used in the traced graph, the default value will be kept as it is. + For example: + ```python + def forward(self, x, y, z=1): + ... + ``` + if z is used in the traced graph, the generated forward method will be: + ```python + def forward(self, x, y, z=1): + ... + ``` + 6. A special case is, if an argument is after an unused argument, but doesn't have a default value. + To make python happy, we have to give it a default value None. + For example: + ```python + def forward(self, x, y, z): + ... + ``` + if y is not used in the traced graph, the generated forward method will be: + ```python + def forward(self, x, y=None, z=None): + if y is not None: raise ValueError + ... + ``` + Please note z has to have default value None, otherwise, python will complain. Args: device (int): device id outfile (str): output file path @@ -275,8 +346,8 @@ def gen( 1. Inherit from ParallelModule 2. Has forward method 3. Add more content to constructor - forward_arg_names (List[str]): argument names of forward function, if None, use node inputs. - This is used only in parallel module + forward_args (Dict[str, Any]): argument names and their default values of forward function, if None, use node inputs. + This is used only in parallel module. Returns: generated code @@ -390,47 +461,9 @@ def gen( raise RuntimeError("The graph has no segment, forward code cannot be generated.") segment_idx = segment_idxs[0] node = gen_nodes[segment_idx] - # will use the orignal names of inputs - inputs = [t.name for t in node.inputs() if not isinstance(t, IRSubTensor) or not t.is_attr()] - # ensure forward args are valid - unused_args = [] - if forward_arg_names: - for i in range(len(inputs)): - if inputs[i] not in forward_arg_names: - raise ValueError(f"Forward args mismatch: {inputs[i]} arg needed") - - forward_args = [] - # find the first mismatch - for i in range(len(inputs)): - if inputs[i] == forward_arg_names[i]: - forward_args.append(inputs[i]) - else: - break - - for i in range(len(forward_args), len(forward_arg_names)): - if not forward_arg_names[i].startswith('*'): - forward_args.append(f'{forward_arg_names[i]}=None') - if forward_arg_names[i] not in inputs: - unused_args.append(forward_arg_names[i]) - _logger.warning(f'Unused forward argument `{forward_arg_names[i]}`.' - f'The argument value will be ignored when you call module forward') - else: - forward_args.append(forward_arg_names[i]) - forward_arg_names = forward_args - else: - forward_arg_names = inputs - - with FunctionBlock(func_name='_forward_impl', args=['self'] + forward_arg_names) as fb: - outputs = self.return_name(node.outputs(), skip_attr=True) - call_code = f'{outputs} = self.{self.node_name(node)}({", ".join(inputs)})' - # be sure the user doesn't specify unused args. - for unused_arg in unused_args: - fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') - fb.insert_body(call_code) - return_code = f'return {self.return_name_complex(self.execplan.graph.outputs())}' - fb.insert_body(return_code) + cb.insert_body('') - cb.insert_body(fb.code) + cb.insert_body(self._generate_forward(node, forward_args)) gencode += cb.code gencode += [''] @@ -445,6 +478,71 @@ def gen( self.clear() return code + def _generate_forward(self, node, forward_args): + # the orignal names of inputs + inputs = [t.name for t in node.inputs() if not isinstance(t, IRSubTensor) or not t.is_attr()] + + unused_args = [] + forward_arg_resolved = [] + if forward_args: + # check all inputs are in forward args + for i in range(len(inputs)): + if inputs[i] not in forward_args: + raise ValueError(f"Forward args mismatch: {inputs[i]} arg needed") + + forward_arg_names = list(forward_args.keys()) + def _get_resolved_arg(arg_name, default_value): + if default_value is inspect.Parameter.empty: + return arg_name + else: + return f'{arg_name}={repr(default_value)}' + + # find the first mismatch + # here, we will keep the default values of the forward args + for i in range(len(inputs)): + if inputs[i] == forward_arg_names[i]: + default_value = forward_args[inputs[i]] + forward_arg_resolved.append(_get_resolved_arg(inputs[i], default_value)) + else: + break + + # check the rest of the forward args + # if the arg is not in inputs, we will set the default value to None + # in runtime, we will make sure the user doesn't specify it. + # if the arg is in inputs, we will keep the default value + # Also *args and **kwargs are kept as it is. + for i in range(len(forward_arg_resolved), len(forward_arg_names)): + if not forward_arg_names[i].startswith('*'): + default_value = forward_args[forward_arg_names[i]] + if forward_arg_names[i] not in inputs: + unused_args.append(forward_arg_names[i]) + forward_arg_resolved.append(f'{forward_arg_names[i]}=None') + _logger.warning(f'Unused forward argument `{forward_arg_names[i]}`.' + f'The argument value will be ignored when you call module forward') + else: + forward_arg_resolved.append( + _get_resolved_arg( + forward_arg_names[i], + None if default_value is inspect.Parameter.empty else default_value + ) + ) + else: + forward_arg_resolved.append(forward_arg_names[i]) + else: + forward_arg_resolved = inputs + + with FunctionBlock(func_name='_forward_impl', args=['self'] + forward_arg_resolved) as fb: + outputs = self.return_name(node.outputs(), skip_attr=True) + call_code = f'{outputs} = self.{self.node_name(node)}({", ".join(inputs)})' + # be sure the user doesn't specify unused args. + for unused_arg in unused_args: + fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') + fb.insert_body(call_code) + return_code = f'return {self.return_name_complex(self.execplan.graph.outputs())}' + fb.insert_body(return_code) + + return fb.code + def emit_comm_groups(self): """ Creating communication group requires all the devices diff --git a/cube/parallel.py b/cube/parallel.py index 66cfd971..7703c82c 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -105,6 +105,11 @@ def _is_any_gencode_loaded(namespace: str) -> bool: return False +def _get_arg_default_values(fn) -> Dict[str, Any]: + args = inspect.signature(inspect.unwrap(fn)) + return {k: v.default for k, v in args.parameters.items()} + + _GENCODE_FILE_PREFIX = 'gencode' _GENCODE_FILE_TEMPLATE = _GENCODE_FILE_PREFIX + '{}.py' # 'gencode{}.py' _CUBE_MODULE_NAMESPACE = '_cube_modules' @@ -201,6 +206,10 @@ def _gencode( IDGenerator().clear() module.cpu() + forward_args_default = _get_arg_default_values(module.forward) + for v in forward_args_default.values(): + if v is not inspect.Parameter.empty and not isinstance(v, (int, str, float, bool, type(None))): + raise ValueError(f"Default value type {type(v)} of forward args is not supported.") # generate fx graph dummy_input = _complex(dummy_input) @@ -216,7 +225,10 @@ def _gencode( fx_input_nodes = [node for node in fx_graph.graph.nodes if node.op == 'placeholder'] # the inputs of graph is different with original forward args # so we get the real forward args from fx inputs - forward_args = [node.target for node in fx_input_nodes] + forward_args = { + node.target: forward_args_default.get(node.target, inspect.Parameter.empty) + for node in fx_input_nodes + } ir_dummy_inputs = [] for node in fx_input_nodes: if node.target.startswith('*'): # *args or **kwargs @@ -290,7 +302,7 @@ def _gencode( mgener = ModuleCodeGen(execplan, scale_ndevs=runtime_ngpus) for rank in range(compute_config.runtime_ngpus): filename = _GENCODE_FILE_TEMPLATE.format(rank) - mgener.gen(rank, forward_arg_names=forward_args, outfile=outdir / filename, attach=False, as_parallel_module=True) + mgener.gen(rank, forward_args=forward_args, outfile=outdir / filename, attach=False, as_parallel_module=True) def _load_cube_module_class( diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 0b22ed64..d79a5f8a 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -74,12 +74,45 @@ def test_codegen_slice(): assert m_new is None +class ArgsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, y, *args): + return self.linear(x) + y + +def test_codegen_args(): + """ + Verify that unused args are supported by parallalize + """ + if not torch.cuda.is_available(): + print('skip test_codegen_args due to lack of cuda devices') + return + + with tempfile.TemporaryDirectory() as tempdir: + # *args is not supported. + with pytest.raises(RuntimeError): + parallelize( + ArgsModule(), + { + 'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + 'y': 1.0, + }, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=True + ) + + class UnusedArgsModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 5) - def forward(self, x, y, z=None, m=None, n=None, **kwargs): + def forward(self, x, y, z=None, m=1, n=2, **kwargs): return self.linear(x) + m @@ -92,7 +125,7 @@ def _gencode_unused_args_worker(tempdir): 'y': torch.tensor([1, 2, 3]), 'z': None, 'm': 0, - 'n': None + 'n': None, }, PASData, ComputeConfig(1, 1), @@ -106,15 +139,14 @@ def _gencode_unused_args_worker(tempdir): assert args.parameters['x'].default is inspect.Parameter.empty assert args.parameters['y'].default is None assert args.parameters['z'].default is None - assert args.parameters['m'].default is None + assert args.parameters['m'].default == 1 assert args.parameters['n'].default is None + assert args.parameters['kwargs'].default is inspect.Parameter.empty - with pytest.raises(TypeError): - # m can't be None - # TypeError is raised by torch.add - m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - - m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), m=1) + assert torch.equal( + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), m=1) + ) with pytest.raises(ValueError): # y must be None @@ -130,3 +162,55 @@ def test_codegen_unused_args(): with tempfile.TemporaryDirectory() as tempdir: launch_torchrun(1, _gencode_unused_args_worker, tempdir) + + +class UnusedArgs2Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, y, m): + return self.linear(x) + m + + +def _gencode_unused_args_worker2(tempdir): + init_distributed() + m_new = parallelize( + UnusedArgs2Module(), + { + 'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + 'y': torch.tensor([1, 2, 3]), + 'm': 0 + }, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=True + ) + assert m_new is not None + args = inspect.signature(m_new._forward_impl) + assert len(args.parameters) == 3 + assert args.parameters['x'].default is inspect.Parameter.empty + assert args.parameters['y'].default is None + assert args.parameters['m'].default is None + + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), m=1) + with pytest.raises(TypeError, match='.*must be Tensor, not NoneType.*'): + # raise by torch.add, as m is None + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) + with pytest.raises(ValueError): + # y must be None + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) + + +def test_codegen_unused_args2(): + """ + Verify that unused args are supported by parallalize + """ + if not torch.cuda.is_available(): + print('skip test_codegen_unused_args2 due to lack of cuda devices') + return + + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, _gencode_unused_args_worker2, tempdir) From 6d4e9884d4a7de89f063db5ff6a2a9ebf414b66d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 25 Oct 2023 09:11:00 +0000 Subject: [PATCH 1509/1892] Merged PR 1871: temp fix on pyops: replicate to all devices within its segment temp fix on pyops: replicate to all devices within its segment This change is made because cube currently doesn't support tracking data dependencies within op.kwargs. To make the generated code safe, we replicate each IRPyFunc to all devices within its segment. Note this will not introduce additional cost when using tensor parallelism inside a segment. --- cube/graph/gener/gen.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 99631a87..2c817c8e 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -139,13 +139,11 @@ def remove_anchor(graph: IRSegment): def auto_pyfunc(graph: IRGraph): """Transform and assign IRPyFunc. - IRPyFunc will be replicated to devices with its producers output - - Note if an IRPyFunc has no input, indicating its device can not - be indicated from any other operators. In this case, the pyfunc - will be replicated to all devices in its segment. To restrict - the replicaed devices in pipeline-like scenarios, use `graph.staging` - to group the operators into segments. + Warning: + Each IRPyFunc will be replicated to all devices of its segment. + + To restrict the replicaed devices in pipeline-like scenarios, use `graph.staging` + to group the operators into segments. Args: graph (IRGraph): the graph to be transformed @@ -156,17 +154,21 @@ def auto_pyfunc(graph: IRGraph): for func in graph.select(ntype=IRPyFunc, flatten=True): # get devices it will lowered to segment: IRSegment = graph.segment(func) - segment_outputs = IRSegment.get_objects_from_complex(segment.outputs()) devices = set() - for t in func.inputs(): - if not isinstance(t, IRObject): continue - cells = segment.consumers(t.parent) if t.is_attr() else segment.producers(t.parent) - for cell in cells: - devices.update(cell.device) - for t in func.outputs(): - if not isinstance(t, IRObject): continue - if t in segment_outputs: - devices.update(segment.device) + + # FIXME: this is temporally disabled as we don't track data dependencies inside + # operator kwargs. This will be fixed in the future. + # segment_outputs = IRSegment.get_objects_from_complex(segment.outputs()) + # for t in func.inputs(): + # if not isinstance(t, IRObject): continue + # cells = segment.consumers(t.parent) if t.is_attr() else segment.producers(t.parent) + # for cell in cells: + # devices.update(cell.device) + # for t in func.outputs(): + # if not isinstance(t, IRObject): continue + # if t in segment_outputs: + # devices.update(segment.device) + # if a pyfunc doesn't have input, it will be replicated # to all devices in its segment. if len(devices) == 0: From a88897e6ed47f2a9a72a0ef844462a0f1ddb807f Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 25 Oct 2023 10:06:13 +0000 Subject: [PATCH 1510/1892] Merged PR 1870: Cell-Input-Output-1: Support graph inputs with nested data structure --- cube/graph/graph.py | 52 ++++++++++++++----------------- cube/graph/parser/fx/parser.py | 16 +++++++--- cube/program.py | 6 +--- tests/graph/parser/test_parser.py | 35 +++++++++++++++++++++ tests/test_program.py | 40 ++++++++++++++++++++++++ 5 files changed, 111 insertions(+), 38 deletions(-) create mode 100644 tests/test_program.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 690a0d6c..cb6362f5 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -61,53 +61,49 @@ def __call__(self, *args): """ return self.forward(*args) - def forward(self, *args: Tuple[Any]) -> Union[IRTensor, Tuple[IRTensor]]: - """ - forward will divide the graph into Actions according to - node device assignment - - Currently each forward call will result in a new flow - even if the input is same + def forward(self, *args: Tuple[IRObject]) -> Union[IRTensor, Tuple[IRTensor]]: + """Forward the IRGraph to add model nodes into program. - @param args Tuple[Any] + Args: + args (Tuple[IRObject]): input IRObjects - @return outputs Union[IRSubTensor, Tuple[IRSubTensor]] + Returns: + Any: output that can be nested structure of IRObjects """ + if not all(isinstance(arg, IRObject) for arg in args): + raise TypeError("Expected input arguments to be IRObject") + # align graph with input tensors - itensors: Tuple[IRObject, ...] = self.inputs() - if len(args) != len(itensors): + iobjs: Tuple[IRObject, ...] = self.inputs() + if len(args) != len(iobjs): _logger.error( f'cube graph forward: skipping arguments due to len(args) != len(itensors): ' - f'{len(args)} != {len(itensors)}' + f'{len(args)} != {len(iobjs)}' ) - if len(args) > len(itensors): - args = args[:len(itensors)] + if len(args) > len(iobjs): + args = args[:len(iobjs)] _logger.warning(f'cube graph forward: args shrinked into {args}') else: raise RuntimeError('len(args) < len(itensors)') - arg_objs = IRGraph.get_objects_from_complex(args) - graph_objs = IRGraph.get_objects_from_complex(self.inputs()) - assert len(arg_objs) == len(graph_objs), f"input object number not match: {len(arg_objs)} != {len(graph_objs)}" - - for idx, (itensor, arg) in enumerate(zip(itensors, args)): + for idx, (iobj, arg) in enumerate(zip(iobjs, args)): + # reset input self.set_input(idx, arg) - - for arg, itensor in zip(arg_objs, graph_objs): - for producer in self.producers(itensor.parent): + # replace node inputs + for producer in self.producers(iobj.parent): with self.update(producer): - while itensor in producer.outputs(): - oidx = producer.outputs().index(itensor) + while iobj in producer.outputs(): + oidx = producer.outputs().index(iobj) producer.set_output(oidx, arg) - for consumer in self.consumers(itensor.parent): + for consumer in self.consumers(iobj.parent): with self.update(consumer): - while itensor in consumer.inputs(): - iidx = consumer.inputs().index(itensor) + while iobj in consumer.inputs(): + iidx = consumer.inputs().index(iobj) consumer.set_input(iidx, arg) # reset output for oidx, output in enumerate(self.outputs()): output = IRGraph.modify_objects_of_complex( - self.output(oidx), lambda t: t if t != itensor else arg) + self.output(oidx), lambda t: t if t != iobj else arg) self.set_output(oidx, output) from cube.program import Program diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 620b4fa6..0a1c6105 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -12,6 +12,7 @@ from cube.graph.function.dimops import IRDimops import torch.fx +from .concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp _logger = logging.getLogger(__name__) @@ -48,9 +49,6 @@ def parse(module: torch.fx.GraphModule, all_ir_nodes (List[IRFwOperation]): the IRFwOperation nodes outputs (List[IRObject]): the output IRObjects """ - from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp - from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata - frame = Frame() frame.push_var() @@ -64,7 +62,15 @@ def parse(module: torch.fx.GraphModule, FxModuleParser.init_objects(node, module, frame, concrete_value) # get graph inputs - inputs = [frame.get_var(n.name) for n in module.graph.nodes if n.op == 'placeholder'] + placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] + inputs = [frame.get_var(n.name) for n in placeholders] + # - if the graph inputs contain nested strcuture, + # it should be wrapped into an IRObject + for idx, placeholder in enumerate(placeholders): + if not isinstance(inputs[idx], IRObject): + obj = IRObject(name=placeholder.name, value=inputs[idx]) + inputs[idx] = obj + frame.set_var(placeholder.name, obj) # parse graph nodes all_ir_nodes = [] @@ -254,7 +260,7 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, assert isinstance(concrete_value, torch.Tensor), \ f"GetAttrPrim: expect tensor but got {type(concrete_value)}" exist_tensor = frame.get_attr_var(concrete_value) - # the cath that the parameter is the first time used by getattr + # the case that the parameter is the first time used by getattr if not exist_tensor: tensor = frame.get_var(node.name) if tensor.requires_grad: diff --git a/cube/program.py b/cube/program.py index d9b11757..ff8ebc4c 100644 --- a/cube/program.py +++ b/cube/program.py @@ -163,8 +163,6 @@ def complex(val: Any): return list(complex(t) for t in val) if isinstance(val, dict): return {complex(key):complex(val) for key, val in val.items()} - if isinstance(val, set): - return {complex(t) for t in val} if isinstance(val, torch.Tensor): return val.cpu() return val @@ -202,12 +200,10 @@ def __call__(self, *args): if self.dummy_input is None: dummy_input = {} sig = inspect.signature(self.model.forward) - # note: we don't support model forward arguments having complex data stucture - # that contains tensor for name, arg in zip(sig.parameters.keys(), args): if isinstance(arg, IRObject): value = arg.value - arg._value = None # remove tensor reference to release memory + arg._value = None # remove value to release memory else: value = arg dummy_input[str(name)] = value diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 33f5df19..75565925 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -1,5 +1,6 @@ import tempfile import torch +from cube.ir.cten import IRObject, IRTensor from cube.graph.parser.converter import to_fx_graph, to_ir_graph @@ -31,3 +32,37 @@ def forward(self, x): assert ir_graph is not None assert len(ir_graph.attributes()) == 2 # param1 and param2 assert len(ir_graph.full_tensors()) == 8 + + +def test_parser_nested_inputs(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param1 = torch.nn.Parameter(torch.empty(4, 4)) + self.param2 = torch.nn.Parameter(torch.empty(4, 4)) + + def forward(self, x: dict): + shortcut = x['data'] + x = torch.matmul(x['data'], self.param1) + x = x + self.param2 + x = x + shortcut + x = x + self.param1 + return {'loss': torch.sum(x)} + + dummy_input = {'x': {'data': torch.randn(4, 4)}} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + print(ir_graph.extra_repr()) + + assert len(ir_graph.inputs()) == 1 + assert isinstance(ir_graph.input(0), IRObject) + assert isinstance(ir_graph.input(0).value, dict) + assert isinstance(ir_graph.input(0).value['data'], IRTensor) + assert len(ir_graph.outputs()) == 1 + assert isinstance(ir_graph.output(0), dict) + assert isinstance(ir_graph.output(0)['loss'], IRTensor) diff --git a/tests/test_program.py b/tests/test_program.py new file mode 100644 index 00000000..95e49a7f --- /dev/null +++ b/tests/test_program.py @@ -0,0 +1,40 @@ + +import torch +from cube.program import SemanticModel, Program +from cube.flags import CompileFlag +from cube.ir.cten import IRObject + + +def test_program_model_nested_input(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param1 = torch.nn.Parameter(torch.empty(4, 4)) + self.param2 = torch.nn.Parameter(torch.empty(4, 4)) + + def forward(self, x: dict): + shortcut = x['data'] + x = torch.matmul(x['data'], self.param1) + x = x + self.param2 + x = x + shortcut + x = x + self.param1 + return {'loss': torch.sum(x)} + + CompileFlag.dev_mode = True + Program().clear() + + dummy_input = {'x': {'data': torch.randn(4, 4)}} + module = MyModule() + model = SemanticModel(module, save_content=False, dynamic_shape=False) + + obj = IRObject(value=dummy_input['x']) + model(obj) + graph = model.get_graph() + print(graph.extra_repr()) + + assert graph.input(0) == obj + # getitem + assert graph.node(0).input(0) == obj + # getitem + assert graph.node(1).input(0) == obj From 1ee990031c7fe1cd15aa9425baa434e62fb2c1b4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 26 Oct 2023 06:56:29 +0000 Subject: [PATCH 1511/1892] Merged PR 1872: Cell-Input-Output-2: Refine cell interface - Remove predecessor and successor interface, as dependency can be inferred by `graph.depends`. - IRCell now supports `.kwargs`, which was originally supported only in IRFwOperation. This is the prepare work for the general signature in the future. - Remove IRCell `init_output=True`, by default it will now initialize output with `[None,] * num_outputs` instead of IRTensor. This is a redundant design. passed parity check. --- cube/execplan/execplan.py | 2 +- cube/graph/gener/rvd/layout.py | 2 +- cube/graph/segment.py | 32 +--- cube/ir/adapter/adapter.py | 2 - cube/ir/cten.py | 261 +++++------------------------ cube/ir/operator.py | 13 +- examples/policies/alpa/layer_op.py | 2 +- 7 files changed, 50 insertions(+), 264 deletions(-) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 34d34bde..95cd086d 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -26,7 +26,7 @@ def __init__(self, cell: IRCell, f"output length mismatch: {cell}\n" f"cell outputs: {cell.outputs()}\noutputs: {outputs}") super().__init__(cell.name, cell.signature, - len(inputs), len(outputs), init_outputs=False) + len(inputs), len(outputs)) for idx, t in enumerate(inputs): self.set_input(idx, t) for idx, t in enumerate(outputs): diff --git a/cube/graph/gener/rvd/layout.py b/cube/graph/gener/rvd/layout.py index f9b5e89e..656b5261 100644 --- a/cube/graph/gener/rvd/layout.py +++ b/cube/graph/gener/rvd/layout.py @@ -138,7 +138,7 @@ def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int], devices: Optio """ dims = tuple(dims) def dummy_assign(tensor: IRSubTensor, devid: int): - tensor.cell = IRCell('dummy', '', 0, 0, init_outputs=False) + tensor.cell = IRCell('dummy', '', 0, 0) tensor.cell.device = devid mats = np.empty((r, v) + dims, dtype=IRSubTensor) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 26caab7b..6f3dba4f 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -77,7 +77,7 @@ class IRSegment(IRCell): """ def __init__(self, nodes: List[IRCell], inputs: List[IRObject], outputs: List[Any], name='segment'): - super().__init__(name, '', len(inputs), len(outputs), init_outputs=False) + super().__init__(name, '', len(inputs), len(outputs)) self._nodes: List[IRCell] = [] @@ -104,8 +104,6 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRObject], outputs: List[An self._dispatch_cached: Dict[int, IRSegment] = {} - # self.reset_dependency() - def set_input(self, idx: int, val: Any): for t in IRSegment.get_objects_from_complex(val): self._add_ftensor(t.parent) @@ -138,34 +136,6 @@ def attributes(self) -> Tuple[IRFullTensor]: """ return tuple(self._attributes) - def reset_dependency(self): - """ - Reset the node dataflow dependency - - Note all the predefined control dependencies will be removed. - TODO: adapter dependency is not set - """ - for node in self._nodes: - node.clear_predecessor() - node.clear_successor() - # TODO: adapter dependency not set - for ftensor in self._ftensors: - for ptensor, producer in zip(self.ptensors(ftensor), self.producers(ftensor)): - for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): - if ptensor.overlap(ctensor): - pidx = producer.outputs().index(ptensor) - cidx = consumer.inputs().index(ctensor) - producer.add_successor(pidx, consumer) - consumer.add_predecessor(cidx, producer) - # set mirror as control dependency - if producer.mirror is not None and isinstance(producer, IRFwOperation): - producer.add_successor(-1, producer.mirror) - producer.mirror.add_predecessor(-1, producer) - # sub segments - for segment in self._nodes: - if isinstance(segment, IRSegment): - segment.reset_dependency() - # ========================= Basic Graph access ======================= @property diff --git a/cube/ir/adapter/adapter.py b/cube/ir/adapter/adapter.py index 8a57ed1b..7cf9f36f 100644 --- a/cube/ir/adapter/adapter.py +++ b/cube/ir/adapter/adapter.py @@ -13,9 +13,7 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor]): name='adapter', signature='adapter', input_length=len(inputs), output_length=len(outputs), - init_outputs=False ) - self.kwargs = dict() # we don't use input and output setter as this will # change tensor device info self._inputs = inputs diff --git a/cube/ir/cten.py b/cube/ir/cten.py index dfb99bfd..f66c6eb9 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -13,9 +13,10 @@ If an IRTensor is the output of Cell, then Cell.device == IRTensor.device """ +from __future__ import annotations from functools import lru_cache -from typing import Iterable, List, Tuple, Union, Optional, Any +from typing import List, Tuple, Union, Optional, Any, Dict import copy import torch @@ -23,6 +24,9 @@ from cube.ir.dtype import DTypeInfo +NestedVarOrStatic = Any + + class IRCell: r""" IRCell serves as a general node for different purpose @@ -32,8 +36,7 @@ def __init__(self, name: str, signature: str, input_length: int, - output_length: int, - init_outputs = True): + output_length: int): """ Create a node with name (variable name) and module type (module_name) @@ -51,23 +54,13 @@ def __init__(self, self._device: Tuple[int] = () - # source tensors - self._inputs: List[Optional[IRTensor]] = [None,] * input_length - - # destination tensors - self._outputs: List[Optional[IRTensor]] = [None,] * output_length - if init_outputs: - self._outputs = [IRTensor() for _ in range(output_length)] - for tensor in self._outputs: - tensor.cell = self - - # destination cells. [-1] for control dependency - self._successors: List[List[IRCell]] = [list() for _ in range(output_length+1)] - # source cells. [-1] for control dependency - self._predecessors: List[List[IRCell]] = [list() for _ in range(input_length+1)] + # input tensors + self._inputs: List[NestedVarOrStatic] = [None,] * input_length + self._kwargs: Dict[str, NestedVarOrStatic] = {} + # output tensors + self._outputs: List[NestedVarOrStatic] = [None,] * output_length self._mirror: Optional[IRCell] = None - # the comment for code generation self._comment: Optional[str] = None @@ -144,105 +137,52 @@ def isfw(self) -> bool: """ return True - def input(self, index:int): - # type: (int) -> Optional[IRTensor] - """ - Get the input tensor at input index + @property + def kwargs(self) -> Dict[str, NestedVarOrStatic]: + return self._kwargs + + def input(self, index: int) -> NestedVarOrStatic: + """Get the index-th input Args: - index (int): - index of the inputs + index (int): index of the inputs Returns: - values: Optional[IRTensor] + NestedVarOrStatic: (nested) IRObject or any static value (int, bool, str, etc) """ return self._inputs[index] # 'maxsize=None' set no limit on cache growth, but it's ok since we have no args @lru_cache(maxsize=None) - def inputs(self): - # type: () -> Tuple[Optional[IRTensor], ...] - """ - Get all input tensors + def inputs(self) -> Tuple[NestedVarOrStatic]: + """Get all input values Returns: - values: Tuple[Optional[IRTensor], ...] + Tuple[NestedVarOrStatic] """ - return tuple(self._inputs) - def predecessors(self, index: Optional[int] = None) -> List: - """ - Get input operator at input index - (or index = -1 for control dependency) - - Returns: - cell(s): Union[List[IRCell], IRCell] - """ - if isinstance(index, int): - if index >= len(self._inputs): - raise RuntimeError( - f"Get the input out of range ({index} >= {len(self._inputs)}" - ) - return copy.copy(self._predecessors[index]) - elif index is None: - predecessors = list() - for pre_cells in self._predecessors: - predecessors += pre_cells - return predecessors - else: - raise TypeError("Expected index to be None or int") - - def output(self, index:int): - # type: (int) -> Optional[IRTensor] - """ - Get the output tensor at output index + def output(self, index: int) -> NestedVarOrStatic: + """Get the index-th output value Args: - index (int): - index of the outputs + index (int): index of the outputs Returns: - values: Optional[IRTensor] + NestedVarOrStatic: (nested) IRObject or any static value (int, bool, str, etc) """ return self._outputs[index] # 'maxsize=None' set no limit on cache growth, but it's ok since we have no args @lru_cache(maxsize=None) - def outputs(self): - # type: () -> Tuple[Optional[IRTensor], ...] - """ - Get all output tensors + def outputs(self) -> Tuple[NestedVarOrStatic]: + """Get all output values Returns: - values: Tuple[Optional[IRTensor], ...] + Tuple[NestedVarOrStatic] """ - return tuple(self._outputs) - def successors(self, index: Optional[int] = None) -> List: - """ - Get output operator at output index - - Args: - index (int or None): - index of the outputs (or -1 for control dependency), - None will return the nodes for all the outputs - """ - if isinstance(index, int): - if index >= len(self._outputs): - raise RuntimeError( - f"Get the output out of range ({index} >= {len(self._outputs)}" - ) - return copy.copy(self._successors[index]) - elif index is None: - successors = list() - for post_cells in self._successors: - successors += post_cells - return successors - else: - raise TypeError("Expected index to be None or int") - def reset_inputs(self, length:int) -> None: """ Resize the inputs list to the new length and reset all input items to None. @@ -250,32 +190,21 @@ def reset_inputs(self, length:int) -> None: self._inputs = [None] * length self.inputs.cache_clear() - def set_input(self, input_index: int, val): - # type: (int, Optional[IRTensor]) -> Optional[IRTensor] - """ - Set the node inputs[input_index] with the tensor + def set_input(self, index: int, val: NestedVarOrStatic) -> NestedVarOrStatic: + """Set the index-th input Args: - val: Optional[IRTensor] + val (NestedVarOrStatic): (nested) IRObject or any deterministic value (int, bool, str, etc) - Return: - the set tensor + Returns: + NestedVarOrStatic: copied value """ - c = len(self._inputs) - if input_index >= c or input_index < -c: - raise RuntimeError( - f"Set the input out of range ({input_index} >= {c} or {input_index} < {-c})" - ) if isinstance(val, IRObject): # copy the val val = copy.copy(val) - # set tensor dst val.cell = self - - self._inputs[input_index] = val - + self._inputs[index] = val self.inputs.cache_clear() - return val def reset_outputs(self, length:int) -> None: @@ -285,129 +214,25 @@ def reset_outputs(self, length:int) -> None: self._outputs = [None] * length self.outputs.cache_clear() - def set_output(self, output_index: int, val): - # type: (int, Optional[IRTensor]) -> Optional[IRTensor] + def set_output(self, index: int, val: NestedVarOrStatic): """ Set the node inputs[output_index] with the tensor Args: - val: Optional[IRTensor] - IRTensor or any deterministic value (int, bool, str, etc) - """ - c = len(self._outputs) - if output_index >= c or output_index < -c: - raise RuntimeError( - f"Set the input out of range ({output_index} >= {c} or {output_index} < {-c})" - ) + val (NestedVarOrStatic): (nested) IRObject or any deterministic value (int, bool, str, etc) + + Returns: + NestedVarOrStatic: copied value + """ if isinstance(val, IRObject): val = copy.copy(val) val.cell = self - - self._outputs[output_index] = val + self._outputs[index] = val self.outputs.cache_clear() - return val - def add_predecessor(self, input_index: int, cell): - """ - Add a predecessor cell in the input_index slot. - - Note this won't add successor if caller cell to the node - - To add control dependency, use `input_index=-1` - """ - if not isinstance(cell, IRCell): - raise TypeError("Expected node to be IRCell") - if input_index >= len(self.inputs()): - raise RuntimeError( - f"Set the input out of range ({input_index} >= {len(self._inputs)})" - ) - if cell not in self._predecessors[input_index]: - self._predecessors[input_index].append(cell) - - def clear_predecessor(self): - """ - Clear all predecessors - """ - self._predecessors = [ - list() for _ in range(len(self.inputs()) + 1) - ] - - def add_successor(self, output_index: int, cell): - """ - Set self node the output index node. - `node` will take the self.output(index) as the input - - To add control dependency, use `output_index=-1` - """ - if not isinstance(cell, IRCell): - raise TypeError("Expected node to be IRCell") - if cell not in self._successors[output_index]: - self._successors[output_index].append(cell) - - def clear_successor(self): - """ - Clear all successors - """ - self._successors = [ - list() for _ in range(len(self.outputs()) + 1) - ] - - def make_empty(self): - """ - Clear all inputs, outputs of this Cell - """ - for idx in range(len(self.inputs())): - self.set_input(idx, None) - for idx in range(len(self.outputs())): - self.set_output(idx, None) - - @staticmethod - def get_inputs(cells): - # type: (Iterable[IRCell]) -> list[IRCell] - """ - Get all the input tensors the is not generated by nodes - - Inputs - - Returns: - List[IRTensor] - """ - all_outputs = list() - for cell in cells: - all_outputs.extend(cell.outputs()) - inputs = list() - for cell in cells: - for input in cell.inputs(): - if isinstance(input, IRTensor): - if input not in all_outputs: - if input not in inputs: - inputs.append(input) - return inputs - - @staticmethod - def get_outputs(cells): - # type: (Iterable[IRCell]) -> list[IRCell] - """ - Get all the input tensors the is not generated by nodes - - Returns: - List[IRTensor] - """ - all_inputs = list() - for node in cells: - all_inputs.extend(node.inputs()) - outputs = list() - for node in cells: - for output in node.outputs(): - if isinstance(output, IRTensor): - if output not in all_inputs: - if output not in outputs: - outputs.append(output) - return outputs - @property - def comment(self) -> Any: + def comment(self) -> Optional[str]: return self._comment @comment.setter diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 7fa43f7a..9d61bfec 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -25,15 +25,14 @@ def __init__(self, name: str, signature: str, """ # recompute schedule self._recompute = None - super().__init__(name, signature, len(inputs), - num_outputs, init_outputs=False) + super().__init__(name, signature, len(inputs), num_outputs) # setup input for idx, input in enumerate(inputs): self.set_input(idx, input) # additional argument - self.kwargs = kwargs + self.kwargs.update(kwargs) # default infer rule requires_grad = any( @@ -132,8 +131,6 @@ def replicate(self): cpy.set_output(idx, output) cpy._mirror = None cpy.recompute = self.recompute - cpy.clear_predecessor() - cpy.clear_successor() return cpy def __repr__(self) -> str: @@ -163,7 +160,7 @@ def __init__(self, ograds: Tuple[Any], igrads: Tuple[Any]): """ super().__init__( 'backward', 'torch.autograd.grad', - len(ograds), len(igrads), init_outputs=False + len(ograds), len(igrads) ) for idx, ograd in enumerate(ograds): self.set_input(idx, ograd) @@ -188,8 +185,6 @@ def replicate(self): for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None - cpy.clear_predecessor() - cpy.clear_successor() return cpy def __repr__(self) -> str: @@ -231,8 +226,6 @@ def replicate(self): for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) cpy._mirror = None - cpy.clear_predecessor() - cpy.clear_successor() return cpy def infer_shape(self): diff --git a/examples/policies/alpa/layer_op.py b/examples/policies/alpa/layer_op.py index d87504af..1cd70f3a 100644 --- a/examples/policies/alpa/layer_op.py +++ b/examples/policies/alpa/layer_op.py @@ -10,7 +10,7 @@ class IRLayerOp(IRCell): def __init__(self, nodes: List[IRCell], layer_id: int = None): - super().__init__('layer_op', 'layer_op', 0, 0, init_outputs=False) + super().__init__('layer_op', 'layer_op', 0, 0) self.nodes = nodes self.layer_id : int = layer_id From accc34ab01c1f610bb3751bdfd0976ec5ac472d6 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 27 Oct 2023 06:54:12 +0000 Subject: [PATCH 1512/1892] Merged PR 1878: emit tensor_name fix Will return tensor_name as `repr` (instead of `str`) for non-tensor names. --- cube/algorithm/ops/dimops.py | 26 +++--- cube/codegen/emit.py | 67 ++++++++++---- cube/codegen/frontend_mapping.py | 20 +---- cube/graph/function/function.py | 71 +++++++++++---- tests/codegen/__init__.py | 0 tests/codegen/test_emit.py | 33 +++++++ tests/graph/function/test_functions.py | 98 ++++++++++++++++++++- tests/parallel_module/test_gencode.py | 115 ++++++++++++++++++++++++- 8 files changed, 359 insertions(+), 71 deletions(-) create mode 100644 tests/codegen/__init__.py create mode 100644 tests/codegen/test_emit.py diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 9e4581ea..1f0c9d8f 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -34,9 +34,9 @@ class DimSplitEinops(GenericDistAlgo): If the identifier appears as the same name in argument name, the argument will also be uniformly partitioned. - + Non-tensor will always be replicated. - + Note the default rule isn't always expressive for all possible partition algorithms. E.g., linear xw + b to partition on reduction dimension, whitch requires b to be value split but actually according to the default rule, will be replicated. @@ -52,7 +52,7 @@ def get_identifier_reduce(self, idx: int, dim: int, num: int) -> Tuple[str, DimA """ Get the partitioned identifier and reduction type. If the partitioned number is 1, return the first hidden identitifer - Otherwise, return the first hidden identifier whose length > 1 + Otherwise, return the first hidden identifier whose length > 1 @param idx int: input/output index. Take the idx-th input tensor or (idx-ninputs)-th output @param dim int: input dimension @@ -75,7 +75,7 @@ def get_identifier_reduce(self, idx: int, dim: int, num: int) -> Tuple[str, DimA def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: """ Check whether the condition satisfies. - + @param idx int: input/output index. Take the idx-th input tensor or (idx-ninputs)-th output tensor @param dim Union[int, str]: tensor dimension or 'v', i.e., partition at value dimension. @param num int: chunks to partition the dimension @@ -85,7 +85,7 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: assert all(isinstance(cond, int) for cond in [idx, num]), "expect int condition" assert isinstance(dim, int) or dim == 'v', f"expect dim to be int or 'v'" node: IRDimops = self.node - + tensors = node.inputs() + node.outputs() assert isinstance(tensors[idx], IRSubTensor), f"partition on a non-tensor input/output" assert 0 <= idx and idx < len(tensors), f"index out of boundary: {idx} >= {len(tensors)}" @@ -94,7 +94,7 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: if isinstance(dim, int): dim = dim if dim >= 0 else dim + tensors[idx].ndims assert dim < tensors[idx].ndims, f"dimension output of boundary: {dim} >= {node.input(idx).ndims}" - + # try split at tensor spatial dimension if isinstance(dim, int): adim, reduce = self.get_identifier_reduce(idx, dim, num) @@ -104,11 +104,11 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: for rule in node.transform_rules: splits = rule.inputs() + rule.outputs() if splits[idx] == DimopSplit.D(dim): - return dimlen >= num + return dimlen % num == 0 # then check default rules if reduce == DimAnno.ReduceType.Freeze: return False - return dimlen >= num + return dimlen % num == 0 else: for rule in node.transform_rules: splits = rule.inputs() + rule.outputs() @@ -125,13 +125,13 @@ def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List adim, reduce = self.get_identifier_reduce(idx, dim, num) else: adim, reduce = 'Value', None - + color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' _logger.info(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") if not satisfy: return None rule: TransformRule = self.infer(idx, dim, num) - + # transform def transform(tensor: Any, split: DimopSplit) -> List[Any]: if not isinstance(tensor, IRSubTensor): @@ -296,8 +296,8 @@ def gen_hash(node: IRFwOperation) -> str: if cur_key in visited: continue - + dq.append((new_node, new_ngpus)) visited.add(cur_key) - - return gen_nodes \ No newline at end of file + + return gen_nodes \ No newline at end of file diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 39c89628..08427ff9 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -17,7 +17,12 @@ class IRValue: - + """ + A wrapper of the tensor name (as a variable name). + This is used to avoid the tensor name to be quoted in repr. + repr('name') => "'name'" + repr(IRValue('name')) => "name" + """ def __init__(self, name: str): self.name = name @@ -25,6 +30,40 @@ def __repr__(self): return self.name +def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: + """ + Return repr-able value of a tensor or value. + For tensor, return IRValue({prefix}{tensor.name}_{tensor.tid}) + For non-tensor, return as it is + + Args: + val (Any): tensor or non-tensor value + prefix_attr (str): prefix to the tensor name if the tensor is an attribute + Returns: + the val that can be repr safely + """ + if isinstance(val, IRObject): + tensor_name = val.name + if '.' in tensor_name: + tensor_name = tensor_name.split('.')[0] + name = '_'.join([tensor_name, str(val.tid)]) + if prefix_attr is not None and val.is_attr(): + name = prefix_attr + name + return IRValue(name) + elif isinstance(val, slice): + return slice(_safe_repr_value(val.start, prefix_attr), _safe_repr_value(val.stop, prefix_attr), _safe_repr_value(val.step, prefix_attr)) + elif isinstance(val, dict): + return {_safe_repr_value(k, prefix_attr): _safe_repr_value(v, prefix_attr) for k, v in val.items()} + elif isinstance(val, list): + return [_safe_repr_value(v, prefix_attr) for v in val] + elif isinstance(val, tuple): + # TODO: support subclasses of tuple, like torch.Size? + return tuple(_safe_repr_value(v, prefix_attr) for v in val) + elif isinstance(val, (int, str, bool, float, type(None), bytes)): # only primitive type supported + return val + raise ValueError(f'Unsupported data type: {type(val)}') + + class CodeEmission: """ Basic emission @@ -32,27 +71,19 @@ class CodeEmission: def node_name(self, node: IRCell) -> str: return f"{node.name}{node.cid}" - def tensor_name(self, tensor: Any, prefix_attr: Optional[str] = None) -> str: + def tensor_name(self, val: Any, prefix_attr: Optional[str] = None) -> str: """ - Return the var name. + Return representation of a value or a tensor. For tensor, return the {prefix}{tensor.name}_{tensor.tid} - For non-tensor, return its string - - @param tensor Any: any value - @attr_prefix Optional[str]: prefix for a attributed tensor + For non-tensor, return its repr - @return str + Args: + val (Any): tensor or non-tensor value + prefix_attr (Optional[str]): prefix to the tensor name if the tensor is an attribute + Returns: + representation of the val in str """ - if isinstance(tensor, IRObject): - tensor_name = tensor.name - if '.' in tensor_name: - tensor_name = tensor_name.split('.')[0] - name = '_'.join([tensor_name, str(tensor.tid)]) - if prefix_attr is not None and tensor.is_attr(): - name = prefix_attr + name - else: - name = str(IRSegment.modify_objects_of_complex(tensor, self.tensor_name)).replace('\'', '') - return name + return repr(_safe_repr_value(val, prefix_attr)) def complex_name(self, val: Any, prefix_attr: Optional[str]=None) -> str: """ diff --git a/cube/codegen/frontend_mapping.py b/cube/codegen/frontend_mapping.py index d467dbf4..a09bd171 100644 --- a/cube/codegen/frontend_mapping.py +++ b/cube/codegen/frontend_mapping.py @@ -16,8 +16,6 @@ def __init__(self) -> None: self._sign2rule = { 'torch.slice': self.emit_slice, 'setattr': self.emit_setattr, - 'builtins.getattr': self.emit_getattr, - '_operator.getitem': self.emit_getitem, } def map(self, signature: str) -> Callable: @@ -76,20 +74,4 @@ def emit_slice(self, node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[st def emit_setattr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: """Special rule for generating setattr node """ - - assert arg_vars[1].startswith('self.') - member = f'"{arg_vars[1][5:]}"' - return f"{node.signature}({arg_vars[0]}, {member}, {arg_vars[2]})" - - def emit_getattr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: - """Special rule for generating getattr node - """ - return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" - - def emit_getitem(self, node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: - """Special rule for generating getitem node - """ - if len(arg_vars) == 2 and len(kw_pairs) == 0 and not arg_vars[1].replace('_', '').isdigit(): - return f"{node.signature}({arg_vars[0]}, '{arg_vars[1]}')" - else: - return self.emit_common(node, arg_vars, kw_pairs) + assert False, f"This emit rule is deprecated, please report if you reach here" diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index c09d1d7d..8a111c5e 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -156,7 +156,7 @@ def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Uni return dimop -def Arange(*args, out=None, dtype=None, layout=None, +def Arange(*args, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): """ torch.arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor @@ -261,13 +261,13 @@ def Full(size, fill_value, *, out=None, dtype=None, layout=None, return dimop -def NewTensor(data, *, dtype=None, device=None, +def NewTensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False, signature=None): # note: device is ignored dtype = dtype if dtype is not None else torch.get_default_dtype() signature = 'cube.runtime.function.tensor' size = tuple(np.array(data).shape) if np.array(data).shape else (1,) # (1,) means it is a scalar - kwargs = {'size': size, 'requires_grad': requires_grad, + kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(size, True) dimop = IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) @@ -339,12 +339,12 @@ def _handle_broadcast_multi(ins_list: List[IRTensor]) -> Tuple[Tuple[List[str]], def Expand(input, *sizes, size = None, signature = None): """ torch.Tensor.expand(*sizes) - + The reason of add ``size`` to this function argument is: 1. ``sizes`` need to reuse in IRDimops.new(), but it is a ``non-keyword arguments``, and can not put it into keyword arguments (something like Expand(input, sizes=[1, 2, 3])) is not work, to support IRDimops.new API, here add a ``size`` to workaround. - + 2. in torch._C.expand API, it has: def expand(self, size: Sequence[Union[_int, SymInt]], *, implicit: _bool=False) -> Tensor: ... so add ``size`` can also solve user using something like: @@ -941,7 +941,7 @@ def _reshape_anno(in_shape: List[int], ou_shape: List[int], kwarg_name: str) -> in_shape List[int]: input shape ou_shape List[int]: output shape kwarg_name str: kwarg name of reshape / view op - + Returns: str: annotation string List[TransformRule]: transformation rules @@ -1063,7 +1063,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s if sdim is not None: ospatial.add(sdim) ofirst.append(sdim) - + # intersection for spatial partitioned dimensions spatial = ispatial.intersection(ospatial) @@ -1497,7 +1497,7 @@ def Slice(tensor: torch.Tensor, dim, start, end, step, signature = None): ianno = ShapeAnno.create_shape_str(tensor.shape) oanno = copy.copy(ianno) ianno[dim] = str(tensor.shape[dim]) - + def clip(ofst): ofst = ofst + tensor.shape[dim] if ofst < 0 else ofst return min(tensor.shape[dim], max(0, ofst)) @@ -1528,21 +1528,51 @@ def SelectScatter(self: torch.Tensor, input: torch.Tensor, dim: int, index: int, in1_anno[dim] = str(self.shape[dim]) out_anno = in1_anno.copy() anno = OpAnno.create_op_str([in1_anno, in2_anno], [out_anno]) - return IRDimops(SelectScatter, 'select_scatter', signature, + return IRDimops(SelectScatter, 'select_scatter', signature, [anno], [self, input], dim=dim, index=index) -def Repeat(tensor, repeats: Tuple[int], *arg_repeats, signature = None): +# If the type is IROject, then value should be type of int, Tuple[int], List[int] +# If the type is Tuple[IROject] or List[IRObject], then the value of each element should be type of int +_VariadicInt = Union[int, Tuple[int, ...], List[int], IRObject, Tuple[IRObject, ...], List[IRObject]] + +def extract_variadic(v: _VariadicInt) -> Tuple[List[int], List[bool]]: + if isinstance(v, int): + if isinstance(v, bool): + raise ValueError("Unsupported type: bool") + return [v], [False] + elif isinstance(v, IRObject): + r = extract_variadic(v.value) + return r[0], [True] * len(r[0]) # because all elements are from IRObject + elif isinstance(v, (tuple, list)): + r = [extract_variadic(e) for e in v] + if any(len(x[0]) != 1 for x in r): + raise ValueError("tuple/list can't be nested") + return [x[0][0] for x in r], [x[1][0] for x in r] + else: + raise ValueError(f"Unsupported type: {type(v)}") + + +def Repeat(tensor, repeats: _VariadicInt, *arg_repeats, signature = None): """ torch.Tensor.repeat(*sizes) """ signature = 'torch.ops.aten.repeat' - repeats = (repeats,) if isinstance(repeats, int) else tuple(repeats) - repeats = repeats + arg_repeats + if isinstance(repeats, (list, tuple)) or ( + isinstance(repeats, IRObject) and isinstance(repeats.value, (list, tuple)) + ): + # follow the behavior of torch.Tensor.repeat, + # ignore arg_repeats in this case + complete_repeats = repeats + else: + complete_repeats = (repeats,) + arg_repeats + repeats, repeats_is_ir = extract_variadic(complete_repeats) + in_shape = list(tensor.shape) - assert len(in_shape) <= len(repeats), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor" + if len(in_shape) > len(repeats): + raise ValueError("Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor") expand = len(repeats) - len(tensor.shape) - in_shape += [1] * expand + in_shape = [1] * expand + in_shape ou_shape = [dimlen * repeat for dimlen, repeat in zip(in_shape, repeats)] ianno, oanno = ShapeAnno.create_shape_str(in_shape), [] for dim, dimlen in enumerate(ou_shape): @@ -1552,11 +1582,14 @@ def Repeat(tensor, repeats: Tuple[int], *arg_repeats, signature = None): if repeats[dim] != 1: ianno[dim] += '^' dim_anno = [str(repeats[dim]), ianno[dim]] + elif repeats_is_ir[dim]: # for dynamic repeat, don't split the dimension + ianno[dim] += '^' + dim_anno = ianno[dim] else: dim_anno = ianno[dim] oanno.append(dim_anno) anno = OpAnno.create_op_str([ianno[expand:]], [oanno]) - return IRDimops(Repeat, 'repeat', signature, [anno], [tensor], repeats=repeats) + return IRDimops(Repeat, 'repeat', signature, [anno], [tensor], repeats=complete_repeats) def CubeEmbedding(input, weight, padding_idx, signature = None, **kwargs): @@ -1627,12 +1660,12 @@ def AdaptiveAvgPool1d(input, output_size, signature = None): return IRDimops(AdaptiveAvgPool1d, 'adaptive_avg_pool1d', signature, [anno], [input], output_size=output_size) -def CrossEntropy(input, target, weight=None, +def CrossEntropy(input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0, signature = None): """ torch.nn.functional.cross_entropy( - input, target, weight=None, + input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0) """ @@ -1664,12 +1697,12 @@ def GraphAnchor(name: str, signature = None): return node -def _comparison(creator: Callable, f: Callable, name: str, signature: str, +def _comparison(creator: Callable, f: Callable, name: str, signature: str, input, other): """ if both operands are scalars, returns bool. if one operand is a tensor, returns a broadcasted tensor with dtype being bool. - + @param creator Callable: the outside creation function @param f Callable: (Scalar, Scalar) -> bools """ diff --git a/tests/codegen/__init__.py b/tests/codegen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/codegen/test_emit.py b/tests/codegen/test_emit.py new file mode 100644 index 00000000..5ba94634 --- /dev/null +++ b/tests/codegen/test_emit.py @@ -0,0 +1,33 @@ +import pytest +from cube.codegen.emit import CodeEmission +from cube.ir.cten import IRObject +from cube.codegen.emit import FuncEmission +from cube.graph.function import Dropout +from cube.ir.tensor import IRFullTensor + + +def test_tensor_name(): + repr_expr = CodeEmission().tensor_name + assert repr_expr(1, 'model.') == '1' + assert repr_expr('1') == "'1'" + + assert repr_expr(IRObject('name', 111, 'value'), 'model.') == 'name_111' + assert repr_expr(IRObject('name', 111, 'value').as_attr(), 'model.') == 'model.name_111' + assert repr_expr((IRObject('name', 111, 'value').as_attr(),), 'model.') == '(model.name_111,)' + + assert repr_expr(slice(1, None, IRObject('name', 111, 'value').as_attr()), 'model.') == 'slice(1, None, model.name_111)' + assert repr_expr({'a': 1, 'b': IRObject('name', 111, 'value')}, 'model.') == "{'a': 1, 'b': name_111}" + assert repr_expr([1], 'model.') == '[1]' + assert repr_expr((1,), 'model.') == '(1,)' + + with pytest.raises(ValueError): + from datetime import datetime + repr_expr(datetime.now()) + + + +def test_emit_module_attr(): + dropout = Dropout(IRFullTensor([1024, 1024], requires_grad=True), p=0.5, training='self.training', signature='torch.nn.functional.dropout') + code = FuncEmission().emit_fnode(dropout) + print(code) + assert 'training=self.training' in code[0] diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 01a8b8e5..8a3f2b26 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -1,7 +1,9 @@ ### Only test the anno creation in these tests import cube.graph.function.function as F -from cube.ir.cten import IRTensor +from cube.ir.cten import IRObject, IRTensor + +import pytest def test_handle_broadcast_multi(): @@ -33,6 +35,100 @@ def test_Expand(): op.new([inp], [out], size=[10, 2]) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ -> a 2' + +def test_variadic_extraction(): + def o(value): + return IRObject(value=value) + assert F.extract_variadic([]) == ([], []) + assert F.extract_variadic(1) == ([1], [False]) + assert F.extract_variadic([1]) == ([1], [False]) + assert F.extract_variadic((1,)) == ([1], [False]) + assert F.extract_variadic([1, 2]) == ([1, 2], [False, False]) + + assert F.extract_variadic(o([])) == ([], []) + assert F.extract_variadic(o(1)) == ([1], [True]) + assert F.extract_variadic(o([1])) == ([1], [True]) + assert F.extract_variadic(o((1,))) == ([1], [True]) + assert F.extract_variadic(o([1, 2])) == ([1, 2], [True, True]) + + assert F.extract_variadic([1, o(2)]) == ([1, 2], [False, True]) + assert F.extract_variadic([1, o(2), 3, o(4)]) == ([1, 2, 3, 4], [False, True, False, True]) + + with pytest.raises(ValueError, match='.*nested.*'): + F.extract_variadic([1, o([2, 3])]) + with pytest.raises(ValueError, match='.*nested.*'): + F.extract_variadic([1, [2, 3]]) + with pytest.raises(ValueError, match='Unsupported type.*'): + F.extract_variadic(True) + with pytest.raises(ValueError, match='Unsupported type.*'): + F.extract_variadic([1, True]) + with pytest.raises(ValueError, match='Unsupported type.*'): + F.extract_variadic(o([1, True])) + + +def test_Repeat(): + def o(value): + return IRObject(value=value) + inp = IRTensor([3]) + op = F.Repeat(inp, (4, 2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b^ -> 4 (2 b^)' + + inp = IRTensor([3]) + op = F.Repeat(inp, (4, 1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b -> 4 b' + + inp = IRTensor([3]) + op = F.Repeat(inp, 4, 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b -> 4 b' + + inp = IRTensor([3, 2]) + op = F.Repeat(inp, (4, 2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ -> (4 a^) (2 b^)' + + inp = IRTensor([3, 2]) + op = F.Repeat(inp, 4, 2, 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b^ c -> 4 (2 b^) c' + + inp = IRTensor([3, 2]) + op = F.Repeat(inp, (4, 2), 1) # the args(1) is ignored + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ -> (4 a^) (2 b^)' + + inp = IRTensor([3, 2]) + op = F.Repeat(inp, (4, 1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b -> (4 a^) b' + + inp = IRTensor([3, 2]) + with pytest.raises(ValueError): + op = F.Repeat(inp, (2)) + + with pytest.raises(ValueError, match='.*nested.*'): + op = F.Repeat(inp, 4, (4, 2)) + + inp = IRTensor([3]) + op = F.Repeat(inp, o((4, 2))) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b^ -> 4 (2 b^)' + + inp = IRTensor([3]) + op = F.Repeat(inp, (4, o(1))) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b^ -> 4 b^' + + inp = IRTensor([3, 2]) + op = F.Repeat(inp, (4, o(2))) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ -> (4 a^) (2 b^)' + + inp = IRTensor([3, 2]) + op = F.Repeat(inp, (o(4), 1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b -> (4 a^) b' + + inp = IRTensor([3, 2]) + op = F.Repeat(inp, 4, 2, o(1), 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'c^ d^ -> 4 2 c^ (2 d^)' + + inp = IRTensor([3, 2]) + with pytest.raises(ValueError): + op = F.Repeat(inp, o(2)) + + def test_Where(): op = F.Where(IRTensor([3, 4]), IRTensor([3, 4]), IRTensor([3, 4])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, a b, a b -> a b' diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index d79a5f8a..22004364 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -6,7 +6,7 @@ from cube.parallel import parallelize, ComputeConfig, CubeModule -from .common import PASData, init_distributed +from .common import PASData, init_distributed, PASRandomSPMD from ..launch_torchrun import launch_torchrun def _to_cube_model(module, compute_config, cube_savedir, load_module): @@ -214,3 +214,116 @@ def test_codegen_unused_args2(): with tempfile.TemporaryDirectory() as tempdir: launch_torchrun(1, _gencode_unused_args_worker2, tempdir) + + +class AttrModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, attr): + return x + getattr(attr, 'a') + + +def _gencode_contains(cubesave_dir, module_class, index, search_re): + from cube.parallel import _CUBE_MODULE_NAMESPACE, _get_full_qualified_name + from pathlib import Path + import re + namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}' + outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) + filecontent = (outdir /f'gencode{index}.py').read_text() + matches = re.findall(search_re, filecontent) + return bool(matches) + + +def test_codegen_attr(): + if not torch.cuda.is_available(): + print('skip test_codegen_attr due to lack of cuda devices') + return + class AttrHelper: + def __init__(self) -> None: + self.a = 2.0 + + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + AttrModule(), + {'x': torch.tensor([1.0, 2.0, 3.0, 6.0]), 'attr': AttrHelper()}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False + ) + assert _gencode_contains(tempdir, AttrModule, 0, r'builtins.getattr\(.*, \'a\'\)') + assert m_new is None + + +class GetItemModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, batched_data): + data_x = batched_data["x"] + n_graph, n_node = data_x.size()[:2] + padding_mask = (data_x[:, :, 0]).eq(0) # B x T x 1 + padding_mask_cls = torch.zeros( + n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype + ) + padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1) + return padding_mask + + +def test_codegen_getitem(): + if not torch.cuda.is_available(): + print('skip test_codegen_getitem due to lack of cuda devices') + return + + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + GetItemModule(), + {'batched_data': {'x': torch.tensor([[[1.0], [2.0], [3.0], [6.0]]])}}, + PASRandomSPMD, + ComputeConfig(2, 2), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False, + override=True, + ) + assert _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') + assert _gencode_contains(tempdir, GetItemModule, 1, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') + assert m_new is None + + +class TrainingModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + if self.training: + return self.linear(x) + else: + return self.linear(x) + 1 + + +def test_codegen_training_flag(): + """ + Test it can support modules without parameters + """ + if not torch.cuda.is_available(): + print('skip test_codegen_training_flag due to lack of cuda devices') + return + with tempfile.TemporaryDirectory() as tempdir: + m = TrainingModule() + m.train() + + # self.training isn't supported in concrete_trace + with pytest.raises(RuntimeError, match='Node referenced nonexistent target.*'): + parallelize( + m, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False + ) From 8aa1dc447efafc59ee17878e22c0605cbafaa33d Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 30 Oct 2023 02:31:50 +0000 Subject: [PATCH 1513/1892] Merged PR 1881: hot fix: fix kwargs error in scale `.kwargs` now should use `.update` to set its content. --- cube/codegen/module/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 98726a0d..bd30411d 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -219,7 +219,7 @@ def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: if isinstance(node, IRAdapter): adapter = copy.copy(node) adapter._id = node.cid - adapter.kwargs = dict(**node.kwargs) + adapter.kwargs.update(node.kwargs) prims = [] for prim in adapter.prims: p = copy.copy(prim) From d2e7069a6b7699cc77b3d4728d997b8e658e461f Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 31 Oct 2023 05:25:06 +0000 Subject: [PATCH 1514/1892] Merged PR 1877: merge concrete trace and shape prop parity check passed --- .../fx/concrete_trace_utils/__init__.py | 2 +- .../fx/concrete_trace_utils/concrete_proxy.py | 5 + .../concrete_trace_utils/concrete_tracer.py | 32 ++- .../kwargs_shape_prop/__init__.py | 0 .../kwargs_shape_prop/kwargs_interpreter.py | 203 ------------------ .../kwargs_shape_prop/kwargs_shape_prop.py | 98 --------- .../parser/fx/concrete_trace_utils/utils.py | 84 +++++++- cube/graph/parser/fx/parser.py | 5 +- 8 files changed, 102 insertions(+), 327 deletions(-) delete mode 100644 cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/__init__.py delete mode 100644 cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py delete mode 100644 cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/__init__.py b/cube/graph/parser/fx/concrete_trace_utils/__init__.py index e4a574ff..630bfd03 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/__init__.py +++ b/cube/graph/parser/fx/concrete_trace_utils/__init__.py @@ -12,4 +12,4 @@ More information about concrete tracing can be found in the :func:`concrete_trace` documentation. """ from .concrete_tracer import ConcreteTracer, concrete_trace -from .utils import ExtraSEFPatcher +from .utils import ExtraSEFPatcher, TensorMetadata diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index ecc10327..105946fc 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -281,6 +281,11 @@ def __init__(self, root: ConcreteProxy, attr: str): self._node: Optional[Node] = None if _orig_isinstance(root.value, torch.Tensor) and attr == 'is_cuda' and self.tracer.cpu_offload: self.value = True + elif _orig_isinstance(root.value, torch.Tensor) and attr == 'device' and self.tracer.cpu_offload: + self.value = torch.device('cuda') + warning_msg = "operation .device is detected, it will always return torch.device('cuda') during trace, " + \ + "please make sure don't manually change the tensor device in the code." + _logger.warning(warning_msg) else: self.value = _orig_getattr(root.value, attr) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 6528061e..1479e992 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -128,7 +128,7 @@ def __exit__(self, *args): _orig_node_is_impure, ) -from .utils import FrameRecord, ExtraSEFPatcher +from .utils import FrameRecord, ExtraSEFPatcher, extract_results_metadata, EmptyResult # pyright: reportGeneralTypeIssues=false _logger = logging.getLogger(__name__) @@ -394,7 +394,7 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] @compatibility(is_backward_compatible=True) def create_node(self, kind : str, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: + type_expr : Optional[Any] = None, node_result: Any = EmptyResult) -> Node: """ This method is almost the same as the one in `TracerBase` class of Pytorch2.0. Add it here because this method of Pytorch1.13 and older version @@ -415,6 +415,7 @@ def create_node(self, kind : str, target : Target, node.meta['nn_module_stack'] = copy.copy(self.module_stack) else: node.meta['nn_module_stack'] = collections.OrderedDict() + extract_results_metadata(node_result, node) return node @compatibility(is_backward_compatible=True) @@ -447,7 +448,7 @@ def upwrapper(obj: Any): assert isinstance(args_, tuple) assert isinstance(kwargs_, dict) - node = self.create_node(kind, target, args_, kwargs_, name, type_expr) + node = self.create_node(kind, target, args_, kwargs_, name, type_expr, value_unwrapped) if self.record_frames and kind != 'placeholder': with self.do_temp_disable(True, True, True): @@ -482,16 +483,16 @@ def create_arg(self, a: Any) -> Union[Node, Any]: if isinstance(a, torch.nn.Parameter): for n, p in self.root.named_parameters(): if a is p: - return self.create_node('get_attr', n, (), {}) + return self.create_node('get_attr', n, (), {}, node_result=a) raise NameError('parameter is not a member of this module') elif isinstance(a, torch.Tensor): for n_, p_ in self.root.named_buffers(): if a is p_: - return self.create_node('get_attr', n_, (), {}) + return self.create_node('get_attr', n_, (), {}, node_result=a) elif isinstance(a, torch.nn.Module): for n_, p_ in self.root.named_modules(): if a is p_: - return self.create_node('get_attr', n_, (), {}) + return self.create_node('get_attr', n_, (), {}, node_result=a) # for slice if isinstance(a, slice): start = self.create_arg(a.start) @@ -500,14 +501,14 @@ def create_arg(self, a: Any) -> Union[Node, Any]: if _orig_isinstance(start, Node)\ or _orig_isinstance(stop, Node)\ or _orig_isinstance(step, Node): - return self.create_node('call_function', _orig_slice, (start, stop, step), {}) + return self.create_node('call_function', _orig_slice, (start, stop, step), {}, node_result=a) else: return a # For NamedTuple instances that appear literally as args, we emit # a node to construct the NamedTuple and use that Node as the argument. if isinstance(a, tuple) and hasattr(a, '_fields'): args = tuple(self.create_arg(elem) for elem in a) - return self.create_node('call_function', a.__class__, args, {}) + return self.create_node('call_function', a.__class__, args, {}, node_result=a) # Tensors do not have a reliable string repr() from which they can be # constructed (and we probably don't want to rely on that, either), so @@ -531,7 +532,7 @@ def create_arg(self, a: Any) -> Union[Node, Any]: self.tensor_attrs[a] = qualname setattr(self.root, qualname, a) - return self.create_node('get_attr', qualname, (), {}) + return self.create_node('get_attr', qualname, (), {}, node_result=a) if _orig_type(a) in _proxyable_classes: # This is an instance of a proxyable class for which we did not @@ -546,7 +547,7 @@ def create_arg(self, a: Any) -> Union[Node, Any]: i += 1 setattr(self.root, qualname, a) - return self.create_node('get_attr', qualname, (), {}) + return self.create_node('get_attr', qualname, (), {}, node_result=a) if isinstance(a, (torch.autograd.function.Function, torch.autograd.function.FunctionMeta)): return a @@ -1134,9 +1135,9 @@ def getattr_wrapper(obj, *args): for module in self._autowrap_search: _autowrap_check(self, module.__dict__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) with OperatorPatcherContext(self, use_operator_patch, operator_patch_backlist): - self.create_node('output', 'output', - (self.create_arg(OperatorPatcherContext.patch_run(fn, *args, *more_args, **kwargs)),), - {}, type_expr=fn.__annotations__.get('return', None)) + results = OperatorPatcherContext.patch_run(fn, *args, *more_args, **kwargs) + self.create_node('output', 'output', (self.create_arg(results),), + {}, type_expr=fn.__annotations__.get('return', None), node_result=results) finally: # for cuda versions of pytorch, autograd.Function.apply should be reverted manually delattr(torch.autograd.Function, 'apply') @@ -1695,8 +1696,6 @@ def f(x, y): assert all(callable(ignore_func) for ignore_func in dce_ignored_function) tracer = ConcreteTracer(cpu_offload = cpu_offload, record_frames = record_frames) - is_training = root.training - root.eval() graph = tracer.trace(root, autowrap_leaf_function = autowrap_leaf_function, @@ -1759,7 +1758,4 @@ def f(x, y): if check_args is not None: assert root(**check_args) == traced(**check_args) - if is_training: - root.train() - return traced diff --git a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/__init__.py b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py deleted file mode 100644 index a7ced04f..00000000 --- a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_interpreter.py +++ /dev/null @@ -1,203 +0,0 @@ -import torch -import torch.fx -from torch.fx import Interpreter, Node, GraphModule -from typing import Optional, Union, Tuple, Dict, List, Any, Iterator, Callable, MutableMapping, Mapping -from torch.utils._pytree import tree_map - -Target = Union[Callable[..., Any], str] - -BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, - torch.Tensor, torch.device, torch.memory_format, torch.layout] - -Argument = Optional[Union[ - Tuple[Any, ...], - List[Any], - Dict[str, Any], - slice, - Node, - BaseArgumentTypes -]] - - -class KwargsInterpreter(Interpreter): - def __init__(self, module : GraphModule, garbage_collect_values : bool = True, fake_device_type='cpu'): - super().__init__(module, garbage_collect_values) - assert fake_device_type in ('cpu', 'cuda') - self.fake_device_type = fake_device_type - - def run(self, - concrete_args: Union[Dict[str, Any], Tuple, MutableMapping[str, Any], Mapping[str, Any]] = None, - initial_env: Optional[Dict[Node, Any]] = None, - enable_io_preocessing: bool = True) -> Any: - - self.env = initial_env if initial_env else {} - - if isinstance(concrete_args, tuple): - # if concrete_args is a tuple, then they are positional args - # then they are consumed left-to-right by `placeholder` nodes. - # Use an iterator to keep track of position and extract those values - if enable_io_preocessing: - args = self.module.graph.process_inputs(*concrete_args) - self.args_iter: Iterator[Any] = iter(args) - self.concrete_kwargs = None - else: - try: - # concrete_args is a kwargs dict/mapping - self.args_iter = None - self.concrete_kwargs = concrete_args - self.used_concrete_kwargs = [] - # get default values of parameters in `forward()` method - import inspect - fw = inspect.unwrap(self.module.forward) - args_default_values = fw.__defaults__ - if args_default_values is not None: - fw_code = fw.__code__ - n_args = fw_code.co_argcount + fw_code.co_kwonlyargcount - names_iter = iter(fw_code.co_varnames) - start_idx = 0 - if fw_code.co_varnames[0] == 'self': - _ = next(names_iter) # skip self - start_idx = 1 - args_names = [next(names_iter) for idx in range(start_idx, n_args)] - diff_len = len(args_names) - len(args_default_values) - self.default_args = {args_names[idx + diff_len]: args_default_values[idx] for idx in - range(len(args_default_values))} - else: - self.default_args = {} - except: - raise RuntimeError(f'invalid concrete_args type: {type(concrete_args)}') - - assert ( - self.args_iter is None or self.concrete_kwargs is None), 'can not use positional args and keyword args at the same time' - - for node in self.module.graph.nodes: - if node in self.env: - continue - try: - self.env[node] = self.run_node(node) - except Exception as e: - print(node.name, node.op, node.target) - msg = f'While executing {node.format_node()}' - msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg) - msg += f"\nOriginal traceback:\n{node.stack_trace}" - e.args = (msg,) + e.args[1:] - if isinstance(e, KeyError): - raise RuntimeError(*e.args) - raise - - if self.garbage_collect_values: - for to_delete in self.user_to_last_uses.get(node, []): - del self.env[to_delete] - - if node.op == 'output': - output_val = self.env[node] - return self.module.graph.process_outputs(output_val) if enable_io_preocessing else output_val - - def run_node(self, n: Node) -> Any: - """ - Run a specific node ``n`` and return the result. - Calls into placeholder, get_attr, call_function, - call_method, call_module, or output depending - on ``node.op`` - - Args: - n (Node): The Node to execute - - Returns: - Any: The result of executing ``n`` - """ - args, kwargs = self.fetch_args_kwargs_from_env(n) - return getattr(self, n.op)(n.target, args, kwargs) - - def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: - """ - Execute a `placeholder` node. - - Args: - target(Target): The call target for this node, - exactly the argument name of the forward function - args(Tuple): Tuple of positional args for this invocation - kwargs(Dict): Dict of keyword arguments for this invocation - - Returns: - Any: The argument value that was retrieved. - """ - assert isinstance(target, str) - if target.startswith('**'): - # For a douvle-starred parameter, e.g., `**kwargs`, - # retrieve all the remaining values from the concrete kwargs dict - remaining_keys = [key for key in self.concrete_kwargs if key not in self.used_concrete_kwargs] - return {key: self.concrete_kwargs[key] for key in remaining_keys} - elif target.startswith('*'): - if self.concrete_kwargs is not None: - raise RuntimeError('unexpected positional args in kwargs mode') - return list(self.args_iter) - else: - if self.concrete_kwargs is not None: - try: - ret_arg = self.concrete_kwargs[target] - except KeyError: - return self.default_args[target] - else: - self.used_concrete_kwargs.append(target) - return ret_arg - else: - try: - return next(self.args_iter) - except StopAsyncIteration: - if len(args) > 0: - return args[0] - else: - raise RuntimeError( - f'Expected positional argument for parameter {target}, but one was not passed in!') - - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: - assert not isinstance(target, str) - if self.fake_device_type == 'cpu': - to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - result = target(*args, **kwargs) - if isinstance(result, torch.Tensor): - return result.cpu() - else: - to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t - return tree_map(to_cpu, result) - else: - return target(*args, **kwargs) - - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - # args[0] is the `self` object for this method call - self_obj, *args_tail = args - assert isinstance(target, str) - if self.fake_device_type == 'cpu': - self_obj = self_obj.cuda() - to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t - args_tail = tree_map(to_cuda, args_tail) - kwargs = tree_map(to_cuda, kwargs) - result = getattr(self_obj, target)(*args_tail, **kwargs) - self_obj = self_obj.cpu() - if isinstance(result, torch.Tensor): - return result.cpu() - else: - to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t - return tree_map(to_cpu, result) - else: - return getattr(self_obj, target)(*args_tail, **kwargs) - - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: - assert isinstance(target, str) - mod = self.fetch_attr(target) - if self.fake_device_type == 'cpu': - mod = mod.cuda() - to_cuda = lambda t: t.cuda() if isinstance(t, torch.Tensor) else t - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - result = mod(*args, **kwargs) - if isinstance(result, torch.Tensor): - return result.cpu() - else: - to_cpu = lambda t: t.cpu() if isinstance(t, torch.Tensor) else t - return tree_map(to_cpu, result) - else: - return mod(*args, **kwargs) diff --git a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py b/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py deleted file mode 100644 index d6a30779..00000000 --- a/cube/graph/parser/fx/concrete_trace_utils/kwargs_shape_prop/kwargs_shape_prop.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -import traceback -from torch.fx.node import Node, map_aggregate -from typing import Optional, Union, NamedTuple, Tuple, Any, Dict -from .kwargs_interpreter import KwargsInterpreter - - -__all__ = ['TensorMetadata', 'KwargsShapeProp'] - - -class TensorMetadata(NamedTuple): - # TensorMetadata is a structure containing pertinent information - # about a tensor within a PyTorch program. - - # General Tensor metadata - shape : torch.Size - dtype : torch.dtype - requires_grad : bool - stride : Tuple[int] - memory_format : Optional[torch.memory_format] - - # Quantization metadata - is_quantized : bool - qparams: Dict[str, Any] - - -def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: - """ - Extract a TensorMetadata NamedTuple describing `result`. - """ - shape = result.shape - dtype = result.dtype - requires_grad = result.requires_grad - stride = result.stride() - - memory_formats = { - torch.contiguous_format, - torch.channels_last, - torch.channels_last_3d, - } - - memory_format = None - - for query_format in memory_formats: - if result.is_contiguous(memory_format=query_format): - memory_format = query_format - break - - is_quantized = result.is_quantized - qparams: Dict[str, Any] = {} - if is_quantized: - qscheme = result.qscheme() - qparams["qscheme"] = qscheme - if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: - qparams["scale"] = result.q_scale() # type: ignore[assignment] - qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] - elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: - # In this branch, scale and zero_point are expected to be tensors, - # we store the values as immutable_list in TensorMetadata for - # easier serialization downstream - qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] - qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] - qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] - - return TensorMetadata( - shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) - - -class KwargsShapeProp(KwargsInterpreter): - def run_node(self, n: Node): - try: - result = super().run_node(n) - except Exception as e: - raise RuntimeError( - f"ShapeProp error for: node={n.format_node()} with " - f"meta={n.meta}" - ) from e - - found_tensor = False - - def extract_tensor_meta(obj): - if isinstance(obj, torch.Tensor): - nonlocal found_tensor - found_tensor = True - return _extract_tensor_metadata(obj) - else: - return obj - - # if the obj is a tensor, then wrap it into a TensorMetaData - # else recursively descend and wrap - meta = map_aggregate(result, extract_tensor_meta) - if found_tensor: - n.meta['tensor_meta'] = meta - n.meta['type'] = type(result) - return result - - def propagate(self, concrete_args: Union[Dict[str, Any], Tuple]): - return super().run(concrete_args) diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index db244c6b..3b792870 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -4,12 +4,11 @@ import builtins from dataclasses import dataclass import operator -from typing import Any, Callable, Set, Type +from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Type import functools import torch -from torch.fx import Node -from torch.fx.node import _side_effectful_functions +from torch.fx.node import Node, map_aggregate, _side_effectful_functions # These need to run in global scope to handle nested calls correctly _orig_module_call: Callable = torch.nn.Module.__call__ @@ -128,6 +127,7 @@ def __repr__(self) -> str: else: return '' + class ExtraSEFPatcher: def __init__(self, extra_side_effectful_functions: Set[Callable]): self.extra_side_effectful_functions = extra_side_effectful_functions @@ -139,3 +139,81 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): _side_effectful_functions.difference_update(self.incontext_funcs) + + +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + # General Tensor metadata + shape : torch.Size + dtype : torch.dtype + requires_grad : bool + stride : Tuple[int] + memory_format : Optional[torch.memory_format] + + # Quantization metadata + is_quantized : bool + qparams: Dict[str, Any] + + +def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() + + memory_formats = { + torch.contiguous_format, + torch.channels_last, + torch.channels_last_3d, + } + + memory_format = None + + for query_format in memory_formats: + if result.is_contiguous(memory_format=query_format): + memory_format = query_format + break + + is_quantized = result.is_quantized + qparams: Dict[str, Any] = {} + if is_quantized: + qscheme = result.qscheme() + qparams["qscheme"] = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + qparams["scale"] = result.q_scale() # type: ignore[assignment] + qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] + elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: + # In this branch, scale and zero_point are expected to be tensors, + # we store the values as immutable_list in TensorMetadata for + # easier serialization downstream + qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] + qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] + qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] + + return TensorMetadata( + shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + + +def extract_tensor_metadata(obj: Any): + if isinstance(obj, torch.Tensor): + return _extract_tensor_metadata(obj) + else: + return obj + + +def extract_results_metadata(results: Any, node: Node): + if results is not EmptyResult: + meta = map_aggregate(results, extract_tensor_metadata) + node.meta['tensor_meta'] = meta + node.meta['type'] = type(results) + + +class EmptyResult: + """Used for identification no results. + """ + pass diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 0a1c6105..0540b4fc 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -12,7 +12,7 @@ from cube.graph.function.dimops import IRDimops import torch.fx -from .concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import KwargsShapeProp as ShapeProp +from .concrete_trace_utils import TensorMetadata _logger = logging.getLogger(__name__) @@ -54,7 +54,6 @@ def parse(module: torch.fx.GraphModule, # shape propagation assert isinstance(dummy_inputs, dict), "Expected dummy inputs to parse module" - ShapeProp(module).propagate(dummy_inputs) # create IRObjects and IRTensors for node in module.graph.nodes: @@ -110,8 +109,6 @@ def parse_node(node: torch.fx.Node, module, dynamic_shape: bool, frame: Frame) - @staticmethod def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame, concrete_value: Optional[Any] = None): - - from cube.graph.parser.fx.concrete_trace_utils.kwargs_shape_prop.kwargs_shape_prop import TensorMetadata assert isinstance(node, torch.fx.Node) def meta2var(meta: Any) -> Any: From f6092d96dee935bce5ffb3f387ca6318198e6d09 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 1 Nov 2023 07:29:58 +0000 Subject: [PATCH 1515/1892] Merged PR 1886: fix output metadata unwrapped --- .../concrete_trace_utils/concrete_tracer.py | 13 ++++++--- tests/graph/parser/test_converter.py | 27 +++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 1479e992..6e3cd7e0 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -433,12 +433,12 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: similar to _symbolic_trace.Tracer.create_proxy. use the 'run_target' to actually execute the code, and store the value in 'value' field. """ - def upwrapper(obj: Any): + def unwrap(obj: Any): while _orig_isinstance(obj, ep.ConcreteProxy): obj = obj.value return obj - args_unwrapped = ep.map_aggregate_not_proxy(args, upwrapper) - kwargs_unwrapped = ep.map_aggregate_not_proxy(kwargs, upwrapper) + args_unwrapped = ep.map_aggregate_not_proxy(args, unwrap) + kwargs_unwrapped = ep.map_aggregate_not_proxy(kwargs, unwrap) # real value by execution value_unwrapped = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) @@ -1136,8 +1136,13 @@ def getattr_wrapper(obj, *args): _autowrap_check(self, module.__dict__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) with OperatorPatcherContext(self, use_operator_patch, operator_patch_backlist): results = OperatorPatcherContext.patch_run(fn, *args, *more_args, **kwargs) + # we should unwrap proxy to the original value in the results when we record it to node.meta['tensor_meta'] + def unwrap(obj: Any): + while _orig_isinstance(obj, ep.ConcreteProxy): + obj = obj.value + return obj self.create_node('output', 'output', (self.create_arg(results),), - {}, type_expr=fn.__annotations__.get('return', None), node_result=results) + {}, type_expr=fn.__annotations__.get('return', None), node_result=ep.map_aggregate_not_proxy(results, unwrap)) finally: # for cuda versions of pytorch, autograd.Function.apply should be reverted manually delattr(torch.autograd.Function, 'apply') diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 15206922..9eef910e 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -94,3 +94,30 @@ def forward(self, x, *args): if 'frame_record' in node.meta and cube_path in str(node.meta['frame_record']): err_msg = f"Cube root path should not in node comment {node.meta['frame_record']}" raise RuntimeError(err_msg) + + +def test_record_metadata(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + return self.linear(x) + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + from cube.graph.parser.fx.concrete_trace_utils.concrete_proxy import ConcreteProxy + from cube.graph.parser.fx.concrete_trace_utils import TensorMetadata + + for node in fx_graph.graph.nodes: + # this assert is only for this simple model, all node should have TensorMetadata type 'tensor_meta' + # other complex model nodes may not have 'tensor_meta' or a TensorMetadata type 'tensor_meta' + assert 'tensor_meta' in node.meta and isinstance(node.meta['tensor_meta'], TensorMetadata) + tm = node.meta['tensor_meta'] + assert not isinstance(tm.shape, ConcreteProxy) + assert not isinstance(tm.dtype, ConcreteProxy) + assert not isinstance(tm.requires_grad, ConcreteProxy) + assert not isinstance(tm.stride, ConcreteProxy) + assert not isinstance(tm.memory_format, ConcreteProxy) From 095437f7430db738315d0b99d4f42dd0bf8c04f4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 1 Nov 2023 11:17:36 +0000 Subject: [PATCH 1516/1892] Merged PR 1887: merge adapter and change all adapters into out-of-place forms 1) both cube.runtime.adapter.nn and cube.runtime.adapter.collective use same collective implementation; 2) all collectives in cube.runtime.adapter.collectives change to adopt out-of-place implementation --- cube/runtime/adapter/collectives.py | 51 ++++++++--- cube/runtime/adapter/nn.py | 135 +++++++--------------------- 2 files changed, 69 insertions(+), 117 deletions(-) diff --git a/cube/runtime/adapter/collectives.py b/cube/runtime/adapter/collectives.py index 750b81c8..46fe6618 100644 --- a/cube/runtime/adapter/collectives.py +++ b/cube/runtime/adapter/collectives.py @@ -1,3 +1,11 @@ +""" +This module offers the wrap of communication primitives +based on `torch.distributed`. The use of these primitives standalone is typically +for non-autograd (e.g., inference) scenarios. + +Every collective is implemented using out-of-place semantics. +""" + from typing import List, Tuple, Optional import torch @@ -40,13 +48,11 @@ def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, def all_reduce(tensor: torch.Tensor, ranks: List[int], async_op=False) -> torch.Tensor: - """ - Allreduce - """ + """Allreduce""" if not async_op: CudaTimer().start(field_name='comm', predefined=True) tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor - tensor = tensor.detach() + tensor = tensor.detach().clone() group = DeviceGroup().get_group(ranks) if async_op: @@ -61,9 +67,7 @@ def all_reduce(tensor: torch.Tensor, def all_gather(tensor: torch.Tensor, dim: int, ranks: Tuple[int], async_op=False) -> torch.Tensor: - """ - Allgather - """ + """Allgather""" if not async_op: CudaTimer().start(field_name='comm', predefined=True) tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor @@ -84,9 +88,7 @@ def all_gather(tensor: torch.Tensor, dim: int, def reduce_scatter(tensor: torch.Tensor, dim: int, ranks: Tuple[int], async_op=False) -> torch.Tensor: - """ - ReduceScatter - """ + """ReduceScatter""" if not async_op: CudaTimer().start(field_name='comm', predefined=True) itensors = list(tensor.chunk(len(ranks), dim)) @@ -104,9 +106,7 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, def all_to_all(tensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int], async_op=False) -> torch.Tensor: - """ - All-to-all - """ + """All-to-all""" if not async_op: CudaTimer().start(field_name='comm', predefined=True) itensors = list(tensor.chunk(len(ranks), dim=odim)) @@ -126,6 +126,31 @@ def all_to_all(tensor: torch.Tensor, idim: int, odim: int, return otensor +def all_to_all_single(tensor: torch.Tensor, idim: int, odim: int, + ranks: Tuple[int], async_op: bool = False) -> torch.Tensor: + """All-to-all for single tensor""" + if not async_op: + CudaTimer().start(field_name='comm', predefined=True) + tensor = tensor.transpose(0, odim) if odim != 0 else tensor + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + group = DeviceGroup().get_group(ranks) + otensor = torch.empty_like(tensor) + work = torch.distributed.all_to_all_single(otensor, tensor, group=group, async_op=async_op) + + def all2all_callback(t): + t = t.transpose(0, odim) if odim != 0 else t + return torch.concat(tuple(t.chunk(len(ranks), dim=odim)), dim=idim) + + if work: + AsyncCommHandler().submit(tensor, [work], all2all_callback) + else: + otensor = all2all_callback(otensor) + + if not async_op: + CudaTimer().stop(field_name='comm', predefined=True) + return otensor + + def chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int], async_op=False) -> torch.Tensor: """ split dimension in n chunks and take idx-th chunk diff --git a/cube/runtime/adapter/nn.py b/cube/runtime/adapter/nn.py index 7089ae8a..8b6a025f 100644 --- a/cube/runtime/adapter/nn.py +++ b/cube/runtime/adapter/nn.py @@ -1,102 +1,35 @@ +""" +This module offers autograd functions for communication +primitives. This is typically used in the training with tensor +parallelism scenario. +""" + from typing import List, Tuple import torch from cube.profiler.timer import CudaTimer from cube.runtime.device import DeviceGroup - - -def _allreduce(itensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm', predefined=True) - if not itensor.is_contiguous(): - itensor = itensor.contiguous() - # force allreduce not to be in-place - itensor = itensor.detach().clone() - group = DeviceGroup().get_group(ranks) - torch.distributed.all_reduce(itensor, group=group) - CudaTimer().stop(field_name='comm', predefined=True) - return itensor - - -def _allgather(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm', predefined=True) - if not itensor.is_contiguous(): - itensor = itensor.contiguous() - group = DeviceGroup().get_group(ranks) - tensor_list = [torch.empty_like(itensor) for _ in ranks] - tensor_list[torch.distributed.get_rank(group)] = itensor.data - torch.distributed.all_gather(tensor_list, itensor, group=group) - # concat - otensor = torch.concat(tuple(tensor_list), dim=dim).requires_grad_() - CudaTimer().stop(field_name='comm', predefined=True) - return otensor - - -def _reducescatter(itensor: torch.Tensor, dim:int, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm', predefined=True) - itensors = list(itensor.chunk(len(ranks), dim)) - for idx, tensor in enumerate(itensors): - if not tensor.is_contiguous(): - itensors[idx] = tensor.contiguous() - group = DeviceGroup().get_group(ranks) - otensor = torch.empty_like(itensors[0]) - torch.distributed.reduce_scatter(otensor, itensors, group=group) - CudaTimer().stop(field_name='comm', predefined=True) - return otensor - - -def _alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm', predefined=True) - itensors = list(itensor.chunk(len(ranks), dim=odim)) - for idx, tensor in enumerate(itensors): - if not tensor.is_contiguous(): - itensors[idx] = tensor.contiguous() - otensors = [torch.empty_like(t) for t in itensors] - group = DeviceGroup().get_group(ranks) - torch.distributed.all_to_all(otensors, itensors, group=group) - otensor = torch.concat(tuple(otensors), dim=idim) - CudaTimer().stop(field_name='comm', predefined=True) - return otensor - - -def _alltoallsingle(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: - CudaTimer().start(field_name='comm', predefined=True) - if odim != 0: - itensor = itensor.transpose(0, odim) - if not itensor.is_contiguous(): - itensor = itensor.contiguous() - group = DeviceGroup().get_group(ranks) - otensor = torch.empty_like(itensor) - torch.distributed.all_to_all_single(otensor, itensor, group=group) - if odim != 0: - otensor = otensor.transpose(0, odim) - otensor = torch.concat(tuple(otensor.chunk(len(ranks), dim=odim)), dim=idim) - CudaTimer().stop(field_name='comm', predefined=True) - return otensor - - -def _chunk(itensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: - """ - split dimension in n chunks and take idx-th chunk - - ranks (Tuple[int]): the order of split tensor. - """ - group = DeviceGroup().get_group(ranks) - idx = torch.distributed.get_rank(group) - return itensor.chunk(len(ranks), dim)[idx] +from .collectives import ( + all_reduce, + all_gather, + reduce_scatter, + all_to_all, + all_to_all_single, + chunk +) class AllReduceIdentity(torch.autograd.Function): @staticmethod def forward(ctx, itensor: torch.Tensor, ranks: Tuple[int]): - return _allreduce(itensor, ranks) + return all_reduce(itensor, ranks) @staticmethod def backward(ctx, grad_output): return grad_output, None -@torch.jit.ignore def allreduce_identity(tensor: torch.Tensor, ranks: List[int]): return AllReduceIdentity.apply(tensor, ranks) @@ -111,10 +44,10 @@ def forward(ctx, itensor: torch.Tensor, ranks: Tuple[int]): @staticmethod def backward(ctx, grad: torch.Tensor): ranks = ctx._ranks - grad = _allreduce(grad, ranks) + grad = all_reduce(grad, ranks) return grad, None -@torch.jit.ignore + def identity_allreduce(tensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: return IdentityAllreduce.apply(tensor, ranks) @@ -124,17 +57,16 @@ class AllReduceAllReduce(torch.autograd.Function): @staticmethod def forward(ctx, itensor: torch.Tensor, ranks: Tuple[int]): ctx._ranks = ranks - otensor = _allreduce(itensor, ranks) + otensor = all_reduce(itensor, ranks) return otensor @staticmethod def backward(ctx, grad: torch.Tensor): ranks = ctx._ranks - grad = _allreduce(grad, ranks) + grad = all_reduce(grad, ranks) return grad, None -@torch.jit.ignore def allreduce_allreduce(tensor: torch.Tensor, ranks: Tuple[int]) -> torch.Tensor: return AllReduceAllReduce.apply(tensor, ranks) @@ -145,17 +77,16 @@ class ReduceScatterAllGather(torch.autograd.Function): def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int]): ctx._ranks = ranks ctx._dim = dim - return _reducescatter(itensor, dim, ranks) + return reduce_scatter(itensor, dim, ranks) @staticmethod def backward(ctx, grad: torch.Tensor): ranks = ctx._ranks dim = ctx._dim - grad = _allgather(grad, dim, ranks) + grad = all_gather(grad, dim, ranks) return grad, None, None -@torch.jit.ignore def reducescatter_allgather(tensor: torch.Tensor, dim: int, ranks: List[int]): return ReduceScatterAllGather.apply(tensor, dim, ranks) @@ -166,17 +97,16 @@ class AllGatherReduceScatter(torch.autograd.Function): def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int]): ctx._ranks = ranks ctx._dim = dim - return _allgather(itensor, dim, ranks) + return all_gather(itensor, dim, ranks) @staticmethod def backward(ctx, grad: torch.Tensor): ranks = ctx._ranks dim = ctx._dim - grad = _reducescatter(grad, dim, ranks) + grad = reduce_scatter(grad, dim, ranks) return grad, None, None -@torch.jit.ignore def allgather_reducescatter(tensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: return AllGatherReduceScatter.apply(tensor, dim, ranks) @@ -187,16 +117,15 @@ class AllGatherSplit(torch.autograd.Function): def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int]): ctx._ranks = ranks ctx._dim = dim - return _allgather(itensor, dim, ranks) + return all_gather(itensor, dim, ranks) @staticmethod def backward(ctx, grad: torch.Tensor): ranks = ctx._ranks dim = ctx._dim - return _chunk(grad, dim, ranks), None, None + return chunk(grad, dim, ranks), None, None -@torch.jit.ignore def allgather_split(tensor: torch.Tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: return AllGatherSplit.apply(tensor, dim, ranks) @@ -210,17 +139,16 @@ def forward(ctx, itensor: torch.Tensor, dim: int, ranks: Tuple[int]): """ ctx._ranks = ranks ctx._dim = dim - return _chunk(itensor, dim, ranks) + return chunk(itensor, dim, ranks) @staticmethod def backward(ctx, grad: torch.Tensor): ranks = ctx._ranks dim = ctx._dim - grad = _allgather(grad, dim, ranks) + grad = all_gather(grad, dim, ranks) return grad, None, None -@torch.jit.ignore def split_allgather(tensor, dim: int, ranks: Tuple[int]) -> torch.Tensor: return SplitAllGather.apply(tensor, dim, ranks) @@ -232,13 +160,13 @@ def forward(ctx, itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) ctx._ranks = ranks ctx._idim = idim ctx._odim = odim - return _alltoall(itensor, idim, odim, ranks) + return all_to_all(itensor, idim, odim, ranks) @staticmethod def backward(ctx, grad: torch.Tensor): ranks = ctx._ranks idim, odim = ctx._idim, ctx._odim - grad = _alltoall(grad, odim, idim, ranks) + grad = all_to_all(grad, odim, idim, ranks) return grad, None, None, None @@ -249,17 +177,16 @@ def forward(ctx, itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) ctx._ranks = ranks ctx._idim = idim ctx._odim = odim - return _alltoallsingle(itensor, idim, odim, ranks) + return all_to_all_single(itensor, idim, odim, ranks) @staticmethod def backward(ctx, grad: torch.Tensor): ranks = ctx._ranks idim, odim = ctx._idim, ctx._odim - grad = _alltoallsingle(grad, odim, idim, ranks) + grad = all_to_all_single(grad, odim, idim, ranks) return grad, None, None, None -@torch.jit.ignore def alltoall_alltoall(itensor: torch.Tensor, idim: int, odim: int, ranks: Tuple[int]) -> torch.Tensor: return AllToAllAllToAllSingle.apply(itensor, idim, odim, ranks) From 96505ad17d3f764c05a8505f7f8b4e36742249b9 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Nov 2023 03:27:50 +0000 Subject: [PATCH 1517/1892] Merged PR 1885: Support checksum on graph status between nodes --- cube/compiler.py | 28 ++++++++++++++++++++++++++++ cube/graph/graph.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/cube/compiler.py b/cube/compiler.py index 317edf72..afa075c7 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -120,6 +120,13 @@ def decorator(fn: Callable) -> Callable: _logger.info(f'loading existed schedule from {filename} ...') return cube.load_default_schedule(filename) + ndevices = DeviceGroup().world_size + local_ndevs = DeviceGroup().local_world_size + nnodes = ndevices // local_ndevs + if nnodes > 1: + compile_ranks = list(range(0, ndevices, local_ndevs)) + compile_group = DeviceGroup().get_group(compile_ranks) + if DeviceGroup().local_rank == 0: compile_start = time.time() @@ -161,6 +168,27 @@ def decorator(fn: Callable) -> Callable: _logger.info(f'saving graph to {save_graph_file}') graph.dump(save_graph_file) + # checking graph consistency between multiple nodes + if nnodes > 1: + checksum = graph.checksum(strict=True) + _logger.debug(f'checking graph consistency (local md5: {checksum}) ...') + state = torch.tensor([ord(c) for c in checksum], dtype=torch.int, + device=torch.cuda.current_device()) + gather_list = None + if DeviceGroup().node_rank == 0: + gather_list = [torch.empty_like(state) for _ in range(nnodes)] + torch.distributed.gather(state, gather_list, dst=0, group=compile_group) + if DeviceGroup().node_rank == 0: + inconsistent_nodes = [] + for node_rank, checksum in enumerate(gather_list): + if state.ne(checksum).any(): + inconsistent_nodes.append(node_rank) + if len(inconsistent_nodes) > 0: + raise RuntimeError( + f'graph status is inconsistent on node ranks: {inconsistent_nodes}. ' + f'Please check pytorch version or re-run the compilation.' + ) + # run policy start = time.time() assert callable(PAS), f"Policy PAS is not callable" diff --git a/cube/graph/graph.py b/cube/graph/graph.py index cb6362f5..895354dc 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -11,6 +11,7 @@ import logging import copy import dill +import hashlib from cube.ir.cten import IRTensor, IRCell, IRObject from cube.ir.unique import IDGenerator @@ -1069,3 +1070,33 @@ def reset_node(segment: IRSegment): reset_node(graph) return graph + + def checksum(self, strict: bool = True) -> str: + """Get the MD5 checksum of the graph. + + This is used to guarantee the consistency of the graph between + multiple nodes. + + Note: + The checksum considers the IDGenerator status. If the user modifies + the IDGenerator status (i.e., creating tensors or nodes), it will + have a different checksum. + + Args: + strict (bool): If True (by default), get the checksum of the whole graph status, + including tensor shapes, tensor ids and node ids; + Otherwise (i.e., False), only check the graph structure of node ids, + node signatures without tensor ids. + + Returns: + str: MD5 checksum (32-bit) of the graph status + """ + max_tensor_id, max_cell_id = IDGenerator().get_states() + if not strict: + node_ids = tuple(n.cid for n in self.nodes()) + signatures = tuple(n.signature for n in self.nodes()) + checksum = hashlib.md5(str((max_tensor_id, max_cell_id, signatures, node_ids)).encode()).hexdigest() + else: + states = str((max_tensor_id, max_cell_id, self.extra_repr())) + checksum = hashlib.md5(states.encode()).hexdigest() + return checksum From c8985d227afc92ec3c715e2ead2ca8119503b855 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 2 Nov 2023 06:06:35 +0000 Subject: [PATCH 1518/1892] Merged PR 1890: add warning messages for unused node results add warning messages for unused node results --- cube/graph/graph.py | 51 ++++++++++++++++++++++++++++--------------- cube/graph/segment.py | 31 +++++++++++++++++--------- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 895354dc..a9b5f453 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -141,16 +141,6 @@ def backward(self, loss: Optional[IRSubTensor] = None): # set loss gradient loss.parent.to_loss() - # update require gradient: for tensors that have no consumers, - # make their gradient to be False - for ftensor in self.full_tensors(): - if ftensor.is_loss(): continue - consumers = [n for n in self.consumers(ftensor) if isinstance(n, IRFwOperation)] - if len(consumers) == 0 and ftensor.requires_grad: - _logger.warning( - f"detected a dead ftensor which is not consumed by any nodes:\n\t{ftensor.name}: {ftensor}") - ftensor.requires_grad = False - # infer gradient for ftensor in self.full_tensors(): self.infer_grad(ftensor) @@ -175,15 +165,16 @@ def backward(self, loss: Optional[IRSubTensor] = None): def from_logic_graph(nodes: List[IRCell], inputs: List[Any], outputs: List[Any], module_name: str): - """ - Generate IRGraph from logical graph (IRFullTensor) + """Generate IRGraph from logical graph (IRFullTensor) - @param nodes: nodes of the graph - @param inputs List[Any]: graph inputs - @param outputs List[Any]: graph outputs - @param module_name str: graph name + Args: + nodes (List[IRCell]): nodes of the graph + inputs (List[Any]): graph inputs + outputs (List[Any]): graph outputs + module_name (str): graph name - @return graph IRGraph + Returns: + IRGraph: the graph with each tensor is IRSubTensor. """ modifier = lambda t: t.tosub() if isinstance(t, IRFullTensor) else t # input / output @@ -200,6 +191,32 @@ def from_logic_graph(nodes: List[IRCell], subtensor = ftensor.tosub() if isinstance(ftensor, IRFullTensor) else ftensor node.set_output(idx, subtensor) graph = IRGraph(nodes, inputs, outputs, module_name) + + # check unused outputs + unused_obj_nodes: Dict[IRObject, List[IRCell]] = {} + graph_output_objects = [ + obj.parent for obj in IRSegment.get_objects_from_complex(graph.outputs())] + for obj in graph.full_objects(): + # loss tensor will always not used + if isinstance(obj, IRFullTensor) and obj.is_loss(): continue + # we don't need to show unused backward ops + if isinstance(obj, IRFullTensor) and obj.is_grad(): continue + consumers = graph.consumers(obj) + if len(consumers) == 0 and obj not in graph_output_objects: + if len(graph.producers(obj)) > 0: + unused_obj_nodes.setdefault(obj, []).extend(graph.producers(obj)) + if len(unused_obj_nodes) > 0: + dscp = (f'Following returns of nodes are not used by any other nodes.\n' + f'Please consider to remove them in the user defined model.\n') + for obj, unused_nodes in unused_obj_nodes.items(): + dscp += f'{obj}:\n' + for node in unused_nodes: + if node.comment is not None: + dscp += f'\t{node.comment}\n\t{node.name} (cid={node.cid})\n' + else: + dscp += f'\t{node.name} (cid={node.cid})\n' + _logger.warning(dscp) + return graph ##### Transformation Primitives ##### diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 6f3dba4f..4d67fc2b 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -81,8 +81,8 @@ def __init__(self, nodes: List[IRCell], inputs: List[IRObject], outputs: List[An self._nodes: List[IRCell] = [] - # full-tensor / sub-tensor mapping - self._ftensors: Set[IRFullTensor] = set() + # full objects + self._fobjects: Set[IRObject] = set() self._producers: Dict[IRFullTensor, List[IRCell]] = dict() self._consumers: Dict[IRFullTensor, List[IRCell]] = dict() self._ptensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() @@ -118,6 +118,17 @@ def isfw(self) -> bool: return all(n.isfw() for n in self._nodes) # return self._have_forward + def full_objects(self) -> Tuple[IRObject]: + """Get all full objects of this graph. + + Note: + The full tensor inside the node (e.g., IRSegment) will not be returned. + + Returns: + fobjects List[IRObject] + """ + return tuple(self._fobjects) + def full_tensors(self) -> Tuple[IRFullTensor]: """ Get all full tensors of this graph. @@ -125,7 +136,7 @@ def full_tensors(self) -> Tuple[IRFullTensor]: @return ftensors List[IRFullTensor] """ - return tuple(t for t in self._ftensors if isinstance(t, IRFullTensor)) + return tuple(t for t in self._fobjects if isinstance(t, IRFullTensor)) def attributes(self) -> Tuple[IRFullTensor]: """ @@ -332,7 +343,7 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: dscp : str = '' - ftensors = [ftensor] if ftensor is not None else self._ftensors + ftensors = [ftensor] if ftensor is not None else self._fobjects for ftensor in ftensors: dscp += f'====\nFull Tensor: {ftensor}\n' dscp += f'Producers:\n' @@ -373,8 +384,8 @@ def _add_ftensor(self, ftensor: IRObject): Add a full tensor in segment if the segment doesn't have the tensor. """ assert isinstance(ftensor, IRObject) - if ftensor not in self._ftensors: - self._ftensors.add(ftensor) + if ftensor not in self._fobjects: + self._fobjects.add(ftensor) self._producers[ftensor] = [] self._consumers[ftensor] = [] self._ptensors[ftensor] = [] @@ -387,8 +398,8 @@ def _remove_ftensor(self, ftensor: IRObject): Remove a full tensor in segment """ assert isinstance(ftensor, IRObject) - if ftensor in self._ftensors: - self._ftensors.remove(ftensor) + if ftensor in self._fobjects: + self._fobjects.remove(ftensor) del self._producers[ftensor] del self._consumers[ftensor] del self._ptensors[ftensor] @@ -404,7 +415,7 @@ def _reorder_producer_consumer(self): Note sub-segment will also be reordered. """ # clear up - self._ftensors, self._attributes = set(), set() + self._fobjects, self._attributes = set(), set() self._producers, self._ptensors = dict(), dict() self._consumers, self._ctensors = dict(), dict() @@ -656,7 +667,7 @@ def multiref(self, ftensor: IRFullTensor, *deprecated_args) -> IRFwOperation: @param tensor IRSubTensor: tensor. @return multiref IRFwOperation: the inserted multiref operator. """ - assert ftensor in self._ftensors, f"tensor: {ftensor} not in this graph." + assert ftensor in self._fobjects, f"tensor: {ftensor} not in this graph." if len(self.consumers(ftensor)) <= 1: return assert not ftensor.is_grad(), f"graph.multiref can only be applied on a non-gradient full tensor." # check no transformation From cde09cdf58ab7cd976009fb4236a3c9da8ca5b96 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Thu, 2 Nov 2023 08:35:43 +0000 Subject: [PATCH 1519/1892] Merged PR 1836: support reducescatter existing implementation for ZeRO is using allreduce to aggregate gradients, then using allgather to update weights. this pr supports reducescatter for aggregating gradients because only partial weights are updated by optimizer. --- cube/codegen/module/module.py | 10 ++ cube/runtime/adapter/reducer.py | 79 ++++++++++++---- cube/runtime/gnorm.py | 162 ++++++++++++++++++++++++++++++++ cube/runtime/module.py | 51 +++++++++- 4 files changed, 282 insertions(+), 20 deletions(-) create mode 100644 cube/runtime/gnorm.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index bd30411d..198b7af2 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -169,12 +169,22 @@ def get_comm_groups(self, scale_ndevs: Optional[int] = None): """ def _add_comm_for_group_zero(ranks): zero_comm_groups = [] + # Create communication group for each zero subgroup for i in range(CompileFlag.zero_ngroups): assert len(ranks) % CompileFlag.zero_ngroups == 0 ranks_per_group = len(ranks) // CompileFlag.zero_ngroups zero_subgroup = tuple(ranks[i * ranks_per_group : (i + 1) * ranks_per_group]) if len(zero_subgroup) > 1 and len(zero_subgroup) < len(ranks): zero_comm_groups.append(zero_subgroup) + # Create communication groups for cross group allreduce. + # Note that this is only for the enabled reduce scatter of ZeRO. + # For example, there are two ZeRO groups [0,1,2,3] and [4,5,6,7], + # then we will create communication groups (0,4), (1,5), (2,6), (3,7). + ranks_per_group = len(ranks) // CompileFlag.zero_ngroups + for i in range(ranks_per_group): + zero_crossgroup = tuple(ranks[i::ranks_per_group]) + if len(zero_crossgroup) > 1 and len(zero_crossgroup) < len(ranks): + zero_comm_groups.append(zero_crossgroup) return zero_comm_groups scale_ndevs = scale_ndevs if scale_ndevs is not None else len(self.devices) assert len(self.devices) == max(self.devices) + 1, f'device must be consecutive' diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 7f50c74c..25396266 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -41,21 +41,24 @@ def __init__(self, params: List[torch.nn.Parameter], param_buffer: torch.Tensor, grad_buffer: torch.Tensor, reduce_op: torch.distributed.ReduceOp, group, async_op: bool, zero: bool, - zero_subgroup: torch.distributed.ProcessGroup = None): + zero_subgroup: torch.distributed.ProcessGroup = None, + zero_crossgroup: torch.distributed.ProcessGroup = None): """ Create a communication unit for parameter allreduce. One allreduce will be called for all gradients associated to the parameters. The parameters are assumed to participate in backward and generate gradient. - @param params List[torch.nn.Parameter]: the parameters - @param param_buffer torch.Tensor: Paramter contiguous buffer - @param grad_buffer torch.Tensor: gradient contiguous buffer - @param reduce_op torch.distributed.ReduceOp: the reduce op used by collectives - @param group: communication group - @param async_op bool: whether to use asynchronous operation - @param zero bool: whether to use zero optimization on gradients - @param zero_subgroup: the subgroup for zero optimization the current rank belongs to + Args: + params List[torch.nn.Parameter]: the parameters + param_buffer torch.Tensor: Paramter contiguous buffer + grad_buffer torch.Tensor: gradient contiguous buffer + reduce_op torch.distributed.ReduceOp: the reduce op used by collectives + group: communication group + async_op bool: whether to use asynchronous operation + zero bool: whether to use zero optimization on gradients + zero_subgroup: the subgroup for zero optimization the current rank belongs to + zero_crossgroup: the communication group for cross zero group allreduce when reduce scatter is enabled """ self._params: List[torch.nn.Parameter] = params @@ -81,6 +84,7 @@ def __init__(self, params: List[torch.nn.Parameter], self._zero_subgroup = self._group if zero_subgroup is None else zero_subgroup self._zgroup_sz: int = torch.distributed.get_world_size(group=self._zero_subgroup) + self._zero_crossgroup = zero_crossgroup # pre and post hooks for gradient synchronization self._pre_hooks: List[Callable] = [] @@ -105,6 +109,26 @@ def zero(self) -> bool: """Whether enable zero for this bucket""" return self._zero + def _group_reduce_scatter(self): + """currently this function is only used in synchronous mode""" + rank = torch.distributed.get_rank(group=self._zero_subgroup) + partial_tensor = self._contiguous_grads.chunk(self._zgroup_sz, dim=0)[rank] + if self._zgroup_sz == self._wsz: + # number of zero groups is 1, thus only reduce scatter is enough + # in this case, self._group == self._zero_subgroup + torch.distributed.reduce_scatter_tensor( + partial_tensor, self._contiguous_grads, + op=self._reduce_op, group=self._zero_subgroup) + else: + # two steps for group reduce scatter + # step #1, allreduce across corresponding GPUs across groups + torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, group=self._zero_crossgroup) + # step #2, reduce scatter within each group + torch.distributed.reduce_scatter_tensor( + partial_tensor, self._contiguous_grads, + op=self._reduce_op, group=self._zero_subgroup) + def build(self): """ Build offset for each parameter @@ -160,11 +184,17 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): self._apply_pre_hooks() # communication if self._zero and Bucket.use_reduce_scatter_for_zero: - rank = torch.distributed.get_rank(group=self._group) - shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) - self._async_handle = torch.distributed.reduce_scatter( - shards[rank], shards, op=self._reduce_op, - group=self._group, async_op=True) + if self._zgroup_sz == self._wsz: + rank = torch.distributed.get_rank(group=self._group) + shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) + self._async_handle = torch.distributed.reduce_scatter( + shards[rank], shards, op=self._reduce_op, + group=self._group, async_op=True) + else: + assert False, "reducescatter is not supported in async mode, " \ + "because the two steps (allreduce, reducescatter) use " \ + "two communication groups, which may induce deadlock." + self._group_reduce_scatter() else: self._async_handle = torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, @@ -201,9 +231,7 @@ def sync_grads(self): self._apply_pre_hooks() # synchrnoize gradients if self._zero and Bucket.use_reduce_scatter_for_zero: - shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) - torch.distributed.reduce_scatter( - shards[rank], shards, op=self._reduce_op, group=self._group) + self._group_reduce_scatter() else: torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, group=self._group) @@ -336,7 +364,7 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, # If the ranks are [0, 2, 4, 6], zero_ngroups=2, then the ranks # will be divided into [0, 2] and [4, 6]. if self._zero and Bucket.use_reduce_scatter_for_zero: - assert zero_ngroups == 1, f"zero_ngroups {zero_ngroups}, which is >1, does not support reduce scatter" + _logger.info(f"Using reduce scatter for ZeRO optimization") if zero_ngroups > 1: assert self._zero, f"USE_ZERO must be set when ZERO_NUM_GROUPS is larger than 1" assert len(ranks) % zero_ngroups == 0, f"length of ranks {ranks} must be divisible by zero factor {zero_ngroups}" @@ -347,9 +375,21 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, if len(sub_ranks) > 1: assert DeviceGroup().group_exists(sub_ranks), f"zero subgroup {sub_ranks} does not exist in comm groups" self._zero_subgroup = DeviceGroup().get_group(sub_ranks) + # crossgroup is for the allreduce across zero subgroups, it is only used when + # reduce scatter is enabled and the number of zero subgroups is larger than 1. + start_rank = curr_rank % zgroup_sz + cross_ranks = ranks[start_rank::zgroup_sz] + assert len(cross_ranks) == zero_ngroups + self._zero_crossgroup = DeviceGroup().get_group(cross_ranks) else: - assert zero_ngroups == 1, f"zero factor must be 1, but got {zero_ngroups}" + assert zero_ngroups == 1, f"ZeRO number of groups must be 1, but got {zero_ngroups}" self._zero_subgroup = self._group + self._zero_crossgroup = None + self._zero_ngroups = zero_ngroups + + @property + def zero_ngroups(self) -> int: + return self._zero_ngroups @property def params(self) -> Tuple[torch.nn.Parameter]: @@ -475,6 +515,7 @@ def build_buckets(self): self._async, self._zero, self._zero_subgroup, + self._zero_crossgroup, ) buckets.append(bucket) torch.cuda.empty_cache() diff --git a/cube/runtime/gnorm.py b/cube/runtime/gnorm.py new file mode 100644 index 00000000..73e1e6ab --- /dev/null +++ b/cube/runtime/gnorm.py @@ -0,0 +1,162 @@ +from typing import List, Dict, Tuple, TYPE_CHECKING +from dataclasses import dataclass +from collections import defaultdict + +import torch +import torch.distributed as dist + +if TYPE_CHECKING: + from cube.runtime.module import CubeModule + +@dataclass +class ParamsInfo: + # An instance of ParamsInfo corresponds to a group of parameters in cube reducer, + # or a single parameter without cube reducer. + ranks: Tuple[int] + params: List[torch.nn.Parameter] + param_names: List[str] + zero_ngroups: int + +@dataclass +class TidReplicaInfo: + # the number of the replicas of the (partitioned) parameter with tid + nreplicated: int + # the number of all the involved ranks for this parameter with tid + nranks: int + +def _calc_grad_shape(slicers_list): + # caculate the shape of each full parameters/grads + tid2shape = {} + for rank_slicers in slicers_list: + for tid, slicers in rank_slicers.items(): + if tid not in tid2shape: + tid2shape[tid] = [0 for _ in slicers] + for i, slicer in enumerate(slicers): + # slicer: (start, end, step) + if slicer.stop > tid2shape[tid][i]: + tid2shape[tid][i] = slicer.stop + # caculate the number of replicas of each model parameter + tid2nreplicas = {} + for rank_slicers in slicers_list: + for tid, slicers in rank_slicers.items(): + if tid not in tid2nreplicas: + tid2nreplicas[tid] = 0 + factor = 1 + for i, slicer in enumerate(slicers): + factor *= (slicer.stop - slicer.start) / tid2shape[tid][i] + tid2nreplicas[tid] += factor + return tid2nreplicas + +def prepare_for_grad_clip_legacy(cube_model: 'CubeModule', curr_rank: int) -> Dict[int, List[torch.nn.Parameter]]: + assert curr_rank == dist.get_rank() + tid2param, tid2slicers = {}, {} + for name, param in cube_model.named_parameters(): + assert name in cube_model.fullmap + if param.requires_grad: + tid = cube_model.tid_of_param_name(name) + slicers = cube_model.fullmap[name][1] + tid2param[tid] = param + tid2slicers[tid] = slicers + # gather all parameters' slicers + tid2ranks_list = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(tid2ranks_list, tid2slicers) + tid2nreplicas = _calc_grad_shape(tid2ranks_list) + nreplicas2localparams = defaultdict(list) + for tid, param in tid2param.items(): + nreplicas = tid2nreplicas[tid] + nreplicas2localparams[nreplicas].append(param) + return nreplicas2localparams + +def _check_is_ordered(ranks: Tuple[int]) -> bool: + for i in range(len(ranks)-1): + if ranks[i] >= ranks[i+1]: + return False + return True + +def _check_no_intersection(ranks_set): + # ranks_set: set of tuple + # check intersection between any two tuples + ranks = set() + for r in ranks_set: + old_len = len(ranks) + ranks.update(r) + if len(ranks) - old_len != len(r): + return False + return True + +def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int, TidReplicaInfo]: + """This function is used to calculate the number of replicas of each model parameter. + Each parameter has a tuple of `len(ranksset)` (we call it nreplicated) and `nranks`, + because a parameter may be replicated (not data parallelism) which is supported by cube. + It affects the calculation of gnorm. So nreplicated represents the number of + non-data-parallelism replicas for this parameter, and nranks represents the number of + all the involved ranks for this parameter. + + Args: + tid2ranks_list: list of dict, each dict is tid2ranks + + Returns: + tid2nreplicas: dict, tid -> TidReplicaInfo + """ + # caculate the number of replicas of each model parameter + tid2nreplicas = {} + tid2ranksset = defaultdict(set) + for tid2ranks in tid2ranks_list: + for tid, ranks in tid2ranks.items(): + assert _check_is_ordered(ranks) + assert isinstance(ranks, tuple), f'ranks {ranks} should be tuple' + tid2ranksset[tid].add(ranks) + # the ranks have been deduplicated using set. + # so the number of ranks represents the number of replicas (pure replicate not data parallelism), + # where each ranks is the unit of ZeRO (or reducer). + for tid, ranksset in tid2ranksset.items(): + assert _check_no_intersection(ranksset) + nranks = sum([len(ranks) for ranks in ranksset]) + tid2nreplicas[tid] = TidReplicaInfo(len(ranksset), nranks) + return tid2nreplicas + +def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, List[torch.nn.Parameter]]: + params_info_for_gnorm = cube_model.parameters_for_calc_gnorm() + tid2ranks = {} + tid2info_list_seq = {} + for seq, params_info in enumerate(params_info_for_gnorm): + # params_info is ParamsInfo, which is defined in this file + assert isinstance(params_info.ranks, tuple), f'ranks {params_info.ranks} should be tuple' + for name, param in zip(params_info.param_names, params_info.params): + assert param.requires_grad + tid = cube_model.tid_of_param_name(name) + tid2ranks[tid] = params_info.ranks + tid2info_list_seq[tid] = seq + tid2ranks_list = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(tid2ranks_list, tid2ranks) + tid2nreplicas = _calc_grad_replicas(tid2ranks_list) + # populate nreplicas2localparams + nreplicas2localparams = defaultdict(list) + processed_seqs = {} + for tid, replicated_info in tid2nreplicas.items(): + if tid not in tid2info_list_seq: + # because tid2nreplicas is from all the ranks, + # if this parameter (tid) does not belong to this rank, + # it is safe to skip it. + continue + seq = tid2info_list_seq[tid] + params_info = params_info_for_gnorm[seq] + if seq in processed_seqs: + assert processed_seqs[seq] == replicated_info, \ + 'the params belonging to the same seq should have the same nreplicated and nranks' + continue + # If ZeRO is not used, the number of replicas of a parameter (partition) is its involved ranks, + # no matter it is pure replicated or data-parallelism replicated. For calculating gnorm, these + # two kinds of replicas are the same, because in data-parallelism, gradients are also allreduced + # before gnorm calculation. + # If ZeRO is used, the number of replicas of a parameter (partition) is the number of pure replicated + # multiplied by the number of ZeRO groups. Multiplying the number of pure replicated is easy + # to understand. Multiplying the number of ZeRO groups is because the gradients of each ZeRO group + # are full model gradients, so the number of ZeRO groups is the number of gradient replicas of the full model. + if not is_zero: + nreplicas = replicated_info.nranks + else: + nreplicas = replicated_info.nreplicated * params_info.zero_ngroups + nreplicas2localparams[nreplicas].extend(params_info.params) + processed_seqs[seq] = replicated_info + return nreplicas2localparams diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 6853865b..57e42969 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,14 +1,16 @@ -from typing import List, Dict, Tuple, Optional +from typing import List, Dict, Tuple import logging import os import sys from pathlib import Path import torch +import torch.distributed as dist from cube.graph.parser.fx.parser import FxModuleParser from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer +from cube.runtime.gnorm import ParamsInfo _logger = logging.getLogger(__name__) @@ -22,11 +24,27 @@ class CubeModule(torch.nn.Module): def __init__(self): super().__init__() self._reducers: List[Reducer] = list() + # Key: str, parameter name (from named_parameters) + # Value: Tuple[int, Tuple[slice], int]: + # full tensor tid, + # position of sub tensor in full tensor, + # position of value in value partition. self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() @property def reducers(self): return self._reducers + + @property + def fullmap(self): + return self._fullmap + + def tid_of_param_name(self, name: str) -> int: + # Return the tid of sub tensor with the parameter name + # It is the last field of the parameter name, which is hacky + if name not in self._fullmap: + raise RuntimeError(f"Cannot find {name} in fullmap") + return int(name.split('_')[-1]) def add_reducer(self, reducer: Reducer): if not isinstance(reducer, Reducer): @@ -59,6 +77,36 @@ def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: # print(f'> get out parameters: {sum(p.numel() for p in params)}') return params + def parameters_for_calc_gnorm(self) -> List[ParamsInfo]: + """Return the necessary information for calculating the gradient norm. + + Returns: + List[Tuple[Tuple[int], List[torch.nn.Parameter], List[str], int]]: + A list of tuples, each tuple contains the following information: + Tuple[int]: the ranks spanned by the parameters in the tuple + List[torch.nn.Parameter]: the contiguous parameters in the tuple + List[str]: the names of the original parameters in the tuple + int: the number of the ZeRO groups for the parameters + """ + paramid2name = {} + for name, param in self.named_parameters(): + paramid2name[id(param)] = name + + params_info_for_gnorm = [] + reducer_pids = set() + for reducer in self._reducers: + param_names = [paramid2name[id(p)] for p in reducer.params] + params_info = ParamsInfo(reducer.ranks, reducer.parameters_for_optimizer(), + param_names, reducer.zero_ngroups) + params_info_for_gnorm.append(params_info) + reducer_pids.update(id(p) for p in reducer.params) + for param in self.parameters(): + if id(param) not in reducer_pids: + # zero_ngroups is 1, since there is no reducer for it and multiplying 1 does not change the result. + params_info = ParamsInfo((dist.get_rank(),), [param], [paramid2name[id(param)]], 1) + params_info_for_gnorm.append(params_info) + return params_info_for_gnorm + def gather_params(self): """ Gather parameters @@ -83,6 +131,7 @@ def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: i assert hasattr(self, attr), f"{attr} is not in the module" self._fullmap[attr] = (tid, slicers, val_chunks) + # TODO: remove this function, use the property instead def get_full_map(self): return self._fullmap From e574fa32420f023e6fa2a333f2787aab04de7ac0 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Thu, 2 Nov 2023 11:29:19 +0000 Subject: [PATCH 1520/1892] Merged PR 1854: merge v0.3 to main --- cube/graph/function/dimops.py | 2 ++ cube/program.py | 16 +++++++++++++--- cube/runtime/module.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 13242a15..3fe2372a 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -755,6 +755,8 @@ def align(self, inputs: List[IRTensor], op_anno: OpAnno, **kwargs) -> bool: # check dimension consistency for ashape, itensor in zip(op_anno.inputs(), inputs): + if itensor is None: + continue if not (isinstance(itensor, IRTensor) ^ ashape.ignore): return False if not isinstance(itensor, IRTensor): diff --git a/cube/program.py b/cube/program.py index ff8ebc4c..6c4f48f9 100644 --- a/cube/program.py +++ b/cube/program.py @@ -172,9 +172,19 @@ def complex(val: Any): def get_graph(self): return self._ir_graph - def load_module(self, filename: str): - """Load module from file.""" - self._loaded_module = load_model(filename, self.save_content) + def load_module(self, filename: str, load_fullmodelpt: Optional[bool] = None): + """Load module from file. + + Args: + filename (str): file path + load_fullmodelpt (Optional[bool]): controls whether to load full model checkpoint. + If None, use the default value of the semantic model. + """ + if load_fullmodelpt is not None: + load_content = load_fullmodelpt + else: + load_content = self.save_content + self._loaded_module = load_model(filename, load_content) def get_gen_module(self) -> Optional[torch.nn.Module]: return self._loaded_module diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 57e42969..8b925ce6 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -77,6 +77,21 @@ def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: # print(f'> get out parameters: {sum(p.numel() for p in params)}') return params + def parameters_for_broadcast(self) -> List[torch.nn.Parameter]: + """ + This function is for broadcasting loaded weights from one scale unit to + all other scale units to resume from sharded checkpoints globally. + """ + params = [] + reducer_pids = set() + for reducer in self._reducers: + params.append(reducer._contiguous_params) + reducer_pids.update(id(p) for p in reducer.params) + for param in self.parameters(): + if id(param) not in reducer_pids: + params.append(param) + return params + def parameters_for_calc_gnorm(self) -> List[ParamsInfo]: """Return the necessary information for calculating the gradient norm. From ef136b1326008d5e69e3ee062cce2d06f633dfc3 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 8 Nov 2023 01:15:29 +0000 Subject: [PATCH 1521/1892] Merged PR 1899: bug fix: squeeze and unsqueeze --- cube/graph/function/function.py | 32 ++++++++++++++++++++------ tests/graph/function/test_functions.py | 32 +++++++++++++++++++++++++- tests/graph/parser/test_converter.py | 18 --------------- 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 8a111c5e..4e171046 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1166,25 +1166,43 @@ def Squeeze(input, dim=None, signature = None): """ out = torch.squeeze(tensor) """ - dim = (dim,) if isinstance(dim, int) else dim + if isinstance(dim, int): + dim = (dim,) + if dim is not None: + dim = tuple(d if d >= 0 else d + len(input.shape) for d in dim) edim_in = ShapeAnno.create_shape_str(input.shape) assert len(edim_in) == len(input.shape) edim_ou = [] - for idx, (dim_anno, dim_size) in enumerate(zip(edim_in, input.shape)): - if dim_size > 1 or (dim is not None and idx not in dim): - edim_ou.append(copy.copy(dim_anno)) + for idx in range(len(input.shape)): + if dim is None or idx in dim: + if input.shape[idx] != 1: + # If this dimension is not 1, then we should never partation it + # Otherwise, it could be squeezed mistakenly if the dimension after partition is 1 + # For example, a tensor with shape(2,4) + # 1. for single gpu, + # after calling squeeze(t, 0) the shape is still (2, 4) + # 2. for 2 gpus, if we partition dim 0, then the tensor shape in each gpu will be (1,4) + # after calling squeeze(t, 0), the shape becomes (4,) in each gpu + # which is not correct + edim_in[idx] += '^' + edim_ou.append(edim_in[idx]) + # else remove this dimension in edim_out + else: + edim_ou.append(edim_in[idx]) anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(Squeeze, 'squeeze', signature, [anno], [input]) + return IRDimops(Squeeze, 'squeeze', signature, [anno], [input], dim=dim) def Unsqueeze(input, dim, signature = None): """ out = torch.unsqueeze(tensor, dim) + A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. + Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1. """ edim_in = ShapeAnno.create_shape_str(input.shape) edim_ou = copy.copy(edim_in) - if dim == -1: - dim = len(edim_ou) + if dim < 0: + dim += len(edim_ou) + 1 edim_ou.insert(dim, '1') anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(Unsqueeze, 'unsqueeze', signature, [anno], [input],dim=dim) diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 8a3f2b26..ec5f229a 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -169,6 +169,36 @@ def test_Max(): def test_Squeeze(): op = F.Squeeze(IRTensor([2, 1, 4, 1])) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a c' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c^ d -> a^ c^' op = F.Squeeze(IRTensor([2, 1, 4, 1]), 1) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a c d' + op = F.Squeeze(IRTensor([2, 1, 4, 1]), 0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c d -> a^ b c d' + op = F.Squeeze(IRTensor([2, 1, 4, 1]), -1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a b c' + op = F.Squeeze(IRTensor([2, 1, 4, 1]), -2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ d -> a b c^ d' + + op = F.Squeeze(IRTensor([2, 1, 4, 1]), (-1, -2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ d -> a b c^' + op = F.Squeeze(IRTensor([2, 1, 4, 1]), (1, -1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a c' + op = F.Squeeze(IRTensor([2, 1, 4, 1]), (1, -2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ d -> a c^ d' + op = F.Squeeze(IRTensor([2, 1, 4, 1]), (0, -2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c^ d -> a^ b c^ d' + + +def test_Unsqueeze(): + op = F.Unsqueeze(IRTensor([2, 4]), 0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> 1 a b' + op = F.Unsqueeze(IRTensor([2, 4]), 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a 1 b' + op = F.Unsqueeze(IRTensor([2, 4]), 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b 1' + op = F.Unsqueeze(IRTensor([2, 4]), -3) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> 1 a b' + op = F.Unsqueeze(IRTensor([2, 4]), -2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a 1 b' + op = F.Unsqueeze(IRTensor([2, 4]), -1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b 1' diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 9eef910e..7e02abad 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -58,24 +58,6 @@ def forward(self, x, **kwargs): assert any(node.signature == 'torch.nn.functional.linear' for node in nodes) -def test_to_ir_graph_args(): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 5) - - def forward(self, x, *args): - return self.linear(x) - dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} - module = MyModule() - fx_graph = to_fx_graph(module, dummy_input) - - with tempfile.TemporaryDirectory() as tempdir: - # currently we don't support *args - with pytest.raises(RuntimeError): - to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) - - def test_record_codeline(): class MyModule(torch.nn.Module): def __init__(self): From 3831179a1a6c914a87e3ecb719102df54c6b6603 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 9 Nov 2023 05:27:07 +0000 Subject: [PATCH 1522/1892] Merged PR 1895: fix cannot getattr from root parity check & ut passed --- .../concrete_trace_utils/concrete_tracer.py | 5 ++++- requirements.txt | 2 +- tests/graph/tracer/__init__.py | 0 tests/graph/tracer/test_getattr.py | 17 +++++++++++++++ tests/parallel_module/test_gencode.py | 21 ++++++++----------- 5 files changed, 31 insertions(+), 14 deletions(-) create mode 100644 tests/graph/tracer/__init__.py create mode 100644 tests/graph/tracer/test_getattr.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 6e3cd7e0..8dbcc8d2 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -822,7 +822,10 @@ def module_getattribute_wrapper(mod, attr): else: return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) elif attr in self.default_module_getattr: - return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) + if self.the_path_of_middle_class[id(mod)] == '': + return self.create_proxy('get_attr', f'{attr}', (), {}) + else: + return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) elif id(attr_val) in self.the_path_of_parameter: return self.create_proxy('get_attr', self.the_path_of_parameter[id(attr_val)], (), {}) elif id(attr_val) in self.the_path_of_buffer: diff --git a/requirements.txt b/requirements.txt index 96ee7e4d..84d96390 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ numpy>=1.23.0 matplotlib more-itertools dill -torch>=1.13 +torch>=1.13,<2.1 diff --git a/tests/graph/tracer/__init__.py b/tests/graph/tracer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/graph/tracer/test_getattr.py b/tests/graph/tracer/test_getattr.py new file mode 100644 index 00000000..09e43e0a --- /dev/null +++ b/tests/graph/tracer/test_getattr.py @@ -0,0 +1,17 @@ +import torch + +from cube.graph.parser.converter import to_fx_graph + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.dropout(x, 0.1, self.training) + +def test_getattr_from_root(): + model = SimpleModel() + dummy_input = {'x': torch.rand(10)} + traced_graph = to_fx_graph(model, dummy_input) + traced_graph(**dummy_input) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 22004364..41efe869 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -315,15 +315,12 @@ def test_codegen_training_flag(): with tempfile.TemporaryDirectory() as tempdir: m = TrainingModule() m.train() - - # self.training isn't supported in concrete_trace - with pytest.raises(RuntimeError, match='Node referenced nonexistent target.*'): - parallelize( - m, - {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, - ComputeConfig(1, 1), - dynamic_shape=True, - cube_savedir=tempdir, - load_module=False - ) + parallelize( + m, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False + ) From 6be9d9c20999d0b4920c2dc71c54fadbeb086b24 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 9 Nov 2023 05:43:40 +0000 Subject: [PATCH 1523/1892] Merged PR 1891: Remove build model when run only --- cube/algorithm/ops/dimops.py | 8 ++++---- cube/graph/graph.py | 30 ++++++++++++++++++++++++------ cube/profiler/database.py | 4 ++-- cube/program.py | 2 +- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 1f0c9d8f..764e2d53 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -126,10 +126,10 @@ def instantiate(self, idx: int, dim: Union[int, str], num: int) -> Optional[List else: adim, reduce = 'Value', None - color, default = '\033[32m' if satisfy else '\033[31m', '\033[0m' - _logger.info(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Success' if satisfy else 'Failed!'}{default}") - - if not satisfy: return None + if not satisfy: + color, default = '\033[31m', '\033[0m' + _logger.info(f"split {node.name}: {node.anno} | dim: {adim} num: {num} reduce: {reduce} ... {color}{'Failed!'}{default}") + return None rule: TransformRule = self.infer(idx, dim, num) # transform diff --git a/cube/graph/graph.py b/cube/graph/graph.py index a9b5f453..5ac59dfd 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -1053,17 +1053,18 @@ def __exit__(self, exc_type, exc_value, traceback): dill.dump(save, f) @staticmethod - def load(filename: str): + def from_dill(id_state, graph): """ - Load the graph from pickled file. + build instance from id_state and graph Note IDGenerator will also be reset to match with graph status - @param filename str + Args: + id_state : read from dill + graph (IRGraph): read from dill - @return graph IRGraph + Returns: + IRGraph: the build graph """ - with open(filename, 'rb') as f: - id_state, graph = dill.load(f) # recover IRGenerator IDGenerator().load_states(id_state) @@ -1088,6 +1089,23 @@ def reset_node(segment: IRSegment): reset_node(graph) return graph + @staticmethod + def load(filename: str): + """ + Load the graph from pickled file. + Note IDGenerator will also be reset to match with graph status + + Args: + filename (str): the file to load + + Returns: + IRGraph: the built graph + """ + with open(filename, 'rb') as f: + id_state, graph = dill.load(f) + + return IRGraph.from_dill(id_state, graph) + def checksum(self, strict: bool = True) -> str: """Get the MD5 checksum of the graph. diff --git a/cube/profiler/database.py b/cube/profiler/database.py index c3b7fabb..138400f9 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -423,13 +423,13 @@ def dump_ops(self, file: str, override=False): for signature in self._data.keys(): file_n = os.path.join(file, signature +'.json') with open(file_n, 'w') as f: - json.dump(self._data[signature],f) + json.dump(self._data[signature], f, indent=2) def dump_op(self, file: str, signature, override=False): assert signature in self._data.keys(), f'this node not be profiled' file_n = os.path.join(file, signature +'.json') with open(file_n, 'w') as f: - json.dump(self._data[signature],f) + json.dump(self._data[signature], f, indent=2) def load(self, file: str): """! diff --git a/cube/program.py b/cube/program.py index 6c4f48f9..f1d7c2f6 100644 --- a/cube/program.py +++ b/cube/program.py @@ -137,7 +137,7 @@ def __init__(self, model: Optional[torch.nn.Module], dynamic_shape (bool): whether to use dynamic shape. Default False. """ - if DeviceGroup().local_rank == 0: + if DeviceGroup().local_rank == 0 and model is not None: assert isinstance(model, torch.nn.Module), f"device of local_rank == 0 must provide model" self.model = model self._dummy_input: Dict[str, Any] = None From c0b1e6150312273222b9c4d525503d008f366043 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 10 Nov 2023 08:51:53 +0000 Subject: [PATCH 1524/1892] Merged PR 1902: reverse auto scale order of multiref and pyop reverse auto scale order of multiref and pyop This was originally a bug when we firstly scale multiref and then pyop: if the multiref depends on pyop, the pyop is not scaled so the multiref will be kept as same. Since pyop now will always be replicated to all the devices, we can put pyop first because it depends nothing. Then the multiref can be correctly scaled after the scale of pyop. --- cube/graph/gener/gen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 2c817c8e..913c5185 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -110,10 +110,10 @@ def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: graph._reorder_producer_consumer() # remove anchor node graph = IRAdapterGener.remove_anchor(graph) - # automatic transform multiref - graph = IRAdapterGener.autoref(graph) # automatic replace pyfunc graph = IRAdapterGener.auto_pyfunc(graph) + # automatic transform multiref + graph = IRAdapterGener.autoref(graph) # generate adapters for activation graph = IRAdapterGener.gen_activation(graph, cost_fn=cost_fn) # generate weight reducer From 90c0c9eeaa41113fe91341203503d3ad8fa0ceac Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 13 Nov 2023 06:34:32 +0000 Subject: [PATCH 1525/1892] Merged PR 1907: support slice by long tensor support slice by long tensor --- cube/graph/function/function.py | 19 ++++++++++++++++++- tests/graph/function/test_functions.py | 11 +++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 4e171046..d2f0bfe1 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1879,11 +1879,28 @@ def To(tensor: IRTensor, dtype_or_device=None, *, device=None, dtype=None, out=N def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: """_operator.getitem(a, b): return a[b]""" - assert not isinstance(b, IRObject) obj, index = a, b # tensor slice if isinstance(obj, IRTensor): # note `None` will always + if isinstance(index, IRTensor): + # TODO: support general tensor slicing: https://pytorch.org/cppdocs/notes/tensor_indexing.html + # move to FullSlice when ready + """ + Examples: + >>> a = torch.randn((4,2)) + >>> b = torch.randn((3,5)).to(torch.int64) + >>> a[b] # shape [3,5,2] + """ + if index.dtype not in (torch.int64, torch.int32): + raise RuntimeError(f"index should be int64 or int32, but got {index.dtype}") + gener = iter(string.ascii_lowercase) + obj_shape = ShapeAnno.create_shape_str(obj.shape, iterator=gener) + obj_shape[0] = obj_shape[0] + '^' + index_shape = ShapeAnno.create_shape_str(index.shape, iterator=gener) + out_shape = index_shape + obj_shape[1:] + anno = OpAnno.create_op_str([obj_shape, index_shape], [out_shape]) + return IRDimops(GetItem, 'getitem', signature, [anno], [obj, index]) index = (index,) if isinstance(index, (int, slice)) else tuple(index) return FullSlice(obj, index) # object slice diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index ec5f229a..cdfee415 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -4,6 +4,7 @@ from cube.ir.cten import IRObject, IRTensor import pytest +import torch def test_handle_broadcast_multi(): @@ -157,6 +158,16 @@ def test_FullSlice(): op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, slice(1, 10, 1))) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 3' +def test_GetItem(): + op = F.GetItem(IRTensor([4, 2]), IRTensor([3, 5], dtype=torch.int64)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b, c d -> c d b' + op = F.GetItem(IRTensor([4, 2]), IRTensor([3], dtype=torch.int64)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b, c -> c b' + op = F.GetItem(IRTensor([3, 4, 2]), IRTensor([3], dtype=torch.int64)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, d -> d b c' + op = F.GetItem(IRTensor([3, 4, 2]), IRTensor([3, 5], dtype=torch.int64)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, d e -> d e b c' + def test_Max(): op = F.Max(IRTensor([2, 3, 4])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' From 3824b23ad9e06c7154df0c90117ec878acb15e18 Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Thu, 16 Nov 2023 03:16:47 +0000 Subject: [PATCH 1526/1892] Merged PR 1908: support buffer tensors for autodist support buffer tensors for autodist --- cube/algorithm/ops/dimops.py | 25 +++- cube/profiler/database.py | 254 +++++++++++++++++------------------ 2 files changed, 144 insertions(+), 135 deletions(-) diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 764e2d53..3c8fc06b 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -172,10 +172,13 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR Given the partition choice on `dim` dimension of idx-th input, return the partitioning of the output tensor. - @param idx int: the input index - @param dim int: the dimension to partition + Args: + idx int: the input index + dim int: the dimension to partition + num int: the number of partitions - @return rule TransformRule: the transformation rule + Returns: + rule TransformRule: the transformation rule """ node: IRDimops = self.node assert isinstance(dim, int) or dim == 'v', f"expect dim to be int or 'v'" @@ -239,6 +242,17 @@ def modify(kwargs: Dict, idx: int, dim: int, num: int): def collect_split_info(node: IRFwOperation): + """ + Collect the split information of the node. + Args: + node (IRFwOperation): the node to be analyzed + Returns: + split_info (Dict[str, Tuple[int, int, int]]): the split information. + The key is the identifier name, and the value is a tuple of (idx_shape, idx_dim, idx_id). + idx_shape: the index of the input (shape) + idx_dim: the index of the dimension in the input's shape + idx_id: the index of the identifier in the dimension + """ anno = node.anno split_info = {} @@ -256,6 +270,11 @@ def collect_split_info(node: IRFwOperation): return split_info def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: + """ + Returns: + List[IRFwOperation]: the partitioned nodes. Each element of the list represents the (identical) sub-operator + of one partition option. + """ def gen_hash(node: IRFwOperation) -> str: ret = node.signature diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 138400f9..429d5f3b 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -8,9 +8,10 @@ import os import json import logging -import _operator +from dataclasses import dataclass, asdict -import cube +import _operator # required by eval() +import cube # required by eval() from cube.ir.cten import IRTensor, IRObject from cube.ir.operator import IRFwOperation from cube.graph.parser.register import CustomizedOps @@ -19,35 +20,64 @@ Shapes = NewType('Shapes', Tuple[Tuple[int]]) DTypes = NewType('DTypes', Tuple[torch.dtype]) +RequiresGrad = NewType('RequiresGrad', Tuple[bool]) ShapesDTypes = NewType('ShapesDTypes', Tuple[Shapes, DTypes]) NameOrFunc = Union[str, Callable] - _train_module_ref: torch.nn.Module = torch.nn.Module().train() _eval_module_ref: torch.nn.Module = torch.nn.Module().eval() +@dataclass +class ProfiledMetrics: + """! + The profiling data of a function + """ + # the bytes of each input tensors (i.e., activation tensors) + # excluding parameter and buffer tensors for `node`, no matter the activation + # tensor requires gradient or not + in_mem_info: Tuple[int] + # the bytes of every parameter and buffer tensor of `node` + param_mem_info: Tuple[int] + buffer_mem_info: Tuple[int] + # the forward span time in milliseconds + fw_span: float + # the backward span time in milliseconds + bw_span: float + # the peak memory in bytes during inference of `node` + infer_memory: int + # the bytes of each activation tensor that is saved for backward + train_mem_info: Tuple[int] + # the index of the tensor saved for backward in `node.inputs()` list + train_mem2in_idx: Tuple[int] + + class CompProfiler: @staticmethod - def profile(func: Callable, shapes: Shapes, dtypes: DTypes, + def profile(node: IRFwOperation, func: Callable, shapes: Shapes, dtypes: DTypes, requires_grads: Tuple[bool], values: Tuple[Any], warmup_sec: float = 2, prof_times: int = 50, **kwargs) -> Tuple[float, float, int, Tuple[int]]: """ Profile a function - @param func Callable: the callable function, e.g., torch.nn.functional.linear - @param shapes Tuple[Tuple[int]]: the shapes of each input tensor - @param dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 - @param warmup_sec float: warmup seconds - @param prof_times int: profile times - @param kwargs Dict: other keyword argument for func call. - - @return fw_span float: the time in milliseconds for forward time - @return bw_span float: the time in milliseconds for backward time - @return infer_mem int: the peak memory in bytes after inference of the function - @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward + Args: + node IRFwOperation: the node in IRGraph + func Callable: the callable function, e.g., torch.nn.functional.linear + shapes Tuple[Tuple[int]]: the shapes of each input tensor + dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 + requires_grads Tuple[bool]: whether the input tensor requires gradient + values Tuple[Any]: the values of the inputs that are not IRTensor + warmup_sec float: warmup seconds + prof_times int: profile times + kwargs Dict: other keyword argument for func call. + + Returns: + fw_span float: the time in milliseconds for forward time + bw_span float: the time in milliseconds for backward time + infer_mem int: the peak memory in bytes after inference of the function + train_mem_info Tuple[int]: byte sizes of activation tensors saved for backward """ assert len(shapes) == len(dtypes), \ f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" @@ -76,13 +106,12 @@ def gen_torch_tensors(shape, dtype, requires_grad): train_val = eval_val = value train_kwargs[name] = train_val eval_kwargs[name] = eval_val + # run one sample outputs = func(*tensors, **train_kwargs) - ''' - only profile IRDimops currently, which has at least one tensor output and - may have non-tensor outputs (like list, tuple, dict, etc.). In additional, - we assume that non-tensor outputs will not be used in backward. - ''' + # only profile IRDimops currently, which has at least one tensor output and + # may have non-tensor outputs (like list, tuple, dict, etc.). In addition, + # we assume that non-tensor outputs will not be used in backward. outputs = (outputs,) if torch.is_tensor(outputs) else outputs outputs = tuple(filter(lambda x: torch.is_tensor(x) and x.requires_grad, outputs)) assert all(torch.is_tensor(otensor) for otensor in outputs), \ @@ -118,17 +147,21 @@ def pack_hook(x): byte_size = x.element_size() for dim in list(x.size()): byte_size = byte_size * dim - train_mem_info.append(byte_size) idx = -1 + is_attr = False for i, t in enumerate(tensors): if not isinstance(t, torch.Tensor): continue if t.storage().data_ptr() == x.storage().data_ptr(): + if node.inputs()[i].is_attr(): + is_attr = True idx = i break - train_mem2in_idx.append(idx) + if not is_attr: + train_mem_info.append(byte_size) + train_mem2in_idx.append(idx) return x - + def unpack_hook(x): return x @@ -214,20 +247,17 @@ def extract_val(val: Union[IRObject, Any]) -> Any: values.append(extract_val(t)) return fn, shapes, dtypes, requires_grads, values, extract_val(node.kwargs) - def profile(self, node: IRFwOperation, device: Optional[int] = None, override: bool = False): + def profile(self, node: IRFwOperation, device: Optional[int] = None, override: bool = False) -> ProfiledMetrics: """ Profile a forward node in IRGraph on a specific device (default current device) - - @param node IRFwOperation: node of IRGraph - @param device int: the device that the node will execute on - - @return in_mem_info Tuple[int]: byte sizes of input tensors - @return param_mem_info Tuple[int]: byte sizes of param tensors - @return fw_span float: the forward span time in milliseconds - @return bw_span float: the backward span time in milliseconds - @return infer_memory int: the peak memory in bytes after inference of the function - @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward - @return residual_mem: ?? + + Args: + node IRFwOperation: node of IRGraph + device int: the device that the node will execute on + override bool: True if the existed can be overrided else False + + Returns: + profiled_metrics ProfiledMetrics: the profiling data """ fn, shapes, dtypes, requires_grads, values, kwargs = ProfileDataBase.get_func(node) @@ -238,15 +268,13 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b orig_device = torch.cuda.current_device() torch.cuda.set_device(device) - in_mem_info, param_mem_info = [], [] - residual_mem, input_count = 0, 0 + in_mem_info, param_mem_info, buffer_mem_info = [], [], [] for t in node.inputs(): if isinstance(t, IRTensor) and t.is_param(): param_mem_info.append(t.byte_size()) + elif isinstance(t, IRTensor) and t.is_buffer(): + buffer_mem_info.append(t.byte_size()) elif hasattr(t, 'byte_size'): - input_count += 1 - if input_count == 1: - residual_mem += t.byte_size() in_mem_info.append(t.byte_size()) else: _logger.warning(f'node {node}: skip input {t}') @@ -254,43 +282,39 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b # run profiling try: fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = \ - CompProfiler.profile(fn, shapes, dtypes, requires_grads, values, **kwargs) + CompProfiler.profile(node, fn, shapes, dtypes, requires_grads, values, **kwargs) except Exception: _logger.exception(f'fail to profile {node}, use default values') fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = 0, 0, 0, [], [] # log to database key = self._serialize(node) - self.insert(node.signature, key, in_mem_info, param_mem_info, fw_span, bw_span,\ - infer_memory, train_mem_info, residual_mem, train_mem2in_idx) + profiled_metrics = ProfiledMetrics(in_mem_info, param_mem_info, buffer_mem_info, + fw_span, bw_span, infer_memory, + train_mem_info, train_mem2in_idx) + self.insert(node.signature, key, profiled_metrics) _logger.info( - f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} " - f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | fw: {round(fw_span, 2)} ms | " - f"bw: {round(bw_span, 2)} ms | infer mem: {infer_memory} | train mem info: {train_mem_info} | idx: {train_mem2in_idx}") + f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} | requires_grads: {requires_grads} | " + f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | " + f"buffer mem info: {buffer_mem_info} | fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} ms | " + f"infer mem: {infer_memory} | train mem info: {train_mem_info} | idx: {train_mem2in_idx}") if isinstance(device, int): torch.cuda.set_device(orig_device) - return tuple(in_mem_info), tuple(param_mem_info), fw_span, bw_span, infer_memory, \ - tuple(train_mem_info), residual_mem, tuple(train_mem2in_idx) + return profiled_metrics - def insert(self, name: str, key: str, in_mem_info: Tuple[int], param_mem_info: Tuple[int], - fw_span: float, bw_span: float, infer_memory: int, train_mem_info: Tuple[int], - residual_mem: int, train_mem2in_idx: Tuple[int]): + def insert(self, name: str, key: str, profiled_metrics: ProfiledMetrics): """ - log the span of a function name with key - - @param name str: the function signature - @param key str: the encoded shapes and dtypes of node inputs - @param in_mem_info Tuple[int]: byte sizes of input tensors - @param param_mem_info Tuple[int]: byte sizes of param tensors - @param fw_span float: the forward span time in milliseconds - @param bw_span float: the backward span time in milliseconds - @param infer_memory int: the peak memory in bytes after inference of the function - @param train_mem_info Tuple[int]: byte sizes of tensors saved for backward + Log the profiling numbers of a function name with key + + Args: + name str: the function signature + key str: the encoded shapes and dtypes of node inputs + profiled_metrics ProfiledMetrics: the profiling data """ assert isinstance(name, str) and isinstance(key, str) if name not in self._data: self._data[name] = dict() - self._data[name][key] = (in_mem_info, param_mem_info, fw_span, bw_span, infer_memory, train_mem_info, residual_mem, train_mem2in_idx) + self._data[name][key] = profiled_metrics def exist(self, node: IRFwOperation) -> bool: """ @@ -307,18 +331,15 @@ def exist(self, node: IRFwOperation) -> bool: return False return True - def query(self, node: IRFwOperation) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int], int, Tuple[int]]: + def query(self, node: IRFwOperation) -> Optional[ProfiledMetrics]: """! Get the performance number of a node in IRGraph - @param node IRFwOperation: node in IRGraph + Args: + node IRFwOperation: node in IRGraph - @return in_mem_info Tuple[int]: byte sizes of input tensors - @return param_mem_info Tuple[int]: byte sizes of param tensors - @return fw_span float: the forward span time in milliseconds - @return bw_span float: the backward span time in milliseconds - @return infer_memory int: the peak memory in bytes after inference of the function - @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward + Returns: + profiled_metrics ProfiledMetrics: the profiling data """ key = self._serialize(node) if node.signature not in self._data: @@ -327,41 +348,6 @@ def query(self, node: IRFwOperation) -> Tuple[Tuple[int], Tuple[int], float, flo return None return self._data[node.signature][key] - def query_func(self, signature, shapes, dtypes) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int], int, Tuple[int]]: - """ - Get performance number of given name (signature), shapes and dtypes - - @param signature str: function signature - @param shapes Tuple[Tuple[int]]: the shape of each input tensor - @param dtypes Tuple[torch.dtype]: the dtype of each tensor - - @return in_mem_info Tuple[int]: byte sizes of input tensors - @return param_mem_info Tuple[int]: byte sizes of param tensors - @return fw_span float: the forward span time in milliseconds - @return bw_span float: the backward span time in milliseconds - @return infer_memory int: the peak memory in bytes after inference of the function - @return train_mem_info Tuple[int]: byte sizes of tensors saved for backward - """ - key = self._serialize(shapes, dtypes) - if signature not in self._data: - return None - if key not in self._data[signature]: - return None - return self._data[signature][key] - - def query_args(self, signature: str) -> Tuple[List[Shapes], List[DTypes]]: - """ - Get the recorded shapes and dtypes of - """ - item_shapes, item_dtypes = [], [] - if signature not in self._data: - return item_shapes, item_dtypes - for shapes_dtypes_str in self._data[torch.signature].keys(): - shapes, dtypes = self._deserialize(shapes_dtypes_str) - item_shapes.append(shapes) - item_dtypes.append(dtypes) - return item_shapes, item_dtypes - def _serialize(self, node: IRFwOperation) -> str: """ Serialize the shapes, dtypes and kwargs into a string @@ -369,40 +355,51 @@ def _serialize(self, node: IRFwOperation) -> str: e.g., shapes: ((1024,), (1024,1024)) dtypes: (torch.float32, torch.float32) - => (1024,)-(1024,1024) : torch.float32-torch.float32 + requires_grads: (True, False) + => (1024,)-(1024,1024) : torch.float32-torch.float32 : True-False - @param shapes Tuple[Tuple[int]]: the shape of each tensor - @param dtypes Tuple[torch.dtype]: the dtype of each tensor + Args: + node IRFwOperation: node in IRGraph - @return key str: the serialized string + Returns: + key str: the serialized string """ - shapes, dtypes = [], [] + shapes, dtypes, requires_grads = [], [], [] for t in node.inputs(): if isinstance(t, IRTensor): shapes.append(t.shape) dtypes.append(t.dtype) + requires_grads.append(t.requires_grad) # else: # shapes.append(None) # dtypes.append(type(t)) shapes = '-'.join(str(tuple(shape)) if shape is not None else str(None) for shape in shapes) dtypes = '-'.join(str(dtype) for dtype in dtypes) - return shapes + ' : ' + dtypes + requires_grads = '-'.join(str(require_grad) for require_grad in requires_grads) + return shapes + ' : ' + dtypes + ' : ' + requires_grads - def _deserialize(self, key: str) -> ShapesDTypes: + def _deserialize(self, key: str) -> Tuple[Shapes, DTypes, RequiresGrad]: """ De-serialize the key string to shapes and dtypes - e.g., (1024,)-(1024,1024)=torch.float32-torch.float32 + e.g., (1024,)-(1024,1024) : torch.float32-torch.float32 : True-False => shapes: ((1024,), (1024,1024)) dtypes: (torch.float32, torch.float32) + requires_grads: (True, False) + + Args: + key str: the serialized string - @param key str: the serialized string - @return shapes_and_dtypes ShapesDTypes: shapes and dtypes + Returns: + shapes Shapes: the shapes of each input tensor + dtypes DTypes: the dtypes of each input tensor + requires_grads RequiresGrad: whether the input tensor requires gradient """ - shapes, dtypes = key.split(' : ') + shapes, dtypes, requires_grads = key.split(' : ') shapes = tuple(eval(shape) for shape in shapes.split('-')) dtypes = tuple(eval(dtype) for dtype in dtypes.split('-')) - return shapes, dtypes + requires_grads = tuple(eval(require_grad) for require_grad in requires_grads.split('-')) + return shapes, dtypes, requires_grads def dump(self, file: str, override=False): """! @@ -416,20 +413,12 @@ def dump(self, file: str, override=False): with open(file, 'w') as f: json.dump(self._data, f) - - def dump_ops(self, file: str, override=False): - if os.path.exists(file): - assert override, f"File {file} exists. Set override = True to force dump." - for signature in self._data.keys(): - file_n = os.path.join(file, signature +'.json') - with open(file_n, 'w') as f: - json.dump(self._data[signature], f, indent=2) - - def dump_op(self, file: str, signature, override=False): + def dump_op(self, file: str, signature, override=False): assert signature in self._data.keys(), f'this node not be profiled' file_n = os.path.join(file, signature +'.json') with open(file_n, 'w') as f: - json.dump(self._data[signature], f, indent=2) + to_dump = {key: asdict(value) for key, value in self._data[signature].items()} + json.dump(to_dump, f, indent=2) def load(self, file: str): """! @@ -446,14 +435,15 @@ def load_ops(self, file: str): if filename.endswith('.json'): with open(os.path.join(file, filename)) as f: signature = filename[:-len('.json')] - self._data[signature] = json.load(f) + loaded_json = json.load(f) + self._data[signature] = {key: ProfiledMetrics(**value) for key, value in loaded_json.items()} def __repr__(self) -> str: data = [] for signature in self._data: for key in self._data[signature]: - shapes, dtypes = self._deserialize(key) - in_mem_info, param_mem_info, fw_span, bw_span, infer_mem, train_mem = self._data[signature][key] - data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, in mem {in_mem_info} bytes, param mem {param_mem_info} bytes, fw span: {fw_span} ms, bw span: {bw_span} ms, infer mem {infer_mem} bytes, train mem {train_mem} bytes') + shapes, dtypes, requires_grads = self._deserialize(key) + pmetrics = self._data[signature][key] + data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, requires_grads={requires_grads}, profiled numbers: {pmetrics}.') data = '\n'.join(data) return data From d58c47da3b2b3667bc0be87fc18de43951d60a27 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 20 Nov 2023 03:42:36 +0000 Subject: [PATCH 1527/1892] Merged PR 1915: Reorg parallelize to support lazy module (only create module when necessary) Support graph dump/load Move the module create code from parallelize to _gencode Move outdir check from _gencode to parallelize --- cube/parallel.py | 256 +++++++++++++++++-------- tests/parallel_module/test_gencode.py | 71 ++++++- tests/parallel_module/test_override.py | 52 +++-- 3 files changed, 286 insertions(+), 93 deletions(-) diff --git a/cube/parallel.py b/cube/parallel.py index 7703c82c..e7f444b6 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -1,6 +1,7 @@ +from enum import Enum from functools import partial import types -from typing import Callable, Any, Dict, Optional, Type, Union, TypeVar +from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar from pathlib import Path import inspect import sys @@ -42,6 +43,9 @@ class ComputeConfig: plan_ngpus: int runtime_ngpus: int + use_zero: bool = False + zero_ngroups: int = 1 + @contextmanager def _flags(flags, warning_on_override=True, /, **kwargs): @@ -59,8 +63,13 @@ def _flags(flags, warning_on_override=True, /, **kwargs): setattr(flags, k, v) -def _compile_flags(): - return _flags(CompileFlag, use_zero=False, async_reducer=False, reducer_op='sum', async_comm=False) +def _compile_flags(compute_config: ComputeConfig): + return _flags( + CompileFlag, + async_reducer=False, reducer_op='sum', async_comm=False, + use_zero=compute_config.use_zero, + zero_ngroups=compute_config.zero_ngroups, + ) def _runtime_flags(**kwargs): @@ -100,7 +109,8 @@ def _add_cube_savedir_to_syspath(cube_savedir: str) -> Path: def _is_any_gencode_loaded(namespace: str) -> bool: """Check if a module is loaded""" for m in sys.modules.values(): - if m.__name__.startswith(namespace + '.' + _GENCODE_FILE_PREFIX): + # m.__name__ doesn't always work as some module doesn't have __name__ attribute. + if getattr(m, '__name__', '').startswith(namespace + '.' + _GENCODE_FILE_PREFIX): return True return False @@ -113,93 +123,119 @@ def _get_arg_default_values(fn) -> Dict[str, Any]: _GENCODE_FILE_PREFIX = 'gencode' _GENCODE_FILE_TEMPLATE = _GENCODE_FILE_PREFIX + '{}.py' # 'gencode{}.py' _CUBE_MODULE_NAMESPACE = '_cube_modules' +_GRAPH_DUMP_FILE = 'graph.ckp' +_FORWARD_ARGS_DUMP_FILE = 'forward_args.pkl' -def _gencode( - module: torch.nn.Module, - dummy_input: dict, - pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], - compute_config: ComputeConfig, - *, - dynamic_shape: bool = True, - cube_savedir: Union[str, Path] = './.cube', - override: bool = False, - instance_name: Optional[str] = None - ) -> None: - """ - Generate cube module source code from a torch module, and save it to file. - Generated module will be save according to its full qualified name. +class ReuseType(Enum): + """The reuse type""" + NONE = 'none' # no reuse, everything will be regenerated. + ALL = 'all' # try to reuse everything if possible + GRAPH = 'graph' # only graph will be reused (so we don't need to trace the graph again) - If you want to save multiple instances of the same module, - you can specify the instance_name to distingish them. - For example, if the module is `torchscale.x.y`, then the generated module will be save to - `cube_savedir/_cube_modules/torchscale/x/y/instance_name`. +def _prepare_and_check_reusable( + cube_savedir, + module_or_module_class: Union[Type[torch.nn.Module], torch.nn.Module], + compute_config, + instance_name, + reuse: ReuseType = ReuseType.ALL, + ) -> Tuple[str, bool]: + """ + Prepare the output directory for code generation, and also check if the existing code is reusable. Args: - module (torch.nn.Module): the module to be compiled - dummy_input (dict): the dummy input for the module - pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy + cube_savedir (str): the directory to save generated code + module_or_module_class (Union[Type[torch.nn.Module], torch.nn.Module]): the original module or module class compute_config (ComputeConfig): the environment resource - dynamic_shape (bool): whether to use dynamic shape - override (bool): If true, source code will be regenerated even if generated code exists. - cube_savedir (Union[str, Path]): the directory to save generated code instance_name (Optional[str]): the instance name of the generated module. + reuse (ReuseType): specify which part can be reused. Returns: - None + Tuple[str, bool]: the output directory and whether the existing code is reusable. + + Raises: + RuntimeError: if the existing code is not reusable, + will raise RuntimeError if the code is not reusable but the module is already loaded. """ - # put cube_savedir into sys.path - # so we can import the generated module with its namespace later + cube_savedir = _add_cube_savedir_to_syspath(cube_savedir) instance_name = instance_name.strip('.') if instance_name else '' instance_namespace = f'.{instance_name}' if instance_name else '' - namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module)}{instance_namespace}' + namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_or_module_class)}{instance_namespace}' outdir = cube_savedir / Path(namespace.replace('.', '/').strip('/')) outdir.mkdir(parents=True, exist_ok=True) # decision matrix for code generation - # override flag | dir condition(imported, empty, match, unmatched) | action + # reuse flag | dir condition(imported, empty, match, unmatched) | action # --------------------------------------------------------- - # True | empty | generate - # True | imported | raise error - # True | match | generate - # True | unmatch | generate - # False | empty | generate - # False | match | do nothing - # False | unmatch | raise error - # False | imported | doesn't matter - if not override: + # NONE/GRAPH | empty | generate + # NONE/GRAPH | imported | raise error + # NONE/GRAPH | match | generate + # NONE/GRAPH | unmatch | generate + # ALL | empty | generate + # ALL | match | do nothing + # ALL | unmatch | raise error + # ALL | imported | doesn't matter + reusable = False + if reuse == ReuseType.ALL: + module_meta_files = [ + outdir / FxModuleParser.ATTR_CONTENT_FILE, + outdir / FxModuleParser.ATTR_MAP_FILE, + ] # check if the module is already generated expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] - expected_output_files.append(outdir / FxModuleParser.ATTR_CONTENT_FILE) - expected_output_files.append(outdir / FxModuleParser.ATTR_MAP_FILE) + expected_output_files.extend(module_meta_files) expected_output_files.append(outdir / ParallelModule.COMPUTE_CONFIG_FILE) + expected_output_files.append(outdir / _GRAPH_DUMP_FILE) + expected_output_files.append(outdir / _FORWARD_ARGS_DUMP_FILE) existing_output_files = [f for f in outdir.glob('*') if f.is_file()] if existing_output_files: if all([output_file.exists() for output_file in expected_output_files]) \ and len(existing_output_files) == len(expected_output_files) \ and torch.load(outdir / ParallelModule.COMPUTE_CONFIG_FILE) == compute_config: - return + reusable = True # everything is matched. elif all(f.suffix != '.py' for f in existing_output_files): # No python source code is generated. # which means its last generation failed. # in this case, we can reuse the same directory safely. - pass + logger.info(f'Output directory {outdir} is not empty. ' + f'But no python source code is present. ' + f'Will reuse the directory and the graph dump if present.') + # we have to trace the graph again if not all meta files are present. + if not all([meta_file.exists() for meta_file in module_meta_files]): + for f in outdir.glob('*'): + if f.is_file(): + f.unlink() else: raise RuntimeError(f'Output directory {outdir} is not empty. ' - f'And the existing files do not match with current config.') + f'And the existing files do not match with current config. ' + f'You can remove the directory and try again, ' + f'or set reuse to ReuseType.NONE to regenerate the code.') else: # check if the module is already loaded if _is_any_gencode_loaded(namespace): raise RuntimeError(f'Output directory {outdir} is already loaded. ' f'You can not override a loaded module.') # clear existing generated files - for f in outdir.glob('*'): + if reuse == ReuseType.NONE: + glob_pattern = '*' + else: + glob_pattern = '*.py' # so we can keep graph dumps. + for f in outdir.glob(glob_pattern): if f.is_file(): f.unlink() + return outdir, reusable + + +def _gen_graph( + module: torch.nn.Module, + dummy_input: dict, + outdir: Path, + dynamic_shape: bool, +): # reset environment program = Program() program.clear() @@ -268,6 +304,79 @@ def _gencode( program.set_output(ir_dummy_outputs) program.finalize() + return graph, forward_args + + +def _gencode( + module_or_module_class: torch.nn.Module, + dummy_input: dict, + pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], + compute_config: ComputeConfig, + outdir: Path, + *, + dynamic_shape: bool = True, + module_dtype: Optional[torch.dtype] = None, + module_fn: Optional[Callable[[], torch.nn.Module]] = None, + ) -> None: + """ + Generate cube module source code from a torch module, and save it to file. + Generated module will be save according to its full qualified name. + + If you want to save multiple instances of the same module, + you can specify the instance_name to distingish them. + + For example, if the module is `torchscale.x.y`, then the generated module will be save to + `cube_savedir/_cube_modules/torchscale/x/y/instance_name`. + + Args: + module (torch.nn.Module): the module to be compiled + dummy_input (dict): the dummy input for the module + pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy + compute_config (ComputeConfig): the environment resource + outdir (Path): the directory to save generated code + dynamic_shape (bool): whether to use dynamic shape + module_dtype (Optional[torch.dtype]): the dtype of the module. Keep as it is when it is None. + module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. + + Returns: + None + """ + graph_ckp = outdir / _GRAPH_DUMP_FILE + forward_args_ckp = outdir / _FORWARD_ARGS_DUMP_FILE + if not graph_ckp.exists() or not forward_args_ckp.exists(): + is_module_class = inspect.isclass(module_or_module_class) + if is_module_class: + try: + if module_fn is None: + # it should only have 1 `self` parameter + if len(inspect.signature(module_or_module_class.__init__).parameters) > 1: + raise ValueError("Module class __init__ should be parameter-free.") + module = module_or_module_class() + else: + module = module_fn() + if type(module) != module_or_module_class: + raise ValueError(f"module_fn should return a {module_or_module_class} instance.") + except Exception as e: + raise RuntimeError(f"Error when creating module instance.") from e + else: + module = module_or_module_class + + if module_dtype is not None: + module = module.to(dtype=module_dtype) + + if any(isinstance(m, CubeModule) for m in module.modules()): + raise RuntimeError('CubeModule can not be nested.') + + graph, forward_args = _gen_graph(module, dummy_input, outdir, dynamic_shape) + graph.dump(graph_ckp) + torch.save(forward_args, forward_args_ckp) + + if is_module_class: + del module + else: + graph = IRGraph.load(graph_ckp) + forward_args = torch.load(forward_args_ckp) + graph = pas_policy(graph, compute_config) if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") @@ -345,9 +454,11 @@ def parallelize( *, dynamic_shape: bool = True, cube_savedir: Union[str, Path] = './.cube', - override: bool = False, + reuse: Union[ReuseType, str] = ReuseType.ALL, instance_name: Optional[str] = None, load_module: bool = True, + module_dtype: Optional[torch.dtype] = None, + module_fn: Optional[Callable[[], torch.nn.Module]] = None, ) -> Union[None, ParallelModule, Type[ParallelModule]]: """ Convert a torch.nn.Module object or class to CubeModule object or class. @@ -359,6 +470,9 @@ def parallelize( Or you can unset load_module flag, and manually copy the generated files to other nodes. After all nodes have the generated files, you can call parallelize() again with load_module flag set. + Note: if reuse is not set to ReuseType.ALL, + the generated code in outdir will be removed EVEN IF the code generetion fails in this call. + if the input is a module object. The module object will be copied to cpu to handle possible insufficient gpu memory. The training flag will be the same as the original module @@ -369,10 +483,12 @@ def parallelize( pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy compute_config (ComputeConfig): the environment resource dynamic_shape (bool): whether to use dynamic shape - override (bool): If true, source code will be regenerated even if generated code exists. + reuse (ReuseType): specify which part can be reused. cube_savedir (Union[str, Path]): the directory to save generated code instance_name (Optional[str]): the instance name of the generated module. load_module (bool): whether to load the generated module after done. + module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. + module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. Returns: Union[CubeModule, Type[CubeModule], None]: @@ -387,36 +503,24 @@ def parallelize( is_module_class = inspect.isclass(module_or_module_class) module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ + reuse = ReuseType(reuse) if isinstance(reuse, str) else reuse # genereate code only in node0 # if it is not in a torchrun environment, just generate. if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - if is_module_class: - # it should only have 1 `self` parameter - if len(inspect.signature(module_or_module_class.__init__).parameters) > 1: - raise ValueError("Module class __init__ should be parameter-free.") - try: - module = module_or_module_class() - except Exception as e: - raise RuntimeError(f"Error when create module instance.") from e - else: - module = module_or_module_class - - if any(isinstance(m, CubeModule) for m in module.modules()): - raise RuntimeError('CubeModule can not be nested.') - with _compile_flags(): - _gencode( - module, - dummy_input, - pas_policy, - compute_config, - dynamic_shape=dynamic_shape, - override=override, - cube_savedir=cube_savedir, - instance_name=instance_name, - ) - if is_module_class: - del module + outdir, reusable = _prepare_and_check_reusable(cube_savedir, module_class, compute_config, instance_name, reuse) + if not reusable: + with _compile_flags(compute_config): + _gencode( + module_or_module_class, + dummy_input, + pas_policy, + compute_config, + outdir, + dynamic_shape=dynamic_shape, + module_dtype=module_dtype, + module_fn=module_fn, + ) if load_module: if not torch.distributed.is_initialized(): # we only support loading in torchrun environment diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 41efe869..3fa464cc 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -234,14 +234,15 @@ def _gencode_contains(cubesave_dir, module_class, index, search_re): matches = re.findall(search_re, filecontent) return bool(matches) +class AttrHelper: + def __init__(self) -> None: + self.a = 2.0 + def test_codegen_attr(): if not torch.cuda.is_available(): print('skip test_codegen_attr due to lack of cuda devices') return - class AttrHelper: - def __init__(self) -> None: - self.a = 2.0 with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( @@ -286,7 +287,6 @@ def test_codegen_getitem(): dynamic_shape=True, cube_savedir=tempdir, load_module=False, - override=True, ) assert _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') assert _gencode_contains(tempdir, GetItemModule, 1, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') @@ -324,3 +324,66 @@ def test_codegen_training_flag(): cube_savedir=tempdir, load_module=False ) + + +# class IdentityModule(torch.nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# return x + + +# def test_codegen_identity(): +# """ +# Test it can support modules without parameters +# """ +# if not torch.cuda.is_available(): +# print('skip test_codegen_iter due to lack of cuda devices') +# return +# with tempfile.TemporaryDirectory() as tempdir: +# m = IdentityModule() +# m.train() +# parallelize( +# m, +# {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, +# PASData, +# ComputeConfig(1, 2), +# dynamic_shape=True, +# cube_savedir=tempdir, +# load_module=False +# ) +# assert False + + +# class IterModule(torch.nn.Module): +# def __init__(self): +# super().__init__() +# self.linear = torch.nn.Linear(3, 5) + +# def forward(self, x): +# x = self.linear(x) +# assert list(x.shape) == [2, 5] # will generate iter here. +# return x + + +# def test_codegen_iter(): +# """ +# Test it can support modules without parameters +# """ +# if not torch.cuda.is_available(): +# print('skip test_codegen_iter due to lack of cuda devices') +# return +# with tempfile.TemporaryDirectory() as tempdir: +# m = IterModule() +# m.train() +# parallelize( +# m, +# {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, +# PASData, +# ComputeConfig(1, 1), +# dynamic_shape=True, +# cube_savedir=tempdir, +# load_module=False +# ) +# assert False diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index 9a6ffcd0..dc753e94 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -1,26 +1,28 @@ from pathlib import Path +from time import sleep import sys import tempfile import pytest import torch import shutil -from cube.parallel import parallelize, ComputeConfig +from cube.parallel import ReuseType, parallelize, ComputeConfig from .common import PASData, init_distributed from ..launch_torchrun import launch_torchrun -def _to_cube_model(module, compute_config, cube_savedir, override, instance_name): +def _to_cube_model(module, compute_config, cube_savedir, reuse, instance_name, load_module=True): return parallelize( module, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, compute_config, dynamic_shape=True, - override=override, + reuse=reuse, cube_savedir=cube_savedir, instance_name=instance_name, + load_module=load_module, ) @@ -38,16 +40,16 @@ def _worker(): with tempfile.TemporaryDirectory() as tempdir: # False | empty | generate - cmodule1 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, False, None) + cmodule1 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.ALL, None) # False | match | do nothing - cmodule2 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, False, None) + cmodule2 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.ALL, None) # true for (n1, v1), (n2, v2) in zip(cmodule1.named_parameters(), cmodule2.named_parameters()): assert n1 == n2 assert torch.equal(v1, v2) - cmodule3 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, False, 'test') - cmodule4 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, False, 'test') + cmodule3 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.ALL, 'test') + cmodule4 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, 'all', 'test') for (n1, v1), (n2, v2) in zip(cmodule3.named_parameters(), cmodule4.named_parameters()): assert n1 == n2 @@ -60,31 +62,35 @@ def _worker(): # True | imported | raise error with pytest.raises(RuntimeError): - _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, True, None) + _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.NONE, None) with pytest.raises(RuntimeError): - _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, True, 'test') + _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.NONE, 'test') # False | unmatch | raise error with pytest.raises(RuntimeError): - _to_cube_model(MyModule(), ComputeConfig(2, 2),tempdir, False, 'test') + _to_cube_model(MyModule(), ComputeConfig(2, 2),tempdir, ReuseType.NONE, 'test') # True | empty | generate - cmodule1 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, True, 'test2') + cmodule1 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.NONE, 'test2') module_path = Path(sys.modules[cmodule1.__module__].__file__).parent test3_module_path = module_path.with_name('test3') test3_module_path.mkdir(exist_ok=True, parents=True) test4_module_path = module_path.with_name('test4') test4_module_path.mkdir(exist_ok=True, parents=True) + test5_module_path = module_path.with_name('test5') + test5_module_path.mkdir(exist_ok=True, parents=True) for f in module_path.glob('*'): if f.is_file(): shutil.copy(f, test3_module_path / f.name) shutil.copy(f, test4_module_path / f.name) + shutil.copy(f, test5_module_path / f.name) # fake two gpus shutil.copy(test4_module_path / 'gencode0.py', test4_module_path / 'gencode1.py') + shutil.copy(test5_module_path / 'gencode0.py', test5_module_path / 'gencode1.py') # True | match | generate - cmodule2 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, True, 'test3') + cmodule2 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, ReuseType.NONE, 'test3') cmodule2_p = dict(cmodule2.named_parameters()) cmodule1_p = dict(cmodule1.named_parameters()) keys = cmodule2_p.keys() @@ -92,9 +98,29 @@ def _worker(): # True | unmatch | generate assert (test4_module_path / 'gencode1.py').exists() - cmodule3 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, True, 'test4') + cmodule3 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, 'none', 'test4') assert not (test4_module_path / 'gencode1.py').exists() + # Graph | matched | generate + assert (test5_module_path / 'gencode1.py').exists() + code_stat = (test5_module_path / 'gencode0.py').stat() + graph_stat = (test5_module_path / 'graph.ckp').stat() + args_stat = (test5_module_path / 'forward_args.pkl').stat() + cmodule4 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, 'graph', 'test5', False) + assert not (test5_module_path / 'gencode1.py').exists() + assert (test5_module_path / 'gencode0.py').stat().st_mtime_ns != code_stat.st_mtime_ns + assert (test5_module_path / 'graph.ckp').stat().st_mtime_ns == graph_stat.st_mtime_ns + assert (test5_module_path / 'forward_args.pkl').stat().st_mtime_ns == args_stat.st_mtime_ns + + code_stat = (test5_module_path / 'gencode0.py').stat() + graph_stat = (test5_module_path / 'graph.ckp').stat() + (test5_module_path / 'forward_args.pkl').unlink() # remove foward_args.pkl will force to generate new code + cmodule5 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, 'graph', 'test5', False) + assert (test5_module_path / 'gencode0.py').stat().st_mtime_ns != code_stat.st_mtime_ns + assert (test5_module_path / 'graph.ckp').stat().st_mtime_ns != graph_stat.st_mtime_ns + assert (test5_module_path / 'forward_args.pkl').exists() + + def test_override(): if not torch.cuda.is_available(): print('skip test_submodules_tp_gpu1 due to lack of cuda devices') From da6e43d984eb82e56e1622ac04f430be9dc04be5 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 22 Nov 2023 02:57:32 +0000 Subject: [PATCH 1528/1892] Merged PR 1914: broadcast weight instead read from disk for rank >= PGPU --- cube/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cube/utils.py b/cube/utils.py index 79679732..305241dc 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -53,8 +53,7 @@ def load_model(filename: Optional[str] = None, load_content: bool = True): loaded_module: cube.runtime.module.CubeModule = module.GenModel().cuda() # load parameter content if load_content: - print_each_rank("> loading parameter content...", - logger=_logger) + _logger.info("loading parameter content...") loaded_module.load_attr_content('./fullmodel.pt') # initialize reducer for reducer in loaded_module.reducers: From 79612a9a407a98d8fadd1834e7d6a3cac28f4987 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Thu, 30 Nov 2023 09:22:53 +0000 Subject: [PATCH 1529/1892] Merged PR 1889: support partitioned fullmodel.pt to reduce peak CPU memory during initialization support partitioned fullmodel.pt to reduce peak CPU memory during initialization Related work items: #1635 --- cube/graph/parser/frame.py | 20 ++++++++++++++++++-- cube/graph/parser/fx/parser.py | 1 + cube/runtime/module.py | 28 +++++++++++++++++++++------- cube/utils.py | 1 + 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index 9e5217f6..ce63fa9f 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -125,12 +125,28 @@ def get_attr_var(self, concrete_value: torch.Tensor) -> Optional[IRTensor]: return tensor return None - def save_attr_content(self, save_file: str = 'fullmodel.pt'): + def save_attr_content(self, save_file: str): """ Save attribute content into file. """ + params_per_part = 1024 * 1024 * 1024 # 1 billion per part + total_size = sum([val.numel() for _, (_, val) in self._attr_map.items()]) + model_pt_part_num = (total_size + params_per_part - 1) // params_per_part + tid2value = {t.tid: val.cpu() for t, (_, val) in self._attr_map.items()} - torch.save(tid2value, save_file) + if model_pt_part_num == 1: + torch.save(tid2value, f'{save_file}.0') + else: + assert model_pt_part_num > 1 + sorted_keys = sorted(list(tid2value.keys())) + assert len(sorted_keys) > 0, "Empty attr map" + chunk_size = (len(sorted_keys) + model_pt_part_num - 1) // model_pt_part_num + chunks = [sorted_keys[i:min(i + chunk_size, len(sorted_keys))] for i in + range(0, len(sorted_keys), chunk_size)] + for idx, chunk in enumerate(chunks): + assert len(chunk) > 0, f"Empty chunk {idx}" + part = {k: tid2value[k] for k in chunk} + torch.save(part, f'{save_file}.{idx}') def save_attr_map(self, save_file: str = 'dist_param_map.pt'): """ diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 0540b4fc..07400812 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -1,3 +1,4 @@ +import os import torch import logging from pathlib import Path diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 8b925ce6..65a09e7e 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -151,14 +151,28 @@ def get_full_map(self): return self._fullmap def load_attr_content(self, filename: str): + partitioned_model_pt = 0 + while os.path.isfile(filename + f'.{partitioned_model_pt}'): + partitioned_model_pt += 1 + if partitioned_model_pt == 0: + raise RuntimeError(f"Cannot find file {filename}.0 in load_attr_content") with torch.no_grad(): - full = torch.load(filename) - for attr in self._fullmap.keys(): - tensor: torch.Tensor = getattr(self, attr) - tid, slicers, nchunks = self._fullmap[attr] - content = full[tid][slicers] / nchunks - tensor.copy_(content) - # print(f'attr {attr}:\n{getattr(self, attr)}') + _logger.info(f'load partitioned model from {filename}, partitioned_model_pt={partitioned_model_pt}') + fullmap2 = {tid: (attr, slicer, nchunks) for attr, (tid, slicer, nchunks) in self._fullmap.items()} + for file_idx in range(partitioned_model_pt): + full = torch.load(filename + f'.{file_idx}') + for tid in full.keys(): + if tid not in fullmap2: + _logger.warning(f'cannot find tid {tid} in fullmap2') + continue + fm = fullmap2[tid] + tensor: torch.Tensor = getattr(self, fm[0]) + content = full[tid][fm[1]] / fm[2] + tensor.copy_(content) + fullmap2.pop(tid, None) + + if len(fullmap2) != 0: + raise RuntimeError(f'cannot find tid {list(fullmap2.keys())} in partitioned model files') def init_group(self, ranks: List[int]): if not all([isinstance(rank, int) for rank in ranks]): diff --git a/cube/utils.py b/cube/utils.py index 305241dc..bfe3eeb2 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -1,3 +1,4 @@ +import os from typing import Optional, Tuple, Callable import logging From 61628aaac496346ddd5240a3969b77229a27e477 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 5 Dec 2023 01:36:43 +0000 Subject: [PATCH 1530/1892] Merged PR 1924: add options to skip init cube module --- cube/codegen/module/module.py | 14 +++-- cube/parallel.py | 91 ++++++++++++++++++++++++++---- cube/runtime/module.py | 33 ++++++++--- tests/parallel_module/test_init.py | 50 ++++++++++++++++ 4 files changed, 165 insertions(+), 23 deletions(-) create mode 100644 tests/parallel_module/test_init.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 198b7af2..0426b278 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -436,11 +436,15 @@ def forward(self, x, y=None, z=None): class_name='GenModel', derived=[f'cube.runtime.module.{"ParallelModule" if as_parallel_module else "CubeModule"}'] ) as cb: - with FunctionBlock(func_name='__init__', args=['self']) as ib: - ib.insert_body(self.model_init_statements) - if as_parallel_module: - cb.insert_body('') - ib.insert_body('self._post_init()') + if as_parallel_module: + cb.insert_body(f'rank = {device}') # save rank in class level + with FunctionBlock(func_name='__init__', args=['self', 'init_params=True']) as ib: + ib.insert_body(self.model_init_statements) + ib.insert_body('') + ib.insert_body('self._post_init(init_params)') + else: + with FunctionBlock(func_name='__init__', args=['self']) as ib: + ib.insert_body(self.model_init_statements) cb.insert_body('') cb.insert_body(ib.code) segment_idxs =[] diff --git a/cube/parallel.py b/cube/parallel.py index e7f444b6..df767fcb 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) -@dataclass +@dataclass(frozen=True) class ComputeConfig: plan_ngpus: int runtime_ngpus: int @@ -46,6 +46,32 @@ class ComputeConfig: use_zero: bool = False zero_ngroups: int = 1 + # which torch.distributed.ReduceOp is used when reduce gradients + # by torch.distributed.all_reduce or torch.distributed.reduce_scatter + # a special case for mean op + # In some cases, you may want to firstly divide the local gradients, and then use torch.distributed.ReduceOp.SUM + # to get the final the gradients + # example code to divide the local gradients: + #```python + # def _mean_hook(reducer, grad): + # if reducer.reduce_op == torch.distributed.ReduceOp.SUM: + # grad.div_(reducer.ranks) + # optimizer.register_reducer_pre_hook(_mean_hook) + # ``` + reducer_op: str = 'sum' + + def __post_init__(self): + if self.plan_ngpus <= 0: + raise ValueError(f"plan_ngpus {self.plan_ngpus} must be > 0") + if self.runtime_ngpus <= 0: + raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be > 0") + if self.runtime_ngpus % self.plan_ngpus != 0: + raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be a multiple of plan_ngpus {self.plan_ngpus}") + if self.use_zero and self.zero_ngroups < 0: + raise ValueError(f"zero_ngroups {self.zero_ngroups} must be >= 0") + if self.reducer_op not in ['sum', 'avg', 'mean', 'max', 'min']: + raise ValueError(f"reducer_op {self.reducer_op} is not supported.") + @contextmanager def _flags(flags, warning_on_override=True, /, **kwargs): @@ -66,7 +92,7 @@ def _flags(flags, warning_on_override=True, /, **kwargs): def _compile_flags(compute_config: ComputeConfig): return _flags( CompileFlag, - async_reducer=False, reducer_op='sum', async_comm=False, + async_reducer=False, reducer_op=compute_config.reducer_op, async_comm=False, use_zero=compute_config.use_zero, zero_ngroups=compute_config.zero_ngroups, ) @@ -180,13 +206,13 @@ def _prepare_and_check_reusable( # ALL | imported | doesn't matter reusable = False if reuse == ReuseType.ALL: - module_meta_files = [ + trace_meta_files = [ outdir / FxModuleParser.ATTR_CONTENT_FILE, outdir / FxModuleParser.ATTR_MAP_FILE, ] # check if the module is already generated expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] - expected_output_files.extend(module_meta_files) + expected_output_files.extend(trace_meta_files) expected_output_files.append(outdir / ParallelModule.COMPUTE_CONFIG_FILE) expected_output_files.append(outdir / _GRAPH_DUMP_FILE) expected_output_files.append(outdir / _FORWARD_ARGS_DUMP_FILE) @@ -204,7 +230,7 @@ def _prepare_and_check_reusable( f'But no python source code is present. ' f'Will reuse the directory and the graph dump if present.') # we have to trace the graph again if not all meta files are present. - if not all([meta_file.exists() for meta_file in module_meta_files]): + if not all([meta_file.exists() for meta_file in trace_meta_files]): for f in outdir.glob('*'): if f.is_file(): f.unlink() @@ -410,8 +436,12 @@ def _gencode( assert len(execplan.graph.device) == compute_config.plan_ngpus, f"{execplan.graph.device}" mgener = ModuleCodeGen(execplan, scale_ndevs=runtime_ngpus) for rank in range(compute_config.runtime_ngpus): - filename = _GENCODE_FILE_TEMPLATE.format(rank) - mgener.gen(rank, forward_args=forward_args, outfile=outdir / filename, attach=False, as_parallel_module=True) + mgener.gen(rank, + forward_args=forward_args, + outfile=outdir / _GENCODE_FILE_TEMPLATE.format(rank), + attach=False, + as_parallel_module=True, + ) def _load_cube_module_class( @@ -419,6 +449,7 @@ def _load_cube_module_class( *, cube_savedir: Union[str, Path] = './.cube', instance_name: Optional[str] = None, + rank: Optional[int] = None, ) -> Type[ParallelModule]: """ Load the generated cube module class. @@ -429,9 +460,12 @@ def _load_cube_module_class( module_class (Type[torch.nn.Module]): the original module class cube_savedir (Union[str, Path]): the directory to load generated code instance_name (Optional[str]): the instance name of the generated module. + rank (Optional[int]): the rank of the module. If it is None, will get the rank from torch.distributed.get_rank(). + This option is only useful for debugging or writing pre/post-processing tools. + when you need to load the generated module in a non-torchrun environment. """ _add_cube_savedir_to_syspath(cube_savedir) - rank = torch.distributed.get_rank() + rank = torch.distributed.get_rank() if rank is None else rank instance_name = instance_name.strip('.') if instance_name else '' instance_namespace = f'.{instance_name}' if instance_name else '' gen_imported = importlib.import_module( @@ -459,6 +493,7 @@ def parallelize( load_module: bool = True, module_dtype: Optional[torch.dtype] = None, module_fn: Optional[Callable[[], torch.nn.Module]] = None, + init_module_params: bool = True, ) -> Union[None, ParallelModule, Type[ParallelModule]]: """ Convert a torch.nn.Module object or class to CubeModule object or class. @@ -477,6 +512,38 @@ def parallelize( The module object will be copied to cpu to handle possible insufficient gpu memory. The training flag will be the same as the original module + This function can be used to convert both module object and module class to cube module or cube module class. + Among key-value arguments, + module_fn and module_dtype control how to create the module object. + whereas init_module_params controls how to load cube module object after conversion is done. + + 1. If the input is a module object, it will return a CubeModule object if load_module is True. + This is useful when the module is created by a factory function. + a. module_fn is ignored. + b. module_dtype is used to control the dtype of the input module. + c. init_module_params is used to control whether to initialize the cube module parameters when load it. + + 2. If the input is a module class, it will return a CubeModule class if load_module is True. + a. module_fn is used to create the module object, or module's__init__ if not prent. + b. module_dtype is used to control the dtype of the created module (by constructor or module_fn). + Of cousre, it can be merged into module_fn. + c. init_module_params is ignored. + + After the module is converted, you can use it to create module object by calling it like a module class. + The module class is defined like: + ``` + class GenModule(cube.runtime.module.ParallelModule): + def __init__(self, init_params=True): + super().__init__() + ... + ... + ``` + So you can use `init_params` in `__init__` to control whether to initialize the module parameters. + For example, if you don't want to intialize module params: + ``` + module = GenModule(init_params=False) + ``` + Args: module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled dummy_input (dict): the dummy input for the module @@ -486,7 +553,11 @@ def parallelize( reuse (ReuseType): specify which part can be reused. cube_savedir (Union[str, Path]): the directory to save generated code instance_name (Optional[str]): the instance name of the generated module. - load_module (bool): whether to load the generated module after done. + load_module (bool): whether to load the generated module or module class after conversion is done. + init_module_params (bool): If true, when we construct the module, all its parameters are initialized with the same value with when we traced. + Otherwise, they will be empty tensor. + This parameter will be passed to the module constructor, + so it is only used when module_or_module_class is a module object, and load_module is true. module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. @@ -534,7 +605,7 @@ def parallelize( if is_module_class: return cube_module_class else: - cube_module = cube_module_class() + cube_module = cube_module_class(init_module_params) cube_module.train(module_or_module_class.training) # set training state to the same as original module return cube_module diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 65a09e7e..e6c63c0c 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -25,16 +25,16 @@ def __init__(self): super().__init__() self._reducers: List[Reducer] = list() # Key: str, parameter name (from named_parameters) - # Value: Tuple[int, Tuple[slice], int]: - # full tensor tid, - # position of sub tensor in full tensor, + # Value: Tuple[int, Tuple[slice], int]: + # full tensor tid, + # position of sub tensor in full tensor, # position of value in value partition. self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() @property def reducers(self): return self._reducers - + @property def fullmap(self): return self._fullmap @@ -94,9 +94,9 @@ def parameters_for_broadcast(self) -> List[torch.nn.Parameter]: def parameters_for_calc_gnorm(self) -> List[ParamsInfo]: """Return the necessary information for calculating the gradient norm. - + Returns: - List[Tuple[Tuple[int], List[torch.nn.Parameter], List[str], int]]: + List[Tuple[Tuple[int], List[torch.nn.Parameter], List[str], int]]: A list of tuples, each tuple contains the following information: Tuple[int]: the ranks spanned by the parameters in the tuple List[torch.nn.Parameter]: the contiguous parameters in the tuple @@ -457,9 +457,23 @@ def __init__(self): # this is used to allow multiple sync_grad() calls self._sync_grad_required = False - def _post_init(self): + def _post_init(self, init_params=True): + """ + This is post init function to further initialize the model. Should be called by subclass's __init__(). + + Args: + init_params (bool): whether to load model init parameters. Default True. + """ + # Here we check the rank to load the module file name + # Current we don't check rank when we are not in distributed mode + # to facilitate local debugging + # TODO: re-enable this check + # if dist.is_initialized() and self.rank != dist.get_rank(): + # raise RuntimeError(f"The rank to load this module file name is expected to be {self._rank}, but got {dist.get_rank()}") + module_file = Path(sys.modules[self.__module__].__file__) - self.load_attr_content(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE}")) + if init_params: + self.load_attr_content(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE}")) self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}")) self._compute_config = torch.load(module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}")) @@ -491,3 +505,6 @@ def get_dist_param_map(self): def get_compute_config(self): return self._compute_config + + def get_rank(self): + return self.rank # rank is a class varible defined in gencode diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py new file mode 100644 index 00000000..f5c63943 --- /dev/null +++ b/tests/parallel_module/test_init.py @@ -0,0 +1,50 @@ +import tempfile + +import torch + +from cube.parallel import parallelize, ComputeConfig + +from ..launch_torchrun import launch_torchrun +from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD, clear_dir_on_rank0 + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x) + + +def _init_params_worker(): + init_distributed() + with tempfile.TemporaryDirectory() as tempdir: + cube_module = parallelize( + MyModule, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + PASRandomSPMD, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + reuse='all', + ) + module1 = cube_module() + module2 = cube_module() + module3 = cube_module(init_params=False) + assert module1.get_rank() == 0 + assert module2.get_rank() == 0 + assert module3.get_rank() == 0 + + for p1, p2 in zip(module1.parameters(), module2.parameters()): + assert torch.equal(p1, p2) + + for p1, p3 in zip(module1.parameters(), module3.parameters()): + assert not torch.equal(p1, p3) + assert torch.all(p3 == 0) + + +def test_init_params(): + if not torch.cuda.is_available(): + print('skip test_init_params due to lack of cuda devices') + return + launch_torchrun(1, _init_params_worker) From bce6f401fca7894075c1957eee0a531d7bd34fad Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 6 Dec 2023 01:49:31 +0000 Subject: [PATCH 1531/1892] Merged PR 1927: support torch.nn.functional.scaled_dot_product_attention support scaled_dot_product_attention to support memory efficient attention --- cube/graph/function/function.py | 26 ++++++++++++++++++++++++++ cube/graph/parser/fx/mapping.py | 2 ++ tests/graph/function/test_functions.py | 5 +++++ 3 files changed, 33 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index d2f0bfe1..7f4fe91d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1991,3 +1991,29 @@ def Is(input, other, signature=None): def IsNot(input, other, signature=None): assert not isinstance(input, IRObject) and not isinstance(other, IRObject) return input is not other + + +def ScaledDotProductAttention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, signature = None, **kwargs): + """ + torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) + + For a common attention, the generated anno is like (a e d^, a b^ d^, a b^ c -> a e c). + """ + if not isinstance(query, IRTensor) or not isinstance(key, IRTensor) or not isinstance(value, IRTensor): + raise RuntimeError(f'query: {query}, key: {key}, value: {value} should be IRTensor, something went wrong.') + if attn_mask is not None: + raise RuntimeError(f'Only support attn_mask is None in scaled_dot_product_attention now.') + gener = iter(string.ascii_lowercase) + value_anno = ShapeAnno.create_shape_str(value.shape, iterator=gener) + value_anno[-2] += '^' + key_anno = copy.copy(value_anno) + key_anno[-1] = next(gener) + '^' + query_anno = copy.copy(key_anno) + query_anno[-2] = next(gener) + out_anno = copy.copy(query_anno) + out_anno[-1] = value_anno[-1] + + anno = OpAnno.create_op_str([query_anno, key_anno, value_anno], [out_anno]) + return IRDimops(ScaledDotProductAttention, 'scaled_dot_product_attention', signature, [anno], [query, key, value], + attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 5fd31dc2..12852d80 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -112,6 +112,8 @@ def exist(signature: str) -> bool: __ftemplate('nll_loss') : function.NLLLoss, 'torch.functional.norm': function.Norm, __ftemplate('layer_norm'): function.LayerNorm, + __ftemplate('scaled_dot_product_attention'): function.ScaledDotProductAttention, + __fcntemplate('scaled_dot_product_attention'): function.ScaledDotProductAttention, # ============== runtime function ================= __tttemplate('size'): function.Size, diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index cdfee415..92e8b64f 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -213,3 +213,8 @@ def test_Unsqueeze(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a 1 b' op = F.Unsqueeze(IRTensor([2, 4]), -1) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b 1' + + +def test_ScaledDotProductAttention(): + op = F.ScaledDotProductAttention(IRTensor([8, 128, 64]), IRTensor([8, 256, 64]), IRTensor([8, 256, 32]), None, 0.05) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a e d^, a b^ d^, a b^ c -> a e c' From b2281e9aa7b65f43e6e25b797ddeeb3c614a832b Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 6 Dec 2023 05:59:34 +0000 Subject: [PATCH 1532/1892] Merged PR 1929: fix regressions caused by partitioned fullmodel.pt --- cube/graph/parser/frame.py | 15 ++++++++------- cube/graph/parser/fx/parser.py | 14 ++++++++------ cube/parallel.py | 10 ++++++++-- cube/runtime/module.py | 2 +- tests/conftest.py | 2 +- tests/graph/parser/test_converter.py | 2 +- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index ce63fa9f..ce5f1c37 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -22,8 +22,8 @@ def push_var(self, inherit_from_top=False): This should only be called when stepping in a module or method. Args: - inherit_from_top (bool): - whether to make all already defined variables in the top frame + inherit_from_top (bool): + whether to make all already defined variables in the top frame accessible to the evaluation procedure (e.g. references to such variables won't cause VarNotFound exception). """ @@ -55,7 +55,7 @@ def add_var(self, var_name: str, val: Any, graph_arg: int = -1): and link the name of the argument name from the callee function to the names of the argument passed-in. """ - + if not isinstance(var_name, str): raise RuntimeError("Expected var_name is str") if var_name in self._vars[-1]: @@ -76,12 +76,12 @@ def add_var(self, var_name: str, val: Any, graph_arg: int = -1): self._vars[-1][var_name] = val else: raise ValueError("graph_arg (int) must be >= 0") - + def set_var(self, var_name: str, val: Any): """ Reset a variable with arbitrary value. If `var_name` doesn't exist, will create a new one - + @param var_name str: variable name @param val Any """ @@ -129,15 +129,16 @@ def save_attr_content(self, save_file: str): """ Save attribute content into file. """ + #TODO: use FxModuleParser.ATTR_CONTENT_FILE_FORMAT to name the files. params_per_part = 1024 * 1024 * 1024 # 1 billion per part total_size = sum([val.numel() for _, (_, val) in self._attr_map.items()]) model_pt_part_num = (total_size + params_per_part - 1) // params_per_part tid2value = {t.tid: val.cpu() for t, (_, val) in self._attr_map.items()} - if model_pt_part_num == 1: + # it can be zero if there is no param in the module (self._attr_map is empty) + if model_pt_part_num <= 1: torch.save(tid2value, f'{save_file}.0') else: - assert model_pt_part_num > 1 sorted_keys = sorted(list(tid2value.keys())) assert len(sorted_keys) > 0, "Empty attr map" chunk_size = (len(sorted_keys) + model_pt_part_num - 1) // model_pt_part_num diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 07400812..49beee5d 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -23,7 +23,9 @@ class FxModuleParser: torch.fx module parser """ - ATTR_CONTENT_FILE = 'fullmodel.pt' + ATTR_CONTENT_FILE_STEM = 'fullmodel.pt' + ATTR_CONTENT_FILE_0 = 'fullmodel.pt.0' + ATTR_CONTENT_FILE_FORMAT = '{stem}.{idx}' ATTR_MAP_FILE = 'dist_param_map.pt' @staticmethod @@ -77,13 +79,13 @@ def parse(module: torch.fx.GraphModule, for node in module.graph.nodes: ir_nodes = FxModuleParser.parse_node(node, module, dynamic_shape, frame) all_ir_nodes += ir_nodes - + # get graph outputs outputs = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] if save_content: attr_savedir = Path(attr_savedir) - frame.save_attr_content(attr_savedir / FxModuleParser.ATTR_CONTENT_FILE) + frame.save_attr_content(attr_savedir / FxModuleParser.ATTR_CONTENT_FILE_STEM) frame.save_attr_map(attr_savedir / FxModuleParser.ATTR_MAP_FILE) frame.pop_var() @@ -140,13 +142,13 @@ def meta2var(meta: Any) -> Any: # FIXME: double check: there should be a concrete value as example, # otherwise, it may fail in parsing node like getattr val = IRObject(name=node.name, value=concrete_value) - + frame.add_var(node.name, val) @staticmethod def parse_complex(val: Any, frame: Frame) -> Any: """parse complex fx.Node into IRObject - + The val is usually from a node's input or output, can be fx.Node nested by tuple/list/dict type, or a fx.Node itself. @@ -268,7 +270,7 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame.add_attr(tensor, concrete_value, node.target) # the case that the parameter is consumed multiple times and regisetered previously else: - frame.set_var(node.name, exist_tensor) + frame.set_var(node.name, exist_tensor) else: assert not isinstance(concrete_value, torch.Tensor), f"GetAttrPrim: unexpected parameter" frame.set_var(node.name, concrete_value) diff --git a/cube/parallel.py b/cube/parallel.py index df767fcb..99344dcd 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -207,7 +207,7 @@ def _prepare_and_check_reusable( reusable = False if reuse == ReuseType.ALL: trace_meta_files = [ - outdir / FxModuleParser.ATTR_CONTENT_FILE, + outdir / FxModuleParser.ATTR_CONTENT_FILE_0, # just check the first is good enough outdir / FxModuleParser.ATTR_MAP_FILE, ] # check if the module is already generated @@ -216,7 +216,13 @@ def _prepare_and_check_reusable( expected_output_files.append(outdir / ParallelModule.COMPUTE_CONFIG_FILE) expected_output_files.append(outdir / _GRAPH_DUMP_FILE) expected_output_files.append(outdir / _FORWARD_ARGS_DUMP_FILE) - existing_output_files = [f for f in outdir.glob('*') if f.is_file()] + existing_output_files = [ + f for f in outdir.glob('*') + if f.is_file() and ( # just take fullmap.pt.0 to compare + not f.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) + or f.name == FxModuleParser.ATTR_CONTENT_FILE_0 + ) + ] if existing_output_files: if all([output_file.exists() for output_file in expected_output_files]) \ and len(existing_output_files) == len(expected_output_files) \ diff --git a/cube/runtime/module.py b/cube/runtime/module.py index e6c63c0c..062bf4a9 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -473,7 +473,7 @@ def _post_init(self, init_params=True): module_file = Path(sys.modules[self.__module__].__file__) if init_params: - self.load_attr_content(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE}")) + self.load_attr_content(str(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE_STEM}"))) self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}")) self._compute_config = torch.load(module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}")) diff --git a/tests/conftest.py b/tests/conftest.py index 1340efc1..ab682345 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ def clean_generated_files(): yield # try to clean generated files after each test run. basedir = Path('./').resolve() - generated_files = [FxModuleParser.ATTR_CONTENT_FILE, FxModuleParser.ATTR_MAP_FILE] + generated_files = [FxModuleParser.ATTR_CONTENT_FILE_0, FxModuleParser.ATTR_MAP_FILE] for f in generated_files: f = basedir / f if f.exists(): diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 7e02abad..80e2fe8b 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -42,7 +42,7 @@ def forward(self, x, **kwargs): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) assert ir_graph is not None assert (Path(tempdir) / FxModuleParser.ATTR_MAP_FILE).exists() - assert (Path(tempdir) / FxModuleParser.ATTR_CONTENT_FILE).exists() + assert (Path(tempdir) / FxModuleParser.ATTR_CONTENT_FILE_0).exists() assert ir_graph.name == 'MyModule' inputs = ir_graph.inputs() assert len(inputs) == 2 From a96c0b60cd2571ac5a625986333954820956babb Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 8 Dec 2023 12:46:57 +0000 Subject: [PATCH 1533/1892] Merged PR 1932: rm patch function cache in tracer to avoid OOM --- .../parser/fx/concrete_trace_utils/operator_patcher.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index f7a95a5b..92c3f8a8 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -170,16 +170,9 @@ class OperatorPatcher: def __init__(self, use_operator_patch: bool, operator_patch_backlist: List[str]): self.use_operator_patch = use_operator_patch self.operator_patch_backlist = operator_patch_backlist - self.function_cache: Dict[int, Callable] = {} - self.function_cache_orig: Dict[int, Callable] = {} def patch_inner(self, func): - if _orig_isinstance(func, torch.nn.Module): - return self.patch_inner_helper(func) # better not cache this - if id(func) not in self.function_cache: - self.function_cache[id(func)] = self.patch_inner_helper(func) - self.function_cache_orig[id(func)] = func - return self.function_cache[id(func)] + return self.patch_inner_helper(func) def patch_inner_helper(self, func): if not hasattr(func, '__module__') or func.__module__ is None or func.__module__.startswith('torch'): From d76ca13fb8eb5bec57c157666d9cf55df9c54f06 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 11 Dec 2023 01:06:05 +0000 Subject: [PATCH 1534/1892] Merged PR 1933: bug fix: literal output will generate invalid code like `5 = builtin.int(getitem_1)` bug fix: literal output will generate invalid code (for example, 5 = builtin.int(getitem_111) ) unit test passed parity check passed. --- cube/graph/parser/fx/parser.py | 2 +- tests/parallel_module/test_gencode.py | 90 ++++++++++++++++++--------- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 49beee5d..83884da5 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -133,7 +133,7 @@ def meta2var(meta: Any) -> Any: raise TypeError(f"only support dict type with str key, but got {meta.keys()}.\n{node}") return {key : meta2var(value) for key, value in meta.items()} # TODO: data type check, with cases like {'a': 1.2, 'b': torch.Tensor} - return meta + return IRObject(name=node.name, value=meta) if hasattr(node, 'meta') and node.meta.get('tensor_meta'): meta = node.meta['tensor_meta'] diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 3fa464cc..bd81880c 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -356,34 +356,68 @@ def test_codegen_training_flag(): # assert False -# class IterModule(torch.nn.Module): -# def __init__(self): -# super().__init__() -# self.linear = torch.nn.Linear(3, 5) +class IterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) -# def forward(self, x): -# x = self.linear(x) -# assert list(x.shape) == [2, 5] # will generate iter here. -# return x + def forward(self, x): + x = self.linear(x) + assert list(x.shape) == [2, 5] # will generate iter here. + return x -# def test_codegen_iter(): -# """ -# Test it can support modules without parameters -# """ -# if not torch.cuda.is_available(): -# print('skip test_codegen_iter due to lack of cuda devices') -# return -# with tempfile.TemporaryDirectory() as tempdir: -# m = IterModule() -# m.train() -# parallelize( -# m, -# {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, -# PASData, -# ComputeConfig(1, 1), -# dynamic_shape=True, -# cube_savedir=tempdir, -# load_module=False -# ) -# assert False +def test_codegen_iter(): + """ + Test it can support modules without parameters + """ + if not torch.cuda.is_available(): + print('skip test_codegen_iter due to lack of cuda devices') + return + with tempfile.TemporaryDirectory() as tempdir: + m = IterModule() + m.train() + # assert no exception raised below + parallelize( + m, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False + ) + + +class ConstantModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + x = self.linear(x) + y = int(x.shape[-1]) + x = x[:, :y] + return x + + +def test_codegen_const(): + """ + Test it can support modules without parameters + """ + if not torch.cuda.is_available(): + print('skip test_codegen_iter due to lack of cuda devices') + return + with tempfile.TemporaryDirectory() as tempdir: + m = ConstantModule() + m.train() + parallelize( + m, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False + ) + assert not _gencode_contains(tempdir, ConstantModule, 0, r'\s+5 = builtins.int') From 251ecbc01897db5301491d1e949216088687822f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 11 Dec 2023 05:02:20 +0000 Subject: [PATCH 1535/1892] Merged PR 1936: fix patched function names fix patched function names unit tests/parity check have passed. old: ![image.png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/1936/attachments/image.png) new: ![image (2).png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/1936/attachments/image%20%282%29.png) --- .../parser/fx/concrete_trace_utils/operator_patcher.py | 7 ++++--- tests/parallel_module/test_gencode.py | 4 +++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index 92c3f8a8..e03ab357 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -193,6 +193,7 @@ def patch_inner_helper(self, func): return func lines, lnum = inspect.findsource(func_inner) + func_name = getattr(func, '__name__', 'new_func') # align with original source code source = ''.join(('\n' * lnum, *inspect.getblock(lines[lnum:]))) dedent_src = dedent(source) @@ -218,7 +219,7 @@ def patch_inner_helper(self, func): ), *body0.body ] - body0.name = 'new_func' + body0.name = func_name # for deleting some annotations like 'add_start_docstrings_to_model_forward' or 'add_code_sample_docstrings' # these decorators are used for tranformers model docstrings generation, can be removed in trace transform_useless_decorators = ('add_start_docstrings_to_model_forward', 'add_code_sample_docstrings', 'replace_return_docstrings') @@ -251,9 +252,9 @@ def patch_inner_helper(self, func): }, var_dict) if the_self is not None: - return var_dict['new_func'].__get__(the_self) + return var_dict[func_name].__get__(the_self) else: - return var_dict['new_func'] + return var_dict[func_name] finally: if sys.version_info < (3, 9): setattr(builtins, 'tuple', tuple_wrapped) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index bd81880c..c47f913b 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -254,6 +254,8 @@ def test_codegen_attr(): cube_savedir=tempdir, load_module=False ) + # in old version, all 'forward' functions will patched to a function named 'new_func' + assert not _gencode_contains(tempdir, AttrModule, 0, r'new_func') assert _gencode_contains(tempdir, AttrModule, 0, r'builtins.getattr\(.*, \'a\'\)') assert m_new is None @@ -406,7 +408,7 @@ def test_codegen_const(): Test it can support modules without parameters """ if not torch.cuda.is_available(): - print('skip test_codegen_iter due to lack of cuda devices') + print('skip test_codegen_const due to lack of cuda devices') return with tempfile.TemporaryDirectory() as tempdir: m = ConstantModule() From 2a0893608cc61241a9de280dbc523d5d1eb41c3e Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 12 Dec 2023 05:58:13 +0000 Subject: [PATCH 1536/1892] Merged PR 1940: Refine interface: gen_partitions Generate the partitioned nodes of the given node. Args: - node (IRFwOperation): the node to be partitioned - ngpus (int): the number of gpus - base (int): the base of the division for the partitioning - depth (int): the maximum depth of the search process Returns: - List[IRFwOperation]: the partitioned nodes. Each element of the list represents the (identical) sub-operator of one partition option. --- cube/algorithm/ops/dimops.py | 41 ++++++++++++++++++++++++------ tests/algorithm/ops/test_dimops.py | 31 ++++++++++++++++++++++ 2 files changed, 64 insertions(+), 8 deletions(-) create mode 100644 tests/algorithm/ops/test_dimops.py diff --git a/cube/algorithm/ops/dimops.py b/cube/algorithm/ops/dimops.py index 3c8fc06b..26b2b69c 100644 --- a/cube/algorithm/ops/dimops.py +++ b/cube/algorithm/ops/dimops.py @@ -269,12 +269,31 @@ def collect_split_info(node: IRFwOperation): return split_info -def gen_partitions(node: IRFwOperation, ngpus: int) -> List[IRFwOperation]: +def gen_partitions(node: IRFwOperation, ngpus: int, base: int = 2, depth: int = -1) -> List[IRFwOperation]: """ + Generate the partitioned nodes of the given node. Each node in the returned list is an + partition instance of a policy. For example, if the input node is a matmul with shape + (1024, 4096), (4096, 2048) -> (1024, 2048), the ngpus is 2, base is 2, then the returned + list will contain 4 instances: + 1. matmul with shape (1024, 4096), (4096, 2048) -> (1024, 2048) + 2. matmul with shape (1024, 2048), (2048, 2048) -> (1024, 2048) + 3. matmul with shape ( 512, 4096), (4096, 2048) -> ( 512, 2048) + 4. matmul with shape (1024, 4096), (4096, 1024) -> (1024, 1024) + + Args: + node (IRFwOperation): the node to be partitioned + ngpus (int): the number of gpus + base (int): the base of the division for the partitioning + depth (int): the maximum depth of the search process, -1 for no limit + Returns: List[IRFwOperation]: the partitioned nodes. Each element of the list represents the (identical) sub-operator of one partition option. """ + if base < 1: + raise ValueError(f"base must be positive, got {base}") + if base == 1: + return [node] def gen_hash(node: IRFwOperation) -> str: ret = node.signature @@ -285,38 +304,44 @@ def gen_hash(node: IRFwOperation) -> str: dq = deque() visited = set() - dq.append((node, ngpus)) + dq.append((node, ngpus, 0)) visited.add(gen_hash(node)) gen_nodes = [] while dq: - cur_node, cur_ngpus = dq.popleft() + cur_node, cur_ngpus, cur_depth = dq.popleft() gen_nodes.append(cur_node) + if depth != -1 and cur_depth >= depth: + continue split_info = collect_split_info(cur_node) for key, val in split_info.items(): idx_1st, dim_1st, _ = val dim_size = cur_node.anno.getlen(key) - # TODO(yizhu1): only consider powers of 2 currently - split_deg = 2 + split_deg = base while split_deg <= dim_size and split_deg <= cur_ngpus: if dim_size % split_deg != 0: break + if cur_ngpus % split_deg != 0: + break new_nodes = cur_node.algorithms('dim').instantiate(idx=idx_1st, dim=dim_1st, num=split_deg) + # instantiate may return None if the partition is not possible + if new_nodes is None: + break new_node = new_nodes[0] new_ngpus = cur_ngpus // split_deg cur_key = gen_hash(new_node) - split_deg = split_deg * 2 + split_deg = split_deg * base if cur_key in visited: continue - dq.append((new_node, new_ngpus)) + dq.append((new_node, new_ngpus, cur_depth + 1)) visited.add(cur_key) - return gen_nodes \ No newline at end of file + return gen_nodes diff --git a/tests/algorithm/ops/test_dimops.py b/tests/algorithm/ops/test_dimops.py new file mode 100644 index 00000000..232c29f1 --- /dev/null +++ b/tests/algorithm/ops/test_dimops.py @@ -0,0 +1,31 @@ +import tempfile +import torch +import os +from cube.parallel import _gen_graph +from cube.ir.operator import IRFwOperation +from cube.algorithm.ops.dimops import gen_partitions + +class NaiveFFN(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(1024, 4096, bias=False) + self.linear2 = torch.nn.Linear(4096, 1024, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + +def test_gen_partitions(): + with tempfile.TemporaryDirectory() as tempdir: + graph, _ = _gen_graph(NaiveFFN(), {'x': torch.randn(2, 128, 1024)}, tempdir, False) + fc1, relu, fc2 = graph.select(ntype=IRFwOperation) + assert len(gen_partitions(fc1, 1)) == 1 + # C(4, 1) + 1 = 5 + assert len(gen_partitions(fc1, 2)) == 5 + # C(4, 2) + 2 * C(4, 1) + 1 - 1 = 14 + assert len(gen_partitions(fc1, 4)) == 14 + # C(4, 1) + 1 - 1 = 4 + assert len(gen_partitions(fc1, 4, base=4, depth=1)) == 4 From 635cd86d8e285cf8ddac1b9f965e9ae8da79e292 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 12 Dec 2023 09:02:46 +0000 Subject: [PATCH 1537/1892] Merged PR 1925: fix bug: support keyword dim & keepdim in kwargs --- cube/graph/function/function.py | 10 +++++++++- tests/graph/parser/test_parser.py | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 7f4fe91d..5bbd493b 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1787,13 +1787,21 @@ def CompareNE(input, other, *, out=None, signature = None): return _comparison(CompareNE, operator.eq, 'ne', signature, input, other) -def Max(input, other_or_dim=None, out_or_keepdim=None, *, out=None, signature = None): +def Max(input, other_or_dim=None, out_or_keepdim=None, *, out=None, signature = None, **kwargs): """ torch.max(input) torch.max(input, dim, keepdim=False, *, out=None) torch.max(input, other, *, out=None) """ signature = 'cube.runtime.function.max_' + if 'dim' in kwargs: + other_or_dim = kwargs['dim'] + if 'keepdim' in kwargs: + assert 'out' not in kwargs, f'out and keepdim cannot be both specified, get {kwargs}' + out_or_keepdim = kwargs['keepdim'] + if 'out' in kwargs: + assert 'keepdim' not in kwargs, f'out and keepdim cannot be both specified, get {kwargs}' + out_or_keepdim = kwargs['out'] if other_or_dim is None: edim_in = [s + '^' for s in ShapeAnno.create_shape_str(input.shape)] annos = [OpAnno.create_op_str([edim_in], ['1'])] diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 75565925..e9cb16fc 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -66,3 +66,25 @@ def forward(self, x: dict): assert len(ir_graph.outputs()) == 1 assert isinstance(ir_graph.output(0), dict) assert isinstance(ir_graph.output(0)['loss'], IRTensor) + + +def test_max(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.max(x, dim=1, keepdim=True)[0] + + dummy_input = {'x': torch.randn(4, 1024)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + print(ir_graph.extra_repr()) + + assert isinstance(ir_graph.output(0), IRTensor) + assert ir_graph.output(0).shape == (4, 1) From 5f2bbda98bd05b21f9902e0748db011df8477c97 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 13 Dec 2023 00:59:27 +0000 Subject: [PATCH 1538/1892] Merged PR 1926: fix bug: sort tensors generated by multi-ref Avoid warning messages in rvd like: ![image.png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/1926/attachments/image.png) --- cube/graph/gener/gen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 913c5185..f294891f 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -726,6 +726,10 @@ def autoref(graph: IRSegment) -> IRGraph: # by default follow producer transformation strategy ptensors = graph.ptensors(ftensor) if len(ptensors) > 0: + # In order to generate correct adapters for multiref, we need to + # ensure Multirefs below is ordered by devices, which is aligned + # with consumer operators. As a result, we sort the ptensors here. + ptensors = sorted(ptensors, key=lambda t: t.device[0]) for tensor in ptensors: mr = MultiRef(tensor, len(multiref.outputs())) mr.input(0).grad = tensor.grad From f012c346b7afde90202140c41d0965aee96df88c Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 13 Dec 2023 08:50:20 +0000 Subject: [PATCH 1539/1892] Merged PR 1928: remove force trace torch.no_grad we don't need keep torch.no_grad in fx graph: 1. cube doesn't support this node now. 2. after merge concrete trace and get tensor meta, the tensor has correct requires_grad in cube graph now. 3. force trace no_grad function will case error if it is used in leaf function. --- .../concrete_trace_utils/concrete_tracer.py | 40 +------------------ .../parser/fx/concrete_trace_utils/utils.py | 3 -- cube/graph/parser/register.py | 2 +- tests/graph/parser/test_no_grad.py | 39 ++++++++++++++++++ 4 files changed, 41 insertions(+), 43 deletions(-) create mode 100644 tests/graph/parser/test_no_grad.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 8dbcc8d2..36098fae 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -89,9 +89,6 @@ def __exit__(self, *args): _orig_agfunc_apply, _orig_torch_assert, - _orig_torch_no_grad, - _orig_torch_no_grad_enter, - _orig_torch_no_grad_exit, _orig_type, _orig_isinstance, @@ -380,7 +377,7 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] result = result.cpu() elif isinstance(result, (list, dict, tuple)): result = tree_map(to_cpu, result) - elif isinstance(result, (int, bool, torch.device, torch.dtype, _orig_torch_no_grad)) or result is None: + elif isinstance(result, (int, bool, torch.device, torch.dtype)) or result is None: # avoid too noisy warning pass else: @@ -965,18 +962,6 @@ def torch_assert_wrapper(condition, message): condition = condition.value return _orig_torch_assert(condition, message) - @functools.wraps(_orig_torch_no_grad) - def torch_no_grad_wrapper(): - return self.create_proxy('call_function', _orig_torch_no_grad, (), {}) - - @functools.wraps(_orig_torch_no_grad_enter) - def torch_no_grad_enter_wrapper(no_grad): - return self.create_proxy('call_function', _orig_torch_no_grad_enter, (no_grad,), {}) - - @functools.wraps(_orig_torch_no_grad_exit) - def torch_no_grad_exit_wrapper(no_grad, exc_type, exc_value, traceback): - return self.create_proxy('call_function', _orig_torch_no_grad_exit, (no_grad, exc_type, exc_value, traceback,), {}) - self.agfunc_dict: dict[Type, Any] = {} self.autowrap_leaf_pairs = { id(_orig_torch_assert): torch_assert_wrapper, @@ -1114,11 +1099,6 @@ def getattr_wrapper(obj, *args): self.patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) self.patcher.patch_method(torch.autograd.Function, "apply", agfunc_apply_wrapper, deduplicate=False) self.patcher.patch_method(torch, "_assert", torch_assert_wrapper, deduplicate=False) - # if class member functions and the class need to be wrapped together, - # wrap the member functions before wrap the class. - self.patcher.patch_method(_orig_torch_no_grad, "__enter__", torch_no_grad_enter_wrapper, deduplicate=False) - self.patcher.patch_method(_orig_torch_no_grad, "__exit__", torch_no_grad_exit_wrapper, deduplicate=False) - self.patcher.patch_method(torch, "no_grad", torch_no_grad_wrapper, deduplicate=False) self.patcher.patch_method(builtins, "map", map_wrapper, deduplicate=False) self.patcher.patch_method(builtins, "enumerate", enumerate_wrapper, deduplicate=False) @@ -1510,9 +1490,6 @@ def _retain_weight_consistency(root: torch.nn.Module): @functools.wraps(_orig_node_is_impure) def node_is_impure_wrapper(node): - if is_useless_no_grad_node(node): - return False - if node.op in {"placeholder", "output"}: return True @@ -1534,18 +1511,6 @@ def node_is_impure_wrapper(node): return False -def is_useless_no_grad_node(node: Node): - # keep the no_gard related nodes, but except useless situation: no node between __enter__ and __exit__ - if node.op == 'call_function': - if node.target is _orig_torch_no_grad_exit: - if node.prev.target is _orig_torch_no_grad_enter and node.prev.prev.target is _orig_torch_no_grad: - setattr(node.prev, '_is_impure', False) - setattr(node.prev.prev, '_is_impure', False) - return True - if node.target is _orig_torch_no_grad_enter or node.target is _orig_torch_no_grad: - return not getattr(node, '_is_impure', True) - return False - def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Union[Dict[str, Any], Tuple], *, @@ -1752,9 +1717,6 @@ def f(x, y): default_extra_side_effectful_functions = { operator.setitem, builtins.next, - _orig_torch_no_grad, - _orig_torch_no_grad_enter, - _orig_torch_no_grad_exit, } extra_side_effectful_functions = default_extra_side_effectful_functions | dce_ignored_function with _Patcher() as patcher, ExtraSEFPatcher(extra_side_effectful_functions): diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index 3b792870..b2ac345e 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -17,9 +17,6 @@ _orig_agfunc_apply: Callable = torch.autograd.function.Function.apply _orig_torch_assert: Callable = torch._assert -_orig_torch_no_grad: Callable = torch.no_grad -_orig_torch_no_grad_enter: Callable = torch.no_grad.__enter__ -_orig_torch_no_grad_exit: Callable = torch.no_grad.__exit__ _orig_type: Callable = builtins.type _orig_isinstance: Callable = builtins.isinstance diff --git a/cube/graph/parser/register.py b/cube/graph/parser/register.py index f2779703..5742316b 100644 --- a/cube/graph/parser/register.py +++ b/cube/graph/parser/register.py @@ -56,7 +56,7 @@ def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Call Returns: None """ - builtins = ['_operator', 'torch', 'cube.runtime.function'] + builtins = ['_operator.', 'torch.', 'cube.runtime.function.'] if any(signature.startswith(builtin) for builtin in builtins): raise RuntimeError(f"Cannot register operators with signature starting from any of {builtins}") assert signature not in CustomizedOps.kOpMap, f"function {signature} is already registered" diff --git a/tests/graph/parser/test_no_grad.py b/tests/graph/parser/test_no_grad.py new file mode 100644 index 00000000..a8d48943 --- /dev/null +++ b/tests/graph/parser/test_no_grad.py @@ -0,0 +1,39 @@ +import pytest +import torch + +from cube.graph.parser.converter import to_fx_graph + + +def test_no_grad(): + class SimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 10) + self.fc2 = torch.nn.Linear(10, 10) + + def forward(self, x): + with torch.no_grad(): + x = self.fc1(x) + x = self.fc2(x) + return x + + traced_graph = to_fx_graph(SimpleModel(), {'x': torch.rand(4, 10)}) + + # The traced graph: + # + # def forward(self, x): + # fc1_weight = self.fc1.weight + # fc1_bias = self.fc1.bias + # linear = torch._C._nn.linear(x, fc1_weight, fc1_bias); x = fc1_weight = fc1_bias = None + # fc2_weight = self.fc2.weight + # fc2_bias = self.fc2.bias + # linear_1 = torch._C._nn.linear(linear, fc2_weight, fc2_bias); linear = fc2_weight = fc2_bias = None + # return linear_1 + + assume_no_requires_grad_nodes = set(['x', 'linear']) + actual_no_requires_grad_nodes = set() + for node in traced_graph.graph.nodes: + if node.meta['tensor_meta'].requires_grad == False: + actual_no_requires_grad_nodes.add(node.name) + assert assume_no_requires_grad_nodes == actual_no_requires_grad_nodes, \ + f'assume no require grad node: {assume_no_requires_grad_nodes}, actual no require grad node: {actual_no_requires_grad_nodes}' From cb0c0fad7dbd645502bb53a4ea60c9b7c0609509 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 13 Dec 2023 12:43:40 +0000 Subject: [PATCH 1540/1892] Merged PR 1938: report error when fullslice args contain tensor --- cube/graph/function/function.py | 10 ++++- tests/graph/function/test_functions.py | 10 +++++ tests/parallel_module/test_gencode.py | 54 ++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 5bbd493b..acbf36e9 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1472,7 +1472,13 @@ def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice, int]], signatu edim_in = ShapeAnno.create_shape_str(tensor.shape) edim_ou = [] in_idx = 0 + tensor_error_msg = ("Tensor is not supported in slice. " + + "If the tensor is scalar type, you can conver it to int by tensor.item() or int(), then use it to index. " + + "If the tensor is not scalar type, you may need to wrap related logic in a Customized Op." + ) def obj_helper(obj): + if isinstance(obj, IRTensor): + raise RuntimeError(tensor_error_msg) if isinstance(obj, IRObject): return obj.value else: @@ -1497,8 +1503,10 @@ def obj_helper(obj): else: edim_ou.append(str(dimlen)) in_idx += 1 + elif isinstance(slicer, IRTensor): + raise RuntimeError(tensor_error_msg) else: - raise RuntimeError(f"Unsupported slicer {slicer}") + raise RuntimeError(f"Unsupported slicer {slicer}. you may need to wrap related logic in a Customized Op.") edim_ou += edim_in[in_idx:] # special case for scalar = torch.Tensor([1,2,3])[0] if len(edim_ou) == 0: diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 92e8b64f..db15ad6d 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -20,6 +20,7 @@ def test_handle_broadcast_multi(): assert ins_anno[2] == ['a', 'b', 'c'] assert out_anno == ['a', 'b', 'c'] + def test_Full(): op = F.Full([1, 2, 3], 1.) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 1 2 3' @@ -27,6 +28,7 @@ def test_Full(): op = F.Full([], 1.) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 1' + def test_Expand(): inp = IRTensor([10, 1]) out = IRTensor([10, 2]) @@ -148,6 +150,7 @@ def test_Where(): op = F.Where(IRTensor([3, 4]), 1, 2) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, ?, ? -> a b' + def test_FullSlice(): op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, 3)) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' @@ -157,6 +160,11 @@ def test_FullSlice(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 2' op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, slice(1, 10, 1))) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 3' + with pytest.raises(RuntimeError): + op = F.FullSlice(IRTensor([2, 3, 4]), (IRTensor([1, 2, 3]),)) + with pytest.raises(RuntimeError): + op = F.FullSlice(IRTensor([2, 3, 4]),(slice(1, IRTensor([2]), 3),)) + def test_GetItem(): op = F.GetItem(IRTensor([4, 2]), IRTensor([3, 5], dtype=torch.int64)) @@ -168,6 +176,7 @@ def test_GetItem(): op = F.GetItem(IRTensor([3, 4, 2]), IRTensor([3, 5], dtype=torch.int64)) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, d e -> d e b c' + def test_Max(): op = F.Max(IRTensor([2, 3, 4])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' @@ -178,6 +187,7 @@ def test_Max(): op = F.Max(IRTensor([2, 3, 4]), 1, False) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a c, a c' + def test_Squeeze(): op = F.Squeeze(IRTensor([2, 1, 4, 1])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c^ d -> a^ c^' diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index c47f913b..dd879e91 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -423,3 +423,57 @@ def test_codegen_const(): load_module=False ) assert not _gencode_contains(tempdir, ConstantModule, 0, r'\s+5 = builtins.int') + + +class TensorSliceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + x = self.linear(x) + padding = torch.count_nonzero(x) + return x[:, :padding] + + +class TensorSliceFixedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + x = self.linear(x) + padding = torch.count_nonzero(x).item() + return x[:, :padding] + + +def test_codegen_tensor_slice(): + if not torch.cuda.is_available(): + print('skip test_codegen_tensor_slice due to lack of cuda devices') + return + with tempfile.TemporaryDirectory() as tempdir: + m = TensorSliceModule() + m.train() + with pytest.raises(RuntimeError, match='Tensor is not supported in slice.'): + parallelize( + m, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False, + reuse='none', + ) + m = TensorSliceFixedModule() + m.train() + parallelize( + m, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False, + reuse='none', + ) From 2f9cf78f3752efa9454a942746f3536befb735c4 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 14 Dec 2023 07:16:08 +0000 Subject: [PATCH 1541/1892] Merged PR 1946: refactor scale_ndevs logic in CodeGen deprecate scale_ndevs argument. add runtime_ndevs argument instead. So we don't need to care whether runtime_ngpus is the same with plan_ngpus. unit test pass parity check pass. --- cube/codegen/module/module.py | 73 ++++++++++++++++++------------- cube/codegen/schedule/schedule.py | 35 ++++++++++++--- cube/parallel.py | 3 +- tests/algorithm/__init__.py | 0 tests/algorithm/ops/__init__.py | 0 5 files changed, 72 insertions(+), 39 deletions(-) create mode 100644 tests/algorithm/__init__.py create mode 100644 tests/algorithm/ops/__init__.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 0426b278..dad58d81 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -79,17 +79,40 @@ class ModuleCodeGen(FuncEmission): ``` """ - def __init__(self, execplan: ExecutionPlan, scale_ndevs: Optional[int] = None) -> None: + def __init__( + self, + execplan: ExecutionPlan, + runtime_ndevs: Optional[int] = None, + *, + scale_ndevs: Optional[int] = None + ) -> None: """ Create Module code generator - @param execplan ExecutionPlan - @param scale_ndevs Optional[int]: scale to number of devices + Args: + execplan (ExecutionPlan): execution plan + runtime_ndevs (Optional[int]): the number of devices in runtime + scale_ndevs (Optional[int]): Deprecated. Use `runtime_ndevs` instead """ - super().__init__() self.execplan: ExecutionPlan = execplan self.devices: Tuple[int] = tuple(sorted(execplan.graph.device)) + if self.devices != tuple(range(len(self.devices))): + raise ValueError(f'device must be consecutive') + + if scale_ndevs is not None: + _logger.warning("scale_ndevs is deprecated, please use runtime_ndevs instead") + if runtime_ndevs is not None: + raise ValueError("You cannot use runtime_ndevs and scale_ndevs at the same time") + self.runtime_ndevs: int = runtime_ndevs or scale_ndevs or len(self.devices) + # we will scale the graph as data parallelism + # when we have more devices than the number of devices used in the graph + # we need to do two things: + # 1. update execplan with dp reducers (via add_scale_reducers) + # 2. update node devices when emitting code (via scale) + if self.runtime_ndevs % len(self.devices) != 0: + raise ValueError(f'runtime_ndevs must be a multiple of {len(self.devices)}') + self.enable_dp = self.runtime_ndevs > len(self.devices) self.init_code: List[str] = [ '\n\n########## Generated Model Code ###########', @@ -117,17 +140,14 @@ def __init__(self, execplan: ExecutionPlan, scale_ndevs: Optional[int] = None) - # batch size self.batch_size = None # communication groups - self.comm_groups: List[Tuple[int]] = self.get_comm_groups(scale_ndevs) - # whether to scale (with data parallelism) - self._scale_to_ndevs = scale_ndevs - if scale_ndevs is not None: - self.add_scale_reducers() + self.comm_groups: List[Tuple[int]] = self.get_comm_groups() + self.add_scale_reducers() def add_scale_reducers(self): """ Insert reducers to for scale scenario """ - if self._scale_to_ndevs is None: + if not self.enable_dp: return graph = self.execplan.graph # for each device, collect parameters in the all reducers and create a reducer for the rest @@ -157,15 +177,13 @@ def add_scale_reducers(self): reducer.device = device # will be scaled in `self.scale` self.execplan.at(device).append(reducer) - def get_comm_groups(self, scale_ndevs: Optional[int] = None): + def get_comm_groups(self): """ Scale the communication groups to multiple devices using data parallelism. @warn this requires user side to setup dataloader for different GPUs - - @param scale_ndevs Optional[int]: scale to number of devices """ def _add_comm_for_group_zero(ranks): zero_comm_groups = [] @@ -186,17 +204,15 @@ def _add_comm_for_group_zero(ranks): if len(zero_crossgroup) > 1 and len(zero_crossgroup) < len(ranks): zero_comm_groups.append(zero_crossgroup) return zero_comm_groups - scale_ndevs = scale_ndevs if scale_ndevs is not None else len(self.devices) - assert len(self.devices) == max(self.devices) + 1, f'device must be consecutive' - assert scale_ndevs % len(self.devices) == 0, f'ngpus must be a multiple of {len(self.devices)}' - nreplica = scale_ndevs // len(self.devices) + + nreplica = self.runtime_ndevs // len(self.devices) # scale communication groups graph = self.execplan.graph comm_groups = [] # communication groups for parameters that are in reducers reducers: List[IRWeightReducer] = graph.select(ntype=IRWeightReducer) for reducer in reducers: - ranks = more_itertools.flatten(list(range(device, scale_ndevs, len(self.devices))) \ + ranks = more_itertools.flatten(list(range(device, self.runtime_ndevs, len(self.devices))) \ for device in reducer.device) ranks = tuple(sorted(ranks)) comm_groups.append(ranks) @@ -204,7 +220,7 @@ def _add_comm_for_group_zero(ranks): comm_groups.extend(_add_comm_for_group_zero(ranks)) # communication groups for parameters that are outside reducers for device in self.devices: - ranks = list(range(device, scale_ndevs, len(self.devices))) + ranks = list(range(device, self.runtime_ndevs, len(self.devices))) if len(ranks) > 1: comm_groups.append(ranks) # add comm groups for group ZeRO @@ -222,9 +238,9 @@ def _add_comm_for_group_zero(ranks): comm_groups.append(shifted_ranks) return comm_groups - def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: - assert len(self.devices) == max(self.devices) + 1, f'device must be consecutive' - assert ndevs % len(self.devices) == 0, f'ngpus must be a multiple of {len(self.devices)}' + def scale(self, node: IRCell, device: int) -> IRCell: + if not self.enable_dp: + return node shift = (device // len(self.devices)) * len(self.devices) if isinstance(node, IRAdapter): adapter = copy.copy(node) @@ -247,7 +263,7 @@ def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: prims.append(p) adapter.prims = prims if node.isfw() and node.differentiable and node.custom: - badapter = self.scale(node.mirror, ndevs, device) + badapter = self.scale(node.mirror, device) IRCell.make_pair(adapter, badapter) return adapter if isinstance(node, IRWeightReducer): @@ -256,11 +272,11 @@ def scale(self, node: IRCell, ndevs: int, device: int) -> List[IRCell]: ranks = list(node.device) scale_ranks = [] for rank in ranks: - scale_ranks += list(range(rank, ndevs, len(self.devices))) + scale_ranks += list(range(rank, self.runtime_ndevs, len(self.devices))) reducer.device = sorted(scale_ranks) return reducer if isinstance(node, IRSegment) and node.isfw(): - nodes = [self.scale(n, ndevs, device) for n in node.nodes()] + nodes = [self.scale(n, device) for n in node.nodes()] segment = IRSegment(nodes, node.inputs(), node.outputs(), node.name) segment._id = node.cid return segment @@ -366,8 +382,7 @@ def forward(self, x, y=None, z=None): node_args: List[List[str]] = list() gen_nodes: List[IRCell] = list() - device_map = device if self._scale_to_ndevs is None else \ - device % len(self.devices) + device_map = device % len(self.devices) sequence = self.execplan.seq(device_map) unrolled_seqs = [] for node in sequence: @@ -378,9 +393,7 @@ def forward(self, x, y=None, z=None): sequence = tuple(dict.fromkeys(unrolled_seqs)) # scale to multiple devices - if self._scale_to_ndevs is not None: - sequence = [self.scale(node, self._scale_to_ndevs, device) \ - for node in sequence] + sequence = [self.scale(node, device) for node in sequence] # init customized adapter fsegments = [node for node in sequence if isinstance(node, IRSegment) and node.isfw()] diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index 8c3a2e42..d72e61e6 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -22,30 +22,51 @@ class ScheduleCodeGen(FuncEmission): - def __init__(self, execplan: ExecutionPlan, scale_ndevs: Optional[int] = None): + def __init__( + self, + execplan: ExecutionPlan, + runtime_ndevs: Optional[int] = None, + *, + scale_ndevs: Optional[int] = None, + ): """ - Create Module code generator + Create a schedule code generator - @param execplan ExecutionPlan - @param scale_ndevs Optional[int]: scale to number of devices + Args: + execplan (ExecutionPlan): execution plan + runtime_ndevs (Optional[int]): the number of devices in runtime + scale_ndevs (Optional[int]): Deprecated. Use `runtime_ndevs` instead """ self.execplan = execplan self.devices: Tuple[int] = tuple(sorted(execplan.graph.device)) + if self.devices != tuple(range(len(self.devices))): + raise ValueError(f'device must be consecutive') + + if scale_ndevs is not None: + _logger.warning("scale_ndevs is deprecated, please use runtime_ndevs instead") + if runtime_ndevs is not None: + raise ValueError("You cannot use runtime_ndevs and scale_ndevs at the same time") + self.runtime_ndevs: int = runtime_ndevs or scale_ndevs or len(self.devices) + # we will scale the graph as data parallelism + # when we have more devices than the number of devices used in the graph + # here we don't need to do anything as things are already done in ModuleCodeGen. + if self.runtime_ndevs % len(self.devices) != 0: + raise ValueError(f'runtime_ndevs must be a multiple of {len(self.devices)}') + self.enable_dp = self.runtime_ndevs > len(self.devices) + # model full code self.init_code: List[str] = [ '\n\n########## Generated Schedule Code ###########', 'import torch', 'import cube', ''] # module member name self.symbols = SymbolTable() - self._scale_to_ndevs = scale_ndevs def gen(self, device: int, outfile=None, attach=None) -> str: """ Generate scheduling code on device """ gencode = copy.copy(self.init_code) - device_map = device if self._scale_to_ndevs is None else \ - device % len(self.devices) + device_map = device % len(self.devices) device_nodes = self.execplan.seq(device_map) assert all(not isinstance(n, IRFwOperation) for n in device_nodes), \ diff --git a/cube/parallel.py b/cube/parallel.py index 99344dcd..151166c4 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -438,9 +438,8 @@ def _gencode( # code generation torch.save(compute_config, outdir / ParallelModule.COMPUTE_CONFIG_FILE) - runtime_ngpus = None if compute_config.plan_ngpus == compute_config.runtime_ngpus else compute_config.runtime_ngpus assert len(execplan.graph.device) == compute_config.plan_ngpus, f"{execplan.graph.device}" - mgener = ModuleCodeGen(execplan, scale_ndevs=runtime_ngpus) + mgener = ModuleCodeGen(execplan, compute_config.runtime_ngpus) for rank in range(compute_config.runtime_ngpus): mgener.gen(rank, forward_args=forward_args, diff --git a/tests/algorithm/__init__.py b/tests/algorithm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/algorithm/ops/__init__.py b/tests/algorithm/ops/__init__.py new file mode 100644 index 00000000..e69de29b From 2530a797de6ead9481e24e88273c7840ed4368ba Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 18 Dec 2023 06:00:07 +0000 Subject: [PATCH 1542/1892] Merged PR 1947: Refine logging & Speed up code gen - directly update node instead of remove -> update -> insert - introduce a new interface `multi_index` to save time - remove some logging to make the output clean parity check passed --- cube/graph/gener/gen.py | 30 ++++++++++++++++++++++++------ cube/graph/graph.py | 2 -- cube/graph/segment.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index f294891f..e471e57f 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -1,12 +1,13 @@ from typing import Dict, List, Optional, Tuple, Callable, Set import numpy as np import itertools +import logging from cube.graph.function.anchor import IRGraphAnchor from cube.graph.gener.concurrent import ConcurrentGener import cube.graph.gener.utils as utils from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment +from cube.graph.segment import IRSegment, CellPosition from cube.graph.function.pyfunc import IRPyFunc from cube.ir.cten import IRCell, IRObject @@ -19,6 +20,8 @@ DeviceID = int +_logger = logging.getLogger(__name__) + def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) -> List[IRFwOperation]: """ @@ -108,12 +111,16 @@ def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: """ # reorder producer and consumer ordering graph._reorder_producer_consumer() + _logger.info("finish reordering producer and consumer") # remove anchor node graph = IRAdapterGener.remove_anchor(graph) + _logger.info("finish removing anchor nodes") # automatic replace pyfunc graph = IRAdapterGener.auto_pyfunc(graph) + _logger.info("finish replacing auto pyfunc") # automatic transform multiref graph = IRAdapterGener.autoref(graph) + _logger.info("finish transforming multiref nodes") # generate adapters for activation graph = IRAdapterGener.gen_activation(graph, cost_fn=cost_fn) # generate weight reducer @@ -321,22 +328,29 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # local producer fusion and local consumer multiref ftensors = [] + _cnt = 0 for ftensor in graph.full_tensors(): # backward will gen in forward if ftensor.is_param() or ftensor.is_grad(): continue - # flatten gradient + # flatten gradient utils.flatten_grad(graph, ftensor) # optimization: local fusion / multiref on producer / consumer ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) IRAdapterGener.local_consumer_multiref(graph, ftensor) ftensors.append(ftensor) + _cnt = _cnt + 1 + if _cnt % 100 == 0: + _logger.info(f'processed local fusion & multiref for {_cnt} tensors') + _logger.info(f'finish local fusion & multiref for {_cnt} tensors') # reorder again since inserted multiref could be mis-ordered graph._reorder_producer_consumer() + _logger.info("finish reordering producer and consumer") # generate adapter for intra-segments # FIXME: assume producers and consumers can run in parallel + _cnt = 0 for ftensor in ftensors: # debug @@ -419,17 +433,17 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # insert forward adapter # graph.insert(fadapter, max(producers) + 1) if len(fconsumers) > 0: - fidx = min(graph.nodes().index(c) for c in fconsumers) + fidx = min(graph.multi_index(fconsumers)) else: # no consumer: find the last forward node for fidx, node in enumerate(graph.nodes()[::-1]): if node.isfw(): - fidx = graph.nnodes - fidx + fidx = CellPosition(tuple([graph.nnodes - fidx])) break graph.insert(fadapter, fidx) # setup recompute if allow_recompute: - if fidx > 0: + if fidx > CellPosition(tuple([0])): prev_node = graph.node(fidx-1) if isinstance(prev_node, (IRFwOperation, IRAdapter)): fadapter.recompute = prev_node.recompute @@ -439,12 +453,16 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: assert isinstance(badapter, IRAdapter) assert isinstance(bgraph, IRSegment) if len(bproducers) > 0: - bidx = max(bgraph.nodes().index(p) for p in bproducers) + 1 + bidx = max(bgraph.multi_index(bproducers)) + 1 else: # no producer: find the first backward node for bidx, node in enumerate(bgraph.nodes()): if not node.isfw(): break bgraph.insert(badapter, bidx) + _cnt = _cnt + 1 + if _cnt % 100 == 0: + _logger.info(f'generated {_cnt} activation adapters') + _logger.info(f'finish generating {_cnt} activation adapters') # generate adapter for each segment segments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5ac59dfd..3d5ce821 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -239,10 +239,8 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis if not isinstance(times, int) or times < 1: raise TypeError("Expected times to be int and >= 1") if node.name == 'multiref': - _logger.warning(f'skip replicating multiref ({node.cid}), which will be handled by system.') return [node] if isinstance(node, IRPyFunc): - _logger.warning(f'skip replicating pyfunc ({node.cid}), which will be handled by system.') return [node] fsegment: IRSegment = self.segment(node) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 4d67fc2b..5de0081a 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -220,6 +220,38 @@ def index(self, node: IRCell) -> CellPosition: return CellPosition((idx,) + index.indices) raise KeyError(f"The queried node: {node} not in the graph") + def multi_index(self, nodes: List[IRCell]) -> List[CellPosition]: + """ + Get multiple node indices, traversing the graph only once + to save time. + + Args: + nodes (List[IRCell]): nodes to be indexed + + Returns: + List[CellPosition]: indices of nodes + """ + visited = 0 + indices = [None] * len(nodes) + def dfs(seg: IRSegment, path: List[int]): + nonlocal visited, indices + for idx, node in enumerate(seg._nodes): + if isinstance(node, IRSegment): + dfs(node, path + [idx]) + elif node in nodes: + indices[nodes.index(node)] = CellPosition(tuple(path + [idx])) + visited += 1 + if visited == len(nodes): + return + dfs(self, []) + if visited != len(nodes): + unvisited = [] + for idx, node in zip(indices, nodes): + if idx is None: + unvisited.append(node) + raise RuntimeError(f"Some of the queried nodes: {unvisited} not in the graph") + return indices + def segment(self, node: IRCell) -> IRCell: """ Get the lowest segment that constains the node From 020eaccea57ba612d24900576763e408ac26fd4d Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 18 Dec 2023 08:38:02 +0000 Subject: [PATCH 1543/1892] Merged PR 1954: add dict.get/torch.clone support add dict.get support uncomment torch.clone --- cube/graph/parser/fx/mapping.py | 6 +-- cube/graph/parser/fx/parser.py | 7 ++-- tests/parallel_module/test_gencode.py | 59 ++++++++++++++++++++++++++- 3 files changed, 65 insertions(+), 7 deletions(-) diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 12852d80..0106188d 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -140,7 +140,7 @@ def exist(signature: str) -> bool: __ttemplate('tensor'): function.NewTensor, __ttemplate('full'): function.Full, __ttemplate('rand'): function.Rand, - # __ttemplate('clone'): function.Clone, + __ttemplate('clone'): function.Clone, '_operator.is_': function.Is, '_operator.is_not': function.IsNot, @@ -153,7 +153,7 @@ def exist(signature: str) -> bool: '_operator.mul': function.Mul, '_operator.imul': function.Mul, # FIXME: may waste memory '_operator.mod': function.Mod, - + __ttemplate('div') : function.Div, __ttemplate('true_divide'): function.Div, '_operator.truediv': function.Div, @@ -178,7 +178,7 @@ def exist(signature: str) -> bool: # __tttemplate('view'): function.View, __tttemplate('contiguous'): function.Contiguous, - + __ttemplate('reshape'): function.Reshape, # # __ttemplate('conv2d'): function.Conv2D, diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 83884da5..84a032ce 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -311,13 +311,14 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> """ The target field of call_method node must be a string. """ - assert isinstance(node_target, str) + if not isinstance(node_target, str): + raise ValueError(f'node_target must be a string, but got {type(node_target)} with value {node_target}') for module, module_name in [(torch, 'torch'), (torch.Tensor, 'torch.Tensor')]: lib_func = getattr(module, node_target, None) if lib_func is not None and callable(lib_func): return f'{module_name}.{node_target}' - assert len(node.args) == 1, f'invalid args {node.args} in {node.name}, {node.target}, {node.meta}' - assert len(node.kwargs) == 0, f'invalid kwargs {node.kwargs} in {node.name}, {node.target}, {node.meta}' + + assert len(node.args) > 0, 'Expect an object as the first argument of call_method' # example node.args[0].meta is {'type': } in_type = node.args[0].meta['type'] assert node_target in in_type().__dir__(), f'node_target = {node_target}, in_type().__dir__() = {in_type().__dir__()}' diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index dd879e91..d4eb9f66 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -4,7 +4,8 @@ import torch import pytest -from cube.parallel import parallelize, ComputeConfig, CubeModule +import cube.graph.function.dimops +from cube.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph from .common import PASData, init_distributed, PASRandomSPMD from ..launch_torchrun import launch_torchrun @@ -477,3 +478,59 @@ def test_codegen_tensor_slice(): load_module=False, reuse='none', ) + + +class DictGetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, batched_data: dict): + data_x = batched_data["x"] + data_y = batched_data.get("y", batched_data['z']) + return data_x + data_y + + +def test_codegen_dictget(): + if not torch.cuda.is_available(): + print('skip test_codegen_dictget due to lack of cuda devices') + return + + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + DictGetModule(), + {'batched_data': { + 'x': torch.tensor([[[1.0], [2.0], [3.0], [6.0]]]), + 'z': torch.tensor([[[1.0], [2.0], [3.0], [6.0]]]) + }}, + PASRandomSPMD, + ComputeConfig(2, 2), + dynamic_shape=True, + cube_savedir=tempdir, + load_module=False, + ) + assert _gencode_contains(tempdir, DictGetModule, 0, r"dict.get\(\w+, 'y', \w+\)") + assert _gencode_contains(tempdir, DictGetModule, 1, r"dict.get\(\w+, 'y', \w+\)") + assert m_new is None + + +class CloneModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.clone() + + +def test_codegen_clone(): + if not torch.cuda.is_available(): + print('skip test_codegen_clone due to lack of cuda devices') + return + + with tempfile.TemporaryDirectory() as tempdir: + g, _ = _gen_graph( + CloneModule(), + {'x': torch.tensor([1.0, 2.0, 3.0, 6.0])}, + tempdir, + True + ) + assert isinstance(g.nodes()[0], cube.graph.function.dimops.IRDimops) From d68d47fb0d8e4ffd393735ec36aba8e799b130b1 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Dec 2023 10:39:36 +0000 Subject: [PATCH 1544/1892] Merged PR 1953: Pipeline Support-0: bug fix on missing reducers for weights across IRSegments This PR fixed a bug for previously missed reducers on weights that are in different IRSegments. The root cause is that each IRSegment hides its weights within the segment, while IRGraph cannot inspect weights from IRSegments' inputs. This fix goes through every IRSegment and collects weights to get the full inspectation and creates reducers if necessary. Passed parity check and unit test. --- cube/graph/gener/gen.py | 198 +++++++++++++------------- cube/graph/segment.py | 30 ++-- tests/graph/gener/test_reducer_gen.py | 111 +++++++++++++++ 3 files changed, 233 insertions(+), 106 deletions(-) create mode 100644 tests/graph/gener/test_reducer_gen.py diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index e471e57f..99e9d2a3 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -192,109 +192,111 @@ def auto_pyfunc(graph: IRGraph): @staticmethod def gen_weight(graph: IRGraph) -> IRGraph: - """ - Generate gradient accumulation - - Only suuport cases that: + """Generate cross-device weight reducers for gradient accumulation. + + If a weight tensor is replicated across multiple devices by different / partitioned operators, + the weight tensor is required to accumulate gradients according to chain rules. + + However, if the weight tensor is replicated across devices by replicated operators, + the weight tensor doesn't need to accumulate gradients. - 1) each sub-tensor weight is consumed by different node cids (no replica) - 2) If the sub-tensor weight is consumed by same replicated node: - The consumers can be grouped by node cids and satisfy: - 1. same number of nodes per cid group - 2. same device set or no-overlapping device set per cid group + Warning: + 1) Each weight tensor's consumers can only be ALL partitioned or ALL replicated. + 2) Weight partitions cannot be partially overlapped. + 3) Limited support for shared weight of multiple operators: + - If operators are on different device group (e.g. pipeline), + operators can only be partitioned. + - If operators are on same device group, + operators can either be all partitioned or all replicated. """ - def check_consistent_local_partition(graph: IRSegment): - """each weight full tensor inside one device should in same format.""" - for ftensor in graph.full_tensors(): - if not ftensor.is_attr(): continue - device_tensors: Dict[int, Set[IRSubTensor]] = {} - for ctensor in graph.ctensors(ftensor): - for devid in ctensor.device: - local_tensors = device_tensors.setdefault(devid, set()) - for t in local_tensors: - assert t == ctensor or not t.overlap(ctensor), ( - f"Detected graph attribute is partitioned with shared part on device {devid}.\n" - f"To achieve this, need call graph.multiref at the front of sProgram.\n" - f"{graph.debug_tensor_map_str(ftensor)}" - ) - local_tensors.add(ctensor) + sub_weights : Dict[IRFullTensor, List[IRSubTensor]] = dict() + sub_weight_consumers: Dict[IRSubTensor, List[IRFwOperation]] = dict() + + def collect_sub_weight(graph: IRSegment): + nonlocal sub_weights, sub_weight_consumers + for ftensor in graph.attributes(): + if not ftensor.is_param(): continue + for ctensor, consumer in zip(graph.ctensors(ftensor), graph.consumers(ftensor)): + if ctensor.grad is None: continue + sub_weight_consumers.setdefault(ctensor, []).append(consumer) + sub_weights.setdefault(ftensor, []).append(ctensor) for segment in graph.select(ntype=IRSegment, flatten=False): if segment.isfw(): - check_consistent_local_partition(segment) - - check_consistent_local_partition(graph) - - # collect subtensor and consumer - fweights: Dict[IRFullTensor, List[IRSubTensor]] = dict() - fgrads: Dict[IRFullTensor, List[IRSubTensor]] = dict() - consumers: Dict[IRFullTensor, List[IRFwOperation]] = dict() - for fnode in graph.nodes(flatten=True): - if not isinstance(fnode, IRFwOperation): continue - assert len(fnode.device) == 1 - for wtensor in fnode.inputs(): - if isinstance(wtensor, IRSubTensor) and wtensor.is_param(): - if wtensor.grad is None: continue - fweight = wtensor.parent - if fweight not in fweights: - fweights[fweight] = [] - fgrads[fweight] = [] - consumers[fweight] = [] - fweights[fweight].append(wtensor) - fgrads[fweight].append(wtensor.grad) - consumers[fweight].append(fnode) - - nl = '\n' - weights: Dict[IRFullTensor, Dict[IRSubTensor, List[int]]] = dict() - for fweight in fweights.keys(): - weights[fweight] = {} - weight_grads: Dict[IRSubTensor, Dict[IRSubTensor, List[IRFwOperation]]] = {} - for weight, grad, consumer in zip(fweights[fweight], fgrads[fweight], consumers[fweight]): - if weight not in weight_grads: - weight_grads[weight] = {} - if grad not in weight_grads[weight]: - weight_grads[weight][grad] = [] - weight_grads[weight][grad].append(consumer) - - # assert all(sw.valmap[1] == len(weight_grads) for sw in weight_grads.keys()) - for sub_weight in weight_grads: - diff_grads = weight_grads[sub_weight] - diff_grads_len = [len(diff_grads[grads]) for grads in diff_grads] - assert all(n == diff_grads_len[0] for n in diff_grads_len), ( - f"If one of the weight consumers are replicated, " - f"other same-weight consumers should also replicated in same way." - f"FullTensor Weight: {fweight}\n" - f"Consumers:\n{nl.join([repr(node) for node in consumers[fweight]])}" + collect_sub_weight(segment) + + collect_sub_weight(graph) + + # check consistency in node replicate or node partition + replicated = [] + for sub_weight, consumers in sub_weight_consumers.items(): + # suppose a weight is originally shared by 2 operators op1 and op2, + # each operator is replicated on a same device group (e.g., rank 0 and rank 1). + # then the device 0 has (op1, op2) and device 1 also has (op1, op2). + # we don't need to accumulate gradients for the weight in this case. + # this case can be checked by whether each device has same consumer set. + dev_cids = dict() + for consumer in consumers: + dev_cids.setdefault(consumer.device[0], []).append(consumer.cid) + dev_cids = [tuple(sorted(cids)) for cids in dev_cids.values()] + cross_device_replicated = all(cids == dev_cids[0] for cids in dev_cids) + + # otherwise, we only support fully partitioned consumers, + # the weight's gradient should be accumulated. + fully_partitioned = len(set(c.cid for c in consumers)) == len(consumers) + + if not (cross_device_replicated or fully_partitioned): + nl = '\n' + raise RuntimeError( + f"The weight consumers can either be ALL replicated or ALL partitioned. " + f"Detected some consumers are replicated and some are partitioned.\n" + f"FullTensor weight: {sub_weight.parent}\n" + f"Consumers:\n{nl.join([repr(n) for n in consumers])}\n" ) - # get devices - devices = [] - for sub_grad in diff_grads: - sub_grad_devices = [node.device[0] for node in diff_grads[sub_grad]] - sub_grad_devices.sort() - devices.append(sub_grad_devices) - devices = np.array(devices, dtype=int).transpose((1, 0)) - for group_devices in devices: - group_devices = set(int(devid) for devid in group_devices) - group_devices = list(group_devices) - group_devices.sort() - weights[fweight][sub_weight] = group_devices - - reducers: Dict[Tuple[int], List[IRSubTensor]] = dict() - for subtensors in weights.values(): - for subw in subtensors: - if len(subtensors[subw]) == 1: - continue - devices = list(subtensors[subw]) - devices.sort() - devices = tuple(devices) - if devices not in reducers: - reducers[devices] = [] - reducers[devices].append(subw) - # generate reducer for each rank - for devices in reducers: - weights = reducers[devices] - opt_op = IRWeightReducer(weights) - opt_op.device = list(devices) - graph.insert(opt_op, graph.nnodes) + if cross_device_replicated == 1: # replicated weights + replicated.append(sub_weight) + # check consistency in weight partition + # note we don't support sub-weight tensors with partially shared part. + # This is because the shared part may require reducer to accumulate gradients only for the + # shared part, requiring a more fine-grained tensor granularity. + # However, we don't support such fine-grained accumulation for now, and we only support + # to either accumulate same sub-weight tensors or not accumulate non-overlapped sub-weight tensors. + for ftensor, sub_ws in sub_weights.items(): + # all the sub weights can only be + # 1) replicated (sw1 == sw2) or, + # 2) partitioned without overlapping (not sw1.overlap(sw2)) + for sw1, sw2 in itertools.combinations(sub_ws, 2): + if not (sw1 == sw2 or not sw1.overlap(sw2)): + nl = '\n' + raise RuntimeError( + f"Detected a weight is partitioned with partially shared part among its sub-tensors.\n" + f"To achieve this, users need to call `graph.multiref(weight)` inside the policy.\n" + f"FullTensor weight: {ftensor}\n" + f"Consumers:\n{nl.join([repr(w.cell) for w in sub_ws])}\n" + ) + + # only record sub-weight that is consumed by multiple devices + sub_weight_devices: Dict[IRSubTensor, Tuple[int,]] = dict() + # - pop out replicated sub weights as they will have full gradients, + # no need for reducer. + for sub_weight in replicated: + del sub_weight_consumers[sub_weight] + # - gather sub weights that are consumed by same device groups + for sub_weight, consumers in sub_weight_consumers.items(): + devices = set(consumer.device[0] for consumer in consumers) + if len(devices) > 1: + devices = tuple(sorted(devices)) + sub_weight_devices[sub_weight] = devices + + # create reducer + reducers: Dict[Tuple[int,], List[IRSubTensor]] = dict() + for subw, devices in sub_weight_devices.items(): + reducers.setdefault(devices, []).append(subw) + for devices, subws in reducers.items(): + reducer = IRWeightReducer(subws) + reducer.device = devices + # insert reducer to as the last node. + graph.insert(reducer, graph.nnodes) + return graph @staticmethod diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 5de0081a..c56ac087 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -288,22 +288,36 @@ def consumers(self, ftensor: IRFullTensor) -> Tuple[IRCell]: return tuple(self._consumers.get(ftensor, ())) def ptensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: - """ - Get consumed sub-tensors of ftensor in execution order in this graph + """Get produced sub-tensors of a full tensor (ftensor). - @param ftensor IRFullTensor: the queried full tensor. + A full tensor (ftensor) is originally produced by some operator(s). + These operators can be further partitioned into multiple sub-operators. + Each sub-operator potentially produces a smaller part of the ftensor (a.k.a. sub-tensor). + This function returns all the sub-tensors that are produced by operators + inside the segment. - @return subtensors Tuple[IRSubTensor]: the consumed subtensors. + Args: + ftensor (IRFullTensor): the queried full tensor. + + Returns: + Tuple[IRSubTensor]: the produced sub-tensors. """ return tuple(self._ptensors.get(ftensor, ())) def ctensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: - """ - Get consumed sub-tensors of ftensor in execution order in this graph + """Get consumed sub-tensors of a full tensor (ftensor) - @param ftensor IRFullTensor: the queried full tensor. + A full tensor (ftensor) is originally consumed by some operator(s). + These operators can be further partitioned into multiple sub-operators. + Each sub-operator potentially consumes a smaller part of the ftensor (a.k.a. sub-tensor). + This function returns all the sub-tensors that are consumed by operators + inside the segment. - @return subtensors Tuple[IRSubTensor]: the consumed subtensors. + Args: + ftensor (IRFullTensor): the queried full tensor. + + Returns: + Tuple[IRSubTensor]: the consumed sub-tensors. """ return tuple(self._ctensors.get(ftensor, ())) diff --git a/tests/graph/gener/test_reducer_gen.py b/tests/graph/gener/test_reducer_gen.py new file mode 100644 index 00000000..104d1271 --- /dev/null +++ b/tests/graph/gener/test_reducer_gen.py @@ -0,0 +1,111 @@ +import pytest +from cube.graph.gener.gen import IRAdapterGener + +from cube.graph import IRGraph +from cube.graph.segment import IRSegment +from cube.graph.parser.converter import convert_model +from cube.ir.operator import IRFwOperation +from cube.ir.tensor import IRFullTensor +from cube.ir.adapter import IRWeightReducer + +import torch +import tempfile + + +def make_param(shape, dtype) -> IRFullTensor: + param = IRFullTensor(shape=shape, dtype=dtype, requires_grad=True) + param.as_param() + return param + + +class ReducerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.param1 = torch.nn.Parameter(torch.zeros([128, 128], dtype=torch.float16)) + self.param2 = torch.nn.Parameter(torch.zeros([128, 128], dtype=torch.float16)) + + def forward(self, x): + x = torch.matmul(x, self.param1) + x = torch.matmul(x, self.param2) + x = x + self.param1 + x = torch.sum(x) + return x + + +def build_graph(): + # build graph + model = ReducerModule() + with tempfile.TemporaryDirectory() as tempdir: + graph = convert_model( + model, + {'x': torch.randn([128, 128], dtype=torch.float16)}, + attr_savedir=tempdir, + dynamic_shape=False + ) + graph.backward(graph.output(0)) + return graph + + +def test_cross_segment_weight_reducer(): + + graph = build_graph() + [matmul1, matmul2, add, sum] = graph.select(ntype=IRFwOperation) + graph.group([matmul1, matmul2]) + graph.group([add, sum]) + + for idx, segment in enumerate(graph.select(ntype=IRSegment, flatten=False)): + if not segment.isfw(): + continue + for node in segment.nodes(): + graph.assign(node, idx) + + print(graph.extra_repr()) + + # build reducer + graph = IRAdapterGener.gen_weight(graph) + print(graph.extra_repr()) + reducers = graph.select(ntype=IRWeightReducer) + assert len(reducers) == 1 + assert len(reducers[0].inputs()) == 1 + assert reducers[0].input(0) == matmul1.input(1) + assert reducers[0].device == (0, 1) + + +def test_replicate_shared_param(): + + graph = build_graph() + for node in graph.select(ntype=IRFwOperation): + sn1, sn2 = graph.replicate(node, 2) + graph.assign(sn1, 0) + graph.assign(sn2, 1) + + graph = IRAdapterGener.gen_weight(graph) + print(graph.extra_repr()) + + reducers = graph.select(ntype=IRWeightReducer) + assert len(reducers) == 0 + + +def test_reducer_partially_shared_part(): + graph = build_graph() + [matmul1, matmul2, add, sum] = graph.select(ntype=IRFwOperation) + + m1, m2 = graph.partition(matmul1, matmul1.algorithms('dim'), idx=0, dim=1, num=2) + graph.assign(m1, 0) + graph.assign(m2, 1) + + add1, add2 = graph.partition(add, add.algorithms('dim'), idx=0, dim=1, num=2) + graph.assign(add1, 0) + graph.assign(add2, 1) + + for node in [matmul2, sum]: + sn1, sn2 = graph.replicate(node, 2) + graph.assign(sn1, 0) + graph.assign(sn2, 1) + + print(graph.extra_repr()) + + with pytest.raises(RuntimeError): + graph = IRAdapterGener.gen_weight(graph) + print(graph.extra_repr()) From a42040a7918190007da39099030396994fc8a5fd Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 20 Dec 2023 06:30:00 +0000 Subject: [PATCH 1545/1892] Merged PR 1948: add IROject support to expand Really not easy to implement. --- cube/graph/function/function.py | 120 +++++++++++++++---------- tests/graph/function/test_functions.py | 49 +++++++++- 2 files changed, 121 insertions(+), 48 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index acbf36e9..6eb2d6b8 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -18,6 +18,33 @@ _logger = logging.getLogger(__name__) +# If the type is IROject, then value should be type of int, Tuple[int], List[int] +# If the type is Tuple[IROject] or List[IRObject], then the value of each element should be type of int +_VariadicInt = Union[int, Tuple[int, ...], List[int], IRObject, Tuple[IRObject, ...], List[IRObject]] + +def extract_variadic(v: _VariadicInt) -> Tuple[List[int], List[bool]]: + if isinstance(v, int): + if isinstance(v, bool): + raise ValueError("Unsupported type: bool") + return [v], [False] + elif isinstance(v, IRObject): + r = extract_variadic(v.value) + return r[0], [True] * len(r[0]) # because all elements are from IRObject + elif isinstance(v, (tuple, list)): + r = [extract_variadic(e) for e in v] + if any(len(x[0]) != 1 for x in r): + raise ValueError("tuple/list can't be nested") + return [x[0][0] for x in r], [x[1][0] for x in r] + else: + raise ValueError(f"Unsupported type: {type(v)}") + + +def is_list_or_tuple(v: Any) -> bool: + return isinstance(v, (list, tuple)) or ( + isinstance(v, IRObject) and isinstance(v.value, (list, tuple)) + ) + + def Identity(tensor: IRObject, signature = None): signature = 'cube.runtime.function.identity' eshape = ShapeAnno.create_shape_str(tensor.shape) @@ -336,38 +363,60 @@ def _handle_broadcast_multi(ins_list: List[IRTensor]) -> Tuple[Tuple[List[str]], return ins_anno, out_anno -def Expand(input, *sizes, size = None, signature = None): +def Expand(input, size, *arg_size, signature = None): """ torch.Tensor.expand(*sizes) - - The reason of add ``size`` to this function argument is: - 1. ``sizes`` need to reuse in IRDimops.new(), but it is a ``non-keyword arguments``, - and can not put it into keyword arguments (something like Expand(input, sizes=[1, 2, 3])) is not work, - to support IRDimops.new API, here add a ``size`` to workaround. - - 2. in torch._C.expand API, it has: - def expand(self, size: Sequence[Union[_int, SymInt]], *, implicit: _bool=False) -> Tensor: ... - so add ``size`` can also solve user using something like: - torch.rand(3, 1).expand(size=(3, 3)) """ signature = 'torch.Tensor.expand' - if size is not None: - assert len(sizes) == 0 - sizes = size - ori_len, exp_len = len(input.shape), len(sizes) - assert ori_len <= exp_len - assert all(dim == expand_dim or dim == 1 or expand_dim == -1 for dim, expand_dim in zip(input.shape, sizes[-ori_len:])) - edim_ou = ShapeAnno.create_shape_str(sizes) + if is_list_or_tuple(size): + # follow the behavior of torch.Tensor.Expand, + if arg_size: + raise ValueError(f"arg_size should not be provided when size is a list or tuple") + complete_size = size + else: + # follow the behavior of torch.Tensor.Expand, + if any(is_list_or_tuple(s) for s in arg_size): + raise ValueError(f"list or tuple should not be provided in arg_size") + complete_size = (size,) + arg_size + + size, size_is_ir = extract_variadic(complete_size) + + ori_len, exp_len = len(input.shape), len(size) + if ori_len > exp_len: + raise ValueError(f"Less dimensions than input is provided. input dims: {ori_len}, sizes: {exp_len}") + if not all(dim == expand_dim or dim == 1 or expand_dim == -1 for dim, expand_dim in zip(input.shape, size[-ori_len:])): + raise ValueError(f"The expanded size of the tensor ({size}) must match the existing size ({input.shape})") + edim_ou = ShapeAnno.create_shape_str(size) edim_in = copy.copy(edim_ou[-ori_len:]) - new_size = [-1] * len(sizes) - for idx, (dim, expand_dim) in enumerate(zip(input.shape, sizes[-len(input.shape):])): - if dim == 1 and dim != expand_dim and expand_dim != -1: + # we must use -1 to represent the dimension that will not be expanded + # Otherwise, splitting on that dimension will be wrong + new_size = [-1] * len(size) + for idx, (dim, expand_dim, expand_dim_is_ir) in enumerate(zip(input.shape, size[-ori_len:], size_is_ir[-ori_len:])): + # when dynamic shape is enable, the dim may change in runtime + # so we can't assume the dim is 1 for sure even if it is 1 in tracing + # If we assume the user code is correct + # 1. if expand_dim is from IRObject, for safety, we don't allow partition + # 2. if expand_dim is not from IRObject, and dim > 1, dimension is partitionable. + # 3. If it is 1 in tracing and exapnd_dim is not from IRObject + # 3.1 if expand_dim is -1, we allow partition on this dimension + # For example, in runtime, (dim, expand_dim) can be (2, -1) or (3, -1) or (4,-1), will not trigger error on partition + # 3.2 if expand_dim is fixed 1, then dim must be 1 to make it valid op. + # partition on this dimension is not useful (both is OK, here we disable partition for this case) + # 3.3 if expand_dim is fixed x > 1, then in runtime dim can be 1 or x + # For example, in runtime, (dim, expand_dim) can be (1, x) or (x, x), will trigger error on partition + if expand_dim_is_ir or (dim == 1 and expand_dim != -1): + new_dim = dim if expand_dim == -1 else expand_dim edim_in[idx] += '^' - edim_ou[exp_len - ori_len + idx] = str(expand_dim) - new_size[exp_len - ori_len + idx] = expand_dim + # keep anno id only if expand_dim == -1 + if expand_dim == -1: + edim_ou[exp_len - ori_len + idx] = edim_in[idx] + else: + edim_ou[exp_len - ori_len + idx] = str(new_dim) + # explicit set tid to -1 to avoid changing IDGenerator state. + new_size[exp_len - ori_len + idx] = IRObject(tid=-1, value=expand_dim) if expand_dim_is_ir else new_dim for idx in range(exp_len - ori_len): - edim_ou[idx] = str(sizes[idx]) - new_size[idx] = sizes[idx] + edim_ou[idx] = str(size[idx]) + new_size[idx] = size[idx] anno = OpAnno.create_op_str([edim_in], [edim_ou]) return IRDimops(Expand, 'expand', signature, [anno], [input], size=new_size) @@ -1558,27 +1607,6 @@ def SelectScatter(self: torch.Tensor, input: torch.Tensor, dim: int, index: int, [anno], [self, input], dim=dim, index=index) -# If the type is IROject, then value should be type of int, Tuple[int], List[int] -# If the type is Tuple[IROject] or List[IRObject], then the value of each element should be type of int -_VariadicInt = Union[int, Tuple[int, ...], List[int], IRObject, Tuple[IRObject, ...], List[IRObject]] - -def extract_variadic(v: _VariadicInt) -> Tuple[List[int], List[bool]]: - if isinstance(v, int): - if isinstance(v, bool): - raise ValueError("Unsupported type: bool") - return [v], [False] - elif isinstance(v, IRObject): - r = extract_variadic(v.value) - return r[0], [True] * len(r[0]) # because all elements are from IRObject - elif isinstance(v, (tuple, list)): - r = [extract_variadic(e) for e in v] - if any(len(x[0]) != 1 for x in r): - raise ValueError("tuple/list can't be nested") - return [x[0][0] for x in r], [x[1][0] for x in r] - else: - raise ValueError(f"Unsupported type: {type(v)}") - - def Repeat(tensor, repeats: _VariadicInt, *arg_repeats, signature = None): """ torch.Tensor.repeat(*sizes) diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index db15ad6d..3d59df47 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -7,6 +7,14 @@ import torch +def o(value): + return IRObject(value=value) + + +def assert_anno(op, expected): + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected + + def test_handle_broadcast_multi(): ins_anno, out_anno = F._handle_broadcast_multi([IRTensor([4]), IRTensor([3, 4]), IRTensor([2, 3, 4])]) assert ins_anno[0] == ['c'] @@ -34,10 +42,49 @@ def test_Expand(): out = IRTensor([10, 2]) op = F.Expand(inp, 10, 2) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ -> a 2' + assert op.kwargs['size'] == [-1, 2] op.new([inp], [out], size=[10, 2]) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ -> a 2' + with pytest.raises(ValueError): + F.Expand(inp, (10, 2), 1) + + with pytest.raises(ValueError): + F.Expand(inp, 1, (10, 2)) + + with pytest.raises(ValueError): + F.Expand(inp, 1) + + with pytest.raises(ValueError): + F.Expand(inp, (5, 2)) + + op = F.Expand(inp, -1, o(1)) + assert_anno(op, 'a b^ -> a 1') + assert op.kwargs['size'][0] == -1 + assert op.kwargs['size'][1].value == 1 + + op = F.Expand(inp, -1, 1) + assert_anno(op, 'a b^ -> a 1') + assert op.kwargs['size'] == [-1, 1] + + assert_anno(F.Expand(inp, -1, o(2)), 'a b^ -> a 2') + assert_anno(F.Expand(inp, o((10, 2))), 'a^ b^ -> 10 2') + + op = F.Expand(inp, o(10), o(2)) + assert_anno(op, 'a^ b^ -> 10 2') + assert op.kwargs['size'][0].value == 10 + assert op.kwargs['size'][1].value == 2 + + op = F.Expand(inp, o(10), o(-1)) + assert_anno(op, 'a^ b^ -> 10 b^') + assert op.kwargs['size'][0].value == 10 + assert op.kwargs['size'][1].value == -1 + + op = F.Expand(inp, 10, 10, 2) + assert_anno(op, 'b c^ -> 10 b 2') + assert op.kwargs['size'] == [10, -1, 2] + def test_variadic_extraction(): def o(value): @@ -70,8 +117,6 @@ def o(value): def test_Repeat(): - def o(value): - return IRObject(value=value) inp = IRTensor([3]) op = F.Repeat(inp, (4, 2)) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b^ -> 4 (2 b^)' From 0be2ac2a7f10cfa33ce551b6fb91cc86d0284036 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 21 Dec 2023 02:24:16 +0000 Subject: [PATCH 1546/1892] Merged PR 1955: test mock: test trace without gpu test mock: trace without gpu --- tests/parallel_module/test_gencode.py | 57 +++++-------------- tests/utils.py | 80 ++++++++++++++++++++++++++- 2 files changed, 92 insertions(+), 45 deletions(-) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index d4eb9f66..b0f9e597 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -9,6 +9,7 @@ from .common import PASData, init_distributed, PASRandomSPMD from ..launch_torchrun import launch_torchrun +from ..utils import replace_all_device_with def _to_cube_model(module, compute_config, cube_savedir, load_module): return parallelize( @@ -55,13 +56,8 @@ def __init__(self): def forward(self, x): return x[:2] +@replace_all_device_with('meta') def test_codegen_slice(): - """ - Test it can support modules without parameters - """ - if not torch.cuda.is_available(): - print('skip test_codegen_slice due to lack of cuda devices') - return with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( SliceModule(), @@ -83,14 +79,9 @@ def __init__(self): def forward(self, x, y, *args): return self.linear(x) + y -def test_codegen_args(): - """ - Verify that unused args are supported by parallalize - """ - if not torch.cuda.is_available(): - print('skip test_codegen_args due to lack of cuda devices') - return +@replace_all_device_with('meta') +def test_codegen_args(): with tempfile.TemporaryDirectory() as tempdir: # *args is not supported. with pytest.raises(RuntimeError): @@ -240,11 +231,8 @@ def __init__(self) -> None: self.a = 2.0 +@replace_all_device_with('meta') def test_codegen_attr(): - if not torch.cuda.is_available(): - print('skip test_codegen_attr due to lack of cuda devices') - return - with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( AttrModule(), @@ -276,11 +264,8 @@ def forward(self, batched_data): return padding_mask +@replace_all_device_with('meta') def test_codegen_getitem(): - if not torch.cuda.is_available(): - print('skip test_codegen_getitem due to lack of cuda devices') - return - with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( GetItemModule(), @@ -308,13 +293,8 @@ def forward(self, x): return self.linear(x) + 1 +@replace_all_device_with('meta') def test_codegen_training_flag(): - """ - Test it can support modules without parameters - """ - if not torch.cuda.is_available(): - print('skip test_codegen_training_flag due to lack of cuda devices') - return with tempfile.TemporaryDirectory() as tempdir: m = TrainingModule() m.train() @@ -370,13 +350,11 @@ def forward(self, x): return x +@replace_all_device_with('meta') def test_codegen_iter(): """ Test it can support modules without parameters """ - if not torch.cuda.is_available(): - print('skip test_codegen_iter due to lack of cuda devices') - return with tempfile.TemporaryDirectory() as tempdir: m = IterModule() m.train() @@ -404,13 +382,11 @@ def forward(self, x): return x +@replace_all_device_with('meta') def test_codegen_const(): """ Test it can support modules without parameters """ - if not torch.cuda.is_available(): - print('skip test_codegen_const due to lack of cuda devices') - return with tempfile.TemporaryDirectory() as tempdir: m = ConstantModule() m.train() @@ -448,10 +424,9 @@ def forward(self, x): return x[:, :padding] +# torch.count_nonzero is not supported on meta +@replace_all_device_with('cpu') def test_codegen_tensor_slice(): - if not torch.cuda.is_available(): - print('skip test_codegen_tensor_slice due to lack of cuda devices') - return with tempfile.TemporaryDirectory() as tempdir: m = TensorSliceModule() m.train() @@ -490,11 +465,8 @@ def forward(self, batched_data: dict): return data_x + data_y +@replace_all_device_with('meta') def test_codegen_dictget(): - if not torch.cuda.is_available(): - print('skip test_codegen_dictget due to lack of cuda devices') - return - with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( DictGetModule(), @@ -521,11 +493,8 @@ def forward(self, x): return x.clone() +@replace_all_device_with('meta') def test_codegen_clone(): - if not torch.cuda.is_available(): - print('skip test_codegen_clone due to lack of cuda devices') - return - with tempfile.TemporaryDirectory() as tempdir: g, _ = _gen_graph( CloneModule(), diff --git a/tests/utils.py b/tests/utils.py index 0f9ae486..26553254 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Callable import torch import math @@ -18,7 +19,7 @@ def trunc_normal_(tensor: torch.Tensor, mean: float = 0., std: float = 1., a: fl tensor.add_(mean) tensor.clamp_(min=a, max=b) return tensor - + torch.random.manual_seed(seed) random.seed(seed) @@ -71,3 +72,80 @@ def assert_same_complex(gt, out): assert gt == out, f'mismatched: {gt} != {out}' assert_same_complex(baseline_outputs, compile_outputs) return None + + +@contextmanager +def replace_all_device_with(device='cpu'): + from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import ConcreteTracer + + orig_to = torch.Tensor.to + orig_cuda = torch.Tensor.cuda + orig_cpu = torch.Tensor.cpu + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] =device + return fn(*args, **kwargs) + wrapper.__name__ = fn.__name__ + wrapper.__qualname__ = fn.__qualname__ + return wrapper + # these constructors are enough for most cases + patched_tensor_constructors = [ + 'empty', 'zeros', 'ones', 'full', 'eye', + 'linspace', 'logspace', 'arange', + 'rand', 'randn', 'randint', 'randperm', + 'randn_like', 'rand_like', 'randint_like', + 'tensor' + ] + old_tensor_constructors = { + tf_name: getattr(torch, tf_name) + for tf_name in patched_tensor_constructors + } + patched_tensor_constructors = { + tf_name: patch_tensor_constructor(fn) + for tf_name, fn in old_tensor_constructors.items() + } + + def patched_to(self, *args, **kwargs): + if len(args) > 0 and isinstance(args[0], (torch.device, str)): + args[0] = device + return orig_to(self, *args, **kwargs) + if 'device' in kwargs: + kwargs['device'] = device + return orig_to(self, *args, **kwargs) + return orig_to(self, *args, **kwargs) + + def patched_cuda(self, *args, **kwargs): + return orig_to(self, device) + + def patched_cpu(self, *args, **kwargs): + return orig_to(self, device) + + try: + torch.Tensor.to = patched_to + torch.Tensor.cuda = patched_cuda + torch.Tensor.cpu = patched_cpu + # patch tensor constructors + for tf_name, fn in old_tensor_constructors.items(): + setattr(torch, tf_name, patched_tensor_constructors[tf_name]) + + # patch concrete tracer's autowrap leaf function + for tf_name, fn in old_tensor_constructors.items(): + leaf_info = ConcreteTracer.default_autowrap_leaf_function.pop(fn, None) + if leaf_info: + ConcreteTracer.default_autowrap_leaf_function[ + patched_tensor_constructors[tf_name] + ] = leaf_info + yield + finally: + for tf_name, fn in patched_tensor_constructors.items(): + leaf_info = ConcreteTracer.default_autowrap_leaf_function.pop(fn, None) + if leaf_info: + ConcreteTracer.default_autowrap_leaf_function[ + old_tensor_constructors[tf_name] + ] = leaf_info + for tf_name, fn in old_tensor_constructors.items(): + setattr(torch, tf_name, fn) + torch.Tensor.to = orig_to + torch.Tensor.cuda = orig_cuda + torch.Tensor.cpu = orig_cpu From 22d18177691e7b6a7b5189ca091a7608d41369be Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 22 Dec 2023 01:49:08 +0000 Subject: [PATCH 1547/1892] Merged PR 1961: hot fix unit test break for meta device when cuda is available. It is a Pytorch bug, and will trigger error sometimes: ``` RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "../torch/csrc/autograd/functions/utils.h":75, please report a bug to PyTorch. ``` if we use meta device when cuda is available. --- tests/parallel_module/test_gencode.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index b0f9e597..ad85c11f 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -56,7 +56,7 @@ def __init__(self): def forward(self, x): return x[:2] -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_slice(): with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( @@ -80,7 +80,7 @@ def forward(self, x, y, *args): return self.linear(x) + y -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_args(): with tempfile.TemporaryDirectory() as tempdir: # *args is not supported. @@ -231,7 +231,7 @@ def __init__(self) -> None: self.a = 2.0 -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_attr(): with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( @@ -264,7 +264,7 @@ def forward(self, batched_data): return padding_mask -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_getitem(): with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( @@ -293,7 +293,7 @@ def forward(self, x): return self.linear(x) + 1 -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_training_flag(): with tempfile.TemporaryDirectory() as tempdir: m = TrainingModule() @@ -350,7 +350,7 @@ def forward(self, x): return x -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_iter(): """ Test it can support modules without parameters @@ -382,7 +382,7 @@ def forward(self, x): return x -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_const(): """ Test it can support modules without parameters @@ -424,7 +424,6 @@ def forward(self, x): return x[:, :padding] -# torch.count_nonzero is not supported on meta @replace_all_device_with('cpu') def test_codegen_tensor_slice(): with tempfile.TemporaryDirectory() as tempdir: @@ -465,7 +464,7 @@ def forward(self, batched_data: dict): return data_x + data_y -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_dictget(): with tempfile.TemporaryDirectory() as tempdir: m_new = parallelize( @@ -493,7 +492,7 @@ def forward(self, x): return x.clone() -@replace_all_device_with('meta') +@replace_all_device_with('cpu') def test_codegen_clone(): with tempfile.TemporaryDirectory() as tempdir: g, _ = _gen_graph( From 0e29976e042777a83eb2cb4dbc9113aed97860ee Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 25 Dec 2023 04:30:27 +0000 Subject: [PATCH 1548/1892] Merged PR 1959: Pipeline Support-1: Dummy dataloader for micro-batches within a mini-batch The feature supports wrapping micro-batch samples into a mini-batch dataloader The code will be: * Compile Phase (1-microbatch view): ```python from cube.runtime.utils import microbatches model = Model() samples = [next(iter(dataloader)),]. # 1 sample minibatch_dataloader = microbatches(samples) @cube.compile(model, minibatch_dataloader, policy=xxx) def train_iter(model, dataloader): # 1-microbatch execution data1, data2 = next(dataloader) loss = model(data1, data2) loss. Backward() return loss ``` * Training Phase ```python def train_step(samples): # samples contain all micro-batch datas minibatch_dataloader = microbatches(samples) loss = train_iter(model, minibatch_dataloader) ``` Note the policy will compile a single micro-batch execution plan into multiple micro-batch execution plans (not implemented yet). The original interface (`train_iter(model, data1, data2, data3)` is still supported) Parity check and unit test passed --- cube/compiler.py | 3 +- cube/program.py | 11 ++- cube/runtime/utils.py | 117 ++++++++++++++++++------ examples/mlp/train.py | 20 ++-- examples/nlp/gpt/model.py | 10 +- examples/nlp/gpt/train.py | 14 ++- examples/nlp/mbart/model.py | 11 +-- examples/nlp/mbart/train.py | 11 ++- examples/vision/swin/model.py | 9 +- examples/vision/swin/train.py | 13 ++- tests/graph/function/test_dataloader.py | 16 ---- tests/runtime/test_dataloader.py | 50 ++++++++++ 12 files changed, 200 insertions(+), 85 deletions(-) create mode 100644 tests/runtime/test_dataloader.py diff --git a/cube/compiler.py b/cube/compiler.py index afa075c7..28d39207 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -24,6 +24,7 @@ from cube.codegen import ModuleCodeGen, ScheduleCodeGen from cube.runtime.device import DeviceGroup +from cube.runtime.utils import MicroBatchDataLoader from cube.program import Program, SemanticDataLoader, SemanticModel from cube.flags import CompileFlag @@ -93,7 +94,7 @@ def train_iter(model, dataloader): inputs = [model] for arg in args: assert not isinstance(arg, (torch.nn.Module, SemanticModel)), f"Only one model can be input for compile" - if isinstance(arg, torch.utils.data.DataLoader): + if isinstance(arg, MicroBatchDataLoader): arg = SemanticDataLoader(arg) elif isinstance(arg, torch.Tensor): tensor = arg diff --git a/cube/program.py b/cube/program.py index f1d7c2f6..b9f97ba4 100644 --- a/cube/program.py +++ b/cube/program.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Optional, Any, Dict +from typing import List, Tuple, Optional, Any, Dict, Union import inspect from cube.ir.cten import IRCell, IRObject @@ -10,6 +10,7 @@ from cube.runtime.module import CubeModule from cube.runtime.device import DeviceGroup +from cube.runtime.utils import MicroBatchDataLoader from cube.utils import load_model @@ -82,16 +83,16 @@ def __repr__(self): class SemanticDataLoader: - def __init__(self, dataloader: data.DataLoader): + def __init__(self, dataloader: MicroBatchDataLoader): """ Create semantic dataloader which will produces IRDataOperation when calling `next`. Args: - dataloader (torch.utils.data.DataLoader): torch dataloader + dataloader (MicroBatchDataLoader): torch dataloader """ - if not isinstance(dataloader, data.DataLoader): - raise TypeError("Expected data loader derived from torch.utils.data.DataLoader") + if not isinstance(dataloader, MicroBatchDataLoader): + raise TypeError("Expected data loader to be MicroBatchDataLoader") self.dataloader: data.DataLoader = dataloader self.object = IRObject(name='dataloader', value=None) diff --git a/cube/runtime/utils.py b/cube/runtime/utils.py index 828ffa2a..ca595ac3 100644 --- a/cube/runtime/utils.py +++ b/cube/runtime/utils.py @@ -1,46 +1,103 @@ r"""Runtime Utilities""" -from typing import Any +from typing import Any, List import logging -import torch.utils.data as data - _logger = logging.getLogger(__name__) -def create_dummy_dataloader(sample: Any, - batch_size: int, drop_last=True, - **dataloader_config) -> data.DataLoader: - """Create a dummy dataloader +class MicroBatchDataLoader: + """ + MicroBatchDataLoader is used for scenarios of gradient accumulation, + where a training iteration will have multiple data samples and perform + multiple forward and backward on each sample (i.e., each refers to + as a micro-batch). + + To support more flexible training patterns, e.g., pipeline parallelism, + MicroBatchDataLoader supports wrapping all data samples of a training iteration + into a light dataloader and passed as input for compilation. - The function is mainly used for performance test. + e.g., + + ```python + # compilation phase + dataloader = MicroBatchDataLoader([(input1,),]) # only need one micro-batch - Args: - sample (Any): a data sample without batch size dimension. - The sample can be a single tensor/object or tuple/list of tensors/objects - batch_size (int): batch size - drop_last (bool): whether to drop last batch to make batch size consistent. - dataloader_config (dict): kwargs for dataloader initialization. + @cube.compile(model, dataloader, ...) + def train_iter(model, dataloader): + input1 = next(dataloader) + loss = model(input1) + loss.backward() + return loss - Returns: - dataloader (torch.utils.data.DataLoader): - returns + ... + + # runtime phase + + for mini_batch_samples in iter(dataloader): + # mini_batch_samples are sample list for + # all micro-batches in one iteration. + dl = MicroBatchDataLoader(mini_batch_samples) + loss =train_iter(model, dl) + ... + ``` """ - class DummyDataset(data.Dataset): + def __init__(self, samples: List[Any], cycle: bool = False): + """Create a micro-batch data loader for a mini-batch. + + Args: + samples (List[Any]): a list of micro-batch samples. Each element + in the list is a micro-batch sample. + cycle (bool): whether to cycle the micro-batch samples. If True, + the micro-batch samples will be cycled infinitely. Note this + is only needed when the number of micro-batch samples is less + than expected micro-batch number during runtime. + """ - def __init__(self, sample: Any): + if not isinstance(samples, (tuple, list)): + raise TypeError("Samples must be a tuple or list of samples.") + self.samples = samples + self.nmicros = len(samples) + self.cycle = cycle + self._idx = 0 - self.sample = sample + def __iter__(self): + self._idx = 0 + return self + + def __next__(self): + if self._idx == self.nmicros: + raise StopIteration + batch = self.samples[self._idx] + self._idx += 1 + if self.cycle: + self._idx = self._idx % self.nmicros + return batch + + def __len__(self): + return self.nmicros + + def get_micro_batch(self, idx: int): + idx = idx % self.nmicros if self.cycle else idx + return self.samples[idx] + + +def microbatches(samples: List[Any], cycle: bool = False) -> MicroBatchDataLoader: + """Create a micro-batch data loader for a mini-batch. - def __len__(self): - return 1024000 - - def __getitem__(self, key: int): - return self.sample + This is for gradient accumulation scenarios. More details refer to + documents of MicroBatchDataLoader. - dataset = DummyDataset(sample) - dataloader = data.DataLoader( - dataset, batch_size=batch_size, drop_last=drop_last, - **dataloader_config) - return dataloader + Args: + samples (List[Any]): a list of micro-batch samples. Each element + in the list is a micro-batch sample. + cycle (bool): whether to cycle the micro-batch samples. If True, + the micro-batch samples will be cycled infinitely. Note this + is only needed when the number of micro-batch samples is less + than expected micro-batch number during runtime. + + Returns: + MicroBatchDataLoader: a micro-batch data loader. + """ + return MicroBatchDataLoader(samples, cycle=cycle) diff --git a/examples/mlp/train.py b/examples/mlp/train.py index b4073ed3..4df366a1 100644 --- a/examples/mlp/train.py +++ b/examples/mlp/train.py @@ -12,7 +12,8 @@ import cube from cube.profiler import CudaTimer from cube.profiler.timer import print_each_rank -from cube.runtime.utils import create_dummy_dataloader +from cube.runtime.utils import microbatches + import examples.mlp.policy.gallery as gallery from examples.utils import get_policy @@ -50,14 +51,16 @@ def forward(self, data): loss = torch.sum(x) return loss +def dummy_data(): + return torch.randn( + args.mbs, args.dim, device=torch.cuda.current_device()) + def train(): model = MLP(dim=args.dim, nlayers=args.layers) - dataloader = create_dummy_dataloader( - torch.randn(args.dim, device=torch.cuda.current_device()), - args.mbs, - ) + # create dummy data + dataloader = microbatches((dummy_data(),)) # compile a training iteration @cube.compile(model, dataloader, PAS=policy) @@ -71,12 +74,17 @@ def train_iter(model, dataloader): optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) CudaTimer(enable=False).warmup() - dataloader = iter(dataloader) iter_num, warmup = 5, 2 for step in range(iter_num): if step == warmup: CudaTimer(enable=True).start('e2e') + + # get data samples + samples = [dummy_data() for _ in range(args.gbs // args.mbs)] + dataloader = microbatches(samples) + # run training iteration train_iter(model, dataloader) + optimizer.step() optimizer.zero_grad() if (step + 1) % 2 == 0: diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 8e5e9dc9..25123183 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -2,7 +2,6 @@ from dataclasses import dataclass import cube -from cube.runtime.utils import create_dummy_dataloader from examples.nlp.blocks.transformer import TransformerLayer @@ -88,18 +87,17 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): return loss -def get_gpt_dummy_dataloader(batch_size: int, cfg: Config): +def dummy_data(batch_size: int, cfg: Config): input_ids = torch.randint( 0, cfg.num_embeddings, - size=(cfg.seqlen,), + size=(batch_size, cfg.seqlen,), dtype=torch.int64, device=torch.cuda.current_device() ) position_ids = torch.arange( 0, cfg.seqlen, dtype=torch.int64, device=torch.cuda.current_device() - ).view(cfg.seqlen,) + ).repeat(batch_size, 1).view(batch_size, cfg.seqlen,) - return create_dummy_dataloader( - (input_ids, position_ids), batch_size) + return input_ids, position_ids diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index ca45d83f..54f5f1f3 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -11,12 +11,12 @@ import logging from functools import partial -from model import GPT, Config -from model import get_gpt_dummy_dataloader +from model import GPT, Config, dummy_data import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary +from cube.runtime.utils import microbatches import examples.nlp.gpt.policy.spmd as spmd import examples.nlp.gpt.policy.mpmd as mpmd @@ -76,7 +76,9 @@ def train(): ) model = GPT(config) model = model if not args.fp16 else model.half() - dataloader = get_gpt_dummy_dataloader(args.mbs, Config) + + gen_data = partial(dummy_data, args.mbs, config) + dataloader = microbatches((gen_data(),), cycle=True) @cube.compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): @@ -92,12 +94,16 @@ def train_iter(model, dataloader): memory_summary() CudaTimer().warmup() - dataloader = iter(dataloader) iter_num, warmup = 5, 2 for step in range(iter_num): if step == warmup: CudaTimer(enable=True).start('e2e') + # collect dummy data + samples = [gen_data() for _ in range(args.gbs // args.mbs)] + dataloader = microbatches(samples) + + # train train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() diff --git a/examples/nlp/mbart/model.py b/examples/nlp/mbart/model.py index 141084fb..6c9764c1 100644 --- a/examples/nlp/mbart/model.py +++ b/examples/nlp/mbart/model.py @@ -5,7 +5,6 @@ from examples.nlp.blocks.transformer import TransformerLayer import cube -from cube.runtime.utils import create_dummy_dataloader @dataclass @@ -163,22 +162,22 @@ def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor): return loss -def get_mbart_dummy_dataloader(batch_size: int, config: Config): +def dummy_data(batch_size: int, config: Config): input_ids = torch.randint( 0, config.vocab, - size=(config.seqlen,), + size=(batch_size, config.seqlen,), dtype=torch.int64, device=torch.cuda.current_device() ) decoder_input_ids = torch.randint( 0, config.vocab, - size=(config.seqlen,), + size=(batch_size, config.seqlen,), dtype=torch.int64, device=torch.cuda.current_device() ) labels = torch.randint( 0, config.num_classes, - size=(), # scalar + size=(batch_size, ), # scalar dtype=torch.int64, device=torch.cuda.current_device() ) - return create_dummy_dataloader((input_ids, decoder_input_ids,), batch_size) + return input_ids, decoder_input_ids diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index f6b49deb..596bce91 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -14,11 +14,13 @@ from functools import partial from examples.nlp.mbart.model import MBartForSentenceClassification, Config -from examples.nlp.mbart.model import get_mbart_dummy_dataloader +from examples.nlp.mbart.model import dummy_data import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary +from cube.runtime.utils import microbatches + import examples.nlp.mbart.policy.gallery as gallery from examples.utils import get_policy @@ -100,7 +102,8 @@ def train(): trunc_normal_(param) model = model.half() if args.fp16 else model - dataloader = get_mbart_dummy_dataloader(batch_size, config) + gen_data = partial(dummy_data, batch_size, config) + dataloader = microbatches((gen_data(),), cycle=True) @cube.compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): @@ -113,11 +116,13 @@ def train_iter(model, dataloader): model.parameters(), lr=3e-05, betas=(0.9, 0.98)) CudaTimer().warmup() - dataloader = iter(dataloader) iter_num, warmup = 5, 2 for step in range(iter_num): if step == warmup: CudaTimer(enable=True).start('e2e') + # prepare input data + samples = [gen_data() for _ in range(args.gbs // args.mbs)] + dataloader = microbatches(samples) # training train_iter(model, dataloader) diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index edae0a44..5f4c8c44 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -6,7 +6,6 @@ from examples.vision.swin.blocks.patch import PatchEmbed, PatchMerging import cube -from cube.runtime.utils import create_dummy_dataloader class Config: @@ -222,10 +221,10 @@ def flops(self): # =========================== Data Loader ======================= -def get_swin_dummy_dataloader(batch_size: int, - dtype: torch.dtype, cfg: Config): +def dummy_data(batch_size: int, + dtype: torch.dtype, cfg: Config): input_ids = torch.randn( - [3, cfg.img_size, cfg.img_size], + [batch_size, 3, cfg.img_size, cfg.img_size], dtype=dtype, device=torch.cuda.current_device() ) - return create_dummy_dataloader(input_ids, batch_size) + return input_ids diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index b293b680..acedd297 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -10,11 +10,12 @@ import torch from functools import partial from examples.vision.swin.blocks.attention import init_relative_position_index -from examples.vision.swin.model import Config, SwinTransformer, get_swin_dummy_dataloader +from examples.vision.swin.model import Config, SwinTransformer, dummy_data import cube from cube.profiler.timer import CudaTimer, print_each_rank from cube.profiler.memory import memory_summary +from cube.runtime.utils import microbatches import examples.vision.swin.policy.gallery as gallery from examples.utils import get_policy @@ -56,7 +57,10 @@ def train(): model = model.half() if args.fp16 else model dtype = torch.float16 if args.fp16 else torch.float32 - dataloader = get_swin_dummy_dataloader(batch_size, dtype, cfg) + + + gen_data = partial(dummy_data, args.mbs, torch.float16, cfg) + dataloader = microbatches((gen_data(),)) @cube.compile(model, dataloader, PAS=policy, load_content=load_content) def train_iter(model, dataloader): @@ -82,12 +86,15 @@ def train_iter(model, dataloader): print_each_rank(f'model parameter: {nparams}') CudaTimer().warmup() - dataloader = iter(dataloader) iter_num, warmup = 5, 2 for step in range(iter_num): if step == warmup: CudaTimer(enable=True).start('e2e') + # collect data + samples = [gen_data() for _ in range(args.gbs // args.mbs)] + dataloader = microbatches(samples, dtype=dtype) + # train iteration train_iter(model, dataloader) optimizer.step() optimizer.zero_grad() diff --git a/tests/graph/function/test_dataloader.py b/tests/graph/function/test_dataloader.py index b2d4850e..d323cb5b 100644 --- a/tests/graph/function/test_dataloader.py +++ b/tests/graph/function/test_dataloader.py @@ -7,22 +7,6 @@ from cube.ir.cten import IRObject from cube.ir.tensor import IRFullTensor from cube.ir.operator import IRDataOperation -from cube.runtime.utils import create_dummy_dataloader - - -def test_dummy_dataloader(): - samples = ( - torch.rand([256, 512], dtype=torch.float32), - torch.rand([128, 224], dtype=torch.float16), - 4, - ) - dataloader = create_dummy_dataloader(samples, batch_size=32) - for idx, samples in enumerate(dataloader): - assert samples[0].shape == torch.Size([32, 256, 512]) - assert samples[1].shape == torch.Size([32, 128, 224]) - assert torch.allclose(samples[2], torch.tensor([4] * 32, dtype=torch.int64)) - if idx == 4: - break def test_data_operation(): diff --git a/tests/runtime/test_dataloader.py b/tests/runtime/test_dataloader.py new file mode 100644 index 00000000..ee4d44e3 --- /dev/null +++ b/tests/runtime/test_dataloader.py @@ -0,0 +1,50 @@ + +import torch +from cube.runtime.utils import MicroBatchDataLoader, microbatches + + +import pytest + +def mock_dataloader_sample(): + tokens = torch.randint(0, 1000, (2, 1024)) + labels = torch.randint(0, 1000, (2,)) + ntokens = 2048 + return tokens, labels, ntokens + + +def test_microbatch_dataloader(): + + samples = [mock_dataloader_sample() for _ in range(4)] + dataloader = microbatches(samples) + + assert isinstance(dataloader, MicroBatchDataLoader) + assert len(dataloader) == 4 + + sample = next(dataloader) + assert isinstance(sample, tuple) + assert isinstance(sample[0], torch.Tensor) + assert isinstance(sample[1], torch.Tensor) + assert isinstance(sample[2], int) + _ = next(dataloader) + _ = next(dataloader) + _ = next(dataloader) + with pytest.raises(StopIteration): + _ = next(dataloader) + + +def test_microbatch_dataloaser_with_cycle(): + + samples = [mock_dataloader_sample() for _ in range(4)] + dataloader = microbatches(samples, cycle=True) + + assert isinstance(dataloader, MicroBatchDataLoader) + assert len(dataloader) == 4 + + # no stop iteration should be raised + for _ in range(16): + sample = next(dataloader) + assert isinstance(sample, tuple) + assert isinstance(sample[0], torch.Tensor) + assert isinstance(sample[1], torch.Tensor) + assert isinstance(sample[2], int) + \ No newline at end of file From 8ea46f3de61438b5a4f8f8ac5a1b7eda8cddf31a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 26 Dec 2023 06:24:02 +0000 Subject: [PATCH 1549/1892] Merged PR 1967: fix bug in graph indexing: cannot index IRSegment node fix bug in graph indexing: cannot index IRSegment node --- cube/graph/segment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index c56ac087..98b0e939 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -236,13 +236,13 @@ def multi_index(self, nodes: List[IRCell]) -> List[CellPosition]: def dfs(seg: IRSegment, path: List[int]): nonlocal visited, indices for idx, node in enumerate(seg._nodes): - if isinstance(node, IRSegment): - dfs(node, path + [idx]) - elif node in nodes: + if node in nodes: indices[nodes.index(node)] = CellPosition(tuple(path + [idx])) visited += 1 if visited == len(nodes): return + if isinstance(node, IRSegment): + dfs(node, path + [idx]) dfs(self, []) if visited != len(nodes): unvisited = [] From 728f88543921a4075065b61b412406402e8f7857 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 26 Dec 2023 08:15:54 +0000 Subject: [PATCH 1550/1892] Merged PR 1958: rewrite inplace op rewrite inplace op unit test passed parity check passed. --- cube/graph/parser/converter.py | 28 ++++ .../concrete_trace_utils/concrete_tracer.py | 3 + .../parser/fx/concrete_trace_utils/utils.py | 8 ++ tests/graph/tracer/test_inplace.py | 131 ++++++++++++++++++ 4 files changed, 170 insertions(+) create mode 100644 tests/graph/tracer/test_inplace.py diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 6203c77d..d18d7dfa 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Union import logging from pathlib import Path +import operator from cube.ir.tensor import IRFullTensor from cube.graph.parser.register import CustomizedOps @@ -10,6 +11,7 @@ from cube.graph.parser.fx.parser import FxModuleParser from cube.graph.parser.fx.concrete_trace_utils import concrete_trace from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply +from cube.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops import cube.runtime.function as cube_rt_function @@ -30,6 +32,31 @@ def unpack(x): super().__init__(pack, unpack) +def _rewrite_inplace_ops(traced_model: torch.fx.GraphModule): + """Rewrite inplace ops to use its outputs so we can track them in IRGraph + + x.add_(y) => x = x.add_(y) + operator.iadd(x, y) => x = operator.iadd(x, y) + x += y => x += y # no change + + Args: + traced_model (torch.fx.GraphModule): fx graph to be modified + """ + done_nodes = set() + for n in traced_model.graph.nodes: + done_nodes.add(n) + if ( + (n.op == "call_method" and n.target.endswith("_") and not n.target.endswith("__")) + or (n.op == "call_function" and n.target in side_effectful_inplace_ops) + ) and n.args[0].meta.get('type', None) == torch.Tensor: + n.args[0].replace_all_uses_with(n, delete_user_cb=lambda node: not node in done_nodes) + # we can't recompile + # it will raise error if we have autograd, customized op, etc. + # The good part is we don't need to generate python code + # we will use the fx graph directly + # traced_model.recompile() + + def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: """ Convert torch.nn.Module based model into torch.fx.GraphModule @@ -61,6 +88,7 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: cpu_offload=True, record_frames=not CompileFlag.disable_code_line_info, ) + _rewrite_inplace_ops(traced_model) return traced_model diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 36098fae..1f60303d 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -124,6 +124,8 @@ def __exit__(self, *args): _orig_max, _orig_node_is_impure, + + side_effectful_inplace_ops, ) from .utils import FrameRecord, ExtraSEFPatcher, extract_results_metadata, EmptyResult @@ -1717,6 +1719,7 @@ def f(x, y): default_extra_side_effectful_functions = { operator.setitem, builtins.next, + *side_effectful_inplace_ops } extra_side_effectful_functions = default_extra_side_effectful_functions | dce_ignored_function with _Patcher() as patcher, ExtraSEFPatcher(extra_side_effectful_functions): diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index b2ac345e..24d23801 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -53,6 +53,14 @@ _orig_node_is_impure: Callable = Node.is_impure +side_effectful_inplace_ops = { + operator.iadd, operator.isub, operator.imul, operator.itruediv, operator.ifloordiv, + operator.iand, operator.ior, operator.ixor, operator.ilshift, operator.irshift, + operator.imod, operator.ipow, + # operator.imatmul is not implemented in torch + # so let's ignore it now +} + def run_onlyif_instance(cond_type: Type[Any], return_orig: bool = True, return_const: Any = None): def helper(fn): diff --git a/tests/graph/tracer/test_inplace.py b/tests/graph/tracer/test_inplace.py new file mode 100644 index 00000000..547bd06d --- /dev/null +++ b/tests/graph/tracer/test_inplace.py @@ -0,0 +1,131 @@ +import operator +import torch + +from cube.graph.parser.converter import to_fx_graph +from cube.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops + +from ...utils import replace_all_device_with + + +def test_side_effectful_inplace_ops(): + # this is for the validation of side_effectful_inplace_ops + int_inplace_ops = { + operator.iadd, operator.isub, operator.imul, operator.ifloordiv, + operator.ilshift, operator.irshift, + operator.imod, operator.ipow, + } + float_inplace_ops = { + operator.itruediv, + } + bool_inplace_ops = { + operator.iand, operator.ior, operator.ixor, + } + + assert int_inplace_ops.union(float_inplace_ops, bool_inplace_ops) == side_effectful_inplace_ops + + not_implemented_mat_inplace_ops = { + operator.imatmul + } + + for intop in int_inplace_ops: + x = torch.arange(1, 10, dtype=torch.int32) + y = torch.arange(1, 10, dtype=torch.int32) + orig_xid = id(x) + orig_x = x.clone() + intop(x, y) + assert id(x) == orig_xid + assert not torch.equal(x, orig_x) + + for floatop in float_inplace_ops: + x = torch.randn(3, 3) + y = torch.randn(3, 3) + orig_xid = id(x) + orig_x = x.clone() + floatop(x, y) + assert id(x) == orig_xid + assert not torch.equal(x, orig_x) + + for boolop in bool_inplace_ops: + x = torch.tensor([True, False, True]) + y = torch.tensor([True, True, False]) + orig_xid = id(x) + orig_x = x.clone() + boolop(x, y) + assert id(x) == orig_xid + assert not torch.equal(x, orig_x) + + for matop in not_implemented_mat_inplace_ops: + x = torch.randn(3, 3) + y = torch.randn(3, 3) + orig_xid = id(x) + orig_x = x.clone() + z = matop(x, y) + assert id(x) == orig_xid + assert not torch.equal(z, orig_x) + assert torch.equal(x, orig_x) # not updated + + +class InplaceOpModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + x += y + x.add_(y) + z = x + y + x.add_(z) + operator.iadd(x, y) + w = torch.add(z, x) + return w + +@replace_all_device_with('cpu') +def test_inplace_op(): + model = InplaceOpModule() + dummy_input = {'x': torch.tensor([1.0,2.0,3.0]), 'y': torch.tensor([4.0,5.0,6.0])} + traced_graph = to_fx_graph(model, dummy_input) + assert torch.equal( + model(torch.tensor([1.0,2.0,3.0]), torch.tensor([4.0,5.0,6.0])), + traced_graph(torch.tensor([1.0,2.0,3.0]), torch.tensor([4.0,5.0,6.0])) + ) + nodes = list(traced_graph.graph.nodes) + assert nodes[0].op == 'placeholder' + assert nodes[0].name == 'x' + assert nodes[1].op == 'placeholder' + assert nodes[1].name == 'y' + + # x += y + assert nodes[2].op == 'call_function' + assert nodes[2].target == operator.iadd + assert nodes[2].args[0] == nodes[0] + assert nodes[2].args[1] == nodes[1] + + # x.add_(y) + assert nodes[3].op == 'call_method' + assert nodes[3].target == 'add_' + assert nodes[3].args[0] == nodes[2] # return value of `x += y` + assert nodes[3].args[1] == nodes[1] + + #z = x + y + assert nodes[4].op == 'call_function' + assert nodes[4].target == operator.add + assert nodes[4].args[0] == nodes[3] # return value of `x.add_(y)` + assert nodes[4].args[1] == nodes[1] + + # x.add_(z) + assert nodes[5].op == 'call_method' + assert nodes[5].target == 'add_' + assert nodes[5].args[0] == nodes[3] # return value of `x.add_(y)` + assert nodes[5].args[1] == nodes[4] # return value of `z = x + y` + + # operator.iadd(x, y) + assert nodes[6].op == 'call_function' + assert nodes[6].target == operator.iadd + assert nodes[6].args[0] == nodes[5] # return value of `x.add_(z)` + assert nodes[6].args[1] == nodes[1] # return value of `y` + + #w = torch.add(z, x) + assert nodes[7].op == 'call_function' + assert nodes[7].target == torch.add + assert nodes[7].args[0] == nodes[4] # return value of `z = x + y` + assert nodes[7].args[1] == nodes[6] # return value of `operator.iadd(x, y)` + + assert nodes[8].op == 'output' + assert nodes[8].args[0] == nodes[7] # return value of `return w` + From a2695a86bbb1d0a780943cf6bbe92eb2e2588ed2 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 26 Dec 2023 08:52:45 +0000 Subject: [PATCH 1551/1892] Merged PR 1900: update tracer for lora support parity check & UT passed --- .../fx/concrete_trace_utils/concrete_proxy.py | 6 +++ .../concrete_trace_utils/concrete_tracer.py | 41 ++++++++++++++++++- tests/graph/parser/test_dce.py | 25 +++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 tests/graph/parser/test_dce.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 105946fc..0c4f02bd 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -76,6 +76,9 @@ def __repr__(self) -> str: return f'ConcreteProxy({self.node.name}, {self.value})' def __getattr__(self, k) -> ConcreteProxy: + # if the proxy is a wrapped module, forward this call to the torch.nn.Module.__getattribute__ + if _orig_isinstance(self.value, torch.nn.Module): + return torch.nn.Module.__getattribute__(self.value, k) return ConcreteAttrProxy(self, k) def __call__(self, *args, **kwargs) -> ConcreteProxy: @@ -174,6 +177,9 @@ def __bool__(self) -> Union[bool, ConcreteProxy]: elif insts[cur].opname == 'CONTAINS_OP': # in executing 'in' return _orig_bool(self.value) + elif insts[cur].opname == 'BINARY_SUBSCR': + # in executing slice or index, my_list[index] or my_dict[key] + return _orig_bool(self.value) elif insts[cur].opcode == self.op_call_ex: # in executing func(..., *proxy) return _orig_bool(self.value) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 1f60303d..bc357119 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -34,6 +34,8 @@ from torch.fx.proxy import TracerBase from torch.fx.operator_schemas import check_for_mutable_operation +dict_keys_type = type(dict().keys()) + try: # Scope is a new class to record module path in pytorch 2.0 from torch.fx.proxy import Scope @@ -240,6 +242,12 @@ class ConcreteTracer(TracerBase): to_func = getattr(torch.Tensor, name, None) to_func = None if to_func == attr else to_func default_autowrap_leaf_function[attr] = ([], False, to_func) + # find the multi position for default_autowrap_leaf_function in torch.__dir__() + for name in dir(torch): + attr = getattr(torch, name) + if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__') \ + and attr in default_autowrap_leaf_function: + default_autowrap_leaf_function[attr][0].append((torch, name)) default_autowrap_leaf_class: Dict[Type, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool]] = { # class @@ -551,6 +559,9 @@ def create_arg(self, a: Any) -> Union[Node, Any]: if isinstance(a, (torch.autograd.function.Function, torch.autograd.function.FunctionMeta)): return a + if isinstance(a, dict_keys_type): + return a + return super().create_arg(a) @compatibility(is_backward_compatible=True) @@ -982,7 +993,7 @@ def torch_assert_wrapper(condition, message): positions = (*positions, (torch.Tensor, func.__name__)) wrapped = _create_wrapped_leaf_method(self, getattr(torch.Tensor, func.__name__), func.__name__, to_func) elif func.__qualname__.startswith('_VariableFunctionsClass'): - if hasattr(torch, func.__name__): + if hasattr(torch, func.__name__) and getattr(torch, func.__name__) == func: # avoid bad attr like 'unique_dim' positions = (*positions, (torch, func.__name__)) if is_force_trace: @@ -1492,6 +1503,9 @@ def _retain_weight_consistency(root: torch.nn.Module): @functools.wraps(_orig_node_is_impure) def node_is_impure_wrapper(node): + if is_useless_iter(node): + return False + if node.op in {"placeholder", "output"}: return True @@ -1513,6 +1527,31 @@ def node_is_impure_wrapper(node): return False +def is_useless_iter(node: Node): + if node.op == 'call_function' and node.target is iter: + node_is_impure = False + for iter_user in node.users: + if not is_useless_next(iter_user): + node_is_impure = True + break + if not node_is_impure: + for iter_user in list(node.users.keys()): + setattr(iter_user, '_is_impure', False) + iter_user.graph.erase_node(iter_user) + if len(node.users) > 0: + raise RuntimeError('The user node of iter is not empty, something goning wrong.') + setattr(node, '_is_impure', False) + return True + else: + return False + +def is_useless_next(node: Node): + if node.op == "call_function" and node.target is next: + if len(node.users) == 0: + return True + else: + return False + def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Union[Dict[str, Any], Tuple], *, diff --git a/tests/graph/parser/test_dce.py b/tests/graph/parser/test_dce.py new file mode 100644 index 00000000..a7138e39 --- /dev/null +++ b/tests/graph/parser/test_dce.py @@ -0,0 +1,25 @@ +import pytest +import torch + +from cube.graph.parser.converter import to_fx_graph + + +def test_dce_useless_next(): + class SimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m_dict = torch.nn.ModuleDict({'test': torch.nn.Linear(10, 8)}) + self.m_keys = ['test', 'default'] + + def forward(self, x): + for key in self.m_keys: + # TODO: (ning) This kind of code style will call instruction 'BINARY_SUBSCR' + # And we should not pass a proxy to the 'BINARY_SUBSCR', or it will trigger a type check error. + # This has already fixed in concrete_proxy.py, but it is hard to add a test for this, add it when we get a idea. + if key in self.m_dict.keys(): + x = self.m_dict[key](x) + return x + + traced_graph = to_fx_graph(SimpleModel(), {'x': torch.rand(4, 10)}) + for node in traced_graph.graph.nodes: + assert node.target is not next and node.target is not iter From 40825cfc5b549cc688132623be1b2507ef5eaec7 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 2 Jan 2024 06:51:23 +0000 Subject: [PATCH 1552/1892] Merged PR 1973: refine ast transform refine ast transform unit test pass parity check pass --- .../concrete_trace_utils/operator_patcher.py | 215 +++++++++--------- tests/graph/parser/test_ast_transformer.py | 133 +++++++++++ 2 files changed, 236 insertions(+), 112 deletions(-) create mode 100644 tests/graph/parser/test_ast_transformer.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index e03ab357..e289513e 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -13,7 +13,7 @@ from textwrap import dedent from types import MethodType, FunctionType -from typing import List, Optional, Callable, Dict +from typing import List, Optional, Callable, Dict, Tuple import torch @@ -29,31 +29,73 @@ _logger = logging.getLogger(__name__) -class TransformerOp(ast.NodeTransformer): - """ - An ast transformer, to check and replace the python ops 'not/is/is not/in/not in' to functions in 'operator' module. - """ +class TrackedTransformer(ast.NodeTransformer): + def __init__(self) -> None: + super().__init__() + self.modified = False - def visit_start(self, node): - # to mark if the ast is changed - self.is_transformed = False - # detect the expr now is in a branch test expr - # 0: not in a branch test expr. - # 1: in propagate if not in func 'visit', or not in a branch test expr in func 'visit' - # 2: in a branch test expr - self.is_incond_status = 0 - ret = super().visit(node) - return self.is_transformed, ret +class OperatorTransformer(TrackedTransformer): + func_map = { + ast.Not: 'not_', # operator.not_ + ast.Is: 'is_', # operator.is_ + ast.IsNot: 'is_not', # operator.is_not + ast.In: 'contains', # operator.contains + } + def visit_UnaryOp(self, node: ast.UnaryOp): + if _orig_isinstance(node.op, ast.Not): + self.modified = True + return self.generic_visit(ast.Call( + func=ast.Name(id=self.func_map[ast.Not], ctx=ast.Load()), + args=[node.operand], + keywords=[] + )) + else: + return self.generic_visit(node) - def visit(self, node): - if self.is_incond_status != 0: - # if the status is 'in branch test', - self.is_incond_status -= 1 - return super().visit(node) + def visit_Compare(self, node: ast.Compare): + if not any(_orig_isinstance(op, (ast.Is, ast.IsNot, ast.In, ast.NotIn)) for op in node.ops): + return self.generic_visit(node) + if _orig_len(node.ops) != 1: + raise RuntimeError('Chained Comparison is not supported') + self.modified = True + if _orig_isinstance(node.ops[0], (ast.In, ast.NotIn)): + args = [node.comparators[0], node.left] + else: + args = [node.left, node.comparators[0]] + if not _orig_isinstance(node.ops[0], ast.NotIn): + ret_node = ast.Call( + func=ast.Name(id=self.func_map[type(node.ops[0])], ctx=ast.Load()), + args=args, + keywords=[], + ) + else: + # not in => operator.not_(operator.contains()) + in_node = ast.Call( + func=ast.Name(id=self.func_map[ast.In], ctx=ast.Load()), + args=args, + keywords=[], + ) + ret_node = ast.Call( + func=ast.Name(id=self.func_map[ast.Not], ctx=ast.Load()), + args=[in_node], + keywords=[] + ) + + return self.generic_visit(ret_node) + + +class SuperTransformer(TrackedTransformer): + """ + Convert super() to super(self.__class__, self) + Because in Patcher, we only patch funtions (instead of class). + super() is not supported for a standalone function. + """ def visit_Call(self, node: ast.Call): - if isinstance(node.func, ast.Name) and node.func.id == 'super' and _orig_len(node.args) == 0: + if _orig_isinstance(node.func, ast.Name) and node.func.id == 'super' and _orig_len(node.args) == 0: + self.modified = True + # convert super() to super(self.__class__, self) return self.generic_visit(ast.Call( func=ast.Name(id='super', ctx=ast.Load()), args=[ @@ -62,102 +104,47 @@ def visit_Call(self, node: ast.Call): ], keywords=node.keywords, )) - elif not isinstance(node.func, ast.Name) or node.func.id != 'patch_run': - self.is_transformed = True - return self.generic_visit(ast.Call( - func=ast.Name(id='patch_run', ctx=ast.Load()), - args=[node.func, *node.args], - keywords=node.keywords, - )) else: return self.generic_visit(node) - def visit_While(self, node: ast.While): - self.is_incond_status = 2 - node.test = self.visit(node.test) - self.is_incond_status = 0 - node.body = [self.visit(item) for item in node.body] - node.orelse = [self.visit(item) for item in node.orelse] - return node - - def visit_If(self, node: ast.If): - self.is_incond_status = 2 - node.test = self.visit(node.test) - self.is_incond_status = 0 - node.body = [self.visit(item) for item in node.body] - node.orelse = [self.visit(item) for item in node.orelse] - return node - - def visit_IfExp(self, node: ast.IfExp): - node.body = self.visit(node.body) - self.visit(node.body) - self.is_incond_status = 2 - node.test = self.visit(node.test) - self.is_incond_status = 0 - node.orelse = self.visit(node.orelse) - return node - def visit_UnaryOp(self, node: ast.UnaryOp): - if self.is_incond_status != 0: - # in branch cond test expr, need no replacement - self.is_incond_status = 2 - return self.generic_visit(node) - elif _orig_isinstance(node.op, ast.Not): - self.is_transformed = True +class ProxyCallTransformer(TrackedTransformer): + def __init__(self, proxy_call_name: str, ignore_funcs: Optional[List[str]] = None) -> None: + """ + Args: + proxy_call_name: the name of the proxy function + ignore_funcs: a list of function names that should not be transformed + """ + super().__init__() + self.proxy_call_name = proxy_call_name + self.ignore_funcs = ignore_funcs or [] + + def visit_Call(self, node: ast.Call): + # will transform all function call to `proxy_call_name(func_name, *args, **kwargs)` + # node.func can be expression, in that case, node.func.id is undefined. + if not _orig_isinstance(node.func, ast.Name) or ( + node.func.id != self.proxy_call_name and node.func.id not in self.ignore_funcs + ): + self.modified = True return self.generic_visit(ast.Call( - func=ast.Name(id='not_', ctx=ast.Load()), - args=[node.operand], - keywords=[], + func=ast.Name(id=self.proxy_call_name, ctx=ast.Load()), + args=[node.func, *node.args], + keywords=node.keywords, )) else: return self.generic_visit(node) - def visit_BoolOp(self, node: ast.BoolOp): - if self.is_incond_status != 0: - # in branch cond test expr, need no replacement - self.is_incond_status = 2 - return self.generic_visit(node) - else: - if not _orig_isinstance(node.values[1], (ast.Call, ast.BoolOp)): - _logger.warning('warning: "and/or" will generate branch expr. The 2nd arg can\'t be traced if the 1st arg returns a True.' - ' Don\'t mix up "and/or" and "&/|"!') - return self.generic_visit(node) - def visit_Compare(self, node: ast.Compare): - should_replace = False - for op in node.ops: - if _orig_type(op) in (ast.Is, ast.IsNot, ast.In, ast.NotIn): - should_replace = True - break - if should_replace: - if _orig_len(node.ops) != 1: - raise RuntimeError( - 'not supported in "{} cmp_op {} cmp_op {}" when cmp_op contains "is/is not/in/not in"') - self.is_transformed = True - func_id = { - ast.Is: 'is_', - ast.IsNot: 'is_not', - ast.In: 'contains', - ast.NotIn: 'contains', - }[_orig_type(node.ops[0])] - if _orig_isinstance(node.ops[0], (ast.In, ast.NotIn)): - args = [node.comparators[0], node.left] - else: - args = [node.left, node.comparators[0]] - ret_node = ast.Call( - func=ast.Name(id=func_id, ctx=ast.Load()), - args=args, - keywords=[], - ) - if _orig_isinstance(node.ops[0], ast.NotIn): - ret_node = ast.Call( - func=ast.Name(id='not_', ctx=ast.Load()), - args=[ret_node], - keywords=[], - ) - return self.generic_visit(ret_node) - else: - return self.generic_visit(node) +def transform(node: ast.AST, transformers: List[TrackedTransformer]) -> Tuple[bool, ast.AST]: + modified = False + for transformer in transformers: + node = transformer.visit(node) + modified = modified or transformer.modified + + if modified: + return True, ast.fix_missing_locations(node) + else: + return False, node class OperatorPatcher: @@ -165,11 +152,10 @@ class OperatorPatcher: An function patcher, to patch the un-wrappable operator 'not/is/is not/in/not in' to wrappable functions. """ - transformer_op = TransformerOp() - def __init__(self, use_operator_patch: bool, operator_patch_backlist: List[str]): self.use_operator_patch = use_operator_patch self.operator_patch_backlist = operator_patch_backlist + self.proxy_call_name = OperatorPatcherContext.patch_run.__name__ def patch_inner(self, func): return self.patch_inner_helper(func) @@ -199,7 +185,12 @@ def patch_inner_helper(self, func): dedent_src = dedent(source) tree = ast.parse(dedent_src) - is_transformed, new_tree = OperatorPatcher.transformer_op.visit_start(tree) + # transformers have states, so we can't reuse them. + is_transformed, new_tree = transform(tree, [ + OperatorTransformer(), + SuperTransformer(), + ProxyCallTransformer(self.proxy_call_name) + ]) if not is_transformed: return func else: @@ -224,7 +215,7 @@ def patch_inner_helper(self, func): # these decorators are used for tranformers model docstrings generation, can be removed in trace transform_useless_decorators = ('add_start_docstrings_to_model_forward', 'add_code_sample_docstrings', 'replace_return_docstrings') body0.decorator_list = [i for i in body0.decorator_list - if isinstance(i, ast.Call) and isinstance(i.func, ast.Name) and i.func.id == 'patch_run' and + if isinstance(i, ast.Call) and isinstance(i.func, ast.Name) and i.func.id == self.proxy_call_name and isinstance(i.args[0], ast.Name) and i.args[0].id not in transform_useless_decorators] ast.fix_missing_locations(new_tree) @@ -246,7 +237,7 @@ def patch_inner_helper(self, func): # use func.__code__.co_filename to make the new function easily debuggable. compile(new_tree, func_inner.__code__.co_filename, 'exec'), { - 'patch_run': OperatorPatcherContext.patch_run, + self.proxy_call_name: OperatorPatcherContext.patch_run, **func_inner.__globals__, **closure_dict, }, diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py new file mode 100644 index 00000000..559e12a8 --- /dev/null +++ b/tests/graph/parser/test_ast_transformer.py @@ -0,0 +1,133 @@ +import ast +from textwrap import dedent + +from cube.graph.parser.fx.concrete_trace_utils.operator_patcher import ( + OperatorTransformer, + SuperTransformer, + ProxyCallTransformer, + transform +) + +def test_op_transfomer(): + tree = ast.parse(dedent(''' + x = True + y = not x + y1 = x and not y + y2 = x is None + y3 = x is not None + y4 = x in () + y5 = x not in () + if x and not y: + pass + y6 = y1 is None or y2 is not None + y7 = y1 if y2 is None else y3 + while x and not y: + pass + ''').strip()) + transformers = [OperatorTransformer()] + modified, new_ast = transform(tree, transformers) + assert modified + assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' + x = True + y = not_(x) + y1 = x and not_(y) + y2 = is_(x, None) + y3 = is_not(x, None) + y4 = contains((), x) + y5 = not_(contains((), x)) + if x and not_(y): + pass + y6 = is_(y1, None) or is_not(y2, None) + y7 = y1 if is_(y2, None) else y3 + while x and not_(y): + pass + ''').strip() + + +def test_super_transform(): + tree = ast.parse(dedent(''' + class A: + def __init__(self) -> None: + super().__init__() + def f(self): + super(A, self).f() + @staticmethod + def g(): + pass + @classmthod + def h(cls): + pass + def i(self): + pass + ''').strip()) + + transfomers = [SuperTransformer()] + modified, new_ast = transform(tree, transfomers) + assert modified + assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' + class A: + def __init__(self) -> None: + super(self.__class__, self).__init__() + def f(self): + super(A, self).f() + @staticmethod + def g(): + pass + @classmthod + def h(cls): + pass + def i(self): + pass + ''').strip() + + +def test_proxy_call_transform(): + # the `(x+y)(a, b)` statement below just demonstrates + # AST doesn't care about the real meaning of the expression. + # It looks like a function call, so it will be treated as function call in AST. + # And for us, we also patch it just like a function call. + tree = ast.parse(dedent(''' + def f(func_name, type: int, /, *args, **kwargs): + return func_name(type, *args, **kwargs) + def g(): + return (x + y)(a, b) + class A: + def f(self) -> None: + super().f() + ''').strip()) + + transfomers = [ProxyCallTransformer('patched_run')] + modified, new_ast = transform(tree, transfomers) + assert modified + assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' + def f(func_name, type: int, /, *args, **kwargs): + return patched_run(func_name, type, *args, **kwargs) + def g(): + return patched_run(x + y, a, b) + class A: + def f(self) -> None: + patched_run(patched_run(super).f) + ''').strip() + + +def test_transform_combine(): + tree = ast.parse(dedent(''' + x = not True + def f(func_name, type: int, /, *args, **kwargs): + return func_name(type, *args, **kwargs) + class A: + def __init__(self) -> None: + super().__init__() + ''').strip()) + + transfomers = [OperatorTransformer(), SuperTransformer(), ProxyCallTransformer('patched_run', ['super'])] + modified, new_ast = transform(tree, transfomers) + assert modified + assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' + x = patched_run(not_, True) + def f(func_name, type: int, /, *args, **kwargs): + return patched_run(func_name, type, *args, **kwargs) + class A: + def __init__(self) -> None: + patched_run(super(self.__class__, self).__init__) + ''').strip() From e4d3731201f2b2c93edf49641ac945e30e4a0af1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 2 Jan 2024 09:10:06 +0000 Subject: [PATCH 1553/1892] Merged PR 1972: move dynamic_shape option to compute config move dynamic_shape option to compute config --- cube/parallel.py | 10 ++++------ tests/parallel_module/test_gencode.py | 14 -------------- tests/parallel_module/test_inference.py | 1 - tests/parallel_module/test_init.py | 1 - tests/parallel_module/test_nested.py | 1 - tests/parallel_module/test_override.py | 1 - tests/parallel_module/test_reducer_hook.py | 1 - tests/parallel_module/test_submodule.py | 1 - tests/parallel_module/test_wholemodule.py | 1 - 9 files changed, 4 insertions(+), 27 deletions(-) diff --git a/cube/parallel.py b/cube/parallel.py index 151166c4..fd8e7fa9 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -43,6 +43,9 @@ class ComputeConfig: plan_ngpus: int runtime_ngpus: int + # whether to use dynamic shape to generate code + dynamic_shape: bool = True + use_zero: bool = False zero_ngroups: int = 1 @@ -346,7 +349,6 @@ def _gencode( compute_config: ComputeConfig, outdir: Path, *, - dynamic_shape: bool = True, module_dtype: Optional[torch.dtype] = None, module_fn: Optional[Callable[[], torch.nn.Module]] = None, ) -> None: @@ -366,7 +368,6 @@ def _gencode( pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy compute_config (ComputeConfig): the environment resource outdir (Path): the directory to save generated code - dynamic_shape (bool): whether to use dynamic shape module_dtype (Optional[torch.dtype]): the dtype of the module. Keep as it is when it is None. module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. @@ -399,7 +400,7 @@ def _gencode( if any(isinstance(m, CubeModule) for m in module.modules()): raise RuntimeError('CubeModule can not be nested.') - graph, forward_args = _gen_graph(module, dummy_input, outdir, dynamic_shape) + graph, forward_args = _gen_graph(module, dummy_input, outdir, compute_config.dynamic_shape) graph.dump(graph_ckp) torch.save(forward_args, forward_args_ckp) @@ -491,7 +492,6 @@ def parallelize( pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], compute_config: ComputeConfig, *, - dynamic_shape: bool = True, cube_savedir: Union[str, Path] = './.cube', reuse: Union[ReuseType, str] = ReuseType.ALL, instance_name: Optional[str] = None, @@ -554,7 +554,6 @@ def __init__(self, init_params=True): dummy_input (dict): the dummy input for the module pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy compute_config (ComputeConfig): the environment resource - dynamic_shape (bool): whether to use dynamic shape reuse (ReuseType): specify which part can be reused. cube_savedir (Union[str, Path]): the directory to save generated code instance_name (Optional[str]): the instance name of the generated module. @@ -593,7 +592,6 @@ def __init__(self, init_params=True): pas_policy, compute_config, outdir, - dynamic_shape=dynamic_shape, module_dtype=module_dtype, module_fn=module_fn, ) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index ad85c11f..4419b23e 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -17,7 +17,6 @@ def _to_cube_model(module, compute_config, cube_savedir, load_module): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, compute_config, - dynamic_shape=True, cube_savedir=cube_savedir, load_module=load_module ) @@ -64,7 +63,6 @@ def test_codegen_slice(): {'x': torch.tensor([1.0, 2.0, 3.0, 6.0])}, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=False ) @@ -93,7 +91,6 @@ def test_codegen_args(): }, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=True ) @@ -121,7 +118,6 @@ def _gencode_unused_args_worker(tempdir): }, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=True ) @@ -176,7 +172,6 @@ def _gencode_unused_args_worker2(tempdir): }, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=True ) @@ -239,7 +234,6 @@ def test_codegen_attr(): {'x': torch.tensor([1.0, 2.0, 3.0, 6.0]), 'attr': AttrHelper()}, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=False ) @@ -272,7 +266,6 @@ def test_codegen_getitem(): {'batched_data': {'x': torch.tensor([[[1.0], [2.0], [3.0], [6.0]]])}}, PASRandomSPMD, ComputeConfig(2, 2), - dynamic_shape=True, cube_savedir=tempdir, load_module=False, ) @@ -303,7 +296,6 @@ def test_codegen_training_flag(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=False ) @@ -332,7 +324,6 @@ def test_codegen_training_flag(): # {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, # PASData, # ComputeConfig(1, 2), -# dynamic_shape=True, # cube_savedir=tempdir, # load_module=False # ) @@ -364,7 +355,6 @@ def test_codegen_iter(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=False ) @@ -395,7 +385,6 @@ def test_codegen_const(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=False ) @@ -435,7 +424,6 @@ def test_codegen_tensor_slice(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=False, reuse='none', @@ -447,7 +435,6 @@ def test_codegen_tensor_slice(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, load_module=False, reuse='none', @@ -475,7 +462,6 @@ def test_codegen_dictget(): }}, PASRandomSPMD, ComputeConfig(2, 2), - dynamic_shape=True, cube_savedir=tempdir, load_module=False, ) diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 255c6817..eaf33c49 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -48,7 +48,6 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - dynamic_shape=True, cube_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index f5c63943..d327e89f 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -24,7 +24,6 @@ def _init_params_worker(): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, PASRandomSPMD, ComputeConfig(1, 1), - dynamic_shape=True, cube_savedir=tempdir, reuse='all', ) diff --git a/tests/parallel_module/test_nested.py b/tests/parallel_module/test_nested.py index 30fb3b26..080ee1d5 100644 --- a/tests/parallel_module/test_nested.py +++ b/tests/parallel_module/test_nested.py @@ -14,7 +14,6 @@ def _to_cube_model(module, pas, compute_config, cube_savedir): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, pas, compute_config, - dynamic_shape=True, cube_savedir=cube_savedir ) diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index dc753e94..dd6b0375 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -18,7 +18,6 @@ def _to_cube_model(module, compute_config, cube_savedir, reuse, instance_name, l {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, compute_config, - dynamic_shape=True, reuse=reuse, cube_savedir=cube_savedir, instance_name=instance_name, diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index 63f1af2d..a8bff918 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -33,7 +33,6 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - dynamic_shape=True, cube_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 32c372e5..6fc72b2c 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -42,7 +42,6 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - dynamic_shape=True, cube_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index aa88edfb..e9307846 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -41,7 +41,6 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - dynamic_shape=True, cube_savedir=cube_savedir, instance_name=instance_name ) From 3a27574d8197f96019bbb8877cdc2c41989b244a Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 3 Jan 2024 06:46:35 +0000 Subject: [PATCH 1554/1892] Merged PR 1969: add mock to dist With this mock, we can run test cases more easily in cpu-only machines. --- cube/parallel.py | 7 +- tests/parallel_module/test_init.py | 47 +++++++++- tests/test_program.py | 38 ++++---- tests/utils.py | 139 +++++++++++++++++++++++++++-- 4 files changed, 202 insertions(+), 29 deletions(-) diff --git a/cube/parallel.py b/cube/parallel.py index fd8e7fa9..7d4468e2 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -77,13 +77,10 @@ def __post_init__(self): @contextmanager -def _flags(flags, warning_on_override=True, /, **kwargs): +def _flags(flags, /, **kwargs): old_flags = {} for k, v in kwargs.items(): old_flags[k] = getattr(flags, k) - if old_flags[k] != v: - if warning_on_override: - logger.warning(f"{flags}.{k}={old_flags[k]} is not supported. Changed to {v}.") setattr(flags, k, v) try: yield @@ -102,7 +99,7 @@ def _compile_flags(compute_config: ComputeConfig): def _runtime_flags(**kwargs): - return _flags(RuntimeFlag, False, **kwargs) + return _flags(RuntimeFlag, **kwargs) def _complex(val: Any): diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index d327e89f..a6b0826f 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -1,11 +1,13 @@ import tempfile +import pytest import torch -from cube.parallel import parallelize, ComputeConfig +from cube.parallel import _load_cube_module_class, parallelize, ComputeConfig from ..launch_torchrun import launch_torchrun from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD, clear_dir_on_rank0 +from ..utils import new_empty, replace_all_device_with, mock_dist class MyModule(torch.nn.Module): def __init__(self): @@ -47,3 +49,46 @@ def test_init_params(): print('skip test_init_params due to lack of cuda devices') return launch_torchrun(1, _init_params_worker) + + +class MyModule2(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = CubeLinear(4, 4, bias=True) + + def forward(self, x): + return self.linear(x) + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('model_class,tp', [(MyModule2, True), (MyModule, False)]) +def test_empty_weights(model_class, tp): + # MyModule2 uses CubeLinear, so tp works + # MyModule uses torch.nn.Linear, so tp doesn't work + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + model_class, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + PASRandomSPMD, + ComputeConfig(2, 4, use_zero=True, zero_ngroups=2), + cube_savedir=tempdir, + reuse='all', + load_module=False, + ) + for i in range(4): + module_class = _load_cube_module_class(model_class, cube_savedir=tempdir, rank=i) + m = new_empty(module_class) + assert m.get_rank() == i + for p in m.parameters(): + assert p.device == torch.device('meta') + for r in m.reducers: + if tp: + assert r.ranks == ((0, 2) if i in (0, 2) else (1, 3)) + else: + assert r.ranks == (0, 1, 2, 3) + assert len(r.buckets) == 1 + assert r.zero + assert r.zero_ngroups == 2 + for b in r.buckets: + assert b._contiguous_grads.device == torch.device('meta') + assert b._contiguous_params.device == torch.device('meta') diff --git a/tests/test_program.py b/tests/test_program.py index 95e49a7f..eb893040 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -21,20 +21,24 @@ def forward(self, x: dict): x = x + self.param1 return {'loss': torch.sum(x)} - CompileFlag.dev_mode = True - Program().clear() - - dummy_input = {'x': {'data': torch.randn(4, 4)}} - module = MyModule() - model = SemanticModel(module, save_content=False, dynamic_shape=False) - - obj = IRObject(value=dummy_input['x']) - model(obj) - graph = model.get_graph() - print(graph.extra_repr()) - - assert graph.input(0) == obj - # getitem - assert graph.node(0).input(0) == obj - # getitem - assert graph.node(1).input(0) == obj + old_dev_mode = CompileFlag.dev_mode + try: + CompileFlag.dev_mode = True + Program().clear() + + dummy_input = {'x': {'data': torch.randn(4, 4)}} + module = MyModule() + model = SemanticModel(module, save_content=False, dynamic_shape=False) + + obj = IRObject(value=dummy_input['x']) + model(obj) + graph = model.get_graph() + print(graph.extra_repr()) + + assert graph.input(0) == obj + # getitem + assert graph.node(0).input(0) == obj + # getitem + assert graph.node(1).input(0) == obj + finally: + CompileFlag.dev_mode = old_dev_mode diff --git a/tests/utils.py b/tests/utils.py index 26553254..511b710b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,19 @@ +import os +import sys +from typing import Type from contextlib import contextmanager from typing import Callable -import torch import math import random +from datetime import timedelta +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d + +from cube.runtime.module import ParallelModule +from cube.runtime.device import DeviceGroup, CompileFlag def init_parameter(model: torch.nn.Module, seed: int = 0): @@ -83,14 +94,16 @@ def replace_all_device_with(device='cpu'): orig_cpu = torch.Tensor.cpu def patch_tensor_constructor(fn): + orig_func = getattr(fn, '__cube_orig_func__', fn) # to support nested patching def wrapper(*args, **kwargs): kwargs["device"] =device - return fn(*args, **kwargs) - wrapper.__name__ = fn.__name__ - wrapper.__qualname__ = fn.__qualname__ + return orig_func(*args, **kwargs) + wrapper.__name__ = orig_func.__name__ + wrapper.__qualname__ = orig_func.__qualname__ + wrapper.__cube_orig_func__ = orig_func return wrapper # these constructors are enough for most cases - patched_tensor_constructors = [ + patched_tensor_constructor_names = [ 'empty', 'zeros', 'ones', 'full', 'eye', 'linspace', 'logspace', 'arange', 'rand', 'randn', 'randint', 'randperm', @@ -99,13 +112,25 @@ def wrapper(*args, **kwargs): ] old_tensor_constructors = { tf_name: getattr(torch, tf_name) - for tf_name in patched_tensor_constructors + for tf_name in patched_tensor_constructor_names } patched_tensor_constructors = { tf_name: patch_tensor_constructor(fn) for tf_name, fn in old_tensor_constructors.items() } + patched_tensor_member_constructor_names = [ + 'new_empty', 'new_zeros', 'new_ones', 'new_full', 'new_tensor' + ] + old_tensor_member_constructors = { + tf_name: getattr(torch.Tensor, tf_name) + for tf_name in patched_tensor_member_constructor_names + } + patched_tensor_member_constructors = { + tf_name: patch_tensor_constructor(fn) + for tf_name, fn in old_tensor_member_constructors.items() + } + def patched_to(self, *args, **kwargs): if len(args) > 0 and isinstance(args[0], (torch.device, str)): args[0] = device @@ -129,6 +154,10 @@ def patched_cpu(self, *args, **kwargs): for tf_name, fn in old_tensor_constructors.items(): setattr(torch, tf_name, patched_tensor_constructors[tf_name]) + # patch tensor member constructors + for tf_name, fn in old_tensor_member_constructors.items(): + setattr(torch.Tensor, tf_name, patched_tensor_member_constructors[tf_name]) + # patch concrete tracer's autowrap leaf function for tf_name, fn in old_tensor_constructors.items(): leaf_info = ConcreteTracer.default_autowrap_leaf_function.pop(fn, None) @@ -144,8 +173,106 @@ def patched_cpu(self, *args, **kwargs): ConcreteTracer.default_autowrap_leaf_function[ old_tensor_constructors[tf_name] ] = leaf_info + for tf_name, fn in old_tensor_member_constructors.items(): + setattr(torch.Tensor, tf_name, fn) for tf_name, fn in old_tensor_constructors.items(): setattr(torch, tf_name, fn) torch.Tensor.to = orig_to torch.Tensor.cuda = orig_cuda torch.Tensor.cpu = orig_cpu + + +# mock process group is from pytorch testing code +# import torch.testing._internal.distributed.distributed_utils + +class MockProcessGroup(dist.ProcessGroup): + def __init__(self, rank, world): + super().__init__(rank, world) + + def getBackendName(self): + return "cube_mock_pg" + + +def create_mock_pg(prefix_store, rank, world_size, timeout): + return MockProcessGroup(rank, world_size) + + +dist.Backend.register_backend('cube_mock_pg', create_mock_pg) + + +def mock_init_dist(rank, world_size): + if dist.is_initialized(): + raise ValueError("dist is already initialized, cannot mock init") + + store = dist.HashStore() + + dist.init_process_group( + backend="cube_mock_pg", + rank=rank, + world_size=world_size, + store=store, + group_name="cube_fake", + timeout=timedelta(seconds=1)) + + +@contextmanager +def mock_dist(rank, world_size): + """ + Mock dist.init_process_group for testing + """ + + old_store_based_barrier = c10d._store_based_barrier + try: + c10d._store_based_barrier = lambda *args, **kwargs: None + mock_init_dist(rank, world_size) + yield + finally: + dist.destroy_process_group() + c10d._store_based_barrier = old_store_based_barrier + + +@contextmanager +def mock_cube_env(cube_module_cls: Type[ParallelModule], compute_config): + old_device_group = DeviceGroup.instance + old_dev_mode = CompileFlag.dev_mode + used_cuda_fns = ['set_device', 'current_device', 'default_stream'] + old_cuda_fns = { + fname: getattr(torch.cuda, fname) + for fname in used_cuda_fns + } + torchrun_envs = ['RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_WORLD_SIZE', 'GROUP_RANK'] + old_envs = { + env: os.environ.get(env, None) + for env in torchrun_envs + } + try: + DeviceGroup.instance = None + CompileFlag.dev_mode = False + for fname, fn in old_cuda_fns.items(): + setattr(torch.cuda, fname, lambda *args, **kwargs: None) + os.environ['RANK'] = os.environ['LOCAL_RANK'] = str(cube_module_cls.rank) + os.environ['WORLD_SIZE'] = os.environ['LOCAL_WORLD_SIZE'] = str(compute_config.runtime_ngpus) + os.environ['GROUP_RANK'] = '0' + yield + finally: + for env, val in old_envs.items(): + if val is None: + del os.environ[env] + else: + os.environ[env] = val + for fname, fn in old_cuda_fns.items(): + setattr(torch.cuda, fname, fn) + CompileFlag.dev_mode = old_dev_mode + DeviceGroup.instance = old_device_group + + +def new_empty(cube_module_cls: Type[ParallelModule]): + """ + Create a new instance with empty weights. + + This is useful when you want to get model information (e.g. fullmap/zero) without allocating memory. + """ + module_file = Path(sys.modules[cube_module_cls.__module__].__file__) + compute_config = torch.load(module_file.with_name(f"{cube_module_cls.COMPUTE_CONFIG_FILE}")) + with replace_all_device_with('meta'), mock_cube_env(cube_module_cls, compute_config), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): + return cube_module_cls(init_params=False) From 6107dbb3b7bebfe5d2fb0c0feac57e13ba889428 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 4 Jan 2024 02:51:23 +0000 Subject: [PATCH 1555/1892] Merged PR 1966: skip reporting warnings on graph anchor operators skip reporting warnings on graph anchor operators --- cube/graph/graph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 3d5ce821..64adce04 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -203,7 +203,8 @@ def from_logic_graph(nodes: List[IRCell], if isinstance(obj, IRFullTensor) and obj.is_grad(): continue consumers = graph.consumers(obj) if len(consumers) == 0 and obj not in graph_output_objects: - if len(graph.producers(obj)) > 0: + producers = [n for n in graph.producers(obj) if not isinstance(n, IRGraphAnchor)] + if len(producers) > 0: unused_obj_nodes.setdefault(obj, []).extend(graph.producers(obj)) if len(unused_obj_nodes) > 0: dscp = (f'Following returns of nodes are not used by any other nodes.\n' From 0cca767ea98e595668d07a5828135237f1e8d6e0 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 4 Jan 2024 04:14:54 +0000 Subject: [PATCH 1556/1892] Merged PR 1979: disable ast test for python <= 3.8 disable ast test for python <= 3.8 --- tests/graph/parser/test_ast_transformer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 559e12a8..9e0aa18e 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -1,5 +1,8 @@ import ast from textwrap import dedent +import sys + +import pytest from cube.graph.parser.fx.concrete_trace_utils.operator_patcher import ( OperatorTransformer, @@ -8,6 +11,7 @@ transform ) +@pytest.mark.skipif(sys.version_info < (3, 9), reason='ast.unparse is not available in python3.8') def test_op_transfomer(): tree = ast.parse(dedent(''' x = True @@ -44,6 +48,7 @@ def test_op_transfomer(): ''').strip() +@pytest.mark.skipif(sys.version_info < (3, 9), reason='ast.unparse is not available in python3.8') def test_super_transform(): tree = ast.parse(dedent(''' class A: @@ -81,6 +86,7 @@ def i(self): ''').strip() +@pytest.mark.skipif(sys.version_info < (3, 9), reason='ast.unparse is not available in python3.8') def test_proxy_call_transform(): # the `(x+y)(a, b)` statement below just demonstrates # AST doesn't care about the real meaning of the expression. @@ -110,6 +116,7 @@ def f(self) -> None: ''').strip() +@pytest.mark.skipif(sys.version_info < (3, 9), reason='ast.unparse is not available in python3.8') def test_transform_combine(): tree = ast.parse(dedent(''' x = not True From 1d46d95543f510485b95aab53aa287e445bdc976 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 4 Jan 2024 10:28:50 +0000 Subject: [PATCH 1557/1892] Merged PR 1975: fix clz_wrapper_clz bug unit test passed parity check can not pass because of torch version is different, and the weight different is larger than 1e-6 As plan B of parity check, checked the gencode is same. --- .../concrete_trace_utils/concrete_tracer.py | 31 +++++++++++++++- .../parser/fx/concrete_trace_utils/utils.py | 1 + tests/graph/tracer/test_cls_wrapper.py | 37 +++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 tests/graph/tracer/test_cls_wrapper.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index bc357119..dc90c902 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -99,6 +99,7 @@ def __exit__(self, *args): _orig_range, _orig_int, + _orig_float, _orig_bool, _orig_tuple, _orig_list, @@ -255,6 +256,7 @@ class ConcreteTracer(TracerBase): # we don't want zip appear as a node in the graph # _orig_zip: ([], False), _orig_int: ([], False), + _orig_float: ([], False), # iterable class _orig_tuple: ([], True), @@ -1401,6 +1403,16 @@ def unwrap_detect_tracers(obj): return method_wrapper def _create_wrapped_leaf_class(tracer: ConcreteTracer, clz): + """ + Wrap a class as a tracable class, we usually wrap some classes that can be seen as creation functions. + For example, we can prevent the trace be interrupted by wrap ```tuple``` in the following case: + + ... + # x is a scalar + x_value = int(x) + new_x = torch.tensor([x_value, x_value]) + ... + """ class clz_wrapper_clz: @functools.wraps(clz) def __call__(self, *args, **kwargs): @@ -1422,9 +1434,26 @@ def __eq__(self, __o: object) -> bool: return id(__o) in (id(self), id(clz)) def __hash__(self): return id(self) - return clz_wrapper_clz() + clz_wrapper = clz_wrapper_clz() + for name in dir(clz): + attr = _orig_getattr(clz, name) + if not name.startswith('_'): + if _orig_isinstance(attr, Callable): + setattr(clz_wrapper, name, _create_wrapped_leaf_method(tracer, attr, name, None)) + else: + setattr(clz_wrapper, name, attr) + return clz_wrapper def _create_wrapped_leaf_iterable_class(tracer: ConcreteTracer, clz): + """ + Wrap a class as a tracable class, we usually wrap some classes that can be seen as creation functions. + For example, we can prevent the trace be interrupted by wrap ```tuple``` in the following case: + + ... + # x is a tensor + x_1st = tuple(x)[0] + ... + """ class clz_wrapper_clz: @functools.wraps(clz) def __call__(self, *args, **kwargs): diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index 24d23801..d2de8166 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -25,6 +25,7 @@ _orig_range: Type[Any] = builtins.range _orig_int: Type[Any] = builtins.int +_orig_float: Type[Any] = builtins.float _orig_bool: Type[Any] = builtins.bool _orig_tuple: Type[Any] = builtins.tuple _orig_list: Type[Any] = builtins.list diff --git a/tests/graph/tracer/test_cls_wrapper.py b/tests/graph/tracer/test_cls_wrapper.py new file mode 100644 index 00000000..7f4ef5dd --- /dev/null +++ b/tests/graph/tracer/test_cls_wrapper.py @@ -0,0 +1,37 @@ +import torch +import pytest + +from cube.graph.parser.converter import to_fx_graph + + +def test_cls_wrapper(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + # if we don't wrap tuple or float, the trace will raise error here + x_value = float(tuple(x)[0]) + x = torch.fill(torch.empty(1, 3), x_value) + return self.linear(x) + + dummy_input = {'x': torch.tensor([1.1, 1.2, 1.3])} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + # just check there is no error raised + assert True + + # The traced graph is: + # + # def forward(self, x): + # tuple_1 = tuple(x); x = None + # getitem = tuple_1[0]; tuple_1 = None + # float_1 = float(getitem); getitem = None + # empty = torch.empty(1, 3) + # fill = torch.fill(empty, float_1); empty = float_1 = None + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(fill, linear_weight, linear_bias); fill = linear_weight = linear_bias = None + # return linear From 83bead8677d7f03620c41f2b44b069d20a773ef4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 5 Jan 2024 04:12:54 +0000 Subject: [PATCH 1558/1892] Merged PR 1982: support setitem with SSA format support setitem by adding a cube runtime opeartor returning the first argument. --- cube/graph/parser/converter.py | 5 ++ .../concrete_trace_utils/concrete_tracer.py | 1 - .../parser/fx/concrete_trace_utils/utils.py | 1 + cube/runtime/function/function.py | 6 ++ tests/graph/tracer/test_inplace.py | 65 ++++++++++++++++++- 5 files changed, 76 insertions(+), 2 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index d18d7dfa..c16e62ec 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -49,6 +49,11 @@ def _rewrite_inplace_ops(traced_model: torch.fx.GraphModule): (n.op == "call_method" and n.target.endswith("_") and not n.target.endswith("__")) or (n.op == "call_function" and n.target in side_effectful_inplace_ops) ) and n.args[0].meta.get('type', None) == torch.Tensor: + # setitem is a special inplace operator that returns None instead of the first modified argument, + # to make it align with SSA format, we use cube runtime function to return the first argument + if n.op == "call_function" and n.target == operator.setitem: + n.meta = n.args[0].meta + n.target = cube_rt_function.setitem n.args[0].replace_all_uses_with(n, delete_user_cb=lambda node: not node in done_nodes) # we can't recompile # it will raise error if we have autograd, customized op, etc. diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index dc90c902..f2feff02 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -1785,7 +1785,6 @@ def f(x, y): # some side effectful functions that should not be deleted during dead code elimination # there may be more than listed here default_extra_side_effectful_functions = { - operator.setitem, builtins.next, *side_effectful_inplace_ops } diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index d2de8166..fe1abc74 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -60,6 +60,7 @@ operator.imod, operator.ipow, # operator.imatmul is not implemented in torch # so let's ignore it now + operator.setitem, } diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index ac275501..d77378a6 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -1,6 +1,7 @@ from typing import Optional, List, Tuple, Union import torch import torch.nn.functional as TorchF +import operator def identity(tensor: torch.Tensor) -> torch.Tensor: @@ -210,3 +211,8 @@ def cat(*tensors, dim=0) -> torch.Tensor: def nndropout(input: torch.Tensor, p=0.5, inplace=False): return torch.nn.Dropout(p, inplace)(input) + + +def setitem(__a, __b, __c): + operator.setitem(__a, __b, __c) + return __a \ No newline at end of file diff --git a/tests/graph/tracer/test_inplace.py b/tests/graph/tracer/test_inplace.py index 547bd06d..13ca00f3 100644 --- a/tests/graph/tracer/test_inplace.py +++ b/tests/graph/tracer/test_inplace.py @@ -1,8 +1,10 @@ import operator +import _operator import torch from cube.graph.parser.converter import to_fx_graph from cube.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops +import cube.runtime.function as cube_rt_function from ...utils import replace_all_device_with @@ -20,8 +22,13 @@ def test_side_effectful_inplace_ops(): bool_inplace_ops = { operator.iand, operator.ior, operator.ixor, } + # add _operator versions to make it sure it still equals with operator version + # in the future + complex_inplace_ops = { + operator.setitem, _operator.setitem + } - assert int_inplace_ops.union(float_inplace_ops, bool_inplace_ops) == side_effectful_inplace_ops + assert int_inplace_ops.union(float_inplace_ops, bool_inplace_ops, complex_inplace_ops) == side_effectful_inplace_ops not_implemented_mat_inplace_ops = { operator.imatmul @@ -129,3 +136,59 @@ def test_inplace_op(): assert nodes[8].op == 'output' assert nodes[8].args[0] == nodes[7] # return value of `return w` + +class SetItemModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.num_heads = 4 + self.graph_token_virtual_distance = torch.nn.Embedding(1, self.num_heads) + + def forward(self, x): + t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1) + x[:, :, 1:, 0] = ( + x[:, :, 1:, 0] + t + ) + x[:, :, 0, :] = ( + x[:, :, 0, :] + t + ) + return x + + +@replace_all_device_with('cpu') +def test_inplace_setitem_op(): + model = SetItemModule() + dummy_input = {'x': torch.rand(4,4,4,4)} + + traced_graph = to_fx_graph(model, dummy_input) + + assert torch.equal( + model(dummy_input['x']), + traced_graph(dummy_input['x']) + ) + + nodes = list(traced_graph.graph.nodes) + for idx, node in enumerate(nodes): + target_name = node.target.__name__ if hasattr(node.target, '__name__') else node.target + print(f'{idx}: ({node.op}) {node} = {target_name}{node.args}') + + assert nodes[0].op == 'placeholder' + assert nodes[0].name == 'x' + + assert nodes[1].op == 'get_attr' + assert nodes[1].name == 'graph_token_virtual_distance_weight' + + assert nodes[5].op == 'call_function' + assert nodes[5].name == 'setitem' + assert nodes[5].target == cube_rt_function.setitem + + assert nodes[6].op == 'call_function' + assert nodes[6].name == 'getitem_1' + assert nodes[6].args[0] == nodes[5] + + assert nodes[8].op == 'call_function' + assert nodes[8].name == 'setitem_1' + assert nodes[8].target == cube_rt_function.setitem + assert nodes[8].args[0] == nodes[5] + + print(nodes[9].args[0]) + assert nodes[9].args[0] == nodes[8] From f0cfe09fafd78c34e1288344b415d93e20054a4d Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 5 Jan 2024 04:24:42 +0000 Subject: [PATCH 1559/1892] Merged PR 1980: Set up CI with Start Right add pipeline to cube. --- azure-pipelines.yml | 19 +++++++++++++++++++ tests/algorithm/ops/test_dimops.py | 4 ++++ tests/graph/gener/__init__.py | 0 tests/graph/gener/test_reducer_gen.py | 11 ++++++++--- tests/graph/parser/test_converter.py | 5 +++++ tests/graph/parser/test_dce.py | 5 ++++- tests/graph/parser/test_no_grad.py | 5 ++++- tests/graph/parser/test_parser.py | 7 ++++++- tests/graph/parser/test_register.py | 7 ++++++- tests/graph/tracer/test_getattr.py | 6 +++++- tests/parallel_module/test_gencode.py | 15 ++++----------- tests/parallel_module/test_inference.py | 5 +++++ tests/parallel_module/test_init.py | 4 +--- tests/parallel_module/test_nested.py | 4 +--- tests/parallel_module/test_override.py | 5 +---- tests/parallel_module/test_reducer_hook.py | 17 +++++------------ tests/parallel_module/test_submodule.py | 16 ++++------------ tests/parallel_module/test_wholemodule.py | 9 +++------ tests/runtime/test_runtime_collectives.py | 10 +++------- tests/test_program.py | 2 ++ tests/utils.py | 9 +++++++-- 21 files changed, 97 insertions(+), 68 deletions(-) create mode 100644 azure-pipelines.yml create mode 100644 tests/graph/gener/__init__.py diff --git a/azure-pipelines.yml b/azure-pipelines.yml new file mode 100644 index 00000000..b1769e99 --- /dev/null +++ b/azure-pipelines.yml @@ -0,0 +1,19 @@ +# Starter pipeline +# Start with a minimal pipeline that you can customize to build and deploy your code. +# Add steps that build, run tests, deploy, and more: +# https://aka.ms/yaml + +trigger: +- main + +pool: + vmImage: ubuntu-latest + +steps: +- script: | + pip install tox + pip install tox-conda + displayName: 'Install tox' +- script: | + tox + displayName: 'Run unit tests' diff --git a/tests/algorithm/ops/test_dimops.py b/tests/algorithm/ops/test_dimops.py index 232c29f1..6c3c8590 100644 --- a/tests/algorithm/ops/test_dimops.py +++ b/tests/algorithm/ops/test_dimops.py @@ -5,6 +5,8 @@ from cube.ir.operator import IRFwOperation from cube.algorithm.ops.dimops import gen_partitions +from ...utils import replace_all_device_with + class NaiveFFN(torch.nn.Module): def __init__(self): super().__init__() @@ -18,6 +20,8 @@ def forward(self, x): x = self.linear2(x) return x + +@replace_all_device_with('cpu') def test_gen_partitions(): with tempfile.TemporaryDirectory() as tempdir: graph, _ = _gen_graph(NaiveFFN(), {'x': torch.randn(2, 128, 1024)}, tempdir, False) diff --git a/tests/graph/gener/__init__.py b/tests/graph/gener/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/graph/gener/test_reducer_gen.py b/tests/graph/gener/test_reducer_gen.py index 104d1271..9e74df82 100644 --- a/tests/graph/gener/test_reducer_gen.py +++ b/tests/graph/gener/test_reducer_gen.py @@ -11,6 +11,8 @@ import torch import tempfile +from ...utils import replace_all_device_with + def make_param(shape, dtype) -> IRFullTensor: param = IRFullTensor(shape=shape, dtype=dtype, requires_grad=True) @@ -31,7 +33,7 @@ def forward(self, x): x = x + self.param1 x = torch.sum(x) return x - + def build_graph(): # build graph @@ -47,6 +49,7 @@ def build_graph(): return graph +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_cross_segment_weight_reducer(): graph = build_graph() @@ -59,7 +62,7 @@ def test_cross_segment_weight_reducer(): continue for node in segment.nodes(): graph.assign(node, idx) - + print(graph.extra_repr()) # build reducer @@ -72,6 +75,7 @@ def test_cross_segment_weight_reducer(): assert reducers[0].device == (0, 1) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_replicate_shared_param(): graph = build_graph() @@ -87,6 +91,7 @@ def test_replicate_shared_param(): assert len(reducers) == 0 +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_reducer_partially_shared_part(): graph = build_graph() [matmul1, matmul2, add, sum] = graph.select(ntype=IRFwOperation) @@ -103,7 +108,7 @@ def test_reducer_partially_shared_part(): sn1, sn2 = graph.replicate(node, 2) graph.assign(sn1, 0) graph.assign(sn2, 1) - + print(graph.extra_repr()) with pytest.raises(RuntimeError): diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 80e2fe8b..b933c3ab 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -9,7 +9,10 @@ from cube.graph.parser import FxModuleParser from cube.ir.cten import IRObject, IRTensor +from ...utils import replace_all_device_with + +@replace_all_device_with('cpu') def test_to_graph(): class MyModule(torch.nn.Module): def __init__(self): @@ -58,6 +61,7 @@ def forward(self, x, **kwargs): assert any(node.signature == 'torch.nn.functional.linear' for node in nodes) +@replace_all_device_with('cpu') def test_record_codeline(): class MyModule(torch.nn.Module): def __init__(self): @@ -78,6 +82,7 @@ def forward(self, x, *args): raise RuntimeError(err_msg) +@replace_all_device_with('cpu') def test_record_metadata(): class MyModule(torch.nn.Module): def __init__(self): diff --git a/tests/graph/parser/test_dce.py b/tests/graph/parser/test_dce.py index a7138e39..49156d4b 100644 --- a/tests/graph/parser/test_dce.py +++ b/tests/graph/parser/test_dce.py @@ -3,14 +3,17 @@ from cube.graph.parser.converter import to_fx_graph +from ...utils import replace_all_device_with + +@replace_all_device_with('cpu') def test_dce_useless_next(): class SimpleModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.m_dict = torch.nn.ModuleDict({'test': torch.nn.Linear(10, 8)}) self.m_keys = ['test', 'default'] - + def forward(self, x): for key in self.m_keys: # TODO: (ning) This kind of code style will call instruction 'BINARY_SUBSCR' diff --git a/tests/graph/parser/test_no_grad.py b/tests/graph/parser/test_no_grad.py index a8d48943..444d6094 100644 --- a/tests/graph/parser/test_no_grad.py +++ b/tests/graph/parser/test_no_grad.py @@ -3,14 +3,17 @@ from cube.graph.parser.converter import to_fx_graph +from ...utils import replace_all_device_with + +@replace_all_device_with('cpu') def test_no_grad(): class SimpleModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(10, 10) self.fc2 = torch.nn.Linear(10, 10) - + def forward(self, x): with torch.no_grad(): x = self.fc1(x) diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index e9cb16fc..8e261eec 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -3,7 +3,10 @@ from cube.ir.cten import IRObject, IRTensor from cube.graph.parser.converter import to_fx_graph, to_ir_graph +from ...utils import replace_all_device_with + +@replace_all_device_with('cpu') def test_multi_consume(): class MyModule(torch.nn.Module): @@ -34,6 +37,7 @@ def forward(self, x): assert len(ir_graph.full_tensors()) == 8 +@replace_all_device_with('cpu') def test_parser_nested_inputs(): class MyModule(torch.nn.Module): @@ -58,7 +62,7 @@ def forward(self, x: dict): with tempfile.TemporaryDirectory() as tempdir: ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) print(ir_graph.extra_repr()) - + assert len(ir_graph.inputs()) == 1 assert isinstance(ir_graph.input(0), IRObject) assert isinstance(ir_graph.input(0).value, dict) @@ -68,6 +72,7 @@ def forward(self, x: dict): assert isinstance(ir_graph.output(0)['loss'], IRTensor) +@replace_all_device_with('cpu') def test_max(): class MyModule(torch.nn.Module): diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index a08e6699..544bd5e4 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -4,6 +4,8 @@ import tempfile import torch +from ...utils import replace_all_device_with + def mock_add(x: torch.Tensor, y: torch.Tensor): return x + y @@ -57,6 +59,7 @@ def forward(self, x, y): # passed test +@replace_all_device_with('cpu') def test_common_register(): model = MockModel() with tempfile.TemporaryDirectory() as tempdir: @@ -68,6 +71,7 @@ def test_common_register(): assert profile_name == p_name, f'{profile_name} should be {p_name}' +@replace_all_device_with('cpu') def test_common_register2(): model = MockModel2() with tempfile.TemporaryDirectory() as tempdir: @@ -79,11 +83,12 @@ def test_common_register2(): assert profile_name == p_name, f'{profile_name} should be {p_name}' +@replace_all_device_with('cpu') def test_autograd_register(): model = MockModel3() with tempfile.TemporaryDirectory() as tempdir: ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) - + # test profiler.database for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'Function.apply']): profile_name = ProfileDataBase.get_func(node)[0].__qualname__ diff --git a/tests/graph/tracer/test_getattr.py b/tests/graph/tracer/test_getattr.py index 09e43e0a..7a6e22c3 100644 --- a/tests/graph/tracer/test_getattr.py +++ b/tests/graph/tracer/test_getattr.py @@ -2,14 +2,18 @@ from cube.graph.parser.converter import to_fx_graph +from ...utils import replace_all_device_with + class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() - + def forward(self, x): return torch.nn.functional.dropout(x, 0.1, self.training) + +@replace_all_device_with('cpu') def test_getattr_from_root(): model = SimpleModel() dummy_input = {'x': torch.rand(10)} diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 4419b23e..0c76aad9 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -37,10 +37,8 @@ def _gencode_worker(tempdir): _to_cube_model(m, ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_codegen(): - if not torch.cuda.is_available(): - print('skip test_codegen due to lack of cuda devices') - return with tempfile.TemporaryDirectory() as tempdir: m = Module0() m_new = _to_cube_model(m, ComputeConfig(2, 4), cube_savedir=tempdir, load_module=False) @@ -140,14 +138,12 @@ def _gencode_unused_args_worker(tempdir): # y must be None m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_codegen_unused_args(): """ Verify that unused args are supported by parallalize """ - if not torch.cuda.is_available(): - print('skip test_unused_input due to lack of cuda devices') - return - with tempfile.TemporaryDirectory() as tempdir: launch_torchrun(1, _gencode_unused_args_worker, tempdir) @@ -191,14 +187,11 @@ def _gencode_unused_args_worker2(tempdir): m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_codegen_unused_args2(): """ Verify that unused args are supported by parallalize """ - if not torch.cuda.is_available(): - print('skip test_codegen_unused_args2 due to lack of cuda devices') - return - with tempfile.TemporaryDirectory() as tempdir: launch_torchrun(1, _gencode_unused_args_worker2, tempdir) diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index eaf33c49..2e5d1928 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -1,6 +1,8 @@ from pathlib import Path import shutil import tempfile + +import pytest import torch from torch import nn @@ -71,9 +73,12 @@ def _inference_worker(ngpus): cube_result = cube_model(data) assert torch.allclose(result, cube_result, atol=1e-4) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_inference1(): torchrun(1, _inference_worker, 1) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_inference2(): torchrun(2, _inference_worker, 2) diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index a6b0826f..ea94f3ab 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -44,10 +44,8 @@ def _init_params_worker(): assert torch.all(p3 == 0) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_init_params(): - if not torch.cuda.is_available(): - print('skip test_init_params due to lack of cuda devices') - return launch_torchrun(1, _init_params_worker) diff --git a/tests/parallel_module/test_nested.py b/tests/parallel_module/test_nested.py index 080ee1d5..4e64d23d 100644 --- a/tests/parallel_module/test_nested.py +++ b/tests/parallel_module/test_nested.py @@ -48,8 +48,6 @@ def forward(self, x): _to_cube_model(Module2(), PASData, ComputeConfig(1, 1), cube_savedir=tempdir) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_nested_module(): - if not torch.cuda.is_available(): - print('skip test_nested_module due to lack of cuda devices') - return launch_torchrun(1, _nested_module_worker) diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index dd6b0375..38a1d2c1 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -119,10 +119,7 @@ def _worker(): assert (test5_module_path / 'graph.ckp').stat().st_mtime_ns != graph_stat.st_mtime_ns assert (test5_module_path / 'forward_args.pkl').exists() - +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_override(): - if not torch.cuda.is_available(): - print('skip test_submodules_tp_gpu1 due to lack of cuda devices') - return launch_torchrun(1, _worker) diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index a8bff918..27d4a688 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -2,6 +2,7 @@ from pathlib import Path from collections import defaultdict +import pytest import torch from torch import nn @@ -116,29 +117,21 @@ def _gpu_worker(pas, ngpus): _train(compiled_module) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_hook_tp_gpu1(): - if not torch.cuda.is_available(): - print('skip test_hook_tp_gpu1 due to lack of cuda devices') - return launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_hook_tp_gpu2(): - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - print('skip test_hook_tp_gpu2 due to lack of cuda devices') - return launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_hook_dp_gpu1(): - if not torch.cuda.is_available(): - print('skip test_hook_dp_gpu1 due to lack of cuda devices') - return launch_torchrun(1, _gpu_worker, PASData, 1) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_hook_dp_gpu2(): - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - print('skip test_hook_dp_gpu2 due to lack of cuda devices') - return launch_torchrun(2, _gpu_worker, PASData, 2) diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 6fc72b2c..1ac54ea1 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -125,11 +125,9 @@ def _gpu_worker(pas, ngpus, update_freq): ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @pytest.mark.parametrize('update_freq', [1, 2, 4]) def test_submodules_tp_gpu1(update_freq): - if not torch.cuda.is_available(): - print('skip test_submodules_tp_gpu1 due to lack of cuda devices') - return results = launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1, update_freq) orig_results, compiled_results, _, _, _, _ = results[0] for orig, compiled in zip(orig_results, compiled_results): @@ -186,11 +184,9 @@ def _compare_weights(orig0, orig1, compiled0, compiled1, fc1_fullmap, fc2_fullma assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') @pytest.mark.parametrize('update_freq', [1, 2, 4]) def test_submodules_tp_gpu2(update_freq): - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - print('skip test_submodules_tp_gpu2 due to lack of cuda devices') - return results = launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2, update_freq) results0, results1 = results[0], results[1] eps = 1e-4 @@ -221,11 +217,9 @@ def test_submodules_tp_gpu2(update_freq): _compare_weights(orig0[3], orig1[3], compiled0[3], compiled1[3], fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='cuda is not available') @pytest.mark.parametrize('update_freq', [1, 2, 4]) def test_submodules_dp_gpu1(update_freq): - if not torch.cuda.is_available(): - print('skip test_submodules_dp_gpu1 due to lack of cuda devices') - return results = launch_torchrun(1, _gpu_worker, PASData, 1, update_freq) orig_results, compiled_results, _, _, _, _ = results[0] for orig, compiled in zip(orig_results, compiled_results): @@ -245,11 +239,9 @@ def test_submodules_dp_gpu1(update_freq): assert torch.allclose(orig[3][k], compiled_cleaned[k.replace('.', '_')], rtol=1e-6, atol=1e-6) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') @pytest.mark.parametrize('update_freq', [1, 2, 4]) def test_submodules_dp_gpu2(update_freq): - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - print('skip test_submodules_dp_gpu2 due to lack of cuda devices') - return eps = 1e-4 results = launch_torchrun(2, _gpu_worker, PASData, 2, update_freq) for r in results.values(): diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index e9307846..4f787cfb 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -4,6 +4,7 @@ from pathlib import Path import shutil +import pytest import torch from torch import nn import numpy as np @@ -110,10 +111,8 @@ def _gpu_worker(pas, ngpus): ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_module_tp_gpu1(): - if not torch.cuda.is_available(): - print('skip test_submodules_tp_gpu1 due to lack of cuda devices') - return results = launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) orig_results, compiled_results, _, _ = results[0] for orig, compiled in zip(orig_results, compiled_results): @@ -141,10 +140,8 @@ def _compare_weights(orig0, orig1, compiled0, compiled1, module_fullmap, module_ assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_module_tp_gpu2(): - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - print('skip test_submodules_tp_gpu2 due to lack of cuda devices') - return results = launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) results0, results1 = results[0], results[1] eps = 1e-4 diff --git a/tests/runtime/test_runtime_collectives.py b/tests/runtime/test_runtime_collectives.py index ff72261e..a31dc5aa 100644 --- a/tests/runtime/test_runtime_collectives.py +++ b/tests/runtime/test_runtime_collectives.py @@ -2,6 +2,7 @@ import cube import torch +import pytest from ..launch_torchrun import launch_torchrun, clone_to_cpu @@ -98,11 +99,8 @@ def _2gpu_worker(): return result - +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_2gpu(): - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - print('skip test_2gpu due to lack of cuda devices') - return results = launch_torchrun(2, _2gpu_worker) for op in ['', '_async']: @@ -189,10 +187,8 @@ def _3gpu_worker(): return result +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 3, reason='lack of gpu devices') def test_3gpu(): - if not torch.cuda.is_available() or torch.cuda.device_count() < 3: - print('skip test_3gpu due to lack of cuda devices') - return results = launch_torchrun(3, _3gpu_worker) for op in ['', '_async']: diff --git a/tests/test_program.py b/tests/test_program.py index eb893040..7dbc3a34 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -1,3 +1,4 @@ +import pytest import torch from cube.program import SemanticModel, Program @@ -5,6 +6,7 @@ from cube.ir.cten import IRObject +@pytest.mark.skipif(not torch.cuda.is_available(), reason='cuda is not available') def test_program_model_nested_input(): class MyModule(torch.nn.Module): diff --git a/tests/utils.py b/tests/utils.py index 511b710b..8ceeecf8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -86,7 +86,12 @@ def assert_same_complex(gt, out): @contextmanager -def replace_all_device_with(device='cpu'): +def replace_all_device_with(device='cpu', force=False): + if not force and torch.cuda.is_available(): + # do not replace device if cuda is available + yield + return + from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import ConcreteTracer orig_to = torch.Tensor.to @@ -274,5 +279,5 @@ def new_empty(cube_module_cls: Type[ParallelModule]): """ module_file = Path(sys.modules[cube_module_cls.__module__].__file__) compute_config = torch.load(module_file.with_name(f"{cube_module_cls.COMPUTE_CONFIG_FILE}")) - with replace_all_device_with('meta'), mock_cube_env(cube_module_cls, compute_config), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): + with replace_all_device_with('meta', True), mock_cube_env(cube_module_cls, compute_config), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): return cube_module_cls(init_params=False) From 326bc6d04994c46064dfbda351ce1690b26101b1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 5 Jan 2024 09:19:39 +0000 Subject: [PATCH 1560/1892] Merged PR 1984: fix build break fix build break --- tests/graph/tracer/test_cls_wrapper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/graph/tracer/test_cls_wrapper.py b/tests/graph/tracer/test_cls_wrapper.py index 7f4ef5dd..a88569f6 100644 --- a/tests/graph/tracer/test_cls_wrapper.py +++ b/tests/graph/tracer/test_cls_wrapper.py @@ -3,7 +3,10 @@ from cube.graph.parser.converter import to_fx_graph +from ...utils import replace_all_device_with + +@replace_all_device_with('cpu') def test_cls_wrapper(): class MyModule(torch.nn.Module): def __init__(self): From f2ae6e71b538c6fdb7fb571d1f2e31d3efb62666 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 5 Jan 2024 11:13:35 +0000 Subject: [PATCH 1561/1892] Merged PR 1978: Pipeline Support-2: remove useless output adapter in schedule This PR will remove useless backward adapter (no communications) caused by forward adapters of broadcasting graph outputs passed parity check and unit test --- cube/graph/gener/gen.py | 4 ++++ cube/graph/schedule/predefined.py | 8 ++++++++ examples/mlp/policy/gallery.py | 4 ++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 99e9d2a3..8732032f 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -428,6 +428,10 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: badapter: Optional[IRAdapter] = fadapter.mirror + # skip badapter if it doesn't contain any primitives + if not fadapter.differentiable and (badapter is not None and len(badapter.prims) == 0): + badapter = None + if (badapter is not None and len(fadapter.prims) == 0 and len(badapter.prims) == 0) or \ (badapter is None and len(fadapter.prims) == 0): continue diff --git a/cube/graph/schedule/predefined.py b/cube/graph/schedule/predefined.py index 7a197b12..bfbd3506 100644 --- a/cube/graph/schedule/predefined.py +++ b/cube/graph/schedule/predefined.py @@ -24,6 +24,8 @@ def sched_1f1b(graph: IRGraph, num_microbatches: int, num_stages: int) -> Schedu f0 b0 f1 | b1 f2 | b2 f3 b3 ``` """ + if num_microbatches <= 0: + raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) fsegs = [seg for seg in segments if seg.isfw()] assert len(fsegs) == num_stages, f"Mismatch of forward segement number ({len(fsegs)}) with num_stages ({len(num_stages)})" @@ -58,6 +60,8 @@ def sched_1f1b_plus(graph: IRGraph, num_microbatches: int, num_stages: int) -> S f0 f1 f0 f2 f1 b0 | f3 f2 b1 | f3 b2 b3 f0 f1 f0 f2 b0 f1 | f3 b1 f2 | b2 f3 b """ + if num_microbatches <= 0: + raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) fsegs = [seg for seg in segments if seg.isfw()] tp_fsegs = [seg for seg in fsegs if len(seg.device) == len(graph.device)] @@ -130,6 +134,8 @@ def sched_gpipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> Sched f0 f1 f2 f3 b0 b1 b2 b3 ``` """ + if num_microbatches <= 0: + raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) fsegs = [seg for seg in segments if seg.isfw()] assert len(fsegs) == num_stages, "Mismatch of forward segement number with num_stages" @@ -168,6 +174,8 @@ def sched_chimera_direct(graph: IRGraph, num_microbatches: int, num_stages: int) Note the f0 and f2 (step 0) should be considered to be one segment in graph. """ + if num_microbatches <= 0: + raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) fsegs = [seg for seg in segments if seg.isfw()] assert len(fsegs) == 4, f"Chimera-direct scheduling only applies for 4 segments, but {len(segments)} detected" diff --git a/examples/mlp/policy/gallery.py b/examples/mlp/policy/gallery.py index 06f91aa0..5df705ac 100644 --- a/examples/mlp/policy/gallery.py +++ b/examples/mlp/policy/gallery.py @@ -79,10 +79,10 @@ def PASMegatron(graph: IRGraph, resource, nmicros: int, tp_size: int, **kwargs) if node.name == 'linear': tensor_parallelism(graph, node, idx=1, dim=idx%2, devs=tp_group) else: - replica(node, devs=tp_group) + replica(graph, node, devs=tp_group) for dl in graph.select(ntype=IRDataOperation): - replica(dl, devs=list(range(resource.ngpus))) + replica(graph, dl, devs=list(range(resource.ngpus))) PredefinedSched.sched_1f1b(graph, nmicros, num_stages) return graph From c84b53b8e571966fac7f8e20ac61e34bf9f13e8f Mon Sep 17 00:00:00 2001 From: Quanlu Zhang Date: Mon, 8 Jan 2024 02:01:16 +0000 Subject: [PATCH 1562/1892] Merged PR 1942: improve profile speed of AutoDist improve profile speed by giving the upper bound of the profile time for each operator. --- cube/profiler/database.py | 30 ++++++++++++++++++----------- tests/profiler/__init__.py | 0 tests/profiler/test_op_profile.py | 32 +++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 11 deletions(-) create mode 100644 tests/profiler/__init__.py create mode 100644 tests/profiler/test_op_profile.py diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 429d5f3b..e30b5e7b 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -7,6 +7,7 @@ import time import os import json +import math import logging from dataclasses import dataclass, asdict @@ -57,7 +58,7 @@ class CompProfiler: @staticmethod def profile(node: IRFwOperation, func: Callable, shapes: Shapes, dtypes: DTypes, requires_grads: Tuple[bool], values: Tuple[Any], - warmup_sec: float = 2, prof_times: int = 50, + warmup_sec: float = 2, prof_times: int = 20, max_prof_sec: float = 20, **kwargs) -> Tuple[float, float, int, Tuple[int]]: """ Profile a function @@ -70,7 +71,8 @@ def profile(node: IRFwOperation, func: Callable, shapes: Shapes, dtypes: DTypes, requires_grads Tuple[bool]: whether the input tensor requires gradient values Tuple[Any]: the values of the inputs that are not IRTensor warmup_sec float: warmup seconds - prof_times int: profile times + prof_times int: number of execution for profiling an operator + max_prof_sec float: max seconds for profiling an operator's forward or backward kwargs Dict: other keyword argument for func call. Returns: @@ -142,8 +144,8 @@ def run_step(func, tensors, kwargs, backward: bool): # ref torch/utils/checkpoint.py/_checkpoint_without_reentrant def pack_hook(x): nonlocal train_mem_info, used_tensor - if x.storage().data_ptr() not in used_tensor: - used_tensor.add(x.storage().data_ptr()) + if x.untyped_storage().data_ptr() not in used_tensor: + used_tensor.add(x.untyped_storage().data_ptr()) byte_size = x.element_size() for dim in list(x.size()): byte_size = byte_size * dim @@ -152,7 +154,7 @@ def pack_hook(x): for i, t in enumerate(tensors): if not isinstance(t, torch.Tensor): continue - if t.storage().data_ptr() == x.storage().data_ptr(): + if t.untyped_storage().data_ptr() == x.untyped_storage().data_ptr(): if node.inputs()[i].is_attr(): is_attr = True idx = i @@ -171,28 +173,34 @@ def unpack_hook(x): outs = run_step(func, tensors, train_kwargs, backward=require_backward) # warmup - tic = time.time() - while time.time() - tic < warmup_sec: + warmup_cnt = 0 + tic = time.perf_counter() + while time.perf_counter() - tic < warmup_sec: run_step(func, tensors, train_kwargs, backward=require_backward) + torch.cuda.synchronize() + warmup_cnt += 1 + toc = time.perf_counter() + func_duration = (toc - tic) / warmup_cnt + real_prof_times = max(1, min(prof_times, math.ceil(max_prof_sec / func_duration))) # profile forward only torch.cuda.synchronize() tic = time.perf_counter() - for _ in range(prof_times): + for _ in range(real_prof_times): with torch.no_grad(): run_step(func, tensors, eval_kwargs, backward=False) torch.cuda.synchronize() toc = time.perf_counter() - fw_span = (toc - tic) / prof_times * 1000 # in milliseconds + fw_span = (toc - tic) / real_prof_times * 1000 # in milliseconds # profile forward + backward torch.cuda.synchronize() tic = time.perf_counter() - for _ in range(prof_times): + for _ in range(real_prof_times): run_step(func, tensors, train_kwargs, backward=require_backward) torch.cuda.synchronize() toc = time.perf_counter() - fwbw_span = (toc - tic) / prof_times * 1000 # in milliseconds + fwbw_span = (toc - tic) / real_prof_times * 1000 # in milliseconds bw_span = max(fwbw_span - fw_span, 0.0) return fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx diff --git a/tests/profiler/__init__.py b/tests/profiler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/profiler/test_op_profile.py b/tests/profiler/test_op_profile.py new file mode 100644 index 00000000..c5a90581 --- /dev/null +++ b/tests/profiler/test_op_profile.py @@ -0,0 +1,32 @@ +import time +import tempfile +import torch +from cube.parallel import _gen_graph +from cube.ir.operator import IRFwOperation +from cube.profiler.database import CompProfiler, ProfileDataBase +from ..utils import replace_all_device_with + +class NaiveFFN(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(1024, 4096, bias=False) + self.linear2 = torch.nn.Linear(4096, 1024, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + +@replace_all_device_with('cpu') +def test_op_profile_times(): + with tempfile.TemporaryDirectory() as tempdir: + graph, _ = _gen_graph(NaiveFFN(), {'x': torch.randn(2, 128, 1024)}, tempdir, False) + fc1, relu, fc2 = graph.select(ntype=IRFwOperation) + fn, shapes, dtypes, requires_grads, values, kwargs = ProfileDataBase.get_func(fc1) + tic = time.perf_counter() + CompProfiler.profile(fc1, fn, shapes, dtypes, requires_grads, values, **kwargs) + toc = time.perf_counter() + # this is always true because the op is very small. + assert toc - tic < 20, f'op profile time is too long {toc - tic}' From e123c38c8963a72e9572938d52b48352f3f0cace Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 8 Jan 2024 04:25:34 +0000 Subject: [PATCH 1563/1892] Merged PR 1985: fix patcher for torch 2.0.0 fix patcher for torch 2.0.0 --- .../parser/fx/concrete_trace_utils/operator_patcher.py | 6 +++++- tests/profiler/test_op_profile.py | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index e289513e..f7a4efe9 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -163,7 +163,11 @@ def patch_inner(self, func): def patch_inner_helper(self, func): if not hasattr(func, '__module__') or func.__module__ is None or func.__module__.startswith('torch'): return func - if hasattr(func, '_Patcher__fx_already_patched'): + # those flags are set by fx _Patcher when a method is patched + # we don't want to patch it again + # _Patcher__fx_already_patched is for torch 2.0.1+ + # __fx_already_patched is for torch 2.0.0 + if hasattr(func, '_Patcher__fx_already_patched') or hasattr(func, '__fx_already_patched'): return func if self.use_operator_patch == (func in self.operator_patch_backlist): return func diff --git a/tests/profiler/test_op_profile.py b/tests/profiler/test_op_profile.py index c5a90581..d2ef7252 100644 --- a/tests/profiler/test_op_profile.py +++ b/tests/profiler/test_op_profile.py @@ -1,10 +1,13 @@ import time import tempfile + +import pytest import torch + from cube.parallel import _gen_graph from cube.ir.operator import IRFwOperation from cube.profiler.database import CompProfiler, ProfileDataBase -from ..utils import replace_all_device_with + class NaiveFFN(torch.nn.Module): def __init__(self): @@ -19,7 +22,8 @@ def forward(self, x): x = self.linear2(x) return x -@replace_all_device_with('cpu') + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_op_profile_times(): with tempfile.TemporaryDirectory() as tempdir: graph, _ = _gen_graph(NaiveFFN(), {'x': torch.randn(2, 128, 1024)}, tempdir, False) From 84f852f86b80b73f9098038fc7632889ba987d72 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 8 Jan 2024 09:09:38 +0000 Subject: [PATCH 1564/1892] Merged PR 1986: update README for ci update README for ci --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index 086a53ae..8aa344bd 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,16 @@ To run a single unit test task during development, you can run pytest tests/your_test_file.py ``` +### Unit tests in AzureDevops pipeline + +We use AzureDevops to run unit tests before you can merge your PR to main branch. You can find the pipeline definition in `azure-pipelines.yml`. + +Please note that in AzureDevops pipeline agent, no gpu is available. So you must make sure your unit tests can run on cpu to pass the CI. Two options are available: +1. Use `@replace_all_device_with('cpu')` decorator to replace all devices with cpu. Please refer to other tests for example. +2. Mark your test case only work on gpu by using `@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices')` decorator. Please refer to existing tests for example. + +Before you push your code, please run tests at least on GPU machines to make sure all tests can pass. GPU test cases can't be run in AzureDevops pipeline. Of course, it would be better if you can run all tests on both GPU and CPU machines. + ### Run unit tests in vscode VS Code has a great support to unit tests. You can run/debug every tests easily in VS Code. Please refer to this document to set up your environment https://code.visualstudio.com/docs/python/testing From ba1920ec35b8ad58676bac5db4205571ba954fa6 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 9 Jan 2024 08:48:38 +0000 Subject: [PATCH 1565/1892] Merged PR 1981: support apex norm memory efficient option apex add an option `memory_efficient` to normalization from this [PR](https://github.com/NVIDIA/apex/commit/6ff45486f432f91eb86937a0def5eb5f2cf792ae). --- cube/graph/parser/external/apex.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cube/graph/parser/external/apex.py b/cube/graph/parser/external/apex.py index 08cd9461..ba50f074 100644 --- a/cube/graph/parser/external/apex.py +++ b/cube/graph/parser/external/apex.py @@ -31,7 +31,7 @@ def apex_fused_layer_norm_anno(input, normalized_shape, *args, **kwargs): apex_fused_rms_norm_anno = apex_fused_layer_norm_anno - def apex_fused_layer_norm_affine_anno(input, weight, bias, normalized_shape, eps) -> str: + def apex_fused_layer_norm_affine_anno(input, weight, bias, normalized_shape, eps, *args, **kwargs) -> str: """ apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction """ @@ -46,10 +46,11 @@ def apex_fused_layer_norm_affine_anno(input, weight, bias, normalized_shape, eps inputs.append(ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) if weight is not None else '?') inputs.append(ShapeAnno.create_shape_str(bias.shape, reduction='^', iterator=letters) if bias is not None else '?') inputs += ['?', '?'] + inputs += ['?' for _ in args] return OpAnno.create_op_str(inputs, outputs) - def apex_fused_rms_norm_affine_anno(input, weight, normalized_shape, eps) -> str: + def apex_fused_rms_norm_affine_anno(input, weight, normalized_shape, eps, *args, **kwargs) -> str: """ apex.normalization.fused_layer_norm.FusedRMSNormAffineFunction """ @@ -62,6 +63,7 @@ def apex_fused_rms_norm_affine_anno(input, weight, normalized_shape, eps) -> str inputs = [input_anno] inputs.append(ShapeAnno.create_shape_str(weight.shape, reduction='^', iterator=letters) if weight is not None else '?') inputs += ['?', '?'] + inputs += ['?' for _ in args] return OpAnno.create_op_str(inputs, outputs) From dd2cbcb0ad89e5449cab902f4cade5d48a91e69f Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 9 Jan 2024 09:29:00 +0000 Subject: [PATCH 1566/1892] Merged PR 1977: fix bug in operator patcher to support trace scope If we want to record the module stack during trace, the `__call__` of the module should be call, but previous we will directly call forward for the non-torch module. In this pr, change the call of module to `__call__`. --- .../concrete_trace_utils/concrete_tracer.py | 43 +++++++++++- .../concrete_trace_utils/operator_patcher.py | 35 ++++++++-- tests/graph/tracer/test_op_patcher.py | 68 +++++++++++++++++++ tests/graph/tracer/test_scope.py | 58 ++++++++++++++++ 4 files changed, 195 insertions(+), 9 deletions(-) create mode 100644 tests/graph/tracer/test_op_patcher.py create mode 100644 tests/graph/tracer/test_scope.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index f2feff02..f281747b 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -712,9 +712,13 @@ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, such as '__main__.FooModel' or '__main__.bar_func'. the namespace is always needed. """ + if not isinstance(root, torch.nn.Module): + # TODO: support trace any callable function by add the fill default values logic. + raise RuntimeError('Only support trace a torch.nn.Module instance now.') + # fill default values - args = inspect.getfullargspec(root.forward).args[1:] - defaults = inspect.getfullargspec(root.forward).defaults + args = inspect.getfullargspec(getattr(root, forward_function_name)).args[1:] + defaults = inspect.getfullargspec(getattr(root, forward_function_name)).defaults defaults = tuple() if defaults is None else defaults if isinstance(concrete_args, (tuple, list)): concrete_args = (*concrete_args, *defaults[len(concrete_args) + len(defaults) - len(args):]) @@ -1050,6 +1054,11 @@ def torch_assert_wrapper(condition, message): wrapped = _create_wrapped_attr_for_middle_class(self, clz, self.the_path_of_middle_class) self.wrapped_leaf[clz.__getattribute__] = (((clz, '__getattribute__'),), wrapped) + # wrap all forward in the submodule to trace the module stack + for mod in self.root.modules(): + wrapped = _create_wrapped_nn_module_func(self, mod, forward_function_name) + self.wrapped_leaf[mod.forward] = (((mod, forward_function_name),), wrapped) + @functools.wraps(_orig_isinstance) def isinstance_wrapper(instance, clz): if _orig_type(clz) in (slice, tuple, list, _orig_slice, _orig_tuple, _orig_list): @@ -1246,6 +1255,36 @@ def wrapped(*args, **kwargs): return wrapped +def _create_wrapped_nn_module_func(tracer: ConcreteTracer, mod: torch.nn.Module, name: str): + orig_fn = _orig_getattr(mod, name) + if not _orig_isinstance(orig_fn, MethodType): + raise RuntimeError(f'{tracer.path_of_module(mod)}.{name} is not a bound method, only support wrap bound method.') + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + module_qualified_name = tracer.path_of_module(mod) + with ScopeContextManager(tracer.scope, Scope(module_qualified_name, type(mod))) as _scope: + need_pop = False + if _scope.module_path not in tracer.module_stack: + need_pop = True + tracer.module_stack[_scope.module_path] = _scope.module_type + elif _scope.module_path != list(tracer.module_stack)[-1]: + raise RuntimeError(f'Scope not match: {_scope.module_path} vs {list(tracer.module_stack)[-1]}') + # has tracer means in tracing progress + if OperatorPatcherContext.ctx_tracer and OperatorPatcherContext.ctx_patcher: + # `patch_run` is needed because this function will be patched by fx patcher, + # which means it will have `__fx_already_patched` flag, and operator patcher will not patch it again, + # so directly call `patch_run` here to avoid the `orig_fn is not patched by the operator patcher. + result = OperatorPatcherContext.patch_run(orig_fn, *args, **kwargs) + else: + result = orig_fn(*args, **kwargs) + if need_pop: + key, _ = tracer.module_stack.popitem(last=True) + assert key == _scope.module_path, f" Unexpected key {key}" + return result + + return wrapped + @compatibility(is_backward_compatible=True) class GraphAppendingConcreteTracer(ConcreteTracer): diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index f7a4efe9..91c36553 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -157,11 +157,31 @@ def __init__(self, use_operator_patch: bool, operator_patch_backlist: List[str]) self.operator_patch_backlist = operator_patch_backlist self.proxy_call_name = OperatorPatcherContext.patch_run.__name__ - def patch_inner(self, func): - return self.patch_inner_helper(func) + def patch_func_or_module(self, func_or_module): + if _orig_isinstance(func_or_module, torch.nn.Module): + module, func = func_or_module, func_or_module.forward + new_func = self.patch_func_helper(func) + module.forward = new_func + return module + else: + return self.patch_func_helper(func_or_module) - def patch_inner_helper(self, func): - if not hasattr(func, '__module__') or func.__module__ is None or func.__module__.startswith('torch'): + def patch_func_helper(self, func): + """ + Patch a function here means we will modify the function source code and recompile to a new one. + The reason of patching function is some code style is not supported to trace, but these cases are common used, + we don't want users put effort on modify their source code (or even some widely used packages' code) for these cases. + + The following will be modify right now: + 1. not a -> operator.not_(a) + 2. a is b -> operator.is_(a, b) + 3. a is not b -> operator.is_not(a, b) + 4. a in b -> operator.contains(b, a) + 5. a not in b -> operator.not_(operator.contains(b, a)) + 6. super() -> super(self.__class__, self) + 7. func(a, b, c) -> patch_run(func, a, b, c) # for patch the functions called in the current function + """ + if not hasattr(func, '__module__') or func.__module__ is None or func.__module__.startswith('torch.'): return func # those flags are set by fx _Patcher when a method is patched # we don't want to patch it again @@ -171,14 +191,15 @@ def patch_inner_helper(self, func): return func if self.use_operator_patch == (func in self.operator_patch_backlist): return func - if _orig_isinstance(func, torch.nn.Module): - func = func.forward + if _orig_isinstance(func, MethodType): + # patch the function, not bound method, the function will be bound back after patch func_inner = func.__func__ the_self = func.__self__ else: func_inner = func the_self = None + # if it is not a function, or it has no code, then we can not patch it, directly return if not _orig_isinstance(func_inner, FunctionType) or not hasattr(func_inner, '__code__'): return func @@ -280,5 +301,5 @@ def patch_run(func, *args, **kwargs): assert OperatorPatcherContext.ctx_tracer is not None assert OperatorPatcherContext.ctx_patcher is not None with OperatorPatcherContext.ctx_tracer.do_temp_disable(True, True, True): - new_func = OperatorPatcherContext.ctx_patcher.patch_inner(func) + new_func = OperatorPatcherContext.ctx_patcher.patch_func_or_module(func) return new_func(*args, **kwargs) diff --git a/tests/graph/tracer/test_op_patcher.py b/tests/graph/tracer/test_op_patcher.py new file mode 100644 index 00000000..20ca94f0 --- /dev/null +++ b/tests/graph/tracer/test_op_patcher.py @@ -0,0 +1,68 @@ +import torch +from types import MethodType +from cube.graph.parser.fx.concrete_trace_utils.operator_patcher import OperatorPatcher + + +def test_patch_func_or_module(): + op_patcher = OperatorPatcher(True, []) + + # case 1: normal function + def normal_func(a, b): + return a + b + new_func = op_patcher.patch_func_or_module(normal_func) + assert normal_func == new_func + + def normal_func_with_compare(a, b): + return a is not b + new_func = op_patcher.patch_func_or_module(normal_func_with_compare) + assert normal_func_with_compare != new_func + + # case 2: bound function + obj = object() + bound_func = MethodType(normal_func, obj) + new_func = op_patcher.patch_func_or_module(bound_func) + assert bound_func == new_func + + obj = object() + bound_func_with_compare = MethodType(normal_func_with_compare, obj) + new_func = op_patcher.patch_func_or_module(bound_func) + assert bound_func_with_compare != new_func + + # case 3: module + class SimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x + + model = SimpleModel() + orig_forward = model.forward + model_with_orig_forward = op_patcher.patch_func_or_module(model) + assert model == model_with_orig_forward and model_with_orig_forward.forward == orig_forward + + class SimpleModelWithCompare(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, cp1, cp2): + if cp1 is cp2: + return x + else: + return x * 2 + + model = SimpleModelWithCompare() + orig_forward = model.forward + model_with_new_forward = op_patcher.patch_func_or_module(model) + assert model == model_with_new_forward and model_with_new_forward.forward != orig_forward + + # case 4: module.forward + model = SimpleModel() + orig_forward = model.forward + new_forward = op_patcher.patch_func_or_module(model.forward) + assert new_forward == orig_forward and not isinstance(new_forward, torch.nn.Module) + + model = SimpleModelWithCompare() + orig_forward = model.forward + new_forward = op_patcher.patch_func_or_module(model.forward) + assert new_forward != orig_forward and not isinstance(new_forward, torch.nn.Module) diff --git a/tests/graph/tracer/test_scope.py b/tests/graph/tracer/test_scope.py new file mode 100644 index 00000000..cdbd2fa4 --- /dev/null +++ b/tests/graph/tracer/test_scope.py @@ -0,0 +1,58 @@ +import torch + +from cube.graph.parser.converter import to_fx_graph +from ...utils import replace_all_device_with + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(10, 5) + self.m = SimpleModel2() + + def forward(self, x): + # node add_2 + return self.fc(x) + self.m.forward(x) + +class SimpleModel2(torch.nn.Module): + def __init__(self): + super().__init__() + self.m2 = SimpleModel3() + self.ffn = torch.nn.Linear(10, 5) + + def forward(self, x): + # node add_1 + return self.m2.forward(x) + self.ffn(x) + +class SimpleModel3(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(10, 5) + + def forward(self, x): + # node add + return self.fc1(x) + self.fc2(x) + + +@replace_all_device_with('cpu') +def test_scope(): + model = SimpleModel() + dummy_input = {'x': torch.rand(10)} + traced_graph = to_fx_graph(model, dummy_input) + traced_graph(**dummy_input) + + name_map = { + 'add': 'm.m2', + 'add_1': 'm', + # 'add_2': None # add_2 is at root module, so it will have an empty stack + } + + viewed_nodes = set() + for node in traced_graph.graph.nodes: + if node.name in name_map: + viewed_nodes.add(node.name) + module_path = list(node.meta['nn_module_stack'])[-1] + assert module_path == name_map[node.name], f'{module_path} == {name_map[node.name]}' + + assert viewed_nodes == set(name_map.keys()) From 29f820531899f5e07209b72ef322486edc24ed93 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 10 Jan 2024 06:41:50 +0000 Subject: [PATCH 1567/1892] Merged PR 1991: make barrier in longer time wait make barrier in longer time wait This PR will allow ALL communications happened in a process group to wait longer time (6 hours currently) --- cube/runtime/device.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cube/runtime/device.py b/cube/runtime/device.py index 0fbaeacd..b6ce32a7 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -6,6 +6,7 @@ import torch import os import logging +import datetime from cube.flags import CompileFlag @@ -25,7 +26,8 @@ def __init__(self): self.node_rank = 0 else: if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl') + torch.distributed.init_process_group( + backend='nccl', timeout=datetime.timedelta(seconds=21600)) self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() # assume each node has the same device number @@ -70,7 +72,8 @@ def get_group(self, ranks): return None rank_bits = DeviceGroup.bitmap(ranks) if rank_bits not in self.instance.groups: - self.groups[rank_bits] = torch.distributed.new_group(list(ranks)) + self.groups[rank_bits] = torch.distributed.new_group( + list(ranks), timeout=datetime.timedelta(seconds=21600)) return self.groups[rank_bits] def get_stream(self, name: str) -> torch.cuda.Stream: From 13418ddb3ded523e19d6c207e8716b75c47dea09 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 10 Jan 2024 08:18:08 +0000 Subject: [PATCH 1568/1892] Merged PR 1992: rename reuse flags none -> override all -> match add moo to support the behavior between override and match --- cube/parallel.py | 171 ++++++++++++++++++------- tests/parallel_module/test_gencode.py | 8 +- tests/parallel_module/test_init.py | 4 +- tests/parallel_module/test_override.py | 119 ++++++++++++----- tests/utils.py | 6 +- 5 files changed, 217 insertions(+), 91 deletions(-) diff --git a/cube/parallel.py b/cube/parallel.py index 7d4468e2..d86e4cc7 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -63,6 +63,43 @@ class ComputeConfig: # ``` reducer_op: str = 'sum' + # you can put any configuration here + # *Note*: the assumption is different user_config should generate different code. + # Example 1: save module configuration + # ```python + # class MyModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # def forward(self, x): + # ... + # if module_config.use_3d: + # ... + # ``` + # here we can set `user_config={'use_3d': module_config.use_3d}`, + # and we can be sure different use_3d will never use the same generated code. + # Example 2: save file stats + # If you want to track all related file stats (just like traditional compilers do), + # you can do + # ```python + # user_config = { + # 'file_stats': { + # str(f): os.stat(f).st_mtime_ns for f in Path('./src').glob('**/*.py') # assume all source code is in ./src + # } + # } + # ``` + # Or you can save the md5 of the files to save some bytes: + # ```python + # import hashlib + # h = hashlib.md5() + # for f in Path('./src').glob('**/*.py'): + # with open(f, 'rb') as f: + # h.update(f.read()) + # user_config = { + # 'files_md5': h.hexdigest() + # } + # ``` + user_config: Optional[Dict[str, Any]] = None + def __post_init__(self): if self.plan_ngpus <= 0: raise ValueError(f"plan_ngpus {self.plan_ngpus} must be > 0") @@ -146,6 +183,16 @@ def _get_arg_default_values(fn) -> Dict[str, Any]: return {k: v.default for k, v in args.parameters.items()} +def _clean_files(_dir: Path, pattern = '*') -> None: + """ + Clean files of a directory. No directories will be removed. + """ + for f in _dir.glob(pattern): + if f.is_file(): + f.unlink() + + +_DEFAULT_INSTANCE_NAME = '_' _GENCODE_FILE_PREFIX = 'gencode' _GENCODE_FILE_TEMPLATE = _GENCODE_FILE_PREFIX + '{}.py' # 'gencode{}.py' _CUBE_MODULE_NAMESPACE = '_cube_modules' @@ -155,17 +202,32 @@ def _get_arg_default_values(fn) -> Dict[str, Any]: class ReuseType(Enum): """The reuse type""" - NONE = 'none' # no reuse, everything will be regenerated. - ALL = 'all' # try to reuse everything if possible - GRAPH = 'graph' # only graph will be reused (so we don't need to trace the graph again) + MATCH = 'match' # reuse if present and match, error if present but not match, generate if not present. + OVERRIDE = 'override' # no reuse, everything will be regenerated. + MOO = 'moo' # (short for match or override)reuse if present and match, generate if not match or not present. + GRAPH = 'graph' # reuse graph only if present and match, generate otherwise. + + +def _prepare_namespace( + cube_savedir: str, + module_or_module_class: Union[Type[torch.nn.Module], torch.nn.Module], + instance_name: Optional[str] = None, +): + cube_savedir = _add_cube_savedir_to_syspath(cube_savedir) + + instance_name = instance_name or _DEFAULT_INSTANCE_NAME + instance_name = instance_name.strip('.') if instance_name else '' + instance_namespace = f'.{instance_name}' if instance_name else '' + namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_or_module_class)}{instance_namespace}' + return namespace def _prepare_and_check_reusable( - cube_savedir, + cube_savedir: str, module_or_module_class: Union[Type[torch.nn.Module], torch.nn.Module], - compute_config, - instance_name, - reuse: ReuseType = ReuseType.ALL, + compute_config: ComputeConfig, + instance_name: Optional[str] = None, + reuse: ReuseType = ReuseType.MATCH, ) -> Tuple[str, bool]: """ Prepare the output directory for code generation, and also check if the existing code is reusable. @@ -174,7 +236,7 @@ def _prepare_and_check_reusable( cube_savedir (str): the directory to save generated code module_or_module_class (Union[Type[torch.nn.Module], torch.nn.Module]): the original module or module class compute_config (ComputeConfig): the environment resource - instance_name (Optional[str]): the instance name of the generated module. + instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. reuse (ReuseType): specify which part can be reused. Returns: @@ -184,36 +246,41 @@ def _prepare_and_check_reusable( RuntimeError: if the existing code is not reusable, will raise RuntimeError if the code is not reusable but the module is already loaded. """ - - cube_savedir = _add_cube_savedir_to_syspath(cube_savedir) - - instance_name = instance_name.strip('.') if instance_name else '' - instance_namespace = f'.{instance_name}' if instance_name else '' - namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_or_module_class)}{instance_namespace}' + namespace = _prepare_namespace(cube_savedir, module_or_module_class, instance_name) outdir = cube_savedir / Path(namespace.replace('.', '/').strip('/')) outdir.mkdir(parents=True, exist_ok=True) # decision matrix for code generation # reuse flag | dir condition(imported, empty, match, unmatched) | action # --------------------------------------------------------- - # NONE/GRAPH | empty | generate - # NONE/GRAPH | imported | raise error - # NONE/GRAPH | match | generate - # NONE/GRAPH | unmatch | generate - # ALL | empty | generate - # ALL | match | do nothing - # ALL | unmatch | raise error - # ALL | imported | doesn't matter + # OVERRIDE/GRAPH | empty | generate + # OVERRIDE/GRAPH | imported | raise error + # OVERRIDE/GRAPH | match | generate + # OVERRIDE/GRAPH | unmatch | generate + # MATCH | empty | generate + # MATCH | match | reuse(do nothing) + # MATCH* | unmatch | raise error (except when there's no python source code, see below) + # MATCH | imported | doesn't matter + # MOO | empty | generate + # MOO | match | reuse(do nothing) + # MOO* | unmatch | generate (specail case is when there's no python source code, see below) + # MOO | imported | raise error if unmatch + # *: The precondition for `except` part is the compute config should match. + # you can take it as a continous operation after a failed MATCH/OVERRIDE. reusable = False - if reuse == ReuseType.ALL: - trace_meta_files = [ - outdir / FxModuleParser.ATTR_CONTENT_FILE_0, # just check the first is good enough - outdir / FxModuleParser.ATTR_MAP_FILE, - ] + config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE + old_config = torch.load(config_file) if config_file.exists() else None + is_config_match = old_config == compute_config + trace_meta_files = [ + outdir / FxModuleParser.ATTR_CONTENT_FILE_0, # just check the first is good enough + outdir / FxModuleParser.ATTR_MAP_FILE, + ] + + if reuse == ReuseType.MATCH or reuse == ReuseType.MOO: # check if the module is already generated expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] expected_output_files.extend(trace_meta_files) - expected_output_files.append(outdir / ParallelModule.COMPUTE_CONFIG_FILE) + expected_output_files.append(config_file) expected_output_files.append(outdir / _GRAPH_DUMP_FILE) expected_output_files.append(outdir / _FORWARD_ARGS_DUMP_FILE) existing_output_files = [ @@ -224,11 +291,12 @@ def _prepare_and_check_reusable( ) ] if existing_output_files: - if all([output_file.exists() for output_file in expected_output_files]) \ - and len(existing_output_files) == len(expected_output_files) \ - and torch.load(outdir / ParallelModule.COMPUTE_CONFIG_FILE) == compute_config: + if is_config_match \ + and all([output_file.exists() for output_file in expected_output_files]) \ + and len(existing_output_files) == len(expected_output_files): reusable = True # everything is matched. - elif all(f.suffix != '.py' for f in existing_output_files): + elif is_config_match \ + and all(f.suffix != '.py' for f in existing_output_files): # No python source code is generated. # which means its last generation failed. # in this case, we can reuse the same directory safely. @@ -237,27 +305,32 @@ def _prepare_and_check_reusable( f'Will reuse the directory and the graph dump if present.') # we have to trace the graph again if not all meta files are present. if not all([meta_file.exists() for meta_file in trace_meta_files]): - for f in outdir.glob('*'): - if f.is_file(): - f.unlink() - else: + _clean_files(outdir) + elif reuse == ReuseType.MATCH: raise RuntimeError(f'Output directory {outdir} is not empty. ' f'And the existing files do not match with current config. ' f'You can remove the directory and try again, ' - f'or set reuse to ReuseType.NONE to regenerate the code.') + f'or set reuse to ReuseType.NONE/ReuseType.OVERRIDE to regenerate the code.') + else: + assert reuse == ReuseType.MOO + if _is_any_gencode_loaded(namespace): + raise RuntimeError(f'Output directory {outdir} is already loaded. ' + f'You can not override a loaded module.') + _clean_files(outdir) else: # check if the module is already loaded if _is_any_gencode_loaded(namespace): raise RuntimeError(f'Output directory {outdir} is already loaded. ' f'You can not override a loaded module.') # clear existing generated files - if reuse == ReuseType.NONE: + if reuse == ReuseType.OVERRIDE \ + or not is_config_match \ + or not all([meta_file.exists() for meta_file in trace_meta_files]): + # we have to trace the graph again if not all meta files are present even when reuse=graph. glob_pattern = '*' else: glob_pattern = '*.py' # so we can keep graph dumps. - for f in outdir.glob(glob_pattern): - if f.is_file(): - f.unlink() + _clean_files(outdir, glob_pattern) return outdir, reusable @@ -435,7 +508,6 @@ def _gencode( execplan = Grouping.apply(execplan) # code generation - torch.save(compute_config, outdir / ParallelModule.COMPUTE_CONFIG_FILE) assert len(execplan.graph.device) == compute_config.plan_ngpus, f"{execplan.graph.device}" mgener = ModuleCodeGen(execplan, compute_config.runtime_ngpus) for rank in range(compute_config.runtime_ngpus): @@ -462,17 +534,15 @@ def _load_cube_module_class( Args: module_class (Type[torch.nn.Module]): the original module class cube_savedir (Union[str, Path]): the directory to load generated code - instance_name (Optional[str]): the instance name of the generated module. + instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. rank (Optional[int]): the rank of the module. If it is None, will get the rank from torch.distributed.get_rank(). This option is only useful for debugging or writing pre/post-processing tools. when you need to load the generated module in a non-torchrun environment. """ - _add_cube_savedir_to_syspath(cube_savedir) rank = torch.distributed.get_rank() if rank is None else rank - instance_name = instance_name.strip('.') if instance_name else '' - instance_namespace = f'.{instance_name}' if instance_name else '' + namespace = _prepare_namespace(cube_savedir, module_class, instance_name) gen_imported = importlib.import_module( - f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}{instance_namespace}.{Path(_GENCODE_FILE_TEMPLATE.format(rank)).stem}' + f'{namespace}.{Path(_GENCODE_FILE_TEMPLATE.format(rank)).stem}' ) cube_module_class = gen_imported.GenModel # rewrite class name and module name @@ -490,7 +560,7 @@ def parallelize( compute_config: ComputeConfig, *, cube_savedir: Union[str, Path] = './.cube', - reuse: Union[ReuseType, str] = ReuseType.ALL, + reuse: Union[ReuseType, str] = ReuseType.MATCH, instance_name: Optional[str] = None, load_module: bool = True, module_dtype: Optional[torch.dtype] = None, @@ -553,7 +623,7 @@ def __init__(self, init_params=True): compute_config (ComputeConfig): the environment resource reuse (ReuseType): specify which part can be reused. cube_savedir (Union[str, Path]): the directory to save generated code - instance_name (Optional[str]): the instance name of the generated module. + instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. load_module (bool): whether to load the generated module or module class after conversion is done. init_module_params (bool): If true, when we construct the module, all its parameters are initialized with the same value with when we traced. Otherwise, they will be empty tensor. @@ -582,6 +652,9 @@ def __init__(self, init_params=True): if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: outdir, reusable = _prepare_and_check_reusable(cube_savedir, module_class, compute_config, instance_name, reuse) if not reusable: + config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE + if not config_file.exists(): + torch.save(compute_config, config_file) with _compile_flags(compute_config): _gencode( module_or_module_class, diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 0c76aad9..04c9c2ed 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -205,10 +205,10 @@ def forward(self, x, attr): def _gencode_contains(cubesave_dir, module_class, index, search_re): - from cube.parallel import _CUBE_MODULE_NAMESPACE, _get_full_qualified_name + from cube.parallel import _CUBE_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME from pathlib import Path import re - namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}' + namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) filecontent = (outdir /f'gencode{index}.py').read_text() matches = re.findall(search_re, filecontent) @@ -419,7 +419,7 @@ def test_codegen_tensor_slice(): ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False, - reuse='none', + reuse='override', ) m = TensorSliceFixedModule() m.train() @@ -430,7 +430,7 @@ def test_codegen_tensor_slice(): ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False, - reuse='none', + reuse='override', ) diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index ea94f3ab..755fdd58 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -27,7 +27,7 @@ def _init_params_worker(): PASRandomSPMD, ComputeConfig(1, 1), cube_savedir=tempdir, - reuse='all', + reuse='match', ) module1 = cube_module() module2 = cube_module() @@ -70,7 +70,7 @@ def test_empty_weights(model_class, tp): PASRandomSPMD, ComputeConfig(2, 4, use_zero=True, zero_ngroups=2), cube_savedir=tempdir, - reuse='all', + reuse='match', load_module=False, ) for i in range(4): diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index 38a1d2c1..348b89cb 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -6,23 +6,33 @@ import torch import shutil -from cube.parallel import ReuseType, parallelize, ComputeConfig +from cube.graph.parser.fx.parser import FxModuleParser +from cube.parallel import ReuseType, parallelize, ComputeConfig, _load_cube_module_class -from .common import PASData, init_distributed -from ..launch_torchrun import launch_torchrun +from ..utils import new_empty, replace_all_device_with +from .common import PASData -def _to_cube_model(module, compute_config, cube_savedir, reuse, instance_name, load_module=True): - return parallelize( - module, +def _to_cube_model(model_class, compute_config, cube_savedir, reuse, instance_name, load_module=True): + parallelize( + model_class, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, PASData, compute_config, reuse=reuse, cube_savedir=cube_savedir, instance_name=instance_name, - load_module=load_module, + load_module=False, ) + if load_module: + module_class = _load_cube_module_class( + model_class, + cube_savedir=cube_savedir, + instance_name=instance_name, + rank=0 + ) + m = new_empty(module_class, device='cpu', init_params=True) + return m class MyModule(torch.nn.Module): @@ -34,21 +44,20 @@ def forward(self, x): return self.linear(x) -def _worker(): - init_distributed() - +@replace_all_device_with('cpu') +def test_override(): with tempfile.TemporaryDirectory() as tempdir: - # False | empty | generate - cmodule1 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.ALL, None) - # False | match | do nothing - cmodule2 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.ALL, None) - # true + # MATCH | empty | generate + cmodule1 = _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MATCH, 'mm0') + # MATCH | match | do nothing + cmodule2 = _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MATCH, 'mm0') for (n1, v1), (n2, v2) in zip(cmodule1.named_parameters(), cmodule2.named_parameters()): assert n1 == n2 assert torch.equal(v1, v2) - cmodule3 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.ALL, 'test') - cmodule4 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, 'all', 'test') + # MATCH | match | do nothing + cmodule3 = _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MATCH, 'test') + cmodule4 = _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, 'match', 'test') for (n1, v1), (n2, v2) in zip(cmodule3.named_parameters(), cmodule4.named_parameters()): assert n1 == n2 @@ -59,19 +68,42 @@ def _worker(): keys = cmodule3_p.keys() assert any(not torch.equal(cmodule2_p[key], cmodule3_p[key]) for key in keys) - # True | imported | raise error + # MATCH | unmatch | raise error + _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MATCH, 'm0') + with pytest.raises(RuntimeError, match='.*not empty.*'): + _to_cube_model(MyModule, ComputeConfig(2, 2),tempdir, 'match', 'm0') + + # MOO | empty | generate + omodule1 = _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o0') + # MOO | match | do nothing + omodule2 = _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o0') + for (n1, v1), (n2, v2) in zip(omodule1.named_parameters(), omodule2.named_parameters()): + assert n1 == n2 + assert torch.equal(v1, v2) + + # MOO | unmatch | generate + _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o1', load_module=False) + _to_cube_model(MyModule, ComputeConfig(2, 2),tempdir, ReuseType.MOO, 'o1') + + # MOO | imported | raise error + _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o2', load_module=True) + with pytest.raises(RuntimeError): + _to_cube_model(MyModule, ComputeConfig(2, 2),tempdir, ReuseType.MOO, 'o2') + + # OVERRIDE | imported | raise error with pytest.raises(RuntimeError): - _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.NONE, None) + _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.OVERRIDE, 'mm0') + # OVERRIDE | imported | raise error with pytest.raises(RuntimeError): - _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.NONE, 'test') + _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.OVERRIDE, 'test') - # False | unmatch | raise error + # OVERRIDE | imported | raise error with pytest.raises(RuntimeError): - _to_cube_model(MyModule(), ComputeConfig(2, 2),tempdir, ReuseType.NONE, 'test') + _to_cube_model(MyModule, ComputeConfig(2, 2),tempdir, ReuseType.OVERRIDE, 'test') - # True | empty | generate - cmodule1 = _to_cube_model(MyModule(), ComputeConfig(1, 1),tempdir, ReuseType.NONE, 'test2') + # OVERRIDE | empty | generate + cmodule1 = _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.OVERRIDE, 'test2') module_path = Path(sys.modules[cmodule1.__module__].__file__).parent test3_module_path = module_path.with_name('test3') test3_module_path.mkdir(exist_ok=True, parents=True) @@ -88,16 +120,16 @@ def _worker(): shutil.copy(test4_module_path / 'gencode0.py', test4_module_path / 'gencode1.py') shutil.copy(test5_module_path / 'gencode0.py', test5_module_path / 'gencode1.py') - # True | match | generate - cmodule2 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, ReuseType.NONE, 'test3') + # OVERRIDE | match | generate + cmodule2 = _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, ReuseType.OVERRIDE, 'test3') cmodule2_p = dict(cmodule2.named_parameters()) cmodule1_p = dict(cmodule1.named_parameters()) keys = cmodule2_p.keys() assert any(not torch.equal(cmodule2_p[key], cmodule1_p[key]) for key in keys) - # True | unmatch | generate + # OVERRIDE | unmatch | generate assert (test4_module_path / 'gencode1.py').exists() - cmodule3 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, 'none', 'test4') + cmodule3 = _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'override', 'test4') assert not (test4_module_path / 'gencode1.py').exists() # Graph | matched | generate @@ -105,7 +137,7 @@ def _worker(): code_stat = (test5_module_path / 'gencode0.py').stat() graph_stat = (test5_module_path / 'graph.ckp').stat() args_stat = (test5_module_path / 'forward_args.pkl').stat() - cmodule4 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, 'graph', 'test5', False) + _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'test5', False) assert not (test5_module_path / 'gencode1.py').exists() assert (test5_module_path / 'gencode0.py').stat().st_mtime_ns != code_stat.st_mtime_ns assert (test5_module_path / 'graph.ckp').stat().st_mtime_ns == graph_stat.st_mtime_ns @@ -114,12 +146,33 @@ def _worker(): code_stat = (test5_module_path / 'gencode0.py').stat() graph_stat = (test5_module_path / 'graph.ckp').stat() (test5_module_path / 'forward_args.pkl').unlink() # remove foward_args.pkl will force to generate new code - cmodule5 = _to_cube_model(MyModule(), ComputeConfig(1, 1), tempdir, 'graph', 'test5', False) + _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'test5', False) assert (test5_module_path / 'gencode0.py').stat().st_mtime_ns != code_stat.st_mtime_ns assert (test5_module_path / 'graph.ckp').stat().st_mtime_ns != graph_stat.st_mtime_ns assert (test5_module_path / 'forward_args.pkl').exists() -@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') -def test_override(): - launch_torchrun(1, _worker) + code_stat = (test5_module_path / 'gencode0.py').stat() + graph_stat = (test5_module_path / 'graph.ckp').stat() + attrmap_stat = (test5_module_path / FxModuleParser.ATTR_MAP_FILE).stat() + (test5_module_path / FxModuleParser.ATTR_CONTENT_FILE_0).unlink() # remove fullmodel.pt.0 will force to generate new code + _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'test5', False) + assert (test5_module_path / 'gencode0.py').stat().st_mtime_ns != code_stat.st_mtime_ns + assert (test5_module_path / 'graph.ckp').stat().st_mtime_ns != graph_stat.st_mtime_ns + assert (test5_module_path / FxModuleParser.ATTR_MAP_FILE).stat().st_mtime_ns != attrmap_stat.st_mtime_ns + assert (test5_module_path / 'forward_args.pkl').exists() + # Graph | empty | generate + g6_module = _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'g6') + + # Graph | imported | raise error + with pytest.raises(RuntimeError): + _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'g6') + + # Graph | unmatch | generate + _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'g7', False) + g7_module_path = module_path.with_name('g7') + graph_stat = (g7_module_path / 'graph.ckp').stat() + args_stat = (g7_module_path / 'forward_args.pkl').stat() + _to_cube_model(MyModule, ComputeConfig(2, 2), tempdir, 'graph', 'g7', False) + assert (g7_module_path / 'graph.ckp').stat().st_mtime_ns != graph_stat.st_mtime_ns + assert (g7_module_path / 'forward_args.pkl').stat().st_mtime_ns != args_stat.st_mtime_ns diff --git a/tests/utils.py b/tests/utils.py index 8ceeecf8..7a4e4fa6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -271,7 +271,7 @@ def mock_cube_env(cube_module_cls: Type[ParallelModule], compute_config): DeviceGroup.instance = old_device_group -def new_empty(cube_module_cls: Type[ParallelModule]): +def new_empty(cube_module_cls: Type[ParallelModule], device='meta', init_params=False): """ Create a new instance with empty weights. @@ -279,5 +279,5 @@ def new_empty(cube_module_cls: Type[ParallelModule]): """ module_file = Path(sys.modules[cube_module_cls.__module__].__file__) compute_config = torch.load(module_file.with_name(f"{cube_module_cls.COMPUTE_CONFIG_FILE}")) - with replace_all_device_with('meta', True), mock_cube_env(cube_module_cls, compute_config), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): - return cube_module_cls(init_params=False) + with replace_all_device_with(device, True), mock_cube_env(cube_module_cls, compute_config), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): + return cube_module_cls(init_params=init_params) From 047ae96502a4d32434ca19143c3fc013da2f5fc0 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 11 Jan 2024 06:52:53 +0000 Subject: [PATCH 1569/1892] Merged PR 1990: Pipeline Support-3: Adapter schedule and return format Adapter schedule: - For adapters created of return tensors (typically broadcast primitives), they are inserted at the end of training / inference iteration Return format: - If the iteration has returned tensors, e.g., `return loss, x`, then the outputs of compiled function will return `[(loss_0, x_0), (loss_1, x_1), ..., (loss_nmicros, x_nmicros)]`, where the x_i means i-th micro-batch output of x. User example: ```python # samples: List[..] each item is a micro-batch's forward input # wrap multiple data samples into a dataloader dataloader = microbatch(samples) # multiple forward + backward outputs = train_iter(model, dataloader) # e.g., [(loss_0, loggings_0,), (loss_1, loggings_1), ...] for i, outs in enumerate(outputs): print('this is i-th micro-batch outputs:') ... # handle outs ``` --- cube/codegen/schedule/schedule.py | 6 +++--- cube/execplan/execplan.py | 27 +++++++++++++++++++++++++-- cube/graph/schedule/schedplan.py | 13 +++++++++++++ cube/graph/segment.py | 4 ++-- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/cube/codegen/schedule/schedule.py b/cube/codegen/schedule/schedule.py index d72e61e6..496d7563 100644 --- a/cube/codegen/schedule/schedule.py +++ b/cube/codegen/schedule/schedule.py @@ -72,7 +72,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: assert all(not isinstance(n, IRFwOperation) for n in device_nodes), \ "Expected all forward operators have been grouped into IRSegment" - lifetime = LifeCycle(device_nodes, [], self.execplan.graph.outputs()) + lifetime = LifeCycle(device_nodes, [], self.execplan.outputs()) args = ['model'] + [self.tensor_name(t) for t in self.execplan.graph.inputs()] @@ -93,7 +93,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: if len(tensors) > 0 : # not necessarily to have one after each line fb.insert_body(self.emit_release(tensors)) # return code - outputs = self.return_name_complex(self.execplan.graph.outputs()) + outputs = self.return_name_complex(self.execplan.outputs()) code = f'return {outputs}' fb.insert_body(code) gencode += fb.code @@ -120,7 +120,7 @@ def gen(self, device: int, outfile=None, attach=None) -> str: if len(tensors) > 0 : # not necessarily to have one after each line fb.insert_body(self.emit_release(tensors)) # return code - outputs = self.return_name_complex(self.execplan.graph.outputs()) + outputs = self.return_name_complex(self.execplan.outputs()) code = f'return {outputs}' fb.insert_body(code) gencode += fb.code diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 95cd086d..2b802b2a 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Any import copy import numpy as np import sys @@ -130,7 +130,20 @@ def block2reuse(node: Block) -> ExeReuseCell: block = block2reuse(block) assert isinstance(block, IRCell) topo_seqs.append(block) - return ExecutionPlan(schedplan.graph, topo_seqs) + + # set up returning outputs by packing output results from each micro-batch into a list + outputs = [] + for mid in range(schedplan.nmicros): + outs = [] + for output in schedplan.graph.outputs(): + outs.append(IRSegment.modify_objects_of_complex(output, lambda x: get(x, mid))) + if len(outs) > 0: + outputs.append(outs[0] if len(outs) == 1 else outs) + + execplan = ExecutionPlan(schedplan.graph, topo_seqs) + execplan.set_outputs(outputs) + + return execplan def __init__(self, graph: IRGraph, topo_seqs: List[IRCell]): @@ -138,6 +151,7 @@ def __init__(self, graph: IRGraph, topo_seqs: List[IRCell]): self._graph = graph self._topo_seqs = topo_seqs self._seq: Dict[int, List[IRCell]] = {} + self._outputs = list(graph.outputs()) for node in self._topo_seqs: assert len(node.device) > 0, f"Node device not set: {node}" @@ -173,6 +187,10 @@ def graph(self) -> IRGraph: def inference(self) -> bool: return not self._graph.train + def outputs(self) -> List[Any]: + """Get execution plan return outputs""" + return self._outputs + def devices(self) -> List[int]: """ Get device set @@ -222,6 +240,11 @@ def set(self, devid: int, seq: List[IRCell]): raise TypeError("Expected a list of Cell") self._seq[devid] = seq + def set_outputs(self, outputs: List[Any]): + if not isinstance(outputs, list): + raise TypeError("Expected a list of outputs") + self._outputs = outputs + def visualize(self, outfile: str, map2time: Optional[Callable] = None, map2mem: Optional[Callable] = None, diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py index 1c5df7fb..1679ced1 100644 --- a/cube/graph/schedule/schedplan.py +++ b/cube/graph/schedule/schedplan.py @@ -64,7 +64,9 @@ def __init__(self, graph: IRGraph) -> None: self.dataloaders : List[IRDataOperation] = [] self.segments: List[IRSegment] = [] self.adapters: List[IRAdapter] = [] + # the IRSegment that consumes the output of IRAdapter self.recvers: Dict[IRAdapter, IRSegment] = {} + # the IRSegment that produces the input of IRAdapter self.senders: Dict[IRAdapter, IRSegment] = {} self.reducers: List[IRWeightReducer] = [] @@ -391,6 +393,17 @@ def _place_adapters(self): f"This usually happens when its sender is dataloader or graph inputs." f"Please replicate dataloader to remove this adapter.") sender: IRSegment = self._dependency.senders[adapter] + + # since the schedule should return the same graph outputs on every device, + # there will be adapters created to broadcast outputs of each microbatch + # from the last-stage devices to all the devices. + # These adapters don't have any dependent recver segment, + # and will be placed at the end of the plan to not block the schedule execution. + if adapter not in self._dependency.recvers: + for mid in range(self._num_microbatches): + self._step_adapters[self.nsteps-1].append(Block(adapter, mid, 1)) + continue + # find sender step and insert adapter for step in range(self.nsteps): blocks = self.start_blocks(step) diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 98b0e939..a894c079 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1119,13 +1119,13 @@ def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[ @staticmethod def modify_objects_of_complex(val: Any, modifier: Callable) -> Any: - """Return a complex data structure where its IRObjects are in-placemently modified + """Return a complex data structure with modified IRObjects Supported complex of types: List, Tuple, Dict, IRTensor, IRObject Args: val (Any): the complex data structure to be modified - modifier (Callable): an inplacement modifier that takes an IRObject and return None + modifier (Callable): a modifier that takes an IRObject and return a new one. Return: new_val (Any): complex data structure with modified IRObjects From 5958efa533c347283144ab50a57dcb2de46234e9 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 11 Jan 2024 07:53:00 +0000 Subject: [PATCH 1570/1892] Merged PR 1989: Add module stack info to Cube nodes merge after [PR](https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube/pullrequest/1977) parity check passed ![image.png](https://dev.azure.com/msrasrg/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/1989/attachments/image.png) Related work items: #1745 --- cube/graph/function/function.py | 2 ++ cube/graph/graph.py | 2 ++ cube/graph/parser/fx/parser.py | 2 ++ cube/ir/cten.py | 14 ++++++++++++++ 4 files changed, 20 insertions(+) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 6eb2d6b8..d38630e1 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -133,10 +133,12 @@ def Matmul(input, other, *, out=None, signature=None): signature = 'torch.matmul' assert out is None annos = [ + 'k+, k+ -> 1', 'm k+, k+ n -> m n', 'k+, k+ n -> n', 'm k+, k+ -> m', '* m k+, k+ n -> * m n', + 'm k+, * k+ n -> * m n', '* m k+, * k+ n -> * m n' # TODO: broadcast ] if len(input.shape) > 2 and len(other.shape) > 2: diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 64adce04..ac8efa73 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -261,6 +261,7 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis fnode.recompute = node.recompute if isinstance(node.comment, str): fnode.comment = node.comment + fnode.module_stack = node.module_stack fnode.device = node.device fsegment.replace(node, fnodes) # insert backward @@ -321,6 +322,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], fnode.recompute = node.recompute if isinstance(node.comment, str): fnode.comment = node.comment + fnode.module_stack = node.module_stack fnode.device = node.device fsegment.replace(node, fnodes) diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 84a032ce..2d90dc67 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -212,6 +212,8 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule ir_node = IRPyFunc(fsig, input_vals, [IRObject()], **kwargs) if isinstance(ir_node, IRCell): + module_stack = node.meta.get('nn_module_stack') + ir_node.module_stack = module_stack comment = str(node.meta.get('frame_record', '')) if comment: ir_node.comment = comment diff --git a/cube/ir/cten.py b/cube/ir/cten.py index f66c6eb9..1a9442db 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -17,6 +17,7 @@ from functools import lru_cache from typing import List, Tuple, Union, Optional, Any, Dict +from collections import OrderedDict import copy import torch @@ -63,6 +64,8 @@ def __init__(self, self._mirror: Optional[IRCell] = None # the comment for code generation self._comment: Optional[str] = None + # the module stack that preserves the hierarchy information + self._module_stack: Optional[OrderedDict[str, Any]] = None @property def cid(self) -> int: @@ -243,6 +246,17 @@ def comment(self, info: str): assert isinstance(info, str), "comment only allowed to be string" self._comment = info + @property + def module_stack(self) -> Optional[OrderedDict[str, Any]]: + return self._module_stack + + @module_stack.setter + def module_stack(self, stack: OrderedDict[str, Any]): + """ + Set the module stack + """ + self._module_stack = stack + def __repr__(self) -> str: """ Cell string presentation From b0795ccad277ed9a042f5768f8493320eb256039 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 12 Jan 2024 01:48:42 +0000 Subject: [PATCH 1571/1892] Merged PR 1994: convert wrap related tuple to dataclass parity check: pass unit test: pass --- cube/graph/parser/converter.py | 9 +- .../concrete_trace_utils/concrete_tracer.py | 250 ++++++++++-------- cube/parallel.py | 2 +- 3 files changed, 145 insertions(+), 116 deletions(-) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index c16e62ec..3ae4e5b5 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -10,7 +10,7 @@ from cube.graph.parser.fx.parser import FxModuleParser from cube.graph.parser.fx.concrete_trace_utils import concrete_trace -from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply +from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import Location, is_autograd_apply, LeafFnWrapInfo from cube.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops import cube.runtime.function as cube_rt_function @@ -76,11 +76,14 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: autowrap_funcs = [CustomizedOps.kOpRuntime[sign] for sign in CustomizedOps.kOpMap] # filter out torch.autograd.Function.apply as concrete trace already treats them as leaf function autowrap_funcs = [fn for fn in autowrap_funcs if not is_autograd_apply(fn)] - leaf_functions = {func: ([], True, None) for func in autowrap_funcs if func is not None} + leaf_functions = {func: LeafFnWrapInfo([], True, None) for func in autowrap_funcs if func is not None} # get cube runtime functions cube_rt_funcs = [cube_rt_function.anchor] - leaf_functions.update({func: ([(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs}) + leaf_functions.update({ + func: LeafFnWrapInfo([Location(cube_rt_function, func.__name__)], True, None) + for func in cube_rt_funcs + }) dce_ignored_funcs = set(cube_rt_funcs) with no_save_tensor_hook(): diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index f281747b..0bc112a5 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -13,6 +13,7 @@ import builtins import traceback import importlib.util +from dataclasses import dataclass, field from itertools import chain from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType @@ -137,6 +138,41 @@ def __exit__(self, *args): HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS +@dataclass +class Location: + """ + The place a function/class locates. + Please note one function/class can be in multiple places. + Take `torch.meshgrid` for example, there are `torch.meshgrid`, 'torch.functional.meshgrid', 'torch._C._VariableFunctions.meshgrid', + """ + ns: Union[Type, ModuleType, Any] # the namespace of the name. It can be a class/module, etc. + name: str + + +@dataclass +class LeafFnWrapInfo: + """ + extra_locs: The place the function is imported. + is_force_trace: If set to false, the function will only be traced if input relates to concrete_args. + Such as 'torch.rand', we should trace it even if it doesn't relate to concrete_args. + replace_fn: If not `None`, we will use it to replace the original function in traced code. + Such as ModuleList.__getitem__, we can use operator.getitem to replace it. + """ + extra_locs: List[Location] = field(default_factory=list) + is_force_trace: bool = False + replace_fn: Optional[Callable] = None + + +@dataclass +class LeafClassWrapInfo: + """ + extra_locs: The place the class is imported. + is_iterable: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True. + """ + extra_locs: List[Location] = field(default_factory=list) + is_iterable: bool = False + + def is_autograd_apply(func) -> bool: return getattr(func, '__name__', None) == 'apply' \ and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) @@ -155,58 +191,58 @@ class ConcreteTracer(TracerBase): default_autowrap_modules = ( 'math', ) - default_autowrap_leaf_function: Dict[Any, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool, Optional[Callable]]] = { + default_autowrap_leaf_function: Dict[Any, LeafFnWrapInfo] = { # function - _orig_len: ([], False, None), - _orig_not: ([], False, None), - _orig_is: ([], False, None), - _orig_is_not: ([], False, None), - _orig_contains: ([], False, None), - _orig_index: ([], False, None), - _orig_all: ((), False, None), - _orig_min: ((), False, None), - _orig_max: ((), False, None), + _orig_len: LeafFnWrapInfo([], False, None), + _orig_not: LeafFnWrapInfo([], False, None), + _orig_is: LeafFnWrapInfo([], False, None), + _orig_is_not: LeafFnWrapInfo([], False, None), + _orig_contains: LeafFnWrapInfo([], False, None), + _orig_index: LeafFnWrapInfo([], False, None), + _orig_all: LeafFnWrapInfo([], False, None), + _orig_min: LeafFnWrapInfo([], False, None), + _orig_max: LeafFnWrapInfo([], False, None), # force-traced function (the factory functions of tensor creation) - torch.arange: ([], True, None), - torch.empty: ([], True, None), - torch.eye: ([], True, None), - torch.full: ([], True, None), - torch.linspace: ([], True, None), - torch.logspace: ([], True, None), - torch.ones: ([], True, None), - torch.rand: ([], True, None), - torch.randint: ([], True, None), - torch.randn: ([], True, None), - # torch.rand_like: ([], True, None), # seems that xxx_like will not directly call torch._TensorBase.xxx - # torch.randn_like: ([], True, None), - # torch.randint_like: ([], True, None), - torch.randperm: ([], True, None), - torch.tensor: ([], True, None), - torch.zeros: ([], True, None), + torch.arange: LeafFnWrapInfo([], True, None), + torch.empty: LeafFnWrapInfo([], True, None), + torch.eye: LeafFnWrapInfo([], True, None), + torch.full: LeafFnWrapInfo([], True, None), + torch.linspace: LeafFnWrapInfo([], True, None), + torch.logspace: LeafFnWrapInfo([], True, None), + torch.ones: LeafFnWrapInfo([], True, None), + torch.rand: LeafFnWrapInfo([], True, None), + torch.randint: LeafFnWrapInfo([], True, None), + torch.randn: LeafFnWrapInfo([], True, None), + # torch.rand_like: LeafFnWrapInfo([], True, None), # seems that xxx_like will not directly call torch._TensorBase.xxx + # torch.randn_like: LeafFnWrapInfo([], True, None), + # torch.randint_like: LeafFnWrapInfo([], True, None), + torch.randperm: LeafFnWrapInfo([], True, None), + torch.tensor: LeafFnWrapInfo([], True, None), + torch.zeros: LeafFnWrapInfo([], True, None), # method - Sequential.__getitem__: ([], False, operator.getitem), - Sequential.__len__: ([], False, _orig_len), - Sequential.__iter__: ([], False, iter), - - ModuleList.__getitem__: ([], False, operator.getitem), - ModuleList.__len__: ([], False, _orig_len), - ModuleList.__iter__: ([], False, iter), - - ModuleDict.__getitem__: ([], False, operator.getitem), - ModuleDict.__len__: ([], False, _orig_len), - ModuleDict.__iter__: ([], False, iter), - ModuleDict.__contains__: ([], False, _orig_contains), - - ParameterList.__getitem__: ([], False, operator.getitem), - ParameterList.__len__: ([], False, _orig_len), - ParameterList.__iter__: ([], False, iter), - - ParameterDict.__getitem__: ([], False, operator.getitem), - ParameterDict.__len__: ([], False, _orig_len), - ParameterDict.__iter__: ([], False, iter), - ParameterDict.__contains__: ([], False, _orig_contains), + Sequential.__getitem__: LeafFnWrapInfo([], False, operator.getitem), + Sequential.__len__: LeafFnWrapInfo([], False, _orig_len), + Sequential.__iter__: LeafFnWrapInfo([], False, iter), + + ModuleList.__getitem__: LeafFnWrapInfo([], False, operator.getitem), + ModuleList.__len__: LeafFnWrapInfo([], False, _orig_len), + ModuleList.__iter__: LeafFnWrapInfo([], False, iter), + + ModuleDict.__getitem__: LeafFnWrapInfo([], False, operator.getitem), + ModuleDict.__len__: LeafFnWrapInfo([], False, _orig_len), + ModuleDict.__iter__: LeafFnWrapInfo([], False, iter), + ModuleDict.__contains__: LeafFnWrapInfo([], False, _orig_contains), + + ParameterList.__getitem__: LeafFnWrapInfo([], False, operator.getitem), + ParameterList.__len__: LeafFnWrapInfo([], False, _orig_len), + ParameterList.__iter__: LeafFnWrapInfo([], False, iter), + + ParameterDict.__getitem__: LeafFnWrapInfo([], False, operator.getitem), + ParameterDict.__len__: LeafFnWrapInfo([], False, _orig_len), + ParameterDict.__iter__: LeafFnWrapInfo([], False, iter), + ParameterDict.__contains__: LeafFnWrapInfo([], False, _orig_contains), } # equals to `from torch.nn import functional as nn_functional` # to pass pyright check @@ -215,59 +251,59 @@ class ConcreteTracer(TracerBase): for name in torch.functional.__all__: attr = getattr(torch.functional, name) if attr not in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr] = ([], False, attr) + default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, attr) for name in dir(nn_functional): attr = getattr(nn_functional, name) if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__')\ and getattr(attr, '__module__', None) not in ('typing', 'torch.nn.modules.utils'): if attr not in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr] = ([], False, getattr(torch.functional, name, None)) + default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) if hasattr(attr, '__module__') and attr.__module__ != 'torch.nn.functional': - default_autowrap_leaf_function[attr][0].append((nn_functional, name)) + default_autowrap_leaf_function[attr].extra_locs.append(Location(nn_functional, name)) for name in dir(torch._C._VariableFunctions): attr = getattr(torch._C._VariableFunctions, name) if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): if attr not in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr] = ([], False, getattr(torch.functional, name, None)) + default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) for name in dir(torch._C._nn): attr = getattr(torch._C._nn, name) if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): if attr not in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr] = ([], False, getattr(torch.functional, name, None)) + default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) if hasattr(attr, '__module__') and attr.__module__ != 'torch._C._nn': - default_autowrap_leaf_function[attr][0].append((torch._C._nn, name)) + default_autowrap_leaf_function[attr].extra_locs.append(Location(torch._C._nn, name)) for name in dir(torch._C._TensorBase): attr = getattr(torch._C._TensorBase, name) if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): if attr not in default_autowrap_leaf_function: to_func = getattr(torch.Tensor, name, None) to_func = None if to_func == attr else to_func - default_autowrap_leaf_function[attr] = ([], False, to_func) + default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, to_func) # find the multi position for default_autowrap_leaf_function in torch.__dir__() for name in dir(torch): attr = getattr(torch, name) if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__') \ and attr in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr][0].append((torch, name)) + default_autowrap_leaf_function[attr].extra_locs.append(Location(torch, name)) - default_autowrap_leaf_class: Dict[Type, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool]] = { + default_autowrap_leaf_class: Dict[Type, LeafClassWrapInfo] = { # class - _orig_bool: ([], False), + _orig_bool: LeafClassWrapInfo([], False), # we don't want zip appear as a node in the graph - # _orig_zip: ([], False), - _orig_int: ([], False), - _orig_float: ([], False), + # _orig_zip: LeafClassWrapInfo([], False), + _orig_int: LeafClassWrapInfo([], False), + _orig_float: LeafClassWrapInfo([], False), # iterable class - _orig_tuple: ([], True), - _orig_list: ([], True), - _orig_set: ([], True), - _orig_frozenset: ([], True), - _orig_dict: ([], True), - _orig_reversed: ((), False), - - _orig_torch_size: ((), False), - _orig_torch_finfo: ((), False), + _orig_tuple: LeafClassWrapInfo([], True), + _orig_list: LeafClassWrapInfo([], True), + _orig_set: LeafClassWrapInfo([], True), + _orig_frozenset: LeafClassWrapInfo([], True), + _orig_dict: LeafClassWrapInfo([], True), + _orig_reversed: LeafClassWrapInfo([], False), + + _orig_torch_size: LeafClassWrapInfo([], False), + _orig_torch_finfo: LeafClassWrapInfo([], False), } @compatibility(is_backward_compatible=True) @@ -684,8 +720,8 @@ def proxy_placeholder(name: str): @compatibility(is_backward_compatible=True) def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, autowrap_modules: Tuple[str] | None = None, - autowrap_leaf_function = None, - autowrap_leaf_class = None, + autowrap_leaf_function: Optional[Dict[Any, LeafFnWrapInfo]] = None, + autowrap_leaf_class: Optional[Dict[Type, LeafClassWrapInfo]] = None, leaf_module = None, fake_middle_class = None, concrete_args: Union[Dict[str, Any], Tuple], @@ -985,29 +1021,30 @@ def torch_assert_wrapper(condition, message): self.autowrap_leaf_pairs = { id(_orig_torch_assert): torch_assert_wrapper, } - self.wrapped_leaf = dict() + self.wrapped_leaf: Dict[Any, Tuple[Tuple[Location,...], Any]] = dict() - for func, (positions, is_force_trace, to_func) in self.autowrap_leaf_function.items(): + for func, wrap_info in self.autowrap_leaf_function.items(): + locations = tuple(wrap_info.extra_locs) if is_autograd_apply(func): # torch.autograd.function - assert to_func == None, '.apply should set to_func to None!' + assert wrap_info.replace_fn == None, '.apply should set to_func to None!' if func.__self__ not in self.agfunc_dict: self.agfunc_dict[func.__self__] = _create_wrapped_leaf_func(self, func, func) wrapped = self.agfunc_dict[func.__self__] else: if func.__qualname__.startswith('_TensorBase'): - positions = (*positions, (torch.Tensor, func.__name__)) - wrapped = _create_wrapped_leaf_method(self, getattr(torch.Tensor, func.__name__), func.__name__, to_func) + locations = (*locations, Location(torch.Tensor, func.__name__)) + wrapped = _create_wrapped_leaf_method(self, getattr(torch.Tensor, func.__name__), func.__name__, wrap_info.replace_fn) elif func.__qualname__.startswith('_VariableFunctionsClass'): if hasattr(torch, func.__name__) and getattr(torch, func.__name__) == func: # avoid bad attr like 'unique_dim' - positions = (*positions, (torch, func.__name__)) - if is_force_trace: - wrapped = _create_wrapped_leaf_func(self, func, to_func, (self,)) + locations = (*locations, Location(torch, func.__name__)) + if wrap_info.is_force_trace: + wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn, (self,)) else: - wrapped = _create_wrapped_leaf_func(self, func, to_func) + wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn) elif _orig_isinstance(func, (MethodDescriptorType, MethodWrapperType)): - wrapped = _create_wrapped_leaf_method(self, func, func.__name__, to_func) + wrapped = _create_wrapped_leaf_method(self, func, func.__name__, wrap_info.replace_fn) elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn' \ and not func.__qualname__.startswith('PyCapsule'): # method @@ -1016,20 +1053,20 @@ def torch_assert_wrapper(condition, message): else: path = sys.modules[func.__module__] path = getattr(path, func.__qualname__.split('.')[0]) - positions = (*positions, (path, func.__name__)) - wrapped = _create_wrapped_leaf_method(self, func, func.__name__, to_func) + locations = (*locations, Location(path, func.__name__)) + wrapped = _create_wrapped_leaf_method(self, func, func.__name__, wrap_info.replace_fn) else: # common function if func.__module__.startswith('_') and func.__module__ != '__main__': path = sys.modules[func.__module__[1:]] else: path = sys.modules[func.__module__] - positions = (*positions, (path, func.__name__)) - if is_force_trace: - wrapped = _create_wrapped_leaf_func(self, func, to_func, (self,)) + locations = (*locations, Location(path, func.__name__)) + if wrap_info.is_force_trace: + wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn, (self,)) else: - wrapped = _create_wrapped_leaf_func(self, func, to_func) - self.wrapped_leaf[func] = (positions, wrapped) + wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn) + self.wrapped_leaf[func] = (locations, wrapped) self.clz_wrapper_map: Dict[Any, Type] = { map_wrapper: _orig_map, @@ -1037,27 +1074,27 @@ def torch_assert_wrapper(condition, message): range_wrapper: _orig_range, type_wrapper: _orig_type, } - for clz, (positions, is_iterable) in self.autowrap_leaf_class.items(): + for clz, wrap_info in self.autowrap_leaf_class.items(): if clz.__module__.startswith('_') and clz.__module__ != '__main__': path = sys.modules[clz.__module__[1:]] else: path = sys.modules[clz.__module__] - if is_iterable: + if wrap_info.is_iterable: wrapped = _create_wrapped_leaf_iterable_class(self, clz) else: wrapped = _create_wrapped_leaf_class(self, clz) - positions = (*positions, (path, clz.__name__)) - self.wrapped_leaf[clz] = (positions, wrapped) + locations = (*wrap_info.extra_locs, Location(path, clz.__name__)) + self.wrapped_leaf[clz] = (locations, wrapped) self.clz_wrapper_map[wrapped] = clz for clz in self.fake_middle_class: wrapped = _create_wrapped_attr_for_middle_class(self, clz, self.the_path_of_middle_class) - self.wrapped_leaf[clz.__getattribute__] = (((clz, '__getattribute__'),), wrapped) + self.wrapped_leaf[clz.__getattribute__] = ((Location(clz, '__getattribute__'),), wrapped) # wrap all forward in the submodule to trace the module stack for mod in self.root.modules(): wrapped = _create_wrapped_nn_module_func(self, mod, forward_function_name) - self.wrapped_leaf[mod.forward] = (((mod, forward_function_name),), wrapped) + self.wrapped_leaf[mod.forward] = ((Location(mod, forward_function_name),), wrapped) @functools.wraps(_orig_isinstance) def isinstance_wrapper(instance, clz): @@ -1133,8 +1170,8 @@ def getattr_wrapper(obj, *args): self.patcher.patch_method(builtins, "getattr", getattr_wrapper, deduplicate=False) for obj, (positions, wrapped) in self.wrapped_leaf.items(): - for path, name in positions: - self.patcher.patch_method(path, name, wrapped, deduplicate=False) + for loc in positions: + self.patcher.patch_method(loc.ns, loc.name, wrapped, deduplicate=False) self.autowrap_leaf_pairs[id(obj)] = wrapped _patch_wrapped_functions(self.patcher) @@ -1627,8 +1664,8 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], operator_patch_backlist: List[str] | None = None, forward_function_name: str = 'forward', check_args: Optional[Dict[str, Any]] = None, - autowrap_leaf_function = None, - autowrap_leaf_class = None, + autowrap_leaf_function: Optional[Dict[Any, LeafFnWrapInfo]] = None, + autowrap_leaf_class: Optional[Dict[Type, LeafClassWrapInfo]] = None, leaf_module: Tuple | None = None, fake_middle_class: Tuple | None = None, dce: bool = True, @@ -1742,22 +1779,11 @@ def f(x, y): operator_patch_backlist (List[str]): Blacklist of the operator patcher. - autowrap_leaf_function (Dict[Any, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool, Optional[Callable]]]): Leaf function dict, + autowrap_leaf_function (Dict[Any, LeafFnWrapInfo]): Leaf function dict, such as 'add' or 'torch.xxx'. You can add your own leaf functions. - The struct of dict is: leaf_function: ([(module_path, module_name)], force_to_trace, replace_to_function). - (module_path, module_name): The place the function exists. Such as torch.meshgrid, there are `torch.meshgrid`, - 'torch.functional.meshgrid', 'torch._C._VariableFunctions.meshgrid', we should wrap them all. - force_to_trace: If set to false, the function will only be traced if input relates to concrete_args. - Such as 'torch.rand', we should trace it even if it doesn't relate to concrete_args. - replace_to_function: If not `None`, we will use it to replace the original function in traced code. - Such as ModuleList.__getitem__, we can use operator.getitem to replace it. - - default_autowrap_leaf_class (Dict[Type, Tuple[List[Tuple[Union[ModuleType, Type], str]], bool]]): Leaf class dict, such as 'int', - 'range' or 'zip'. You can add your own leaf functions such as 'torch.finfo' or 'modeling_outputs.SequenceClassifierOutput'. - - The struct of dict is: leaf_class: ([(module_path, module_name)], is_iterator_class). - is_iterator_class: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True. + autowrap_leaf_class: (Dict[Type, LeafClassWrapInfo]): Leaf class dict, such as 'int', + 'range' or 'zip'. You can add your own leaf functions such as 'modeling_outputs.SequenceClassifierOutput'. dce (bool): If set to True, dead code eliminatation will be applied on the graph. diff --git a/cube/parallel.py b/cube/parallel.py index d86e4cc7..408b6b58 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -266,7 +266,7 @@ def _prepare_and_check_reusable( # MOO* | unmatch | generate (specail case is when there's no python source code, see below) # MOO | imported | raise error if unmatch # *: The precondition for `except` part is the compute config should match. - # you can take it as a continous operation after a failed MATCH/OVERRIDE. + # you can take it as a continous operation after a failed generation. reusable = False config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE old_config = torch.load(config_file) if config_file.exists() else None From 092f9b59a236c7021b9f20ba654a0d594cd6ce7a Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 12 Jan 2024 03:27:29 +0000 Subject: [PATCH 1572/1892] Merged PR 1983: support trace namedtuple & type parity check passed ![image.png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/1983/attachments/image.png) --- .../concrete_trace_utils/concrete_tracer.py | 81 ++++++++++++------- tests/graph/tracer/test_namedtuple.py | 26 ++++++ 2 files changed, 77 insertions(+), 30 deletions(-) create mode 100644 tests/graph/tracer/test_namedtuple.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 0bc112a5..a6068a77 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -891,7 +891,7 @@ def module_call_wrapper(mod, *args, **kwargs): else: # codes below corresponds to symbolic tracer's call_module module_qualified_name = self.path_of_module(mod) - with ScopeContextManager(self.scope, Scope(module_qualified_name, type(mod))) as _scope: + with ScopeContextManager(self.scope, Scope(module_qualified_name, _orig_type(mod))) as _scope: self.module_stack[_scope.module_path] = _scope.module_type if not self.is_leaf_module(mod, module_qualified_name): _autowrap_check(self, @@ -912,8 +912,10 @@ def module_call_wrapper(mod, *args, **kwargs): return ret_val class map_wrapper_clz: - @functools.wraps(_orig_map) - def __call__(self, the_func, *iterables: Any): + # used to track the original class + _fx_wrapped_ori_clz = _orig_map + + def __new__(cls, the_func, *iterables: Any): tracers = _orig_set() for one_iter in iterables: if _orig_isinstance(one_iter, ep.Proxy): @@ -946,11 +948,13 @@ def __eq__(self, __o: object) -> bool: return id(__o) in (id(self), id(_orig_map)) def __hash__(self): return id(self) - map_wrapper = map_wrapper_clz() + map_wrapper = map_wrapper_clz class range_wrapper_clz: - @functools.wraps(_orig_range) - def __call__(self, *args): + # used to track the original class + _fx_wrapped_ori_clz = _orig_range + + def __new__(cls, *args): # TODO: better infomation assert 1 <= _orig_len(args) <= 3 args = (arg.value if _orig_isinstance(arg, ep.ConcreteProxy) else arg for arg in args) @@ -959,11 +963,13 @@ def __eq__(self, __o: object) -> bool: return id(__o) in (id(self), id(_orig_range)) def __hash__(self): return id(self) - range_wrapper = range_wrapper_clz() + range_wrapper = range_wrapper_clz class enumerate_wrapper_clz: - @functools.wraps(_orig_enumerate) - def __call__(self, iterable, start=0): + # used to track the original class + _fx_wrapped_ori_clz = _orig_enumerate + + def __new__(cls, iterable, start=0): count = start for elem in iterable: if _orig_isinstance(elem, ep.ConcreteProxy) and _orig_isinstance(elem.value, (_orig_int, str)): @@ -975,21 +981,32 @@ def __eq__(self, __o: object) -> bool: return id(__o) in (id(self), id(_orig_enumerate)) def __hash__(self): return id(self) - enumerate_wrapper = enumerate_wrapper_clz() + enumerate_wrapper = enumerate_wrapper_clz class type_wrapper_clz: - @functools.wraps(_orig_type) - def __call__(self, instance): - orig_type = _orig_type(instance) - if orig_type in (ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): - return _orig_type(instance.value) + # used to track the original class + _fx_wrapped_ori_clz = _orig_type + + def __new__(cls, obj_or_name, *args): + # case 1: class type(name, bases, dict, **kwds) + if _orig_len(args) > 0: + assert _orig_len(args) == 2 + base_cls, cls_dict = args[0], args[1] + # if it is a wrapped class, replace it to the original one + base_cls = _orig_tuple(bs._fx_wrapped_ori_clz if hasattr(bs, '_fx_wrapped_ori_clz') else bs for bs in base_cls) + return _orig_type(obj_or_name, base_cls, cls_dict) + # case 2: class type(object) else: - return orig_type + orig_type = _orig_type(obj_or_name) + if orig_type in (ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): + return _orig_type(obj_or_name.value) + else: + return orig_type def __eq__(self, __o: object) -> bool: return id(__o) in (id(self), id(_orig_enumerate)) def __hash__(self): return id(self) - type_wrapper = type_wrapper_clz() + type_wrapper = type_wrapper_clz @classmethod @functools.wraps(_orig_agfunc_apply) @@ -1300,7 +1317,7 @@ def _create_wrapped_nn_module_func(tracer: ConcreteTracer, mod: torch.nn.Module, @functools.wraps(orig_fn) def wrapped(*args, **kwargs): module_qualified_name = tracer.path_of_module(mod) - with ScopeContextManager(tracer.scope, Scope(module_qualified_name, type(mod))) as _scope: + with ScopeContextManager(tracer.scope, Scope(module_qualified_name, _orig_type(mod))) as _scope: need_pop = False if _scope.module_path not in tracer.module_stack: need_pop = True @@ -1490,8 +1507,10 @@ def _create_wrapped_leaf_class(tracer: ConcreteTracer, clz): ... """ class clz_wrapper_clz: - @functools.wraps(clz) - def __call__(self, *args, **kwargs): + # used to track the original class + _fx_wrapped_ori_clz = clz + + def __new__(cls, *args, **kwargs): if tracer.temp_disable_call: return clz(*args, **kwargs) tracers = _orig_set() @@ -1510,15 +1529,15 @@ def __eq__(self, __o: object) -> bool: return id(__o) in (id(self), id(clz)) def __hash__(self): return id(self) - clz_wrapper = clz_wrapper_clz() + for name in dir(clz): attr = _orig_getattr(clz, name) if not name.startswith('_'): if _orig_isinstance(attr, Callable): - setattr(clz_wrapper, name, _create_wrapped_leaf_method(tracer, attr, name, None)) + setattr(clz_wrapper_clz, name, _create_wrapped_leaf_method(tracer, attr, name, None)) else: - setattr(clz_wrapper, name, attr) - return clz_wrapper + setattr(clz_wrapper_clz, name, attr) + return clz_wrapper_clz def _create_wrapped_leaf_iterable_class(tracer: ConcreteTracer, clz): """ @@ -1531,8 +1550,10 @@ def _create_wrapped_leaf_iterable_class(tracer: ConcreteTracer, clz): ... """ class clz_wrapper_clz: - @functools.wraps(clz) - def __call__(self, *args, **kwargs): + # used to track the original class + _fx_wrapped_ori_clz = clz + + def __new__(cls, *args, **kwargs): if tracer.temp_disable_call: return clz(*args, **kwargs) tracers = _orig_set() @@ -1556,15 +1577,15 @@ def __eq__(self, __o: object) -> bool: return id(__o) in (id(self), id(clz)) def __hash__(self): return id(self) - clz_wrapper = clz_wrapper_clz() + for name in dir(clz): attr = _orig_getattr(clz, name) if not name.startswith('_') or name in ('__getitem__', '__setitem__', '__iter__', '__len__'): if _orig_isinstance(attr, Callable): - setattr(clz_wrapper, name, _create_wrapped_leaf_method(tracer, attr, name, None)) + setattr(clz_wrapper_clz, name, _create_wrapped_leaf_method(tracer, attr, name, None)) else: - setattr(clz_wrapper, name, attr) - return clz_wrapper + setattr(clz_wrapper_clz, name, attr) + return clz_wrapper_clz def _create_wrapped_attr_for_middle_class(tracer: ConcreteTracer, clz, the_path_of_middle_class): _orig_clz_getattribute = clz.__getattribute__ diff --git a/tests/graph/tracer/test_namedtuple.py b/tests/graph/tracer/test_namedtuple.py new file mode 100644 index 00000000..4097fdaf --- /dev/null +++ b/tests/graph/tracer/test_namedtuple.py @@ -0,0 +1,26 @@ +from collections import namedtuple + +import torch +from cube.graph.parser.converter import to_fx_graph + +from ...utils import replace_all_device_with + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(10, 5) + + def forward(self, x): + Result = namedtuple('Result', ['r1', 'r2']) + return Result(self.fc1(x), self.fc2(x)) + +@replace_all_device_with('cpu') +def test_namedtuple(): + model = SimpleModel() + dummy_input = {'x': torch.rand(10)} + traced_graph = to_fx_graph(model, dummy_input) + + # just check if we can trace a model contains namedtuple + assert True From c65fdf8086f0e507d065c7dd4d570379de1e55b5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 12 Jan 2024 09:05:12 +0000 Subject: [PATCH 1573/1892] Merged PR 1993: Pipeline Support-4: Refine module loading - remove mis-leading warning of missing saved model chunk tids - re-organize loading using attr_name key instead of tid key to avoid potential bugs that multiple chunks of a parameter can be allocated on a same device - refine variables naming and logging info to make code more clear parity test and UT passed --- cube/runtime/module.py | 62 ++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 062bf4a9..7238de85 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -24,11 +24,12 @@ class CubeModule(torch.nn.Module): def __init__(self): super().__init__() self._reducers: List[Reducer] = list() - # Key: str, parameter name (from named_parameters) - # Value: Tuple[int, Tuple[slice], int]: - # full tensor tid, - # position of sub tensor in full tensor, - # position of value in value partition. + # self._fullmap contains the mapping of local attribute tensors to its fulltensor + # name (from named_parameters or named_buffers) -> ( + # fulltensor.tid, + # index position of its fulltensor, + # value partition num_chunks, + # ) self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() @property @@ -151,28 +152,41 @@ def get_full_map(self): return self._fullmap def load_attr_content(self, filename: str): - partitioned_model_pt = 0 - while os.path.isfile(filename + f'.{partitioned_model_pt}'): - partitioned_model_pt += 1 - if partitioned_model_pt == 0: + """Load module attribute (parameters and buffers) from file + + Args: + filename (str): base file name (without '.0', '.1', etc.) + that saved with model parameters + """ + npartitions = 0 + while os.path.isfile(filename + f'.{npartitions}'): + npartitions += 1 + if npartitions == 0: raise RuntimeError(f"Cannot find file {filename}.0 in load_attr_content") with torch.no_grad(): - _logger.info(f'load partitioned model from {filename}, partitioned_model_pt={partitioned_model_pt}') - fullmap2 = {tid: (attr, slicer, nchunks) for attr, (tid, slicer, nchunks) in self._fullmap.items()} - for file_idx in range(partitioned_model_pt): - full = torch.load(filename + f'.{file_idx}') - for tid in full.keys(): - if tid not in fullmap2: - _logger.warning(f'cannot find tid {tid} in fullmap2') + _logger.info(f'loading partitioned model from {filename}, number of model parameter chunks: {npartitions}') + # self._fullmap + attr_names = set(self._fullmap.keys()) + for file_idx in range(npartitions): + # part_model contains a subset of attributes, where each attribute is a fulltensor + # fulltensor.tid -> torch.Tensor + part_model: Dict[int, torch.Tensor] = torch.load(filename + f'.{file_idx}') + loaded_name = set() + for attr_name in attr_names: + tid, slicers, val_nchunks = self._fullmap[attr_name] + if tid not in part_model: continue - fm = fullmap2[tid] - tensor: torch.Tensor = getattr(self, fm[0]) - content = full[tid][fm[1]] / fm[2] - tensor.copy_(content) - fullmap2.pop(tid, None) - - if len(fullmap2) != 0: - raise RuntimeError(f'cannot find tid {list(fullmap2.keys())} in partitioned model files') + attr = getattr(self, attr_name) + content = part_model[tid][slicers] + if val_nchunks != 1: + content = content / val_nchunks + attr.copy_(content) + loaded_name.add(attr_name) + for name in loaded_name: + attr_names.remove(name) + if len(attr_names) != 0: + raise RuntimeError( + f'remaining graph parameters / buffers cannot find in model files: {list(attr_names)}') def init_group(self, ranks: List[int]): if not all([isinstance(rank, int) for rank in ranks]): From 45041941f1fae8165a385e66d05765aaef4c8757 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 15 Jan 2024 09:20:19 +0000 Subject: [PATCH 1574/1892] Merged PR 1998: support dict values/items --- cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index a6068a77..0fe06796 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -36,6 +36,8 @@ from torch.fx.operator_schemas import check_for_mutable_operation dict_keys_type = type(dict().keys()) +dict_values_type = type(dict().values()) +dict_items_type = type(dict().items()) try: # Scope is a new class to record module path in pytorch 2.0 @@ -597,7 +599,7 @@ def create_arg(self, a: Any) -> Union[Node, Any]: if isinstance(a, (torch.autograd.function.Function, torch.autograd.function.FunctionMeta)): return a - if isinstance(a, dict_keys_type): + if isinstance(a, (dict_keys_type, dict_values_type, dict_items_type)): return a return super().create_arg(a) From caecf7231ca73bdd8402db17d3f43fc54aa43157 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 16 Jan 2024 01:53:21 +0000 Subject: [PATCH 1575/1892] Merged PR 1995: add option to set attr savedir add option to set attr savedir parity check pass unit test pass --- cube/program.py | 9 +++++++-- cube/runtime/module.py | 5 +++-- cube/utils.py | 32 +++++++++++++++++++------------- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/cube/program.py b/cube/program.py index b9f97ba4..81f6a1df 100644 --- a/cube/program.py +++ b/cube/program.py @@ -126,7 +126,9 @@ class SemanticModel: def __init__(self, model: Optional[torch.nn.Module], save_content: bool = True, - dynamic_shape: bool = False): + dynamic_shape: bool = False, + attr_savedir: str = './', + ): """ Create semantic model based on AI Scientist description. @@ -137,6 +139,8 @@ def __init__(self, model: Optional[torch.nn.Module], whether to save the content of model and load it into generated model. Default True. dynamic_shape (bool): whether to use dynamic shape. Default False. + attr_savedir (str): + directory to save content (attribtes) """ if DeviceGroup().local_rank == 0 and model is not None: assert isinstance(model, torch.nn.Module), f"device of local_rank == 0 must provide model" @@ -147,6 +151,7 @@ def __init__(self, model: Optional[torch.nn.Module], # parser configuration self.save_content: bool = save_content self.dynamic_shape: bool = dynamic_shape + self.attr_savedir: str = attr_savedir @property def dummy_input(self) -> Any: @@ -223,7 +228,7 @@ def __call__(self, *args): self._ir_graph = parser.convert_model( self.model, dummy_input=self.dummy_input, - attr_savedir='./', + attr_savedir=self.attr_savedir, dynamic_shape=self.dynamic_shape ) return self._ir_graph(*args) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 7238de85..499a387e 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -200,8 +200,9 @@ def get_checkpoint(self, optimizer: torch.optim.Optimizer = None): # so we will try to load it from file on the fly. dist_param_map = getattr(self, '_dist_param_map', None) if not dist_param_map: - assert os.path.isfile('dist_param_map.pt'), 'Cannot open distributed parameter mapping file: dist_param_map.pt' - dist_param_map = torch.load('dist_param_map.pt') + module_file = Path(sys.modules[self.__module__].__file__) + # load from the same directory as the module file + dist_param_map = torch.load(module_file.with_name(FxModuleParser.ATTR_MAP_FILE)) param_area_map = self._fullmap optimizer_state_dict = optimizer.state_dict() if optimizer is not None else None return state_dict, dist_param_map, param_area_map, optimizer_state_dict diff --git a/cube/utils.py b/cube/utils.py index bfe3eeb2..9985dc02 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -1,6 +1,8 @@ import os from typing import Optional, Tuple, Callable import logging +from pathlib import Path +import sys import cube from cube.runtime.device import DeviceGroup @@ -13,7 +15,7 @@ def print_each_rank(msg: str, rank_only: Optional[int] = None, logger: Optional[logging.Logger] = None): """Logging the message. - + Args: msg (str): message to be logged. rank_only (int, optional): @@ -41,21 +43,25 @@ def print_each_rank(msg: str, rank_only: Optional[int] = None, logger: Optional[ def _load_module_attr(filename: str, name: str): + # TODO: use `importlib.import_module` instead import importlib.util spec = importlib.util.spec_from_file_location(name, filename) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) + sys.modules[name] = module # so you can find the loaded module in sys.modules return module -def load_model(filename: Optional[str] = None, load_content: bool = True): +def load_model(filename: Optional[str] = None, load_content: bool = True, fullmodel_filename: Optional[str] = None): filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename - module = _load_module_attr(filename, 'GenModel') + module = _load_module_attr(filename, Path(filename).stem) loaded_module: cube.runtime.module.CubeModule = module.GenModel().cuda() # load parameter content if load_content: _logger.info("loading parameter content...") - loaded_module.load_attr_content('./fullmodel.pt') + if not fullmodel_filename: + fullmodel_filename = str(Path(filename).with_name('fullmodel.pt')) + loaded_module.load_attr_content(fullmodel_filename) # initialize reducer for reducer in loaded_module.reducers: reducer.build_buckets() @@ -64,13 +70,13 @@ def load_model(filename: Optional[str] = None, load_content: bool = True): def load_default_schedule(filename: Optional[str] = None): filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename - module = _load_module_attr(filename, '_train_step') + module = _load_module_attr(filename, Path(filename).stem) return module._train_step def load_eval_schedule(filename: Optional[str] = None): filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename - module = _load_module_attr(filename, '_infer_step') + module = _load_module_attr(filename, Path(filename).stem) return module._infer_step @@ -91,7 +97,7 @@ class accum_mode: optimizer.zero_grad() ``` - Or, + Or, ``` for _ in range(num_iters): @@ -111,17 +117,17 @@ def __init__(self, begin: bool = True, end: bool = True): of the parameters in the reducer. end (bool): Whether the iteration is the last accumulation step. If True, the `model.reduce_grad()` will be enabled to reduce gradients at - the end of the iteration. + the end of the iteration. """ self.begin: bool = begin self.end: bool = end self.old: Tuple[bool, bool] = None - + def __enter__(self): """Enter the accumulation mode. Example usage: - + ``` for _ in range(num_iters): for step in range(accum_steps): @@ -131,7 +137,7 @@ def __enter__(self): optimizer.step() optimizer.zero_grad() ``` - + """ self.old = (RuntimeFlag.skip_zero_grad, RuntimeFlag.skip_reducer) RuntimeFlag.skip_zero_grad = (not self.begin) @@ -158,10 +164,10 @@ def steps(nsteps: int): optimizer.step() optimizer.zero_grad() ``` - + Args: nsteps (int): The number of accumulation steps. - + Yield: int: The current step index. """ From 264df5745989332a938f73ba59b8d6df4baf04a5 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Wed, 17 Jan 2024 06:09:53 +0000 Subject: [PATCH 1576/1892] Merged PR 2000: Pipeline Support-5: Formal backward hook This PR formally supports backward hook with new interface. The original hook implementation is hacky and will fail at pipeline phase since not every stage has loss in backward. --- cube/runtime/executor.py | 44 ++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/cube/runtime/executor.py b/cube/runtime/executor.py index 1823f03e..7e59b2a8 100644 --- a/cube/runtime/executor.py +++ b/cube/runtime/executor.py @@ -76,7 +76,7 @@ class Executor: # Each graph has its name, and multiple call for the graph will append # (instant id -> detached) input tensor pairs for backward reference. _detach: Dict[str, List[TensorPairs]] = dict() - _fn: Callable = None + _backward_pre_hook: Optional[Callable] = None @staticmethod def fexecute(name: str, subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True): @@ -143,12 +143,6 @@ def backward(name: str, gradient tensors corresponding to input_tensors. """ output_tensor_grads = Executor.sync_tensors(output_tensor_grads) - if Executor._fn is not None and output_tensor_grads[0] is None: - assert len(output_tensor_grads) == 1 - assert len(output_tensors) == 1 - output_tensors = (Executor._fn(output_tensors[0]), ) - - if len(output_tensors) == 0: return None saved_pairs = Executor._detach[name].pop(0) tensor_ids: List[int] = [pair[0] for pair in saved_pairs] @@ -162,7 +156,8 @@ def backward(name: str, f"Remain {len(Executor._detach[name])} segments.\n" f"{''.join(traceback.format_stack())}" ) - + + if len(output_tensors) == 0: return None input_tensors = [] for t in dtensors: @@ -181,6 +176,15 @@ def backward(name: str, dedup_output_tensors.append(t) dedup_output_tensor_grads.append(g) + # apply hook before backward + if Executor._backward_pre_hook is not None: + input_tensors, dedup_output_tensors, dedup_output_tensor_grads = \ + Executor._backward_pre_hook( + input_tensors, + dedup_output_tensors, + dedup_output_tensor_grads + ) + torch.autograd.backward( dedup_output_tensors, grad_tensors=dedup_output_tensor_grads, @@ -199,9 +203,33 @@ def sync_tensors(tensors: List[Any]) -> List[Any]: """ return [AsyncCommHandler().wait(t) if torch.is_tensor(t) else t for t in tensors] + + @staticmethod + def register_backward_pre_hook(hook: Optional[Callable]): + """Register a backward hook for the right before the backward executor. + + The backward hook will be called with the following arguments: + hook(input_tensors, output_tensors, output_tensor_grads) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]] + + The backward hook mainly serves for the scenarios like loss scaling. + + Notes: + Users can only register one backward pre_hook. If there was a hook + registered before, it will be overwritten. + + Args: + hook (Callable or None): the backward hook to be registered. The hook takes + input_tensors (List[torch.Tensor]), + output_tensors (List[torch.Tensor]), + output_tensor_grads (List[torch.Tensor]) as inputs and returns the + same format of updated tensors. + """ + Executor._backward_pre_hook = hook + @staticmethod def clear(): Executor._detach = dict() + Executor._backward_pre_hook = None @staticmethod def check_clear(): From b14fe0e3a972085aa3992f649b6bbf9239a80f5a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 18 Jan 2024 02:48:28 +0000 Subject: [PATCH 1577/1892] Merged PR 2001: Pipeline Support-6: fix staging and dataloader bugs --- cube/graph/graph.py | 24 +++++++++--------------- cube/graph/schedule/schedplan.py | 18 +++++++++++++----- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index ac8efa73..710c0984 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -856,13 +856,12 @@ def staging(self, nodes: Tuple[IRFwOperation]): # adjust the start of the first stage to involve beginning operators for idx in range(starts[0]): node = self.node(idx) + # IRDataOperation cannot be involved in the IRSegment in current + # implementation. if isinstance(node, IRDataOperation): continue assert isinstance(node, IRFwOperation), \ f"Expected nodes previous from the first stage are all IRFwOperation, but got {type(node)}" - if node.name == 'multiref' or isinstance(node, IRPyFunc): - pass - else: - _logger.info(f'involve node {node.name}({node.cid}) into the first stage') + _logger.info(f'involve node {node.name}({node.cid}) into the first stage') starts[0] = idx break @@ -895,11 +894,9 @@ def get_sid(fnode: IRCell) -> Optional[int]: def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: fwop = Identity(tensor) - fwop.infer_shape() - fwop.output(0).parent.dtype = tensor.dtype - fwop.set_output(0, fwop.output(0).tosub()) + output = tensor.parent.like().tosub() + fwop.set_output(0, output) if tensor.requires_grad: - fwop.output(0).parent.requires_grad = True # set input grad igrad = tensor.parent.grad.select(tensor.indmap, tensor.valmap) fwop.input(0).grad = igrad @@ -958,24 +955,21 @@ def insert_identity(tensor: IRSubTensor, sid: int) -> IRFwOperation: for cidx, consumer in enumerate(buckets[sid]): if fgrad is None: grad = None - elif isinstance(fgrad, float): - assert fgrad == 1.0, "Detect a backward tensor, but gradient can only be 1.0" - grad = fgrad else: valmap = curr_valmap.map((0, 2)) if cidx != nconsumers - 1 else curr_valmap grad = fgrad.select(ptensor.indmap, valmap) curr_valmap = curr_valmap.map((1, 2)) if cidx != nconsumers - 1 else curr_valmap # update forward consumer idx = consumer.inputs().index(ptensor) - ptensor = consumer.input(idx) + tensor = consumer.input(idx) with self.update(consumer) as consumer: consumer.set_input(idx, out) consumer.input(idx).grad = grad # update backward - if isinstance(consumer.mirror, IRCell): + if tensor.grad is not None: with self.update(consumer.mirror) as bconsumer: - idx = bconsumer.outputs().index(ptensor.grad) - bconsumer.set_output(idx,grad ) + idx = bconsumer.outputs().index(tensor.grad) + bconsumer.set_output(idx, grad) # grouping into segment for sid in range(len(fstages)): diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py index 1679ced1..a72ac484 100644 --- a/cube/graph/schedule/schedplan.py +++ b/cube/graph/schedule/schedplan.py @@ -284,6 +284,13 @@ def _place_dataloader(self): """ Place dataloaders together with segments """ + def insert_block(dl, mid, step): + dl_block = Block(dl, mid, 1) + # print(f'inserting microbatch {mid} at step {step} before {segment.name}{segment.cid}') + self._blocks.append(dl_block) + self._step_blocks[step+block.span-1].insert(0, dl_block) + self._block_start_step[dl_block] = step+block.span-1 + # insert dataloaders to its devices before the first required segment for dl in self._dependency.dataloaders: inserted_mids = set() @@ -294,13 +301,14 @@ def _place_dataloader(self): if mid in inserted_mids: continue if dl.device[0] not in segment.device: continue if self.graph.depends(dl, segment): - dl_block = Block(dl, mid, 1) - # print(f'inserting microbatch {mid} at step {step} before {segment.name}{segment.cid}') - self._blocks.append(dl_block) - self._step_blocks[step+block.span-1].insert(0, dl_block) - self._block_start_step[dl_block] = step+block.span-1 + insert_block(dl, mid, step) inserted_mids.add(mid) break + # we guarantee each dataloader is inserted into the schedule plan, + # in case that graph output requires the data from dataloader. + for mid in range(self._num_microbatches): + if mid not in inserted_mids: + insert_block(dl, mid, self.nsteps - 1) def topo_sort(self): """ From 86a81af23327b1eedb4b006de9dbdd5fa457641c Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 19 Jan 2024 02:39:39 +0000 Subject: [PATCH 1578/1892] Merged PR 2005: support dimops with identifiers in kwargs using IRObject support dimops with identifiers in kwargs using IRObject --- cube/graph/function/dimops.py | 20 +++++++++++++++----- tests/graph/parser/test_register.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 3fe2372a..1744f3b2 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -615,7 +615,7 @@ def __init__(self, create_fn: Callable, name: str, for anno in self._annos_candidates: anno = OpAnno(anno) # expand * and check shape dimension consistency - if self.align(inputs, anno, **kwargs): + if self.align(signature, inputs, anno, kwargs): self._iannos = anno.inputs() self._oannos = anno.outputs() self._anno = anno @@ -706,7 +706,7 @@ def new(self, inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): op.set_output(idx, output) return op - def align(self, inputs: List[IRTensor], op_anno: OpAnno, **kwargs) -> bool: + def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict) -> bool: """! Align input tensor shapes to the operator annotation. @@ -776,9 +776,19 @@ def align(self, inputs: List[IRTensor], op_anno: OpAnno, **kwargs) -> bool: if identifier not in kwargs: toinfer.append(identifier) else: - assert isinstance(kwargs[identifier], int), "require integer for annotation inference" - ret = op_anno.setlen(identifier, kwargs[identifier]) - accum *= kwargs[identifier] + if isinstance(kwargs[identifier], IRObject): + _logger.warning( + f"Function {signature}: Found identifier {identifier} in kwargs to be IRObject, " + f"this will turn it into a static value. Pay attention to the usage " + f"in dynamic-shape scenarios") + kwargs[identifier] = kwargs[identifier].value + length = kwargs[identifier] + if not isinstance(length, int): + raise ValueError( + f"Function {signature}: identifier {identifier} in kwargs " + f"must be int or IRObject[value=int], but got {length}") + ret = op_anno.setlen(identifier, length) + accum *= length else: accum *= length if len(toinfer) == 0 and accum != dimlen: diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index 544bd5e4..cb05381b 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -18,6 +18,11 @@ def mock_add2(x: torch.Tensor, y: torch.Tensor): return x + y +@cube.graph.parser.register('(h w^) k^ -> h (w^ k^)') +def mock_view_with_obj(x, h): + return x.view(h, -1) + + class MockAGF(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, y: torch.Tensor): @@ -58,6 +63,17 @@ def forward(self, x, y): return MockAGF.apply(x, y) +class MockModelObj(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x, h: int): + # x: [40, 10] + x = self.fc(x) + return mock_view_with_obj(x, h) + + # passed test @replace_all_device_with('cpu') def test_common_register(): @@ -93,3 +109,16 @@ def test_autograd_register(): for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'Function.apply']): profile_name = ProfileDataBase.get_func(node)[0].__qualname__ assert profile_name == p_name, f'{profile_name} should be {p_name}' + + +@replace_all_device_with('cpu') +def test_autograd_register(): + model = MockModelObj() + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(model, {'x': torch.rand(40, 10), 'h': 4}, tempdir, False) + + node = ir_graph.select(name='mock_view_with_obj')[0] + assert node.kwargs['h'] == 4 + sub_nodes = ir_graph.partition(node, node.algorithms('dim'), idx=0, dim=0, num=2) + for sub_node in sub_nodes: + assert sub_node.kwargs['h'] == 2 From 290a79b08445e18cf6b3832ccb63a3cb96cceff3 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 19 Jan 2024 04:00:06 +0000 Subject: [PATCH 1579/1892] Merged PR 2009: add doc for parallel module add doc for parallel module. This document doesn't include the case when parallelmodule is a submodule of a pytorch module. This part will be included when ParallelModule has fully support to that case. Related work items: #1758 --- cube/parallel.py | 5 +- docs/parallel_module.md | 291 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 1 deletion(-) create mode 100644 docs/parallel_module.md diff --git a/cube/parallel.py b/cube/parallel.py index 408b6b58..b39c4062 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -477,6 +477,7 @@ def _gencode( if is_module_class: del module else: + logger.info(f"Reuse graph dump in {outdir}") graph = IRGraph.load(graph_ckp) forward_args = torch.load(forward_args_ckp) @@ -577,7 +578,7 @@ def parallelize( Or you can unset load_module flag, and manually copy the generated files to other nodes. After all nodes have the generated files, you can call parallelize() again with load_module flag set. - Note: if reuse is not set to ReuseType.ALL, + Note: if reuse is not set to ReuseType.MATCH, the generated code in outdir will be removed EVEN IF the code generetion fails in this call. if the input is a module object. @@ -665,6 +666,8 @@ def __init__(self, init_params=True): module_dtype=module_dtype, module_fn=module_fn, ) + else: + logger.info(f"Reuse generated code in {outdir}") if load_module: if not torch.distributed.is_initialized(): # we only support loading in torchrun environment diff --git a/docs/parallel_module.md b/docs/parallel_module.md new file mode 100644 index 00000000..9ef7fe15 --- /dev/null +++ b/docs/parallel_module.md @@ -0,0 +1,291 @@ +# Parallel Module + +Besides the support of end-to-end model training, Cube can also convert a `torch.nn.Module` to a parallel module. A parallel module is a special `torch.nn.Module` but runs in multiple gpus/nodes. All the complexity of distributed training/inferring is hidden from the user. + +## An example + +```python +import torch +from cube.parallel import parallelize, ComputeConfig, build_optimizer + +class LLM(torch.nn.Module): + def __init__(self, ...): + ... + def forward(self, x): + ... + +llm_sample_input = ... # dummpy input will be used to do tracing +pas_policy = ... # the PAS policy, you can use autodist pas +compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + ..., +) # compute environment config +ParallelizedLLM = parallelize( + LLM, + {'x': llm_sample_input}, + pas_policy, + compute_config, +) + +# do inference exactly the same way +def infer(model: ParallelizedLLM, x): + model.eval() + with torch.inference_mode(): + return model(x) + + +# do training exactly the same way +# except you need to patch your optimizer to support distributed training via build_optimizer +def train(model: ParallelizedLLM, data): + loss_fn = ... + # build_optimizer function will help to create a distributed optimizer + optimizer = build_optimizer(model, ...) + + for i, (x, y) in enumerate(data): + model.train() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.step() + optimizer.zero_grad() +``` + +## APIs + +### ComputeConfig +The configuration of the compute environment. It is a dataclass with the following fields: +```python +@dataclass(frozen=True) +class ComputeConfig: + plan_ngpus: int + runtime_ngpus: int + + dynamic_shape: bool = True + + reducer_op: str = 'sum' + use_zero: bool = False + zero_ngroups: int = 1 + + user_config: Optional[Dict[str, Any]] = None +``` +We can categorize the fields into 4 categories: + +1. Trace configuration + - dynamic_shape: whether to use dynamic shape or static shape. +2. Compute environment configuration + - plan_ngpus: the number of gpus to be used as a unit. The model is partitioned (TP or PP) within a unit, and then data parallelism is applied across multiple units. So every `plan_ngpus` devices holds the whole model. Furthermore, assume we have two workers, and their ranks are `rank1` and `rank2`: + 1. if `rank1 // plan_gpus == rank2 // plan_ngpus`, then they are in the same unit. + 2. If `rank1 % plan_ngpus == rank2 % plan_ngpus`, then the portion of model hold on both gpus are exactly the same. + - runtime_ngpus: the number of gpus to be used in runtime. It should be a multiple of `plan_ngpus`, which means we have `runtime_ngpus // plan_ngpus` units in runtime, and the data parallelism is `runtime_ngpus // plan_ngpus`. +3. Code generation feature configuration + - use_zero: whether to use zero. If it is true, the generated code will use zero1 to do distributed training. + - zero_ngroups: the number of groups to be used in zero. + - reducer_op: the reducer operation for the gradients. It can be `sum`, `mean`, `min`, `max`, `avg`. +4. User configuration + - user_config: the user configuration. A typical usage is deciding whether skipping compiling and reusing the previously compiled parallel module. If user_config is the same between two runs, compiling in the second run will be skipped. + +Note: +1. `reducer_op` represents which `torch.distributed.ReduceOp` is used when reduce gradients + by torch.distributed.all_reduce or torch.distributed.reduce_scatter + + In some cases, you may want to firstly divide the local gradients, and then use torch.distributed.ReduceOp.SUM to get the final the gradients. + You can achieve that speical mean with `optimizer.register_reducer_pre_hook` by setting `reducer_op` to `sum` and divide the local gradients with the following code: + ```python + def _mean_hook(reducer, grad): + if reducer.reduce_op == torch.distributed.ReduceOp.SUM: + grad.div_(reducer.ranks) + optimizer.register_reducer_pre_hook(_mean_hook) + ``` +2. You can put any graph related configuration here. The assumption is different user_config should generate different graph/code. So if the user config is changed, we will regenerate the graph/code automatically. Here are some examples: + + - Example 1: save module configuration + ```python + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + ... + if module_config.use_3d: + ... + ``` + here we can set `user_config={'use_3d': module_config.use_3d}`, + and we can be sure different use_3d config will never use the same generated code. + + - Example 2: save file stats + If you want to track all related file stats (just like traditional compilers do), + you can do + ```python + user_config = { + 'file_stats': { + str(f): os.stat(f).st_mtime_ns for f in Path('./src').glob('**/*.py') # assume all source code is in ./src + } + } + ``` + Or you can save the md5 of the files to save some bytes: + ```python + import hashlib + h = hashlib.md5() + for f in Path('./src').glob('**/*.py'): + with open(f, 'rb') as f: + h.update(f.read()) + user_config = { + 'files_md5': h.hexdigest() + } + ``` + +### ReuseType + +The reuse policy for the existing generated code. It is an enum with the following values: +```python +class ReuseType(Enum): + MATCH = 'match' + OVERRIDE = 'override' + MOO = 'moo' + GRAPH = 'graph' +``` +We call it a `match` when the `ComputeConfig` is the same with the previous run. + +1. MATCH: Reuse if match, error if not match, generate if no previous gerenated code exists. +2. OVERRIDE: Nothing will be reused. Everything will be regenerated. +3. MOO: MOO is short for 'match or override'. It will reuse if match, generate if not match or no previous gerenated code exists. +4. GRAPH: Reuse graph only if match, generate otherwise. + + +### Module Conversion + +We have `parallelize` function to Convert a torch.nn.Module to a ParallelModule. +```python +def parallelize( + module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]], + dummy_input: dict, + pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], + compute_config: ComputeConfig, + *, + cube_savedir: Union[str, Path] = './.cube', + reuse: Union[ReuseType, str] = ReuseType.MATCH, + instance_name: Optional[str] = None, + load_module: bool = True, + module_dtype: Optional[torch.dtype] = None, + module_fn: Optional[Callable[[], torch.nn.Module]] = None, + init_module_params: bool = True, +) -> Union[None, ParallelModule, Type[ParallelModule]]: +``` +It has the following parameters: + +- module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled. Please note if the input is a module object, we will return a `ParalleModule` object. If the input is a module class, we will return a `ParalleModule` class. + +- dummy_input (dict): the dummy input for the module + +- pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy, which describes how to place all computations across devices. You can use `autodist` to do the pas automatically in the most efficient way. + +- compute_config (ComputeConfig): the environment resource + +- reuse (ReuseType): specify which part can be reused. + +- cube_savedir (Union[str, Path]): the directory to save generated code + +- instance_name (Optional[str]): the instance name of the generated module. If it is `None`, will use the default name. + +- load_module (bool): whether to load the generated module or module class after conversion is done. +Currently the module can only be loaded in torchrun environment. So you can do the conversion in any environment (with `load_module` unset), and load the module in torchrun environment. + +- init_module_params (bool): If true, when we construct the module, all its parameters are initialized with the same value with when we traced. +Otherwise, they will be empty tensor. +This parameter will be passed to the module constructor, +so it is only used when `module_or_module_class` is a module object, and `load_module` is true. + +- module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. + +- module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use `__init__` if it is None. This parameter is only used when `module_or_module_class` is a module class. + +Note: + +1. This function can be used to convert both module object and module class to cube module or cube module class. +Among key-value arguments, +`module_fn` and `module_dtype` control how to create the module object. +whereas `init_module_params` controls how to load cube module object after conversion is done. + +2. If you want to save multiple instances of the same module (with differnt configurations), +you can specify the `instance_name` to distingish them. + +3. Currently you must use a shared file system to share the generated files (like mounted Azure Blob). +Or you can unset `load_module` flag, and manually copy the generated files to other nodes. +After all nodes have the generated files, you can call `parallelize()` again with `load_module` flag set. + +4. if reuse is not set to ReuseType.MATCH, +the generated code in outdir will be removed EVEN IF the code generetion fails in this call. + +After the module is converted, you can use it to create module object by calling it like a module class. +The module class is defined like: +```python +class GenModule(cube.runtime.module.ParallelModule): + def __init__(self, init_params=True): + super().__init__() + ... + ... +``` +So you can use `init_params` in `__init__` to control whether to initialize the module parameters. +For example, if you don't want to intialize module params: +```python +module = GenModule(init_params=False) +``` + +### Optimizer Creation + +We have `build_optimizer` to build an optimizer for distributed training. +```python +def build_optimizer( + module: torch.nn.Module, + optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], + *args, + **kwargs, +) -> OptimizerT: +``` +It has the following parameters: +- module (torch.nn.Module): the module to be optimized +- optimizer_fn (Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]): + It can be the optimizer class or optimizer factory function. +- *args: the args will pass to `optimizer_fn` +- **kwargs: the kwargs will pass to `optimizer_fn` + +To support distrubted training, in the function we need to hook 4 places: + +1. optimizer constructor: + the parameters of optimizer will not be the same with the parameters of the module if we use zero. + So we need to replace the parameters of optimizer with `CubeModule.parameters_for_optimizer`. + +2. `optimizer.step()`: + we need to call `optimizer.sync_shard_grad()` to sync the gradients of the module before `optimizer.step()`. + In zero mode (not supported yet), we have to call `CubeModule.gather_params()` after `optimizer.step()` + +3. `optimizer.zero_grad()`: + We need to call `CubeModule.zero_grad()` after `optimizer.zero_grad()` + +`build_optimizer` will patch optimizer for you. Besides the above patches, we also add several utility functions to optimizer: + +1. `sync_shard_grad`: Sync the shard gradients of the module from nodes with same shard to the optimizer. This function is called in optimizer's pre-step hook. But If you want to access the gradients before `optimizer.step()`(for example, you need gnorm), you need to call this function manually. + +2. `register_reducer_pre_hook`, `register_reducer_post_hook`: Register pre/post hooks to reducers which will be applied before/after gradient synchronization. + +### Dataset + +We use the same dataset/dataloader as pytorch. For example, you can use `torch.utils.data.DistributedSampler` to create a distributed sampler. + +`ParallelModule`s running in the same unit should use the same input, and will get the same output. `ParallelModule`s runing in different units should use different input and will get different output (similar to data parallelism). The gradients of all parameters will be synced across all the devices automatically. + +Take `torch.utils.data.DistributedSampler` for example, you can create the sampler like this: +```python +def create_distributed_sampler(dataset): + return torch.utils.data.DistributedSampler( + dataset=dataset, + num_replicas=compute_config.runtime_ngpus // compute_config.plan_ngpus, + rank=torch.distributed.get_rank() // compute_config.plan_ngpus, + ..., + ) +``` + +## TODOs +1. When ParallelModule is a submodule of another Module, Pytorch DDP is not supported yet. +2. Pipeline parallelism is not supported yet. From 4bb0c3304dd8c7808587794fb5635a84a0568298 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 22 Jan 2024 15:37:35 +0000 Subject: [PATCH 1580/1892] Merged PR 2015: add missed runtime tensor creation This PR add missed runtime tensor creation (required at function.py: https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube?path=/cube/graph/function/function.py&version=GBmain&line=297&lineEnd=297&lineStartColumn=1&lineEndColumn=47&lineStyle=plain&_a=contents) --- cube/graph/function/function.py | 27 ++++++++++++++++++++------ cube/runtime/function/function.py | 8 ++++++++ tests/graph/function/test_functions.py | 22 +++++++++++++++++++++ 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index d38630e1..33217777 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -292,15 +292,30 @@ def Full(size, fill_value, *, out=None, dtype=None, layout=None, def NewTensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False, signature=None): - # note: device is ignored - dtype = dtype if dtype is not None else torch.get_default_dtype() + """ + torch.tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) + """ signature = 'cube.runtime.function.tensor' - size = tuple(np.array(data).shape) if np.array(data).shape else (1,) # (1,) means it is a scalar - kwargs = {'size': size, 'requires_grad': requires_grad, + + val = data + if isinstance(data, IRTensor): + size = data.shape + elif isinstance(data, IRObject): + size = torch.tensor(data.value).shape + else: + # for non-IRObject instance, we will always convert to list + # through torch.tensor, since we cannot guarantee the `data` + # instance to be executable for its `repr(data)` string + # in gencode + val = torch.tensor(data) + size = val.shape + val = val.tolist() + size = size if len(size) > 0 else (1,) # for scalar + + kwargs = {'data': val, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} - anno, rules = _get_creator_anno_rules(size, True) + anno, rules = _get_creator_anno_rules(size, False) dimop = IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) - dimop.output(0).parent.dtype = dtype return dimop diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index d77378a6..0d194eb5 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -149,6 +149,14 @@ def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): return torch.masked_scatter(input, mask, src) +def tensor(data, *, dtype=None, requires_grad=False, pin_memory=False): + return torch.tensor( + data, dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=requires_grad, pin_memory=pin_memory + ) + + def empty(size: Tuple[int], dtype=None, requires_grad=False, pin_memory=False): return torch.empty( size, dtype=torch.get_default_dtype() if dtype is None else dtype, diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 3d59df47..4a90eab8 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -5,6 +5,7 @@ import pytest import torch +import numpy as np def o(value): @@ -273,3 +274,24 @@ def test_Unsqueeze(): def test_ScaledDotProductAttention(): op = F.ScaledDotProductAttention(IRTensor([8, 128, 64]), IRTensor([8, 256, 64]), IRTensor([8, 256, 32]), None, 0.05) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a e d^, a b^ d^, a b^ c -> a e c' + + +def test_NewTensor(): + op = F.NewTensor(torch.tensor(1)) + assert op.signature == 'cube.runtime.function.tensor' + assert repr(op.anno) == ' -> 1^' + assert op.kwargs['data'] == 1 + + op = F.NewTensor(torch.tensor([1,2])) + assert op.signature == 'cube.runtime.function.tensor' + assert repr(op.anno) == ' -> 2^' + assert op.kwargs['data'] == [1,2] + + obj = IRObject(value=np.array([1,2])) + op = F.NewTensor(obj) + assert repr(op.anno) == ' -> 2^' + assert op.kwargs['data'] == obj + + op = F.NewTensor(np.array([[1],[2],[3]])) + assert repr(op.anno) == ' -> 3^ 1^' + assert op.kwargs['data'] == [[1],[2],[3]] From fb8a1676d6d934160aabd967e5c3b457a85f5b04 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Jan 2024 06:07:47 +0000 Subject: [PATCH 1581/1892] Merged PR 2018: Pipeline Support-8: Refine code --- cube/compiler.py | 7 ++++--- cube/graph/schedule/predefined.py | 2 +- cube/graph/schedule/schedplan.py | 4 ++-- cube/parallel.py | 7 ++++--- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 28d39207..42b5a288 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -202,11 +202,12 @@ def decorator(fn: Callable) -> Callable: # check assignment and remove anchor node for node in graph.nodes(flatten=True): - # skip graph anchor and multiref: they will be removed or replaced by system + # skip graph anchor: will be removed + # skip multiref and IRPyFunc: they will be managed by system if isinstance(node, IRGraphAnchor) or node.name == 'multiref': - graph.assign(node, 0) + continue if isinstance(node, IRPyFunc): - graph.assign(node, 0) + continue if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") diff --git a/cube/graph/schedule/predefined.py b/cube/graph/schedule/predefined.py index bfbd3506..ebb81b18 100644 --- a/cube/graph/schedule/predefined.py +++ b/cube/graph/schedule/predefined.py @@ -28,7 +28,7 @@ def sched_1f1b(graph: IRGraph, num_microbatches: int, num_stages: int) -> Schedu raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) fsegs = [seg for seg in segments if seg.isfw()] - assert len(fsegs) == num_stages, f"Mismatch of forward segement number ({len(fsegs)}) with num_stages ({len(num_stages)})" + assert len(fsegs) == num_stages, f"Mismatch of forward segement number ({len(fsegs)}) with num_stages ({num_stages})" # describe schedule sched = SchedulePlan(graph, num_microbatches) diff --git a/cube/graph/schedule/schedplan.py b/cube/graph/schedule/schedplan.py index a72ac484..d973ea8f 100644 --- a/cube/graph/schedule/schedplan.py +++ b/cube/graph/schedule/schedplan.py @@ -371,10 +371,10 @@ def apply(self): """ # step 1: build dependency for scheduling self._dependency.build() - # step 2: apply this scheduling + # step 2: insert adapters and dataloaders to the plan self._place_adapters() self._place_dataloader() - # step 3: generate topological sequence + # step 3: generate topological sequence, append reducers self.topo_sort() def validate(self) -> bool: diff --git a/cube/parallel.py b/cube/parallel.py index b39c4062..bf427a09 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -487,11 +487,12 @@ def _gencode( # check assignment and remove anchor node for node in graph.nodes(flatten=True): - # skip graph anchor and multiref: they will be removed or replaced by system + # skip graph anchor: will be removed + # skip multiref and IRPyFunc: they will be managed by system if isinstance(node, IRGraphAnchor) or node.name == 'multiref': - graph.assign(node, 0) + continue if isinstance(node, IRPyFunc): - graph.assign(node, 0) + continue if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") graph = IRAdapterGener.gen(graph, cost_fn=None) From 9e8067a504a4004a43f56727c3080fecddfb733a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 23 Jan 2024 07:37:31 +0000 Subject: [PATCH 1582/1892] Merged PR 2016: Pipeline Support-7: fix adapter generation bug adapter for graph / segment output was skipped by only checking the consistency of forward tensors but not considering backward tensors. --- cube/graph/gener/gen.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 8732032f..a94cd30c 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -360,23 +360,22 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # print(f'backward:\n{graph.mirror.debug_tensor_map_str(ftensor.grad)}') # producers can be operators and graph inputs - fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) + fptensors = graph.ptensors(ftensor) if ftensor in input_producer: fptensors = fptensors + tuple(fop.output(0) for fop in input_producer[ftensor]) fptensors = expand_devices(fptensors, producer=True) assert all(len(ptensor.device) == 1 for ptensor in fptensors), "Not support for multi-device" # consumers can be operators and graph outputs - fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) + fctensors = graph.ctensors(ftensor) fctensors = expand_devices(fctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" - bproducers, bptensors = [], [] - bconsumers, bctensors = [], [] + bptensors, bctensors = [], [] if isinstance(ftensor.grad, IRFullTensor): - bproducers, bptensors = bgraph.producers(ftensor.grad), bgraph.ptensors(ftensor.grad) + bptensors = bgraph.ptensors(ftensor.grad) bptensors = expand_devices(bptensors, producer=True) - bconsumers, bctensors = bgraph.consumers(ftensor.grad), bgraph.ctensors(ftensor.grad) + bctensors = bgraph.ctensors(ftensor.grad) if ftensor in input_producer: bctensors = bctensors + tuple(fwop.output(0).grad for fwop in input_producer[ftensor]) bctensors = expand_devices(bctensors, consumer=True) @@ -392,22 +391,23 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: fadapters.append(fadapter) # (activation -> graph/segment output) generation: generate communication adapters between - # producer operatiors and graph/segment output tensors. Note graph/segment output tensors + # producer operators and graph/segment output tensors. Note graph/segment output tensors # always require for full-shape/value for output, while consumers may partition them. Therefore, # we need to additionally generate adapters for this case. if ftensor in output_consumer: out_fctensors = tuple(fwop.input(0) for fwop in output_consumer[ftensor]) out_fctensors = expand_devices(out_fctensors, consumer=True) + out_bptensors = [] + if isinstance(ftensor.grad, IRFullTensor): + out_bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) + out_bptensors = expand_devices(out_bptensors, consumer=True) # dedup adapter if the output is same with activation tensor if set(out_fctensors) == set(fctensors) and \ + set(out_bptensors) == set(bptensors) and \ set(t.device[0] for t in out_fctensors) == set(t.device[0] for t in fctensors): pass else: - fctensors = out_fctensors - bptensors = [] - if isinstance(ftensor.grad, IRFullTensor): - bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) - bptensors = expand_devices(bptensors, producer=True) + fctensors, bptensors = out_fctensors, out_bptensors if (not skip(fptensors, fctensors)) or (not skip(bptensors, bctensors)): fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) if fadapter is not None: @@ -438,6 +438,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # insert forward adapter # graph.insert(fadapter, max(producers) + 1) + fconsumers = graph.consumers(ftensor) if len(fconsumers) > 0: fidx = min(graph.multi_index(fconsumers)) else: @@ -458,6 +459,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: if badapter is not None: assert isinstance(badapter, IRAdapter) assert isinstance(bgraph, IRSegment) + bproducers = bgraph.producers(ftensor.grad) if len(bproducers) > 0: bidx = max(bgraph.multi_index(bproducers)) + 1 else: From 41e839c9f825b01e96d45bb87a83ca68869a29e5 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Tue, 23 Jan 2024 11:28:49 +0000 Subject: [PATCH 1583/1892] Merged PR 2026: Revert 'Pipeline Support-7: fix adapter generation bug' adapter for graph / segment output was skipped by only checking the consistency of forward tensors but not considering backward tensors. Reverts !2016 --- cube/graph/gener/gen.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index a94cd30c..8732032f 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -360,22 +360,23 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # print(f'backward:\n{graph.mirror.debug_tensor_map_str(ftensor.grad)}') # producers can be operators and graph inputs - fptensors = graph.ptensors(ftensor) + fproducers, fptensors = graph.producers(ftensor), graph.ptensors(ftensor) if ftensor in input_producer: fptensors = fptensors + tuple(fop.output(0) for fop in input_producer[ftensor]) fptensors = expand_devices(fptensors, producer=True) assert all(len(ptensor.device) == 1 for ptensor in fptensors), "Not support for multi-device" # consumers can be operators and graph outputs - fctensors = graph.ctensors(ftensor) + fconsumers, fctensors = graph.consumers(ftensor), graph.ctensors(ftensor) fctensors = expand_devices(fctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in fctensors), "Not support for multi-device" - bptensors, bctensors = [], [] + bproducers, bptensors = [], [] + bconsumers, bctensors = [], [] if isinstance(ftensor.grad, IRFullTensor): - bptensors = bgraph.ptensors(ftensor.grad) + bproducers, bptensors = bgraph.producers(ftensor.grad), bgraph.ptensors(ftensor.grad) bptensors = expand_devices(bptensors, producer=True) - bctensors = bgraph.ctensors(ftensor.grad) + bconsumers, bctensors = bgraph.consumers(ftensor.grad), bgraph.ctensors(ftensor.grad) if ftensor in input_producer: bctensors = bctensors + tuple(fwop.output(0).grad for fwop in input_producer[ftensor]) bctensors = expand_devices(bctensors, consumer=True) @@ -391,23 +392,22 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: fadapters.append(fadapter) # (activation -> graph/segment output) generation: generate communication adapters between - # producer operators and graph/segment output tensors. Note graph/segment output tensors + # producer operatiors and graph/segment output tensors. Note graph/segment output tensors # always require for full-shape/value for output, while consumers may partition them. Therefore, # we need to additionally generate adapters for this case. if ftensor in output_consumer: out_fctensors = tuple(fwop.input(0) for fwop in output_consumer[ftensor]) out_fctensors = expand_devices(out_fctensors, consumer=True) - out_bptensors = [] - if isinstance(ftensor.grad, IRFullTensor): - out_bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) - out_bptensors = expand_devices(out_bptensors, consumer=True) # dedup adapter if the output is same with activation tensor if set(out_fctensors) == set(fctensors) and \ - set(out_bptensors) == set(bptensors) and \ set(t.device[0] for t in out_fctensors) == set(t.device[0] for t in fctensors): pass else: - fctensors, bptensors = out_fctensors, out_bptensors + fctensors = out_fctensors + bptensors = [] + if isinstance(ftensor.grad, IRFullTensor): + bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) + bptensors = expand_devices(bptensors, producer=True) if (not skip(fptensors, fctensors)) or (not skip(bptensors, bctensors)): fadapter = ConcurrentGener.gen(fptensors, fctensors, bptensors, bctensors, cost_fn) if fadapter is not None: @@ -438,7 +438,6 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # insert forward adapter # graph.insert(fadapter, max(producers) + 1) - fconsumers = graph.consumers(ftensor) if len(fconsumers) > 0: fidx = min(graph.multi_index(fconsumers)) else: @@ -459,7 +458,6 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: if badapter is not None: assert isinstance(badapter, IRAdapter) assert isinstance(bgraph, IRSegment) - bproducers = bgraph.producers(ftensor.grad) if len(bproducers) > 0: bidx = max(bgraph.multi_index(bproducers)) + 1 else: From ae2a856d6175f84928c6990c0e44067dd709c228 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 24 Jan 2024 02:46:06 +0000 Subject: [PATCH 1584/1892] Merged PR 2003: revert the patched when tracing the leaf functions need add test for patcher ![image.png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2003/attachments/image.png) --- .../concrete_trace_utils/concrete_tracer.py | 33 ++--- .../concrete_trace_utils/function_patcher.py | 136 ++++++++++++++++++ .../concrete_trace_utils/operator_patcher.py | 4 +- 3 files changed, 150 insertions(+), 23 deletions(-) create mode 100644 cube/graph/parser/fx/concrete_trace_utils/function_patcher.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 0fe06796..759f7f75 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -29,7 +29,7 @@ import torch.fx from torch.fx import GraphModule from torch.fx._compatibility import compatibility -from torch.fx._symbolic_trace import _Patcher, _proxyable_classes +from torch.fx._symbolic_trace import _proxyable_classes from torch.fx.graph import Graph from torch.fx.node import Target, Node, Argument, _side_effectful_functions from torch.fx.proxy import TracerBase @@ -86,6 +86,7 @@ def __exit__(self, *args): return from . import concrete_proxy as ep +from .function_patcher import FunctionPatcher from .operator_patcher import OperatorPatcherContext from .utils import ( _orig_module_call, @@ -320,6 +321,7 @@ def __init__(self, cpu_offload = False, record_frames = False): self.node_name_to_scope = {} self.cpu_offload = cpu_offload self.record_frames = record_frames + self.patcher = FunctionPatcher() @contextmanager def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): @@ -388,26 +390,17 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] if kind == 'call_function': assert isinstance(target, Callable) fn = target - if _orig_getattr(fn, '__module__', None) != self.__module__ \ - and hasattr(fn, '__globals__'): - _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - return OperatorPatcherContext.patch_run(fn, *args, **kwargs) + result = fn(*args, **kwargs) elif kind == 'call_method': self_obj, *args_tail = args fn = _orig_getattr(self_obj, target) - if _orig_getattr(fn, '__module__', None) != self.__module__ \ - and hasattr(fn, '__globals__'): - _autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) result = fn(*args_tail, **kwargs) elif kind == 'call_module': assert isinstance(target, str) mod = self.fetch_attr(target) if self.cpu_offload: mod.cuda() # how it works in ddp? - if _orig_getattr(mod, '__module__', None) != self.__module__ \ - and hasattr(mod, '__globals__'): - _autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - result = OperatorPatcherContext.patch_run(mod, *args, **kwargs) + result = mod(*args, **kwargs) if self.cpu_offload: mod.cpu() elif kind == 'get_attr': @@ -479,6 +472,7 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: """ similar to _symbolic_trace.Tracer.create_proxy. use the 'run_target' to actually execute the code, and store the value in 'value' field. + create the nodes for the target and the input of the target (if the target is one of call_method, call_function, call_module). """ def unwrap(obj: Any): while _orig_isinstance(obj, ep.ConcreteProxy): @@ -487,8 +481,8 @@ def unwrap(obj: Any): args_unwrapped = ep.map_aggregate_not_proxy(args, unwrap) kwargs_unwrapped = ep.map_aggregate_not_proxy(kwargs, unwrap) - # real value by execution - value_unwrapped = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) + with self.patcher.revert(): + value_unwrapped = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) args_ = self.create_arg(args) kwargs_ = self.create_arg(kwargs) @@ -1172,12 +1166,13 @@ def getattr_wrapper(obj, *args): self.temp_disable_attr_level = 0 self.temp_disable_agfunc_apply_level = 0 try: - with _Patcher() as self.patcher: + with self.patcher: # allow duplicate patches to support the case of nested calls self.patcher.patch_method(torch.nn.Module, "__getattribute__", module_getattribute_wrapper, deduplicate=False) self.patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) - self.patcher.patch_method(torch.autograd.Function, "apply", agfunc_apply_wrapper, deduplicate=False) + # for cuda versions of pytorch, autograd.Function.apply should be reverted by delattr + self.patcher.patch_method(torch.autograd.Function, "apply", agfunc_apply_wrapper, deduplicate=False, revert_by_del=True) self.patcher.patch_method(torch, "_assert", torch_assert_wrapper, deduplicate=False) self.patcher.patch_method(builtins, "map", map_wrapper, deduplicate=False) @@ -1207,8 +1202,6 @@ def unwrap(obj: Any): self.create_node('output', 'output', (self.create_arg(results),), {}, type_expr=fn.__annotations__.get('return', None), node_result=ep.map_aggregate_not_proxy(results, unwrap)) finally: - # for cuda versions of pytorch, autograd.Function.apply should be reverted manually - delattr(torch.autograd.Function, 'apply') _retain_weight_consistency(self.root) pass @@ -1257,7 +1250,7 @@ def wrapped(*args, **kwargs): return wrapped -def _patch_wrapped_functions(patcher : _Patcher): +def _patch_wrapped_functions(patcher : FunctionPatcher): """ Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap the listed global functions in the `_create_wrapped_func` wrapper. @@ -1877,7 +1870,7 @@ def f(x, y): *side_effectful_inplace_ops } extra_side_effectful_functions = default_extra_side_effectful_functions | dce_ignored_function - with _Patcher() as patcher, ExtraSEFPatcher(extra_side_effectful_functions): + with FunctionPatcher() as patcher, ExtraSEFPatcher(extra_side_effectful_functions): patcher.patch_method(Node, 'is_impure', node_is_impure_wrapper, deduplicate=False) traced.graph.eliminate_dead_code() traced.recompile() # this need to be done in MagicMethodPatcher context diff --git a/cube/graph/parser/fx/concrete_trace_utils/function_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/function_patcher.py new file mode 100644 index 00000000..8bb97921 --- /dev/null +++ b/cube/graph/parser/fx/concrete_trace_utils/function_patcher.py @@ -0,0 +1,136 @@ +import builtins +from contextlib import contextmanager +from typing import Any, Callable, List, Dict, NamedTuple + +from torch.fx._symbolic_trace import _Patcher +from .utils import _orig_reversed + +class _PatchedFnReusable(NamedTuple): + frame_dict: Any + fn_name: str + orig_fn: Any + new_fn: Any + + def patch(self): + raise NotImplementedError() + + +class _PatchedFnSetItemReusable(_PatchedFnReusable): + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + + def revert(self): + self.frame_dict[self.fn_name] = self.orig_fn + + +class _PatchedFnDelReusable(_PatchedFnReusable): + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + + def revert(self): + del self.frame_dict[self.fn_name] + + +class _PatchedFnSetAttrReusable(_PatchedFnReusable): + def patch(self): + setattr(self.frame_dict, self.fn_name, self.new_fn) + + def revert(self): + setattr(self.frame_dict, self.fn_name, self.orig_fn) + + +class _PatchedFnSetAttrDelReusable(_PatchedFnReusable): + def patch(self): + setattr(self.frame_dict, self.fn_name, self.new_fn) + + def revert(self): + delattr(self.frame_dict, self.fn_name) + + +class FunctionPatcher(_Patcher): + def __init__(self): + super().__init__() + self.patches_made: List[_PatchedFnReusable] = [] + self.patch_mode = False + self.in_global_context = False + + def patch( + self, + frame_dict: Dict[str, Any], + name: str, + new_fn: Callable, + deduplicate: bool = True, + ): + """ + Replace frame_dict[name] with new_fn until we exit the context manager. + """ + if not self.patch_mode: + raise RuntimeError('only can do patch in patch mode') + setattr(new_fn, '__fx_already_patched', deduplicate) + if name not in frame_dict and hasattr(builtins, name): + self.patches_made.append(_PatchedFnDelReusable(frame_dict, name, None, new_fn)) + elif getattr(frame_dict[name], "__fx_already_patched", False): + return # already patched, no need to do it again + else: + self.patches_made.append( + _PatchedFnSetItemReusable(frame_dict, name, frame_dict[name], new_fn) + ) + self.patches_made[-1].patch() + + def patch_method( + self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True, revert_by_del: bool = False + ): + """ + Replace object_or_dict.name with new_fn until we exit the context manager. + """ + if not self.patch_mode: + raise RuntimeError('only can do patch in patch mode') + setattr(new_fn, '__fx_already_patched', deduplicate) + orig_fn = getattr(cls, name) + if getattr(orig_fn, "__fx_already_patched", False): + return # already patched, no need to do it again + if revert_by_del: + self.patches_made.append(_PatchedFnSetAttrDelReusable(cls, name, orig_fn, new_fn)) + else: + self.patches_made.append(_PatchedFnSetAttrReusable(cls, name, orig_fn, new_fn)) + self.patches_made[-1].patch() + + @contextmanager + def revert(self): + if self.in_global_context: + self._change_patch_mode_to(False) + for patch in _orig_reversed(self.patches_made): + # unpatch in reverse order to handle duplicates correctly + patch.revert() + try: + yield + finally: + self._change_patch_mode_to(True) + for patch in self.patches_made: + patch.patch() + else: + try: + yield + finally: + pass + + def _change_patch_mode_to(self, to_mode: bool): + if self.patch_mode != (not to_mode): + raise RuntimeError(f'want to change patch mode to {to_mode}, but get current patch mode {self.patch_mode}') + self.patch_mode = to_mode + + def __enter__(self): + self.in_global_context = True + self._change_patch_mode_to(True) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Undo all the changes made via self.patch() and self.patch_method() + """ + while self.patches_made: + # unpatch in reverse order to handle duplicates correctly + self.patches_made.pop().revert() + self.visited.clear() + self._change_patch_mode_to(False) + self.in_global_context = False diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index 91c36553..a7adb019 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -185,9 +185,7 @@ def patch_func_helper(self, func): return func # those flags are set by fx _Patcher when a method is patched # we don't want to patch it again - # _Patcher__fx_already_patched is for torch 2.0.1+ - # __fx_already_patched is for torch 2.0.0 - if hasattr(func, '_Patcher__fx_already_patched') or hasattr(func, '__fx_already_patched'): + if hasattr(func, '__fx_already_patched'): return func if self.use_operator_patch == (func in self.operator_patch_backlist): return func From 3e6d66be66f6a055fbee5676af39fcc1f4180c28 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 25 Jan 2024 01:53:27 +0000 Subject: [PATCH 1585/1892] Merged PR 2022: parallel module: add non-parallel module support Support the case when only some summodules are parallelized. --- cube/parallel.py | 106 +++++-- cube/runtime/module.py | 14 +- docs/parallel_module.md | 65 ++++- tests/parallel_module/common.py | 26 +- tests/parallel_module/test_ddp.py | 324 +++++++++++++++++++++ tests/parallel_module/test_reducer_hook.py | 11 +- tests/parallel_module/test_submodule.py | 11 +- tests/parallel_module/test_wholemodule.py | 11 +- 8 files changed, 501 insertions(+), 67 deletions(-) create mode 100644 tests/parallel_module/test_ddp.py diff --git a/cube/parallel.py b/cube/parallel.py index bf427a09..c93c7d39 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -33,9 +33,11 @@ from cube.program import Program from cube.runtime.adapter.reducer import Reducer from cube.runtime.module import CubeModule, ParallelModule +from cube.runtime.device import DeviceGroup logger = logging.getLogger(__name__) +_VALID_REDUCER_OPS = ['sum', 'avg', 'mean', 'max', 'min'] @dataclass(frozen=True) @@ -109,9 +111,16 @@ def __post_init__(self): raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be a multiple of plan_ngpus {self.plan_ngpus}") if self.use_zero and self.zero_ngroups < 0: raise ValueError(f"zero_ngroups {self.zero_ngroups} must be >= 0") - if self.reducer_op not in ['sum', 'avg', 'mean', 'max', 'min']: + if self.reducer_op not in _VALID_REDUCER_OPS: raise ValueError(f"reducer_op {self.reducer_op} is not supported.") + @property + def gpu_config(self) -> Dict[str, int]: + return { + 'plan_ngpus': self.plan_ngpus, + 'runtime_ngpus': self.runtime_ngpus, + } + @contextmanager def _flags(flags, /, **kwargs): @@ -692,6 +701,10 @@ class ParallelOptimizer(torch.optim.Optimizer): A optimizer stub to support parallelized module. The returned optimizer of build_optimizer() will have the same methods in this class. """ + + # this is a reducer for non-parallel modules + _non_parallel_module_reducer: Optional[Reducer] = None + def sync_shard_grad(self): """ Sync the shard gradients of the module from nodes with same shard to the optimizer. @@ -729,6 +742,7 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], + non_parallel_module_reducer_op: str = 'sum', *args, **kwargs, ) -> OptimizerT: @@ -755,7 +769,10 @@ def build_optimizer( optimizer_fn (Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]): It can be the optimizer class or optimizer factory function. If it is a factory function, the signature should be the same with optimizer class constructor. - *args: the args for optimizer constructor + non_parallel_module_reducer_op (str): the reducer op for non-parallel modules. Default is 'sum'. + *args: the args for optimizer constructor. + Note: If you use `*args`, you must specify `non_parallel_module_reducer_op`. + Suggest to use kwargs instead, so you don't need to explicitly specify the default value of `non_parallel_module_reducer_op`. **kwargs: the kwargs for optimizer constructor Returns: @@ -766,62 +783,93 @@ def build_optimizer( if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): raise RuntimeError("End2End mode is not supported") + if not non_parallel_module_reducer_op in _VALID_REDUCER_OPS: + raise ValueError(f"non_parallel_module_reducer_op {non_parallel_module_reducer_op} is not supported.") RuntimeFlag.skip_reducer = True RuntimeFlag.skip_zero_grad = False + non_parallel_module_reducer = None + non_parallel_modules = [m for m in module.modules() if not isinstance(m, ParallelModule)] + parallel_modules = [m for m in module.modules() if isinstance(m, ParallelModule)] + if not parallel_modules: + raise RuntimeError("No ParallelModule found in the module. Please make sure you have called parallelize() before build_optimizer().") + + # check if all ParallelModules have the same gpu_config + compute_configs = [m.get_compute_config() for m in parallel_modules] + for i in range(1, len(compute_configs)): + if compute_configs[i].gpu_config != compute_configs[0].gpu_config: + raise RuntimeError("All ParallelModules should have the same gpu_config.") + plan_ngpus, runtime_ngpus = compute_configs[0].plan_ngpus, compute_configs[0].runtime_ngpus + + # we need to add all parameters of non-parallel modules to a reducer to reduce grads + # if there are non-parallel parameters + if plan_ngpus != runtime_ngpus and non_parallel_modules and any(p.numel() for m in non_parallel_modules for p in m.parameters(False)): + rank = torch.distributed.get_rank() + # create all groups + for i in range(plan_ngpus): + DeviceGroup().get_group(list(range(i, runtime_ngpus, plan_ngpus))) + group = list(range(rank % plan_ngpus, runtime_ngpus, plan_ngpus)) + non_parallel_module_reducer = Reducer(group, reduce_op=non_parallel_module_reducer_op) + for m in non_parallel_modules: + for param in m.parameters(recurse=False): # only add leaf parameters to avoid duplicate + non_parallel_module_reducer.add_param(param) + non_parallel_module_reducer.build_buckets() + def _local_parameters(module: torch.nn.Module): gen = module._named_members(lambda m: m._parameters.items()) for _, param in gen: yield param optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), *args, **kwargs) + optimizer._non_parallel_module_reducer = non_parallel_module_reducer def _step_pre_hook(opt, *args, **kwargs): opt.sync_shard_grad() + def _step_post_hook(opt, *args, **kwargs): - for m in module.modules(): - if isinstance(m, ParallelModule): - m.gather_params() - else: - assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + for m in parallel_modules: + m.gather_params() + optimizer.register_step_pre_hook(_step_pre_hook) optimizer.register_step_post_hook(_step_post_hook) orig_zero_grad = optimizer.zero_grad def _patched_zero_grad_hook(self, set_to_none: bool = True): orig_zero_grad(set_to_none) - for m in module.modules(): - if isinstance(m, ParallelModule): - m.zero_grad() - else: - assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + for m in parallel_modules: + m.zero_grad() + if non_parallel_module_reducer: + non_parallel_module_reducer.zero_grad() optimizer.zero_grad = types.MethodType(_patched_zero_grad_hook, optimizer) def _sync_shard_grad(self): with _runtime_flags(skip_reducer=False): - for m in module.modules(): - if isinstance(m, ParallelModule): - m.sync_grad() - else: - assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + # HACK: we reuse the _sync_grad_required flag of the first parallel module + # in order to support calling sync_shard_grad() multiple times. + # _sync_grad_required will reset to `True` in forward() of ParallelModule. + if parallel_modules[0]._sync_grad_required: + for m in parallel_modules: + m.sync_grad() # _sync_grad_required flag will reset inside sync_grad() + + if non_parallel_module_reducer: + non_parallel_module_reducer.sync_grads() + optimizer.sync_shard_grad = types.MethodType(_sync_shard_grad, optimizer) def _register_reducer_pre_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): - for m in module.modules(): - if isinstance(m, ParallelModule): - for reducer in m.reducers: - reducer.register_pre_hook(partial(fn, reducer)) - else: - assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + for m in parallel_modules: + for reducer in m.reducers: + reducer.register_pre_hook(partial(fn, reducer)) + if non_parallel_module_reducer: + non_parallel_module_reducer.register_pre_hook(partial(fn, non_parallel_module_reducer)) def _register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): - for m in module.modules(): - if isinstance(m, ParallelModule): - for reducer in m.reducers: - reducer.register_post_hook(partial(fn, reducer)) - else: - assert not isinstance(m, CubeModule), "Only ParallelModule is supported in this mode" + for m in parallel_modules: + for reducer in m.reducers: + reducer.register_post_hook(partial(fn, reducer)) + if non_parallel_module_reducer: + non_parallel_module_reducer.register_post_hook(partial(fn, non_parallel_module_reducer)) optimizer.register_reducer_pre_hook = types.MethodType(_register_reducer_pre_hook, optimizer) optimizer.register_reducer_post_hook = types.MethodType(_register_reducer_post_hook, optimizer) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 499a387e..0ace43c8 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, TYPE_CHECKING import logging import os import sys @@ -12,6 +12,10 @@ from cube.runtime.adapter.reducer import Reducer from cube.runtime.gnorm import ParamsInfo +if TYPE_CHECKING: + from cube.parallel import ComputeConfig + + _logger = logging.getLogger(__name__) @@ -153,9 +157,9 @@ def get_full_map(self): def load_attr_content(self, filename: str): """Load module attribute (parameters and buffers) from file - + Args: - filename (str): base file name (without '.0', '.1', etc.) + filename (str): base file name (without '.0', '.1', etc.) that saved with model parameters """ npartitions = 0 @@ -165,7 +169,7 @@ def load_attr_content(self, filename: str): raise RuntimeError(f"Cannot find file {filename}.0 in load_attr_content") with torch.no_grad(): _logger.info(f'loading partitioned model from {filename}, number of model parameter chunks: {npartitions}') - # self._fullmap + # self._fullmap attr_names = set(self._fullmap.keys()) for file_idx in range(npartitions): # part_model contains a subset of attributes, where each attribute is a fulltensor @@ -518,7 +522,7 @@ def sync_grad(self): def get_dist_param_map(self): return self._dist_param_map - def get_compute_config(self): + def get_compute_config(self) -> 'ComputeConfig': return self._compute_config def get_rank(self): diff --git a/docs/parallel_module.md b/docs/parallel_module.md index 9ef7fe15..d68f4e53 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -4,6 +4,8 @@ Besides the support of end-to-end model training, Cube can also convert a `torch ## An example +- Example 1: Parallelize the whole module + ```python import torch from cube.parallel import parallelize, ComputeConfig, build_optimizer @@ -28,7 +30,55 @@ ParallelizedLLM = parallelize( pas_policy, compute_config, ) +``` + +- Example 2: Parallelize submodules. + +In this case, for non-paralle modules, they are replicated inside unit, and run data parallelism across units. See more details about unit in [Compute Config](###ComputeConfig) section. + +```python +import torch +from cube.parallel import parallelize, ComputeConfig, build_optimizer + +class HeavyModule(torch.nn.Module): + def __init__(self, ...): + ... + def forward(self, x): + ... + +class ParallelizedLLM(torch.nn.Module): + def __init__(self, ...): + ... + # use parallelize to convert submodules + heavy_module_sample_input = ... # dummpy input will be used to do tracing + pas_policy = ... # the PAS policy, you can use autodist pas + compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + ..., + ) # compute environment config + self.heavy_module = parallelize( + HeavyModule(), + {'x': heavy_module_sample_input}, + pas_policy, + compute_config, + ) + # you can add other submodules here + ... + + def forward(self, x, ...): + # call other submodules + ... + x = self.heavy_module(x) + ... + # call other submodules + return x +``` +For both example 1 & 2, you can train/infer that module in multiple GPUs/Nodes just like a normal `torch.nn.Module`: + +```python # do inference exactly the same way def infer(model: ParallelizedLLM, x): model.eval() @@ -239,6 +289,7 @@ We have `build_optimizer` to build an optimizer for distributed training. def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], + non_parallel_module_reducer_op: str = 'sum', *args, **kwargs, ) -> OptimizerT: @@ -247,7 +298,10 @@ It has the following parameters: - module (torch.nn.Module): the module to be optimized - optimizer_fn (Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]): It can be the optimizer class or optimizer factory function. -- *args: the args will pass to `optimizer_fn` +- non_parallel_module_reducer_op (str): the reducer op for non-parallel modules. Default is 'sum'. +- *args: the args will pass to `optimizer_fn`. + If you use `*args`, you must specify `non_parallel_module_reducer_op` explicitly even when you only need its default value. + So we suggest using `kwargs` instead of `args` to specify `optimizer_fn` arguments if possible. - **kwargs: the kwargs will pass to `optimizer_fn` To support distrubted training, in the function we need to hook 4 places: @@ -263,11 +317,13 @@ To support distrubted training, in the function we need to hook 4 places: 3. `optimizer.zero_grad()`: We need to call `CubeModule.zero_grad()` after `optimizer.zero_grad()` -`build_optimizer` will patch optimizer for you. Besides the above patches, we also add several utility functions to optimizer: +`build_optimizer` will patch optimizer for you. Besides the above patches, we also add several utility functions/variables to optimizer: 1. `sync_shard_grad`: Sync the shard gradients of the module from nodes with same shard to the optimizer. This function is called in optimizer's pre-step hook. But If you want to access the gradients before `optimizer.step()`(for example, you need gnorm), you need to call this function manually. -2. `register_reducer_pre_hook`, `register_reducer_post_hook`: Register pre/post hooks to reducers which will be applied before/after gradient synchronization. +2. `register_reducer_pre_hook`, `register_reducer_post_hook`: Register pre/post hooks to reducers which will be applied before/after gradient synchronization. These hooks will apply to all the reducers (including `_non_parallel_module_reducer`) in the optimizer. + +3. `_non_parallel_module_reducer`: The reducer for the modules which are not parallelized. It is used to sync the parameters in those modules across units. ### Dataset @@ -287,5 +343,4 @@ def create_distributed_sampler(dataset): ``` ## TODOs -1. When ParallelModule is a submodule of another Module, Pytorch DDP is not supported yet. -2. Pipeline parallelism is not supported yet. +1. Pipeline parallelism is not supported yet. diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index c5492eb2..4819e715 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -92,28 +92,18 @@ def PASData(graph: IRGraph, env_resource: ComputeConfig): for dl in graph.select(ntype=IRDataOperation): _replica(dl, list(range(ngpus))) - graph_inputs = IRSegment.get_objects_from_complex(graph.inputs()) - graph_outputs = IRSegment.get_objects_from_complex(graph.outputs()) for node in graph.nodes(): # print(node) if isinstance(node, IRFwOperation): - # Currently cube only support replicate if node's input or input is part of the graph output - # workaround for now - # will fix later. - if any(output in graph_outputs for output in node.outputs()) \ - or any(input in graph_outputs for input in node.inputs()): - # or any(input in graph_inputs for input in node.inputs()): + try: + algo = node.algorithms('dim') + idx = 0 + sub_nodes = graph.partition( + node, algo, idx=idx, dim=batch_dim, num=ngpus) + # except AssertionError: + except: + # print(f'WARNING: {node} cannot find dim algo, using replicate instead') sub_nodes = graph.replicate(node, ngpus) - else: - try: - algo = node.algorithms('dim') - idx = 0 - sub_nodes = graph.partition( - node, algo, idx=idx, dim=batch_dim, num=ngpus) - # except AssertionError: - except: - # print(f'WARNING: {node} cannot find dim algo, using replicate instead') - sub_nodes = graph.replicate(node, ngpus) for idx, node in enumerate(sub_nodes): graph.assign(node, idx) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py new file mode 100644 index 00000000..6b167881 --- /dev/null +++ b/tests/parallel_module/test_ddp.py @@ -0,0 +1,324 @@ +import tempfile +import itertools +import re +from pathlib import Path +import shutil +import pytest +from typing import Dict, Tuple, List +from dataclasses import dataclass + +import torch +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler + +import numpy as np + +from cube.parallel import ComputeConfig, parallelize, build_optimizer +from cube.runtime.module import ParallelModule + +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively + + +class FcRelu(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, out_features, bias=bias) + self.relu2 = nn.ReLU() + + + def forward(self, x): + return self.relu2(self.fc2(self.relu1(self.fc1(x)))) + + +class FcRelu_4_4(FcRelu): + def __init__(self): + super().__init__(4, 4) + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + +class OrigModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = FcRelu_4_4() + self.linear2 = nn.Linear(4, 4) + self.fc_relu2 = FcRelu_4_4() + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + + +def _create_torch_module(): + init_random() + return OrigModule().cuda() + + +def _create_cube_module(pas, compute_config, cube_savedir): + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu1') + self.linear2 = nn.Linear(4, 4) + self.fc_relu2 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu2') + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + compiled_module = CompiledModule().cuda() + return compiled_module + +DATA_SIZE = 32 + +@dataclass +class StepResult: + pred: torch.Tensor + loss: torch.Tensor + grads: Dict[str, torch.Tensor] + weights: Dict[str, torch.Tensor] + + +def _train_ddp(model, update_freq, num_replicas, rank): + from torch.nn.parallel import DistributedDataParallel as DDP + init_random() + + loss_fn = nn.BCELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + model = DDP(model, device_ids=[rank]) + + data = [] + UPDATE_FREQ = update_freq + init_random() + for _ in range(DATA_SIZE): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + data = [data[i] for i in range(rank, DATA_SIZE, num_replicas)] + results = [] + for i, (x, y) in enumerate(data): + model.train() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + if i % UPDATE_FREQ == UPDATE_FREQ - 1: + optimizer.step() + # remove leadding `module.` prefix + prefix_len = len('module.') + grads = {n[prefix_len:]: p.grad for n, p in model.named_parameters()} + results.append(clone_to_cpu_recursively([y_pred, loss, grads])) + optimizer.zero_grad() + weights = {n[prefix_len:]: p.data for n, p in model.named_parameters()} + results[-1].append(clone_to_cpu_recursively(weights)) + results[-1] = StepResult(*results[-1]) + return results + + +def _train(model, is_cube, update_freq, num_replicas, rank): + init_random() + + loss_fn = nn.BCELoss() + if is_cube: + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + data = [] + UPDATE_FREQ = update_freq + init_random() + for _ in range(DATA_SIZE): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + data = [data[i] for i in range(rank, DATA_SIZE, num_replicas)] + results = [] + for i, (x, y) in enumerate(data): + model.train() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + if i % UPDATE_FREQ == UPDATE_FREQ - 1: + optimizer.step() + grads = {n: p.grad for n, p in model.named_parameters()} + results.append(clone_to_cpu_recursively([y_pred, loss, grads])) + optimizer.zero_grad() + weights = {n: p.data for n, p in model.named_parameters()} + results[-1].append(clone_to_cpu_recursively(weights)) + results[-1] = StepResult(*results[-1]) + return results + + +def _gpu_worker_ga(update_freq): + init_distributed() + orig_module = _create_torch_module() + # update_freq *2 to simulate ddp = 2 + orig_results = _train(orig_module, False, update_freq*2, 1, 0) + return ( + orig_results, + ) + + +def _gpu_worker_ddp(update_freq): + init_distributed() + orig_module = _create_torch_module() + orig_results = _train_ddp(orig_module, update_freq, 2, torch.distributed.get_rank()) + return ( + orig_results, + ) + +def _gpu_worker_cube(pas, plan_ngpus, runtime_ngpus, update_freq): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + compiled_module = _create_cube_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus), tempdir) + compiled_results = _train( + compiled_module, True, update_freq, + runtime_ngpus // plan_ngpus, + torch.distributed.get_rank() // plan_ngpus + ) + return ( + compiled_results, + compiled_module.fc_relu1.get_full_map(), + compiled_module.fc_relu1.get_dist_param_map(), + compiled_module.fc_relu2.get_full_map(), + compiled_module.fc_relu2.get_dist_param_map(), + ) + +def _get_fc_weights(state_dict: dict, prefix): + result = {} + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith(prefix): + result[k[len(prefix):]] = v + else: + new_state_dict[k] = v + state_dict.clear() + state_dict.update(new_state_dict) + return result + + +def _compare_weights(orig0, compiled0, compiled1, fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map): + fc1_weights0 = _get_fc_weights(compiled0, 'fc_relu1.') + fc2_weights0 = _get_fc_weights(compiled0, 'fc_relu2.') + fc1_weights1 = _get_fc_weights(compiled1, 'fc_relu1.') + fc2_weights1 = _get_fc_weights(compiled1, 'fc_relu2.') + + cube_state_fc1 = [(fc1_weights0, {'state':{}}, fc1_dist_param_map[0], fc1_fullmap[0]), (fc1_weights1, {'state':{}}, fc1_dist_param_map[1], fc1_fullmap[1])] + cube_state_fc2 = [(fc2_weights0, {'state':{}}, fc2_dist_param_map[0], fc2_fullmap[0]), (fc2_weights1, {'state':{}}, fc2_dist_param_map[1], fc2_fullmap[1])] + merged_fc1, _ = ParallelModule.merge_partial_states(cube_state_fc1) + merged_fc1_fixed = {} + for k, v in merged_fc1.items(): + merged_fc1_fixed['fc_relu1.' + k] = v + merged_fc2, _ = ParallelModule.merge_partial_states(cube_state_fc2) + merged_fc2_fixed = {} + for k, v in merged_fc2.items(): + merged_fc2_fixed['fc_relu2.' + k] = v + assert len(merged_fc1_fixed) + len(merged_fc2_fixed) + len(compiled0) == len(orig0) + assert len(compiled1) == len(compiled0) + for k, v in compiled0.items(): + assert torch.allclose(compiled0[k], compiled1[k], rtol=1e-4, atol=1e-4) + for k, v in itertools.chain(merged_fc1_fixed.items(), merged_fc2_fixed.items(), compiled0.items()): + # print(f'key: {k}, max diff: {torch.max(torch.abs(orig0[k] - v))}') + assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('update_freq', [1, 2, 4]) +def test_tp_ddp(update_freq): + orig_results: Dict[int, tuple] = launch_torchrun(2, _gpu_worker_ddp, update_freq) + orig_results2: Dict[int, tuple] = launch_torchrun(1, _gpu_worker_ga, update_freq) + + # check equavalence of ddp and gradient accumulation + ddp_worker_result0, ddp_worker_result1 = orig_results[0], orig_results[1] + ddp_result0, ddp_result1 = ddp_worker_result0[0], ddp_worker_result1[0] + for i in range(len(ddp_result0)): + for k in ddp_result0[i].grads.keys(): # grad + ddp_result0[i].grads[k] += ddp_result1[i].grads[k] + for k in ddp_result0[i].weights.keys(): # weights + assert torch.equal(ddp_result0[i].weights[k], ddp_result1[i].weights[k]) + + ga_simulated_result0 = orig_results2[0][0] + assert len(ddp_result0) == len(ga_simulated_result0) + assert len(ddp_result1) == len(ga_simulated_result0) + for i in range(len(ddp_result0)): + a0, b = ddp_result0[i], ga_simulated_result0[i] + for k in b.grads.keys(): # grad + # print('grad: ', k, torch.max(torch.abs(a0[2][k] - b[2][k]))) + assert torch.allclose(a0.grads[k], b.grads[k], atol=1e-2, rtol=1e-2) # grad + for k in b.weights.keys(): # weights + # ddp will prefix `module.` to the key + # print('weight: ', k, torch.max(torch.abs(a0[3][k]- b[3][k]))) + assert torch.allclose(a0.weights[k], b.weights[k], atol=1e-2, rtol=1e-2) # weights + + cube_results = launch_torchrun(4, _gpu_worker_cube, PASRandomSPMD, 2, 4, update_freq) + worker_results0, worker_results1, worker_results2, worker_results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] + results0, results1, results2, results3 = worker_results0[0], worker_results1[0], worker_results2[0], worker_results3[0] + + fc1_fullmap = worker_results0[1], worker_results1[1] + assert fc1_fullmap == (worker_results2[1], worker_results3[1]) + fc1_dist_param_map = (worker_results0[2], worker_results1[2]) + assert fc1_dist_param_map == (worker_results2[2], worker_results3[2]) + + fc2_fullmap = worker_results0[3], worker_results1[3] + assert fc2_fullmap == (worker_results2[3], worker_results3[3]) + fc2_dist_param_map = worker_results0[4],worker_results1[4] + assert fc2_dist_param_map == (worker_results2[4], worker_results3[4]) + + # pred, loss + for r0, r1 in [(results0, results1), (results2, results3)]: + # have the same input + assert len(r0) == len(r1) # iteration count + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.pred, b.pred) # pred + assert torch.equal(a.loss, b.loss) # loss + + # grad, weights + for r0, r1 in [(results0, results2), (results1, results3)]: + # in the same shard, grads and weights are the same + assert len(r0) == len(r1) + for i in range(len(r0)): + a, b = r0[i], r1[i] + for k in a.grads.keys(): # grad + assert torch.equal(a.grads[k], b.grads[k]) + for k in a.weights.keys(): # weights + assert torch.equal(a.weights[k], b.weights[k]) + + assert len(ga_simulated_result0) == len(results0) + for i in range(len(ddp_result0)): + print('iteration: ', i) + orig0, compiled0, compiled1 = ga_simulated_result0[i], results0[i], results1[i] + + print('grads') + # grad + _compare_weights(orig0.grads, compiled0.grads, compiled1.grads, fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map) + + print('weights') + # weights + _compare_weights(orig0.weights, compiled0.weights, compiled1.weights, fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map) diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index 27d4a688..94e2bdd0 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -78,6 +78,8 @@ def post_hook(reducer, grad): for m in model.modules(): if isinstance(m, ParallelModule): reducers.extend(m.reducers) + if optimizer._non_parallel_module_reducer: + reducers.append(optimizer._non_parallel_module_reducer) if not reducers: print('No reducer found, skip test_hook') @@ -110,10 +112,10 @@ def post_hook(reducer, grad): return results -def _gpu_worker(pas, ngpus): +def _gpu_worker(pas, plan_ngpus, runtime_ngpus=None): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_hook') as tempdir: - compiled_module = _create_module(pas, ComputeConfig(ngpus, ngpus), tempdir) + compiled_module = _create_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus or plan_ngpus), tempdir) _train(compiled_module) @@ -127,6 +129,11 @@ def test_hook_tp_gpu2(): launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_hook_tp_gpu4(): + launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_hook_dp_gpu1(): launch_torchrun(1, _gpu_worker, PASData, 1) diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 1ac54ea1..aaccb8e1 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -80,11 +80,14 @@ def forward(self, x): return orig_module, compiled_module -def _train(model, update_freq): +def _train(model, update_freq, is_cube): init_random() loss_fn = nn.BCELoss() - optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + if is_cube: + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) data = [] DATA_SIZE = 20 UPDATE_FREQ = update_freq @@ -113,8 +116,8 @@ def _gpu_worker(pas, ngpus, update_freq): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) - orig_results = _train(orig_module, update_freq) - compiled_results = _train(compiled_module, update_freq) + orig_results = _train(orig_module, update_freq, False) + compiled_results = _train(compiled_module, update_freq, True) return ( orig_results, compiled_results, diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 4f787cfb..7309ff35 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -68,11 +68,14 @@ def forward(self, x): return orig_module, compiled_module -def _train(model): +def _train(model, is_cube): init_random() loss_fn = nn.BCELoss() - optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + if is_cube: + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) data = [] DATA_SIZE = 20 UPDATE_FREQ = 1 # TODO: update_freq support @@ -101,8 +104,8 @@ def _gpu_worker(pas, ngpus): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) - orig_results = _train(orig_module) - compiled_results = _train(compiled_module) + orig_results = _train(orig_module, False) + compiled_results = _train(compiled_module, True) return ( orig_results, compiled_results, From d9cce06229e529dc853bdf4ff57f48c7078b70b1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 26 Jan 2024 01:50:59 +0000 Subject: [PATCH 1586/1892] Merged PR 2031: remove unused IRFwOperation.infer_dtype It is not used, and the logic (infer output dtype from input dtype) looks not reasonable. unit test pass parity check pass --- cube/graph/function/anchor.py | 3 --- cube/graph/function/dimops.py | 2 +- cube/graph/function/function.py | 41 ++++++++++----------------------- cube/ir/operator.py | 26 ++++----------------- 4 files changed, 18 insertions(+), 54 deletions(-) diff --git a/cube/graph/function/anchor.py b/cube/graph/function/anchor.py index acaeca14..4a999f24 100644 --- a/cube/graph/function/anchor.py +++ b/cube/graph/function/anchor.py @@ -45,9 +45,6 @@ def __init__(self, signature: str, name: str): super().__init__(name, signature, [], 1) self.kwargs['name'] = name self.set_output(0, IRObject('anchor', value=None)) - - def infer_dtype(self): - return def infer_shape(self): return True diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 1744f3b2..005153a1 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -664,7 +664,7 @@ def oanno(self, index: int) -> Tuple[DimAnno]: def infer_shape(self) -> bool: """ - Shape and dtype inference using the matched annotation and tensor. + Shape inference using the matched annotation and tensor. @return sucess: True if successfully inferred shape """ diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 33217777..9873ca0d 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -180,9 +180,7 @@ def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Uni size = (math.ceil((end_val-start_val)/step_val),) anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), False) - dimop = IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) - dimop.output(0).parent.dtype = dtype - return dimop + return IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) def Arange(*args, out=None, dtype=None, layout=None, @@ -215,9 +213,7 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) - dimop = IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) - dimop.output(0).parent.dtype = dtype - return dimop + return IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) def Zeros(size, *arg_size, out=None, dtype=None, layout=None, @@ -232,9 +228,7 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) - dimop = IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) - dimop.output(0).parent.dtype = dtype - return dimop + return IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) def Ones(size, *arg_size, out=None, dtype=None, layout=None, @@ -249,9 +243,7 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) - dimop = IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) - dimop.output(0).parent.dtype = dtype - return dimop + return IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, @@ -267,9 +259,7 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) - dimop = IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) - dimop.output(0).parent.dtype = dtype - return dimop + return IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) def Full(size, fill_value, *, out=None, dtype=None, layout=None, @@ -284,10 +274,8 @@ def Full(size, fill_value, *, out=None, dtype=None, layout=None, size = tuple(size) if size else (1,) anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) - dimop = IRDimops(Full, 'full', signature, [anno], [], rules, + return IRDimops(Full, 'full', signature, [anno], [], rules, size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad) - dimop.output(0).parent.dtype = dtype - return dimop def NewTensor(data, *, dtype=None, device=None, @@ -311,12 +299,11 @@ def NewTensor(data, *, dtype=None, device=None, size = val.shape val = val.tolist() size = size if len(size) > 0 else (1,) # for scalar - + kwargs = {'data': val, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(size, False) - dimop = IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) - return dimop + return IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: @@ -823,8 +810,8 @@ def Where(condition, input, other, *, out=None, signature = None): edim_out = copy.copy(edim_in0) annos = [OpAnno.create_op_str([edim_in0, edim_in1, edim_in2], [edim_out])] - dimop = IRDimops(Where, 'where', signature, annos, [condition, input, other]) - return dimop + return IRDimops(Where, 'where', signature, annos, [condition, input, other]) + def CubeLayerNorm(input, weight=None, bias=None, normalized_shape=None, eps=1e-05, signature = None): """ @@ -1784,15 +1771,11 @@ def _comparison(creator: Callable, f: Callable, name: str, signature: str, if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - dimop = IRDimops(creator, name, signature, annos, [input, other]) - dimop.output(0).parent.dtype = torch.bool - return dimop + return IRDimops(creator, name, signature, annos, [input, other]) # case2: torch.equal(tensor1, obj2) / torch.equal(obj1, tensor2) if isinstance(input, IRTensor) or isinstance(other, IRTensor): annos = ['*, ? -> *', '?, * -> *',] - dimop = IRDimops(creator, name, signature, annos, [input, other]) - dimop.output(0).parent.dtype = torch.bool - return dimop + return IRDimops(creator, name, signature, annos, [input, other]) # case3: torch.equal(obj1, obj2) else: return IRPyFunc(signature, [input, other], [IRObject()]) diff --git a/cube/ir/operator.py b/cube/ir/operator.py index 9d61bfec..e783925d 100644 --- a/cube/ir/operator.py +++ b/cube/ir/operator.py @@ -30,35 +30,19 @@ def __init__(self, name: str, signature: str, # setup input for idx, input in enumerate(inputs): self.set_input(idx, input) - + # additional argument self.kwargs.update(kwargs) # default infer rule requires_grad = any( t.requires_grad for t in inputs if isinstance(t, IRTensor)) - + # setup output outputs = [IRFullTensor(requires_grad=requires_grad) for _ in range(num_outputs)] for idx, output in enumerate(outputs): self.set_output(idx, output) - def infer_dtype(self): - """ - Infer output value dtype. - By default will follow the same dtype promotion rule with PyTorch. - """ - itensors = [t for t in self.inputs() if isinstance(t, IRTensor)] - otensors = [t for t in self.outputs() if isinstance(t, IRTensor)] - odtype = DTypeInfo.promote([t.dtype for t in itensors]) - for tensor in otensors: - # in case of setting manually due to special rules - if tensor.dtype is None: - if isinstance(tensor, IRFullTensor): - tensor.dtype = odtype - else: - tensor.parent.dtype = odtype - def infer_shape(self): """ Infer output value shape @@ -155,7 +139,7 @@ class IRBpOperation(IRCell): def __init__(self, ograds: Tuple[Any], igrads: Tuple[Any]): """ Create dummy backward node for forward inputs and forward outputs - + @param fwop IRFwOperation: forward operator """ super().__init__( @@ -196,7 +180,7 @@ def __repr__(self) -> str: class IRDataOperation(IRCell): """Dataloader operator - + The output of a dataloader operator is a tuple of (IRObject,). """ @@ -233,7 +217,7 @@ def infer_shape(self): Infer output value shape """ return True - + def __repr__(self): dscp = (f"DataLoader{self._id}-{self.device}(outputs={self.outputs()})") return dscp From 5e68f30c03ef911ef06ee2a928501189a7a5b756 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 30 Jan 2024 03:25:19 +0000 Subject: [PATCH 1587/1892] Merged PR 2032: parallel module: add gnorm support unit test pass --- cube/parallel.py | 40 +++++++- cube/runtime/gnorm.py | 149 +++++++++++++++++++++++++++++- cube/runtime/module.py | 27 +++++- tests/parallel_module/test_ddp.py | 26 ++++-- 4 files changed, 230 insertions(+), 12 deletions(-) diff --git a/cube/parallel.py b/cube/parallel.py index c93c7d39..59381da9 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -1,7 +1,7 @@ from enum import Enum from functools import partial import types -from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar +from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar, List from pathlib import Path import inspect import sys @@ -34,6 +34,7 @@ from cube.runtime.adapter.reducer import Reducer from cube.runtime.module import CubeModule, ParallelModule from cube.runtime.device import DeviceGroup +from cube.runtime.gnorm import calcuate_gnorm, clip_grads logger = logging.getLogger(__name__) @@ -714,6 +715,18 @@ def sync_shard_grad(self): """ ... + def clip_gnorm(self, max_norm: Optional[float] = None) -> torch.Tensor: + """ + Clip the gradients with global norm, and return the global gnorm value. + + Args: + max_norm (Optional[float]): the max global norm. If it is None, no clipping will be applied. + + Returns: + torch.Tensor: the gradient norm. + """ + ... + def register_reducer_pre_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): """ Register pre hooks to reducers which will be applied before gradient synchronization. @@ -857,6 +870,31 @@ def _sync_shard_grad(self): optimizer.sync_shard_grad = types.MethodType(_sync_shard_grad, optimizer) + @torch.no_grad() + def _clip_gnorm(self, max_norm: Optional[float] = None): + self.sync_shard_grad() + total_norm_squared = 0.0 + grads: List[torch.Tensor] = [] + + for m in parallel_modules: + mnorm, mgrads = m.clip_gnorm(None) + total_norm_squared += torch.square(mnorm) + grads.extend(mgrads) + + if non_parallel_module_reducer: + params = non_parallel_module_reducer.parameters_for_optimizer() + mnorm, mgrads = calcuate_gnorm(params) + total_norm_squared += torch.square(mnorm) + grads.extend(mgrads) + + total_norm = torch.sqrt(total_norm_squared) + if max_norm is not None and max_norm > 0: + clip_grads(grads, total_norm, max_norm) + + return total_norm + + optimizer.clip_gnorm = types.MethodType(_clip_gnorm, optimizer) + def _register_reducer_pre_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): for m in parallel_modules: for reducer in m.reducers: diff --git a/cube/runtime/gnorm.py b/cube/runtime/gnorm.py index 73e1e6ab..05dc0bb3 100644 --- a/cube/runtime/gnorm.py +++ b/cube/runtime/gnorm.py @@ -1,13 +1,21 @@ -from typing import List, Dict, Tuple, TYPE_CHECKING +from typing import List, Dict, Tuple, Optional, TYPE_CHECKING from dataclasses import dataclass from collections import defaultdict import torch import torch.distributed as dist +try: + from amp_C import multi_tensor_l2norm + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + + if TYPE_CHECKING: from cube.runtime.module import CubeModule + @dataclass class ParamsInfo: # An instance of ParamsInfo corresponds to a group of parameters in cube reducer, @@ -17,6 +25,7 @@ class ParamsInfo: param_names: List[str] zero_ngroups: int + @dataclass class TidReplicaInfo: # the number of the replicas of the (partitioned) parameter with tid @@ -24,6 +33,7 @@ class TidReplicaInfo: # the number of all the involved ranks for this parameter with tid nranks: int + def _calc_grad_shape(slicers_list): # caculate the shape of each full parameters/grads tid2shape = {} @@ -47,6 +57,7 @@ def _calc_grad_shape(slicers_list): tid2nreplicas[tid] += factor return tid2nreplicas + def prepare_for_grad_clip_legacy(cube_model: 'CubeModule', curr_rank: int) -> Dict[int, List[torch.nn.Parameter]]: assert curr_rank == dist.get_rank() tid2param, tid2slicers = {}, {} @@ -67,12 +78,14 @@ def prepare_for_grad_clip_legacy(cube_model: 'CubeModule', curr_rank: int) -> Di nreplicas2localparams[nreplicas].append(param) return nreplicas2localparams + def _check_is_ordered(ranks: Tuple[int]) -> bool: for i in range(len(ranks)-1): if ranks[i] >= ranks[i+1]: return False return True + def _check_no_intersection(ranks_set): # ranks_set: set of tuple # check intersection between any two tuples @@ -84,6 +97,7 @@ def _check_no_intersection(ranks_set): return False return True + def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int, TidReplicaInfo]: """This function is used to calculate the number of replicas of each model parameter. Each parameter has a tuple of `len(ranksset)` (we call it nreplicated) and `nranks`, @@ -115,6 +129,7 @@ def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int tid2nreplicas[tid] = TidReplicaInfo(len(ranksset), nranks) return tid2nreplicas + def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, List[torch.nn.Parameter]]: params_info_for_gnorm = cube_model.parameters_for_calc_gnorm() tid2ranks = {} @@ -160,3 +175,135 @@ def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, nreplicas2localparams[nreplicas].extend(params_info.params) processed_seqs[seq] = replicated_info return nreplicas2localparams + + +def _multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: + """ + Returns: + Total norm of the input tensors in float32. + """ + per_device_grads = {} + norms = [] + for grad in grads: + device = grad.device + cur_device_grads = per_device_grads.get(device) + if cur_device_grads is None: + cur_device_grads = [] + per_device_grads[device] = cur_device_grads + cur_device_grads.append(grad) + for device in per_device_grads.keys(): + cur_device_grads = per_device_grads[device] + if device.type == "cuda": + # TODO(msb) return has_inf + has_inf = torch.zeros((1, 1), dtype=torch.int, device=device) + with torch.cuda.device(device): + norm = multi_tensor_l2norm( + chunk_size, has_inf, [cur_device_grads], False + ) + norms.append(norm[0].to(torch.cuda.current_device())) + else: + assert False, 'non cuda device is not supported.' + norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads] + assert len(norms) == 1 + total_norm = torch.norm(torch.stack(norms)) + return total_norm + + +@torch.no_grad() +def calcuate_gnorm(params: List[torch.Tensor], device: Optional[torch.device] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Calculate the gradient norm of the given parameters. + + Args: + params (List[torch.Tensor],): list of parameters + device (Optional[torch.device]): device to calculate the gradient norm. Default is the device of the first parameter + + Returns: + Tuple[torch.Tensor, List[torch.Tensor]]: Tuple of the gradient norm and the list of gradients. + """ + def grad_exists(p): + return p is not None and getattr(p, "grad", None) is not None + if device is None: + # assume all weights are on the same device + device = params[0].device + params = list(filter(grad_exists, params)) + grads = [] + for p in params: + grads.append(p.grad.detach()) + if len(grads) == 0: + total_norm = torch.tensor(0.0, dtype=torch.float32, device=device) # alway use float32 + elif len(grads) == 1: + total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) + else: + if multi_tensor_l2norm_available: + total_norm = _multi_tensor_total_norm(grads).to(device) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] + ) + ) + + return total_norm, grads + + +@torch.no_grad() +def clip_grads(grads: List[torch.Tensor], gnorm, max_norm: float) -> None: + """ + Clip gradients. + + Args: + grads: list of gradients + gnorm: the norm of all the gradients (maybe in different devices) + max_norm: max norm value + + Returns: + None + """ + max_norm = float(max_norm) + clip_coef = (max_norm / (gnorm + 1e-6)).clamp_(max=1) + for g in grads: + g.mul_(clip_coef) + + +@torch.no_grad() +def clip_gnorm( + nreplicas2localparams: Dict[int, List[torch.Tensor]], + max_norm: Optional[float] = None +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Calculate gnorm and clip gradients + + Args: + nreplicas2localparams: a dict mapping from number_of_replicas to a list of local params. + For example, nreplicas2localparams[2] contains all the parameters that have replicated 2 times. + max_norm: max norm value. If None or <= 0, no clipping will be performed. + + Returns: + Tuple[torch.Tensor, List[torch.Tensor]]: Tuple of The gradient norm and the list of gradients. + """ + # assume all weights are on the same device + for localparams in nreplicas2localparams.values(): + if len(localparams) == 0: + continue + device = localparams[0].device + break + else: + raise RuntimeError('no parameters found') + + total_grad_square = torch.tensor(0.0, dtype=torch.float64, device=device) + grads = [] + for nreplicas, localparams in nreplicas2localparams.items(): + if len(localparams) == 0: + continue + # compute gnorm + local_gnorm, local_grads = calcuate_gnorm(localparams, device) + total_grad_square += local_gnorm.to(dtype=torch.float64).pow_(2).div_(nreplicas) + grads.extend(local_grads) + dist.all_reduce(total_grad_square) + total_norm = total_grad_square.sqrt_().to(torch.float32) + + if max_norm is not None and max_norm > 0: + clip_grads(grads, total_norm, max_norm) + + return total_norm, grads diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 0ace43c8..64972c4a 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Tuple, TYPE_CHECKING +from typing import List, Dict, Tuple, Optional, TYPE_CHECKING import logging import os import sys @@ -475,6 +475,12 @@ def __init__(self): super().__init__() # this is used to allow multiple sync_grad() calls self._sync_grad_required = False + # save the param replicas info for calculating gradient norm + # it is a dict mapping from number_of_replicas to a list of local params. + # For example, _nreplicas2localparams[2] contains all the parameters that have replicated 2 times. + # this is a lazy initialization, + # which will be initialized in the first call of `clip_gnorm` + self._nreplicas2localparams: Optional[Dict[int, List[torch.nn.Parameter]]] = None def _post_init(self, init_params=True): """ @@ -527,3 +533,22 @@ def get_compute_config(self) -> 'ComputeConfig': def get_rank(self): return self.rank # rank is a class varible defined in gencode + + def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Calculate the gradient norm and clip gradients. + + Args: + max_norm (Optional[float]): max norm value. If None or <= 0, no clipping will be performed. + + Returns: + Tuple of The gradient norm and the list of gradients. + """ + from cube.runtime.gnorm import prepare_for_grad_clip, clip_gnorm + if self._nreplicas2localparams is None: + self._nreplicas2localparams = prepare_for_grad_clip(self, self.get_compute_config().use_zero) + + # make sure the gradients are synchronized + self.sync_grad() + + return clip_gnorm(self._nreplicas2localparams, max_norm) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index 6b167881..62846e03 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -16,6 +16,7 @@ from cube.parallel import ComputeConfig, parallelize, build_optimizer from cube.runtime.module import ParallelModule +from cube.runtime.gnorm import calcuate_gnorm from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively @@ -102,6 +103,7 @@ class StepResult: pred: torch.Tensor loss: torch.Tensor grads: Dict[str, torch.Tensor] + gnorm: torch.Tensor weights: Dict[str, torch.Tensor] @@ -133,7 +135,8 @@ def _train_ddp(model, update_freq, num_replicas, rank): # remove leadding `module.` prefix prefix_len = len('module.') grads = {n[prefix_len:]: p.grad for n, p in model.named_parameters()} - results.append(clone_to_cpu_recursively([y_pred, loss, grads])) + gnorm = calcuate_gnorm(list(model.parameters()))[0] + results.append(clone_to_cpu_recursively([y_pred, loss, grads, gnorm])) optimizer.zero_grad() weights = {n[prefix_len:]: p.data for n, p in model.named_parameters()} results[-1].append(clone_to_cpu_recursively(weights)) @@ -167,7 +170,11 @@ def _train(model, is_cube, update_freq, num_replicas, rank): if i % UPDATE_FREQ == UPDATE_FREQ - 1: optimizer.step() grads = {n: p.grad for n, p in model.named_parameters()} - results.append(clone_to_cpu_recursively([y_pred, loss, grads])) + if is_cube: + gnorm = optimizer.clip_gnorm() + else: + gnorm = calcuate_gnorm(list(model.parameters()))[0] + results.append(clone_to_cpu_recursively([y_pred, loss, grads, gnorm])) optimizer.zero_grad() weights = {n: p.data for n, p in model.named_parameters()} results[-1].append(clone_to_cpu_recursively(weights)) @@ -263,7 +270,7 @@ def test_tp_ddp(update_freq): for k in ddp_result0[i].weights.keys(): # weights assert torch.equal(ddp_result0[i].weights[k], ddp_result1[i].weights[k]) - ga_simulated_result0 = orig_results2[0][0] + ga_simulated_result0: List[StepResult] = orig_results2[0][0] assert len(ddp_result0) == len(ga_simulated_result0) assert len(ddp_result1) == len(ga_simulated_result0) for i in range(len(ddp_result0)): @@ -278,7 +285,10 @@ def test_tp_ddp(update_freq): cube_results = launch_torchrun(4, _gpu_worker_cube, PASRandomSPMD, 2, 4, update_freq) worker_results0, worker_results1, worker_results2, worker_results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] - results0, results1, results2, results3 = worker_results0[0], worker_results1[0], worker_results2[0], worker_results3[0] + results0: List[StepResult] = worker_results0[0] + results1: List[StepResult] = worker_results1[0] + results2: List[StepResult] = worker_results2[0] + results3: List[StepResult] = worker_results3[0] fc1_fullmap = worker_results0[1], worker_results1[1] assert fc1_fullmap == (worker_results2[1], worker_results3[1]) @@ -298,6 +308,7 @@ def test_tp_ddp(update_freq): a, b = r0[i], r1[i] assert torch.equal(a.pred, b.pred) # pred assert torch.equal(a.loss, b.loss) # loss + assert torch.equal(a.gnorm, b.gnorm) # gnorm # grad, weights for r0, r1 in [(results0, results2), (results1, results3)]: @@ -305,6 +316,7 @@ def test_tp_ddp(update_freq): assert len(r0) == len(r1) for i in range(len(r0)): a, b = r0[i], r1[i] + assert torch.equal(a.gnorm, b.gnorm) # gnorm for k in a.grads.keys(): # grad assert torch.equal(a.grads[k], b.grads[k]) for k in a.weights.keys(): # weights @@ -312,13 +324,9 @@ def test_tp_ddp(update_freq): assert len(ga_simulated_result0) == len(results0) for i in range(len(ddp_result0)): - print('iteration: ', i) orig0, compiled0, compiled1 = ga_simulated_result0[i], results0[i], results1[i] - - print('grads') + assert torch.allclose(orig0.gnorm, compiled0.gnorm, atol=1e-6, rtol=1e-6) # gnorm # grad _compare_weights(orig0.grads, compiled0.grads, compiled1.grads, fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map) - - print('weights') # weights _compare_weights(orig0.weights, compiled0.weights, compiled1.weights, fc1_fullmap, fc2_fullmap, fc1_dist_param_map, fc2_dist_param_map) From 77334f194e7ad733a7940e6d712f9c67d81c716e Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 30 Jan 2024 08:44:17 +0000 Subject: [PATCH 1588/1892] Merged PR 2035: parallel module: add zero support parallel module: add zero support --- cube/parallel.py | 6 +++++- tests/parallel_module/test_ddp.py | 35 ++++++++++++++++++++++++++----- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/cube/parallel.py b/cube/parallel.py index 59381da9..9e1b58aa 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -830,7 +830,11 @@ def build_optimizer( non_parallel_module_reducer.build_buckets() def _local_parameters(module: torch.nn.Module): - gen = module._named_members(lambda m: m._parameters.items()) + gen = module._named_members( + lambda m: [(str(id(p)), p) for p in m.parameters_for_optimizer()] # (str(id(p)), p) to meet _named_members requirement + if isinstance(m, ParallelModule) + else m._parameters.items() + ) for _, param in gen: yield param diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index 62846e03..c019d92a 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -200,10 +200,13 @@ def _gpu_worker_ddp(update_freq): orig_results, ) -def _gpu_worker_cube(pas, plan_ngpus, runtime_ngpus, update_freq): +def _gpu_worker_cube(pas, plan_ngpus, runtime_ngpus, update_freq, use_zero): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: - compiled_module = _create_cube_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus), tempdir) + compiled_module = _create_cube_module(pas, + ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero), + tempdir + ) compiled_results = _train( compiled_module, True, update_freq, runtime_ngpus // plan_ngpus, @@ -283,12 +286,18 @@ def test_tp_ddp(update_freq): # print('weight: ', k, torch.max(torch.abs(a0[3][k]- b[3][k]))) assert torch.allclose(a0.weights[k], b.weights[k], atol=1e-2, rtol=1e-2) # weights - cube_results = launch_torchrun(4, _gpu_worker_cube, PASRandomSPMD, 2, 4, update_freq) + cube_results = launch_torchrun(4, _gpu_worker_cube, PASRandomSPMD, 2, 4, update_freq, False) + zcube_results = launch_torchrun(4, _gpu_worker_cube, PASRandomSPMD, 2, 4, update_freq, True) worker_results0, worker_results1, worker_results2, worker_results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] results0: List[StepResult] = worker_results0[0] results1: List[StepResult] = worker_results1[0] results2: List[StepResult] = worker_results2[0] results3: List[StepResult] = worker_results3[0] + zworker_results0, zworker_results1, zworker_results2, zworker_results3 = zcube_results[0], zcube_results[1], zcube_results[2], zcube_results[3] + zresults0: List[StepResult] = zworker_results0[0] + zresults1: List[StepResult] = zworker_results1[0] + zresults2: List[StepResult] = zworker_results2[0] + zresults3: List[StepResult] = zworker_results3[0] fc1_fullmap = worker_results0[1], worker_results1[1] assert fc1_fullmap == (worker_results2[1], worker_results3[1]) @@ -300,8 +309,21 @@ def test_tp_ddp(update_freq): fc2_dist_param_map = worker_results0[4],worker_results1[4] assert fc2_dist_param_map == (worker_results2[4], worker_results3[4]) + fc1_fullmap = zworker_results0[1], zworker_results1[1] + assert fc1_fullmap == (zworker_results2[1], zworker_results3[1]) + fc1_dist_param_map = (zworker_results0[2], zworker_results1[2]) + assert fc1_dist_param_map == (zworker_results2[2], zworker_results3[2]) + + fc2_fullmap = zworker_results0[3], zworker_results1[3] + assert fc2_fullmap == (zworker_results2[3], zworker_results3[3]) + fc2_dist_param_map = zworker_results0[4], zworker_results1[4] + assert fc2_dist_param_map == (zworker_results2[4], zworker_results3[4]) + # pred, loss - for r0, r1 in [(results0, results1), (results2, results3)]: + for r0, r1 in [(results0, results1), (results2, results3), + (zresults0, zresults1), (zresults2, zresults3), + (results0, zresults0), (results2, zresults2) + ]: # have the same input assert len(r0) == len(r1) # iteration count for i in range(len(r0)): @@ -311,7 +333,10 @@ def test_tp_ddp(update_freq): assert torch.equal(a.gnorm, b.gnorm) # gnorm # grad, weights - for r0, r1 in [(results0, results2), (results1, results3)]: + for r0, r1 in [(results0, results2), (results1, results3), + (zresults0, zresults2), (zresults1, zresults3), + (results0, zresults0), (results1, zresults1) + ]: # in the same shard, grads and weights are the same assert len(r0) == len(r1) for i in range(len(r0)): From 50a96f925ab4d39926ad96e2be168b1842ee5b32 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 1 Feb 2024 04:57:10 +0000 Subject: [PATCH 1589/1892] Merged PR 2034: keep IRObject contains input & IRTensor during parsing --- cube/graph/function/function.py | 88 +++++++++++++++------- cube/graph/parser/fx/mapping.py | 2 + cube/graph/parser/fx/parser.py | 26 ++++--- cube/ir/cten.py | 15 +++- tests/graph/parser/test_ir_obj_constant.py | 38 ++++++++++ 5 files changed, 127 insertions(+), 42 deletions(-) create mode 100644 tests/graph/parser/test_ir_obj_constant.py diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 9873ca0d..1a23b89b 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -45,6 +45,27 @@ def is_list_or_tuple(v: Any) -> bool: ) +def ir_object_recursive(obj: IRObject, fn: Callable): + """recursive on obj.value / dict / list / tuple""" + if isinstance(obj, dict): + return any(ir_object_recursive(v, fn) for v in obj.values()) + elif isinstance(obj, (list, tuple)): + return any(ir_object_recursive(v, fn) for v in obj) + elif isinstance(obj, IRObject): + if fn(obj): + return True + elif obj.value is not None: + return ir_object_recursive(obj.value, fn) + else: + return False + else: + return False + + +def ir_object_contains_dynamic(obj: IRObject): + return ir_object_recursive(obj, lambda a: not a.is_constant) + + def Identity(tensor: IRObject, signature = None): signature = 'cube.runtime.function.identity' eshape = ShapeAnno.create_shape_str(tensor.shape) @@ -459,14 +480,31 @@ def BitwiseNot(input, *, out=None, signature=None): return IRDimops(BitwiseNot, 'bitwise_not', signature, annos, [input]) +def _unwrap_value(obj: IRObject): + if isinstance(obj, IRObject): + return _unwrap_value(obj.value) + else: + return obj + + +def _compute_unary_op(input, fn, name): + out_val = fn(_unwrap_value(input)) + contains_dynamic_val = ir_object_contains_dynamic(input) + return IRObject(name=name, value=out_val, is_constant=not contains_dynamic_val) + + +def _compute_binary_op(input, other, fn, name): + out_val = fn(_unwrap_value(input), _unwrap_value(other)) + contains_dynamic_val = ir_object_contains_dynamic(input) or ir_object_contains_dynamic(other) + return IRObject(name=name, value=out_val, is_constant=not contains_dynamic_val) + + def Add(input, other, alpha=1, *, out=None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input + alpha * other if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): - iv = input.value if isinstance(input, IRObject) else input - ov = other.value if isinstance(other, IRObject) else other - return IRPyFunc(signature, [input, other], [IRObject(name='add', value=iv+ov)]) + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, operator.add, 'add')]) signature = 'torch.add' annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): @@ -481,9 +519,7 @@ def Sub(input, other, alpha=1, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input - alpha * other if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): - iv = input.value if isinstance(input, IRObject) else input - ov = other.value if isinstance(other, IRObject) else other - return IRPyFunc(signature, [input, other], [IRObject(name='sub', value=iv-ov)]) + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, operator.sub, 'sub')]) annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) @@ -496,9 +532,7 @@ def Mul(input, other, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input * other if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): - iv = input.value if isinstance(input, IRObject) else input - ov = other.value if isinstance(other, IRObject) else other - return IRPyFunc(signature, [input, other], [IRObject(name='mul', value=iv*ov)]) + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, operator.mul, 'mul')]) signature = 'torch.mul' annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): @@ -511,6 +545,8 @@ def Mod(input, other, *, out = None, signature = None): assert out is None if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input % other + if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, operator.mod, 'mod')]) signature = 'torch.fmod' annos = ['*, ? -> *'] if isinstance(input, IRTensor) and isinstance(other, IRTensor): @@ -524,9 +560,7 @@ def Div(input, other, *, rounding_mode=None, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input / other if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): - iv = input.value if isinstance(input, IRObject) else input - ov = other.value if isinstance(other, IRObject) else other - return IRPyFunc(signature, [input, other], [IRObject(name='div', value=iv/ov)]) + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, operator.truediv, 'div')]) signature = 'torch.div' annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): @@ -541,10 +575,10 @@ def Exp(input, *, out=None, signature=None): """ assert out is None if not isinstance(input, IRObject): - return torch.exp(input) + return torch.exp(input) if isinstance(input, torch.Tensor) else math.exp(input) if not isinstance(input, IRTensor): assert input.value is not None - return IRPyFunc(signature, [input], [IRObject(name='exp', value=torch.exp(input.value))]) + return IRPyFunc(signature, [input], [_compute_unary_op(input, math.exp, 'exp')]) shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] return IRDimops(Exp, 'exp', signature, annos, [input]) @@ -556,10 +590,9 @@ def Sqrt(input, *, out=None, signature=None): """ assert out is None if not isinstance(input, IRObject): - return torch.sqrt(input) + return torch.sqrt(input) if isinstance(input, torch.Tensor) else math.sqrt(input) if not isinstance(input, IRTensor): - iv = input.value if isinstance(input, IRObject) else input - return IRPyFunc(signature, [input], [IRObject(name='sqrt', value=torch.sqrt(iv))]) + return IRPyFunc(signature, [input], [_compute_unary_op(input, math.sqrt, 'sqrt')]) shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] return IRDimops(Sqrt, 'sqrt', signature, annos, [input]) @@ -570,8 +603,8 @@ def RSqrt(input, *, out=None, signature=None): if not isinstance(input, IRObject): return torch.rsqrt(input) if not isinstance(input, IRTensor): - iv = input.value if isinstance(input, IRObject) else input - return IRPyFunc(signature, [input], [IRObject(name='rsqrt', value=torch.rsqrt(iv))]) + # NOTE: can not find a common library implementation of rsqrt for non-tensor + return IRPyFunc(signature, [input], [_compute_unary_op(input, lambda a: 1 / math.sqrt(a), 'rsqrt')]) shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] return IRDimops(RSqrt, 'rsqrt', signature, annos, [input]) @@ -582,9 +615,7 @@ def FloorDiv(input, other, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input // other if (not isinstance(input, IRTensor)) and (not isinstance(other, IRTensor)): - iv = input.value if isinstance(input, IRObject) else input - ov = other.value if isinstance(other, IRObject) else other - return IRPyFunc(signature, [input, other], [IRObject(name='fdiv', value=iv//ov)]) + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, operator.floordiv, 'fdiv')]) annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(other, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, other) @@ -597,9 +628,7 @@ def Pow(input, exponent, *, out=None, signature = None): if (not isinstance(input, IRObject)) and (not isinstance(exponent, IRObject)): return input ** exponent if (not isinstance(input, IRTensor)) and (not isinstance(exponent, IRTensor)): - iv = input.value if isinstance(input, IRObject) else input - ev = exponent.value if isinstance(exponent, IRObject) else exponent - return IRPyFunc(signature, [input, exponent], [IRObject(name='pow', value=iv**ev)]) + return IRPyFunc(signature, [input, exponent], [_compute_binary_op(input, exponent, operator.pow, 'pow')]) annos = ['*, ? -> *', '?, * -> *',] if isinstance(input, IRTensor) and isinstance(exponent, IRTensor): lshape, rshape, oshape = _handle_broadcast(input, exponent) @@ -611,8 +640,7 @@ def Neg(input, *, out=None, signature = None): assert out is None if not isinstance(input, IRObject): return -1 * input if not isinstance(input, IRTensor): - iv = input.value if isinstance(input, IRObject) else input - return IRPyFunc(signature, [input], [IRObject(name='neg', value=-iv)]) + return IRPyFunc(signature, [input], [_compute_unary_op(input, operator.neg, 'neg')]) annos = ['* -> *'] return IRDimops(Neg, 'neg', signature, annos, [input]) @@ -1778,7 +1806,7 @@ def _comparison(creator: Callable, f: Callable, name: str, signature: str, return IRDimops(creator, name, signature, annos, [input, other]) # case3: torch.equal(obj1, obj2) else: - return IRPyFunc(signature, [input, other], [IRObject()]) + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, f, name)]) def CompareGT(input, other, *, out=None, signature = None): @@ -1953,7 +1981,9 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: if isinstance(obj.value[index], IRTensor): out = obj.value[index] else: - out = IRObject(name='getitem', value=obj.value[index]) + val = obj.value[index] + is_constant = not (isinstance(val, IRObject) and not val.is_constant) + out = IRObject(name='getitem', value=val, is_constant=is_constant) return IRPyFunc(signature, [obj, index], [out]) return obj[index] diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 0106188d..7bf192d0 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -59,7 +59,9 @@ def exist(signature: str) -> bool: __ttemplate('mean') : function.Mean, __ttemplate('abs'): function.Abs, __ttemplate('exp'): function.Exp, + 'math.exp': function.Exp, __ttemplate('sqrt'): function.Sqrt, + 'math.sqrt': function.Sqrt, __ttemplate('rsqrt'): function.RSqrt, __ttemplate('clamp'): function.Clamp, __ttemplate('clamp_min'): function.ClampMin, diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 2d90dc67..9c029297 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -1,16 +1,16 @@ -import os import torch import logging from pathlib import Path from typing import Any, List, Tuple, Callable, Union, Dict, Type, Optional from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRFullTensor +from cube.ir.tensor import IRFullTensor, IRTensor from cube.ir.cten import IRObject, IRCell from cube.graph.parser.frame import Frame from cube.graph.parser.fx.mapping import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import IRDimops +from cube.graph.function.function import ir_object_recursive import torch.fx from .concrete_trace_utils import TensorMetadata @@ -60,8 +60,10 @@ def parse(module: torch.fx.GraphModule, # create IRObjects and IRTensors for node in module.graph.nodes: - concrete_value = dummy_inputs.get(node.name) if node.op == 'placeholder' else None - FxModuleParser.init_objects(node, module, frame, concrete_value) + if node.op == 'placeholder': + FxModuleParser.init_objects(node, module, frame, dummy_inputs.get(node.name), is_constant=False) + else: + FxModuleParser.init_objects(node, module, frame, None, is_constant=True) # get graph inputs placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] @@ -111,7 +113,7 @@ def parse_node(node: torch.fx.Node, module, dynamic_shape: bool, frame: Frame) - @staticmethod def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, - frame: Frame, concrete_value: Optional[Any] = None): + frame: Frame, concrete_value: Optional[Any] = None, is_constant: bool = True): assert isinstance(node, torch.fx.Node) def meta2var(meta: Any) -> Any: @@ -133,7 +135,7 @@ def meta2var(meta: Any) -> Any: raise TypeError(f"only support dict type with str key, but got {meta.keys()}.\n{node}") return {key : meta2var(value) for key, value in meta.items()} # TODO: data type check, with cases like {'a': 1.2, 'b': torch.Tensor} - return IRObject(name=node.name, value=meta) + return IRObject(name=node.name, value=meta, is_constant=is_constant) if hasattr(node, 'meta') and node.meta.get('tensor_meta'): meta = node.meta['tensor_meta'] @@ -141,7 +143,7 @@ def meta2var(meta: Any) -> Any: else: # FIXME: double check: there should be a concrete value as example, # otherwise, it may fail in parsing node like getattr - val = IRObject(name=node.name, value=concrete_value) + val = IRObject(name=node.name, value=concrete_value, is_constant=is_constant) frame.add_var(node.name, val) @@ -209,6 +211,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # case2: python runtime function else: _logger.warning(f'Set python runtime function: {fsig}') + if ir_object_recursive(input_vals, lambda a: not a.is_constant): + err_msg = f'non register python runtime function {fsig} has a non constant input: {input_vals}, ' + \ + 'please register it as a customized function using cube.graph.parser.register' + raise RuntimeError(err_msg) ir_node = IRPyFunc(fsig, input_vals, [IRObject()], **kwargs) if isinstance(ir_node, IRCell): @@ -226,8 +232,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule assert len(vals) == len(ir_node.outputs()), f'{vals}, {ir_node.outputs()}' for i in range(len(vals)): ir_node.set_output(i, vals[i]) - elif ir_node.output(0).value is not None: - if dynamic_shape: + elif not isinstance(ir_node.output(0), IRTensor) and ir_node.output(0).value is not None: + if dynamic_shape or \ + ir_object_recursive(ir_node.output(0), lambda a: not a.is_constant) or \ + ir_object_recursive(ir_node.output(0), lambda a: isinstance(a, IRTensor)): frame.set_var(node.name, ir_node.output(0)) ir_node.output(0).name = node.name else: diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 1a9442db..7aef0e36 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -219,7 +219,7 @@ def reset_outputs(self, length:int) -> None: def set_output(self, index: int, val: NestedVarOrStatic): """ - Set the node inputs[output_index] with the tensor + Set the node outputs[output_index] with the tensor Args: val (NestedVarOrStatic): (nested) IRObject or any deterministic value (int, bool, str, etc) @@ -273,16 +273,19 @@ class IRObject: IRObject serves as general data of IRGraph edge """ - def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: Optional[None] = None): + def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: Optional[None] = None, is_constant: bool = True): """ @param name str: object name @param tid int: object unique id + @param val any: the value of this object + @param is_constant bool: if the value is a constant during the whole training / inference """ self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() self.name: str = name if name else 'obj' self._cell: Optional[IRCell] = None self._is_attr: bool = False self._value: Optional[Any] = value + self._is_constant: bool = is_constant def __eq__(self, obj): if not isinstance(obj, IRObject): @@ -340,6 +343,10 @@ def value(self) -> Any: """Get example value""" return self._value + @property + def is_constant(self) -> bool: + return self._is_constant + def __eq__(self, obj) -> bool: if not isinstance(obj, IRObject): return False @@ -347,7 +354,7 @@ def __eq__(self, obj) -> bool: def __copy__(self): """Copy this object but remove the cell information""" - return IRObject(self.name, self._id, self._value) + return IRObject(self.name, self._id, self._value, self._is_constant) def as_attr(self): """ @@ -374,7 +381,7 @@ def overlap(self, other: Any) -> bool: return False def __repr__(self): - return f'Object({self.name}{self.tid}, val={self.value})' + return f'Object({self.name}{self.tid}, val={self.value}, is_constant={self.is_constant})' class IRTensor(IRObject): diff --git a/tests/graph/parser/test_ir_obj_constant.py b/tests/graph/parser/test_ir_obj_constant.py new file mode 100644 index 00000000..d4d61cbf --- /dev/null +++ b/tests/graph/parser/test_ir_obj_constant.py @@ -0,0 +1,38 @@ +import pytest +import tempfile +import math +import torch + +from cube.graph.parser.converter import convert_model + +from ...utils import replace_all_device_with + + +@replace_all_device_with('cpu') +def test_input_broadcast_constant_attr(): + class SimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(10, 5) + + def forward(self, sample): + res = sample['y'] + 1 + res = res - 1 + res = res * 1 + res = res / 1 + res = res // 1 + res = res % 1 + res = res ** 1 + res = res - 1 + res = -res + res = math.exp(res) + res = math.sqrt(res) + return self.fc(sample['x']), res + + with tempfile.TemporaryDirectory() as tempdir: + cube_graph = convert_model(SimpleModel(), {'sample': {'x': torch.rand(4, 10), 'y': 10}}, tempdir, dynamic_shape=True) + # check input is not constant + assert not cube_graph.input(0).value['y'].is_constant + for i, name in enumerate(['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow', 'sub', 'neg', 'exp', 'sqrt']): + op_node = cube_graph.nodes()[i + 1] + assert op_node.signature.split(".")[-1] == name and not op_node.output(0).is_constant From 77010f71d67e472d2dfa5e5a40172e0a7ec1b529 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 1 Feb 2024 04:58:03 +0000 Subject: [PATCH 1590/1892] Merged PR 2021: refine logging for parser and to device --- cube/graph/function/function.py | 28 +++++++++++++++++++ .../concrete_trace_utils/concrete_tracer.py | 2 +- cube/graph/parser/fx/mapping.py | 1 + cube/runtime/function/function.py | 6 ++++ 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 1a23b89b..8bc2d7c8 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -221,6 +221,29 @@ def Arange(*args, out=None, dtype=None, layout=None, return CubeArange(start, end, step, dtype, requires_grad=requires_grad) +def CubeLinspace(start: Union[int, IRObject], end: Union[int, IRObject], steps: Union[int, IRObject], + dtype=None, requires_grad=False, signature=None): + dtype = dtype if dtype is not None else torch.get_default_dtype() + assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + signature = 'cube.runtime.function.linspace' + kwargs = {'start': start, 'end': end, 'steps': steps, + 'dtype': dtype, 'requires_grad': requires_grad} + steps_val = steps.value if isinstance(steps, IRObject) else steps + anno, rules = _get_creator_anno_rules((steps_val,), False) + dimop = IRDimops(CubeLinspace, 'linspace', signature, [anno], [], rules, **kwargs) + dimop.output(0).parent.dtype = dtype + return dimop + + +def Linspace(start, end, steps, *, out=None, dtype=None, + layout=None, device=None, requires_grad=False, signature=None): + """ + torch.linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor + """ + assert layout is None + return CubeLinspace(start, end, steps, dtype, requires_grad=requires_grad) + + def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): # note: device is ignored @@ -1934,6 +1957,11 @@ def To(tensor: IRTensor, dtype_or_device=None, *, device=None, dtype=None, out=N # FIXME: support full version of torch.Tensor.to dtype_or_device = dtype if dtype is not None else dtype_or_device dtype_or_device = device if dtype_or_device is None else dtype_or_device + if isinstance(dtype_or_device, torch.device) or isinstance(device, torch.device): + warn_msg = 'Cube will handle the tensor device placement, the call of torch.Tensor.to(device=...) will be ignore, ' \ + 'if you really want to put the tensor on cpu to excute some op, please wrap all related ops in an independent function ' \ + 'and using cube.graph.parser.register to register this function.' + _logger.warning(warn_msg) # create "to" in cube runtime functions because dtype if not kwarg in torch.Tensor.to signature = 'cube.runtime.function.to' annos = ['* -> *'] diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 759f7f75..282f1271 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -420,7 +420,7 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] result = result.cpu() elif isinstance(result, (list, dict, tuple)): result = tree_map(to_cpu, result) - elif isinstance(result, (int, bool, torch.device, torch.dtype)) or result is None: + elif isinstance(result, (int, bool, float, torch.device, torch.dtype, torch.finfo)) or result is None: # avoid too noisy warning pass else: diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 7bf192d0..9965403d 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -100,6 +100,7 @@ def exist(signature: str) -> bool: __tttemplate('expand'): function.Expand, __tttemplate('expand_as'): function.ExpandAs, __ttemplate('arange'): function.Arange, + __ttemplate('linspace'): function.Linspace, __ttemplate('detach'): function.Detach, __ttemplate('_shape_as_tensor'): function.ShapeAsTensor, __ttemplate('index_select'): function.IndexSelect, diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index 0d194eb5..a01322f2 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -201,6 +201,12 @@ def arange(start: int, end: int, step: int, dtype: torch.dtype, requires_grad=Fa device=torch.cuda.current_device()) +def linspace(start: Union[int, torch.Tensor], end: Union[int, torch.Tensor], + steps: int, dtype: torch.dtype, requires_grad=False): + return torch.linspace(start, end, steps, dtype=dtype, requires_grad=requires_grad, + device=torch.cuda.current_device()) + + def index_select(input: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor: return torch.index_select(input, dim, index) From 75c11bdf1875290e9cde091891fae6c54e2ee3c1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Sun, 4 Feb 2024 05:01:01 +0000 Subject: [PATCH 1591/1892] Merged PR 2039: ParallelModule: add basic checkpoint support --- cube/parallel.py | 18 +- cube/runtime/adapter/reducer.py | 4 +- cube/runtime/module.py | 161 +++++++++++++++++- cube/utils.py | 23 ++- tests/parallel_module/test_checkpoint.py | 202 +++++++++++++++++++++++ 5 files changed, 400 insertions(+), 8 deletions(-) create mode 100644 tests/parallel_module/test_checkpoint.py diff --git a/cube/parallel.py b/cube/parallel.py index 9e1b58aa..26c370ef 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -17,6 +17,7 @@ from cube.ir.tensor import IRFullTensor from cube.flags import CompileFlag, RuntimeFlag +from cube.utils import get_shared_params from cube.graph import IRGraph from cube.graph import parser @@ -32,7 +33,7 @@ from cube.ir.unique import IDGenerator from cube.program import Program from cube.runtime.adapter.reducer import Reducer -from cube.runtime.module import CubeModule, ParallelModule +from cube.runtime.module import CubeModule, ParallelModule, OriginModuleMetadata from cube.runtime.device import DeviceGroup from cube.runtime.gnorm import calcuate_gnorm, clip_grads @@ -293,6 +294,7 @@ def _prepare_and_check_reusable( expected_output_files.append(config_file) expected_output_files.append(outdir / _GRAPH_DUMP_FILE) expected_output_files.append(outdir / _FORWARD_ARGS_DUMP_FILE) + expected_output_files.append(outdir / ParallelModule.ORIGIN_MODULE_METADATA_FILE) existing_output_files = [ f for f in outdir.glob('*') if f.is_file() and ( # just take fullmap.pt.0 to compare @@ -456,7 +458,8 @@ def _gencode( """ graph_ckp = outdir / _GRAPH_DUMP_FILE forward_args_ckp = outdir / _FORWARD_ARGS_DUMP_FILE - if not graph_ckp.exists() or not forward_args_ckp.exists(): + origin_module_metadata_ckp = outdir / ParallelModule.ORIGIN_MODULE_METADATA_FILE + if not graph_ckp.exists() or not forward_args_ckp.exists() or not origin_module_metadata_ckp.exists(): is_module_class = inspect.isclass(module_or_module_class) if is_module_class: try: @@ -480,6 +483,14 @@ def _gencode( if any(isinstance(m, CubeModule) for m in module.modules()): raise RuntimeError('CubeModule can not be nested.') + # save origin module metadata + meta_info = OriginModuleMetadata( + origin_param_names=[name for name, _ in module.named_parameters()], + origin_state_dict_names=list(module.state_dict().keys()), + origin_shared_param_names=get_shared_params(module), + ) + torch.save(meta_info, origin_module_metadata_ckp) + graph, forward_args = _gen_graph(module, dummy_input, outdir, compute_config.dynamic_shape) graph.dump(graph_ckp) torch.save(forward_args, forward_args_ckp) @@ -758,7 +769,7 @@ def build_optimizer( non_parallel_module_reducer_op: str = 'sum', *args, **kwargs, -) -> OptimizerT: +) -> Union[OptimizerT, ParallelOptimizer]: """ Build an optimizer for a module. @@ -792,6 +803,7 @@ def build_optimizer( torch.optim.Optimizer: the optimizer you should use to train the module The optimizer is created by optimizer_fn, and will be patched with the methods in ParallelModule class to support parallelized module. + Please note the type annotation of the returned optimizer (`Union[OptimizerT, ParallelOptimizer]`) is just for intellisense. """ if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 25396266..81ff105a 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -100,7 +100,7 @@ def numel(self) -> int: return self._numel @property - def params(self) -> Tuple: + def params(self) -> List[torch.nn.Parameter]: """Parameter list""" return self._params @@ -481,7 +481,7 @@ def build_buckets(self): for params in seq_buckets: starts.append(buffer_length) numel = sum(p.numel() for p in params) - padding = len(self._ranks) - numel % len(self._ranks) + padding = (len(self._ranks) - numel % len(self._ranks)) % len(self._ranks) buffer_length += numel + padding stops.append(buffer_length) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 64972c4a..7bd20d22 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,8 +1,9 @@ -from typing import List, Dict, Tuple, Optional, TYPE_CHECKING +from typing import List, Set, Dict, Tuple, Optional, TYPE_CHECKING import logging import os import sys from pathlib import Path +from dataclasses import dataclass, asdict import torch import torch.distributed as dist @@ -465,8 +466,40 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): }, filename_prefix + '.full.ckpt') +@dataclass +class OriginModuleMetadata: + origin_state_dict_names: List[str] + origin_param_names: List[str] + origin_shared_param_names: List[Set[str]] + + +@dataclass +class ZeroMetadata: + # a mapping from the index of the parameter in the model + # to (optimizer_index, the start and end in the bucket, the shape of the parameter) + model_idx2opt_idx: Optional[Dict] = None + # a mapping from optimizer_index to the related bucket information (sub_ranks, bucket_size) + opt_idx2ranks: Optional[Dict] = None + + +@dataclass +class ParallelModuleConfig: + rank: int + compute_config: 'ComputeConfig' + dist_param_map: Dict + param_area_map: Dict + cube_param_names: List[str] + + +@dataclass +class ExtraState(ZeroMetadata, OriginModuleMetadata, ParallelModuleConfig): + pass + + class ParallelModule(CubeModule): COMPUTE_CONFIG_FILE = 'compute_config.pt' + ORIGIN_MODULE_METADATA_FILE = 'origin_module_metadata.pt' + EXTRA_STATE_KEY = 'CUBE_EXTRA_STATE' def __init__(self): if self.__class__ == ParallelModule: # not init via super().__init__() @@ -500,11 +533,21 @@ def _post_init(self, init_params=True): if init_params: self.load_attr_content(str(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE_STEM}"))) self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}")) - self._compute_config = torch.load(module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}")) + self._compute_config: 'ComputeConfig' = torch.load(module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}")) + self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}")) for reducer in self.reducers: reducer.build_buckets() + self._zero_metadata = self._get_zero_metadata() + + # add state_dict hook to save extra state + # Please note extra_state is only used for merging, not for loading + # so we can safely remove it in load_state_dict pre hook + self._register_state_dict_hook(ParallelModule._post_state_dict_hook) + # add load_state_dict pre hook to pop extra state to prevent warning + self._register_load_state_dict_pre_hook(ParallelModule._pre_load_state_dict_hook, with_module=True) + def forward(self, *args, **kwargs): if self.training: self._sync_grad_required = True # mark sync_grad() can be called again @@ -552,3 +595,117 @@ def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, Li self.sync_grad() return clip_gnorm(self._nreplicas2localparams, max_norm) + + def _get_zero_metadata(self) -> ZeroMetadata: + """ + Get zero related metadata for checkpointing. + + In this function, we have a mocked optimizer index representing the combined flattened index of (reducer_index, bucket_index) + + Note: + Parameters can be in one bucket or not in any bucket. + When we need to reduce(sume) the gradient of a parameter across ranks, + the parameters will be added in one reducer based on the rank group. + There are two types of reducing: cross scale unit or intra scale unit. + + So when num of scale unit > 1, the parameters have to be reduced across scale units, + so they will be in a reducer. + + When the num of scale unit == 1, the parameters can still need to be reduced inside the scale unit, + when the parameters is replicated because the ops using that parameters are partitioned. + (when the paremeter is used by multiple ops, + but some of ops are partitioned and some of ops are replicated, + In that case, the parameter will not be in a reudcer. + We will use mutliref, and insert identity-allreduce in generated code to reduce the parameter instead of using a reducer. + ) + + Returns: + ZeroMetadata: the zero related metadata + """ + if not self.get_compute_config().use_zero: + return ZeroMetadata() + + model_params = self.parameters_for_optimizer() + opt_idx = 0 # the combined flattened index of (reducer_index, bucket_index) + # key: the index of the parameter in the model + # value: (opt_idx, param_start, param_end, param_shape) + # where param_start and param_end are the start and end index of the parameter in the bucket + model_idx2opt_idx: Dict[int, Tuple[int, int, int, torch.Size]] = {} + # key: opt_idx + # value: (sub_ranks, bucket_size), If value is None, then the parameter is not in a bucket + opt_idx2ranks: Dict[int, Optional[Tuple[List[int], int]]] = {} + model_params_id = [id(param) for param in self.parameters()] + + for reducer in self.reducers: + _, sub_ranks = self._get_zero_subranks(reducer) + for bucket in reducer.buckets: + pstart, pend = 0, 0 + for param in bucket.params: + pstart = pend + pend += param.numel() + model_idx = model_params_id.index(id(param)) + model_idx2opt_idx[model_idx] = (opt_idx, pstart, pend, param.shape) + assert len(bucket._contiguous_params.shape) == 1 + opt_idx2ranks[opt_idx] = (sub_ranks, bucket._contiguous_params.shape[0]) + opt_idx += 1 + + assert len(model_params) >= opt_idx + # The remaining parameters are not in any bucket + # we assign them to the next available opt_idx + # and set the opt_idx2ranks to None + for param in model_params[opt_idx:]: + model_idx = model_params_id.index(id(param)) + model_idx2opt_idx[model_idx] = opt_idx + opt_idx2ranks[opt_idx] = None + opt_idx += 1 + + assert len(model_params) == opt_idx + + return ZeroMetadata( + model_idx2opt_idx=model_idx2opt_idx, + opt_idx2ranks=opt_idx2ranks, + ) + + def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: + """ + Get the index in the zero subgroup the reducer belongs to, and the ranks of the subgroup. + + Args: + reducer (cube.runtime.adapter.Reducer): a reducer of cube model + + Returns: + rank_idx (int): the index of current rank in sub_ranks + sub_ranks (list): the ranks of ZeRO subgroup the current rank belongs to + """ + cf = self.get_compute_config() + if not cf.use_zero: + raise RuntimeError('ZERO is not enabled, cannot get the zero subgroup info') + + rank_idx = reducer.ranks.index(self.get_rank()) + if cf.zero_ngroups > 1: + assert len(reducer.ranks) % cf.zero_ngroups == 0, \ + f'reducer.ranks {reducer.ranks} should be divisible by ZERO_NUM_GROUPS {cf.zero_ngroups}' + zgroup_sz = len(reducer.ranks) // cf.zero_ngroups + group_idx = rank_idx // zgroup_sz + sub_ranks = reducer.ranks[group_idx * zgroup_sz : (group_idx + 1) * zgroup_sz] + new_rank_idx = sub_ranks.index(self.get_rank()) + return new_rank_idx, sub_ranks + else: + assert cf.zero_ngroups == 1 + return rank_idx, reducer.ranks + + def _post_state_dict_hook(self, state_dict, prefix, local_metadata) -> None: + state_dict[f'{prefix}{self.EXTRA_STATE_KEY}'] = asdict( + ExtraState( + rank=self.get_rank(), + compute_config=asdict(self._compute_config), + dist_param_map=self._dist_param_map, + param_area_map=self._fullmap, + cube_param_names=[name for name, _ in self.named_parameters()], + **asdict(self._orign_module_metadata), + **asdict(self._zero_metadata), + ) + ) + + def _pre_load_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None: + state_dict.pop(f'{prefix}{self.EXTRA_STATE_KEY}', None) diff --git a/cube/utils.py b/cube/utils.py index 9985dc02..90f44686 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -1,8 +1,9 @@ import os -from typing import Optional, Tuple, Callable +from typing import Optional, Tuple, Callable, List, Set import logging from pathlib import Path import sys +from collections import defaultdict import cube from cube.runtime.device import DeviceGroup @@ -80,6 +81,26 @@ def load_eval_schedule(filename: Optional[str] = None): return module._infer_step +def get_param_by_name(model: torch.nn.Module, name: str) -> torch.nn.Parameter: + """ + Get the parameter of the model by its full name. + """ + sliced_names = name.split(".") + model_attr = model + for sliced_name in sliced_names: + model_attr = getattr(model_attr, sliced_name) + return model_attr + + +def get_shared_params(model: torch.nn.Module) -> List[Set[str]]: + paramid2name = defaultdict(set) + for name in model.state_dict().keys(): + param = get_param_by_name(model, name) + paramid = id(param) + paramid2name[paramid].add(name) + return [names for _, names in paramid2name.items() if len(names) > 1] + + class accum_mode: """Make cube execution in gradient accumulation mode. diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py new file mode 100644 index 00000000..fca3dc24 --- /dev/null +++ b/tests/parallel_module/test_checkpoint.py @@ -0,0 +1,202 @@ +import tempfile +import itertools +import re +from pathlib import Path +import shutil +import pytest +from typing import Dict, Tuple, List +from dataclasses import dataclass + +import torch +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler + +import numpy as np + +from cube.parallel import ComputeConfig, parallelize, build_optimizer +from cube.runtime.module import ParallelModule, ExtraState +from cube.runtime.gnorm import calcuate_gnorm + +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively + + +class FcRelu(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, out_features, bias=bias) + self.relu2 = nn.ReLU() + + + def forward(self, x): + return self.relu2(self.fc2(self.relu1(self.fc1(x)))) + + +class FcRelu_4_4(FcRelu): + def __init__(self): + super().__init__(4, 4) + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + + +def _create_cube_module(pas, compute_config, cube_savedir): + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu1') + self.linear2 = nn.Linear(4, 4) + self.fc_relu2 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu2') + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + compiled_module = CompiledModule().cuda() + return compiled_module + +DATA_SIZE = 256 + +@dataclass +class StepResult: + pred: torch.Tensor + loss: torch.Tensor + grads: Dict[str, torch.Tensor] + gnorm: torch.Tensor + weights: Dict[str, torch.Tensor] + + +def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): + ckpt_file_template = 'ckpt_{rank}_{start}.pth' + ckpt_start_file = ckpt_dir / ckpt_file_template.format( + rank=torch.distributed.get_rank(), + start=start + ) + init_random() + + loss_fn = nn.BCELoss() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + if ckpt_start_file.exists(): + ckpt_dict = torch.load(ckpt_start_file) + model_state_dict = ckpt_dict['model'] + assert 'fc_relu1.CUBE_EXTRA_STATE' in model_state_dict + assert 'fc_relu2.CUBE_EXTRA_STATE' in model_state_dict + optimizer_state_dict = ckpt_dict['optimizer'] + model.load_state_dict(model_state_dict) + optimizer.load_state_dict(optimizer_state_dict) + data = [] + init_random() + for _ in range(DATA_SIZE): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + data = data[start:end] # continue from last training + data = [data[i] for i in range(rank, len(data), num_replicas)] + results = [] + for i, (x, y) in enumerate(data): + model.train() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.step() + grads = {n: p.grad for n, p in model.named_parameters()} + gnorm = optimizer.clip_gnorm() + results.append(clone_to_cpu_recursively([y_pred, loss, grads, gnorm])) + optimizer.zero_grad() + weights = {n: p.data for n, p in model.named_parameters()} + results[-1].append(clone_to_cpu_recursively(weights)) + results[-1] = StepResult(*results[-1]) + + ckpt_file = ckpt_dir / ckpt_file_template.format( + rank=torch.distributed.get_rank(), + start=end + ) + model_state_dict = model.state_dict() + assert 'fc_relu1.CUBE_EXTRA_STATE' in model_state_dict + assert 'fc_relu2.CUBE_EXTRA_STATE' in model_state_dict + extra_state1 = ExtraState(**model_state_dict['fc_relu1.CUBE_EXTRA_STATE']) + assert extra_state1.compute_config + assert extra_state1.model_idx2opt_idx + assert extra_state1.opt_idx2ranks + assert extra_state1.origin_param_names + optimizer_state_dict = optimizer.state_dict() + torch.save({ + 'model': model_state_dict, + 'optimizer': optimizer_state_dict + }, ckpt_file) + return results + + +def _gpu_worker(pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_count): + init_distributed() + compiled_results = [] + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + for i in range(resume_count): + start = i * per_resume_update_count + end = (i + 1) * per_resume_update_count + compiled_module = _create_cube_module(pas, + ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=True), + tempdir + ) + compiled_results.extend(_train( + compiled_module, + runtime_ngpus // plan_ngpus, + torch.distributed.get_rank() // plan_ngpus, + start, end, tempdir + )) + return compiled_results + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_checkpoint(): + cube_results = launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4, 32, 1) + rcube_results = launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4, 16, 2) + + results0, results1, results2, results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] + rresults0, rresults1, rresults2, rresults3 = rcube_results[0], rcube_results[1], rcube_results[2], rcube_results[3] + + # pred, loss + for r0, r1 in [(results0, results1), (results2, results3), + (rresults0, rresults1), (rresults2, rresults3), + (results0, rresults0), (results2, rresults2) + ]: + # have the same input + assert len(r0) == len(r1) # iteration count + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.pred, b.pred) # pred + assert torch.equal(a.loss, b.loss) # loss + assert torch.equal(a.gnorm, b.gnorm) # gnorm + + # grad, weights + for r0, r1 in [(results0, results2), (results1, results3), + (rresults0, rresults2), (rresults1, rresults3), + (results0, rresults0), (results1, rresults1) + ]: + # in the same shard, grads and weights are the same + assert len(r0) == len(r1) + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.gnorm, b.gnorm) # gnorm + for k in a.grads.keys(): # grad + assert torch.equal(a.grads[k], b.grads[k]) + for k in a.weights.keys(): # weights + assert torch.equal(a.weights[k], b.weights[k]) From d43d712a5ec7fc4acd53322bc6d8342b6636b4db Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 5 Feb 2024 08:39:15 +0000 Subject: [PATCH 1592/1892] Merged PR 2037: support pipeline in autodist - remove useless log - refine PAS check --- cube/compiler.py | 6 +++++- cube/graph/function/function.py | 2 -- cube/graph/gener/gen.py | 2 +- cube/parallel.py | 6 +++++- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index 42b5a288..e9f5555d 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -11,6 +11,7 @@ from cube.ir.unique import IDGenerator from cube.graph.gener.gen import IRAdapterGener from cube.graph.graph import IRGraph +from cube.ir.operator import IRBpOperation from cube.ir.cten import IRObject from cube.ir.tensor import IRFullTensor from cube.graph.function.anchor import IRGraphAnchor @@ -200,7 +201,7 @@ def decorator(fn: Callable) -> Callable: if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") - # check assignment and remove anchor node + # check assignment for node in graph.nodes(flatten=True): # skip graph anchor: will be removed # skip multiref and IRPyFunc: they will be managed by system @@ -208,11 +209,14 @@ def decorator(fn: Callable) -> Callable: continue if isinstance(node, IRPyFunc): continue + if isinstance(node, IRBpOperation) and node.mirror.name == 'multiref': + continue if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") # generate adapter start = time.time() + # anchor node removed in gener graph = IRAdapterGener.gen(graph, cost_fn=comm_cost_fn) span = time.time() - start _logger.info('finish generating adapters: {:.2f} s'.format(span)) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 8bc2d7c8..29a81789 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -104,8 +104,6 @@ def Linear(input, weight, bias=None, signature = None): return IRDimops(Linear, 'linear', signature, annos, [input, weight], bias=None) else: annos = ['b * k^, n k^, n -> b * n'] - _logger.warning( - 'detected a linear operator has bias, the partition on reduction dimension is disabled.') return IRDimops(Linear, 'linear', signature, annos, [input, weight, bias]) diff --git a/cube/graph/gener/gen.py b/cube/graph/gener/gen.py index 8732032f..39170cd4 100644 --- a/cube/graph/gener/gen.py +++ b/cube/graph/gener/gen.py @@ -670,7 +670,7 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): devops.setdefault(devid, []).append(consumer) assert devtensors[devid][0] == ctensor, ( f"Detect that a full tensor is partitioned differently on a device.\n" - f"To achieve this, need call graph.multiref before graph transformation.\n" + f"To avoid this, need call graph.multiref before graph transformation.\n" f"{graph.debug_tensor_map_str(ftensor)}" ) diff --git a/cube/parallel.py b/cube/parallel.py index 26c370ef..c9be25c6 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -21,6 +21,7 @@ from cube.graph import IRGraph from cube.graph import parser +from cube.ir.operator import IRBpOperation from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.pyfunc import IRPyFunc from cube.graph.schedule.schedplan import SchedulePlan @@ -506,7 +507,7 @@ def _gencode( if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") - # check assignment and remove anchor node + # check assignment for node in graph.nodes(flatten=True): # skip graph anchor: will be removed # skip multiref and IRPyFunc: they will be managed by system @@ -514,8 +515,11 @@ def _gencode( continue if isinstance(node, IRPyFunc): continue + if isinstance(node, IRBpOperation) and node.mirror.name == 'multiref': + continue if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") + # anchor node removed in gener graph = IRAdapterGener.gen(graph, cost_fn=None) if graph.sched is not None: graph.sched.apply() From bc4f08c26ab7e4d7c4b80632f4120c4200b592f3 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Sun, 18 Feb 2024 08:41:27 +0000 Subject: [PATCH 1593/1892] Merged PR 2043: parallel module: remove reducer_op option reducer_op option is not very useful and can lead to inconsistence when there is intra scale unit reduce. For example, If a param is replicated due to op's partition, then the param will be added to reducer and its reduce behavior depends on reducer_op option. But if the param is used by multiple ops, and some ops are partitioned, and other ops are replicated, In that case, the param will be reduced by identity-allreduce, and the reducer_op is always `sum` (regardless the value of reducer_op option) So let's remove this option. --- cube/parallel.py | 55 +++---- cube/runtime/module.py | 2 +- docs/parallel_module.md | 40 +++-- tests/parallel_module/test_scale_grads.py | 178 ++++++++++++++++++++++ 4 files changed, 226 insertions(+), 49 deletions(-) create mode 100644 tests/parallel_module/test_scale_grads.py diff --git a/cube/parallel.py b/cube/parallel.py index c9be25c6..bd73bffc 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -40,7 +40,6 @@ logger = logging.getLogger(__name__) -_VALID_REDUCER_OPS = ['sum', 'avg', 'mean', 'max', 'min'] @dataclass(frozen=True) @@ -54,20 +53,6 @@ class ComputeConfig: use_zero: bool = False zero_ngroups: int = 1 - # which torch.distributed.ReduceOp is used when reduce gradients - # by torch.distributed.all_reduce or torch.distributed.reduce_scatter - # a special case for mean op - # In some cases, you may want to firstly divide the local gradients, and then use torch.distributed.ReduceOp.SUM - # to get the final the gradients - # example code to divide the local gradients: - #```python - # def _mean_hook(reducer, grad): - # if reducer.reduce_op == torch.distributed.ReduceOp.SUM: - # grad.div_(reducer.ranks) - # optimizer.register_reducer_pre_hook(_mean_hook) - # ``` - reducer_op: str = 'sum' - # you can put any configuration here # *Note*: the assumption is different user_config should generate different code. # Example 1: save module configuration @@ -114,8 +99,6 @@ def __post_init__(self): raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be a multiple of plan_ngpus {self.plan_ngpus}") if self.use_zero and self.zero_ngroups < 0: raise ValueError(f"zero_ngroups {self.zero_ngroups} must be >= 0") - if self.reducer_op not in _VALID_REDUCER_OPS: - raise ValueError(f"reducer_op {self.reducer_op} is not supported.") @property def gpu_config(self) -> Dict[str, int]: @@ -141,7 +124,7 @@ def _flags(flags, /, **kwargs): def _compile_flags(compute_config: ComputeConfig): return _flags( CompileFlag, - async_reducer=False, reducer_op=compute_config.reducer_op, async_comm=False, + async_reducer=False, reducer_op='sum', async_comm=False, use_zero=compute_config.use_zero, zero_ngroups=compute_config.zero_ngroups, ) @@ -742,6 +725,21 @@ def clip_gnorm(self, max_norm: Optional[float] = None) -> torch.Tensor: """ ... + def scale_grads(self, scale: float) -> None: + """ + Scale the gradients of the module. + + Please note + 1. you can only call this function **after** `sync_shard_grad`, + because the gradients are `None` until `sync_shard_grad` is called. + 2. Only the gradients of parameters in this optimizer be multiplied by this factor, + (When ZERO is on, not all parameters of the module are added to the optimizer). + + Args: + scale (float): the scale factor. Gradients will be multiplied by this factor. + """ + ... + def register_reducer_pre_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): """ Register pre hooks to reducers which will be applied before gradient synchronization. @@ -770,7 +768,6 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], - non_parallel_module_reducer_op: str = 'sum', *args, **kwargs, ) -> Union[OptimizerT, ParallelOptimizer]: @@ -796,11 +793,8 @@ def build_optimizer( module (torch.nn.Module): the module to be optimized optimizer_fn (Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]): It can be the optimizer class or optimizer factory function. - If it is a factory function, the signature should be the same with optimizer class constructor. - non_parallel_module_reducer_op (str): the reducer op for non-parallel modules. Default is 'sum'. - *args: the args for optimizer constructor. - Note: If you use `*args`, you must specify `non_parallel_module_reducer_op`. - Suggest to use kwargs instead, so you don't need to explicitly specify the default value of `non_parallel_module_reducer_op`. + The first parameter of the optimizer_fn should be the parameters of the module. + *args: other args for `optimizer_fn` besides parameters. **kwargs: the kwargs for optimizer constructor Returns: @@ -812,8 +806,6 @@ def build_optimizer( if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): raise RuntimeError("End2End mode is not supported") - if not non_parallel_module_reducer_op in _VALID_REDUCER_OPS: - raise ValueError(f"non_parallel_module_reducer_op {non_parallel_module_reducer_op} is not supported.") RuntimeFlag.skip_reducer = True RuntimeFlag.skip_zero_grad = False @@ -839,7 +831,7 @@ def build_optimizer( for i in range(plan_ngpus): DeviceGroup().get_group(list(range(i, runtime_ngpus, plan_ngpus))) group = list(range(rank % plan_ngpus, runtime_ngpus, plan_ngpus)) - non_parallel_module_reducer = Reducer(group, reduce_op=non_parallel_module_reducer_op) + non_parallel_module_reducer = Reducer(group) for m in non_parallel_modules: for param in m.parameters(recurse=False): # only add leaf parameters to avoid duplicate non_parallel_module_reducer.add_param(param) @@ -915,6 +907,15 @@ def _clip_gnorm(self, max_norm: Optional[float] = None): optimizer.clip_gnorm = types.MethodType(_clip_gnorm, optimizer) + def _scale_grads(self, scale: float) -> None: + if parallel_modules[0]._sync_grad_required: + raise RuntimeError("You can only call scale_grads() after gradients are synchronized.") + for pg in optimizer.param_groups: + for p in pg['params']: + if p.grad is not None: + p.grad.mul_(scale) + optimizer.scale_grads = types.MethodType(_scale_grads, optimizer) + def _register_reducer_pre_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): for m in parallel_modules: for reducer in m.reducers: diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 7bd20d22..6d27c607 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -612,7 +612,7 @@ def _get_zero_metadata(self) -> ZeroMetadata: so they will be in a reducer. When the num of scale unit == 1, the parameters can still need to be reduced inside the scale unit, - when the parameters is replicated because the ops using that parameters are partitioned. + when a parameter is replicated due to its operator's partition (i.e., through graph.partition) (when the paremeter is used by multiple ops, but some of ops are partitioned and some of ops are replicated, In that case, the parameter will not be in a reudcer. diff --git a/docs/parallel_module.md b/docs/parallel_module.md index d68f4e53..7d30b236 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -114,7 +114,6 @@ class ComputeConfig: dynamic_shape: bool = True - reducer_op: str = 'sum' use_zero: bool = False zero_ngroups: int = 1 @@ -132,23 +131,11 @@ We can categorize the fields into 4 categories: 3. Code generation feature configuration - use_zero: whether to use zero. If it is true, the generated code will use zero1 to do distributed training. - zero_ngroups: the number of groups to be used in zero. - - reducer_op: the reducer operation for the gradients. It can be `sum`, `mean`, `min`, `max`, `avg`. 4. User configuration - user_config: the user configuration. A typical usage is deciding whether skipping compiling and reusing the previously compiled parallel module. If user_config is the same between two runs, compiling in the second run will be skipped. Note: -1. `reducer_op` represents which `torch.distributed.ReduceOp` is used when reduce gradients - by torch.distributed.all_reduce or torch.distributed.reduce_scatter - - In some cases, you may want to firstly divide the local gradients, and then use torch.distributed.ReduceOp.SUM to get the final the gradients. - You can achieve that speical mean with `optimizer.register_reducer_pre_hook` by setting `reducer_op` to `sum` and divide the local gradients with the following code: - ```python - def _mean_hook(reducer, grad): - if reducer.reduce_op == torch.distributed.ReduceOp.SUM: - grad.div_(reducer.ranks) - optimizer.register_reducer_pre_hook(_mean_hook) - ``` -2. You can put any graph related configuration here. The assumption is different user_config should generate different graph/code. So if the user config is changed, we will regenerate the graph/code automatically. Here are some examples: +1. You can put any graph related configuration here. The assumption is different user_config should generate different graph/code. So if the user config is changed, we will regenerate the graph/code automatically. Here are some examples: - Example 1: save module configuration ```python @@ -289,7 +276,6 @@ We have `build_optimizer` to build an optimizer for distributed training. def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], - non_parallel_module_reducer_op: str = 'sum', *args, **kwargs, ) -> OptimizerT: @@ -298,10 +284,8 @@ It has the following parameters: - module (torch.nn.Module): the module to be optimized - optimizer_fn (Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]): It can be the optimizer class or optimizer factory function. -- non_parallel_module_reducer_op (str): the reducer op for non-parallel modules. Default is 'sum'. -- *args: the args will pass to `optimizer_fn`. - If you use `*args`, you must specify `non_parallel_module_reducer_op` explicitly even when you only need its default value. - So we suggest using `kwargs` instead of `args` to specify `optimizer_fn` arguments if possible. + The first parameter of the optimizer_fn should be the module parameters. +- *args: other args for `optimizer_fn` besides module parameters. - **kwargs: the kwargs will pass to `optimizer_fn` To support distrubted training, in the function we need to hook 4 places: @@ -321,9 +305,23 @@ To support distrubted training, in the function we need to hook 4 places: 1. `sync_shard_grad`: Sync the shard gradients of the module from nodes with same shard to the optimizer. This function is called in optimizer's pre-step hook. But If you want to access the gradients before `optimizer.step()`(for example, you need gnorm), you need to call this function manually. -2. `register_reducer_pre_hook`, `register_reducer_post_hook`: Register pre/post hooks to reducers which will be applied before/after gradient synchronization. These hooks will apply to all the reducers (including `_non_parallel_module_reducer`) in the optimizer. +2. `scale_grads`: Scale the gradients of the module by multiplying a factor. This function is useful to avoid overflow when the gradients are large. Please note you can only call this function **after** `sync_shard_grad`, because the gradients are `None` until `sync_shard_grad` is called. -3. `_non_parallel_module_reducer`: The reducer for the modules which are not parallelized. It is used to sync the parameters in those modules across units. +3. `clip_gnorm`: Clip the gradients with global norm, and return the global gnorm value, it will sync grads across devices if necessary. This function is useful to avoid gradient explosion. + +4. `register_reducer_pre_hook`, `register_reducer_post_hook`: Register pre/post hooks to reducers which will be applied before/after gradient synchronization. These hooks will apply to all the reducers (including `_non_parallel_module_reducer`) in the optimizer. + +You can use `register_reducer_pre_hook` and `register_reducer_post_hook` to do some operations before/after gradient synchronization. Not all paramers are managed by reducers, so it is tricky to use them. Actually we don't encourage you to use these functions. + +Here is one example (Assume we calculate loss with sum) showing how to carefully scale down the gradient locally and scale up the gradient after reduce. This is useful to avoid overflow when the gradients are large:. + +```python +num_scale_units = ... +optimizer.register_reducer_pre_hook(lambda reducer, grad: grad.div_(num_scale_units)) # scale down with factor num_scale_units before reduce +optimizer.register_reducer_post_hook(lambda reducer, grad: grad.mul_(num_scale_units) # scale up with factor num_scale_units after reduce +``` + +5. `_non_parallel_module_reducer`: The reducer for the modules which are not parallelized. It is used to sync the parameters in those modules across units. ### Dataset diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py new file mode 100644 index 00000000..62e7e8f7 --- /dev/null +++ b/tests/parallel_module/test_scale_grads.py @@ -0,0 +1,178 @@ +import tempfile +import itertools +import re +from pathlib import Path +import shutil +import pytest +from typing import Dict, Tuple, List +from dataclasses import dataclass + +import torch +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler + +import numpy as np + +from cube.parallel import ComputeConfig, parallelize, build_optimizer +from cube.runtime.module import ParallelModule, ExtraState +from cube.runtime.gnorm import calcuate_gnorm + +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively + + +class FcRelu(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, out_features, bias=bias) + self.relu2 = nn.ReLU() + + + def forward(self, x): + return self.relu2(self.fc2(self.relu1(self.fc1(x)))) + + +class FcRelu_4_4(FcRelu): + def __init__(self): + super().__init__(4, 4) + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + + +def _create_cube_module(pas, compute_config, cube_savedir): + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu1') + self.linear2 = nn.Linear(4, 4) + self.fc_relu2 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu2') + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + compiled_module = CompiledModule().cuda() + return compiled_module + +DATA_SIZE = 64 + +@dataclass +class StepResult: + pred: torch.Tensor + loss: torch.Tensor + grads: Dict[str, torch.Tensor] + gnorm: torch.Tensor + weights: Dict[str, torch.Tensor] + + +def _train(model: torch.nn.Module, num_replicas, rank, scale_grads: bool): + NUM_SCALE_UNITS = 2 + NUM_SAMPLES_PER_UPDATE = 2 + init_random() + + loss_fn = nn.BCELoss() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + if scale_grads: + # before reduce + optimizer.register_reducer_pre_hook(lambda reducer, grad: grad.div_(NUM_SCALE_UNITS)) + # after reduce + optimizer.register_reducer_post_hook(lambda reducer, grad: grad.mul_(NUM_SCALE_UNITS)) + data = [] + init_random() + for _ in range(DATA_SIZE): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + data = [data[i] for i in range(rank, len(data), num_replicas)] + results = [] + for i, (x, y) in enumerate(data): + model.train() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.sync_shard_grad() + optimizer.scale_grads(1/NUM_SAMPLES_PER_UPDATE) + optimizer.step() + grads = {n: p.grad for n, p in model.named_parameters()} + gnorm = optimizer.clip_gnorm() + results.append(clone_to_cpu_recursively([y_pred, loss, grads, gnorm])) + optimizer.zero_grad() + weights = {n: p.data for n, p in model.named_parameters()} + results[-1].append(clone_to_cpu_recursively(weights)) + results[-1] = StepResult(*results[-1]) + + return results + + +def _gpu_worker(pas, plan_ngpus, runtime_ngpus, scale_grads: bool): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_scale_grads') as tempdir: + compiled_module = _create_cube_module(pas, + ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=True), + tempdir + ) + return _train( + compiled_module, + runtime_ngpus // plan_ngpus, + torch.distributed.get_rank() // plan_ngpus, + scale_grads + ) + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_scale_grads(): + cube_results = launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4, True) + rcube_results = launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4, False) + + results0, results1, results2, results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] + rresults0, rresults1, rresults2, rresults3 = rcube_results[0], rcube_results[1], rcube_results[2], rcube_results[3] + + # pred, loss, gnorm + for r0, r1 in [(results0, results1), (results2, results3), + (rresults0, rresults1), (rresults2, rresults3), + (results0, rresults0), (results2, rresults2) + ]: + assert len(r0) == len(r1) # iteration count + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.pred, b.pred) # pred + assert torch.equal(a.loss, b.loss) # loss + assert torch.equal(a.gnorm, b.gnorm) # gnorm + + # grad, weights + for r0, r1 in [(results0, results2), (results1, results3), + (rresults0, rresults2), (rresults1, rresults3), + (results0, rresults0), (results1, rresults1) + ]: + assert len(r0) == len(r1) + for i in range(len(r0)): + a, b = r0[i], r1[i] + # for grads, as we have scale_grads, + # grads in `parameters_for_optimizer` are scaled, but the rest are not + # so they can be the same or scaled by 2 or divided by 2 + for k in a.grads.keys(): # grad + assert torch.equal(a.grads[k], b.grads[k]) \ + or torch.equal(a.grads[k], b.grads[k] * 2) \ + or torch.equal(a.grads[k], b.grads[k] / 2) + # in the same shard, weights are the same + for k in a.weights.keys(): # weights + assert torch.equal(a.weights[k], b.weights[k]) From 06ec6591c8ddffc95ee126dd3d3b7b4c55b870a4 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 22 Feb 2024 02:29:10 +0000 Subject: [PATCH 1594/1892] Merged PR 2042: Support Flexible Merging Checkpoints 1. Add richer information in gencode with original tensor name, full tensor shape of the weight 2. Remove reliance of `dist_param_map` in merging full model states 3. Make merging full model dict states standalone 4. Fix bugs in merging model states and optimizer state dicts (for pipeline cases) --- cube/codegen/emit.py | 3 +- cube/codegen/module/module.py | 12 +- cube/graph/parser/frame.py | 2 +- cube/graph/parser/fx/parser.py | 4 +- cube/runtime/module.py | 469 +++++++++++++++-------------- tests/runtime/test_module_merge.py | 167 ++++++++++ 6 files changed, 426 insertions(+), 231 deletions(-) create mode 100644 tests/runtime/test_module_merge.py diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index 08427ff9..cb6de746 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -44,8 +44,7 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: """ if isinstance(val, IRObject): tensor_name = val.name - if '.' in tensor_name: - tensor_name = tensor_name.split('.')[0] + tensor_name = tensor_name.replace('.', '_') name = '_'.join([tensor_name, str(val.tid)]) if prefix_attr is not None and val.is_attr(): name = prefix_attr + name diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index dad58d81..2a64969f 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -598,7 +598,7 @@ def init_attributes(self, node: IRCell): """ psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}))" - map_sign = "self.add_full_map('{attr}', {tid}, {slicers}, {val_chunks})" + map_sign = "self.add_full_map('{attr}', {tid}, {is_param}, '{orig_name}', {full_shape}, {slicers}, {val_chunks})" if not isinstance(node, IRSegment): for itensor in node.inputs(): name = self.tensor_name(itensor, prefix_attr='self.') @@ -612,12 +612,16 @@ def init_attributes(self, node: IRCell): dtype=itensor.dtype ) self.model_init_statements.append(code) - tid = itensor.parent.tid slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) val_chunks = itensor.valmap[1] code = map_sign.format( - attr=self.tensor_name(itensor), tid=tid, - slicers=str(slicers), val_chunks=val_chunks + attr=self.tensor_name(itensor), + tid=itensor.parent.tid, + is_param=itensor.is_param(), + orig_name=itensor.parent.name, + full_shape=tuple(itensor.parent.shape), + slicers=str(slicers), + val_chunks=val_chunks ) self.model_init_statements.append(code) self.model_init_statements.append('') diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index ce5f1c37..7898d965 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -153,7 +153,7 @@ def save_attr_map(self, save_file: str = 'dist_param_map.pt'): """ Save local_param -> origin_param name map. """ - ir_name_to_orig_name = {str(t.name): name for t, (name, _) in self._attr_map.items()} + ir_name_to_orig_name = {str(t.name).replace('.', '_'): name for t, (name, _) in self._attr_map.items()} torch.save(ir_name_to_orig_name, save_file) def push_param(self, var_name): diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 9c029297..80d5a2c8 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -137,7 +137,7 @@ def meta2var(meta: Any) -> Any: # TODO: data type check, with cases like {'a': 1.2, 'b': torch.Tensor} return IRObject(name=node.name, value=meta, is_constant=is_constant) - if hasattr(node, 'meta') and node.meta.get('tensor_meta'): + if hasattr(node, 'meta') and 'tensor_meta' in node.meta: meta = node.meta['tensor_meta'] val = meta2var(meta) else: @@ -273,6 +273,8 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, # the case that the parameter is the first time used by getattr if not exist_tensor: tensor = frame.get_var(node.name) + # set tensor name same with the name in original model + tensor.name = node.target if tensor.requires_grad: tensor.as_param() else: diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 6d27c607..0eb1308b 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,4 +1,4 @@ -from typing import List, Set, Dict, Tuple, Optional, TYPE_CHECKING +from typing import List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any import logging import os import sys @@ -12,6 +12,7 @@ from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer from cube.runtime.gnorm import ParamsInfo +from cube.flags import CompileFlag if TYPE_CHECKING: from cube.parallel import ComputeConfig @@ -20,6 +21,22 @@ _logger = logging.getLogger(__name__) +@dataclass +class AttrMeta: + # full tensor ID + tid: int + # is this a parameter + is_param: bool + # original name in the module + orig_name: str + # shape of the full tensor + shape: Tuple[int, ...] + # list of slicers to index the full tensor + slicers: Tuple[slice, ...] + # the number of the partitioned values, usually 1 + # (i.e., no partition on value -> no need to sum up) + val_chunks: int + class CubeModule(torch.nn.Module): """ The module is responsible for parameter synchronization @@ -29,13 +46,9 @@ class CubeModule(torch.nn.Module): def __init__(self): super().__init__() self._reducers: List[Reducer] = list() - # self._fullmap contains the mapping of local attribute tensors to its fulltensor - # name (from named_parameters or named_buffers) -> ( - # fulltensor.tid, - # index position of its fulltensor, - # value partition num_chunks, - # ) - self._fullmap : Dict[str, Tuple[int, Tuple[slice], int]] = dict() + # self._fullmap is mapping from the name of local attribute tensor + # to its corresponding fulltensor meta + self._fullmap : Dict[str, AttrMeta] = dict() @property def reducers(self): @@ -138,19 +151,22 @@ def gather_params(self): if reducer.zero: reducer.gather_params() - def add_full_map(self, attr: str, tid: int, slicers: Tuple[slice], val_chunks: int): - """ - Add an attribute map. - The mapping includes current attribute name (str) to logical tensor id, - and the mapping of logical tensor id including spatial (slice) and val chunks - - @param attr str: attribute name of this moudle - @param tid int: full tensor id - @param slicers Tuple[slice]: indexing from full tensor - @param val_chunks int: the number of value chunks. + def add_full_map(self, attr: str, tid: int, is_param: bool, orig_name: str, shape: Tuple[int], + slicers: Tuple[slice], val_chunks: int): + """Add an attribute map. + + Args: + attr (str): attribute name of this module + tid (int): full tensor id + is_param (bool): whether this attribute is a parameter, otherwise it is a buffer + orig_name (str): attribute name in the original module + shape (Tuple[int]): shape of the full tensor + slicers (Tuple[slicer]): indexing from full tensor + val_chunks int: the number of value chunks. """ assert hasattr(self, attr), f"{attr} is not in the module" - self._fullmap[attr] = (tid, slicers, val_chunks) + meta = AttrMeta(tid, is_param, orig_name, shape, slicers, val_chunks) + self._fullmap[attr] = meta # TODO: remove this function, use the property instead def get_full_map(self): @@ -170,7 +186,6 @@ def load_attr_content(self, filename: str): raise RuntimeError(f"Cannot find file {filename}.0 in load_attr_content") with torch.no_grad(): _logger.info(f'loading partitioned model from {filename}, number of model parameter chunks: {npartitions}') - # self._fullmap attr_names = set(self._fullmap.keys()) for file_idx in range(npartitions): # part_model contains a subset of attributes, where each attribute is a fulltensor @@ -178,13 +193,13 @@ def load_attr_content(self, filename: str): part_model: Dict[int, torch.Tensor] = torch.load(filename + f'.{file_idx}') loaded_name = set() for attr_name in attr_names: - tid, slicers, val_nchunks = self._fullmap[attr_name] - if tid not in part_model: + meta = self._fullmap[attr_name] + if meta.tid not in part_model: continue attr = getattr(self, attr_name) - content = part_model[tid][slicers] - if val_nchunks != 1: - content = content / val_nchunks + content = part_model[meta.tid][meta.slicers] + if meta.val_chunks != 1: + content = content / meta.val_chunks attr.copy_(content) loaded_name.add(attr_name) for name in loaded_name: @@ -226,13 +241,76 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref }, filename) @staticmethod - def merge_partial_states(state_dicts, zero_idx_maps=None): + def merge_model_state_dicts(state_dicts: List[Dict], + fullmaps: List[Dict[str, AttrMeta]]): + """Merge model states from multiple shard into a single-model state. + + Note: + Users only need to provide as fewer local model states as necessary to + cover the full model state. + + Args: + state_dicts (List[Dict[str, torch.Tensor]]): per-rank local model state dict from model.state_dict() + fullmaps (List[Dict[str, AttrMeta]]): per-rank fullmap + + Returns: + full_state_dicts (List[Dict[str, torch.Tensor]]): Full model state dict """ - :param state_dicts: list of state_dict from different ranks - state_dict(model_state_dict, optimizer_state_dict, dist_param_map, param_area_map) - :return: merged state_dict(model_state_dict, optimizer_state_dict,) + if len(state_dicts) != len(fullmaps): + raise ValueError("Expected model state dicts to have the same length as fullmaps") + + full_model_state_dict: Dict[str, torch.Tensor] = {} + # gather param/buffer full tensor + for model_state_dict, local_fullmap in zip(state_dicts, fullmaps): + for local_name, meta in local_fullmap.items(): + # create full tensor on cpu + partial_tensor = model_state_dict[local_name] + if meta.orig_name not in full_model_state_dict: + full_model_state_dict[meta.orig_name] = torch.empty( + meta.shape, dtype=partial_tensor.dtype) + # assign partial tensor + if meta.val_chunks > 1: + raise NotImplementedError("Not support of partitioning parameter / buffer at value dimension") + full_model_state_dict[meta.orig_name][meta.slicers] = partial_tensor + return full_model_state_dict + + @staticmethod + def merge_partial_states(state_dicts: List, + zero_idx_maps=None): + """Merge model and optimizer states from different shard into a single-model state. + + Warnings: + * This function only supports merging optimizer states of Adam-like optimizers, + in which the optimizer state is expected to contain 'state' keyword. + * Only support single parameter group, i.e., code implementations like: `torch.optim.Adam(model.parameters(), lr=0.1)` + + Args: + state_dicts (List[(Dict, Dict, Dict, Dict)]): per-rank states containing: + * model_state_dicts (List[Dict[str, torch.Tensor]]): per-rank model state dict from model.state_dict() + * optim_state_dicts (Optional[List[Dict]]): per-rank optimizer state dict from optimizer.state_dict() + * dist_param_map: deprecated, will be removed in the future. + * fullmaps (List[Dict[str, AttrMeta]]): per-rank fullmap + zero_idx_maps (Optional[List[Dict]]) + + Returns: + Dict[str, torch.Tensor]: Full model state dict + Dict[str, Dict[str, torch.Tensor]]: Full optimizer state dict """ - assert len(state_dicts) > 0 + model_state_dicts = [states[0] for states in state_dicts] + optim_state_dicts = [states[1] for states in state_dicts] + fullmaps: List[Dict[str, AttrMeta]] = [states[-1] for states in state_dicts] + + if len(model_state_dicts) != len(fullmaps): + raise ValueError("Expected model state dicts to have the same length as fullmaps") + if optim_state_dicts is not None: + if len(optim_state_dicts) != len(fullmaps): + raise ValueError("Expected optimizer state dicts to have the same length as fullmaps") + + # gather model states + full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps) + + # gather optimizer states + full_optim_state_dict: Dict[str, Any] = {} # param_id -> Dict[state_name, value] plan_ngpus = -1 # TODO: remove this flag @@ -244,203 +322,148 @@ def merge_partial_states(state_dicts, zero_idx_maps=None): _logger.info(f'plan_ngpus = {plan_ngpus}') # at first, merge the partitioned optimizer states due to zero to the zero-disabled format - if zero_idx_maps is not None: - if bool(int(os.environ.get('USE_ZERO', default=0))): - def _check_state_size(opt_state_keys, bucket_state): - if len(opt_state_keys) <= 1: - return True - return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape - for key in opt_state_keys) - - def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): - assert bucket_size % len(bucket_states) == 0 - opt_state_keys = list(bucket_states[0].keys()) - if 'step' in bucket_states[0]: - opt_state_keys.remove('step') - assert _check_state_size(opt_state_keys, bucket_states[0]), f'the keys {opt_state_keys} have different shape' - # NOTE: only support adam for now - assert 'exp_avg' in opt_state_keys - assert 'exp_avg_sq' in opt_state_keys - chunk_size = bucket_size // len(bucket_states) - start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size - end_rank_id, end_offset = pend // chunk_size, pend % chunk_size - opt_states, opt_states_1d = {}, {} + if CompileFlag.use_zero: + if zero_idx_maps is None: + raise ValueError(f"Detected zero optimization enabled, " + f"expected zero_idx_maps for merging.") + def _check_state_size(opt_state_keys, bucket_state): + """ + Check that all the keys except the scalar step for a + parameter in optimizer states have the same shaped tensor. + + For example, exp_avg, exp_avg_sq in Adam. + """ + if len(opt_state_keys) <= 1: + return True + return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape + for key in opt_state_keys) + + def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): + assert bucket_size % len(bucket_states) == 0 + opt_state_keys = list(bucket_states[0].keys()) + if 'step' in bucket_states[0]: + opt_state_keys.remove('step') + assert _check_state_size(opt_state_keys, bucket_states[0]), f'the keys {opt_state_keys} have different shape' + # NOTE: only support adam for now + assert 'exp_avg' in opt_state_keys + assert 'exp_avg_sq' in opt_state_keys + chunk_size = bucket_size // len(bucket_states) + start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size + end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + opt_states, opt_states_1d = {}, {} + for key in opt_state_keys: + opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, + device=bucket_states[0][key].device, requires_grad=False) + opt_states_1d[key] = opt_states[key].view(-1) + + if start_rank_id == end_rank_id: for key in opt_state_keys: - opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, - device=bucket_states[0][key].device, requires_grad=False) - opt_states_1d[key] = opt_states[key].view(-1) - - if start_rank_id == end_rank_id: + opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] + else: + offset = chunk_size-start_offset + for key in opt_state_keys: + opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] + for i in range(start_rank_id+1, end_rank_id): for key in opt_state_keys: - opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] + opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] + offset += chunk_size + for key in opt_state_keys: + opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + + if 'step' in bucket_states[0]: + opt_states['step'] = bucket_states[0]['step'] + return opt_states + + opt_state_list = [] + worker_cnt = len(state_dicts) + for work_idx in (range(worker_cnt) if plan_ngpus < 0 else range(plan_ngpus)): + model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[work_idx] + opt_state = {} + for model_idx, opt_idx in model_idx2opt_idx.items(): + if isinstance(opt_idx, int): + # the param without reducer + assert opt_idx2ranks[opt_idx] is None + # state_dicts [worker idx][opt state]['state'][param idx] + opt_state[model_idx] = state_dicts[work_idx][1]['state'][opt_idx] else: - offset = chunk_size-start_offset - for key in opt_state_keys: - opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] - for i in range(start_rank_id+1, end_rank_id): - for key in opt_state_keys: - opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] - offset += chunk_size - for key in opt_state_keys: - opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] - - if 'step' in bucket_states[0]: - opt_states['step'] = bucket_states[0]['step'] - return opt_states - - opt_state_list = [] - worker_cnt = len(state_dicts) - for work_idx in (range(worker_cnt) if plan_ngpus < 0 else range(plan_ngpus)): - model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[work_idx] - opt_state = {} - for model_idx, opt_idx in model_idx2opt_idx.items(): - if isinstance(opt_idx, int): - # the param without reducer - assert opt_idx2ranks[opt_idx] is None - # state_dicts [worker idx][opt state]['state'][param idx] - opt_state[model_idx] = state_dicts[work_idx][1]['state'][opt_idx] + # the param in reducer bucket + opt_idx, pstart, pend, pshape = opt_idx + ranks, bucket_size = opt_idx2ranks[opt_idx] + bucket_states = [state_dicts[rank][1]['state'][opt_idx] for rank in ranks] + opt_state[model_idx] = _retrieve_param_opt_state( + bucket_states, + pstart, + pend, + pshape, + bucket_size) + opt_state_list.append(opt_state) + assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' + + # build parameter order to match with the optimizer state order + # NOTE: the param IDs in optimizer typically follow the same order of + # local `model.parameters()`. However, `state_dict.keys()` contains + # both parameters and buffers, we need to remove the buffers from the list. + # More details refer to the implementation: + # https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module._save_to_state_dict + origin_parameter_names: List[str] = [] + for local_fullmap in fullmaps: + for _, meta in local_fullmap.items(): + if not meta.is_param: continue + # shared parameters in CubeModule is already de-duplicated. So in the + # local model state, we will not have multiple parameters sharing with same content + # but in different names. + if meta.orig_name not in origin_parameter_names: + origin_parameter_names.append(meta.orig_name) + + # handle 'state' in optimizer state dict + # NOTE: each rank may have its local optimizer state working on a sub-set + # of parameters of the full model. So the param IDs in each local optimizer + # state is a sub-sequence of global parameter ordering. + + # we follow the order of in origin parameter names to find each (partitioned) + # parameter in the local model state, and assign the slice to the position. + full_optim_state_dict['state'] = {} + full_states = full_optim_state_dict['state'] + # full_index: param IDs in the full optimizer state + for full_index, param_name in enumerate(origin_parameter_names): + for optim_state, fullmap in zip(optim_state_dicts, fullmaps): + if 'state' not in optim_state: continue + # adam-like optimizers have optim_state['state']={} before any optimizer.step() + if not optim_state['state']: continue + # filter out non-param attributes as they don't appear in the optimizer state + param_fullmap = [meta for meta in fullmap.values() if meta.is_param] + # local index: param IDs in the local optimizer state, we assume + # it aligns with the order of local `model.parameters()` + for local_index, meta in enumerate(param_fullmap): + if meta.orig_name != param_name: continue + full_states.setdefault(full_index, {}) + # TODO: support customized param groups, where each parameter has IDs + # specified from its own param_group + states: Dict[str, torch.Tensor] = optim_state['state'][local_index] + for state_name in states.keys(): + value = states[state_name] + # special handle for step: scalar tensor type + if state_name == 'step': + full_states[full_index][state_name] = value + continue + # for non-tensor states + if not isinstance(value, torch.Tensor): + full_states[full_index][state_name] = value + # for tensor states, like 'exp_avg' else: - # the param in reducer bucket - opt_idx, pstart, pend, pshape = opt_idx - ranks, bucket_size = opt_idx2ranks[opt_idx] - bucket_states = [state_dicts[rank][1]['state'][opt_idx] for rank in ranks] - opt_state[model_idx] = _retrieve_param_opt_state( - bucket_states, - pstart, - pend, - pshape, - bucket_size) - opt_state_list.append(opt_state) - assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' - else: - if plan_ngpus > 0: - _logger.warning(f'plan_ngpus {plan_ngpus} not handled USE_ZERO == False') - def _check_opt_state(opt_state): - cnt = 0 - sorted_opt_state = {} - for idx in sorted(opt_state.keys()): - assert cnt == idx, f'opt state error: {idx} vs {cnt}, in {opt_state.keys()}' - sorted_opt_state[idx] = opt_state[idx] - cnt += 1 - return sorted_opt_state - optimizer_state_dict = {} - worker_cnt = len(state_dicts) - opt_state_list = [] - for work_idx in range(worker_cnt): - zero_idx2model_idx, model_idx2zero_idx, zero_rank_groups = zero_idx_maps[work_idx] - opt_state = {} - # first place local opt state to right index - if len(zero_idx2model_idx) == 0: - assert len(state_dicts[work_idx][1]['state']) == 0 - for local_idx, val in state_dicts[work_idx][1]['state'].items(): # worker / last_optimizer_state / state - global_idx = zero_idx2model_idx[local_idx] - assert global_idx not in opt_state - opt_state[global_idx] = val - # for each rank group, copy opt state from other buckets - for rank_group, param_idx_buckets in zero_rank_groups.items(): - for bucket_idx, rank in enumerate(rank_group): - if rank == work_idx: continue - for global_idx in param_idx_buckets[bucket_idx]: - other_local_idx = zero_idx_maps[rank][1][global_idx] # rank / model_idx2zero_idx / global_idx - assert global_idx not in opt_state - opt_state[global_idx] = state_dicts[rank][1]['state'][other_local_idx] # worker / last_optimizer_state / state / local idx - opt_state = _check_opt_state(opt_state) - opt_state_list.append(opt_state) - assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' - # assign opt_state to state_dicts, cannot be assigned in the above loop - opt_state_len = len(opt_state_list[0]) - for work_idx in (range(worker_cnt) if plan_ngpus < 0 else range(plan_ngpus)): - state_dicts[work_idx][1]['state'] = opt_state_list[work_idx] - state_dicts[work_idx][1]['param_groups'][0]['params'] = sorted(opt_state_list[work_idx].keys()) - assert len(opt_state_list[work_idx]) == opt_state_len - - # find tensor full shape - param_max_dimsize = {} - if plan_ngpus > 0: - state_dicts = state_dicts[0:plan_ngpus] - for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: - for param_area in param_area_map.items(): - local_name = param_area[0][0:param_area[0].rfind('_')] - assert len(local_name) > 0 - raw_name = dist_param_map[local_name] - slices = param_area[1][1] - if param_area[1][2] != 1: - _logger.error(f'value-split on {raw_name} is not supported') - if raw_name in param_max_dimsize: - param_max_dimsize[raw_name] = max(param_max_dimsize[raw_name], slices) - else: - param_max_dimsize[raw_name] = slices - - # create full tensors - param_full_tensors = {} - sample_step = -1 - optim_full_tensors: Dict[int, Dict[any, any]] = {} # param_id, (state_name, state_val) - for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: - if len(optimizer_state_dict['state'].items()) > 0: - optimizer_state_names = list(optimizer_state_dict['state'][0].keys()) - _logger.info(f'optimizer_state_names = {optimizer_state_names}') - if 'step' in optimizer_state_names: - sample_step = optimizer_state_dict['state'][0]['step'] - optimizer_state_names.remove('step') - _logger.info(f'optimizer_state_names (without step) = {optimizer_state_names}') - else: - optimizer_state_names = [] - - other_optim_keys = [key for key in optimizer_state_dict.keys() if key != 'state'] - optimizer_other_state_dict = {} - for key in other_optim_keys: - optimizer_other_state_dict[key] = optimizer_state_dict[key] - - # for raw_name in param_max_dimsize.keys(): - model_state_dict_keys = list(model_state_dict.keys()) - for param_area in param_area_map.items(): - local_name_with_id = param_area[0] - local_name = local_name_with_id[0:local_name_with_id.rfind('_')] - raw_name = dist_param_map[local_name] - - tensor_size_slice = param_max_dimsize[raw_name] - tensor_size = [] - for dim_slice in tensor_size_slice: - tensor_size.append(dim_slice.stop) - partial_tensor = model_state_dict[local_name_with_id] - param_full_tensors[raw_name] = torch.zeros(tuple(tensor_size), dtype=partial_tensor.dtype) - - index = model_state_dict_keys.index(local_name_with_id) - if index in optimizer_state_dict['state']: - for state_name in optimizer_state_names: # 'step' - if index not in optim_full_tensors: - optim_full_tensors[index] = {} - optim_full_tensors[index][state_name] = torch.zeros(tuple(tensor_size)) - else: - _logger.info(f'merge_checkpoint skips {local_name_with_id}\'s optimizer state') - break # only create once - - # assign value - for model_state_dict, optimizer_state_dict, dist_param_map, param_area_map in state_dicts: - model_state_dict_keys = list(model_state_dict.keys()) - for param_area in param_area_map.items(): - local_name_with_id = param_area[0] - local_name = local_name_with_id[0:local_name_with_id.rfind('_')] - raw_name = dist_param_map[local_name] - slices = param_area[1][1] - partial_tensor = model_state_dict[local_name_with_id] - param_full_tensors[raw_name][slices] = partial_tensor - - index = model_state_dict_keys.index(local_name_with_id) - if index in optimizer_state_dict['state']: - states = optimizer_state_dict['state'][index] - for name in optimizer_state_names: - val = states[name] - optim_full_tensors[index][name][slices] = val - if sample_step > 0: - optim_full_tensors[index]['step'] = sample_step - - # print(f'param_full_tensors (assigned) = {param_full_tensors}') - # print(f'optim_full_tensors (assigned) = {optim_full_tensors}') - - optimizer_other_state_dict.update({'state': optim_full_tensors}) - # dump to ckpt - return param_full_tensors, optimizer_other_state_dict + # create optimizer state tensor + if state_name not in full_states[full_index]: + full_states[full_index][state_name] = torch.empty(meta.shape, dtype=value.dtype) + # assign with partial tensor + full_states[full_index][state_name][meta.slicers] = value + + # handle additional state dict keys + for optim_state_dict in optim_state_dicts: + for key in optim_state_dict.keys(): + if key != 'state': + full_optim_state_dict[key] = optim_state_dict[key] + + return full_model_state_dict, full_optim_state_dict @staticmethod def merge_checkpoints(filename_prefix='dist_checkpoint'): diff --git a/tests/runtime/test_module_merge.py b/tests/runtime/test_module_merge.py new file mode 100644 index 00000000..ec0c9169 --- /dev/null +++ b/tests/runtime/test_module_merge.py @@ -0,0 +1,167 @@ +import torch +import cube + +from functools import partial + +from cube.ir.operator import IRFwOperation +from cube.runtime.device import DeviceGroup +from ..launch_torchrun import torchrun +import tempfile + + +class Module(torch.nn.Module): + def __init__(self): + super(Module, self).__init__() + + self.register_buffer('buffer0', torch.randn(8, 8)) + self.param0 = torch.nn.Parameter(torch.randn(8, 8)) + self.param1 = torch.nn.Parameter(torch.randn(8, 8)) + self.register_buffer('buffer1', torch.randn(8, 8)) + self.param2 = torch.nn.Parameter(torch.randn(8, 8)) + + def forward(self, x): + x = x * self.param0 + x = x + self.buffer0 + x = x * self.param1 + x = x + self.buffer1 + x = x * self.param2 + return torch.sum(x) + +def tp_policy(graph, resource): + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if node.name == 'add': + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=1, dim=idx % 2, num=resource.ngpus) + else: + sub_nodes = graph.replicate(node, times=resource.ngpus) + for devid, node in enumerate(sub_nodes): + graph.assign(node, devid) + return graph + + +def assert_same_state(origin, merged): + assert set(origin.keys()) == set(merged.keys()), \ + f"state keys are not equal: origin: {origin.keys()}, merged: {merged.keys()}" + for name in origin.keys(): + if isinstance(origin[name], dict): + assert_same_state(origin[name], merged[name]) + elif isinstance(origin[name], torch.Tensor): + assert torch.equal(origin[name].cpu(), merged[name].cpu()), \ + f"state {name} is not equal: origin:\n{origin[name]}\nmerged:\n{merged[name]}" + else: + assert origin[name] == merged[name], \ + f"state {name} is not equal: origin:\n{origin[name]}\nmerged:\n{merged[name]}" + + +def merge_model_states_test(): + cube.init() + + model = Module() + sample = torch.randn(8, 8, device=torch.cuda.current_device()) + + full_model_state = model.state_dict() + + @cube.compile(model, sample, PAS=tp_policy) + def train_iter(model, sample): + loss = model(sample) + loss.backward() + return loss + cube_model = cube.load_model() + + state_dict = cube_model.state_dict() + torch.save({'state_dict': state_dict, 'fullmap': cube_model.fullmap}, + f'checkpoint-shard{DeviceGroup().rank}.pt') + torch.distributed.barrier() + if DeviceGroup().rank == 0: + model_states = [] + fullmaps = [] + for i in range(DeviceGroup().world_size): + checkpoint = torch.load(f'checkpoint-shard{i}.pt') + model_states.append(checkpoint['state_dict']) + fullmaps.append(checkpoint['fullmap']) + merged_state_dict = cube_model.merge_model_state_dicts(model_states, fullmaps) + assert_same_state(full_model_state, merged_state_dict) + + +test_merge_model_states = partial(torchrun, 2, merge_model_states_test) + + +def merge_optimizer_states_test(): + cube.init() + + torch.manual_seed(0) + model = Module().cuda() + full_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + sample = torch.randn(8, 8, device=torch.cuda.current_device()) + + full_model_state = model.state_dict() + full_optim_state = full_optimizer.state_dict() + + @cube.compile(model, sample, PAS=tp_policy) + def train_iter(model, sample): + loss = model(sample) + loss.backward() + return loss + + cube_model = cube.load_model() + optimizer = torch.optim.Adam(cube_model.parameters(), lr=0.01) + + # test for initial state + model_state_dict = cube_model.state_dict() + optim_state_dict = optimizer.state_dict() + states = { + 'model': model_state_dict, + 'optimizer': optim_state_dict, + 'fullmap': cube_model.fullmap + } + torch.save(states, f'checkpoint-shard{DeviceGroup().rank}.pt') + torch.distributed.barrier() + + if DeviceGroup().rank == 0: + states = [] + for i in range(DeviceGroup().world_size): + checkpoint = torch.load(f'checkpoint-shard{i}.pt') + states.append((checkpoint['model'], checkpoint['optimizer'], checkpoint['fullmap'])) + merged_model_states, merged_optim_states = cube_model.merge_partial_states(states) + assert_same_state(full_model_state, merged_model_states) + assert_same_state(full_optim_state, merged_optim_states) + torch.distributed.barrier() + + # test after training + + for _ in range(2): + # full model + loss = model(sample) + loss.backward() + full_optimizer.step() + full_optimizer.zero_grad() + + # cube model + loss = train_iter(cube_model, sample) + optimizer.step() + optimizer.zero_grad() + + model_state_dict = cube_model.state_dict() + optim_state_dict = optimizer.state_dict() + states = { + 'model': model_state_dict, + 'optimizer': optim_state_dict, + 'fullmap': cube_model.fullmap + } + + torch.save(states, f'checkpoint-shard{DeviceGroup().rank}.pt') + torch.distributed.barrier() + + full_model_state = model.state_dict() + full_optim_state = full_optimizer.state_dict() + + if DeviceGroup().rank == 0: + states = [] + for i in range(DeviceGroup().world_size): + checkpoint = torch.load(f'checkpoint-shard{i}.pt') + states.append((checkpoint['model'], checkpoint['optimizer'], checkpoint['fullmap'])) + merged_model_states, merged_optim_states = cube_model.merge_partial_states(states) + assert_same_state(full_model_state, merged_model_states) + assert_same_state(full_optim_state, merged_optim_states) + +test_merge_optim_states = partial(torchrun, 2, merge_optimizer_states_test) From 3b5566f6971b40cd3f09d8e7575c0e7f2c9aa704 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 23 Feb 2024 06:44:23 +0000 Subject: [PATCH 1595/1892] Merged PR 2046: support dependency track for IRObjects support dependency track for IRObjects --- cube/compiler.py | 2 +- cube/execplan/execplan.py | 44 +++++++++++++++++++++++++++------------ cube/graph/graph.py | 9 ++++---- cube/program.py | 41 ++++++++++++++++++++++++------------ 4 files changed, 64 insertions(+), 32 deletions(-) diff --git a/cube/compiler.py b/cube/compiler.py index e9f5555d..563966bb 100644 --- a/cube/compiler.py +++ b/cube/compiler.py @@ -149,7 +149,7 @@ def decorator(fn: Callable) -> Callable: if isinstance(input, SemanticModel): pinputs.append('model') elif isinstance(input, SemanticDataLoader): - pinputs.append(input.object) + pinputs.append(input.irobj) else: pinputs.append(input) Program().set_input(pinputs) diff --git a/cube/execplan/execplan.py b/cube/execplan/execplan.py index 2b802b2a..7aed4990 100644 --- a/cube/execplan/execplan.py +++ b/cube/execplan/execplan.py @@ -3,7 +3,7 @@ import numpy as np import sys -from cube.ir.cten import IRCell +from cube.ir.cten import IRCell, IRObject from cube.ir.tensor import IRSubTensor, IRFullTensor from cube.ir.adapter import IRAdapter, IRWeightReducer from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation @@ -88,20 +88,38 @@ def from_graph(graph: IRGraph): @staticmethod def from_schedplan(schedplan: SchedulePlan): + """Create execution plan from SchedulePlan + + A schedule plan has multiple micro-batches, where each micro-batch + goes through the all operators in the model graph. So an operator + will be executed multiple times with different data from different micro-batches. + + The IRGraph only contains operators / IRTensors / IRObjects of one micro-batch. + To represent data of a different micro-batch, we need to map the data in IRGraph to a + new one with different IDs. """ - Create execution plan from SchedulePlan - """ - micro_ftensors: Dict[int, Dict[IRFullTensor, IRFullTensor]] = {} - def get(tensor: IRSubTensor, micro_idx: int) -> IRSubTensor: - """Get a same-shape tensor for micro-batch index""" - if not isinstance(tensor, IRSubTensor): return tensor + graph_inputs = schedplan.graph.inputs() + micro_objs: Dict[int, Dict[IRObject, IRObject]] = {} + def get(tensor: IRObject, micro_idx: int) -> IRObject: + """Get an IRObject same to tensor, but with different tid for each given micro-batch index""" + if not isinstance(tensor, IRObject): return tensor + # NOTE: the graph inputs (e.g., dataloader) serves as the global variables during the + # execution of schedules, where every micro-batch shares the same one for execution. + # Typically, the graph inputs can be dataloader object + if tensor in graph_inputs: return tensor if micro_idx == 0: return tensor - ftensor = micro_ftensors.setdefault(micro_idx, {}).setdefault(tensor.parent, tensor.parent.like()) - t = ftensor.select(tensor.indmap, tensor.valmap) - if tensor.grad is not None: - fgrad: IRFullTensor = ftensor.grad - micro_ftensors.setdefault(micro_idx, {}).setdefault(tensor.parent.grad, fgrad) - t.grad = fgrad.select(tensor.grad.indmap, tensor.grad.valmap) + if not isinstance(tensor, IRSubTensor): + # IRObject but not IRSubTensor + micro_objs.setdefault(micro_idx, {}).setdefault(tensor, IRObject(tensor.name, value=tensor.value)) + t = micro_objs[micro_idx][tensor] + else: + # IRSubTensor + ftensor = micro_objs.setdefault(micro_idx, {}).setdefault(tensor.parent, tensor.parent.like()) + t = ftensor.select(tensor.indmap, tensor.valmap) + if tensor.grad is not None: + fgrad: IRFullTensor = ftensor.grad + micro_objs.setdefault(micro_idx, {}).setdefault(tensor.parent.grad, fgrad) + t.grad = fgrad.select(tensor.grad.indmap, tensor.grad.valmap) return t micro_fcells: Dict[(int, IRCell), ExeReuseCell] = {} diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 710c0984..df832f88 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -639,11 +639,10 @@ def depends(self, pre_node: IRCell, succ_node: IRCell) -> bool: Returns: ret (bool): True if post_node depends on pre_node on dataflow, otherwise False. """ - itensors = [t for t in succ_node.inputs() if isinstance(t, IRSubTensor)] - for otensor in pre_node.outputs(): - if not isinstance(otensor, IRSubTensor): continue - for itensor in itensors: - if otensor.overlap(itensor): + input_objs = IRSegment.get_objects_from_complex(succ_node.inputs()) + for out_obj in IRSegment.get_objects_from_complex(pre_node.outputs()): + for in_obj in input_objs: + if out_obj.overlap(in_obj): return True return False diff --git a/cube/program.py b/cube/program.py index 81f6a1df..b8646bcd 100644 --- a/cube/program.py +++ b/cube/program.py @@ -84,9 +84,13 @@ def __repr__(self): class SemanticDataLoader: def __init__(self, dataloader: MicroBatchDataLoader): - """ - Create semantic dataloader which will produces IRDataOperation - when calling `next`. + """Create semantic dataloader representing the dataloader in training iteration. + + Calling `next(SemanticDataLoader)` will generate an IRDataOperation in graph, + which takes the `self.irobj` (i.e., reperesenting the non-tensor value of real + dataloader instance) as input and produces outputs that are converted to + IRObject or IRTensor. The IRDataOperation will be added to the final + graph and generate code like `data = next(dataloader)` Args: dataloader (MicroBatchDataLoader): torch dataloader @@ -94,31 +98,42 @@ def __init__(self, dataloader: MicroBatchDataLoader): if not isinstance(dataloader, MicroBatchDataLoader): raise TypeError("Expected data loader to be MicroBatchDataLoader") self.dataloader: data.DataLoader = dataloader - self.object = IRObject(name='dataloader', value=None) + # the IRObject representing the `dataloader` instance, which is only used by the + # IRDataOperation. Since we already know the output of the dataloader, + # we don't need to set the value for it. + self.irobj = IRObject(name='dataloader', value=None) def __iter__(self): return self def __next__(self): - def generate_output(sample): - """Support complex of types: List, Tuple, torch.Tensor, object""" + def generate_output(sample, name='data'): + """Support complex of types: Tuple, List, Dict, torch.Tensor""" if isinstance(sample, tuple): - return tuple(generate_output(t) for t in sample) + return tuple(generate_output(t, name) for t in sample) if isinstance(sample, list): - return list(generate_output(t) for t in sample) + return list(generate_output(t, name) for t in sample) + if isinstance(sample, dict): + return {k: generate_output(v, str(k)) for k, v in sample.items()} if isinstance(sample, torch.Tensor): - tensor = IRFullTensor(list(sample.shape), 'data', dtype=sample.dtype).tosub() + tensor = IRFullTensor(list(sample.shape), name, dtype=sample.dtype).tosub() tensor._value = sample return tensor - return IRObject('data', value=sample) + return IRObject(name, value=sample, is_constant=False) # get dataloader sample sample = next(iter(self.dataloader)) + if not isinstance(sample, tuple): + sample = (sample,) # turn sample into IRObjects - outputs = generate_output(sample) + outputs = tuple(generate_output(s) for s in sample) + outputs = tuple(IRObject('data', value=out) if not isinstance(out, IRObject) else out for out in outputs) # create dataloader operation - node_outputs = outputs if isinstance(outputs, (tuple, list)) else (outputs,) - data_op = IRDataOperation(self.object, node_outputs) + # the `self.irobj` is the IRObject standing for the non-tensor value of real dataloader. + # the `self.irobj` are also usually used as one input of the whole graph + data_op = IRDataOperation(self.irobj, outputs) Program().add_node(data_op) + # return the outputs in the same format with real dataloader + outputs = outputs[0] if len(outputs) == 1 else outputs return outputs From b1b7d3d808482c81e9f0f7bb296b11c51ccd7917 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 26 Feb 2024 01:28:54 +0000 Subject: [PATCH 1596/1892] Merged PR 2051: parallel module: checkpoint merging support (weights only) Optimizer state dict merging will be another PR. --- cube/graph/parser/frame.py | 26 ++- cube/parallel.py | 273 ++++++++++++++++++++++- cube/runtime/module.py | 135 +++++++++-- tests/parallel_module/test_checkpoint.py | 131 ++++++++--- tests/runtime/test_module_merge.py | 34 ++- 5 files changed, 516 insertions(+), 83 deletions(-) diff --git a/cube/graph/parser/frame.py b/cube/graph/parser/frame.py index 7898d965..1a9c8e7d 100644 --- a/cube/graph/parser/frame.py +++ b/cube/graph/parser/frame.py @@ -125,29 +125,35 @@ def get_attr_var(self, concrete_value: torch.Tensor) -> Optional[IRTensor]: return tensor return None - def save_attr_content(self, save_file: str): + def save_attr_content(self, save_file_stem: str, params_per_file: int = 1024 * 1024 * 1024): """ Save attribute content into file. + + Args: + save_file_stem (str): stem file name. Actual file name will be `save_file_stem`.0, `save_file_stem`.1, etc. + params_per_file (int): number of params per file,default is 1 billion + + Returns: + None """ #TODO: use FxModuleParser.ATTR_CONTENT_FILE_FORMAT to name the files. - params_per_part = 1024 * 1024 * 1024 # 1 billion per part total_size = sum([val.numel() for _, (_, val) in self._attr_map.items()]) - model_pt_part_num = (total_size + params_per_part - 1) // params_per_part + model_pt_part_num = (total_size + params_per_file - 1) // params_per_file tid2value = {t.tid: val.cpu() for t, (_, val) in self._attr_map.items()} # it can be zero if there is no param in the module (self._attr_map is empty) if model_pt_part_num <= 1: - torch.save(tid2value, f'{save_file}.0') + torch.save(tid2value, f'{save_file_stem}.0') else: - sorted_keys = sorted(list(tid2value.keys())) - assert len(sorted_keys) > 0, "Empty attr map" - chunk_size = (len(sorted_keys) + model_pt_part_num - 1) // model_pt_part_num - chunks = [sorted_keys[i:min(i + chunk_size, len(sorted_keys))] for i in - range(0, len(sorted_keys), chunk_size)] + tids = list(tid2value.keys()) + assert len(tids) > 0, "Empty attr map" + chunk_size = (len(tids) + model_pt_part_num - 1) // model_pt_part_num + chunks = [tids[i:min(i + chunk_size, len(tids))] for i in + range(0, len(tids), chunk_size)] for idx, chunk in enumerate(chunks): assert len(chunk) > 0, f"Empty chunk {idx}" part = {k: tid2value[k] for k in chunk} - torch.save(part, f'{save_file}.{idx}') + torch.save(part, f'{save_file_stem}.{idx}') def save_attr_map(self, save_file: str = 'dist_param_map.pt'): """ diff --git a/cube/parallel.py b/cube/parallel.py index bd73bffc..f8bf169e 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -6,7 +6,7 @@ import inspect import sys import importlib -from dataclasses import dataclass +from dataclasses import dataclass, asdict from contextlib import contextmanager import logging @@ -34,7 +34,7 @@ from cube.ir.unique import IDGenerator from cube.program import Program from cube.runtime.adapter.reducer import Reducer -from cube.runtime.module import CubeModule, ParallelModule, OriginModuleMetadata +from cube.runtime.module import CubeModule, ParallelModule, OriginModuleMetadata, ExtraState from cube.runtime.device import DeviceGroup from cube.runtime.gnorm import calcuate_gnorm, clip_grads @@ -187,6 +187,11 @@ def _clean_files(_dir: Path, pattern = '*') -> None: f.unlink() +def _int_dict_to_list(d: Dict[int, Any]) -> List[Any]: + """Convert a dict with int keys to a list""" + return [d[i] for i in range(len(d))] + + _DEFAULT_INSTANCE_NAME = '_' _GENCODE_FILE_PREFIX = 'gencode' _GENCODE_FILE_TEMPLATE = _GENCODE_FILE_PREFIX + '{}.py' # 'gencode{}.py' @@ -762,6 +767,26 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] """ ... +@dataclass(unsafe_hash=True) +class ModuleParameterLocation: + # the location of the parameters of a module in optimizer.param_groups[0]['params'] + # [offset, offset + count) is the range of the parameters in optimizer.param_groups[0]['params'] + offset: int + count: int + + +@dataclass +class OptimizerExtraState: + rank: int + name: str # the name of the optimizer + module_locs: Dict[str, ModuleParameterLocation] + + def __post_init__(self): + for k in self.module_locs: + if isinstance(self.module_locs[k], dict): + self.module_locs[k] = ModuleParameterLocation(**self.module_locs[k]) + + OptimizerT = TypeVar('OptimizerT', bound=torch.optim.Optimizer) @@ -837,13 +862,24 @@ def build_optimizer( non_parallel_module_reducer.add_param(param) non_parallel_module_reducer.build_buckets() + opt_module_locs: Dict[str, ModuleParameterLocation] = {} def _local_parameters(module: torch.nn.Module): + cube_suffix = "_CUBE_SUFFIX" gen = module._named_members( - lambda m: [(str(id(p)), p) for p in m.parameters_for_optimizer()] # (str(id(p)), p) to meet _named_members requirement + lambda m: [(cube_suffix, p) for p in m.parameters_for_optimizer()] # (cube_suffix, p) to meet _named_members requirement if isinstance(m, ParallelModule) else m._parameters.items() ) - for _, param in gen: + for idx, (name, param) in enumerate(gen): + if name.endswith(cube_suffix): # is a parameter of ParallelModule + # -1 for removing the dot + # please note when the whole module is a ParallModule, + # the name will be empty after removing the suffix + name = name[:-len(cube_suffix) - 1] + if name not in opt_module_locs: + opt_module_locs[name] = ModuleParameterLocation(idx, 1) + else: + opt_module_locs[name].count += 1 yield param optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), *args, **kwargs) @@ -860,13 +896,30 @@ def _step_post_hook(opt, *args, **kwargs): optimizer.register_step_post_hook(_step_post_hook) orig_zero_grad = optimizer.zero_grad - def _patched_zero_grad_hook(self, set_to_none: bool = True): + def _patched_zero_grad(self, set_to_none: bool = True): orig_zero_grad(set_to_none) for m in parallel_modules: m.zero_grad() if non_parallel_module_reducer: non_parallel_module_reducer.zero_grad() - optimizer.zero_grad = types.MethodType(_patched_zero_grad_hook, optimizer) + optimizer.zero_grad = types.MethodType(_patched_zero_grad, optimizer) + + orig_state_dict = optimizer.state_dict + def _patched_state_dict(self): + state_dict = orig_state_dict() + state_dict[ParallelModule.EXTRA_STATE_KEY] = asdict(OptimizerExtraState( + rank=torch.distributed.get_rank(), + name=type(optimizer).__name__, + module_locs=opt_module_locs, + )) + return state_dict + optimizer.state_dict = types.MethodType(_patched_state_dict, optimizer) + + orig_load_state_dict = optimizer.load_state_dict + def _patched_load_state_dict(self, state_dict): + state_dict.pop(ParallelModule.EXTRA_STATE_KEY, None) + orig_load_state_dict(state_dict) + optimizer.load_state_dict = types.MethodType(_patched_load_state_dict, optimizer) def _sync_shard_grad(self): with _runtime_flags(skip_reducer=False): @@ -934,3 +987,211 @@ def _register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None optimizer.register_reducer_post_hook = types.MethodType(_register_reducer_post_hook, optimizer) return optimizer + + +def _get_parallel_module_state_dict_info( + model_state_dicts: List[Dict[str, Any]] +) -> Tuple[ + Dict[Tuple[str, ...], List[ExtraState]], # parallel module extrastate for each rank + Dict[Tuple[str,...], List[Dict[str, Any]]], # parallel module state dict for each rank + Dict[str, Any] # non-parallel module state dict +]: + # parted key model state dicts + pk_model_state_dicts: List[Dict[Tuple[str,...], Any]] = [] + for model_state_dict in model_state_dicts: + pk_model_state_dicts.append({tuple(k.split('.')): v for k, v in model_state_dict.items()}) + + # find all parallel module state keys (whose key ends with ParallelModule.EXTRA_STATE_KEY) + # key: the module prefix + # value: the list of extra states from all ranks + pm_extra_states: Dict[Tuple[str, ...], List[ExtraState]] = {} + for pk_model_state_dict in pk_model_state_dicts: + for k in pk_model_state_dict: + if k[-1] == ParallelModule.EXTRA_STATE_KEY: + module_prefix = k[:-1] + if module_prefix not in pm_extra_states: + pm_extra_states[module_prefix] = [None] * len(pk_model_state_dicts) + opt_extra_state = ExtraState(**pk_model_state_dict[k]) + pm_extra_states[module_prefix][opt_extra_state.rank] = opt_extra_state + + # collect ParallelModule state dicts + # key is the module prefix of the parallel module in state dict + # value is the list of state dicts of the parallel module from all ranks + pm_state_dicts: Dict[Tuple[str,...], List[Dict[str, Any]]] = {} + # non-parallel module state dict + non_pm_state_dict: Dict[str, Any] = {} + for pk_model_state_dict in pk_model_state_dicts: + for k in pk_model_state_dict: + if k[-1] == ParallelModule.EXTRA_STATE_KEY: # skip extra state, we already have them + continue + module_prefix = k[:-1] + if module_prefix in pm_extra_states: + if module_prefix not in pm_state_dicts: + pm_state_dicts[module_prefix] = [dict() for _ in range(len(pk_model_state_dicts))] + opt_extra_state = ExtraState(**pk_model_state_dict[module_prefix + (ParallelModule.EXTRA_STATE_KEY,)]) + pm_state_dicts[module_prefix][opt_extra_state.rank][k[-1]] = pk_model_state_dict[k] + else: + # no further processing + # here we assume values from all ranks are the same + non_pm_state_dict['.'.join(k)] = pk_model_state_dict[k] + + return pm_extra_states, pm_state_dicts, non_pm_state_dict + + +def _get_optimizer_state_dict_info( + optimizer_state_dicts: List[Dict[str, Any]] +) -> Tuple[ + List[OptimizerExtraState], + Dict[str, # key: the module prefix + List[Dict[ # value: a list of dict from all ranks. The dict is + str, # key: the state key `state` (all other keys will be ignored.) + Dict[ # value: a dict which is the same with opt_state_dict['state'], it is: + int, # key: an integer representing the parameter index + Dict[str, Any] # value: a dict contains the parameter related info, the keys include 'step', 'exp_avg', 'exp_avg_sq'. + ] + ] + ] + ], + Dict[str, Any] +]: + ret_opt_state_dict = {'state': {}} + # collect optimizer state dicts + # merge ParallelModule state dicts + # here we only need to handle `state` key in the optimizer state dict + # all other keys will be copied to the final state dict + opt_extra_states: List[OptimizerExtraState] = [None] * len(optimizer_state_dicts) + opt_state_dicts: Dict[str, # key: the module prefix + List[Dict[ # value: a list of dict from all ranks. The dict is + str, # key: the state key `state` (all other keys will be ignored.) + Dict[ # value: a dict which is the same with opt_state_dict['state'], it is: + int, # key: an integer representing the parameter index + Dict[str, Any] # value: a dict contains the parameter related info, the keys include 'step', 'exp_avg', 'exp_avg_sq'. + ] + ] + ] + ] = {} + for opt_state_dict in optimizer_state_dicts: + opt_extra_state = OptimizerExtraState(**opt_state_dict[ParallelModule.EXTRA_STATE_KEY]) + if 'adam' not in opt_extra_state.name.lower(): + raise ValueError("Only Adam-like optimizers are supported.") + opt_extra_states[opt_extra_state.rank] = opt_extra_state + + for module_prefix, loc in opt_extra_state.module_locs.items(): + if module_prefix not in opt_state_dicts: + opt_state_dicts[module_prefix] = [dict(state=[], param_groups=[]) for _ in range(len(optimizer_state_dicts))] + for i in range(loc.offset, loc.offset + loc.count): + opt_state_dicts[module_prefix][opt_extra_state.rank]['state'].append(opt_state_dict['state'][i]) + # TODO: inaccurate param_groups, for example, the 'params' in it is not right. + # we have this to make `ParallelModule.merge_partial_states` happy. + opt_state_dicts[module_prefix][opt_extra_state.rank]['param_groups'] = opt_state_dict['param_groups'] + + for k, v in opt_state_dict.items(): + if k == ParallelModule.EXTRA_STATE_KEY or k == 'state': + continue + # no further processing + # here we assume values from all ranks are the same + ret_opt_state_dict[k] = v + + return opt_extra_states, opt_state_dicts, ret_opt_state_dict + + +def merge_state_dicts( + model_state_dicts: List[Dict[str, Any]], + optimizer_state_dicts: Optional[List[Dict[str, Any]]], +) -> Tuple[Dict[str, Any], Optional[List[Dict[str, Any]]]]: + """ + Merge a list of shard state dicts (one for each rank) to a single full state dict + Note: Only Adam-like optimizers are supported for merging + + Args: + model_state_dicts (List[Dict[str, Any]]): the model state dicts from each rank + optimizer_state_dicts (Optional[List[Dict[str, Any]]]): the optimizer state dicts from each rank + + Returns: + Tuple[Dict[str, Any], Optional[List[Dict[str, Any]]]]: the merged model state dict and the merged optimizer state dict + """ + if optimizer_state_dicts is not None: + # TODO: support checkpoint optimization + # where the following check may be too strong. + if len(model_state_dicts) != len(optimizer_state_dicts): + raise ValueError("The length of model_state_dicts and optimizer_state_dicts should be the same.") + if not model_state_dicts: + raise ValueError("model_state_dicts should not be empty.") + + pm_extra_states, pm_state_dicts, ret_state_dict = _get_parallel_module_state_dict_info(model_state_dicts) + if optimizer_state_dicts is not None: + opt_extra_states, opt_state_dicts, ret_opt_state_dict = _get_optimizer_state_dict_info(optimizer_state_dicts) + # the new optimizer state dict for ParallelModules + # key: the parallel module location in the optimizer state + # value: the new state values for the parallel module + # (index is the parameter index in parallel module) + new_pm_states: Dict[ModuleParameterLocation, List[Any]] = {} + else: + opt_extra_states, opt_state_dicts, ret_opt_state_dict, new_pm_states = None, None, None, None + + # do merging + # every loop will merge one ParallelModule + for k, state_dicts_for_merge in pm_state_dicts.items(): + extra_states = pm_extra_states[k] + module_prefix = '.'.join(k) + opt_state_dicts_for_merge = [{'state': {}} for _ in range(len(state_dicts_for_merge))] \ + if opt_state_dicts is None else opt_state_dicts[module_prefix] + + merge_partial_states_state_dicts = [] + merge_partial_states_zero_idx_maps = [] + for m, opt, extra in zip(state_dicts_for_merge, opt_state_dicts_for_merge, extra_states): + merge_partial_states_state_dicts.append((m, opt, extra.dist_param_map, extra.param_area_map)) + merge_partial_states_zero_idx_maps.append((extra.model_idx2opt_idx, extra.opt_idx2ranks)) + if not extra_states[0].compute_config.use_zero: # all ranks should have the same use_zero + merge_partial_states_zero_idx_maps = None + merged_state_dict, merged_opt_state_dict = ParallelModule.merge_partial_states( + merge_partial_states_state_dicts, + merge_partial_states_zero_idx_maps + ) + + # merge back module state dict + for km, vm in merged_state_dict.items(): + key = km if not module_prefix else f'{module_prefix}.{km}' + ret_state_dict[key] = vm + + # merge back opt state dict + if opt_state_dicts is not None: + opt_module_locs = [opt_extra_states[i].module_locs[module_prefix] for i in range(len(opt_extra_states))] + + # Assume all ranks have the same opt_module_locs (offset and count) + # TODO: assert may fail for pipeline parallelism + for i in range(1, len(opt_module_locs)): + assert opt_module_locs[i] == opt_module_locs[0] + new_pm_states[opt_module_locs[0]] = _int_dict_to_list(merged_opt_state_dict['state']) + + if new_pm_states: + ret_state_list: List[Any] = _int_dict_to_list(optimizer_state_dicts[0]['state']) + sorted_keys = sorted(new_pm_states.keys(), key=lambda x: x.offset, reverse=True) + for loc in sorted_keys: + # this assign is only safe in reverse order + ret_state_list[loc.offset:loc.offset + loc.count] = new_pm_states[loc] + ret_opt_state_dict['state'] = {i: v for i, v in enumerate(ret_state_list)} + + return ret_state_dict, ret_opt_state_dict + + +def load_merged_state_dicts(module: torch.nn.Module, state_dict: Dict[str, Any]) -> torch.nn.Module: + """ + Load the merged state dicts to the module and optimizer. + + Args: + module (torch.nn.Module): the module to be loaded + state_dict (Dict[str, Any]): the merged model state dict + Returns: + torch.nn.Module: the module after loading the state dict + """ + # there will be mismatched keys if the module is a ParallelModule or contains ParallelModule + # so we need to ignore the mismatched keys + module.load_state_dict(state_dict, strict=False) + # load ParallelModule state dicts + for name, child_module in module.named_modules(): + if isinstance(child_module, ParallelModule): + prefix = name + '.' if name else '' + child_module.load_merged_state_dict(state_dict, prefix=prefix) + + return module diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 0eb1308b..cd4d79a7 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -4,6 +4,7 @@ import sys from pathlib import Path from dataclasses import dataclass, asdict +from collections import defaultdict import torch import torch.distributed as dist @@ -46,8 +47,9 @@ class CubeModule(torch.nn.Module): def __init__(self): super().__init__() self._reducers: List[Reducer] = list() - # self._fullmap is mapping from the name of local attribute tensor + # self._fullmap is mapping from the name of local attribute tensor # to its corresponding fulltensor meta + # please note there can be multiple entries with same tid self._fullmap : Dict[str, AttrMeta] = dict() @property @@ -244,7 +246,7 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref def merge_model_state_dicts(state_dicts: List[Dict], fullmaps: List[Dict[str, AttrMeta]]): """Merge model states from multiple shard into a single-model state. - + Note: Users only need to provide as fewer local model states as necessary to cover the full model state. @@ -252,13 +254,13 @@ def merge_model_state_dicts(state_dicts: List[Dict], Args: state_dicts (List[Dict[str, torch.Tensor]]): per-rank local model state dict from model.state_dict() fullmaps (List[Dict[str, AttrMeta]]): per-rank fullmap - + Returns: full_state_dicts (List[Dict[str, torch.Tensor]]): Full model state dict """ if len(state_dicts) != len(fullmaps): raise ValueError("Expected model state dicts to have the same length as fullmaps") - + full_model_state_dict: Dict[str, torch.Tensor] = {} # gather param/buffer full tensor for model_state_dict, local_fullmap in zip(state_dicts, fullmaps): @@ -278,20 +280,20 @@ def merge_model_state_dicts(state_dicts: List[Dict], def merge_partial_states(state_dicts: List, zero_idx_maps=None): """Merge model and optimizer states from different shard into a single-model state. - + Warnings: * This function only supports merging optimizer states of Adam-like optimizers, in which the optimizer state is expected to contain 'state' keyword. * Only support single parameter group, i.e., code implementations like: `torch.optim.Adam(model.parameters(), lr=0.1)` Args: - state_dicts (List[(Dict, Dict, Dict, Dict)]): per-rank states containing: + state_dicts (List[(Dict, Dict, Dict, Dict)]): per-rank states containing: * model_state_dicts (List[Dict[str, torch.Tensor]]): per-rank model state dict from model.state_dict() * optim_state_dicts (Optional[List[Dict]]): per-rank optimizer state dict from optimizer.state_dict() * dist_param_map: deprecated, will be removed in the future. * fullmaps (List[Dict[str, AttrMeta]]): per-rank fullmap zero_idx_maps (Optional[List[Dict]]) - + Returns: Dict[str, torch.Tensor]: Full model state dict Dict[str, Dict[str, torch.Tensor]]: Full optimizer state dict @@ -322,21 +324,18 @@ def merge_partial_states(state_dicts: List, _logger.info(f'plan_ngpus = {plan_ngpus}') # at first, merge the partitioned optimizer states due to zero to the zero-disabled format - if CompileFlag.use_zero: - if zero_idx_maps is None: - raise ValueError(f"Detected zero optimization enabled, " - f"expected zero_idx_maps for merging.") + if zero_idx_maps is not None: def _check_state_size(opt_state_keys, bucket_state): """ - Check that all the keys except the scalar step for a - parameter in optimizer states have the same shaped tensor. - + Check that all the keys except the scalar step for a + parameter in optimizer states have the same shaped tensor. + For example, exp_avg, exp_avg_sq in Adam. """ if len(opt_state_keys) <= 1: return True return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape - for key in opt_state_keys) + for key in opt_state_keys) def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): assert bucket_size % len(bucket_states) == 0 @@ -353,7 +352,7 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): opt_states, opt_states_1d = {}, {} for key in opt_state_keys: opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, - device=bucket_states[0][key].device, requires_grad=False) + device=bucket_states[0][key].device, requires_grad=False) opt_states_1d[key] = opt_states[key].view(-1) if start_rank_id == end_rank_id: @@ -367,8 +366,9 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): for key in opt_state_keys: opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] offset += chunk_size - for key in opt_state_keys: - opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + if end_offset: # skip if end_offset == 0, because it is a no-op + for key in opt_state_keys: + opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] if 'step' in bucket_states[0]: opt_states['step'] = bucket_states[0]['step'] @@ -399,9 +399,16 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): opt_state_list.append(opt_state) assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' + # assign opt_state to state_dicts, cannot be assigned in the above loop + opt_state_len = len(opt_state_list[0]) + for work_idx in (range(worker_cnt) if plan_ngpus < 0 else range(plan_ngpus)): + optim_state_dicts[work_idx]['state'] = opt_state_list[work_idx] + optim_state_dicts[work_idx]['param_groups'][0]['params'] = sorted(opt_state_list[work_idx].keys()) + assert len(opt_state_list[work_idx]) == opt_state_len + # build parameter order to match with the optimizer state order # NOTE: the param IDs in optimizer typically follow the same order of - # local `model.parameters()`. However, `state_dict.keys()` contains + # local `model.parameters()`. However, `state_dict.keys()` contains # both parameters and buffers, we need to remove the buffers from the list. # More details refer to the implementation: # https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module._save_to_state_dict @@ -414,7 +421,7 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): # but in different names. if meta.orig_name not in origin_parameter_names: origin_parameter_names.append(meta.orig_name) - + # handle 'state' in optimizer state dict # NOTE: each rank may have its local optimizer state working on a sub-set # of parameters of the full model. So the param IDs in each local optimizer @@ -509,10 +516,21 @@ class ZeroMetadata: class ParallelModuleConfig: rank: int compute_config: 'ComputeConfig' - dist_param_map: Dict - param_area_map: Dict + # the dist_param_map of ParallelModule + dist_param_map: Dict[str, str] + # the fullmap of ParallelModule + param_area_map: Dict[str, AttrMeta] + # the parameter names of ParallelModule cube_param_names: List[str] + def __post_init__(self): + if isinstance(self.compute_config, dict): + from cube.parallel import ComputeConfig + self.compute_config = ComputeConfig(**self.compute_config) + for k in self.param_area_map: + if isinstance(self.param_area_map[k], dict): + self.param_area_map[k] = AttrMeta(**self.param_area_map[k]) + @dataclass class ExtraState(ZeroMetadata, OriginModuleMetadata, ParallelModuleConfig): @@ -553,6 +571,7 @@ def _post_init(self, init_params=True): # raise RuntimeError(f"The rank to load this module file name is expected to be {self._rank}, but got {dist.get_rank()}") module_file = Path(sys.modules[self.__module__].__file__) + self.module_dir = module_file.parent if init_params: self.load_attr_content(str(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE_STEM}"))) self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}")) @@ -591,7 +610,12 @@ def sync_grad(self): for reducer in self._reducers: reducer.sync_grads() - def get_dist_param_map(self): + def get_dist_param_map(self) -> Dict[str, str]: + """ + Get the parameter map of the model. + The map is a dict mapping from the new parameter name (without tid suffix) in parallel module + to the parameter name in original module. + """ return self._dist_param_map def get_compute_config(self) -> 'ComputeConfig': @@ -721,7 +745,7 @@ def _post_state_dict_hook(self, state_dict, prefix, local_metadata) -> None: state_dict[f'{prefix}{self.EXTRA_STATE_KEY}'] = asdict( ExtraState( rank=self.get_rank(), - compute_config=asdict(self._compute_config), + compute_config=self._compute_config, dist_param_map=self._dist_param_map, param_area_map=self._fullmap, cube_param_names=[name for name, _ in self.named_parameters()], @@ -732,3 +756,66 @@ def _post_state_dict_hook(self, state_dict, prefix, local_metadata) -> None: def _pre_load_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None: state_dict.pop(f'{prefix}{self.EXTRA_STATE_KEY}', None) + + def _list_fullmodel_files(self) -> List[Path]: + legacy_fullmodel_path = self.module_dir / FxModuleParser.ATTR_CONTENT_FILE_STEM + files = [] + if not legacy_fullmodel_path.is_file(): + file_idx = 0 + while True: + filepath = self.module_dir / FxModuleParser.ATTR_CONTENT_FILE_FORMAT.format(stem=FxModuleParser.ATTR_CONTENT_FILE_STEM, idx=file_idx) + if not filepath.is_file(): + break + files.append(filepath) + file_idx += 1 + else: + files.append(legacy_fullmodel_path) + + return files + + def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', strict: bool = True): + """ + Load the model from a merged state dict. + + Args: + state_dict (Dict[str, Any]): the merged state dict + prefix (str): the prefix of the model state dict in the merged state dict + strict (bool, optional): whether to strictly enforce that state_dict has have all the parameters of the module + Note: unlike `torch.nn.Module.load_state_dict`, + we only make sure no missing keys. Unexpected keys are not checked. + Default: `True` + Returns: + None + Raises: + RuntimeError: if strict=True and there are missing keys. + """ + + dist2param = self.get_dist_param_map() + orig_param_names = list(dist2param.values()) # param names in original module (without prefix) + + with torch.no_grad(): + attr_names = set(self._fullmap.keys()) + + origname_tid_map = {meta.orig_name: meta.tid for meta in self._fullmap.values()} + tid_info = defaultdict(list) + for attr, meta in self._fullmap.items(): + tid_info[meta.tid].append((attr, meta.slicers, meta.val_chunks)) # multiple params may share the same tid + + for orig_param_name in orig_param_names: + orig_param_name_with_prefix = prefix + orig_param_name + param_value = state_dict[orig_param_name_with_prefix] + tid = origname_tid_map[orig_param_name] + for attr, slicer, nchunks in tid_info[tid]: + tensor: torch.Tensor = getattr(self, attr) + content = param_value[slicer] + if nchunks != 1: + content = content / nchunks + tensor.copy_(content) + attr_names.remove(attr) + + if len(attr_names) != 0: + erro_msg = f'Missing key(s) in state_dict: {[prefix + self._fullmap[attr].orig_name for attr in attr_names]}.' + if strict: + raise RuntimeError(erro_msg) + else: + _logger.warning(erro_msg) diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index fca3dc24..feab1b4b 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -14,7 +14,7 @@ import numpy as np -from cube.parallel import ComputeConfig, parallelize, build_optimizer +from cube.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts from cube.runtime.module import ParallelModule, ExtraState from cube.runtime.gnorm import calcuate_gnorm @@ -51,24 +51,45 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): ) -def _create_cube_module(pas, compute_config, cube_savedir): - class CompiledModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(4, 4) - self.fc_relu1 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu1') - self.linear2 = nn.Linear(4, 4) - self.fc_relu2 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu2') - self.linear3 = nn.Linear(4, 1) - self.sigmoid = nn.Sigmoid() - def forward(self, x): - x = self.linear1(x) - x = self.fc_relu1(x) - x = self.linear2(x) - x = self.fc_relu2(x) - x = self.linear3(x) - x = self.sigmoid(x) - return x +def _create_cube_module(pas, compute_config, cube_savedir, module_type='whole'): + init_random() + if module_type == 'whole': + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = FcRelu_4_4() + self.linear2 = nn.Linear(4, 4) + self.fc_relu2 = FcRelu_4_4() + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + CompiledModule = _to_cube_model(CompiledModule, pas, compute_config, cube_savedir, 'whole') + else: + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu1') + self.linear2 = nn.Linear(4, 4) + self.fc_relu2 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu2') + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x init_random() compiled_module = CompiledModule().cuda() return compiled_module @@ -86,10 +107,14 @@ class StepResult: def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): ckpt_file_template = 'ckpt_{rank}_{start}.pth' + ckpt_merged_file_template = 'ckpt_merged_{start}.pth' ckpt_start_file = ckpt_dir / ckpt_file_template.format( rank=torch.distributed.get_rank(), start=start ) + ckpt_start_merged_file = ckpt_dir / ckpt_merged_file_template.format( + start=start + ) init_random() loss_fn = nn.BCELoss() @@ -97,11 +122,32 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): if ckpt_start_file.exists(): ckpt_dict = torch.load(ckpt_start_file) model_state_dict = ckpt_dict['model'] - assert 'fc_relu1.CUBE_EXTRA_STATE' in model_state_dict - assert 'fc_relu2.CUBE_EXTRA_STATE' in model_state_dict + for name, m in model.named_modules(): + prefix = f'{name}.' if name else '' + if isinstance(m, ParallelModule): + assert f'{prefix}CUBE_EXTRA_STATE' in model_state_dict optimizer_state_dict = ckpt_dict['optimizer'] + assert 'CUBE_EXTRA_STATE' in optimizer_state_dict model.load_state_dict(model_state_dict) optimizer.load_state_dict(optimizer_state_dict) + + assert ckpt_start_merged_file.exists() + merged_ckpt_dict = torch.load(ckpt_start_merged_file) + merged_model_state_dict = merged_ckpt_dict['model'] + model_from_merged = load_merged_state_dicts(type(model)(), merged_model_state_dict) + + # check merged model + result_orig_model_state_dict = model.state_dict() + result_merged_model_state_dict = model_from_merged.state_dict() + assert set(result_orig_model_state_dict.keys()) == set(result_merged_model_state_dict.keys()) + for k in result_orig_model_state_dict.keys(): + if k.endswith('CUBE_EXTRA_STATE'): + continue + assert torch.equal(result_orig_model_state_dict[k], result_merged_model_state_dict[k]) + + # TODO: check merged optimizer + # merged_optimizer_state_dict = merged_ckpt_dict['optimizer'] + data = [] init_random() for _ in range(DATA_SIZE): @@ -130,23 +176,40 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): rank=torch.distributed.get_rank(), start=end ) + ckpt_merged_file = ckpt_dir / ckpt_merged_file_template.format( + start=end + ) model_state_dict = model.state_dict() - assert 'fc_relu1.CUBE_EXTRA_STATE' in model_state_dict - assert 'fc_relu2.CUBE_EXTRA_STATE' in model_state_dict - extra_state1 = ExtraState(**model_state_dict['fc_relu1.CUBE_EXTRA_STATE']) - assert extra_state1.compute_config - assert extra_state1.model_idx2opt_idx - assert extra_state1.opt_idx2ranks - assert extra_state1.origin_param_names + for name, m in model.named_modules(): + if isinstance(m, ParallelModule): + prefix = f'{name}.' if name else '' + assert f'{prefix}CUBE_EXTRA_STATE' in model_state_dict + extra_state1 = ExtraState(**model_state_dict[f'{prefix}CUBE_EXTRA_STATE']) + assert extra_state1.compute_config + assert extra_state1.model_idx2opt_idx + assert extra_state1.opt_idx2ranks + assert extra_state1.origin_param_names optimizer_state_dict = optimizer.state_dict() + assert 'CUBE_EXTRA_STATE' in optimizer_state_dict torch.save({ 'model': model_state_dict, 'optimizer': optimizer_state_dict }, ckpt_file) + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + ckpt_files = [ckpt_dir / ckpt_file_template.format(rank=i, start=end) for i in range(torch.distributed.get_world_size())] + ckpt_state_dicts = [torch.load(f) for f in ckpt_files] + model_state_dicts = [ckpt['model'] for ckpt in ckpt_state_dicts] + optimizer_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_state_dicts] + merged_model_state_dicts, merged_optimizer_state_dict = merge_state_dicts(model_state_dicts, optimizer_state_dicts) + torch.save({ + 'model': merged_model_state_dicts, + 'optimizer': merged_optimizer_state_dict + }, ckpt_merged_file) return results -def _gpu_worker(pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_count): +def _gpu_worker(module_type, pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_count): init_distributed() compiled_results = [] with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: @@ -155,7 +218,8 @@ def _gpu_worker(pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_ end = (i + 1) * per_resume_update_count compiled_module = _create_cube_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=True), - tempdir + tempdir, + module_type, ) compiled_results.extend(_train( compiled_module, @@ -166,9 +230,10 @@ def _gpu_worker(pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_ return compiled_results @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') -def test_checkpoint(): - cube_results = launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4, 32, 1) - rcube_results = launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4, 16, 2) +@pytest.mark.parametrize('module_type', ['sub', 'whole']) +def test_checkpoint(module_type): + cube_results = launch_torchrun(4, _gpu_worker, module_type, PASRandomSPMD, 2, 4, 32, 1) + rcube_results = launch_torchrun(4, _gpu_worker, module_type, PASRandomSPMD, 2, 4, 16, 2) results0, results1, results2, results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] rresults0, rresults1, rresults2, rresults3 = rcube_results[0], rcube_results[1], rcube_results[2], rcube_results[3] diff --git a/tests/runtime/test_module_merge.py b/tests/runtime/test_module_merge.py index ec0c9169..e5c946db 100644 --- a/tests/runtime/test_module_merge.py +++ b/tests/runtime/test_module_merge.py @@ -1,18 +1,31 @@ import torch import cube +import os from functools import partial +import pytest from cube.ir.operator import IRFwOperation from cube.runtime.device import DeviceGroup from ..launch_torchrun import torchrun -import tempfile + + +@pytest.fixture(autouse=True, scope='module') +def clean_checkpoints(): + yield + i = 0 + while True: + try: + os.remove(f'checkpoint-shard{i}.pt') + i += 1 + except Exception: + break class Module(torch.nn.Module): def __init__(self): super(Module, self).__init__() - + self.register_buffer('buffer0', torch.randn(8, 8)) self.param0 = torch.nn.Parameter(torch.randn(8, 8)) self.param1 = torch.nn.Parameter(torch.randn(8, 8)) @@ -26,7 +39,7 @@ def forward(self, x): x = x + self.buffer1 x = x * self.param2 return torch.sum(x) - + def tp_policy(graph, resource): for idx, node in enumerate(graph.select(ntype=IRFwOperation)): if node.name == 'add': @@ -60,16 +73,16 @@ def merge_model_states_test(): sample = torch.randn(8, 8, device=torch.cuda.current_device()) full_model_state = model.state_dict() - + @cube.compile(model, sample, PAS=tp_policy) def train_iter(model, sample): loss = model(sample) loss.backward() return loss cube_model = cube.load_model() - + state_dict = cube_model.state_dict() - torch.save({'state_dict': state_dict, 'fullmap': cube_model.fullmap}, + torch.save({'state_dict': state_dict, 'fullmap': cube_model.fullmap}, f'checkpoint-shard{DeviceGroup().rank}.pt') torch.distributed.barrier() if DeviceGroup().rank == 0: @@ -96,13 +109,13 @@ def merge_optimizer_states_test(): full_model_state = model.state_dict() full_optim_state = full_optimizer.state_dict() - + @cube.compile(model, sample, PAS=tp_policy) def train_iter(model, sample): loss = model(sample) loss.backward() return loss - + cube_model = cube.load_model() optimizer = torch.optim.Adam(cube_model.parameters(), lr=0.01) @@ -135,12 +148,12 @@ def train_iter(model, sample): loss.backward() full_optimizer.step() full_optimizer.zero_grad() - + # cube model loss = train_iter(cube_model, sample) optimizer.step() optimizer.zero_grad() - + model_state_dict = cube_model.state_dict() optim_state_dict = optimizer.state_dict() states = { @@ -164,4 +177,5 @@ def train_iter(model, sample): assert_same_state(full_model_state, merged_model_states) assert_same_state(full_optim_state, merged_optim_states) + test_merge_optim_states = partial(torchrun, 2, merge_optimizer_states_test) From 1834a050c7017dae2f26b63d502c868d6ec36b63 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 28 Feb 2024 06:25:48 +0000 Subject: [PATCH 1597/1892] Merged PR 2056: Add check for IRPyFunc when building graph Cube does not support to compute gradients for IRPyFunc. However, some IRPyFunc may require gradients. --- cube/graph/graph.py | 13 +++++++++++++ cube/graph/parser/fx/parser.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index df832f88..5fadda12 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -192,6 +192,19 @@ def from_logic_graph(nodes: List[IRCell], node.set_output(idx, subtensor) graph = IRGraph(nodes, inputs, outputs, module_name) + # check IRPyFunc + requires_grad_pyfunc: List[IRPyFunc] = [] + for node in nodes: + if not isinstance(node, IRPyFunc): continue + if any(isinstance(t, IRSubTensor) and t.requires_grad for t in node.outputs()): + requires_grad_pyfunc.append(node) + if len(requires_grad_pyfunc) > 0: + dscp = (f'Cube does not support to compute gradients for IRPyFunc.\n' + f'Following nodes require gradients, this may trigger error in backward:\n') + for node in requires_grad_pyfunc: + dscp += f'\t{node.signature}, cid: {node.cid}\n' + _logger.warning(dscp) + # check unused outputs unused_obj_nodes: Dict[IRObject, List[IRCell]] = {} graph_output_objects = [ diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 80d5a2c8..aec639a4 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -159,7 +159,7 @@ def parse_complex(val: Any, frame: Frame) -> Any: frame (Frame): the frame to get the fx.Node Returns: - the copied strcuture where the fx.Node is replaced by IRObjects/IRTensors + the copied structure where the fx.Node is replaced by IRObjects/IRTensors """ # to support more nested types, we can refer to the implementation of # https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py From 219c46c78739645c41816ce3a7a76523ad3f3ca9 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Thu, 29 Feb 2024 02:52:48 +0000 Subject: [PATCH 1598/1892] Merged PR 2057: Modify the way to get slicers from AttrMeta when using old gnorm Modify the way to get slicers from AttrMeta when using old gnorm --- cube/runtime/gnorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/runtime/gnorm.py b/cube/runtime/gnorm.py index 05dc0bb3..72033485 100644 --- a/cube/runtime/gnorm.py +++ b/cube/runtime/gnorm.py @@ -65,7 +65,7 @@ def prepare_for_grad_clip_legacy(cube_model: 'CubeModule', curr_rank: int) -> Di assert name in cube_model.fullmap if param.requires_grad: tid = cube_model.tid_of_param_name(name) - slicers = cube_model.fullmap[name][1] + slicers = cube_model.fullmap[name].slicers tid2param[tid] = param tid2slicers[tid] = slicers # gather all parameters' slicers From 3e9160ba1a02bb60082f3e201afbfc64191bfcfa Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 1 Mar 2024 02:42:31 +0000 Subject: [PATCH 1599/1892] Merged PR 2050: correct the inplace operation value & metadata 1. update the op input after concrete run op 2. treat `dict_keys_type, dict_values_type, dict_items_type` as `list` during tracing 3. add non-tensor input related inplace operation in SSA replacement 4. fix SSA setitem metadata 5. add SetItem in function.py parity check pass --- cube/graph/function/function.py | 83 ++++++++-- cube/graph/parser/converter.py | 6 +- .../concrete_trace_utils/concrete_tracer.py | 124 ++++++++++----- .../concrete_trace_utils/operator_patcher.py | 3 +- .../parser/fx/concrete_trace_utils/utils.py | 148 +++++++++++++----- cube/graph/parser/fx/mapping.py | 3 + cube/graph/parser/fx/parser.py | 8 +- tests/graph/function/test_functions.py | 21 +++ tests/graph/tracer/test_inplace.py | 43 +++++ tests/graph/tracer/test_pytree.py | 59 +++++++ 10 files changed, 401 insertions(+), 97 deletions(-) create mode 100644 tests/graph/tracer/test_pytree.py diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 29a81789..53f7c8eb 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -45,17 +45,23 @@ def is_list_or_tuple(v: Any) -> bool: ) -def ir_object_recursive(obj: IRObject, fn: Callable): - """recursive on obj.value / dict / list / tuple""" +# TODO: this function should rewrite with pytree +def any_ir_object_satisfy(obj: Union[Any, IRObject], condition: Callable[[IRObject], bool]) -> bool: + """ + recursive on obj.value / dict / list / tuple / slice with a function returned bool, + if any IRObject hit the condition, return True, or return false. + """ if isinstance(obj, dict): - return any(ir_object_recursive(v, fn) for v in obj.values()) + return any(any_ir_object_satisfy(v, condition) for v in obj.values()) elif isinstance(obj, (list, tuple)): - return any(ir_object_recursive(v, fn) for v in obj) + return any(any_ir_object_satisfy(v, condition) for v in obj) + elif isinstance(obj, slice): + return any(any_ir_object_satisfy(v, condition) for v in (obj.start, obj.stop, obj.step)) elif isinstance(obj, IRObject): - if fn(obj): + if condition(obj): return True elif obj.value is not None: - return ir_object_recursive(obj.value, fn) + return any_ir_object_satisfy(obj.value, condition) else: return False else: @@ -63,7 +69,7 @@ def ir_object_recursive(obj: IRObject, fn: Callable): def ir_object_contains_dynamic(obj: IRObject): - return ir_object_recursive(obj, lambda a: not a.is_constant) + return any_ir_object_satisfy(obj, lambda a: not a.is_constant) def Identity(tensor: IRObject, signature = None): @@ -501,9 +507,16 @@ def BitwiseNot(input, *, out=None, signature=None): return IRDimops(BitwiseNot, 'bitwise_not', signature, annos, [input]) -def _unwrap_value(obj: IRObject): +# TODO: this function should rewrite with pytree +def _unwrap_value(obj: Union[IRObject, Any]): if isinstance(obj, IRObject): return _unwrap_value(obj.value) + elif isinstance(obj, (list, tuple)): + return type(obj)(_unwrap_value(v) for v in obj) + elif isinstance(obj, dict): + return {k: _unwrap_value(v) for k, v in obj.items()} + elif isinstance(obj, slice): + return slice(_unwrap_value(obj.start), _unwrap_value(obj.stop), _unwrap_value(obj.step)) else: return obj @@ -1236,13 +1249,13 @@ def Reshape(input, shape: Tuple[int], *arg_shape, signature = None): assert shape.value is not None, f"shape should have a reference value but got: {shape}" if isinstance(shape.value, int): shape = (shape,) + arg_shape - ou_shape = [d.value if isinstance(d, IRObject) else d for d in shape] + ou_shape = _unwrap_value(list(shape)) else: # tuple[int] / list[int] assert len(arg_shape) == 0, f"already got a tuple of int shape" - ou_shape = list(shape.value) + ou_shape = _unwrap_value(list(shape.value)) else: # int / tuple[int] shape = ((shape,) if isinstance(shape, int) else tuple(shape)) + arg_shape - ou_shape = [d.value if isinstance(d, IRObject) else d for d in shape] + ou_shape = _unwrap_value(list(shape)) assert all(isinstance(d, int) for d in ou_shape), f"but got {ou_shape}" anno, rules = _reshape_anno(in_shape, ou_shape, kwarg_name='shape') @@ -2004,16 +2017,58 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: # object slice if isinstance(obj, IRObject): assert obj.value is not None - if isinstance(obj.value[index], IRTensor): - out = obj.value[index] + unwrap_index = _unwrap_value(index) + if isinstance(obj.value[unwrap_index], IRTensor): + out = obj.value[unwrap_index] else: - val = obj.value[index] + val = obj.value[unwrap_index] is_constant = not (isinstance(val, IRObject) and not val.is_constant) out = IRObject(name='getitem', value=val, is_constant=is_constant) return IRPyFunc(signature, [obj, index], [out]) + # obj is not a IRObject, index is a IRObject + if any_ir_object_satisfy(index, lambda a: isinstance(a, IRObject)): + # if index is not constant, than the out is not constant + is_constant = not ir_object_contains_dynamic(index) + val = obj[_unwrap_value(index)] + out = IRObject(name='getitem', value=val, is_constant=is_constant) + return IRPyFunc(signature, [obj, index], [out]) return obj[index] +def SetItem(__a: Any, __b: Any, __c: Any, signature = None) -> Union[Any, IRPyFunc]: + """_operator.setitem(__a, __b, __c) / cube.runtime.function.setitem(__a, __b, __c)""" + signature = 'cube.runtime.function.setitem' + obj, index, val = __a, __b, __c + if isinstance(obj, IRTensor): + # TODO: move to some function like FullSlice when ready + # TODO: give a IRTensor as return value or return a IRDimops + return IRPyFunc(signature, [__a, __b, __c], [IRObject()]) + + is_constant = not ir_object_contains_dynamic(index) + index = _unwrap_value(index) + if isinstance(obj, IRObject): + is_constant = is_constant and obj.is_constant + obj = obj.value + + # not sure if it is safe the original obj be modified, + # we can not get the original value in the following program if we need it but it is inplace modified, + # here use shallow copy to prevent modify the original obj + obj = copy.copy(obj) + obj[index] = val + return IRPyFunc(signature, [__a, __b, __c], [IRObject(value=obj, is_constant=is_constant)]) + + +def Len(__obj: Any, signature = None) -> Union[Any, IRPyFunc]: + """builtins.len""" + if isinstance(__obj, IRTensor): + # TODO: IRTensor did not support dynamic shape attr now, so here the returned IRObject is constant + return IRPyFunc(signature, [__obj], [IRObject(value=__obj.shape[0])]) + elif isinstance(__obj, IRObject): + return IRPyFunc(signature, [__obj], [IRObject(value=len(__obj.value), is_constant=__obj.is_constant)]) + else: + return IRPyFunc(signature, [__obj], [IRObject(value=len(__obj))]) + + def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], IRPyFunc]: """ builtins.getattr(object, name[, default]) diff --git a/cube/graph/parser/converter.py b/cube/graph/parser/converter.py index 3ae4e5b5..5e11e0f0 100644 --- a/cube/graph/parser/converter.py +++ b/cube/graph/parser/converter.py @@ -45,14 +45,14 @@ def _rewrite_inplace_ops(traced_model: torch.fx.GraphModule): done_nodes = set() for n in traced_model.graph.nodes: done_nodes.add(n) + # inplace operator on torch.Tensor has the pattern: first arg is tensor + "call_method" + method name end with single "_" if ( (n.op == "call_method" and n.target.endswith("_") and not n.target.endswith("__")) - or (n.op == "call_function" and n.target in side_effectful_inplace_ops) - ) and n.args[0].meta.get('type', None) == torch.Tensor: + and n.args[0].meta.get('type', None) == torch.Tensor + ) or (n.op == "call_function" and n.target in side_effectful_inplace_ops): # setitem is a special inplace operator that returns None instead of the first modified argument, # to make it align with SSA format, we use cube runtime function to return the first argument if n.op == "call_function" and n.target == operator.setitem: - n.meta = n.args[0].meta n.target = cube_rt_function.setitem n.args[0].replace_all_uses_with(n, delete_user_cb=lambda node: not node in done_nodes) # we can't recompile diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 282f1271..aca1121c 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -24,14 +24,14 @@ import torch from torch._C import ScriptObject from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_unflatten import torch.fx from torch.fx import GraphModule from torch.fx._compatibility import compatibility from torch.fx._symbolic_trace import _proxyable_classes from torch.fx.graph import Graph -from torch.fx.node import Target, Node, Argument, _side_effectful_functions +from torch.fx.node import Target, Node, Argument, _side_effectful_functions, base_types from torch.fx.proxy import TracerBase from torch.fx.operator_schemas import check_for_mutable_operation @@ -134,7 +134,15 @@ def __exit__(self, *args): side_effectful_inplace_ops, ) -from .utils import FrameRecord, ExtraSEFPatcher, extract_results_metadata, EmptyResult +from .utils import ( + FrameRecord, + ExtraSEFPatcher, + EmptyResult, + extract_results_metadata, + flatten_trees_with_func, + flatten_trees_with_func_and_spec, + map_trees_with_func, +) # pyright: reportGeneralTypeIssues=false _logger = logging.getLogger(__name__) @@ -375,18 +383,11 @@ def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: D apply the patcher, and the _autowrap_check to the target function. """ if kind == 'output': - return args[0] + return args[0], args, kwargs elif kind == 'placeholder': - return self.placeholder_dict[target] - - to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t - to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t + return self.placeholder_dict[target], args, kwargs def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]): - if self.cpu_offload: - args = tree_map(to_cuda, args) - kwargs = tree_map(to_cuda, kwargs) - if kind == 'call_function': assert isinstance(target, Callable) fn = target @@ -411,25 +412,21 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] return result with self.do_temp_disable(call=True): + if self.cpu_offload: + args, kwargs = tree_to_cuda(args), tree_to_cuda(kwargs) + result = run(kind, target, args, kwargs) + if self.cpu_offload: - # move back arguments to cpu if cpu_offload - args = tree_map(to_cpu, args) - kwargs = tree_map(to_cpu, kwargs) - if isinstance(result, torch.Tensor): - result = result.cpu() - elif isinstance(result, (list, dict, tuple)): - result = tree_map(to_cpu, result) - elif isinstance(result, (int, bool, float, torch.device, torch.dtype, torch.finfo)) or result is None: - # avoid too noisy warning - pass - else: - _logger.warning(f"result of target {target} is {type(result)}, which is not a common behavior.") + args, kwargs, result = tree_to_cpu(args), tree_to_cpu(kwargs), tree_to_cpu(result) + unexpected_types = types_other_than(result, (*base_types, type(None), torch.Tensor)) + if not contains_types(result, (torch.Tensor,)) and unexpected_types: + _logger.warning(f"result of target {target} contains unexpected types {unexpected_types}, which is not a common behavior.") torch.cuda.empty_cache() self.temp_disable_call = False - return result + return result, args, kwargs @compatibility(is_backward_compatible=True) def create_node(self, kind : str, target : Target, @@ -474,22 +471,22 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: use the 'run_target' to actually execute the code, and store the value in 'value' field. create the nodes for the target and the input of the target (if the target is one of call_method, call_function, call_module). """ - def unwrap(obj: Any): - while _orig_isinstance(obj, ep.ConcreteProxy): - obj = obj.value - return obj - args_unwrapped = ep.map_aggregate_not_proxy(args, unwrap) - kwargs_unwrapped = ep.map_aggregate_not_proxy(kwargs, unwrap) - with self.patcher.revert(): - value_unwrapped = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) + args_unwrapped, kwargs_unwrapped = unwrap_nested_proxy(args), unwrap_nested_proxy(kwargs) + value_unwrapped, args_run, kwargs_run = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) + + # because setitem is an inplace operation and will not return the obj, so here is a workaound to record node result + node_result = args_run[0] if kind == "call_function" and target == operator.setitem else value_unwrapped + # here update the origin args/kwargs to prevent inplace operator to the input + args = update_tree_proxy_value(args, args_run) + kwargs = update_tree_proxy_value(kwargs, kwargs_run) args_ = self.create_arg(args) kwargs_ = self.create_arg(kwargs) assert isinstance(args_, tuple) assert isinstance(kwargs_, dict) - node = self.create_node(kind, target, args_, kwargs_, name, type_expr, value_unwrapped) + node = self.create_node(kind, target, args_, kwargs_, name, type_expr, node_result) if self.record_frames and kind != 'placeholder': with self.do_temp_disable(True, True, True): @@ -594,7 +591,9 @@ def create_arg(self, a: Any) -> Union[Node, Any]: return a if isinstance(a, (dict_keys_type, dict_values_type, dict_items_type)): - return a + # here we directly flat all values as a list, + # for the create_arg do not support (dict_keys_type, dict_values_type, dict_items_type) + a = list(a) return super().create_arg(a) @@ -1335,6 +1334,61 @@ def wrapped(*args, **kwargs): return wrapped +def contains_types(pytree, types) -> bool: + """if pytree leaf has the given types, return true""" + return any(flatten_trees_with_func(lambda x: isinstance(x, types), [pytree])[0]) + + +def types_other_than(pytree, given_types) -> Set[Type]: + """return a set of types of the pytree leaf other than given_types""" + types = set(flatten_trees_with_func(lambda x: type(x) if not isinstance(x, given_types) else None, [pytree])[0]) + if None in types: + types.remove(None) + return types + + +def tree_to_cuda(pytree): + """return a same spec pytree with all the given pytree leaf tensor to cuda""" + return map_trees_with_func(lambda a: a.cuda() if isinstance(a, torch.Tensor) else a, [pytree]) + + +def tree_to_cpu(pytree): + """return a same spec pytree with all the given pytree leaf tensor to cpu""" + return map_trees_with_func(lambda a: a.cpu() if isinstance(a, torch.Tensor) else a, [pytree]) + + +def unwrap_nested_proxy(pytree): + """ + return a same spec pytree with the ConcreteProxy in the old pytree replaced with ConcreteProxy.value + """ + def unwrap(obj: Any): + while isinstance(obj, ep.ConcreteProxy): + obj = obj.value + return obj + + while contains_types(pytree, (ep.ConcreteProxy,)): + pytree = map_trees_with_func(unwrap, [pytree]) + return pytree + + +def update_tree_proxy_value(dst_pytree, src_pytree): + """ + copy the value from src_pytree to dst_pytree with the dst_pytree spec, + if the leaf is proxy, only replace the proxy.value, not replace the proxy. + """ + _, spec = tree_flatten(dst_pytree) + + def update_proxy_value(a, b): + if isinstance(a, ep.ConcreteProxy): + a.value = update_tree_proxy_value(a.value, b) + return a + else: + return b + + flat_arg = flatten_trees_with_func_and_spec(update_proxy_value, [dst_pytree, src_pytree], spec) + return tree_unflatten(flat_arg, spec) + + @compatibility(is_backward_compatible=True) class GraphAppendingConcreteTracer(ConcreteTracer): def __init__(self, graph: Graph): diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index a7adb019..fc7747b9 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -181,7 +181,8 @@ def patch_func_helper(self, func): 6. super() -> super(self.__class__, self) 7. func(a, b, c) -> patch_run(func, a, b, c) # for patch the functions called in the current function """ - if not hasattr(func, '__module__') or func.__module__ is None or func.__module__.startswith('torch.'): + if not hasattr(func, '__module__') or func.__module__ is None \ + or func.__module__.startswith('torch.') or func.__module__.startswith('cube.'): return func # those flags are set by fx _Patcher when a method is patched # we don't want to patch it again diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index fe1abc74..733b596f 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -2,13 +2,17 @@ # Licensed under the MIT license. import builtins +from collections import namedtuple from dataclasses import dataclass import operator -from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Type -import functools +from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Type, List import torch +import torch.utils._pytree as torch_pytree from torch.fx.node import Node, map_aggregate, _side_effectful_functions +from torch.utils._pytree import tree_flatten, tree_unflatten, LeafSpec, TreeSpec, SUPPORTED_NODES + +from . import concrete_proxy as ep # These need to run in global scope to handle nested calls correctly _orig_module_call: Callable = torch.nn.Module.__call__ @@ -64,24 +68,6 @@ } -def run_onlyif_instance(cond_type: Type[Any], return_orig: bool = True, return_const: Any = None): - def helper(fn): - if return_orig: - @functools.wraps(fn) - def wrapper_orig(*args): - if _orig_isinstance(args[-1], cond_type): - return fn(*args) - return args[-1] - return wrapper_orig - else: - @functools.wraps(fn) - def wrapper_const(*args): - if _orig_isinstance(args[-1], cond_type): - return fn(*args) - return return_const - return wrapper_const - return helper - def map_recursive(fn: Callable, arg) -> Any: """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. @@ -97,28 +83,110 @@ def map_recursive(fn: Callable, arg) -> Any: else: return fn(arg) -def map_recursive_zip(fn: Callable, arg0, *args) -> Any: + +def _get_node_type(pytree: Any) -> Any: + if isinstance(pytree, ep.ConcreteProxy): + return _orig_type(pytree) + if torch_pytree._is_namedtuple_instance(pytree): + return namedtuple + return type(pytree) + +torch_pytree._get_node_type = _get_node_type + + +def flatten_trees_with_func(fn, pytrees): """ - Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + Each pytree in pytrees should have the same structure. + + Example: + + pytrees = [ + [1, 2, (3, 4)], # pytree 1 + [5, 6, (7, 8)], # pytree 2 + ] + + # the returned value is + [fn(1, 5), fn(2, 6), fn(3, 7), fn(4, 8)], [*, *, (*, *)] """ - if _orig_type(arg0) != torch.Size and _orig_isinstance(arg0, _orig_tuple): - for arg in args: - assert (not _orig_isinstance(arg, torch.Size)) and _orig_isinstance(arg, _orig_tuple) - assert len(arg0) == len(arg) - return _orig_tuple(map_recursive_zip(fn, *sub_args) for sub_args in _orig_zip(arg0, *args)) - elif _orig_isinstance(arg0, _orig_list): - for arg in args: - assert _orig_isinstance(arg, _orig_list) - assert len(arg0) == len(arg) - return _orig_list(map_recursive_zip(fn, *sub_args) for sub_args in _orig_zip(arg0, *args)) - elif _orig_isinstance(arg0, _orig_dict): - keys = _orig_set(arg0.keys()) - for arg in args: - assert _orig_isinstance(arg, _orig_dict) and len(keys.symmetric_difference(arg.keys())) == 0 - return {k: map_recursive_zip(fn, arg0[k], *(arg[k] for arg in args)) for k in keys} - else: - # assert not _orig_isinstance(arg0, slice) - return fn(arg0, *args) + flat_trees = [tree_flatten(pytree) for pytree in pytrees] + flat_args = [v[0] for v in flat_trees] + specs = [v[1] for v in flat_trees] + + if not all(len(flat_arg) == len(flat_args[0]) for flat_arg in flat_args): + raise RuntimeError('the element number of pytrees are not equal') + if not all(str(spec) == str(specs[0]) for spec in specs): + raise RuntimeError('the structure of pytrees are not equal') + + return [fn(*vals) for vals in zip(*flat_args)], specs[0] + + +def map_trees_with_func(fn, pytrees): + """ + Each pytree in pytrees should have the same structure. + The returned value has the same structure with pytree in pytrees. + + Example: + + pytrees = [ + [1, 2, (3, 4)], # pytree 1 + [5, 6, (7, 8)], # pytree 2 + ] + + # the returned value is + [fn(1, 5), fn(2, 6), (fn(3, 7), fn(4, 8))] + """ + flat_args, spec = flatten_trees_with_func(fn, pytrees) + return tree_unflatten([i for i in flat_args], spec) + + +def flatten_tree_with_spec(pytree, spec: TreeSpec) -> List: + """ + Flat a pytree with a given spec. + + Example: + + pytree = [1, (2, {3: 4})] + spec = TreeSpec([*, (*, *)]) + + # the returned value is + [1, 2, {3: 4}] + """ + assert isinstance(spec, TreeSpec) + + if isinstance(spec, LeafSpec): + return [pytree] + + flatten_fn = SUPPORTED_NODES[spec.type].flatten_fn + child_pytrees, _ = flatten_fn(pytree) + + if len(child_pytrees) != len(spec.children_specs): + raise RuntimeError(f'The number of pytree children is not equal to the give specs.') + + result = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = flatten_tree_with_spec(child, child_spec) + result += flat + + return result + + +def flatten_trees_with_func_and_spec(fn, pytrees, spec): + """ + Example: + + pytrees = [ + [1, (2, {3: 4})], + [5, (6, 7)] + ] + spec = [*, (*, *)] + + # the returned value is + [fn(1, 5), fn(2, 6), fn({3: 4}, 7)] + """ + flat_args = [flatten_tree_with_spec(pytree, spec) for pytree in pytrees] + if not all(len(flat_arg) == len(flat_args[0]) for flat_arg in flat_args): + raise RuntimeError('the element number of pytrees are not equal') + return [fn(*vals) for vals in zip(*flat_args)] @dataclass diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 9965403d..26629577 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -123,10 +123,12 @@ def exist(signature: str) -> bool: __tttemplate('to'): function.To, __tttemplate('dim'): function.Dim, '_operator.getitem': function.GetItem, + '_operator.setitem': function.SetItem, 'builtins.getattr': function.GetAttr, 'builtins.tuple': function.MakeTuple, 'builtins.list': function.MakeList, 'builtins.slice': function.MakeSlice, + 'builtins.len': function.Len, # # torch nn functional '_operator.matmul': function.Matmul, @@ -211,6 +213,7 @@ def exist(signature: str) -> bool: __rtemplate('identity'): function.Identity, __rtemplate('multiref'): function.MultiRef, __rtemplate('accum'): function.Accum, + __rtemplate('setitem'): function.SetItem, # #einops # __einopsize('apply_for_scriptable_torch'): function.ScriptEinOps, diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index aec639a4..efa5cab4 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -10,7 +10,7 @@ from cube.graph.parser.fx.mapping import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc from cube.graph.function.dimops import IRDimops -from cube.graph.function.function import ir_object_recursive +from cube.graph.function.function import any_ir_object_satisfy import torch.fx from .concrete_trace_utils import TensorMetadata @@ -211,7 +211,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # case2: python runtime function else: _logger.warning(f'Set python runtime function: {fsig}') - if ir_object_recursive(input_vals, lambda a: not a.is_constant): + if any_ir_object_satisfy(input_vals, lambda a: not a.is_constant): err_msg = f'non register python runtime function {fsig} has a non constant input: {input_vals}, ' + \ 'please register it as a customized function using cube.graph.parser.register' raise RuntimeError(err_msg) @@ -234,8 +234,8 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule ir_node.set_output(i, vals[i]) elif not isinstance(ir_node.output(0), IRTensor) and ir_node.output(0).value is not None: if dynamic_shape or \ - ir_object_recursive(ir_node.output(0), lambda a: not a.is_constant) or \ - ir_object_recursive(ir_node.output(0), lambda a: isinstance(a, IRTensor)): + any_ir_object_satisfy(ir_node.output(0), lambda a: not a.is_constant) or \ + any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, IRTensor)): frame.set_var(node.name, ir_node.output(0)) ir_node.output(0).name = node.name else: diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 4a90eab8..35b256d5 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -221,6 +221,8 @@ def test_GetItem(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, d -> d b c' op = F.GetItem(IRTensor([3, 4, 2]), IRTensor([3, 5], dtype=torch.int64)) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, d e -> d e b c' + op = F.GetItem([1, 2, 3], IRObject(value=0, is_constant=False), signature='operator.getitem') + assert op.outputs()[0].value == 1 and op.outputs()[0].is_constant == False def test_Max(): @@ -295,3 +297,22 @@ def test_NewTensor(): op = F.NewTensor(np.array([[1],[2],[3]])) assert repr(op.anno) == ' -> 3^ 1^' assert op.kwargs['data'] == [[1],[2],[3]] + + +def test_Setitem(): + set_val = IRObject(value=4, is_constant=False) + op = F.SetItem(IRObject(value=[1, 2, 3]), 0, set_val) + assert op.outputs()[0].value == [set_val, 2, 3] + assert op.outputs()[0].is_constant + + op = F.SetItem(IRObject(value=[1, 2, 3], is_constant=False), 0, set_val) + assert op.outputs()[0].value == [set_val, 2, 3] + assert not op.outputs()[0].is_constant + + +def test_Len(): + op = F.Len([1, 2, 3], signature='builtins.len') + assert op.outputs()[0].value == 3 + + op = F.Len(IRObject(value=[1, 2, 3], is_constant=False), signature='builtins.len') + assert op.outputs()[0].value == 3 and not op.outputs()[0].is_constant diff --git a/tests/graph/tracer/test_inplace.py b/tests/graph/tracer/test_inplace.py index 13ca00f3..58b27474 100644 --- a/tests/graph/tracer/test_inplace.py +++ b/tests/graph/tracer/test_inplace.py @@ -192,3 +192,46 @@ def test_inplace_setitem_op(): print(nodes[9].args[0]) assert nodes[9].args[0] == nodes[8] + + +class InplaceOpBranchModule(torch.nn.Module): + def forward(self, x: torch.Tensor, ls: list): + y = torch.ones_like(x) + y.add_(1) + # if y can be correct add 1 during trace, it will always go into the first branch + if y.mean().item() > 1.5: + x.sub_(y) + else: + x.mul_(y) + ls[-1] = 1 + ls[-2] = 2 + return ls + + +@replace_all_device_with('cpu') +def test_inplace_op_value(): + model = InplaceOpBranchModule() + dummy_input = {'x': torch.rand(10), 'ls': [3, 3, 3]} + traced_graph = to_fx_graph(model, dummy_input) + + contains_sub_node, contains_mul_node = False, False + setitem_count = 0 + for node in traced_graph.graph.nodes: + if node.op == 'call_method': + if node.target == 'sub_': + contains_sub_node = True + if node.target == 'mul_': + contains_mul_node = True + if node.op == 'call_function': + if node.target is cube_rt_function.setitem: + # this means during trace, the value of ls after setitem is correct + if setitem_count == 0: + assert node.meta['tensor_meta'] == [3, 3, 1] + if setitem_count == 1: + assert node.meta['tensor_meta'] == [3, 2, 1] + setitem_count += 1 + + # this means during trace, the value of y after inplace add is correct + assert contains_sub_node and not contains_mul_node + # there should have two setitem node in graph + assert setitem_count == 2 diff --git a/tests/graph/tracer/test_pytree.py b/tests/graph/tracer/test_pytree.py new file mode 100644 index 00000000..6eb712f8 --- /dev/null +++ b/tests/graph/tracer/test_pytree.py @@ -0,0 +1,59 @@ +import pytest + +from torch.utils._pytree import tree_flatten + +from cube.graph.parser.fx.concrete_trace_utils.utils import ( + flatten_tree_with_spec, + flatten_trees_with_func, + flatten_trees_with_func_and_spec, + map_trees_with_func +) + + +def test_flatten_tree_with_spec(): + pytree_1 = [1, (2, (3, 4))] + pytree_2 = [1, (2, (3, (4, 5)))] + pytree_3 = [1, (2, (3,))] + _, spec = tree_flatten(pytree_1) + + assert flatten_tree_with_spec(pytree_2, spec) == [1, 2, 3, (4, 5)] + + # pytree_3 can not flatten by pytree_1 spec, so it should raise error + with pytest.raises(RuntimeError): + flatten_tree_with_spec(pytree_3, spec) + + +def test_flatten_trees_with_func(): + pytree_1 = [1, (2, {3: 4})] + pytree_2 = [5, (6, {3: 5})] + flat_args, spec = flatten_trees_with_func(lambda a, b: a + b, [pytree_1, pytree_2]) + assert flat_args == [6, 8, 9] + assert spec == tree_flatten(pytree_1)[1] + + pytree_3 = [1, (2, (3,))] + # pytree_3 has different spec with pytree_1 and pytree_2, so it should raise error + with pytest.raises(RuntimeError): + flatten_trees_with_func(lambda a, b, c: a + b + c, [pytree_1, pytree_2, pytree_3]) + + +def test_flatten_trees_with_func_and_spec(): + pytree_0 = [1, (2, 3)] + _, spec = tree_flatten(pytree_0) + + def merge(a, b): + if isinstance(a, dict): + assert isinstance(b, dict) + return {**a, **b} + else: + return a + b + + pytree_1 = [1, (2, {3: 4})] + pytree_2 = [5, (6, {4: 5})] + assert flatten_trees_with_func_and_spec(merge, [pytree_1, pytree_2], spec) == [6, 8, {3: 4, 4: 5}] + + +def test_map_trees_with_func(): + pytree_1 = [1, (2, {3: 4})] + pytree_2 = [5, (6, {3: 5})] + + assert map_trees_with_func(lambda a, b: a + b, [pytree_1, pytree_2]) == [6, (8, {3: 9})] From 8cf14052caeb2280b3e116a50b88e7bebe00bc3b Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 6 Mar 2024 08:04:10 +0000 Subject: [PATCH 1600/1892] Merged PR 2063: add kwargs in node serialize sometimes kwargs will change the op output behavior, for example, the output of `tensor[:, 10]` and `tensor[:, 2: 10]` has different shape, and we should treat these two node different. --- cube/profiler/database.py | 5 ++++- tests/profiler/test_op_profile.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/cube/profiler/database.py b/cube/profiler/database.py index e30b5e7b..893f2c2f 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -230,11 +230,14 @@ def get_func(node: IRFwOperation) -> Tuple[Callable, Shapes, DTypes, Dict]: fn = eval(node.signature) shapes, dtypes, requires_grads, values = [], [], [], [] + # TODO: this function should rewrite with pytree def extract_val(val: Union[IRObject, Any]) -> Any: if isinstance(val, IRObject): return extract_val(val.value) elif isinstance(val, tuple): return tuple([extract_val(v) for v in val]) + elif isinstance(val, list): + return list([extract_val(v) for v in val]) elif isinstance(val, dict): return {k: extract_val(v) for k, v in val.items()} elif isinstance(val, slice): @@ -373,7 +376,7 @@ def _serialize(self, node: IRFwOperation) -> str: key str: the serialized string """ shapes, dtypes, requires_grads = [], [], [] - for t in node.inputs(): + for t in node.inputs() + node.outputs(): if isinstance(t, IRTensor): shapes.append(t.shape) dtypes.append(t.dtype) diff --git a/tests/profiler/test_op_profile.py b/tests/profiler/test_op_profile.py index d2ef7252..576b3697 100644 --- a/tests/profiler/test_op_profile.py +++ b/tests/profiler/test_op_profile.py @@ -5,6 +5,7 @@ import torch from cube.parallel import _gen_graph +from cube.ir.tensor import IRTensor from cube.ir.operator import IRFwOperation from cube.profiler.database import CompProfiler, ProfileDataBase @@ -34,3 +35,14 @@ def test_op_profile_times(): toc = time.perf_counter() # this is always true because the op is very small. assert toc - tic < 20, f'op profile time is too long {toc - tic}' + + +def test_serialize(): + op = IRFwOperation('test', 'cube_test', [IRTensor(shape=[10, 20], dtype=torch.float)], 1) + db = ProfileDataBase() + op.set_output(0, IRTensor(shape=[10, 20], dtype=torch.float)) + key1 = db._serialize(op) + op.set_output(0, IRTensor(shape=[10, 10], dtype=torch.float)) + key2 = db._serialize(op) + # test different output have different serialize + assert key1 != key2 From 75a18929013d18a86c14f9ac60edece4648b99fb Mon Sep 17 00:00:00 2001 From: "Xin Ji (CSI Interfusion Co Ltd)" Date: Wed, 13 Mar 2024 07:00:24 +0000 Subject: [PATCH 1601/1892] Merged PR 2061: support some functions for new models Add new functions and test them on yilei's new huggingface model --- cube/graph/function/function.py | 116 +++++++++++++++++++------ cube/graph/parser/fx/mapping.py | 4 + cube/runtime/function/function.py | 10 --- tests/graph/function/test_functions.py | 71 ++++++++++++++- tests/graph/parser/test_parser.py | 23 +++++ tests/parallel_module/test_gencode.py | 74 ++++++++++++++++ 6 files changed, 261 insertions(+), 37 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 53f7c8eb..d4dd9af3 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1885,40 +1885,42 @@ def CompareNE(input, other, *, out=None, signature = None): return _comparison(CompareNE, operator.eq, 'ne', signature, input, other) -def Max(input, other_or_dim=None, out_or_keepdim=None, *, out=None, signature = None, **kwargs): +def Max(input, other_or_dim=None, keepdim=False, *, out=None, signature = None, **kwargs): """ torch.max(input) torch.max(input, dim, keepdim=False, *, out=None) torch.max(input, other, *, out=None) """ - signature = 'cube.runtime.function.max_' + assert out is None if 'dim' in kwargs: + assert other_or_dim is None and 'other' not in kwargs, f'dim and other cannot be both specified, get {kwargs}' other_or_dim = kwargs['dim'] - if 'keepdim' in kwargs: - assert 'out' not in kwargs, f'out and keepdim cannot be both specified, get {kwargs}' - out_or_keepdim = kwargs['keepdim'] - if 'out' in kwargs: - assert 'keepdim' not in kwargs, f'out and keepdim cannot be both specified, get {kwargs}' - out_or_keepdim = kwargs['out'] - if other_or_dim is None: - edim_in = [s + '^' for s in ShapeAnno.create_shape_str(input.shape)] - annos = [OpAnno.create_op_str([edim_in], ['1'])] - return IRDimops(Max, 'max', signature, annos, [input]) - elif isinstance(other_or_dim, IRTensor): - lshape, rshape, oshape = _handle_broadcast(input, other_or_dim) + if 'other' in kwargs: + assert other_or_dim is None and 'dim' not in kwargs, f'dim and other cannot be both specified, get {kwargs}' + other_or_dim = kwargs['other'] + if isinstance(other_or_dim, IRTensor): + other = other_or_dim + lshape, rshape, oshape = _handle_broadcast(input, other) annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] - return IRDimops(Max, 'max', signature, annos, [input, other_or_dim]) + return IRDimops(Max, 'max', signature, annos, [input, other]) else: - assert isinstance(other_or_dim, int) and isinstance(out_or_keepdim, bool) - edim_in = ShapeAnno.create_shape_str(input.shape) - edim_in[other_or_dim] += '^' - edim_out = copy.copy(edim_in) - if out_or_keepdim: - edim_out[other_or_dim] = '1' - else: - edim_out.pop(other_or_dim) - annos = [OpAnno.create_op_str([edim_in], [edim_out, edim_out])] - return IRDimops(Max, 'max', signature, annos, [input], other_or_dim=other_or_dim, out_or_keepdim=out_or_keepdim) + other_or_dim_val = _unwrap_value(other_or_dim) + if other_or_dim_val is None: + edim_in = [s + '^' for s in ShapeAnno.create_shape_str(input.shape)] + annos = [OpAnno.create_op_str([edim_in], ['1'])] + return IRDimops(Max, 'max', signature, annos, [input]) + elif isinstance(other_or_dim_val, int): + keepdim_val = _unwrap_value(keepdim) + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_in[other_or_dim_val] += '^' + edim_out = copy.copy(edim_in) + if keepdim_val: + edim_out[other_or_dim_val] = '1' + else: + edim_out.pop(other_or_dim_val) + kwargs = {'dim': other_or_dim, 'keepdim': keepdim} + annos = [OpAnno.create_op_str([edim_in], [edim_out, edim_out])] + return IRDimops(Max, 'max', signature, annos, [input], **kwargs) def ShapeAsTensor(input: IRTensor, signature = None): @@ -2172,3 +2174,67 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, dropout_p=0.0, anno = OpAnno.create_op_str([query_anno, key_anno, value_anno], [out_anno]) return IRDimops(ScaledDotProductAttention, 'scaled_dot_product_attention', signature, [anno], [query, key, value], attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + + +def Min(input, other_or_dim=None, keepdim=False, *, out=None, signature = None, **kwargs): + """ + torch.min(input) + torch.min(input, dim, keepdim=False, *, out=None) + torch.min(input, other, *, out=None) + """ + assert out is None + if 'dim' in kwargs: + assert other_or_dim is None and 'other' not in kwargs, f'dim and other cannot be both specified, get {kwargs}' + other_or_dim = kwargs['dim'] + if 'other' in kwargs: + assert other_or_dim is None and 'dim' not in kwargs, f'dim and other cannot be both specified, get {kwargs}' + other_or_dim = kwargs['other'] + if isinstance(other_or_dim, IRTensor): + other = other_or_dim + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] + return IRDimops(Min, 'min', signature, annos, [input, other]) + else: + other_or_dim_val = _unwrap_value(other_or_dim) + if other_or_dim_val is None: + edim_in = [s + '^' for s in ShapeAnno.create_shape_str(input.shape)] + annos = [OpAnno.create_op_str([edim_in], ['1'])] + return IRDimops(Min, 'min', signature, annos, [input]) + elif isinstance(other_or_dim_val, int): + keepdim_val = _unwrap_value(keepdim) + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_in[other_or_dim_val] += '^' + edim_out = copy.copy(edim_in) + if keepdim_val: + edim_out[other_or_dim_val] = '1' + else: + edim_out.pop(other_or_dim_val) + kwargs = {'dim': other_or_dim, 'keepdim': keepdim} + annos = [OpAnno.create_op_str([edim_in], [edim_out, edim_out])] + return IRDimops(Min, 'min', signature, annos, [input], **kwargs) + + +def Log(input, *, out=None, signature=None): + """ + torch.log(input, *, out=None) -> Tensor + """ + assert out is None + if not isinstance(input, IRObject): + return torch.log(input) if isinstance(input, torch.Tensor) else math.log(input) + if not isinstance(input, IRTensor): + return IRPyFunc(signature, [input], [_compute_unary_op(input, math.log, 'log')]) + edim_in = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([edim_in], [edim_in])] + return IRDimops(Log, 'log', signature, annos, [input]) + + +def FullLike(input, fill_value, *, dtype=None, layout=None, + device=None, requires_grad=False, memory_format=None, signature=None): + """ + torch.full_like(input, fill_value, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor + """ + assert layout in (None, torch.strided) and memory_format is None, f"Not support for non-default memory_format and layout" + kwargs = {'fill_value': fill_value, 'requires_grad': requires_grad,'dtype': dtype} + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(FullLike, 'full_like', signature, annos,[input],**kwargs) \ No newline at end of file diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 26629577..543b55ac 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -62,6 +62,8 @@ def exist(signature: str) -> bool: 'math.exp': function.Exp, __ttemplate('sqrt'): function.Sqrt, 'math.sqrt': function.Sqrt, + __ttemplate('log'): function.Log, + 'math.log': function.Log, __ttemplate('rsqrt'): function.RSqrt, __ttemplate('clamp'): function.Clamp, __ttemplate('clamp_min'): function.ClampMin, @@ -78,6 +80,7 @@ def exist(signature: str) -> bool: __ttemplate('ne') : function.CompareNE, '_operator.ne': function.CompareNE, __ttemplate('max'): function.Max, + __ttemplate('min'): function.Min, __ttemplate('where'): function.Where, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('long'): function.Long, @@ -144,6 +147,7 @@ def exist(signature: str) -> bool: __ttemplate('ones'): function.Ones, __ttemplate('tensor'): function.NewTensor, __ttemplate('full'): function.Full, + __ttemplate('full_like'): function.FullLike, __ttemplate('rand'): function.Rand, __ttemplate('clone'): function.Clone, diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index a01322f2..dbe1029f 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -29,16 +29,6 @@ def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) return tensor.to(dtype_or_device) -def max_(input: torch.Tensor, other_or_dim: Union[torch.Tensor, int, None]=None, out_or_keepdim: Optional[bool]=None) -> torch.Tensor: - if other_or_dim is None: - return torch.max(input) - elif isinstance(other_or_dim, int): - return torch.max(input, other_or_dim, out_or_keepdim) - else: - assert isinstance(other_or_dim, torch.Tensor) - return torch.max(input, other_or_dim) - - def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: """ accumulate tensors in to one tensor diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 35b256d5..b3419768 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -6,6 +6,7 @@ import pytest import torch import numpy as np +import math def o(value): @@ -228,12 +229,28 @@ def test_GetItem(): def test_Max(): op = F.Max(IRTensor([2, 3, 4])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' - op = F.Max(IRTensor([2, 3, 4]), IRTensor([4])) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c, c -> a b c' op = F.Max(IRTensor([2, 3, 4]), 1, True) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a 1 c, a 1 c' op = F.Max(IRTensor([2, 3, 4]), 1, False) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a c, a c' + op = F.Max(IRTensor([2, 3, 4]), 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a c, a c' + op = F.Max(IRTensor([2, 3, 4]), IRTensor([4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c, c -> a b c' + op = F.Max(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False), True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' + op = F.Max(IRTensor([2, 3, 4]), 2,IRObject(value=True, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' + op = F.Max(IRTensor([2, 3, 4]), 2,IRObject(value=False, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b, a b' + op = F.Max(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False),IRObject(value=True, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' + op = F.Max(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False),IRObject(value=IRObject(value=True), is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' + op = F.Max(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False),IRObject(value=None, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b, a b' + op = F.Max(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b, a b' def test_Squeeze(): @@ -316,3 +333,53 @@ def test_Len(): op = F.Len(IRObject(value=[1, 2, 3], is_constant=False), signature='builtins.len') assert op.outputs()[0].value == 3 and not op.outputs()[0].is_constant + + +def test_Min(): + op = F.Min(IRTensor([2, 3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' + op = F.Min(IRTensor([2, 3, 4]), 1, True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a 1 c, a 1 c' + op = F.Min(IRTensor([2, 3, 4]), 1, False) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a c, a c' + op = F.Min(IRTensor([2, 3, 4]), 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a c, a c' + op = F.Min(IRTensor([2, 3, 4]), IRTensor([4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c, c -> a b c' + op = F.Min(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False), True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' + op = F.Min(IRTensor([2, 3, 4]), 2,IRObject(value=True, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' + op = F.Min(IRTensor([2, 3, 4]), 2,IRObject(value=False, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b, a b' + op = F.Min(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False),IRObject(value=True, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' + op = F.Min(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False),IRObject(value=None, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b, a b' + op = F.Min(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b, a b' + op = F.Min(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False),IRObject(value=IRObject(value=True), is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' + + +def test_FullLike(): + op = F.FullLike(IRTensor([2, 1, 4, 1]), 1.) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a b c d' + op_int = F.FullLike(IRTensor([3, 2]), 5) + assert len(op_int._annos_candidates) == 1 and op_int._annos_candidates[0] == 'a b -> a b' + op_true = F.FullLike(IRTensor([2, 2]), 1., requires_grad=True) + assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' + op_float = F.FullLike(IRTensor([1, 2],dtype=int), 1, dtype=torch.float) + assert len(op_float._annos_candidates) == 1 and op_float._annos_candidates[0] == 'a b -> a b' + + +def test_Log(): + result = F.Log(2) + assert result == math.log(2) + input_tensor = torch.rand(1, 2, 3) + op = F.Log(input_tensor) + assert torch.allclose(op, torch.log(input_tensor)) and op.shape == (1, 2, 3) + op = F.Log(IRTensor([1,2,3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c -> a b c' + op = F.Log(IRObject(value=6, is_constant=False), signature='math.log') + assert op.outputs()[0].value == math.log(6) and not op.outputs()[0].is_constant \ No newline at end of file diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 8e261eec..434a7e9c 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -93,3 +93,26 @@ def forward(self, x): assert isinstance(ir_graph.output(0), IRTensor) assert ir_graph.output(0).shape == (4, 1) + + +@replace_all_device_with('cpu') +def test_min(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.min(x, dim=1, keepdim=True)[0] + + dummy_input = {'x': torch.randn(10, 256)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + print(ir_graph.extra_repr()) + + assert isinstance(ir_graph.output(0), IRTensor) + assert ir_graph.output(0).shape == (10, 1) \ No newline at end of file diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 04c9c2ed..f87a85aa 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -481,3 +481,77 @@ def test_codegen_clone(): True ) assert isinstance(g.nodes()[0], cube.graph.function.dimops.IRDimops) + + +class MinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + return torch.min(a, b) + +def _gencode_min_function_worker(tempdir): + init_distributed() + m_new = parallelize( + MinModule(), + { + 'a': torch.tensor([5, 2, 3]), + 'b': torch.tensor([1, 8, 1]), + }, + PASData, + ComputeConfig(1, 1), + cube_savedir=tempdir, + load_module=True + ) + assert m_new is not None + args = inspect.signature(m_new._forward_impl) + assert len(args.parameters) == 2 + assert args.parameters['a'].default is inspect.Parameter.empty + assert args.parameters['b'].default is inspect.Parameter.empty + + + assert torch.equal(m_new(torch.tensor([5, 2, 3]), torch.tensor([1, 8, 1])), torch.tensor([1, 2, 1])), "Expected element-wise min" + assert torch.equal(m_new(torch.tensor([-5, -2, -3]), torch.tensor([-1, -8, -1])), torch.tensor([-5, -8, -3])), "Expected element-wise min with negative values" + + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_codegen_min(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, _gencode_min_function_worker, tempdir) + + + +class MaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + return torch.max(a, dim=1, keepdim=True)[0] + +def _gencode_max_function(tempdir): + init_distributed() + m_new = parallelize( + MaxModule(), + { + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + }, + PASData, + ComputeConfig(1, 1), + cube_savedir=tempdir, + load_module=True + ) + assert m_new is not None + args = inspect.signature(m_new._forward_impl).parameters + + assert len(args) == 1, "Expected 1 argument in the forward method" + assert args['a'].default is inspect.Parameter.empty, "Expected 'a' to have no default value" + + expected_output = torch.tensor([[3], [6]]) + actual_output = m_new(torch.tensor([[1, 2, 3], [4, 5, 6]])) + assert torch.equal(actual_output, expected_output), "Expected each row's max value with original dimension" + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of GPU devices') +def test_codegen_max(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, _gencode_max_function, tempdir) \ No newline at end of file From b05805acc87874071ddd7025208c25335360444f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 18 Mar 2024 01:41:20 +0000 Subject: [PATCH 1602/1892] Merged PR 2060: Parallel Module: checkpoint support for optimizer 1. Many codes are from Fairseq 2. Unused parameter support 3. Shared parameter support 4. Keep paddings in reducer bucket (in zero mode) Unit test pass Parity check pass --- cube/parallel.py | 503 +++++++++++++++--- cube/runtime/adapter/reducer.py | 18 +- cube/runtime/module.py | 35 +- cube/utils.py | 15 +- docs/parallel_module.md | 48 ++ tests/launch_torchrun.py | 2 +- tests/parallel_module/test_checkpoint.py | 151 +++++- .../parallel_module/test_checkpoint_buffer.py | 70 +++ .../parallel_module/test_checkpoint_shared.py | 197 +++++++ .../parallel_module/test_checkpoint_unused.py | 130 +++++ tests/parallel_module/test_gencode.py | 81 ++- 11 files changed, 1143 insertions(+), 107 deletions(-) create mode 100644 tests/parallel_module/test_checkpoint_buffer.py create mode 100644 tests/parallel_module/test_checkpoint_shared.py create mode 100644 tests/parallel_module/test_checkpoint_unused.py diff --git a/cube/parallel.py b/cube/parallel.py index f8bf169e..715e50eb 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -1,7 +1,7 @@ from enum import Enum from functools import partial import types -from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar, List +from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar, List, Set from pathlib import Path import inspect import sys @@ -9,6 +9,7 @@ from dataclasses import dataclass, asdict from contextlib import contextmanager import logging +import copy import torch from cube.graph.parser.fx.parser import FxModuleParser @@ -34,10 +35,10 @@ from cube.ir.unique import IDGenerator from cube.program import Program from cube.runtime.adapter.reducer import Reducer -from cube.runtime.module import CubeModule, ParallelModule, OriginModuleMetadata, ExtraState +from cube.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState from cube.runtime.device import DeviceGroup from cube.runtime.gnorm import calcuate_gnorm, clip_grads - +from cube.utils import get_member_by_name logger = logging.getLogger(__name__) @@ -700,6 +701,48 @@ def __init__(self, init_params=True): return cube_module +@dataclass(unsafe_hash=True) +class ModuleParameterLocation: + """ + the location of the parameters of a module in optimizer.param_groups[0]['params'] + [offset, offset + count) is the range of the parameters in optimizer.param_groups[0]['params'] + + Args: + offset: the first parameter's index in optimizer.state + count: represents the number of parameters within this module. + """ + offset: int + count: int + +@dataclass +class OptimizerExtraState: + """ + Args: + rank: the rank of the worker in torchrun + name: the name of the optimizer type + parallel_module_locs: the locations of the parameters of the parallelized module. + the key is the module prefix of the parallel module. + A module prefix is the same prefix used when you call `module.state_dict()` without the ending dot. + For example, if you have a module + module + submodule1_1 + submodule2_1 + submodule1_2 + then the prefix of `module` itself is `` (empty str). + the prefix of `submodule1_1` is `submodule1_1`. + the prefix of `submodule2_1` is `submodule1_1.submodule2_1`. + etc. + """ + rank: int + name: str + parallel_module_locs: Dict[str, ModuleParameterLocation] + + def __post_init__(self): + for k in self.parallel_module_locs: + if isinstance(self.parallel_module_locs[k], dict): + self.parallel_module_locs[k] = ModuleParameterLocation(**self.parallel_module_locs[k]) + + class ParallelOptimizer(torch.optim.Optimizer): """ A optimizer stub to support parallelized module. @@ -708,6 +751,8 @@ class ParallelOptimizer(torch.optim.Optimizer): # this is a reducer for non-parallel modules _non_parallel_module_reducer: Optional[Reducer] = None + # the extra state that will be used when loading state dict. + _extra_state: Optional[OptimizerExtraState] = None def sync_shard_grad(self): """ @@ -767,25 +812,6 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] """ ... -@dataclass(unsafe_hash=True) -class ModuleParameterLocation: - # the location of the parameters of a module in optimizer.param_groups[0]['params'] - # [offset, offset + count) is the range of the parameters in optimizer.param_groups[0]['params'] - offset: int - count: int - - -@dataclass -class OptimizerExtraState: - rank: int - name: str # the name of the optimizer - module_locs: Dict[str, ModuleParameterLocation] - - def __post_init__(self): - for k in self.module_locs: - if isinstance(self.module_locs[k], dict): - self.module_locs[k] = ModuleParameterLocation(**self.module_locs[k]) - OptimizerT = TypeVar('OptimizerT', bound=torch.optim.Optimizer) @@ -866,7 +892,13 @@ def build_optimizer( def _local_parameters(module: torch.nn.Module): cube_suffix = "_CUBE_SUFFIX" gen = module._named_members( - lambda m: [(cube_suffix, p) for p in m.parameters_for_optimizer()] # (cube_suffix, p) to meet _named_members requirement + lambda m: [ + (cube_suffix, p) # (cube_suffix, p) to meet _named_members requirement + for p in ( + m.parameters_for_optimizer() if m.get_compute_config().use_zero + else m.parameters() # `CubeModule.merge_partial_states` supports parameters_for_optimizer() only in zero mode + ) + ] if isinstance(m, ParallelModule) else m._parameters.items() ) @@ -884,6 +916,11 @@ def _local_parameters(module: torch.nn.Module): optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), *args, **kwargs) optimizer._non_parallel_module_reducer = non_parallel_module_reducer + optimizer._extra_state = OptimizerExtraState( + rank=torch.distributed.get_rank(), + name=type(optimizer).__name__, + parallel_module_locs=opt_module_locs, + ) def _step_pre_hook(opt, *args, **kwargs): opt.sync_shard_grad() @@ -907,11 +944,7 @@ def _patched_zero_grad(self, set_to_none: bool = True): orig_state_dict = optimizer.state_dict def _patched_state_dict(self): state_dict = orig_state_dict() - state_dict[ParallelModule.EXTRA_STATE_KEY] = asdict(OptimizerExtraState( - rank=torch.distributed.get_rank(), - name=type(optimizer).__name__, - module_locs=opt_module_locs, - )) + state_dict[ParallelModule.EXTRA_STATE_KEY] = asdict(optimizer._extra_state) return state_dict optimizer.state_dict = types.MethodType(_patched_state_dict, optimizer) @@ -1054,6 +1087,28 @@ def _get_optimizer_state_dict_info( ], Dict[str, Any] ]: + """ + An example of optimizer state dict: + { + 'state': { + 0: {'step': 10, 'exp_avg': ..., 'exp_avg_sq': ...}, + 1: {'step': 10, 'exp_avg': ..., 'exp_avg_sq': ...}, + # no 2 here, because param 2 is not used + 3: {'step': 10, 'exp_avg': ..., 'exp_avg_sq': ...}, + 4: {'step': 10, 'exp_avg': ..., 'exp_avg_sq': ...}, + 5: {'step': 10, 'exp_avg': ..., 'exp_avg_sq': ...}, + 6: {'step': 10, 'exp_avg': ..., 'exp_avg_sq': ...}, + # no 7 here, because param 7 is not used + }, + 'param_groups': [ { # we only support the case when there is only one param_group + 'lr': ..., + 'betas': ..., + 'eps': ..., + ..., + 'params': [0, 1, 2, 3, 4, 5, 6, 7] # all params will be listed here, no matter it is used or not + }] + } + """ ret_opt_state_dict = {'state': {}} # collect optimizer state dicts # merge ParallelModule state dicts @@ -1076,14 +1131,19 @@ def _get_optimizer_state_dict_info( raise ValueError("Only Adam-like optimizers are supported.") opt_extra_states[opt_extra_state.rank] = opt_extra_state - for module_prefix, loc in opt_extra_state.module_locs.items(): + for module_prefix, loc in opt_extra_state.parallel_module_locs.items(): if module_prefix not in opt_state_dicts: - opt_state_dicts[module_prefix] = [dict(state=[], param_groups=[]) for _ in range(len(optimizer_state_dicts))] + opt_state_dicts[module_prefix] = [dict(state={}, param_groups=[]) for _ in range(len(optimizer_state_dicts))] for i in range(loc.offset, loc.offset + loc.count): - opt_state_dicts[module_prefix][opt_extra_state.rank]['state'].append(opt_state_dict['state'][i]) + # if the parameter is not used or requires_grad is False, it will not be in the state dict + # for us, as we use a continous buffer, it will always have grad, so it will always be in the state dict + # the state for each parameters is inserted in Adam in a lazy way. + # see https://github.com/pytorch/pytorch/blob/dad1b765848c4f52501c4c60b1c3e6fbd3cc8837/torch/optim/adam.py#L103 + assert i in opt_state_dict['state'] + opt_state_dicts[module_prefix][opt_extra_state.rank]['state'][i - loc.offset] = opt_state_dict['state'][i] # TODO: inaccurate param_groups, for example, the 'params' in it is not right. # we have this to make `ParallelModule.merge_partial_states` happy. - opt_state_dicts[module_prefix][opt_extra_state.rank]['param_groups'] = opt_state_dict['param_groups'] + opt_state_dicts[module_prefix][opt_extra_state.rank]['param_groups'] = copy.deepcopy(opt_state_dict['param_groups']) for k, v in opt_state_dict.items(): if k == ParallelModule.EXTRA_STATE_KEY or k == 'state': @@ -1095,14 +1155,22 @@ def _get_optimizer_state_dict_info( return opt_extra_states, opt_state_dicts, ret_opt_state_dict +@torch.no_grad() def merge_state_dicts( - model_state_dicts: List[Dict[str, Any]], - optimizer_state_dicts: Optional[List[Dict[str, Any]]], + module_state_dicts: List[Dict[str, Any]], + optimizer_state_dicts: Optional[List[Dict[str, Any]]] = None, ) -> Tuple[Dict[str, Any], Optional[List[Dict[str, Any]]]]: """ Merge a list of shard state dicts (one for each rank) to a single full state dict Note: Only Adam-like optimizers are supported for merging + Please Note: + We don't garantee the devices of tensors are the same in the merged state dict. + You can assume the device of the tensors in the merged state dict can be one of the following: + 1. the current device when running this function + 2. the current cuda device when running this function + 3. the device of the tensor in the original state dict + When you load the state dict from file, you can just use `torch.load(..., map_location='...')` to unify the device of the tensors. Args: model_state_dicts (List[Dict[str, Any]]): the model state dicts from each rank optimizer_state_dicts (Optional[List[Dict[str, Any]]]): the optimizer state dicts from each rank @@ -1113,23 +1181,28 @@ def merge_state_dicts( if optimizer_state_dicts is not None: # TODO: support checkpoint optimization # where the following check may be too strong. - if len(model_state_dicts) != len(optimizer_state_dicts): + if len(module_state_dicts) != len(optimizer_state_dicts): raise ValueError("The length of model_state_dicts and optimizer_state_dicts should be the same.") - if not model_state_dicts: + if not module_state_dicts: raise ValueError("model_state_dicts should not be empty.") - pm_extra_states, pm_state_dicts, ret_state_dict = _get_parallel_module_state_dict_info(model_state_dicts) + pm_extra_states, pm_state_dicts, ret_state_dict = _get_parallel_module_state_dict_info(module_state_dicts) if optimizer_state_dicts is not None: opt_extra_states, opt_state_dicts, ret_opt_state_dict = _get_optimizer_state_dict_info(optimizer_state_dicts) # the new optimizer state dict for ParallelModules # key: the parallel module location in the optimizer state - # value: the new state values for the parallel module + # value: A tuple of + # 0. the new state values for the parallel module # (index is the parameter index in parallel module) - new_pm_states: Dict[ModuleParameterLocation, List[Any]] = {} + # 1. the module prefix + # 2. the original parameter names (OriginModuleMetadata.origin_param_names) + opt_new_pm_states: Dict[ModuleParameterLocation, Tuple[Dict[int, Any], str, List[str]]] = {} else: - opt_extra_states, opt_state_dicts, ret_opt_state_dict, new_pm_states = None, None, None, None + opt_extra_states, opt_state_dicts, ret_opt_state_dict, opt_new_pm_states = None, None, None, None - # do merging + # merging parallel module state dicts, + # non parallel module parts have been handled at _get_parallel_module_state_dict_info + # and _get_optimizer_state_dict_info # every loop will merge one ParallelModule for k, state_dicts_for_merge in pm_state_dicts.items(): extra_states = pm_extra_states[k] @@ -1150,48 +1223,354 @@ def merge_state_dicts( ) # merge back module state dict - for km, vm in merged_state_dict.items(): - key = km if not module_prefix else f'{module_prefix}.{km}' - ret_state_dict[key] = vm + # all ranks have the same extra_states + origin_state_dict_names = extra_states[0].origin_state_dict_names + shared_param_names = extra_states[0].origin_shared_param_names + for name in origin_state_dict_names: + key = name if not module_prefix else f'{module_prefix}.{name}' + if name in merged_state_dict: + ret_state_dict[key] = merged_state_dict[name] + else: + name_in_merged = _get_valid_name_from_merged_model(name, shared_param_names, merged_state_dict) + if name_in_merged is not None: + ret_state_dict[key] = merged_state_dict[name_in_merged] + key_in_merged = name_in_merged if not module_prefix else f'{module_prefix}.{name_in_merged}' + logger.warning( + f"Missing param/buffer {key} in merged_model_state_dict, " + f"safely using its shared param/buffer {key_in_merged} as {key}." + ) + else: + logger.warning( + f"Missing param/buffer {key} in merged_model_state_dict, " + f"high likely because {key} is created but not used in your model." + ) # merge back opt state dict if opt_state_dicts is not None: - opt_module_locs = [opt_extra_states[i].module_locs[module_prefix] for i in range(len(opt_extra_states))] + opt_module_locs = [opt_extra_states[i].parallel_module_locs[module_prefix] for i in range(len(opt_extra_states))] # Assume all ranks have the same opt_module_locs (offset and count) # TODO: assert may fail for pipeline parallelism for i in range(1, len(opt_module_locs)): assert opt_module_locs[i] == opt_module_locs[0] - new_pm_states[opt_module_locs[0]] = _int_dict_to_list(merged_opt_state_dict['state']) - - if new_pm_states: - ret_state_list: List[Any] = _int_dict_to_list(optimizer_state_dicts[0]['state']) - sorted_keys = sorted(new_pm_states.keys(), key=lambda x: x.offset, reverse=True) - for loc in sorted_keys: - # this assign is only safe in reverse order - ret_state_list[loc.offset:loc.offset + loc.count] = new_pm_states[loc] - ret_opt_state_dict['state'] = {i: v for i, v in enumerate(ret_state_list)} + opt_new_pm_states[opt_module_locs[0]] = (merged_opt_state_dict['state'], module_prefix, extra_states[0].origin_param_names) + + if opt_new_pm_states: + pm_orig_param_names: Dict[str, List[str]] = {} + for k, extra_states in pm_extra_states.items(): + module_prefix = '.'.join(k) + pm_orig_param_names[module_prefix] = CubeModule.get_origin_parameter_names([e.param_area_map for e in extra_states]) + # now we can construct the merged state of optimizer from any rank + # let's just use the first rank + orig_states: Dict[int, Any] = optimizer_state_dicts[0]['state'] + ret_states: Dict[int, Any] = {} # see `_get_optimizer_state_dict_info` for the value structure. + sorted_pm_locs = sorted(opt_new_pm_states.keys(), key=lambda x: x.offset) + assert len(optimizer_state_dicts[0]['param_groups']) == 1 + orig_effective_state_len = len(optimizer_state_dicts[0]['param_groups'][0]['params']) + orig_cur_index = 0 # index of orig_states + ret_states_cur_index = 0 # index of ret_state_dict + sorted_pm_locs_cur_index = 0 # index of sorted_pm_locs + while orig_cur_index < orig_effective_state_len: + if ( + sorted_pm_locs_cur_index >= len(sorted_pm_locs) # after all parallel module parameters + or orig_cur_index < sorted_pm_locs[sorted_pm_locs_cur_index].offset # not in the range of current parallel module + ): + # non parallel module paramters + if orig_cur_index in orig_states: + ret_states[ret_states_cur_index] = orig_states[orig_cur_index] + orig_cur_index += 1 + ret_states_cur_index += 1 + else: + # parallel module parameters + pm_loc = sorted_pm_locs[sorted_pm_locs_cur_index] + state, module_prefix, orignal_param_names = opt_new_pm_states[pm_loc] + named_state = {} # the state dict with named keys + for i, v in state.items(): + named_state[pm_orig_param_names[module_prefix][i]] = v + # reorder with the order of original param names + for i, name in enumerate(orignal_param_names): + if name in named_state: + v = named_state[name] + ret_states[ret_states_cur_index + i] = v + # always increase the index by the count of the original module parameters + ret_states_cur_index += len(orignal_param_names) + orig_cur_index += pm_loc.count + sorted_pm_locs_cur_index += 1 + + ret_opt_state_dict['state'] = ret_states + ret_opt_state_dict['param_groups'][0]['params'] = list(range(ret_states_cur_index)) return ret_state_dict, ret_opt_state_dict -def load_merged_state_dicts(module: torch.nn.Module, state_dict: Dict[str, Any]) -> torch.nn.Module: +@torch.no_grad() +def load_merged_state_dicts( + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + device: Union[str, torch.device] = None +): """ - Load the merged state dicts to the module and optimizer. - + Load the merged state dicts to the module, and optionally the optimizer to a specified device. Args: module (torch.nn.Module): the module to be loaded - state_dict (Dict[str, Any]): the merged model state dict + module_state_dict (Dict[str, Any]): the merged model state dict + optimizer (Optional[torch.optim.Optimizer]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. Returns: - torch.nn.Module: the module after loading the state dict + None """ + device = device or torch.cuda.current_device() + + # non ParallelModule parameters will be loaded here # there will be mismatched keys if the module is a ParallelModule or contains ParallelModule # so we need to ignore the mismatched keys - module.load_state_dict(state_dict, strict=False) + module.load_state_dict(module_state_dict, strict=False) # load ParallelModule state dicts for name, child_module in module.named_modules(): if isinstance(child_module, ParallelModule): prefix = name + '.' if name else '' - child_module.load_merged_state_dict(state_dict, prefix=prefix) + child_module.load_merged_state_dict(module_state_dict, prefix=prefix) + + module.to(device) + + if optimizer is not None and optimizer_state_dict is not None: + if 'adam' not in optimizer._extra_state.name.lower(): + raise ValueError("Only Adam-like optimizers are supported.") + + # handle non-paralleled module parameters + # make sure the order of the parameters + pm_name_locs: Dict[str, ModuleParameterLocation] = dict(sorted(optimizer._extra_state.parallel_module_locs.items(), key=lambda x: x[1].offset)) + pm_modules: List[torch.nn.Module] = [] + pm_locs = list(pm_name_locs.values()) + for name in pm_name_locs: + m = get_member_by_name(module, name) + if not isinstance(m, ParallelModule): + raise ValueError(f"Module {name} is not a ParallelModule") + pm_modules.append(m) + + merged_cur = 0 # the current index of the merged state dict + pm_cur = 0 # the current index of the parallel module in pm_locs + new_states: Dict[int, Dict[str, Any]] = {} + new_cur = 0 # the current index of the new state dict + assert len(optimizer_state_dict['param_groups']) == 1 + effective_state_len = len(optimizer_state_dict['param_groups'][0]['params']) + while merged_cur < effective_state_len: + # N: non-paralleled module parameters, P: paralleled module (will have multiple parameters) + # The parameter list would look like: NNPNPPPN + # []: the current processing parameter + # <>: the current processing parallel module + if ( + pm_cur >= len(pm_modules) # NNPNPPP[N]: the ending parameters, no current parallel module + or new_cur < pm_locs[pm_cur].offset # [N]N

NPPPN: other parameters + ): + # non-parallel module + if merged_cur in optimizer_state_dict['state']: + new_states[new_cur] = optimizer_state_dict['state'][merged_cur] + merged_cur += 1 + new_cur += 1 + else: + # NNPN<[P]PP>N: the current parallel module + # parallel module + pm_param_count = len(pm_modules[pm_cur]._orign_module_metadata.origin_param_names) + # will map `pm_param_count` parameters in merge state dict + # to `pm_locs[pm_cur].count` in optimizer state. + cur_states = {} + for i in range(pm_param_count): + if merged_cur + i in optimizer_state_dict['state']: + cur_states[i] =optimizer_state_dict['state'][merged_cur + i] + pm_new_states = _opt_load_merged_state_dict(pm_modules[pm_cur], cur_states) + for idx, value in pm_new_states.items(): + new_states[new_cur + idx] = value + new_cur += pm_locs[pm_cur].count + merged_cur += pm_param_count + pm_cur += 1 + + # move the new states to the device if needed + for idx, state in new_states.items(): + for key, value in state.items(): + if isinstance(value, torch.Tensor): + new_states[idx][key] = value.to(device) + + new_optimizer_state_dict = {} + new_optimizer_state_dict['state'] = new_states + new_optimizer_state_dict['param_groups'] = copy.deepcopy(optimizer_state_dict['param_groups']) + new_optimizer_state_dict['param_groups'][0]['params'] = list(range(new_cur)) + optimizer.load_state_dict(new_optimizer_state_dict) + + +def _opt_load_merged_state_dict(module: ParallelModule, states: Dict[int, Dict[str, Any]]): + with torch.no_grad(): + # orig_name -> state + orig_param_dict: Dict[str, Dict[str, Any]] = {} + cnt = 0 + origin_param_names = module._orign_module_metadata.origin_param_names + for name in origin_param_names: + if cnt in states: # some parameters may not in the sates when it is not used or requires_grad is False in training + orig_param_dict[name] = states[cnt] + cnt = cnt + 1 + + if module.get_compute_config().use_zero: + return _construct_optim_state_zero(module, orig_param_dict) + else: + return _construct_optim_state_nonzero(module, orig_param_dict) + + +def _construct_optim_state_zero( + module: ParallelModule, + orig_param_dict: Dict[str, Dict[str, Any]], +): + dist_param_map = module.get_dist_param_map() # name in parallel module (without tid suffix) -> name in origin module + param_area_map = module.get_full_map() # str -> AttrMeta + def _get_optimizer_state_of_param(param, param_ids, local_names): + # find the parameter's optimizer state and pick the slices induced by tensor parallelism + param_idx = param_ids.index(id(param)) + local_name = local_names[param_idx] + return _extract_new_state(local_name, orig_param_dict, dist_param_map, param_area_map) + + # prepare param ids and corresponding local param names + param_ids, local_names = [], [] + for local_name, param in module.named_parameters(): + param_ids.append(id(param)) + local_names.append(local_name) + state_dict, opt_param_idx = {}, 0 + opt_param = module.parameters_for_optimizer() + # first load the params' optimizer state for the reducers's flattened params + for reducer in module.reducers: + rank_idx, sub_ranks = module._get_zero_subranks(reducer) + for bucket in reducer.buckets: + # one bucket corresponds to one flattened param + assert len(opt_param[opt_param_idx].shape) == 1 + assert bucket._contiguous_params.shape[0] % len(sub_ranks) == 0 + chunk_size = bucket._contiguous_params.shape[0] // len(sub_ranks) + # the flattened param is in the range [bucket_chunk_start, bucket_chunk_end) + bucket_chunk_start = rank_idx * chunk_size + bucket_chunk_end = (rank_idx + 1) * chunk_size + # NOTE: assume the traverse order of params is consistent + # with them in contiguous buffer. + # param_offset: the param's start offset in the contiguous buffer + # chunk_offset: the current offset of the current rank corresponding chunk + param_offset, chunk_offset = 0, 0 + step, opt_states, opt_state_keys = None, {}, None + for param in bucket.params: + sliced_new_val = _get_optimizer_state_of_param(param, param_ids, local_names) + # init the chunk's optimizer state + if opt_state_keys is None: + opt_state_keys = [key for key in sliced_new_val] + if 'step' in sliced_new_val: + step = sliced_new_val['step'] + if 'step' in sliced_new_val: + opt_state_keys.remove('step') + for key in opt_state_keys: + opt_states[key] = torch.zeros([chunk_size], dtype=sliced_new_val[key].dtype, + device=sliced_new_val[key].device, requires_grad=False) + # copy the param's slices to the optimizer's chunk + for key in opt_state_keys: + sliced_new_val[key] = sliced_new_val[key].view(-1) + + # parameter range: <> + # bucket range: [] + if param_offset < bucket_chunk_start \ + and bucket_chunk_start < param_offset + param.numel() < bucket_chunk_end: + # case: < [ > ] + copy_size = param_offset + param.numel() - bucket_chunk_start + for key in opt_state_keys: + opt_states[key][chunk_offset:chunk_offset+copy_size] = sliced_new_val[key][-copy_size:] + chunk_offset += copy_size + elif bucket_chunk_start <= param_offset < bucket_chunk_end \ + and bucket_chunk_start <= param_offset + param.numel() < bucket_chunk_end: + # case: [ < > ] + for key in opt_state_keys: + opt_states[key][chunk_offset:chunk_offset+param.numel()] = sliced_new_val[key][:] + chunk_offset += param.numel() + elif bucket_chunk_start <= param_offset < bucket_chunk_end \ + and param_offset + param.numel() >= bucket_chunk_end: + # case: [ < ] > + copy_size = bucket_chunk_end - param_offset + for key in opt_state_keys: + opt_states[key][chunk_offset:chunk_offset+copy_size] = sliced_new_val[key][:copy_size] + chunk_offset += copy_size + elif param_offset < bucket_chunk_start \ + and param_offset + param.numel() >= bucket_chunk_end: + # case: < [ ] > + copy_size = bucket_chunk_end - bucket_chunk_start + for key in opt_state_keys: + opt_states[key][chunk_offset:chunk_offset + copy_size] \ + = sliced_new_val[key][bucket_chunk_start-param_offset:bucket_chunk_start-param_offset + copy_size] + chunk_offset += copy_size + else: + # case: [] <>, <> [] + logger.debug(f'Skipped: parameter range({param_offset},{param_offset + param.numel()}) vs. bucket range({bucket_chunk_start},{bucket_chunk_end})') + param_offset += param.numel() + # as there is padding in chunk, slicing to obtain the correct shape opt states + for key in opt_state_keys: + opt_states[key] = opt_states[key][:opt_param[opt_param_idx].shape[0]] + if step is not None: + opt_states['step'] = step + state_dict[opt_param_idx] = opt_states + opt_param_idx += 1 + # load the params' optimizer state that are not in reducers + # this part corresponds to cube/runtime/module.py: parameters_for_optimizer + reducer_pids = set() + for reducer in module.reducers: + reducer_pids.update(id(p) for p in reducer.params) + for param in module.parameters(): + if id(param) not in reducer_pids: + sliced_new_val = _get_optimizer_state_of_param(param, param_ids, local_names) + state_dict[opt_param_idx] = sliced_new_val + opt_param_idx += 1 + return state_dict + + +def _construct_optim_state_nonzero( + module: ParallelModule, + orig_param_dict: Dict[str, Dict[str, Any]] +): + dist_param_map = module.get_dist_param_map() # name in parallel module (without tid suffix) -> name in origin module + param_area_map = module.get_full_map() # str -> AttrMeta - return module + new_states = {} + for index, (local_name, _) in enumerate(module.named_parameters()): + new_states[index] = _extract_new_state(local_name, orig_param_dict, dist_param_map, param_area_map) + + return new_states + + +def _extract_new_state( + local_name: str, + orig_param_dict: Dict[str, Dict[str, Any]], + dist_param_map: Dict[str, str], + param_area_map: Dict[str, AttrMeta], +): + name = '_'.join(local_name.split('_')[:-1]) # remove the integer suffix + assert name in dist_param_map + attr_meta = param_area_map[local_name] + new_val = orig_param_dict[dist_param_map[name]] + sliced_new_val = {} + for key in new_val: + if key in ('step',): + sliced_new_val[key] = new_val[key] + else: + sliced_new_val[key] = new_val[key][attr_meta.slicers] / attr_meta.val_chunks + return sliced_new_val + + +def _get_valid_name_from_merged_model( + target_name: str, + shared_param_names: List[List[str]], + merged_model_state_dict: Dict[str, Any] +) -> Optional[str]: + """Find target_name in one set of shared_param_names, then find a name in merged_model_state_dict + that is in the same set as target_name. + """ + for shared_names in shared_param_names: + if target_name in shared_names: + for name in shared_names: + if name in merged_model_state_dict: + return name + break + return None diff --git a/cube/runtime/adapter/reducer.py b/cube/runtime/adapter/reducer.py index 81ff105a..df61ed08 100644 --- a/cube/runtime/adapter/reducer.py +++ b/cube/runtime/adapter/reducer.py @@ -145,9 +145,13 @@ def build(self): opt = self._contiguous_params[:self._numel] else: rank = torch.distributed.get_rank(group=self._zero_subgroup) + assert len(self._contiguous_params) % self._zgroup_sz == 0 + # Note: + # There may be paddings both in the middle and at the end of the contiguous buffer + # When there are paddings in the middle or end of the contiguous buffer, + # the calculation of gnorm is not affected as long as the paddings are all 0. + # So for now, it looks harmless. opt = self._contiguous_params.chunk(self._zgroup_sz)[rank] - if rank == self._zgroup_sz - 1 and self._padding != 0: - opt = opt[:-self._padding] self._param_for_optimizer = torch.nn.Parameter(opt) def register_hooks(self): @@ -246,8 +250,6 @@ def sync_grads(self): if self._zero: rank = torch.distributed.get_rank(group=self._zero_subgroup) grad = self._contiguous_grads.chunk(self._zgroup_sz, dim=0)[rank] - if rank == self._zgroup_sz - 1 and self._padding != 0: - grad = grad[:-self._padding] self._param_for_optimizer.grad = grad else: self._param_for_optimizer.grad = self._contiguous_grads[:self._numel] @@ -559,6 +561,14 @@ def zero_grad(self): def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: """ Get parameters for optimizers + Please note for ZeRO optimization, + the returned parameters are not the same as the original parameters, + and can have paddings (with value 0.0) both at the end and in the middle of paramters data. + + the calculation of gnorm is not affected as paddings are all 0. + + Returns: + List[torch.nn.Parameter]: parameters for optimizer """ params = [] for bucket in self._buckets: diff --git a/cube/runtime/module.py b/cube/runtime/module.py index cd4d79a7..18ed857a 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -276,6 +276,23 @@ def merge_model_state_dicts(state_dicts: List[Dict], full_model_state_dict[meta.orig_name][meta.slicers] = partial_tensor return full_model_state_dict + @staticmethod + def get_origin_parameter_names(fullmaps: List[Dict[str, AttrMeta]]): + """ + Get a list of original parameter names from the fullmaps. + `merge_partial_states` will use this list to build the parameter order + """ + origin_parameter_names: List[str] = [] + for local_fullmap in fullmaps: + for _, meta in local_fullmap.items(): + if not meta.is_param: continue + # shared parameters in CubeModule is already de-duplicated. So in the + # local model state, we will not have multiple parameters sharing with same content + # but in different names. + if meta.orig_name not in origin_parameter_names: + origin_parameter_names.append(meta.orig_name) + return origin_parameter_names + @staticmethod def merge_partial_states(state_dicts: List, zero_idx_maps=None): @@ -412,15 +429,7 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): # both parameters and buffers, we need to remove the buffers from the list. # More details refer to the implementation: # https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module._save_to_state_dict - origin_parameter_names: List[str] = [] - for local_fullmap in fullmaps: - for _, meta in local_fullmap.items(): - if not meta.is_param: continue - # shared parameters in CubeModule is already de-duplicated. So in the - # local model state, we will not have multiple parameters sharing with same content - # but in different names. - if meta.orig_name not in origin_parameter_names: - origin_parameter_names.append(meta.orig_name) + origin_parameter_names: List[str] = CubeModule.get_origin_parameter_names(fullmaps) # handle 'state' in optimizer state dict # NOTE: each rank may have its local optimizer state working on a sub-set @@ -498,9 +507,9 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): @dataclass class OriginModuleMetadata: - origin_state_dict_names: List[str] - origin_param_names: List[str] - origin_shared_param_names: List[Set[str]] + origin_state_dict_names: List[str] # used for merging module state dict + origin_param_names: List[str] # used for merging optimizer state dict + origin_shared_param_names: List[List[str]]# used for merging module state dict @dataclass @@ -803,6 +812,8 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s for orig_param_name in orig_param_names: orig_param_name_with_prefix = prefix + orig_param_name + if orig_param_name_with_prefix not in state_dict: + continue param_value = state_dict[orig_param_name_with_prefix] tid = origname_tid_map[orig_param_name] for attr, slicer, nchunks in tid_info[tid]: diff --git a/cube/utils.py b/cube/utils.py index 90f44686..697a2368 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Tuple, Callable, List, Set +from typing import Optional, Tuple, Callable, List, Set, Any import logging from pathlib import Path import sys @@ -81,10 +81,13 @@ def load_eval_schedule(filename: Optional[str] = None): return module._infer_step -def get_param_by_name(model: torch.nn.Module, name: str) -> torch.nn.Parameter: +def get_member_by_name(model: torch.nn.Module, name: str) -> Any: """ - Get the parameter of the model by its full name. + Get the member of the model by its full name. + if name is empty, return the model itself. """ + if not name: + return model sliced_names = name.split(".") model_attr = model for sliced_name in sliced_names: @@ -92,13 +95,13 @@ def get_param_by_name(model: torch.nn.Module, name: str) -> torch.nn.Parameter: return model_attr -def get_shared_params(model: torch.nn.Module) -> List[Set[str]]: +def get_shared_params(model: torch.nn.Module) -> List[List[str]]: paramid2name = defaultdict(set) for name in model.state_dict().keys(): - param = get_param_by_name(model, name) + param = get_member_by_name(model, name) paramid = id(param) paramid2name[paramid].add(name) - return [names for _, names in paramid2name.items() if len(names) > 1] + return [list(names) for _, names in paramid2name.items() if len(names) > 1] class accum_mode: diff --git a/docs/parallel_module.md b/docs/parallel_module.md index 7d30b236..c17f6094 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -323,6 +323,54 @@ optimizer.register_reducer_post_hook(lambda reducer, grad: grad.mul_(num_scale_u 5. `_non_parallel_module_reducer`: The reducer for the modules which are not parallelized. It is used to sync the parameters in those modules across units. +### Checkpoint support + +You can save/load the checkpoints for parallel modules. +Each rank will save/load its own checkpoint just like the normal module. + +You can also merge the checkpoints from different ranks to a single checkpoint. +We call it a merged checkpoint. The merged checkpoint can be loaded by original module directly. +So you can easily share the checkpoint with the original module. + + +We provide two functions to help you save/load the merged checkpoint for the parallel module. + +#### `merge_state_dicts` +```python +def merge_state_dicts( + module_state_dicts: List[Dict[str, Any]], + optimizer_state_dicts: Optional[List[Dict[str, Any]]], +) -> Tuple[Dict[str, Any], Optional[List[Dict[str, Any]]]]: +``` + +Merge a list of shard state dicts (one for each rank) to a single full state dict +Note: Only Adam-like optimizers are supported for merging + +Please Note: + We don't garantee the devices of tensors are the same in the merged state dict. + You can assume the device of the tensors in the merged state dict can be one of the following: + 1. the current device when running this function + 2. the current cuda device when running this function + 3. the device of the tensor in the original state dict + When you load the state dict from file, you can just use `torch.load(..., map_location='...')` to unify the device of the tensors. + + +#### `load_merged_state_dicts` +```python +def load_merged_state_dicts( + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + device: Union[str, torch.device] = None +) -> None: +``` +Load the merged state dicts to the module, and optionally the optimizer to a specified device. + +Please note the `device` parameter. If it is None, we will use `torch.cuda.current_device()` to get the current device. If you want to load the state dict to a specific device, you can set it to the device you want. + + ### Dataset We use the same dataset/dataloader as pytorch. For example, you can use `torch.utils.data.DistributedSampler` to create a distributed sampler. diff --git a/tests/launch_torchrun.py b/tests/launch_torchrun.py index 6267aa19..f66a010f 100644 --- a/tests/launch_torchrun.py +++ b/tests/launch_torchrun.py @@ -11,7 +11,7 @@ def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): max_nodes=1, nproc_per_node=nproc_per_node, rdzv_backend = "c10d", - rdzv_endpoint = "localhost:29400", + rdzv_endpoint = "localhost:29401", run_id = str(uuid.uuid4()), monitor_interval=0.1, max_restarts=0, diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index feab1b4b..28b7ca91 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -20,6 +20,7 @@ from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively +from ..utils import replace_all_device_with class FcRelu(nn.Module): @@ -38,6 +39,9 @@ def forward(self, x): class FcRelu_4_4(FcRelu): def __init__(self): super().__init__(4, 4) + self.register_buffer('buffer', torch.ones(1, 4)) + def forward(self, x): + return super().forward(x + self.buffer) def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): @@ -72,7 +76,7 @@ def forward(self, x): x = self.sigmoid(x) return x CompiledModule = _to_cube_model(CompiledModule, pas, compute_config, cube_savedir, 'whole') - else: + elif module_type == 'sub': class CompiledModule(torch.nn.Module): def __init__(self): super().__init__() @@ -90,6 +94,54 @@ def forward(self, x): x = self.linear3(x) x = self.sigmoid(x) return x + elif module_type == 'start': + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = _to_cube_model(CubeLinear(4, 4, bias=True), + pas, compute_config, cube_savedir, f'start_linear1' + ) + self.linear2 = CubeLinear(4, 1, bias=True) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sigmoid(x) + return x + elif module_type == 'end': + # parallel module as the last module + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = CubeLinear(4, 4, bias=True) + self.linear2 = _to_cube_model(CubeLinear(4, 4, bias=True), + pas, compute_config, cube_savedir, f'end_linear2' + ) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = torch.sum(x, dim=1, keepdim=True) + x = self.sigmoid(x) + return x + elif module_type == 'small': + # num of parameter elements is small + class CompiledModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = CubeLinear(4, 4, bias=True) + self.linear2 = _to_cube_model(CubeLinear(4, 1, bias=True), + pas, compute_config, cube_savedir, f'small_linear2' + ) + # the following tests depend on the rngstate in PASRandomSPMD + assert len(self.linear2.reducers) == 1 + assert len(self.linear2.reducers[0].ranks) == 4 + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sigmoid(x) + return x init_random() compiled_module = CompiledModule().cuda() return compiled_module @@ -134,20 +186,33 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): assert ckpt_start_merged_file.exists() merged_ckpt_dict = torch.load(ckpt_start_merged_file) merged_model_state_dict = merged_ckpt_dict['model'] - model_from_merged = load_merged_state_dicts(type(model)(), merged_model_state_dict) + merged_opt_state_dict = merged_ckpt_dict['optimizer'] + model_from_merged = type(model)() + optimizer_from_merged = build_optimizer(model_from_merged, torch.optim.Adam, lr=0.01) + load_merged_state_dicts( + model_from_merged, merged_model_state_dict, + optimizer_from_merged, merged_opt_state_dict, + ) # check merged model result_orig_model_state_dict = model.state_dict() result_merged_model_state_dict = model_from_merged.state_dict() assert set(result_orig_model_state_dict.keys()) == set(result_merged_model_state_dict.keys()) - for k in result_orig_model_state_dict.keys(): - if k.endswith('CUBE_EXTRA_STATE'): + for index in result_orig_model_state_dict.keys(): + if index.endswith('CUBE_EXTRA_STATE'): continue - assert torch.equal(result_orig_model_state_dict[k], result_merged_model_state_dict[k]) - - # TODO: check merged optimizer - # merged_optimizer_state_dict = merged_ckpt_dict['optimizer'] + assert torch.equal(result_orig_model_state_dict[index], result_merged_model_state_dict[index]) + result_orig_opt_state_dict = optimizer.state_dict() + result_merged_opt_state_dict = optimizer_from_merged.state_dict() + assert set(result_orig_opt_state_dict.keys()) == set(result_merged_opt_state_dict.keys()) + assert result_orig_opt_state_dict['CUBE_EXTRA_STATE'] == result_merged_opt_state_dict['CUBE_EXTRA_STATE'] + assert result_orig_opt_state_dict['param_groups'] == result_merged_opt_state_dict['param_groups'] + assert set(result_orig_opt_state_dict['state']) == set(result_merged_opt_state_dict['state']) + for index in result_orig_opt_state_dict['state']: + for key in ('step', 'exp_avg', 'exp_avg_sq'): + assert torch.equal(result_orig_opt_state_dict['state'][index][key], result_merged_opt_state_dict['state'][index][key]) + torch.distributed.barrier() data = [] init_random() for _ in range(DATA_SIZE): @@ -186,8 +251,9 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): assert f'{prefix}CUBE_EXTRA_STATE' in model_state_dict extra_state1 = ExtraState(**model_state_dict[f'{prefix}CUBE_EXTRA_STATE']) assert extra_state1.compute_config - assert extra_state1.model_idx2opt_idx - assert extra_state1.opt_idx2ranks + if extra_state1.compute_config.use_zero: + assert extra_state1.model_idx2opt_idx + assert extra_state1.opt_idx2ranks assert extra_state1.origin_param_names optimizer_state_dict = optimizer.state_dict() assert 'CUBE_EXTRA_STATE' in optimizer_state_dict @@ -206,10 +272,11 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): 'model': merged_model_state_dicts, 'optimizer': merged_optimizer_state_dict }, ckpt_merged_file) + torch.distributed.barrier() return results -def _gpu_worker(module_type, pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_count): +def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_count, check_module=None): init_distributed() compiled_results = [] with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: @@ -217,10 +284,12 @@ def _gpu_worker(module_type, pas, plan_ngpus, runtime_ngpus, per_resume_update_c start = i * per_resume_update_count end = (i + 1) * per_resume_update_count compiled_module = _create_cube_module(pas, - ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=True), + ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero), tempdir, module_type, ) + if check_module: + check_module(compiled_module) compiled_results.extend(_train( compiled_module, runtime_ngpus // plan_ngpus, @@ -230,10 +299,13 @@ def _gpu_worker(module_type, pas, plan_ngpus, runtime_ngpus, per_resume_update_c return compiled_results @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') -@pytest.mark.parametrize('module_type', ['sub', 'whole']) -def test_checkpoint(module_type): - cube_results = launch_torchrun(4, _gpu_worker, module_type, PASRandomSPMD, 2, 4, 32, 1) - rcube_results = launch_torchrun(4, _gpu_worker, module_type, PASRandomSPMD, 2, 4, 16, 2) +@pytest.mark.parametrize('module_type', ['sub', 'whole', 'start', 'end', 'small']) +@pytest.mark.parametrize('use_zero', [True, False]) +def test_checkpoint(module_type, use_zero): + plan_ngpus = 2 + runtime_ngpus = 4 + cube_results = launch_torchrun(4, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus, 32, 1) + rcube_results = launch_torchrun(4, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus, 16, 2) results0, results1, results2, results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] rresults0, rresults1, rresults2, rresults3 = rcube_results[0], rcube_results[1], rcube_results[2], rcube_results[3] @@ -265,3 +337,50 @@ def test_checkpoint(module_type): assert torch.equal(a.grads[k], b.grads[k]) for k in a.weights.keys(): # weights assert torch.equal(a.weights[k], b.weights[k]) + + +def assert_intra_reducer(module: ParallelModule): + assert module.get_compute_config().plan_ngpus == module.get_compute_config().runtime_ngpus + assert len(module.reducers) > 0 + # so we have both parameters in reducers and not in reducers + # (assume one reducer gives one bucket, which is true in general.) + assert len(module.parameters_for_optimizer()) > len(module.reducers) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('module_type', ['whole']) +@pytest.mark.parametrize('use_zero', [True, False]) +def test_checkpoint_intra_reducer(module_type, use_zero): + """ + Test when: + Some of the parameters will be added to reducers, + but some of the parameters are not. + """ + plan_ngpus = 2 + runtime_ngpus = 2 + cube_results = launch_torchrun(2, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus, 32, 1, assert_intra_reducer) + rcube_results = launch_torchrun(2, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus, 16, 2, assert_intra_reducer) + results0 = cube_results[0] + rresults0 = rcube_results[0] + + # pred, loss + for r0, r1 in [(results0, rresults0)]: + # have the same input + assert len(r0) == len(r1) # iteration count + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.pred, b.pred) # pred + assert torch.equal(a.loss, b.loss) # loss + assert torch.equal(a.gnorm, b.gnorm) # gnorm + + # grad, weights + for r0, r1 in [(results0, rresults0)]: + # in the same shard, grads and weights are the same + assert len(r0) == len(r1) + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.gnorm, b.gnorm) # gnorm + for k in a.grads.keys(): # grad + assert torch.equal(a.grads[k], b.grads[k]) + for k in a.weights.keys(): # weights + assert torch.equal(a.weights[k], b.weights[k]) diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py new file mode 100644 index 00000000..40b8f32f --- /dev/null +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -0,0 +1,70 @@ +from pathlib import Path +import tempfile +import torch + +import pytest + +from cube.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dicts + +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun + + +class Net1(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer('buffer', torch.ones(128, 64), persistent=False) + self.fc = torch.nn.Linear(64, 64) + + # x with shape [128, 64] + def forward(self, x): + return self.fc(x + self.buffer) + + +class Net2(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer('buffer', torch.ones(256, 64), persistent=False) + self.fc = torch.nn.Linear(64, 64) + + # x with shape [256, 64] + def forward(self, x): + return self.fc(x + self.buffer) + + +def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_shape): + return parallelize( + module, + {'x': torch.randn(input_shape)}, + PASRandomSPMD, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + + +def _gpu_worker(): + init_distributed() + compute_config = ComputeConfig(1, 1, use_zero=False) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + net1 = _to_cube_model(Net1(), compute_config, tempdir, 'net1', (128, 64)) + cube_state_dict = net1.state_dict() + assert any(key.startswith('buffer') for key in cube_state_dict) + merged_state_dict, _ = merge_state_dicts([cube_state_dict]) + assert 'buffer' not in merged_state_dict + + net2 = Net2() + net2.load_state_dict(merged_state_dict, strict=False) # should success + + net2 = _to_cube_model(Net2(), compute_config, tempdir, 'net2', (256, 64)) + net2.load_merged_state_dict(merged_state_dict, strict=False) # should success + + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_checkpoint_buffer(): + """ + Please note the buffer size in Net1 and Net2 are different. + """ + launch_torchrun(1, _gpu_worker) diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py new file mode 100644 index 00000000..4bb4785a --- /dev/null +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -0,0 +1,197 @@ +import tempfile +from pathlib import Path +import pytest +from typing import Dict, Tuple, List, Any + +import torch +from torch import nn + +from cube.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts + +from .common import PASRandomSPMD, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun + + +class FcReluWithShared(nn.Module): + def __init__(self, in_features, bias=True): + super().__init__() + init_random() + self.unused_fc1 = CubeLinear(in_features, in_features, bias=bias) + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, in_features, bias=bias) + self.fc2.fc.weight = self.fc1.fc.weight # share the weights + self.relu2 = nn.ReLU() + + + def forward(self, x): + return self.relu2(self.fc2(self.relu1(self.fc1(x)))) + + +class FcRelu_4_WithShared(FcReluWithShared): + def __init__(self): + super().__init__(4, 4) + + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + + +def _create_cube_module(pas, compute_config, cube_savedir, module_type='raw'): + init_random() + if module_type == 'raw': + class RawModuleWithShared(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = FcRelu_4_WithShared() + self.unused_linear1 = nn.Linear(4, 4) + self.linear2 = nn.Linear(4, 4) + self.linear2.weight = self.linear1.weight # share the weights + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + return RawModuleWithShared().cuda() + else: + class ParallelModuleWithShared(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = _to_cube_model( + FcRelu_4_WithShared(), pas, + compute_config, cube_savedir, 'fc_relu1' + ) + self.unused_linear1 = nn.Linear(4, 4) + self.linear2 = nn.Linear(4, 4) + self.linear2.weight = self.linear1.weight # share the weights + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + return ParallelModuleWithShared().cuda() + + +DATA_SIZE = 256 +RAW_CKPT_FILE_NAME = 'raw.pth' + + +def _train_raw(model: torch.nn.Module, ckpt_dir): + DATA = [] + for _ in range(DATA_SIZE): + DATA.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + loss_fn = nn.BCELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + for i, (x, y) in enumerate(DATA): + model.train() + optimizer.zero_grad() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.step() + torch.save({ + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict() + }, ckpt_dir / RAW_CKPT_FILE_NAME) + + +def _load_merged(parallel_model: torch.nn.Module, ckpt_dir): + raw_ckpt_dict = torch.load(ckpt_dir / RAW_CKPT_FILE_NAME) + raw_model_state_dict: Dict[str, Any] = raw_ckpt_dict['model'] + raw_opt_state_dict = raw_ckpt_dict['optimizer'] + optimizer = build_optimizer(parallel_model, torch.optim.Adam, lr=0.01) + load_merged_state_dicts( + parallel_model, raw_model_state_dict, + optimizer, raw_opt_state_dict, + ) + + ckpt_file_template = 'ckpt_{rank}.pth' + ckpt_merged_file = ckpt_dir / 'ckpt_merged.pth' + ckpt_file = ckpt_dir / ckpt_file_template.format( + rank=torch.distributed.get_rank() + ) + model_state_dict = parallel_model.state_dict() + optimizer_state_dict = optimizer.state_dict() + torch.save({ + 'model': model_state_dict, + 'optimizer': optimizer_state_dict + }, ckpt_file) + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + ckpt_files = [ckpt_dir / ckpt_file_template.format(rank=i) for i in range(torch.distributed.get_world_size())] + ckpt_state_dicts = [torch.load(f) for f in ckpt_files] + model_state_dicts = [ckpt['model'] for ckpt in ckpt_state_dicts] + optimizer_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_state_dicts] + merged_model_state_dicts, merged_optimizer_state_dict = merge_state_dicts(model_state_dicts, optimizer_state_dicts) + torch.save({ + 'model': merged_model_state_dicts, + 'optimizer': merged_optimizer_state_dict + }, ckpt_merged_file) + raw_model_state_dict = { + key: value + for key, value in raw_model_state_dict.items() + if not key.startswith('fc_relu1.unused_fc') + } + assert set(merged_model_state_dicts.keys()) == set(raw_model_state_dict.keys()) + for index in merged_model_state_dicts.keys(): + assert torch.equal(merged_model_state_dicts[index].cuda(), raw_model_state_dict[index].cuda()) + + assert set(merged_optimizer_state_dict.keys()) == set(raw_opt_state_dict.keys()) + assert merged_optimizer_state_dict['param_groups'] == raw_opt_state_dict['param_groups'] + assert set(merged_optimizer_state_dict['state']) == set(raw_opt_state_dict['state']) + for index in merged_optimizer_state_dict['state']: + for key in ('step', 'exp_avg', 'exp_avg_sq'): + assert torch.equal(merged_optimizer_state_dict['state'][index][key].cuda(), raw_opt_state_dict['state'][index][key].cuda()) + + +def _gpu_worker(use_zero, pas, plan_ngpus, runtime_ngpus): + # Basic logic: + # a. first train the original model, get a full state dict + # b. then use parallel model to load the full state dict as a merged state dict + # c. then parallel model save their own state dicts, and merge them to get a merged state dict. + # d. compare the full state dict in step a and the merged state dict in step c. They should be the same. + init_distributed() + compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + if torch.distributed.get_rank() == 0: + tempdir.mkdir(parents=True, exist_ok=True) + _train_raw(_create_cube_module(pas, compute_config, tempdir, 'raw'), tempdir) + torch.distributed.barrier() + _load_merged( + _create_cube_module(pas, compute_config, tempdir, 'cube'), + tempdir + ) + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('use_zero', [True, False]) +def test_checkpoint_load_from_raw_checkpoint(use_zero): + """ + Test when the checkpoint is generated from raw module and need to be loaded to parallel module. + """ + plan_ngpus = 2 + runtime_ngpus = 4 + launch_torchrun(4, _gpu_worker, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus) diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py new file mode 100644 index 00000000..6c554dba --- /dev/null +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -0,0 +1,130 @@ +import tempfile +import itertools +import re +from pathlib import Path +import shutil +import pytest +from typing import Dict, Tuple, List, Any +from dataclasses import dataclass + +import torch +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler + +import numpy as np + +from cube.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts +from cube.runtime.module import ParallelModule, ExtraState +from cube.runtime.gnorm import calcuate_gnorm + +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively +from .test_checkpoint_shared import _train_raw, _load_merged + + +class FcReluWithUnused(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.unused_fc0 = CubeLinear(out_features, out_features, bias=bias) + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.unused_fc1 = CubeLinear(out_features, out_features, bias=bias) + self.fc2 = CubeLinear(in_features, out_features, bias=bias) + self.unused_fc2 = CubeLinear(out_features, out_features, bias=bias) + self.relu2 = nn.ReLU() + + + def forward(self, x): + return self.relu2(self.fc2(self.relu1(self.fc1(x)))) + + +class FcRelu_4_4_WithUnused(FcReluWithUnused): + def __init__(self): + super().__init__(4, 4) + + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + + +def _create_cube_module(pas, compute_config, cube_savedir, module_type='raw'): + init_random() + if module_type == 'raw': + class RawModuleWithUnused(torch.nn.Module): + def __init__(self): + super().__init__() + self.unused_linear0 = nn.Linear(4, 4) + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = FcRelu_4_4_WithUnused() + self.unused_linear1 = nn.Linear(4, 4) + self.linear3 = nn.Linear(4, 1) + self.unused_linear2 = nn.Linear(4, 4) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + return RawModuleWithUnused().cuda() + else: + class ParallelModuleWithUnused(torch.nn.Module): + def __init__(self): + super().__init__() + self.unused_linear0 = nn.Linear(4, 4) + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = _to_cube_model( + FcRelu_4_4_WithUnused(), pas, + compute_config, cube_savedir, 'fc_relu1' + ) + self.unused_linear1 = nn.Linear(4, 4) + self.linear3 = nn.Linear(4, 1) + self.unused_linear2 = nn.Linear(4, 4) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + init_random() + return ParallelModuleWithUnused().cuda() + + +def _gpu_worker(use_zero, pas, plan_ngpus, runtime_ngpus): + # Basic logic: + # a. first train the original model, get a full state dict + # b. then use parallel model to load the full state dict as a merged state dict + # c. then parallel model save their own state dicts, and merge them to get a merged state dict. + # d. compare the full state dict in step a and the merged state dict in step c. They should be the same. + init_distributed() + compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + if torch.distributed.get_rank() == 0: + tempdir.mkdir(parents=True, exist_ok=True) + _train_raw(_create_cube_module(pas, compute_config, tempdir, 'raw'), tempdir) + torch.distributed.barrier() + _load_merged( + _create_cube_module(pas, compute_config, tempdir, 'cube'), + tempdir + ) + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('use_zero', [True, False]) +def test_checkpoint_load_from_raw_checkpoint(use_zero): + """ + Test when the checkpoint is generated from raw module and need to be loaded to parallel module. + """ + plan_ngpus = 2 + runtime_ngpus = 4 + launch_torchrun(4, _gpu_worker, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index f87a85aa..d0dd890f 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -212,7 +212,7 @@ def _gencode_contains(cubesave_dir, module_class, index, search_re): outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) filecontent = (outdir /f'gencode{index}.py').read_text() matches = re.findall(search_re, filecontent) - return bool(matches) + return matches class AttrHelper: def __init__(self) -> None: @@ -514,14 +514,12 @@ def _gencode_min_function_worker(tempdir): assert torch.equal(m_new(torch.tensor([-5, -2, -3]), torch.tensor([-1, -8, -1])), torch.tensor([-5, -8, -3])), "Expected element-wise min with negative values" - @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_codegen_min(): with tempfile.TemporaryDirectory() as tempdir: launch_torchrun(1, _gencode_min_function_worker, tempdir) - class MaxModule(torch.nn.Module): def __init__(self): super().__init__() @@ -536,8 +534,8 @@ def _gencode_max_function(tempdir): { 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), }, - PASData, - ComputeConfig(1, 1), + PASData, + ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True ) @@ -551,7 +549,78 @@ def _gencode_max_function(tempdir): actual_output = m_new(torch.tensor([[1, 2, 3], [4, 5, 6]])) assert torch.equal(actual_output, expected_output), "Expected each row's max value with original dimension" + @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of GPU devices') def test_codegen_max(): with tempfile.TemporaryDirectory() as tempdir: - launch_torchrun(1, _gencode_max_function, tempdir) \ No newline at end of file + launch_torchrun(1, _gencode_max_function, tempdir) + + +class SharedParameterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) + self.linear2.weight = self.linear1.weight # shared parameter + + def forward(self, x): + return self.linear2(self.linear1(x)) + + +@replace_all_device_with('cpu') +def test_codegen_shared_parameter(): + with tempfile.TemporaryDirectory() as tempdir: + m = SharedParameterModule() + m.train() + parallelize( + m, + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1), + cube_savedir=tempdir, + load_module=False, + reuse='override', + ) + assert _gencode_contains(tempdir, SharedParameterModule, 0, r"self\.register_parameter\('linear1_bias_*") + assert _gencode_contains(tempdir, SharedParameterModule, 0, r"self\.register_parameter\('linear2_bias_*") + assert _gencode_contains(tempdir, SharedParameterModule, 0, r"self\.register_parameter\('linear1_weight_*") + # linear2_weight shares the same parameter with linear1_weight + # so there will be no linear2_weight in the generated code + assert not _gencode_contains(tempdir, SharedParameterModule, 0, r"self\.register_parameter\('linear2_weight_*") + + +class BufferModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer('buffer', torch.ones(128, 64), persistent=False) + self.fc = torch.nn.Linear(64, 64) + + # x with shape [128, 64] + def forward(self, x): + return self.fc(x + self.buffer) + + +@replace_all_device_with('cpu') +def test_codegen_buffer(): + """ + Test even the buffer is not persistent, + it will be registered in the generated code as a persistent buffer. + """ + with tempfile.TemporaryDirectory() as tempdir: + m = BufferModule() + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + PASData, + ComputeConfig(1, 1), + cube_savedir=tempdir, + load_module=False, + reuse='override', + ) + matches = _gencode_contains(tempdir, BufferModule, 0, + r"self\.register_buffer\('buffer_*" + ) + assert len(matches) == 1 + match = matches[0] + assert 'persistent' not in match From 27236305b695948e642c754d12b8d00113dcb290 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 18 Mar 2024 02:34:32 +0000 Subject: [PATCH 1603/1892] Merged PR 1851: fix bug: parser support to IRDimops with returns of None fix bug: dimop support empty return This is the support to parse customized functions like ```python def no_return_function(x, y): logger.logging(x, y) # no return or # return None ``` The annotation can be `'a b, c d -> ?'`, since python function will always return None if `return` is not specified. Related work items: #1612 --- cube/graph/function/dimops.py | 7 ++++++- cube/graph/parser/fx/parser.py | 13 +++++++------ tests/graph/function/test_dimops.py | 19 +++++++++++++++---- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 005153a1..3145c377 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -18,7 +18,7 @@ Special identifier: 1) '*': this special identifier indicates the dimension is dynamic, which will automatically get expanded given the shape - 2) '?': this special identifier indicates the value is not a tensor, which will be ignored + 2) '?': this special identifier indicates the value is can only be replicated, no matter it is a tensor or a non-tensor. A `reduction` can be a set of {'', '+', '^'}: '' indicates this dimension can be partitioned, and each output should have this dimension. @@ -632,6 +632,11 @@ def __init__(self, create_fn: Callable, name: str, n_outputs = len(self._oannos) super().__init__(name, signature, inputs, n_outputs, **kwargs) + # change tensor to IRObject for '?' annotation + for idx, shape_anno in enumerate(self._oannos): + if shape_anno.ignore: + self.set_output(idx, IRObject()) + @property def anno(self) -> OpAnno: return self._anno diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index efa5cab4..2adf4d6b 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -4,8 +4,8 @@ from typing import Any, List, Tuple, Callable, Union, Dict, Type, Optional from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRFullTensor, IRTensor -from cube.ir.cten import IRObject, IRCell +from cube.ir.tensor import IRFullTensor +from cube.ir.cten import IRObject, IRCell, IRTensor from cube.graph.parser.frame import Frame from cube.graph.parser.fx.mapping import SignFx2Op from cube.graph.function.pyfunc import IRPyFunc @@ -247,10 +247,11 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule output_val = frame.get_var(node.name) if isinstance(ir_node, IRDimops): ir_node.infer_shape() - assert output_val.shape == ir_node.output(0).shape, ( - f'find shape inference not match: {output_val.shape} vs {ir_node.output(0).shape}' - f'\nnode: {node}' - ) + if isinstance(output_val, IRTensor) and isinstance(ir_node.output(0), IRTensor): + assert output_val.shape == ir_node.output(0).shape, ( + f'find shape inference not match: {output_val.shape} vs {ir_node.output(0).shape}' + f'\nnode: {node}' + ) ir_node.set_output(0, output_val) else: frame.set_var(node.name, ir_node) diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index 7f68f37a..68f929a8 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -8,6 +8,7 @@ import cube.graph.function as F from cube.graph.function.dimops import IRDimops from cube.ir.tensor import IRFullTensor +from cube.ir.cten import IRObject def create_op(creator: Callable, @@ -36,11 +37,21 @@ def partitionable(node: IRDimops, **config): idx=0, dim=1, num=2, ) -def create_udf_op1(input, weight, signature='test_udf_op1'): +def UDFOp1(input, weight, signature='test_udf_op1'): anno = 'L 8^ (L 2), L E -> 8^ (L 2) E ' - return IRDimops(create_udf_op1, 'udf_op1', signature, [anno], [input, weight]) + return IRDimops(UDFOp1, 'udf_op1', signature, [anno], [input, weight]) test_multi_dim_partition = partial(partitionable, - create_op(create_udf_op1, [(2048, 8, 4096), (2048, 4096)]), + create_op(UDFOp1, [(2048, 8, 4096), (2048, 4096)]), idx=0, dim=0, num=2, -) \ No newline at end of file +) + + +def test_no_return_op(): + + def NoReturnOp(input, weight, signature='no_return_op'): + anno = 'a b, b c -> ?' + return IRDimops(NoReturnOp, 'no_return_op', signature, [anno], [input, weight]) + + op = create_op(NoReturnOp, [(1024, 512), (512, 1024)]) + assert len(op.outputs()) == 1 and isinstance(op.output(0), IRObject) and (not isinstance(op.output(0), IRFullTensor)) From 3535db1f070a03ad4f8762fcfa05405269107355 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 18 Mar 2024 07:21:21 +0000 Subject: [PATCH 1604/1892] Merged PR 2065: parallel module: add broadcast support for generated files --- cube/parallel.py | 215 ++++++++++++++++++++++-- cube/runtime/module.py | 38 +++++ docs/parallel_module.md | 38 +++++ tests/parallel_module/test_broadcast.py | 126 ++++++++++++++ 4 files changed, 399 insertions(+), 18 deletions(-) create mode 100644 tests/parallel_module/test_broadcast.py diff --git a/cube/parallel.py b/cube/parallel.py index 715e50eb..a8bb7f82 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -10,6 +10,7 @@ from contextlib import contextmanager import logging import copy +import os import torch from cube.graph.parser.fx.parser import FxModuleParser @@ -209,18 +210,64 @@ class ReuseType(Enum): GRAPH = 'graph' # reuse graph only if present and match, generate otherwise. +class BroadcastGenFilesStrategy(Enum): + """ + The broadcast strategy for generated files. + Only new generated files can be broadcasted. + The files includes: + 1. config file: compute config (compute_config.pt) + 2. trace files: graph dump (graph.ckp), forward args dump(forward_args.pkl), + origin module metadata (origin_module_metadata.pt), init weights file(fullmodel.pt.*), + param name mapping (dist_param_map.pt) + 3. code: generated code files (gencode*.py) + Reused files will not be broadcasted with any of the following options. + """ + + # nothing will be broadcasted. + # You need to do it by yourself or the generated files are saved in a shared directory (like azure blob). + NONE = 'none' + + # broadcast all new generated files to all nodes. + # This is useful when you want to run the same code on all nodes. + # please note the init weight files can be huge. + ALL = 'all' + + # broadcast all new generated files except init weights (fullmodel.pt.*). + # Without weights, + # you can only construct the parallel module with `init_params=False`. + # You can then + # 1. Load the weights from a checkpoint file with `module.load_state_dict` or `load_merged_state_dict` + # 2. Or you can use `ParallelModule.broadcast_weights` to get the weights from the workers in node0. + # (local world size should be bigger than plan_ngpus) + NO_WEIGHTS = 'no_weights' + + # broadcast the new generated code only (gencode*.py) + # It's your responsibility to make sure other necessary files are available on all nodes. + CODE = 'code' + + +class RegenStatus(Enum): + NONE = 'none' # nothing is regenerated. + ALL = 'all' # everything is regenerated, including graph and code + CODE = 'code' # only code is regenerated. + + def _prepare_namespace( cube_savedir: str, module_or_module_class: Union[Type[torch.nn.Module], torch.nn.Module], instance_name: Optional[str] = None, -): +) -> Tuple[str, Path]: cube_savedir = _add_cube_savedir_to_syspath(cube_savedir) instance_name = instance_name or _DEFAULT_INSTANCE_NAME instance_name = instance_name.strip('.') if instance_name else '' instance_namespace = f'.{instance_name}' if instance_name else '' namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_or_module_class)}{instance_namespace}' - return namespace + + outdir = cube_savedir / Path(namespace.replace('.', '/').strip('/')) + outdir.mkdir(parents=True, exist_ok=True) + + return namespace, outdir def _prepare_and_check_reusable( @@ -247,9 +294,7 @@ def _prepare_and_check_reusable( RuntimeError: if the existing code is not reusable, will raise RuntimeError if the code is not reusable but the module is already loaded. """ - namespace = _prepare_namespace(cube_savedir, module_or_module_class, instance_name) - outdir = cube_savedir / Path(namespace.replace('.', '/').strip('/')) - outdir.mkdir(parents=True, exist_ok=True) + namespace, outdir = _prepare_namespace(cube_savedir, module_or_module_class, instance_name) # decision matrix for code generation # reuse flag | dir condition(imported, empty, match, unmatched) | action @@ -287,7 +332,7 @@ def _prepare_and_check_reusable( expected_output_files.append(outdir / ParallelModule.ORIGIN_MODULE_METADATA_FILE) existing_output_files = [ f for f in outdir.glob('*') - if f.is_file() and ( # just take fullmap.pt.0 to compare + if f.is_file() and ( # just take fullmodel.pt.0 to compare not f.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) or f.name == FxModuleParser.ATTR_CONTENT_FILE_0 ) @@ -364,7 +409,7 @@ def _gen_graph( ) # generate dummy inputs for logic graph - # that is, generate IRObject/IRFullTensor for fx graph dummpy input + # that is, generate IRObject/IRFullTensor for fx graph dummy input fx_input_nodes = [node for node in fx_graph.graph.nodes if node.op == 'placeholder'] # the inputs of graph is different with original forward args # so we get the real forward args from fx inputs @@ -423,7 +468,7 @@ def _gencode( *, module_dtype: Optional[torch.dtype] = None, module_fn: Optional[Callable[[], torch.nn.Module]] = None, - ) -> None: + ) -> RegenStatus: """ Generate cube module source code from a torch module, and save it to file. Generated module will be save according to its full qualified name. @@ -444,13 +489,15 @@ def _gencode( module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. Returns: - None + RegenStatus: which part is regenerated. """ graph_ckp = outdir / _GRAPH_DUMP_FILE forward_args_ckp = outdir / _FORWARD_ARGS_DUMP_FILE origin_module_metadata_ckp = outdir / ParallelModule.ORIGIN_MODULE_METADATA_FILE + ret = RegenStatus.NONE if not graph_ckp.exists() or not forward_args_ckp.exists() or not origin_module_metadata_ckp.exists(): is_module_class = inspect.isclass(module_or_module_class) + ret = RegenStatus.ALL if is_module_class: try: if module_fn is None: @@ -488,6 +535,7 @@ def _gencode( if is_module_class: del module else: + ret = RegenStatus.CODE logger.info(f"Reuse graph dump in {outdir}") graph = IRGraph.load(graph_ckp) forward_args = torch.load(forward_args_ckp) @@ -534,6 +582,8 @@ def _gencode( as_parallel_module=True, ) + return ret + def _load_cube_module_class( module_class: Type[torch.nn.Module], @@ -556,7 +606,7 @@ def _load_cube_module_class( when you need to load the generated module in a non-torchrun environment. """ rank = torch.distributed.get_rank() if rank is None else rank - namespace = _prepare_namespace(cube_savedir, module_class, instance_name) + namespace, _ = _prepare_namespace(cube_savedir, module_class, instance_name) gen_imported = importlib.import_module( f'{namespace}.{Path(_GENCODE_FILE_TEMPLATE.format(rank)).stem}' ) @@ -582,19 +632,20 @@ def parallelize( module_dtype: Optional[torch.dtype] = None, module_fn: Optional[Callable[[], torch.nn.Module]] = None, init_module_params: bool = True, + broadcast_strategy: Union[str, BroadcastGenFilesStrategy] = 'none', ) -> Union[None, ParallelModule, Type[ParallelModule]]: """ Convert a torch.nn.Module object or class to CubeModule object or class. If you want to save multiple instances of the same module, - you can specify the instance_name to distingish them. + you can specify the instance_name to distinguish them. Currently you must use a shared file system to share the generated files (like mounted Azure Blob) Or you can unset load_module flag, and manually copy the generated files to other nodes. After all nodes have the generated files, you can call parallelize() again with load_module flag set. Note: if reuse is not set to ReuseType.MATCH, - the generated code in outdir will be removed EVEN IF the code generetion fails in this call. + the generated code in outdir will be removed EVEN IF the code generation fails in this call. if the input is a module object. The module object will be copied to cpu to handle possible insufficient gpu memory. @@ -614,7 +665,7 @@ def parallelize( 2. If the input is a module class, it will return a CubeModule class if load_module is True. a. module_fn is used to create the module object, or module's__init__ if not prent. b. module_dtype is used to control the dtype of the created module (by constructor or module_fn). - Of cousre, it can be merged into module_fn. + Of course, it can be merged into module_fn. c. init_module_params is ignored. After the module is converted, you can use it to create module object by calling it like a module class. @@ -627,7 +678,7 @@ def __init__(self, init_params=True): ... ``` So you can use `init_params` in `__init__` to control whether to initialize the module parameters. - For example, if you don't want to intialize module params: + For example, if you don't want to initialize module params: ``` module = GenModule(init_params=False) ``` @@ -647,7 +698,9 @@ def __init__(self, init_params=True): so it is only used when module_or_module_class is a module object, and load_module is true. module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. - + broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for generated files. + Please note that the broadcasting will only be done in torchrun environment, + and will throw an error if torch.distributed is not initialized and broadcast_strategy is not NONE. Returns: Union[CubeModule, Type[CubeModule], None]: if load_module flag is set, return the converted CubeModule object or class @@ -662,6 +715,7 @@ def __init__(self, init_params=True): is_module_class = inspect.isclass(module_or_module_class) module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ reuse = ReuseType(reuse) if isinstance(reuse, str) else reuse + broadcast_strategy = BroadcastGenFilesStrategy(broadcast_strategy) if isinstance(broadcast_strategy, str) else broadcast_strategy # genereate code only in node0 # if it is not in a torchrun environment, just generate. @@ -672,7 +726,7 @@ def __init__(self, init_params=True): if not config_file.exists(): torch.save(compute_config, config_file) with _compile_flags(compute_config): - _gencode( + regen_status = _gencode( module_or_module_class, dummy_input, pas_policy, @@ -682,8 +736,47 @@ def __init__(self, init_params=True): module_fn=module_fn, ) else: + regen_status = RegenStatus.NONE logger.info(f"Reuse generated code in {outdir}") + if broadcast_strategy != BroadcastGenFilesStrategy.NONE: + if not torch.distributed.is_initialized(): # we only support loading in torchrun environment + raise RuntimeError("Broadcast generated files failed: torch.distributed is not initialized.") + torch.distributed.barrier() + # sync regen_status + curr_rank = torch.distributed.get_rank() + if curr_rank == 0: + sent_obj = [regen_status] + else: + sent_obj = [None] + torch.distributed.broadcast_object_list( + sent_obj, + src=0, + ) + if curr_rank != 0: + regen_status = sent_obj[0] + + # narrow down broadcast_strategy according to regen_status + if regen_status == RegenStatus.NONE: + # we don't need to broadcast anything + broadcast_strategy = BroadcastGenFilesStrategy.NONE + elif regen_status == RegenStatus.CODE: + # narrow ALL/NO_WEIGHTS down to code + broadcast_strategy = BroadcastGenFilesStrategy.CODE + else: + # we don't need to narrow broadcast_strategy in this case + # keep the original broadcast_strategy + assert regen_status == RegenStatus.ALL + + # broadcast generated files according to regen_status + if broadcast_strategy != BroadcastGenFilesStrategy.NONE: + _broadcast_gen_files( + module_class, + cube_savedir=cube_savedir, + instance_name=instance_name, + broadcast_strategy=broadcast_strategy, + ) + if load_module: if not torch.distributed.is_initialized(): # we only support loading in torchrun environment raise RuntimeError("Load CubeModule failed: torch.distributed is not initialized.") @@ -831,7 +924,7 @@ def build_optimizer( so we need to replace the parameters of optimizer with CubeModule.parameters_for_optimizer It is impossible to make this change transparent to end users. 2. optimizer.step(): - we need to call optimier.sync_shard_grad() to sync the gradients of the module before optimizer.step(). + we need to call optimizer.sync_shard_grad() to sync the gradients of the module before optimizer.step(). In zero mode, we have to call CubeModule.gather_params() after optimizer.step() 3. optimizer.zero_grad(): We need to call CubeModule.zero_grad() after optimizer.zero_grad() @@ -905,7 +998,7 @@ def _local_parameters(module: torch.nn.Module): for idx, (name, param) in enumerate(gen): if name.endswith(cube_suffix): # is a parameter of ParallelModule # -1 for removing the dot - # please note when the whole module is a ParallModule, + # please note when the whole module is a ParallelModule, # the name will be empty after removing the suffix name = name[:-len(cube_suffix) - 1] if name not in opt_module_locs: @@ -1574,3 +1667,89 @@ def _get_valid_name_from_merged_model( return name break return None + + +def _broadcast_gen_files( + module_class: Type[torch.nn.Module], + *, + cube_savedir: Union[str, Path] = './.cube', + instance_name: Optional[str] = None, + broadcast_strategy: Union[str, BroadcastGenFilesStrategy], +): + """ + Broadcast new generated files for a module to all nodes. + + Args: + module_class (Type[torch.nn.Module]): the original torch module class + cube_savedir (Union[str, Path]): the directory to save generated code + instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. + broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for generated files. + + Returns: + None + """ + + broadcast_strategy = BroadcastGenFilesStrategy(broadcast_strategy) if isinstance(broadcast_strategy, str) else broadcast_strategy + if broadcast_strategy == BroadcastGenFilesStrategy.NONE: + return + + world_size = torch.distributed.get_world_size() + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', default=1)) + assert world_size % local_world_size == 0, "world_size should be a multiple of local_world_size" + nnode = world_size // local_world_size + + if nnode == 1: + # no need to broadcast generated files + return + + curr_rank = torch.distributed.get_rank() + ranks = list(range(0, world_size, local_world_size)) + group = DeviceGroup().get_group(ranks) + + # use the first rank of each node to broadcast + if curr_rank % local_world_size == 0: + _, outdir = _prepare_namespace(cube_savedir, module_class, instance_name) + files: List[str] = [] + # send file list + if curr_rank == 0: + for file in outdir.glob('*'): + if file.is_file() and ( + broadcast_strategy == BroadcastGenFilesStrategy.ALL or + ( + broadcast_strategy == BroadcastGenFilesStrategy.NO_WEIGHTS + and not file.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) + ) or + ( + broadcast_strategy == BroadcastGenFilesStrategy.CODE + and file.suffix == '.py' + ) + ): + files.append(file.name) + sent_obj = [files] + else: + sent_obj = [None] + torch.distributed.broadcast_object_list( + sent_obj, + src=0, + group=group, + ) + # get file list + if curr_rank != 0: + files = sent_obj[0] + + logging.info(f'File list broadcasted ({len(files)} in total).') + # send file content one by one + for fname in files: + if curr_rank == 0: + with open(outdir / fname, 'rb') as f: + data = [f.read()] + else: + data = [None] + torch.distributed.broadcast_object_list(data, src=0, group=group) + if curr_rank != 0: + with open(outdir / fname, 'wb') as f: + f.write(data[0]) + logging.info(f'File {fname} broadcasted.') + + # wait for all nodes to finish + torch.distributed.barrier() diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 18ed857a..c076aa0b 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -830,3 +830,41 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s raise RuntimeError(erro_msg) else: _logger.warning(erro_msg) + + def broadcast_weights(self): + """ + Broadcast weights (including parameters and buffers) across scale units. + The source ranks is the ranks in first scale unit. + The weights in the ranks in the rest scale units will be replace inplace. + """ + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', default=1)) + plan_ngpus = self.get_compute_config().plan_ngpus + + if local_world_size < plan_ngpus: + raise RuntimeError(f'LOCAL_WORLD_SIZE {local_world_size} is less than plan_ngpus {self.get_compute_config().plan_ngpus}. Cannot broadcast weights.') + + for i in range(plan_ngpus): + ranks = list(range(i, world_size, plan_ngpus)) + DeviceGroup().get_group(ranks) + + curr_parallel_group_ranks = list(range(rank % plan_ngpus, world_size, plan_ngpus)) + curr_parallel_group = DeviceGroup().get_group(curr_parallel_group_ranks) + src_rank = min(curr_parallel_group_ranks) + logging.info(f'Rank-{rank} is broadcasting weight to ranks {curr_parallel_group_ranks}, broadcast root: {src_rank}...') + + # NOTE: please make sure the above checkpoint load is from local checkpoint file, + # otherwise, the following broadcast may time out due to slow checkpoint file read. + # Broadcast parameters and buffers across scale units + params = self.parameters_for_broadcast() + logging.info(f'Inplace broadcasting {len(params)} parameters...') + for i, param in enumerate(params): + torch.distributed.broadcast(param.data, src=src_rank, group=curr_parallel_group) + logging.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') + + # NOTE: may batch buffers for efficient broadcast, + # current implementation is the most memory efficient way. + logging.info(f'Inplace broadcasting {len(self._buffers)} buffers...') + for _, buffer in self._buffers.items(): + torch.distributed.broadcast(buffer.data, src=src_rank, group=curr_parallel_group) diff --git a/docs/parallel_module.md b/docs/parallel_module.md index c17f6094..4bbc6e0f 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -175,6 +175,7 @@ Note: ### ReuseType The reuse policy for the existing generated code. It is an enum with the following values: + ```python class ReuseType(Enum): MATCH = 'match' @@ -189,6 +190,36 @@ We call it a `match` when the `ComputeConfig` is the same with the previous run. 3. MOO: MOO is short for 'match or override'. It will reuse if match, generate if not match or no previous gerenated code exists. 4. GRAPH: Reuse graph only if match, generate otherwise. +### BroadcastGenFilesStrategy + +The broadcast strategy for new generated files. +Please note we never broadcast reused files. + +```python +class BroadcastGenFilesStrategy(Enum): + NONE = 'none' + ALL = 'all' + NO_WEIGHTS = 'no_weights' + CODE = 'code' +``` + +1. None: nothing will be broadcasted. + You need to do it by yourself or the generated files are save in a shared directory (like azure blob). + +2. ALL: broadcast all the generated files to all nodes. + This is useful when you want to run the same code on all nodes. + please note the init weight files can be huge. + +3. NO_WEIGHTS: broadcast all except init weights. + Without weights, you can only construct the parallel module with `init_params=False`. + You can then + - Load the weights from a checkpoint file with `module.load_state_dict` or `load_merged_state_dict` + - Or you can use `broadcast_weights_inplace` to get the weights from the workers in node0. + (local world size should be bigger than plan_ngpus) + +4. CODE: broadcast the new generated code only + It's your responsibility to make sure other necessary files are available on all nodes. + ### Module Conversion @@ -207,6 +238,7 @@ def parallelize( module_dtype: Optional[torch.dtype] = None, module_fn: Optional[Callable[[], torch.nn.Module]] = None, init_module_params: bool = True, + broadcast_strategy: Union[str, BroadcastGenFilesStrategy] = 'none', ) -> Union[None, ParallelModule, Type[ParallelModule]]: ``` It has the following parameters: @@ -237,6 +269,8 @@ so it is only used when `module_or_module_class` is a module object, and `load_m - module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use `__init__` if it is None. This parameter is only used when `module_or_module_class` is a module class. +- broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for new generated files. + Note: 1. This function can be used to convert both module object and module class to cube module or cube module class. @@ -254,6 +288,9 @@ After all nodes have the generated files, you can call `parallelize()` again wit 4. if reuse is not set to ReuseType.MATCH, the generated code in outdir will be removed EVEN IF the code generetion fails in this call. +5. For broadcast_strategy, Please note that the broadcast will only be done in torchrun environment, + and will throw an error if torch.distributed is not initialized and broadcast_strategy is not NONE. + After the module is converted, you can use it to create module object by calling it like a module class. The module class is defined like: ```python @@ -269,6 +306,7 @@ For example, if you don't want to intialize module params: module = GenModule(init_params=False) ``` + ### Optimizer Creation We have `build_optimizer` to build an optimizer for distributed training. diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py new file mode 100644 index 00000000..4ea30cc5 --- /dev/null +++ b/tests/parallel_module/test_broadcast.py @@ -0,0 +1,126 @@ +import tempfile +import os +from pathlib import Path + +import pytest +import torch + +from cube.parallel import ComputeConfig, parallelize + +from .common import PASRandomSPMD, init_distributed +from ..launch_torchrun import launch_torchrun + + +class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x) + + +def _to_cube_model(module, compute_config, cube_savedir, + instance_name=None, load_module=False, + broadcast_strategy='none', + **kwargs, +): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + PASRandomSPMD, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name, + load_module=load_module, + broadcast_strategy=broadcast_strategy, + **kwargs, + ) + + +def _gpu_worker(): + init_distributed() + # fake two machines, as we use different cube_savedir for each worker + os.environ['LOCAL_WORLD_SIZE'] = '1' + p = lambda t, b, i, load_module=True, **kwargs: _to_cube_model( + Module(), + ComputeConfig(1, 2), + t, + load_module=load_module, + broadcast_strategy=b, + instance_name=i, + **kwargs, + ) + # case 1: no broadcast, so only rank 0 can load the module + # rank 1 will raise ModuleNotFoundError + with tempfile.TemporaryDirectory() as tempdir: + if torch.distributed.get_rank() == 0: + p(tempdir, 'none', '_1') + else: + with pytest.raises(ModuleNotFoundError): + p(tempdir, 'none', '_1') + + # case 2: broadcast only code, so only rank 0 can load the module + # rank 1 will raise RuntimeError because it will fail to load fullmodel.pt + with tempfile.TemporaryDirectory() as tempdir: + if torch.distributed.get_rank() == 0: + p(tempdir, 'code', '_2') + else: + with pytest.raises(RuntimeError, match='Cannot find file.*'): + p(tempdir, 'code', '_2') + + # case 3: broadcast except weights, so only rank 0 can load the module + # rank 1 will raise RuntimeError because it will fail to load fullmodel.pt + with tempfile.TemporaryDirectory() as tempdir: + if torch.distributed.get_rank() == 0: + p(tempdir, 'no_weights', '_3') + else: + with pytest.raises(RuntimeError, match='Cannot find file.*'): + p(tempdir, 'no_weights', '_3') + + # case 4: broadcast except weights, every rank can succeed if don't lood init params + with tempfile.TemporaryDirectory() as tempdir: + m = p(tempdir, 'no_weights', '_4', + init_module_params=torch.distributed.get_rank() == 0 + ) + if torch.distributed.get_rank() == 0: + for n, pa in m.named_parameters(): + if n.startswith('linear_weight'): + pa.data.fill_(1.0) + else: + for n, pa in m.named_parameters(): + if n.startswith('linear_weight'): + assert not torch.equal(pa.data, torch.ones_like(pa.data)) + m.broadcast_weights() + # check if broadcast_weights works + for n, pa in m.named_parameters(): + if n.startswith('linear_weight'): + assert torch.equal(pa.data, torch.ones_like(pa.data)) + + # case 5: broadcast all, all ranks will succeed + with tempfile.TemporaryDirectory() as tempdir: + p(tempdir, 'all', '_5') + + # case 6: test incremental broadcast + with tempfile.TemporaryDirectory() as tempdir: + # generate without broadcasting + m = p(tempdir, 'none', '_6', load_module=False) + if torch.distributed.get_rank() != 0: + assert list(Path(tempdir).glob('*')) == [] + + # case 6.1: broadcast code even we set broadcast_strategy to `all` + # because only code is new generated. + m = p(tempdir, 'all', '_6', load_module=False, reuse='graph') + if torch.distributed.get_rank() != 0: + # only python files are broadcasted + assert list(f.name for f in Path(tempdir).glob('**/*') if f.is_file()) == ['gencode0.py', 'gencode1.py'] + + # case 6.2: everything should be broadcasted, including weights + # so the load_module will succeed. + m = p(tempdir, 'all', '_6', load_module=True, reuse='override') + + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_broadcast(): + launch_torchrun(2, _gpu_worker) From da934fa6b0b415163bcf10199197f3ef6bf85310 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Tue, 19 Mar 2024 03:54:09 +0000 Subject: [PATCH 1605/1892] Merged PR 2068: add test for gnorm add test for gnorm - dp case - dp + scale unit case - pp case - pp + scale unit case --- tests/runtime/test_gnorm.py | 116 ++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tests/runtime/test_gnorm.py diff --git a/tests/runtime/test_gnorm.py b/tests/runtime/test_gnorm.py new file mode 100644 index 00000000..870ec565 --- /dev/null +++ b/tests/runtime/test_gnorm.py @@ -0,0 +1,116 @@ +""" +This test is to verify the correctness of the gradient norm algorithm for cube. + +To avoid other potential parity issues that may have influence the gradient value, +we use weight data as gradient, and calculate its norm to verify the correctness +of gnorm calculation. +""" +import torch +from functools import partial + +import cube +from cube.ir.operator import IRFwOperation +from cube.runtime.module import CubeModule +from cube.runtime.gnorm import prepare_for_grad_clip, clip_gnorm +from cube.flags import CompileFlag + +from ..launch_torchrun import torchrun +from ..utils import init_parameter + + +class Module(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(16, 16, bias=False) + self.linear2 = torch.nn.Linear(16, 16, bias=False) + self.linear3 = torch.nn.Linear(16, 16, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return torch.sum(x) + + +def tensor_parallelism(graph, node, idx, dim, num): + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=idx, dim=dim, num=num) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return sub_nodes + + +def cal_wnorm_baseline(model): + wnorm = torch.norm( + torch.stack([torch.norm(p, p=2, dtype=torch.float32) for p in model.parameters()]) + ) + return wnorm + + +def cal_wnorm_cube(model: CubeModule): + for p in model.parameters_for_optimizer(): + p.grad = p.data + # p.grad.copy_(p.data) + nreplicas2localparams = prepare_for_grad_clip(model, is_zero=CompileFlag.use_zero) + wnorm, _ = clip_gnorm(nreplicas2localparams, None) + # maps = {tid: [t.size() for t in ts] for tid, ts in nreplicas2localparams.items()} + # print(f'cube nrepicas len: {maps}') + return wnorm + +# su_num: scale unit number +def dp_policy(graph, resource, su_num): + ngpus = resource.ngpus // su_num + for node in graph.select(ntype=IRFwOperation): + tensor_parallelism(graph, node, idx=0, dim=0, num=ngpus) + return graph + +def pp_policy(graph, resource, su_num): + ngpus = resource.ngpus // su_num + devid = 0 + for node in graph.select(ntype=IRFwOperation): + graph.assign(node, devid) + devid = (devid + 1) % ngpus + return graph + + +def model_test(policy, su_num: int = 1, use_zero: bool = False): + # su_num: scale unit number + cube.init() + CompileFlag.use_zero = use_zero + + model = Module().cuda() + init_parameter(model) + + # get baseline weight norm + wnorm_baseline = cal_wnorm_baseline(model) + + sample = torch.randn(16, 16).cuda() + @cube.compile(model, sample, PAS=partial(policy, su_num=su_num), + scale=su_num > 1) + def train_iter(model, data): + loss = model(data) + loss.backward() + return loss + + model = cube.load_model() + + # train_iter(model, sample) # link .grad to reducer buffer + wnorm_cube = cal_wnorm_cube(model) + + for rank in range(torch.distributed.get_world_size()): + if rank == torch.distributed.get_rank(): + print(f'rank: {rank}: baseline wnorm: {wnorm_baseline}') + print(f'rank: {rank}: cube wnorm: {wnorm_cube}') + torch.distributed.barrier() + + assert wnorm_cube == wnorm_baseline + + +test_norm_case1_dp = partial(torchrun, 2, model_test, dp_policy) +test_norm_case1_dp_su = partial(torchrun, 4, model_test, dp_policy, 2) +test_norm_case1_dp_zero = partial(torchrun, 2, model_test, dp_policy, 1, True) + +test_norm_case2_pp = partial(torchrun, 2, model_test, pp_policy) +test_norm_case2_pp_su = partial(torchrun, 4, model_test, pp_policy, 2) +test_norm_case2_pp_su_zero = partial(torchrun, 4, model_test, dp_policy, 2, True) \ No newline at end of file From 345485a97be6b123395bac6dc493de0efb0f666a Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Fri, 22 Mar 2024 08:16:02 +0000 Subject: [PATCH 1606/1892] Merged PR 2073: fix pipeline bug in scale cases fix pipeline bug when the number of scale unit is larger than 1. The bug is due to the invisibility of parameters in sub-level segments, which misses to create reducers for non-reduced parameters. --- cube/codegen/module/module.py | 24 ++++-- cube/graph/segment.py | 36 ++++---- tests/compiler/test_compile.py | 147 ++++++++++++++++++++++----------- 3 files changed, 137 insertions(+), 70 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 2a64969f..b25f4d53 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -162,14 +162,22 @@ def add_scale_reducers(self): all_params.update(reducer.inputs()) # create a reducer for the rest parameters used for this device rest_params = [] - for param in self.execplan.graph.attributes(): - if not param.is_param(): continue - for ctensor in graph.ctensors(param): - if device not in ctensor.device: continue - if ctensor not in all_params: - # a same parameter can be consumed multiple times by different operators - if ctensor not in rest_params: - rest_params.append(ctensor) + + def collect_rest_params(segment): + """Resursively collect parameters. Note parameters can be in sub-segments, + which is invisible to its top-level segment.""" + for param in segment.attributes(): + if not param.is_param(): continue + for ctensor in segment.ctensors(param): + if device not in ctensor.device: continue + if ctensor not in all_params: + # a same parameter can be consumed multiple times by different operators + if ctensor not in rest_params: + rest_params.append(ctensor) + for seg in segment.select(ntype=IRSegment, flatten=False): + collect_rest_params(seg) + + collect_rest_params(graph) if len(rest_params) == 0: continue # create reducer and append to the execution diff --git a/cube/graph/segment.py b/cube/graph/segment.py index a894c079..91723810 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -642,25 +642,31 @@ def exist(self, node: IRCell, flatten: bool = True) -> bool: return False def select(self, name: Optional[str] = None, ntype: Optional[IRCell] = None, flatten: bool = True) -> List[IRCell]: - """ - Select all the nodes (including nodes in sub-segment) that - satisfy the condition. + """Select all the nodes that satisfy all the specified conditions. - @param name Optional[str]: the node name - @param ntype Optional[Type]: the node type - @param flatten bool: whether to flatten the segment to nodes. (Default True) + Note: + Current IRGraph can have at most a 2-level hierarchy (IRGraph[IRSegment]). + We don't allow IRSegment inside IRSegment. So when users try to index + IRSegment, turn `flatten=False` will get the same result as `flatten=True`, + and can save more time because `flatten=False` will not traverse the + nodes in IRSegment. + + Args: + name (Optional[str]): the node name + ntype (Optional[Type]): the node type + flatten (bool): whether to recursively search the nodes inside segments (Default True). - @return nodes List[IRCell]: the nodes that have the name. + Returns: + List[IRCell]: the nodes that satisfied the name or ntype. """ nodes = [] - for node in self.nodes(flatten=flatten): - if name is not None: - if node.name != name: - continue - if ntype is not None: - if not isinstance(node, ntype): - continue - nodes.append(node) + for node in self._nodes: + if (name is None or name == node.name) and \ + (ntype is None or isinstance(node, ntype)): + nodes.append(node) + # recursively search in sub-segment + if flatten and isinstance(node, IRSegment): + nodes += node.select(name, ntype, flatten) return nodes def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwOperation: diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index 49234484..e8615d29 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -2,12 +2,14 @@ pytest unit_tests/compiler/test_compile.py """ import torch -import logging from functools import partial +import more_itertools as mitr import cube +from cube.runtime.utils import microbatches from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation +from cube.graph.segment import IRSegment +from cube.ir.operator import IRFwOperation, IRDataOperation from cube.flags import CompileFlag from ..launch_torchrun import torchrun from ..utils import init_parameter, assert_parity @@ -57,38 +59,64 @@ def baseline(): return losses -def scale(ngpus_per_unit: int): +# ================================== cube functionality ======================================== - model = MLP() - init_parameter(model) +def pipe_policy(graph: IRGraph, resource, ngpus_per_unit: int): + + ngpus = min(ngpus_per_unit, resource.ngpus) + fnodes = graph.select(ntype=IRFwOperation) - def policy(graph: IRGraph, resource): + stages = mitr.divide(ngpus, fnodes) + stages = [list(s) for s in stages] + lead_nodes = [s[0] for s in stages] + graph.staging(lead_nodes) + + for dl in graph.select(ntype=IRDataOperation): + graph.assign(dl, 0) + + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + for idx, stage in enumerate(stages): + graph.assign(stage, idx) + return graph + + +def tp_policy(graph: IRGraph, resource, ngpus_per_unit: int): + + ngpus = min(ngpus_per_unit, resource.ngpus) - ngpus = min(ngpus_per_unit, resource.ngpus) + def tensor_parallelism(node, idx, dim, num): + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=idx, dim=dim, num=num) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return sub_nodes - def tensor_parallelism(node, idx, dim, num): - sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=idx, dim=dim, num=num) + l1, l2, l3, l4 = graph.select(name='linear') + + # l1 tensor parallelism + tensor_parallelism(l1, idx=1, dim=0, num=ngpus) + # l2 data parallelism + tensor_parallelism(l2, idx=0, dim=0, num=ngpus) + # l3 tensor parallelism + tensor_parallelism(l3, idx=1, dim=1, num=ngpus) + # l4 replicate + + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if len(node.device) == 0: + sub_nodes = graph.replicate(node, times=ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) - return sub_nodes - - l1, l2, l3, l4 = graph.select(name='linear') - - # l1 tensor parallelism - tensor_parallelism(l1, idx=1, dim=0, num=ngpus) - # l2 data parallelism - tensor_parallelism(l2, idx=0, dim=0, num=ngpus) - # l3 tensor parallelism - tensor_parallelism(l3, idx=1, dim=1, num=ngpus) - # l4 replicate - - for node in graph.select(ntype=IRFwOperation): - if len(node.device) == 0: - sub_nodes = graph.replicate(node, times=ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return graph + return graph + + +def cube_run(ngpus_per_unit: int, policy): + + cube.init() + CompileFlag.disable_code_line_info = True # speedup parse + + model = MLP() + init_parameter(model) ngpus_per_unit = min(ngpus_per_unit, torch.distributed.get_world_size()) nreplicas = torch.distributed.get_world_size() // ngpus_per_unit @@ -96,8 +124,13 @@ def tensor_parallelism(node, idx, dim, num): print('>> set batch size to', batch_size) x = get_dummy_data(batch_size=batch_size) - @cube.compile(model, x, PAS=policy, scale=True) - def train_iter(model, x): + dl = microbatches([x,]) + + policy = partial(policy, ngpus_per_unit=ngpus_per_unit) + + @cube.compile(model, dl, PAS=policy, scale=True) + def train_iter(model, dataloader): + x = next(iter(dataloader)) loss = model(x) loss.backward() return loss @@ -108,7 +141,8 @@ def train_iter(model, x): losses = [] for _ in range(3): x = get_dummy_data(batch_size=batch_size) - loss = train_iter(model, x) + dl = microbatches([x,]) + loss = train_iter(model, dl) loss = loss * nreplicas optimizer.step() optimizer.zero_grad() @@ -119,19 +153,38 @@ def train_iter(model, x): return losses - -def scale_test(): - cube.init() - CompileFlag.disable_code_line_info = True # speedup parse - assert_parity(baseline, partial(scale, 2)) - - -def scale_test_dp(): - cube.init() - CompileFlag.disable_code_line_info = True # speedup parse - assert_parity(baseline, partial(scale, 1)) - - -test_scale_2gpu = partial(torchrun, 2, scale_test) -test_scale_2gpu_dp = partial(torchrun, 2, scale_test_dp) -test_scale_4gpu = partial(torchrun, 4, scale_test) +# single-gpu test +test_single = partial(torchrun, 1, assert_parity, + baseline, + partial(cube_run, 1, tp_policy) +) + +# scale test +test_scale2 = partial(torchrun, 2, assert_parity, + baseline, + partial(cube_run, 1, tp_policy) +) + +# tensor parallelism test +test_tp2 = partial(torchrun, 2, assert_parity, + baseline, + partial(cube_run, 2, tp_policy) +) + +# tensor parallelism + scale test +test_tp2scale2 = partial(torchrun, 4, assert_parity, + baseline, + partial(cube_run, 2, tp_policy) +) + +# pipeline parallelism test +test_pipe2 = partial(torchrun, 2, assert_parity, + baseline, + partial(cube_run, 2, pipe_policy) +) + +# pipeline parallelism + scale test +test_pipe2scale2 = partial(torchrun, 4, assert_parity, + baseline, + partial(cube_run, 2, pipe_policy) +) From f3ffb602348e850f291ece3e4b68c778527de0e7 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 27 Mar 2024 02:38:21 +0000 Subject: [PATCH 1607/1892] Merged PR 2071: parallel module: add dedup state dict support unit test pass parity check pass Related work items: #1846 --- cube/parallel.py | 370 +++++++++++++++--- cube/runtime/module.py | 150 +++---- cube/utils.py | 44 ++- docs/parallel_module.md | 136 +++++-- tests/parallel_module/common.py | 18 +- tests/parallel_module/test_broadcast.py | 6 +- .../parallel_module/test_checkpoint_dedup.py | 198 ++++++++++ 7 files changed, 757 insertions(+), 165 deletions(-) create mode 100644 tests/parallel_module/test_checkpoint_dedup.py diff --git a/cube/parallel.py b/cube/parallel.py index a8bb7f82..899d4dd1 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -39,7 +39,7 @@ from cube.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState from cube.runtime.device import DeviceGroup from cube.runtime.gnorm import calcuate_gnorm, clip_grads -from cube.utils import get_member_by_name +from cube.utils import get_member_by_name, setup_stride_broadcast_group logger = logging.getLogger(__name__) @@ -56,7 +56,8 @@ class ComputeConfig: zero_ngroups: int = 1 # you can put any configuration here - # *Note*: the assumption is different user_config should generate different code. + # Note: different user_config should generate different graph/code. + # so if user_config is changed, both graph and code will be regenerated. # Example 1: save module configuration # ```python # class MyModule(torch.nn.Module): @@ -99,8 +100,11 @@ def __post_init__(self): raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be > 0") if self.runtime_ngpus % self.plan_ngpus != 0: raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be a multiple of plan_ngpus {self.plan_ngpus}") - if self.use_zero and self.zero_ngroups < 0: - raise ValueError(f"zero_ngroups {self.zero_ngroups} must be >= 0") + if self.use_zero and self.zero_ngroups <= 0: + raise ValueError(f"zero_ngroups {self.zero_ngroups} must be > 0") + if not self.use_zero and self.zero_ngroups != 1: + logger.warning(f"use_zero is False, but zero_ngroups is {self.zero_ngroups}. Will set zero_ngroups to 1.") + self.zero_ngroups = 1 @property def gpu_config(self) -> Dict[str, int]: @@ -109,6 +113,34 @@ def gpu_config(self) -> Dict[str, int]: 'runtime_ngpus': self.runtime_ngpus, } + @property + def graph_config(self) -> Dict[str, Any]: + return { + 'dynamic_shape': self.dynamic_shape, + 'user_config': self.user_config, + } + + @property + def module_dedup_group_size(self) -> int: + """ + Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. + """ + return self.plan_ngpus + + @property + def optimizer_dedup_group_size(self) -> int: + """ + Get the size of the deduplication group of the optimizer state dict. + + Nonzero mode: the group size is the same with plan_ngpus + Zero mode: the group size is `zero_group`, which equals `runtime_ngpus//zero_ngroups` + """ + + if self.use_zero: + return self.runtime_ngpus // self.zero_ngroups + else: + return self.plan_ngpus + @contextmanager def _flags(flags, /, **kwargs): @@ -237,7 +269,7 @@ class BroadcastGenFilesStrategy(Enum): # you can only construct the parallel module with `init_params=False`. # You can then # 1. Load the weights from a checkpoint file with `module.load_state_dict` or `load_merged_state_dict` - # 2. Or you can use `ParallelModule.broadcast_weights` to get the weights from the workers in node0. + # 2. Or you can use `broadcast_weights` to get the weights from the workers in node0. # (local world size should be bigger than plan_ngpus) NO_WEIGHTS = 'no_weights' @@ -807,6 +839,7 @@ class ModuleParameterLocation: offset: int count: int + @dataclass class OptimizerExtraState: """ @@ -829,11 +862,17 @@ class OptimizerExtraState: rank: int name: str parallel_module_locs: Dict[str, ModuleParameterLocation] + parallel_module_configs: Dict[str, ComputeConfig] def __post_init__(self): - for k in self.parallel_module_locs: - if isinstance(self.parallel_module_locs[k], dict): - self.parallel_module_locs[k] = ModuleParameterLocation(**self.parallel_module_locs[k]) + self.parallel_module_locs = { + k: ModuleParameterLocation(**v) if isinstance(v, dict) else v + for k, v in self.parallel_module_locs.items() + } + self.parallel_module_configs = { + k: ComputeConfig(**v) if isinstance(v, dict) else v + for k, v in self.parallel_module_configs.items() + } class ParallelOptimizer(torch.optim.Optimizer): @@ -1013,6 +1052,11 @@ def _local_parameters(module: torch.nn.Module): rank=torch.distributed.get_rank(), name=type(optimizer).__name__, parallel_module_locs=opt_module_locs, + parallel_module_configs={ + name: m.get_compute_config() + for name, m in module.named_modules() + if isinstance(m, ParallelModule) + } ) def _step_pre_hook(opt, *args, **kwargs): @@ -1137,8 +1181,8 @@ def _get_parallel_module_state_dict_info( module_prefix = k[:-1] if module_prefix not in pm_extra_states: pm_extra_states[module_prefix] = [None] * len(pk_model_state_dicts) - opt_extra_state = ExtraState(**pk_model_state_dict[k]) - pm_extra_states[module_prefix][opt_extra_state.rank] = opt_extra_state + pm_extra_state = ExtraState(**pk_model_state_dict[k]) + pm_extra_states[module_prefix][pm_extra_state.rank] = pm_extra_state # collect ParallelModule state dicts # key is the module prefix of the parallel module in state dict @@ -1152,10 +1196,13 @@ def _get_parallel_module_state_dict_info( continue module_prefix = k[:-1] if module_prefix in pm_extra_states: + pm_extra_state = ExtraState(**pk_model_state_dict[module_prefix + (ParallelModule.EXTRA_STATE_KEY,)]) + module_dedup_group_size = pm_extra_state.compute_config.module_dedup_group_size if module_prefix not in pm_state_dicts: - pm_state_dicts[module_prefix] = [dict() for _ in range(len(pk_model_state_dicts))] - opt_extra_state = ExtraState(**pk_model_state_dict[module_prefix + (ParallelModule.EXTRA_STATE_KEY,)]) - pm_state_dicts[module_prefix][opt_extra_state.rank][k[-1]] = pk_model_state_dict[k] + pm_state_dicts[module_prefix] = [dict() for _ in range(module_dedup_group_size)] + # only collect the state from the first module_dedup_group_size ranks + if pm_extra_state.rank < module_dedup_group_size: + pm_state_dicts[module_prefix][pm_extra_state.rank][k[-1]] = pk_model_state_dict[k] else: # no further processing # here we assume values from all ranks are the same @@ -1225,25 +1272,30 @@ def _get_optimizer_state_dict_info( opt_extra_states[opt_extra_state.rank] = opt_extra_state for module_prefix, loc in opt_extra_state.parallel_module_locs.items(): + opt_dedup_group_size = opt_extra_state.parallel_module_configs[module_prefix].optimizer_dedup_group_size if module_prefix not in opt_state_dicts: - opt_state_dicts[module_prefix] = [dict(state={}, param_groups=[]) for _ in range(len(optimizer_state_dicts))] - for i in range(loc.offset, loc.offset + loc.count): - # if the parameter is not used or requires_grad is False, it will not be in the state dict - # for us, as we use a continous buffer, it will always have grad, so it will always be in the state dict - # the state for each parameters is inserted in Adam in a lazy way. - # see https://github.com/pytorch/pytorch/blob/dad1b765848c4f52501c4c60b1c3e6fbd3cc8837/torch/optim/adam.py#L103 - assert i in opt_state_dict['state'] - opt_state_dicts[module_prefix][opt_extra_state.rank]['state'][i - loc.offset] = opt_state_dict['state'][i] - # TODO: inaccurate param_groups, for example, the 'params' in it is not right. - # we have this to make `ParallelModule.merge_partial_states` happy. - opt_state_dicts[module_prefix][opt_extra_state.rank]['param_groups'] = copy.deepcopy(opt_state_dict['param_groups']) + opt_state_dicts[module_prefix] = [dict(state={}, param_groups=[]) for _ in range(opt_dedup_group_size)] + # only collect the state from the first optimizer_dedup_group_size ranks + if opt_extra_state.rank < opt_dedup_group_size: + for i in range(loc.offset, loc.offset + loc.count): + # if the parameter is not used or requires_grad is False, it will not be in the state dict + # for us, as we use a continous buffer, it will always have grad, so it will always be in the state dict + # the state for each parameters is inserted in Adam in a lazy way. + # see https://github.com/pytorch/pytorch/blob/dad1b765848c4f52501c4c60b1c3e6fbd3cc8837/torch/optim/adam.py#L103 + assert i in opt_state_dict['state'] + opt_state_dicts[module_prefix][opt_extra_state.rank]['state'][i - loc.offset] = opt_state_dict['state'][i] + # TODO: inaccurate param_groups, for example, the 'params' in it is not right. + # we have this to make `ParallelModule.merge_partial_states` happy. + opt_state_dicts[module_prefix][opt_extra_state.rank]['param_groups'] = copy.deepcopy(opt_state_dict['param_groups']) for k, v in opt_state_dict.items(): if k == ParallelModule.EXTRA_STATE_KEY or k == 'state': continue # no further processing # here we assume values from all ranks are the same - ret_opt_state_dict[k] = v + # the value may change, so we deepcopy to make sure the input is not accidentally changed + # for example, it will updated in `merge_state_dict` function. + ret_opt_state_dict[k] = copy.deepcopy(v) return opt_extra_states, opt_state_dicts, ret_opt_state_dict @@ -1271,11 +1323,6 @@ def merge_state_dicts( Returns: Tuple[Dict[str, Any], Optional[List[Dict[str, Any]]]]: the merged model state dict and the merged optimizer state dict """ - if optimizer_state_dicts is not None: - # TODO: support checkpoint optimization - # where the following check may be too strong. - if len(module_state_dicts) != len(optimizer_state_dicts): - raise ValueError("The length of model_state_dicts and optimizer_state_dicts should be the same.") if not module_state_dicts: raise ValueError("model_state_dicts should not be empty.") @@ -1300,19 +1347,16 @@ def merge_state_dicts( for k, state_dicts_for_merge in pm_state_dicts.items(): extra_states = pm_extra_states[k] module_prefix = '.'.join(k) - opt_state_dicts_for_merge = [{'state': {}} for _ in range(len(state_dicts_for_merge))] \ - if opt_state_dicts is None else opt_state_dicts[module_prefix] - - merge_partial_states_state_dicts = [] - merge_partial_states_zero_idx_maps = [] - for m, opt, extra in zip(state_dicts_for_merge, opt_state_dicts_for_merge, extra_states): - merge_partial_states_state_dicts.append((m, opt, extra.dist_param_map, extra.param_area_map)) - merge_partial_states_zero_idx_maps.append((extra.model_idx2opt_idx, extra.opt_idx2ranks)) + opt_state_dicts_for_merge = None if opt_state_dicts is None else opt_state_dicts[module_prefix] + + merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks) for e in extra_states] if not extra_states[0].compute_config.use_zero: # all ranks should have the same use_zero merge_partial_states_zero_idx_maps = None - merged_state_dict, merged_opt_state_dict = ParallelModule.merge_partial_states( - merge_partial_states_state_dicts, - merge_partial_states_zero_idx_maps + merged_state_dict, merged_opt_state_dict = ParallelModule.merge_state_dicts( + [e.param_area_map for e in extra_states], + state_dicts_for_merge, + opt_state_dicts_for_merge, + merge_partial_states_zero_idx_maps, ) # merge back module state dict @@ -1398,12 +1442,12 @@ def merge_state_dicts( @torch.no_grad() def load_merged_state_dicts( - module: torch.nn.Module, - module_state_dict: Dict[str, Any], - optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, - optimizer_state_dict: Optional[Dict[str, Any]] = None, - *, - device: Union[str, torch.device] = None + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + device: Union[str, torch.device] = None ): """ Load the merged state dicts to the module, and optionally the optimizer to a specified device. @@ -1753,3 +1797,235 @@ def _broadcast_gen_files( # wait for all nodes to finish torch.distributed.barrier() + + +@torch.no_grad() +def deduped_state_dict( + module: torch.nn.Module, + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, +) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: + """ + Return the state dict only for the ranks that is necessary. + For details, see `ComputeConfig.optimizer_dedup_group_size` + and `ComputeConfig.module_dedup_group_size`. + + Args: + module (torch.nn.Module): the module to get state dict + optimizer (Optional[Union[torch.optim.Optimizer, ParallelOptimizer]]): the optimizer to get state dict + + Returns: + Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: the deduped state dict for the module and optimizer + """ + + cur_rank = torch.distributed.get_rank() + module_state_dict, opt_state_dict = None, None + parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} + + # The reason we use `Module.state_dict` on the whole to get the complete state dict + # instead of call `Module.state_dict` on each submodule + # is to make sure the hooks to state_dict are called. + module_state_dict = module.state_dict() + for key in list(module_state_dict.keys()): + if key.endswith(ParallelModule.EXTRA_STATE_KEY): # never remove extra state + continue + prefix = '.'.join(key.split('.')[:-1]) # remove the last part of the key + dedup_group_size = parallel_modules[prefix].module_dedup_group_size \ + if prefix in parallel_modules else 1 + # only keep the first `dedup_group_size` ranks' state + if cur_rank >= dedup_group_size: + module_state_dict.pop(key, None) + + if optimizer is not None: + opt_state_dict = optimizer.state_dict() + + # get the locations of non-parallel module parameters + # by removing the parallel module locations + non_parallel_module_locs: Set[int] = set(opt_state_dict['param_groups'][0]['params']) + for pm_loc in optimizer._extra_state.parallel_module_locs.values(): + non_parallel_module_locs.difference_update(range(pm_loc.offset, pm_loc.offset + pm_loc.count)) + + # only keep non-parallel module parameters in rank 0 + if cur_rank > 0: + for idx in non_parallel_module_locs: + opt_state_dict['state'].pop(idx, None) + + for pm_prefix, pm_loc in optimizer._extra_state.parallel_module_locs.items(): + dedup_group_size = optimizer._extra_state.parallel_module_configs[pm_prefix].optimizer_dedup_group_size + # only keep the first `dedup_group_size` ranks' state + if cur_rank >= dedup_group_size: + for idx in range(pm_loc.offset, pm_loc.offset + pm_loc.count): + opt_state_dict['state'].pop(idx, None) + + return module_state_dict, opt_state_dict + + +@torch.no_grad() +def load_deduped_state_dict( + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + device: Union[str, torch.device] = None +) -> None: + """ + Load the deduped state dicts to the module and optionally the optimizer to a specified device. + + Args: + module (torch.nn.Module): the module to be loaded + module_state_dict (Dict[str, Any]): the deduped model state dict + optimizer (Optional[Union[torch.optim.Optimizer, ParallelOptimizer]]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the deduped optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. + Returns: + None + """ + device = device or torch.cuda.current_device() + + # only load partial state for all ranks except rank 0 + module.load_state_dict(module_state_dict, strict=False) + module.to(device) + torch.distributed.barrier() + + # broadcast weights + broadcast_weights(module) + + if optimizer is not None: + if 'adam' not in optimizer._extra_state.name.lower(): + raise ValueError("Only Adam-like optimizers are supported.") + if optimizer_state_dict is None: + raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") + + for idx, state in optimizer_state_dict['state'].items(): + for key, value in state.items(): + if isinstance(value, torch.Tensor): + optimizer_state_dict['state'][idx][key] = value.to(device) + + # get the locations of non-parallel module parameters + # by removing the parallel module locations + non_parallel_module_locs: Set[int] = set(optimizer_state_dict['param_groups'][0]['params']) + # a list of tuple to track how to broadcast states + # Tuple: + # 0: a list of state idx + # 1: the dedup group size for the state idx's + opt_broadcast_groups: List[Tuple[List[int], int]] = [] + for prefix, pm_loc in optimizer._extra_state.parallel_module_locs.items(): + state_range = list(range(pm_loc.offset, pm_loc.offset + pm_loc.count)) + opt_broadcast_groups.append((state_range, optimizer._extra_state.parallel_module_configs[prefix].optimizer_dedup_group_size)) + non_parallel_module_locs.difference_update(state_range) + # append also works + # but insert to 0 feels better + # the dedup size for non-parallel module is 1 + opt_broadcast_groups.insert(0, (list(non_parallel_module_locs), 1)) + # TODO: what if opt_broadcast_groups are different in different ranks? + # Will it happend in pipeline parallelism? + for bg in opt_broadcast_groups: + _broadcast_opt_state(optimizer_state_dict, *bg) + optimizer.load_state_dict(optimizer_state_dict) + + torch.distributed.barrier() + + +def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_group_size: int): + rank = torch.distributed.get_rank() + broadcast_group = setup_stride_broadcast_group(dedup_group_size) + src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks + + logging.info(f'Rank-{rank} is broadcasting states to ranks {curr_parallel_group_ranks}, broadcast root: {src_rank}...') + + # broadcast param groups and state keys/shapes/dtypes via broadcast_object_list + if rank == src_rank: + state_info = {} + for idx in state_indexes: + state_info[idx] = {key: (value.shape, value.dtype) for key, value in optimizer_state_dict['state'][idx].items()} + sent = [state_info] + else: + sent = [None] + torch.distributed.broadcast_object_list( + sent, + src=src_rank, + group=curr_parallel_group, + ) + if rank != src_rank: + for k, v in sent[0].items(): + optimizer_state_dict['state'][k] = { + key: torch.zeros(value[0], dtype=value[1], device=torch.cuda.current_device()) + for key, value in v.items() + } + + # broadcast step + # step is too small, so we can just broadcast all of them all together + if rank == src_rank: + step_stack = torch.stack( + [optimizer_state_dict['state'][k]['step'] for k in state_indexes] + ) + else: + step_stack = torch.zeros( + len(state_indexes), + dtype=optimizer_state_dict['state'][k]['step'].dtype, + device=torch.cuda.current_device() + ) + torch.distributed.broadcast(step_stack, src=src_rank, group=curr_parallel_group) + if rank != src_rank: + for k, v in zip(state_indexes, step_stack): + optimizer_state_dict['state'][k]['step'].copy_(v) + + # broadcast other states + # TODO: can be slow? + for k in state_indexes: + keys = sorted(optimizer_state_dict['state'][k].keys()) + assert set(keys) == {'step', 'exp_avg', 'exp_avg_sq'} + keys.remove('step') # we have done step in previous. + for key in keys: + value = optimizer_state_dict['state'][k][key] + torch.distributed.broadcast(value.data, src=src_rank, group=curr_parallel_group) + + torch.distributed.barrier() + + +def broadcast_weights(module: torch.nn.Module, stride_size: Optional[int] = None): + """ + Broadcast the weights of the module from the ranks in dedup group to all ranks. + + When you load the deduped state dict to broadcast the weights, you don't need to specify the `stride_size`. + + Args: + module (torch.nn.Module): the module to be broadcasted + stride_size (Optional[int]): the stride size for broadcast. + If it is None, will use the dedup group size of each submodule. + Returns: + None + """ + parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} + + for prefix, m in module.named_modules(): + if stride_size is not None: + stride = stride_size + elif prefix not in parallel_modules: + stride = 1 + else: + stride = parallel_modules[prefix].module_dedup_group_size + _broadcast_weights(m, stride) + + +def _broadcast_weights(module: torch.nn.Module, stride_size: int): + broadcast_group = setup_stride_broadcast_group(stride_size) + rank = torch.distributed.get_rank() + src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks + logging.info(f'Rank-{rank} is broadcasting weight to ranks {curr_parallel_group_ranks}, broadcast root: {src_rank}...') + + # we have a special optimization for ParallelModule + params = module.parameters_for_broadcast() if isinstance(module, ParallelModule) else module._parameters.values() + logging.info(f'Inplace broadcasting {len(params)} parameters...') + for i, param in enumerate(params): + torch.distributed.broadcast(param.data, src=src_rank, group=curr_parallel_group) + logging.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') + + # NOTE: may batch buffers for efficient broadcast, + # current implementation is the most memory efficient way. + logging.info(f'Inplace broadcasting {len(module._buffers)} buffers...') + for _, buffer in module._buffers.items(): + torch.distributed.broadcast(buffer.data, src=src_rank, group=curr_parallel_group) + + torch.distributed.barrier() diff --git a/cube/runtime/module.py b/cube/runtime/module.py index c076aa0b..0c47dffd 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -243,8 +243,10 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref }, filename) @staticmethod - def merge_model_state_dicts(state_dicts: List[Dict], - fullmaps: List[Dict[str, AttrMeta]]): + def merge_model_state_dicts( + state_dicts: List[Dict], + fullmaps: List[Dict[str, AttrMeta]] + ): """Merge model states from multiple shard into a single-model state. Note: @@ -309,37 +311,57 @@ def merge_partial_states(state_dicts: List, * optim_state_dicts (Optional[List[Dict]]): per-rank optimizer state dict from optimizer.state_dict() * dist_param_map: deprecated, will be removed in the future. * fullmaps (List[Dict[str, AttrMeta]]): per-rank fullmap - zero_idx_maps (Optional[List[Dict]]) + zero_idx_maps (Optional[List[Dict]]): zero information for the model, `None` if zero is not enabled Returns: Dict[str, torch.Tensor]: Full model state dict - Dict[str, Dict[str, torch.Tensor]]: Full optimizer state dict + Dict[str, Any]: Full optimizer state dict """ - model_state_dicts = [states[0] for states in state_dicts] - optim_state_dicts = [states[1] for states in state_dicts] - fullmaps: List[Dict[str, AttrMeta]] = [states[-1] for states in state_dicts] + # the filtering below is to be compatible with fairseq + # which will set some model_state_dicts/optim_state_dicts to None for deduplication + return CubeModule.merge_state_dicts( + [state_dict[-1] for state_dict in state_dicts], + [state_dict[0] for state_dict in state_dicts if state_dict[0] is not None], + [state_dict[1] for state_dict in state_dicts if state_dict[1] is not None], + zero_idx_maps + ) - if len(model_state_dicts) != len(fullmaps): - raise ValueError("Expected model state dicts to have the same length as fullmaps") - if optim_state_dicts is not None: - if len(optim_state_dicts) != len(fullmaps): - raise ValueError("Expected optimizer state dicts to have the same length as fullmaps") + @staticmethod + def merge_state_dicts( + fullmaps: List[Dict[str, AttrMeta]], + model_state_dicts: List[Dict[str, torch.Tensor]], + optim_state_dicts: Optional[List[Dict[str, Any]]] = None, + zero_idx_maps: Optional[List[Dict]] = None + ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]: + """Merge model and optimizer states from different shard into a single-model state. + + `fullmaps` should always have the information for all ranks. + To support checkpoint deduplication, `model_state_dicts` and `optim_state_dicts` + can contains only the first `dedup_group_size` items. + + Warnings: + * This function only supports merging optimizer states of Adam-like optimizers, + in which the optimizer state is expected to contain 'state' keyword. + * Only support single parameter group, i.e., code implementations like: `torch.optim.Adam(model.parameters(), lr=0.1)` + + Args: + fullmaps (List[Dict[str, AttrMeta]]): per-rank fullmap + model_state_dicts (List[Dict[str, torch.Tensor]]): per-rank model state dict from model.state_dict() + optim_state_dicts (Optional[List[Dict]]): per-rank optimizer state dict from optimizer.state_dict() + zero_idx_maps (Optional[List[Dict]]): zero information for the model, `None` if zero is not enabled + Returns: + Dict[str, torch.Tensor]: Full model state dict + Dict[str, Any]: Full optimizer state dict + """ # gather model states - full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps) + full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps[0: len(model_state_dicts)]) + if optim_state_dicts is None: + return full_model_state_dict, None # gather optimizer states full_optim_state_dict: Dict[str, Any] = {} # param_id -> Dict[state_name, value] - plan_ngpus = -1 - # TODO: remove this flag - if 'PLAN_NGPUS' in os.environ: - plan_ngpus = int(os.environ['PLAN_NGPUS']) - assert plan_ngpus >= 1, plan_ngpus - assert plan_ngpus <= len(state_dicts), f'plan_ngpus = {plan_ngpus}, len(state_dicts) = {len(state_dicts)}' - assert len(state_dicts) % plan_ngpus == 0, f'plan_ngpus = {plan_ngpus}, len(state_dicts) = {len(state_dicts)}' - _logger.info(f'plan_ngpus = {plan_ngpus}') - # at first, merge the partitioned optimizer states due to zero to the zero-disabled format if zero_idx_maps is not None: def _check_state_size(opt_state_keys, bucket_state): @@ -392,21 +414,20 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): return opt_states opt_state_list = [] - worker_cnt = len(state_dicts) - for work_idx in (range(worker_cnt) if plan_ngpus < 0 else range(plan_ngpus)): + worker_cnt = len(optim_state_dicts) + for work_idx in range(worker_cnt): model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[work_idx] opt_state = {} for model_idx, opt_idx in model_idx2opt_idx.items(): if isinstance(opt_idx, int): # the param without reducer assert opt_idx2ranks[opt_idx] is None - # state_dicts [worker idx][opt state]['state'][param idx] - opt_state[model_idx] = state_dicts[work_idx][1]['state'][opt_idx] + opt_state[model_idx] = optim_state_dicts[work_idx]['state'][opt_idx] else: # the param in reducer bucket opt_idx, pstart, pend, pshape = opt_idx ranks, bucket_size = opt_idx2ranks[opt_idx] - bucket_states = [state_dicts[rank][1]['state'][opt_idx] for rank in ranks] + bucket_states = [optim_state_dicts[rank]['state'][opt_idx] for rank in ranks] opt_state[model_idx] = _retrieve_param_opt_state( bucket_states, pstart, @@ -414,11 +435,11 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): pshape, bucket_size) opt_state_list.append(opt_state) - assert len(state_dicts[work_idx][1]['param_groups']) == 1, 'only support param_groups to be one group' + assert len(optim_state_dicts[work_idx]['param_groups']) == 1, 'only support param_groups to be one group' # assign opt_state to state_dicts, cannot be assigned in the above loop opt_state_len = len(opt_state_list[0]) - for work_idx in (range(worker_cnt) if plan_ngpus < 0 else range(plan_ngpus)): + for work_idx in range(worker_cnt): optim_state_dicts[work_idx]['state'] = opt_state_list[work_idx] optim_state_dicts[work_idx]['param_groups'][0]['params'] = sorted(opt_state_list[work_idx].keys()) assert len(opt_state_list[work_idx]) == opt_state_len @@ -536,9 +557,10 @@ def __post_init__(self): if isinstance(self.compute_config, dict): from cube.parallel import ComputeConfig self.compute_config = ComputeConfig(**self.compute_config) - for k in self.param_area_map: - if isinstance(self.param_area_map[k], dict): - self.param_area_map[k] = AttrMeta(**self.param_area_map[k]) + self.param_area_map = { + k: AttrMeta(**v) if isinstance(v, dict) else v + for k, v in self.param_area_map.items() + } @dataclass @@ -630,7 +652,7 @@ def get_dist_param_map(self) -> Dict[str, str]: def get_compute_config(self) -> 'ComputeConfig': return self._compute_config - def get_rank(self): + def get_rank(self) -> int: return self.rank # rank is a class varible defined in gencode def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]: @@ -750,7 +772,7 @@ def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: assert cf.zero_ngroups == 1 return rank_idx, reducer.ranks - def _post_state_dict_hook(self, state_dict, prefix, local_metadata) -> None: + def _add_extra_state(self, state_dict, prefix) -> None: state_dict[f'{prefix}{self.EXTRA_STATE_KEY}'] = asdict( ExtraState( rank=self.get_rank(), @@ -763,9 +785,29 @@ def _post_state_dict_hook(self, state_dict, prefix, local_metadata) -> None: ) ) - def _pre_load_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None: + def _remove_extra_state(self, state_dict, prefix) -> None: state_dict.pop(f'{prefix}{self.EXTRA_STATE_KEY}', None) + def _post_state_dict_hook(self, state_dict, prefix, local_metadata) -> None: + self._add_extra_state(state_dict, prefix) + + def _pre_load_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None: + self._remove_extra_state(state_dict, prefix) + + @property + def module_dedup_group_size(self) -> int: + """ + Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. + """ + return self.get_compute_config().module_dedup_group_size + + @property + def optimizer_dedup_group_size(self) -> int: + """ + Get the size of the deduplication group of the optimizer state dict. + """ + return self.get_compute_config().optimizer_dedup_group_size + def _list_fullmodel_files(self) -> List[Path]: legacy_fullmodel_path = self.module_dir / FxModuleParser.ATTR_CONTENT_FILE_STEM files = [] @@ -830,41 +872,3 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s raise RuntimeError(erro_msg) else: _logger.warning(erro_msg) - - def broadcast_weights(self): - """ - Broadcast weights (including parameters and buffers) across scale units. - The source ranks is the ranks in first scale unit. - The weights in the ranks in the rest scale units will be replace inplace. - """ - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', default=1)) - plan_ngpus = self.get_compute_config().plan_ngpus - - if local_world_size < plan_ngpus: - raise RuntimeError(f'LOCAL_WORLD_SIZE {local_world_size} is less than plan_ngpus {self.get_compute_config().plan_ngpus}. Cannot broadcast weights.') - - for i in range(plan_ngpus): - ranks = list(range(i, world_size, plan_ngpus)) - DeviceGroup().get_group(ranks) - - curr_parallel_group_ranks = list(range(rank % plan_ngpus, world_size, plan_ngpus)) - curr_parallel_group = DeviceGroup().get_group(curr_parallel_group_ranks) - src_rank = min(curr_parallel_group_ranks) - logging.info(f'Rank-{rank} is broadcasting weight to ranks {curr_parallel_group_ranks}, broadcast root: {src_rank}...') - - # NOTE: please make sure the above checkpoint load is from local checkpoint file, - # otherwise, the following broadcast may time out due to slow checkpoint file read. - # Broadcast parameters and buffers across scale units - params = self.parameters_for_broadcast() - logging.info(f'Inplace broadcasting {len(params)} parameters...') - for i, param in enumerate(params): - torch.distributed.broadcast(param.data, src=src_rank, group=curr_parallel_group) - logging.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') - - # NOTE: may batch buffers for efficient broadcast, - # current implementation is the most memory efficient way. - logging.info(f'Inplace broadcasting {len(self._buffers)} buffers...') - for _, buffer in self._buffers.items(): - torch.distributed.broadcast(buffer.data, src=src_rank, group=curr_parallel_group) diff --git a/cube/utils.py b/cube/utils.py index 697a2368..3f2a94b6 100644 --- a/cube/utils.py +++ b/cube/utils.py @@ -1,9 +1,10 @@ import os -from typing import Optional, Tuple, Callable, List, Set, Any +from typing import Optional, Tuple, Callable, List, Set, Any, Iterable import logging from pathlib import Path import sys from collections import defaultdict +from dataclasses import dataclass import cube from cube.runtime.device import DeviceGroup @@ -104,6 +105,47 @@ def get_shared_params(model: torch.nn.Module) -> List[List[str]]: return [list(names) for _, names in paramid2name.items() if len(names) > 1] +@dataclass +class BroadcastGroup: + src_rank: int # the source rank in the group which the current rank belongs to + ranks: List[int] # the ranks in the group which the current rank belongs to + group: torch.distributed.ProcessGroup + + +def setup_stride_broadcast_group(stride_size: int) -> BroadcastGroup: + """ + Setup the broadcast group for the given stride size. + + For example, assume stride size is 4, then + we will create 4 broadcasting groups: + [0, 4, 8, ...], + [1, 5, 9, ...], + [2, 6, 10, ...], + [3, 7, 11, ...] + the broadcast will happen in above groups, the sending rank is the first rank in the group. + + Args: + stride_size (int): the stride size. + Returns: + BroadcastGroup: the source rank and the broadcast group. + """ + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + for i in range(stride_size): + ranks = list(range(i, world_size, stride_size)) + DeviceGroup().get_group(ranks) + + curr_parallel_group_ranks = list(range(rank % stride_size, world_size, stride_size)) + curr_parallel_group = DeviceGroup().get_group(curr_parallel_group_ranks) + src_rank = min(curr_parallel_group_ranks) + + return BroadcastGroup( + src_rank=src_rank, + ranks=curr_parallel_group_ranks, + group=curr_parallel_group + ) + + class accum_mode: """Make cube execution in gradient accumulation mode. diff --git a/docs/parallel_module.md b/docs/parallel_module.md index 4bbc6e0f..acf06461 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -128,6 +128,7 @@ We can categorize the fields into 4 categories: 1. if `rank1 // plan_gpus == rank2 // plan_ngpus`, then they are in the same unit. 2. If `rank1 % plan_ngpus == rank2 % plan_ngpus`, then the portion of model hold on both gpus are exactly the same. - runtime_ngpus: the number of gpus to be used in runtime. It should be a multiple of `plan_ngpus`, which means we have `runtime_ngpus // plan_ngpus` units in runtime, and the data parallelism is `runtime_ngpus // plan_ngpus`. + Please note all modules must have the same `plan_ngpus` and `runtime_ngpus`. 3. Code generation feature configuration - use_zero: whether to use zero. If it is true, the generated code will use zero1 to do distributed training. - zero_ngroups: the number of groups to be used in zero. @@ -187,13 +188,13 @@ We call it a `match` when the `ComputeConfig` is the same with the previous run. 1. MATCH: Reuse if match, error if not match, generate if no previous gerenated code exists. 2. OVERRIDE: Nothing will be reused. Everything will be regenerated. -3. MOO: MOO is short for 'match or override'. It will reuse if match, generate if not match or no previous gerenated code exists. +3. MOO: MOO is short for 'match or override'. It will reuse if match, generate if not match or no previous generated code exists. 4. GRAPH: Reuse graph only if match, generate otherwise. ### BroadcastGenFilesStrategy The broadcast strategy for new generated files. -Please note we never broadcast reused files. +Please note we never broadcast reused files (i.e., specified by `ReuseType`.). ```python class BroadcastGenFilesStrategy(Enum): @@ -203,25 +204,55 @@ class BroadcastGenFilesStrategy(Enum): CODE = 'code' ``` -1. None: nothing will be broadcasted. +1. `None`: nothing will be broadcasted. You need to do it by yourself or the generated files are save in a shared directory (like azure blob). -2. ALL: broadcast all the generated files to all nodes. +2. `ALL`: broadcast all the generated files to all nodes. This is useful when you want to run the same code on all nodes. please note the init weight files can be huge. -3. NO_WEIGHTS: broadcast all except init weights. +3. `NO_WEIGHTS`: broadcast all except init weights. Without weights, you can only construct the parallel module with `init_params=False`. You can then - - Load the weights from a checkpoint file with `module.load_state_dict` or `load_merged_state_dict` - - Or you can use `broadcast_weights_inplace` to get the weights from the workers in node0. + - Load the weights from a checkpoint file with `module.load_state_dict`, `load_merged_state_dict` + or `load_deduped_state_dict` + - Or you can use `broadcast_weights` to get the weights from the workers in node0. (local world size should be bigger than plan_ngpus) -4. CODE: broadcast the new generated code only +4. `CODE`: broadcast the new generated code only It's your responsibility to make sure other necessary files are available on all nodes. +Here are some guidelines to choose the strategy: -### Module Conversion +1. When restarting a training and there is a successful previous run: As we have a previous run, the compiling process has been done before. So there will be no new generated files and no broadcast will happen no matter what this option is. Please be sure the reuse flag of `parallelize` is `MATCH`, so we can make sure the generated code is the same with the previous run. + +2. When training a model from scratch. If there is only one node, `none` is good enough. +If there are multiple nodes, here are some strategies: + +a. If use `none`, the user should run `parallelize(..., load_module=False, ..)`, and then copy all files to all nodes manually, so all nodes have the same files. Then the user load the module by running `parallelize(..., load_module=True, ..)`. + +b. if they are using a NAS-like device to save generated files, and the upload/download speed is fast in the cluster, they can also use `none`, and just run `parallelize(..., load_module=True, ..)` to do the training. + +c. If use `all`, then user can just run `parallelize(..., load_module=True, ..)` safely. (remember to set `nccl` communication timeout to a very big value to tolerate the duration of this `nccl` broadcast) + +d. If use `no_weights`. then user can run `parallelize(..., load_module=True, init_module_params=rank Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: +``` + +#### `load_deduped_state_dict` + +```python +def load_deduped_state_dict( + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + device: Union[str, torch.device] = None +) -> None: +``` + ### Dataset We use the same dataset/dataloader as pytorch. For example, you can use `torch.utils.data.DistributedSampler` to create a distributed sampler. -`ParallelModule`s running in the same unit should use the same input, and will get the same output. `ParallelModule`s runing in different units should use different input and will get different output (similar to data parallelism). The gradients of all parameters will be synced across all the devices automatically. +`ParallelModule`s running in the same unit should use the same input, and will get the same output. `ParallelModule`s running in different units should use different input and will get different output (similar to data parallelism). The gradients of all parameters will be synced across all the devices automatically. Take `torch.utils.data.DistributedSampler` for example, you can create the sampler like this: ```python @@ -427,4 +482,5 @@ def create_distributed_sampler(dataset): ``` ## TODOs + 1. Pipeline parallelism is not supported yet. diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 4819e715..d681162c 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -2,7 +2,7 @@ import math import random import shutil -from typing import List, Optional +from typing import Any, Dict, List, Optional import contextlib import torch @@ -158,3 +158,19 @@ def clear_dir_on_rank0(tempdir): torch.distributed.barrier() if torch.distributed.get_rank() == 0 and tempdir.exists(): shutil.rmtree(tempdir) + + +def assert_equal(a: Any, b: Any): + assert type(a) == type(b) + if isinstance(a, torch.Tensor): + assert torch.equal(a.cpu(), b.cpu()) + elif isinstance(a, dict): + assert len(a) == len(b) + for k in a.keys(): + assert_equal(a[k], b[k]) + elif isinstance(a, (list, tuple)): + assert len(a) == len(b) + for i in range(len(a)): + assert_equal(a[i], b[i]) + else: + assert a == b diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index 4ea30cc5..61f7e371 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -5,7 +5,7 @@ import pytest import torch -from cube.parallel import ComputeConfig, parallelize +from cube.parallel import ComputeConfig, parallelize, broadcast_weights from .common import PASRandomSPMD, init_distributed from ..launch_torchrun import launch_torchrun @@ -91,7 +91,7 @@ def _gpu_worker(): for n, pa in m.named_parameters(): if n.startswith('linear_weight'): assert not torch.equal(pa.data, torch.ones_like(pa.data)) - m.broadcast_weights() + broadcast_weights(m) # check if broadcast_weights works for n, pa in m.named_parameters(): if n.startswith('linear_weight'): @@ -113,7 +113,7 @@ def _gpu_worker(): m = p(tempdir, 'all', '_6', load_module=False, reuse='graph') if torch.distributed.get_rank() != 0: # only python files are broadcasted - assert list(f.name for f in Path(tempdir).glob('**/*') if f.is_file()) == ['gencode0.py', 'gencode1.py'] + assert set(f.name for f in Path(tempdir).glob('**/*') if f.is_file()) == set(['gencode0.py', 'gencode1.py']) # case 6.2: everything should be broadcasted, including weights # so the load_module will succeed. diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py new file mode 100644 index 00000000..f0c965b7 --- /dev/null +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -0,0 +1,198 @@ +import tempfile +from pathlib import Path +import pytest +from typing import Dict, Tuple, List, Any + +import torch +from torch import nn + +from cube.parallel import ComputeConfig, parallelize, build_optimizer, \ + merge_state_dicts, load_merged_state_dicts, \ + deduped_state_dict, load_deduped_state_dict +from cube.runtime.module import ParallelModule + +from .common import PASRandomSPMD, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, assert_equal +from ..launch_torchrun import launch_torchrun + + +class FcRelu(nn.Module): + def __init__(self, in_features, bias=True): + super().__init__() + init_random() + self.fc1 = CubeLinear(in_features, in_features, bias=bias) + self.relu1 = nn.ReLU() + self.fc2 = CubeLinear(in_features, in_features, bias=bias) + self.relu2 = nn.ReLU() + + + def forward(self, x): + return self.relu2(self.fc2(self.relu1(self.fc1(x)))) + + +class FcRelu4(FcRelu): + def __init__(self): + super().__init__(4, 4) + + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + + +def _create_cube_module(pas, compute_config1, compute_config2, cube_savedir): + init_random() + class ParallelModule0(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.fc_relu1 = _to_cube_model( + FcRelu4(), pas, + compute_config1, cube_savedir, f'fc_relu1' + ) + self.linear2 = nn.Linear(4, 4) + self.fc_relu2 = _to_cube_model( + FcRelu4(), pas, + compute_config2, cube_savedir, f'fc_relu2' + ) + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.linear1(x) + x = self.fc_relu1(x) + x = self.linear2(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + + init_random() + return ParallelModule0().cuda() + + +DATA_SIZE = 256 +CKPT_FILE_NAME_TEMPLATE = '{}.pth' + + +def _train(model: torch.nn.Module, ckpt_dir): + CKPT_FILE_NAME = CKPT_FILE_NAME_TEMPLATE.format(torch.distributed.get_rank()) + DATA = [] + for _ in range(DATA_SIZE): + DATA.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.randn((2, 1), device='cuda', dtype=torch.float32), + )) + loss_fn = nn.BCELoss() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + for i, (x, y) in enumerate(DATA): + model.train() + optimizer.zero_grad() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.step() + deduped_model_state_dict, deduped_opt_state_dict = deduped_state_dict(model, optimizer) + torch.save({ + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'model-dedup': deduped_model_state_dict, + 'optimizer-dedup': deduped_opt_state_dict + }, ckpt_dir / CKPT_FILE_NAME) + + +def _check_deduped(model: torch.nn.Module, ckpt_dir): + rank = torch.distributed.get_rank() + ckpt_files = [ + ckpt_dir / CKPT_FILE_NAME_TEMPLATE.format(i) + for i in range(torch.distributed.get_world_size()) + ] + ckpt_state_dicts = [torch.load(f) for f in ckpt_files] + model_state_dicts = [ckpt['model'] for ckpt in ckpt_state_dicts] + optimizer_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_state_dicts] + dedupped_model_state_dicts = [ckpt['model-dedup'] for ckpt in ckpt_state_dicts] + dedupped_optimizer_state_dicts = [ckpt['optimizer-dedup'] for ckpt in ckpt_state_dicts] + + parallel_modules = [m for m in model.modules() if isinstance(m, ParallelModule)] + assert len(parallel_modules) == 2 + + module_dedup_group_size = [m.module_dedup_group_size for m in parallel_modules] + opt_dedup_group_size = [m.optimizer_dedup_group_size for m in parallel_modules] + assert all(s1 >= s2 for s1, s2 in zip(opt_dedup_group_size, module_dedup_group_size)) + assert all(s1 % s2 == 0 for s1, s2 in zip(opt_dedup_group_size, module_dedup_group_size)) + + # check deduped state dicts are correct + for i, ( + model_state_dict, + optimizer_state_dict, + dedupped_model_state_dict, + dedupped_optimizer_state_dict + ) in enumerate(zip(model_state_dicts, optimizer_state_dicts, dedupped_model_state_dicts, dedupped_optimizer_state_dicts)): + if i == 0: + assert_equal(model_state_dict, dedupped_model_state_dict) + elif i >= max(module_dedup_group_size): + # only EXTRA_STATEs are kept + assert len(dedupped_model_state_dict) == len(parallel_modules) + assert all(k.endswith(ParallelModule.EXTRA_STATE_KEY) for k in dedupped_model_state_dict.keys()) + else: + assert len(parallel_modules) < len(dedupped_model_state_dict) < len(model_state_dict) + for k, v in dedupped_model_state_dict.items(): + assert_equal(v, model_state_dict[k]) + + # we keep param_groups in all ranks. + assert_equal(dedupped_optimizer_state_dict['param_groups'], optimizer_state_dict['param_groups']) + if i == 0: + assert_equal(optimizer_state_dict, dedupped_optimizer_state_dict) + elif i >= max(opt_dedup_group_size): + # only EXTRA_STATEs and param_groups are kept + assert not dedupped_optimizer_state_dict['state'] # should have empty state + else: + assert 0 < len(dedupped_optimizer_state_dict['state']) < len(optimizer_state_dict['state']) + for k, v in dedupped_optimizer_state_dict['state'].items(): + assert_equal(v, optimizer_state_dict['state'][k]) + + # check deduped state dicts can be merged and output exactly the same state dict + merged_model_state_dicts, merged_optimizer_state_dict = \ + merge_state_dicts(model_state_dicts, optimizer_state_dicts) + merged_model_state_dicts_dedup, merged_optimizer_state_dict_dedup = \ + merge_state_dicts(dedupped_model_state_dicts, dedupped_optimizer_state_dicts) + + assert_equal(merged_model_state_dicts, merged_model_state_dicts_dedup) + assert_equal(merged_optimizer_state_dict, merged_optimizer_state_dict_dedup) + + # check deduped state dicts can be loaded to the model + # which should output the same state dict as the original model + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + load_deduped_state_dict(model, dedupped_model_state_dicts[rank], + optimizer, dedupped_optimizer_state_dicts[rank] + ) + assert_equal(model.state_dict(), model_state_dicts[rank]) + assert_equal(optimizer.state_dict(), optimizer_state_dicts[rank]) + + +def _gpu_worker(pas, cc1, cc2): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_compact') as tempdir: + _train(_create_cube_module(pas, cc1, cc2, tempdir), tempdir) + torch.distributed.barrier() + _check_deduped( + _create_cube_module(pas, cc1, cc2, tempdir), + tempdir + ) + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('use_zero', [True, False]) +def test_checkpoint_compact(use_zero): + cc1 = ComputeConfig(1, 4, use_zero=use_zero, zero_ngroups=2 if use_zero else 1) + cc2 = ComputeConfig(1, 4, use_zero=use_zero, zero_ngroups=4 if use_zero else 1) + launch_torchrun(4, _gpu_worker, PASRandomSPMD, cc1, cc2) + + # mixed zero and non-zero + cc1 = ComputeConfig(2, 4, use_zero=not use_zero, zero_ngroups=2 if not use_zero else 1) + cc2 = ComputeConfig(2, 4, use_zero=use_zero, zero_ngroups=1) + launch_torchrun(4, _gpu_worker, PASRandomSPMD, cc1, cc2) From 2ee9f9078513a2b92229a82316a59613c8e324b8 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 28 Mar 2024 10:38:11 +0000 Subject: [PATCH 1608/1892] Merged PR 2080: limited revert wrapped functions during tracing & support torch 2.2/2.3 Because the `revert` cost a lot of time, only revert during run customized function now --- .../concrete_trace_utils/concrete_tracer.py | 162 ++++++++++-------- .../concrete_trace_utils/operator_patcher.py | 2 +- requirements.txt | 2 +- 3 files changed, 94 insertions(+), 72 deletions(-) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index aca1121c..7c540698 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -331,42 +331,37 @@ def __init__(self, cpu_offload = False, record_frames = False): self.record_frames = record_frames self.patcher = FunctionPatcher() + # When we concrete executing some functions, + # we need revert all the patched function to the unpatched version to ensure the correctness of some underlying code. + # For most functions, disable_call is sufficient, but it is necessary when executing, for example, a triton function. + # Here we put all user wrapped function into the set, and unpatch all the patched functions when executing the user function. + self.need_revert_functions = set() + self.need_revert_wrapped_functions = set() + + self.temp_call_origin = False + + def add_need_revert_function(self, func, wrapped_func): + self.need_revert_functions.add(func) + self.need_revert_wrapped_functions.add(wrapped_func) + + def need_revert(self, func): + return func in self.need_revert_functions or func in self.need_revert_wrapped_functions + @contextmanager - def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): - assert call | attr | agfunc_apply - # to pass pyright check - temp_disable_call, temp_disable_attr, temp_disable_agfunc_apply = False, False, False - if call: - self.temp_disable_call_level += 1 - temp_disable_call = self.temp_disable_call - self.temp_disable_call = True - if attr: - self.temp_disable_attr_level += 1 - temp_disable_attr = self.temp_disable_attr - self.temp_disable_attr = True - if agfunc_apply: - self.temp_disable_agfunc_apply_level += 1 - temp_disable_agfunc_apply = self.temp_disable_agfunc_apply - self.temp_disable_agfunc_apply = True + def do_temp_call_origin(self): + temp_call_origin = self.temp_call_origin + self.temp_call_origin = True try: yield finally: - if agfunc_apply: - self.temp_disable_agfunc_apply = temp_disable_agfunc_apply - self.temp_disable_agfunc_apply_level -= 1 - if attr: - self.temp_disable_attr = temp_disable_attr - self.temp_disable_attr_level -= 1 - if call: - self.temp_disable_call = temp_disable_call - self.temp_disable_call_level -= 1 + self.temp_call_origin = temp_call_origin @compatibility(is_backward_compatible=True) def fetch_attr(self, target: str) -> Any: """ to get the attr in self.root. only for execution of 'call_module' nodes. """ - with self.do_temp_disable(attr=True): + with self.do_temp_call_origin(): target_atoms = target.split('.') attr_itr = self.root for i, atom in _orig_enumerate(target_atoms): @@ -400,10 +395,14 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] assert isinstance(target, str) mod = self.fetch_attr(target) if self.cpu_offload: - mod.cuda() # how it works in ddp? - result = mod(*args, **kwargs) - if self.cpu_offload: - mod.cpu() + try: + mod.cuda() + result = mod(*args, **kwargs) + except: + mod.cpu() + raise + else: + result = mod(*args, **kwargs) elif kind == 'get_attr': assert isinstance(target, str) return self.fetch_attr(target) @@ -411,21 +410,27 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] raise RuntimeError() return result - with self.do_temp_disable(call=True): - if self.cpu_offload: - args, kwargs = tree_to_cuda(args), tree_to_cuda(kwargs) + if self.cpu_offload: + args, kwargs = tree_to_cuda(args), tree_to_cuda(kwargs) + try: result = run(kind, target, args, kwargs) - + except torch.cuda.OutOfMemoryError: if self.cpu_offload: - args, kwargs, result = tree_to_cpu(args), tree_to_cpu(kwargs), tree_to_cpu(result) + _logger.warning(f"cuda out of memory, try to trace {target} on cpu.") + args, kwargs = tree_to_cpu(args), tree_to_cpu(kwargs) + result = run(kind, target, args, kwargs) + else: + raise + + if self.cpu_offload: + args, kwargs, result = tree_to_cpu(args), tree_to_cpu(kwargs), tree_to_cpu(result) - unexpected_types = types_other_than(result, (*base_types, type(None), torch.Tensor)) - if not contains_types(result, (torch.Tensor,)) and unexpected_types: - _logger.warning(f"result of target {target} contains unexpected types {unexpected_types}, which is not a common behavior.") - torch.cuda.empty_cache() + unexpected_types = types_other_than(result, (*base_types, type(None), torch.Tensor)) + if not contains_types(result, (torch.Tensor,)) and unexpected_types: + _logger.warning(f"result of target {target} contains unexpected types {unexpected_types}, which is not a common behavior.") + torch.cuda.empty_cache() - self.temp_disable_call = False return result, args, kwargs @compatibility(is_backward_compatible=True) @@ -471,9 +476,14 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: use the 'run_target' to actually execute the code, and store the value in 'value' field. create the nodes for the target and the input of the target (if the target is one of call_method, call_function, call_module). """ - with self.patcher.revert(): + with self.do_temp_call_origin(): args_unwrapped, kwargs_unwrapped = unwrap_nested_proxy(args), unwrap_nested_proxy(kwargs) - value_unwrapped, args_run, kwargs_run = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) + + if self.need_revert(target): + with self.patcher.revert(): + value_unwrapped, args_run, kwargs_run = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) + else: + value_unwrapped, args_run, kwargs_run = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) # because setitem is an inplace operation and will not return the obj, so here is a workaound to record node result node_result = args_run[0] if kind == "call_function" and target == operator.setitem else value_unwrapped @@ -489,7 +499,7 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: node = self.create_node(kind, target, args_, kwargs_, name, type_expr, node_result) if self.record_frames and kind != 'placeholder': - with self.do_temp_disable(True, True, True): + with self.do_temp_call_origin(): # record code frame, include filename, line number, and function name frame_record = FrameRecord(None, None, None, None) cube_path = str(Path(importlib.util.find_spec('cube').origin).parent) + '/' # the cube path @@ -842,12 +852,12 @@ def get_middle_class(node, memo = set(), prefix = ''): @functools.wraps(_orig_module_getattribute) def module_getattribute_wrapper(mod, attr): - if self.temp_disable_call | self.temp_disable_attr: + if self.temp_call_origin: try: return _orig_module_getattribute(mod, attr) except AttributeError: return _orig_module_getattr(mod, attr) - with self.do_temp_disable(attr=True): + with self.do_temp_call_origin(): try: attr_val = _orig_module_getattribute(mod, attr) except AttributeError: @@ -881,7 +891,7 @@ def module_getattribute_wrapper(mod, attr): @functools.wraps(_orig_module_call) def module_call_wrapper(mod, *args, **kwargs): - if self.temp_disable_call: + if self.temp_call_origin: return _orig_module_call(mod, *args, **kwargs) else: # codes below corresponds to symbolic tracer's call_module @@ -1008,7 +1018,7 @@ def __hash__(self): def agfunc_apply_wrapper(clz, *args, **kwargs): if clz not in self.agfunc_dict: self.agfunc_dict[clz] = torch._C._FunctionBase.__dict__['apply'].__get__(None, clz) - if self.temp_disable_agfunc_apply or self.temp_disable_call: + if self.temp_call_origin: return self.agfunc_dict[clz](*args, **kwargs) tracers = _orig_set() def unwrap_detect_tracers(obj): @@ -1060,26 +1070,45 @@ def torch_assert_wrapper(condition, message): elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn' \ and not func.__qualname__.startswith('PyCapsule'): # method - if func.__module__.startswith('_') and func.__module__ != '__main__': - path = sys.modules[func.__module__[1:]] - else: - path = sys.modules[func.__module__] - path = getattr(path, func.__qualname__.split('.')[0]) - locations = (*locations, Location(path, func.__name__)) + # in torch >= 2.2, we found two functions under torch._C has no __module__: + # + # + if func.__module__ is not None: + if func.__module__.startswith('_') and func.__module__ != '__main__': + path = sys.modules[func.__module__[1:]] + else: + path = sys.modules[func.__module__] + path = getattr(path, func.__qualname__.split('.')[0]) + locations = (*locations, Location(path, func.__name__)) + if len(locations) == 0: + _logger.warning(f'Can not find location of {func}, skip wrap it.') + continue wrapped = _create_wrapped_leaf_method(self, func, func.__name__, wrap_info.replace_fn) else: # common function - if func.__module__.startswith('_') and func.__module__ != '__main__': - path = sys.modules[func.__module__[1:]] - else: - path = sys.modules[func.__module__] - locations = (*locations, Location(path, func.__name__)) + # in torch >= 2.2, we found two functions under torch._C has no __module__: + # + # + if func.__module__ is not None: + if func.__module__.startswith('_') and func.__module__ != '__main__': + path = sys.modules[func.__module__[1:]] + else: + path = sys.modules[func.__module__] + locations = (*locations, Location(path, func.__name__)) + if len(locations) == 0: + _logger.warning(f'Can not find location of {func}, skip wrap it.') + continue if wrap_info.is_force_trace: wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn, (self,)) else: wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn) self.wrapped_leaf[func] = (locations, wrapped) + # for the customized functions, we need to revert all the wrapped function to the original one to run it + # for the functions default wrapped, we don't revert to save time + for func in autowrap_leaf_function: + self.add_need_revert_function(func, self.wrapped_leaf.get(func, (None, None))[1]) + self.clz_wrapper_map: Dict[Any, Type] = { map_wrapper: _orig_map, enumerate_wrapper: _orig_enumerate, @@ -1157,13 +1186,6 @@ def getattr_wrapper(obj, *args): args[0] = args[0].value return _orig_getattr(obj, *args) - # for passing the tracing of leaf modules - self.temp_disable_call = False - self.temp_disable_attr = False - self.temp_disable_agfunc_apply = False - self.temp_disable_call_level = 0 - self.temp_disable_attr_level = 0 - self.temp_disable_agfunc_apply_level = 0 try: with self.patcher: # allow duplicate patches to support the case of nested calls @@ -1506,7 +1528,7 @@ def _create_wrapped_leaf_func(tracer: ConcreteTracer, func: Callable, to_func: O to_func = func @functools.wraps(func) def func_wrapper(*args, **kwargs): - if tracer.temp_disable_call: + if tracer.temp_call_origin: return func(*args, **kwargs) tracers = _orig_set(init_tracers) def unwrap_detect_tracers(obj): @@ -1525,7 +1547,7 @@ def unwrap_detect_tracers(obj): def _create_wrapped_leaf_method(tracer: ConcreteTracer, method, name: str, to_func: Optional[Callable]): @functools.wraps(method) def method_wrapper(*args, **kwargs): - if tracer.temp_disable_call: + if tracer.temp_call_origin: return method(*args, **kwargs) tracers = _orig_set() def unwrap_detect_tracers(obj): @@ -1560,7 +1582,7 @@ class clz_wrapper_clz: _fx_wrapped_ori_clz = clz def __new__(cls, *args, **kwargs): - if tracer.temp_disable_call: + if tracer.temp_call_origin: return clz(*args, **kwargs) tracers = _orig_set() def unwrap_detect_tracers(obj): @@ -1603,7 +1625,7 @@ class clz_wrapper_clz: _fx_wrapped_ori_clz = clz def __new__(cls, *args, **kwargs): - if tracer.temp_disable_call: + if tracer.temp_call_origin: return clz(*args, **kwargs) tracers = _orig_set() if _orig_len(args) != 0: @@ -1644,7 +1666,7 @@ def _create_wrapped_attr_for_middle_class(tracer: ConcreteTracer, clz, the_path_ _orig_clz_getattr = None @functools.wraps(_orig_clz_getattribute) def clz_getattr_wrapper(obj, attr): - if tracer.temp_disable_call | tracer.temp_disable_attr: + if tracer.temp_call_origin: if _orig_clz_getattr == None: return _orig_clz_getattribute(obj, attr) else: diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py index fc7747b9..d89e3f78 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -299,6 +299,6 @@ def __exit__(self, exc_type, exc_value, tb): def patch_run(func, *args, **kwargs): assert OperatorPatcherContext.ctx_tracer is not None assert OperatorPatcherContext.ctx_patcher is not None - with OperatorPatcherContext.ctx_tracer.do_temp_disable(True, True, True): + with OperatorPatcherContext.ctx_tracer.do_temp_call_origin(): new_func = OperatorPatcherContext.ctx_patcher.patch_func_or_module(func) return new_func(*args, **kwargs) diff --git a/requirements.txt b/requirements.txt index 84d96390..880014c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ numpy>=1.23.0 matplotlib more-itertools dill -torch>=1.13,<2.1 +torch>=2.0 From 2cbff22b9b72567f9927cc3af694bcd260fa1bda Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 29 Mar 2024 06:11:16 +0000 Subject: [PATCH 1609/1892] Merged PR 2087: parallel module: get function -> property / add long barrier by gloo group 1. use gloo group for long-time barrier. 2. change get_xx() to property for easy use. 3. add retry logic for torchrun test cases for listening port confliction. --- cube/graph/function/function.py | 12 ++++---- cube/parallel.py | 31 +++++++++++++------ cube/runtime/device.py | 20 ++++++++++-- cube/runtime/module.py | 35 ++++++++++++--------- tests/launch_torchrun.py | 4 +++ tests/parallel_module/test_checkpoint.py | 2 +- tests/parallel_module/test_ddp.py | 8 ++--- tests/parallel_module/test_init.py | 8 ++--- tests/parallel_module/test_submodule.py | 8 ++--- tests/parallel_module/test_wholemodule.py | 4 +-- tests/utils.py | 37 ++++++++++++++++++++++- 11 files changed, 121 insertions(+), 48 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index d4dd9af3..649caf16 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -928,7 +928,7 @@ def Norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None, signatur } if dim is None: einput = [edim + '^' for edim in einput] - anno = OpAnno.create_op_str([einput], ['1']) + anno = OpAnno.create_op_str([einput], [['1']]) return IRDimops(Norm, 'norm', signature, [anno], [input], **kwargs) else: dim = (dim,) if isinstance(dim, int) else dim @@ -959,7 +959,7 @@ def Sum(input, dim=None, keepdim=False, *, dtype=None, signature = None): eoutput = copy.copy(einput) if dim is None: einput = [edim + '+' for edim in einput] - anno = OpAnno.create_op_str([einput], ['1']) + anno = OpAnno.create_op_str([einput], [['1']]) return IRDimops(Sum, 'sum', signature, [anno], [input]) else: dim = (dim,) if isinstance(dim, int) else dim @@ -1907,7 +1907,7 @@ def Max(input, other_or_dim=None, keepdim=False, *, out=None, signature = None, other_or_dim_val = _unwrap_value(other_or_dim) if other_or_dim_val is None: edim_in = [s + '^' for s in ShapeAnno.create_shape_str(input.shape)] - annos = [OpAnno.create_op_str([edim_in], ['1'])] + annos = [OpAnno.create_op_str([edim_in], [['1']])] return IRDimops(Max, 'max', signature, annos, [input]) elif isinstance(other_or_dim_val, int): keepdim_val = _unwrap_value(keepdim) @@ -2198,7 +2198,7 @@ def Min(input, other_or_dim=None, keepdim=False, *, out=None, signature = None, other_or_dim_val = _unwrap_value(other_or_dim) if other_or_dim_val is None: edim_in = [s + '^' for s in ShapeAnno.create_shape_str(input.shape)] - annos = [OpAnno.create_op_str([edim_in], ['1'])] + annos = [OpAnno.create_op_str([edim_in], [['1']])] return IRDimops(Min, 'min', signature, annos, [input]) elif isinstance(other_or_dim_val, int): keepdim_val = _unwrap_value(keepdim) @@ -2212,7 +2212,7 @@ def Min(input, other_or_dim=None, keepdim=False, *, out=None, signature = None, kwargs = {'dim': other_or_dim, 'keepdim': keepdim} annos = [OpAnno.create_op_str([edim_in], [edim_out, edim_out])] return IRDimops(Min, 'min', signature, annos, [input], **kwargs) - + def Log(input, *, out=None, signature=None): """ @@ -2227,7 +2227,7 @@ def Log(input, *, out=None, signature=None): annos = [OpAnno.create_op_str([edim_in], [edim_in])] return IRDimops(Log, 'log', signature, annos, [input]) - + def FullLike(input, fill_value, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None, signature=None): """ diff --git a/cube/parallel.py b/cube/parallel.py index 899d4dd1..e334a963 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -749,7 +749,13 @@ def __init__(self, init_params=True): reuse = ReuseType(reuse) if isinstance(reuse, str) else reuse broadcast_strategy = BroadcastGenFilesStrategy(broadcast_strategy) if isinstance(broadcast_strategy, str) else broadcast_strategy - # genereate code only in node0 + # Call it here just to ensure the device group is initialized. + # If the user initializes torch.distributed + # and doesn't call `cube.init()` before calling this function, this is necessary. + if torch.distributed.is_initialized(): + _ = DeviceGroup() + + # generate code only in node0 # if it is not in a torchrun environment, just generate. if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: outdir, reusable = _prepare_and_check_reusable(cube_savedir, module_class, compute_config, instance_name, reuse) @@ -771,6 +777,13 @@ def __init__(self, init_params=True): regen_status = RegenStatus.NONE logger.info(f"Reuse generated code in {outdir}") + if torch.distributed.is_initialized(): + # code generation can take very long time (for example, over 1 hour) + # It is not always OK to use torch.distributed.barrier() directly. + # because the default timeout for nccl is 30 minutes + # (we can't control the timeout setting if torch.distributed is not initialized by us) + DeviceGroup().long_barrier() + if broadcast_strategy != BroadcastGenFilesStrategy.NONE: if not torch.distributed.is_initialized(): # we only support loading in torchrun environment raise RuntimeError("Broadcast generated files failed: torch.distributed is not initialized.") @@ -1000,7 +1013,7 @@ def build_optimizer( raise RuntimeError("No ParallelModule found in the module. Please make sure you have called parallelize() before build_optimizer().") # check if all ParallelModules have the same gpu_config - compute_configs = [m.get_compute_config() for m in parallel_modules] + compute_configs = [m.compute_config for m in parallel_modules] for i in range(1, len(compute_configs)): if compute_configs[i].gpu_config != compute_configs[0].gpu_config: raise RuntimeError("All ParallelModules should have the same gpu_config.") @@ -1027,7 +1040,7 @@ def _local_parameters(module: torch.nn.Module): lambda m: [ (cube_suffix, p) # (cube_suffix, p) to meet _named_members requirement for p in ( - m.parameters_for_optimizer() if m.get_compute_config().use_zero + m.parameters_for_optimizer() if m.compute_config.use_zero else m.parameters() # `CubeModule.merge_partial_states` supports parameters_for_optimizer() only in zero mode ) ] @@ -1053,7 +1066,7 @@ def _local_parameters(module: torch.nn.Module): name=type(optimizer).__name__, parallel_module_locs=opt_module_locs, parallel_module_configs={ - name: m.get_compute_config() + name: m.compute_config for name, m in module.named_modules() if isinstance(m, ParallelModule) } @@ -1551,7 +1564,7 @@ def _opt_load_merged_state_dict(module: ParallelModule, states: Dict[int, Dict[s orig_param_dict[name] = states[cnt] cnt = cnt + 1 - if module.get_compute_config().use_zero: + if module.compute_config.use_zero: return _construct_optim_state_zero(module, orig_param_dict) else: return _construct_optim_state_nonzero(module, orig_param_dict) @@ -1561,8 +1574,8 @@ def _construct_optim_state_zero( module: ParallelModule, orig_param_dict: Dict[str, Dict[str, Any]], ): - dist_param_map = module.get_dist_param_map() # name in parallel module (without tid suffix) -> name in origin module - param_area_map = module.get_full_map() # str -> AttrMeta + dist_param_map = module.dist_param_map # name in parallel module (without tid suffix) -> name in origin module + param_area_map = module.fullmap # str -> AttrMeta def _get_optimizer_state_of_param(param, param_ids, local_names): # find the parameter's optimizer state and pick the slices induced by tensor parallelism param_idx = param_ids.index(id(param)) @@ -1667,8 +1680,8 @@ def _construct_optim_state_nonzero( module: ParallelModule, orig_param_dict: Dict[str, Dict[str, Any]] ): - dist_param_map = module.get_dist_param_map() # name in parallel module (without tid suffix) -> name in origin module - param_area_map = module.get_full_map() # str -> AttrMeta + dist_param_map = module.dist_param_map # name in parallel module (without tid suffix) -> name in origin module + param_area_map = module.fullmap # str -> AttrMeta new_states = {} for index, (local_name, _) in enumerate(module.named_parameters()): diff --git a/cube/runtime/device.py b/cube/runtime/device.py index b6ce32a7..a84f6fcf 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -11,6 +11,7 @@ from cube.flags import CompileFlag _logger = logging.getLogger(__name__) +_LARGE_TIMEOUT = datetime.timedelta(seconds=21600) class DeviceGroup: @@ -27,7 +28,16 @@ def __init__(self): else: if not torch.distributed.is_initialized(): torch.distributed.init_process_group( - backend='nccl', timeout=datetime.timedelta(seconds=21600)) + backend='nccl', timeout=_LARGE_TIMEOUT + ) + + # create a barrier group for synchronization + # it is OK even the user has already created this gloo group + # this new timeout will override the old one. + self.barrier_gloo_group = torch.distributed.new_group( + backend='gloo', timeout=_LARGE_TIMEOUT + ) + self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() # assume each node has the same device number @@ -73,9 +83,15 @@ def get_group(self, ranks): rank_bits = DeviceGroup.bitmap(ranks) if rank_bits not in self.instance.groups: self.groups[rank_bits] = torch.distributed.new_group( - list(ranks), timeout=datetime.timedelta(seconds=21600)) + list(ranks), timeout=_LARGE_TIMEOUT) return self.groups[rank_bits] + def long_barrier(self): + """ + Barrier synchronization with very long timeout + """ + torch.distributed.barrier(group=self.instance.barrier_gloo_group) + def get_stream(self, name: str) -> torch.cuda.Stream: """ Get stream by name. If name doesn't exist, diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 0c47dffd..aa7b3e5c 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -57,7 +57,11 @@ def reducers(self): return self._reducers @property - def fullmap(self): + def fullmap(self) -> Dict[str, AttrMeta]: + """ + Get the mapping from the name of local attribute tensor + to its corresponding fulltensor meta + """ return self._fullmap def tid_of_param_name(self, name: str) -> int: @@ -572,6 +576,8 @@ class ParallelModule(CubeModule): COMPUTE_CONFIG_FILE = 'compute_config.pt' ORIGIN_MODULE_METADATA_FILE = 'origin_module_metadata.pt' EXTRA_STATE_KEY = 'CUBE_EXTRA_STATE' + # the rank of the module, will be assigned in the generated subclasses + rank: int def __init__(self): if self.__class__ == ParallelModule: # not init via super().__init__() @@ -641,7 +647,8 @@ def sync_grad(self): for reducer in self._reducers: reducer.sync_grads() - def get_dist_param_map(self) -> Dict[str, str]: + @property + def dist_param_map(self) -> Dict[str, str]: """ Get the parameter map of the model. The map is a dict mapping from the new parameter name (without tid suffix) in parallel module @@ -649,12 +656,10 @@ def get_dist_param_map(self) -> Dict[str, str]: """ return self._dist_param_map - def get_compute_config(self) -> 'ComputeConfig': + @property + def compute_config(self) -> 'ComputeConfig': return self._compute_config - def get_rank(self) -> int: - return self.rank # rank is a class varible defined in gencode - def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ Calculate the gradient norm and clip gradients. @@ -667,7 +672,7 @@ def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, Li """ from cube.runtime.gnorm import prepare_for_grad_clip, clip_gnorm if self._nreplicas2localparams is None: - self._nreplicas2localparams = prepare_for_grad_clip(self, self.get_compute_config().use_zero) + self._nreplicas2localparams = prepare_for_grad_clip(self, self.compute_config.use_zero) # make sure the gradients are synchronized self.sync_grad() @@ -700,7 +705,7 @@ def _get_zero_metadata(self) -> ZeroMetadata: Returns: ZeroMetadata: the zero related metadata """ - if not self.get_compute_config().use_zero: + if not self.compute_config.use_zero: return ZeroMetadata() model_params = self.parameters_for_optimizer() @@ -755,18 +760,18 @@ def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: rank_idx (int): the index of current rank in sub_ranks sub_ranks (list): the ranks of ZeRO subgroup the current rank belongs to """ - cf = self.get_compute_config() + cf = self.compute_config if not cf.use_zero: raise RuntimeError('ZERO is not enabled, cannot get the zero subgroup info') - rank_idx = reducer.ranks.index(self.get_rank()) + rank_idx = reducer.ranks.index(self.rank) if cf.zero_ngroups > 1: assert len(reducer.ranks) % cf.zero_ngroups == 0, \ f'reducer.ranks {reducer.ranks} should be divisible by ZERO_NUM_GROUPS {cf.zero_ngroups}' zgroup_sz = len(reducer.ranks) // cf.zero_ngroups group_idx = rank_idx // zgroup_sz sub_ranks = reducer.ranks[group_idx * zgroup_sz : (group_idx + 1) * zgroup_sz] - new_rank_idx = sub_ranks.index(self.get_rank()) + new_rank_idx = sub_ranks.index(self.rank) return new_rank_idx, sub_ranks else: assert cf.zero_ngroups == 1 @@ -775,7 +780,7 @@ def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: def _add_extra_state(self, state_dict, prefix) -> None: state_dict[f'{prefix}{self.EXTRA_STATE_KEY}'] = asdict( ExtraState( - rank=self.get_rank(), + rank=self.rank, compute_config=self._compute_config, dist_param_map=self._dist_param_map, param_area_map=self._fullmap, @@ -799,14 +804,14 @@ def module_dedup_group_size(self) -> int: """ Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. """ - return self.get_compute_config().module_dedup_group_size + return self.compute_config.module_dedup_group_size @property def optimizer_dedup_group_size(self) -> int: """ Get the size of the deduplication group of the optimizer state dict. """ - return self.get_compute_config().optimizer_dedup_group_size + return self.compute_config.optimizer_dedup_group_size def _list_fullmodel_files(self) -> List[Path]: legacy_fullmodel_path = self.module_dir / FxModuleParser.ATTR_CONTENT_FILE_STEM @@ -841,7 +846,7 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s RuntimeError: if strict=True and there are missing keys. """ - dist2param = self.get_dist_param_map() + dist2param = self.dist_param_map orig_param_names = list(dist2param.values()) # param names in original module (without prefix) with torch.no_grad(): diff --git a/tests/launch_torchrun.py b/tests/launch_torchrun.py index f66a010f..886928f4 100644 --- a/tests/launch_torchrun.py +++ b/tests/launch_torchrun.py @@ -3,8 +3,12 @@ import torch from torch.distributed.run import elastic_launch, LaunchConfig +from torch.distributed.elastic.multiprocessing.errors import ChildFailedError +from .utils import retry + +@retry(ChildFailedError, delay=10, match='RuntimeError: The server socket has failed to listen on any local network address.') def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): launch_config = LaunchConfig( min_nodes=1, diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 28b7ca91..7eaac8cd 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -340,7 +340,7 @@ def test_checkpoint(module_type, use_zero): def assert_intra_reducer(module: ParallelModule): - assert module.get_compute_config().plan_ngpus == module.get_compute_config().runtime_ngpus + assert module.compute_config.plan_ngpus == module.compute_config.runtime_ngpus assert len(module.reducers) > 0 # so we have both parameters in reducers and not in reducers # (assume one reducer gives one bucket, which is true in general.) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index c019d92a..31dfd4ab 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -214,10 +214,10 @@ def _gpu_worker_cube(pas, plan_ngpus, runtime_ngpus, update_freq, use_zero): ) return ( compiled_results, - compiled_module.fc_relu1.get_full_map(), - compiled_module.fc_relu1.get_dist_param_map(), - compiled_module.fc_relu2.get_full_map(), - compiled_module.fc_relu2.get_dist_param_map(), + compiled_module.fc_relu1.fullmap, + compiled_module.fc_relu1.dist_param_map, + compiled_module.fc_relu2.fullmap, + compiled_module.fc_relu2.dist_param_map, ) def _get_fc_weights(state_dict: dict, prefix): diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 755fdd58..acabf819 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -32,9 +32,9 @@ def _init_params_worker(): module1 = cube_module() module2 = cube_module() module3 = cube_module(init_params=False) - assert module1.get_rank() == 0 - assert module2.get_rank() == 0 - assert module3.get_rank() == 0 + assert module1.rank == 0 + assert module2.rank == 0 + assert module3.rank == 0 for p1, p2 in zip(module1.parameters(), module2.parameters()): assert torch.equal(p1, p2) @@ -76,7 +76,7 @@ def test_empty_weights(model_class, tp): for i in range(4): module_class = _load_cube_module_class(model_class, cube_savedir=tempdir, rank=i) m = new_empty(module_class) - assert m.get_rank() == i + assert m.rank == i for p in m.parameters(): assert p.device == torch.device('meta') for r in m.reducers: diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index aaccb8e1..39050900 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -121,10 +121,10 @@ def _gpu_worker(pas, ngpus, update_freq): return ( orig_results, compiled_results, - compiled_module.fc_relu1.get_full_map(), - compiled_module.fc_relu1.get_dist_param_map(), - compiled_module.fc_relu2.get_full_map(), - compiled_module.fc_relu2.get_dist_param_map(), + compiled_module.fc_relu1.fullmap, + compiled_module.fc_relu1.dist_param_map, + compiled_module.fc_relu2.fullmap, + compiled_module.fc_relu2.dist_param_map, ) diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 7309ff35..81aeb5f1 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -109,8 +109,8 @@ def _gpu_worker(pas, ngpus): return ( orig_results, compiled_results, - compiled_module.get_full_map(), - compiled_module.get_dist_param_map(), + compiled_module.fullmap, + compiled_module.dist_param_map, ) diff --git a/tests/utils.py b/tests/utils.py index 7a4e4fa6..badc80a8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,10 @@ import os +import re import sys -from typing import Type +from typing import Optional, Tuple, Type, Union, Pattern from contextlib import contextmanager from typing import Callable +import functools import math import random from datetime import timedelta @@ -227,9 +229,11 @@ def mock_dist(rank, world_size): """ old_store_based_barrier = c10d._store_based_barrier + old_new_group = dist.new_group try: c10d._store_based_barrier = lambda *args, **kwargs: None mock_init_dist(rank, world_size) + dist.new_group = lambda *args, **kwargs: None yield finally: dist.destroy_process_group() @@ -281,3 +285,34 @@ def new_empty(cube_module_cls: Type[ParallelModule], device='meta', init_params= compute_config = torch.load(module_file.with_name(f"{cube_module_cls.COMPUTE_CONFIG_FILE}")) with replace_all_device_with(device, True), mock_cube_env(cube_module_cls, compute_config), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): return cube_module_cls(init_params=init_params) + + +def retry(*exceptions, max_tries=3, match: Optional[Union[str, Pattern[str]]] = None, delay=5): + """ + Retry the function if an exception is raised. + + Args: + max_tries (int): the maximum number of tries + + Example: + @retry(): + def f(*args, **kwargs): + ... + """ + exceptions = exceptions or (Exception,) + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for i in range(max_tries): + try: + return func(*args, **kwargs) + except exceptions as e: + matched = not match or re.search(match, str(e)) + if i == max_tries - 1 or not matched: + raise + from time import sleep + print(f"retrying... {e} after {delay} seconds") + sleep(delay) + return wrapper + + return decorator From 9e3f91314f9debf3428fddabb54ae4ac8e1ef451 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 8 Apr 2024 06:36:00 +0000 Subject: [PATCH 1610/1892] Merged PR 2089: support trace model several times & add code log for .device --- .../fx/concrete_trace_utils/concrete_proxy.py | 4 +- .../concrete_trace_utils/concrete_tracer.py | 22 +-------- .../parser/fx/concrete_trace_utils/utils.py | 47 +++++++++++++------ 3 files changed, 38 insertions(+), 35 deletions(-) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 0c4f02bd..6926bc95 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -32,6 +32,7 @@ _orig_slice, _orig_set, map_recursive, + get_frame_record, ) _logger = logging.getLogger(__name__) @@ -290,7 +291,8 @@ def __init__(self, root: ConcreteProxy, attr: str): elif _orig_isinstance(root.value, torch.Tensor) and attr == 'device' and self.tracer.cpu_offload: self.value = torch.device('cuda') warning_msg = "operation .device is detected, it will always return torch.device('cuda') during trace, " + \ - "please make sure don't manually change the tensor device in the code." + "please make sure don't manually change the tensor device in the code.\n" + \ + f"\t{get_frame_record()}" _logger.warning(warning_msg) else: self.value = _orig_getattr(root.value, attr) diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 7c540698..c5b1fb4f 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -11,15 +11,12 @@ import operator import functools import builtins -import traceback -import importlib.util from dataclasses import dataclass, field from itertools import chain from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType from typing import Any, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, List, Callable, Union from contextlib import contextmanager -from pathlib import Path import torch from torch._C import ScriptObject @@ -142,6 +139,7 @@ def __exit__(self, *args): flatten_trees_with_func, flatten_trees_with_func_and_spec, map_trees_with_func, + get_frame_record, ) # pyright: reportGeneralTypeIssues=false @@ -500,20 +498,7 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: if self.record_frames and kind != 'placeholder': with self.do_temp_call_origin(): - # record code frame, include filename, line number, and function name - frame_record = FrameRecord(None, None, None, None) - cube_path = str(Path(importlib.util.find_spec('cube').origin).parent) + '/' # the cube path - torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path - ignore_dirs = [cube_path, torch_path] - for frame in traceback.extract_stack()[-2::-1]: - if any(p in frame.filename for p in ignore_dirs): - continue - frame_record.filename = frame.filename - frame_record.lineno = frame.lineno - frame_record.line = frame.line - frame_record.name = frame.name - break - node.meta['frame_record'] = frame_record + node.meta['frame_record'] = get_frame_record() proxy = self.proxy(value_unwrapped, node) return proxy @@ -867,9 +852,6 @@ def module_getattribute_wrapper(mod, attr): 'this is usually caused by directly assigning the return value of some leaf function to the attribute of the module. ' + \ 'Please note that this writing method may cause some trace errors.' _logger.warning(warn_msg) - if callable(attr_val) and not _orig_isinstance(attr_val, ep.ConcreteProxy): - if attr_val in self.wrapped_leaf: - return self.wrapped_leaf[attr_val][1] return attr_val # using isinstance instead of _orig_isinstance to judge whether # the ConcreteProxy.value is the following three types if the attr_val is a ConcreteProxy diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/cube/graph/parser/fx/concrete_trace_utils/utils.py index 733b596f..22233762 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/cube/graph/parser/fx/concrete_trace_utils/utils.py @@ -4,7 +4,10 @@ import builtins from collections import namedtuple from dataclasses import dataclass +import importlib import operator +import traceback +from pathlib import Path from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Type, List import torch @@ -189,20 +192,6 @@ def flatten_trees_with_func_and_spec(fn, pytrees, spec): return [fn(*vals) for vals in zip(*flat_args)] -@dataclass -class FrameRecord: - filename: str - lineno: str - line: str - name: str - - def __repr__(self) -> str: - if self.filename: - return f'File "{self.filename}", line {self.lineno}, in {self.name}, {self.line}' - else: - return '' - - class ExtraSEFPatcher: def __init__(self, extra_side_effectful_functions: Set[Callable]): self.extra_side_effectful_functions = extra_side_effectful_functions @@ -292,3 +281,33 @@ class EmptyResult: """Used for identification no results. """ pass + + +@dataclass +class FrameRecord: + filename: str + lineno: str + line: str + # the name of the frame is the function name + name: str + + def __repr__(self) -> str: + if self.filename: + return f'File "{self.filename}", line {self.lineno}, in {self.name}, {self.line}' + else: + return '' + + +def get_frame_record() -> Optional[FrameRecord]: + # record code frame, include filename, line number, and function name + frame_record = None + cube_path = str(Path(importlib.util.find_spec('cube').origin).parent) + '/' # the cube path + torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path + ignore_dirs = [cube_path, torch_path] + # the last frame is the current frame [get_frame_record], so we need to skip it + for frame in traceback.extract_stack()[-2::-1]: + if any(p in frame.filename for p in ignore_dirs): + continue + frame_record = FrameRecord(frame.filename, frame.lineno, frame.line, frame.name) + break + return frame_record From 16bfa48e801e3e8b1dfe4d24757ee7e797aecf13 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Mon, 8 Apr 2024 07:26:50 +0000 Subject: [PATCH 1611/1892] Merged PR 2095: fix: full-tensors in kwargs convert to sub tensor fix kwargs to sub tensor --- cube/graph/graph.py | 1 + tests/graph/test_graph.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tests/graph/test_graph.py diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 5fadda12..9d9444d2 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -190,6 +190,7 @@ def from_logic_graph(nodes: List[IRCell], if isinstance(ftensor, IRObject): subtensor = ftensor.tosub() if isinstance(ftensor, IRFullTensor) else ftensor node.set_output(idx, subtensor) + node.kwargs.update(IRSegment.modify_objects_of_complex(node.kwargs, modifier)) graph = IRGraph(nodes, inputs, outputs, module_name) # check IRPyFunc diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py new file mode 100644 index 00000000..784bd32c --- /dev/null +++ b/tests/graph/test_graph.py @@ -0,0 +1,28 @@ + +from cube.ir.tensor import IRFullTensor, IRSubTensor +from cube.ir.operator import IRFwOperation +from cube.graph.graph import IRGraph + + +def test_graph_from_logic(): + + node = IRFwOperation("test", "test", + inputs=[IRFullTensor([256, 256])], + num_outputs=1, + # kwargs + kw={ + 'a':[IRFullTensor([128, 256]),], + 'b':IRFullTensor([128, 128]) + }, + t=IRFullTensor([128, 256])) + output = IRFullTensor([128, 256]) + node.set_output(0, output) + graph = IRGraph.from_logic_graph([node], [node.input(0)], [output], 'GenModule') + assert len(graph.nodes()) == 1 + node = graph.node(0) + print(node.kwargs) + assert isinstance(node.input(0), IRSubTensor) + assert isinstance(node.output(0), IRSubTensor) + assert isinstance(node.kwargs['kw']['a'][0], IRSubTensor) + assert isinstance(node.kwargs['kw']['b'], IRSubTensor) + assert isinstance(node.kwargs['t'], IRSubTensor) From f5084313fe329fcb74ec87b8dc7f9a9d50cb5671 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 9 Apr 2024 01:31:15 +0000 Subject: [PATCH 1612/1892] Merged PR 2096: Avoid redundant retrieve in zero Problem solved in this PR: when the zero group size is very large, like 128, traversing all zero workers to rebuild optimizer states leads to out of cpu memory. Since the zero information (model_idx2opt_idx and opt_idx2ranks) is shared intra scale units, we only need to `retrieve_param_opt_state` for the 1st scale unit. In addition, add useful logging. --- cube/runtime/module.py | 28 +++++++++++++++---- tests/parallel_module/test_checkpoint.py | 2 +- .../parallel_module/test_checkpoint_dedup.py | 2 +- .../parallel_module/test_checkpoint_shared.py | 2 +- tests/parallel_module/test_ddp.py | 4 +-- tests/parallel_module/test_reducer_hook.py | 2 +- tests/parallel_module/test_scale_grads.py | 2 +- tests/parallel_module/test_submodule.py | 2 +- tests/parallel_module/test_wholemodule.py | 2 +- 9 files changed, 31 insertions(+), 15 deletions(-) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index aa7b3e5c..706cf262 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -358,8 +358,12 @@ def merge_state_dicts( Dict[str, torch.Tensor]: Full model state dict Dict[str, Any]: Full optimizer state dict """ + # state dicts in the 1st scale unit may be a subset of `model_state_dicts`. Using `plan_ngpus` here to + # help understand the whole logic. In other words, the real plan_ngpus is <= len(model_state_dicts). + plan_ngpus = len(model_state_dicts) # gather model states - full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps[0: len(model_state_dicts)]) + full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps[0: plan_ngpus]) + _logger.info('finish merge model states') if optim_state_dicts is None: return full_model_state_dict, None @@ -417,9 +421,14 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): opt_states['step'] = bucket_states[0]['step'] return opt_states + # Parameters are partitioned inside a scale unit composed of plan_ngpus GPUs. + # When ZeRO-1 is enabled, optimizer states (like exp_avg and exp_avg_sq) are partitioned within + # each ZeRO group. Since the training is done in a synchronized way, the optimizer states are + # identical across each ZeRO group. + # As a result, we can retrieve and merge the optimizer states in other scale units following the + # information stored in zero_idx_maps ONLY for the first scale unit. opt_state_list = [] - worker_cnt = len(optim_state_dicts) - for work_idx in range(worker_cnt): + for work_idx in range(plan_ngpus): model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[work_idx] opt_state = {} for model_idx, opt_idx in model_idx2opt_idx.items(): @@ -438,14 +447,16 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): pend, pshape, bucket_size) + _logger.info(f'finish handle optimizer state for worker {work_idx}') opt_state_list.append(opt_state) assert len(optim_state_dicts[work_idx]['param_groups']) == 1, 'only support param_groups to be one group' # assign opt_state to state_dicts, cannot be assigned in the above loop opt_state_len = len(opt_state_list[0]) - for work_idx in range(worker_cnt): + for work_idx in range(plan_ngpus): optim_state_dicts[work_idx]['state'] = opt_state_list[work_idx] optim_state_dicts[work_idx]['param_groups'][0]['params'] = sorted(opt_state_list[work_idx].keys()) + _logger.info(f'finish assign optimizer state for worker {work_idx}') assert len(opt_state_list[work_idx]) == opt_state_len # build parameter order to match with the optimizer state order @@ -467,7 +478,8 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): full_states = full_optim_state_dict['state'] # full_index: param IDs in the full optimizer state for full_index, param_name in enumerate(origin_parameter_names): - for optim_state, fullmap in zip(optim_state_dicts, fullmaps): + _logger.info(f'start to handle optimizer state for param {param_name} with full_index {full_index}') + for optim_state, fullmap in zip(optim_state_dicts[0 : plan_ngpus], fullmaps[0 : plan_ngpus]): if 'state' not in optim_state: continue # adam-like optimizers have optim_state['state']={} before any optimizer.step() if not optim_state['state']: continue @@ -499,9 +511,13 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): full_states[full_index][state_name][meta.slicers] = value # handle additional state dict keys - for optim_state_dict in optim_state_dicts: + for optim_state_dict in optim_state_dicts[0 : plan_ngpus]: for key in optim_state_dict.keys(): if key != 'state': + if key in full_optim_state_dict: + _logger.info(f'overwrite optimizer state key {key}') + else: + _logger.info(f'inherit optimizer state key {key}') full_optim_state_dict[key] = optim_state_dict[key] return full_model_state_dict, full_optim_state_dict diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 7eaac8cd..2d7aa7a8 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -218,7 +218,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): for _ in range(DATA_SIZE): data.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) data = data[start:end] # continue from last training data = [data[i] for i in range(rank, len(data), num_replicas)] diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index f0c965b7..f3b0a100 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -86,7 +86,7 @@ def _train(model: torch.nn.Module, ckpt_dir): for _ in range(DATA_SIZE): DATA.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) loss_fn = nn.BCELoss() optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index 4bb4785a..03b4227e 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -101,7 +101,7 @@ def _train_raw(model: torch.nn.Module, ckpt_dir): for _ in range(DATA_SIZE): DATA.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) loss_fn = nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index 31dfd4ab..647f2a01 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -121,7 +121,7 @@ def _train_ddp(model, update_freq, num_replicas, rank): for _ in range(DATA_SIZE): data.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) data = [data[i] for i in range(rank, DATA_SIZE, num_replicas)] results = [] @@ -158,7 +158,7 @@ def _train(model, is_cube, update_freq, num_replicas, rank): for _ in range(DATA_SIZE): data.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) data = [data[i] for i in range(rank, DATA_SIZE, num_replicas)] results = [] diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index 94e2bdd0..c6972233 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -91,7 +91,7 @@ def post_hook(reducer, grad): for _ in range(DATA_SIZE): data.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) results = [] for i, (x, y) in enumerate(data): diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py index 62e7e8f7..1c003540 100644 --- a/tests/parallel_module/test_scale_grads.py +++ b/tests/parallel_module/test_scale_grads.py @@ -101,7 +101,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, scale_grads: bool): for _ in range(DATA_SIZE): data.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) data = [data[i] for i in range(rank, len(data), num_replicas)] results = [] diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 39050900..3bf299f7 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -94,7 +94,7 @@ def _train(model, update_freq, is_cube): for _ in range(DATA_SIZE): data.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) results = [] for i, (x, y) in enumerate(data): diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 81aeb5f1..eec4a4a4 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -82,7 +82,7 @@ def _train(model, is_cube): for _ in range(DATA_SIZE): data.append(( torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.randn((2, 1), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), )) results = [] for i, (x, y) in enumerate(data): From 5cb4ea29b027b96f3e38563423198eb3142b289e Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 9 Apr 2024 02:58:23 +0000 Subject: [PATCH 1613/1892] Merged PR 2100: hotfix bug: remove new gloo group fix bug: new gloo group hangs --- cube/runtime/device.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cube/runtime/device.py b/cube/runtime/device.py index a84f6fcf..081ae3e1 100644 --- a/cube/runtime/device.py +++ b/cube/runtime/device.py @@ -31,12 +31,14 @@ def __init__(self): backend='nccl', timeout=_LARGE_TIMEOUT ) + # disable it for now due to connection refused error when nnodes > 1 + # TODO: investigate the root cause # create a barrier group for synchronization # it is OK even the user has already created this gloo group # this new timeout will override the old one. - self.barrier_gloo_group = torch.distributed.new_group( - backend='gloo', timeout=_LARGE_TIMEOUT - ) + # self.barrier_gloo_group = torch.distributed.new_group( + # backend='gloo', timeout=_LARGE_TIMEOUT + # ) self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -90,7 +92,8 @@ def long_barrier(self): """ Barrier synchronization with very long timeout """ - torch.distributed.barrier(group=self.instance.barrier_gloo_group) + # torch.distributed.barrier(group=self.instance.barrier_gloo_group) + torch.distributed.barrier() def get_stream(self, name: str) -> torch.cuda.Stream: """ From 14c98c31c04686f406f92d32ed6b77e375fa2bb6 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 10 Apr 2024 06:23:48 +0000 Subject: [PATCH 1614/1892] Merged PR 2101: Add function: LogSigmoid Apart from LogSigmoid, this PR includes - add use_reentrant=True to recompute during codegen, since this keyword is required for some version of torch - refine the default value setting when operator profiling fails --- cube/codegen/module/module.py | 2 +- cube/graph/function/function.py | 6 ++++++ cube/graph/parser/fx/mapping.py | 1 + cube/profiler/database.py | 10 +++++++++- 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index b25f4d53..0b79a1db 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -830,7 +830,7 @@ def recompute(tensor_2222): fb.insert_body(f'return {output_names_tuple}') codes = [''] + fb.code + [''] codes.append( - f'{output_names_tuple} = ckpt.checkpoint(recompute, {input_names_tuple})' + f'{output_names_tuple} = ckpt.checkpoint(recompute, {input_names_tuple}, use_reentrant=True)' ) return codes diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 649caf16..5df77053 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -709,6 +709,12 @@ def SiLU(input, inplace=False, signature = None): return IRDimops(SiLU, 'silu', signature, annos, [input], inplace=inplace) +def LogSigmoid(input, signature = None): + annos = ['* -> *'] + signature = 'torch._C._nn.log_sigmoid' + return IRDimops(LogSigmoid, 'log_sigmoid', signature, annos, [input]) + + def ReLU(input, inplace=False, signature = None): annos = ['* -> *'] signature = 'torch.nn.functional.relu' diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 543b55ac..4b304739 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -74,6 +74,7 @@ def exist(signature: str) -> bool: __ttemplate('tril'): function.Tril, __ftemplate('relu'): function.ReLU, __ftemplate('silu'): function.SiLU, + __fcntemplate('log_sigmoid'): function.LogSigmoid, __fcntemplate('gelu'): function.GeLU, __ttemplate('eq') : function.CompareEQ, '_operator.eq': function.CompareEQ, diff --git a/cube/profiler/database.py b/cube/profiler/database.py index 893f2c2f..8a1a928d 100644 --- a/cube/profiler/database.py +++ b/cube/profiler/database.py @@ -6,6 +6,7 @@ import torch import time import os +import copy import json import math import logging @@ -296,7 +297,14 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b CompProfiler.profile(node, fn, shapes, dtypes, requires_grads, values, **kwargs) except Exception: _logger.exception(f'fail to profile {node}, use default values') - fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = 0, 0, 0, [], [] + fw_span, bw_span = 0, 0 + infer_memory = 0 + for t in node.outputs(): + if isinstance(t, IRTensor): + infer_memory += t.byte_size() + # by default, we assume that all the input tensors are saved for backward + train_mem_info = copy.deepcopy(in_mem_info) + train_mem2in_idx = list(range(len(in_mem_info))) # log to database key = self._serialize(node) profiled_metrics = ProfiledMetrics(in_mem_info, param_mem_info, buffer_mem_info, From 63c1b08a6a169d068537b5c4760371e5d45a237e Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Wed, 10 Apr 2024 08:29:40 +0000 Subject: [PATCH 1615/1892] Merged PR 2066: huggingface top models bug fix --- cube/codegen/emit.py | 2 +- cube/graph/function/dimops.py | 4 +- cube/graph/function/function.py | 174 +++++++++++++++++++------ cube/graph/graph.py | 2 +- cube/graph/parser/fx/mapping.py | 1 + cube/graph/segment.py | 2 + cube/runtime/function/function.py | 7 +- tests/codegen/test_emit.py | 2 + tests/graph/function/test_functions.py | 121 +++++++++++++++-- 9 files changed, 257 insertions(+), 58 deletions(-) diff --git a/cube/codegen/emit.py b/cube/codegen/emit.py index cb6de746..ec790afd 100644 --- a/cube/codegen/emit.py +++ b/cube/codegen/emit.py @@ -58,7 +58,7 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: elif isinstance(val, tuple): # TODO: support subclasses of tuple, like torch.Size? return tuple(_safe_repr_value(v, prefix_attr) for v in val) - elif isinstance(val, (int, str, bool, float, type(None), bytes)): # only primitive type supported + elif isinstance(val, (int, str, bool, float, type(None), bytes, type(Ellipsis))): return val raise ValueError(f'Unsupported data type: {type(val)}') diff --git a/cube/graph/function/dimops.py b/cube/graph/function/dimops.py index 3145c377..8929c179 100644 --- a/cube/graph/function/dimops.py +++ b/cube/graph/function/dimops.py @@ -826,7 +826,8 @@ def algorithms(self, tag: Optional[str] = None): def transform_space(self) -> List[Tuple[int, int]]: """ - Get transformation space of the operator + Get transformation space of the operator, the transformation space + represents all configurations that can be segmented @return List[Tuple[int, int]]: list of (idx, dim) """ @@ -834,6 +835,7 @@ def transform_space(self) -> List[Tuple[int, int]]: configs = [] ashapes = self.anno.inputs() + self.anno.outputs() for idx, eshape in enumerate(ashapes): + if eshape.ignore: continue if idx < len(self.inputs()): if not isinstance(self.input(idx), IRTensor): continue for dim, edim in enumerate(eshape.dims): diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 5df77053..b05c92df 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -194,7 +194,12 @@ def creator_modifier(kwargs: Dict, idx, dim, num: int) -> Dict: def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Union[int, IRObject], dtype=None, requires_grad=False, signature=None): - dtype = dtype if dtype is not None else torch.get_default_dtype() + if dtype is None: + if any(isinstance(_unwrap_value(s), float) for s in (start, end, step)) or \ + any(s.dtype in [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16] for s in (start, end, step) if s is IRTensor): + dtype = torch.get_default_dtype() + else: + dtype = torch.int64 assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" signature = 'cube.runtime.function.arange' kwargs = {'start': start, 'end': end, 'step': step, @@ -859,11 +864,43 @@ def MaskedFill(input, mask, value, signature = None): return IRDimops(MaskedFill, 'masked_fill', signature, [anno], [input, mask], value=value) -def Where(condition, input, other, *, out=None, signature = None): +def Topk(input, k, dim=None, largest=True, sorted=True, *, out=None, signature = None): + """ + torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) + """ + edim_in = ShapeAnno.create_shape_str(input.shape) + if dim is None: + edim_in[-1] += '^' + else: + edim_in[dim] += '^' + edim_ou = [['?'], ['?']] + anno = OpAnno.create_op_str([edim_in], edim_ou) + return IRDimops(Topk, 'topk', signature, [anno], [input], k=k, dim=dim, largest=largest, sorted=sorted) + + +def Nonzero(input, *, out=None, as_tuple=False, signature = None): + """ + torch.nonzero(input, *, out=None, as_tuple=False) + """ + edim_in = ShapeAnno.create_shape_str(input.shape, reduction="^") + if as_tuple: + edim_ou = list(['?'] for _ in range(len(input.shape))) + else: + edim_ou = [['?']] + anno = OpAnno.create_op_str([edim_in], edim_ou) + return IRDimops(Nonzero, 'nonzero', signature, [anno], [input], as_tuple=as_tuple) + + +def Where(condition, input=None, other=None, *, out=None, signature = None): """ torch.where """ assert isinstance(condition, IRTensor) + if input is None and other is None or \ + (input is IRObject and input.value is None) and (other is IRObject and other.value is None): + return Nonzero(condition, as_tuple=True, signature = 'torch.nonzero') + if input is None or other is None: + raise ValueError("Both input and other must be provided together") if isinstance(input, IRTensor) and isinstance(other, IRTensor): (edim_in0, edim_in1, edim_in2), edim_out = _handle_broadcast_multi([condition, input, other]) elif isinstance(input, IRTensor) and len(input.shape) > 0 and not (len(input.shape) == 1 and input.shape[0] == 1): @@ -1562,7 +1599,7 @@ def IndexSelect(input: torch.Tensor, dim: int, index: torch.Tensor, *, out=None, return CubeIndexSelect(input, index, dim, signature=signature) -def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice, int]], signature=None): +def FullSlice(tensor: IRTensor, *slicers: Tuple[Union[None, slice, int, IRTensor, IRObject]], signature=None): """ Examples: >>> a = torch.randn((4,2)) @@ -1572,9 +1609,9 @@ def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice, int]], signatu >>> a[(2, None)] # shape [1,2] >>> a[(2, slice(None, None, None)), None] # shape [2,1] >>> a[(2, None, slice(None, None, None))] # shape [1,2] + >>> a[(2, torch.tensor([0, 1]), None)] # shape [2,1] """ signature = 'cube.runtime.function.fullslice' - slicers = tuple(slicers) # deal with ... in slice if any(slicer is Ellipsis for slicer in slicers): @@ -1590,7 +1627,9 @@ def FullSlice(tensor: IRTensor, slicers: Tuple[Union[None, slice, int]], signatu mid_slicers = [slice(None, None, None) for _ in range(len(tensor.shape) - front_count - back_count)] slicers = tuple(front_slicers + mid_slicers + back_slicers) - edim_in = ShapeAnno.create_shape_str(tensor.shape) + edim_in_additional = [] + fullslice_iterator = iter(string.ascii_lowercase) + edim_in = ShapeAnno.create_shape_str(tensor.shape, iterator=fullslice_iterator) edim_ou = [] in_idx = 0 tensor_error_msg = ("Tensor is not supported in slice. " @@ -1604,13 +1643,24 @@ def obj_helper(obj): return obj.value else: return obj + + # If there are more than one tensors or lists in slicers and their date type is not bool, they will broadcast to each other, + # and the output shape will be infered by the shapes of all tensors and lists in slicers, will use '?' in edim_ou + _single_int_tensor = len([slicer for slicer in slicers if + (isinstance(slicer, IRTensor) and slicer.dtype is not bool ) + or (isinstance(slicer, list) and slicer[0] is not bool)]) <= 1 + output_shape_unkonwn = False + slicers = list(slicers) for slicer in slicers: if slicer is None: + edim_in_additional.append(['?']) edim_ou.append('1') elif isinstance(slicer, int): + edim_in_additional.append(['?']) edim_in[in_idx] += '^' in_idx += 1 elif isinstance(slicer, slice): + edim_in_additional.append(['?']) if slicer != slice(None, None, None): edim_in[in_idx] += '^' _start, _stop, _step = obj_helper(slicer.start), obj_helper(slicer.stop), obj_helper(slicer.step) @@ -1625,15 +1675,57 @@ def obj_helper(obj): edim_ou.append(str(dimlen)) in_idx += 1 elif isinstance(slicer, IRTensor): - raise RuntimeError(tensor_error_msg) + # TODO: output shape can be infered by shapes of all lists and tensors in slicers + # Examples: a = torch.randn(3,4) + # a[torch.tensor([0, 1, 2]) ,[0, 1, 1]] == a[[0, 1, 2] ,[0, 1, 1]] == [a[0, 0], a[1, 1], a[2, 1]] + # a[[0] ,[0, 1, 1]] == a[[0, 0, 0] ,[0, 1, 1]] + # a[[True, False, True]] == a[torch.tensor([0, 2])] == a[[0, 2]] == [a[0,:], a[2,:]] + # a[[True, False, True], [0, 1]] == a[torch.tensor([0, 2]), [0, 1]] == a[[0, 2], [0, 1]] == [a[0, 0], a[2, 1]] + # when dtype of IRTensor or value of list is bool, the input shape must be the same as the sliced tensor at corresponding dimensions + slicer_anno = ShapeAnno.create_shape_str(slicer.shape, iterator=fullslice_iterator) + if slicer.dtype != torch.bool: + edim_in[in_idx] += '^' + in_idx += 1 + else: + slen = len(slicer.shape) + for i in range(in_idx, in_idx+slen): + edim_in[i] += '^' + in_idx += slen + if not _single_int_tensor or slicer.dtype == torch.bool: + slicer_anno = [ anno + "^" for anno in slicer_anno ] + output_shape_unkonwn = True + edim_in_additional.append(slicer_anno) + edim_ou.extend(slicer_anno) + elif isinstance(slicer, list): + if len(slicer) == 0: + raise RuntimeError(f"Unsupported slicer {slicer}. The length of the list in the slicer cannot be 0") + def list_shape(lst): + return [len(lst)] + (list_shape(lst[0]) if isinstance(lst[0], list) else []) + if type(slicer[0]) == bool and len(list_shape(slicer)) > 1: + raise RuntimeError(f"Unsupported slicer {slicer}. The depth of the list in the slicer cannot exceed 1 when value type is bool") + edim_in_additional.append(['?']) + edim_in[in_idx] += '^' + in_idx += 1 + if not _single_int_tensor or type(slicer[0]) == bool: + output_shape_unkonwn = True + else: + edim_ou.extend([str(a) for a in list_shape(slicer)]) else: raise RuntimeError(f"Unsupported slicer {slicer}. you may need to wrap related logic in a Customized Op.") - edim_ou += edim_in[in_idx:] - # special case for scalar = torch.Tensor([1,2,3])[0] - if len(edim_ou) == 0: - edim_ou.append('1') - anno = OpAnno.create_op_str([edim_in], [edim_ou]) - return IRDimops(FullSlice, 'fullslice', signature, [anno], [tensor], slicers=slicers) + + if output_shape_unkonwn: + edim_ou = ['?'] + else: + edim_ou += edim_in[in_idx:] + if len(edim_ou) == 0: + # special case for scalar = torch.Tensor([1,2,3])[0] + edim_ou.append('1') + + edim_in = [edim_in] + edim_in.extend(edim_in_additional) + anno = OpAnno.create_op_str(edim_in, [edim_ou]) + + return IRDimops(FullSlice, 'fullslice', signature, [anno], [tensor] + slicers) def Slice(tensor: torch.Tensor, dim, start, end, step, signature = None): @@ -1742,6 +1834,9 @@ def Embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, def Flatten(input, start_dim=0, end_dim=-1, signature = None): + """ + torch.flatten(input, start_dim=0, end_dim=-1) -> Tensor + """ start_dim = len(input.shape) + start_dim if start_dim < 0 else start_dim end_dim = len(input.shape) + end_dim if end_dim < 0 else end_dim ishape = ShapeAnno.create_shape_str(input.shape) @@ -1749,6 +1844,7 @@ def Flatten(input, start_dim=0, end_dim=-1, signature = None): ishape[dim] += '^' oshape = ishape[:start_dim] oshape.append(ishape[start_dim:end_dim+1]) + oshape.extend(ishape[end_dim+1:]) anno = OpAnno.create_op_str([ishape], [oshape]) return IRDimops(Flatten, 'flatten', signature, [anno], [input], start_dim=start_dim, end_dim=end_dim) @@ -2001,27 +2097,9 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: obj, index = a, b # tensor slice if isinstance(obj, IRTensor): - # note `None` will always - if isinstance(index, IRTensor): # TODO: support general tensor slicing: https://pytorch.org/cppdocs/notes/tensor_indexing.html - # move to FullSlice when ready - """ - Examples: - >>> a = torch.randn((4,2)) - >>> b = torch.randn((3,5)).to(torch.int64) - >>> a[b] # shape [3,5,2] - """ - if index.dtype not in (torch.int64, torch.int32): - raise RuntimeError(f"index should be int64 or int32, but got {index.dtype}") - gener = iter(string.ascii_lowercase) - obj_shape = ShapeAnno.create_shape_str(obj.shape, iterator=gener) - obj_shape[0] = obj_shape[0] + '^' - index_shape = ShapeAnno.create_shape_str(index.shape, iterator=gener) - out_shape = index_shape + obj_shape[1:] - anno = OpAnno.create_op_str([obj_shape, index_shape], [out_shape]) - return IRDimops(GetItem, 'getitem', signature, [anno], [obj, index]) - index = (index,) if isinstance(index, (int, slice)) else tuple(index) - return FullSlice(obj, index) + index = (index,) if isinstance(index, (int, slice, IRTensor, IRObject)) else tuple(index) + return FullSlice(obj, *index) # object slice if isinstance(obj, IRObject): assert obj.value is not None @@ -2165,8 +2243,6 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, dropout_p=0.0, """ if not isinstance(query, IRTensor) or not isinstance(key, IRTensor) or not isinstance(value, IRTensor): raise RuntimeError(f'query: {query}, key: {key}, value: {value} should be IRTensor, something went wrong.') - if attn_mask is not None: - raise RuntimeError(f'Only support attn_mask is None in scaled_dot_product_attention now.') gener = iter(string.ascii_lowercase) value_anno = ShapeAnno.create_shape_str(value.shape, iterator=gener) value_anno[-2] += '^' @@ -2176,10 +2252,30 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, dropout_p=0.0, query_anno[-2] = next(gener) out_anno = copy.copy(query_anno) out_anno[-1] = value_anno[-1] - - anno = OpAnno.create_op_str([query_anno, key_anno, value_anno], [out_anno]) - return IRDimops(ScaledDotProductAttention, 'scaled_dot_product_attention', signature, [anno], [query, key, value], - attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + if attn_mask is not None: + if not isinstance(attn_mask, IRTensor): + raise RuntimeError(f'attn_mask: {attn_mask} should be IRTensor, something went wrong.') + if len(attn_mask.shape) < 2 or len(attn_mask.shape) > len(query.shape): + raise RuntimeError(f'attn_mask shape {attn_mask.shape} is not supported, while query shape is {query.shape}') + attn_mask_anno = [] + # the anno of attn_mask will conbine query and attn_mask shape except last dimension, + # the last dimension of the attn_mask anno will be the same as key penultimate dimension + for index, sval in enumerate(attn_mask.shape[-2::-1]): + if attn_mask.shape[-2-index] == query.shape[-2-index]: + attn_mask_anno.insert(0, query_anno[-2-index]) + else: + attn_mask_anno.insert(0, str(attn_mask.shape[-2-index])) + if attn_mask.shape[-1] == key.shape[-2]: + attn_mask_anno.append(key_anno[-2]) + else: + attn_mask_anno.append(str(attn_mask.shape[-1])) + anno = OpAnno.create_op_str([query_anno, key_anno, value_anno, attn_mask_anno], [out_anno]) + return IRDimops(ScaledDotProductAttention, 'scaled_dot_product_attention', signature, [anno], [query, key, value, attn_mask], + dropout_p=dropout_p, is_causal=is_causal, **kwargs) + else: + anno = OpAnno.create_op_str([query_anno, key_anno, value_anno], [out_anno]) + return IRDimops(ScaledDotProductAttention, 'scaled_dot_product_attention', signature, [anno], [query, key, value], + dropout_p=dropout_p, is_causal=is_causal, **kwargs) def Min(input, other_or_dim=None, keepdim=False, *, out=None, signature = None, **kwargs): @@ -2243,4 +2339,4 @@ def FullLike(input, fill_value, *, dtype=None, layout=None, kwargs = {'fill_value': fill_value, 'requires_grad': requires_grad,'dtype': dtype} shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] - return IRDimops(FullLike, 'full_like', signature, annos,[input],**kwargs) \ No newline at end of file + return IRDimops(FullLike, 'full_like', signature, annos,[input],**kwargs) diff --git a/cube/graph/graph.py b/cube/graph/graph.py index 9d9444d2..b51f1926 100644 --- a/cube/graph/graph.py +++ b/cube/graph/graph.py @@ -351,7 +351,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], ctensors: Dict[IRFullTensor, List[IRSubTensor]] = dict() consumers: Dict[IRFullTensor, List[IRCell]] = dict() for fnode in fnodes: - for itensor in set(fnode.inputs()): + for itensor in set(IRSegment.get_objects_from_complex(fnode.inputs())): if not isinstance(itensor, IRSubTensor): continue ctensors.setdefault(itensor.parent, []).append(itensor) consumers.setdefault(itensor.parent, []).append(fnode) diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index 4b304739..af388b98 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -226,4 +226,5 @@ def exist(signature: str) -> bool: 'torch.functional.split': function.Split, __ttemplate('split'): function.Split, __tttemplate('split'): function.Split, + __ttemplate('topk'): function.Topk, } diff --git a/cube/graph/segment.py b/cube/graph/segment.py index 91723810..6e835ee0 100644 --- a/cube/graph/segment.py +++ b/cube/graph/segment.py @@ -1119,6 +1119,8 @@ def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[ for key, value in val.items(): IRSegment.get_objects_from_complex(key, _objects) IRSegment.get_objects_from_complex(value, _objects) + if isinstance(val, slice): + IRSegment.get_objects_from_complex([val.start, val.stop, val.step], _objects) if isinstance(val, IRObject): _objects.append(val) return _objects diff --git a/cube/runtime/function/function.py b/cube/runtime/function/function.py index dbe1029f..b2431732 100644 --- a/cube/runtime/function/function.py +++ b/cube/runtime/function/function.py @@ -39,7 +39,7 @@ def accum(*tensors: Tuple[torch.Tensor]) -> torch.Tensor: return torch.sum(torch.stack(tensors, dim=0), dim=0) -def fullslice(input: torch.Tensor, slicers: Tuple[Union[None, slice, int]]): +def fullslice(input: torch.Tensor, *slicers: Union[None, slice, int, torch.Tensor]): """Slice tensors Note: @@ -49,13 +49,12 @@ def fullslice(input: torch.Tensor, slicers: Tuple[Union[None, slice, int]]): Args: input (torch.Tensor): input tensor - slicers (Tuple[None | slicer | int]): slicer tuple - + slicers (Union[None | slicer | int | torch.Tensor]): slicers for input Returns: torch.Tensor: sliced tensor """ - return input[slicers] + return input[tuple(slicers)] def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], diff --git a/tests/codegen/test_emit.py b/tests/codegen/test_emit.py index 5ba94634..496f61d4 100644 --- a/tests/codegen/test_emit.py +++ b/tests/codegen/test_emit.py @@ -19,6 +19,8 @@ def test_tensor_name(): assert repr_expr({'a': 1, 'b': IRObject('name', 111, 'value')}, 'model.') == "{'a': 1, 'b': name_111}" assert repr_expr([1], 'model.') == '[1]' assert repr_expr((1,), 'model.') == '(1,)' + + assert repr_expr((1,...), ) == '(1, Ellipsis)' with pytest.raises(ValueError): from datetime import datetime diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index b3419768..304aa35c 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -8,6 +8,8 @@ import numpy as np import math +from cube.ir.tensor import IRFullTensor + def o(value): return IRObject(value=value) @@ -199,21 +201,50 @@ def test_Where(): def test_FullSlice(): - op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, 3)) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' - op = F.FullSlice(IRTensor([2, 3, 4]), (..., 2)) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b' - op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, slice(0, 3, 2))) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 2' - op = F.FullSlice(IRTensor([2, 3, 4]), (1, 2, slice(1, 10, 1))) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 3' - with pytest.raises(RuntimeError): - op = F.FullSlice(IRTensor([2, 3, 4]), (IRTensor([1, 2, 3]),)) + op = F.FullSlice(IRTensor([2, 3, 4]), 1, [1.2, -1], 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, ?, ?, ? -> 2' + + op = F.FullSlice(IRTensor([2, 3, 4]), 1, 2, 3) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, ?, ?, ? -> 1' + op = F.FullSlice(IRTensor([2, 3, 4]), ..., 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^, ?, ?, ? -> a b' + op = F.FullSlice(IRTensor([2, 3, 4]), 1, 2, slice(0, 3, 2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, ?, ?, ? -> 2' + op = F.FullSlice(IRTensor([2, 3, 4]), 1, 2, slice(1, 10, 1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, ?, ?, ? -> 3' + op = F.FullSlice(IRTensor([2, 3, 4]), 1, None, ..., None, slice(0, 2, None)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c^, ?, ?, ?, ?, ? -> 1 b 1 2' + with pytest.raises(RuntimeError): - op = F.FullSlice(IRTensor([2, 3, 4]),(slice(1, IRTensor([2]), 3),)) + op = F.FullSlice(IRTensor([2, 3, 4]), slice(1, IRTensor([2]), 3)) + op = F.FullSlice(IRTensor([3, 4]), None, 0, slice(0, 4, 2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, ?, ?, ? -> 1 2' + op = F.FullSlice(IRTensor([3, 4]), [0,2], 0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, ?, ? -> 2' + op = F.FullSlice(IRTensor([3, 4]), [[0,1], [1,2]], 0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, ?, ? -> 2 2' + op = F.FullSlice(IRTensor([3, 4]), IRFullTensor([2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b, c -> c b' + op = F.FullSlice(IRTensor([3, 4]), IRFullTensor([2,2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b, c d -> c d b' + op = F.FullSlice(IRTensor([3, 4]), [True, False, True]) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b, ? -> ?' + op = F.FullSlice(IRTensor([3, 4]), IRFullTensor([3], dtype=torch.bool), 0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, c^, ? -> ?' + op = F.FullSlice(IRTensor([3, 4]), [True, False, True], [0,1]) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, ?, ? -> ?' + op = F.FullSlice(IRTensor([3, 4]), [True, False, True], IRFullTensor([2,2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, ?, c^ d^ -> ?' + op = F.FullSlice(IRTensor([3, 4]), IRFullTensor([3, 4], dtype=torch.bool)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, c^ d^ -> ?' + op = F.FullSlice(IRTensor([3, 4]), IRFullTensor([3]), IRFullTensor([3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, c^, d^ -> ?' + op = F.FullSlice(IRTensor([3, 4]), IRFullTensor([2,2]), IRFullTensor([2,2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, c^ d^, e^ f^ -> ?' def test_GetItem(): + # obj is IRTensor, index is IRTensor op = F.GetItem(IRTensor([4, 2]), IRTensor([3, 5], dtype=torch.int64)) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b, c d -> c d b' op = F.GetItem(IRTensor([4, 2]), IRTensor([3], dtype=torch.int64)) @@ -222,9 +253,39 @@ def test_GetItem(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, d -> d b c' op = F.GetItem(IRTensor([3, 4, 2]), IRTensor([3, 5], dtype=torch.int64)) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, d e -> d e b c' + + # obj is IRTensor, index is not IRTensor, will call FullSlice + op = F.GetItem(IRTensor([3, 4, 2]), 0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, ? -> b c' + op = F.GetItem(IRTensor([3, 4, 2]), [0, 1]) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c, ?, ? -> c' + op = F.GetItem(IRTensor([3, 4, 2]), slice(0, 3, 2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b c, ? -> 2 b c' + op = F.GetItem(IRTensor([3, 4, 2]), [slice(None), IRTensor([3, 5], dtype=torch.int64)]) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c, ?, d e -> a d e c' + op = F.GetItem(IRTensor([3, 4, 2]), [slice(None), IRTensor([4, 2], dtype=torch.bool)]) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c^, ?, d^ e^ -> ?' + + # obj is IRObject + op = F.GetItem(IRObject(value=[3, 4, 5], is_constant=False), IRObject(value=0, is_constant=False), signature='operator.getitem') + assert op.outputs()[0].value == 3 and op.outputs()[0].is_constant == True + op = F.GetItem(IRObject(value=[3, 4, 5], is_constant=False), IRObject(value=slice(0, 2, 1), is_constant=False), signature='operator.getitem') + assert op.outputs()[0].value == [3, 4] and op.outputs()[0].is_constant == True + op = F.GetItem(IRObject(value=[3, 4, 5], is_constant=False), 0, signature='operator.getitem') + assert op.outputs()[0].value == 3 and op.outputs()[0].is_constant == True + op = F.GetItem(IRObject(value=[3, 4, 5], is_constant=False), slice(0, 2, 1), signature='operator.getitem') + assert op.outputs()[0].value == [3, 4] and op.outputs()[0].is_constant == True + + # obj is not a IRObject, index is a IRObject op = F.GetItem([1, 2, 3], IRObject(value=0, is_constant=False), signature='operator.getitem') assert op.outputs()[0].value == 1 and op.outputs()[0].is_constant == False + # direct call obj[index] + op = F.GetItem([3, 4, 2], 1) + assert op == 4 + op = F.GetItem([3, 4, 2], slice(0, 2, 1)) + assert op == [3, 4] + def test_Max(): op = F.Max(IRTensor([2, 3, 4])) @@ -293,6 +354,17 @@ def test_Unsqueeze(): def test_ScaledDotProductAttention(): op = F.ScaledDotProductAttention(IRTensor([8, 128, 64]), IRTensor([8, 256, 64]), IRTensor([8, 256, 32]), None, 0.05) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a e d^, a b^ d^, a b^ c -> a e c' + op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([128, 256]), 0.05) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, f c^ -> a b f d' + op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([1, 128, 256]), 0.05) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, 1 f c^ -> a b f d' + op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([1, 8, 128, 256]), 0.05) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, 1 b f c^ -> a b f d' + op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([1, 1, 256]), 0.05) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, 1 1 c^ -> a b f d' + op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([1, 8, 128, 1]), 0.05) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, 1 b f 1 -> a b f d' + def test_NewTensor(): @@ -382,4 +454,29 @@ def test_Log(): op = F.Log(IRTensor([1,2,3])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c -> a b c' op = F.Log(IRObject(value=6, is_constant=False), signature='math.log') - assert op.outputs()[0].value == math.log(6) and not op.outputs()[0].is_constant \ No newline at end of file + assert op.outputs()[0].value == math.log(6) and not op.outputs()[0].is_constant + + +def test_Arange(): + op = F.Arange(10) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 10^' and op.kwargs['dtype'] == torch.int64 + op = F.Arange(1, 10, 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 5^' and op.kwargs['dtype'] == torch.int64 + op = F.Arange(1, 10, 2, dtype=torch.float) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 5^' and op.kwargs['dtype'] == torch.float + op = F.Arange(1.0, 10.0, 2.0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 5^' and op.kwargs['dtype'] == torch.float + op = F.Arange(IRObject(value=10)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 10^' and op.kwargs['dtype'] == torch.int64 + op = F.Arange(IRObject(value=1), IRObject(value=10.0), 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 5^' and op.kwargs['dtype'] == torch.float + +def test_Flatten(): + op = F.Flatten(IRTensor([2,3,4,5]), 1, 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c^ d -> a (b^ c^) d' + op = F.Flatten(IRTensor([2,3,4,5]), 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c^ d^ -> a (b^ c^ d^)' + op = F.Flatten(IRTensor([2,3,4,5]), end_dim = 2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ d -> (a^ b^ c^) d' + + From 0581c1cc182561b1effa521b158754d8ca792e11 Mon Sep 17 00:00:00 2001 From: Deming Chu Date: Thu, 11 Apr 2024 07:00:14 +0000 Subject: [PATCH 1616/1892] Merged PR 2099: fix: incorrect ValueMap.overlap() checking In /cube/ir/tensor.py line 187, this should be other.weight ``` idx1, nchunk1 = self.weight idx2, nchunk2 = self.weight ``` --- cube/ir/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index b4bb6819..b0d86e36 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -184,7 +184,7 @@ def overlap(self, other) -> bool: if not isinstance(other, ValueMap): raise TypeError("Expected ValueMap") idx1, nchunk1 = self.weight - idx2, nchunk2 = self.weight + idx2, nchunk2 = other.weight span1 = (idx1 * nchunk2, idx1 * nchunk2 + nchunk2) span2 = (idx2 * nchunk1, idx2 * nchunk1 + nchunk1) if max(span1[0], span2[0]) < min(span1[1], span2[1]): From a72708a753c828c35a8d709b99a0fb2156fbf29e Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 11 Apr 2024 07:31:55 +0000 Subject: [PATCH 1617/1892] Merged PR 2106: bug fix expand for dynamic shape bug fix expand for dynamic shape --- cube/graph/function/function.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index b05c92df..248c19ac 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -469,12 +469,23 @@ def Expand(input, size, *arg_size, signature = None): edim_ou[exp_len - ori_len + idx] = edim_in[idx] else: edim_ou[exp_len - ori_len + idx] = str(new_dim) - # explicit set tid to -1 to avoid changing IDGenerator state. - new_size[exp_len - ori_len + idx] = IRObject(tid=-1, value=expand_dim) if expand_dim_is_ir else new_dim + new_size[exp_len - ori_len + idx] = new_dim for idx in range(exp_len - ori_len): edim_ou[idx] = str(size[idx]) new_size[idx] = size[idx] anno = OpAnno.create_op_str([edim_in], [edim_ou]) + + # fix the size parameter with IRObject + if isinstance(complete_size, IRObject): + new_size = complete_size # use the original IRObject size + else: + assert isinstance(complete_size, (tuple, list)) + assert len(complete_size) == len(new_size) + for idx in range(len(new_size)): + if isinstance(complete_size[idx], IRObject): + # replace with IRObject version + new_size[idx] = complete_size[idx] + return IRDimops(Expand, 'expand', signature, [anno], [input], size=new_size) @@ -1646,8 +1657,8 @@ def obj_helper(obj): # If there are more than one tensors or lists in slicers and their date type is not bool, they will broadcast to each other, # and the output shape will be infered by the shapes of all tensors and lists in slicers, will use '?' in edim_ou - _single_int_tensor = len([slicer for slicer in slicers if - (isinstance(slicer, IRTensor) and slicer.dtype is not bool ) + _single_int_tensor = len([slicer for slicer in slicers if + (isinstance(slicer, IRTensor) and slicer.dtype is not bool ) or (isinstance(slicer, list) and slicer[0] is not bool)]) <= 1 output_shape_unkonwn = False slicers = list(slicers) @@ -1720,7 +1731,7 @@ def list_shape(lst): if len(edim_ou) == 0: # special case for scalar = torch.Tensor([1,2,3])[0] edim_ou.append('1') - + edim_in = [edim_in] edim_in.extend(edim_in_additional) anno = OpAnno.create_op_str(edim_in, [edim_ou]) From 4d1140ecb2f448a0be852fd7eba9c82e0b984b12 Mon Sep 17 00:00:00 2001 From: "Xin Ji (CSI Interfusion Co Ltd)" Date: Mon, 15 Apr 2024 03:31:08 +0000 Subject: [PATCH 1618/1892] Merged PR 2075: Bug fixes and support for more operations and testing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix and support operators: Find unknown pytorch operation: torch.outer Find unknown pytorch operation: torch.erf Find unknown pytorch operation: torch.addmm Find unknown pytorch operation: torch.Tensor.type Find unknown pytorch operation: torch.conv1d Find unknown pytorch operation: torch.multiply Find unknown pytorch operation: torch.softmax Find unknown pytorch operation: torch.zeros_like Find unknown pytorch operation: torch.ones_like --- cube/graph/function/function.py | 170 ++++++++++++++++++++++++- cube/graph/parser/fx/mapping.py | 10 ++ tests/graph/function/test_functions.py | 141 +++++++++++++++++++- 3 files changed, 316 insertions(+), 5 deletions(-) diff --git a/cube/graph/function/function.py b/cube/graph/function/function.py index 248c19ac..f3c84e16 100644 --- a/cube/graph/function/function.py +++ b/cube/graph/function/function.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Tuple, Dict, Union, Iterable +from typing import Any, Callable, List, Optional, Tuple, Dict, Union, Iterable import string import copy import torch @@ -2346,8 +2346,174 @@ def FullLike(input, fill_value, *, dtype=None, layout=None, """ torch.full_like(input, fill_value, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor """ - assert layout in (None, torch.strided) and memory_format is None, f"Not support for non-default memory_format and layout" + if not (layout in (None, torch.strided) and memory_format is None): + raise ValueError("Not support for non-default memory_format and layout") kwargs = {'fill_value': fill_value, 'requires_grad': requires_grad,'dtype': dtype} shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] return IRDimops(FullLike, 'full_like', signature, annos,[input],**kwargs) + + +def ZerosLike(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None, signature=None): + """ + torch.zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor + """ + if not (layout in (None, torch.strided) and memory_format is None): + raise ValueError("Not support for non-default memory_format and layout") + dtype = dtype if dtype is not None else torch.get_default_dtype() + if not isinstance(dtype, torch.dtype): + raise TypeError("only supports torch.dtype but got {}".format(dtype)) + kwargs = {'requires_grad': requires_grad, 'dtype': dtype} + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(ZerosLike, 'zeros_like', signature, annos,[input],**kwargs) + + +def OnesLike(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None, signature=None): + """ + torch.ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor + """ + if not (layout in (None, torch.strided) and memory_format is None): + raise ValueError("Not support for non-default memory_format and layout") + dtype = dtype if dtype is not None else torch.get_default_dtype() + if not isinstance(dtype, torch.dtype): + raise TypeError("only supports torch.dtype but got {}".format(dtype)) + kwargs = {'requires_grad': requires_grad, 'dtype': dtype} + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(OnesLike, 'onesLike', signature, annos,[input],**kwargs) + + +def Addmm(input: IRTensor, mat1: IRTensor, mat2: IRTensor, *, beta=1, alpha=1, out=None, signature = None): + """ + torch.addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) → Tensor + """ + if out is not None: + raise ValueError("Expected 'out' to be None") + if len(mat1.shape) != 2 or len(mat2.shape) != 2: + raise ValueError("mat1 and mat2 must both be 2-dimensional.") + if mat1.shape[-1] != mat2.shape[-2]: + raise ValueError("Shapes of mat1 and mat2 are incompatible for matrix multiplication.") + matmul_result_shape = (mat1.shape[0], mat2.shape[1]) + if len(input.shape) < 2: + matmul_result = IRTensor(shape=matmul_result_shape) + lshape, rshape, oshape = _handle_broadcast(input, matmul_result) + anno = f"{' '.join(lshape)}, {rshape[0]} k^, k^ {rshape[1]} -> {' '.join(oshape)}" + elif len(input.shape) == 2: + if (input.shape[0] != 1 and input.shape[0] != matmul_result_shape[0]) or \ + (input.shape[1] != 1 and input.shape[1] != matmul_result_shape[1]): + raise ValueError("`input` shape cannot be broadcasted to match the result of mat1 @ mat2.") + else: + anno = f'{"1" if input.shape[0] == 1 else "a"} {"1" if input.shape[1] == 1 else "b"}, a k^, k^ b -> a b' + else: + raise ValueError("The `input` tensor does not have a compatible shape for this operation.") + return IRDimops(Addmm, 'addmm', signature, [anno], [input, mat1, mat2], beta=beta, alpha=alpha) + + +def Type(tensor: IRTensor, dtype: Optional[Union[str, torch.dtype, IRObject]] = None, non_blocking: bool = False, out=None, signature=None, **kwargs): + """ + Tensor.type(dtype=None, non_blocking=False, **kwargs) → str or Tensor + """ + if out is not None: + raise ValueError("Expected 'out' to be None") + annos = ['* -> *'] + original_dtype = dtype + dtype = _unwrap_value(dtype) + if dtype is None: + return IRPyFunc(signature,[tensor], [IRObject(value=str(tensor.dtype))]) + else: + if isinstance(dtype, str): + return IRDimops(Type, 'type', signature, annos, [tensor], dtype=original_dtype, non_blocking=non_blocking) + elif isinstance(dtype, torch.dtype): + return IRDimops(Type, 'type', signature, annos, [tensor], dtype=original_dtype, non_blocking=non_blocking) + else: + raise RuntimeError(f'function.type with unknown arg: {dtype}') + + +def Outer(input, vec2, *, out=None, signature=None): + """ + torch.outer(input, vec2, *, out=None) → Tensor + """ + if out is not None: + raise ValueError("Expected 'out' to be None") + if not (len(input.shape) == 1 and len(vec2.shape) == 1): + raise ValueError("'input' and 'vec2' must both be 1-D tensors.") + anno = 'n, m -> n m' + return IRDimops(Outer, 'outer', signature, [anno], [input, vec2]) + + +def Erf(input, *, out=None, signature=None): + """ + torch.erf(input, *, out=None) → Tensor + """ + if out is not None: + raise ValueError("Expected 'out' to be None") + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(Erf, 'erf', signature, annos, [input]) + + +def Conv1D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, signature=None): + """ + torch.nn.functional.conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor + """ + if isinstance(stride, int): + stride = (stride,) + if isinstance(dilation, int): + dilation = (dilation,) + if isinstance(padding, str): + if padding == 'same': + # For 'same' padding, calculate padding needed to keep the output shape the same as input shape + # this mode doesn’t support any stride values other than 1. + kW = weight.shape[2] + iW = input.shape[2] + effective_kernel_size = (kW - 1) * dilation[0] + 1 + total_padding = max(0, (iW - 1) * stride[0] + effective_kernel_size - iW) + pad_ = total_padding // 2 + # NOTE: While we calculate padding for both sides, conv1d expects a single integer for symmetrical padding. + padding = (pad_, ) + elif padding == 'valid': + padding = (0, ) + else: + raise ValueError("Unsupported padding value: {}. Use 'valid', 'same', or an integer.".format(padding)) + elif isinstance(padding, int): + padding = (padding,) + elif not isinstance(padding, tuple): + raise ValueError("Padding must be a string ('valid', 'same'), an integer, or a tuple") + + ori_groups = groups + if isinstance(groups, IRObject): groups = groups.value + _, iW = input.shape[1:3] + oC, iC, kW = weight.shape + oW = (iW + 2 * padding[0] - dilation[0] * (kW - 1) - 1) // stride[0] + 1 + if input.shape[1] // groups != weight.shape[1]: + raise ValueError(f'Input shape and weight shape are not compatible for the number of groups. input shape: {input.shape}, weight shape: {weight.shape}, groups: {groups}') + if weight.shape[0] % groups != 0: + raise ValueError('The output channels of weight must be divisible by the number of groups.') + def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: + # only for partitioning groups + kwargs = dict(**kwargs) + kw_groups = kwargs['groups'] + if isinstance(kw_groups, IRObject): + _logger.warning(f'partition groups in IRObject: {kw_groups}') + kw_groups = kw_groups.value + kwargs['groups'] = kw_groups // num + return kwargs + if bias is None: + # NOTE: cannot support partitioning inchannel when groups>1 + if groups == 1: + annos = [f'n iC+ {iW}, oC iC+ {kW} -> n oC {oW}'] + rules = None + else: + rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0)], [DimopSplit.D(1)], modifier)] + annos = [f'n (g {iC}) {iW}, (g {oC//groups}) {iC} {kW} -> n (g {oC//groups}) {oW}'] + else: + # NOTE: not supported value partition of bias yet + if groups == 1: + annos = [f'n iC^ {iW}, oC iC^ {kW}, oC -> n oC {oW}'] + rules = None + else: + rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0), DimopSplit.D(0)], [DimopSplit.D(1)], modifier)] + annos = [f'n (g {iC}) {iW}, (g {oC//groups}) {iC} {kW}, (g {oC//groups}) -> n (g {oC//groups}) {oW}'] + return IRDimops(Conv1D, 'conv1d', signature, annos, [input, weight, bias] if bias is not None else [input, weight], rules, + stride=stride, padding=padding, dilation=dilation, groups=ori_groups) diff --git a/cube/graph/parser/fx/mapping.py b/cube/graph/parser/fx/mapping.py index af388b98..0ac80d11 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/cube/graph/parser/fx/mapping.py @@ -57,6 +57,8 @@ def exist(signature: str) -> bool: __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, __ttemplate('mean') : function.Mean, + __ttemplate('outer'): function.Outer, + __ttemplate('erf'): function.Erf, __ttemplate('abs'): function.Abs, __ttemplate('exp'): function.Exp, 'math.exp': function.Exp, @@ -84,6 +86,7 @@ def exist(signature: str) -> bool: __ttemplate('min'): function.Min, __ttemplate('where'): function.Where, __ttemplate('nan_to_num') : function.NanToNum, + __tttemplate('type'): function.Type, __tttemplate('long'): function.Long, __tttemplate('int'): function.Int, __tttemplate('float'): function.Float, @@ -94,6 +97,7 @@ def exist(signature: str) -> bool: __ttemplate('cumsum'): function.CumSum, __ttemplate('tanh'): function.Tanh, __ftemplate('softmax') : function.Softmax, + __ttemplate('softmax'): function.Softmax, __ftemplate('log_softmax') : function.LogSoftmax, __ttemplate('bmm') : function.BatchLinear, __ttemplate('pow'): function.Pow, @@ -145,7 +149,9 @@ def exist(signature: str) -> bool: # # creators __ttemplate('empty'): function.Empty, __ttemplate('zeros'): function.Zeros, + __ttemplate('zeros_like'): function.ZerosLike, __ttemplate('ones'): function.Ones, + __ttemplate('ones_like'): function.OnesLike, __ttemplate('tensor'): function.NewTensor, __ttemplate('full'): function.Full, __ttemplate('full_like'): function.FullLike, @@ -156,12 +162,14 @@ def exist(signature: str) -> bool: '_operator.is_not': function.IsNot, __ttemplate('add') : function.Add, '_operator.add': function.Add, + __ttemplate('addmm'): function.Addmm, '_operator.iadd': function.Add, # FIXME: may waste memory __ttemplate('sub') : function.Sub, '_operator.sub': function.Sub, __ttemplate('mul') : function.Mul, '_operator.mul': function.Mul, '_operator.imul': function.Mul, # FIXME: may waste memory + __ttemplate('multiply') : function.Mul, '_operator.mod': function.Mod, __ttemplate('div') : function.Div, @@ -190,6 +198,8 @@ def exist(signature: str) -> bool: __tttemplate('contiguous'): function.Contiguous, __ttemplate('reshape'): function.Reshape, + + __ttemplate('conv1d'): function.Conv1D, # # __ttemplate('conv2d'): function.Conv2D, # diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 304aa35c..db17afb6 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -418,6 +418,8 @@ def test_Min(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a c, a c' op = F.Min(IRTensor([2, 3, 4]), IRTensor([4])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c, c -> a b c' + op = F.Min(IRTensor([4]), IRTensor([2, 3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'c, a b c -> a b c' op = F.Min(IRTensor([2, 3, 4]), IRObject(value=2, is_constant=False), True) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 1, a b 1' op = F.Min(IRTensor([2, 3, 4]), 2,IRObject(value=True, is_constant=False)) @@ -457,6 +459,140 @@ def test_Log(): assert op.outputs()[0].value == math.log(6) and not op.outputs()[0].is_constant +def test_ZerosLike(): + op = F.ZerosLike(IRTensor([2, 1, 4, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a b c d' + op_true = F.ZerosLike(IRTensor([2, 2]), requires_grad=True) + assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' + op_float = F.ZerosLike(IRTensor([1, 2],dtype=int), dtype=torch.float) + assert len(op_float._annos_candidates) == 1 and op_float._annos_candidates[0] == 'a b -> a b' + + +def test_OnesLike(): + op = F.OnesLike(IRTensor([2, 1, 4, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a b c d' + op_true = F.OnesLike(IRTensor([2, 2]), requires_grad=True) + assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' + op_float = F.OnesLike(IRTensor([1, 2],dtype=int), dtype=torch.float) + assert len(op_float._annos_candidates) == 1 and op_float._annos_candidates[0] == 'a b -> a b' + + +def test_addmm(): + op = F.Addmm(IRTensor([2, 3]), mat1=IRTensor([2, 7]), mat2=IRTensor([7, 3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, a k^, k^ b -> a b' + op = F.Addmm(IRTensor([1, 3]), mat1=IRTensor([2, 7]), mat2=IRTensor([7, 3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '1 b, a k^, k^ b -> a b' + op = F.Addmm(IRTensor([2, 1]), mat1=IRTensor([2, 7]), mat2=IRTensor([7, 3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 1, a k^, k^ b -> a b' + op = F.Addmm(IRTensor([1, 1]), mat1=IRTensor([2, 7]), mat2=IRTensor([7, 3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '1 1, a k^, k^ b -> a b' + op = F.Addmm(IRTensor([3]), mat1=IRTensor([2, 3]), mat2=IRTensor([3, 3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b, a k^, k^ b -> a b' + op = F.Addmm(IRTensor([7]), mat1=IRTensor([2, 3]), mat2=IRTensor([3, 7])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'b, a k^, k^ b -> a b' + + +def test_type(): + op = F.Type(IRTensor([2,3],dtype=None),torch.float32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype'] == torch.float32 + op = F.Type(IRTensor([3, 5], dtype=torch.int64),torch.float32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype'] == torch.float32 + op = F.Type(IRTensor([3, 5], dtype=torch.int64),IRObject(value=torch.float32, is_constant=False), signature='torch.type') + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype'].value == torch.float32 + op = F.Type(IRTensor([3, 5], dtype=torch.int64),dtype=IRObject(value=None, is_constant=False), signature='torch.type') + assert op.outputs()[0].value == "torch.int64" + op = F.Type(IRTensor([3, 5], dtype=torch.int64),dtype=torch.int64) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype'] == torch.int64 + op = F.Type(IRTensor([3, 5], dtype=torch.int64),"torch.int64") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype'] == 'torch.int64' + op = F.Type(IRTensor([3, 5], dtype=torch.int64), signature='torch.type') + assert op.outputs()[0].value == "torch.int64" + + +def test_to(): + op = F.To(IRTensor([2, 3], dtype=None), dtype=torch.float32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype_or_device'] == torch.float32 + op = F.To(IRTensor([2, 3], dtype=torch.float32), device=torch.device('cuda:0')) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b' + op = F.To(IRTensor([3, 5], dtype=torch.int64), dtype=torch.float32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype_or_device'] == torch.float32 + op = F.To(IRTensor([2, 3], dtype=torch.float32), device=torch.device('cuda:0')) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b' + op = F.To(IRTensor([3, 5], dtype=torch.int64), dtype=IRTensor(dtype=torch.float32)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' + op = F.To(IRTensor([2, 3], dtype=torch.float32), device=torch.device('cuda:0'), dtype=torch.float32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype_or_device'] == torch.float32 + op = F.To(IRTensor([3, 5], dtype=torch.int64), dtype_or_device=IRTensor(dtype=torch.float32)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' + + + +def test_outer(): + op = F.Outer(IRTensor([2]), vec2=IRTensor([2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n, m -> n m' + + +def test_erf(): + op = F.Erf(IRTensor([2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a -> a' + op = F.Erf(IRTensor([2,3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b' + op = F.Erf(IRTensor([2,3,4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c -> a b c' + op = F.Erf(IRTensor([2,3,4,5])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a b c d' + + +def test_mul_or_multiply(): + op = F.Mul(IRTensor([1,2]),100) + assert len(op._annos_candidates) == 2 and op._annos_candidates[0] == '*, ? -> *' + op = F.Mul(IRTensor([1,2]),IRTensor([1,2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, a b -> a b' + op = F.Mul(IRTensor([2,2]),IRTensor([2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, b -> a b' + op = F.Mul(100,IRTensor([6])) + assert len(op._annos_candidates) == 2 and op._annos_candidates[1] == '?, * -> *' + op = F.Mul(torch.tensor([[1, 2], [1, 2]]),100) + assert torch.equal(torch.mul(torch.tensor([[1, 2], [1, 2]]), 100), op), "The result does not match the expected output" + op = F.Mul(torch.tensor([[1, 2], [1, 2]]),torch.tensor([[1, 2]])) + assert torch.equal(torch.mul(torch.tensor([[1, 2], [1, 2]]), torch.tensor([[1, 2]])), op), "The result does not match the expected output" + op = F.Mul(torch.tensor([1, 2]),IRObject(value=100, is_constant=False), signature='torch.mul') + assert torch.equal(op.outputs()[0].value, torch.mul(torch.tensor([1, 2]), 100)) and not op.outputs()[0].is_constant and torch.equal(op.outputs()[0].value, torch.tensor([100, 200])) + + +def test_Softmax(): + op = F.Softmax(IRTensor([2, 3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c -> a b c' + op = F.Softmax(IRTensor([2, 3, 4]), dim=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a b^ c' + op = F.Softmax(IRTensor([2, 3, 4]), dim=2, dtype=torch.float64) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b c^' + op = F.Softmax(IRTensor([2, 3, 4]), dtype=torch.float32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c -> a b c' + + +def test_Conv1D(): + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 4' + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), stride=2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 2' + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), padding=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 6' + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), dilation=2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 4' + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), groups=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 4' + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), groups=1,padding="valid") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 4' + op = F.Conv1D(input=IRTensor([4, 8, 32]), weight=IRTensor([16, 8, 3]), bias=IRTensor([16,]),groups=1,padding="same") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC^ 32, oC iC^ 3, oC -> n oC 32' + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 1, 1]), groups=3) + expected_annotation_for_groups = 'n (g 1) 4, (g 1) 1 1 -> n (g 1) 4' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation_for_groups + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), bias=IRTensor([3])) + assert op._annos_candidates[0] == 'n iC^ 4, oC iC^ 1, oC -> n oC 4', "Annotation mismatch." + + def test_Arange(): op = F.Arange(10) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 10^' and op.kwargs['dtype'] == torch.int64 @@ -471,12 +607,11 @@ def test_Arange(): op = F.Arange(IRObject(value=1), IRObject(value=10.0), 2) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 5^' and op.kwargs['dtype'] == torch.float + def test_Flatten(): op = F.Flatten(IRTensor([2,3,4,5]), 1, 2) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c^ d -> a (b^ c^) d' op = F.Flatten(IRTensor([2,3,4,5]), 1) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c^ d^ -> a (b^ c^ d^)' op = F.Flatten(IRTensor([2,3,4,5]), end_dim = 2) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ d -> (a^ b^ c^) d' - - + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ d -> (a^ b^ c^) d' \ No newline at end of file From 8ffd5e13be991bf81c83ae5dee810acf12a119d1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 17 Apr 2024 04:30:33 +0000 Subject: [PATCH 1619/1892] Merged PR 2103: parallel module: add pipeline (end2end) support 1. end2end module support 2. inference_only flag added. 3. pipeline related config added. --- cube/codegen/module/module.py | 24 +- cube/graph/schedule/predefined.py | 20 +- cube/parallel.py | 340 ++++++++++++----- cube/runtime/module.py | 146 ++++++- docs/parallel_module.md | 186 +++++++-- examples/mlp/train.py | 2 +- tests/parallel_module/common.py | 107 +++++- tests/parallel_module/test_broadcast.py | 2 +- tests/parallel_module/test_checkpoint.py | 207 ++++++++-- .../parallel_module/test_checkpoint_dedup.py | 59 ++- .../parallel_module/test_checkpoint_shared.py | 57 +-- tests/parallel_module/test_end2end.py | 356 ++++++++++++++++++ tests/parallel_module/test_gencode.py | 107 ++++++ tests/parallel_module/test_inference.py | 14 +- tests/parallel_module/test_override.py | 27 +- tests/parallel_module/test_submodule.py | 4 +- tests/parallel_module/test_wholemodule.py | 2 +- 17 files changed, 1432 insertions(+), 228 deletions(-) create mode 100644 tests/parallel_module/test_end2end.py diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index 0b79a1db..eab29896 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -297,6 +297,7 @@ def gen( attach: bool = False, *, as_parallel_module: bool = False, + end2end_mode: bool = False, forward_args: Optional[Dict[str, Any]] = None ) -> str: """ @@ -380,6 +381,10 @@ def forward(self, x, y=None, z=None): 1. Inherit from ParallelModule 2. Has forward method 3. Add more content to constructor + end2end_mode (bool): whether to generate code for end2end mode. + If True, a mocked `forward` will be generated which only raises NotImplementedError. + If False, the real forward function will be generated. + This is used only in parallel module. forward_args (Dict[str, Any]): argument names and their default values of forward function, if None, use node inputs. This is used only in parallel module. @@ -490,15 +495,20 @@ def forward(self, x, y=None, z=None): cb.insert_body(fb.code) if as_parallel_module: - if len(segment_idxs) > 1: - raise RuntimeError("The graph has more than one segment, forward code cannot be generated.") - elif not segment_idxs: + if not segment_idxs: raise RuntimeError("The graph has no segment, forward code cannot be generated.") - segment_idx = segment_idxs[0] - node = gen_nodes[segment_idx] - cb.insert_body('') - cb.insert_body(self._generate_forward(node, forward_args)) + if not end2end_mode: + if len(segment_idxs) > 1: + raise RuntimeError("The graph has more than one segment, forward code cannot be generated.") + segment_idx = segment_idxs[0] + node = gen_nodes[segment_idx] + cb.insert_body(self._generate_forward(node, forward_args)) + else: + msg = "Code of forward is not generated. You should use module.train_step/module.infer_step instead." + with FunctionBlock(func_name='_forward_impl', args=['self', '*args', '**kwargs']) as fb: + fb.insert_body(f'raise NotImplementedError("{msg}")') + cb.insert_body(fb.code) gencode += cb.code gencode += [''] diff --git a/cube/graph/schedule/predefined.py b/cube/graph/schedule/predefined.py index ebb81b18..b3436d6b 100644 --- a/cube/graph/schedule/predefined.py +++ b/cube/graph/schedule/predefined.py @@ -75,7 +75,7 @@ def sched_1f1b_plus(graph: IRGraph, num_microbatches: int, num_stages: int) -> S wait_steps = [sid for sid in range(num_stages)] bw_ofst = [num_stages - 1 - sid for sid in range(num_stages)] total_steps = num_microbatches * 2 + (num_stages - 1) * 2 - + # 1f1b schedule for step in range(total_steps): for sid in range(num_stages): @@ -88,7 +88,7 @@ def sched_1f1b_plus(graph: IRGraph, num_microbatches: int, num_stages: int) -> S # append for execution if mb_idx < 0 or mb_idx >= num_microbatches: continue sched.add_segment(segment, mb_idx, step) - + # insert for mid in range(num_microbatches): for tp_seg in tp_fsegs: @@ -104,7 +104,7 @@ def sched_1f1b_plus(graph: IRGraph, num_microbatches: int, num_stages: int) -> S # insert forward if next_seg in segments: sched.insert_step(step, tp_seg, mid, 1) - assert not insert_fw + assert not insert_fw insert_fw = True # insert backward if next_seg.mirror in segments: @@ -116,7 +116,7 @@ def sched_1f1b_plus(graph: IRGraph, num_microbatches: int, num_stages: int) -> S assert insert_fw and insert_bw, ( f'find one segment cannot be inserted in schedplan: ', f'mid: {mid}, fw: {insert_fw}, bs: {insert_bw}') - + sched.finish() # print(sched) return sched @@ -131,7 +131,7 @@ def sched_gpipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> Sched f0 f1 f2 f3 b0 b1 b2 b3 f0 f1 f2 f3 b0 b1 b2 b3 f0 f1 f2 f3 b0 b1 b2 b3 - f0 f1 f2 f3 b0 b1 b2 b3 + f0 f1 f2 f3 b0 b1 b2 b3 ``` """ if num_microbatches <= 0: @@ -171,14 +171,14 @@ def sched_chimera_direct(graph: IRGraph, num_microbatches: int, num_stages: int) 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 (-> steps) ``` - + Note the f0 and f2 (step 0) should be considered to be one segment in graph. """ if num_microbatches <= 0: raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) fsegs = [seg for seg in segments if seg.isfw()] - assert len(fsegs) == 4, f"Chimera-direct scheduling only applies for 4 segments, but {len(segments)} detected" + assert len(fsegs) == 4, f"Chimera-direct scheduling only applies for 4 segments, but {len(fsegs)} detected" sched = SchedulePlan(graph, num_microbatches) assert num_microbatches % 2 == 0 mid = 0 @@ -215,9 +215,9 @@ def sched_infer_pipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> An illustration of scheduling schema (the number is micro-batch index): ``` - f0 f1 f2 f3 - f0 f1 f2 f3 - f0 f1 f2 f3 + f0 f1 f2 f3 + f0 f1 f2 f3 + f0 f1 f2 f3 f0 f1 f2 f3 ``` """ diff --git a/cube/parallel.py b/cube/parallel.py index e334a963..8f8ef9bc 100644 --- a/cube/parallel.py +++ b/cube/parallel.py @@ -6,16 +6,18 @@ import inspect import sys import importlib -from dataclasses import dataclass, asdict +from dataclasses import dataclass, asdict, field from contextlib import contextmanager import logging import copy import os import torch +from cube.codegen.schedule.schedule import ScheduleCodeGen from cube.graph.parser.fx.parser import FxModuleParser -from cube.ir.cten import IRObject +from cube.graph.schedule.predefined import PredefinedSched +from cube.ir.cten import IRObject, IRTensor from cube.ir.tensor import IRFullTensor from cube.flags import CompileFlag, RuntimeFlag @@ -23,7 +25,7 @@ from cube.graph import IRGraph from cube.graph import parser -from cube.ir.operator import IRBpOperation +from cube.ir.operator import IRBpOperation, IRDataOperation from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.pyfunc import IRPyFunc from cube.graph.schedule.schedplan import SchedulePlan @@ -44,20 +46,17 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class ComputeConfig: - plan_ngpus: int - runtime_ngpus: int - - # whether to use dynamic shape to generate code - dynamic_shape: bool = True +_PREDEFINE_SCHEDS: Dict[str, Callable[[IRGraph, int, int], SchedulePlan]] = {} +_PREDEFINED_INFERENCE_SCHEDS = ['infer_pipe'] +_PREDEFINE_SCHED_NAME_PREFIX = 'sched_' +for k, v in PredefinedSched.__dict__.items(): + if isinstance(v, staticmethod) and k.startswith(_PREDEFINE_SCHED_NAME_PREFIX): + _PREDEFINE_SCHEDS[k[len(_PREDEFINE_SCHED_NAME_PREFIX):]] = v - use_zero: bool = False - zero_ngroups: int = 1 - - # you can put any configuration here - # Note: different user_config should generate different graph/code. - # so if user_config is changed, both graph and code will be regenerated. +@dataclass +class UserConfig: + # you should put any configuration that may affect the traced graph here. + # So we can track the changes and make sure the generated code is correct. # Example 1: save module configuration # ```python # class MyModule(torch.nn.Module): @@ -68,30 +67,59 @@ class ComputeConfig: # if module_config.use_3d: # ... # ``` - # here we can set `user_config={'use_3d': module_config.use_3d}`, + # here we can set `graph={'use_3d': module_config.use_3d}`, # and we can be sure different use_3d will never use the same generated code. # Example 2: save file stats # If you want to track all related file stats (just like traditional compilers do), - # you can do - # ```python - # user_config = { - # 'file_stats': { - # str(f): os.stat(f).st_mtime_ns for f in Path('./src').glob('**/*.py') # assume all source code is in ./src - # } - # } - # ``` - # Or you can save the md5 of the files to save some bytes: + # you can save the md5 of the files to save some bytes: # ```python # import hashlib # h = hashlib.md5() # for f in Path('./src').glob('**/*.py'): # with open(f, 'rb') as f: # h.update(f.read()) - # user_config = { + # graph = { # 'files_md5': h.hexdigest() # } # ``` - user_config: Optional[Dict[str, Any]] = None + graph: Dict[str, Any] = field(default_factory=dict) + # you can put any configuration that may affect the generated code (but not affect the traced graph) here. + # For example, extra arguments of your PAS function can put here. + code: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ComputeConfig: + plan_ngpus: int + runtime_ngpus: int + + # whether to use dynamic shape to generate code + dynamic_shape: bool = True + + use_zero: bool = False + zero_ngroups: int = 1 + + # whether the generated code is for inference only + inference_only: bool = False + + # end2end means, + # 1. the first argument of `module.forward` must be the data sample + # 2. the first return value of `module.forward` must be the loss + # which must be a scalar tensor + use_end2end: bool = False + + # current only end2end module supports in pipeline mode. + # so be sure to set use_end2end=True when use_pipeline=True + use_pipeline: bool = False + # number of micro-batches + pipeline_nmicros: int = -1 + # number of stages + pipeline_nstages: int = -1 + # it is pas's responsibility to apply the scheduler + pipeline_scheduler: str = '1f1b' + # the customized configs from user that can affect the graph and code generation. + # for example, module configuration or PAS policy settings. + user_config: UserConfig = field(default_factory=UserConfig) def __post_init__(self): if self.plan_ngpus <= 0: @@ -101,10 +129,40 @@ def __post_init__(self): if self.runtime_ngpus % self.plan_ngpus != 0: raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be a multiple of plan_ngpus {self.plan_ngpus}") if self.use_zero and self.zero_ngroups <= 0: - raise ValueError(f"zero_ngroups {self.zero_ngroups} must be > 0") + raise ValueError(f"zero_ngroups {self.zero_ngroups} must be > 0") if not self.use_zero and self.zero_ngroups != 1: logger.warning(f"use_zero is False, but zero_ngroups is {self.zero_ngroups}. Will set zero_ngroups to 1.") - self.zero_ngroups = 1 + # have to use __setattr__ for frozen dataclass + super().__setattr__('zero_ngroups', 1) + + if self.use_pipeline: + if not self.use_end2end: + raise ValueError("pipeline is only supported in end2end mode") + if self.pipeline_nmicros <= 0: + raise ValueError(f"pipeline_nmicros {self.pipeline_nmicros} must be > 0 when use pipeline") + if self.pipeline_nstages <= 0: + raise ValueError(f"pipeline_nstages {self.pipeline_nstages} must be > 0 when use pipeline") + if self.plan_ngpus % self.pipeline_nstages != 0: + raise ValueError(f"pipeline_nstages {self.plan_ngpus} must be a multiple of plan_ngpus {self.pipeline_nstages}") + if self.pipeline_scheduler not in _PREDEFINE_SCHEDS: + raise ValueError(f"pipeline_scheduler {self.pipeline_scheduler} is not supported. " + f"Supported schedulers are {_PREDEFINE_SCHEDS.keys()}") + if self.inference_only and self.pipeline_scheduler not in _PREDEFINED_INFERENCE_SCHEDS: + raise ValueError(f"pipeline_scheduler {self.pipeline_scheduler} is not supported in inference mode. " + f"Supported schedulers are {_PREDEFINED_INFERENCE_SCHEDS}") + if not self.inference_only and self.pipeline_scheduler in _PREDEFINED_INFERENCE_SCHEDS: + raise ValueError(f"pipeline_scheduler {self.pipeline_scheduler} is not supported in training mode.") + + if isinstance(self.user_config, dict): + super().__setattr__('user_config', UserConfig(**self.user_config)) + + def apply_pipeline_scheduler(self, graph: IRGraph) -> Optional[SchedulePlan]: + """ + Do nothing if not use_pipeline + """ + if self.use_pipeline: + sched = _PREDEFINE_SCHEDS[self.pipeline_scheduler] + return sched(graph, self.pipeline_nmicros, self.pipeline_nstages) @property def gpu_config(self) -> Dict[str, int]: @@ -115,9 +173,12 @@ def gpu_config(self) -> Dict[str, int]: @property def graph_config(self) -> Dict[str, Any]: - return { + return { 'dynamic_shape': self.dynamic_shape, - 'user_config': self.user_config, + 'graph_user_config': self.user_config.graph, + 'inference_only': self.inference_only, # there will be no backward nodes in the graph in inference mode + 'use_pipeline': self.use_pipeline, # pipeline option can affect the graph generation. + 'end2end_mode': self.use_end2end, # end2end_mode can affect the graph generation. } @property @@ -168,20 +229,59 @@ def _runtime_flags(**kwargs): return _flags(RuntimeFlag, **kwargs) -def _complex(val: Any): +def _to_cpu(val: Any): """Complex to CPU""" if isinstance(val, tuple): - return tuple(_complex(t) for t in val) + return tuple(_to_cpu(t) for t in val) if isinstance(val, list): - return list(_complex(t) for t in val) + return list(_to_cpu(t) for t in val) if isinstance(val, dict): - return {_complex(key):_complex(val) for key, val in val.items()} + return {_to_cpu(key):_to_cpu(val) for key, val in val.items()} if isinstance(val, set): - return {_complex(t) for t in val} + return {_to_cpu(t) for t in val} if isinstance(val, torch.Tensor): return val.cpu() return val +def to_ir_input(sample, name): + """Support complex of types: Tuple, List, Dict, torch.Tensor""" + if isinstance(sample, tuple): + return tuple(to_ir_input(t, name) for t in sample) + if isinstance(sample, list): + return list(to_ir_input(t, name) for t in sample) + if isinstance(sample, dict): + return {k: to_ir_input(v, str(k)) for k, v in sample.items()} + if isinstance(sample, torch.Tensor): + # note: we will always set tensor to require gradient, which may + # generate backward communications in adapter. However, as long as + # the data doesn't require gradient in real runtime, the backward + # communication will not be triggered. + tensor = IRFullTensor( + shape=sample.size(), + name=name, + requires_grad=True, + dtype=sample.dtype + ).tosub() + tensor._value = sample + tensor.grad = tensor.parent.grad.tosub() + return tensor + return IRObject(name, value=sample, is_constant=False) + + +def _contains_uncommutable_data(ir_outputs: Any): + """ + only IRObject (but not IRTensor) is not commutable between gpus. + """ + if isinstance(ir_outputs, (tuple, list)): + return any(_contains_uncommutable_data(t) for t in ir_outputs) + elif isinstance(ir_outputs, dict): + return any(_contains_uncommutable_data(k) or _contains_uncommutable_data(v) for k, v in ir_outputs.items()) + elif isinstance(ir_outputs, IRTensor): + return False + elif isinstance(ir_outputs, IRObject): + return True + return False + def _get_full_qualified_name(obj: Any) -> str: """Get full qualified name of an object""" @@ -221,11 +321,6 @@ def _clean_files(_dir: Path, pattern = '*') -> None: f.unlink() -def _int_dict_to_list(d: Dict[int, Any]) -> List[Any]: - """Convert a dict with int keys to a list""" - return [d[i] for i in range(len(d))] - - _DEFAULT_INSTANCE_NAME = '_' _GENCODE_FILE_PREFIX = 'gencode' _GENCODE_FILE_TEMPLATE = _GENCODE_FILE_PREFIX + '{}.py' # 'gencode{}.py' @@ -273,7 +368,7 @@ class BroadcastGenFilesStrategy(Enum): # (local world size should be bigger than plan_ngpus) NO_WEIGHTS = 'no_weights' - # broadcast the new generated code only (gencode*.py) + # broadcast the new generated code (gencode*.py) and compute_config.pt only. # It's your responsibility to make sure other necessary files are available on all nodes. CODE = 'code' @@ -331,24 +426,30 @@ def _prepare_and_check_reusable( # decision matrix for code generation # reuse flag | dir condition(imported, empty, match, unmatched) | action # --------------------------------------------------------- - # OVERRIDE/GRAPH | empty | generate - # OVERRIDE/GRAPH | imported | raise error - # OVERRIDE/GRAPH | match | generate - # OVERRIDE/GRAPH | unmatch | generate - # MATCH | empty | generate - # MATCH | match | reuse(do nothing) - # MATCH* | unmatch | raise error (except when there's no python source code, see below) - # MATCH | imported | doesn't matter - # MOO | empty | generate - # MOO | match | reuse(do nothing) - # MOO* | unmatch | generate (specail case is when there's no python source code, see below) - # MOO | imported | raise error if unmatch + # OVERRIDE | empty | generate + # OVERRIDE | imported | raise error + # OVERRIDE | whatever match | generate + # OVERRIDE | unmatch | generate + # GRAPH | empty | generate + # GRAPH | imported | raise error + # GRAPH | graph match | reuse graph, and regenerate code + # GRAPH | all match | reuse graph, and regenerate code + # GRAPH | unmatch | generate + # MATCH | empty | generate + # MATCH | match | reuse(do nothing) + # MATCH* | whatever unmatch| raise error (except when there's no python source code, see below) + # MATCH | imported | doesn't matter + # MOO | empty | generate + # MOO | match | reuse(do nothing) + # MOO | match graph | reuse graph, and regenerate code + # MOO | imported | raise error if whatever unmatch # *: The precondition for `except` part is the compute config should match. # you can take it as a continous operation after a failed generation. reusable = False config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE - old_config = torch.load(config_file) if config_file.exists() else None + old_config: ComputeConfig = torch.load(config_file) if config_file.exists() else None is_config_match = old_config == compute_config + is_graph_config_match = old_config is not None and old_config.graph_config == compute_config.graph_config trace_meta_files = [ outdir / FxModuleParser.ATTR_CONTENT_FILE_0, # just check the first is good enough outdir / FxModuleParser.ATTR_MAP_FILE, @@ -395,7 +496,11 @@ def _prepare_and_check_reusable( if _is_any_gencode_loaded(namespace): raise RuntimeError(f'Output directory {outdir} is already loaded. ' f'You can not override a loaded module.') - _clean_files(outdir) + elif is_graph_config_match: + # reuse the graph dump + _clean_files(outdir, '*.py') + else: + _clean_files(outdir) else: # check if the module is already loaded if _is_any_gencode_loaded(namespace): @@ -403,7 +508,7 @@ def _prepare_and_check_reusable( f'You can not override a loaded module.') # clear existing generated files if reuse == ReuseType.OVERRIDE \ - or not is_config_match \ + or not is_graph_config_match \ or not all([meta_file.exists() for meta_file in trace_meta_files]): # we have to trace the graph again if not all meta files are present even when reuse=graph. glob_pattern = '*' @@ -419,6 +524,8 @@ def _gen_graph( dummy_input: dict, outdir: Path, dynamic_shape: bool, + end2end_mode: bool = False, + inference_only: bool = False, ): # reset environment program = Program() @@ -432,7 +539,7 @@ def _gen_graph( raise ValueError(f"Default value type {type(v)} of forward args is not supported.") # generate fx graph - dummy_input = _complex(dummy_input) + dummy_input = _to_cpu(dummy_input) fx_graph = parser.to_fx_graph(module, dummy_input) # generate ir logic graph @@ -460,31 +567,48 @@ def _gen_graph( else: raise ValueError(f"Input {node.target} not in dummy input. Default value is not supported.") for i in range(len(ir_dummy_inputs)): - if isinstance(ir_dummy_inputs[i], torch.Tensor): - # note: we will always set tensor to require gradient, which may - # generate backward communications in adapter. However, as long as - # the data doesn't require gradient in real runtime, the backward - # communication will not be triggered. - ir_dummy_inputs[i] = IRFullTensor( - shape=ir_dummy_inputs[i].size(), - name=fx_input_nodes[i].target, - requires_grad=True, - dtype=ir_dummy_inputs[i].dtype).tosub() - ir_dummy_inputs[i].grad = ir_dummy_inputs[i].parent.grad.tosub() - else: - ir_dummy_inputs[i] = IRObject( - name=fx_input_nodes[i].target, - value=ir_dummy_inputs[i] - ) + ir_dummy_inputs[i] = to_ir_input(ir_dummy_inputs[i], fx_input_nodes[i].target) + # if the input is not a tensor, we should wrap it with IRObject + if not isinstance(ir_dummy_inputs[i], IRObject): + ir_dummy_inputs[i] = IRObject(fx_input_nodes[i].target, value=ir_dummy_inputs[i], is_constant=False) + # generate complete ir graph - ir_dummy_outputs = ir_graph(*ir_dummy_inputs) + if end2end_mode: + # in end2end mode, we must use dataloader as the first argument of forward + # we assume the first argument of forward is the data sample (which is a requirement in our doc) + + # the IRObject representing the `dataloader` instance, which is only used by the + # IRDataOperation. Since we already know the output of the dataloader, + # we don't need to set the value for it. + ir_root_obj = IRObject(name='dataloader', value=None) + Program().set_input([ir_root_obj]) + data_op = IRDataOperation(ir_root_obj, ir_dummy_inputs) + # add the data operation to the graph, which will use `next` to get data. + Program().add_node(data_op) + ir_dummy_outputs = ir_graph(*ir_dummy_inputs) + graph = program.get_graph() + # we require the first output is the loss + if isinstance(ir_dummy_outputs, (list, tuple)): + ir_loss = ir_dummy_outputs[0] + else: + ir_loss = ir_dummy_outputs + if not isinstance(ir_loss, IRTensor) or ir_loss.shape != (1,): + # TODO: update when we support scalar tensor + raise RuntimeError(f"Loss can only be scalar tensor but got {ir_loss.shape if isinstance(ir_loss, IRTensor) else ir_loss}") + if not inference_only: + ir_loss.backward() + else: + program.set_input(ir_dummy_inputs) + ir_dummy_outputs = ir_graph(*ir_dummy_inputs) + graph = program.get_graph() + if not inference_only: + graph.backward() - graph = program.get_graph() - graph.backward() - program.set_input(ir_dummy_inputs) if ir_dummy_outputs is None: ir_dummy_outputs = [] - elif not (isinstance(ir_dummy_outputs, tuple) or isinstance(ir_dummy_outputs, list)): + elif not isinstance(ir_dummy_outputs, (tuple, list)): ir_dummy_outputs = [ir_dummy_outputs] + if _contains_uncommutable_data(ir_dummy_outputs): + raise RuntimeError(f"Communication generation error: some of outputs are not commutable between gpus.") program.set_output(ir_dummy_outputs) program.finalize() @@ -560,7 +684,11 @@ def _gencode( ) torch.save(meta_info, origin_module_metadata_ckp) - graph, forward_args = _gen_graph(module, dummy_input, outdir, compute_config.dynamic_shape) + graph, forward_args = _gen_graph( + module, dummy_input, outdir, + dynamic_shape=compute_config.dynamic_shape, end2end_mode=compute_config.use_end2end, + inference_only=compute_config.inference_only + ) graph.dump(graph_ckp) torch.save(forward_args, forward_args_ckp) @@ -606,12 +734,21 @@ def _gencode( # code generation assert len(execplan.graph.device) == compute_config.plan_ngpus, f"{execplan.graph.device}" mgener = ModuleCodeGen(execplan, compute_config.runtime_ngpus) + sgener = ScheduleCodeGen(execplan, compute_config.runtime_ngpus) for rank in range(compute_config.runtime_ngpus): + fname = outdir / _GENCODE_FILE_TEMPLATE.format(rank) mgener.gen(rank, forward_args=forward_args, - outfile=outdir / _GENCODE_FILE_TEMPLATE.format(rank), + outfile=fname, attach=False, as_parallel_module=True, + end2end_mode=compute_config.use_end2end + ) + # generate temporal schedule code + sgener.gen( + device=rank, + outfile=fname, + attach=True ) return ret @@ -625,7 +762,8 @@ def _load_cube_module_class( rank: Optional[int] = None, ) -> Type[ParallelModule]: """ - Load the generated cube module class. + Load the generated cube module class, with train_step and infer_step assigned as member function.. + Please note that the cube module class should be generated beforehand by _gencode(). @@ -636,6 +774,8 @@ def _load_cube_module_class( rank (Optional[int]): the rank of the module. If it is None, will get the rank from torch.distributed.get_rank(). This option is only useful for debugging or writing pre/post-processing tools. when you need to load the generated module in a non-torchrun environment. + Returns: + Type[ParallelModule]: the generated module class """ rank = torch.distributed.get_rank() if rank is None else rank namespace, _ = _prepare_namespace(cube_savedir, module_class, instance_name) @@ -648,6 +788,8 @@ def _load_cube_module_class( cube_module_class.__qualname__ = module_class.__qualname__ # cube_module_class.__module__ = module_class.__module__ cube_module_class.__orig_module_class__ = module_class # save the original module class + cube_module_class._train_step = gen_imported._train_step + cube_module_class._infer_step = gen_imported._infer_step return cube_module_class @@ -761,8 +903,7 @@ def __init__(self, init_params=True): outdir, reusable = _prepare_and_check_reusable(cube_savedir, module_class, compute_config, instance_name, reuse) if not reusable: config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE - if not config_file.exists(): - torch.save(compute_config, config_file) + torch.save(compute_config, config_file) # always refresh compute config with _compile_flags(compute_config): regen_status = _gencode( module_or_module_class, @@ -1003,6 +1144,10 @@ def build_optimizer( if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): raise RuntimeError("End2End mode is not supported") + # only the root module can be end2end module. + if any(m != module and isinstance(m, ParallelModule) and m.compute_config.use_end2end for m in module.modules()): + raise RuntimeError("End2End module cannot be nested in another module") + RuntimeFlag.skip_reducer = True RuntimeFlag.skip_zero_grad = False @@ -1399,10 +1544,12 @@ def merge_state_dicts( if opt_state_dicts is not None: opt_module_locs = [opt_extra_states[i].parallel_module_locs[module_prefix] for i in range(len(opt_extra_states))] - # Assume all ranks have the same opt_module_locs (offset and count) - # TODO: assert may fail for pipeline parallelism - for i in range(1, len(opt_module_locs)): - assert opt_module_locs[i] == opt_module_locs[0] + # We can't assume all ranks have the same opt_module_locs (offset and count) + # when we use pipeline parallelism, different ranks may have different opt_module_locs + # fortunately, we can use the location information from any rank to do the merging in following + # here we always use the location information from rank 0 + # for i in range(1, len(opt_module_locs)): + # assert opt_module_locs[i] == opt_module_locs[0] opt_new_pm_states[opt_module_locs[0]] = (merged_opt_state_dict['state'], module_prefix, extra_states[0].origin_param_names) if opt_new_pm_states: @@ -1411,7 +1558,7 @@ def merge_state_dicts( module_prefix = '.'.join(k) pm_orig_param_names[module_prefix] = CubeModule.get_origin_parameter_names([e.param_area_map for e in extra_states]) # now we can construct the merged state of optimizer from any rank - # let's just use the first rank + # as said previously, the merge will be based on rank0's data orig_states: Dict[int, Any] = optimizer_state_dicts[0]['state'] ret_states: Dict[int, Any] = {} # see `_get_optimizer_state_dict_info` for the value structure. sorted_pm_locs = sorted(opt_new_pm_states.keys(), key=lambda x: x.offset) @@ -1777,8 +1924,11 @@ def _broadcast_gen_files( and not file.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) ) or ( + # broadcast code files and compute config file + # please note the compute config file can be updated + # even when the graph is reused. broadcast_strategy == BroadcastGenFilesStrategy.CODE - and file.suffix == '.py' + and (file.suffix == '.py' or file.name == ParallelModule.COMPUTE_CONFIG_FILE) ) ): files.append(file.name) @@ -1930,9 +2080,9 @@ def load_deduped_state_dict( # append also works # but insert to 0 feels better # the dedup size for non-parallel module is 1 - opt_broadcast_groups.insert(0, (list(non_parallel_module_locs), 1)) - # TODO: what if opt_broadcast_groups are different in different ranks? - # Will it happend in pipeline parallelism? + if non_parallel_module_locs: + opt_broadcast_groups.insert(0, (list(non_parallel_module_locs), 1)) + for bg in opt_broadcast_groups: _broadcast_opt_state(optimizer_state_dict, *bg) optimizer.load_state_dict(optimizer_state_dict) @@ -1976,7 +2126,7 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g else: step_stack = torch.zeros( len(state_indexes), - dtype=optimizer_state_dict['state'][k]['step'].dtype, + dtype=optimizer_state_dict['state'][0]['step'].dtype, device=torch.cuda.current_device() ) torch.distributed.broadcast(step_stack, src=src_rank, group=curr_parallel_group) diff --git a/cube/runtime/module.py b/cube/runtime/module.py index 706cf262..ad1e7e7d 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -1,4 +1,4 @@ -from typing import List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any +from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union import logging import os import sys @@ -12,8 +12,11 @@ from cube.graph.parser.fx.parser import FxModuleParser from cube.runtime.device import DeviceGroup from cube.runtime.adapter.reducer import Reducer +from cube.runtime.executor import Executor from cube.runtime.gnorm import ParamsInfo from cube.flags import CompileFlag +from cube.runtime.utils import microbatches +from cube.utils import accum_mode if TYPE_CHECKING: from cube.parallel import ComputeConfig @@ -452,12 +455,10 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): assert len(optim_state_dicts[work_idx]['param_groups']) == 1, 'only support param_groups to be one group' # assign opt_state to state_dicts, cannot be assigned in the above loop - opt_state_len = len(opt_state_list[0]) for work_idx in range(plan_ngpus): optim_state_dicts[work_idx]['state'] = opt_state_list[work_idx] optim_state_dicts[work_idx]['param_groups'][0]['params'] = sorted(opt_state_list[work_idx].keys()) _logger.info(f'finish assign optimizer state for worker {work_idx}') - assert len(opt_state_list[work_idx]) == opt_state_len # build parameter order to match with the optimizer state order # NOTE: the param IDs in optimizer typically follow the same order of @@ -663,6 +664,142 @@ def sync_grad(self): for reducer in self._reducers: reducer.sync_grads() + def _train_step(self, dataloader) -> Union[List[Any], Any]: + """ + This function is assigned automatically when loading module class + Returns: + Union[List[Any], Any]: the output of the training step, + In Pipeline mode, it should return a list of outputs for each sample + Otherwise, it should return a single output + """ + ... + + def _infer_step(self, dataloader) -> Union[List[Any], Any]: + """ + This function is assigned automatically when loading module class + Returns: + Union[List[Any], Any]: the output of the training step, + In Pipeline mode, it should return a list of outputs for each sample + Otherwise, it should return a single output + """ + ... + + def _scale_loss(self, is_dummy_batch: Optional[List[bool]], scale_fn: Optional[Callable[[torch.Tensor], torch.Tensor]]) -> None: + """Setup cube backward hook for loss scale and dummy batch. + + If the batch is a dummy batch, the loss will be 0 to make the + gradient 0. + + Args: + is_dummy_batch (List[bool]): indicate whether the each micro-batch is dummy + scale_fn (Callable[[torch.Tensor], torch.Tensor]): the function to scale the loss + """ + + # clear the previous hook + Executor.register_backward_pre_hook(None) + + if not is_dummy_batch and not scale_fn: + return + + accum_idx = 0 + def cube_scale(ins, outs, grads): + nonlocal accum_idx + if is_dummy_batch and accum_idx >= len(is_dummy_batch): + raise RuntimeError( + f"Expected {len(is_dummy_batch)} number of micro-batches, but got more than it." + ) + mul_coef = 0.0 if is_dummy_batch and is_dummy_batch[accum_idx] else 1.0 + # find loss + for idx in range(len(outs)): + # loss always requires to be a scalar, and its gradient should be None + if grads[idx] is None: + assert idx == 0, "Loss must be the first output." + if outs[idx].size() != torch.Size([]): + raise ValueError(f"Expected scalar loss, but got {outs[idx].size()}.") + if scale_fn: + outs[idx] = mul_coef * scale_fn(outs[idx]) + else: + outs[idx] = mul_coef * outs[idx] + break + accum_idx += 1 + return ins, outs, grads + + Executor.register_backward_pre_hook(cube_scale) + + def train_step(self, + samples: List[Any], + is_dummy_batch: Optional[List[bool]] = None, + scale_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ) -> List[Any]: + """ + The training step function. It should be called in the training loop. + Please note: + 1. This function is only supported in end2end mode. + 2. Gradient accumulation is done inside this function. + You shouldn't do gradient accumulation outside this function, + because the gradients will be cleared in the beginning of this function + Args: + samples (List[Any]): a list of samples. + if pipeline is used, it must have the same length as pipeline_nmicros + is_dummy_batch (Optional[List[bool]]): indicates whether the each micro-batch is dummy + scale_fn (Optional[Callable[[torch.Tensor], torch.Tensor]]): the function to scale the loss + Results: + List[Any]: a list of outputs for each sample + """ + if not self.compute_config.use_end2end: + raise RuntimeError("train_step() is only supported in end2end mode") + if is_dummy_batch and len(samples) != len(is_dummy_batch): + raise ValueError("The length of samples and is_dummy_batch should be the same") + + self._scale_loss(is_dummy_batch, scale_fn) + + # sync_grad will be done in _train_step + # so we never need to call it manually + self._sync_grad_required = False + sample_count = len(samples) + dataloader = microbatches(samples, cycle=False) + + if self.compute_config.use_pipeline: + if len(samples) != self.compute_config.pipeline_nmicros: + raise ValueError(f"Expected {self.compute_config.pipeline_nmicros} samples, but got {sample_count}") + # only one step, so begin/end are both True + with accum_mode(begin=True, end=True): + return self._train_step(dataloader) + else: + outputs = [] + for idx in range(sample_count): + with accum_mode(begin=(idx==0), end=(idx==sample_count-1)): + output = self._train_step(dataloader) + outputs.append(output) + return outputs + + def infer_step(self, samples: List[Any]) -> List[Any]: + """ + The inference step function. It should be called in the inference loop. + Please note this function is only supported in end2end mode. + + Args: + samples (List[Any]): a list of samples. + if pipeline is used, it must have the same length as pipeline_nmicros + Results: + List[Any]: a list of outputs for each sample + """ + if not self.compute_config.use_end2end: + raise RuntimeError("infer_step() is only supported in end2end mode") + + sample_count = len(samples) + dataloader = microbatches(samples, cycle=False) + if self.compute_config.use_pipeline: + if len(samples) != self.compute_config.pipeline_nmicros: + raise ValueError(f"Expected {self.compute_config.pipeline_nmicros} samples, but got {sample_count}") + return self._infer_step(dataloader) + else: + outputs = [] + for _ in range(sample_count): + output = self._infer_step(dataloader) + outputs.append(output) + return outputs + @property def dist_param_map(self) -> Dict[str, str]: """ @@ -874,6 +1011,9 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s tid_info[meta.tid].append((attr, meta.slicers, meta.val_chunks)) # multiple params may share the same tid for orig_param_name in orig_param_names: + if orig_param_name not in origname_tid_map: + # in pipeline mode, the parameter may not be in this rank + continue orig_param_name_with_prefix = prefix + orig_param_name if orig_param_name_with_prefix not in state_dict: continue diff --git a/docs/parallel_module.md b/docs/parallel_module.md index acf06461..7be5686b 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -1,8 +1,18 @@ # Parallel Module -Besides the support of end-to-end model training, Cube can also convert a `torch.nn.Module` to a parallel module. A parallel module is a special `torch.nn.Module` but runs in multiple gpus/nodes. All the complexity of distributed training/inferring is hidden from the user. +Cube can parallelize a `torch.nn.Module` to a parallel module. A parallel module is a special `torch.nn.Module` but runs in multiple gpus/nodes. All the complexity of distributed training/inferring is hidden from the user. -## An example +Currently we support three kinds of parallelism: data parallelism, tensor parallelism and pipeline parallelism (model parallelism). We can also combine them to get the best performance. + +Data parallelism and tensor parallelism are support for all kinds of module, but pipeline parallelism is only supported for end2end modules for scheduling reason. + +An end2end module is a module which satisfies: +- the first argument of `module.forward` is the data sample +- the first return value of `module.forward` is the loss (scalar tensor) + +The above restrictions are necessary for the pipeline parallelism to work. Of course, you can still use the parallel module without pipeline parallelism for end2end modules. + +## Examples - Example 1: Parallelize the whole module @@ -102,11 +112,80 @@ def train(model: ParallelizedLLM, data): optimizer.zero_grad() ``` +- Example 3: Parallelize end2end module. +```python +class End2EndMLP(nn.Module): + def __init__(self): + init_random() + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(8): + self.layers.append(nn.Linear(16, 16, bias=False)) + self.loss_fn = nn.BCELoss() + + def forward(self, data: Dict[str, torch.Tensor]): + x = data['data'] + for layer in self.layers: + x = layer(x) + x = torch.sigmoid(x) + loss = self.loss_fn(x, data['target']) + return loss + + llm_sample_input = {'data': ..., 'target': ...} # dummpy input will be used to do tracing + pas_policy = ... # the PAS policy, you can use autodist pas + compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + use_end2end=True, + use_pipeline=..., + pipeline_nmicros=..., + pipeline_nstages=..., + pipeline_scheduler=..., + ..., + ) # compute environment config + ParallelizedPipelinedLLM = parallelize( + LLM, + {'data': llm_sample_input}, + pas_policy, + compute_config, + ) +``` + +If you want to enable pipeline parallelism, you need to set `use_end2end=True` and `use_pipeline=True` in `ComputeConfig`. You also need to set `pipeline_nmicros` and `pipeline_nstages` to specify the number of microbatches and stages in the pipeline. The `pipeline_scheduler` is the scheduler to schedule the pipeline. See below for details. + +For end2end modules, you can't use `Module.forward`. +Instead, you must use `ParallelModule.train_step` and `ParallelModule.infer_step` to train/infer the module. + +```python +def infer(model: ParallelizedPipelinedLLM, data): + model.eval() + with torch.inference_mode(): + return model.infer_step(data) + + +def train(model: ParallelizedPipelinedLLM, data): + # build_optimizer function will help to create a distributed optimizer + optimizer = build_optimizer(model, ...) + + for i, x in enumerate(data): + model.train() + losses = model.train_step(x) + optimizer.step() + optimizer.zero_grad() +``` + ## APIs ### ComputeConfig The configuration of the compute environment. It is a dataclass with the following fields: ```python + +@dataclass +class UserConfig: + graph: Dict[str, Any] = field(default_factory=dict) + code: Dict[str, Any] = field(default_factory=dict) + @dataclass(frozen=True) class ComputeConfig: plan_ngpus: int @@ -117,26 +196,42 @@ class ComputeConfig: use_zero: bool = False zero_ngroups: int = 1 - user_config: Optional[Dict[str, Any]] = None + inference_only : bool = False + use_end2end: bool = False + + use_pipeline: bool = False + pipeline_nmicros: int = -1 + pipeline_nstages: int = 1 + pipeline_scheduler: Optional[str] = None + + user_config: UserConfig = field(default_factory=UserConfig) ``` We can categorize the fields into 4 categories: 1. Trace configuration - - dynamic_shape: whether to use dynamic shape or static shape. + - `dynamic_shape`: whether to use dynamic shape or static shape. 2. Compute environment configuration - - plan_ngpus: the number of gpus to be used as a unit. The model is partitioned (TP or PP) within a unit, and then data parallelism is applied across multiple units. So every `plan_ngpus` devices holds the whole model. Furthermore, assume we have two workers, and their ranks are `rank1` and `rank2`: + - `plan_ngpus`: the number of gpus to be used as a unit. The model is partitioned (TP or PP) within a unit, and then data parallelism is applied across multiple units. So every `plan_ngpus` devices holds the whole model. Furthermore, assume we have two workers, and their ranks are `rank1` and `rank2`: 1. if `rank1 // plan_gpus == rank2 // plan_ngpus`, then they are in the same unit. 2. If `rank1 % plan_ngpus == rank2 % plan_ngpus`, then the portion of model hold on both gpus are exactly the same. - - runtime_ngpus: the number of gpus to be used in runtime. It should be a multiple of `plan_ngpus`, which means we have `runtime_ngpus // plan_ngpus` units in runtime, and the data parallelism is `runtime_ngpus // plan_ngpus`. + - `runtime_ngpus`: the number of gpus to be used in runtime. It should be a multiple of `plan_ngpus`, which means we have `runtime_ngpus // plan_ngpus` units in runtime, and the data parallelism is `runtime_ngpus // plan_ngpus`. Please note all modules must have the same `plan_ngpus` and `runtime_ngpus`. 3. Code generation feature configuration - - use_zero: whether to use zero. If it is true, the generated code will use zero1 to do distributed training. - - zero_ngroups: the number of groups to be used in zero. + - `use_zero`: whether to use zero. If it is true, the generated code will use zero1 to do distributed training. + - `zero_ngroups`: the number of groups to be used in zero. + - `inference_only`: whether to generate code for inference only. If it is true, the generated code can not be used to train the model. + - `use_end2end`: whether to use end2end training. For the requirement of end2end, see the description above. + - `use_pipeline`: whether to use pipeline. Please note the pipeline parallelism is only supported for end2end modules, so you must set `use_end2end=True` if you want to use pipeline. + - `pipeline_nmicros`: the number of microbatches in the pipeline. + - `pipeline_nstages`: the number of stages in the pipeline. + - `pipeline_scheduler`: the scheduler name for the pipeline. Current we support four schedulers in training `1f1b`/`1f1b_plus`/`gpipe`/`chimera_direct` (4 stages pipeline only), and one scheduler in inference `infer_pipe`. 4. User configuration - - user_config: the user configuration. A typical usage is deciding whether skipping compiling and reusing the previously compiled parallel module. If user_config is the same between two runs, compiling in the second run will be skipped. + - user_config: the user configuration,which is used to decide whether skipping compiling and reusing the previously compiled parallel module. It has two categories of configuration: + - `graph`: the graph related configuration, which is used to decide whether skipping graph generation only. + - `code`: if it has changed, the code will be regenerated. Note: -1. You can put any graph related configuration here. The assumption is different user_config should generate different graph/code. So if the user config is changed, we will regenerate the graph/code automatically. Here are some examples: +1. You can put any custom configurations in `user_config`. The assumption is different `user_config` should generate different graph/code. So if the user config is changed, we will regenerate the graph/code automatically. Here are some examples: - Example 1: save module configuration ```python @@ -148,28 +243,25 @@ Note: if module_config.use_3d: ... ``` - here we can set `user_config={'use_3d': module_config.use_3d}`, - and we can be sure different use_3d config will never use the same generated code. + here we can set `user_config.graph` to `{'use_3d': module_config.use_3d}`, + and we can be sure different use_3d config will never use the same graph (and eventually the generated code). - Example 2: save file stats If you want to track all related file stats (just like traditional compilers do), - you can do - ```python - user_config = { - 'file_stats': { - str(f): os.stat(f).st_mtime_ns for f in Path('./src').glob('**/*.py') # assume all source code is in ./src - } - } - ``` - Or you can save the md5 of the files to save some bytes: + you can save the md5 of the files to save some bytes: ```python import hashlib h = hashlib.md5() for f in Path('./src').glob('**/*.py'): with open(f, 'rb') as f: h.update(f.read()) - user_config = { - 'files_md5': h.hexdigest() + compute_config = { + ...., + user_config: UserConfig( + graph = { + 'files_md5': h.hexdigest() + } + ) } ``` @@ -389,6 +481,50 @@ optimizer.register_reducer_post_hook(lambda reducer, grad: grad.mul_(num_scale_u 5. `_non_parallel_module_reducer`: The reducer for the modules which are not parallelized. It is used to sync the parameters in those modules across units. +### ParallelModule APIs + +The `ParallelModule` is a subclass of `torch.nn.Module`. It has the following APIs: + +1. constructor +```python +def __init__(self, init_params=True): + ... +``` +You can use `init_params` to control whether to initialize the module parameters with the module parameters' values when we trace it. You can set it to `False` if you don't want to. + +2. `train_step` +```python +def train_step(self, + samples: List[Any], + is_dummy_batch: Optional[List[bool]], + scale_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] +) -> List[Any]: + ... +``` +The training step function. It should be called in the training loop. +Please note: + 1. This function is only supported in end2end mode. + 2. Gradient accumulation is done in this function. + You shouldn't do it outside this function, + because `zero_grad` will be called in the beginning of this function + +It has the following arguments: +- `samples` (`List[Any]`): a list of samples. + if pipeline is used, it must have the same length as pipeline_nmicros +- `is_dummy_batch` (`Optional[List[bool]]`): indicates whether the each micro-batch is dummy +- `scale_fn` (`Optional[Callable[[torch.Tensor], torch.Tensor]]`): the function to scale the loss + +And it will return a list of outputs for the samples. + +3. `infer_step` +```python +def infer_step(self, samples: List[Any]) -> List[Any]: + ... +``` +The inference step function. It should be called in the inference loop. +The input is a list of samples, and returns a list of outputs for the samples. If pipeline is used, it must have the same length as pipeline_nmicros + + ### Checkpoint support You can save/load the checkpoints for parallel modules. @@ -480,7 +616,3 @@ def create_distributed_sampler(dataset): ..., ) ``` - -## TODOs - -1. Pipeline parallelism is not supported yet. diff --git a/examples/mlp/train.py b/examples/mlp/train.py index 4df366a1..409260e5 100644 --- a/examples/mlp/train.py +++ b/examples/mlp/train.py @@ -70,7 +70,7 @@ def train_iter(model, dataloader): loss.backward() # load generated model model = cube.utils.load_model() - + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) CudaTimer(enable=False).warmup() diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index d681162c..65b288c5 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -2,13 +2,15 @@ import math import random import shutil -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple +import more_itertools as mitr import contextlib import torch from torch import nn import numpy as np +from cube.graph.schedule.predefined import PredefinedSched from cube.parallel import ComputeConfig from cube.graph.function.anchor import IRGraphAnchor from cube.graph.function.dimops import IRDimops @@ -17,6 +19,34 @@ from cube.ir.operator import IRDataOperation, IRFwOperation +def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: + """ + Create hybrid (nested) groups given the each group number. + + The product of group_num should be same with total devices. + + e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: + ( + ( (0,1,2), (3,4,5) ), + ( (0,3), (2,5), (3,6) ), + ) + """ + group_num = np.array(group_num) + cnt = np.prod(group_num) + assert cnt == ngpus, 'total device not match' + grid = np.arange(cnt).reshape(tuple(group_num)) + dims = list(range(len(group_num))) + outputs = [] + for dim, num in enumerate(group_num): + remain = ngpus // num + order = tuple(dims[:dim] + dims[dim+1:] + [dim]) + grid_dim = np.transpose(grid, order).reshape((remain,num)) + grid_dim = grid_dim.tolist() + outputs.append(tuple(tuple(ranks) for ranks in grid_dim)) + assert len(outputs) == len(group_num) + return tuple(outputs) + + def _tp(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int): sub_nodes = graph.partition( node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) @@ -90,7 +120,7 @@ def PASData(graph: IRGraph, env_resource: ComputeConfig): batch_dim = 0 for dl in graph.select(ntype=IRDataOperation): - _replica(dl, list(range(ngpus))) + _replica(graph, dl, list(range(ngpus))) for node in graph.nodes(): # print(node) @@ -111,6 +141,78 @@ def PASData(graph: IRGraph, env_resource: ComputeConfig): return graph +def PASMegatron(graph: IRGraph, config: ComputeConfig): + num_stages = config.pipeline_nstages + tp_size = config.plan_ngpus // num_stages + _, tp_mesh = create_mesh(config.plan_ngpus, (num_stages, tp_size)) + + # group to sub-graphs + linears = graph.select(name='linear') + stage_start_nodes = linears[::len(linears) // num_stages][:num_stages] + graph.staging(stage_start_nodes) + + segments = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + + for sid, segment in enumerate(fsegs): + # get tensor parallel group + tp_group = tp_mesh[sid] + for idx, node in enumerate(segment.nodes()): + if node.name == 'linear': + _tp(graph, node, idx=1, dim=idx%2, devs=tp_group) + else: + _replica(graph, node, devs=tp_group) + + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs=list(range(config.plan_ngpus))) + config.apply_pipeline_scheduler(graph) + return graph + + +def PASHybrid(graph: IRGraph, config: ComputeConfig): + """ + Hybrid Tensor and Pipeline Parallelism + """ + ngpus: int = config.plan_ngpus + nstages = config.pipeline_nstages + tp_size: int = config.plan_ngpus // nstages + if ngpus % tp_size != 0: + raise ValueError(f'invalid tp_size {tp_size} for ngpus {ngpus}') + pp_size = ngpus // tp_size + + fnodes = graph.select(ntype=IRFwOperation) + stages = mitr.divide(pp_size, fnodes) + stages = [list(s) for s in stages] + for idx, stage in enumerate(stages): + print(f'> stage {idx}: {stage[0]}') + graph.staging([s[0] for s in stages]) + + stages: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + assert len(stages) == pp_size, "Internal Error" + + # stage-wise tensor parallelism + curr_devices = list(range(ngpus)) + for stage in stages: + for node in stage.nodes(): + devs = curr_devices[:tp_size] + try: + _tp(graph, node, devs, idx=0, dim=0) + except Exception as e: + _replica(graph, node, devs) + curr_devices = curr_devices[tp_size:] + assert len(curr_devices) == 0, f"remaining devices: {curr_devices} not used" + + # replicate dataloader + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs=list(range(ngpus))) + + # setup 1f1b pipeline scheduler + # PredefinedSched.sched_1f1b(graph, nmicros, pp_size) + config.apply_pipeline_scheduler(graph) + return graph + + class CubeLinear(nn.Module): def __init__(self, in_features, out_features, bias=False): super().__init__() @@ -149,7 +251,6 @@ def init_random(): torch.cuda.manual_seed(1) - @contextlib.contextmanager def clear_dir_on_rank0(tempdir): if torch.distributed.get_rank() == 0 and tempdir.exists(): diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index 61f7e371..97728bc1 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -113,7 +113,7 @@ def _gpu_worker(): m = p(tempdir, 'all', '_6', load_module=False, reuse='graph') if torch.distributed.get_rank() != 0: # only python files are broadcasted - assert set(f.name for f in Path(tempdir).glob('**/*') if f.is_file()) == set(['gencode0.py', 'gencode1.py']) + assert set(f.name for f in Path(tempdir).glob('**/*') if f.is_file()) == set(['gencode0.py', 'gencode1.py', 'compute_config.pt']) # case 6.2: everything should be broadcasted, including weights # so the load_module will succeed. diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 2d7aa7a8..46833965 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -5,7 +5,7 @@ import shutil import pytest from typing import Dict, Tuple, List -from dataclasses import dataclass +from dataclasses import dataclass, replace import torch from torch import nn @@ -18,7 +18,7 @@ from cube.runtime.module import ParallelModule, ExtraState from cube.runtime.gnorm import calcuate_gnorm -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, PASMegatron from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively from ..utils import replace_all_device_with @@ -44,10 +44,10 @@ def forward(self, x): return super().forward(x + self.buffer) -def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name, dummy_input = None): return parallelize( module, - {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + dummy_input if dummy_input is not None else {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, cube_savedir=cube_savedir, @@ -55,7 +55,122 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): ) -def _create_cube_module(pas, compute_config, cube_savedir, module_type='whole'): +def pipeline_dummy_data(): + return { + 'data': torch.randn( + 2, 16, device=torch.cuda.current_device()), + 'target': torch.rand( + 2, 16, device=torch.cuda.current_device()) + } + + +class End2EndMLP(nn.Module): + def __init__(self): + init_random() + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(8): + self.layers.append(nn.Linear(16, 16, bias=False)) + self.loss_fn = nn.BCELoss() + + def forward(self, data: Dict[str, torch.Tensor]): + x = data['data'] + for layer in self.layers: + x = layer(x) + x = torch.sigmoid(x) + loss = self.loss_fn(x, data['target']) + return loss + + @classmethod + def to_pipeline_module(cls, compute_config: ComputeConfig, cube_savedir, + instance_name='pipeline', scheduler='1f1b' + ): + assert compute_config.runtime_ngpus == 4 + assert compute_config.plan_ngpus == 2 + compute_config = replace(compute_config, + use_end2end=True, + use_pipeline=True, + pipeline_nmicros=2, + pipeline_nstages=2, + pipeline_scheduler=scheduler + ) + return parallelize( + cls, + {'data': pipeline_dummy_data()}, + PASMegatron, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name + ) + + @classmethod + def gen_pipeline_data(cls, data_size, start, end, rank, num_replicas): + data = [] + for _ in range(data_size): + data.append(pipeline_dummy_data()) + data = data[start:end] + data = [data[i] for i in range(rank, len(data), num_replicas)] + data = [(data[i:i + 2], None) for i in range(0, len(data), 2)] + return data + + @classmethod + def gen_raw_data(cls, data_size, start, end, rank, num_replicas): + data = [] + for _ in range(data_size): + data.append(pipeline_dummy_data()) + data = data[start:end] + data = [(data[i], None) for i in range(rank, len(data), num_replicas)] + return data + + +class End2EndMLPWithUnusedAndShared(End2EndMLP): + def __init__(self): + super().__init__() + self.linear0_unused = nn.Linear(4, 4) # unused weights + self.layers[5].weight = self.layers[0].weight # shared weights across stages + + +def train_step(model, x, y, optimizer): + model.train() + if isinstance(model, ParallelModule) and model.compute_config.use_pipeline: + # actually train_step will return two losses (for each input) + # here we fake one loss to y_pred, so we don't need to change the check logic + y_pred, loss = model.train_step(x) + # workaround scalar tensor bug + y_pred = y_pred.reshape(()) + loss = loss.reshape(()) + elif isinstance(model, End2EndMLP): + y_pred = model(x) + loss = y_pred + loss.backward() + else: + loss_fn = nn.BCELoss() + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + optimizer.step() + return y_pred, loss + + +def gendata(model, data_size, start, end, rank, num_replicas): + data = [] + init_random() + if isinstance(model, ParallelModule) and model.compute_config.use_pipeline: + data = End2EndMLP.gen_pipeline_data(data_size, start, end, rank, num_replicas) + elif isinstance(model, End2EndMLP): + data = End2EndMLP.gen_raw_data(data_size, start, end, rank, num_replicas) + else: + for _ in range(data_size): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), + )) + data = data[start:end] # continue from last training + data = [data[i] for i in range(rank, len(data), num_replicas)] + return data + + +def _create_cube_module(pas, compute_config: ComputeConfig, cube_savedir, module_type='whole'): init_random() if module_type == 'whole': class CompiledModule(torch.nn.Module): @@ -75,15 +190,20 @@ def forward(self, x): x = self.linear3(x) x = self.sigmoid(x) return x - CompiledModule = _to_cube_model(CompiledModule, pas, compute_config, cube_savedir, 'whole') + CompiledModule = _to_cube_model(CompiledModule, pas, compute_config, cube_savedir, f'whole-{compute_config.inference_only}') + elif module_type == 'pipeline': + CompiledModule = End2EndMLP.to_pipeline_module(compute_config, cube_savedir, + f'pipeline-{compute_config.inference_only}', + scheduler='infer_pipe' if compute_config.inference_only else '1f1b' + ) elif module_type == 'sub': class CompiledModule(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(4, 4) - self.fc_relu1 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu1') + self.fc_relu1 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, f'fc_relu1-{compute_config.inference_only}') self.linear2 = nn.Linear(4, 4) - self.fc_relu2 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, 'fc_relu2') + self.fc_relu2 = _to_cube_model(FcRelu_4_4(), pas, compute_config, cube_savedir, f'fc_relu2-{compute_config.inference_only}') self.linear3 = nn.Linear(4, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): @@ -99,7 +219,7 @@ class CompiledModule(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = _to_cube_model(CubeLinear(4, 4, bias=True), - pas, compute_config, cube_savedir, f'start_linear1' + pas, compute_config, cube_savedir, f'start_linear1-{compute_config.inference_only}' ) self.linear2 = CubeLinear(4, 1, bias=True) self.sigmoid = nn.Sigmoid() @@ -115,7 +235,7 @@ def __init__(self): super().__init__() self.linear1 = CubeLinear(4, 4, bias=True) self.linear2 = _to_cube_model(CubeLinear(4, 4, bias=True), - pas, compute_config, cube_savedir, f'end_linear2' + pas, compute_config, cube_savedir, f'end_linear2-{compute_config.inference_only}' ) self.sigmoid = nn.Sigmoid() def forward(self, x): @@ -131,11 +251,12 @@ def __init__(self): super().__init__() self.linear1 = CubeLinear(4, 4, bias=True) self.linear2 = _to_cube_model(CubeLinear(4, 1, bias=True), - pas, compute_config, cube_savedir, f'small_linear2' + pas, compute_config, cube_savedir, f'small_linear2-{compute_config.inference_only}' ) # the following tests depend on the rngstate in PASRandomSPMD - assert len(self.linear2.reducers) == 1 - assert len(self.linear2.reducers[0].ranks) == 4 + if not compute_config.inference_only: + assert len(self.linear2.reducers) == 1 + assert len(self.linear2.reducers[0].ranks) == 4 self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.linear1(x) @@ -157,9 +278,18 @@ class StepResult: weights: Dict[str, torch.Tensor] -def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): +def assert_model_state_dict_equal(state_dict1: dict, state_dict2: dict): + assert set(state_dict1.keys()) == set(state_dict2.keys()) + for index in state_dict1.keys(): + if index.endswith('CUBE_EXTRA_STATE'): + continue + assert torch.equal(state_dict1[index].cpu(), state_dict2[index].cpu()) + + +def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inference_module: torch.nn.Module = None): ckpt_file_template = 'ckpt_{rank}_{start}.pth' ckpt_merged_file_template = 'ckpt_merged_{start}.pth' + temp_inferenece_ckpt_file_template = 'inference-{rank}.pth' ckpt_start_file = ckpt_dir / ckpt_file_template.format( rank=torch.distributed.get_rank(), start=start @@ -167,6 +297,8 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): ckpt_start_merged_file = ckpt_dir / ckpt_merged_file_template.format( start=start ) + temp_inferenece_ckpt_file = ckpt_dir / temp_inferenece_ckpt_file_template.format(rank=torch.distributed.get_rank()) + init_random() loss_fn = nn.BCELoss() @@ -187,6 +319,21 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): merged_ckpt_dict = torch.load(ckpt_start_merged_file) merged_model_state_dict = merged_ckpt_dict['model'] merged_opt_state_dict = merged_ckpt_dict['optimizer'] + + # In most cases, we can't load state_dict directly + # because they are different models, and the names of parameters are changed. + # inference_module.load_state_dict(model_state_dict, strict=False) + # assert not check_model_state_dict_equal(inference_module.state_dict(), model_state_dict) + + # inference model can be loaded from merged state_dict + load_merged_state_dicts(inference_module, merged_model_state_dict) + torch.save(inference_module.state_dict(), temp_inferenece_ckpt_file) + torch.distributed.barrier() + inference_ckpt_files = [ckpt_dir / temp_inferenece_ckpt_file_template.format(rank=i) for i in range(torch.distributed.get_world_size())] + inference_state_dicts = [torch.load(f) for f in inference_ckpt_files] + merged_inference_state_dict, _ = merge_state_dicts(inference_state_dicts) + assert_model_state_dict_equal(merged_model_state_dict, merged_inference_state_dict) + model_from_merged = type(model)() optimizer_from_merged = build_optimizer(model_from_merged, torch.optim.Adam, lr=0.01) load_merged_state_dicts( @@ -197,11 +344,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): # check merged model result_orig_model_state_dict = model.state_dict() result_merged_model_state_dict = model_from_merged.state_dict() - assert set(result_orig_model_state_dict.keys()) == set(result_merged_model_state_dict.keys()) - for index in result_orig_model_state_dict.keys(): - if index.endswith('CUBE_EXTRA_STATE'): - continue - assert torch.equal(result_orig_model_state_dict[index], result_merged_model_state_dict[index]) + assert_model_state_dict_equal(result_orig_model_state_dict, result_merged_model_state_dict) result_orig_opt_state_dict = optimizer.state_dict() result_merged_opt_state_dict = optimizer_from_merged.state_dict() @@ -213,22 +356,10 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir): for key in ('step', 'exp_avg', 'exp_avg_sq'): assert torch.equal(result_orig_opt_state_dict['state'][index][key], result_merged_opt_state_dict['state'][index][key]) torch.distributed.barrier() - data = [] - init_random() - for _ in range(DATA_SIZE): - data.append(( - torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.rand((2, 1), device='cuda', dtype=torch.float32), - )) - data = data[start:end] # continue from last training - data = [data[i] for i in range(rank, len(data), num_replicas)] + data = gendata(model, DATA_SIZE, start, end, rank, num_replicas) results = [] for i, (x, y) in enumerate(data): - model.train() - y_pred = model(x) - loss = loss_fn(y_pred, y) - loss.backward() - optimizer.step() + y_pred, loss = train_step(model, x, y, optimizer) grads = {n: p.grad for n, p in model.named_parameters()} gnorm = optimizer.clip_gnorm() results.append(clone_to_cpu_recursively([y_pred, loss, grads, gnorm])) @@ -288,18 +419,24 @@ def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus, per_resum tempdir, module_type, ) + compiled_inference_module = _create_cube_module(pas, + ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero, inference_only=True), + tempdir, + module_type, + ) if check_module: check_module(compiled_module) compiled_results.extend(_train( compiled_module, runtime_ngpus // plan_ngpus, torch.distributed.get_rank() // plan_ngpus, - start, end, tempdir + start, end, tempdir, + inference_module=compiled_inference_module )) return compiled_results @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') -@pytest.mark.parametrize('module_type', ['sub', 'whole', 'start', 'end', 'small']) +@pytest.mark.parametrize('module_type', ['sub', 'whole', 'start', 'end', 'small', 'pipeline']) @pytest.mark.parametrize('use_zero', [True, False]) def test_checkpoint(module_type, use_zero): plan_ngpus = 2 diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index f3b0a100..e6533c35 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -11,8 +11,9 @@ deduped_state_dict, load_deduped_state_dict from cube.runtime.module import ParallelModule -from .common import PASRandomSPMD, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, assert_equal +from .common import PASRandomSPMD, PASMegatron, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, assert_equal from ..launch_torchrun import launch_torchrun +from .test_checkpoint import gendata, train_step, End2EndMLP, End2EndMLPWithUnusedAndShared class FcRelu(nn.Module): @@ -82,21 +83,12 @@ def forward(self, x): def _train(model: torch.nn.Module, ckpt_dir): CKPT_FILE_NAME = CKPT_FILE_NAME_TEMPLATE.format(torch.distributed.get_rank()) - DATA = [] - for _ in range(DATA_SIZE): - DATA.append(( - torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.rand((2, 1), device='cuda', dtype=torch.float32), - )) - loss_fn = nn.BCELoss() + + DATA = gendata(model, DATA_SIZE, 0, DATA_SIZE, 0, 1) optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) for i, (x, y) in enumerate(DATA): - model.train() + train_step(model, x, y, optimizer) optimizer.zero_grad() - y_pred = model(x) - loss = loss_fn(y_pred, y) - loss.backward() - optimizer.step() deduped_model_state_dict, deduped_opt_state_dict = deduped_state_dict(model, optimizer) torch.save({ 'model': model.state_dict(), @@ -119,7 +111,7 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): dedupped_optimizer_state_dicts = [ckpt['optimizer-dedup'] for ckpt in ckpt_state_dicts] parallel_modules = [m for m in model.modules() if isinstance(m, ParallelModule)] - assert len(parallel_modules) == 2 + # assert len(parallel_modules) == 2 module_dedup_group_size = [m.module_dedup_group_size for m in parallel_modules] opt_dedup_group_size = [m.optimizer_dedup_group_size for m in parallel_modules] @@ -140,7 +132,11 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): assert len(dedupped_model_state_dict) == len(parallel_modules) assert all(k.endswith(ParallelModule.EXTRA_STATE_KEY) for k in dedupped_model_state_dict.keys()) else: - assert len(parallel_modules) < len(dedupped_model_state_dict) < len(model_state_dict) + if not isinstance(model, ParallelModule): + # in this case, non parallel module is removed, so it should have less keys + assert len(parallel_modules) < len(dedupped_model_state_dict) < len(model_state_dict) + else: + assert len(dedupped_model_state_dict) == len(model_state_dict) for k, v in dedupped_model_state_dict.items(): assert_equal(v, model_state_dict[k]) @@ -152,7 +148,11 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): # only EXTRA_STATEs and param_groups are kept assert not dedupped_optimizer_state_dict['state'] # should have empty state else: - assert 0 < len(dedupped_optimizer_state_dict['state']) < len(optimizer_state_dict['state']) + if not isinstance(model, ParallelModule): + # in this case, non parallel module is removed, so it should have less keys + assert 0 < len(dedupped_optimizer_state_dict['state']) < len(optimizer_state_dict['state']) + else: + assert len(dedupped_optimizer_state_dict['state']) == len(optimizer_state_dict['state']) for k, v in dedupped_optimizer_state_dict['state'].items(): assert_equal(v, optimizer_state_dict['state'][k]) @@ -196,3 +196,30 @@ def test_checkpoint_compact(use_zero): cc1 = ComputeConfig(2, 4, use_zero=not use_zero, zero_ngroups=2 if not use_zero else 1) cc2 = ComputeConfig(2, 4, use_zero=use_zero, zero_ngroups=1) launch_torchrun(4, _gpu_worker, PASRandomSPMD, cc1, cc2) + + +def _gpu_worker_pipeline(cc): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_compact_pipeline') as tempdir: + for model_cls in [End2EndMLP, End2EndMLPWithUnusedAndShared]: + pipeline_moule_cls = model_cls.to_pipeline_module(cc, tempdir) + _train(pipeline_moule_cls().cuda(), tempdir) + torch.distributed.barrier() + _check_deduped( + pipeline_moule_cls().cuda(), + tempdir + ) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_checkpoint_compact_pipeline(): + cc1 = ComputeConfig(2, 4, use_zero=False) + launch_torchrun(4, _gpu_worker_pipeline, cc1) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_checkpoint_compact_pipeline_use_zero(): + cc1 = ComputeConfig(2, 4, use_zero=True, zero_ngroups=1) + cc2 = ComputeConfig(2, 4, use_zero=True, zero_ngroups=2) + launch_torchrun(4, _gpu_worker_pipeline, cc1) + launch_torchrun(4, _gpu_worker_pipeline, cc2) diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index 03b4227e..838761c3 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -10,6 +10,7 @@ from .common import PASRandomSPMD, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun +from .test_checkpoint import End2EndMLP, train_step, gendata class FcReluWithShared(nn.Module): @@ -45,9 +46,9 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): ) -def _create_cube_module(pas, compute_config, cube_savedir, module_type='raw'): +def _create_cube_module(pas, compute_config, cube_savedir, module_type='sub/raw'): init_random() - if module_type == 'raw': + if module_type == 'sub/raw': class RawModuleWithShared(torch.nn.Module): def __init__(self): super().__init__() @@ -67,7 +68,7 @@ def forward(self, x): return x init_random() return RawModuleWithShared().cuda() - else: + elif module_type == 'sub/cube': class ParallelModuleWithShared(torch.nn.Module): def __init__(self): super().__init__() @@ -90,6 +91,28 @@ def forward(self, x): return x init_random() return ParallelModuleWithShared().cuda() + elif module_type.startswith('pipeline/'): + class RawModuleWithUnused(End2EndMLP): + def __init__(self): + super().__init__() + self.linear0_unused = nn.Linear(4, 4) # unused weights + self.layers[2].weight = self.layers[0].weight # shared weights in same stage + init_random() + if module_type.endswith('/raw'): + return RawModuleWithUnused().cuda() + else: + return RawModuleWithUnused.to_pipeline_module(compute_config, cube_savedir, 'pipeline')().cuda() + elif module_type.startswith('pipeline2/'): + class RawModuleWithUnused(End2EndMLP): + def __init__(self): + super().__init__() + self.linear0_unused = nn.Linear(4, 4) # unused weights + self.layers[5].weight = self.layers[0].weight # shared weights across stages + init_random() + if module_type.endswith('/raw'): + return RawModuleWithUnused().cuda() + else: + return RawModuleWithUnused.to_pipeline_module(compute_config, cube_savedir, 'pipeline')().cuda() DATA_SIZE = 256 @@ -97,21 +120,11 @@ def forward(self, x): def _train_raw(model: torch.nn.Module, ckpt_dir): - DATA = [] - for _ in range(DATA_SIZE): - DATA.append(( - torch.randn((2, 4), device='cuda', dtype=torch.float32), - torch.rand((2, 1), device='cuda', dtype=torch.float32), - )) - loss_fn = nn.BCELoss() + DATA = gendata(model, DATA_SIZE, 0, DATA_SIZE, 0, 1) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for i, (x, y) in enumerate(DATA): - model.train() + y_pred, loss = train_step(model, x, y, optimizer) optimizer.zero_grad() - y_pred = model(x) - loss = loss_fn(y_pred, y) - loss.backward() - optimizer.step() torch.save({ 'model': model.state_dict(), 'optimizer': optimizer.state_dict() @@ -151,10 +164,11 @@ def _load_merged(parallel_model: torch.nn.Module, ckpt_dir): 'model': merged_model_state_dicts, 'optimizer': merged_optimizer_state_dict }, ckpt_merged_file) + # only key that contains `unused`` and not start with `unused` will be removed raw_model_state_dict = { key: value for key, value in raw_model_state_dict.items() - if not key.startswith('fc_relu1.unused_fc') + if not ('unused' in key and not key.startswith('unused')) } assert set(merged_model_state_dicts.keys()) == set(raw_model_state_dict.keys()) for index in merged_model_state_dicts.keys(): @@ -168,7 +182,7 @@ def _load_merged(parallel_model: torch.nn.Module, ckpt_dir): assert torch.equal(merged_optimizer_state_dict['state'][index][key].cuda(), raw_opt_state_dict['state'][index][key].cuda()) -def _gpu_worker(use_zero, pas, plan_ngpus, runtime_ngpus): +def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus): # Basic logic: # a. first train the original model, get a full state dict # b. then use parallel model to load the full state dict as a merged state dict @@ -179,19 +193,20 @@ def _gpu_worker(use_zero, pas, plan_ngpus, runtime_ngpus): with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: if torch.distributed.get_rank() == 0: tempdir.mkdir(parents=True, exist_ok=True) - _train_raw(_create_cube_module(pas, compute_config, tempdir, 'raw'), tempdir) + _train_raw(_create_cube_module(pas, compute_config, tempdir, f'{module_type}/raw'), tempdir) torch.distributed.barrier() _load_merged( - _create_cube_module(pas, compute_config, tempdir, 'cube'), + _create_cube_module(pas, compute_config, tempdir, f'{module_type}/cube'), tempdir ) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @pytest.mark.parametrize('use_zero', [True, False]) -def test_checkpoint_load_from_raw_checkpoint(use_zero): +@pytest.mark.parametrize('module_type', ['sub', 'pipeline', 'pipeline2']) +def test_checkpoint_load_from_raw_checkpoint(module_type, use_zero): """ Test when the checkpoint is generated from raw module and need to be loaded to parallel module. """ plan_ngpus = 2 runtime_ngpus = 4 - launch_torchrun(4, _gpu_worker, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus) + launch_torchrun(4, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus) diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py new file mode 100644 index 00000000..501d9ad3 --- /dev/null +++ b/tests/parallel_module/test_end2end.py @@ -0,0 +1,356 @@ +""" +PYTHONPATH=.:$PYTHONPATH torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/train.py --policy PASMegatronTP +""" + +from pathlib import Path +import tempfile +from typing import Dict +import pytest +import torch +from torch import nn +import torch.distributed + +import cube +from cube.runtime.gnorm import calcuate_gnorm +from cube.runtime.utils import microbatches +from cube.runtime.module import ParallelModule +from cube.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts +from .common import PASData, PASRandomSPMD, assert_equal, clear_dir_on_rank0, init_distributed, PASMegatron, init_random, PASHybrid +from ..launch_torchrun import clone_to_cpu_recursively, launch_torchrun + +from .test_checkpoint import End2EndMLP + + +DATA_SIZE = 64 +MBS = 2 # microbatch size +DIM = 16 +LAYERS = 16 + +class MLP(nn.Module): + def __init__(self, dim: int = DIM, nlayers: int = LAYERS): + init_random() + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + self.loss_fn = nn.BCELoss() + + def forward(self, data: Dict[str, torch.Tensor]): + x = data['data'] + for layer in self.layers: + x = layer(x) + x = torch.sigmoid(x) + loss = self.loss_fn(x, data['target']) + return loss + + +def dummy_data(): + return { + 'data': torch.randn( + MBS, DIM, device=torch.cuda.current_device()), + 'target': torch.rand( + MBS, DIM, device=torch.cuda.current_device()) + } + + +def _train_cube(model: ParallelModule, mbs, num_replicas, rank): + init_random() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + data = [] + init_random() + for _ in range(DATA_SIZE): + data.append(dummy_data()) + data = [data[i] for i in range(rank, DATA_SIZE, num_replicas)] + chunks = [data[i:i + mbs] for i in range(0, len(data), mbs)] + results = [] + for _, x in enumerate(chunks): + model.train() + losses = model.train_step(x) + optimizer.step() + gnorm = optimizer.clip_gnorm() + grads = {n: p.grad for n, p in model.named_parameters()} + model._add_extra_state(grads, '') + weights = {n: p.data for n, p in model.named_parameters()} + model._add_extra_state(weights, '') + results.append(clone_to_cpu_recursively([grads, weights, gnorm])) + optimizer.zero_grad() + return results + + +def _train_ga(model, update_freq, data_size=DATA_SIZE): + init_random() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + data = [] + init_random() + for _ in range(data_size): + data.append(dummy_data()) + results = [] + for i, x in enumerate(data): + model.train() + loss = model(x) + loss.backward() + if i % update_freq == update_freq - 1: + optimizer.step() + gnorm = calcuate_gnorm(list(model.parameters()))[0] + grads = {n: p.grad for n, p in model.named_parameters()} + weights = {n: p.data for n, p in model.named_parameters()} + results.append(clone_to_cpu_recursively([grads, weights, gnorm])) + optimizer.zero_grad() + return results + + +def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, use_pipeline, nstages=None, nmicros=None, model_cls=MLP, pipeline_scheduler='1f1b'): + init_distributed() + init_random() + nstages = nstages or plan_ngpus + nmicros = nmicros or plan_ngpus + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end') as tempdir: + init_random() + model = model_cls() + model = parallelize( + model, + {'data': dummy_data()}, + pas_policy=policy, + compute_config= ComputeConfig( + plan_ngpus, runtime_ngpus, + use_end2end=True, + use_pipeline=use_pipeline, pipeline_nmicros=nmicros, pipeline_nstages=nstages, + pipeline_scheduler=pipeline_scheduler + ), + cube_savedir=tempdir + ) + model.cuda() + train_result = _train_cube(model, nmicros, runtime_ngpus // plan_ngpus, torch.distributed.get_rank() // plan_ngpus) + + with torch.inference_mode(): + model.eval() + init_random() + infer_data = [] + for _ in range(nmicros): + infer_data.append(dummy_data()) + infer_result = clone_to_cpu_recursively(model.infer_step(infer_data)) + + return train_result, infer_result, clone_to_cpu_recursively(infer_data) + + +def merge_cube_result(cube_results): + cube_result = [] + for i in range(len(cube_results[0])): + for rank in cube_results: + assert torch.equal(cube_results[rank][i][2], cube_results[0][i][2]) + cube_result.append([ + merge_state_dicts([cube_results[rank][i][0] for rank in cube_results])[0], + merge_state_dicts([cube_results[rank][i][1] for rank in cube_results])[0], + cube_results[0][i][2] + ]) + return cube_result + + +def allclose(a, b, atol=1e-6, rtol=1e-6): + assert len(a) == len(b) + for step in range(len(a)): + assert len(a[step][0]) == len(b[step][0]) + assert len(a[step][1]) == len(b[step][1]) + for k in a[step][0].keys(): # grads + assert torch.allclose(a[step][0][k].cpu(), b[step][0][k].cpu(), atol=atol, rtol=rtol) + for k in a[step][1].keys(): # weights + assert torch.allclose(a[step][1][k].cpu(), b[step][1][k].cpu(), atol=atol, rtol=rtol) + # gnorm + assert torch.allclose(a[step][2].cpu(), b[step][2].cpu(), atol=atol, rtol=rtol) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_end2end(): + torch.cuda.set_device(0) + torch.set_default_device(f'cuda:0') + model = MLP() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 + + cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, PASHybrid, True) # micro_batch_size = 4 + cube2_result = merge_cube_result({k: v[0] for k, v in cube2_results.items()}) + assert len(cube2_result) == 16 + allclose(cube2_result, ga4_result) + + cube4_results = launch_torchrun(4, gpu_worker_cube, 4, 4, PASMegatron, True) # micro_batch_size = 4 + cube4_result = merge_cube_result({k: v[0] for k, v in cube4_results.items()}) + assert len(cube4_result) == 16 + allclose(cube4_result, ga4_result) + + cube2_results_non_pipeline = launch_torchrun(4, gpu_worker_cube, 4, 2, PASRandomSPMD, False) # micro_batch_size = 4 + cube2_result_non_pipeline = merge_cube_result({k: v[0] for k, v in cube2_results_non_pipeline.items()}) + assert len(cube2_result_non_pipeline) == 16 + allclose(cube2_result_non_pipeline, ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error + + infer_results = {k: v[1] for k, v in cube2_results_non_pipeline.items()} + infer_datas = {k: v[2] for k, v in cube2_results_non_pipeline.items()} + assert len(infer_results) == 4 + assert len(infer_datas) == 4 + infer_result = infer_results[0] + infer_data = infer_datas[0] + for k in infer_results: + assert_equal(infer_results[k], infer_result) + for k in infer_datas: + assert_equal(infer_datas[k], infer_data) + + for i, data in enumerate(infer_data): + with torch.inference_mode(): + model.eval() + loss = model({key: v.cuda() for key, v in data.items()}) + assert torch.allclose(loss.cpu(), infer_result[i].cpu(), atol=1e-6, rtol=1e-6) + + +class MLPShared(End2EndMLP): + def __init__(self): + super().__init__() + self.layers[5].weight = self.layers[0].weight # shared weights across stages + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_pipeline_shared(): + torch.cuda.set_device(0) + torch.set_default_device(f'cuda:0') + model = MLPShared() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 + for step in range(len(ga4_result)): + # fake shared weights for later compare + ga4_result[step][0]['layers.5.weight'] = ga4_result[step][0]['layers.0.weight'] + ga4_result[step][1]['layers.5.weight'] = ga4_result[step][1]['layers.0.weight'] + + with pytest.raises(ValueError, match='is not supported in training mode'): + ComputeConfig( + 2, 2, + inference_only=False, + use_end2end=True, + use_pipeline=True, pipeline_nmicros=2, pipeline_nstages=2, + pipeline_scheduler='infer_pipe' + ) + with pytest.raises(ValueError, match='is not supported in inference mode'): + ComputeConfig( + 2, 2, + inference_only=True, + use_end2end=True, + use_pipeline=True, pipeline_nmicros=2, pipeline_nstages=2, + pipeline_scheduler='1f1b' + ) + + for ps in ['1f1b', '1f1b_plus','gpipe']: + # 'chimera_direct' needs more gpus + # 'infer_pipe' only work for inference + # None looks doesn't work + cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, PASHybrid, True, None, None, MLPShared, ps) # micro_batch_size = 4 + cube2_result = merge_cube_result({k: v[0] for k, v in cube2_results.items()}) + assert len(cube2_result) == 16 + allclose(cube2_result, ga4_result) + + # TODO: fix `chimera_direct` + # cube4_results = launch_torchrun(4, gpu_worker_cube, 4, 4, PASMegatron, True, None, None, MLPShared, 'chimera_direct') # micro_batch_size = 4 + # cube4_result = merge_cube_result({k: v[0] for k, v in cube4_results.items()}) + # assert len(cube4_result) == 16 + # allclose(cube4_result, ga4_result) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 8, reason='lack of gpu devices') +def test_pipeline(): + torch.cuda.set_device(0) + torch.set_default_device(f'cuda:0') + model = MLP() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 + + # pp_size = 2 + # tp_size = 2 + # scale unit size = 4 + cube8_results = launch_torchrun(8, gpu_worker_cube, 8, 4, PASMegatron, True, 2, 2) # micro_batch_size = 4 + cube8_result = merge_cube_result({k: v[0] for k, v in cube8_results.items()}) + assert len(cube8_result) == 16 + allclose(cube8_result, ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error + + # TODO: scalar type support + # `v[1].reshape(())` to unify torch.shape == [] or torch.shape == [1] + infer_results = {k: tuple(i.reshape(()) for i in v[1]) for k, v in cube8_results.items()} + infer_datas = {k: v[2] for k, v in cube8_results.items()} + assert len(infer_results) == 8 + assert len(infer_datas) == 8 + infer_result = infer_results[0] + infer_data = infer_datas[0] + for k in infer_results: + assert_equal(infer_results[k], infer_result) + for k in infer_datas: + assert_equal(infer_datas[k], infer_data) + + for i, data in enumerate(infer_data): + with torch.inference_mode(): + model.eval() + loss = model({key: v.cuda() for key, v in data.items()}) + assert torch.allclose(loss.cpu(), infer_result[i].cpu(), atol=1e-6, rtol=1e-6) + + +def _train_cube_one_sample(model: ParallelModule, mbs): + init_random() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + data = [] + init_random() + data_size = mbs + for _ in range(data_size): + data.append(dummy_data()) + chunks = [data[i:i + mbs] for i in range(0, len(data), mbs)] + results = [] + for _, x in enumerate(chunks): + model.train() + losses = model.train_step(x, [False, True], scale_fn=lambda t: t * 2.0) + optimizer.step() + gnorm = optimizer.clip_gnorm() + grads = {n: p.grad for n, p in model.named_parameters()} + model._add_extra_state(grads, '') + weights = {n: p.data for n, p in model.named_parameters()} + model._add_extra_state(weights, '') + results.append(clone_to_cpu_recursively([grads, weights, gnorm])) + optimizer.zero_grad() + return results + + +def gpu_worker_cube_one_sample(): + init_distributed() + init_random() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end') as tempdir: + init_random() + model = MLP() + model = parallelize( + model, + {'data': dummy_data()}, + pas_policy=PASHybrid, + compute_config= ComputeConfig( + 2, 2, + use_end2end=True, + use_pipeline=True, pipeline_nmicros=2, pipeline_nstages=2, + pipeline_scheduler='1f1b' + ), + cube_savedir=tempdir + ) + model.cuda() + train_result = _train_cube_one_sample(model, 2) + return train_result + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_loss_scaling(): + torch.cuda.set_device(0) + torch.set_default_device(f'cuda:0') + model = MLP() + ga4_result = _train_ga(model, 1, 1) + assert len(ga4_result) == 1 + ga4_grads = ga4_result[0][0] + scaled_ga4_grads = {n: g * 2.0 for n, g in ga4_grads.items()} + + cube2_results = launch_torchrun(2, gpu_worker_cube_one_sample) + cube2_result = merge_cube_result({k: v for k, v in cube2_results.items()}) + assert len(cube2_result) == 1 + cube2_grads = cube2_result[0][0] + assert len(cube2_grads) == len(scaled_ga4_grads) + for k in cube2_grads: + assert torch.allclose(cube2_grads[k].cpu(), scaled_ga4_grads[k].cpu(), atol=1e-6, rtol=1e-6) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index d0dd890f..b365f946 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -624,3 +624,110 @@ def test_codegen_buffer(): assert len(matches) == 1 match = matches[0] assert 'persistent' not in match + + +class End2EndModule(torch.nn.Module): + def __init__(self, dim: int = 1024, nlayers: int = 16): + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + def forward(self, data: torch.Tensor, return_type: int = 0): + x = data + for layer in self.layers: + x = layer(x) + loss = torch.sum(x) + if return_type == 0: + return loss + elif return_type == 1: + return loss, data.shape # the second return is not tensor + elif return_type == 2: + return loss, {'data': data} + elif return_type == 3: + return torch.sum(x, -1) # bad loss + elif return_type == 4: + return {'data': data} # not tensor + + +@replace_all_device_with('cpu') +def test_codegen_inference(): + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + Module0(), + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + PASData, + ComputeConfig(1, 1, inference_only=True), + cube_savedir=tempdir, + load_module=False + ) + assert _gencode_contains(tempdir, Module0, 0, + r"self\.register_buffer" + ) + assert not _gencode_contains(tempdir, Module0, 0, + r"self\.register_parameter" + ) + + +@replace_all_device_with('cpu') +def test_codegen_end2end(): + """ + Test end2end code generation for different configs + (use_pipeline, dynamic shape, return value) + """ + dim = 1024 + nlayers = 16 + batch_size = 64 + def p(cube_dir, use_pipeline, dynamic_shape, return_type, inference_only=False): + m = End2EndModule(dim, nlayers) + m.train() + parallelize( + m, + {'data': torch.randn(batch_size, dim), 'return_type': return_type}, + PASData, + compute_config= ComputeConfig( + 4, 4, + inference_only=inference_only, + dynamic_shape=dynamic_shape, + use_end2end=True, + use_pipeline=use_pipeline, + pipeline_nmicros=4, + pipeline_nstages=4, + pipeline_scheduler='infer_pipe' if inference_only else '1f1b' + ), + cube_savedir=cube_dir, + load_module=False, + reuse='override', + ) + with tempfile.TemporaryDirectory() as tempdir: + for use_pipeline in [True, False]: + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=0) # should success + assert not _gencode_contains(tempdir, End2EndModule, 0, + r"self\.register_buffer" + ) + assert _gencode_contains(tempdir, End2EndModule, 0, + r"self\.register_parameter" + ) + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=0) # should success + with pytest.raises(RuntimeError, match='.*Communication generation.*'): + # fail for non-tensor IRObject return + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=1) + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=1) # should success + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=2) # should success + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=2) # should success + with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=3) + with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=3) + with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=4) + with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=4) + + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=0, inference_only=True) # should success + assert not _gencode_contains(tempdir, End2EndModule, 0, + r"self\.register_parameter" + ) + assert _gencode_contains(tempdir, End2EndModule, 0, + r"self\.register_buffer" + ) diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 2e5d1928..751136b3 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -54,14 +54,18 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): instance_name=instance_name ) -def _inference_worker(ngpus): +def _inference_worker(ngpus, inference_only): init_distributed() init_random() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_inference_test') as tempdir: model = Module() model.eval() - cube_model = _to_cube_model(model, PASRandomSPMD, ComputeConfig(ngpus, ngpus), tempdir, 'test_inference') + + cube_model = _to_cube_model(model, PASRandomSPMD, + ComputeConfig(ngpus, ngpus, inference_only=inference_only), + tempdir, 'test_inference' + ) data = torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]) assert not model.training @@ -76,9 +80,11 @@ def _inference_worker(ngpus): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_inference1(): - torchrun(1, _inference_worker, 1) + torchrun(1, _inference_worker, 1, True) + torchrun(1, _inference_worker, 1, False) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_inference2(): - torchrun(2, _inference_worker, 2) + torchrun(2, _inference_worker, 2, True) + torchrun(2, _inference_worker, 2, False) diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index 348b89cb..1825978e 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -8,6 +8,7 @@ from cube.graph.parser.fx.parser import FxModuleParser from cube.parallel import ReuseType, parallelize, ComputeConfig, _load_cube_module_class +from cube.runtime.module import ParallelModule from ..utils import new_empty, replace_all_device_with from .common import PASData @@ -83,7 +84,7 @@ def test_override(): # MOO | unmatch | generate _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o1', load_module=False) - _to_cube_model(MyModule, ComputeConfig(2, 2),tempdir, ReuseType.MOO, 'o1') + _to_cube_model(MyModule, ComputeConfig(2, 2, dynamic_shape=False),tempdir, ReuseType.MOO, 'o1') # MOO | imported | raise error _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o2', load_module=True) @@ -173,6 +174,28 @@ def test_override(): g7_module_path = module_path.with_name('g7') graph_stat = (g7_module_path / 'graph.ckp').stat() args_stat = (g7_module_path / 'forward_args.pkl').stat() - _to_cube_model(MyModule, ComputeConfig(2, 2), tempdir, 'graph', 'g7', False) + _to_cube_model(MyModule, ComputeConfig(2, 2, dynamic_shape=False), tempdir, 'graph', 'g7', False) assert (g7_module_path / 'graph.ckp').stat().st_mtime_ns != graph_stat.st_mtime_ns assert (g7_module_path / 'forward_args.pkl').stat().st_mtime_ns != args_stat.st_mtime_ns + + # Graph | graph match | generate + _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'g8', False) + g8_module_path = module_path.with_name('g8') + assert torch.load(g8_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(1, 1) + graph_stat = (g8_module_path / 'graph.ckp').stat() + args_stat = (g8_module_path / 'forward_args.pkl').stat() + _to_cube_model(MyModule, ComputeConfig(2, 2), tempdir, 'graph', 'g8', False) + assert (g8_module_path / 'graph.ckp').stat().st_mtime_ns == graph_stat.st_mtime_ns + assert (g8_module_path / 'forward_args.pkl').stat().st_mtime_ns == args_stat.st_mtime_ns + assert torch.load(g8_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(2, 2) + + # MOO | graph match | generate code only + _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'moo', 'g9', False) + g9_module_path = module_path.with_name('g9') + assert torch.load(g9_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(1, 1) + graph_stat = (g9_module_path / 'graph.ckp').stat() + args_stat = (g9_module_path / 'forward_args.pkl').stat() + _to_cube_model(MyModule, ComputeConfig(2, 2), tempdir, 'moo', 'g9', False) + assert (g9_module_path / 'graph.ckp').stat().st_mtime_ns == graph_stat.st_mtime_ns + assert (g9_module_path / 'forward_args.pkl').stat().st_mtime_ns == args_stat.st_mtime_ns + assert torch.load(g9_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(2, 2) diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 3bf299f7..d0774529 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -182,9 +182,9 @@ def _compare_weights(orig0, orig1, compiled0, compiled1, fc1_fullmap, fc2_fullma assert len(merged_fc1_fixed) + len(merged_fc2_fixed) + len(compiled0) == len(orig0) assert len(compiled1) == len(compiled0) for k, v in compiled0.items(): - assert torch.allclose(compiled0[k], compiled1[k], rtol=1e-4, atol=1e-4) + assert torch.allclose(compiled0[k].cpu(), compiled1[k].cpu(), rtol=1e-4, atol=1e-4) for k, v in itertools.chain(merged_fc1_fixed.items(), merged_fc2_fixed.items(), compiled0.items()): - assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) + assert torch.allclose(v.cpu(), orig0[k].cpu(), rtol=1e-4, atol=1e-4) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index eec4a4a4..3a4f9da0 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -140,7 +140,7 @@ def _compare_weights(orig0, orig1, compiled0, compiled1, module_fullmap, module_ merged_state, _ = ParallelModule.merge_partial_states(cube_state) assert len(compiled1) == len(compiled0) == len(orig0) for k, v in merged_state.items(): - assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) + assert torch.allclose(v.cpu(), orig0[k].cpu(), rtol=1e-4, atol=1e-4) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') From 91314b3a1f4949062ba46de56cae9616b830bd82 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Wed, 17 Apr 2024 08:10:53 +0000 Subject: [PATCH 1620/1892] Merged PR 2102: Fix persistent buffer The original logic is to register all non-nn.Parameter tensors as persistent buffers, and then these buffers will be saved together in the checkpoint. On the one hand, this will lead to useless checkpoint storage, and some buffers will be stored in fp32 because they are initialized to fp32. When restoring in fp16 format, these tensors will be loaded into fp16 format. In the process of fp32_dump, load ckpt, and to_fp16, checkpoint comparison will be inconsistent due to accuracy issues. Now, ordinary tensors will be kept registered as non-persistent buffers, and the original buffers will remain persistent and consistent with the original. --- cube/codegen/module/module.py | 23 +++++++++++++------ .../concrete_trace_utils/concrete_tracer.py | 7 +++--- cube/graph/parser/fx/parser.py | 3 ++- cube/ir/cten.py | 17 ++++++++++++-- cube/ir/tensor.py | 4 +++- cube/runtime/module.py | 4 ++++ .../parallel_module/test_checkpoint_buffer.py | 2 +- 7 files changed, 45 insertions(+), 15 deletions(-) diff --git a/cube/codegen/module/module.py b/cube/codegen/module/module.py index eab29896..f2d07d7c 100644 --- a/cube/codegen/module/module.py +++ b/cube/codegen/module/module.py @@ -615,7 +615,7 @@ def init_attributes(self, node: IRCell): the names of the variables for the tensors ever encountered. """ psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" - bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}))" + bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}), persistent={persistent})" map_sign = "self.add_full_map('{attr}', {tid}, {is_param}, '{orig_name}', {full_shape}, {slicers}, {val_chunks})" if not isinstance(node, IRSegment): for itensor in node.inputs(): @@ -623,12 +623,21 @@ def init_attributes(self, node: IRCell): if isinstance(itensor, IRSubTensor): if itensor.is_attr() and not self.symbols.exist(name): self.symbols.create(name) - sign = psign if itensor.is_param() else bsign - code = sign.format( - name=self.tensor_name(itensor), - shape=tuple(itensor.shape), - dtype=itensor.dtype - ) + if itensor.is_param(): + code = psign.format( + name=self.tensor_name(itensor), + shape=tuple(itensor.shape), + dtype=itensor.dtype + ) + elif itensor.is_buffer(): + code = bsign.format( + name=self.tensor_name(itensor), + shape=tuple(itensor.shape), + dtype=itensor.dtype, + persistent=itensor.is_persistent() + ) + else: + raise RuntimeError(f"Unexpected tensor type: {itensor}") self.model_init_statements.append(code) slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) val_chunks = itensor.valmap[1] diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index c5b1fb4f..21ef0a92 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -1439,10 +1439,11 @@ def copy_attr_new(from_module: torch.nn.Module, to_module: torch.nn.Module, targ from_module, to_module = f, t orig = getattr(from_module, field) - # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. - # So, we register it as a named buffer in the target module. + + # If it is a buffer, register the tensor as the same type of buffer, otherwise, just set the attribute. if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): - to_module.register_buffer(field, orig) + persistent = field in from_module._buffers and field not in from_module._non_persistent_buffers_set + to_module.register_buffer(field, orig, persistent=persistent) else: setattr(to_module, field, orig) diff --git a/cube/graph/parser/fx/parser.py b/cube/graph/parser/fx/parser.py index 2adf4d6b..cc7422ea 100644 --- a/cube/graph/parser/fx/parser.py +++ b/cube/graph/parser/fx/parser.py @@ -279,7 +279,8 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, if tensor.requires_grad: tensor.as_param() else: - tensor.as_buffer() + persistent = node.name not in module._non_persistent_buffers_set + tensor.as_buffer(persistent=persistent) frame.add_attr(tensor, concrete_value, node.target) # the case that the parameter is consumed multiple times and regisetered previously else: diff --git a/cube/ir/cten.py b/cube/ir/cten.py index 7aef0e36..e1e8d2e9 100644 --- a/cube/ir/cten.py +++ b/cube/ir/cten.py @@ -392,7 +392,7 @@ class IRTensor(IRObject): and will be translated to None in code generation. """ - _meta = ['name', '_is_attr', '_is_grad', '_requires_grad', '_dtype'] + _meta = ['name', '_is_attr', '_is_grad', '_requires_grad', '_dtype', '_persistent'] def __init__(self, shape=None, name='tensor', dtype=None, tid=None): @@ -404,6 +404,9 @@ def __init__(self, shape=None, name='tensor', dtype=None, tid=None): self._is_grad: bool = False self._requires_grad: bool = False self._grad: Optional[Union[IRTensor, float]] = None + # _persistent is a buffer only field, but in inference mode all params will be post-processed to buffers, + # so set _persistent True in as_param() for register these params to persistent buffers. + self._persistent = False @property def dtype(self) -> Optional[torch.dtype]: @@ -435,6 +438,14 @@ def is_buffer(self) -> bool: @return is_buffer boolean: True if is buffer. """ return self._is_attr and not self.requires_grad + + def is_persistent(self) -> bool: + """! + Check if the tensor is persistent buffer. + + @return is_persistent boolean: True if is persistent. + """ + return self.is_buffer() and self._persistent def is_grad(self) -> bool: """! @@ -452,15 +463,17 @@ def as_param(self): self._requires_grad = True self._is_attr = True self._is_grad = False + self._persistent = True return self - def as_buffer(self): + def as_buffer(self, persistent=True): """ Set the tensor as un-trainable buffer """ self._requires_grad = False self._is_attr = True self._is_grad = False + self._persistent = persistent return self def as_grad(self): diff --git a/cube/ir/tensor.py b/cube/ir/tensor.py index b0d86e36..dc60335f 100644 --- a/cube/ir/tensor.py +++ b/cube/ir/tensor.py @@ -346,16 +346,18 @@ def as_param(self): self.requires_grad = True self._is_attr = True self._is_grad = False + self._persistent = True if isinstance(self.grad, IRFullTensor): self.grad._is_attr = True - def as_buffer(self): + def as_buffer(self, persistent=True): """ Set the tensor as un-trainable buffer """ self.requires_grad = False self._is_attr = True self._is_grad = False + self._persistent = persistent def as_grad(self, of_attr: bool = False): self._attr = True if of_attr else False diff --git a/cube/runtime/module.py b/cube/runtime/module.py index ad1e7e7d..7503e02f 100644 --- a/cube/runtime/module.py +++ b/cube/runtime/module.py @@ -274,6 +274,10 @@ def merge_model_state_dicts( # gather param/buffer full tensor for model_state_dict, local_fullmap in zip(state_dicts, fullmaps): for local_name, meta in local_fullmap.items(): + if local_name not in model_state_dict: + # this is a non persistent buffer, skip + # non persistent buffer should be stored in the fullmap, but not in the model state dict + continue # create full tensor on cpu partial_tensor = model_state_dict[local_name] if meta.orig_name not in full_model_state_dict: diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index 40b8f32f..ceca83e6 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -49,7 +49,7 @@ def _gpu_worker(): with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: net1 = _to_cube_model(Net1(), compute_config, tempdir, 'net1', (128, 64)) cube_state_dict = net1.state_dict() - assert any(key.startswith('buffer') for key in cube_state_dict) + assert not any(key.startswith('buffer') for key in cube_state_dict) merged_state_dict, _ = merge_state_dicts([cube_state_dict]) assert 'buffer' not in merged_state_dict From 8e0a87bfee9bfc9685dac97693e4e856b32b1f8c Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Thu, 18 Apr 2024 03:30:36 +0000 Subject: [PATCH 1621/1892] Merged PR 2117: [Re-org Repo, Part 1] Move "cube" to "nnscaler" This PR shows what has been changed in MagicCube/main repo. Merging AutoDist, and adding packaging scripts will be opened as two separate PRs, for clarity. This PR does not necessarily need to be merged. We may either develop on the new branch "nnscaler". I personally have no preference. --- cube/codegen/__init__.py | 2 - cube/execplan/__init__.py | 1 - cube/graph/__init__.py | 2 - cube/graph/function/__init__.py | 2 - cube/graph/parser/__init__.py | 4 -- cube/graph/schedule/__init__.py | 1 - cube/ir/__init__.py | 5 -- cube/ir/adapter/__init__.py | 1 - cube/profiler/__init__.py | 2 - cube/runtime/__init__.py | 6 -- cube/runtime/adapter/__init__.py | 4 -- cube/runtime/function/__init__.py | 1 - docs/parallel_module.md | 6 +- examples/alphafold2/alphafold2.py | 22 +++--- examples/alphafold2/model.py | 36 +++++----- examples/alphafold2/module.py | 50 ++++++------- examples/alphafold2/policy/spmd.py | 6 +- examples/llama/chat.py | 8 +-- examples/llama/generation.py | 10 +-- examples/llama/model.py | 8 +-- examples/llama/test_chat_completion.py | 8 +-- examples/megatron_gpt/convert.py | 12 ++-- examples/megatron_gpt/parallel.py | 22 +++--- examples/mlp/policy/gallery.py | 8 +-- examples/mlp/train.py | 14 ++-- examples/nlp/blocks/attention.py | 6 +- examples/nlp/blocks/mlp.py | 4 +- examples/nlp/gpt/model.py | 4 +- examples/nlp/gpt/policy/mpmd.py | 8 +-- examples/nlp/gpt/policy/spmd.py | 6 +- examples/nlp/gpt/train.py | 20 +++--- examples/nlp/mbart/model.py | 10 +-- examples/nlp/mbart/policy/gallery.py | 12 ++-- examples/nlp/mbart/train.py | 22 +++--- examples/openfold/blocks/attention.py | 30 ++++---- examples/openfold/blocks/embedder.py | 8 +-- examples/openfold/blocks/evoformer.py | 24 +++---- examples/openfold/blocks/opm.py | 14 ++-- examples/openfold/blocks/tmu.py | 12 ++-- examples/openfold/blocks/utils.py | 4 +- examples/openfold/model.py | 6 +- examples/openfold/policy/mpmd.py | 14 ++-- examples/openfold/train.py | 16 ++--- examples/policies/alpa/README.md | 2 +- examples/policies/alpa/__init__.py | 16 ++--- examples/policies/alpa/cost_model.py | 10 +-- examples/policies/alpa/estimator.py | 14 ++-- examples/policies/alpa/inter_op.py | 2 +- examples/policies/alpa/intra_op.py | 6 +- examples/policies/alpa/layer_op.py | 8 +-- examples/policies/alpa/plan.py | 4 +- examples/policies/gshard.py | 10 +-- examples/policies/random_spmd.py | 8 +-- examples/utils.py | 18 ++--- examples/vision/swin/baseline.py | 10 +-- examples/vision/swin/blocks/attention.py | 4 +- examples/vision/swin/blocks/mlp.py | 4 +- examples/vision/swin/blocks/patch.py | 6 +- examples/vision/swin/blocks/transformer.py | 8 +-- examples/vision/swin/model.py | 4 +- examples/vision/swin/policy/gallery.py | 8 +-- examples/vision/swin/train.py | 14 ++-- {cube => nnscaler}/__init__.py | 16 ++--- {cube => nnscaler}/algorithm/__init__.py | 0 {cube => nnscaler}/algorithm/factory.py | 4 +- {cube => nnscaler}/algorithm/generics.py | 2 +- {cube => nnscaler}/algorithm/ops/__init__.py | 0 {cube => nnscaler}/algorithm/ops/conv.py | 6 +- {cube => nnscaler}/algorithm/ops/dimops.py | 10 +-- nnscaler/codegen/__init__.py | 2 + {cube => nnscaler}/codegen/emit.py | 16 ++--- .../codegen/frontend_mapping.py | 6 +- {cube => nnscaler}/codegen/lifecycle.py | 10 +-- {cube => nnscaler}/codegen/module/__init__.py | 0 {cube => nnscaler}/codegen/module/autograd.py | 10 +-- {cube => nnscaler}/codegen/module/module.py | 40 +++++------ .../codegen/schedule/__init__.py | 0 .../codegen/schedule/schedule.py | 28 ++++---- {cube => nnscaler}/codegen/syntax/__init__.py | 0 {cube => nnscaler}/codegen/syntax/blocks.py | 0 {cube => nnscaler}/codegen/syntax/symtable.py | 0 {cube => nnscaler}/compiler.py | 52 +++++++------- nnscaler/execplan/__init__.py | 1 + {cube => nnscaler}/execplan/execplan.py | 12 ++-- .../execplan/planpass/__init__.py | 0 .../execplan/planpass/fusion.py | 24 +++---- .../execplan/planpass/grouping.py | 16 ++--- .../execplan/planpass/planpass.py | 2 +- {cube => nnscaler}/flags.py | 0 nnscaler/graph/__init__.py | 2 + nnscaler/graph/function/__init__.py | 2 + {cube => nnscaler}/graph/function/anchor.py | 6 +- {cube => nnscaler}/graph/function/conv.py | 8 +-- {cube => nnscaler}/graph/function/dimops.py | 6 +- {cube => nnscaler}/graph/function/function.py | 70 +++++++++---------- {cube => nnscaler}/graph/function/pyfunc.py | 4 +- {cube => nnscaler}/graph/gener/__init__.py | 0 {cube => nnscaler}/graph/gener/concurrent.py | 20 +++--- {cube => nnscaler}/graph/gener/gen.py | 26 +++---- .../graph/gener/rvd/__init__.py | 0 {cube => nnscaler}/graph/gener/rvd/inter.py | 18 ++--- {cube => nnscaler}/graph/gener/rvd/intra.py | 32 ++++----- {cube => nnscaler}/graph/gener/rvd/layout.py | 6 +- {cube => nnscaler}/graph/gener/utils.py | 8 +-- {cube => nnscaler}/graph/graph.py | 26 +++---- nnscaler/graph/parser/__init__.py | 4 ++ {cube => nnscaler}/graph/parser/converter.py | 18 ++--- .../graph/parser/external/__init__.py | 0 .../graph/parser/external/apex.py | 4 +- {cube => nnscaler}/graph/parser/frame.py | 2 +- .../fx/concrete_trace_utils/__init__.py | 0 .../fx/concrete_trace_utils/concrete_proxy.py | 0 .../concrete_trace_utils/concrete_tracer.py | 0 .../concrete_trace_utils/function_patcher.py | 0 .../concrete_trace_utils/operator_patcher.py | 2 +- .../parser/fx/concrete_trace_utils/utils.py | 2 +- {cube => nnscaler}/graph/parser/fx/mapping.py | 8 +-- {cube => nnscaler}/graph/parser/fx/parser.py | 18 ++--- {cube => nnscaler}/graph/parser/register.py | 20 +++--- nnscaler/graph/schedule/__init__.py | 1 + .../graph/schedule/predefined.py | 6 +- .../graph/schedule/schedplan.py | 16 ++--- {cube => nnscaler}/graph/segment.py | 12 ++-- nnscaler/ir/__init__.py | 5 ++ nnscaler/ir/adapter/__init__.py | 1 + {cube => nnscaler}/ir/adapter/adapter.py | 6 +- {cube => nnscaler}/ir/adapter/prim.py | 52 +++++++------- {cube => nnscaler}/ir/cten.py | 6 +- {cube => nnscaler}/ir/dtype.py | 0 {cube => nnscaler}/ir/operator.py | 10 +-- {cube => nnscaler}/ir/tensor.py | 2 +- {cube => nnscaler}/ir/unique.py | 0 {cube => nnscaler}/parallel.py | 64 ++++++++--------- {cube => nnscaler}/profiler/README.md | 4 +- nnscaler/profiler/__init__.py | 2 + {cube => nnscaler}/profiler/database.py | 10 +-- {cube => nnscaler}/profiler/estimator.py | 8 +-- {cube => nnscaler}/profiler/memory.py | 2 +- {cube => nnscaler}/profiler/timer.py | 2 +- {cube => nnscaler}/program.py | 18 ++--- nnscaler/runtime/__init__.py | 6 ++ nnscaler/runtime/adapter/__init__.py | 4 ++ .../runtime/adapter/collectives.py | 6 +- {cube => nnscaler}/runtime/adapter/nn.py | 4 +- {cube => nnscaler}/runtime/adapter/reducer.py | 8 +-- .../runtime/adapter/transform.py | 0 {cube => nnscaler}/runtime/device.py | 2 +- {cube => nnscaler}/runtime/executor.py | 0 nnscaler/runtime/function/__init__.py | 1 + .../runtime/function/function.py | 0 {cube => nnscaler}/runtime/gnorm.py | 4 +- {cube => nnscaler}/runtime/module.py | 26 +++---- {cube => nnscaler}/runtime/resource.py | 2 +- {cube => nnscaler}/runtime/utils.py | 2 +- {cube => nnscaler}/utils.py | 16 ++--- setup.py | 6 +- tests/algorithm/ops/test_dimops.py | 6 +- tests/codegen/test_emit.py | 10 +-- tests/compiler/test_compile.py | 18 ++--- tests/conftest.py | 2 +- tests/graph/function/test_dataloader.py | 6 +- tests/graph/function/test_dimops.py | 8 +-- tests/graph/function/test_functions.py | 10 +-- tests/graph/gener/check_inter_rvd.py | 12 ++-- tests/graph/gener/check_intra_rvd.py | 12 ++-- tests/graph/gener/test_reducer_gen.py | 16 ++--- tests/graph/parser/test_ast_transformer.py | 2 +- tests/graph/parser/test_converter.py | 12 ++-- tests/graph/parser/test_dce.py | 2 +- tests/graph/parser/test_ir_obj_constant.py | 2 +- tests/graph/parser/test_no_grad.py | 2 +- tests/graph/parser/test_parser.py | 4 +- tests/graph/parser/test_register.py | 14 ++-- tests/graph/parser/test_register_external.py | 6 +- tests/graph/test_graph.py | 6 +- tests/graph/test_multiref.py | 14 ++-- tests/graph/tracer/test_cls_wrapper.py | 2 +- tests/graph/tracer/test_getattr.py | 2 +- tests/graph/tracer/test_inplace.py | 6 +- tests/graph/tracer/test_namedtuple.py | 2 +- tests/graph/tracer/test_op_patcher.py | 2 +- tests/graph/tracer/test_pytree.py | 2 +- tests/graph/tracer/test_scope.py | 2 +- tests/ir/tensor.py | 2 +- tests/parallel_module/common.py | 14 ++-- tests/parallel_module/test_broadcast.py | 2 +- tests/parallel_module/test_checkpoint.py | 6 +- .../parallel_module/test_checkpoint_buffer.py | 2 +- .../parallel_module/test_checkpoint_dedup.py | 4 +- .../parallel_module/test_checkpoint_shared.py | 2 +- .../parallel_module/test_checkpoint_unused.py | 6 +- tests/parallel_module/test_ddp.py | 6 +- tests/parallel_module/test_end2end.py | 10 +-- tests/parallel_module/test_gencode.py | 8 +-- tests/parallel_module/test_inference.py | 2 +- tests/parallel_module/test_init.py | 2 +- tests/parallel_module/test_nested.py | 2 +- tests/parallel_module/test_override.py | 6 +- tests/parallel_module/test_reducer_hook.py | 4 +- tests/parallel_module/test_scale_grads.py | 6 +- tests/parallel_module/test_submodule.py | 4 +- tests/parallel_module/test_wholemodule.py | 4 +- tests/profiler/test_op_profile.py | 8 +-- tests/runtime/test_dataloader.py | 2 +- tests/runtime/test_gnorm.py | 18 ++--- tests/runtime/test_grad_accum.py | 10 +-- tests/runtime/test_module_merge.py | 18 ++--- tests/runtime/test_reducer.py | 14 ++-- tests/runtime/test_runtime_collectives.py | 34 ++++----- tests/test_program.py | 6 +- tests/utils.py | 6 +- tox.ini | 2 +- tutorial.md | 4 +- utility/test_rvd_prim.py | 10 +-- 214 files changed, 953 insertions(+), 953 deletions(-) delete mode 100644 cube/codegen/__init__.py delete mode 100644 cube/execplan/__init__.py delete mode 100644 cube/graph/__init__.py delete mode 100644 cube/graph/function/__init__.py delete mode 100644 cube/graph/parser/__init__.py delete mode 100644 cube/graph/schedule/__init__.py delete mode 100644 cube/ir/__init__.py delete mode 100644 cube/ir/adapter/__init__.py delete mode 100644 cube/profiler/__init__.py delete mode 100644 cube/runtime/__init__.py delete mode 100644 cube/runtime/adapter/__init__.py delete mode 100644 cube/runtime/function/__init__.py rename {cube => nnscaler}/__init__.py (68%) rename {cube => nnscaler}/algorithm/__init__.py (100%) rename {cube => nnscaler}/algorithm/factory.py (95%) rename {cube => nnscaler}/algorithm/generics.py (97%) rename {cube => nnscaler}/algorithm/ops/__init__.py (100%) rename {cube => nnscaler}/algorithm/ops/conv.py (98%) rename {cube => nnscaler}/algorithm/ops/dimops.py (98%) create mode 100644 nnscaler/codegen/__init__.py rename {cube => nnscaler}/codegen/emit.py (96%) rename {cube => nnscaler}/codegen/frontend_mapping.py (95%) rename {cube => nnscaler}/codegen/lifecycle.py (94%) rename {cube => nnscaler}/codegen/module/__init__.py (100%) rename {cube => nnscaler}/codegen/module/autograd.py (90%) rename {cube => nnscaler}/codegen/module/module.py (96%) rename {cube => nnscaler}/codegen/schedule/__init__.py (100%) rename {cube => nnscaler}/codegen/schedule/schedule.py (88%) rename {cube => nnscaler}/codegen/syntax/__init__.py (100%) rename {cube => nnscaler}/codegen/syntax/blocks.py (100%) rename {cube => nnscaler}/codegen/syntax/symtable.py (100%) rename {cube => nnscaler}/compiler.py (89%) create mode 100644 nnscaler/execplan/__init__.py rename {cube => nnscaler}/execplan/execplan.py (98%) rename {cube => nnscaler}/execplan/planpass/__init__.py (100%) rename {cube => nnscaler}/execplan/planpass/fusion.py (86%) rename {cube => nnscaler}/execplan/planpass/grouping.py (89%) rename {cube => nnscaler}/execplan/planpass/planpass.py (74%) rename {cube => nnscaler}/flags.py (100%) create mode 100644 nnscaler/graph/__init__.py create mode 100644 nnscaler/graph/function/__init__.py rename {cube => nnscaler}/graph/function/anchor.py (89%) rename {cube => nnscaler}/graph/function/conv.py (96%) rename {cube => nnscaler}/graph/function/dimops.py (99%) rename {cube => nnscaler}/graph/function/function.py (98%) rename {cube => nnscaler}/graph/function/pyfunc.py (92%) rename {cube => nnscaler}/graph/gener/__init__.py (100%) rename {cube => nnscaler}/graph/gener/concurrent.py (97%) rename {cube => nnscaler}/graph/gener/gen.py (98%) rename {cube => nnscaler}/graph/gener/rvd/__init__.py (100%) rename {cube => nnscaler}/graph/gener/rvd/inter.py (97%) rename {cube => nnscaler}/graph/gener/rvd/intra.py (97%) rename {cube => nnscaler}/graph/gener/rvd/layout.py (98%) rename {cube => nnscaler}/graph/gener/utils.py (96%) rename {cube => nnscaler}/graph/graph.py (98%) create mode 100644 nnscaler/graph/parser/__init__.py rename {cube => nnscaler}/graph/parser/converter.py (90%) rename {cube => nnscaler}/graph/parser/external/__init__.py (100%) rename {cube => nnscaler}/graph/parser/external/apex.py (97%) rename {cube => nnscaler}/graph/parser/frame.py (99%) rename {cube => nnscaler}/graph/parser/fx/concrete_trace_utils/__init__.py (100%) rename {cube => nnscaler}/graph/parser/fx/concrete_trace_utils/concrete_proxy.py (100%) rename {cube => nnscaler}/graph/parser/fx/concrete_trace_utils/concrete_tracer.py (100%) rename {cube => nnscaler}/graph/parser/fx/concrete_trace_utils/function_patcher.py (100%) rename {cube => nnscaler}/graph/parser/fx/concrete_trace_utils/operator_patcher.py (99%) rename {cube => nnscaler}/graph/parser/fx/concrete_trace_utils/utils.py (99%) rename {cube => nnscaler}/graph/parser/fx/mapping.py (97%) rename {cube => nnscaler}/graph/parser/fx/parser.py (97%) rename {cube => nnscaler}/graph/parser/register.py (93%) create mode 100644 nnscaler/graph/schedule/__init__.py rename {cube => nnscaler}/graph/schedule/predefined.py (98%) rename {cube => nnscaler}/graph/schedule/schedplan.py (98%) rename {cube => nnscaler}/graph/segment.py (99%) create mode 100644 nnscaler/ir/__init__.py create mode 100644 nnscaler/ir/adapter/__init__.py rename {cube => nnscaler}/ir/adapter/adapter.py (97%) rename {cube => nnscaler}/ir/adapter/prim.py (92%) rename {cube => nnscaler}/ir/cten.py (99%) rename {cube => nnscaler}/ir/dtype.py (100%) rename {cube => nnscaler}/ir/operator.py (96%) rename {cube => nnscaler}/ir/tensor.py (99%) rename {cube => nnscaler}/ir/unique.py (100%) rename {cube => nnscaler}/parallel.py (98%) rename {cube => nnscaler}/profiler/README.md (95%) create mode 100644 nnscaler/profiler/__init__.py rename {cube => nnscaler}/profiler/database.py (98%) rename {cube => nnscaler}/profiler/estimator.py (90%) rename {cube => nnscaler}/profiler/memory.py (98%) rename {cube => nnscaler}/profiler/timer.py (99%) rename {cube => nnscaler}/program.py (95%) create mode 100644 nnscaler/runtime/__init__.py create mode 100644 nnscaler/runtime/adapter/__init__.py rename {cube => nnscaler}/runtime/adapter/collectives.py (98%) rename {cube => nnscaler}/runtime/adapter/nn.py (98%) rename {cube => nnscaler}/runtime/adapter/reducer.py (99%) rename {cube => nnscaler}/runtime/adapter/transform.py (100%) rename {cube => nnscaler}/runtime/device.py (99%) rename {cube => nnscaler}/runtime/executor.py (100%) create mode 100644 nnscaler/runtime/function/__init__.py rename {cube => nnscaler}/runtime/function/function.py (100%) rename {cube => nnscaler}/runtime/gnorm.py (99%) rename {cube => nnscaler}/runtime/module.py (98%) rename {cube => nnscaler}/runtime/resource.py (97%) rename {cube => nnscaler}/runtime/utils.py (98%) rename {cube => nnscaler}/utils.py (94%) diff --git a/cube/codegen/__init__.py b/cube/codegen/__init__.py deleted file mode 100644 index 84f4bcd6..00000000 --- a/cube/codegen/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from cube.codegen.module.module import ModuleCodeGen -from cube.codegen.schedule.schedule import ScheduleCodeGen diff --git a/cube/execplan/__init__.py b/cube/execplan/__init__.py deleted file mode 100644 index c6d0899c..00000000 --- a/cube/execplan/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cube.execplan.execplan import ExecutionPlan \ No newline at end of file diff --git a/cube/graph/__init__.py b/cube/graph/__init__.py deleted file mode 100644 index ec86b08f..00000000 --- a/cube/graph/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from cube.graph.graph import IRGraph -from cube.graph import parser diff --git a/cube/graph/function/__init__.py b/cube/graph/function/__init__.py deleted file mode 100644 index fc28ba75..00000000 --- a/cube/graph/function/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from cube.graph.function.dimops import IRDimops -from cube.graph.function.function import * \ No newline at end of file diff --git a/cube/graph/parser/__init__.py b/cube/graph/parser/__init__.py deleted file mode 100644 index ba172d07..00000000 --- a/cube/graph/parser/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from cube.graph.parser.fx.parser import FxModuleParser -from cube.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph -from cube.graph.parser.register import register -from cube.graph.parser.external import * diff --git a/cube/graph/schedule/__init__.py b/cube/graph/schedule/__init__.py deleted file mode 100644 index 4c5f2f80..00000000 --- a/cube/graph/schedule/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cube.graph.schedule.schedplan import SchedulePlan diff --git a/cube/ir/__init__.py b/cube/ir/__init__.py deleted file mode 100644 index 23ad9584..00000000 --- a/cube/ir/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from cube.ir.dtype import * -from cube.ir.cten import IRTensor, IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor -from cube.ir.operator import IRFwOperation, IRBpOperation, IRDataOperation -from cube.ir.adapter.adapter import IRAdapter diff --git a/cube/ir/adapter/__init__.py b/cube/ir/adapter/__init__.py deleted file mode 100644 index 553b5db3..00000000 --- a/cube/ir/adapter/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cube.ir.adapter.adapter import IRAdapter, IRWeightReducer diff --git a/cube/profiler/__init__.py b/cube/profiler/__init__.py deleted file mode 100644 index 6bf47044..00000000 --- a/cube/profiler/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from cube.profiler.timer import CudaTimer -from cube.profiler.database import ProfileDataBase diff --git a/cube/runtime/__init__.py b/cube/runtime/__init__.py deleted file mode 100644 index 4fea0597..00000000 --- a/cube/runtime/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from cube.runtime import executor -from cube.runtime import device -from cube.runtime import adapter -from cube.runtime import resource -from cube.runtime import module -from cube.runtime import function diff --git a/cube/runtime/adapter/__init__.py b/cube/runtime/adapter/__init__.py deleted file mode 100644 index 6eb54da8..00000000 --- a/cube/runtime/adapter/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from cube.runtime.adapter.collectives import * -from cube.runtime.adapter.transform import * -from cube.runtime.adapter import nn -from cube.runtime.adapter.reducer import Reducer diff --git a/cube/runtime/function/__init__.py b/cube/runtime/function/__init__.py deleted file mode 100644 index c5b9ae13..00000000 --- a/cube/runtime/function/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cube.runtime.function.function import * \ No newline at end of file diff --git a/docs/parallel_module.md b/docs/parallel_module.md index 7be5686b..c35f7570 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -18,7 +18,7 @@ The above restrictions are necessary for the pipeline parallelism to work. Of co ```python import torch -from cube.parallel import parallelize, ComputeConfig, build_optimizer +from nnscaler.parallel import parallelize, ComputeConfig, build_optimizer class LLM(torch.nn.Module): def __init__(self, ...): @@ -48,7 +48,7 @@ In this case, for non-paralle modules, they are replicated inside unit, and run ```python import torch -from cube.parallel import parallelize, ComputeConfig, build_optimizer +from nnscaler.parallel import parallelize, ComputeConfig, build_optimizer class HeavyModule(torch.nn.Module): def __init__(self, ...): @@ -414,7 +414,7 @@ the generated code in outdir will be removed EVEN IF the code generation fails i After the module is converted, you can use it to create module object by calling it like a module class. The module class is defined like: ```python -class GenModule(cube.runtime.module.ParallelModule): +class GenModule(nnscaler.runtime.module.ParallelModule): def __init__(self, init_params=True): super().__init__() ... diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 5936be61..3b68484a 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -1,17 +1,17 @@ import torch import math -import cube +import nnscaler -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank +from nnscaler.profiler import CudaTimer +from nnscaler.profiler.timer import print_each_rank from examples.alphafold2.model import * import examples.alphafold2.policy.spmd as spmd -from cube.ir.operator import IRFwOperation, IRBpOperation -from cube.profiler.database import ProfileDataBase -from cube.algorithm.ops.dimops import gen_partitions -from cube.graph.function.anchor import IRGraphAnchor +from nnscaler.ir.operator import IRFwOperation, IRBpOperation +from nnscaler.profiler.database import ProfileDataBase +from nnscaler.algorithm.ops.dimops import gen_partitions +from nnscaler.graph.function.anchor import IRGraphAnchor @@ -44,15 +44,15 @@ def run(size_config, other_config, policy): if not is_train: model.eval() - model = cube.SemanticModel(model, + model = nnscaler.SemanticModel(model, input_shapes=([bs, s, r, cm], [bs, r, r, cz])) - dataloader = cube.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], + dataloader = nnscaler.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], [bs, r, r, cz]), dtypes=(dtype, dtype), batch_dims=(0, 0)) - @cube.compile(model, dataloader, PAS=policy, override=True) + @nnscaler.compile(model, dataloader, PAS=policy, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) @@ -125,5 +125,5 @@ def test_main(): if __name__ == '__main__': - cube.init() + nnscaler.init() test_main() diff --git a/examples/alphafold2/model.py b/examples/alphafold2/model.py index b78f4ded..c2109c91 100644 --- a/examples/alphafold2/model.py +++ b/examples/alphafold2/model.py @@ -1,4 +1,4 @@ -import cube +import nnscaler import torch import math from torch import nn @@ -155,7 +155,7 @@ def __init__(self, torch.randn(ff_mult * cz, cz)) def forward(self, msa_repr, pair_repr): - cube.runtime.function.anchor('MSARow') + nnscaler.runtime.function.anchor('MSARow') pair_repr, dummy_pair_repr = multi2ref(pair_repr) msa_repr = msa_repr + MSARowAttentionWithPairBias( @@ -164,7 +164,7 @@ def forward(self, msa_repr, pair_repr): self.msa_head, self.c, self.scale, self.msa_row_chunk, self.is_train) - cube.runtime.function.anchor('MSACol') + nnscaler.runtime.function.anchor('MSACol') if self.is_extra: msa_repr = msa_repr + MSAColGlobalAttention( self.col_norm(msa_repr), self.col_q_proj, self.col_k_proj, @@ -176,13 +176,13 @@ def forward(self, msa_repr, pair_repr): self.col_out_proj, self.msa_head, self.c, self.scale, self.msa_col_chunk, self.is_train) - cube.runtime.function.anchor('MSATrans') + nnscaler.runtime.function.anchor('MSATrans') msa_repr = msa_repr + MSATransition(self.msa_transition_norm(msa_repr), self.msa_transition_proj1, self.msa_transition_proj2) succ_msa_repr, msa_repr = multi2ref(msa_repr) - cube.runtime.function.anchor('OPM') + nnscaler.runtime.function.anchor('OPM') msa_repr = self.outer_norm(msa_repr) opm_left, opm_right = OPMLeftProj(msa_repr, self.outer_proj1), OPMRightProj( @@ -191,7 +191,7 @@ def forward(self, msa_repr, pair_repr): opm_left, opm_right, self.outer_out_proj, self.opm_chunk, self.is_train) - cube.runtime.function.anchor('TMO') + nnscaler.runtime.function.anchor('TMO') pair_repr = self.tri_mul_out_norm1(pair_repr) tmo_left, tmo_right = TMOLeftProj( pair_repr, self.tri_mul_out_proj1, @@ -203,7 +203,7 @@ def forward(self, msa_repr, pair_repr): tmo_left, tmo_right, tmo_g, self.tri_mul_out_norm2_weight, self.tri_mul_out_norm2_bias, self.tri_mul_out_proj5, self.cz) - cube.runtime.function.anchor('TMI') + nnscaler.runtime.function.anchor('TMI') pair_repr = self.tri_mul_in_norm1(pair_repr) tmi_left = TMILeftProj(pair_repr, self.tri_mul_in_proj1, self.tri_mul_in_proj2) @@ -214,7 +214,7 @@ def forward(self, msa_repr, pair_repr): tmi_left, tmi_right, tmi_gate, self.tri_mul_in_norm2_weight, self.tri_mul_in_norm2_bias, self.tri_mul_in_proj5, self.cz) - cube.runtime.function.anchor('TANS') + nnscaler.runtime.function.anchor('TANS') pair_repr = self.tri_att_start_norm(pair_repr) bias = TANSBias(pair_repr, self.tri_att_start_bias_proj) pair_repr = pair_repr + TriangleAttentionNodeStart( @@ -222,7 +222,7 @@ def forward(self, msa_repr, pair_repr): self.tri_att_start_qkv_proj, self.tri_att_start_out_proj, bias, self.pair_head, self.c, self.scale, self.tans_chunk, self.is_train) - cube.runtime.function.anchor('TANE') + nnscaler.runtime.function.anchor('TANE') pair_repr = self.tri_att_end_norm(pair_repr) bias = TANEBias(pair_repr, self.tri_att_end_bias_proj) pair_repr = pair_repr + TriangleAttentionNodeEnd( @@ -230,7 +230,7 @@ def forward(self, msa_repr, pair_repr): self.tri_att_end_out_proj, bias, self.pair_head, self.c, self.scale, self.tane_chunk, self.is_train) - cube.runtime.function.anchor('PairTrans') + nnscaler.runtime.function.anchor('PairTrans') pair_repr = pair_repr + PairTransition( self.pair_transition_norm(pair_repr), self.pair_transition_proj1, self.pair_transition_proj2) @@ -266,12 +266,12 @@ def forward(self, msa, pair): msa = self.msa_norm(msa) pair = self.pair_norm(pair) - cube.runtime.function.anchor('Evoformer Stack Start') + nnscaler.runtime.function.anchor('Evoformer Stack Start') for evoformer in self.evoformers: - cube.runtime.function.anchor('One Layer Evoformer Start') + nnscaler.runtime.function.anchor('One Layer Evoformer Start') msa, pair = evoformer(msa, pair) - cube.runtime.function.anchor('One Layer Evoformer End') - cube.runtime.function.anchor('Evoformer Stack End') + nnscaler.runtime.function.anchor('One Layer Evoformer End') + nnscaler.runtime.function.anchor('Evoformer Stack End') loss = torch.sum(msa) * torch.sum(pair) return loss @@ -296,11 +296,11 @@ def forward(self, msa, pair): msa = self.msa_norm(msa) pair = self.pair_norm(pair) - cube.runtime.function.anchor('Evoformer Stack Start') + nnscaler.runtime.function.anchor('Evoformer Stack Start') for evoformer in self.evoformers: - cube.runtime.function.anchor('One Layer Evoformer Start') + nnscaler.runtime.function.anchor('One Layer Evoformer Start') msa, pair = evoformer(msa, pair) - cube.runtime.function.anchor('One Layer Evoformer End') - cube.runtime.function.anchor('Evoformer Stack End') + nnscaler.runtime.function.anchor('One Layer Evoformer End') + nnscaler.runtime.function.anchor('Evoformer Stack End') loss = torch.sum(msa) * torch.sum(pair) return loss \ No newline at end of file diff --git a/examples/alphafold2/module.py b/examples/alphafold2/module.py index 09de7d1f..2e20fc93 100644 --- a/examples/alphafold2/module.py +++ b/examples/alphafold2/module.py @@ -1,9 +1,9 @@ -import cube +import nnscaler import torch import torch.utils.checkpoint as ckpt -@cube.graph.parser.register('*, *, * -> *, *, *, *', name='calc_qkvg') +@nnscaler.graph.parser.register('*, *, * -> *, *, *, *', name='calc_qkvg') def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, bs: int, s: int, r: int, head: int, c: int): gate = torch.sigmoid(torch.matmul(x, gate_proj)) @@ -23,7 +23,7 @@ def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, """ -@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R^ M^', +@nnscaler.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R^ M^', name='MSAAttention') @torch.jit.ignore def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, @@ -91,7 +91,7 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return out -@cube.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^, N 1^ 8^ R^ R^ -> N S R^ M^', +@nnscaler.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^, N 1^ 8^ R^ R^ -> N S R^ M^', name='MSAAttentionWithBias') @torch.jit.ignore def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, @@ -177,7 +177,7 @@ def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # note: code not reused constrained by cube's interface -@cube.graph.parser.register('N S R^ M^, N R^ R^ Z^, M^ E^, M^ F^, E^ M^, Z^ H^ -> N S R^ M^', +@nnscaler.graph.parser.register('N S R^ M^, N R^ R^ Z^, M^ E^, M^ F^, E^ M^, Z^ H^ -> N S R^ M^', name='MSARowAttentionWithPairBias') def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, pair_repr: torch.Tensor, @@ -196,7 +196,7 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, head, c, scale, chunk_size, is_train) -@cube.graph.parser.register('N S^ R M^, M^ E^, M^ F^, E^ M^ -> N S^ R M^', +@nnscaler.graph.parser.register('N S^ R M^, M^ E^, M^ F^, E^ M^ -> N S^ R M^', name='MSAColAttention') def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, @@ -207,7 +207,7 @@ def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, is_train).permute(0, 2, 1, 3) -@cube.graph.parser.register('N S^ R^ M^, M^ M^, M^ E^, M^ E^, M^ M^, M^ M^ -> N S^ R^ M^', +@nnscaler.graph.parser.register('N S^ R^ M^, M^ M^, M^ E^, M^ E^, M^ M^, M^ M^ -> N S^ R^ M^', name='MSAColGlobalAttention') def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, v_proj: torch.Tensor, @@ -250,7 +250,7 @@ def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, """ -@cube.graph.parser.register('N S R M^, M^ E^, E^ M^ -> N S R M^', +@nnscaler.graph.parser.register('N S R M^, M^ E^, E^ M^ -> N S R M^', name='MSATransition') def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -258,12 +258,12 @@ def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, torch.nn.functional.relu(torch.matmul(msa_repr, proj1)), proj2) -@cube.graph.parser.register('N S R M^, M^ C^ -> N S R C^', name='OPMLeftProj') +@nnscaler.graph.parser.register('N S R M^, M^ C^ -> N S R C^', name='OPMLeftProj') def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): return torch.matmul(msa_repr, proj) -@cube.graph.parser.register('N S R M^, M^ C^ -> N S R C^', name='OPMRightProj') +@nnscaler.graph.parser.register('N S R M^, M^ C^ -> N S R C^', name='OPMRightProj') def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): return torch.matmul(msa_repr, proj) @@ -273,7 +273,7 @@ def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): """ -@cube.graph.parser.register('N S^ R M^, N S^ T^ M^, F^ Z^ -> N R^ T Z^', +@nnscaler.graph.parser.register('N S^ R M^, N S^ T^ M^, F^ Z^ -> N R^ T Z^', name='OuterProductMean') @torch.jit.ignore def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, @@ -308,7 +308,7 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): return outer -@cube.graph.parser.register('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMOLeftProj') +@nnscaler.graph.parser.register('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMOLeftProj') def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): a = torch.sigmoid(torch.matmul(pair_repr, proj1)) @@ -316,7 +316,7 @@ def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return b -@cube.graph.parser.register('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', +@nnscaler.graph.parser.register('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMORightProj') def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -325,12 +325,12 @@ def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return b -@cube.graph.parser.register('N S T^ Z^, Z^ Z^ -> N S T^ Z^', name='TMOGate') +@nnscaler.graph.parser.register('N S T^ Z^, Z^ Z^ -> N S T^ Z^', name='TMOGate') def TMOGate(pair_repr: torch.Tensor, proj: torch.Tensor): return torch.sigmoid(torch.matmul(pair_repr, proj)) -@cube.graph.parser.register('N S R^ E^, N T^ R^ E^, N S T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', +@nnscaler.graph.parser.register('N S R^ E^, N T^ R^ E^, N S T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='TriangleMultiplicationOut') def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, @@ -347,7 +347,7 @@ def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, return p * g -@cube.graph.parser.register('N R^ S Z^, Z^ E^, Z^ E^ -> N R^ S E^', name='TMILeftProj') +@nnscaler.graph.parser.register('N R^ S Z^, Z^ E^, Z^ E^ -> N R^ S E^', name='TMILeftProj') def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): a = torch.sigmoid(torch.matmul(pair_repr, proj1)) @@ -355,7 +355,7 @@ def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return a -@cube.graph.parser.register('N R^ T Z^, Z^ E^, Z^ E^ -> N R^ T E^', +@nnscaler.graph.parser.register('N R^ T Z^, Z^ E^, Z^ E^ -> N R^ T E^', name='TMIRightProj') def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -364,12 +364,12 @@ def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return a -@cube.graph.parser.register('N S^ T Z^, Z^ Z^ -> N S^ T Z^', name='TMIGate') +@nnscaler.graph.parser.register('N S^ T Z^, Z^ Z^ -> N S^ T Z^', name='TMIGate') def TMIGate(pair_repr: torch.Tensor, proj: torch.Tensor): return torch.sigmoid(torch.matmul(pair_repr, proj)) -@cube.graph.parser.register('N R^ S E^, N R^ T^ E^, N T^ S Z^, E^, E^, E^ Z^ -> N T^ S Z^', +@nnscaler.graph.parser.register('N R^ S E^, N R^ T^ E^, N T^ S Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='TriangleMultiplicationIn') def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, tri_mul_norm2_weight: torch.Tensor, @@ -385,12 +385,12 @@ def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, return p.permute(0, 2, 1, 3) * g -@cube.graph.parser.register('N S R^ C^, C^ D^ -> N S R^ D^', name='TANSBias') +@nnscaler.graph.parser.register('N S R^ C^, C^ D^ -> N S R^ D^', name='TANSBias') def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): return torch.matmul(pair_repr, bias_proj) -@cube.graph.parser.register('N S R^ Z^, Z^ E^, Z^ F^, E^ Z^, N T^ R^ G^ -> N S R^ Z^', +@nnscaler.graph.parser.register('N S R^ Z^, Z^ E^, Z^ F^, E^ Z^, N T^ R^ G^ -> N S R^ Z^', name='TriangleAttentionNodeStart') def TriangleAttentionNodeStart(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, @@ -404,12 +404,12 @@ def TriangleAttentionNodeStart(pair_repr: torch.Tensor, head, c, scale, chunk_size, is_train) -@cube.graph.parser.register('N S^ R C^, C^ D^ -> N S^ R D^', name='TANEBias') +@nnscaler.graph.parser.register('N S^ R C^, C^ D^ -> N S^ R D^', name='TANEBias') def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): return torch.matmul(pair_repr, bias_proj) -@cube.graph.parser.register('N R^ S Z^, Z^ E^, Z^ F^, E^ Z^, N R^ T^ G^ -> N R^ S Z^', +@nnscaler.graph.parser.register('N R^ S Z^, Z^ E^, Z^ F^, E^ Z^, N R^ T^ G^ -> N R^ S Z^', name='TriangleAttentionNodeEnd') def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, @@ -424,7 +424,7 @@ def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, return out.permute(0, 2, 1, 3) -@cube.graph.parser.register('N R T^ Z^, Z^ E^, E^ Z^ -> N R T^ Z^', +@nnscaler.graph.parser.register('N R T^ Z^, Z^ E^, E^ Z^ -> N R T^ Z^', name='PairTransition') def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -432,6 +432,6 @@ def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, torch.nn.functional.relu(torch.matmul(pair_repr, proj1)), proj2) -@cube.graph.parser.register('* -> *, *', name='multi2ref') +@nnscaler.graph.parser.register('* -> *, *', name='multi2ref') def multi2ref(x: torch.Tensor): return (x, x) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py index 24d07f1c..1cc25fd0 100644 --- a/examples/alphafold2/policy/spmd.py +++ b/examples/alphafold2/policy/spmd.py @@ -1,9 +1,9 @@ from typing import List from numpy import TooHardError -from cube.graph import IRGraph -from cube.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation -from cube.graph.function.anchor import IRGraphAnchor +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation +from nnscaler.graph.function.anchor import IRGraphAnchor def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): diff --git a/examples/llama/chat.py b/examples/llama/chat.py index e89b6493..a9de661b 100644 --- a/examples/llama/chat.py +++ b/examples/llama/chat.py @@ -17,11 +17,11 @@ from examples.llama.generation import Llama -import cube +import nnscaler -cube.init() -cube.set_logger_level(level=logging.WARNING) -logging.getLogger('cube.compiler').setLevel(logging.INFO) +nnscaler.init() +nnscaler.set_logger_level(level=logging.WARNING) +logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) def main( diff --git a/examples/llama/generation.py b/examples/llama/generation.py index 3594ebd4..92558df8 100644 --- a/examples/llama/generation.py +++ b/examples/llama/generation.py @@ -16,8 +16,8 @@ Role = Literal["system", "user", "assistant"] -import cube -from cube.flags import CompileFlag +import nnscaler +from nnscaler.flags import CompileFlag class Message(TypedDict): @@ -108,12 +108,12 @@ def build_inference(self): 1, 1000, size=(4, 38), dtype=torch.int64) def policy(graph, resource): - from cube.ir.operator import IRFwOperation + from nnscaler.ir.operator import IRFwOperation for fwop in graph.select(ntype=IRFwOperation): graph.assign(fwop, 0) return graph - @cube.compile(self.model, sample_tokens, 0, + @nnscaler.compile(self.model, sample_tokens, 0, PAS=policy, model_dynamic_shape=True) def infer(model: torch.nn.Module, tokens: torch.Tensor, prev_pos: int): logits = model(tokens, prev_pos) @@ -123,7 +123,7 @@ def infer(model: torch.nn.Module, tokens: torch.Tensor, prev_pos: int): vocab_size, n_layers = params.vocab_size, params.n_layers del self.model - self.model = cube.load_model() + self.model = nnscaler.load_model() # TODO: support auto reset non-parameter attributes for llama model self.model.params = params diff --git a/examples/llama/model.py b/examples/llama/model.py index 646d13be..846ad825 100644 --- a/examples/llama/model.py +++ b/examples/llama/model.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from torch import nn -import cube +import nnscaler @dataclass class ModelArgs: @@ -58,7 +58,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): # TODO: fix annotation -@cube.graph.parser.register('*, *, 38^ 64^ -> *, *') +@nnscaler.graph.parser.register('*, *, 38^ 64^ -> *, *') def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, @@ -72,7 +72,7 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) -@cube.graph.parser.register('N seqlen^, N seqlen^ H^ -> 1 1 seqlen^ seqlen^') +@nnscaler.graph.parser.register('N seqlen^, N seqlen^ H^ -> 1 1 seqlen^ seqlen^') def create_mask(tokens: torch.Tensor, h: torch.Tensor, start_pos: int): seqlen = tokens.shape[1] mask = None @@ -84,7 +84,7 @@ def create_mask(tokens: torch.Tensor, h: torch.Tensor, start_pos: int): return mask -@cube.graph.parser.register('N seqlen *, 1 1 * -> N seqlen *') +@nnscaler.graph.parser.register('N seqlen *, 1 1 * -> N seqlen *') def apply_mask(x: torch.Tensor, mask: torch.Tensor): return x if mask is None else x + mask diff --git a/examples/llama/test_chat_completion.py b/examples/llama/test_chat_completion.py index 1a95c739..0d98e664 100644 --- a/examples/llama/test_chat_completion.py +++ b/examples/llama/test_chat_completion.py @@ -17,11 +17,11 @@ from examples.llama.generation import Llama -import cube +import nnscaler -cube.init() -cube.set_logger_level(level=logging.WARNING) -logging.getLogger('cube.compiler').setLevel(logging.INFO) +nnscaler.init() +nnscaler.set_logger_level(level=logging.WARNING) +logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) def main( diff --git a/examples/megatron_gpt/convert.py b/examples/megatron_gpt/convert.py index dc0f3dfb..20c30c64 100644 --- a/examples/megatron_gpt/convert.py +++ b/examples/megatron_gpt/convert.py @@ -3,11 +3,11 @@ model = build_model() # 2. register customized op -from cube.graph.parser.register import register +from nnscaler.graph.parser.register import register register('* h, h -> * h')(GeLUFunction.apply) # 3. build semantic model -from cube import SemanticModel +from nnscaler import SemanticModel smodel = SemanticModel(model) # 4. set dummy input @@ -21,8 +21,8 @@ 'ntokens': 128, } -from cube.graph.function import IRObject -from cube.ir import IRFullTensor +from nnscaler.graph.function import IRObject +from nnscaler.ir import IRFullTensor src_tokens = IRFullTensor(shape=[batch_size, seq_len], name='src_tokens', @@ -35,8 +35,8 @@ ntokens = IRObject(name='ntokens') # 5. convert to graph -from cube.graph.segment import IRSegment -from cube.program import Program +from nnscaler.graph.segment import IRSegment +from nnscaler.program import Program from torch.autograd.graph import saved_tensors_hooks diff --git a/examples/megatron_gpt/parallel.py b/examples/megatron_gpt/parallel.py index e41df6f8..091896da 100644 --- a/examples/megatron_gpt/parallel.py +++ b/examples/megatron_gpt/parallel.py @@ -3,21 +3,21 @@ runtime_ngpus = int(os.environ['CUBE_SCALING_FACTOR']) * plan_ngpus # 1. load graph -from cube.graph import IRGraph +from nnscaler.graph import IRGraph graph = IRGraph.load('megatron_gpt2.cube') # 2. register customized op from gpt_model import GeLUFunction -from cube.graph.parser.register import register +from nnscaler.graph.parser.register import register register('* h, h -> * h')(GeLUFunction.apply) # 3. parallel model -from fairseq.cube.pas_policies import PASData, PASRandomSPMD +from fairseq.nnscaler.pas_policies import PASData, PASRandomSPMD graph = PASData(graph, plan_ngpus) for node in graph.nodes(flatten=True): - from cube.graph.function.anchor import IRGraphAnchor - from cube.graph.function.pyfunc import IRPyFunc + from nnscaler.graph.function.anchor import IRGraphAnchor + from nnscaler.graph.function.pyfunc import IRPyFunc # skip graph anchor and multiref: they will be removed or replaced by system if isinstance(node, IRGraphAnchor) or node.name == 'multiref': graph.assign(node, 0) @@ -25,28 +25,28 @@ graph.assign(node, 0) if len(node.device) == 0: raise RuntimeError(f"Node {node} device is not set") -from cube.graph.gener.gen import IRAdapterGener +from nnscaler.graph.gener.gen import IRAdapterGener graph = IRAdapterGener.gen(graph, cost_fn=None) if graph.sched is not None: graph.sched.apply() print(graph.sched) -from cube.graph.schedule.schedplan import SchedulePlan -from cube.execplan import ExecutionPlan +from nnscaler.graph.schedule.schedplan import SchedulePlan +from nnscaler.execplan import ExecutionPlan if isinstance(graph.sched, SchedulePlan): execplan = ExecutionPlan.from_schedplan(graph.sched) else: execplan = ExecutionPlan.from_graph(graph) # execplan.visualize('plan.png') -from cube.execplan.planpass.fusion import DiffFusion +from nnscaler.execplan.planpass.fusion import DiffFusion execplan = DiffFusion.apply(execplan) # plan pass for computation grouping -from cube.execplan.planpass.grouping import Grouping +from nnscaler.execplan.planpass.grouping import Grouping if not graph.sched: execplan = Grouping.apply(execplan) # 4. generate code -from cube.codegen import ModuleCodeGen, ScheduleCodeGen +from nnscaler.codegen import ModuleCodeGen, ScheduleCodeGen filename = 'gencode{}.py' _runtime_ngpus = None if plan_ngpus == runtime_ngpus else runtime_ngpus assert len(execplan.graph.device) == plan_ngpus, f"{execplan.graph.device}" diff --git a/examples/mlp/policy/gallery.py b/examples/mlp/policy/gallery.py index 5df705ac..b69974f5 100644 --- a/examples/mlp/policy/gallery.py +++ b/examples/mlp/policy/gallery.py @@ -1,8 +1,8 @@ from typing import List -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.schedule.predefined import PredefinedSched +from nnscaler.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.graph.schedule.predefined import PredefinedSched from examples.utils import tensor_parallelism, replica, create_mesh diff --git a/examples/mlp/train.py b/examples/mlp/train.py index 409260e5..e5a40fac 100644 --- a/examples/mlp/train.py +++ b/examples/mlp/train.py @@ -9,10 +9,10 @@ from torch import nn from functools import partial -import cube -from cube.profiler import CudaTimer -from cube.profiler.timer import print_each_rank -from cube.runtime.utils import microbatches +import nnscaler +from nnscaler.profiler import CudaTimer +from nnscaler.profiler.timer import print_each_rank +from nnscaler.runtime.utils import microbatches import examples.mlp.policy.gallery as gallery @@ -29,7 +29,7 @@ parser.add_argument('--tp-size', type=int, default=2, help='tensor parallelism size only for Megatron policy') args = parser.parse_args() -cube.init() +nnscaler.init() # get policy policy = get_policy([gallery], args.policy) @@ -63,13 +63,13 @@ def train(): dataloader = microbatches((dummy_data(),)) # compile a training iteration - @cube.compile(model, dataloader, PAS=policy) + @nnscaler.compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) loss.backward() # load generated model - model = cube.utils.load_model() + model = nnscaler.utils.load_model() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 2a15b28a..86d95ceb 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -1,8 +1,8 @@ import torch -import cube +import nnscaler -@cube.graph.parser.register('L^ N E^, (h+ d^ 3) E^, (h+ d^ 3), E^ (h+ d^) -> L^ N E^', name='self_attention') +@nnscaler.graph.parser.register('L^ N E^, (h+ d^ 3) E^, (h+ d^ 3), E^ (h+ d^) -> L^ N E^', name='self_attention') def self_attention(query: torch.Tensor, qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, @@ -56,7 +56,7 @@ def self_attention(query: torch.Tensor, return output -@cube.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') +@nnscaler.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') def cross_attention(query: torch.Tensor, key: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py index 767a54e6..95b52e09 100644 --- a/examples/nlp/blocks/mlp.py +++ b/examples/nlp/blocks/mlp.py @@ -1,8 +1,8 @@ import torch -import cube +import nnscaler -@cube.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') +@nnscaler.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj1_bias: torch.Tensor, proj2: torch.Tensor, diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py index 25123183..6fca2959 100644 --- a/examples/nlp/gpt/model.py +++ b/examples/nlp/gpt/model.py @@ -1,7 +1,7 @@ import torch from dataclasses import dataclass -import cube +import nnscaler from examples.nlp.blocks.transformer import TransformerLayer @@ -76,7 +76,7 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): enc = embed.transpose(0, 1) for layer in self.layers: - cube.runtime.function.anchor('transformer start') + nnscaler.runtime.function.anchor('transformer start') enc = layer(enc) enc = self.final_layernorm(enc) diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py index 8d9978b7..ec6b1041 100644 --- a/examples/nlp/gpt/policy/mpmd.py +++ b/examples/nlp/gpt/policy/mpmd.py @@ -1,8 +1,8 @@ """GPT policy gallery for MPMD Parallelism""" -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.schedule.predefined import PredefinedSched +from nnscaler.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.graph.schedule.predefined import PredefinedSched from examples.utils import create_mesh, tensor_parallelism, replica, group_to_layers diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py index f84f8cc0..1c6da6db 100644 --- a/examples/nlp/gpt/policy/spmd.py +++ b/examples/nlp/gpt/policy/spmd.py @@ -2,9 +2,9 @@ from typing import List -from cube.graph import IRGraph -from cube.graph.function.pyfunc import IRPyFunc -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from nnscaler.graph import IRGraph +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from examples.utils import tensor_parallelism, replica diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 54f5f1f3..8478c7fd 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -13,10 +13,10 @@ from model import GPT, Config, dummy_data -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.utils import microbatches +import nnscaler +from nnscaler.profiler.timer import CudaTimer, print_each_rank +from nnscaler.profiler.memory import memory_summary +from nnscaler.runtime.utils import microbatches import examples.nlp.gpt.policy.spmd as spmd import examples.nlp.gpt.policy.mpmd as mpmd @@ -51,9 +51,9 @@ args = parser.parse_args() -cube.init() -cube.set_logger_level(logging.WARN) -logging.getLogger('cube.compiler').setLevel(logging.INFO) +nnscaler.init() +nnscaler.set_logger_level(logging.WARN) +logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) # get policy policy = get_policy([spmd, mpmd], args.policy) @@ -80,12 +80,12 @@ def train(): gen_data = partial(dummy_data, args.mbs, config) dataloader = microbatches((gen_data(),), cycle=True) - @cube.compile(model, dataloader, PAS=policy) + @nnscaler.compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) loss.backward() - model = cube.utils.load_model() + model = nnscaler.utils.load_model() optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) @@ -123,5 +123,5 @@ def train_iter(model, dataloader): if __name__ == '__main__': - cube.init() + nnscaler.init() train() \ No newline at end of file diff --git a/examples/nlp/mbart/model.py b/examples/nlp/mbart/model.py index 6c9764c1..5a9e6675 100644 --- a/examples/nlp/mbart/model.py +++ b/examples/nlp/mbart/model.py @@ -4,7 +4,7 @@ from examples.nlp.blocks.transformer import TransformerLayer -import cube +import nnscaler @dataclass @@ -126,7 +126,7 @@ def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor): The loss computation is also simplified by using sum. """ # encoder embedding - cube.runtime.function.anchor('encoder embedding') + nnscaler.runtime.function.anchor('encoder embedding') enc_emb = torch.nn.functional.embedding(input_ids, self.vocab) enc_emb = enc_emb * self.embed_scale_encoder enc_emb = enc_emb + self.encoder_position @@ -136,12 +136,12 @@ def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor): # encoder layers for layer in self.encoders: - cube.runtime.function.anchor('encoder layer') + nnscaler.runtime.function.anchor('encoder layer') enc = layer(enc) enc = self.layer_norm_encoder(enc) # decoder embedding - cube.runtime.function.anchor('decoder embedding') + nnscaler.runtime.function.anchor('decoder embedding') dec_emb = torch.nn.functional.embedding(decoder_input_ids, self.vocab) dec_emb = dec_emb * self.embed_scale_decoder dec_emb = dec_emb + self.decoder_position @@ -151,7 +151,7 @@ def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor): # decoder layers for layer in self.decoders: - cube.runtime.function.anchor('decoder layer') + nnscaler.runtime.function.anchor('decoder layer') dec = layer(dec, enc) dec = self.layer_norm_decoder(dec) diff --git a/examples/nlp/mbart/policy/gallery.py b/examples/nlp/mbart/policy/gallery.py index ab45bd8a..0e675958 100644 --- a/examples/nlp/mbart/policy/gallery.py +++ b/examples/nlp/mbart/policy/gallery.py @@ -1,11 +1,11 @@ from typing import List -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.schedule.predefined import PredefinedSched -from cube.graph.segment import IRSegment -from cube.ir.cten import IRCell +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation, IRDataOperation +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.schedule.predefined import PredefinedSched +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.cten import IRCell from examples.utils import create_mesh, tensor_parallelism, replica diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index 596bce91..946ed414 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -16,10 +16,10 @@ from examples.nlp.mbart.model import MBartForSentenceClassification, Config from examples.nlp.mbart.model import dummy_data -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.utils import microbatches +import nnscaler +from nnscaler.profiler.timer import CudaTimer, print_each_rank +from nnscaler.profiler.memory import memory_summary +from nnscaler.runtime.utils import microbatches import examples.nlp.mbart.policy.gallery as gallery @@ -50,13 +50,13 @@ args = parser.parse_args() -cube.init() +nnscaler.init() print(args) -cube.init() -cube.set_logger_level(logging.WARN) -logging.getLogger('cube.compiler').setLevel(logging.INFO) +nnscaler.init() +nnscaler.set_logger_level(logging.WARN) +logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) # get policy policy = get_policy([gallery], args.policy) @@ -105,12 +105,12 @@ def train(): gen_data = partial(dummy_data, batch_size, config) dataloader = microbatches((gen_data(),), cycle=True) - @cube.compile(model, dataloader, PAS=policy) + @nnscaler.compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): input_ids, decoder_input_ids = next(dataloader) loss = model(input_ids, decoder_input_ids) loss.backward() - model = cube.load_model() + model = nnscaler.load_model() optimizer = torch.optim.Adam( model.parameters(), lr=3e-05, betas=(0.9, 0.98)) @@ -143,5 +143,5 @@ def train_iter(model, dataloader): if __name__ == '__main__': - cube.init() + nnscaler.init() train() \ No newline at end of file diff --git a/examples/openfold/blocks/attention.py b/examples/openfold/blocks/attention.py index 7c94aafb..079ef4ea 100644 --- a/examples/openfold/blocks/attention.py +++ b/examples/openfold/blocks/attention.py @@ -2,17 +2,17 @@ Attention Module for MSA Attention and Pair Attention in Evoformer """ -import cube +import nnscaler import torch import torch.utils.checkpoint as ckpt -@cube.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S R^ M^', name='msa_attn') +@nnscaler.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S R^ M^', name='msa_attn') @torch.jit.ignore def msa_attn(x: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): - # cube.profiler.CudaTimer().start('msa_attn') + # nnscaler.profiler.CudaTimer().start('msa_attn') bs, s, r, cm = x.size() if chunk_size == -1: @@ -60,17 +60,17 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out_chunks.append(attend) out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) - # cube.profiler.CudaTimer().stop('msa_attn') + # nnscaler.profiler.CudaTimer().stop('msa_attn') return out -@cube.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, N 1 head+ R^ R^ -> N S R^ M^', name='msa_attn_bias') +@nnscaler.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, N 1 head+ R^ R^ -> N S R^ M^', name='msa_attn_bias') @torch.jit.ignore def msa_attn_bias(x: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, bias: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): - # cube.profiler.CudaTimer().start('msa_attn_bias') + # nnscaler.profiler.CudaTimer().start('msa_attn_bias') bs, s, r, cm = x.size() assert gate_proj.size(1) % head == 0 c = gate_proj.size(1) // head @@ -127,12 +127,12 @@ def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attend = attention_bias(q, k, v, gate, bias, start) out_chunks.append(attend) out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) - # cube.profiler.CudaTimer().stop('msa_attn_bias') + # nnscaler.profiler.CudaTimer().stop('msa_attn_bias') return out # note: code not reused constrained by cube's interface -@cube.graph.parser.register('N S R^ M^, N R^ R^ Z^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, Z^ head+ -> N S R^ M^', name='row_attn') +@nnscaler.graph.parser.register('N S R^ M^, N R^ R^ Z^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, Z^ head+ -> N S R^ M^', name='row_attn') def row_attn(msa_repr: torch.Tensor, pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, @@ -147,7 +147,7 @@ def row_attn(msa_repr: torch.Tensor, pair_repr: torch.Tensor, head, c, scale, chunk_size, is_train) -@cube.graph.parser.register('N S^ R M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S^ R M^', name='col_attn') +@nnscaler.graph.parser.register('N S^ R M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S^ R M^', name='col_attn') def col_attn(msa_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): @@ -160,7 +160,7 @@ def col_attn(msa_repr: torch.Tensor, gate_proj: torch.Tensor, return out -# @cube.graph.parser.register('N S^ R^ M^, M^ (head+ dim^), M^ E^, M^ E^, M^ (head+ dim^), (head+ dim) M^ -> N S^ R^ M^', name='MSAColGlobalAttention') +# @nnscaler.graph.parser.register('N S^ R^ M^, M^ (head+ dim^), M^ E^, M^ E^, M^ (head+ dim^), (head+ dim) M^ -> N S^ R^ M^', name='MSAColGlobalAttention') def global_attn(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, v_proj: torch.Tensor, gate_proj: torch.Tensor, out_proj: torch.Tensor, @@ -197,21 +197,21 @@ def global_attn(msa_repr: torch.Tensor, q_proj: torch.Tensor, return torch.matmul(o, out_proj).transpose(-2, -3) -@cube.graph.parser.register('N S R M^, M^ E+, E+ M^ -> N S R M^', name='feedforward') +@nnscaler.graph.parser.register('N S R M^, M^ E+, E+ M^ -> N S R M^', name='feedforward') @torch.jit.ignore def feedforward(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): """ MSA transition """ - # cube.profiler.CudaTimer().start('ffn') + # nnscaler.profiler.CudaTimer().start('ffn') x = torch.matmul(msa_repr, proj1) x = torch.nn.functional.relu(x) x = torch.matmul(x, proj2) - # cube.profiler.CudaTimer().stop('ffn') + # nnscaler.profiler.CudaTimer().stop('ffn') return x -@cube.graph.parser.register('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N R^ R^ head+ -> N S R^ Z^', name='tri_attn_start') +@nnscaler.graph.parser.register('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N R^ R^ head+ -> N S R^ Z^', name='tri_attn_start') def tri_attn_start(pair_repr: torch.Tensor, gate: torch.Tensor, qkv: torch.Tensor, out: torch.Tensor, bias: torch.Tensor, @@ -224,7 +224,7 @@ def tri_attn_start(pair_repr: torch.Tensor, return out -@cube.graph.parser.register('N S^ R Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N S^ S^ head+ -> N S^ R Z^', name='tri_attn_end') +@nnscaler.graph.parser.register('N S^ R Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N S^ S^ head+ -> N S^ R Z^', name='tri_attn_end') def tri_attn_end(pair_repr: torch.Tensor, gate: torch.Tensor, qkv: torch.Tensor, out: torch.Tensor, bias: torch.Tensor, diff --git a/examples/openfold/blocks/embedder.py b/examples/openfold/blocks/embedder.py index cba55626..568c6621 100644 --- a/examples/openfold/blocks/embedder.py +++ b/examples/openfold/blocks/embedder.py @@ -3,11 +3,11 @@ from typing import Tuple, Optional -import cube +import nnscaler -@cube.graph.parser.register('N res, cz nobins, cz -> N res res cz', name='relpos') +@nnscaler.graph.parser.register('N res, cz nobins, cz -> N res res cz', name='relpos') def input_embedder_pair_emb(ri: torch.Tensor, tf_emb_i: torch.Tensor, tf_emb_j: torch.Tensor, w_relpos: torch.Tensor, b_relpos: torch.Tensor, @@ -32,7 +32,7 @@ def input_embedder_pair_emb(ri: torch.Tensor, return pair_emb -@cube.graph.parser.register('N res tfdim^, cm tfdim^, cm -> N nclust^, res, cm') +@nnscaler.graph.parser.register('N res tfdim^, cm tfdim^, cm -> N nclust^, res, cm') def input_embedder_tf_m(tf: torch.Tensor, w_tf_m: torch.Tensor, b_tf_m: torch.Tensor, nclust: int) -> torch.Tensor: tf_m = torch.nn.linear(tf, w_tf_m, b_tf_m) tf_m = tf_m.unsqueeze(-3).expand(((-1,) * len(tf.shape[:-2]) + (nclust, -1, -1))) @@ -119,7 +119,7 @@ def forward(self, tf: torch.Tensor, ri: torch.Tensor, msa: torch.Tensor) -> Tupl -@cube.graph.parser.register() +@nnscaler.graph.parser.register() def sum_d(x: torch.Tensor, bins: torch.Tensor, inf: float) -> torch.Tensor: squared_bins = bins ** 2 upper = torch.cat( diff --git a/examples/openfold/blocks/evoformer.py b/examples/openfold/blocks/evoformer.py index 0128fa1f..7ce24954 100644 --- a/examples/openfold/blocks/evoformer.py +++ b/examples/openfold/blocks/evoformer.py @@ -6,17 +6,17 @@ from examples.openfold.blocks.utils import multi2ref import math -import cube +import nnscaler -# @cube.graph.parser.register('N S^ R^ cm^, N R^ R^ cz^ -> N out^') +# @nnscaler.graph.parser.register('N S^ R^ cm^, N R^ R^ cz^ -> N out^') # @torch.jit.ignore # def input_packing(msa: torch.Tensor, pair: torch.Tensor, out: int) -> torch.Tensor: # buffer = torch.cat((torch.flatten(msa, start_dim=1), torch.flatten(pair, start_dim=1))) # return buffer # # -# @cube.graph.parser.register('N out^ -> N S^ R^ cm^, N R^ R^ cz^', name='input_unflatten') +# @nnscaler.graph.parser.register('N out^ -> N S^ R^ cm^, N R^ R^ cz^', name='input_unflatten') # @torch.jit.ignore # def input_unpacking(buffer: torch.Tensor, # S: int, R: int, cm: int, cz: int) -> Tuple[torch.Tensor, torch.Tensor]: @@ -86,46 +86,46 @@ def __init__(self, s: int, r: int, cm: int, cz: int, def forward(self, msa_repr, pair_repr): - cube.runtime.function.anchor('MSARow') + nnscaler.runtime.function.anchor('MSARow') pair_repr, dummy_pair_repr = multi2ref(pair_repr) residual = msa_repr msa_repr = self.row_norm_m(msa_repr) dummy_pair_repr = self.row_norm_z(dummy_pair_repr) msa_repr = residual + self.row_attn(msa_repr, dummy_pair_repr) - cube.runtime.function.anchor('MSACol') + nnscaler.runtime.function.anchor('MSACol') residual = msa_repr msa_repr = self.col_norm(msa_repr) msa_repr = residual + self.col_attn(msa_repr) - # cube.runtime.function.anchor('MSATrans') + # nnscaler.runtime.function.anchor('MSATrans') residual = msa_repr msa_repr = self.msa_transition_norm(msa_repr) msa_repr = self.msa_transition(msa_repr) msa_repr = residual + msa_repr succ_msa_repr, msa_repr = multi2ref(msa_repr) - cube.runtime.function.anchor('OPM') + nnscaler.runtime.function.anchor('OPM') msa_repr = self.outer_norm(msa_repr) pair_repr = pair_repr + self.outer_prod_mean(msa_repr) - cube.runtime.function.anchor('TMO') + nnscaler.runtime.function.anchor('TMO') pair_repr = self.tmo(pair_repr) - cube.runtime.function.anchor('TMI') + nnscaler.runtime.function.anchor('TMI') pair_repr = self.tmi(pair_repr) - cube.runtime.function.anchor('TANS') + nnscaler.runtime.function.anchor('TANS') residual = pair_repr pair_repr = self.tri_attn_node_start(pair_repr) pair_repr = residual + pair_repr - cube.runtime.function.anchor('TANE') + nnscaler.runtime.function.anchor('TANE') residual = pair_repr pair_repr = self.tri_attn_node_end(pair_repr) pair_repr = residual + pair_repr - cube.runtime.function.anchor('PairTrans') + nnscaler.runtime.function.anchor('PairTrans') residual = pair_repr pair_repr = self.pair_transition_norm(pair_repr) pair_repr = self.pair_transition(pair_repr) diff --git a/examples/openfold/blocks/opm.py b/examples/openfold/blocks/opm.py index 5ee13ad8..005841e3 100644 --- a/examples/openfold/blocks/opm.py +++ b/examples/openfold/blocks/opm.py @@ -2,16 +2,16 @@ Outer Product Mean module for Evoformer """ -import cube +import nnscaler import torch import torch.utils.checkpoint as ckpt -# @cube.graph.parser.register('N S+ R^ M^, M^ c^, M^ c^, (c^ c^) cz^ -> N R^ R^ cz^', name='outer_prod_mean') +# @nnscaler.graph.parser.register('N S+ R^ M^, M^ c^, M^ c^, (c^ c^) cz^ -> N R^ R^ cz^', name='outer_prod_mean') @torch.jit.ignore def outer_prod_mean(msa_repr: torch.Tensor, left_proj: torch.Tensor, right_proj: torch.Tensor, out_proj: torch.Tensor, chunk_size: int, training: bool): - # cube.profiler.CudaTimer().start('opm') + # nnscaler.profiler.CudaTimer().start('opm') # N S R M, M c -> N S R c opm_left = torch.matmul(msa_repr, left_proj) # N S T M, M c -> N S T c @@ -45,17 +45,17 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): ret = opm(a, b, start) out_chunks.append(ret) outer = torch.cat(out_chunks, dim=1) - # cube.profiler.CudaTimer().stop('opm') + # nnscaler.profiler.CudaTimer().stop('opm') return outer -@cube.graph.parser.register('N S R M+, M+ C -> N S R C', name='opm_projection') +@nnscaler.graph.parser.register('N S R M+, M+ C -> N S R C', name='opm_projection') def opm_projection(msa_repr: torch.Tensor, proj1: torch.Tensor): x = torch.matmul(msa_repr, proj1) return x -@cube.graph.parser.register('N S^ R C^, N S^ T^ C^, F^ Z^ -> N R T^ Z^') +@nnscaler.graph.parser.register('N S^ R C^, N S^ T^ C^, F^ Z^ -> N R T^ Z^') @torch.jit.ignore def opm(left: torch.Tensor, right: torch.Tensor, out_proj: torch.Tensor, chunk_size: int, training: bool): @@ -87,7 +87,7 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): ret = opm(a, b, start) out_chunks.append(ret) outer = torch.cat(out_chunks, dim=1) - # cube.profiler.CudaTimer().stop('opm') + # nnscaler.profiler.CudaTimer().stop('opm') return outer diff --git a/examples/openfold/blocks/tmu.py b/examples/openfold/blocks/tmu.py index c5a91ea9..ea83c3c4 100644 --- a/examples/openfold/blocks/tmu.py +++ b/examples/openfold/blocks/tmu.py @@ -1,21 +1,21 @@ -import cube +import nnscaler import torch from examples.openfold.blocks.utils import multi2ref -# @cube.graph.parser.register('N S R Z^, Z^ E, Z^ E -> N S R E') +# @nnscaler.graph.parser.register('N S R Z^, Z^ E, Z^ E -> N S R E') # def tmu_projection(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): # x = torch.matmul(pair_repr, proj1) # x = torch.sigmoid(x) # x = x * torch.matmul(pair_repr, proj2) # # -# @cube.graph.parser.register('N S R Z+, Z+ E-> N S R E') +# @nnscaler.graph.parser.register('N S R Z+, Z+ E-> N S R E') # def tmu_gate(pair_repr: torch.Tensor, proj: torch.Tensor): # return torch.sigmoid(torch.matmul(pair_repr, proj)) -@cube.graph.parser.register('N S R Z^, Z^ E^, Z^ E^, Z^ E, Z^ E^, Z^ Z^ -> N S R E, N S R E^, N S R Z^', name='tmu_projection') +@nnscaler.graph.parser.register('N S R Z^, Z^ E^, Z^ E^, Z^ E, Z^ E^, Z^ Z^ -> N S R E, N S R E^, N S R Z^', name='tmu_projection') def tmu_projection(pair_repr: torch.Tensor, left1: torch.Tensor, left2: torch.Tensor, right1: torch.Tensor, right2: torch.Tensor, @@ -34,7 +34,7 @@ def tmu_projection(pair_repr: torch.Tensor, return left, right, gate -@cube.graph.parser.register('N S R^ E, N T^ R^ E^, N S^ T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='tmo') +@nnscaler.graph.parser.register('N S R^ E, N T^ R^ E^, N S^ T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='tmo') def tmo(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, norm_w: torch.Tensor, norm_b: torch.Tensor, out: torch.Tensor): a = left.permute(0, 3, 1, 2) @@ -46,7 +46,7 @@ def tmo(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, return p -@cube.graph.parser.register('N R^ S E, N R^ T^ E^, N T^ S^ Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='tmi') +@nnscaler.graph.parser.register('N R^ S E, N R^ T^ E^, N T^ S^ Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='tmi') def tmi(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, norm_w: torch.Tensor, norm_b: torch.Tensor, out: torch.Tensor): a = left.permute(0, 3, 2, 1) diff --git a/examples/openfold/blocks/utils.py b/examples/openfold/blocks/utils.py index 520a3c1d..bf512f2e 100644 --- a/examples/openfold/blocks/utils.py +++ b/examples/openfold/blocks/utils.py @@ -1,7 +1,7 @@ -import cube +import nnscaler import torch -@cube.graph.parser.register('* -> *, *', name='multi2ref') +@nnscaler.graph.parser.register('* -> *, *', name='multi2ref') def multi2ref(x: torch.Tensor): return (x, x) \ No newline at end of file diff --git a/examples/openfold/model.py b/examples/openfold/model.py index 06953b13..3403b3c2 100644 --- a/examples/openfold/model.py +++ b/examples/openfold/model.py @@ -10,7 +10,7 @@ from dataclasses import dataclass -import cube +import nnscaler @dataclass @@ -143,10 +143,10 @@ def forward(self, msa, pair): """ msa = self.msa_norm(msa) pair = self.pair_norm(pair) - # cube.runtime.function.anchor('PackingRegion') + # nnscaler.runtime.function.anchor('PackingRegion') # x = input_packing(msa, pair, self.fout) for evoformer in self.evoformers: - cube.runtime.function.anchor('Evoformer Start') + nnscaler.runtime.function.anchor('Evoformer Start') msa, pair = evoformer(msa, pair) # x = evoformer(x) # msa, pair = input_unpacking(x, self.s, self.r, self.cm, self.cz) diff --git a/examples/openfold/policy/mpmd.py b/examples/openfold/policy/mpmd.py index 55ab9d6c..26353f1c 100644 --- a/examples/openfold/policy/mpmd.py +++ b/examples/openfold/policy/mpmd.py @@ -1,12 +1,12 @@ from typing import List -from cube.graph import IRGraph -from cube.ir.cten import IRCell -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.segment import IRSegment -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.graph.schedule.schednf1b import IRScheduleNF1B -from cube.graph.schedule.sched1f1b import IRSchedule1F1B +from nnscaler.graph import IRGraph +from nnscaler.ir.cten import IRCell +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from nnscaler.graph.schedule.schednf1b import IRScheduleNF1B +from nnscaler.graph.schedule.sched1f1b import IRSchedule1F1B import more_itertools import numpy as np diff --git a/examples/openfold/train.py b/examples/openfold/train.py index dec3ebbf..11186ad5 100644 --- a/examples/openfold/train.py +++ b/examples/openfold/train.py @@ -8,16 +8,16 @@ import torch from examples.openfold.model import AlphaFold, Config -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary +import nnscaler +from nnscaler.profiler.timer import CudaTimer, print_each_rank +from nnscaler.profiler.memory import memory_summary from examples.openfold.policy.mpmd import PASDAP, PASRoundRobin, PASNF1B, PASDAPPipe import argparse from functools import partial -cube.init() +nnscaler.init() parser = argparse.ArgumentParser(description='AlphaFold Train') parser.add_argument('--fp16', action='store_true', default=False, @@ -42,7 +42,7 @@ help='data parallelism size') args = parser.parse_args() -dp = cube.runtime.device.DeviceGroup().world_size // (args.tp * args.pp) +dp = nnscaler.runtime.device.DeviceGroup().world_size // (args.tp * args.pp) assert args.gbs % args.mbs == 0 assert args.mbs % dp == 0 assert args.msa_hidden % args.head_dim == 0 @@ -74,7 +74,7 @@ def train(): model = model.half() dtype = torch.float16 if args.fp16 else torch.float32 - dataloader = cube.runtime.syndata.SynDataLoader( + dataloader = nnscaler.runtime.syndata.SynDataLoader( shapes=([cfg.bs, cfg.evoformer_s, cfg.evoformer_r, cfg.evoformer_cm], [cfg.bs, cfg.evoformer_r, cfg.evoformer_r, cfg.evoformer_cz]), dtypes=(dtype, dtype), @@ -83,8 +83,8 @@ def train(): print_each_rank(f'before partitioned model parameter: {nparams(model)}') - model = cube.SemanticModel(model) - @cube.compile(model, dataloader, PAS=PASDAPPipe, override=True, load_content=True) + model = nnscaler.SemanticModel(model) + @nnscaler.compile(model, dataloader, PAS=PASDAPPipe, override=True, load_content=True) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) diff --git a/examples/policies/alpa/README.md b/examples/policies/alpa/README.md index 359e380b..5c847cea 100644 --- a/examples/policies/alpa/README.md +++ b/examples/policies/alpa/README.md @@ -9,7 +9,7 @@ pip install pulp ## Implementation Notes -* The implementation doesn't support auto_layer construction, and relies on the `cube.runtime.function.anchor` as stage division candidates. +* The implementation doesn't support auto_layer construction, and relies on the `nnscaler.runtime.function.anchor` as stage division candidates. * The implementation doesn't support `follow`, which relies on the user customized operator to achieve manual fusion. diff --git a/examples/policies/alpa/__init__.py b/examples/policies/alpa/__init__.py index af910735..7b012bc9 100644 --- a/examples/policies/alpa/__init__.py +++ b/examples/policies/alpa/__init__.py @@ -3,14 +3,14 @@ import warnings import torch -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.dimops import IRDimops, TransformRule, DimopSplit -from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.ir.tensor import IRFullTensor -from cube.graph.schedule.predefined import PredefinedSched -from cube.runtime.device import DeviceGroup +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.dimops import IRDimops, TransformRule, DimopSplit +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRFwOperation, IRDataOperation +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph.schedule.predefined import PredefinedSched +from nnscaler.runtime.device import DeviceGroup from examples.policies.alpa.plan import ParallelSpec from examples.policies.alpa.inter_op import inter_op diff --git a/examples/policies/alpa/cost_model.py b/examples/policies/alpa/cost_model.py index 57772c5b..91ba02a0 100644 --- a/examples/policies/alpa/cost_model.py +++ b/examples/policies/alpa/cost_model.py @@ -4,11 +4,11 @@ from typing import List, Callable, Tuple, Dict import numpy as np -from cube.graph import IRGraph -from cube.ir.cten import IRTensor -from cube.ir.operator import IRFwOperation -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.dimops import IRDimops, TransformRule, DimopSplit +from nnscaler.graph import IRGraph +from nnscaler.ir.cten import IRTensor +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.dimops import IRDimops, TransformRule, DimopSplit DistSpec = Dict[int, Tuple[Tuple[int, int]]] diff --git a/examples/policies/alpa/estimator.py b/examples/policies/alpa/estimator.py index a658c475..07391383 100644 --- a/examples/policies/alpa/estimator.py +++ b/examples/policies/alpa/estimator.py @@ -4,16 +4,16 @@ import json # ===== neccesaary for profiling ===== -import cube +import nnscaler import torch # ==================================== -from cube.ir.cten import IRTensor, IRObject, IRCell -from cube.ir.operator import IRFwOperation -from cube.graph.parser.register import CustomizedOps -from cube.graph.segment import IRSegment -from cube.graph.function.dimops import IRDimops -from cube.graph.function import IRGraphAnchor +from nnscaler.ir.cten import IRTensor, IRObject, IRCell +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.parser.register import CustomizedOps +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.function import IRGraphAnchor Shapes = NewType('Shapes', Tuple[Tuple[int]]) diff --git a/examples/policies/alpa/inter_op.py b/examples/policies/alpa/inter_op.py index 7db2a5f5..e8cb2bc4 100644 --- a/examples/policies/alpa/inter_op.py +++ b/examples/policies/alpa/inter_op.py @@ -8,7 +8,7 @@ from typing import List, Callable, Tuple, Dict, Optional import time -from cube.ir.operator import IRFwOperation +from nnscaler.ir.operator import IRFwOperation from examples.policies.alpa.layer_op import IRLayerOp, cluster_to_layer_ops from examples.policies.alpa.plan import StageSpec, ParallelSpec diff --git a/examples/policies/alpa/intra_op.py b/examples/policies/alpa/intra_op.py index e595f5b2..7f2af325 100644 --- a/examples/policies/alpa/intra_op.py +++ b/examples/policies/alpa/intra_op.py @@ -5,9 +5,9 @@ import warnings import time -from cube.ir.cten import IRTensor -from cube.ir.operator import IRFwOperation -from cube.graph.function.anchor import IRGraphAnchor +from nnscaler.ir.cten import IRTensor +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.function.anchor import IRGraphAnchor from examples.policies.alpa.layer_op import IRLayerOp from examples.policies.alpa.cost_model import CostModel diff --git a/examples/policies/alpa/layer_op.py b/examples/policies/alpa/layer_op.py index 1cd70f3a..bf220456 100644 --- a/examples/policies/alpa/layer_op.py +++ b/examples/policies/alpa/layer_op.py @@ -1,10 +1,10 @@ from typing import List, Dict, Tuple import more_itertools -from cube.ir.cten import IRCell -from cube.ir.operator import IRFwOperation -from cube.graph.graph import IRGraph -from cube.graph.function.anchor import IRGraphAnchor +from nnscaler.ir.cten import IRCell +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.function.anchor import IRGraphAnchor class IRLayerOp(IRCell): diff --git a/examples/policies/alpa/plan.py b/examples/policies/alpa/plan.py index 80e9212d..291ee7c8 100644 --- a/examples/policies/alpa/plan.py +++ b/examples/policies/alpa/plan.py @@ -2,8 +2,8 @@ from dataclasses import dataclass import json -from cube.ir.operator import IRFwOperation -from cube.graph.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.graph import IRGraph @dataclass class StageSpec: diff --git a/examples/policies/gshard.py b/examples/policies/gshard.py index 7e836dd0..f4378f6e 100644 --- a/examples/policies/gshard.py +++ b/examples/policies/gshard.py @@ -4,11 +4,11 @@ from typing import List -from cube.ir.tensor import IRSubTensor -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.graph import IRGraph -from cube.graph.function.dimops import IRDimops -from cube.graph.function.anchor import IRGraphAnchor +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.function.anchor import IRGraphAnchor def follow(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int, diff --git a/examples/policies/random_spmd.py b/examples/policies/random_spmd.py index 686dd8d9..7b176f8e 100644 --- a/examples/policies/random_spmd.py +++ b/examples/policies/random_spmd.py @@ -2,10 +2,10 @@ Random SPMD policy """ from typing import List, Optional -from cube.graph.graph import IRGraph -from cube.graph.function.dimops import IRDimops -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.graph.function.anchor import IRGraphAnchor from datetime import datetime import random diff --git a/examples/utils.py b/examples/utils.py index c1488ef1..f08b6191 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,15 +1,15 @@ from typing import List, Union, Callable, Optional, Tuple import logging -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.graph.function.dimops import IRDimops -from cube.graph.gener.rvd.intra import IntraAutoPlacer -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor -from cube.graph.function.anchor import IRGraphAnchor -from cube.utils import print_each_rank +from nnscaler.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.gener.rvd.intra import IntraAutoPlacer +from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.ir.cten import IRCell +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.utils import print_each_rank import numpy as np diff --git a/examples/vision/swin/baseline.py b/examples/vision/swin/baseline.py index 557fec3e..0678a105 100644 --- a/examples/vision/swin/baseline.py +++ b/examples/vision/swin/baseline.py @@ -13,10 +13,10 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint -from cube.profiler.memory import memory_summary, model_summary -from cube.profiler.timer import CudaTimer, print_each_rank +from nnscaler.profiler.memory import memory_summary, model_summary +from nnscaler.profiler.timer import CudaTimer, print_each_rank -import cube +import nnscaler def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): @@ -628,7 +628,7 @@ def flops(self): return flops -class ImageDataLoader(cube.runtime.syndata.CubeDataLoader): +class ImageDataLoader(nnscaler.runtime.syndata.CubeDataLoader): def __init__(self, batch_size: int, img_size: int, num_classes: int): @@ -769,5 +769,5 @@ def train_iter(model, dataloader): if __name__ == '__main__': - cube.init() + nnscaler.init() train() diff --git a/examples/vision/swin/blocks/attention.py b/examples/vision/swin/blocks/attention.py index c38c6987..b3ec28d8 100644 --- a/examples/vision/swin/blocks/attention.py +++ b/examples/vision/swin/blocks/attention.py @@ -1,13 +1,13 @@ from typing import Optional import torch -import cube +import nnscaler # REMARK: as default attention has qkv project weight of (3 head dim_head) C, # this cannot partition on head dimension # as the head dimension is a secondary hidden dimension in (3 head dim_head). # To make partition work (correctness guarantee), the dimension is swapped as (head dim_head 3) -@cube.graph.parser.register('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), nw N^ N^ -> B N^ C^') +@nnscaler.graph.parser.register('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), nw N^ N^ -> B N^ C^') def window_attn(x: torch.Tensor, qkv_w: torch.Tensor, qkv_bias: torch.Tensor, relative_position_index: torch.Tensor, diff --git a/examples/vision/swin/blocks/mlp.py b/examples/vision/swin/blocks/mlp.py index 1873cc53..d36d1456 100644 --- a/examples/vision/swin/blocks/mlp.py +++ b/examples/vision/swin/blocks/mlp.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn -import cube +import nnscaler -@cube.graph.parser.register('B HW^ E^, H+ E^, H+, E^ H+ -> B HW^ E^', name='feedforward') +@nnscaler.graph.parser.register('B HW^ E^, H+ E^, H+, E^ H+ -> B HW^ E^', name='feedforward') def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj1_bias: torch.Tensor, proj2: torch.Tensor, dropout: float) -> torch.Tensor: diff --git a/examples/vision/swin/blocks/patch.py b/examples/vision/swin/blocks/patch.py index 3b48677a..77cf3b8d 100644 --- a/examples/vision/swin/blocks/patch.py +++ b/examples/vision/swin/blocks/patch.py @@ -3,10 +3,10 @@ import torch import torch.nn as nn -import cube +import nnscaler -@cube.graph.parser.register('B (2 h^ 2 w^) C^ -> B (h w) (4 C)') +@nnscaler.graph.parser.register('B (2 h^ 2 w^) C^ -> B (h w) (4 C)') def patch_merge(x: torch.Tensor, h: int, w: int): B, L, C = x.shape H = 2 * h @@ -22,7 +22,7 @@ def patch_merge(x: torch.Tensor, h: int, w: int): x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C return x -@cube.graph.parser.register('B ic+ (ps^ w^) (ps^ h^), oc ic+ k^ k^, oc -> B oc w^ h^') +@nnscaler.graph.parser.register('B ic+ (ps^ w^) (ps^ h^), oc ic+ k^ k^, oc -> B oc w^ h^') def patch(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, ps: int): """ @param ps int: patch size diff --git a/examples/vision/swin/blocks/transformer.py b/examples/vision/swin/blocks/transformer.py index 7bda0406..559de247 100644 --- a/examples/vision/swin/blocks/transformer.py +++ b/examples/vision/swin/blocks/transformer.py @@ -5,10 +5,10 @@ from examples.vision.swin.blocks.attention import WindowAttention from examples.vision.swin.blocks.mlp import Mlp -import cube +import nnscaler -@cube.graph.parser.register('* -> *') +@nnscaler.graph.parser.register('* -> *') def drop_path(x: torch.Tensor, drop_prob: float, training: bool): if drop_prob <= 0. or not training: return x @@ -20,7 +20,7 @@ def drop_path(x: torch.Tensor, drop_prob: float, training: bool): return output -@cube.graph.parser.register('B (nh ws) (nw ws) C -> (B nh nw) ws ws C') +@nnscaler.graph.parser.register('B (nh ws) (nw ws) C -> (B nh nw) ws ws C') def window_partition(x: torch.Tensor, ws: int): """ Args: @@ -36,7 +36,7 @@ def window_partition(x: torch.Tensor, ws: int): return windows -@cube.graph.parser.register('(B nh nw) ws ws C -> B (nh ws) (nw ws) C') +@nnscaler.graph.parser.register('(B nh nw) ws ws C -> B (nh ws) (nw ws) C') def window_reverse(windows: torch.Tensor, ws: int, nh: int, nw: int): """ Args: diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index 5f4c8c44..a2853ca8 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -5,7 +5,7 @@ from examples.vision.swin.blocks.transformer import SwinTransformerBlock from examples.vision.swin.blocks.patch import PatchEmbed, PatchMerging -import cube +import nnscaler class Config: @@ -110,7 +110,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, def forward(self, x): for blk in self.blocks: - cube.runtime.function.anchor('transformer block start') + nnscaler.runtime.function.anchor('transformer block start') x = blk(x) if self.downsample is not None: x = self.downsample(x) diff --git a/examples/vision/swin/policy/gallery.py b/examples/vision/swin/policy/gallery.py index 79a6956b..bdc16cfa 100644 --- a/examples/vision/swin/policy/gallery.py +++ b/examples/vision/swin/policy/gallery.py @@ -1,9 +1,9 @@ from typing import List -from cube.graph import IRGraph -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.schedule.predefined import PredefinedSched -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from nnscaler.graph import IRGraph +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.schedule.predefined import PredefinedSched +from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from examples.utils import tensor_parallelism, replica, group_to_layers diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index acedd297..26f64542 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -12,10 +12,10 @@ from examples.vision.swin.blocks.attention import init_relative_position_index from examples.vision.swin.model import Config, SwinTransformer, dummy_data -import cube -from cube.profiler.timer import CudaTimer, print_each_rank -from cube.profiler.memory import memory_summary -from cube.runtime.utils import microbatches +import nnscaler +from nnscaler.profiler.timer import CudaTimer, print_each_rank +from nnscaler.profiler.memory import memory_summary +from nnscaler.runtime.utils import microbatches import examples.vision.swin.policy.gallery as gallery from examples.utils import get_policy @@ -35,7 +35,7 @@ parser.add_argument('--mbs', type=int, default=4, help='micro batch size') args = parser.parse_args() -cube.init() +nnscaler.init() # get policy @@ -62,12 +62,12 @@ def train(): gen_data = partial(dummy_data, args.mbs, torch.float16, cfg) dataloader = microbatches((gen_data(),)) - @cube.compile(model, dataloader, PAS=policy, load_content=load_content) + @nnscaler.compile(model, dataloader, PAS=policy, load_content=load_content) def train_iter(model, dataloader): imgs = next(dataloader) loss = model(imgs) loss.backward() - model = cube.utils.load_model() + model = nnscaler.utils.load_model() if not load_content: for name, buffer in model.named_buffers(): diff --git a/cube/__init__.py b/nnscaler/__init__.py similarity index 68% rename from cube/__init__.py rename to nnscaler/__init__.py index 647c6a24..8b0baa8c 100644 --- a/cube/__init__.py +++ b/nnscaler/__init__.py @@ -1,17 +1,17 @@ from typing import Optional import logging -from cube import runtime -from cube import utils +from nnscaler import runtime +from nnscaler import utils -from cube import profiler -from cube.profiler.timer import CudaTimer +from nnscaler import profiler +from nnscaler.profiler.timer import CudaTimer -from cube.compiler import SemanticModel, compile +from nnscaler.compiler import SemanticModel, compile -from cube.utils import load_model, load_default_schedule, load_eval_schedule -from cube.utils import accum_mode +from nnscaler.utils import load_model, load_default_schedule, load_eval_schedule +from nnscaler.utils import accum_mode -from cube.flags import CompileFlag +from nnscaler.flags import CompileFlag def _check_torch_version(): diff --git a/cube/algorithm/__init__.py b/nnscaler/algorithm/__init__.py similarity index 100% rename from cube/algorithm/__init__.py rename to nnscaler/algorithm/__init__.py diff --git a/cube/algorithm/factory.py b/nnscaler/algorithm/factory.py similarity index 95% rename from cube/algorithm/factory.py rename to nnscaler/algorithm/factory.py index 37c8a12d..12781511 100644 --- a/cube/algorithm/factory.py +++ b/nnscaler/algorithm/factory.py @@ -59,10 +59,10 @@ def algorithms(self, op, tag = None): def _load_predefined_algos(self): - import cube.algorithm.ops.dimops as dimops + import nnscaler.algorithm.ops.dimops as dimops self.register(dimops.IRDimops, dimops.DimSplitEinops, tag='dim') - import cube.algorithm.ops.conv as conv + import nnscaler.algorithm.ops.conv as conv self.register(conv.IRPad, conv.DimSplitPad, tag='dim') self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') self.register(conv.IRConv2D, conv.HaloSplitConv2D, tag='halo') diff --git a/cube/algorithm/generics.py b/nnscaler/algorithm/generics.py similarity index 97% rename from cube/algorithm/generics.py rename to nnscaler/algorithm/generics.py index e537c165..3c970e77 100644 --- a/cube/algorithm/generics.py +++ b/nnscaler/algorithm/generics.py @@ -1,6 +1,6 @@ from typing import List, Optional -from cube.ir.cten import IRCell +from nnscaler.ir.cten import IRCell class GenericDistAlgo: diff --git a/cube/algorithm/ops/__init__.py b/nnscaler/algorithm/ops/__init__.py similarity index 100% rename from cube/algorithm/ops/__init__.py rename to nnscaler/algorithm/ops/__init__.py diff --git a/cube/algorithm/ops/conv.py b/nnscaler/algorithm/ops/conv.py similarity index 98% rename from cube/algorithm/ops/conv.py rename to nnscaler/algorithm/ops/conv.py index d04f470b..cdf21d4c 100644 --- a/cube/algorithm/ops/conv.py +++ b/nnscaler/algorithm/ops/conv.py @@ -1,10 +1,10 @@ from typing import List, Tuple -from cube.ir.tensor import IRSubTensor -from cube.algorithm.generics import GenericDistAlgo +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.algorithm.generics import GenericDistAlgo -from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D +from nnscaler.graph.function.conv import IRPad, IRConv2D, IRConv3D def _split_axis_custom(tensor: IRSubTensor, dim: int, chunks: List[Tuple[int, int]]): diff --git a/cube/algorithm/ops/dimops.py b/nnscaler/algorithm/ops/dimops.py similarity index 98% rename from cube/algorithm/ops/dimops.py rename to nnscaler/algorithm/ops/dimops.py index 26b2b69c..9c200ad1 100644 --- a/cube/algorithm/ops/dimops.py +++ b/nnscaler/algorithm/ops/dimops.py @@ -1,12 +1,12 @@ from typing import List, Optional, Any, Dict, Union, Tuple import numpy as np import logging -from cube.algorithm.generics import GenericDistAlgo +from nnscaler.algorithm.generics import GenericDistAlgo -from cube.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule -from cube.ir.tensor import IRSubTensor -from cube.ir.cten import IRTensor -from cube.ir.operator import IRFwOperation +from nnscaler.graph.function.dimops import IRDimops, DimAnno, DimopSplit, TransformRule +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir.cten import IRTensor +from nnscaler.ir.operator import IRFwOperation from collections import deque _logger = logging.getLogger(__name__) diff --git a/nnscaler/codegen/__init__.py b/nnscaler/codegen/__init__.py new file mode 100644 index 00000000..ad3e3973 --- /dev/null +++ b/nnscaler/codegen/__init__.py @@ -0,0 +1,2 @@ +from nnscaler.codegen.module.module import ModuleCodeGen +from nnscaler.codegen.schedule.schedule import ScheduleCodeGen diff --git a/cube/codegen/emit.py b/nnscaler/codegen/emit.py similarity index 96% rename from cube/codegen/emit.py rename to nnscaler/codegen/emit.py index ec790afd..b278de0f 100644 --- a/cube/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -1,17 +1,17 @@ from typing import Generator, Iterable, List, Any, Optional, Tuple import logging -from cube.ir.cten import IRCell, IRTensor, IRObject -from cube.ir.tensor import IRSubTensor -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.ir.adapter import IRWeightReducer, IRAdapter -from cube.ir.adapter.prim import CommPrim +from nnscaler.ir.cten import IRCell, IRTensor, IRObject +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.ir.adapter import IRWeightReducer, IRAdapter +from nnscaler.ir.adapter.prim import CommPrim -from cube.graph.segment import IRSegment +from nnscaler.graph.segment import IRSegment -from cube.codegen.frontend_mapping import Sign2EmitRule +from nnscaler.codegen.frontend_mapping import Sign2EmitRule -from cube.flags import CompileFlag +from nnscaler.flags import CompileFlag _logger = logging.getLogger(__name__) diff --git a/cube/codegen/frontend_mapping.py b/nnscaler/codegen/frontend_mapping.py similarity index 95% rename from cube/codegen/frontend_mapping.py rename to nnscaler/codegen/frontend_mapping.py index a09bd171..26d76549 100644 --- a/cube/codegen/frontend_mapping.py +++ b/nnscaler/codegen/frontend_mapping.py @@ -3,9 +3,9 @@ from typing import Callable, Dict, List, Optional -from cube import ir -from cube.ir.cten import IRTensor -from cube.ir.operator import IRFwOperation +from nnscaler import ir +from nnscaler.ir.cten import IRTensor +from nnscaler.ir.operator import IRFwOperation class Sign2EmitRule: diff --git a/cube/codegen/lifecycle.py b/nnscaler/codegen/lifecycle.py similarity index 94% rename from cube/codegen/lifecycle.py rename to nnscaler/codegen/lifecycle.py index 0be80895..9aff2052 100644 --- a/cube/codegen/lifecycle.py +++ b/nnscaler/codegen/lifecycle.py @@ -1,12 +1,12 @@ from typing import Iterable, Dict, List, Any import itertools -from cube.ir.cten import IRCell, IRTensor, IRObject -from cube.ir.tensor import IRSubTensor -from cube.graph.segment import IRSegment -from cube.execplan.execplan import ExeReuseCell +from nnscaler.ir.cten import IRCell, IRTensor, IRObject +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.graph.segment import IRSegment +from nnscaler.execplan.execplan import ExeReuseCell -from cube.codegen.emit import FuncEmission +from nnscaler.codegen.emit import FuncEmission class LifeCycle: diff --git a/cube/codegen/module/__init__.py b/nnscaler/codegen/module/__init__.py similarity index 100% rename from cube/codegen/module/__init__.py rename to nnscaler/codegen/module/__init__.py diff --git a/cube/codegen/module/autograd.py b/nnscaler/codegen/module/autograd.py similarity index 90% rename from cube/codegen/module/autograd.py rename to nnscaler/codegen/module/autograd.py index 25fa2782..a6e25ba8 100644 --- a/cube/codegen/module/autograd.py +++ b/nnscaler/codegen/module/autograd.py @@ -1,12 +1,12 @@ from typing import List -from cube.ir.tensor import IRSubTensor -from cube.ir.adapter import IRAdapter -from cube.ir.adapter.prim import IRAdapterPrim +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir.adapter import IRAdapter +from nnscaler.ir.adapter.prim import IRAdapterPrim -from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock +from nnscaler.codegen.syntax.blocks import ClassBlock, FunctionBlock -from cube.codegen.emit import FuncEmission +from nnscaler.codegen.emit import FuncEmission class AutogradAdapterCodeGen(FuncEmission): diff --git a/cube/codegen/module/module.py b/nnscaler/codegen/module/module.py similarity index 96% rename from cube/codegen/module/module.py rename to nnscaler/codegen/module/module.py index f2d07d7c..a79f01f3 100644 --- a/cube/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -6,26 +6,26 @@ import numpy as np import inspect -from cube.ir.cten import IRCell -from cube.ir.tensor import IRSubTensor -from cube.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation -from cube.ir.adapter import IRWeightReducer, IRAdapter -from cube.ir.adapter.prim import CollectivePrim +from nnscaler.ir.cten import IRCell +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation +from nnscaler.ir.adapter import IRWeightReducer, IRAdapter +from nnscaler.ir.adapter.prim import CollectivePrim -from cube.graph.graph import IRSegment -from cube.graph.parser.register import CustomizedOps +from nnscaler.graph.graph import IRSegment +from nnscaler.graph.parser.register import CustomizedOps -from cube.execplan import ExecutionPlan -from cube.execplan.execplan import ExeReuseCell +from nnscaler.execplan import ExecutionPlan +from nnscaler.execplan.execplan import ExeReuseCell -from cube.codegen.syntax.symtable import SymbolTable -from cube.codegen.syntax.blocks import ClassBlock, FunctionBlock +from nnscaler.codegen.syntax.symtable import SymbolTable +from nnscaler.codegen.syntax.blocks import ClassBlock, FunctionBlock -from cube.codegen.emit import FuncEmission -from cube.codegen.module.autograd import AutogradAdapterCodeGen -from cube.codegen.lifecycle import LifeCycle +from nnscaler.codegen.emit import FuncEmission +from nnscaler.codegen.module.autograd import AutogradAdapterCodeGen +from nnscaler.codegen.lifecycle import LifeCycle -from cube.flags import CompileFlag +from nnscaler.flags import CompileFlag _logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ class ModuleCodeGen(FuncEmission): # intermediate codes for 'adapter456(self, tensor_4444)' [ - 'tensor_5555 = cube.runtime.adapter.all_reduce(tensor_4444, ranks=[0,1,2,3])' + 'tensor_5555 = nnscaler.runtime.adapter.all_reduce(tensor_4444, ranks=[0,1,2,3])' ] ] ``` @@ -119,7 +119,7 @@ def __init__( 'from typing import *', 'from pathlib import Path', 'import torch', 'import torch.utils.checkpoint as ckpt', - 'import cube', 'import _operator', 'from numpy import inf', 'import builtins', '', ''] + 'import nnscaler', 'import _operator', 'from numpy import inf', 'import builtins', '', ''] if CompileFlag.use_nnfusion: self.init_code.extend(['import nnfusion', '']) @@ -460,7 +460,7 @@ def forward(self, x, y=None, z=None): # generate full code with ClassBlock( class_name='GenModel', - derived=[f'cube.runtime.module.{"ParallelModule" if as_parallel_module else "CubeModule"}'] + derived=[f'nnscaler.runtime.module.{"ParallelModule" if as_parallel_module else "CubeModule"}'] ) as cb: if as_parallel_module: cb.insert_body(f'rank = {device}') # save rank in class level @@ -677,7 +677,7 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: reduce_op = f"'{CompileFlag.reducer_op}'" # reducer init interface reducer_init = ( - "{reducer} = cube.runtime.adapter.Reducer(" + "{reducer} = nnscaler.runtime.adapter.Reducer(" "ranks={ranks}, reduce_op={reduce_op}, " "async_op={async_op}, zero={zero}, max_bucket_size_bytes={max_nbytes}, " "zero_ngroups={zero_ngroups})" @@ -766,7 +766,7 @@ def _emit_nodes(self, nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: ``` tensor_2222 = torch.view(tensor_1111, size=[3,6,9]) del tensor_1111 # if no more reference - tensor_3333 = cube.runtime.adapter.allgather_reducescatter(tensor_2222, dim=1, rank=[0,1]) + tensor_3333 = nnscaler.runtime.adapter.allgather_reducescatter(tensor_2222, dim=1, rank=[0,1]) del tensor_2222 # if no more reference ``` diff --git a/cube/codegen/schedule/__init__.py b/nnscaler/codegen/schedule/__init__.py similarity index 100% rename from cube/codegen/schedule/__init__.py rename to nnscaler/codegen/schedule/__init__.py diff --git a/cube/codegen/schedule/schedule.py b/nnscaler/codegen/schedule/schedule.py similarity index 88% rename from cube/codegen/schedule/schedule.py rename to nnscaler/codegen/schedule/schedule.py index 496d7563..a81ab30a 100644 --- a/cube/codegen/schedule/schedule.py +++ b/nnscaler/codegen/schedule/schedule.py @@ -3,18 +3,18 @@ import copy import logging -from cube.ir.cten import IRCell, IRTensor -from cube.ir.operator import IRDataOperation, IRFwOperation -from cube.ir.tensor import IRSubTensor -from cube.ir.adapter import IRWeightReducer, IRAdapter -from cube.graph.graph import IRSegment +from nnscaler.ir.cten import IRCell, IRTensor +from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir.adapter import IRWeightReducer, IRAdapter +from nnscaler.graph.graph import IRSegment -from cube.execplan.execplan import ExecutionPlan, ExeReuseCell +from nnscaler.execplan.execplan import ExecutionPlan, ExeReuseCell -from cube.codegen.emit import FuncEmission -from cube.codegen.syntax.symtable import SymbolTable -from cube.codegen.lifecycle import LifeCycle -from cube.codegen.syntax.blocks import FunctionBlock +from nnscaler.codegen.emit import FuncEmission +from nnscaler.codegen.syntax.symtable import SymbolTable +from nnscaler.codegen.lifecycle import LifeCycle +from nnscaler.codegen.syntax.blocks import FunctionBlock _logger = logging.getLogger(__name__) @@ -57,7 +57,7 @@ def __init__( # model full code self.init_code: List[str] = [ '\n\n########## Generated Schedule Code ###########', - 'import torch', 'import cube', ''] + 'import torch', 'import nnscaler', ''] # module member name self.symbols = SymbolTable() @@ -137,9 +137,9 @@ def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: """ Emit node / subgraph code """ - fsign = '{outputs} = cube.runtime.executor.fexecute({name}, {model}, *{inputs}, requires_grad={req_grad})' - asign = '{outputs} = cube.runtime.executor.aexecute({model}, *{inputs}, requires_grad={req_grad})' - bsign = '{input_grads} = cube.runtime.executor.backward({name}, {input_tensors}, {output_tensors}, {output_grads})' + fsign = '{outputs} = nnscaler.runtime.executor.fexecute({name}, {model}, *{inputs}, requires_grad={req_grad})' + asign = '{outputs} = nnscaler.runtime.executor.aexecute({model}, *{inputs}, requires_grad={req_grad})' + bsign = '{input_grads} = nnscaler.runtime.executor.backward({name}, {input_tensors}, {output_tensors}, {output_grads})' node_inputs, node_outputs = node.inputs(), node.outputs() req_grad = any(t.requires_grad for t in node.outputs() if isinstance(t, IRTensor)) diff --git a/cube/codegen/syntax/__init__.py b/nnscaler/codegen/syntax/__init__.py similarity index 100% rename from cube/codegen/syntax/__init__.py rename to nnscaler/codegen/syntax/__init__.py diff --git a/cube/codegen/syntax/blocks.py b/nnscaler/codegen/syntax/blocks.py similarity index 100% rename from cube/codegen/syntax/blocks.py rename to nnscaler/codegen/syntax/blocks.py diff --git a/cube/codegen/syntax/symtable.py b/nnscaler/codegen/syntax/symtable.py similarity index 100% rename from cube/codegen/syntax/symtable.py rename to nnscaler/codegen/syntax/symtable.py diff --git a/cube/compiler.py b/nnscaler/compiler.py similarity index 89% rename from cube/compiler.py rename to nnscaler/compiler.py index 563966bb..7ad96093 100644 --- a/cube/compiler.py +++ b/nnscaler/compiler.py @@ -4,32 +4,32 @@ import os import logging -import cube +import nnscaler -from cube.ir.cten import IRObject -from cube.ir.tensor import IRFullTensor -from cube.ir.unique import IDGenerator -from cube.graph.gener.gen import IRAdapterGener -from cube.graph.graph import IRGraph -from cube.ir.operator import IRBpOperation -from cube.ir.cten import IRObject -from cube.ir.tensor import IRFullTensor -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.schedule.schedplan import SchedulePlan +from nnscaler.ir.cten import IRObject +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.ir.unique import IDGenerator +from nnscaler.graph.gener.gen import IRAdapterGener +from nnscaler.graph.graph import IRGraph +from nnscaler.ir.operator import IRBpOperation +from nnscaler.ir.cten import IRObject +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.schedule.schedplan import SchedulePlan -from cube.execplan import ExecutionPlan -from cube.execplan.planpass.fusion import DiffFusion -from cube.execplan.planpass.grouping import Grouping +from nnscaler.execplan import ExecutionPlan +from nnscaler.execplan.planpass.fusion import DiffFusion +from nnscaler.execplan.planpass.grouping import Grouping -from cube.codegen import ModuleCodeGen, ScheduleCodeGen +from nnscaler.codegen import ModuleCodeGen, ScheduleCodeGen -from cube.runtime.device import DeviceGroup -from cube.runtime.utils import MicroBatchDataLoader +from nnscaler.runtime.device import DeviceGroup +from nnscaler.runtime.utils import MicroBatchDataLoader -from cube.program import Program, SemanticDataLoader, SemanticModel -from cube.flags import CompileFlag -from cube.utils import print_each_rank +from nnscaler.program import Program, SemanticDataLoader, SemanticModel +from nnscaler.flags import CompileFlag +from nnscaler.utils import print_each_rank _logger = logging.getLogger(__name__) @@ -50,7 +50,7 @@ def compile(model: Union[torch.nn.Module, SemanticModel], *args, Examples: ``` - @cube.compile(model, data, PAS=policy) + @nnscaler.compile(model, data, PAS=policy) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) @@ -88,7 +88,7 @@ def train_iter(model, dataloader): if isinstance(model, torch.nn.Module): model = SemanticModel(model) - assert isinstance(model, SemanticModel), f'Require cube.SemanticModel or torch.nn.Module, but got model: {type(model)}' + assert isinstance(model, SemanticModel), f'Require nnscaler.SemanticModel or torch.nn.Module, but got model: {type(model)}' model.save_content = load_content model.dynamic_shape = model_dynamic_shape @@ -120,7 +120,7 @@ def decorator(fn: Callable) -> Callable: model.load_module(filename) # load schedule code _logger.info(f'loading existed schedule from {filename} ...') - return cube.load_default_schedule(filename) + return nnscaler.load_default_schedule(filename) ndevices = DeviceGroup().world_size local_ndevs = DeviceGroup().local_world_size @@ -132,7 +132,7 @@ def decorator(fn: Callable) -> Callable: if DeviceGroup().local_rank == 0: compile_start = time.time() - resource = cube.runtime.resource.EnvResource() + resource = nnscaler.runtime.resource.EnvResource() # run once to get model structure and tensor shape graph = None @@ -299,6 +299,6 @@ def decorator(fn: Callable) -> Callable: model.dummy_input = None # load temporal schedule print_each_rank(f'loading generated schedule from {filename} ...', logger=_logger) - return cube.load_default_schedule(filename) + return nnscaler.load_default_schedule(filename) return decorator diff --git a/nnscaler/execplan/__init__.py b/nnscaler/execplan/__init__.py new file mode 100644 index 00000000..a542cec3 --- /dev/null +++ b/nnscaler/execplan/__init__.py @@ -0,0 +1 @@ +from nnscaler.execplan.execplan import ExecutionPlan \ No newline at end of file diff --git a/cube/execplan/execplan.py b/nnscaler/execplan/execplan.py similarity index 98% rename from cube/execplan/execplan.py rename to nnscaler/execplan/execplan.py index 7aed4990..fd4ba6ea 100644 --- a/cube/execplan/execplan.py +++ b/nnscaler/execplan/execplan.py @@ -3,12 +3,12 @@ import numpy as np import sys -from cube.ir.cten import IRCell, IRObject -from cube.ir.tensor import IRSubTensor, IRFullTensor -from cube.ir.adapter import IRAdapter, IRWeightReducer -from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.graph.graph import IRGraph, IRSegment -from cube.graph.schedule.schedplan import SchedulePlan, Block +from nnscaler.ir.cten import IRCell, IRObject +from nnscaler.ir.tensor import IRSubTensor, IRFullTensor +from nnscaler.ir.adapter import IRAdapter, IRWeightReducer +from nnscaler.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation +from nnscaler.graph.graph import IRGraph, IRSegment +from nnscaler.graph.schedule.schedplan import SchedulePlan, Block class ExeReuseCell(IRCell): diff --git a/cube/execplan/planpass/__init__.py b/nnscaler/execplan/planpass/__init__.py similarity index 100% rename from cube/execplan/planpass/__init__.py rename to nnscaler/execplan/planpass/__init__.py diff --git a/cube/execplan/planpass/fusion.py b/nnscaler/execplan/planpass/fusion.py similarity index 86% rename from cube/execplan/planpass/fusion.py rename to nnscaler/execplan/planpass/fusion.py index ebe5024e..0b27c70b 100644 --- a/cube/execplan/planpass/fusion.py +++ b/nnscaler/execplan/planpass/fusion.py @@ -1,20 +1,20 @@ from typing import List, Union, Set import logging -from cube.graph.graph import IRSegment +from nnscaler.graph.graph import IRSegment -from cube.ir.adapter import IRAdapter +from nnscaler.ir.adapter import IRAdapter -from cube.execplan import ExecutionPlan -from cube.execplan.execplan import ExeReuseCell -from cube.execplan.planpass.planpass import PlanPass +from nnscaler.execplan import ExecutionPlan +from nnscaler.execplan.execplan import ExeReuseCell +from nnscaler.execplan.planpass.planpass import PlanPass -from cube.ir.adapter.prim import IRAdapterPrim -from cube.ir.adapter.prim import AllReducePrim, AllGatherPrim, ReduceScatterPrim, AllToAllPrim -from cube.ir.adapter.prim import IdentityPrim, ChunkPrim -from cube.ir.adapter.prim import IdentityAllreducePrim, AllReduceIdentityPrim, AllReduceAllReducePrim -from cube.ir.adapter.prim import AllGatherReduceScatterPrim, ReduceScatterAllGatherPrim -from cube.ir.adapter.prim import SplitAllGatherPrim, AllGatherSplitPrim -from cube.ir.adapter.prim import AllToAllAllToAllPrim +from nnscaler.ir.adapter.prim import IRAdapterPrim +from nnscaler.ir.adapter.prim import AllReducePrim, AllGatherPrim, ReduceScatterPrim, AllToAllPrim +from nnscaler.ir.adapter.prim import IdentityPrim, ChunkPrim +from nnscaler.ir.adapter.prim import IdentityAllreducePrim, AllReduceIdentityPrim, AllReduceAllReducePrim +from nnscaler.ir.adapter.prim import AllGatherReduceScatterPrim, ReduceScatterAllGatherPrim +from nnscaler.ir.adapter.prim import SplitAllGatherPrim, AllGatherSplitPrim +from nnscaler.ir.adapter.prim import AllToAllAllToAllPrim _logger = logging.getLogger(__name__) diff --git a/cube/execplan/planpass/grouping.py b/nnscaler/execplan/planpass/grouping.py similarity index 89% rename from cube/execplan/planpass/grouping.py rename to nnscaler/execplan/planpass/grouping.py index 3f6ea81d..dc3e3055 100644 --- a/cube/execplan/planpass/grouping.py +++ b/nnscaler/execplan/planpass/grouping.py @@ -3,15 +3,15 @@ """ from typing import List, Dict, Tuple -from cube.execplan import ExecutionPlan -from cube.execplan.planpass.planpass import PlanPass -from cube.ir.adapter import IRAdapter -from cube.ir.adapter.prim import IdentityPrim -from cube.ir.operator import IRFwOperation -from cube.graph.function.pyfunc import IRPyFunc -from cube.ir.cten import IRCell +from nnscaler.execplan import ExecutionPlan +from nnscaler.execplan.planpass.planpass import PlanPass +from nnscaler.ir.adapter import IRAdapter +from nnscaler.ir.adapter.prim import IdentityPrim +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.ir.cten import IRCell -from cube.flags import CompileFlag +from nnscaler.flags import CompileFlag class Grouping(PlanPass): diff --git a/cube/execplan/planpass/planpass.py b/nnscaler/execplan/planpass/planpass.py similarity index 74% rename from cube/execplan/planpass/planpass.py rename to nnscaler/execplan/planpass/planpass.py index 3d079b2b..373d21ad 100644 --- a/cube/execplan/planpass/planpass.py +++ b/nnscaler/execplan/planpass/planpass.py @@ -1,4 +1,4 @@ -from cube.execplan import ExecutionPlan +from nnscaler.execplan import ExecutionPlan class PlanPass: diff --git a/cube/flags.py b/nnscaler/flags.py similarity index 100% rename from cube/flags.py rename to nnscaler/flags.py diff --git a/nnscaler/graph/__init__.py b/nnscaler/graph/__init__.py new file mode 100644 index 00000000..b258cbe6 --- /dev/null +++ b/nnscaler/graph/__init__.py @@ -0,0 +1,2 @@ +from nnscaler.graph.graph import IRGraph +from nnscaler.graph import parser diff --git a/nnscaler/graph/function/__init__.py b/nnscaler/graph/function/__init__.py new file mode 100644 index 00000000..8afb8034 --- /dev/null +++ b/nnscaler/graph/function/__init__.py @@ -0,0 +1,2 @@ +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.function.function import * \ No newline at end of file diff --git a/cube/graph/function/anchor.py b/nnscaler/graph/function/anchor.py similarity index 89% rename from cube/graph/function/anchor.py rename to nnscaler/graph/function/anchor.py index 4a999f24..8f7fd236 100644 --- a/cube/graph/function/anchor.py +++ b/nnscaler/graph/function/anchor.py @@ -1,6 +1,6 @@ -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRObject +from nnscaler.ir.operator import IRFwOperation +from nnscaler.ir.cten import IRObject class IRGraphAnchor(IRFwOperation): @@ -23,7 +23,7 @@ def __init__(self): def forward(self, x): for layer in self.layers: - cube.runtime.function.anchor('layer start') + nnscaler.runtime.function.anchor('layer start') x = layer(x) return x ``` diff --git a/cube/graph/function/conv.py b/nnscaler/graph/function/conv.py similarity index 96% rename from cube/graph/function/conv.py rename to nnscaler/graph/function/conv.py index 771343fa..85e8e99c 100644 --- a/cube/graph/function/conv.py +++ b/nnscaler/graph/function/conv.py @@ -1,7 +1,7 @@ from typing import List -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRTensor +from nnscaler.ir.operator import IRFwOperation +from nnscaler.ir.cten import IRTensor class IRPad(IRFwOperation): @@ -52,7 +52,7 @@ class IRConv2D(IRFwOperation): def __init__(self, signature: str, inputs: List[IRTensor], name: str, **kwargs): - signature = 'cube.runtime.function.conv2d' + signature = 'nnscaler.runtime.function.conv2d' assert len(inputs) == 3, "Expected only input, weight, bias as inputs" assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" super().__init__(name, signature, inputs, 1, **kwargs) @@ -98,7 +98,7 @@ class IRConv3D(IRFwOperation): def __init__(self, signature: str, inputs: List[IRTensor], name: str, **kwargs): - signature = 'cube.runtime.function.conv3d' + signature = 'nnscaler.runtime.function.conv3d' assert len(inputs) == 3, "Expected only input, weight, bias as inputs" assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" super().__init__(name, signature, inputs, 1, **kwargs) diff --git a/cube/graph/function/dimops.py b/nnscaler/graph/function/dimops.py similarity index 99% rename from cube/graph/function/dimops.py rename to nnscaler/graph/function/dimops.py index 8929c179..23962a35 100644 --- a/cube/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -68,9 +68,9 @@ import string import logging -from cube.ir.cten import IRTensor, IRObject -from cube.ir.operator import IRFwOperation -from cube.algorithm.factory import DistAlgorithmFactory +from nnscaler.ir.cten import IRTensor, IRObject +from nnscaler.ir.operator import IRFwOperation +from nnscaler.algorithm.factory import DistAlgorithmFactory _kSpecialIdentifiers = ('*', '?') diff --git a/cube/graph/function/function.py b/nnscaler/graph/function/function.py similarity index 98% rename from cube/graph/function/function.py rename to nnscaler/graph/function/function.py index f3c84e16..6e86de91 100644 --- a/cube/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -8,12 +8,12 @@ import logging from collections.abc import Iterable -from cube.ir.cten import IRTensor, IRObject -from cube.ir.tensor import IRSubTensor, IRFullTensor -from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule -from cube.graph.function.conv import IRPad, IRConv2D, IRConv3D -from cube.graph.function.anchor import IRGraphAnchor +from nnscaler.ir.cten import IRTensor, IRObject +from nnscaler.ir.tensor import IRSubTensor, IRFullTensor +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule +from nnscaler.graph.function.conv import IRPad, IRConv2D, IRConv3D +from nnscaler.graph.function.anchor import IRGraphAnchor _logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ def ir_object_contains_dynamic(obj: IRObject): def Identity(tensor: IRObject, signature = None): - signature = 'cube.runtime.function.identity' + signature = 'nnscaler.runtime.function.identity' eshape = ShapeAnno.create_shape_str(tensor.shape) anno = OpAnno.create_op_str([eshape], [eshape]) return IRDimops(Identity, 'identity', signature, [anno], [tensor]) @@ -81,9 +81,9 @@ def Identity(tensor: IRObject, signature = None): def MultiRef(tensor: IRTensor, times: int, signature = None): """ - cube.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] + nnscaler.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] """ - signature = 'cube.runtime.function.multiref' + signature = 'nnscaler.runtime.function.multiref' assert isinstance(tensor, IRTensor), "require all inputs to be IRSubTensor" assert isinstance(times, int), "require int for second input" anno = '* -> ' + ', '.join('*' for _ in range(times)) @@ -93,10 +93,10 @@ def MultiRef(tensor: IRTensor, times: int, signature = None): def Accum(*inputs, signature = None): """ - tensor = cube.runtime.function.accum(tensors) + tensor = nnscaler.runtime.function.accum(tensors) """ assert all(isinstance(t, IRTensor) for t in inputs) - signature = 'cube.runtime.function.accum' + signature = 'nnscaler.runtime.function.accum' iannos = [ShapeAnno.create_shape_str(t.shape) for t in inputs] oannos = [copy.copy(iannos[0])] anno = OpAnno.create_op_str(iannos, oannos) @@ -136,7 +136,7 @@ def BMMAdd(input, batch1, batch2, *, beta=1, alpha=1, out=None, signature = None def CubeEinSum(*operands, equation=None, signature = None): assert isinstance(equation, str) - signature = 'cube.runtime.function.einsum' + signature = 'nnscaler.runtime.function.einsum' lhs, rhs = equation.split('->') assert ',' not in rhs lhs_dims = set(lhs.replace(',', ' ').split(' ')) @@ -201,7 +201,7 @@ def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Uni else: dtype = torch.int64 assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" - signature = 'cube.runtime.function.arange' + signature = 'nnscaler.runtime.function.arange' kwargs = {'start': start, 'end': end, 'step': step, 'dtype': dtype, 'requires_grad': requires_grad} start_val = start.value if isinstance(start, IRObject) else start @@ -234,7 +234,7 @@ def CubeLinspace(start: Union[int, IRObject], end: Union[int, IRObject], steps: dtype=None, requires_grad=False, signature=None): dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" - signature = 'cube.runtime.function.linspace' + signature = 'nnscaler.runtime.function.linspace' kwargs = {'start': start, 'end': end, 'steps': steps, 'dtype': dtype, 'requires_grad': requires_grad} steps_val = steps.value if isinstance(steps, IRObject) else steps @@ -259,7 +259,7 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi assert layout in (None, torch.strided) and memory_format is None, f"Not support for non-default memory_format and layout" dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" - signature = 'cube.runtime.function.empty' + signature = 'nnscaler.runtime.function.empty' size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) size: Tuple[Union[int, IRObject]] = size + arg_size kwargs = {'size': size, 'requires_grad': requires_grad, @@ -275,7 +275,7 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, assert layout in (None, torch.strided), f"Not support for non-strided layout, get {layout}" dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" - signature = 'cube.runtime.function.zeros' + signature = 'nnscaler.runtime.function.zeros' size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) size: Tuple[Union[int, IRObject]] = size + arg_size kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} @@ -290,7 +290,7 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, assert layout in (None, torch.strided), f"Not support for non-strided layout, get {layout}" dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" - signature = 'cube.runtime.function.ones' + signature = 'nnscaler.runtime.function.ones' size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) size: Tuple[Union[int, IRObject]] = size + arg_size kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} @@ -305,7 +305,7 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir assert layout in (None, torch.strided) and memory_format is None, f"Not support for non-default memory_format and layout" dtype = dtype if dtype is not None else torch.get_default_dtype() assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" - signature = 'cube.runtime.function.rand' + signature = 'nnscaler.runtime.function.rand' size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) size: Tuple[Union[int, IRObject]] = size + arg_size kwargs = {'size': size, 'requires_grad': requires_grad, @@ -322,7 +322,7 @@ def Full(size, fill_value, *, out=None, dtype=None, layout=None, """ assert layout in (None, torch.strided), f"Not support for non-default layout" dtype = dtype if dtype is not None else torch.get_default_dtype() - signature = 'cube.runtime.function.full' + signature = 'nnscaler.runtime.function.full' # cube treat scalar as size (1,) tensor now, scalar support will in another pr if necessary size = tuple(size) if size else (1,) anno, rules = _get_creator_anno_rules( @@ -336,7 +336,7 @@ def NewTensor(data, *, dtype=None, device=None, """ torch.tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) """ - signature = 'cube.runtime.function.tensor' + signature = 'nnscaler.runtime.function.tensor' val = data if isinstance(data, IRTensor): @@ -794,7 +794,7 @@ def nnDropout(input, p=0.5, inplace=False, signature=None): """ torch.nn.Dropout(p=0.5, inplace=False) """ - signature = 'cube.runtime.function.nndropout' + signature = 'nnscaler.runtime.function.nndropout' annos = ['* -> *'] return IRDimops(nnDropout, 'Dropout', signature, annos, [input], p=p, inplace=inplace) @@ -931,9 +931,9 @@ def Where(condition, input=None, other=None, *, out=None, signature = None): def CubeLayerNorm(input, weight=None, bias=None, normalized_shape=None, eps=1e-05, signature = None): """ - cube.runtime.function.layer_norm(input, weight, bias, normliazed_shape, eps) + nnscaler.runtime.function.layer_norm(input, weight, bias, normliazed_shape, eps) """ - signature = 'cube.runtime.function.layer_norm' + signature = 'nnscaler.runtime.function.layer_norm' assert not (weight is None and bias is not None), f"Not support for None of weight and parameter of bias" letters = iter(string.ascii_lowercase) einput = ShapeAnno.create_shape_str(input.shape, iterator=letters) @@ -1516,7 +1516,7 @@ def CubeCat(*tensors, dim=0, signature = None): # with dimension. dim=None is for the support of kwarg inputs from torchfx assert all(isinstance(tensor, IRTensor) for tensor in tensors) assert isinstance(dim, int) - signature = 'cube.runtime.function.cat' + signature = 'nnscaler.runtime.function.cat' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] dimlens = [t.shape[dim] for t in tensors] for ashape, dimlen in zip(iannos, dimlens): @@ -1541,7 +1541,7 @@ def CubeStack(*tensors, dim=0, signature=None): # with dimension. assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'but got {tensors}' assert isinstance(dim, int), f"but not {dim}" - signature = 'cube.runtime.function.stack' + signature = 'nnscaler.runtime.function.stack' iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] oanno = [None for i in range(len(tensors[0].shape) + 1)] oanno[dim] = f'{len(tensors)}^' @@ -1594,7 +1594,7 @@ def Select(input, dim, index, signature = None): def CubeIndexSelect(input: torch.Tensor, index: torch.Tensor, dim: int, signature = None): - signature = 'cube.runtime.function.index_select' + signature = 'nnscaler.runtime.function.index_select' edim_in = ShapeAnno.create_shape_str(input.shape) edim_in[dim] += '^' idx_anno = chr(ord(edim_in[-1]) + 1) @@ -1622,7 +1622,7 @@ def FullSlice(tensor: IRTensor, *slicers: Tuple[Union[None, slice, int, IRTensor >>> a[(2, None, slice(None, None, None))] # shape [1,2] >>> a[(2, torch.tensor([0, 1]), None)] # shape [2,1] """ - signature = 'cube.runtime.function.fullslice' + signature = 'nnscaler.runtime.function.fullslice' # deal with ... in slice if any(slicer is Ellipsis for slicer in slicers): @@ -1767,7 +1767,7 @@ def SelectScatter(self: torch.Tensor, input: torch.Tensor, dim: int, index: int, torch.select_scatter(self:Tensor, input:Tensor, dim:int, index:int) -> Tensor """ # 'torch.select_scatter' isn't supported by Torch2ONNX yet. - signature = 'cube.runtime.function.select_scatter' + signature = 'nnscaler.runtime.function.select_scatter' # shape check self_shape, input_shape = self.shape, input.shape self_shape.pop(dim) @@ -1823,9 +1823,9 @@ def Repeat(tensor, repeats: _VariadicInt, *arg_repeats, signature = None): def CubeEmbedding(input, weight, padding_idx, signature = None, **kwargs): """ - cube.runtime.function.embedding(input, weight, padding_idx, start, stop) + nnscaler.runtime.function.embedding(input, weight, padding_idx, start, stop) """ - signature = 'cube.runtime.function.embedding' + signature = 'nnscaler.runtime.function.embedding' if isinstance(weight, IRSubTensor): start, stop = weight.indmap[0] else: @@ -1924,7 +1924,7 @@ def CrossEntropy(input, target, weight=None, def GraphAnchor(name: str, signature = None): """ - cube.runtime.function.anchor() -> None + nnscaler.runtime.function.anchor() -> None """ node = IRGraphAnchor(signature, name) return node @@ -2086,10 +2086,10 @@ def To(tensor: IRTensor, dtype_or_device=None, *, device=None, dtype=None, out=N if isinstance(dtype_or_device, torch.device) or isinstance(device, torch.device): warn_msg = 'Cube will handle the tensor device placement, the call of torch.Tensor.to(device=...) will be ignore, ' \ 'if you really want to put the tensor on cpu to excute some op, please wrap all related ops in an independent function ' \ - 'and using cube.graph.parser.register to register this function.' + 'and using nnscaler.graph.parser.register to register this function.' _logger.warning(warn_msg) # create "to" in cube runtime functions because dtype if not kwarg in torch.Tensor.to - signature = 'cube.runtime.function.to' + signature = 'nnscaler.runtime.function.to' annos = ['* -> *'] if isinstance(dtype_or_device, torch.device): # skip device movement as policy can determine device for the tensor. @@ -2133,8 +2133,8 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: def SetItem(__a: Any, __b: Any, __c: Any, signature = None) -> Union[Any, IRPyFunc]: - """_operator.setitem(__a, __b, __c) / cube.runtime.function.setitem(__a, __b, __c)""" - signature = 'cube.runtime.function.setitem' + """_operator.setitem(__a, __b, __c) / nnscaler.runtime.function.setitem(__a, __b, __c)""" + signature = 'nnscaler.runtime.function.setitem' obj, index, val = __a, __b, __c if isinstance(obj, IRTensor): # TODO: move to some function like FullSlice when ready diff --git a/cube/graph/function/pyfunc.py b/nnscaler/graph/function/pyfunc.py similarity index 92% rename from cube/graph/function/pyfunc.py rename to nnscaler/graph/function/pyfunc.py index 68430e31..ddc34dd8 100644 --- a/cube/graph/function/pyfunc.py +++ b/nnscaler/graph/function/pyfunc.py @@ -1,7 +1,7 @@ from typing import Tuple -from cube.ir.operator import IRFwOperation -from cube.ir.cten import IRObject +from nnscaler.ir.operator import IRFwOperation +from nnscaler.ir.cten import IRObject class IRPyFunc(IRFwOperation): diff --git a/cube/graph/gener/__init__.py b/nnscaler/graph/gener/__init__.py similarity index 100% rename from cube/graph/gener/__init__.py rename to nnscaler/graph/gener/__init__.py diff --git a/cube/graph/gener/concurrent.py b/nnscaler/graph/gener/concurrent.py similarity index 97% rename from cube/graph/gener/concurrent.py rename to nnscaler/graph/gener/concurrent.py index b16deb7f..c90facf7 100644 --- a/cube/graph/gener/concurrent.py +++ b/nnscaler/graph/gener/concurrent.py @@ -6,16 +6,16 @@ import numpy as np import logging -from cube.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap -from cube.ir.adapter.prim import IRAdapterPrim -from cube.ir.adapter import IRAdapter -from cube.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim -from cube.ir.adapter.prim import BroadcastPrim - -from cube.graph.gener.rvd.layout import RVDLayout -from cube.graph.gener.rvd.intra import IntraPathFinder -from cube.graph.gener.rvd.inter import InterPathFinder -from cube.flags import CompileFlag +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap +from nnscaler.ir.adapter.prim import IRAdapterPrim +from nnscaler.ir.adapter import IRAdapter +from nnscaler.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim +from nnscaler.ir.adapter.prim import BroadcastPrim + +from nnscaler.graph.gener.rvd.layout import RVDLayout +from nnscaler.graph.gener.rvd.intra import IntraPathFinder +from nnscaler.graph.gener.rvd.inter import InterPathFinder +from nnscaler.flags import CompileFlag _logger = logging.getLogger(__name__) diff --git a/cube/graph/gener/gen.py b/nnscaler/graph/gener/gen.py similarity index 98% rename from cube/graph/gener/gen.py rename to nnscaler/graph/gener/gen.py index 39170cd4..070c83d3 100644 --- a/cube/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -3,19 +3,19 @@ import itertools import logging -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.gener.concurrent import ConcurrentGener -import cube.graph.gener.utils as utils -from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment, CellPosition -from cube.graph.function.pyfunc import IRPyFunc - -from cube.ir.cten import IRCell, IRObject -from cube.ir.tensor import IRFullTensor, IRSubTensor -from cube.ir.operator import IRFwOperation - -from cube.ir.adapter import IRAdapter, IRWeightReducer -from cube.graph.function.function import Accum, Cat, MultiRef +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.gener.concurrent import ConcurrentGener +import nnscaler.graph.gener.utils as utils +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.segment import IRSegment, CellPosition +from nnscaler.graph.function.pyfunc import IRPyFunc + +from nnscaler.ir.cten import IRCell, IRObject +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor +from nnscaler.ir.operator import IRFwOperation + +from nnscaler.ir.adapter import IRAdapter, IRWeightReducer +from nnscaler.graph.function.function import Accum, Cat, MultiRef DeviceID = int diff --git a/cube/graph/gener/rvd/__init__.py b/nnscaler/graph/gener/rvd/__init__.py similarity index 100% rename from cube/graph/gener/rvd/__init__.py rename to nnscaler/graph/gener/rvd/__init__.py diff --git a/cube/graph/gener/rvd/inter.py b/nnscaler/graph/gener/rvd/inter.py similarity index 97% rename from cube/graph/gener/rvd/inter.py rename to nnscaler/graph/gener/rvd/inter.py index 41e35589..f05dbd30 100644 --- a/cube/graph/gener/rvd/inter.py +++ b/nnscaler/graph/gener/rvd/inter.py @@ -4,17 +4,17 @@ import sys import copy -from cube.ir.tensor import IRFullTensor +from nnscaler.ir.tensor import IRFullTensor -from cube.ir.adapter.prim import IRAdapterPrim -from cube.ir.adapter.prim import MovePrim # p2p -from cube.ir.adapter.prim import BroadcastPrim -from cube.ir.adapter.prim import RDScatterPrim, RVScatterPrim -from cube.ir.adapter.prim import RDGatherPrim, RVGatherPrim +from nnscaler.ir.adapter.prim import IRAdapterPrim +from nnscaler.ir.adapter.prim import MovePrim # p2p +from nnscaler.ir.adapter.prim import BroadcastPrim +from nnscaler.ir.adapter.prim import RDScatterPrim, RVScatterPrim +from nnscaler.ir.adapter.prim import RDGatherPrim, RVGatherPrim -from cube.graph.gener.rvd.layout import RVDLayout -from cube.graph.gener.rvd.intra import IntraPathFinder -from cube.graph.gener.utils import tensor_vd_repr +from nnscaler.graph.gener.rvd.layout import RVDLayout +from nnscaler.graph.gener.rvd.intra import IntraPathFinder +from nnscaler.graph.gener.utils import tensor_vd_repr TShape = Tuple[int, ...] diff --git a/cube/graph/gener/rvd/intra.py b/nnscaler/graph/gener/rvd/intra.py similarity index 97% rename from cube/graph/gener/rvd/intra.py rename to nnscaler/graph/gener/rvd/intra.py index 2a705a54..08ae8a00 100644 --- a/cube/graph/gener/rvd/intra.py +++ b/nnscaler/graph/gener/rvd/intra.py @@ -5,22 +5,22 @@ import logging import torch -from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor - -from cube.ir.adapter.prim import IRAdapterPrim -from cube.ir.adapter.prim import AllGatherPrim # d2r -from cube.ir.adapter.prim import AllToAllPrim # d2d -from cube.ir.adapter.prim import AllReducePrim # v2r -from cube.ir.adapter.prim import ReduceScatterPrim # v2d -from cube.ir.adapter.prim import ChunkPrim # r2d -from cube.ir.adapter.prim import VChunkPrim # r2v - -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.graph.gener.rvd.layout import RVDLayout - -from cube.graph.gener.utils import tensor_vd_repr +from nnscaler.ir.cten import IRCell +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor + +from nnscaler.ir.adapter.prim import IRAdapterPrim +from nnscaler.ir.adapter.prim import AllGatherPrim # d2r +from nnscaler.ir.adapter.prim import AllToAllPrim # d2d +from nnscaler.ir.adapter.prim import AllReducePrim # v2r +from nnscaler.ir.adapter.prim import ReduceScatterPrim # v2d +from nnscaler.ir.adapter.prim import ChunkPrim # r2d +from nnscaler.ir.adapter.prim import VChunkPrim # r2v + +from nnscaler.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.gener.rvd.layout import RVDLayout + +from nnscaler.graph.gener.utils import tensor_vd_repr _logger = logging.getLogger(__name__) TShape = Tuple[int, ...] diff --git a/cube/graph/gener/rvd/layout.py b/nnscaler/graph/gener/rvd/layout.py similarity index 98% rename from cube/graph/gener/rvd/layout.py rename to nnscaler/graph/gener/rvd/layout.py index 656b5261..7fd2b0b3 100644 --- a/cube/graph/gener/rvd/layout.py +++ b/nnscaler/graph/gener/rvd/layout.py @@ -2,9 +2,9 @@ import copy import numpy as np -from cube.ir.cten import IRCell -from cube.ir.tensor import IRFullTensor, IRSubTensor -from cube.ir.tensor import ValueMap +from nnscaler.ir.cten import IRCell +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor +from nnscaler.ir.tensor import ValueMap TShape = Tuple[int, ...] diff --git a/cube/graph/gener/utils.py b/nnscaler/graph/gener/utils.py similarity index 96% rename from cube/graph/gener/utils.py rename to nnscaler/graph/gener/utils.py index 0e8f8e5e..9fd36a96 100644 --- a/cube/graph/gener/utils.py +++ b/nnscaler/graph/gener/utils.py @@ -2,10 +2,10 @@ Utilities for gradient modification """ from typing import Dict, List, Union, Tuple -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap +from nnscaler.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRFwOperation +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor, ValueMap class DummyInputOuput(IRFwOperation): diff --git a/cube/graph/graph.py b/nnscaler/graph/graph.py similarity index 98% rename from cube/graph/graph.py rename to nnscaler/graph/graph.py index b51f1926..23c4a587 100644 --- a/cube/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -13,18 +13,18 @@ import dill import hashlib -from cube.ir.cten import IRTensor, IRCell, IRObject -from cube.ir.unique import IDGenerator -from cube.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation -from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap +from nnscaler.ir.cten import IRTensor, IRCell, IRObject +from nnscaler.ir.unique import IDGenerator +from nnscaler.ir.operator import IRBpOperation, IRFwOperation, IRDataOperation +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor, ValueMap -from cube.graph.function.function import Identity -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.function.dimops import IRDimops, OpAnno -from cube.graph.segment import IRSegment +from nnscaler.graph.function.function import Identity +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.function.dimops import IRDimops, OpAnno +from nnscaler.graph.segment import IRSegment -from cube.algorithm.generics import GenericDistAlgo +from nnscaler.algorithm.generics import GenericDistAlgo _logger = logging.getLogger(__name__) @@ -107,7 +107,7 @@ def forward(self, *args: Tuple[IRObject]) -> Union[IRTensor, Tuple[IRTensor]]: self.output(oidx), lambda t: t if t != iobj else arg) self.set_output(oidx, output) - from cube.program import Program + from nnscaler.program import Program Program().add_nodes(self.nodes()) # return @@ -476,7 +476,7 @@ def fuse_op_fn(*args, **kwargs) -> IRDimops: return IRDimops(fuse_op_fn, fuse_op_name, signature, [fuse_op_anno], args, **kwargs) if make_customized_op: - from cube.graph.parser.register import CustomizedOps + from nnscaler.graph.parser.register import CustomizedOps def to_name(t: Any) -> str: """Convert an object to its name.""" @@ -680,7 +680,7 @@ def _bind_schedule(self, schedplan): Returns: None """ - from cube.graph.schedule import SchedulePlan + from nnscaler.graph.schedule import SchedulePlan if not isinstance(schedplan, SchedulePlan): raise TypeError(f"Expect a SchedulePlan but got: {type(schedplan)}") assert self._sched is None, "The graph is already bound with one schedule plan." diff --git a/nnscaler/graph/parser/__init__.py b/nnscaler/graph/parser/__init__.py new file mode 100644 index 00000000..b06caafd --- /dev/null +++ b/nnscaler/graph/parser/__init__.py @@ -0,0 +1,4 @@ +from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph +from nnscaler.graph.parser.register import register +from nnscaler.graph.parser.external import * diff --git a/cube/graph/parser/converter.py b/nnscaler/graph/parser/converter.py similarity index 90% rename from cube/graph/parser/converter.py rename to nnscaler/graph/parser/converter.py index 5e11e0f0..1455c3c9 100644 --- a/cube/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -3,17 +3,17 @@ from pathlib import Path import operator -from cube.ir.tensor import IRFullTensor -from cube.graph.parser.register import CustomizedOps -from cube.graph import IRGraph -from cube.flags import CompileFlag +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph.parser.register import CustomizedOps +from nnscaler.graph import IRGraph +from nnscaler.flags import CompileFlag -from cube.graph.parser.fx.parser import FxModuleParser -from cube.graph.parser.fx.concrete_trace_utils import concrete_trace -from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import Location, is_autograd_apply, LeafFnWrapInfo -from cube.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops +from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.parser.fx.concrete_trace_utils import concrete_trace +from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import Location, is_autograd_apply, LeafFnWrapInfo +from nnscaler.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops -import cube.runtime.function as cube_rt_function +import nnscaler.runtime.function as cube_rt_function import torch import torch.fx diff --git a/cube/graph/parser/external/__init__.py b/nnscaler/graph/parser/external/__init__.py similarity index 100% rename from cube/graph/parser/external/__init__.py rename to nnscaler/graph/parser/external/__init__.py diff --git a/cube/graph/parser/external/apex.py b/nnscaler/graph/parser/external/apex.py similarity index 97% rename from cube/graph/parser/external/apex.py rename to nnscaler/graph/parser/external/apex.py index ba50f074..fc56fb62 100644 --- a/cube/graph/parser/external/apex.py +++ b/nnscaler/graph/parser/external/apex.py @@ -2,8 +2,8 @@ import logging import string -from cube.graph.function.dimops import ShapeAnno, OpAnno -from cube.graph import parser +from nnscaler.graph.function.dimops import ShapeAnno, OpAnno +from nnscaler.graph import parser _logger = logging.getLogger(__name__) diff --git a/cube/graph/parser/frame.py b/nnscaler/graph/parser/frame.py similarity index 99% rename from cube/graph/parser/frame.py rename to nnscaler/graph/parser/frame.py index 1a9c8e7d..1c852946 100644 --- a/cube/graph/parser/frame.py +++ b/nnscaler/graph/parser/frame.py @@ -1,6 +1,6 @@ from collections import OrderedDict from typing import List, Any, Dict, Tuple, Optional -from cube.ir.cten import IRTensor +from nnscaler.ir.cten import IRTensor import torch diff --git a/cube/graph/parser/fx/concrete_trace_utils/__init__.py b/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py similarity index 100% rename from cube/graph/parser/fx/concrete_trace_utils/__init__.py rename to nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py similarity index 100% rename from cube/graph/parser/fx/concrete_trace_utils/concrete_proxy.py rename to nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py similarity index 100% rename from cube/graph/parser/fx/concrete_trace_utils/concrete_tracer.py rename to nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/function_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py similarity index 100% rename from cube/graph/parser/fx/concrete_trace_utils/function_patcher.py rename to nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py diff --git a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py similarity index 99% rename from cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py rename to nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py index d89e3f78..953a439f 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -182,7 +182,7 @@ def patch_func_helper(self, func): 7. func(a, b, c) -> patch_run(func, a, b, c) # for patch the functions called in the current function """ if not hasattr(func, '__module__') or func.__module__ is None \ - or func.__module__.startswith('torch.') or func.__module__.startswith('cube.'): + or func.__module__.startswith('torch.') or func.__module__.startswith('nnscaler.'): return func # those flags are set by fx _Patcher when a method is patched # we don't want to patch it again diff --git a/cube/graph/parser/fx/concrete_trace_utils/utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py similarity index 99% rename from cube/graph/parser/fx/concrete_trace_utils/utils.py rename to nnscaler/graph/parser/fx/concrete_trace_utils/utils.py index 22233762..dec5178e 100644 --- a/cube/graph/parser/fx/concrete_trace_utils/utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py @@ -301,7 +301,7 @@ def __repr__(self) -> str: def get_frame_record() -> Optional[FrameRecord]: # record code frame, include filename, line number, and function name frame_record = None - cube_path = str(Path(importlib.util.find_spec('cube').origin).parent) + '/' # the cube path + cube_path = str(Path(importlib.util.find_spec('nnscaler').origin).parent) + '/' # the cube path torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path ignore_dirs = [cube_path, torch_path] # the last frame is the current frame [get_frame_record], so we need to skip it diff --git a/cube/graph/parser/fx/mapping.py b/nnscaler/graph/parser/fx/mapping.py similarity index 97% rename from cube/graph/parser/fx/mapping.py rename to nnscaler/graph/parser/fx/mapping.py index 0ac80d11..ac4b8e5e 100644 --- a/cube/graph/parser/fx/mapping.py +++ b/nnscaler/graph/parser/fx/mapping.py @@ -2,9 +2,9 @@ from typing import Callable, Union from functools import partial -import cube.graph.function as function -from cube.ir.operator import IRFwOperation -from cube.graph.parser.register import CustomizedOps +import nnscaler.graph.function as function +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.parser.register import CustomizedOps class SignFx2Op: @@ -43,7 +43,7 @@ def exist(signature: str) -> bool: __tttemplate = lambda name: f'torch.Tensor.{name}' # runtime template - __rtemplate = lambda name: f'cube.runtime.function.function.{name}' + __rtemplate = lambda name: f'nnscaler.runtime.function.function.{name}' # einops __einopsize = lambda name: f'einops._torch_specific.{name}' diff --git a/cube/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py similarity index 97% rename from cube/graph/parser/fx/parser.py rename to nnscaler/graph/parser/fx/parser.py index cc7422ea..45948d81 100644 --- a/cube/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -3,14 +3,14 @@ from pathlib import Path from typing import Any, List, Tuple, Callable, Union, Dict, Type, Optional -from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRFullTensor -from cube.ir.cten import IRObject, IRCell, IRTensor -from cube.graph.parser.frame import Frame -from cube.graph.parser.fx.mapping import SignFx2Op -from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.function.dimops import IRDimops -from cube.graph.function.function import any_ir_object_satisfy +from nnscaler.ir.operator import IRFwOperation +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.ir.cten import IRObject, IRCell, IRTensor +from nnscaler.graph.parser.frame import Frame +from nnscaler.graph.parser.fx.mapping import SignFx2Op +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.function.function import any_ir_object_satisfy import torch.fx from .concrete_trace_utils import TensorMetadata @@ -213,7 +213,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule _logger.warning(f'Set python runtime function: {fsig}') if any_ir_object_satisfy(input_vals, lambda a: not a.is_constant): err_msg = f'non register python runtime function {fsig} has a non constant input: {input_vals}, ' + \ - 'please register it as a customized function using cube.graph.parser.register' + 'please register it as a customized function using nnscaler.graph.parser.register' raise RuntimeError(err_msg) ir_node = IRPyFunc(fsig, input_vals, [IRObject()], **kwargs) diff --git a/cube/graph/parser/register.py b/nnscaler/graph/parser/register.py similarity index 93% rename from cube/graph/parser/register.py rename to nnscaler/graph/parser/register.py index 5742316b..b3ed80f0 100644 --- a/cube/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -7,9 +7,9 @@ import inspect import logging -from cube.graph.function.dimops import IRDimops, OpAnno -from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply -from cube.ir.operator import IRTensor +from nnscaler.graph.function.dimops import IRDimops, OpAnno +from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply +from nnscaler.ir.operator import IRTensor _logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Call Returns: None """ - builtins = ['_operator.', 'torch.', 'cube.runtime.function.'] + builtins = ['_operator.', 'torch.', 'nnscaler.runtime.function.'] if any(signature.startswith(builtin) for builtin in builtins): raise RuntimeError(f"Cannot register operators with signature starting from any of {builtins}") assert signature not in CustomizedOps.kOpMap, f"function {signature} is already registered" @@ -82,19 +82,19 @@ def register(node_repr: Union[str, Callable], name: Optional[str] = None, Examples: ```python - import cube + import nnscaler from third_party import func - cube.graph.parser.register('a (b c) -> (a b) c')(func) + nnscaler.graph.parser.register('a (b c) -> (a b) c')(func) ``` or, ```python - import cube + import nnscaler from third_party import func - @cube.graph.parser.register('a (b c) -> (a b) c') + @nnscaler.graph.parser.register('a (b c) -> (a b) c') def func(x, b = 4): xxx ``` @@ -102,13 +102,13 @@ def func(x, b = 4): or, ```python - import cube + import nnscaler from third_party import func def anno_fn(*inputs, **kwargs): return 'a (b c) -> (a b) c' - cube.graph.parser.register(anno_fn)(func) + nnscaler.graph.parser.register(anno_fn)(func) ``` Args: diff --git a/nnscaler/graph/schedule/__init__.py b/nnscaler/graph/schedule/__init__.py new file mode 100644 index 00000000..3d5d29e0 --- /dev/null +++ b/nnscaler/graph/schedule/__init__.py @@ -0,0 +1 @@ +from nnscaler.graph.schedule.schedplan import SchedulePlan diff --git a/cube/graph/schedule/predefined.py b/nnscaler/graph/schedule/predefined.py similarity index 98% rename from cube/graph/schedule/predefined.py rename to nnscaler/graph/schedule/predefined.py index b3436d6b..1a346194 100644 --- a/cube/graph/schedule/predefined.py +++ b/nnscaler/graph/schedule/predefined.py @@ -4,9 +4,9 @@ from typing import List -from cube.graph.schedule.schedplan import SchedulePlan -from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment +from nnscaler.graph.schedule.schedplan import SchedulePlan +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.segment import IRSegment class PredefinedSched: diff --git a/cube/graph/schedule/schedplan.py b/nnscaler/graph/schedule/schedplan.py similarity index 98% rename from cube/graph/schedule/schedplan.py rename to nnscaler/graph/schedule/schedplan.py index d973ea8f..187687bf 100644 --- a/cube/graph/schedule/schedplan.py +++ b/nnscaler/graph/schedule/schedplan.py @@ -1,13 +1,13 @@ from typing import Dict, List, Optional, Tuple, Set -from cube.ir.cten import IRCell -from cube.ir.adapter import IRAdapter -from cube.ir.adapter import IRWeightReducer -from cube.ir.operator import IRDataOperation - -from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.flags import CompileFlag +from nnscaler.ir.cten import IRCell +from nnscaler.ir.adapter import IRAdapter +from nnscaler.ir.adapter import IRWeightReducer +from nnscaler.ir.operator import IRDataOperation + +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.flags import CompileFlag class Block: diff --git a/cube/graph/segment.py b/nnscaler/graph/segment.py similarity index 99% rename from cube/graph/segment.py rename to nnscaler/graph/segment.py index 6e835ee0..b9c33e37 100644 --- a/cube/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -2,13 +2,13 @@ from typing import Dict, Union, List, Optional, Set, Tuple, Any, Callable import numpy as np -from cube.ir.tensor import IRFullTensor, IRSubTensor, ValueMap -from cube.ir.cten import IRTensor, IRCell, IRObject -from cube.ir.operator import IRFwOperation, IRBpOperation -from cube.ir.adapter import IRAdapter +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor, ValueMap +from nnscaler.ir.cten import IRTensor, IRCell, IRObject +from nnscaler.ir.operator import IRFwOperation, IRBpOperation +from nnscaler.ir.adapter import IRAdapter -from cube.graph.function.function import MultiRef -from cube.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.function.function import MultiRef +from nnscaler.graph.function.pyfunc import IRPyFunc class CellPosition: diff --git a/nnscaler/ir/__init__.py b/nnscaler/ir/__init__.py new file mode 100644 index 00000000..0152b006 --- /dev/null +++ b/nnscaler/ir/__init__.py @@ -0,0 +1,5 @@ +from nnscaler.ir.dtype import * +from nnscaler.ir.cten import IRTensor, IRCell +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor +from nnscaler.ir.operator import IRFwOperation, IRBpOperation, IRDataOperation +from nnscaler.ir.adapter.adapter import IRAdapter diff --git a/nnscaler/ir/adapter/__init__.py b/nnscaler/ir/adapter/__init__.py new file mode 100644 index 00000000..a16ff34b --- /dev/null +++ b/nnscaler/ir/adapter/__init__.py @@ -0,0 +1 @@ +from nnscaler.ir.adapter.adapter import IRAdapter, IRWeightReducer diff --git a/cube/ir/adapter/adapter.py b/nnscaler/ir/adapter/adapter.py similarity index 97% rename from cube/ir/adapter/adapter.py rename to nnscaler/ir/adapter/adapter.py index 7cf9f36f..7b340cd9 100644 --- a/cube/ir/adapter/adapter.py +++ b/nnscaler/ir/adapter/adapter.py @@ -1,9 +1,9 @@ from typing import List, Optional, Dict import copy -from cube.ir.adapter.prim import IRAdapterPrim, IdentityPrim -from cube.ir.tensor import IRSubTensor -from cube.ir.cten import IRCell +from nnscaler.ir.adapter.prim import IRAdapterPrim, IdentityPrim +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir.cten import IRCell class IRAdapter(IRCell): diff --git a/cube/ir/adapter/prim.py b/nnscaler/ir/adapter/prim.py similarity index 92% rename from cube/ir/adapter/prim.py rename to nnscaler/ir/adapter/prim.py index 9eae0c9e..e89900d7 100644 --- a/cube/ir/adapter/prim.py +++ b/nnscaler/ir/adapter/prim.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union, Tuple import copy -from cube.ir.tensor import IRSubTensor, IndexMap, ValueMap +from nnscaler.ir.tensor import IRSubTensor, IndexMap, ValueMap # the general adapter primitive class @@ -122,7 +122,7 @@ class IdentityPrim(SpatialPrim): def __init__(self, itensor: IRSubTensor): super().__init__([itensor], [itensor]) - self.signature = 'cube.runtime.adapter.identity' + self.signature = 'nnscaler.runtime.adapter.identity' def __repr__(self): dscp = f"{self.output(0)} = identity({self.input(0)})" @@ -139,7 +139,7 @@ def __init__(self, indmap = tuple(slice(s, e) for s, e in indmap) valmap = ValueMap(valmap).weight[1] super().__init__([itensor], [otensor], indmap=indmap, valmap=valmap) - self.signature = f"cube.runtime.adapter.select" + self.signature = f"nnscaler.runtime.adapter.select" def __repr__(self): dscp = f"{self.output(0)} = select({self.input(0)}, indmap={self.kwargs['indmap']}, valmap={self.kwargs['valmap']})" @@ -153,7 +153,7 @@ class MergeDimPrim(SpatialPrim): def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, dim: int) -> None: assert all(itensor.device == itensors[0].device for itensor in itensors), "device not same" super().__init__(itensors, [otensor], dim=dim) - self.signature = 'cube.runtime.adapter.smerge' + self.signature = 'nnscaler.runtime.adapter.smerge' def __repr__(self) -> str: return f"dev{self.device}: {self.output(0)} = concat({self.inputs()}, dim={self.kwargs['dim']})" @@ -165,7 +165,7 @@ class SumPrim(ValuePrim): def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor): assert all(itensor.device == itensors[0].device for itensor in itensors), "device not same" super().__init__(itensors, [otensor]) - self.signature = 'cube.runtime.adapter.vmerge' + self.signature = 'nnscaler.runtime.adapter.vmerge' def __repr__(self) -> str: return f"dev{self.device}: {self.output(0)} = add({self.inputs()})" @@ -186,7 +186,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None shape, dtype, src, dst = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dst'] super().__init__(itensors, otensors, shape=shape, dtype=dtype, src=src, dst=dst) - self.signature = 'cube.runtime.adapter.move' + self.signature = 'nnscaler.runtime.adapter.move' def volume(self) -> int: if len(self._inputs) > 0: @@ -228,7 +228,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim kwargs['dsts'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) shape, dtype, src, dsts = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dsts'] super().__init__(itensors, otensors, shape=shape, dtype=dtype, dim=dim, src=src, dsts=dsts) - self.signature = 'cube.runtime.adapter.rdscatter' + self.signature = 'nnscaler.runtime.adapter.rdscatter' def volume(self) -> int: return self.input(0).nelement() @@ -259,7 +259,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k kwargs['dsts'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) shape, dtype, src, dsts = kwargs['shape'], kwargs['dtype'], kwargs['src'], kwargs['dsts'] super().__init__(itensors, otensors, shape=shape, dtype=dtype, src=src, dst=dsts) - self.signature = 'cube.runtime.adapter.rvscatter' + self.signature = 'nnscaler.runtime.adapter.rvscatter' def volume(self) -> int: return self.input(0).nelement() * len(self.outputs()) @@ -284,7 +284,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None shape, dtype, srcs, dst = kwargs['shape'], kwargs['dtype'], kwargs['srcs'], kwargs['dst'] super().__init__(itensors, otensors, shape=shape, dtype=dtype, srcs=srcs, dst=dst, dim=dim) - self.signature = 'cube.runtime.adapter.rdgather' + self.signature = 'nnscaler.runtime.adapter.rdgather' def volume(self) -> int: return self.output(0).nelement() @@ -309,7 +309,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None shape, dtype, srcs, dst = kwargs['shape'], kwargs['dtype'], kwargs['srcs'], kwargs['dst'] super().__init__(itensors, otensors, shape=shape, dtype=dtype, srcs=srcs, dst=dst) - self.signature = 'cube.runtime.adapter.rvgather' + self.signature = 'nnscaler.runtime.adapter.rvgather' def volume(self) -> int: return self.output(0).nelement() * len(self.inputs()) @@ -331,7 +331,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None super().__init__(itensors, otensors, **kwargs) - self.signature = 'cube.runtime.adapter.broadcast' + self.signature = 'nnscaler.runtime.adapter.broadcast' def volume(self) -> int: ndevs = len(self.outputs()) @@ -348,7 +348,7 @@ class AllReducePrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): super().__init__(itensors, otensors, **kwargs) - self.signature = 'cube.runtime.adapter.all_reduce' + self.signature = 'nnscaler.runtime.adapter.all_reduce' def volume(self) -> int: """ @@ -367,7 +367,7 @@ class AllGatherPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): super().__init__(itensors, otensors, dim=dim, **kwargs) - self.signature = 'cube.runtime.adapter.all_gather' + self.signature = 'nnscaler.runtime.adapter.all_gather' def volume(self) -> int: """ @@ -386,7 +386,7 @@ class ReduceScatterPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): super().__init__(itensors, otensors, dim=dim, **kwargs) - self.signature = 'cube.runtime.adapter.reduce_scatter' + self.signature = 'nnscaler.runtime.adapter.reduce_scatter' def volume(self) -> int: """ @@ -407,7 +407,7 @@ class ReducePrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensor: IRSubTensor, **kwargs): super().__init__(itensors, [otensor], dst=otensor.device[0], **kwargs) - self.signature = 'cube.runtime.adapter.reduce' + self.signature = 'nnscaler.runtime.adapter.reduce' def volume(self) -> int: ndevs = len(self.inputs()) @@ -428,7 +428,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idi idim != odim """ super().__init__(itensors, otensors, idim=idim, odim=odim, **kwargs) - self.signature = 'cube.runtime.adapter.all_to_all' + self.signature = 'nnscaler.runtime.adapter.all_to_all' def volume(self) -> int: ndevs = len(self.inputs()) @@ -444,7 +444,7 @@ class ChunkPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): super().__init__(itensors, otensors, dim=dim, **kwargs) - self.signature = 'cube.runtime.adapter.chunk' + self.signature = 'nnscaler.runtime.adapter.chunk' def volume(self) -> int: return 0 @@ -459,7 +459,7 @@ class VChunkPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): super().__init__(itensors, otensors, **kwargs) - self.signature = 'cube.runtime.adapter.vchunk' + self.signature = 'nnscaler.runtime.adapter.vchunk' def volume(self) -> int: return 0 @@ -475,7 +475,7 @@ class AllReduceIdentityPrim(AllReducePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): super().__init__(itensors, otensors, **kwargs) - self.signature = 'cube.runtime.adapter.nn.allreduce_identity' + self.signature = 'nnscaler.runtime.adapter.nn.allreduce_identity' def __repr__(self) -> str: return f"{self.outputs()} = allreduce_identity[{self.device}]({self.inputs()})" @@ -488,7 +488,7 @@ class IdentityAllreducePrim(AllReducePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): super().__init__(itensors, otensors, **kwargs) - self.signature = 'cube.runtime.adapter.nn.identity_allreduce' + self.signature = 'nnscaler.runtime.adapter.nn.identity_allreduce' def __repr__(self) -> str: return f"{self.outputs()} = identity_allreduce[{self.device}]({self.inputs()})" @@ -501,7 +501,7 @@ class AllReduceAllReducePrim(AllReducePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): super().__init__(itensors, otensors, **kwargs) - self.signature = 'cube.runtime.adapter.nn.allreduce_allreduce' + self.signature = 'nnscaler.runtime.adapter.nn.allreduce_allreduce' def __repr__(self) -> str: return f"{self.outputs} = nn.allreduce_allreduce[{self.device}]({self.inputs()}" @@ -514,7 +514,7 @@ class ReduceScatterAllGatherPrim(ReduceScatterPrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): super().__init__(itensors, otensors, dim, **kwargs) - self.signature = 'cube.runtime.adapter.nn.reducescatter_allgather' + self.signature = 'nnscaler.runtime.adapter.nn.reducescatter_allgather' class AllGatherReduceScatterPrim(AllGatherPrim): @@ -524,7 +524,7 @@ class AllGatherReduceScatterPrim(AllGatherPrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): super().__init__(itensors, otensors, dim, **kwargs) - self.signature = 'cube.runtime.adapter.nn.allgather_reducescatter' + self.signature = 'nnscaler.runtime.adapter.nn.allgather_reducescatter' class AllGatherSplitPrim(AllGatherPrim): @@ -534,7 +534,7 @@ class AllGatherSplitPrim(AllGatherPrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): super().__init__(itensors, otensors, dim, **kwargs) - self.signature = 'cube.runtime.adapter.nn.allgather_split' + self.signature = 'nnscaler.runtime.adapter.nn.allgather_split' class SplitAllGatherPrim(AllGatherPrim): @@ -544,7 +544,7 @@ class SplitAllGatherPrim(AllGatherPrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): super().__init__(itensors, otensors, dim, **kwargs) - self.signature = 'cube.runtime.adapter.nn.split_allgather' + self.signature = 'nnscaler.runtime.adapter.nn.split_allgather' class AllToAllAllToAllPrim(AllToAllPrim): @@ -554,7 +554,7 @@ class AllToAllAllToAllPrim(AllToAllPrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idim: int, odim: int, **kwargs): super().__init__(itensors, otensors, idim, odim, **kwargs) - self.signature = 'cube.runtime.adapter.nn.alltoall_alltoall' + self.signature = 'nnscaler.runtime.adapter.nn.alltoall_alltoall' class ReduceBroadcastPrim(CollectivePrim): diff --git a/cube/ir/cten.py b/nnscaler/ir/cten.py similarity index 99% rename from cube/ir/cten.py rename to nnscaler/ir/cten.py index e1e8d2e9..1089f419 100644 --- a/cube/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -21,8 +21,8 @@ import copy import torch -from cube.ir.unique import IDGenerator -from cube.ir.dtype import DTypeInfo +from nnscaler.ir.unique import IDGenerator +from nnscaler.ir.dtype import DTypeInfo NestedVarOrStatic = Any @@ -539,7 +539,7 @@ def backward(self) -> None: @return None """ - from cube.program import Program + from nnscaler.program import Program graph = Program().get_graph() return graph.backward(self) diff --git a/cube/ir/dtype.py b/nnscaler/ir/dtype.py similarity index 100% rename from cube/ir/dtype.py rename to nnscaler/ir/dtype.py diff --git a/cube/ir/operator.py b/nnscaler/ir/operator.py similarity index 96% rename from cube/ir/operator.py rename to nnscaler/ir/operator.py index e783925d..6433772c 100644 --- a/cube/ir/operator.py +++ b/nnscaler/ir/operator.py @@ -1,11 +1,11 @@ from typing import Optional, Tuple, Any, Union, List import copy -from cube.ir.cten import IRCell, IRTensor, IRObject -from cube.ir.tensor import IRFullTensor -from cube.algorithm.factory import DistAlgorithmFactory -from cube.algorithm.generics import GenericDistAlgo -from cube.ir.dtype import DTypeInfo +from nnscaler.ir.cten import IRCell, IRTensor, IRObject +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.algorithm.factory import DistAlgorithmFactory +from nnscaler.algorithm.generics import GenericDistAlgo +from nnscaler.ir.dtype import DTypeInfo class IRFwOperation(IRCell): diff --git a/cube/ir/tensor.py b/nnscaler/ir/tensor.py similarity index 99% rename from cube/ir/tensor.py rename to nnscaler/ir/tensor.py index dc60335f..fad9ae47 100644 --- a/cube/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -26,7 +26,7 @@ from typing import List, Optional, Union, Tuple, NewType, Dict, Any -from cube.ir.cten import IRTensor +from nnscaler.ir.cten import IRTensor StartEnd = NewType('[start:end)', Tuple[int, int]) IdxChunk = NewType('(index, chunks)', Tuple[int, int]) diff --git a/cube/ir/unique.py b/nnscaler/ir/unique.py similarity index 100% rename from cube/ir/unique.py rename to nnscaler/ir/unique.py diff --git a/cube/parallel.py b/nnscaler/parallel.py similarity index 98% rename from cube/parallel.py rename to nnscaler/parallel.py index 8f8ef9bc..12d6a3b7 100644 --- a/cube/parallel.py +++ b/nnscaler/parallel.py @@ -13,35 +13,35 @@ import os import torch -from cube.codegen.schedule.schedule import ScheduleCodeGen -from cube.graph.parser.fx.parser import FxModuleParser - -from cube.graph.schedule.predefined import PredefinedSched -from cube.ir.cten import IRObject, IRTensor -from cube.ir.tensor import IRFullTensor - -from cube.flags import CompileFlag, RuntimeFlag -from cube.utils import get_shared_params - -from cube.graph import IRGraph -from cube.graph import parser -from cube.ir.operator import IRBpOperation, IRDataOperation -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.pyfunc import IRPyFunc -from cube.graph.schedule.schedplan import SchedulePlan -from cube.graph.gener.gen import IRAdapterGener - -from cube.codegen import ModuleCodeGen -from cube.execplan import ExecutionPlan -from cube.execplan.planpass.grouping import Grouping -from cube.execplan.planpass.fusion import DiffFusion -from cube.ir.unique import IDGenerator -from cube.program import Program -from cube.runtime.adapter.reducer import Reducer -from cube.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState -from cube.runtime.device import DeviceGroup -from cube.runtime.gnorm import calcuate_gnorm, clip_grads -from cube.utils import get_member_by_name, setup_stride_broadcast_group +from nnscaler.codegen.schedule.schedule import ScheduleCodeGen +from nnscaler.graph.parser.fx.parser import FxModuleParser + +from nnscaler.graph.schedule.predefined import PredefinedSched +from nnscaler.ir.cten import IRObject, IRTensor +from nnscaler.ir.tensor import IRFullTensor + +from nnscaler.flags import CompileFlag, RuntimeFlag +from nnscaler.utils import get_shared_params + +from nnscaler.graph import IRGraph +from nnscaler.graph import parser +from nnscaler.ir.operator import IRBpOperation, IRDataOperation +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.schedule.schedplan import SchedulePlan +from nnscaler.graph.gener.gen import IRAdapterGener + +from nnscaler.codegen import ModuleCodeGen +from nnscaler.execplan import ExecutionPlan +from nnscaler.execplan.planpass.grouping import Grouping +from nnscaler.execplan.planpass.fusion import DiffFusion +from nnscaler.ir.unique import IDGenerator +from nnscaler.program import Program +from nnscaler.runtime.adapter.reducer import Reducer +from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState +from nnscaler.runtime.device import DeviceGroup +from nnscaler.runtime.gnorm import calcuate_gnorm, clip_grads +from nnscaler.utils import get_member_by_name, setup_stride_broadcast_group logger = logging.getLogger(__name__) @@ -845,7 +845,7 @@ def parallelize( After the module is converted, you can use it to create module object by calling it like a module class. The module class is defined like: ``` - class GenModule(cube.runtime.module.ParallelModule): + class GenModule(nnscaler.runtime.module.ParallelModule): def __init__(self, init_params=True): super().__init__() ... @@ -893,7 +893,7 @@ def __init__(self, init_params=True): # Call it here just to ensure the device group is initialized. # If the user initializes torch.distributed - # and doesn't call `cube.init()` before calling this function, this is necessary. + # and doesn't call `nnscaler.init()` before calling this function, this is necessary. if torch.distributed.is_initialized(): _ = DeviceGroup() @@ -1811,7 +1811,7 @@ def _get_optimizer_state_of_param(param, param_ids, local_names): state_dict[opt_param_idx] = opt_states opt_param_idx += 1 # load the params' optimizer state that are not in reducers - # this part corresponds to cube/runtime/module.py: parameters_for_optimizer + # this part corresponds to nnscaler/runtime/module.py: parameters_for_optimizer reducer_pids = set() for reducer in module.reducers: reducer_pids.update(id(p) for p in reducer.params) diff --git a/cube/profiler/README.md b/nnscaler/profiler/README.md similarity index 95% rename from cube/profiler/README.md rename to nnscaler/profiler/README.md index 6a02dc9a..86bf0577 100644 --- a/cube/profiler/README.md +++ b/nnscaler/profiler/README.md @@ -11,7 +11,7 @@ prof = cProfile.Profile() prof.enable() # our code to profile goes here -@cube.compile(...) +@nnscaler.compile(...) def iter(dataloader): x, y = next(dataloader) z = model(x, y) @@ -51,4 +51,4 @@ pip install viztracer viztracer --log_multiprocess torchrun --nproc_per_node=4 --nnodes=1 examples/mlp/linears.py ``` -For more configurations please check `viztracer -h`. \ No newline at end of file +For more configurations please check `viztracer -h`. diff --git a/nnscaler/profiler/__init__.py b/nnscaler/profiler/__init__.py new file mode 100644 index 00000000..15a9d386 --- /dev/null +++ b/nnscaler/profiler/__init__.py @@ -0,0 +1,2 @@ +from nnscaler.profiler.timer import CudaTimer +from nnscaler.profiler.database import ProfileDataBase diff --git a/cube/profiler/database.py b/nnscaler/profiler/database.py similarity index 98% rename from cube/profiler/database.py rename to nnscaler/profiler/database.py index 8a1a928d..4fe40d66 100644 --- a/cube/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -1,6 +1,6 @@ """ Usage: - python -m cube.profiler.database --export ./profile.dat.json + python -m nnscaler.profiler.database --export ./profile.dat.json """ from typing import Callable, Tuple, Union, Optional, Dict, NewType, List, Any import torch @@ -13,10 +13,10 @@ from dataclasses import dataclass, asdict import _operator # required by eval() -import cube # required by eval() -from cube.ir.cten import IRTensor, IRObject -from cube.ir.operator import IRFwOperation -from cube.graph.parser.register import CustomizedOps +import nnscaler # required by eval() +from nnscaler.ir.cten import IRTensor, IRObject +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.parser.register import CustomizedOps _logger = logging.getLogger(__name__) diff --git a/cube/profiler/estimator.py b/nnscaler/profiler/estimator.py similarity index 90% rename from cube/profiler/estimator.py rename to nnscaler/profiler/estimator.py index 51f63202..9a123119 100644 --- a/cube/profiler/estimator.py +++ b/nnscaler/profiler/estimator.py @@ -3,10 +3,10 @@ import os import logging -from cube.ir.operator import IRFwOperation -from cube.graph.segment import IRSegment -from cube.graph.function import IRGraphAnchor -from cube.profiler.database import ProfileDataBase +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.function import IRGraphAnchor +from nnscaler.profiler.database import ProfileDataBase _logger = logging.getLogger(__name__) diff --git a/cube/profiler/memory.py b/nnscaler/profiler/memory.py similarity index 98% rename from cube/profiler/memory.py rename to nnscaler/profiler/memory.py index 5d536664..b7e82cd6 100644 --- a/cube/profiler/memory.py +++ b/nnscaler/profiler/memory.py @@ -1,6 +1,6 @@ from typing import Any, List import logging -from cube.utils import print_each_rank +from nnscaler.utils import print_each_rank import torch _logger = logging.getLogger(__name__) diff --git a/cube/profiler/timer.py b/nnscaler/profiler/timer.py similarity index 99% rename from cube/profiler/timer.py rename to nnscaler/profiler/timer.py index 4fffa0e1..e73c504f 100644 --- a/cube/profiler/timer.py +++ b/nnscaler/profiler/timer.py @@ -3,7 +3,7 @@ import logging import torch -from cube.utils import print_each_rank +from nnscaler.utils import print_each_rank _logger = logging.getLogger(__name__) diff --git a/cube/program.py b/nnscaler/program.py similarity index 95% rename from cube/program.py rename to nnscaler/program.py index b8646bcd..03b6abfd 100644 --- a/cube/program.py +++ b/nnscaler/program.py @@ -1,18 +1,18 @@ from typing import List, Tuple, Optional, Any, Dict, Union import inspect -from cube.ir.cten import IRCell, IRObject -from cube.ir.tensor import IRFullTensor, IRSubTensor -from cube.ir.operator import IRBpOperation, IRDataOperation +from nnscaler.ir.cten import IRCell, IRObject +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor +from nnscaler.ir.operator import IRBpOperation, IRDataOperation -from cube.graph import IRGraph -from cube.graph import parser +from nnscaler.graph import IRGraph +from nnscaler.graph import parser -from cube.runtime.module import CubeModule -from cube.runtime.device import DeviceGroup -from cube.runtime.utils import MicroBatchDataLoader +from nnscaler.runtime.module import CubeModule +from nnscaler.runtime.device import DeviceGroup +from nnscaler.runtime.utils import MicroBatchDataLoader -from cube.utils import load_model +from nnscaler.utils import load_model import torch import torch.utils.data as data diff --git a/nnscaler/runtime/__init__.py b/nnscaler/runtime/__init__.py new file mode 100644 index 00000000..faac540f --- /dev/null +++ b/nnscaler/runtime/__init__.py @@ -0,0 +1,6 @@ +from nnscaler.runtime import executor +from nnscaler.runtime import device +from nnscaler.runtime import adapter +from nnscaler.runtime import resource +from nnscaler.runtime import module +from nnscaler.runtime import function diff --git a/nnscaler/runtime/adapter/__init__.py b/nnscaler/runtime/adapter/__init__.py new file mode 100644 index 00000000..574332e8 --- /dev/null +++ b/nnscaler/runtime/adapter/__init__.py @@ -0,0 +1,4 @@ +from nnscaler.runtime.adapter.collectives import * +from nnscaler.runtime.adapter.transform import * +from nnscaler.runtime.adapter import nn +from nnscaler.runtime.adapter.reducer import Reducer diff --git a/cube/runtime/adapter/collectives.py b/nnscaler/runtime/adapter/collectives.py similarity index 98% rename from cube/runtime/adapter/collectives.py rename to nnscaler/runtime/adapter/collectives.py index 46fe6618..b18645e5 100644 --- a/cube/runtime/adapter/collectives.py +++ b/nnscaler/runtime/adapter/collectives.py @@ -9,10 +9,10 @@ from typing import List, Tuple, Optional import torch -from cube.runtime.device import DeviceGroup -from cube.profiler.timer import CudaTimer +from nnscaler.runtime.device import DeviceGroup +from nnscaler.profiler.timer import CudaTimer -from cube.runtime.executor import AsyncCommHandler +from nnscaler.runtime.executor import AsyncCommHandler def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int, async_op=False): diff --git a/cube/runtime/adapter/nn.py b/nnscaler/runtime/adapter/nn.py similarity index 98% rename from cube/runtime/adapter/nn.py rename to nnscaler/runtime/adapter/nn.py index 8b6a025f..5ac2623b 100644 --- a/cube/runtime/adapter/nn.py +++ b/nnscaler/runtime/adapter/nn.py @@ -7,8 +7,8 @@ from typing import List, Tuple import torch -from cube.profiler.timer import CudaTimer -from cube.runtime.device import DeviceGroup +from nnscaler.profiler.timer import CudaTimer +from nnscaler.runtime.device import DeviceGroup from .collectives import ( all_reduce, all_gather, diff --git a/cube/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py similarity index 99% rename from cube/runtime/adapter/reducer.py rename to nnscaler/runtime/adapter/reducer.py index df61ed08..96563d8d 100644 --- a/cube/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -4,9 +4,9 @@ import torch from torch.utils.hooks import RemovableHandle -from cube.runtime.device import DeviceGroup -from cube.profiler.timer import CudaTimer -from cube.flags import RuntimeFlag +from nnscaler.runtime.device import DeviceGroup +from nnscaler.profiler.timer import CudaTimer +from nnscaler.flags import RuntimeFlag _logger = logging.getLogger(__name__) @@ -182,7 +182,7 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): if self._async_param_cnt > len(self._params): raise RuntimeError( "Detected gradient accumulation with asynchronous Reducer. " - "Users should run with `cube.accum_mode` to manage gradient synchronization.") + "Users should run with `nnscaler.accum_mode` to manage gradient synchronization.") if self._async_param_cnt == len(self._params): # apply pre hooks self._apply_pre_hooks() diff --git a/cube/runtime/adapter/transform.py b/nnscaler/runtime/adapter/transform.py similarity index 100% rename from cube/runtime/adapter/transform.py rename to nnscaler/runtime/adapter/transform.py diff --git a/cube/runtime/device.py b/nnscaler/runtime/device.py similarity index 99% rename from cube/runtime/device.py rename to nnscaler/runtime/device.py index 081ae3e1..63f00003 100644 --- a/cube/runtime/device.py +++ b/nnscaler/runtime/device.py @@ -8,7 +8,7 @@ import logging import datetime -from cube.flags import CompileFlag +from nnscaler.flags import CompileFlag _logger = logging.getLogger(__name__) _LARGE_TIMEOUT = datetime.timedelta(seconds=21600) diff --git a/cube/runtime/executor.py b/nnscaler/runtime/executor.py similarity index 100% rename from cube/runtime/executor.py rename to nnscaler/runtime/executor.py diff --git a/nnscaler/runtime/function/__init__.py b/nnscaler/runtime/function/__init__.py new file mode 100644 index 00000000..ae856192 --- /dev/null +++ b/nnscaler/runtime/function/__init__.py @@ -0,0 +1 @@ +from nnscaler.runtime.function.function import * \ No newline at end of file diff --git a/cube/runtime/function/function.py b/nnscaler/runtime/function/function.py similarity index 100% rename from cube/runtime/function/function.py rename to nnscaler/runtime/function/function.py diff --git a/cube/runtime/gnorm.py b/nnscaler/runtime/gnorm.py similarity index 99% rename from cube/runtime/gnorm.py rename to nnscaler/runtime/gnorm.py index 72033485..bb94b34b 100644 --- a/cube/runtime/gnorm.py +++ b/nnscaler/runtime/gnorm.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: - from cube.runtime.module import CubeModule + from nnscaler.runtime.module import CubeModule @dataclass @@ -101,7 +101,7 @@ def _check_no_intersection(ranks_set): def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int, TidReplicaInfo]: """This function is used to calculate the number of replicas of each model parameter. Each parameter has a tuple of `len(ranksset)` (we call it nreplicated) and `nranks`, - because a parameter may be replicated (not data parallelism) which is supported by cube. + because a parameter may be replicated (not data parallelism) which is supported by nnscaler. It affects the calculation of gnorm. So nreplicated represents the number of non-data-parallelism replicas for this parameter, and nranks represents the number of all the involved ranks for this parameter. diff --git a/cube/runtime/module.py b/nnscaler/runtime/module.py similarity index 98% rename from cube/runtime/module.py rename to nnscaler/runtime/module.py index 7503e02f..8375c97b 100644 --- a/cube/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -9,17 +9,17 @@ import torch import torch.distributed as dist -from cube.graph.parser.fx.parser import FxModuleParser -from cube.runtime.device import DeviceGroup -from cube.runtime.adapter.reducer import Reducer -from cube.runtime.executor import Executor -from cube.runtime.gnorm import ParamsInfo -from cube.flags import CompileFlag -from cube.runtime.utils import microbatches -from cube.utils import accum_mode +from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.runtime.device import DeviceGroup +from nnscaler.runtime.adapter.reducer import Reducer +from nnscaler.runtime.executor import Executor +from nnscaler.runtime.gnorm import ParamsInfo +from nnscaler.flags import CompileFlag +from nnscaler.runtime.utils import microbatches +from nnscaler.utils import accum_mode if TYPE_CHECKING: - from cube.parallel import ComputeConfig + from nnscaler.parallel import ComputeConfig _logger = logging.getLogger(__name__) @@ -86,7 +86,7 @@ def zero_grad(self): This function will be automatically inserted inside the generated code at the beginning of each iteration. - If the function is under the context of `with cube.accum_mode()`, the zero of gradients + If the function is under the context of `with nnscaler.accum_mode()`, the zero of gradients will be skipped. """ for reducer in self._reducers: @@ -580,7 +580,7 @@ class ParallelModuleConfig: def __post_init__(self): if isinstance(self.compute_config, dict): - from cube.parallel import ComputeConfig + from nnscaler.parallel import ComputeConfig self.compute_config = ComputeConfig(**self.compute_config) self.param_area_map = { k: AttrMeta(**v) if isinstance(v, dict) else v @@ -827,7 +827,7 @@ def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, Li Returns: Tuple of The gradient norm and the list of gradients. """ - from cube.runtime.gnorm import prepare_for_grad_clip, clip_gnorm + from nnscaler.runtime.gnorm import prepare_for_grad_clip, clip_gnorm if self._nreplicas2localparams is None: self._nreplicas2localparams = prepare_for_grad_clip(self, self.compute_config.use_zero) @@ -911,7 +911,7 @@ def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: Get the index in the zero subgroup the reducer belongs to, and the ranks of the subgroup. Args: - reducer (cube.runtime.adapter.Reducer): a reducer of cube model + reducer (nnscaler.runtime.adapter.Reducer): a reducer of cube model Returns: rank_idx (int): the index of current rank in sub_ranks diff --git a/cube/runtime/resource.py b/nnscaler/runtime/resource.py similarity index 97% rename from cube/runtime/resource.py rename to nnscaler/runtime/resource.py index d8f3dafa..9223fcef 100644 --- a/cube/runtime/resource.py +++ b/nnscaler/runtime/resource.py @@ -4,7 +4,7 @@ from typing import Tuple import torch -from cube.flags import CompileFlag +from nnscaler.flags import CompileFlag from dataclasses import dataclass diff --git a/cube/runtime/utils.py b/nnscaler/runtime/utils.py similarity index 98% rename from cube/runtime/utils.py rename to nnscaler/runtime/utils.py index ca595ac3..57e87ca5 100644 --- a/cube/runtime/utils.py +++ b/nnscaler/runtime/utils.py @@ -23,7 +23,7 @@ class MicroBatchDataLoader: # compilation phase dataloader = MicroBatchDataLoader([(input1,),]) # only need one micro-batch - @cube.compile(model, dataloader, ...) + @nnscaler.compile(model, dataloader, ...) def train_iter(model, dataloader): input1 = next(dataloader) loss = model(input1) diff --git a/cube/utils.py b/nnscaler/utils.py similarity index 94% rename from cube/utils.py rename to nnscaler/utils.py index 3f2a94b6..5de58c97 100644 --- a/cube/utils.py +++ b/nnscaler/utils.py @@ -6,9 +6,9 @@ from collections import defaultdict from dataclasses import dataclass -import cube -from cube.runtime.device import DeviceGroup -from cube.flags import RuntimeFlag, CompileFlag +import nnscaler +from nnscaler.runtime.device import DeviceGroup +from nnscaler.flags import RuntimeFlag, CompileFlag import torch @@ -57,7 +57,7 @@ def _load_module_attr(filename: str, name: str): def load_model(filename: Optional[str] = None, load_content: bool = True, fullmodel_filename: Optional[str] = None): filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename module = _load_module_attr(filename, Path(filename).stem) - loaded_module: cube.runtime.module.CubeModule = module.GenModel().cuda() + loaded_module: nnscaler.runtime.module.CubeModule = module.GenModel().cuda() # load parameter content if load_content: _logger.info("loading parameter content...") @@ -157,7 +157,7 @@ class accum_mode: for _ in range(num_iters): for step in range(accum_steps): datas = next(dataloader) - with cube.accum_mode(begin=(step == 0), end=(step == accum_steps - 1)): + with nnscaler.accum_mode(begin=(step == 0), end=(step == accum_steps - 1)): train_iter(model, *datas) optimizer.step() optimizer.zero_grad() @@ -167,7 +167,7 @@ class accum_mode: ``` for _ in range(num_iters): - for step in cube.accum_mode.steps(accum_steps): + for step in nnscaler.accum_mode.steps(accum_steps): datas = next(dataloader) train_iter(model, *datas) optimizer.step() @@ -198,7 +198,7 @@ def __enter__(self): for _ in range(num_iters): for step in range(accum_steps): datas = next(dataloader) - with cube.accum_mode(begin=(step == 0), end=(step == accum_steps - 1)): + with nnscaler.accum_mode(begin=(step == 0), end=(step == accum_steps - 1)): train_iter(model, *datas) optimizer.step() optimizer.zero_grad() @@ -224,7 +224,7 @@ def steps(nsteps: int): ``` for _ in range(num_iters): - for step in cube.accum_mode.steps(accum_steps): + for step in nnscaler.accum_mode.steps(accum_steps): datas = next(dataloader) train_iter(model, *datas) optimizer.step() diff --git a/setup.py b/setup.py index 20c504db..269504e1 100644 --- a/setup.py +++ b/setup.py @@ -6,12 +6,12 @@ ] setuptools.setup( - name= 'cube', + name= 'nnscaler', version= '0.2', - author= 'Cube Team', + author= 'nnScaler Team', description= 'Parallelize DNN Traning from A Systematic Way', long_description= 'Parallelize DNN Traning from A Systematic Way', - packages= ['cube'], + packages= ['nnscaler'], python_requires= '>=3.8', install_requires= install_requires, ) diff --git a/tests/algorithm/ops/test_dimops.py b/tests/algorithm/ops/test_dimops.py index 6c3c8590..b3e4ab9c 100644 --- a/tests/algorithm/ops/test_dimops.py +++ b/tests/algorithm/ops/test_dimops.py @@ -1,9 +1,9 @@ import tempfile import torch import os -from cube.parallel import _gen_graph -from cube.ir.operator import IRFwOperation -from cube.algorithm.ops.dimops import gen_partitions +from nnscaler.parallel import _gen_graph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.algorithm.ops.dimops import gen_partitions from ...utils import replace_all_device_with diff --git a/tests/codegen/test_emit.py b/tests/codegen/test_emit.py index 496f61d4..fef41acd 100644 --- a/tests/codegen/test_emit.py +++ b/tests/codegen/test_emit.py @@ -1,9 +1,9 @@ import pytest -from cube.codegen.emit import CodeEmission -from cube.ir.cten import IRObject -from cube.codegen.emit import FuncEmission -from cube.graph.function import Dropout -from cube.ir.tensor import IRFullTensor +from nnscaler.codegen.emit import CodeEmission +from nnscaler.ir.cten import IRObject +from nnscaler.codegen.emit import FuncEmission +from nnscaler.graph.function import Dropout +from nnscaler.ir.tensor import IRFullTensor def test_tensor_name(): diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index e8615d29..ee8007ec 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -5,12 +5,12 @@ from functools import partial import more_itertools as mitr -import cube -from cube.runtime.utils import microbatches -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRFwOperation, IRDataOperation -from cube.flags import CompileFlag +import nnscaler +from nnscaler.runtime.utils import microbatches +from nnscaler.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRFwOperation, IRDataOperation +from nnscaler.flags import CompileFlag from ..launch_torchrun import torchrun from ..utils import init_parameter, assert_parity @@ -112,7 +112,7 @@ def tensor_parallelism(node, idx, dim, num): def cube_run(ngpus_per_unit: int, policy): - cube.init() + nnscaler.init() CompileFlag.disable_code_line_info = True # speedup parse model = MLP() @@ -128,14 +128,14 @@ def cube_run(ngpus_per_unit: int, policy): policy = partial(policy, ngpus_per_unit=ngpus_per_unit) - @cube.compile(model, dl, PAS=policy, scale=True) + @nnscaler.compile(model, dl, PAS=policy, scale=True) def train_iter(model, dataloader): x = next(iter(dataloader)) loss = model(x) loss.backward() return loss - model = cube.load_model() + model = nnscaler.load_model() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) losses = [] diff --git a/tests/conftest.py b/tests/conftest.py index ab682345..1877d5d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from pathlib import Path -from cube.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.parser.fx.parser import FxModuleParser @pytest.fixture(autouse=True) def clean_generated_files(): diff --git a/tests/graph/function/test_dataloader.py b/tests/graph/function/test_dataloader.py index d323cb5b..74c5f48f 100644 --- a/tests/graph/function/test_dataloader.py +++ b/tests/graph/function/test_dataloader.py @@ -4,9 +4,9 @@ import torch -from cube.ir.cten import IRObject -from cube.ir.tensor import IRFullTensor -from cube.ir.operator import IRDataOperation +from nnscaler.ir.cten import IRObject +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.ir.operator import IRDataOperation def test_data_operation(): diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index 68f929a8..aec2f28c 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -5,10 +5,10 @@ from typing import Callable, Tuple, List from functools import partial -import cube.graph.function as F -from cube.graph.function.dimops import IRDimops -from cube.ir.tensor import IRFullTensor -from cube.ir.cten import IRObject +import nnscaler.graph.function as F +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.ir.cten import IRObject def create_op(creator: Callable, diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index db17afb6..b3588f16 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -1,14 +1,14 @@ ### Only test the anno creation in these tests -import cube.graph.function.function as F -from cube.ir.cten import IRObject, IRTensor +import nnscaler.graph.function.function as F +from nnscaler.ir.cten import IRObject, IRTensor import pytest import torch import numpy as np import math -from cube.ir.tensor import IRFullTensor +from nnscaler.ir.tensor import IRFullTensor def o(value): @@ -369,12 +369,12 @@ def test_ScaledDotProductAttention(): def test_NewTensor(): op = F.NewTensor(torch.tensor(1)) - assert op.signature == 'cube.runtime.function.tensor' + assert op.signature == 'nnscaler.runtime.function.tensor' assert repr(op.anno) == ' -> 1^' assert op.kwargs['data'] == 1 op = F.NewTensor(torch.tensor([1,2])) - assert op.signature == 'cube.runtime.function.tensor' + assert op.signature == 'nnscaler.runtime.function.tensor' assert repr(op.anno) == ' -> 2^' assert op.kwargs['data'] == [1,2] diff --git a/tests/graph/gener/check_inter_rvd.py b/tests/graph/gener/check_inter_rvd.py index 6578b850..04be22dc 100644 --- a/tests/graph/gener/check_inter_rvd.py +++ b/tests/graph/gener/check_inter_rvd.py @@ -7,16 +7,16 @@ """ from typing import List, Tuple -import cube -from cube.ir.tensor import IRFullTensor -from cube.graph.gener.rvd.layout import RVDLayout, RVDInspector -from cube.graph.gener.rvd.inter import InterPathFinder +import nnscaler +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph.gener.rvd.layout import RVDLayout, RVDInspector +from nnscaler.graph.gener.rvd.inter import InterPathFinder import numpy as np -from cube.graph.gener.utils import tensor_vd_repr +from nnscaler.graph.gener.utils import tensor_vd_repr -cube.init() +nnscaler.init() def factors(k: int, num: int) -> List[Tuple[int]]: diff --git a/tests/graph/gener/check_intra_rvd.py b/tests/graph/gener/check_intra_rvd.py index 9bb970e1..efa9851e 100644 --- a/tests/graph/gener/check_intra_rvd.py +++ b/tests/graph/gener/check_intra_rvd.py @@ -7,16 +7,16 @@ """ from typing import List, Tuple -import cube -from cube.ir.tensor import IRFullTensor -from cube.graph.gener.rvd.layout import RVDLayout, RVDInspector -from cube.graph.gener.rvd.intra import IntraPathFinder, IntraAutoPlacer, IntraTransition +import nnscaler +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph.gener.rvd.layout import RVDLayout, RVDInspector +from nnscaler.graph.gener.rvd.intra import IntraPathFinder, IntraAutoPlacer, IntraTransition import numpy as np -from cube.graph.gener.utils import tensor_vd_repr +from nnscaler.graph.gener.utils import tensor_vd_repr -cube.init() +nnscaler.init() def factors(k: int, num: int) -> List[Tuple[int]]: diff --git a/tests/graph/gener/test_reducer_gen.py b/tests/graph/gener/test_reducer_gen.py index 9e74df82..332336b7 100644 --- a/tests/graph/gener/test_reducer_gen.py +++ b/tests/graph/gener/test_reducer_gen.py @@ -1,12 +1,12 @@ import pytest -from cube.graph.gener.gen import IRAdapterGener - -from cube.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.graph.parser.converter import convert_model -from cube.ir.operator import IRFwOperation -from cube.ir.tensor import IRFullTensor -from cube.ir.adapter import IRWeightReducer +from nnscaler.graph.gener.gen import IRAdapterGener + +from nnscaler.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.parser.converter import convert_model +from nnscaler.ir.operator import IRFwOperation +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.ir.adapter import IRWeightReducer import torch import tempfile diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 9e0aa18e..59625cf9 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -4,7 +4,7 @@ import pytest -from cube.graph.parser.fx.concrete_trace_utils.operator_patcher import ( +from nnscaler.graph.parser.fx.concrete_trace_utils.operator_patcher import ( OperatorTransformer, SuperTransformer, ProxyCallTransformer, diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index b933c3ab..e960488d 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -5,9 +5,9 @@ import torch import pytest -from cube.graph.parser.converter import to_fx_graph, to_ir_graph -from cube.graph.parser import FxModuleParser -from cube.ir.cten import IRObject, IRTensor +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.graph.parser import FxModuleParser +from nnscaler.ir.cten import IRObject, IRTensor from ...utils import replace_all_device_with @@ -74,7 +74,7 @@ def forward(self, x, *args): module = MyModule() fx_graph = to_fx_graph(module, dummy_input) - cube_path = str(Path(importlib.util.find_spec('cube').origin).parent) + '/' + cube_path = str(Path(importlib.util.find_spec('nnscaler').origin).parent) + '/' for node in fx_graph.graph.nodes: if 'frame_record' in node.meta and cube_path in str(node.meta['frame_record']): @@ -95,8 +95,8 @@ def forward(self, x): module = MyModule() fx_graph = to_fx_graph(module, dummy_input) - from cube.graph.parser.fx.concrete_trace_utils.concrete_proxy import ConcreteProxy - from cube.graph.parser.fx.concrete_trace_utils import TensorMetadata + from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_proxy import ConcreteProxy + from nnscaler.graph.parser.fx.concrete_trace_utils import TensorMetadata for node in fx_graph.graph.nodes: # this assert is only for this simple model, all node should have TensorMetadata type 'tensor_meta' diff --git a/tests/graph/parser/test_dce.py b/tests/graph/parser/test_dce.py index 49156d4b..60decdc8 100644 --- a/tests/graph/parser/test_dce.py +++ b/tests/graph/parser/test_dce.py @@ -1,7 +1,7 @@ import pytest import torch -from cube.graph.parser.converter import to_fx_graph +from nnscaler.graph.parser.converter import to_fx_graph from ...utils import replace_all_device_with diff --git a/tests/graph/parser/test_ir_obj_constant.py b/tests/graph/parser/test_ir_obj_constant.py index d4d61cbf..d493e2fb 100644 --- a/tests/graph/parser/test_ir_obj_constant.py +++ b/tests/graph/parser/test_ir_obj_constant.py @@ -3,7 +3,7 @@ import math import torch -from cube.graph.parser.converter import convert_model +from nnscaler.graph.parser.converter import convert_model from ...utils import replace_all_device_with diff --git a/tests/graph/parser/test_no_grad.py b/tests/graph/parser/test_no_grad.py index 444d6094..963f189b 100644 --- a/tests/graph/parser/test_no_grad.py +++ b/tests/graph/parser/test_no_grad.py @@ -1,7 +1,7 @@ import pytest import torch -from cube.graph.parser.converter import to_fx_graph +from nnscaler.graph.parser.converter import to_fx_graph from ...utils import replace_all_device_with diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 434a7e9c..9c6f6d6d 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -1,7 +1,7 @@ import tempfile import torch -from cube.ir.cten import IRObject, IRTensor -from cube.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.ir.cten import IRObject, IRTensor +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph from ...utils import replace_all_device_with diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index cb05381b..ec21d090 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -1,6 +1,6 @@ -import cube -from cube.graph.parser.converter import convert_model -from cube.profiler.database import ProfileDataBase +import nnscaler +from nnscaler.graph.parser.converter import convert_model +from nnscaler.profiler.database import ProfileDataBase import tempfile import torch @@ -10,15 +10,15 @@ def mock_add(x: torch.Tensor, y: torch.Tensor): return x + y -cube.graph.parser.register('*, * -> *')(mock_add) +nnscaler.graph.parser.register('*, * -> *')(mock_add) -@cube.graph.parser.register('*, * -> *') +@nnscaler.graph.parser.register('*, * -> *') def mock_add2(x: torch.Tensor, y: torch.Tensor): return x + y -@cube.graph.parser.register('(h w^) k^ -> h (w^ k^)') +@nnscaler.graph.parser.register('(h w^) k^ -> h (w^ k^)') def mock_view_with_obj(x, h): return x.view(h, -1) @@ -32,7 +32,7 @@ def forward(ctx, x: torch.Tensor, y: torch.Tensor): def backward(ctx, grad): return grad, grad -cube.graph.parser.register('*, * -> *')(MockAGF.apply) +nnscaler.graph.parser.register('*, * -> *')(MockAGF.apply) class MockModel(torch.nn.Module): diff --git a/tests/graph/parser/test_register_external.py b/tests/graph/parser/test_register_external.py index 842586c6..f5322f20 100644 --- a/tests/graph/parser/test_register_external.py +++ b/tests/graph/parser/test_register_external.py @@ -2,9 +2,9 @@ import torch import logging import tempfile -from cube.graph.parser.converter import convert_model -from cube.ir.operator import IRFwOperation -from cube.graph.function.dimops import IRDimops +from nnscaler.graph.parser.converter import convert_model +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.function.dimops import IRDimops _logger = logging.getLogger(__name__) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 784bd32c..5de8015c 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -1,7 +1,7 @@ -from cube.ir.tensor import IRFullTensor, IRSubTensor -from cube.ir.operator import IRFwOperation -from cube.graph.graph import IRGraph +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.graph import IRGraph def test_graph_from_logic(): diff --git a/tests/graph/test_multiref.py b/tests/graph/test_multiref.py index 935cf9fd..b9f81d99 100644 --- a/tests/graph/test_multiref.py +++ b/tests/graph/test_multiref.py @@ -5,9 +5,9 @@ import logging from functools import partial -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation +import nnscaler +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation from ..launch_torchrun import torchrun from ..utils import init_parameter, assert_parity @@ -94,13 +94,13 @@ def policy(graph: IRGraph, resource): x, y = get_dummy_data() - @cube.compile(model, x, y, PAS=policy) + @nnscaler.compile(model, x, y, PAS=policy) def train_iter(model, x, y): loss = model(x, y) loss.backward() return loss - model = cube.load_model() + model = nnscaler.load_model() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) losses = [] @@ -115,8 +115,8 @@ def train_iter(model, x, y): def multiref_test(): - cube.init() - cube.set_logger_level(logging.INFO) + nnscaler.init() + nnscaler.set_logger_level(logging.INFO) assert_parity(baseline, multiref) diff --git a/tests/graph/tracer/test_cls_wrapper.py b/tests/graph/tracer/test_cls_wrapper.py index a88569f6..192cd486 100644 --- a/tests/graph/tracer/test_cls_wrapper.py +++ b/tests/graph/tracer/test_cls_wrapper.py @@ -1,7 +1,7 @@ import torch import pytest -from cube.graph.parser.converter import to_fx_graph +from nnscaler.graph.parser.converter import to_fx_graph from ...utils import replace_all_device_with diff --git a/tests/graph/tracer/test_getattr.py b/tests/graph/tracer/test_getattr.py index 7a6e22c3..b4ac5b7e 100644 --- a/tests/graph/tracer/test_getattr.py +++ b/tests/graph/tracer/test_getattr.py @@ -1,6 +1,6 @@ import torch -from cube.graph.parser.converter import to_fx_graph +from nnscaler.graph.parser.converter import to_fx_graph from ...utils import replace_all_device_with diff --git a/tests/graph/tracer/test_inplace.py b/tests/graph/tracer/test_inplace.py index 58b27474..90552e46 100644 --- a/tests/graph/tracer/test_inplace.py +++ b/tests/graph/tracer/test_inplace.py @@ -2,9 +2,9 @@ import _operator import torch -from cube.graph.parser.converter import to_fx_graph -from cube.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops -import cube.runtime.function as cube_rt_function +from nnscaler.graph.parser.converter import to_fx_graph +from nnscaler.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops +import nnscaler.runtime.function as cube_rt_function from ...utils import replace_all_device_with diff --git a/tests/graph/tracer/test_namedtuple.py b/tests/graph/tracer/test_namedtuple.py index 4097fdaf..91b16245 100644 --- a/tests/graph/tracer/test_namedtuple.py +++ b/tests/graph/tracer/test_namedtuple.py @@ -1,7 +1,7 @@ from collections import namedtuple import torch -from cube.graph.parser.converter import to_fx_graph +from nnscaler.graph.parser.converter import to_fx_graph from ...utils import replace_all_device_with diff --git a/tests/graph/tracer/test_op_patcher.py b/tests/graph/tracer/test_op_patcher.py index 20ca94f0..a3ab400e 100644 --- a/tests/graph/tracer/test_op_patcher.py +++ b/tests/graph/tracer/test_op_patcher.py @@ -1,6 +1,6 @@ import torch from types import MethodType -from cube.graph.parser.fx.concrete_trace_utils.operator_patcher import OperatorPatcher +from nnscaler.graph.parser.fx.concrete_trace_utils.operator_patcher import OperatorPatcher def test_patch_func_or_module(): diff --git a/tests/graph/tracer/test_pytree.py b/tests/graph/tracer/test_pytree.py index 6eb712f8..043cb322 100644 --- a/tests/graph/tracer/test_pytree.py +++ b/tests/graph/tracer/test_pytree.py @@ -2,7 +2,7 @@ from torch.utils._pytree import tree_flatten -from cube.graph.parser.fx.concrete_trace_utils.utils import ( +from nnscaler.graph.parser.fx.concrete_trace_utils.utils import ( flatten_tree_with_spec, flatten_trees_with_func, flatten_trees_with_func_and_spec, diff --git a/tests/graph/tracer/test_scope.py b/tests/graph/tracer/test_scope.py index cdbd2fa4..18cf1932 100644 --- a/tests/graph/tracer/test_scope.py +++ b/tests/graph/tracer/test_scope.py @@ -1,6 +1,6 @@ import torch -from cube.graph.parser.converter import to_fx_graph +from nnscaler.graph.parser.converter import to_fx_graph from ...utils import replace_all_device_with diff --git a/tests/ir/tensor.py b/tests/ir/tensor.py index 4b9ba867..5f020d65 100644 --- a/tests/ir/tensor.py +++ b/tests/ir/tensor.py @@ -1,4 +1,4 @@ -from cube.ir.tensor import IRSubTensor, IRFullTensor +from nnscaler.ir.tensor import IRSubTensor, IRFullTensor def test_tensor_grad(): diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 65b288c5..a2ef0599 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -10,13 +10,13 @@ from torch import nn import numpy as np -from cube.graph.schedule.predefined import PredefinedSched -from cube.parallel import ComputeConfig -from cube.graph.function.anchor import IRGraphAnchor -from cube.graph.function.dimops import IRDimops -from cube.graph.graph import IRGraph -from cube.graph.segment import IRSegment -from cube.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.graph.schedule.predefined import PredefinedSched +from nnscaler.parallel import ComputeConfig +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRDataOperation, IRFwOperation def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index 97728bc1..12f04309 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -5,7 +5,7 @@ import pytest import torch -from cube.parallel import ComputeConfig, parallelize, broadcast_weights +from nnscaler.parallel import ComputeConfig, parallelize, broadcast_weights from .common import PASRandomSPMD, init_distributed from ..launch_torchrun import launch_torchrun diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 46833965..ccdd4965 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -14,9 +14,9 @@ import numpy as np -from cube.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts -from cube.runtime.module import ParallelModule, ExtraState -from cube.runtime.gnorm import calcuate_gnorm +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts +from nnscaler.runtime.module import ParallelModule, ExtraState +from nnscaler.runtime.gnorm import calcuate_gnorm from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, PASMegatron from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index ceca83e6..44d0c9fa 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -4,7 +4,7 @@ import pytest -from cube.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dicts +from nnscaler.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dicts from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index e6533c35..c6a3545d 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -6,10 +6,10 @@ import torch from torch import nn -from cube.parallel import ComputeConfig, parallelize, build_optimizer, \ +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, \ merge_state_dicts, load_merged_state_dicts, \ deduped_state_dict, load_deduped_state_dict -from cube.runtime.module import ParallelModule +from nnscaler.runtime.module import ParallelModule from .common import PASRandomSPMD, PASMegatron, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, assert_equal from ..launch_torchrun import launch_torchrun diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index 838761c3..e0aef12b 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -6,7 +6,7 @@ import torch from torch import nn -from cube.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts from .common import PASRandomSPMD, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py index 6c554dba..e6bc7a6a 100644 --- a/tests/parallel_module/test_checkpoint_unused.py +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -14,9 +14,9 @@ import numpy as np -from cube.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts -from cube.runtime.module import ParallelModule, ExtraState -from cube.runtime.gnorm import calcuate_gnorm +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts +from nnscaler.runtime.module import ParallelModule, ExtraState +from nnscaler.runtime.gnorm import calcuate_gnorm from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index 647f2a01..a8efcd19 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -14,9 +14,9 @@ import numpy as np -from cube.parallel import ComputeConfig, parallelize, build_optimizer -from cube.runtime.module import ParallelModule -from cube.runtime.gnorm import calcuate_gnorm +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.runtime.module import ParallelModule +from nnscaler.runtime.gnorm import calcuate_gnorm from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index 501d9ad3..c808a56e 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -13,11 +13,11 @@ from torch import nn import torch.distributed -import cube -from cube.runtime.gnorm import calcuate_gnorm -from cube.runtime.utils import microbatches -from cube.runtime.module import ParallelModule -from cube.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts +import nnscaler +from nnscaler.runtime.gnorm import calcuate_gnorm +from nnscaler.runtime.utils import microbatches +from nnscaler.runtime.module import ParallelModule +from nnscaler.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts from .common import PASData, PASRandomSPMD, assert_equal, clear_dir_on_rank0, init_distributed, PASMegatron, init_random, PASHybrid from ..launch_torchrun import clone_to_cpu_recursively, launch_torchrun diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index b365f946..8642bbfa 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -4,8 +4,8 @@ import torch import pytest -import cube.graph.function.dimops -from cube.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph +import nnscaler.graph.function.dimops +from nnscaler.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph from .common import PASData, init_distributed, PASRandomSPMD from ..launch_torchrun import launch_torchrun @@ -205,7 +205,7 @@ def forward(self, x, attr): def _gencode_contains(cubesave_dir, module_class, index, search_re): - from cube.parallel import _CUBE_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME + from nnscaler.parallel import _CUBE_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME from pathlib import Path import re namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' @@ -480,7 +480,7 @@ def test_codegen_clone(): tempdir, True ) - assert isinstance(g.nodes()[0], cube.graph.function.dimops.IRDimops) + assert isinstance(g.nodes()[0], nnscaler.graph.function.dimops.IRDimops) class MinModule(torch.nn.Module): diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 751136b3..23ab211d 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -6,7 +6,7 @@ import torch from torch import nn -from cube.parallel import ComputeConfig, parallelize +from nnscaler.parallel import ComputeConfig, parallelize from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD, clear_dir_on_rank0 from ..launch_torchrun import torchrun diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index acabf819..3e2ffc0b 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -3,7 +3,7 @@ import torch -from cube.parallel import _load_cube_module_class, parallelize, ComputeConfig +from nnscaler.parallel import _load_cube_module_class, parallelize, ComputeConfig from ..launch_torchrun import launch_torchrun from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD, clear_dir_on_rank0 diff --git a/tests/parallel_module/test_nested.py b/tests/parallel_module/test_nested.py index 4e64d23d..ea0a8080 100644 --- a/tests/parallel_module/test_nested.py +++ b/tests/parallel_module/test_nested.py @@ -3,7 +3,7 @@ import torch import pytest -from cube.parallel import parallelize, ComputeConfig +from nnscaler.parallel import parallelize, ComputeConfig from .common import PASData, init_distributed from ..launch_torchrun import launch_torchrun diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index 1825978e..6993d278 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -6,9 +6,9 @@ import torch import shutil -from cube.graph.parser.fx.parser import FxModuleParser -from cube.parallel import ReuseType, parallelize, ComputeConfig, _load_cube_module_class -from cube.runtime.module import ParallelModule +from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.parallel import ReuseType, parallelize, ComputeConfig, _load_cube_module_class +from nnscaler.runtime.module import ParallelModule from ..utils import new_empty, replace_all_device_with from .common import PASData diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index c6972233..a3d2e81f 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -6,8 +6,8 @@ import torch from torch import nn -from cube.parallel import ComputeConfig, parallelize, build_optimizer -from cube.runtime.module import ParallelModule +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.runtime.module import ParallelModule from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py index 1c003540..22e49bbd 100644 --- a/tests/parallel_module/test_scale_grads.py +++ b/tests/parallel_module/test_scale_grads.py @@ -14,9 +14,9 @@ import numpy as np -from cube.parallel import ComputeConfig, parallelize, build_optimizer -from cube.runtime.module import ParallelModule, ExtraState -from cube.runtime.gnorm import calcuate_gnorm +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.runtime.module import ParallelModule, ExtraState +from nnscaler.runtime.gnorm import calcuate_gnorm from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index d0774529..575cb642 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -9,8 +9,8 @@ from torch import nn import numpy as np -from cube.parallel import ComputeConfig, parallelize, build_optimizer -from cube.runtime.module import ParallelModule +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.runtime.module import ParallelModule from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 3a4f9da0..53c3c3b6 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -9,8 +9,8 @@ from torch import nn import numpy as np -from cube.parallel import ComputeConfig, parallelize, build_optimizer -from cube.runtime.module import ParallelModule +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.runtime.module import ParallelModule from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively diff --git a/tests/profiler/test_op_profile.py b/tests/profiler/test_op_profile.py index 576b3697..88105464 100644 --- a/tests/profiler/test_op_profile.py +++ b/tests/profiler/test_op_profile.py @@ -4,10 +4,10 @@ import pytest import torch -from cube.parallel import _gen_graph -from cube.ir.tensor import IRTensor -from cube.ir.operator import IRFwOperation -from cube.profiler.database import CompProfiler, ProfileDataBase +from nnscaler.parallel import _gen_graph +from nnscaler.ir.tensor import IRTensor +from nnscaler.ir.operator import IRFwOperation +from nnscaler.profiler.database import CompProfiler, ProfileDataBase class NaiveFFN(torch.nn.Module): diff --git a/tests/runtime/test_dataloader.py b/tests/runtime/test_dataloader.py index ee4d44e3..7f39104a 100644 --- a/tests/runtime/test_dataloader.py +++ b/tests/runtime/test_dataloader.py @@ -1,6 +1,6 @@ import torch -from cube.runtime.utils import MicroBatchDataLoader, microbatches +from nnscaler.runtime.utils import MicroBatchDataLoader, microbatches import pytest diff --git a/tests/runtime/test_gnorm.py b/tests/runtime/test_gnorm.py index 870ec565..624163b2 100644 --- a/tests/runtime/test_gnorm.py +++ b/tests/runtime/test_gnorm.py @@ -1,5 +1,5 @@ """ -This test is to verify the correctness of the gradient norm algorithm for cube. +This test is to verify the correctness of the gradient norm algorithm for nnscaler. To avoid other potential parity issues that may have influence the gradient value, we use weight data as gradient, and calculate its norm to verify the correctness @@ -8,11 +8,11 @@ import torch from functools import partial -import cube -from cube.ir.operator import IRFwOperation -from cube.runtime.module import CubeModule -from cube.runtime.gnorm import prepare_for_grad_clip, clip_gnorm -from cube.flags import CompileFlag +import nnscaler +from nnscaler.ir.operator import IRFwOperation +from nnscaler.runtime.module import CubeModule +from nnscaler.runtime.gnorm import prepare_for_grad_clip, clip_gnorm +from nnscaler.flags import CompileFlag from ..launch_torchrun import torchrun from ..utils import init_parameter @@ -76,7 +76,7 @@ def pp_policy(graph, resource, su_num): def model_test(policy, su_num: int = 1, use_zero: bool = False): # su_num: scale unit number - cube.init() + nnscaler.init() CompileFlag.use_zero = use_zero model = Module().cuda() @@ -86,14 +86,14 @@ def model_test(policy, su_num: int = 1, use_zero: bool = False): wnorm_baseline = cal_wnorm_baseline(model) sample = torch.randn(16, 16).cuda() - @cube.compile(model, sample, PAS=partial(policy, su_num=su_num), + @nnscaler.compile(model, sample, PAS=partial(policy, su_num=su_num), scale=su_num > 1) def train_iter(model, data): loss = model(data) loss.backward() return loss - model = cube.load_model() + model = nnscaler.load_model() # train_iter(model, sample) # link .grad to reducer buffer wnorm_cube = cal_wnorm_cube(model) diff --git a/tests/runtime/test_grad_accum.py b/tests/runtime/test_grad_accum.py index 6b8875b9..7191e015 100644 --- a/tests/runtime/test_grad_accum.py +++ b/tests/runtime/test_grad_accum.py @@ -2,8 +2,8 @@ import pytest from functools import partial -import cube -from cube.runtime.module import CubeModule +import nnscaler +from nnscaler.runtime.module import CubeModule from ..launch_torchrun import torchrun from ..utils import init_parameter, assert_parity @@ -18,7 +18,7 @@ def __init__(self, ngpus, async_op, dim=512, nlayers=4,): for _ in range(nlayers): self.layers.append(torch.nn.Linear(dim, dim, bias=False)) - self.wreducer1 = cube.runtime.adapter.Reducer(ranks=ranks, reduce_op='sum', async_op=async_op, zero=False, + self.wreducer1 = nnscaler.runtime.adapter.Reducer(ranks=ranks, reduce_op='sum', async_op=async_op, zero=False, max_bucket_size_bytes=137217728, zero_ngroups=1) for param in self.parameters(): self.wreducer1.add_param(param) @@ -151,7 +151,7 @@ def reducer_async_test_correct(accum_times: int = 4): for _ in range(3): model.zero_grad() for step in range(accum_times): - with cube.accum_mode(begin=(step == 0), end=(step == accum_times - 1)): + with nnscaler.accum_mode(begin=(step == 0), end=(step == accum_times - 1)): x = get_dummy_data() x = x.chunk(ngpus, dim=0)[rank] loss = model(x) @@ -170,7 +170,7 @@ def reducer_async_test_correct(accum_times: int = 4): def accum_test(): - cube.init() + nnscaler.init() print('starting reducer sync') assert_parity(baseline, partial(reducer_sync_test, 4)) print('starting reducer async') diff --git a/tests/runtime/test_module_merge.py b/tests/runtime/test_module_merge.py index e5c946db..dc255e96 100644 --- a/tests/runtime/test_module_merge.py +++ b/tests/runtime/test_module_merge.py @@ -1,12 +1,12 @@ import torch -import cube +import nnscaler import os from functools import partial import pytest -from cube.ir.operator import IRFwOperation -from cube.runtime.device import DeviceGroup +from nnscaler.ir.operator import IRFwOperation +from nnscaler.runtime.device import DeviceGroup from ..launch_torchrun import torchrun @@ -67,19 +67,19 @@ def assert_same_state(origin, merged): def merge_model_states_test(): - cube.init() + nnscaler.init() model = Module() sample = torch.randn(8, 8, device=torch.cuda.current_device()) full_model_state = model.state_dict() - @cube.compile(model, sample, PAS=tp_policy) + @nnscaler.compile(model, sample, PAS=tp_policy) def train_iter(model, sample): loss = model(sample) loss.backward() return loss - cube_model = cube.load_model() + cube_model = nnscaler.load_model() state_dict = cube_model.state_dict() torch.save({'state_dict': state_dict, 'fullmap': cube_model.fullmap}, @@ -100,7 +100,7 @@ def train_iter(model, sample): def merge_optimizer_states_test(): - cube.init() + nnscaler.init() torch.manual_seed(0) model = Module().cuda() @@ -110,13 +110,13 @@ def merge_optimizer_states_test(): full_model_state = model.state_dict() full_optim_state = full_optimizer.state_dict() - @cube.compile(model, sample, PAS=tp_policy) + @nnscaler.compile(model, sample, PAS=tp_policy) def train_iter(model, sample): loss = model(sample) loss.backward() return loss - cube_model = cube.load_model() + cube_model = nnscaler.load_model() optimizer = torch.optim.Adam(cube_model.parameters(), lr=0.01) # test for initial state diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py index 5a3281f1..2997ebea 100644 --- a/tests/runtime/test_reducer.py +++ b/tests/runtime/test_reducer.py @@ -5,10 +5,10 @@ import logging from functools import partial -import cube -from cube.graph import IRGraph -from cube.ir.operator import IRFwOperation -from cube.flags import CompileFlag +import nnscaler +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.flags import CompileFlag from ..launch_torchrun import torchrun from ..utils import init_parameter, assert_parity @@ -91,13 +91,13 @@ def tensor_parallelism(node, idx, dim, num): x = get_dummy_data() - @cube.compile(model, x, PAS=policy) + @nnscaler.compile(model, x, PAS=policy) def train_iter(model, x): loss = model(x) loss.backward() return loss - model = cube.load_model() + model = nnscaler.load_model() optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=0.01) losses = [] @@ -118,7 +118,7 @@ def train_iter(model, x): def reducer_test(): - cube.init() + nnscaler.init() CompileFlag.disable_code_line_info = True # speedup parse print('starting zero=True, async=True') assert_parity(baseline, partial(reducer, True, True)) diff --git a/tests/runtime/test_runtime_collectives.py b/tests/runtime/test_runtime_collectives.py index a31dc5aa..70b39dc5 100644 --- a/tests/runtime/test_runtime_collectives.py +++ b/tests/runtime/test_runtime_collectives.py @@ -1,6 +1,6 @@ from typing import List -import cube +import nnscaler import torch import pytest @@ -27,10 +27,10 @@ def _move_worker(async_op: bool): shape = [128, 256] tensor = _get_tensor(shape) - tensor = cube.runtime.adapter.move(tensor, shape, torch.float32, 0, 1, async_op=async_op) + tensor = nnscaler.runtime.adapter.move(tensor, shape, torch.float32, 0, 1, async_op=async_op) if async_op: - tensor = cube.runtime.executor.AsyncCommHandler().wait(tensor) + tensor = nnscaler.runtime.executor.AsyncCommHandler().wait(tensor) return clone_to_cpu(tensor) @@ -39,10 +39,10 @@ def _allreduce_worker(async_op: bool): tensor = _get_tensor(shape) sum_tensor = tensor.clone().detach() - sum_tensor = cube.runtime.adapter.all_reduce(sum_tensor, [0, 1], async_op=async_op) + sum_tensor = nnscaler.runtime.adapter.all_reduce(sum_tensor, [0, 1], async_op=async_op) if async_op: - sum_tensor = cube.runtime.executor.AsyncCommHandler().wait(sum_tensor) + sum_tensor = nnscaler.runtime.executor.AsyncCommHandler().wait(sum_tensor) return (clone_to_cpu(tensor), clone_to_cpu(sum_tensor)) @@ -50,10 +50,10 @@ def _allgather_worker(async_op: bool): shape = [128, 256] tensor = _get_tensor(shape) - otensor = cube.runtime.adapter.all_gather(tensor, 0, [0, 1], async_op=async_op) + otensor = nnscaler.runtime.adapter.all_gather(tensor, 0, [0, 1], async_op=async_op) if async_op: - otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + otensor = nnscaler.runtime.executor.AsyncCommHandler().wait(otensor) return (clone_to_cpu(tensor), clone_to_cpu(otensor)) @@ -61,10 +61,10 @@ def _reduce_scatter_worker(async_op: bool): shape = [128, 256] tensor = _get_tensor(shape) - otensor = cube.runtime.adapter.reduce_scatter(tensor, 0, [0, 1], async_op=async_op) + otensor = nnscaler.runtime.adapter.reduce_scatter(tensor, 0, [0, 1], async_op=async_op) if async_op: - otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + otensor = nnscaler.runtime.executor.AsyncCommHandler().wait(otensor) return (clone_to_cpu(tensor), clone_to_cpu(otensor)) @@ -75,10 +75,10 @@ def _all2all_worker(async_op): tensor = _get_tensor(shape) # # synchronize - otensor = cube.runtime.adapter.all_to_all(tensor, 0, 1, [0, 1], async_op=async_op) + otensor = nnscaler.runtime.adapter.all_to_all(tensor, 0, 1, [0, 1], async_op=async_op) if async_op: - otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + otensor = nnscaler.runtime.executor.AsyncCommHandler().wait(otensor) return (clone_to_cpu(tensor), clone_to_cpu(otensor)) @@ -139,11 +139,11 @@ def _rdscatter_worker(async_op): shape = [128, 256] tensor = _get_tensor(shape) - otensor = cube.runtime.adapter.rdscatter( + otensor = nnscaler.runtime.adapter.rdscatter( tensor, shape, torch.float32, dim=0, src=0, dsts=[1,2], async_op=async_op) if async_op: - otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + otensor = nnscaler.runtime.executor.AsyncCommHandler().wait(otensor) return (clone_to_cpu(tensor), clone_to_cpu(otensor)) @@ -152,11 +152,11 @@ def _rdgather_worker(async_op): shape = [128, 256] tensor = _get_tensor(shape) - otensor = cube.runtime.adapter.rdgather( + otensor = nnscaler.runtime.adapter.rdgather( tensor, shape, torch.float32, dim=0, srcs=[1,2], dst=0) if async_op: - otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + otensor = nnscaler.runtime.executor.AsyncCommHandler().wait(otensor) return (clone_to_cpu(tensor), clone_to_cpu(otensor)) @@ -167,10 +167,10 @@ def _broadcast_worker(async_op): tensor = _get_tensor(shape) # synchronize - otensor = cube.runtime.adapter.broadcast( + otensor = nnscaler.runtime.adapter.broadcast( tensor, shape, torch.float32, src=0, ranks=[0,1,2]) if async_op: - otensor = cube.runtime.executor.AsyncCommHandler().wait(otensor) + otensor = nnscaler.runtime.executor.AsyncCommHandler().wait(otensor) return (clone_to_cpu(tensor), clone_to_cpu(otensor)) diff --git a/tests/test_program.py b/tests/test_program.py index 7dbc3a34..53b34b3c 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -1,9 +1,9 @@ import pytest import torch -from cube.program import SemanticModel, Program -from cube.flags import CompileFlag -from cube.ir.cten import IRObject +from nnscaler.program import SemanticModel, Program +from nnscaler.flags import CompileFlag +from nnscaler.ir.cten import IRObject @pytest.mark.skipif(not torch.cuda.is_available(), reason='cuda is not available') diff --git a/tests/utils.py b/tests/utils.py index badc80a8..906d771f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,8 +14,8 @@ import torch.distributed as dist import torch.distributed.distributed_c10d as c10d -from cube.runtime.module import ParallelModule -from cube.runtime.device import DeviceGroup, CompileFlag +from nnscaler.runtime.module import ParallelModule +from nnscaler.runtime.device import DeviceGroup, CompileFlag def init_parameter(model: torch.nn.Module, seed: int = 0): @@ -94,7 +94,7 @@ def replace_all_device_with(device='cpu', force=False): yield return - from cube.graph.parser.fx.concrete_trace_utils.concrete_tracer import ConcreteTracer + from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import ConcreteTracer orig_to = torch.Tensor.to orig_cuda = torch.Tensor.cuda diff --git a/tox.ini b/tox.ini index 4411b3ac..e455477d 100644 --- a/tox.ini +++ b/tox.ini @@ -14,5 +14,5 @@ deps = -rrequirements.txt -rrequirements-dev.txt commands = coverage erase - pytest --cov={toxinidir}/cube -x tests + pytest --cov={toxinidir}/nnscaler -x tests coverage html diff --git a/tutorial.md b/tutorial.md index d71ba92e..db88c36f 100644 --- a/tutorial.md +++ b/tutorial.md @@ -73,7 +73,7 @@ If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimens To register a customized "matmul" operator in the runtime, user can simply define a python function and add an decorator on the function with its annotations: ```py -@cube.graph.parser.register('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom') +@nnscaler.graph.parser.register('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom') def operator(x: torch.Tensor, w: torch.Tensor, h: float) -> torch.Tensor: out = torch.matmul(x, w) out = out.view(h, out.size(0) // h, out.size(1)) @@ -106,4 +106,4 @@ def PAS(graph: IRGraph, resource): return graph ``` -Note: we require user to add type annotation of output and input in the function, to help system understand each identifier number. The non-tensor inputs should be listed at the last and don't need to be represented into annotation. \ No newline at end of file +Note: we require user to add type annotation of output and input in the function, to help system understand each identifier number. The non-tensor inputs should be listed at the last and don't need to be represented into annotation. diff --git a/utility/test_rvd_prim.py b/utility/test_rvd_prim.py index 739c8ca3..7e8251c2 100644 --- a/utility/test_rvd_prim.py +++ b/utility/test_rvd_prim.py @@ -10,14 +10,14 @@ """ from typing import Callable -import cube +import nnscaler import torch import time import argparse -from cube.profiler.timer import CudaTimer, print_each_rank +from nnscaler.profiler.timer import CudaTimer, print_each_rank -from cube.runtime.adapter.collectives import all_reduce, all_gather, reduce_scatter, all_to_all -from cube.runtime.device import DeviceGroup +from nnscaler.runtime.adapter.collectives import all_reduce, all_gather, reduce_scatter, all_to_all +from nnscaler.runtime.device import DeviceGroup def prim_allreduce(itensor, ranks, dim0=None, dim1=None): @@ -100,7 +100,7 @@ def prim_bw(prim: Callable, bandwidth: Callable, ranks, size, warmup=100, profil if __name__ == '__main__': - cube.init() + nnscaler.init() parser = argparse.ArgumentParser(description='comm primitive') parser.add_argument('--prims', type=str, nargs='+', action='append', From 2f9d3305efd054903b95cc8165f8372aed6d27d2 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 18 Apr 2024 06:44:28 +0000 Subject: [PATCH 1622/1892] Merged PR 2116: fix dimops align infer in dimops anno infer, we should support the following case: (a b), a -> a b in previous, if a tensor has combined identifiers like (a b), then we will directly tell user we cannot infer the unknown value of a b and raise error. but in fact, we can get a and b value from other tensor identifier, this pr support this case. --- nnscaler/graph/function/dimops.py | 21 ++++++++++++++++++--- tests/graph/function/test_dimops.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index 23962a35..eb58f7c3 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -758,6 +758,18 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict op_anno.set_output(idx, shape_anno) op_anno.reset_identifiers() + identifier_values: Dict[str, int] = dict() + for ashape, itensor in zip(op_anno.inputs(), inputs): + if not isinstance(itensor, IRTensor) or ashape.ignore: + continue + if ashape.ndims != len(itensor.shape): + return False + for adim, dimlen in zip(ashape.dims, itensor.shape): + if len(adim.identifiers) == 1: + if adim.identifiers[0] in identifier_values and identifier_values[adim.identifiers[0]] != dimlen: + raise RuntimeError(f'the exist identifier value {identifier_values[adim.identifiers[0]]} is not equal to the new value {dimlen}') + identifier_values[adim.identifiers[0]] = dimlen + # check dimension consistency for ashape, itensor in zip(op_anno.inputs(), inputs): if itensor is None: @@ -778,9 +790,7 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict for identifier in identifiers: length = op_anno.getlen(identifier) if length is None: - if identifier not in kwargs: - toinfer.append(identifier) - else: + if identifier in kwargs: if isinstance(kwargs[identifier], IRObject): _logger.warning( f"Function {signature}: Found identifier {identifier} in kwargs to be IRObject, " @@ -794,6 +804,11 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict f"must be int or IRObject[value=int], but got {length}") ret = op_anno.setlen(identifier, length) accum *= length + elif identifier in identifier_values: + ret = op_anno.setlen(identifier, identifier_values[identifier]) + accum *= identifier_values[identifier] + else: + toinfer.append(identifier) else: accum *= length if len(toinfer) == 0 and accum != dimlen: diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index aec2f28c..cbfb230d 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -55,3 +55,32 @@ def NoReturnOp(input, weight, signature='no_return_op'): op = create_op(NoReturnOp, [(1024, 512), (512, 1024)]) assert len(op.outputs()) == 1 and isinstance(op.output(0), IRObject) and (not isinstance(op.output(0), IRFullTensor)) + + +def test_inner_dimensions_infer(): + + def TestFunc(input, weight, signature='test_func'): + anno = 'a b, (b c) -> a (b c)' + return IRDimops(TestFunc, 'test_func', signature, [anno], [input, weight]) + + op = create_op(TestFunc, [(1024, 512), (2048,)]) + partitionable(op, idx=0, dim=1, num=2) + +def test_anno_kwargs_infer(): + + def TestFunc(input, weight, number=128, signature='test_func'): + anno = '(a number), (b number) -> (a b)' + return IRDimops(TestFunc, 'test_func', signature, [anno], [input, weight], number=number) + + op = create_op(TestFunc, [(1024,), (2048,)]) + partitionable(op, idx=0, dim=0, num=2) + + +def test_dynamic_shape_infer(): + # TODO: please note that this test should be rewritten after we can fully support dynamic shape + def TestFunc(input, weight, bias, number=128, signature='test_func'): + anno = '(a number), (b number), (a b) -> 1' + return IRDimops(TestFunc, 'test_func', signature, [anno], [input, weight, bias], number=number) + + op = create_op(TestFunc, [(1024,), (2048,), (128,)], number=IRObject(value=128)) + partitionable(op, idx=0, dim=0, num=2) From 7871d899dd708cf18f2ff0b0bce8b8bdba53d4c8 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Thu, 18 Apr 2024 06:54:01 +0000 Subject: [PATCH 1623/1892] Merged PR 2118: [Re-Org Repo, Part 2] Merge AutoDist This PR copies the files from the AutoDist repo (latest commit `afd8dc5`). The source code folder `autodist` is copied to `nnscaler/autodist`; The test cases folder `tests` is copied to `tests/autodist`; All other files are *temporarily* copied to `autodist` folder, and they will be migrated in following up PRs. This PR will be merged into `main` branch after #2117 It is currently opened to `nnscaler` branch for easier comparison. --- .gitignore | 17 + autodist/.pre-commit-config.yaml | 32 + autodist/.style.yapf | 2 + autodist/README.md | 17 + autodist/benchmark/alpa/Alpa_solver.md | 43 + autodist/benchmark/alpa/README.md | 203 +++ .../benchmark/alpa/analyse_strategy/README.md | 15 + .../alpa/analyse_strategy/gen_str.py | 137 ++ .../alpa/analyse_strategy/strategy.zip | Bin 0 -> 886 bytes autodist/benchmark/alpa/benchmark.py | 204 +++ autodist/benchmark/alpa/gpt_alpa_2d_table1.sh | 16 + autodist/benchmark/alpa/gpt_alpa_2d_table2.sh | 27 + autodist/benchmark/alpa/gpt_alpa_3d.sh | 16 + autodist/benchmark/alphafold2.md | 124 ++ autodist/benchmark/gpt.md | 234 ++++ autodist/benchmark/recompute.md | 28 + autodist/build_env.py | 58 + autodist/comm_profile.py | 107 ++ autodist/docs/descs.py | 58 + autodist/docs/images/arch.png | Bin 0 -> 49324 bytes autodist/docs/interface_design.md | 36 + .../solver_interface/partition_constraint.md | 43 + .../solver_interface/pc_examples/moe_pc.yaml | 20 + .../pc_examples/retnet_dp2_pc.yaml | 26 + .../pc_examples/retnet_hybrid2_pc.yaml | 21 + .../pc_examples/retnet_mp2_pc.yaml | 21 + .../profile_data/16xmi200/comm/intra_16.json | 122 ++ .../profile_data/16xmi200/comm/intra_2.json | 122 ++ .../profile_data/16xmi200/comm/intra_4.json | 122 ++ .../profile_data/16xmi200/comm/intra_8.json | 122 ++ autodist/script/alphafold/foldtp.sh | 64 + autodist/script/gpt/adapt_recom_tp.sh | 59 + autodist/script/gpt/analyze.py | 77 ++ autodist/script/gpt/analyze_adapt_recom.py | 77 ++ autodist/script/gpt/checker.sh | 42 + autodist/script/gpt/pp_all_run.sh | 45 + autodist/script/gpt/profile.sh | 66 + autodist/script/gpt/tp_all_run.sh | 57 + autodist/script/pre_install.sh | 7 + autodist/script/swin/analysis.py | 72 + autodist/script/swin/profile_swin.sh | 60 + autodist/script/swin/swintp.sh | 66 + nnscaler/autodist/__init__.py | 0 nnscaler/autodist/apis.py | 277 ++++ nnscaler/autodist/autodist_config.py | 224 ++++ nnscaler/autodist/cost_database.py | 502 +++++++ nnscaler/autodist/csrc/solver.cpp | 799 +++++++++++ nnscaler/autodist/cube_operator.py | 111 ++ nnscaler/autodist/descs.py | 187 +++ nnscaler/autodist/model_graph.py | 797 +++++++++++ nnscaler/autodist/op_partition.py | 150 +++ nnscaler/autodist/pipeline_solver.py | 295 +++++ nnscaler/autodist/spmd_solver.py | 1167 +++++++++++++++++ nnscaler/autodist/util.py | 83 ++ requirements-dev.txt | 10 +- requirements.txt | 6 +- tests/autodist/__init__.py | 0 tests/autodist/graph/__init__.py | 0 tests/autodist/graph/test_calc_flops.py | 61 + tests/autodist/graph/test_recompute.py | 125 ++ tests/autodist/partition/__init__.py | 0 tests/autodist/partition/test_state.py | 46 + tests/autodist/pas/__init__.py | 0 tests/autodist/pas/all_replicated_pp.json | 87 ++ .../pas/replicated_and_partition.json | 87 ++ .../pas/test_shared_param_pipeline.py | 79 ++ tests/autodist/spmd_solver/__init__.py | 0 .../spmd_solver/test_attention_follow.yaml | 18 + .../spmd_solver/test_cube_operator.py | 50 + tests/autodist/spmd_solver/test_follow.py | 283 ++++ .../spmd_solver/test_partition_constraint.py | 107 ++ tests/autodist/spmd_solver/test_pc.yaml | 16 + .../autodist/spmd_solver/test_shared_param.py | 66 + 73 files changed, 8312 insertions(+), 6 deletions(-) create mode 100644 autodist/.pre-commit-config.yaml create mode 100644 autodist/.style.yapf create mode 100644 autodist/README.md create mode 100644 autodist/benchmark/alpa/Alpa_solver.md create mode 100644 autodist/benchmark/alpa/README.md create mode 100644 autodist/benchmark/alpa/analyse_strategy/README.md create mode 100644 autodist/benchmark/alpa/analyse_strategy/gen_str.py create mode 100644 autodist/benchmark/alpa/analyse_strategy/strategy.zip create mode 100644 autodist/benchmark/alpa/benchmark.py create mode 100644 autodist/benchmark/alpa/gpt_alpa_2d_table1.sh create mode 100644 autodist/benchmark/alpa/gpt_alpa_2d_table2.sh create mode 100644 autodist/benchmark/alpa/gpt_alpa_3d.sh create mode 100644 autodist/benchmark/alphafold2.md create mode 100644 autodist/benchmark/gpt.md create mode 100644 autodist/benchmark/recompute.md create mode 100644 autodist/build_env.py create mode 100644 autodist/comm_profile.py create mode 100644 autodist/docs/descs.py create mode 100644 autodist/docs/images/arch.png create mode 100644 autodist/docs/interface_design.md create mode 100644 autodist/docs/solver_interface/partition_constraint.md create mode 100644 autodist/docs/solver_interface/pc_examples/moe_pc.yaml create mode 100644 autodist/docs/solver_interface/pc_examples/retnet_dp2_pc.yaml create mode 100644 autodist/docs/solver_interface/pc_examples/retnet_hybrid2_pc.yaml create mode 100644 autodist/docs/solver_interface/pc_examples/retnet_mp2_pc.yaml create mode 100644 autodist/profile_data/16xmi200/comm/intra_16.json create mode 100644 autodist/profile_data/16xmi200/comm/intra_2.json create mode 100644 autodist/profile_data/16xmi200/comm/intra_4.json create mode 100644 autodist/profile_data/16xmi200/comm/intra_8.json create mode 100755 autodist/script/alphafold/foldtp.sh create mode 100644 autodist/script/gpt/adapt_recom_tp.sh create mode 100644 autodist/script/gpt/analyze.py create mode 100644 autodist/script/gpt/analyze_adapt_recom.py create mode 100755 autodist/script/gpt/checker.sh create mode 100755 autodist/script/gpt/pp_all_run.sh create mode 100644 autodist/script/gpt/profile.sh create mode 100755 autodist/script/gpt/tp_all_run.sh create mode 100644 autodist/script/pre_install.sh create mode 100644 autodist/script/swin/analysis.py create mode 100644 autodist/script/swin/profile_swin.sh create mode 100755 autodist/script/swin/swintp.sh create mode 100644 nnscaler/autodist/__init__.py create mode 100644 nnscaler/autodist/apis.py create mode 100644 nnscaler/autodist/autodist_config.py create mode 100644 nnscaler/autodist/cost_database.py create mode 100644 nnscaler/autodist/csrc/solver.cpp create mode 100644 nnscaler/autodist/cube_operator.py create mode 100644 nnscaler/autodist/descs.py create mode 100644 nnscaler/autodist/model_graph.py create mode 100644 nnscaler/autodist/op_partition.py create mode 100644 nnscaler/autodist/pipeline_solver.py create mode 100644 nnscaler/autodist/spmd_solver.py create mode 100644 nnscaler/autodist/util.py create mode 100644 tests/autodist/__init__.py create mode 100644 tests/autodist/graph/__init__.py create mode 100644 tests/autodist/graph/test_calc_flops.py create mode 100644 tests/autodist/graph/test_recompute.py create mode 100644 tests/autodist/partition/__init__.py create mode 100644 tests/autodist/partition/test_state.py create mode 100644 tests/autodist/pas/__init__.py create mode 100644 tests/autodist/pas/all_replicated_pp.json create mode 100644 tests/autodist/pas/replicated_and_partition.json create mode 100644 tests/autodist/pas/test_shared_param_pipeline.py create mode 100644 tests/autodist/spmd_solver/__init__.py create mode 100644 tests/autodist/spmd_solver/test_attention_follow.yaml create mode 100644 tests/autodist/spmd_solver/test_cube_operator.py create mode 100644 tests/autodist/spmd_solver/test_follow.py create mode 100644 tests/autodist/spmd_solver/test_partition_constraint.py create mode 100644 tests/autodist/spmd_solver/test_pc.yaml create mode 100644 tests/autodist/spmd_solver/test_shared_param.py diff --git a/.gitignore b/.gitignore index b6267036..c866c6ef 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,20 @@ benchmark/deepspeed/Megatron-DeepSpeed gencode*.py fullmodel.pt dist_param_map.pt + +## autodist ## + +# Python cache +*.pyc +dist +.cache +*env + +# Generated by Cube +gencode* +*.pt + +# Other +shelf +*.iml +*.xml diff --git a/autodist/.pre-commit-config.yaml b/autodist/.pre-commit-config.yaml new file mode 100644 index 00000000..f332baa8 --- /dev/null +++ b/autodist/.pre-commit-config.yaml @@ -0,0 +1,32 @@ +# File introduces automated checks triggered on git events +# to enable run `pip install pre-commit && pre-commit install` + +repos: + - repo: local + hooks: + - id: yapf + name: yapf + language: python + entry: yapf + args: [-i, -vv] + types: [python] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: check-docstring-first + - id: check-json + - id: check-added-large-files + - id: check-yaml + - id: debug-statements + - id: requirements-txt-fixer + - id: check-merge-conflict + - id: double-quote-string-fixer + - id: end-of-file-fixer + - repo: meta + hooks: + - id: check-useless-excludes + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v15.0.7 + hooks: + - id: clang-format diff --git a/autodist/.style.yapf b/autodist/.style.yapf new file mode 100644 index 00000000..0e9640c2 --- /dev/null +++ b/autodist/.style.yapf @@ -0,0 +1,2 @@ +[style] +based_on_style = google diff --git a/autodist/README.md b/autodist/README.md new file mode 100644 index 00000000..b252f7ec --- /dev/null +++ b/autodist/README.md @@ -0,0 +1,17 @@ +# AutoDist + +AutoDist is a package that optimizes for efficient distributed execution plans, given a DL data flow graph and cluster specifications. Compared to [Alpa](https://github.com/alpa-projects/alpa), AutoDist has two main advantages: +- a topology aware dynamic programming solver, which is faster than integer linear programing formulation in most cases +- achieve a balance between memory and time automatically, instead of using a global option + +## Prerequisite + +```bash +bash ./script/pre_install.sh +``` + +## Pipeline + +

+ +

diff --git a/autodist/benchmark/alpa/Alpa_solver.md b/autodist/benchmark/alpa/Alpa_solver.md new file mode 100644 index 00000000..ef3e97c5 --- /dev/null +++ b/autodist/benchmark/alpa/Alpa_solver.md @@ -0,0 +1,43 @@ +# Alpa Solver Details + +We have conducted a detailed test on the  solver in Alpa and have two conclusions. +1) input and constraints limit the efficiency of the solver.  +2) Alpa is unable to correctly solve problems with memory constraints.  + +The table below shows the numbers of the Alpa solver under different micro batch sizes and GPT models in spmd. This table includes there parts, the number of free variables (the first two columns) , alpa solving times (the mid three columns) and autodist compile times (the last column). **Baseline time** represents the original solver time, **random** represents filling the array with random numbers, and **mem** represents adding memory constraint conditions.  + +**1.3B** +| | num_nodes | num_edges | baseline time/s | random time/s | mem time/s | autodist time/s | +|---:|------------:|------------:|------------------:|----------------:|:-------------|------------------:| +| 1 | 2637 | 4586 | 19.3861 | 19.832 | > 600 | 2.65 | +| 2 | 2472 | 4468 | 18.4109 | 34.7781 | > 600 | 3.96 | +| 4 | 2473 | 4470 | 21.6232 | 42.6273 | > 600 | 6 | +| 8 | 2473 | 4470 | 19.3299 | 25.2286 | > 600 | 9.78 | +| 16 | 2473 | 4470 | 19.34 | 43.1969 | > 600 | 14.91 | +| 32 | 2473 | 4470 | 20.1404 | 38.766 | > 600 | 9.29 | + +**2.6B** +| | num_nodes | num_edges | baseline time/s | random time/s | mem time/s | autodist time/s | +|---:|------------:|------------:|------------------:|----------------:|:-------------|------------------:| +| 1 | 3493 | 6090 | 27.2841 | 48.645 | > 600 | 17.57 | +| 2 | 3272 | 5932 | 27.0608 | 41.8054 | > 600 | 24.84 | +| 4 | 3272 | 5932 | 25.7738 | 67.6498 | > 600 | 36.19 | +| 8 | 3273 | 5934 | 29.1824 | 67.1933 | > 600 | 76.63 | +| 16 | 3273 | 5934 | 27.4334 | 33.0636 | > 600 | 115.04 | +| 32 | 3273 | 5934 | 30.3701 | 63.2393 | > 600 | 69.36 | + +We can see from each row of the table that randomizing the input (**1.5~3x**) and adding memory constraint conditions (**>20x**) will increase the solving time of the solver, thus leading to the first conclusion stated above. Meanwhile, Autodist can quickly solve problems with memory constraints, achieving a maximum of **226x** faster solving efficiency (the first row at table 1.3B).   + +For the second conclusion, we reduced the layers of GPT-3 1.3B from 24 to 12 under a 30GB memory constraint. Alpa was unable to find a solution (but Autodist is able to find one). We make experimental examples (shown in **Table\***) to state that Alpa solver with memory constraint is unreasonable. + +**Table\*** + +The GPT model is 1.3B, we decrease the layer from 24 into 1, 5 and 12 respectively. Time ratio = mem time / baseline time. + +| | baseline time/s | mem time/s | time ratio | +|---:|-----------------:|---------------:|-------------:| +| 1 | 1.00 | 2.55 | 2.55 | +| 5 | 4.11 | 72.60 | 17.66 | +| 12 | 9.93 | None solution | -- | + +There are two unreasonable aspects, the first aspect is that the time ratio increases exponentially as the gpt model increases. Second, Alpa solver will not search for a solution when the model becomes larger (although this solution certainly exists from the above statement). diff --git a/autodist/benchmark/alpa/README.md b/autodist/benchmark/alpa/README.md new file mode 100644 index 00000000..fd99a332 --- /dev/null +++ b/autodist/benchmark/alpa/README.md @@ -0,0 +1,203 @@ +# Benchmark Alpa + +## GPT-3 + +### Usage + +For the 3d setting, the config is the same with Table 4 in [1]. For the 2d setting, we test the GPT-3 6.7B with only 4 layers. Details of the model config can be found in the `benchmark.py`, `gpt_alpa_3d.sh`, `gpt_alpa_2d_table1.sh` and `gpt_alpa_2d_table2.sh`. + +You can cd the analyse_strategy folder for more specific analysis. + +### Experimental Config + +The benchmarks are implemented on a server runing on Ubuntu 20.04 system, which is equipped with an Intel(R) Xeon(R) Platinum 8160 CPU @ 2.10GHz and 16 NVIDIA V100-SXM2 32GB GPUs, each having a theoretical TFLOPS of 120 for FP16. The 16 GPUs are connected via NVLink and the interconnect bandwidth is 300GB/s (details seeing [NVIDIA TESLA V100 GPU ACCELERATOR](https://images.nvidia.com/content/technologies/volta/pdf/437317-Volta-V100-DS-NV-US-WEB.pdf). The version of CUDA is 11.3. + +**w/ pipeline parallelism (i.e. 3d)** + +We follow alpa's GPT-3 benchmark code (seeing Fig. 7a in [1]) on our testbed and results are in table 1. +In this case you can choose to overwrite the `benchmark.py` or not and run: + +```bash +bash gpt_alpa_3d.sh +``` + +**w/o pipeline parallelism (i.e. 2d)** + +We follow alpa's GPT-3 benchmark code under shard parallel (i.e. only intra-opeartor parallelism, no pipeline parallelism). +The results with 8 V100s are in table2.1 and those with 4 V100s in table2.2. +In this case you need to overwrite the `benchmark.py` and run: + +```bash +bash gpt_alpa_2d_table1.sh + +bash gpt_alpa_2d_table2.sh +``` + +**Description of parameters in alpa** + +- `shard-only` : Only profile the 2D case. No pipeline parallelism, default=`False` +- `num_micro_batches`: The number of micro batches, equal to batch size/micro batches. When `num_micro_batches>1`, the grad function will apply `alpa.grad`, which adds the gradient accumulation mechanism to `jax.grad`. The default is `1` +- `num_gpt_layer` : The number of the gpt layer, other config parameters can be seen in `benchmark.py`. +- `dp`: The number of channel for data parallelism, an `int` from [1,2,4,……,gpus]. +- `op`: The number of channel for operator parallelism, an `int` from [1,2,4,……,gpus]. +- `reduce-scatter`: If this is True, alpa will use **reduce-scatter** and **all-gather** to replace **all-reduce**. It will achieve a sppedup in execute for reduce-scatter-friendly system, but burden the optimization time. +- `parallel mode`. It can be selected from `uniform` and `zero-3`. +- `shard`. Not using the ray cluster, default=`False`. +- `profile driven time`. Profile the execution time on the driver instead of the workers, default=`False`. +- `recomputation`. This switch determines whether recomputation is turned on, default=`False`. If recomputation is open, the memory cost will increase during the backward and the communication overhead (**all-reduce** during the recomputation) saves. + +### Results + +**Table 1** + +| gpus | TFLOPs | Peak Mem/GB | Execute time/s (Mean, Std) | Complie + optimize time/s | +|:----:|:------: |:-----------:|:--------------------------:|:-------------------------:| +| 1 | 40.67 | 7.053 | (80.787, 0.004) | 50.19 | +| 2 | 119.04 | 10.376 | (57.36, 0.080) | 49.94 | +| 4 | 240.76 | 8.575 | (48.337, 0.023) | 57.52 | +| 8 | 511.36 | 11.66 | (45.646, 0.010) | 110.69 | +| 16 | 1110.72 | 11.346 | (51.868, 0.019) | 117.46 | + +**Table 2.1** + +Details for Table 2.1: + +```bash +--num-devices-per-host 8 --num_gpt_layer 4 --num_batch_size 32 --num_micro_batches 1 --reduce_scatter +``` + +| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| (1,8) | 462.56 | 6.478 | (0.565, 0.000) | 28 | 28 | 0 | 0 | 0 | 5.69 | +| (2,4) | 538.88 | 5.098 | (0.485, 0.001) | 33 | 29 | 1 | 3 | 0 | 7.98 | +| (4,2) | 571.20 | 5.449 | (0.457, 0.000) | 33 | 29 | 1 | 3 | 0 | 7.96 | +| (8,1) | 587.44 | 6.924 | (0.445, 0.003) | 4 | 1 | 1 | 2 | 0 | 4.00 | + +**Table 2.2** w/ recompute float16 + +Details for table 2.2 w/ recompute float16: + +```bash +--num-devices-per-host 4 --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 --recomputation +``` + +| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| (1,4) | 229.80 | 5.076 | (0.142, 0.000) | 28 | 28 | 0 | 0 | 0 | 5.98 | +| (2,2) | 179.44 | 10.287 | (0.182, 0.000) | 31 | 31 | 0 | 0 | 0 | 8.33 | +| (4,1) | 161.92 | 20.571 | (0.202, 0.001) | 3 | 3 | 0 | 0 | 0 | 4.08 | + +**Table 2.2** w/o recompute float16 + +Details for table 2.2 w/o recompute float16 + +```bash +--num-devices-per-host 4 --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 +``` + +| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| (1,4) | 220.48 | 6.288 | (0.117, 0.000) | 20 | 20 | 0 | 0 | 0 | 4.47 | +| (2,2) | 164.64 | 10.287 | (0.157, 0.000) | 23 | 23 | 0 | 0 | 0 | 6.45 | +| (4,1) | 143.00 | 20.571 | (0.180, 0.001) | 3 | 3 | 0 | 0 | 0 | 2.80 | + +**Table 2.2** w/ recompute float32 + +Details for table 2.2 w/ recompute float32: + +```bash +--num-devices-per-host 4 --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 --recomputation +``` + +| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| (1,4) | 48.12 | 5.485 | (0.679, 0.001) | 29 | 27 | 2 | 0 | 0 | 5.62 | +| (2,2) | 43.96 | 11.429 | (0.743, 0.000) | 30 | 30 | 0 | 0 | 0 | 8.59 | +| (4,1) | 43.20 | 22.857 | (0.756, 0.001) | 2 | 2 | 0 | 0 | 0 | 3.59 | + +**Table 2.2** w/o recompute float32 + +Details for table 2.2 w/o recompute float32: + +```bash +--num-devices-per-host 4 --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 +``` + +| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| (1,4) | 47.28 | 7.704 | (0.545, 0.000) | 21 | 19 | 2 | 0 | 0 | 4.44 | +| (2,2) | 42.08 | 11.429 | (0.613, 0.000) | 22 | 22 | 0 | 0 | 0 | 5.89 | +| (4,1) | 40.64 | 22.857 | (0.634, 0.001) | 2 | 2 | 0 | 0 | 0 | 2.81 | + +**Table 2.3** w/o recompute float16 + +Details for table 2.3 w/o recompute float16: + +```bash +--num-devices-per-host 4 --num_batch_size 4 --num_micro_batches 1 --dp 1 --op 4 +``` + +| layers | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| 8 | 228.12 | 10.791 | (0.203, 0.000) | 36 | 36 | 0 | 0 | 0 | 8.72 | +| 12 | 228.60 | 15.340 | (0.293, 0.000) | 52 | 52 | 0 | 0 | 0 | 13.14 | +| 16 | 231.32 | 19.843 | (0.379, 0.000) | 68 | 68 | 0 | 0 | 0 | 18.05 | + +**Table 2.3** w/ recompute float16 + +Details for table 2.3 w/o recompute float16: + +```bash +--num-devices-per-host 4 --num_batch_size 4 --num_micro_batches 1 --dp 1 --op 4 +``` + +| layers | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| 8 | 236.68 | 8.163 | (0.254, 0.000) | 52 | 52 | 0 | 0 | 0 | 11.50 | +| 12 | 237.52 | 11.290 | (0.369, 0.000) | 76 | 76 | 0 | 0 | 0 | 18.95 | +| 16 | 237.76 | 14.448 | (0.484, 0.001) | 100 | 100 | 0 | 0 | 0 | 23.82 | + +**Table 2.3** w/o recompute float32 + +Details for table 2.3 w/o recompute float32: + +```bash +--num-devices-per-host 4 --num_batch_size 4 --num_micro_batches 1 --dp 1 --op 4 +``` + +| layers | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| 8 | 49.32 | 12.707 | (0.940, 0.000) | 37 | 35 | 2 | 0 | 0 | 7.70 | +| 12 | 50.40 | 17.710 | (1.330, 0.001) | 53 | 51 | 2 | 0 | 0 | 13.46 | +| 16 | 50.68 | 22.744 | (1.729, 0.001) | 69 | 67 | 2 | 0 | 0 | 16.34 | + +**Table 2.3** w/ recompute float32 + +Details for table 2.3 w/ recompute float32: + +```bash +--num-devices-per-host 4 --num_batch_size 4 --num_micro_batches 1 --dp 1 --op 4 +``` + +| layers | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | +|:------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| +| 8 | 50.12 | 8.564 | (1.200, 0.001) | 53 | 51 | 2 | 0 | 0 | 12.01 | +| 12 | 50.72 | 11.644 | (1.727, 0.000) | 77 | 75 | 2 | 0 | 0 | 16.54 | +| 16 | 51.16 | 14.801 | (2.251, 0.002) | 101 | 99 | 2 | 0 | 0 | 22.59 | + +Remark 1: When `Prefer_reduce_scatter=False` and `recomputation=False`, the tensor parallelism strategy generated by *alpa* is consistent with that of *megatron-lm*. + +## Q&A + + Q1: Why the mean time for data parallelism (dp=8,op=1 in Table 2.1) is faster than tensor parallelism (dp=1,op=8 in Table 2.1)? + + A1: Because the communication volume of the former is 12 *hidden_size*hidden_size and that of the later is 4*batch size*hidden_size (2 all-reduce in the feedfoward and 2 in the backward). + And the mean time in Table 2.2 (both w/ and w/o recomputation) supports this view. When we reduce batch size from 32 to 4, then the data parallelism (dp=4,op=1 in Table 2.2) is slower than tensor parallelism (dp=1,op=4 in Table 2.2). + + Q2: Why the TFLOPs are reduced to 1/4 of the precision of 16 bits when the precision is 32 bits? + + A2: Because it uses the tensor core technique, which boosts the TFLOPS. + +## Reference + +\[1\] Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning diff --git a/autodist/benchmark/alpa/analyse_strategy/README.md b/autodist/benchmark/alpa/analyse_strategy/README.md new file mode 100644 index 00000000..290dc08c --- /dev/null +++ b/autodist/benchmark/alpa/analyse_strategy/README.md @@ -0,0 +1,15 @@ +# Analyse alpa strategy + +In this part, we write the gen_str.py to generate the partition strategy from the log in Alpa. The best spmd results with 760M and 1.3B are in **strategy.zip**. + +## Usage + +```bash +python gen_str.sh +``` + +The default load_file is log.txt and the save_file is test.txt, you can specific the load_file and save_file by adding **--load_file** and **--save_file**. If you want to see more information about the Alpa partition, you can add **--whole_strategy --detailed_partition_strs** + +## Comparsion with Autodist + + diff --git a/autodist/benchmark/alpa/analyse_strategy/gen_str.py b/autodist/benchmark/alpa/analyse_strategy/gen_str.py new file mode 100644 index 00000000..c7110175 --- /dev/null +++ b/autodist/benchmark/alpa/analyse_strategy/gen_str.py @@ -0,0 +1,137 @@ +import os +import time +import json +import argparse + +LAYER_DOT_NUM = 6 +MAX_LAYER = 64 + + +class strategy: + + def __init__(self, Instruction: str): + self.elements = Instruction.split(' ') + self.id = self.elements[1] + at_index = self.elements.index('@') + self.selected_str = self.elements[at_index - 5:at_index] + self.selected_str = ' '.join(self.selected_str) + + +def write_str(save_file, strs): + lines = [] + for i in strs: + lines += i + with open(save_file, 'w') as f: + f.writelines(lines) + return + + +def get_str(args, lines, indexs, selected_strs): + strs = [] + dot_name = { + 0: 'qvk_combined', + 1: '...qhd,...khd->...hqk', + 2: '...hqk,...khd->...qhd', + 3: 'attention/output', + 4: 'intermediate/dense', + 5: 'output/dense' + } + assert len(indexs) % LAYER_DOT_NUM == 0 + str_count = 0 + for dot_count, index in enumerate(indexs): + + this_strs = [] + assert 'Instruction' in lines[index] + this_id = lines[index].split(' ')[2].split('%')[-1] + for selected_str in selected_strs: + if this_id == strategy(selected_str).id: + this_s = selected_str + break + + if dot_count % LAYER_DOT_NUM == 0: + this_strs.append('transformer_layer:' + + str(dot_count // LAYER_DOT_NUM) + '\n') + this_strs.append(' ' + dot_name[dot_count % LAYER_DOT_NUM] + ':' + + '\n') + if args.whole_strategy: + this_strs.append(' ' + 'instruction: ' + this_s) + this_strs.append(' ' + 'partition: ' + + strategy(this_s).selected_str + '\n') + + i = index + while True: + i += 1 + if 'Instruction' in lines[i]: + break + if args.detailed_partition_strs: + this_strs.append(' ' + lines[i]) + str_count = i - index - 1 + this_strs.append(' ' + 'total strategy numbers: ' + str(str_count) + + '\n') + strs.append(this_strs) + + return strs + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--start_layer', + type=int, + default=0, + help='set the start layer') + parser.add_argument( + '--end_layer', + type=int, + default=-1, + help='set the end layer, and generate [start_layer,……,end_layer-1]') + parser.add_argument('--load_file', + type=str, + default='log.txt', + help='set the loader folder for experiment data') + parser.add_argument('--save_file', + type=str, + default='test.txt', + help='set the save folder for experiment data') + parser.add_argument('--whole_strategy', + action='store_true', + help='show the whole strategy instruction') + parser.add_argument('--detailed_partition_strs', + action='store_true', + help='show the partition strategy that can be chosen') + args = parser.parse_args() + + total_layers = list(range(0, MAX_LAYER + 1)) + layers = total_layers[args.start_layer:args.end_layer] + f = open(args.load_file, 'r') + lines = f.readlines() + indexs = [] + for i in range(len(lines)): + if 'Startegy Map' in lines[i]: + start_i = i + if 'Auto sharding strategy' in lines[i]: + end_i = i + break + end_i = len(lines) + assert end_i != len(lines) + + for i in range(start_i, end_i): + if 'dot(' in lines[i]: + for layer in layers: + if 'layer/' + str(layer) + '/' in lines[i]: + indexs.append(i) + break + selected_strs = [] + start_i = end_i + for i in range(len(lines)): + if 'Exit AutoSharding' in lines[i]: + end_i = i + break + end_i = len(lines) + for i in range(start_i, end_i): + if 'dot(' in lines[i]: + selected_strs.append(lines[i]) + selected_strs = selected_strs[1:] + f.close() + strs = get_str(args, lines, indexs, selected_strs) + write_str(args.save_file, strs) diff --git a/autodist/benchmark/alpa/analyse_strategy/strategy.zip b/autodist/benchmark/alpa/analyse_strategy/strategy.zip new file mode 100644 index 0000000000000000000000000000000000000000..e11c2630e9316d79a72e4e00f301da6740b55939 GIT binary patch literal 886 zcmWIWW@Zs#U|`^2h;+0K`#RxNodP2RL!1UkM25jAy`WUDq@pA=gp+}pQ|nstY9KDH z;AUWC+2qT{!@_vvHTT;a(S9uk608r(!$0afNFMy?wy1jnw}Qgl3kq)o>U!SYd$ec9 zgwB&Tm6JX@9bUS8%l+FIZ~XWv@F*ks@4cFv5%=OcedL}0iXFRFd3);1xF#GfmGe6?qvrEj9Rl?DE7HyJufF`F453+0559zFEPOGPisF z4i1|2%UY8)e#(^ztd@K4)EaEL|74`}LanZ&}leqyC=x zvj3l>d6ebWhAmuKjG1=>xjnM|brsfD139(<+#atdMma?JUtzjbwM=Y6*f}HSuI^km z5%JeN8hYO(17dVc7VNl@;ZV5oQp2O9<%~y<1_2eUWfKZN7zSpDh)45k=xvkwKmE0I z06c_Ku!OJ)atO-+LwD(D2$LSdj7)OOxKgnMFm*F9024LCl12~_CJ+w*uT)`u literal 0 HcmV?d00001 diff --git a/autodist/benchmark/alpa/benchmark.py b/autodist/benchmark/alpa/benchmark.py new file mode 100644 index 00000000..338a4fff --- /dev/null +++ b/autodist/benchmark/alpa/benchmark.py @@ -0,0 +1,204 @@ +"""The entry point of intra-op + inter-op parallelism benchmark.""" +import os +import argparse +from datetime import datetime +import time + +import numpy as np + +from alpa.util import (write_tsv, get_num_hosts_and_num_devices, to_str_round, + GB) +from collections import namedtuple +from benchmark_one_case import benchmark_one_case +import suite_auto_gpt +import suite_auto_moe +import suite_manual_gpt +import suite_manual_moe +import suite_wresnet +import suite_inference_gpt +from benchmark_parallel_utils import (BenchmarkCase, ShardParallelArgs, + UniformParallelArgs) +#from suite_manual_gpt import GPTModelConfig + +#B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size +GPTModelConfig = namedtuple( + 'GPTModelConfig', + ['seq_len', 'hidden_size', 'num_layers', 'num_heads', 'vocab_size']) + +benchmark_suites = { + 'gpt.tmp': suite_manual_gpt.tmp_suite, + 'gpt.tmp_auto': suite_auto_gpt.tmp_suite, + 'gpt.perf_test_fast_2d': suite_manual_gpt.perf_test_fast_2d_suite, + 'gpt.perf_test_manual': suite_manual_gpt.perf_test_suite, + 'gpt.perf_test_auto': suite_auto_gpt.perf_test_suite, + 'gpt.grid_search_auto': suite_auto_gpt.grid_search_suite, + 'gpt.correctness_test_auto': suite_auto_gpt.correctness_test_suite, + 'gpt_inference.profile': suite_inference_gpt.profile_suite, + 'gpt_no_embedding_inference.profile': suite_inference_gpt.profile_suite, + 'moe.tmp': suite_manual_moe.tmp_suite, + 'moe.tmp_auto': suite_auto_moe.tmp_suite, + 'moe.perf_test_fast_2d': suite_manual_moe.perf_test_fast_2d_suite, + 'moe.perf_test_auto': suite_auto_moe.perf_test_suite, + 'moe.grid_search_auto': suite_auto_moe.grid_search_suite, + 'wresnet.perf_test_2d': suite_wresnet.perf_test_2d_suite, + 'wresnet.perf_test_auto': suite_wresnet.perf_test_auto_suite, + 'wresnet.grid_search_auto': suite_wresnet.grid_search_auto_suite, +} + + +def benchmark_suite(suite_name, + num_hosts, + num_devices_per_host, + input_gpt_layer, + input_batch_size, + input_micro_batches, + reduce_scatter, + dp, + op, + recomputation, + exp_name='default', + niter=3, + shard_only=False, + local=False, + profile_driver_time=False, + profile_stage_execution_time=False, + disable_tqdm=False, + use_separate_process=True): + num_gpus = num_hosts * num_devices_per_host + + if local: + assert shard_only, ('Only shard-only mode is supported for execution ' + 'on local GPUs.') + + if num_gpus not in benchmark_suites[suite_name]: + return + suite = benchmark_suites[suite_name][num_gpus] + #print("suit is {},suit[0]is {}".format(suite,benchmark_case)) + os.makedirs('tmp', exist_ok=True) + + model_type = suite_name.split('.')[0] + output_name = f'{exp_name}.tsv' + + # Run all cases + for benchmark_case in suite: + + if shard_only: + assert dp * op == num_gpus, ('dp*op != num_gpus.') + # B, model, NB, PM, (RS, Remat, 3D Config, FM) + benchmark_case_new = BenchmarkCase( + input_batch_size, + GPTModelConfig(1024, 4096, input_gpt_layer, 32, 51200), + input_micro_batches, 'uniform', + UniformParallelArgs(reduce_scatter, recomputation, dp, op, 1, + True)) + + else: + benchmark_case_new = benchmark_case + + model_config = benchmark_case_new.model_config + num_micro_batches = benchmark_case_new.num_micro_batches + parallel_args = benchmark_case_new.parallel_args + + # Run one case + print('Working on case: {}'.format(str(benchmark_case_new))) + + result = benchmark_one_case(model_type, + benchmark_case_new, + niter, + num_hosts, + num_devices_per_host, + shard_only=shard_only, + local=local, + profile_driver_time=profile_driver_time, + disable_tqdm=disable_tqdm, + use_separate_process=use_separate_process) + + (parameter_count, peak_mem, latencies, tflops, metadata) = result + + heads = [ + 'Type', 'Model Config', '#Microbatch', '#GPU', 'Parallel Config', + 'Mean Time (s)', 'Std Time (s)', '#Params (Billion)', 'TFLOPs', + 'Peak Mem (GB)', 'Metadata' + ] + values = [ + model_type, model_config, num_micro_batches, num_gpus, + parallel_args, f'{np.mean(latencies):.3f}', + f'{np.std(latencies):.3f}', f'{parameter_count/1e9:.3f}B', + f'{tflops:.2f}', f'{peak_mem/GB:.3f}', + to_str_round(metadata, 2) + ] + write_tsv(heads, values, output_name) + + time.sleep(0.1) # for ctrl+c to work + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--suite', + choices=list(benchmark_suites.keys()), + type=str, + required=True) + parser.add_argument('--niter', + type=int, + default=3, + help='The number of benchmark iterations') + parser.add_argument('--num-hosts', type=int, default=None) + parser.add_argument('--num-devices-per-host', type=int, default=None) + parser.add_argument('--shard-only', + action='store_true', + help='Only profile the 2D case. No pipeline ' + 'parallelism.') + parser.add_argument('--local', + action='store_true', + help='Run on local GPUs. Do not use ray actors.') + parser.add_argument('--profile-driver-time', + action='store_true', + help='Profile the execution time on the driver instead ' + 'of the workers.') + parser.add_argument( + '--profile-stage-execution-time', + action='store_true', + help='Profile the execution timestamps of each pipeline ' + 'stage') + parser.add_argument('--no-separate-process', + action='store_false', + help='Do not launch separate processes for benchmark. ' + 'Errors in a single case will terminate this ' + 'script.', + dest='use_separate_process') + parser.add_argument('--exp-name', type=str, default='default') + parser.add_argument('--disable-tqdm', action='store_true') + parser.add_argument('--num_gpt_layer', type=int, default=1) + parser.add_argument('--num_batch_size', type=int, default=4) + parser.add_argument('--num_micro_batches', type=int, default=1) + parser.add_argument('--reduce_scatter', + action='store_true', + help='Prefer_reduce_scatter = True.') + parser.add_argument('--dp', type=int, default=4) + parser.add_argument('--op', type=int, default=1) + parser.add_argument('--recomputation', + action='store_true', + help='remat = True.') + args = parser.parse_args() + + num_hosts, num_devices_per_host = get_num_hosts_and_num_devices(args) + + benchmark_suite( + args.suite, + num_hosts, + num_devices_per_host, + args.num_gpt_layer, + args.num_batch_size, + args.num_micro_batches, + args.reduce_scatter, + args.dp, + args.op, + args.recomputation, + args.exp_name, + args.niter, + args.shard_only, + args.local, + args.profile_driver_time, + args.disable_tqdm, + args.use_separate_process, + ) diff --git a/autodist/benchmark/alpa/gpt_alpa_2d_table1.sh b/autodist/benchmark/alpa/gpt_alpa_2d_table1.sh new file mode 100644 index 00000000..597bbc0f --- /dev/null +++ b/autodist/benchmark/alpa/gpt_alpa_2d_table1.sh @@ -0,0 +1,16 @@ +#!/bin/bash --login +start_time=$(date +%s) +dp=(1 2 4 8) +op=(8 4 2 1) + + +for ((k=0; k<${#dp[*]}; k=k+1)); do + python benchmark.py --suite gpt.perf_test_fast_2d \ + --shard-only --num-hosts 1 --num-devices-per-host 8 \ + --num_gpt_layer 4 --num_batch_size 32 --num_micro_batches 1 \ + --dp ${dp[k]} --op ${op[k]} --reduce_scatter +done + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "running spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/benchmark/alpa/gpt_alpa_2d_table2.sh b/autodist/benchmark/alpa/gpt_alpa_2d_table2.sh new file mode 100644 index 00000000..4b89a854 --- /dev/null +++ b/autodist/benchmark/alpa/gpt_alpa_2d_table2.sh @@ -0,0 +1,27 @@ +#!/bin/bash --login +start_time=$(date +%s) +dp=(1 2 4) +op=(4 2 1) + + +for ((k=0; k<${#dp[*]}; k=k+1)); do + python benchmark.py --suite gpt.perf_test_fast_2d \ + --shard-only --num-hosts 1 --num-devices-per-host 4 \ + --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 \ + --dp ${dp[k]} --op ${op[k]} --recomputation + +done + +for ((k=0; k<${#dp[*]}; k=k+1)); do + python benchmark.py --suite gpt.perf_test_fast_2d \ + --shard-only --num-hosts 1 --num-devices-per-host 4 \ + --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 \ + --dp ${dp[k]} --op ${op[k]} + +done + + + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "running spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/benchmark/alpa/gpt_alpa_3d.sh b/autodist/benchmark/alpa/gpt_alpa_3d.sh new file mode 100644 index 00000000..1633ddeb --- /dev/null +++ b/autodist/benchmark/alpa/gpt_alpa_3d.sh @@ -0,0 +1,16 @@ +#!/bin/bash --login +start_time=$(date +%s) +gpus=(1 2 4 8 8 16) +device=(1 1 1 1 2 1) + + + +for ((k=0; k<${#gpus[*]}; k=k+1)); do + python benchmark.py --suite gpt.perf_test_auto \ + --num-hosts ${device[k]} --num-devices-per-host ${gpus[k]} + +done + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "running spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/benchmark/alphafold2.md b/autodist/benchmark/alphafold2.md new file mode 100644 index 00000000..9360c928 --- /dev/null +++ b/autodist/benchmark/alphafold2.md @@ -0,0 +1,124 @@ +# Alphafold2 + +## Model Config + +We focus on evoformer like structures during training currently. Data type is *float16*. + +**Evoformer Stack** + +| Case | s | r | cm | cz | +| ---------------- | ---- | ---- | ---- | ---- | +| initial training | 128 | 256 | 256 | 128 | +| 1st fine-tuning | 512 | 256 | 256 | 128 | +| 2nd fine-tuning | 512 | 384 | 256 | 128 | + +**Extra Msa Stack** + +| Case | s | r | cm | cz | +| ---------------- | ---- | ---- | ---- | ---- | +| initial training | 1024 | 256 | 64 | 128 | +| 2.1 fine-tuning | 1024 | 384 | 64 | 128 | +| 2.2 fine-tuning | 5120 | 384 | 64 | 128 | + +## Baselines + +**Deepmind's plan** + +data parallelism (each accelerator with exact 1 sample) + recompute. Since the parameter size is relatively small in Alphafold2, the latency can be approximately by single device execution. Hyperparameter setting is listed in the following table. + +| Case | evo_num | use_chunk | +| --------------- | ------- | --------- | +| Evoformer Stack | 48 | False | +| Extra Msa Stack | 4 | True | + +**Dynamic Axial Parallelism (DAP)** + +The end-to-end time is bounded by the computation. In other words, given input tensors with a fixed batch size, it is possible to reduce the time by introducing more devices (partition a operator into parallelizable sub-operators). Here are possible experiment dimensions. + +| batch size | #gpus | +| ---------- | ----- | +| 1 | 2 | +| 2 | 4 | +| 4 | 8 | +| 8 | 16 | + +**Table 1: Evoformer Stack & Training** + +| Case | batch size | #gpus | latency/ms | peak mem/MB | +| ---------------- | ---------- | ----- | ---------- | ----------- | +| initial training | 1 | 1 | 3521.98 | 4414 | +| initial training | 1 | 2 | 2430.38 | 2531 | +| initial training | 1 | 4 | 1497.77 | 1574 | +| initial training | 2 | 4 | 2485.53 | 2647 | +| 1st fine-tuning | 1 | 1 | 7696.62 | 10729 | +| 1st fine-tuning | 1 | 2 | 4663.32 | 5744 | +| 1st fine-tuning | 1 | 4 | 2620.09 | 3211 | +| 1st fine-tuning | 2 | 4 | 4717.36 | 5921 | +| 2nd fine-tuning | 1 | 1 | 16632.06 | 17810 | +| 2nd fine-tuning | 1 | 2 | 9377.98 | 9417 | +| 2nd fine-tuning | 1 | 4 | 5099.72 | 5157 | +| 2nd fine-tuning | 2 | 4 | 9422.99 | 9804 | + +**Table 2: Extra Msa Stack & Training** + +| Case | batch size | #gpus | latency/ms | peak mem/MB | +| ---------------- | ---------- | ----- | ---------- | ----------- | +| initial training | 1 | 1 | x | x | +| initial training | 1 | 2 | x | x | +| initial training | 1 | 4 | x | x | +| initial training | 2 | 4 | x | x | +| 2.1 fine-tuning | 1 | 1 | x | x | +| 2.1 fine-tuning | 1 | 2 | x | x | +| 2.1 fine-tuning | 1 | 4 | x | x | +| 2.1 fine-tuning | 2 | 4 | x | x | +| 2.2 fine-tuning | 1 | 1 | x | x | +| 2.2 fine-tuning | 1 | 2 | x | x | +| 2.2 fine-tuning | 1 | 4 | x | x | +| 2.2 fine-tuning | 2 | 4 | x | x | + +## End-to-end evaluation results (DAP vs Autodist) + +### Model Config + +Evoformer Stack + - shape config + - bs, s, r, cm, cz = 1, 128, 256, 256, 128 + - bs, s, r, cm, cz = 1, 512, 256, 256, 128 + - bs, s, r, cm, cz = 1, 512, 384, 256, 128 + - other config: dtype, use_chunk, is_train, is_extra = torch.float16, False, True, False + +*note*: results organized in (estimate time/ms, execution time/ms, device mem/GB) + +**Table 1: tensor parallelism(2gpu) w/o recompute** + +evo_num = 4 + +| s, r | DAP | Autodist | compile time/s | +| ------------- | --------------- | ------------------------ | ---------------| +| 128, 256 | (139.15, 4.58) | (127.13, 156.15, 5.35) | 0.77 | +| 512, 256 | (293.11, 11.02) | (286.04, 307.54, 12.86) | 0.77 | +| 512, 384 | (596.41, 20.91) | (568.72, 595.00, 24.44) | 0.77 | + +*note*: results organized in (estimate time/ms, execution time/ms, device mem/GB) + +**Table 2: tensor parallelism(2gpu) w/ adaptive recompute** + +evo_num = 48 +memory constraint = 40GB + +| s, r | DAP | Autodist | compile time/s | +| ------------- | --------------- | ------------------------- | ---------------| +| 128, 256 | (2250.27, 2.53) | (1690.71, 1915.13, 38.33) | 43.57 | +| 512, 256 | (4733.89, 5.74) | (4273.40, 4525.81, 39.06) | 45.39 | +| 512, 384 | (9673.10, 9.42) | (8911.85, 10042.22, 39.70)| 43.88 | + +**Table 3: tensor parallelism(4gpu) w/ adaptive recompute** + +evo_num = 48 +memory constraint = 40GB + +| s, r | DAP | Autodist | compile time/s | +| ------------- | --------------- | ------------------------- | ---------------| +| 128, 256 | (1874.73, 1.54) | (1083.93, 1400.29, 29.13) | 4650.48 | +| 512, 256 | (3350.06, 3.13) | (2388.69, 2965.40, 36.50) | 4483.49 | +| 512, 384 | (6724.48, 5.04) | (4932.62, 6450.42, 41.80) | 4427.15 | diff --git a/autodist/benchmark/gpt.md b/autodist/benchmark/gpt.md new file mode 100644 index 00000000..bf976a75 --- /dev/null +++ b/autodist/benchmark/gpt.md @@ -0,0 +1,234 @@ + + +# GPT-3 + +## Model Config + +**batch size**, **sequence length** and **vocabulary size** are fixed to 1024, 1024 and 51200 respectively. The data type is *float16*. + +| #params | Hidden size | #layers | #heads | #gpus | +| ------- | ----------- | ------- | ------ | ----- | +| 350M | 1024 | 24 | 16 | 1 | +| 760M | 1536 | 24 | 16 | 2 | +| 1.3B | 2048 | 24 | 32 | 4 | +| 2.6B | 2560 | 32 | 32 | 8 | +| 6.7B | 4096 | 32 | 32 | 16 | +| 15B | 5120 | 48 | 40 | 32 | +| 39B | 8192 | 48 | 64 | 64 | + +## End-to-end evaluation results + +*note*: results organized in (execution time/s, device mem/GB, compile time/s) + +**Table 1: include pipeline** + +| Config | alpa | ours | +| ------ | ---------------------- | ---------------------- | +| 760M | (59.21, 14.49, 232.51) | (46.13, 30.27, 7.56) | +| 1.3B | (47.23, 24.83, 355.14) | (42.45, 20.68, 33.96) | +| 2.6B | (45.00, 13.20, 731.86) | (39.72, 24.74, 295.58) | +| 6.7B | (46.19, 16.75, 832.01) | (45.17, 15.57, 1906.22)| + +**Table 2: tensor parallelism only** + +| Config | alpa | ours | +| ------ | --------------------- | ----------------------- | +| 350M | (54.88, 26.50, 8.77) | (56.03, 26.52, 0.37) | +| 760M | (51.06, 21.38, 27.04) | (52.14, 19.41, 2.00) | +| 1.3B | (47.83, 20.22, 25.86) | (47.84, 30.37, 14.91) | +| 2.6B | (55.10, 21.47, 65.88) | (69.92, 10.85, 36.91) | +| 6.7B | (84.16, 25.09, 65.21) | (61.79, 21.93, 226.11) | + +## AutoDist Details + +Memory constraint set to 30 GB in most test cases. + +For cases with \*, we set the memory constraint to 25GB. + +For the case with runtime=-1, we cannot run the case due to two errors from Cube: multiref and rvd failed. + +### Include pipeline + +**760M** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 65.37 | 56.58 | 6.19 | 6.71 | 5.19 | +| 2 | 53.18 | 51.11 | 10.26 | 10.88 | 5.66 | +| 4 | 49.13 | 46.74 | 17.47 | 18.14 | 6.05 | +| 8 | 48.52 | 46.13 | 29.56 | 30.27 | 7.56 | +| 16*| 53.86 | 53.47 | 24.97 | 29.98 | 9.3 | +| 32*| 57.49 | 56.48 | 24.81 | 27.5 | 9.59 | + +**1.3B** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 52.36 | 48.27 | 8.03 | 8.99 | 28.99 | +| 2 | 45.59 | 43.48 | 13.81 | 14.77 | 27.7 | +| 4 | 43.87 | 42.45 | 20.03 | 20.68 | 33.96 | +| 8 | 44.3 | -1 | 24.98 | -1 | 38.67 | +| 16*| 47.71 | 48.13 | 24.86 | 30.35 | 48.31 | +| 32*| 48.96 | 48.03 | 24.98 | 25.96 | 42.15 | + +**2.6B** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 46.32 | 46.45 | 7.34 | 8.94 | 228.92 | +| 2 | 42.6 | 40.96 | 12.58 | 14.26 | 248.44 | +| 4 | 41.48 | 39.72 | 23.06 | 24.74 | 295.58 | +| 8 | 41.01 | -1 | 28.74 | -1 | 436.54 | +| 16 | 42.68 | -1 | 29.89 | -1 | 476.78 | +| 32 | 44.98 | -1 | 29.98 | -1 | 394.06 | + +**6.7B** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 50.58 | 50.26 | 7.94 | 9.92 | 1496.21 | +| 2 | 46.37 | 45.17 | 13.54 | 15.57 | 1906.22 | +| 4 | 43.81 | -1 | 17.16 | -1 | 1937.48 | +| 8 | 43.42 | -1 | 29.65 | -1 | 2082.68 | +| 16 | 43.66 | -1 | 29.98 | -1 | 1834.14 | +| 32 | 47.15 | -1 | 29.96 | -1 | 1555.16 | + +Remark: + +### Tensor Parallelism Only + +**350M** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 82.23 | 71.08 | 5.92 | 5.64 | 0.38 | +| 2 | 69.33 | 62.75 | 8.91 | 8.63 | 0.39 | +| 4 | 62.98 | 58.48 | 14.88 | 14.26 | 0.37 | +| 8 | 59.6 | 56.03 | 26.84 | 26.52 | 0.37 | +| 16 | 64.73 | 62.61 | 29.8 | 29.21 | 0.38 | +| 32 | 71.85 | 68.8 | 29.43 | 26.36 | 0.3 | + +**760M** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 90.14 | 81.4 | 4.96 | 5.02 | 0.79 | +| 2 | 65.31 | 66.03 | 6.84 | 7.12 | 0.94 | +| 4 | 58.3 | 57.21 | 10.65 | 11.22 | 1.07 | +| 8 | 53.8 | 52.14 | 18.25 | 19.41 | 2 | +| 16*| 53.86 | 53.46 | 24.97 | 29.78 | 3.31 | +| 32*| 57.49 | 56.42 | 24.81 | 27.54 | 2.7 | + +**1.3B** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 79.5 | 83.84 | 4.42 | 4.41 | 2.65 | +| 2 | 61.4 | 69.35 | 6.25 | 6.3 | 3.96 | +| 4 | 55.88 | 57.85 | 9.91 | 10.29 | 6 | +| 8 | 50.25 | 50.38 | 17.2 | 19.65 | 9.78 | +| 16*| 47.71 | 48.08 | 24.86 | 30.37 | 14.91 | +| 32*| 48.96 | 47.84 | 24.98 | 25.76 | 9.29 | + +**2.6B** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 87.92 | 102.79 | 4.3 | 4.58 | 17.57 | +| 2 | 70.56 | 80.51 | 6.03 | 6.98 | 24.84 | +| 4 | 63.19 | 69.62 | 9.5 | 10.85 | 36.91 | +| 8 | 53.25 | -1 | 15.94 | -1 | 76.63 | +| 16 | 46.44 | -1 | 24.99 | -1 | 115.04 | +| 32 | 46.38 | -1 | 24.98 | -1 | 69.36 | + +**6.7B** + +| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | +|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| +| 1 | 107.5 | 127.41 | 5.01 | 5.24 | 61.34 | +| 2 | 88.84 | 101.43 | 6.84 | 7.78 | 79.19 | +| 4 | 77.33 | 89.68 | 10.45 | 13.54 | 116.27 | +| 8 | 61.7 | 61.79 | 19.2 | 21.93 | 226.11 | +| 16 | 50.44 | -1 | 24.16 | -1 | 189.93 | +| 32 | 49.57 | -1 | 25 | -1 | 138.96 | + + +## Alpa Details + +Here we show the key hyperparameters of alpa under different config in both **table 1** and **table 2**. + +**Details for table 1** + +| Config | micro batch size | Recompute | num_auto_layers | forward_stage_layer_ids | submesh_shapes | logical_mesh_shapes | autosharding_option_dicts | +|:-------|:-----------------|:----------|:----------------|:-------------------------------|:-------------------------------|:-------------------------------|:-----------------------------------------------------| +| 350M* | 32 | True | 1 | [0] | [1, 1] | [1, 1] | {} | +| 760M | 16 | True | 6 | [0, 1, 2], [3, 4, 5] | [1, 1], [1, 1] | [1, 1], [1, 1] | force_dp_dict,{} | +| 1.3B | 64 | True | 6 | [0, 1, 2, 3, 4, 5] | [1, 4] | [4, 1] | force_dp_dict,{} | +| 2.6B | 16 | True | 8 | [0, 1, 2, 3], [4, 5, 6, 7] | [1, 4], [1, 4] | [4, 1], [4, 1] | force_dp_dict,force_dp_dict | +| 6.7B | 16 | True | 8 | [0, 1], [2, 3], [4, 5], [6, 7] | [1, 4], [1, 4], [1, 4], [1, 4] | [4, 1], [4, 1], [4, 1], [4, 1] | force_dp_dict,{},force_dp_dict,force_dp_dict | +| 15B | 8 | True | 16 | [0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15] | [1, 4]\*8 | [4, 1]\*8 | force_dp_dict \* 8 | + +**15B** + +| batch size | time/s | memory/GB | recompute | +| ---------- | -------| --------- | --------- | +| 2 | 75.77 | 21.94 | False | +| 4 | 95.56 | 17.69 | True | +| 8 | 59.98 | 18.10 | True | +| 16 | 60.84 | 18.67 | True | +| 32 | 70.05 | 23.32 | True | +| 64 | 76.03 | 21.31 | True | + +**Details for table 2** + +**350M** + +| batch size | time/s | memory/GB | recompute | +| ---------- | -------| --------- | --------- | +| 4 | 56.46 | 16.30 | False | +| 8 | 54.88 | 26.50 | False | +| 16 | 71.06 | 16.10 | True | +| 32 | 70.22 | 26.09 | True | +| 64 | x | OOM | True | + +**760M** + +| batch size | time/s | memory/GB | recompute | +| ---------- | -------| --------- | --------- | +| 4 | 51.82 | 13.82 | False | +| 8 | 51.06 | 21.38 | False | +| 16 | 62.99 | 12.67 | True | +| 32 | 62.88 | 19.06 | True | +| 64 | x | OOM | True | + +**1.3B** + +| batch size | time/s | memory/GB | recompute | +| ---------- | -------| --------- | --------- | +| 4 | 49.35 | 12.77 | False | +| 8 | 47.83 | 20.22 | False | +| 16 | 60.01 | 10.67 | True | +| 32 | 59.31 | 16.01 | True | +| 64 | 58.81 | 26.67 | True | +| 128 | x | OOM | True | + +**2.6B** + +| batch size | time/s | memory/GB | recompute | +| ---------- | -------| --------- | --------- | +| 4 | 55.16 | 13.36 | False | +| 8 | 55.10 | 21.47 | False | +| 16 | 67.90 | 12.03 | True | +| 32 | 67.56 | 18.85 | True | +| 64 | x | OOM | True | + + +**6.7B** + +| batch size | time/s | memory/GB | +| ---------- | -------| --------- | +| 2 | 109.60 | 7.87 | +| 8 | 89.98 | 11.11 | +| 16 | 86.14 | 15.77 | +| 32 | 84.16 | 25.09 | +| 64 | x | OOM | diff --git a/autodist/benchmark/recompute.md b/autodist/benchmark/recompute.md new file mode 100644 index 00000000..39f77e08 --- /dev/null +++ b/autodist/benchmark/recompute.md @@ -0,0 +1,28 @@ +# Continuous Recompute + +We implement the continuous recompute search algorithm, which outperforms the alpa and the manual strategy of Megatron[1] on GPT. On the following cases, [Alpa](https://github.com/alpa-projects/alpa) is OOM and autodist has a 5.4% gain over megatron. + +## Experimental Config + +The model config is GPT3 760M, on 2 GPUs, num_layer increased from 24 into 48, global_batch_size = 1024 and micro_batch_size = 8. + +## Results + +| Search algorithm | runtime (time/s, memory/GB) | compile time/s | Remark | +|:-------------------------------|:----------------------------|:----------------|:---------| +| Megatron(selective recompute) | (137.10, 17.22) | 8.40 | OOM | +| Megatron(full recompute) | (171.78, 7.09) | 9.23 | | +| Megatron(search) | (154.22, 15.85) | 9.60 | | +| Alpa | (149.53, 16.15) | 129.96 | OOM | +| Autodist(continuous recompute) | **(145.90, 15.67)** | 2069.80 | | +| Autodist(single recompute) | (146.65, 15.99) | 130.31 | Multiref | + +## Details + +**Megatron:** Megatron using selective recompute(as well as using full recompute) represents selective recompute(as well as full recompute) for all layers. Megatron(search) is the optimal solution searched manually according to [args.recompute-method](https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/arguments.py#L375:~:text=group.add_argument(%27%2D%2Drecompute%2Dmethod%27%2C%20type%3Dstr%2C%20default%3DNone%2C)), [args.recompute-granularity](https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/arguments.py#L375:~:text=group.add_argument(%27%2D%2Drecompute%2Dgranularity%27%2C%20type%3Dstr%2C%20default%3DNone%2C)) and [args.recompute-num-layers](https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/arguments.py#L375:~:text=group.add_argument(%27%2D%2Drecompute%2Dnum%2Dlayers%27%2C%20type%3Dint%2C%20default%3D1%2C)) in Megatron. Depending on the permutation of these switches, without pipeline, we can specify that some layers recompute and others don't, and that the recompute strategy can be selective recompute or full recompute. But we can not obtain the strategy that different layers take different recompute strategies. The optimal solution of our artificial search is that Megatron(search) fully recomputes the first 31 layers. + +**Alpa:** The Alpa solution is OOM with the config(dp=1, op=2, use_remat=True). + +**Autodist:** The Autodist(continuous recompute) searches for a solution, where the first 23 layers fully recompute and the left 25 layers selective recompute. The compile time of Autodist(continuous recompute) is up to 15x of that of Autodist(single recompute), whose search solution has multiref BUG. + +**Remark**: We use whether the memory exceeds 16G to determine if it is OOM. Because heavy fragmentation arises in the above case, memory utilization is less than 70% and difficults the memory estimation. diff --git a/autodist/build_env.py b/autodist/build_env.py new file mode 100644 index 00000000..c13978fe --- /dev/null +++ b/autodist/build_env.py @@ -0,0 +1,58 @@ +import torch +import os +import shutil +import sys +from datetime import datetime +import subprocess +from pathlib import Path +import torch +from nnscaler.autodist.util import get_node_arch + +if bool(int(os.environ.get('PROFILE_COMM', default=0))): + profile_comm = True +else: + profile_comm = False + + +def main(): + base_path = str(Path.home()) + '/.autodist' + default_path = base_path + '/' + get_node_arch() + + code_path = Path(__file__).parents[1] + + if not os.path.exists(default_path): + os.makedirs(default_path) + print('> create folder: ', default_path) + os.makedirs(default_path + '/plan') + else: + print('> folder already exists: ', default_path) + + # profile communication cost + if profile_comm: + print('> CUDA device num: ', torch.cuda.device_count()) + for device_num in [2, 4, 8, 16]: + if device_num > torch.cuda.device_count(): + break + command = f'torchrun --master_port 21212 --nproc_per_node={device_num} ./comm_profile.py --comm_profile_dir={default_path}/comm' + output = subprocess.check_output(command, shell=True, text=True) + else: + print('> skip communication profiling, using mi200 profile data') + if os.path.exists(default_path + '/comm'): + print('> backup existing comm profile data') + shutil.move( + default_path + '/comm', + default_path + f'/comm_back_{str(datetime.now().timestamp())}') + shutil.copytree(code_path / 'autodist/profile_data/16xmi200/comm', default_path + '/comm') + + # compile solver + solver_csrc = code_path / 'nnscaler/autodist/csrc/solver.cpp' + compile_command = f'g++ -std=c++11 {solver_csrc} -O3 -pthread -o solver' + compile_out = subprocess.check_output(compile_command, + shell=True, + text=True) + subprocess.check_output(f'mv solver {base_path}/', shell=True, text=True) + print('> build env successfully') + + +if __name__ == '__main__': + main() diff --git a/autodist/comm_profile.py b/autodist/comm_profile.py new file mode 100644 index 00000000..852d61c3 --- /dev/null +++ b/autodist/comm_profile.py @@ -0,0 +1,107 @@ +import argparse +import json +import torch +from pathlib import Path +import os +from typing import Tuple, List, Dict + +import nnscaler +from nnscaler.runtime.adapter.collectives import all_gather, all_reduce, all_to_all, reduce_scatter +from nnscaler.profiler import CudaTimer +from nnscaler.runtime.device import DeviceGroup +from nnscaler.autodist.util import get_node_arch + + +class CommProfiler: + + def __init__(self, + nranks: int, + warmup_times: int = 10, + profile_times: int = 10) -> None: + self.nranks = nranks + self.warmup_times = warmup_times + self.profile_times = profile_times + self.ranks = tuple(range(self.nranks)) + + def collect_profile_info(self, + primitive: str) -> Tuple[List[float], List[float]]: + + b_size = 16 + sequence_len = 16 + quarter_mb_size_list = [ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 + ] + model_dim_list = [ + mem * 256 * 256 // b_size // sequence_len + for mem in quarter_mb_size_list + ] + sizes_in_mb = [0.25 * val for val in quarter_mb_size_list] + times_in_s = [] + for cur_sz, d_size in zip(sizes_in_mb, model_dim_list): + assert d_size % self.nranks == 0 + if primitive in ['all gather', 'all to all']: + d_size = d_size // self.nranks + tensor = torch.rand([b_size, sequence_len, d_size], + dtype=torch.float32, + device=torch.cuda.current_device()) + if primitive == 'all gather': + func = all_gather + kwargs = {'tensor': tensor, 'dim': 2, 'ranks': self.ranks} + elif primitive == 'all reduce': + func = all_reduce + kwargs = {'tensor': tensor, 'ranks': self.ranks} + elif primitive == 'reduce scatter': + func = reduce_scatter + kwargs = {'tensor': tensor, 'dim': 2, 'ranks': self.ranks} + elif primitive == 'all to all': + func = all_to_all + kwargs = { + 'tensor': tensor, + 'idim': 0, + 'odim': 2, + 'ranks': self.ranks + } + else: + raise ValueError('Unknown primitive: {}'.format(primitive)) + for _ in range(self.warmup_times): + func(**kwargs) + CudaTimer().clear() + for _ in range(self.profile_times): + otensor = func(**kwargs) + cur_t = CudaTimer().instance.field_data['comm'] / self.profile_times + times_in_s.append(cur_t) + return sizes_in_mb, times_in_s + + def profile(self) -> Dict[str, Tuple[List[float], List[float]]]: + profile_info = {} + for primitive in [ + 'all gather', 'all reduce', 'reduce scatter', 'all to all' + ]: + profile_info[primitive] = self.collect_profile_info( + primitive=primitive) + return profile_info + + +parser = argparse.ArgumentParser( + description='Profile runtime communication cost') +parser.add_argument('--comm_profile_dir', + type=str, + default=str(Path.home()) + '/.autodist/comm', + help='autodist comm profile folder') +args = parser.parse_args() + +nnscaler.init() + +CudaTimer(enable=True, predefined=True) +world_size = DeviceGroup().world_size +comm_profiler = CommProfiler(nranks=world_size) + +profile_info = comm_profiler.profile() + +if torch.distributed.get_rank() == 0: + dir_path = args.comm_profile_dir + if not os.path.exists(dir_path): + os.makedirs(dir_path) + file_name = dir_path + '/' + f'intra_{world_size}.json' + with open(file_name, 'w') as f: + json.dump(profile_info, f, indent=2) diff --git a/autodist/docs/descs.py b/autodist/docs/descs.py new file mode 100644 index 00000000..3567c2d0 --- /dev/null +++ b/autodist/docs/descs.py @@ -0,0 +1,58 @@ +@dataclass +class AutoDistConfig: + recompute: bool = False + mem_granularity_mb: bool = 1 + + +@dataclass +class NodePartitionDesc: + # list element: (idx, dim, num), the order matters + desc: List[Tuple[int, int, int]] + + +@dataclass +class DeviceDesc: + dev_num: int + peak_mem_gb: int = 30 + connection: str = 'NV3' + + +@dataclass +class TensorParallelDesc: + partition_descs: List[NodePartitionDesc] + recompute_groups: List[List[int]] + logical_desc: DeviceDesc + + +@dataclass +class ParallelDesc: + stages: List[Tuple[TensorParallelDesc, DeviceDesc]] + + +class TensorParallelDPSolver: + + # resource is a logical mesh + def __init__(graph: IRGraph, resource: DeviceDesc, config: AutoDistConfig): + pass + + def solver(): + pass + + # temp design + def get_optimal_plan( + start_desc: NodePartitionDesc, end_desc: NodePartitionDesc + ) -> Tuple[TensorParallelDesc, float, int]: + pass + + +class PipelineDPSolver: + + # resource is a physical mesh + def __init__(graph: IRGraph, resource: DeviceDesc, config: AutoDistConfig): + pass + + def solver(): + pass + + def get_optimal_plan() -> Tuple[ParallelDesc, float]: + pass diff --git a/autodist/docs/images/arch.png b/autodist/docs/images/arch.png new file mode 100644 index 0000000000000000000000000000000000000000..abba4589abb61eff4bd984bd9dad0c093d0844d3 GIT binary patch literal 49324 zcmc$`1yEjHvo1(Nkl^kP!QGuekl?}HgS)%CySoKUit z&Yg2-S5XwOS$nP3tGl1>=h<)Q4;e9d7#tWdFfe%W@51t6VBkAoU>^>kA%Wk_ghgBc z-#*yOiwS{Mj^iBxKYTJ4lokX7tBHnv)`I|k{%rGI%^nO4q38YU!+>@1PcX1Qg5tu0 ziZ0rxYfzeKVz`ih4`6awYwoqHaevq7_gLkHe#x$J4#kl~VGHd?k){o8Ul2A3t62!4 z41|&id)v`gsit&&Bv@x#;(=!vDII^OfGd zko@bS(PNbP*KdCs!u5xS{BxB@ba5j8bHyeyVNv|+&RS;3oqrF|ubc6&I|qlcB2fQz zs~Wby?Y{>#_%sE}YN*`#18x#XA2}gKm6`CH94mVl1iBj=oew`j{y17 z7zI0B%Pqj?`jhU+L{9AL~rJpW6$aTiDKwB7MgO-G796FKC+6laMeO5>K+1 z+>e^NEnGYx#O|#Io&>xGTIt+vx78W7os&bg$3x0y(uthMHH&$=WTo?e!tw=tw%Yqq zr$q<{x{~x*FXLN|6adpcCk4zOHlX{APCh>ZWh^GvyjvBP)u|a)%}p#h zuSt2R4goKK6EI{Ry(1VWpooUxrGej)ftNV(B?~W8Qg^bl*Bv=fLz`u)Q`~fSTmyy; zx1n1sQ$R4dG4e-+3;izrPSl6GcH4_YZ!8G`lIpd5v|2;D2!a!#*?M6q#G3v1kArFL zC>=LmFQg!8VM$jSa`VEmdkqybr(*B;EQY6avh>eQUii_b!G5|rXZGc{Tv0tM$p#iH z`jaIWMDNEQa;M7S_$QbUX&%z+CYjgEzSuKrDZPvDg) zJKHK}G;FjGa#jlwwMtn~sS-`T$52}}Y-`91F2?2w{YAdTS(oZTb~FB{5AFfOrkn{U zGiG^&*|agL%tfvXd*qXLC>7HBPI|su-InU9BX}x}SBUS%vrp>vzTP~dj)CrPGn=*s zsBi&me?fDa%>@s9Vwe3g!) zff+!|{KN87$p6Deo_%I%@mwBFD;xdbKT&BA&i`_^x++-NYMgOx$;U9Aprg5P1R*sx(QAimnW^L%gwh9-yX)h1A$HVi$`#t<5w;`eh-BELXlrk9ja_zl#uW*7~u$?iZ5bB zwz4kNM<3MH*-be68SZ-Ls>$q zoGPP8Uija1C=Q?OZy*{F$y9-`4T9i$jdf(YB(gO=TSbSqo&p=& zn!p9gu6|V=)zHTx+#a3pd+?h50e5=gtG6N5LtRDVB7P|bm6{e2m?UB2Uo69E6YeSSJtwLB(%Rb{C2A&N1 zlgN9DWek>WM+?pgJGF~0Z9IsRp2D!A#OdzC-B3~?rw$)Q&kRuLPHeuzC9a<9MT<#tHtw=!h+GcdO|J|CYpjrD*V{VEo z%N7U47Yk|ktsLKK^A1&B`56@EFdv=e@KbeD-0M5E^lx)Nk;&)lon!90WLR}FZYT_} zdiP8nMaUyX7nNURGGvYqWCn!3ecO2zUHk2~plhLg)v> zv$^Md^rhDFXxH$|jt65!x>zmaEX;=P77s)w+pkJgh$|si`O*Dt91fX>aazOQoE-b= zM}ksVm3DJuyOkL5FX+$uj=!?Wp?g|H6cZiD9jV^MM;l6Be3Q6(f^9#vMtOV5w!A@! zPe(^;G+L>cGQ_;4AV86wy+6m*lseqL8#wi(&kpIF6)@PH6X?w`TPmz&o&vEt<$bLa z?*>hkjpQU8FV+~W-eW~3=?DNmHg%#q;=*jo+>(1hVVJ}5pf`3!rta?Fq_?#^g5=8o zH5=hkdwP%PleSKSU=sudl5#P6hgLgg5rcMj>WjCo5GZbPZ7I+;6J#(}@o>KO3+jZi ze1H}2vz_oBKkSz46R*sfcEkiVfA4Wk_+M+3CKvwOR~-DOia&j{6@PoO%%NdE3k~jE zKe)TVZ{FzNoUtJ$%Ue#e(AhOvg}2Jr>$$mn;y0FG-$mqPi)*k~2 zYEy+|^fpW>SaG=!)1$S-eh8Cgt>5*9t?%h!I$$Y?G03=R!cx7W2rk~G}o9SDO zy4Z4ExpYPisd={w>wPr#ivp_JAlnsGiJa{8>>obmmSD09C;!K4%d=WYTczU-nH~NS zYiFW$`x+gz>Ac0s?zw>j(GYE`dj&np!{F^pM}3|&Bui=VYf@LGvPHbPW09&ow zr8^sc@%8g%w~A$}URNk1n%OO#@&#XHc=EXJ8zJgsM*)JfQ{=aUS94wE=lAf~-aebd ztMz??ZpS;|vtILG|Kdz>(t1~Z^b@~X4$6q_eh=rZf~j!%9)u8RJ){HeNoK3&t9YBb zyQoKIMkj2cxp5=p})_}V5_h$VCo)}6vJOOM&oA5W{#dHync%cx<5mAzweg; z_3~%ekq~i3dMkG)U8;ogA=Gyl9=hk5&6SEvB?@oRE+1(0>yWur9L6Zoup*rGRxyp$ z#{o`Q&>XX^lRKoS7Qz3ZD6NA1+&5vBTKK$AhBMYmvRCh-XS^Sh!>5k&EGq-D@-^~ex*^m6JfB|DO_1!SnhlC_ygOW zUY+bEFu<+Qd)>F;u(eN?7|DJSYW>z=5o9#-UOg0wSzyb4&?=Fg8XNn({fbepYCO+u z#Oc2$c~3msNZoqxO6X*EbuHzx+XCCMj&R{<*5E-%SorblaW8`u9Z6ao$N2Z*NeDIbwghzDeb5?a7uc4vizp-2zK zw~dte!fv0D^UZlmtCVE#3@&P*|KxsX^^oW7`o3|u2c>_UGaL%3H%nrCt?aTv0L{}3 zo>e0Y27BYT5Ngs~e;9#5|M7!1&t@6H^)H{yMMn~t<&Wnk3Zzap-Iy+sp4HmgOn^tj zWbf4+GbiK@Rqj|c0-rij@c!q!@(1*{l7_x&&+xw^M<$#c=Yi*iI(L5w!09tP{Y;3J z%4s@+3xe=QfwO+L#CADPgtnWNH2tZC-tOafCv@b8=MMX>iNpUZoO{7ZrPQDuP`Y} z;Gq`(xv}vovt|M*O_zv{Z`k<>5|u*cPr&CZR*Y=LdVf*jEj88IaIuymhhYpWg$Z)x z5XqS{)RZ49qMkbXqMx2DyH{%n4wpO7S8>?_G~#nroq5Q=rDgFT)l=@O;~*TDsRo^> z*nS(PX!SFUahP}GXE496|5}FQgd|VYX>UDhkMQmOqARyrsu(CXJ$?*)_?7Nny;Fb} znU3@j?Qtm`-X^V)@8q{RBLvdzYpa|gf8}Gfaj;8xWk5oYjnI&TOSXBHU=8TNqV7ok zc@$i02rJ&Q8lbUI`GpCI;64AW1a`T=7@DtvpTU?Q+umT7^U^DQl}4=A>&k8>zRI!8 zjv9QM)|m|`%GHyYG^tY5&n~jH`D%2egNz(yY=>tC^R1XSe+bSUwENN`{J=sP8P@G{ z7y>O!mO_#{qJ%Mb!Y($l{HV_#gVRmF)2SivsICIt7xu-hH5uVIo?&{9bUM3{f_c5! z4@mke`eWGKkmqk{1lygk_oIFmbhQgdp;H?k;ntmr6jTqko+9R??OIYF7P(443@e%k z6UQg%frfCQ;XH$EsIXWIbBMK!U|^N5QpoH0;!+z5zEy*%Hz23|bv%jRZIf>%h)9ZjGB1h98NvI1UA!Ug?O@Ky0P%DQ{CiBe+(!(2CXE!GOx*s9mlpZ&JOW1ziPtPf45pZZ4~3a8l#X@WW1(Y>v1U^2!2_V^X8R8X8qc z!4Io9i;{4QpT>VNIixzh>wuLGeKJv<#dusxml%QLtaGQR6j+PK2l@NrMNTc#bx5zo zfno@*+-gnnHg&5KyrA7OejZU227lkn!=PWPz21HAxmfefJIjw^%>DKbD|I4YU8;tT z|HaU=Wuu{ZG!J#r`Cw$52rtos5vXn8^LAdf~x7TuGr&(|IPmKq9LL#o`t%zKC z4m@lB3g^lFWO*W|=x$Fiw2Wd^1Kn2Gwi6DPwbq}|@$)q5l!-NT~KK5$p ztnRP-gXVa5PjHvGHseg9h5!EEE#vx4?|dXr%|+PPtX6nAac*tLFRN4a5d>3j=iNz> zE7kYJF%5d=gH1vM)u;RguG+Gn)8%dGN`w11qjv@fc(aQNnaA%R;IY`ED0K?U6=9Ix zfx;SAL?@)yk?d4O2YBz!SW3+e37kyiWSK^)us`1yk?_?q^&BH|B99GTscRX|!<3^} zcBw3HXt`xxlWS9&pGxv|IbQzM@_no*!OC?0It6jXA+1likT+SuZzvm@sSTgi2f_IQ z-SnrPcVev98WF$=@HBiU(mZ`ki%4QhPQQ*~sqHjS&pk< zy5<47E z;VrR#6f%kVEp6>Uo6?pouad{ne@6qqa%i9OqK9AqrjLVBr5+ln5R{{|_NLd;88grE zd1s|xdBc%%-jZ?oV2)(_;r=1#_8_IkfotD;0rvLf=Fi1MSbl+Begn2Cs5KHlKq9r< z)lAJc%J*O2s;k zeSA+p3pU7|1-hB3UB{EzWeXFo9a$-BC9u|o=Qh%(vy0q`uX`Fbm^%|^y$pIU}~(F9rf zuaH);e;nw11FI0o_AMMa*!it}D_1md;iBC3a}*15Vk%#r#u;r6*Jh_yAdf{!EFZE) zdZV#;y`0$2yvO$J7(XTBI5+aQE2b^%g@58ztKs2~ynEFf8q4rQ`?^cL%)1BuW&LvX z79@qBdZD#9y5P>iNcQ-C(#nVT4wPGZNi?S<=Jbwzzx!?@GW#PaL^%34#`clSmK<~Z zSG)1JbKmeCk*3gD;{!4#dUz3W?f4n~nnjV6KX4y&*)oFJyLBG4*jR>c5isP?%9j{+}L@DsDR8Y@49z-&v z_K(YbC$PN}f8s&)gY+Mv=%;O!q&)Ub@4lMt@YTGGaFaAiTdYq(S16cH+3tJQc>R#m z+nJT1ZNYSh6NA~PqJiPj!k98L^+^bi(}hC=by}0dr0t1BN)yOAAL?vpEPR(+W{)ps znv(UbOZ`sB)EbxXAayJ<|EsRo@OtOvNO`JLl{|NDIb@aajbxjhBB}76_>1dHp!N-i z>(V`L-^-q*2u5aRG2?zn8aXUVnSdQK_lo@XwPz>+IeqBOn?{7q2*W9gMVaX0ekJh} zE#H?cnpaw%32dM>o_F(|^i8%jrvNsU#aLPNLz18EKOJRHcacvITbo@o%>+Nc?w^YM zsFA@>n7ECTS<_pvU7UeB0B0Li`-Onhsk7T6hR-*E6;URJxc#sP#v9rwGXMPLN3I!q z^S;z&RCr{JvDGKRTSxW; zzG)eh3Pnuo3|GS$w?VX>h}bk$Vnlx~h!CW*l1nZTz$&Yz_Ml50ju>Eh_Ym4TOb%8XlDz@={WM&m?u@SzKQ3?&=Z*2#~>c z#qOS-ITc}RC`?RD6IPsCbYM1e%i=2kW_PB3`>f6WH?02uXCmppG=8e+_gPmG2{C%j zW}mTijjrRM(ApsYf9 zoDKQE>`thZDvGDdn!OMc6T3Y8C6-8_^c4z(p<3BJJuRpIhkneP?bXeogdbcwPq*&q z+8v6YOTerXW_)F~BkFLz&_<(CmXiL5t;?GQCO{cEs5a5l0}hKp`$Ib#n?>|<8|p~h z7g!if`tlBmj=z6xG&?-hJke#Xkgey-Q)@N%FD|0{emRmc2kV9d9l)jC=!MymyGg`nPMAkz&P?)3#LpM^}0WK zfX6~43rEZy#bpt$ziJ}A^jBP4~FgOrRSXCc00J&?vP^pxx9)uV2 z`}c1aq<>gc35xeK1a9~Bbxs$>bjIO$lpYG-8;#HNYZaFsYP-rFk^BAW z$Z81hr;4O8&||gPR^sb=gSpBc!4(HG8d|_F>2&7*7}WTEP!>B_SBtOf6&f7n+U@FE zam+Qo50^Uwi33SPzRk~AEdQK733oi{Ov=nSjiJ3ingoWPbEQfW3%9#r@V&AboXD6V zn6#=tv?Tu7FVjf69pR&dRc?By-5R>x&ex~1*4;e<0j~&h=`>1Tm&@(Gf3N(T1vZ7m z0l7>@G2AKnhC+wkp2QbM@7Tq!&wx#BeO-J%Ld+KaGbj+W3^OP*9GK zlXbl0lF5D&gd|hS`T73CJIisu+Z+60KiwMT)9w}D?mpJPHrbFCFd{QF7g08=X&g$- z{ppgzQ7t_z9Gvh@e`Mm+ZY1VPGneMXst57Md35C;7bayq{MF&oGpt7Bp8$-}Aq zL?6xl)Sn~wi_w|5lw7D=X*PHscxn7r2{cEAfjQ@i-cT)J@MWO#EKmGrX zm;WyZeGjLPtmx?Ifix^_nt@u5)9@xQrnHk;Z=f`W+-%fPAe0;Xg z54Kovkd6R{E_Xb0J1joUgtgtDBmFD=I2l6W(_sFF!w#e&mYl=ya)$(`FOQ&)n`qi3s4ve6N=$N|j2ytNpR?(Nvb6{(jMS zue_)QR1N__LEFm>=A5?MpFBK0QE#S`+{Zm%9}iMEZS(A2ztbth_5+g(l@)kF^Y!&L zgPE)~l!U`#1RG0ZGx#eFtXPQuakxoxsO!Ez5_hMSjfaQFV5!#d;bA8_aUtQgT(#yS z0Rh2q`xbMHNInqZC2ZDn1i=fLMF1L&`by-{@nl{PNSwoTGBvyNmz{5~xP=ml@(Q#g z)LqRUwWljh3=yLFSLz)C)?3|y#cFlGt^l0|B_F5fQ@vg6!(c3_84RQOJS`kHv#_Vg zk*}s>Dx==V4$wId8-&ShG4lRIHlm7(3R@uXPlmhKBNv@-QF=W=RN&zwcDsYVXDdxK zO|@$3rBFLpV3u#7jkj5qKft~vy0C9W**YGl(=Df!#F2HXJK zOmlxQv`aG&4uzn&55$P+I^YlqeHE^>Mc&8h3vR4WVopY~SN{IP`(X zWv8s#vybvPuSrmJZ*5uaV_-*04h~NKT?9Pckn9Tma#8nrz1`hbE1$}8P;1Hn{8BDo zlnF>2d5T{UKeYnvL-U8#UBUH6%avcF8bYP@5hEPGF{$K*v6;2K>IVxYlM}pwg)`$k zPEi<+B=uEjG$ouLCo|~mlTJ+gtg=-DVRYCdmD$kdc$O+k;G3|p^488yK5v+YGJp#9 z@P3BcO2`))dYrGeND$HEa#-Pk11h26YCL>=CfBQ9sB+&6;vl>h+XX>usL_JiFRr4|R8UcPtN-U=d*2S3L9`wF@!vwye02H>bW>Oc2DeS>cL-qw$= zFF?$sc|RKS%DPIzn0+f2-MI zC-SJ`RUg5d#|K1DbsI-VrGxDWY3uH742f=QnQ0;>DKw+iQ5h;^*C5(K=>_tBNoNq34P z*37e=bD9e1F-h-UE@j*ej}oOpB%3l9G}#Koa)%@=_2uuZi|ZYQdW8qGDqC zfIlzRo5rW5aqpT<-JUGi?hT_Cc62dEPSdn#Z*Ok;_(8x@@8x7>X10=@@p(QemXvL@ zyJK!`ZJANa9Wzg$R#sL@@O4e^wpNCRhcj>dcDAO~YE?}2wzOlk5}%G5>1vPhvt{dJ zCg61;0)z@A)GC#<1%&F42r?PdX0=UCm~k?yM+*rD0A-j`nR4jh;P>0&#!3JZUtUiLgr~8Y!&E4z?>gCS|Ev%n zM~%egz@SksfA78m6WQP0;l1N20EYGSJmSYn_HF|(Cry{^X_WnTLe?*oFA^cZ%!|cN zS0*MVWM`{%1co;JP>brbcH=LKtoZp9s*%86NvhsIk$VMNM#{l zG0?o9FN5v2R&jPkLjkBP-rU?w*20&;S$LZOV7wZqGxhM@&yr9(N3H3ttJ3A3pgT_2 ztAezTKgTn$&(F`_v(L1a>Urz0g_F5L;SvgJY8ZJ49j=c1vlS>E9UVvU2p~Tg+?_62 zJKo;{&)WnRtl|}}YM1QEB|Ge@ecl+K)jBk_GNQ{!xt3(T)*6^E7Of9N0cd!5rR0zA z@$l*PB+<$&sn_b*l(eC>q)M{|1xOZBHX|A}`hAn-YAlL?O` zUi~dl(uM-Z`b)E=5R8O~i%auUR8z3Wm&*}`&WIIRSF_pT*cceDGA&mMH*6{-O#6dz zbS}Hy5FXLJ>wW-R%{r?>YV^;DngO|?{aim9-8Js5QCnM^c80R2)$678@87@QCq*=2 zLi*wl56;=STKLm)Tu$eNL&KdKiN+*%qL(X~W zZePFpOlJ}~iLI@!%D?b`&0A;yYQITLpIOTvE3gy-qPp$D zIR4%(6ro3sgTQD!Mf6n~WxK7teNXPU;1~uS{=Hj5@8?_xSa*#KPTP+<9+!d43`!w2 z04^+)?BLb^BIia5u`$d9Knaz-Tj2Ek9iqs#5w-}kKVRkYxLYy?+GrsLfS+QqTVMf` z2KWy<0)7vyFA@nqv`20a-^-Qjr0uP(+=wLlGM&y2JCOENo=Dyys?Y#$3KF>z|c&}sC~mK&^Z`I>K>D7D)-8!eVloXfp#=Tsz+Vy1>L z2X;YJdmwwmE503|9ip-yI!%))6AT>bT$xPn;S*A^DBZ`ayQDY_3yzUG{tL*QU&2>IsNtPmT3;q;WDpl-yudrNQZGv0u52&xP zQ$6%f0Y%Sn8*h#rFm!@z-i^$8AK+GD&9FbO(El!mOWd5tL0!j!e^vu_km=_jl3~`N zC0WzQ{>IlKf3$JH1ya2g09IM?Am`@OXc}ADQy8UZJnjS#Ve>{u{E^?GV5+R-`;;^o zcFzr%U4Qlic!ug)7`qQZsyi|gSBGKP?dIjK&hC!^<5n37u|Pe;eEu=+*#V(N1E+mO z_1DyIqdz_$C{?R-m>D?i0)vBTuB%Rl6KVclUCHm`vY7?Em&?F_BAs3@339K7c8G|s zH@Ec(m4K!AJ5_<0rF0O&jL?#<48T2L^{e3516PKub{ylGJkHeDi8Lz8>aU>y{;DCZ z8J%FdIcG|hT5NSJV`k7Pj@9T7U3ao$Aa%=PXxV#5b=| zZHp|IryEmE>yrAoO25m+1NITE8P~Xih~eifvz8_?QAZ&BHPam0VLIkTeD4S$Lp&xQ z8GMXgU-f=_3C_tOGSsFoTB4S@fRUV+)IT|9c6N4ljG{CUmvy{#K#v&F()kHs#Qhx0 z1ic|}#Qre8Uo0fomK+-tc@e?{-UUF%2WUu-N8jwf@6>SIZQSC-!9+IUQk^ky6FE%1oJ-C|xTbzb^Pu74e?Svd^OfNCfVK*GTUs(9wNU#{EIg=MVmr&_) z={M?~!5yeEp3!+QSi?Bwb48V<`*9G?%?hJrwWf@hFtD1*R7;N4mO&R0M`99in8$z`8=>*s##9?DEj>h}`eZGN)xVBRD1 zkkJxe-Q1R#l=2o5B3~7S%5YlBG5^*AF`#e)8xj&1+SLq!WoI@O>#x{f z^G}$+G&=R4>%T$QRs~HgJZ(N~ovHFS(uE!Ej>U9nrRcwY8IH-bnX#wluaHA>m7Eva zVJFNt0AUoXlptN9l|}Q{A#A?lLhR`^?>VDUJo_7{*C~WNMjvYF_RghB$E;0#KtQNU z!;5X`ispQa>I-Km+ffZTqF07DiHhcF`1ae@W0e@WkPdRk`gjT1W1$1QW5X6dwba9} z$}1!M(f-$VdPl`3jwgz@#VckKf1*idSe{2}+o0A1UjoC=`N9nc{BVJasI&(4ZTF1? zXOl>7FKCLa52-kCf0{B+a?C0)4woyKU?jT0{-J!?E%G~m=hs-E7%sg(@x4zy@D~(N;<(e&hv{4jztaQ=CACRw?p;YgHyV|FZxF5!pM-2^`hV;MWxN%sq zpyhmoYi0gcqql+c{I?Ja((4X_oPIe%L1m)L+=5i~d7QNf+_-em9|jXt6kW5gwNwATe%-??#9aNuJ=?P8y7MvPayiU$B$HX8S~t9ohTP za$@`3xubfk=7~1={B+;>Sud_GXWH+Dn{H>+TlnH5-IY^H;2NlH>#Rb?9V1Ut?ntwe z#wXg?wZF*y!vDaJp7N$rE5=<8e?UKDa{IvBmSU5-5sp3|P&OCB17_+C;YqIA4psDZ zDp!O}#Gl<&mHG0cMT)(SkPGI+0m9DTz+j1FdxIQ8_ECaxJ%3)t=iiTuZp1BX8t^VI zm(U1V%~3^W61^1)i=A&+R)e2OKHiVx{UFf4Xn#Yq66tL_2+TY0zWBrc@QhjP(aH1| z#Uo)N>rZ5LN>f|HCGoh0R>XxEeBYl}0<60+*7B=XlqO3iDd-j_JbtJbn}e3K*rZ27 z|B2jdml$D&)yWfz0z+fu7?}3CLww!Vs61KEmSZ27h@%=&i01*g@QIpFn)VawI%JAnfOIHI)C1;o^qxMZoe2 zTmXckp!~Es56k>ice1Yf*h2|NzzF6qCf%Pm<^|1q!Jn<$4!J7O3Mt^R<@@2=fc|>a z8^1Qv$_7$ib=$R5-=8585)wPS{A)ha?Ci{`y{D4{5Pzt0DMCbKINi@R{T zDg!J6?cP{_3LCyl%)vONqENlprdSy34DEZCemLxqiH;1vME)jKE0n3S&#?AD1-cyc zPYaDKyS=^rY6+->i>Zz>q6-9Q5nWmMN2`j=)I=O#HD~b@J8c;7n zFDjyMJf^-eAE^LU`s&4Gp(G|7+fFKgepbC)TwO~_u%owz9poltNLc}iWDP~UGhCqc zIiMpTK2@hf6*N<)?dcoJOM(xO4rJO+X=@F!%yZ_qT(vgjH@&_Idz_S1{qh!AiKU60A$gR52ev)A*V3&;fVLX!6R!X8Knn-jGw>o zS>XRF{dV(L{&`ZXtxiB3HD*Euvtqt|N9QzsSR1-*a1VQ3Dqe8^Lr$m)wU2*=vW=Vr z6~bbSzi2A-^p+y*o9rwHO~m%ND5Bhu8ubpVo8HAX37%y^tuVDgsh2Z#Wd};U2$W>3gv(zT4Eu!HYZ}skS znV5N~XMO&Uv}wy9=9_r58TWkLNz>r1yf$xU^gypO8Ul7-@5kB0z!OoNqXyrHbtN#u$4vGM3Kk8AJJ}l;ZN`O0K^>&}wb?u!X*@c^ z-hiBYo#+vt!7%apU=nHUga=%XqT2pItB!#4T}B?Ba_S7mbN6M?v$=kUP%~maUwFqz z_l{J8)Q*eHulnS~R~304P2ij}(!N8ZdNu#oblu);$Bc~`JIXaXVb5(2v@m}={04fv z+&VpUa8XqMw&cDeOv@^YOT8&HKY=$!Wp>pc{!UoVjwAPa`r^JF@OUQu^~VS0-H`T_ zn@6FCP#moG0}MYuzrU;x+#br2qASefA)Pg1S{MDqhksv`6bY`OqB3~Es~)vB%|M?h zj9hZAp(-X1toit}+EI(f2?zRZ&j(2f{@6cbJ=+bG%M*#K4#bk#^N`%hbRWRJ5Sc~v zXK4sdgR)2-S`n2at-_hdjhcx0?jhpP=1dibH}?OZRPyc%qMM|733C2+;jT^U{f3aX zXZKm%Z1$r{x{;OyaxXKX+ZXoYtnp%Q;lzEn)1Z11m20(ahmhf9^;RTB=VPB@v9C#*ur zozaT~_j{1Gj&6LWsoDw89-Z~**x*1nAl_u_UQ0=j^vvIK=8@#Q+CA{E^qL5Li(%W| zG$3SWkiuZwo8P|5?~|*W*Y|&Okdu{INXDVpCIuDRy|RS)7*+(v-*^=%OYV3>?Km>c zjLDcdxaG!Q>Av!>>~yZ!C{fE67(T^qU%y~~x&H!ik=PqNYs{6q-X&hKCACxN>X>uVOvf ziTjku$t9?_4cYha6}PWRGREO*Hw4KaSRVc0Ja!1(gG$tpX$&9tbQ~8~5Wk$fl*<#V zA~h4+Bo}wrr6tv+e8$_ZHg5$|{EHJ2KghM~M3_}m-9Z!<(-*7?-MGb~(C-+gkS?Pq7i z6%$J)RHEKXZoI@k-*8RY%sTIlYr^N=9ZbhfWrKXE&a|7Eq*A1~N z$@zOsPIZlq=d~nqMb5WHw(^ieI-Fsh=mR>>p2>HBfnPHUFeLQAR97~}AfeG6!NF@b+Gz0MIS!>LA=39RBT28WuFK3%ev$DZs%dq8x04EmPdM5VqU<6uN4fk9_saVt+ruf#}7GHt?$ zPg+TP8WLtWdX>eRSn`0pujNs(dawQ{n_dK=K3DDUQcK8lPu(Ir?B#AlNsDS${{%xg zYwL5gPjJdvvdjPJwGMcxLVeeA8E4c>*i<@PsIoBDC1X+b%BMPgKPpgc;m%}l9U6Wu zP4}>Mh3VkFfS@XGMk3c-2oJ<4$B8#(wy8}5^tCxVJ>|g$WuI#UCkm(xw`FM58^Q7h~w75LfoY1 zmiW;nSwo!2-&yZcW5xMtu;D_Y#B1@i$T^cS*L*5V!bO-vc2;R%)cE z8E@G>pMJy5z z$42_zGN_(;kfwNm7ZFNrg!{be7;CfWSo^`A-A`N7UhVpGh>o1=cFIm&`GUyC+dP@rr)H*f) zcIIxq&4~PD&{Np>(<+$VUHT^dBy?oqtxpG)C$D@>>ojv6p3hh(?1(BRVWE5m3Avpr ziK0a0B!YIF$yxfyF;pg)L&IL40U13~A9{C?!UvV_P^25A)(%e3U+br;ku4G#T-4jLra<2=1D zNBFHMHrDEpPlw($<|CpODmD+4(QTSs%NE@5HRyEV>ovaOAvKRs)))t2H;3wcnXiTEAor6o;~ zK5b|%*4pCg*Xo!vf{+V!Vk!7)1pYvK7V@WX=0Mu1lRbHT&~j3II>aqsW5GPLGx>Wn z>Eu$l+hHLozVCEctI{pDQQnu;au)OL8ZEfdeJd~>)ar;slE=}j`z)P8pnEV8e|gYy zKxk~#{DeDC03AGuDOk}eGN-6?BR+cg3emU`&TuY`*r?yN{H8m(|1}@OQKw3yc>Y6F zYN5}c2k0hiiYoJG^O=3oIt;eLGq`~c*O^WH@?SJk&7GFa`-EYB?6SBX%hY9Q^fUPu zC;K)HOq*A&pC9%TkjF1Lx|f?hj`KWX4sU(eAPBmXnSr$0mxt~23{g+~X*ZeB5d(pt$aLxd(ou;j*A?#ZX4v0FPCnM%mo6BvH zTcu+rEVu_i_%N%Ipl#yNnsPoFRQv#HGwkYE^x8bf)xlzgcg+)}TXS=uCtcF!LP zkZPPS5hb_gk%GRt7wc{^-AWjPV6q+wz_dCyUoZ_}uBBY%x#PeXOYx&(!7z9U1u;;u zz(ISu5DFrj+I%>t13nNzI^}$4@c3e~MdGuf>4EQMUDDQqj*9Mh^Sa66ru`|*bgu5e zXDl^}DK;CARtH2w&WyMLMS&L-;vx3Fl3t8Gh0AmnOn88Gfgkec7;k2wyc} zG38w*FC9CfY4CHG&3{XLlZ z8;1T96hsapPpBH9vglY;5AaDIFwajTD9^l&gz5Swf4IMrmnr zzANXJmX>fvP>qS{-E`3iQ^ptuy-l4x2REj8%Z%V_xQ%^n=1$d#;I|w|wOcHT z)C#DX1EYNxE4(%C*AfvI?0DuEFh9+w5v0o%g~_aWgsQjdBq$Z?So%j_@K<=94X%eT z57|BF&ciHBrYnaek5Q|1igH`+?V2L&bk!m6RsUm zTXzuWrQT9`w(<$JUk4REl&3@vYTCzf8r)^rK5{iI zC_5wfowC}V%%flaNN!J?yV5C$#?l&+J7jf>Z8)NhWdvEGzl0m2s8>V|zFabpfE#s^ zwMR~1b6~;c0aWX7qf|D(mb5hA)djh?3ZOyi1K8j>#S_=EY!)j{ieIy~z=trD>YrYI z3JfFDd-dj*#OZ54)A>JGJIkQB+OEweA-F?uhv4oG2^u^=aEIV7!QG*O4({#_!QI^n z?(Xi+9G>@`neQ8^si~PN`VZCB)kph2d#`KlP@aa(A@=YYiI@{E#UmZTZ=Y&x(Msc% z6jy0W=?i%R2A#bcOp-+@k7&YB#iuOLJgh4B_73P`u`0JgY^E$BPl)0nrFEO|;P_I` zt1&qd6jcXH}tN^y16^ot)9QO*x10s@!e@ce^zVA2eTlCcttj{)H-d z(r_7NJfhM6_i*E`gKT3X`^PH$ zFnfsc;W(^3T*MU_wdWe87SR~)`ur6~+tp>XUhE~kngtX|G*70(4MZIPoIy-Rh7cMW zn$PV9LE-BgELL2I^Y!ydI=^pMe))!8J*KRSWTtv5E>$G^i$rp0!85jcEwj;TQBa}^ zA2Zt{!QII`_|ivxi%ts<(Hx_8_={oY%h|jKO8f_(Bg4Z?(__*?zIW;$KRh9AtnDAd zu>2IeA;pl-mvz~46}K*mqA3M>Pdj1|)YyIhVE$9*2)Co@eFT;yKd4gS`#I>&WEro1`o*?)&_7i0eJ zHVIP41zw_{4cU>X+5gK=@p)Gn`KE7X$lkqAw={MFPsequF)rL*3k$6;V8wE6egJU% z0@sU*f|8M^9zS97L%ln9O~Hj9K(4|F`PCEU=;h_*E2==$dKX&*egFnbv+HIa`Dz51 zmjDpOVxd|>MkW!UK{Ca`gl|cYYykTkd|GV3y26ccCg|=V=D$#~3Z{s)Ba=tmx=s|b zIo|t<|9DE6MhXwtrh;`%vs1ZuPoWmoK#=#_oXhnL`=#mjQ@;ZC>2zsOsEfj%8f)W3 zvqBz6pqFj?{P62r2Ui4d@dxW(awaW1`SM;l`WFEs6CT+@By;rJ@07__T zN(CMk)(0ddjfDgJ=Z(p`pUv~%!o0Go3CsVkIpi$Nd}-NGA~IHKG~`@;N+mCgoLN~& zo~K!T6gyv6W|*>~X013>zeEbacenM=k(B1OVSTa+Ul@S=K?Z}2hxM-^u1Jh#e)>&0 z-~njcKUIv(0%ihYSvz8rjYdOjJ~ ziKHG`+-HPZme+!y(YvYEXqa*vhHj{vw@0;|cqR1URsEF#D|pY6ye|0mU|Py_eFUJx zihdMkF;tcT+$U|zhb_1MTyISoa#Bf;UUOPz+jAEYGx*HK=G;A%llv88q+H&$zF&fy zn-*rigEp_9I7W%Uk%S%-O=HAiDfK8CV~`~K8W~+{ZN@kh;rPJ<9#Fb>_1t%lTYd?T z6Y0zCwv3GU=U${OC@^jIXxeFMXG@y_*|S|+PR*)6-;EOT`P(CJvE!VdP4NsrxiURS zF0z{S>Tq3iozIAntd$_UkA5UpaD9F58e%%B@;t-l$8tMOYMb?b%;(U5xr;p8Tg3B# z3h(X?(6D8ons^|!UfqB1+NlOcNjOKi@_Sqje-97BZUEQ0PoAR}Ii9;CNMN?AeB-DSCFmbB6JGcMM|J>xIk zqW&R$nKl~I;_C48!%($!%ay;k~b$4N1zGDMO- zlgoXNEi30NSGPrsDtXukRfw(hv`S^&Yn^H4bsQ(x5~AYvQEuF+ z=eoYnm4A^Hwktjj!Ys^Zwv?F9727`}Iq2+g4lGZS3{_hLZO}0(dUqbJFfJU_Jkr+N z*#oL{9zdhRR}j}40>SJHl36jF`b7L?$hh+sJ-ccOgNIhuvys6P(a{LfBZRk^KEg4$ z*wZF8p%g>b3@|=XJRh5vx^oJ;db!=ML=jrM7sb2Mv_o4D&@ec{wN7a{0K*5(ZU&9! z>fu#jt$Iyj(U$d!BNXENG&85Aez%b#H&|8-_TCLc}eQGU$hBP%~hmv>w_SNsje?0ywO}9v}PNoIzSt zc{Apv?&)My1*T%*NCom}&TpCLG;v;SpW?F!zPl}nJ+Z%d73h>%4#!-hy zLLyf_=Lc|Q@|COxYf#Frm&t=ekKHo0PZF#LV#+Za-7gV&F#mo%RDdPAItoWpvKJ3VYvYV~kk_q&a{#CJs(^0OGrS1&NrigVvh8qaC-=$hUc zWLhQE*>N3ghr<>egK9gwMaks5I6jHTm3A|4gEm`CbXJt!awi|reP4J#Sf6^#)iHYV z*z6GM^YjF*@3cLl(=G0{5cNH!_qwM9v!u*@(exzs<3qcMAmB}d?=}SJV=9F-JJlsQb8bJCamD|4ql7%Lthc+3mVU zTIGtYx>ye)e4Q@>A67w8=|O+*e$!;i@rMrWKfV>_rdj9~+RtWMGZ^CWgyHLS$3xp5 zZW%{uJqc$vvzYh@)1=_XzJscl&Hd2XuE&61OxnCYo5<+wm6Q#Prwl;S3Fm7JW1jnW zWGgGgSUBsWsh?}=OEK^&U7~cyGh7{i0<93g=l`WkkxRXVLCq!QMmYt32YtUZ1Hp!a zN$hDLl*(cwJ0h_{VQ=?^!AUKi2$2kWD;$?y|BI|hp|ClGMvGVyYH%2g3KX#$Mpw^e zJsN!m5JHWOGOKo^39<{gRj!=j`cLu81$EDWj3^&f|fHifVbKPH1ntWc66GC884B8&SwBazt2D{*S zey@TiV-hIMhAl7^3Cf>Gyn0TP3{AR-yZKBvRht18W?kj@7g&Kx`h=^+iKm z0E}D~pyslr-!dLec?+AV1#=)@Y=y%1o~|?|mgp{h!VP7*55GaGkSjvVtqk(hlt)_+ zhEcJOb1d;Ye|6T|-#wv1^B_$2yzq8Qj4BXxZwY`{*qvWARyc5-n7F92)9&#fA7|Qj zs)%IJSS^ z#oF9LNq-!_0NtcfX<(Y*X?MV3e&u;zS;0>)A^}#njNKBDEkui2+L!O&gVEM_j0cj! zps?xjV~I71KT~lH9J4_2iw&O``puuHc|*yLB2B*5P>q|uN4k6HylvQrapwSwM{g~O zy&?fh(ir)@BVFs$j346$lx~Z6-!#&x?R^cAZf+_7*<6fXzmGG3_P$ZYTs%Db0F8%3 ziUN+MR-_;&wq9xGhOi@*gsd&N|7!nf3RVPPJ;d-;tZ1Jxk8wXV%M~$<^Q9-G`<6jU z`qtg{s1X(S@uEOsR9PlGe@Gq4vP~0qOq3u}3O$sJw*{pN7m*Fu<&MLYpf@IHm-LRA zkkZPKGo_I#?4>d8%LnZcf7>5@P2Pz>>CNaZV9(=C*R zd}j^UQO;?r{rLyqXj?u1sDENwRN*~P8vrNnrb(;~wElmY_L4wv;c-}e0V2gH*w_Uh zw}7yTGH|-JBuZPQ>)B`-6JLDtJb8qyUGWe=;aFM(5j7zz5boxVEF;jH8GY!yHF$Bb z{)~+|a`ldvV$4s@AF4s!>1*!RLgBA=2`k**uk2=56m zA%}iHp?C3+C{GU7?nDUS2^J*4KQED^RVa7JIdY z0&m3`Yr$R-E+XezX4_ib1Qje=ea@rT4u}PLwvTgMkGl*QCfL<%G~?vTHGD+ z#}v-v80@@6Sf({qx=ep`z^lcr3xLP1we+3y>8kos60FbT^IJ=i+2O#hB!!o7(^F!_ z87E2l!|TYY_reI@+2+P}f?yf#l3RH%eKeUF!aW=jMU{D7nTU1Ru?lYO#4gigSEC81 z-9NGpQ9%UJWY^Tntv|30>ayMbH_UFv5(`4hOpf2lf5Hg{4wT|%^>ydox%Ckq%{Ran z?}iD#awx4YTqFAzo#tbqV7>(#mxh2eqzgOw{n7Dp>{zYIgpf?;Eb2b1X3V6yNFK#y zx0Zos7Pht|BU@rpu#0U7b^KOikD{&Z7?J+S1KEabbFkD-@R}MW&2{Qq)&j?$1|QbC-R% zl`s7S9dy*)uK?y=Z@|sMq9EhJ4?Kb@%|!D$S+Q@AST5BwTRpnoo|NVRBA+rX&>|y{ z#L{C)oFb<-HZaI8$p-*sGs(eH%?1`LH-LT4%F9E?jIYet#F^Jcx_oxwF;wZ3=#wi1_YlXt{-sf3vd+^A28-N>`Oe0r> zMQjUKj1D;@ED|PUV=BO_p5ZbK@@{ibw#!evq#ix97{~q$Qh1eYYmkP*)opifKZWIq z6+2gj06(ENNW^zp5wi3}wSLb+WMJ$z@mV$s%UvhubT%t1RraT8)sN3}2wj$ymKIX- zo$|?Nbpc{~sUXmfB3QD9Sji4h2d1Q^##;Vz0+f(Zlw7K<@PbBy^tPi$LwAe)Cf@Qq zV$!`P-K!D${vA{V>=a~Kt3N;Yf1a@fuu#HMo@}WgFq7w4@#!kb&;-x!7g;xc3w(Q$ zAV2qAv|RDtgkk7xmmMjV^t>Mm_Y3|_#eCyM_E-*gCVUCxf7_4+?%oMLo%yj4@tV1a zIi0TXZ?D*@1;AmP0kJQdTa`(W>K?vK9F{8tCDQ$dLZ|0YOs@r~w&M=b*!H%T=*CJD z6e|_4+p|E=W%I&d_54WLTSISjBW=qcTAL14Ud(ChxGMHkt4|#VxQ6-jtv6>X4fj;I zK@S3ngockVK9`(~o8z^D6=wSiB@gDDvRZASErQizo|kwHzXHzNrt@t*UtOO_Jzg2{ z0C^iYd@7(k!&YQ^Y!-ka#VcH{?MhUuaG+l=`+U*`MD#9CD@uJN^2u~zR>D>qbWm#d zL0QjijG8Sjn2#JUff!Qp2G@3UfFJOI-F z?!1p^d&O*Txrr0VqvbISIB$Jf8W};1XV6mka1P{#h=Ik&fkIYTxBnriDbq2v@oo)T zN2IVY#&d3)2DuKIC)@2*U~i&<{UC0Xkm!lrC>h-l&)r+j^-64v;ds12s0^!BB{`SM ztmY`#O2K96c$19jBcO}j~@JdPghYo6U04>D33fgvW2_l!b@O+)RB>A2l>h{?0 zYqrr^xgt24(1!eiA;MZcY#=)*tL;bSf#4P1@ZB-775RiE8_gP>yZb2d7j>g%FYKXS zk3dhtY-XoZt!skFdN+4;OgX~v%;(1px!24)qh4#SgsDMo@husAO;!Sfz$@WX0l%}J zEsV2P)Jz8VSgIvKCzMBzXVj>@AehTlKf`D-{fz2Yh(K^xN}JrqS6>Ea`AO?Th;-KX zlZ2AxCj^~E5?wi~Ll6c=R=xN6%%@QS+VegxtKZs&y!Ei#oEgg~oi6u|(Q$yyLtam@ zWBL!pmcEk$9Iy4M&77bwq0Mgi$(VAzdzsIR=Ri1>{{7#G~y=6R*_q^MrwwG(T(nPEupIxQ4hq0n7J0GEaSY};5bY_DTBVOv5a4m% z(;U7Zy+g^L8R?$&d(0ZkkWVS@UM{4zek-G(yJB-~)X{J#Lm_A_GZOG{_JO9Ul`(gY(PJ$S^1zo?YvVnZkh)>=W=8MSjBbmTPseXUYyWOGS-vDHNTF>8Fn+xs((wVL0DKrL|!QBQ6Nm=0_gn@uK2+1Z~PJsxYR z)c)lBX;C0%08*Jb43%MF?B2;qV2i;J#ySg3gQW{NyejvP;i*&$dWqk@i;ySAD=J$J z90v2JTg~x*Wg9itR*5`>uj#IP^qb;r3;VjHIxD!7LYv-JWr!XT0yQwXbHKn6P;4Dv zd%1LBl~r=(d5>?M&fe6oPnmxDW^iV#AlybxhPD4(@C#ZgqXp71XTYg41Mq9<98!Z|e9YqK#?(poFTHdn|D$5MD8| z_}%@`MNppLczZW~>Y^5LElwT>qPZvQ+@L2$aO@j*&?FkP&5*ADRWzP8gE%YL$5Z zISw}e>jyLhNK!8!n9k7&h{}UR-nY!g$WTnyOdzhp;wKP zn+(cEHQx`mRjT|-Kkmp`m$F4-XX-t~^d;mM`93a#|L}`WpzJr|WC&x#pF=Q*zRhl7 zy_x~t3_+7WJak-^%Kb15FF&-O;5|y|V@c4ZnazU3;Ky9T-| z5_WZ8|A#6h9RX)CQ+*zMR(WX`6BsghEK1IHXASbFH9H)rHoY@7e}ajxDJa?;YJ>j~ z1wWk`_Y(alsmt?+ue{<0YemL3=?}=G+7DQ~%80#>j^&YU(Z|#`5hvd7+qA`IIR37C zY%&uH!u4(rX)qN<5bmlVh>WT2Mlr8xNY_c;uinvd|8Fa*Y7t=rzvsB>S@Tv(d@2tk z)}wyR&>t%iCU(ghjIPO@=^YBFmcL@#JEnISwjC!-(g?DBg=X45#Hb%dFhvCthTT{D ze`A4!`=bG1lNoI0c4=6`SV~d^)E-mTU?(7WnJfDAW()zeAHYrjXrBIAWD$WPZ_<04 zpGSWoU<13NAKf#aan1Wj;FLp1Z_Tf+waT>v+QE>RI1RHK#3la@g4RB-?Y8!EfnD1k zYv{Zh+7DX6jCm0WeyrtGU!-qm_3zrrx$+Br${n>73FWXH+yq;#qtMk$;a6@G0(wA! zS!J%g2Y*|#C-mCO>r-kF^W0Y3f$Sp6(;Z7~SP{hYJqGriJto=@GJOuTWGJWh+t=7W zthR@>Ck(-(gc&4J2a>Z?o)7kD7C#4vM7L>u9rdr;m7I43FkVwy{JosM?Lrs6le4z5~POYcj#YyGN+5VSl=L*s<8QXlnpKXf2=I4xUBqYj!)OwJI))?A_5#o zeFaF{ktviRe|f(k4$|*fTH>gbshT^G$_;UEgx$wVbzz)WCHsxHi~YDCJ&gsUx?apu zk~S%m9WByffv4HFB`YWGp#Gb6+ADz4D@5)sf6yAPO4!C?Y9GDGV)bcummNcP(i7vP zg?%E{#E(m`A+F-@n=fB@60`T)Nh2awey}~Y8hw<-2pPmomY&Wl6vI?Wc z_TX+YHP8v>0xzRQ>3PG)n<731ZLt1CyXPtVh1cCfUE?4ALv}jRRn+yGv~sZsm7}LI zcoog)+XQR6EDvFIXTyU@C8L0_>DOT#m3)@}pAx(6K_R%MDe}13b@kq9_ne`>3d0hrrYI_&j_63Z z`*kH3sll{a+P~wY;~W|=>X7!`)vuzbH;UA5*ahyCs>h0AXLNLi>u)Zu1@Di^B1Ma^ zJ{{Guex1jC#GEgIY3g;Ab)?d2o9?lEIh3qF#3e6efio6-Rm0?d^ChNEMQ6xC z^jy9Pw_*a@=A|F?W8na!zDmYBJIP&P;crZ7e6zke|Kf%%#jmev=Pd(gc!A$S?Jod?~Wl{8vXF9js%Bvc8*P!>EGX)gSZVw0X{fi$)JAEv-BCjqC zA8S4JCp97yS8N&Jl-Ta@^0U_0B54D^g>(DQ#E_1s*0n9~G)BAw3>!7Zu z30M=bs4fBtofu5}d(VwGHHDOfK>jSQ03HOoGVb@QY{vpUvQ(lmq7GV6!RMCdrKqud zzkGLJeWQ7m>C*z=$=%((5d8HDBSPvqt}MP+0Yz`rU_yuDnM7x}NGWb^u1>$d+hiriCgPx| z2lLgFEnAdMuB7R`WhdE~*4bO9;`-|;Tu4$)?(^wnHgig@-R(sye|=nmiS+(`K=6tz zLa(Bb0lbFBCDIHFy09 zw&H$Jh-lI>L6+V8rY)qOQGBRNAC1kyk-|lt%tust9E?GO$6~sz`GRQ3E@*P_mRZ^I zhkNhb?swxb`EH9B7s;YiOoGL(%#V0l2-s>B(%eth`UB8bM!X=~;BLOhFUC%TWtpZB@v{!s~`-fn4;xOYc}MH|Boac$2|=I0qEY-Xn307F$?ZLCdDq}Z@darJHlZtWr1Wy=u+%ub z`u?^Ls99l5W_`(rTgch<9OG)RF zm}1h`MKNw%IUe0`l9+#WOVMKbn5teU)<*(=R(|}5UG2_K(bVHqGQ-79mWMk53y<_! z$>@F~c+4owYguS$V>^p{$oju`s#vD5udpwca~DeHf2XorkxG651;(IsI6raOqDQ?q zj!5l9kH%Ef8p`n2{4Jj5pYQZ8>_KvmwUVePaakkLE|P11Km1)bVeEsnNSoO_-K}^P z{rlMs$8qe_wT%GX?E|Z#aYNCDS{?n|T~deP+^>6)sb`Ys{N*BJ=pEp5_c%U3gZ#Ai zoDsNL7RVSl^DPNEXK3uq@+oNaM(rpBwJEpm`8`4$RB~H`)&;cfrs`2+nFpgNCLXZBG z(U29UVEQddbg1+v$bW25>5TP9w=o{$+Dz#s*THeXb6@$|$u&g6W7{OGn0M{JCG@A7 z-2p$StLZcu^6H;vzGR2&E9tUOAk|${?W=!sHC+UHVWc4}Q~qwPeKyMlJ|6cpRQVAC zaw0!RX+$42$R0i!IS~1jI&H%L{OdQ8(Dd*jt7d+pK(r+xIz|k-MVp}!F*q~|T_19D zj>CFMVq!o9B2~LQ1qK0DC2-APaV_;vEO``$jC#LckUtPpb>TYSSQJ+r{|2bFEvYAQ4lR}#7yKF^A-vw_H*H)*$je( zI_X>x7Mo9Vj)Y=FFP((`X=&lXy|uLH5=nCT5w6-o{nZrUgXq*^+)jQaviPX6Y+~~` zP#KLR<+l1d;8XB%=F|c(A+1k6NGfZhnkv#HC24|ytac^bnw_1$eW#(=mP8$N6t4b- zJ}c_ygdHtp*oDVlZ}A+9@%!4edOu?xx1TfO#Ahc4BA;I-OJv4L^TsHD4ZG2HzxTZfhtE?pZ2Ph=v)IkLq}9;ZAe`jxK)5ELy_O}3tsg?` z3WtvHoZj=6JN$1D$67vN__*tg(%ixTrSH}s(QOLE-{M%|tRV5V#Z?cRhbt0`yFzH( znzVr8T1Ux@7uRa4DN8yir9uBQgXoDOJ!EOk2TsSMdB4~~T+9k04$H?X<1yu(?XR6Bxo-n6folspm=Xk)*zJ|Dn#qym@)X=Nv}>TL z;&i3$0Z-9~i+wR2#MlBkoVBg|XFl^BOE5#R5++6chMT9x>_c5+&G+XPJ1Vu$(0{fV zf|PY$0;DQ8i%kKvZ|-YB3^jfqMV(L=A$sN2DZyNoF(pT7nxF0SVn&gVLZ)% z24w>NA6&fdfqS#oN`EpCCrb;~jRa$Muy%0la#+bNVa|qsJ=!h&wid*AxP{4aF(NYf zSlwCZ0L^UwO2M4tA8un-`(-J@ZOS<3?71^=#i~DpioDi$lMVHRFq`=ecIMC6G{iQh2UJ<5N3qB zx{ykGDHw6fqX&aJHO3SWCzc|(S5S?$;V=TQ6*xHj!fc#-R+%xTylCd$UMz7LFmG2e6#zxbkM)ErZGDwQ6Q z*O8fpUC2TgN3g%zB-M%KtHRWPC3?IQyE>(OW0!QL*IyL*&Rz*N4I2%QQ}_sbR_mxX zgO)R1gu0&EN_tUPaZ7jZY-!v~CkBQjIey&Nl0~eo^xatk`ce85D$Ueiurq~~LySwU z3#1!WTSJ3c-_Lg`=lw_X>am~P13w8))ncg!J?Oi-4*iY_drCQ1iTq&4?D60L1WY2m zIH^%o(#n%CS$6DuEdv&F5cS2H&^iT9Gz(BU8q7d?cx9CCD_&79+=kmTpg})upN)7& z1|ee&F?W|ULzi}{4^GJFvG_Tc<_&T^TLJJr&GUkweL+vgUZmsCW?PZ9q6V~sU;%s` zBndY=w?);7KmOU{N$Ho{$J4x^jgkD&JG1VgLtdY@YIFu34f&lRKeN!m58L`BF|n&? zUB4>xq?k#1_79v6L{P!}2SO76y&=jIh>o-WqTI(*|qf&_cKd7yB@TC$c<#8Vk8Nfd{9L=c9I*N&zMbAnxx+Gia6rg30W=@ zGGMhB9?ky}*Xdkt#s6dYvUayW z!ToO0x{dnM?!`SiZVS-jcJ%c@D=CuI zi#1!|Ex@Fq!b0jQvuT>Ey$RvBu>;js%j$`B%~VH3>X_lit3oY_Dr?+qlqUqs;U2pm zG6&W}7sAc5zD30b=BQOKxyBY`V{7FycwUwJP*dKl7qh^&?a#N86ihiFS^dBTdx675 zT`qfib?=CH{WmAM1wSufs$Ir@xi?e#{vZQaDnh6HwqigEXdAP_Z~g@HlM&=9V=m<- z{40{?VBG~@zyWubP_&&wL+?k4A{K^SZaQ8My;2_oU=uCVX)EKCSO=cF0aIfyS^C)j z$bg3&5>P;Hsx~?lvb-+D{015`Ifi^UaH7!^GE*@Qq-BQ9z}&{aq4CVr!Ypn26iQE1 zQ(?w$So>ko;T7gwQcJq?$$AOW6q$eb zV+`7WU2k@9Nc~o7N&~u_fPNYWoI-hYbQELI0La|T%q~ zCO=g9jyfB=zdsyqRYML*d6PUINZ@RyG7*HW2BO{S`d6y8I~#cb^H67JOq$)F>aybe z;;d@G$UWGG#i$+j=o8G61eivP?a%uh#D2CO?>vgcL;~A!cYjZPiNh^8nM|ux6vNZU zavWe`ZZOqpJ^tAX8)@tWvTm|r-6T6mYeuNxVx6j_DH^$y@+lyfAI|pU8p7vVgqRjn zZtTZ{I5eeo`6FtU#v`Tm`0WT8(KxJ)xCj*Q4p^Qu# z9ZG^B>RaIHn7xySpPbLGuKuy4yJXeHF13>*H&WF79D?8nxJ)cZ$8y1fG(P9#W`Q@8 zp6y!u`?n7a4hnirt%r|28$#!7SlA@QmCe2#V>xYq-JcyD$?$&)zh zRejryZrJUb%bPKSi=iCK7L6vg=G1bhWIw9veI_j9~S+4#+Ua%{UNtWSF`MC$m4X|NBQRk zzo*^d5@73$0d$fh03k9Cn7;WupHy9zfw344#|uh{UU-1N^!??|h~myRLn@%UWCk>` zy(g(L3$7#*M*n&9|BCJM`OckC06@KUzzFya+T*M^f=GG&wPv&6b5P)>+e*Z=Z{ALFHW$ zYt6`Ry{7=w+PMh+g?13+;41Ij*H+$2^9J(vAnKY`Ck7;v3edcAs*(*sWUx3xCi46cYZ7?3R* z0Ml9qoAG~x@i=91>cj}~Bd~C(vw6H;_yCwU58(%nhK|)jb)ntY$;oT92B!hQnEgFZ z8fUxaY1z(3K)*Y<-aCk@RIA?RYs0Y`A?ps%){;!(v)KitnSVz_9~{_|T=g%4!56T9 z-E+iRw9CVcXrVvYRWLZEPwzVb5Y7}Ccnc2~w+MgrBKvj=!AchC`?mMBB_O?dzE;_O zVY<{1-})lXFt>&2&|vTu$U za(HpLGW?ZqMskmbs}MYAMT%F;ZQVRTq1m6rPCLXVS+})<4-62hcfbF`QU^6K; zoip$Z+*8){V+n1qc<+Gaq9ME*Fci|Z-mkJ4kG^?m0{{>0A>_^&di;Q&^PnGYt3-p< z{>0QaZ71gT!2u2Qtq@%AbPOlibAdi^_L2gQsX02Kb86m-i7E8~;^bamS;$6H2*#Jr-iQr^?4;i$4YnbSHHiGb%b zR|KFe?pFkyKiA!@VZStwN+PAby`1!k|+{@p`$>zo2R!6MW^f8pZ|0msW+* z*0IrJV;7ZlPI7ga4dj|9aCFGOYxH~IwwwMw*vgAsuyzCIIT)b8X*O}#%x9#vqFQqf ztqvGT!7;R8H0s-DFadcKRm}h;9wjB^Ck9Pb9FIr{p`PednP|Dri(OGZhJUj27Xde zDM#Hjy{6LY|$1;s+fH2 zgWJRah}%+>n4o3f)d}=7WZ4%d(R*2u7WYvb2ZygOMf~SOu_{iftlDY;W!v^7plmYD zP>0SUjlz2CiHTqiSYj5qoto4b77>YS!a{~OEBqA_AU&;Lf3jAsy<`v~>pI^$;GxqY z`kkpBS|dw-rhWgRU#FRe26LvkP~o?y2Xo_enO`H((>tA;9$+GkKa_c|8 zP?2VJp090{-|h5yp%gJlsZ%-15rArTg;8*x5lEb5H;?F{>H$x^77tu6B3Hu=;~h>_ zw&f}Kx60$E$CLtRPu8z55t>?2U6Av50igb>C~(xl($jr_b;Kv;xTCg;OG$&sfss($}L! z*~Xq@eKgQ;?&_>3ofGmYT=dXNSTPGjJvQ7Lff$$@NB{hBGcxFWu@S_yW9D;$?w0*a zfW8y?IYN*jpAQ#YzaAbN-PYbN@XU0VoOoxMOv%T;ZHn9+vE0sj?ZJ#t?`1}|115q6 zx3)|ekAIWatjE~(jpEuDgpJ!v!@BCN`8iv4hV-ky+{Z2Mw5-1IY3#Zh;KGd*+l-{_%W8w+qBsSAdkRgvxPtco1R zJPLJ@g3j{IZlsDpr*TLPkf~ZIF5koESv8m%_t`sTMp%{BfGnpqqMZ4dw%&Lv z`NU+#1ciVdCLfKZSC#H6rQS`Svd~7`$d2fOS z>&%%luCyg`tJG=|bF)n+PS>~svFO)fL?~l;Pq?$Q5iO>dAe=DtO;BHjHIrp5mUg{c zRDoNoD(|K_0s$LLS07{G?sSC;2c5|Bbq^71pZ;Bj=bqqCguGI2f8HaUZz5NGdU}A( ze-LsFAXAE|2heM)+Ng5=hCH>@g?>jRX%?WRaKOcrH$0+7Ep4UL&N=&!*crk`Z{tgr zZHBUp{$V(~_~6>#vmC?$f=%DXtI)#*M?zaeS(>l1fKdrEuMittIbhG@!6dOJk;G zUyOgq#tO}3uYzxB)OMLP(Gn@Oz6B9qFT-8CIf(LiKcsH@YY?`Kcb96A;jiM6I6?=4 zK*er7PneZMY&?;rFSnhXKeUj68{4XP9;bfsC;2d-l8W=RDla~>J|e_cB;|R&o07aVTtQy zPk%ctrD^5GpeL=p#=Orl+&K(MtR|H!FFlZq!mQ0}u^BhDDXS45RsI13SV^c5-kdCJybU|ZF?^zSEME>=><7;f;C8-uEckmb{ApIS z>$H5;hBQ${QRY9CU~$k>gFhGAU7 zHJEkYJW}$~Qm^`bMIFwo zvK~QC`H`=&8oie024}vxs#lwoRJqU&F3FpR0YJd@9F0zm>+?>yZI-Rgg0ZSi9Zl=g zH1uxyonBwHBUKi>&*>Cuf<480i_XPuus9xb{pF9P5@yW92(8Wl2&su;cO(DQG{25? z_uv_i8R6G+{38K-9M8r43?)xQUnJ>{S=uuDDMw(YDQ^H_Z-=5hN@J(Vcs*UzqA73p zA#L@3?K{d&xXkIg;>z1AQB9sIA}?tZa8x7j9D5n-LwN&+(@WSKFKjyFm0tlfuA3Dk z3Dq98Z~3F{^j1&|fUZ_2Cym zrc4isy_2KP`qAggUW;+gzka{4XAE#7K5}Q-r3akxc@XA1GWHr{5b0mfiQaLBXiTQN zdW9oBE3fItDc|Xhv40HsNV0YA^2FZ^W#h{HaFY}yMSXe(+#+ri-UN@Oap&zdSq4a@ z;~02!H(~{*k*)eaxLdqi>#_eZaxg1Tswk#E-9+|hO@#P9JHj<&@+kikIWI>YB`QPn zF-uw>lh%e{g{Pl^{6l^h6z z{}q!|@6CR0RaCdsTl_JM{B3ULQE1ZTm6FpfbAl_BNAXFdQ~hgoJ=3xbm&C4I@@*@| zE6uK^Up-fDT!l|~0qb?>)q&_gdV_rrS9!B~IZclZr45bV*WU-MP`2wq3#8Fvv={D; z4D0uc&cee^PH4Cu)XYcT)zg9Y618D%)uZSh|E0C>j%up;+C_M=;42nDlq#YkARt}3 zihzLBgis?d(tGa&M8QH;TBM6K6Pf}cbdZiTfe?B?S|B0Rgg|Kb;P3u!S?jKQzwfRu z|0F9XnKLtIX3yU9JhPui&-@dyeAHXJJC-Pc@XBeGIzivlh3XY;Y}?#_EATH>_VW9r zJrSAzKKVuXCe1|~-181F(GWBnNBY)_-tS`<0R^{r5?SbPwo?)N>DamKhayp7W9LvK zSMA0ij|$^|wsr(?6ODuUp6P^&vt%iT#CSNux2}rDjE!{p)Ohp1RljlYi0)OAaQPj@ zttrMCk7k4cFS+!eBe(aZ4{FH{f7{2Nf>TL9P*03KDeXaXor9>&DG>+vw$rFHdOrl( zdrN;iSL8rqhkj{J2Cwh9N)N{9xQ6V&1X#uQ->IUIcXJb!)Ba6tI%*kOH(XRYIe<7_ z7aLqJ0SJ1z^9>Ido=rw-a%A!YTI$Aq;!0a>qBVxIQ-2-OAW!w*e(0y&oz6Y;CvRl^ zY@n)gIN<}oyPHi4;an&&Wv*HKv~jSvd=jmaGcVve6l?|Q?>?4yU5flP zvd2V&2Hux#KagE{+xVWSNE-r#8ZKvCiuHb;rSEqUNOia_{-pWTh1!r;bDP#f^gD~< z9)ufUUCd%L*@L}mH{x;nC^xUHp8;?XIpM2i-6fwlJrDd8^>uCmvBt#?9sh4f{$df! zS~<{8f_~5kVe;2soBZ{KvcK0-X1)nO(Fx`IUwA6fot}81m0Lc`XGJ+s9-@X^HOVYd zr`QK9A65R!ecGRB?3j1$e z6V2>oTWR|$2D!X$YrxWHD;QlKxuEq!!~!JDIIhBRD*h+)StT0(RF9GB+p?=x>y!x9 zHnS+1urp+bj|^!7!QbXak1^p>E#KS+#U5~zVmcaXTUYr+-(U04EZJr)+T}^WHI%;_ zZ*Ug(6kQqCWKXOzGD-=W_aV=rH4VuxJMsusv$qKvi)_L zf-=avd-v{_##eVg-!=K5`)xn!h{Trv)JXO6*&)^|AFeuGaiSf#%cfs?{8y64UxzRi z3kiifG}>RY8s=tak0g{fa%W&y!Q2XdQ~M*q*|VP|9~ZT=EnGn1)|YJRb)UD)61 zR{y4JuBnM;H#Q(PNczV6t#&Q7EV<6}_Xg*SovG#K2wuQF13(Z7qe z$pGdi;+w)!3+C2Q!FRJK9Lni=d>hKbX^!iOQL5ui8I7MaGG@*ozQ_eW-o(Om$`j`1 z*}^779%X%z#(dkO5XLHPCb;S7;PNqr4%_?QWbao=0@;=B+jm?il2)n$%M2gCS`Zgc zqaGDb1S9PsRdb`T+`x|rA(J~#n{+vgvoaNk*dN1W?Ph6?U77&Ls8EA-jX9s+S=p%Q z_kDiI4@sX|%S+FCu;0sTfgkXnr=#;wo@yr`A&&7$2X&zqRg=s)_!p4|b;hO^G0Y04 z`%;29+5=+V4REAHCTsP1Sx8g&Hf4Xuk{mXGLvn-%lUpO|eAe?s#V-p$HT2~|R& zTDUEhNRfBLSbsamH7m3bWwj{IG$=B=A2<>8o3B6EE%9(zp}4XTpRrxtz^7P7Xxo@d>*udPl3GWx7=O7jkd%CQ3o%YhpocjWOjHqRXM#i+7 zB(54QVQTi?HNmf3`l=i18%g{cdSI91?Pk?*9PExW zL~hq@qZy?z)BZC0{}9|@^42kTdXF>(7zo5tM`)F0Knp z15)rgO(oez`Lh|}6>7hJ10DOHJa?86mxh)E_3c2$n3+VAMj@g^IF}_`(9IfRo z&z3@Kzk%rDk<)?j(hD;32klQHU3DRiGSrn)f~mTpf@*NnL@Jc!I*f zU+SuBr^dUlYZSgLbv3m-rS&V`y$|xF;G)zO^r7Biu->4%X&K$!F~~ zq8l~z57&IQ7BmZQF`lv9z41bhr=zTTj%h#G#|OxZ{CH;$MGnh$FfY#*@?R@Z>7K(; zv&soU+LDZd2ExiAmYm1rt+IzG7&(wqe5njo!D~36OTr)nS)_}NLCHOqIrQyL$l;CS z;j8!y?H~kNeW$44W!{A}A8&78*n(PL2vMgmAoN{rWwlO4sGiYa+yEifEplkhgO3Sg zkjcc8TD?wg`M^%5r6)=UVJ7>v=$aQ_09sW$)RsswvbEU(+e#S~KvGuM@pA}elVO3W zyrIxe)Y3@0uhG1@G`R*5G@prl|M)7Na!*a4wl-g5oqQ<3Hy z3>NDT(RSAj(|xTVJxVIm_{mEABA)9GHWMHMa-I*NM(!cNR+FNJ=u??oDHv zQo)A_Ust>fUWxH5RKqwZX09>#(Zz<{THGf^vs~QKi0S@lbxP&HBp60%Qb~un@S2Rclk##&^^cA=EGLE8jd6m5P zR%OnFVd~0+9o3=#eH4-_Z83b>)_1fk9GQ;KgT$=l$xvelSM`>he!Qxot!8TCI@gP7 zM(ZOQc|LUT@4f4)a$cezWLi!812=x0DT&kGaqUSF(RG!%($dF*Vu%sQ3%{N$j)-l~)mQ!`71-ubj4K>YMj@6_<^d-AUQi7%M@g2yly{L*LB{L?)2M? zpUbdx#7pzPB}>Id?Q_C~n=wHLg`n+_J8Eh1ypFC~cu!ld3SaL%c4^}o+17(AL>K(i zvtdJ+bL3C-(3WRaAb3#raH21LY~e$87Qczp-FExI{-2n_3wb(0p&u&N$oK?b^jqAT z(_^Jm1+lk$<+2up5#TuK zWj7_84DNjDihr95$)}GM9Q#5Lf@%k_y5;$=JEyy%b-TEg=j3TYRm!tj01F^nLz`Hz7%^&RdSI zYZ`#4C5|lj0a4h@iKYQt*yh^3&id_sg{`c-u|Fm9j(k=b2!}Eeme5AQhu{p{g6B_uf`n%56kgF4j0!ru{SiIZ?EvY z=asr7GV?`MsF*{f6g0fWZWSs87|-9ZVIv?F7HJDmna>2%Ln$vCKsE8J6o$qBFlVMhWz$u#IjNJ zOsF$ltafai0IrX36duyEbkEFTHa%*Jx4i#UTK?H^mb2xc`}Ub4mpK1lXXp+R0cksJ z0Pg+X1Eem|2hswqN{BWBYzBl{^i{vW+2-&x3EL>>X}V_PD=aKV^#PvkZEedxh0DM; zHhC3G>ph!K&5&S0X?NU|C~*3lI?)KLgElU|LL$xNRPxvMOa}fAQ;kNij&3fn-tJh# zqzzd9ntqML_<>Bg*Xm94&&|=jYy$4WKwGUCSoH#A^o$yUeCz(6sa{HRmQY?kDbsg# zOuX)Ix|d}kw3R}?OXw-TN!K0fd0&GUjzOb7&P)lO^Bv=h(W0Lb>d%hLCX(e#v+qn7 zLPnwl+h+8_8GDdNef;FXwFJUJ@3G)&RyjLflY#E&`<-W>KJV@Z)J*}oV}osccjG$e z2VvA{dw2;8pIOH#y2hl8jAB5Oal)3!HH!;gNA<7^b=D!wi}ywW$bV4A%WBm&7ziH zqpx&`0Lbw7TK{!7IKbh^1JXMQhaDa4NC18@gR}DBi2~lun=Q|?q@7cyCE~Q_BiWz- zvW{1>gO_N}T@|&s3(y*ZuhYGZwgR>oou;b4U%7H62~f`op1MX%+syD7407=&V*46R zPxN-p-D=t!m$~U)_JohzCdDbNb)B8Vp2e*|^F19E4>nq_E0KyHO|LZ_hdM_=PD;(F z^2X}|*+|-%=dKd3tY$FHr9@*T+}_8=8UesN>AsueOk>0J3!W=iUWhT!ec&`5eQW;c zAW={GI~+hI5tgBQYVA>M6>!fL6HvDWkXAuHWcOP`$bb@c41^eZ%uLs%33-FsDxx%6 zuS}qdGHt!Rranc>r%4n7+YEnQp7)<0~1osO={HFTdOEA2WV#s7}|p2@G! z%?-r%F#FYDY|NXY}X3JP3PhqHs1hyFej(9+ha<{&HH(y>wK<^WJhOXJt7JAVQ_0k+}Zr)HY0wL}K zQ@03E0SuNCfMVPTsoe5DT|D$h0I=@kl&2WDHTJsR6bCPY@~=yrb&=4y~G#+10( zW4pd$pH8n7V&L=6P#tYbFN*kVywf&2X5+2q2NF-gt&CBiR?a z3uH%?w(sYv6R@xY+Fz|64cte#2~)s1026(9+m&W->JR$$W9TnHE(=dkHTBU)y8PcIvl zJ)dqo8cVAx22!TUDc?UIG9VuY$i+v(bZRKm97P?q;C^{@bwr^e-Aju$f?zVBMOtTE zCB5_%$n>=8+}e{QIPdbC6D}3QKnJ`w+}qkcPCn4>UbEu*^MRNDUzcXBBRv0n5fXEs z`Q*dP_{)H$0bU#8Ee}rK>6F#4ojG}L?mRnx^8SA+M09>J>D0naKp!3ThMAji?+^wk zdJe)ZLs#yd6xMa&?EFK~lcx*V_8wE(K0x&-0p#rM&jA-7x;?yh#tT8d|GX&_R(D`* zda52^bN#8-4fdNc3^m}OYwP0`_!;JTUn2JQpL*9^VyOO8%m1SK|Ce;azYFQLtFzp0ESYKV8?{4m=yt%fcVv z!p8lk3G%*#JYaImYZn4pNS^=);k;7%vr)$a*&#!w^3bJ}Eqqq6F28z=N6N_!cImCS zsw0Y zNB~5Ya)!XgpUEhgx_7{dhkoMi9DpuA@B1Qa6f?RuAEf2i)jJKaHOh! z7n6<2zf>=d z|LLikGge$(?L&3N#t6f9W5UjQVS`_Ik+q9ieBlYANT=aQQb=!~c&ki#5Z##aU@NdtM)z;?iY=)6?S zI>LLBMjpFeSh9oi;+n$33af>em6R`K;HjE&nkV)fSijZGlR!|359yzbD;) zzldmeaMU{f~t*PCVpPr>0|zur9|n);=toS4ytTP?ulrTaKQ;yp4l za@=;-%Q|usm{DW4e~hWumFYkI^q(`X|FDx%M7V55Il+#>hXFV63Co+Am(9M* z8-BN6l`sm%5z2^lRPbnNu9V6TL}O9{OY*&3@V)|2>dOkCRJ`z-KrAUBwl?!^0h-p-pM5Z+c@L4SXZ8c zgx@XjqBtRnvq2Yta%andm7`xfrt$g_Z|rj){j8`4GHt+FO4gYO-p!RvkI`c5+YKmA*KXtw^9Rj|F!$lweupj0|IBVc)X z3}pr|V`%>5Rb%MunNNtiqBC;Ny65hOML$o)S{fb1?!c5a-Cg=aZu@`vB72aREyMk- z&ZCek>Q`Z+I^lQtYbo*yg60G&mcWyCGo4X3-CCpv1nB+EiMmh(pnaLwu8noF51EkB zQvTehA~s4OVbxczpaz06F6`IeUca*sS#H4dj(k2+mGb(DrjY_BASv^cjl-1&s+h25 zeWiVxA*tVQR(kyi3(J)ih{1e&bcpe9Zspk9I>R*6bJrDUNHe4OL+0q%SRLnn3Jyq@ zN5VYE$;7$Z96&`A3n{{PgxV3kx`q*fwc)a{b@aU`F}2hf9hR%>*mDHEmvNY(* z#RbKuvbahO}2Z{ z>2;?oIrNyE-V?wiEzJ*aljkhZG*KJ_SH|PnwmZk4=k`3cv9amc1_aWpLFa+iSv&wt z%Jw&}_S3gHRg7G5hWP_K#`DH%u|)Rd&RfnsG>RY5N4}^t@F%bPGVVD`WlhUrnhcfQ zHOAzt=^@A6u$J|pu#FpiQRQaCha)!4&S}cF#cm~Y0)bIQht`SGz63C^NT@NSpUxxw zO21r^G7>fH$IbTSKL!Oc^5BhR?+AmR_vS$)HE#SB| z5o<;nzl#!cs|-7GP}9bUuCn zlQMNb=48rvTn9WKyj5rElk0Hv{2ut|c&FAjXuSM(eV7R%Y8q`RkU8Ol(74z=?e%2| zcAU1h?vacfbt_Q@%?yX7Gj_L$U^v%Z4Wg*HNDd#+%ziM!KpejCv&;-ia{?602{UEs z9IT4eY*)e{Y|H0UK4Z302~Vu^R|xS`nqp~PRLzx2E1aIt!G?XBit`6~OiAf!S*3i( zc&S8Lr1KtC+qY%>ZJh%z?7MS=a-h}cS_`OdPWUM%!I7ubXCaR>A49cF_7fom?tfNa zv6JcldsnmQg8airKUPYH!bE|bsyqIPNqKq2fEc6f;J4A!2;C*^ex>A*TyxbuVN=eQ z)`pJ{&g@#|{ur1pWlB|+>zc{??i2u@R;zPf=t>bv+8kXsK+t@Re!BX+m~rENVC?LV z8xn%Hkzk?UqvR6$y1amHr0aiwpE=_#T^O9|41hNS$>4_Yorn>O)dWkAZ?DRaB5UZ` zt~ro|CDBnv^@cg3xp@r=CD?448wh_Mmjh7hBFOYc9Q>M&6*jS6}cR}qogZW05+5b#S#jsE9<-kuIznLk&15M9RqDC_T@EPLgYTrG4wxKNRB%k( z(MFc?*d`uUVcnEoN|m|Y4gD5d>vz`INmATXo+}Z>8}#KSq;yz(9?NEgL_7)V_pPZv zF1E=+vS!DdWqp?3hToNM*Yn+1EINvfdbfLGqZhbNjOnY%i~8%efi^Gjj^vl4l{=p* zZYJ*bxjQw_qA40q&quiLTPQ`@ZR3+G)mhdae)n$&B}Za z3F8S+_r~%X4=$-vaA)|z>y@d)^caWPu#t+#FXPIBwOfN{3eRN2%!PLpN`?Zr1jptL za33zMr>tzi4{=;@*tiinch|$7;p~Kvj3y48TqF>c7VRO9m z)8ErfmfxOmUaT?7GXtUA40V6>7uY|zim=^Iaqi(mc*I_7Xp*T{(4nKjfk$bhLj8|DWzuDOXDLASx^FOFJPFwE5;>zAl- zI-ZQO=BCsCf5G%Dzo7hO@246pnQA~WvzCDCYH>Brd2&vV{2N}xnmGrC+Fc3b%Au)6 ze|5k+PNWmS^uP~~4I?~npdU1XE(=41C3brA(rWP{l%~eIAU_h)V4EdlqUV{{mjv4+ zr?8Vzb~Fvd2&Es8B*}GfYljZT)1|t6XTc9&XNW?=^}l;jz(s%?b+VAodZ_`y6eE8k zp}{mC<3jVm*F6$DE;k>XY{a}CQPRB){N`m(a!-Jq?G25#f4Rt_0N`W}oE`9!)3~)} zuv^GWAa2aC!lv|$!A`~fXsZ6O;i>-9L*G@fPSbG#P=k3VBax6GVXw#JGP4FsL=?8( zESaew>}v8?rQg|v*0NTLIoWd$YM>dbFjLB-qrHN8E*U*s;I4uZh`w1V*`NPvk@!^& zZter?dv!hA?e;1~S=XjqYqY4a{D*EnzqDNu#Pz{f`Vr2c_yNx1d4DHTR0)lgn^U`u zq6iH-mp@-O*5QGh=df`go~J6}f^QR>my7_f%5%g~|M}jKACN$zIS97D!fsY&1%HvS z{q^Zykkc?KeFFTm*7nPet4YaIBmLKD?F(|1y4j6~;VB1xPXd2FFX3EY&0n&m&vYFS z6$f;(HEN6yR=#YW+=gxJ>npjp(|x;)z|xas42io)SZcTv_FntVSM^EFiSao$*G!cs zwi0ENXx;6m%);zZ4MEydh+5_3gz`;H2W8 zu+HO0F%YQcjg6Vqg^$(-wpN3Jf6vJ_*sc{po2#Ls^WBLv5pQ-2lU2(k4pm|#GLr%=)s+BEDWV{3f)t!82Iey*QI%P zxN4u#p^B7&qe9;YFNf6YlZ{E|Jo1&A057d6Vr7F6gukYnCblMMH^n5CY1cU>U3e?K zZ$9l>*k1aCUU8-KT+^DXR;}KQK@-6|p&RkXF5w_S3)KwaPB&@XFPTGHi!7^kz*jj=CNc1XOsv&;W^c z`4SrY`xXUl(&8*$yvP)BmGe8m3rI*v%m>t4O(9O=A5pGpjA!73S`)pU{f$flwr{m)zre+yV|EcSNDJt4 z<&bkPd&Nl}jR|x{EmF^l?EO=}RwJbLyGO>xf=D`?^lF@((M`l&Kg7r!>6~6ht$?I< zutK<8>eqn%PN_h^Tg{Zw0iBjXp}*&?_t+5HdC(i;zBSHi5MHq&UA=}PUaTuNg`!RV zAvp~muP~p9t<|A6p+Tt%jpqB}XH-UK*K#S&BX7waU(ULkQEEE;MER+|^EqT)qGYlL z-jT3zhDX0RsXIdveQZ9*l{GY5>xirUgu>H%Ead(D*U}CU3}RGF0pV#@=a9MO8}J0w zXo4V#Z)!cL)&A{HhLkj4J0!Qvr#T*5=ye0)Felx8TOol#?k*r14eE_UkaU#cok{E4 zD7l6bx$Xm+iz>_PmMkjEWTo<1aJ&`n2B>|;b}CXKZRKP4og=vYv-yAEo5KEOdoE#S z6)r#H5Ct%0_-4|&bz}d~%S5Mm>wtC0PUdv!#`IqmO32Nt62LF68nmXaSD_an~J^mt9QDBR^JSyUFdd1 zjiY0ZQ%$x*<6m`&~!Hhm$Rl=Jvha$Z7=b88t$5dV-P;_O(jN{nL=p8CaS0d(p_MI*^it<%Y|qa9%=tG1hNJ zQ|W=_Ok{S_78=yli1MW+4}~KA29%r<^|U6w>2LenI124FsjGGzYa~lRKafPdVmXX= z3KlU96DVSc0PnVOqE(rre&!AP@i~JE$yaJEKa&}(Qh^Mr!VP*urf z@(ab<@*rP#a6LHz?SjGH;%{Q|1(&@9rJ9CwAgy#>r__QB3`|=Q8f(R`%j+H|iM7?z zD;R8=;vxGWQO47!gPOAXp~YWH|Mu;^MH=a!#xr*azP(A_TlmO9&UPbU8e`4SPKA8D zBrW1x*7WigV19t6oor=ltl8@lo`jeILRNVJ?VZ`7+8eA!v2Xf(VVwm+se+$tqN^bI zg2T~)fNF*j{OL0Udk!gas?#}E66 z0c}3YvN$)|%Sm6^^g8+P(tjAuvpGHmtb6MmC(XG`6=oUJ7zsdhMCniOoG>q^+NVI0 z8^CG<+&L*DWsuPy)>8ZF(i5J`pL7E(iEWsspugn;{?uAc9jG<_XW_vQ+^0p+qDlLE zMgal4spxDtA-Zrf0~@L2hr1g~Tm4M`^ZCk!l}^;|o(839dZnaTbg{=h(h2<&r>v@7 zF=@X$U`A^ncCuc_|2nuP*Gj=mT?;OHMRL<0zO*uSa6K)kKeN`T z9LhPeNnJW%tdNqt`2&H@Bm9aNL#z*bH^jSaf+DpXf4~mwYiRuZ<^m2q=`f7_5>~RM z2z5OBp&s`z>&xppzd-(dfTZdB>q6}IQvr6zpV855M9icib%p2&qF8@1Z?Bs|=rU%~ zxZruxEbG|ZnN><7Qyk`y<}|TS6B#cNo|yW)WPg3Me{r+R6-NIt7%I|+48*0uZZ>NSq7TL)xzExE{r1pXC0HX-eFuc68TqbjI9}d) zvb)W(HBqJ|deJOiUP(C&kAK~ z-FsbN68_;=dk>$jKjT1zW)h1C4ZOC0SvPUOr0X>V1LKar8gNe%k35GdSo2}&rK74^ zBq=woV(leMxsp(@wXF2^>BHRf*%%#v^tj865J;M5Rai!#I z?J(*qo}f+N2Jyg3^@hWu3K`BUEdmL)HbS3JX*Ja6ge4g?JlGa#kUNySu!gZfH}MF9 zilaMX8pV6glI6Pd744$$_`4&SFUF1bj)SBKaYW4ibf(q?$KoIHJ_df)J^AOW3GyhR zHG?5qD#s3nOT-+ls}d@2;E!%J0DYln^GM-YJ25!6eW`PQcZtiX3%BXT(@9<}Shuban8w)>_)njEQSfWx@O zyF@H>yU(UN)sen0K^nC^*fN&l7RXAgoAR@X3@ov~0J^a)UQMslZO718$2p9@j&4iW zaX;t;_u%`-`q@l6e6C5)tbRIviI~kEmEx(ujblG@+2R^uW1z)4|2e`&~bt>+AbAaQuKo(fgmYt6s!gUGduzam z+XD|m;R@LZKccW*w*W(5^3mS2Let<(*9NcE-GPhSCjjK^3r|zadV#cuaqs`cyk_Rb zvN6KhC z?o$hW_OWNlaW6&H@HdB(Clr)?(GQZT6H<9IMQpT|B(BwOZn8*}0#%re^Md-HViNeI zF@8h3r~uodrHg{a)4$JG+RXFSM1w)YSRL?MNZjPdASo>sF^keZ@lc^aBWF~R`Lode zT5Z99haR~5Cs@cIOmFt_<&KZ@qHcf{{U99!-_v-${rChQZf7BO&UD#{$b~Exw*C0P87h7{W|74l}!1kYrsQ=eJhp*7Jq5Nn7 zvs{(%sCge1We8wkt6LktTctw}0{*~o7I*&z%`hK;{=iEg5|<19JK#Mi5y#~Z@I94w z#;p6VYz+LWay37}?E9|(Rr0Xc0VQZbtzC literal 0 HcmV?d00001 diff --git a/autodist/docs/interface_design.md b/autodist/docs/interface_design.md new file mode 100644 index 00000000..9d4b3504 --- /dev/null +++ b/autodist/docs/interface_design.md @@ -0,0 +1,36 @@ +# Interface Design + +Similar to current user experiences in cube, the entrance to *AutoDist* is a function that accept a data flow graph and a resource descriptor as input. The function returns a rewritten graph. The core modules including +1. *profile*: build cost models to provide the underlying solver with operator and communication information +2. *dp_solver*: encapsulate existing dynamic programming logic + +```python +from cube.graph import IRGraph + +def annotate_graph(graph: IRGraph) -> AnnotatedIRGraph: + # TODO + pass + +def profile(anno_graph, resource): + # use_case: + # t = comm_cost_model.estimate_cost(primitive='allreduce', size=1024) + # TODO: multiple dim partition? + # t, m = comp_cost_model.estimate_cost(op_name='bmm0', idx=0, dim=0, num=4, recompute=False, chunk=False) + comm_cost_model = build_comm_cost_model(resource) + comp_cost_model = build_comp_cost_model(anno_graph, resource) + return (comm_cost_model, comp_cost_model) + +def dp_solver(anno_graph: AnnotatedIRGraph, cost_model) -> DistPolicy: + # TODO: solve the optimization problem + pass + +def rewrite_graph(graph: IRGraph, dist_policy) -> IRGraph: + # transform the initial dataflow graph according to generated distributed policy + pass + +def autodist(graph: IRGraph, resource: Resource) -> IRGraph: + anno_graph = annotate_graph(graph) + cost_model = profile(anno_graph, resource) + dist_policy = dp_solver(anno_graph, cost_model) + return rewrite_graph(graph, dist_policy) +``` diff --git a/autodist/docs/solver_interface/partition_constraint.md b/autodist/docs/solver_interface/partition_constraint.md new file mode 100644 index 00000000..64e4de57 --- /dev/null +++ b/autodist/docs/solver_interface/partition_constraint.md @@ -0,0 +1,43 @@ +# Control partitions of a model in AutoDist + +In autodist, we provide a set of partition constraints to control the distributed plan of the model. The partition constraints are specified in a yaml file. The following is an example of the partition constraints. + +```yaml +- allowed_partition_dims: + - 0,0 + name: torchscale.component.xmoe.routing.top2gating + parent_module: Top2Gate + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: torchscale.component.xmoe.moe_layer.dispatch_expert_inputs + parent_module: MOELayer + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: torchscale.component.xmoe.moe_layer.merge_expert_outputs + parent_module: MOELayer + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: torchscale.component.xmoe.routing.compute_logits + parent_module: Top2Gate + replica_allowed: true +``` + +In this example, we have four partition constraints for the MoE model in retnet. Each partition constraint has 4 fields: `name`, `parent_module`, `allowed_partition_dims`, and `replica_allowed`. + +- `name` is the name of the corresponding operator in the model. It equals to the `signature` field in the `IRFwOperation` in cube. Note: signature is the full name of the operator, for example, you should provide `torch.nn.functional.linear` instead of `linear`. +- `parent_module` is the **closest** father module name of the operator. You can provide two partition constraints with a same `name` but different `module` to control the partition of the same operator in different modules. +- `allowed_partition_dims` is a list of allowed partition dimensions of input tensors. Each element in the list is a list of two integers, which are the index of the partitioned tensor among inputs and the partitioned dimension of that tensor. For example, the annotation of `torchscale.component.xmoe.routing.compute_logits` can be `(C 16) E^ C, E^ C M^ -> (C 16) M^`. `allowed_partition_dims = [[0, 0]]` means we only allow to partition the first input tensor along the first dimension, which is `(C, 16)` in this case. An empty list means no partition is allowed, note that in yaml, you should give an empty list explicitly, i.e., `allowed_partition_dims: []`. +- `replica_allowed` is a boolean value. If it is `true`, it is allowed to replicate the operator across devices. + +After specifying the partition constraints in a yaml file, we can feed them to autodist by `--autodist-partition-constraints-path ` in fairseq. + +# Examples + +Three examples are provided in `pc_examples` folder. + +- `pc_examples/retnet_dp_pc.yaml` helps to generate a pure data parallel plan. +- `pc_examples/retnet_mp_pc.yaml` helps to generate a pure model parallel plan. +- `pc_examples/retnet_hybrid_pc.yaml` helps to generate a hybrid plan: data parallel for the attention module and model parallel for the feed forward module. diff --git a/autodist/docs/solver_interface/pc_examples/moe_pc.yaml b/autodist/docs/solver_interface/pc_examples/moe_pc.yaml new file mode 100644 index 00000000..6b3a66d9 --- /dev/null +++ b/autodist/docs/solver_interface/pc_examples/moe_pc.yaml @@ -0,0 +1,20 @@ +- allowed_partition_dims: + - 0,0 + name: torchscale.component.xmoe.routing.top2gating + parent_module: Top2Gate + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: torchscale.component.xmoe.moe_layer.dispatch_expert_inputs + parent_module: MOELayer + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: torchscale.component.xmoe.moe_layer.merge_expert_outputs + parent_module: MOELayer + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: torchscale.component.xmoe.routing.compute_logits + parent_module: Top2Gate + replica_allowed: true diff --git a/autodist/docs/solver_interface/pc_examples/retnet_dp2_pc.yaml b/autodist/docs/solver_interface/pc_examples/retnet_dp2_pc.yaml new file mode 100644 index 00000000..06fc0a9c --- /dev/null +++ b/autodist/docs/solver_interface/pc_examples/retnet_dp2_pc.yaml @@ -0,0 +1,26 @@ +- allowed_partition_dims: + - 0,0 + name: torch.nn.functional.linear + parent_module: MultiScaleRetention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: cube.runtime.function.embedding + parent_module: LMDecoder + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction.apply + parent_module: LMDecoder + replica_allowed: false +- allowed_partition_dims: + - 0,0 + - 0,1 + name: torch.matmul + parent_module: MultiScaleRetention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: torch.nn.functional.linear + parent_module: FeedForwardNetwork + replica_allowed: false diff --git a/autodist/docs/solver_interface/pc_examples/retnet_hybrid2_pc.yaml b/autodist/docs/solver_interface/pc_examples/retnet_hybrid2_pc.yaml new file mode 100644 index 00000000..1c787a2c --- /dev/null +++ b/autodist/docs/solver_interface/pc_examples/retnet_hybrid2_pc.yaml @@ -0,0 +1,21 @@ +- allowed_partition_dims: + - 0,0 + name: torch.nn.functional.linear + parent_module: MultiScaleRetention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: cube.runtime.function.embedding + parent_module: LMDecoder + replica_allowed: false +- allowed_partition_dims: + - 0,0 + - 0,1 + name: torch.matmul + parent_module: MultiScaleRetention + replica_allowed: false +- allowed_partition_dims: + - 1,0 + name: torch.nn.functional.linear + parent_module: FeedForwardNetwork + replica_allowed: false diff --git a/autodist/docs/solver_interface/pc_examples/retnet_mp2_pc.yaml b/autodist/docs/solver_interface/pc_examples/retnet_mp2_pc.yaml new file mode 100644 index 00000000..ba3694be --- /dev/null +++ b/autodist/docs/solver_interface/pc_examples/retnet_mp2_pc.yaml @@ -0,0 +1,21 @@ +- allowed_partition_dims: + - 1,0 + name: torch.nn.functional.linear + parent_module: MultiScaleRetention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: cube.runtime.function.embedding + parent_module: LMDecoder + replica_allowed: false +- allowed_partition_dims: + - 0,0 + - 0,1 + name: torch.matmul + parent_module: MultiScaleRetention + replica_allowed: false +- allowed_partition_dims: + - 1,0 + name: torch.nn.functional.linear + parent_module: FeedForwardNetwork + replica_allowed: false diff --git a/autodist/profile_data/16xmi200/comm/intra_16.json b/autodist/profile_data/16xmi200/comm/intra_16.json new file mode 100644 index 00000000..2c221272 --- /dev/null +++ b/autodist/profile_data/16xmi200/comm/intra_16.json @@ -0,0 +1,122 @@ +{ + "all gather": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.0007568836212158203, + 0.0007002830505371093, + 0.0007263660430908203, + 0.0007683753967285157, + 0.0008181095123291016, + 0.0009009361267089844, + 0.0008710384368896485, + 0.0011073827743530273, + 0.00132598876953125, + 0.0018424749374389648, + 0.002922821044921875, + 0.005091261863708496 + ] + ], + "all reduce": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.0001537322998046875, + 0.0001667022705078125, + 0.00018396377563476563, + 0.00022487640380859376, + 0.0005011320114135742, + 0.00040619373321533204, + 0.00041749477386474607, + 0.0006366968154907227, + 0.0008751630783081054, + 0.0016357183456420898, + 0.00281984806060791, + 0.005467009544372558 + ] + ], + "reduce scatter": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.0009616374969482422, + 0.0008935928344726562, + 0.0009247541427612304, + 0.0008810997009277344, + 0.0009346246719360351, + 0.0009841203689575195, + 0.0010200738906860352, + 0.0012418031692504883, + 0.0014723777770996095, + 0.001923823356628418, + 0.002941608428955078, + 0.0050580501556396484 + ] + ], + "all to all": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.0005378484725952149, + 0.0005496501922607422, + 0.0005365371704101563, + 0.0005259037017822266, + 0.0005403518676757813, + 0.0005397558212280274, + 0.0005373716354370117, + 0.0005512714385986328, + 0.0005965471267700195, + 0.000803828239440918, + 0.0011274337768554688, + 0.001718592643737793 + ] + ] +} diff --git a/autodist/profile_data/16xmi200/comm/intra_2.json b/autodist/profile_data/16xmi200/comm/intra_2.json new file mode 100644 index 00000000..651fe980 --- /dev/null +++ b/autodist/profile_data/16xmi200/comm/intra_2.json @@ -0,0 +1,122 @@ +{ + "all gather": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.00024454593658447267, + 0.00024988651275634763, + 0.0002572059631347656, + 0.0002519130706787109, + 0.00026090145111083983, + 0.0002851247787475586, + 0.0003421306610107422, + 0.0004983425140380859, + 0.000761103630065918, + 0.00133059024810791, + 0.0024867534637451174, + 0.004799938201904297 + ] + ], + "all reduce": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 9.734630584716797e-05, + 9.603500366210938e-05, + 0.00010254383087158204, + 0.00010678768157958984, + 0.00012938976287841798, + 0.00016248226165771484, + 0.00023658275604248046, + 0.0003794431686401367, + 0.0006722688674926757, + 0.001253032684326172, + 0.002414846420288086, + 0.004711151123046875 + ] + ], + "reduce scatter": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.00020804405212402344, + 0.0002099752426147461, + 0.00021717548370361328, + 0.00022215843200683593, + 0.0002355337142944336, + 0.0002671480178833008, + 0.0003258228302001953, + 0.00046095848083496096, + 0.0007699251174926758, + 0.0014074087142944337, + 0.0025737762451171877, + 0.004897451400756836 + ] + ], + "all to all": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.00021393299102783202, + 0.00020859241485595703, + 0.00022313594818115234, + 0.00024058818817138672, + 0.0002853870391845703, + 0.0003378629684448242, + 0.0004602193832397461, + 0.0006934881210327148, + 0.0012228965759277343, + 0.0022298812866210936, + 0.0039789676666259766, + 0.007758593559265137 + ] + ] +} diff --git a/autodist/profile_data/16xmi200/comm/intra_4.json b/autodist/profile_data/16xmi200/comm/intra_4.json new file mode 100644 index 00000000..ffcd6e76 --- /dev/null +++ b/autodist/profile_data/16xmi200/comm/intra_4.json @@ -0,0 +1,122 @@ +{ + "all gather": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.00028171539306640623, + 0.0002959251403808594, + 0.00030524730682373046, + 0.0003502368927001953, + 0.0004027366638183594, + 0.0005098581314086914, + 0.0007292509078979492, + 0.0011570215225219726, + 0.0020204782485961914, + 0.0038918256759643555, + 0.007619047164916992, + 0.014976143836975098 + ] + ], + "all reduce": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.0001232624053955078, + 0.00012753009796142578, + 0.00014941692352294923, + 0.00019860267639160156, + 0.00030059814453125, + 0.0004938364028930664, + 0.0008866548538208008, + 0.0016759634017944336, + 0.0032621145248413084, + 0.006457090377807617, + 0.012811422348022461, + 0.025483202934265137 + ] + ], + "reduce scatter": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.0002853631973266602, + 0.00028748512268066405, + 0.00030829906463623045, + 0.0003508090972900391, + 0.0004006624221801758, + 0.0005053281784057617, + 0.0007104158401489257, + 0.001131153106689453, + 0.0020192861557006836, + 0.003987717628479004, + 0.007789778709411621, + 0.015131282806396484 + ] + ], + "all to all": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.00024099349975585936, + 0.00023865699768066406, + 0.00023796558380126954, + 0.00024230480194091796, + 0.0002443075180053711, + 0.0003082513809204102, + 0.0004188776016235352, + 0.000558161735534668, + 0.0008316516876220703, + 0.0013772964477539063, + 0.002527189254760742, + 0.004883217811584473 + ] + ] +} diff --git a/autodist/profile_data/16xmi200/comm/intra_8.json b/autodist/profile_data/16xmi200/comm/intra_8.json new file mode 100644 index 00000000..dd2ca87f --- /dev/null +++ b/autodist/profile_data/16xmi200/comm/intra_8.json @@ -0,0 +1,122 @@ +{ + "all gather": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.00047366619110107423, + 0.0004703760147094727, + 0.0004788875579833984, + 0.0005095481872558593, + 0.0004987716674804688, + 0.0005119085311889648, + 0.0006889581680297851, + 0.0009741544723510742, + 0.0015397071838378906, + 0.002678585052490234, + 0.0050097227096557615, + 0.009758901596069337 + ] + ], + "all reduce": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.00016224384307861328, + 0.00016684532165527343, + 0.00019915103912353515, + 0.0002106904983520508, + 0.00026621818542480467, + 0.00036542415618896485, + 0.0005738735198974609, + 0.0009899616241455078, + 0.0018474102020263673, + 0.003551936149597168, + 0.006996250152587891, + 0.013917374610900878 + ] + ], + "reduce scatter": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.0005039691925048829, + 0.0005405664443969727, + 0.0005137443542480469, + 0.0005560636520385743, + 0.0006148099899291992, + 0.0006127595901489258, + 0.0007145166397094727, + 0.0009651422500610351, + 0.001480269432067871, + 0.00257718563079834, + 0.0047527790069580075, + 0.009169626235961913 + ] + ], + "all to all": [ + [ + 0.25, + 0.5, + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 512.0 + ], + [ + 0.0002892017364501953, + 0.00029306411743164064, + 0.0002937793731689453, + 0.0002928495407104492, + 0.0002915620803833008, + 0.0004178762435913086, + 0.0003500223159790039, + 0.0004662752151489258, + 0.0005607128143310547, + 0.0008365631103515625, + 0.0013719320297241211, + 0.0024344682693481444 + ] + ] +} diff --git a/autodist/script/alphafold/foldtp.sh b/autodist/script/alphafold/foldtp.sh new file mode 100755 index 00000000..defb30e5 --- /dev/null +++ b/autodist/script/alphafold/foldtp.sh @@ -0,0 +1,64 @@ +#!/bin/bash --login +start_time=$(date +%s) + +memory_constraint=38 # in GB +memory_granularity=1 # in byte +connect_type='NV2' +save_folder="alphafold" +topk=1 +cache_folder1="autodist/cost_model/comm/__pycache__" +cache_folder2="autodist/cost_model/__pycache__" +comm_dev=(2 4) + +for ((i=0; i<${#comm_dev[*]}; i=i+1)); do + torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type +done + +if [ ! -d $save_folder ] +then + mkdir $save_folder +fi + +# We run all cases in the machine with 4 gpus. + +mesh_rows=(1 1 1) +mesh_cols=(1 2 4) +setting=(1 2 3) +layer=48 + +for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do + for ((j=0; j<${#setting[*]}; j=j+1)); do + + echo "start runtime Alphafold2 setting=${setting[j]} gpus=${mesh_cols[k]}" + if [ -d $cache_folder1 ] + then + echo "Removing $cache_folder1 directory..." + rm -r $cache_folder1 + rm -r $cache_folder2 + else + echo "$cache_folder1 directory not found" + fi + + SINGLE_DEV_MODE=1 python main.py --is_train \ + --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity \ + --save_folder=$save_folder --connect_type=$connect_type \ + --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --compile \ + --topk=$topk --ignore_small_tensor_threshold=2048 \ + --verbose --alphafold_setting=${setting[j]} --alphafold \ + --alphafold_layer=$layer --recompute --adaptive_recom + + torchrun --master_port=30001 --nnodes=${mesh_rows[k]} \ + --nproc_per_node=${mesh_cols[k]} main.py --is_train \ + --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder \ + --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} \ + --plan_idx=0 --iter_num=4 --warm_num=2 \ + --global_batch_size=1 --alphafold --ignore_small_tensor_threshold=2048 \ + --alphafold_setting=${setting[j]} --alphafold_layer=$layer --recompute + done +done + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/adapt_recom_tp.sh b/autodist/script/gpt/adapt_recom_tp.sh new file mode 100644 index 00000000..b733ce8d --- /dev/null +++ b/autodist/script/gpt/adapt_recom_tp.sh @@ -0,0 +1,59 @@ +#!/bin/bash --login +start_time=$(date +%s) + +memory_constraint=30 +memory_granularity=1 # in byte +connect_type='NV2' +save_folder="tp_data" +topk=1 + +comm_dev=(2 4 8 16) + +for ((i=0; i<${#comm_dev[*]}; i=i+1)); do + torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type +done + +if [ ! -d $save_folder ] +then + mkdir $save_folder +fi + +# We run all cases in the machine with 4 gpus. + +bs=(1 2 4 8 16 32) +mesh_rows=(1 1 1 1 1) +mesh_cols=(1 2 4 8 16) +model_config=('350M' '760M' '1.3B' '2.6B' '6.7B') + +for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do + for ((j=0; j<${#bs[*]}; j=j+1)); do + + echo "start runtime ${bs[j]} ${model_config[k]}" + + SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ + --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --compile \ + --topk=$topk --fine_grained_GPT --ignore_small_tensor_threshold=1048576 \ + --verbose --adaptive_recom + + for (( i=0; i < $topk; ++i)) + do + torchrun --nnodes=${mesh_rows[k]} --nproc_per_node=${mesh_cols[k]} main.py --GPT_setting=${model_config[k]} --is_train --recompute \ + --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type \ + --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --plan_idx=$i --fine_grained_GPT --adaptive_recom + if [ $? -eq 0 ] + then + echo "success at $i trial" + break + else + echo "fail at $i trial" + fi + done + + done +done + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/analyze.py b/autodist/script/gpt/analyze.py new file mode 100644 index 00000000..1c965033 --- /dev/null +++ b/autodist/script/gpt/analyze.py @@ -0,0 +1,77 @@ +import json +import argparse + +parser = argparse.ArgumentParser(description='GPT Train') +parser.add_argument('--save_folder', + type=str, + default='tp_data', + help='set the save folder for experiment data') +parser.add_argument('--pp', + action='store_true', + help='for pipeline number analysis') +parser.add_argument('--is_train', + action='store_true', + help='True: train, False: inference') +args = parser.parse_args() +import pandas as pd + +model_setting_list = ['760M', '1.3B', '2.6B', '6.7B' + ] if args.pp else ['350M', '760M', '1.3B', '2.6B', '6.7B'] +gpus = {'350M': 1, '760M': 2, '1.3B': 4, '2.6B': 8, '6.7B': 16} +recompute_list = ['True'] +batch_size_list = [1, 2, 4, 8, 16, 32] + +for recompute in recompute_list: + table = {} + for model_setting in model_setting_list: + for batch_size in batch_size_list: + table[batch_size] = {} + fname = './' + args.save_folder + '/gpt3-' + model_setting + '-' + str( + gpus[model_setting]) + 'gpu-' + str( + batch_size) + 'batch_size-' + str(args.is_train) + estimated_fname = fname + '-estimate.json' + backup_fname = fname + '-backup.json' + real_fname = fname + '-real.json' + + try: + with open(backup_fname, 'r') as f: + estimated_dict = json.load(f) + try: + tmp = estimated_dict['estimated memory'] + except: + estimated_dict = estimated_dict[0] + estimated_time = estimated_dict['estimated time'] + estimated_memory = estimated_dict['estimated memory'][ + 0] if args.pp else estimated_dict['estimated memory'] + compile_time = estimated_dict['compile time'] + except: + try: + with open(estimated_fname, 'r') as f: + estimated_dict = json.load(f) + estimated_time = estimated_dict['estimated time'] + estimated_memory = estimated_dict['estimated memory'][ + 0] if args.pp else estimated_dict['estimated memory'] + compile_time = estimated_dict['compile time'] + except: + estimated_time = -1 + estimated_memory = -1 + compile_time = -1 + try: + with open(real_fname, 'r') as f: + real_dict = json.load(f) + real_time = real_dict['time/s'] + real_memory = max(real_dict['memory/GB'].values()) + except: + real_time = -1 + real_memory = -1 + + table[batch_size]['estimation time/s'] = estimated_time + table[batch_size]['runtime/s'] = real_time + table[batch_size][ + 'estimation memory/GB'] = estimated_memory if estimated_memory != -1 else -1 + table[batch_size][ + 'runtime memory/GB'] = real_memory if real_memory != -1 else -1 + table[batch_size]['compile time/s'] = compile_time + pdTable = pd.DataFrame(table).round(2).T + print(model_setting, recompute) + print(pdTable.to_markdown()) diff --git a/autodist/script/gpt/analyze_adapt_recom.py b/autodist/script/gpt/analyze_adapt_recom.py new file mode 100644 index 00000000..47f7cf06 --- /dev/null +++ b/autodist/script/gpt/analyze_adapt_recom.py @@ -0,0 +1,77 @@ +import json +import argparse + +parser = argparse.ArgumentParser(description='GPT Train') +parser.add_argument('--save_folder_tp', + type=str, + default='tp_data', + help='set the save folder for tp') +parser.add_argument('--save_folder_pp', + type=str, + default='pp_data', + help='set the save folder for pp') +parser.add_argument('--suffix', + type=str, + default='_nar', + help='set the save folder for w/o adaptive_recom') +parser.add_argument('--pp', + action='store_true', + help='for pipeline number analysis') +args = parser.parse_args() +import pandas as pd + +folders = [args.save_folder_tp, args.save_folder_tp + args.suffix] +if args.pp: + folders = [args.save_folder_pp, args.save_folder_pp + args.suffix] + +model_setting_list = ['760M', '1.3B'] if args.pp else ['350M', '760M', '1.3B'] +gpus = {'350M': 1, '760M': 2, '1.3B': 4, '2.6B': 8, '6.7B': 16} +recompute_list = ['True'] +batch_size_list = [1, 2, 4, 8, 16, 32] + +for recompute in recompute_list: + table = {} + for model_setting in model_setting_list: + for batch_size in batch_size_list: + table[batch_size] = {} + for index, folder in enumerate(folders): + fname = './' + folder + '/gpt3-' + model_setting + '-' + str( + gpus[model_setting]) + 'gpu-' + str( + batch_size) + 'batch_size' + backup_fname = fname + '-backup.json' + + try: + with open(backup_fname, 'r') as f: + estimated_dict = json.load(f) + try: + tmp = estimated_dict['estimated memory'] + except: + estimated_dict = estimated_dict[0] + estimated_time = estimated_dict['estimated time'] + estimated_memory = estimated_dict['estimated memory'][ + 0] if args.pp else estimated_dict['estimated memory'] + compile_time = estimated_dict['compile time'] + except: + estimated_time = -1 + estimated_memory = -1 + compile_time = -1 + + if index == 0: + table[batch_size][ + 'est time w/ adapt_recom /s'] = estimated_time + else: + table[batch_size][ + 'est time w/o adapt_recom /s'] = estimated_time + if index == 0: + table[batch_size][ + 'compile time w/ adapt_recom /s'] = compile_time + else: + table[batch_size][ + 'compile time w/o adapt_recom /s'] = compile_time + table[batch_size]['gain/%'] = ( + table[batch_size]['est time w/o adapt_recom /s'] - + table[batch_size]['est time w/ adapt_recom /s'] + ) / table[batch_size]['est time w/o adapt_recom /s'] * 100 + pdTable = pd.DataFrame(table).round(2).T + print(model_setting) + print(pdTable.to_markdown()) diff --git a/autodist/script/gpt/checker.sh b/autodist/script/gpt/checker.sh new file mode 100755 index 00000000..680100ff --- /dev/null +++ b/autodist/script/gpt/checker.sh @@ -0,0 +1,42 @@ +#!/bin/bash --login +start_time=$(date +%s) + +memory_constraint=30.5 +memory_granularity=1 # in byte +connect_type='NV2' +save_folder="exp_data_test" + +comm_dev=(2 4) + +for ((i=0; i<${#comm_dev[*]}; i=i+1)); do + torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type +done + +if [ ! -d $save_folder ] +then + mkdir $save_folder +fi + +# spmd for a simple case (with 1 gpu) and a complex case (with 4 gpus). + +bs=(32) +mesh_cols=(1 4) +model_config=('350M' '1.3B') + +for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do + for ((j=0; j<${#bs[*]}; j=j+1)); do + + SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --recompute \ + --batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_col=${mesh_cols[k]} --compile + + torchrun --nproc_per_node=${mesh_cols[k]} main.py --GPT_setting=${model_config[k]} --recompute \ + --batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_col=${mesh_cols[k]} + + done +done + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "checkRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/pp_all_run.sh b/autodist/script/gpt/pp_all_run.sh new file mode 100755 index 00000000..c7d38d89 --- /dev/null +++ b/autodist/script/gpt/pp_all_run.sh @@ -0,0 +1,45 @@ +#!/bin/bash --login +start_time=$(date +%s) + +memory_constraint=30 +memory_granularity=1 # in byte +connect_type='NV2' +save_folder="pp_data" +topk=1 + +comm_dev=(2 4 8 16) + +for ((i=0; i<${#comm_dev[*]}; i=i+1)); do + torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type +done + +if [ ! -d $save_folder ] +then + mkdir $save_folder +fi + +# We run all cases in the machine with 4 gpus. + +bs=(1 2 4 8 16 32) +mesh_rows=(1 1 1 1) +mesh_cols=(2 4 8 16) +model_config=('760M' '1.3B' '2.6B' '6.7B') + +for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do + for ((j=0; j<${#bs[*]}; j=j+1)); do + echo "start runtime ${bs[j]} ${model_config[k]}" + SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ + --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder \ + --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --compile --pipeline --topk=1 + + torchrun --nnodes=${mesh_rows[k]} --nproc_per_node=${mesh_cols[k]} main.py --GPT_setting=${model_config[k]} --is_train --recompute \ + --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --pipeline --plan_idx=0 + + done +done + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/profile.sh b/autodist/script/gpt/profile.sh new file mode 100644 index 00000000..5f31f008 --- /dev/null +++ b/autodist/script/gpt/profile.sh @@ -0,0 +1,66 @@ +#!/bin/bash --login +start_time=$(date +%s) + +memory_constraint=30 +memory_granularity=1 # in byte +connect_type='NV2' +save_folder="tp_data" +topk=20 + +comm_dev=(2 4 8 16) + +for ((i=0; i<${#comm_dev[*]}; i=i+1)); do + torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type +done + +if [ ! -d $save_folder ] +then + mkdir $save_folder +fi + +# Use nvidia-smi to get a list of GPUs +gpus=$(nvidia-smi -L) + +# Count the number of lines of output +num_gpus=$(echo "$gpus" | wc -l) + +bs=(1 2 4 8 16 32) +mesh_cols=(1 2 4 8 16) +model_config=('350M' '760M' '1.3B' '2.6B' '6.7B') + + +for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do + for ((j=0; j<${#bs[*]}; j=j+1)); do + + count=$((k * ${#bs[*]} + j)) + q=$(expr $count % $num_gpus) + + # bash profile.sh to profile the coarse-gained GPT + # bash profile.sh * to profile the fine-gained GPT + if [ $# -eq 0 ]; then + CUDA_VISIBLE_DEVICES=$q SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ + --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder \ + --connect_type=$connect_type --mesh_col=${mesh_cols[k]} --compile \ + --topk=$topk & + + else + CUDA_VISIBLE_DEVICES=$q SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ + --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder \ + --connect_type=$connect_type --mesh_col=${mesh_cols[k]} --compile \ + --topk=$topk --fine_grained_GPT & + fi + + if [ "$q" -eq "$((num_gpus-1))" ]; then + wait + fi + + done +done + +wait + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/tp_all_run.sh b/autodist/script/gpt/tp_all_run.sh new file mode 100755 index 00000000..bad93399 --- /dev/null +++ b/autodist/script/gpt/tp_all_run.sh @@ -0,0 +1,57 @@ +#!/bin/bash --login +start_time=$(date +%s) + +memory_constraint=30 +memory_granularity=1 # in byte +connect_type='NV2' +save_folder="tp_data" +topk=1 + +comm_dev=(2 4 8 16) + +for ((i=0; i<${#comm_dev[*]}; i=i+1)); do + torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type +done + +if [ ! -d $save_folder ] +then + mkdir $save_folder +fi + +# We run all cases in the machine with 4 gpus. + +bs=(1 2 4 8 16 32) +mesh_rows=(1 1 1 1 1) +mesh_cols=(1 2 4 8 16) +model_config=('350M' '760M' '1.3B' '2.6B' '6.7B') + +for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do + for ((j=0; j<${#bs[*]}; j=j+1)); do + + echo "start runtime ${bs[j]} ${model_config[k]}" + + SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ + --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --compile \ + --topk=$topk + + for (( i=0; i < $topk; ++i)) + do + torchrun --nnodes=${mesh_rows[k]} --nproc_per_node=${mesh_cols[k]} main.py --GPT_setting=${model_config[k]} --is_train --recompute \ + --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --plan_idx=$i + if [ $? -eq 0 ] + then + echo "success at $i trial" + break + else + echo "fail at $i trial" + fi + done + + done +done + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/pre_install.sh b/autodist/script/pre_install.sh new file mode 100644 index 00000000..e34fe0e8 --- /dev/null +++ b/autodist/script/pre_install.sh @@ -0,0 +1,7 @@ +sudo echo 'export PATH="$HOME/.local/bin:$PATH"' > $(dirname $(pwd))/.bashrc +source $(dirname $(pwd))/.bashrc +pip install -r requirements-dev.txt +pip install pre-commit +pre-commit install +pre-commit run --all-files +python setup.py develop --user diff --git a/autodist/script/swin/analysis.py b/autodist/script/swin/analysis.py new file mode 100644 index 00000000..7890cbe9 --- /dev/null +++ b/autodist/script/swin/analysis.py @@ -0,0 +1,72 @@ +import json +import argparse + +parser = argparse.ArgumentParser(description='Swin Train') +parser.add_argument('--save_folder', + type=str, + default='swin', + help='set the save folder for experiment data') +parser.add_argument('--pp', + action='store_true', + help='for pipeline number analysis') +args = parser.parse_args() +import pandas as pd + +model_setting_list = ['toy', '355M', '1.8B'] +gpus = {'toy': 1, '355M': 2, '1.8B': 4, '2.6B': 8, '6.7B': 16} +recompute_list = ['True'] +batch_size_list = [1, 2, 4, 8, 16, 32] + +for recompute in recompute_list: + table = {} + for model_setting in model_setting_list: + for batch_size in batch_size_list: + table[batch_size] = {} + fname = './' + args.save_folder + '/swin-' + model_setting + '-' + str( + gpus[model_setting]) + 'gpu-' + str(batch_size) + 'batch_size' + estimated_fname = fname + '-estimate.json' + backup_fname = fname + '-backup.json' + real_fname = fname + '-real.json' + + try: + with open(backup_fname, 'r') as f: + estimated_dict = json.load(f) + try: + tmp = estimated_dict['estimated memory'] + except: + estimated_dict = estimated_dict[0] + estimated_time = estimated_dict['estimated time'] + estimated_memory = estimated_dict['estimated memory'][ + 0] if args.pp else estimated_dict['estimated memory'] + compile_time = estimated_dict['compile time'] + except: + try: + with open(estimated_fname, 'r') as f: + estimated_dict = json.load(f) + estimated_time = estimated_dict['estimated time'] + estimated_memory = estimated_dict['estimated memory'][ + 0] if args.pp else estimated_dict['estimated memory'] + compile_time = estimated_dict['compile time'] + except: + estimated_time = -1 + estimated_memory = -1 + compile_time = -1 + try: + with open(real_fname, 'r') as f: + real_dict = json.load(f) + real_time = real_dict['time/s'] + real_memory = max(real_dict['memory/GB'].values()) + except: + real_time = -1 + real_memory = -1 + + table[batch_size]['estimation time/s'] = estimated_time + table[batch_size]['runtime/s'] = real_time + table[batch_size][ + 'estimation memory/GB'] = estimated_memory if estimated_memory != -1 else -1 + table[batch_size][ + 'runtime memory/GB'] = real_memory if real_memory != -1 else -1 + table[batch_size]['compile time/s'] = compile_time + pdTable = pd.DataFrame(table).round(2).T + print(model_setting, recompute) + print(pdTable.to_markdown()) diff --git a/autodist/script/swin/profile_swin.sh b/autodist/script/swin/profile_swin.sh new file mode 100644 index 00000000..8036e69b --- /dev/null +++ b/autodist/script/swin/profile_swin.sh @@ -0,0 +1,60 @@ +#!/bin/bash --login +start_time=$(date +%s) + +memory_constraint=30 +memory_granularity=1 # in byte +connect_type='NV2' +save_folder="tp_data" +topk=20 + +comm_dev=(2 4) + +for ((i=0; i<${#comm_dev[*]}; i=i+1)); do + torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type +done + +if [ ! -d $save_folder ] +then + mkdir $save_folder +fi + +# Use nvidia-smi to get a list of GPUs +gpus=$(nvidia-smi -L) + +# Count the number of lines of output +num_gpus=$(echo "$gpus" | wc -l) + +bs=(1 2 4 8 16) +mesh_cols=(1 2 4) +setting=('toy' '355M' '1.8B') + + +for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do + for ((j=0; j<${#bs[*]}; j=j+1)); do + + count=$((k * ${#bs[*]} + j)) + q=$(expr $count % $num_gpus) + + echo "start runtime Swin setting=${setting[k]} bs=${bs[j]}" + + CUDA_VISIBLE_DEVICES=$q LOG_TRANSFORM=1 SINGLE_DEV_MODE=1 python main.py --is_train \ + --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity \ + --save_folder=$save_folder --connect_type=$connect_type \ + --mesh_row=1 --mesh_col=${mesh_cols[k]} --compile \ + --topk=$topk \ + --verbose --swin_setting=${setting[k]} --swin \ + --recompute --micro_batch_size=${bs[j]} --global_batch_size=32 & + + if [ "$q" -eq "$((num_gpus-1))" ]; then + wait + fi + + done +done + +wait + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/swin/swintp.sh b/autodist/script/swin/swintp.sh new file mode 100755 index 00000000..9ce957fb --- /dev/null +++ b/autodist/script/swin/swintp.sh @@ -0,0 +1,66 @@ +#!/bin/bash --login +start_time=$(date +%s) + +memory_constraint=35 # in GB +memory_granularity=1 # in byte +connect_type='NV2' +save_folder="swin" +topk=1 +cache_folder1="autodist/cost_model/comm/__pycache__" +cache_folder2="autodist/cost_model/__pycache__" +comm_dev=(2 4) + +for ((i=0; i<${#comm_dev[*]}; i=i+1)); do + torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type +done + +if [ ! -d $save_folder ] +then + mkdir $save_folder +fi + +# We run all cases in the machine with 4 gpus. + +bs=(1 2 4 8 16 32) + +mesh_cols=(1 2 4) +mesh_rows=(1 1 1) +setting=('toy' '355M' '1.8B') + +for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do + for ((j=0; j<${#bs[*]}; j=j+1)); do + + echo "start runtime Swin setting=${setting[k]} bs=${bs[j]}" + if [ -d $cache_folder1 ] + then + echo "Removing $cache_folder1 directory..." + rm -r $cache_folder1 + rm -r $cache_folder2 + else + echo "$cache_folder1 directory not found" + fi + + LOG_TRANSFORM=1 SINGLE_DEV_MODE=1 python main.py --is_train \ + --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity \ + --save_folder=$save_folder --connect_type=$connect_type \ + --mesh_row=1 --mesh_col=${mesh_cols[k]} --compile \ + --topk=$topk \ + --verbose --swin_setting=${setting[k]} --swin \ + --micro_batch_size=${bs[j]} \ + --global_batch_size=32 --recompute --adaptive_recom + + torchrun --master_port=30001 --nnodes=${mesh_rows[k]} \ + --nproc_per_node=${mesh_cols[k]} main.py --is_train \ + --memory_constraint=$memory_constraint \ + --memory_granularity=$memory_granularity --save_folder=$save_folder \ + --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} \ + --plan_idx=0 --iter_num=2 --warm_num=1 --micro_batch_size=${bs[j]} \ + --global_batch_size=32 --swin_setting=${setting[k]} --swin \ + --recompute + done +done + +end_time=$(date +%s) +cost_time=$[ $end_time - $start_time] +echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/nnscaler/autodist/__init__.py b/nnscaler/autodist/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py new file mode 100644 index 00000000..63377b01 --- /dev/null +++ b/nnscaler/autodist/apis.py @@ -0,0 +1,277 @@ +from .spmd_solver import calc_optimal_spmd_plan +from .pipeline_solver import calc_optimal_pp_plan +from .autodist_config import AutoDistConfig +from .model_graph import ModelGraph, estimate_mem_lower_bound +from .descs import * +from .util import get_node_arch, replica, partition_node + +from nnscaler.graph import IRGraph +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.graph.function import IRDimops +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.schedule.predefined import PredefinedSched + +import json +import os +import logging +import time +from typing import Dict, List +from pathlib import Path +from collections import defaultdict + +_logger = logging.getLogger(__name__) + +__all__ = [ + 'parallelize_graph', +] + + +def check_env(autodist_config: AutoDistConfig): + error_msg = ' does not exist, please run \'python autodist/build_env.py\' first' + autodist_dir = autodist_config.profile_dir + '/' + get_node_arch() + assert os.path.exists(autodist_dir), autodist_dir + error_msg + comm_path = autodist_dir + '/comm' + assert os.path.exists(comm_path), comm_path + error_msg + + +def pre_estimate_mem(graph: ModelGraph): + ''' + Estimate a rough lower bound of memory consumption per device. Exit if the model is too large + for allocated resources. + ''' + + def to_mb(size): + return size // 1024 // 1024 + + def to_gb(size): + return to_mb(size) // 1024 + + # calculate sizes of activations, buffers and parameters, exit if the model is + # too large for allocated resources + param_mem, buffer_mem, activation_mem = graph.query_mem(0, graph.op_num - 1) + _logger.info( + f'param mem {to_mb(param_mem)} MB, buff mem {to_mb(buffer_mem)} MB, activation mem {to_mb(activation_mem)} MB' + ) + plan_ngpus = graph.autodist_config.mesh_desc.ngpus + if graph.autodist_config.zero_stage == 1: + zero_group_size = graph.autodist_config.world_size // graph.autodist_config.zero_ngroups + elif graph.autodist_config.zero_stage == 0: + zero_group_size = plan_ngpus + else: + raise RuntimeError( + f'invalid zero stage {graph.autodist_config.zero_stage}') + min_single_dev_mem = estimate_mem_lower_bound( + param_mem=param_mem, + buffer_mem=buffer_mem, + activation_mem=activation_mem, + plan_ngpus=plan_ngpus, + zero_group_size=zero_group_size, + cfg=graph.autodist_config, + ) + min_single_dev_mem += graph.recompute_mem + _logger.info( + f'estimated minimum memory per device {to_mb(min_single_dev_mem)} MB') + mem_constraint = graph.autodist_config.memory_constraint + if min_single_dev_mem > mem_constraint * 1024 * 1024 * 1024: + raise RuntimeError( + f'est min mem: {to_gb(min_single_dev_mem)} GB vs mem constraint: {mem_constraint} GB, ' + + 'model is too large for current resources, try to ' + + 'reduce batch size, add more devices or increase zero group size') + + +def calc_parallel_plan(graph: IRGraph, + autodist_config: AutoDistConfig) -> PipelineSearchOutput: + _logger.info(autodist_config) + check_env(autodist_config) + + autodist_graph = ModelGraph(ir_graph=graph, autodist_config=autodist_config) + pre_estimate_mem(autodist_graph) + + recompute_groups = autodist_graph.recompute_groups + recompute_groups = [ + [node.cid for node in group] for group in recompute_groups + ] + recompute_mem = autodist_graph.recompute_mem / 1024 / 1024 / 1024 + + if autodist_config.pipeline: + pp_out = calc_optimal_pp_plan(autodist_graph, autodist_config) + else: + pp_out = calc_optimal_spmd_plan(autodist_graph, autodist_config) + pp_out.desc.recompute_groups = recompute_groups + pp_out.stage_mems = [mem + recompute_mem for mem in pp_out.stage_mems] + return pp_out + + +def parallelize_graph(graph: IRGraph, + autodist_config: AutoDistConfig) -> IRGraph: + segments: List[IRSegment] = graph.select(ntype=IRSegment) + if segments: + raise RuntimeError('assume there is no segment in the graph') + + if autodist_config.load_plan_path: + _logger.info(f'load plan from {autodist_config.load_plan_path}') + with open(autodist_config.load_plan_path, 'r') as f: + search_out_json = json.load(f) + search_out = PipelineSearchOutput.from_json(search_out_json) + else: + _logger.info(f'save plan to {autodist_config.save_plan_path}') + compile_start_time = time.time() + search_out = calc_parallel_plan(graph, autodist_config) + compile_cost_time = time.time() - compile_start_time + + with open(autodist_config.save_plan_path, 'w') as f: + json.dump(search_out.to_json(), f, indent=2) + + _logger.info(f'use plan with e2e time/s {search_out.e2e_time}s,' + + f'stage mems/GB {search_out.stage_mems}, ' + + f'stage all times/s {search_out.stage_all_times}, ' + + f'stage comp times/s {search_out.stage_comp_times}') + pp_desc = search_out.desc + + cid2node: Dict[int, IRFwOperation] = dict() + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + cid2node[node.cid] = node + + # set recompute groups + for group in pp_desc.recompute_groups: + nodes = [cid2node[cid] for cid in group] + graph.recompute(nodes) + + # graph staging + if len(pp_desc.spmd_descs) > 1: + # add multiref for shared parameters across stages + shared_param2stage_info = defaultdict(dict) + for ftensor in graph.attributes(): + if not ftensor.is_param(): + continue + for ctensor, consumer in zip(graph.ctensors(ftensor), + graph.consumers(ftensor)): + if ctensor.grad is None: + continue + for stage_idx, stage_desc in enumerate(pp_desc.spmd_descs): + if consumer.cid in stage_desc.partition_descs: + if len(stage_desc.partition_descs[ + consumer.cid].desc) != 1: + raise RuntimeError( + f'node {consumer} has more than one partition dim' + ) + (p_idx, p_dim), p_num = stage_desc.partition_descs[ + consumer.cid].desc[0] + if p_idx != -1 and consumer.inputs()[p_dim] == ftensor: + raise RuntimeError( + f'node {consumer} has partitioned input {ftensor}' + ) + is_replicated = p_idx == -1 + if stage_idx not in shared_param2stage_info[ftensor]: + shared_param2stage_info[ftensor][stage_idx] = [] + shared_param2stage_info[ftensor][stage_idx].append( + is_replicated) + + for ftensor, stage_info in shared_param2stage_info.items(): + if len(stage_info) == 1: + continue + # special case: all stages have only one gpu + stage_idxs = list(stage_info.keys()) + stage_sizes = [ + pp_desc.spmd_descs[i].mesh_desc.ngpus for i in stage_idxs + ] + if all([s == 1 for s in stage_sizes]): + continue + # check whether all partitioned + # In AutoDist, shared parameters are not allowed to be partitioned. + # As a result, the related operator is replicated or in data parallel. + has_replicated = False + for stage_idx, replicate_info in stage_info.items(): + if any(replicate_info): + has_replicated = True + break + if has_replicated: + _logger.info(f'add multiref for shared param {ftensor}') + graph.multiref(ftensor) + + stages = [] + for spmd_desc in pp_desc.spmd_descs: + stage = [] + for cid in spmd_desc.partition_descs: + stage.append(cid2node[cid]) + stages.append(stage) + graph.staging([s[0] for s in stages]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + else: + stages = [graph] + + # add multiref to a tensor when + # 1. it is not a grad tensor + # 2. it has more than one consumers + # 3. consumers are different operators or in different partitions + for stage, spmd_desc in zip(stages, pp_desc.spmd_descs): + for ftensor in stage.full_tensors(): + if ftensor.is_grad(): + continue + if len(stage.consumers(ftensor)) <= 1: + continue + consumers = stage.consumers(ftensor) + splits = set() + for consumer in consumers: + if consumer.cid in spmd_desc.partition_descs: + node_desc = spmd_desc.partition_descs[consumer.cid].desc + if len(node_desc) != 1: + raise RuntimeError( + f'node {consumer} has more than one partition desc') + (p_idx, p_dim), p_num = node_desc[0] + else: + _logger.warning( + f'node {consumer} is not in any partition desc') + p_idx, p_dim, p_num = -1, -1, spmd_desc.mesh_desc.ngpus + repr_str = f'{consumer.signature}-{p_idx}-{p_dim}-{p_num}' + splits.add(repr_str) + if len(splits) > 1: + _logger.debug(f'add multiref {consumers}') + stage.multiref(ftensor) + + # partition and assign nodes to devices + # TODO(yizhu1): network topo aware device map + offset = 0 + for spmd_desc, stage in zip(pp_desc.spmd_descs, stages): + cur_ngpus = spmd_desc.mesh_desc.ngpus + dev = [offset + i for i in range(cur_ngpus)] + offset += cur_ngpus + for node in stage.nodes(): + if isinstance(node, IRFwOperation): + if isinstance( + node, + (IRGraphAnchor, IRPyFunc)) or node.name == 'multiref': + continue + if node.cid in spmd_desc.partition_descs: + p_desc = spmd_desc.partition_descs[node.cid] + partition_node(node, graph, dev, p_desc) + if isinstance(node, IRDimops): + _logger.info( + f'apply {node} with {node.anno} at {node.comment}, plan: {p_desc}' + ) + else: + _logger.info( + f'replicate non-IRDimops {node.signature} with {node.comment}' + ) + else: + replica(graph, node, dev) + _logger.info( + f'NOT included in plan, replicate {node.signature} with {node.comment}' + ) + + for dl in graph.select(ntype=IRDataOperation): + replica(graph, dl, devs=list(range(autodist_config.mesh_desc.ngpus))) + + # apply 1f1b schedule + if len(stages) > 1: + PredefinedSched.sched_1f1b( + graph, + autodist_config.update_freq, + len(stages), + ) + + return graph diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py new file mode 100644 index 00000000..2f2c5710 --- /dev/null +++ b/nnscaler/autodist/autodist_config.py @@ -0,0 +1,224 @@ +from pathlib import Path +import argparse +import logging +from .descs import MeshDesc + +_logger = logging.getLogger(__name__) + + +def _validate_file_path(path: str): + if not Path(path).exists(): + raise ValueError(f'file path {path} does not exist') + + +def _validate_dir_path(path: str): + if not Path(path).is_dir(): + raise ValueError(f'path {path} is not a directory') + + +class AutoDistConfig: + r""" + AutoDistConfig is the configuration for AutoDist. It contains the following fields: + + - task_name (`str`, *optional*, defaults to `'default'`): + The name of the current task to distinguish runs. + - consider_mem (`bool`, *optional*, defaults to `True`): + Whether to consider memory when searching plans. + - opt_resident_coef (`int`, *optional*, defaults to `2`): + The coefficient of the optimizer resident state compare with the model weight size. + For example: training a fp32 model with adam optimizer, movement1 and movement2 will be saved in the optimizer state, + movement1 and movement2 are fp32 and have the same size with model weight, + so the opt_resident_coef is (1 + 1) = 2. + Common cases: + - fp32 training w/ adam: (1 + 1) (fp32 movement1 + fp32 movement2) + - fp16 & bf16 training w/ adam: (2 + 2 + 2) (fp32 movement1 + fp32 movement2 + fp32 weight) + - fp16 & bf16 training w/ memory efficient adam: (2 + 2) (fp32 movement1 + fp32 movement2) + - opt_transient_coef (`int`, *optional*, defaults to `0`): + The coefficient of the optimizer transient state compare with the model weight size. + For example: training a fp16 model with adam optimizer, fp16 gradient will transient convert to fp32, + so the opt_transient_coef is 2. + Common cases: + - fp32 training w/ adam: 0 + - fp16 & bf16 training w/ adam w/o inkernal cast: (2) (fp32 gradient) + - fp16 & bf16 training w/ memory efficient adam w/o inkernal cast: (2 + 2) (fp32 weight + fp32 gradient) + - partition_constraints_path (`str`, *optional*, defaults to `''`): + The path to the partition constraints file. Details can be found in docs/solver_interface/partition_constraints.md + - profile_dir (`str`, *optional*, defaults to `~/.autodist`): + The directory to store the profiling results. + - load_plan_path (`str`, *optional*, defaults to `''`): + The path to the plan file to load. If specified, the plan will be loaded from the file instead of searching. + - save_plan_path (`str`, *optional*, defaults to `'./{task_name}.json'`): + The path to the plan file to save. + - topk (`int`, *optional*, defaults to `20`): + The number of plans to generate for robustness. + - zero_stage (`int`, *optional*, defaults to `0`): + The zero stage, see https://arxiv.org/abs/1910.02054 for details. Currently only support zero stage 0 and 1. + - zero_ngroups (`int`, *optional*, defaults to `1`): + The number of zero groups to balance memory usage and communication cost. The larger the number, + more memory will be used and less communication cost will be incurred. + - is_train (`bool`, *optional*, defaults to `True`): + Whether the model is for training or inference. + - mesh_row (`int`, *optional*, defaults to `1`): + The number of available nodes. + - mesh_col (`int`, *optional*, defaults to `1`): + The number of available devices in each node. + - recompute_modules (`str`, *optional*, defaults to `''`): + The module names to recompute, separated by `,`. For example, `module1,module2`. + - memory_constraint (`float`, *optional*, defaults to `32`): + The memory constraint in each device in GB. + - memory_granularity (`int`, *optional*, defaults to `1`): + The memory granularity in Byte. + - micro_batch_size (`int`, *optional*, defaults to `1`): + The micro batch size. + - update_freq (`int`, *optional*, defaults to `1`): + The update frequency (micro batch size x update freq = real batch size). + - world_size (`int`, *optional*, defaults to `1`): + The total number of devices. (mesh_row x mesh_col x scale_factor = world_size) + - nproc (`int`, *optional*, defaults to `1`): + The number of processes in pipeline parallelism search. + - ignore_small_tensor_threshold (`int`, *optional*, defaults to `1`): + The tensor size threshold to ignore. + - verbose (`bool`, *optional*, defaults to `False`): + Whether to print verbose information. + - re_profile (`bool`, *optional*, defaults to `False`): + If set to `True`, the computation profiling results will be overridden. + - pipeline (`bool`, *optional*, defaults to `False`): + Whether to use pipeline parallelism or tensor parallelism. + - pipeline_pivots (`str`, *optional*, defaults to `''`): + The module names to pivot the pipeline, separated by `,`. For example, if `module1,module2` + is specified, stages searched by pipeline solver only start from either `module1` or `module2`. + - max_pipeline_bubble_ratio (`float`, *optional*, defaults to `0.4`): + The maximum bubble ratio in pipeline parallelism. The higher the ratio, the more bubbles will be allowed, + the larger search space will be explored. + - max_pipeline_unbalance_ratio (`float`, *optional*, defaults to `0.5`): + The maximum unbalance ratio in pipeline parallelism. The higher the ratio, the more unbalance is required, + the smaller search space will be explored. + - solver (`str`, *optional*, defaults to `'dp'`): + The solver to use in spmd parallelism. Currently only support `'dp'` (dynamic programming) and `'ilp'` (integer linear programming). + """ + + def __init__(self, + task_name='default', + consider_mem=True, + opt_resident_coef=2, + opt_transient_coef=0, + partition_constraints_path='', + profile_dir=str(Path.home()) + '/.autodist', + load_plan_path='', + save_plan_path='', + topk=20, + zero_stage=0, + zero_ngroups=1, + is_train=True, + mesh_row=1, + mesh_col=1, + recompute_modules='', + memory_constraint=32, + memory_granularity=1, + micro_batch_size=1, + update_freq=1, + world_size=1, + nproc=1, + ignore_small_tensor_threshold=1, + verbose=False, + re_profile=False, + pipeline=False, + pipeline_pivots='', + max_pipeline_bubble_ratio=0.4, + max_pipeline_unbalance_ratio=0.5, + solver='dp', + **kwargs): + self.pc_path = partition_constraints_path + self.profile_dir = profile_dir + self.topk = topk + self.task_name = task_name + self.load_plan_path = load_plan_path + self.save_plan_path = save_plan_path + + self.consider_mem = consider_mem + self.zero_stage = zero_stage + self.zero_ngroups = zero_ngroups + self.opt_resident_coef = opt_resident_coef + self.opt_transient_coef = opt_transient_coef + self.is_train = is_train + self.mesh_desc = MeshDesc(mesh_row, mesh_col) + self.ngpus = self.mesh_desc.row * self.mesh_desc.col + self.recompute_modules = recompute_modules + # from GB to Byte + self.memory_constraint = int(memory_constraint * 1024 * 1024 * 1024) + self.memory_granularity = memory_granularity + self.micro_batch_size = micro_batch_size + self.update_freq = update_freq + self.world_size = world_size + self.nproc = nproc + + self.ignore_small_tensor_threshold = ignore_small_tensor_threshold + self.verbose = verbose + self.re_profile = re_profile + self.pipeline = pipeline + self.pipeline_pivots = pipeline_pivots + self.max_pipeline_bubble_ratio = max_pipeline_bubble_ratio + self.max_pipeline_unbalance_ratio = max_pipeline_unbalance_ratio + self.solver = solver + + ignored_keys = list(kwargs.keys()) + if ignored_keys: + warning_msg = f'autodist config got unknown config keys: {ignored_keys}' + _logger.warning(warning_msg) + + self._validate_config() + + def _validate_config(self): + if self.pc_path: + _validate_file_path(self.pc_path) + + _validate_dir_path(self.profile_dir) + + if self.pipeline: + if self.max_pipeline_bubble_ratio <= 0 or self.max_pipeline_bubble_ratio >= 1: + raise ValueError( + f'max pipeline bubble ratio {self.max_pipeline_bubble_ratio} must be in (0, 1)' + ) + if self.max_pipeline_unbalance_ratio <= 0 or self.max_pipeline_unbalance_ratio >= 1: + raise ValueError( + f'max pipeline unbalance ratio {self.max_pipeline_unbalance_ratio} must be in (0, 1)' + ) + + if self.topk < 1: + raise ValueError(f'topk {self.topk} must be positive') + + if not self.task_name: + raise RuntimeError('task name cannot be empty') + + if self.load_plan_path: + _validate_file_path(self.load_plan_path) + if self.save_plan_path: + raise ValueError( + 'cannot specify both load plan path and save plan path') + else: + self.save_plan_path = self.load_plan_path + + if self.save_plan_path: + _validate_dir_path(Path(self.save_plan_path).parent) + else: + self.save_plan_path = f'./{self.task_name}.json' + + if self.zero_stage not in [0, 1]: + raise ValueError(f'zero stage {self.zero_stage} must be 0 or 1') + else: + if self.zero_stage == 1: + if self.world_size % self.zero_ngroups != 0: + raise ValueError( + f'world size {self.world_size} must be divisible by zero num groups {self.zero_ngroups}' + ) + scale_factor = self.world_size // self.mesh_desc.ngpus + if scale_factor % self.zero_ngroups != 0: + raise ValueError( + f'world size {self.world_size} must be divisible by zero num groups {self.zero_ngroups}' + ) + + if not self.solver in ['dp', 'ilp']: + raise ValueError(f'solver {self.solver} must be dp or ilp') + + def __repr__(self): + return f'{self.__class__.__name__} {self.__dict__}' diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py new file mode 100644 index 00000000..a5ad5c92 --- /dev/null +++ b/nnscaler/autodist/cost_database.py @@ -0,0 +1,502 @@ +from typing import List, Tuple, Union +import numpy as np +import json +import os +from os import listdir +from pathlib import Path +import logging + +from nnscaler.graph import IRGraph +from nnscaler.ir.cten import IRTensor +from nnscaler.profiler.database import ProfileDataBase, ProfiledMetrics +from nnscaler.algorithm.ops.dimops import gen_partitions +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.dimops import DimopSplit, IRDimops + +from .util import get_node_arch +from .autodist_config import AutoDistConfig + +_logger = logging.getLogger(__name__) + + +def _piecewise_estimator(xs: List[float], ys: List[float], x: float) -> float: + """ + Piecewise linear estimator. + + Args: + xs: x coordinates of the points. + ys: y coordinates of the points. + x: x coordinate of the query point. + + Returns: + y coordinate of the query point. + """ + if x <= xs[0]: + return ys[0] + # Communication profile results vary across a large data range, e,g. x1 < x2 but y1 > y2. + # To make sure the returned time is always positive, using linear approximation when the + # message size is very large (>512MB). + if x >= xs[-1]: + assert xs[-1] > 0 and ys[ + -1] > 0, f'Unexpected val x={x}, xs={xs}, ys={ys}' + if xs[-1] < 512: + _logger.warning( + f'Estimation may be inaccurate for x={x} MB, xs={xs[-1]} MB, ys={ys[-1]} s' + ) + return x / xs[-1] * ys[-1] + for i in range(len(xs) - 1): + if xs[i] <= x < xs[i + 1]: + return ys[i] + (x - xs[i]) * (ys[i + 1] - ys[i]) / (xs[i + 1] - + xs[i]) + raise RuntimeError(f'x={x}, xs={xs}, ys={ys}, should not reach here') + + +class CostDatabase: + + def __init__(self, graph: IRGraph, config: AutoDistConfig): + self.comm_info = {} + + self.graph = graph + self.autodist_config = config + + self.profile_dir = Path(config.profile_dir) / get_node_arch() + self.db = ProfileDataBase() + self.comp_profile_path = self.profile_dir / 'comp' + if not self.comp_profile_path.exists(): + self.comp_profile_path.mkdir(parents=True) + self.db.load_ops(self.comp_profile_path) + + comm_dir = self.profile_dir / 'comm' + if not comm_dir.exists(): + raise RuntimeError( + f'{comm_dir} does not exist, please run \'python autodist/build_env.py\' first' + ) + for fname in listdir(comm_dir): + with open(comm_dir / fname, 'r') as f: + self.comm_info[fname] = json.load(f) + + self.memory_granularity = self.autodist_config.memory_granularity + self.ignore_small_tensor_threshold = self.autodist_config.ignore_small_tensor_threshold + + def profile_comp(self, partition_degree: int): + visited_nodes = set() + for node in self.graph.select(ntype=IRFwOperation): + if isinstance(node, (IRGraphAnchor, IRPyFunc)): + continue + hash_code = node.signature + ' : ' + self.db._serialize(node) + if hash_code in visited_nodes: + continue + if hasattr(node, 'anno'): + partition_nodes = gen_partitions(node, + partition_degree, + base=partition_degree, + depth=1) + else: + _logger.info(f'only profile replicated for {node}') + partition_nodes = [node] + for partition_node in partition_nodes: + # the returned schema may change over time, we re-profile + # if encountered an exception + try: + profiled_metrics: ProfiledMetrics = self.db.profile( + partition_node, + override=self.autodist_config.re_profile) + except Exception: + profiled_metrics: ProfiledMetrics = self.db.profile( + partition_node, override=True) + self.db.dump_op(self.comp_profile_path, + node.signature, + override=True) + visited_nodes.add(hash_code) + + def exist(self, node: IRFwOperation) -> bool: + return self.db.exist(node) + + def query_profiled_metrics( + self, obj: Union[IRFwOperation, 'CubeOperator', 'OpPartition'] + ) -> ProfiledMetrics: + node = obj if isinstance(obj, IRFwOperation) else obj.ir_cell + if not self.exist(node): + raise RuntimeError(f'cannot find {node} in the profile database') + return self.db.query(node) + + def round(self, mem): + if mem % self.memory_granularity == 0: + return mem + else: + return (mem + self.memory_granularity + ) // self.memory_granularity * self.memory_granularity + + def filter_then_sum(self, tensor_sizes: Tuple[int], mask=[]): + # assert len(tensor_sizes) == len( + # mask), f'len(tensor_sizes) is not equal to len(masks)' + masked_sizes = [i * j for i, j in zip(tensor_sizes, mask)] + return sum(masked_sizes) + + def get_mems(self, op_partition): + memory_types = ['train', 'infer', 'input', 'param', 'buffer'] + memory_results = {} + for memory_type in memory_types: + if isinstance(op_partition.operator.ir_cell, IRDimops): + mem = self.query_single_mem(op_partition, memory_type) + else: + mem = 0 + memory_results[memory_type] = mem + return memory_results + + def get_mem_and_buffer(self, op_partition, is_train: bool, stage_num: int): + """ + Get the memory consumption and buffer memory consumption of a partition option. + + Args: + op_partition: the partition option to be calculated + + Returns: + node_mem: the memory consumption of the partition option + node_buffer: the buffer memory consumption of the partition option + activation_mem: the activation memory consumption of the partition option + opt_transient_mem: the optimizer transient memory consumption of the partition option + """ + memory_results = self.get_mems(op_partition) + activation_mem = memory_results['train'] + if not self.autodist_config.zero_stage in [0, 1]: + raise RuntimeError( + f'invalid zero stage {self.autodist_config.zero_stage}') + # estimate optimizer memory consumption for training. + # no gradient no memory consumption, + # weight_mem should be 0 when require_grad is false. + opt_resident_mem, opt_transient_mem = 0, 0 + if is_train and memory_results['param'] > 0: + if self.autodist_config.zero_stage == 0: + weight_mem = memory_results['param'] + else: + # if zero-1 is used, we assume the full weight is distributed equally + # among all devices + weight_mem = self.query_single_mem(op_partition, 'full_weight') + opt_resident_mem = self.autodist_config.opt_resident_coef * weight_mem + opt_transient_mem = self.autodist_config.opt_transient_coef * weight_mem + if self.autodist_config.zero_stage == 1: + if op_partition.is_replicated(): + assert self.autodist_config.world_size % self.autodist_config.ngpus == 0 + scale_factor = self.autodist_config.world_size // self.autodist_config.ngpus + divisor = scale_factor // self.autodist_config.zero_ngroups + else: + assert self.autodist_config.world_size % self.autodist_config.zero_ngroups == 0 + divisor = self.autodist_config.world_size // self.autodist_config.zero_ngroups + opt_resident_mem = opt_resident_mem // divisor + opt_transient_mem = opt_transient_mem // divisor + + # optimizer state + saved activation tensors for backward + param + # + gradients + buffer tensors (has deduplicated with the saved tensors) + node_mem = opt_resident_mem + memory_results[ + 'train'] + 2 * memory_results['param'] + memory_results['buffer'] + node_mem = node_mem + (stage_num - 1) * activation_mem \ + if is_train else memory_results['param'] + node_buffer = max(memory_results.values()) \ + if is_train else memory_results['infer'] + + if node_mem != 0: + + def to_mb(x): + return x / 1024 / 1024 + + _logger.debug( + f'{op_partition.operator.ir_cell.cid}, {op_partition.ir_cell}, ' + + f'node mem: {to_mb(node_mem)} MB, ' + + f'activation mem: {to_mb(activation_mem)} MB, ' + + f'optimizer transient mem: {to_mb(opt_transient_mem)} MB') + + return node_mem, node_buffer, activation_mem, opt_transient_mem + + def query_single_mem(self, obj, memory_type, round=True) -> int: + """ + Query memory size of a single operator or partition. + OpPartition represents one partition of an operator. + CubeOperator represents the full operator before partitioning. + + 'input' is the total bytes of the input tensors excluding parameter and buffer tensors. + 'param' is the total bytes of the parameter tensors. + 'buffer' is the total bytes of the buffer tensors. + 'infer' is the peak bytes during op inference. + 'train' is the total bytes of the saved activation tensors for backward. + 'full_weight' is the total bytes of the weight of the full operator. + + Args: + obj: OpPartition or CubeOperator + memory_type: 'input', 'param', 'infer', 'train', 'full_weight' + round: whether to round the memory size up to the nearest multiple of memory_granularity + + Returns: + memory size in bytes + """ + from .op_partition import OpPartition + from .cube_operator import CubeOperator + if isinstance(obj, OpPartition): + masks = self.gen_masks(obj.operator) + else: + assert isinstance(obj, CubeOperator) + masks = self.gen_masks(obj) + if memory_type == 'full_weight' and isinstance(obj, OpPartition): + profiled_metrics = self.query_profiled_metrics(obj.operator) + else: + profiled_metrics = self.query_profiled_metrics(obj) + + if memory_type == 'input': + mask = masks['input'] + ret = self.filter_then_sum(profiled_metrics.in_mem_info, mask) + elif memory_type == 'param': + mask = masks['param'] + ret = self.filter_then_sum(profiled_metrics.param_mem_info, mask) + elif memory_type == 'buffer': + mask = masks['buffer'] + ret = self.filter_then_sum(profiled_metrics.buffer_mem_info, mask) + elif memory_type == 'infer': + ret = profiled_metrics.infer_memory + elif memory_type == 'train': + mask = masks['train'] + ret = self.filter_then_sum(profiled_metrics.train_mem_info, mask) + elif memory_type == 'full_weight': + mask = masks['param'] + ret = self.filter_then_sum(profiled_metrics.param_mem_info, mask) + else: + raise ValueError( + f'Invalid memory_type {memory_type} provided. Choose from: ' + + "'input', 'param', 'buffer', 'infer', 'train', 'full_weight'.") + if round: + return self.round(ret) + else: + return ret + + def query_comp_time(self, + op_or_partition: Union['CubeOperator', 'OpPartition'], + recompute: bool = False, + is_train: bool = True): + profiled_metrics = self.query_profiled_metrics(op_or_partition) + if not is_train: + return profiled_metrics.fw_span / 1000 + if recompute: + return (profiled_metrics.fw_span + profiled_metrics.bw_span + + profiled_metrics.fw_span) / 1000 + else: + return (profiled_metrics.fw_span + profiled_metrics.bw_span) / 1000 + + def primitive_to_cost(self, dev_num: int, byte_size: int, primitive: str): + if byte_size == 0: + return 0 + size_mb = byte_size / 1024 / 1024 + device_setting = f'intra_{dev_num}.json' + sizes_in_mb, times_in_s = self.comm_info[device_setting][primitive] + est_time = _piecewise_estimator(sizes_in_mb, times_in_s, size_mb) + assert est_time >= 0, f'{primitive} {dev_num} comm size: {size_mb} MB, est time: {est_time} s' + return est_time + + def calc_weight_update_time(self, cur_partition) -> float: + """ + Calculate communication cost for weight update. Currently cost is evaluated + by allreduce. + + Args: + cur_partition: one partition option of the operator + + Returns: + communication cost in seconds + """ + # partition_dims and partition_nums represent a concrete partition option of a node + # if the element in partition_dims is -1, it means the node is replicated. + # currently, len of partition_dims is 1, we only support partitioning one dimension + partition_dims = cur_partition.partition_dims + partition_nums = cur_partition.partition_nums + # TODO: remove this assertion, support partitioning multiple dimensions + assert len( + partition_dims + ) == 1, f'expect len(partition_dims) == 1, got {len(partition_dims)}' + full_weight_mem = self.query_single_mem(cur_partition, + 'full_weight', + round=False) + partitioned_weight_mem = self.query_single_mem(cur_partition, + 'param', + round=False) + + if partitioned_weight_mem == 0: + return 0 + if full_weight_mem % partitioned_weight_mem == 0: + mem_weight_spatial_num = full_weight_mem // partitioned_weight_mem + else: + # when setting memory granularity > 1, possible that the two numbers are not divisible + mem_weight_spatial_num = (full_weight_mem + partitioned_weight_mem + ) // partitioned_weight_mem + + replica_num = 1 + for i, partition_dim_name in enumerate(partition_dims): + if partition_dim_name == -1: + replica_num *= partition_nums[i] + all_num = 1 + for num in cur_partition.partition_nums: + all_num *= num + weight_update_num = all_num // (mem_weight_spatial_num * replica_num) + if weight_update_num == 1: + return 0 + comm_time = self.primitive_to_cost(dev_num=weight_update_num, + primitive='all reduce', + byte_size=partitioned_weight_mem) + + return comm_time + + def estimate_comm_cost(self, src_p, dst_p, is_forward) -> float: + """ + Estimate communication cost between src partition and dst partition. + Currently the communication is only for activation tensors. + + Args: + src_p: the partition of source operator + dst_p: the partition of destination operator + is_forward: whether the communication is for only forward pass + or only backward pass + + Returns: + communication cost in seconds + """ + assert len(src_p.partition_nums) == 1 and len(dst_p.partition_nums) == 1 + + def comm_cost(tensor: IRTensor, num_devices: int, src_split: DimopSplit, + dst_split: DimopSplit, dst_replica: bool, + is_forward: bool): + """ + Calculate communication cost for a single tensor. + Note for data parallel, we don't consider allreduce cost as it + will only be performed at the last of iteration. + + Args: + tensor: the tensor to be communicated + num_devices: number of devices + src_split: the split info of the tensor in the source operator + dst_split: the split info of the tensor in the destination operator + dst_replica: whether the destination operator is replicated + is_forward: whether the communication is for only forward pass or + only backward pass + + Returns: + communication cost in seconds + """ + assert not dst_split.isV() + assert not tensor.is_attr() + byte_size = tensor.byte_size() + + def helper(primitive: str): + return self.primitive_to_cost(num_devices, byte_size, primitive) + + # R: replicated, V: value split, D: dim split + if src_split.isR(): + if dst_split.isR(): + if dst_replica: + return 0.0 + else: + # identity-allreduce + if is_forward: + return 0.0 + else: + return helper('all reduce') + elif dst_split.isD(): + # split-allgather + if is_forward: + return 0.0 + else: + return helper('all gather') + if src_split.isV(): + if dst_split.isR(): + # allreduce-identity + if dst_replica: + if is_forward: + return helper('all reduce') + else: + return 0.0 + else: + # allreduce-allreduce + return helper('all reduce') + elif dst_split.isD(): + if is_forward: + return helper('reduce scatter') + else: + return helper('all gather') + if src_split.isD(): + # allgahter-reducescatter or allgather-split + if dst_split.isR(): + if is_forward: + return helper('all gather') + else: + if dst_replica: + return 0.0 + else: + return helper('reduce scatter') + # all2all-all2all or identity-identity + if dst_split.isD(): + return 0.0 if src_split.dims == dst_split.dims else helper( + 'all to all') + raise NotImplementedError( + f'Unknown split type: {src_split} -> {dst_split}') + + src_p_dim, src_p_num = src_p.partition_dims[0], src_p.partition_nums[0] + dst_p_dim, dst_p_num = dst_p.partition_dims[0], dst_p.partition_nums[0] + assert src_p_num == dst_p_num + src_idx, src_dim = src_p.operator.dim_id2pos(src_p_dim) + dst_idx, dst_dim = dst_p.operator.dim_id2pos(dst_p_dim) + rule_src, rule_dst = None, None + if src_idx != -1: + rule_src = src_p.operator.ir_cell.algorithms('dim').infer( + src_idx, src_dim, src_p_num) + if dst_idx != -1: + rule_dst = dst_p.operator.ir_cell.algorithms('dim').infer( + dst_idx, dst_dim, dst_p_num) + cost = 0.0 + for i, src_t in enumerate(src_p.operator.ir_cell.outputs()): + for j, dst_t in enumerate(dst_p.operator.ir_cell.inputs()): + if src_t == dst_t: + if not is_forward and not src_t.requires_grad: + # if the activation does not require grad, + # then no backward communication. + cost += 0.0 + else: + cost += comm_cost( + src_t, src_p_num, + rule_src.outputs()[i] + if rule_src is not None else DimopSplit(r=True), + rule_dst.inputs()[j] if rule_dst is not None else + DimopSplit(r=True), dst_idx == -1, is_forward) + break + return cost + + def gen_masks(self, op): + masks = {} + profiled_metrics = self.query_profiled_metrics(op) + inputs = profiled_metrics.in_mem_info + param = profiled_metrics.param_mem_info + buffer = profiled_metrics.buffer_mem_info + train_m = profiled_metrics.train_mem_info + + def helper(mems): + return [ + 0 if mem < self.ignore_small_tensor_threshold else 1 + for mem in mems + ] + + param_mask = helper(param) + for idx in op.omit_param_idx: + param_mask[idx] = 0 + train_m_mask = helper(train_m) + for idx in op.omit_train_idx: + train_m_mask[idx] = 0 + buffer_mask = helper(buffer) + for idx in op.omit_buffer_idx: + buffer_mask[idx] = 0 + # no need to deduplicate inputs, because input tensors are transient. + # the saved input tensors for backward have been considered in train_m. + + masks = { + 'input': helper(inputs), + 'param': param_mask, + 'train': train_m_mask, + 'buffer': buffer_mask, + } + return masks diff --git a/nnscaler/autodist/csrc/solver.cpp b/nnscaler/autodist/csrc/solver.cpp new file mode 100644 index 00000000..26cde49e --- /dev/null +++ b/nnscaler/autodist/csrc/solver.cpp @@ -0,0 +1,799 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(unsigned int n = std::thread::hardware_concurrency()); + + template void enqueue(F &&f); + void waitFinished(); + ~ThreadPool(); + + unsigned int getProcessed() const { return processed; } + +private: + std::vector workers; + std::deque> tasks; + std::mutex queue_mutex; + std::condition_variable cv_task; + std::condition_variable cv_finished; + std::atomic_uint processed; + unsigned int busy; + bool stop; + + void thread_proc(); +}; + +ThreadPool::ThreadPool(unsigned int n) : busy(), processed(), stop() { + for (unsigned int i = 0; i < n; ++i) + workers.emplace_back(std::bind(&ThreadPool::thread_proc, this)); +} + +ThreadPool::~ThreadPool() { + // set stop-condition + std::unique_lock latch(queue_mutex); + stop = true; + cv_task.notify_all(); + latch.unlock(); + + // all threads terminate, then we're done. + for (auto &t : workers) + t.join(); +} + +void ThreadPool::thread_proc() { + while (true) { + std::unique_lock latch(queue_mutex); + cv_task.wait(latch, [this]() { return stop || !tasks.empty(); }); + if (!tasks.empty()) { + // got work. set busy. + ++busy; + + // pull from queue + auto fn = tasks.front(); + tasks.pop_front(); + + // release lock. run async + latch.unlock(); + + // run function outside context + fn(); + ++processed; + + latch.lock(); + --busy; + cv_finished.notify_one(); + } else if (stop) { + break; + } + } +} + +// generic function push +template void ThreadPool::enqueue(F &&f) { + std::unique_lock lock(queue_mutex); + tasks.emplace_back(std::forward(f)); + cv_task.notify_one(); +} + +// waits until the queue is empty. +void ThreadPool::waitFinished() { + std::unique_lock lock(queue_mutex); + cv_finished.wait(lock, [this]() { return tasks.empty() && (busy == 0); }); +} + +struct DPNode; + +struct Node { + int id; + int father_id; + + int cut_len; + std::vector cut_nodes; + + int p_num; + std::vector p_time; + std::vector p_comp_mem; + std::vector p_buf_mem; + std::vector p_act_mem; + std::vector p_opt_mem; + std::vector p_father; + + int producer_num; + std::vector producers; + std::vector> comm_costs; + + // assume the number of combinations is less than 2e9 + int dp_num; + std::vector dp_nodes; +}; + +int verbose; + +struct DPNode { + Node *graph_node; + int pg_id; + std::vector ir; + std::vector> in_edges; + // mem, time, activation_mem, optimzer_mem + std::vector> state; +}; + +void resetNode(Node *node) { + for (DPNode *dp_node : node->dp_nodes) { + dp_node->state.clear(); + } +} + +void printNode(Node *node) { + std::cout << "id: " << node->id << std::endl; + std::cout << "father_id: " << node->father_id << std::endl; + std::cout << "cut_len: " << node->cut_len << std::endl; + std::cout << "cut_nodes: "; + for (auto cut_node : node->cut_nodes) { + std::cout << cut_node->id << " "; + } + std::cout << std::endl; + std::cout << "p_num: " << node->p_num << std::endl; + std::cout << "p_time: "; + for (auto p_time : node->p_time) { + std::cout << p_time << " "; + } + std::cout << std::endl; + std::cout << "p_comp_mem: "; + for (auto p_comp_mem : node->p_comp_mem) { + std::cout << p_comp_mem << " "; + } + std::cout << std::endl; + std::cout << "p_buf_mem: "; + for (auto p_buf_mem : node->p_buf_mem) { + std::cout << p_buf_mem << " "; + } + std::cout << std::endl; + std::cout << "p_act_mem: "; + for (auto p_act_mem : node->p_act_mem) { + std::cout << p_act_mem << " "; + } + std::cout << std::endl; + std::cout << "p_opt_mem: "; + for (auto p_opt_mem : node->p_opt_mem) { + std::cout << p_opt_mem << " "; + } + std::cout << std::endl; + std::cout << "producer_num: " << node->producer_num << std::endl; + std::cout << "producers: "; + for (auto producer : node->producers) { + std::cout << producer->id << " "; + } + std::cout << std::endl; + std::cout << "p_father: "; + for (auto p_father : node->p_father) { + std::cout << p_father << " "; + } + std::cout << std::endl; + std::cout << "comm_costs: " << std::endl; + for (auto comm_cost : node->comm_costs) { + for (auto cost : comm_cost) { + std::cout << cost << " "; + } + std::cout << std::endl; + } + std::cout << "dp_num: " << node->dp_num << std::endl; + std::cout << std::endl; +} + +std::unordered_map id2node; +std::vector> queries; +// mode = 0: training, use the sum of the two largest buffer sizes +// mode = 1: inference, use the largest buffer size +int mode; +// mem_bound: the maximum memory usage, in bytes +int mem_bound; +// mem_div: the memory divisor, to avoid overflow in int32 +int mem_div; +int topk; +const int MAX_CONCURRENCY = std::thread::hardware_concurrency(); + +bool is_big_endian(void) { + union { + uint32_t i; + char c[4]; + } bint = {0x01020304}; + + return bint.c[0] == 1; +} + +template void read_binary(T *ptr, std::ifstream &stream) { + stream.read(reinterpret_cast(ptr), sizeof(*ptr)); +} + +void build_graph(std::ifstream &input) { + queries.clear(); + id2node.clear(); + + int n, q_size; + read_binary(&mode, input); + read_binary(&n, input); + read_binary(&mem_bound, input); + read_binary(&mem_div, input); + read_binary(&topk, input); + read_binary(&q_size, input); + + if (verbose) { + printf("is big endian: %d\n", is_big_endian()); + } + printf("node num: %d, mem_bound: %d, mem_div: %d, topk: %d, query num: %d\n", + n, mem_bound, mem_div, topk, q_size); + queries.resize(q_size); + for (int i = 0; i < q_size; ++i) { + int start, end; + read_binary(&start, input); + read_binary(&end, input); + queries[i] = std::make_pair(start, end); + } + for (int i = 0; i < n; ++i) { + Node *node = new Node(); + read_binary(&node->id, input); + id2node[node->id] = node; + read_binary(&node->father_id, input); + + int cut_id; + read_binary(&node->cut_len, input); + node->cut_nodes.resize(node->cut_len); + for (int j = 0; j < node->cut_len; ++j) { + read_binary(&cut_id, input); + node->cut_nodes[j] = id2node[cut_id]; + } + + read_binary(&node->p_num, input); + node->p_father.resize(node->p_num); + node->p_time.resize(node->p_num); + node->p_comp_mem.resize(node->p_num); + node->p_buf_mem.resize(node->p_num); + node->p_act_mem.resize(node->p_num); + node->p_opt_mem.resize(node->p_num); + for (int j = 0; j < node->p_num; ++j) { + read_binary(node->p_time.data() + j, input); + read_binary(node->p_comp_mem.data() + j, input); + read_binary(node->p_buf_mem.data() + j, input); + read_binary(node->p_act_mem.data() + j, input); + read_binary(node->p_opt_mem.data() + j, input); + read_binary(node->p_father.data() + j, input); + } + + read_binary(&node->producer_num, input); + node->producers.clear(); + node->comm_costs.clear(); + node->comm_costs.resize(node->producer_num); + for (int j = 0; j < node->producer_num; ++j) { + int producer_id; + read_binary(&producer_id, input); + Node *producer = id2node[producer_id]; + node->producers.push_back(producer); + node->comm_costs[j].resize(node->p_num * producer->p_num); + for (int k = 0; k < node->p_num * producer->p_num; ++k) { + read_binary(node->comm_costs[j].data() + k, input); + } + } + node->dp_num = 1; + for (Node *cut_node : node->cut_nodes) { + node->dp_num *= cut_node->p_num; + } + node->dp_nodes.resize(node->dp_num); + for (int j = 0; j < node->dp_num; ++j) { + DPNode *dp_node = new DPNode(); + node->dp_nodes[j] = dp_node; + dp_node->graph_node = node; + // pg: partition group, denotes the maintained partition states in + // a node. to reduce memory usage, we use a single int to + // represent a partition group + dp_node->pg_id = j; + dp_node->ir.clear(); + dp_node->in_edges.clear(); + dp_node->state.clear(); + } + if (verbose) { + printNode(node); + } + } +} + +// lazy decode +// after decoding, ir stores the partition id of each cut node +void decodePGID(DPNode *dp_node) { + if (!dp_node->ir.empty()) { + return; + } + Node *node = dp_node->graph_node; + int val = dp_node->pg_id; + for (int i = 0; i < node->cut_len; ++i) { + Node *cur_node = node->cut_nodes[node->cut_len - i - 1]; + dp_node->ir.push_back(val % cur_node->p_num); + val /= cur_node->p_num; + } + std::reverse(dp_node->ir.begin(), dp_node->ir.end()); +} + +// lazy build edge +void buildInEdges(DPNode *dp_node) { + if (!dp_node->in_edges.empty()) { + return; + } + Node *node = dp_node->graph_node; + + // special case: the node does not have any producer + // the pred dp node is composed of the same cut nodes as the current node + // except the last one. the transition cost is 0 since there is no + // communication + if (node->producer_num == 0) { + int val = 0; + for (int i = 0; i < node->cut_len - 1; ++i) { + val += dp_node->ir[i]; + if (i < node->cut_len - 2) { + val *= node->cut_nodes[i + 1]->p_num; + } + } + Node *pre_node = id2node[node->id - 1]; + dp_node->in_edges.push_back(std::make_pair(pre_node->dp_nodes[val], 0)); + return; + } + + int cur_p = *(dp_node->ir.rbegin()); + // we have filtered out the partition that cannot find a father to follow + assert(node->p_father[cur_p] != -1); + std::map info; + for (int i = 0; i < node->cut_len - 1; ++i) { + info[node->cut_nodes[i]->id] = dp_node->ir[i]; + } + // TODO(yizhu1): optimize + int producer_comb_num = 1; + for (Node *producer : node->producers) { + producer_comb_num *= producer->p_num; + } + // enumerate all the possible producer partition combinations + // to build the in edges + for (int idx = 0; idx < producer_comb_num; ++idx) { + bool is_legal = true; + int val = idx; + std::vector producer_ps(node->producer_num); + // decode the producer partition combination + for (int j = 0; j < node->producer_num; ++j) { + int k = node->producer_num - 1 - j; + producer_ps[k] = val % node->producers[k]->p_num; + val /= node->producers[k]->p_num; + // constraint: if the producer shares the same father with the node, + // then the partition of the node should follow the producer's partition, + // except the producer is the father node. + if (node->father_id != node->id) { + Node *producer = node->producers[k]; + // TODO: do we need to check producer->father_id != producer->id? + // seems this case will be filtered out by checker in line 411 + if (producer->father_id == node->father_id && + producer->father_id != producer->id) { + if (node->p_father[cur_p] != producer->p_father[producer_ps[k]]) { + is_legal = false; + } + } + } + } + if (!is_legal) { + continue; + } + // + std::vector> cur_ir(node->cut_len - 1); + bool has_found_follow = false; + for (int i = 0; i < node->cut_len - 1; ++i) { + cur_ir[i] = std::make_pair(node->cut_nodes[i]->id, dp_node->ir[i]); + if (node->cut_nodes[i]->father_id == node->father_id) { + has_found_follow = true; + } + } + double cost = 0; + std::vector> follow_candidates; + for (int j = 0; j < node->producer_num; ++j) { + int producer_id = node->producers[j]->id; + int producer_p = producer_ps[j]; + auto iter = info.find(producer_id); + if (iter != info.end()) { + if (producer_p != iter->second) { + is_legal = false; + break; + } + } else { + Node *producer = node->producers[j]; + if (producer->father_id != node->father_id) { + // check that there is a existing node in cur_ir that in the same + // follow chain with the producer + bool find_existing_follow = false; + for (int i = 0; i < cur_ir.size(); ++i) { + Node *tmp = id2node[cur_ir[i].first]; + if (tmp->father_id == producer->father_id) { + find_existing_follow = true; + // update + if (tmp->id < producer->id) { + for (int _ = 0; _ < producer->p_num; ++_) { + if (producer->p_father[_] == + tmp->p_father[cur_ir[i].second]) { + // replace to align with the filter logic in python + // only the newest node in the follow chain is kept + cur_ir[i] = std::make_pair(producer->id, _); + break; + } + } + } + break; + } + } + if (!find_existing_follow) { + cur_ir.push_back(std::make_pair(producer_id, producer_p)); + } + } else { + follow_candidates.push_back(std::make_pair(producer_id, producer_p)); + } + } + cost += + node->comm_costs[j][cur_p * node->producers[j]->p_num + producer_p]; + } + if (!is_legal) { + continue; + } + // handle follow + bool find_pre_id = false; + for (int j = 0; j < cur_ir.size(); ++j) { + if (cur_ir[j].first == node->id - 1) { + find_pre_id = true; + break; + } + } + if (!find_pre_id) { + Node *pre_node = id2node[node->id - 1]; + if (pre_node->father_id != node->father_id) { + // do nothing, means the pre_node's output is not used + // we select the 1st partition of the pre_node + // need to be careful when the graph has multiple outputs + // shall we constrain that the output of the graph is replicated? + } else if (pre_node->father_id == pre_node->id) { + assert(follow_candidates.rbegin()->first == pre_node->id); + cur_ir.push_back(*follow_candidates.rbegin()); + } else { + bool find_same_follow_p = false; + for (int k = 0; k < pre_node->p_num; ++k) { + if (pre_node->p_father[k] == node->p_father[cur_p]) { + cur_ir.push_back(std::make_pair(node->id - 1, k)); + find_same_follow_p = true; + break; + } + } + assert(find_same_follow_p); + } + } else { + if (node->father_id != node->id && !has_found_follow && + !follow_candidates.empty()) { + cur_ir.push_back(*follow_candidates.rbegin()); + } + } + std::sort(cur_ir.begin(), cur_ir.end()); + val = 0; + for (int j = 0; j < cur_ir.size(); ++j) { + val += cur_ir[j].second; + if (j + 1 < cur_ir.size()) { + val *= id2node[cur_ir[j + 1].first]->p_num; + } + } + dp_node->in_edges.push_back( + std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); + } +} + +// do dp for a partition group +void update(DPNode *dp_node, int start_level) { + Node *node = dp_node->graph_node; + decodePGID(dp_node); + int cur_p_idx = *(dp_node->ir.rbegin()); + if (node->id == start_level) { + // each dp node maintains a list of states, each state is a tuple + // (mem, time, pred_dp_node, activation_mem, optimizer_mem) + dp_node->state.push_back(std::make_tuple( + node->p_comp_mem[cur_p_idx], node->p_time[cur_p_idx], nullptr, + node->p_act_mem[cur_p_idx], node->p_opt_mem[cur_p_idx])); + return; + } + + // storing edges takes space, so we build edges when needed + buildInEdges(dp_node); + int cur_p = *(dp_node->ir.rbegin()); + if (dp_node->in_edges.empty()) { + dp_node->state.push_back(std::make_tuple( + 0, std::numeric_limits::infinity(), nullptr, 0, 0)); + return; + } + + // use a priority queue to maintain the best state, similar to the merge sort + double cur_p_time = node->p_time[cur_p]; + int cur_p_comp_mem = node->p_comp_mem[cur_p]; + int cur_p_act_mem = node->p_act_mem[cur_p]; + int cur_p_opt_mem = node->p_opt_mem[cur_p]; + std::priority_queue> pq; + for (int i = 0; i < dp_node->in_edges.size(); ++i) { + DPNode *pred = dp_node->in_edges[i].first; + int mem = cur_p_comp_mem + std::get<0>(pred->state[0]); + double cost = + cur_p_time + dp_node->in_edges[i].second + std::get<1>(pred->state[0]); + int act_mem = cur_p_act_mem + std::get<3>(pred->state[0]); + int opt_mem = cur_p_opt_mem + std::get<4>(pred->state[0]); + pq.push(std::make_tuple(-mem, -cost, i, -act_mem, -opt_mem)); + } + + std::vector lows(dp_node->in_edges.size(), 1); + + int cur_mem; + double cur_cost; + int pred_idx; + int cur_act_mem; + int cur_opt_mem; + while (!pq.empty()) { + std::tie(cur_mem, cur_cost, pred_idx, cur_act_mem, cur_opt_mem) = pq.top(); + cur_mem = -cur_mem; + cur_cost = -cur_cost; + cur_act_mem = -cur_act_mem; + cur_opt_mem = -cur_opt_mem; + pq.pop(); + if (lows[pred_idx] < dp_node->in_edges[pred_idx].first->state.size()) { + DPNode *pred = dp_node->in_edges[pred_idx].first; + int mem = cur_p_comp_mem + std::get<0>(pred->state[lows[pred_idx]]); + double cost = cur_p_time + dp_node->in_edges[pred_idx].second + + std::get<1>(pred->state[lows[pred_idx]]); + int act_mem = cur_p_act_mem + std::get<3>(pred->state[lows[pred_idx]]); + int opt_mem = cur_p_opt_mem + std::get<4>(pred->state[lows[pred_idx]]); + pq.push(std::make_tuple(-mem, -cost, pred_idx, -act_mem, -opt_mem)); + ++lows[pred_idx]; + } + if (dp_node->state.empty()) { + dp_node->state.push_back( + std::make_tuple(cur_mem, cur_cost, dp_node->in_edges[pred_idx].first, + cur_act_mem, cur_opt_mem)); + } else { + int pre_mem = std::get<0>(dp_node->state[dp_node->state.size() - 1]); + double pre_cost = std::get<1>(dp_node->state[dp_node->state.size() - 1]); + // if (cur_mem > pre_mem && cur_cost < pre_cost && + // cur_mem + cur_opt_mem <= mem_bound) { + if (cur_mem > pre_mem && cur_cost < pre_cost && + cur_mem - cur_act_mem + std::max(cur_act_mem, cur_opt_mem) <= + mem_bound) { + dp_node->state.push_back(std::make_tuple( + cur_mem, cur_cost, dp_node->in_edges[pred_idx].first, cur_act_mem, + cur_opt_mem)); + } + } + } +} + +ThreadPool pool(MAX_CONCURRENCY); + +std::vector> split_work(int num) { + std::vector work; + if (num < MAX_CONCURRENCY) { + work = std::vector(num, 1); + } else { + work = std::vector(MAX_CONCURRENCY, num / MAX_CONCURRENCY); + for (int i = 0; i < num % MAX_CONCURRENCY; ++i) { + work[i] += 1; + } + } + std::vector> ret(work.size()); + int cum_sum = 0; + for (int i = 0; i < work.size(); ++i) { + ret[i] = std::make_pair(cum_sum, work[i]); + cum_sum += work[i]; + } + return ret; +} + +void do_dp(int start_level, int end_level) { + // reset all the dp nodes, since we may have multiple queries + for (auto iter = id2node.begin(); iter != id2node.end(); ++iter) { + resetNode(iter->second); + } + + for (int i = start_level; i <= end_level; ++i) { + // use multi-thread to do dp for each level to reduce time + auto iter = id2node.find(i); + if (iter == id2node.end()) { + // TODO(yizhu1): check here + assert(false); + } + if (verbose) { + std::cout << "Start to process level id: " << i + << ", state num: " << iter->second->dp_nodes.size() + << std::endl; + } + std::vector> split_info = + split_work(iter->second->dp_num); + for (const auto &item : split_info) { + pool.enqueue([=] { + for (int i = 0; i < item.second; ++i) { + int offset = item.first + i; + update(iter->second->dp_nodes[offset], start_level); + } + }); + } + pool.waitFinished(); + } +} + +std::tuple>> +process_state(DPNode *dp_node, int idx) { + // build the optimal path of each partition of last operator + // and return the best path + std::vector> path; + DPNode *cur_dp_node = dp_node; + int cur_idx = idx; + int best_mem = std::get<0>(dp_node->state[idx]); + double best_time = std::get<1>(dp_node->state[idx]); + int act_mem = std::get<3>(dp_node->state[idx]); + int opt_mem = std::get<4>(dp_node->state[idx]); + double inner_time = 0; + int cur_best_mem = best_mem; + std::vector buffers; + while (true) { + int cur_p = *(cur_dp_node->ir.rbegin()); + Node *node = cur_dp_node->graph_node; + path.push_back(std::make_pair(node->id, cur_p)); + buffers.push_back(node->p_buf_mem[cur_p]); + inner_time += node->p_time[cur_p]; + cur_best_mem -= node->p_comp_mem[cur_p]; + DPNode *pred_dp_node = std::get<2>(cur_dp_node->state[cur_idx]); + if (pred_dp_node == nullptr) { + break; + } else { + cur_dp_node = pred_dp_node; + cur_idx = std::lower_bound( + cur_dp_node->state.begin(), cur_dp_node->state.end(), + std::make_tuple(cur_best_mem, static_cast(-1), + static_cast(nullptr), -1, -1)) - + cur_dp_node->state.begin(); + } + } + std::reverse(path.begin(), path.end()); + std::sort(buffers.begin(), buffers.end()); + long long ret_mem = static_cast(best_mem); + if (mode == 0) { + ret_mem += buffers[buffers.size() - 1] + buffers[buffers.size() - 2]; + } else if (mode == 1) { + ret_mem += buffers[buffers.size() - 1]; + } + ret_mem = ret_mem - act_mem + std::max(act_mem, opt_mem); + if (ret_mem > mem_bound) { + return std::make_tuple(-1, -1, -1, std::vector>()); + } + if (verbose) { + std::cout << "best time: " << best_time + << ", best mem: " << best_mem / 1024 / 1024 * mem_div << "MB, " + << "activation mem: " << act_mem / 1024 / 1024 * mem_div << "MB, " + << "optimizer state mem: " << opt_mem / 1024 / 1024 * mem_div + << "MB" << std::endl; + } + return std::make_tuple(best_time, inner_time, static_cast(ret_mem), + path); +} + +template void write_binary(T val, std::ofstream &stream) { + stream.write(reinterpret_cast(&val), sizeof val); +} + +void post_process(int start_level, int end_level, int topk, + std::ofstream &output) { + std::vector>>> + best_info; + double best_time; + double inner_time; + int best_mem; + std::vector> path; + for (DPNode *dp_node : id2node[end_level]->dp_nodes) { + int cnt = 0; + for (int i = 0; i < dp_node->state.size(); ++i) { + std::tie(best_time, inner_time, best_mem, path) = + process_state(dp_node, dp_node->state.size() - i - 1); + if (best_time > 0) { + ++cnt; + best_info.push_back( + std::make_tuple(best_time, inner_time, best_mem, path)); + if (cnt == topk) { + break; + } + } + } + } + std::sort(best_info.begin(), best_info.end()); + int ret_size = std::min(topk, int(best_info.size())); + int path_len = 0; + if (ret_size > 0) { + path_len = static_cast(std::get<3>(best_info[0]).size()); + } + write_binary(start_level, output); + write_binary(end_level, output); + write_binary(ret_size, output); + write_binary(path_len, output); + for (int i = 0; i < ret_size; ++i) { + best_time = std::get<0>(best_info[i]); + inner_time = std::get<1>(best_info[i]); + best_mem = std::get<2>(best_info[i]); + write_binary(best_time, output); + write_binary(inner_time, output); + write_binary(best_mem, output); + for (auto &item : std::get<3>(best_info[i])) { + write_binary(item.first, output); + write_binary(item.second, output); + } + } +} + +int main(int argc, char **argv) { + if (argc != 4) { + std::cout << "Usage: ./solver in_path out_path is_verbose" << std::endl; + return 0; + } + std::ifstream input(argv[1], std::ios::binary); + std::ofstream output(argv[2], std::ios::out | std::ios::binary); + verbose = argv[3][0] == '1'; + build_graph(input); + // to reduce time, we first group the queries by start node (level) + std::unordered_map> intervals; + for (const auto &query : queries) { + auto iter = intervals.find(query.first); + if (iter == intervals.end()) { + intervals[query.first] = std::vector(1, query.second); + } else { + iter->second.push_back(query.second); + } + } + auto start = std::chrono::system_clock::now(); + for (auto &item : intervals) { + // for each start node, we do dp until the last end node + int start_level = item.first; + std::vector &end_levels = item.second; + std::sort(end_levels.begin(), end_levels.end()); + do_dp(start_level, *end_levels.rbegin()); + for (int end_level : end_levels) { + post_process(start_level, end_level, topk, output); + } + long long state_cnt = 0; + for (auto iter = id2node.begin(); iter != id2node.end(); ++iter) { + int cur_id = iter->first; + Node *cur_node = iter->second; + for (DPNode *dp_node : cur_node->dp_nodes) { + state_cnt += dp_node->state.size(); + } + } + if (verbose) { + std::cout << "state num: " << state_cnt << std::endl; + } + } + auto end = std::chrono::system_clock::now(); + + std::chrono::duration elapsed_seconds = end - start; + + std::cout << "elapsed time: " << elapsed_seconds.count() << " s" << std::endl; + input.close(); + output.close(); + return 0; +} diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py new file mode 100644 index 00000000..fe1ad671 --- /dev/null +++ b/nnscaler/autodist/cube_operator.py @@ -0,0 +1,111 @@ +from typing import List, Tuple, Dict, Set, Optional +from nnscaler.ir import IRTensor, IRFwOperation, IRSubTensor +from nnscaler.graph.function.dimops import DimAnno, IRDimops +from nnscaler.algorithm.ops.dimops import collect_split_info + + +class CubeOperator: + """ + CubeOperator is a wrapper for IRFwOperation. + Currently, it maintains the following information for an IRDimops: + - in_tensors: input tensors, including parameters and buffers + - out_tensors: output tensors + - producers: operators that produce the input tensors + - consumers: operators that consume the output tensors + - dim_info: a mapping from dimension name to its position and reduce type + - parallelable_dims: a set of dimension names that can be parallelized + - recompute: a flag indicating whether the operator will be recomputed + - has_batch_dim: a flag indicating whether the operator has a batch dimension + - since there can be shared tensors in the model, we use the following vars to estimate the memory usage accurately: + - omit_train_idx: a list of indices of activation tensors that should be omitted + - omit_param_idx: a list of indices of parameter tensors that should be omitted + - omit_buffer_idx: a list of indices of buffer tensors that should be omitted + """ + + def __init__(self, ir_cell: IRFwOperation): + self.ir_cell = ir_cell + self.in_tensors, self.out_tensors = [], [] + self.op_name = self.ir_cell.signature + + self.producers: Set[CubeOperator] = set() + self.consumers: Set[CubeOperator] = set() + + self.dim_info = {} + self.parallelable_dims = set() + self._recompute = False + + self.omit_train_idx = [] + self.omit_param_idx = [] + self.omit_buffer_idx = [] + self.has_batch_dim = False + + if not isinstance(ir_cell, IRDimops): + return + + for item in ir_cell.inputs(): + if isinstance(item, IRTensor): + self.in_tensors.append(item) + for item in ir_cell.outputs(): + if isinstance(item, IRTensor): + self.out_tensors.append(item) + + self.collect_anno_info() + + @property + def recompute(self): + return self._recompute + + @recompute.setter + def recompute(self, value: bool): + self._recompute = value + + def add_producer(self, producer: 'CubeOperator'): + self.producers.add(producer) + + def add_consumer(self, consumer: 'CubeOperator'): + self.consumers.add(consumer) + + def collect_anno_info(self): + for idx_shape, shape_anno in enumerate(self.ir_cell.anno.inputs()): + if not isinstance(self.ir_cell.inputs()[idx_shape], IRTensor): + continue + for idx_dim, dim_anno in enumerate(shape_anno.dims): + for idx_id, identifier in enumerate(dim_anno.identifiers): + reduce_type = dim_anno.reduces[idx_id] + if reduce_type != DimAnno.ReduceType.Freeze: + self.parallelable_dims.add(identifier) + val = (idx_shape, idx_dim, idx_id, reduce_type) + if identifier not in self.dim_info: + self.dim_info[identifier] = val + else: + if reduce_type != self.dim_info[identifier][-1]: + raise ValueError( + f'inconsistent reduce type for {identifier} in {self.ir_cell} with {self.ir_cell.anno}' + ) + + def dim_id2pos(self, dim_name: str) -> Tuple[int, int]: + if dim_name == -1: + return (-1, -1) + else: + assert dim_name in self.dim_info, f'{dim_name} not in {self.dim_info}' + idx, dim, _, _ = self.dim_info[dim_name] + return idx, dim + + def pos2dim_id(self, pos: Tuple[int, int]) -> str: + if pos == (-1, -1): + return -1 + else: + if not isinstance(self.ir_cell, IRDimops): + raise ValueError(f'{self.ir_cell} is not IRDimops') + idx, dim = pos + adim, reduce_type = self.ir_cell.algorithms( + 'dim').get_identifier_reduce(idx, dim, 2) + assert adim is not None, f'cannot find dim at {pos} in {self.ir_cell}' + return adim + + def get_reduce_type(self, dim_id: str): + return self.dim_info[dim_id][-1] + + def __repr__(self): + anno = self.ir_cell.anno if isinstance(self.ir_cell, IRDimops) else '' + return f'Operator {self.ir_cell} {anno} at {self.ir_cell.comment}' diff --git a/nnscaler/autodist/descs.py b/nnscaler/autodist/descs.py new file mode 100644 index 00000000..f6fcab20 --- /dev/null +++ b/nnscaler/autodist/descs.py @@ -0,0 +1,187 @@ +from dataclasses import dataclass +from typing import List, Dict, Tuple, Any, Optional +import json +import copy +import yaml + + +@dataclass +class NodePartitionDesc: + # list element: ((idx, dim), num), the order matters + desc: List[Tuple[Tuple[int, int], int]] + + +@dataclass +class MeshDesc: + # inter node + row: int + # intra node + col: int + + @property + def ngpus(self): + return self.row * self.col + + def to_json(self): + return (self.row, self.col) + + @staticmethod + def from_json(val): + return MeshDesc(*val) + + +@dataclass +class TensorParallelDesc: + partition_descs: Dict[int, NodePartitionDesc] + recompute_groups: List[List[int]] + mesh_desc: MeshDesc + + def to_json(self): + ret = {} + descs_list = [(k, v.desc) for k, v in self.partition_descs.items()] + ret['partition_descs'] = descs_list + ret['recompute_groups'] = self.recompute_groups + ret['mesh_desc'] = self.mesh_desc.to_json() + return ret + + @staticmethod + def from_json(ret): + partition_descs = {} + for k, v in ret['partition_descs']: + partition_descs[k] = NodePartitionDesc(v) + return TensorParallelDesc(partition_descs, + copy.deepcopy(ret['recompute_groups']), + MeshDesc.from_json(ret['mesh_desc'])) + + +@dataclass +class SPMDSearchOutput: + desc: TensorParallelDesc + memory: float + all_time: float + comp_time: float + + def to_json(self): + return { + 'desc': self.desc.to_json(), + 'memory': self.memory, + 'all_time': self.all_time, + 'comp_time': self.comp_time, + } + + @staticmethod + def from_json(json_val): + desc = TensorParallelDesc.from_json(json_val['desc']) + return SPMDSearchOutput(desc, json_val['memory'], json_val['all_time'], + json_val['comp_time']) + + +@dataclass +class PipelineParallelDesc: + spmd_descs: List[TensorParallelDesc] + recompute_groups: List[List[int]] + mesh_desc: MeshDesc + + def to_json(self): + return { + 'spmd_descs': [desc.to_json() for desc in self.spmd_descs], + 'recompute_groups': self.recompute_groups, + 'mesh_desc': self.mesh_desc.to_json(), + } + + @staticmethod + def from_json(json_val): + spmd_descs = [] + for spmd_desc_json in json_val['spmd_descs']: + spmd_descs.append(TensorParallelDesc.from_json(spmd_desc_json)) + recompute_groups = copy.deepcopy(json_val['recompute_groups']) + mesh_desc = MeshDesc.from_json(json_val['mesh_desc']) + return PipelineParallelDesc(spmd_descs, recompute_groups, mesh_desc) + + +@dataclass +class PipelineSearchOutput: + desc: PipelineParallelDesc + e2e_time: float + stage_mems: List[float] + stage_all_times: List[float] + stage_comp_times: List[float] + + def to_json(self): + return { + 'desc': self.desc.to_json(), + 'e2e_time': self.e2e_time, + 'stage_mems': self.stage_mems, + 'stage_all_times': self.stage_all_times, + 'stage_comp_times': self.stage_comp_times, + } + + @staticmethod + def from_json(json_val): + desc = PipelineParallelDesc.from_json(json_val['desc']) + return PipelineSearchOutput(desc, json_val['e2e_time'], + json_val['stage_mems'], + json_val['stage_all_times'], + json_val['stage_comp_times']) + + +@dataclass +class PartitionConstraint: + + # the name of the corresponding operator in the model. It equals + # to the `signature` field in the `IRFwOperation` in cube + name: str + # the **closest** father module name of the operator + parent_module: str + # a list of allowed partition dimensions of input tensors + allowed_partition_dims: List[Tuple[int, int]] + replica_allowed: bool = True + + @staticmethod + def from_json(content: Dict[str, Any]): + allowed_partition_dims = [ + tuple(x) for x in content['allowed_partition_dims'] + ] + return PartitionConstraint(content['name'], content['parent_module'], + allowed_partition_dims, + content['replica_allowed']) + + def to_json(self): + return { + 'name': self.name, + 'parent_module': self.parent_module, + 'allowed_partition_dims': self.allowed_partition_dims, + 'replica_allowed': self.replica_allowed, + } + + @staticmethod + def from_yaml(content: Dict[str, Any]): + + def _parse_dims(dims: str) -> List[int]: + return tuple([int(x) for x in dims.split(',')]) + + allowed_partition_dims = [ + _parse_dims(x) for x in content['allowed_partition_dims'] + ] + return PartitionConstraint(content['name'], content['parent_module'], + allowed_partition_dims, + content['replica_allowed']) + + def to_yaml(self): + + def to_str(dims: List[int]) -> str: + return ','.join([str(x) for x in dims]) + + allowed_partition_dims = [ + to_str(x) for x in self.allowed_partition_dims + ] + return { + 'name': self.name, + 'parent_module': self.parent_module, + 'allowed_partition_dims': allowed_partition_dims, + 'replica_allowed': self.replica_allowed, + } + + def __hash__(self): + return hash((self.name, self.parent_module, + tuple(self.allowed_partition_dims), self.replica_allowed)) diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py new file mode 100644 index 00000000..766efc6d --- /dev/null +++ b/nnscaler/autodist/model_graph.py @@ -0,0 +1,797 @@ +from __future__ import annotations + +from nnscaler.graph.graph import IRGraph +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir.operator import IRFwOperation +from nnscaler.ir.cten import IRObject, IRTensor +from .cube_operator import CubeOperator +from .autodist_config import AutoDistConfig +from .cost_database import CostDatabase + +from dataclasses import dataclass +from collections import deque +import logging +import copy +from typing import List, Tuple, Dict, Any, Callable + +_logger = logging.getLogger(__name__) + + +# expect ops with output tensors are all IRDimops +def should_include(node: IRFwOperation): + return any(isinstance(t, IRTensor) for t in node.outputs()) + + +def calc_flops(node: IRFwOperation): + if 'torch.nn.functional.linear' in node.signature: + assert len(node.inputs()) >= 2 + assert len(node.outputs()) == 1 + ret = 2 * node.inputs()[0].nelement() + if len(node.inputs()[1].shape) == 2: + ret = ret * node.inputs()[1].shape[0] + return ret + elif 'torch.bmm' in node.signature: + # this function do not support broadcast + assert len(node.inputs()) == 2 + assert len(node.outputs()) == 1 + b, m, k = node.inputs()[0].shape + _, _, n = node.inputs()[1].shape + return 2 * b * m * n * k + elif 'torch.matmul' in node.signature: + assert len(node.inputs()) == 2 + assert len(node.outputs()) == 1 + lhs, rhs = node.inputs() + out = node.outputs()[0] + if len(lhs.shape) == 1 and len(rhs.shape) == 1: + # vector-vector + ret = lhs.nelement() + elif len(lhs.shape) == 2 and len(rhs.shape) == 2: + # matrix-vector + m, k = lhs.shape + _, n = rhs.shape + ret = m * n * k + elif len(lhs.shape) == 1 and len(rhs.shape) == 2: + # vector-matrix + k, n = rhs.shape + ret = k * n + elif len(lhs.shape) == 2 and len(rhs.shape) == 1: + # matrix-vector + m, k = lhs.shape + ret = m * k + elif len(lhs.shape) > 2 or len(rhs.shape) > 2: + ret = out.nelement() + if len(lhs.shape) > 2: + ret = ret * lhs.shape[-1] + elif len(rhs.shape) > 2: + ret = ret * rhs.shape[-2] + else: + raise RuntimeError( + f'unsupported matmul {lhs.shape}, {rhs.shape}, {out.shape}') + return 2 * ret + return 0 + + +def estimate_mem_lower_bound( + param_mem: int, + buffer_mem: int, + activation_mem: int, + plan_ngpus: int, + zero_group_size: int, + cfg: AutoDistConfig, +) -> float: + ''' + Given memory consumption of parameters, buffers and activations, and the number of + pipeline stages (counting from the last stage, including itself), calculate the + minimum possible memory consumption of each device. + Assume the activation memory is shared with transient optimizer memory, since activations + have been deallocated before optimizer's step. + The minimum memory consumption is achieved when: + 1. activations, parameters, buffers and gradients are distributed evenly across plan_ngpus + 2. the optimizer memory is distributed evenly across zero_group_size (when zero stage 1 is enabled) or plan_ngpus + ''' + # avg memory cost of activation, param (grad), buffer + activation_mem = activation_mem / plan_ngpus + param_mem = param_mem / plan_ngpus + buffer_mem = buffer_mem / plan_ngpus + + # avg opt mem + opt_resident_mem = cfg.opt_resident_coef * param_mem + opt_transient_mem = cfg.opt_transient_coef * param_mem + if cfg.zero_stage == 1: + opt_resident_mem = opt_resident_mem / zero_group_size + opt_transient_mem = opt_transient_mem / zero_group_size + elif cfg.zero_stage == 0: + opt_resident_mem = opt_resident_mem / plan_ngpus + opt_transient_mem = opt_transient_mem / plan_ngpus + else: + raise RuntimeError(f'invalid zero stage {cfg.zero_stage}') + + min_single_dev_mem = max(opt_transient_mem, activation_mem) \ + + 2 * param_mem + buffer_mem + opt_resident_mem + return min_single_dev_mem + + +def aggregate_common_mem(sub_nodes: List[IRFwOperation], + check_connected: bool) -> Tuple[int, int, int]: + """ + Aggregate the memory size of input tensors, parameter tensors and buffer tensors + in the subgraph. + Use IRObject as edges to find the connected components, and check the connectivity + of the subgraph if check_connected is True. + + Args: + sub_nodes: a list of IRFwOperation from the whole graph + check_connected: whether to check the connectivity of the subgraph + + Returns: + in_mem: the memory size of input tensors + param_mem: the memory size of parameter tensors + buffer_mem: the memory size of buffer tensors + """ + + def _unfold_complex(data): + if isinstance(data, (list, tuple)): + ret = [] + for d in data: + ret += _unfold_complex(d) + return ret + elif isinstance(data, dict): + ret = [] + for _, d in data.items(): + ret += _unfold_complex(d) + return ret + elif isinstance(data, slice): + return _unfold_complex([data.start, data.stop, data.step]) + elif isinstance(data, IRObject): + return [data] + else: + return [] + + object2producer: Dict[IRObject, IRFwOperation] = dict() + for node in sub_nodes: + for complex_output in node.outputs(): + for t in _unfold_complex(complex_output): + assert isinstance(t, IRObject) + if t in object2producer: + raise RuntimeError(f'tensor {t} has multiple producers') + object2producer[t] = node + + # use union set to check whether the subgraph is connected + node2father: Dict[IRDimops, IRDimops] = dict() + for node in sub_nodes: + node2father[node] = node + + def get_father(node): + father = node2father[node] + if node == father: + return father + else: + father = get_father(father) + node2father[node] = father + return father + + def merge(lhs, rhs): + lhs_father = get_father(lhs) + rhs_father = get_father(rhs) + node2father[rhs_father] = lhs_father + + edges: Dict[IRFwOperation, List[IRFwOperation]] = dict() + in2consumer: Dict[IRObject, List[IRFwOperation]] = dict() + for node in sub_nodes: + # deal with both inputs and kwargs to track connecting edges + complex_inputs = list(node.inputs()) + list(node.kwargs.values()) + for complex_input in complex_inputs: + for t in _unfold_complex(complex_input): + assert isinstance(t, IRObject) + if t not in object2producer: + if t not in in2consumer: + in2consumer[t] = [] + in2consumer[t].append(node) + continue + src = object2producer[t] + if src not in edges: + edges[src] = [] + edges[src].append(node) + merge(src, node) + + for _, consumers in in2consumer.items(): + for i in range(len(consumers) - 1): + merge(consumers[i], consumers[i + 1]) + + components = set() + for node, father in node2father.items(): + components.add(get_father(node)) + + if check_connected and len(components) > 1: + for i, father in enumerate(components): + _logger.info(f'{i}-th component') + for node, _ in node2father.items(): + if get_father(node) == father: + _logger.info(node) + raise RuntimeError('more than one connect component') + + in_mem, param_mem, buffer_mem = 0, 0, 0 + for t, _ in in2consumer.items(): + if not isinstance(t, IRTensor): + continue + if t.is_param(): + param_mem += t.byte_size() + elif t.is_buffer(): + buffer_mem += t.byte_size() + else: + in_mem += t.byte_size() + return in_mem, param_mem, buffer_mem + + +def aggregate_train_mem(sub_nodes: List[IRFwOperation], db) -> int: + visited_tensors: Set[IRTensor] = set() + train_mem = 0 + for node in sub_nodes: + metrics = db.query(node) + if metrics is None: + # if the node is not in the database, skip it currently + continue + train_mem2in_idx = metrics.train_mem2in_idx + train_mem_info = metrics.train_mem_info + for mem, in_idx in zip(train_mem_info, train_mem2in_idx): + if in_idx == -1: + train_mem += mem + else: + t = node.inputs()[in_idx] + if t not in visited_tensors: + train_mem += mem + visited_tensors.add(t) + return train_mem + + +class ScopeNode: + + def __init__(self, + name: str, + module_type: Any, + parent=None, + node: IRFwOperation = None, + depth: int = 0, + leaf_size: int = 0, + flops: int = 0, + fw_span: float = 0, + start: int = 0, + end: int = 0): + self.name = name + self.module_type = module_type + self.children = [] + self.parent = parent + self.node = node + self.depth = depth + self.leaf_size = leaf_size + self.flops = flops + self.fw_span = fw_span + self.in_mem = 0 + self.train_mem = 0 + self.param_mem = 0 + self.buffer_mem = 0 + self.start = start + self.end = end + + def insert(self, node: IRFwOperation, module_info: List[Tuple[str, Any]], + flops: int, fw_span: float, idx: int): + self.leaf_size += 1 + self.flops += flops + self.fw_span += fw_span + if len(module_info) == 0: + child = ScopeNode(node.signature, + None, + parent=self, + node=node, + depth=self.depth + 1, + leaf_size=1, + flops=flops, + fw_span=fw_span, + start=idx, + end=idx) + self.children.append(child) + return child + module_path, module_type = module_info[0] + for i, child in enumerate(self.children): + if child.name == module_path: + if i == len(self.children) - 1: + return child.insert(node, + module_info[1:], + flops, + fw_span, + idx=idx) + else: + _logger.warning( + f'{node} with {module_info} used multiple times') + child = ScopeNode(module_path, + module_type, + parent=self, + depth=self.depth + 1) + ret = child.insert(node, module_info[1:], flops, fw_span, idx=idx) + self.children.append(child) + return ret + + @property + def is_leaf(self): + return self.node is not None + + @property + def is_root(self): + return self.parent is None + + def select(self, func): + if func(self): + return [self] + ret = [] + for child in self.children: + ret += child.select(func) + return ret + + # time complexity: O(depth * #nodes) + def pull_up(self, db): + # leaf node + if self.node is not None: + if not isinstance(self.node, IRFwOperation): + raise RuntimeError(f'expect IRFwOperation, got {self.node}') + if isinstance(self.node, IRDimops): + profiled_metrics = db.query(self.node) + if profiled_metrics is not None: + self.in_mem = sum(profiled_metrics.in_mem_info) + self.train_mem = sum(profiled_metrics.train_mem_info) + self.param_mem = sum(profiled_metrics.param_mem_info) + self.buffer_mem = sum(profiled_metrics.buffer_mem_info) + else: + raise RuntimeError(f'cannot find {self.node} in db') + else: + if should_include(self.node): + _logger.warning( + f'detect a non-IRDimops {self.node.signature} ' + \ + f'at {self.node.comment} that produces tensors') + return [self.node] + sub_nodes = [] + for child in self.children: + sub_nodes += child.pull_up(db) + # a sub-module can have more than one connected component, like RoPE + # we check the connectivity only when self is the root node + self.in_mem, self.param_mem, self.buffer_mem = aggregate_common_mem( + sub_nodes, self.parent is None) + self.train_mem = aggregate_train_mem(sub_nodes, db) + self.start = self.children[0].start + self.end = self.children[-1].end + return sub_nodes + + def query(self, start: int, end: int, cache: Dict[Tuple[int, int], Any], + leaf_handler: Callable[int, Any], merger: Callable[List[Any], + Any]): + ''' + Boost the query by segment tree and cache + Args: + start: the left index of nodes + end: the right index of nodes + cache: the cache for query + leaf_handler: the handler for leaf nodes + merger: the merger for sub-intervals + + Returns: + the result of the query + ''' + if not (self.start <= start and end <= self.end): + raise RuntimeError( + f'[{start}, {end}] not in [{self.start}, {self.end}]') + if (start, end) in cache: + return cache[(start, end)] + if start == end: + ret = leaf_handler(start) + else: + # break the interval into sub-intervals + def get_intersection(x1, y1, x2, y2): + return max(x1, x2), min(y1, y2) + + sub_rets = [] + for child in self.children: + x, y = get_intersection(start, end, child.start, child.end) + if x > y: + continue + sub_rets.append(child.query(x, y, cache, leaf_handler, merger)) + ret = merger(sub_rets) + cache[(start, end)] = ret + return ret + + def __repr__(self): + if self.node is not None: + return '' + desc = ' ' * self.depth + info = [ + self.name, + str(self.module_type), + f'depth: {self.depth}', + f'size: {self.leaf_size}', + 'FLOPs: {0:.3g}B'.format(self.flops / 1e9), + 'fw_span: {0:.3g}ms'.format(self.fw_span), + # TODO: the node may be a IRPytfunc whose fw_span = 0 currently + 'FLOPS: {0:.3g}T'.format(0. if self.fw_span == 0. else self.flops / + self.fw_span / 1e9), + 'in_mem: {0:.3g}MB'.format(self.in_mem / 1024 / 1024), + 'train_mem: {0:.3g}MB'.format(self.train_mem / 1024 / 1024), + 'param_mem: {0:.3g}MB'.format(self.param_mem / 1024 / 1024), + 'buffer_mem: {0:.3g}MB'.format(self.buffer_mem / 1024 / 1024) + ] + desc = desc + ', '.join(info) + '\n' + for child in self.children: + desc += child.__repr__() + return desc + + +# a class to store statistics of a continuous sub-sequence +# in the initial graph's topology sequence +@dataclass +class IntervalInfo: + start: int + end: int + fw_span: float + param_mem: int + buffer_mem: int + activation_mem: int + + def equivalent(self, other): + if self.end - self.start != other.end - other.start: + return False + if self.fw_span != other.fw_span: + return False + if self.param_mem != other.param_mem: + return False + if self.buffer_mem != other.buffer_mem: + return False + if self.activation_mem != other.activation_mem: + return False + # TODO(yizhu1): check whether the operators are the same + return True + + +class ModelGraph: + + def __init__(self, ir_graph: IRGraph, autodist_config: AutoDistConfig): + self.ir_graph = ir_graph + self.autodist_config = autodist_config + self.cost_database = CostDatabase(self.ir_graph, self.autodist_config) + self.cost_database.profile_comp(partition_degree=1) + + self.scope_tree_root = self.reconstruct_scope_tree() + self.scope_leaf_nodes = self.scope_tree_root.select(lambda x: x.is_leaf) + + self.recompute_mem, self.recompute_groups = self.init_recompute_nodes() + + self.operator_list: List[CubeOperator] = [] + self._ir_cell2idx: Dict[IRFwOperation, int] = dict() + self.init_operators() + + self._query_fw_span_cache: Dict[Tuple[int, int], float] = dict() + self._query_mem_cache = dict() + + @property + def op_num(self): + return len(self.operator_list) + + def get_op_idx(self, op: CubeOperator): + return self._ir_cell2idx[op.ir_cell] + + def reconstruct_scope_tree(self): + fw_cube_nodes = self.ir_graph.select(ntype=IRFwOperation) + root = ScopeNode('root', None) + db = self.cost_database.db + + for i, node in enumerate(fw_cube_nodes): + # filter out the anchor nodes, since they don't have module stack + if isinstance(node, IRGraphAnchor): + continue + if isinstance(node, IRDimops): + if not self.cost_database.exist(node): + fw_span = 0 + else: + fw_span = self.cost_database.query_profiled_metrics( + node).fw_span + else: + fw_span = 0 + module_info = [] + for module_path, module_type in node.module_stack.items(): + module_info.append((module_path.split('.')[-1], module_type)) + root.insert(node, module_info, calc_flops(node), fw_span, idx=i) + + root.pull_up(db) + _logger.info('\n' + root.__repr__()) + + return root + + def get_pipeline_pivots(self) -> List[int]: + ''' + To reduce the search space, we only consider limited number of pivot + operators which break the model into several pipeline stages. + Currently, user's guidance (autodist_config.pipeline_pivots) is required. + + Returns: + the indices of pivot operators in the operator list + ''' + # TODO(yizhu1): check recompute_modules are between pivots + if not self.autodist_config.pipeline: + raise RuntimeError('pipeline is not enabled') + pp_pivot_modules = self.autodist_config.pipeline_pivots.split(',') + pp_pivot_modules = [module for module in pp_pivot_modules if module] + if not pp_pivot_modules: + raise RuntimeError('pipeline_pivots is empty') + + def filter_func(scope_node): + if scope_node.is_leaf: + return False + for module in pp_pivot_modules: + if scope_node.is_root: + continue + if not isinstance(scope_node.module_type, type): + raise RuntimeError( + f'expect type, got {scope_node.module_type}') + if module == scope_node.module_type.__name__: + return True + return False + + pivot_modules = self.scope_tree_root.select(filter_func) + node2idx: Dict[IRFwOperation, int] = dict() + for i, op in enumerate(self.operator_list): + node2idx[op.ir_cell] = i + pivot_idxs = [] + for module in pivot_modules: + leaf_nodes = module.select(lambda x: x.is_leaf) + pivot_idxs.append(node2idx[leaf_nodes[0].node]) + if not pivot_idxs: + raise RuntimeError(f'cannot find any pivot in {pp_pivot_modules}') + return pivot_idxs + + def calc_interval_info(self, start: int, end: int) -> IntervalInfo: + ''' + calculate the interval info of nodes in [start, end] + ''' + fw_span = self.query_fw_span(start, end) + param_mem, buffer_mem, activation_mem = self.query_mem(start, end) + return IntervalInfo(start, end, fw_span, param_mem, buffer_mem, + activation_mem) + + def group_pipeline_intervals(self) -> List[List[IntervalInfo]]: + ''' + Group the pipeline intervals with the same interval info. It is used to + reduce the search time of a stage's (interval) spmd plan: only one + interval in a group needs to be searched. + + Returns: + a list of groups, each group contains a list of intervals + ''' + idxs = [0] + self.get_pipeline_pivots() + [self.op_num] + len2intervals: Dict[int, List[List[IntervalInfo]]] = dict() + for i in range(len(idxs) - 1): + start = idxs[i] + for j in range(i + 1, len(idxs)): + end = idxs[j] - 1 + length = end - start + 1 + cur_interval = self.calc_interval_info(start, end) + if length not in len2intervals: + len2intervals[length] = [[cur_interval]] + else: + found_equal = False + for group in len2intervals[length]: + if group[0].equivalent(cur_interval): + group.append(cur_interval) + found_equal = True + break + if not found_equal: + len2intervals[length].append([cur_interval]) + ret = [] + for _, groups in len2intervals.items(): + ret += groups + return ret + + def query_fw_span(self, start: int, end: int) -> float: + ''' + Time complexity: O(log(#nodes)) + Args: + start: the left index of the operator list + end: the right index of the operator list + + Returns: + the forward span of operators in [start, end] + ''' + + def leaf_handler(idx): + return self.scope_leaf_nodes[idx].fw_span + + def merger(sub_rets): + return sum(sub_rets) + + return self.scope_tree_root.query( + start, + end, + self._query_fw_span_cache, + leaf_handler, + merger, + ) + + def init_recompute_nodes(self): + recompute_modules = self.autodist_config.recompute_modules.split(',') + recompute_modules = [ + module for module in recompute_modules if len(module) > 0 + ] + if len(recompute_modules) == 0: + return 0, [] + + def fetch_module(scope_node): + if scope_node.node is not None: + return [] + for module in recompute_modules: + if module in str(scope_node.module_type): + return [scope_node] + ret = [] + for child in scope_node.children: + ret += fetch_module(child) + return ret + + modules = fetch_module(self.scope_tree_root) + in_mem, train_mem = 0, 0 + for module in modules: + in_mem += module.in_mem + train_mem = max(train_mem, module.train_mem) + recompute_mem = in_mem + train_mem + _logger.info(f'recompute mem {recompute_mem / 1024 / 1024} MB') + self.autodist_config.memory_constraint -= recompute_mem + + def fetch_nodes(scope_node): + if scope_node.node is not None: + return [scope_node.node] + ret = [] + for child in scope_node.children: + ret += fetch_nodes(child) + return ret + + recompute_groups = [] + for module in modules: + recompute_groups.append(fetch_nodes(module)) + return recompute_mem, recompute_groups + + def label_ops(self, operator_list: List[CubeOperator]): + # label the tensors that are shared by multiple operators, examples: + # 1. the embedding matrix is shared by embedding lookup and the last linear layer + # 2. the activation tensor is shared by query, key and value projections in transformer + # label the operators that have been set to recompute + counted_tensors: Set[IRTensor] = set() + recompute_nodes: Set[IRFwOperation] = set() + for group in self.recompute_groups: + recompute_nodes.update(group) + for operator in operator_list: + if not isinstance(operator.ir_cell, IRDimops): + continue + # deduplicate activation tensors + # train_mem2in_idx only includes activation tensors without param/buffer tensors + train_mem2in_idx = self.cost_database.query_profiled_metrics( + operator).train_mem2in_idx + for i, idx in enumerate(train_mem2in_idx): + if idx == -1: + continue + if operator.in_tensors[idx].tid in counted_tensors: + operator.omit_train_idx.append(i) + else: + counted_tensors.add(operator.in_tensors[idx].tid) + # deduplicate parameter and buffer tensors + # assume the traverse order of input tensors is the same as + # the order in profiling + b_idx, w_idx = -1, -1 + for in_tensor in operator.in_tensors: + if in_tensor.is_param(): + assert not in_tensor.is_buffer() + w_idx += 1 + if in_tensor.tid in counted_tensors: + operator.omit_param_idx.append(w_idx) + else: + counted_tensors.add(in_tensor.tid) + if in_tensor.is_buffer(): + assert not in_tensor.is_param() + b_idx += 1 + if in_tensor.tid in counted_tensors: + operator.omit_buffer_idx.append(b_idx) + else: + counted_tensors.add(in_tensor.tid) + if operator.ir_cell in recompute_nodes: + operator.recompute = True + operator.omit_train_idx = list(range(len(train_mem2in_idx))) + + def query_mem(self, start: int, end: int) -> Tuple[int, int, int]: + ''' + calculate memory consumption of operators in [start, end] + Time complexity: O(log(#nodes)) + + Args: + start: the left index of the operator list + end: the right index of the operator list + + Returns: + (param_mem, buffer_mem, activation_mem) + ''' + db_inst = self.cost_database + + def leaf_handler(idx): + op = self.operator_list[idx] + if not isinstance(op.ir_cell, IRDimops): + return 0, 0, 0 + return db_inst.query_single_mem(op, 'param', round=False), \ + db_inst.query_single_mem(op, 'buffer', round=False), \ + db_inst.query_single_mem(op, 'train', round=False) + + def merger(sub_rets): + param_mem, buffer_mem, activation_mem = 0, 0, 0 + for ret in sub_rets: + param_mem += ret[0] + buffer_mem += ret[1] + activation_mem += ret[2] + return param_mem, buffer_mem, activation_mem + + return self.scope_tree_root.query(start, end, self._query_mem_cache, + leaf_handler, merger) + + def init_operators(self): + cube_nodes = self.ir_graph.select(ntype=IRFwOperation) + cube_nodes = [ + node for node in cube_nodes if not isinstance(node, IRGraphAnchor) + ] + operator_list = [] + + tid2consumers = {} + for i, ir_cell in enumerate(cube_nodes): + operator_list.append(CubeOperator(ir_cell=ir_cell)) + for t in ir_cell.inputs(): + if isinstance(t, IRTensor): + if t.tid not in tid2consumers: + tid2consumers[t.tid] = [] + tid2consumers[t.tid].append(operator_list[-1]) + + # init producer and consumer relations + for src_op_idx in range(len(operator_list) - 1): + src_op = operator_list[src_op_idx] + for t in src_op.ir_cell.outputs(): + if not isinstance(t, IRTensor): + continue + # graph outputs (like loss) have no consumer + if t.tid not in tid2consumers: + continue + for dst_op in tid2consumers[t.tid]: + src_op.add_consumer(dst_op) + dst_op.add_producer(src_op) + + # infer batch dims + seed_ops = [] + visited = set() + for op in operator_list: + if len(op.producers) == 0 and len(op.in_tensors) > 0: + contain_non_param = False + for t in op.in_tensors: + if not t.is_attr(): + contain_non_param = True + break + if contain_non_param: + _logger.info(f'add seed op {op.ir_cell}') + seed_ops.append(op) + visited.add(op.ir_cell.cid) + dq = deque(seed_ops) + while len(dq) > 0: + op = dq.popleft() + op.has_batch_dim = True + for consumer in op.consumers: + if consumer.ir_cell.cid not in visited: + visited.add(consumer.ir_cell.cid) + dq.append(consumer) + for op in operator_list: + if not op.has_batch_dim: + _logger.debug(f'{op.ir_cell} don\'t have batch dim') + + self.label_ops(operator_list) + if len(operator_list) != len(self.scope_leaf_nodes): + raise RuntimeError( + f'expect {len(self.scope_leaf_nodes)} operators, got {len(operator_list)}' + ) + for i, op in enumerate(operator_list): + self._ir_cell2idx[op.ir_cell] = i + self.operator_list = operator_list diff --git a/nnscaler/autodist/op_partition.py b/nnscaler/autodist/op_partition.py new file mode 100644 index 00000000..d52c3de7 --- /dev/null +++ b/nnscaler/autodist/op_partition.py @@ -0,0 +1,150 @@ +from nnscaler.autodist.cube_operator import CubeOperator +from nnscaler.graph.function.dimops import DimAnno, IRDimops + +import itertools +from typing import List, Tuple + + +def calc_factors(val: int, num: int) -> List[Tuple[int, ...]]: + """ + Calculate all possible factors of val that can be divided into num parts. + NOTE: 6=2*3 and 6=3*2 are considered the same. + """ + plans = [] + + def backtrace(target: int, remaining: int, path: List[int]): + if remaining == 1: + if target != 1: + plans.append(path + [target]) + else: + if target != 1 or path: + raise RuntimeError(f'invalid target {target}, path {path}') + plans.append([1]) + return + + for i in range(2, target): + if target % i == 0: + backtrace(target // i, remaining - 1, path + [i]) + + backtrace(val, num, []) + + visited = set() + for plan in plans: + plan.sort() + visited.add(tuple(plan)) + return list(visited) + + +_factor_cache = {} + + +def calc_factors_cached(val: int, num: int) -> List[List[int]]: + if (val, num) not in _factor_cache: + _factor_cache[(val, num)] = calc_factors(val, num) + return _factor_cache[(val, num)] + + +def generate_partitions( + dim_ids: List[str], + device_num: int) -> List[Tuple[Tuple[str, ...], Tuple[int, ...]]]: + """ + Generate all possible partitions of dim_ids into device_num parts. + + Args: + dim_ids: a list of dimension names. + device_num: the number of devices. + + Returns: + A list of possible partitions. + + Example: + dim_ids = ['a', 'b'], device_num = 4 + possible partitions: + (('a', 'b'), (2, 2)) + (('b', 'a'), (2, 2)) + (('a',), (4,)) + (('b',), (4,)) + """ + candidates = [] + for i in range(1, device_num + 1): + if i > len(dim_ids): + break + factors = calc_factors_cached(device_num, i) + if not factors: + break + for factor in factors: + visited = set() + for factor_permutation in itertools.permutations(factor): + if factor_permutation not in visited: + visited.add(factor_permutation) + for dim_permutation in itertools.permutations(dim_ids, i): + if -1 in dim_permutation and dim_permutation[0] != -1: + continue + candidates.append((dim_permutation, factor_permutation)) + return candidates + + +class OpPartition: + """ + OpPartition represents a partition plan for a CubeOperator. + It is defined by a list of partition_dims and a list of partition_nums. + + If there is a matrix multiplication operator with annotation 'm k+, k+ n -> m n' + where m=512, k=1024, n=2048, a partition plan can be: + partition_dims = [-1, 'm', 'k'], partition_nums = [2, 2, 2]. + It means that the operator will be split into 8 sub-operators with shape + m=256, k=512, n=2048. + NOTE: + - if -1 in partition_dims, it should be placed at the first position. + - the example partition above is different from [-1, 'k', 'm'], [2, 2, 2] + """ + + def __init__(self, partition_dims: Tuple[str, ...], + partition_nums: Tuple[int, ...], operator: CubeOperator): + self.operator = operator + self.partition_dims = partition_dims + self.partition_nums = partition_nums + self.is_partial_val = False + + if len(partition_dims) != len(partition_nums): + raise ValueError( + 'partition_dims and partition_nums should have the same length') + if len(partition_dims) != 1: + raise ValueError('only support split along one dimension for now') + + if isinstance(self.operator.ir_cell, IRDimops): + if partition_dims[0] != -1: + idx, dim = operator.dim_id2pos(partition_dims[0]) + if not operator.ir_cell.algorithms('dim').satisfy( + idx, dim, partition_nums[0]): + raise ValueError( + f'invalid partition plan {partition_dims}, {partition_nums} for {operator.op_name}' + ) + # Store the first node among partition results of the full cube node. + # Other nodes are not stored because + # 1. they share the same shape with the first node. + # 2. we can calculate th intra-communication cost without knowing the device assignment now, + # since operator is constrained to be partitioned along one dimension. + # It is used to query the computation cost in the cost database. + self.ir_cell = operator.ir_cell.algorithms('dim').instantiate( + idx, dim, partition_nums[0])[0] + else: + self.ir_cell = operator.ir_cell + + for dim, num in zip(partition_dims, partition_nums): + if dim == -1: + continue + if operator.get_reduce_type(dim) == DimAnno.ReduceType.Sum and \ + num > 1: + self.is_partial_val = True + break + else: + if partition_dims[0] != -1: + raise ValueError('only support replicated for non-dimops') + self.ir_cell = operator.ir_cell + + def is_replicated(self): + return len(self.partition_dims) == 1 and self.partition_dims[0] == -1 + + def __repr__(self): + return f'OpPartition({self.partition_dims}, {self.partition_nums})' diff --git a/nnscaler/autodist/pipeline_solver.py b/nnscaler/autodist/pipeline_solver.py new file mode 100644 index 00000000..b3e4b5d4 --- /dev/null +++ b/nnscaler/autodist/pipeline_solver.py @@ -0,0 +1,295 @@ +from .model_graph import ModelGraph, estimate_mem_lower_bound, IntervalInfo +from .spmd_solver import SPMDSolver +from .descs import * +from .autodist_config import AutoDistConfig + +import os +import time +import json +import copy +import math +import multiprocessing +import logging +from typing import List, Dict, Tuple +from pathlib import Path + +__all__ = [ + 'calc_optimal_pp_plan', +] + +_logger = logging.getLogger(__name__) + + +def _dev_num2mesh_desc(dev_num: int, base_col: int) -> MeshDesc: + if dev_num <= base_col: + return MeshDesc(1, dev_num) + else: + assert dev_num % base_col == 0 + return MeshDesc(dev_num // base_col, base_col) + + +def _calc_legal_tp_degrees(max_tp_degree: int) -> List[int]: + ret = [] + tp_degree = 1 + while tp_degree <= max_tp_degree: + ret.append(tp_degree) + tp_degree = tp_degree * 2 + return ret + + +def _collect_tp_intervals( + model_graph: ModelGraph, + cfg: AutoDistConfig, + tp_degree: int, + stage_num: int, + interval_groups: List[List[IntervalInfo]], + spmd_solver: SPMDSolver, +) -> List[int]: + ''' + collect intervals for given tp_degree and stage_num + no need to calculate all possible intervals + 1. some intervals may not fit into the memory + 2. some intervals are sub-optimal + we want to make pipeline stages as balanced as possible + ideally, we want to make the time of each stage equal. + to be robust, we can constrain the average time of each stage + is within a certain range, like no more than 200% of the global + average time + 3. some intervals are identical (exactly the same ops and topology) + + Args: + model_graph: the graph in AutoDist + cfg: the AutoDistConfig + tp_degree: the tensor parallelism degree + stage_num: the pipeline stage number + interval_groups: a list of groups. identical intervals are in a group + spmd_solver: the solver for tensor parallelism + + Returns: + selected_groups: the indices of selected interval groups + ''' + + def calc_min_mem(start, end): + param_mem, buffer_mem, activation_mem = model_graph.query_mem( + start, end) + if cfg.zero_stage == 1: + zero_group_size = tp_degree * cfg.world_size // cfg.mesh_desc.ngpus // cfg.zero_ngroups + elif cfg.zero_stage == 0: + zero_group_size = tp_degree + else: + raise RuntimeError(f'invalid zero stage {cfg.zero_stage}') + return estimate_mem_lower_bound( + param_mem=param_mem, + buffer_mem=buffer_mem, + activation_mem=activation_mem * stage_num, + plan_ngpus=tp_degree, + zero_group_size=zero_group_size, + cfg=cfg, + ) + + idxs = [0] + model_graph.get_pipeline_pivots() + [model_graph.op_num] + global_fw_span = model_graph.query_fw_span( + 0, model_graph.op_num - 1) / model_graph.autodist_config.mesh_desc.ngpus + min_fw_span = global_fw_span * cfg.max_pipeline_unbalance_ratio + max_fw_span = global_fw_span / cfg.max_pipeline_unbalance_ratio + selected_groups = [] + for i, group in enumerate(interval_groups): + start, end = group[0].start, group[0].end + if calc_min_mem(start, end) > cfg.memory_constraint: + continue + if spmd_solver.estimate_min_mem(start, end) > cfg.memory_constraint: + continue + local_fw_span = model_graph.query_fw_span(start, end) / tp_degree + if local_fw_span < min_fw_span or local_fw_span > max_fw_span: + continue + selected_groups.append(i) + return selected_groups + + +def _compute_tp_info( + model_graph: ModelGraph, + cfg: AutoDistConfig, + legal_tp_degrees: List[int], +) -> Dict[Tuple[int, int, int, int], SPMDSearchOutput]: + ''' + Pre-compute the optimal spmd plan and store the result in a dict. + The key of the dict is (tp_degree, stage_num, start, end), + which means the optimal spmd plan for the interval [start, end] + with tp_degree devices and stage_num pipeline stages. + + Args: + model_graph: the graph in AutoDist + cfg: the AutoDistConfig + legal_tp_degrees: the legal tensor parallelism device numbers + + Returns: + tp_info: the dict that stores the optimal spmd plan for each interval + ''' + + _logger.info('start to compute tp info') + interval_groups = model_graph.group_pipeline_intervals() + # if there is no solution for (tp_degree, stage_num, start, end), + # there is no solution for (tp_degree, stage_num + 1, start, end) + no_solution_states = set() + + def process_case(device_num, stage_num): + solver = SPMDSolver(graph=model_graph, + mesh_desc=_dev_num2mesh_desc( + device_num, cfg.mesh_desc.col), + autodist_config=cfg, + stage_num=stage_num) + + selected_group_idxs = _collect_tp_intervals( + model_graph, + cfg, + device_num, + stage_num, + interval_groups, + solver, + ) + intervals = [] + for i in selected_group_idxs: + start, end = interval_groups[i][0].start, interval_groups[i][0].end + if (start, end, device_num) in no_solution_states: + continue + intervals.append((start, end)) + _logger.info( + f'process case: tp {device_num}, s {stage_num}, {len(intervals)} intervals' + ) + solver_ret = solver.solve(intervals, 1) + return intervals, solver_ret + + def _calc_upper_bound(tp_degree: int): + # bubble time percentage <= bubble_ratio: + # (stage_num - 1) / (stage_num - 1 + micro_batch_num) <= bubble_ratio + # stage_num <= 1 + bubble_ratio * micro_batch_num / (1 - bubble_ratio) + bubble_ratio = cfg.max_pipeline_bubble_ratio + micro_batch_num = cfg.update_freq + upper_bound = math.floor(bubble_ratio / + (1 - bubble_ratio) * micro_batch_num + 1) + return min(cfg.mesh_desc.ngpus - tp_degree + 1, upper_bound) + + # TODO(yizhu1): use multiprocessing to speed up + tp_info = {} + for tp_degree in legal_tp_degrees: + for stage_num in range(1, _calc_upper_bound(tp_degree) + 1): + intervals, solver_ret = process_case(tp_degree, stage_num) + for interval, spmd_descs in zip(intervals, solver_ret): + start, end = interval + if spmd_descs: + for group in interval_groups: + if group[0].start == start and group[0].end == end: + for interval in group: + tp_info[(tp_degree, stage_num, interval.start, + interval.end)] = spmd_descs[0] + else: + no_solution_states.add((start, end, tp_degree)) + _logger.info( + f'fail to find a valid plan for {start}, {end}') + _logger.info('finish computing tp info') + return tp_info + + +def calc_optimal_pp_plan( + model_graph: ModelGraph, + autodist_config: AutoDistConfig) -> PipelineSearchOutput: + # TODO: based on experience, tensor parallelism should <= 8 + legal_tp_degrees = _calc_legal_tp_degrees( + min(8, autodist_config.mesh_desc.col)) + + tp_info = _compute_tp_info(model_graph, autodist_config, legal_tp_degrees) + ''' + T: dynamic programming table + T[s, pp, tp, i]: optimal time of a pipeline state, where + - s: stage number + - pp: device number used for this state + - tp: device number used for the 1st pipeline stage in this state + - i: start operator index + Transitions of T: + - leaf state: + tp == pp and s == 1: means current tp is the last one in the pipeline + T[1, pp, tp, i] = tp[tp][s][i][end_op_idx] + - non-leaf state: + T[s, pp, tp, i] = min(max(T[s-1, pp-tp, tp', j+1], tp[tp][s][i][j])) + store optimal path during dynamic programming in T as well + ''' + ngpus = autodist_config.mesh_desc.ngpus + pp_idxs = [0] + model_graph.get_pipeline_pivots() + [model_graph.op_num] + T = {} + for s in range(1, ngpus + 1): + for pp in range(s, ngpus + 1): + for tp in range(1, pp - s + 1 + 1): + if tp not in legal_tp_degrees: + continue + for ii in range(len(pp_idxs) - 1 - 1, 0 - 1, -1): + i = pp_idxs[ii] + cur_idx = (s, pp, tp, i) + T[cur_idx] = [float('inf'), (-1, -1, -1, -1)] + + if tp == pp and s == 1: + tp_idx = (tp, s, i, model_graph.op_num - 1) + if tp_idx in tp_info: + T[cur_idx][0] = tp_info[tp_idx].all_time + continue + + for jj in range(len(pp_idxs) - 1 - 1, ii, -1): + j = pp_idxs[jj] + next_pp = pp - tp + for next_tp in range(1, next_pp - (s - 1) + 1 + 1): + if next_tp not in legal_tp_degrees: + continue + prev_idx = (s - 1, next_pp, next_tp, j) + if prev_idx not in T: + continue + prev_tp_idx = (tp, s, i, j - 1) + if prev_tp_idx not in tp_info: + continue + lhs, _ = T[prev_idx] + rhs = tp_info[prev_tp_idx].all_time + val = max(lhs, rhs) + if T[cur_idx][0] > val: + T[cur_idx] = [val, prev_idx] + + best_time = float('inf') + best_state = (-1, -1, -1, -1) + micro_batch_num = autodist_config.update_freq + for stage_num in range(1, ngpus + 1): + for pp_dev_num in range(stage_num, ngpus + 1): + for tp_degree in range(1, pp_dev_num - stage_num + 1 + 1): + if tp_degree not in legal_tp_degrees: + continue + + cur_idx = (stage_num, pp_dev_num, tp_degree, 0) + if cur_idx not in T: + continue + cur_time = T[cur_idx][0] * (micro_batch_num - 1 + stage_num) + if best_time > cur_time: + best_time, best_state = cur_time, cur_idx + + _logger.info( + f'best time/s: {best_time}, state (s, pp, tp, i): {best_state}') + if best_state == (-1, -1, -1, -1): + raise RuntimeError('fail to find a valid pipeline plan') + + spmd_outs = [] + + def build_answer(s, pp, tp, i): + _, prev_idx = T[(s, pp, tp, i)] + if prev_idx[0] == -1: + tp_idx = (tp, s, i, pp_idxs[-1] - 1) + else: + j_plus_1 = prev_idx[3] + tp_idx = (tp, s, i, j_plus_1 - 1) + spmd_outs.append(tp_info[tp_idx]) + if prev_idx[0] != -1: + build_answer(*prev_idx) + + build_answer(*best_state) + + spmd_descs = [spmd_out.desc for spmd_out in spmd_outs] + pp_desc = PipelineParallelDesc(spmd_descs, [], autodist_config.mesh_desc) + stage_mems = [spmd_out.memory for spmd_out in spmd_outs] + stage_all_times = [spmd_out.all_time for spmd_out in spmd_outs] + stage_comp_times = [spmd_out.comp_time for spmd_out in spmd_outs] + return PipelineSearchOutput(pp_desc, best_time, stage_mems, stage_all_times, + stage_comp_times) diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py new file mode 100644 index 00000000..b478e598 --- /dev/null +++ b/nnscaler/autodist/spmd_solver.py @@ -0,0 +1,1167 @@ +from .model_graph import ModelGraph +from .cube_operator import CubeOperator +from .descs import * +from .util import ( + int2byte, + int4byte, + double2byte, + double4byte, +) +from .cost_database import CostDatabase +from .autodist_config import AutoDistConfig +from .op_partition import OpPartition, generate_partitions +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir import IRTensor + +import os +import copy +import time +import json +import yaml +import numpy +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, Tuple, List, Set, Any + +__all__ = [ + 'SPMDSolver', + 'calc_optimal_spmd_plan', +] + +_logger = logging.getLogger(__name__) + + +@dataclass +class PartitionCostDesc: + # the computation time of a partition + comp_time: float + # the communication cost when updating the weights + # currently, it is estimated by allreduce time + weight_update_time: float + # the minimum memory required for the partition, including + # 1. activation memory + # 2. weight memory (weight needs gradient) + # 3. gradient memory, assume the same type as weight memory + # 4. buffer memory (weight does not need gradient) + # 5. optimizer resident memory: like 1st and 2nd moment in Adam + mem: int + # additional transient mem cost: currently use the maximum tensor size + transient_mem: int + # sum of the activation tensor size + activation_mem: int + # transient memory size in the optimizer, e,g. some optimizers cast + # the weight and gradient before stepping + opt_transient_mem: int + # comm_time[i][j] is the communication time between current partition + # and i-th producer's j-th partition + comm_time: List[List[float]] + + def __repr__(self): + contents = dict() + for k, v in self.__dict__.items(): + if 'mem' in k: + k_in_mb = k + ' (MB)' + contents[k_in_mb] = v // 1024 // 1024 + else: + contents[k] = v + return str(contents) + + +class SPMDSolver: + + def __init__( + self, + graph: ModelGraph, + autodist_config: AutoDistConfig, + mesh_desc: MeshDesc, + stage_num: int = 1, + micro_batch_num: int = 1, + ): + self.mesh_desc = mesh_desc + if mesh_desc.row != 1: + raise RuntimeError(f'mesh row should be 1, but got {mesh_desc.row}') + self.device_num = mesh_desc.col + self.autodist_config = autodist_config + self.micro_batch_num = micro_batch_num + self.mem_bound = autodist_config.memory_constraint + self.verbose = autodist_config.verbose + self.is_train = autodist_config.is_train + self.graph = graph + self.pcs: Dict[str, Dict[str, PartitionConstraint]] = dict() + self.non_used_pcs: Set[PartitionConstraint] = set() + if autodist_config.pc_path: + self._load_partition_constraints(autodist_config.pc_path) + else: + _logger.info('no partition constraint is loaded') + + self.cost_database = graph.cost_database + self.cost_database.profile_comp(self.device_num) + self.stage_num = stage_num + + # assume the dataflow graph is + # a + # / \ + # b c + # | / \ + # d e f + # | / | + # g h + # the ops are stored in a topological order [a, b, d, c, e, f, g, h] + # in spmd solver, dynamic programming is used to find the optimal partition plan + # dp[p(u), M] is the optimal plan for the subgraph ending with u in partition state p, + # with memory bound M. if v is the predecessor of u in the topological order, then + # dp[p(u), M] = min(dp[q(v), M - mem(p(u))] + comm_cost(p(u), q(v))) + comp_cost(p(u)) + comm_cost(p(u)) + # where q(v) is the partition state of v, mem(p(u)) is the memory cost of p(u), + # comm_cost(p(u), q(v)) is the communication cost between p(u) and q(v), and + # comp_cost(p(u)) is the computation cost of p(u), comm_cost(p(u)) is the communication + # cost of p(u) (like the allreduce cost in model update). + # However, u and v may not be connected in the dataflow graph, like [d, c], [e, f], [f, g] + # and [g, h] in the example above. To calculate the communication cost between p(u) and q(v), + # we need to store additional information in the partition state. For example, we need to maintain + # the partition state of node d in the partition state of node c, so that we can calculate the + # communication cost when reaching node g. + # to achieve this, we calcuate the 'cut ops' for each node, which is the set of nodes that + # need to be maintained in the partition state of the current node. The cut ops for the example + # above are: + # a: [a] + # b: [a, b] + # d: [a, d] + # c: [d, c] + # e: [d, c, e] + # f: [d, e, f] + # g: [f, g] + # h: [h] + self.initialize() + + def initialize(self): + self.build_cut_ops() + self.init_op_partitions() + self.build_following_relationships() + self.calc_partition_info() + + def _load_partition_constraints(self, pc_path: str): + pc_path = Path(pc_path) + if pc_path.exists(): + try: + with open(pc_path, 'r') as f: + pc_yamls = yaml.safe_load(f) + for pc_yaml in pc_yamls: + pc = PartitionConstraint.from_yaml(pc_yaml) + self.non_used_pcs.add(pc) + cur_name_pcs = self.pcs.setdefault(pc.name, {}) + if pc.parent_module in cur_name_pcs: + _logger.warning( + f'find duplicate partition constraint in {pc.parent_module}, omit {pc}' + ) + else: + cur_name_pcs[pc.parent_module] = pc + except Exception: + _logger.exception( + f'fail to load partition constraints from {pc_path}') + self.pcs = dict() + else: + _logger.warning(f'pc path {pc_path} does not exist') + + def get_operator(self, idx: int) -> CubeOperator: + return self.graph.operator_list[idx] + + def get_op_partition_count(self, idx: int) -> int: + return len(self._op_partitions[idx]) + + def init_op_partitions(self): + ''' + Autodist adopts a heuristic to force operators to be replicated if the following conditions + are satisfied: + 1. the operator does not have a batch dimension, i.e., the operator is not data dependent + 2. the sum of inputs and outputs size is smaller than `force_replica_threshold`, which implies + that the operator is small enough to be replicated + ''' + force_replica_threshold = 0 + for operator in self.graph.operator_list: + # In modern deep learning models + # 1. norm operators are in the backbone of the model (layernorm in transformer, batchnorm in CNN, etc.) + # 2. the output size of a norm operator is related to the batch size (do not replicate norm operators) + # As a result, if the sum of inputs and outputs size of an operator is smaller than + # the minimum output size of a norm operator, replicate it is safe. + if 'norm' in operator.op_name.lower(): + norm_size = operator.out_tensors[0].nelement() + if force_replica_threshold == 0: + force_replica_threshold = norm_size + else: + force_replica_threshold = min(force_replica_threshold, + norm_size) + _logger.info(f'force_replica_threshold is {force_replica_threshold}') + + def should_force_replica(operator: CubeOperator) -> bool: + if operator.has_batch_dim: + return False + cnt = 0 + for item in operator.ir_cell.inputs(): + if isinstance(item, IRTensor): + cnt += item.nelement() + for item in operator.ir_cell.outputs(): + if isinstance(item, IRTensor): + cnt += item.nelement() + return cnt < force_replica_threshold + + # Do not allow to partition shared parameters currently + param2consumers = defaultdict(list) + for operator in self.graph.operator_list: + for tensor in operator.ir_cell.inputs(): + if isinstance(tensor, IRTensor) and tensor.is_param(): + param2consumers[tensor].append(operator.ir_cell) + shared_param_constraints = defaultdict(set) + for param, consumers in param2consumers.items(): + if len(consumers) == 1: + continue + _logger.info(f'find shared parameter {param} in {consumers}') + for consumer in consumers: + if not isinstance(consumer, IRDimops): + # always replicate non-dimops + continue + idx = consumer.inputs().index(param) + shape_anno = consumer.anno.input(idx) + for dim_anno in shape_anno.dims: + shared_param_constraints[consumer].add(dim_anno.name) + + def is_valid_partition(operator: CubeOperator, p_ids: List[Any], + p_nums: List[int]) -> bool: + ''' + use a function to filter invalid partitions. Note: the partition + representation will be refined in the future. + Args: + operator (CubeOperator): the operator to be partitioned + p_ids (List[Any]): the partition identifiers, -1 means replicated + p_nums (List[int]): the partition numbers + Returns: + bool: True if the partition is valid, False otherwise + + Examples: + >>> # matmul with annotation m k+, k+ n -> m n + >>> # partition on 8 devices, possible inputs + >>> # partition along dim 'm' + >>> is_valid_partition(matmul, ['m',], [8,]) + >>> # replicated across all devices + >>> is_valid_partition(matmul, [-1,], [8,]) + >>> # partition along dim 'm', 'n' and 'k' (currently not supported) + >>> # each device has a partial value with shape: + >>> m // 2 k // 2, k // 2 n // 2 -> m // 2, n // 2 + >>> is_valid_partition(matmul, ['m', 'n', 'k'], [2, 2, 2]) + >>> # partition along dim 'm' and -1 (currently not supported) + >>> # each device has a complete value with shape: + >>> m // 2 k, k n -> m // 2, n + >>> is_valid_partition(matmul, ['m', -1], [2, 4]) + ''' + if len(p_ids) != len(p_nums): + raise RuntimeError( + f'invalid partition {p_ids} {p_nums} for {operator.op_name}' + ) + + if self.mesh_desc.col == 1: + return True + + # in order to reduce search space and simplify the communication pattern, + # we constrain operators to be partitioned along only one dimension + for u, v in zip(p_ids, p_nums): + if v != self.mesh_desc.col: + return False + + if len(p_ids) != 1: + raise RuntimeError( + f'exactly one dimension should be partitioned, but got {p_ids} {p_nums}' + ) + + # force replica for non-dimops + if not isinstance(operator.ir_cell, IRDimops): + return p_ids[0] == -1 + + p_idx, p_dim = operator.dim_id2pos(p_ids[0]) + + if operator.ir_cell in shared_param_constraints and isinstance( + operator.ir_cell, IRDimops): + if (p_ids[0] != -1) and ( + p_ids[0] in shared_param_constraints[operator.ir_cell]): + return False + + if operator.op_name in self.pcs: + if not isinstance(operator.ir_cell, IRDimops): + raise RuntimeError( + f'operator {operator.op_name} is not a dimops, check the partition constraint' + ) + module_info = [(module_path, module_type) + for module_path, module_type in + operator.ir_cell.module_stack.items()] + module_info.reverse() + selected_pc = None + for module_path, module_type in module_info: + for pc in self.pcs[operator.op_name].values(): + if pc.parent_module in str(module_type): + selected_pc = pc + _logger.debug( + f'find partition constraint {pc} for {operator.ir_cell}' + ) + break + if selected_pc is not None: + break + if selected_pc is not None: + _logger.debug( + f'find partition constraint {selected_pc} for {operator.ir_cell} {module_info}' + ) + self.non_used_pcs.discard(selected_pc) + for u, v in zip(p_ids, p_nums): + if u == -1: + if not selected_pc.replica_allowed: + return False + else: + allowed_pids = [ + operator.pos2dim_id(pos) + for pos in selected_pc.allowed_partition_dims + ] + if u not in allowed_pids: + return False + + if p_ids[0] != -1: + if not operator.ir_cell.algorithms('dim').satisfy( + p_idx, p_dim, p_nums[0]): + return False + return True + + def build_op_partitions(operator: CubeOperator) -> List[OpPartition]: + # force replica for non-dimops + if not isinstance(operator.ir_cell, IRDimops): + candidates = [((-1,), (self.device_num,))] + else: + if should_force_replica(operator): + _logger.debug(f'force replica {operator.ir_cell}') + candidates = [((-1,), (self.device_num,))] + elif self.device_num == 1: + candidates = [((-1,), (1,))] + else: + # python set is not stable + p_dims = [-1] + sorted(operator.parallelable_dims) + + candidates = generate_partitions(p_dims, self.device_num) + + op_partitions = [] + for dim_ids, p_nums in candidates: + if is_valid_partition(operator, dim_ids, p_nums): + op_partitions.append( + OpPartition( + partition_dims=dim_ids, + partition_nums=p_nums, + operator=operator, + )) + + return op_partitions + + # generate partitions for each operator + self._op_partitions: List[List[OpPartition]] = list() + replicated_ops = defaultdict(list) + for i, operator in enumerate(self.graph.operator_list): + self._op_partitions.append(build_op_partitions(operator)) + if not self._op_partitions[-1]: + raise RuntimeError( + f'node {operator} has no valid partition, check profiler, partition constraint and filter' + ) + if len(self._op_partitions[-1]) == 1 and \ + isinstance(operator.ir_cell, IRDimops): + if operator.ir_cell == self._op_partitions[-1][0].ir_cell: + replicated_ops[operator.ir_cell.signature].append( + operator.ir_cell) + if replicated_ops: + for signature, ops in replicated_ops.items(): + _logger.info(f'find {len(ops)} replicated {signature}') + for op in ops: + _logger.info(f'\t{op}\n\t{op.comment}\n\n') + if self.non_used_pcs: + _logger.warning( + f'find unused partition constraints {self.non_used_pcs}') + _logger.info('finish building op partitions') + + # use a union-find set to find the oldest operator in a following chain + def get_father_id(self, i): + if self.father_ids[i] == i: + return i + self.father_ids[i] = self.get_father_id(self.father_ids[i]) + return self.father_ids[i] + + def build_following_relationships(self): + # self.producers[i]: the indices of the operators that produce tensors for the i-th operator + self.producers: List[List[int]] = list() + for i, operator in enumerate(self.graph.operator_list): + self.producers.append([ + self.graph.get_op_idx(producer) + for producer in operator.producers + ]) + + # important: build following relationships + # a + # / \ + # b c + # | | + # d e + # | | + # f g + # \ / + # h + # a: layer norm + # b, c: reshape + # d, e: transpose + # f, g: view + # h: matmul + # assume operators are stored in a topological order [a, b, d, f, c, e, g, h] + # in order to reduce the search space and keep the partition plan optimal, + # we group some operators into 4 following chains + # 1. a + # 2. b -> d -> f + # 3. c -> e -> g + # 4. h + # in a chain, there are no communication adapters between operators. + # follow_ids[i] is the index of the operator that i follows, if follow_ids[i] = i, i is the oldest operator in the chain + # father_ids[i] is the index of the oldest operator in the following chain that i belongs to + # in the example above, + # follow_ids = [0, 1, 1, 2, 4, 4, 5, 7] + # father_ids = [0, 1, 1, 1, 4, 4, 4, 7] + self.follow_ids = list(range(self.graph.op_num)) + self.father_ids = list(range(self.graph.op_num)) + + for i, op in enumerate(self.graph.operator_list): + # - op consumes tensors from only one producer + # - op has only one input tensor + # - the producer has only one input tensor + if len(self.producers[i]) == 1: + if len(op.in_tensors) == 1: + j = self.producers[i][0] + # constrain the following chain starts from a unary operator + if len(self.graph.operator_list[j].in_tensors) == 1: + self.follow_ids[i] = j + self.father_ids[i] = self.get_father_id(j) + + _logger.info('finish building following relationships') + + # after follow, only keep the newest one in cut ops + for i in range(self.graph.op_num): + fathers = set() + pre_cut_ops = copy.copy(self.cut_ops[i]) + self.cut_ops[i] = [] + for j in range(len(pre_cut_ops)): + u = pre_cut_ops[-1 - j] + if self.get_father_id(u) in fathers: + continue + else: + fathers.add(self.get_father_id(u)) + self.cut_ops[i].append(u) + self.cut_ops[i].sort() + + def find_idx_map(src_op, tgt_op): + ret = [] + for i, src_t in enumerate(src_op.ir_cell.outputs()): + if not isinstance(src_t, IRTensor): + continue + for j, tgt_t in enumerate(tgt_op.ir_cell.inputs()): + if not isinstance(tgt_t, IRTensor): + continue + if src_t == tgt_t: + ret.append((i, j)) + return ret + + # After building following relationships for each operator, we can build the following chains + # for each operator's each partition. The communication cost between partitions in a + # following chain is 0, e,g. no communication adapter will be generated. + # p_fathers[i][j]: + # assume i-th operator is in the following chain indexed by fi = get_father_id(i), + # i-th operator's j-th partition follows fi-th operator's p_fathers[i][j]-th partition + # For example, there is a following chain composed of 3 operators: + # x1 = layer_norm(x0), annotation: a, b, c^ -> a, b, c^ + # x2 = permute(x1, [0, 2, 1]), annotation: a, b, c -> a, c, b + # x3 = gelu(x2), annotation: a, b, c -> a, b, c + # assume x0's shape is [2, 1024, 4096] and the device number is 2 + # then the partitions for each operators are: + # layer_norm: [(-1,), (2,)], [('a',), (2,)], [('b',), (2,)] + # permute: [(-1,), (2,)], [('a',), (2,)], [('b',), (2,)], [('c',), (2,)] + # gelu: [(-1,), (2,)], [('a',), (2,)], [('b',), (2,)], [('c',), (2,)] + # the p_fathers for each operator are: + # layer_norm: [0, 1, 2] + # permute: [0, 1, 2, -1] + # gelu: [0, 1, -1, 2] + def calc_father4op_partition(): + p_fathers = [] + father_id2preserved_pids = {} + for i in range(self.graph.op_num): + fi = self.get_father_id(i) + if fi == i: + p_fathers.append(list(range( + self.get_op_partition_count(i)))) + father_id2preserved_pids[i] = set(p_fathers[-1]) + else: + cur_p_fathers = [-1] * self.get_op_partition_count(i) + for producer in self.producers[i]: + if self.get_father_id(producer) != fi: + continue + # assume there is only one tensor from producer to consumer + idx_map = find_idx_map(self.get_operator(producer), + self.get_operator(i)) + if len(idx_map) != 1: + raise RuntimeError( + f'find multiple or no idx_map {idx_map}') + u, v = idx_map[0] + for j, tgt_p in enumerate(self._op_partitions[i]): + have_changed = False + p_father = -1 + for k, src_p in enumerate( + self._op_partitions[producer]): + # use shape to check follow relationship between partitions + # TODO: is this correct? what if the shape is the same but the partition is different? + if src_p.ir_cell.outputs()[u].shape == tgt_p.ir_cell.inputs()[v].shape and \ + not src_p.is_partial_val: + p_producer = p_fathers[producer][k] + if p_producer == -1: + p_father = -1 + else: + if not have_changed: + p_father = p_producer + have_changed = True + # if p_father = -1, this partition will be filtered out + if cur_p_fathers[j] != -1: + assert p_father == cur_p_fathers[ + j], f'{i} {self.get_operator(i).ir_cell} {fi} {self.get_operator(fi).ir_cell}' + cur_p_fathers[j] = p_father + p_fathers.append(cur_p_fathers) + # -1 will be filtered out in the intersection operation below + father_id2preserved_pids[fi] = father_id2preserved_pids[ + fi].intersection(set(p_fathers[-1])) + return p_fathers, father_id2preserved_pids + + p_fathers, father_id2preserved_pids = calc_father4op_partition() + + # filter useless partitions in following chains + for i in range(self.graph.op_num): + filtered_partitions = [] + fi = self.get_father_id(i) + for p_father, partition in zip(p_fathers[i], + self._op_partitions[i]): + if p_father in father_id2preserved_pids[fi]: + filtered_partitions.append(partition) + self._op_partitions[i] = filtered_partitions + if not filtered_partitions: + raise RuntimeError( + f'fail to find valid partition for {self.get_operator(i).ir_cell}' + ) + + self.p_fathers, _ = calc_father4op_partition() + + # reorder partition + for i in range(self.graph.op_num): + p_num = self.get_op_partition_count(i) + if p_num == 1: + continue + if self.get_father_id(i) == i: + continue + cur_p_fathers = self.p_fathers[i] + partitions = [None] * p_num + for j, p_father in enumerate(self.p_fathers[i]): + if p_father == -1: + raise RuntimeError(f'find -1 in p_fathers for operator {i}') + partitions[p_father] = self._op_partitions[i][j] + self._op_partitions[i] = partitions + self.p_fathers[i] = list(range(p_num)) + + _logger.info('finish filtering useless partitions') + + def calc_partition_cost(self, op_idx: int, partition_idx: int): + """ + Calculate the latency, memory, and communication features of a partition option. + + Args: + op_idx: the index of the current op + partition_idx: the index of the current partition option + + Returns: + a PartitionCostDesc object containing the calculated features + """ + micro_batch_num = self.micro_batch_num + is_train = self.autodist_config.is_train + tgt_p = self._op_partitions[op_idx][partition_idx] + if is_train: + # only calculate the communication cost for the weight that requires gradient + weights_require_grad = [] + for in_tensor in tgt_p.operator.ir_cell.inputs(): + if isinstance(in_tensor, IRTensor) and in_tensor.is_param(): + weights_require_grad.append(in_tensor.requires_grad) + # currently not support the case that there are two weights, one requires grad, the other not + # TODO: support this case when we encounter it. + assert all(weights_require_grad) or not any( + weights_require_grad + ), f'expect all weights require grad or not, got {weights_require_grad}' + if isinstance(tgt_p, IRDimops) and any(weights_require_grad): + weight_comm_time = self.cost_database.calc_weight_update_time( + cur_partition=tgt_p) + else: + weight_comm_time = 0 + else: + weight_comm_time = 0 + + if not self.autodist_config.consider_mem: + node_mem, node_buffer, act_mem, opt_transient_mem = 0, 0, 0, 0 + else: + node_mem, node_buffer, act_mem, opt_transient_mem = self.cost_database.get_mem_and_buffer( + tgt_p, self.is_train, self.stage_num) + + # communication cost induced by partitioning activation tensors of the given op partition + comm_vecs = [] + for producer in self.producers[op_idx]: + comm_vec = [0.0] * self.get_op_partition_count(producer) + for k, src_p in enumerate(self._op_partitions[producer]): + fw_comm_time = self.cost_database.estimate_comm_cost( + src_p, tgt_p, True) + if is_train and src_p.operator.ir_cell.mirror is not None: + bw_comm_time = self.cost_database.estimate_comm_cost( + src_p, tgt_p, False) + else: + bw_comm_time = 0 + intra_time = micro_batch_num * (fw_comm_time + bw_comm_time) + # double check the follow chain + if self.get_father_id(op_idx) == self.get_father_id( + producer) and intra_time == 0: + if src_p.operator.ir_cell.mirror is not None: + if self.p_fathers[op_idx][ + partition_idx] != self.p_fathers[producer][k]: + _logger.warning( + f'Unexpected comm cost, set to inf: {src_p.ir_cell} to {tgt_p.ir_cell}' + ) + intra_time = float('inf') + comm_vec[k] = intra_time + + comm_vecs.append(comm_vec) + + if isinstance(tgt_p.ir_cell, IRDimops): + comp_time = self.cost_database.query_comp_time( + op_or_partition=tgt_p, + recompute=tgt_p.operator.recompute, + is_train=is_train) + else: + comp_time = 0.0 + + return PartitionCostDesc( + comp_time=micro_batch_num * comp_time, + weight_update_time=weight_comm_time, + mem=node_mem, + transient_mem=node_buffer, + activation_mem=act_mem, + opt_transient_mem=opt_transient_mem, + comm_time=comm_vecs, + ) + + def calc_partition_info(self): + self.partition_info: List[List[PartitionCostDesc]] = list() + for i in range(self.graph.op_num): + cur_info = [] + _logger.info(f'calc partition info for {self.get_operator(i)}') + for j in range(self.get_op_partition_count(i)): + cost_desc = self.calc_partition_cost(i, j) + if cost_desc.comp_time == float('inf'): + _logger.warning( + f'profile error {self.get_operator(i).ir_cell}, reset compute time to 0.0' + ) + cost_desc.comp_time = 0.0 + cur_info.append(cost_desc) + _logger.info(f'{self._op_partitions[i][j]} {cost_desc}') + self.partition_info.append(cur_info) + _logger.info('finish spmd solver initializetion') + + def estimate_min_mem(self, start: int, end: int) -> int: + ''' + different from the estimation in ModelGraph, this function + in smaller granularity, i.e., the memory cost of a partition. + it helps to reduce the search cost in pipeline parallelism. + + Args: + start (int): the left index of the interval + end (int): the right index of the interval + + Returns: + int: the estimated minimum memory cost of the interval in bytes + ''' + node_mem, act_mem, opt_mem, tmp_mems = 0, 0, 0, [] + for i in range(start, end + 1): + cur_node_mem, cur_act_mem, cur_opt_mem, tmp_mem = [], [], [], [] + for j, tgt_p in enumerate(self._op_partitions[i]): + p_cost_desc = self.partition_info[i][j] + cur_node_mem.append(p_cost_desc.mem) + cur_act_mem.append(p_cost_desc.activation_mem) + cur_opt_mem.append(p_cost_desc.opt_transient_mem) + tmp_mem.append(p_cost_desc.transient_mem) + node_mem += min(cur_node_mem) + act_mem += min(cur_act_mem) + opt_mem += min(cur_opt_mem) + tmp_mems.append(min(tmp_mem)) + min_mem = node_mem - act_mem + max(act_mem, opt_mem) + if not tmp_mems: + raise RuntimeError('fail to estimate min mem') + tmp_mems.sort() + tmp_mems.reverse() + if len(tmp_mems) == 1 or not self.autodist_config.is_train: + return min_mem + tmp_mems[0] + else: + return min_mem + tmp_mems[0] + tmp_mems[1] + + def gen_min_mem_plan_greedy(self, start: int, + end: int) -> List[Tuple[int, int]]: + ''' + generate the minimum memory plan for the interval [start, end] in a greedy way. + for each operator, we choose the partition with the minimum memory cost. + NOTE: do not guarantee the plan satisfies the memory constraint. + + Args: + start (int): the left index of the interval + end (int): the right index of the interval + + Returns: + List[Tuple[int, int]]: the minimum memory plan + ''' + plan = [] + for i in range(start, end + 1): + cur_mem = [] + for desc in self.partition_info[i]: + cur_mem.append(desc.mem) + plan.append((i, cur_mem.index(min(cur_mem)))) + return plan + + def satisfy_mem_constraint(self, plan: List[Tuple[int, int]]) -> bool: + mem, act_mem, opt_transient_mem, transient_mem = 0, 0, 0, [] + for op_idx, p_idx in plan: + desc = self.partition_info[op_idx][p_idx] + mem += desc.mem + act_mem += desc.activation_mem + opt_transient_mem += desc.opt_transient_mem + transient_mem.append(desc.transient_mem) + cost = mem - act_mem + max(act_mem, opt_transient_mem) + # A heuristic that helps to estimate the memory cost accurately. + # It is hard to fully reuse large memory blocks in the cached allocator. + # - in training, use the maximum 2 transient memory + # - in inference, use the largest transient memory + if transient_mem: + transient_mem.sort() + transient_mem.reverse() + if len(transient_mem) == 1 or not self.autodist_config.is_train: + cost += transient_mem[0] + else: + cost += transient_mem[0] + transient_mem[1] + return cost <= self.mem_bound + + def build_cut_ops(self): + cid2idx = {} + for i, op in enumerate(self.graph.operator_list): + cid2idx[op.ir_cell.cid] = i + out_degs = [len(op.consumers) for op in self.graph.operator_list] + unclosed_idx = set() + self.cut_ops: List[List[int]] = list() + for i, op in enumerate(self.graph.operator_list): + for pred in op.producers: + pred_idx = cid2idx[pred.ir_cell.cid] + assert pred_idx in unclosed_idx + out_degs[pred_idx] -= 1 + if out_degs[pred_idx] == 0: + unclosed_idx.remove(pred_idx) + ret = list(unclosed_idx) + [i] + ret.sort() + self.cut_ops.append(ret) + if len(op.consumers) > 0: + unclosed_idx.add(i) + + def _solve_by_ilp(self, start: int, end: int) -> SPMDSearchOutput: + import pulp + import multiprocessing + from pulp import LpVariable, LpProblem, LpMinimize, LpStatus, lpSum, lpDot + tic = time.time() + + # 1. define the variables + # s[i][j] = 1 if the i-th operator selects the j-th partition + s = [] + # e[i][j][k] = 1, the i-th edge's source selects the j-th partition and the destination selects the k-th partition + e = [] + + num_nodes = 0 + for i in range(start, end + 1): + fi = self.get_father_id(i) + p_num = self.get_op_partition_count(i) + if fi == i or fi < start: + if p_num == 1: + s.append([1]) + else: + num_nodes += 1 + s.append( + LpVariable.matrix(f's[{i}]', (range(p_num),), + cat='Binary')) + else: + s.append(s[fi - start]) + + num_edges = 0 + for dst in range(start, end + 1): + for src in self.producers[dst]: + j = dst - start + i = src - start + # in pipeline parallelism, the producer may be in the previous stage + # omit the communication cost in this case + if i < 0: + continue + if len(s[i]) == 1: + e.append(s[j]) + elif len(s[j]) == 1: + e.append(s[i]) + else: + num_edges += 1 + e.append( + LpVariable.matrix(f'e[{i},{j}]', + (range(len(s[i]) * len(s[j])),), + cat='Binary')) + + # 2. set initial value for warm start + plan = self.gen_min_mem_plan_greedy(start, end) + for op_idx, p_idx in plan: + s_idx = op_idx - start + if len(s[s_idx]) == 1: + continue + for i in range(len(s[s_idx])): + s[s_idx][i].setInitialValue(i == p_idx) + + # 3. define the objective function + prob = LpProblem('SPMD', LpMinimize) + # inner cost + obj = 0 + for i in range(start, end + 1): + cost = [] + for desc in self.partition_info[i]: + cost.append(desc.comp_time + desc.weight_update_time) + obj += lpDot(s[i - start], cost) + + # intra communication cost + offset = 0 + for dst in range(start, end + 1): + dst_p_num = self.get_op_partition_count(dst) + j = dst - start + for idx, src in enumerate(self.producers[dst]): + if src < start: + continue + src_p_num = self.get_op_partition_count(src) + i = src - start + cost = [0 for _ in range(src_p_num * dst_p_num)] + for k, desc in enumerate(self.partition_info[dst]): + for l in range(src_p_num): + cost[l * dst_p_num + k] = desc.comm_time[idx][l] + obj += lpDot(e[offset], cost) + offset += 1 + assert offset == len(e) + + prob += obj + # 4. define the constraints + + # 4.1. each node can only choose one partition + for i in range(start, end + 1): + fi = self.get_father_id(i) + if fi == i or fi < start: + prob += lpSum(s[i - start]) == 1 + + # 4.2. satisfy memory constraint + mem = 0 + act_mem = 0 + opt_transient_mem = 0 + max_act_opt_transient = LpVariable('max_act_opt_transient', lowBound=0) + max_transient = LpVariable('max_transient', lowBound=0) + for i in range(start, end + 1): + cur_mem = [] + cur_act_mem = [] + cur_opt_transient_mem = [] + cur_transient_mem = [] + for desc in self.partition_info[i]: + cur_mem.append(desc.mem) + cur_act_mem.append(desc.activation_mem) + cur_opt_transient_mem.append(desc.opt_transient_mem) + cur_transient_mem.append(desc.transient_mem) + mem += lpDot(s[i - start], cur_mem) + act_mem += lpDot(s[i - start], cur_act_mem) + opt_transient_mem += lpDot(s[i - start], cur_opt_transient_mem) + prob += lpDot(s[i - start], cur_transient_mem) <= max_transient + prob += act_mem <= max_act_opt_transient + prob += opt_transient_mem <= max_act_opt_transient + if self.autodist_config.is_train: + transient_coef = 2 + else: + transient_coef = 1 + prob += mem - act_mem + max_act_opt_transient + transient_coef * max_transient <= self.mem_bound + + # 4.3. constraint over e + offset = 0 + for dst in range(start, end + 1): + for src in self.producers[dst]: + if src < start: + continue + dst_p_num = self.get_op_partition_count(dst) + src_p_num = self.get_op_partition_count(src) + if dst_p_num == 1 or src_p_num == 1: + offset += 1 + continue + prob += lpSum(e[offset]) == 1 + j = dst - start + i = src - start + for row in range(src_p_num): + prob += lpSum([ + e[offset][row * dst_p_num + col] + for col in range(dst_p_num) + ]) <= s[i][row] + for col in range(dst_p_num): + prob += lpSum([ + e[offset][row * dst_p_num + col] + for row in range(src_p_num) + ]) <= s[j][col] + offset += 1 + assert offset == len(e) + assert 'PULP_CBC_CMD' in pulp.listSolvers(onlyAvailable=True), ( + "Please install ILP solvers by 'sudo apt install coinor-cbc'") + + solver = pulp.PULP_CBC_CMD(mip=True, + msg=self.verbose, + timeLimit=600, + threads=multiprocessing.cpu_count()) + + prob.solve(solver) + status = prob.status + objective = pulp.value(prob.objective) + # corner case: no variables + if num_nodes == 0: + assert num_edges == 0 + objective = obj.constant + else: + objective = float(objective) if objective is not None else -1.0 + _logger.debug(f'\n {prob}') + _logger.debug( + f'status: {status}, objective: {objective}, time: {time.time() - tic}' + ) + if prob.status in [pulp.LpStatusInfeasible] or objective < 0: + return None + + def get_non_zero_index(binary_vector): + """Get the index of non-zero item in a vector.""" + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + assert ct == 1 + return ret + + s_val = [-1] * (end - start + 1) + for i in range(start, end + 1): + s_val[i - start] = get_non_zero_index(s[i - start]) + e_val = [-1] * len(e) + offset = 0 + for dst in range(start, end + 1): + for src in self.producers[dst]: + if src < start: + continue + j = dst - start + i = src - start + e_val[offset] = get_non_zero_index(e[offset]) + i_spec_index = e_val[offset] // len(s[j]) + j_spec_index = e_val[offset] % len(s[j]) + assert s_val[i] == i_spec_index + assert s_val[j] == j_spec_index + offset += 1 + plans = [] + all_time_cost = objective + inner_time_cost, mem_cost = 0, 0 + for i in range(start, end + 1): + plans.append((i, s_val[i - start])) + p_cost_desc = self.partition_info[i][s_val[i - start]] + inner_time_cost += p_cost_desc.comp_time + p_cost_desc.weight_update_time + mem_cost += p_cost_desc.mem + return SPMDSearchOutput(self.partition_path2desc(plans), + mem_cost / 1024 / 1024 / 1024, all_time_cost, + inner_time_cost) + + def do_ilp(self, intervals: List[Tuple[int, int]], + topk: int) -> List[SPMDSearchOutput]: + if topk != 1: + raise RuntimeError('topk != 1 is not supported') + ret = [] + for start, end in intervals: + solver_out = self._solve_by_ilp(start, end) + if solver_out is not None: + ret.append([solver_out]) + else: + ret.append([]) + _logger.debug(f'finish solving interval {start} {end}') + return ret + + def do_dp(self, intervals: List[Tuple[int, int]], + topk: int) -> List[SPMDSearchOutput]: + + idx_map = {} + for idx, item in enumerate(intervals): + idx_map[item] = idx + + ret = [None] * len(intervals) + + in_fname = f'./cpp_in_{self.device_num}_{self.stage_num}.bin' + out_fname = f'./cpp_out_{self.device_num}_{self.stage_num}.bin' + """ + cpp_in format + is_train node_num mem_bound mem_div topk interval_num + interval_num tuples of (start_idx, end_idx) + node_num groups of + - id + - father_id + - cut_lens, [cut_ids] + - partition nums, [p_father, comp_time + weight_update_time, train_mem, buffer_mem, act_mem, opt_mem] + - producer_num, [producer_id, comm time mat] + """ + # TODO: hardcode value which should change according to device memory + mem_div = 64 + with open(in_fname, 'wb') as f: + # header + if self.is_train: + f.write(int2byte(0)) + else: + f.write(int2byte(1)) + f.write( + int2byte(self.graph.op_num) + + int2byte(int(self.mem_bound) // mem_div) + int2byte(mem_div)) + f.write(int2byte(topk) + int2byte(len(intervals))) + # intervals + for u, v in intervals: + f.write(int2byte(u) + int2byte(v)) + # partition info + for idx in range(self.graph.op_num): + f.write(int2byte(idx)) + f.write(int2byte(self.father_ids[idx])) + + f.write(int2byte(len(self.cut_ops[idx]))) + for cut_op in self.cut_ops[idx]: + f.write(int2byte(cut_op)) + + f.write(int2byte(self.get_op_partition_count(idx))) + for i, partition in enumerate(self._op_partitions[idx]): + p_cost_desc = self.partition_info[idx][i] + f.write( + double2byte(p_cost_desc.comp_time + + p_cost_desc.weight_update_time)) + + f.write( + int2byte(p_cost_desc.mem // mem_div) + + int2byte(p_cost_desc.transient_mem // mem_div) + + int2byte(p_cost_desc.activation_mem // mem_div) + + int2byte(p_cost_desc.opt_transient_mem // mem_div) + + int2byte(self.p_fathers[idx][i])) + + f.write(int2byte(len(self.producers[idx]))) + for p_i, producer in enumerate(self.producers[idx]): + f.write(int2byte(producer)) + for i, tgt_p in enumerate(self._op_partitions[idx]): + comm_time = self.partition_info[idx][i].comm_time + for j, src_p in enumerate( + self._op_partitions[producer]): + f.write(double2byte(comm_time[p_i][j])) + + os.system( + f'{str(Path.home())}/.autodist/solver {in_fname} {out_fname} {int(self.autodist_config.verbose)}' + ) + """ + cpp_out file format + interval_num parts, each part's format is + start_idx end_idx path_num path_len + each path is (opt_time, inner_time, opt_mem, a sequence of path_len (op_idx, partition_idx)) + """ + with open(out_fname, 'rb') as f: + data = f.read() + offset = 0 + for _ in range(len(intervals)): + descs = [] + start_level = int4byte(data[offset:offset + 4]) + end_level = int4byte(data[offset + 4:offset + 8]) + num = int4byte(data[offset + 8:offset + 12]) + path_len = int4byte(data[offset + 12:offset + 16]) + offset += 16 + for i in range(num): + opt_time = double4byte(data[offset:offset + 8]) + offset += 8 + inner_time = double4byte(data[offset:offset + 8]) + offset += 8 + opt_mem = int4byte(data[offset:offset + 4]) + offset += 4 + plans = [] + for j in range(path_len): + cur_op_idx = int4byte(data[offset:offset + 4]) + offset += 4 + cur_p_idx = int4byte(data[offset:offset + 4]) + offset += 4 + plans.append((cur_op_idx, cur_p_idx)) + desc = self.partition_path2desc(plans) + descs.append( + SPMDSearchOutput(desc, + opt_mem * mem_div / 1024 / 1024 / 1024, + opt_time, inner_time)) + ret[idx_map[(start_level, end_level)]] = descs + os.system(f'rm {in_fname} {out_fname}') + return ret + + def solve(self, intervals: List[Tuple[int, int]], + topk: int) -> List[SPMDSearchOutput]: + if self.autodist_config.solver == 'ilp': + return self.do_ilp(intervals, topk) + elif self.autodist_config.solver == 'dp': + return self.do_dp(intervals, topk) + else: + raise RuntimeError( + f'unsupported solver {self.autodist_config.solver}') + + def partition_path2desc( + self, plans: List[Tuple[int, int]]) -> Dict[int, NodePartitionDesc]: + partitions = [self._op_partitions[u][v] for u, v in plans] + + partition_descs = {} + for p in partitions: + op = p.operator + p_info = tuple([ + (op.dim_id2pos(dim), num) + for dim, num in zip(p.partition_dims, p.partition_nums) + ]) + partition_descs[op.ir_cell.cid] = NodePartitionDesc(desc=p_info) + + return TensorParallelDesc(partition_descs=partition_descs, + mesh_desc=self.mesh_desc, + recompute_groups=[]) + + +def calc_optimal_spmd_plan( + model_graph: ModelGraph, + autodist_config: AutoDistConfig) -> PipelineSearchOutput: + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=autodist_config.mesh_desc, + autodist_config=autodist_config, + stage_num=1, + micro_batch_num=autodist_config.update_freq, + ) + + spmd_outs = spmd_solver.solve([(0, model_graph.op_num - 1)], 1)[0] + if not spmd_outs: + raise RuntimeError( + 'fail to find a valid partition plan, ' \ + 'try to increase device number or reduce batch size' + ) + spmd_out = spmd_outs[0] + pp_desc = PipelineParallelDesc( + spmd_descs=[spmd_out.desc], + recompute_groups=spmd_out.desc.recompute_groups, + mesh_desc=spmd_out.desc.mesh_desc, + ) + pp_out = PipelineSearchOutput( + desc=pp_desc, + e2e_time=spmd_out.all_time, + stage_mems=[spmd_out.memory], + stage_all_times=[spmd_out.all_time], + stage_comp_times=[spmd_out.comp_time], + ) + return pp_out diff --git a/nnscaler/autodist/util.py b/nnscaler/autodist/util.py new file mode 100644 index 00000000..53ee350f --- /dev/null +++ b/nnscaler/autodist/util.py @@ -0,0 +1,83 @@ +from .descs import NodePartitionDesc +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation + +import struct +from typing import List +from collections import deque + + +def int2byte(val): + return struct.pack('i', val) + + +def int4byte(val): + return struct.unpack('i', val)[0] + + +def double2byte(val): + return struct.pack('d', val) + + +def double4byte(val): + return struct.unpack('d', val)[0] + + +def get_node_arch(): + import torch + return torch.cuda.get_device_name(torch.cuda.current_device()).replace( + ' ', '_') + + +# tensor parallelism +def tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, + dim: int): + algo = node.algorithms('dim') + sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) + assert sub_nodes is not None + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def partition_node(node: IRFwOperation, graph: IRGraph, devs: [int], + desc: NodePartitionDesc) -> None: + min_dev_index = min(devs) + tp_size = len(devs) + info = desc.desc + + dq = deque() + dq.append((node, (0, tp_size))) + for (idx, dim), num in info: + + cur_nodes = [] + while dq: + u, (low, high) = dq.popleft() + assert (high - low) % num == 0 + inc = (high - low) // num + sub_intervals = list( + map(lambda x: (low + x * inc, low + (x + 1) * inc), + list(range(num)))) + if idx == -1 and dim == -1: + sub_nodes = graph.replicate(u, times=num) + else: + assert idx >= 0 and dim >= 0 + algo = u.algorithms('dim') + sub_nodes = graph.partition(u, algo, idx=idx, dim=dim, num=num) + for i in range(num): + cur_nodes.append((sub_nodes[i], sub_intervals[i])) + + for cur_node in cur_nodes: + dq.append(cur_node) + + while dq: + u, (low, high) = dq.popleft() + assert high - low == 1 + graph.assign(u, low + min_dev_index) diff --git a/requirements-dev.txt b/requirements-dev.txt index 275ac755..826839f5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,10 @@ -pytest -mock -pytest-mock -tox coverage +mock +pre-commit +pytest pytest-cov +pytest-mock tabulate +tox tox-conda +yapf diff --git a/requirements.txt b/requirements.txt index 880014c0..1912dc45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ -numpy>=1.23.0 +dill matplotlib more-itertools -dill +numpy>=1.23.0 +pulp +pyyaml torch>=2.0 diff --git a/tests/autodist/__init__.py b/tests/autodist/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/autodist/graph/__init__.py b/tests/autodist/graph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/autodist/graph/test_calc_flops.py b/tests/autodist/graph/test_calc_flops.py new file mode 100644 index 00000000..dac7e72c --- /dev/null +++ b/tests/autodist/graph/test_calc_flops.py @@ -0,0 +1,61 @@ +import pytest + +import tempfile +import torch +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.autodist.model_graph import calc_flops + + +class Model(torch.nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + + def forward(self, t_1d, t_2d, t_3d): + x = self.fc1(t_3d) + y = self.fc2(t_3d) + z1 = torch.bmm(x, y) + z2 = torch.matmul(t_1d, t_1d) + z3 = torch.matmul(t_2d, t_2d) + z4 = torch.matmul(t_1d, t_2d) + z5 = torch.matmul(t_2d, t_1d) + z6 = torch.matmul(t_2d, t_3d) + z7 = torch.matmul(t_3d, t_2d) + return x.sum() + y.sum() + z1.sum() + z2.sum() + z3.sum() + z4.sum( + ) + z5.sum() + z6.sum() + z7.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_calc_flops(): + batch_size, hidden_dim = 2, 1024 + dummy_input = { + 't_1d': torch.randn(hidden_dim), + 't_2d': torch.randn(hidden_dim, hidden_dim), + 't_3d': torch.randn(batch_size, hidden_dim, hidden_dim) + } + model = Model(hidden_dim) + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + dynamic_shape=True) + nodes = ir_graph.select(ntype=IRFwOperation) + assert calc_flops( + nodes[0]) == 2 * batch_size * hidden_dim * hidden_dim * hidden_dim + assert calc_flops( + nodes[1]) == 2 * batch_size * hidden_dim * hidden_dim * hidden_dim + assert calc_flops( + nodes[2]) == 2 * batch_size * hidden_dim * hidden_dim * hidden_dim + assert calc_flops(nodes[3]) == 2 * hidden_dim + assert calc_flops(nodes[4]) == 2 * hidden_dim * hidden_dim * hidden_dim + assert calc_flops(nodes[5]) == 2 * hidden_dim * hidden_dim + assert calc_flops(nodes[6]) == 2 * hidden_dim * hidden_dim + assert calc_flops( + nodes[7]) == 2 * batch_size * hidden_dim * hidden_dim * hidden_dim + assert calc_flops( + nodes[8]) == 2 * batch_size * hidden_dim * hidden_dim * hidden_dim diff --git a/tests/autodist/graph/test_recompute.py b/tests/autodist/graph/test_recompute.py new file mode 100644 index 00000000..e34254c1 --- /dev/null +++ b/tests/autodist/graph/test_recompute.py @@ -0,0 +1,125 @@ +import pytest + +import tempfile +import torch +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.model_graph import ModelGraph +from nnscaler.autodist.autodist_config import AutoDistConfig + + +class MLP(torch.nn.Module): + + def __init__(self, hidden_dim, ffn_dim): + super().__init__() + self.fc1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=False) + self.fc2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=False) + self.gelu = torch.nn.GELU() + + def forward(self, x): + return self.fc2(self.gelu(self.fc1(x))) + + +class Layer(torch.nn.Module): + + def __init__(self, hidden_dim, ffn_dim): + super().__init__() + self.mlp = MLP(hidden_dim, ffn_dim) + self.ln = torch.nn.LayerNorm(hidden_dim) + + def forward(self, x): + residual = x + x = self.ln(x) + x = self.mlp(x) + x = x + residual + return x + + +class Decoder(torch.nn.Module): + + def __init__(self, hidden_dim, ffn_dim, num_layers): + super().__init__() + self.layers = torch.nn.ModuleList( + [Layer(hidden_dim, ffn_dim) for _ in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = x.sum() + return x + + +class Model(torch.nn.Module): + + def __init__(self, hidden_dim, ffn_dim, num_layers): + super().__init__() + self.decoder = Decoder(hidden_dim, ffn_dim, num_layers) + + def forward(self, x): + return self.decoder.forward(x) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_recompute(): + batch_size = 2 + hidden_dim, ffn_dim, num_layers = 1024, 4096, 1 + + dummy_input = {'x': torch.randn(batch_size, hidden_dim)} + model = Model(hidden_dim, ffn_dim, num_layers) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + dynamic_shape=True) + + config = AutoDistConfig(recompute_modules='Layer') + model_graph = ModelGraph(ir_graph, config) + model_node = model_graph.scope_tree_root + print(model_node) + + assert len(model_node.children) == 1 + decoder_node = model_node.children[0] + + assert len(decoder_node.children) == num_layers + 1 + for layer_node in decoder_node.children[:-1]: + assert len(layer_node.children) == 3 + ln_node = layer_node.children[0] + assert ln_node.leaf_size == 1 + assert ln_node.in_mem == batch_size * hidden_dim * 4 + assert ln_node.train_mem == batch_size * hidden_dim * 4 + batch_size * 8 + assert ln_node.param_mem == hidden_dim * 8 + assert ln_node.buffer_mem == 0 + mlp_node = layer_node.children[1] + assert mlp_node.leaf_size == 3 + assert mlp_node.in_mem == batch_size * hidden_dim * 4 + assert mlp_node.train_mem == batch_size * hidden_dim * 4 + batch_size * ffn_dim * 8 + assert mlp_node.param_mem == hidden_dim * ffn_dim * 8 + assert mlp_node.buffer_mem == 0 + add_node = layer_node.children[2] + assert add_node.leaf_size == 1 + assert add_node.in_mem == batch_size * hidden_dim * 8 + assert add_node.train_mem == 0 + assert add_node.param_mem == 0 + assert add_node.buffer_mem == 0 + + assert layer_node.leaf_size == ln_node.leaf_size + mlp_node.leaf_size + add_node.leaf_size + assert layer_node.in_mem == batch_size * hidden_dim * 4 + assert layer_node.train_mem == ln_node.train_mem + mlp_node.train_mem + add_node.train_mem + assert layer_node.param_mem == ln_node.param_mem + mlp_node.param_mem + add_node.param_mem + assert layer_node.buffer_mem == 0 + + assert decoder_node.leaf_size == num_layers * layer_node.leaf_size + 1 + assert decoder_node.in_mem == batch_size * hidden_dim * 4 + assert decoder_node.train_mem == num_layers * layer_node.train_mem + assert decoder_node.param_mem == num_layers * layer_node.param_mem + assert decoder_node.buffer_mem == 0 + + assert model_node.leaf_size == decoder_node.leaf_size + assert model_node.in_mem == decoder_node.in_mem + assert model_node.train_mem == decoder_node.train_mem + assert model_node.param_mem == decoder_node.param_mem + assert model_node.buffer_mem == decoder_node.buffer_mem + + assert model_graph.recompute_mem == layer_node.train_mem + num_layers * layer_node.in_mem diff --git a/tests/autodist/partition/__init__.py b/tests/autodist/partition/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/autodist/partition/test_state.py b/tests/autodist/partition/test_state.py new file mode 100644 index 00000000..2c6a3ce7 --- /dev/null +++ b/tests/autodist/partition/test_state.py @@ -0,0 +1,46 @@ +from nnscaler.autodist.op_partition import calc_factors, generate_partitions + + +def test_calc_factors(): + assert calc_factors(1, 1) == [(1,)] + assert calc_factors(2, 1) == [(2,)] + assert calc_factors(2, 2) == [] + assert calc_factors(4, 2) == [(2, 2)] + assert calc_factors(6, 2) == [(2, 3)] + assert calc_factors(8, 2) == [(2, 4)] + assert calc_factors(8, 3) == [(2, 2, 2)] + assert calc_factors(16, 3) == [(2, 2, 4)] + + +def test_generate_partitions(): + # [['a'], [2]], [['b'], [2]] + assert len(generate_partitions(['a', 'b'], 2)) == 2 + # [['a'], [4]], [['b'], [4]], [['a', 'b'], [2, 2]], [['b', 'a'], [2, 2]] + assert len(generate_partitions(['a', 'b'], 4)) == 4 + # [['a'], [4]], [['b'], [4]], [['c'], [4]] + # [['a', 'b'], [2, 2]], [['a', 'c'], [2, 2]] + # [['b', 'a'], [2, 2]], [['b', 'c'], [2, 2]] + # [['c', 'a'], [2, 2]], [['c', 'b'], [2, 2] + assert len(generate_partitions(['a', 'b', 'c'], 4)) == 9 + # [['a'], [8]], [['b'], [8]] + # [['a', 'b'], [2, 4]], [['b', 'a'], [2, 4]] + # [['a', 'b'], [4, 2]], [['b', 'a'], [4, 2]] + assert len(generate_partitions(['a', 'b'], 8)) == 6 + # [['a'], [8]], [['b'], [8]], [['c'], [8]] + # [['a', 'b'], [2, 4]], [['a', 'c'], [2, 4]] + # [['b', 'a'], [2, 4]], [['b', 'c'], [2, 4]] + # [['c', 'a'], [2, 4]], [['c', 'b'], [2, 4]] + # [['a', 'b'], [4, 2]], [['a', 'c'], [4, 2]] + # [['b', 'a'], [4, 2]], [['b', 'c'], [4, 2]] + # [['c', 'a'], [4, 2]], [['c', 'b'], [4, 2]] + # [['a', 'b', 'c'], [2, 2, 2]], [['a', 'c', 'b'], [2, 2, 2]] + # [['b', 'a', 'c'], [2, 2, 2]], [['b', 'c', 'a'], [2, 2, 2]] + # [['c', 'a', 'b'], [2, 2, 2]], [['c', 'b', 'a'], [2, 2, 2]] + assert len(generate_partitions(['a', 'b', 'c'], 8)) == 21 + # [['a'], [8]], [['b'], [8]], [[-1], [8]] + # [[-1, 'a'], [2, 4]], [[-1, 'b'], [2, 4]] + # [['a', 'b'], [2, 4]], [['b', 'a'], [2, 4]] + # [[-1, 'a'], [4, 2]], [[-1, 'b'], [4, 2]] + # [['a', 'b'], [4, 2]], [['b', 'a'], [4, 2]] + # [[-1, 'a', 'b'], [2, 2, 2]], [[-1, 'b', 'a'], [2, 2, 2]] + assert len(generate_partitions(['a', 'b', -1], 8)) == 13 diff --git a/tests/autodist/pas/__init__.py b/tests/autodist/pas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/autodist/pas/all_replicated_pp.json b/tests/autodist/pas/all_replicated_pp.json new file mode 100644 index 00000000..ab6a532b --- /dev/null +++ b/tests/autodist/pas/all_replicated_pp.json @@ -0,0 +1,87 @@ +{ + "desc": { + "spmd_descs": [ + { + "partition_descs": [ + [ + 3, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 4, + [ + [ + [ + 0, + 0 + ], + 2 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ] + }, + { + "partition_descs": [ + [ + 5, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 6, + [ + [ + [ + 0, + 0 + ], + 2 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ] + } + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 4 + ] + }, + "e2e_time": 0.0, + "stage_mems": [ + 0.0 + ], + "stage_all_times": [ + 0.0 + ], + "stage_comp_times": [ + 0.0 + ] +} diff --git a/tests/autodist/pas/replicated_and_partition.json b/tests/autodist/pas/replicated_and_partition.json new file mode 100644 index 00000000..edcf8949 --- /dev/null +++ b/tests/autodist/pas/replicated_and_partition.json @@ -0,0 +1,87 @@ +{ + "desc": { + "spmd_descs": [ + { + "partition_descs": [ + [ + 3, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 4, + [ + [ + [ + 0, + 0 + ], + 2 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ] + }, + { + "partition_descs": [ + [ + 5, + [ + [ + [ + 0, + 0 + ], + 2 + ] + ] + ], + [ + 6, + [ + [ + [ + 0, + 0 + ], + 2 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ] + } + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 4 + ] + }, + "e2e_time": 0.0, + "stage_mems": [ + 0.0 + ], + "stage_all_times": [ + 0.0 + ], + "stage_comp_times": [ + 0.0 + ] +} diff --git a/tests/autodist/pas/test_shared_param_pipeline.py b/tests/autodist/pas/test_shared_param_pipeline.py new file mode 100644 index 00000000..31957240 --- /dev/null +++ b/tests/autodist/pas/test_shared_param_pipeline.py @@ -0,0 +1,79 @@ +import pytest + +import tempfile +import torch +import os +from pathlib import Path +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.autodist.apis import parallelize_graph + +import nnscaler +from nnscaler.ir.unique import IDGenerator +from nnscaler.graph.segment import IRSegment +from nnscaler.flags import CompileFlag +from nnscaler.runtime.utils import microbatches +from nnscaler.program import Program, SemanticDataLoader +from nnscaler.graph.gener.gen import IRAdapterGener + + +class Model(torch.nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.w = torch.nn.Parameter(torch.randn(hidden_dim, hidden_dim)) + + def forward(self, x): + x = torch.matmul(x, self.w) + x = torch.nn.functional.relu(x) + x = torch.matmul(x, self.w) + return x.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_shared_param_pipeline(): + bsz, hidden_dim = 1024, 1024 + + CompileFlag.dev_mode = True + + for idx, cfg_fname in enumerate( + ['all_replicated_pp.json', 'replicated_and_partition.json']): + with tempfile.TemporaryDirectory() as tempdir: + model = Model(hidden_dim) + model.train() + + IDGenerator().clear() + if idx > 0: + Program().clear() + smodel = nnscaler.SemanticModel(model, attr_savedir=tempdir) + smodel.dummy_input = {'x': torch.randn(bsz, hidden_dim)} + smodel.dynamic_shape = False + + dataloader = SemanticDataLoader( + microbatches([{ + 'x': torch.randn(bsz, hidden_dim) + }])) + Program().set_input([dataloader.irobj]) + ir_dummy_input = next(dataloader) + outputs = smodel(ir_dummy_input) + outputs.backward() + Program().set_output([outputs]) + Program().finalize() + ir_graph = Program().get_graph() + + print(ir_graph.nodes()) + plan_path = Path(os.path.dirname(__file__)) / cfg_fname + cfg = AutoDistConfig(load_plan_path=plan_path, mesh_col=4) + graph = parallelize_graph(ir_graph, cfg) + assert isinstance(graph.nodes()[4], IRSegment) + # check multiref is correctly inserted at the 1st IRSegment (pipeline stage) + has_multiref = False + for node in graph.nodes()[4].nodes(): + if node.signature == 'nnscaler.runtime.function.multiref': + has_multiref = True + break + assert has_multiref + + graph = IRAdapterGener.gen(graph, cost_fn=None) + if graph.sched is not None: + graph.sched.apply() diff --git a/tests/autodist/spmd_solver/__init__.py b/tests/autodist/spmd_solver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/autodist/spmd_solver/test_attention_follow.yaml b/tests/autodist/spmd_solver/test_attention_follow.yaml new file mode 100644 index 00000000..00f4f9a8 --- /dev/null +++ b/tests/autodist/spmd_solver/test_attention_follow.yaml @@ -0,0 +1,18 @@ +- allowed_partition_dims: + - 1,0 + - 1,1 + name: torch.nn.functional.linear + parent_module: Attention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + - 0,1 + name: torch.matmul + parent_module: Attention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + - 0,1 + name: torch.nn.functional.softmax + parent_module: Attention + replica_allowed: false diff --git a/tests/autodist/spmd_solver/test_cube_operator.py b/tests/autodist/spmd_solver/test_cube_operator.py new file mode 100644 index 00000000..c632b6b7 --- /dev/null +++ b/tests/autodist/spmd_solver/test_cube_operator.py @@ -0,0 +1,50 @@ +import pytest + +import tempfile +import torch +import os +from pathlib import Path +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.model_graph import ModelGraph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.autodist.spmd_solver import SPMDSolver + +import nnscaler + + +@nnscaler.graph.parser.register( + '(1 h) l^ d^, (1 h) l^ d^, (1 h) l^ d^ -> (1 h) l^ d^', 'mock_attention') +def mock_attention(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): + return x + y + z + + +class Model(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + t = mock_attention(x, y, z) + return t.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_cube_operator(): + bsz, head_num, seq_len, head_dim = 1, 8, 128, 64 + data = torch.randn((bsz * head_num, seq_len, head_dim)) + + dummy_input = {'x': data, 'y': data, 'z': data} + model = Model() + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + dynamic_shape=False) + cfg = AutoDistConfig(mesh_col=2) + model_graph = ModelGraph(ir_graph, cfg) + mock_attention_op = model_graph.operator_list[0] + assert mock_attention_op.pos2dim_id((0, 0)) == 'h' + assert mock_attention_op.dim_id2pos('h') == (0, 0) diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py new file mode 100644 index 00000000..c398ba17 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow.py @@ -0,0 +1,283 @@ +import pytest + +import tempfile +import torch +import math +import os +from pathlib import Path +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.model_graph import ModelGraph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.autodist.spmd_solver import SPMDSolver + + +def rotate_half(x): + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Model(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, q, k, cos, sin, position_ids): + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + out = q + k + return out.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_follow_rope(): + bsz, seq_len, hidden_dim = 2, 128, 512 + dummy_input = { + 'q': torch.rand(bsz, 1, seq_len, hidden_dim), + 'k': torch.rand(bsz, 1, seq_len, hidden_dim), + 'cos': torch.rand(seq_len, hidden_dim), + 'sin': torch.rand(seq_len, hidden_dim), + 'position_ids': torch.arange(seq_len, dtype=torch.long), + } + model = Model() + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + dynamic_shape=False) + ''' + the computation graph is as follows: + getitem getitem + | | + unsqueeze unsqueeze + | \ | + | -------------------------------------------------mul + | | | + mul fullsclie fullslice | fullsclie fullslice | + | \ | | \ | | + | \ neg | \ neg | + | \ | | \ | | + | concat | concat | + | | | | | + add-----------mul-------------------------mul----------add + | | + | | + ---------------------------add-------------------------- + | + sum + currently, the following chain is only composed of unary ops + there are 2 chains in total: + 1. fullslice -> neg + 2. fullslice -> neg + in future, we may add follow chains for binary ops, like mul, add, etc. + ''' + + cfg = AutoDistConfig(mesh_col=2) + model_graph = ModelGraph(ir_graph, cfg) + + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=cfg.mesh_desc, + autodist_config=cfg, + stage_num=1, + micro_batch_num=cfg.update_freq, + ) + + assert spmd_solver.follow_ids == [ + 0, 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 11, 12, 13, 13, 15, 16, 17, 18, 19 + ] + partition_counts = [ + spmd_solver.get_op_partition_count(i) + for i in range(model_graph.op_num) + ] + assert partition_counts[6] == partition_counts[7] + assert partition_counts[13] == partition_counts[14] + + +class Attention(torch.nn.Module): + + def __init__(self, hidden_dim, num_heads): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.o_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + + def forward(self, x): + bsz, seq_len, _ = x.shape + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt( + self.head_dim) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = torch.nn.functional.dropout(attn_weights, + p=0.0, + training=self.training) + attn_out = torch.matmul(attn_weights, v) + + attn_out = attn_out.transpose(1, 2).contiguous().reshape( + bsz, seq_len, self.hidden_dim) + attn_out = self.o_proj(attn_out) + return attn_out + + +class AttentionModel(torch.nn.Module): + + def __init__(self, hidden_dim, num_heads): + super().__init__() + self.attention = Attention(hidden_dim, num_heads) + + def forward(self, x): + return self.attention(x).sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_follow_attention(): + from nnscaler.ir.unique import IDGenerator + IDGenerator().clear() + bsz, seq_len, hidden_dim, num_heads = 2, 128, 512, 8 + dummy_input = { + 'x': torch.rand(bsz, seq_len, hidden_dim), + } + model = AttentionModel(hidden_dim, num_heads) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + dynamic_shape=False) + print(ir_graph.nodes()) + ''' + the computation graph is as follows: + linear linear linear + | | | + view view view + | | | + transpose transpose transpose + \ | | + | transpose | + \ / | + matmul | + | | + div | + | | + softmax | + | | + dropout | + \ / + \ / + matmul + | + transpose + | + contiguous + | + reshape + | + linear + | + sum + + the follow chain is as follows: + 1. view -> transpose + 2. view -> transpose -> transpose + 3. view -> transpose + 4. div -> softmax -> dropout + 5. transpose -> contiguous -> reshape + ''' + + pc_path = Path(os.path.dirname(__file__)) / 'test_attention_follow.yaml' + cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2) + model_graph = ModelGraph(ir_graph, cfg) + + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=cfg.mesh_desc, + autodist_config=cfg, + stage_num=1, + micro_batch_num=cfg.update_freq, + ) + + assert spmd_solver.follow_ids == [ + 0, 1, 2, 3, 3, 5, 5, 7, 7, 6, 10, 11, 11, 12, 14, 15, 15, 16, 18, 19 + ] + partition_counts = [ + spmd_solver.get_op_partition_count(i) + for i in range(model_graph.op_num) + ] + assert partition_counts == [ + 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 4, 4, 4, 2, 4 + ] + # under the current partition constraints, the solver should generate + # a Megatron-LM plan + expected_out = [ + # partition out feature for q_proj + (2, (((1, 0), 2),)), + # partition out feature for k_proj + (3, (((1, 0), 2),)), + # partition out feature for v_proj + (4, (((1, 0), 2),)), + # partition hidden dim for q's view + (5, (((0, 2), 2),)), + # partition the head dim for q's transpose + (6, (((0, 2), 2),)), + # partition the hidden dim for k's view + (7, (((0, 2), 2),)), + # partition the head dim for k's transpose + (8, (((0, 2), 2),)), + # partition the hidden dim for v's view + (9, (((0, 2), 2),)), + # partition the head dim for v's transpose + (10, (((0, 2), 2),)), + # partition the head dim for k's 2nd transpose + (11, (((0, 1), 2),)), + # partition the head dim for matmul(q, k) + (12, (((0, 1), 2),)), + # partition the head dim div + (13, (((0, 1), 2),)), + # partition the head dim for softmax + (14, (((0, 1), 2),)), + # partition the head dim for dropout + (15, (((0, 1), 2),)), + # partition the head dim for matmul(attn_weights, v) + (16, (((0, 1), 2),)), + # partition the head dim for attn_out.transpose + (17, (((0, 1), 2),)), + # partition the head dim for contiguous + (18, (((0, 2), 2),)), + # partition the head dim for reshape + (19, (((0, 2), 2),)), + # partition the input feature for o_proj + (20, (((0, 2), 2),)), + # replicate the sum + (21, (((-1, -1), 2),)) + ] + + def helper(search_out): + return search_out[0][0].to_json()['desc']['partition_descs'] + + dp_spmd_outs = spmd_solver.do_dp([(0, model_graph.op_num - 1)], 1) + ilp_spmd_outs = spmd_solver.do_ilp([(0, model_graph.op_num - 1)], 1) + assert helper(dp_spmd_outs) == expected_out + assert helper(ilp_spmd_outs) == expected_out diff --git a/tests/autodist/spmd_solver/test_partition_constraint.py b/tests/autodist/spmd_solver/test_partition_constraint.py new file mode 100644 index 00000000..e8e65d16 --- /dev/null +++ b/tests/autodist/spmd_solver/test_partition_constraint.py @@ -0,0 +1,107 @@ +import pytest + +import tempfile +import torch +import os +from pathlib import Path +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.model_graph import ModelGraph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.autodist.spmd_solver import SPMDSolver + + +class Attention(torch.nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + + def forward(self, x): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + score = torch.matmul(q, k.transpose(-2, -1)) + score = torch.nn.functional.softmax(score, dim=-1) + out = torch.matmul(score, v) + out = self.out_proj(out) + return out + + +class FFN(torch.nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = torch.nn.functional.relu(x) + x = self.fc2(x) + return x + + +class Decoder(torch.nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.attn = Attention(hidden_dim) + self.ffn = FFN(hidden_dim) + + def forward(self, x): + x = self.attn(x) + x = self.ffn(x) + x = x.sum() + return x + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_partition_constraint(): + bsz, seq_len, hidden_dim = 2, 128, 768 + + dummy_input = {'x': torch.randn(bsz, seq_len, hidden_dim)} + model = Decoder(hidden_dim) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + dynamic_shape=False) + + pc_path = Path(os.path.dirname( + os.path.realpath(__file__))) / 'test_pc.yaml' + cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2) + model_graph = ModelGraph(ir_graph, cfg) + + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=cfg.mesh_desc, + autodist_config=cfg, + stage_num=1, + micro_batch_num=cfg.update_freq, + ) + + partition_counts = [ + spmd_solver.get_op_partition_count(i) + for i in range(model_graph.op_num) + ] + ''' + q_proj: 1 + k_proj: 1 + v_proj: 1 + transpose: 4 + matmul: 1 + softmax: 3 + matmul: 1 + out_proj: 1 + fc1: 2 + relu: 4 + fc2: 2 + sum: 4 + ''' + assert partition_counts == [1, 1, 1, 4, 1, 3, 1, 1, 2, 4, 2, 4] diff --git a/tests/autodist/spmd_solver/test_pc.yaml b/tests/autodist/spmd_solver/test_pc.yaml new file mode 100644 index 00000000..bff44a94 --- /dev/null +++ b/tests/autodist/spmd_solver/test_pc.yaml @@ -0,0 +1,16 @@ +- allowed_partition_dims: + - 0,0 + name: torch.nn.functional.linear + parent_module: Attention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + name: torch.matmul + parent_module: Attention + replica_allowed: false +- allowed_partition_dims: + - 1,0 + - 1,1 + name: torch.nn.functional.linear + parent_module: FFN + replica_allowed: false diff --git a/tests/autodist/spmd_solver/test_shared_param.py b/tests/autodist/spmd_solver/test_shared_param.py new file mode 100644 index 00000000..f1312644 --- /dev/null +++ b/tests/autodist/spmd_solver/test_shared_param.py @@ -0,0 +1,66 @@ +import pytest + +import tempfile +import torch +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.model_graph import ModelGraph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.autodist.spmd_solver import SPMDSolver + + +class Model(torch.nn.Module): + + def __init__(self, dict_size, hidden_dim): + super().__init__() + self.embedding = torch.nn.Embedding(dict_size, hidden_dim) + self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.fc2.weight = self.fc1.weight + self.fc3 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.fc = torch.nn.Linear(hidden_dim, dict_size, bias=False) + self.fc.weight = self.embedding.weight + + def forward(self, x): + x = self.embedding(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.fc(x) + return x.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_shared_param(): + bsz, seq_len, hidden_dim, dict_size = 2, 128, 768, 1024 + + dummy_input = {'x': torch.randint(0, dict_size, (bsz, seq_len))} + model = Model(dict_size, hidden_dim) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + dynamic_shape=False) + + cfg = AutoDistConfig(mesh_col=4) + model_graph = ModelGraph(ir_graph, cfg) + + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=cfg.mesh_desc, + autodist_config=cfg, + stage_num=1, + micro_batch_num=cfg.update_freq, + ) + + partition_counts = [ + spmd_solver.get_op_partition_count(i) + for i in range(model_graph.op_num) + ] + # batch size cannot be partitioned on 4 devices + # each operator can be replicated and partitioned on the sequence length dim + # for fc3, the out_feature and in_feature dims can be partitioned + # for sum, the hidden dim can be partitioned + assert partition_counts == [2, 2, 2, 4, 2, 3] From 93b5b2387340bb81eea221c39dc8f35346fa6f55 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Fri, 19 Apr 2024 08:23:41 +0000 Subject: [PATCH 1624/1892] Merged PR 2122: Add packaging related files According to our discussion: 1. The C++ script (solver) will be ported to pybind11 and be built with cppimport at runtime; 2. The profiling script will be moved into nnscaler package, and will be invoked when using autodist for the first time. Since both of them need code changes, they are not seriously considered in this PR. How to use: For distribution: ``` python -m build ``` It will generate 2 files. `dist/nnscaler-0.3-py3-none-any.whl` is the package to be installed. `nnscaler-0.3.tar.gz` should be ignored since there is no benefit to use sdist for our planned solution. For development: ``` python -m pip install -e . ``` As usual. --- nnscaler/__init__.py | 2 ++ nnscaler/version.py | 1 + pyproject.toml | 30 ++++++++++++++++++++++++++++++ setup.py | 17 ----------------- 4 files changed, 33 insertions(+), 17 deletions(-) create mode 100644 nnscaler/version.py create mode 100644 pyproject.toml delete mode 100644 setup.py diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index 8b0baa8c..133d1a29 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -13,6 +13,8 @@ from nnscaler.flags import CompileFlag +from .version import __version__ + def _check_torch_version(): import torch diff --git a/nnscaler/version.py b/nnscaler/version.py new file mode 100644 index 00000000..cce384d3 --- /dev/null +++ b/nnscaler/version.py @@ -0,0 +1 @@ +__version__ = '0.3' diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..009d83ac --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +dynamic = ["version", "dependencies"] + +name = "nnscaler" +description = "Parallelize DNN Traning from A Systematic Way" +readme = "README.md" +requires-python = ">=3.8" +# TODO: license +authors = [ + {name = "nnScaler Team", email = "nnscaler@microsoft.com"} # FIXME: email +] +# TODO: keywords +# TODO: classifiers + +[project.urls] +Homepage = "https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube" + +[tool.setuptools] +dynamic.version.attr = "nnscaler.version.__version__" +dynamic.dependencies.file = "requirements.txt" + +# NOTE: +# the following part only affects wheel, not sdist +# since our current plan is to use cppimport, sdist is not needed +packages.find.include = ["nnscaler*"] +package-data.nnscaler = ["autodist/csrc/*"] diff --git a/setup.py b/setup.py deleted file mode 100644 index 269504e1..00000000 --- a/setup.py +++ /dev/null @@ -1,17 +0,0 @@ -import setuptools - -with open("requirements.txt") as f: - install_requires = [ - line.split("#")[0].strip() for line in f if not line.startswith("#") and line.split("#")[0].strip() != "" - ] - -setuptools.setup( - name= 'nnscaler', - version= '0.2', - author= 'nnScaler Team', - description= 'Parallelize DNN Traning from A Systematic Way', - long_description= 'Parallelize DNN Traning from A Systematic Way', - packages= ['nnscaler'], - python_requires= '>=3.8', - install_requires= install_requires, -) From 0ebef93a20321ba322b7249a1e32c31b87074b3e Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 22 Apr 2024 12:39:07 +0000 Subject: [PATCH 1625/1892] Merged PR 2124: Use cppimport + pybind11 to compile cpp code --- .gitignore | 2 + autodist/build_env.py | 7 - nnscaler/autodist/csrc/solver.cpp | 799 ---------------------------- nnscaler/autodist/dp_solver.cpp | 829 ++++++++++++++++++++++++++++++ nnscaler/autodist/spmd_solver.py | 136 +---- requirements.txt | 2 + tests/autodist/test_dp_solver.py | 38 ++ 7 files changed, 896 insertions(+), 917 deletions(-) delete mode 100644 nnscaler/autodist/csrc/solver.cpp create mode 100644 nnscaler/autodist/dp_solver.cpp create mode 100644 tests/autodist/test_dp_solver.py diff --git a/.gitignore b/.gitignore index c866c6ef..d7d4e70d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__ *.egg-info +*.so .vs/ .vscode/ @@ -14,6 +15,7 @@ benchmark/deepspeed/Megatron-DeepSpeed gencode*.py fullmodel.pt +fullmodel.pt.* dist_param_map.pt ## autodist ## diff --git a/autodist/build_env.py b/autodist/build_env.py index c13978fe..590ce804 100644 --- a/autodist/build_env.py +++ b/autodist/build_env.py @@ -44,13 +44,6 @@ def main(): default_path + f'/comm_back_{str(datetime.now().timestamp())}') shutil.copytree(code_path / 'autodist/profile_data/16xmi200/comm', default_path + '/comm') - # compile solver - solver_csrc = code_path / 'nnscaler/autodist/csrc/solver.cpp' - compile_command = f'g++ -std=c++11 {solver_csrc} -O3 -pthread -o solver' - compile_out = subprocess.check_output(compile_command, - shell=True, - text=True) - subprocess.check_output(f'mv solver {base_path}/', shell=True, text=True) print('> build env successfully') diff --git a/nnscaler/autodist/csrc/solver.cpp b/nnscaler/autodist/csrc/solver.cpp deleted file mode 100644 index 26cde49e..00000000 --- a/nnscaler/autodist/csrc/solver.cpp +++ /dev/null @@ -1,799 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { -public: - ThreadPool(unsigned int n = std::thread::hardware_concurrency()); - - template void enqueue(F &&f); - void waitFinished(); - ~ThreadPool(); - - unsigned int getProcessed() const { return processed; } - -private: - std::vector workers; - std::deque> tasks; - std::mutex queue_mutex; - std::condition_variable cv_task; - std::condition_variable cv_finished; - std::atomic_uint processed; - unsigned int busy; - bool stop; - - void thread_proc(); -}; - -ThreadPool::ThreadPool(unsigned int n) : busy(), processed(), stop() { - for (unsigned int i = 0; i < n; ++i) - workers.emplace_back(std::bind(&ThreadPool::thread_proc, this)); -} - -ThreadPool::~ThreadPool() { - // set stop-condition - std::unique_lock latch(queue_mutex); - stop = true; - cv_task.notify_all(); - latch.unlock(); - - // all threads terminate, then we're done. - for (auto &t : workers) - t.join(); -} - -void ThreadPool::thread_proc() { - while (true) { - std::unique_lock latch(queue_mutex); - cv_task.wait(latch, [this]() { return stop || !tasks.empty(); }); - if (!tasks.empty()) { - // got work. set busy. - ++busy; - - // pull from queue - auto fn = tasks.front(); - tasks.pop_front(); - - // release lock. run async - latch.unlock(); - - // run function outside context - fn(); - ++processed; - - latch.lock(); - --busy; - cv_finished.notify_one(); - } else if (stop) { - break; - } - } -} - -// generic function push -template void ThreadPool::enqueue(F &&f) { - std::unique_lock lock(queue_mutex); - tasks.emplace_back(std::forward(f)); - cv_task.notify_one(); -} - -// waits until the queue is empty. -void ThreadPool::waitFinished() { - std::unique_lock lock(queue_mutex); - cv_finished.wait(lock, [this]() { return tasks.empty() && (busy == 0); }); -} - -struct DPNode; - -struct Node { - int id; - int father_id; - - int cut_len; - std::vector cut_nodes; - - int p_num; - std::vector p_time; - std::vector p_comp_mem; - std::vector p_buf_mem; - std::vector p_act_mem; - std::vector p_opt_mem; - std::vector p_father; - - int producer_num; - std::vector producers; - std::vector> comm_costs; - - // assume the number of combinations is less than 2e9 - int dp_num; - std::vector dp_nodes; -}; - -int verbose; - -struct DPNode { - Node *graph_node; - int pg_id; - std::vector ir; - std::vector> in_edges; - // mem, time, activation_mem, optimzer_mem - std::vector> state; -}; - -void resetNode(Node *node) { - for (DPNode *dp_node : node->dp_nodes) { - dp_node->state.clear(); - } -} - -void printNode(Node *node) { - std::cout << "id: " << node->id << std::endl; - std::cout << "father_id: " << node->father_id << std::endl; - std::cout << "cut_len: " << node->cut_len << std::endl; - std::cout << "cut_nodes: "; - for (auto cut_node : node->cut_nodes) { - std::cout << cut_node->id << " "; - } - std::cout << std::endl; - std::cout << "p_num: " << node->p_num << std::endl; - std::cout << "p_time: "; - for (auto p_time : node->p_time) { - std::cout << p_time << " "; - } - std::cout << std::endl; - std::cout << "p_comp_mem: "; - for (auto p_comp_mem : node->p_comp_mem) { - std::cout << p_comp_mem << " "; - } - std::cout << std::endl; - std::cout << "p_buf_mem: "; - for (auto p_buf_mem : node->p_buf_mem) { - std::cout << p_buf_mem << " "; - } - std::cout << std::endl; - std::cout << "p_act_mem: "; - for (auto p_act_mem : node->p_act_mem) { - std::cout << p_act_mem << " "; - } - std::cout << std::endl; - std::cout << "p_opt_mem: "; - for (auto p_opt_mem : node->p_opt_mem) { - std::cout << p_opt_mem << " "; - } - std::cout << std::endl; - std::cout << "producer_num: " << node->producer_num << std::endl; - std::cout << "producers: "; - for (auto producer : node->producers) { - std::cout << producer->id << " "; - } - std::cout << std::endl; - std::cout << "p_father: "; - for (auto p_father : node->p_father) { - std::cout << p_father << " "; - } - std::cout << std::endl; - std::cout << "comm_costs: " << std::endl; - for (auto comm_cost : node->comm_costs) { - for (auto cost : comm_cost) { - std::cout << cost << " "; - } - std::cout << std::endl; - } - std::cout << "dp_num: " << node->dp_num << std::endl; - std::cout << std::endl; -} - -std::unordered_map id2node; -std::vector> queries; -// mode = 0: training, use the sum of the two largest buffer sizes -// mode = 1: inference, use the largest buffer size -int mode; -// mem_bound: the maximum memory usage, in bytes -int mem_bound; -// mem_div: the memory divisor, to avoid overflow in int32 -int mem_div; -int topk; -const int MAX_CONCURRENCY = std::thread::hardware_concurrency(); - -bool is_big_endian(void) { - union { - uint32_t i; - char c[4]; - } bint = {0x01020304}; - - return bint.c[0] == 1; -} - -template void read_binary(T *ptr, std::ifstream &stream) { - stream.read(reinterpret_cast(ptr), sizeof(*ptr)); -} - -void build_graph(std::ifstream &input) { - queries.clear(); - id2node.clear(); - - int n, q_size; - read_binary(&mode, input); - read_binary(&n, input); - read_binary(&mem_bound, input); - read_binary(&mem_div, input); - read_binary(&topk, input); - read_binary(&q_size, input); - - if (verbose) { - printf("is big endian: %d\n", is_big_endian()); - } - printf("node num: %d, mem_bound: %d, mem_div: %d, topk: %d, query num: %d\n", - n, mem_bound, mem_div, topk, q_size); - queries.resize(q_size); - for (int i = 0; i < q_size; ++i) { - int start, end; - read_binary(&start, input); - read_binary(&end, input); - queries[i] = std::make_pair(start, end); - } - for (int i = 0; i < n; ++i) { - Node *node = new Node(); - read_binary(&node->id, input); - id2node[node->id] = node; - read_binary(&node->father_id, input); - - int cut_id; - read_binary(&node->cut_len, input); - node->cut_nodes.resize(node->cut_len); - for (int j = 0; j < node->cut_len; ++j) { - read_binary(&cut_id, input); - node->cut_nodes[j] = id2node[cut_id]; - } - - read_binary(&node->p_num, input); - node->p_father.resize(node->p_num); - node->p_time.resize(node->p_num); - node->p_comp_mem.resize(node->p_num); - node->p_buf_mem.resize(node->p_num); - node->p_act_mem.resize(node->p_num); - node->p_opt_mem.resize(node->p_num); - for (int j = 0; j < node->p_num; ++j) { - read_binary(node->p_time.data() + j, input); - read_binary(node->p_comp_mem.data() + j, input); - read_binary(node->p_buf_mem.data() + j, input); - read_binary(node->p_act_mem.data() + j, input); - read_binary(node->p_opt_mem.data() + j, input); - read_binary(node->p_father.data() + j, input); - } - - read_binary(&node->producer_num, input); - node->producers.clear(); - node->comm_costs.clear(); - node->comm_costs.resize(node->producer_num); - for (int j = 0; j < node->producer_num; ++j) { - int producer_id; - read_binary(&producer_id, input); - Node *producer = id2node[producer_id]; - node->producers.push_back(producer); - node->comm_costs[j].resize(node->p_num * producer->p_num); - for (int k = 0; k < node->p_num * producer->p_num; ++k) { - read_binary(node->comm_costs[j].data() + k, input); - } - } - node->dp_num = 1; - for (Node *cut_node : node->cut_nodes) { - node->dp_num *= cut_node->p_num; - } - node->dp_nodes.resize(node->dp_num); - for (int j = 0; j < node->dp_num; ++j) { - DPNode *dp_node = new DPNode(); - node->dp_nodes[j] = dp_node; - dp_node->graph_node = node; - // pg: partition group, denotes the maintained partition states in - // a node. to reduce memory usage, we use a single int to - // represent a partition group - dp_node->pg_id = j; - dp_node->ir.clear(); - dp_node->in_edges.clear(); - dp_node->state.clear(); - } - if (verbose) { - printNode(node); - } - } -} - -// lazy decode -// after decoding, ir stores the partition id of each cut node -void decodePGID(DPNode *dp_node) { - if (!dp_node->ir.empty()) { - return; - } - Node *node = dp_node->graph_node; - int val = dp_node->pg_id; - for (int i = 0; i < node->cut_len; ++i) { - Node *cur_node = node->cut_nodes[node->cut_len - i - 1]; - dp_node->ir.push_back(val % cur_node->p_num); - val /= cur_node->p_num; - } - std::reverse(dp_node->ir.begin(), dp_node->ir.end()); -} - -// lazy build edge -void buildInEdges(DPNode *dp_node) { - if (!dp_node->in_edges.empty()) { - return; - } - Node *node = dp_node->graph_node; - - // special case: the node does not have any producer - // the pred dp node is composed of the same cut nodes as the current node - // except the last one. the transition cost is 0 since there is no - // communication - if (node->producer_num == 0) { - int val = 0; - for (int i = 0; i < node->cut_len - 1; ++i) { - val += dp_node->ir[i]; - if (i < node->cut_len - 2) { - val *= node->cut_nodes[i + 1]->p_num; - } - } - Node *pre_node = id2node[node->id - 1]; - dp_node->in_edges.push_back(std::make_pair(pre_node->dp_nodes[val], 0)); - return; - } - - int cur_p = *(dp_node->ir.rbegin()); - // we have filtered out the partition that cannot find a father to follow - assert(node->p_father[cur_p] != -1); - std::map info; - for (int i = 0; i < node->cut_len - 1; ++i) { - info[node->cut_nodes[i]->id] = dp_node->ir[i]; - } - // TODO(yizhu1): optimize - int producer_comb_num = 1; - for (Node *producer : node->producers) { - producer_comb_num *= producer->p_num; - } - // enumerate all the possible producer partition combinations - // to build the in edges - for (int idx = 0; idx < producer_comb_num; ++idx) { - bool is_legal = true; - int val = idx; - std::vector producer_ps(node->producer_num); - // decode the producer partition combination - for (int j = 0; j < node->producer_num; ++j) { - int k = node->producer_num - 1 - j; - producer_ps[k] = val % node->producers[k]->p_num; - val /= node->producers[k]->p_num; - // constraint: if the producer shares the same father with the node, - // then the partition of the node should follow the producer's partition, - // except the producer is the father node. - if (node->father_id != node->id) { - Node *producer = node->producers[k]; - // TODO: do we need to check producer->father_id != producer->id? - // seems this case will be filtered out by checker in line 411 - if (producer->father_id == node->father_id && - producer->father_id != producer->id) { - if (node->p_father[cur_p] != producer->p_father[producer_ps[k]]) { - is_legal = false; - } - } - } - } - if (!is_legal) { - continue; - } - // - std::vector> cur_ir(node->cut_len - 1); - bool has_found_follow = false; - for (int i = 0; i < node->cut_len - 1; ++i) { - cur_ir[i] = std::make_pair(node->cut_nodes[i]->id, dp_node->ir[i]); - if (node->cut_nodes[i]->father_id == node->father_id) { - has_found_follow = true; - } - } - double cost = 0; - std::vector> follow_candidates; - for (int j = 0; j < node->producer_num; ++j) { - int producer_id = node->producers[j]->id; - int producer_p = producer_ps[j]; - auto iter = info.find(producer_id); - if (iter != info.end()) { - if (producer_p != iter->second) { - is_legal = false; - break; - } - } else { - Node *producer = node->producers[j]; - if (producer->father_id != node->father_id) { - // check that there is a existing node in cur_ir that in the same - // follow chain with the producer - bool find_existing_follow = false; - for (int i = 0; i < cur_ir.size(); ++i) { - Node *tmp = id2node[cur_ir[i].first]; - if (tmp->father_id == producer->father_id) { - find_existing_follow = true; - // update - if (tmp->id < producer->id) { - for (int _ = 0; _ < producer->p_num; ++_) { - if (producer->p_father[_] == - tmp->p_father[cur_ir[i].second]) { - // replace to align with the filter logic in python - // only the newest node in the follow chain is kept - cur_ir[i] = std::make_pair(producer->id, _); - break; - } - } - } - break; - } - } - if (!find_existing_follow) { - cur_ir.push_back(std::make_pair(producer_id, producer_p)); - } - } else { - follow_candidates.push_back(std::make_pair(producer_id, producer_p)); - } - } - cost += - node->comm_costs[j][cur_p * node->producers[j]->p_num + producer_p]; - } - if (!is_legal) { - continue; - } - // handle follow - bool find_pre_id = false; - for (int j = 0; j < cur_ir.size(); ++j) { - if (cur_ir[j].first == node->id - 1) { - find_pre_id = true; - break; - } - } - if (!find_pre_id) { - Node *pre_node = id2node[node->id - 1]; - if (pre_node->father_id != node->father_id) { - // do nothing, means the pre_node's output is not used - // we select the 1st partition of the pre_node - // need to be careful when the graph has multiple outputs - // shall we constrain that the output of the graph is replicated? - } else if (pre_node->father_id == pre_node->id) { - assert(follow_candidates.rbegin()->first == pre_node->id); - cur_ir.push_back(*follow_candidates.rbegin()); - } else { - bool find_same_follow_p = false; - for (int k = 0; k < pre_node->p_num; ++k) { - if (pre_node->p_father[k] == node->p_father[cur_p]) { - cur_ir.push_back(std::make_pair(node->id - 1, k)); - find_same_follow_p = true; - break; - } - } - assert(find_same_follow_p); - } - } else { - if (node->father_id != node->id && !has_found_follow && - !follow_candidates.empty()) { - cur_ir.push_back(*follow_candidates.rbegin()); - } - } - std::sort(cur_ir.begin(), cur_ir.end()); - val = 0; - for (int j = 0; j < cur_ir.size(); ++j) { - val += cur_ir[j].second; - if (j + 1 < cur_ir.size()) { - val *= id2node[cur_ir[j + 1].first]->p_num; - } - } - dp_node->in_edges.push_back( - std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); - } -} - -// do dp for a partition group -void update(DPNode *dp_node, int start_level) { - Node *node = dp_node->graph_node; - decodePGID(dp_node); - int cur_p_idx = *(dp_node->ir.rbegin()); - if (node->id == start_level) { - // each dp node maintains a list of states, each state is a tuple - // (mem, time, pred_dp_node, activation_mem, optimizer_mem) - dp_node->state.push_back(std::make_tuple( - node->p_comp_mem[cur_p_idx], node->p_time[cur_p_idx], nullptr, - node->p_act_mem[cur_p_idx], node->p_opt_mem[cur_p_idx])); - return; - } - - // storing edges takes space, so we build edges when needed - buildInEdges(dp_node); - int cur_p = *(dp_node->ir.rbegin()); - if (dp_node->in_edges.empty()) { - dp_node->state.push_back(std::make_tuple( - 0, std::numeric_limits::infinity(), nullptr, 0, 0)); - return; - } - - // use a priority queue to maintain the best state, similar to the merge sort - double cur_p_time = node->p_time[cur_p]; - int cur_p_comp_mem = node->p_comp_mem[cur_p]; - int cur_p_act_mem = node->p_act_mem[cur_p]; - int cur_p_opt_mem = node->p_opt_mem[cur_p]; - std::priority_queue> pq; - for (int i = 0; i < dp_node->in_edges.size(); ++i) { - DPNode *pred = dp_node->in_edges[i].first; - int mem = cur_p_comp_mem + std::get<0>(pred->state[0]); - double cost = - cur_p_time + dp_node->in_edges[i].second + std::get<1>(pred->state[0]); - int act_mem = cur_p_act_mem + std::get<3>(pred->state[0]); - int opt_mem = cur_p_opt_mem + std::get<4>(pred->state[0]); - pq.push(std::make_tuple(-mem, -cost, i, -act_mem, -opt_mem)); - } - - std::vector lows(dp_node->in_edges.size(), 1); - - int cur_mem; - double cur_cost; - int pred_idx; - int cur_act_mem; - int cur_opt_mem; - while (!pq.empty()) { - std::tie(cur_mem, cur_cost, pred_idx, cur_act_mem, cur_opt_mem) = pq.top(); - cur_mem = -cur_mem; - cur_cost = -cur_cost; - cur_act_mem = -cur_act_mem; - cur_opt_mem = -cur_opt_mem; - pq.pop(); - if (lows[pred_idx] < dp_node->in_edges[pred_idx].first->state.size()) { - DPNode *pred = dp_node->in_edges[pred_idx].first; - int mem = cur_p_comp_mem + std::get<0>(pred->state[lows[pred_idx]]); - double cost = cur_p_time + dp_node->in_edges[pred_idx].second + - std::get<1>(pred->state[lows[pred_idx]]); - int act_mem = cur_p_act_mem + std::get<3>(pred->state[lows[pred_idx]]); - int opt_mem = cur_p_opt_mem + std::get<4>(pred->state[lows[pred_idx]]); - pq.push(std::make_tuple(-mem, -cost, pred_idx, -act_mem, -opt_mem)); - ++lows[pred_idx]; - } - if (dp_node->state.empty()) { - dp_node->state.push_back( - std::make_tuple(cur_mem, cur_cost, dp_node->in_edges[pred_idx].first, - cur_act_mem, cur_opt_mem)); - } else { - int pre_mem = std::get<0>(dp_node->state[dp_node->state.size() - 1]); - double pre_cost = std::get<1>(dp_node->state[dp_node->state.size() - 1]); - // if (cur_mem > pre_mem && cur_cost < pre_cost && - // cur_mem + cur_opt_mem <= mem_bound) { - if (cur_mem > pre_mem && cur_cost < pre_cost && - cur_mem - cur_act_mem + std::max(cur_act_mem, cur_opt_mem) <= - mem_bound) { - dp_node->state.push_back(std::make_tuple( - cur_mem, cur_cost, dp_node->in_edges[pred_idx].first, cur_act_mem, - cur_opt_mem)); - } - } - } -} - -ThreadPool pool(MAX_CONCURRENCY); - -std::vector> split_work(int num) { - std::vector work; - if (num < MAX_CONCURRENCY) { - work = std::vector(num, 1); - } else { - work = std::vector(MAX_CONCURRENCY, num / MAX_CONCURRENCY); - for (int i = 0; i < num % MAX_CONCURRENCY; ++i) { - work[i] += 1; - } - } - std::vector> ret(work.size()); - int cum_sum = 0; - for (int i = 0; i < work.size(); ++i) { - ret[i] = std::make_pair(cum_sum, work[i]); - cum_sum += work[i]; - } - return ret; -} - -void do_dp(int start_level, int end_level) { - // reset all the dp nodes, since we may have multiple queries - for (auto iter = id2node.begin(); iter != id2node.end(); ++iter) { - resetNode(iter->second); - } - - for (int i = start_level; i <= end_level; ++i) { - // use multi-thread to do dp for each level to reduce time - auto iter = id2node.find(i); - if (iter == id2node.end()) { - // TODO(yizhu1): check here - assert(false); - } - if (verbose) { - std::cout << "Start to process level id: " << i - << ", state num: " << iter->second->dp_nodes.size() - << std::endl; - } - std::vector> split_info = - split_work(iter->second->dp_num); - for (const auto &item : split_info) { - pool.enqueue([=] { - for (int i = 0; i < item.second; ++i) { - int offset = item.first + i; - update(iter->second->dp_nodes[offset], start_level); - } - }); - } - pool.waitFinished(); - } -} - -std::tuple>> -process_state(DPNode *dp_node, int idx) { - // build the optimal path of each partition of last operator - // and return the best path - std::vector> path; - DPNode *cur_dp_node = dp_node; - int cur_idx = idx; - int best_mem = std::get<0>(dp_node->state[idx]); - double best_time = std::get<1>(dp_node->state[idx]); - int act_mem = std::get<3>(dp_node->state[idx]); - int opt_mem = std::get<4>(dp_node->state[idx]); - double inner_time = 0; - int cur_best_mem = best_mem; - std::vector buffers; - while (true) { - int cur_p = *(cur_dp_node->ir.rbegin()); - Node *node = cur_dp_node->graph_node; - path.push_back(std::make_pair(node->id, cur_p)); - buffers.push_back(node->p_buf_mem[cur_p]); - inner_time += node->p_time[cur_p]; - cur_best_mem -= node->p_comp_mem[cur_p]; - DPNode *pred_dp_node = std::get<2>(cur_dp_node->state[cur_idx]); - if (pred_dp_node == nullptr) { - break; - } else { - cur_dp_node = pred_dp_node; - cur_idx = std::lower_bound( - cur_dp_node->state.begin(), cur_dp_node->state.end(), - std::make_tuple(cur_best_mem, static_cast(-1), - static_cast(nullptr), -1, -1)) - - cur_dp_node->state.begin(); - } - } - std::reverse(path.begin(), path.end()); - std::sort(buffers.begin(), buffers.end()); - long long ret_mem = static_cast(best_mem); - if (mode == 0) { - ret_mem += buffers[buffers.size() - 1] + buffers[buffers.size() - 2]; - } else if (mode == 1) { - ret_mem += buffers[buffers.size() - 1]; - } - ret_mem = ret_mem - act_mem + std::max(act_mem, opt_mem); - if (ret_mem > mem_bound) { - return std::make_tuple(-1, -1, -1, std::vector>()); - } - if (verbose) { - std::cout << "best time: " << best_time - << ", best mem: " << best_mem / 1024 / 1024 * mem_div << "MB, " - << "activation mem: " << act_mem / 1024 / 1024 * mem_div << "MB, " - << "optimizer state mem: " << opt_mem / 1024 / 1024 * mem_div - << "MB" << std::endl; - } - return std::make_tuple(best_time, inner_time, static_cast(ret_mem), - path); -} - -template void write_binary(T val, std::ofstream &stream) { - stream.write(reinterpret_cast(&val), sizeof val); -} - -void post_process(int start_level, int end_level, int topk, - std::ofstream &output) { - std::vector>>> - best_info; - double best_time; - double inner_time; - int best_mem; - std::vector> path; - for (DPNode *dp_node : id2node[end_level]->dp_nodes) { - int cnt = 0; - for (int i = 0; i < dp_node->state.size(); ++i) { - std::tie(best_time, inner_time, best_mem, path) = - process_state(dp_node, dp_node->state.size() - i - 1); - if (best_time > 0) { - ++cnt; - best_info.push_back( - std::make_tuple(best_time, inner_time, best_mem, path)); - if (cnt == topk) { - break; - } - } - } - } - std::sort(best_info.begin(), best_info.end()); - int ret_size = std::min(topk, int(best_info.size())); - int path_len = 0; - if (ret_size > 0) { - path_len = static_cast(std::get<3>(best_info[0]).size()); - } - write_binary(start_level, output); - write_binary(end_level, output); - write_binary(ret_size, output); - write_binary(path_len, output); - for (int i = 0; i < ret_size; ++i) { - best_time = std::get<0>(best_info[i]); - inner_time = std::get<1>(best_info[i]); - best_mem = std::get<2>(best_info[i]); - write_binary(best_time, output); - write_binary(inner_time, output); - write_binary(best_mem, output); - for (auto &item : std::get<3>(best_info[i])) { - write_binary(item.first, output); - write_binary(item.second, output); - } - } -} - -int main(int argc, char **argv) { - if (argc != 4) { - std::cout << "Usage: ./solver in_path out_path is_verbose" << std::endl; - return 0; - } - std::ifstream input(argv[1], std::ios::binary); - std::ofstream output(argv[2], std::ios::out | std::ios::binary); - verbose = argv[3][0] == '1'; - build_graph(input); - // to reduce time, we first group the queries by start node (level) - std::unordered_map> intervals; - for (const auto &query : queries) { - auto iter = intervals.find(query.first); - if (iter == intervals.end()) { - intervals[query.first] = std::vector(1, query.second); - } else { - iter->second.push_back(query.second); - } - } - auto start = std::chrono::system_clock::now(); - for (auto &item : intervals) { - // for each start node, we do dp until the last end node - int start_level = item.first; - std::vector &end_levels = item.second; - std::sort(end_levels.begin(), end_levels.end()); - do_dp(start_level, *end_levels.rbegin()); - for (int end_level : end_levels) { - post_process(start_level, end_level, topk, output); - } - long long state_cnt = 0; - for (auto iter = id2node.begin(); iter != id2node.end(); ++iter) { - int cur_id = iter->first; - Node *cur_node = iter->second; - for (DPNode *dp_node : cur_node->dp_nodes) { - state_cnt += dp_node->state.size(); - } - } - if (verbose) { - std::cout << "state num: " << state_cnt << std::endl; - } - } - auto end = std::chrono::system_clock::now(); - - std::chrono::duration elapsed_seconds = end - start; - - std::cout << "elapsed time: " << elapsed_seconds.count() << " s" << std::endl; - input.close(); - output.close(); - return 0; -} diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp new file mode 100644 index 00000000..8e3826b8 --- /dev/null +++ b/nnscaler/autodist/dp_solver.cpp @@ -0,0 +1,829 @@ +// cppimport +#include +#include + +namespace py = pybind11; + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(unsigned int n = std::thread::hardware_concurrency()); + + template void enqueue(F &&f); + void waitFinished(); + ~ThreadPool(); + + unsigned int getProcessed() const { return processed; } + +private: + std::vector workers; + std::deque> tasks; + std::mutex queue_mutex; + std::condition_variable cv_task; + std::condition_variable cv_finished; + std::atomic_uint processed; + unsigned int busy; + bool stop; + + void thread_proc(); +}; + +ThreadPool::ThreadPool(unsigned int n) : busy(), processed(), stop() { + for (unsigned int i = 0; i < n; ++i) + workers.emplace_back(std::bind(&ThreadPool::thread_proc, this)); +} + +ThreadPool::~ThreadPool() { + // set stop-condition + std::unique_lock latch(queue_mutex); + stop = true; + cv_task.notify_all(); + latch.unlock(); + + // all threads terminate, then we're done. + for (auto &t : workers) + t.join(); +} + +void ThreadPool::thread_proc() { + while (true) { + std::unique_lock latch(queue_mutex); + cv_task.wait(latch, [this]() { return stop || !tasks.empty(); }); + if (!tasks.empty()) { + // got work. set busy. + ++busy; + + // pull from queue + auto fn = tasks.front(); + tasks.pop_front(); + + // release lock. run async + latch.unlock(); + + // run function outside context + fn(); + ++processed; + + latch.lock(); + --busy; + cv_finished.notify_one(); + } else if (stop) { + break; + } + } +} + +// generic function push +template void ThreadPool::enqueue(F &&f) { + std::unique_lock lock(queue_mutex); + tasks.emplace_back(std::forward(f)); + cv_task.notify_one(); +} + +// waits until the queue is empty. +void ThreadPool::waitFinished() { + std::unique_lock lock(queue_mutex); + cv_finished.wait(lock, [this]() { return tasks.empty() && (busy == 0); }); +} + +const int MAX_CONCURRENCY = std::thread::hardware_concurrency(); +ThreadPool pool(MAX_CONCURRENCY); + +std::vector> split_work(int num) { + std::vector work; + if (num < MAX_CONCURRENCY) { + work = std::vector(num, 1); + } else { + work = std::vector(MAX_CONCURRENCY, num / MAX_CONCURRENCY); + for (int i = 0; i < num % MAX_CONCURRENCY; ++i) { + work[i] += 1; + } + } + std::vector> ret(work.size()); + int cum_sum = 0; + for (std::size_t i = 0; i < work.size(); ++i) { + ret[i] = std::make_pair(cum_sum, work[i]); + cum_sum += work[i]; + } + return ret; +} + +struct DPNode; + +struct Node { + int id; + int father_id; + + int cut_len; + std::vector cut_nodes; + + int p_num; + std::vector p_time; + std::vector p_comp_mem; + std::vector p_buf_mem; + std::vector p_act_mem; + std::vector p_opt_mem; + std::vector p_father; + + int producer_num; + std::vector producers; + std::vector> comm_costs; + + // assume the number of combinations is less than 2e9 + int dp_num; + std::vector dp_nodes; +}; + +struct DPNode { + Node *graph_node; + int pg_id; + std::vector ir; + std::vector> in_edges; + // mem, time, activation_mem, optimzer_mem + std::vector> state; +}; + +void resetNode(Node *node) { + for (DPNode *dp_node : node->dp_nodes) { + dp_node->state.clear(); + } +} + +void printNode(Node *node) { + std::cout << "id: " << node->id << std::endl; + std::cout << "father_id: " << node->father_id << std::endl; + std::cout << "cut_len: " << node->cut_len << std::endl; + std::cout << "cut_nodes: "; + for (auto cut_node : node->cut_nodes) { + std::cout << cut_node->id << " "; + } + std::cout << std::endl; + std::cout << "p_num: " << node->p_num << std::endl; + std::cout << "p_time: "; + for (auto p_time : node->p_time) { + std::cout << p_time << " "; + } + std::cout << std::endl; + std::cout << "p_comp_mem: "; + for (auto p_comp_mem : node->p_comp_mem) { + std::cout << p_comp_mem << " "; + } + std::cout << std::endl; + std::cout << "p_buf_mem: "; + for (auto p_buf_mem : node->p_buf_mem) { + std::cout << p_buf_mem << " "; + } + std::cout << std::endl; + std::cout << "p_act_mem: "; + for (auto p_act_mem : node->p_act_mem) { + std::cout << p_act_mem << " "; + } + std::cout << std::endl; + std::cout << "p_opt_mem: "; + for (auto p_opt_mem : node->p_opt_mem) { + std::cout << p_opt_mem << " "; + } + std::cout << std::endl; + std::cout << "producer_num: " << node->producer_num << std::endl; + std::cout << "producers: "; + for (auto producer : node->producers) { + std::cout << producer->id << " "; + } + std::cout << std::endl; + std::cout << "p_father: "; + for (auto p_father : node->p_father) { + std::cout << p_father << " "; + } + std::cout << std::endl; + std::cout << "comm_costs: " << std::endl; + for (auto comm_cost : node->comm_costs) { + for (auto cost : comm_cost) { + std::cout << cost << " "; + } + std::cout << std::endl; + } + std::cout << "dp_num: " << node->dp_num << std::endl; + std::cout << std::endl; +} + +// lazy decode +// after decoding, ir stores the partition id of each cut node +void decodePGID(DPNode *dp_node) { + if (!dp_node->ir.empty()) { + return; + } + Node *node = dp_node->graph_node; + int val = dp_node->pg_id; + for (int i = 0; i < node->cut_len; ++i) { + Node *cur_node = node->cut_nodes[node->cut_len - i - 1]; + dp_node->ir.push_back(val % cur_node->p_num); + val /= cur_node->p_num; + } + std::reverse(dp_node->ir.begin(), dp_node->ir.end()); +} + +struct SearchPlan { + double all_time; + double inner_time; + int memory; + std::vector> path; + + bool operator<(const SearchPlan &other) const { + return all_time < other.all_time; + } +}; + +class DPSolver { +public: + DPSolver(bool verbose, int mode, int mem_bound, int mem_div, int topk) : verbose(verbose), mode(mode), mem_bound(mem_bound), mem_div(mem_div), topk(topk) + { + queries.clear(); + id2node.clear(); + search_results.clear(); + } + + void add_interval(int start, int end) { + if (verbose) { + std::cout << "add interval start: " << start << ", end: " << end + << std::endl; + } + queries.push_back(std::make_pair(start, end)); + } + + void add_node(int id, int father_id, std::vector cut_ids, + std::vector producers, int p_num) { + if (verbose) { + std::cout << "id: " << id << ", father_id: " << father_id + << ", cut_ids: "; + for (int cut_id : cut_ids) { + std::cout << cut_id << " "; + } + std::cout << ", producers: "; + for (int producer : producers) { + std::cout << producer << " "; + } + std::cout << ", p_num: " << p_num << std::endl; + } + Node *node = new Node(); + id2node[id] = node; + node->id = id; + node->p_num = p_num; + node->p_father.resize(p_num); + node->p_time.resize(p_num); + node->p_comp_mem.resize(p_num); + node->p_buf_mem.resize(p_num); + node->p_act_mem.resize(p_num); + node->p_opt_mem.resize(p_num); + node->father_id = father_id; + node->cut_len = cut_ids.size(); + node->cut_nodes.resize(node->cut_len); + for (int i = 0; i < node->cut_len; ++i) { + node->cut_nodes[i] = id2node[cut_ids[i]]; + } + node->producer_num = producers.size(); + node->producers.resize(node->producer_num); + node->comm_costs.clear(); + node->comm_costs.resize(node->producer_num); + for (int i = 0; i < node->producer_num; ++i) { + node->producers[i] = id2node[producers[i]]; + node->comm_costs[i].resize(node->p_num * node->producers[i]->p_num); + } + } + + void add_partition(int node_id, int p_idx, double p_time, int p_comp_mem, + int p_buf_mem, int p_act_mem, int p_opt_mem, int p_father, + std::vector> comm_costs) { + if (verbose) { + std::cout << "node_id: " << node_id << ", p_idx: " << p_idx + << ", p_time: " << p_time << ", p_comp_mem: " << p_comp_mem + << ", p_buf_mem: " << p_buf_mem << ", p_act_mem: " << p_act_mem + << ", p_opt_mem: " << p_opt_mem << ", p_father: " << p_father + << std::endl; + std::cout << "comm_costs: " << std::endl; + for (std::size_t i = 0; i < comm_costs.size(); ++i) { + for (std::size_t j = 0; j < comm_costs[i].size(); ++j) { + std::cout << comm_costs[i][j] << " "; + } + std::cout << std::endl; + } + } + Node *node = id2node[node_id]; + node->p_time[p_idx] = p_time; + node->p_comp_mem[p_idx] = p_comp_mem; + node->p_buf_mem[p_idx] = p_buf_mem; + node->p_act_mem[p_idx] = p_act_mem; + node->p_opt_mem[p_idx] = p_opt_mem; + node->p_father[p_idx] = p_father; + for (int i = 0; i < node->producer_num; ++i) { + for (int j = 0; j < node->producers[i]->p_num; ++j) { + node->comm_costs[i][p_idx * node->producers[i]->p_num + j] = + comm_costs[i][j]; + } + } + } + + void init_dp_info() { + for (auto iter = id2node.begin(); iter != id2node.end(); ++iter) { + Node *node = iter->second; + node->dp_num = 1; + for (Node *cut_node : node->cut_nodes) { + node->dp_num *= cut_node->p_num; + } + node->dp_nodes.resize(node->dp_num); + // pg: partition group, denotes the maintained partition states in + // a node. to reduce memory usage, we use a single int to + // represent a partition group + for (int j = 0; j < node->dp_num; ++j) { + DPNode *dp_node = new DPNode(); + node->dp_nodes[j] = dp_node; + dp_node->graph_node = node; + dp_node->pg_id = j; + dp_node->ir.clear(); + dp_node->in_edges.clear(); + dp_node->state.clear(); + } + } + } + + // lazy build edge + void buildInEdges(DPNode *dp_node) { + if (!dp_node->in_edges.empty()) { + return; + } + Node *node = dp_node->graph_node; + + // special case: the node does not have any producer + // the pred dp node is composed of the same cut nodes as the current node + // except the last one. the transition cost is 0 since there is no + // communication + if (node->producer_num == 0) { + int val = 0; + for (int i = 0; i < node->cut_len - 1; ++i) { + val += dp_node->ir[i]; + if (i < node->cut_len - 2) { + val *= node->cut_nodes[i + 1]->p_num; + } + } + Node *pre_node = id2node[node->id - 1]; + dp_node->in_edges.push_back(std::make_pair(pre_node->dp_nodes[val], 0)); + return; + } + + int cur_p = *(dp_node->ir.rbegin()); + // we have filtered out the partition that cannot find a father to follow + assert(node->p_father[cur_p] != -1); + std::map info; + for (int i = 0; i < node->cut_len - 1; ++i) { + info[node->cut_nodes[i]->id] = dp_node->ir[i]; + } + // TODO(yizhu1): optimize + int producer_comb_num = 1; + for (Node *producer : node->producers) { + producer_comb_num *= producer->p_num; + } + // enumerate all the possible producer partition combinations + // to build the in edges + for (int idx = 0; idx < producer_comb_num; ++idx) { + bool is_legal = true; + int val = idx; + std::vector producer_ps(node->producer_num); + // decode the producer partition combination + for (int j = 0; j < node->producer_num; ++j) { + int k = node->producer_num - 1 - j; + producer_ps[k] = val % node->producers[k]->p_num; + val /= node->producers[k]->p_num; + // constraint: if the producer shares the same father with the node, + // then the partition of the node should follow the producer's + // partition, except the producer is the father node. + if (node->father_id != node->id) { + Node *producer = node->producers[k]; + // TODO: do we need to check producer->father_id != producer->id? + // seems this case will be filtered out by checker in line 411 + if (producer->father_id == node->father_id && + producer->father_id != producer->id) { + if (node->p_father[cur_p] != producer->p_father[producer_ps[k]]) { + is_legal = false; + } + } + } + } + if (!is_legal) { + continue; + } + // + std::vector> cur_ir(node->cut_len - 1); + bool has_found_follow = false; + for (int i = 0; i < node->cut_len - 1; ++i) { + cur_ir[i] = std::make_pair(node->cut_nodes[i]->id, dp_node->ir[i]); + if (node->cut_nodes[i]->father_id == node->father_id) { + has_found_follow = true; + } + } + double cost = 0; + std::vector> follow_candidates; + for (int j = 0; j < node->producer_num; ++j) { + int producer_id = node->producers[j]->id; + int producer_p = producer_ps[j]; + auto iter = info.find(producer_id); + if (iter != info.end()) { + if (producer_p != iter->second) { + is_legal = false; + break; + } + } else { + Node *producer = node->producers[j]; + if (producer->father_id != node->father_id) { + // check that there is a existing node in cur_ir that in the same + // follow chain with the producer + bool find_existing_follow = false; + for (std::size_t i = 0; i < cur_ir.size(); ++i) { + Node *tmp = id2node[cur_ir[i].first]; + if (tmp->father_id == producer->father_id) { + find_existing_follow = true; + // update + if (tmp->id < producer->id) { + for (int _ = 0; _ < producer->p_num; ++_) { + if (producer->p_father[_] == + tmp->p_father[cur_ir[i].second]) { + // replace to align with the filter logic in python + // only the newest node in the follow chain is kept + cur_ir[i] = std::make_pair(producer->id, _); + break; + } + } + } + break; + } + } + if (!find_existing_follow) { + cur_ir.push_back(std::make_pair(producer_id, producer_p)); + } + } else { + follow_candidates.push_back( + std::make_pair(producer_id, producer_p)); + } + } + cost += + node->comm_costs[j][cur_p * node->producers[j]->p_num + producer_p]; + } + if (!is_legal) { + continue; + } + // handle follow + bool find_pre_id = false; + for (std::size_t j = 0; j < cur_ir.size(); ++j) { + if (cur_ir[j].first == node->id - 1) { + find_pre_id = true; + break; + } + } + if (!find_pre_id) { + Node *pre_node = id2node[node->id - 1]; + if (pre_node->father_id != node->father_id) { + // do nothing, means the pre_node's output is not used + // we select the 1st partition of the pre_node + // need to be careful when the graph has multiple outputs + // shall we constrain that the output of the graph is replicated? + } else if (pre_node->father_id == pre_node->id) { + assert(follow_candidates.rbegin()->first == pre_node->id); + cur_ir.push_back(*follow_candidates.rbegin()); + } else { + bool find_same_follow_p = false; + for (int k = 0; k < pre_node->p_num; ++k) { + if (pre_node->p_father[k] == node->p_father[cur_p]) { + cur_ir.push_back(std::make_pair(node->id - 1, k)); + find_same_follow_p = true; + break; + } + } + assert(find_same_follow_p); + } + } else { + if (node->father_id != node->id && !has_found_follow && + !follow_candidates.empty()) { + cur_ir.push_back(*follow_candidates.rbegin()); + } + } + std::sort(cur_ir.begin(), cur_ir.end()); + val = 0; + for (std::size_t j = 0; j < cur_ir.size(); ++j) { + val += cur_ir[j].second; + if (j + 1 < cur_ir.size()) { + val *= id2node[cur_ir[j + 1].first]->p_num; + } + } + dp_node->in_edges.push_back( + std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); + } + } + + // do dp for a partition group + void update(DPNode *dp_node, int start_level) { + Node *node = dp_node->graph_node; + decodePGID(dp_node); + int cur_p_idx = *(dp_node->ir.rbegin()); + if (node->id == start_level) { + // each dp node maintains a list of states, each state is a tuple + // (mem, time, pred_dp_node, activation_mem, optimizer_mem) + dp_node->state.push_back(std::make_tuple( + node->p_comp_mem[cur_p_idx], node->p_time[cur_p_idx], nullptr, + node->p_act_mem[cur_p_idx], node->p_opt_mem[cur_p_idx])); + return; + } + + // storing edges takes space, so we build edges when needed + buildInEdges(dp_node); + int cur_p = *(dp_node->ir.rbegin()); + if (dp_node->in_edges.empty()) { + dp_node->state.push_back(std::make_tuple( + 0, std::numeric_limits::infinity(), nullptr, 0, 0)); + return; + } + + // use a priority queue to maintain the best state, similar to the merge + // sort + double cur_p_time = node->p_time[cur_p]; + int cur_p_comp_mem = node->p_comp_mem[cur_p]; + int cur_p_act_mem = node->p_act_mem[cur_p]; + int cur_p_opt_mem = node->p_opt_mem[cur_p]; + std::priority_queue> pq; + for (std::size_t i = 0; i < dp_node->in_edges.size(); ++i) { + DPNode *pred = dp_node->in_edges[i].first; + int mem = cur_p_comp_mem + std::get<0>(pred->state[0]); + double cost = cur_p_time + dp_node->in_edges[i].second + + std::get<1>(pred->state[0]); + int act_mem = cur_p_act_mem + std::get<3>(pred->state[0]); + int opt_mem = cur_p_opt_mem + std::get<4>(pred->state[0]); + pq.push(std::make_tuple(-mem, -cost, i, -act_mem, -opt_mem)); + } + + std::vector lows(dp_node->in_edges.size(), 1); + + int cur_mem; + double cur_cost; + int pred_idx; + int cur_act_mem; + int cur_opt_mem; + while (!pq.empty()) { + std::tie(cur_mem, cur_cost, pred_idx, cur_act_mem, cur_opt_mem) = + pq.top(); + cur_mem = -cur_mem; + cur_cost = -cur_cost; + cur_act_mem = -cur_act_mem; + cur_opt_mem = -cur_opt_mem; + pq.pop(); + if (lows[pred_idx] < dp_node->in_edges[pred_idx].first->state.size()) { + DPNode *pred = dp_node->in_edges[pred_idx].first; + int mem = cur_p_comp_mem + std::get<0>(pred->state[lows[pred_idx]]); + double cost = cur_p_time + dp_node->in_edges[pred_idx].second + + std::get<1>(pred->state[lows[pred_idx]]); + int act_mem = cur_p_act_mem + std::get<3>(pred->state[lows[pred_idx]]); + int opt_mem = cur_p_opt_mem + std::get<4>(pred->state[lows[pred_idx]]); + pq.push(std::make_tuple(-mem, -cost, pred_idx, -act_mem, -opt_mem)); + ++lows[pred_idx]; + } + if (dp_node->state.empty()) { + dp_node->state.push_back(std::make_tuple( + cur_mem, cur_cost, dp_node->in_edges[pred_idx].first, cur_act_mem, + cur_opt_mem)); + } else { + int pre_mem = std::get<0>(dp_node->state[dp_node->state.size() - 1]); + double pre_cost = + std::get<1>(dp_node->state[dp_node->state.size() - 1]); + // if (cur_mem > pre_mem && cur_cost < pre_cost && + // cur_mem + cur_opt_mem <= mem_bound) { + if (cur_mem > pre_mem && cur_cost < pre_cost && + cur_mem - cur_act_mem + std::max(cur_act_mem, cur_opt_mem) <= + mem_bound) { + dp_node->state.push_back(std::make_tuple( + cur_mem, cur_cost, dp_node->in_edges[pred_idx].first, cur_act_mem, + cur_opt_mem)); + } + } + } + } + + void do_dp(int start_level, int end_level) { + // reset all the dp nodes, since we may have multiple queries + for (auto iter = id2node.begin(); iter != id2node.end(); ++iter) { + resetNode(iter->second); + } + + for (int i = start_level; i <= end_level; ++i) { + // use multi-thread to do dp for each level to reduce time + auto iter = id2node.find(i); + if (iter == id2node.end()) { + // TODO(yizhu1): check here + assert(false); + } + if (verbose) { + std::cout << "Start to process level id: " << i + << ", state num: " << iter->second->dp_nodes.size() + << std::endl; + } + std::vector> split_info = + split_work(iter->second->dp_num); + for (const auto &item : split_info) { + pool.enqueue([=] { + for (int i = 0; i < item.second; ++i) { + int offset = item.first + i; + update(iter->second->dp_nodes[offset], start_level); + } + }); + } + pool.waitFinished(); + } + } + + SearchPlan process_state(DPNode *dp_node, int idx) { + // build the optimal path of each partition of last operator + // and return the best path + std::vector> path; + DPNode *cur_dp_node = dp_node; + int cur_idx = idx; + int best_mem = std::get<0>(dp_node->state[idx]); + double best_time = std::get<1>(dp_node->state[idx]); + int act_mem = std::get<3>(dp_node->state[idx]); + int opt_mem = std::get<4>(dp_node->state[idx]); + double inner_time = 0; + int cur_best_mem = best_mem; + std::vector buffers; + while (true) { + int cur_p = *(cur_dp_node->ir.rbegin()); + Node *node = cur_dp_node->graph_node; + path.push_back(std::make_pair(node->id, cur_p)); + buffers.push_back(node->p_buf_mem[cur_p]); + inner_time += node->p_time[cur_p]; + cur_best_mem -= node->p_comp_mem[cur_p]; + DPNode *pred_dp_node = std::get<2>(cur_dp_node->state[cur_idx]); + if (pred_dp_node == nullptr) { + break; + } else { + cur_dp_node = pred_dp_node; + cur_idx = std::lower_bound( + cur_dp_node->state.begin(), cur_dp_node->state.end(), + std::make_tuple(cur_best_mem, static_cast(-1), + static_cast(nullptr), -1, -1)) - + cur_dp_node->state.begin(); + } + } + std::reverse(path.begin(), path.end()); + std::sort(buffers.begin(), buffers.end()); + long long ret_mem = static_cast(best_mem); + if (mode == 0) { + ret_mem += buffers[buffers.size() - 1] + buffers[buffers.size() - 2]; + } else if (mode == 1) { + ret_mem += buffers[buffers.size() - 1]; + } + ret_mem = ret_mem - act_mem + std::max(act_mem, opt_mem); + if (ret_mem > mem_bound) { + return SearchPlan{-1, -1, -1, std::vector>()}; + } + if (verbose) { + std::cout << "best time: " << best_time + << ", best mem: " << best_mem / 1024 / 1024 * mem_div << "MB, " + << "activation mem: " << act_mem / 1024 / 1024 * mem_div + << "MB, " + << "optimizer state mem: " << opt_mem / 1024 / 1024 * mem_div + << "MB" << std::endl; + } + return SearchPlan{best_time, inner_time, static_cast(ret_mem), path}; + } + + void post_process(int start_level, int end_level, int topk) { + std::vector best_info; + double best_time; + double inner_time; + int best_mem; + std::vector> path; + for (DPNode *dp_node : id2node[end_level]->dp_nodes) { + int cnt = 0; + for (std::size_t i = 0; i < dp_node->state.size(); ++i) { + SearchPlan plan = process_state(dp_node, dp_node->state.size() - i - 1); + if (plan.all_time > 0) { + ++cnt; + best_info.push_back(plan); + if (cnt == topk) { + break; + } + } + } + } + std::sort(best_info.begin(), best_info.end()); + search_results[std::make_pair(start_level, end_level)] = best_info; + } + + void solve() { + if (verbose) { + std::cout << "start to solve" << std::endl; + std::cout << "verbose: " << verbose << std::endl; + std::cout << "mode: " << mode << std::endl; + std::cout << "mem_bound: " << mem_bound << std::endl; + std::cout << "mem_div: " << mem_div << std::endl; + std::cout << "topk: " << topk << std::endl; + } + init_dp_info(); + // to reduce time, we first group the queries by start node (level) + std::unordered_map> intervals; + for (const auto &query : queries) { + auto iter = intervals.find(query.first); + if (iter == intervals.end()) { + intervals[query.first] = std::vector(1, query.second); + } else { + iter->second.push_back(query.second); + } + } + + auto start = std::chrono::system_clock::now(); + for (auto &item : intervals) { + // for each start node, we do dp until the last end node + int start_level = item.first; + std::vector &end_levels = item.second; + std::sort(end_levels.begin(), end_levels.end()); + do_dp(start_level, *end_levels.rbegin()); + for (int end_level : end_levels) { + post_process(start_level, end_level, topk); + } + long long state_cnt = 0; + for (auto iter = id2node.begin(); iter != id2node.end(); ++iter) { + int cur_id = iter->first; + Node *cur_node = iter->second; + for (DPNode *dp_node : cur_node->dp_nodes) { + state_cnt += dp_node->state.size(); + } + } + if (verbose) { + std::cout << "state num: " << state_cnt << std::endl; + } + } + auto end = std::chrono::system_clock::now(); + + std::chrono::duration elapsed_seconds = end - start; + + std::cout << "elapsed time: " << elapsed_seconds.count() << " s" + << std::endl; + } + + std::vector get_results(int start_level, int end_level) { + return search_results[std::make_pair(start_level, end_level)]; + } + + bool verbose; + // mode = 0: training, use the sum of the two largest buffer sizes + // mode = 1: inference, use the largest buffer size + int mode; + // mem_bound: the maximum memory usage, in bytes + int mem_bound; + // mem_div: the memory divisor, to avoid overflow in int32 + int mem_div; + int topk; + + std::unordered_map id2node; + std::vector> queries; + std::map, std::vector> search_results; +}; + +PYBIND11_MODULE(dp_solver, m) { + py::class_(m, "SearchPlan") + .def_readonly("all_time", &SearchPlan::all_time) + .def_readonly("inner_time", &SearchPlan::inner_time) + .def_readonly("memory", &SearchPlan::memory) + .def_readonly("path", &SearchPlan::path); + + py::class_(m, "DPSolver") + .def(py::init()) + .def("add_interval", &DPSolver::add_interval) + .def("add_node", &DPSolver::add_node) + .def("add_partition", &DPSolver::add_partition) + .def("solve", &DPSolver::solve) + .def("get_results", &DPSolver::get_results); +} +/* +<% +setup_pybind11(cfg) +cfg['extra_compile_args'] = ['-std=c++11'] +cfg['extra_compile_args'] = ['-O3'] +cfg['extra_compile_args'] = ['-pthread'] +%> +*/ diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index b478e598..eb89db9d 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -1,12 +1,6 @@ from .model_graph import ModelGraph from .cube_operator import CubeOperator from .descs import * -from .util import ( - int2byte, - int4byte, - double2byte, - double4byte, -) from .cost_database import CostDatabase from .autodist_config import AutoDistConfig from .op_partition import OpPartition, generate_partitions @@ -997,113 +991,33 @@ def do_ilp(self, intervals: List[Tuple[int, int]], def do_dp(self, intervals: List[Tuple[int, int]], topk: int) -> List[SPMDSearchOutput]: + import cppimport.import_hook + import nnscaler.autodist.dp_solver as dp_solver - idx_map = {} - for idx, item in enumerate(intervals): - idx_map[item] = idx - - ret = [None] * len(intervals) - - in_fname = f'./cpp_in_{self.device_num}_{self.stage_num}.bin' - out_fname = f'./cpp_out_{self.device_num}_{self.stage_num}.bin' - """ - cpp_in format - is_train node_num mem_bound mem_div topk interval_num - interval_num tuples of (start_idx, end_idx) - node_num groups of - - id - - father_id - - cut_lens, [cut_ids] - - partition nums, [p_father, comp_time + weight_update_time, train_mem, buffer_mem, act_mem, opt_mem] - - producer_num, [producer_id, comm time mat] - """ - # TODO: hardcode value which should change according to device memory + mode = 0 if self.is_train else 1 mem_div = 64 - with open(in_fname, 'wb') as f: - # header - if self.is_train: - f.write(int2byte(0)) - else: - f.write(int2byte(1)) - f.write( - int2byte(self.graph.op_num) + - int2byte(int(self.mem_bound) // mem_div) + int2byte(mem_div)) - f.write(int2byte(topk) + int2byte(len(intervals))) - # intervals - for u, v in intervals: - f.write(int2byte(u) + int2byte(v)) - # partition info - for idx in range(self.graph.op_num): - f.write(int2byte(idx)) - f.write(int2byte(self.father_ids[idx])) - - f.write(int2byte(len(self.cut_ops[idx]))) - for cut_op in self.cut_ops[idx]: - f.write(int2byte(cut_op)) - - f.write(int2byte(self.get_op_partition_count(idx))) - for i, partition in enumerate(self._op_partitions[idx]): - p_cost_desc = self.partition_info[idx][i] - f.write( - double2byte(p_cost_desc.comp_time + - p_cost_desc.weight_update_time)) - - f.write( - int2byte(p_cost_desc.mem // mem_div) + - int2byte(p_cost_desc.transient_mem // mem_div) + - int2byte(p_cost_desc.activation_mem // mem_div) + - int2byte(p_cost_desc.opt_transient_mem // mem_div) + - int2byte(self.p_fathers[idx][i])) - - f.write(int2byte(len(self.producers[idx]))) - for p_i, producer in enumerate(self.producers[idx]): - f.write(int2byte(producer)) - for i, tgt_p in enumerate(self._op_partitions[idx]): - comm_time = self.partition_info[idx][i].comm_time - for j, src_p in enumerate( - self._op_partitions[producer]): - f.write(double2byte(comm_time[p_i][j])) - - os.system( - f'{str(Path.home())}/.autodist/solver {in_fname} {out_fname} {int(self.autodist_config.verbose)}' - ) - """ - cpp_out file format - interval_num parts, each part's format is - start_idx end_idx path_num path_len - each path is (opt_time, inner_time, opt_mem, a sequence of path_len (op_idx, partition_idx)) - """ - with open(out_fname, 'rb') as f: - data = f.read() - offset = 0 - for _ in range(len(intervals)): - descs = [] - start_level = int4byte(data[offset:offset + 4]) - end_level = int4byte(data[offset + 4:offset + 8]) - num = int4byte(data[offset + 8:offset + 12]) - path_len = int4byte(data[offset + 12:offset + 16]) - offset += 16 - for i in range(num): - opt_time = double4byte(data[offset:offset + 8]) - offset += 8 - inner_time = double4byte(data[offset:offset + 8]) - offset += 8 - opt_mem = int4byte(data[offset:offset + 4]) - offset += 4 - plans = [] - for j in range(path_len): - cur_op_idx = int4byte(data[offset:offset + 4]) - offset += 4 - cur_p_idx = int4byte(data[offset:offset + 4]) - offset += 4 - plans.append((cur_op_idx, cur_p_idx)) - desc = self.partition_path2desc(plans) - descs.append( - SPMDSearchOutput(desc, - opt_mem * mem_div / 1024 / 1024 / 1024, - opt_time, inner_time)) - ret[idx_map[(start_level, end_level)]] = descs - os.system(f'rm {in_fname} {out_fname}') + mem_bound = int(self.mem_bound) // mem_div + solver = dp_solver.DPSolver(self.autodist_config.verbose, mode, mem_bound, mem_div, topk) + for start, end in intervals: + solver.add_interval(start, end) + for idx in range(self.graph.op_num): + solver.add_node(idx, self.father_ids[idx], self.cut_ops[idx], + self.producers[idx], self.get_op_partition_count(idx)) + for i, partition in enumerate(self._op_partitions[idx]): + p_cost_desc = self.partition_info[idx][i] + solver.add_partition(idx, i, p_cost_desc.comp_time + p_cost_desc.weight_update_time, + p_cost_desc.mem // mem_div, p_cost_desc.transient_mem // mem_div, + p_cost_desc.activation_mem // mem_div, p_cost_desc.opt_transient_mem // mem_div, + self.p_fathers[idx][i], p_cost_desc.comm_time) + solver.solve() + ret = [] + for start, end in intervals: + cpp_results = solver.get_results(start, end) + descs = [] + for result in cpp_results: + desc = self.partition_path2desc(result.path) + descs.append(SPMDSearchOutput(desc, result.memory * mem_div / 1024 / 1024 / 1024, result.all_time, result.inner_time)) + ret.append(descs) return ret def solve(self, intervals: List[Tuple[int, int]], diff --git a/requirements.txt b/requirements.txt index 1912dc45..55e0755d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ +cppimport dill matplotlib more-itertools numpy>=1.23.0 pulp +pybind11 pyyaml torch>=2.0 diff --git a/tests/autodist/test_dp_solver.py b/tests/autodist/test_dp_solver.py new file mode 100644 index 00000000..b44abc79 --- /dev/null +++ b/tests/autodist/test_dp_solver.py @@ -0,0 +1,38 @@ +import cppimport.import_hook +import nnscaler.autodist.dp_solver as dp_solver + +# use a naive ffn to test the dynamic programming solver +# the ffn has 3 layers +# - linear layer +# - relu layer +# - linear layer +# each operator has 2 partition options + +def test_dp_solver(): + solver = dp_solver.DPSolver(True, 0, 80 * 1024, 1, 1) + solver.add_interval(0, 2) + + solver.add_node(0, 0, [0], [], 2) + solver.add_partition(0, 0, 1, 1, 1, 1, 1, 0, [[]]) + solver.add_partition(0, 1, 2, 2, 2, 2, 2, 1, [[]]) + + solver.add_node(1, 1, [1], [0], 2) + solver.add_partition(1, 0, 0.5, 1, 1, 1, 1, 0, [[0.1, 1]]) + solver.add_partition(1, 1, 1, 2, 2, 2, 2, 1, [[1, 0]]) + + solver.add_node(2, 2, [2], [1], 2) + solver.add_partition(2, 0, 1, 1, 1, 1, 1, 0, [[0.2, 1]]) + solver.add_partition(2, 1, 2, 2, 2, 2, 2, 1, [[1, 0]]) + + solver.solve() + + ans = solver.get_results(0, 2) + + best = ans[0] + + # optimal all time 1 + 0.5 + 0.1 + 1 + 0.2 = 2.8 + assert best.all_time == 2.8 + # optimal inner time 1 + 0.5 + 1 = 2.5 + assert best.inner_time == 2.5 + # the optimal plan is each operator's first partition + assert best.path == [(0, 0), (1, 0), (2, 0)] From c38b49e327ebba8c1dca4c0b48e7f109ddf9a46c Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 24 Apr 2024 17:46:47 +0000 Subject: [PATCH 1626/1892] Merged PR 2127: Refine non persistent at model loading according to [discussion](https://dev.azure.com/msrasrg/SuperScaler/_workitems/edit/1894), current implementation in nnscaler and fairseq cannot handle non persistent buffer correctly. To fix this problem, we will set load_content = True when detecting non persistent buffers. --- nnscaler/runtime/module.py | 10 ++++++++++ nnscaler/utils.py | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 8375c97b..5b385b6c 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -55,6 +55,16 @@ def __init__(self): # please note there can be multiple entries with same tid self._fullmap : Dict[str, AttrMeta] = dict() + def get_non_persistent_buffers(self) -> Dict[str, torch.Tensor]: + """ + Get non-persistent buffers in the module + """ + non_persistent_buffers = {} + for name, buffer in self.named_buffers(recurse=False): + if name in self._non_persistent_buffers_set: + non_persistent_buffers[name] = buffer + return non_persistent_buffers + @property def reducers(self): return self._reducers diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 5de58c97..fb1f47cf 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -58,6 +58,12 @@ def load_model(filename: Optional[str] = None, load_content: bool = True, fullmo filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename module = _load_module_attr(filename, Path(filename).stem) loaded_module: nnscaler.runtime.module.CubeModule = module.GenModel().cuda() + non_persistent_buffers = loaded_module.get_non_persistent_buffers() + if non_persistent_buffers: + names = [name for name, _ in non_persistent_buffers.items()] + _logger.warning(f'Detected non-persistent buffers: {names}, will load content, make sure fullmodel.pt.* are available and consistent.') + if not load_content: + load_content = True # load parameter content if load_content: _logger.info("loading parameter content...") From d98cbe833b56ff8fc2adf2e1fbcbf30ea6a44194 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sat, 27 Apr 2024 15:02:20 +0000 Subject: [PATCH 1627/1892] Merged PR 2129: Refine autodist's comm path to enable tests - if `comm` is not found, use the default comm data for mi200 - remove useless files --- autodist/README.md | 17 -- autodist/benchmark/alpa/Alpa_solver.md | 43 ---- autodist/benchmark/alpa/README.md | 203 --------------- .../benchmark/alpa/analyse_strategy/README.md | 15 -- .../alpa/analyse_strategy/gen_str.py | 137 ---------- .../alpa/analyse_strategy/strategy.zip | Bin 886 -> 0 bytes autodist/benchmark/alpa/benchmark.py | 204 --------------- autodist/benchmark/alpa/gpt_alpa_2d_table1.sh | 16 -- autodist/benchmark/alpa/gpt_alpa_2d_table2.sh | 27 -- autodist/benchmark/alpa/gpt_alpa_3d.sh | 16 -- autodist/benchmark/alphafold2.md | 124 ---------- autodist/benchmark/gpt.md | 234 ------------------ autodist/benchmark/recompute.md | 28 --- autodist/build_env.py | 51 ---- autodist/script/alphafold/foldtp.sh | 64 ----- autodist/script/gpt/adapt_recom_tp.sh | 59 ----- autodist/script/gpt/analyze.py | 77 ------ autodist/script/gpt/analyze_adapt_recom.py | 77 ------ autodist/script/gpt/checker.sh | 42 ---- autodist/script/gpt/pp_all_run.sh | 45 ---- autodist/script/gpt/profile.sh | 66 ----- autodist/script/gpt/tp_all_run.sh | 57 ----- autodist/script/pre_install.sh | 7 - autodist/script/swin/analysis.py | 72 ------ autodist/script/swin/profile_swin.sh | 60 ----- autodist/script/swin/swintp.sh | 66 ----- nnscaler/autodist/apis.py | 10 +- nnscaler/autodist/autodist_config.py | 9 +- nnscaler/autodist/cost_database.py | 7 +- nnscaler/autodist/util.py | 3 + nnscaler/profiler/database.py | 2 +- .../mi200}/comm/intra_16.json | 0 .../mi200}/comm/intra_2.json | 0 .../mi200}/comm/intra_4.json | 0 .../mi200}/comm/intra_8.json | 0 {autodist => utility}/comm_profile.py | 41 +-- utility/prim_profiler.py | 53 ++++ 37 files changed, 92 insertions(+), 1840 deletions(-) delete mode 100644 autodist/README.md delete mode 100644 autodist/benchmark/alpa/Alpa_solver.md delete mode 100644 autodist/benchmark/alpa/README.md delete mode 100644 autodist/benchmark/alpa/analyse_strategy/README.md delete mode 100644 autodist/benchmark/alpa/analyse_strategy/gen_str.py delete mode 100644 autodist/benchmark/alpa/analyse_strategy/strategy.zip delete mode 100644 autodist/benchmark/alpa/benchmark.py delete mode 100644 autodist/benchmark/alpa/gpt_alpa_2d_table1.sh delete mode 100644 autodist/benchmark/alpa/gpt_alpa_2d_table2.sh delete mode 100644 autodist/benchmark/alpa/gpt_alpa_3d.sh delete mode 100644 autodist/benchmark/alphafold2.md delete mode 100644 autodist/benchmark/gpt.md delete mode 100644 autodist/benchmark/recompute.md delete mode 100644 autodist/build_env.py delete mode 100755 autodist/script/alphafold/foldtp.sh delete mode 100644 autodist/script/gpt/adapt_recom_tp.sh delete mode 100644 autodist/script/gpt/analyze.py delete mode 100644 autodist/script/gpt/analyze_adapt_recom.py delete mode 100755 autodist/script/gpt/checker.sh delete mode 100755 autodist/script/gpt/pp_all_run.sh delete mode 100644 autodist/script/gpt/profile.sh delete mode 100755 autodist/script/gpt/tp_all_run.sh delete mode 100644 autodist/script/pre_install.sh delete mode 100644 autodist/script/swin/analysis.py delete mode 100644 autodist/script/swin/profile_swin.sh delete mode 100755 autodist/script/swin/swintp.sh rename {autodist/profile_data/16xmi200 => profile_data/mi200}/comm/intra_16.json (100%) rename {autodist/profile_data/16xmi200 => profile_data/mi200}/comm/intra_2.json (100%) rename {autodist/profile_data/16xmi200 => profile_data/mi200}/comm/intra_4.json (100%) rename {autodist/profile_data/16xmi200 => profile_data/mi200}/comm/intra_8.json (100%) rename {autodist => utility}/comm_profile.py (75%) create mode 100644 utility/prim_profiler.py diff --git a/autodist/README.md b/autodist/README.md deleted file mode 100644 index b252f7ec..00000000 --- a/autodist/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# AutoDist - -AutoDist is a package that optimizes for efficient distributed execution plans, given a DL data flow graph and cluster specifications. Compared to [Alpa](https://github.com/alpa-projects/alpa), AutoDist has two main advantages: -- a topology aware dynamic programming solver, which is faster than integer linear programing formulation in most cases -- achieve a balance between memory and time automatically, instead of using a global option - -## Prerequisite - -```bash -bash ./script/pre_install.sh -``` - -## Pipeline - -

- -

diff --git a/autodist/benchmark/alpa/Alpa_solver.md b/autodist/benchmark/alpa/Alpa_solver.md deleted file mode 100644 index ef3e97c5..00000000 --- a/autodist/benchmark/alpa/Alpa_solver.md +++ /dev/null @@ -1,43 +0,0 @@ -# Alpa Solver Details - -We have conducted a detailed test on the  solver in Alpa and have two conclusions. -1) input and constraints limit the efficiency of the solver.  -2) Alpa is unable to correctly solve problems with memory constraints.  - -The table below shows the numbers of the Alpa solver under different micro batch sizes and GPT models in spmd. This table includes there parts, the number of free variables (the first two columns) , alpa solving times (the mid three columns) and autodist compile times (the last column). **Baseline time** represents the original solver time, **random** represents filling the array with random numbers, and **mem** represents adding memory constraint conditions.  - -**1.3B** -| | num_nodes | num_edges | baseline time/s | random time/s | mem time/s | autodist time/s | -|---:|------------:|------------:|------------------:|----------------:|:-------------|------------------:| -| 1 | 2637 | 4586 | 19.3861 | 19.832 | > 600 | 2.65 | -| 2 | 2472 | 4468 | 18.4109 | 34.7781 | > 600 | 3.96 | -| 4 | 2473 | 4470 | 21.6232 | 42.6273 | > 600 | 6 | -| 8 | 2473 | 4470 | 19.3299 | 25.2286 | > 600 | 9.78 | -| 16 | 2473 | 4470 | 19.34 | 43.1969 | > 600 | 14.91 | -| 32 | 2473 | 4470 | 20.1404 | 38.766 | > 600 | 9.29 | - -**2.6B** -| | num_nodes | num_edges | baseline time/s | random time/s | mem time/s | autodist time/s | -|---:|------------:|------------:|------------------:|----------------:|:-------------|------------------:| -| 1 | 3493 | 6090 | 27.2841 | 48.645 | > 600 | 17.57 | -| 2 | 3272 | 5932 | 27.0608 | 41.8054 | > 600 | 24.84 | -| 4 | 3272 | 5932 | 25.7738 | 67.6498 | > 600 | 36.19 | -| 8 | 3273 | 5934 | 29.1824 | 67.1933 | > 600 | 76.63 | -| 16 | 3273 | 5934 | 27.4334 | 33.0636 | > 600 | 115.04 | -| 32 | 3273 | 5934 | 30.3701 | 63.2393 | > 600 | 69.36 | - -We can see from each row of the table that randomizing the input (**1.5~3x**) and adding memory constraint conditions (**>20x**) will increase the solving time of the solver, thus leading to the first conclusion stated above. Meanwhile, Autodist can quickly solve problems with memory constraints, achieving a maximum of **226x** faster solving efficiency (the first row at table 1.3B).   - -For the second conclusion, we reduced the layers of GPT-3 1.3B from 24 to 12 under a 30GB memory constraint. Alpa was unable to find a solution (but Autodist is able to find one). We make experimental examples (shown in **Table\***) to state that Alpa solver with memory constraint is unreasonable. - -**Table\*** - -The GPT model is 1.3B, we decrease the layer from 24 into 1, 5 and 12 respectively. Time ratio = mem time / baseline time. - -| | baseline time/s | mem time/s | time ratio | -|---:|-----------------:|---------------:|-------------:| -| 1 | 1.00 | 2.55 | 2.55 | -| 5 | 4.11 | 72.60 | 17.66 | -| 12 | 9.93 | None solution | -- | - -There are two unreasonable aspects, the first aspect is that the time ratio increases exponentially as the gpt model increases. Second, Alpa solver will not search for a solution when the model becomes larger (although this solution certainly exists from the above statement). diff --git a/autodist/benchmark/alpa/README.md b/autodist/benchmark/alpa/README.md deleted file mode 100644 index fd99a332..00000000 --- a/autodist/benchmark/alpa/README.md +++ /dev/null @@ -1,203 +0,0 @@ -# Benchmark Alpa - -## GPT-3 - -### Usage - -For the 3d setting, the config is the same with Table 4 in [1]. For the 2d setting, we test the GPT-3 6.7B with only 4 layers. Details of the model config can be found in the `benchmark.py`, `gpt_alpa_3d.sh`, `gpt_alpa_2d_table1.sh` and `gpt_alpa_2d_table2.sh`. - -You can cd the analyse_strategy folder for more specific analysis. - -### Experimental Config - -The benchmarks are implemented on a server runing on Ubuntu 20.04 system, which is equipped with an Intel(R) Xeon(R) Platinum 8160 CPU @ 2.10GHz and 16 NVIDIA V100-SXM2 32GB GPUs, each having a theoretical TFLOPS of 120 for FP16. The 16 GPUs are connected via NVLink and the interconnect bandwidth is 300GB/s (details seeing [NVIDIA TESLA V100 GPU ACCELERATOR](https://images.nvidia.com/content/technologies/volta/pdf/437317-Volta-V100-DS-NV-US-WEB.pdf). The version of CUDA is 11.3. - -**w/ pipeline parallelism (i.e. 3d)** - -We follow alpa's GPT-3 benchmark code (seeing Fig. 7a in [1]) on our testbed and results are in table 1. -In this case you can choose to overwrite the `benchmark.py` or not and run: - -```bash -bash gpt_alpa_3d.sh -``` - -**w/o pipeline parallelism (i.e. 2d)** - -We follow alpa's GPT-3 benchmark code under shard parallel (i.e. only intra-opeartor parallelism, no pipeline parallelism). -The results with 8 V100s are in table2.1 and those with 4 V100s in table2.2. -In this case you need to overwrite the `benchmark.py` and run: - -```bash -bash gpt_alpa_2d_table1.sh - -bash gpt_alpa_2d_table2.sh -``` - -**Description of parameters in alpa** - -- `shard-only` : Only profile the 2D case. No pipeline parallelism, default=`False` -- `num_micro_batches`: The number of micro batches, equal to batch size/micro batches. When `num_micro_batches>1`, the grad function will apply `alpa.grad`, which adds the gradient accumulation mechanism to `jax.grad`. The default is `1` -- `num_gpt_layer` : The number of the gpt layer, other config parameters can be seen in `benchmark.py`. -- `dp`: The number of channel for data parallelism, an `int` from [1,2,4,……,gpus]. -- `op`: The number of channel for operator parallelism, an `int` from [1,2,4,……,gpus]. -- `reduce-scatter`: If this is True, alpa will use **reduce-scatter** and **all-gather** to replace **all-reduce**. It will achieve a sppedup in execute for reduce-scatter-friendly system, but burden the optimization time. -- `parallel mode`. It can be selected from `uniform` and `zero-3`. -- `shard`. Not using the ray cluster, default=`False`. -- `profile driven time`. Profile the execution time on the driver instead of the workers, default=`False`. -- `recomputation`. This switch determines whether recomputation is turned on, default=`False`. If recomputation is open, the memory cost will increase during the backward and the communication overhead (**all-reduce** during the recomputation) saves. - -### Results - -**Table 1** - -| gpus | TFLOPs | Peak Mem/GB | Execute time/s (Mean, Std) | Complie + optimize time/s | -|:----:|:------: |:-----------:|:--------------------------:|:-------------------------:| -| 1 | 40.67 | 7.053 | (80.787, 0.004) | 50.19 | -| 2 | 119.04 | 10.376 | (57.36, 0.080) | 49.94 | -| 4 | 240.76 | 8.575 | (48.337, 0.023) | 57.52 | -| 8 | 511.36 | 11.66 | (45.646, 0.010) | 110.69 | -| 16 | 1110.72 | 11.346 | (51.868, 0.019) | 117.46 | - -**Table 2.1** - -Details for Table 2.1: - -```bash ---num-devices-per-host 8 --num_gpt_layer 4 --num_batch_size 32 --num_micro_batches 1 --reduce_scatter -``` - -| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| (1,8) | 462.56 | 6.478 | (0.565, 0.000) | 28 | 28 | 0 | 0 | 0 | 5.69 | -| (2,4) | 538.88 | 5.098 | (0.485, 0.001) | 33 | 29 | 1 | 3 | 0 | 7.98 | -| (4,2) | 571.20 | 5.449 | (0.457, 0.000) | 33 | 29 | 1 | 3 | 0 | 7.96 | -| (8,1) | 587.44 | 6.924 | (0.445, 0.003) | 4 | 1 | 1 | 2 | 0 | 4.00 | - -**Table 2.2** w/ recompute float16 - -Details for table 2.2 w/ recompute float16: - -```bash ---num-devices-per-host 4 --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 --recomputation -``` - -| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| (1,4) | 229.80 | 5.076 | (0.142, 0.000) | 28 | 28 | 0 | 0 | 0 | 5.98 | -| (2,2) | 179.44 | 10.287 | (0.182, 0.000) | 31 | 31 | 0 | 0 | 0 | 8.33 | -| (4,1) | 161.92 | 20.571 | (0.202, 0.001) | 3 | 3 | 0 | 0 | 0 | 4.08 | - -**Table 2.2** w/o recompute float16 - -Details for table 2.2 w/o recompute float16 - -```bash ---num-devices-per-host 4 --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 -``` - -| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| (1,4) | 220.48 | 6.288 | (0.117, 0.000) | 20 | 20 | 0 | 0 | 0 | 4.47 | -| (2,2) | 164.64 | 10.287 | (0.157, 0.000) | 23 | 23 | 0 | 0 | 0 | 6.45 | -| (4,1) | 143.00 | 20.571 | (0.180, 0.001) | 3 | 3 | 0 | 0 | 0 | 2.80 | - -**Table 2.2** w/ recompute float32 - -Details for table 2.2 w/ recompute float32: - -```bash ---num-devices-per-host 4 --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 --recomputation -``` - -| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| (1,4) | 48.12 | 5.485 | (0.679, 0.001) | 29 | 27 | 2 | 0 | 0 | 5.62 | -| (2,2) | 43.96 | 11.429 | (0.743, 0.000) | 30 | 30 | 0 | 0 | 0 | 8.59 | -| (4,1) | 43.20 | 22.857 | (0.756, 0.001) | 2 | 2 | 0 | 0 | 0 | 3.59 | - -**Table 2.2** w/o recompute float32 - -Details for table 2.2 w/o recompute float32: - -```bash ---num-devices-per-host 4 --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 -``` - -| (dp,op) | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:-------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| (1,4) | 47.28 | 7.704 | (0.545, 0.000) | 21 | 19 | 2 | 0 | 0 | 4.44 | -| (2,2) | 42.08 | 11.429 | (0.613, 0.000) | 22 | 22 | 0 | 0 | 0 | 5.89 | -| (4,1) | 40.64 | 22.857 | (0.634, 0.001) | 2 | 2 | 0 | 0 | 0 | 2.81 | - -**Table 2.3** w/o recompute float16 - -Details for table 2.3 w/o recompute float16: - -```bash ---num-devices-per-host 4 --num_batch_size 4 --num_micro_batches 1 --dp 1 --op 4 -``` - -| layers | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| 8 | 228.12 | 10.791 | (0.203, 0.000) | 36 | 36 | 0 | 0 | 0 | 8.72 | -| 12 | 228.60 | 15.340 | (0.293, 0.000) | 52 | 52 | 0 | 0 | 0 | 13.14 | -| 16 | 231.32 | 19.843 | (0.379, 0.000) | 68 | 68 | 0 | 0 | 0 | 18.05 | - -**Table 2.3** w/ recompute float16 - -Details for table 2.3 w/o recompute float16: - -```bash ---num-devices-per-host 4 --num_batch_size 4 --num_micro_batches 1 --dp 1 --op 4 -``` - -| layers | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| 8 | 236.68 | 8.163 | (0.254, 0.000) | 52 | 52 | 0 | 0 | 0 | 11.50 | -| 12 | 237.52 | 11.290 | (0.369, 0.000) | 76 | 76 | 0 | 0 | 0 | 18.95 | -| 16 | 237.76 | 14.448 | (0.484, 0.001) | 100 | 100 | 0 | 0 | 0 | 23.82 | - -**Table 2.3** w/o recompute float32 - -Details for table 2.3 w/o recompute float32: - -```bash ---num-devices-per-host 4 --num_batch_size 4 --num_micro_batches 1 --dp 1 --op 4 -``` - -| layers | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| 8 | 49.32 | 12.707 | (0.940, 0.000) | 37 | 35 | 2 | 0 | 0 | 7.70 | -| 12 | 50.40 | 17.710 | (1.330, 0.001) | 53 | 51 | 2 | 0 | 0 | 13.46 | -| 16 | 50.68 | 22.744 | (1.729, 0.001) | 69 | 67 | 2 | 0 | 0 | 16.34 | - -**Table 2.3** w/ recompute float32 - -Details for table 2.3 w/ recompute float32: - -```bash ---num-devices-per-host 4 --num_batch_size 4 --num_micro_batches 1 --dp 1 --op 4 -``` - -| layers | TFLOPS | Peak Mem/GB | Execute time/s (Mean, Std) | #comm | allreduce | allgather | reducescatter | all2all | Complie + optimize time/s | -|:------:|:-------:|:-----------:|:--------------------------:|:-----:|:---------:|:---------:|:-------------:|:-------:|:-------------------------:| -| 8 | 50.12 | 8.564 | (1.200, 0.001) | 53 | 51 | 2 | 0 | 0 | 12.01 | -| 12 | 50.72 | 11.644 | (1.727, 0.000) | 77 | 75 | 2 | 0 | 0 | 16.54 | -| 16 | 51.16 | 14.801 | (2.251, 0.002) | 101 | 99 | 2 | 0 | 0 | 22.59 | - -Remark 1: When `Prefer_reduce_scatter=False` and `recomputation=False`, the tensor parallelism strategy generated by *alpa* is consistent with that of *megatron-lm*. - -## Q&A - - Q1: Why the mean time for data parallelism (dp=8,op=1 in Table 2.1) is faster than tensor parallelism (dp=1,op=8 in Table 2.1)? - - A1: Because the communication volume of the former is 12 *hidden_size*hidden_size and that of the later is 4*batch size*hidden_size (2 all-reduce in the feedfoward and 2 in the backward). - And the mean time in Table 2.2 (both w/ and w/o recomputation) supports this view. When we reduce batch size from 32 to 4, then the data parallelism (dp=4,op=1 in Table 2.2) is slower than tensor parallelism (dp=1,op=4 in Table 2.2). - - Q2: Why the TFLOPs are reduced to 1/4 of the precision of 16 bits when the precision is 32 bits? - - A2: Because it uses the tensor core technique, which boosts the TFLOPS. - -## Reference - -\[1\] Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning diff --git a/autodist/benchmark/alpa/analyse_strategy/README.md b/autodist/benchmark/alpa/analyse_strategy/README.md deleted file mode 100644 index 290dc08c..00000000 --- a/autodist/benchmark/alpa/analyse_strategy/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# Analyse alpa strategy - -In this part, we write the gen_str.py to generate the partition strategy from the log in Alpa. The best spmd results with 760M and 1.3B are in **strategy.zip**. - -## Usage - -```bash -python gen_str.sh -``` - -The default load_file is log.txt and the save_file is test.txt, you can specific the load_file and save_file by adding **--load_file** and **--save_file**. If you want to see more information about the Alpa partition, you can add **--whole_strategy --detailed_partition_strs** - -## Comparsion with Autodist - - diff --git a/autodist/benchmark/alpa/analyse_strategy/gen_str.py b/autodist/benchmark/alpa/analyse_strategy/gen_str.py deleted file mode 100644 index c7110175..00000000 --- a/autodist/benchmark/alpa/analyse_strategy/gen_str.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import time -import json -import argparse - -LAYER_DOT_NUM = 6 -MAX_LAYER = 64 - - -class strategy: - - def __init__(self, Instruction: str): - self.elements = Instruction.split(' ') - self.id = self.elements[1] - at_index = self.elements.index('@') - self.selected_str = self.elements[at_index - 5:at_index] - self.selected_str = ' '.join(self.selected_str) - - -def write_str(save_file, strs): - lines = [] - for i in strs: - lines += i - with open(save_file, 'w') as f: - f.writelines(lines) - return - - -def get_str(args, lines, indexs, selected_strs): - strs = [] - dot_name = { - 0: 'qvk_combined', - 1: '...qhd,...khd->...hqk', - 2: '...hqk,...khd->...qhd', - 3: 'attention/output', - 4: 'intermediate/dense', - 5: 'output/dense' - } - assert len(indexs) % LAYER_DOT_NUM == 0 - str_count = 0 - for dot_count, index in enumerate(indexs): - - this_strs = [] - assert 'Instruction' in lines[index] - this_id = lines[index].split(' ')[2].split('%')[-1] - for selected_str in selected_strs: - if this_id == strategy(selected_str).id: - this_s = selected_str - break - - if dot_count % LAYER_DOT_NUM == 0: - this_strs.append('transformer_layer:' + - str(dot_count // LAYER_DOT_NUM) + '\n') - this_strs.append(' ' + dot_name[dot_count % LAYER_DOT_NUM] + ':' + - '\n') - if args.whole_strategy: - this_strs.append(' ' + 'instruction: ' + this_s) - this_strs.append(' ' + 'partition: ' + - strategy(this_s).selected_str + '\n') - - i = index - while True: - i += 1 - if 'Instruction' in lines[i]: - break - if args.detailed_partition_strs: - this_strs.append(' ' + lines[i]) - str_count = i - index - 1 - this_strs.append(' ' + 'total strategy numbers: ' + str(str_count) + - '\n') - strs.append(this_strs) - - return strs - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--start_layer', - type=int, - default=0, - help='set the start layer') - parser.add_argument( - '--end_layer', - type=int, - default=-1, - help='set the end layer, and generate [start_layer,……,end_layer-1]') - parser.add_argument('--load_file', - type=str, - default='log.txt', - help='set the loader folder for experiment data') - parser.add_argument('--save_file', - type=str, - default='test.txt', - help='set the save folder for experiment data') - parser.add_argument('--whole_strategy', - action='store_true', - help='show the whole strategy instruction') - parser.add_argument('--detailed_partition_strs', - action='store_true', - help='show the partition strategy that can be chosen') - args = parser.parse_args() - - total_layers = list(range(0, MAX_LAYER + 1)) - layers = total_layers[args.start_layer:args.end_layer] - f = open(args.load_file, 'r') - lines = f.readlines() - indexs = [] - for i in range(len(lines)): - if 'Startegy Map' in lines[i]: - start_i = i - if 'Auto sharding strategy' in lines[i]: - end_i = i - break - end_i = len(lines) - assert end_i != len(lines) - - for i in range(start_i, end_i): - if 'dot(' in lines[i]: - for layer in layers: - if 'layer/' + str(layer) + '/' in lines[i]: - indexs.append(i) - break - selected_strs = [] - start_i = end_i - for i in range(len(lines)): - if 'Exit AutoSharding' in lines[i]: - end_i = i - break - end_i = len(lines) - for i in range(start_i, end_i): - if 'dot(' in lines[i]: - selected_strs.append(lines[i]) - selected_strs = selected_strs[1:] - f.close() - strs = get_str(args, lines, indexs, selected_strs) - write_str(args.save_file, strs) diff --git a/autodist/benchmark/alpa/analyse_strategy/strategy.zip b/autodist/benchmark/alpa/analyse_strategy/strategy.zip deleted file mode 100644 index e11c2630e9316d79a72e4e00f301da6740b55939..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 886 zcmWIWW@Zs#U|`^2h;+0K`#RxNodP2RL!1UkM25jAy`WUDq@pA=gp+}pQ|nstY9KDH z;AUWC+2qT{!@_vvHTT;a(S9uk608r(!$0afNFMy?wy1jnw}Qgl3kq)o>U!SYd$ec9 zgwB&Tm6JX@9bUS8%l+FIZ~XWv@F*ks@4cFv5%=OcedL}0iXFRFd3);1xF#GfmGe6?qvrEj9Rl?DE7HyJufF`F453+0559zFEPOGPisF z4i1|2%UY8)e#(^ztd@K4)EaEL|74`}LanZ&}leqyC=x zvj3l>d6ebWhAmuKjG1=>xjnM|brsfD139(<+#atdMma?JUtzjbwM=Y6*f}HSuI^km z5%JeN8hYO(17dVc7VNl@;ZV5oQp2O9<%~y<1_2eUWfKZN7zSpDh)45k=xvkwKmE0I z06c_Ku!OJ)atO-+LwD(D2$LSdj7)OOxKgnMFm*F9024LCl12~_CJ+w*uT)`u diff --git a/autodist/benchmark/alpa/benchmark.py b/autodist/benchmark/alpa/benchmark.py deleted file mode 100644 index 338a4fff..00000000 --- a/autodist/benchmark/alpa/benchmark.py +++ /dev/null @@ -1,204 +0,0 @@ -"""The entry point of intra-op + inter-op parallelism benchmark.""" -import os -import argparse -from datetime import datetime -import time - -import numpy as np - -from alpa.util import (write_tsv, get_num_hosts_and_num_devices, to_str_round, - GB) -from collections import namedtuple -from benchmark_one_case import benchmark_one_case -import suite_auto_gpt -import suite_auto_moe -import suite_manual_gpt -import suite_manual_moe -import suite_wresnet -import suite_inference_gpt -from benchmark_parallel_utils import (BenchmarkCase, ShardParallelArgs, - UniformParallelArgs) -#from suite_manual_gpt import GPTModelConfig - -#B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size -GPTModelConfig = namedtuple( - 'GPTModelConfig', - ['seq_len', 'hidden_size', 'num_layers', 'num_heads', 'vocab_size']) - -benchmark_suites = { - 'gpt.tmp': suite_manual_gpt.tmp_suite, - 'gpt.tmp_auto': suite_auto_gpt.tmp_suite, - 'gpt.perf_test_fast_2d': suite_manual_gpt.perf_test_fast_2d_suite, - 'gpt.perf_test_manual': suite_manual_gpt.perf_test_suite, - 'gpt.perf_test_auto': suite_auto_gpt.perf_test_suite, - 'gpt.grid_search_auto': suite_auto_gpt.grid_search_suite, - 'gpt.correctness_test_auto': suite_auto_gpt.correctness_test_suite, - 'gpt_inference.profile': suite_inference_gpt.profile_suite, - 'gpt_no_embedding_inference.profile': suite_inference_gpt.profile_suite, - 'moe.tmp': suite_manual_moe.tmp_suite, - 'moe.tmp_auto': suite_auto_moe.tmp_suite, - 'moe.perf_test_fast_2d': suite_manual_moe.perf_test_fast_2d_suite, - 'moe.perf_test_auto': suite_auto_moe.perf_test_suite, - 'moe.grid_search_auto': suite_auto_moe.grid_search_suite, - 'wresnet.perf_test_2d': suite_wresnet.perf_test_2d_suite, - 'wresnet.perf_test_auto': suite_wresnet.perf_test_auto_suite, - 'wresnet.grid_search_auto': suite_wresnet.grid_search_auto_suite, -} - - -def benchmark_suite(suite_name, - num_hosts, - num_devices_per_host, - input_gpt_layer, - input_batch_size, - input_micro_batches, - reduce_scatter, - dp, - op, - recomputation, - exp_name='default', - niter=3, - shard_only=False, - local=False, - profile_driver_time=False, - profile_stage_execution_time=False, - disable_tqdm=False, - use_separate_process=True): - num_gpus = num_hosts * num_devices_per_host - - if local: - assert shard_only, ('Only shard-only mode is supported for execution ' - 'on local GPUs.') - - if num_gpus not in benchmark_suites[suite_name]: - return - suite = benchmark_suites[suite_name][num_gpus] - #print("suit is {},suit[0]is {}".format(suite,benchmark_case)) - os.makedirs('tmp', exist_ok=True) - - model_type = suite_name.split('.')[0] - output_name = f'{exp_name}.tsv' - - # Run all cases - for benchmark_case in suite: - - if shard_only: - assert dp * op == num_gpus, ('dp*op != num_gpus.') - # B, model, NB, PM, (RS, Remat, 3D Config, FM) - benchmark_case_new = BenchmarkCase( - input_batch_size, - GPTModelConfig(1024, 4096, input_gpt_layer, 32, 51200), - input_micro_batches, 'uniform', - UniformParallelArgs(reduce_scatter, recomputation, dp, op, 1, - True)) - - else: - benchmark_case_new = benchmark_case - - model_config = benchmark_case_new.model_config - num_micro_batches = benchmark_case_new.num_micro_batches - parallel_args = benchmark_case_new.parallel_args - - # Run one case - print('Working on case: {}'.format(str(benchmark_case_new))) - - result = benchmark_one_case(model_type, - benchmark_case_new, - niter, - num_hosts, - num_devices_per_host, - shard_only=shard_only, - local=local, - profile_driver_time=profile_driver_time, - disable_tqdm=disable_tqdm, - use_separate_process=use_separate_process) - - (parameter_count, peak_mem, latencies, tflops, metadata) = result - - heads = [ - 'Type', 'Model Config', '#Microbatch', '#GPU', 'Parallel Config', - 'Mean Time (s)', 'Std Time (s)', '#Params (Billion)', 'TFLOPs', - 'Peak Mem (GB)', 'Metadata' - ] - values = [ - model_type, model_config, num_micro_batches, num_gpus, - parallel_args, f'{np.mean(latencies):.3f}', - f'{np.std(latencies):.3f}', f'{parameter_count/1e9:.3f}B', - f'{tflops:.2f}', f'{peak_mem/GB:.3f}', - to_str_round(metadata, 2) - ] - write_tsv(heads, values, output_name) - - time.sleep(0.1) # for ctrl+c to work - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--suite', - choices=list(benchmark_suites.keys()), - type=str, - required=True) - parser.add_argument('--niter', - type=int, - default=3, - help='The number of benchmark iterations') - parser.add_argument('--num-hosts', type=int, default=None) - parser.add_argument('--num-devices-per-host', type=int, default=None) - parser.add_argument('--shard-only', - action='store_true', - help='Only profile the 2D case. No pipeline ' - 'parallelism.') - parser.add_argument('--local', - action='store_true', - help='Run on local GPUs. Do not use ray actors.') - parser.add_argument('--profile-driver-time', - action='store_true', - help='Profile the execution time on the driver instead ' - 'of the workers.') - parser.add_argument( - '--profile-stage-execution-time', - action='store_true', - help='Profile the execution timestamps of each pipeline ' - 'stage') - parser.add_argument('--no-separate-process', - action='store_false', - help='Do not launch separate processes for benchmark. ' - 'Errors in a single case will terminate this ' - 'script.', - dest='use_separate_process') - parser.add_argument('--exp-name', type=str, default='default') - parser.add_argument('--disable-tqdm', action='store_true') - parser.add_argument('--num_gpt_layer', type=int, default=1) - parser.add_argument('--num_batch_size', type=int, default=4) - parser.add_argument('--num_micro_batches', type=int, default=1) - parser.add_argument('--reduce_scatter', - action='store_true', - help='Prefer_reduce_scatter = True.') - parser.add_argument('--dp', type=int, default=4) - parser.add_argument('--op', type=int, default=1) - parser.add_argument('--recomputation', - action='store_true', - help='remat = True.') - args = parser.parse_args() - - num_hosts, num_devices_per_host = get_num_hosts_and_num_devices(args) - - benchmark_suite( - args.suite, - num_hosts, - num_devices_per_host, - args.num_gpt_layer, - args.num_batch_size, - args.num_micro_batches, - args.reduce_scatter, - args.dp, - args.op, - args.recomputation, - args.exp_name, - args.niter, - args.shard_only, - args.local, - args.profile_driver_time, - args.disable_tqdm, - args.use_separate_process, - ) diff --git a/autodist/benchmark/alpa/gpt_alpa_2d_table1.sh b/autodist/benchmark/alpa/gpt_alpa_2d_table1.sh deleted file mode 100644 index 597bbc0f..00000000 --- a/autodist/benchmark/alpa/gpt_alpa_2d_table1.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) -dp=(1 2 4 8) -op=(8 4 2 1) - - -for ((k=0; k<${#dp[*]}; k=k+1)); do - python benchmark.py --suite gpt.perf_test_fast_2d \ - --shard-only --num-hosts 1 --num-devices-per-host 8 \ - --num_gpt_layer 4 --num_batch_size 32 --num_micro_batches 1 \ - --dp ${dp[k]} --op ${op[k]} --reduce_scatter -done - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "running spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/benchmark/alpa/gpt_alpa_2d_table2.sh b/autodist/benchmark/alpa/gpt_alpa_2d_table2.sh deleted file mode 100644 index 4b89a854..00000000 --- a/autodist/benchmark/alpa/gpt_alpa_2d_table2.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) -dp=(1 2 4) -op=(4 2 1) - - -for ((k=0; k<${#dp[*]}; k=k+1)); do - python benchmark.py --suite gpt.perf_test_fast_2d \ - --shard-only --num-hosts 1 --num-devices-per-host 4 \ - --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 \ - --dp ${dp[k]} --op ${op[k]} --recomputation - -done - -for ((k=0; k<${#dp[*]}; k=k+1)); do - python benchmark.py --suite gpt.perf_test_fast_2d \ - --shard-only --num-hosts 1 --num-devices-per-host 4 \ - --num_gpt_layer 4 --num_batch_size 4 --num_micro_batches 1 \ - --dp ${dp[k]} --op ${op[k]} - -done - - - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "running spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/benchmark/alpa/gpt_alpa_3d.sh b/autodist/benchmark/alpa/gpt_alpa_3d.sh deleted file mode 100644 index 1633ddeb..00000000 --- a/autodist/benchmark/alpa/gpt_alpa_3d.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) -gpus=(1 2 4 8 8 16) -device=(1 1 1 1 2 1) - - - -for ((k=0; k<${#gpus[*]}; k=k+1)); do - python benchmark.py --suite gpt.perf_test_auto \ - --num-hosts ${device[k]} --num-devices-per-host ${gpus[k]} - -done - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "running spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/benchmark/alphafold2.md b/autodist/benchmark/alphafold2.md deleted file mode 100644 index 9360c928..00000000 --- a/autodist/benchmark/alphafold2.md +++ /dev/null @@ -1,124 +0,0 @@ -# Alphafold2 - -## Model Config - -We focus on evoformer like structures during training currently. Data type is *float16*. - -**Evoformer Stack** - -| Case | s | r | cm | cz | -| ---------------- | ---- | ---- | ---- | ---- | -| initial training | 128 | 256 | 256 | 128 | -| 1st fine-tuning | 512 | 256 | 256 | 128 | -| 2nd fine-tuning | 512 | 384 | 256 | 128 | - -**Extra Msa Stack** - -| Case | s | r | cm | cz | -| ---------------- | ---- | ---- | ---- | ---- | -| initial training | 1024 | 256 | 64 | 128 | -| 2.1 fine-tuning | 1024 | 384 | 64 | 128 | -| 2.2 fine-tuning | 5120 | 384 | 64 | 128 | - -## Baselines - -**Deepmind's plan** - -data parallelism (each accelerator with exact 1 sample) + recompute. Since the parameter size is relatively small in Alphafold2, the latency can be approximately by single device execution. Hyperparameter setting is listed in the following table. - -| Case | evo_num | use_chunk | -| --------------- | ------- | --------- | -| Evoformer Stack | 48 | False | -| Extra Msa Stack | 4 | True | - -**Dynamic Axial Parallelism (DAP)** - -The end-to-end time is bounded by the computation. In other words, given input tensors with a fixed batch size, it is possible to reduce the time by introducing more devices (partition a operator into parallelizable sub-operators). Here are possible experiment dimensions. - -| batch size | #gpus | -| ---------- | ----- | -| 1 | 2 | -| 2 | 4 | -| 4 | 8 | -| 8 | 16 | - -**Table 1: Evoformer Stack & Training** - -| Case | batch size | #gpus | latency/ms | peak mem/MB | -| ---------------- | ---------- | ----- | ---------- | ----------- | -| initial training | 1 | 1 | 3521.98 | 4414 | -| initial training | 1 | 2 | 2430.38 | 2531 | -| initial training | 1 | 4 | 1497.77 | 1574 | -| initial training | 2 | 4 | 2485.53 | 2647 | -| 1st fine-tuning | 1 | 1 | 7696.62 | 10729 | -| 1st fine-tuning | 1 | 2 | 4663.32 | 5744 | -| 1st fine-tuning | 1 | 4 | 2620.09 | 3211 | -| 1st fine-tuning | 2 | 4 | 4717.36 | 5921 | -| 2nd fine-tuning | 1 | 1 | 16632.06 | 17810 | -| 2nd fine-tuning | 1 | 2 | 9377.98 | 9417 | -| 2nd fine-tuning | 1 | 4 | 5099.72 | 5157 | -| 2nd fine-tuning | 2 | 4 | 9422.99 | 9804 | - -**Table 2: Extra Msa Stack & Training** - -| Case | batch size | #gpus | latency/ms | peak mem/MB | -| ---------------- | ---------- | ----- | ---------- | ----------- | -| initial training | 1 | 1 | x | x | -| initial training | 1 | 2 | x | x | -| initial training | 1 | 4 | x | x | -| initial training | 2 | 4 | x | x | -| 2.1 fine-tuning | 1 | 1 | x | x | -| 2.1 fine-tuning | 1 | 2 | x | x | -| 2.1 fine-tuning | 1 | 4 | x | x | -| 2.1 fine-tuning | 2 | 4 | x | x | -| 2.2 fine-tuning | 1 | 1 | x | x | -| 2.2 fine-tuning | 1 | 2 | x | x | -| 2.2 fine-tuning | 1 | 4 | x | x | -| 2.2 fine-tuning | 2 | 4 | x | x | - -## End-to-end evaluation results (DAP vs Autodist) - -### Model Config - -Evoformer Stack - - shape config - - bs, s, r, cm, cz = 1, 128, 256, 256, 128 - - bs, s, r, cm, cz = 1, 512, 256, 256, 128 - - bs, s, r, cm, cz = 1, 512, 384, 256, 128 - - other config: dtype, use_chunk, is_train, is_extra = torch.float16, False, True, False - -*note*: results organized in (estimate time/ms, execution time/ms, device mem/GB) - -**Table 1: tensor parallelism(2gpu) w/o recompute** - -evo_num = 4 - -| s, r | DAP | Autodist | compile time/s | -| ------------- | --------------- | ------------------------ | ---------------| -| 128, 256 | (139.15, 4.58) | (127.13, 156.15, 5.35) | 0.77 | -| 512, 256 | (293.11, 11.02) | (286.04, 307.54, 12.86) | 0.77 | -| 512, 384 | (596.41, 20.91) | (568.72, 595.00, 24.44) | 0.77 | - -*note*: results organized in (estimate time/ms, execution time/ms, device mem/GB) - -**Table 2: tensor parallelism(2gpu) w/ adaptive recompute** - -evo_num = 48 -memory constraint = 40GB - -| s, r | DAP | Autodist | compile time/s | -| ------------- | --------------- | ------------------------- | ---------------| -| 128, 256 | (2250.27, 2.53) | (1690.71, 1915.13, 38.33) | 43.57 | -| 512, 256 | (4733.89, 5.74) | (4273.40, 4525.81, 39.06) | 45.39 | -| 512, 384 | (9673.10, 9.42) | (8911.85, 10042.22, 39.70)| 43.88 | - -**Table 3: tensor parallelism(4gpu) w/ adaptive recompute** - -evo_num = 48 -memory constraint = 40GB - -| s, r | DAP | Autodist | compile time/s | -| ------------- | --------------- | ------------------------- | ---------------| -| 128, 256 | (1874.73, 1.54) | (1083.93, 1400.29, 29.13) | 4650.48 | -| 512, 256 | (3350.06, 3.13) | (2388.69, 2965.40, 36.50) | 4483.49 | -| 512, 384 | (6724.48, 5.04) | (4932.62, 6450.42, 41.80) | 4427.15 | diff --git a/autodist/benchmark/gpt.md b/autodist/benchmark/gpt.md deleted file mode 100644 index bf976a75..00000000 --- a/autodist/benchmark/gpt.md +++ /dev/null @@ -1,234 +0,0 @@ - - -# GPT-3 - -## Model Config - -**batch size**, **sequence length** and **vocabulary size** are fixed to 1024, 1024 and 51200 respectively. The data type is *float16*. - -| #params | Hidden size | #layers | #heads | #gpus | -| ------- | ----------- | ------- | ------ | ----- | -| 350M | 1024 | 24 | 16 | 1 | -| 760M | 1536 | 24 | 16 | 2 | -| 1.3B | 2048 | 24 | 32 | 4 | -| 2.6B | 2560 | 32 | 32 | 8 | -| 6.7B | 4096 | 32 | 32 | 16 | -| 15B | 5120 | 48 | 40 | 32 | -| 39B | 8192 | 48 | 64 | 64 | - -## End-to-end evaluation results - -*note*: results organized in (execution time/s, device mem/GB, compile time/s) - -**Table 1: include pipeline** - -| Config | alpa | ours | -| ------ | ---------------------- | ---------------------- | -| 760M | (59.21, 14.49, 232.51) | (46.13, 30.27, 7.56) | -| 1.3B | (47.23, 24.83, 355.14) | (42.45, 20.68, 33.96) | -| 2.6B | (45.00, 13.20, 731.86) | (39.72, 24.74, 295.58) | -| 6.7B | (46.19, 16.75, 832.01) | (45.17, 15.57, 1906.22)| - -**Table 2: tensor parallelism only** - -| Config | alpa | ours | -| ------ | --------------------- | ----------------------- | -| 350M | (54.88, 26.50, 8.77) | (56.03, 26.52, 0.37) | -| 760M | (51.06, 21.38, 27.04) | (52.14, 19.41, 2.00) | -| 1.3B | (47.83, 20.22, 25.86) | (47.84, 30.37, 14.91) | -| 2.6B | (55.10, 21.47, 65.88) | (69.92, 10.85, 36.91) | -| 6.7B | (84.16, 25.09, 65.21) | (61.79, 21.93, 226.11) | - -## AutoDist Details - -Memory constraint set to 30 GB in most test cases. - -For cases with \*, we set the memory constraint to 25GB. - -For the case with runtime=-1, we cannot run the case due to two errors from Cube: multiref and rvd failed. - -### Include pipeline - -**760M** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 65.37 | 56.58 | 6.19 | 6.71 | 5.19 | -| 2 | 53.18 | 51.11 | 10.26 | 10.88 | 5.66 | -| 4 | 49.13 | 46.74 | 17.47 | 18.14 | 6.05 | -| 8 | 48.52 | 46.13 | 29.56 | 30.27 | 7.56 | -| 16*| 53.86 | 53.47 | 24.97 | 29.98 | 9.3 | -| 32*| 57.49 | 56.48 | 24.81 | 27.5 | 9.59 | - -**1.3B** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 52.36 | 48.27 | 8.03 | 8.99 | 28.99 | -| 2 | 45.59 | 43.48 | 13.81 | 14.77 | 27.7 | -| 4 | 43.87 | 42.45 | 20.03 | 20.68 | 33.96 | -| 8 | 44.3 | -1 | 24.98 | -1 | 38.67 | -| 16*| 47.71 | 48.13 | 24.86 | 30.35 | 48.31 | -| 32*| 48.96 | 48.03 | 24.98 | 25.96 | 42.15 | - -**2.6B** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 46.32 | 46.45 | 7.34 | 8.94 | 228.92 | -| 2 | 42.6 | 40.96 | 12.58 | 14.26 | 248.44 | -| 4 | 41.48 | 39.72 | 23.06 | 24.74 | 295.58 | -| 8 | 41.01 | -1 | 28.74 | -1 | 436.54 | -| 16 | 42.68 | -1 | 29.89 | -1 | 476.78 | -| 32 | 44.98 | -1 | 29.98 | -1 | 394.06 | - -**6.7B** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 50.58 | 50.26 | 7.94 | 9.92 | 1496.21 | -| 2 | 46.37 | 45.17 | 13.54 | 15.57 | 1906.22 | -| 4 | 43.81 | -1 | 17.16 | -1 | 1937.48 | -| 8 | 43.42 | -1 | 29.65 | -1 | 2082.68 | -| 16 | 43.66 | -1 | 29.98 | -1 | 1834.14 | -| 32 | 47.15 | -1 | 29.96 | -1 | 1555.16 | - -Remark: - -### Tensor Parallelism Only - -**350M** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 82.23 | 71.08 | 5.92 | 5.64 | 0.38 | -| 2 | 69.33 | 62.75 | 8.91 | 8.63 | 0.39 | -| 4 | 62.98 | 58.48 | 14.88 | 14.26 | 0.37 | -| 8 | 59.6 | 56.03 | 26.84 | 26.52 | 0.37 | -| 16 | 64.73 | 62.61 | 29.8 | 29.21 | 0.38 | -| 32 | 71.85 | 68.8 | 29.43 | 26.36 | 0.3 | - -**760M** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 90.14 | 81.4 | 4.96 | 5.02 | 0.79 | -| 2 | 65.31 | 66.03 | 6.84 | 7.12 | 0.94 | -| 4 | 58.3 | 57.21 | 10.65 | 11.22 | 1.07 | -| 8 | 53.8 | 52.14 | 18.25 | 19.41 | 2 | -| 16*| 53.86 | 53.46 | 24.97 | 29.78 | 3.31 | -| 32*| 57.49 | 56.42 | 24.81 | 27.54 | 2.7 | - -**1.3B** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 79.5 | 83.84 | 4.42 | 4.41 | 2.65 | -| 2 | 61.4 | 69.35 | 6.25 | 6.3 | 3.96 | -| 4 | 55.88 | 57.85 | 9.91 | 10.29 | 6 | -| 8 | 50.25 | 50.38 | 17.2 | 19.65 | 9.78 | -| 16*| 47.71 | 48.08 | 24.86 | 30.37 | 14.91 | -| 32*| 48.96 | 47.84 | 24.98 | 25.76 | 9.29 | - -**2.6B** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 87.92 | 102.79 | 4.3 | 4.58 | 17.57 | -| 2 | 70.56 | 80.51 | 6.03 | 6.98 | 24.84 | -| 4 | 63.19 | 69.62 | 9.5 | 10.85 | 36.91 | -| 8 | 53.25 | -1 | 15.94 | -1 | 76.63 | -| 16 | 46.44 | -1 | 24.99 | -1 | 115.04 | -| 32 | 46.38 | -1 | 24.98 | -1 | 69.36 | - -**6.7B** - -| | estimation time/s | runtime/s | estimation memory/GB | runtime memory/GB | compile time/s | -|---:|--------------------:|------------:|-----------------------:|--------------------:|-----------------:| -| 1 | 107.5 | 127.41 | 5.01 | 5.24 | 61.34 | -| 2 | 88.84 | 101.43 | 6.84 | 7.78 | 79.19 | -| 4 | 77.33 | 89.68 | 10.45 | 13.54 | 116.27 | -| 8 | 61.7 | 61.79 | 19.2 | 21.93 | 226.11 | -| 16 | 50.44 | -1 | 24.16 | -1 | 189.93 | -| 32 | 49.57 | -1 | 25 | -1 | 138.96 | - - -## Alpa Details - -Here we show the key hyperparameters of alpa under different config in both **table 1** and **table 2**. - -**Details for table 1** - -| Config | micro batch size | Recompute | num_auto_layers | forward_stage_layer_ids | submesh_shapes | logical_mesh_shapes | autosharding_option_dicts | -|:-------|:-----------------|:----------|:----------------|:-------------------------------|:-------------------------------|:-------------------------------|:-----------------------------------------------------| -| 350M* | 32 | True | 1 | [0] | [1, 1] | [1, 1] | {} | -| 760M | 16 | True | 6 | [0, 1, 2], [3, 4, 5] | [1, 1], [1, 1] | [1, 1], [1, 1] | force_dp_dict,{} | -| 1.3B | 64 | True | 6 | [0, 1, 2, 3, 4, 5] | [1, 4] | [4, 1] | force_dp_dict,{} | -| 2.6B | 16 | True | 8 | [0, 1, 2, 3], [4, 5, 6, 7] | [1, 4], [1, 4] | [4, 1], [4, 1] | force_dp_dict,force_dp_dict | -| 6.7B | 16 | True | 8 | [0, 1], [2, 3], [4, 5], [6, 7] | [1, 4], [1, 4], [1, 4], [1, 4] | [4, 1], [4, 1], [4, 1], [4, 1] | force_dp_dict,{},force_dp_dict,force_dp_dict | -| 15B | 8 | True | 16 | [0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15] | [1, 4]\*8 | [4, 1]\*8 | force_dp_dict \* 8 | - -**15B** - -| batch size | time/s | memory/GB | recompute | -| ---------- | -------| --------- | --------- | -| 2 | 75.77 | 21.94 | False | -| 4 | 95.56 | 17.69 | True | -| 8 | 59.98 | 18.10 | True | -| 16 | 60.84 | 18.67 | True | -| 32 | 70.05 | 23.32 | True | -| 64 | 76.03 | 21.31 | True | - -**Details for table 2** - -**350M** - -| batch size | time/s | memory/GB | recompute | -| ---------- | -------| --------- | --------- | -| 4 | 56.46 | 16.30 | False | -| 8 | 54.88 | 26.50 | False | -| 16 | 71.06 | 16.10 | True | -| 32 | 70.22 | 26.09 | True | -| 64 | x | OOM | True | - -**760M** - -| batch size | time/s | memory/GB | recompute | -| ---------- | -------| --------- | --------- | -| 4 | 51.82 | 13.82 | False | -| 8 | 51.06 | 21.38 | False | -| 16 | 62.99 | 12.67 | True | -| 32 | 62.88 | 19.06 | True | -| 64 | x | OOM | True | - -**1.3B** - -| batch size | time/s | memory/GB | recompute | -| ---------- | -------| --------- | --------- | -| 4 | 49.35 | 12.77 | False | -| 8 | 47.83 | 20.22 | False | -| 16 | 60.01 | 10.67 | True | -| 32 | 59.31 | 16.01 | True | -| 64 | 58.81 | 26.67 | True | -| 128 | x | OOM | True | - -**2.6B** - -| batch size | time/s | memory/GB | recompute | -| ---------- | -------| --------- | --------- | -| 4 | 55.16 | 13.36 | False | -| 8 | 55.10 | 21.47 | False | -| 16 | 67.90 | 12.03 | True | -| 32 | 67.56 | 18.85 | True | -| 64 | x | OOM | True | - - -**6.7B** - -| batch size | time/s | memory/GB | -| ---------- | -------| --------- | -| 2 | 109.60 | 7.87 | -| 8 | 89.98 | 11.11 | -| 16 | 86.14 | 15.77 | -| 32 | 84.16 | 25.09 | -| 64 | x | OOM | diff --git a/autodist/benchmark/recompute.md b/autodist/benchmark/recompute.md deleted file mode 100644 index 39f77e08..00000000 --- a/autodist/benchmark/recompute.md +++ /dev/null @@ -1,28 +0,0 @@ -# Continuous Recompute - -We implement the continuous recompute search algorithm, which outperforms the alpa and the manual strategy of Megatron[1] on GPT. On the following cases, [Alpa](https://github.com/alpa-projects/alpa) is OOM and autodist has a 5.4% gain over megatron. - -## Experimental Config - -The model config is GPT3 760M, on 2 GPUs, num_layer increased from 24 into 48, global_batch_size = 1024 and micro_batch_size = 8. - -## Results - -| Search algorithm | runtime (time/s, memory/GB) | compile time/s | Remark | -|:-------------------------------|:----------------------------|:----------------|:---------| -| Megatron(selective recompute) | (137.10, 17.22) | 8.40 | OOM | -| Megatron(full recompute) | (171.78, 7.09) | 9.23 | | -| Megatron(search) | (154.22, 15.85) | 9.60 | | -| Alpa | (149.53, 16.15) | 129.96 | OOM | -| Autodist(continuous recompute) | **(145.90, 15.67)** | 2069.80 | | -| Autodist(single recompute) | (146.65, 15.99) | 130.31 | Multiref | - -## Details - -**Megatron:** Megatron using selective recompute(as well as using full recompute) represents selective recompute(as well as full recompute) for all layers. Megatron(search) is the optimal solution searched manually according to [args.recompute-method](https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/arguments.py#L375:~:text=group.add_argument(%27%2D%2Drecompute%2Dmethod%27%2C%20type%3Dstr%2C%20default%3DNone%2C)), [args.recompute-granularity](https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/arguments.py#L375:~:text=group.add_argument(%27%2D%2Drecompute%2Dgranularity%27%2C%20type%3Dstr%2C%20default%3DNone%2C)) and [args.recompute-num-layers](https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/arguments.py#L375:~:text=group.add_argument(%27%2D%2Drecompute%2Dnum%2Dlayers%27%2C%20type%3Dint%2C%20default%3D1%2C)) in Megatron. Depending on the permutation of these switches, without pipeline, we can specify that some layers recompute and others don't, and that the recompute strategy can be selective recompute or full recompute. But we can not obtain the strategy that different layers take different recompute strategies. The optimal solution of our artificial search is that Megatron(search) fully recomputes the first 31 layers. - -**Alpa:** The Alpa solution is OOM with the config(dp=1, op=2, use_remat=True). - -**Autodist:** The Autodist(continuous recompute) searches for a solution, where the first 23 layers fully recompute and the left 25 layers selective recompute. The compile time of Autodist(continuous recompute) is up to 15x of that of Autodist(single recompute), whose search solution has multiref BUG. - -**Remark**: We use whether the memory exceeds 16G to determine if it is OOM. Because heavy fragmentation arises in the above case, memory utilization is less than 70% and difficults the memory estimation. diff --git a/autodist/build_env.py b/autodist/build_env.py deleted file mode 100644 index 590ce804..00000000 --- a/autodist/build_env.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import os -import shutil -import sys -from datetime import datetime -import subprocess -from pathlib import Path -import torch -from nnscaler.autodist.util import get_node_arch - -if bool(int(os.environ.get('PROFILE_COMM', default=0))): - profile_comm = True -else: - profile_comm = False - - -def main(): - base_path = str(Path.home()) + '/.autodist' - default_path = base_path + '/' + get_node_arch() - - code_path = Path(__file__).parents[1] - - if not os.path.exists(default_path): - os.makedirs(default_path) - print('> create folder: ', default_path) - os.makedirs(default_path + '/plan') - else: - print('> folder already exists: ', default_path) - - # profile communication cost - if profile_comm: - print('> CUDA device num: ', torch.cuda.device_count()) - for device_num in [2, 4, 8, 16]: - if device_num > torch.cuda.device_count(): - break - command = f'torchrun --master_port 21212 --nproc_per_node={device_num} ./comm_profile.py --comm_profile_dir={default_path}/comm' - output = subprocess.check_output(command, shell=True, text=True) - else: - print('> skip communication profiling, using mi200 profile data') - if os.path.exists(default_path + '/comm'): - print('> backup existing comm profile data') - shutil.move( - default_path + '/comm', - default_path + f'/comm_back_{str(datetime.now().timestamp())}') - shutil.copytree(code_path / 'autodist/profile_data/16xmi200/comm', default_path + '/comm') - - print('> build env successfully') - - -if __name__ == '__main__': - main() diff --git a/autodist/script/alphafold/foldtp.sh b/autodist/script/alphafold/foldtp.sh deleted file mode 100755 index defb30e5..00000000 --- a/autodist/script/alphafold/foldtp.sh +++ /dev/null @@ -1,64 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) - -memory_constraint=38 # in GB -memory_granularity=1 # in byte -connect_type='NV2' -save_folder="alphafold" -topk=1 -cache_folder1="autodist/cost_model/comm/__pycache__" -cache_folder2="autodist/cost_model/__pycache__" -comm_dev=(2 4) - -for ((i=0; i<${#comm_dev[*]}; i=i+1)); do - torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type -done - -if [ ! -d $save_folder ] -then - mkdir $save_folder -fi - -# We run all cases in the machine with 4 gpus. - -mesh_rows=(1 1 1) -mesh_cols=(1 2 4) -setting=(1 2 3) -layer=48 - -for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do - for ((j=0; j<${#setting[*]}; j=j+1)); do - - echo "start runtime Alphafold2 setting=${setting[j]} gpus=${mesh_cols[k]}" - if [ -d $cache_folder1 ] - then - echo "Removing $cache_folder1 directory..." - rm -r $cache_folder1 - rm -r $cache_folder2 - else - echo "$cache_folder1 directory not found" - fi - - SINGLE_DEV_MODE=1 python main.py --is_train \ - --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity \ - --save_folder=$save_folder --connect_type=$connect_type \ - --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --compile \ - --topk=$topk --ignore_small_tensor_threshold=2048 \ - --verbose --alphafold_setting=${setting[j]} --alphafold \ - --alphafold_layer=$layer --recompute --adaptive_recom - - torchrun --master_port=30001 --nnodes=${mesh_rows[k]} \ - --nproc_per_node=${mesh_cols[k]} main.py --is_train \ - --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder \ - --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} \ - --plan_idx=0 --iter_num=4 --warm_num=2 \ - --global_batch_size=1 --alphafold --ignore_small_tensor_threshold=2048 \ - --alphafold_setting=${setting[j]} --alphafold_layer=$layer --recompute - done -done - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/adapt_recom_tp.sh b/autodist/script/gpt/adapt_recom_tp.sh deleted file mode 100644 index b733ce8d..00000000 --- a/autodist/script/gpt/adapt_recom_tp.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) - -memory_constraint=30 -memory_granularity=1 # in byte -connect_type='NV2' -save_folder="tp_data" -topk=1 - -comm_dev=(2 4 8 16) - -for ((i=0; i<${#comm_dev[*]}; i=i+1)); do - torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type -done - -if [ ! -d $save_folder ] -then - mkdir $save_folder -fi - -# We run all cases in the machine with 4 gpus. - -bs=(1 2 4 8 16 32) -mesh_rows=(1 1 1 1 1) -mesh_cols=(1 2 4 8 16) -model_config=('350M' '760M' '1.3B' '2.6B' '6.7B') - -for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do - for ((j=0; j<${#bs[*]}; j=j+1)); do - - echo "start runtime ${bs[j]} ${model_config[k]}" - - SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ - --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --compile \ - --topk=$topk --fine_grained_GPT --ignore_small_tensor_threshold=1048576 \ - --verbose --adaptive_recom - - for (( i=0; i < $topk; ++i)) - do - torchrun --nnodes=${mesh_rows[k]} --nproc_per_node=${mesh_cols[k]} main.py --GPT_setting=${model_config[k]} --is_train --recompute \ - --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type \ - --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --plan_idx=$i --fine_grained_GPT --adaptive_recom - if [ $? -eq 0 ] - then - echo "success at $i trial" - break - else - echo "fail at $i trial" - fi - done - - done -done - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/analyze.py b/autodist/script/gpt/analyze.py deleted file mode 100644 index 1c965033..00000000 --- a/autodist/script/gpt/analyze.py +++ /dev/null @@ -1,77 +0,0 @@ -import json -import argparse - -parser = argparse.ArgumentParser(description='GPT Train') -parser.add_argument('--save_folder', - type=str, - default='tp_data', - help='set the save folder for experiment data') -parser.add_argument('--pp', - action='store_true', - help='for pipeline number analysis') -parser.add_argument('--is_train', - action='store_true', - help='True: train, False: inference') -args = parser.parse_args() -import pandas as pd - -model_setting_list = ['760M', '1.3B', '2.6B', '6.7B' - ] if args.pp else ['350M', '760M', '1.3B', '2.6B', '6.7B'] -gpus = {'350M': 1, '760M': 2, '1.3B': 4, '2.6B': 8, '6.7B': 16} -recompute_list = ['True'] -batch_size_list = [1, 2, 4, 8, 16, 32] - -for recompute in recompute_list: - table = {} - for model_setting in model_setting_list: - for batch_size in batch_size_list: - table[batch_size] = {} - fname = './' + args.save_folder + '/gpt3-' + model_setting + '-' + str( - gpus[model_setting]) + 'gpu-' + str( - batch_size) + 'batch_size-' + str(args.is_train) - estimated_fname = fname + '-estimate.json' - backup_fname = fname + '-backup.json' - real_fname = fname + '-real.json' - - try: - with open(backup_fname, 'r') as f: - estimated_dict = json.load(f) - try: - tmp = estimated_dict['estimated memory'] - except: - estimated_dict = estimated_dict[0] - estimated_time = estimated_dict['estimated time'] - estimated_memory = estimated_dict['estimated memory'][ - 0] if args.pp else estimated_dict['estimated memory'] - compile_time = estimated_dict['compile time'] - except: - try: - with open(estimated_fname, 'r') as f: - estimated_dict = json.load(f) - estimated_time = estimated_dict['estimated time'] - estimated_memory = estimated_dict['estimated memory'][ - 0] if args.pp else estimated_dict['estimated memory'] - compile_time = estimated_dict['compile time'] - except: - estimated_time = -1 - estimated_memory = -1 - compile_time = -1 - try: - with open(real_fname, 'r') as f: - real_dict = json.load(f) - real_time = real_dict['time/s'] - real_memory = max(real_dict['memory/GB'].values()) - except: - real_time = -1 - real_memory = -1 - - table[batch_size]['estimation time/s'] = estimated_time - table[batch_size]['runtime/s'] = real_time - table[batch_size][ - 'estimation memory/GB'] = estimated_memory if estimated_memory != -1 else -1 - table[batch_size][ - 'runtime memory/GB'] = real_memory if real_memory != -1 else -1 - table[batch_size]['compile time/s'] = compile_time - pdTable = pd.DataFrame(table).round(2).T - print(model_setting, recompute) - print(pdTable.to_markdown()) diff --git a/autodist/script/gpt/analyze_adapt_recom.py b/autodist/script/gpt/analyze_adapt_recom.py deleted file mode 100644 index 47f7cf06..00000000 --- a/autodist/script/gpt/analyze_adapt_recom.py +++ /dev/null @@ -1,77 +0,0 @@ -import json -import argparse - -parser = argparse.ArgumentParser(description='GPT Train') -parser.add_argument('--save_folder_tp', - type=str, - default='tp_data', - help='set the save folder for tp') -parser.add_argument('--save_folder_pp', - type=str, - default='pp_data', - help='set the save folder for pp') -parser.add_argument('--suffix', - type=str, - default='_nar', - help='set the save folder for w/o adaptive_recom') -parser.add_argument('--pp', - action='store_true', - help='for pipeline number analysis') -args = parser.parse_args() -import pandas as pd - -folders = [args.save_folder_tp, args.save_folder_tp + args.suffix] -if args.pp: - folders = [args.save_folder_pp, args.save_folder_pp + args.suffix] - -model_setting_list = ['760M', '1.3B'] if args.pp else ['350M', '760M', '1.3B'] -gpus = {'350M': 1, '760M': 2, '1.3B': 4, '2.6B': 8, '6.7B': 16} -recompute_list = ['True'] -batch_size_list = [1, 2, 4, 8, 16, 32] - -for recompute in recompute_list: - table = {} - for model_setting in model_setting_list: - for batch_size in batch_size_list: - table[batch_size] = {} - for index, folder in enumerate(folders): - fname = './' + folder + '/gpt3-' + model_setting + '-' + str( - gpus[model_setting]) + 'gpu-' + str( - batch_size) + 'batch_size' - backup_fname = fname + '-backup.json' - - try: - with open(backup_fname, 'r') as f: - estimated_dict = json.load(f) - try: - tmp = estimated_dict['estimated memory'] - except: - estimated_dict = estimated_dict[0] - estimated_time = estimated_dict['estimated time'] - estimated_memory = estimated_dict['estimated memory'][ - 0] if args.pp else estimated_dict['estimated memory'] - compile_time = estimated_dict['compile time'] - except: - estimated_time = -1 - estimated_memory = -1 - compile_time = -1 - - if index == 0: - table[batch_size][ - 'est time w/ adapt_recom /s'] = estimated_time - else: - table[batch_size][ - 'est time w/o adapt_recom /s'] = estimated_time - if index == 0: - table[batch_size][ - 'compile time w/ adapt_recom /s'] = compile_time - else: - table[batch_size][ - 'compile time w/o adapt_recom /s'] = compile_time - table[batch_size]['gain/%'] = ( - table[batch_size]['est time w/o adapt_recom /s'] - - table[batch_size]['est time w/ adapt_recom /s'] - ) / table[batch_size]['est time w/o adapt_recom /s'] * 100 - pdTable = pd.DataFrame(table).round(2).T - print(model_setting) - print(pdTable.to_markdown()) diff --git a/autodist/script/gpt/checker.sh b/autodist/script/gpt/checker.sh deleted file mode 100755 index 680100ff..00000000 --- a/autodist/script/gpt/checker.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) - -memory_constraint=30.5 -memory_granularity=1 # in byte -connect_type='NV2' -save_folder="exp_data_test" - -comm_dev=(2 4) - -for ((i=0; i<${#comm_dev[*]}; i=i+1)); do - torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type -done - -if [ ! -d $save_folder ] -then - mkdir $save_folder -fi - -# spmd for a simple case (with 1 gpu) and a complex case (with 4 gpus). - -bs=(32) -mesh_cols=(1 4) -model_config=('350M' '1.3B') - -for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do - for ((j=0; j<${#bs[*]}; j=j+1)); do - - SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --recompute \ - --batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_col=${mesh_cols[k]} --compile - - torchrun --nproc_per_node=${mesh_cols[k]} main.py --GPT_setting=${model_config[k]} --recompute \ - --batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_col=${mesh_cols[k]} - - done -done - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "checkRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/pp_all_run.sh b/autodist/script/gpt/pp_all_run.sh deleted file mode 100755 index c7d38d89..00000000 --- a/autodist/script/gpt/pp_all_run.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) - -memory_constraint=30 -memory_granularity=1 # in byte -connect_type='NV2' -save_folder="pp_data" -topk=1 - -comm_dev=(2 4 8 16) - -for ((i=0; i<${#comm_dev[*]}; i=i+1)); do - torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type -done - -if [ ! -d $save_folder ] -then - mkdir $save_folder -fi - -# We run all cases in the machine with 4 gpus. - -bs=(1 2 4 8 16 32) -mesh_rows=(1 1 1 1) -mesh_cols=(2 4 8 16) -model_config=('760M' '1.3B' '2.6B' '6.7B') - -for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do - for ((j=0; j<${#bs[*]}; j=j+1)); do - echo "start runtime ${bs[j]} ${model_config[k]}" - SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ - --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder \ - --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --compile --pipeline --topk=1 - - torchrun --nnodes=${mesh_rows[k]} --nproc_per_node=${mesh_cols[k]} main.py --GPT_setting=${model_config[k]} --is_train --recompute \ - --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --pipeline --plan_idx=0 - - done -done - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/profile.sh b/autodist/script/gpt/profile.sh deleted file mode 100644 index 5f31f008..00000000 --- a/autodist/script/gpt/profile.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) - -memory_constraint=30 -memory_granularity=1 # in byte -connect_type='NV2' -save_folder="tp_data" -topk=20 - -comm_dev=(2 4 8 16) - -for ((i=0; i<${#comm_dev[*]}; i=i+1)); do - torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type -done - -if [ ! -d $save_folder ] -then - mkdir $save_folder -fi - -# Use nvidia-smi to get a list of GPUs -gpus=$(nvidia-smi -L) - -# Count the number of lines of output -num_gpus=$(echo "$gpus" | wc -l) - -bs=(1 2 4 8 16 32) -mesh_cols=(1 2 4 8 16) -model_config=('350M' '760M' '1.3B' '2.6B' '6.7B') - - -for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do - for ((j=0; j<${#bs[*]}; j=j+1)); do - - count=$((k * ${#bs[*]} + j)) - q=$(expr $count % $num_gpus) - - # bash profile.sh to profile the coarse-gained GPT - # bash profile.sh * to profile the fine-gained GPT - if [ $# -eq 0 ]; then - CUDA_VISIBLE_DEVICES=$q SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ - --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder \ - --connect_type=$connect_type --mesh_col=${mesh_cols[k]} --compile \ - --topk=$topk & - - else - CUDA_VISIBLE_DEVICES=$q SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ - --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder \ - --connect_type=$connect_type --mesh_col=${mesh_cols[k]} --compile \ - --topk=$topk --fine_grained_GPT & - fi - - if [ "$q" -eq "$((num_gpus-1))" ]; then - wait - fi - - done -done - -wait - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/gpt/tp_all_run.sh b/autodist/script/gpt/tp_all_run.sh deleted file mode 100755 index bad93399..00000000 --- a/autodist/script/gpt/tp_all_run.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) - -memory_constraint=30 -memory_granularity=1 # in byte -connect_type='NV2' -save_folder="tp_data" -topk=1 - -comm_dev=(2 4 8 16) - -for ((i=0; i<${#comm_dev[*]}; i=i+1)); do - torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type -done - -if [ ! -d $save_folder ] -then - mkdir $save_folder -fi - -# We run all cases in the machine with 4 gpus. - -bs=(1 2 4 8 16 32) -mesh_rows=(1 1 1 1 1) -mesh_cols=(1 2 4 8 16) -model_config=('350M' '760M' '1.3B' '2.6B' '6.7B') - -for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do - for ((j=0; j<${#bs[*]}; j=j+1)); do - - echo "start runtime ${bs[j]} ${model_config[k]}" - - SINGLE_DEV_MODE=1 python main.py --GPT_setting=${model_config[k]} --is_train --recompute \ - --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --compile \ - --topk=$topk - - for (( i=0; i < $topk; ++i)) - do - torchrun --nnodes=${mesh_rows[k]} --nproc_per_node=${mesh_cols[k]} main.py --GPT_setting=${model_config[k]} --is_train --recompute \ - --micro_batch_size=${bs[j]} --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} --plan_idx=$i - if [ $? -eq 0 ] - then - echo "success at $i trial" - break - else - echo "fail at $i trial" - fi - done - - done -done - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/pre_install.sh b/autodist/script/pre_install.sh deleted file mode 100644 index e34fe0e8..00000000 --- a/autodist/script/pre_install.sh +++ /dev/null @@ -1,7 +0,0 @@ -sudo echo 'export PATH="$HOME/.local/bin:$PATH"' > $(dirname $(pwd))/.bashrc -source $(dirname $(pwd))/.bashrc -pip install -r requirements-dev.txt -pip install pre-commit -pre-commit install -pre-commit run --all-files -python setup.py develop --user diff --git a/autodist/script/swin/analysis.py b/autodist/script/swin/analysis.py deleted file mode 100644 index 7890cbe9..00000000 --- a/autodist/script/swin/analysis.py +++ /dev/null @@ -1,72 +0,0 @@ -import json -import argparse - -parser = argparse.ArgumentParser(description='Swin Train') -parser.add_argument('--save_folder', - type=str, - default='swin', - help='set the save folder for experiment data') -parser.add_argument('--pp', - action='store_true', - help='for pipeline number analysis') -args = parser.parse_args() -import pandas as pd - -model_setting_list = ['toy', '355M', '1.8B'] -gpus = {'toy': 1, '355M': 2, '1.8B': 4, '2.6B': 8, '6.7B': 16} -recompute_list = ['True'] -batch_size_list = [1, 2, 4, 8, 16, 32] - -for recompute in recompute_list: - table = {} - for model_setting in model_setting_list: - for batch_size in batch_size_list: - table[batch_size] = {} - fname = './' + args.save_folder + '/swin-' + model_setting + '-' + str( - gpus[model_setting]) + 'gpu-' + str(batch_size) + 'batch_size' - estimated_fname = fname + '-estimate.json' - backup_fname = fname + '-backup.json' - real_fname = fname + '-real.json' - - try: - with open(backup_fname, 'r') as f: - estimated_dict = json.load(f) - try: - tmp = estimated_dict['estimated memory'] - except: - estimated_dict = estimated_dict[0] - estimated_time = estimated_dict['estimated time'] - estimated_memory = estimated_dict['estimated memory'][ - 0] if args.pp else estimated_dict['estimated memory'] - compile_time = estimated_dict['compile time'] - except: - try: - with open(estimated_fname, 'r') as f: - estimated_dict = json.load(f) - estimated_time = estimated_dict['estimated time'] - estimated_memory = estimated_dict['estimated memory'][ - 0] if args.pp else estimated_dict['estimated memory'] - compile_time = estimated_dict['compile time'] - except: - estimated_time = -1 - estimated_memory = -1 - compile_time = -1 - try: - with open(real_fname, 'r') as f: - real_dict = json.load(f) - real_time = real_dict['time/s'] - real_memory = max(real_dict['memory/GB'].values()) - except: - real_time = -1 - real_memory = -1 - - table[batch_size]['estimation time/s'] = estimated_time - table[batch_size]['runtime/s'] = real_time - table[batch_size][ - 'estimation memory/GB'] = estimated_memory if estimated_memory != -1 else -1 - table[batch_size][ - 'runtime memory/GB'] = real_memory if real_memory != -1 else -1 - table[batch_size]['compile time/s'] = compile_time - pdTable = pd.DataFrame(table).round(2).T - print(model_setting, recompute) - print(pdTable.to_markdown()) diff --git a/autodist/script/swin/profile_swin.sh b/autodist/script/swin/profile_swin.sh deleted file mode 100644 index 8036e69b..00000000 --- a/autodist/script/swin/profile_swin.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) - -memory_constraint=30 -memory_granularity=1 # in byte -connect_type='NV2' -save_folder="tp_data" -topk=20 - -comm_dev=(2 4) - -for ((i=0; i<${#comm_dev[*]}; i=i+1)); do - torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type -done - -if [ ! -d $save_folder ] -then - mkdir $save_folder -fi - -# Use nvidia-smi to get a list of GPUs -gpus=$(nvidia-smi -L) - -# Count the number of lines of output -num_gpus=$(echo "$gpus" | wc -l) - -bs=(1 2 4 8 16) -mesh_cols=(1 2 4) -setting=('toy' '355M' '1.8B') - - -for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do - for ((j=0; j<${#bs[*]}; j=j+1)); do - - count=$((k * ${#bs[*]} + j)) - q=$(expr $count % $num_gpus) - - echo "start runtime Swin setting=${setting[k]} bs=${bs[j]}" - - CUDA_VISIBLE_DEVICES=$q LOG_TRANSFORM=1 SINGLE_DEV_MODE=1 python main.py --is_train \ - --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity \ - --save_folder=$save_folder --connect_type=$connect_type \ - --mesh_row=1 --mesh_col=${mesh_cols[k]} --compile \ - --topk=$topk \ - --verbose --swin_setting=${setting[k]} --swin \ - --recompute --micro_batch_size=${bs[j]} --global_batch_size=32 & - - if [ "$q" -eq "$((num_gpus-1))" ]; then - wait - fi - - done -done - -wait - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/autodist/script/swin/swintp.sh b/autodist/script/swin/swintp.sh deleted file mode 100755 index 9ce957fb..00000000 --- a/autodist/script/swin/swintp.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash --login -start_time=$(date +%s) - -memory_constraint=35 # in GB -memory_granularity=1 # in byte -connect_type='NV2' -save_folder="swin" -topk=1 -cache_folder1="autodist/cost_model/comm/__pycache__" -cache_folder2="autodist/cost_model/__pycache__" -comm_dev=(2 4) - -for ((i=0; i<${#comm_dev[*]}; i=i+1)); do - torchrun --nproc_per_node=${comm_dev[i]} comm_profile.py --connect_type=$connect_type -done - -if [ ! -d $save_folder ] -then - mkdir $save_folder -fi - -# We run all cases in the machine with 4 gpus. - -bs=(1 2 4 8 16 32) - -mesh_cols=(1 2 4) -mesh_rows=(1 1 1) -setting=('toy' '355M' '1.8B') - -for ((k=0; k<${#mesh_cols[*]}; k=k+1)); do - for ((j=0; j<${#bs[*]}; j=j+1)); do - - echo "start runtime Swin setting=${setting[k]} bs=${bs[j]}" - if [ -d $cache_folder1 ] - then - echo "Removing $cache_folder1 directory..." - rm -r $cache_folder1 - rm -r $cache_folder2 - else - echo "$cache_folder1 directory not found" - fi - - LOG_TRANSFORM=1 SINGLE_DEV_MODE=1 python main.py --is_train \ - --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity \ - --save_folder=$save_folder --connect_type=$connect_type \ - --mesh_row=1 --mesh_col=${mesh_cols[k]} --compile \ - --topk=$topk \ - --verbose --swin_setting=${setting[k]} --swin \ - --micro_batch_size=${bs[j]} \ - --global_batch_size=32 --recompute --adaptive_recom - - torchrun --master_port=30001 --nnodes=${mesh_rows[k]} \ - --nproc_per_node=${mesh_cols[k]} main.py --is_train \ - --memory_constraint=$memory_constraint \ - --memory_granularity=$memory_granularity --save_folder=$save_folder \ - --connect_type=$connect_type --mesh_row=${mesh_rows[k]} --mesh_col=${mesh_cols[k]} \ - --plan_idx=0 --iter_num=2 --warm_num=1 --micro_batch_size=${bs[j]} \ - --global_batch_size=32 --swin_setting=${setting[k]} --swin \ - --recompute - done -done - -end_time=$(date +%s) -cost_time=$[ $end_time - $start_time] -echo "allRun.sh spends $(($cost_time/60))min $(($cost_time%60))s" diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 63377b01..86860703 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -29,12 +29,10 @@ def check_env(autodist_config: AutoDistConfig): - error_msg = ' does not exist, please run \'python autodist/build_env.py\' first' - autodist_dir = autodist_config.profile_dir + '/' + get_node_arch() - assert os.path.exists(autodist_dir), autodist_dir + error_msg - comm_path = autodist_dir + '/comm' - assert os.path.exists(comm_path), comm_path + error_msg - + arch_dir = autodist_config.profile_dir / get_node_arch() + if not arch_dir.exists(): + _logger.info(f'create folder: {arch_dir}') + arch_dir.mkdir(parents=True, exist_ok=True) def pre_estimate_mem(graph: ModelGraph): ''' diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 2f2c5710..f1bbcaf8 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -2,6 +2,7 @@ import argparse import logging from .descs import MeshDesc +from .util import get_default_profile_path _logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ class AutoDistConfig: - fp16 & bf16 training w/ memory efficient adam w/o inkernal cast: (2 + 2) (fp32 weight + fp32 gradient) - partition_constraints_path (`str`, *optional*, defaults to `''`): The path to the partition constraints file. Details can be found in docs/solver_interface/partition_constraints.md - - profile_dir (`str`, *optional*, defaults to `~/.autodist`): + - profile_dir (`str`, *optional*, defaults to `~/.nnscaler/autodist`): The directory to store the profiling results. - load_plan_path (`str`, *optional*, defaults to `''`): The path to the plan file to load. If specified, the plan will be loaded from the file instead of searching. @@ -103,7 +104,7 @@ def __init__(self, opt_resident_coef=2, opt_transient_coef=0, partition_constraints_path='', - profile_dir=str(Path.home()) + '/.autodist', + profile_dir=get_default_profile_path(), load_plan_path='', save_plan_path='', topk=20, @@ -172,7 +173,9 @@ def _validate_config(self): if self.pc_path: _validate_file_path(self.pc_path) - _validate_dir_path(self.profile_dir) + if not Path(self.profile_dir).exists(): + _logger.info(f'create folder: {self.profile_dir}') + Path(self.profile_dir).mkdir(parents=True, exist_ok=True) if self.pipeline: if self.max_pipeline_bubble_ratio <= 0 or self.max_pipeline_bubble_ratio >= 1: diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index a5ad5c92..49cf325f 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -52,6 +52,8 @@ def _piecewise_estimator(xs: List[float], ys: List[float], x: float) -> float: xs[i]) raise RuntimeError(f'x={x}, xs={xs}, ys={ys}, should not reach here') +import nnscaler +_DEFAULT_COMM_DATA_PATH = Path(nnscaler.__file__).parent.parent / 'profile_data/mi200/comm' class CostDatabase: @@ -70,9 +72,8 @@ def __init__(self, graph: IRGraph, config: AutoDistConfig): comm_dir = self.profile_dir / 'comm' if not comm_dir.exists(): - raise RuntimeError( - f'{comm_dir} does not exist, please run \'python autodist/build_env.py\' first' - ) + _logger.warning(f'Communication profile data not found, using default data at {_DEFAULT_COMM_DATA_PATH}') + comm_dir = Path(_DEFAULT_COMM_DATA_PATH) for fname in listdir(comm_dir): with open(comm_dir / fname, 'r') as f: self.comm_info[fname] = json.load(f) diff --git a/nnscaler/autodist/util.py b/nnscaler/autodist/util.py index 53ee350f..7a632046 100644 --- a/nnscaler/autodist/util.py +++ b/nnscaler/autodist/util.py @@ -3,6 +3,7 @@ from nnscaler.ir.operator import IRFwOperation import struct +from pathlib import Path from typing import List from collections import deque @@ -22,6 +23,8 @@ def double2byte(val): def double4byte(val): return struct.unpack('d', val)[0] +def get_default_profile_path(): + return Path.home() / '.cache' / 'nnscaler' / 'autodist' / '1.0' def get_node_arch(): import torch diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 4fe40d66..d7927bfd 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -433,7 +433,7 @@ def dump(self, file: str, override=False): json.dump(self._data, f) def dump_op(self, file: str, signature, override=False): - assert signature in self._data.keys(), f'this node not be profiled' + assert signature in self._data.keys(), f'{signature} has not been profiled' file_n = os.path.join(file, signature +'.json') with open(file_n, 'w') as f: to_dump = {key: asdict(value) for key, value in self._data[signature].items()} diff --git a/autodist/profile_data/16xmi200/comm/intra_16.json b/profile_data/mi200/comm/intra_16.json similarity index 100% rename from autodist/profile_data/16xmi200/comm/intra_16.json rename to profile_data/mi200/comm/intra_16.json diff --git a/autodist/profile_data/16xmi200/comm/intra_2.json b/profile_data/mi200/comm/intra_2.json similarity index 100% rename from autodist/profile_data/16xmi200/comm/intra_2.json rename to profile_data/mi200/comm/intra_2.json diff --git a/autodist/profile_data/16xmi200/comm/intra_4.json b/profile_data/mi200/comm/intra_4.json similarity index 100% rename from autodist/profile_data/16xmi200/comm/intra_4.json rename to profile_data/mi200/comm/intra_4.json diff --git a/autodist/profile_data/16xmi200/comm/intra_8.json b/profile_data/mi200/comm/intra_8.json similarity index 100% rename from autodist/profile_data/16xmi200/comm/intra_8.json rename to profile_data/mi200/comm/intra_8.json diff --git a/autodist/comm_profile.py b/utility/comm_profile.py similarity index 75% rename from autodist/comm_profile.py rename to utility/comm_profile.py index 852d61c3..5f767c24 100644 --- a/autodist/comm_profile.py +++ b/utility/comm_profile.py @@ -9,7 +9,7 @@ from nnscaler.runtime.adapter.collectives import all_gather, all_reduce, all_to_all, reduce_scatter from nnscaler.profiler import CudaTimer from nnscaler.runtime.device import DeviceGroup -from nnscaler.autodist.util import get_node_arch +from nnscaler.autodist.util import get_node_arch, get_default_profile_path class CommProfiler: @@ -81,27 +81,28 @@ def profile(self) -> Dict[str, Tuple[List[float], List[float]]]: primitive=primitive) return profile_info +if __name__ == '__main__': -parser = argparse.ArgumentParser( - description='Profile runtime communication cost') -parser.add_argument('--comm_profile_dir', - type=str, - default=str(Path.home()) + '/.autodist/comm', - help='autodist comm profile folder') -args = parser.parse_args() + parser = argparse.ArgumentParser( + description='Profile runtime communication cost') + parser.add_argument('--comm_profile_dir', + type=str, + default=get_default_profile_path() / get_node_arch() / 'comm', + help='autodist comm profile folder') + args = parser.parse_args() -nnscaler.init() + nnscaler.init() -CudaTimer(enable=True, predefined=True) -world_size = DeviceGroup().world_size -comm_profiler = CommProfiler(nranks=world_size) + CudaTimer(enable=True, predefined=True) + world_size = DeviceGroup().world_size + comm_profiler = CommProfiler(nranks=world_size) -profile_info = comm_profiler.profile() + profile_info = comm_profiler.profile() -if torch.distributed.get_rank() == 0: - dir_path = args.comm_profile_dir - if not os.path.exists(dir_path): - os.makedirs(dir_path) - file_name = dir_path + '/' + f'intra_{world_size}.json' - with open(file_name, 'w') as f: - json.dump(profile_info, f, indent=2) + if torch.distributed.get_rank() == 0: + dir_path = Path(args.comm_profile_dir) + if not dir_path.exists(): + dir_path.mkdir(parents=True, exist_ok=True) + file_name = dir_path / f'intra_{world_size}.json' + with open(file_name, 'w') as f: + json.dump(profile_info, f, indent=2) diff --git a/utility/prim_profiler.py b/utility/prim_profiler.py new file mode 100644 index 00000000..e68f5bf9 --- /dev/null +++ b/utility/prim_profiler.py @@ -0,0 +1,53 @@ +import torch +import os +import sys +import shutil +from datetime import datetime +import subprocess +import torch +import logging +from pathlib import Path +from nnscaler.autodist.util import get_node_arch, get_default_profile_path + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("nnscaler.comm_profiler") + + +def main(): + base_path = get_default_profile_path() + default_path = base_path / get_node_arch() + + if not default_path.is_dir(): + default_path.mkdir(parents=True) + logger.info(f'create folder: {default_path}') + else: + logger.info(f'folder already exists: {default_path}') + + comm_path = default_path / 'comm' + + if comm_path.is_dir(): + logger.info(f'back up legacy comm info: {comm_path}') + shutil.move( + comm_path, + default_path / f'comm_back_{str(datetime.now().timestamp())}') + comm_path.mkdir(parents=True, exist_ok=True) + + logger.info(f'CUDA device num: {torch.cuda.device_count()}') + profiler_fname = Path(__file__).parent / 'comm_profile.py' + device_num = 2 + while device_num <= torch.cuda.device_count(): + command = f'torchrun --master_port 21212 --nproc_per_node={device_num} {profiler_fname} --comm_profile_dir={comm_path}' + output = subprocess.check_output(command, shell=True, text=True) + device_num = device_num * 2 + + logger.info(f'comm profile done') + + +if __name__ == '__main__': + main() From d9f6fdb98a8dba1f6246c4db1117a37e73fc67c5 Mon Sep 17 00:00:00 2001 From: "Xin Ji (CSI Interfusion Co Ltd)" Date: Mon, 29 Apr 2024 06:10:50 +0000 Subject: [PATCH 1628/1892] Merged PR 2115: add new support operators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix and support operators: Find unknown pytorch operation: torch.nn.functional.unfold Find unknown pytorch operation: torch.gather Find unknown pytorch operation: torch.ceil Find unknown pytorch operation: torch.sign Find unknown pytorch operation: torch.sigmoid --- nnscaler/graph/function/function.py | 82 ++++++++++++++++++++++++++ nnscaler/graph/parser/fx/mapping.py | 8 +++ tests/graph/function/test_functions.py | 66 ++++++++++++++++++++- 3 files changed, 155 insertions(+), 1 deletion(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 6e86de91..ba4a0921 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -2517,3 +2517,85 @@ def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: annos = [f'n (g {iC}) {iW}, (g {oC//groups}) {iC} {kW}, (g {oC//groups}) -> n (g {oC//groups}) {oW}'] return IRDimops(Conv1D, 'conv1d', signature, annos, [input, weight, bias] if bias is not None else [input, weight], rules, stride=stride, padding=padding, dilation=dilation, groups=ori_groups) + + +def Gather(input: IRTensor, dim, index: IRTensor, sparse_grad=False, out=None, signature=None): + """ + torch.gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + """ + if not (0 <= dim < len(input.shape)): + raise ValueError(f"Dimension {dim} is out of bounds for input with {len(input.shape)} dimensions.") + if len(input.shape) != len(index.shape): + raise ValueError("The dimensions of 'input' and 'index' must be the same.") + for i, (dim_input, dim_index) in enumerate(zip(input.shape, index.shape)): + if i != dim and dim_index > dim_input: + raise ValueError(f"Index size {dim_index} at dimension {i} exceeds input size {dim_input} at the same dimension.") + gener = iter(string.ascii_lowercase) + input_anno = ShapeAnno.create_shape_str(input.shape, iterator=gener) + index_anno = ShapeAnno.create_shape_str(index.shape, iterator=gener) + for i, (dim_input, dim_index) in enumerate(zip(input.shape, index.shape)): + if dim_input != dim_index: + input_anno[i] += '^' + index_anno[i] += '^' + elif i == dim: + index_anno[i] = input_anno[i] + input_anno[i] += '^' + index_anno[i] += '^' + else: + # TODO: Currently, this only works in static cases. + # When dynamic shape is enabled, this partition may be incorrect. + # We keep the partition here for now, and consider reporting errors that cannot be partitioned at run time in future. + index_anno[i] = input_anno[i] + anno = OpAnno.create_op_str([input_anno, index_anno], [index_anno]) + return IRDimops(Gather, 'gather', signature, [anno], [input, index], dim=dim) + + +def Ceil(input: IRTensor, out=None, signature=None): + """ + # torch.ceil(input, *, out=None) → Tensor + """ + if out is not None: + raise ValueError("Expected 'out' to be None") + annos = ['* -> *'] + return IRDimops(Ceil, 'ceil', signature, annos, [input]) + + +def Sign(input: IRTensor, out=None, signature=None): + """ + torch.sign(input, *, out=None) → Tensor + """ + if out is not None: + raise ValueError("Expected 'out' to be None") + annos = ['* -> *'] + return IRDimops(Sign, 'sign', signature, annos, [input]) + + +def Unfold(input: IRTensor, kernel_size, dilation=1, padding=0, stride=1, signature=None): + """ + Extracts sliding local blocks from a batched input tensor. + torch.nn.functional.unfold(input, kernel_size, dilation=1, padding=0, stride=1) + """ + if not isinstance(input, IRTensor) or len(input.shape) != 4: + raise ValueError("Input must be an IRTensor with 4 dimensions, [N, C, H, W].") + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + padding = (padding, padding) if isinstance(padding, int) else padding + stride = (stride, stride) if isinstance(stride, int) else stride + N, C, H, W = input.shape + H_out = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 + W_out = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 + L = H_out * W_out + kernel_area = kernel_size[0] * kernel_size[1] + anno = f'N C {H} {W} -> N (C {kernel_area}) {L}' + return IRDimops(Unfold, 'unfold', signature, [anno], [input], kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + + +def Sigmoid(input, *, out=None, signature=None): + ''' + torch.sigmoid(input, *, out=None) → Tensor + ''' + if out is not None: + raise ValueError("Expected 'out' to be None") + annos = ['* -> *'] + return IRDimops(Sigmoid, 'sigmoid', signature, annos, [input]) diff --git a/nnscaler/graph/parser/fx/mapping.py b/nnscaler/graph/parser/fx/mapping.py index ac4b8e5e..a8efb69b 100644 --- a/nnscaler/graph/parser/fx/mapping.py +++ b/nnscaler/graph/parser/fx/mapping.py @@ -72,6 +72,9 @@ def exist(signature: str) -> bool: __ttemplate('squeeze'): function.Squeeze, __ttemplate('unsqueeze'): function.Unsqueeze, __tttemplate('type_as'): function.TypeAs, + __ttemplate('gather'): function.Gather, + __ttemplate('ceil'): function.Ceil, + __ttemplate('sign'): function.Sign, __ttemplate('triu'): function.Triu, __ttemplate('tril'): function.Tril, __ftemplate('relu'): function.ReLU, @@ -95,6 +98,10 @@ def exist(signature: str) -> bool: __ttemplate('masked_fill'): function.MaskedFill, __tttemplate('masked_fill_'): function.MaskedFill, __ttemplate('cumsum'): function.CumSum, + __ttemplate('sigmoid'): function.Sigmoid, + __tttemplate('sigmoid'): function.Sigmoid, + __ftemplate('sigmoid') : function.Sigmoid, + __fcntemplate('sigmoid') : function.Sigmoid, __ttemplate('tanh'): function.Tanh, __ftemplate('softmax') : function.Softmax, __ttemplate('softmax'): function.Softmax, @@ -120,6 +127,7 @@ def exist(signature: str) -> bool: '_operator.invert': function.BitwiseNot, __ftemplate('embedding'): function.Embedding, 'torch.functional.einsum': function.EinSum, + __ftemplate('unfold'): function.Unfold, __ftemplate('nll_loss') : function.NLLLoss, 'torch.functional.norm': function.Norm, __ftemplate('layer_norm'): function.LayerNorm, diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index b3588f16..aa82bf36 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -614,4 +614,68 @@ def test_Flatten(): op = F.Flatten(IRTensor([2,3,4,5]), 1) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c^ d^ -> a (b^ c^ d^)' op = F.Flatten(IRTensor([2,3,4,5]), end_dim = 2) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ d -> (a^ b^ c^) d' \ No newline at end of file + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ d -> (a^ b^ c^) d' + + +def test_Gather(): + op = F.Gather(IRTensor([2, 5, 3]), 2, IRTensor([2, 5, 1])) + expected_annotation = 'a b c^, a b f^ -> a b f^' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + op = F.Gather(IRTensor([2, 5, 3]), 2, IRTensor([2, 5, 3])) + expected_annotation = 'a b c^, a b c^ -> a b c^' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + op = F.Gather(IRTensor([2, 5, 3]), 2, IRTensor([2, 4, 3])) + expected_annotation = 'a b^ c^, a e^ c^ -> a e^ c^' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + op = F.Gather(IRTensor([2, 5, 3]), 2, IRTensor([1, 3, 1])) + expected_annotation = 'a^ b^ c^, d^ e^ f^ -> d^ e^ f^' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + op = F.Gather(IRTensor([2, 5, 3]), 1, IRTensor([2, 2, 3])) + expected_annotation = 'a b^ c, a e^ c -> a e^ c' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + op = F.Gather(IRTensor([2, 5, 3]), 0, IRTensor([1, 5, 3])) + expected_annotation = 'a^ b c, d^ b c -> d^ b c' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + op = F.Gather(IRTensor([2, 3]), 1, IRTensor([2, 1])) + expected_annotation = 'a b^, a d^ -> a d^' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + op = F.Gather(IRTensor([2, 3]), 1, IRTensor([1, 1])) + expected_annotation = 'a^ b^, c^ d^ -> c^ d^' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + + +def test_Ceil(): + input_tensor = IRTensor([2, 3]) + op = F.Ceil(input_tensor) + expected_annotation = '* -> *' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Ceil." + input_tensor = IRTensor([2, 3, 4]) + op = F.Ceil(input_tensor) + expected_annotation = '* -> *' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Ceil." + + +def test_Sign(): + input_tensor = IRTensor([2, 3]) + op = F.Sign(input_tensor) + expected_annotation = '* -> *' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Sign." + input_tensor = IRTensor([2, 3, 4]) + op = F.Sign(input_tensor) + expected_annotation = '* -> *' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Sign." + + +def test_Unfold(): + input_tensor = IRTensor([2, 3, 32, 32]) + kernel_size = (3, 3) + stride = (2, 2) + padding = (1, 1) + dilation = (1, 1) + op = F.Unfold(input_tensor, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'N C 32 32 -> N (C 9) 256' + + +def test_Sigmoid(): + op = F.Sigmoid(IRTensor([2, 3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ No newline at end of file From 0430d5b52ac1283add2078323f76aaeac5e5253b Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 7 May 2024 01:33:12 +0000 Subject: [PATCH 1629/1892] Merged PR 2132: Fix autodist tests - clear nnscaler environment correctly - enforce re_profile=True when searching plans --- autodist/.pre-commit-config.yaml | 32 ---------- autodist/.style.yapf | 2 - autodist/docs/descs.py | 58 ------------------ autodist/docs/images/arch.png | Bin 49324 -> 0 bytes .../autodist}/interface_design.md | 0 .../solver_interface/partition_constraint.md | 0 .../solver_interface/pc_examples/moe_pc.yaml | 0 .../pc_examples/retnet_dp2_pc.yaml | 0 .../pc_examples/retnet_hybrid2_pc.yaml | 0 .../pc_examples/retnet_mp2_pc.yaml | 0 nnscaler/autodist/apis.py | 2 + tests/autodist/pas/all_replicated_pp.json | 8 +-- .../pas/replicated_and_partition.json | 8 +-- .../pas/test_shared_param_pipeline.py | 19 +++--- tests/autodist/spmd_solver/test_follow.py | 4 +- .../spmd_solver/test_partition_constraint.py | 2 +- 16 files changed, 23 insertions(+), 112 deletions(-) delete mode 100644 autodist/.pre-commit-config.yaml delete mode 100644 autodist/.style.yapf delete mode 100644 autodist/docs/descs.py delete mode 100644 autodist/docs/images/arch.png rename {autodist/docs => docs/autodist}/interface_design.md (100%) rename {autodist/docs => docs/autodist}/solver_interface/partition_constraint.md (100%) rename {autodist/docs => docs/autodist}/solver_interface/pc_examples/moe_pc.yaml (100%) rename {autodist/docs => docs/autodist}/solver_interface/pc_examples/retnet_dp2_pc.yaml (100%) rename {autodist/docs => docs/autodist}/solver_interface/pc_examples/retnet_hybrid2_pc.yaml (100%) rename {autodist/docs => docs/autodist}/solver_interface/pc_examples/retnet_mp2_pc.yaml (100%) diff --git a/autodist/.pre-commit-config.yaml b/autodist/.pre-commit-config.yaml deleted file mode 100644 index f332baa8..00000000 --- a/autodist/.pre-commit-config.yaml +++ /dev/null @@ -1,32 +0,0 @@ -# File introduces automated checks triggered on git events -# to enable run `pip install pre-commit && pre-commit install` - -repos: - - repo: local - hooks: - - id: yapf - name: yapf - language: python - entry: yapf - args: [-i, -vv] - types: [python] - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 - hooks: - - id: trailing-whitespace - - id: check-docstring-first - - id: check-json - - id: check-added-large-files - - id: check-yaml - - id: debug-statements - - id: requirements-txt-fixer - - id: check-merge-conflict - - id: double-quote-string-fixer - - id: end-of-file-fixer - - repo: meta - hooks: - - id: check-useless-excludes - - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v15.0.7 - hooks: - - id: clang-format diff --git a/autodist/.style.yapf b/autodist/.style.yapf deleted file mode 100644 index 0e9640c2..00000000 --- a/autodist/.style.yapf +++ /dev/null @@ -1,2 +0,0 @@ -[style] -based_on_style = google diff --git a/autodist/docs/descs.py b/autodist/docs/descs.py deleted file mode 100644 index 3567c2d0..00000000 --- a/autodist/docs/descs.py +++ /dev/null @@ -1,58 +0,0 @@ -@dataclass -class AutoDistConfig: - recompute: bool = False - mem_granularity_mb: bool = 1 - - -@dataclass -class NodePartitionDesc: - # list element: (idx, dim, num), the order matters - desc: List[Tuple[int, int, int]] - - -@dataclass -class DeviceDesc: - dev_num: int - peak_mem_gb: int = 30 - connection: str = 'NV3' - - -@dataclass -class TensorParallelDesc: - partition_descs: List[NodePartitionDesc] - recompute_groups: List[List[int]] - logical_desc: DeviceDesc - - -@dataclass -class ParallelDesc: - stages: List[Tuple[TensorParallelDesc, DeviceDesc]] - - -class TensorParallelDPSolver: - - # resource is a logical mesh - def __init__(graph: IRGraph, resource: DeviceDesc, config: AutoDistConfig): - pass - - def solver(): - pass - - # temp design - def get_optimal_plan( - start_desc: NodePartitionDesc, end_desc: NodePartitionDesc - ) -> Tuple[TensorParallelDesc, float, int]: - pass - - -class PipelineDPSolver: - - # resource is a physical mesh - def __init__(graph: IRGraph, resource: DeviceDesc, config: AutoDistConfig): - pass - - def solver(): - pass - - def get_optimal_plan() -> Tuple[ParallelDesc, float]: - pass diff --git a/autodist/docs/images/arch.png b/autodist/docs/images/arch.png deleted file mode 100644 index abba4589abb61eff4bd984bd9dad0c093d0844d3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 49324 zcmc$`1yEjHvo1(Nkl^kP!QGuekl?}HgS)%CySoKUit z&Yg2-S5XwOS$nP3tGl1>=h<)Q4;e9d7#tWdFfe%W@51t6VBkAoU>^>kA%Wk_ghgBc z-#*yOiwS{Mj^iBxKYTJ4lokX7tBHnv)`I|k{%rGI%^nO4q38YU!+>@1PcX1Qg5tu0 ziZ0rxYfzeKVz`ih4`6awYwoqHaevq7_gLkHe#x$J4#kl~VGHd?k){o8Ul2A3t62!4 z41|&id)v`gsit&&Bv@x#;(=!vDII^OfGd zko@bS(PNbP*KdCs!u5xS{BxB@ba5j8bHyeyVNv|+&RS;3oqrF|ubc6&I|qlcB2fQz zs~Wby?Y{>#_%sE}YN*`#18x#XA2}gKm6`CH94mVl1iBj=oew`j{y17 z7zI0B%Pqj?`jhU+L{9AL~rJpW6$aTiDKwB7MgO-G796FKC+6laMeO5>K+1 z+>e^NEnGYx#O|#Io&>xGTIt+vx78W7os&bg$3x0y(uthMHH&$=WTo?e!tw=tw%Yqq zr$q<{x{~x*FXLN|6adpcCk4zOHlX{APCh>ZWh^GvyjvBP)u|a)%}p#h zuSt2R4goKK6EI{Ry(1VWpooUxrGej)ftNV(B?~W8Qg^bl*Bv=fLz`u)Q`~fSTmyy; zx1n1sQ$R4dG4e-+3;izrPSl6GcH4_YZ!8G`lIpd5v|2;D2!a!#*?M6q#G3v1kArFL zC>=LmFQg!8VM$jSa`VEmdkqybr(*B;EQY6avh>eQUii_b!G5|rXZGc{Tv0tM$p#iH z`jaIWMDNEQa;M7S_$QbUX&%z+CYjgEzSuKrDZPvDg) zJKHK}G;FjGa#jlwwMtn~sS-`T$52}}Y-`91F2?2w{YAdTS(oZTb~FB{5AFfOrkn{U zGiG^&*|agL%tfvXd*qXLC>7HBPI|su-InU9BX}x}SBUS%vrp>vzTP~dj)CrPGn=*s zsBi&me?fDa%>@s9Vwe3g!) zff+!|{KN87$p6Deo_%I%@mwBFD;xdbKT&BA&i`_^x++-NYMgOx$;U9Aprg5P1R*sx(QAimnW^L%gwh9-yX)h1A$HVi$`#t<5w;`eh-BELXlrk9ja_zl#uW*7~u$?iZ5bB zwz4kNM<3MH*-be68SZ-Ls>$q zoGPP8Uija1C=Q?OZy*{F$y9-`4T9i$jdf(YB(gO=TSbSqo&p=& zn!p9gu6|V=)zHTx+#a3pd+?h50e5=gtG6N5LtRDVB7P|bm6{e2m?UB2Uo69E6YeSSJtwLB(%Rb{C2A&N1 zlgN9DWek>WM+?pgJGF~0Z9IsRp2D!A#OdzC-B3~?rw$)Q&kRuLPHeuzC9a<9MT<#tHtw=!h+GcdO|J|CYpjrD*V{VEo z%N7U47Yk|ktsLKK^A1&B`56@EFdv=e@KbeD-0M5E^lx)Nk;&)lon!90WLR}FZYT_} zdiP8nMaUyX7nNURGGvYqWCn!3ecO2zUHk2~plhLg)v> zv$^Md^rhDFXxH$|jt65!x>zmaEX;=P77s)w+pkJgh$|si`O*Dt91fX>aazOQoE-b= zM}ksVm3DJuyOkL5FX+$uj=!?Wp?g|H6cZiD9jV^MM;l6Be3Q6(f^9#vMtOV5w!A@! zPe(^;G+L>cGQ_;4AV86wy+6m*lseqL8#wi(&kpIF6)@PH6X?w`TPmz&o&vEt<$bLa z?*>hkjpQU8FV+~W-eW~3=?DNmHg%#q;=*jo+>(1hVVJ}5pf`3!rta?Fq_?#^g5=8o zH5=hkdwP%PleSKSU=sudl5#P6hgLgg5rcMj>WjCo5GZbPZ7I+;6J#(}@o>KO3+jZi ze1H}2vz_oBKkSz46R*sfcEkiVfA4Wk_+M+3CKvwOR~-DOia&j{6@PoO%%NdE3k~jE zKe)TVZ{FzNoUtJ$%Ue#e(AhOvg}2Jr>$$mn;y0FG-$mqPi)*k~2 zYEy+|^fpW>SaG=!)1$S-eh8Cgt>5*9t?%h!I$$Y?G03=R!cx7W2rk~G}o9SDO zy4Z4ExpYPisd={w>wPr#ivp_JAlnsGiJa{8>>obmmSD09C;!K4%d=WYTczU-nH~NS zYiFW$`x+gz>Ac0s?zw>j(GYE`dj&np!{F^pM}3|&Bui=VYf@LGvPHbPW09&ow zr8^sc@%8g%w~A$}URNk1n%OO#@&#XHc=EXJ8zJgsM*)JfQ{=aUS94wE=lAf~-aebd ztMz??ZpS;|vtILG|Kdz>(t1~Z^b@~X4$6q_eh=rZf~j!%9)u8RJ){HeNoK3&t9YBb zyQoKIMkj2cxp5=p})_}V5_h$VCo)}6vJOOM&oA5W{#dHync%cx<5mAzweg; z_3~%ekq~i3dMkG)U8;ogA=Gyl9=hk5&6SEvB?@oRE+1(0>yWur9L6Zoup*rGRxyp$ z#{o`Q&>XX^lRKoS7Qz3ZD6NA1+&5vBTKK$AhBMYmvRCh-XS^Sh!>5k&EGq-D@-^~ex*^m6JfB|DO_1!SnhlC_ygOW zUY+bEFu<+Qd)>F;u(eN?7|DJSYW>z=5o9#-UOg0wSzyb4&?=Fg8XNn({fbepYCO+u z#Oc2$c~3msNZoqxO6X*EbuHzx+XCCMj&R{<*5E-%SorblaW8`u9Z6ao$N2Z*NeDIbwghzDeb5?a7uc4vizp-2zK zw~dte!fv0D^UZlmtCVE#3@&P*|KxsX^^oW7`o3|u2c>_UGaL%3H%nrCt?aTv0L{}3 zo>e0Y27BYT5Ngs~e;9#5|M7!1&t@6H^)H{yMMn~t<&Wnk3Zzap-Iy+sp4HmgOn^tj zWbf4+GbiK@Rqj|c0-rij@c!q!@(1*{l7_x&&+xw^M<$#c=Yi*iI(L5w!09tP{Y;3J z%4s@+3xe=QfwO+L#CADPgtnWNH2tZC-tOafCv@b8=MMX>iNpUZoO{7ZrPQDuP`Y} z;Gq`(xv}vovt|M*O_zv{Z`k<>5|u*cPr&CZR*Y=LdVf*jEj88IaIuymhhYpWg$Z)x z5XqS{)RZ49qMkbXqMx2DyH{%n4wpO7S8>?_G~#nroq5Q=rDgFT)l=@O;~*TDsRo^> z*nS(PX!SFUahP}GXE496|5}FQgd|VYX>UDhkMQmOqARyrsu(CXJ$?*)_?7Nny;Fb} znU3@j?Qtm`-X^V)@8q{RBLvdzYpa|gf8}Gfaj;8xWk5oYjnI&TOSXBHU=8TNqV7ok zc@$i02rJ&Q8lbUI`GpCI;64AW1a`T=7@DtvpTU?Q+umT7^U^DQl}4=A>&k8>zRI!8 zjv9QM)|m|`%GHyYG^tY5&n~jH`D%2egNz(yY=>tC^R1XSe+bSUwENN`{J=sP8P@G{ z7y>O!mO_#{qJ%Mb!Y($l{HV_#gVRmF)2SivsICIt7xu-hH5uVIo?&{9bUM3{f_c5! z4@mke`eWGKkmqk{1lygk_oIFmbhQgdp;H?k;ntmr6jTqko+9R??OIYF7P(443@e%k z6UQg%frfCQ;XH$EsIXWIbBMK!U|^N5QpoH0;!+z5zEy*%Hz23|bv%jRZIf>%h)9ZjGB1h98NvI1UA!Ug?O@Ky0P%DQ{CiBe+(!(2CXE!GOx*s9mlpZ&JOW1ziPtPf45pZZ4~3a8l#X@WW1(Y>v1U^2!2_V^X8R8X8qc z!4Io9i;{4QpT>VNIixzh>wuLGeKJv<#dusxml%QLtaGQR6j+PK2l@NrMNTc#bx5zo zfno@*+-gnnHg&5KyrA7OejZU227lkn!=PWPz21HAxmfefJIjw^%>DKbD|I4YU8;tT z|HaU=Wuu{ZG!J#r`Cw$52rtos5vXn8^LAdf~x7TuGr&(|IPmKq9LL#o`t%zKC z4m@lB3g^lFWO*W|=x$Fiw2Wd^1Kn2Gwi6DPwbq}|@$)q5l!-NT~KK5$p ztnRP-gXVa5PjHvGHseg9h5!EEE#vx4?|dXr%|+PPtX6nAac*tLFRN4a5d>3j=iNz> zE7kYJF%5d=gH1vM)u;RguG+Gn)8%dGN`w11qjv@fc(aQNnaA%R;IY`ED0K?U6=9Ix zfx;SAL?@)yk?d4O2YBz!SW3+e37kyiWSK^)us`1yk?_?q^&BH|B99GTscRX|!<3^} zcBw3HXt`xxlWS9&pGxv|IbQzM@_no*!OC?0It6jXA+1likT+SuZzvm@sSTgi2f_IQ z-SnrPcVev98WF$=@HBiU(mZ`ki%4QhPQQ*~sqHjS&pk< zy5<47E z;VrR#6f%kVEp6>Uo6?pouad{ne@6qqa%i9OqK9AqrjLVBr5+ln5R{{|_NLd;88grE zd1s|xdBc%%-jZ?oV2)(_;r=1#_8_IkfotD;0rvLf=Fi1MSbl+Begn2Cs5KHlKq9r< z)lAJc%J*O2s;k zeSA+p3pU7|1-hB3UB{EzWeXFo9a$-BC9u|o=Qh%(vy0q`uX`Fbm^%|^y$pIU}~(F9rf zuaH);e;nw11FI0o_AMMa*!it}D_1md;iBC3a}*15Vk%#r#u;r6*Jh_yAdf{!EFZE) zdZV#;y`0$2yvO$J7(XTBI5+aQE2b^%g@58ztKs2~ynEFf8q4rQ`?^cL%)1BuW&LvX z79@qBdZD#9y5P>iNcQ-C(#nVT4wPGZNi?S<=Jbwzzx!?@GW#PaL^%34#`clSmK<~Z zSG)1JbKmeCk*3gD;{!4#dUz3W?f4n~nnjV6KX4y&*)oFJyLBG4*jR>c5isP?%9j{+}L@DsDR8Y@49z-&v z_K(YbC$PN}f8s&)gY+Mv=%;O!q&)Ub@4lMt@YTGGaFaAiTdYq(S16cH+3tJQc>R#m z+nJT1ZNYSh6NA~PqJiPj!k98L^+^bi(}hC=by}0dr0t1BN)yOAAL?vpEPR(+W{)ps znv(UbOZ`sB)EbxXAayJ<|EsRo@OtOvNO`JLl{|NDIb@aajbxjhBB}76_>1dHp!N-i z>(V`L-^-q*2u5aRG2?zn8aXUVnSdQK_lo@XwPz>+IeqBOn?{7q2*W9gMVaX0ekJh} zE#H?cnpaw%32dM>o_F(|^i8%jrvNsU#aLPNLz18EKOJRHcacvITbo@o%>+Nc?w^YM zsFA@>n7ECTS<_pvU7UeB0B0Li`-Onhsk7T6hR-*E6;URJxc#sP#v9rwGXMPLN3I!q z^S;z&RCr{JvDGKRTSxW; zzG)eh3Pnuo3|GS$w?VX>h}bk$Vnlx~h!CW*l1nZTz$&Yz_Ml50ju>Eh_Ym4TOb%8XlDz@={WM&m?u@SzKQ3?&=Z*2#~>c z#qOS-ITc}RC`?RD6IPsCbYM1e%i=2kW_PB3`>f6WH?02uXCmppG=8e+_gPmG2{C%j zW}mTijjrRM(ApsYf9 zoDKQE>`thZDvGDdn!OMc6T3Y8C6-8_^c4z(p<3BJJuRpIhkneP?bXeogdbcwPq*&q z+8v6YOTerXW_)F~BkFLz&_<(CmXiL5t;?GQCO{cEs5a5l0}hKp`$Ib#n?>|<8|p~h z7g!if`tlBmj=z6xG&?-hJke#Xkgey-Q)@N%FD|0{emRmc2kV9d9l)jC=!MymyGg`nPMAkz&P?)3#LpM^}0WK zfX6~43rEZy#bpt$ziJ}A^jBP4~FgOrRSXCc00J&?vP^pxx9)uV2 z`}c1aq<>gc35xeK1a9~Bbxs$>bjIO$lpYG-8;#HNYZaFsYP-rFk^BAW z$Z81hr;4O8&||gPR^sb=gSpBc!4(HG8d|_F>2&7*7}WTEP!>B_SBtOf6&f7n+U@FE zam+Qo50^Uwi33SPzRk~AEdQK733oi{Ov=nSjiJ3ingoWPbEQfW3%9#r@V&AboXD6V zn6#=tv?Tu7FVjf69pR&dRc?By-5R>x&ex~1*4;e<0j~&h=`>1Tm&@(Gf3N(T1vZ7m z0l7>@G2AKnhC+wkp2QbM@7Tq!&wx#BeO-J%Ld+KaGbj+W3^OP*9GK zlXbl0lF5D&gd|hS`T73CJIisu+Z+60KiwMT)9w}D?mpJPHrbFCFd{QF7g08=X&g$- z{ppgzQ7t_z9Gvh@e`Mm+ZY1VPGneMXst57Md35C;7bayq{MF&oGpt7Bp8$-}Aq zL?6xl)Sn~wi_w|5lw7D=X*PHscxn7r2{cEAfjQ@i-cT)J@MWO#EKmGrX zm;WyZeGjLPtmx?Ifix^_nt@u5)9@xQrnHk;Z=f`W+-%fPAe0;Xg z54Kovkd6R{E_Xb0J1joUgtgtDBmFD=I2l6W(_sFF!w#e&mYl=ya)$(`FOQ&)n`qi3s4ve6N=$N|j2ytNpR?(Nvb6{(jMS zue_)QR1N__LEFm>=A5?MpFBK0QE#S`+{Zm%9}iMEZS(A2ztbth_5+g(l@)kF^Y!&L zgPE)~l!U`#1RG0ZGx#eFtXPQuakxoxsO!Ez5_hMSjfaQFV5!#d;bA8_aUtQgT(#yS z0Rh2q`xbMHNInqZC2ZDn1i=fLMF1L&`by-{@nl{PNSwoTGBvyNmz{5~xP=ml@(Q#g z)LqRUwWljh3=yLFSLz)C)?3|y#cFlGt^l0|B_F5fQ@vg6!(c3_84RQOJS`kHv#_Vg zk*}s>Dx==V4$wId8-&ShG4lRIHlm7(3R@uXPlmhKBNv@-QF=W=RN&zwcDsYVXDdxK zO|@$3rBFLpV3u#7jkj5qKft~vy0C9W**YGl(=Df!#F2HXJK zOmlxQv`aG&4uzn&55$P+I^YlqeHE^>Mc&8h3vR4WVopY~SN{IP`(X zWv8s#vybvPuSrmJZ*5uaV_-*04h~NKT?9Pckn9Tma#8nrz1`hbE1$}8P;1Hn{8BDo zlnF>2d5T{UKeYnvL-U8#UBUH6%avcF8bYP@5hEPGF{$K*v6;2K>IVxYlM}pwg)`$k zPEi<+B=uEjG$ouLCo|~mlTJ+gtg=-DVRYCdmD$kdc$O+k;G3|p^488yK5v+YGJp#9 z@P3BcO2`))dYrGeND$HEa#-Pk11h26YCL>=CfBQ9sB+&6;vl>h+XX>usL_JiFRr4|R8UcPtN-U=d*2S3L9`wF@!vwye02H>bW>Oc2DeS>cL-qw$= zFF?$sc|RKS%DPIzn0+f2-MI zC-SJ`RUg5d#|K1DbsI-VrGxDWY3uH742f=QnQ0;>DKw+iQ5h;^*C5(K=>_tBNoNq34P z*37e=bD9e1F-h-UE@j*ej}oOpB%3l9G}#Koa)%@=_2uuZi|ZYQdW8qGDqC zfIlzRo5rW5aqpT<-JUGi?hT_Cc62dEPSdn#Z*Ok;_(8x@@8x7>X10=@@p(QemXvL@ zyJK!`ZJANa9Wzg$R#sL@@O4e^wpNCRhcj>dcDAO~YE?}2wzOlk5}%G5>1vPhvt{dJ zCg61;0)z@A)GC#<1%&F42r?PdX0=UCm~k?yM+*rD0A-j`nR4jh;P>0&#!3JZUtUiLgr~8Y!&E4z?>gCS|Ev%n zM~%egz@SksfA78m6WQP0;l1N20EYGSJmSYn_HF|(Cry{^X_WnTLe?*oFA^cZ%!|cN zS0*MVWM`{%1co;JP>brbcH=LKtoZp9s*%86NvhsIk$VMNM#{l zG0?o9FN5v2R&jPkLjkBP-rU?w*20&;S$LZOV7wZqGxhM@&yr9(N3H3ttJ3A3pgT_2 ztAezTKgTn$&(F`_v(L1a>Urz0g_F5L;SvgJY8ZJ49j=c1vlS>E9UVvU2p~Tg+?_62 zJKo;{&)WnRtl|}}YM1QEB|Ge@ecl+K)jBk_GNQ{!xt3(T)*6^E7Of9N0cd!5rR0zA z@$l*PB+<$&sn_b*l(eC>q)M{|1xOZBHX|A}`hAn-YAlL?O` zUi~dl(uM-Z`b)E=5R8O~i%auUR8z3Wm&*}`&WIIRSF_pT*cceDGA&mMH*6{-O#6dz zbS}Hy5FXLJ>wW-R%{r?>YV^;DngO|?{aim9-8Js5QCnM^c80R2)$678@87@QCq*=2 zLi*wl56;=STKLm)Tu$eNL&KdKiN+*%qL(X~W zZePFpOlJ}~iLI@!%D?b`&0A;yYQITLpIOTvE3gy-qPp$D zIR4%(6ro3sgTQD!Mf6n~WxK7teNXPU;1~uS{=Hj5@8?_xSa*#KPTP+<9+!d43`!w2 z04^+)?BLb^BIia5u`$d9Knaz-Tj2Ek9iqs#5w-}kKVRkYxLYy?+GrsLfS+QqTVMf` z2KWy<0)7vyFA@nqv`20a-^-Qjr0uP(+=wLlGM&y2JCOENo=Dyys?Y#$3KF>z|c&}sC~mK&^Z`I>K>D7D)-8!eVloXfp#=Tsz+Vy1>L z2X;YJdmwwmE503|9ip-yI!%))6AT>bT$xPn;S*A^DBZ`ayQDY_3yzUG{tL*QU&2>IsNtPmT3;q;WDpl-yudrNQZGv0u52&xP zQ$6%f0Y%Sn8*h#rFm!@z-i^$8AK+GD&9FbO(El!mOWd5tL0!j!e^vu_km=_jl3~`N zC0WzQ{>IlKf3$JH1ya2g09IM?Am`@OXc}ADQy8UZJnjS#Ve>{u{E^?GV5+R-`;;^o zcFzr%U4Qlic!ug)7`qQZsyi|gSBGKP?dIjK&hC!^<5n37u|Pe;eEu=+*#V(N1E+mO z_1DyIqdz_$C{?R-m>D?i0)vBTuB%Rl6KVclUCHm`vY7?Em&?F_BAs3@339K7c8G|s zH@Ec(m4K!AJ5_<0rF0O&jL?#<48T2L^{e3516PKub{ylGJkHeDi8Lz8>aU>y{;DCZ z8J%FdIcG|hT5NSJV`k7Pj@9T7U3ao$Aa%=PXxV#5b=| zZHp|IryEmE>yrAoO25m+1NITE8P~Xih~eifvz8_?QAZ&BHPam0VLIkTeD4S$Lp&xQ z8GMXgU-f=_3C_tOGSsFoTB4S@fRUV+)IT|9c6N4ljG{CUmvy{#K#v&F()kHs#Qhx0 z1ic|}#Qre8Uo0fomK+-tc@e?{-UUF%2WUu-N8jwf@6>SIZQSC-!9+IUQk^ky6FE%1oJ-C|xTbzb^Pu74e?Svd^OfNCfVK*GTUs(9wNU#{EIg=MVmr&_) z={M?~!5yeEp3!+QSi?Bwb48V<`*9G?%?hJrwWf@hFtD1*R7;N4mO&R0M`99in8$z`8=>*s##9?DEj>h}`eZGN)xVBRD1 zkkJxe-Q1R#l=2o5B3~7S%5YlBG5^*AF`#e)8xj&1+SLq!WoI@O>#x{f z^G}$+G&=R4>%T$QRs~HgJZ(N~ovHFS(uE!Ej>U9nrRcwY8IH-bnX#wluaHA>m7Eva zVJFNt0AUoXlptN9l|}Q{A#A?lLhR`^?>VDUJo_7{*C~WNMjvYF_RghB$E;0#KtQNU z!;5X`ispQa>I-Km+ffZTqF07DiHhcF`1ae@W0e@WkPdRk`gjT1W1$1QW5X6dwba9} z$}1!M(f-$VdPl`3jwgz@#VckKf1*idSe{2}+o0A1UjoC=`N9nc{BVJasI&(4ZTF1? zXOl>7FKCLa52-kCf0{B+a?C0)4woyKU?jT0{-J!?E%G~m=hs-E7%sg(@x4zy@D~(N;<(e&hv{4jztaQ=CACRw?p;YgHyV|FZxF5!pM-2^`hV;MWxN%sq zpyhmoYi0gcqql+c{I?Ja((4X_oPIe%L1m)L+=5i~d7QNf+_-em9|jXt6kW5gwNwATe%-??#9aNuJ=?P8y7MvPayiU$B$HX8S~t9ohTP za$@`3xubfk=7~1={B+;>Sud_GXWH+Dn{H>+TlnH5-IY^H;2NlH>#Rb?9V1Ut?ntwe z#wXg?wZF*y!vDaJp7N$rE5=<8e?UKDa{IvBmSU5-5sp3|P&OCB17_+C;YqIA4psDZ zDp!O}#Gl<&mHG0cMT)(SkPGI+0m9DTz+j1FdxIQ8_ECaxJ%3)t=iiTuZp1BX8t^VI zm(U1V%~3^W61^1)i=A&+R)e2OKHiVx{UFf4Xn#Yq66tL_2+TY0zWBrc@QhjP(aH1| z#Uo)N>rZ5LN>f|HCGoh0R>XxEeBYl}0<60+*7B=XlqO3iDd-j_JbtJbn}e3K*rZ27 z|B2jdml$D&)yWfz0z+fu7?}3CLww!Vs61KEmSZ27h@%=&i01*g@QIpFn)VawI%JAnfOIHI)C1;o^qxMZoe2 zTmXckp!~Es56k>ice1Yf*h2|NzzF6qCf%Pm<^|1q!Jn<$4!J7O3Mt^R<@@2=fc|>a z8^1Qv$_7$ib=$R5-=8585)wPS{A)ha?Ci{`y{D4{5Pzt0DMCbKINi@R{T zDg!J6?cP{_3LCyl%)vONqENlprdSy34DEZCemLxqiH;1vME)jKE0n3S&#?AD1-cyc zPYaDKyS=^rY6+->i>Zz>q6-9Q5nWmMN2`j=)I=O#HD~b@J8c;7n zFDjyMJf^-eAE^LU`s&4Gp(G|7+fFKgepbC)TwO~_u%owz9poltNLc}iWDP~UGhCqc zIiMpTK2@hf6*N<)?dcoJOM(xO4rJO+X=@F!%yZ_qT(vgjH@&_Idz_S1{qh!AiKU60A$gR52ev)A*V3&;fVLX!6R!X8Knn-jGw>o zS>XRF{dV(L{&`ZXtxiB3HD*Euvtqt|N9QzsSR1-*a1VQ3Dqe8^Lr$m)wU2*=vW=Vr z6~bbSzi2A-^p+y*o9rwHO~m%ND5Bhu8ubpVo8HAX37%y^tuVDgsh2Z#Wd};U2$W>3gv(zT4Eu!HYZ}skS znV5N~XMO&Uv}wy9=9_r58TWkLNz>r1yf$xU^gypO8Ul7-@5kB0z!OoNqXyrHbtN#u$4vGM3Kk8AJJ}l;ZN`O0K^>&}wb?u!X*@c^ z-hiBYo#+vt!7%apU=nHUga=%XqT2pItB!#4T}B?Ba_S7mbN6M?v$=kUP%~maUwFqz z_l{J8)Q*eHulnS~R~304P2ij}(!N8ZdNu#oblu);$Bc~`JIXaXVb5(2v@m}={04fv z+&VpUa8XqMw&cDeOv@^YOT8&HKY=$!Wp>pc{!UoVjwAPa`r^JF@OUQu^~VS0-H`T_ zn@6FCP#moG0}MYuzrU;x+#br2qASefA)Pg1S{MDqhksv`6bY`OqB3~Es~)vB%|M?h zj9hZAp(-X1toit}+EI(f2?zRZ&j(2f{@6cbJ=+bG%M*#K4#bk#^N`%hbRWRJ5Sc~v zXK4sdgR)2-S`n2at-_hdjhcx0?jhpP=1dibH}?OZRPyc%qMM|733C2+;jT^U{f3aX zXZKm%Z1$r{x{;OyaxXKX+ZXoYtnp%Q;lzEn)1Z11m20(ahmhf9^;RTB=VPB@v9C#*ur zozaT~_j{1Gj&6LWsoDw89-Z~**x*1nAl_u_UQ0=j^vvIK=8@#Q+CA{E^qL5Li(%W| zG$3SWkiuZwo8P|5?~|*W*Y|&Okdu{INXDVpCIuDRy|RS)7*+(v-*^=%OYV3>?Km>c zjLDcdxaG!Q>Av!>>~yZ!C{fE67(T^qU%y~~x&H!ik=PqNYs{6q-X&hKCACxN>X>uVOvf ziTjku$t9?_4cYha6}PWRGREO*Hw4KaSRVc0Ja!1(gG$tpX$&9tbQ~8~5Wk$fl*<#V zA~h4+Bo}wrr6tv+e8$_ZHg5$|{EHJ2KghM~M3_}m-9Z!<(-*7?-MGb~(C-+gkS?Pq7i z6%$J)RHEKXZoI@k-*8RY%sTIlYr^N=9ZbhfWrKXE&a|7Eq*A1~N z$@zOsPIZlq=d~nqMb5WHw(^ieI-Fsh=mR>>p2>HBfnPHUFeLQAR97~}AfeG6!NF@b+Gz0MIS!>LA=39RBT28WuFK3%ev$DZs%dq8x04EmPdM5VqU<6uN4fk9_saVt+ruf#}7GHt?$ zPg+TP8WLtWdX>eRSn`0pujNs(dawQ{n_dK=K3DDUQcK8lPu(Ir?B#AlNsDS${{%xg zYwL5gPjJdvvdjPJwGMcxLVeeA8E4c>*i<@PsIoBDC1X+b%BMPgKPpgc;m%}l9U6Wu zP4}>Mh3VkFfS@XGMk3c-2oJ<4$B8#(wy8}5^tCxVJ>|g$WuI#UCkm(xw`FM58^Q7h~w75LfoY1 zmiW;nSwo!2-&yZcW5xMtu;D_Y#B1@i$T^cS*L*5V!bO-vc2;R%)cE z8E@G>pMJy5z z$42_zGN_(;kfwNm7ZFNrg!{be7;CfWSo^`A-A`N7UhVpGh>o1=cFIm&`GUyC+dP@rr)H*f) zcIIxq&4~PD&{Np>(<+$VUHT^dBy?oqtxpG)C$D@>>ojv6p3hh(?1(BRVWE5m3Avpr ziK0a0B!YIF$yxfyF;pg)L&IL40U13~A9{C?!UvV_P^25A)(%e3U+br;ku4G#T-4jLra<2=1D zNBFHMHrDEpPlw($<|CpODmD+4(QTSs%NE@5HRyEV>ovaOAvKRs)))t2H;3wcnXiTEAor6o;~ zK5b|%*4pCg*Xo!vf{+V!Vk!7)1pYvK7V@WX=0Mu1lRbHT&~j3II>aqsW5GPLGx>Wn z>Eu$l+hHLozVCEctI{pDQQnu;au)OL8ZEfdeJd~>)ar;slE=}j`z)P8pnEV8e|gYy zKxk~#{DeDC03AGuDOk}eGN-6?BR+cg3emU`&TuY`*r?yN{H8m(|1}@OQKw3yc>Y6F zYN5}c2k0hiiYoJG^O=3oIt;eLGq`~c*O^WH@?SJk&7GFa`-EYB?6SBX%hY9Q^fUPu zC;K)HOq*A&pC9%TkjF1Lx|f?hj`KWX4sU(eAPBmXnSr$0mxt~23{g+~X*ZeB5d(pt$aLxd(ou;j*A?#ZX4v0FPCnM%mo6BvH zTcu+rEVu_i_%N%Ipl#yNnsPoFRQv#HGwkYE^x8bf)xlzgcg+)}TXS=uCtcF!LP zkZPPS5hb_gk%GRt7wc{^-AWjPV6q+wz_dCyUoZ_}uBBY%x#PeXOYx&(!7z9U1u;;u zz(ISu5DFrj+I%>t13nNzI^}$4@c3e~MdGuf>4EQMUDDQqj*9Mh^Sa66ru`|*bgu5e zXDl^}DK;CARtH2w&WyMLMS&L-;vx3Fl3t8Gh0AmnOn88Gfgkec7;k2wyc} zG38w*FC9CfY4CHG&3{XLlZ z8;1T96hsapPpBH9vglY;5AaDIFwajTD9^l&gz5Swf4IMrmnr zzANXJmX>fvP>qS{-E`3iQ^ptuy-l4x2REj8%Z%V_xQ%^n=1$d#;I|w|wOcHT z)C#DX1EYNxE4(%C*AfvI?0DuEFh9+w5v0o%g~_aWgsQjdBq$Z?So%j_@K<=94X%eT z57|BF&ciHBrYnaek5Q|1igH`+?V2L&bk!m6RsUm zTXzuWrQT9`w(<$JUk4REl&3@vYTCzf8r)^rK5{iI zC_5wfowC}V%%flaNN!J?yV5C$#?l&+J7jf>Z8)NhWdvEGzl0m2s8>V|zFabpfE#s^ zwMR~1b6~;c0aWX7qf|D(mb5hA)djh?3ZOyi1K8j>#S_=EY!)j{ieIy~z=trD>YrYI z3JfFDd-dj*#OZ54)A>JGJIkQB+OEweA-F?uhv4oG2^u^=aEIV7!QG*O4({#_!QI^n z?(Xi+9G>@`neQ8^si~PN`VZCB)kph2d#`KlP@aa(A@=YYiI@{E#UmZTZ=Y&x(Msc% z6jy0W=?i%R2A#bcOp-+@k7&YB#iuOLJgh4B_73P`u`0JgY^E$BPl)0nrFEO|;P_I` zt1&qd6jcXH}tN^y16^ot)9QO*x10s@!e@ce^zVA2eTlCcttj{)H-d z(r_7NJfhM6_i*E`gKT3X`^PH$ zFnfsc;W(^3T*MU_wdWe87SR~)`ur6~+tp>XUhE~kngtX|G*70(4MZIPoIy-Rh7cMW zn$PV9LE-BgELL2I^Y!ydI=^pMe))!8J*KRSWTtv5E>$G^i$rp0!85jcEwj;TQBa}^ zA2Zt{!QII`_|ivxi%ts<(Hx_8_={oY%h|jKO8f_(Bg4Z?(__*?zIW;$KRh9AtnDAd zu>2IeA;pl-mvz~46}K*mqA3M>Pdj1|)YyIhVE$9*2)Co@eFT;yKd4gS`#I>&WEro1`o*?)&_7i0eJ zHVIP41zw_{4cU>X+5gK=@p)Gn`KE7X$lkqAw={MFPsequF)rL*3k$6;V8wE6egJU% z0@sU*f|8M^9zS97L%ln9O~Hj9K(4|F`PCEU=;h_*E2==$dKX&*egFnbv+HIa`Dz51 zmjDpOVxd|>MkW!UK{Ca`gl|cYYykTkd|GV3y26ccCg|=V=D$#~3Z{s)Ba=tmx=s|b zIo|t<|9DE6MhXwtrh;`%vs1ZuPoWmoK#=#_oXhnL`=#mjQ@;ZC>2zsOsEfj%8f)W3 zvqBz6pqFj?{P62r2Ui4d@dxW(awaW1`SM;l`WFEs6CT+@By;rJ@07__T zN(CMk)(0ddjfDgJ=Z(p`pUv~%!o0Go3CsVkIpi$Nd}-NGA~IHKG~`@;N+mCgoLN~& zo~K!T6gyv6W|*>~X013>zeEbacenM=k(B1OVSTa+Ul@S=K?Z}2hxM-^u1Jh#e)>&0 z-~njcKUIv(0%ihYSvz8rjYdOjJ~ ziKHG`+-HPZme+!y(YvYEXqa*vhHj{vw@0;|cqR1URsEF#D|pY6ye|0mU|Py_eFUJx zihdMkF;tcT+$U|zhb_1MTyISoa#Bf;UUOPz+jAEYGx*HK=G;A%llv88q+H&$zF&fy zn-*rigEp_9I7W%Uk%S%-O=HAiDfK8CV~`~K8W~+{ZN@kh;rPJ<9#Fb>_1t%lTYd?T z6Y0zCwv3GU=U${OC@^jIXxeFMXG@y_*|S|+PR*)6-;EOT`P(CJvE!VdP4NsrxiURS zF0z{S>Tq3iozIAntd$_UkA5UpaD9F58e%%B@;t-l$8tMOYMb?b%;(U5xr;p8Tg3B# z3h(X?(6D8ons^|!UfqB1+NlOcNjOKi@_Sqje-97BZUEQ0PoAR}Ii9;CNMN?AeB-DSCFmbB6JGcMM|J>xIk zqW&R$nKl~I;_C48!%($!%ay;k~b$4N1zGDMO- zlgoXNEi30NSGPrsDtXukRfw(hv`S^&Yn^H4bsQ(x5~AYvQEuF+ z=eoYnm4A^Hwktjj!Ys^Zwv?F9727`}Iq2+g4lGZS3{_hLZO}0(dUqbJFfJU_Jkr+N z*#oL{9zdhRR}j}40>SJHl36jF`b7L?$hh+sJ-ccOgNIhuvys6P(a{LfBZRk^KEg4$ z*wZF8p%g>b3@|=XJRh5vx^oJ;db!=ML=jrM7sb2Mv_o4D&@ec{wN7a{0K*5(ZU&9! z>fu#jt$Iyj(U$d!BNXENG&85Aez%b#H&|8-_TCLc}eQGU$hBP%~hmv>w_SNsje?0ywO}9v}PNoIzSt zc{Apv?&)My1*T%*NCom}&TpCLG;v;SpW?F!zPl}nJ+Z%d73h>%4#!-hy zLLyf_=Lc|Q@|COxYf#Frm&t=ekKHo0PZF#LV#+Za-7gV&F#mo%RDdPAItoWpvKJ3VYvYV~kk_q&a{#CJs(^0OGrS1&NrigVvh8qaC-=$hUc zWLhQE*>N3ghr<>egK9gwMaks5I6jHTm3A|4gEm`CbXJt!awi|reP4J#Sf6^#)iHYV z*z6GM^YjF*@3cLl(=G0{5cNH!_qwM9v!u*@(exzs<3qcMAmB}d?=}SJV=9F-JJlsQb8bJCamD|4ql7%Lthc+3mVU zTIGtYx>ye)e4Q@>A67w8=|O+*e$!;i@rMrWKfV>_rdj9~+RtWMGZ^CWgyHLS$3xp5 zZW%{uJqc$vvzYh@)1=_XzJscl&Hd2XuE&61OxnCYo5<+wm6Q#Prwl;S3Fm7JW1jnW zWGgGgSUBsWsh?}=OEK^&U7~cyGh7{i0<93g=l`WkkxRXVLCq!QMmYt32YtUZ1Hp!a zN$hDLl*(cwJ0h_{VQ=?^!AUKi2$2kWD;$?y|BI|hp|ClGMvGVyYH%2g3KX#$Mpw^e zJsN!m5JHWOGOKo^39<{gRj!=j`cLu81$EDWj3^&f|fHifVbKPH1ntWc66GC884B8&SwBazt2D{*S zey@TiV-hIMhAl7^3Cf>Gyn0TP3{AR-yZKBvRht18W?kj@7g&Kx`h=^+iKm z0E}D~pyslr-!dLec?+AV1#=)@Y=y%1o~|?|mgp{h!VP7*55GaGkSjvVtqk(hlt)_+ zhEcJOb1d;Ye|6T|-#wv1^B_$2yzq8Qj4BXxZwY`{*qvWARyc5-n7F92)9&#fA7|Qj zs)%IJSS^ z#oF9LNq-!_0NtcfX<(Y*X?MV3e&u;zS;0>)A^}#njNKBDEkui2+L!O&gVEM_j0cj! zps?xjV~I71KT~lH9J4_2iw&O``puuHc|*yLB2B*5P>q|uN4k6HylvQrapwSwM{g~O zy&?fh(ir)@BVFs$j346$lx~Z6-!#&x?R^cAZf+_7*<6fXzmGG3_P$ZYTs%Db0F8%3 ziUN+MR-_;&wq9xGhOi@*gsd&N|7!nf3RVPPJ;d-;tZ1Jxk8wXV%M~$<^Q9-G`<6jU z`qtg{s1X(S@uEOsR9PlGe@Gq4vP~0qOq3u}3O$sJw*{pN7m*Fu<&MLYpf@IHm-LRA zkkZPKGo_I#?4>d8%LnZcf7>5@P2Pz>>CNaZV9(=C*R zd}j^UQO;?r{rLyqXj?u1sDENwRN*~P8vrNnrb(;~wElmY_L4wv;c-}e0V2gH*w_Uh zw}7yTGH|-JBuZPQ>)B`-6JLDtJb8qyUGWe=;aFM(5j7zz5boxVEF;jH8GY!yHF$Bb z{)~+|a`ldvV$4s@AF4s!>1*!RLgBA=2`k**uk2=56m zA%}iHp?C3+C{GU7?nDUS2^J*4KQED^RVa7JIdY z0&m3`Yr$R-E+XezX4_ib1Qje=ea@rT4u}PLwvTgMkGl*QCfL<%G~?vTHGD+ z#}v-v80@@6Sf({qx=ep`z^lcr3xLP1we+3y>8kos60FbT^IJ=i+2O#hB!!o7(^F!_ z87E2l!|TYY_reI@+2+P}f?yf#l3RH%eKeUF!aW=jMU{D7nTU1Ru?lYO#4gigSEC81 z-9NGpQ9%UJWY^Tntv|30>ayMbH_UFv5(`4hOpf2lf5Hg{4wT|%^>ydox%Ckq%{Ran z?}iD#awx4YTqFAzo#tbqV7>(#mxh2eqzgOw{n7Dp>{zYIgpf?;Eb2b1X3V6yNFK#y zx0Zos7Pht|BU@rpu#0U7b^KOikD{&Z7?J+S1KEabbFkD-@R}MW&2{Qq)&j?$1|QbC-R% zl`s7S9dy*)uK?y=Z@|sMq9EhJ4?Kb@%|!D$S+Q@AST5BwTRpnoo|NVRBA+rX&>|y{ z#L{C)oFb<-HZaI8$p-*sGs(eH%?1`LH-LT4%F9E?jIYet#F^Jcx_oxwF;wZ3=#wi1_YlXt{-sf3vd+^A28-N>`Oe0r> zMQjUKj1D;@ED|PUV=BO_p5ZbK@@{ibw#!evq#ix97{~q$Qh1eYYmkP*)opifKZWIq z6+2gj06(ENNW^zp5wi3}wSLb+WMJ$z@mV$s%UvhubT%t1RraT8)sN3}2wj$ymKIX- zo$|?Nbpc{~sUXmfB3QD9Sji4h2d1Q^##;Vz0+f(Zlw7K<@PbBy^tPi$LwAe)Cf@Qq zV$!`P-K!D${vA{V>=a~Kt3N;Yf1a@fuu#HMo@}WgFq7w4@#!kb&;-x!7g;xc3w(Q$ zAV2qAv|RDtgkk7xmmMjV^t>Mm_Y3|_#eCyM_E-*gCVUCxf7_4+?%oMLo%yj4@tV1a zIi0TXZ?D*@1;AmP0kJQdTa`(W>K?vK9F{8tCDQ$dLZ|0YOs@r~w&M=b*!H%T=*CJD z6e|_4+p|E=W%I&d_54WLTSISjBW=qcTAL14Ud(ChxGMHkt4|#VxQ6-jtv6>X4fj;I zK@S3ngockVK9`(~o8z^D6=wSiB@gDDvRZASErQizo|kwHzXHzNrt@t*UtOO_Jzg2{ z0C^iYd@7(k!&YQ^Y!-ka#VcH{?MhUuaG+l=`+U*`MD#9CD@uJN^2u~zR>D>qbWm#d zL0QjijG8Sjn2#JUff!Qp2G@3UfFJOI-F z?!1p^d&O*Txrr0VqvbISIB$Jf8W};1XV6mka1P{#h=Ik&fkIYTxBnriDbq2v@oo)T zN2IVY#&d3)2DuKIC)@2*U~i&<{UC0Xkm!lrC>h-l&)r+j^-64v;ds12s0^!BB{`SM ztmY`#O2K96c$19jBcO}j~@JdPghYo6U04>D33fgvW2_l!b@O+)RB>A2l>h{?0 zYqrr^xgt24(1!eiA;MZcY#=)*tL;bSf#4P1@ZB-775RiE8_gP>yZb2d7j>g%FYKXS zk3dhtY-XoZt!skFdN+4;OgX~v%;(1px!24)qh4#SgsDMo@husAO;!Sfz$@WX0l%}J zEsV2P)Jz8VSgIvKCzMBzXVj>@AehTlKf`D-{fz2Yh(K^xN}JrqS6>Ea`AO?Th;-KX zlZ2AxCj^~E5?wi~Ll6c=R=xN6%%@QS+VegxtKZs&y!Ei#oEgg~oi6u|(Q$yyLtam@ zWBL!pmcEk$9Iy4M&77bwq0Mgi$(VAzdzsIR=Ri1>{{7#G~y=6R*_q^MrwwG(T(nPEupIxQ4hq0n7J0GEaSY};5bY_DTBVOv5a4m% z(;U7Zy+g^L8R?$&d(0ZkkWVS@UM{4zek-G(yJB-~)X{J#Lm_A_GZOG{_JO9Ul`(gY(PJ$S^1zo?YvVnZkh)>=W=8MSjBbmTPseXUYyWOGS-vDHNTF>8Fn+xs((wVL0DKrL|!QBQ6Nm=0_gn@uK2+1Z~PJsxYR z)c)lBX;C0%08*Jb43%MF?B2;qV2i;J#ySg3gQW{NyejvP;i*&$dWqk@i;ySAD=J$J z90v2JTg~x*Wg9itR*5`>uj#IP^qb;r3;VjHIxD!7LYv-JWr!XT0yQwXbHKn6P;4Dv zd%1LBl~r=(d5>?M&fe6oPnmxDW^iV#AlybxhPD4(@C#ZgqXp71XTYg41Mq9<98!Z|e9YqK#?(poFTHdn|D$5MD8| z_}%@`MNppLczZW~>Y^5LElwT>qPZvQ+@L2$aO@j*&?FkP&5*ADRWzP8gE%YL$5Z zISw}e>jyLhNK!8!n9k7&h{}UR-nY!g$WTnyOdzhp;wKP zn+(cEHQx`mRjT|-Kkmp`m$F4-XX-t~^d;mM`93a#|L}`WpzJr|WC&x#pF=Q*zRhl7 zy_x~t3_+7WJak-^%Kb15FF&-O;5|y|V@c4ZnazU3;Ky9T-| z5_WZ8|A#6h9RX)CQ+*zMR(WX`6BsghEK1IHXASbFH9H)rHoY@7e}ajxDJa?;YJ>j~ z1wWk`_Y(alsmt?+ue{<0YemL3=?}=G+7DQ~%80#>j^&YU(Z|#`5hvd7+qA`IIR37C zY%&uH!u4(rX)qN<5bmlVh>WT2Mlr8xNY_c;uinvd|8Fa*Y7t=rzvsB>S@Tv(d@2tk z)}wyR&>t%iCU(ghjIPO@=^YBFmcL@#JEnISwjC!-(g?DBg=X45#Hb%dFhvCthTT{D ze`A4!`=bG1lNoI0c4=6`SV~d^)E-mTU?(7WnJfDAW()zeAHYrjXrBIAWD$WPZ_<04 zpGSWoU<13NAKf#aan1Wj;FLp1Z_Tf+waT>v+QE>RI1RHK#3la@g4RB-?Y8!EfnD1k zYv{Zh+7DX6jCm0WeyrtGU!-qm_3zrrx$+Br${n>73FWXH+yq;#qtMk$;a6@G0(wA! zS!J%g2Y*|#C-mCO>r-kF^W0Y3f$Sp6(;Z7~SP{hYJqGriJto=@GJOuTWGJWh+t=7W zthR@>Ck(-(gc&4J2a>Z?o)7kD7C#4vM7L>u9rdr;m7I43FkVwy{JosM?Lrs6le4z5~POYcj#YyGN+5VSl=L*s<8QXlnpKXf2=I4xUBqYj!)OwJI))?A_5#o zeFaF{ktviRe|f(k4$|*fTH>gbshT^G$_;UEgx$wVbzz)WCHsxHi~YDCJ&gsUx?apu zk~S%m9WByffv4HFB`YWGp#Gb6+ADz4D@5)sf6yAPO4!C?Y9GDGV)bcummNcP(i7vP zg?%E{#E(m`A+F-@n=fB@60`T)Nh2awey}~Y8hw<-2pPmomY&Wl6vI?Wc z_TX+YHP8v>0xzRQ>3PG)n<731ZLt1CyXPtVh1cCfUE?4ALv}jRRn+yGv~sZsm7}LI zcoog)+XQR6EDvFIXTyU@C8L0_>DOT#m3)@}pAx(6K_R%MDe}13b@kq9_ne`>3d0hrrYI_&j_63Z z`*kH3sll{a+P~wY;~W|=>X7!`)vuzbH;UA5*ahyCs>h0AXLNLi>u)Zu1@Di^B1Ma^ zJ{{Guex1jC#GEgIY3g;Ab)?d2o9?lEIh3qF#3e6efio6-Rm0?d^ChNEMQ6xC z^jy9Pw_*a@=A|F?W8na!zDmYBJIP&P;crZ7e6zke|Kf%%#jmev=Pd(gc!A$S?Jod?~Wl{8vXF9js%Bvc8*P!>EGX)gSZVw0X{fi$)JAEv-BCjqC zA8S4JCp97yS8N&Jl-Ta@^0U_0B54D^g>(DQ#E_1s*0n9~G)BAw3>!7Zu z30M=bs4fBtofu5}d(VwGHHDOfK>jSQ03HOoGVb@QY{vpUvQ(lmq7GV6!RMCdrKqud zzkGLJeWQ7m>C*z=$=%((5d8HDBSPvqt}MP+0Yz`rU_yuDnM7x}NGWb^u1>$d+hiriCgPx| z2lLgFEnAdMuB7R`WhdE~*4bO9;`-|;Tu4$)?(^wnHgig@-R(sye|=nmiS+(`K=6tz zLa(Bb0lbFBCDIHFy09 zw&H$Jh-lI>L6+V8rY)qOQGBRNAC1kyk-|lt%tust9E?GO$6~sz`GRQ3E@*P_mRZ^I zhkNhb?swxb`EH9B7s;YiOoGL(%#V0l2-s>B(%eth`UB8bM!X=~;BLOhFUC%TWtpZB@v{!s~`-fn4;xOYc}MH|Boac$2|=I0qEY-Xn307F$?ZLCdDq}Z@darJHlZtWr1Wy=u+%ub z`u?^Ls99l5W_`(rTgch<9OG)RF zm}1h`MKNw%IUe0`l9+#WOVMKbn5teU)<*(=R(|}5UG2_K(bVHqGQ-79mWMk53y<_! z$>@F~c+4owYguS$V>^p{$oju`s#vD5udpwca~DeHf2XorkxG651;(IsI6raOqDQ?q zj!5l9kH%Ef8p`n2{4Jj5pYQZ8>_KvmwUVePaakkLE|P11Km1)bVeEsnNSoO_-K}^P z{rlMs$8qe_wT%GX?E|Z#aYNCDS{?n|T~deP+^>6)sb`Ys{N*BJ=pEp5_c%U3gZ#Ai zoDsNL7RVSl^DPNEXK3uq@+oNaM(rpBwJEpm`8`4$RB~H`)&;cfrs`2+nFpgNCLXZBG z(U29UVEQddbg1+v$bW25>5TP9w=o{$+Dz#s*THeXb6@$|$u&g6W7{OGn0M{JCG@A7 z-2p$StLZcu^6H;vzGR2&E9tUOAk|${?W=!sHC+UHVWc4}Q~qwPeKyMlJ|6cpRQVAC zaw0!RX+$42$R0i!IS~1jI&H%L{OdQ8(Dd*jt7d+pK(r+xIz|k-MVp}!F*q~|T_19D zj>CFMVq!o9B2~LQ1qK0DC2-APaV_;vEO``$jC#LckUtPpb>TYSSQJ+r{|2bFEvYAQ4lR}#7yKF^A-vw_H*H)*$je( zI_X>x7Mo9Vj)Y=FFP((`X=&lXy|uLH5=nCT5w6-o{nZrUgXq*^+)jQaviPX6Y+~~` zP#KLR<+l1d;8XB%=F|c(A+1k6NGfZhnkv#HC24|ytac^bnw_1$eW#(=mP8$N6t4b- zJ}c_ygdHtp*oDVlZ}A+9@%!4edOu?xx1TfO#Ahc4BA;I-OJv4L^TsHD4ZG2HzxTZfhtE?pZ2Ph=v)IkLq}9;ZAe`jxK)5ELy_O}3tsg?` z3WtvHoZj=6JN$1D$67vN__*tg(%ixTrSH}s(QOLE-{M%|tRV5V#Z?cRhbt0`yFzH( znzVr8T1Ux@7uRa4DN8yir9uBQgXoDOJ!EOk2TsSMdB4~~T+9k04$H?X<1yu(?XR6Bxo-n6folspm=Xk)*zJ|Dn#qym@)X=Nv}>TL z;&i3$0Z-9~i+wR2#MlBkoVBg|XFl^BOE5#R5++6chMT9x>_c5+&G+XPJ1Vu$(0{fV zf|PY$0;DQ8i%kKvZ|-YB3^jfqMV(L=A$sN2DZyNoF(pT7nxF0SVn&gVLZ)% z24w>NA6&fdfqS#oN`EpCCrb;~jRa$Muy%0la#+bNVa|qsJ=!h&wid*AxP{4aF(NYf zSlwCZ0L^UwO2M4tA8un-`(-J@ZOS<3?71^=#i~DpioDi$lMVHRFq`=ecIMC6G{iQh2UJ<5N3qB zx{ykGDHw6fqX&aJHO3SWCzc|(S5S?$;V=TQ6*xHj!fc#-R+%xTylCd$UMz7LFmG2e6#zxbkM)ErZGDwQ6Q z*O8fpUC2TgN3g%zB-M%KtHRWPC3?IQyE>(OW0!QL*IyL*&Rz*N4I2%QQ}_sbR_mxX zgO)R1gu0&EN_tUPaZ7jZY-!v~CkBQjIey&Nl0~eo^xatk`ce85D$Ueiurq~~LySwU z3#1!WTSJ3c-_Lg`=lw_X>am~P13w8))ncg!J?Oi-4*iY_drCQ1iTq&4?D60L1WY2m zIH^%o(#n%CS$6DuEdv&F5cS2H&^iT9Gz(BU8q7d?cx9CCD_&79+=kmTpg})upN)7& z1|ee&F?W|ULzi}{4^GJFvG_Tc<_&T^TLJJr&GUkweL+vgUZmsCW?PZ9q6V~sU;%s` zBndY=w?);7KmOU{N$Ho{$J4x^jgkD&JG1VgLtdY@YIFu34f&lRKeN!m58L`BF|n&? zUB4>xq?k#1_79v6L{P!}2SO76y&=jIh>o-WqTI(*|qf&_cKd7yB@TC$c<#8Vk8Nfd{9L=c9I*N&zMbAnxx+Gia6rg30W=@ zGGMhB9?ky}*Xdkt#s6dYvUayW z!ToO0x{dnM?!`SiZVS-jcJ%c@D=CuI zi#1!|Ex@Fq!b0jQvuT>Ey$RvBu>;js%j$`B%~VH3>X_lit3oY_Dr?+qlqUqs;U2pm zG6&W}7sAc5zD30b=BQOKxyBY`V{7FycwUwJP*dKl7qh^&?a#N86ihiFS^dBTdx675 zT`qfib?=CH{WmAM1wSufs$Ir@xi?e#{vZQaDnh6HwqigEXdAP_Z~g@HlM&=9V=m<- z{40{?VBG~@zyWubP_&&wL+?k4A{K^SZaQ8My;2_oU=uCVX)EKCSO=cF0aIfyS^C)j z$bg3&5>P;Hsx~?lvb-+D{015`Ifi^UaH7!^GE*@Qq-BQ9z}&{aq4CVr!Ypn26iQE1 zQ(?w$So>ko;T7gwQcJq?$$AOW6q$eb zV+`7WU2k@9Nc~o7N&~u_fPNYWoI-hYbQELI0La|T%q~ zCO=g9jyfB=zdsyqRYML*d6PUINZ@RyG7*HW2BO{S`d6y8I~#cb^H67JOq$)F>aybe z;;d@G$UWGG#i$+j=o8G61eivP?a%uh#D2CO?>vgcL;~A!cYjZPiNh^8nM|ux6vNZU zavWe`ZZOqpJ^tAX8)@tWvTm|r-6T6mYeuNxVx6j_DH^$y@+lyfAI|pU8p7vVgqRjn zZtTZ{I5eeo`6FtU#v`Tm`0WT8(KxJ)xCj*Q4p^Qu# z9ZG^B>RaIHn7xySpPbLGuKuy4yJXeHF13>*H&WF79D?8nxJ)cZ$8y1fG(P9#W`Q@8 zp6y!u`?n7a4hnirt%r|28$#!7SlA@QmCe2#V>xYq-JcyD$?$&)zh zRejryZrJUb%bPKSi=iCK7L6vg=G1bhWIw9veI_j9~S+4#+Ua%{UNtWSF`MC$m4X|NBQRk zzo*^d5@73$0d$fh03k9Cn7;WupHy9zfw344#|uh{UU-1N^!??|h~myRLn@%UWCk>` zy(g(L3$7#*M*n&9|BCJM`OckC06@KUzzFya+T*M^f=GG&wPv&6b5P)>+e*Z=Z{ALFHW$ zYt6`Ry{7=w+PMh+g?13+;41Ij*H+$2^9J(vAnKY`Ck7;v3edcAs*(*sWUx3xCi46cYZ7?3R* z0Ml9qoAG~x@i=91>cj}~Bd~C(vw6H;_yCwU58(%nhK|)jb)ntY$;oT92B!hQnEgFZ z8fUxaY1z(3K)*Y<-aCk@RIA?RYs0Y`A?ps%){;!(v)KitnSVz_9~{_|T=g%4!56T9 z-E+iRw9CVcXrVvYRWLZEPwzVb5Y7}Ccnc2~w+MgrBKvj=!AchC`?mMBB_O?dzE;_O zVY<{1-})lXFt>&2&|vTu$U za(HpLGW?ZqMskmbs}MYAMT%F;ZQVRTq1m6rPCLXVS+})<4-62hcfbF`QU^6K; zoip$Z+*8){V+n1qc<+Gaq9ME*Fci|Z-mkJ4kG^?m0{{>0A>_^&di;Q&^PnGYt3-p< z{>0QaZ71gT!2u2Qtq@%AbPOlibAdi^_L2gQsX02Kb86m-i7E8~;^bamS;$6H2*#Jr-iQr^?4;i$4YnbSHHiGb%b zR|KFe?pFkyKiA!@VZStwN+PAby`1!k|+{@p`$>zo2R!6MW^f8pZ|0msW+* z*0IrJV;7ZlPI7ga4dj|9aCFGOYxH~IwwwMw*vgAsuyzCIIT)b8X*O}#%x9#vqFQqf ztqvGT!7;R8H0s-DFadcKRm}h;9wjB^Ck9Pb9FIr{p`PednP|Dri(OGZhJUj27Xde zDM#Hjy{6LY|$1;s+fH2 zgWJRah}%+>n4o3f)d}=7WZ4%d(R*2u7WYvb2ZygOMf~SOu_{iftlDY;W!v^7plmYD zP>0SUjlz2CiHTqiSYj5qoto4b77>YS!a{~OEBqA_AU&;Lf3jAsy<`v~>pI^$;GxqY z`kkpBS|dw-rhWgRU#FRe26LvkP~o?y2Xo_enO`H((>tA;9$+GkKa_c|8 zP?2VJp090{-|h5yp%gJlsZ%-15rArTg;8*x5lEb5H;?F{>H$x^77tu6B3Hu=;~h>_ zw&f}Kx60$E$CLtRPu8z55t>?2U6Av50igb>C~(xl($jr_b;Kv;xTCg;OG$&sfss($}L! z*~Xq@eKgQ;?&_>3ofGmYT=dXNSTPGjJvQ7Lff$$@NB{hBGcxFWu@S_yW9D;$?w0*a zfW8y?IYN*jpAQ#YzaAbN-PYbN@XU0VoOoxMOv%T;ZHn9+vE0sj?ZJ#t?`1}|115q6 zx3)|ekAIWatjE~(jpEuDgpJ!v!@BCN`8iv4hV-ky+{Z2Mw5-1IY3#Zh;KGd*+l-{_%W8w+qBsSAdkRgvxPtco1R zJPLJ@g3j{IZlsDpr*TLPkf~ZIF5koESv8m%_t`sTMp%{BfGnpqqMZ4dw%&Lv z`NU+#1ciVdCLfKZSC#H6rQS`Svd~7`$d2fOS z>&%%luCyg`tJG=|bF)n+PS>~svFO)fL?~l;Pq?$Q5iO>dAe=DtO;BHjHIrp5mUg{c zRDoNoD(|K_0s$LLS07{G?sSC;2c5|Bbq^71pZ;Bj=bqqCguGI2f8HaUZz5NGdU}A( ze-LsFAXAE|2heM)+Ng5=hCH>@g?>jRX%?WRaKOcrH$0+7Ep4UL&N=&!*crk`Z{tgr zZHBUp{$V(~_~6>#vmC?$f=%DXtI)#*M?zaeS(>l1fKdrEuMittIbhG@!6dOJk;G zUyOgq#tO}3uYzxB)OMLP(Gn@Oz6B9qFT-8CIf(LiKcsH@YY?`Kcb96A;jiM6I6?=4 zK*er7PneZMY&?;rFSnhXKeUj68{4XP9;bfsC;2d-l8W=RDla~>J|e_cB;|R&o07aVTtQy zPk%ctrD^5GpeL=p#=Orl+&K(MtR|H!FFlZq!mQ0}u^BhDDXS45RsI13SV^c5-kdCJybU|ZF?^zSEME>=><7;f;C8-uEckmb{ApIS z>$H5;hBQ${QRY9CU~$k>gFhGAU7 zHJEkYJW}$~Qm^`bMIFwo zvK~QC`H`=&8oie024}vxs#lwoRJqU&F3FpR0YJd@9F0zm>+?>yZI-Rgg0ZSi9Zl=g zH1uxyonBwHBUKi>&*>Cuf<480i_XPuus9xb{pF9P5@yW92(8Wl2&su;cO(DQG{25? z_uv_i8R6G+{38K-9M8r43?)xQUnJ>{S=uuDDMw(YDQ^H_Z-=5hN@J(Vcs*UzqA73p zA#L@3?K{d&xXkIg;>z1AQB9sIA}?tZa8x7j9D5n-LwN&+(@WSKFKjyFm0tlfuA3Dk z3Dq98Z~3F{^j1&|fUZ_2Cym zrc4isy_2KP`qAggUW;+gzka{4XAE#7K5}Q-r3akxc@XA1GWHr{5b0mfiQaLBXiTQN zdW9oBE3fItDc|Xhv40HsNV0YA^2FZ^W#h{HaFY}yMSXe(+#+ri-UN@Oap&zdSq4a@ z;~02!H(~{*k*)eaxLdqi>#_eZaxg1Tswk#E-9+|hO@#P9JHj<&@+kikIWI>YB`QPn zF-uw>lh%e{g{Pl^{6l^h6z z{}q!|@6CR0RaCdsTl_JM{B3ULQE1ZTm6FpfbAl_BNAXFdQ~hgoJ=3xbm&C4I@@*@| zE6uK^Up-fDT!l|~0qb?>)q&_gdV_rrS9!B~IZclZr45bV*WU-MP`2wq3#8Fvv={D; z4D0uc&cee^PH4Cu)XYcT)zg9Y618D%)uZSh|E0C>j%up;+C_M=;42nDlq#YkARt}3 zihzLBgis?d(tGa&M8QH;TBM6K6Pf}cbdZiTfe?B?S|B0Rgg|Kb;P3u!S?jKQzwfRu z|0F9XnKLtIX3yU9JhPui&-@dyeAHXJJC-Pc@XBeGIzivlh3XY;Y}?#_EATH>_VW9r zJrSAzKKVuXCe1|~-181F(GWBnNBY)_-tS`<0R^{r5?SbPwo?)N>DamKhayp7W9LvK zSMA0ij|$^|wsr(?6ODuUp6P^&vt%iT#CSNux2}rDjE!{p)Ohp1RljlYi0)OAaQPj@ zttrMCk7k4cFS+!eBe(aZ4{FH{f7{2Nf>TL9P*03KDeXaXor9>&DG>+vw$rFHdOrl( zdrN;iSL8rqhkj{J2Cwh9N)N{9xQ6V&1X#uQ->IUIcXJb!)Ba6tI%*kOH(XRYIe<7_ z7aLqJ0SJ1z^9>Ido=rw-a%A!YTI$Aq;!0a>qBVxIQ-2-OAW!w*e(0y&oz6Y;CvRl^ zY@n)gIN<}oyPHi4;an&&Wv*HKv~jSvd=jmaGcVve6l?|Q?>?4yU5flP zvd2V&2Hux#KagE{+xVWSNE-r#8ZKvCiuHb;rSEqUNOia_{-pWTh1!r;bDP#f^gD~< z9)ufUUCd%L*@L}mH{x;nC^xUHp8;?XIpM2i-6fwlJrDd8^>uCmvBt#?9sh4f{$df! zS~<{8f_~5kVe;2soBZ{KvcK0-X1)nO(Fx`IUwA6fot}81m0Lc`XGJ+s9-@X^HOVYd zr`QK9A65R!ecGRB?3j1$e z6V2>oTWR|$2D!X$YrxWHD;QlKxuEq!!~!JDIIhBRD*h+)StT0(RF9GB+p?=x>y!x9 zHnS+1urp+bj|^!7!QbXak1^p>E#KS+#U5~zVmcaXTUYr+-(U04EZJr)+T}^WHI%;_ zZ*Ug(6kQqCWKXOzGD-=W_aV=rH4VuxJMsusv$qKvi)_L zf-=avd-v{_##eVg-!=K5`)xn!h{Trv)JXO6*&)^|AFeuGaiSf#%cfs?{8y64UxzRi z3kiifG}>RY8s=tak0g{fa%W&y!Q2XdQ~M*q*|VP|9~ZT=EnGn1)|YJRb)UD)61 zR{y4JuBnM;H#Q(PNczV6t#&Q7EV<6}_Xg*SovG#K2wuQF13(Z7qe z$pGdi;+w)!3+C2Q!FRJK9Lni=d>hKbX^!iOQL5ui8I7MaGG@*ozQ_eW-o(Om$`j`1 z*}^779%X%z#(dkO5XLHPCb;S7;PNqr4%_?QWbao=0@;=B+jm?il2)n$%M2gCS`Zgc zqaGDb1S9PsRdb`T+`x|rA(J~#n{+vgvoaNk*dN1W?Ph6?U77&Ls8EA-jX9s+S=p%Q z_kDiI4@sX|%S+FCu;0sTfgkXnr=#;wo@yr`A&&7$2X&zqRg=s)_!p4|b;hO^G0Y04 z`%;29+5=+V4REAHCTsP1Sx8g&Hf4Xuk{mXGLvn-%lUpO|eAe?s#V-p$HT2~|R& zTDUEhNRfBLSbsamH7m3bWwj{IG$=B=A2<>8o3B6EE%9(zp}4XTpRrxtz^7P7Xxo@d>*udPl3GWx7=O7jkd%CQ3o%YhpocjWOjHqRXM#i+7 zB(54QVQTi?HNmf3`l=i18%g{cdSI91?Pk?*9PExW zL~hq@qZy?z)BZC0{}9|@^42kTdXF>(7zo5tM`)F0Knp z15)rgO(oez`Lh|}6>7hJ10DOHJa?86mxh)E_3c2$n3+VAMj@g^IF}_`(9IfRo z&z3@Kzk%rDk<)?j(hD;32klQHU3DRiGSrn)f~mTpf@*NnL@Jc!I*f zU+SuBr^dUlYZSgLbv3m-rS&V`y$|xF;G)zO^r7Biu->4%X&K$!F~~ zq8l~z57&IQ7BmZQF`lv9z41bhr=zTTj%h#G#|OxZ{CH;$MGnh$FfY#*@?R@Z>7K(; zv&soU+LDZd2ExiAmYm1rt+IzG7&(wqe5njo!D~36OTr)nS)_}NLCHOqIrQyL$l;CS z;j8!y?H~kNeW$44W!{A}A8&78*n(PL2vMgmAoN{rWwlO4sGiYa+yEifEplkhgO3Sg zkjcc8TD?wg`M^%5r6)=UVJ7>v=$aQ_09sW$)RsswvbEU(+e#S~KvGuM@pA}elVO3W zyrIxe)Y3@0uhG1@G`R*5G@prl|M)7Na!*a4wl-g5oqQ<3Hy z3>NDT(RSAj(|xTVJxVIm_{mEABA)9GHWMHMa-I*NM(!cNR+FNJ=u??oDHv zQo)A_Ust>fUWxH5RKqwZX09>#(Zz<{THGf^vs~QKi0S@lbxP&HBp60%Qb~un@S2Rclk##&^^cA=EGLE8jd6m5P zR%OnFVd~0+9o3=#eH4-_Z83b>)_1fk9GQ;KgT$=l$xvelSM`>he!Qxot!8TCI@gP7 zM(ZOQc|LUT@4f4)a$cezWLi!812=x0DT&kGaqUSF(RG!%($dF*Vu%sQ3%{N$j)-l~)mQ!`71-ubj4K>YMj@6_<^d-AUQi7%M@g2yly{L*LB{L?)2M? zpUbdx#7pzPB}>Id?Q_C~n=wHLg`n+_J8Eh1ypFC~cu!ld3SaL%c4^}o+17(AL>K(i zvtdJ+bL3C-(3WRaAb3#raH21LY~e$87Qczp-FExI{-2n_3wb(0p&u&N$oK?b^jqAT z(_^Jm1+lk$<+2up5#TuK zWj7_84DNjDihr95$)}GM9Q#5Lf@%k_y5;$=JEyy%b-TEg=j3TYRm!tj01F^nLz`Hz7%^&RdSI zYZ`#4C5|lj0a4h@iKYQt*yh^3&id_sg{`c-u|Fm9j(k=b2!}Eeme5AQhu{p{g6B_uf`n%56kgF4j0!ru{SiIZ?EvY z=asr7GV?`MsF*{f6g0fWZWSs87|-9ZVIv?F7HJDmna>2%Ln$vCKsE8J6o$qBFlVMhWz$u#IjNJ zOsF$ltafai0IrX36duyEbkEFTHa%*Jx4i#UTK?H^mb2xc`}Ub4mpK1lXXp+R0cksJ z0Pg+X1Eem|2hswqN{BWBYzBl{^i{vW+2-&x3EL>>X}V_PD=aKV^#PvkZEedxh0DM; zHhC3G>ph!K&5&S0X?NU|C~*3lI?)KLgElU|LL$xNRPxvMOa}fAQ;kNij&3fn-tJh# zqzzd9ntqML_<>Bg*Xm94&&|=jYy$4WKwGUCSoH#A^o$yUeCz(6sa{HRmQY?kDbsg# zOuX)Ix|d}kw3R}?OXw-TN!K0fd0&GUjzOb7&P)lO^Bv=h(W0Lb>d%hLCX(e#v+qn7 zLPnwl+h+8_8GDdNef;FXwFJUJ@3G)&RyjLflY#E&`<-W>KJV@Z)J*}oV}osccjG$e z2VvA{dw2;8pIOH#y2hl8jAB5Oal)3!HH!;gNA<7^b=D!wi}ywW$bV4A%WBm&7ziH zqpx&`0Lbw7TK{!7IKbh^1JXMQhaDa4NC18@gR}DBi2~lun=Q|?q@7cyCE~Q_BiWz- zvW{1>gO_N}T@|&s3(y*ZuhYGZwgR>oou;b4U%7H62~f`op1MX%+syD7407=&V*46R zPxN-p-D=t!m$~U)_JohzCdDbNb)B8Vp2e*|^F19E4>nq_E0KyHO|LZ_hdM_=PD;(F z^2X}|*+|-%=dKd3tY$FHr9@*T+}_8=8UesN>AsueOk>0J3!W=iUWhT!ec&`5eQW;c zAW={GI~+hI5tgBQYVA>M6>!fL6HvDWkXAuHWcOP`$bb@c41^eZ%uLs%33-FsDxx%6 zuS}qdGHt!Rranc>r%4n7+YEnQp7)<0~1osO={HFTdOEA2WV#s7}|p2@G! z%?-r%F#FYDY|NXY}X3JP3PhqHs1hyFej(9+ha<{&HH(y>wK<^WJhOXJt7JAVQ_0k+}Zr)HY0wL}K zQ@03E0SuNCfMVPTsoe5DT|D$h0I=@kl&2WDHTJsR6bCPY@~=yrb&=4y~G#+10( zW4pd$pH8n7V&L=6P#tYbFN*kVywf&2X5+2q2NF-gt&CBiR?a z3uH%?w(sYv6R@xY+Fz|64cte#2~)s1026(9+m&W->JR$$W9TnHE(=dkHTBU)y8PcIvl zJ)dqo8cVAx22!TUDc?UIG9VuY$i+v(bZRKm97P?q;C^{@bwr^e-Aju$f?zVBMOtTE zCB5_%$n>=8+}e{QIPdbC6D}3QKnJ`w+}qkcPCn4>UbEu*^MRNDUzcXBBRv0n5fXEs z`Q*dP_{)H$0bU#8Ee}rK>6F#4ojG}L?mRnx^8SA+M09>J>D0naKp!3ThMAji?+^wk zdJe)ZLs#yd6xMa&?EFK~lcx*V_8wE(K0x&-0p#rM&jA-7x;?yh#tT8d|GX&_R(D`* zda52^bN#8-4fdNc3^m}OYwP0`_!;JTUn2JQpL*9^VyOO8%m1SK|Ce;azYFQLtFzp0ESYKV8?{4m=yt%fcVv z!p8lk3G%*#JYaImYZn4pNS^=);k;7%vr)$a*&#!w^3bJ}Eqqq6F28z=N6N_!cImCS zsw0Y zNB~5Ya)!XgpUEhgx_7{dhkoMi9DpuA@B1Qa6f?RuAEf2i)jJKaHOh! z7n6<2zf>=d z|LLikGge$(?L&3N#t6f9W5UjQVS`_Ik+q9ieBlYANT=aQQb=!~c&ki#5Z##aU@NdtM)z;?iY=)6?S zI>LLBMjpFeSh9oi;+n$33af>em6R`K;HjE&nkV)fSijZGlR!|359yzbD;) zzldmeaMU{f~t*PCVpPr>0|zur9|n);=toS4ytTP?ulrTaKQ;yp4l za@=;-%Q|usm{DW4e~hWumFYkI^q(`X|FDx%M7V55Il+#>hXFV63Co+Am(9M* z8-BN6l`sm%5z2^lRPbnNu9V6TL}O9{OY*&3@V)|2>dOkCRJ`z-KrAUBwl?!^0h-p-pM5Z+c@L4SXZ8c zgx@XjqBtRnvq2Yta%andm7`xfrt$g_Z|rj){j8`4GHt+FO4gYO-p!RvkI`c5+YKmA*KXtw^9Rj|F!$lweupj0|IBVc)X z3}pr|V`%>5Rb%MunNNtiqBC;Ny65hOML$o)S{fb1?!c5a-Cg=aZu@`vB72aREyMk- z&ZCek>Q`Z+I^lQtYbo*yg60G&mcWyCGo4X3-CCpv1nB+EiMmh(pnaLwu8noF51EkB zQvTehA~s4OVbxczpaz06F6`IeUca*sS#H4dj(k2+mGb(DrjY_BASv^cjl-1&s+h25 zeWiVxA*tVQR(kyi3(J)ih{1e&bcpe9Zspk9I>R*6bJrDUNHe4OL+0q%SRLnn3Jyq@ zN5VYE$;7$Z96&`A3n{{PgxV3kx`q*fwc)a{b@aU`F}2hf9hR%>*mDHEmvNY(* z#RbKuvbahO}2Z{ z>2;?oIrNyE-V?wiEzJ*aljkhZG*KJ_SH|PnwmZk4=k`3cv9amc1_aWpLFa+iSv&wt z%Jw&}_S3gHRg7G5hWP_K#`DH%u|)Rd&RfnsG>RY5N4}^t@F%bPGVVD`WlhUrnhcfQ zHOAzt=^@A6u$J|pu#FpiQRQaCha)!4&S}cF#cm~Y0)bIQht`SGz63C^NT@NSpUxxw zO21r^G7>fH$IbTSKL!Oc^5BhR?+AmR_vS$)HE#SB| z5o<;nzl#!cs|-7GP}9bUuCn zlQMNb=48rvTn9WKyj5rElk0Hv{2ut|c&FAjXuSM(eV7R%Y8q`RkU8Ol(74z=?e%2| zcAU1h?vacfbt_Q@%?yX7Gj_L$U^v%Z4Wg*HNDd#+%ziM!KpejCv&;-ia{?602{UEs z9IT4eY*)e{Y|H0UK4Z302~Vu^R|xS`nqp~PRLzx2E1aIt!G?XBit`6~OiAf!S*3i( zc&S8Lr1KtC+qY%>ZJh%z?7MS=a-h}cS_`OdPWUM%!I7ubXCaR>A49cF_7fom?tfNa zv6JcldsnmQg8airKUPYH!bE|bsyqIPNqKq2fEc6f;J4A!2;C*^ex>A*TyxbuVN=eQ z)`pJ{&g@#|{ur1pWlB|+>zc{??i2u@R;zPf=t>bv+8kXsK+t@Re!BX+m~rENVC?LV z8xn%Hkzk?UqvR6$y1amHr0aiwpE=_#T^O9|41hNS$>4_Yorn>O)dWkAZ?DRaB5UZ` zt~ro|CDBnv^@cg3xp@r=CD?448wh_Mmjh7hBFOYc9Q>M&6*jS6}cR}qogZW05+5b#S#jsE9<-kuIznLk&15M9RqDC_T@EPLgYTrG4wxKNRB%k( z(MFc?*d`uUVcnEoN|m|Y4gD5d>vz`INmATXo+}Z>8}#KSq;yz(9?NEgL_7)V_pPZv zF1E=+vS!DdWqp?3hToNM*Yn+1EINvfdbfLGqZhbNjOnY%i~8%efi^Gjj^vl4l{=p* zZYJ*bxjQw_qA40q&quiLTPQ`@ZR3+G)mhdae)n$&B}Za z3F8S+_r~%X4=$-vaA)|z>y@d)^caWPu#t+#FXPIBwOfN{3eRN2%!PLpN`?Zr1jptL za33zMr>tzi4{=;@*tiinch|$7;p~Kvj3y48TqF>c7VRO9m z)8ErfmfxOmUaT?7GXtUA40V6>7uY|zim=^Iaqi(mc*I_7Xp*T{(4nKjfk$bhLj8|DWzuDOXDLASx^FOFJPFwE5;>zAl- zI-ZQO=BCsCf5G%Dzo7hO@246pnQA~WvzCDCYH>Brd2&vV{2N}xnmGrC+Fc3b%Au)6 ze|5k+PNWmS^uP~~4I?~npdU1XE(=41C3brA(rWP{l%~eIAU_h)V4EdlqUV{{mjv4+ zr?8Vzb~Fvd2&Es8B*}GfYljZT)1|t6XTc9&XNW?=^}l;jz(s%?b+VAodZ_`y6eE8k zp}{mC<3jVm*F6$DE;k>XY{a}CQPRB){N`m(a!-Jq?G25#f4Rt_0N`W}oE`9!)3~)} zuv^GWAa2aC!lv|$!A`~fXsZ6O;i>-9L*G@fPSbG#P=k3VBax6GVXw#JGP4FsL=?8( zESaew>}v8?rQg|v*0NTLIoWd$YM>dbFjLB-qrHN8E*U*s;I4uZh`w1V*`NPvk@!^& zZter?dv!hA?e;1~S=XjqYqY4a{D*EnzqDNu#Pz{f`Vr2c_yNx1d4DHTR0)lgn^U`u zq6iH-mp@-O*5QGh=df`go~J6}f^QR>my7_f%5%g~|M}jKACN$zIS97D!fsY&1%HvS z{q^Zykkc?KeFFTm*7nPet4YaIBmLKD?F(|1y4j6~;VB1xPXd2FFX3EY&0n&m&vYFS z6$f;(HEN6yR=#YW+=gxJ>npjp(|x;)z|xas42io)SZcTv_FntVSM^EFiSao$*G!cs zwi0ENXx;6m%);zZ4MEydh+5_3gz`;H2W8 zu+HO0F%YQcjg6Vqg^$(-wpN3Jf6vJ_*sc{po2#Ls^WBLv5pQ-2lU2(k4pm|#GLr%=)s+BEDWV{3f)t!82Iey*QI%P zxN4u#p^B7&qe9;YFNf6YlZ{E|Jo1&A057d6Vr7F6gukYnCblMMH^n5CY1cU>U3e?K zZ$9l>*k1aCUU8-KT+^DXR;}KQK@-6|p&RkXF5w_S3)KwaPB&@XFPTGHi!7^kz*jj=CNc1XOsv&;W^c z`4SrY`xXUl(&8*$yvP)BmGe8m3rI*v%m>t4O(9O=A5pGpjA!73S`)pU{f$flwr{m)zre+yV|EcSNDJt4 z<&bkPd&Nl}jR|x{EmF^l?EO=}RwJbLyGO>xf=D`?^lF@((M`l&Kg7r!>6~6ht$?I< zutK<8>eqn%PN_h^Tg{Zw0iBjXp}*&?_t+5HdC(i;zBSHi5MHq&UA=}PUaTuNg`!RV zAvp~muP~p9t<|A6p+Tt%jpqB}XH-UK*K#S&BX7waU(ULkQEEE;MER+|^EqT)qGYlL z-jT3zhDX0RsXIdveQZ9*l{GY5>xirUgu>H%Ead(D*U}CU3}RGF0pV#@=a9MO8}J0w zXo4V#Z)!cL)&A{HhLkj4J0!Qvr#T*5=ye0)Felx8TOol#?k*r14eE_UkaU#cok{E4 zD7l6bx$Xm+iz>_PmMkjEWTo<1aJ&`n2B>|;b}CXKZRKP4og=vYv-yAEo5KEOdoE#S z6)r#H5Ct%0_-4|&bz}d~%S5Mm>wtC0PUdv!#`IqmO32Nt62LF68nmXaSD_an~J^mt9QDBR^JSyUFdd1 zjiY0ZQ%$x*<6m`&~!Hhm$Rl=Jvha$Z7=b88t$5dV-P;_O(jN{nL=p8CaS0d(p_MI*^it<%Y|qa9%=tG1hNJ zQ|W=_Ok{S_78=yli1MW+4}~KA29%r<^|U6w>2LenI124FsjGGzYa~lRKafPdVmXX= z3KlU96DVSc0PnVOqE(rre&!AP@i~JE$yaJEKa&}(Qh^Mr!VP*urf z@(ab<@*rP#a6LHz?SjGH;%{Q|1(&@9rJ9CwAgy#>r__QB3`|=Q8f(R`%j+H|iM7?z zD;R8=;vxGWQO47!gPOAXp~YWH|Mu;^MH=a!#xr*azP(A_TlmO9&UPbU8e`4SPKA8D zBrW1x*7WigV19t6oor=ltl8@lo`jeILRNVJ?VZ`7+8eA!v2Xf(VVwm+se+$tqN^bI zg2T~)fNF*j{OL0Udk!gas?#}E66 z0c}3YvN$)|%Sm6^^g8+P(tjAuvpGHmtb6MmC(XG`6=oUJ7zsdhMCniOoG>q^+NVI0 z8^CG<+&L*DWsuPy)>8ZF(i5J`pL7E(iEWsspugn;{?uAc9jG<_XW_vQ+^0p+qDlLE zMgal4spxDtA-Zrf0~@L2hr1g~Tm4M`^ZCk!l}^;|o(839dZnaTbg{=h(h2<&r>v@7 zF=@X$U`A^ncCuc_|2nuP*Gj=mT?;OHMRL<0zO*uSa6K)kKeN`T z9LhPeNnJW%tdNqt`2&H@Bm9aNL#z*bH^jSaf+DpXf4~mwYiRuZ<^m2q=`f7_5>~RM z2z5OBp&s`z>&xppzd-(dfTZdB>q6}IQvr6zpV855M9icib%p2&qF8@1Z?Bs|=rU%~ zxZruxEbG|ZnN><7Qyk`y<}|TS6B#cNo|yW)WPg3Me{r+R6-NIt7%I|+48*0uZZ>NSq7TL)xzExE{r1pXC0HX-eFuc68TqbjI9}d) zvb)W(HBqJ|deJOiUP(C&kAK~ z-FsbN68_;=dk>$jKjT1zW)h1C4ZOC0SvPUOr0X>V1LKar8gNe%k35GdSo2}&rK74^ zBq=woV(leMxsp(@wXF2^>BHRf*%%#v^tj865J;M5Rai!#I z?J(*qo}f+N2Jyg3^@hWu3K`BUEdmL)HbS3JX*Ja6ge4g?JlGa#kUNySu!gZfH}MF9 zilaMX8pV6glI6Pd744$$_`4&SFUF1bj)SBKaYW4ibf(q?$KoIHJ_df)J^AOW3GyhR zHG?5qD#s3nOT-+ls}d@2;E!%J0DYln^GM-YJ25!6eW`PQcZtiX3%BXT(@9<}Shuban8w)>_)njEQSfWx@O zyF@H>yU(UN)sen0K^nC^*fN&l7RXAgoAR@X3@ov~0J^a)UQMslZO718$2p9@j&4iW zaX;t;_u%`-`q@l6e6C5)tbRIviI~kEmEx(ujblG@+2R^uW1z)4|2e`&~bt>+AbAaQuKo(fgmYt6s!gUGduzam z+XD|m;R@LZKccW*w*W(5^3mS2Let<(*9NcE-GPhSCjjK^3r|zadV#cuaqs`cyk_Rb zvN6KhC z?o$hW_OWNlaW6&H@HdB(Clr)?(GQZT6H<9IMQpT|B(BwOZn8*}0#%re^Md-HViNeI zF@8h3r~uodrHg{a)4$JG+RXFSM1w)YSRL?MNZjPdASo>sF^keZ@lc^aBWF~R`Lode zT5Z99haR~5Cs@cIOmFt_<&KZ@qHcf{{U99!-_v-${rChQZf7BO&UD#{$b~Exw*C0P87h7{W|74l}!1kYrsQ=eJhp*7Jq5Nn7 zvs{(%sCge1We8wkt6LktTctw}0{*~o7I*&z%`hK;{=iEg5|<19JK#Mi5y#~Z@I94w z#;p6VYz+LWay37}?E9|(Rr0Xc0VQZbtzC diff --git a/autodist/docs/interface_design.md b/docs/autodist/interface_design.md similarity index 100% rename from autodist/docs/interface_design.md rename to docs/autodist/interface_design.md diff --git a/autodist/docs/solver_interface/partition_constraint.md b/docs/autodist/solver_interface/partition_constraint.md similarity index 100% rename from autodist/docs/solver_interface/partition_constraint.md rename to docs/autodist/solver_interface/partition_constraint.md diff --git a/autodist/docs/solver_interface/pc_examples/moe_pc.yaml b/docs/autodist/solver_interface/pc_examples/moe_pc.yaml similarity index 100% rename from autodist/docs/solver_interface/pc_examples/moe_pc.yaml rename to docs/autodist/solver_interface/pc_examples/moe_pc.yaml diff --git a/autodist/docs/solver_interface/pc_examples/retnet_dp2_pc.yaml b/docs/autodist/solver_interface/pc_examples/retnet_dp2_pc.yaml similarity index 100% rename from autodist/docs/solver_interface/pc_examples/retnet_dp2_pc.yaml rename to docs/autodist/solver_interface/pc_examples/retnet_dp2_pc.yaml diff --git a/autodist/docs/solver_interface/pc_examples/retnet_hybrid2_pc.yaml b/docs/autodist/solver_interface/pc_examples/retnet_hybrid2_pc.yaml similarity index 100% rename from autodist/docs/solver_interface/pc_examples/retnet_hybrid2_pc.yaml rename to docs/autodist/solver_interface/pc_examples/retnet_hybrid2_pc.yaml diff --git a/autodist/docs/solver_interface/pc_examples/retnet_mp2_pc.yaml b/docs/autodist/solver_interface/pc_examples/retnet_mp2_pc.yaml similarity index 100% rename from autodist/docs/solver_interface/pc_examples/retnet_mp2_pc.yaml rename to docs/autodist/solver_interface/pc_examples/retnet_mp2_pc.yaml diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 86860703..87ca38c0 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -194,6 +194,8 @@ def parallelize_graph(graph: IRGraph, for spmd_desc in pp_desc.spmd_descs: stage = [] for cid in spmd_desc.partition_descs: + if cid not in cid2node: + raise RuntimeError(f'node {cid} not found in {cid2node}, make sure the plan is correct') stage.append(cid2node[cid]) stages.append(stage) graph.staging([s[0] for s in stages]) diff --git a/tests/autodist/pas/all_replicated_pp.json b/tests/autodist/pas/all_replicated_pp.json index ab6a532b..285edb7c 100644 --- a/tests/autodist/pas/all_replicated_pp.json +++ b/tests/autodist/pas/all_replicated_pp.json @@ -4,7 +4,7 @@ { "partition_descs": [ [ - 3, + 2, [ [ [ @@ -16,7 +16,7 @@ ] ], [ - 4, + 3, [ [ [ @@ -37,7 +37,7 @@ { "partition_descs": [ [ - 5, + 4, [ [ [ @@ -49,7 +49,7 @@ ] ], [ - 6, + 5, [ [ [ diff --git a/tests/autodist/pas/replicated_and_partition.json b/tests/autodist/pas/replicated_and_partition.json index edcf8949..6a133938 100644 --- a/tests/autodist/pas/replicated_and_partition.json +++ b/tests/autodist/pas/replicated_and_partition.json @@ -4,7 +4,7 @@ { "partition_descs": [ [ - 3, + 2, [ [ [ @@ -16,7 +16,7 @@ ] ], [ - 4, + 3, [ [ [ @@ -37,7 +37,7 @@ { "partition_descs": [ [ - 5, + 4, [ [ [ @@ -49,7 +49,7 @@ ] ], [ - 6, + 5, [ [ [ diff --git a/tests/autodist/pas/test_shared_param_pipeline.py b/tests/autodist/pas/test_shared_param_pipeline.py index 31957240..370df6fb 100644 --- a/tests/autodist/pas/test_shared_param_pipeline.py +++ b/tests/autodist/pas/test_shared_param_pipeline.py @@ -42,24 +42,25 @@ def test_shared_param_pipeline(): model = Model(hidden_dim) model.train() + program = Program() + program.clear() IDGenerator().clear() - if idx > 0: - Program().clear() - smodel = nnscaler.SemanticModel(model, attr_savedir=tempdir) - smodel.dummy_input = {'x': torch.randn(bsz, hidden_dim)} - smodel.dynamic_shape = False dataloader = SemanticDataLoader( microbatches([{ 'x': torch.randn(bsz, hidden_dim) }])) - Program().set_input([dataloader.irobj]) + + smodel = nnscaler.SemanticModel(model, attr_savedir=tempdir) + smodel.dummy_input = {'x': torch.randn(bsz, hidden_dim)} + smodel.dynamic_shape = False + program.set_input([dataloader.irobj]) ir_dummy_input = next(dataloader) outputs = smodel(ir_dummy_input) outputs.backward() - Program().set_output([outputs]) - Program().finalize() - ir_graph = Program().get_graph() + program.set_output([outputs]) + program.finalize() + ir_graph = program.get_graph() print(ir_graph.nodes()) plan_path = Path(os.path.dirname(__file__)) / cfg_fname diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index c398ba17..6643628e 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -82,7 +82,7 @@ def test_follow_rope(): in future, we may add follow chains for binary ops, like mul, add, etc. ''' - cfg = AutoDistConfig(mesh_col=2) + cfg = AutoDistConfig(mesh_col=2, re_profile=True) model_graph = ModelGraph(ir_graph, cfg) spmd_solver = SPMDSolver( @@ -208,7 +208,7 @@ def test_follow_attention(): ''' pc_path = Path(os.path.dirname(__file__)) / 'test_attention_follow.yaml' - cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2) + cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2, re_profile=True) model_graph = ModelGraph(ir_graph, cfg) spmd_solver = SPMDSolver( diff --git a/tests/autodist/spmd_solver/test_partition_constraint.py b/tests/autodist/spmd_solver/test_partition_constraint.py index e8e65d16..da51bbe1 100644 --- a/tests/autodist/spmd_solver/test_partition_constraint.py +++ b/tests/autodist/spmd_solver/test_partition_constraint.py @@ -75,7 +75,7 @@ def test_partition_constraint(): pc_path = Path(os.path.dirname( os.path.realpath(__file__))) / 'test_pc.yaml' - cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2) + cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2, re_profile=True) model_graph = ModelGraph(ir_graph, cfg) spmd_solver = SPMDSolver( From 71b169ba3f720f58a7f8f163b49fa054a7021a63 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 7 May 2024 06:22:33 +0000 Subject: [PATCH 1630/1892] Merged PR 2128: parallel module: refine non-persistent buffer support parallel module: refine non-persistent buffer support by updating document and adding more warning message when loading from checkpoint --- .gitignore | 3 + docs/parallel_module.md | 51 +++++---- nnscaler/parallel.py | 18 ++++ nnscaler/runtime/module.py | 36 +++++++ .../parallel_module/test_checkpoint_buffer.py | 102 ++++++++++++++++-- tests/utils.py | 15 +++ 6 files changed, 192 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index d7d4e70d..b216a0d1 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ gencode* shelf *.iml *.xml + +# cppimport generated file +.rendered.*.cpp \ No newline at end of file diff --git a/docs/parallel_module.md b/docs/parallel_module.md index c35f7570..c11c0d96 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -297,21 +297,26 @@ class BroadcastGenFilesStrategy(Enum): ``` 1. `None`: nothing will be broadcasted. + You need to do it by yourself or the generated files are save in a shared directory (like azure blob). -2. `ALL`: broadcast all the generated files to all nodes. +2. `ALL`: broadcast all the generated files to all nodes (Recommended). + This is useful when you want to run the same code on all nodes. please note the init weight files can be huge. -3. `NO_WEIGHTS`: broadcast all except init weights. +3. `NO_WEIGHTS`: broadcast all except init weights (Only for experts). + Without weights, you can only construct the parallel module with `init_params=False`. You can then - - Load the weights from a checkpoint file with `module.load_state_dict`, `load_merged_state_dict` - or `load_deduped_state_dict` - - Or you can use `broadcast_weights` to get the weights from the workers in node0. - (local world size should be bigger than plan_ngpus) + - Safe way: you can use `broadcast_weights` to get the weights from the workers who have init weights. By default rank 0 will run the `parallelize` and store all the generated files. So if local world size is bigger than plan_ngpus, you can use `broadcast_weights` to get the weights from workers on node0. + - Risk Way: Load the weights from a checkpoint file with `module.load_state_dict`, `load_merged_state_dict` or `load_deduped_state_dict`. + + Please note: the non-persistent buffers will remain uninitialized after loading the checkpoints, + because they are not saved in the state dict. + To make sure all the buffers are initialized, you still need to set `init_params=True` to make sure non-persistent buffers are initialized if you want to initialize weights by loading a checkpoint. -4. `CODE`: broadcast the new generated code only +4. `CODE`: broadcast the new generated code only (Not recommeneded) It's your responsibility to make sure other necessary files are available on all nodes. Here are some guidelines to choose the strategy: @@ -325,9 +330,9 @@ a. If use `none`, the user should run `parallelize(..., load_module=False, ..)`, b. if they are using a NAS-like device to save generated files, and the upload/download speed is fast in the cluster, they can also use `none`, and just run `parallelize(..., load_module=True, ..)` to do the training. -c. If use `all`, then user can just run `parallelize(..., load_module=True, ..)` safely. (remember to set `nccl` communication timeout to a very big value to tolerate the duration of this `nccl` broadcast) +c. If use `all`, then user can just run `parallelize(..., load_module=True, ..)` safely. (remember to set `nccl` communication timeout to a very big value to tolerate the duration of this `nccl` broadcast). This is the most recommended way. -d. If use `no_weights`. then user can run `parallelize(..., load_module=True, init_module_params=rank None: f.unlink() +def _broadcast_single_value(src_rank, group, obj=None): + sent_obj = [obj] + torch.distributed.broadcast_object_list( + sent_obj, + src=src_rank, + group=group, + ) + return sent_obj[0] + + _DEFAULT_INSTANCE_NAME = '_' _GENCODE_FILE_PREFIX = 'gencode' _GENCODE_FILE_TEMPLATE = _GENCODE_FILE_PREFIX + '{}.py' # 'gencode{}.py' @@ -2178,6 +2189,10 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks logging.info(f'Rank-{rank} is broadcasting weight to ranks {curr_parallel_group_ranks}, broadcast root: {src_rank}...') + if isinstance(module, ParallelModule): + if not _broadcast_single_value(src_rank, curr_parallel_group, module.non_presistent_buffers_inited): + module._warn_uninitialized_non_persistent_buffers(raise_error=True) + # we have a special optimization for ParallelModule params = module.parameters_for_broadcast() if isinstance(module, ParallelModule) else module._parameters.values() logging.info(f'Inplace broadcasting {len(params)} parameters...') @@ -2191,4 +2206,7 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): for _, buffer in module._buffers.items(): torch.distributed.broadcast(buffer.data, src=src_rank, group=curr_parallel_group) + if isinstance(module, ParallelModule): + module.mark_non_persistent_buffers_inited() + torch.distributed.barrier() diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 5b385b6c..5bc36a6d 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -623,6 +623,29 @@ def __init__(self): # this is a lazy initialization, # which will be initialized in the first call of `clip_gnorm` self._nreplicas2localparams: Optional[Dict[int, List[torch.nn.Parameter]]] = None + # track whether all the parames (especially the non-persistent buffers) have been initialized + self._non_presistent_buffers_inited = False + + @property + def non_presistent_buffers_inited(self): + return self._non_presistent_buffers_inited + + def mark_non_persistent_buffers_inited(self): + self._non_presistent_buffers_inited = True + + def _warn_uninitialized_non_persistent_buffers(self, raise_error = False): + _non_persistent_buffers_load_warning = ( + "Non-persistent buffers cannot be initialized with load_[/merged/dedupped]state_dict. " + "Please be sure to you will initialize them manually. " + ) + _non_persistent_buffers_load_error = ( + "Non-persistent buffers haven't been initialized." + ) + if not self._non_presistent_buffers_inited: + if raise_error: + raise RuntimeError(_non_persistent_buffers_load_error) + else: + _logger.warning(_non_persistent_buffers_load_warning) def _post_init(self, init_params=True): """ @@ -638,10 +661,14 @@ def _post_init(self, init_params=True): # if dist.is_initialized() and self.rank != dist.get_rank(): # raise RuntimeError(f"The rank to load this module file name is expected to be {self._rank}, but got {dist.get_rank()}") + self._non_presistent_buffers_inited = init_params or not self._non_persistent_buffers_set module_file = Path(sys.modules[self.__module__].__file__) self.module_dir = module_file.parent if init_params: self.load_attr_content(str(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE_STEM}"))) + + self._warn_uninitialized_non_persistent_buffers() + self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}")) self._compute_config: 'ComputeConfig' = torch.load(module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}")) self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}")) @@ -659,6 +686,7 @@ def _post_init(self, init_params=True): self._register_load_state_dict_pre_hook(ParallelModule._pre_load_state_dict_hook, with_module=True) def forward(self, *args, **kwargs): + self._warn_uninitialized_non_persistent_buffers(raise_error=True) if self.training: self._sync_grad_required = True # mark sync_grad() can be called again return self._forward_impl(*args, **kwargs) @@ -760,6 +788,8 @@ def train_step(self, Results: List[Any]: a list of outputs for each sample """ + self._warn_uninitialized_non_persistent_buffers(raise_error=True) + if not self.compute_config.use_end2end: raise RuntimeError("train_step() is only supported in end2end mode") if is_dummy_batch and len(samples) != len(is_dummy_batch): @@ -798,6 +828,8 @@ def infer_step(self, samples: List[Any]) -> List[Any]: Results: List[Any]: a list of outputs for each sample """ + self._warn_uninitialized_non_persistent_buffers(raise_error=True) + if not self.compute_config.use_end2end: raise RuntimeError("infer_step() is only supported in end2end mode") @@ -965,6 +997,8 @@ def _post_state_dict_hook(self, state_dict, prefix, local_metadata) -> None: def _pre_load_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None: self._remove_extra_state(state_dict, prefix) + # Both load_state_dict and load_deduped_state_dict will trigger this hook + self._warn_uninitialized_non_persistent_buffers() @property def module_dedup_group_size(self) -> int: @@ -1047,3 +1081,5 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s raise RuntimeError(erro_msg) else: _logger.warning(erro_msg) + + self._warn_uninitialized_non_persistent_buffers() diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index 44d0c9fa..d8d603ff 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -3,11 +3,13 @@ import torch import pytest +import torch.distributed -from nnscaler.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dicts +from nnscaler.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dicts, broadcast_weights from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun +from ..utils import catch_log class Net1(torch.nn.Module): @@ -32,14 +34,26 @@ def forward(self, x): return self.fc(x + self.buffer) -def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_shape): +class Net3(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer('buffer', torch.ones(128, 64), persistent=True) + self.fc = torch.nn.Linear(64, 64) + + # x with shape [128, 64] + def forward(self, x): + return self.fc(x + self.buffer) + + +def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_shape, init_module_params=True): return parallelize( module, {'x': torch.randn(input_shape)}, PASRandomSPMD, compute_config, cube_savedir=cube_savedir, - instance_name=instance_name + instance_name=instance_name, + init_module_params=init_module_params ) @@ -56,10 +70,48 @@ def _gpu_worker(): net2 = Net2() net2.load_state_dict(merged_state_dict, strict=False) # should success - net2 = _to_cube_model(Net2(), compute_config, tempdir, 'net2', (256, 64)) - net2.load_merged_state_dict(merged_state_dict, strict=False) # should success + from nnscaler.runtime.module import _logger + with catch_log(_logger) as log_stream: + net2 = _to_cube_model(Net2(), compute_config, tempdir, 'net2', (256, 64)) + net2.load_merged_state_dict(merged_state_dict, strict=False) # should success + assert torch.equal(list(net2._buffers.values())[0], torch.ones(256, 64)) + + logs = log_stream.getvalue() + assert not 'Non-persistent buffers cannot be initialized with' in logs + + with catch_log(_logger) as log_stream: + net2 = _to_cube_model(Net2(), compute_config, tempdir, 'net2-2', (256, 64), init_module_params=False) + net2.load_merged_state_dict(merged_state_dict, strict=False) # should success + assert not torch.equal(list(net2._buffers.values())[0], torch.ones(256, 64)) + + logs = log_stream.getvalue() + assert 'Non-persistent buffers cannot be initialized with' in logs + + net3 = _to_cube_model(Net3(), compute_config, tempdir, 'net3', (128, 64)) + cube_state_dict = net3.state_dict() + assert any(key.startswith('buffer') for key in cube_state_dict) + merged_state_dict, _ = merge_state_dicts([cube_state_dict]) + assert 'buffer' in merged_state_dict + + net3 = Net3() + net3.load_state_dict(merged_state_dict, strict=False) # should success + assert torch.equal(net3.buffer, torch.ones(128, 64)) + + with catch_log(_logger) as log_stream: + net3 = _to_cube_model(Net3(), compute_config, tempdir, 'net3-2', (128, 64)) + net3.load_merged_state_dict(merged_state_dict, strict=False) # should success + assert torch.equal(list(net3._buffers.values())[0], torch.ones(128, 64)) - assert True + logs = log_stream.getvalue() + assert not 'Non-persistent buffers cannot be initialized with' in logs + + with catch_log(_logger) as log_stream: + net3 = _to_cube_model(Net3(), compute_config, tempdir, 'net3-2', (128, 64), init_module_params=False) + net3.load_merged_state_dict(merged_state_dict, strict=False) # should success + assert torch.equal(list(net3._buffers.values())[0], torch.ones(128, 64)) + + logs = log_stream.getvalue() + assert not 'Non-persistent buffers cannot be initialized with' in logs @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @@ -68,3 +120,41 @@ def test_checkpoint_buffer(): Please note the buffer size in Net1 and Net2 are different. """ launch_torchrun(1, _gpu_worker) + + +def _gpu_worker_broadcast(): + init_distributed() + compute_config = ComputeConfig(1, 2, use_zero=False) + rank = torch.distributed.get_rank() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_broadcast_fail') as tempdir: + net1 = _to_cube_model(Net1(), compute_config, tempdir, 'net1', (128, 64), init_module_params=False) + with pytest.raises(RuntimeError, match="Non-persistent buffers haven't been initialized."): + broadcast_weights(net1) + + with pytest.raises(RuntimeError, match="Non-persistent buffers haven't been initialized."): + net1(torch.randn(128, 64)) + + net1 = _to_cube_model(Net1(), compute_config, tempdir, 'net1-2', (128, 64), + init_module_params=rank < 1 + ) + + if rank == 0: + assert net1.non_presistent_buffers_inited + assert torch.equal(list(net1._buffers.values())[0], torch.ones(128, 64)) + else: + assert not net1.non_presistent_buffers_inited + assert not torch.equal(list(net1._buffers.values())[0], torch.ones(128, 64)) + + broadcast_weights(net1) + assert net1.non_presistent_buffers_inited + assert torch.equal(list(net1._buffers.values())[0], torch.ones(128, 64)) + + net1(torch.randn(128, 64)) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_checkpoint_buffer_broadcast(): + """ + Please note the buffer size in Net1 and Net2 are different. + """ + launch_torchrun(2, _gpu_worker_broadcast) diff --git a/tests/utils.py b/tests/utils.py index 906d771f..2170ecec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -316,3 +316,18 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +@contextmanager +def catch_log(_logger, loglevel='DEBUG'): + import logging + from io import StringIO + string_stream = StringIO() + old = _logger.level + _logger.setLevel(loglevel) + handler = logging.StreamHandler(string_stream) + handler.setLevel(loglevel) + _logger.addHandler(handler) + yield string_stream + _logger.removeHandler(handler) + _logger.setLevel(old) From 7abb89ab8f6f74027732d83e19388ad810c14575 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 7 May 2024 06:39:17 +0000 Subject: [PATCH 1631/1892] Merged PR 2131: refine nnscaler interface refine nnscaler interface 1. Clear the functions in nnscaler.__init__ 2. rename register => register_op 3. rename first argument of register_op to annotation TODO: convert all examples to parallelize interface --- docs/parallel_module.md | 4 +- docs/register_custom_op.md | 234 ++++++++++++++++++ examples/alphafold2/alphafold2.py | 7 +- examples/alphafold2/module.py | 48 ++-- examples/llama/chat.py | 5 +- examples/llama/generation.py | 10 +- examples/llama/model.py | 10 +- examples/llama/test_chat_completion.py | 3 +- examples/megatron_gpt/convert.py | 4 +- examples/megatron_gpt/parallel.py | 4 +- examples/mlp/train.py | 3 +- examples/nlp/blocks/attention.py | 10 +- examples/nlp/blocks/mlp.py | 2 +- examples/nlp/gpt/train.py | 14 +- examples/nlp/mbart/train.py | 14 +- examples/openfold/blocks/attention.py | 18 +- examples/openfold/blocks/embedder.py | 12 +- examples/openfold/blocks/evoformer.py | 20 +- examples/openfold/blocks/opm.py | 6 +- examples/openfold/blocks/tmu.py | 22 +- examples/openfold/blocks/utils.py | 2 +- examples/openfold/train.py | 7 +- examples/vision/swin/blocks/attention.py | 4 +- examples/vision/swin/blocks/mlp.py | 2 +- examples/vision/swin/blocks/patch.py | 6 +- examples/vision/swin/blocks/transformer.py | 6 +- examples/vision/swin/train.py | 9 +- nnscaler/__init__.py | 70 +++--- nnscaler/compiler.py | 8 +- nnscaler/graph/parser/register.py | 29 ++- nnscaler/utils.py | 13 + .../pas/test_shared_param_pipeline.py | 4 +- .../spmd_solver/test_cube_operator.py | 2 +- tests/compiler/test_compile.py | 28 ++- tests/graph/parser/test_register.py | 8 +- tests/graph/test_multiref.py | 20 +- tests/runtime/test_gnorm.py | 10 +- tests/runtime/test_grad_accum.py | 19 +- tests/runtime/test_module_merge.py | 10 +- tests/runtime/test_reducer.py | 18 +- tutorial.md | 12 +- 41 files changed, 509 insertions(+), 228 deletions(-) create mode 100644 docs/register_custom_op.md diff --git a/docs/parallel_module.md b/docs/parallel_module.md index c11c0d96..304871f8 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -18,7 +18,7 @@ The above restrictions are necessary for the pipeline parallelism to work. Of co ```python import torch -from nnscaler.parallel import parallelize, ComputeConfig, build_optimizer +from nnscaler import parallelize, ComputeConfig, build_optimizer class LLM(torch.nn.Module): def __init__(self, ...): @@ -48,7 +48,7 @@ In this case, for non-paralle modules, they are replicated inside unit, and run ```python import torch -from nnscaler.parallel import parallelize, ComputeConfig, build_optimizer +from nnscaler import parallelize, ComputeConfig, build_optimizer class HeavyModule(torch.nn.Module): def __init__(self, ...): diff --git a/docs/register_custom_op.md b/docs/register_custom_op.md new file mode 100644 index 00000000..85cb4928 --- /dev/null +++ b/docs/register_custom_op.md @@ -0,0 +1,234 @@ +# Register a new operator/function + +## Overview + +During iterating the model, users may encounter the situation that some operator is failed to be traced by nnscaler concrete tracer. In this case, users can register this operator to nnscaler, then nnscaler will treat it as one simple operator instead of tracing into the sub-operators of this operator. The registration also tells nnscaler the feasible partition options of this operator. + +Note, the registration only works for function while does not work for PyTorch Module, because nnscaler does not allow weight tensors to be managed by the registered operator. If you are dealing with a PyTorch Module, you can consider its underlying PyTorch function instead. + +Taking `torch.nn.InstanceNorm2d` (or actually `torch.nn.functional.instance_norm`) as an example. Currently nnscaler does not support partitioning of this operator. If you use this operator in your model, you will see a warning message "Find unknown pytorch operation: torch.xxx.xxx" or "Set python runtime function: xxx". Then you can register this operator into nnscaler and specify its partition options as follows: + +```python +import torch +import torch.nn as nn +import nnscaler +# you write a new function to wrap the operator. +# suggest to make all the argument of this function torch.Tensor, +# and *REMEMBER* to add type annotation for each input argument. + +# the first argument of register is the annotation to indicate how this function can be partitioned, +# very similar to einsum expression. '^' means the corresponding dimension cannot be partitioned. +@nnscaler.register_op('n c h^ w^, c, c -> n c h^ w^', name='my_instance_norm') +def my_instance_norm(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): + return nn.functional.instance_norm(input, weight=weight, bias=bias) +``` + +Here is another example to support a custom matmul operator: +```python +# file: custom_ops.py +def operator(x: torch.Tensor, w: torch.Tensor, h: float) -> torch.Tensor: + out = torch.matmul(x, w) + out = out.view(h, out.size(0) // h, out.size(1)) + return out + +# file: main.py +import nnscaler +from custom_ops import operator +nnscaler.register_op('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom')(operator) +``` + + +## Api Explains + +```python +def register_op( + annotation: Union[str, Callable], + name: Optional[str] = None, + code_impl_pattern: str = 'import' +) -> Callable: + ... +``` + +Register a function with dimension annotations. + +This function is cooperated with nnscaler tracer. Users can only register global functions(which are defined in a module level, instead of ones defined inside a function / class or __main__ scope). + +The annotation (`annotation`) specifies the number of inputs as `*args`, +and treat all the rest inputs as `**kwargs`. + +For tensor-type inputs, the annotation should be a string of identifiers separated by space, e.g., `'a b'`; +For non-tensor-type inputs, the annotation should be specified '?'. + +This function can be used as a decorator or a function. +Here are several Examples: + +```python +import nnscaler +from third_party import func + +nnscaler.register_op('a (b c) -> (a b) c')(func) +``` + +or, + +```python +import nnscaler +from third_party import func + +@nnscaler.register_op('a (b c) -> (a b) c') +def func(x, b = 4): + ... +``` + +or, + +```python +import nnscaler +from third_party import func + +def anno_fn(*inputs, **kwargs): + return 'a (b c) -> (a b) c' + +nnscaler.register_op(anno_fn)(func) +``` +This function has the following parameters: + +- `annotation` (`str | Callable`): operator annotation, it can be: + - op annotation: e.g., 'a (b c) -> (a b) c' + - a callable function that generates op annotation (str). The function + taks inputs and kwargs as arguments and returns the operator annotation. +- `name` (`str | None`): operator name. Only usable when `node_repr` is a string. +- `code_impl_pattern` (`str`): It can only be 'import' or 'source'. If 'import' (default), will generate code with import statement. If 'source', will take the source code directly. + + +## Dimension Annotion Operations + +An operator has (multiple) input tensors and (multiple) output tensors. +Each tensor can be annotated with dimension annotations (DimAnno) using `identifiers`. +The same `identifier` indicates the they have the same real length. + +### Dimension Annotation + + e.g., 'a+', 'ab^', 'cd', '(ab+ c^ d)', '64' + +A dimension of a tensor can be annotated by {identifier}{reduction} template. + +An `identifier` must be one of: + 1) symbolic annotation that must match with the criteria of python str.isidentifier. + 2) numeric string that must match with python str.isdecimal. This indicates the shape is the same value + numeric string will always have '^' reduction type' + +Special identifier: + 1) '*': this special identifier indicates the dimension is dynamic, which will automatically get expanded given the shape + 2) '?': this special identifier indicates the value is can only be replicated, no matter it is a tensor or a non-tensor. + +A `reduction` can be a set of {'', '+', '^'}: + '' indicates this dimension can be partitioned, and each output should have this dimension. + '+' indicates this dimension can be partitioned, and each output doesn't have this and need to do sum-reduction. + '^' means this dimension cannot be partitioned. + +A dimension can also be annotated with inner-dimensions using brackets, i.e., '(' and ')'. +The value of inner dimension needs to be inferrable, or indicated by function args (of same name). + +### Shape Annotation + + e.g., 'a (c+ d^) e' + +A shape annotation consists of dimension annotation separated by (multiple) spaces. + + +### Operator Annotation + + e.g., 'm k+, n k+ -> m n', '4 k+, k+ d -> 8 d', '* d^, s -> * s' + + An operator can be annotated with input shape annotations and output shape annotations. + + '->' seperates the inputs (left) and outputs (right) and ',' separates each input and output tensor. + + Identifiers in output tensor annotation needs to be + 1) apearred in input tensor annotations + 2) using numeric string + +### Operator Partitioning Rule: + + 1) Spatial Partition (dimension with '' reduce type): + tensors can be uniformly partitioned on dimensions having spatial reduction type. + other tensors in the operator that don't have this dimension will be replicated. + + 2) Value Partition (dimension with '+' reduce type): + * tensors can be uniformly partition on dimensions having numerical reduction type + * other tensors in the the operator that don't have this dimension will be partitioned numerically. + + 3) Illegal Splitting (dimension with '^' reduce type): + * tensors can not be partitioned on dimensions having '^' reduction type. + +### Hidden dimension + + Sometimes user need to reshape the tensor by splitting a dimension into multiple dimensions. For example, a tensor of (1024, 8) size needs to be reshaped into the shape of (8, 128, 8): + + ```python + # annotation: (h t) k -> h t k + def reshape(tensor: torch.Tensor, h : int = 8) -> torch.Tensor: + out = tensor.reshape(h, tensor.size(0) // h, tensor.size(-1)) + return out + ``` + + This can be represented by annotating a dimension using brackets `()`. The bracket contains multple identifiers (and their reductions), like `'(h t)'` here for the first dimension of the input tensor. To help system infer the number of `h` and `t` in the annotation, the function requires to put in a same-named argument `h` or `t` (`h=8` here in example). + +## Inplace Operators + +We assume the module is SSA (static single-assignment), which means you should avoid change the input tensors inplace in your custom operators. + +However, if you have to do this, it's your responsibility to make sure the inplace operation is correct. And to help us track the dependencies between tensors, you must return all the input tensors that are changed in the custom operators. + +```python +# this is wrong +# because x is changed inplace, but it is not returned. +@nnscaler.register_op('*, * -> *', name='inplace_operator) +def inplace_operator(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x.add_(y) + z = x + y + return z + +# this is correct +@nnscaler.register_op('*, * -> *, *', name='inplace_operator) +def inplace_operator(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x.add_(y) + z = x + y + return x, z +``` + +## Optional Tensor Input + +If you have an optional tensor input, you should tell `nnscaler` how this optional tensor will be used. + +There are two cases: + +1. The optional tensor can only be replicated. In this case, you should use '?' as the identifier. + +```python +@nnscaler.register_op('a^ b^, ? -> a^ b^', name='optional_tensor') +def optional_op(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: + out = torch.triu(x) + if y is not None: + out += y + return out +``` + +2. The optional tensor can be partitioned. In this case, you should use an annotation function to tell `nnscaler` how to partition the optional tensor when it is not None. + +```python +def anno_fn(*inputs, **kwargs): + if inputs[1] is None: + return '*, ? -> *' + else: + return '*, * -> *' + +@nnscaler.register_op(anno_fn, name='optional_tensor') +def optional_op(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: + if y is None: + return x + return x + y +``` + +Please note the value of the optional tensor should be consistent in runtime and tracing time. Which mean if the value of the optional tensor is `None` in tracing time, it should always be `None` in runtime, and if the value of optional tensor is not `None` in tracing time, it should always not be `None` in runtime. It may cause runtime error if the consistency is not guaranteed. diff --git a/examples/alphafold2/alphafold2.py b/examples/alphafold2/alphafold2.py index 3b68484a..10e70ce6 100644 --- a/examples/alphafold2/alphafold2.py +++ b/examples/alphafold2/alphafold2.py @@ -3,6 +3,7 @@ import nnscaler from nnscaler.profiler import CudaTimer +from nnscaler.compiler import compile, SemanticModel from nnscaler.profiler.timer import print_each_rank from examples.alphafold2.model import * @@ -14,7 +15,7 @@ from nnscaler.graph.function.anchor import IRGraphAnchor - + def build_alphafold_config(setting:int): assert setting in [1, 2, 3], "setting should be in [1, 2, 3]." # dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False @@ -44,7 +45,7 @@ def run(size_config, other_config, policy): if not is_train: model.eval() - model = nnscaler.SemanticModel(model, + model = SemanticModel(model, input_shapes=([bs, s, r, cm], [bs, r, r, cz])) dataloader = nnscaler.runtime.syndata.SynDataLoader(shapes=([bs, s, r, cm], @@ -52,7 +53,7 @@ def run(size_config, other_config, policy): dtypes=(dtype, dtype), batch_dims=(0, 0)) - @nnscaler.compile(model, dataloader, PAS=policy, override=True) + @compile(model, dataloader, PAS=policy, override=True) def train_iter(model, dataloader): msa_repr, pair_repr = next(dataloader) loss = model(msa_repr, pair_repr) diff --git a/examples/alphafold2/module.py b/examples/alphafold2/module.py index 2e20fc93..b6de9849 100644 --- a/examples/alphafold2/module.py +++ b/examples/alphafold2/module.py @@ -3,7 +3,7 @@ import torch.utils.checkpoint as ckpt -@nnscaler.graph.parser.register('*, *, * -> *, *, *, *', name='calc_qkvg') +@nnscaler.register_op('*, *, * -> *, *, *, *', name='calc_qkvg') def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, bs: int, s: int, r: int, head: int, c: int): gate = torch.sigmoid(torch.matmul(x, gate_proj)) @@ -23,7 +23,7 @@ def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, """ -@nnscaler.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R^ M^', +@nnscaler.register_op('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R^ M^', name='MSAAttention') @torch.jit.ignore def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, @@ -91,7 +91,7 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return out -@nnscaler.graph.parser.register('N S R^ M^, M^ E^, M^ F^, E^ M^, N 1^ 8^ R^ R^ -> N S R^ M^', +@nnscaler.register_op('N S R^ M^, M^ E^, M^ F^, E^ M^, N 1^ 8^ R^ R^ -> N S R^ M^', name='MSAAttentionWithBias') @torch.jit.ignore def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, @@ -177,7 +177,7 @@ def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # note: code not reused constrained by cube's interface -@nnscaler.graph.parser.register('N S R^ M^, N R^ R^ Z^, M^ E^, M^ F^, E^ M^, Z^ H^ -> N S R^ M^', +@nnscaler.register_op('N S R^ M^, N R^ R^ Z^, M^ E^, M^ F^, E^ M^, Z^ H^ -> N S R^ M^', name='MSARowAttentionWithPairBias') def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, pair_repr: torch.Tensor, @@ -196,7 +196,7 @@ def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, head, c, scale, chunk_size, is_train) -@nnscaler.graph.parser.register('N S^ R M^, M^ E^, M^ F^, E^ M^ -> N S^ R M^', +@nnscaler.register_op('N S^ R M^, M^ E^, M^ F^, E^ M^ -> N S^ R M^', name='MSAColAttention') def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, @@ -207,7 +207,7 @@ def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, is_train).permute(0, 2, 1, 3) -@nnscaler.graph.parser.register('N S^ R^ M^, M^ M^, M^ E^, M^ E^, M^ M^, M^ M^ -> N S^ R^ M^', +@nnscaler.register_op('N S^ R^ M^, M^ M^, M^ E^, M^ E^, M^ M^, M^ M^ -> N S^ R^ M^', name='MSAColGlobalAttention') def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, v_proj: torch.Tensor, @@ -250,7 +250,7 @@ def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, """ -@nnscaler.graph.parser.register('N S R M^, M^ E^, E^ M^ -> N S R M^', +@nnscaler.register_op('N S R M^, M^ E^, E^ M^ -> N S R M^', name='MSATransition') def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -258,12 +258,12 @@ def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, torch.nn.functional.relu(torch.matmul(msa_repr, proj1)), proj2) -@nnscaler.graph.parser.register('N S R M^, M^ C^ -> N S R C^', name='OPMLeftProj') +@nnscaler.register_op('N S R M^, M^ C^ -> N S R C^', name='OPMLeftProj') def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): return torch.matmul(msa_repr, proj) -@nnscaler.graph.parser.register('N S R M^, M^ C^ -> N S R C^', name='OPMRightProj') +@nnscaler.register_op('N S R M^, M^ C^ -> N S R C^', name='OPMRightProj') def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): return torch.matmul(msa_repr, proj) @@ -273,7 +273,7 @@ def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): """ -@nnscaler.graph.parser.register('N S^ R M^, N S^ T^ M^, F^ Z^ -> N R^ T Z^', +@nnscaler.register_op('N S^ R M^, N S^ T^ M^, F^ Z^ -> N R^ T Z^', name='OuterProductMean') @torch.jit.ignore def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, @@ -308,7 +308,7 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): return outer -@nnscaler.graph.parser.register('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMOLeftProj') +@nnscaler.register_op('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMOLeftProj') def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): a = torch.sigmoid(torch.matmul(pair_repr, proj1)) @@ -316,7 +316,7 @@ def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return b -@nnscaler.graph.parser.register('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', +@nnscaler.register_op('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMORightProj') def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -325,12 +325,12 @@ def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return b -@nnscaler.graph.parser.register('N S T^ Z^, Z^ Z^ -> N S T^ Z^', name='TMOGate') +@nnscaler.register_op('N S T^ Z^, Z^ Z^ -> N S T^ Z^', name='TMOGate') def TMOGate(pair_repr: torch.Tensor, proj: torch.Tensor): return torch.sigmoid(torch.matmul(pair_repr, proj)) -@nnscaler.graph.parser.register('N S R^ E^, N T^ R^ E^, N S T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', +@nnscaler.register_op('N S R^ E^, N T^ R^ E^, N S T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='TriangleMultiplicationOut') def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, @@ -347,7 +347,7 @@ def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, return p * g -@nnscaler.graph.parser.register('N R^ S Z^, Z^ E^, Z^ E^ -> N R^ S E^', name='TMILeftProj') +@nnscaler.register_op('N R^ S Z^, Z^ E^, Z^ E^ -> N R^ S E^', name='TMILeftProj') def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): a = torch.sigmoid(torch.matmul(pair_repr, proj1)) @@ -355,7 +355,7 @@ def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return a -@nnscaler.graph.parser.register('N R^ T Z^, Z^ E^, Z^ E^ -> N R^ T E^', +@nnscaler.register_op('N R^ T Z^, Z^ E^, Z^ E^ -> N R^ T E^', name='TMIRightProj') def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -364,12 +364,12 @@ def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, return a -@nnscaler.graph.parser.register('N S^ T Z^, Z^ Z^ -> N S^ T Z^', name='TMIGate') +@nnscaler.register_op('N S^ T Z^, Z^ Z^ -> N S^ T Z^', name='TMIGate') def TMIGate(pair_repr: torch.Tensor, proj: torch.Tensor): return torch.sigmoid(torch.matmul(pair_repr, proj)) -@nnscaler.graph.parser.register('N R^ S E^, N R^ T^ E^, N T^ S Z^, E^, E^, E^ Z^ -> N T^ S Z^', +@nnscaler.register_op('N R^ S E^, N R^ T^ E^, N T^ S Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='TriangleMultiplicationIn') def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, tri_mul_norm2_weight: torch.Tensor, @@ -385,12 +385,12 @@ def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, return p.permute(0, 2, 1, 3) * g -@nnscaler.graph.parser.register('N S R^ C^, C^ D^ -> N S R^ D^', name='TANSBias') +@nnscaler.register_op('N S R^ C^, C^ D^ -> N S R^ D^', name='TANSBias') def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): return torch.matmul(pair_repr, bias_proj) -@nnscaler.graph.parser.register('N S R^ Z^, Z^ E^, Z^ F^, E^ Z^, N T^ R^ G^ -> N S R^ Z^', +@nnscaler.register_op('N S R^ Z^, Z^ E^, Z^ F^, E^ Z^, N T^ R^ G^ -> N S R^ Z^', name='TriangleAttentionNodeStart') def TriangleAttentionNodeStart(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, @@ -404,12 +404,12 @@ def TriangleAttentionNodeStart(pair_repr: torch.Tensor, head, c, scale, chunk_size, is_train) -@nnscaler.graph.parser.register('N S^ R C^, C^ D^ -> N S^ R D^', name='TANEBias') +@nnscaler.register_op('N S^ R C^, C^ D^ -> N S^ R D^', name='TANEBias') def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): return torch.matmul(pair_repr, bias_proj) -@nnscaler.graph.parser.register('N R^ S Z^, Z^ E^, Z^ F^, E^ Z^, N R^ T^ G^ -> N R^ S Z^', +@nnscaler.register_op('N R^ S Z^, Z^ E^, Z^ F^, E^ Z^, N R^ T^ G^ -> N R^ S Z^', name='TriangleAttentionNodeEnd') def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, @@ -424,7 +424,7 @@ def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, return out.permute(0, 2, 1, 3) -@nnscaler.graph.parser.register('N R T^ Z^, Z^ E^, E^ Z^ -> N R T^ Z^', +@nnscaler.register_op('N R T^ Z^, Z^ E^, E^ Z^ -> N R T^ Z^', name='PairTransition') def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): @@ -432,6 +432,6 @@ def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, torch.nn.functional.relu(torch.matmul(pair_repr, proj1)), proj2) -@nnscaler.graph.parser.register('* -> *, *', name='multi2ref') +@nnscaler.register_op('* -> *, *', name='multi2ref') def multi2ref(x: torch.Tensor): return (x, x) diff --git a/examples/llama/chat.py b/examples/llama/chat.py index a9de661b..9e5925b9 100644 --- a/examples/llama/chat.py +++ b/examples/llama/chat.py @@ -18,9 +18,10 @@ from examples.llama.generation import Llama import nnscaler +from nnscaler.utils import set_default_logger_level nnscaler.init() -nnscaler.set_logger_level(level=logging.WARNING) +set_default_logger_level(level=logging.WARNING) logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) @@ -43,7 +44,7 @@ def main( ) dialog = [ - {"role": "system", "content": + {"role": "system", "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature."}, ] diff --git a/examples/llama/generation.py b/examples/llama/generation.py index 92558df8..52b12b0b 100644 --- a/examples/llama/generation.py +++ b/examples/llama/generation.py @@ -17,6 +17,8 @@ Role = Literal["system", "user", "assistant"] import nnscaler +from nnscaler.compiler import compile +from nnscaler.utils import load_model from nnscaler.flags import CompileFlag @@ -94,7 +96,7 @@ def build( def __init__(self, model: Transformer, tokenizer: Tokenizer, use_cube: bool): self.model = model self.tokenizer = tokenizer - + # ======================= cube initilizer ================= self.use_cube = use_cube if use_cube: @@ -113,17 +115,17 @@ def policy(graph, resource): graph.assign(fwop, 0) return graph - @nnscaler.compile(self.model, sample_tokens, 0, + @compile(self.model, sample_tokens, 0, PAS=policy, model_dynamic_shape=True) def infer(model: torch.nn.Module, tokens: torch.Tensor, prev_pos: int): logits = model(tokens, prev_pos) return logits - + params = self.model.params vocab_size, n_layers = params.vocab_size, params.n_layers del self.model - self.model = nnscaler.load_model() + self.model = load_model() # TODO: support auto reset non-parameter attributes for llama model self.model.params = params diff --git a/examples/llama/model.py b/examples/llama/model.py index 846ad825..64ae7a3b 100644 --- a/examples/llama/model.py +++ b/examples/llama/model.py @@ -58,7 +58,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): # TODO: fix annotation -@nnscaler.graph.parser.register('*, *, 38^ 64^ -> *, *') +@nnscaler.register_op('*, *, 38^ 64^ -> *, *') def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, @@ -72,7 +72,7 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) -@nnscaler.graph.parser.register('N seqlen^, N seqlen^ H^ -> 1 1 seqlen^ seqlen^') +@nnscaler.register_op('N seqlen^, N seqlen^ H^ -> 1 1 seqlen^ seqlen^') def create_mask(tokens: torch.Tensor, h: torch.Tensor, start_pos: int): seqlen = tokens.shape[1] mask = None @@ -84,7 +84,7 @@ def create_mask(tokens: torch.Tensor, h: torch.Tensor, start_pos: int): return mask -@nnscaler.graph.parser.register('N seqlen *, 1 1 * -> N seqlen *') +@nnscaler.register_op('N seqlen *, 1 1 * -> N seqlen *') def apply_mask(x: torch.Tensor, mask: torch.Tensor): return x if mask is None else x + mask @@ -184,12 +184,12 @@ def forward( keys = keys.transpose(1, 2) values = values.transpose(1, 2) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) - + # NOTE: cube doesn't support dynamic graph # if mask is not None: # scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = apply_mask(scores, mask) - + scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) diff --git a/examples/llama/test_chat_completion.py b/examples/llama/test_chat_completion.py index 0d98e664..1d5dab32 100644 --- a/examples/llama/test_chat_completion.py +++ b/examples/llama/test_chat_completion.py @@ -18,9 +18,10 @@ from examples.llama.generation import Llama import nnscaler +from nnscaler.utils import set_logger_level nnscaler.init() -nnscaler.set_logger_level(level=logging.WARNING) +set_logger_level(level=logging.WARNING) logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) diff --git a/examples/megatron_gpt/convert.py b/examples/megatron_gpt/convert.py index 20c30c64..ca5f1dbc 100644 --- a/examples/megatron_gpt/convert.py +++ b/examples/megatron_gpt/convert.py @@ -3,8 +3,8 @@ model = build_model() # 2. register customized op -from nnscaler.graph.parser.register import register -register('* h, h -> * h')(GeLUFunction.apply) +from nnscaler import register_op +register_op('* h, h -> * h')(GeLUFunction.apply) # 3. build semantic model from nnscaler import SemanticModel diff --git a/examples/megatron_gpt/parallel.py b/examples/megatron_gpt/parallel.py index 091896da..7de6570d 100644 --- a/examples/megatron_gpt/parallel.py +++ b/examples/megatron_gpt/parallel.py @@ -8,8 +8,8 @@ # 2. register customized op from gpt_model import GeLUFunction -from nnscaler.graph.parser.register import register -register('* h, h -> * h')(GeLUFunction.apply) +from nnscaler import register_op +register_op('* h, h -> * h')(GeLUFunction.apply) # 3. parallel model from fairseq.nnscaler.pas_policies import PASData, PASRandomSPMD diff --git a/examples/mlp/train.py b/examples/mlp/train.py index e5a40fac..dfefff34 100644 --- a/examples/mlp/train.py +++ b/examples/mlp/train.py @@ -10,6 +10,7 @@ from functools import partial import nnscaler +from nnscaler.compiler import compile from nnscaler.profiler import CudaTimer from nnscaler.profiler.timer import print_each_rank from nnscaler.runtime.utils import microbatches @@ -63,7 +64,7 @@ def train(): dataloader = microbatches((dummy_data(),)) # compile a training iteration - @nnscaler.compile(model, dataloader, PAS=policy) + @compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): data = next(dataloader) loss = model(data) diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py index 86d95ceb..913363ec 100644 --- a/examples/nlp/blocks/attention.py +++ b/examples/nlp/blocks/attention.py @@ -2,8 +2,8 @@ import nnscaler -@nnscaler.graph.parser.register('L^ N E^, (h+ d^ 3) E^, (h+ d^ 3), E^ (h+ d^) -> L^ N E^', name='self_attention') -def self_attention(query: torch.Tensor, +@nnscaler.register_op('L^ N E^, (h+ d^ 3) E^, (h+ d^ 3), E^ (h+ d^) -> L^ N E^', name='self_attention') +def self_attention(query: torch.Tensor, qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, out_proj: torch.Tensor, h: int, scale: float, dropout_p: float, mask: bool = False): @@ -17,14 +17,14 @@ def self_attention(query: torch.Tensor, q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - + # ======== replace the semantic into more efficient implementation ============ # q = q.transpose(0, 1) # L (N h) d -> (N h) L d # k = k.transpose(0, 1) # L (N h) d -> (N h) L d # q = q * scale # (N h) L d, 1 -> (N h) L d # k = k.transpose(1, 2) # (N h) L d -> (N h) d L # attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - + # preallocating input tensor: (N h) L L matmul_input_buffer = torch.empty([N * h, L, L], dtype=query.dtype, device=query.device) # L (N h) d, L (N h) d -> (N h) L L @@ -56,7 +56,7 @@ def self_attention(query: torch.Tensor, return output -@nnscaler.graph.parser.register('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') +@nnscaler.register_op('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') def cross_attention(query: torch.Tensor, key: torch.Tensor, q_proj: torch.Tensor, q_bias: torch.Tensor, k_proj: torch.Tensor, k_bias: torch.Tensor, diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py index 95b52e09..4a5f3e7c 100644 --- a/examples/nlp/blocks/mlp.py +++ b/examples/nlp/blocks/mlp.py @@ -2,7 +2,7 @@ import nnscaler -@nnscaler.graph.parser.register('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') +@nnscaler.register_op('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj1_bias: torch.Tensor, proj2: torch.Tensor, diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py index 8478c7fd..1982e135 100644 --- a/examples/nlp/gpt/train.py +++ b/examples/nlp/gpt/train.py @@ -14,6 +14,8 @@ from model import GPT, Config, dummy_data import nnscaler +from nnscaler.compiler import compile +from nnscaler.utils import set_default_logger_level from nnscaler.profiler.timer import CudaTimer, print_each_rank from nnscaler.profiler.memory import memory_summary from nnscaler.runtime.utils import microbatches @@ -34,7 +36,7 @@ help='micro-batch size') parser.add_argument('--gbs', type=int, default=8, help='global batch size') -parser.add_argument('--dp', type=int, default=1, +parser.add_argument('--dp', type=int, default=1, help='data parallel size, only for megatron') parser.add_argument('--tp', type=int, default=1, help='tensor parallel size, only for megatron') @@ -52,13 +54,13 @@ nnscaler.init() -nnscaler.set_logger_level(logging.WARN) +set_default_logger_level(logging.WARN) logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) # get policy policy = get_policy([spmd, mpmd], args.policy) -policy = partial(policy, - nmicros=args.gbs//args.mbs, +policy = partial(policy, + nmicros=args.gbs//args.mbs, dp_size=args.dp, tp_size=args.tp ) @@ -72,7 +74,7 @@ def train(): heads=args.heads, ffn_hidden_dim=4*args.hidden, num_embeddings=51200, - seqlen=args.seqlen, + seqlen=args.seqlen, ) model = GPT(config) model = model if not args.fp16 else model.half() @@ -80,7 +82,7 @@ def train(): gen_data = partial(dummy_data, args.mbs, config) dataloader = microbatches((gen_data(),), cycle=True) - @nnscaler.compile(model, dataloader, PAS=policy) + @compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py index 946ed414..28ae9da9 100644 --- a/examples/nlp/mbart/train.py +++ b/examples/nlp/mbart/train.py @@ -17,6 +17,8 @@ from examples.nlp.mbart.model import dummy_data import nnscaler +from nnscaler.compiler import compile +from nnscaler.utils import set_default_logger_level, load_model from nnscaler.profiler.timer import CudaTimer, print_each_rank from nnscaler.profiler.memory import memory_summary from nnscaler.runtime.utils import microbatches @@ -29,7 +31,7 @@ parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') -parser.add_argument('--dp', type=int, default=1, +parser.add_argument('--dp', type=int, default=1, help='data parallel size, only for megatron') parser.add_argument('--tp', type=int, default=1, help='tensor parallel size, only for megatron') @@ -55,13 +57,13 @@ nnscaler.init() -nnscaler.set_logger_level(logging.WARN) +set_default_logger_level(logging.WARN) logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) # get policy policy = get_policy([gallery], args.policy) -policy = partial(policy, - nmicros=args.gbs//args.mbs, +policy = partial(policy, + nmicros=args.gbs//args.mbs, dp_size=args.dp, tp_size=args.tp ) @@ -105,12 +107,12 @@ def train(): gen_data = partial(dummy_data, batch_size, config) dataloader = microbatches((gen_data(),), cycle=True) - @nnscaler.compile(model, dataloader, PAS=policy) + @compile(model, dataloader, PAS=policy) def train_iter(model, dataloader): input_ids, decoder_input_ids = next(dataloader) loss = model(input_ids, decoder_input_ids) loss.backward() - model = nnscaler.load_model() + model = load_model() optimizer = torch.optim.Adam( model.parameters(), lr=3e-05, betas=(0.9, 0.98)) diff --git a/examples/openfold/blocks/attention.py b/examples/openfold/blocks/attention.py index 079ef4ea..5913f791 100644 --- a/examples/openfold/blocks/attention.py +++ b/examples/openfold/blocks/attention.py @@ -7,7 +7,7 @@ import torch.utils.checkpoint as ckpt -@nnscaler.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S R^ M^', name='msa_attn') +@nnscaler.register_op('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S R^ M^', name='msa_attn') @torch.jit.ignore def msa_attn(x: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, @@ -64,7 +64,7 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return out -@nnscaler.graph.parser.register('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, N 1 head+ R^ R^ -> N S R^ M^', name='msa_attn_bias') +@nnscaler.register_op('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, N 1 head+ R^ R^ -> N S R^ M^', name='msa_attn_bias') @torch.jit.ignore def msa_attn_bias(x: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, @@ -132,7 +132,7 @@ def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # note: code not reused constrained by cube's interface -@nnscaler.graph.parser.register('N S R^ M^, N R^ R^ Z^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, Z^ head+ -> N S R^ M^', name='row_attn') +@nnscaler.register_op('N S R^ M^, N R^ R^ Z^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, Z^ head+ -> N S R^ M^', name='row_attn') def row_attn(msa_repr: torch.Tensor, pair_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, @@ -147,7 +147,7 @@ def row_attn(msa_repr: torch.Tensor, pair_repr: torch.Tensor, head, c, scale, chunk_size, is_train) -@nnscaler.graph.parser.register('N S^ R M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S^ R M^', name='col_attn') +@nnscaler.register_op('N S^ R M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S^ R M^', name='col_attn') def col_attn(msa_repr: torch.Tensor, gate_proj: torch.Tensor, qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, c: int, scale: float, chunk_size: int, is_train: bool): @@ -160,7 +160,7 @@ def col_attn(msa_repr: torch.Tensor, gate_proj: torch.Tensor, return out -# @nnscaler.graph.parser.register('N S^ R^ M^, M^ (head+ dim^), M^ E^, M^ E^, M^ (head+ dim^), (head+ dim) M^ -> N S^ R^ M^', name='MSAColGlobalAttention') +# @nnscaler.register_op('N S^ R^ M^, M^ (head+ dim^), M^ E^, M^ E^, M^ (head+ dim^), (head+ dim) M^ -> N S^ R^ M^', name='MSAColGlobalAttention') def global_attn(msa_repr: torch.Tensor, q_proj: torch.Tensor, k_proj: torch.Tensor, v_proj: torch.Tensor, gate_proj: torch.Tensor, out_proj: torch.Tensor, @@ -197,7 +197,7 @@ def global_attn(msa_repr: torch.Tensor, q_proj: torch.Tensor, return torch.matmul(o, out_proj).transpose(-2, -3) -@nnscaler.graph.parser.register('N S R M^, M^ E+, E+ M^ -> N S R M^', name='feedforward') +@nnscaler.register_op('N S R M^, M^ E+, E+ M^ -> N S R M^', name='feedforward') @torch.jit.ignore def feedforward(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): """ @@ -211,7 +211,7 @@ def feedforward(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor return x -@nnscaler.graph.parser.register('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N R^ R^ head+ -> N S R^ Z^', name='tri_attn_start') +@nnscaler.register_op('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N R^ R^ head+ -> N S R^ Z^', name='tri_attn_start') def tri_attn_start(pair_repr: torch.Tensor, gate: torch.Tensor, qkv: torch.Tensor, out: torch.Tensor, bias: torch.Tensor, @@ -224,7 +224,7 @@ def tri_attn_start(pair_repr: torch.Tensor, return out -@nnscaler.graph.parser.register('N S^ R Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N S^ S^ head+ -> N S^ R Z^', name='tri_attn_end') +@nnscaler.register_op('N S^ R Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N S^ S^ head+ -> N S^ R Z^', name='tri_attn_end') def tri_attn_end(pair_repr: torch.Tensor, gate: torch.Tensor, qkv: torch.Tensor, out: torch.Tensor, bias: torch.Tensor, @@ -286,7 +286,7 @@ def forward(self, msa_repr: torch.Tensor) -> torch.Tensor: msa_repr: [N S R M] """ out = col_attn( - msa_repr, self.gate, self.qkv, self.out, + msa_repr, self.gate, self.qkv, self.out, self.heads, self.dhead, self.scale,self.chunk_size, self.training ) return out diff --git a/examples/openfold/blocks/embedder.py b/examples/openfold/blocks/embedder.py index 568c6621..0a5f7080 100644 --- a/examples/openfold/blocks/embedder.py +++ b/examples/openfold/blocks/embedder.py @@ -7,12 +7,12 @@ -@nnscaler.graph.parser.register('N res, cz nobins, cz -> N res res cz', name='relpos') +@nnscaler.register_op('N res, cz nobins, cz -> N res res cz', name='relpos') def input_embedder_pair_emb(ri: torch.Tensor, tf_emb_i: torch.Tensor, tf_emb_j: torch.Tensor, w_relpos: torch.Tensor, b_relpos: torch.Tensor, relpos_k) -> torch.Tensor: - + ri = ri.type(tf_emb_i.dtype) d = ri[..., None] - ri[..., None, :] boundaries = torch.arange( @@ -25,14 +25,14 @@ def input_embedder_pair_emb(ri: torch.Tensor, d = nn.functional.one_hot(d, num_classes=len(boundaries)).float() d = d.to(ri.dtype) pair_emb = torch.nn.functional.linear(d, w_relpos, b_relpos) - + pair_emb = pair_emb + tf_emb_i[..., None, :] pair_emb = pair_emb + tf_emb_j[..., None, :, :] return pair_emb -@nnscaler.graph.parser.register('N res tfdim^, cm tfdim^, cm -> N nclust^, res, cm') +@nnscaler.register_op('N res tfdim^, cm tfdim^, cm -> N nclust^, res, cm') def input_embedder_tf_m(tf: torch.Tensor, w_tf_m: torch.Tensor, b_tf_m: torch.Tensor, nclust: int) -> torch.Tensor: tf_m = torch.nn.linear(tf, w_tf_m, b_tf_m) tf_m = tf_m.unsqueeze(-3).expand(((-1,) * len(tf.shape[:-2]) + (nclust, -1, -1))) @@ -104,7 +104,7 @@ def forward(self, tf: torch.Tensor, ri: torch.Tensor, msa: torch.Tensor) -> Tupl # [*, N_res, N_res, c_z] pair_emb = input_embedder_pair_emb( - ri, tf_emb_i, tf_emb_j, + ri, tf_emb_i, tf_emb_j, self.w_linear_relpos, self.b_linear_relpos ) # pair_emb = relpos(ri.type(tf_emb_i.dtype)) @@ -119,7 +119,7 @@ def forward(self, tf: torch.Tensor, ri: torch.Tensor, msa: torch.Tensor) -> Tupl -@nnscaler.graph.parser.register() +@nnscaler.register_op() def sum_d(x: torch.Tensor, bins: torch.Tensor, inf: float) -> torch.Tensor: squared_bins = bins ** 2 upper = torch.cat( diff --git a/examples/openfold/blocks/evoformer.py b/examples/openfold/blocks/evoformer.py index 7ce24954..c3b63977 100644 --- a/examples/openfold/blocks/evoformer.py +++ b/examples/openfold/blocks/evoformer.py @@ -9,14 +9,14 @@ import nnscaler -# @nnscaler.graph.parser.register('N S^ R^ cm^, N R^ R^ cz^ -> N out^') +# @nnscaler.register_op('N S^ R^ cm^, N R^ R^ cz^ -> N out^') # @torch.jit.ignore # def input_packing(msa: torch.Tensor, pair: torch.Tensor, out: int) -> torch.Tensor: # buffer = torch.cat((torch.flatten(msa, start_dim=1), torch.flatten(pair, start_dim=1))) # return buffer -# -# -# @nnscaler.graph.parser.register('N out^ -> N S^ R^ cm^, N R^ R^ cz^', name='input_unflatten') +# +# +# @nnscaler.register_op('N out^ -> N S^ R^ cm^, N R^ R^ cz^', name='input_unflatten') # @torch.jit.ignore # def input_unpacking(buffer: torch.Tensor, # S: int, R: int, cm: int, cz: int) -> Tuple[torch.Tensor, torch.Tensor]: @@ -59,11 +59,11 @@ def __init__(self, s: int, r: int, cm: int, cz: int, # MSA column-wise gated self-attention self.col_norm = torch.nn.LayerNorm(cm) self.col_attn = MSAColAttention(cm, msa_head, self.scale, self.msa_col_chunk) - + # MSA transition self.msa_transition_norm = torch.nn.LayerNorm(cm) self.msa_transition = Transition(cm, ff_mult) - + # Outer product mean self.outer_norm = torch.nn.LayerNorm(cm) self.outer_prod_mean = OuterProducterMean(cm, c, cz, self.opm_chunk) @@ -79,7 +79,7 @@ def __init__(self, s: int, r: int, cm: int, cz: int, # Triangular gated self-attention around ending node self.tri_attn_node_end = TriangleAttentionNodeEnd(cz, pair_head, c, self.scale, self.tane_chunk) - + # Transition in the pair stack self.pair_transition_norm = torch.nn.LayerNorm(cz) self.pair_transition = Transition(cz, ff_mult) @@ -145,20 +145,20 @@ def tflops(self, n_seq: int, n_res: int) -> float: flops += 4 * (msa_size * 4) # pair layer norm flops += 2 * (pair_size * 4) - + # attention: gate + qkv + q@k (N S head r c, N S head c r) + k@v + dense msa_attn = n_seq * n_res * self.cm * self.cm + \ 3 * n_seq * n_res * self.cm * self.cm + \ n_seq * (self.cm // self.c) * n_res * n_res * self.c + \ n_seq * (self.cm // self.c) * n_res * n_res * self.c + \ n_seq * n_res * self.cm * self.cm - + pair_attn = n_res * n_res * self.cz * self.cz + \ 3 * n_res * n_res * self.cz * self.cz + \ n_res * (self.cz // self.c) * n_res * n_res * self.c + \ n_res * (self.cz // self.c) * n_res * n_res * self.c + \ n_res * n_res * self.cz * self.cz - + # row and col end attention flops += 2 * msa_attn # tirangle start and triangle end diff --git a/examples/openfold/blocks/opm.py b/examples/openfold/blocks/opm.py index 005841e3..8d4e8271 100644 --- a/examples/openfold/blocks/opm.py +++ b/examples/openfold/blocks/opm.py @@ -7,7 +7,7 @@ import torch.utils.checkpoint as ckpt -# @nnscaler.graph.parser.register('N S+ R^ M^, M^ c^, M^ c^, (c^ c^) cz^ -> N R^ R^ cz^', name='outer_prod_mean') +# @nnscaler.register_op('N S+ R^ M^, M^ c^, M^ c^, (c^ c^) cz^ -> N R^ R^ cz^', name='outer_prod_mean') @torch.jit.ignore def outer_prod_mean(msa_repr: torch.Tensor, left_proj: torch.Tensor, right_proj: torch.Tensor, out_proj: torch.Tensor, chunk_size: int, training: bool): @@ -49,13 +49,13 @@ def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): return outer -@nnscaler.graph.parser.register('N S R M+, M+ C -> N S R C', name='opm_projection') +@nnscaler.register_op('N S R M+, M+ C -> N S R C', name='opm_projection') def opm_projection(msa_repr: torch.Tensor, proj1: torch.Tensor): x = torch.matmul(msa_repr, proj1) return x -@nnscaler.graph.parser.register('N S^ R C^, N S^ T^ C^, F^ Z^ -> N R T^ Z^') +@nnscaler.register_op('N S^ R C^, N S^ T^ C^, F^ Z^ -> N R T^ Z^') @torch.jit.ignore def opm(left: torch.Tensor, right: torch.Tensor, out_proj: torch.Tensor, chunk_size: int, training: bool): diff --git a/examples/openfold/blocks/tmu.py b/examples/openfold/blocks/tmu.py index ea83c3c4..32d0550e 100644 --- a/examples/openfold/blocks/tmu.py +++ b/examples/openfold/blocks/tmu.py @@ -3,20 +3,20 @@ from examples.openfold.blocks.utils import multi2ref -# @nnscaler.graph.parser.register('N S R Z^, Z^ E, Z^ E -> N S R E') +# @nnscaler.register_op('N S R Z^, Z^ E, Z^ E -> N S R E') # def tmu_projection(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): # x = torch.matmul(pair_repr, proj1) # x = torch.sigmoid(x) # x = x * torch.matmul(pair_repr, proj2) -# -# -# @nnscaler.graph.parser.register('N S R Z+, Z+ E-> N S R E') +# +# +# @nnscaler.register_op('N S R Z+, Z+ E-> N S R E') # def tmu_gate(pair_repr: torch.Tensor, proj: torch.Tensor): # return torch.sigmoid(torch.matmul(pair_repr, proj)) -@nnscaler.graph.parser.register('N S R Z^, Z^ E^, Z^ E^, Z^ E, Z^ E^, Z^ Z^ -> N S R E, N S R E^, N S R Z^', name='tmu_projection') -def tmu_projection(pair_repr: torch.Tensor, +@nnscaler.register_op('N S R Z^, Z^ E^, Z^ E^, Z^ E, Z^ E^, Z^ Z^ -> N S R E, N S R E^, N S R Z^', name='tmu_projection') +def tmu_projection(pair_repr: torch.Tensor, left1: torch.Tensor, left2: torch.Tensor, right1: torch.Tensor, right2: torch.Tensor, gate: torch.Tensor): @@ -34,7 +34,7 @@ def tmu_projection(pair_repr: torch.Tensor, return left, right, gate -@nnscaler.graph.parser.register('N S R^ E, N T^ R^ E^, N S^ T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='tmo') +@nnscaler.register_op('N S R^ E, N T^ R^ E^, N S^ T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='tmo') def tmo(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, norm_w: torch.Tensor, norm_b: torch.Tensor, out: torch.Tensor): a = left.permute(0, 3, 1, 2) @@ -46,7 +46,7 @@ def tmo(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, return p -@nnscaler.graph.parser.register('N R^ S E, N R^ T^ E^, N T^ S^ Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='tmi') +@nnscaler.register_op('N R^ S E, N R^ T^ E^, N T^ S^ Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='tmi') def tmi(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, norm_w: torch.Tensor, norm_b: torch.Tensor, out: torch.Tensor): a = left.permute(0, 3, 2, 1) @@ -68,7 +68,7 @@ def __init__(self, cz: int, mult: int, outgoing: bool) -> None: self.left2 = torch.nn.Parameter(torch.empty(cz, mult)) self.right1 = torch.nn.Parameter(torch.empty(cz, mult)) self.right2 = torch.nn.Parameter(torch.empty(cz, mult)) - + # self.norm = torch.nn.LayerNorm(mult) self.normw = torch.nn.Parameter(torch.empty(mult)) self.normb = torch.nn.Parameter(torch.empty(mult)) @@ -76,7 +76,7 @@ def __init__(self, cz: int, mult: int, outgoing: bool) -> None: self.out = torch.nn.Parameter(torch.empty(mult, cz)) self.gate = torch.nn.Parameter(torch.empty(cz, cz)) self.outgoing = outgoing - + def forward(self, pair_repr: torch.Tensor): """ pair_repr: [N S R Z] @@ -85,7 +85,7 @@ def forward(self, pair_repr: torch.Tensor): pair_repr = self.layer_norm(pair_repr) left, right, gate = tmu_projection(pair_repr, - self.left1, self.left2, + self.left1, self.left2, self.right1, self.right2, self.gate ) diff --git a/examples/openfold/blocks/utils.py b/examples/openfold/blocks/utils.py index bf512f2e..7bb61f49 100644 --- a/examples/openfold/blocks/utils.py +++ b/examples/openfold/blocks/utils.py @@ -2,6 +2,6 @@ import torch -@nnscaler.graph.parser.register('* -> *, *', name='multi2ref') +@nnscaler.register_op('* -> *, *', name='multi2ref') def multi2ref(x: torch.Tensor): return (x, x) \ No newline at end of file diff --git a/examples/openfold/train.py b/examples/openfold/train.py index 11186ad5..61417954 100644 --- a/examples/openfold/train.py +++ b/examples/openfold/train.py @@ -9,6 +9,7 @@ from examples.openfold.model import AlphaFold, Config import nnscaler +from nnscaler.compiler import compile, SemanticModel from nnscaler.profiler.timer import CudaTimer, print_each_rank from nnscaler.profiler.memory import memory_summary from examples.openfold.policy.mpmd import PASDAP, PASRoundRobin, PASNF1B, PASDAPPipe @@ -75,7 +76,7 @@ def train(): dtype = torch.float16 if args.fp16 else torch.float32 dataloader = nnscaler.runtime.syndata.SynDataLoader( - shapes=([cfg.bs, cfg.evoformer_s, cfg.evoformer_r, cfg.evoformer_cm], + shapes=([cfg.bs, cfg.evoformer_s, cfg.evoformer_r, cfg.evoformer_cm], [cfg.bs, cfg.evoformer_r, cfg.evoformer_r, cfg.evoformer_cz]), dtypes=(dtype, dtype), batch_dims=(0, 0) @@ -83,8 +84,8 @@ def train(): print_each_rank(f'before partitioned model parameter: {nparams(model)}') - model = nnscaler.SemanticModel(model) - @nnscaler.compile(model, dataloader, PAS=PASDAPPipe, override=True, load_content=True) + model = SemanticModel(model) + @compile(model, dataloader, PAS=PASDAPPipe, override=True, load_content=True) def train_iter(model, dataloader): input_ids, position_ids = next(dataloader) loss = model(input_ids, position_ids) diff --git a/examples/vision/swin/blocks/attention.py b/examples/vision/swin/blocks/attention.py index b3ec28d8..07375626 100644 --- a/examples/vision/swin/blocks/attention.py +++ b/examples/vision/swin/blocks/attention.py @@ -3,11 +3,11 @@ import nnscaler -# REMARK: as default attention has qkv project weight of (3 head dim_head) C, +# REMARK: as default attention has qkv project weight of (3 head dim_head) C, # this cannot partition on head dimension # as the head dimension is a secondary hidden dimension in (3 head dim_head). # To make partition work (correctness guarantee), the dimension is swapped as (head dim_head 3) -@nnscaler.graph.parser.register('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), nw N^ N^ -> B N^ C^') +@nnscaler.register_op('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), nw N^ N^ -> B N^ C^') def window_attn(x: torch.Tensor, qkv_w: torch.Tensor, qkv_bias: torch.Tensor, relative_position_index: torch.Tensor, diff --git a/examples/vision/swin/blocks/mlp.py b/examples/vision/swin/blocks/mlp.py index d36d1456..b8663783 100644 --- a/examples/vision/swin/blocks/mlp.py +++ b/examples/vision/swin/blocks/mlp.py @@ -3,7 +3,7 @@ import nnscaler -@nnscaler.graph.parser.register('B HW^ E^, H+ E^, H+, E^ H+ -> B HW^ E^', name='feedforward') +@nnscaler.register_op('B HW^ E^, H+ E^, H+, E^ H+ -> B HW^ E^', name='feedforward') def feedforward(x: torch.Tensor, proj1: torch.Tensor, proj1_bias: torch.Tensor, proj2: torch.Tensor, dropout: float) -> torch.Tensor: diff --git a/examples/vision/swin/blocks/patch.py b/examples/vision/swin/blocks/patch.py index 77cf3b8d..3d8a124d 100644 --- a/examples/vision/swin/blocks/patch.py +++ b/examples/vision/swin/blocks/patch.py @@ -6,7 +6,7 @@ import nnscaler -@nnscaler.graph.parser.register('B (2 h^ 2 w^) C^ -> B (h w) (4 C)') +@nnscaler.register_op('B (2 h^ 2 w^) C^ -> B (h w) (4 C)') def patch_merge(x: torch.Tensor, h: int, w: int): B, L, C = x.shape H = 2 * h @@ -22,7 +22,7 @@ def patch_merge(x: torch.Tensor, h: int, w: int): x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C return x -@nnscaler.graph.parser.register('B ic+ (ps^ w^) (ps^ h^), oc ic+ k^ k^, oc -> B oc w^ h^') +@nnscaler.register_op('B ic+ (ps^ w^) (ps^ h^), oc ic+ k^ k^, oc -> B oc w^ h^') def patch(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, ps: int): """ @param ps int: patch size @@ -92,7 +92,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.conv_w = nn.Parameter(torch.empty(embed_dim, in_chans, self.patch_size, self.patch_size)) self.conv_b = nn.Parameter(torch.empty(embed_dim)) - + if norm_layer is not None: self.norm = norm_layer(embed_dim) else: diff --git a/examples/vision/swin/blocks/transformer.py b/examples/vision/swin/blocks/transformer.py index 559de247..bfd520b1 100644 --- a/examples/vision/swin/blocks/transformer.py +++ b/examples/vision/swin/blocks/transformer.py @@ -8,7 +8,7 @@ import nnscaler -@nnscaler.graph.parser.register('* -> *') +@nnscaler.register_op('* -> *') def drop_path(x: torch.Tensor, drop_prob: float, training: bool): if drop_prob <= 0. or not training: return x @@ -20,7 +20,7 @@ def drop_path(x: torch.Tensor, drop_prob: float, training: bool): return output -@nnscaler.graph.parser.register('B (nh ws) (nw ws) C -> (B nh nw) ws ws C') +@nnscaler.register_op('B (nh ws) (nw ws) C -> (B nh nw) ws ws C') def window_partition(x: torch.Tensor, ws: int): """ Args: @@ -36,7 +36,7 @@ def window_partition(x: torch.Tensor, ws: int): return windows -@nnscaler.graph.parser.register('(B nh nw) ws ws C -> B (nh ws) (nw ws) C') +@nnscaler.register_op('(B nh nw) ws ws C -> B (nh ws) (nw ws) C') def window_reverse(windows: torch.Tensor, ws: int, nh: int, nw: int): """ Args: diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 26f64542..ba109c59 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -13,6 +13,7 @@ from examples.vision.swin.model import Config, SwinTransformer, dummy_data import nnscaler +from nnscaler.compiler import compile from nnscaler.profiler.timer import CudaTimer, print_each_rank from nnscaler.profiler.memory import memory_summary from nnscaler.runtime.utils import microbatches @@ -26,7 +27,7 @@ parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') parser.add_argument('--fp16', action='store_true', default=False, help='use fp16 for the training') -parser.add_argument('--dp', type=int, default=1, +parser.add_argument('--dp', type=int, default=1, help='data parallel size, only for megatron') parser.add_argument('--tp', type=int, default=1, help='tensor parallel size, only for megatron') @@ -40,8 +41,8 @@ # get policy policy = get_policy([gallery], args.policy) -policy = partial(policy, - nmicros=args.gbs//args.mbs, +policy = partial(policy, + nmicros=args.gbs//args.mbs, dp_size=args.dp, tp_size=args.tp ) @@ -62,7 +63,7 @@ def train(): gen_data = partial(dummy_data, args.mbs, torch.float16, cfg) dataloader = microbatches((gen_data(),)) - @nnscaler.compile(model, dataloader, PAS=policy, load_content=load_content) + @compile(model, dataloader, PAS=policy, load_content=load_content) def train_iter(model, dataloader): imgs = next(dataloader) loss = model(imgs) diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index 133d1a29..cb9e60ef 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -1,45 +1,51 @@ -from typing import Optional -import logging -from nnscaler import runtime -from nnscaler import utils - -from nnscaler import profiler -from nnscaler.profiler.timer import CudaTimer - -from nnscaler.compiler import SemanticModel, compile - -from nnscaler.utils import load_model, load_default_schedule, load_eval_schedule -from nnscaler.utils import accum_mode +from .version import __version__ +from .parallel import ( + ParallelModule, + UserConfig, + ComputeConfig, + ReuseType, + BroadcastGenFilesStrategy, + parallelize, + build_optimizer, + merge_state_dicts, + load_merged_state_dicts, + deduped_state_dict, + load_deduped_state_dict, + broadcast_weights, +) +from nnscaler.graph.parser.register import register_op -from nnscaler.flags import CompileFlag -from .version import __version__ +def init(): + """ + Initialize the nnscaler library. + It will initialize torch distributed nccl process_group + and set the default cuda device according to the local rank of the process. -def _check_torch_version(): - import torch - torch_version = str(torch.__version__).split('+')[0] - torch_version = float('.'.join(torch_version.split('.')[:2])) - if torch_version < 1.12: - logging.warn(f"expected PyTorch version >= 1.12 but got {torch_version}") + It is recommended to call this function before any other nnscaler functions, + although it is optional if you initialize the torch distributed nccl process_group + and set the default cuda device by yourself. + Please note that you should intialize torch distributed process_group with a large timeout (6 hours, for example), + because the parallelization of modules may take a long time, + and the default timeout (30 minutes) may be too short. -def init(): + Returns: + None + """ + from nnscaler import runtime _ = runtime.device.DeviceGroup() _ = runtime.resource.EnvResource() -def set_logger_level(level): - """Set the logger level with predefined logging format. - - Args: - level (int): the level of the logger. - """ - logging.basicConfig( - level=level, - format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S" - ) +def _check_torch_version(): + import torch + import logging + torch_version = str(torch.__version__).split('+')[0] + torch_version = tuple(int(v) for v in torch_version.split('.')[:2]) + if torch_version < (2, 0): + logging.warn(f"expected PyTorch version >= 2.0 but got {torch_version}") _check_torch_version() diff --git a/nnscaler/compiler.py b/nnscaler/compiler.py index 7ad96093..22d7dfb6 100644 --- a/nnscaler/compiler.py +++ b/nnscaler/compiler.py @@ -29,7 +29,7 @@ from nnscaler.program import Program, SemanticDataLoader, SemanticModel from nnscaler.flags import CompileFlag -from nnscaler.utils import print_each_rank +from nnscaler.utils import print_each_rank, load_default_schedule _logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ def train_iter(model, dataloader): args (Tuple[Any]): compile function example inputs PAS (Callable | Tuple[Callable, Callable, Callable]): policy to transform and schedule graph model_dynamic_shape (bool): whether to compile model with dynamic shape - load_graph_file (str | None): + load_graph_file (str | None): load cached graph. This will skip parsing the function and model. Note the user should keep correct `fullmodel.pt` if load_content is True. save_graph_file (str | None): save parsed graph before applying policy. @@ -99,7 +99,7 @@ def train_iter(model, dataloader): arg = SemanticDataLoader(arg) elif isinstance(arg, torch.Tensor): tensor = arg - arg = IRFullTensor(arg.shape, name='tensor', + arg = IRFullTensor(arg.shape, name='tensor', requires_grad=arg.requires_grad, dtype=arg.dtype).tosub() arg._value = tensor @@ -299,6 +299,6 @@ def decorator(fn: Callable) -> Callable: model.dummy_input = None # load temporal schedule print_each_rank(f'loading generated schedule from {filename} ...', logger=_logger) - return nnscaler.load_default_schedule(filename) + return load_default_schedule(filename) return decorator diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index b3ed80f0..d7c381c3 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -19,7 +19,7 @@ class CustomizedOps: # signature -> IRDimop creation function kOpMap: Dict[str, Callable] = {} - # singature -> runtime function + # singature -> runtime function kOpRuntime: Dict[str, Callable] = {} # signature -> runtime function implementation code kOpCodeDef: Dict[str, str] = {} @@ -27,7 +27,7 @@ class CustomizedOps: @staticmethod def map(signature: str) -> Callable: """Get IRDimop creation function by signature - + Args: signature (str): operator signature @@ -65,7 +65,7 @@ def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Call CustomizedOps.kOpCodeDef[signature] = code -def register(node_repr: Union[str, Callable], name: Optional[str] = None, +def register_op(annotation: Union[str, Callable], name: Optional[str] = None, code_impl_pattern: str = 'import') -> Callable: """ Register a function with IRDimops annotations. @@ -73,9 +73,9 @@ def register(node_repr: Union[str, Callable], name: Optional[str] = None, This function is cooperated with IRDimops. Users can only register functions defined under a module, instead of ones defined inside a function / class or __main__ scope. - The annotation (`node_repr`) specifies the number of inputs as *args, - and treat all the rest inputs as **kwargs. - + The annotation (`annotation`) specifies the number of inputs as *args, + and treat all the rest inputs as **kwargs. + For tensor-type inputs, the annotation should be a string of identifiers separated by space, e.g., `'a b'`; For non-tensor-type inputs, the annotation should be specified '?'. @@ -98,7 +98,7 @@ def register(node_repr: Union[str, Callable], name: Optional[str] = None, def func(x, b = 4): xxx ``` - + or, ```python @@ -112,7 +112,7 @@ def anno_fn(*inputs, **kwargs): ``` Args: - node_repr (str | Callable): operator annotation of IRDimops or callable function that generates IRFwOperation. + annotation (str | Callable): operator annotation of IRDimops or callable function that generates IRFwOperation. - op annotation: e.g., 'a (b c) -> (a b) c' - a callable function that generates op annotation (str). The function taks inputs and kwargs as arguments and returns the operator annotation. @@ -166,7 +166,7 @@ def get_source_code(fn: Callable) -> str: code = inspect.getsource(fn) code = code[code.index('def'):] return code - + def get_import_code(fn: Callable) -> str: import_path = get_import_path(fn) code = f'import {import_path}' @@ -180,11 +180,11 @@ def get_import_code(fn: Callable) -> str: raise ValueError(f'code_impl_pattern should be either "import" or "source", got {code_impl_pattern}') # step 3. define customized IRDimops creation function - if not (isinstance(node_repr, str) or callable(node_repr)): - raise TypeError(f"node_repr should be either str or callable, got {type(node_repr)}") + if not (isinstance(annotation, str) or callable(annotation)): + raise TypeError(f"annotation should be either str or callable, got {type(annotation)}") def udfop(*args, signature=None, **kwargs): - anno = node_repr if isinstance(node_repr, str) else node_repr(*args, **kwargs) + anno = annotation if isinstance(annotation, str) else annotation(*args, **kwargs) if not isinstance(anno, str): raise TypeError(f"node_repr should return a string, but got {type(anno)}: {anno}") anno = OpAnno(anno) @@ -212,3 +212,8 @@ def udfop(*args, signature=None, **kwargs): return fn return decorator + + +# [Deprecated] register_op alias +# Will remove in future. +register = register_op diff --git a/nnscaler/utils.py b/nnscaler/utils.py index fb1f47cf..e14819b7 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -152,6 +152,19 @@ def setup_stride_broadcast_group(stride_size: int) -> BroadcastGroup: ) +def set_default_logger_level(level): + """Set the logger level with predefined logging format. + + Args: + level (int): the level of the logger. + """ + logging.basicConfig( + level=level, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" + ) + + class accum_mode: """Make cube execution in gradient accumulation mode. diff --git a/tests/autodist/pas/test_shared_param_pipeline.py b/tests/autodist/pas/test_shared_param_pipeline.py index 370df6fb..2f784287 100644 --- a/tests/autodist/pas/test_shared_param_pipeline.py +++ b/tests/autodist/pas/test_shared_param_pipeline.py @@ -13,7 +13,7 @@ from nnscaler.graph.segment import IRSegment from nnscaler.flags import CompileFlag from nnscaler.runtime.utils import microbatches -from nnscaler.program import Program, SemanticDataLoader +from nnscaler.program import Program, SemanticDataLoader, SemanticModel from nnscaler.graph.gener.gen import IRAdapterGener @@ -51,7 +51,7 @@ def test_shared_param_pipeline(): 'x': torch.randn(bsz, hidden_dim) }])) - smodel = nnscaler.SemanticModel(model, attr_savedir=tempdir) + smodel = SemanticModel(model, attr_savedir=tempdir) smodel.dummy_input = {'x': torch.randn(bsz, hidden_dim)} smodel.dynamic_shape = False program.set_input([dataloader.irobj]) diff --git a/tests/autodist/spmd_solver/test_cube_operator.py b/tests/autodist/spmd_solver/test_cube_operator.py index c632b6b7..0ce7c21c 100644 --- a/tests/autodist/spmd_solver/test_cube_operator.py +++ b/tests/autodist/spmd_solver/test_cube_operator.py @@ -12,7 +12,7 @@ import nnscaler -@nnscaler.graph.parser.register( +@nnscaler.register_op( '(1 h) l^ d^, (1 h) l^ d^, (1 h) l^ d^ -> (1 h) l^ d^', 'mock_attention') def mock_attention(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return x + y + z diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index ee8007ec..f54f3f4b 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -6,6 +6,8 @@ import more_itertools as mitr import nnscaler +from nnscaler.utils import load_model +from nnscaler.compiler import compile from nnscaler.runtime.utils import microbatches from nnscaler.graph import IRGraph from nnscaler.graph.segment import IRSegment @@ -33,7 +35,7 @@ def forward(self, data): def get_dummy_data(batch_size: int = 512): torch.random.manual_seed(0) return torch.randn( - [128, 512], dtype=torch.float32, + [128, 512], dtype=torch.float32, device=torch.cuda.current_device()).repeat([batch_size // 128, 1]) @@ -43,7 +45,7 @@ def baseline(): init_parameter(model) model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - + losses = [] for _ in range(3): x = get_dummy_data() @@ -65,7 +67,7 @@ def pipe_policy(graph: IRGraph, resource, ngpus_per_unit: int): ngpus = min(ngpus_per_unit, resource.ngpus) fnodes = graph.select(ntype=IRFwOperation) - + stages = mitr.divide(ngpus, fnodes) stages = [list(s) for s in stages] lead_nodes = [s[0] for s in stages] @@ -117,7 +119,7 @@ def cube_run(ngpus_per_unit: int, policy): model = MLP() init_parameter(model) - + ngpus_per_unit = min(ngpus_per_unit, torch.distributed.get_world_size()) nreplicas = torch.distributed.get_world_size() // ngpus_per_unit batch_size = 512 // nreplicas @@ -128,16 +130,16 @@ def cube_run(ngpus_per_unit: int, policy): policy = partial(policy, ngpus_per_unit=ngpus_per_unit) - @nnscaler.compile(model, dl, PAS=policy, scale=True) + @compile(model, dl, PAS=policy, scale=True) def train_iter(model, dataloader): x = next(iter(dataloader)) loss = model(x) loss.backward() return loss - - model = nnscaler.load_model() + + model = load_model() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - + losses = [] for _ in range(3): x = get_dummy_data(batch_size=batch_size) @@ -161,30 +163,30 @@ def train_iter(model, dataloader): # scale test test_scale2 = partial(torchrun, 2, assert_parity, - baseline, + baseline, partial(cube_run, 1, tp_policy) ) # tensor parallelism test test_tp2 = partial(torchrun, 2, assert_parity, - baseline, + baseline, partial(cube_run, 2, tp_policy) ) # tensor parallelism + scale test test_tp2scale2 = partial(torchrun, 4, assert_parity, - baseline, + baseline, partial(cube_run, 2, tp_policy) ) # pipeline parallelism test test_pipe2 = partial(torchrun, 2, assert_parity, - baseline, + baseline, partial(cube_run, 2, pipe_policy) ) # pipeline parallelism + scale test test_pipe2scale2 = partial(torchrun, 4, assert_parity, - baseline, + baseline, partial(cube_run, 2, pipe_policy) ) diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index ec21d090..f439082d 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -10,15 +10,15 @@ def mock_add(x: torch.Tensor, y: torch.Tensor): return x + y -nnscaler.graph.parser.register('*, * -> *')(mock_add) +nnscaler.register_op('*, * -> *')(mock_add) -@nnscaler.graph.parser.register('*, * -> *') +@nnscaler.register_op('*, * -> *') def mock_add2(x: torch.Tensor, y: torch.Tensor): return x + y -@nnscaler.graph.parser.register('(h w^) k^ -> h (w^ k^)') +@nnscaler.register_op('(h w^) k^ -> h (w^ k^)') def mock_view_with_obj(x, h): return x.view(h, -1) @@ -32,7 +32,7 @@ def forward(ctx, x: torch.Tensor, y: torch.Tensor): def backward(ctx, grad): return grad, grad -nnscaler.graph.parser.register('*, * -> *')(MockAGF.apply) +nnscaler.register_op('*, * -> *')(MockAGF.apply) class MockModel(torch.nn.Module): diff --git a/tests/graph/test_multiref.py b/tests/graph/test_multiref.py index b9f81d99..42aa0e35 100644 --- a/tests/graph/test_multiref.py +++ b/tests/graph/test_multiref.py @@ -6,6 +6,8 @@ from functools import partial import nnscaler +from nnscaler.compiler import compile +from nnscaler.utils import set_default_logger_level, load_model from nnscaler.graph import IRGraph from nnscaler.ir.operator import IRFwOperation from ..launch_torchrun import torchrun @@ -31,7 +33,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): y = y * self.param loss = torch.sum(y) return loss - + def get_dummy_data(batch_size: int = 256): torch.random.manual_seed(0) @@ -47,7 +49,7 @@ def baseline(): init_parameter(model) model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - + losses = [] for _ in range(3): x, y = get_dummy_data() @@ -67,19 +69,19 @@ def multiref(): x, y = get_dummy_data() def policy(graph: IRGraph, resource): - + first_mul = graph.select('mul')[0] first_add = graph.select('add')[0] sub_muls = graph.partition( - first_mul, first_mul.algorithms('dim'), + first_mul, first_mul.algorithms('dim'), idx=0, dim=0, num=resource.ngpus ) for idx, sub_node in enumerate(sub_muls): graph.assign(sub_node, idx) sub_adds = graph.partition( - first_add, first_add.algorithms('dim'), + first_add, first_add.algorithms('dim'), idx=0, dim=0, num=resource.ngpus ) for idx, sub_node in enumerate(sub_adds): @@ -91,16 +93,16 @@ def policy(graph: IRGraph, resource): for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) return graph - + x, y = get_dummy_data() - @nnscaler.compile(model, x, y, PAS=policy) + @compile(model, x, y, PAS=policy) def train_iter(model, x, y): loss = model(x, y) loss.backward() return loss - model = nnscaler.load_model() + model = load_model() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) losses = [] @@ -116,7 +118,7 @@ def train_iter(model, x, y): def multiref_test(): nnscaler.init() - nnscaler.set_logger_level(logging.INFO) + set_default_logger_level(logging.INFO) assert_parity(baseline, multiref) diff --git a/tests/runtime/test_gnorm.py b/tests/runtime/test_gnorm.py index 624163b2..c1ef5eda 100644 --- a/tests/runtime/test_gnorm.py +++ b/tests/runtime/test_gnorm.py @@ -2,13 +2,15 @@ This test is to verify the correctness of the gradient norm algorithm for nnscaler. To avoid other potential parity issues that may have influence the gradient value, -we use weight data as gradient, and calculate its norm to verify the correctness +we use weight data as gradient, and calculate its norm to verify the correctness of gnorm calculation. """ import torch from functools import partial import nnscaler +from nnscaler.compiler import compile +from nnscaler.utils import load_model from nnscaler.ir.operator import IRFwOperation from nnscaler.runtime.module import CubeModule from nnscaler.runtime.gnorm import prepare_for_grad_clip, clip_gnorm @@ -86,14 +88,14 @@ def model_test(policy, su_num: int = 1, use_zero: bool = False): wnorm_baseline = cal_wnorm_baseline(model) sample = torch.randn(16, 16).cuda() - @nnscaler.compile(model, sample, PAS=partial(policy, su_num=su_num), + @compile(model, sample, PAS=partial(policy, su_num=su_num), scale=su_num > 1) def train_iter(model, data): loss = model(data) loss.backward() return loss - - model = nnscaler.load_model() + + model = load_model() # train_iter(model, sample) # link .grad to reducer buffer wnorm_cube = cal_wnorm_cube(model) diff --git a/tests/runtime/test_grad_accum.py b/tests/runtime/test_grad_accum.py index 7191e015..3c0b5909 100644 --- a/tests/runtime/test_grad_accum.py +++ b/tests/runtime/test_grad_accum.py @@ -3,6 +3,7 @@ from functools import partial import nnscaler +from nnscaler.utils import accum_mode from nnscaler.runtime.module import CubeModule from ..launch_torchrun import torchrun from ..utils import init_parameter, assert_parity @@ -17,7 +18,7 @@ def __init__(self, ngpus, async_op, dim=512, nlayers=4,): self.layers = torch.nn.ModuleList([]) for _ in range(nlayers): self.layers.append(torch.nn.Linear(dim, dim, bias=False)) - + self.wreducer1 = nnscaler.runtime.adapter.Reducer(ranks=ranks, reduce_op='sum', async_op=async_op, zero=False, max_bucket_size_bytes=137217728, zero_ngroups=1) for param in self.parameters(): @@ -30,7 +31,7 @@ def forward(self, data): x = layer(x) loss = torch.sum(x) return loss - + class BaseMLP(torch.nn.Module): def __init__(self, dim=512, nlayers=4,): @@ -50,7 +51,7 @@ def forward(self, data): def get_dummy_data(batch_size: int = 256): torch.random.manual_seed(0) return torch.randn( - [batch_size, 512], dtype=torch.float32, + [batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()) @@ -59,7 +60,7 @@ def baseline(accum_times: int = 4): init_parameter(model) model = model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - + losses = [] for _ in range(3): for _ in range(accum_times): @@ -84,7 +85,7 @@ def reducer_sync_test(accum_times: int = 4): for reducer in model.reducers: reducer.build_buckets() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - + losses = [] for _ in range(3): model.zero_grad() @@ -115,7 +116,7 @@ def reducer_async_test_wrong(accum_times: int = 4): for reducer in model.reducers: reducer.build_buckets() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - + losses = [] for _ in range(3): model.zero_grad() @@ -146,12 +147,12 @@ def reducer_async_test_correct(accum_times: int = 4): for reducer in model.reducers: reducer.build_buckets() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - + losses = [] for _ in range(3): model.zero_grad() for step in range(accum_times): - with nnscaler.accum_mode(begin=(step == 0), end=(step == accum_times - 1)): + with accum_mode(begin=(step == 0), end=(step == accum_times - 1)): x = get_dummy_data() x = x.chunk(ngpus, dim=0)[rank] loss = model(x) @@ -167,7 +168,7 @@ def reducer_async_test_correct(accum_times: int = 4): loss /= 10.0 losses.append(loss) return losses - + def accum_test(): nnscaler.init() diff --git a/tests/runtime/test_module_merge.py b/tests/runtime/test_module_merge.py index dc255e96..e66e89c8 100644 --- a/tests/runtime/test_module_merge.py +++ b/tests/runtime/test_module_merge.py @@ -7,6 +7,8 @@ from nnscaler.ir.operator import IRFwOperation from nnscaler.runtime.device import DeviceGroup +from nnscaler.compiler import compile +from nnscaler.utils import load_model from ..launch_torchrun import torchrun @@ -74,12 +76,12 @@ def merge_model_states_test(): full_model_state = model.state_dict() - @nnscaler.compile(model, sample, PAS=tp_policy) + @compile(model, sample, PAS=tp_policy) def train_iter(model, sample): loss = model(sample) loss.backward() return loss - cube_model = nnscaler.load_model() + cube_model = load_model() state_dict = cube_model.state_dict() torch.save({'state_dict': state_dict, 'fullmap': cube_model.fullmap}, @@ -110,13 +112,13 @@ def merge_optimizer_states_test(): full_model_state = model.state_dict() full_optim_state = full_optimizer.state_dict() - @nnscaler.compile(model, sample, PAS=tp_policy) + @compile(model, sample, PAS=tp_policy) def train_iter(model, sample): loss = model(sample) loss.backward() return loss - cube_model = nnscaler.load_model() + cube_model = load_model() optimizer = torch.optim.Adam(cube_model.parameters(), lr=0.01) # test for initial state diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py index 2997ebea..6efe6e1b 100644 --- a/tests/runtime/test_reducer.py +++ b/tests/runtime/test_reducer.py @@ -6,6 +6,8 @@ from functools import partial import nnscaler +from nnscaler.compiler import compile +from nnscaler.utils import load_model from nnscaler.graph import IRGraph from nnscaler.ir.operator import IRFwOperation from nnscaler.flags import CompileFlag @@ -31,7 +33,7 @@ def forward(self, data): def get_dummy_data(batch_size: int = 256): torch.random.manual_seed(0) return torch.randn( - [batch_size, 512], dtype=torch.float32, + [batch_size, 512], dtype=torch.float32, device=torch.cuda.current_device()) @@ -41,7 +43,7 @@ def baseline(): init_parameter(model) model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - + losses = [] for _ in range(3): x = get_dummy_data() @@ -64,7 +66,7 @@ def reducer(use_zero: bool, async_reducer: bool): model = MLP() init_parameter(model) - + def policy(graph: IRGraph, resource): def tensor_parallelism(node, idx, dim, num): @@ -88,18 +90,18 @@ def tensor_parallelism(node, idx, dim, num): for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) return graph - + x = get_dummy_data() - @nnscaler.compile(model, x, PAS=policy) + @compile(model, x, PAS=policy) def train_iter(model, x): loss = model(x) loss.backward() return loss - - model = nnscaler.load_model() + + model = load_model() optimizer = torch.optim.Adam(model.parameters_for_optimizer(), lr=0.01) - + losses = [] for _ in range(3): x = get_dummy_data() diff --git a/tutorial.md b/tutorial.md index db88c36f..13f59621 100644 --- a/tutorial.md +++ b/tutorial.md @@ -1,6 +1,6 @@ # Dimop Tutorial -## Dimop: Dimension-annotated Operator +## Dimop: Dimension-annotated Operator ### Annotation for Shape Inference and Transformation @@ -16,11 +16,11 @@ def operator(x: torch.Tensor, w: torch.Tensor, h: float) -> torch.Tensor: return out ``` -To separate inputs and outputs of an operator, `'->'` is a separation keyword where its left part are inputs and right part are outputs. Inside inputs and outputs region, annotation of each tensor is further separated by `','`. +To separate inputs and outputs of an operator, `'->'` is a separation keyword where its left part are inputs and right part are outputs. Inside inputs and outputs region, annotation of each tensor is further separated by `','`. Every dimension of a tensor is annotated by a template of **{identifiers}{reduction}**, like `'m^ kd+'`, `'kd+ n'`, `'m^ n'`, where `m`, `kd` and `n` are identitifiers, `'^'` and `'+'` are reductions. -If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimensions, the first dimension is `m` and the second dimension is `kd`. Dimensions need to be separated by space `' '`. +If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimensions, the first dimension is `m` and the second dimension is `kd`. Dimensions need to be separated by space `' '`. * Identifiers @@ -35,7 +35,7 @@ If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimens Special identifier: 1) `'*'`: this special identifier indicates the dimension is dynamic, which will automatically get expanded given the shape. If there are multiple `*` for different tensors, then they must have same shape for the expanded dimensions, - + e.g., `'* t -> a * t'` can be expanded into `'b c t -> a b c t'` 2) `'?'`: this special identifier indicates the value is not a tensor, which will be ignored @@ -45,7 +45,7 @@ If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimens * Reductions Reductions are served for transformation plans. The reduction can be one of {`''`, `'+'`, `'^'`}: - + * `''` (empty) indicates this dimension can be spatially partitioned, and each output that have this identifier will also be spatially partitioned. * `'+'` indicates this dimension can be spatially partitioned. And each output that doesn't have this identifier will be numerically partitioned (sum-reduction required). @@ -73,7 +73,7 @@ If a tensor is represented as `'m^ kd+'`, it indicates the tensor has two dimens To register a customized "matmul" operator in the runtime, user can simply define a python function and add an decorator on the function with its annotations: ```py -@nnscaler.graph.parser.register('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom') +@nnscaler.register_op('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom') def operator(x: torch.Tensor, w: torch.Tensor, h: float) -> torch.Tensor: out = torch.matmul(x, w) out = out.view(h, out.size(0) // h, out.size(1)) From a94c2512cb54fa1e3c54ec2e3fef1e0e088cfbc6 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 8 May 2024 05:46:43 +0000 Subject: [PATCH 1632/1892] Merged PR 2133: Add inline timer According to our experience, it is helpful to get the time of each line in the gencode to - find the hang point - get the elapsed time cost --- nnscaler/codegen/emit.py | 11 ++++- nnscaler/codegen/schedule/schedule.py | 7 ++- nnscaler/flags.py | 1 + nnscaler/runtime/function/function.py | 13 +++++- tests/parallel_module/test_line_timer.py | 57 ++++++++++++++++++++++++ tests/utils.py | 11 +++++ 6 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 tests/parallel_module/test_line_timer.py diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index b278de0f..500b28e4 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -174,6 +174,8 @@ def emit_fnode(self, node: IRFwOperation, prefix_attr: str = None) -> List[str]: # insert comment if node.comment is not None: codes.append(f'# {node.comment}') + if CompileFlag.line_timer: + codes.append(f'nnscaler.runtime.function.print_time({repr(node.comment or node.signature)})') signature = node.signature # setup arg string @@ -242,6 +244,8 @@ def emit_adapter(self, node: IRAdapter, prefix_attr: Optional[str] = None, prim_kwargs['async_op'] = True kwargs = self.kwargs_name(**prim_kwargs) outputs = self.return_name(prim.outputs()) + if CompileFlag.line_timer: + codes.append(f'nnscaler.runtime.function.print_time({repr(prim.signature)})') code = f'{outputs} = {prim.signature}({itensors}, {kwargs})' codes.append(code) return codes @@ -254,8 +258,11 @@ def emit_reducer(self, node: IRWeightReducer) -> List[str]: - NONE """ reducer_name = f'self.wreducer{node._id}' - code = f'{reducer_name}.sync_grads()' - return [code] + codes = [] + if CompileFlag.line_timer: + codes.append(f'nnscaler.runtime.function.print_time({repr(reducer_name)})') + codes.append(f'{reducer_name}.sync_grads()') + return codes def emit_release(self, tensors: Iterable[IRTensor]) -> str: tnames : Generator = (self.tensor_name(t) for t in tensors) diff --git a/nnscaler/codegen/schedule/schedule.py b/nnscaler/codegen/schedule/schedule.py index a81ab30a..e46d6d86 100644 --- a/nnscaler/codegen/schedule/schedule.py +++ b/nnscaler/codegen/schedule/schedule.py @@ -15,6 +15,7 @@ from nnscaler.codegen.syntax.symtable import SymbolTable from nnscaler.codegen.lifecycle import LifeCycle from nnscaler.codegen.syntax.blocks import FunctionBlock +from nnscaler.flags import CompileFlag _logger = logging.getLogger(__name__) @@ -200,4 +201,8 @@ def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: else: raise RuntimeError(f"Unspported node type: {type(unwrap_node)}") - return [code] + if CompileFlag.line_timer: + type_str = type(unwrap_node).__name__ + return [f'nnscaler.runtime.function.print_time({repr(type_str)})', code] + else: + return [code] diff --git a/nnscaler/flags.py b/nnscaler/flags.py index d93bcbdf..2e843e6a 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -33,6 +33,7 @@ class CompileFlag: # ============== runtime ==================== dev_mode = _to_bool('SINGLE_DEV_MODE') # allow to use python xx.py async_comm = _to_bool('ASYNC_COMM') + line_timer = _to_bool('LINE_TIMER') # ============== reducer ================== # use zero optimization on optimizer status. diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index b2431732..b0f7867b 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -2,6 +2,8 @@ import torch import torch.nn.functional as TorchF import operator +import datetime +from nnscaler.flags import CompileFlag def identity(tensor: torch.Tensor) -> torch.Tensor: @@ -218,4 +220,13 @@ def nndropout(input: torch.Tensor, p=0.5, inplace=False): def setitem(__a, __b, __c): operator.setitem(__a, __b, __c) - return __a \ No newline at end of file + return __a + + +def print_time(content: str): + if not CompileFlag.line_timer: + return + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1 + if torch.cuda.is_available(): + torch.cuda.synchronize() + print(f"line timer: {rank} - {datetime.datetime.now()} - {content}") \ No newline at end of file diff --git a/tests/parallel_module/test_line_timer.py b/tests/parallel_module/test_line_timer.py new file mode 100644 index 00000000..668f2fb6 --- /dev/null +++ b/tests/parallel_module/test_line_timer.py @@ -0,0 +1,57 @@ +from pathlib import Path +import tempfile +import torch + +import pytest +import torch.distributed + +from nnscaler.parallel import parallelize, ComputeConfig +from nnscaler.flags import CompileFlag + +from .common import PASRandomSPMD, init_distributed, clear_dir_on_rank0 +from ..launch_torchrun import launch_torchrun +from ..utils import catch_stdout + + +class Net(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(64, 64) + + # x with shape [128, 64] + def forward(self, x): + return self.fc(x) + + +def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_shape, init_module_params=True): + return parallelize( + module, + {'x': torch.randn(input_shape)}, + PASRandomSPMD, + compute_config, + cube_savedir=cube_savedir, + instance_name=instance_name, + init_module_params=init_module_params + ) + + +def _gpu_worker(): + init_distributed() + compute_config = ComputeConfig(1, 1, use_zero=False) + try: + CompileFlag.line_timer = True + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_line_timer') as tempdir: + net = _to_cube_model(Net(), compute_config, tempdir, 'net', (128, 64)) + x = torch.randn(128, 64).cuda() + + with catch_stdout() as log_stream: + net(x) + logs = log_stream.getvalue() + assert 'line timer: 0' in logs + finally: + CompileFlag.line_timer = False + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_line_timer(): + launch_torchrun(1, _gpu_worker) diff --git a/tests/utils.py b/tests/utils.py index 2170ecec..0b9915dc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -331,3 +331,14 @@ def catch_log(_logger, loglevel='DEBUG'): yield string_stream _logger.removeHandler(handler) _logger.setLevel(old) + + +@contextmanager +def catch_stdout(): + import sys + from io import StringIO + old = sys.stdout + string_stream = StringIO() + sys.stdout = string_stream + yield string_stream + sys.stdout = old From 7bbf3e2899f1a25b27675707c7c1ed42a7d9d71f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 10 May 2024 02:54:18 +0000 Subject: [PATCH 1633/1892] Merged PR 2137: Add buitin policies Integrate autodist to parallelize api. --- .gitignore | 2 +- README.md | 2 +- docs/parallel_module.md | 41 ++- nnscaler/autodist/apis.py | 11 +- nnscaler/autodist/autodist_config.py | 10 +- nnscaler/parallel.py | 64 ++-- nnscaler/policies.py | 299 ++++++++++++++++++ tests/parallel_module/common.py | 148 +-------- tests/parallel_module/test_broadcast.py | 4 +- tests/parallel_module/test_checkpoint.py | 12 +- .../parallel_module/test_checkpoint_buffer.py | 4 +- .../parallel_module/test_checkpoint_dedup.py | 6 +- .../parallel_module/test_checkpoint_shared.py | 4 +- .../parallel_module/test_checkpoint_unused.py | 4 +- tests/parallel_module/test_ddp.py | 6 +- tests/parallel_module/test_end2end.py | 10 +- tests/parallel_module/test_gencode.py | 40 +-- tests/parallel_module/test_inference.py | 4 +- tests/parallel_module/test_init.py | 6 +- tests/parallel_module/test_line_timer.py | 4 +- tests/parallel_module/test_nested.py | 6 +- tests/parallel_module/test_override.py | 3 +- tests/parallel_module/test_reducer_hook.py | 12 +- tests/parallel_module/test_scale_grads.py | 6 +- tests/parallel_module/test_submodule.py | 10 +- tests/parallel_module/test_wholemodule.py | 6 +- tests/test_policies.py | 61 ++++ tests/utils.py | 9 + 28 files changed, 540 insertions(+), 254 deletions(-) create mode 100644 nnscaler/policies.py create mode 100644 tests/test_policies.py diff --git a/.gitignore b/.gitignore index b216a0d1..b6fbec72 100644 --- a/.gitignore +++ b/.gitignore @@ -36,4 +36,4 @@ shelf *.xml # cppimport generated file -.rendered.*.cpp \ No newline at end of file +.rendered.*.cpp diff --git a/README.md b/README.md index 8aa344bd..6f1f60df 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Install the following packages before the installation of cube: * Python >= 3.8 -* PyTorch >= 1.13 +* PyTorch >= 2.0 ## Install diff --git a/docs/parallel_module.md b/docs/parallel_module.md index 304871f8..22965cbf 100644 --- a/docs/parallel_module.md +++ b/docs/parallel_module.md @@ -377,7 +377,7 @@ It has the following parameters: - `dummy_input` (`dict`): the dummy input for the module. The keys are the argument names of `Module.forward` function, and the values are the dummy input for the arguments. The dummy input will be used to trace the module. Please note the module can't be parallelize if `Module.forward` has positional-only arguments. -- `pas_policy` (`Callable[[IRGraph, ComputeConfig], IRGraph]`): the pas policy, which describes how to place all computations across devices. You can use `autodist` to do the pas automatically in the most efficient way. +- `pas_policy` (`Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]`): the pas (partition-assign-schedule) policy, which describes how to place all computations across devices. You need either pass a builtin PAS policy name or a a custom policy function which should take an `IRGraph` and a `ComputeConfig` as input, and return a new `IRGraph` with the PAS policy applied. We have 6 builtin PAS policies: `dp`, `tp`, `pp`, `data`, `hybrid`, and `autodist`. Please note all builtin PAS policies except `autodist` are only for test purpose. The `autodist` policy is the recommended policy for most cases. For details, please refer to `PAS Policies` section. - `compute_config` (`ComputeConfig`): the environment resource @@ -439,7 +439,7 @@ It has the following parameters: - *args: other args for `optimizer_fn` besides module parameters. - **kwargs: the kwargs will pass to `optimizer_fn` -To support distributed training, in the function we need to hook 4 places: +To support distributed training, in the function we need to hook 4 places (which we have done for you in `build_optimizer`. That's why you should use `build_optimizer` to create optimizer): 1. optimizer constructor: the parameters of optimizer will not be the same with the parameters of the module if we use zero. @@ -447,7 +447,7 @@ To support distributed training, in the function we need to hook 4 places: 2. `optimizer.step()`: we need to call `optimizer.sync_shard_grad()` to sync the gradients of the module before `optimizer.step()`. - In zero mode (not supported yet), we have to call `CubeModule.gather_params()` after `optimizer.step()` + In zero mode, we have to call `CubeModule.gather_params()` after `optimizer.step()` 3. `optimizer.zero_grad()`: We need to call `CubeModule.zero_grad()` after `optimizer.zero_grad()` @@ -520,6 +520,41 @@ def infer_step(self, samples: List[Any]) -> List[Any]: The inference step function. It should be called in the inference loop. The input is a list of samples, and returns a list of outputs for the samples. If pipeline is used, it must have the same length as pipeline_nmicros +### PAS Policies + +Writing a pas policy can be very hard and error-prone. So we provide 6 builtin PAS policies to help you. `dp`, `tp`, `pp`, `data`, `hybrid`, and `autodist`. Please note only `autodist` policy is the recommended policy for most cases, and all other PAS policies are mainly test purpose only. + +The configuration of the PAS policy should be passed in the `user_config.code['pas']` of `ComputeConfig` as a dictionary. + +1. `dp`: data parallelism. It will replicate the module across all devices, and run data parallelism across all devices. It requires the `plan_ngpus` must be 1 and no configurations + +2. `tp`: tensor parallelism + data parallelism. It will do tensor parallelism inside a scale unit, and run data parallelism across scale units. It has only one configuration: + - seed: the random seed for choose the partition dimension. Default is `1` + +3. `pp`: pipeline parallelism + data parallelism. It will do model parallelism inside a scale unit, and run data parallelism across scale units. It requires the `use_end2end` and `use_pipeline` to be true. It has no configurations. + +4. `data`: tensor parallelism on batch dimension. It has no configurations. + +5. `hybrid`: pipeline parallelism + tensor parallelism + data parallelism. It will do model parallelism and tensor parallelism(on 0 dimension) inside a scale unit, and run data parallelism across scale units. It requires the `use_end2end` and `use_pipeline` to be true. It has no configurations. + +6. `autodist`: the recommended policy for most cases. Currently it only support Adam-like optimizers. It will automatically choose the best partition for you by balancing the memory usage and speed. It has the following configurations. + - `update_freq (int)`: the update frequency when training the module. Required. + - `mem_constraint (float)`: The memory constraint in each device in GB. Optional. + - `task_name (str)`: The name of the current task to distinguish runs. Optional. + - `use_fp16 (bool)`: Whether you use `fp16`. Default is `False`. Optional. + - `use_memory_efficient_fp16` Whether you use memory efficient fp16 optimizer. Default is `False`. Optional. + - `use_bf16`: Whether you use `bf16`. Default is `False`. Optional. + - `use_memory_efficient_bf16`: Whether you use memory efficient bf16 optimizer. Default is `False`. Optional. + - `re_profile (bool)`: If set to `True`, the computation profiling results will be overridden. Please note reprofiling will take some time. Optional. + - `verbose (bool)`: Whether to print verbose information. Optional. + - `load_plan_path (str)`: The path to the plan file to load. If specified, the plan will be loaded from the file instead of searching. Optional. + - `save_plan_path (str)`: The path to the plan file to save. Optional. + - `partition_constraints_path (str)`: The path to the partition constraints file. Optional. + - `recompute_modules (str)`: The module names to recompute, separated by `,`. For example, `module1,module2`. Optional. + - `pipeline_pivots (str)`: The module names to pivot the pipeline, separated by `,`. For example, if `module1,module2` is specified, stages searched by pipeline solver only start from either `module1` or `module2`. Optional. + - `use_apex_fused_adam_v2`: If set to `True`, the apex fused adam v2 optimizer will be used. Default is `False`. Optional. + +Please note all options to `autodist` are just suggestions. `autodist` will try to find the best partition for you, which may not be the same with your suggestions. ### Checkpoint support diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 87ca38c0..d0816d31 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -114,13 +114,14 @@ def parallelize_graph(graph: IRGraph, search_out_json = json.load(f) search_out = PipelineSearchOutput.from_json(search_out_json) else: - _logger.info(f'save plan to {autodist_config.save_plan_path}') compile_start_time = time.time() search_out = calc_parallel_plan(graph, autodist_config) compile_cost_time = time.time() - compile_start_time - with open(autodist_config.save_plan_path, 'w') as f: - json.dump(search_out.to_json(), f, indent=2) + if autodist_config.save_plan_path: + _logger.info(f'save plan to {autodist_config.save_plan_path}') + with open(autodist_config.save_plan_path, 'w') as f: + json.dump(search_out.to_json(), f, indent=2) _logger.info(f'use plan with e2e time/s {search_out.e2e_time}s,' + f'stage mems/GB {search_out.stage_mems}, ' + @@ -204,6 +205,10 @@ def parallelize_graph(graph: IRGraph, else: stages = [graph] + # TODO: check pipeline_nstages when ready. + # if autodist_config.pipeline and len(stages) != autodist_config.pipeline_nstages: + # raise RuntimeError("pipeline_nstages doesn't match the number of stages (based on your pipeline_pivots config) in the plan") + # add multiref to a tensor when # 1. it is not a grad tensor # 2. it has more than one consumers diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index f1bbcaf8..f2d9a863 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -88,6 +88,8 @@ class AutoDistConfig: - pipeline_pivots (`str`, *optional*, defaults to `''`): The module names to pivot the pipeline, separated by `,`. For example, if `module1,module2` is specified, stages searched by pipeline solver only start from either `module1` or `module2`. + - pipeline_nstages(`int`, *optional*, defaults to `1`): + The number of stages in pipeline parallelism. This option is only used when pipeline is True. - max_pipeline_bubble_ratio (`float`, *optional*, defaults to `0.4`): The maximum bubble ratio in pipeline parallelism. The higher the ratio, the more bubbles will be allowed, the larger search space will be explored. @@ -125,6 +127,7 @@ def __init__(self, re_profile=False, pipeline=False, pipeline_pivots='', + pipeline_nstages=1, max_pipeline_bubble_ratio=0.4, max_pipeline_unbalance_ratio=0.5, solver='dp', @@ -158,6 +161,7 @@ def __init__(self, self.re_profile = re_profile self.pipeline = pipeline self.pipeline_pivots = pipeline_pivots + self.pipeline_nstages = pipeline_nstages self.max_pipeline_bubble_ratio = max_pipeline_bubble_ratio self.max_pipeline_unbalance_ratio = max_pipeline_unbalance_ratio self.solver = solver @@ -198,13 +202,9 @@ def _validate_config(self): if self.save_plan_path: raise ValueError( 'cannot specify both load plan path and save plan path') - else: - self.save_plan_path = self.load_plan_path if self.save_plan_path: - _validate_dir_path(Path(self.save_plan_path).parent) - else: - self.save_plan_path = f'./{self.task_name}.json' + Path(self.save_plan_path).parent.mkdir(parents=True, exist_ok=True) if self.zero_stage not in [0, 1]: raise ValueError(f'zero stage {self.zero_stage} must be 0 or 1') diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index e595d5fe..3912fb9c 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -13,35 +13,37 @@ import os import torch -from nnscaler.codegen.schedule.schedule import ScheduleCodeGen -from nnscaler.graph.parser.fx.parser import FxModuleParser -from nnscaler.graph.schedule.predefined import PredefinedSched -from nnscaler.ir.cten import IRObject, IRTensor -from nnscaler.ir.tensor import IRFullTensor +from nnscaler.codegen import ModuleCodeGen +from nnscaler.codegen.schedule.schedule import ScheduleCodeGen -from nnscaler.flags import CompileFlag, RuntimeFlag -from nnscaler.utils import get_shared_params +from nnscaler.execplan import ExecutionPlan +from nnscaler.execplan.planpass.fusion import DiffFusion +from nnscaler.execplan.planpass.grouping import Grouping from nnscaler.graph import IRGraph from nnscaler.graph import parser -from nnscaler.ir.operator import IRBpOperation, IRDataOperation from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.function.pyfunc import IRPyFunc -from nnscaler.graph.schedule.schedplan import SchedulePlan from nnscaler.graph.gener.gen import IRAdapterGener +from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.schedule.predefined import PredefinedSched +from nnscaler.graph.schedule.schedplan import SchedulePlan -from nnscaler.codegen import ModuleCodeGen -from nnscaler.execplan import ExecutionPlan -from nnscaler.execplan.planpass.grouping import Grouping -from nnscaler.execplan.planpass.fusion import DiffFusion +from nnscaler.ir.cten import IRObject, IRTensor +from nnscaler.ir.operator import IRBpOperation, IRDataOperation +from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.unique import IDGenerator -from nnscaler.program import Program + from nnscaler.runtime.adapter.reducer import Reducer -from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState from nnscaler.runtime.device import DeviceGroup from nnscaler.runtime.gnorm import calcuate_gnorm, clip_grads -from nnscaler.utils import get_member_by_name, setup_stride_broadcast_group +from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState + +from nnscaler.flags import CompileFlag, RuntimeFlag +import nnscaler.policies as policies +from nnscaler.program import Program +from nnscaler.utils import get_member_by_name, setup_stride_broadcast_group, get_shared_params logger = logging.getLogger(__name__) @@ -53,6 +55,13 @@ if isinstance(v, staticmethod) and k.startswith(_PREDEFINE_SCHED_NAME_PREFIX): _PREDEFINE_SCHEDS[k[len(_PREDEFINE_SCHED_NAME_PREFIX):]] = v +_PREDEFINED_POLICIES: Dict[str, Callable[[IRGraph, 'ComputeConfig'], IRGraph]] = {} +_PREDEFINED_POLICIES_NAME_PREFIX = 'pas_' +for k, v in policies.__dict__.items(): + if callable(v) and k.startswith(_PREDEFINED_POLICIES_NAME_PREFIX): + _PREDEFINED_POLICIES[k[len(_PREDEFINED_POLICIES_NAME_PREFIX):]] = v + + @dataclass class UserConfig: # you should put any configuration that may affect the traced graph here. @@ -85,8 +94,15 @@ class UserConfig: graph: Dict[str, Any] = field(default_factory=dict) # you can put any configuration that may affect the generated code (but not affect the traced graph) here. # For example, extra arguments of your PAS function can put here. + # For all builtin pas, we will put PAS config in `code['pas']`. code: Dict[str, Any] = field(default_factory=dict) + def get_pas_config(self) -> Dict[str, Any]: + """ + All builtin pas will read their config here. + """ + return self.code.get('pas', {}) + @dataclass(frozen=True) class ComputeConfig: @@ -807,7 +823,7 @@ def _load_cube_module_class( def parallelize( module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]], dummy_input: dict, - pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], + pas_policy: Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]], compute_config: ComputeConfig, *, cube_savedir: Union[str, Path] = './.cube', @@ -871,7 +887,8 @@ def __init__(self, init_params=True): Args: module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled dummy_input (dict): the dummy input for the module - pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy + pas_policy (Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]): the pas policy, + it can be a name of builtin policies, or a custom policy function. compute_config (ComputeConfig): the environment resource reuse (ReuseType): specify which part can be reused. cube_savedir (Union[str, Path]): the directory to save generated code @@ -897,6 +914,11 @@ def __init__(self, init_params=True): ): return module_or_module_class if load_module else None + if isinstance(pas_policy, str): + if not pas_policy in _PREDEFINED_POLICIES: + raise ValueError(f"Invalid pas_policy: {pas_policy}") + pas_policy = _PREDEFINED_POLICIES[pas_policy] + is_module_class = inspect.isclass(module_or_module_class) module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ reuse = ReuseType(reuse) if isinstance(reuse, str) else reuse @@ -1122,7 +1144,7 @@ def build_optimizer( """ Build an optimizer for a module. - To support parallelized module (CubeModule), we need to hook 4 places: + To support parallelized module (CubeModule), we hook 4 places in this function: 1. optimizer constructor: the parameters of optimizer will not be the same with the parameters of the module if we use zero so we need to replace the parameters of optimizer with CubeModule.parameters_for_optimizer @@ -1135,8 +1157,6 @@ def build_optimizer( 4. backward(): you need to call optimizer.sync_shard_grad() manually if you want to read the gradients of the module before optimizer.step(). - Please note this DOES NOT work in end2end mode. - Args: module (torch.nn.Module): the module to be optimized optimizer_fn (Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]): @@ -1153,7 +1173,7 @@ def build_optimizer( """ if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): - raise RuntimeError("End2End mode is not supported") + raise RuntimeError("Old style CubeModule is not supported") # only the root module can be end2end module. if any(m != module and isinstance(m, ParallelModule) and m.compute_config.use_end2end for m in module.modules()): diff --git a/nnscaler/policies.py b/nnscaler/policies.py new file mode 100644 index 00000000..041681e0 --- /dev/null +++ b/nnscaler/policies.py @@ -0,0 +1,299 @@ +""" +Policy Writing Guidelines + +Users can write the policy following the steps: + +1. Apply multiref +2. Apply recompute +3. Graph staging (pipeline only) +4. Graph partition & assign +5. Apply schedule (pipeline only) + +Note the steps 1, 2, 3 must be finished before any graph partition. + +IRDataOperation is recommended to be replicated to all devices. +""" + +import logging +from typing import List, Optional, TYPE_CHECKING +import random + +import torch +import more_itertools as mitr + +from nnscaler.autodist.apis import parallelize_graph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.graph import IRGraph +from nnscaler.graph.function.anchor import IRGraphAnchor +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.segment import IRSegment +from nnscaler.ir.operator import IRDataOperation, IRFwOperation + + +if TYPE_CHECKING: + from nnscaler.parallel import ComputeConfig + + +_logger = logging.getLogger(__name__) + + +def _tp(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int): + if len(devs) > 1: + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) + else: + sub_nodes = [node] + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def _replica(graph: IRGraph, node, devs: List[int]): + sub_nodes = graph.replicate(node, times=len(devs)) + for devid, sub_node in zip(devs, sub_nodes): + graph.assign(sub_node, devid) + return sub_nodes + + +def pas_dp(graph: IRGraph, cfg: 'ComputeConfig'): + """ + pure data parallelism policy + """ + ngpus = cfg.plan_ngpus + if ngpus != 1: + raise ValueError("Data parallelism only supports 1 plan GPU") + + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): continue + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor) + + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'multiref' or isinstance(node, IRGraphAnchor): + continue + _replica(graph, node, [0]) + return graph + + +def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): + """ + random tensor parallelism inside a scale unit, and dp across scale units + """ + ngpus = cfg.plan_ngpus + # get the current random state + state = random.getstate() + + seed = cfg.user_config.get_pas_config().get('seed', 1) # by default we fix the seed for test reproducibility + random.seed(seed) + devs = list(range(ngpus)) + + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): continue + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor) + + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'multiref' or isinstance(node, IRGraphAnchor): + continue + if isinstance(node, IRDimops): + configs = node.transform_space() + if len(configs) == 0: + _replica(graph, node, devs) + else: + configs = sorted(configs, reverse=True, + key=lambda config: node.input(config[0]).shape[config[1]]) + random.shuffle(configs) + for (idx, dim) in configs: + if node.input(idx).shape[dim] % len(devs) != 0: continue + if node.algorithms('dim').satisfy(idx=idx, dim=dim, num=len(devs)): + _tp(graph, node, devs, idx, dim) + break + else: + _replica(graph, node, devs) + else: + _replica(graph, node, devs) + + # restore the random state + random.setstate(state) + return graph + + +def pas_pp(graph: IRGraph, cfg: 'ComputeConfig'): + """ + pipeline parallelism inside a scale unit, and dp across scale units + """ + if cfg.pipeline_nstages != cfg.plan_ngpus: + raise ValueError("pipeline_nstages should be equal to plan_ngpus") + return pas_hybrid(graph, cfg) + + +def pas_data(graph: IRGraph, env_resource: 'ComputeConfig'): + """ + tensor partition on batch dimension inside a scale unit, and dp across scale units + """ + ngpus = env_resource.plan_ngpus + # auto multi-ref + for ftensor in graph.full_tensors(): + if len(graph.consumers(ftensor)) > 1: + graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) + + batch_dim = 0 + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, list(range(ngpus))) + + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + try: + algo = node.algorithms('dim') + idx = 0 + sub_nodes = graph.partition( + node, algo, idx=idx, dim=batch_dim, num=ngpus) + except Exception: + sub_nodes = graph.replicate(node, ngpus) + + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): + """ + pipeline and tensor parallelism inside a scale unit, and dp across scale units + """ + if not cfg.use_pipeline: + raise ValueError("pipeline should be enabled") + + ngpus: int = cfg.plan_ngpus + nstages = cfg.pipeline_nstages + tp_size: int = cfg.plan_ngpus // nstages + if ngpus % tp_size != 0: + raise ValueError(f'invalid tp_size {tp_size} for ngpus {ngpus}') + pp_size = ngpus // tp_size + + fnodes = graph.select(ntype=IRFwOperation) + stages = mitr.divide(pp_size, fnodes) + stages = [list(s) for s in stages] + for idx, stage in enumerate(stages): + _logger.info(f'> stage {idx}: {stage[0]}') + graph.staging([s[0] for s in stages]) + + stages: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + assert len(stages) == pp_size, "Internal Error" + + # stage-wise tensor parallelism + curr_devices = list(range(ngpus)) + for stage in stages: + for node in stage.nodes(): + devs = curr_devices[:tp_size] + try: + _tp(graph, node, devs, idx=0, dim=0) + except Exception as e: + _replica(graph, node, devs) + curr_devices = curr_devices[tp_size:] + assert len(curr_devices) == 0, f"remaining devices: {curr_devices} not used" + + # replicate dataloader + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs=list(range(ngpus))) + + cfg.apply_pipeline_scheduler(graph) + return graph + + +def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: + pas_cfg = cfg.user_config.get_pas_config() + + # required parameters + update_freq = pas_cfg['update_freq'] + if isinstance(update_freq, (tuple, list)): + update_freq = update_freq[0] + if cfg.use_pipeline and update_freq != cfg.pipeline_nmicros: + raise ValueError("pipeline_nmicros should be equal to update_freq") + + # optional parameters + mesh_col = pas_cfg.get('mesh_col', cfg.plan_ngpus) + if mesh_col != cfg.plan_ngpus: + raise ValueError("mesh_col should be equal to plan_ngpus") + mem_constraint = pas_cfg.get('mem_constraint', -1) + task_name = pas_cfg.get('task_name', '_') + use_memory_efficient_fp16 = pas_cfg.get('use_memory_efficient_fp16', False) + use_memory_efficient_bf16 = pas_cfg.get('use_memory_efficient_bf16', False) + use_fp16 = pas_cfg.get('use_fp16', use_memory_efficient_fp16) + use_bf16 = pas_cfg.get('use_bf16', use_memory_efficient_bf16) + re_profile = pas_cfg.get('re_profile', False) + verbose = pas_cfg.get('verbose', False) + load_plan_path = pas_cfg.get('load_plan_path', None) + save_plan_path = pas_cfg.get('save_plan_path', None) + partition_constraints_path = pas_cfg.get('partition_constraints_path', '') + recompute_modules = pas_cfg.get('recompute_modules', '') + pipeline_pivots = pas_cfg.get('pipeline_pivots', '') + use_apex_fused_adam_v2 = pas_cfg.get('use_apex_fused_adam_v2', False) + + mesh_row = 1 + ngpus = mesh_row * mesh_col + task_name = f'{task_name}_{ngpus}gpus_{update_freq}update_freq' + if mem_constraint == -1: + # consider memory fragmentation and other buffers, use 80% of the memory + memory_constraint = int(0.8 * torch.cuda.mem_get_info()[1] / 1024 / + 1024 / 1024) + if cfg.use_zero: + zero_stage = 1 + zero_ngroups = cfg.zero_ngroups + else: + zero_stage = 0 + zero_ngroups = 1 + if use_fp16 or use_bf16: + support_inkernel_cast = use_apex_fused_adam_v2 + if use_memory_efficient_fp16 or use_memory_efficient_bf16: + # Check fairseq/optim/fused_adam.py + # If memory efficient: + # Considered in opt_resident_mem: fp32 moment1, fp32 moment2. + # Considered in opt_transient_mem: fp32 weight, fp32 gradient, + # because fp16 weight and gradient are casted to fp32. + # Here weight_mem is in fp16, so multiply by (2+2). + opt_resident_coef = 4 + opt_transient_coef = 0 if support_inkernel_cast else 4 + else: + # If not memory efficient: + # Considered in opt_resident_mem: fp32 moment1, fp32 moment2, fp32 weight. + # Considered in opt_transient_mem: fp32 gradient, + # because fp16 gradient are casted to fp32. + # Here weight_mem is in fp16, so multiply by (2+2+2). + opt_resident_coef = 6 + # inkernel cast between fp32 weight and fp16 grad has not support + opt_transient_coef = 2 if support_inkernel_cast else 2 + else: + # Considered in opt_resident_mem: fp32 moment1, fp32 moment2 + # Considered in opt_transient_mem: 0 + # Here weight_mem is in fp32, so multiply by (1+1). + opt_resident_coef = 2 + opt_transient_coef = 0 + + autodist_cfg = AutoDistConfig( + mesh_row=mesh_row, + mesh_col=mesh_col, + update_freq=update_freq, + task_name=task_name, + is_train=not cfg.inference_only, + ignore_small_tensor_threshold=524288, # 0.5 MB is a good threshold to reduce search time and make the result correct, will refine later + memory_granularity=524288, # 0.5 MB is a good threshold to reduce search time and make the result correct, will refine later + consider_mem=True, + partition_constraints_path=partition_constraints_path, + memory_constraint=memory_constraint, + opt_resident_coef=opt_resident_coef, + opt_transient_coef=opt_transient_coef, + verbose=verbose, + re_profile=re_profile, + world_size=cfg.runtime_ngpus, + recompute_modules=recompute_modules, + zero_stage=zero_stage, + zero_ngroups=zero_ngroups, + load_plan_path=load_plan_path, + save_plan_path=save_plan_path, + pipeline=cfg.use_pipeline, + pipeline_pivots=pipeline_pivots, + pipeline_nstages=cfg.pipeline_nstages, + ) + + return parallelize_graph(graph, autodist_cfg) diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index a2ef0599..5dbbb752 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -17,6 +17,9 @@ from nnscaler.graph.graph import IRGraph from nnscaler.graph.segment import IRSegment from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.policies import _tp, _replica + +from ..utils import init_random def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: @@ -47,100 +50,6 @@ def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: return tuple(outputs) -def _tp(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int): - sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def _replica(graph: IRGraph, node, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def PASRandomSPMD(graph: IRGraph, env_resource: ComputeConfig): - """ - Random SPMD policy - """ - ngpus = env_resource.plan_ngpus - # get the current random state - state = random.getstate() - - seed = 1 - # print(f'> set random SPDM policy seed to {seed}') - random.seed(seed) - devs = list(range(ngpus)) - - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): continue - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor) - - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if node.name == 'multiref' or isinstance(node, IRGraphAnchor): - continue - if isinstance(node, IRDimops): - configs = node.transform_space() - if len(configs) == 0: - _replica(graph, node, devs) - else: - configs = sorted(configs, reverse=True, - key=lambda config: node.input(config[0]).shape[config[1]]) - random.shuffle(configs) - for (idx, dim) in configs: - if node.input(idx).shape[dim] % len(devs) != 0: continue - if node.algorithms('dim').satisfy(idx=idx, dim=dim, num=len(devs)): - # print(f'> partition node {node.name} ({node.cid}) with config idx={idx}, dim={dim}') - _tp(graph, node, devs, idx, dim) - break - else: - _replica(graph, node, devs) - else: - _replica(graph, node, devs) - - # restore the random state - random.setstate(state) - # print(graph.extra_repr()) - return graph - - -def PASData(graph: IRGraph, env_resource: ComputeConfig): - """ - Data Parallel - """ - ngpus = env_resource.plan_ngpus - # auto multi-ref - for ftensor in graph.full_tensors(): - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) - - batch_dim = 0 - for dl in graph.select(ntype=IRDataOperation): - _replica(graph, dl, list(range(ngpus))) - - for node in graph.nodes(): - # print(node) - if isinstance(node, IRFwOperation): - try: - algo = node.algorithms('dim') - idx = 0 - sub_nodes = graph.partition( - node, algo, idx=idx, dim=batch_dim, num=ngpus) - # except AssertionError: - except: - # print(f'WARNING: {node} cannot find dim algo, using replicate instead') - sub_nodes = graph.replicate(node, ngpus) - - for idx, node in enumerate(sub_nodes): - graph.assign(node, idx) - # print(graph.extra_repr()) - return graph - - def PASMegatron(graph: IRGraph, config: ComputeConfig): num_stages = config.pipeline_nstages tp_size = config.plan_ngpus // num_stages @@ -169,50 +78,6 @@ def PASMegatron(graph: IRGraph, config: ComputeConfig): return graph -def PASHybrid(graph: IRGraph, config: ComputeConfig): - """ - Hybrid Tensor and Pipeline Parallelism - """ - ngpus: int = config.plan_ngpus - nstages = config.pipeline_nstages - tp_size: int = config.plan_ngpus // nstages - if ngpus % tp_size != 0: - raise ValueError(f'invalid tp_size {tp_size} for ngpus {ngpus}') - pp_size = ngpus // tp_size - - fnodes = graph.select(ntype=IRFwOperation) - stages = mitr.divide(pp_size, fnodes) - stages = [list(s) for s in stages] - for idx, stage in enumerate(stages): - print(f'> stage {idx}: {stage[0]}') - graph.staging([s[0] for s in stages]) - - stages: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) - stages = [s for s in stages if s.isfw()] - assert len(stages) == pp_size, "Internal Error" - - # stage-wise tensor parallelism - curr_devices = list(range(ngpus)) - for stage in stages: - for node in stage.nodes(): - devs = curr_devices[:tp_size] - try: - _tp(graph, node, devs, idx=0, dim=0) - except Exception as e: - _replica(graph, node, devs) - curr_devices = curr_devices[tp_size:] - assert len(curr_devices) == 0, f"remaining devices: {curr_devices} not used" - - # replicate dataloader - for dl in graph.select(ntype=IRDataOperation): - _replica(graph, dl, devs=list(range(ngpus))) - - # setup 1f1b pipeline scheduler - # PredefinedSched.sched_1f1b(graph, nmicros, pp_size) - config.apply_pipeline_scheduler(graph) - return graph - - class CubeLinear(nn.Module): def __init__(self, in_features, out_features, bias=False): super().__init__() @@ -244,13 +109,6 @@ def init_distributed(): torch.set_default_device(f'cuda:{rank}') -def init_random(): - np.random.seed(1) - torch.manual_seed(1) - if torch.cuda.is_available(): - torch.cuda.manual_seed(1) - - @contextlib.contextmanager def clear_dir_on_rank0(tempdir): if torch.distributed.get_rank() == 0 and tempdir.exists(): diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index 12f04309..f7a3a36e 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -7,7 +7,7 @@ from nnscaler.parallel import ComputeConfig, parallelize, broadcast_weights -from .common import PASRandomSPMD, init_distributed +from .common import init_distributed from ..launch_torchrun import launch_torchrun @@ -28,7 +28,7 @@ def _to_cube_model(module, compute_config, cube_savedir, return parallelize( module, {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, - PASRandomSPMD, + 'tp', compute_config, cube_savedir=cube_savedir, instance_name=instance_name, diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index ccdd4965..78dfb28d 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -14,11 +14,11 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts, UserConfig from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, PASMegatron +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0, PASMegatron from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively from ..utils import replace_all_device_with @@ -441,8 +441,8 @@ def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus, per_resum def test_checkpoint(module_type, use_zero): plan_ngpus = 2 runtime_ngpus = 4 - cube_results = launch_torchrun(4, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus, 32, 1) - rcube_results = launch_torchrun(4, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus, 16, 2) + cube_results = launch_torchrun(4, _gpu_worker, module_type, use_zero, 'tp', plan_ngpus, runtime_ngpus, 32, 1) + rcube_results = launch_torchrun(4, _gpu_worker, module_type, use_zero, 'tp', plan_ngpus, runtime_ngpus, 16, 2) results0, results1, results2, results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] rresults0, rresults1, rresults2, rresults3 = rcube_results[0], rcube_results[1], rcube_results[2], rcube_results[3] @@ -495,8 +495,8 @@ def test_checkpoint_intra_reducer(module_type, use_zero): """ plan_ngpus = 2 runtime_ngpus = 2 - cube_results = launch_torchrun(2, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus, 32, 1, assert_intra_reducer) - rcube_results = launch_torchrun(2, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus, 16, 2, assert_intra_reducer) + cube_results = launch_torchrun(2, _gpu_worker, module_type, use_zero, 'tp', plan_ngpus, runtime_ngpus, 32, 1, assert_intra_reducer) + rcube_results = launch_torchrun(2, _gpu_worker, module_type, use_zero, 'tp', plan_ngpus, runtime_ngpus, 16, 2, assert_intra_reducer) results0 = cube_results[0] rresults0 = rcube_results[0] diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index d8d603ff..8aec6d4b 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -7,7 +7,7 @@ from nnscaler.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dicts, broadcast_weights -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun from ..utils import catch_log @@ -49,7 +49,7 @@ def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_sh return parallelize( module, {'x': torch.randn(input_shape)}, - PASRandomSPMD, + 'tp', compute_config, cube_savedir=cube_savedir, instance_name=instance_name, diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index c6a3545d..2b8a0ab1 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -11,7 +11,7 @@ deduped_state_dict, load_deduped_state_dict from nnscaler.runtime.module import ParallelModule -from .common import PASRandomSPMD, PASMegatron, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, assert_equal +from .common import PASMegatron, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, assert_equal from ..launch_torchrun import launch_torchrun from .test_checkpoint import gendata, train_step, End2EndMLP, End2EndMLPWithUnusedAndShared @@ -190,12 +190,12 @@ def _gpu_worker(pas, cc1, cc2): def test_checkpoint_compact(use_zero): cc1 = ComputeConfig(1, 4, use_zero=use_zero, zero_ngroups=2 if use_zero else 1) cc2 = ComputeConfig(1, 4, use_zero=use_zero, zero_ngroups=4 if use_zero else 1) - launch_torchrun(4, _gpu_worker, PASRandomSPMD, cc1, cc2) + launch_torchrun(4, _gpu_worker, 'tp', cc1, cc2) # mixed zero and non-zero cc1 = ComputeConfig(2, 4, use_zero=not use_zero, zero_ngroups=2 if not use_zero else 1) cc2 = ComputeConfig(2, 4, use_zero=use_zero, zero_ngroups=1) - launch_torchrun(4, _gpu_worker, PASRandomSPMD, cc1, cc2) + launch_torchrun(4, _gpu_worker, 'tp', cc1, cc2) def _gpu_worker_pipeline(cc): diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index e0aef12b..0165f024 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -8,7 +8,7 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts -from .common import PASRandomSPMD, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun from .test_checkpoint import End2EndMLP, train_step, gendata @@ -209,4 +209,4 @@ def test_checkpoint_load_from_raw_checkpoint(module_type, use_zero): """ plan_ngpus = 2 runtime_ngpus = 4 - launch_torchrun(4, _gpu_worker, module_type, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus) + launch_torchrun(4, _gpu_worker, module_type, use_zero, 'tp', plan_ngpus, runtime_ngpus) diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py index e6bc7a6a..8e90041b 100644 --- a/tests/parallel_module/test_checkpoint_unused.py +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -18,7 +18,7 @@ from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively from .test_checkpoint_shared import _train_raw, _load_merged @@ -127,4 +127,4 @@ def test_checkpoint_load_from_raw_checkpoint(use_zero): """ plan_ngpus = 2 runtime_ngpus = 4 - launch_torchrun(4, _gpu_worker, use_zero, PASRandomSPMD, plan_ngpus, runtime_ngpus) + launch_torchrun(4, _gpu_worker, use_zero, 'tp', plan_ngpus, runtime_ngpus) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index a8efcd19..ec256176 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -18,7 +18,7 @@ from nnscaler.runtime.module import ParallelModule from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively @@ -286,8 +286,8 @@ def test_tp_ddp(update_freq): # print('weight: ', k, torch.max(torch.abs(a0[3][k]- b[3][k]))) assert torch.allclose(a0.weights[k], b.weights[k], atol=1e-2, rtol=1e-2) # weights - cube_results = launch_torchrun(4, _gpu_worker_cube, PASRandomSPMD, 2, 4, update_freq, False) - zcube_results = launch_torchrun(4, _gpu_worker_cube, PASRandomSPMD, 2, 4, update_freq, True) + cube_results = launch_torchrun(4, _gpu_worker_cube, 'tp', 2, 4, update_freq, False) + zcube_results = launch_torchrun(4, _gpu_worker_cube, 'tp', 2, 4, update_freq, True) worker_results0, worker_results1, worker_results2, worker_results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] results0: List[StepResult] = worker_results0[0] results1: List[StepResult] = worker_results1[0] diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index c808a56e..86f94370 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -18,7 +18,7 @@ from nnscaler.runtime.utils import microbatches from nnscaler.runtime.module import ParallelModule from nnscaler.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts -from .common import PASData, PASRandomSPMD, assert_equal, clear_dir_on_rank0, init_distributed, PASMegatron, init_random, PASHybrid +from .common import assert_equal, clear_dir_on_rank0, init_distributed, PASMegatron, init_random from ..launch_torchrun import clone_to_cpu_recursively, launch_torchrun from .test_checkpoint import End2EndMLP @@ -170,7 +170,7 @@ def test_end2end(): ga4_result = _train_ga(model, 4) # micro_batch_size = 4 assert len(ga4_result) == 16 - cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, PASHybrid, True) # micro_batch_size = 4 + cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid', True) # micro_batch_size = 4 cube2_result = merge_cube_result({k: v[0] for k, v in cube2_results.items()}) assert len(cube2_result) == 16 allclose(cube2_result, ga4_result) @@ -180,7 +180,7 @@ def test_end2end(): assert len(cube4_result) == 16 allclose(cube4_result, ga4_result) - cube2_results_non_pipeline = launch_torchrun(4, gpu_worker_cube, 4, 2, PASRandomSPMD, False) # micro_batch_size = 4 + cube2_results_non_pipeline = launch_torchrun(4, gpu_worker_cube, 4, 2, 'tp', False) # micro_batch_size = 4 cube2_result_non_pipeline = merge_cube_result({k: v[0] for k, v in cube2_results_non_pipeline.items()}) assert len(cube2_result_non_pipeline) == 16 allclose(cube2_result_non_pipeline, ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error @@ -242,7 +242,7 @@ def test_pipeline_shared(): # 'chimera_direct' needs more gpus # 'infer_pipe' only work for inference # None looks doesn't work - cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, PASHybrid, True, None, None, MLPShared, ps) # micro_batch_size = 4 + cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid', True, None, None, MLPShared, ps) # micro_batch_size = 4 cube2_result = merge_cube_result({k: v[0] for k, v in cube2_results.items()}) assert len(cube2_result) == 16 allclose(cube2_result, ga4_result) @@ -323,7 +323,7 @@ def gpu_worker_cube_one_sample(): model = parallelize( model, {'data': dummy_data()}, - pas_policy=PASHybrid, + pas_policy='hybrid', compute_config= ComputeConfig( 2, 2, use_end2end=True, diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 8642bbfa..d98c2433 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -7,7 +7,7 @@ import nnscaler.graph.function.dimops from nnscaler.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph -from .common import PASData, init_distributed, PASRandomSPMD +from .common import init_distributed from ..launch_torchrun import launch_torchrun from ..utils import replace_all_device_with @@ -15,7 +15,7 @@ def _to_cube_model(module, compute_config, cube_savedir, load_module): return parallelize( module, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'data', compute_config, cube_savedir=cube_savedir, load_module=load_module @@ -59,7 +59,7 @@ def test_codegen_slice(): m_new = parallelize( SliceModule(), {'x': torch.tensor([1.0, 2.0, 3.0, 6.0])}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False @@ -87,7 +87,7 @@ def test_codegen_args(): 'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 'y': 1.0, }, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True @@ -114,7 +114,7 @@ def _gencode_unused_args_worker(tempdir): 'm': 0, 'n': None, }, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True @@ -166,7 +166,7 @@ def _gencode_unused_args_worker2(tempdir): 'y': torch.tensor([1, 2, 3]), 'm': 0 }, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True @@ -225,7 +225,7 @@ def test_codegen_attr(): m_new = parallelize( AttrModule(), {'x': torch.tensor([1.0, 2.0, 3.0, 6.0]), 'attr': AttrHelper()}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False @@ -257,7 +257,7 @@ def test_codegen_getitem(): m_new = parallelize( GetItemModule(), {'batched_data': {'x': torch.tensor([[[1.0], [2.0], [3.0], [6.0]]])}}, - PASRandomSPMD, + 'tp', ComputeConfig(2, 2), cube_savedir=tempdir, load_module=False, @@ -287,7 +287,7 @@ def test_codegen_training_flag(): parallelize( m, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False @@ -346,7 +346,7 @@ def test_codegen_iter(): parallelize( m, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False @@ -376,7 +376,7 @@ def test_codegen_const(): parallelize( m, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False @@ -415,7 +415,7 @@ def test_codegen_tensor_slice(): parallelize( m, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False, @@ -426,7 +426,7 @@ def test_codegen_tensor_slice(): parallelize( m, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False, @@ -453,7 +453,7 @@ def test_codegen_dictget(): 'x': torch.tensor([[[1.0], [2.0], [3.0], [6.0]]]), 'z': torch.tensor([[[1.0], [2.0], [3.0], [6.0]]]) }}, - PASRandomSPMD, + 'tp', ComputeConfig(2, 2), cube_savedir=tempdir, load_module=False, @@ -498,7 +498,7 @@ def _gencode_min_function_worker(tempdir): 'a': torch.tensor([5, 2, 3]), 'b': torch.tensor([1, 8, 1]), }, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True @@ -534,7 +534,7 @@ def _gencode_max_function(tempdir): { 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), }, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True @@ -575,7 +575,7 @@ def test_codegen_shared_parameter(): parallelize( m, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False, @@ -612,7 +612,7 @@ def test_codegen_buffer(): parallelize( m, {'x': torch.randn(128, 64)}, - PASData, + 'dp', ComputeConfig(1, 1), cube_savedir=tempdir, load_module=False, @@ -656,7 +656,7 @@ def test_codegen_inference(): parallelize( Module0(), {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'dp', ComputeConfig(1, 1, inference_only=True), cube_savedir=tempdir, load_module=False @@ -684,7 +684,7 @@ def p(cube_dir, use_pipeline, dynamic_shape, return_type, inference_only=False): parallelize( m, {'data': torch.randn(batch_size, dim), 'return_type': return_type}, - PASData, + 'data', compute_config= ComputeConfig( 4, 4, inference_only=inference_only, diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 23ab211d..417eb17a 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -8,7 +8,7 @@ from nnscaler.parallel import ComputeConfig, parallelize -from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD, clear_dir_on_rank0 +from .common import CubeLinear, init_distributed, init_random, clear_dir_on_rank0 from ..launch_torchrun import torchrun @@ -62,7 +62,7 @@ def _inference_worker(ngpus, inference_only): model = Module() model.eval() - cube_model = _to_cube_model(model, PASRandomSPMD, + cube_model = _to_cube_model(model, 'tp', ComputeConfig(ngpus, ngpus, inference_only=inference_only), tempdir, 'test_inference' ) diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 3e2ffc0b..6340850b 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -6,7 +6,7 @@ from nnscaler.parallel import _load_cube_module_class, parallelize, ComputeConfig from ..launch_torchrun import launch_torchrun -from .common import CubeLinear, init_distributed, init_random, PASRandomSPMD, clear_dir_on_rank0 +from .common import CubeLinear, init_distributed, init_random, clear_dir_on_rank0 from ..utils import new_empty, replace_all_device_with, mock_dist class MyModule(torch.nn.Module): @@ -24,7 +24,7 @@ def _init_params_worker(): cube_module = parallelize( MyModule, {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, - PASRandomSPMD, + 'tp', ComputeConfig(1, 1), cube_savedir=tempdir, reuse='match', @@ -67,7 +67,7 @@ def test_empty_weights(model_class, tp): parallelize( model_class, {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, - PASRandomSPMD, + 'tp', ComputeConfig(2, 4, use_zero=True, zero_ngroups=2), cube_savedir=tempdir, reuse='match', diff --git a/tests/parallel_module/test_line_timer.py b/tests/parallel_module/test_line_timer.py index 668f2fb6..bc92dfc0 100644 --- a/tests/parallel_module/test_line_timer.py +++ b/tests/parallel_module/test_line_timer.py @@ -8,7 +8,7 @@ from nnscaler.parallel import parallelize, ComputeConfig from nnscaler.flags import CompileFlag -from .common import PASRandomSPMD, init_distributed, clear_dir_on_rank0 +from .common import init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun from ..utils import catch_stdout @@ -27,7 +27,7 @@ def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_sh return parallelize( module, {'x': torch.randn(input_shape)}, - PASRandomSPMD, + 'tp', compute_config, cube_savedir=cube_savedir, instance_name=instance_name, diff --git a/tests/parallel_module/test_nested.py b/tests/parallel_module/test_nested.py index ea0a8080..27018788 100644 --- a/tests/parallel_module/test_nested.py +++ b/tests/parallel_module/test_nested.py @@ -5,7 +5,7 @@ from nnscaler.parallel import parallelize, ComputeConfig -from .common import PASData, init_distributed +from .common import init_distributed from ..launch_torchrun import launch_torchrun def _to_cube_model(module, pas, compute_config, cube_savedir): @@ -32,7 +32,7 @@ def _nested_module_worker(): class Module1(torch.nn.Module): def __init__(self): super().__init__() - self.module0 = _to_cube_model(Module0(), PASData, ComputeConfig(1, 1), cube_savedir=tempdir) + self.module0 = _to_cube_model(Module0(), 'dp', ComputeConfig(1, 1), cube_savedir=tempdir) def forward(self, x): return self.module0(x) @@ -45,7 +45,7 @@ def forward(self, x): return self.module1(x) with pytest.raises(RuntimeError, match='CubeModule can not be nested.'): - _to_cube_model(Module2(), PASData, ComputeConfig(1, 1), cube_savedir=tempdir) + _to_cube_model(Module2(), 'data', ComputeConfig(1, 1), cube_savedir=tempdir) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index 6993d278..622c5139 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -11,14 +11,13 @@ from nnscaler.runtime.module import ParallelModule from ..utils import new_empty, replace_all_device_with -from .common import PASData def _to_cube_model(model_class, compute_config, cube_savedir, reuse, instance_name, load_module=True): parallelize( model_class, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, - PASData, + 'data', compute_config, reuse=reuse, cube_savedir=cube_savedir, diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index a3d2e81f..fc6e2234 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -9,7 +9,7 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.runtime.module import ParallelModule -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively @@ -121,24 +121,24 @@ def _gpu_worker(pas, plan_ngpus, runtime_ngpus=None): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_hook_tp_gpu1(): - launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) + launch_torchrun(1, _gpu_worker, 'tp', 1) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_hook_tp_gpu2(): - launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) + launch_torchrun(2, _gpu_worker, 'tp', 2) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_hook_tp_gpu4(): - launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4) + launch_torchrun(4, _gpu_worker, 'tp', 2, 4) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_hook_dp_gpu1(): - launch_torchrun(1, _gpu_worker, PASData, 1) + launch_torchrun(1, _gpu_worker, 'dp', 1) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_hook_dp_gpu2(): - launch_torchrun(2, _gpu_worker, PASData, 2) + launch_torchrun(2, _gpu_worker, 'data', 2) diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py index 22e49bbd..109bd6b0 100644 --- a/tests/parallel_module/test_scale_grads.py +++ b/tests/parallel_module/test_scale_grads.py @@ -18,7 +18,7 @@ from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively @@ -140,8 +140,8 @@ def _gpu_worker(pas, plan_ngpus, runtime_ngpus, scale_grads: bool): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_scale_grads(): - cube_results = launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4, True) - rcube_results = launch_torchrun(4, _gpu_worker, PASRandomSPMD, 2, 4, False) + cube_results = launch_torchrun(4, _gpu_worker, 'tp', 2, 4, True) + rcube_results = launch_torchrun(4, _gpu_worker, 'tp', 2, 4, False) results0, results1, results2, results3 = cube_results[0], cube_results[1], cube_results[2], cube_results[3] rresults0, rresults1, rresults2, rresults3 = rcube_results[0], rcube_results[1], rcube_results[2], rcube_results[3] diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 575cb642..45d7584a 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -12,7 +12,7 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.runtime.module import ParallelModule -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively @@ -131,7 +131,7 @@ def _gpu_worker(pas, ngpus, update_freq): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @pytest.mark.parametrize('update_freq', [1, 2, 4]) def test_submodules_tp_gpu1(update_freq): - results = launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1, update_freq) + results = launch_torchrun(1, _gpu_worker, 'tp', 1, update_freq) orig_results, compiled_results, _, _, _, _ = results[0] for orig, compiled in zip(orig_results, compiled_results): assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred @@ -190,7 +190,7 @@ def _compare_weights(orig0, orig1, compiled0, compiled1, fc1_fullmap, fc2_fullma @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') @pytest.mark.parametrize('update_freq', [1, 2, 4]) def test_submodules_tp_gpu2(update_freq): - results = launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2, update_freq) + results = launch_torchrun(2, _gpu_worker, 'tp', 2, update_freq) results0, results1 = results[0], results[1] eps = 1e-4 @@ -223,7 +223,7 @@ def test_submodules_tp_gpu2(update_freq): @pytest.mark.skipif(not torch.cuda.is_available(), reason='cuda is not available') @pytest.mark.parametrize('update_freq', [1, 2, 4]) def test_submodules_dp_gpu1(update_freq): - results = launch_torchrun(1, _gpu_worker, PASData, 1, update_freq) + results = launch_torchrun(1, _gpu_worker, 'dp', 1, update_freq) orig_results, compiled_results, _, _, _, _ = results[0] for orig, compiled in zip(orig_results, compiled_results): assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred @@ -246,7 +246,7 @@ def test_submodules_dp_gpu1(update_freq): @pytest.mark.parametrize('update_freq', [1, 2, 4]) def test_submodules_dp_gpu2(update_freq): eps = 1e-4 - results = launch_torchrun(2, _gpu_worker, PASData, 2, update_freq) + results = launch_torchrun(2, _gpu_worker, 'data', 2, update_freq) for r in results.values(): orig_results, compiled_results, _, _, _, _ = r for orig, compiled in zip(orig_results, compiled_results): diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 53c3c3b6..bb9b44f1 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -12,7 +12,7 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.runtime.module import ParallelModule -from .common import PASRandomSPMD, PASData, CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively @@ -116,7 +116,7 @@ def _gpu_worker(pas, ngpus): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_module_tp_gpu1(): - results = launch_torchrun(1, _gpu_worker, PASRandomSPMD, 1) + results = launch_torchrun(1, _gpu_worker, 'tp', 1) orig_results, compiled_results, _, _ = results[0] for orig, compiled in zip(orig_results, compiled_results): assert torch.allclose(orig[0], compiled[0], rtol=1e-6, atol=1e-6) # pred @@ -145,7 +145,7 @@ def _compare_weights(orig0, orig1, compiled0, compiled1, module_fullmap, module_ @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_module_tp_gpu2(): - results = launch_torchrun(2, _gpu_worker, PASRandomSPMD, 2) + results = launch_torchrun(2, _gpu_worker, 'tp', 2) results0, results1 = results[0], results[1] eps = 1e-4 diff --git a/tests/test_policies.py b/tests/test_policies.py new file mode 100644 index 00000000..b0a9f2f9 --- /dev/null +++ b/tests/test_policies.py @@ -0,0 +1,61 @@ +import tempfile +from typing import * + +import pytest +import torch +import torch.nn as nn + +from nnscaler.parallel import ComputeConfig, UserConfig, parallelize + +from .utils import init_random + +MBS = 2 +DIM = 16 +LAYERS = 16 + +class MLP(nn.Module): + def __init__(self, dim: int = DIM, nlayers: int = LAYERS): + init_random() + super().__init__() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + self.loss_fn = nn.BCELoss() + + def forward(self, data: Dict[str, torch.Tensor]): + x = data['data'] + for layer in self.layers: + x = layer(x) + x = torch.sigmoid(x) + loss = self.loss_fn(x, data['target']) + return loss + + +def dummy_data(): + return { + 'data': torch.randn( + MBS, DIM, device=torch.cuda.current_device()), + 'target': torch.rand( + MBS, DIM, device=torch.cuda.current_device()) + } + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_autodist(): + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + MLP(), + {'data': dummy_data()}, + 'autodist', + ComputeConfig(2, 4, user_config=UserConfig( + code={ + 'pas': { + 'update_freq': 1, + 'task_name': 'test_autodist', + } + } + )), + cube_savedir=tempdir, + load_module=False + ) + assert m_new is None diff --git a/tests/utils.py b/tests/utils.py index 0b9915dc..b21abd87 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,6 +10,8 @@ from datetime import timedelta from pathlib import Path +import numpy as np + import torch import torch.distributed as dist import torch.distributed.distributed_c10d as c10d @@ -43,6 +45,13 @@ def trunc_normal_(tensor: torch.Tensor, mean: float = 0., std: float = 1., a: fl torch.nn.init.constant_(param, 0) +def init_random(): + np.random.seed(1) + torch.manual_seed(1) + if torch.cuda.is_available(): + torch.cuda.manual_seed(1) + + def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-4) -> bool: """Compare the output of baseline_fn and compile_fn From eb68c09076ed98162f3ca511ca6313f3db049751 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 13 May 2024 02:33:07 +0000 Subject: [PATCH 1634/1892] Merged PR 2135: refine some graph functions & fix trace bug under torch.no_grad context 1. refine the following functions: is / is not / setitem / scaled_dot_product_attention / linear 2. fix a bug when tracing under `torch.no_grad`, caused by moving tensor between cpu and cuda, the `required_grad` attr of input tensors change to False. --- nnscaler/graph/function/function.py | 46 ++++++++++++++++--- nnscaler/graph/parser/converter.py | 1 + .../concrete_trace_utils/concrete_tracer.py | 8 +++- nnscaler/ir/tensor.py | 2 +- tests/graph/function/test_functions.py | 18 ++++++-- tests/graph/tracer/test_ctxt_manager.py | 45 ++++++++++++++++++ 6 files changed, 105 insertions(+), 15 deletions(-) create mode 100644 tests/graph/tracer/test_ctxt_manager.py diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index ba4a0921..794c9e42 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -105,11 +105,13 @@ def Accum(*inputs, signature = None): def Linear(input, weight, bias=None, signature = None): signature = 'torch.nn.functional.linear' + assert isinstance(input, IRTensor) and isinstance(weight, IRTensor) if bias is None: - annos = ['b * k+, n k+ -> b * n'] + annos = ['* k+, n k+ -> * n'] return IRDimops(Linear, 'linear', signature, annos, [input, weight], bias=None) else: - annos = ['b * k^, n k^, n -> b * n'] + assert isinstance(bias, IRTensor) + annos = ['* k^, n k^, n -> * n'] return IRDimops(Linear, 'linear', signature, annos, [input, weight, bias]) @@ -2139,7 +2141,31 @@ def SetItem(__a: Any, __b: Any, __c: Any, signature = None) -> Union[Any, IRPyFu if isinstance(obj, IRTensor): # TODO: move to some function like FullSlice when ready # TODO: give a IRTensor as return value or return a IRDimops - return IRPyFunc(signature, [__a, __b, __c], [IRObject()]) + gener = iter(string.ascii_lowercase) + # obj annotation + edim_obj = ShapeAnno.create_shape_str(obj.shape, '^', iterator=gener) + edim_out = copy.copy(edim_obj) + + edim_ins = [edim_obj] + + # index annotation + if isinstance(index, IRTensor): + edim_index = ShapeAnno.create_shape_str(index.shape, '^', iterator=gener) + edim_ins.append(edim_index) + elif isinstance(index, IRObject) and any_ir_object_satisfy(index, lambda a: isinstance(a, IRTensor)): + raise RuntimeError(f"setitem did not support slicers include tensor now, got {index}") + else: + edim_ins.append(['?']) + + # value annotation + if isinstance(val, IRTensor): + edim_val = ShapeAnno.create_shape_str(val.shape, '^', iterator=gener) + edim_ins.append(edim_val) + else: + edim_ins.append(['?']) + + anno = OpAnno.create_op_str(edim_ins, [edim_out]) + return IRDimops(SetItem, 'setitem', signature, [anno], [obj, index, val]) is_constant = not ir_object_contains_dynamic(index) index = _unwrap_value(index) @@ -2236,13 +2262,17 @@ def MakeSlice(*inputs: Iterable, signature=None): def Is(input, other, signature=None): - assert not isinstance(input, IRObject) and not isinstance(other, IRObject) - return input is other + if not isinstance(input, IRObject) and not isinstance(other, IRObject): + return input is other + else: + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, operator.is_, 'is')]) def IsNot(input, other, signature=None): - assert not isinstance(input, IRObject) and not isinstance(other, IRObject) - return input is not other + if not isinstance(input, IRObject) and not isinstance(other, IRObject): + return input is not other + else: + return IRPyFunc(signature, [input, other], [_compute_binary_op(input, other, operator.is_not, 'is_not')]) def ScaledDotProductAttention(query, key, value, attn_mask=None, dropout_p=0.0, @@ -2261,6 +2291,8 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, dropout_p=0.0, key_anno[-1] = next(gener) + '^' query_anno = copy.copy(key_anno) query_anno[-2] = next(gener) + if is_causal or attn_mask is not None: + query_anno[-2] += '^' out_anno = copy.copy(query_anno) out_anno[-1] = value_anno[-1] if attn_mask is not None: diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 1455c3c9..f614af72 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -158,5 +158,6 @@ def convert_model( IRGraph: IRGraph of model """ traced_model = to_fx_graph(model, dummy_input) + _logger.debug(f'the traced model is:\n{traced_model}') graph = to_ir_graph(traced_model, dummy_input, attr_savedir, dynamic_shape) return graph diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 21ef0a92..45884a31 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -1353,12 +1353,16 @@ def types_other_than(pytree, given_types) -> Set[Type]: def tree_to_cuda(pytree): """return a same spec pytree with all the given pytree leaf tensor to cuda""" - return map_trees_with_func(lambda a: a.cuda() if isinstance(a, torch.Tensor) else a, [pytree]) + # any operations under torch.no_grad context will have the result tensor with attribute requires_grad is False, + # here we must follow the original tensor requires_grad attribute when we move tensor to cuda to ensure the correctness of the tensor requires_grad state + return map_trees_with_func(lambda a: a.cuda().requires_grad_(a.requires_grad) if isinstance(a, torch.Tensor) else a, [pytree]) def tree_to_cpu(pytree): """return a same spec pytree with all the given pytree leaf tensor to cpu""" - return map_trees_with_func(lambda a: a.cpu() if isinstance(a, torch.Tensor) else a, [pytree]) + # any operations under torch.no_grad context will have the result tensor with attribute requires_grad is False, + # here we must follow the original tensor requires_grad attribute when we move tensor to cpu to ensure the correctness of the tensor requires_grad state + return map_trees_with_func(lambda a: a.cpu().requires_grad_(a.requires_grad) if isinstance(a, torch.Tensor) else a, [pytree]) def unwrap_nested_proxy(pytree): diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index fad9ae47..94a7553f 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -404,7 +404,7 @@ def tosub(self): return sub_tensor def __repr__(self): - dscp = f'FullTensor(id={self._id}, shape={self.shape})' + dscp = f'FullTensor(id={self._id}, shape={self.shape}, req_grad={self.requires_grad})' return dscp def extra_repr(self) -> str: diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index aa82bf36..d24fd357 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -354,16 +354,18 @@ def test_Unsqueeze(): def test_ScaledDotProductAttention(): op = F.ScaledDotProductAttention(IRTensor([8, 128, 64]), IRTensor([8, 256, 64]), IRTensor([8, 256, 32]), None, 0.05) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a e d^, a b^ d^, a b^ c -> a e c' + op = F.ScaledDotProductAttention(IRTensor([8, 128, 64]), IRTensor([8, 256, 64]), IRTensor([8, 256, 32]), None, 0.05, True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a e^ d^, a b^ d^, a b^ c -> a e^ c' op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([128, 256]), 0.05) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, f c^ -> a b f d' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f^ e^, a b c^ e^, a b c^ d, f^ c^ -> a b f^ d' op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([1, 128, 256]), 0.05) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, 1 f c^ -> a b f d' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f^ e^, a b c^ e^, a b c^ d, 1 f^ c^ -> a b f^ d' op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([1, 8, 128, 256]), 0.05) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, 1 b f c^ -> a b f d' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f^ e^, a b c^ e^, a b c^ d, 1 b f^ c^ -> a b f^ d' op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([1, 1, 256]), 0.05) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, 1 1 c^ -> a b f d' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f^ e^, a b c^ e^, a b c^ d, 1 1 c^ -> a b f^ d' op = F.ScaledDotProductAttention(IRTensor([16, 8, 128, 64]), IRTensor([16, 8, 256, 64]), IRTensor([16, 8, 256, 32]), IRTensor([1, 8, 128, 1]), 0.05) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f e^, a b c^ e^, a b c^ d, 1 b f 1 -> a b f d' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b f^ e^, a b c^ e^, a b c^ d, 1 b f^ 1 -> a b f^ d' @@ -398,6 +400,12 @@ def test_Setitem(): assert op.outputs()[0].value == [set_val, 2, 3] assert not op.outputs()[0].is_constant + op = F.SetItem(IRTensor([3, 4, 5]), IRObject(value=slice(0, 5, 1)), IRObject(value=1.)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, ?, ? -> a^ b^ c^' + + op = F.SetItem(IRTensor([3, 4, 5]), IRTensor([3, 4, 5]), IRObject(value=1.)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, d^ e^ f^, ? -> a^ b^ c^' + def test_Len(): op = F.Len([1, 2, 3], signature='builtins.len') diff --git a/tests/graph/tracer/test_ctxt_manager.py b/tests/graph/tracer/test_ctxt_manager.py new file mode 100644 index 00000000..8e0de855 --- /dev/null +++ b/tests/graph/tracer/test_ctxt_manager.py @@ -0,0 +1,45 @@ +import tempfile +import torch +from nnscaler.graph.parser.converter import convert_model + +from ...utils import replace_all_device_with + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x): + with torch.no_grad(): + y = self.fc(x) + z = self.fc(x) + return y + z + + +@replace_all_device_with('cpu') +def test_requires_grad(): + with tempfile.TemporaryDirectory() as tempdir: + model = SimpleModel() + dummy_input = {'x': torch.rand(10)} + graph = convert_model(model, dummy_input, tempdir) + + node_no_grad_fc, node_fc, node_add = graph.nodes() + # x under no grad context + assert node_no_grad_fc.inputs()[0].parent.requires_grad is False + # fc weight under no grad context + assert node_no_grad_fc.inputs()[1].parent.requires_grad is True + # fc output under no grad context + assert node_no_grad_fc.outputs()[0].parent.requires_grad is False + # x outside no grad context + assert node_fc.inputs()[0].parent.requires_grad is False + # fc weight outside no grad context + assert node_fc.inputs()[1].parent.requires_grad is True + # fc output outside no grad context + assert node_fc.outputs()[0].parent.requires_grad is True + # y + assert node_add.inputs()[0].parent.requires_grad is False + # z + assert node_add.inputs()[1].parent.requires_grad is True + # result + assert node_add.outputs()[0].parent.requires_grad is True From 8525b4b0043eb7c6cc94be5e4c12affa8fe60686 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 14 May 2024 08:49:20 +0000 Subject: [PATCH 1635/1892] Merged PR 2121: Refine recompute logic in autodist Refine the memory estimation formula in **ILP** solver, will update the dynamic solver in future PRs. As a result, the default solver is **ILP** now. The memory constraint in the optimization problem is: $$param + \max(activation, optimizer\_transient) + recompute + fragmentation\_penalty \leq mem\_bound$$ Here: - $param$ includes parameter, buffer, gradient of parameter, resident optimizer states - $activation$ is the size of the tensors stored in the forward pass for backward - $optimizer\_transient$ is the transient memory cost of optimizer, like type casting in the fp16 memory efficient optimizer - $recompute$ is the max size for the saved tensors across all recompute modules - $fragmentation\_penalty$ is a heuristic term that penalize operator partitions that generate large tensors To estimate the $recompute$ term correctly, in this PR - add `recompute_border` field in operator to annotate whether an operator is the *border* of a recompute module - add `omit_recompute_in_idx` field in operator: when applying recompute to a module, torch saves this module's input tensors, and launch another forward pass during backward. We use this flag to avoid counting the shared input tensor multiple times. - update the memory constraint term in the ILP solver --- .../profile}/mi200/comm/intra_16.json | 0 .../profile}/mi200/comm/intra_2.json | 0 .../profile}/mi200/comm/intra_4.json | 0 .../profile}/mi200/comm/intra_8.json | 0 .../solver_interface/partition_constraint.md | 2 +- nnscaler/autodist/apis.py | 5 +- nnscaler/autodist/autodist_config.py | 15 +- nnscaler/autodist/cost_database.py | 13 +- nnscaler/autodist/cube_operator.py | 12 + nnscaler/autodist/model_graph.py | 105 +++++++-- nnscaler/autodist/spmd_solver.py | 214 +++++++++++++++--- tests/autodist/graph/test_recompute.py | 48 ++-- 12 files changed, 334 insertions(+), 80 deletions(-) rename {profile_data => data/profile}/mi200/comm/intra_16.json (100%) rename {profile_data => data/profile}/mi200/comm/intra_2.json (100%) rename {profile_data => data/profile}/mi200/comm/intra_4.json (100%) rename {profile_data => data/profile}/mi200/comm/intra_8.json (100%) diff --git a/profile_data/mi200/comm/intra_16.json b/data/profile/mi200/comm/intra_16.json similarity index 100% rename from profile_data/mi200/comm/intra_16.json rename to data/profile/mi200/comm/intra_16.json diff --git a/profile_data/mi200/comm/intra_2.json b/data/profile/mi200/comm/intra_2.json similarity index 100% rename from profile_data/mi200/comm/intra_2.json rename to data/profile/mi200/comm/intra_2.json diff --git a/profile_data/mi200/comm/intra_4.json b/data/profile/mi200/comm/intra_4.json similarity index 100% rename from profile_data/mi200/comm/intra_4.json rename to data/profile/mi200/comm/intra_4.json diff --git a/profile_data/mi200/comm/intra_8.json b/data/profile/mi200/comm/intra_8.json similarity index 100% rename from profile_data/mi200/comm/intra_8.json rename to data/profile/mi200/comm/intra_8.json diff --git a/docs/autodist/solver_interface/partition_constraint.md b/docs/autodist/solver_interface/partition_constraint.md index 64e4de57..d71fc692 100644 --- a/docs/autodist/solver_interface/partition_constraint.md +++ b/docs/autodist/solver_interface/partition_constraint.md @@ -28,7 +28,7 @@ In autodist, we provide a set of partition constraints to control the distribute In this example, we have four partition constraints for the MoE model in retnet. Each partition constraint has 4 fields: `name`, `parent_module`, `allowed_partition_dims`, and `replica_allowed`. - `name` is the name of the corresponding operator in the model. It equals to the `signature` field in the `IRFwOperation` in cube. Note: signature is the full name of the operator, for example, you should provide `torch.nn.functional.linear` instead of `linear`. -- `parent_module` is the **closest** father module name of the operator. You can provide two partition constraints with a same `name` but different `module` to control the partition of the same operator in different modules. +- `parent_module` is the **closest** father module name of the operator. You can provide two partition constraints with a same `name` but different `module` to control the partition of the same operator in different modules. Similar to `recompute_modules`, Module name can be any suffix of the full module name, e.g., `module1` will match `x.module1`, `y.module1`, `x.module1` will match `x.module1` but not `y.module1`. - `allowed_partition_dims` is a list of allowed partition dimensions of input tensors. Each element in the list is a list of two integers, which are the index of the partitioned tensor among inputs and the partitioned dimension of that tensor. For example, the annotation of `torchscale.component.xmoe.routing.compute_logits` can be `(C 16) E^ C, E^ C M^ -> (C 16) M^`. `allowed_partition_dims = [[0, 0]]` means we only allow to partition the first input tensor along the first dimension, which is `(C, 16)` in this case. An empty list means no partition is allowed, note that in yaml, you should give an empty list explicitly, i.e., `allowed_partition_dims: []`. - `replica_allowed` is a boolean value. If it is `true`, it is allowed to replicate the operator across devices. diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index d0816d31..56bdbc7d 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -68,7 +68,7 @@ def to_gb(size): zero_group_size=zero_group_size, cfg=graph.autodist_config, ) - min_single_dev_mem += graph.recompute_mem + min_single_dev_mem += graph.min_recompute_mem _logger.info( f'estimated minimum memory per device {to_mb(min_single_dev_mem)} MB') mem_constraint = graph.autodist_config.memory_constraint @@ -91,14 +91,13 @@ def calc_parallel_plan(graph: IRGraph, recompute_groups = [ [node.cid for node in group] for group in recompute_groups ] - recompute_mem = autodist_graph.recompute_mem / 1024 / 1024 / 1024 if autodist_config.pipeline: pp_out = calc_optimal_pp_plan(autodist_graph, autodist_config) else: pp_out = calc_optimal_spmd_plan(autodist_graph, autodist_config) pp_out.desc.recompute_groups = recompute_groups - pp_out.stage_mems = [mem + recompute_mem for mem in pp_out.stage_mems] + pp_out.stage_mems = [mem for mem in pp_out.stage_mems] return pp_out diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index f2d9a863..b261de58 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -44,7 +44,7 @@ class AutoDistConfig: - fp16 & bf16 training w/ memory efficient adam w/o inkernal cast: (2 + 2) (fp32 weight + fp32 gradient) - partition_constraints_path (`str`, *optional*, defaults to `''`): The path to the partition constraints file. Details can be found in docs/solver_interface/partition_constraints.md - - profile_dir (`str`, *optional*, defaults to `~/.nnscaler/autodist`): + - profile_dir (`str`, *optional*, defaults to `~/.cache/nnscaler/autodist`): The directory to store the profiling results. - load_plan_path (`str`, *optional*, defaults to `''`): The path to the plan file to load. If specified, the plan will be loaded from the file instead of searching. @@ -65,6 +65,8 @@ class AutoDistConfig: The number of available devices in each node. - recompute_modules (`str`, *optional*, defaults to `''`): The module names to recompute, separated by `,`. For example, `module1,module2`. + Module name can be any suffix of the full module name, e.g., `module1` will match `x.module1`, `y.module1`, + `x.module1` will match `x.module1` but not `y.module1`. - memory_constraint (`float`, *optional*, defaults to `32`): The memory constraint in each device in GB. - memory_granularity (`int`, *optional*, defaults to `1`): @@ -97,7 +99,9 @@ class AutoDistConfig: The maximum unbalance ratio in pipeline parallelism. The higher the ratio, the more unbalance is required, the smaller search space will be explored. - solver (`str`, *optional*, defaults to `'dp'`): - The solver to use in spmd parallelism. Currently only support `'dp'` (dynamic programming) and `'ilp'` (integer linear programming). + The solver to use in spmd parallelism. Currently only support + `'dp'` (dynamic programming) + `'ilp'` (integer linear programming). """ def __init__(self, @@ -130,7 +134,7 @@ def __init__(self, pipeline_nstages=1, max_pipeline_bubble_ratio=0.4, max_pipeline_unbalance_ratio=0.5, - solver='dp', + solver='ilp', **kwargs): self.pc_path = partition_constraints_path self.profile_dir = profile_dir @@ -220,7 +224,10 @@ def _validate_config(self): f'world size {self.world_size} must be divisible by zero num groups {self.zero_ngroups}' ) - if not self.solver in ['dp', 'ilp']: + if not self.solver in [ + 'dp', + 'ilp', + ]: raise ValueError(f'solver {self.solver} must be dp or ilp') def __repr__(self): diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 49cf325f..d34125a3 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Callable, Dict import numpy as np import json import os @@ -53,7 +53,7 @@ def _piecewise_estimator(xs: List[float], ys: List[float], x: float) -> float: raise RuntimeError(f'x={x}, xs={xs}, ys={ys}, should not reach here') import nnscaler -_DEFAULT_COMM_DATA_PATH = Path(nnscaler.__file__).parent.parent / 'profile_data/mi200/comm' +_DEFAULT_COMM_DATA_PATH = Path(nnscaler.__file__).parent.parent / 'data/profile/mi200/comm' class CostDatabase: @@ -159,6 +159,7 @@ def get_mem_and_buffer(self, op_partition, is_train: bool, stage_num: int): node_buffer: the buffer memory consumption of the partition option activation_mem: the activation memory consumption of the partition option opt_transient_mem: the optimizer transient memory consumption of the partition option + input_mem: the input memory consumption of the partition option """ memory_results = self.get_mems(op_partition) activation_mem = memory_results['train'] @@ -209,7 +210,8 @@ def to_mb(x): f'activation mem: {to_mb(activation_mem)} MB, ' + f'optimizer transient mem: {to_mb(opt_transient_mem)} MB') - return node_mem, node_buffer, activation_mem, opt_transient_mem + return node_mem, node_buffer, activation_mem, opt_transient_mem, memory_results[ + 'input'] def query_single_mem(self, obj, memory_type, round=True) -> int: """ @@ -482,6 +484,9 @@ def helper(mems): for mem in mems ] + in_mask = helper(inputs) + for idx in op.omit_recompute_in_idx: + in_mask[idx] = 0 param_mask = helper(param) for idx in op.omit_param_idx: param_mask[idx] = 0 @@ -495,7 +500,7 @@ def helper(mems): # the saved input tensors for backward have been considered in train_m. masks = { - 'input': helper(inputs), + 'input': in_mask, 'param': param_mask, 'train': train_m_mask, 'buffer': buffer_mask, diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py index fe1ad671..24a987e9 100644 --- a/nnscaler/autodist/cube_operator.py +++ b/nnscaler/autodist/cube_operator.py @@ -15,8 +15,10 @@ class CubeOperator: - dim_info: a mapping from dimension name to its position and reduce type - parallelable_dims: a set of dimension names that can be parallelized - recompute: a flag indicating whether the operator will be recomputed + - recompute_start_op: a flag indicating whether the operator consumes tensors outside of a recompute region - has_batch_dim: a flag indicating whether the operator has a batch dimension - since there can be shared tensors in the model, we use the following vars to estimate the memory usage accurately: + - omit_recompute_in_idx: a list of indices of input tensors that should be omitted - omit_train_idx: a list of indices of activation tensors that should be omitted - omit_param_idx: a list of indices of parameter tensors that should be omitted - omit_buffer_idx: a list of indices of buffer tensors that should be omitted @@ -33,7 +35,9 @@ def __init__(self, ir_cell: IRFwOperation): self.dim_info = {} self.parallelable_dims = set() self._recompute = False + self._recompute_start_op = False + self.omit_recompute_in_idx = [] self.omit_train_idx = [] self.omit_param_idx = [] self.omit_buffer_idx = [] @@ -59,6 +63,14 @@ def recompute(self): def recompute(self, value: bool): self._recompute = value + @property + def recompute_start_op(self): + return self._recompute_start_op + + @recompute_start_op.setter + def recompute_start_op(self, value: bool): + self._recompute_start_op = value + def add_producer(self, producer: 'CubeOperator'): self.producers.add(producer) diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 766efc6d..44d4c240 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -460,7 +460,7 @@ def __init__(self, ir_graph: IRGraph, autodist_config: AutoDistConfig): self.scope_tree_root = self.reconstruct_scope_tree() self.scope_leaf_nodes = self.scope_tree_root.select(lambda x: x.is_leaf) - self.recompute_mem, self.recompute_groups = self.init_recompute_nodes() + self.min_recompute_mem, self.recompute_groups = self.init_recompute_nodes() self.operator_list: List[CubeOperator] = [] self._ir_cell2idx: Dict[IRFwOperation, int] = dict() @@ -620,25 +620,30 @@ def init_recompute_nodes(self): if len(recompute_modules) == 0: return 0, [] - def fetch_module(scope_node): + def fetch_module(scope_node: ScopeNode, prefix: List[str]): if scope_node.node is not None: return [] + if scope_node.is_root: + next_prefix = copy.deepcopy(prefix) + else: + next_prefix = prefix + [scope_node.module_type.__name__] + cur_name = '.'.join(next_prefix) for module in recompute_modules: - if module in str(scope_node.module_type): + if module in cur_name: return [scope_node] ret = [] for child in scope_node.children: - ret += fetch_module(child) + ret += fetch_module(child, next_prefix) return ret - modules = fetch_module(self.scope_tree_root) - in_mem, train_mem = 0, 0 + modules = fetch_module(self.scope_tree_root, []) + train_mem = 0 for module in modules: - in_mem += module.in_mem train_mem = max(train_mem, module.train_mem) - recompute_mem = in_mem + train_mem - _logger.info(f'recompute mem {recompute_mem / 1024 / 1024} MB') - self.autodist_config.memory_constraint -= recompute_mem + # calculate the lower bound of memory consumption for recompute + # assume the activation memory is evenly distributed across devices + min_recompute_mem = train_mem / self.autodist_config.ngpus + _logger.info(f'estimated recompute mem {min_recompute_mem / 1024 / 1024} MB') def fetch_nodes(scope_node): if scope_node.node is not None: @@ -651,14 +656,16 @@ def fetch_nodes(scope_node): recompute_groups = [] for module in modules: recompute_groups.append(fetch_nodes(module)) - return recompute_mem, recompute_groups + return min_recompute_mem, recompute_groups def label_ops(self, operator_list: List[CubeOperator]): + # NOTE: complicated input composed of tensors are not considered, like list of tensors # label the tensors that are shared by multiple operators, examples: # 1. the embedding matrix is shared by embedding lookup and the last linear layer # 2. the activation tensor is shared by query, key and value projections in transformer # label the operators that have been set to recompute counted_tensors: Set[IRTensor] = set() + counted_in_tensors: Set[IRTensor] = set() recompute_nodes: Set[IRFwOperation] = set() for group in self.recompute_groups: recompute_nodes.update(group) @@ -679,7 +686,7 @@ def label_ops(self, operator_list: List[CubeOperator]): # deduplicate parameter and buffer tensors # assume the traverse order of input tensors is the same as # the order in profiling - b_idx, w_idx = -1, -1 + in_idx, b_idx, w_idx = -1, -1, -1 for in_tensor in operator.in_tensors: if in_tensor.is_param(): assert not in_tensor.is_buffer() @@ -688,16 +695,60 @@ def label_ops(self, operator_list: List[CubeOperator]): operator.omit_param_idx.append(w_idx) else: counted_tensors.add(in_tensor.tid) - if in_tensor.is_buffer(): + elif in_tensor.is_buffer(): assert not in_tensor.is_param() b_idx += 1 if in_tensor.tid in counted_tensors: operator.omit_buffer_idx.append(b_idx) else: counted_tensors.add(in_tensor.tid) + else: + in_idx += 1 + # avoid an input tensor is counted multiple times + # when it is shared by multiple operators on the + # border of recompute groups. For example, if tensor + # x is consumed by two operators a and b who are on the + # border of a recompute group, x should not be counted twice. + if in_tensor.tid in counted_in_tensors: + operator.omit_recompute_in_idx.append(in_idx) + else: + counted_in_tensors.add(in_tensor.tid) if operator.ir_cell in recompute_nodes: operator.recompute = True - operator.omit_train_idx = list(range(len(train_mem2in_idx))) + + # label border operators for recompute groups + for group in self.recompute_groups: + output_tensors: Set[IRTensor] = set() + for node in group: + for t in node.outputs(): + if isinstance(t, IRTensor): + output_tensors.add(t) + for node in group: + is_border = False + for t in node.inputs(): + if isinstance(t, IRTensor) and not t.is_attr(): + if not t in output_tensors: + is_border = True + break + if is_border: + op = operator_list[self._ir_cell2idx[node]] + op.recompute_start_op = True + train_mem2in_idx = self.cost_database.query_profiled_metrics( + op).train_mem2in_idx + for idx, tensor in enumerate(op.in_tensors): + if tensor.is_attr(): + continue + if tensor in output_tensors: + # avoid count multiple times when the input is another + # border operator's output + op.omit_recompute_in_idx.append(idx) + else: + # avoid count multiple times when the input has been + # saved by the recompute interface + if idx in train_mem2in_idx: + i = train_mem2in_idx.index(idx) + if i not in op.omit_train_idx: + op.omit_train_idx.append(i) def query_mem(self, start: int, end: int) -> Tuple[int, int, int]: ''' @@ -717,9 +768,13 @@ def leaf_handler(idx): op = self.operator_list[idx] if not isinstance(op.ir_cell, IRDimops): return 0, 0, 0 - return db_inst.query_single_mem(op, 'param', round=False), \ - db_inst.query_single_mem(op, 'buffer', round=False), \ - db_inst.query_single_mem(op, 'train', round=False) + param_mem = db_inst.query_single_mem(op, 'param', round=False) + buffer_mem = db_inst.query_single_mem(op, 'buffer', round=False) + # set the activation memory to 0 if the operator is set to recompute. + # the memory is considered in `min_recompute_mem` instead + activation_mem = 0 if op.recompute else db_inst.query_single_mem( + op, 'train', round=False) + return param_mem, buffer_mem, activation_mem def merger(sub_rets): param_mem, buffer_mem, activation_mem = 0, 0, 0 @@ -787,11 +842,25 @@ def init_operators(self): if not op.has_batch_dim: _logger.debug(f'{op.ir_cell} don\'t have batch dim') - self.label_ops(operator_list) if len(operator_list) != len(self.scope_leaf_nodes): raise RuntimeError( f'expect {len(self.scope_leaf_nodes)} operators, got {len(operator_list)}' ) for i, op in enumerate(operator_list): self._ir_cell2idx[op.ir_cell] = i + self.label_ops(operator_list) self.operator_list = operator_list + + self._recompute_group_idxs: List[List[int]] = list() + for recompute_group in self.recompute_groups: + interval = [] + for node in recompute_group: + interval.append(self._ir_cell2idx[node]) + start, end = interval[0], interval[-1] + if end - start + 1 != len(interval): + raise RuntimeError('recompute nodes are not continuous') + self._recompute_group_idxs.append(interval) + + @property + def recompute_group_idxs(self) -> List[List[int]]: + return self._recompute_group_idxs diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index eb89db9d..8f0e5cb9 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -40,6 +40,8 @@ class PartitionCostDesc: # 4. buffer memory (weight does not need gradient) # 5. optimizer resident memory: like 1st and 2nd moment in Adam mem: int + # input memory size, used in recompute + in_mem: int # additional transient mem cost: currently use the maximum tensor size transient_mem: int # sum of the activation tensor size @@ -283,24 +285,23 @@ def is_valid_partition(operator: CubeOperator, p_ids: List[Any], raise RuntimeError( f'operator {operator.op_name} is not a dimops, check the partition constraint' ) - module_info = [(module_path, module_type) - for module_path, module_type in - operator.ir_cell.module_stack.items()] - module_info.reverse() - selected_pc = None - for module_path, module_type in module_info: - for pc in self.pcs[operator.op_name].values(): - if pc.parent_module in str(module_type): - selected_pc = pc - _logger.debug( - f'find partition constraint {pc} for {operator.ir_cell}' - ) - break - if selected_pc is not None: - break + nested_module_type = '.'.join( + [module_type.__name__ for _, module_type in operator.ir_cell.module_stack.items()]) + candidate_pcs: List[Tuple[int, PartitionConstraint]] = [] + for pc in self.pcs[operator.op_name].values(): + name_pos = nested_module_type.rfind(pc.parent_module) + if name_pos == -1: + continue + # use the length of the parent module name to find the closest partition constraint + candidate_pcs.append([len(pc.parent_module), pc]) + candidate_pcs.sort(key=lambda x: -x[0]) + if candidate_pcs: + selected_pc = candidate_pcs[0][1] + else: + selected_pc = None if selected_pc is not None: _logger.debug( - f'find partition constraint {selected_pc} for {operator.ir_cell} {module_info}' + f'find partition constraint {selected_pc} for {operator.ir_cell} {nested_module_type}' ) self.non_used_pcs.discard(selected_pc) for u, v in zip(p_ids, p_nums): @@ -597,9 +598,9 @@ def calc_partition_cost(self, op_idx: int, partition_idx: int): weight_comm_time = 0 if not self.autodist_config.consider_mem: - node_mem, node_buffer, act_mem, opt_transient_mem = 0, 0, 0, 0 + node_mem, node_buffer, act_mem, opt_transient_mem, in_mem = 0, 0, 0, 0, 0 else: - node_mem, node_buffer, act_mem, opt_transient_mem = self.cost_database.get_mem_and_buffer( + node_mem, node_buffer, act_mem, opt_transient_mem, in_mem = self.cost_database.get_mem_and_buffer( tgt_p, self.is_train, self.stage_num) # communication cost induced by partitioning activation tensors of the given op partition @@ -641,6 +642,7 @@ def calc_partition_cost(self, op_idx: int, partition_idx: int): comp_time=micro_batch_num * comp_time, weight_update_time=weight_comm_time, mem=node_mem, + in_mem=in_mem, transient_mem=node_buffer, activation_mem=act_mem, opt_transient_mem=opt_transient_mem, @@ -722,15 +724,54 @@ def gen_min_mem_plan_greedy(self, start: int, plan.append((i, cur_mem.index(min(cur_mem)))) return plan - def satisfy_mem_constraint(self, plan: List[Tuple[int, int]]) -> bool: + def calc_mem_cost(self, plan: List[Tuple[int, int]]) -> int: + ''' + calculate the memory cost of the plan + + Args: + plan (List[Tuple[int, int]]): the plan to be evaluated + + Returns: + int: the memory cost of the plan in bytes + ''' + + def to_mb(size: int) -> int: + return size // 1024 // 1024 + mem, act_mem, opt_transient_mem, transient_mem = 0, 0, 0, [] for op_idx, p_idx in plan: desc = self.partition_info[op_idx][p_idx] - mem += desc.mem - act_mem += desc.activation_mem + if self.graph.operator_list[op_idx].recompute: + if self.graph.operator_list[op_idx].recompute_start_op: + mem += desc.in_mem + mem += desc.mem - desc.activation_mem + else: + mem += desc.mem + act_mem += desc.activation_mem opt_transient_mem += desc.opt_transient_mem transient_mem.append(desc.transient_mem) + _logger.info(f'resident mem: {to_mb(mem)} MB') + _logger.info(f'activation mem: {to_mb(act_mem)} MB') + _logger.info(f'opt transient mem: {to_mb(opt_transient_mem)} MB') cost = mem - act_mem + max(act_mem, opt_transient_mem) + + start, end = plan[0][0], plan[-1][0] + recompute_mem_cost = 0 + for group in self.graph.recompute_group_idxs: + cur_start, cur_end = group[0], group[-1] + # do not consider the recompute cost when it is out of the current stage + if cur_start > end or cur_end < start: + continue + if cur_start >= start and cur_end <= end: + cur_recompute_mem_cost = 0 + for i in range(cur_start, cur_end + 1): + p_cost_desc = self.partition_info[i][plan[i - start][1]] + cur_recompute_mem_cost += p_cost_desc.activation_mem + recompute_mem_cost = max(recompute_mem_cost, + cur_recompute_mem_cost) + _logger.info(f'recompute mem: {to_mb(recompute_mem_cost)} MB') + cost += recompute_mem_cost + # A heuristic that helps to estimate the memory cost accurately. # It is hard to fully reuse large memory blocks in the cached allocator. # - in training, use the maximum 2 transient memory @@ -740,9 +781,51 @@ def satisfy_mem_constraint(self, plan: List[Tuple[int, int]]) -> bool: transient_mem.reverse() if len(transient_mem) == 1 or not self.autodist_config.is_train: cost += transient_mem[0] + _logger.info(f'transient mem: {to_mb(transient_mem[0])} MB') else: cost += transient_mem[0] + transient_mem[1] - return cost <= self.mem_bound + _logger.info( + f'transient mem: {to_mb(transient_mem[0])} MB, {to_mb(transient_mem[1])} MB' + ) + _logger.info(f'total mem cost: {to_mb(cost)} MB') + return cost + + def calc_inner_time_cost(self, plan: List[Tuple[int, int]]) -> float: + ''' + calculate the inner time cost of the plan: computation time + weight update time + + Args: + plan (List[Tuple[int, int]]): the plan to be evaluated + + Returns: + float: the inner time cost of the plan + ''' + cost = 0.0 + for op_idx, p_idx in plan: + desc = self.partition_info[op_idx][p_idx] + cost += desc.comp_time + desc.weight_update_time + return cost + + def calc_intra_time_cost(self, plan: List[Tuple[int, int]]) -> float: + ''' + calculate the intra time cost of the plan: communication time between operators + + Args: + plan (List[Tuple[int, int]]): the plan to be evaluated + + Returns: + float: the intra time cost of the plan + ''' + cost = 0.0 + op_idx2p_idx: Dict[int, int] = dict(plan) + for op_idx, p_idx in plan: + desc = self.partition_info[op_idx][p_idx] + for k, comm_vec in enumerate(desc.comm_time): + producer = self.producers[op_idx][k] + if not producer in op_idx2p_idx: + continue + cost += comm_vec[op_idx2p_idx[producer]] + return cost def build_cut_ops(self): cid2idx = {} @@ -811,14 +894,15 @@ def _solve_by_ilp(self, start: int, end: int) -> SPMDSearchOutput: (range(len(s[i]) * len(s[j])),), cat='Binary')) + # NOTE: comment temporarily, refine it later # 2. set initial value for warm start - plan = self.gen_min_mem_plan_greedy(start, end) - for op_idx, p_idx in plan: - s_idx = op_idx - start - if len(s[s_idx]) == 1: - continue - for i in range(len(s[s_idx])): - s[s_idx][i].setInitialValue(i == p_idx) + # plan = self.gen_min_mem_plan_greedy(start, end) + # for op_idx, p_idx in plan: + # s_idx = op_idx - start + # if len(s[s_idx]) == 1: + # continue + # for i in range(len(s[s_idx])): + # s[s_idx][i].setInitialValue(i == p_idx) # 3. define the objective function prob = LpProblem('SPMD', LpMinimize) @@ -865,25 +949,51 @@ def _solve_by_ilp(self, start: int, end: int) -> SPMDSearchOutput: max_transient = LpVariable('max_transient', lowBound=0) for i in range(start, end + 1): cur_mem = [] + cur_in_mem = [] cur_act_mem = [] + cur_param_mem = [] cur_opt_transient_mem = [] cur_transient_mem = [] for desc in self.partition_info[i]: cur_mem.append(desc.mem) + cur_in_mem.append(desc.in_mem) cur_act_mem.append(desc.activation_mem) + cur_param_mem.append(desc.mem - desc.activation_mem) cur_opt_transient_mem.append(desc.opt_transient_mem) cur_transient_mem.append(desc.transient_mem) - mem += lpDot(s[i - start], cur_mem) - act_mem += lpDot(s[i - start], cur_act_mem) + if not self.graph.operator_list[i].recompute: + mem += lpDot(s[i - start], cur_mem) + act_mem += lpDot(s[i - start], cur_act_mem) + else: + if self.graph.operator_list[i].recompute_start_op: + mem += lpDot(s[i - start], cur_in_mem) + mem += lpDot(s[i - start], cur_param_mem) opt_transient_mem += lpDot(s[i - start], cur_opt_transient_mem) prob += lpDot(s[i - start], cur_transient_mem) <= max_transient + recompute_mem = LpVariable('recompute_mem', lowBound=0) + for group in self.graph.recompute_group_idxs: + cur_start, cur_end = group[0], group[-1] + if cur_start > end or cur_end < start: + continue + if cur_start >= start and cur_end <= end: + cur_group_mem = 0 + for i in range(cur_start, cur_end + 1): + cur_act_mem = [] + for desc in self.partition_info[i]: + cur_act_mem.append(desc.activation_mem) + cur_group_mem += lpDot(s[i - start], cur_act_mem) + prob += cur_group_mem <= recompute_mem + else: + _logger.warning( + f'interval {start} {end} and recompute group {cur_start} {cur_end} overlap' + ) prob += act_mem <= max_act_opt_transient prob += opt_transient_mem <= max_act_opt_transient if self.autodist_config.is_train: transient_coef = 2 else: transient_coef = 1 - prob += mem - act_mem + max_act_opt_transient + transient_coef * max_transient <= self.mem_bound + prob += mem - act_mem + max_act_opt_transient + transient_coef * max_transient + recompute_mem <= self.mem_bound # 4.3. constraint over e offset = 0 @@ -965,18 +1075,18 @@ def get_non_zero_index(binary_vector): offset += 1 plans = [] all_time_cost = objective - inner_time_cost, mem_cost = 0, 0 + inner_time_cost = 0 for i in range(start, end + 1): plans.append((i, s_val[i - start])) p_cost_desc = self.partition_info[i][s_val[i - start]] inner_time_cost += p_cost_desc.comp_time + p_cost_desc.weight_update_time - mem_cost += p_cost_desc.mem + mem_cost = self.calc_mem_cost(plans) return SPMDSearchOutput(self.partition_path2desc(plans), mem_cost / 1024 / 1024 / 1024, all_time_cost, inner_time_cost) def do_ilp(self, intervals: List[Tuple[int, int]], - topk: int) -> List[SPMDSearchOutput]: + topk: int) -> List[List[SPMDSearchOutput]]: if topk != 1: raise RuntimeError('topk != 1 is not supported') ret = [] @@ -990,7 +1100,7 @@ def do_ilp(self, intervals: List[Tuple[int, int]], return ret def do_dp(self, intervals: List[Tuple[int, int]], - topk: int) -> List[SPMDSearchOutput]: + topk: int) -> List[List[SPMDSearchOutput]]: import cppimport.import_hook import nnscaler.autodist.dp_solver as dp_solver @@ -1022,6 +1132,18 @@ def do_dp(self, intervals: List[Tuple[int, int]], def solve(self, intervals: List[Tuple[int, int]], topk: int) -> List[SPMDSearchOutput]: + ''' + generate the optimal partition plan for operators in the interval [start, end] by + integer linear programming (ILP) or dynamic programming (DP). Communication cost + between the node in the interval to its producer outside the interval is not considered. + + Args: + intervals (List[Tuple[int, int]]): the intervals to be solved + topk (int): the number of top-k plans for each interval + + Returns: + List[List[SPMDSearchOutput]]: the top-k partition plans for each interval + ''' if self.autodist_config.solver == 'ilp': return self.do_ilp(intervals, topk) elif self.autodist_config.solver == 'dp': @@ -1032,6 +1154,15 @@ def solve(self, intervals: List[Tuple[int, int]], def partition_path2desc( self, plans: List[Tuple[int, int]]) -> Dict[int, NodePartitionDesc]: + ''' + convert the partition representation: (op_idx, partition_idx) to (op_cid, partition_desc) + + Args: + plans (List[Tuple[int, int]]): the partition plan to be converted + + Returns: + Dict[int, NodePartitionDesc]: the converted partition plan + ''' partitions = [self._op_partitions[u][v] for u, v in plans] partition_descs = {} @@ -1051,6 +1182,17 @@ def partition_path2desc( def calc_optimal_spmd_plan( model_graph: ModelGraph, autodist_config: AutoDistConfig) -> PipelineSearchOutput: + ''' + calculate the optimal sigle-program-multiple-data plan for the input graph, + the returned plan is wrapped in a PipelineSearchOutput object + + Args: + model_graph (ModelGraph): the wrapped input IRGraph + autodist_config (AutoDistConfig): the configuration for AutoDist + + Returns: + PipelineSearchOutput: the optimal plan + ''' spmd_solver = SPMDSolver( graph=model_graph, mesh_desc=autodist_config.mesh_desc, diff --git a/tests/autodist/graph/test_recompute.py b/tests/autodist/graph/test_recompute.py index e34254c1..a75d4ab4 100644 --- a/tests/autodist/graph/test_recompute.py +++ b/tests/autodist/graph/test_recompute.py @@ -3,6 +3,7 @@ import tempfile import torch from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.ir.operator import IRFwOperation from nnscaler.autodist.model_graph import ModelGraph from nnscaler.autodist.autodist_config import AutoDistConfig @@ -33,6 +34,18 @@ def forward(self, x): x = x + residual return x +class Encoder(torch.nn.Module): + + def __init__(self, hidden_dim, ffn_dim, num_layers): + super().__init__() + self.layers = torch.nn.ModuleList( + [Layer(hidden_dim, ffn_dim) for _ in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + class Decoder(torch.nn.Module): @@ -44,7 +57,6 @@ def __init__(self, hidden_dim, ffn_dim, num_layers): def forward(self, x): for layer in self.layers: x = layer(x) - x = x.sum() return x @@ -52,10 +64,14 @@ class Model(torch.nn.Module): def __init__(self, hidden_dim, ffn_dim, num_layers): super().__init__() + self.encoder = Encoder(hidden_dim, ffn_dim, num_layers) self.decoder = Decoder(hidden_dim, ffn_dim, num_layers) def forward(self, x): - return self.decoder.forward(x) + x = self.encoder.forward(x) + x = self.decoder.forward(x) + x = x.sum() + return x @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') @@ -74,16 +90,18 @@ def test_recompute(): attr_savedir=tempdir, dynamic_shape=True) - config = AutoDistConfig(recompute_modules='Layer') + config = AutoDistConfig(recompute_modules='Decoder.Layer') model_graph = ModelGraph(ir_graph, config) model_node = model_graph.scope_tree_root print(model_node) - assert len(model_node.children) == 1 - decoder_node = model_node.children[0] + assert len(model_node.children) == 3 + encoder_node = model_node.children[0] + decoder_node = model_node.children[1] + print(decoder_node) - assert len(decoder_node.children) == num_layers + 1 - for layer_node in decoder_node.children[:-1]: + assert len(decoder_node.children) == num_layers + for layer_node in decoder_node.children: assert len(layer_node.children) == 3 ln_node = layer_node.children[0] assert ln_node.leaf_size == 1 @@ -110,16 +128,18 @@ def test_recompute(): assert layer_node.param_mem == ln_node.param_mem + mlp_node.param_mem + add_node.param_mem assert layer_node.buffer_mem == 0 - assert decoder_node.leaf_size == num_layers * layer_node.leaf_size + 1 + assert decoder_node.leaf_size == num_layers * layer_node.leaf_size assert decoder_node.in_mem == batch_size * hidden_dim * 4 assert decoder_node.train_mem == num_layers * layer_node.train_mem assert decoder_node.param_mem == num_layers * layer_node.param_mem assert decoder_node.buffer_mem == 0 - assert model_node.leaf_size == decoder_node.leaf_size - assert model_node.in_mem == decoder_node.in_mem - assert model_node.train_mem == decoder_node.train_mem - assert model_node.param_mem == decoder_node.param_mem - assert model_node.buffer_mem == decoder_node.buffer_mem + assert model_node.leaf_size == encoder_node.leaf_size + decoder_node.leaf_size + 1 + assert model_node.in_mem == encoder_node.in_mem + assert model_node.train_mem == encoder_node.train_mem + decoder_node.train_mem + assert model_node.param_mem == encoder_node.param_mem + decoder_node.param_mem + assert model_node.buffer_mem == encoder_node.buffer_mem + decoder_node.buffer_mem - assert model_graph.recompute_mem == layer_node.train_mem + num_layers * layer_node.in_mem + assert model_graph.min_recompute_mem == layer_node.train_mem + fnodes = ir_graph.select(ntype=IRFwOperation) + assert model_graph.recompute_groups == [fnodes[5 * (num_layers + i) : 5 * (num_layers + i) + 5] for i in range(num_layers)] From 87209d28dcca691b6859a001473b141c1862a804 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 15 May 2024 06:32:22 +0000 Subject: [PATCH 1636/1892] Merged PR 2139: refine swin example --- examples/alphafold2/README.md | 210 --------- examples/alphafold2/alphafold2.py | 130 ------ examples/alphafold2/images/evoformer.png | Bin 159728 -> 0 bytes examples/alphafold2/model.py | 306 ------------ examples/alphafold2/module.py | 437 ------------------ examples/alphafold2/policy/spmd.py | 274 ----------- examples/megatron_gpt/.gitignore | 3 - examples/megatron_gpt/README.md | 27 -- examples/megatron_gpt/convert.py | 76 --- examples/megatron_gpt/gpt_model.py | 34 -- examples/megatron_gpt/parallel.py | 64 --- examples/megatron_gpt/run.sh | 108 ----- examples/openfold/blocks/attention.py | 360 --------------- examples/openfold/blocks/embedder.py | 230 --------- examples/openfold/blocks/evoformer.py | 177 ------- examples/openfold/blocks/opm.py | 114 ----- examples/openfold/blocks/tmu.py | 98 ---- examples/openfold/blocks/utils.py | 7 - examples/openfold/model.py | 164 ------- examples/openfold/policy/mpmd.py | 313 ------------- examples/openfold/train.py | 124 ----- examples/policies/__init__.py | 3 - examples/policies/alpa/README.md | 26 -- examples/policies/alpa/__init__.py | 240 ---------- examples/policies/alpa/cost_model.py | 227 --------- examples/policies/alpa/estimator.py | 402 ---------------- examples/policies/alpa/inter_op.py | 176 ------- examples/policies/alpa/intra_op.py | 230 --------- examples/policies/alpa/layer_op.py | 42 -- examples/policies/alpa/plan.py | 105 ----- examples/policies/gshard.py | 96 ---- examples/policies/random_spmd.py | 70 --- examples/utils.py | 23 +- .../blocks => vision/swin}/__init__.py | 0 examples/vision/swin/baseline.py | 51 +- examples/vision/swin/blocks/attention.py | 2 +- examples/vision/swin/blocks/patch.py | 3 +- examples/vision/swin/blocks/transformer.py | 4 +- examples/vision/swin/model.py | 6 +- examples/vision/swin/policy/gallery.py | 84 ++-- examples/vision/swin/train.py | 206 ++++++--- nnscaler/algorithm/ops/dimops.py | 3 +- nnscaler/autodist/apis.py | 4 +- nnscaler/autodist/autodist_config.py | 2 +- nnscaler/autodist/cost_database.py | 3 +- nnscaler/autodist/util.py | 11 +- nnscaler/graph/function/dimops.py | 65 ++- nnscaler/graph/parser/register.py | 2 +- nnscaler/policies.py | 13 +- nnscaler/profiler/database.py | 3 +- nnscaler/runtime/function/function.py | 6 +- tests/autodist/spmd_solver/test_follow.py | 6 +- .../comp/torch.Tensor.contiguous.json | 62 +++ .../comp/torch.Tensor.reshape.json | 50 ++ .../comp/torch.Tensor.view.json | 50 ++ .../comp/torch.div.json | 62 +++ .../comp/torch.matmul.json | 211 +++++++++ .../comp/torch.nn.functional.dropout.json | 62 +++ .../comp/torch.nn.functional.linear.json | 182 ++++++++ .../comp/torch.nn.functional.softmax.json | 66 +++ .../comp/torch.sum.json | 50 ++ .../comp/torch.transpose.json | 182 ++++++++ .../comp/_operator.neg.json | 50 ++ .../comp/nnscaler.runtime.function.cat.json | 41 ++ .../nnscaler.runtime.function.fullslice.json | 77 +++ .../comp/torch.add.json | 67 +++ .../comp/torch.mul.json | 67 +++ .../comp/torch.sum.json | 62 +++ .../comp/torch.unsqueeze.json | 38 ++ .../spmd_solver/test_partition_constraint.py | 3 +- .../comp/torch.matmul.json | 192 ++++++++ .../comp/torch.nn.functional.linear.json | 182 ++++++++ .../comp/torch.nn.functional.relu.json | 66 +++ .../comp/torch.nn.functional.softmax.json | 50 ++ .../comp/torch.sum.json | 50 ++ .../comp/torch.transpose.json | 50 ++ tests/graph/function/test_dimops.py | 17 +- 77 files changed, 2303 insertions(+), 5056 deletions(-) delete mode 100644 examples/alphafold2/README.md delete mode 100644 examples/alphafold2/alphafold2.py delete mode 100644 examples/alphafold2/images/evoformer.png delete mode 100644 examples/alphafold2/model.py delete mode 100644 examples/alphafold2/module.py delete mode 100644 examples/alphafold2/policy/spmd.py delete mode 100644 examples/megatron_gpt/.gitignore delete mode 100644 examples/megatron_gpt/README.md delete mode 100644 examples/megatron_gpt/convert.py delete mode 100644 examples/megatron_gpt/gpt_model.py delete mode 100644 examples/megatron_gpt/parallel.py delete mode 100644 examples/megatron_gpt/run.sh delete mode 100644 examples/openfold/blocks/attention.py delete mode 100644 examples/openfold/blocks/embedder.py delete mode 100644 examples/openfold/blocks/evoformer.py delete mode 100644 examples/openfold/blocks/opm.py delete mode 100644 examples/openfold/blocks/tmu.py delete mode 100644 examples/openfold/blocks/utils.py delete mode 100644 examples/openfold/model.py delete mode 100644 examples/openfold/policy/mpmd.py delete mode 100644 examples/openfold/train.py delete mode 100644 examples/policies/__init__.py delete mode 100644 examples/policies/alpa/README.md delete mode 100644 examples/policies/alpa/__init__.py delete mode 100644 examples/policies/alpa/cost_model.py delete mode 100644 examples/policies/alpa/estimator.py delete mode 100644 examples/policies/alpa/inter_op.py delete mode 100644 examples/policies/alpa/intra_op.py delete mode 100644 examples/policies/alpa/layer_op.py delete mode 100644 examples/policies/alpa/plan.py delete mode 100644 examples/policies/gshard.py delete mode 100644 examples/policies/random_spmd.py rename examples/{openfold/blocks => vision/swin}/__init__.py (100%) create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.contiguous.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.reshape.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.view.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.div.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.matmul.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.dropout.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.linear.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.softmax.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.sum.json create mode 100644 tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.transpose.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/_operator.neg.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.cat.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.fullslice.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.add.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.mul.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.sum.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.unsqueeze.json create mode 100644 tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.matmul.json create mode 100644 tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.linear.json create mode 100644 tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.relu.json create mode 100644 tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.softmax.json create mode 100644 tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.sum.json create mode 100644 tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.transpose.json diff --git a/examples/alphafold2/README.md b/examples/alphafold2/README.md deleted file mode 100644 index bb9b4cc2..00000000 --- a/examples/alphafold2/README.md +++ /dev/null @@ -1,210 +0,0 @@ -# Introduction - -Benchmark and analysis different schedule plans of Alphafold2 based on MagicCube. - -# Model - -## Structure - -An evoformer block is composed of 9 sub-modules. - -- Row-wise gated self-attention with pair bias & Column-wise gated self-attention -> customized attention module -- MSA transition -> feed forward network -- Outer product mean -- Triangle update using outgoing edges & Triangle update using incoming edges -- Triangle self-attention around starting nodes & Triangle self-attention around ending node -> customized attention module -- Pair transition -> feed forward network - -

- -

- -## Memory Estimation - -notation -- $s$: multiple sequence alignment (MSA) number -- $r$: residue number -- $c_{m}$: hidden dimension of MSA representation -- $c_{z}$: hidden dimension of pair representation -- $h$: head number. Different modules may differ - -activation -- one Evoformer's output: $s \cdot r \cdot c_{m} + r^{2} \cdot c_{z}$ -- Modules' outputs inside a Evoformer block: $3 \cdot s \cdot r \cdot c_{m} + 6 \cdot s \cdot r^{2} \cdot c_{z}$ - -peak memory -- MSA Row Attention with Bias: $h \cdot s \cdot r^2$, where $h=8$ -- MSA Col Attention: $h \cdot s^2 \cdot r$, where $h=8$ -- MSA Transition: $4 \cdot s \cdot r \cdot c_{m}$ -- Outer Product Mean: $r^2 \cdot c^2$, where $c=32$ -- Triangular Multiplicative Update using Outgoing Edges: $r^2 \cdot c$, where $c=128$ -- Triangular Multiplicative Update using Ingoing Edges: $r^2 \cdot c$, where $c=128$ -- Triangular Gated Self-Attention around Starting Node: $h \cdot r^3$, where $h=4$ -- Triangular Gated Self-Attention around Ending Node: $h \cdot r^3$, where $h=4$ -- Pair Transition: $4 \cdot r^2 \cdot c_{z}$ - -parameter -- less than 1M - -## Challenge - -The core problem is: the evoformer consumes large amount of memory and we need to find the minimal execution time under the accelerator's memory constraint. - -According to the estimation above, we find that the memory distribution of evoformer is different from the classical transformer. Using GPT as an example, batch size is 1 in both blocks. - -| Model | # Parameter | # Activation | # Output | -|:-------------------------|:------------|:-------------|:---------| -| Evoformer (Alphafold2) | < 1 M | 5120 M | 66 M | -| Transformer (GPT-3 6.7 B)| 192 M | 512 M | 8 M | - -Assume the data type is float16 in the following analysis. - -The memory usage of $n$ evoformer blocks is around $5 \cdot n$ GB without checkpoint (recompute). Since there are 48 evoformers in Alphafold2, checkpoint is inevitable during training. According to deepmind's nature paper, they store each evoformer's output tensors (*msa_repr* and *pair_repr*) and recompute all of the activations inside the evoformer when backward. The memory usage of this recompute policy is $2 * (48 * 66 + 5120) / 1024 \approx 16$ GB, which can be fit in accelerators like TPU, V100 and A100. - -However, this checkpoint policy cannot resolve all problems. - -1. If the device's memory is less than 16 GB, can we execute the model successfully and efficiently? In other words, given a random device, can we find the optimal checkpoint plan to minimize the latency? -2. In the *Extra MSA Stack*, $s$ can be a large number (1024 and 5120). As a result, the attention matrix in *Row-wise gated self-attention with pair bias* is very large, $2 * 8 * 5120 * 384 * 384 / 1024 / 1024 = 11520$ MB, which means activations are the bottle neck now. -3. In inference, the setting is different from training. For example, the length of the protein (residue number) can be very large (around 2048). Activations in many sub-modules are extremely large and far beyond the device's memory capacity. For example, the attention matrix in the *Row-wise gated self-attention with pair bias* is about $4 * 4 * 2048^{3} / 1024^3 = 128$ GB (in inference float32 is used). - -## Possible Solution - -To solve this problem, current dynamic programming formulation need to be updated. - -1. Instead of the activation memory size, we need to maintain the *peak memory*: the sum of preserved tensors and maximum intermediate variables. -2. Different from the previous binary choice (recompute or not), there is a much larger space indeed, a list of tuples $(inter\_mem, preserved\_mem, time)$. - - k pass recompute policy: reduce peak memory, increase execution time - - coshard / chunk: split computation with extremly large output size into acceptable ones - -**dynamic formulation** - -$f(i, max(p, r), q + s) = min (f(i, max(p, r), q + s), f(i-1, p, q) + t(i, r, s))$ - -- $f(i, p, q)$: the minimal execution time from 1st to i-th operator when maximum temporary tensor size = $p$ and the sum size of checkpointed tensor = $q$ -- $t(i, r, s)$: the minimal time of plans that schedule i-th operator when max temporary size = $r$ and checkpointed size = $s$. The space spanned by different checkpoint policies and chunk sizes is described in $t$. -- the optimal value in the end: ${min}_{p+q= warm_up: - CudaTimer(enable=True).start('e2e') - train_iter(model, dataloader) - if is_train: - optimizer.step() - optimizer.zero_grad() - if i >= warm_up: - CudaTimer().stop('e2e') - if i > 0 and (i + 1) % 20 == 0: - print_each_rank(f'iter [{i + 1}/{iter_num}]', rank_only=0) - - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num - warm_up, field_name='e2e'))) - CudaTimer().print_all(times=iter_num - warm_up) - print_each_rank('memory consumption: {} MB'.format( - int(torch.cuda.max_memory_allocated() / 1024 / 1024))) - - -def test_main(): - # Training && Evoformer Stack - # initial training - bs, s, r, cm, cz = 1, 128, 256, 256, 128 - # first fine-tuning - # bs, s, r, cm, cz = 1, 512, 256, 256, 128 - # second fine-tuning - # bs, s, r, cm, cz = 1, 512, 384, 256, 128 - # bs, s, r, cm, cz = 1, 512, 512, 256, 128 - - dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 48, False, True, False - policy = spmd.PASDAP - - # Training && Extra Sequence - # initial training - # bs, s, r, cm, cz = 1, 1024, 256, 64, 128 - # second fine-tuning - # bs, s, r, cm, cz = 1, 1024, 384, 64, 128 - # bs, s, r, cm, cz = 1, 5120, 384, 64, 128 - - # dtype, evo_num, use_chunk, is_train, is_extra = torch.float16, 4, True, True, True - # policy = spmd.PASExtraSingle - - # Inference - # bs, s, r, cm, cz = 1, 128, 2048, 256, 128 - # dtype, evo_num, use_chunk, is_train, is_extra = torch.float32, 48, True, False, False - # policy = spmd.PASSingleInference - # policy = spmd.PASDAPInference - - run((bs, s, r, cm, cz), (dtype, evo_num, use_chunk, is_train, is_extra), - policy) - - -if __name__ == '__main__': - nnscaler.init() - test_main() diff --git a/examples/alphafold2/images/evoformer.png b/examples/alphafold2/images/evoformer.png deleted file mode 100644 index c3f30249b54932b7bc12be52d538a45f9b1b6593..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 159728 zcmdq}g;!MV_dbrJfOK~YDBa!N;2=n+lyplEAYD=dg3{d}A`OEyh@^B6Idpe@59;$g zulHyD{(|4VmgAb4zV^QFW5i34JO(N$DjXafhN6OuIvgA_JsjL)Ipin6C#KrI z1;AgAT-D{J;3@{mH-R?@mXa?d;oz!b(QZr-f%hm*3c9Xva1>z=e;&oI9a_M_-RCRH zNNRW*?aU+UXy{TsaVn75+2hrkbl=tV7_Vt+^4Z-ZY-%EaN7Kh|iuz_dUSsI<0$pRt zr&TbypM%=~$Nb3?u>a7pf3WL5I=9Clwa9*&84}I^vzSc<0>c0ADvN|7{6ANHWeUan zzZZq~mJoxU{y&$M{{Q2T**#4=z^ucp{GQ_Q;$^<4nE6m5C1XN&7^pl@DL2UM!`KC* zwwi|36e6fUc6ku#@{P>zW$yK8$ocrjYYvRVz#mL-afjkybW7%u0g;cJJ9`m;1+D&t zAuO`DoG^!!dlSvMrv;C1BDlfM zkB#$t=^DKc#e2NB`#tW61e*0VDBVy-Ddb{$%cQn46O8p&Y{5GL?pg%gk{va#5$*<7 z8SISA&(XEdxo869p5I0ns0%*T47eh$z7~TB4&0frZgpw5K}MxSqQ9qA-LJn~g+Y)t^N`kp49{dI;WQXICxV zXBv8(l56i01tW7bigSA7M~wuVvp_KgWDVYrf6E`=bwMG0e-jZ`qd=Xfo*ITH27Oug zbGRrWXS6w0uD>4%df3>lP^v|%S-R`VUdLoAAbDSSVED5#1-bDtt<#{p$J=2ufi^7} zOMyp{LsbQSE&}-E5!`(9;t{{omZ3KClatgiElUwGyWN3AI)Vj#Ce|$$t2otrws;BZ ztSsp}Xd!Kb4D}ytGWvx0&b!h`4(?6Ww2*zZ`q_cv=Ty9coK!*FZOm(UEDF1`ORT92 z^6f(O#9f$!I1503Nqd%OMyTETIMjLk=fIhJdTkK3Zfjb74M}veJD_)8?7#UChlj%%G{?PJKhpq+;kF=y@`=A zzCGR&HWe~x$z-^n-}b5*@a@K(D_Nll^{OI@L6TI13Ky$y5}e zUSee1F^bmBy~?p%#7WuPil(P#ml;-0TZ9G@xXwJS1R?g|idNaJDV2%6#iP~+T&P-> zxfU$h;+?gRO|PO4A@XnK+ij?v;Fof-fAe;M$YJbn4S+Ruz7&9xfEs>&|7QJ5#dh8M z0`hSv9^Agr$|ZqEmzE`2P=RzR|2Ot*0uSuuk<6nq$y%^l1MZ=FsL)`zG%`(WAvdKl z!^>fF-0RDNb%n|i8hL-Lj%P(VaZc@26vNl&_`^jS!-8}YBIQ$ZQV&37ws1L&qSggO zvMV{6J_!j$pkr=TT^OjWw{VokE52gNss9280ou1f&5)#rlyqlM!zM=VX0vH#xTQ|) zy;OW7lGuUb2rkHW5n|cU5^(?56Diw4}`ng_O)YSvV0pS zb;ESKZC;56A)Hy29d_#|KQM-G2^S5EbCNP0qZC5*FSzPnLG_OG4N`g<;>oH>Iz_Kw zUO(|U(SmdBeJeQUj#OpX7~W9wy+AWkP!xTLo=EJ)=@+f$f%_Vwg2_Qt`9t0@p<7op z2$ucUD6uog1urs=r$4hdzGA_?{IU>?N*_!y#>9)eVH=yx>USw8K#h#Gt}Ovh!tClc z9;@M`J?+Ro9-&xeiKlkRvGVwt3yvhvtGMHDT6K!ZDL!c509i?ha! zjl&RPuKd^`6@;5MZ0)>_bO8T+1)O;w70>f%49V9y=Kvw*fCd=N??1# zZqgJF&Tkt}M0JCI#E_c9>gkV|+5cjxuQa;-Hg<+Yv%fcdVXR%(hxpTY=75V>!e8PQ zvI5i{#y%u}*+^tgR!{0M*DLVqt_Qy9akDDGW^TQP&ay1;E{2~m2WzcpnO-}1lx2<} zK7VD4!)_^W=VLD7RwjW-U!-9pU~c>Fll-1?LZ+l2f0eBYf%$MrR{sbF(&~5bDIk6K zSV}|2YzWA!$2$^z`46|M@AiAbjC2+nkUvD6wbL#=A1ZpXp!bfcZ~P@abBxjU;7S8e zz%{w0);C!rc>*4inb+6M&IOmN&e%p7=q2NUB!VN78DbY@3{tH4I=sc9IB`*;I3f5gK?E`)5f~ zIrX%#M?S}mt)@Neb-fPo_?WF_!>bW9{nGJ=*c|wjBw(Ma-PX$zUy*JG%}k~j@g@Gq z2KG_?*sx&U_4>+{Q&tpYs3^YqfZX-pUE( zr+6zwF&44izLq~%;XkL{O&9`7ATh*>u*3}x+fIENBo}o;AZI$8M07!Zf-*GNWGeBA zXV_{qt8FU|TNUut(r!pmj(-bJala|7);5li{lW6aHjzT-Fx)g;8^OzlSk3tV@pv+jzgw zJ7wdP=i*-d#CU(WvqYg%jau=AC2`}PB^CN$ZV2B7oF|xhQ|b@%z*pp?1R;Z+3Bn9E z+KRQ6HP}C)k$ezR*ast%DR8j;L#%UHU}6(IKi^%nVs@fo7)xM2r9@foe(jgm>rYa@D0M7+q~ zS%!t?m~&w)uUQ3XPQJg}Laz_3l70J9FtKAUPSHj2sYCsVOvjV-s*?}XD_m!r*o|*S zC9ajJ(l}JEeo(^+^Qw=zqNG3l*xXT2(3drF8;UX+NWZ9h`b(T%^Ks&5mvKA&kf$Q^ zQ|Ozr5?Kp2xgNWz5eKf}KUYw{cl8m(^D3jDi(B)h-g+b2biZoL{$VPTOuFh62wAim zo#;B6%^q5=K7YMBVjWk!>)y=abHiR&Cn*;_g}D-V(tL{*Q7%8Gpy!73(|uWDxFe}3 zRee#u6$d1-WdD<41x|BWNavKMo)I?9rJ;(Am9NhysbwhiZSi>_`^kzqsKImCd?&ai zryqRQy(A)J#Z}XF_63X3p35eSnFp09uh0_k;$|x2sqbl(?jks%yu5QOejSP+T=Z%z zC`$KYZN8vDpV2OmCR7y8AaQIF;e$Tg*=7BZi zXB(a)e$*H@CJKwxA>I>X{Sz%R@kFlB;CDV#f;n_9bR!Zh^42$*`il=?Qnq`%F;wrm zfeA_%3VD&rXEK;3n8IQ5xcfax&`7R)wGInmRF`R$T|EsQc3qbe!XcVzZ}IRhK2Yj$H2Lj_wA@%yAO(B2py5b%g)a0vo1$f zBrQu2$Fdz?+TBd-GCo_T|{z78}pLlKZY<5Gj9wje5ql4m&wWE41Uv5j$#VVUj zjT$k0&{475=$2qVZWRz;avIthjwpwBnUq{4*>+z2@#?mm7B;{%=8iwK*ay1|t4Mq4 zrg(7TTQnpvacpVz_L5`g+uf5#M4TmYo}lq6mCMIjwr?=TiuF2nOQYyzP6TC^of%O) zqkT58Of6bR6^d(veAi@=_g+v#5m;S5+ri_$QwD`(zbcfvl`V8WT$cTHlpZysIR1s5 z7o!dy-sw$e<<&+UTn)alfw-N`$H!ZaA$TXTYwKg;OIi1yS?dt#UB&8a()j$#h?wSaTYkx z=wZe^BO#}iW*TFzp>y*a&wM zabZM!6Ug|djz5BShF|%H-}mnBi0skzw%3^l6~+wt`;vcLt5XJJKJ-6ZzTevz%8V`7 zuWSGH3(sSJL000swx2OQO{63z9tgQPYWxLPhu9neuErPDB&^moCAwWnYw!0!SFenB zA75+Uhx>MC3I`F5Fj%XUm(1gz-=0#(!Ev&uWXe}~M?>2zMbCMP-+{7!mbIcvT9lC3 z^{v0aKLT+C zd<(~m8pAER7rz!T&j{964#VbQaUcx)S|fq8flE!VBc1#ZP@!xj%cnS4n9+i4)NlYMYEycAcVf2i7y{x!&s=p8my4>uRT6k2+EDG;|k<9oEA&xkj|1&-6 zPbs*&reUM|G#>HuPPH|(ce?&F9fOd^>!~dr3o^kO($6T9pTP}*D!Wozw{kQx&=qVd zTODt2yNBX%7Ofz68xsKqtb zBX8sn%+40aB^Qa3sB*LOSf6&HSrCE_JmjyxT|qvz)QzB$W{de?$gtfmnyv#B3ef2N zE9!PALW>8HR<>nGnL-xkg3zmccbP4bwPA2Q8aR{Wd`@+0o?4XWs%LT6sq=2~`A=8N zsz2#JEh8u(kfl>)3LDKwOG{68zP)n0C zcJjkV-O^Cd@ZdG-KgZs}!*m-!j}~H`X$Y&1c}ETk*EVZt_l1*;UC(YtCQAH7?dt&a zgKx9Axwib}2d1numlp#A_~~-<1~D-;KQ46@io#MP;a_}oxH!A9Q4o*}Q*ql^x!EpF z(Pd*C+e$2|TCFu*dJ3EOy=biL{FY6j)}$k6-LAbB2cJG~oWlCcaW#p%Boz!x-k9;823O2)`y1G}tWN!h zEZM>^bJtC4W=$@P{3EgUC$chfr356?yUQxU?eTlMC;mZs(lzmeIt?!QrsXrR0QrDS zr>tYuD_8F9e#v~aLlPr+S81{Ur^k;gO$|6HzIi?A=((pWxY>JW9?icLedlaMv_flw zAs~L-6Hlp|tvvr&e=T29i~{EoCB~>G8H!~{594@A^nvFyBdcns_CjA+-hS%Vn0?hN zlX|ew6u~DRlMWQc9AXeMVX()!FBIZJ8tPsaZAyYCQLf%bEpJ&*R$F+`bEd_ZJ@H0< ze`QhSOd%11^f#8Oaw_Fw$Qf@Jed`?NU)DJo-ki-k_NH*=-Z`c%y)JQH>y3xMy*}?b zTzs1A+LCpXsSU&YC+x>K_$U$pd>5VKn=}9q~hSVwP z3yy(f?#F}C5|3~|MRkNG+Bx^Z((^I`ZC0^<&s_v0o3C%P+O4VF-6XuDNv&ralrw=@ zO5om*oGuc(&=A?J}M|5M|C3|%(tNVfW*1bSitSkT{Urf6kuOSV@*h_ z@Jxiq&kfRd!-B^O!Ek|lX70L3MQ>#}wSXGt_jSo*)!z*y*giy!8_WEiNcuR!9+kAtgbGt{oqC@&#G}` z_&kV`UFNww?}st{pf9m}*jT=E(Ao}okG%VDK+Qs!?|RR5b~RZ=+#Ob+VScotrpuRx zeej*A$z!+8nvun1^%n+9MG~ovCe4^W*{JJcTRuqSc7-`6VzR@)@L>uZDZ2pijWX+r zF#+6aHS>*8Jx{@ei~i|C8bfW6=h;uIaS<9C=BNOtcKV%LWk#phNN!)fb#6sY#A5;} z?Io(wKXYtwrb*y?Sq`R20??vTYulT|oSay6a(LKwe|Ii^sq%1}t9V9-?E^phnfzDM z>CLF?i^#e^#C3*OqOI-{o|JU*LT@@nkYCGS(oYp9WH0ZunsC{ZJ5jhYKM&GNptPS0 zEPMxGJ#uqU0|Bm7eqVTZ>3+(b7;H`LgdK*Q#D{d1j_=R-WNA5t=8N`bT+^gUGpZQt zRN*nPF&FjN;oWz9>!WeIBb_6XWV&Q2nj})jQuc1v&2NTn{f8TSQFz_LxG=;tD#j3g z+)yH*5Jy}4!%il@2-cFm`9xbM)t25s5ce`)vb(#~x##5ZlGS~n6Mn+XEy(ZIaSd5K zOhJUrs|m%e@r+}j0Ng(QX)RvE9_cy~4SS9kLI@fb5Ai{J$rw+HqAbuhX}%cyRAhF7 zRNVSo@ARUy_dLAgkX!=-lNEPeXMudTFB@_t=9!Gc?eUkWmjaR^j#W$wB8d(!(X4l) zzHfe*YZb5G&P>LfCgtgN#Oxb!YEffHZ45(9!j(+_f{>;{#S)ldFKv74IlR7ug|*qA zdR{3p_gR+^<)qSV!hQB|3?S@($C4$w2g1tM@fO~ z91#+i1ZP7Lybv+nA_=ET>hFIKYrC`p0s;h_ml2+z6Ti#!znkVzeeLKNyFQS*@*|1) z=#z}h1RgvvAu=G=L>9VC%~0ID7f{tL!;l*?e1D1OlPwn$Y&`9%9)wPD&H)6>qdLdM zkLvN%$^ymXHe(fr`bsLe4=D2K6*@iaKBkGQROCac=5Ri10AKEUhed}FOZpfE0asU$ zQaW*K^qIPj+(I}x=q3G&&^*k|1NpE+okkNhzVfuty*2;9Af58;Zd2=6e`*y0$)Wob znM6=Ys8TXvzR>%wcHZqJk#Z`VJ6z-Y>k4?6w=)rzN9K;h{l2hx{rf4NzL!OpRxIxG z!Kcb}OO@)I*?1f7v$SD^K(c#h4w^~z1z&A^d^}D#FZFa`#VRvxv}xuSz9Uj$$o3bD zS;DeU!uY4h&jR%lFkkSCBrUjAsu;<>)F#el4@D56!2=B+i$_P~CrFi#Qrj0}5;r+v zObDt%^S?tbC3}QHZpu%i6cL6^h1GZ0Ax7 zefN|Zv{5wicY-}bMijSi8ejifchlj>sk$?ef2Nv&y@ec#hwVh#akPChWQbuuK`CH! z5gVH9RGg}_=+84uR~|P0M&AN_KoW!-aT?1ki^D_3NoC9L*gbsu&7^I~|n;>ik&1t-t8 zC~}8(M(F)@E-c{{QW76VWSwh!yPN(O`+m}&6na!%xSN^MjRjNKA3k`tVz9m{f$Eru zQHD+?JxnzjwD26Y=34#5x5kQT?GMShJceuS^??tJ#iGJ{kASk4bq_z|gM1X}!A0Eu zB%YP`1MgzVJzi!wWPQHL1NN&hHqje@g-$XM0XCA0JmGQ08Hc| z_35W+ruU^IcM|F0D3SE`Cwlbkca^CEf66{Gg_Y^w0_9mh+Kg z0Hy%ucgp)Aumhhb0wvOS1samcG~!!=yWK5ug8sVSZ|8|M>G?m9-rZ?PrgJbeVNAVG z4ON6y2IHNGrIQ3Lqny+!s$oQR9--V#F3mKFe0?@XOxcvK6qP5)l(mg(+AZk88q5A~|3Hf;Tq-yfx%WXlEx=X~DxY--n4LzruC1`g{8JL)%e%aEI zkE1!)4g^_BlMK9n1yrn>&mp6;r)3eHpwk!dOGC%bvvIit)Kp+AN2@tmpE# zz8O2E`?-fNZDw8w8vOR2^L!8zR$c>q`rY}k>^1^xuSegUb4s?fYi**RZ|07W-E$parxM6p89_v0J;=9})I zi<29=K`-(il>JiA6S1Exu6c-;t>5RFAp&kckdht6vHZJ`%^nd!H98-ml(Iw~HH`+{ zP%cLfF{T}e#Lg}@!l6R01;OL=&Go)GM)aaidlON^euq$MHgC`6qZNglo;TlZ=UQOV z36vYo+tXnO^nK!@Jp-o{XZ^irw93fex0{Hd<^BsBj-nl5y?5eWmf|A52iQYv zWPVfAn44o~4E)ytCk+;eZTDS<28W^{uf&p_uf3+0A|}7j4L>glMm*0I1O3i8cbrm= zs9g#vEtA1d-1H)%qHDP_G9Picp$OCmIG8VNBBhGoFW}&)GhuvKEeaK5BRViOrZ*o; zp_oi&x}!>QiQYt@F{A05D%EE1V8_z^YTg!4E58c-#zWRs2Of0hQ>qKy#q{0yoFNgs zE3LMn){a@^gbn*NydXTmW;^{QQO^7h$I|qv)u98ski`VF@NADmuQT{TI6n>u1k#(z zU&+(+wAP!L<24)vd`Uy~oF?*f3L1Uz#Rl)-9qx$^hmxZ9_wLo39@HO%tj6D4CYel_H?oNyqicCSE8eu=OQJ zJHVgFW=J}eDV`jQi{smxf<=$5mY6Q)OG?6J9_ObBaCO?V-C24n>faLRAK!I8#aA+Z z#C;bNVUNtPj=0p8acE_9MBk~6ROxMLSs2JSGQ=?HOTgBWpc)0LJkQe4Kra&6F1Z_d zD5*Gz$7g3B*U3Bg%RX^ic88!5tqCCF=83_q#^F;!1|DcEa{}4#wuo`~a@Egn@ zJP54bBk%J))&a%vh6qugv)*zU8Jr%c$7(FV^n+8H!vr>*?ZW(9uRkyGrzg9!+MiRy ziP$Ndo10g)XmY|$1Oga7HUKj>fX-6i9_%l)G|egfQSA4txo^6V+lO_A(R}5wFf6Jc zz&#ebAMUZdu@MHoJW>=t(XX|OxWC=^Kf?JRJ!rPkkdY-WkJj=VB6DP;+7}m%#iV5L z#egxNX?%AynX>o}Fzky3JlhZaz50Jz2|G8D9UoyL~+|9*degm3^hX-IG*!Dx8kC-9i6}Z}$pe(*5A4|zv_nS5S zHAh#PZuJ{QSE9|){EkxXO28B}7U=ZYU60+iCzAZ?iI_ByjUnK8{XaW=f2?f^difGt zDT(R%fXO$9c|AY|b-)tw(ZVjP;V07b3kybiH2-^juhv|%7hX(E%&0+VY-}uGR(wJr z0Qh$N@2{yM30awLB{Bt^bN;Wng+npO=msnjnYBR<)HJA3z)Xv*)}-UfW|&unK?621 zF~k4SV>*W4ap9@g(duYgH^le)K@WfyVox(v*#sg#52spH{>buwwXhokb~N(@F{)^I zaN=uz$4{REdveB_nug)HQ*=yTeB`6u>WZ#pMn*{1c6Gy_!6# zrH(MHWc@$J`SWTw8AmOh*AT5$uGhyZBa;od;j1JtzYEbT4{d1?1sp>2f5orxIp0$} zCd0(U^gLV+c`(t#6pxI@q#c1#cM52p8@kSb(7Nv^911wHG)H;#NOIKe`7Ay$aWMO% zo#T(4+t#yU{wOjpv0LeNBQNgQAVYZnxR{rkS$Zg4So^__501iM(ke?i3Gx9F*(GYX z{&=6vPWS*zZcs2p&#;!-C*MNhy65q;dwWWyIvw`g{;ziV%r)48r9!!qs{HQ_mm(;_ z9GBZx>W-MDCZ5stvm5ziChuwwTQ1`y+Sfe#)Xyqc^5K}km|1IVEBQx%T+W4vIYSFF z;&z2pYMT#-?KlbfAtYAY@n%gsq^xiM)-p}#8vA%ML%Ij&`ej)gsX0qEGNbaBA^pYzPTwAbtp|bznsP zvRFeQ?&+hOTQff{kjh{rhg64=a9mBhdRIhJ+*Ngcsi|=jK~))=Ka;#S!QDVk&lkz`}mufh7~#%32~cnk|x(fF8Z zjb~*rcAP-mfRH|A--FfG^P^JPzd3AvZPet*qs}Id6TFnB4l7G+lH)mWEff(+t`Q zn`42}?zH_7|4B(Z{RWq{yoMMIGQM6QHh&9M@FG=7{7W*cK3$zmCE8?0_WW>*aXLqz zS{F(5^-gEr^M%pQiCoKl0YNK^yUfQ<_api{A}-1^6Ot-)k(zjkMpRW@a)t5N&c}Dm zmlG$f4q*;Tc=Jk=uQZ_EY*ylS?Ca%%Q3;lJNwq{vVuus7bOYiS{%zx__a^Dri^M#; zDJNH8Rvwz#k0jiS@Q#bVANK>=X(5?EoduZ@KPOfcB0mk+bn;dec&|CjrYw%P;Cu0N zT(?XCEbeu^hX-a8>;4c;G(5fur^KgwJ`ShRM`xQq8Cr?*Hv^`>l5oFKaHXISc25M- z^dT7qr>iZK{azb4d#X7^P-eCwnBGXSclE`Ad9`0(p&E;L-e`%X+O_Q>Rpx&QM1G&r z`o>iFaTV_)`bx;&bQIQTE%Ev!NF#)gtjce4(R1ZoQ7qL|@Ad0A9wMA)Z`qF?r@ zCl0=xj0&~N2+Dl>O)(`D9gu5dyP7sl{Pr35rA&@e)o>KY!x#(B0Oi`-rhBvr zjO)F7qZKSy0z~d*pU(P_43lxTWeU=}}8~rw(KT#Eq`MLHnr8$&DiM^d> z^7vx$o>+$VFEY*^K|7FusWI%W5csN-KLAIK;{g_X+k%kto!a+X@7d~fl6924XIvIJ zuEQsHYbhojS`&wIres*!j3B3S36rl_!iF^Q10xjetL!5VPjPt7ubelDyV!D69mOxO zT(o!%MT1Kjf*aqi2rBXNV`N&e)e$q-8GV#O!*6Q7zd%e#z6!o<7|_K3Yqws)wSiRa z&&LA3SKsEOWWfRdY;P(= zqIp_VK}Hc4v;M1ny|#76PgPF|=qB@C8WMS?`7mTkR*m=bQ;j%xiX}{HjRm@Gta1CP z=$|hlb4M7qhLLG2g~M0e98=xdZQ2~oL~0T=4~l;n`se6mq!+8+x43l?pd>u6!ybu^ zdBArryD5%Mq|>P?vo*qP&)W$qg4xfIri}f@_dcEG&6iB57~YMF`^~o_;RMfr);Yt( zFBukmmyYh8Bd$mEO|;Zf>Zg>&uFQ0I-4Pih!;Fw{#PlI+AHa#rkE^am5-kc!8N!oX zFumsz{u4tt=n2M$brHH3pP&6J#wgGMZz9!HS69!OdI9Mt^@*@xzc#IEnJ4)sHn)=r zmG(_0k5`MW`vU5{>Z2nP=S*Ze;{9K2AHRw3boDjWE0yIuy{dQ4q-y&93J?cdpO0Al zelfM8V?GQ$vrChxQe1w|rXQoS4w;!FOX4p;*1kU5nI27(2b!fdLYk`)U?^xV&7@ty za+|eKZ5cafstaKK%iT)WP!#mT3O%64tmfh;N5CodZnKprE~a1LVxQM})_dI(_!W~k zng;5KLT?s_fNQiP4x(Wz-M)S?{koS{bG5fi;&jY(8dlXfl?xaB@{AR@(1buYZ4x7o zs=&u}o_`&blTz!vB4<6AYtj+wf`)Sodw>aQ0Iie!US^8<>N!-Et>1c((6Pxh>^0I+ zWmB~Q0ia^^kU;(N%LRar%ab$N32rng2j}HPM{1GVogdCwpQQK=%{JbLf_WT-+|H;+ zZzvl({qx;Zm(DZvlW!O(MSU!BUt%+NJ2lb$O{qvl`55x}b&($*S(*QJkb0KHBOp#5 zh&W+C4)e2M6`gRs=C&i_AiKq2StGjJ*Y|x*NAdNAw_jR9HyF=m#eFX$z`=g_nPNye z?vwVReho%8fdPlAgjgZO*E#a`7mFL#m=ATVG@h_Kg6bWOep`MafBV zK3Y|@9$bkPzl+lkhsgt^Iv%4hQ@r&g=;~;#-zCaI1bhXWu@pD|YI;$YaD;P(bIWVm zn$5mi40-H$Bi1Ls9AhZ&NS2XXR8QWXbr6vRA30`h4?`QcY5~6F56&pg=jdO<-s6S(;Uw33T)?Paf8~pW=M&ID^oHsPrEz$i>UQJ>rzGO-e|e5-0%-}-v7oKB9sV_N0B(FQ_OuQc zk$Yi*+9Kl8^I}X@3|F`XfLegs@rlq+8dmwECs>d;?!Wxnr}u9g#^4RiXl2%aC(TE6 z`tY?>v`JS^6(zeFvbr-J`u>y^T+-LYbAAbPq{8!;=h4sx1XYpNgjnnlF3c|~I;+p? zxJkGaM1vS)!DSq?`w0i zqZo~0aY!h{){^M;lx!VqNaXX=Sb($CU9>yZg)P~b4L)mpybgT)z{mN<%=B#ghGF4c4dk*!}Z(zR~;n!-|9G`CZR8=mCAIfDrC* z#h(2XZvP@2@&=$O(*DIG)}?c#O>bR=tUNKRcr|233Lnz_j?3cHL!iC&q9#r)h$bV^^F4cdjQtTu(@}v|4PaDA9Wc>)pURIB&bY}6{Bw+)uI}o!&2UGqP z^%@S(Q5VCf=`n8K9H~13-{>?Qmq4rP$QreQZ;U??R)qSkxZTD7XDQMDEcLt{1t^7l zDq?w;zKpFNn`K6gsck7%!RyT_1jwVs(Y zbcm*S0up`Wxl^r`4z!)^leRY|(yh9y-bVj+>|ZGXfAo?+bS#~?(BM~=RHSrNKWY|$ z-1F53&;-MG<0^fsI_$Qy%P6#GmuE=&L7DRFO5pw;HEV#?UN+pm=`a&t_BDR^CbB#k z{kEVV2LJD#3Xz@D_D{6I|>lzTof^M8kLQrd09-4NpWPLlBuYK-{lG=4{BChYY&nIHO)R z_!*n`8WJ&9k6)OS>RiTLluqK?O?e~VWs3D{GrZBXYZ@8$0JC@95Wf#EWM3eobd6ye zbSbsJIC>5@I~S|GIH|PN+M55v=v5W_4RR8jwtZshINojOP=_lo;0jh)%2S3 zFP6~Q0*D*_^Eb6&!2!WUso$zvKQI}s*(IMd$G0#+-OpJ>3H;M&d{VbpEgV_bI`4P* zQF5vS+>eQ6`xYTfUP;cE#rEr81iO14cs z=gui1N&03np7M~CckKGrg(ov_e&}QI+Pd|85yPu;x#F@*+Xx&u*1qScaX&*?HbRW@C0kty5ty+?I)Crbmdt2i=&Ca*6ugE6MVxrKlqSIrx2Ckke2Rs?&WmGH1 z`&c<0^8LHSWr=KSR@Obq6)1A~_+dozXKe$sxBqN}eXS6tB4#(|3M7%Psp5G?kfkCX z&$OR?ICV?+N_$L;r~f*-vhn?n@Ou>L0i=lg#+Nt1@nwfS`Gi}p{| z+VmDD3PhzE$t`1>cI)r>ps~iC7BXia)tQ-*FRQ@C z0|*-{Vxvr^|NRJGt@fY!GEGkX|CFt_0@MZRVjdW}fREZ+>)MMk93t@&%f7f-O!{e* zjq0MlwVQL-fqb5*GhacJUL+COt)H{Gw#y!hdL5YJ#G=2+Zm{U1aOh?c!iv+2Gn*EU zvKP}rC+{PX-Sy5e6jFUXWe)`Tv{Ys$M~AHUyxuu^h%(k7BdEK&{o^?XyN2zIexrNo z#D09`o0C4BH}To~&DRg6{SD3D6;of(X&V~3?P-S5F^u&2q8Bg zrAHw!y|WBx@cxSLdI@#P0qjDtYcherZ-AIr|8p6|@;3!Y$w(0JJ-5#Y-R9}mus?d8 z(?txfzUY`eH(c@8$*cQx? zh|Hi-G>S5f#@kkXZ`7wf=}im@Z$78GQCl54V>ps$eu9^y6XX+W7$VR7=$z~H5i5;zJd@4Wk=3>8#>~c$QQPrP~?}P%VE`L_<%)fjl zt{xXeEap(m)8{#CQN8@^8S}E+w68eVYDJJ%de1*11mqGb?E8t;x~DF0_e6TOoE*tT zaM`P6X^=Jh05za}!h#$$<(3EycD%f~L{rk-jhuV_8T!-SxJ%JPt8@F^zDeGszrW+8 zZy-3Nd`KV-=yd>u|Fgb6oNxV8^`5Z><2fw)K6^27w#_H!{P+XGnpKS`}KVH(JFl%g=@ z#k|qe7>zBqe`tLCSJ6nZf0=}W9Tt7~?Rhn!RWPP~C=*ENE2K5;3Fsl;m!Y{EN8`03 zp?EZV{mF1Q>h&mkyi`RO-c^uRK6V|}mV}${YQ~o~`mxReZMQI1AP>4^GKuzgh${ON zpASYR?3C?SwmylG4i#}bT;CP`b0u_F_&@JA{%3_LN|1odav6_TjboQwv-nDoEA4iU zdIT;5Re@5D&2$ND%2~jD_XL?+ki@p%>``bUQ-F|YT2ich_T2(SbJ|tOsn2Suy_Yh- z5zj+>wR=F1UTmID@hr+7o|^jlx164Ksd{ahc3$KII%P}*0r6ZU7{#ybdw0T0)rI0d zaXwiPv^QEa@85Y9koI%!L9*W~2iEZ~G-ARGyzmp@rv4dy!b~?u8trouiT=E>lz?z zH#AuCS4>}dEj!&+Gjc>Y#Lsu1x@$6$pr%BbpSy zTCFq#OWATZF#+cC_uOuK-VbKA7Jw%F#}g_3k9WWAabb2T*TvHlOShw77V&BY6a1um z>Kof`Cc&7eHY=IhLJ{MWN)a2qwh|-#tKjuYNy~$n7+dFZ)1O5j!~~++pcQ(Y_a%IE zyED=o113fY(3`l*+S6|&%{{{&z`;J_Z1LHLtdC6CE-TWmirgLf`b!6pOm*Ew8Qu4R zMEy-MzW-9JHVltxaz82*!@D2v{US1b_>o=+M3?u=FT%D&Lx@RCt4JeKa&03tGdIH| zrmNQ5=LoJMa>Hc`-jN`N?8h-jh5cg8w!BD38Qn--DV__VJ)NV^+wOk6Q~-D^wzYWAsZ(>O9nK0 zV5NW7NCY7DpXMiG7aA;H>LDLhaJz=mjU0>zp5w@VN9f*&o8romH&5h|hLok*_D!Lv zHlZRcT}X_Juq=de5NQ>~hn;*k4ypLV8`()j%Ix!vCaS0(g;Gb)NQP}@B8JbnDW?_~ z(A4u-qNh=*!-pf6Z!^-~(J>~_*tUEMb7PjBJkE0p%v?mejMSH(r3T!pab@5(+++QDxl@RFOS4qdU!&-kK4fMhAs=id0F8d= zY5uJL>Czq$Be8gbRUI0voqC<6@|a-f*gd4Oxx6DG-(HOaLa{^6oh#SBkh0J77*Fq` zET;U(M?7Cbs)RwGhI$;XQr?UFHJ4`8hM|OCBRs`Q{utxEV^*9-q2GK2bLHtU-nO27 z_o8apWhThK7TpZq5+eAZmCm&2zzIC$A(ix$Y|D)}*aEM|Ol`k_Ayo@a*k0>Im#q@n z`3GVy>a$?rRMZ;j>38zN_>|EKt|NL4>7;K61&(JQ2l1f6DwkXR+OtRRJ5b;YxzeyR z-cZPKr@ih%MWJs#^>{duXE<89%xFV3Q3}NA5^p8-LWyZ>nJgSi>dfe3X`MmQ>`eLsTH9`R?EY!a3lZPzR1E*@cUn<2^L$~g`GG}6 zbCtRR(1FjGy@i^z_UY5v0Z)8eoE$>HD0n%qd&WIoyOB1$iGp*51Z|?-`MtnHTJ92U zo;Go503n9k|yO89}?Z#4I5a$VRkeaNELl+;IIqQQCz~W zoqT~&Q-MskkFU>ru>_6YVX}uV1u^@gbr|_K;f(GMe&6r(GnlV} z*M4uiVRWoL2n+bI~=MgUO zA5`wAdncqwP$t;s)IQlMlW8f?)(AY9d10C}1bBA-zyAnendSj^(z@DFu&Bg3f&FGq zV4J?asfm_>K>^tE#v$PEKoLsZZ^10&w?{3Ht`Okzd8%g~izKY6#Pj*>OTFjNt`ky) z?9rUpg&3Y@JZdBcVgkQ-vAzKZ_C>~}ZG6O+1!|l;iEV2{YnF+pDcQ`UU4l>K;P4Y>e_EeGQuCfT`6Z=4~9IEjqi6=wPpM@yY@u1dQpv~*}x%8rm z05uq$UZ?H>id9a$Lh;J?21}R7-5JYyc86ObtI2^?v={`IJ1M<+?KxdC|WWI5^(6 z=JA%1*BA!8s&t@9<&iD|P~c#V`cvqnv40nEbR>@eZ+$2eD@iG2 zp~a`)Qr56H1pk~$9XMLbqLuff{(fqC>1L!IWj{7nJ{`e*czp}Oai80>8ba$gUJ_Tf zj39#UddFE-`~jcI-W}b5)-l&<9l@ZKK8){$sGaT%so-w4*|47TPdwqR!T>gXeeyIz zERt{Hxvk2*_B85?UafDS!FBTIOUq{z^{ziIAT7>5y7Y;lyg()k3Om6n6k1o??7p$b zNcr~9Rkr&USOuyrh$t`h{PmFlP6v#2x3%7e;LTtGH4`|T|HVO8TYy@>%OVu=KzRN( zDEbCi^uW;qL}2rOs>U|?8oy#OZea2=;|?b0mn>AVyVQx>O>NZ%FFoU+{&F$ zjl@blDJVymgNC8ed$P(`s*o0O^R!W_-fk?E<|m=-EE9b%ax19wg3iFV(!Z3vb|oX{ z{oLu#mCvMZ!=0^G19nn)MJ3#M3!9~iu6*~b5p1Q$i?>t#t=%Lf;{IT(X>XDx?Lfod zfhLLe2i5{XbBMIUrCq=)vD;%69h6d4uZ~SIDLfpgqj$sV!H@{ zVq=d4eVS;*#_ zcBEMW56l5iDZJAw$4r1D6?$95!;pLRO!|pQFC0FRh8+=fV}pKo_GoLIxfy(k%Wl;C zx(Z+U1Zza~e+-((iiq!5?-V#B_rAA0#4zxyj=RRU%jLbYtKN$0AR$GgZ+p;S4A%lP z#aJPR&l(uQyQ)^g9d?iJVb6FJasuT_Wj~;<LVEFP9728Aw0;alwaE zyx(A4?G(7!#%T3f2}3h*$7;&!nZX8o=R`!ul13$*y0wOap$WOe_12Iu2Squz&FRMN zdv6N)?`~~0c9|k@^N|ICTIuAx@(s$f20FrNU%+S`C^!V!PJ!njyS)pfubtHi-8I(5 z@z3-%RrKA)zKP2hh|uM#NeRPHPc_9TqnI1pQ7A(9xqq|$o`)m{KA2_(###RmIC{1) zyb9x1HqpBAQtf`$w$7UvE8|wbmN{JK2P@b7wX(3sUgDE!VB=x*8CtrOQo000qzCC%xi$x}_D6PRXIWbEw}v>OS>;&&yx42hIw2UGa$v`CO+Eadml)l0MqZHiqj~*_3KA z-$nYb(xs;cT`?lVSH;rlg0Jo%<9@Gt%SR?HRbJlbc`tMLy_$sOH^TED2z|0x_Crnv zJXy4ef<9xzP1ycSc)x#E)?Z3xVq&51h71RZi|yAEJ)Lp$`{^6R%4L&MnT-oxKU%OD zt@dCekNJMY;R?3em-=?`aw&|iNBRbO;8l6nINo==jK6_3oVGHN%^py;CsIgb^c3qD6e?XhM6DfMtu<&Yd+WmtdfFm1ko84Li||ICJM1A?RE^o* z>!Aog)(7}3R@2^+HWi+{11l8ZZD^N*BHAtgqjH-^30)qDn0CeH=v3tKjgc6CeS2dt zsp|TS<&AB?Je-vyk5qSADgzvKD zBcVCI-M~z}tFKvKiagm7pq^DbtPulcDVnXVt&5x6{BZ6w&UpD?F9u}bDoc7(kQWm& za#c26EZFNV^24ebNM)1Fp$vIT!9u!YbNyeIpfyK-8L+27bi1gOk8S){u+6Bon@Fpa z`~rx3HqIJ8MLw74v{Q2pZVS|isvR(iHs1fvFVazW`c+>Fmq-&fDf^4%o{z6b^5PC%0ZVY{ z0b_)HdO?nc)kjD;ST#cB9P7zS7z`GLp%2JbbIZ%jbeCQXSPZ38esYdMp)AC7#`78; zl0puvgj3VgrKa7tt){Apii(N`@3V2nvkuTnH4$D)9zq5TrJ9H?We=GKQl*-RFFEe8 zkqnV5z3|r@#b+rP?*i$+Kb6w%8}Vx^aGvaqcN1Y6($I$lTSIw9WIF^sVlQrM2{h8w zCA*oZba$){7jVhoeY{bOHuHJK4~xZRZnK_O>%JFO5t2NarOR0?LdmXWaLv-r5MQEx z{60Z9{PWzZmtp(w%Ws0ICM`>G5rZ@Kco;ArcW1q?DMTX`)+MGirUK<) zLFc6BFM}Kc<=o&msijrl(3* z^y&QFidC@7F4&57$;=DthFt)@XjI7>6w++s95<65M7?!AJDm$}#GMR)@3l|g)mIU% zP(m8K$R3r*ZsOT1WwzsXfWI!U2nK1bGu^A(LoM7!($#S68nK%C_)m%$56dThBK*o3 z19@x+xG|CiMRbl|N6yArd>Ry_Z){*bwanrSbq!7K7giL$T%m-vl@ZLml8z&k7Ub@H zu;wMX&4l*zaKq>R_oFKH4Z>ORYx1_5#Mm6cm)*`cNkE=cVhZwngwe0m_@B7mE)DV> z!uxdW%|dyHqbEo6FzxR3jjeo2Q0;yG0RXe2pU} zCT4DkTYBNA_1?HMf(H7sdLG55@vMr3T{6e^gxJTtDu)w=*6^Pd1s#EtD!gf zquN!DqlFpgoN6l(Tb=$ffw1-iRvWmee;t{p-C@6Y?(dfqwczcGk8uGW%ux^Qg|wd> zo`Kaqk-fmj6vF$$?~IQ}ELAeIP5#c}`|}d}Z$UaDD~Nd_o_diQOj(&x*=DJ95i*QI zVN#DqfRH)qyrxmhF#h0L?WQ)nU4xTEWg6IUbsnX$5oAS*G^#q6KTY22NH(QWz5fn7 zPl~j3y?W1cl~Bw}A=XCdW@BK(QRD(iQIADL%ih)}C3Uh%;l)nE;E*C%iy_nTuZ|mF zy*AgP4l%Vl9AIK+-30#zxE$NsVJ(2PtJ37<8RObtJVNJv!SIyOySr*x%JaC+UWpH` zme^VTlC9T7=L^X)e^YksvXpa8-Q*P50Dp*}((Bk**(bRz>Az^23{#xpW~^3v z;kYEFVt!zJc-yXMuhdra31Bzu)N%hiZ~u2&u9~&TA9B1s&0%eqD(Gmj*q2JDp7%U@ zSrYJw!}-hyKiSSU2ArwwVeG&K>=pXF2Q%gOMOe>A;j7-X?xZXCxzs}r65=`iU7sk% z4+?+2t7u0%^gip}2lL#Y^>)5Legw-1!eMh?ZC+6%l?u`j2F66wz!oalhyK>(`6j zyMAjszL0`5z*|U}m-(d~hSSctlQK|DupNIg)!>yweU8^YVxvywc7I*v=&N(`) z`SDp+RqV#aGY+I&jcj9ZRV>SlzFOE`u83AwzaUwTBsbDw%5L4r+1v6#B~QsRpev$^ zs*ch*;8A7l|Ei8iJm=M*;yp~kU&Wh!Zf3b=1rbh)@LB&P(J9n$TS1O_SSlT^eV-1o z{|Nqc)Oxm{*3%rdp3q41?`jF%<eKDy_6p@eDbo2a5QtB^f_b4h=cq9gv+4uFM*X653AJ z-o^9K)YNhC*?y1bO&^8UWgkK^;ECm+dQ${n%{6iZzjL!0lTs%G9_ zy)_RP1bV4RD=Kx*eU?6udAa*cP}b0v|Kzb9rqJvA4?AxV8ShT}vq?LY3@vvAg&S!? zT)ba%TCk{z$CR=i2OqThL5@u{>@LT`4R3J#V7VX=N|W)ZBgbVSr#p87m3i6$$dMSq zi~Or-f1|E%14s^rTFxGSDs`gvC78rgS6owVvS+bDLlmd>{k4T}Ah%tXq60~b&HpUU z#|NK5pm|+@!&0DKhGEhX)@)vN^#60!wsjvVTmq&qs=;1bCyL%dF~^A zmM&O}rh?tB))x|tNbpvnK45|mMt^AoOVN@PEm9?DDF%nz(bu-Mq&=B67m|&U{rGQ1 z19KnZId&Bgrzf!$tyuY$U};qmfsn%`_rHX6mSl|MK*M)1-JvidLC}qH^eIK8fId~a zgp6#?X1UzS%02i)dEm^Si|(u_BN=c*h^tjfQR*p6LPa z5+0jyPTxDoVhXWm+PIYWX)T%YJ9BDpMsv&GUQ{Rq>D-eK1DGgm#c%i5t5*b**tSSR z__197y97qDpIWih*m}Ulh+PIt@jK+zJW(V?+3^5}Nkt>@m&c-(;P+ZYWucB(@KXeO zlL(rMjAfu7Da(9K%O0p+2v`n3egimEBZU&-k~=zWMpC**hhP6{X9f8rL)O#xG!@s0 z=Cy)8dR~E=MjFY6(Y@lNZdBI9i(Hdfy3^szNN~gtm$m{(kKQoUG}>a$YWSa;E54X3 zFohx2yJ0>Q0Go zS^WQoTW~=EYJCfw&U9svN8a*>;5uUSkMT5DnnOxFTAo`u&iK?}r=cJ03pHXk=EF<9 z+XeMRH?v%N$JX1gxS-qhxPfCkZIX0L_Et3@BKD62Bm9EE|MgeKCp@A_jsji({Dv>B zaar!c>Er&7$kz^lgnE~~Oz}Iy?#EA%)j-rVxU$?yhKVa z7Z=I@fjZi;!AsnZEY@nQ#!JMJdLEn#^1y`wuk+&aGAoipJ@1ZZksrb@C+@$Ktq{;F zlkmru_x%Q3gq~$9#fft3*Ooi&T*#=W3Opon0g!vocfD#w2% zR;ZmWAX3_$x8VIlHsry*H9;syi0v3krU$_H<$Dh8*zwe@*svb5=5Gkb1 z9_ZO{iVPaqJvkK!Z};atL;u za|;9BGK*82@y3E{*g_}+*21>=Itt%H-S}e0iJ6<00G~*B_8}P>4zr&^95t7j>*}O1 z%zXpHh07u`T}P;mX9BZRUmsC+aO>i5<$Q}G5+irGKa(?{dFau6;|CrCTCKo1E*ypm zo$%qM&N({6y(14;t+LS_wcIF9Z=7l?s81DgvPrPyc6Z`(3>r=;2Yl|!Wk}i%c-?^n z@mD`!b=jf+zo4qcLLy2fmqU~Q?GK_{1kwh=#7EW^2?r!gb93|+p(;M7dOado0MXzV z*g>_L5Ui{jS|I9x@svMcsid76xVyctzCo%1uizm5ts}S>9K2tK_L`;&Jcq>Le zIs3WrcEan5KsL_DHyNKafR?ZPY>VRUeI21Y1yG_sycIXZ>!gDIyrJ~=LrFC)hQ*5$ zGyK+yR~9+Ii}+;&|BO3lo8j89%3aRjiEWvDwxIoK-*(d&3DVG3tOge9c)*eT1NneE zsnxV$`l7f}aZxd{7$PPi&wVj3n@q{aG_8&VqxSlH~ zpEqzXMKb&P4BeF)|8Zqc&jC-5F_Z;)*)A}!3%bE_bB`~s<{#rWfy9WmI{+e#(BL5* z&U^WqcZg}zuJ7h*f(8vG>T&V;lKZLD{;RU|#*1%ndB;n>-CQ|NZ5b*xCE{U3p!c<} z#wy`CcFpp-w7uq}Eqty7({3c|TN#tZsGy{v;{4i;+TZ-R!l-8^bqGJK(E2Ax1*tqB z>^vK-Y(M3dA^TM=YYC)je6VKryvz6y@W?$ZCyg)2GeBlI%B2{H65oa4W|4h7j0=kG zYwz3_`2*Y&NdqYZQLDhf4o`1zfT7(Z_&tYJm9>KqQjSDFNL#0AcYK^^L;;;Sy%)3u zeRAb#^XOR_Y&kRgd+9yLqrbQd6G94+2UT@eUpWpIG>;eNeHapE${F3BbH2k@?>6Z% z+(Ik}OS-C1jF7EdjSEzxRYs5p$tK11MT2xIA0%tl9(EPV46B>vtW%DVnM4}zvNQ;` z#TG8YkDuE*TK3v?p4iuJr1NmfTlV6{c@UtOoET}?VdfCWw|{9(t0Ks)a@R0AIHZ7< zw?e$OFI9ESFs&PHpq9r+Z6-~ktNhTP6E;x9@}(wfVSedJ{&UT4u@p?Zj${@~Wp0al zFwOp~Bp2&gZ2_1`H|3p<%+Xxv~u)dzM13%)B$`WbdS&fHgXa(H`1_LpbteB@98 z_&bn6j(`GI$95W-?jS6bGc@M2;WgWDA}Ul`?!lg09Owrz=qU>4V<#OPP6;PtSP9;r z@{^Z_NSJ0z*EFBcu~eOJ(FBQAILWDaq1Z2BJXf@O)L7~ftfSv=_Tg?J_eI_wJsf;o z_&v;^yCE}<@-s;|rqELzzU2BbSks^2&8)Mmjn9H?sEx4Fol+p3R5BrWOwq+RHw&_|V(+DyMR zB=Ux58c%iRxn|%0=^}i)50Dpzmsd}FcmxClC`8E)o-okK4ih|JaT~#V5E5JixjY)U z3QJ2;_NT>pMqB-8P7^9Orsn(iv_=y33jB2&d#^vl4+z*5cpTEDf^m6jbqZx#{IJpn zUc?bAYue3!4~mf`Cpy_?qiFXgf|80_P0KryTgfD?^Yn2&36 zTXfStLunWO-h=qBhRMhmZ;Fp!Yu_W?gse!Ug;7dEoBKQoy9CxdW$HO|hU~~So+kU4 zq6wv&r!30lTwjqWU$<_oMGiwa;JLDs$H?*8zVd8%tNw>r|C7Vs?pp0MHH zIm7ZQB^H=HyFc)sS?|6|sRMh3DhOxUGM-k2YpNZnnwMAIVNvwf98^D!vLePL=n@ zD%P-SxHH>zOfJ^@=VBrJzJ8?#TC+> zCsVGiKj{d)cb?zV(kkj7rC}ZD=btp6*0O)6&yDQyn>Z2n@^c)?`}t`KIwVONfg;Kv zMivAh(EayCf+W!_GG6M)3f#z+?lNRI-b`O2PgcgUuB_azNr7%XCzF1*zDTEm9Ap-6 z3=t`{+`9E#vEbS$vuVp#q_Rs_in_B=NjjP>?Kb;m&2g5@eA>msOMK2&KlYj^1+iY- z6H*?Uu1AhOV|HhD@=KwUp|u${$87=qQKyu-2P3V<*`L1^m2SaWfY>#<_GR0QyjRm@M6}8P59LtYF`cipSC)g6}G`nY?K1?~ju8>rQ|RM)}VLqYs1I zcV#wt6tY*%-kJ6J=Z3vOVVh95LssBE$yNw>A|o6&+00rxWh3f%C680wW06q(}{|qL`o>2*u&HF={kTsQE2AH#l>P-`gSJ{ z%WKt2McG*A=S+@;@-C)OS{*u#M~`n;+?K8@m&Gd`YhZxSz3XnUSM)7=ZL8^l|A6qQ zq7hR2QqHejOtY;N?^Wj`AtP~d%pB^27r_U(dau*zR@r5Tw?**{!}@SQHU&1y4>ahN zk9S(cFqLlx;wb$9JRJ7bkVH?BKFR9MT4K!j;`S^_gMi}|cRjn;?HE)0_XJOJL5`4LQ-bJ zB-|1Cuy))7$`JnnwEX6i7i;pk`N0p2Sj^d+-LCw8TKXq+W3%WroPpbNrV||J@MpDy zIJrPu6xjB0{0L0!i%t>qHgH>TdVT=|KEOt}I|d|sBi{#%d(-+C0t6#`e+-Eic84RT zuiiGuqgFaPSBnm?r|;a6i+b$k3>zn&6u3Zc6tB)UiEYm`aLlfsUj!li-n5Q=t?`fm z{;~d{vfuWi1aYg}*wUO>{mg}A(ktbm%~aPeW6BOZaYA25eLAH8}e_PqN&R=X}w+k+UB#?#JRYPJ0GZ(~N` zs_s+VK$z)7-Do~&7OPNyFd9!GHo9sIpv*pQ>V!GJyAZgY{*26>YGFICW^zuaL_bM) z&+wl|3+>iK@d5%y8yyhjWSz;Z3kzpb+4n{GNs2qn#X}6&avJmis@h9tm`XhlR8J{V zA!;mBC6l=<7<`stHrBN-mK`*w_+2?iIl#&HLzF2c9#UlebX|=1sI3fsb zP|&et(#{%iu}Hv)+vwpdf0`vzZ`=U0Gd3@S{k7SJy|=9!IdV<_U$_Qn(AiJn`-)GH;ja2;D`IC>lP@o{GD1f{jITRv$_#Ey#bn~ADEx`2Aep`&;#paK znz5slq4JIlzoG~D&7#$E9=Ewiet#X?k6$a{iQ;jXL)Tw9L^XEilJ?s@`JLX>^?iO*=*7y7jYt7N%HW=Mu^BY$0ehZxdU01pbn+2} z*v#)Rl!)*+NUoX^UoXi1t2KGzwccu}R)$XJ>nK+nyUOONF;!*+=XIS{`Mpu&@TV16gC=z?k`x-{|ngA86LF&b>+)v7=w^9ONbN5Al%-9M_^n|KXVbTqo* zc9Ho`P#w=o3^KT#>;0si&`@mLzQQ@6X~O#NVmJDJ`y|Ki?rDpq?b<{HI!7s(24L&- zqd6z5yvSjgZ5$ z|2_c>ehHey2dThMIYw6k&~vbW-IFW;nzQ)rA}JrTOB^VP^}tsPrxR55{2(Dl5hoP& z`Wl#PBYP@NnL>@^y#jAdfrV*uHMK^vsIT*bE5C;1+c({617;!{p)8c8iK31QQ@2=v zYEy14^}cD>H?A(5=~{oinm6*k4k(8o%eIRWgrATBbXX9LuhxUydCp;1pvuYUcxUgE zk~Urv_o~>n0<>A>n!1&1ELUr=>gG>%*0+Z5yU?D+kYv^Iw1noT`L6%Cpagm5LUu>l zUhnIYo78kZ7AU}8lgzz0pfdD#8~D_%LB#bh%|}Jer4OlKOC=xX1?(CMNR@UQcgMjt zj!li2{c9TuSXR&Lb?GkB$|Tdpl3|7x2kps3-t>*0;(Ud%v_E4jN~O_>4m&TL@E%jR zKO7uO3SF5;b?i?7&F-$(8b^q^dzhx^ZnJAPbB=DK=(6?hsQ0#ihV>VTe5KIaLj)hK zPhjLl51&aE{s?B+Xs_w$_7pqgB1b)qbrMA&6_?%fpeD6nK7$bH7LR#Rm?d7NagEM@^^Sp7-t za4nCf|8BdQvB^kiZt?N2`#2)rrRf_r0#H{8l(HKeoNGr+t0)!B<4J#UMQF>2kC3)UqzZZ_?~3 zhEl(@Gn`zV=IjtnkyCNdc`JDCIN!ZA&zLdiusBOHokZ>fdzR{5X^Fxs5B2+e$&Oza zuTYJJ00^MiqC!(zEjP2L>~kQ)++CdzI>Vhby-3JVbnt`BoLhCC{wVTQB zlkHR_8G|85=C77h;gw~4>*h$&?%*|?zi0uii#}K(^nJth?&wA3ZBENTNXkX_h0;DS z&T#+PkM2;goTc|A_7}ga);U?%-1RUPeI|Te zq+i#?)#Wt~YoVcYaq7hmKPJO=DU8nTR< z$j2Cevl?fV_ILUICH=|BvDW$2$mlFE0^Q}oRCLW`9d63RS~x}WYM(7@am418+V zDybgOMU@)*ujSG2*JC7=PBJ>^&ciJz|^ji3;ZOj7~nYx{Cv=Wpy-%LuIuLUcXR!%JtYLlii>JEn%Hp(17=_Eie!eA1~+M6#8yE<9(N|!nz|;pcX|9 z_~Hc0UEbp;`PuJ=mZo9XGm?K2M>}0T92P!{;ag&thY%khpT&4v$n~{ak)FUE`S<7q zbnH#4i-8{@>L3(MPepu}-bz!NB@`N~l)?|c?OIn9tr1Vw3BW-RGIB%Wz}%6ye^CoL zJ#qW$a|#6=MZ&UIKi=DLaJN)l+(T34h0`BDr|=1hu*f{UOEs8V!TQs-%!oc<8pPg} z^W}HG$&}{Z1Reg}*l2rjoDLEPzhptupFE;IpSn5jEs}-vxnyq>(7gQYh z$l|L6GW@EZzpcMeYHdu>x>m^d!v;(HCjp3E?~FLzoEWeCDg1mc_&`#TDXt42t2?3< zGV7V3K6N9{&$M*2WY4wv$f=LrDn%46(8(o0;MMn#2v`f|HP7i9>M{2q;Mp-&q&!BXjRtMTc$H4Rw(k0Fi-bl!1x6jrxDI8tqA zQf@m-%N~b~B2e^&K%?(~Z;=2K${T35L2$_e3POk*?%w3BUg9Rk-21 zq{fcU^r`|HB`shrm9_&~nTEeP%NbnYm7+?av45O@ot3M!M$*3o#21%5*7M^B+7;0< zNkoq9%l@A8*jjdNt4_-Z^vH?VaT0eud1mRp9Qo?mjyg(-7GABqyTt-1owL57Ole(2 z?SeIP9|v4_JPBTShV!Xd-p?3&!xM8};vg*xF2*mUD&DG|{x?|-6v`K#-$@nqUYT6w z-5Oqx8+VpH^O{#c{p>{qlZN80HKDN`sVisKKfxpYThI@LAK>gM9;D6ht#rlV*e~}) ziyBnJ(mZ4qds4*=YcGL4`BJ=!m=hF1{kI2PMM!nGBH~2X@WTF?sG%W9y^-_q*M1zM zz^?f#ga+k-PyWz?WWGI=#nYhJK;$e;#dfA%9N34V)KBI(m?nTX_!Z>lnUTQfietJH zSbb&nH7y5%1ZbnZo?VV@=7-KJa@r|uYe2|d<(Tu~h_7Oc*bL?yeE5j_d{JPaq|;&Q zJ;o4&Rc^B0_0xw8>q;_C&(5Q(976odv0-K~4B0`UDt?dj?eynqdpP_pvm-=Q34Tlc~za%4I!+w7Q^O)KbE* zQ0nxPgp;cOc12l7ph7Y+9?N!MD)eNS9k=hpF0%d`jM;$HmEd2CycO>VE*lx&ZYfi&y9Mu)Ta`evp3h9Vg@j zqRB4Ug4F_?>eQOpCM|q{!Q1bzmm;MP;k)zgIOd_}z+2vAp(FfTib&m|aN8&@SofGd z>q_SCrvBuoX4FzPcP7C0h1^!+c2BcUr3}OxTWpPz4?xEYGzaTf5k48%24V z0gj0X1H$itS?@Ql=rQB3+^sJsXCTgtqpL(KEm*WDEhl#3;Yf)?gHVyptbd@L%MyCp z$k*v_>W(#D!crc61qrvm@HO4ikQ&n<8K6*d*{Q(xEL-kjp35# z4ZsXN58utb5SL>+Ug6T3C}+JdkQeTyIz%u3*?1M&)1GO2PNNlorT&H?ZgllSt;?E9 zWiF&XA~40Q5O8xPRQj)@ikB7m*6N4~4j(1nOcTMCcQ0tDD6Tv& z(E7#Tzla5mTvn15sn$Y3F(RDg4!Ut4_^5(R5%{Oc9IcP+%{8O34Wn$|&u+oCem)d- zJCwV^9i+D!$s;Y-;mtJ;eEI#od@zFvb&zV#B=&i}RtYL!Qv@j1%eVv=!n?;G08^hE z4znW3Ti<+P_!CPxye~Wr6}!=iA8j>60ADl20ol;v=R)Gh2boQw-+!)BwhmA#mB7wi zsarV2o*n`o?${ZIUU*d4?VT15Cw2t%cwxXo=kmrNrkpJm?9rh<#YNDZ)R4a1>VKI5PDUI;4_#Hi&IhZTef`u)|E9=J;5sM7+Hrl!v|vvUp#te zy=|M*KVA-GXyTNw_ERPYgtyISbmDeO>+frYu^qfJ9*PUC_Lv}j(CH*IiT`w<($>3Y z%AG_ryf?h}3vlK~w;tWB4hga5Pdv2S6O#BWM#f-Yb64#XM`i;|gsex;m9o#kBWmX# z(ktt_u_c)N@k*75VVSeN#yEm49Y`eu0smLNO7R0X&5Bfp8XSUGWvmWG0H@-?PlU;$ zjM%q1S#MI%ih2YAawY`jj2h{8ce@Tun+u48w5hQy`^Ka3Vz<8|w*jQbmN@UL2T%Bd zvbr*K%l#q3{6+@;z#pke!-<`Xnug{R5#6MDbmQ=27gx$;cZa*ovGTT1&1`k~F|C*a zdJL|6I@q(o2BeH;!3*!|>ZRVXDSVf6@q(OH5k%3tcA~V9rxa;S1i0jxdLG~kB8jc=%qr|C#XblqnOm3fGo#x2jC3Y z6YShENO@4)uJXNU8XhBXo$+0%k==!!5Y|P^bi|s94vs<@R&M7uQHJXEG0}f@suVDv zGRp=|G;vDlieaD@-F!SNq4WB{v4wr?`1GOO?#aH(2n(X+Aar-UAOGh2o0tY8g0gpl zvi%i2b~ISpZJGGz-aCkq*;%}?TOFntc&+D5qI~x_a;RC=yXerKf3%^%606};FH^Lhu6bgkMw#&xGj?&Q?k#lH_Jy~%Aa zT5J$9e?lCIHn^wsw69Tlu z>DvDD4cuSp<59o0_qLs%n27t{Shg}9=o7Yr#siSr2G-TpdE6-ghM;mQ_!RfiffIH6 z2Ogsjs>_eVB<$!xu`B>sEjQ`zn~!*n8`L~RN@+q7SJnG+d1_3O6Yrp3W}jKOZ#NPu z-F+ru{ee(jb`uAm=5)I`L?QRnZHN+0cJoOT*@ZB{S9=M?i)cMO$Nhb+%1-`MucsPH zyDj=vc)L0Sq*)cn%JgZwo)J<}p~%GT4Z0tsBQT!?t94Kd4U2gJ0unGpZ2NIuyqICg z!tC5GwQY+xCRbHBB0q!H%NJ%B#{>lfwoIbqh=fk^rxSFSmC0y zY9DeUMIU@NSwm!0JK16-US$Tbn;SXklu?Tk^e9%JX1k1OaR{Kht;4w;Pkv&8)7k^S)TQ7waMCASYS#{0^*dl zS$rD7Z?Y|K?CoQEb9E}%QJCdQpYL^FG*An(qS_s$-0%{R|CmPw4fg>dY}Vs#@hiRc8exOk_Eb>ZKeuTnXJ(V98(8udQlP zmwb#En2*OyscKGM=ou+Qf}I^yLTpKOKi`-Snzu0Wm2a^l#yF?^tbc65N7eoCIrvxIm8ZmiirtpGl<=S?H75OxZS7O# z3rk#K_?M5nw1E{Ki?An>1UTL2sAXXTqhEh8S5L8z$*9H8Wq&c)Ln8;YnLc@*ACK3& zmQ4}IYdvL9P7zkSiJowEEeu|1m0z#4K2#hxVtpiE5ltyeOb)az_#C z%O{xwQLR!JCL8yJ)D^rPrJ7Ibccr%&@)1kx7CPK$Dk5pvwLT4Y6SCf3Qmt5`lcn~< zw|eM|ba1kZJ6ogd3!cwXDM2?K3Dy*Oq`ot6z|K0n-)x+?A9Wb&e3U$_5Fbjn?Q^#W zZ7eFJ$B~jQ$XA%WlZ(HYDY9MuURkkS|u2u5U<-9@lV_wfl0(PrkzI-B)oZU z6K(*S8^F2_y&j(U$qMNbbxbn$p^Jgy4QXM>J0&Hw`-9=3mQkm^~Pky3S*mA)}>2){zhO z83s%ncPCmg!=pYizMA3{`dpl^sn2Vzg7%#A>VX=|sKV^IdPHtC^T#WKZvaz4&#CMg z+Y|-EmAQL(vc<+O&Kqu@1#aD|vmV}X2jJX?6uEpil%4R7s7Jgmm)RnGZs!1JEcM@0 z5*WE31w#CQ!W%JFV?THi9R=7+e6}-ssE!Iil?!WC3UwaVw1X`E&$TdxHp82mn#vm_ zo4hU;#)^FPrEu5}ed##iy*_C%o9cqq>$ zc5`xWPv5rgS`!{wuEY7gD zX|!{vEC#5?3wQ8XI44}h*CC<&dQrNZiWgyo9IQdkxrWupPe9fgd3A$z^hJ1DpEI4k ziYT$C)m``}1$DAex&+$uNi$sb&m&*C4FpMSXKper_ciBT3L)(J0PR(X6;S-WZ^_WZ zVO#sRtF5#H;+QL_lll^7AFUsO$C*fmRj=)wqCktnCG##vL~3L89T{gX)? znuKg2(Cwcl@`(rFQe<*bGAdl88iNrf(RbvDFV1(G-sH3O27{M{h0+Fqj6NuJWhRjH zq>4&gec1r5THsqp`2Mb90GN`|U5G$Mii!Eh!b`Y2=8`4W0}@LM%fF;!c3kk?-8G?? zY)x(V%TZYQhl%m~W{$9&udCzTg_5dVAF-YGx?9ewd2I05B3$zPfCoy+ zcUQHu&E`_=p;Egu{Cbj9Sw1+SFX;sA*Bpnm!}Op)^^5I={h9|*?~S{R$o!EHZy&Q= zM3v<|FsQ#`t$Cu0<&~kSVa1BO#uwEycuR#}B%VKX$H6;X?MbRR%H_ zkyxdsO2BXxWoLZqeSey&SEucmjp0%RfEuM^F49NM+FwfCMV$TSyI`9G+bpc*Zt%hLXCZ3*GuBk1Itr za;+x5NV|z{ZxdaMeFHhUVwHBI2LPUKqfYxHN~O~;_UukiXnjnQ^hqO=Vq1WuhP@Y5 zpYiy{^o9C$8a8LOsMrmvR3LE&Qf6EXU@Sv)q)eC=F%mW5 z!%P(w?g>Ea>2G#0N=eJ9#ks2(!czZ*@ z{U}$SJUjncn|Bc$ddY*DW^t1}Jp|#`zb{mOV|ysO)pLR9l#Tj21#G3nhIn0cMa!U^ zmGdqZW*yMT_TH$_kg@r4ozsVwd(AZshB}ne@&-sE>sZJ1#}=7pAB{|q0k7NE^T_D) znon(jK4?~wsjZOnoNeI6G)2a7ceTeGdWAb0%}SboXK))1SG6a@e6%Q70EIVmPNB`y zJ9YNUPQdW>adOy<`kOzcGpkb>DAo!#$z$?`yw3Ke-44JY(~+h4-Ecr|cf0yEZ6FrJ z4Q$dlSj$IsEul|TpMTi5n`JE^k*}bLe!haO8R8-4AMGp|$0rMlfHHPF%bn|yh+!_t zGLX)qFe+LYXfW=9E7=WTUEqiT-b%&E$JI48-wTOa-;w(xXWqFuvb?(f?5}^ZujQ?T zsMbT^w>|5=S65OUIa_)aqQ$*rh|K_;SsQHR@J!@tF3f&sjQ~(8u%$5BTT2`MCr5i>~Qv?6aVq`8|`S7ESbbE zh~1v*hf$0{XRP;Z{ zHFO_=@ZfWcIY^yS>P3f6ubwM=gQak0-%5VXeYH!+nZ&ZetuLB7%!T*%MHcQ#6=1W# zwR}k%++NHd{&2}~5Dq!Aa8-Oc7WAnvWbRYsP1Kuw9yy~hi`L1#8>`DWf6ijz52mOr zR0UWTAni@S(3pTLZUyvvjHux?YfiwATVh4^+^+5X%DL%a8k_?%yO(NMZDZG)4m}1$ z)7i?c^-B=mFNRMSl}YJ>W1wG^4j>j!^#Vkk=aryFAW6++6(w+{Ym7$yTjw5N;>V## zl~?4qOHt)NudCQ2(3H`zuH(G))|}9wDuG?P3cC{x@Uv?x%{F?bY2@3TZ0pzDEdUH% z#dqj?QJZGIH@RHf0 zTOO9orJm@wi7oqn7V=9=4Pa$-qo3tKj=jIT8>K5D;oIiig`G#S|9tU~ndc~Y5#4YTo8=Ux;h6uB5ds07GMDGww_C74x+Agv9tQn-qjz@iMtV3Lgp>#v9WdE$ef|I!|#a2Uk{U`Id#W_-1e7V7dk zZEBgbK}ywp$Q{6-sg~{!vsLv9!x;01(6L5FjdDvLvmv*+=^&(*LB=hNM;3kr3@cBp z{C%H#cSa4K4V3BG^;l1mC?BmTbS}O-H9bM z^rMiQaVarne06B!n^LIz0=6TnRrv8Bp6AUnlFM2?f@#@;af% zE2jifkx63Ckwrd)QB6WJ1RqI^VN~3!wl}LTG{ONojZkja0c616&0(k@!|?B(FHr(;umFwCH~9;_m3DU;zU{2oYtVbD4Pf z!@}FTxsD*3s?SFa$RkkWyuxuxrN35biNOhvYCSv&oHp&=Z6HK;y~qc_RAbq(7p;3t z)i0eIxzvHcylT5C-k@p3Msq1XC`2c<$?hcX&{0||?4R(*2rxnLdo5Rj@m%tg=Q<`u zE5J`+yw158u-yu@AFR285mUGTaQO-b-}TozTCU+uri%cBvJl(tt3=ix!4bCDRj)`( z3k_vGQU-2T)oFh9!}c$&b{^}uZxmoz_emP6ebN3&{M7vfTR33|Z%9`|GEz{Tce7Xt zTJHAIsOUj;{9H@5l9s?L;j76e47$^iF)PLgB_F6)@nD&Zj&hH&tAKJwJ^IQ;l4VCY?g*L+^j#I-#imcvM+d< z^1y4~E4s7yMj@q+H&s;0!*eWpj;_AuRJGg_CyYYtvgME{_)n>y~0lR49gPEs?G&DjcoM zlWJD%A@OfR@nXpcJ0J0AKTKQ{wBy@GCr+N z_PC(4L3?+NKDX_TYAWi;y1~}_u!-Dy-b%AEo{w8W6Z~rT$V8mMw~=T-=nNbr{l8!Oz>_1wON%$H+3Q z?fsKmv*PiS{i^ZM%zMP8EAJn{M8mt_luFD^K+wesyCE|iL#@%QL0cWiW`=BO-o5kJ zc17c%h5Mq0Nn@KYUy+Ah7*}kD*}KE)xx)j(_0>?oho&GQ6~pPZ+7z=LYBhBnMxAWR z)o-`{QyPl%Wx^uKPaBt)D4QC*IQL1pYHzrUv_nOIC`gu0v(+il zSe*HLfXSKmvHZh|EgUd}`oFKdqt^h*vJtRWVRih+?|$U+-t@<3fXa(LY86oU*NYQi z*_Z=$)d?p97n@sN4qD19-Z~0YwwW-$CV|e<$a!1MVjt^m@19NBtm+$lzd73=saxL; zx+a65IJt-`5s#x&H-%W=C3rXY3-2mtzCPsCtITpdJhhT21QE)prpUuM>S53Eem-xV zC7TbPw#CZ`pMCrE3d*{~o-qQr9V!0o%>8ik*2P#`GH*^|2mM2*vxbv4zAAby- zRfZVY19oRPf@oYZ8n5Wp%77=1{rKYIB`Ym0tr0*W3k*bD=G-df@eu%`))^bN$9`Iz z0EQ3Uf)%$Zcx_$jvMT3eNoey6bjT7vX# zVee>R2C&_3RpudYR>rOHTjd$onOQr`jgZdHTwgKB7+W%rWB)3&KK#E9+e9XPu0&Xz_EU;x{r`d1AWMpl5)T z=G=D2#~$@4Nm11vChq&MExtGnBrji&U){vzG{5vTKI)4|%5*Wi^~br@K~z&Kznd=- z6z;H~G%=c%jB|4vgS1h6e}T$lx%$4fog}tu{t~S?#o@YgW`t<2 z{o!UqFO@^U!mzKjQ^!>2v(@x3D-0`X-9=X#JYM*2c~0REZrDh=*{4etzp1BD)39B) zSg1&v^IOB~<2;kia#|?BG(DxQsB3ZR4uLx(4FfwO!+{J05Igm2*6q@n`u=JTz7~KT zx%C$Q5d{*b`%}s8$8L0gp3>Qt2WdG%wUTz?QR2AUAS_9DI_A^Yj|7KzaTvVIviG+e zoB$pkZPqv_|8XOKBG7CzKrLDLVyC?0!4*Y;DRrBOiMjZ7l%kHNIfGvEpB#E7-Nf|j zwc54&aJ{$e9Ip>CBUXJ55(h5P0gV6M#jxDr9fN=2V^e

|-8%u%ST#=Qs)V$zNqv zzB7Nsy^DR;=%2Lws~gj*4b!;KB%O_dG~R%cund&U*_z1cqVQpKIpBk!hoIAwQS-<~ zSHp%Go!y_)&AZ{*o=_6Tk+}xfi z;=G~jd~$f`v%JMF1pe6L$CI-)+>~^z;RB1Js)ninr(DFxef3645AdmJlAFN$LsW(X$ayh)A&S<=D#*n_HbP$ z1r(!Qy!-OXp@l98Lx9z!-Ji8GBBL@sTht2=pGHhWj!8{xM)n+uB8WxuE6g8nTT{lj zT+h=nd{P^It@6?Yq%(-QhKQnue)Z@$Ze*A6r{0LptTD)!dWrjH-sq3yD7*mnbyx={ z$I44KTiwpAuS!a|3X1xGXAnMI_v2OmKumL1_}avBd69j(lj%E|EAl(PEwVcbs}BZA zAkJc^pN!Yf=FJ*wYzZhf6WUF~DUM=^*GEb`M=sW2l7`?9(LOwQsIn6$m`rj!Z!z~;F>8Ikn*rBdebqqJ|H7n!8Fa+# zPZ_vak7At`6gFtE&HLS5J``bYG=~Jo2Q|Q|Kn;COj4cMJTdg`gUe%<564ju}i;>=) z)Yl1bsU91})xsj{R0?5qONd@3b~_L1i0Ft5z<+a_m!y8I^6TDAxjcVyi-+UY-Tvw|bZzaj6PdKr7y_BI+E*iQ-x<0Dia^%tjo){>Vo5_$l zqH`B|fLQhGI@~gdPtBXBH!-~m;&8`JO-*BVY#yx=1V|}4l+7#Ai@!AscOCAy;==ZO zI}q~9kWYejg{HcAI;+&I?Wxf9C#kxLG>rG8bH17oW~dLBh^0nn?gS4~pG?{>QvDFJ zemcHzlx7A?4sw!PgNK>YHKbpes`TVX=Xz7}8F7J2_@{r~Sip}*kz?ttXC>^a2ApiT zzb&C+6?=&7oQ)`(ZEj(ECgnmalK2ISo;8KasECc%ok0)RX-E`G`DD-!Js7j8y^*ZZ z@~Z;*puYh(u(w1^fWy%^8;?@?lAQXd6K#Xc%m5-R0sp`%<=S57X-yt#qCT zc62E})oyWeHT$vVNXK&XSm@9MDrQf`cvew8YOzVy~SI*25d=(4y!sC!d)FisBh)3fg*ZBbV< zb^#5hwO2I-smV@1@U&Bvw{p+sf56@?L_(!6h^RRqAo@zxX!gl9V?+yCr^ayqh`otJ zFTkU>n+NBApBQm@Bdn(TqkyXm08kNu*i!`3&f}TWN2N6xhm05-2}_cP$Rtdj=|{>snoq!7j{j3Kceg4p50A|Qtrpfu;+*2428oc*HAWM3XA)f< z#JsO=Xz|OhS3{i@VLTRj@9<>goIN_Y>61n4EMC5#CaX1bepLR|lNI5+0;*e4_ZN1e z3i0DMF9QW*l1>&*&~Bhz-pB_%T?=oiVhlWh_#I!I zxQfB_r8?Sgdr`z1ZCXcsFcv79q|4c(zWB++#2Jp{g<)$sL!^xKdZ(wK|K*Ya7T8dM z;|F_&2hb1B0~Eh^MRRd7NtE4s4lhO8pJ;!w%^i6|ZX$8&>J)3sfTkzQTt|{~nGw+- zVR399Z>|7V?Kx5PUE|)~4N~R9TE+b_kRxHgVH~L^?7|&SpiIs_PPNU`TLEI@-B=oRb@QhAo@LC7K~*7d#9KUx#~~ z9*w?V|J8Rr-c=q0AGcR1`i`tchH#-d;C!>%(Mck2KK&PE ziO}H)R|&&2c$}|=E#!U#(-5B`UT`SN(VLLuLpP1FrtzNt$`N_%s7dGJQWEr%B~=H6 z3`INY$fW``HIY&CeP7Y=iFw#vXyIw9Ws#oxaQIufdcEN8wqxk(at-SY)FNmJIR9ox zCo^BNOrR z_J_^4lrD-8L&Q1!K!PRvkZcQV&WVfq!u+0El92T#+~YTK+6c znhVbhODu{gA$cA6dIfaPV})gk*tu$FxN|bSnh;}KSHV{%4LeIL_#7FYP8AlUy?uzj z1Se|{By#R*9S1j@-Oklfq#g*(t4h5PA&{u!mvKZ#0uJUTEY^8Vl2tA0t{V+5YFocn z=oLh!-VgKlJJU zVjBo@R*AKVl+t=rjfY-FTWM)&4M2RlKsgOBOT^>zY$g#Kt9@ROqYZZ&sFqex>W^9e z0#Plq$kRZiyM`^)d~jo2&OTR3tmU1Gli-_?-fG*ZvpO&CUmPWgUV^l|zRBZr!Dv92~CArG5{#8~o7IAdr&eTYp8Ct)V{FZi@ zhCPM}WxU`>RMu6;!fGRBBgL+--m8~&Sa^cW>tHL@)0~py8>nf`|D>>PMjnwdihz$lxmVT^&>t+a1rE3IX$v< zuWY`S{iT9=Xcw^lL#i7I0!p8Pc2fzDet=HWV0n~93RJpzaE{8xUCai@*K*p~kJ@Xh zAL0EXFpW8}9rjgbh^9N1S=S-kg9}GK1#Q~kTz3#N&Sq)Seb4#X5W{Rxfb`oeazsof zlj@Ic;e|{7sbQ{%RsEpMd3tfBRE~&a1Wpg5Brl=0xpO%VZv*_f>c|ZS7GE-GB6m1B zwy09`=@2H6L@5K8cJp3;Y3*PCp48DX4?#b9pBi}mq6p?Pi7gH-$sIrt!rO%?z^}6- z^}H-JnYLZ27OU@My;UIeGWl;O3IwwBIl%sNZ=7l{X?F9F$|^2f{}!&HDG#oPFoHD6 z_%oTnZPgg-1U8XX69uf=~ zS`@BKv;8E3~ju0JmBpJd?T;RzKJ!=tjqvkk2l@IN(Rg3w~+Lk++ zJ3l7w-TsYqFj7cWuxKTF^+C=@tdjnRlJkRvJ7jh6Qj*!Ty^y?u8U$RRT{6?CH1u|>qSz17HgM$jeHF%w zZl_loay|)O<wuryNDy|A4`OoZtU5o?@1dlC>Q@p+=~Nuo;zqG8Uo)d0 z-%l%~@q>TZ5$UM`7n^-U%U+}8R^-l1U_?XhO4OUoPAPs}!-JT{L@*g-H0F8iTH~RX z3bHgekCeOvs{?j6)gHSuC$oyIEc*GduOnL6e!($u=&Ev$FuQ9aY>@nU_k4Bg1nmP2 zjhrahX1sAoBR|TjpG?dq$5ez@pk%zL61el8QkVZQs`lMWw2EPsO!w>D@Lzd4#?&z^ zWxR7lyp-FH9DD9&Zuly2($(5msO-M^rLS&(MC(bxRG$8lllPhsw)~D&V#9^a)CqeH zTyvt5dPen2+%2yt_yr(9F}vArv&ww`+0|WUzt7ZcJCqk=xq&v`S+=_}S9@&`6LFj8 z%!Mn^G{5q-5ZaYOoz56Y62YcVPSfHhEhowfo>Th5XSLDfc70-njNDf>f=;ZNP8`2B zl6d#4YhBpy9E{XF(^sow`n7l0?pHH1lxNQ&8`p08^7zB_ zdr|sG*)sT8ut5z}`o+Q3a8FNaoo2a)r{x5zbKH-uu@1;hB~*G&%iZI)Qplacf2xmC zlFK%n%8nDuL*J^}{|Kt=oIL)!2+-mhL-uX#y8O+EudlLn zXzv!|qV*MCN*1q2H>&))dYyY>Zx;vjqjJfhotLzEt@FnVjkllfwwOFi71IlHPS>c& z#rXBF*|wMO9&{g)yYf7W$KAT5-Z<=Ej`*-@6)fBBVKskSjQ($}p!F#-65|>{QTpav z9+hH7&z>cl#IG6W9kWo`rL@QL`3!*G%p}pUElshFTCP`_eHU7$si80^8^a2=ERSvL zLx$$Fuf84sXXh-X%C&Obl__XHJxrWmp58qxUM%Vi=-*Ff9pVw@%&GK@RfleRKEWQe zJVvbG0{QrlMaa;OEumN9)g>kstbw|5_ceB{e5;!)%V$n=f~}HHp_MXI&MHAgc$29N*ZM^z+y)Mn)P@j6Z zSuhZoB{N+A|0Y!)TD=M=S9`KPst%+}{a zymx*Z1WqO&vU7_GEh1RBXS_P$6O;$|I%-$9he1eV0cLK|LMk6qMbnbTI@1!Lkx3a1Y(pH*q`7m(a#?YV`pRwjWXi6sX@Ra0ruyx zQ@orG;KN8e?1|$~24#;_#{RP7{N(;NN>X3#N#OmtUaLR(V{uK7WzBo3PG)3iLV2V3 z8K_YyDNm{4x3`slPJU@sjnvC+$VQ(_yv|AJcp9kZd=~rN@tbw~qL^OO$(`KnXsiqs ztE2tTt|{HJSsg6CRyWNR=qvW`^dPyeh$q+D8I^(e z+D-C{e_x8_1{mFgN^h`*5!2_!7j>{Pr#T7#8MEk4YkSi;yvna2?3c%w*J>sgCI z1lGORO3Lqj?uPS}%Gb}e)gl)o6JHk~YAST9y-kw`=7iwPxgXD3P9%Y!iKr_01e&U- ztG+3L{$iqtCe@4lMgyRkNsVSfiI{D% z>+(VwxgdO#qDBz&pfz@|!0sg{KzU-jcM9s@BOJo4SCt;bG}ohW{gxrcPt+VDH8JAT z9fzCdY+$|6{m8JK6g8pJluD^Gvh;l);JSl5*z*6o^8MLNUXWa0c)5P{HV2J(-AD)Z z-TfuP`{s=G;a=zf{~VAwAqBKj6tw=mF8_Dmw@2CBqPSgpCvCDOtpzITxFdhr2}1Q8 zV@Xfut&NbO)mMI=0aI>T^Z~h}giXKBo-F_)0@Dll;cj!jAG2Vs&v(Wi%W^lG2beLb zRo#Tezt>@5>66lEH=&LJeh+)T>*JMb*9gUzvr&uQ&CoQEr>8u1)>(iZ9#CorbD|RMqiA`zw}t8l&bgSx7-L02jzjk#)X`bnQ}!oz>b2v@XUU1ZGWBEkX`7F zI4b5-O9nl}9d$(#0h4>luA|}HE1~i<5DAAqojRSXG_c|#1q`QtW{dl+Y;8pVaGnA6 zF1q&MU_dB_Z|8ol2u6oo;ZWiHf@c@Y&gMUkH3Y`#uE!MTtyT-%4*krm8~aaVTTr>! zM;v6R3YVJ8uXIxJOrf5ccFwAL51*+4hBjt9Qh|dxQs=@ba%+f`*vQ;Y$y!-3bAX!v zy;Z=~xg%I5Ke%?*8LZ{4bW9u#mYGfe}R zs0&RA{M5pRk|KaArTUE;L5{9b<4c2wWJ~FSD5&%qeJI!Y&kI364un&8t25wp=o1#| z^f4_f%lR=H*a`l0|v7n4rO za7cEtmlnFmkWlOPa9Y`r_Z}RCEZt!I^lFWWt#1n&rb~%brut7EKQTZ#4k|ssJv*2N zA!$)HnEYL1od{N^x!xkR7iw@u^a13}jk>XN!-h40Ec}Fd=`L(RnH1m@HrIJ z>>j~X>xZdD8@(Phmn+_!8!zVFoE?n!7kMr~{^_S*pUwTewHOlRSf5b!^v^8MXB|pJpKdwj$8syWPp5z1SssJ0X!=4sp`H9HjZt*^vYLB6&w7w{T zcC?5h+(;wWz?kC@fw^$g@U!Q8lX*@A3)PD%*pd2*My+Dy=WKbi+Q;wYeeqebLh^d= zd{DvLqf3~yvv*kR#c^FDOdy_b9^@MeIfP%b-i(ph1S1VER#WtMTMs_glNZ4Iwa}=2 z&w5be9j}Z<9@HRGR11>C)4dVx9>0%#Zl#V4_-FyOxje;xj_;{7%9f{sX_$HHuv!Lk68Kq5ExnCT+GMZemM3@i zmAXVYRJ!gW6WP00;D0hRj|>&EE)XIP7^KTT^P6Nc)j8W1#W@W>lih;;F6?kt`sy#? z=m4Sjl_b$o1K(G;$Q?JgRr`-Tbx2(%j&n4^f*A4rK~N#>QSc9+h>@(TkG!_uhUr*y zh1F9xo1eXr6a&H49e?)lZ5Ij^ebK1UzzY;T?lk#8{heHXOpvtvdea3 zPq-!yt+)Jb#%zWCw}ycRxoxbugdN8JUe@PL%RxJd|K$*_>-)JmvMBVDZitaTd@O z@iMSVSnhb!(^x|q8@CWdWT4jQa~XDb>(H5DgY3UmxPx$erVt{WcR$(>CT=u6=geFG+>{C{Y!||VAwEFm25=!XCzfPwZ^g4g&Ljv{p znyp)b6n{%c_4xh8p!HZ!pn5mOro;9%*^J0+?9bHj*BU>)V>Jq;ii}<#sg8d1e(z3x z*ISo&-%wRY1=fn}GREHA7W+MOR3JE*o=i|F0YwZueN zJSjAs_}Rfzx4r}qsqGKl`P_MO9@29Kf!+sJWk~zW+*z7xmq0eZk_2OlXZEa+OTqxGg9XWK8lH#ymP{aP2tChdejcV! z1FC4lv1nOz^W2uRw0vc82VpDsMct^~+b};1*`|2S_FkOHT!Gbna}9X-!#JVwr0$)jJh=(&H9$1Bb4;p+$vFb3mR7~Sla zNn@ZWeGB+S9D?Tl;=;UoD%-g2od4lrLF933PbL+_g?=X8%U3wwx0?>#W01Di`@9Tm z`I=aH?1W`q9jJhsd{_IN#7J!kD4Zuto}D}rKCUiv4-zgApgyVImHSe(dpx^P(4nXB zNd|n_pQg=AOf1{^KlS0gw7MZofV`jV)%9GIbKU^(Ak%2`_3*#=ZGGZaqF>Db7;;7b zB~Aerj`}U0oOV+s^#6)5T}vE4C&sZaO7*`I)Rckb(utAAB(~Fp74tVUxWBQ6

f7 zy)`=ZkN3vy#ZPLp+ZV{Qh-j}lcDF?gZB9Jt# zZOyigEog)Wvpi4cHZ{rC2WDfRy|k(TQ)n=t_FuM0A>w;h;?MC7HEbh2(v>+5+9jIA zmKfF6{wdbHCaUDc<_%kYRAc+6LQKFn8}?4rHV=06Pf`LfL?N(hLYt3vBEyjU6e#u9CDh6!I<$|P*{z+C|*dCymNJ zMT3L9=BUvKMVRpZ739Q>d{BzP9G+bg9iMFYzpiIlzXg^6`)h0YxW}+i@2DM@LG&mg zY~F+6^DBY=CJ3llH3;mrfTNYT%P8@Ibe`P2H}4W0zp~gB94FROBY67k8B$R&=)oTu zX|*=cqo;ATTiKz4g=J!5a+2QYe-+gam4=x?I^Sz~LFlo(?n%4dY)oVRjFBMRcjCUS zr_rssgg#vS$;U7iz-5;ge65&;HDWd2?Y8r#>|!}=!H}fWYmmjAgX13g=MCCDvufC= zyR%UYPvD(~(Z!eWFXHWeLe$M>6Z-hFD4AS8B(U{APET$!zF?u?#&hfmUv1+pF*d~6 z5yaUQqS7KU$aj@&KP0WAjd-uxy<;V+LSwjb>_mHow;x77VS%@|Osh=l*h_t?tPzU( zZuHD{u9Y66k`ZsU2ewMKP)&W~xyRJ@`}c>1W1}Ao0RCllVz~H(WNxVs4RGUgbaSHg zv^iS--chS=V73qK8lq-Avg*7mnV6!kUoVb|gDHwg3P@sK9M>&vxbvppZDI~vxZX=*)5mwcr9}=QvoYc| z*J%FY=JJsu>*k`X3U#RZ^7}Cpz3d2`%L3Ie(-u@S$*k^gU&F|R@~}nD_ktfLN#4P& zPLF4l%?B}q=h|dor_YLMs|OB!x(~hm7IMEXRRx@kGHC(RqnDi4|53&@*9rk>h(7mV zA|pB=!cj~Wl{xoOq9V8p5IsU#0GCK#O+mU5cyYBXA!fgSSe=CW9J=T`xeqC8xBhuD5Iz3G~Jlx;s47zVeMQvdLpK(PwsHxe^MxrlYe8wGRC8vIU4^vuyLJCIu z*)5IP7(Fq`j*|8EM{qmP%r^r|UZBi76%7u#|IQE}5?CfztbI`VI+e(=$}C>%XTr(N znEup^twt3RLj^jDRet6xEF#()Gx|1wYNQVsvD!eS^tXddOj@tsD;qlPN2AqAAyRA@Rtti3U95XR-8CjwQu+=q=dlQll zY!8U0z1kBl_I0lEYa;SDB8{UMUDNsZ8eLa>FQ8N!_O|K&m1lvbdO&#=J|to!+XL7{ zi#7cHTkoiIJ91vSvHI^*lAQ)PJrD=83c8OG+r%DEAV_}~bkqpMZLe68<>WILh5au2 znlJ|qu5xj#^&h#+cJKcleP&`fG7!IDb>>BT2gTd$sUQSMqo0E$2>m*Gu>P;7(=&IIIk{$vFQr^C!+ys1JiSy_M_W#dBx zD85D>M<7F|^!^VCV?WieRROREByzd2VgDjw=#va$sfJTmet+RvT&}|r1ESE8v#61O>n}?in|u(2rI+e1i%R-j;J4x0LKVL^ z3F6YP7)B7ro@7V&7Xs#A$LrSXBo-OjmGtxc$jSKWm)>KdHW!s9fly#w6`Ws#4b|C; z8(!$0xf=Iw_QFI?Sdlxfjxn z$my6~k?y!3d%vrBXEMyvSorKrJC3U zIvkcm|9NGqrom{xZ+q2g>A-SFVfn#8k3#!dva(dq@%=IY1S#uWxict%4Vi7m_0PF7 zj7!bUjp0&w$O&XtAUZoW)HJNW<052=azY<%O`cim?rONc`Y_-s^EVOc9yUhqO^ay) zWfB#R>RoZW^K!sPxi+uCXen_pRnJ1lqzSK7)oVy+f*u%ueegn6xZRerf=V!x-}Z$f zP@@BxVJJ};*Avi zYL8@Qr5%7mCB+wtQK>8S30@h9R2JsBE){=MZO^bMjSWl@*Z7(q!{5PuQAyDbB1O=G z4s^B#T;sotpGW1Gz9>sSh35oWOfHw6zzDZm=Mp=0(@=hd!zDWsz4(Bt`m@0x-G5Ai zUFGdG4;F)37Q30UZ&bX`X=u0(DNKJ@D*A}K12X9+t};@<%KeaNXLpyKR_Hwoi-7%% z%0q^Io+8Gq?7HeFpf|_~xF83d$1RU1@X!Dk_1A!tl_tQAzqaS~1XOv_A#ThEhz@RA ze&*~d7M3@qx~g1)=^amCk%|~#^H!h&^IU5!3-Mu)adC_$^%qHz`HL&jfv?6v4Z8+0 z6C?KwpNe?A7ogpWzz7PObf60BtFRMJru;OdSuF~l%5-^g;fj=9d3W17_UF^P6FIkHT{&$b1^mWB`;+TT{`f`ls?CphC=EuwFoI3_^mrHzU=5 zds2A3GMOS2Wed%@CDbIPUDe+ZFFTWr57frbpU}(Ps5W1<>slK_38D}ZG>3H7SL*`d zm2*d8mBD?B(!b?e&&IA%xAflUAFtz^$rFEq!Z4rq#1kWyN%t%xM4L4q+L0MdILs!W}GyWzzHC)ijC~pY!Qv6s4Pf!3=q)WZsA@pq^ zR@ar?#Gy|FTRI09y%`I*@nsK@rY;XnCa^m~YdC28*~0|ZY|L(Vq0`D6H)!%6Ykh1u z!4rON;26`Sx>rADt)=lv6yZhHN_a{o5-0YRis`=~_e-1@!p!rg;~J*(j7Ky$=RnVl z0m0lPBQ=}xyfA1FA1{}7k%};MBeIp5X~qPwu^tv2F9@~IJ7(@p1%(2#E<+$WqrBCY z!*-w#y8(#zNKQ4`D0&GYr^UY-{db5p*IsVs>iZ~OwVoHDAl~J_-D&C&re>qSx0g_R z9hVVgTt%C_Ssh#nhJn#FIYw@?nVbx4sN3x`DJfzpcP_J0a=vMEzMo#KYkXeg7F?<` zggGj12ml8{#L9I6v{dW7AVx~B8m10-CHt%nxDCh8Qb0I8O*BLvv_SnYRd->yqC3t1 z{v%x2JXe=v&S}}}dB@{%1JwE7Q+yBUoSQb395S5tq6v7CBVgzZt%!zbl;6K^N~p2B zMZG)w)r=4=tWktW4V;TcRJ)z3#E{QZcnl$BMe4BF;7@Hjt_%Ea!(`JyG-cAaz;Ybg z&`F_;429qFYTadf-Qn;%DVS0YVv=hGLONd=Wo(=w;rcbFL};TpeR?66UlkF?Hi9lf zzN%uOR+OV_-e&{b!ufXGvQuK9Ny=t+L}m%ZZoC zQ%R$A`lMOgs3%}+I5XledgxIu2T+H>p%91!vN7hr*_aEAA$fQ7Y%qu$K?y=nAsJKp)Vhv6 zV_29%bS9zX6_x4Q$9?O#T$K8 zQ=8MzE>FE{KfroL&*KOpR;zYnB&5`dfmp_#!F zeOn8Lky=H|v%7adMnlSlhcJeC%MV&-{?hS0{jECh#1&{T^3d8u%^Orlip^daqLlQky!bE$8za2B zGXCbBf-4xn4DouaHa?_hrGi+@wA>-|Hw4mYzVB``#NRh$^`r=5ukgPoPW@#NVK@l8 zBT%e>K5dgMO#>w>bJ;|Dbf$_HH5iNE;7X(hsyFxOsikYby$Q(1!}VWd1&g(PSoy89 zk~cwTeG0)KK}G4F%TepBe1yAb*=klHR5I){yd0M-(Do0F3NUp;JzucZvc z-5y$wsFzbyDBTQOZ$wpWz%??B#i+6BBn7BrF9BuKh4gdHL_niwu!RGtc%L~V+Epp- z^8(kamY1HBE5Ga&l&tyRIqD>XaPso?=$ghQkxNS8V%W5>>GKhZ?`%cXxL1f0JB6g^rBri(3ltbP z!~4^|6A-5P{I7bwO`HeXJwS4Jtj=EV+NTL$kmKl{ROEe8lX zsm#dtitGTE&E1UuOQ7G`0I`9INc-H#s7)lYkP3CC z>@6d$sK*Ym@A(cBDVbb64QE5pvffTkq}!29L!Qg@W?$%A-NkmuZ=xUT;V(a671mZ0h;U$>{LDH{nURaL)SFp5rhI|^b*5)+BBv;_ z*1z&K^+LJmDFYHUKx-8J#Ob0!~ zt6U0Jj!<7sDdc^rDnNto1V0umb#))2S&Xj|U_J(dDG)fMu4AB${u*o%YkM`|y@w9G zRu^z)Bns2)VNTDuMv%i&$Mf5K^_Ac--kJV$(N1BPetg6Bg{YOxEsnTr4KoP+vA)`W z@$tj-TQ$sOtp`)(3m+=5I`0N_{Wi9QTvlEob-x13$Nn$;;D^hON3+|D{X?Uh&d*cr z%r{%DH!ibw8WqzTceX$3zsw1aL$1us{mLkha1R{bMy651yIQQ>a1MBhykcgBIA4^P zM-^p2gBxQ)(j5%Sb+auy9&Xr`(zs&xe%A~osoG2wU;?HjNq@kp^WZhpQ_8QO1qh=B z_M#EENw&ju-pNBa;O#mcTMvc9ua7TsMU1)5$VKp+RhhGWHdmkv{$sYM@AeW>ZBmr4 zQG9L&=UNe6Zp4@ueGem638PYvtkUHnArGK6HeBeNKar8ItQIwCy_-b)7uF9YSptTk z!NUay(5FdLKtU?`xdnD4=0ZIa`zP9rdfrUqvOQhg2Qm^WYe;oFNkdykvH0xSg2&IJ zK#SEwS_gwEURE&d)zpOT5jg@YXXnjW$`#g7v0Cp=ir>-1mBav5DrhgBa|jb>2ve~W z`d?3#c(3G$3ZXuoMk{8SsW2ua!k z;AWD98v-8wWIU@fm{Eej_{CB7vErSh!@da`YKq55j3D}YQ{fiNXSrEfZ6E?1{vUU5 z85VW7y^lXCf{2tL9TL(h-4fCb0s_)V#}EQiN+YF`(kUQ~#Lx}W9n!-vgeWZ?zdh;` z=X^cqy!yZS&vkiSzF^|B_u6Z(xbJ(ZQR|8&@j1kObluQO7IITg4O!A<^fTW5{`Ip) zW9KU>A-BzTho*D7s8E$JsUh%f(1ErCq)CaWH^ME@jmjF!hKSmS#8&Db{WW$&OO zn;2sRll_BQ#AuTwO7(8{F?}{i*xOwMwym%%@EHs#DKt81rfe8fNsGfw-hSslalHb~ z8~dd7flABsIXuGUY;f))>@k;+3lqH;r&-|rXI!mK0~kegeA0zeeSjTQmfYM<2$o7_ zHx{oI@_dPqVVQQGHb28z8O@^TqCc8V#?}2Om^1>jrOcQ|t(ikBw~70eeA4x_@4Kz! zZRRB5AdRTw{8+eyP~C?V3Sg?3Q1BfCc;GoWFIE3;?nTLWZ{Aos?n@TndK$40@>R>R zyfN%(m0S=w?!KA+J^9?cML*ymZA}YBWsfgDZ-u6#zRMVf8h_fqe&pF@!ieFx z7^XN!|1}T(Os%XP7}(RRm@t+Q$|7bI`_H-VsQ5$r3%t)pY z-&B5>xHv^II&-1Bfa`j{_O4+TjI%5C7BhxSta4&<>d|UK6Zd1b&n%a-qEWtB$DdhF zz2kwOqCrA$!M`XoW_pxCsjyokNoC2EurkpbgB@iWr)9oIJj)|_uggwo&+f)!W$v68 zA|*(vwR7gDYDI|HGut2*y|%X)(;l7F-DB2O#BR5N_(rGTQ9HT~0tV(r6p7$^!s!FD zLnH$e=X?qqAM83dC>4ToVQAOC9uR=yO9%Bu&W#jmq&73~;Rp$du1InOgj~>tgV;1A zDJkiUXhr>d;#zgvc0>}%v_22Sw!ruamle3Wl}bW0N!EE>`vNcE7E49LB)K`+2p+Yl znvWU)Idg=FH(W-2?M|J!PLuiqYA-*W`S1;^9&bQkpfS)T-7UV^iPQrIC97;<`B@}E zux@hnnLxd`AQjs*STlYK`bOH_lkYaiP}tie06F1zSsN1C$jXwAc9Tbx~|qa!OZl&z;NC}LaC$6!vHUVjQi#JJVzC6QfC*2c$|*B z?+HgGQr5Y=>Ly#i5$vA!C62J8^X&4GBT0sTxsmd3$&S|dP7u)-%Bs0tU6q0+11ywM zXn_q}n~18$BhB0(y%fle*b9nCj&5m|!u6Y##bi&nLZ4qcNsQTlO}dYs8m$QbR(1y) z)717qS!e_z01E2nq!Z$B`HQn-M9OM^&I-?PHN8wmr1JGA>m}RgqVUT8_To^sw#e!A zjER-DcEr-PMSZ*oc4Sk2tLc&&6yxo$KYlf(CjPZ_+fdzA049y&HJj7@zCX%po`CLg zGYc8{$PK%&rvd)lR4Kb&sFx8tO42G)RlV0z=rB(DblJSmDe)keMEipoLB*CB`Zp>Z z>jKYs5nMHAs@Vv4U^*cP6uQPAHyiMKh}@#lp%QfpXkIK|nKiu``*P$nkz80ijKQPt zV-D-!?)2g>;{R@-8EdQqC8i;k+G7+8*?WNiYe((~y>1OI(|W5NSWS7NLhp zHp?0bL7E0Yk$Sb8-+ele9D2OB3I{5lFlgx;6y2Z9^Knf2$Ww{w$s>&9(2g@-k=x^fd~Y12D)uDKxz^gW0>qY0Z0sO zgL2xlgcl33roOg`0jAEu@U=|hqsNM~bN)ljNtWkC#gityxYy!gnO*#3dkKD9O;6~_ z_u5YeN5pf^Z@}Y^Ng|#ETTmEgCi|1Di+u~(GNoG*)Zcn6b>cCa1z&yyo*t(cUC|2} zu+nM^zz{DZrYc>CB_8HCA60EX(b1y87-J&jPV%@~f*m3Ggr4XAMC!}aPWZAE>=aW0 zXyNMWqsKS;EI75If!OH)i9ZfP1hkt9{2JZx$)_wi-EgbWprQ6m(6UnZ*!N*ONq%%R zUTa&M#$n>x_#b`Dj1v!-jXj%tP>MsJ!(Oi5Kp7-|Svnqu;A~R9B=gATyd$Wr-kDGw z+BnCH(M1?3`KkQ6J4BUCVa6uYsy2D3$=BCxAeq>0YrX@RW|<>QLfH3$3@mXxJq(m* zVty<%&PzUw4Y;Y*Zp`KLOl}Wh$b082*&0@hS{1kH_BXID(HmehA9=0otN^(pim|ud zRbxMT@3>TDk--Oo_zJUptc7 zRW%FzBoEYyCQ8_=6cz@vl&;4(s!4uRC(?*3zW+S@(T%)~n_wyq%R)IIlaZPTP!&tT z(VB9msw^?t(#l=x)LIXz-QS>H<)vt6R;czJqFwaRxd9fHrltpdFpgIStCE1yO+O+H z06nDKw@ileN4f!;bNI9j`U+!gy~n7v6V!kP7Qe#+j|6j3r>~~ql2MU^ow)t-cf0b3 z2ffD*Nk|GFlBnV#V>dZ1lq9$Bg&$$W?t8z0RzRsY&ATP?-@ZoHc9_A0n6@cVzQ=iJ>i@W@8D* zvve(V9%@$!Lzw0$_^$B*hbqZJ-yX|hYrS?ewpe&&GfFEmaA~=aDuJ`_ngz9KSEDF5 zL8U8`Ss~x#iryf-a9nEFDX`+xH@5Mv{OkTNuV*ieP8T{lOb{wO?eCg6<|-}qBtMRb z?ku{(i?TkLlS|qtp0htvrg(3tc#^2_9Urz$=Ye!m<_IZc1ZkqmObPqf(&PuB&HvXQ z1CG!$&{1hZ1@p61Bilbb{`kJWK*HJ(&->#d2L332x^h#^iEAGasR5KzZiZRK7(4=F zNgj(D{>92}1!B;&P49mq_CEP_+|OjAiD+6&hCAJ zex7UBVcNzIs+YLPA~S&g2*=_5!kvB9wq~Y315Ge6Lp`!y`lOe=i9rT_2#*r$iFdG^ zg>vdv+q0HLO_Uf?f_yrfeUHPuCuXdPuGWp3j&PjHJFw(C<&@S0*MftRivSkSK%{(1 zt`1abd5&jzVRc;mO_UzosV|2lCc6co)R2|x+%yiG36>aUZMJ6+saH@Q`c=B$8u zwt0y`g9Jz=mQR7END{BzYeo&aW1BMb(3&wr<=1p`@9iz$3}1Ovr_w>WrT8AYmwQRn zSM>T$tYD~ugbs-*FD~|T>UuZgC0Qciab7CIhCiW%%XEF@agPtjD+e^3l?a@rFytIb9@7ban&Kj%~C#5du1&H^U zo=PH$xwhkncMcW4=;0}qAp#_Wv8ZGs$#$ji^upcU-HpJuAmN|2Zd+K|+S)a9AMNJ^ z5fodx_3oTw)|K4|QqYW!0%gD+jB<;CWVOgIIe|`oQz9_iNr_Ld3sAAXa-=CjR=;OI zx19dKSG-l_bt;4iIG5xox2CbxbAV>IovwcCWXnsNzy2MSEyN0^40O8ks{#Cl)aI$_ zc*SRaue&|VRjYONrO&@`+vG=IQy(rihg)qwsU;VaJ8W7J!YFvp-nJe+3~e&U z^~w#G=xPPfoSyWB?ELL_n-O8SPeHrot`tj3q(_5LreT*sf{5oed(tOYsf+8U;r`>J z?@b*Z5T&9Ochv$QVYau>tuF#7jiR1?#pe#P_?su}L90WJM8g@Zjq79N8eQz-?XVls zoMnTOzF)iS6c&Z-a9g9K;Yw3iP?Sgm)aK}5(0zTAtz3|+9nc*2qFVtZ^KauD&_y(v z^q;P@=^e~a&5%Hj_(P67E#V9$$3~XI7bY?vVgRo}+0&pH`ecoqoP0Q24$p&C9%n;s zNfZkUOQ*@FpPHJg_;?YwmcaB z!;8{cF8ENzS|3jd&|@L2_-afHhK<7Z^Y!^;Jl_GLGZRdy+eO*m-BpiFw4SbD?1^J# zzC4_(bK|t1<@Y$nYcnYfnR*EW0eM^(;1k03BP-+R-&C3yUwzEQ8BJ$|B26$5AlUG} zy$nNQ$wV~cHX%EiWg?mflh^sc1brOcYaUapAJFb8%6WyOdyfnJ|sL}bHYSY?WuVu5(efvCF=-K`{&T8SLysG8F@-1$0wm!7SN1b)q6 zH(5ZHknGjKSctdcZ~4LPe|EJVP`C0Snr3(Qo6TJRj~}znoIhxe z7^b>P)95o2ehS5EO)WNm{fOCqvbp8+^#Q^2%D3F_m`zsRbC?(I?BT?_hP1xBl9Mrd zrz;`0m!m$Md2Gp(f3lM#YwqqXWW)xpw;Axy_cy*zo{c`FGE3NF_b!-~zKd1`UjnEP z-xqo@M7L~`JF9FzQ#YZ4xpT3T86c)-^fj6p9n}zJQwJ@Qv+Xu)NO9euM47VW+Xzy@8zrQ1wVw1^OoQ#|)Y}OlxIbNV0uP3yFi>F-S z^#gTqwKaKaIve!Gc7029E9KdQLlngK*M_eFZm%>l7efb!;hBeGA`ll~=#1_pDC71h zBv!!tqzfMI*-~9$+lDud+Nj)h`a?CSJp+~0xAw6S!8I4bywS;^Ew#2I2B|}4^b|V6 z=bmAK-lGrCEyU|^YFzJ0#}A7m^f8@zPiRbLM1@~dOad&?2HES6CT_O#KKSgQPqsdP zk4UrUI+M~f7rmzp7&?)S$#)wVRfW=3=U_*hFQz2KNTS@>JK8^d&Um!0S`k4{LDK2z z6L>}fS{@sbo5Phq`?3&ChU#ed$C-^pcp-tCf_eI1jQxk%t^vBBVkOK|mbhR!*8}tQ<08q9 zD_7{c`b`<=*KWFXsO^#8iuBBI6J^z0|6tO6W#{Nk2$3f7b*K22HJ;9R$Je*-UymSr z1GX`#k&kxRKF$mc-BmeL{|4s-p z5QO?m29|p7Q>~hZCDIWjpW@jUgyP+F%FRO0bfPFk-ZuFhZD*Y6n9x6m%vRgt%%t&Qu1Iq}z}raS)KD4_hwJ z;ekAaPDa>@^I2;EI%o~wVbTooC{imYlb>lN zx;)igUUscY8Sx0LoaaaH$oQgQ6COnuHC=fBb46Q$g8Lh`CxXa)3-o(h(#8E>^Vm!zd{@L( zlfpqb?~EGnKMB!;na;Ps6r&EnBV<{iV;jDUf`ZbnNC)@2I9}b2=CObl`cH|s z%*~0c1peNTFPvJ5>u(YT$O~eMTEin&EvRv3YyKc}s%*I^y<2$HM~r2T1~(Za7T&Q8 z|2l2E29-55kwvQgG}Oogfl_p4dj71>ebEFCz8vdZ{2yHM?z>kzvD|h$dun3RJDsP+ zAas9yY$Ppn~?YtPZ4w@6@36noxse%U))RxSrlz(_WD0Q8KQIVG)jEVAv4oUWL|XOh z=yp{+WhpY(Wd=U_F5TM%d=n>PU7?lsGSJgV!UxjmQ!l|py)G#)yY)st270gkVVS8) zn|NXLhwH!h0;^3u=nk3BjGP$BD*|g_?o;-CEE_NSzOO8HajEh1<=t{w|IoJKNYrQ( zk5O>3)LP_ z1ddZML#XfGP?5Aqe>dtGgE#hg*dx?QF3cTS1~vN8BLKVfn&yL_xJ!%@DG#pl=?|=7Mh+&oOV?`uprWc`S!a$FLhDHMhA|q2FP{{suj*<79 z73-+390U;nfp>P>k(ZJB6@~OZz-Q4>ou8IcYz2wzr!H_(b^p&vJq7^9cd}9ktXEQs zx!)Hi1$440=ZVnal;q496mEns)v_EtzaR-``ye@i>bzsxcP=U=ApUiGG0UH;d9z@>g{|NZ8%Fu-t~T!6|KDtq;QC%vGI zmsT~Z)0(_dYnAM|U2a@@SQ1v?mkG&-!Qqwy8wtbbEU|c4t&%{tYrI1c@SikQE%L&T zg=iG*J&8GPx^m@ZKx|&wb>)~EE99in(* zRi!Kt&|8zd@#G5tH67?Q&IN%m=s30xLf1E_`7|9K10^*8RN z$t-!@UXA5%0s%?nAtQZ#=zd#XreGcqC}vhj8bM)IS=&V^6in-GFDEK_wa`fCy3_?= zeR8&i^yC^A(t$zb9EGcz0YQ&HCZ|ikzg6t$a*F@tRHZ%WcyE!bg4G(~^4h19#ggBL z&meRMRHyQs1lI@qi90q9I(G#yNC!`5&3B1FCW`mp%v%BQd$Qvw7uv{!A@>& z!AVJ^SY-U^Alj|mdI8kPKg%L}eM$`X?>=bZvOjy;=zAi~ln_zb9$jXM)&xi+Ok;l+H z4?0y${}cvinUu^A7@phqx0hy z{3}NO5x;@F#vDvXshOuL&PE+lboW-VVG~;vd40-RI-4xG@C?-dRwgbY%6W!zwD6w zr?|>p5845Gh(YL{Lb3%5J?-3SEZ$c!z|5b6h`Y0TpHsW@NlX^Da9(=56RpyJZi6^P ziA5GY4l5ODaX04oVZy<+nWXwMFyWB zx)x|n&m~&cBLaNRPrCIYLb~w@ccj`A00?g>H_b;q@+uN8qE%wR>HwjpTiIJ3jMdxk zKq;$}06cKV?_VAv@)$vX5HwYGb`c7~V1^TD9k&Bu>~eW~zrPTR%3Y9b{)mGA3Wt&Z zK`Ga2b+dajNCOcK-IGr?ToMMpd|ll}pk*Wh8y8vcIzt37%zih#HJIVenJKB~HJUH< zKGM)Hqz(6ECM$~CYnZ!`Y#jfdD-(ab74~4hM%lC{p}oUnlm*7Dp|Vmerri5$d{#_e z?>*rr0Xy*ErSd`UzOOqxjq5|0(VfrlQuFt$9RO}r>D7YVei>O+ZUeIE5X?4DQ`NdE zqA?qy2?u_BXX4OFnQEic3Pzow3J!pu(6kmq7uCo$y|iDqh?tI&{P^DZdpS8>6zgic zy%(PyL#K;a5s5{kJQ8KpZdKQR%>e^j6aZuOPAK}rxSn|6k`fNmY4lHpoeQfr=C z)^Zm?T@UHmWa8SOjGL|ja9=Ei|9Uw7Om51Zd+C1OQAl<+x8=}-VehY)BQy z#`6QcWAqMl5Su!^6Svi-rgPJ8$DpRmXOIrDsR~(Y10W4tz(L^?mN?vOZ)m#$C5K`z zY%Tbap7O#@48%6>?FJ}0KRWTSd$+M*_tC2OAw{*_1grA%3HHwKm0@cU<_pYbkd zR{VqOCKlUYHvIDtEqMVeu=$cg{UJSLF--!C$;}a%#mD_h3uF z)8rZEu%XFAe?L|kqqm+~mN}Dc!=L3md~}x!MX@a43=IFNj>Agt?&;YpRy}o9OcY!h zZ`h;QpUEb07)#W{JFxam_&j!?ndwGfn2*AxX8(%HNGHd_9!>WfW2!^oA52T8Ok@DFu0(J+ou`7CUWO{r6=(Sxn zrhYGtsT;jd37XCil7urlcy!o63pAbt=r;v}S?b21jG$BRUN*~-$Yl{(A91rSBxFxK znQ-Y0XEl`Hu-$`ZfD%Fxx+j;cy>xbal^6i5;;nnook^(mYHuAmWg3@FiXx-?kb>mx zu|4MH$-Fl?`haQTE#Z6d|2ZczD<@5okv@+MaDjtJ=2mG=EVErMHK66$G)&OE^zQlF z3qf*G!~_9*GgV6u2Ja8$C#yM3NlP6$w2hmNz;ze^l|cyI@Tr<4A_lu3wWPVj6xZGL zK28^1wRFv`C3p1j$0bew#*~a&%Gv)YGFOMCWyYY@s zm`z|>XA&4#Se*b9jf;bGck9#Q)>ae{G`b&71`!R}lSrmq62@H(<9BF3rWyF#BYQ6F zjl}6F8!cggP^`^K0O-^DGi5#4yie31T@&C>MWr$HN$q;q2RKAq;6Tt3-};NDV-X*k zN<^3R4@C{PJfxz=9nZ6lAbW<4(YuBk*t2-%*RxJ%4=rB_laTiu%rg=Gfz*L4pGXa& zRO}3DEfGh}bIsx8X?A1`vGfw_c|x)U0*w8Rs{&gw0LPTA;bZw|AdlBwrZ0!;Rush3 zUUK<%NBHz8MCK$FPy+-t?2agI=Y&q}i1pjN=(8qg!n(mBVE>;&$uTDh6aBhpN{wR; z0M68Xk|(>f)B!FrB14J=GUR&Gg7wy|Og2)icm>8{-&3i*at(D61SCDITPobV?f0mq z{NaZSzVRjgR0hJsn>aHnA@92Yo*9Rjn9)IwijBR6qZ4;EkYB9bgL;4vG9R)hll*6( zF1lR#DEu$bOme&{hl<&aV!Itbm0YJ!@3mSozVOeg*(>OPWd{;f4{`^>SKwzjR1L!% z-JNQdUq1v%)?`4*?f@Ov26Pz{N^NBntg1nbh~gX4#}Aph=07ImtVcm8!^p`+gIA45 z%U>otI_wYRycBfCvyJxIOi01|pcKMrMzm}4$P<5TP4z=?8em}p*ZaHH5i63)6L-8s zatWsVy_xhHNE0iWy)9N8dS6aS{xM$VV*o<&JS;v_=o%row`;DTJ&kG?8!A9aAjD&)h(S!Qmh$s9xe z&c)hFx!?_(CzjbbEl^YndO={Xs?bE)z7+z^>>fuwZNo$a>|2UV=Goq*NB=2f_CAQt z!4oBj*)$xlfpU*YR;;@CHUasoNT5-g$2YluEC&#Z2oQArJ5tdK^o>+2tsnikA-^<1 zlw#zp)dBodsTZAIIr+8k`u>-3KqstMjid|TF~hgRDw_*{Qgi}nJ0h-&z3Gg)sg!z2 zm-v<`Sz8U&03M_f*pa&`_j6Y>Kv7P$po3mIR>j$x+tJM}o|F{bJR#g93B4@<|68Qk<(Y<$ z_(LhUzcEXPV`RwQ!(_f7bcO+XPbRt>Ec6^ZNYs*XSu#!yE=|~4y3@}}h=UL>Iwl$5 zREKn8!#^}X{`TrAS7c`-4=oc;SG>E?>2!lTTSTQC;H~_zqV4iRUi)@*Rg#F?U4fRPZ|z_#ZnA#smv(qQfCN%K?TFK6=teJ7if|{`58C;OS4>; zg}zX)8Vg`HA6Qy4014ti5Jz?!I?O`Vh<6sCbUivWovTqIMWNICXtJZ%+e~~+qjO)lub-U z*1aD6U4G~>57ywnP$D){wRv4BMFT1=cu6b$?Mw5Rs-2H8BlO6RHVK7DfD)r4s>;fQ zq;xFFXL+9X7sVnlIN{TbsssUM!|2ITa&UmljlwS@Hdr&C*0TUqAgs21H=e!sw4F{O|9y&4Tm0#DQG!c0D zb^sVz&mQ^cP(wH=S&j$GO1ZqC%*X%38S#oC^1T{-QcKfRpe(wn<9(lkDQa)3n%x@l zED%sdG(^m%4AZw^BR%!$_DJJZ5*(2r>>(kz4Hojj?!BMaMzeQAb0G#+l20#GM0uTX* zDuAz@j|3SJ_l*Gf-ml==-J*m$weB^4*Ws(@ z3#xDk78eO_p0vRBxlc?PXw&$y)cWLpGH;F6C2_srQWSY~kcq@0*4j)lo-`IwRLVuL zDPasII(GZ;V;YVO8#J%cG9kKKDdBTMl5Z=CT9EN&Q*tvc&+p3(9RSe$4gjS-MFM%m z3!upNp=#{~(>v~wjrX8wOH(njL0?tFu4%#DYWnJi@zmJR9i}H9eRB@bLP>7>lC@Xz z?<#g}iiZW-#^stC8s(#jc~Bi%xe%Tr3!J#Ddjj3?u@UIL^2uqQi@=a$g;**x-1dJo z1CX#yK%2Z<4z{4)7sNexG_BF_KM0X6kmF2FM0LXF6gbPQpIJT<#J7Xz>$fdP2*6_K zpCcw7q?D0Kv&~0z4V%m|>Ko4Ejq8uATG$1p7%%B?OpNvaQAtOi7-0}K13Q6C)}v*n zCZRG-e%Tf@V-Ws!;nxK!3UQ| zWy>!Zu5ulR9<77Je z=z3NZ&X6!yJtQHoM!F?Ruti8<`)@Fr|F%u~6R=UyneE+_85*3CBAf~)?TQ3f{gdet zbMijC(BRW*qlGc$BbJZ4Bzkugu*2ZFj_KZ_Yns+(3_2M8LH(_6GyW!Q$FGsXQU9fp z4*CRAW&p28E%X*ZCMrp=6<3`6xDxw?t#%I3qDB3)~-4qP<*y!tpzXtOxR7j0F z1-_nngjTsfU6AsTMujdyO+Z|JO{>t_5dC`c_-s9F*r>YA_bBb)?alZL(;29T4X-8kFO3=KjZeI1LpsZUd6(hYhu zz~7_)gRb)T4yivE$nYsJww!sI(9t1nBV(xSDAm?YMPMJ*Sv0)qsJgy=TyCOM#*>&wovd;ZGAM2{6;OYX4nOfqA z*OVghDDcW*em!&dORJ-)a~s6?d+hwx^qM@~g1Qbddb;&@d=TJ+fKFF+53@3*F3ibp z*kRT-3}ZuaWHkMo5x13?u16&!)j#M>fB6Dgd4Nk(iS{+f&HeS;hZeJV4G7rE9Dg-T zBvkdAPIvE396jX_;TL!c@k@G5{!ihp5Fam0Gwp&! zT~9iSq}iuQE|z>()A+vT)=%$KDW!fy^-!p4E?9%I;kDA^5I<-2NBuA!z zrY;y=X4evV<7Nao{qB#k1^KT<-}Rv9%nBJT=5-+LWTg1K3!!q5Ht*WJly)b&LbG5_ zl1=E%Igy|}<0~y&3y5)woNjALaBq`}*Tj~Zft~z~zoBveYSrRhnTYFgrb*8p|N7-T z_styxvu1MFrhTb-`_jPV6KwhTEd3m`V^_JKzzbp7}X8|_8!FQDB!hCB}sADNJizkbneofV4P2Jt= z@zu{ROvjaY5Eo|uVmw}0SHcdUv99YJVI5Vm_$OnP9gdna1x@&CQ`@RznyD#txuh9& z7BkeNV>F*&ZGK8q#^ch;8KKs=kY5uqrlFATy3nuc@sm?n?kL#ug_Zy-%LrAl9V=ew zT1q2~K4N)Sr7t?kZdg}TKVpXMo@-zzAA^Ohlq*!_;{c872WtxeNT}Ym){p?W z(Lb^K$aCThPRQpsF$T&wL6LevUOBo;U3m1YY{Jt_?0#-EF8w5AdkewAZ{oUidiEAY zLUfu?TapeJ4FX_Cj9yz!1Txj+tz21UC2f8X*HdXAv42I)GjcJ1*siLtV9nA?w-QhH z8A(Zfj~Vuao#lsZ<0Hr8-8@(q!~-@e~z|4x5$t^u7d6_CU-o1 zMh6pkz6q_@_;i+&6s!I2Hqf;F6GQ*EzXn|V*4PcVbUS_vbzf%+-Cifqx zChRgri@5L2P%`KjvnKLT(z1WhJNamOJ8pc>d)e80(2X8^>euf4PYSzfgGa{-tZWrr z=qh!M@Yi)x&cZ&s`W~^28NYQS)(rGF&Sl*1T9iEvnJpV!x=OouaSYq$#AiFblUP9+ z;d^Cb+eVf2~T{Rpmsh`r~zzZ5^o@TiiD$!GYQ+JcO(zKg|$P;jOnLcxsNQiSaC zY%gl}%ZP9O)bwC>__nA3(Q}y^zF!j7_xSz_^Jpu9O3u(W#;0?L^O)UG{PjR7DSSM^ z3Q{WH4->wVxQIvKtCm*@^E9jZZpK;l`$7rNIZ(^DY7Wt!>`({WWopHbJBXa1sp7wzXy7!Gdzdr96q`Qf;* z=&~_@6N09VO;!cv=&nz?Bo~P+(Z!ApL*L+~R5eg4w}_&+WFf?FH%I4X(>2{0jP>AK zJtT|q6$>$Goo|lrDUf4)$&eoEClM$(SGMY&-to|VWv=jZCHX_Uokqi=nlcPA&egS>e=Y>}Hw z`6}Kaxc;7AtHTTx;!$?Y;*QGTQ1-cHjVAT81}KGl`%)UGGb-@^ZKk{G%dEp!JD8s} zXJo1M6v#5lx$m3DoYxoVvk`N)U0dBut%BYi)R&&BnNSz{Oth7bd+)0b9-i7(^%MhG>-}~;uQ5FDOHjypqB3l_E?0k z+7K2>^s_(jE>UZUlZ5O|cq7KMVwM+tF9Q#!=bUP2J%`_0F4*%Q9NTxO(6_N%e9q0O ze-+ue{sfJ~XU!I$aYkBnvhE-$GuHY=qrSIG z%M!J>ZcoRlMH$dSnLZ?s(x>E$Rv_Dpvu&NJ_GK=ixYt?xB%Xx;w@(9zQzQPI+Pa^C zbZb9j+k8+jE8SYmPf&?l1y^U<PvpZ!Re@Jg&()vtWL8^O4OnnfS=89i^zyU5_-ak&nIn z+}0zM$MTgih8SfJW}Jugk2MZsp6f%s6xbL;Od^k*2Q44urBJ${xTY=QaXZ0g3~H`Sk96A;@g*OBFj?5ef={ev|m8#x|N z+Z~T)ee+Zn50y9U4dbJv{``ttvdYBpaXqU5HJjV;hdmUutQ_gV)TZ{$YZ(cG zZmha4WWh||PZmcjTm}>dE3MBPcNZL%wge#EZx?(nB|Tg}>DERb513H3@E=k(s7vGh z<|xKG`3aOc?wxfopPhX*FMrxrjFC_)Jho(474BZ_;B$KI^3eA(xBeqBt8*WWnBfcL8Vz1fqPXS^4(T)Y{f-E@&3{_7hvv4Kh_ z^?krceM1gvbH_mZiSpxXDgu7w7b0*vB;e|PR3e>}n$o|JLc_s-yW`U2*Ir*-<_N$W zs)7jgq`;b}^-!_(o9Qd!{BW7Ezjy?q+<&=DkQpyS-~)SZ1YBSRwCsT>Q)Y}2@N?Y@ zh=&^`oik03d1-(CxN^dOc`VZ)ZD_YiknAMq0%P@U@Ky4Tt1SxAO}$+#!o{-^eev*ZPC5sO{GB zJ2bi+6nNo6$0V4!!*CDR*Db86jk)@BHW}5g1@uHF4}WI9MxOo_s7)I39NB7HDLo9} zPvp_ti)Cs=5B-pKe&g4Eea865a-htdoJTdOrAJ0}nD|k-Q8sEKZA(pl$j5;^kohZtqjqeY0Y8#?O z5Z2hG?z3_EXAdbg9mXhG1tZ&*1#ECixA-ZAY zhQZ*L%jjzU%zp=>=27{}7DfOD^sg!j%8bX!8J}6nidETbqtrEps2ESMMsLKM=W*~^ z7YqzSN=D5XwA?ozO06A2J=WzC2P*fXcIFznhy;7^9>z@OdJi(XpN(=MVrhQvhkVk% z>|hrVI}3Y=qChObE>%MaKDxo@JgJ`Mg^o+focoP~Z;@3{*+JnJ#7AsSn@xgc=JA_> zy_IbsEr`{A^YiyId^7$>(q5WU`=nZrg<3~GgWU_rbNj}40!z$!l9@2HR`BN2qj~{4 z6~m|vLJB?o7z;(`uiNHc!?p?fYdV$2CA-$(Pm!-2)_>Y{6*vvfVfO-#obE1HIXQsi zP_f?c;M3*5p+&|6jEixyRl6v2o$<~{;xlno$ z(sBJ6ot4#?HTRcFviJ4VzYbYG>0j|utCPeovgQ!-DQ5D^6VUbhGKNh+IQcHye2qb6 znQ{24!@@;T{nDRr{5$FV`;BAJb!N4-(W)PLw#CldKWff)7Q5&{(css*JpawQC}O6S zlHFr+Er*s4(25usoGv@H7R_MCXea+#lkb05LO`|q*P4)k#8_Ty(-xT+Kd-Mp>;+;t z)tk&E{J-qz%isA5xEtz`YeRj%<41q#mmo;ay?@$%Hz?iVxA&ic>u`sWWXw-SIU;tP9D zNKnYcp80fflbA8gqvG)zbhs)UdM`*(bY{Y$Ii9V{=~ud(qGA2mMzDG@asPCOzf4LI zh<-b=kvVUy`YY#1pGYWiI1W&t1kJV}xEA0?k^Kp?Fs|jH{B&hQf(EY=`r35I^G~*# zZz$-!{U1S%UE@*2X&{rffp5@-S2!873bw#ksuu{t83kBXMo}n~@K#{@CRcQBoAX|L zwRaiFeObJwiXuhX(H!Nvd&&UkAJg$0?Kv7vdY@@;U~vr9UHv0PH+-1dCJ?4WnfZ0h z@q#u8a9Iv_9BsEr0?LWUcHLKN4YWRo0!HtzGpz8t1uQgFb`RJ>jRkw$Mt@B*b`EE$ zdp*MP-6L{6X?MWYuJODvL+IB~5S&q{9O#m!(l+qIbjo*|u2{jegz|A{rvUfR(~_50 z6qN7oel^v{5_r`y_klV|+TZrOFfP+8Jb6HKslLs$tQbsoN62_8W_)`N^4NzgdMhpD zubs65EwNu|)!(N*=6f$Xc4@xp70XAFc_Q$*z?pt0wi^rhYup~g*6Xex-?m`7^A2K06yHMKJmY+yoURTS+!pA5v0ar~ zKh7&yl*|gyaQZta^g#L#DYG5|1j=P`Ai_x^f{~A9);{jzLWZ_Dt0q<@$2oMrlBWlK zRe~cg8|w>81ZV#x(E)10(^Y8Dr}SqJI&jsXPZ{Vf!_tEnu8nS41pe^(=&DiKjjjf_zq0cC_+)_@?K) zi0FTuFCPIIda*rmc&6?hGWm#>i%Y0t>*t(rub&k8DMWPKXu3-tR+8-alu!`i5#{7} zUm31NcjwR0kyXZ z|5OuAq{k_u`Nh^!__Fq53fQ>r_!Xx3_d2!)%KSBChR{4DxEuC%pJ+g}j=XKWuoh@K< z8Ql7yf}h8YxzO>o)R9nu)I305`>J>KEFX}b{z|HKeLd$;i4GC4D3HhGO&R~3*%W|s zM{eU;Q#DLqPdg9`*yHTV(T83geCm-tOdyWkg!bdYL>}|UV7A}?wb&xC3hZ+&#cHmv zpA`7@FtX3;-RcS;7S-8=1rkf1W*Y1i4#{u5OvZ3)EtiZ6-#jFu&*hUst-QH@lUw>y zj=9YCC0bBraX8t#SDCWPMREeLr`ha{E|0`hYRD2@NJ4zL{2b-Tq69-)%%5aTn2E1< zwsCrFlCUHRVP7MOeM_O|Pi9!yT(&m?NX|bUZ%AR#?+pAI0_Dvq{}&O82nEE8EZuue ztDI6Bx_>+pVv%U>?%JB)(#I#LWDM&QY;)%sK@x-sy9UUfwAjiN6}9S_HrOH9~cB@K2ZcW*liee|D(;x?8G}jXu#x zQQl}jK@^l(UI-o)3k3}^D8Ln+8JF>&*-wSp38}yKpRWf4t;>>-c{5wSN}he(wm7Gq?VK^dU?g@px>r%{~j5 z+q)t!*Dwg|FPo7u#E;j=ZD2Eh^>mkiM51r<@eP&vJ^v87X_>(~ML9E|3!O_06%K#c zMsNIeBpPbH{tV;4e*h!0*F724Cw&J7hXc$IgZr@Evb!v#d^Z~!H(%O&zc*<);R2ml z0T{(qZFw>H>>U5|`LCR8`kj#vM+PgniQbG1;(N}}NNr9B5c{KL=>?0f2^o<*^(2>x zfn5;V8x5A&=V!g8MR^NbPC4V%IsQeotn0|qUd0u5?fj4ZT(dwfHNZcLi4 zl7v`n9W#nt8}Wp+m_jUEfK@aoTHc?mf0E{8L?qxFM$7zoo2pi>MD1X%u|+fIS21_I z$q32?iT@rpFhLL~7yaeuA?%4}lb{#Rxf1pSgKe(u0gJGtC+v$s00HvtB6&?qg~-?9Iae#}T`?322^1 zs?AYb^FF*d-ZyPBob0Yfspcf7v-(l^V25{}hcqn2yi`W`C2ExDOnEGdPE_}zMF^7W zvNVJxtzu>Qa;`&0c$$+{_sZ0%wX)>sTr+#nwXaEMR7{~~OsGqdM1=t`!mmPF2U05O zufp4}7>h{_fwJ1ELu9pnkd*Lv4worKNcx*d20J~JyM;Qs?6CBMIaVIe^&yXM-dXS< z;m)$*@x^qf2bDYVj(hub6$9^scnuwFZzHbl?W{(cJb*j_KOnwdkh)GcEJ|>*(;NUo`$7d2bz6<<_+itDu5_AX~bmrP*{Np>#+KNO!}gQ@Xoh6Cw!G zNVn9cLAtxUyS^Je=Q+=F-sg?+zQ6yzF&t;W(S7fA&o$$k*P3h2!ll*39iudPIH^Jl z6%Uwee#|{SB;En8Y4mLiX7d2eDY`XMB80vO%$`2{hZH_M{5Oa}0a$es&>r<{wzg8l z{9W=kZceG)Cz;IK@6Syti%PmOq5Igxx_sW=;{HrifO+_f@(O-&j&JF7KT*||Q_u;C9edV9vj<6J zqDK%|>0hW<1)gYPc}qJp0MnU)J78I%OhC^B&jz&+2ZO3`x0c*j^(>>Z!05$Y>?iT) z+`KP^;z!s+VSG(11WU5?DY#)P6if3?Hziw55LrcF>P>y@Uy}Z3j_1$ZJT?a?<0WM> zuo6DCApo0oQPZBO_<;#>N#nhrCFP$m-uxh)4#@^azYOxiBY*QkHo>r+p#IYP<3fee z0SvSK7XGM@ItTyRmzhJFDhq5Prw4H2S@E;XSwgp)H;fjeCifWZ1}{eNNp??1jI_Ce zv|~v05MEb!O5q#b>DH~L zJ9XwC7Cv6A4-sDp*iN%;4Kc4g)&Na?D6O=}{LnWziPjih&=J60>fBIO&wi3JI} z3kQsey1sw?WdQQNh1`OX*DgolAYu7uU(@~OjrlO+cN;~blZBiaYY zP|=)TA*ivt`qw{g^vIUo*QmC_`#eY3I(qW7m~!9m1J3-6Iv-)+jJwN1Bk8mE;Ky*? z)h}(HakFo?RitEFt5&lJO;?8pz%}h_R$_7612+qnx<$x{|5TCxB=VNocx8b@tp95- z|3AQlz{pUQYkj3EJ+n`VKuMPJfBiseZJ?Xq=YJrce~ZHa7V|9S|} zzIvC~DvNTJ;hFHb?^BNIwx>&%?%vV_PqkY^+&J9d@$XK1?j&#;J*ct8|K!;cuOl@7 zv9$d?I*xBBGv!ipAMZf<8n)yIXXMB=x)5#;43~I@=7d}CfpQ_4<8@ZTA!U4_^ABIZ z*)rO_+ZZYh7fgj|^&u%pUv=Nar>+m{bHMsgTfafx=@$HVoSgHY@Jed-Irc>Mv+qVD zKRK&^yw8i8Ju-^?Quy@Rpj>Fdc5qCUl7ri1xaGW(7s$EZJ88Q?(WalV+146d?wm#-E(}P_?AiK! zL?Ie(*sQN^WVrRwG=|y5SEj+dk&baKj6Ujc$W4bhOW$WtHWw6R72uT35UA1Ycu%{l zgfRh-ka=r?Kaq5{KjxcWyDDKqfGe1Ax-YjE)@B&i#1Wf7SI6EOcMb@}(SQ7;fu)s0HdYk(kRZL6Tj zk;Bq4VFHmv6zKbwm9vxSRXF_;Kayc${{go%roVJFRzaLbyWLVryv2`FCkrE6eMYOb z@cX#)PN_G%PTyZU_<$Tg6Po&-D2!S}A+Pj~|TPbj+Po7JB_4OXY`$ z1$Hmqt`)x0+BGUGkjHszqvzwTTJtgGIj4<*Crpr&*bPooo$>Zo?oUTYBF^=1l1Zvs zT1x1Iue1}(&#IBOPsZJh4t`m*?dXlcRrY+(N;j9ilvYA`V!3qYIb<9Q_<);04%U6! zS2$gWg?~Ek7ZdRR21Uhu1V+%RHxK9j!6Lm|iuorNsnCMJODDrh)|@ab=n6M9IQ8@~ z{$C3P%18eyEB=$Ew+AM|A7i&u3=2A!Bk}EcAgqP9Yy2}0RJDA;`W?{xSsP}0oN#VX zkaaAxAe~#I35b5F#x~?n;7Ky|uV)K){j*M=*v;L{r|t0(zPzETjoFQ>khPzPLCM;# zXV>1VKYNZR`05R2#Dj;Aw8@Ck5GH);MbfqRzAImN?5)vshiwN)id94(xul*MSJ&8( zU8!%Wu8madNxwqydW9gLKLkyeKXymH<)s`>>OWBXvXH>zKoHlP;{OoEqs`C+<=OBD zPobIyEV$JS8#WAu3HLWQ4M-)92hOzE{EvFwyTkFTZMJSx34&OY;3_{_|M6Vl(ZHWX zV#p%_wABbs&sX^fOq8#Y5<&O>yZL1vE9JT^SPu(!tBCGYPHoZ9+POfZuv_)0mIUR- zLT4|(@0?dN&WCTZ{{V8+X*Ks9q1@)1ow;fZA*LtfU%ukQfA$rfy@*WfuH)eC$tt{A zwJ$GTxG9(QuZ#yHoX!MGr8aS(Id_k{Z~h>5bRmEZ=o;K$p#NUnuayV%z~BGI=--0~ zul+WHQ{Y1f^bH!k0NU4aFV3d8x2y2KUl8~&3A4ccZ$0G zl>d3HOrZPUn(=|Z_rU-@r2lK_@B#^ftME?$vq!d2(uy2BOA!WbWgeT3Z9gBxWC6CaKGGNfArMvU(<*4 zew{R4J&0Yv%I=@as2{s!PwzxXU@yf|EiInr7p59W(w4k%bA@`%E&Y#=qv2rPzq|Hx zKz7YRuOMFE?{^+pmn32^S4$#jdggfm#hBQJ{LL!q`|JtX+Zl>~S~tpnBJn!clP4D* zMj??)O2NRVjmN5&9QN%xN3l7L$?q&cf36)(80H7;mOz3Ku;yZt4rn^Cr^I8tmsn(xhan+%#GQxZu&cay)0??Hts?^+>auqShwH|x;a$#ayKt) z5^lUP&qWXIT`|a@-eF$|AJxd2HD*>tnf1lVBPrgRS8Sbg>ho~+87|JHP`-5vD!5|M z21d=XY*fOHa9O}LU!uYx^Q}fbxkRjiLrp4P@Dk! zG)+C#3>Q0PE?7^f1~z)fs9u@+t|!sO@#vKQ7DT+fI&&4y;eN|EdX?G~@u4Ya4npH_ zqn{>qW_=~h&n`Eq*bqyZYd;>+NKmN@9eUxoR~-;+EwkTv_g)c9t~Ga{yp*P+_8@r; ze}(TU+f#jr4jVt#Rgb7K))h*Ndvfs{0&DHL*7~-$)s#CXCH-{!73%0ovMl1?6ov%K zf_}f1+FkpI-H~_xFe^J0G^v{CD6Bl5ft(TSmsGRJSPXY9jiqr*OYL{oVs-Ri&TJ5E z7Sub`QCuuR=9E;V_RRA6DS

yI&QO5^vI@SJ6ava1vZPX^0b0MuAiu_!&It3HrNFH7+shT?WHFYB{^RrZz$Q=a1lihnnl5%Z_jKTn23}uEOO}PzrMaEoAX}x zcC2)PDoe}v?{zgtiF%mzf14!?tX21DVrXuo3Hw7eqKV|fjXoy1>~l!UQn>XHibq;My9nyGPgE2HxH&Ms=_F_mX?+jR)Ls7Gi5l{U(3tOaJVdmgDfG9 zBI*w8g@a4TQE5zoaH1NR-KJ2AEc4+KiSS#uZf$~+%tEa>la$S zetm3W@)6X`;V|DsM$&`mNoYP=Y?}c+1H(tzM1Js8jt+c8a4F^cJ}SYPzln{_eE*)V zFH<2v)8+JkaK>Fx6pTVb8UX5$>s#znHq$LbASzPW3J-;clNV^!rDK_mS8Kc^2U?Cj~fQd!NdIks`YV5EmPV3To6*ZbozZ-T1&QE#=oD>&9h38m8oZ)FY! zPy6|rO>mIN={vnzF)%T89~$wAb);WbEi`ypD0TWcit)Lc6V>@8;r{TkFWG<-K#k{5H2?A zszV-)fs)|qk3Y@-Ja(^EiH8qC%{KwDE9`&Gh!x+p2cd5oM(0|t{c~aqy_u=7fWa2L zIs4a$S@FT({z(hFX+qn-m-wF(__w9A;(HKA*#|dx0Wlii|K63S?I31HZ_QKjJMf?D z4*FVhFc|(nFU`~Tk#gH|G7Y}Qgg+PN-=B;64+{#~0bAoI$Nm3#u4M_)S{_d>JEk&+ z{k;!ft>vJaGrhRDh#bj3S5G?WPGV(Oq{6=r?SH);OC#k9WZY&}$eNh>dkxtPe`G%i zFC)`%zW*;Kp z<|sPjiCS;joO3F{|6D28)hZP((;s%o(fo6R#KS1t;Epl<>+t^9^ zE`-Cln-2ZlxrgEBaY)~Tu#`Y4_2~<9XzBwlsU-5-pE|U>%!{T!vY9pyU`Y^>kjy3M zp1!E;{MPDt5^~(WGSPU%#Od%JW(EdC>WF@b0qh+p)1Ht&9V}-4e{o9Pt2HM%l94y~ z%HyD?j~n^(>=s|;8%$;%9+hyrUtjUBtH|faN;5oRA=WoA;Q#v!)MgV7$XVcnA)k^2 z6;{Mu$;yVdzZH@KLl5o(#3PwOZxHbd%2`SfP#5W6X#Muh`l|;bE6np(OG?**|`Fxv%cpTI4?2@4Aw8Hu^mA;&@- zA}AyT7XMkksvV5;4bAs~H}*tPu$QyVf%OG^kE zqVL)Fr-PwKoCYW78#q4r(BEro6C?t!z~|}erhrAc20qC!T0sXcn$jn=$daYIv9Ym| zl9KXpKj+eBk^g;E{zPg%=nYmnGCaJF!y&Jzh!dR0e8pnU>6HYZ2hSmnJUpBhq#l4S zd3sOy_@02Iy8;;Uts&D~zRCnlW+tX@GzZg*o3w(v6F&QUvq{S}rAA#MaMAdkf4?Wl zI`u{qL+db+4unAOKu84JE1ed!Iu}4rs`RAVZeAXi;0J{aS=nq%WMO%kHupRd{`92K zfDf^Y^>e{_PV2fdG3@t1$1;kp?}MBumD`%PZOEUGChe|rbwuJdd~Tv-o#WVbp8r=QSm=J> zASNdE_b~o^PHU{7a{(!*(SMsGySYwpWG7&!qWyW<|K`ThA0_2X)(4zg`+pht=t9pE zk3W&&>FJ=@W{dkZ6~I}a?$Lj&(@a~uEl-6_FyH_5j{i1%WeKpVGqq|wPF?X|+g>|; z2(shtY-iH`v8V0N(Xs4$WO)Ajv@mdR`dCXxe*TAH`rrKm6r*~JO;t{<-E*tcvz@Rx z-e0`=UuOXs%5#jmR;GV^Jpm!%`-utd|4(0!2p|<3_nbQE7;5s<&o4{Z+_d|zkDrZr zbVshxpzYK?>DZmVs&*PgSH2;!;QtmTph(i?MMOk2wYA?BE;8=iIas=F7=8M_yM7#? zc_+~gidG3vB^tNa>MxO1)HF2Q|LX-%B-c1NUjO|P6g%$!BilgLJMmVYeiZ6)^&jiO z1p6KI@c3P~vkRpPY;A2L0s;bH*WG}9tduF=I7qML^w(LwQv_22;)Y-fz#`;>1!E3r zqy-BG=L}|!iJJN@C;~b-7!wiy{pK=V!}bI@*zAR^Wx581PvDQYxA$)lu!T_cvMu*j z7ARnW#RHRtSgqF9!kdL^pCrMHzy4UcfhxD{bt_T_2`gZ&wbejp~6sZr^$Jq?%= z)6_z1z;tVL7Hq&q$_z+x4t}T0#GCB>p5E4n4}=6D!F||ad1xL&Z2hClPc=of) z6llA@%XsS|394VcNuI~^OWiGz$Z6Uy{@C9hwzS`z>k5f_6wfUK4i`emzH0Y_H?LnO zgHX9d_;P~)tScOabui1+c(hA*tZWike&$S<@fMJ!~g6;kF+2w<7H zxs^ei=tq%pkKNJ$heUTcz0@YCfSZh0yCa`&UOJi%l1BXNMCIv!y?i6(^`V0NsGC|j z5M2_wyeKhCfR!r*(a}B#(Jl$IsujJ0Yuf?=gB^$#ONGz0C#ui=3@dTl;&)uh29aP9 z;e+kvi0J5OBn}QLG4nl~>aFoSp$#K2A0)Y1wTUE0MV8~$AwwVUtG65GtG2jzzPZW4 zAx9k2IH>KJ_B0)IQU1dlZR>(|4ABRTqS zt*uhwza}SXZr;2ZP1(O735xZ}>XubXFUfeK=(d)JRkjxTdcY|ssH^h9)xiRZa7Q)R z>SC?1Nwy$>F@@wh{JobKO6Xvza9i&)7-@-aAa~J8fGVFNTq+_83Jr$h=qmyaOOHW> zthe0OP@-h2_k-LuiaTzq;YEPd*)I*q!K>fU!J?KuQkBo`=$`d9Q!)Y9T#EV19`SsX4pu{Y@w^tE7pp<9;F6$<*`Jk$(KepU#d*ZV2 znKNgK^wA(0l7Z*?kUP8Dv1im8n*Z$EWiFH6hmfro>H8dl%I~fCw(UI5liT9Q(YZzX zdLWLb*^@4#aQCiP)u$Zz&*dZz%o_voS?|iqc&m=KhZp*?`o&sxb#;4T_t`HGfkdf8 zALw~5_qT>`c0HiVgnG`xP;u5u`7Grpwc?MR+pul6&0V(^(m?H*P4g-i&B>>Vzr-}3nw`ynNTr!j zT^5qf=@j;x*RF}m#Bsi+JA!My?Kuk-oZ?GaLN1#h;&?1Aw4zSD*;!fJti_^5kzlh| z`4XkzhuYe^5=Hu{LL1+2FASIF@;fZ4&RGu?aQqHvHHRD+F%8v6>8xZ^aDYhc(U-6H z0B{*57MAa0J)(dFh(IKxoUNJ++Pt!t%?BAl(zA+=DDFHBtXK~Zk8rp--5HOlV0p-@ zmwan)XDjGw@Wmk5grwFOFf@P@Dvb7E&E>kTls|~)vHStq8X~#qzCWK4{(iL5aS?(T z6{yOA99?!0!F3bZ)cp1fZ|InA-I9m(*9osL25Hf)l~HxrC^rd}rwdhCTtHx%({k*y zNEm1foBnJFLTt&+p(ZnKW&wdp#pZe<#FMox|6UJ`k5>=+aT`evc{_hkw-Q03Pb2D& zgp_cSNKW4a6}=Cq*@wd9eK9d62*7Vv@1_8Ru!hKzL$V+Rn@7IXVl+WR*3af=1BOJ& z*eei%-iOn&hrOl6X-V^ffQSe(@uU!!*OLoSlrRE$@HklY`&)|{V6-!>r<%4kIobWe zN`z7gTbv}m)w;rLaUemd9(fz%zAykV~bLxOeX!ELxRhO&wT| z^02t+CE&}FOnzh=YgOwtYbb(hfc863cl{MZB;zxA9;z3BL4+W1hLLhEKhV#w(*+4A_Q$O4`+Z?CfZNOSGn#v zhB1_@XTLJ?M9Xd6uWFf6z(IOv+_MI;laT80mRAJR)6&*?Bs9My%K*7A0)O7G@s`z@ zhHsrA@KAgRnSi{NRUV`Wx!^0!hl(UXFLC0~@Eu^Ha0XAdla-MVrc=Mf;Nm?<*k>vc z4zJZ(;)Z^qQb_OUFNJO3H@xH-`#>S!@Gu-C#Mwx_@kvQdR*9nCym{Z*`S{=#WE_AJ zWejNQfobT3BE!~FVTZ}_;T|llq~4E8%&uXr%+4B&6fWFnywJ#bt?q{1+5f=s2s%Jt zpuv_B1Gn1+4uTLLU(Uz~JEwhVuu!Se!Cayw8STO)%A%9X&mE={CYYLZ{=UA1aX#ox zOil@aerq0Bk&u{zV+N4$1t8W4qL%y`9D3WmaleQ=_#$D1TDh2W{mvkcS0mP?_MmEDL8}o1f%C#beGo zh6^`%X<|XC&OLnXd26f5=V#7WE?-U&KHg33)I(yKVw1iEMY+ZZ&)L-H8h$<9-AfTe zJf=UEQYz)rGBR5G`%9OCS+0ZqU%uDlM^$18Na(?{$vyB)OB<}d{{BNZ7dCe2=L;f^ zG)!U5s(kn~Q#uaE&>4mE8-V(j{(hB>*^Yjwf)s-=nnXAJVh>8&Lk0KP5;8KW5sNbX z^vf$wF0Q^i9b{nAVF6Uk`JV;TdBS1m2nf*!-lPy4vR}n;%gAbKURdO|l^7R6aL}8%Yn@~z2lNx%* zot4ooP{moj^ATA{3tV6 z$xTT8Nx~#vE_@p#cGf)w;Vdh7JfSeQO0s6RQWULh7)H$fBuYQ;Q;T6 z8+N>n1Ve>n*iTH#yDIEWx0VJKK#?>XiLD_5d`Ts&3FgNZLh6iP^WK&yvgz&)SyT{6dK0hH_NN__wUpWK1<_bX&%S?!4erk(HZEsWRH z&<)eQ0eSAcFgdE*1~0&1rU9@Z6@C^CEq1r!BH$r@LnaVq-Z|dYJFpoln1TzyVK%^y zL?e#l?xpWHnmIgoSs~m)7O8EAkjfnI%pWOh8Pm0s;cHOVh%--7+QoH^Deu4|G&KNd4K1_d%m1=@+ss=p7pGCueI*$ zy080QQNpIoK0FJ!kOKQB-CN4n(M2lNl?x=MKd?&RdG()n-i3i!tI|A)p(o=4#^B(< z8PnafsL?rKA!W>TWv{vbyaIM+ivI|bX18JsnEi>yXiIXyyo&GFtB`Wdl#liGEO1z& z$3TgT;MqG@V+goY0AB1u968G(+q@1NS>O&RQH0LEj!8%e94OAo>fE&L!r?4q@>`kt1o5d?lrzFw^aiZtg_b^Sk?qQN*r_u zn1Hqm8h}4vwDxKN1i68UiBiBiShZ9<`c5N8tE)OEz`n6N=Tw7N+`$xp&va%AUbz4M z*})T8V0-oee8?`C0r;a3xO(pGDViGq-@z)p?DVc(g9G#1jRSxKLRY{J%)}9DyFqtE zRF{vE0}0j@UZ^hQ`fN`=u(^o2vN7P%fGsUysi3k51S9m(p`oG8bDBqv9NG5^lwWKD z2L{UP=+yR4H9R#sd-iPb>VTW%z`|Bvo{`Y5If)(v=SR-~aw8sCEeMu?vCIZ%PIlIA z%M+@z=7%b9VDD&dYBIfb>)l{E!lR#wwIi2f%`tcScWgm>p8+f1|MVWl| zng~BZGO>(LO~p67N|@=zKm~PoP5^kvMvIpk>|)-($%czffS`%f@+6tx1dh9@XQ#wp z&Z2S-$-05(d7(SNucN^pF<&PpR0HOY1AFcGo+0bdkm4os7bw=aI@KHxGrgw2my7G| z`s6^nWAUxB`#%mhG&Xv)zU$h*Z=dgs$j`s|WRZ*ojw-4WDaeWgoWWf5 zHt_NQMvMWs6jt^=@(&J|1AcRr>*bWC4{#zQHE^_;+em%%el{4{-Tv6w+4KgW#$tD= z>0dtgn)k2X$awUHFX@g`^f4h7(A5q=ShefY{Pk*(Kdz#^J{kM;Q66~2%a>xs5B(hi zi1P!J6Z|{b+sjtQUGSW18G}Y30Eg0JX;Rjj5>~U@FtY=Hx9g%N*j}=84HN8@-9hh~ znwg=GDTYG98M&n8_VXeZeE+^KUQDIo(^iM53Roj-kE-P6V;lzlIUK2Ufh>7&L(Qw}4s6XleqbOVe%QzaLc`j5vsF!1iLUSC(U`kEQee)YcclTT7lHsB z(awGI=1)ksjkf~2$+@q`BpO(k5zs~5?aFEZv`N=|1%2UtroSlffc&qK5Y7V2K752k z^s22Y);gD52?N@Ve7vKnF%ghJkL z$bm^X3G7{K?(ZOxVbl}?kb;CGyy}kMZIu>tFLNEO{ zA(Q`#kozwpD;#HDKQBE{*_YaX2qR`}53(qnkauY}h}#!Z%|J+-*KNb7%VEK?M1ta4 z*rp5~SS_Ke%0#9s-U=6s-|McAp0$zin$1G%+x17nAntS>y&0eWNKWzW2!X014NVO%z$)06@rg93%XV0$2ThR(NQT$x{2V?;KUTX5izcw2 zhag61*0!Y+zo&ycHE=ntua_m&M2C`3m)C21# z9=X0HU(-1pP**!W-(^j5z@pglNu$!0?`rnAw9L}fd`**f7_pGXAGA|qiJvo9Wkm7?74@&2#9wzNn6WOYJgb}+V zuI;_8iAQ>C%E`{!S&~Jcey}_i(6P)cl#r;(fda?wDw87gRom3dX{^3`{ zZn6+jhg_<-kaNjoIM=m-r$zsK4wA_ytIZa+nkhc;?3>jDYhrql8H;apS@OytbLnP2 z_rfV~m1eh!tQ(&{EYW1L?MlJzOZDW_?f&R2B$f1|Mr~d@6WbWpS)?&k(n?p8QC6M` zJ)dB8R)|iG=m>{BdgR}7&8s1z())3NkYVp`?jrP9%q^XlUC@fj< z+zq~KRl~%{9K#Cg*Q3LqZFm-f^{;Gwx)t) z9K-DB6TgH*Tf=YIsjPv{)?o*FJ3Mu9lM$z^A^qdjcD(UwH8;0@KL#Aa{Ra;8fM6{V z2sL-fjo*zqI2!jaN0w0?3y{97af;FHDPtW5`sF5 z-8w3Qk3Oc}{ZmP8W+2PhJ3mV$1pdVMTe|oGZ3_o+o0J2|Xf^lC>G}Np4vt^ClU#yJ z*+bU4gk(i^*ipxVYj|p_jTYR%a^-7oaLIgZIY&X$QAA|wJ+EA<;Ex$|gyxz}^QyWk za(IE^X4xSfUB$rLio0)m$$nE)<1|kBNz)pifbRF;w~u zQ@^5|+m)w^5Q;+hQ9FQbmVx} zRV7@C4!6dBCm_5*DYhYI#P#vy94Qk*=d*d73bBTeV=H1uw&s3irLF6pbUv_%>)?A7 z=UJbzi@W3Z}p znQ|v=-L{@#SR&fh9jIp4uxvw|rH0GNjjB1%(sAV;)JL`ZksYut0jR3?BCJLu<0xsi zs%d^wUe<|rg9Yz0uUHeweos8BJVcw#ts1Ie<_;k!d-?41=7*-S2{{vopk-r@&nAa# z#p$CNJn~WnHH0`{pD6pyg3ep8#NWv9I_&)5_f_%z$b@^**WArcn9v`1TagKv)Am~Vb_K`64TJ7LIn zy_TVyl^{Kwc@1eZ+w{BxAtrzm*DEA}#d{abt4lfM%mi_}_E9HDlXpB7hurGX8z!a}vm5REMe9w9Xh?;$+iSS?hPwIo+~ z#g;*nYt;X@2{ZSSJOz{D-;=FmlW9HFTFptxdYJ8>vr|HtQ@> zBj^Q9j~#k=$o8-<2fdv<@oCqqIw!hpc$8+|rG>m)wapJOf=JD-IA}?pz50jcBATX_ zi@-q!Mb#`i!oO!9CM?ZYU7@F8MNeSj7S-=g-*%YSkC{XzY>dt~RqZ+R$L(0$jD}sv zTb?&jN*0R?en^dN`L@XW7IvT2e11Y56!k&!D2GNAWn;~QKm!Hbnxm9h7WjH}&c?Q= zX85I=HC3Lp>P!0;G8tsYB71im!|TWRx?h-lEn2{M?8Q9G$;M^{SB&IJQ4c5;jhr51 ztn%4rRu;?#vzYf~XGD#H<}axhbQ$p72y$3Mree>FMVep!NV6s`!-%std;2c!Ux|%7 z1*7ERyl~-l>TfIT!`FYiaN*t7@z0f$M5Eb3_~3ODkE@@en^_3LzC{Z5x!`?fa(F;D zldYhK64}7IvGn!w<~pSpe=>5Go_D?ktJI$#r7U|Kb9PwYp9iZbUwTeYBp8X;Q}JV) z-ouBsUE`DP8++2OiHc8LO`(8r8l-$T>fKm7|gt{8%97%gC zvX9Z9Ta{D4I)Be^nXN8Nt0*tVEqV1>>$UoD#@WZv)FkiX1cu7)Cd_MBx78>4To;Ex zn$PjCeL7)a(j!Tl9eCEQcP$O@PX1wkq-v;haNTZ~N|?a#c0`*B1wu}kZkX&`jT9Li zTxfnn)`M!TQa{vI;O5J^KmFw3s65UdWsLM~Z!Jc0&2;@%aSGl`O01Tgo#Yl&^(ni? zS8Zcbq!nnHog&PynUM)OKLQ`pZxLO4As8LPuq7Y-Ry*{3C{cWn^in(u#s2RDo zUiCc_72Kj007G#HD%K0U^0;COmgqi6H`|u5deRAYHe;E z`}QD0!o&&*Hz~{F(+>=+YAXMpur0Uo`Qpe*bM+3EDLa)E4Vd>M*mEsJ!@_acs*Qkmz?c98bbG+OznE~J7BKesp7IzBO_aB|W-(NCNI z{*XSZa$zgfXyjU?C62LRR%C@~f*4d7q_yR{PnE=w6gr`Ot-57ynO%w+Q?Oy~ztj$`4 zx>@95uSr;TQPKUi2b^}bM#t#2+S_pElGuqF^XvtxeFUdvzB7W=+(+p~+3%m`%J4sa z?H2^*s6X=;!%!)8D_%6*`a~5vlxH)7>hx-3p^T9ciUnqCcVtb8+$oqB>xUO7@>-qw;rIP)27SMe6i60o}@dR5PfOt<% zTb7WBt|Z#$0vcXhbvlFf2&w3|V_=QX*D5om`!3EnPzt z;>Gjy9I;Er>J4$L&K_n9NI)7bXJUA3RK5{aBbwOk&hbNd4SRP<{QaEA!I42m9(R_8 ze7UvF+;h=|oVHO`(qXr-)0^ecQXBH!8JiGAN>kO^z|N7k3<|_m1XQ^kx6^XJdPqPR zB9vxH%TwqrpY}1@=`cfR+YKAz2Q?Azui1VJ{p~{?h?*^p9i&XQ%?T$7Y<$Ar2@5Qb z44O!O_KlcarqB_`e&b>@n$wpVWYp;G?|M6`g7LDtv$H+uT)1r8=W@!#mBsvH8#R8j z0Rd;B*7iva@9g$GDBV^r^|N=MP7IO^dtMeAtG=4$*Au>+S9YL{{3C`xs622KokBSs zM1PG)5*YL{8!ca!pV|&IAXj$^i8a~8Sa%~i#8Wp<-|s;Lq0&N55!_~`baW>xx@3dc zZER1U1@@JmV*2EO*oLd8eNSgV5VW1(?9ph{_zmSGyRP+NsxbJJiqgtwsUSO1vGlc@ zOtc-%Q#ky`=3F@hzkE44szd&M!|eeL9fICWWKc!)mdv6$^w4>^%7qTX6p|N_L@S$} z+Tp`Dhn!MT>2Oh`JlHP*Soi2OBgOX6uiOf#pfVs?#ZQ0illyGzGrMp8#p~XvBd!gF zJUxEd*F9GJZd(gEy>*6dMH{k~X1WBmH(V69&+cS--&U$x)QH8>2dm%1o{+S|J}!hbi7jKYD-R3Ce^2#ED$gUP_@x@fF~nM(nhizt`A)zv2?9xRSJ# zJKIUp$DfZgpY%{^EQIy@WdmN`Z@-OD_GKGU>-;as=W$2b6IHaZA6k0yKCc~kX`kQS zDeo7;sQS(WJu_v#>XiCS#SPiz3Z979w)t1??qx>nR&>lZFaGqOr35DNBJA^#WD{E- z;qLv_-?G?!9AgnX8iLa68}b=BLUSY7)I3L4P~lh==J5jDMz6|O+TTEKeBA?_wWlIm z+2O?S50gYz;_%qX&$FK}=pxSJm?AEP!7PV60nC=58vf7$R>j>Bi`Yk-2Q9ce{P=&VClo59Mh;DQW(HeCm z_Va2LlymeaL`qWLwO-k1Zu`B|jY`YFNc-F)Xmq$0Ch>7F+}7N>X-#DKdJjF!CVS|Z z%oh?gBrK*nnP`#y71#G{Um&I@+pw(h^lS$*Nv9xofv1#5^AA#;5&fxA=MiX@GwxgE zv1pVnG-0bqOa1<|K>i^S(#*?1zAvYnwwTjNy*S6fL1;vhHjFKNuYwf0X(sJ#F|0G> zAK#dqL#*eBm`6cf+`o_ zpTWvFrrUX`W~a(H+(3D8T{m(TX1->RQ#);gKc_89yBI#(+BGs*;u3A_YB(Uh=HBvl zL@PMYaa)_ZnV#-Gl$MVCdIPdET0WmwTan+nsAO_H<2-kDX8H|1Hh#|sn3dnXTtS`Y zos$N;#=9~8DYu@@dyc5S8nTTGs^jiiq0M?$aXS!5X8SrEL{$8|3u~3C#dEij)&mQ5 zq*^ZsZO`jT!n3v_%a?AD5-HgRiILtMA0fo(q*PO!?}Fc37^FPDxB9yQ2_of^6YnoF zwp8Rjd?PF^!C|r5k(IbpKB~C&44+!A?n!K?jA>L`&{GPFYQ-0+4npEeLA4rNpG!B+ zFb}nb6(o_tiKqo6Yu>HcC5zY7M z+8p+~k{SfXW!qJYxVyurg`uz}N?bOS^U1E1ZhSu$pKMRn&A!L_-gsVhdAPy}B6gvW zm(bDMT%7LyB3!I(Jupto})Piai|D=L_HIN~ES_dXcf1Yw@0vZq`!+ z?yq&&jJ=zha?oINE zB>30oT6T*i*kL4?98t(S*K@)&WVg2`lod9WuaCGsjTLwF0<<={qjlp+Ays`$UpzB`JUqYsyEv^hLrCpGKv>mE)hW9Vwfy z^0|zLk7fq2a-W3k_s3XpQVH_2(Z@yJ_D=3!RoKpg`Ax%`oMq{(u|87v z>blq0HtERj`9rb939`6Z@_|iHk*;Rg^IT)U1M?MEDtMQohOsmCX32i%jcwK+W}&!b zmJZXD_fjc=omCkS5}LC1^buM2=2%u?q!&UC|KdER($H%!N4te3Dos?@ihE|@x=@Ru zMFT-NH!U_{rir%DSCrMwJ|KzjD=>{|W*o_^8d0qNO=f2c_9?0?xK~Ecm1X*A*h_B1 zItvUZhrvVw%uh8i*||yot!4q4Byhi8M|k;pj&sn+8fKM^BUVc!R!}k&>!)Xk(Ok2! zCwI_M9zS&HusgK%U{=<0T`qA}du&VfX}CP`kGDD-%TLT5PhTg8{ciLj{*>pv+kO4< zg6!xMgZGC|2krerH&o7tJe2J38`nL_0A{GGHzstYAPv=WW^F)b!ytBxM;@~!V zwmljU6wRjEs#R=_AVi2IYRsq6D+U?`qI(|2+NrX$56~wJ z8$bOt9mD#dX}#qUk1MUTg->(8GhZBVtegS@yt z)ZY%otWhc7@8rmfmsYy$GD>h{xFzm0{a&4N;QBXd_eR~G!TLyldKOo1VO>nPH;0VU z*3`88tbcp1w?81u-#=7>`+fKWze8!Hepepvcsfm37P*PnI<9t8$&0u11(luO_-fb` zbWch~7Z3bjvf|oCay~&K0^!!vl&>B+LZ0%NPHW#ES`>@jE_~sJ;7R%U{+Q06n8=VE zosk?+dkdo3Tx>9{bf%qX12K!UcksDYZDZ5+Mj|y{)2O8IIJO*jdF7KdPIUoI zcjP!Dc;pEgTi-R9^e5yxC5ucMo81)K+upYxZV+8X06oPpY>q8MTJa0${LnHlsLeIEz0m3AzHvZ~{a|vQsf5XWMp1*FbGt#!fyB0N zAB5WIHWB8}7$vRbESyr^+bj%NPo51b;WLxIE()c1Ev09?>v6pKd%0l7lu*hQ9lJV(DYY%gr0? z{q#)dQs+o|7Djxu`w)F0xG>x_)oWYT*BB5p)Gz0G$p;w}Bx7NLYpyj-+_q4|w=lwNlDOjWhO(aIcg_5PcQwAm22vYv^@4zH(b+W&iVSA$rrTr zjY_gv>37RJ&VD=6m9(1f?xG%alT#=@RAh`{h4dNl3->IBsdcFi8+W zCa0`lEz7H2Ni=)3KzqFI-H(uRzme>TXHPp>3p$mfj-_v)bqq0tli4H)$zDV(=xX%Q zjaf`G4np->z+B&rNLOMwZ+<&|Z5n$agh373rh0FMuQ=wj^_R>?cATh$uu6;hA-syJ z?DqXh^xRPCW7k5vg1I^+xNU;Ijmf_Pl~iZQH$5C`KNVi(`OGLa!3T$~2(xOC->Rl@ zxHJN)pexg?n<1d$^r}~T3pSJ@pvn!e%jxJHr|wJYTMKEM8^fd#+M(>Q306Ay`B`1` z1?DQPKa9m+3r`QohZq&DIv*M~J2~_jhq@nePqHdxBrB$+{=n;flrbMIZ77qUe`CoqgchZOGs4AD?|t$X+)1{(e9QqagTMud z@Xs|aj)dN;u{w4IpDs@)GT^G?Pz510JF3FQL+5qh=KwoZmk+F^J_{64@=AAD8uszx%P(3oxtv(zHsJ;IpIwYQHjp&rCKDg;L>)8e+fy`n7O|mX;Gv9T z{wNmdy`084CdX{fsQZ1a=(c4w^V+SZve<5?RqHb}Y_i6{;i0{sQce%G5Hiu)k72Vv^5adb9F6X7xck_& z@!1vCoO{JwL@pfv_A_COsIj!j!%@@}kcJ7}yVz)1LYLPK-&52~3VxDQ?cTSdo$jt* z=bc$N8=;FpuSpJ(=KWqbk>A1v$Yx-<7ygh`bbm-_RdOh8w{_s0;o4e`n`$6&e_yf> zLiVVdCiysE0uP_h(_5owvpuS8=^b+)o~m(tZ~^J;%74*^KB$v|YwI0?<5PN_zjOEm z8P5*Hy?K`3+JzR z4zAYQZ~LFQcKwctjH1q^yR*}a^9zfXMs>u;{#(u|)~eox?SZy;PQgJc_BI?>lwOMa zO9_+lRj%j#-66gZB8#%Q($hiV-=DaYo9ia0pbho?vCp96f;eP1TE6?5dSPRwZF>1m zM3~lbsAs47P3z#iu#I)~-TUr)DeGc>^4E1m7Ku-d&IgEf#`GUa-~l8>t{X#tZr^-A z>O)=pF9FU68JA;f4eZp6%AZf<2c>sQvoBPnz<9Vzr6g4E`t|d3n*S6@DB1($j@BiU2$0O= zKmv7WfG~Rc&GZkSFZRETc1-y{i8oDli-dgf>-xLALB?+&Va2-o{>xo?r)#S_NJswC z*5(`Y=RYFCy+A|=GK;Y{9!dRE?jhrHGhdQ%+J(D(N#N%}>M8-`L}|_i{jHPbkUNf} z0)5}L$Y29NUamIoDkVV@#EX3&f3?ad2&7#vRxd|syOTf_lwIv#pUtZkkip>rlz3Q6 zLRx)VdfMN`3;3^wbsiN8DeKR`gl%sCQ8*}k^J^l=*wZ z6p#uodsJm4)NkVQrJi?Zvq6GADpdz%4J4#pAcH_e&FvqR=l@qz{=crK{6DpJOnd`r zkgpoP&T0L!*PBtG$N{KC_NA?-v;?G4mNI`$hl`r7%88S+R#sl|K&ZU1Py`YXgEju1 zz!A4j9|OAg=8+LQ@TvLi>!uWGNRV4cGn^`. + +The example contains one single script, ``train.py``. + +*********** +Get Started +*********** + +Installation +============ + +0. Get your `Hugging Face token `_ to access Llama 3 model :: + + export HF_TOKEN=... + +1. Install nnScaler :: + + pip install nnscaler + +2. Clone nnScaler repo to get the example :: + + git clone --recursive https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube + cd MagicCube/examples/llama3_demo + +3. Install Llama 3 dependencies :: + + pip install -r requirements.txt + +4. Prepare dataset :: + + # To run Llama 3 8B: + python train.py --prepare_data + + # Or to run a shrinked Llama for debug: + python train.py --prepare_data --mini + +Train a Mini-model +================== + +This examples requires 8 × 80GB GPU memory to train a full 8B model. +If your have adequate GPUs, you can skip to :ref:`the next section `. + +Alternatively, you can start from a smaller model for verification: :: + + python train.py --prepare_data --mini + torchrun --nproc_per_node=2 train.py --mini + +This will resize Llama 3 to 4 hidden layers and reduce max sequence length to 4K. +We have tested it with 2 × 48GB memory. + +If the model is still too large, you can shrink it further: :: + + python train.py --prepare_data --max_seq_len=1024 + torchrun --nproc_per_node=2 train.py --max_seq_len=1024 --num_hidden_layers=2 --from_scratch + +With the default mini config (4 layers, 4K sequence length), the loss curve will be like following: + +.. image:: ./images/llama3-curves-mini.png + +Finetune Llama 3 8B +=================== + +Use the following commands to finetune `Meta-Llama-3-8B-Instruct `: :: + + python train.py --prepare_data + torchrun --nproc_per_node=8 train.py + +.. image:: ./images/llama3-curves-8b.png + +******** +Resuming +******** + +The example will save a checkpoint on finish. +To continue training from the checkpoint: :: + + torchrun --nproc_per_node=8 train.py --resume_from=last --max_train_steps=2000 + +Please note that the checkpoint is sharded according to the distribution strategy. +If you want to resume a checkpoint in a different environment, you need to merge it into an ordinal checkpoint first: :: + + python train.py --merge_checkpoint=./checkpoints/last + torchrun --nproc_per_node=8 train.py --resume_from=./checkpoints/merged.ckpt --max_train_steps=3000 diff --git a/examples/llama3_demo/README.rst b/examples/llama3_demo/README.rst new file mode 120000 index 00000000..52a40886 --- /dev/null +++ b/examples/llama3_demo/README.rst @@ -0,0 +1 @@ +../../docs/source/llama3_demo_example.rst \ No newline at end of file diff --git a/examples/llama3_demo/requirements.txt b/examples/llama3_demo/requirements.txt new file mode 100644 index 00000000..2cac3224 --- /dev/null +++ b/examples/llama3_demo/requirements.txt @@ -0,0 +1,3 @@ +datasets +tensorboard +transformers<4.43 diff --git a/examples/llama3_demo/train.py b/examples/llama3_demo/train.py new file mode 100644 index 00000000..4a48c0b2 --- /dev/null +++ b/examples/llama3_demo/train.py @@ -0,0 +1,271 @@ +import argparse +from datetime import datetime +import os +from pathlib import Path + +import datasets +import torch +from transformers import ( + AutoConfig, + AutoTokenizer, + AutoModelForCausalLM, + DataCollatorForLanguageModeling, +) + +from nnscaler.cli.loggers.tensorboard import TensorBoardLogger +from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer_args import ( + CheckpointConfig, + ComputeConfig, + DataloaderConfig, + DatasetConfig, + LogConfig, + ModelConfig, + OptimizerConfig, + TrainerArgs, +) +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdamW +import nnscaler.utils + +import torch._dynamo # FIXME: a workaround to avoid tracing the dynamic import + + +model_id = 'meta-llama/Meta-Llama-3-8B-Instruct' +tokenizer_id = model_id +dataset_id = 'bookcorpus/bookcorpus' + + +def prepare_data(max_seq_len, dataset_path=None): + if dataset_path is None: + dataset_path = f'./bookcorpus-{max_seq_len}' + + dataset = datasets.load_dataset(dataset_id)['train'] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + + def _tokenize(sample): + text = tokenizer.bos_token + sample['text'] + tokenizer.eos_token + input_ids = tokenizer.encode(text, add_special_tokens=False) + return {'input_ids': input_ids} + + tokenized_dataset = dataset.map( + _tokenize, + remove_columns=dataset.column_names, + num_proc=32, + ) + + def _concat_split(samples): + buffer = [] + resized_ids = [] + for input_ids in samples['input_ids']: + buffer.extend(input_ids) + while len(buffer) >= max_seq_len: + resized_ids.append(buffer[:max_seq_len]) + buffer = buffer[max_seq_len:] + return {'input_ids': resized_ids} + + final_dataset = tokenized_dataset.map( + _concat_split, + remove_columns=tokenized_dataset.column_names, + num_proc=32, + batched=True, + batch_size=10000, + ) + + final_dataset.save_to_disk(dataset_path) + return dataset_path + + +class WrapperModel(torch.nn.Module): + def __init__(self, model_id, from_scratch=False, num_hidden_layers=None): + super().__init__() + + if num_hidden_layers is not None: + from_scratch = True + + if from_scratch: + config = AutoConfig.from_pretrained(model_id) + if num_hidden_layers: + config.num_hidden_layers = num_hidden_layers + self.model = AutoModelForCausalLM.from_config(config) + else: + self.model = AutoModelForCausalLM.from_pretrained(model_id) + + def forward(self, data): + result = self.model( + input_ids=data['input_ids'], + labels=data['labels'], + ) + return result.loss + + +def main(): + nnscaler.utils.set_default_logger_level('INFO') + + ## Parse Args ## + + parser = argparse.ArgumentParser() + parser.add_argument( + '--prepare_data', + action='store_true', + help='prepare dataset', + ) + parser.add_argument( + '--max_train_steps', + type=int, + default=1000, + help='specify max training steps', + ) + parser.add_argument( + '--mini', + action='store_true', + help='equals to "--from_scratch=True --num_hidden_layers=4 --max_seq_len=4096" (overrides these parameters)', + ) + parser.add_argument( + '--resume_from', + help='load specified checkpoint', + ) + parser.add_argument( + '--merge_checkpoint', + help='merge specified checkpoint', + ) + parser.add_argument( + '--from_scratch', + action='store_true', + help='train from scratch instead of finetune from huggingface checkpoint', + ) + parser.add_argument( + '--num_hidden_layers', + type=int, + help="specify the model's layer number", + ) + parser.add_argument( + '--max_seq_len', + type=int, + default=8192, + help="specify max sequence length", + ) + parser.add_argument( + '--dataset_path', + help='specify dataset path (default to "./bookcorpus-{max_seq_len}")', + ) + args = parser.parse_args() + + if args.mini: + args.from_scratch = True + args.num_hidden_layers = 4 + args.max_seq_len = 4096 + + ## Special Commands ## + + if args.prepare_data: + dataset_path = prepare_data(args.max_seq_len, args.dataset_path) + print(f'Dataset saved to {dataset_path}') + return + + if args.merge_checkpoint: + checkpoint_files = sorted(Path(args.merge_checkpoint).iterdir()) + Trainer.merge_checkpoint(checkpoint_files, './checkpoints/merged.ckpt') + print('Checkpoint merged to ./checkpoints/merged.ckpt') + return + + ## Setup Dataset ## + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + tokenizer.pad_token = tokenizer.eos_token + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + def collate(samples): + if len(samples) == 0: + return {} + + mini_batch = data_collator(samples) + + input_ids = mini_batch['input_ids'] + seq_len = input_ids.size(-1) + + shift_labels = mini_batch['labels'][..., 1:] + labels = torch.nn.functional.pad(shift_labels, (0, 1), 'constant', -100).contiguous() + + return { + 'input_ids': input_ids, + 'labels': labels, + } + + ## Config Trainer ## + + world_size = int(os.getenv('WORLD_SIZE')) + compute_config = ComputeConfig( + plan_ngpus=1, + runtime_ngpus=world_size, + constant_folding=True, + use_zero=True, + use_end2end=True, + pas_config={ # to reduce memory usage + 'recompute_modules': 'LlamaDecoderLayer', + 'transient_mem_coef': 0.5, + } + ) + + model_config = ModelConfig( + type=WrapperModel, + args={ + 'model_id': model_id, + 'from_scratch': args.from_scratch, + 'num_hidden_layers': args.num_hidden_layers, + }, + ) + + optimizer_config = OptimizerConfig( + type=MixedPrecisionAdamW, + args={'lr': 2e-5, 'fused': True}, + clip_gnorm=1.0, + ) + + dataset_path = args.dataset_path + if dataset_path is None: + dataset_path = f'./bookcorpus-{args.max_seq_len}' + dataset_config = DatasetConfig( + type=datasets.load_from_disk, + train_args={'dataset_path': dataset_path}, + ) + + dataloader_config = DataloaderConfig( + train_args={'collate_fn': collate, 'drop_last': True}, + ) + + checkpoint_config = CheckpointConfig( + every_n_epochs=1, + save_type='deduped', + resume_from=args.resume_from, + ) + + timestamp = datetime.now().strftime('%y%m%d%H%M%S') + log_config = LogConfig( + type=TensorBoardLogger, + args={ + 'name': f'llama3-example-{timestamp}', + 'root_dir': 'runs', + }, + ) + + trainer_args = TrainerArgs( + compute_config=compute_config, + pas_policy='autodist', + model=model_config, + optimizer=optimizer_config, + dataset=dataset_config, + dataloader=dataloader_config, + checkpoint=checkpoint_config, + log=[log_config], + precision='bf16', + grad_accumulation_steps=8, + max_train_steps=args.max_train_steps, + seed=0, + ) + + trainer = Trainer(train_args=trainer_args) + trainer.run() + + +if __name__ == '__main__': + main() From 1dc54fcb91177a24bc6f1397761a2cf3ea1d0035 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 24 Sep 2024 09:07:47 +0000 Subject: [PATCH 1733/1892] Merged PR 2269: UT refine: add model to cuda tracer will place the model to the strategy preference device, if want to reuse the original model after tracing, should place the model to the correct device, but it is not recommended to reuse the model after tracing. --- tests/runtime/test_module_merge.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/runtime/test_module_merge.py b/tests/runtime/test_module_merge.py index e66e89c8..ba208926 100644 --- a/tests/runtime/test_module_merge.py +++ b/tests/runtime/test_module_merge.py @@ -144,6 +144,7 @@ def train_iter(model, sample): # test after training + model.cuda() for _ in range(2): # full model loss = model(sample) From 716b9dc0706805af7279c8b986d8f38133177b2c Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Fri, 27 Sep 2024 01:20:36 +0000 Subject: [PATCH 1734/1892] Merged PR 2272: Pin llama 3 demo example's dependency versions torch 2.0 is compatible with transfromers 4.40+ Pin their versions to prevent this kind of problems Also added `torch<2.4` to the main requirements file torch 2.4 has been verified to be unsupported (known by Shang Ning; error recorded in troubleshooting) --- docs/source/llama3_demo_example.rst | 3 +++ examples/llama3_demo/requirements.txt | 5 +++-- requirements.txt | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/llama3_demo_example.rst b/docs/source/llama3_demo_example.rst index c8771851..12eff6ec 100644 --- a/docs/source/llama3_demo_example.rst +++ b/docs/source/llama3_demo_example.rst @@ -30,6 +30,9 @@ Installation pip install -r requirements.txt + Note: The requirements file has pinned ``torch``, ``transformers``, and ``datasets`` versions + to ensure their compatibility with each others. + 4. Prepare dataset :: # To run Llama 3 8B: diff --git a/examples/llama3_demo/requirements.txt b/examples/llama3_demo/requirements.txt index 2cac3224..5fd38d67 100644 --- a/examples/llama3_demo/requirements.txt +++ b/examples/llama3_demo/requirements.txt @@ -1,3 +1,4 @@ -datasets +datasets==2.21.0 tensorboard -transformers<4.43 +torch==2.3.1 +transformers==4.42.4 diff --git a/requirements.txt b/requirements.txt index a96ac7c9..9fd451a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,5 @@ psutil pulp pybind11 pyyaml -torch>=2.0 +torch>=2.0,<2.4 tqdm From 2c05d391a23fd172be824c70b78307b8a260e21a Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Fri, 27 Sep 2024 05:13:10 +0000 Subject: [PATCH 1735/1892] Merged PR 2271: add copyright add copyright --- examples/__init__.py | 3 + .../chunk_linear_cross_entropy.py | 3 + .../test_chunk_linear_cross_entropy.py | 3 + .../core/ring_attn_implementation.py | 3 + .../ring_attention/core/utils.py | 5 ++ .../core/zigzag_attn_implementation.py | 5 ++ .../ring_attention/ring_attn.py | 3 + .../ring_attention/test_ring_attn.py | 3 + .../ring_attention/test_zigzag_attn.py | 3 + .../ring_attention/zigzag_attn.py | 3 + examples/huggingface_nlp/compile_hf.py | 3 + examples/huggingface_nlp/compile_interface.py | 3 + examples/llama3_8B_128K/bookcorpus.py | 3 + .../chunk_linear_cross_entropy.py | 3 + examples/llama3_8B_128K/ckpt_merger.py | 3 + examples/llama3_8B_128K/create_mini_model.py | 3 + examples/llama3_8B_128K/modeling_modifier.py | 3 + examples/llama3_8B_128K/train.py | 3 + examples/llama3_demo/train.py | 3 + examples/nanogpt/train_cli.py | 4 + examples/nanogpt/train_nnscaler.py | 3 + examples/utils.py | 3 + examples/vision/swin/__init__.py | 3 + examples/vision/swin/baseline.py | 4 +- examples/vision/swin/blocks/__init__.py | 3 + examples/vision/swin/blocks/attention.py | 3 + examples/vision/swin/blocks/mlp.py | 3 + examples/vision/swin/blocks/patch.py | 3 + examples/vision/swin/blocks/transformer.py | 3 + examples/vision/swin/blocks/utils.py | 3 + examples/vision/swin/model.py | 3 + examples/vision/swin/policy/__init__.py | 3 + examples/vision/swin/policy/gallery.py | 3 + examples/vision/swin/train.py | 3 + nnscaler/__init__.py | 3 + nnscaler/algorithm/factory.py | 3 + nnscaler/algorithm/generics.py | 3 + nnscaler/algorithm/ops/conv.py | 2 + nnscaler/algorithm/ops/dimops.py | 3 + nnscaler/autodist/apis.py | 3 + nnscaler/autodist/autodist_config.py | 3 + nnscaler/autodist/cost_database.py | 3 + nnscaler/autodist/cube_operator.py | 3 + nnscaler/autodist/descs.py | 3 + nnscaler/autodist/dp_solver.cpp | 6 ++ nnscaler/autodist/dp_solver.h | 5 ++ nnscaler/autodist/model_graph.py | 3 + nnscaler/autodist/op_partition.py | 3 + nnscaler/autodist/pipeline_solver.py | 3 + nnscaler/autodist/spmd_solver.py | 3 + nnscaler/autodist/util.py | 3 + nnscaler/cli/__init__.py | 3 + nnscaler/cli/arg_parser.py | 3 + nnscaler/cli/loggers/__init__.py | 3 + nnscaler/cli/loggers/logger_base.py | 3 + nnscaler/cli/loggers/tensorboard.py | 5 ++ nnscaler/cli/loggers/wandb.py | 5 ++ nnscaler/cli/train.py | 3 + nnscaler/cli/train_hook.py | 3 + nnscaler/cli/trainer.py | 3 + nnscaler/cli/trainer_args.py | 3 + nnscaler/codegen/__init__.py | 3 + nnscaler/codegen/emit.py | 3 + nnscaler/codegen/frontend_mapping.py | 3 + nnscaler/codegen/lifecycle.py | 3 + nnscaler/codegen/module/autograd.py | 3 + nnscaler/codegen/module/module.py | 3 + nnscaler/codegen/schedule/schedule.py | 2 + nnscaler/codegen/syntax/blocks.py | 3 + nnscaler/codegen/syntax/symtable.py | 3 +- nnscaler/compiler.py | 3 + nnscaler/execplan/__init__.py | 3 + nnscaler/execplan/execplan.py | 3 + nnscaler/execplan/planpass/fusion.py | 3 + nnscaler/execplan/planpass/grouping.py | 3 + nnscaler/execplan/planpass/planpass.py | 3 + nnscaler/flags.py | 3 + nnscaler/graph/__init__.py | 3 + nnscaler/graph/function/__init__.py | 3 + nnscaler/graph/function/anchor.py | 2 + nnscaler/graph/function/conv.py | 3 + nnscaler/graph/function/dimops.py | 3 + nnscaler/graph/function/function.py | 3 + nnscaler/graph/function/pyfunc.py | 3 + nnscaler/graph/function/wrapnn.py | 3 + nnscaler/graph/gener/concurrent.py | 3 + nnscaler/graph/gener/gen.py | 3 + nnscaler/graph/gener/rvd/inter.py | 3 + nnscaler/graph/gener/rvd/intra.py | 3 + nnscaler/graph/gener/rvd/layout.py | 3 + nnscaler/graph/gener/utils.py | 3 + nnscaler/graph/graph.py | 3 + nnscaler/graph/parser/__init__.py | 3 + nnscaler/graph/parser/converter.py | 3 + nnscaler/graph/parser/external/__init__.py | 3 + nnscaler/graph/parser/external/apex.py | 3 + nnscaler/graph/parser/frame.py | 3 + .../fx/concrete_trace_utils/__init__.py | 4 +- .../parser/fx/concrete_trace_utils/_pytree.py | 82 +++++++++++++++++++ .../fx/concrete_trace_utils/concrete_proxy.py | 6 +- .../concrete_trace_utils/concrete_tracer.py | 6 +- .../fx/concrete_trace_utils/frame_utils.py | 3 + .../concrete_trace_utils/function_patcher.py | 3 + .../fx/concrete_trace_utils/metadata.py | 3 + .../concrete_trace_utils/operator_patcher.py | 4 +- .../fx/concrete_trace_utils/orig_func.py | 5 +- .../fx/concrete_trace_utils/pytree_utils.py | 3 + .../concrete_trace_utils/torch_fx_patcher.py | 3 + .../fx/concrete_trace_utils/trace_strategy.py | 3 + .../fx/concrete_trace_utils/wrap_utils.py | 3 + nnscaler/graph/parser/fx/mapping.py | 2 + nnscaler/graph/parser/fx/parser.py | 3 + nnscaler/graph/parser/register.py | 3 + nnscaler/graph/schedule/__init__.py | 3 + nnscaler/graph/schedule/predefined.py | 3 + nnscaler/graph/schedule/schedplan.py | 3 + nnscaler/graph/segment.py | 3 + .../integration/lightning/pytorch/__init__.py | 3 + .../lightning/pytorch/precision.py | 19 +++++ .../integration/lightning/pytorch/strategy.py | 3 + nnscaler/integration/lightning/utils.py | 3 + nnscaler/ir/__init__.py | 3 + nnscaler/ir/adapter/__init__.py | 3 + nnscaler/ir/adapter/adapter.py | 3 + nnscaler/ir/adapter/prim.py | 3 + nnscaler/ir/cten.py | 3 + nnscaler/ir/dtype.py | 3 + nnscaler/ir/operator.py | 3 + nnscaler/ir/tensor.py | 3 + nnscaler/ir/unique.py | 2 + nnscaler/parallel.py | 3 + nnscaler/policies.py | 3 + nnscaler/profiler/__init__.py | 3 + nnscaler/profiler/database.py | 3 + nnscaler/profiler/estimator.py | 3 + nnscaler/profiler/memory.py | 3 + nnscaler/profiler/timer.py | 3 + nnscaler/program.py | 3 + nnscaler/resources/__init__.py | 3 + nnscaler/runtime/__init__.py | 3 + nnscaler/runtime/adapter/__init__.py | 3 + nnscaler/runtime/adapter/collectives.py | 3 + nnscaler/runtime/adapter/nn.py | 3 + nnscaler/runtime/adapter/reducer.py | 3 + nnscaler/runtime/adapter/transform.py | 3 + nnscaler/runtime/device.py | 3 + nnscaler/runtime/executor.py | 3 + nnscaler/runtime/f16_optimizer.py | 5 ++ nnscaler/runtime/function/__init__.py | 3 + nnscaler/runtime/function/function.py | 3 + nnscaler/runtime/gnorm.py | 5 ++ nnscaler/runtime/module.py | 3 + nnscaler/runtime/resource.py | 3 + nnscaler/runtime/utils.py | 3 + nnscaler/utils.py | 3 + nnscaler/version.py | 3 + tests/algorithm/ops/test_dimops.py | 3 + tests/autodist/graph/test_calc_flops.py | 3 + tests/autodist/graph/test_recompute.py | 3 + tests/autodist/partition/test_state.py | 3 + .../pas/test_shared_param_pipeline.py | 3 + .../spmd_solver/test_cube_operator.py | 3 + tests/autodist/spmd_solver/test_follow.py | 3 + .../spmd_solver/test_partition_constraint.py | 3 + .../autodist/spmd_solver/test_shared_param.py | 3 + tests/autodist/test_dp_solver.py | 3 + tests/cli/common.py | 3 + tests/cli/test_arg_parser.py | 3 + tests/cli/test_train_args.py | 3 + tests/cli/test_trainer.py | 3 + tests/codegen/test_emit.py | 3 + tests/compiler/test_compile.py | 3 + tests/compiler/test_model.py | 3 + tests/conftest.py | 3 + tests/graph/function/helper.py | 3 + tests/graph/function/test_dataloader.py | 3 + tests/graph/function/test_dict_values.py | 3 + tests/graph/function/test_dimops.py | 3 + tests/graph/function/test_functions.py | 3 + tests/graph/function/test_script_func.py | 3 + tests/graph/gener/check_inter_rvd.py | 3 + tests/graph/gener/check_intra_rvd.py | 3 + tests/graph/gener/test_producer_fusion.py | 2 + tests/graph/gener/test_reducer_gen.py | 3 + tests/graph/parser/test_ast_transformer.py | 3 + tests/graph/parser/test_converter.py | 3 + tests/graph/parser/test_dce.py | 3 + tests/graph/parser/test_ir_obj_constant.py | 3 + tests/graph/parser/test_no_grad.py | 3 + tests/graph/parser/test_parser.py | 3 + tests/graph/parser/test_register.py | 3 + tests/graph/parser/test_register_external.py | 2 + tests/graph/test_graph.py | 2 + tests/graph/test_loss.py | 3 + tests/graph/test_multiref.py | 3 + tests/graph/test_segment.py | 3 + tests/graph/tracer/test_buffer.py | 3 + tests/graph/tracer/test_cls_wrapper.py | 3 + tests/graph/tracer/test_ctxt_manager.py | 3 + tests/graph/tracer/test_getattr.py | 3 + tests/graph/tracer/test_inplace.py | 3 + tests/graph/tracer/test_module_jit_init.py | 3 + tests/graph/tracer/test_namedtuple.py | 3 + tests/graph/tracer/test_op_patcher.py | 3 + tests/graph/tracer/test_pytree.py | 3 + tests/graph/tracer/test_scope.py | 3 + tests/integration/common.py | 19 +++++ tests/integration/lightning/datasets.py | 3 +- .../lightning/pytorch/simple_datamodules.py | 2 - .../lightning/pytorch/simple_models.py | 17 +++- .../lightning/pytorch/test_strategy.py | 3 + tests/ir/test_cten.py | 3 + tests/ir/test_tensor.py | 3 + tests/launch_torchrun.py | 3 + tests/parallel_module/common.py | 3 + tests/parallel_module/test_attr_dedup.py | 3 + tests/parallel_module/test_broadcast.py | 3 + tests/parallel_module/test_checkpoint.py | 3 + .../parallel_module/test_checkpoint_buffer.py | 3 + .../parallel_module/test_checkpoint_dedup.py | 3 + .../parallel_module/test_checkpoint_shared.py | 3 + .../parallel_module/test_checkpoint_unused.py | 3 + tests/parallel_module/test_ddp.py | 3 + tests/parallel_module/test_embedding.py | 3 + tests/parallel_module/test_end2end.py | 3 + tests/parallel_module/test_gencode.py | 3 + tests/parallel_module/test_inference.py | 3 + tests/parallel_module/test_init.py | 3 + tests/parallel_module/test_line_timer.py | 3 + tests/parallel_module/test_nested.py | 3 + tests/parallel_module/test_normlayer.py | 3 + tests/parallel_module/test_override.py | 3 + tests/parallel_module/test_pyfunc.py | 3 + tests/parallel_module/test_reducer_hook.py | 3 + tests/parallel_module/test_scale_grads.py | 3 + tests/parallel_module/test_submodule.py | 3 + tests/parallel_module/test_wholemodule.py | 3 + tests/profiler/test_op_profile.py | 3 + tests/runtime/test_dataloader.py | 2 + tests/runtime/test_f16_optimizer.py | 3 + tests/runtime/test_gnorm.py | 3 + tests/runtime/test_grad_accum.py | 3 + tests/runtime/test_module_merge.py | 3 + tests/runtime/test_reducer.py | 3 + tests/runtime/test_runtime_collectives.py | 3 + tests/test_policies.py | 3 + tests/test_program.py | 3 + tests/test_torchrun.py | 3 + tests/test_utils.py | 3 + tests/utils.py | 3 + 250 files changed, 879 insertions(+), 15 deletions(-) diff --git a/examples/__init__.py b/examples/__init__.py index e69de29b..4f8b9058 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + diff --git a/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py b/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py index d56a32b0..2f81f00a 100644 --- a/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py +++ b/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import torch.utils.checkpoint as ckpt diff --git a/examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py b/examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py index 1be4b144..10c5cd47 100644 --- a/examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py +++ b/examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from chunk_linear_cross_entropy import chunk_linear_cross_entropy, linear_cross_entropy diff --git a/examples/customized_ops/ring_attention/core/ring_attn_implementation.py b/examples/customized_ops/ring_attention/core/ring_attn_implementation.py index e6731b54..42219ad3 100644 --- a/examples/customized_ops/ring_attention/core/ring_attn_implementation.py +++ b/examples/customized_ops/ring_attention/core/ring_attn_implementation.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import torch.distributed as dist from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward diff --git a/examples/customized_ops/ring_attention/core/utils.py b/examples/customized_ops/ring_attention/core/utils.py index ba1d1b61..643fa59d 100644 --- a/examples/customized_ops/ring_attention/core/utils.py +++ b/examples/customized_ops/ring_attention/core/utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention + from typing import Optional, Tuple from functools import reduce import operator diff --git a/examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py b/examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py index b1cbfdba..f18deac4 100644 --- a/examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py +++ b/examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention + import torch import torch.distributed as dist from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward diff --git a/examples/customized_ops/ring_attention/ring_attn.py b/examples/customized_ops/ring_attention/ring_attn.py index e6da2f3a..801378ce 100644 --- a/examples/customized_ops/ring_attention/ring_attn.py +++ b/examples/customized_ops/ring_attention/ring_attn.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Tuple, List, Dict import torch from torch import Tensor diff --git a/examples/customized_ops/ring_attention/test_ring_attn.py b/examples/customized_ops/ring_attention/test_ring_attn.py index 198a1f4a..62104533 100644 --- a/examples/customized_ops/ring_attention/test_ring_attn.py +++ b/examples/customized_ops/ring_attention/test_ring_attn.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import nnscaler from nnscaler.graph import IRGraph diff --git a/examples/customized_ops/ring_attention/test_zigzag_attn.py b/examples/customized_ops/ring_attention/test_zigzag_attn.py index e5eea77d..c00f082e 100644 --- a/examples/customized_ops/ring_attention/test_zigzag_attn.py +++ b/examples/customized_ops/ring_attention/test_zigzag_attn.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import nnscaler from nnscaler.graph import IRGraph diff --git a/examples/customized_ops/ring_attention/zigzag_attn.py b/examples/customized_ops/ring_attention/zigzag_attn.py index a20fcba0..3fccb59b 100644 --- a/examples/customized_ops/ring_attention/zigzag_attn.py +++ b/examples/customized_ops/ring_attention/zigzag_attn.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Tuple, List, Dict import torch from torch import Tensor diff --git a/examples/huggingface_nlp/compile_hf.py b/examples/huggingface_nlp/compile_hf.py index 55a01507..08c2f39f 100644 --- a/examples/huggingface_nlp/compile_hf.py +++ b/examples/huggingface_nlp/compile_hf.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM from _collections_abc import MutableMapping import os diff --git a/examples/huggingface_nlp/compile_interface.py b/examples/huggingface_nlp/compile_interface.py index 1e95cd66..92d5422d 100644 --- a/examples/huggingface_nlp/compile_interface.py +++ b/examples/huggingface_nlp/compile_interface.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from nnscaler.graph.parser.converter import to_fx_graph import nnscaler diff --git a/examples/llama3_8B_128K/bookcorpus.py b/examples/llama3_8B_128K/bookcorpus.py index b69b90b0..1a77273a 100644 --- a/examples/llama3_8B_128K/bookcorpus.py +++ b/examples/llama3_8B_128K/bookcorpus.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import argparse from typing import List, Dict diff --git a/examples/llama3_8B_128K/chunk_linear_cross_entropy.py b/examples/llama3_8B_128K/chunk_linear_cross_entropy.py index db902b92..5fd3a54f 100644 --- a/examples/llama3_8B_128K/chunk_linear_cross_entropy.py +++ b/examples/llama3_8B_128K/chunk_linear_cross_entropy.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import torch.utils.checkpoint as ckpt diff --git a/examples/llama3_8B_128K/ckpt_merger.py b/examples/llama3_8B_128K/ckpt_merger.py index ce80f5a7..50e17b14 100644 --- a/examples/llama3_8B_128K/ckpt_merger.py +++ b/examples/llama3_8B_128K/ckpt_merger.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import argparse import os diff --git a/examples/llama3_8B_128K/create_mini_model.py b/examples/llama3_8B_128K/create_mini_model.py index e0718480..1151771e 100644 --- a/examples/llama3_8B_128K/create_mini_model.py +++ b/examples/llama3_8B_128K/create_mini_model.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import argparse from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer diff --git a/examples/llama3_8B_128K/modeling_modifier.py b/examples/llama3_8B_128K/modeling_modifier.py index 57b65b90..7792e3d8 100644 --- a/examples/llama3_8B_128K/modeling_modifier.py +++ b/examples/llama3_8B_128K/modeling_modifier.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + # This file modifies the official modeling_llama.py file at runtime to # 1. register the flash attention function to nnscaler and update related code # 2. replace the un-fused RMSNorm with apex's fused version diff --git a/examples/llama3_8B_128K/train.py b/examples/llama3_8B_128K/train.py index f74c3318..47915dd7 100644 --- a/examples/llama3_8B_128K/train.py +++ b/examples/llama3_8B_128K/train.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import argparse import os diff --git a/examples/llama3_demo/train.py b/examples/llama3_demo/train.py index 4a48c0b2..7d1f2edc 100644 --- a/examples/llama3_demo/train.py +++ b/examples/llama3_demo/train.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import argparse from datetime import datetime import os diff --git a/examples/nanogpt/train_cli.py b/examples/nanogpt/train_cli.py index 84ad7e9d..95c8627a 100644 --- a/examples/nanogpt/train_cli.py +++ b/examples/nanogpt/train_cli.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Run training with this command in this directory: ``` @@ -5,6 +8,7 @@ ../../nnscaler/cli/train.py -f train_cli_args.yaml ``` """ + import math import os from pathlib import Path diff --git a/examples/nanogpt/train_nnscaler.py b/examples/nanogpt/train_nnscaler.py index a5fe394d..b97f58c0 100644 --- a/examples/nanogpt/train_nnscaler.py +++ b/examples/nanogpt/train_nnscaler.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import math import os from pathlib import Path diff --git a/examples/utils.py b/examples/utils.py index 74192693..25eea356 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Union, Callable, Optional, Tuple import logging diff --git a/examples/vision/swin/__init__.py b/examples/vision/swin/__init__.py index e69de29b..4f8b9058 100644 --- a/examples/vision/swin/__init__.py +++ b/examples/vision/swin/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + diff --git a/examples/vision/swin/baseline.py b/examples/vision/swin/baseline.py index f7976e41..4a5b9e00 100644 --- a/examples/vision/swin/baseline.py +++ b/examples/vision/swin/baseline.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ OMP_NUM_THREADS=4 torchrun \ --nproc_per_node=1 \ @@ -5,7 +8,6 @@ examples/vision/swin/baseline.py """ - import math from typing import List, Tuple import warnings diff --git a/examples/vision/swin/blocks/__init__.py b/examples/vision/swin/blocks/__init__.py index e69de29b..4f8b9058 100644 --- a/examples/vision/swin/blocks/__init__.py +++ b/examples/vision/swin/blocks/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + diff --git a/examples/vision/swin/blocks/attention.py b/examples/vision/swin/blocks/attention.py index 72f4c3d8..89484608 100644 --- a/examples/vision/swin/blocks/attention.py +++ b/examples/vision/swin/blocks/attention.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Optional import torch import nnscaler diff --git a/examples/vision/swin/blocks/mlp.py b/examples/vision/swin/blocks/mlp.py index b8663783..8748a3b5 100644 --- a/examples/vision/swin/blocks/mlp.py +++ b/examples/vision/swin/blocks/mlp.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import torch.nn as nn import nnscaler diff --git a/examples/vision/swin/blocks/patch.py b/examples/vision/swin/blocks/patch.py index 8c48f551..dc71e162 100644 --- a/examples/vision/swin/blocks/patch.py +++ b/examples/vision/swin/blocks/patch.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Tuple import torch diff --git a/examples/vision/swin/blocks/transformer.py b/examples/vision/swin/blocks/transformer.py index 946948a2..03790429 100644 --- a/examples/vision/swin/blocks/transformer.py +++ b/examples/vision/swin/blocks/transformer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Tuple import torch import torch.nn as nn diff --git a/examples/vision/swin/blocks/utils.py b/examples/vision/swin/blocks/utils.py index acf35e91..56ac7775 100644 --- a/examples/vision/swin/blocks/utils.py +++ b/examples/vision/swin/blocks/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import math diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index c28b3778..481feeca 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import torch.nn as nn diff --git a/examples/vision/swin/policy/__init__.py b/examples/vision/swin/policy/__init__.py index e69de29b..4f8b9058 100644 --- a/examples/vision/swin/policy/__init__.py +++ b/examples/vision/swin/policy/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + diff --git a/examples/vision/swin/policy/gallery.py b/examples/vision/swin/policy/gallery.py index 7576ae9c..22554a23 100644 --- a/examples/vision/swin/policy/gallery.py +++ b/examples/vision/swin/policy/gallery.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List import more_itertools as mitr diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 27416f71..8e724acb 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ example: diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index f43567cd..87ae960f 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .version import __version__ from .parallel import ( ParallelModule, diff --git a/nnscaler/algorithm/factory.py b/nnscaler/algorithm/factory.py index 12781511..e72bdc09 100644 --- a/nnscaler/algorithm/factory.py +++ b/nnscaler/algorithm/factory.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Dict, Any diff --git a/nnscaler/algorithm/generics.py b/nnscaler/algorithm/generics.py index 3c970e77..3e10ecd1 100644 --- a/nnscaler/algorithm/generics.py +++ b/nnscaler/algorithm/generics.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Optional from nnscaler.ir.cten import IRCell diff --git a/nnscaler/algorithm/ops/conv.py b/nnscaler/algorithm/ops/conv.py index cdf21d4c..1c961176 100644 --- a/nnscaler/algorithm/ops/conv.py +++ b/nnscaler/algorithm/ops/conv.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from typing import List, Tuple diff --git a/nnscaler/algorithm/ops/dimops.py b/nnscaler/algorithm/ops/dimops.py index 794e4a72..88d414b2 100644 --- a/nnscaler/algorithm/ops/dimops.py +++ b/nnscaler/algorithm/ops/dimops.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Optional, Any, Dict, Union, Tuple import numpy as np import logging diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 8e9073b2..31b9b1ad 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .spmd_solver import calc_optimal_spmd_plan, analysis_pretty_printer from .pipeline_solver import calc_optimal_pp_plan from .autodist_config import AutoDistConfig diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index fbef7612..511eb425 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pathlib import Path import argparse import logging diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index c7224960..53159bca 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Tuple, Union, Callable, Dict import json import os diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py index 1ec6ed7c..59be7bbe 100644 --- a/nnscaler/autodist/cube_operator.py +++ b/nnscaler/autodist/cube_operator.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Tuple, Dict, Set, Optional from nnscaler.ir import IRTensor, IRFwOperation, IRSubTensor from nnscaler.graph.function.dimops import DimAnno, IRDimops diff --git a/nnscaler/autodist/descs.py b/nnscaler/autodist/descs.py index 9d9bb9e3..19c5820d 100644 --- a/nnscaler/autodist/descs.py +++ b/nnscaler/autodist/descs.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from dataclasses import dataclass from typing import List, Dict, Tuple, Any, Optional import json diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index 515b412d..ecfda7f7 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -1,4 +1,10 @@ // cppimport + +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT License. + */ + #include "dp_solver.h" #include #include diff --git a/nnscaler/autodist/dp_solver.h b/nnscaler/autodist/dp_solver.h index 05bc2196..9e24ad44 100644 --- a/nnscaler/autodist/dp_solver.h +++ b/nnscaler/autodist/dp_solver.h @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT License. + */ + #include #include #include diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 1b824c46..8815e3a1 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from __future__ import annotations from nnscaler.graph.graph import IRGraph diff --git a/nnscaler/autodist/op_partition.py b/nnscaler/autodist/op_partition.py index d52c3de7..00450454 100644 --- a/nnscaler/autodist/op_partition.py +++ b/nnscaler/autodist/op_partition.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.autodist.cube_operator import CubeOperator from nnscaler.graph.function.dimops import DimAnno, IRDimops diff --git a/nnscaler/autodist/pipeline_solver.py b/nnscaler/autodist/pipeline_solver.py index 9d000461..306c8b4d 100644 --- a/nnscaler/autodist/pipeline_solver.py +++ b/nnscaler/autodist/pipeline_solver.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .model_graph import ModelGraph, estimate_mem_lower_bound, IntervalInfo from .spmd_solver import SPMDSolver from .descs import * diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index a9ff87df..9d60b963 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .model_graph import ModelGraph, collect_depth2scope_nodes from .cube_operator import CubeOperator from .descs import * diff --git a/nnscaler/autodist/util.py b/nnscaler/autodist/util.py index c256c328..62d7b3dd 100644 --- a/nnscaler/autodist/util.py +++ b/nnscaler/autodist/util.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .descs import NodePartitionDesc from nnscaler.graph import IRGraph from nnscaler.ir.operator import IRFwOperation diff --git a/nnscaler/cli/__init__.py b/nnscaler/cli/__init__.py index 134577af..958e874f 100644 --- a/nnscaler/cli/__init__.py +++ b/nnscaler/cli/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.cli.trainer import Trainer from nnscaler.cli.trainer_args import ( TrainerArgs, diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index ce82542f..b9f94910 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Tuple, Dict, Any, Union from dataclasses import dataclass, is_dataclass, asdict import enum diff --git a/nnscaler/cli/loggers/__init__.py b/nnscaler/cli/loggers/__init__.py index 20900fb3..0e870679 100644 --- a/nnscaler/cli/loggers/__init__.py +++ b/nnscaler/cli/loggers/__init__.py @@ -1,2 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .tensorboard import TensorBoardLogger from .wandb import WandbLogger diff --git a/nnscaler/cli/loggers/logger_base.py b/nnscaler/cli/loggers/logger_base.py index 00b2f0bc..0ec5eae8 100644 --- a/nnscaler/cli/loggers/logger_base.py +++ b/nnscaler/cli/loggers/logger_base.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from abc import ABC, abstractmethod from typing import Optional, Dict diff --git a/nnscaler/cli/loggers/tensorboard.py b/nnscaler/cli/loggers/tensorboard.py index 43295fe4..93acaf37 100644 --- a/nnscaler/cli/loggers/tensorboard.py +++ b/nnscaler/cli/loggers/tensorboard.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# CREDITS: This logger implementation is inspired by Fairseq https://github.com/facebookresearch/fairseq/blob/main/fairseq/logging/progress_bar.py + import atexit from pathlib import Path from typing import Dict, Optional diff --git a/nnscaler/cli/loggers/wandb.py b/nnscaler/cli/loggers/wandb.py index 08e2dd76..bac4fbbc 100644 --- a/nnscaler/cli/loggers/wandb.py +++ b/nnscaler/cli/loggers/wandb.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# CREDITS: This logger implementation is inspired by Fairseq https://github.com/facebookresearch/fairseq/blob/main/fairseq/logging/progress_bar.py + from typing import Dict, Optional from pathlib import Path diff --git a/nnscaler/cli/train.py b/nnscaler/cli/train.py index c6981d76..670e150a 100644 --- a/nnscaler/cli/train.py +++ b/nnscaler/cli/train.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import logging import nnscaler diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index ea68a823..09707f9c 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Any, Dict, List, TYPE_CHECKING import torch diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 188801e0..1424cc4e 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional, Union from pathlib import Path diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 575357cb..30af224d 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from dataclasses import asdict, dataclass, field import importlib from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union diff --git a/nnscaler/codegen/__init__.py b/nnscaler/codegen/__init__.py index ad3e3973..6a782b15 100644 --- a/nnscaler/codegen/__init__.py +++ b/nnscaler/codegen/__init__.py @@ -1,2 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.codegen.module.module import ModuleCodeGen from nnscaler.codegen.schedule.schedule import ScheduleCodeGen diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index d2060a22..31527b73 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Generator, Iterable, List, Any, Optional, Tuple, Dict import logging diff --git a/nnscaler/codegen/frontend_mapping.py b/nnscaler/codegen/frontend_mapping.py index 8bcce3d3..7eb362b4 100644 --- a/nnscaler/codegen/frontend_mapping.py +++ b/nnscaler/codegen/frontend_mapping.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + # Some operators should be specially handled during codegen to the frontend code, # here we define the customized rule for code emisson. diff --git a/nnscaler/codegen/lifecycle.py b/nnscaler/codegen/lifecycle.py index e2413cde..5bdcf5b5 100644 --- a/nnscaler/codegen/lifecycle.py +++ b/nnscaler/codegen/lifecycle.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Iterable, Dict, List, Any import itertools diff --git a/nnscaler/codegen/module/autograd.py b/nnscaler/codegen/module/autograd.py index a6e25ba8..facb165a 100644 --- a/nnscaler/codegen/module/autograd.py +++ b/nnscaler/codegen/module/autograd.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List from nnscaler.ir.tensor import IRSubTensor diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 1e5a2e25..79a82229 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Optional, Tuple, Dict, Any import more_itertools import logging diff --git a/nnscaler/codegen/schedule/schedule.py b/nnscaler/codegen/schedule/schedule.py index e46d6d86..c21a764e 100644 --- a/nnscaler/codegen/schedule/schedule.py +++ b/nnscaler/codegen/schedule/schedule.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from typing import List, Optional, Tuple import copy diff --git a/nnscaler/codegen/syntax/blocks.py b/nnscaler/codegen/syntax/blocks.py index d6198d10..822656ac 100644 --- a/nnscaler/codegen/syntax/blocks.py +++ b/nnscaler/codegen/syntax/blocks.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Optional class Block: diff --git a/nnscaler/codegen/syntax/symtable.py b/nnscaler/codegen/syntax/symtable.py index 8732112f..348ede13 100644 --- a/nnscaler/codegen/syntax/symtable.py +++ b/nnscaler/codegen/syntax/symtable.py @@ -1,4 +1,5 @@ - +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. class SymbolTable: """ diff --git a/nnscaler/compiler.py b/nnscaler/compiler.py index 4d99a815..78a8aa5b 100644 --- a/nnscaler/compiler.py +++ b/nnscaler/compiler.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Callable, Tuple, Union, Optional import torch import time diff --git a/nnscaler/execplan/__init__.py b/nnscaler/execplan/__init__.py index a542cec3..e6306e06 100644 --- a/nnscaler/execplan/__init__.py +++ b/nnscaler/execplan/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.execplan.execplan import ExecutionPlan \ No newline at end of file diff --git a/nnscaler/execplan/execplan.py b/nnscaler/execplan/execplan.py index e9bf96e1..bc86cdf5 100644 --- a/nnscaler/execplan/execplan.py +++ b/nnscaler/execplan/execplan.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Callable, Dict, List, Optional, Tuple, Any import copy import numpy as np diff --git a/nnscaler/execplan/planpass/fusion.py b/nnscaler/execplan/planpass/fusion.py index 0b27c70b..c65353eb 100644 --- a/nnscaler/execplan/planpass/fusion.py +++ b/nnscaler/execplan/planpass/fusion.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Union, Set import logging from nnscaler.graph.graph import IRSegment diff --git a/nnscaler/execplan/planpass/grouping.py b/nnscaler/execplan/planpass/grouping.py index dc3e3055..2e1d77b2 100644 --- a/nnscaler/execplan/planpass/grouping.py +++ b/nnscaler/execplan/planpass/grouping.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Operation grouping """ diff --git a/nnscaler/execplan/planpass/planpass.py b/nnscaler/execplan/planpass/planpass.py index 373d21ad..727b1b51 100644 --- a/nnscaler/execplan/planpass/planpass.py +++ b/nnscaler/execplan/planpass/planpass.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.execplan import ExecutionPlan diff --git a/nnscaler/flags.py b/nnscaler/flags.py index 31fd5c91..032fe3b6 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Environment flags for compiling options """ diff --git a/nnscaler/graph/__init__.py b/nnscaler/graph/__init__.py index b258cbe6..988fb6ca 100644 --- a/nnscaler/graph/__init__.py +++ b/nnscaler/graph/__init__.py @@ -1,2 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.graph.graph import IRGraph from nnscaler.graph import parser diff --git a/nnscaler/graph/function/__init__.py b/nnscaler/graph/function/__init__.py index 8afb8034..202fd241 100644 --- a/nnscaler/graph/function/__init__.py +++ b/nnscaler/graph/function/__init__.py @@ -1,2 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.graph.function.dimops import IRDimops from nnscaler.graph.function.function import * \ No newline at end of file diff --git a/nnscaler/graph/function/anchor.py b/nnscaler/graph/function/anchor.py index 8f7fd236..298145c5 100644 --- a/nnscaler/graph/function/anchor.py +++ b/nnscaler/graph/function/anchor.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from nnscaler.ir.operator import IRFwOperation from nnscaler.ir.cten import IRObject diff --git a/nnscaler/graph/function/conv.py b/nnscaler/graph/function/conv.py index 85e8e99c..b14aeedf 100644 --- a/nnscaler/graph/function/conv.py +++ b/nnscaler/graph/function/conv.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List from nnscaler.ir.operator import IRFwOperation diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index b555624f..8448bbfd 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Dimension Annotion Operations. diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 1a605aea..c7d5e28d 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Any, Callable, List, Optional, Tuple, Dict, Union, Iterable import string import copy diff --git a/nnscaler/graph/function/pyfunc.py b/nnscaler/graph/function/pyfunc.py index ddc34dd8..ce0a5643 100644 --- a/nnscaler/graph/function/pyfunc.py +++ b/nnscaler/graph/function/pyfunc.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Tuple from nnscaler.ir.operator import IRFwOperation diff --git a/nnscaler/graph/function/wrapnn.py b/nnscaler/graph/function/wrapnn.py index 70651a6e..300250c5 100644 --- a/nnscaler/graph/function/wrapnn.py +++ b/nnscaler/graph/function/wrapnn.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ This file deals with some special nn modules which have control flows (if/else) in their forward function. These control flows go different branches according to self.training. diff --git a/nnscaler/graph/gener/concurrent.py b/nnscaler/graph/gener/concurrent.py index c90facf7..1c91ab1b 100644 --- a/nnscaler/graph/gener/concurrent.py +++ b/nnscaler/graph/gener/concurrent.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Concurrent producer / consumer Adapter Generator """ diff --git a/nnscaler/graph/gener/gen.py b/nnscaler/graph/gener/gen.py index 79c69a3b..7f870458 100644 --- a/nnscaler/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Dict, List, Optional, Tuple, Callable, Set import numpy as np import itertools diff --git a/nnscaler/graph/gener/rvd/inter.py b/nnscaler/graph/gener/rvd/inter.py index 48df84d0..8640b7ef 100644 --- a/nnscaler/graph/gener/rvd/inter.py +++ b/nnscaler/graph/gener/rvd/inter.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Callable, Dict, List, Tuple, Optional, Set, Union from functools import partial import numpy as np diff --git a/nnscaler/graph/gener/rvd/intra.py b/nnscaler/graph/gener/rvd/intra.py index dcc6f82e..89c37dce 100644 --- a/nnscaler/graph/gener/rvd/intra.py +++ b/nnscaler/graph/gener/rvd/intra.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Callable, Dict, List, Tuple, Optional, Set from functools import partial import numpy as np diff --git a/nnscaler/graph/gener/rvd/layout.py b/nnscaler/graph/gener/rvd/layout.py index 7fd2b0b3..e20ba40b 100644 --- a/nnscaler/graph/gener/rvd/layout.py +++ b/nnscaler/graph/gener/rvd/layout.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Dict, List, Tuple, Optional import copy import numpy as np diff --git a/nnscaler/graph/gener/utils.py b/nnscaler/graph/gener/utils.py index 068583fe..6b8a9f3c 100644 --- a/nnscaler/graph/gener/utils.py +++ b/nnscaler/graph/gener/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Utilities for gradient modification """ diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index de4098d1..1d835316 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ IRGraph: a graph that is composed by node (IRFwOperation) and edge (IRTensor). diff --git a/nnscaler/graph/parser/__init__.py b/nnscaler/graph/parser/__init__.py index b06caafd..8b025b52 100644 --- a/nnscaler/graph/parser/__init__.py +++ b/nnscaler/graph/parser/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.graph.parser.fx.parser import FxModuleParser from nnscaler.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph from nnscaler.graph.parser.register import register diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 41a9debd..4a2b3251 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Any, Dict, Union import logging from pathlib import Path diff --git a/nnscaler/graph/parser/external/__init__.py b/nnscaler/graph/parser/external/__init__.py index 8574c302..5c71d8f9 100644 --- a/nnscaler/graph/parser/external/__init__.py +++ b/nnscaler/graph/parser/external/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .apex import * \ No newline at end of file diff --git a/nnscaler/graph/parser/external/apex.py b/nnscaler/graph/parser/external/apex.py index fc56fb62..c8274fe8 100644 --- a/nnscaler/graph/parser/external/apex.py +++ b/nnscaler/graph/parser/external/apex.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import copy import logging import string diff --git a/nnscaler/graph/parser/frame.py b/nnscaler/graph/parser/frame.py index 1c852946..6c1dcbba 100644 --- a/nnscaler/graph/parser/frame.py +++ b/nnscaler/graph/parser/frame.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from collections import OrderedDict from typing import List, Any, Dict, Tuple, Optional from nnscaler.ir.cten import IRTensor diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py b/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py index 282a505c..4439a216 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ FX is a toolkit for developers to use to transform ``nn.Module`` instances. FX consists of three main components, and this pipeline of diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py b/nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py index 3cc31f63..ce69b0bd 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py @@ -1,3 +1,85 @@ +# From PyTorch: +# +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) +# +# From Caffe2: +# +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. +# +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. +# +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. +# +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. +# +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain +# +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. +# +# All contributions by Arm: +# Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates +# +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. +# +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. +# +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + """ NOTE: This file is copy from torch 2.3.0, and make some extension diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 82db105a..ef44c608 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -1,5 +1,7 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# CREDITS: This implementation is inspired by PyTorch fx symbolic trace: https://github.com/pytorch/pytorch/blob/main/torch/fx/_symbolic_trace.py from __future__ import annotations diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index f8fa1186..6626bedb 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -1,5 +1,7 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# CREDITS: This implementation is inspired by PyTorch fx symbolic trace: https://github.com/pytorch/pytorch/blob/main/torch/fx/_symbolic_trace.py from __future__ import annotations diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py index 34b741ed..09e470d8 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from dataclasses import dataclass import dis import importlib diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py index ff66d44b..ae979a1a 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import builtins from contextlib import contextmanager from typing import Any, Callable, List, Dict, NamedTuple diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py b/nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py index a463fc23..dfb9f187 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Any, Dict, NamedTuple, Optional, Tuple import torch diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py index 2b623d3b..235da0b5 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from typing import TYPE_CHECKING if TYPE_CHECKING: diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py b/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py index 3a3aa099..cae65f7e 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py @@ -1,9 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ During tracing, the function or class in this file might be wrapped as another function or class. If the original function is needed to use (usually in tracer), should call the function in this file. """ -# all functions in operator will be wrapped during tracing +# all functions in operator will be wrapped during tracing from operator import * # the wrapped functon/class in builtins diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py index f23a034a..55dd19e5 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ This file is the pytree extension by nnscaler. """ diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py index 0cf6c50b..8affab89 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import operator from typing import Any, Callable, Set diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py index 2abc82a7..a55c10ff 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import time from typing import TYPE_CHECKING, Any, Tuple, Dict, Callable, Type diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py index 1151d52d..808868dd 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import builtins from contextlib import contextmanager from dataclasses import dataclass, field diff --git a/nnscaler/graph/parser/fx/mapping.py b/nnscaler/graph/parser/fx/mapping.py index 85b13938..df7e246f 100644 --- a/nnscaler/graph/parser/fx/mapping.py +++ b/nnscaler/graph/parser/fx/mapping.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from typing import Callable, Union from functools import partial diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index 6d98a2ff..96db3ef1 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import logging from pathlib import Path diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index fbb4dec8..1b6de28a 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Register cutomized function """ diff --git a/nnscaler/graph/schedule/__init__.py b/nnscaler/graph/schedule/__init__.py index 3d5d29e0..7fbb7d90 100644 --- a/nnscaler/graph/schedule/__init__.py +++ b/nnscaler/graph/schedule/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.graph.schedule.schedplan import SchedulePlan diff --git a/nnscaler/graph/schedule/predefined.py b/nnscaler/graph/schedule/predefined.py index 1a346194..9d2ce581 100644 --- a/nnscaler/graph/schedule/predefined.py +++ b/nnscaler/graph/schedule/predefined.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Common scheduling descriptions """ diff --git a/nnscaler/graph/schedule/schedplan.py b/nnscaler/graph/schedule/schedplan.py index 187687bf..69914684 100644 --- a/nnscaler/graph/schedule/schedplan.py +++ b/nnscaler/graph/schedule/schedplan.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Dict, List, Optional, Tuple, Set from nnscaler.ir.cten import IRCell diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index 251f192d..ba2b122b 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from contextlib import contextmanager from typing import Dict, Union, List, Optional, Set, Tuple, Any, Callable import numpy as np diff --git a/nnscaler/integration/lightning/pytorch/__init__.py b/nnscaler/integration/lightning/pytorch/__init__.py index 80f37b9e..6c7161a1 100644 --- a/nnscaler/integration/lightning/pytorch/__init__.py +++ b/nnscaler/integration/lightning/pytorch/__init__.py @@ -1,2 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .precision import NnScalerPrecision from .strategy import NnScalerStrategy diff --git a/nnscaler/integration/lightning/pytorch/precision.py b/nnscaler/integration/lightning/pytorch/precision.py index 5f43e03c..5690b13b 100644 --- a/nnscaler/integration/lightning/pytorch/precision.py +++ b/nnscaler/integration/lightning/pytorch/precision.py @@ -1,3 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Code modified from: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/plugins/precision/fsdp.py + +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, Union import torch diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py index 83477023..54461d39 100644 --- a/nnscaler/integration/lightning/pytorch/strategy.py +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from contextlib import contextmanager, nullcontext from functools import partial import logging diff --git a/nnscaler/integration/lightning/utils.py b/nnscaler/integration/lightning/utils.py index bc105233..bb88c11f 100644 --- a/nnscaler/integration/lightning/utils.py +++ b/nnscaler/integration/lightning/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from torch.optim.lbfgs import LBFGS diff --git a/nnscaler/ir/__init__.py b/nnscaler/ir/__init__.py index 0152b006..10ffb7f2 100644 --- a/nnscaler/ir/__init__.py +++ b/nnscaler/ir/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.ir.dtype import * from nnscaler.ir.cten import IRTensor, IRCell from nnscaler.ir.tensor import IRFullTensor, IRSubTensor diff --git a/nnscaler/ir/adapter/__init__.py b/nnscaler/ir/adapter/__init__.py index a16ff34b..8bc65b2d 100644 --- a/nnscaler/ir/adapter/__init__.py +++ b/nnscaler/ir/adapter/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.ir.adapter.adapter import IRAdapter, IRWeightReducer diff --git a/nnscaler/ir/adapter/adapter.py b/nnscaler/ir/adapter/adapter.py index 7b340cd9..11fc6f53 100644 --- a/nnscaler/ir/adapter/adapter.py +++ b/nnscaler/ir/adapter/adapter.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Optional, Dict import copy diff --git a/nnscaler/ir/adapter/prim.py b/nnscaler/ir/adapter/prim.py index 47e4583f..d85250ef 100644 --- a/nnscaler/ir/adapter/prim.py +++ b/nnscaler/ir/adapter/prim.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ The primitive used for IRAdapter """ diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index f91f21e8..f3cb1544 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + r""" IRCell: a graph node component serving for different purpose, diff --git a/nnscaler/ir/dtype.py b/nnscaler/ir/dtype.py index 8040b6d3..c39c43c2 100644 --- a/nnscaler/ir/dtype.py +++ b/nnscaler/ir/dtype.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Any import torch diff --git a/nnscaler/ir/operator.py b/nnscaler/ir/operator.py index 9269ad5e..516dbd4f 100644 --- a/nnscaler/ir/operator.py +++ b/nnscaler/ir/operator.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Optional, Tuple, Any, Union, List import copy diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index 09312e76..b7b79861 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + r""" SubTensor Gradient rule: diff --git a/nnscaler/ir/unique.py b/nnscaler/ir/unique.py index b40851e4..dde3ceb2 100644 --- a/nnscaler/ir/unique.py +++ b/nnscaler/ir/unique.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. class IDGenerator: """ diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index f7c67670..5e7ba873 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from enum import Enum from functools import partial import types diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 567c5e3f..2ffffd10 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Policy Writing Guidelines diff --git a/nnscaler/profiler/__init__.py b/nnscaler/profiler/__init__.py index 15a9d386..43d65619 100644 --- a/nnscaler/profiler/__init__.py +++ b/nnscaler/profiler/__init__.py @@ -1,2 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.profiler.timer import CudaTimer from nnscaler.profiler.database import ProfileDataBase diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index d21c8c5b..1aeabbc9 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Usage: python -m nnscaler.profiler.database --export ./profile.dat.json diff --git a/nnscaler/profiler/estimator.py b/nnscaler/profiler/estimator.py index 9a123119..a2466426 100644 --- a/nnscaler/profiler/estimator.py +++ b/nnscaler/profiler/estimator.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Union, Tuple import sys import os diff --git a/nnscaler/profiler/memory.py b/nnscaler/profiler/memory.py index eeda2b0c..b54d03c0 100644 --- a/nnscaler/profiler/memory.py +++ b/nnscaler/profiler/memory.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Any, List import logging from nnscaler.utils import print_each_rank diff --git a/nnscaler/profiler/timer.py b/nnscaler/profiler/timer.py index e73c504f..ac371480 100644 --- a/nnscaler/profiler/timer.py +++ b/nnscaler/profiler/timer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Optional import time import logging diff --git a/nnscaler/program.py b/nnscaler/program.py index 421f6d04..4be41dca 100644 --- a/nnscaler/program.py +++ b/nnscaler/program.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Tuple, Optional, Any, Dict, Union import inspect diff --git a/nnscaler/resources/__init__.py b/nnscaler/resources/__init__.py index af4fe9c0..b757e372 100644 --- a/nnscaler/resources/__init__.py +++ b/nnscaler/resources/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Pseudo module of resource files. """ diff --git a/nnscaler/runtime/__init__.py b/nnscaler/runtime/__init__.py index faac540f..d0171757 100644 --- a/nnscaler/runtime/__init__.py +++ b/nnscaler/runtime/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.runtime import executor from nnscaler.runtime import device from nnscaler.runtime import adapter diff --git a/nnscaler/runtime/adapter/__init__.py b/nnscaler/runtime/adapter/__init__.py index 574332e8..3d53aa79 100644 --- a/nnscaler/runtime/adapter/__init__.py +++ b/nnscaler/runtime/adapter/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.runtime.adapter.collectives import * from nnscaler.runtime.adapter.transform import * from nnscaler.runtime.adapter import nn diff --git a/nnscaler/runtime/adapter/collectives.py b/nnscaler/runtime/adapter/collectives.py index b18645e5..5ec6fd76 100644 --- a/nnscaler/runtime/adapter/collectives.py +++ b/nnscaler/runtime/adapter/collectives.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ This module offers the wrap of communication primitives based on `torch.distributed`. The use of these primitives standalone is typically diff --git a/nnscaler/runtime/adapter/nn.py b/nnscaler/runtime/adapter/nn.py index 5ac2623b..92628b6f 100644 --- a/nnscaler/runtime/adapter/nn.py +++ b/nnscaler/runtime/adapter/nn.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ This module offers autograd functions for communication primitives. This is typically used in the training with tensor diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index 96563d8d..07993926 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List, Dict, Tuple, Any, Callable, Optional, Set from functools import partial import logging diff --git a/nnscaler/runtime/adapter/transform.py b/nnscaler/runtime/adapter/transform.py index d60a9b18..c2ddd1b2 100644 --- a/nnscaler/runtime/adapter/transform.py +++ b/nnscaler/runtime/adapter/transform.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Adapter: Tensor Transformation """ diff --git a/nnscaler/runtime/device.py b/nnscaler/runtime/device.py index f28c47b2..10e02922 100644 --- a/nnscaler/runtime/device.py +++ b/nnscaler/runtime/device.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Communication group settings among devices """ diff --git a/nnscaler/runtime/executor.py b/nnscaler/runtime/executor.py index c762f438..06a46707 100644 --- a/nnscaler/runtime/executor.py +++ b/nnscaler/runtime/executor.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + r""" Executor for runtime """ diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py index 7d128c68..fc6b770e 100644 --- a/nnscaler/runtime/f16_optimizer.py +++ b/nnscaler/runtime/f16_optimizer.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# CREDITS: This implementation is inspired by Fairseq https://github.com/facebookresearch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py + import logging from typing import Optional, TYPE_CHECKING diff --git a/nnscaler/runtime/function/__init__.py b/nnscaler/runtime/function/__init__.py index ae856192..e2044e28 100644 --- a/nnscaler/runtime/function/__init__.py +++ b/nnscaler/runtime/function/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.runtime.function.function import * \ No newline at end of file diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 9cc5f4a3..9d278a77 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Optional, List, Tuple, Union, Any import torch import torch.nn.functional as TorchF diff --git a/nnscaler/runtime/gnorm.py b/nnscaler/runtime/gnorm.py index 8f073da3..d8e7db4d 100644 --- a/nnscaler/runtime/gnorm.py +++ b/nnscaler/runtime/gnorm.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# CREDITS: This implementation is inspired by Fairseq https://github.com/facebookresearch/fairseq/blob/main/fairseq/utils.py + from typing import List, Dict, Tuple, Optional, TYPE_CHECKING from dataclasses import dataclass from collections import defaultdict diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index ff672e7a..0a406825 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union import logging import os diff --git a/nnscaler/runtime/resource.py b/nnscaler/runtime/resource.py index 9223fcef..889ca966 100644 --- a/nnscaler/runtime/resource.py +++ b/nnscaler/runtime/resource.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + r""" Runtime information """ diff --git a/nnscaler/runtime/utils.py b/nnscaler/runtime/utils.py index 57e87ca5..b15748ea 100644 --- a/nnscaler/runtime/utils.py +++ b/nnscaler/runtime/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + r"""Runtime Utilities""" from typing import Any, List diff --git a/nnscaler/utils.py b/nnscaler/utils.py index e10b8f9f..55104bca 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from contextlib import contextmanager from functools import wraps from typing import Generator, Optional, Tuple, Callable, List, Set, Any, Iterable, Type, Union diff --git a/nnscaler/version.py b/nnscaler/version.py index 506a4934..46a80f4f 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + __version__ = '0.2.dev0' diff --git a/tests/algorithm/ops/test_dimops.py b/tests/algorithm/ops/test_dimops.py index b3e4ab9c..193c77fb 100644 --- a/tests/algorithm/ops/test_dimops.py +++ b/tests/algorithm/ops/test_dimops.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import torch import os diff --git a/tests/autodist/graph/test_calc_flops.py b/tests/autodist/graph/test_calc_flops.py index 1f1d8295..62d9cc6d 100644 --- a/tests/autodist/graph/test_calc_flops.py +++ b/tests/autodist/graph/test_calc_flops.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/autodist/graph/test_recompute.py b/tests/autodist/graph/test_recompute.py index f1dbaccc..6fe0b93b 100644 --- a/tests/autodist/graph/test_recompute.py +++ b/tests/autodist/graph/test_recompute.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/autodist/partition/test_state.py b/tests/autodist/partition/test_state.py index 2c6a3ce7..b104dcfd 100644 --- a/tests/autodist/partition/test_state.py +++ b/tests/autodist/partition/test_state.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.autodist.op_partition import calc_factors, generate_partitions diff --git a/tests/autodist/pas/test_shared_param_pipeline.py b/tests/autodist/pas/test_shared_param_pipeline.py index 482f9365..74d91705 100644 --- a/tests/autodist/pas/test_shared_param_pipeline.py +++ b/tests/autodist/pas/test_shared_param_pipeline.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/autodist/spmd_solver/test_cube_operator.py b/tests/autodist/spmd_solver/test_cube_operator.py index 29cab169..6d70efe1 100644 --- a/tests/autodist/spmd_solver/test_cube_operator.py +++ b/tests/autodist/spmd_solver/test_cube_operator.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index 7f37b312..1b585308 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/autodist/spmd_solver/test_partition_constraint.py b/tests/autodist/spmd_solver/test_partition_constraint.py index 95bacaef..fc40dd39 100644 --- a/tests/autodist/spmd_solver/test_partition_constraint.py +++ b/tests/autodist/spmd_solver/test_partition_constraint.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/autodist/spmd_solver/test_shared_param.py b/tests/autodist/spmd_solver/test_shared_param.py index 920d93a1..081406e7 100644 --- a/tests/autodist/spmd_solver/test_shared_param.py +++ b/tests/autodist/spmd_solver/test_shared_param.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/autodist/test_dp_solver.py b/tests/autodist/test_dp_solver.py index 330e8630..29f33322 100644 --- a/tests/autodist/test_dp_solver.py +++ b/tests/autodist/test_dp_solver.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import cppimport.import_hook import nnscaler.autodist.dp_solver as dp_solver diff --git a/tests/cli/common.py b/tests/cli/common.py index ea300e1b..39e64d20 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from torch.utils.data import DataLoader, Dataset diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 7e4ba3c1..2bf853e8 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from dataclasses import asdict, dataclass, field from typing import List, Optional, Tuple, Dict, Any, Union import sys diff --git a/tests/cli/test_train_args.py b/tests/cli/test_train_args.py index b6c9b2ad..183ccfbe 100644 --- a/tests/cli/test_train_args.py +++ b/tests/cli/test_train_args.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import nnscaler diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 20b681c0..295e44f3 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pathlib import Path import shutil diff --git a/tests/codegen/test_emit.py b/tests/codegen/test_emit.py index 9ec6b894..62075826 100644 --- a/tests/codegen/test_emit.py +++ b/tests/codegen/test_emit.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest from nnscaler.codegen.emit import CodeEmission, IRValue from nnscaler.ir.cten import IRObject diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index a9d4d6a3..5a6fbf1a 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ pytest unit_tests/compiler/test_compile.py """ diff --git a/tests/compiler/test_model.py b/tests/compiler/test_model.py index 2d1b3ef9..58da2397 100644 --- a/tests/compiler/test_model.py +++ b/tests/compiler/test_model.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import os import logging from functools import partial diff --git a/tests/conftest.py b/tests/conftest.py index 1877d5d2..dbaa1cc9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest from pathlib import Path diff --git a/tests/graph/function/helper.py b/tests/graph/function/helper.py index 1bc5257f..f2dc38dc 100644 --- a/tests/graph/function/helper.py +++ b/tests/graph/function/helper.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from nnscaler import register_op diff --git a/tests/graph/function/test_dataloader.py b/tests/graph/function/test_dataloader.py index 74c5f48f..c950a487 100644 --- a/tests/graph/function/test_dataloader.py +++ b/tests/graph/function/test_dataloader.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ pytest unit_tests/graph/function/test_dataloader.py """ diff --git a/tests/graph/function/test_dict_values.py b/tests/graph/function/test_dict_values.py index 193f329f..7bf5c58b 100644 --- a/tests/graph/function/test_dict_values.py +++ b/tests/graph/function/test_dict_values.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import torch from nnscaler.parallel import parallelize, ComputeConfig diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index 900215f4..59c57057 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ pytest unit_tests/graph/function/test_dimops.py """ diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index aa1743bd..be0a1f98 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + ### Only test the anno creation in these tests import nnscaler.graph.function.function as F diff --git a/tests/graph/function/test_script_func.py b/tests/graph/function/test_script_func.py index f5d6ee4e..00da0520 100644 --- a/tests/graph/function/test_script_func.py +++ b/tests/graph/function/test_script_func.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import torch from nnscaler.parallel import parallelize, ComputeConfig diff --git a/tests/graph/gener/check_inter_rvd.py b/tests/graph/gener/check_inter_rvd.py index 04be22dc..dc4cc2b5 100644 --- a/tests/graph/gener/check_inter_rvd.py +++ b/tests/graph/gener/check_inter_rvd.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Note this is not for test. diff --git a/tests/graph/gener/check_intra_rvd.py b/tests/graph/gener/check_intra_rvd.py index efa9851e..84ce00d7 100644 --- a/tests/graph/gener/check_intra_rvd.py +++ b/tests/graph/gener/check_intra_rvd.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ Note this is not for test. diff --git a/tests/graph/gener/test_producer_fusion.py b/tests/graph/gener/test_producer_fusion.py index 72791c8e..22144064 100644 --- a/tests/graph/gener/test_producer_fusion.py +++ b/tests/graph/gener/test_producer_fusion.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from nnscaler.ir.tensor import IRFullTensor import nnscaler.graph.function.function as F diff --git a/tests/graph/gener/test_reducer_gen.py b/tests/graph/gener/test_reducer_gen.py index c318d39b..0a4cd4eb 100644 --- a/tests/graph/gener/test_reducer_gen.py +++ b/tests/graph/gener/test_reducer_gen.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest from nnscaler.graph.gener.gen import IRAdapterGener diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 429876a3..3aa0f13a 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import ast from textwrap import dedent import sys diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 65e4e905..04b24302 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import importlib from pathlib import Path diff --git a/tests/graph/parser/test_dce.py b/tests/graph/parser/test_dce.py index 60decdc8..d6730a0f 100644 --- a/tests/graph/parser/test_dce.py +++ b/tests/graph/parser/test_dce.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import torch diff --git a/tests/graph/parser/test_ir_obj_constant.py b/tests/graph/parser/test_ir_obj_constant.py index 0263d824..43b52c01 100644 --- a/tests/graph/parser/test_ir_obj_constant.py +++ b/tests/graph/parser/test_ir_obj_constant.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile import math diff --git a/tests/graph/parser/test_no_grad.py b/tests/graph/parser/test_no_grad.py index 963f189b..b2ddbfaf 100644 --- a/tests/graph/parser/test_no_grad.py +++ b/tests/graph/parser/test_no_grad.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import torch diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 85632f83..9c2da080 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import pytest import torch diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index a3b6dc70..7ee3fd51 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import nnscaler from nnscaler.graph.parser.converter import convert_model from nnscaler.profiler.database import get_func diff --git a/tests/graph/parser/test_register_external.py b/tests/graph/parser/test_register_external.py index f5322f20..ce5f3dca 100644 --- a/tests/graph/parser/test_register_external.py +++ b/tests/graph/parser/test_register_external.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import torch import logging diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 392421cd..07913bc7 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from nnscaler.ir.tensor import IRFullTensor, IRSubTensor from nnscaler.ir.operator import IRFwOperation diff --git a/tests/graph/test_loss.py b/tests/graph/test_loss.py index 3494726a..44015a6d 100644 --- a/tests/graph/test_loss.py +++ b/tests/graph/test_loss.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/graph/test_multiref.py b/tests/graph/test_multiref.py index 42aa0e35..3e1dcb3d 100644 --- a/tests/graph/test_multiref.py +++ b/tests/graph/test_multiref.py @@ -1,6 +1,9 @@ """ pytest unit_tests/graph/test_multiref.py """ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import logging from functools import partial diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py index 5af0678d..eb9c8821 100644 --- a/tests/graph/test_segment.py +++ b/tests/graph/test_segment.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import nnscaler import nnscaler.graph.function.function as F from nnscaler.ir.tensor import IRFullTensor diff --git a/tests/graph/tracer/test_buffer.py b/tests/graph/tracer/test_buffer.py index 60ecf086..83cc8047 100644 --- a/tests/graph/tracer/test_buffer.py +++ b/tests/graph/tracer/test_buffer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import torch diff --git a/tests/graph/tracer/test_cls_wrapper.py b/tests/graph/tracer/test_cls_wrapper.py index 192cd486..45614312 100644 --- a/tests/graph/tracer/test_cls_wrapper.py +++ b/tests/graph/tracer/test_cls_wrapper.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import pytest diff --git a/tests/graph/tracer/test_ctxt_manager.py b/tests/graph/tracer/test_ctxt_manager.py index 9d214e9f..b211287c 100644 --- a/tests/graph/tracer/test_ctxt_manager.py +++ b/tests/graph/tracer/test_ctxt_manager.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import torch from nnscaler.graph.parser.converter import convert_model diff --git a/tests/graph/tracer/test_getattr.py b/tests/graph/tracer/test_getattr.py index b4ac5b7e..0086426e 100644 --- a/tests/graph/tracer/test_getattr.py +++ b/tests/graph/tracer/test_getattr.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from nnscaler.graph.parser.converter import to_fx_graph diff --git a/tests/graph/tracer/test_inplace.py b/tests/graph/tracer/test_inplace.py index 6323c5ed..b77980bc 100644 --- a/tests/graph/tracer/test_inplace.py +++ b/tests/graph/tracer/test_inplace.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import operator import _operator import torch diff --git a/tests/graph/tracer/test_module_jit_init.py b/tests/graph/tracer/test_module_jit_init.py index 9165725d..09e97e2d 100644 --- a/tests/graph/tracer/test_module_jit_init.py +++ b/tests/graph/tracer/test_module_jit_init.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from nnscaler.graph.parser.converter import to_fx_graph diff --git a/tests/graph/tracer/test_namedtuple.py b/tests/graph/tracer/test_namedtuple.py index 91b16245..0731c960 100644 --- a/tests/graph/tracer/test_namedtuple.py +++ b/tests/graph/tracer/test_namedtuple.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from collections import namedtuple import torch diff --git a/tests/graph/tracer/test_op_patcher.py b/tests/graph/tracer/test_op_patcher.py index a3ab400e..e1f436b3 100644 --- a/tests/graph/tracer/test_op_patcher.py +++ b/tests/graph/tracer/test_op_patcher.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from types import MethodType from nnscaler.graph.parser.fx.concrete_trace_utils.operator_patcher import OperatorPatcher diff --git a/tests/graph/tracer/test_pytree.py b/tests/graph/tracer/test_pytree.py index 096a52a9..2f64a6f3 100644 --- a/tests/graph/tracer/test_pytree.py +++ b/tests/graph/tracer/test_pytree.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.graph.parser.fx.concrete_trace_utils import pytree_utils from nnscaler.graph.parser.fx.concrete_trace_utils.pytree_utils import ( get_common_spec, diff --git a/tests/graph/tracer/test_scope.py b/tests/graph/tracer/test_scope.py index 18cf1932..bf50ec47 100644 --- a/tests/graph/tracer/test_scope.py +++ b/tests/graph/tracer/test_scope.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch from nnscaler.graph.parser.converter import to_fx_graph diff --git a/tests/integration/common.py b/tests/integration/common.py index 4e8b5a83..28530ae0 100644 --- a/tests/integration/common.py +++ b/tests/integration/common.py @@ -1,3 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Code modified from: https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/tests_fabric/test_fabric.py + +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch class BoringModel(nn.Module): diff --git a/tests/integration/lightning/datasets.py b/tests/integration/lightning/datasets.py index 3769e76d..7e2b6068 100644 --- a/tests/integration/lightning/datasets.py +++ b/tests/integration/lightning/datasets.py @@ -1,4 +1,5 @@ -# Copyright The Lightning AI team. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/integration/lightning/pytorch/simple_datamodules.py b/tests/integration/lightning/pytorch/simple_datamodules.py index 0ead1445..eafd48ff 100644 --- a/tests/integration/lightning/pytorch/simple_datamodules.py +++ b/tests/integration/lightning/pytorch/simple_datamodules.py @@ -1,5 +1,3 @@ -# Copyright The Lightning AI team. -# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/integration/lightning/pytorch/simple_models.py b/tests/integration/lightning/pytorch/simple_models.py index 037abb22..01abc6a4 100644 --- a/tests/integration/lightning/pytorch/simple_models.py +++ b/tests/integration/lightning/pytorch/simple_models.py @@ -1,4 +1,19 @@ -# copied from lightning tests +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Dict, Iterator, List, Optional, Tuple import torch diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index 39229f1b..41023354 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from contextlib import contextmanager import os from pathlib import Path diff --git a/tests/ir/test_cten.py b/tests/ir/test_cten.py index 6c65b158..ab054861 100644 --- a/tests/ir/test_cten.py +++ b/tests/ir/test_cten.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import pytest diff --git a/tests/ir/test_tensor.py b/tests/ir/test_tensor.py index 5f020d65..724c88c0 100644 --- a/tests/ir/test_tensor.py +++ b/tests/ir/test_tensor.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from nnscaler.ir.tensor import IRSubTensor, IRFullTensor diff --git a/tests/launch_torchrun.py b/tests/launch_torchrun.py index 50cc596e..fce62ed2 100644 --- a/tests/launch_torchrun.py +++ b/tests/launch_torchrun.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Callable import uuid import torch diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index c06ac4a7..fa8900cd 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from datetime import datetime import math import random diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py index 97ac8ff4..b89ce66a 100644 --- a/tests/parallel_module/test_attr_dedup.py +++ b/tests/parallel_module/test_attr_dedup.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile from pathlib import Path import pytest diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index a732f860..c7bc814a 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import os from pathlib import Path diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index fb0ac375..34449e6d 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import itertools import re diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index d3ea99fb..5c93b4d9 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pathlib import Path import tempfile import torch diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index 59115655..944685bb 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile from pathlib import Path import pytest diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index d0d02f7e..0fb417fa 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile from pathlib import Path import pytest diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py index 21e886cf..3b9fa85c 100644 --- a/tests/parallel_module/test_checkpoint_unused.py +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import itertools import re diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index 57509ea6..bff5c315 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import itertools import re diff --git a/tests/parallel_module/test_embedding.py b/tests/parallel_module/test_embedding.py index 21d1a8d1..6d48f9aa 100644 --- a/tests/parallel_module/test_embedding.py +++ b/tests/parallel_module/test_embedding.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import tempfile import pytest diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index b1d5fc38..b4c12e87 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ PYTHONPATH=.:$PYTHONPATH torchrun \ --nproc_per_node=4 \ diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index a4295d49..44cb9373 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import inspect import tempfile from contextlib import nullcontext diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 5dfb2de9..fbba994c 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pathlib import Path import shutil import tempfile diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 91322dc8..9344c04a 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import pytest diff --git a/tests/parallel_module/test_line_timer.py b/tests/parallel_module/test_line_timer.py index 44c865f9..e1b8a0d8 100644 --- a/tests/parallel_module/test_line_timer.py +++ b/tests/parallel_module/test_line_timer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pathlib import Path import tempfile import torch diff --git a/tests/parallel_module/test_nested.py b/tests/parallel_module/test_nested.py index 1c215949..9a24c1f0 100644 --- a/tests/parallel_module/test_nested.py +++ b/tests/parallel_module/test_nested.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import torch diff --git a/tests/parallel_module/test_normlayer.py b/tests/parallel_module/test_normlayer.py index 2a4f063d..243a56d5 100644 --- a/tests/parallel_module/test_normlayer.py +++ b/tests/parallel_module/test_normlayer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import uuid import torch.distributed as dist import tempfile diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index 90da6b6d..c7886353 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pathlib import Path from time import sleep import sys diff --git a/tests/parallel_module/test_pyfunc.py b/tests/parallel_module/test_pyfunc.py index 86a854ff..29f90114 100644 --- a/tests/parallel_module/test_pyfunc.py +++ b/tests/parallel_module/test_pyfunc.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import tempfile diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index 2c145c3d..6aaf5ca6 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile from pathlib import Path from collections import defaultdict diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py index 0af1a934..c10d6f78 100644 --- a/tests/parallel_module/test_scale_grads.py +++ b/tests/parallel_module/test_scale_grads.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import itertools import re diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 40e76df3..20f72f86 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import itertools import re diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 84c4d886..a0e1edf4 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import itertools import re diff --git a/tests/profiler/test_op_profile.py b/tests/profiler/test_op_profile.py index 9af6ac11..40e8e32e 100644 --- a/tests/profiler/test_op_profile.py +++ b/tests/profiler/test_op_profile.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import time import tempfile diff --git a/tests/runtime/test_dataloader.py b/tests/runtime/test_dataloader.py index 7f39104a..d5dbc2d6 100644 --- a/tests/runtime/test_dataloader.py +++ b/tests/runtime/test_dataloader.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import torch from nnscaler.runtime.utils import MicroBatchDataLoader, microbatches diff --git a/tests/runtime/test_f16_optimizer.py b/tests/runtime/test_f16_optimizer.py index 95609c90..3f843496 100644 --- a/tests/runtime/test_f16_optimizer.py +++ b/tests/runtime/test_f16_optimizer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pathlib import Path import shutil diff --git a/tests/runtime/test_gnorm.py b/tests/runtime/test_gnorm.py index c1ef5eda..5d8c4a87 100644 --- a/tests/runtime/test_gnorm.py +++ b/tests/runtime/test_gnorm.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ This test is to verify the correctness of the gradient norm algorithm for nnscaler. diff --git a/tests/runtime/test_grad_accum.py b/tests/runtime/test_grad_accum.py index 3c0b5909..21f38763 100644 --- a/tests/runtime/test_grad_accum.py +++ b/tests/runtime/test_grad_accum.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import pytest from functools import partial diff --git a/tests/runtime/test_module_merge.py b/tests/runtime/test_module_merge.py index ba208926..3e35c72a 100644 --- a/tests/runtime/test_module_merge.py +++ b/tests/runtime/test_module_merge.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import nnscaler import os diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py index 6efe6e1b..811b3510 100644 --- a/tests/runtime/test_reducer.py +++ b/tests/runtime/test_reducer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ pytest unit_tests/runtime/test_reducer.py """ diff --git a/tests/runtime/test_runtime_collectives.py b/tests/runtime/test_runtime_collectives.py index 70b39dc5..e1cf375a 100644 --- a/tests/runtime/test_runtime_collectives.py +++ b/tests/runtime/test_runtime_collectives.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import List import nnscaler diff --git a/tests/test_policies.py b/tests/test_policies.py index 5f56c24a..4f1fa7b0 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile from typing import * diff --git a/tests/test_program.py b/tests/test_program.py index 8ed79062..3b59d99f 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest import torch diff --git a/tests/test_torchrun.py b/tests/test_torchrun.py index 3d09c952..0a9b0a38 100644 --- a/tests/test_torchrun.py +++ b/tests/test_torchrun.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import os from .launch_torchrun import launch_torchrun diff --git a/tests/test_utils.py b/tests/test_utils.py index 1f5f3d7f..dd6ec951 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import pytest from nnscaler.utils import select_many diff --git a/tests/utils.py b/tests/utils.py index bd1eca0a..0e3eaad0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import os import re import sys From 44725cf1abf5a5ac15fe11430abd561e0b253a61 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 27 Sep 2024 06:29:26 +0000 Subject: [PATCH 1736/1892] Merged PR 2273: bump version to v0.3 bump version to v0.3 --- README.md | 20 +++++++++----------- nnscaler/version.py | 2 +- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index d5c6cddb..0319019a 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -drawing +drawing nnScaler: Compiling DNN models for Parallel Training over Multiple Devices ============== @@ -40,7 +40,7 @@ Install the following packages before the installation of cube: PyTorch >= 2.0, < 2.4 (2.2.0 is recommanded) ### (Option 1) Install nnScaler from source -Execute below commands in nnScaler directory: +Execute below commands in nnScaler directory: pip install -r requirements.txt pip install -e . @@ -62,13 +62,13 @@ To get started, install the latest wheel by visiting [DevOps Artifacts](https:// ### Prerequisite for Llama-3 -Install packages required to run Llama-3. Besides, a certain version of CUDA library is needed during flash-attn installation. For example, [CUDA V11.8](https://developer.nvidia.com/cuda-11-8-0-download-archive) is needed if using PyTorch 2.20. +Install packages required to run Llama-3. Besides, a certain version of CUDA library is needed during flash-attn installation. For example, [CUDA V11.8](https://developer.nvidia.com/cuda-11-8-0-download-archive) is needed if using PyTorch 2.20. python -m pip install transformers==4.40.0 flash-attn==2.5.5 tensorboard ### Model Access -Obtain access of Llama-3 model from [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), where you will receive an access token which should be set as an environment variable: +Obtain access of Llama-3 model from [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), where you will receive an access token which should be set as an environment variable: export HF_TOKEN= @@ -101,7 +101,7 @@ class WrapperModel(torch.nn.Module): def main(args): # data config dataloader_config = ... - + # model config model_config = ModelConfig( type=WrapperModel, @@ -109,15 +109,15 @@ def main(args): 'model_id': args.model_id, }, ) - # optimizer hyperparameters + # optimizer hyperparameters optimizer_config = OptimizerConfig( type=MixedPrecisionAdamW, args={'lr': 2e-5, 'betas': (0.9, 0.95), 'weight_decay': 0.0, 'fused': True}, #... ) #... - - # setup trainer with configs of dataloader/model/optimizer, etc. + + # setup trainer with configs of dataloader/model/optimizer, etc. trainer = Trainer(train_args=TrainerArgs( #... model=model_config, @@ -131,7 +131,7 @@ def main(args): ### Run the example Llama-3 training -Then we can start the example, and all the parallelization tasks will be finished by nnScaler automatically. +Then we can start the example, and all the parallelization tasks will be finished by nnScaler automatically. ```shell cd examples/llama3_8B_128K @@ -190,8 +190,6 @@ Or if you have multiple nodes, for example 2 nodes with 4 GPUs each: NOTE: The local batch size is fixed by default, so using more workers will result in a larger global batch size. -💡 _For advanced usages, please refer to: **TODO:link to rst docs**_ - # Success Stories diff --git a/nnscaler/version.py b/nnscaler/version.py index 46a80f4f..409b28ef 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -__version__ = '0.2.dev0' +__version__ = '0.3' From 498491de6c649f0853a2bba6d6bf2fea61517a0b Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Fri, 27 Sep 2024 07:36:55 +0000 Subject: [PATCH 1737/1892] Merged PR 2275: Fix packaging bug --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4d02a157..398b97f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,4 +27,4 @@ dynamic.dependencies.file = "requirements.txt" # the following part only affects wheel, not sdist # since our current plan is to use cppimport, sdist is not needed packages.find.include = ["nnscaler*"] -package-data = { nnscaler = ["resources/**", "autodist/*.cpp"] } +package-data = { nnscaler = ["resources/**", "autodist/*.h", "autodist/*.cpp"] } From 2866af18e4fd444f5ccc5f15ed7240c7401fd1dd Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 8 Oct 2024 07:01:21 +0000 Subject: [PATCH 1738/1892] Merged PR 2267: Refine Async Support Refine Async Support 1. sort reducer params based on the first used position in `forward` 2. refine build bucket logic: strictly follow the order of params 3. refine mixed precision support: use a separate reducer for each dtype. 4. support buckets in sync-mode. Parity Check: Verified the result of async and async are exactly the same given they use the same buckets. But different bucket configuration can lead to different results. --- docs/source/parallel_module.md | 12 + nnscaler/codegen/lifecycle.py | 8 + nnscaler/codegen/module/module.py | 76 +++++- nnscaler/flags.py | 5 +- nnscaler/graph/gener/gen.py | 46 ++-- nnscaler/ir/adapter/adapter.py | 29 ++- nnscaler/parallel.py | 20 +- nnscaler/policies.py | 6 + nnscaler/runtime/adapter/reducer.py | 91 ++++--- nnscaler/runtime/module.py | 12 +- tests/parallel_module/test_end2end.py | 66 +++++- .../test_end2end_mix_precision.py | 222 ++++++++++++++++++ tests/parallel_module/test_init.py | 98 +++++++- tests/runtime/test_reducer.py | 22 +- tests/utils.py | 13 +- 15 files changed, 631 insertions(+), 95 deletions(-) create mode 100644 tests/parallel_module/test_end2end_mix_precision.py diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 616bb71e..301d59cb 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -191,6 +191,9 @@ class ComputeConfig: inference_only : bool = False use_end2end: bool = False + use_async_reducer: bool = False + reducer_bucket_cap_mb: Optional[float] = None + pas_config: Dict[str, Any] = field(default_factory=dict) user_config: Dict[str, Any] = field(default_factory=dict) ``` @@ -233,6 +236,15 @@ We can categorize the fields into 4 categories: - `zero_ngroups`: the number of groups to be used in zero. - `inference_only`: whether to generate code for inference only. If it is true, the generated code can not be used to train the model. - `use_end2end`: whether to use end2end training. For the requirement of end2end, see the description above. + - `use_async_reducer`: whether to use async reducer. + If it is true, the gradients will be reduced asynchronously. + Please note this only works when `use_end2end` is true. + - `reducer_bucket_cap_mb`: the bucket capacity of the reducer. + If it is `None` or `0`, the default value will be used, which is + - 25MB for async, the same default value with pytorch ddp implementation + - no limit for sync + + Please note this only works when `use_end2end` is true. - `pas_config`: the configuration for the PAS policy (partition-assign-schedule policy, which describes how to place all computations across devices. For details, please refer to [PAS Policies](#pas-policies)). It is a dictionary, and will be used by the PAS policy. Please note different PAS will have different configurations, diff --git a/nnscaler/codegen/lifecycle.py b/nnscaler/codegen/lifecycle.py index 5bdcf5b5..ee44811d 100644 --- a/nnscaler/codegen/lifecycle.py +++ b/nnscaler/codegen/lifecycle.py @@ -38,12 +38,20 @@ def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: inputs : Iterable[IRObject] if isinstance(node, (IRSegment, ExeReuseCell)): + # this is only for scheduler to track the lifetime of tensors + # for module code generation, no IRSegment/ExeReuseCell will be used. + # so will not go into this branch. + # forward segment if node.isfw(): outputs = node.outputs() inputs = node.inputs() # backward segment else: + # in `_train_step`, we will explicitly call backward to generate gradients. + # When pipeline is enabled, there will be multiple backward calls in the same segment. + # So we also need to track the temporary gradient tensors + # and delete them after the backward call to save memory. fw_inputs, fw_outputs, output_grads, input_grads = \ func_emission.get_backward_callsite_io_tensors(node) # remove loss gradient diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 79a82229..fad8ff28 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -10,7 +10,7 @@ import inspect from nnscaler.ir.cten import IRCell -from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from nnscaler.ir.adapter import IRWeightReducer, IRAdapter from nnscaler.ir.adapter.prim import CollectivePrim @@ -29,6 +29,7 @@ from nnscaler.codegen.lifecycle import LifeCycle from nnscaler.flags import CompileFlag +from nnscaler import __version__ as runtime_version _logger = logging.getLogger(__name__) @@ -118,11 +119,13 @@ def __init__( self.enable_dp = self.runtime_ndevs > len(self.devices) self.init_code: List[str] = [ - '\n\n########## Generated Model Code ###########', + '########## Generated Model Code ###########', 'from typing import *', 'from pathlib import Path', 'import torch', 'import torch.utils.checkpoint as ckpt', - 'import nnscaler', 'import _operator', 'from numpy import inf', 'import builtins', '', ''] + 'import nnscaler', 'import _operator', 'from numpy import inf', 'import builtins', '', + f'runtime_version = {runtime_version!r}', '', '' + ] if CompileFlag.use_nnfusion: self.init_code.extend(['import nnfusion', '']) @@ -184,9 +187,9 @@ def collect_rest_params(segment): if len(rest_params) == 0: continue # create reducer and append to the execution - reducer = IRWeightReducer(rest_params) - reducer.device = device # will be scaled in `self.scale` - self.execplan.at(device).append(reducer) + # device will be scaled in `self.scale` + for reducer in IRWeightReducer.from_weights(rest_params, device): + self.execplan.at(device).append(reducer) def get_comm_groups(self): """ @@ -423,6 +426,18 @@ def forward(self, x, y=None, z=None): # initialize communication groups self.emit_comm_groups() + # we can have multiple segments in the graph when pipeline is enabled. + # Here we don't use tid to sort parameters + # because that assumption may be not true in the future, + # and the current implementation is clearer and more robust. + # key: parameter tensor, value: (segment index, node index) + param_first_used_pos: Dict[IRFullTensor, Tuple[int, int]] = {} + for i, n in enumerate(sequence): + if isinstance(n, IRSegment) and n.isfw(): + for k, v in self._get_param_first_used_pos(n).items(): + if k not in param_first_used_pos: + param_first_used_pos[k] = (i, v) + # emit code for node in sequence: if isinstance(node, IRSegment): @@ -433,7 +448,7 @@ def forward(self, x, y=None, z=None): elif isinstance(node, IRAdapter): codes = self.emit_adapter(node, prefix_attr='self.', async_op=CompileFlag.async_comm) elif isinstance(node, IRWeightReducer): - self.init_reducer(node, device) + self.init_reducer(node, device, param_first_used_pos, as_parallel_module) codes = self.emit_reducer(node) elif isinstance(node, IRBpOperation): continue @@ -468,9 +483,19 @@ def forward(self, x, y=None, z=None): graph_sched = self.execplan.graph.sched cb.insert_body(f'use_scheduler = {graph_sched is not None}') cb.insert_body(f'nmicros_per_scheduler_step = {graph_sched.nmicros if graph_sched is not None else 1}') + if as_parallel_module: cb.insert_body(f'rank = {device}') # save rank in class level - with FunctionBlock(func_name='__init__', args=['self', 'init_params=True']) as ib: + # async_op and max_bucket_size_bytes parameters are for testing purpose + # and will not expose to user + with FunctionBlock(func_name='__init__', + args=[ + 'self', + 'init_params=True', + f'async_op={CompileFlag.async_reducer}', + f'max_bucket_size_bytes={CompileFlag.max_reducer_bucket}' + ] + ) as ib: ib.insert_body(self.model_init_statements) ib.insert_body('') ib.insert_body('self._post_init(init_params)') @@ -672,15 +697,22 @@ def init_attributes(self, node: IRCell): self.init_attributes(sub_node) return - def init_reducer(self, node: IRWeightReducer, device: int) -> None: + def init_reducer(self, + node: IRWeightReducer, + device: int, + param_first_used_pos: Dict[IRFullTensor, int], + as_parallel_module: bool = True, + ) -> None: """ Emit code to initialize involved reducer objects in `__init__`. The fields storing intermediate codes that are populated by this method: - `model_init_statements` """ - max_nbytes = CompileFlag.max_reducer_bucket - async_op = CompileFlag.async_reducer + # when parallel module is used, + # `max_bucket_size_bytes` and `async_op` are passed as arguments + max_nbytes = CompileFlag.max_reducer_bucket if not as_parallel_module else 'max_bucket_size_bytes' + async_op = CompileFlag.async_reducer if not as_parallel_module else 'async_op' zero = CompileFlag.use_zero zero_ngroups = CompileFlag.zero_ngroups reduce_op = f"'{CompileFlag.reducer_op}'" @@ -702,7 +734,12 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: reducer=reducer_name, ranks=ranks, reduce_op=reduce_op, async_op=async_op, zero=zero, max_nbytes=max_nbytes, zero_ngroups=zero_ngroups) self.model_init_statements.append(init_code) - weights = [self.tensor_name(t, prefix_attr='self.') for t in weights] + # sort weights by first used time (which is gradient all-reduce time in reverse order) + # so that weights with similar gradient all-reduce time are bucketed together + weights = [ + self.tensor_name(t, prefix_attr='self.') + for t in sorted(weights, key=lambda t: param_first_used_pos[t.parent]) + ] for weight in weights: add_param_code = add_param.format(reducer=reducer_name, weight=weight) self.model_init_statements.append(add_param_code) @@ -863,6 +900,21 @@ def recompute(tensor_2222): return codes + def _get_param_first_used_pos(self, segment: IRSegment) -> Dict[IRFullTensor, int]: + """ + Get the first used node index of each parameter in the segment. + """ + # get all the parameters in the segment + first_used_pos: Dict[IRFullTensor, int] = {} + + for i, node in enumerate(segment.nodes()): + # parameters are used as inputs of the node + for tin in IRSegment.get_objects_from_complex(node.inputs()): + if isinstance(tin, IRSubTensor) and tin.is_param() and tin.parent not in first_used_pos: + first_used_pos[tin.parent] = i + + return first_used_pos + def clear(self): """ Clear buffer that used for generating code diff --git a/nnscaler/flags.py b/nnscaler/flags.py index 032fe3b6..cde7965a 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -47,8 +47,9 @@ class CompileFlag: use_zero = _to_bool('USE_ZERO') # use async communication to overlap gradient synchronization and backward computation async_reducer = _to_bool('ASYNC_REDUCER') # use async reducer - # maximal reducer weight bytes for one allreduce (only effective for async): default 128MB - max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=137217728) + # maximal reducer weight bytes for one allreduce (only effective for async): + # default 0 means using the default value in reducer + max_reducer_bucket = _to_int('MAX_REDUCER_BUCKET', default=0) # perform reducer op on gradients, can be sum, avg, mean, max, min. Default is sum reducer_op = os.environ.get('REDUCER_OP', default='sum') # zero_ngroups is the number of subgroups in each original ZeRO gruop (e.g., weights reducer) diff --git a/nnscaler/graph/gener/gen.py b/nnscaler/graph/gener/gen.py index 7f870458..8184787e 100644 --- a/nnscaler/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -29,12 +29,12 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) -> List[IRFwOperation]: """ - Create dummy operators segment inputs and outputs. + Create dummy operators segment inputs and outputs. @param segment IRSegment: the target segment @param inputs bool: True for creating dummy operators to produce segement's inputs @param outputs bool: True for creating dummpy operators to consume segment's outputs - + @return nodes List[IRCell]: the generated operation """ # devices = segment.device @@ -66,7 +66,7 @@ def create_dummy(segment: IRSegment, inputs: bool = True, outputs: bool = True) return input_producers, output_consumers -def expand_devices(tensors: List[Optional[IRSubTensor]], +def expand_devices(tensors: List[Optional[IRSubTensor]], producer: bool = False, consumer: bool = False) -> List[IRSubTensor]: """ Scatter a tensor if it is on multiple devices. It produces a tensor list where @@ -110,7 +110,7 @@ def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: @param graph IRGraph: the graph without adapter @param cost_fn Optional[Callable]: takes an IRAdapterPrim and outputs a cost in float. default to be None, which will use communication volume. - + @return graph IRGraph: the graph with adapter inserted """ # reorder producer and consumer ordering @@ -145,7 +145,7 @@ def remove_anchor(graph: IRSegment): elif isinstance(anchor, IRSegment): IRAdapterGener.remove_anchor(anchor) return graph - + @staticmethod def auto_pyfunc(graph: IRGraph): """Transform and assign IRPyFunc. @@ -155,10 +155,10 @@ def auto_pyfunc(graph: IRGraph): To restrict the replicaed devices in pipeline-like scenarios, use `graph.staging` to group the operators into segments. - + Args: graph (IRGraph): the graph to be transformed - + Returns: graph (IRGraph): the transformed graph """ @@ -198,7 +198,7 @@ def auto_pyfunc(graph: IRGraph): def gen_weight(graph: IRGraph) -> IRGraph: """Generate cross-device weight reducers for gradient accumulation. - If a weight tensor is replicated across multiple devices by different / partitioned operators, + If a weight tensor is replicated across multiple devices by different / partitioned operators, the weight tensor is required to accumulate gradients according to chain rules. However, if the weight tensor is replicated across devices by replicated operators, @@ -244,7 +244,7 @@ def collect_sub_weight(graph: IRSegment): dev_cids = [tuple(sorted(cids)) for cids in dev_cids.values()] cross_device_replicated = all(cids == dev_cids[0] for cids in dev_cids) - # otherwise, we only support fully partitioned consumers, + # otherwise, we only support fully partitioned consumers, # the weight's gradient should be accumulated. fully_partitioned = len(set(c.cid for c in consumers)) == len(consumers) @@ -265,7 +265,7 @@ def collect_sub_weight(graph: IRSegment): # However, we don't support such fine-grained accumulation for now, and we only support # to either accumulate same sub-weight tensors or not accumulate non-overlapped sub-weight tensors. for ftensor, sub_ws in sub_weights.items(): - # all the sub weights can only be + # all the sub weights can only be # 1) replicated (sw1 == sw2) or, # 2) partitioned without overlapping (not sw1.overlap(sw2)) for sw1, sw2 in itertools.combinations(sub_ws, 2): @@ -277,7 +277,7 @@ def collect_sub_weight(graph: IRSegment): f"FullTensor weight: {ftensor}\n" f"Consumers:\n{nl.join([repr(w.cell) for w in sub_ws])}\n" ) - + # only record sub-weight that is consumed by multiple devices sub_weight_devices: Dict[IRSubTensor, Tuple[int,]] = dict() # - pop out replicated sub weights as they will have full gradients, @@ -290,16 +290,16 @@ def collect_sub_weight(graph: IRSegment): if len(devices) > 1: devices = tuple(sorted(devices)) sub_weight_devices[sub_weight] = devices - + # create reducer - reducers: Dict[Tuple[int,], List[IRSubTensor]] = dict() + reducers: Dict[Tuple[int,...], List[IRSubTensor]] = dict() for subw, devices in sub_weight_devices.items(): reducers.setdefault(devices, []).append(subw) + for devices, subws in reducers.items(): - reducer = IRWeightReducer(subws) - reducer.device = devices - # insert reducer to as the last node. - graph.insert(reducer, graph.nnodes) + for reducer in IRWeightReducer.from_weights(subws, devices): + # insert reducer to as the last node. + graph.insert(reducer, graph.nnodes) return graph @@ -317,7 +317,7 @@ def gen_activation(graph: IRSegment, allow_recompute: bool = True, cost_fn: Opti default to be None, which will use communication volume. Returns: - graph (IRGraph): the (inplace) modified graph with activation adapters. + graph (IRGraph): the (inplace) modified graph with activation adapters. """ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # e.g., loss or parameter/buffer @@ -331,7 +331,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: input_producer, output_consumer = create_dummy(graph, inputs=True, outputs=True) bgraph: Optional[IRSegment] = graph.mirror - + # local producer fusion and local consumer multiref ftensors = [] _cnt = 0 @@ -349,7 +349,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: if _cnt % 100 == 0: _logger.info(f'processed local fusion & multiref for {_cnt} tensors') _logger.info(f'finish local fusion & multiref for {_cnt} tensors') - + # reorder again since inserted multiref could be mis-ordered graph._reorder_producer_consumer() _logger.info("finish reordering producer and consumer") @@ -382,7 +382,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bptensors = expand_devices(bptensors, producer=True) bconsumers, bctensors = bgraph.consumers(ftensor.grad), bgraph.ctensors(ftensor.grad) if ftensor in input_producer: - bctensors = bctensors + tuple(fwop.output(0).grad for fwop in input_producer[ftensor]) + bctensors = bctensors + tuple(fwop.output(0).grad for fwop in input_producer[ftensor]) bctensors = expand_devices(bctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" # special case for loss tensor: @@ -512,7 +512,7 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens apllied with same recompute region. Otherwise no recompute. @param tensors List[IRSubTensor]: tensors to be fused in local device - + @return new_ftensor IRFullTensor: the new full tensor. If cannot fuse, the original ftensor. """ @@ -611,7 +611,7 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens f"Users can try to adjust node ordering to meet with accum order\n" f"{graph.debug_tensor_map_str(ftensor)}" ) - + # === Optimization: quick accumulation to early release tensor lhs, rhs = ptensors[0], None for ptensor in ptensors[1:]: diff --git a/nnscaler/ir/adapter/adapter.py b/nnscaler/ir/adapter/adapter.py index 11fc6f53..a21413cc 100644 --- a/nnscaler/ir/adapter/adapter.py +++ b/nnscaler/ir/adapter/adapter.py @@ -4,6 +4,8 @@ from typing import List, Optional, Dict import copy +import torch + from nnscaler.ir.adapter.prim import IRAdapterPrim, IdentityPrim from nnscaler.ir.tensor import IRSubTensor from nnscaler.ir.cten import IRCell @@ -112,7 +114,7 @@ def dispatch(self, devid: int, _mirror: bool = True): inputs = [] for itensor in self.inputs(): if devid in itensor.device and itensor not in inputs: - inputs.append(itensor) + inputs.append(itensor) outputs = [] for otensor in self.outputs(): if devid in otensor.device and otensor not in outputs: @@ -173,6 +175,9 @@ class IRWeightReducer(IRCell): def __init__(self, weights: List[IRSubTensor], name='reducer'): if not all(isinstance(w, IRSubTensor) and w.is_param() for w in weights): raise RuntimeError("Expected a list of gradient IRSubTensor") + if len(set(w.dtype for w in weights)) != 1: + raise RuntimeError("All weights should have the same dtype") + signature = None super().__init__(name, signature, len(weights), 0) for idx, weight in enumerate(weights): @@ -180,7 +185,7 @@ def __init__(self, weights: List[IRSubTensor], name='reducer'): def isfw(self) -> bool: return False - + def dispatch(self, device: int): return self @@ -190,3 +195,23 @@ def __repr__(self): def module_repr(self) -> str: return repr(self) + + @classmethod + def from_weights(cls, weights: List[IRSubTensor], devices, name='reducer') -> List['IRWeightReducer']: + """! + Create reducers from a list of weights + """ + if not weights: + return [] + + dtype_groups: Dict[torch.dtype, List[IRSubTensor]] = {} + for sub in weights: + dtype_groups.setdefault(sub.dtype, []).append(sub) + + reducers = [] + for typed_subws in dtype_groups.values(): + reducer = IRWeightReducer(typed_subws, name) + reducer.device = devices + reducers.append(reducer) + + return reducers diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 5e7ba873..86a9c0a8 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -87,6 +87,14 @@ class ComputeConfig: # 2. the first return value of `module.forward` must be the loss # which must be a scalar tensor use_end2end: bool = False + # whether to use async reducer + # if True, the gradient all-reduce will be async, + # This only works when the `use_end2end` is `True` for now. + use_async_reducer: bool = False + # the maximal reducer weight bytes for one allreduce in megabytes + # It is also effective for sync reducer. + # None/0 means using the default value. (25MB for async, no limit for sync) + reducer_bucket_cap_mb: Optional[float] = None # PAS policy settings # you can also put any other settings that can affect code generation here. @@ -140,6 +148,12 @@ def __post_init__(self): # have to use __setattr__ for frozen dataclass super().__setattr__('zero_ngroups', 1) + if self.use_async_reducer and not self.use_end2end: + raise ValueError("use_async_reducer is only supported in end2end mode.") + + if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: + raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") + def apply_pipeline_scheduler( self, graph: IRGraph, @@ -294,7 +308,10 @@ def _flags(flags, /, **kwargs): def _compile_flags(compute_config: ComputeConfig): return _flags( CompileFlag, - async_reducer=False, reducer_op='sum', async_comm=False, + async_reducer=compute_config.use_async_reducer, reducer_op='sum', + max_reducer_bucket=int(compute_config.reducer_bucket_cap_mb * 1024 * 1024) + if compute_config.reducer_bucket_cap_mb else None, + async_comm=False, use_zero=compute_config.use_zero, zero_ngroups=compute_config.zero_ngroups, trace_strategy=compute_config.trace_strategy, @@ -858,6 +875,7 @@ def _load_parallel_module_class( # parallel_module_class.__module__ = module_class.__module__ parallel_module_class.__orig_module_class__ = module_class # save the original module class # override train_step and infer_step only if they are defined in the generated module (end2end module only) + parallel_module_class.runtime_version = getattr(gen_imported, 'runtime_version', None) parallel_module_class._train_step = getattr(gen_imported, '_train_step', parallel_module_class._train_step) parallel_module_class._infer_step = getattr(gen_imported, '_infer_step', parallel_module_class._infer_step) return parallel_module_class diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 2ffffd10..9c0c5b0d 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -160,6 +160,9 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): """ if not cfg.use_end2end: raise ValueError("Hybrid policy only supports end2end module") + if cfg.use_async_reducer: + raise ValueError("Hybrid policy does not support async reducer") + ngpus: int = cfg.plan_ngpus nstages = cfg.pas_config.get('pipeline_nstages', cfg.plan_ngpus) nmicros = cfg.pas_config['pipeline_nmicros'] @@ -211,6 +214,9 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: explore_pipeline = pas_cfg.get('explore_pipeline', False) if explore_pipeline and not cfg.use_end2end: raise ValueError("explore_pipeline cannot be enabled if use_end2end is False") + if explore_pipeline and cfg.use_async_reducer: + raise ValueError("explore_pipeline cannot be enabled if use_async_reducer is True") + pipeline_scheduler = pas_cfg.get('pipeline_scheduler', '1f1b') if pipeline_scheduler != '1f1b': raise ValueError(f"Only 1f1b scheduler is supported in autodist.") diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index 07993926..1206a47e 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import List, Dict, Tuple, Any, Callable, Optional, Set +from typing import List, Dict, Tuple, Any, Callable, Optional, Set, Sequence from functools import partial import logging import torch @@ -328,8 +328,12 @@ def reset(self): class Reducer: + # the default bucket cap for async reducer in megabytes + # with the same value as pytorch + # https://github.com/pytorch/pytorch/blob/4fd16dd8aa259cd75c9a6d2ddcd8171cd1ee8e28/torch/nn/parallel/distributed.py#L548 + _DEFAULT_BUCKET_CAP_MB = 25 # 25MB, the same as pytorch - def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, + def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None, reduce_op: str = 'sum', async_op: bool = False, zero: bool = False, zero_ngroups: int = 1): """ @@ -337,20 +341,27 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes=536870912, This assumes the communication group is already created by every rank. - @param ranks List[int]: reducer communication group - @param max_bucket_size_bytes int: largest bucket size for one-time communication, - only work for asynchronous reducer. - @param reduce_op str: reduce operation, can be 'sum', 'avg', 'max' or 'min' (default 'sum') - @param async_op bool: whether to overlap with backward computation (default False) - @param zero bool: whether to apply ZeRO optimization on gradients - @param zero_ngroups int: number of ZeRO subgroups in the original ZeRO group + Args: + ranks (List[int]): reducer communication group + max_bucket_size_bytes (Optional[int]): largest bucket size for one-time communication, + `0` or `None` will use default value, + which is `_DEFAULT_BUCKET_CAP_MB` for async reducer, and no limit for sync reducer. + Default is `None` + reduce_op (str): reduce operation, can be 'sum', 'avg', 'max' or 'min' (default 'sum') + async_op (bool): whether to overlap with backward computation (default False) + zero (bool): whether to apply ZeRO optimization on gradients + zero_ngroups (int): number of ZeRO subgroups in the original ZeRO group """ self._params: List[torch.nn.Parameter] = list() self._param_ids: Set[int] = set() self._numel: int = 0 self._ranks = ranks self._group = DeviceGroup().get_group(ranks) - self._bucket_size: Optional[int] = max_bucket_size_bytes if async_op else None + + self._bucket_size: Optional[int] = max_bucket_size_bytes + if not self._bucket_size and async_op: + self._bucket_size = self._DEFAULT_BUCKET_CAP_MB * 1024 * 1024 + self._reduce_op = _get_reduce_op(reduce_op) # buckets stands for a transission unit self._buckets: List[Bucket] = list() @@ -397,11 +408,11 @@ def zero_ngroups(self) -> int: return self._zero_ngroups @property - def params(self) -> Tuple[torch.nn.Parameter]: + def params(self) -> Tuple[torch.nn.Parameter, ...]: return tuple(self._params) @property - def ranks(self) -> Tuple[int]: + def ranks(self) -> Tuple[int, ...]: return tuple(self._ranks) @property @@ -415,7 +426,7 @@ def zero(self) -> bool: return self._zero @property - def buckets(self) -> Tuple[Bucket]: + def buckets(self) -> Tuple[Bucket, ...]: return tuple(self._buckets) @property @@ -451,33 +462,39 @@ def build_buckets(self): than the max_bucket_size_bytes. """ # step 1: build bucket for overlapping gradient synchronization - bucket_size = self._numel * 8 + 1 if self._bucket_size is None else self._bucket_size - buckets = {} - dtype2size = {} + # self._numel * 8 + 1 here is to make sure + # the bucket size is larger than the total size of all parameters + # 8 is the size of float64, which is the largest data type in PyTorch + + # TODO: we may use a small bucket size for the first bucket, which is used in pytorch + # https://github.com/pytorch/pytorch/blob/4fd16dd8aa259cd75c9a6d2ddcd8171cd1ee8e28/torch/nn/parallel/distributed.py#L1172C17-L1172C36 + # TODO: use native version of reducer, which is more efficient + # (used in pytorch, with a couple percentage improvement) + bucket_size = self._numel * 8 + 1 if not self._bucket_size else self._bucket_size + + # items in the bucket is params list + seq_buckets: List[List[torch.nn.Parameter]] = [] + last_bucket_size = None + + assert len(set(p.dtype for p in self._params)) == 1, ( + "All parameters in the reducer should have the same data type" + ) for param in self._params: if param.requires_grad: cur_byte_size = param.nelement() * param.element_size() - tp = param.data.type() - if tp not in buckets: - buckets[tp] = [[param]] - dtype2size[tp] = cur_byte_size + # also work when cur_byte_size > bucket_size + # It will go the `else` branch + # and finish the current bucket and start a new bucket. + # This new bucket will be sealed in the next iteration + if len(seq_buckets) == 0: + seq_buckets.append([param]) + last_bucket_size = cur_byte_size + elif last_bucket_size + cur_byte_size <= bucket_size: + seq_buckets[-1].append(param) + last_bucket_size += cur_byte_size else: - if cur_byte_size > bucket_size: - _logger.warning(f'find one parameter {param.shape} ({cur_byte_size} bytes) larger than bucket size {self._bucket_size}') - buckets[tp].insert(0, [param]) - elif dtype2size[tp] + cur_byte_size <= bucket_size: - dtype2size[tp] = dtype2size[tp] + cur_byte_size - buckets[tp][-1].append(param) - else: - dtype2size[tp] = cur_byte_size - buckets[tp].append([param]) - seq_buckets: List[List[torch.nn.Parameter]] = [] - for dtype in buckets: - if not self._async: - assert len(buckets[dtype]) == 1, \ - f"internal error: synchronized reducer only needs one bucket, but got {len(buckets[dtype])}" - for bucket in buckets[dtype]: - seq_buckets.append(bucket) + seq_buckets.append([param]) + last_bucket_size = cur_byte_size # step 2: build meta data for the offset of each bucket # the start of each bucket will be padded to the next multiple of `len(self.ranks)` @@ -524,7 +541,9 @@ def build_buckets(self): ) buckets.append(bucket) torch.cuda.empty_cache() + # make it in reverse order as the backward happens from tail to head + # it is not important but may be helpful for waiting cuda stream to finish self._buckets: List[Bucket] = list(reversed(buckets)) assert len(self._buckets) > 0, ( f"Find {len(self._params)} parameters in the reducer. " diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 0a406825..e5fafaba 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -13,12 +13,15 @@ import torch.distributed as dist from nnscaler.graph.parser.fx.parser import FxModuleParser + from nnscaler.runtime.device import DeviceGroup from nnscaler.runtime.adapter.reducer import Reducer from nnscaler.runtime.executor import Executor from nnscaler.runtime.gnorm import ParamsInfo -from nnscaler.flags import CompileFlag from nnscaler.runtime.utils import microbatches + +from nnscaler import __version__ as runtime_version +from nnscaler.flags import CompileFlag from nnscaler.utils import accum_mode if TYPE_CHECKING: @@ -759,11 +762,18 @@ class ParallelModule(CubeModule): EXTRA_STATE_KEY = 'CUBE_EXTRA_STATE' # the rank of the module, will be assigned in the generated subclasses rank: int + # the runtime version of the module when it is generated, will be assigned in the generated subclasses + runtime_version: str def __init__(self): if self.__class__ == ParallelModule: # not init via super().__init__() raise RuntimeError(f"ParallelModule should not be initialized directly. Please derive it first") + rv = getattr(self, 'runtime_version', None) + if rv != runtime_version: + _logger.warning( + f"Runtime version mismatch: {rv} vs {runtime_version}. " + ) super().__init__() # this is used to allow multiple sync_grad() calls self._sync_grad_required = False diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index b4c12e87..62f83886 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -10,7 +10,7 @@ from pathlib import Path import tempfile -from typing import Dict +from typing import Dict, TypedDict import pytest import torch from torch import nn @@ -105,7 +105,7 @@ def _train_ga(model, update_freq, data_size=DATA_SIZE): return results -def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=None, model_cls=MLP, pipeline_scheduler='1f1b'): +def gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=None, model_cls=MLP, async_reducer=False, use_zero=False, use_bucket=False, pipeline_scheduler='1f1b'): init_distributed() init_random() nstages = nstages or plan_ngpus @@ -120,6 +120,9 @@ def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=Non compute_config= ComputeConfig( plan_ngpus, runtime_ngpus, use_end2end=True, + use_zero=use_zero, + use_async_reducer=async_reducer, + reducer_bucket_cap_mb=1e-6 if use_bucket else 0, # 1e-6 to make sure one parameter per bucket pas_config=dict( pipeline_nmicros=nmicros, pipeline_nstages=nstages, @@ -142,6 +145,24 @@ def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=Non return train_result, infer_result, clone_to_cpu_recursively(infer_data) +def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=None, model_cls=MLP, pipeline_scheduler='1f1b'): + return gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages, nmicros, model_cls, False, False, False, pipeline_scheduler) + + +class CubeOptions(TypedDict): + use_zero: bool = False + use_async_reducer: bool = False + use_bucket: bool = False + + +def gpu_work_cube_tp_2_4(option: CubeOptions): + return gpu_worker_cube_general(4, 2, 'tp', + use_zero=option['use_zero'], + use_bucket=option['use_bucket'], + async_reducer=option['use_async_reducer'] + ) + + def merge_cube_result(cube_results): cube_result = [] for i in range(len(cube_results[0])): @@ -192,16 +213,37 @@ def test_end2end(): assert len(cube4_result) == 16 allclose(cube4_result, ga4_result) - cube2_results_non_pipeline = launch_torchrun(4, gpu_worker_cube, 4, 2, 'tp') # micro_batch_size = 4 - for _, v in cube2_results.items(): - # all losses should be scalar tensor - assert all(i.shape == () for i in v[1]) - cube2_result_non_pipeline = merge_cube_result({k: v[0] for k, v in cube2_results_non_pipeline.items()}) - assert len(cube2_result_non_pipeline) == 16 - allclose(cube2_result_non_pipeline, ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error - - infer_results = {k: v[1] for k, v in cube2_results_non_pipeline.items()} - infer_datas = {k: v[2] for k, v in cube2_results_non_pipeline.items()} + cube2_results_non_pipeline = {} + for use_async_reducer in [False, True]: + for use_zero in [False, True]: + for use_bucket in [False, True]: + cube2_results_non_pipeline[(use_zero, use_async_reducer, use_bucket)] = launch_torchrun( + 4, gpu_work_cube_tp_2_4, + CubeOptions(use_zero=use_zero, use_async_reducer=use_async_reducer, use_bucket=use_bucket) + ) + + for r in cube2_results_non_pipeline.values(): + for _, v in r.items(): + # all losses should be scalar tensor + assert all(i.shape == () for i in v[1]) + + cube2_result_non_pipeline = {kk: merge_cube_result({k: v[0] for k, v in vv.items()}) for kk, vv in cube2_results_non_pipeline.items()} + + for r in cube2_result_non_pipeline.values(): + assert len(r) == 16 + + for use_async_reducer in [False, True]: + for use_zero in [False, True]: + for use_bucket in [False, True]: + allclose(cube2_result_non_pipeline[(use_zero, use_async_reducer, use_bucket)], ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error + + for use_zero in [False, True]: + # when use_bucket, it should be the same for both async and non-async + assert_equal(cube2_result_non_pipeline[(use_zero, use_async_reducer, True)], + cube2_result_non_pipeline[(use_zero, not use_async_reducer, True)]) + + infer_results = {k: v[1] for k, v in cube2_results_non_pipeline[(False, False, False)].items()} + infer_datas = {k: v[2] for k, v in cube2_results_non_pipeline[(False, False, False)].items()} assert len(infer_results) == 4 assert len(infer_datas) == 4 infer_result = infer_results[0] diff --git a/tests/parallel_module/test_end2end_mix_precision.py b/tests/parallel_module/test_end2end_mix_precision.py new file mode 100644 index 00000000..02d54dbc --- /dev/null +++ b/tests/parallel_module/test_end2end_mix_precision.py @@ -0,0 +1,222 @@ +""" +PYTHONPATH=.:$PYTHONPATH torchrun \ + --nproc_per_node=4 \ + --nnodes=1 \ + examples/mlp/train.py --policy PASMegatronTP +""" + +from pathlib import Path +import tempfile +from typing import Dict, TypedDict +import pytest +import torch +from torch import nn +import torch.distributed + +import nnscaler +from nnscaler.runtime.gnorm import calcuate_gnorm +from nnscaler.runtime.utils import microbatches +from nnscaler.runtime.module import ParallelModule +from nnscaler.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts +from .common import assert_equal, clear_dir_on_rank0, init_distributed, PASMegatron, init_random +from ..launch_torchrun import clone_to_cpu_recursively, launch_torchrun + +from .test_checkpoint import End2EndMLP +from .test_end2end import allclose, merge_cube_result +from ..utils import init_parameter + + +DATA_SIZE = 16 + + +class MPModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.w0 = torch.nn.Parameter(torch.empty(8, 8, dtype=torch.float32)) + self.b0 = torch.nn.Parameter(torch.empty(8, dtype=torch.float32)) + + self.w1 = torch.nn.Parameter(torch.empty(8, 8, dtype=torch.float64)) + self.b1 = torch.nn.Parameter(torch.empty(8, dtype=torch.float64)) + + self.w2 = torch.nn.Parameter(torch.empty(8, 8, dtype=torch.float32)) + self.b2 = torch.nn.Parameter(torch.empty(8, dtype=torch.float64)) + self.loss_fn = nn.BCELoss() + + self.reset_parameters(self.w0, self.b0) + self.reset_parameters(self.w1, self.b1) + self.reset_parameters(self.w2, self.b2) + + def reset_parameters(self, w, b) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + import math + torch.nn.init.kaiming_uniform_(w, a=math.sqrt(5)) + if b is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(w) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + torch.nn.init.uniform_(b, -bound, bound) + + def forward(self, data: dict): + x = data['data'] + x = self.w0 @ x + self.b0 + x = x.to(torch.float64) + x = self.w1 @ x + self.b1 + x = self.w2 @ x.float() + x = x.to(torch.float64) + self.b2 + x = torch.sigmoid(x.float()) + loss = self.loss_fn(x, data['target']) + return loss + + +def dummy_data(): + return { + 'data': torch.randn( + 8, 8, device=torch.cuda.current_device()), + 'target': torch.rand( + 8, 8, device=torch.cuda.current_device()) + } + + +def _train_cube(model: ParallelModule, mbs, num_replicas, rank): + init_random() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + data = [] + init_random() + for _ in range(DATA_SIZE): + data.append(dummy_data()) + data = [data[i] for i in range(rank, DATA_SIZE, num_replicas)] + chunks = [data[i:i + mbs] for i in range(0, len(data), mbs)] + results = [] + for _, x in enumerate(chunks): + model.train() + losses = model.train_step(x) + print(f'loss {_}: {losses}') + optimizer.step() + # gnorm = optimizer.clip_gnorm() + grads = {n: p.grad for n, p in model.named_parameters()} + model._add_extra_state(grads, '') + weights = {n: p.data for n, p in model.named_parameters()} + model._add_extra_state(weights, '') + # gnorm calculation doesn't support float64, so let's skip it + results.append(clone_to_cpu_recursively([grads, weights, torch.tensor(0.0)])) + optimizer.zero_grad() + return results + + +def _train_ga(model, update_freq, data_size=DATA_SIZE): + init_random() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + data = [] + init_random() + for _ in range(data_size): + data.append(dummy_data()) + results = [] + for i, x in enumerate(data): + model.train() + loss = model(x) + print(f'loss {i}: {loss}') + loss.backward() + if i % update_freq == update_freq - 1: + optimizer.step() + grads = {n: p.grad for n, p in model.named_parameters()} + weights = {n: p.data for n, p in model.named_parameters()} + # gnorm calculation doesn't support float64, so let's skip it + results.append(clone_to_cpu_recursively([grads, weights, torch.tensor(0.0)])) + optimizer.zero_grad() + return results + + +def gpu_worker_cube(use_zero=False, async_reducer=False, use_bucket=False): + init_distributed() + init_random() + plan_ngpus = 2 + runtime_ngpus = 4 + nmicros = plan_ngpus + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end_mp') as tempdir: + init_random() + model = MPModule() + model = parallelize( + model, + {'data': dummy_data()}, + pas_policy='tp', + compute_config= ComputeConfig( + plan_ngpus, runtime_ngpus, + use_end2end=True, + use_zero=use_zero, + use_async_reducer=async_reducer, + reducer_bucket_cap_mb=1e-6 if use_bucket else 0, # 1e-6 to make sure one parameter per bucket + ), + gen_savedir=tempdir + ) + # (intra + inter) * (float32 + float64) + assert len(model.reducers) == 4 + model.cuda() + train_result = _train_cube(model, nmicros, runtime_ngpus // plan_ngpus, torch.distributed.get_rank() // plan_ngpus) + + with torch.inference_mode(): + model.eval() + init_random() + infer_data = [] + for _ in range(nmicros): + infer_data.append(dummy_data()) + infer_result = clone_to_cpu_recursively(model.infer_step(infer_data)) + + return train_result, infer_result, clone_to_cpu_recursively(infer_data) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_mixed_precision(): + torch.cuda.set_device(0) + torch.set_default_device(f'cuda:0') + init_random() + model = MPModule() + torch.save(model.state_dict(), 'model.pth') + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 4 + + cube2_results_non_pipeline = {} + for use_async_reducer in [False, True]: + for use_zero in [False, True]: + for use_bucket in [False, True]: + cube2_results_non_pipeline[(use_zero, use_async_reducer, use_bucket)] = launch_torchrun( + 4, gpu_worker_cube, + use_zero, use_async_reducer, use_bucket + ) + + for r in cube2_results_non_pipeline.values(): + for _, v in r.items(): + # all losses should be scalar tensor + assert all(i.shape == () for i in v[1]) + + cube2_result_non_pipeline = {kk: merge_cube_result({k: v[0] for k, v in vv.items()}) for kk, vv in cube2_results_non_pipeline.items()} + + for r in cube2_result_non_pipeline.values(): + assert len(r) == 4 + + for use_async_reducer in [False, True]: + for use_zero in [False, True]: + for use_bucket in [False, True]: + allclose(cube2_result_non_pipeline[(use_zero, use_async_reducer, use_bucket)], ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error + + for use_zero in [False, True]: + # when use_bucket, it should be the same for both async and non-async + assert_equal(cube2_result_non_pipeline[(use_zero, use_async_reducer, True)], + cube2_result_non_pipeline[(use_zero, not use_async_reducer, True)]) + + infer_results = {k: v[1] for k, v in cube2_results_non_pipeline[(False, False, False)].items()} + infer_datas = {k: v[2] for k, v in cube2_results_non_pipeline[(False, False, False)].items()} + assert len(infer_results) == 4 + assert len(infer_datas) == 4 + infer_result = infer_results[0] + infer_data = infer_datas[0] + for k in infer_results: + assert_equal(infer_results[k], infer_result) + for k in infer_datas: + assert_equal(infer_datas[k], infer_data) + + for i, data in enumerate(infer_data): + with torch.inference_mode(): + model.eval() + loss = model({key: v.cuda() for key, v in data.items()}) + assert torch.allclose(loss.cpu(), infer_result[i].cpu(), atol=1e-6, rtol=1e-6) diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 9344c04a..4f191e81 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -6,6 +6,7 @@ import torch +import nnscaler from nnscaler.parallel import _load_parallel_module_class, parallelize, ComputeConfig from ..launch_torchrun import launch_torchrun @@ -66,6 +67,7 @@ def forward(self, x): def test_empty_weights(model_class, tp): # MyModule2 uses CubeLinear, so tp works # MyModule uses torch.nn.Linear, so tp doesn't work + instance_name = f'm_{tp}' with tempfile.TemporaryDirectory() as tempdir: parallelize( model_class, @@ -75,9 +77,10 @@ def test_empty_weights(model_class, tp): gen_savedir=tempdir, reuse='match', load_module=False, + instance_name=instance_name, ) for i in range(4): - module_class = _load_parallel_module_class(model_class, gen_savedir=tempdir, rank=i) + module_class = _load_parallel_module_class(model_class, gen_savedir=tempdir, instance_name=instance_name, rank=i) m = new_empty(module_class) assert m.rank == i for p in m.parameters(): @@ -93,3 +96,96 @@ def test_empty_weights(model_class, tp): for b in r.buckets: assert b._contiguous_grads.device == torch.device('meta') assert b._contiguous_params.device == torch.device('meta') + + +class MyModule3(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = CubeLinear(8, 8, bias=True) + + def forward(self, x): + x = self.linear(x) + return torch.sum(x) + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('async_op', [True, False]) +def test_async_reducer(async_op): + instance_name = f'm_{async_op}' + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + MyModule3, + {'x': torch.randn(8, 8)}, + 'dp', + ComputeConfig(1, 2, use_zero=True, zero_ngroups=2, use_end2end=True, + use_async_reducer=async_op, + # 1e-6to make sure one parameter per bucket + reducer_bucket_cap_mb=1e-6 if async_op else 0 + ), + gen_savedir=tempdir, + reuse='match', + load_module=False, + instance_name=instance_name, + ) + for i in range(2): + module_class = _load_parallel_module_class(MyModule3, gen_savedir=tempdir, instance_name=instance_name, rank=i) + m = new_empty(module_class, device='cpu') + assert m.rank == i + assert m.runtime_version == nnscaler.__version__ + assert len(m.reducers) == 1 + assert m.reducers[0]._async == async_op + if async_op: + assert len(m.reducers[0].buckets) == 2 + else: + assert len(m.reducers[0].buckets) == 1 + + +class MyModule4(torch.nn.Module): + def __init__(self): + super().__init__() + self.w0 = torch.nn.Parameter(torch.randn(8, 8, dtype=torch.float32)) + self.b0 = torch.nn.Parameter(torch.randn(8, dtype=torch.float32)) + + self.w1 = torch.nn.Parameter(torch.randn(8, 8, dtype=torch.float16)) + self.b1 = torch.nn.Parameter(torch.randn(8, dtype=torch.float16)) + + self.w2 = torch.nn.Parameter(torch.randn(8, 8, dtype=torch.float32)) + self.b2 = torch.nn.Parameter(torch.randn(8, dtype=torch.float16)) + + def forward(self, x: torch.Tensor): + x = self.w0 @ x + self.b0 + x = x.half() + x = self.w1 @ x + self.b1 + x = self.w2 @ x.float() + x = x.half() + self.b2 + return torch.sum(x).float() + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('async_op', [True, False]) +def test_reducer_mixed_precision(async_op): + instance_name = f'm_{async_op}' + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + MyModule4, + {'x': torch.randn(8, 8)}, + 'tp', + ComputeConfig(2, 4, use_end2end=True, + use_async_reducer=async_op, + # a big number to make sure all parameters in one bucket + reducer_bucket_cap_mb=100 + ), + gen_savedir=tempdir, + reuse='match', + load_module=False, + instance_name=instance_name, + ) + for i in range(4): + module_class = _load_parallel_module_class(MyModule4, gen_savedir=tempdir, instance_name=instance_name, rank=i) + m = new_empty(module_class, device='cpu') + assert m.rank == i + assert m.runtime_version == nnscaler.__version__ + # (intra-group + inter-group) * (float16 + float32) + # totally 4 reducers + assert len(m.reducers) == 4 + assert m.reducers[0]._async == async_op diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py index 811b3510..3f080af0 100644 --- a/tests/runtime/test_reducer.py +++ b/tests/runtime/test_reducer.py @@ -14,8 +14,9 @@ from nnscaler.graph import IRGraph from nnscaler.ir.operator import IRFwOperation from nnscaler.flags import CompileFlag +from nnscaler.runtime.adapter.reducer import Reducer from ..launch_torchrun import torchrun -from ..utils import init_parameter, assert_parity +from ..utils import init_parameter, assert_parity, mock_reducer_env class MLP(torch.nn.Module): @@ -135,3 +136,22 @@ def reducer_test(): assert_parity(baseline, partial(reducer, False, False)) test_reducer_2gpu = partial(torchrun, 2, reducer_test) + + +@mock_reducer_env(0, 2) +def test_reducer_build(): + reducer = Reducer([0, 1], max_bucket_size_bytes=16) # 16 bytes means 4 float32 + reducer.add_param(torch.nn.Parameter(torch.randn(1, 2))) # small at first + reducer.add_param(torch.nn.Parameter(torch.randn(1, 10))) # bigger than max_bucket_size_bytes + reducer.add_param(torch.nn.Parameter(torch.randn(1, 3))) # small again + reducer.add_param(torch.nn.Parameter(torch.randn(1, 3))) # small again + reducer.add_param(torch.nn.Parameter(torch.randn(1, 1))) # small again + reducer.add_param(torch.nn.Parameter(torch.randn(1, 1))) # small again + reducer.build_buckets() + assert len(reducer.buckets) == 5 + buckets = list(reversed(reducer.buckets)) + assert buckets[0].numel == 2 + assert buckets[1].numel == 10 + assert buckets[2].numel == 3 + assert buckets[3].numel == 4 + assert buckets[4].numel == 1 diff --git a/tests/utils.py b/tests/utils.py index 0e3eaad0..ed1ef2ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -258,7 +258,7 @@ def mock_dist(rank, world_size): @contextmanager -def mock_cube_env(cube_module_cls: Type[ParallelModule], compute_config): +def mock_cube_env(rank, world_size): old_device_group = nnscaler.runtime.device._instance old_dev_mode = CompileFlag.dev_mode used_cuda_fns = ['set_device', 'current_device', 'default_stream'] @@ -276,8 +276,8 @@ def mock_cube_env(cube_module_cls: Type[ParallelModule], compute_config): CompileFlag.dev_mode = False for fname, fn in old_cuda_fns.items(): setattr(torch.cuda, fname, lambda *args, **kwargs: None) - os.environ['RANK'] = os.environ['LOCAL_RANK'] = str(cube_module_cls.rank) - os.environ['WORLD_SIZE'] = os.environ['LOCAL_WORLD_SIZE'] = str(compute_config.runtime_ngpus) + os.environ['RANK'] = os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = os.environ['LOCAL_WORLD_SIZE'] = str(world_size) os.environ['GROUP_RANK'] = '0' os.environ['TORCHELASTIC_RUN_ID'] = '0' # fake torchrun env yield @@ -293,6 +293,11 @@ def mock_cube_env(cube_module_cls: Type[ParallelModule], compute_config): nnscaler.runtime.device._instance = old_device_group +@contextmanager +def mock_reducer_env(rank, runtime_ngpus, device='cpu'): + with replace_all_device_with(device, True), mock_cube_env(rank, runtime_ngpus), mock_dist(rank, runtime_ngpus): + yield + def new_empty(cube_module_cls: Type[ParallelModule], device='meta', init_params=False): """ Create a new instance with empty weights. @@ -301,7 +306,7 @@ def new_empty(cube_module_cls: Type[ParallelModule], device='meta', init_params= """ module_file = Path(sys.modules[cube_module_cls.__module__].__file__) compute_config = ComputeConfig.safe_load_from_file(module_file.with_name(f"{cube_module_cls.COMPUTE_CONFIG_FILE}")) - with replace_all_device_with(device, True), mock_cube_env(cube_module_cls, compute_config), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): + with replace_all_device_with(device, True), mock_cube_env(cube_module_cls.rank, compute_config.runtime_ngpus), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): return cube_module_cls(init_params=init_params) From 58008fada2e03bb65665e0d011e7373005dae8ac Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Wed, 9 Oct 2024 04:15:40 +0000 Subject: [PATCH 1739/1892] Merged PR 2277: Update pyproject meta fields --- pyproject.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 398b97f8..212f82ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,18 +6,18 @@ build-backend = "setuptools.build_meta" dynamic = ["version", "dependencies"] name = "nnscaler" -description = "Parallelize DNN Traning from A Systematic Way" +description = "Parallelize DNN Training via A Systematic Way" readme = "README.md" requires-python = ">=3.8" -# TODO: license authors = [ {name = "nnScaler Team", email = "nnscaler@service.microsoft.com"} ] -# TODO: keywords -# TODO: classifiers +classifiers = [ + "License :: OSI Approved :: MIT License", +] [project.urls] -Homepage = "https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube" # TODO: github +Homepage = "https://github.com/microsoft/nnscaler" [tool.setuptools] dynamic.version.attr = "nnscaler.version.__version__" @@ -25,6 +25,6 @@ dynamic.dependencies.file = "requirements.txt" # NOTE: # the following part only affects wheel, not sdist -# since our current plan is to use cppimport, sdist is not needed +# since we are using cppimport, sdist is not needed packages.find.include = ["nnscaler*"] package-data = { nnscaler = ["resources/**", "autodist/*.h", "autodist/*.cpp"] } From 4f02e2cc12b71d326e2958eb7da9ba73eeb0dd93 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 11 Oct 2024 08:32:54 +0000 Subject: [PATCH 1740/1892] Merged PR 2280: add options to use reduce scatter when zero is on add options to use reduce scatter when zero is on --- docs/source/parallel_module.md | 6 ++ nnscaler/codegen/module/module.py | 14 +++- nnscaler/flags.py | 6 ++ nnscaler/parallel.py | 13 ++++ nnscaler/runtime/adapter/reducer.py | 57 +++++++++------- tests/parallel_module/test_end2end.py | 95 +++++++++++++++++++-------- tests/utils.py | 1 + 7 files changed, 138 insertions(+), 54 deletions(-) diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 301d59cb..ae24fd96 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -187,6 +187,7 @@ class ComputeConfig: use_zero: bool = False zero_ngroups: int = 1 + zero_use_reduce_scatter: bool = False inference_only : bool = False use_end2end: bool = False @@ -234,6 +235,11 @@ We can categorize the fields into 4 categories: 3. Code generation feature configuration - `use_zero`: whether to use zero. If it is true, the generated code will use zero1 to do distributed training. - `zero_ngroups`: the number of groups to be used in zero. + - `zero_use_reduce_scatter`: whether to use reduce scatter in zero. If it is true, the gradients will be reduced by reduce scatter in zero. + + Please note + - Reduce scatter is only available when `zero_ngroups` is 1. when `zero_ngroups` > 1, you should set it to `False`, or an error will be raised. + - In some cases, it can introduce parity issue. So use it with caution. - `inference_only`: whether to generate code for inference only. If it is true, the generated code can not be used to train the model. - `use_end2end`: whether to use end2end training. For the requirement of end2end, see the description above. - `use_async_reducer`: whether to use async reducer. diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index fad8ff28..e2c0b940 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -486,14 +486,17 @@ def forward(self, x, y=None, z=None): if as_parallel_module: cb.insert_body(f'rank = {device}') # save rank in class level - # async_op and max_bucket_size_bytes parameters are for testing purpose + # async_op, max_bucket_size_bytes and zero_use_reduce_scatter + # parameters are for testing purpose # and will not expose to user with FunctionBlock(func_name='__init__', args=[ 'self', 'init_params=True', + '*', f'async_op={CompileFlag.async_reducer}', - f'max_bucket_size_bytes={CompileFlag.max_reducer_bucket}' + f'max_bucket_size_bytes={CompileFlag.max_reducer_bucket}', + f'zero_use_reduce_scatter={CompileFlag.zero_use_reduce_scatter}', ] ) as ib: ib.insert_body(self.model_init_statements) @@ -713,6 +716,8 @@ def init_reducer(self, # `max_bucket_size_bytes` and `async_op` are passed as arguments max_nbytes = CompileFlag.max_reducer_bucket if not as_parallel_module else 'max_bucket_size_bytes' async_op = CompileFlag.async_reducer if not as_parallel_module else 'async_op' + zero_use_reduce_scatter = CompileFlag.zero_use_reduce_scatter if not as_parallel_module else 'zero_use_reduce_scatter' + zero = CompileFlag.use_zero zero_ngroups = CompileFlag.zero_ngroups reduce_op = f"'{CompileFlag.reducer_op}'" @@ -721,6 +726,7 @@ def init_reducer(self, "{reducer} = nnscaler.runtime.adapter.Reducer(" "ranks={ranks}, reduce_op={reduce_op}, " "async_op={async_op}, zero={zero}, max_bucket_size_bytes={max_nbytes}, " + "zero_use_reduce_scatter={zero_use_reduce_scatter}, " "zero_ngroups={zero_ngroups})" ) reducer_add = 'self.add_reducer({reducer})' @@ -732,7 +738,9 @@ def init_reducer(self, ranks = list(sorted(node.device)) init_code = reducer_init.format( reducer=reducer_name, ranks=ranks, reduce_op=reduce_op, - async_op=async_op, zero=zero, max_nbytes=max_nbytes, zero_ngroups=zero_ngroups) + async_op=async_op, zero=zero, max_nbytes=max_nbytes, + zero_ngroups=zero_ngroups, zero_use_reduce_scatter=zero_use_reduce_scatter + ) self.model_init_statements.append(init_code) # sort weights by first used time (which is gradient all-reduce time in reverse order) # so that weights with similar gradient all-reduce time are bucketed together diff --git a/nnscaler/flags.py b/nnscaler/flags.py index cde7965a..55ad0418 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -57,6 +57,12 @@ class CompileFlag: # it helps reduce communication cost of allgather weights in ZeRO, but increase the weights' # optimization states on each GPU. zero_ngroups = _to_int('ZERO_NUM_GROUPS', default=1) + # whether to use reduce scatter for zero (default False). + # By default we use `allreduce` for zero, which is due to + # 1) `reduce_scatter` will make some parameters have stale gradient after synchronization, + # hence break the consistency of `.data` and `.grad` of parameters. Need to be careful when using optimizer. + # 2) `reduce_scatter`` doesn't significantly improve performance comparing with `allreduce`. + zero_use_reduce_scatter = _to_bool('ZERO_USE_REDUCE_SCATTER') # use automate mixture precision training, where weights, gradients # and optimizer status are kept in its original data type (can be float32), diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 86a9c0a8..47bb5141 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -78,6 +78,11 @@ class ComputeConfig: use_zero: bool = False zero_ngroups: int = 1 + # whether to use reduce scatter for zero + # Please note + # 1. this only works when `use_zero` is True and `zero_ngroups` is 1. + # 2. In some cases, it can introduce parity issue. So use it with caution. + zero_use_reduce_scatter: bool = False # whether the generated code is for inference only inference_only: bool = False @@ -87,6 +92,7 @@ class ComputeConfig: # 2. the first return value of `module.forward` must be the loss # which must be a scalar tensor use_end2end: bool = False + # whether to use async reducer # if True, the gradient all-reduce will be async, # This only works when the `use_end2end` is `True` for now. @@ -154,6 +160,12 @@ def __post_init__(self): if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") + # TODO: Please note in current implementation of Bucket, + # zero_use_reduce_scatter still works when zero_ngroups > 1 in sync mode + # Let's hide this feature for now for consistency. + if self.use_zero and self.zero_use_reduce_scatter and self.zero_ngroups != 1: + raise ValueError("zero_use_reduce_scatter is only supported when zero_ngroups is 1.") + def apply_pipeline_scheduler( self, graph: IRGraph, @@ -314,6 +326,7 @@ def _compile_flags(compute_config: ComputeConfig): async_comm=False, use_zero=compute_config.use_zero, zero_ngroups=compute_config.zero_ngroups, + zero_use_reduce_scatter=compute_config.zero_use_reduce_scatter, trace_strategy=compute_config.trace_strategy, ) diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index 1206a47e..6c177abb 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -32,20 +32,14 @@ def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: class Bucket: - - # config: whether to use reduce scatter for zero (default False). - # By default we use `allreduce` for zero, which is due to - # 1) `reduce_scatter` will make some parameters have stale gradient after synchronization, - # hence break the consistency of `.data` and `.grad` of parameters. Need to be careful when using optimizer. - # 2) `reduce_scatter`` doesn't significantly improve performance comparing with `allreduce`. - use_reduce_scatter_for_zero: bool = False - def __init__(self, params: List[torch.nn.Parameter], param_buffer: torch.Tensor, grad_buffer: torch.Tensor, reduce_op: torch.distributed.ReduceOp, - group, async_op: bool, zero: bool, + group: torch.distributed.ProcessGroup, async_op: bool, zero: bool, zero_subgroup: torch.distributed.ProcessGroup = None, - zero_crossgroup: torch.distributed.ProcessGroup = None): + zero_crossgroup: torch.distributed.ProcessGroup = None, + zero_use_reduce_scatter: bool = False, + ): """ Create a communication unit for parameter allreduce. @@ -53,15 +47,16 @@ def __init__(self, params: List[torch.nn.Parameter], The parameters are assumed to participate in backward and generate gradient. Args: - params List[torch.nn.Parameter]: the parameters - param_buffer torch.Tensor: Paramter contiguous buffer - grad_buffer torch.Tensor: gradient contiguous buffer - reduce_op torch.distributed.ReduceOp: the reduce op used by collectives - group: communication group - async_op bool: whether to use asynchronous operation - zero bool: whether to use zero optimization on gradients - zero_subgroup: the subgroup for zero optimization the current rank belongs to - zero_crossgroup: the communication group for cross zero group allreduce when reduce scatter is enabled + params (List[torch.nn.Parameter]): the parameters + param_buffer (torch.Tensor): Paramter contiguous buffer + grad_buffer (torch.Tensor): gradient contiguous buffer + reduce_op (torch.distributed.ReduceOp): the reduce op used by collectives + group (torch.distributed.ProcessGroup): communication group + async_op (bool): whether to use asynchronous operation + zero (bool): whether to use zero optimization on gradients + zero_subgroup (torch.distributed.ProcessGroup): the subgroup for zero optimization the current rank belongs to + zero_crossgroup (torch.distributed.ProcessGroup): the communication group for cross zero group allreduce when reduce scatter is enabled + zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization """ self._params: List[torch.nn.Parameter] = params @@ -75,6 +70,7 @@ def __init__(self, params: List[torch.nn.Parameter], self._async: bool = async_op self._zero: bool = zero + self._zero_use_reduce_scatter = zero_use_reduce_scatter self._contiguous_params = param_buffer self._contiguous_grads = grad_buffer assert grad_buffer.size() == param_buffer.size() @@ -190,15 +186,17 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): # apply pre hooks self._apply_pre_hooks() # communication - if self._zero and Bucket.use_reduce_scatter_for_zero: + if self._zero and self._zero_use_reduce_scatter: if self._zgroup_sz == self._wsz: rank = torch.distributed.get_rank(group=self._group) shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) + # inplace reduce scatter is supported + # see https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#c.ncclReduceScatter self._async_handle = torch.distributed.reduce_scatter( shards[rank], shards, op=self._reduce_op, group=self._group, async_op=True) else: - assert False, "reducescatter is not supported in async mode, " \ + assert False, "group zero + reducescatter is not supported in async mode, " \ "because the two steps (allreduce, reducescatter) use " \ "two communication groups, which may induce deadlock." self._group_reduce_scatter() @@ -237,7 +235,7 @@ def sync_grads(self): # apply pre-hooks self._apply_pre_hooks() # synchrnoize gradients - if self._zero and Bucket.use_reduce_scatter_for_zero: + if self._zero and self._zero_use_reduce_scatter: self._group_reduce_scatter() else: torch.distributed.all_reduce( @@ -335,7 +333,9 @@ class Reducer: def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None, reduce_op: str = 'sum', async_op: bool = False, - zero: bool = False, zero_ngroups: int = 1): + zero: bool = False, zero_ngroups: int = 1, + zero_use_reduce_scatter: bool = False, + ): """ Create a reducer applied on a set of weights for weight reduction @@ -351,6 +351,7 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None async_op (bool): whether to overlap with backward computation (default False) zero (bool): whether to apply ZeRO optimization on gradients zero_ngroups (int): number of ZeRO subgroups in the original ZeRO group + zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization """ self._params: List[torch.nn.Parameter] = list() self._param_ids: Set[int] = set() @@ -367,6 +368,7 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None self._buckets: List[Bucket] = list() self._async: bool = async_op self._zero: bool = zero + self._zero_use_reduce_scatter = zero_use_reduce_scatter # contiguous parameter buffer and gradient buffer self._contiguous_params: torch.Tensor = None self._contiguous_grads: torch.Tensor = None @@ -379,8 +381,14 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None # the ranks will be divided into [0, 1, 2, 3] and [4, 5, 6, 7]. # If the ranks are [0, 2, 4, 6], zero_ngroups=2, then the ranks # will be divided into [0, 2] and [4, 6]. - if self._zero and Bucket.use_reduce_scatter_for_zero: + if self._zero and self._zero_use_reduce_scatter: _logger.info(f"Using reduce scatter for ZeRO optimization") + # TODO: In current implementation of Bucket, + # zero_use_reduce_scatter works when zero_ngroups > 1 in sync mode + # We can enable it in sync mode when it is proved to be useful. + if zero_ngroups > 1: + raise ValueError("reduce scatter is not supported when zero_ngroups > 1") + if zero_ngroups > 1: assert self._zero, f"USE_ZERO must be set when ZERO_NUM_GROUPS is larger than 1" assert len(ranks) % zero_ngroups == 0, f"length of ranks {ranks} must be divisible by zero factor {zero_ngroups}" @@ -538,6 +546,7 @@ def build_buckets(self): self._zero, self._zero_subgroup, self._zero_crossgroup, + self._zero_use_reduce_scatter, ) buckets.append(bucket) torch.cuda.empty_cache() diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index 62f83886..13fb20aa 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -105,7 +105,7 @@ def _train_ga(model, update_freq, data_size=DATA_SIZE): return results -def gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=None, model_cls=MLP, async_reducer=False, use_zero=False, use_bucket=False, pipeline_scheduler='1f1b'): +def gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=None, model_cls=MLP, async_reducer=False, use_zero=False, use_bucket=False, zero_use_reduce_scatter=False, pipeline_scheduler='1f1b'): init_distributed() init_random() nstages = nstages or plan_ngpus @@ -121,6 +121,7 @@ def gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages=None, nmi plan_ngpus, runtime_ngpus, use_end2end=True, use_zero=use_zero, + zero_use_reduce_scatter=zero_use_reduce_scatter, use_async_reducer=async_reducer, reducer_bucket_cap_mb=1e-6 if use_bucket else 0, # 1e-6 to make sure one parameter per bucket pas_config=dict( @@ -146,47 +147,57 @@ def gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages=None, nmi def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=None, model_cls=MLP, pipeline_scheduler='1f1b'): - return gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages, nmicros, model_cls, False, False, False, pipeline_scheduler) + return gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages, nmicros, model_cls, False, False, False, False, pipeline_scheduler) class CubeOptions(TypedDict): use_zero: bool = False use_async_reducer: bool = False use_bucket: bool = False + zero_use_reduce_scatter: bool = False def gpu_work_cube_tp_2_4(option: CubeOptions): return gpu_worker_cube_general(4, 2, 'tp', use_zero=option['use_zero'], use_bucket=option['use_bucket'], - async_reducer=option['use_async_reducer'] + async_reducer=option['use_async_reducer'], + zero_use_reduce_scatter=option['zero_use_reduce_scatter'], ) -def merge_cube_result(cube_results): +def merge_cube_result(cube_results, zero_use_reduce_scatter=False): cube_result = [] for i in range(len(cube_results[0])): for rank in cube_results: assert torch.equal(cube_results[rank][i][2], cube_results[0][i][2]) - cube_result.append([ - merge_state_dicts([cube_results[rank][i][0] for rank in cube_results])[0], - merge_state_dicts([cube_results[rank][i][1] for rank in cube_results])[0], - cube_results[0][i][2] - ]) + if not zero_use_reduce_scatter: + cube_result.append([ + merge_state_dicts([cube_results[rank][i][0] for rank in cube_results])[0], + merge_state_dicts([cube_results[rank][i][1] for rank in cube_results])[0], + cube_results[0][i][2] + ]) + else: + # grads are not merged for zero_use_reduce_scatter + # as they are different in different ranks + cube_result.append([ + merge_state_dicts([cube_results[rank][i][1] for rank in cube_results])[0], + cube_results[0][i][2] + ]) return cube_result def allclose(a, b, atol=1e-6, rtol=1e-6): assert len(a) == len(b) for step in range(len(a)): - assert len(a[step][0]) == len(b[step][0]) - assert len(a[step][1]) == len(b[step][1]) - for k in a[step][0].keys(): # grads - assert torch.allclose(a[step][0][k].cpu(), b[step][0][k].cpu(), atol=atol, rtol=rtol) - for k in a[step][1].keys(): # weights - assert torch.allclose(a[step][1][k].cpu(), b[step][1][k].cpu(), atol=atol, rtol=rtol) - # gnorm - assert torch.allclose(a[step][2].cpu(), b[step][2].cpu(), atol=atol, rtol=rtol) + # grads and weights (grads can be absent in case of zero_use_reduce_scatter) + assert len(a[step]) == len(b[step]) + for i in range(len(a[step]) - 1): + assert len(a[step][i]) == len(b[step][i]) + for k in a[step][i].keys(): + assert torch.allclose(a[step][i][k].cpu(), b[step][i][k].cpu(), atol=atol, rtol=rtol) + # gnorm is last element + assert torch.allclose(a[step][-1].cpu(), b[step][-1].cpu(), atol=atol, rtol=rtol) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @@ -196,6 +207,10 @@ def test_end2end(): model = MLP() ga4_result = _train_ga(model, 4) # micro_batch_size = 4 assert len(ga4_result) == 16 + # will be used for comparision when zero_use_reduce_scatter is True + ga4_result_without_grads = [] + for i in range(len(ga4_result)): + ga4_result_without_grads.append([ga4_result[i][1], ga4_result[i][2]]) cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid') # micro_batch_size = 4 for _, v in cube2_results.items(): @@ -217,17 +232,37 @@ def test_end2end(): for use_async_reducer in [False, True]: for use_zero in [False, True]: for use_bucket in [False, True]: - cube2_results_non_pipeline[(use_zero, use_async_reducer, use_bucket)] = launch_torchrun( + zero_use_reduce_scatter = False + cube2_results_non_pipeline[(use_zero, use_async_reducer, use_bucket, zero_use_reduce_scatter)] = launch_torchrun( 4, gpu_work_cube_tp_2_4, - CubeOptions(use_zero=use_zero, use_async_reducer=use_async_reducer, use_bucket=use_bucket) + CubeOptions(use_zero=use_zero, + use_async_reducer=use_async_reducer, + use_bucket=use_bucket, + zero_use_reduce_scatter=zero_use_reduce_scatter + ) ) + if not use_zero: + cube2_results_non_pipeline[(use_zero, use_async_reducer, use_bucket, not zero_use_reduce_scatter)] = \ + cube2_results_non_pipeline[(use_zero, use_async_reducer, use_bucket, zero_use_reduce_scatter)] + else: + cube2_results_non_pipeline[(use_zero, use_async_reducer, use_bucket, not zero_use_reduce_scatter)] = launch_torchrun( + 4, gpu_work_cube_tp_2_4, + CubeOptions(use_zero=use_zero, + use_async_reducer=use_async_reducer, + use_bucket=use_bucket, + zero_use_reduce_scatter=not zero_use_reduce_scatter + ) + ) for r in cube2_results_non_pipeline.values(): for _, v in r.items(): # all losses should be scalar tensor assert all(i.shape == () for i in v[1]) - cube2_result_non_pipeline = {kk: merge_cube_result({k: v[0] for k, v in vv.items()}) for kk, vv in cube2_results_non_pipeline.items()} + cube2_result_non_pipeline = { + kk: merge_cube_result({k: v[0] for k, v in vv.items()}, zero_use_reduce_scatter=kk[3]) + for kk, vv in cube2_results_non_pipeline.items() + } for r in cube2_result_non_pipeline.values(): assert len(r) == 16 @@ -235,15 +270,21 @@ def test_end2end(): for use_async_reducer in [False, True]: for use_zero in [False, True]: for use_bucket in [False, True]: - allclose(cube2_result_non_pipeline[(use_zero, use_async_reducer, use_bucket)], ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error + for zero_use_reduce_scatter in [False, True]: + allclose(cube2_result_non_pipeline[(use_zero, use_async_reducer, use_bucket, zero_use_reduce_scatter)], + ga4_result if not zero_use_reduce_scatter else ga4_result_without_grads, + atol=1e-5, rtol=1e-5) # looks tp introduces more error for use_zero in [False, True]: - # when use_bucket, it should be the same for both async and non-async - assert_equal(cube2_result_non_pipeline[(use_zero, use_async_reducer, True)], - cube2_result_non_pipeline[(use_zero, not use_async_reducer, True)]) - - infer_results = {k: v[1] for k, v in cube2_results_non_pipeline[(False, False, False)].items()} - infer_datas = {k: v[2] for k, v in cube2_results_non_pipeline[(False, False, False)].items()} + for zero_use_reduce_scatter in [False, True]: + # when use_bucket, it should be the same for both async and non-async + use_async_reducer = True + use_bucket = True + assert_equal(cube2_result_non_pipeline[(use_zero, use_async_reducer, use_bucket, zero_use_reduce_scatter)], + cube2_result_non_pipeline[(use_zero, not use_async_reducer, use_bucket, zero_use_reduce_scatter)]) + + infer_results = {k: v[1] for k, v in cube2_results_non_pipeline[(False, False, False, False)].items()} + infer_datas = {k: v[2] for k, v in cube2_results_non_pipeline[(False, False, False, False)].items()} assert len(infer_results) == 4 assert len(infer_datas) == 4 infer_result = infer_results[0] diff --git a/tests/utils.py b/tests/utils.py index ed1ef2ee..9cf5213d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -298,6 +298,7 @@ def mock_reducer_env(rank, runtime_ngpus, device='cpu'): with replace_all_device_with(device, True), mock_cube_env(rank, runtime_ngpus), mock_dist(rank, runtime_ngpus): yield + def new_empty(cube_module_cls: Type[ParallelModule], device='meta', init_params=False): """ Create a new instance with empty weights. From 94e76e7054ec387346dc775623eb0b7a0cfa558c Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 14 Oct 2024 07:27:28 +0000 Subject: [PATCH 1741/1892] Merged PR 2283: Support constant pad --- nnscaler/graph/function/function.py | 21 ++++++++++++++++++++- tests/graph/function/test_functions.py | 13 ++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index c7d5e28d..e0d83e49 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1566,7 +1566,26 @@ def Pad(input, pad, mode='constant', value=0.0, signature = None): """ torch.nn.functional.pad(input, pad, mode='constant', value=0.0) """ - return IRPad(signature, [input], 'pad', pad=pad, mode=mode, value=value) + if mode != 'constant': + raise ValueError(f"Currently only support mode='constant' but got {mode}") + + pad_vals, _ = extract_variadic(pad) + if len(pad_vals) % 2 != 0: + raise ValueError(f"pad should be a list of even length but got {pad}") + + pad_vals = [(pad_l, pad_r) for pad_l, pad_r in zip(pad_vals[::2], pad_vals[1::2])] + pad_vals.reverse() + pad_dim_num = len(pad_vals) + + gener = iter(string.ascii_lowercase) + prefix_anno = ShapeAnno.create_shape_str(input.shape[:-pad_dim_num], iterator=gener) + in_pad_dim_anno = [str(dim) for dim in input.shape[-pad_dim_num:]] + out_pad_dim_anno = [str(dim + pad_l + pad_r) for dim, (pad_l, pad_r) in zip(input.shape[-pad_dim_num:], pad_vals)] + in_anno = prefix_anno + in_pad_dim_anno + out_anno = prefix_anno + out_pad_dim_anno + anno = OpAnno.create_op_str([in_anno], [out_anno]) + + return IRDimops(Pad, 'pad', signature, [anno], [input], pad=pad, mode=mode, value=value) # def Conv2D(signature, inputs): diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index be0a1f98..e211c89b 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -629,7 +629,6 @@ def test_Conv1D(): assert op._annos_candidates[0] == 'n iC^ 4, oC iC^ 1, oC -> n oC 4', "Annotation mismatch." - def test_Arange(): op = F.Arange(10) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 10^' and op.kwargs['dtype'] == torch.int64 @@ -859,3 +858,15 @@ def test_ConvTranspose1D(): op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), bias=IRObject(value=3)) assert op._annos_candidates[0] == 'n iC+ 4, iC+ oC 1, oC -> n oC 4', "Annotation mismatch." + +def test_Pad(): + op = F.Pad(IRTensor([3, 3, 4, 2]), pad=(1, 1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c 2 -> a b c 4' + op = F.Pad(IRTensor([3, 3, 4, 2]), pad=(1, 1, 2, 2)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b 4 2 -> a b 8 4' + op = F.Pad(IRTensor([3, 3, 4, 2]), pad=(0, 1, 2, 1, 3, 3)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 3 4 2 -> a 9 7 3' + op = F.Pad(IRTensor([3, 4, 2]), pad=(0, 1, 2, 1, 3, 3)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '3 4 2 -> 9 7 3' + op = F.Pad(IRTensor([3, 3, 4, 2]), pad=(o(1), o(1))) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c 2 -> a b c 4' From cdf7944148f62591d5d025c5443ad91eb2a17777 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 16 Oct 2024 06:41:02 +0000 Subject: [PATCH 1742/1892] Merged PR 2281: Refine code for MoE 1. fix bug in autodist when profiling sequentially 2. update the interface of TransformRule's modifier, it will take the partition position into consideration. 3. avoid redundant `import` when generating code 4. update register_op, it accepts two additional input parameters: transform_rules and input_gen_fn to handle partitioning and profiling of special functions 5. handle inplace function (like setitem) in profiler --- docs/source/register_custom_op.md | 16 +++- nnscaler/algorithm/ops/dimops.py | 4 +- nnscaler/autodist/cost_database.py | 14 +-- nnscaler/autodist/spmd_solver.py | 4 +- nnscaler/codegen/module/module.py | 2 +- nnscaler/graph/function/dimops.py | 9 +- nnscaler/graph/function/function.py | 6 +- nnscaler/graph/parser/fx/mapping.py | 24 +++--- nnscaler/graph/parser/register.py | 37 ++++++-- nnscaler/profiler/database.py | 20 ++++- tests/autodist/spmd_solver/test_setitem.py | 63 ++++++++++++++ tests/graph/parser/test_register.py | 99 ++++++++++++++++++++++ 12 files changed, 258 insertions(+), 40 deletions(-) create mode 100644 tests/autodist/spmd_solver/test_setitem.py diff --git a/docs/source/register_custom_op.md b/docs/source/register_custom_op.md index 6a7f120b..1fd6a2f2 100644 --- a/docs/source/register_custom_op.md +++ b/docs/source/register_custom_op.md @@ -44,7 +44,10 @@ nnscaler.register_op('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom')(oper def register_op( annotation: Union[str, Callable], name: Optional[str] = None, - code_impl_pattern: str = 'import' + code_impl_pattern: str = 'import', + emit_fn: Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str] = None, + transform_rules: Tuple[TransformRule] = None, + input_gen_fn: Callable[IRFwOperation, List[torch.Tensor]] = None) -> Callable: ) -> Callable: ... ``` @@ -99,6 +102,17 @@ This function has the following parameters: taks inputs and kwargs as arguments and returns the operator annotation. - `name` (`str | None`): operator name. Only usable when `node_repr` is a string. - `code_impl_pattern` (`str`): It can only be 'import' or 'source'. If 'import' (default), will generate code with import statement. If 'source', will take the source code directly. +- `emit_fn` (`Callable`): special emit function for codegen, it accepts the node, repred args, repred kwargs, runtime_devid, + plan_ndevs, runtime_ndevs as input and returns the generated code. Check examples/customized_ops/ring_attention/zigzag_attn.py for more details. + Default: None. +- `transform_rules` (`Tuple[TransformRule]`): a tuple of special TransformRules which will be used when partitioning the node. + Default: None. +- `input_gen_fn` (`Callable`): input generator function for profiler, this function accepts the IRFwOperation as input and returns + the list of input tensors, which is used during operator profiling. kwargs are same as that in the input node. By default, the + profiler will use `torch.rand` for floating point data types and `torch.zeros` for special types like `torch.int64` and `torch.bool`. + However, input tensors' contents may influence the speed dramatically. The mask in attention and dispatched expert index in MoE + are real examples. To handle this scenario, user can provide the customized `input_gen_fn`. + Default: None. ## Dimension Annotion Operations diff --git a/nnscaler/algorithm/ops/dimops.py b/nnscaler/algorithm/ops/dimops.py index 88d414b2..c156769c 100644 --- a/nnscaler/algorithm/ops/dimops.py +++ b/nnscaler/algorithm/ops/dimops.py @@ -158,12 +158,12 @@ def transform(tensor: Any, split: DimopSplit) -> List[Any]: ous = list() for split, otensor in zip(rule.outputs(), node.outputs()): ous.append(transform(otensor, split)) - kwargs = rule.modifier()(node.kwargs, idx, dim, num) sub_nodes = list() for nid in range(num): inputs = [t[nid] for t in ins] outputs = [t[nid] for t in ous] + kwargs = rule.modifier()(node.kwargs, idx, dim, num, nid) sub_node: IRDimops = node.new(inputs, outputs, **kwargs) sub_node.infer_shape() sub_nodes.append(sub_node) @@ -233,7 +233,7 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR ) otransform.append(DimopSplit.D(dims)) # modifier - def modify(kwargs: Dict, idx: int, dim: int, num: int): + def modify(kwargs: Dict, idx: int, dim: int, num: int, pos: int): updated_kwargs = dict(**kwargs) if adim in updated_kwargs: assert updated_kwargs[adim] % num == 0, \ diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 53159bca..75f0c754 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -138,6 +138,12 @@ def __init__(self, graph: IRGraph, config: AutoDistConfig): self.ignore_small_tensor_threshold = self.autodist_config.ignore_small_tensor_threshold def profile_comp(self, partition_degree: int): + def insert_profile_info(info: List[Tuple[str, str, ProfiledMetrics]]): + for sign, serialized, profiled_metrics in info: + _logger.debug(f'profiled {sign} in {serialized} with {profiled_metrics}') + if not self.db.exist_serialized(sign, serialized): + self.db.insert(sign, serialized, profiled_metrics) + if self.autodist_config.parallel_profile: _logger.info('Profiling in parallel') # use spawn to make sure the profiling process is independent from each other @@ -155,10 +161,7 @@ def profile_comp(self, partition_degree: int): # put queue.get() before join to avoid deadlock for p in processes: ret = results.get() - for sign, serialized, profiled_metrics in ret: - _logger.debug(f'profiled {sign} in {serialized} with {profiled_metrics}') - if not self.db.exist_serialized(sign, serialized): - self.db.insert(sign, serialized, profiled_metrics) + insert_profile_info(ret) results.close() for p in processes: @@ -166,7 +169,8 @@ def profile_comp(self, partition_degree: int): else: _logger.info('Profiling in serial') node_to_profile = _filter_nodes(self.graph, self.db) - _profile_nodes(node_to_profile, self.db, partition_degree, self.autodist_config.re_profile) + ret = _profile_nodes(node_to_profile, self.db, partition_degree, self.autodist_config.re_profile) + insert_profile_info(ret) self.db.dump_ops(self.comp_profile_path, override=True) diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 9d60b963..fc7badcd 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -615,8 +615,8 @@ def calc_father4op_partition(): cur_p_fathers = self.p_fathers[i] partitions = [None] * p_num for j, p_father in enumerate(self.p_fathers[i]): - if p_father == -1: - raise RuntimeError(f'find -1 in p_fathers for operator {i}') + if p_father == -1 or partitions[p_father] is not None: + raise RuntimeError(f'illegal p_fathers {self.p_fathers[i]} for {self.get_operator(i).ir_cell}') partitions[p_father] = self._op_partitions[i][j] self._op_partitions[i] = partitions self.p_fathers[i] = list(range(p_num)) diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index e2c0b940..351c3d92 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -131,7 +131,7 @@ def __init__( self.init_code.extend(['import nnfusion', '']) # customized op code - for _, op_impl in CustomizedOps.kOpCodeDef.items(): + for op_impl in set(CustomizedOps.kOpCodeDef.values()): # self.init_code.append('@torch.jit.script') self.init_code.append(op_impl) self.init_code += [''] diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index 8448bbfd..c0a76da7 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -597,7 +597,12 @@ class TransformRule: """ Partition rule """ - def __init__(self, irules: Tuple[DimopSplit], orules: Tuple[DimopSplit], kwarg_modifier: Optional[Callable] = None) -> None: + def __init__( + self, + irules: Tuple[DimopSplit], + orules: Tuple[DimopSplit], + kwarg_modifier: Optional[Callable[[Dict, int, Union[int, str], int, int], Dict]] = None, + ) -> None: self._inputs = tuple(irules) self._outputs = tuple(orules) modifier = kwarg_modifier if kwarg_modifier is not None else TransformRule.default_modifier @@ -624,7 +629,7 @@ def __repr__(self) -> str: return f'{inputs} -> {outputs}' @staticmethod - def default_modifier(kwargs: Dict, idx: int, dim: Union[int, str], num: int) -> Dict: + def default_modifier(kwargs: Dict, idx: int, dim: Union[int, str], num: int, subnode_idx: int) -> Dict: return kwargs diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index e0d83e49..0cf72a77 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1357,7 +1357,7 @@ def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[s if bracket[hdim] not in spatial: bracket[hdim] = str(shape_map[bracket[hdim]]) - def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: + def modifier(kwargs: Dict, idx, dim, num: int, subnode_idx: int) -> Dict: kwargs = dict(**kwargs) identifier = ifirst[dim] oidx = ofirst.index(identifier) @@ -2738,7 +2738,7 @@ def Conv1D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, raise ValueError(f'Input shape and weight shape are not compatible for the number of groups. input shape: {input.shape}, weight shape: {weight.shape}, groups: {groups_val}') if oC % groups_val != 0: raise ValueError('The output channels of weight must be divisible by the number of groups.') - def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: + def modifier(kwargs: Dict, idx, dim, num: int, subnode_idx: int) -> Dict: # only for partitioning groups kwargs = dict(**kwargs) kw_groups = kwargs['groups'] @@ -2887,7 +2887,7 @@ def Conv2D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, if oC % groups_val != 0: raise ValueError('The output channels of weight must be divisible by the number of groups.') - def modifier(kwargs: dict, idx, dim, num: int) -> dict: + def modifier(kwargs: dict, idx, dim, num: int, subnode_idx: int) -> dict: # only for partitioning groups kwargs = dict(**kwargs) kw_groups = kwargs['groups'] diff --git a/nnscaler/graph/parser/fx/mapping.py b/nnscaler/graph/parser/fx/mapping.py index df7e246f..c202231b 100644 --- a/nnscaler/graph/parser/fx/mapping.py +++ b/nnscaler/graph/parser/fx/mapping.py @@ -204,7 +204,7 @@ def exist(signature: str) -> bool: __ttemplate('neg'): function.Neg, '_operator.neg': function.Neg, - # + __ttemplate('gt'): function.CompareGT, '_operator.gt': function.CompareGT, __ttemplate('lt'): function.CompareLT, @@ -213,11 +213,11 @@ def exist(signature: str) -> bool: '_operator.ge': function.CompareGE, __ttemplate('le'): function.CompareLE, '_operator.le': function.CompareLE, - # + __ttemplate('sin'): function.Sin, - # + __ttemplate('cos'): function.Cos, - # + __tttemplate('view'): function.View, __tttemplate('contiguous'): function.Contiguous, @@ -227,32 +227,32 @@ def exist(signature: str) -> bool: __ftemplate('conv1d'): function.Conv1D, __ttemplate('conv_transpose1d'): function.ConvTranspose1D, __ftemplate('conv_transpose1d'): function.ConvTranspose1D, - # + __ttemplate('conv2d'): function.Conv2D, __ftemplate('conv2d'): function.Conv2D, __ttemplate('conv_transpose2d'): function.ConvTranspose2D, __ftemplate('conv_transpose2d'): function.ConvTranspose2D, - # + # __ttemplate('conv3d'): function.Conv3D, - # - # __ttemplate('pad'): function.Pad, - # + + __ftemplate('pad'): function.Pad, + # __ttemplate('select'): function.Select, # # __ttemplate('slice'): function.Slice, # # #pytorch1.11 # __ttemplate('select_scatter'): function.SelectScatter, - # + __tttemplate('repeat'): function.Repeat, __ttemplate('cat'): function.Cat, __ttemplate('stack'): function.Stack, __ttemplate('chunk'): function.Chunk, __ttemplate('flatten'): function.Flatten, # __ttemplate('roll'): function.Roll, - # + # __ttemplate('adaptive_avg_pool1d'): function.AdaptiveAvgPool1d, - # + # runtime functions __rtemplate('anchor'): function.GraphAnchor, __rtemplate('ifexpr'): function.Ifexpr, diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index 1b6de28a..b8b2923e 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -5,14 +5,15 @@ Register cutomized function """ -from typing import Dict, Callable, Optional, Union, List +from typing import Dict, Callable, Optional, Union, List, Tuple from functools import partial import inspect import logging +import torch from torch import ScriptFunction -from nnscaler.graph.function.dimops import IRDimops, OpAnno +from nnscaler.graph.function.dimops import IRDimops, OpAnno, TransformRule from nnscaler.graph.parser.fx.concrete_trace_utils.wrap_utils import is_autograd_apply from nnscaler.ir.operator import IRTensor, IRFwOperation @@ -32,6 +33,10 @@ class CustomizedOps: # It accepts the node, repred args, repred kwargs, runtime_devid, plan_ndevs, runtime_ndevs # as input and returns the generated code. kOpEmit: Dict[str, Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str]] = {} + # signature -> input generator function + # It accepts the IRFwOperation as input and returns the list of input tensors, which is used + # during operator profiling. + kOpInputGen: Dict[str, Callable[[IRFwOperation], List[torch.Tensor]]] = {} @staticmethod def map(signature: str) -> Callable: @@ -54,7 +59,8 @@ def exist(signature: str) -> bool: @staticmethod def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Callable, - emit_fn: Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str] = None): + emit_fn: Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str] = None, + input_gen_fn: Callable[[IRFwOperation], List[torch.Tensor]] = None) -> None: """Register an operator Args: @@ -65,6 +71,8 @@ def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Call emit_fn (Callable): special emit function for codegen, will use default emit function if emit_fn is None. It accepts the node, repred args, repred kwargs, runtime_devid, plan_ndevs, runtime_ndevs as input and returns the generated code. + input_gen_fn (Callable): input generator function for profiler, will use default input generator function + if input_gen_fn is None. kwargs are same as that in the input node. Returns: None @@ -78,10 +86,15 @@ def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Call CustomizedOps.kOpCodeDef[signature] = code if emit_fn is not None: CustomizedOps.kOpEmit[signature] = emit_fn + if input_gen_fn is not None: + CustomizedOps.kOpInputGen[signature] = input_gen_fn def register_op(annotation: Union[str, Callable], name: Optional[str] = None, - code_impl_pattern: str = 'import', emit_fn: Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str] = None) -> Callable: + code_impl_pattern: str = 'import', + emit_fn: Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str] = None, + transform_rules: Tuple[TransformRule] = None, + input_gen_fn: Callable[[IRFwOperation], List[torch.Tensor]] = None) -> Callable: """ Register a function with IRDimops annotations. @@ -136,10 +149,18 @@ def anno_fn(*inputs, **kwargs): can only be 'import' or 'source'. If 'import', will generate code with import statement. If 'source', will take the source code directly. Default: 'import'. - emit_fn (Callable): special emit function for codegen, this emit accepts the node, repred args, repred kwargs, runtime_devid, - plan_ndevs, runtime_ndevs as input and returns the generated code. Check examples/zigzag_ring_attention/zigzag_attn.py + emit_fn (Callable): special emit function for codegen, it accepts the node, repred args, repred kwargs, runtime_devid, + plan_ndevs, runtime_ndevs as input and returns the generated code. Check examples/customized_ops/ring_attention/zigzag_attn.py for more details. Default: None. + transform_rules (Tuple[TransformRule]): a tuple of special TransformRules which will be used when partitioning the node. + Default: None. + input_gen_fn (Callable): input generator function for profiler, this function accepts the IRFwOperation as input and returns + the list of input tensors, which is used during operator profiling. kwargs are same as that in the input node. By default, the + profiler will use `torch.rand` for floating point data types and `torch.zeros` for special types like `torch.int64` and `torch.bool`. + However, input tensors' contents may influence the speed dramatically. The mask in attention and dispatched expert index in MoE + are real examples. To handle this scenario, user can provide the customized `input_gen_fn`. + Default: None. Returns: fn (Callable): the runtime function @@ -234,11 +255,11 @@ def udfop(*args, signature=None, **kwargs): kwarg_vals = args[ninputs:] for name, val in zip(kwarg_names, kwarg_vals): kwargs[name] = val - return IRDimops(udfop, op_name, signature, [repr(anno)], tensors, **kwargs) + return IRDimops(udfop, op_name, signature, [repr(anno)], tensors, **kwargs, transform_rules=transform_rules) # step 4. register in CustomizedOps _logger.info(f'registering op {fsig}...') - CustomizedOps.register(fsig, udfop, code, fn, emit_fn) + CustomizedOps.register(fsig, udfop, code, fn, emit_fn, input_gen_fn) return fn return decorator diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 1aeabbc9..c6ae4671 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -150,10 +150,15 @@ def gen_torch_tensors(shape, dtype, requires_grad): constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) - tensors = tuple( - gen_torch_tensors(shape, dtype, requires_grad) if isinstance(value, IRTensor) else value \ - for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) - ) + if CustomizedOps.kOpInputGen.get(node.signature, None) is not None: + in_tensors = CustomizedOps.kOpInputGen[node.signature](node) + else: + in_tensors = tuple( + gen_torch_tensors(shape, dtype, requires_grad) if isinstance(value, IRTensor) else value \ + for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) + ) + # add clone() to avoid error "RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation." + tensors = tuple([t.clone() if torch.is_tensor(t) else t for t in in_tensors]) total_input_size = sum(t.numel() * t.element_size() for t in tensors if torch.is_tensor(t)) require_backward = any([t.requires_grad for t in tensors if hasattr(t, 'requires_grad')]) # FIXME: reconsidering requires_grad @@ -173,6 +178,13 @@ def gen_torch_tensors(shape, dtype, requires_grad): # run one sample outputs = func(*tensors, **train_kwargs) + + # check whether func is a in-place operation + for t1, t2 in zip(in_tensors, tensors): + if torch.is_tensor(t1) and not torch.equal(t1, t2): + _logger.warning(f"{node}: in-place operation detected, the input tensor is modified, will not profile backward") + require_backward = False + # only profile IRDimops currently, which has at least one tensor output and # may have non-tensor outputs (like list, tuple, dict, etc.). In addition, # we assume that non-tensor outputs will not be used in backward. diff --git a/tests/autodist/spmd_solver/test_setitem.py b/tests/autodist/spmd_solver/test_setitem.py new file mode 100644 index 00000000..49a576d4 --- /dev/null +++ b/tests/autodist/spmd_solver/test_setitem.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +import logging +import tempfile +import torch +import os +import nnscaler +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.profiler.database import ProfileDataBase +from ...utils import catch_log + + +class Module(torch.nn.Module): + + def __init__(self, hidden_dim): + super(Module, self).__init__() + self.hidden_dim = hidden_dim + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + y1 = self.linear1(x) + y2 = self.linear2(x) + z = y1.new_empty(x.size(0), 2 * self.hidden_dim) + z[:, :self.hidden_dim] = y1 + z[:, self.hidden_dim:] = y2 + return z.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_set_item(): + nnscaler.utils.set_default_logger_level(logging.INFO) + from nnscaler.ir.unique import IDGenerator + IDGenerator().clear() + bsz, hidden_dim = 2, 10 + dummy_input = { + 'x': torch.rand(bsz, hidden_dim), + } + model = Module(hidden_dim) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + constant_folding=True) + + selected_nodes = [node for node in ir_graph.nodes() if 'setitem' in node.signature] + db = ProfileDataBase() + from nnscaler.profiler.database import _logger as _logger_profiler + with catch_log(_logger_profiler) as log_stream_profiler: + for node in selected_nodes: + ret = db.profile(node) + profiler_logs = log_stream_profiler.getvalue() + profiler_logs = profiler_logs.split('\n') + in_place_log = [log for log in profiler_logs if 'in-place operation detected, the input tensor is modified, will not profile backward' in log] + assert len(in_place_log) == 2 + fail_log = [log for log in profiler_logs if 'fail to profile' in log] + assert len(fail_log) == 0 diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index 7ee3fd51..fff23ec1 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -5,6 +5,8 @@ from nnscaler.graph.parser.converter import convert_model from nnscaler.profiler.database import get_func from nnscaler.codegen.emit import FuncEmission +from nnscaler.graph.function.dimops import DimopSplit, TransformRule +from nnscaler.graph.parser.register import CustomizedOps import tempfile import torch @@ -48,6 +50,7 @@ def forward(self, x, y): x, y = self.fc(x), self.fc(y) return mock_add(x, y) + class MockModel2(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -57,6 +60,7 @@ def forward(self, x, y): x, y = self.fc(x), self.fc(y) return mock_add2(x, y) + class MockModel3(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -141,6 +145,7 @@ def emit_customized_add(node, args, kwargs, runtime_devid, plan_ndevs, runtime_n nnscaler.register_op('*, * -> *', emit_fn=emit_customized_add)(customized_add) + class ModelCustomizedAdd(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -148,6 +153,7 @@ def __init__(self) -> None: def forward(self, x, y): return customized_add(x, y) + @replace_all_device_with('cpu') def test_customized_emit(): model = ModelCustomizedAdd() @@ -156,3 +162,96 @@ def test_customized_emit(): add_node = ir_graph.nodes()[0] code = FuncEmission().emit_fnode(add_node, runtime_devid=0, plan_ndevs=1, runtime_ndevs=1) assert 'torch.add' in code[-1] + + +def mock_transform_rule_add(x: torch.Tensor, y: torch.Tensor, z: int): + return x + y + + +def build_mock_transform_rules(): + itransform = [ + DimopSplit.D(0), + DimopSplit.D(0), + ] + + otransform = [ + DimopSplit.D(0), + ] + + def modifier(kwargs, idx, dim, num, subnode_idx): + updated_kwargs = dict(**kwargs) + if idx == 0 and dim == 0: + updated_kwargs['z'] = kwargs['z'] * (subnode_idx + 1) + else: + updated_kwargs['z'] = kwargs['z'] + return updated_kwargs + + return (TransformRule(itransform, otransform, modifier),) + +nnscaler.register_op('*, * -> *', transform_rules=build_mock_transform_rules())(mock_transform_rule_add) + + +class MockModelTransformRule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(10, 10, bias=False) + + def forward(self, x, y): + x1, y1 = self.fc(x), self.fc(y) + z1 = mock_transform_rule_add(x1, y1, 10) + x2, y2 = self.fc(x), self.fc(y) + z2 = mock_transform_rule_add(x2, y2, 10) + return z1 + z2 + + +@replace_all_device_with('cpu') +def test_transform_rule(): + model = MockModelTransformRule() + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) + add_node0 = ir_graph.nodes()[2] + add_node1 = ir_graph.nodes()[5] + sub0, sub1 = ir_graph.partition(add_node0, add_node0.algorithms('dim'), idx=0, dim=0, num=2) + assert sub0.kwargs['z'] == 10 + assert sub1.kwargs['z'] == 20 + + sub2, sub3 = ir_graph.partition(add_node1, add_node1.algorithms('dim'), idx=0, dim=1, num=2) + assert sub2.kwargs['z'] == 10 + assert sub3.kwargs['z'] == 10 + + +def mock_select(x: torch.Tensor, selected_rows: torch.Tensor): + return x[selected_rows, :] + + +def input_gen_fn(node): + inputs = [] + row = None + for i, t in enumerate(node.inputs()): + if i == 1: + inputs.append(torch.randint(low=0, high=row, size=t.shape, dtype=torch.int64, requires_grad=t.requires_grad)) + else: + row = t.shape[0] + inputs.append(torch.rand(t.shape, dtype=t.dtype, requires_grad=t.requires_grad)) + return tuple(inputs) + +nnscaler.register_op('a^ b^, c^ -> c^ b^', input_gen_fn=input_gen_fn)(mock_select) + + +class MockModelSelect(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, selected_rows): + return mock_select(x, selected_rows) + + +@replace_all_device_with('cpu') +def test_input_gen_fn(): + model = MockModelSelect() + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'selected_rows': torch.randint(0, 10, (5,), dtype=torch.int64)}, tempdir, False) + select_node = ir_graph.nodes()[0] + fn = CustomizedOps.kOpInputGen[select_node.signature] + ret = mock_select(*fn(select_node)) + assert True From c314670fca4bd01b549086f4822b5bd3797d8db8 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 17 Oct 2024 02:13:10 +0000 Subject: [PATCH 1743/1892] Merged PR 2284: Add modeling code and doc for deepseek coder v2 lite --- docs/source/register_custom_op.md | 11 +- examples/deepseek_coder_v2_lite/README.md | 78 + .../modeling/__init__.py | 0 .../modeling/configuration_deepseek.py | 206 ++ .../modeling/modeling_deepseek.py | 1922 +++++++++++++++++ .../modeling/modeling_deepseek_modifier.py | 491 +++++ .../modeling/moe_utils.py | 214 ++ examples/deepseek_coder_v2_lite/train.py | 337 +++ 8 files changed, 3256 insertions(+), 3 deletions(-) create mode 100644 examples/deepseek_coder_v2_lite/README.md create mode 100644 examples/deepseek_coder_v2_lite/modeling/__init__.py create mode 100644 examples/deepseek_coder_v2_lite/modeling/configuration_deepseek.py create mode 100644 examples/deepseek_coder_v2_lite/modeling/modeling_deepseek.py create mode 100644 examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py create mode 100644 examples/deepseek_coder_v2_lite/modeling/moe_utils.py create mode 100644 examples/deepseek_coder_v2_lite/train.py diff --git a/docs/source/register_custom_op.md b/docs/source/register_custom_op.md index 1fd6a2f2..286f28a6 100644 --- a/docs/source/register_custom_op.md +++ b/docs/source/register_custom_op.md @@ -108,10 +108,15 @@ This function has the following parameters: - `transform_rules` (`Tuple[TransformRule]`): a tuple of special TransformRules which will be used when partitioning the node. Default: None. - `input_gen_fn` (`Callable`): input generator function for profiler, this function accepts the IRFwOperation as input and returns - the list of input tensors, which is used during operator profiling. kwargs are same as that in the input node. By default, the + the list of input tensors, which is used during operator profiling. kwargs are same as that in the input node. By default, nnScaler's profiler will use `torch.rand` for floating point data types and `torch.zeros` for special types like `torch.int64` and `torch.bool`. - However, input tensors' contents may influence the speed dramatically. The mask in attention and dispatched expert index in MoE - are real examples. To handle this scenario, user can provide the customized `input_gen_fn`. + However, input tensors' contents may influence operator's behavior and speed dramatically. + Take function `nnscaler_moe_gmm` in `examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py` as an example. It dispatches + tokens (`hidden_states`) to experts according to another input tensor `topk_idx`. In most of the training time, tokens are distributed + evenly among experts with indices in `[local_expert_start, local_expert_end]`. Since `top_idx`'s type is `torch.int64`, if we generate + it with `torch.zeros` then all of the tokens are dispatched to the 1st expert, which can be ilegal and far from the real profile statistics + of the operator. By using `input_gen_fn`, we can provide compatible input tensors to the profiler so that the solver can generate a + good distributed plan. Default: None. diff --git a/examples/deepseek_coder_v2_lite/README.md b/examples/deepseek_coder_v2_lite/README.md new file mode 100644 index 00000000..f9aa6e5e --- /dev/null +++ b/examples/deepseek_coder_v2_lite/README.md @@ -0,0 +1,78 @@ +# Introduction + +This example demonstrates how to train deepseek-coder-v2-lite-2k on 8xH100s or 8xA100s. + +# Requirements + +To run this example, you need to install the following packages: + +```text +nnscaler +transformers==4.40.0 +datasets==2.20.0 +apex +flash-attn +grouped_gemm==1.1.4 +``` + +We recommend to launch the script under a Nvidia docker directly, like `nvidia/pytorch:24.02-py3`. You can find grouped_gemm at https://github.com/fanshiqing/grouped_gemm. + +# Data Preparation + +Like the *llama3_8B_128K* example, [bookcorpus](https://huggingface.co/datasets/bookcorpus) dataset is used for training. You can use the following command directly + +```bash +python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name deepseek-ai/DeepSeek-Coder-V2-Lite-Base --save_path ./bookcorpus_2k --sequence_length 2048 +``` + +# Training + +## Code Modification + +Modeling is based on the open source version for [deepseek coder v2](https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Base/tree/main). To boost the training performance and be compatible with nnScaler, the source code is modified. You can check modifications in details under `modeling` folder: + +- `configuration_deepseek.py` and `modeling_deepseek.py` are identical with the public available ones. +- Token dispatching logics are in `moe_utils.py`, which is adapted from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py). +- Most of the modifications are in `modeling_deepseek_modifier.py`. + +Similar to *llama3_8B_128K*, apex and flash-attn are introduced to reduce the execution time of RMSNorm and multi-head attention. In addition, there are several deepseek specific modifications: + +- register the routing function with annotation to nnScaler, since it is composed of fine-grained irregular operators and generating the annoation automatically is non-trivial. +- the for loop based MoE implementation is replaced with an efficient implementation built on [cutlass](https://github.com/NVIDIA/cutlass/blob/main/examples/24_gemm_grouped/gemm_grouped.cu). Along with kernel, separated expert weights are merged after loading the checkpoints. + +## Distributed Config + +The input data is organized into batches of 64 sequences whose length = 2048. The micro batch size is 4 and gradient accumulation step is 8. 8 GPUs are divided into 2 data parallel groups (4 GPUs maintain a full copy of weights). + +You can use following commands to compile and run the model. Checkpoints can be merged by the script in *llama3_8B_128K*. If you want to load the weights to huggingface, the merged experts should be split to the original names. + +**Compile** + +```bash +python train.py --run_mode compile --model_id deepseek-ai/DeepSeek-Coder-V2-Lite-Base --dataset_path ./bookcorpus_2k --plan_ngpus=4 --runtime_ngpus=8 2>&1 | tee compile.log +``` + +**Run** + +```bash +torchrun --nproc_per_node=8 train.py --model_id deepseek-ai/DeepSeek-Coder-V2-Lite-Base --dataset_path ./bookcorpus_2k --plan_ngpus=4 --runtime_ngpus=8 2>&1 | tee run.log +``` + +# Performance + +We have tested the training script on 8xH100 and each step takes about 2s. A step is composed of 128K tokens and the number of activated params is about 2.65B. Combining them together, the MFU is about 13% (attention's FLOPs is omitted since the sequence is short in this ). The root cause is the low utilization rate of the MoE part. We collect statistics for the grouped gemm in the table below. Note that in deepseek coder v2 lite, there are 64 experts with hidden size = 2048, intermediate size = 1408, each token will be dispatched to 8 experts. + +| # Dispatch Token | # Expert | forward / ms | backward / ms | MFU | +| :---- | :---- | :---- | :---- | :--- | +| 4096 | 64 | 3.190 | 6.363 | 13.5% | +| 2048 | 32 | 1.851 | 3.367 | 12.3% | +| 8192 | 64 | 5.148 | 8.964 | 18.2% | +| 2048 | 16 | 1.613 | 2.459 | 15.8% | +| 16384 | 64 | 8.901 | 14.90 | 21.6% | +| 2048 | 8 | 1.663 | 2.329 | 16.1% | + +To improve the performance, we recommend to + +- Replace the cutlass kernel with better ones. Current script is based on grouped_gemm@v1.14. +- Fuse more kernels like rope and memory slicing in attention. +- There are about 16 * 8 = 128 GB space used to store the optimizer states. Adding more devices helps to save more memory and nnScaler can find a better plan then. diff --git a/examples/deepseek_coder_v2_lite/modeling/__init__.py b/examples/deepseek_coder_v2_lite/modeling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/deepseek_coder_v2_lite/modeling/configuration_deepseek.py b/examples/deepseek_coder_v2_lite/modeling/configuration_deepseek.py new file mode 100644 index 00000000..82e0f5d9 --- /dev/null +++ b/examples/deepseek_coder_v2_lite/modeling/configuration_deepseek.py @@ -0,0 +1,206 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V2. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV2Model, DeepseekV2Config + + >>> # Initializing a Deepseek-V2 style configuration + >>> configuration = DeepseekV2Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size = 1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts = None, + n_routed_experts = None, + ep_size = 1, + routed_scaling_factor = 1.0, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'gready', + n_group = None, + topk_group = None, + num_experts_per_tok = None, + moe_layer_freq = 1, + first_k_dense_replace = 0, + norm_topk_prob = False, + scoring_func = 'softmax', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek.py b/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek.py new file mode 100644 index 00000000..ea723720 --- /dev/null +++ b/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek.py @@ -0,0 +1,1922 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_deepseek import DeepseekV2Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV2Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) + + +class DeepseekV2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV2MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + elif self.topk_method == "group_limited_greedy": + group_scores = ( + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + ### expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class DeepseekV2MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave( + self.num_experts_per_tok, dim=0 + ) + y = torch.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.to(hidden_states.dtype).view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 +class DeepseekV2FlashAttention2(DeepseekV2Attention): + """ + DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV2FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = ( + self.q_proj.weight.dtype + if self.q_lora_rank is None + else self.q_a_proj.weight.dtype + ) + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV2Attention, + "flash_attention_2": DeepseekV2FlashAttention2, +} + + +class DeepseekV2DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV2MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV2MLP(config) + ) + self.input_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2PreTrainedModel(PreTrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2Model(DeepseekV2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py b/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py new file mode 100644 index 00000000..9ae7da5a --- /dev/null +++ b/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py @@ -0,0 +1,491 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# This file modifies the official modeling_llama.py file at runtime to +# 1. register the flash attention function to nnscaler and update related code +# 2. replace the un-fused RMSNorm with apex's fused version +# 3. register the MoE routing function to nnscaler +# 4. replace the for loop in MoE forward with grouped gemm implementation + +import types +from typing import List, Optional, Tuple, Union + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir import IRTensor + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .modeling_deepseek import DeepseekV2FlashAttention2, ATTENTION_CLASSES, apply_rotary_pos_emb, DeepseekV2RMSNorm, AddAuxiliaryLoss, MoEGate, DeepseekV2MoE, _get_unpad_data +from .moe_utils import moe_gather, moe_scatter, permute, unpermute + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +try: + from apex.normalization.fused_layer_norm import fused_rms_norm_affine + has_apex = True +except ImportError: + has_apex = False + + +try: + from grouped_gemm.ops import gmm +except ImportError: + raise ImportError( + "Grouped GEMM is not available. Please run " + "`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.0`." + ) + + +def rmsnorm_fwd(self, hidden_states): + if has_apex: + return fused_rms_norm_affine(hidden_states, self.weight, self.weight.shape, self.variance_epsilon) + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def moe_gate_fwd(self, hidden_states): + topk_idx, topk_weight, aux_loss = moe_route(hidden_states, self.weight, self.topk_method, self.top_k, self.n_group, self.n_routed_experts, + self.topk_group, self.training, self.alpha, self.norm_topk_prob, self.routed_scaling_factor, self.seq_aux) + return topk_idx, topk_weight, aux_loss + + +def moe_fwd(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + if self.training: + # gate_projs, up_projs, down_projs are merged after checkpoints are loaded + y = nnscaler_moe_gmm(hidden_states, topk_idx, topk_weight, aux_loss, self.gate_projs, self.up_projs, self.down_projs, self.config.n_routed_experts, 0, self.config.n_routed_experts) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + +class NNScalerDeepseekFlashAttention2(DeepseekV2FlashAttention2): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV2FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) # start signal + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + query_states = torch.cat([q_nope, q_pe], dim=-1) + + # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + key_states = torch.cat([k_nope, k_pe.expand(-1, k_nope.size(1), -1, -1)], dim=-1) + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = ( + self.q_proj.weight.dtype + if self.q_lora_rank is None + else self.q_a_proj.weight.dtype + ) + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and q_len != 1 + + attn_output = nnscaler_flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, causal=causal + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +### register custom functions +def nnscaler_flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, causal=True +): + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = nnscaler_upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + return attn_output + + +def nnscaler_upad_input(query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, query_layer.shape[-2], head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def moe_route(hidden_states: torch.Tensor, weight: torch.Tensor, + topk_method: str, top_k: int, n_group: int, n_routed_experts: int, topk_group: int, + training: bool, alpha: float, norm_topk_prob: bool, routed_scaling_factor: float, seq_aux: bool): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), weight.type(torch.float32), None + ) + scores = nn.functional.softmax(logits, dim=-1, dtype=torch.float32) + + ### select top-k experts + if topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=top_k, dim=-1, sorted=False + ) + elif topk_method == "group_limited_greedy": + group_scores = ( + scores.view(bsz * seq_len, n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, n_group, n_routed_experts // n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=top_k, dim=-1, sorted=False + ) + + ### norm gate to sum 1 + if top_k > 1 and norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * routed_scaling_factor + ### expert-level computation auxiliary loss + if training and alpha > 0.0: + scores_for_aux = scores + aux_topk = top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * n_routed_experts + aux_loss = (Pi * fi).sum() * alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.function.dimops import DimopSplit, TransformRule + + +# NOTE: moe_route is replicated intra scale unit, since: +# 1. the computation overhead is small +# 2. the returned aux_loss is summed by mean along the batch dimension, which makes +# it difficult to handle it correctly without modifying the code +# 3. dispatch by allgather is used currently, which is compatible with the replicated +# moe_route plan +register_op(f'n^ l^ h^, e^ h^ -> (n^ l^) k^, (n^ l^) k^, 1')(moe_route) + + +def nnscaler_llama_flash_attention_forward_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + from nnscaler.ir import IRTensor + if isinstance(attention_mask, IRTensor): + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, b l^ -> b l^ {q_anno} vd^' + else: + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' + + +register_op(nnscaler_llama_flash_attention_forward_anno)(nnscaler_flash_attention_forward) + + +def nnscaler_moe_gmm( + hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weight: torch.Tensor, aux_loss: torch.Tensor, + gate_projs: torch.Tensor, up_projs: torch.Tensor, down_projs: torch.Tensor, + n_routed_experts: int, local_expert_start: int, local_expert_end: int): + + orig_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) + topk_weight = topk_weight.reshape(-1, topk_weight.shape[-1]) + + with torch.no_grad(): + local_mask = (topk_idx >= local_expert_start) & (topk_idx < local_expert_end) + local_idx = topk_idx.masked_select(local_mask) + + local_prob = topk_weight.masked_select(local_mask) + local_prob = local_prob.view(-1, 1) + local_map = local_mask.nonzero()[:, 0] + local_map = local_map.view(-1, 1).expand(-1, hidden_states.shape[-1]) + local_hidden_states = moe_gather.apply(hidden_states, local_map) + + with torch.no_grad(): + tokens_per_expert = torch.histc(local_idx, bins=local_expert_end - local_expert_start, min=local_expert_start, max=local_expert_end - 1) + tokens_per_expert = tokens_per_expert.cpu().to(torch.long) + + permuted_inputs, row_id_map = permute(local_hidden_states, local_idx) + + fc1_output = gmm(permuted_inputs, gate_projs, tokens_per_expert, trans_b=True) + fc2_output = gmm(permuted_inputs, up_projs, tokens_per_expert, trans_b=True) + intermediate_parallel = torch.nn.functional.silu(fc1_output) * fc2_output + expert_outs = gmm(intermediate_parallel, down_projs, tokens_per_expert, trans_b=True) + + y = unpermute(expert_outs, row_id_map) + y = y * local_prob + y = moe_scatter.apply(y, local_map, hidden_states.shape) + + y = y.to(hidden_states.dtype).view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + + return y + + +def build_ep_transform_rule(): + itransform = [ + DimopSplit.R(), + DimopSplit.R(), + DimopSplit.R(), + DimopSplit.R(), + DimopSplit.D(0), + DimopSplit.D(0), + DimopSplit.D(0), + ] + + otransform = [ + DimopSplit.V(), + ] + + def modifier(kwargs, idx, dim, num, pos): + updated_kwargs = dict(**kwargs) + expert_num = kwargs['local_expert_end'] - kwargs['local_expert_start'] + updated_kwargs['local_expert_start'] = expert_num // num * pos + updated_kwargs['local_expert_end'] = expert_num // num * (pos + 1) + return updated_kwargs + + return TransformRule(itransform, otransform, modifier) + + +def input_gen_fn(node: IRFwOperation): + inputs = [] + device = torch.cuda.current_device() + for i, t in enumerate(node.inputs()): + if i == 1: + inputs.append(torch.randint(low=0, high=64, size=t.shape, dtype=torch.int64, device=device, requires_grad=t.requires_grad)) + else: + inputs.append(torch.rand(t.shape, dtype=t.dtype, device=device, requires_grad=t.requires_grad)) + return tuple(inputs) + + +register_op(f'n l h^, (n l) 6, (n l) 6, 1, E+ d+ h^, E+ d+ h^, E+ h^ d+ -> n l h^', transform_rules=(build_ep_transform_rule(),), input_gen_fn=input_gen_fn)(nnscaler_moe_gmm) + + +def nnscaler_deepseek_init(): + ATTENTION_CLASSES['flash_attention_2'] = NNScalerDeepseekFlashAttention2 + DeepseekV2RMSNorm.forward = rmsnorm_fwd + MoEGate.forward = moe_gate_fwd + DeepseekV2MoE.forward = moe_fwd diff --git a/examples/deepseek_coder_v2_lite/modeling/moe_utils.py b/examples/deepseek_coder_v2_lite/modeling/moe_utils.py new file mode 100644 index 00000000..01698d60 --- /dev/null +++ b/examples/deepseek_coder_v2_lite/modeling/moe_utils.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# This file is adapted from the Megatron-LM project. + +import torch + + +def permute(tokens, indices, num_out_tokens: int = None, padded_mode: bool = False): + """Permute the tokens based on the indices. Token with the same index will be grouped together. + The input indices shape is [tokens, top_k], it indicates which experts were selected by each + token separately. + Args: + tokens (torch.Tensor): The input token tensor. + indices (torch.Tensor): The token to expert indices tensor, should have a shape of + [num_tokens] or [num_tokens, topk]. + num_out_tokens (int, optional): The effective output token count, when enabling the + capacity factor, should equal the number of tokens not + dropped. By default, set to None, meaning no tokens are + dropped. + padded_mode (bool, optional): If True, indicating the indices are padded to + [num_expert, capacity] to denote selected tokens per expert. + Defaults to False. + + Returns: + torch.Tensor: The permuted tensor. + torch.Tensor: The sorted_indices corresponding permuted tensor. + """ + if padded_mode: + return permute_with_padded_tokens(tokens, indices) + + if indices.dim() == 1: + indices = indices.unsqueeze(1) + + topk = indices.size(1) + flatten_indices = indices.view(-1) + sorted_indices = torch.argsort(flatten_indices, stable=True) + if num_out_tokens is not None: + sorted_indices = sorted_indices[:num_out_tokens] + moe_gather_indices = (sorted_indices // topk).unsqueeze(1).expand(-1, tokens.size(-1)) + permuted_tokens = moe_gather.apply(tokens, moe_gather_indices) + + return permuted_tokens, sorted_indices + + +def unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + probs: torch.Tensor = None, + padded_mode: bool = False, + restore_shape: torch.Size = None, +): + """Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the + tokens with their corresponding probabilities. + + Args: + permuted_tokens (torch.Tensor): 2D tensor [num_tokens*topk, hidden]. The tensor of permuted + tokens to be unpermuted. + sorted_indices (torch.Tensor): 1D tensor [num_tokens*topk]. The tensor of sorted indices + used to unpermute the tokens. + probs (torch.Tensor, optional): 2D tensor [num_tokens, topk]. The tensor of probabilities + corresponding to the permuted tokens. If provided, + the unpermuted tokens will be merged with their respective + probabilities. + padded_mode (bool, optional): If True, indicating the indices are padded to + [num_expert, capacity] to denote selected tokens per expert. + Defaults to False. + restore_shape (torch.Size, optional): The input shape before permutation, only used in + padding mode. Defaults to None. + + Returns: + torch.Tensor: The unpermuted tokens, optionally merged with probabilities. + """ + if padded_mode: + return unpermute_with_padded_tokens( + permuted_tokens, sorted_indices, probs, restore_shape=restore_shape + ) + + assert sorted_indices.numel() == permuted_tokens.size( + 0 + ), f"Got {sorted_indices.numel()} != {permuted_tokens.size(0)}." + if probs is not None: + # Unpermute and merge the tokens with their probabilities + num_unpermuted_tokens = probs.numel() + assert probs.dim() == 2, f"Expected 2D tensor for probs, got {probs.dim()} dims." + topk = probs.size(1) + else: + # Unpermute the tokens without merge + num_unpermuted_tokens = permuted_tokens.size(0) + topk = 1 + + output_size = [num_unpermuted_tokens, permuted_tokens.shape[-1]] + moe_scatter_indices = sorted_indices.unsqueeze(1).expand(-1, permuted_tokens.size(-1)) + unpermuted_tokens = moe_scatter.apply(permuted_tokens, moe_scatter_indices, output_size) + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) + if probs is not None: + unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) + unpermuted_tokens = unpermuted_tokens.sum(dim=1) + + return unpermuted_tokens + + +def permute_with_padded_tokens(tokens, indices): + """Permute the tokens based on the indices, only used in padding mode. + The input indices shape is [num_expert, capacity], it indicates which tokens were selected + by each expert separately. + Args: + tokens (torch.Tensor): The input token tensor. + indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected + tokens for each expert. + + Returns: + torch.Tensor: The permuted tensor. + torch.Tensor: The sorted_indices corresponding permuted tensor. + """ + permuted_tokens = tokens.index_select(dim=0, index=indices.view(-1)) + + return permuted_tokens, indices + + +def unpermute_with_padded_tokens( + permuted_tokens: torch.Tensor, + indices: torch.Tensor, + probs: torch.Tensor, + restore_shape: torch.Size, +) -> torch.Tensor: + """ + Unpermutes a padded permuted tokens based on sorted indices and merges the tokens with their + corresponding probabilities. + + This function takes a tensor of permuted tokens and reorders them according to the provided + indices. It also combines the tokens with their associated probabilities. + + Parameters: + permuted_tokens (torch.Tensor): A 2D tensor containing permuted tokens. + indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected + tokens for each expert. + probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities + corresponding to each token. + restore_shape (torch.Size): The target shape for the unpermuted tokens tensor. + + Returns: + torch.Tensor: A tensor of unpermuted tokens, merged with their probabilities. + + """ + # Ensure permuted_tokens is 2D + assert permuted_tokens.dim() == 2, f"Got {permuted_tokens.dim()}D." + + # Reshape and expand probabilities and indices to match permuted_tokens + probs = probs.view(-1).unsqueeze(-1) + indices = indices.view(-1, 1).expand(-1, permuted_tokens.shape[1]) + assert ( + permuted_tokens.shape == indices.shape + ), "Shape mismatch between permuted_tokens and indices." + + # Combine tokens with their probabilities + combined_output = probs * permuted_tokens + + # Prepare a tensor of zeros with the desired output shape + empty_tokens = torch.zeros( + restore_shape, dtype=combined_output.dtype, device=combined_output.device + ) + + # Scatter the combined tokens back to their original positions + unpermuted_tokens = torch.scatter_add(empty_tokens, 0, indices, combined_output) + + return unpermuted_tokens + + +class moe_gather(torch.autograd.Function): + """Gather the input tensor based on the map tensor.""" + + @staticmethod + def forward(ctx, input_, map_): + """Gather the input tensor based on the map tensor.""" + ctx.input_size = input_.size() + ctx.map = map_ + return torch.gather(input_, 0, map_) + + @staticmethod + def backward(ctx, grad_output): + """Scatter the grad_output tensor based on the map tensor.""" + input_size = ctx.input_size + map_ = ctx.map + + output = torch.zeros( + input_size, dtype=grad_output.dtype, device=torch.cuda.current_device() + ) + output.scatter_add_(0, map_, grad_output) + return output, None, None + + +class moe_scatter(torch.autograd.Function): + """Scatter the input tensor based on the map tensor.""" + + @staticmethod + def forward(ctx, input_, map_, output_size=None): + """Scatter the input tensor based on the map tensor.""" + ctx.map = map_ + + if output_size is not None: + output = torch.zeros(output_size, dtype=input_.dtype, device=input_.device) + else: + output = torch.zeros_like(input_) + + output.scatter_add_(0, map_, input_) + return output + + @staticmethod + def backward(ctx, grad_output): + """Gather the grad_output tensor based on the map tensor.""" + map_ = ctx.map + grad_input = torch.gather(grad_output, 0, map_) + return grad_input, None, None, None diff --git a/examples/deepseek_coder_v2_lite/train.py b/examples/deepseek_coder_v2_lite/train.py new file mode 100644 index 00000000..3c6ecf37 --- /dev/null +++ b/examples/deepseek_coder_v2_lite/train.py @@ -0,0 +1,337 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os +import logging +import math + +import datasets +from datasets import load_from_disk +import torch +from torch.optim.lr_scheduler import LRScheduler +from transformers import AutoConfig, AutoTokenizer, DataCollatorForLanguageModeling +from modeling.modeling_deepseek import DeepseekV2ForCausalLM, DeepseekV2MoE, DeepseekV2RotaryEmbedding +from modeling.modeling_deepseek_modifier import nnscaler_deepseek_init + +from nnscaler.utils import set_default_logger_level +from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer_args import ( + CheckpointConfig, + DatasetConfig, + HookMapConfig, + ModelConfig, + OptimizerConfig, + TrainerArgs, + DataloaderConfig, + AggregatedOutputs, + LogConfig, + DatasetSamplerConfig, + LRSchedulerConfig, +) +from nnscaler.parallel import ComputeConfig, BroadcastGenFilesStrategy +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdamW +from nnscaler.cli.loggers.tensorboard import TensorBoardLogger + +_logger = logging.getLogger(__name__) + + +IGNORE_IDX = -100 + + +class WarmupScheduler(LRScheduler): + + def __init__(self, optimizer, warmup_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch + 1 >= self.warmup_steps: + return self.base_lrs + return [base_lr * (self.last_epoch + 1) / self.warmup_steps for base_lr in self.base_lrs] + + +def get_tokenizer(tokenizer_name_or_path, + model_max_length=None, + default_bos_token="", + default_eos_token="", + default_pad_token="[PAD]", + default_unk_token=""): + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + special_tokens_dict = dict() + if tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = default_pad_token + if tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = default_eos_token + if tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = default_bos_token + if tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = default_unk_token + + tokenizer.add_special_tokens(special_tokens_dict) + if model_max_length: + tokenizer.model_max_length = model_max_length + return tokenizer + + +class WrapperModel(torch.nn.Module): + def __init__(self, model_id): + super().__init__() + self.model = DeepseekV2ForCausalLM.from_pretrained(model_id, attn_implementation='flash_attention_2') + self.model.train() + + # post-process model for usibility + # - merge small linear weights into large ones to use high performance kernel `grouped gemm` + # and avoid the overhead that merge them on the fly. Note that checkpoints of the wrapped model + # cannot be loaded directly to the transformers model. You need to split the weights with correct + # names. + # - reset `max_seq_len_cached`` in rotary embeddings since in transformers source code, `cos_cached` + # and `sin_cached` are evaluated during runtime, which violates the assumption of concrete tracer. + # As a result, we reset `max_seq_len_cached` to make caches static. + config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + for name, child in self.model.named_modules(): + if isinstance(child, DeepseekV2MoE): + _logger.info(f'Merging experts in {name} with {type(child)}') + # num_local_experts, intermediate_size, hidden_size + gate_projs = torch.stack([expert.gate_proj.weight for expert in child.experts], dim=0) + up_projs = torch.stack([expert.up_proj.weight for expert in child.experts], dim=0) + down_projs = torch.stack([expert.down_proj.weight for expert in child.experts], dim=0) + child.register_parameter('gate_projs', torch.nn.Parameter(gate_projs)) + child.register_parameter('up_projs', torch.nn.Parameter(up_projs)) + child.register_parameter('down_projs', torch.nn.Parameter(down_projs)) + elif isinstance(child, DeepseekV2RotaryEmbedding): + child.max_seq_len_cached = config.max_position_embeddings + + def forward(self, samples): + outputs = self.model( + input_ids=samples['net_input']['src_tokens'], + use_cache=False, + return_dict=False, + ) + logits = outputs[0].view(-1, outputs[0].size(-1)) + labels = samples['target'].view(-1) + normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) + loss = torch.nn.functional.nll_loss(normalized_logits, labels, reduction='sum', ignore_index=IGNORE_IDX) + return loss, loss.data, samples['ntokens'], samples['nsentences'] + + +def aggregate_outputs_fn(loss_outputs, sync_group) -> AggregatedOutputs: + losses, ntokens_info = [], [] + for _, loss, ntokens, _ in loss_outputs: + losses.append(loss) + ntokens_info.append(ntokens) + + loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) + torch.distributed.all_reduce(loss_sum, group=sync_group) + ntokens_sum = torch.sum(torch.tensor(ntokens_info, dtype=torch.float64, device=torch.cuda.current_device())) + torch.distributed.all_reduce(ntokens_sum, group=sync_group) + num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) + torch.distributed.all_reduce(num_batches, group=sync_group) + + return AggregatedOutputs( + loss_sum=loss_sum.item() / ntokens_sum.item() / math.log(2), + num_batches=num_batches.item(), + num_tokens=ntokens_sum.item(), + ) + + +def main(args): + + if args.run_mode == 'run': + broadcast_strategy = 'all' + else: + broadcast_strategy = 'none' + + set_default_logger_level('INFO') + + nnscaler_deepseek_init() + + ## Setup Dataset ## + + dataset = load_from_disk(args.dataset_path) + tokenizer = get_tokenizer(args.model_id) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + def collate_fn(samples): + if len(samples) == 0: + return {} + + mini_batch = data_collator(samples) + _mini_batch = {} + + src_tokens = mini_batch.pop('input_ids') + seq_len = src_tokens.size(-1) + _mini_batch['src_tokens'] = src_tokens + + shift_labels = mini_batch['labels'][..., 1:] + _mini_batch['labels'] = torch.nn.functional.pad(shift_labels, (0, 1), 'constant', IGNORE_IDX).contiguous() + + return { + "nsentences": len(samples), + "ntokens": len(samples) * seq_len, + "net_input": _mini_batch, + "target": _mini_batch.pop('labels'), + } + + ## Config Trainer ## + + if args.run_mode == 'compile': + if args.runtime_ngpus is None: + raise ValueError('runtime_ngpus must be specified in compile mode') + runtime_ngpus = args.runtime_ngpus + elif args.run_mode == 'run': + world_size = int(os.getenv('WORLD_SIZE')) + if args.runtime_ngpus is None: + runtime_ngpus = world_size + else: + if args.runtime_ngpus != world_size: + raise ValueError('runtime_ngpus must match the number of GPUs in run mode') + runtime_ngpus = args.runtime_ngpus + if runtime_ngpus % args.plan_ngpus != 0: + raise ValueError('runtime_ngpus must be a multiple of plan_ngpus') + + compute_config = ComputeConfig( + plan_ngpus=args.plan_ngpus, + runtime_ngpus=runtime_ngpus, + constant_folding=True, + use_zero=True, + use_end2end=True, + # autodist config: + # - memory constraint is set to 64GB + pas_config={ + 'mem_constraint': 64, + }, + ) + + model_config = ModelConfig( + type=WrapperModel, + args={ + 'model_id': args.model_id, + }, + ) + + # optimizer hyperparameters are from YaRN + lrscheduler_config = LRSchedulerConfig( + type=WarmupScheduler, + args={ + 'warmup_steps': 10, + }, + interval='step', + ) + + optimizer_config = OptimizerConfig( + type=MixedPrecisionAdamW, + args={'lr': 1e-5, 'betas': (0.9, 0.95), 'weight_decay': 0.0, 'fused': True}, + clip_gnorm=1.0, + loss_reduction='sum', + grad_reduction='per-token-mean', + aggregate_outputs_fn=aggregate_outputs_fn, + ) + + dataset_config = DatasetConfig( + type=(lambda split: dataset), + train_args={'split': 'train'}, + ) + + dataloader_config = DataloaderConfig( + train_args={ + 'collate_fn': collate_fn, + 'drop_last': True, + }, + ) + + sampler_config = DatasetSamplerConfig( + train_args={ + 'shuffle': True, + }, + ) + + checkpoint_config = CheckpointConfig( + every_n_train_steps=1000, + save_type='deduped', + resume_from=(args.resume_path or 'last'), + ) + + log_config = LogConfig( + type=TensorBoardLogger, + args={ + 'name': args.name, + 'root_dir': './runs', + }, + ) + + trainer_args = TrainerArgs( + instance_name=args.name, + run_mode=args.run_mode, + compute_config=compute_config, + pas_policy='autodist', + model=model_config, + optimizer=optimizer_config, + dataset=dataset_config, + dataloader=dataloader_config, + checkpoint=checkpoint_config, + precision='bf16', + max_epochs=1, + micro_batch_size=4, + grad_accumulation_steps=8, + log=[log_config], + seed=0, + broadcast_strategy=broadcast_strategy, + dataset_sampler=sampler_config, + lr_scheduler=lrscheduler_config, + ) + + trainer = Trainer(train_args=trainer_args) + trainer.run() + + +if __name__ == '__main__': + ## Parse Args ## + + parser = argparse.ArgumentParser() + parser.add_argument( + '--name', + default='deepseek-coder-v2-lite-2k', + type=str, + help='name of the experiment', + ) + parser.add_argument( + '--run_mode', + default='run', + choices=['run', 'compile'], + help='run or compile', + ) + parser.add_argument( + '--plan_ngpus', + type=int, + required=True, + help='specify the scale unit size', + ) + parser.add_argument( + '--runtime_ngpus', + type=int, + required=True, + help='specify the number of GPUs to use', + ) + parser.add_argument( + '--resume_path', + default=None, + type=str, + help='path to dir of ckpts or the ckpt file to resume from', + ) + parser.add_argument( + '--dataset_path', + default=None, + type=str, + help='path to the dataset', + ) + parser.add_argument( + '--model_id', + default=None, + type=str, + help='transformers model id', + ) + args = parser.parse_args() + + main(args) From 85560cbc8841d96af89984aef123ebc33871b1e7 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 17 Oct 2024 04:01:22 +0000 Subject: [PATCH 1744/1892] Merged PR 2286: Fix bugs for gnorm computation Gnorm calculation uses `nreplica` to compute the global gnorm. However in previous implementation, it is possible that params with different replica_info are organized into a same reducer. To fix this, we decouple the partitioned params and replicated params when generating cross scale unit reducers. --- nnscaler/codegen/module/module.py | 24 +++++++---- nnscaler/runtime/gnorm.py | 3 +- nnscaler/runtime/module.py | 4 ++ tests/graph/gener/test_reducer_gen.py | 59 +++++++++++++++++++++++++-- 4 files changed, 79 insertions(+), 11 deletions(-) diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 351c3d92..8214c978 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -166,8 +166,14 @@ def add_scale_reducers(self): assert param not in all_params, \ f'detected a parameter {param} in multiple reducers on device {device}' all_params.update(reducer.inputs()) - # create a reducer for the rest parameters used for this device - rest_params = [] + # create reducers for the rest parameters used for this device + # nnscaler's weights are either fully replicated or partitioned, which has been checked + # at graph/gener/gen.py/gen_weights. + # We decouple the replicated and partitioned weights to align with the calculation of + # gradient norm which uses the replicated number of each weight to make the global value + # correct. + rest_params_replicated = [] + rest_params_partitioned = [] def collect_rest_params(segment): """Resursively collect parameters. Note parameters can be in sub-segments, @@ -178,17 +184,21 @@ def collect_rest_params(segment): if device not in ctensor.device: continue if ctensor not in all_params: # a same parameter can be consumed multiple times by different operators - if ctensor not in rest_params: - rest_params.append(ctensor) + if ctensor.shape == ctensor.parent.shape: + if ctensor not in rest_params_replicated: + rest_params_replicated.append(ctensor) + else: + if ctensor not in rest_params_partitioned: + rest_params_partitioned.append(ctensor) for seg in segment.select(ntype=IRSegment, flatten=False): collect_rest_params(seg) collect_rest_params(graph) - if len(rest_params) == 0: - continue # create reducer and append to the execution # device will be scaled in `self.scale` - for reducer in IRWeightReducer.from_weights(rest_params, device): + for reducer in IRWeightReducer.from_weights(rest_params_replicated, device): + self.execplan.at(device).append(reducer) + for reducer in IRWeightReducer.from_weights(rest_params_partitioned, device): self.execplan.at(device).append(reducer) def get_comm_groups(self): diff --git a/nnscaler/runtime/gnorm.py b/nnscaler/runtime/gnorm.py index d8e7db4d..eb6a3e5b 100644 --- a/nnscaler/runtime/gnorm.py +++ b/nnscaler/runtime/gnorm.py @@ -142,8 +142,9 @@ def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, for seq, params_info in enumerate(params_info_for_gnorm): # params_info is ParamsInfo, which is defined in this file assert isinstance(params_info.ranks, tuple), f'ranks {params_info.ranks} should be tuple' - for name, param in zip(params_info.param_names, params_info.params): + for param in params_info.params: assert param.requires_grad + for name in params_info.param_names: tid = cube_model.tid_of_param_name(name) tid2ranks[tid] = params_info.ranks tid2info_list_seq[tid] = seq diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index e5fafaba..69255a5b 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -237,6 +237,10 @@ def parameters_for_calc_gnorm(self) -> List[ParamsInfo]: reducer_pids = set() for reducer in self._reducers: param_names = [paramid2name[id(p)] for p in reducer.params] + # we should use `parameters_for_optimizer` here since calculating gnorm + # is ahead of the optimizer step. When ZeRO is enabled, each device only + # maintains a subset of the parameters. As a result, `param_names` may not + # align with the value of `reducer.parameters_for_optimizer()`. params_info = ParamsInfo(reducer.ranks, reducer.parameters_for_optimizer(), param_names, reducer.zero_ngroups) params_info_for_gnorm.append(params_info) diff --git a/tests/graph/gener/test_reducer_gen.py b/tests/graph/gener/test_reducer_gen.py index 0a4cd4eb..2a1e5a99 100644 --- a/tests/graph/gener/test_reducer_gen.py +++ b/tests/graph/gener/test_reducer_gen.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import pytest +from pathlib import Path from nnscaler.graph.gener.gen import IRAdapterGener from nnscaler.graph import IRGraph @@ -10,9 +11,12 @@ from nnscaler.ir.operator import IRFwOperation from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.adapter import IRWeightReducer +from nnscaler.parallel import ComputeConfig, _load_parallel_module_class, parallelize +from ...utils import new_empty import torch import tempfile +import importlib from ...utils import replace_all_device_with @@ -27,8 +31,8 @@ class ReducerModule(torch.nn.Module): def __init__(self): super().__init__() - self.param1 = torch.nn.Parameter(torch.zeros([128, 128], dtype=torch.float16)) - self.param2 = torch.nn.Parameter(torch.zeros([128, 128], dtype=torch.float16)) + self.param1 = torch.nn.Parameter(torch.zeros([128, 128], dtype=torch.float32)) + self.param2 = torch.nn.Parameter(torch.zeros([128, 128], dtype=torch.float32)) def forward(self, x): x = torch.matmul(x, self.param1) @@ -44,7 +48,7 @@ def build_graph(): with tempfile.TemporaryDirectory() as tempdir: graph = convert_model( model, - {'x': torch.randn([128, 128], dtype=torch.float16)}, + {'x': torch.randn([128, 128], dtype=torch.float32)}, attr_savedir=tempdir, constant_folding=True ) @@ -117,3 +121,52 @@ def test_reducer_partially_shared_part(): with pytest.raises(RuntimeError): graph = IRAdapterGener.gen_weight(graph) print(graph.extra_repr()) + + +def pas_intra_reducer(graph: IRGraph, config: ComputeConfig): + dataloader = graph.nodes()[0] + sn0, sn1 = graph.replicate(dataloader, 2) + graph.assign(sn0, 0) + graph.assign(sn1, 1) + + fw_nodes = graph.select(ntype=IRFwOperation) + + for i, node in enumerate(fw_nodes): + if i == 1: + sn0, sn1 = graph.partition(node, node.algorithms('dim'), idx=1, dim=0, num=2) + else: + sn0, sn1 = graph.replicate(node, 2) + graph.assign(sn0, 0) + graph.assign(sn1, 1) + return graph + + +@replace_all_device_with('cpu') +def test_intra_scale_unit_reducers(): + compute_config = ComputeConfig( + plan_ngpus=2, + runtime_ngpus=4, + constant_folding=True, + use_zero=True, + use_end2end=True, + ) + model = ReducerModule() + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + model, + {'x': torch.randn([128, 128], dtype=torch.float32)}, + pas_intra_reducer, + compute_config, + gen_savedir=tempdir, + reuse='match', + load_module=False, + ) + for i in range(4): + module_class = _load_parallel_module_class(ReducerModule, gen_savedir=Path(tempdir), rank=i) + m = new_empty(module_class) + assert len(m.reducers) == 2 + reducer0, reducer1 = m.reducers + assert len(reducer0.params) == 1 + assert reducer0.params[0].shape == torch.Size([128, 128]) + assert len(reducer1.params) == 1 + assert reducer1.params[0].shape == torch.Size([64, 128]) From 59827669b37674725bd71d81baf23633bfdd12a0 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 22 Oct 2024 08:39:10 +0000 Subject: [PATCH 1745/1892] Merged PR 2282: add option to use reduce-scatter adapter add option to use reduce-scatter adapter --- nnscaler/flags.py | 9 +- nnscaler/graph/gener/rvd/inter.py | 33 ++- nnscaler/graph/gener/rvd/intra.py | 141 ++++++---- nnscaler/graph/gener/rvd/layout.py | 127 +++++++-- nnscaler/ir/adapter/prim.py | 10 +- nnscaler/ir/tensor.py | 1 + nnscaler/utils.py | 12 + .../{check_inter_rvd.py => test_inter_rvd.py} | 56 ++-- .../{check_intra_rvd.py => test_intra_rvd.py} | 260 +++++++++++++----- tests/graph/gener/test_layout.py | 44 +++ tests/parallel_module/test_gencode.py | 73 +++++ tests/test_utils.py | 31 ++- 12 files changed, 622 insertions(+), 175 deletions(-) rename tests/graph/gener/{check_inter_rvd.py => test_inter_rvd.py} (62%) rename tests/graph/gener/{check_intra_rvd.py => test_intra_rvd.py} (53%) create mode 100644 tests/graph/gener/test_layout.py diff --git a/nnscaler/flags.py b/nnscaler/flags.py index 55ad0418..cc8cb4bf 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -8,8 +8,8 @@ import os -def _to_bool(s: str) -> bool: - val = os.environ.get(s, default=0) +def _to_bool(s: str, default=False) -> bool: + val = os.environ.get(s, default=default) return bool(int(val)) @@ -34,6 +34,11 @@ class CompileFlag: disable_code_line_info = _to_bool('DISABLE_CODE_LINE_INFO') # will add original code information in generated code, note that this will make trace slow # how to execute the functions during trace, available choices ['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'] trace_strategy = os.environ.get('TRACE_STRATEGY', default='cuda_run_cpu_offload') + # reduce scatter adapter can reduce the communication cost, and improve the performance + # but sometimes it may cause communication problems, so we provide an option to enable/disable it + # by default, we disable it and use allreduce+chunk instead. + # TODO: enable it by default after we fix the parity issue + enable_reduce_scatter_adapter = _to_bool('ENABLE_REDUCE_SCATTER_ADAPTER', False) # ============== runtime ==================== dev_mode = _to_bool('SINGLE_DEV_MODE') # allow to use python xx.py diff --git a/nnscaler/graph/gener/rvd/inter.py b/nnscaler/graph/gener/rvd/inter.py index 8640b7ef..926d4c1f 100644 --- a/nnscaler/graph/gener/rvd/inter.py +++ b/nnscaler/graph/gener/rvd/inter.py @@ -4,8 +4,6 @@ from typing import Callable, Dict, List, Tuple, Optional, Set, Union from functools import partial import numpy as np -import sys -import copy from nnscaler.ir.tensor import IRFullTensor @@ -17,7 +15,9 @@ from nnscaler.graph.gener.rvd.layout import RVDLayout from nnscaler.graph.gener.rvd.intra import IntraPathFinder -from nnscaler.graph.gener.utils import tensor_vd_repr + +from nnscaler.utils import classproperty +from nnscaler.flags import CompileFlag TShape = Tuple[int, ...] @@ -140,6 +140,8 @@ def transitionable(src_rvd: TRVD, dst_rvd: TRVD) -> Optional[Callable]: decd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 > d2] if len(incd) == 0 and len(decd) == 0: decd = [0] + # only support one dimension change + # this only happens when the device number changes if len(incd) + len(decd) != 1: return trans_fn if len(incd) == 1: incd = incd[0] @@ -215,11 +217,28 @@ class InterPathFinder: """ inter-RVD Path finder for generating communication plans for RVDLayout """ + # Key is configuration. + # Currently only CompileFlag.enable_reduce_scatter_adapter is considered + _config_cached_inter_nodes: Dict[Tuple, Dict[Tuple[TShape, int, int], Tuple[Tuple[InterRVD]]]] = {} + _config_cached_inter_edges: Dict[Tuple, Dict[Tuple[TShape, int, int], Tuple[np.ndarray]]] = {} + _config_cached_inter_paths: Dict[Tuple, Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]]] = {} + + @classproperty + def _cached_inter_nodes(cls): + return cls._config_cached_inter_nodes.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + + @classproperty + def _cached_inter_edges(cls): + return cls._config_cached_inter_edges.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) - _cached_inter_nodes: Dict[Tuple[TShape, int, int], Tuple[Tuple[InterRVD]]] = {} - _cached_inter_edges: Dict[Tuple[TShape, int, int], Tuple[np.ndarray]] = {} - _cached_inter_paths: Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]] = {} + @classproperty + def _cached_inter_paths(cls): + return cls._config_cached_inter_paths.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + # type annotation because type cannot be inferred from `classproperty` + _cached_inter_nodes: Dict[Tuple[TShape, int, int], Tuple[Tuple[InterRVD]]] + _cached_inter_edges: Dict[Tuple[TShape, int, int], Tuple[np.ndarray]] + _cached_inter_paths: Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]] @staticmethod def path(ilayout: RVDLayout, olayout: RVDLayout, cost_fn: Optional[Callable] = None) -> List[IRAdapterPrim]: @@ -396,7 +415,7 @@ def init_graph(ftensor: IRFullTensor, src_ndevs: int, dst_ndevs: int, cost_fn: C IntraPathFinder._cached_intra_edges[(shape, src_ndevs)] = src_edges IntraPathFinder._cached_intra_paths[(shape, src_ndevs)] = {} - if (shape, dst_ndevs) in InterPathFinder._cached_inter_edges: + if (shape, dst_ndevs) in IntraPathFinder._cached_intra_nodes: dst_nodes = IntraPathFinder._cached_intra_nodes[(shape, dst_ndevs)] dst_edges = IntraPathFinder._cached_intra_edges[(shape, dst_ndevs)] else: diff --git a/nnscaler/graph/gener/rvd/intra.py b/nnscaler/graph/gener/rvd/intra.py index 89c37dce..a5d10d15 100644 --- a/nnscaler/graph/gener/rvd/intra.py +++ b/nnscaler/graph/gener/rvd/intra.py @@ -25,6 +25,10 @@ from nnscaler.graph.gener.utils import tensor_vd_repr +from nnscaler.utils import classproperty +from nnscaler.flags import CompileFlag + + _logger = logging.getLogger(__name__) TShape = Tuple[int, ...] TRVD = Tuple[int, ...] @@ -104,7 +108,7 @@ def v2d(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: @staticmethod def r2d(rvd: TRVD, dim: int, chunks: int) -> Tuple: """ - intra-RVD primitive V->D: schunk + intra-RVD primitive R->D: schunk @param dim int: tensor axis @param chunks int: the number of chunks to transfer @@ -120,9 +124,8 @@ def r2d(rvd: TRVD, dim: int, chunks: int) -> Tuple: @staticmethod def r2v(rvd: TRVD, chunks: int) -> Tuple: """ - intra-RVD primitive V->D: schunk + intra-RVD primitive R->V: vchunk - @param dim int: tensor axis @param chunks int: the number of chunks to transfer @return rvd Tuple[int]: output RVD @@ -173,13 +176,16 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD) -> List[Tuple[RVDLayout, Li """ Transfer from source RVD to destination RVD. Get all possible device-placement choices for RVD + (for returned RVDLayout, only device placement are different.) given the fixed device placement of RVD. - @param src_layout RVDLayout: source ilayout - @param dst_rvd Tuple[int]: destination RVD + Args: + src_layout (RVDLayout): source ilayout + dst_rvd (Tuple[int, ...]): destination RVD - @return rets List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: - tuple of pairs of with each has a different device mapping. + Returns: + List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: + tuple of pairs of with each has a different device mapping. """ src_rvd = src_layout.vec if src_rvd == dst_rvd: return [(src_layout, [])] @@ -198,6 +204,7 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD) -> List[Tuple[RVDLayout, Li ilayouts: List[RVDLayout] = [src_layout] olayouts: List[RVDLayout] = [RVDLayout.grid(src_layout.ftensor, r=dst_rvd[0], v=dst_rvd[1], dims=dst_rvd[2:])] # setup ilayout choices + # add alternative choices for device placement with inner-transpose if decd in optional_dims: ftensor = src_layout.ftensor for k in range(2, src_rvd[decd]): @@ -228,11 +235,29 @@ class IntraPathFinder: """ intra-RVD Path finder for generating communication plans for RVDLayout """ - + # Key is configuration. + # Currently only CompileFlag.enable_reduce_scatter_adapter is considered # intra-shard: cached nodes. paths[shape][i][j] = List[int] of indices from (src -> dst] - _cached_intra_nodes: Dict[Tuple[TShape, int], Tuple[TRVD]] = {} - _cached_intra_edges: Dict[Tuple[TShape, int], np.ndarray] = {} - _cached_intra_paths: Dict[Tuple[TShape, int], Dict[TRVD, List[List[int]]]] = {} + _config_cached_intra_nodes: Dict[Tuple, Dict[Tuple[TShape, int], Tuple[TRVD]]] = {} + _config_cached_intra_edges: Dict[Tuple, Dict[Tuple[TShape, int], np.ndarray]] = {} + _config_cached_intra_paths: Dict[Tuple, Dict[Tuple[TShape, int], Dict[TRVD, List[List[int]]]]] = {} + + @classproperty + def _cached_intra_nodes(cls): + return cls._config_cached_intra_nodes.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + + @classproperty + def _cached_intra_edges(cls): + return cls._config_cached_intra_edges.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + + @classproperty + def _cached_intra_paths(cls): + return cls._config_cached_intra_paths.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + + # type annotation because type cannot be inferred from `classproperty` + _cached_intra_nodes: Dict[Tuple[TShape, int], Tuple[TRVD]] + _cached_intra_edges: Dict[Tuple[TShape, int], np.ndarray] + _cached_intra_paths: Dict[Tuple[TShape, int], Dict[TRVD, List[List[int]]]] @staticmethod def path(ilayout: RVDLayout, olayout: RVDLayout, @@ -307,17 +332,30 @@ def backup_path(ilayout: RVDLayout, olayout: RVDLayout, @staticmethod def device_align(ilayout: RVDLayout, olayout: RVDLayout, - rvd_path: Tuple[TRVD], _all_prims: Optional[None] = None) -> Tuple[bool, List[IRAdapterPrim]]: + rvd_path: Tuple[TRVD, ...], _all_prims: Optional[None] = None) -> Tuple[bool, List[IRAdapterPrim]]: """ Align devices for intra-RVD - - @param ilayouts RVDLayout: source layout - @param olayout RVDLayout: target layout with correct device mapping - @param rvd_hops: Tuple[TRVD]: the hops from ilayout to olayout, which - contains ilayout and olayout at beginning and last, respectively. - - @return success bool: True if found device, else False. - @return primitives List[IRAdapterPrim]: the correspoinding primitives + We recursively search for the correct device mapping from `ilayout` to `olayout` + The success of the search is determined by the device placement of `ilayout` and `olayout`. + + `rvd_path` is the transition path from ilayout to olayout. + The first item can be assumed ilayout (R/V/D are same but device placement may be different in recursive calls), + and the last item is olayout. + + The exit condition is when the length of `rvd_path` is 1, + which means ilayout and olayout have the same R/V/D, + and we just check the device placement are compatible (via `RVDLayout.align`). + + Args: + ilayouts (RVDLayout): source layout + olayout (RVDLayout): target layout with correct device mapping + rvd_hops (Tuple[TRVD, ...]): the hops from ilayout to olayout, which + contains ilayout and olayout at beginning and last, respectively. + _all_prims (List[IRAdapterPrim]): the previous primitives, only for recursive calls + Returns: + Tuple[bool, List[IRAdapterPrim]]: + - success bool: True if found device, else False. + - primitives List[IRAdapterPrim]: the correspoinding primitives """ _all_prims = [] if _all_prims is None else _all_prims assert ilayout.vec == rvd_path[0] and olayout.vec == rvd_path[-1] @@ -424,15 +462,16 @@ def get_backup_path(ftensor: IRFullTensor, src_rvd: TRVD, dst_rvd: TRVD, return left + right[1:] @staticmethod - def get_device_space(ftensor: IRFullTensor, rvd_paths: List[TRVD], placement: Tuple[int]) -> Set[Tuple[int]]: + def get_device_space(ftensor: IRFullTensor, rvd_paths: List[TRVD], placement: Tuple[int, ...]) -> Set[Tuple[int, ...]]: """ - Get all possible device placement of the last RVD given the rvd transition paths. + Get all possible device placement of the destination RVD given the rvd transition paths. - @param ftensor IRFullTensor - @param rvd_paths Tuple[TRVDS]: transition RVD paths from source to destination - @param placement Tuple[int]: device placement of the first RVD in rvd_paths - - @return placements Set[Tuple[int]]: all possible device placement + Args: + ftensor (IRFullTensor): the full tensor + rvd_paths (List[TRVDS]): transition RVD paths from source to destination + placement (Tuple[int, ...]): device placement of the first RVD in rvd_paths + Returns: + Set[Tuple[int, ...]]: all possible device placement of the destination RVD """ init, hops = rvd_paths[0], rvd_paths[1:] rvds: List[RVDLayout] = [RVDLayout.grid(ftensor, r=init[0], v=init[1], dims=init[2:], devices=placement)] @@ -475,16 +514,13 @@ def init_graph(ftensor: IRFullTensor, ndevs: int, cost_fn: Optional[Callable] = @staticmethod def get_rvd_space(ftensor: IRFullTensor, ndevs: int) -> List[Tuple[int, ...]]: """ - Get all possible RVD representations given ftensor. - - This space is pruned by limiting partition number of each RVD dimension - in the range of [min(ilayout[dim], olayout[dim]), max(ilayout[dim], olayout[dim])] - - @param ftensor IRFullTensor - @param ilayout GridLayout: input layout - @param olayout GridLayout: output layout + Get all possible RVD representations given ftensor and device number. - @return layouts List[GridLayout]: + Args: + ftensor (IRFullTensor): the full tensor + ndevs (int): the number of devices + Returns: + List[Tuple[int, ...]]: all possible RVD representations """ all_layouts: List[int] = [] @@ -616,21 +652,22 @@ def advice(shape: TShape, fw_src_rvd: TRVD, fw_dst_rvd: TRVD, bw_src_rvd: Optional[TRVD], bw_dst_rvd: Optional[TRVD], src_placement: List[int], - cost_fn: Optional[Callable] = None) -> Tuple[Tuple[int], float]: - """ - Search for a good device placement for - source and destination RVD partition - - @param shape Tuple[int]: full tensor shape - @param fw_src_rvd Tuple[int]: forward producer RVD layout vector - @param fw_dst_rvd Tuple[int]: forward consumer RVD layout vector - @param bw_src_rvd Optional[Tuple[int]]: backward producer RVD layout vector - @param bw_dst_rvd Optional[Tuple[int]]: backward consumer RVD layout vector - @param cost_fn Optional[Callable]: cost function of each primitive. - Default (None) will use communication volume as metrics - - @return devices Tuple[int]: device sequence for RVD tensors - @return cost float: Cost of communication plan + cost_fn: Optional[Callable] = None) -> Tuple[Tuple[int, ...], float]: + """ + Search for a good device placement for destination RVD partition (fw_dst_rvd and bw_src_rvd) + + Args: + shape (Tuple[int]): full tensor shape + fw_src_rvd (TRVD): forward producer RVD layout vector + fw_dst_rvd (TRVD): forward consumer RVD layout vector + bw_src_rvd (Optional[TRVD]): backward producer RVD layout vector + bw_dst_rvd (Optional[TRVD]): backward consumer RVD layout vector + src_placement (List[int]): device placement of source RVD + cost_fn (Optional[Callable]): cost function of each primitive. + Default (None) will use communication volume as metrics + Returns: + Tuple[int, ...]: best device placement for RVD tensors + float: Cost of communication plan """ src_placement = tuple(src_placement) ftensor = IRFullTensor(shape, dtype=torch.float16) @@ -661,6 +698,8 @@ def advice(shape: TShape, placement = None # - if find, choose one + # FIXME: looks the above code (`devices = fw_consumer_devices` as a fallback) should be removed. + # so here we check whether we have found a valid placement. if len(devices) > 0: placement = list(devices)[0] # - if not find, keep forward one as optimal and adopt backup plan for backward one diff --git a/nnscaler/graph/gener/rvd/layout.py b/nnscaler/graph/gener/rvd/layout.py index e20ba40b..f2ba7b0f 100644 --- a/nnscaler/graph/gener/rvd/layout.py +++ b/nnscaler/graph/gener/rvd/layout.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Dict, List, Tuple, Optional +from typing import Dict, Iterator, List, Tuple, Optional import copy import numpy as np @@ -19,8 +19,27 @@ class RVDLayout: This class assumes a full-tensor can only be uniformly partitioned / replicated on dimensions and values. - A partition plan N-dim tensor layout can be represented as + DNN clusters are usually equipped with homogeneous accelerator devices. + Therefore, most parallelization plans partition operators evenly. + Thus, a partition plan N-dim tensor layout can be simply represented as : R (replica), V (value), dim_i (dimension) + + which means: + 1) R(i), the tensor is replicated to i copies; + 2) V(j), value split, the tensor is decomposed to j copies with the same shape; + 3) D(k1,k2,...,kn), uniformly partition the tensor into k1 parts in + the first dimension, k2 parts in the second dimension, so on + so forth. + + We use RVD to denote the transformation of a tensor. + For example, R(1)V(2)D(1,2) indicates a 2-D pTensor + requires no replication, is decomposed into 2 vTensors with + the same shape, and each is partitioned into 2 vTensors by + partitioning the second axis. + Thus, R(1)V(2)D(1,2) can represent 4 vTensors. + + RVD can represent both producer vTensors and consumer vTensors + as they are both transformed from the pTensor. """ def __init__(self, ftensor: IRFullTensor, subtensors: List[IRSubTensor], mats: np.ndarray): @@ -73,7 +92,7 @@ def tensor(self, r: int, v: int, d: List[int]) -> IRSubTensor: def __repr__(self): dscp = f'T{self.ftensor._id}' return dscp - + def __copy__(self): tensors = [] for t in self.mat.flatten(): @@ -91,7 +110,7 @@ def align(self, layout) -> bool: @param layout RVDLayout - @return same bool: + @return same bool: """ if not isinstance(layout, RVDLayout): return False @@ -109,7 +128,23 @@ def align(self, layout) -> bool: def inner_transpose(self, dim: int, chunks: int): """ - transpose ordering of tensor within a dimension. + Transpose ordering of tensor within a dimension. + The only goal is to shuffle the tensors (but RVD values are the same) in a dimension + to try to find a better path. + + Currently only R abd V dim are using this function. + If dim is 0 (R), then the tensor is shuffled in the first dimension. + which means the dp units are shuffled. + For example, we have 8 devices, and R=4, chunks=2, then + before: devices of 0~3 replica: [0, 1], [2, 3], [4, 5], [6, 7] + after: devices of 0~3 replica: [0, 1], [4, 5], [2, 3], [6, 7] + If dim is 1 (V), we have similar behavior. + For example, we have 8 devices, and R=1 V=4, chunks=2, then + before: devices of 0~3 value partitions: [0, 1], [2, 3], [4, 5], [6, 7] + after: devices of 0~3 value partitions: [0, 1], [4, 5], [2, 3], [6, 7] + + You can see after the shuffle, nothing is changed except the device assignment order. + """ assert 0 <= dim and dim < len(self._mats.shape) assert self.vec[dim] % chunks == 0 @@ -125,6 +160,24 @@ def inner_transpose(self, dim: int, chunks: int): @staticmethod def dim2last(mat: np.ndarray, dim: int, chunk: int) -> np.ndarray: + """ + Move the dimension that needs to be operated on to the last. + So in the following operation we can operate on the last dimension, like + ``` + for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): + prims.append(primitive(itensors, otensors)) + ``` + For example, if we want to transform R(1)V(2)D(1, 4) to R(1)V(1)D(1, 8). + Essentially, we want to transform + `imat[*, *, 0, *, *]` and `imat[*, *, 1, *, *]` + to + `omat[*, *, 0, *, *, 0] and `omat[*, *, 0, *, *, 1]` + + and reshape omat to R(1)V(1)D(1, 8) + + We don't bother to use a nested for loop, instead, + we move the related dimension to the last, imat[*, *, V, *, *] -> imat[*, *, *, *, V] + """ shape = list(mat.shape) assert shape[dim] % chunk == 0 shape[dim] = shape[dim] // chunk @@ -135,9 +188,29 @@ def dim2last(mat: np.ndarray, dim: int, chunk: int) -> np.ndarray: return mat @staticmethod - def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int], devices: Optional[Tuple[int]] = None): + def grid(ftensor: IRFullTensor, r: int, v: int, dims: Tuple[int], devices: Optional[Tuple[int, ...]] = None): """ partition a ftensor using grid layout of + + For device assignment, if devices is not None, assign devices in order. + For example, you have 8 devices, and r=2, v=2, dims=(1, 2) Then + 1. Split devices into r groups, which mean the outmost is data parallelism. + So (0, 1, 2, 3) is a sub group, and (4, 5, 6, 7) is another sub group + These two sub groups are replicated. + 2. Split devices in each r-group into v groups. + V is for value parallelism. + When V > 1, the value is partitioned. + That happens when previous forward op splits reducer dimention (the `+` in dimop annoation). + For the example above, (0, 1, 2, 3) will be splitted into (0, 1) and (2, 3) + 3. Split devices in each v-group into dims groups. It is tensor parallelism, + and is the innermost. + So (0, 1) is splitted into (0,) and (1,) + + Please note that is not the only way to assign devices. But it is our best guess. + `.inner_transpose()` can be used to shuffle the tensor within a dimension, + and hope to find a match for devices + + TODO: We need to support more flexible device assignment. """ dims = tuple(dims) def dummy_assign(tensor: IRSubTensor, devid: int): @@ -147,7 +220,7 @@ def dummy_assign(tensor: IRSubTensor, devid: int): mats = np.empty((r, v) + dims, dtype=IRSubTensor) all_subtensors = [] - def iter_idx(dims: List[int]) -> Tuple[int]: + def iter_idx(dims: List[int]) -> Iterator[Tuple[int, ...]]: if len(dims) == 0: yield () else: @@ -182,30 +255,44 @@ def iter_idx(dims: List[int]) -> Tuple[int]: @staticmethod def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): """ - convert ftensor and subtensors into a RVDLayout. - - If failed, raise error + Convert ftensor and subtensors into a RVDLayout. + Here we requires all subtensors are well formed, and can be organized as R(...)V(...)D(...) format. + + Please note the devices are kept as it is, and may be different with how `.grid()` assigns the devices. + + Args: + ftensor (IRFullTensor): full tensor + subtensors (List[IRSubTensor]): subtensors of the full tensor. + Returns: + RVDLayout: rvd layout + Raises: + RuntimeError: if subtensors are not well formed. """ _replica: int = None _value: int = None _dims: List[int] = [None] * len(ftensor.shape) + # id(subtensor) -> [replica_index, value_index, dim1_index, dim2_index, ...] + # Plese note key is not subtensor.id, but id(subtensor) _tindex: Dict[int, List[int]] = dict() ndims = len(ftensor.shape) + # Key: subtensor id + # Please note subtensors with same indmap and valmap have the same tid. + # which indicates they are replicated. replicas: Dict[int, List[IRSubTensor]] = dict() vchunks: set = set() dchunks: List[set] = [set() for _ in range(ndims)] for subtensor in subtensors: - tid = id(subtensor) + oid = id(subtensor) # set up replica if subtensor.tid not in replicas: replicas[subtensor.tid] = [] - _tindex[tid] = [len(replicas[subtensor.tid])] + _tindex[oid] = [len(replicas[subtensor.tid])] replicas[subtensor.tid].append(subtensor) # setup value - _tindex[tid].append(subtensor.valmap[0]) + _tindex[oid].append(subtensor.valmap[0]) vchunks.add(subtensor.valmap[1]) # setup dimensions for dim in range(ndims): @@ -219,7 +306,7 @@ def togrid(ftensor: IRFullTensor, subtensors: List[IRSubTensor]): f"full nele: {fnele}, sub nele: {snele}, start: {start}" ) dchunks[dim].add(fnele // snele) - _tindex[tid].append(start // snele) + _tindex[oid].append(start // snele) # replica (R) nreplicas = set(len(ts) for ts in replicas.values()) if len(nreplicas) != 1: @@ -255,6 +342,7 @@ def draw(prvd: RVDLayout, crvd: RVDLayout, outfile: str) -> None: """ import matplotlib.pyplot as plt from matplotlib.patches import Rectangle + import matplotlib.axes max_dev = max( max(t.device[0] for t in prvd.subtensors), max(t.device[0] for t in crvd.subtensors) @@ -266,6 +354,7 @@ def draw(prvd: RVDLayout, crvd: RVDLayout, outfile: str) -> None: plt.close('all') plt.rcParams['figure.figsize'] = (4.0 * devlen, 7.0) fig, ax = plt.subplots() + ax: matplotlib.axes.Axes fontsize = 30 @@ -285,7 +374,7 @@ def draw_subtensor(t: IRSubTensor, xy: Tuple[int], color: str): subx_nchunks = t.parent.shape[1] // t.shape[1] subw = recflen / subx_nchunks subx = x + subw * (t.indmap[1][0] // t.shape[1]) - + suby_nchunks = t.parent.shape[0] // t.shape[0] subh = recflen / suby_nchunks suby = y + subh * (t.indmap[0][0] // t.shape[0]) @@ -293,7 +382,7 @@ def draw_subtensor(t: IRSubTensor, xy: Tuple[int], color: str): # if t.valmap != (0, 1): ax.text(x=x+recflen/2, y=y+recflen+recflen/2, s=f'val({t.valmap[0]}/{t.valmap[1]})', fontsize=fontsize, ha='center', va='center', color='black') - + subrec = Rectangle((subx, suby), subw, subh, color=color, ec='black', lw=2.0) ax.add_artist(rec) ax.add_artist(subrec) @@ -311,9 +400,10 @@ def draw_subtensor(t: IRSubTensor, xy: Tuple[int], color: str): ax.text(x=-1, y=0.5+recflen/2, s='Consumer', fontsize=fontsize, ha='center', va='center', color='black') - + for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(fontsize) + tick.label1.set_fontsize(fontsize) + tick.label2.set_fontsize(fontsize) ax.spines['bottom'].set_color('white') ax.spines['top'].set_color('white') @@ -323,4 +413,3 @@ def draw_subtensor(t: IRSubTensor, xy: Tuple[int], color: str): ax.get_yaxis().set_visible(False) plt.savefig(outfile) - \ No newline at end of file diff --git a/nnscaler/ir/adapter/prim.py b/nnscaler/ir/adapter/prim.py index d85250ef..b67acd36 100644 --- a/nnscaler/ir/adapter/prim.py +++ b/nnscaler/ir/adapter/prim.py @@ -9,6 +9,7 @@ import copy from nnscaler.ir.tensor import IRSubTensor, IndexMap, ValueMap +from nnscaler.flags import CompileFlag # the general adapter primitive class @@ -344,7 +345,6 @@ def __repr__(self) -> str: return f"{self.outputs()} = broadcast{self.device}({self.inputs()}, src={self.kwargs['src']})" - class AllReducePrim(CollectivePrim): """ non-differentiable allreduce @@ -396,9 +396,10 @@ def volume(self) -> int: Use ring-based communication cost """ ndevs = len(self.inputs()) - # FIXME: temporally disable reduce scatter in code generation - # which has parity issues for now. - return 100 * (ndevs - 1) * self.input(0).nelement() // ndevs + vol = (ndevs - 1) * self.input(0).nelement() // ndevs + if not CompileFlag.enable_reduce_scatter_adapter: + vol *= 100 + return vol def __repr__(self) -> str: return f'{self.outputs()} = reduce_scatter[{self.device}]({self.inputs()})' @@ -462,6 +463,7 @@ class VChunkPrim(CollectivePrim): """ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): super().__init__(itensors, otensors, **kwargs) + # FIXME: nnscaler.runtime.adapter.vchunk does not exist self.signature = 'nnscaler.runtime.adapter.vchunk' def volume(self) -> int: diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index b7b79861..a0526df2 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -487,6 +487,7 @@ def __init__(self, ftensor: IRFullTensor, **kwargs): """ Create an IRSubTensor. + Please note same sub-tensor (parent+indmap+valmap) will have the same tid @param ftensor IRFullTensor: the full tensor @param indmap IndexMap: index map diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 55104bca..b5209250 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -280,6 +280,18 @@ def select_many(data: Iterable[Any], fn: Callable[[Any], Iterable[Any]]) -> Iter yield from fn(item) +# ref: https://stackoverflow.com/questions/128573/using-property-on-classmethods +class classproperty(property): + """ + A simple class property decorator. + """ + def __get__(self, obj, objtype=None): + # obj will be None when accessed from the class like `MyClass.my_property` + return super(classproperty, self).__get__(objtype) + # This hack doesn't work for __set__ and __delete__. + # so here __set__ and __delete__ are not implemented, and the property is read-only + + class accum_mode: """Make cube execution in gradient accumulation mode. diff --git a/tests/graph/gener/check_inter_rvd.py b/tests/graph/gener/test_inter_rvd.py similarity index 62% rename from tests/graph/gener/check_inter_rvd.py rename to tests/graph/gener/test_inter_rvd.py index dc4cc2b5..82ba5a32 100644 --- a/tests/graph/gener/check_inter_rvd.py +++ b/tests/graph/gener/test_inter_rvd.py @@ -1,14 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -""" -Note this is not for test. - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - tests/adapter/test_inter_rvd.py -""" - from typing import List, Tuple import nnscaler from nnscaler.ir.tensor import IRFullTensor @@ -16,10 +8,7 @@ from nnscaler.graph.gener.rvd.inter import InterPathFinder import numpy as np -from nnscaler.graph.gener.utils import tensor_vd_repr - - -nnscaler.init() +from .test_intra_rvd import enable_reduce_scatter_adapter def factors(k: int, num: int) -> List[Tuple[int]]: @@ -36,7 +25,6 @@ def factors(k: int, num: int) -> List[Tuple[int]]: def test_one_f_case(): - fshape = [128, 256, 512] src_r, src_v, src_d = 1,4,(1,1,2) @@ -46,7 +34,7 @@ def test_one_f_case(): pndevs = np.prod(src_rvd) cndevs = np.prod(dst_rvd) - + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) pdevs = list(range(pndevs)) @@ -56,21 +44,44 @@ def test_one_f_case(): fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=cdevs) rvds = InterPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) - print(f"optimal path: {' -> '.join(str(rvd) for rvd in rvds)}") + assert rvds == (('p', 1, 4, 1, 1, 2), ('p', 1, 1, 4, 1, 2), ('c', 1, 1, 4, 1, 2), ('c', 2, 1, 2, 1, 2)) fprims = InterPathFinder.path(fp_rvd, fc_rvd) - for prim in fprims: - print(prim) + assert len(fprims) == 14 + # producer part, v->d, so reduce_scatter + assert fprims[0].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[0].device == [0, 2, 4, 6] + assert fprims[1].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[1].device == [1, 3, 5, 7] + # inter part move + src_devs = set() + dst_devs = set() + for i in range(8): + assert fprims[2 + i].signature == 'nnscaler.runtime.adapter.move' + src_devs.add(fprims[2 + i].kwargs['src']) + dst_devs.add(fprims[2 + i].kwargs['dst']) + + assert src_devs == set([0, 1, 2, 3, 4, 5, 6, 7]) + assert dst_devs == set([8, 9, 10, 11, 12, 13, 14, 15]) + + # consumer part, d->v, so all_gather + assert fprims[10].signature == 'nnscaler.runtime.adapter.all_gather' + assert fprims[10].device == [8, 12] + assert fprims[11].signature == 'nnscaler.runtime.adapter.all_gather' + assert fprims[11].device == [9, 13] + assert fprims[12].signature == 'nnscaler.runtime.adapter.all_gather' + assert fprims[12].device == [10, 14] + assert fprims[13].signature == 'nnscaler.runtime.adapter.all_gather' + assert fprims[13].device == [11, 15] def test_all_f_cases_fix_placement(): - fshape = [128, 256, 512] ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) pndevs = 4 cndevs = 8 - + ndims = len(fshape) + 2 for src_rvd in factors(pndevs, ndims): for dst_rvd in factors(cndevs, ndims): @@ -86,8 +97,5 @@ def test_all_f_cases_fix_placement(): rvds = InterPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) print(f"==> path: {'->'.join(str(rvd) for rvd in rvds)}") - -if __name__ == '__main__': - - # test_one_f_case() - test_all_f_cases_fix_placement() \ No newline at end of file + # should not raise any exception + assert True diff --git a/tests/graph/gener/check_intra_rvd.py b/tests/graph/gener/test_intra_rvd.py similarity index 53% rename from tests/graph/gener/check_intra_rvd.py rename to tests/graph/gener/test_intra_rvd.py index 84ce00d7..0c425d5a 100644 --- a/tests/graph/gener/check_intra_rvd.py +++ b/tests/graph/gener/test_intra_rvd.py @@ -1,14 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -""" -Note this is not for test. - -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - unit_test/graph/gener/test_intra_rvd.py -""" - from typing import List, Tuple import nnscaler from nnscaler.ir.tensor import IRFullTensor @@ -16,10 +8,17 @@ from nnscaler.graph.gener.rvd.intra import IntraPathFinder, IntraAutoPlacer, IntraTransition import numpy as np -from nnscaler.graph.gener.utils import tensor_vd_repr + +import pytest -nnscaler.init() +@pytest.fixture(autouse=True) +def enable_reduce_scatter_adapter(): + from nnscaler.flags import CompileFlag + old = CompileFlag.enable_reduce_scatter_adapter + CompileFlag.enable_reduce_scatter_adapter = True + yield + CompileFlag.enable_reduce_scatter_adapter = old def factors(k: int, num: int) -> List[Tuple[int]]: @@ -35,25 +34,28 @@ def factors(k: int, num: int) -> List[Tuple[int]]: return res -def test_intra_transition(): - +def test_intra_transition(tmp_path): fshape = [256, 256] ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) src = (1, 2, 1, 4) dst = (1, 1, 1, 8) - + devs = list(range(8)) src_rvd = RVDLayout.grid(ftensor, r=src[0], v=src[1], dims=src[2:], devices=devs) - rets = IntraTransition.transition(src, dst, src_rvd, True) + rets = IntraTransition.transition(src_rvd, dst) + assert len(rets) == 1 + ret = rets[0] + assert ret[0].vec == dst + assert len(ret[1]) == 4 # one prim will handle 2 devices + # v->d will generate reduce_scatter + assert all(p.signature == 'nnscaler.runtime.adapter.reduce_scatter' for p in ret[1]) for idx, (layout, prims) in enumerate(rets): - RVDInspector.draw(src_rvd, layout, f'rvd-trans-{idx}.png') - + RVDInspector.draw(src_rvd, layout, tmp_path / 'rvd-trans-{idx}.png') -def test_transition_space(): - +def test_transition_space(tmp_path): fshape = [256, 256] ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) @@ -61,27 +63,20 @@ def test_transition_space(): dst = (1, 1, 1, 8) devs = list(range(8)) - choices = IntraPathFinder.get_device_space(ftensor, [src, dst], src_placement=devs) - print('choices:', choices) - - reverse_choices = IntraPathFinder.get_device_space(ftensor, [src, dst], dst_placement=devs) - print('reverse_choices:', reverse_choices) + choices = IntraPathFinder.get_device_space(ftensor, [src, dst], placement=devs) + assert len(choices) == 1 + # 0/4, 1/5, 2/6, 3/7 have the same indmap, different valmap + # reduce_scatter will be generated + assert choices.pop() == (0, 4, 1, 5, 2, 6, 3, 7) - # draw reverse output + # draw output for idx, choice in enumerate(choices): src_rvd = RVDLayout.grid(ftensor, r=src[0], v=src[1], dims=src[2:], devices=devs) dst_rvd = RVDLayout.grid(ftensor, r=dst[0], v=dst[1], dims=dst[2:], devices=choice) - RVDInspector.draw(src_rvd, dst_rvd, f'rvd-{idx}.png') - - # draw reverse output - for idx, choice in enumerate(reverse_choices): - src_rvd = RVDLayout.grid(ftensor, r=src[0], v=src[1], dims=src[2:], devices=choice) - dst_rvd = RVDLayout.grid(ftensor, r=dst[0], v=dst[1], dims=dst[2:], devices=devs) - RVDInspector.draw(src_rvd, dst_rvd, f'rvd-reverse-{idx}.png') + RVDInspector.draw(src_rvd, dst_rvd, tmp_path / f'rvd-{idx}.png') def test_one_f_case(): - fshape = [128, 256, 512] src_r, src_v, src_d = 1,4,(1,1,2) @@ -89,7 +84,7 @@ def test_one_f_case(): src_rvd = (src_r, src_v) + src_d dst_rvd = (dst_r, dst_v) + dst_d ndevs = src_r * src_v * np.prod(np.array(src_d)) - + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) pdevs = list(range(ndevs)) @@ -99,15 +94,132 @@ def test_one_f_case(): fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=cdevs) rvds = IntraPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) - print(f"optimal path: {' -> '.join(str(rvd) for rvd in rvds)}") + # reduce-scatter(v2d) and then all-gather(d2r) + assert rvds == ((1, 4, 1, 1, 2), (1, 1, 4, 1, 2), (2, 1, 2, 1, 2)) fprims = IntraPathFinder.path(fp_rvd, fc_rvd) - for prim in fprims: - print(prim) + assert len(fprims) == 6 + # (1, 4, 1, 1, 2) => (1, 1, 4, 1, 2) + # here the device align is found with `inner_transpose` alternative. + assert fprims[0].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[0].device == [0, 2, 4, 6] + assert fprims[1].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[1].device == [1, 3, 5, 7] + # (1, 1, 4, 1, 2), (2, 1, 2, 1, 2) + assert fprims[2].signature == 'nnscaler.runtime.adapter.all_gather' + assert fprims[2].device == [0, 4] + assert fprims[3].signature == 'nnscaler.runtime.adapter.all_gather' + assert fprims[3].device == [1, 5] + assert fprims[4].signature == 'nnscaler.runtime.adapter.all_gather' + assert fprims[4].device == [2, 6] + assert fprims[5].signature == 'nnscaler.runtime.adapter.all_gather' + assert fprims[5].device == [3, 7] + + +def test_f_reducescatter_alltoall(): + # this functio is trying to reproduce the case where reduce-scatter + all2all are used + # which sometimes can lead some bugs + # but currently we still can't reproduce the bug + # this test case is for reference + fshape = [8, 8] + + src_r, src_v, src_d = 1,2,(1,2) + dst_r, dst_v, dst_d = 1,1,(4,1) + src_rvd = (src_r, src_v) + src_d + dst_rvd = (dst_r, dst_v) + dst_d + ndevs = src_r * src_v * np.prod(np.array(src_d)) + assert ndevs == 4 + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) -def test_all_f_cases_fix_placement(): + pdevs = list(range(ndevs)) + assert pdevs == [0, 1, 2, 3] + fp_rvd = RVDLayout.grid(ftensor, r=src_r, v=src_v, dims=src_d, devices=pdevs) + fp_subtensors = { + f.device: f for f in fp_rvd.mat.flatten() + } + cdevs = list(range(ndevs)) + assert cdevs == [0, 1, 2, 3] + fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=cdevs) + fc_subtensors = { + f.device: f for f in fc_rvd.mat.flatten() + } + + rvds = IntraPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) + # reduce-scatter(v2d) and then all-2-all(d2d) + assert rvds == ((1, 2, 1, 2), (1, 1, 2, 2), (1, 1, 4, 1)) + + fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + assert len(fprims) == 4 + # (1, 2, 1, 2) => (1, 1, 2, 2) + assert fprims[0].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[0].device == [0, 2] + assert fprims[0]._inputs[0].device == (0,) + assert fprims[0]._inputs[0].indmap == ((0,8), (0,4)) + assert fprims[0]._inputs[0].valmap == (0, 2) + assert fprims[0]._inputs[0] == fp_subtensors[fprims[0]._inputs[0].device] + + assert fprims[0]._inputs[1].device == (2,) + assert fprims[0]._inputs[1].indmap == ((0,8), (0,4)) + assert fprims[0]._inputs[1].valmap == (1, 2) + assert fprims[0]._inputs[1] == fp_subtensors[fprims[0]._inputs[1].device] + + assert fprims[0]._outputs[0].device == (0,) + assert fprims[0]._outputs[0].indmap == ((0,4), (0,4)) + assert fprims[0]._outputs[0].valmap == (0, 1) + + assert fprims[0]._outputs[1].device == (2,) + assert fprims[0]._outputs[1].indmap == ((4,8), (0,4)) + assert fprims[0]._outputs[1].valmap == (0, 1) + + assert fprims[1].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[1].device == [1, 3] + assert fprims[1]._inputs[0].device == (1,) + assert fprims[1]._inputs[0].indmap == ((0,8), (4,8)) + assert fprims[1]._inputs[0].valmap == (0, 2) + assert fprims[1]._inputs[0] == fp_subtensors[fprims[1]._inputs[0].device] + + assert fprims[1]._inputs[1].device == (3,) + assert fprims[1]._inputs[1].indmap == ((0,8), (4,8)) + assert fprims[1]._inputs[1].valmap == (1, 2) + assert fprims[1]._inputs[1] == fp_subtensors[fprims[1]._inputs[1].device] + + assert fprims[1]._outputs[0].device == (1,) + assert fprims[1]._outputs[0].indmap == ((0,4), (4,8)) + assert fprims[1]._outputs[0].valmap == (0, 1) + + assert fprims[1]._outputs[1].device == (3,) + assert fprims[1]._outputs[1].indmap == ((4,8), (4,8)) + assert fprims[1]._outputs[1].valmap == (0, 1) + + # (1, 1, 2, 2) => (1, 1, 4, 1) d2d + assert fprims[2].signature == 'nnscaler.runtime.adapter.all_to_all' + assert fprims[2].device == [0, 1] + assert fprims[2]._inputs[0] == fprims[0]._outputs[0] + assert fprims[2]._inputs[0].device == fprims[0]._outputs[0].device + assert fprims[2]._inputs[1] == fprims[1]._outputs[0] + assert fprims[2]._inputs[1].device == fprims[1]._outputs[0].device + + assert fprims[2]._outputs[0].device == (0,) + assert fprims[2]._outputs[1].device == (1,) + assert fprims[2]._outputs[0] == fc_subtensors[fprims[2]._outputs[0].device] + assert fprims[2]._outputs[1] == fc_subtensors[fprims[2]._outputs[1].device] + + assert fprims[3].signature == 'nnscaler.runtime.adapter.all_to_all' + assert fprims[3].device == [2, 3] + assert fprims[3]._inputs[0] == fprims[0]._outputs[1] + assert fprims[3]._inputs[0].device == fprims[0]._outputs[1].device + assert fprims[3]._inputs[1] == fprims[1]._outputs[1] + assert fprims[3]._inputs[1].device == fprims[1]._outputs[1].device + + assert fprims[3]._outputs[0].device == (2,) + assert fprims[3]._outputs[1].device == (3,) + assert fprims[3]._outputs[0] == fc_subtensors[fprims[3]._outputs[0].device] + assert fprims[3]._outputs[1] == fc_subtensors[fprims[3]._outputs[1].device] + + +def test_all_f_cases_fix_placement(): fshape = [128, 256, 512] ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) @@ -126,10 +238,11 @@ def test_all_f_cases_fix_placement(): fctensors = fc_rvd.subtensors fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + # the above code will not raise any exception + assert True def test_all_f_cases_auto_placement(): - fshape = [128, 256, 512] ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) @@ -142,7 +255,7 @@ def test_all_f_cases_auto_placement(): pdevs = list(range(ndevs)) fp_rvd = RVDLayout.grid(ftensor, r=src_rvd[0], v=src_rvd[1], dims=src_rvd[2:], devices=pdevs) - placement, cost = IntraAutoPlacer.auto_place( + placement, cost = IntraAutoPlacer.advice( ftensor.shape, src_rvd, dst_rvd, None, None, src_placement=pdevs @@ -151,10 +264,11 @@ def test_all_f_cases_auto_placement(): fprims = IntraPathFinder.path(fp_rvd, fc_rvd) print(f'cost: {cost}') + # the above code will not raise any exception + assert True def test_one_fb_case(): - fshape = [128, 256, 512] # forward @@ -176,10 +290,12 @@ def test_one_fb_case(): bc_rvd = RVDLayout.grid(btensor, r=bdst_r, v=bdst_v, dims=bdst_d, devices=fpdevs) # forward consumer / backward producer - fcdevs, _ = IntraAutoPlacer.auto_place( + fcdevs, _ = IntraAutoPlacer.advice( fshape, (fsrc_r, fsrc_v) + fsrc_d, (fdst_r, fdst_v) + fdst_d, (bsrc_r, bsrc_v) + bsrc_d, (bdst_r, bdst_v) + bdst_d, fpdevs) - + + assert fcdevs == (0, 2, 1, 3, 4, 6, 5, 7) + fc_rvd = RVDLayout.grid(ftensor, r=fdst_r, v=fdst_v, dims=fdst_d, devices=fcdevs) # print('forward consumer tensor:') # for t in fc_rvd.mat.flatten(): @@ -189,16 +305,32 @@ def test_one_fb_case(): fprims = IntraPathFinder.path(fp_rvd, fc_rvd) bprims = IntraPathFinder.path(bp_rvd, bc_rvd) - print('forward prims:') - for prim in fprims: - print('\t', prim) - print('backward prims:') - for prim in bprims: - print('\t', prim) + assert len(fprims) == 4 + assert fprims[0].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[0].device == [0, 2] + assert fprims[1].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[1].device == [1, 3] + assert fprims[2].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[2].device == [4, 6] + assert fprims[3].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[3].device == [5, 7] + + assert len(bprims) == 6 + assert bprims[0].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert bprims[0].device == [0, 4] + assert bprims[1].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert bprims[1].device == [2, 6] + assert bprims[2].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert bprims[2].device == [1, 5] + assert bprims[3].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert bprims[3].device == [3, 7] + assert bprims[4].signature == 'nnscaler.runtime.adapter.all_gather' + assert bprims[4].device == [0, 2, 4, 6] + assert bprims[5].signature == 'nnscaler.runtime.adapter.all_gather' + assert bprims[5].device == [1, 3, 5, 7] def test_all_fb_cases_fix_placement(): - fshape = [128, 256, 512] ndevs = 8 @@ -210,16 +342,16 @@ def test_all_fb_cases_fix_placement(): fdevs = list(range(ndevs)) fp = RVDLayout.grid(ftensor, r=fp_rvd[0], v=fp_rvd[1], dims=fp_rvd[2:], devices=fdevs) - + for fc_rvd in factors(ndevs, ndims): if fc_rvd[1] != 1: continue fc = RVDLayout.grid(ftensor, r=fc_rvd[0], v=fc_rvd[1], dims=fc_rvd[2:], devices=fdevs) - + # case1: forward replica -> backward replica bp_rvd = fc_rvd bc_rvd = (fp_rvd[0] * fp_rvd[1], 1) + fp_rvd[2:] print(f'test generating | fp rvd: {fp_rvd}, fc rvd: {fc_rvd}, bp rvd: {bp_rvd}, bc rvd: {bc_rvd}') - + bp = RVDLayout.grid(btensor, r=bp_rvd[0], v=bp_rvd[1], dims=bp_rvd[2:], devices=fdevs) bc = RVDLayout.grid(btensor, r=bc_rvd[0], v=bc_rvd[1], dims=bc_rvd[2:], devices=fdevs) @@ -237,9 +369,11 @@ def test_all_fb_cases_fix_placement(): fprims = IntraPathFinder.path(fp, fc) bprims = IntraPathFinder.path(bp, bc) + # the above code will not raise any exception + assert True -def test_all_fb_cases_advisor(): +def test_all_fb_cases_advisor(): fshape = [128, 256, 512] ndevs = 8 @@ -251,15 +385,15 @@ def test_all_fb_cases_advisor(): fdevs = list(range(ndevs)) fp = RVDLayout.grid(ftensor, r=fp_rvd[0], v=fp_rvd[1], dims=fp_rvd[2:], devices=fdevs) - + for fc_rvd in factors(ndevs, ndims): if fc_rvd[1] != 1: continue - + # case1: forward replica -> backward replica bp_rvd = fc_rvd bc_rvd = (fp_rvd[0] * fp_rvd[1], 1) + fp_rvd[2:] print(f'test generating | fp rvd: {fp_rvd}, fc rvd: {fc_rvd}, bp rvd: {bp_rvd}, bc rvd: {bc_rvd}') - + placement, cost = IntraAutoPlacer.advice( fshape, fp_rvd, fc_rvd, bp_rvd, bc_rvd, fdevs) @@ -281,17 +415,9 @@ def test_all_fb_cases_advisor(): fc = RVDLayout.grid(ftensor, r=fc_rvd[0], v=fc_rvd[1], dims=fc_rvd[2:], devices=placement) bp = RVDLayout.grid(btensor, r=bp_rvd[0], v=bp_rvd[1], dims=bp_rvd[2:], devices=placement) bc = RVDLayout.grid(btensor, r=bc_rvd[0], v=bc_rvd[1], dims=bc_rvd[2:], devices=fdevs) - + fprims = IntraPathFinder.path(fp, fc) bprims = IntraPathFinder.path(bp, bc) - -if __name__ == '__main__': - # test_intra_transition() - # test_transition_space() - # test_one_f_case() - # test_all_f_cases_fix_placement() - # test_all_f_cases_auto_placement() - # test_one_fb_case() - # test_all_fb_cases_fix_placement() - test_all_fb_cases_advisor() \ No newline at end of file + # the above code will not raise any exception + assert True diff --git a/tests/graph/gener/test_layout.py b/tests/graph/gener/test_layout.py new file mode 100644 index 00000000..8158fa0c --- /dev/null +++ b/tests/graph/gener/test_layout.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import List +import numpy as np + +from nnscaler.graph.gener.rvd.layout import RVDLayout +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor + + +def test_rvd_layout(): + fshape = [128, 256, 512] + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=True) + + fsrc_r, fsrc_v, fsrc_d = 2,8,(1,1,2) + + ndevs = fsrc_r * fsrc_v * np.prod(np.array(fsrc_d)) + + fpdevs = list(range(ndevs)) + fp_rvd = RVDLayout.grid(ftensor, r=fsrc_r, v=fsrc_v, dims=fsrc_d, devices=fpdevs) + assert True + assert fp_rvd.R == fsrc_r + assert fp_rvd.V == fsrc_v + assert fp_rvd.D == fsrc_d + assert len(fp_rvd.subtensors) == ndevs + assert fp_rvd.mat.shape == (fsrc_r, fsrc_v, *fsrc_d) + # 0/1 are replicated. They should be the same. + assert np.array_equal(fp_rvd.mat[0], fp_rvd.mat[1]) + mat = fp_rvd.mat[0] + # check valmap + for i in range(fp_rvd.V): + for j in mat[i].flatten(): + j: IRSubTensor + assert j.valmap == (i, fp_rvd.V) + + # check idxmap + mat: List[IRSubTensor] = fp_rvd.mat[0][0].flatten().tolist() + for i in range(0, len(mat)//2): + assert mat[i].indmap == ((0, fshape[0]), (0, fshape[1]), (0, fshape[2]//2)) + assert mat[i + len(mat)//2].indmap == ((0, fshape[0]), (0, fshape[1]), (fshape[2]//2, fshape[2])) + + # check device + assert all([s.device == (i,) for i, s in enumerate(fp_rvd.mat.flatten().tolist())]) + diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 44cb9373..fe860d61 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -8,6 +8,7 @@ import torch import pytest +from nnscaler.flags import CompileFlag import nnscaler.graph.function.dimops from nnscaler.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph @@ -1097,6 +1098,78 @@ def test_codegen_dictout(tmp_path): ) +class ReduceScatterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 1024, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear(x) + x = self.relu(x) + return x + + +def pas_reduce_scatter(graph, cfg): + from nnscaler.ir import IRFwOperation, IRDataOperation + from nnscaler.policies import _tp, _replica + ngpus = cfg.plan_ngpus + + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'linear': + _tp(graph, node, list(range(ngpus)), 0, 1) + elif node.name == 'relu': + _tp(graph, node, list(range(ngpus)), 0, 0) + else: + _replica(graph, node, list(range(ngpus))) + return graph + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('enable_reduce_scatter_adapter', [True, False]) +def test_codegen_reduce_scatter(tmp_path, enable_reduce_scatter_adapter): + old = CompileFlag.enable_reduce_scatter_adapter + CompileFlag.enable_reduce_scatter_adapter = enable_reduce_scatter_adapter + m = ReduceScatterModule() + m.train() + parallelize( + m, + {'x': torch.randn(2, 512)}, + pas_reduce_scatter, + ComputeConfig(2, 2), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # With reduce-scatter, it should looks like: + # ... + # linear_40 = nnscaler.runtime.adapter.nn.reducescatter_allgather(linear_30, dim=0, ranks=[0, 1]) + # ... + + # without reduce-scatter, it should looks like: + # ... + # class Adapter24(torch.autograd.Function): + # @staticmethod + # def forward(ctx, linear_30): + # linear_18 = nnscaler.runtime.adapter.all_reduce(linear_30, ranks=[0, 1]) + # linear_40 = nnscaler.runtime.adapter.chunk(linear_18, dim=0, ranks=[0, 1]) + # return linear_40 + # @staticmethod + # def backward(ctx, glinear_48): + # glinear_25 = nnscaler.runtime.adapter.all_gather(glinear_48, dim=0, ranks=[0, 1]) + # return glinear_25 + # ... + CompileFlag.enable_reduce_scatter_adapter = old + if enable_reduce_scatter_adapter: + assert _gencode_contains(tmp_path, ReduceScatterModule, 0, + r"nnscaler.runtime.adapter.nn.reducescatter_allgather" + ) + else: + assert not _gencode_contains(tmp_path, ReduceScatterModule, 0, + r"nnscaler.runtime.adapter.nn.reducescatter_allgather" + ) + + class KwargsModule(torch.nn.Module): def forward(self, x): return x + torch.zeros_like(x, dtype=torch.float32) diff --git a/tests/test_utils.py b/tests/test_utils.py index dd6ec951..ffb2b2dd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ import pytest -from nnscaler.utils import select_many +from nnscaler.utils import select_many, classproperty def test_select_many(): @@ -11,3 +11,32 @@ def test_select_many(): assert list(select_many([1, [2, 3]], lambda k: k if isinstance(k, list) else [k])) == [1, 2, 3] with pytest.raises(TypeError): list(select_many([1, [2, 3]], lambda k: k)) + + +def test_classproperty_int(): + class A: + _x = 1234567 + @classproperty + def value(cls): + return cls._x + + assert A.value == 1234567 + assert id(A().value) == id(A.value) + + with pytest.raises(AttributeError): + A().value = 43 + + assert A.value == 1234567 + + +def test_classproperty_dict(): + class A: + _x = {} + @classproperty + def cfg(cls): + return cls._x.setdefault('a', {}) + + x = A.cfg + x[1] = 2 + assert A.cfg == {1: 2} + assert id(A().cfg) == id(x) From ce1e31a096f87804f5983f4dfe0f04fb009bfff1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 23 Oct 2024 09:05:14 +0000 Subject: [PATCH 1746/1892] Merged PR 2292: rvd: retry without reduce-scatter adapter when any adapters are invalid It is still impossible to repro the case when reduce-scatter adapter + all-to-all adapter introduces invalid communication. As a workaround, in this PR, if we find bad communication adapters, we will retry with reduce-scatter disabled. --- nnscaler/flags.py | 6 +- nnscaler/graph/gener/concurrent.py | 74 ++++++++++++++++++---- nnscaler/graph/gener/rvd/inter.py | 8 +-- nnscaler/graph/gener/rvd/intra.py | 12 ++-- nnscaler/ir/adapter/prim.py | 30 ++++++++- nnscaler/ir/tensor.py | 27 ++++++++ nnscaler/runtime/adapter/collectives.py | 26 ++++++-- tests/graph/gener/test_concurrent.py | 51 +++++++++++++++ tests/graph/gener/test_inter_rvd.py | 2 +- tests/graph/gener/test_intra_rvd.py | 43 ++++++++++++- tests/graph/parser/test_ast_transformer.py | 2 - tests/ir/test_tensor.py | 23 ++++++- tests/parallel_module/test_gencode.py | 12 ++-- tox.ini | 2 +- 14 files changed, 271 insertions(+), 47 deletions(-) create mode 100644 tests/graph/gener/test_concurrent.py diff --git a/nnscaler/flags.py b/nnscaler/flags.py index cc8cb4bf..77333987 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -35,10 +35,8 @@ class CompileFlag: # how to execute the functions during trace, available choices ['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'] trace_strategy = os.environ.get('TRACE_STRATEGY', default='cuda_run_cpu_offload') # reduce scatter adapter can reduce the communication cost, and improve the performance - # but sometimes it may cause communication problems, so we provide an option to enable/disable it - # by default, we disable it and use allreduce+chunk instead. - # TODO: enable it by default after we fix the parity issue - enable_reduce_scatter_adapter = _to_bool('ENABLE_REDUCE_SCATTER_ADAPTER', False) + # but sometimes it may cause communication bugs, so we provide an option to enable/disable it + disable_reduce_scatter_adapter = _to_bool('DISABLE_REDUCE_SCATTER_ADAPTER', False) # ============== runtime ==================== dev_mode = _to_bool('SINGLE_DEV_MODE') # allow to use python xx.py diff --git a/nnscaler/graph/gener/concurrent.py b/nnscaler/graph/gener/concurrent.py index 1c91ab1b..ee0fe1d8 100644 --- a/nnscaler/graph/gener/concurrent.py +++ b/nnscaler/graph/gener/concurrent.py @@ -8,9 +8,10 @@ import copy import numpy as np import logging +from contextlib import contextmanager from nnscaler.ir.tensor import IRFullTensor, IRSubTensor, IndexMap, ValueMap -from nnscaler.ir.adapter.prim import IRAdapterPrim +from nnscaler.ir.adapter.prim import IRAdapterPrim, ReduceScatterPrim, AllToAllPrim from nnscaler.ir.adapter import IRAdapter from nnscaler.ir.adapter.prim import SelectPrim, MovePrim, SumPrim, MergeDimPrim from nnscaler.ir.adapter.prim import BroadcastPrim @@ -31,10 +32,18 @@ _logger.warning('Detected disabling general communication fusion, which may have big impact on performance in certain cases.') +@contextmanager +def _temp_disable_reduce_scatter_adapter(): + assert not CompileFlag.disable_reduce_scatter_adapter, "Already disabled" + CompileFlag.disable_reduce_scatter_adapter = True + yield + CompileFlag.disable_reduce_scatter_adapter = False + + class ConcurrentGener: @staticmethod - def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], cost_fn: Optional[Callable] = None) -> Optional[IRAdapter]: """ @@ -90,7 +99,7 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], # warnings.warn('The adapter is generated using P2P communication') if fadapter is None: fadapter = ConcurrentGener.gen_general(fptensors, fctensors, bptensors, bctensors) - + if set(pdevs) == set(cdevs) and fadapter.mirror is not None: fadapter.differentiable = True fadapter.mirror.differentiable = True @@ -98,7 +107,46 @@ def gen(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], return fadapter @staticmethod - def gen_intra_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + def _path( + path_fn: Callable, + ilayout: RVDLayout, olayout: RVDLayout, + cost_fn: Optional[Callable] = None + ) -> List[IRAdapterPrim]: + prims = path_fn(ilayout, olayout, cost_fn) + if any(isinstance(prim, AllToAllPrim) and not prim.is_valid() for prim in prims): + if not CompileFlag.disable_reduce_scatter_adapter \ + and any(isinstance(prim, ReduceScatterPrim) for prim in prims): + _logger.warning( + 'Detected invalid AllToAllPrim, retrying with reduce-scatter disabled.' + 'Please report this issue to the developers.' + ) + # the problem may be caused by the ReduceScatterPrim + # let's retry without it. + with _temp_disable_reduce_scatter_adapter(): + prims = path_fn(ilayout, olayout, cost_fn) + + if any(not prim.is_valid() for prim in prims): + # will use `ConcurrentGener.gen_general` to generate adapter + raise RuntimeError('Invalid primitives detected. Please report this issue to the developers.') + + return prims + + @staticmethod + def _intra_path( + ilayout: RVDLayout, olayout: RVDLayout, + cost_fn: Optional[Callable] = None + ) -> List[IRAdapterPrim]: + return ConcurrentGener._path(IntraPathFinder.path, ilayout, olayout, cost_fn) + + @staticmethod + def _inter_path( + ilayout: RVDLayout, olayout: RVDLayout, + cost_fn: Optional[Callable] = None + ) -> List[IRAdapterPrim]: + return ConcurrentGener._path(InterPathFinder.path, ilayout, olayout, cost_fn) + + @staticmethod + def gen_intra_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], cost_fn: Optional[Callable] = None) -> IRAdapter: """ @@ -109,7 +157,7 @@ def gen_intra_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], @param bptensors List[IRSubTensor]: backward produced tensors @param bctensors List[IRSubTensor]: backward consumed tensors @param cost_fn Optional[Callable]: takes in an IRAdapterPrim and outputs a cost in float - + @return adapter IRAdapter: forward IRAdapter with backward (if has) in its .mirror attribute. """ ftensor = fptensors[0].parent @@ -124,7 +172,7 @@ def gen_intra_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], assert all(t is not None for t in ctensors), f"empty device slot {ctensors}" olayout = RVDLayout.togrid(ftensor, ctensors) # get forward primitives - fprims = IntraPathFinder.path(ilayout, olayout, cost_fn) + fprims = ConcurrentGener._intra_path(ilayout, olayout, cost_fn) fadapter = IRAdapter(fptensors, fctensors) fadapter.prims = fprims @@ -143,7 +191,7 @@ def gen_intra_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], ilayout = RVDLayout.togrid(grad, ptensors) olayout = RVDLayout.togrid(grad, bctensors) # paths, bprims = ilayout.path(olayout) - bprims = IntraPathFinder.path(ilayout, olayout, cost_fn) + bprims = ConcurrentGener._intra_path(ilayout, olayout, cost_fn) # generate backward adapter badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims @@ -152,7 +200,7 @@ def gen_intra_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], return fadapter @staticmethod - def gen_inter_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], + def gen_inter_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], bptensors: List[IRSubTensor], bctensors: List[IRSubTensor], cost_fn: Optional[Callable] = None) -> IRAdapter: """ @@ -170,7 +218,7 @@ def gen_inter_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], ftensor = fptensors[0].parent ilayout = RVDLayout.togrid(ftensor, fptensors) olayout = RVDLayout.togrid(ftensor, fctensors) - fprims = InterPathFinder.path(ilayout, olayout, cost_fn) + fprims = ConcurrentGener._inter_path(ilayout, olayout, cost_fn) fadapter = IRAdapter(fptensors, fctensors) fadapter.prims = fprims @@ -178,7 +226,7 @@ def gen_inter_rvd(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], if len(bptensors) > 0 or len(bctensors) > 0: ilayout = RVDLayout.togrid(grad, bptensors) olayout = RVDLayout.togrid(grad, bctensors) - bprims = InterPathFinder.path(ilayout, olayout, cost_fn) + bprims = ConcurrentGener._inter_path(ilayout, olayout, cost_fn) badapter = IRAdapter(bptensors, bctensors) badapter.prims = bprims IRAdapter.make_pair(fadapter, badapter) @@ -189,7 +237,7 @@ def gen_general(fptensors: List[IRSubTensor], fctensors: List[IRSubTensor], bptensors: List[IRSubTensor], bctensors: List[IRSubTensor]) -> IRAdapter: """ A general way to generate adapter. - + @param ftensor IRFullTensor @return adapter IRAdapter """ @@ -250,7 +298,7 @@ def gen_subtensor_coll(ctensors: List[IRSubTensor], ptensors: List[IRSubTensor], fuse_broadcast = False break # fuse to broadcast - if fuse_broadcast: + if fuse_broadcast: cdev_tensors, pdev_tensors = dict(), dict() for ptensor in ptensors: pdev_tensors.setdefault(ptensor.device[0], []).append(ptensor) @@ -278,7 +326,7 @@ def gen_subtensor_coll(ctensors: List[IRSubTensor], ptensors: List[IRSubTensor], def gen_subtensor(ctensor: IRSubTensor, ptensors: List[IRSubTensor], workload: Dict[int, int]) -> List[IRAdapterPrim]: """ Generate communiction primitives for ctensor - + @param ctensor IRSubTensor: the consumed tensor as destination @param ptensors List[IRSubTensor]: the produced tensors as source diff --git a/nnscaler/graph/gener/rvd/inter.py b/nnscaler/graph/gener/rvd/inter.py index 926d4c1f..a57a8a21 100644 --- a/nnscaler/graph/gener/rvd/inter.py +++ b/nnscaler/graph/gener/rvd/inter.py @@ -218,22 +218,22 @@ class InterPathFinder: inter-RVD Path finder for generating communication plans for RVDLayout """ # Key is configuration. - # Currently only CompileFlag.enable_reduce_scatter_adapter is considered + # Currently only CompileFlag.disable_reduce_scatter_adapter is considered _config_cached_inter_nodes: Dict[Tuple, Dict[Tuple[TShape, int, int], Tuple[Tuple[InterRVD]]]] = {} _config_cached_inter_edges: Dict[Tuple, Dict[Tuple[TShape, int, int], Tuple[np.ndarray]]] = {} _config_cached_inter_paths: Dict[Tuple, Dict[Tuple[TShape, int, int], Dict[TRVD, List[List[int]]]]] = {} @classproperty def _cached_inter_nodes(cls): - return cls._config_cached_inter_nodes.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + return cls._config_cached_inter_nodes.setdefault((CompileFlag.disable_reduce_scatter_adapter,), {}) @classproperty def _cached_inter_edges(cls): - return cls._config_cached_inter_edges.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + return cls._config_cached_inter_edges.setdefault((CompileFlag.disable_reduce_scatter_adapter,), {}) @classproperty def _cached_inter_paths(cls): - return cls._config_cached_inter_paths.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + return cls._config_cached_inter_paths.setdefault((CompileFlag.disable_reduce_scatter_adapter,), {}) # type annotation because type cannot be inferred from `classproperty` _cached_inter_nodes: Dict[Tuple[TShape, int, int], Tuple[Tuple[InterRVD]]] diff --git a/nnscaler/graph/gener/rvd/intra.py b/nnscaler/graph/gener/rvd/intra.py index a5d10d15..2783019e 100644 --- a/nnscaler/graph/gener/rvd/intra.py +++ b/nnscaler/graph/gener/rvd/intra.py @@ -226,7 +226,7 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD) -> List[Tuple[RVDLayout, Li otensor.cell = itensor.cell prims = [] for itensors, otensors in zip(imat.reshape(-1, chunks), omat.reshape(-1, chunks)): - prims.append(primitive(itensors, otensors)) + prims.append(primitive(itensors.tolist(), otensors.tolist())) rets.append((olayout, prims)) return rets @@ -236,7 +236,7 @@ class IntraPathFinder: intra-RVD Path finder for generating communication plans for RVDLayout """ # Key is configuration. - # Currently only CompileFlag.enable_reduce_scatter_adapter is considered + # Currently only CompileFlag.disable_reduce_scatter_adapter is considered # intra-shard: cached nodes. paths[shape][i][j] = List[int] of indices from (src -> dst] _config_cached_intra_nodes: Dict[Tuple, Dict[Tuple[TShape, int], Tuple[TRVD]]] = {} _config_cached_intra_edges: Dict[Tuple, Dict[Tuple[TShape, int], np.ndarray]] = {} @@ -244,15 +244,15 @@ class IntraPathFinder: @classproperty def _cached_intra_nodes(cls): - return cls._config_cached_intra_nodes.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + return cls._config_cached_intra_nodes.setdefault((CompileFlag.disable_reduce_scatter_adapter,), {}) @classproperty def _cached_intra_edges(cls): - return cls._config_cached_intra_edges.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + return cls._config_cached_intra_edges.setdefault((CompileFlag.disable_reduce_scatter_adapter,), {}) @classproperty def _cached_intra_paths(cls): - return cls._config_cached_intra_paths.setdefault((CompileFlag.enable_reduce_scatter_adapter,), {}) + return cls._config_cached_intra_paths.setdefault((CompileFlag.disable_reduce_scatter_adapter,), {}) # type annotation because type cannot be inferred from `classproperty` _cached_intra_nodes: Dict[Tuple[TShape, int], Tuple[TRVD]] @@ -563,7 +563,7 @@ def estimate_cost(ftensor: IRFullTensor, rvd_paths: List[Tuple[TRVD]], cost_fn: olayout: RVDLayout = RVDLayout.grid(ftensor, r=hop[0], v=hop[1], dims=hop[2:]) imat = RVDLayout.dim2last(ilayout.mat, decd, chunks) omat = RVDLayout.dim2last(olayout.mat, incd, chunks) - prim = primitive(imat.reshape(-1, chunks)[0], omat.reshape(-1, chunks)[0]) + prim = primitive(imat.reshape(-1, chunks)[0].tolist(), omat.reshape(-1, chunks)[0].tolist()) cost += cost_fn(prim) src = hop return cost diff --git a/nnscaler/ir/adapter/prim.py b/nnscaler/ir/adapter/prim.py index b67acd36..46256f5a 100644 --- a/nnscaler/ir/adapter/prim.py +++ b/nnscaler/ir/adapter/prim.py @@ -14,7 +14,6 @@ # the general adapter primitive class class IRAdapterPrim: - def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor], **kwargs): self._inputs = list(inputs) self._outputs = list(outputs) @@ -26,6 +25,17 @@ def __init__(self, inputs: List[IRSubTensor], outputs: List[IRSubTensor], **kwar # whether the primitive is happened locally self.local: bool = False + def is_valid(self) -> bool: + """ + check if the input to the adapter primitive is valid + """ + # TODO: put this check to the constructor + # In current implementation of RVDLayout optimal path search + # Invalid inputs can be generated, but then discarded later. + # In order to keep current flow, let's disable this check in construction, + # and call it after all the prims are generated + return True + def input(self, idx:int): return self._inputs[idx] @@ -397,7 +407,7 @@ def volume(self) -> int: """ ndevs = len(self.inputs()) vol = (ndevs - 1) * self.input(0).nelement() // ndevs - if not CompileFlag.enable_reduce_scatter_adapter: + if CompileFlag.disable_reduce_scatter_adapter: vol *= 100 return vol @@ -434,6 +444,22 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], idi super().__init__(itensors, otensors, idim=idim, odim=odim, **kwargs) self.signature = 'nnscaler.runtime.adapter.all_to_all' + def is_valid(self) -> bool: + """ + check if the input to all-to-all primitive is valid + """ + indmaps = [t.indmap for t in self._inputs] + + idim = self.kwargs['idim'] + odim = self.kwargs['odim'] + + # odim should be the same for all input tensors + for i in range(1, len(indmaps)): + if indmaps[i][odim] != indmaps[0][odim]: + return False + + return IRSubTensor.is_dim_continous(self._inputs, idim) + def volume(self) -> int: ndevs = len(self.inputs()) return self.input(0).nelement() * (ndevs - 1) // ndevs diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index a0526df2..b385209b 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -914,6 +914,33 @@ def common(self, other) -> Optional[IRTensor]: return sub_tensor return None + @classmethod + def is_dim_continous(cls, tensors: List['IRSubTensor'], dim: int) -> bool: + """ + Check if the tensors are continuous along a dimension + + Args: + tensors (List[IRSubTensor]): the tensors to check + dim (int): the dimension + Returns: + bool: True if continuous + Raises: + ValueError: if `tensors` is empty + """ + if not tensors: + raise ValueError("Expected a non-empty tensor list") + if dim < 0: + dim += len(tensors[0].shape) + if dim < 0 or dim >= len(tensors[0].shape): + raise ValueError(f"Expected 0 <= dim < {len(tensors[0].shape)}. Got {dim}") + indmaps = [t.indmap[dim] for t in tensors] + indmaps.sort() + # [start, end) should be continuous after sorted + for idx in range(1, len(indmaps)): + if indmaps[idx][0] != indmaps[idx-1][1]: + return False + return True + def __repr__(self) -> str: anno = 't' if self.is_attr(): diff --git a/nnscaler/runtime/adapter/collectives.py b/nnscaler/runtime/adapter/collectives.py index 5ec6fd76..2a3bafb3 100644 --- a/nnscaler/runtime/adapter/collectives.py +++ b/nnscaler/runtime/adapter/collectives.py @@ -108,8 +108,26 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, def all_to_all(tensor: torch.Tensor, idim: int, odim: int, - ranks: Tuple[int], async_op=False) -> torch.Tensor: - """All-to-all""" + ranks: Tuple[int, ...], async_op=False) -> torch.Tensor: + """ + All-to-all (but different with torch.distributed.all_to_all) + + 1. Each device will split the tensor into `len(ranks)` chunks on `odim` + 2. Send each chunk to the corresponding device with `torch.distributed.all_to_all`. + 3. Concatenate the received chunks on `idim`. + + So the overall work is to change the tensor partitioning from `idim` to `odim`. + + Args: + tensor (torch.Tensor): input tensor + idim (int): the dimension to concatenate the received chunks + odim (int): the dimension to split the tensor + ranks (Tuple[int]): the order of split tensor. + async_op (bool): whether to use async communication + + Returns: + torch.Tensor: the output tensor + """ if not async_op: CudaTimer().start(field_name='comm', predefined=True) itensors = list(tensor.chunk(len(ranks), dim=odim)) @@ -139,11 +157,11 @@ def all_to_all_single(tensor: torch.Tensor, idim: int, odim: int, group = DeviceGroup().get_group(ranks) otensor = torch.empty_like(tensor) work = torch.distributed.all_to_all_single(otensor, tensor, group=group, async_op=async_op) - + def all2all_callback(t): t = t.transpose(0, odim) if odim != 0 else t return torch.concat(tuple(t.chunk(len(ranks), dim=odim)), dim=idim) - + if work: AsyncCommHandler().submit(tensor, [work], all2all_callback) else: diff --git a/tests/graph/gener/test_concurrent.py b/tests/graph/gener/test_concurrent.py new file mode 100644 index 00000000..26e192ba --- /dev/null +++ b/tests/graph/gener/test_concurrent.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph.gener.concurrent import ConcurrentGener, CompileFlag, \ + AllToAllPrim, ReduceScatterPrim, _logger +from ...utils import catch_log + + +def test_path_retry(): + ftensor = IRFullTensor((128, 512), requires_grad=True) + indmap = [] + for dimlen in ftensor.shape: + indmap.append((0, dimlen)) + indmap[0] = (0, 2) + sub1 = ftensor.select(tuple(indmap), (0, 1)) + indmap[0] = (2, 4) + sub2 = ftensor.select(tuple(indmap), (0, 1)) + indmap[0] = (4, 6) + sub3 = ftensor.select(tuple(indmap), (0, 1)) + + wrong_called = False + right_called = False + def path_with_reduce_scatter(*args, **kwargs): + nonlocal wrong_called, right_called + if not CompileFlag.disable_reduce_scatter_adapter: + # the parameter is fake, just for testing + wrong_called = True + return [ReduceScatterPrim([sub1, sub2], [sub3], dim=0), AllToAllPrim([sub1, sub3], [sub2], idim=0, odim=1)] + else: + right_called = True + return [AllToAllPrim([sub1, sub2], [sub3], idim=0, odim=1)] + + with catch_log(_logger, 'WARNING') as log_stream: + assert ConcurrentGener._path(path_with_reduce_scatter, None, None, None) + assert right_called and wrong_called + assert 'Detected invalid AllToAllPrim' in log_stream.getvalue() + + called = 0 + def path_without_rc(*args, **kwargs): + nonlocal called + called += 1 + return [AllToAllPrim([sub1, sub3], [sub2], idim=0, odim=1)] + + with pytest.raises(RuntimeError, match='Invalid primitives detected.*'): + with catch_log(_logger) as log_stream: + ConcurrentGener._path(path_without_rc, None, None, None) + + assert called == 1 + assert 'Detected invalid AllToAllPrim' not in log_stream.getvalue() diff --git a/tests/graph/gener/test_inter_rvd.py b/tests/graph/gener/test_inter_rvd.py index 82ba5a32..08d7fe33 100644 --- a/tests/graph/gener/test_inter_rvd.py +++ b/tests/graph/gener/test_inter_rvd.py @@ -8,7 +8,7 @@ from nnscaler.graph.gener.rvd.inter import InterPathFinder import numpy as np -from .test_intra_rvd import enable_reduce_scatter_adapter +from .test_intra_rvd import enable_reduce_scatter_adapter # noqa def factors(k: int, num: int) -> List[Tuple[int]]: diff --git a/tests/graph/gener/test_intra_rvd.py b/tests/graph/gener/test_intra_rvd.py index 0c425d5a..1a784a1a 100644 --- a/tests/graph/gener/test_intra_rvd.py +++ b/tests/graph/gener/test_intra_rvd.py @@ -15,10 +15,10 @@ @pytest.fixture(autouse=True) def enable_reduce_scatter_adapter(): from nnscaler.flags import CompileFlag - old = CompileFlag.enable_reduce_scatter_adapter - CompileFlag.enable_reduce_scatter_adapter = True + old = CompileFlag.disable_reduce_scatter_adapter + CompileFlag.disable_reduce_scatter_adapter = False yield - CompileFlag.enable_reduce_scatter_adapter = old + CompileFlag.disable_reduce_scatter_adapter = old def factors(k: int, num: int) -> List[Tuple[int]]: @@ -219,6 +219,43 @@ def test_f_reducescatter_alltoall(): assert fprims[3]._outputs[1] == fc_subtensors[fprims[3]._outputs[1].device] +def print_prims(fp_rvd, fc_rvd, fprims): + print('fp_rvd:') + for f in fp_rvd.mat.flatten(): + print(f'\tdevice({f.device[0]}): indmap({f.indmap}) | valmap({f.valmap})') + + print('prims:') + for f in fprims: + print(f.signature) + for i, t in enumerate(f._inputs): + print(f'\tinput {i}: device({t.device[0]}) | indmap({t.indmap}) | valmap({t.valmap})') + for i, t in enumerate(f._outputs): + print(f'\toutput {i}: device({t.device[0]}) | indmap({t.indmap}) | valmap({t.valmap})') + + print('fc_rvd:') + for f in fc_rvd.mat.flatten(): + print(f'\tdevice({f.device[0]}): indmap({f.indmap}) | valmap({f.valmap})') + + +def test_align(): + fshape = [8, 8] + + src_r, src_v, src_d = 1,2,(1,2) + dst_r, dst_v, dst_d = 1,1,(4,1) + + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + + pdevs = [0, 2, 1, 3] + fp_rvd = RVDLayout.grid(ftensor, r=src_r, v=src_v, dims=src_d, devices=pdevs) + + cdevs = [0, 1, 2, 3] + fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=cdevs) + + rvds = ((1,2,1,2), (1, 1, 1, 4), (1, 1, 4, 1)) + align, all_prims = IntraPathFinder.device_align(fp_rvd, fc_rvd, rvds) + assert True + + def test_all_f_cases_fix_placement(): fshape = [128, 256, 512] ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 3aa0f13a..00f1bbde 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -68,8 +68,6 @@ def test_ifexpr_transfomer(): assert not modified - - @pytest.mark.skipif(sys.version_info < (3, 9), reason='ast.unparse is not available in python3.8') def test_op_transfomer(): tree = ast.parse(dedent(''' diff --git a/tests/ir/test_tensor.py b/tests/ir/test_tensor.py index 724c88c0..59e528f4 100644 --- a/tests/ir/test_tensor.py +++ b/tests/ir/test_tensor.py @@ -3,9 +3,10 @@ from nnscaler.ir.tensor import IRSubTensor, IRFullTensor +import pytest -def test_tensor_grad(): +def test_tensor_grad(): ftensor = IRFullTensor((128, 512), requires_grad=True) subtensor = ftensor.tosub() @@ -18,3 +19,23 @@ def test_tensor_grad(): assert ftensor.grad is None assert subtensor.grad is None assert subtensor.requires_grad is False + + +def test_continous(): + ftensor = IRFullTensor((128, 512), requires_grad=True) + with pytest.raises(ValueError): + IRSubTensor.is_dim_continous([], dim=0) + + indmap = [] + for dimlen in ftensor.shape: + indmap.append((0, dimlen)) + indmap[0] = (0, 2) + sub1 = ftensor.select(tuple(indmap), (0, 1)) + indmap[0] = (2, 4) + sub2 = ftensor.select(tuple(indmap), (0, 1)) + indmap[0] = (4, 6) + sub3 = ftensor.select(tuple(indmap), (0, 1)) + + assert IRSubTensor.is_dim_continous([sub1, sub2, sub3], dim=0) + assert not IRSubTensor.is_dim_continous([sub1, sub2, sub3], dim=1) + assert not IRSubTensor.is_dim_continous([sub1, sub3], dim=0) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index fe860d61..5dd2c3c6 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -1126,10 +1126,10 @@ def pas_reduce_scatter(graph, cfg): @replace_all_device_with('cpu') -@pytest.mark.parametrize('enable_reduce_scatter_adapter', [True, False]) -def test_codegen_reduce_scatter(tmp_path, enable_reduce_scatter_adapter): - old = CompileFlag.enable_reduce_scatter_adapter - CompileFlag.enable_reduce_scatter_adapter = enable_reduce_scatter_adapter +@pytest.mark.parametrize('disable_reduce_scatter_adapter', [True, False]) +def test_codegen_reduce_scatter(tmp_path, disable_reduce_scatter_adapter): + old = CompileFlag.disable_reduce_scatter_adapter + CompileFlag.disable_reduce_scatter_adapter = disable_reduce_scatter_adapter m = ReduceScatterModule() m.train() parallelize( @@ -1159,8 +1159,8 @@ def test_codegen_reduce_scatter(tmp_path, enable_reduce_scatter_adapter): # glinear_25 = nnscaler.runtime.adapter.all_gather(glinear_48, dim=0, ranks=[0, 1]) # return glinear_25 # ... - CompileFlag.enable_reduce_scatter_adapter = old - if enable_reduce_scatter_adapter: + CompileFlag.disable_reduce_scatter_adapter = old + if not disable_reduce_scatter_adapter: assert _gencode_contains(tmp_path, ReduceScatterModule, 0, r"nnscaler.runtime.adapter.nn.reducescatter_allgather" ) diff --git a/tox.ini b/tox.ini index fd753382..b1236354 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = py38,py310 +envlist = py310 skipsdist = True [testenv] From 2e7cbd8c1ff9c6c89f4f1ef6a28f3c041491937b Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 24 Oct 2024 01:35:08 +0000 Subject: [PATCH 1747/1892] Merged PR 2293: add dis test add dis test --- .../fx/concrete_trace_utils/concrete_proxy.py | 12 +- .../fx/concrete_trace_utils/frame_utils.py | 34 ++- tests/graph/tracer/test_dis.py | 211 ++++++++++++++++++ 3 files changed, 246 insertions(+), 11 deletions(-) create mode 100644 tests/graph/tracer/test_dis.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index ef44c608..1c1d33d5 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -20,7 +20,7 @@ from . import concrete_tracer as et from . import pytree_utils, orig_func, wrap_utils, trace_strategy -from .frame_utils import get_frame_record, get_instruction +from .frame_utils import get_frame_record, get_instructions _logger = logging.getLogger(__name__) @@ -77,8 +77,8 @@ def __call__(self, *args, **kwargs) -> ConcreteProxy: return self.value.__call__(*args, **kwargs) return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) - def __iter__(self) -> Union[Iterable, ConcreteProxy]: - insts, cur = get_instruction(1) + def __iter__(self) -> Union[Iterable, ConcreteProxy]: + insts, cur = get_instructions(1) if insts[cur].opcode == self.op_call_ex: # in executing func(..., *proxy) @@ -109,7 +109,7 @@ def __next__(self) -> ConcreteProxy: return self.tracer.create_proxy('call_function', next, (self,), {}) def __len__(self) -> Union[int, ConcreteProxy]: - insts, cur = get_instruction(1) + insts, cur = get_instructions(1) if insts[cur].opcode == self.op_call_ex: # in executing func(..., *proxy) @@ -132,7 +132,7 @@ def __setitem__(self, *args, **kwargs) -> ConcreteProxy: return self.tracer.create_proxy('call_function', orig_func.setitem, (self,) + args, kwargs) def __bool__(self) -> Union[bool, ConcreteProxy]: - insts, cur = get_instruction(1) + insts, cur = get_instructions(1) if insts[cur].opcode in self.jump_opcodes or ( insts[cur].opcode in self.jump_before_opcodes and insts[cur + 1].opcode in self.jump_opcodes): @@ -181,7 +181,7 @@ def __exit__(self, exc_type, exc_value, traceback): @compatibility(is_backward_compatible=True) def keys(self): - insts, cur = get_instruction(1) + insts, cur = get_instructions(1) if insts[cur].opcode == self.op_call_ex or insts[cur].opcode == self.op_dict_merge: # in executing `**proxy` diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py index 09e470d8..c509dc85 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py @@ -9,13 +9,21 @@ import sys import traceback -from typing import List, Optional +from typing import List, Tuple, Optional -def get_instruction(back_times=1) -> dis.Instruction: +def get_instructions(back_times=1) -> Tuple[List[dis.Instruction], int]: """ - Get the instruction of the (back_times)-th frame from the bottom. - By default (back_times=1), the instruction of the frame who call this function will be returned. + Get the instructions of the (back_times)-th frame from the bottom. + + Args: + back_times: The number of frames to go back. + By default (back_times=1), the instruction of the frame who call this function will be returned. + + Returns: + A tuple of two elements: + - A list of dis.Instruction objects in frame. + - The index of the current instruction in the list. """ frame = inspect.currentframe() assert frame is not None @@ -40,13 +48,29 @@ def get_instruction(back_times=1) -> dis.Instruction: # From python doc: # EXTENDED_ARG(ext): Prefixes any opcode which has an argument too big to fit into the default one byte. - # ext holds an additional byte which act as higher bits in the argument. + # ext holds an additional byte which act as higher bits in the argument. # For each opcode, at most three prefixal EXTENDED_ARG are allowed, forming an argument from two-byte to four-byte. while insts[cur].opname == 'EXTENDED_ARG': cur += 1 return insts, cur +def get_last_instruction(back_times=1) -> dis.Instruction: + """ + Get the current instruction of the (back_times)-th frame from the bottom. + + Args: + back_times: The number of frames to go back. + By default (back_times=1), the instruction of the frame who call this function will be returned. + + Returns: + The current instruction in that frame. + """ + # +1 because the first frame is the frame of get_last_instruction + insts, cur = get_instructions(back_times + 1) + return insts[cur] + + @dataclass class FrameRecord: filename: str diff --git a/tests/graph/tracer/test_dis.py b/tests/graph/tracer/test_dis.py new file mode 100644 index 00000000..c0a741b5 --- /dev/null +++ b/tests/graph/tracer/test_dis.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys + +import pytest + +from nnscaler.graph.parser.fx.concrete_trace_utils.frame_utils import get_last_instruction, get_instructions + + +class A: + def __init__(self) -> None: + self.caller_inst = None + self.len_caller_inst = None + def __iter__(self): + self.caller_inst = get_last_instruction() + return iter([1, 2, 3]) + def __len__(self): + self.len_caller_inst = get_last_instruction() + return 3 + + +class B: + def __init__(self) -> None: + self.value = {'1':2, '3':4, '5':6} + self.caller_inst = None + self.getitem_count = 0 + def __iter__(self): + return iter(self.value) + def __getitem__(self, key): + self.getitem_count += 1 + return self.value[key] + def __len__(self): + return len(self.value) + def keys(self): + self.caller_inst = get_last_instruction() + return self.value.keys() + def values(self): + return self.value.values() + + +class C: + def __init__(self) -> None: + self.caller_inst = None + def __bool__(self): + self.caller_inst = get_last_instruction() + return True + + +def func0(*args, **kwargs): + pass + + +def test_for(): + a = A() + for _ in a: + break + assert a.caller_inst.opname == 'GET_ITER' + assert a.len_caller_inst is None + + +def test_single_starargs(): + a = A() + func0(*a) + assert a.caller_inst.opname == 'CALL_FUNCTION_EX' + assert a.len_caller_inst.opname == 'CALL_FUNCTION_EX' + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason='behavior is different in python3.8') +def test_multi_starargs(): + # in <= python 3.8 + # the opname will be BUILD_TUPLE_UNPACK_WITH_CALL + # in >= python 3.9 + # the opname will be LIST_EXTEND + a = A() + func0(*[1,2], *a) + assert a.caller_inst.opname == 'LIST_EXTEND' + assert a.len_caller_inst.opname == 'LIST_EXTEND' + + a = A() + func0(*a, *[1,2]) + assert a.caller_inst.opname == 'LIST_EXTEND' + assert a.len_caller_inst.opname == 'LIST_EXTEND' + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason='behavior is different in python3.8') +def test_normal_item_with_starargs(): + # in <= python 3.8 + # the opname will be BUILD_LIST_UNPACK + # in >= python 3.9 + # the opname will be LIST_EXTEND + a = A() + [1,2, *a] + assert a.caller_inst.opname == 'LIST_EXTEND' + assert a.len_caller_inst.opname == 'LIST_EXTEND' + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason='behavior is different in python3.8') +def test_normal_item_with_starargs2(): + # in <= python 3.8 + # the opname will be BUILD_TUPLE_UNPACK + # in >= python 3.9 + # the opname will be LIST_EXTEND + a = A() + (1,2, *a) + assert a.caller_inst.opname == 'LIST_EXTEND' + assert a.len_caller_inst.opname == 'LIST_EXTEND' + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason='behavior is different in python3.8') +def test_extend(): + a = A() + [1,2].extend(a) + assert a.caller_inst.opname == 'CALL_METHOD' + assert a.len_caller_inst.opname == 'CALL_METHOD' + + [1, *a] + assert a.caller_inst.opname == 'LIST_EXTEND' # BUILD_LIST_UNPACK in python 3.8 + assert a.len_caller_inst.opname == 'LIST_EXTEND' + + (1, *a) + assert a.caller_inst.opname == 'LIST_EXTEND' # BUILD_TUPLE_UNPACK in python 3.8 + assert a.len_caller_inst.opname == 'LIST_EXTEND' + + +def test_unpack(): + a = A() + x, y, z = a + assert a.caller_inst.opname == 'UNPACK_SEQUENCE' + assert a.len_caller_inst is None + + +def test_dict_keys1(mocker): + b = B() + mock1 = mocker.patch.object(b, '__iter__') + mock2 = mocker.patch.object(b, '__len__') + mock3 = mocker.patch.object(b, 'keys', side_effect=b.keys) + mock4 = mocker.patch.object(b, 'values') + + func0(**b) + assert mock1.call_count == 0 + assert mock2.call_count == 0 + assert mock3.call_count == 1 + assert mock4.call_count == 0 + assert b.getitem_count == 3 # 3 times of __getitem__ + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason='behavior is different in python3.8') +def test_dict_key2(): + b = B() + func0(**b) + assert b.caller_inst.opname == 'DICT_MERGE' # CALL_FUNCTION_EX in python 3.8 + + b = B() + func0(**b, **{'a': 1}) + assert b.caller_inst.opname == 'DICT_MERGE' # BUILD_MAP_UNPACK_WITH_CALL in python 3.8 + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason='behavior is different in python3.8') +def test_dict_key3(): + b = B() + {'a': 1, **b} + assert b.caller_inst.opname == 'DICT_UPDATE' # BUILD_MAP_UNPACK in python 3.8 + b.caller_inst = None + + {**b, **{'a': 1}} + assert b.caller_inst.opname == 'DICT_UPDATE' # BUILD_MAP_UNPACK in python 3.8 + + +def test_bool(): + c0 = 1 + c1 = 0 + c = C() + not c # UNARY_NOT + assert c.caller_inst.opname == 'UNARY_NOT' + c.caller_inst = None + + x = {c: c} + bool(x[c]) # CALL_FUNCTION + assert c.caller_inst.opname == 'CALL_FUNCTION' + + c and 1 # JUMP_IF_FALSE_OR_POP + assert c.caller_inst.opname == 'JUMP_IF_FALSE_OR_POP' + c.caller_inst = None + + c or 1 # JUMP_IF_TRUE_OR_POP + assert c.caller_inst.opname == 'JUMP_IF_TRUE_OR_POP' + c.caller_inst = None + + bool(c) # CALL_FUNCTION + assert c.caller_inst.opname == 'CALL_FUNCTION' + c.caller_inst = None + + if c: # POP_JUMP_IF_FALSE + pass + assert c.caller_inst.opname == 'POP_JUMP_IF_FALSE' + c.caller_inst = None + + if not c: # POP_JUMP_IF_TRUE + pass + assert c.caller_inst.opname == 'POP_JUMP_IF_TRUE' + c.caller_inst = None + + if bool(c): # CALL_FUNCTION + pass + assert c.caller_inst.opname == 'CALL_FUNCTION' + c.caller_inst = None + + x = 1 if c else 0 # POP_JUMP_IF_FALSE + assert c.caller_inst.opname == 'POP_JUMP_IF_FALSE' + c.caller_inst = None From 16a1635661b141694521ba01aa1e59d20a4cc4d3 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 24 Oct 2024 02:52:07 +0000 Subject: [PATCH 1748/1892] Merged PR 2291: update llama 128k example args --- examples/llama3_8B_128K/README.md | 17 +++++++++++++++++ examples/llama3_8B_128K/train.py | 17 +++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/examples/llama3_8B_128K/README.md b/examples/llama3_8B_128K/README.md index 6dd30eda..449d2fcc 100644 --- a/examples/llama3_8B_128K/README.md +++ b/examples/llama3_8B_128K/README.md @@ -104,6 +104,23 @@ We execute the training script on a node with 8xH100 80GB HBM3. The time cost is - add more devices to avoid recomputation: in order to fit the model into the memory, we recompute by layer. - do more kernel optimizations. For example, the swiglu activation can be fused into the matmul ahead of it. +## Trace Strategy + +During compiling, the time cost of trace model graph can vary significantly depending on the tracing strategy employed. Below are some reference time to trace `meta-llama/Meta-Llama-3-8B-Instruct` with different strategies and different context length, the time tested on one single A100 80GB: + +| Strategy | Context Length | Time/seconds | +| :------: | :------------: | :----------: | +| `reuse_cache` | 8k | 8.11 | +| `reuse_cache` | 32k | 11.06 | +| `reuse_cache` | 64k | 15.36 | +| `reuse_cache` | 128k | 26.29 | +| `cuda_run_cpu_offload` | 8k | 55.28 | +| `cuda_run_cpu_offload` | 32k | 194.27 | +| `cuda_run_cpu_offload` | 64k | 342.15 | +| `cuda_run_cpu_offload` | 128k | 789.15 | + +The trace strategy can be changed by setting `--trace_strategy` option. Please note that different strategies have different applicable scenarios. For more information and explanation to the different strategies, please read `docs/source/parallel_module.md`. + # Debugging Since the 128K config is challenging, it is recommended to use a smaller model for debugging. For example, you can use the following command to prepare data and train a smaller llama3 (same architecture, but with 4 decoder layers) model on two GPUs. diff --git a/examples/llama3_8B_128K/train.py b/examples/llama3_8B_128K/train.py index 47915dd7..10d5b776 100644 --- a/examples/llama3_8B_128K/train.py +++ b/examples/llama3_8B_128K/train.py @@ -157,12 +157,13 @@ def collate_fn(samples): use_zero=True, use_end2end=True, # autodist config: - # - memory constraint is set to 64GB + # - memory constraint default value is 64GB # - recompute by the transformer layer in Llama pas_config={ - 'mem_constraint': 64, + 'mem_constraint': args.gpu_mem_constraint, 'recompute_modules': 'LlamaDecoderLayer', }, + trace_strategy=args.trace_strategy, ) model_config = ModelConfig( @@ -283,6 +284,18 @@ def collate_fn(samples): type=str, help='transformers model id', ) + parser.add_argument( + '--gpu_mem_constraint', + default=64, + type=int, + help='the max memory usage constraint (GB) per GPU during nnscaler generating distribution plan, recommended to be 80 percent of GPU memory', + ) + parser.add_argument( + '--trace_strategy', + default='reuse_cache', + type=str, + help='trace strategy control the function execution during tracing model graph, `cuda_run_cpu_offload` and `reuse_cache` are recommended, please read `docs/source/parallel_module.md` for more information', + ) args = parser.parse_args() main(args) From f5688b2c34b3c47d840eb0c610e952563314e8c1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 25 Oct 2024 08:25:31 +0000 Subject: [PATCH 1749/1892] Merged PR 2294: wrapnn: integrate to parallelize integrate wrapnn to parallelize --- nnscaler/graph/function/wrapnn.py | 67 ++++++++++++--- nnscaler/parallel.py | 13 +-- nnscaler/program.py | 14 ++-- tests/parallel_module/test_normlayer.py | 105 ++++++++++++++++++------ 4 files changed, 151 insertions(+), 48 deletions(-) diff --git a/nnscaler/graph/function/wrapnn.py b/nnscaler/graph/function/wrapnn.py index 300250c5..fa3cd4e1 100644 --- a/nnscaler/graph/function/wrapnn.py +++ b/nnscaler/graph/function/wrapnn.py @@ -14,6 +14,7 @@ At last, we provide a utility function to replace the original nn modules with the wrapped nn modules. """ +from contextlib import contextmanager from dataclasses import dataclass from typing import Tuple, List, Dict from typing import Tuple @@ -138,12 +139,12 @@ def batchnorm2d_annotation_fn(*inputs, **kwargs): input, weight, bias, running_mean, running_var, num_batches_tracked = inputs """ Restrictions: - 1. If `weight` is None, then `bias` must also be None. This is because in the absence of `weight`, + 1. If `weight` is None, then `bias` must also be None. This is because in the absence of `weight`, BatchNorm2d does not apply affine transformation, which means there is no need for `bias`. - 2. If `running_mean` is None, then `running_var` and `num_batches_tracked` must also be None. - This is because `running_mean` and `running_var` are used for tracking the statistics of - the batch normalization during training. If `running_mean` is not provided, it implies - that the module should not track statistics, hence `running_var` and `num_batches_tracked` + 2. If `running_mean` is None, then `running_var` and `num_batches_tracked` must also be None. + This is because `running_mean` and `running_var` are used for tracking the statistics of + the batch normalization during training. If `running_mean` is not provided, it implies + that the module should not track statistics, hence `running_var` and `num_batches_tracked` should also be absent. Reference: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html """ @@ -270,11 +271,11 @@ def emit_batchnorm2d( """ This function wraps the original InstanceNorm2d forward function. - + The logic in this function is exactly the same as in the original PyTorch implementation. - We copied the logic here to register it as a customized operation because nnscaler's - `register_op` only supports functions, not nn.Module classes. Therefore, this function - serves as a wrapper around the InstanceNorm2d forward logic, treating the entire function + We copied the logic here to register it as a customized operation because nnscaler's + `register_op` only supports functions, not nn.Module classes. Therefore, this function + serves as a wrapper around the InstanceNorm2d forward logic, treating the entire function as a black-box leaf node in nnscaler. """ @@ -451,11 +452,17 @@ def instancenorm2d_reinit(module: _InstanceNorm) -> _InstanceNorm: } -def convert_to_wrapnn(module: torch.nn.Module): +_ORIGINAL_MODULE_ATTR = "__nnscaler_original_module__" + + +def convert_to_wrapnn(module: torch.nn.Module) -> torch.nn.Module: """Traverse the module and replace the original nn module with its wrapped version if it is in the `wrapped_modules`. Currently `wrapped_modules` contains `BatchNorm2d` and `InstanceNorm2d`. + Please note the child modules of the input module will be replaced in-place. + You can use `undo_convert_to_wrapnn` to revert the changes. + It is necessary to call this function on user instantiated model before parallelizing the it, otherwise the modules in `wrapped_modules` cannot be partitioned, but be always replicated. @@ -464,10 +471,48 @@ def convert_to_wrapnn(module: torch.nn.Module): does not have the modules in `wrapped_modules`. """ if type(module) in wrapped_modules: - return wrapped_modules[type(module)](module) + wrapped = wrapped_modules[type(module)](module) + # module will be save to children module if we use setattr(wrapped,...) + object.__setattr__(wrapped, _ORIGINAL_MODULE_ATTR, module) + return wrapped for name, child in module.named_children(): module.add_module( name, convert_to_wrapnn(child) ) # will inplace replace the module with the same name return module + + +def undo_convert_to_wrapnn(module: torch.nn.Module) -> torch.nn.Module: + """ + Undo the effect of `convert_to_wrapnn` function. + """ + if hasattr(module, _ORIGINAL_MODULE_ATTR): + origin_module = getattr(module, _ORIGINAL_MODULE_ATTR) + delattr(module, _ORIGINAL_MODULE_ATTR) + return origin_module + + for name, child in module.named_children(): + module.add_module( + name, undo_convert_to_wrapnn(child) + ) # will inplace replace the module with the same name + return module + + +@contextmanager +def wrapnn(module: torch.nn.Module, *, restore: bool = True): + """ + wrap the nn module and undo the wrap after the context. + Args: + module: the nn module to wrap + restore: whether to restore the original module after the context + Returns: + the wrapped module + """ + try: + yield convert_to_wrapnn(module) + finally: + # just restore the original module inplace + # return value is discarded + if restore: + undo_convert_to_wrapnn(module) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 47bb5141..86c3d5eb 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -28,6 +28,7 @@ from nnscaler.graph import parser from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.function.wrapnn import convert_to_wrapnn, wrapnn from nnscaler.graph.gener.gen import IRAdapterGener from nnscaler.graph.parser.fx.parser import FxModuleParser from nnscaler.graph.schedule.predefined import PredefinedSched @@ -772,11 +773,13 @@ def _gencode( ) torch.save(meta_info, origin_module_metadata_ckp) - graph, forward_args = _gen_graph( - module, dummy_forward_args, outdir, - constant_folding=compute_config.constant_folding, end2end_mode=compute_config.use_end2end, - inference_only=compute_config.inference_only, - ) + with wrapnn(module, restore=not is_module_class) as wrapped_module: + graph, forward_args = _gen_graph( + wrapped_module, dummy_forward_args, outdir, + constant_folding=compute_config.constant_folding, end2end_mode=compute_config.use_end2end, + inference_only=compute_config.inference_only, + ) + graph.dump(graph_ckp) torch.save(forward_args, forward_args_ckp) diff --git a/nnscaler/program.py b/nnscaler/program.py index 4be41dca..7fdd126e 100644 --- a/nnscaler/program.py +++ b/nnscaler/program.py @@ -10,6 +10,7 @@ from nnscaler.graph import IRGraph from nnscaler.graph import parser +from nnscaler.graph.function.wrapnn import wrapnn from nnscaler.runtime.module import CubeModule from nnscaler.runtime.device import DeviceGroup @@ -233,10 +234,11 @@ def __call__(self, *args): dummy_input[str(name)] = value self.dummy_input = dummy_input # parse graph - self._ir_graph = parser.convert_model( - self.model, - dummy_input=self.dummy_input, - attr_savedir=self.attr_savedir, - constant_folding=self.constant_folding - ) + with wrapnn(self.model) as wrapped_model: + self._ir_graph = parser.convert_model( + wrapped_model, + dummy_input=self.dummy_input, + attr_savedir=self.attr_savedir, + constant_folding=self.constant_folding + ) return self._ir_graph(*args) diff --git a/tests/parallel_module/test_normlayer.py b/tests/parallel_module/test_normlayer.py index 243a56d5..aa9918be 100644 --- a/tests/parallel_module/test_normlayer.py +++ b/tests/parallel_module/test_normlayer.py @@ -7,12 +7,14 @@ import torch import pytest import random +from unittest.mock import patch + import nnscaler from nnscaler.graph.graph import IRGraph from nnscaler.ir.operator import IRFwOperation from nnscaler.runtime.device import DeviceGroup from tests.parallel_module.test_gencode import _gencode_contains -from nnscaler.graph.function.wrapnn import convert_to_wrapnn +from nnscaler.graph.function.wrapnn import convert_to_wrapnn, wrapnn, NnScalerBatchNorm2d, undo_convert_to_wrapnn, _ORIGINAL_MODULE_ATTR from nnscaler.parallel import parallelize, ComputeConfig from torch.nn.parallel import DistributedDataParallel as DDP @@ -77,18 +79,23 @@ def forward(self, x): def _gencode_batchnorm2d_function(tempdir, config, pas_policy): init_distributed() m = BatchNorm2dModule().cuda() - m_2d = convert_to_wrapnn(m) x = torch.randn(8, 8, 32, 32).cuda() - m_new = parallelize( - m_2d, - {"x": x}, - pas_policy, - config, - gen_savedir=tempdir, - load_module=True, - reuse="override", - ) + with patch("nnscaler.graph.function.wrapnn.undo_convert_to_wrapnn", side_effect=undo_convert_to_wrapnn) as px: + m_new = parallelize( + m, + {"x": x}, + pas_policy, + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + px.assert_called() + # bn should be restored after parallelize + assert isinstance(m.bn, torch.nn.BatchNorm2d) + assert not hasattr(m.bn, _ORIGINAL_MODULE_ATTR) + assert m_new is not None m_new.train() output = m_new(x) @@ -128,12 +135,11 @@ def _gencode_batchnorm2d_function_2(tempdir, config, pas_policy): device = torch.device(f"cuda:{rank_id}") m = BatchNorm2dModule().to(device) - m_2d = convert_to_wrapnn(m) shared_data = generate_parallel_data((8, 8, 32, 32), device, dtype) x_part = shared_data[rank_id] m_new = parallelize( - m_2d, + m, {"x": x_part}, pas_policy, config, @@ -193,13 +199,12 @@ def _gencode_batchnorm2d_function_4(tempdir, config, pas_policy, dim): device = torch.device(f"cuda:{rank_id}") m = BatchNorm2dModule().to(device) - m_2d = convert_to_wrapnn(m) x_list = generate_parallel_data((8, 8, 32, 32), device, dtype) x = x_list[rank_id // 2] m_new = parallelize( - m_2d, + m, {"x": x}, lambda graph, resource: pas_policy(graph, resource, dim), config, @@ -274,7 +279,6 @@ def _gencode_batchnorm2d_function_eval(tempdir, config, pas_policy): init_distributed() m = BatchNorm2dModule().cuda() x = torch.randn(8, 8, 32, 32).cuda() - m = convert_to_wrapnn(m) m_new = parallelize( m, {"x": x}, @@ -324,12 +328,11 @@ def _gencode_batchnorm2d_function_eval_2(tempdir, config, pas_policy): device = torch.device(f"cuda:{rank_id}") m = BatchNorm2dModule().to(device) - m_2d = convert_to_wrapnn(m) shared_data = generate_parallel_data((4, 8, 32, 32), device, dtype) x_part = shared_data[rank_id] m_new = parallelize( - m_2d, + m, {"x": x_part}, pas_policy, config, @@ -395,13 +398,12 @@ def _gencode_batchnorm2d_function_eval_4(tempdir, config, pas_policy, dim): device = torch.device(f"cuda:{rank_id}") m = BatchNorm2dModule().to(device) - m_2d = convert_to_wrapnn(m) x_list = generate_parallel_data((8, 8, 32, 32), device, dtype) x = x_list[rank_id // 2] m_new = parallelize( - m_2d, + m, {"x": x}, lambda graph, resource: pas_policy(graph, resource, dim), config, @@ -489,7 +491,6 @@ def forward(self, x): def _gencode_instancenorm2d_function(tempdir, config, pas_policy): init_distributed() m = InstanceNorm2dModule().cuda() - m = convert_to_wrapnn(m) m_new = parallelize( m, {"x": torch.randn(4, 4, 32, 32).cuda()}, @@ -529,7 +530,6 @@ def _gencode_instancenorm2d_function_2(tempdir, config, pas_policy): init_random() device = torch.device(f"cuda:{rank_id}") m = InstanceNorm2dModule().cuda() - m = convert_to_wrapnn(m) shared_data = generate_parallel_data((2, 4, 32, 32), device, dtype) x_part = shared_data[rank_id] @@ -582,7 +582,6 @@ def _gencode_instancenorm2d_function_4(tempdir, config, pas_policy): init_random() device = torch.device(f"cuda:{rank_id}") m = InstanceNorm2dModule().cuda() - m = convert_to_wrapnn(m) x_list = generate_parallel_data((2, 4, 32, 32), device, dtype) x = x_list[rank_id // 2] @@ -631,7 +630,6 @@ def test_codegen_instancenorm2d_2_4(): def _gencode_instancenorm2d_function_eval(tempdir, config, pas_policy): init_distributed() m = InstanceNorm2dModule().cuda() - m = convert_to_wrapnn(m) m.eval() m_new = parallelize( m, @@ -676,7 +674,6 @@ def _gencode_instancenorm2d_function_eval_2(tempdir, config, pas_policy): init_random() device = torch.device(f"cuda:{rank_id}") m = InstanceNorm2dModule().cuda() - m = convert_to_wrapnn(m) shared_data = generate_parallel_data((2, 4, 32, 32), device, dtype) x_part = shared_data[rank_id] @@ -733,7 +730,6 @@ def _gencode_instancenorm2d_function_eval_4(tempdir, config, pas_policy): init_random() device = torch.device(f"cuda:{rank_id}") m = InstanceNorm2dModule().cuda() - m = convert_to_wrapnn(m) x_list = generate_parallel_data((2, 4, 32, 32), device, dtype) x = x_list[rank_id // 2] @@ -781,3 +777,60 @@ def test_codegen_instancenorm2d_2_4_eval(): ComputeConfig(2, 4), policy, ) + + +class NestedBatchNorm2dModule(torch.nn.Module): + def __init__(self): + super(NestedBatchNorm2dModule, self).__init__() + self.nested = BatchNorm2dModule() + self.linear = torch.nn.Linear(8, 8) + + def forward(self, x): + # doesn't care about forward + pass + + +def test_convert_to_wrapnn(): + m = NestedBatchNorm2dModule() + + def check_converted(mc): + assert len(list(mc.children())) == 2 + assert isinstance(mc.nested, BatchNorm2dModule) + assert len(list(mc.nested.children())) == 1 + assert isinstance(mc.nested.bn, NnScalerBatchNorm2d) + assert len(list(mc.nested.bn.children())) == 0 + assert id(m.linear) == id(mc.linear) + assert len(list(m.modules())) == len(list(mc.modules())) + + assert isinstance(getattr(mc.nested.bn, _ORIGINAL_MODULE_ATTR), torch.nn.BatchNorm2d) + assert not hasattr(mc.linear, _ORIGINAL_MODULE_ATTR) + assert not hasattr(mc, _ORIGINAL_MODULE_ATTR) + assert not hasattr(mc.nested, _ORIGINAL_MODULE_ATTR) + + def check_undo_converted(mcc): + assert len(list(mcc.children())) == 2 + assert isinstance(mcc.nested, BatchNorm2dModule) + assert len(list(mcc.nested.children())) == 1 + assert not isinstance(mcc.nested.bn, NnScalerBatchNorm2d) + assert len(list(mcc.nested.bn.children())) == 0 + assert id(m.linear) == id(mcc.linear) + assert not hasattr(mc.linear, _ORIGINAL_MODULE_ATTR) + assert len(list(m.modules())) == len(list(mcc.modules())) + + assert not hasattr(mcc.nested.bn, _ORIGINAL_MODULE_ATTR) + assert not hasattr(mcc.linear, _ORIGINAL_MODULE_ATTR) + assert not hasattr(mcc, _ORIGINAL_MODULE_ATTR) + assert not hasattr(mcc.nested, _ORIGINAL_MODULE_ATTR) + + mc = convert_to_wrapnn(m) + check_converted(mc) + mcc = undo_convert_to_wrapnn(mc) + check_undo_converted(mcc) + + with wrapnn(m) as mc: + check_converted(mc) + check_undo_converted(m) + + with wrapnn(m, restore=False) as mc: + check_converted(mc) + check_converted(m) From 4e305a18af45a2661f87886a26d5f36fc37e355b Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 25 Oct 2024 08:40:29 +0000 Subject: [PATCH 1750/1892] Merged PR 2296: [skip ci] bump version to 0.4 bump version to 0.4 --- nnscaler/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnscaler/version.py b/nnscaler/version.py index 409b28ef..47fd5b28 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -__version__ = '0.3' +__version__ = '0.4' From 4e2c30e56efa1eb7d9d59516ba1dd7bae09e890f Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Wed, 30 Oct 2024 04:05:18 +0000 Subject: [PATCH 1751/1892] Merged PR 2297: fix unpartitionable identifiers in annotation Always unpartitionable identifiers in inner dimension, so autodist can correctly handle them. parity check pass unit test pass --- nnscaler/graph/function/dimops.py | 97 +++++++++++++++---- nnscaler/graph/function/function.py | 20 ++-- nnscaler/graph/graph.py | 3 +- .../spmd_solver/test_cube_operator.py | 60 +++++++++++- tests/graph/function/test_dimops.py | 16 +++ tests/parallel_module/test_gencode.py | 58 +++++++++++ 6 files changed, 217 insertions(+), 37 deletions(-) diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index c0a76da7..cfdc12dc 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -112,11 +112,11 @@ def name(self) -> str: return '(' + ' '.join(self._identifiers) + ')' @property - def identifiers(self) -> Tuple[str]: + def identifiers(self) -> Tuple[str, ...]: return self._identifiers @property - def reduces(self) -> Tuple[ReduceType]: + def reduces(self) -> Tuple[ReduceType, ...]: return self._reduces def __eq__(self, other): @@ -175,7 +175,7 @@ def __init__(self, dim_annos: Union[str, Tuple[DimAnno]]): self._dims: Tuple[DimAnno] = dim_annos @property - def dims(self) -> Tuple[DimAnno]: + def dims(self) -> Tuple[DimAnno, ...]: return self._dims @property @@ -417,14 +417,15 @@ def __repr__(self) -> str: outputs = ', '.join(repr(output) for output in self.outputs()) return inputs + ' -> ' + outputs - @staticmethod - def parse(anno: str) -> Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]]: + @classmethod + def parse(cls, anno: str) -> Tuple[Tuple[ShapeAnno, ...], Tuple[ShapeAnno, ...]]: """! Parse op annotation string to input shape annos and output shape annos. - @param anno str: operator annotation - - @return (inputs, outputs) Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]] + Args: + anno (str): operator annotation + Returns: + tuple[tuple[ShapeAnno, ...], tuple[ShapeAnno, ...]]: input shape annos and output shape annos """ # to inputs and outputs if '->' not in anno: @@ -438,12 +439,69 @@ def parse(anno: str) -> Tuple[Tuple[ShapeAnno], Tuple[ShapeAnno]]: # to ShapeAnnos inputs: Tuple[ShapeAnno] = tuple(ShapeAnno(shape) for shape in inputs) outputs: Tuple[ShapeAnno] = tuple(ShapeAnno(shape) for shape in outputs) + cls._verify_and_fix_inner_dim_anno(anno, inputs, outputs) + return inputs, outputs - @staticmethod - def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], - ous: Tuple[Tuple[Union[str, Tuple[str]]]]) -> str: - """! + @classmethod + def _verify_and_fix_inner_dim_anno( + cls, + anno: str, + inputs: Tuple[ShapeAnno, ...], + outputs: Tuple[ShapeAnno, ...] + ): + """ + Verify to make sure reduce type of annotations are consistent. + We also force reduce type of all inner dimension identifiers to be freeze, + Because we can't partition inner dimensions in current implementation. + """ + # used to track reduce type of each identifier + id_reduce_map: dict[str, DimAnno.ReduceType] = dict() + for shape in inputs + outputs: + for edim in shape.dims: + for idx, identifier in enumerate(edim.identifiers): + if id_reduce_map.setdefault(identifier, edim.reduces[idx]) != edim.reduces[idx]: + raise ValueError(f"Reduce type of identifier {identifier} is not consistent") + + non_first_inner_dim_ids = cls._get_non_leading_anno_ids(*inputs, *outputs) + updated_ids = set() + + for shape in inputs + outputs: + for edim in shape.dims: + reduces = [] + for idx, identifier in enumerate(edim.identifiers): + if identifier in non_first_inner_dim_ids and edim.reduces[idx] != DimAnno.ReduceType.Freeze: + updated_ids.add(identifier) + reduces.append(DimAnno.ReduceType.Freeze) + else: + reduces.append(edim.reduces[idx]) + # HACK: modify protected member to fix reduce type inplace + edim._reduces = tuple(reduces) + + if updated_ids: + _logger.debug(f"Inner dimensions {updated_ids} in {anno} are forced to be frozen because they can't be partitioned") + + @classmethod + def _get_non_leading_anno_ids(cls, *shape_annos: ShapeAnno) -> Set[str]: + """ + collect all unpartitioned identifiers in inner dimensions, which most are not in the first position. + See `transform_space` and `_verify_and_fix_inner_dim_anno` for more information. + """ + nonleading_ids = set() + for shape in shape_annos: + for dim, dim_anno in enumerate(shape.dims): + for identifier in list(dropwhile(lambda x: x == '1', dim_anno.identifiers))[1:]: + if not str.isdecimal(identifier): + nonleading_ids.add(identifier) + return nonleading_ids + + @classmethod + def create_op_str( + cls, + ins: Tuple[Tuple[Union[str, Tuple[str]]]], + ous: Tuple[Tuple[Union[str, Tuple[str]]]] + ) -> str: + """ Create operator annotation string e.g., ins = [ ['a', 'b', 'c+'], ['c+', ['d', 'e']] ] @@ -451,10 +509,12 @@ def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], => 'a b c+, c+ (d e) -> a b d e' - @param ins Tuple[Tuple[Union[str, Tuple[str]]]: input identifier list - @param ous Tuple[Tuple[Union[str, Tuple[str]]]: output identifier list + Args: + ins (Tuple[Tuple[Union[str, Tuple[str]]]): input identifier list + ous (Tuple[Tuple[Union[str, Tuple[str]]]): output identifier list - @return anno str: operator annotation + Returns: + str: operator annotation """ in_annos = list() ou_annos = list() @@ -501,12 +561,7 @@ def transform_space(self) -> List[Tuple[int, int]]: # in both cases, a can be partitioned, but b can't # collect all unpartitioned identifiers that are not in first position - nonleading_ids = set() - for shape in self.inputs() + self.outputs(): - for dim, dim_anno in enumerate(shape.dims): - for identifier in list(dropwhile(lambda x: x == '1', dim_anno.identifiers))[1:]: - if not str.isdecimal(identifier): - nonleading_ids.add(identifier) + nonleading_ids = self._get_non_leading_anno_ids(*self.inputs(), *self.outputs()) visited : Set[str] = set() # to remove equivalent configurations configs = [] diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 0cf72a77..adf58836 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -2707,9 +2707,9 @@ def Conv1D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_val = unwrap_if_irobject(padding) dilation_val = unwrap_if_irobject(dilation) groups_val = unwrap_if_irobject(groups) - if isinstance(stride_val, int): + if isinstance(stride_val, int): stride_val = (stride_val,) - if isinstance(dilation_val, int): + if isinstance(dilation_val, int): dilation_val = (dilation_val,) kW = weight.shape[-1] effective_kernel_size = (kW - 1) * dilation_val[0] @@ -2828,7 +2828,7 @@ def ConvTranspose1D(input, weight, bias=None, stride=1, padding=0, output_paddin [f'(groups group_size^) {iW}, (groups group_size^) oC {kW}, oC -> (groups oC) {oW}'] return IRDimops(ConvTranspose1D, 'conv_transpose1d', signature, annos, [input, weight, bias], stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) - if len(input.shape) == 3: + if len(input.shape) == 3: if bias is None: annos = [f'n iC+ {iW}, iC+ oC {kW} -> n oC {oW}'] if groups_val == 1 else \ [f'n (groups group_size^) {iW}, (groups group_size^) oC {kW} -> n (groups oC) {oW}'] @@ -2854,9 +2854,9 @@ def Conv2D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_val = unwrap_if_irobject(padding) dilation_val = unwrap_if_irobject(dilation) groups_val = unwrap_if_irobject(groups) - if isinstance(stride_val, int): + if isinstance(stride_val, int): stride_val = (stride_val, stride_val) - if isinstance(dilation_val, int): + if isinstance(dilation_val, int): dilation_val = (dilation_val, dilation_val) if isinstance(padding_val, str): if padding_val == 'same': @@ -2946,13 +2946,13 @@ def ConvTranspose2D(input, weight, bias=None, stride=1, padding=0, output_paddin output_padding_val = unwrap_if_irobject(output_padding) dilation_val = unwrap_if_irobject(dilation) groups_val = unwrap_if_irobject(groups) - if isinstance(stride_val, int): + if isinstance(stride_val, int): stride_val = (stride_val, stride_val) - if isinstance(padding_val, int): + if isinstance(padding_val, int): padding_val = (padding_val, padding_val) - if isinstance(output_padding_val, int): + if isinstance(output_padding_val, int): output_padding_val = (output_padding_val, output_padding_val) - if isinstance(dilation_val, int): + if isinstance(dilation_val, int): dilation_val = (dilation_val, dilation_val) if not (len(stride_val) == 2 and len(padding_val) == 2 and len(output_padding_val) == 2 and len(dilation_val) == 2): raise ValueError("stride, padding, output_padding, and dilation must have a length of 2") @@ -2979,7 +2979,7 @@ def ConvTranspose2D(input, weight, bias=None, stride=1, padding=0, output_paddin [f'(groups group_size^) {iH} {iW}, (groups group_size^) oC {kH} {kW}, oC -> (groups oC) {oH} {oW}'] return IRDimops(ConvTranspose2D, 'conv_transpose2d', signature, annos, [input, weight, bias], stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) - if len(input.shape) == 4: + if len(input.shape) == 4: if bias is None: annos = [f'n iC+ {iH} {iW}, iC+ oC {kH} {kW} -> n oC {oH} {oW}'] if groups_val == 1 else \ [f'n (groups group_size^) {iH} {iW}, (groups group_size^) oC {kH} {kW} -> n (groups oC) {oH} {oW}'] diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 1d835316..c3c920ff 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -391,7 +391,8 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], # get partitioned sub-nodes fnodes = algo.instantiate(**config) - assert fnodes is not None, f"Fail to partition node: {node} use algorithm and config: {config}" + if not fnodes: + raise ValueError(f"Fail to partition node: {node}. Please check your config: {config}.") # insert forward node fsegment: IRSegment = self.segment(node) diff --git a/tests/autodist/spmd_solver/test_cube_operator.py b/tests/autodist/spmd_solver/test_cube_operator.py index 6d70efe1..dd404377 100644 --- a/tests/autodist/spmd_solver/test_cube_operator.py +++ b/tests/autodist/spmd_solver/test_cube_operator.py @@ -1,19 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pytest - import tempfile -import torch import os from pathlib import Path + +import pytest +import torch +from torch.nn import functional as F + +import nnscaler from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph from nnscaler.autodist.model_graph import ModelGraph from nnscaler.autodist.autodist_config import AutoDistConfig from nnscaler.autodist.spmd_solver import SPMDSolver -import nnscaler - @nnscaler.register_op( '(1 h) l^ d^, (1 h) l^ d^, (1 h) l^ d^ -> (1 h) l^ d^', 'mock_attention') @@ -51,3 +52,52 @@ def test_cube_operator(): mock_attention_op = model_graph.operator_list[0] assert mock_attention_op.pos2dim_id((0, 0)) == 'h' assert mock_attention_op.dim_id2pos('h') == (0, 0) + + +class CVModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.out_channel = 32 + self.kernel_size = 3 + + def forward(self, input): + batch, in_channel, height, width = input.shape + + input = input.view(1, batch * in_channel, height, width) + weight = torch.randn(batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + return out + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_cube_operator_conv_transpose2d(): + """ + ConvTranspose2D and ConvTranspose1D oC dim can't be split + """ + batch, in_channel, height, width = 2, 16, 32, 32 + input = torch.randn((batch, in_channel, height, width)) + + dummy_input = {'input': input} + model = CVModel() + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir) + cfg = AutoDistConfig(mesh_col=2) + model_graph = ModelGraph(ir_graph, cfg) + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=cfg.mesh_desc, + autodist_config=cfg, + ) + + partition_counts = [ + spmd_solver.get_op_partition_count(i) + for i in range(model_graph.op_num) + ] + assert partition_counts == [4, 1, 2, 2] diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index 59c57057..67ae3551 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -8,6 +8,8 @@ from typing import Callable, Tuple, List from functools import partial +import pytest + import nnscaler.graph.function as F from nnscaler.graph.function.dimops import IRDimops, OpAnno from nnscaler.ir.tensor import IRFullTensor @@ -98,3 +100,17 @@ def test_transform_space(): assert OpAnno('a b, (b n) c -> a (1 1 1 b c) n').transform_space() == [(0, 0), (0, 1)] assert OpAnno('a b, (b n) c^ -> a (1 1 1 b) n c^').transform_space() == [(0, 0), (0, 1)] assert OpAnno('a b, (d^ n) c -> a (c n) d^').transform_space() == [(0, 0), (0, 1), (1,1)] + + +def test_parse_op(): + assert str(OpAnno('a b, b c -> a c')) == 'a b, b c -> a c' + assert str(OpAnno('a^ b, b c -> a^ c')) == 'a^ b, b c -> a^ c' + assert str(OpAnno('a b, (b n) c -> a (n c)')) == 'a b, (b n^) c^ -> a (n^ c^)' + assert str(OpAnno('a b, (b n) c -> a (n b c)')) == 'a b^, (b^ n^) c^ -> a (n^ b^ c^)' + assert str(OpAnno('a b, (b n) c -> a (1 b c) n')) == 'a b, (b n^) c^ -> a (1^ b c^) n^' + assert str(OpAnno('a b, (b n) c -> a (1 1 1 b c) n')) == 'a b, (b n^) c^ -> a (1^ 1^ 1^ b c^) n^' + assert str(OpAnno('a b, (b n) c^ -> a (1 1 1 b) n c^')) == 'a b, (b n^) c^ -> a (1^ 1^ 1^ b) n^ c^' + assert str(OpAnno('a b, (d^ n) c -> a (c n) d^')) == 'a b, (d^ n^) c -> a (c n^) d^' + + with pytest.raises(ValueError): + str(OpAnno('a b^, b^ c -> a c^')) == 'a b^, b^ c -> a c^' diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 5dd2c3c6..e3981ee8 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -6,6 +6,7 @@ from contextlib import nullcontext import torch +import torch.nn.functional as F import pytest from nnscaler.flags import CompileFlag @@ -1170,6 +1171,63 @@ def test_codegen_reduce_scatter(tmp_path, disable_reduce_scatter_adapter): ) +class CVModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.out_channel = 32 + self.kernel_size = 3 + + def forward(self, input): + batch, in_channel, height, width = input.shape + + input = input.view(1, batch * in_channel, height, width) + weight = torch.randn(batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + return out + + +def pas_conv2d(graph, cfg): + from nnscaler.ir import IRFwOperation, IRDataOperation + from nnscaler.policies import _tp, _replica + ngpus = cfg.plan_ngpus + + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'conv_transpose2d': + # this is an invalid partition + # ValueError will be raised + _tp(graph, node, list(range(ngpus)), 1, 1) + else: + _replica(graph, node, list(range(ngpus))) + return graph + + +@replace_all_device_with('cpu') +def test_invalid_partition(tmp_path): + """ + ConvTranspose2D and ConvTranspose1D oC dim can't be split + """ + batch, in_channel, height, width = 2, 16, 32, 32 + input = torch.randn((batch, in_channel, height, width)) + + dummy_input = {'input': input} + + m = CVModel() + m.train() + + with pytest.raises(ValueError): + parallelize( + m, + dummy_input, + pas_conv2d, + ComputeConfig(2, 2), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + + class KwargsModule(torch.nn.Module): def forward(self, x): return x + torch.zeros_like(x, dtype=torch.float32) From fa98abb31c31d31ad5ad99bcaf6e23e48471f8db Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 31 Oct 2024 08:01:11 +0000 Subject: [PATCH 1752/1892] Merged PR 2295: add grad check in trainer --- docs/source/trainer.md | 72 ++++++++++++++++------------ nnscaler/cli/trainer.py | 92 ++++++++++++++++++++++++++++++++++++ nnscaler/cli/trainer_args.py | 11 +++++ nnscaler/runtime/module.py | 3 +- tests/cli/test_trainer.py | 41 ++++++++++++++++ 5 files changed, 187 insertions(+), 32 deletions(-) diff --git a/docs/source/trainer.md b/docs/source/trainer.md index 5b13ece8..78e2acda 100644 --- a/docs/source/trainer.md +++ b/docs/source/trainer.md @@ -231,6 +231,16 @@ Internally we will get the final value with `__value_type(value)`. - `type` (`str`): The logger type or factory function. - `args` (`Dict[str, Any]`): The arguments of the logger. +- `debug` (`DebugConfig`): Trainer debug related setting. + + ```python + @dataclass + class DebugConfig: + check_gradient_sync_cross_devices: bool = True + ``` + + - `check_gradient_sync_cross_devices` (`bool`): Before gradient clip norm, check the gradient sync for the same parameter is consistent cross devices, if ZeRO is enabled, will check the gradient cross each ZeRO group, if ZeRO is not enabled, will check the gradient cross each nnscaler scale unit. This helps to find bugs related to gradient updates during training. Default is `True`. + - `hook` (`Union[HookConfig, HookMapConfig, None]`): The hooks to be used. You can provide a hook with a hook class or a map of hook functions. Please note if your `model`/`optimizer`/`lr_scheduler` inherit from `TrainHook`, @@ -240,51 +250,51 @@ and hooks passed with this config is called in the last. Hook class: - ```python - @dataclass - class HookConfig: - type: str = None - args: Dict[str, Any] = field(default_factory=dict) - ``` + ```python + @dataclass + class HookConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + ``` - `type` (`str`): The hook type or factory function. - `args` (`Dict[str, Any]`): The arguments of the hook. Hook map: - ```python - @dataclass - class HookMapConfig: - after_setup: str = None + ```python + @dataclass + class HookMapConfig: + after_setup: str = None - on_train_start: str = None - on_train_end: str = None - on_val_start: str = None - on_val_end: str = None + on_train_start: str = None + on_train_end: str = None + on_val_start: str = None + on_val_end: str = None - on_epoch_start: str = None - on_epoch_end: str = None + on_epoch_start: str = None + on_epoch_end: str = None - on_train_step_start: str = None - on_train_step_end: str = None - on_val_step_start: str = None - on_val_step_end: str = None + on_train_step_start: str = None + on_train_step_end: str = None + on_val_step_start: str = None + on_val_step_end: str = None - after_aggregate_train_step_outputs: str = None - after_aggregate_val_step_outputs: str = None + after_aggregate_train_step_outputs: str = None + after_aggregate_val_step_outputs: str = None - before_zero_grad: str = None - after_zero_grad: str = None + before_zero_grad: str = None + after_zero_grad: str = None - before_gnorm_clip: str = None - after_gnorm_clip: str = None + before_gnorm_clip: str = None + after_gnorm_clip: str = None - before_optimizer_step: str = None - after_optimizer_step: str = None + before_optimizer_step: str = None + after_optimizer_step: str = None - on_load_checkpoint: str = None - on_save_checkpoint: str = None - ``` + on_load_checkpoint: str = None + on_save_checkpoint: str = None + ``` - `after_setup` (`str`): The hook function to be called after setting up the trainer. Only be called when `run_mode == 'run'`. Signature: `def after_setup(trainer: 'Trainer') -> None:` diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 1424cc4e..3d2797df 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -538,6 +538,94 @@ def _fix_batches(self, batches): batches += [self.dummy_input] * gap return batches, is_dummy_batch + @torch.no_grad() + def _check_grad_cross_devices_correctness(self): + # if ZeRO is enabled, will check the gradient cross each ZeRO group. + # if ZeRO is not enabled, will check the gradient cross each nnscaler scale unit. + def get_optimizer_sync_group(): + # if ZeRO is enabled, `compute_config.optimizer_dedup_group_size` is the ZeRO group size, + # return the corresponding rank list and device group cross ZeRO groups, parallel to the current rank, + # if ZeRO is not enabled, `compute_config.optimizer_dedup_group_size` is the plan_ngpus, + # return the corresponding rank list and device group cross scale units, parallel to the current rank. + rank = torch.distributed.get_rank() + group_size = self.train_args.compute_config.optimizer_dedup_group_size + runtime_ngpus = self.train_args.compute_config.runtime_ngpus + + # group_size equal to runtime_ngpus means one of: + # 1. ZeRO is enabled and ZeRO group number is 1 + # 2. ZeRO is not enabled and nnscaler scale unit number is 1 + # in these cases, the gradient of one parameter/sub-parameter have only one copy, + # so there is not need to check the gradient consistent cross rank. + if group_size == runtime_ngpus: + return [rank], None + + from nnscaler.runtime.device import DeviceGroup + # make sure all needed device groups have been created to make safe + for i in range(group_size): + DeviceGroup().get_group( + list(range(i, runtime_ngpus, group_size)) + ) + rank_list = list(range(rank % group_size, runtime_ngpus, group_size)) + return rank_list, DeviceGroup().get_group(rank_list) + + rank_list, sync_group = get_optimizer_sync_group() + + if sync_group is None: + return + + params_info_for_gnorm = self.model.parameters_for_calc_gnorm() + tidx2param = {} + for r_idx, params_info in enumerate(params_info_for_gnorm): + for p_idx, param in enumerate(params_info.params): + # each param is the `Bucket._param_for_optimizer` of one of reducer's bucket, + # r_idx is the index of the reducer, p_idx is the index of the bucket. + tidx2param[(r_idx, p_idx)] = param + tidx2grad = {k: v.grad for k, v in sorted(tidx2param.items(), key=lambda item: item[0])} + + def get_grad_metric(grad: torch.Tensor): + mean, max, min, norm = grad.float().mean().item(), grad.max().item(), grad.min().item(), grad.float().norm().item() + return mean, max, min, norm + + # check gradient metric: (mean, max, min, norm) + tidx2metric = {k: get_grad_metric(v) for k, v in tidx2grad.items()} + tidx2ranks_metric = [None for _ in range(len(rank_list))] + torch.distributed.all_gather_object(tidx2ranks_metric, tidx2metric, group=sync_group) + + def is_consistent(m1, m2, delta=1e-6): + # refer to Fairseq's approach, we don't check the completely equal here + # to ignore the precision loss due to communication. + abs_diff = abs(m1 - m2) + return abs_diff / (abs(m1) + delta) < delta + + # check if all the gradient metric gathered from other rank is consistent with corrent rank + grad_consistent = True + for _tidx2metric in tidx2ranks_metric: + for tidx, metric in tidx2metric.items(): + check_result = [is_consistent(m1, m2) for m1, m2 in zip(_tidx2metric[tidx], metric)] + if not all(check_result): + grad_consistent = False + break + + if not grad_consistent: + pretty_detail = [] + header = "rank mean{:6} max{:7} min{:7} norm{:7}".format("", "", "", "") + line = "-" * 80 + for tidx, _ in tidx2metric.items(): + pretty_detail.extend([line, f"reducer {tidx[0]} bucket {tidx[1]}", line, header]) + for r, _tidx2metric in zip(rank_list, tidx2ranks_metric): + pretty_detail.append("{:4d} {:10.6f} {:10.6f} {:10.6f} {:10.6f}".format(r, *_tidx2metric[tidx])) + pretty_detail.append(line) + pretty_detail = "\n".join(pretty_detail) + + error_detail = "grad metric detail across the workers:\n{}\n".format(pretty_detail) + raise RuntimeError( + "Fatal error: gradients are inconsistent between workers. " + + "\n" + + "=" * 80 + + "\n{}\n".format(error_detail) + + "=" * 80 + ) + def _train(self): logger.info('Training...') # reset peak memory stats before training @@ -733,6 +821,10 @@ def _train_epoch(self, epoch): multiplier /= aggregated_outputs.num_tokens self.optimizer.scale_grads(multiplier) + # check gradient sync & scale correctness + if self.train_args.check_gradient_sync_cross_devices: + self._check_grad_cross_devices_correctness() + # clip gradients self.hook.before_gnorm_clip(self) if self.train_args.optimizer.clip_gnorm: diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 30af224d..3d43356a 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -225,6 +225,15 @@ class LogConfig: args: Dict[str, Any] = field(default_factory=dict) +@dataclass +class DebugConfig: + # before gradient clip norm, check the gradient sync for the same parameter is consistent cross devices, + # if ZeRO is enabled, will check the gradient cross each ZeRO group, + # if ZeRO is not enabled, will check the gradient cross each nnscaler scale unit. + # this helps to find bugs related to gradient updates during training. + check_gradient_sync_cross_devices: bool = True + + @dataclass class HookConfig: type: str = None @@ -315,6 +324,8 @@ class TrainerArgs: # It can be `HookConfig` or `HookMapConfig` hook: Union[HookConfig, HookMapConfig, None] = None + debug: DebugConfig = field(default_factory=DebugConfig) + # TODO: mixed precision support precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 69255a5b..58addd16 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -240,7 +240,8 @@ def parameters_for_calc_gnorm(self) -> List[ParamsInfo]: # we should use `parameters_for_optimizer` here since calculating gnorm # is ahead of the optimizer step. When ZeRO is enabled, each device only # maintains a subset of the parameters. As a result, `param_names` may not - # align with the value of `reducer.parameters_for_optimizer()`. + # align with the value of `reducer.parameters_for_optimizer()`, only part of + # parameters assigned to a bucket will be shown in `reducer.parameters_for_optimizer()`. params_info = ParamsInfo(reducer.ranks, reducer.parameters_for_optimizer(), param_names, reducer.zero_ngroups) params_info_for_gnorm.append(params_info) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 295e44f3..30bafde5 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -469,3 +469,44 @@ def _empty_train_args(): train_args.dataset.val_args = {} assert train_args.create_dataset() is not None assert train_args.create_dataset('val') is None + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4 or torch.cuda.device_count() >= 8, reason='lack of gpu devices') +@pytest.mark.parametrize('use_bf16', [True, False]) +@pytest.mark.parametrize('zero_ngroups', [None, '1', '2']) +def test_trainer_grad_sync_check_4gpu(tmp_path, use_bf16, zero_ngroups): + launch_torchrun(4, trainer_grad_sync_check, tmp_path, use_bf16, zero_ngroups, '4') + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 8, reason='lack of gpu devices') +@pytest.mark.parametrize('use_bf16', [True, False]) +@pytest.mark.parametrize('zero_ngroups', [None, '1', '2', '4']) +def test_trainer_grad_sync_check_8gpu(tmp_path, use_bf16, zero_ngroups): + launch_torchrun(8, trainer_grad_sync_check, tmp_path, use_bf16, zero_ngroups, '8') + + +def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + optimizer_type = 'torch.optim.Adam' + use_zero = False if zero_ngroups is None else True + zero_ngroups = '1' if zero_ngroups is None else zero_ngroups + + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16' if use_bf16 else 'none', + '--optimizer.type', optimizer_type, + '--max_epochs', '1', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', runtime_ngpus, + '--compute_config.use_zero', str(use_zero), + '--compute_config.zero_ngroups', zero_ngroups, + '--checkpoint.save_dir', str(ckpt_savedir), + '--debug.check_gradient_sync_cross_devices', 'true', + ]) + trainer.run() + torch.distributed.barrier() From 442e012e83a37eeb7090fa529cf9e936253b6b06 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 1 Nov 2024 05:29:32 +0000 Subject: [PATCH 1753/1892] Merged PR 2302: [Bugfix] grad check config fix the bug introduced in https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube/pullrequest/2295 --- nnscaler/cli/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 3d2797df..d84b88f4 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -822,7 +822,7 @@ def _train_epoch(self, epoch): self.optimizer.scale_grads(multiplier) # check gradient sync & scale correctness - if self.train_args.check_gradient_sync_cross_devices: + if self.train_args.debug.check_gradient_sync_cross_devices: self._check_grad_cross_devices_correctness() # clip gradients From b59b82f2da79a73c073cab9dc6ca8def03991314 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 1 Nov 2024 06:41:04 +0000 Subject: [PATCH 1754/1892] Merged PR 2299: refine tracer hierarchy unit test passed parity check passed --- examples/huggingface_nlp/compile_hf.py | 2 +- nnscaler/codegen/frontend_mapping.py | 2 +- nnscaler/graph/parser/__init__.py | 2 +- nnscaler/graph/parser/converter.py | 8 ++++---- nnscaler/graph/parser/{fx => }/mapping.py | 0 nnscaler/graph/parser/{fx => }/parser.py | 4 ++-- nnscaler/graph/parser/register.py | 2 +- .../fx/concrete_trace_utils => tracer}/__init__.py | 0 .../{parser/fx/concrete_trace_utils => tracer}/_pytree.py | 0 .../fx/concrete_trace_utils => tracer}/concrete_proxy.py | 0 .../fx/concrete_trace_utils => tracer}/concrete_tracer.py | 0 .../fx/concrete_trace_utils => tracer}/frame_utils.py | 0 .../concrete_trace_utils => tracer}/function_patcher.py | 0 .../fx/concrete_trace_utils => tracer}/metadata.py | 0 .../concrete_trace_utils => tracer}/operator_patcher.py | 0 .../fx/concrete_trace_utils => tracer}/orig_func.py | 0 .../fx/concrete_trace_utils => tracer}/pytree_utils.py | 2 +- .../concrete_trace_utils => tracer}/torch_fx_patcher.py | 0 .../fx/concrete_trace_utils => tracer}/trace_strategy.py | 0 .../fx/concrete_trace_utils => tracer}/wrap_utils.py | 2 +- nnscaler/parallel.py | 2 +- nnscaler/runtime/module.py | 2 +- tests/conftest.py | 2 +- tests/graph/parser/test_ast_transformer.py | 2 +- tests/graph/parser/test_converter.py | 4 ++-- tests/graph/tracer/test_dis.py | 2 +- tests/graph/tracer/test_inplace.py | 2 +- tests/graph/tracer/test_op_patcher.py | 2 +- tests/graph/tracer/test_pytree.py | 8 ++++---- tests/ir/test_cten.py | 2 +- tests/parallel_module/test_override.py | 2 +- tests/parallel_module/test_pyfunc.py | 2 +- tests/utils.py | 2 +- 33 files changed, 29 insertions(+), 29 deletions(-) rename nnscaler/graph/parser/{fx => }/mapping.py (100%) rename nnscaler/graph/parser/{fx => }/parser.py (99%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/__init__.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/_pytree.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/concrete_proxy.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/concrete_tracer.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/frame_utils.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/function_patcher.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/metadata.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/operator_patcher.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/orig_func.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/pytree_utils.py (98%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/torch_fx_patcher.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/trace_strategy.py (100%) rename nnscaler/graph/{parser/fx/concrete_trace_utils => tracer}/wrap_utils.py (99%) diff --git a/examples/huggingface_nlp/compile_hf.py b/examples/huggingface_nlp/compile_hf.py index 08c2f39f..376a27ea 100644 --- a/examples/huggingface_nlp/compile_hf.py +++ b/examples/huggingface_nlp/compile_hf.py @@ -431,7 +431,7 @@ def parse_arguments() -> argparse.Namespace: # add model name to FxModuleParser log fxparser_warning_path = args.log_dir / FXMODULE_PARSER_WARNING_FNAME file_handler = logging.FileHandler(fxparser_warning_path) - from nnscaler.graph.parser.fx.parser import _logger + from nnscaler.graph.parser.parser import _logger _logger.addHandler(file_handler) _logger.warning(f"\n{args.model_name}") diff --git a/nnscaler/codegen/frontend_mapping.py b/nnscaler/codegen/frontend_mapping.py index 7eb362b4..a64845be 100644 --- a/nnscaler/codegen/frontend_mapping.py +++ b/nnscaler/codegen/frontend_mapping.py @@ -11,7 +11,7 @@ from nnscaler.ir.operator import IRFwOperation from nnscaler.graph.parser.register import CustomizedOps -from nnscaler.graph.parser.fx.parser import SELF_GETATTR_SIG +from nnscaler.graph.parser.parser import SELF_GETATTR_SIG class Sign2EmitRule: diff --git a/nnscaler/graph/parser/__init__.py b/nnscaler/graph/parser/__init__.py index 8b025b52..1dea36e7 100644 --- a/nnscaler/graph/parser/__init__.py +++ b/nnscaler/graph/parser/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.parser.parser import FxModuleParser from nnscaler.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph from nnscaler.graph.parser.register import register from nnscaler.graph.parser.external import * diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 4a2b3251..308fff73 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -12,10 +12,10 @@ from nnscaler.graph import IRGraph from nnscaler.flags import CompileFlag -from nnscaler.graph.parser.fx.parser import FxModuleParser -from nnscaler.graph.parser.fx.concrete_trace_utils import concrete_trace -from nnscaler.graph.parser.fx.concrete_trace_utils.wrap_utils import Location, is_autograd_apply, LeafWrapInfo -from nnscaler.graph.parser.fx.concrete_trace_utils.torch_fx_patcher import side_effectful_inplace_ops +from nnscaler.graph.parser import FxModuleParser +from nnscaler.graph.tracer import concrete_trace +from nnscaler.graph.tracer.wrap_utils import Location, is_autograd_apply, LeafWrapInfo +from nnscaler.graph.tracer.torch_fx_patcher import side_effectful_inplace_ops import nnscaler.runtime.function as cube_rt_function diff --git a/nnscaler/graph/parser/fx/mapping.py b/nnscaler/graph/parser/mapping.py similarity index 100% rename from nnscaler/graph/parser/fx/mapping.py rename to nnscaler/graph/parser/mapping.py diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/parser.py similarity index 99% rename from nnscaler/graph/parser/fx/parser.py rename to nnscaler/graph/parser/parser.py index 96db3ef1..751aacf1 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -11,13 +11,13 @@ from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.cten import IRObject, IRCell, IRTensor from nnscaler.graph.parser.frame import Frame -from nnscaler.graph.parser.fx.mapping import SignFx2Op +from nnscaler.graph.parser.mapping import SignFx2Op from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import IRDimops from nnscaler.graph.function.function import any_ir_object_satisfy +from nnscaler.graph.tracer import TensorMetadata, DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE import torch.fx -from .concrete_trace_utils import TensorMetadata, DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE _logger = logging.getLogger(__name__) diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index b8b2923e..7d9c1661 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -14,7 +14,7 @@ from torch import ScriptFunction from nnscaler.graph.function.dimops import IRDimops, OpAnno, TransformRule -from nnscaler.graph.parser.fx.concrete_trace_utils.wrap_utils import is_autograd_apply +from nnscaler.graph.tracer.wrap_utils import is_autograd_apply from nnscaler.ir.operator import IRTensor, IRFwOperation _logger = logging.getLogger(__name__) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py b/nnscaler/graph/tracer/__init__.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py rename to nnscaler/graph/tracer/__init__.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py b/nnscaler/graph/tracer/_pytree.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py rename to nnscaler/graph/tracer/_pytree.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/tracer/concrete_proxy.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py rename to nnscaler/graph/tracer/concrete_proxy.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py rename to nnscaler/graph/tracer/concrete_tracer.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py b/nnscaler/graph/tracer/frame_utils.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py rename to nnscaler/graph/tracer/frame_utils.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py b/nnscaler/graph/tracer/function_patcher.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py rename to nnscaler/graph/tracer/function_patcher.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py b/nnscaler/graph/tracer/metadata.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py rename to nnscaler/graph/tracer/metadata.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/nnscaler/graph/tracer/operator_patcher.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py rename to nnscaler/graph/tracer/operator_patcher.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py b/nnscaler/graph/tracer/orig_func.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py rename to nnscaler/graph/tracer/orig_func.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py b/nnscaler/graph/tracer/pytree_utils.py similarity index 98% rename from nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py rename to nnscaler/graph/tracer/pytree_utils.py index 55dd19e5..5fc83be8 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py +++ b/nnscaler/graph/tracer/pytree_utils.py @@ -13,7 +13,7 @@ from . import orig_func, _pytree from ._pytree import * -import nnscaler.graph.parser.fx.concrete_trace_utils.concrete_proxy as cct +import nnscaler.graph.tracer as cct # if pytree is a ConcreteProxy, type(pytree) will return the type of ConcreteProxy.value diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py b/nnscaler/graph/tracer/torch_fx_patcher.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py rename to nnscaler/graph/tracer/torch_fx_patcher.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py b/nnscaler/graph/tracer/trace_strategy.py similarity index 100% rename from nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py rename to nnscaler/graph/tracer/trace_strategy.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py b/nnscaler/graph/tracer/wrap_utils.py similarity index 99% rename from nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py rename to nnscaler/graph/tracer/wrap_utils.py index 808868dd..28f63905 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py +++ b/nnscaler/graph/tracer/wrap_utils.py @@ -14,7 +14,7 @@ import torch from torch.fx.proxy import Scope, ScopeContextManager -import nnscaler.graph.parser.fx.concrete_trace_utils as cct +import nnscaler.graph.tracer as cct from . import pytree_utils, orig_func, operator_patcher if TYPE_CHECKING: from .concrete_tracer import ConcreteTracer diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 86c3d5eb..2a6cb375 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -30,7 +30,7 @@ from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.wrapnn import convert_to_wrapnn, wrapnn from nnscaler.graph.gener.gen import IRAdapterGener -from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.parser import FxModuleParser from nnscaler.graph.schedule.predefined import PredefinedSched from nnscaler.graph.schedule.schedplan import SchedulePlan diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 58addd16..8c0b0403 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -12,7 +12,7 @@ import torch import torch.distributed as dist -from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.parser import FxModuleParser from nnscaler.runtime.device import DeviceGroup from nnscaler.runtime.adapter.reducer import Reducer diff --git a/tests/conftest.py b/tests/conftest.py index dbaa1cc9..e42de300 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import pytest from pathlib import Path -from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.parser import FxModuleParser @pytest.fixture(autouse=True) def clean_generated_files(): diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 00f1bbde..51330fd9 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -7,7 +7,7 @@ import pytest -from nnscaler.graph.parser.fx.concrete_trace_utils.operator_patcher import ( +from nnscaler.graph.tracer.operator_patcher import ( OperatorTransformer, SuperTransformer, ProxyCallTransformer, diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 04b24302..20f2bcff 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -98,8 +98,8 @@ def forward(self, x): module = MyModule() fx_graph = to_fx_graph(module, dummy_input) - from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_proxy import ConcreteProxy - from nnscaler.graph.parser.fx.concrete_trace_utils import TensorMetadata + from nnscaler.graph.tracer.concrete_proxy import ConcreteProxy + from nnscaler.graph.tracer import TensorMetadata for node in fx_graph.graph.nodes: # this assert is only for this simple model, all node should have TensorMetadata type 'tensor_meta' diff --git a/tests/graph/tracer/test_dis.py b/tests/graph/tracer/test_dis.py index c0a741b5..03b30cdd 100644 --- a/tests/graph/tracer/test_dis.py +++ b/tests/graph/tracer/test_dis.py @@ -5,7 +5,7 @@ import pytest -from nnscaler.graph.parser.fx.concrete_trace_utils.frame_utils import get_last_instruction, get_instructions +from nnscaler.graph.tracer.frame_utils import get_last_instruction, get_instructions class A: diff --git a/tests/graph/tracer/test_inplace.py b/tests/graph/tracer/test_inplace.py index b77980bc..a4222b87 100644 --- a/tests/graph/tracer/test_inplace.py +++ b/tests/graph/tracer/test_inplace.py @@ -6,7 +6,7 @@ import torch from nnscaler.graph.parser.converter import to_fx_graph -from nnscaler.graph.parser.fx.concrete_trace_utils.torch_fx_patcher import side_effectful_inplace_ops +from nnscaler.graph.tracer.torch_fx_patcher import side_effectful_inplace_ops import nnscaler.runtime.function as cube_rt_function from ...utils import replace_all_device_with diff --git a/tests/graph/tracer/test_op_patcher.py b/tests/graph/tracer/test_op_patcher.py index e1f436b3..1c60491c 100644 --- a/tests/graph/tracer/test_op_patcher.py +++ b/tests/graph/tracer/test_op_patcher.py @@ -3,7 +3,7 @@ import torch from types import MethodType -from nnscaler.graph.parser.fx.concrete_trace_utils.operator_patcher import OperatorPatcher +from nnscaler.graph.tracer.operator_patcher import OperatorPatcher def test_patch_func_or_module(): diff --git a/tests/graph/tracer/test_pytree.py b/tests/graph/tracer/test_pytree.py index 2f64a6f3..59d83095 100644 --- a/tests/graph/tracer/test_pytree.py +++ b/tests/graph/tracer/test_pytree.py @@ -1,13 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from nnscaler.graph.parser.fx.concrete_trace_utils import pytree_utils -from nnscaler.graph.parser.fx.concrete_trace_utils.pytree_utils import ( +from nnscaler.graph.tracer import pytree_utils +from nnscaler.graph.tracer.pytree_utils import ( get_common_spec, tree_leaves_with_spec, ) -from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_proxy import ConcreteProxy -from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import ( +from nnscaler.graph.tracer.concrete_proxy import ConcreteProxy +from nnscaler.graph.tracer.concrete_tracer import ( update_tree_proxy_value, ) diff --git a/tests/ir/test_cten.py b/tests/ir/test_cten.py index ab054861..9baba6db 100644 --- a/tests/ir/test_cten.py +++ b/tests/ir/test_cten.py @@ -7,7 +7,7 @@ from nnscaler.ir.cten import IRObject from nnscaler.ir.tensor import IRFullTensor, IRSubTensor -from nnscaler.graph.parser.fx.parser import TensorMetadata, DICT_VALUES_TYPE, DICT_ITEMS_TYPE +from nnscaler.graph.parser.parser import TensorMetadata, DICT_VALUES_TYPE, DICT_ITEMS_TYPE @pytest.mark.parametrize('tosub', [True, False]) diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index c7886353..3b24267c 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -9,7 +9,7 @@ import torch import shutil -from nnscaler.graph.parser.fx.parser import FxModuleParser +from nnscaler.graph.parser import FxModuleParser from nnscaler.parallel import ReuseType, parallelize, ComputeConfig, _load_parallel_module_class from nnscaler.runtime.module import ParallelModule diff --git a/tests/parallel_module/test_pyfunc.py b/tests/parallel_module/test_pyfunc.py index 29f90114..e30c1afe 100644 --- a/tests/parallel_module/test_pyfunc.py +++ b/tests/parallel_module/test_pyfunc.py @@ -43,7 +43,7 @@ def _worker(): init_distributed() dummy_input = {'x': torch.rand(2, 10)} - from nnscaler.graph.parser.fx.parser import _logger as _logger_parser + from nnscaler.graph.parser.parser import _logger as _logger_parser from nnscaler.graph.graph import _logger as _logger_graph from nnscaler.graph.segment import _logger as _logger_seg with tempfile.TemporaryDirectory() as tempdir, \ diff --git a/tests/utils.py b/tests/utils.py index 9cf5213d..17ca2659 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -108,7 +108,7 @@ def replace_all_device_with(device='cpu', force=False): yield return - from nnscaler.graph.parser.fx.concrete_trace_utils import wrap_utils + from nnscaler.graph.tracer import wrap_utils orig_to = torch.Tensor.to orig_cuda = torch.Tensor.cuda From aba5acad34e51e8c827b1c4e0c8a53d95e7c4871 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 5 Nov 2024 04:50:37 +0000 Subject: [PATCH 1755/1892] Merged PR 2298: [BugFix]: infer grad correctly Fix bug trigger by previous [PR](https://dev.azure.com/msrasrg/SuperScaler/_git/MagicCube/pullrequest/2207) To align with behavior of `backward` in `IRGraph`. We will not record gradient for IRPyFunc and node that doesn't have a output tensor that requires grad. --- nnscaler/graph/gener/gen.py | 23 +-- nnscaler/graph/graph.py | 9 +- nnscaler/graph/segment.py | 65 +++++-- nnscaler/ir/tensor.py | 2 + tests/graph/test_segment.py | 175 ++++++++++++++++++ tests/parallel_module/common.py | 10 - tests/parallel_module/test_attr_dedup.py | 3 +- tests/parallel_module/test_checkpoint.py | 4 +- .../parallel_module/test_checkpoint_buffer.py | 4 +- .../parallel_module/test_checkpoint_dedup.py | 3 +- .../parallel_module/test_checkpoint_shared.py | 3 +- .../parallel_module/test_checkpoint_unused.py | 3 +- tests/parallel_module/test_ddp.py | 3 +- tests/parallel_module/test_end2end.py | 3 +- .../test_end2end_mix_precision.py | 4 +- tests/parallel_module/test_inference.py | 3 +- tests/parallel_module/test_init.py | 4 +- tests/parallel_module/test_line_timer.py | 4 +- tests/parallel_module/test_reducer_hook.py | 3 +- tests/parallel_module/test_scale_grads.py | 3 +- tests/parallel_module/test_submodule.py | 3 +- tests/parallel_module/test_wholemodule.py | 3 +- tests/utils.py | 11 ++ 23 files changed, 281 insertions(+), 67 deletions(-) diff --git a/nnscaler/graph/gener/gen.py b/nnscaler/graph/gener/gen.py index 8184787e..14c79c05 100644 --- a/nnscaler/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -385,23 +385,6 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bctensors = bctensors + tuple(fwop.output(0).grad for fwop in input_producer[ftensor]) bctensors = expand_devices(bctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" - # special case for loss tensor: - # 1) Since loss is the output of the whole graph, we don't have a backward producer node for loss. - # Therefore, bptensors is empty for loss tensor. - # 2) We must make sure bptensors to be non-empty to generate correct communication prims. If bptensor - # is empty, grad communication (the backward adapter) will not be generated, so only forward 'all-reduce' - # will be used. As a result, the loss tensor's requires_grad will be set to False at runtime. - # 3) According to loss's semantics in current deep learning, the backward prim should be `identity`. When - # the loss tensor is partitioned along the value dimension, since it is reduced by `add` operation, it is - # safe to use `identity` as the backward prim. - # 4) To generated `identity`, we follow the implementation at activation -> graph/segment output below: create - # dummy producer tensor and assign device information. Note, it is equivalent to copy bptensors from bctensors. - if ftensor.is_loss() and ftensor.requires_grad: - assert len(bptensors) == 0, f'expect no backward producer for loss tensor {ftensor}, but got {bproducers} with {bptensors}' - assert ftensor in output_consumer, f'expect loss tensor {ftensor} in output_consumer' - bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) - bptensors = expand_devices(bptensors, producer=True) - fadapters = [] @@ -414,13 +397,15 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: # (activation -> graph/segment output) generation: generate communication adapters between # producer operators and graph/segment output tensors. Note graph/segment output tensors - # always require for full-shape/value for output, while consumers may partition them. Therefore, + # always require for full-shape/value for output, while producers may partition them. Therefore, # we need to additionally generate adapters for this case. if ftensor in output_consumer: out_fctensors = tuple(fwop.input(0) for fwop in output_consumer[ftensor]) out_fctensors = expand_devices(out_fctensors, consumer=True) - # dedup adapter if the output is same with activation tensor + out_bptensors = [t.grad for t in out_fctensors if isinstance(t, IRSubTensor)] + # skip if the output is same with activation tensor if set(out_fctensors) == set(fctensors) and \ + set(out_bptensors) == set(bptensors) and \ set(t.device[0] for t in out_fctensors) == set(t.device[0] for t in fctensors): pass else: diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index c3c920ff..803859f1 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -132,10 +132,15 @@ def backward(self, loss: Optional[IRSubTensor] = None): For operators that doesn't need backward, all gradients of their input/output tensors will make to None (despite require_grad is True) - @param loss IRSubTensor: the loss tensor, must be in the output + Note grad of input tensors of a IRPyFunc will be None and we will not + generate a backward node for IRPyFunc. + + Args: + loss (IRSubTensor): the loss tensor, must be in the output of current graph. The loss shape should be (1,) - @return self IRGraph: None + Returns: + self (IRGraph): updated graph with backward operators """ # set mirror as self self._mirror = self diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index ba2b122b..0bfd429f 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -376,25 +376,21 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: t.grad = grad # set for consumers - # We strictly follow the behavior in the fx graph. It is possible that there - # exists a node that consumes a tensor with gradient but generates tensors without - # gradient, e.g., the `.data` operation in torch. As a result, nnscaler will generate - # backward adapter (communications) between this consumer and its producer. - # According to the runtime behavior, we have - # case 1: there are gradients flowing in the consume full tensor. This case happens - # when the full tensor is the segment output at the same time. Note in nnscaler - # we will replicate segment's outputs and we will generate another adapter for - # the activation -> segment output case if the two adapters are different. As a - # result, the node (not matter IRDimOps or IRPyFunc, e.g., .data) should be replicated - # as well. In this case, the backward adapter is correct. - # case 2: no gradients exist, then the backward adapter does not influence the result. + # We strictly follow the `requires_grad` in the fx graph in most cases. However, we will + # ignore the gradient when the corresponding subtensor is consumed by a IRPyFunc, since + # nnScaler will not generate a backward node for IRPyFunc currently (check IRGraph.backward). consumers, ctensors = [], [] for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): itensors = consumer.find(ctensor) # set by default None for itensor in itensors: itensor.grad = None - if fgrad is not None: + if isinstance(consumer, IRPyFunc): + continue + # filter out non-autograd operators + if fgrad is None: continue + if isinstance(consumer, IRPyFunc): continue + if any(isinstance(t, IRSubTensor) and t.requires_grad for t in consumer.outputs()): consumers.append(consumer) ctensors.append(ctensor) @@ -1022,8 +1018,47 @@ def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> I inputs.add(itensor) for otensor in node.oobjs(): # if the tensor is required by segment outputs, set as output - if otensor in segment_outputs: - outputs.add(otensor) + # NOTE: there may be several identical tensors in `segment_outputs`, for example when + # user code contains `return t, t`. To handle this case, we will break the loop once + # finding a matched tensor. + seg_out_matched = False + for seg_out in segment_outputs: + if otensor != seg_out: + continue + # Since we set a consume tensor's grad to None if it is a input of a IRPyFunc, + # it is possible that two tensors share a same id but with different grad. + # For example, consider the following case: + # t1 = dimops(xx) + # t2 = pyfunc(t1) + # return t1, t2 + # Furtherly assume: + # - t1 requires grad + # - `dimops` is partitioned along `xx`'s batch dim + # - `pyfunc` is replicated. + # In nnscaler, there is one fulltensor representing `t1` and several subtensors. + # For simplicity, we name them separately: + # - when t1 is at the output of `dimops`, it is called `t1_a` + # - when t1 is at the input of `pyfunc`, it is called `t1_b` + # - when t1 is the output of the segment, it is called `t1_c` + # In this case, `t1_b`'s grad is set to None (check func `infer_grad`), but the + # the `t1_a` and `t1_c`'s grads are not None. The three subtensors share the same + # id inherited from the fulltensor, which means they are considered as the same + # tensor when calling `__equal__` method. + # Since partition plans for operators are different, there will be two adapters + # in communication generation: + # - between `t1_a` and the returned `t1_c`, this adapter's forward + # output subtensor's grad is not None and share same id with `t1_c`. + # - between `t1_a` and the consumed `t1_b`, this adapter doesn't have + # a mirror (backward op) and its output subtensor's grad is None. Its id is + # the identical to `t1_b` as well. + # `create_segment` is called after `gen_activation`, and we construct the segment's + # output here. However, we don't want to treat `t1_b` as output. To distinguish between + # output tensors of the two adapters, we double check the grad here. + if not isinstance(otensor, IRSubTensor) or otensor.grad == seg_out.grad: + outputs.add(otensor) + seg_out_matched = True + break + if seg_out_matched: continue consumers, ctensors = self.consumers(otensor.parent), self.ctensors(otensor.parent) cids = set(c.cid for c, t in zip(consumers, ctensors) if dmatch(t, otensor)) diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index b385209b..6546720f 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -518,6 +518,8 @@ def __init__(self, ftensor: IRFullTensor, def __eq__(self, other) -> bool: if isinstance(other, IRSubTensor): return self._id == other._id + else: + return False def __hash__(self) -> int: return self._id diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py index eb9c8821..3024faad 100644 --- a/tests/graph/test_segment.py +++ b/tests/graph/test_segment.py @@ -1,11 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import torch +import pytest +import torch.nn as nn +import tempfile +import shutil +import contextlib +from pathlib import Path + + import nnscaler import nnscaler.graph.function.function as F from nnscaler.ir.tensor import IRFullTensor from nnscaler.graph import IRGraph from nnscaler.ir.adapter import IRAdapter +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.ir.operator import IRFwOperation, IRDataOperation +from tests.parallel_module.test_gencode import _gencode_contains +from ..utils import replace_all_device_with, clear_dir_on_rank0, init_random +from ..launch_torchrun import torchrun def _tensor(shape, requires_grad=True): @@ -60,3 +74,164 @@ def test_create_segment_loss_adapter(): print(segment.extra_repr()) assert len(segment.outputs()) == 1 assert segment.output(0) == loss + + +class ModelA(nn.Module): + + def __init__(self): + super(ModelA, self).__init__() + self.fc = nn.Linear(8, 8) + + def forward(self, q): + q = self.fc(q) + q = q.reshape(q.size(0), q.size(1) * q.size(2), -1) + q = q.transpose(0, 1) + l = q.sum() + return l, l.data + + +def policy_transpose(graph: IRGraph, resource: ComputeConfig) -> IRGraph: + ngpus = resource.plan_ngpus + for _, node in enumerate(graph.select(ntype=IRFwOperation)): + print(node.signature) + if node.signature in ["torch.transpose"]: + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=0, dim=0, num=ngpus) + else: + sub_nodes = graph.replicate(node, times=ngpus) + + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + for node in graph.select(ntype=IRDataOperation): + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + + return graph + + +def worker_a(): + nnscaler.init() + init_random() + m = ModelA() + m.train() + trace_data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) + data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_infer_grad_pyfunc') as tempdir: + pm = parallelize( + m, + {'q': trace_data,}, + policy_transpose, + ComputeConfig(2, 2, use_end2end=True), + gen_savedir=tempdir, + reuse='override', + ) + pm.to('cuda') + ret = pm.train_step((data,)) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_infer_grad_pyfunc(): + torchrun(2, worker_a) + # should not raise any exception + assert True + + +def func(x: torch.Tensor) -> torch.Tensor: + return x.detach().clone() + +nnscaler.register_op('* -> *')(func) + +class ModelB(nn.Module): + + def __init__(self): + super(ModelB, self).__init__() + self.fc1 = nn.Linear(8, 8, bias=False) + + def forward(self, q): + q = self.fc1(q) + k = func(q) + l = q.sum() + k.sum() + return l, l + + +def policy_nograd(graph: IRGraph, cfg: ComputeConfig) -> IRGraph: + ngpus = cfg.plan_ngpus + # print(graph.nodes()) + if cfg.use_end2end: + fc1_node = graph.nodes()[1] + func_node = graph.nodes()[2] + else: + fc1_node = graph.nodes()[0] + func_node = graph.nodes()[1] + assert fc1_node.inputs()[0].requires_grad and fc1_node.inputs()[0].grad + assert fc1_node.inputs()[1].requires_grad and fc1_node.inputs()[1].grad + assert fc1_node.outputs()[0].requires_grad and fc1_node.outputs()[0].grad + assert func_node.inputs()[0].requires_grad and not func_node.inputs()[0].grad + assert not func_node.outputs()[0].requires_grad and not func_node.outputs()[0].grad + # add multiref since consumers of fc1's output may in different partition states + # without it generated adapters are wrong + graph.multiref(fc1_node.outputs()[0].parent) + + for _, node in enumerate(graph.select(ntype=IRFwOperation)): + # print(node.signature) + if node.signature == 'torch.nn.functional.linear': + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=0, dim=0, num=ngpus) + elif node.signature == 'torch.sum': + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=0, dim=1, num=ngpus) + elif 'func' in node.signature: + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=0, dim=0, num=ngpus) + else: + sub_nodes = graph.replicate(node, times=ngpus) + + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + + for node in graph.select(ntype=IRDataOperation): + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + return graph + + +def worker_b(use_end2end): + nnscaler.init() + m = ModelB() + m.train() + init_random() + trace_data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) + data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_infer_grad_no_grad') as tempdir: + pm = parallelize( + m, + {'q': trace_data,}, + policy_nograd, + ComputeConfig(2, 2, use_end2end=use_end2end), + gen_savedir=tempdir, + reuse='override', + ) + # adapter between q to q.sum() + assert len(_gencode_contains(tempdir, ModelB, pm.rank, 'nnscaler.runtime.adapter.nn.alltoall_alltoall')) == 1 + # adapter between q to func(q) + assert len(_gencode_contains(tempdir, ModelB, pm.rank, 'nnscaler.runtime.adapter.all_to_all')) == 1 + # adapter between q.sum() to add + assert len(_gencode_contains(tempdir, ModelB, pm.rank, 'nnscaler.runtime.adapter.nn.allreduce_identity')) == 1 + # adapter between k.sum() to add + assert len(_gencode_contains(tempdir, ModelB, pm.rank, 'nnscaler.runtime.adapter.all_reduce')) == 1 + + pm.to('cuda') + if use_end2end: + ret = pm.train_step((data,)) + else: + ret = pm.forward(data) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('use_end2end', [True, False]) +def test_infer_grad_no_grad(use_end2end): + torchrun(2, worker_b, use_end2end) + # should not raise any exception + assert True diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index fa8900cd..0cfac768 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -114,16 +114,6 @@ def init_distributed(): torch.set_default_device(f'cuda:{rank}') -@contextlib.contextmanager -def clear_dir_on_rank0(tempdir): - if torch.distributed.get_rank() == 0 and tempdir.exists(): - shutil.rmtree(tempdir) - yield tempdir - torch.distributed.barrier() - if torch.distributed.get_rank() == 0 and tempdir.exists(): - shutil.rmtree(tempdir) - - def assert_equal(a: Any, b: Any): assert type(a) == type(b) if isinstance(a, torch.Tensor): diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py index b89ce66a..64bf490a 100644 --- a/tests/parallel_module/test_attr_dedup.py +++ b/tests/parallel_module/test_attr_dedup.py @@ -18,8 +18,9 @@ from nnscaler.policies import _tp, _replica from nnscaler.runtime.module import dedup_attrs -from .common import init_distributed, clear_dir_on_rank0, assert_equal +from .common import init_distributed, assert_equal from ..launch_torchrun import launch_torchrun +from ..utils import clear_dir_on_rank0 class Net(torch.nn.Module): def __init__(self): diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 34449e6d..f81514b9 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -21,9 +21,9 @@ from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0, PASMegatron +from .common import CubeLinear, init_random, init_distributed, PASMegatron from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import replace_all_device_with +from ..utils import replace_all_device_with, clear_dir_on_rank0 class FcRelu(nn.Module): diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index 5c93b4d9..cde6f864 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -10,9 +10,9 @@ from nnscaler.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dict, broadcast_weights -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun -from ..utils import catch_log +from ..utils import catch_log, clear_dir_on_rank0 class Net1(torch.nn.Module): diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index 944685bb..d0f5ebae 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -14,9 +14,10 @@ deduped_state_dict, load_deduped_state_dict from nnscaler.runtime.module import ParallelModule -from .common import PASMegatron, CubeLinear, init_random, init_distributed, clear_dir_on_rank0, assert_equal +from .common import PASMegatron, CubeLinear, init_random, init_distributed, assert_equal from ..launch_torchrun import launch_torchrun from .test_checkpoint import gendata, train_step, End2EndMLP, End2EndMLPWithUnusedAndShared +from ..utils import clear_dir_on_rank0 class FcRelu(nn.Module): diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index 0fb417fa..8d71a2f8 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -11,9 +11,10 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dict -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun from .test_checkpoint import End2EndMLP, train_step, gendata +from ..utils import clear_dir_on_rank0 class FcReluWithShared(nn.Module): diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py index 3b9fa85c..ca8ea820 100644 --- a/tests/parallel_module/test_checkpoint_unused.py +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -21,9 +21,10 @@ from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively from .test_checkpoint_shared import _train_raw, _load_merged +from ..utils import clear_dir_on_rank0 class FcReluWithUnused(nn.Module): diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index bff5c315..b43e0d20 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -21,8 +21,9 @@ from nnscaler.runtime.module import ParallelModule from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively +from ..utils import clear_dir_on_rank0 class FcRelu(nn.Module): diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index 13fb20aa..c38d4403 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -21,8 +21,9 @@ from nnscaler.runtime.utils import microbatches from nnscaler.runtime.module import ParallelModule from nnscaler.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts -from .common import assert_equal, clear_dir_on_rank0, init_distributed, PASMegatron, init_random +from .common import assert_equal, init_distributed, PASMegatron, init_random from ..launch_torchrun import clone_to_cpu_recursively, launch_torchrun +from ..utils import replace_all_device_with, clear_dir_on_rank0 from .test_checkpoint import End2EndMLP diff --git a/tests/parallel_module/test_end2end_mix_precision.py b/tests/parallel_module/test_end2end_mix_precision.py index 02d54dbc..aa84dfd9 100644 --- a/tests/parallel_module/test_end2end_mix_precision.py +++ b/tests/parallel_module/test_end2end_mix_precision.py @@ -18,12 +18,12 @@ from nnscaler.runtime.utils import microbatches from nnscaler.runtime.module import ParallelModule from nnscaler.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts -from .common import assert_equal, clear_dir_on_rank0, init_distributed, PASMegatron, init_random +from .common import assert_equal, init_distributed, PASMegatron, init_random from ..launch_torchrun import clone_to_cpu_recursively, launch_torchrun from .test_checkpoint import End2EndMLP from .test_end2end import allclose, merge_cube_result -from ..utils import init_parameter +from ..utils import init_parameter, clear_dir_on_rank0 DATA_SIZE = 16 diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index fbba994c..6b15d8a1 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -11,8 +11,9 @@ from nnscaler.parallel import ComputeConfig, parallelize -from .common import CubeLinear, init_distributed, init_random, clear_dir_on_rank0 +from .common import CubeLinear, init_distributed, init_random from ..launch_torchrun import torchrun +from ..utils import clear_dir_on_rank0 class FcRelu(nn.Module): diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 4f191e81..d2917526 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -10,8 +10,8 @@ from nnscaler.parallel import _load_parallel_module_class, parallelize, ComputeConfig from ..launch_torchrun import launch_torchrun -from .common import CubeLinear, init_distributed, init_random, clear_dir_on_rank0 -from ..utils import new_empty, replace_all_device_with, mock_dist +from .common import CubeLinear, init_distributed, init_random +from ..utils import new_empty, replace_all_device_with, mock_dist, clear_dir_on_rank0 class MyModule(torch.nn.Module): def __init__(self): diff --git a/tests/parallel_module/test_line_timer.py b/tests/parallel_module/test_line_timer.py index e1b8a0d8..90483848 100644 --- a/tests/parallel_module/test_line_timer.py +++ b/tests/parallel_module/test_line_timer.py @@ -11,9 +11,9 @@ from nnscaler.parallel import parallelize, ComputeConfig from nnscaler.flags import CompileFlag -from .common import init_distributed, clear_dir_on_rank0 +from .common import init_distributed from ..launch_torchrun import launch_torchrun -from ..utils import catch_stdout +from ..utils import catch_stdout, clear_dir_on_rank0 class Net(torch.nn.Module): diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index 6aaf5ca6..851dfad0 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -12,8 +12,9 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.runtime.module import ParallelModule -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively +from ..utils import clear_dir_on_rank0 class FcRelu(nn.Module): diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py index c10d6f78..255869cf 100644 --- a/tests/parallel_module/test_scale_grads.py +++ b/tests/parallel_module/test_scale_grads.py @@ -21,8 +21,9 @@ from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively +from ..utils import clear_dir_on_rank0 class FcRelu(nn.Module): diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 20f72f86..047f9f5b 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -15,8 +15,9 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.runtime.module import ParallelModule -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively +from ..utils import clear_dir_on_rank0 class FcRelu(nn.Module): diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index a0e1edf4..64ec5ff0 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -15,8 +15,9 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.runtime.module import ParallelModule -from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 +from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively +from ..utils import clear_dir_on_rank0 class FcRelu(nn.Module): diff --git a/tests/utils.py b/tests/utils.py index 17ca2659..272b02c2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,6 +12,7 @@ import random from datetime import timedelta from pathlib import Path +import shutil import numpy as np @@ -366,3 +367,13 @@ def catch_stdout(): sys.stdout = string_stream yield string_stream sys.stdout = old + + +@contextmanager +def clear_dir_on_rank0(tempdir): + if torch.distributed.get_rank() == 0 and tempdir.exists(): + shutil.rmtree(tempdir) + yield tempdir + torch.distributed.barrier() + if torch.distributed.get_rank() == 0 and tempdir.exists(): + shutil.rmtree(tempdir) From 51dd454cd4a26fdf00778e52c13bfa81643c38c7 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 11 Nov 2024 06:31:22 +0000 Subject: [PATCH 1756/1892] Merged PR 2304: Transformer (vit) example Transformer (vit) example --- docs/source/trainer.md | 2 +- examples/llama3_demo/.gitignore | 4 + examples/nanogpt/train_cli.py | 2 +- examples/vit/.gitignore | 4 + examples/vit/README.md | 43 ++++++ examples/vit/__init__.py | 0 examples/vit/requirements.txt | 5 + examples/vit/train_cli_args.yaml | 56 +++++++ examples/vit/vit_cli.py | 258 +++++++++++++++++++++++++++++++ nnscaler/cli/train.py | 7 +- nnscaler/cli/trainer.py | 7 +- nnscaler/cli/trainer_args.py | 7 +- nnscaler/codegen/emit.py | 24 ++- nnscaler/parallel.py | 2 +- nnscaler/version.py | 2 +- pyproject.toml | 3 + 16 files changed, 409 insertions(+), 17 deletions(-) create mode 100644 examples/llama3_demo/.gitignore create mode 100644 examples/vit/.gitignore create mode 100644 examples/vit/README.md create mode 100644 examples/vit/__init__.py create mode 100644 examples/vit/requirements.txt create mode 100644 examples/vit/train_cli_args.yaml create mode 100644 examples/vit/vit_cli.py diff --git a/docs/source/trainer.md b/docs/source/trainer.md index 78e2acda..d51db712 100644 --- a/docs/source/trainer.md +++ b/docs/source/trainer.md @@ -414,7 +414,7 @@ you can run `compile` mode without `torchrun`. - `val_every_n_epochs` (`Optional[int]`): Validate every `val_every_n_epochs` epochs. Default is `1`. - `enable_progress_bar` (`bool`): Whether to enable the progress bar. Default is `True`. - `seed` (`Optional[int]`): The random seed. Default is `None`. -- `init_env_fn` (`str`): The function to initialize the environment. Default is `None`. +- `init_env_fn` (`str`): The function to initialize the environment. Its only input is `Trainer`. Default is `None`. ## CLI diff --git a/examples/llama3_demo/.gitignore b/examples/llama3_demo/.gitignore new file mode 100644 index 00000000..ea0ff54d --- /dev/null +++ b/examples/llama3_demo/.gitignore @@ -0,0 +1,4 @@ +runs*/ +events.* +.nnscaler*/ +bookcorpus-4096*/ diff --git a/examples/nanogpt/train_cli.py b/examples/nanogpt/train_cli.py index 95c8627a..04586397 100644 --- a/examples/nanogpt/train_cli.py +++ b/examples/nanogpt/train_cli.py @@ -34,7 +34,7 @@ from model import GPTConfig, GPT -def init_env(train_args: 'TrainerArgs'): +def init_env(trainer: 'Trainer'): torch.manual_seed(0) np.random.seed(0) random.seed(0) diff --git a/examples/vit/.gitignore b/examples/vit/.gitignore new file mode 100644 index 00000000..affaef81 --- /dev/null +++ b/examples/vit/.gitignore @@ -0,0 +1,4 @@ +checkpoints/ +logs/ +test-cifar-10/ +wandb/ diff --git a/examples/vit/README.md b/examples/vit/README.md new file mode 100644 index 00000000..96dd2ed5 --- /dev/null +++ b/examples/vit/README.md @@ -0,0 +1,43 @@ +# Introduction + +This example demonstrates how to use nnscaler to fine-tuning a transformer model. Here we use ViT as an example. + +# Requirements + +To run this example, you need to install the packages listed in the `requirements.txt` file. You can install them by running the following command: + +```bash +pip install -r requirements.txt +``` + +*nnScaler* is a framework for distributed training by automatically partitioning the model. Apart from the core nnScaler library, it also includes a mini-trainer for modern model training. You can find related documents and examples at [nnScaler](https://github.com/microsoft/nnscaler) + +*transformers* and *datasets* are required to prepare the data and loading the model. + +The implementation is inspired by [here](https://medium.com/@supersjgk/fine-tuning-vision-transformer-with-hugging-face-and-pytorch-df19839d5396). Many thanks to the author. + + +## Run + +First go to `examples/vit` directory, You can use the following command to run the example: + +1. Use transformer.train() to train the model + - `python examples/vit/vit_cli.py`: will use `DataParallel` to train the model. + It will utilize all your GPUs in current node. + You can specify the GPUs with `CUDA_VISIBLE_DEVICES` env variable. + - `torchrun --nproc_per_node= --nnodes= examples/vit/vit_cli.py`: will use `DistributedDataParallel` to train the model. + +2. Use nnscaler to train the model + `torchrun --nproc_per_node= --nnodes= $(which nnscaler-train) -f train_cli_args.yaml` + +In order to be consistent with `transformers.train()`, +we use dataloader/scheduler from `transformers`. +See `accelerator_dataloader_fn` and `scheduler_fn` functions in the code for details. +If you don't need to be consistent with `transformers`, you can just use your own dataloader/scheduler. + +Please note `nnscaler` only supports 1 optimizer parameter group for now. So we also disable multiple parameter groups in the code when you use `transformers.train()`. +See `SingleParamGroupTrainer` in the code for details. + +The loss of `nnscaler` will be exactly the same as `transformers` given gnorm clipping is disabled. +When gnorm clipping is enabled, the loss will be slightly different +due to the difference in gnorm calculation. diff --git a/examples/vit/__init__.py b/examples/vit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/vit/requirements.txt b/examples/vit/requirements.txt new file mode 100644 index 00000000..b8e0b879 --- /dev/null +++ b/examples/vit/requirements.txt @@ -0,0 +1,5 @@ +transformers +datasets +transformers[torch] +torch +torchvision diff --git a/examples/vit/train_cli_args.yaml b/examples/vit/train_cli_args.yaml new file mode 100644 index 00000000..2d269f99 --- /dev/null +++ b/examples/vit/train_cli_args.yaml @@ -0,0 +1,56 @@ +compute_config: + plan_ngpus: 1 + constant_folding: false + use_zero: true + use_end2end: true + +init_env_fn: examples.vit.vit_cli.init_env +run_mode: run +pas_policy: dp +micro_batch_size: 10 +grad_accumulation_steps: 1 +max_epochs: 3 +enable_progress_bar: true +# precision: bf16 + +model: + type: examples.vit.vit_cli.VModel + +optimizer: + type: torch.optim.AdamW + args: + lr: 2e-5 + weight_decay: 0.01 + clip_gnorm: 1.0 + +lr_scheduler: + type: examples.vit.vit_cli.scheduler_fn + args: + num_warmup_steps: 0 + interval: step + +dataset: + type: examples.vit.vit_cli.cifar10_dataset + train_args: + split: train + val_args: + split: val + +dataloader: + type: examples.vit.vit_cli.accelerator_dataloader_fn + train_args: + collate_fn: examples.vit.vit_cli.cifar10_collate_fn + drop_last: false + +checkpoint: + keep_last_n_checkpoints: 10 + every_n_epochs: 1 + save_type: deduped + +log: + - type: nnscaler.cli.loggers.TensorBoardLogger + args: + name: logs/tb + root_dir: . +# hook: +# on_train_step_end: examples.vit.vit_cli.on_train_step_end diff --git a/examples/vit/vit_cli.py b/examples/vit/vit_cli.py new file mode 100644 index 00000000..92bb5e08 --- /dev/null +++ b/examples/vit/vit_cli.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# ref: https://medium.com/@supersjgk/fine-tuning-vision-transformer-with-hugging-face-and-pytorch-df19839d5396 + +""" +Run example: + +First go to `examples/vit` directory, Then + +1. Use transformer.train() to train the model + a. `python examples/vit/vit_cli.py`: will use `dp` to train the model. + It will utilize all your GPUs in current node. + You can specify the GPUs with `CUDA_VISIBLE_DEVICES` env variable. + b. `torchrun --nproc_per_node= --nnodes= examples/vit/vit_cli.py`: will use `ddp` to train the model. +2. Use nnscaler to train the model + `torchrun --nproc_per_node= --nnodes= $(which nnscaler-train) -f train_cli_args.yaml` + +Here in order to be consistent with transformers, +we use dataloader/scheduler from transformers. +See `accelerator_dataloader_fn` and `scheduler_fn` functions below for details. +If you don't need to be consistent with transformers, you can use your own dataloader/scheduler. + +Please note `nnscaler` only supports 1 optimizer parameter group for now. So we also disable multiple parameter groups in the code when you use `transformers.train()`. +See `SingleParamGroupTrainer` in the code for details. + +The loss of nnscaler will be exactly the same as transformers given gnorm clipping is disabled. +When gnorm clipping is enabled, the loss will be slightly different +due to the difference in gnorm calculation. + +""" + +import random +from typing import TYPE_CHECKING +import time +import os + +from datasets import load_dataset + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Normalize, Resize, ToTensor, Compose +from transformers import ViTImageProcessor, ViTForImageClassification, get_linear_schedule_with_warmup + +import nnscaler + +if TYPE_CHECKING: + from nnscaler.cli.trainer import Trainer + from nnscaler.cli.trainer_args import TrainerArgs + + +VIT_MODEL_NAME = "google/vit-base-patch16-224" + + +_trainer: 'Trainer' = None + + +def init_env(trainer: 'Trainer'): + global _trainer + # save trainer for later use (e.g. to get max_train_steps) + _trainer = trainer + torch.manual_seed(0) + np.random.seed(0) + random.seed(0) + if os.environ.get('DETERMINISTIC') is not None: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + + if int(os.environ.get('NNSCALER_DEBUG', 0)): + import debugpy + # 5678 is the default attach port in the VS Code debug configurations. + # Unless a host and port are specified, host defaults to 127.0.0.1 + # see https://code.visualstudio.com/docs/python/debugging for more details + debugpy.listen(5678) + print("Waiting for debugger attach") + debugpy.wait_for_client() + debugpy.breakpoint() + print('Resume on this line') + + +def on_train_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: + if torch.distributed.get_rank() == 0: + print(f'# train_loss {idx:03d}', outputs[0].item()) + + +def on_val_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: + if torch.distributed.get_rank() == 0: + print(f'# val_loss {idx:03d}', outputs[0].item()) + + +datasets = None +itos = None +stoi = None +vit_processor = None + + +def init_dataset(): + global datasets, itos, stoi, vit_processor + + vit_processor = ViTImageProcessor.from_pretrained(VIT_MODEL_NAME) + + mu, sigma = vit_processor.image_mean, vit_processor.image_std #get default mu,sigma + size = vit_processor.size + + norm = Normalize(mean=mu, std=sigma) #normalize image pixels range to [-1,1] + + # resize 3x32x32 to 3x224x224 -> convert to Pytorch tensor -> normalize + _transf = Compose([ + Resize(size['height']), + ToTensor(), + norm + ]) + + # apply transforms to PIL Image and store it to 'pixels' key + def transf(arg): + arg['pixels'] = [_transf(image.convert('RGB')) for image in arg['img']] + return arg + + trainds, = load_dataset("cifar10", split=["train[:5000]"]) + + itos = dict((k,v) for k,v in enumerate(trainds.features['label'].names)) + stoi = dict((v,k) for k,v in enumerate(trainds.features['label'].names)) + + splits = trainds.train_test_split(test_size=0.1, shuffle=False) + trainds = splits['train'] + valds = splits['test'] + + trainds.set_transform(transf) + valds.set_transform(transf) + + datasets = { + 'train': trainds, + 'val': valds, + } + + +def cifar10_dataset(split): + if not datasets: + init_dataset() + return datasets[split] + + +def cifar10_collate_fn(batch): + return { + 'pixel_values': torch.stack([x['pixels'] for x in batch]), + 'labels': torch.tensor([x['label'] for x in batch]), + } + + +def accelerator_dataloader_fn(dataset, batch_size, collate_fn, num_workers=0, drop_last=False, **kwargs): + from accelerate import Accelerator + from accelerate.utils import GradientAccumulationPlugin, DataLoaderConfiguration + from torch.utils.data.sampler import RandomSampler + from transformers.trainer_utils import seed_worker + + accelerator = Accelerator( + gradient_accumulation_plugin=GradientAccumulationPlugin(num_steps=1), + dataloader_config=DataLoaderConfiguration(even_batches=True, use_seedable_sampler=True) + ) + + sampler = RandomSampler(dataset) + + dataloader_params = { + "batch_size": batch_size, + "collate_fn": collate_fn, + "num_workers": num_workers, + "drop_last": drop_last, + "sampler": sampler, + "worker_init_fn": seed_worker, + } + return accelerator.prepare(DataLoader(dataset, **dataloader_params)) + + +def vit_model(): + with torch.random.fork_rng(): + torch.manual_seed(0) + return ViTForImageClassification.from_pretrained( + VIT_MODEL_NAME, num_labels=10, + ignore_mismatched_sizes=True, + id2label=itos, + label2id=stoi + ) + + +class VModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = vit_model() + + def forward(self, data): + outputs = self.model(pixel_values=data['pixel_values'], labels=data['labels']) + return outputs.loss + + +def scheduler_fn(optimizer, num_warmup_steps): + return get_linear_schedule_with_warmup(optimizer, num_warmup_steps, _trainer.max_train_steps) + + +if __name__ == '__main__': + from transformers import TrainingArguments, Trainer + from sklearn.metrics import accuracy_score + from nnscaler.cli.trainer_args import TrainerArgs + from pathlib import Path + import yaml + + with open(Path(__file__).absolute().with_name('train_cli_args.yaml'), 'r') as f: + trainer_args = yaml.safe_load(f) + + init_env(None) + init_dataset() + + args = TrainingArguments( + f"test-cifar-10", + save_strategy="epoch", + evaluation_strategy="epoch", + learning_rate=float(trainer_args['optimizer']['args']['lr']), + per_device_train_batch_size=int(trainer_args['micro_batch_size']), + max_grad_norm=float(trainer_args['optimizer']['clip_gnorm']), + per_device_eval_batch_size=4, + num_train_epochs=int(trainer_args['max_epochs']), + weight_decay=float(trainer_args['optimizer']['args']['weight_decay']), + warmup_steps=int(trainer_args['lr_scheduler']['args']['num_warmup_steps']), + load_best_model_at_end=True, + metric_for_best_model="accuracy", + logging_dir='logs', + logging_steps=1, + remove_unused_columns=False, + seed=0, + ) + + def compute_metrics(eval_pred): + predictions, labels = eval_pred + predictions = np.argmax(predictions, axis=1) + return dict(accuracy=accuracy_score(predictions, labels)) + + model = vit_model() + class SingleParamGroupTrainer(Trainer): + """ + For parity check reason, + we need to override the `create_optimizer` method to use only one param group for optimizer + """ + def get_decay_parameter_names(self, model) -> list[str]: + # make all parameters decay + return [n for n, _ in model.named_parameters()] + + trainer = SingleParamGroupTrainer( + model, + args, + train_dataset=datasets['train'], + eval_dataset=datasets['val'], + data_collator=cifar10_collate_fn, + compute_metrics=compute_metrics, + tokenizer=vit_processor, + ) + trainer.train() diff --git a/nnscaler/cli/train.py b/nnscaler/cli/train.py index 670e150a..c2a81519 100644 --- a/nnscaler/cli/train.py +++ b/nnscaler/cli/train.py @@ -4,11 +4,14 @@ import logging import nnscaler - from nnscaler.cli.trainer import Trainer -if __name__ == '__main__': +def main(): nnscaler.utils.set_default_logger_level(level=logging.INFO) trainer = Trainer() trainer.run() + + +if __name__ == '__main__': + main() diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index d84b88f4..04ee9d5b 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -133,7 +133,7 @@ def _load_dummy_input(self): return next(iter(dataloader)) def _setup(self): - self.train_args.init_env() + self.train_args.init_env(self) compile_only = self.train_args.compile_mode if is_running_distributed(): @@ -641,7 +641,10 @@ def _train(self): self.hook.on_train_start(self) for epoch in range(start_epoch, self.train_args.max_epochs or sys.maxsize): - self.dataloader['train'].sampler.set_epoch(epoch) + if hasattr(self.dataloader['train'], 'set_epoch'): + self.dataloader['train'].set_epoch(epoch) + elif hasattr(self.dataloader['train'].sampler, 'set_epoch'): + self.dataloader['train'].sampler.set_epoch(epoch) torch.distributed.barrier() diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 3d43356a..e3d01cdf 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -26,6 +26,9 @@ from .loggers.logger_base import LoggerBase from .train_hook import TrainHook +if TYPE_CHECKING: + from .trainer import Trainer + logger = logging.getLogger(__name__) @@ -516,7 +519,7 @@ def buffer_dtype(self) -> torch.dtype: def input_dtype(self) -> torch.dtype: return _PRECISION_MAP[self.precision['input']] - def init_env(self): + def init_env(self, trainer: 'Trainer'): if self.seed is not None: import random import numpy as np @@ -527,7 +530,7 @@ def init_env(self): if self.init_env_fn is None: return init_env_fn = load_type(self.init_env_fn) - init_env_fn(self) + init_env_fn(trainer) def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model.args) diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index 31527b73..fc197b71 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -341,21 +341,31 @@ def emit_release(self, tensors: Iterable[IRTensor]) -> str: tnames : Generator = (self.tensor_name(t) for t in tensors) return 'del ' + ', '.join(tnames) - def get_backward_callsite_io_tensors(self, bwop: IRCell) -> Tuple: + def get_backward_callsite_io_tensors( + self, bwop: IRCell + ) -> Tuple[List[IRSubTensor], List[IRSubTensor], List[IRSubTensor], List[IRSubTensor]]: """ Get backward inputs and outputs + + A tuple of 4 lists will be returned: ``` (input_tensors, output_tensors, output_grads, input_grads) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~ #inputs to 'backward' outputs of 'backward' ``` + See `nnscaler.runtime.executor.backward` for more details. + + Args: + bwop (IRCell): backward node - @return input_tensors List[IRSubTensor]: forward input tensors (backward input) - @return output_tensors List[IRSubTensor]: forward output tensors (backward output) - @return output_grads List[IRSubTensor]: gradient of forward output tensors - (backward input) - @return input_grads List[IRSubTensor]: gradient of forward input tensors - (backward output) + Returns: + tupe of 4 lists: + input_tensors (List[IRSubTensor]): forward input tensors (also backward iutput) + output_tensors (List[IRSubTensor]): forward output tensors (also backward input) + output_grads (List[IRSubTensor]): gradient of forward output tensors + (also backward input) + input_grads (List[IRSubTensor]): gradient of forward input tensors + (also backward output) """ assert not bwop.isfw() fwop: IRCell = bwop.mirror diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 2a6cb375..cb98e023 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -348,7 +348,7 @@ def _to_cpu(val: Any): return {_to_cpu(t) for t in val} if isinstance(val, torch.Tensor): requires_grad = val.is_floating_point() or val.is_complex() - return val.cpu().requires_grad_(requires_grad) + return val.detach().clone().cpu().requires_grad_(requires_grad) return val diff --git a/nnscaler/version.py b/nnscaler/version.py index 47fd5b28..2fa3e43b 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -__version__ = '0.4' +__version__ = '0.5' diff --git a/pyproject.toml b/pyproject.toml index 212f82ff..37f38af6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ classifiers = [ "License :: OSI Approved :: MIT License", ] +[project.scripts] +nnscaler-train = "nnscaler.cli.train:main" + [project.urls] Homepage = "https://github.com/microsoft/nnscaler" From 206e541f8a6e23df308452b6a35912c5842e6cd9 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Tue, 12 Nov 2024 08:43:05 +0000 Subject: [PATCH 1757/1892] Merged PR 2305: add unit test for depthwise conv2d Add unit test for depthwise conv2d It's used to test the correctness of the gen_graph and segmentation for depthwise conv2d, which means groups==in_channels and out_channels==k*in_channels in conv2d. Related work items: #2060 --- nnscaler/algorithm/ops/dimops.py | 45 +++++++++++++++++------------- tests/algorithm/ops/test_dimops.py | 36 ++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/nnscaler/algorithm/ops/dimops.py b/nnscaler/algorithm/ops/dimops.py index c156769c..f1aa1cae 100644 --- a/nnscaler/algorithm/ops/dimops.py +++ b/nnscaler/algorithm/ops/dimops.py @@ -57,11 +57,14 @@ def get_identifier_reduce(self, idx: int, dim: int, num: int) -> Tuple[str, DimA If the partitioned number is 1, return the first hidden identitifer Otherwise, return the first hidden identifier whose length > 1 - @param idx int: input/output index. Take the idx-th input tensor or (idx-ninputs)-th output - @param dim int: input dimension + Args: + idx (int): input/output index. Take the idx-th input tensor or (idx-ninputs)-th output + dim (int): input dimension + num (int): chunks to partition the dimension - @return identifier Optional[str]: annotated dimension identifier - @return reduction Optional[DimAnno.ReduceType] + Returns: + identifier (Optional[str]): annotated dimension identifier + reduction (Optional[DimAnno.ReduceType]) """ node: IRDimops = self.node eshapes = node.anno.inputs() + node.anno.outputs() @@ -79,11 +82,13 @@ def satisfy(self, idx: int, dim: Union[int, str], num: int) -> bool: """ Check whether the condition satisfies. - @param idx int: input/output index. Take the idx-th input tensor or (idx-ninputs)-th output tensor - @param dim Union[int, str]: tensor dimension or 'v', i.e., partition at value dimension. - @param num int: chunks to partition the dimension + Args: + idx (int): input/output index. Take the idx-th input tensor or (idx-ninputs)-th output tensor + dim (Union[int, str]): tensor dimension or 'v', i.e., partition at value dimension. + num (int): chunks to partition the dimension - @return satisfy bool: true if can be partitioned, elsewise false. + Returns: + satisfy (bool): true if can be partitioned, elsewise false. """ assert all(isinstance(cond, int) for cond in [idx, num]), "expect int condition" assert isinstance(dim, int) or dim == 'v', f"expect dim to be int or 'v'" @@ -176,12 +181,12 @@ def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformR return the partitioning of the output tensor. Args: - idx int: the input index - dim int: the dimension to partition - num int: the number of partitions + idx (int): the input index + dim (int): the dimension to partition + num (int): the number of partitions Returns: - rule TransformRule: the transformation rule + rule (TransformRule): the transformation rule """ node: IRDimops = self.node assert isinstance(dim, int) or dim == 'v', f"expect dim to be int or 'v'" @@ -275,19 +280,19 @@ def collect_split_info(node: IRDimops): def gen_partitions(node: IRFwOperation, ngpus: int, base: int = 2, depth: int = -1) -> List[IRFwOperation]: """ - Generate the partitioned nodes of the given node. Each node in the returned list is an - partition instance of a policy. For example, if the input node is a matmul with shape + Generate the partitioned nodes of the given node. Each node in the returned list is a possible partition + instance of a policy in one of the devices. For example, if the input node is a matmul with shape (1024, 4096), (4096, 2048) -> (1024, 2048), the ngpus is 2, base is 2, then the returned list will contain 4 instances: - 1. matmul with shape (1024, 4096), (4096, 2048) -> (1024, 2048) - 2. matmul with shape (1024, 2048), (2048, 2048) -> (1024, 2048) - 3. matmul with shape ( 512, 4096), (4096, 2048) -> ( 512, 2048) - 4. matmul with shape (1024, 4096), (4096, 1024) -> (1024, 1024) + 1. matmul with shape (1024, 4096), (4096, 2048) -> (1024, 2048), this means no partition, replicate on 2 gpus + 2. matmul with shape ( 512, 4096), (4096, 2048) -> ( 512, 2048), partition first input first dimension + 3. matmul with shape (1024, 2048), (2048, 2048) -> (1024, 2048), partition first input second dimension + 4. matmul with shape (1024, 4096), (4096, 1024) -> (1024, 1024), partition second input second dimension Args: node (IRFwOperation): the node to be partitioned ngpus (int): the number of gpus - base (int): the base of the division for the partitioning + base (int): the partition number at each generation step in breadth-first-search depth (int): the maximum depth of the search process, -1 for no limit Returns: @@ -316,7 +321,7 @@ def gen_hash(node: IRFwOperation) -> str: while dq: cur_node, cur_ngpus, cur_depth = dq.popleft() gen_nodes.append(cur_node) - if depth != -1 and cur_depth >= depth: + if (depth != -1 and cur_depth >= depth) or base > cur_ngpus: continue split_info = collect_split_info(cur_node) diff --git a/tests/algorithm/ops/test_dimops.py b/tests/algorithm/ops/test_dimops.py index 193c77fb..cf3386d2 100644 --- a/tests/algorithm/ops/test_dimops.py +++ b/tests/algorithm/ops/test_dimops.py @@ -36,3 +36,39 @@ def test_gen_partitions(): assert len(gen_partitions(fc1, 4)) == 14 # C(4, 1) + 1 - 1 = 4 assert len(gen_partitions(fc1, 4, base=4, depth=1)) == 4 + + +class DepthwiseConv2d(torch.nn.Module): + def __init__(self, in_channels, multiplier_k, kernel_size, stride=1, padding=0): + super(DepthwiseConv2d, self).__init__() + self.depthwise = torch.nn.Conv2d(in_channels, in_channels * multiplier_k, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + + def forward(self, x): + return self.depthwise(x) + + +@replace_all_device_with('cpu') +def test_gen_partitions_depthwise_conv2d(): + in_channels = 8 + multiplier_k = 4 + kernel_size = 3 + stride = 1 + padding = 1 + batch_size = 16 + height = 256 + width = 256 + with tempfile.TemporaryDirectory() as tempdir: + graph, _ = _gen_graph(DepthwiseConv2d(in_channels, multiplier_k, kernel_size, stride, padding), + {'x': torch.randn(batch_size, in_channels, height, width)}, + tempdir, False) + depthwise = graph.select(ntype=IRFwOperation)[0] + # anno: n (g 1^) 256^ 256^, (g 4^) 1^ 3^ 3^, (g 4^) -> n (g 4^) 256^ 256^ + assert len(gen_partitions(depthwise, 1)) == 1 + # n g, n/2 g, n g/2 + assert len(gen_partitions(depthwise, 2)) == 3 + # n g, n/2 g, n g/2, n/2 g/2, n g/2/2, n/2/2 g + assert len(gen_partitions(depthwise, 4)) == 6 + # n g, n/4 g, n g/4 + assert len(gen_partitions(depthwise, 4, base=4, depth=1)) == 3 + From edb4510162c82d30435953a7812d4153daf5ab2f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 13 Nov 2024 05:41:56 +0000 Subject: [PATCH 1758/1892] Merged PR 2307: add async support for non-end2end modules with sync_grad_when function add async support for non-end2end modules with sync_grad_when function --- nnscaler/parallel.py | 33 ++++- tests/parallel_module/test_async.py | 159 +++++++++++++++++++++++ tests/parallel_module/test_checkpoint.py | 1 + 3 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 tests/parallel_module/test_async.py diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index cb98e023..b72522f1 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -155,9 +155,6 @@ def __post_init__(self): # have to use __setattr__ for frozen dataclass super().__setattr__('zero_ngroups', 1) - if self.use_async_reducer and not self.use_end2end: - raise ValueError("use_async_reducer is only supported in end2end mode.") - if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") @@ -2365,3 +2362,33 @@ def load_sharded_state_dict( if optimizer_state_dict is None: raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") optimizer.load_state_dict(optimizer_state_dict) + + +def sync_grad_when(cond: bool): + """ + Context manager to enable/disable gradient synchronizations across workers. + + Within this context, gradients will be accumulated + only when `cond` is True. + + This is needed when + 1. The mode is not end2end model. + For end2end model, gradients are synchronized across workers automatically. + 2. async is enabled (`compute_config.use_async_reducer` is `True`). + + If both conditions are not satisfied, this function has no effect. + + Example: + >>> model = parallelize(model, ...) + >>> accum_steps = ... + >>> for step in range(accum_steps) + >>> with sync_grad_when(step == accum_steps - 1): + >>> loss = ... + >>> loss.backward() + >>> optimizer.step() + >>> optimizer.zero_grad() + + Args: + cond (bool): whether to synchronize gradients. + """ + return _runtime_flags(skip_reducer=not cond) diff --git a/tests/parallel_module/test_async.py b/tests/parallel_module/test_async.py new file mode 100644 index 00000000..27d4d101 --- /dev/null +++ b/tests/parallel_module/test_async.py @@ -0,0 +1,159 @@ +from pathlib import Path +import tempfile +import pytest +import torch +from torch import nn + +from nnscaler import parallelize, ComputeConfig, ParallelModule + +from nnscaler.parallel import build_optimizer, sync_grad_when, merge_state_dicts +from tests.launch_torchrun import launch_torchrun +from tests.launch_torchrun import clone_to_cpu_recursively +from tests.parallel_module.common import assert_equal, init_distributed +from tests.utils import clear_dir_on_rank0, init_random +from .test_wholemodule import FcRelu_4_4 + + +class OrigModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc_relu1 = FcRelu_4_4() + self.fc_relu2 = FcRelu_4_4() + self.linear3 = nn.Linear(4, 1) + self.sigmoid = nn.Sigmoid() + def forward(self, x): + x = self.fc_relu1(x) + x = self.fc_relu2(x) + x = self.linear3(x) + x = self.sigmoid(x) + return x + + +def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): + return parallelize( + module, + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + compute_config, + gen_savedir=cube_savedir, + instance_name=instance_name + ) + + +def _create_modules(pas, compute_config, cube_savedir, name_prefix=''): + init_random() + whole_module = _to_cube_model( + OrigModule(), pas, compute_config, cube_savedir, f'{name_prefix}whole' + ).cuda() + init_random() + sub_module = OrigModule().cuda() + sub_module.fc_relu1 = _to_cube_model( + sub_module.fc_relu1, pas, compute_config, cube_savedir, f'{name_prefix}fc_relu1' + ).cuda() + sub_module.fc_relu2 = _to_cube_model( + sub_module.fc_relu2, pas, compute_config, cube_savedir, f'{name_prefix}fc_relu2' + ).cuda() + return whole_module, sub_module + + +def _train(model: ParallelModule, update_freq): + init_random() + + loss_fn = nn.BCELoss() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + data = [] + DATA_SIZE = 20 + UPDATE_FREQ = update_freq + for _ in range(DATA_SIZE): + data.append(( + torch.randn((2, 4), device='cuda', dtype=torch.float32), + torch.rand((2, 1), device='cuda', dtype=torch.float32), + )) + results = [] + for i, (x, y) in enumerate(data): + model.train() + with sync_grad_when(i % UPDATE_FREQ == UPDATE_FREQ - 1): + y_pred = model(x) + loss = loss_fn(y_pred, y) + loss.backward() + if i % UPDATE_FREQ == UPDATE_FREQ - 1: + optimizer.step() + optimizer.zero_grad() + results.append(clone_to_cpu_recursively([y_pred, model.state_dict()])) + return results + + +def _gpu_worker(pas, ngpus, update_freq): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_async') as tempdir: + whole_module_async, sub_module_async = _create_modules( + pas, ComputeConfig( + 1, ngpus, use_async_reducer=True, + reducer_bucket_cap_mb=1e-6 + ), + tempdir, + 'async_', + ) + whole_module_sync, sub_module_sync = _create_modules( + pas, ComputeConfig( + 1, ngpus, use_async_reducer=False, + reducer_bucket_cap_mb=100 + ), + tempdir, + 'sync_', + ) + whole_async_results = _train(whole_module_async, update_freq) + whole_sync_results = _train(whole_module_sync, update_freq) + sub_async_results = _train(sub_module_async, update_freq) + sub_sync_results = _train(sub_module_sync, update_freq) + return ( + whole_async_results, + whole_sync_results, + sub_async_results, + sub_sync_results + ) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('update_freq', [1, 4]) +def test_dp2(update_freq): + results = launch_torchrun(2, _gpu_worker, 'dp', 2, update_freq) + whole_async0, whole_sync0, sub_async0, sub_sync0 = results[0] + whole_async1, whole_sync1, sub_async1, sub_sync1 = results[1] + + assert len(whole_async0) == len(whole_sync0) == len(sub_async0) == len(sub_sync0) + + for iter in range(len(whole_async0)): # for each iteration + iter_whole_async0 = whole_async0[iter] + iter_whole_sync0 = whole_sync0[iter] + iter_sub_async0 = sub_async0[iter] + iter_sub_sync0 = sub_sync0[iter] + + iter_whole_async1 = whole_async1[iter] + iter_whole_sync1 = whole_sync1[iter] + iter_sub_async1 = sub_async1[iter] + iter_sub_sync1 = sub_sync1[iter] + + # pred + assert torch.equal(iter_whole_async0[0], iter_whole_async1[0]) + assert torch.equal(iter_sub_async0[0], iter_sub_async1[0]) + assert torch.equal(iter_whole_sync0[0], iter_whole_sync1[0]) + assert torch.equal(iter_sub_sync0[0], iter_sub_sync1[0]) + + assert torch.equal(iter_whole_async0[0], iter_whole_sync0[0]) + assert torch.equal(iter_sub_async0[0], iter_sub_sync0[0]) + assert torch.equal(iter_whole_async0[0], iter_sub_async0[0]) + + # weights + whole_async_weights, _ = merge_state_dicts([iter_whole_async0[1], iter_whole_async1[1]]) + whole_sync_weights, _ = merge_state_dicts([iter_whole_sync0[1], iter_whole_sync1[1]]) + sub_async_weights, _ = merge_state_dicts([iter_sub_async0[1], iter_sub_async1[1]]) + sub_sync_weights, _ = merge_state_dicts([iter_sub_sync0[1], iter_sub_sync1[1]]) + + assert_equal(whole_async_weights, whole_sync_weights) + assert_equal(sub_async_weights, sub_sync_weights) + + assert set(whole_async_weights.keys()) == set(sub_async_weights.keys()) + + for key in whole_async_weights.keys(): + assert torch.equal(whole_async_weights[key], sub_async_weights[key]) diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index f81514b9..a0be40ca 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -457,6 +457,7 @@ def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus, per_resum )) return compiled_results + @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @pytest.mark.parametrize('module_type', ['sub', 'whole', 'start', 'end', 'small', 'pipeline']) @pytest.mark.parametrize('use_zero', [True, False]) From 8d44619b3402c496066d73f04abfe46a9288c482 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 13 Nov 2024 06:44:27 +0000 Subject: [PATCH 1759/1892] Merged PR 2306: Refine function.To implementation Refine function.To implementation to be more compatible with torch.Tensor.to Related work items: #2075 --- nnscaler/graph/function/function.py | 185 +++++++++++++++++++------ nnscaler/graph/function/wrapnn.py | 22 ++- nnscaler/ir/cten.py | 25 +++- nnscaler/runtime/function/function.py | 2 + tests/graph/function/test_functions.py | 107 ++++++++++++-- tests/parallel_module/test_gencode.py | 26 ++++ 6 files changed, 297 insertions(+), 70 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index adf58836..43b6ca99 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -2213,32 +2213,135 @@ def Dim(tensor, signature=None) -> Union[List[int], IRPyFunc]: return len(tensor.shape) -def To(tensor: IRTensor, dtype_or_device=None, *, device=None, dtype=None, out=None, signature = None): +def _resolve_overload_args( + args: List[Any], + kwargs: Dict[str, Any], + positional_arg_names: List[List[str]], + kwarg_names: List[List[str]], + arg_types: Dict[str, Any], +) -> Dict[str, Any]: + """ + Resolve the arguments of a function with multiple overloads. + """ + + def _find_first_invalid_arg_name(arg_values, overload_idx): + for arg_name in arg_values: + if arg_name not in positional_arg_names[overload_idx] \ + and arg_name not in kwarg_names[overload_idx]: + return arg_name + return None + + arg_values = dict(kwargs) + + if args: + # some parameters are passed as positional arguments + # overload matching is done by checking the type of the first positional argument + # here we use unwrapped value, + # because if it's a wrapped IRObject, it's impossible for us to select the correct overload + arg0 = IRObject.try_unwrap(args[0]) + for overload_idx in range(len(positional_arg_names)): + # if arg[0] is None, use the first overload + if arg0 is None or isinstance(arg0, arg_types[positional_arg_names[overload_idx][0]]): + if len(args) > len(positional_arg_names[overload_idx]): + raise ValueError('Received too many positional arguments') + + for i, arg in enumerate(args): + arg_name = positional_arg_names[overload_idx][i] + if arg_name in arg_values: + raise ValueError(f'{arg_name} is specified as a keyword argument and as a positional argument') + # TODO: check the type of arg + # Currently we assume that the type of arg is correct + # We may need to add type checking in the future + # when we know 100% how pytorch handles the overloads + arg_values[arg_name] = arg + + if invalid_arg_name := _find_first_invalid_arg_name(arg_values, overload_idx): + raise ValueError(f'{invalid_arg_name} is not a valid argument for this overload') + + break + else: + raise ValueError('Received an invalid combination of arguments') + else: + # no positional arguments are passed + # In this case, we dont' know which overload to use + # Here we will check the arguments and report error if it doesn't match any overloads. + invalids = [_find_first_invalid_arg_name(arg_values, i) for i in range(len(positional_arg_names))] + if all(invalids): # all overloads have error + # just report the invalid argument of first overload + raise ValueError(f'{invalids[0]} is not a valid argument') + + return arg_values + + +def To( + input: IRTensor, + *args, + **kwargs, +): """ torch.Tensor.to(*args, **kwargs) → Tensor - """ - assert out is None - # FIXME: support full version of torch.Tensor.to - dtype_or_device = dtype if dtype is not None else dtype_or_device - dtype_or_device = device if dtype_or_device is None else dtype_or_device - if isinstance(dtype_or_device, torch.device) or isinstance(device, torch.device): - warn_msg = 'Cube will handle the tensor device placement, the call of torch.Tensor.to(device=...) will be ignore, ' \ + three overloads: + ``` + to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None) + to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None) + to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None) + ``` + Pytorch will try to match the overloads from top to bottom. + We will mimic the behavior here. + """ + assert kwargs.pop('out', None) is None + signature = kwargs.pop('signature', 'torch.Tensor.to') + + args_types ={ + 'device': (int, str, torch.device), + 'dtype': torch.dtype, + 'tensor': IRTensor, + 'non_blocking': bool, + 'copy': bool, + 'memory_format': torch.memory_format, + } + + positional_arg_names = [ + ('device', 'dtype', 'non_blocking', 'copy'), # 1st overload + ('dtype', 'non_blocking', 'copy'), # 2nd overload + ('tensor', 'non_blocking', 'copy'), # 3rd overload + ] + + kwarg_names = [ + ('memory_format',), + ('memory_format',), + ('memory_format',), + ] + + arg_values = _resolve_overload_args( + args, + kwargs, + positional_arg_names, + kwarg_names, + args_types, + ) + + if arg_values.get('device', None) is not None: + warn_msg = 'nnscaler will handle the tensor device placement, the call of torch.Tensor.to(device=...) will be ignore, ' \ 'if you really want to put the tensor on cpu to excute some op, please wrap all related ops in an independent function ' \ - 'and using nnscaler.graph.parser.register to register this function.' + 'and using nnscaler.register_op to register this function.' _logger.warning(warn_msg) - # create "to" in cube runtime functions because dtype if not kwarg in torch.Tensor.to - signature = 'nnscaler.runtime.function.to' + annos = ['* -> *'] - if isinstance(dtype_or_device, torch.device): - # skip device movement as policy can determine device for the tensor. - return Identity(tensor) - elif isinstance(dtype_or_device, torch.dtype): - return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype_or_device) - elif isinstance(dtype_or_device, IRTensor): - dtype = dtype_or_device.dtype - return IRDimops(To, 'to', signature, annos, [tensor], dtype_or_device=dtype) - else: - raise RuntimeError(f'function.To with unknown arg: {dtype_or_device}') + if 'tensor' in arg_values: # overload 3 + # Here we keep tensor.dtype, + # and discard tensor.device, because we will handle the device placement + arg_values['dtype'] = arg_values['tensor'].dtype + arg_values.pop('tensor') + # TODO: It may be better if we wrap dtype in an IRObject but it will introduce a lot of code. + # (e.g. we need to insert a getattr op in the graph) + return IRDimops(To, 'to', signature, annos, [input], **arg_values) + elif 'device' not in arg_values and 'dtype' in arg_values: # overload 2 + return IRDimops(To, 'to', signature, annos, [input], **arg_values) + else: # overload 1 + # remove device from kwargs because we will handle the device placement + filtered_kwargs = {k: v for k, v in arg_values.items() if k != 'device'} + return IRDimops(To, 'to', signature, annos, [input], **filtered_kwargs) def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: @@ -2693,20 +2796,16 @@ def Erf(input, *, out=None, signature=None): return IRDimops(Erf, 'erf', signature, annos, [input]) -def unwrap_if_irobject(x): - return x.value if isinstance(x, IRObject) else x - - def Conv1D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, signature=None): """ torch.nn.functional.conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor """ if len(input.shape) not in [2, 3]: raise ValueError(f"Expected input tensor to have 2 or 3 dimensions, but got {input.shape}") - stride_val = unwrap_if_irobject(stride) - padding_val = unwrap_if_irobject(padding) - dilation_val = unwrap_if_irobject(dilation) - groups_val = unwrap_if_irobject(groups) + stride_val = IRObject.try_unwrap(stride) + padding_val = IRObject.try_unwrap(padding) + dilation_val = IRObject.try_unwrap(dilation) + groups_val = IRObject.try_unwrap(groups) if isinstance(stride_val, int): stride_val = (stride_val,) if isinstance(dilation_val, int): @@ -2789,11 +2888,11 @@ def ConvTranspose1D(input, weight, bias=None, stride=1, padding=0, output_paddin """ if len(input.shape) not in [2, 3]: raise ValueError(f"Expected input tensor to have 2 or 3 dimensions, but got {input.shape}") - stride_val = unwrap_if_irobject(stride) - padding_val = unwrap_if_irobject(padding) - output_padding_val = unwrap_if_irobject(output_padding) - dilation_val = unwrap_if_irobject(dilation) - groups_val = unwrap_if_irobject(groups) + stride_val = IRObject.try_unwrap(stride) + padding_val = IRObject.try_unwrap(padding) + output_padding_val = IRObject.try_unwrap(output_padding) + dilation_val = IRObject.try_unwrap(dilation) + groups_val = IRObject.try_unwrap(groups) if isinstance(stride_val, int): stride_val = (stride_val,) if isinstance(padding_val, int): @@ -2850,10 +2949,10 @@ def Conv2D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, """ if len(input.shape) not in [3, 4]: raise ValueError(f"Expected input tensor to have 3 or 4 dimensions, but got {input.shape}") - stride_val = unwrap_if_irobject(stride) - padding_val = unwrap_if_irobject(padding) - dilation_val = unwrap_if_irobject(dilation) - groups_val = unwrap_if_irobject(groups) + stride_val = IRObject.try_unwrap(stride) + padding_val = IRObject.try_unwrap(padding) + dilation_val = IRObject.try_unwrap(dilation) + groups_val = IRObject.try_unwrap(groups) if isinstance(stride_val, int): stride_val = (stride_val, stride_val) if isinstance(dilation_val, int): @@ -2941,11 +3040,11 @@ def ConvTranspose2D(input, weight, bias=None, stride=1, padding=0, output_paddin """ if len(input.shape) not in [3, 4]: raise ValueError(f"Expected input tensor to have 3 or 4 dimensions, but got {input.shape}") - stride_val = unwrap_if_irobject(stride) - padding_val = unwrap_if_irobject(padding) - output_padding_val = unwrap_if_irobject(output_padding) - dilation_val = unwrap_if_irobject(dilation) - groups_val = unwrap_if_irobject(groups) + stride_val = IRObject.try_unwrap(stride) + padding_val = IRObject.try_unwrap(padding) + output_padding_val = IRObject.try_unwrap(output_padding) + dilation_val = IRObject.try_unwrap(dilation) + groups_val = IRObject.try_unwrap(groups) if isinstance(stride_val, int): stride_val = (stride_val, stride_val) if isinstance(padding_val, int): diff --git a/nnscaler/graph/function/wrapnn.py b/nnscaler/graph/function/wrapnn.py index fa3cd4e1..a65a0615 100644 --- a/nnscaler/graph/function/wrapnn.py +++ b/nnscaler/graph/function/wrapnn.py @@ -128,10 +128,6 @@ def wrap_batchnorm2d_func( ) -def unwrap_if_irobject(x): - return x.value if isinstance(x, IRObject) and not isinstance(x, IRTensor) else x - - def batchnorm2d_annotation_fn(*inputs, **kwargs): assert ( len(inputs) == 6 @@ -148,11 +144,11 @@ def batchnorm2d_annotation_fn(*inputs, **kwargs): should also be absent. Reference: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html """ - weight = unwrap_if_irobject(weight) - bias = unwrap_if_irobject(bias) - running_mean = unwrap_if_irobject(running_mean) - running_var = unwrap_if_irobject(running_var) - num_batches_tracked = unwrap_if_irobject(num_batches_tracked) + weight = IRObject.try_unwrap(weight) + bias = IRObject.try_unwrap(bias) + running_mean = IRObject.try_unwrap(running_mean) + running_var = IRObject.try_unwrap(running_var) + num_batches_tracked = IRObject.try_unwrap(num_batches_tracked) if weight is None: assert bias is None @@ -379,10 +375,10 @@ def instancenorm2d_annotation_fn(*inputs, **kwargs): ), "Expected 5 inputs: input, weight, bias, running_mean, running_var" input, weight, bias, running_mean, running_var = inputs - weight = unwrap_if_irobject(weight) - bias = unwrap_if_irobject(bias) - running_mean = unwrap_if_irobject(running_mean) - running_var = unwrap_if_irobject(running_var) + weight = IRObject.try_unwrap(weight) + bias = IRObject.try_unwrap(bias) + running_mean = IRObject.try_unwrap(running_mean) + running_var = IRObject.try_unwrap(running_var) if weight is None: assert bias is None diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index f3cb1544..03550aa5 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -665,7 +665,7 @@ def _inner(obj) -> Tuple[Any, bool]: return IRObject(name, value=obj, is_constant=is_constant), False else: return {k: r[0] for k, r in result.items()}.items(), True - + # slice will go here, as its start/stop/step are never tensor-like objects return IRObject(name, value=obj, is_constant=is_constant), False return _inner(data)[0] @@ -682,6 +682,29 @@ def tosub_complex(cls, obj: Any) -> Any: modifier = lambda t: t.tosub() if isinstance(t, IRFullTensor) else t return IRCell.modify_objects_of_complex(obj, modifier) + @classmethod + def try_unwrap(cls, x: Union[Any, 'IRObject']) -> Any: + """ + Unwrap the IRObject to its original value if it is an IRObject + otherwise, go recursively. + + Args: + x (Any): the object to unwrap + + Returns: + Any: the original value + """ + if isinstance(x, IRObject) and not isinstance(x, IRTensor): + return x.value + elif isinstance(x, (list, tuple)): + return type(x)(cls.try_unwrap(v) for v in x) + elif isinstance(x, dict): + return {k: cls.try_unwrap(v) for k, v in x.items()} + elif isinstance(x, slice): + return slice(cls.try_unwrap(x.start), cls.try_unwrap(x.stop), cls.try_unwrap(x.step)) + else: + return x + def __repr__(self): return f'Object({self.name}{self.tid}, val={self.value}, is_constant={self.is_constant})' diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 9d278a77..ee1f7166 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -39,6 +39,8 @@ def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) -> torch.Tensor: + # deprecated + # keep it only for backward compatibility return tensor.to(dtype_or_device) diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index e211c89b..951cfb35 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -538,21 +538,102 @@ def test_type(): def test_to(): - op = F.To(IRTensor([2, 3], dtype=None), dtype=torch.float32) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype_or_device'] == torch.float32 + with pytest.raises(ValueError, match='.*is not a valid argument.*'): + op = F.To(IRTensor([2, 3], dtype=torch.float32), xx=None) + + with pytest.raises(ValueError, match='.*is not a valid argument.*'): + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0, xx=None) + + with pytest.raises(ValueError, match='.*is not a valid argument.*'): + op = F.To(IRTensor([2, 3], dtype=torch.float32), torch.float32, xx=None) + + # 1st overload + op = F.To(IRTensor([2, 3], dtype=torch.float32)) # No arguments + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs + op = F.To(IRTensor([2, 3], dtype=torch.float32), None) # only None + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs + op = F.To(IRTensor([2, 3], dtype=torch.float32), torch.device('cuda:0')) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs + op = F.To(IRTensor([2, 3], dtype=torch.float32), 'cuda:0') + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs + op = F.To(IRTensor([2, 3], dtype=torch.float32), device=None) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs op = F.To(IRTensor([2, 3], dtype=torch.float32), device=torch.device('cuda:0')) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b' - op = F.To(IRTensor([3, 5], dtype=torch.int64), dtype=torch.float32) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype_or_device'] == torch.float32 - op = F.To(IRTensor([2, 3], dtype=torch.float32), device=torch.device('cuda:0')) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b' - op = F.To(IRTensor([3, 5], dtype=torch.int64), dtype=IRTensor(dtype=torch.float32)) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' - op = F.To(IRTensor([2, 3], dtype=torch.float32), device=torch.device('cuda:0'), dtype=torch.float32) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs['dtype_or_device'] == torch.float32 - op = F.To(IRTensor([3, 5], dtype=torch.int64), dtype_or_device=IRTensor(dtype=torch.float32)) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs + op = F.To(IRTensor([2, 3], dtype=torch.float32), device=0) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs + op = F.To(IRTensor([2, 3], dtype=torch.float32), device='cuda:0') + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and not op.kwargs + + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0, None) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ + and op.kwargs == {'dtype': None} + + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0, None, True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ + and op.kwargs == {'dtype': None, 'non_blocking': True} + + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0, None, True, None) + # Note type of copy is None, which is not correct. + # because currently we don't do type checking + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ + and op.kwargs == {'dtype': None, 'non_blocking': True, 'copy': None} + + # 1st overload with options + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0, copy=True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ + and op.kwargs == {'copy': True} + + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0, None, copy=True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ + and op.kwargs == {'dtype': None, 'copy': True} + + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0, copy=True, non_blocking=True, memory_format=torch.contiguous_format) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ + and op.kwargs == {'copy': True, 'non_blocking': True, 'memory_format': torch.contiguous_format} + + # 1st overload with duplicate options + with pytest.raises(ValueError): + op = F.To(IRTensor([2, 3], dtype=torch.float32), 0, copy=True, device=None) + # 2nd overload + op = F.To(IRTensor([2, 3]), torch.float32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs == {'dtype': torch.float32} + + op = F.To(IRTensor([2, 3]), dtype=torch.float32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs == {'dtype': torch.float32} + + op = F.To(IRTensor([2, 3]), torch.float32, False) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ + and op.kwargs == {'dtype': torch.float32, 'non_blocking': False} + + op = F.To(IRTensor([2, 3]), torch.float32, False, copy=False, memory_format=torch.contiguous_format) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ + and op.kwargs == {'dtype': torch.float32, 'non_blocking': False, + 'copy': False, 'memory_format': torch.contiguous_format} + + # duplicate options + with pytest.raises(ValueError): + op = F.To(IRTensor([2, 3]), torch.float32, False, non_blocking=True) + + # 3rd overload + op = F.To(IRTensor([3, 5], dtype=torch.int64), IRTensor(dtype=torch.float32)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs == {'dtype': torch.float32} + op = F.To(IRTensor([3, 5], dtype=torch.int64), IRTensor(dtype=torch.float32), True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs == {'dtype': torch.float32, 'non_blocking': True} + op = F.To(IRTensor([3, 5], dtype=torch.int64), IRTensor(dtype=torch.float32), True, None) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs == {'dtype': torch.float32, 'non_blocking': True, 'copy': None} + op = F.To(IRTensor([3, 5], dtype=torch.int64), IRTensor(dtype=torch.float32), True, copy=True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' and op.kwargs == {'dtype': torch.float32, 'non_blocking': True, 'copy': True} + + # duplicate options + with pytest.raises(ValueError): + op = F.To(IRTensor([2, 3]), IRTensor(dtype=torch.float32), False, non_blocking=True) + # too many positional arguments + with pytest.raises(ValueError, match='.*too many positional arguments.*'): + op = F.To(IRTensor([3, 5], dtype=torch.int64), IRTensor(dtype=torch.float32), True, None, True) def test_outer(): diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index e3981ee8..85f699fe 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -1459,3 +1459,29 @@ def _gencode_conv2d_function_(tempdir): def test_codegen_conv2d_groups(): with tempfile.TemporaryDirectory() as tempdir: launch_torchrun(1, _gencode_conv2d_function_, tempdir) + + +class FunctionToModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + return self.linear(x.to(0)).to(torch.float32, copy=True) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of GPU devices') +def test_codegen_function_to(tmp_path): + parallelize( + FunctionToModule(), + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False + ) + # device argument is removed + # to_23 = torch.Tensor.to(x_29) + assert _gencode_contains(tmp_path, FunctionToModule, 0, r'to_\d+ = torch\.Tensor\.to\(x_\d+\)') + # to_1_22 = torch.Tensor.to(linear_26, copy=True, dtype=torch.float32) + assert _gencode_contains(tmp_path, FunctionToModule, 0, r'torch\.Tensor\.to([^, ]*, copy=True, dtype=torch.float32)') From 4126a1405762633a4e8b337eb7a805fc03f0337b Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Thu, 14 Nov 2024 00:14:40 +0000 Subject: [PATCH 1760/1892] Merged PR 2290: Resume RNG state and add test case Store and resume torch's RNG state in mini-trainer's sharded checkpoints. Python and numpy's RNG are not touched, mentioned in doc. It has been tested on MI300X. --- docs/source/trainer.md | 5 ++ nnscaler/cli/trainer.py | 49 ++++++++++-- tests/cli/test_resume_seed.py | 137 ++++++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 tests/cli/test_resume_seed.py diff --git a/docs/source/trainer.md b/docs/source/trainer.md index d51db712..a4635eba 100644 --- a/docs/source/trainer.md +++ b/docs/source/trainer.md @@ -388,6 +388,11 @@ we will run validation on the validation dataset and save the validation loss to The validation run will ignore the `val_every_n_train_steps` and `val_every_n_epochs` configurations. If no valid dataset is provided, validation is skipped and `valid_loss` is set to `train_loss` by default. +3. The sharded checkpoints will contain PyTorch's RNG state, but not Python's or NumPy's. +The checkpoint's RNG state will be resumed right before training start, +which means the initialization stage will use `TrainerArgs.seed` instead. +Merged checkpoints will discard the RNG state. + ### Other configs - `gen_savedir` (`str`): The directory to save the generated files. Default is `./.nnscaler`. - `gen_reuse` (`str`): the reuse strategy of the generated code, it can be diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 04ee9d5b..a965cbcf 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional, Union from pathlib import Path @@ -93,6 +95,8 @@ def __init__(self, self.max_train_steps = None self.loggers = [] self.hook = None + # RNG states pending resume; reset to None after resuming + self.rng_states_from_resume: dict[str, torch.Tensor] | None = None def run(self): self._setup() @@ -288,6 +292,7 @@ def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): 'lr_scheduler': state_dicts[0].get('lr_scheduler', None), 'train_status': state_dicts[0]['train_status'], 'train_args': train_args, + 'rng_states': None, } torch.save(merged_state_dict, output_file) @@ -341,6 +346,7 @@ def _load_checkpoint(self): if self.lr_scheduler: self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) self.train_status = TrainStatus(**state_dict['train_status']) + self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() def _log_mem_stats(self, tag=None): # log minimum free memory over the iteration @@ -409,6 +415,7 @@ def _save_checkpoint(self, loss): 'lr_scheduler': self.lr_scheduler.state_dict() if self.lr_scheduler else None, 'train_status': asdict(self.train_status), 'train_args': self.train_args.to_dict(), + 'rng_states': self._get_rng_states(), } self.hook.on_save_checkpoint(self, state_dict) ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( @@ -499,11 +506,26 @@ def _expire_checkpoints(self): logger.info('Removing old checkpoint: %s', ckpt_name) shutil.rmtree(save_dir / ckpt_name) - def _global_batch_iterator(self, num_skip_first = 0, stage='train'): + def _global_batch_iterator(self, num_skip_first=0, stage='train'): + if num_skip_first == 0: + # if the checkpoint stops at the end of an epoch, + # the rng states must be resumed before creating iterator + # because `DataLoader.__iter__()` uses the rng (dunno why), + # and the previous run had not call it yet + self._try_resume_rng_states() + + it = iter(self.dataloader[stage]) + for _ in range(num_skip_first * self.train_args.update_freq): + _sample = next(it) + + if num_skip_first != 0: + # if the checkpoint stops in the middle of an epoch, + # the rng states must be resumed before loading the first batch, which depends on the rng; + # and must be resumed after skipping unused batches, which will affect the rng + self._try_resume_rng_states() + samples = [] - for idx, sample in enumerate(self.dataloader[stage]): - if idx < num_skip_first * self.train_args.update_freq: - continue + for sample in it: sample = self._fix_input(sample) samples.append(sample) if len(samples) == self.train_args.update_freq: @@ -738,14 +760,14 @@ def _validate(self, step_stat: _StepStat): logger.info(self._format_metrics(f'Validation', None, val_metrics)) return step_stat.val_loss - def _train_epoch(self, epoch): + def _train_epoch(self, epoch: int) -> None: VAL_STATUS_NO = 0 # not validated or saved VAL_STATUS_VAL = 1 # validated but not saved VAL_STATUS_SAVE = 2 # validated and saved has_validated = VAL_STATUS_NO # 3 states resume_from_idx = self.train_status.finished_train_steps % self.total_train_steps_per_epoch - data_iter = enumerate(self._global_batch_iterator(num_skip_first=resume_from_idx)) + data_iter = enumerate(self._global_batch_iterator(resume_from_idx)) max_epoch = self.max_train_steps // self.total_train_steps_per_epoch if self.max_train_steps % self.total_train_steps_per_epoch != 0: @@ -898,3 +920,18 @@ def _train_epoch(self, epoch): and (epoch + 1) % self.train_args.val_every_n_epochs == 0: self._validate(step_stat) has_validated = VAL_STATUS_VAL + + def _get_rng_states(self) -> dict[str, torch.Tensor]: + return { + 'torch': torch.get_rng_state(), + 'torch_cuda': torch.cuda.get_rng_state(), + } + + def _try_resume_rng_states(self) -> None: + # assuming hooks do not use rng + if self.rng_states_from_resume is not None: + if self.rng_states_from_resume.get('torch') is not None: + torch.set_rng_state(self.rng_states_from_resume['torch']) + if self.rng_states_from_resume.get('torch_cuda') is not None: + torch.cuda.set_rng_state(self.rng_states_from_resume['torch_cuda']) + self.rng_states_from_resume = None diff --git a/tests/cli/test_resume_seed.py b/tests/cli/test_resume_seed.py new file mode 100644 index 00000000..30d4992a --- /dev/null +++ b/tests/cli/test_resume_seed.py @@ -0,0 +1,137 @@ +import os +import pytest +import torch +from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer_args import * + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='no gpu') +def test_resume_seed(): + _set_envs({ + # required by deterministic + 'CUBLAS_WORKSPACE_CONFIG': ':4096:8', + + # fake torchrun environment, check https://pytorch.org/docs/stable/elastic/run.html#environment-variables + 'LOCAL_RANK': 0, + 'RANK': 0, + 'GROUP_RANK': 0, + 'LOCAL_WORLD_SIZE': 1, + 'WORLD_SIZE': 1, + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': 29470, + 'TORCHELASTIC_RUN_ID': 'UT', + }) + + torch.use_deterministic_algorithms(True) + + # compile separately because run multiple trainers in one process will confuse `gen_reuse` + _compile() + + _test_resume_seed(steps_per_epoch=100, max_steps=20, resume_at=10) + + _test_resume_seed(steps_per_epoch=5, max_steps=20, resume_at=10) + + _restore_envs() + + +def _test_resume_seed(steps_per_epoch, max_steps, resume_at): + # no resume + model_1 = _train(steps_per_epoch, max_train_steps=max_steps, resume_from=None) + weight_1 = next(model_1.parameters()).data + + # resume + _train(steps_per_epoch, max_train_steps=resume_at, resume_from=None) + model_2 = _train(steps_per_epoch, max_train_steps=max_steps, resume_from='last') + weight_2 = next(model_2.parameters()).data + + assert torch.equal(weight_1, weight_2) + + ## resume without resuming seeds + _train(steps_per_epoch, max_train_steps=resume_at, resume_from=None) + _remove_rng_states() + model_3 = _train(steps_per_epoch, max_train_steps=max_steps, resume_from='last') + weight_3 = next(model_3.parameters()).data + + assert not torch.equal(weight_1, weight_3) + + +def _compile(): + trainer_args = TrainerArgs( + compute_config=ComputeConfig(plan_ngpus=1, runtime_ngpus=1, use_end2end=True), + gen_reuse='override', + run_mode='compile', + model=ModelConfig(type=Model), + optimizer=OptimizerConfig(type=torch.optim.AdamW), + dataset=DatasetConfig(type=RandomDataset, train_args={'length': 100}), + max_train_steps=1, + enable_progress_bar=False, + seed=0, + ) + trainer = Trainer(train_args=trainer_args) + trainer.run() + + +def _train(steps_per_epoch, max_train_steps, resume_from): + trainer_args = TrainerArgs( + compute_config=ComputeConfig(plan_ngpus=1, runtime_ngpus=1, use_end2end=True), + model=ModelConfig(type=Model), + optimizer=OptimizerConfig(type=torch.optim.AdamW), + dataset=DatasetConfig(type=RandomDataset, train_args={'length': steps_per_epoch}), + checkpoint=CheckpointConfig(resume_from=resume_from), + max_train_steps=max_train_steps, + enable_progress_bar=False, + seed=0, + ) + trainer = Trainer(train_args=trainer_args) + trainer.run() + return trainer.model + + +def _remove_rng_states(): + ckpt_path = 'checkpoints/last/0.ckpt' + ckpt = torch.load(ckpt_path, weights_only=False) + ckpt['rng_states'] = None + torch.save(ckpt, ckpt_path) + + +_backup_envs = {} + +def _set_envs(envs): + _backup_envs.clear() + for key, value in envs.items(): + _backup_envs[key] = os.environ.get(key, None) + os.environ[key] = str(value) + +def _restore_envs(): + for key, value in _backup_envs.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(100, 10) + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, data): + x = data['x'] + x = self.linear(x) + x = self.dropout(x) + return torch.nn.functional.cross_entropy(x, data['y']) + + +class RandomDataset: + def __init__(self, length): + self.length = length + + def __getitem__(self, i): + return { + 'x': torch.rand(100), + 'y': torch.randint(10, tuple()), + } + + def __len__(self): + return self.length From 91fd30a68797eb4cd164e393b080fc7a3c0f03d2 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Thu, 14 Nov 2024 07:54:18 +0000 Subject: [PATCH 1761/1892] Merged PR 2288: make all parameters in reducer buffer to be aligned to 16 bytes Make all parameters in reducer buffer to be aligned to 16 bytes. The alignment can affect the result of torch ops. This change can ease our parity check work. --- nnscaler/parallel.py | 41 +++++++------ nnscaler/runtime/adapter/reducer.py | 77 ++++++++++++++++++++---- nnscaler/runtime/module.py | 5 +- tests/parallel_module/test_checkpoint.py | 4 +- tests/runtime/test_reducer.py | 25 +++++--- 5 files changed, 108 insertions(+), 44 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index b72522f1..8011949a 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1884,6 +1884,8 @@ def _get_optimizer_state_of_param(param, param_ids, local_names): step, opt_states, opt_state_keys = None, {}, None for param in bucket.params: sliced_new_val = _get_optimizer_state_of_param(param, param_ids, local_names) + # there are padding in the chunk, so `param.numel()` doesn't work here + param_numel = bucket.get_aligned_numel(param) # init the chunk's optimizer state if opt_state_keys is None: opt_state_keys = [key for key in sliced_new_val] @@ -1900,41 +1902,46 @@ def _get_optimizer_state_of_param(param, param_ids, local_names): # parameter range: <> # bucket range: [] + # in the following branches, we check the range including paddings. + # but in branch body, we only copy the valid range (without paddings) but update the chunk_offset with paddings. if param_offset < bucket_chunk_start \ - and bucket_chunk_start < param_offset + param.numel() < bucket_chunk_end: + and bucket_chunk_start < param_offset + param_numel < bucket_chunk_end: # case: < [ > ] - copy_size = param_offset + param.numel() - bucket_chunk_start - for key in opt_state_keys: - opt_states[key][chunk_offset:chunk_offset+copy_size] = sliced_new_val[key][-copy_size:] + copy_size = param_offset + param_numel - bucket_chunk_start + copy_size_without_padding = param_offset + param.numel() - bucket_chunk_start + if copy_size_without_padding > 0: + for key in opt_state_keys: + opt_states[key][chunk_offset:chunk_offset+copy_size_without_padding] = sliced_new_val[key][-copy_size_without_padding:] chunk_offset += copy_size elif bucket_chunk_start <= param_offset < bucket_chunk_end \ - and bucket_chunk_start <= param_offset + param.numel() < bucket_chunk_end: + and bucket_chunk_start <= param_offset + param_numel < bucket_chunk_end: # case: [ < > ] for key in opt_state_keys: opt_states[key][chunk_offset:chunk_offset+param.numel()] = sliced_new_val[key][:] - chunk_offset += param.numel() + chunk_offset += param_numel elif bucket_chunk_start <= param_offset < bucket_chunk_end \ - and param_offset + param.numel() >= bucket_chunk_end: + and param_offset + param_numel >= bucket_chunk_end: # case: [ < ] > copy_size = bucket_chunk_end - param_offset + copy_size_without_padding = min(copy_size, param.numel()) for key in opt_state_keys: - opt_states[key][chunk_offset:chunk_offset+copy_size] = sliced_new_val[key][:copy_size] + opt_states[key][chunk_offset:chunk_offset+copy_size_without_padding] = sliced_new_val[key][:copy_size_without_padding] chunk_offset += copy_size elif param_offset < bucket_chunk_start \ - and param_offset + param.numel() >= bucket_chunk_end: + and param_offset + param_numel >= bucket_chunk_end: # case: < [ ] > copy_size = bucket_chunk_end - bucket_chunk_start - for key in opt_state_keys: - opt_states[key][chunk_offset:chunk_offset + copy_size] \ - = sliced_new_val[key][bucket_chunk_start-param_offset:bucket_chunk_start-param_offset + copy_size] + copy_size_without_padding = min(copy_size, param_offset + param.numel() - bucket_chunk_start) + if copy_size_without_padding > 0: + for key in opt_state_keys: + opt_states[key][chunk_offset:chunk_offset + copy_size_without_padding] \ + = sliced_new_val[key][bucket_chunk_start-param_offset:bucket_chunk_start-param_offset + copy_size_without_padding] chunk_offset += copy_size else: # case: [] <>, <> [] - logger.debug(f'Skipped: parameter range({param_offset},{param_offset + param.numel()}) vs. bucket range({bucket_chunk_start},{bucket_chunk_end})') - param_offset += param.numel() - # as there is padding in chunk, slicing to obtain the correct shape opt states - for key in opt_state_keys: - opt_states[key] = opt_states[key][:opt_param[opt_param_idx].shape[0]] + logger.debug(f'Skipped: parameter range({param_offset},{param_offset + param_numel}) vs. bucket range({bucket_chunk_start},{bucket_chunk_end})') + param_offset += param_numel + if step is not None: opt_states['step'] = step state_dict[opt_param_idx] = opt_states diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index 6c177abb..be83fdc3 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -3,6 +3,7 @@ from typing import List, Dict, Tuple, Any, Callable, Optional, Set, Sequence from functools import partial +import math import logging import torch from torch.utils.hooks import RemovableHandle @@ -14,6 +15,33 @@ _logger = logging.getLogger(__name__) +# According to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#device-memory-accesses +# Any address of a variable residing in global memory or returned by one of the memory allocation +# routines from the driver or runtime API is always aligned to at least 256 bytes. +# But in our practice, we found that 16 bytes alignment is enough, it can be modified if unaligned access is detected. +ALIGNED_BYTES = 16 + + +def _aligned_nbyte(nelement: int, element_size: int, align_size: int = ALIGNED_BYTES) -> int: + """ + Align the number of elements, so the total byte size of elements is multiple of `align_size` + Returns: + the aligned number of bytes + """ + if align_size % element_size != 0: + raise ValueError(f"align_size {align_size} must be divisible by element_size {element_size}") + return (nelement * element_size + align_size - 1) // align_size * align_size + + +def _aligned_nelement(nelement: int, element_size: int, align_size: int = ALIGNED_BYTES) -> int: + """ + Align the number of elements, so the total byte size of elements is multiple of `align_size` + Returns: + the aligned number of elements + """ + return _aligned_nbyte(nelement, element_size, align_size) // element_size + + def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: """ Get reduce op from string @@ -39,6 +67,7 @@ def __init__(self, params: List[torch.nn.Parameter], zero_subgroup: torch.distributed.ProcessGroup = None, zero_crossgroup: torch.distributed.ProcessGroup = None, zero_use_reduce_scatter: bool = False, + align_size: int = ALIGNED_BYTES, ): """ Create a communication unit for parameter allreduce. @@ -57,6 +86,7 @@ def __init__(self, params: List[torch.nn.Parameter], zero_subgroup (torch.distributed.ProcessGroup): the subgroup for zero optimization the current rank belongs to zero_crossgroup (torch.distributed.ProcessGroup): the communication group for cross zero group allreduce when reduce scatter is enabled zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization + align_size (int): the alignment size in bytes for each parameter """ self._params: List[torch.nn.Parameter] = params @@ -78,8 +108,12 @@ def __init__(self, params: List[torch.nn.Parameter], # the parameter exposed for optimizer self._param_for_optimizer: torch.nn.Parameter = None # total number of parameters + self._align_size: int = align_size + if self._align_size % ALIGNED_BYTES != 0: + raise ValueError(f"align_size {self._align_size} must be divisible by {ALIGNED_BYTES}") + self._numel: int = sum(p.numel() for p in self._params) - self._padding: int = self._contiguous_grads.size(0) - self._numel + self._aligned_numel: int = sum(_aligned_nelement(p.nelement(), p.element_size(), self._align_size) for p in self._params) self._zero_subgroup = self._group if zero_subgroup is None else zero_subgroup self._zgroup_sz: int = torch.distributed.get_world_size(group=self._zero_subgroup) @@ -108,6 +142,12 @@ def zero(self) -> bool: """Whether enable zero for this bucket""" return self._zero + def get_aligned_numel(self, param) -> int: + """ + Get the aligned number of elements for a parameter + """ + return _aligned_nelement(param.nelement(), param.element_size(), self._align_size) + def _group_reduce_scatter(self): """currently this function is only used in synchronous mode""" rank = torch.distributed.get_rank(group=self._zero_subgroup) @@ -133,15 +173,14 @@ def build(self): Build offset for each parameter This should only be called once during the construction of bucket. """ - self._numel = sum(p.numel() for p in self._params) ofst = 0 for param in self._params: self._pofset[param] = ofst - ofst += param.numel() + ofst += _aligned_nelement(param.nelement(), param.element_size(), self._align_size) # build parameter for optimizer (shared storage). # Its gradient will be updated everytime calling `self.sync_grads()` if not self._zero: - opt = self._contiguous_params[:self._numel] + opt = self._contiguous_params else: rank = torch.distributed.get_rank(group=self._zero_subgroup) assert len(self._contiguous_params) % self._zgroup_sz == 0 @@ -209,7 +248,7 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): # same trick with FSDP and Megatron # reference: https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/fully_sharded_data_parallel.py#L3177-L3188 param_tmp = param.expand_as(param) - # gets its AccumulateGrad object. + # gets its AccumulateGrad object grad_acc = param_tmp.grad_fn.next_functions[0][0] hook = grad_acc.register_hook(partial(post_grad_hook, param)) # grad_acc must keep, otherwise the hook won't take effect @@ -253,7 +292,7 @@ def sync_grads(self): grad = self._contiguous_grads.chunk(self._zgroup_sz, dim=0)[rank] self._param_for_optimizer.grad = grad else: - self._param_for_optimizer.grad = self._contiguous_grads[:self._numel] + self._param_for_optimizer.grad = self._contiguous_grads # apply post-hooks self._apply_post_hooks() @@ -297,7 +336,7 @@ def _apply_pre_hooks(self): The pre-hooks will be applied one by one following the order of registration. """ if len(self._pre_hooks) == 0: return - grads = self._contiguous_grads[:self._numel] + grads = self._contiguous_grads for hook in self._pre_hooks: hook(grads) @@ -307,7 +346,7 @@ def _apply_post_hooks(self): The post-hooks will be applied one by one following the order of registration. """ if len(self._post_hooks) == 0: return - grads = self._contiguous_grads[:self._numel] + grads = self._contiguous_grads for hook in self._post_hooks: hook(grads) @@ -335,6 +374,7 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None reduce_op: str = 'sum', async_op: bool = False, zero: bool = False, zero_ngroups: int = 1, zero_use_reduce_scatter: bool = False, + align_size: int = ALIGNED_BYTES ): """ Create a reducer applied on a set of weights for weight reduction @@ -352,12 +392,14 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None zero (bool): whether to apply ZeRO optimization on gradients zero_ngroups (int): number of ZeRO subgroups in the original ZeRO group zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization + align_size (int): the alignment size in bytes for each parameter """ self._params: List[torch.nn.Parameter] = list() self._param_ids: Set[int] = set() self._numel: int = 0 self._ranks = ranks self._group = DeviceGroup().get_group(ranks) + self._wsz: int = torch.distributed.get_world_size(group=self._group) self._bucket_size: Optional[int] = max_bucket_size_bytes if not self._bucket_size and async_op: @@ -369,6 +411,10 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None self._async: bool = async_op self._zero: bool = zero self._zero_use_reduce_scatter = zero_use_reduce_scatter + self._align_size: int = align_size + if self._align_size % ALIGNED_BYTES != 0: + raise ValueError(f"align_size {self._align_size} must be divisible by {ALIGNED_BYTES}") + # contiguous parameter buffer and gradient buffer self._contiguous_params: torch.Tensor = None self._contiguous_grads: torch.Tensor = None @@ -489,7 +535,7 @@ def build_buckets(self): ) for param in self._params: if param.requires_grad: - cur_byte_size = param.nelement() * param.element_size() + cur_byte_size = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) * param.element_size() # also work when cur_byte_size > bucket_size # It will go the `else` branch # and finish the current bucket and start a new bucket. @@ -510,8 +556,11 @@ def build_buckets(self): starts, stops = [], [] for params in seq_buckets: starts.append(buffer_length) - numel = sum(p.numel() for p in params) - padding = (len(self._ranks) - numel % len(self._ranks)) % len(self._ranks) + numel = sum(_aligned_nelement(p.nelement(), p.element_size(), self._align_size) for p in params) + # this pad is for zero, which needs numels in each Bucket can be divided by the number of ranks in this group * _align_size + # so that each chunck during zero can be divided by _align_size + align_nelements = self._align_size // params[0].element_size() * len(self._ranks) + padding = (align_nelements - numel % align_nelements) % len(self._ranks) buffer_length += numel + padding stops.append(buffer_length) @@ -521,7 +570,7 @@ def build_buckets(self): (buffer_length,), dtype=self._params[0].dtype, device=torch.cuda.current_device(), requires_grad=False) # parameter buffer - self._contiguous_params: torch.Tensor = torch.empty( + self._contiguous_params: torch.Tensor = torch.zeros( (buffer_length,), dtype=self._params[0].dtype, device=torch.cuda.current_device(), requires_grad=False) @@ -534,7 +583,8 @@ def build_buckets(self): with torch.no_grad(): self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) - ofst += param.numel() + aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) + ofst += aligned_nelements # initialize buckets bucket = Bucket( params, @@ -547,6 +597,7 @@ def build_buckets(self): self._zero_subgroup, self._zero_crossgroup, self._zero_use_reduce_scatter, + self._align_size, ) buckets.append(bucket) torch.cuda.empty_cache() diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 8c0b0403..6dc7cf97 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -1093,9 +1093,10 @@ def _get_zero_metadata(self) -> ZeroMetadata: pstart, pend = 0, 0 for param in bucket.params: pstart = pend - pend += param.numel() + pend = pstart + bucket.get_aligned_numel(param) + pend_without_padding = pstart + param.numel() model_idx = model_params_id.index(id(param)) - model_idx2opt_idx[model_idx] = (opt_idx, pstart, pend, param.shape) + model_idx2opt_idx[model_idx] = (opt_idx, pstart, pend_without_padding, param.shape) assert len(bucket._contiguous_params.shape) == 1 opt_idx2ranks[opt_idx] = (sub_ranks, bucket._contiguous_params.shape[0]) opt_idx += 1 diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index a0be40ca..627f76bb 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -364,11 +364,11 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf results = [] for i, (x, y) in enumerate(data): y_pred, loss = train_step(model, x, y, optimizer) - grads = {n: p.grad for n, p in model.named_parameters()} + grads = {n: p.grad.clone() for n, p in model.named_parameters()} gnorm = optimizer.clip_gnorm() results.append(clone_to_cpu_recursively([y_pred, loss, grads, gnorm])) optimizer.zero_grad() - weights = {n: p.data for n, p in model.named_parameters()} + weights = {n: p.data.clone() for n, p in model.named_parameters()} results[-1].append(clone_to_cpu_recursively(weights)) results[-1] = StepResult(*results[-1]) diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py index 3f080af0..2721c3ee 100644 --- a/tests/runtime/test_reducer.py +++ b/tests/runtime/test_reducer.py @@ -140,18 +140,23 @@ def reducer_test(): @mock_reducer_env(0, 2) def test_reducer_build(): - reducer = Reducer([0, 1], max_bucket_size_bytes=16) # 16 bytes means 4 float32 - reducer.add_param(torch.nn.Parameter(torch.randn(1, 2))) # small at first - reducer.add_param(torch.nn.Parameter(torch.randn(1, 10))) # bigger than max_bucket_size_bytes - reducer.add_param(torch.nn.Parameter(torch.randn(1, 3))) # small again - reducer.add_param(torch.nn.Parameter(torch.randn(1, 3))) # small again - reducer.add_param(torch.nn.Parameter(torch.randn(1, 1))) # small again - reducer.add_param(torch.nn.Parameter(torch.randn(1, 1))) # small again + reducer = Reducer([0, 1], max_bucket_size_bytes=48) # 24 bytes means 12 float32 + reducer.add_param(torch.nn.Parameter(torch.randn(1, 2))) # 4 floats # small at first + reducer.add_param(torch.nn.Parameter(torch.randn(1, 14))) # 16 floats # bigger than max_bucket_size_bytes + reducer.add_param(torch.nn.Parameter(torch.randn(1, 5))) # 8 floats # small again + reducer.add_param(torch.nn.Parameter(torch.randn(1, 5))) # 8 floats # small again + reducer.add_param(torch.nn.Parameter(torch.randn(1, 1))) # 4 floats small again + reducer.add_param(torch.nn.Parameter(torch.randn(1, 1))) # 4 floats small again reducer.build_buckets() assert len(reducer.buckets) == 5 buckets = list(reversed(reducer.buckets)) assert buckets[0].numel == 2 - assert buckets[1].numel == 10 - assert buckets[2].numel == 3 - assert buckets[3].numel == 4 + assert buckets[0]._aligned_numel == 4 + assert buckets[1].numel == 14 + assert buckets[1]._aligned_numel == 16 + assert buckets[2].numel == 5 + assert buckets[2]._aligned_numel == 8 + assert buckets[3].numel == 6 + assert buckets[3]._aligned_numel == 12 assert buckets[4].numel == 1 + assert buckets[4]._aligned_numel == 4 \ No newline at end of file From 9fb800d11640bff2713823677422bd8a148ec9e1 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 15 Nov 2024 06:48:52 +0000 Subject: [PATCH 1762/1892] Merged PR 2308: [BugFix] add detach loss in codegen to deallocate tensors correctly add detach loss in codegen --- nnscaler/codegen/schedule/schedule.py | 43 +++-- tests/parallel_module/test_e2e_detach_loss.py | 173 ++++++++++++++++++ 2 files changed, 205 insertions(+), 11 deletions(-) create mode 100644 tests/parallel_module/test_e2e_detach_loss.py diff --git a/nnscaler/codegen/schedule/schedule.py b/nnscaler/codegen/schedule/schedule.py index c21a764e..6e260b9d 100644 --- a/nnscaler/codegen/schedule/schedule.py +++ b/nnscaler/codegen/schedule/schedule.py @@ -136,6 +136,12 @@ def gen(self, device: int, outfile=None, attach=None) -> str: f.write(code) return code + def emit_detach(self, tensor: IRTensor) -> str: + """ + Emit detach code + """ + return f'{self.tensor_name(tensor)} = {self.tensor_name(tensor)}.detach()' + def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: """ Emit node / subgraph code @@ -158,13 +164,13 @@ def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: if isinstance(unwrap_node, IRSegment): # emit forward segment if node.isfw(): - code = fsign.format( + codes = [fsign.format( outputs = outputs, name = f"'{name}'", model = f'model.{name}', inputs = inputs, req_grad = req_grad - ) + )] else: # get gradient computation arguments input_tensors, output_tensors, output_grads, input_grads = \ @@ -173,38 +179,53 @@ def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: for idx, tensor in enumerate(output_grads): if isinstance(tensor, IRSubTensor) and tensor.is_loss(): output_grads[idx] = None - code = bsign.format( + codes = [bsign.format( name = f"'{self.node_name(unwrap_node.mirror)}'", input_grads = self.return_name(input_grads), input_tensors = self.tuple_name(input_tensors, skip_attr=True, prefix_attr='model.'), output_tensors = self.tuple_name(output_tensors, skip_attr=True, prefix_attr='model.'), output_grads = self.tuple_name(output_grads, skip_attr=True, prefix_attr='model.') - ) + )] + """ + In the end2end mode, although the graph's output may contain tensors that requires grad, + like the loss tensor, the backward pass has been done by the nnscaler runtime by calling + `nnscaler.runtime.executor.backward` in the generated code. In other words, the returned + loss cannot be used by `loss.backward()`. + In pipeline parallelism, loss tensors of micro-batches are generated at the last stage and + transferred to remaining stages at the end for `_train_step` function. We add the detach + operation here so that the backward graph's tensors can be deallocated right after the + backward pass. + """ + for tensor in output_tensors: + if not isinstance(tensor, IRTensor): + continue + if tensor in self.execplan.outputs(): + codes.append(self.emit_detach(tensor)) elif isinstance(unwrap_node, IRDataOperation): - code = f'{outputs} = {unwrap_node.signature}(*{inputs})' + codes = [f'{outputs} = {unwrap_node.signature}(*{inputs})'] elif isinstance(unwrap_node, IRAdapter): - code = asign.format( + codes = [asign.format( outputs = outputs, model = f'model.{name}', inputs = inputs, req_grad = req_grad - ) + )] elif isinstance(unwrap_node, IRWeightReducer): - code = asign.format( + codes = [asign.format( outputs = outputs, model=f'model.{name}', inputs='()', req_grad=req_grad - ) + )] else: raise RuntimeError(f"Unspported node type: {type(unwrap_node)}") if CompileFlag.line_timer: type_str = type(unwrap_node).__name__ - return [f'nnscaler.runtime.function.print_time({repr(type_str)})', code] + return [f'nnscaler.runtime.function.print_time({repr(type_str)})'] + codes else: - return [code] + return codes diff --git a/tests/parallel_module/test_e2e_detach_loss.py b/tests/parallel_module/test_e2e_detach_loss.py new file mode 100644 index 00000000..aaa9648a --- /dev/null +++ b/tests/parallel_module/test_e2e_detach_loss.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +import tempfile +import shutil +import contextlib +import pytest +from pathlib import Path + + +import nnscaler +import nnscaler.graph.function.function as F +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph import IRGraph +from nnscaler.ir.adapter import IRAdapter +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.ir.operator import IRFwOperation, IRDataOperation +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.schedule.predefined import PredefinedSched +from tests.utils import clear_dir_on_rank0, init_random +from tests.launch_torchrun import torchrun +from tests.parallel_module.test_gencode import _gencode_contains + + +def get_mem(): + return torch.cuda.max_memory_allocated() // 1024 // 1024 + + +class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(4096, 4096, bias=False) + self.fc2 = torch.nn.Linear(4096, 4096, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x.sum() + + +def policy_pp(graph, cfg): + data_loader, fc1, fc2, loss = graph.nodes()[:4] + graph.staging([fc1, fc2]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + sub_nodes = graph.replicate(data_loader, 4) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + sub_nodes = graph.partition(fc1, fc1.algorithms('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + identity = stages[1].nodes()[0] + sub_nodes = graph.replicate(identity, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.partition(fc2, fc2.algorithms('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.partition(loss, loss.algorithms('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + PredefinedSched.sched_1f1b(graph, 4, len(stages)) + + return graph + + +def worker_pipeline_2x2(): + nnscaler.init() + m = Model() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) + + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2x2') as tempdir: + pm = parallelize( + m, + {'x': trace_data}, + policy_pp, + ComputeConfig(4, 4, use_end2end=True), + reuse='override', + gen_savedir=tempdir, + ) + + if pm.rank in [2, 3]: + assert len(_gencode_contains(tempdir, Model, pm.rank, 'detach\(\)')) == 4 + + samples = [torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) for _ in range(4)] + ret = pm.train_step(samples) + mem0 = get_mem() + ret = pm.train_step(samples) + mem1 = get_mem() + ret = pm.train_step(samples) + mem2 = get_mem() + ret = pm.train_step(samples) + mem3 = get_mem() + # print(mem0, mem1, mem2, mem3) + assert mem0 == mem1 == mem2 == mem3 + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_detach_loss_pipeline_hard(): + # should not raise any exception + torchrun(4, worker_pipeline_2x2) + + +def policy_easy(graph, cfg): + data_loader, fc1, fc2, loss = graph.nodes()[:4] + graph.staging([fc1, fc2]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + sub_nodes = graph.replicate(data_loader, 2) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + graph.assign(fc1, 0) + + identity = stages[1].nodes()[0] + graph.assign(identity, 1) + graph.assign(fc2, 1) + graph.assign(loss, 1) + + PredefinedSched.sched_1f1b(graph, 4, len(stages)) + + return graph + + +def worker_pipeline_2(): + nnscaler.init() + m = Model() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) + + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2') as tempdir: + pm = parallelize( + m, + {'x': trace_data}, + policy_easy, + ComputeConfig(2, 2, use_end2end=True), + reuse='override', + gen_savedir=tempdir, + ) + pm.to(torch.cuda.current_device()) + + if pm.rank == 1: + assert len(_gencode_contains(tempdir, Model, pm.rank, 'detach\(\)')) == 4 + samples = [torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) for _ in range(4)] + ret = pm.train_step(samples) + mem0 = get_mem() + ret = pm.train_step(samples) + mem1 = get_mem() + ret = pm.train_step(samples) + mem2 = get_mem() + ret = pm.train_step(samples) + mem3 = get_mem() + assert mem0 == mem1 == mem2 == mem3 + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_detach_loss_pipeline_easy(): + torchrun(2, worker_pipeline_2) + # should not raise any exception + assert True From cfc53405e1ca0c7ffef0be4097235738193a0328 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 15 Nov 2024 06:51:47 +0000 Subject: [PATCH 1763/1892] Merged PR 2309: [Bugfix] tracer: handle importlib carefully Currently we only wrap function calls via `ProxyCallTransformer`, but some functions can be triggered by getattr (e.g. torch._dynamo), in which it is impossible to 'unpatch' the calls triggered inside `getattr` (Make it leaf function doesn't work because we don't have opportunity to unpatch it). This PR is trying to fix a special case when `importlib.import_module` is called inside `getattr`. Two reasons: 1. import_module will run the code of modules recursively (the imported module can import other modules), and is easy to have problem. 2. `torch._dynamo` is used in some pytorch modules, and we want to support it. Two ways to fix this: 1. Handle `import_module` in a special way. 2. Refine the `ProxyCallTransformer` to patch getattr as well. The second way is more general, but it's more complex and may introduce potential bugs. In this PR, we choose the first way as a quick fix. --- nnscaler/graph/parser/converter.py | 8 +++ nnscaler/graph/tracer/concrete_tracer.py | 15 +++++- nnscaler/graph/tracer/orig_func.py | 3 ++ nnscaler/graph/tracer/wrap_utils.py | 22 ++++++++- tests/parallel_module/test_gencode.py | 62 ++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 3 deletions(-) diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 308fff73..9840fba8 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -82,6 +82,14 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: autowrap_funcs = [fn for fn in autowrap_funcs if not is_autograd_apply(fn)] leaf_functions = {func: LeafWrapInfo([], True, None) for func in autowrap_funcs if func is not None} + # importlib functions + # currently only import_module is handled in the code + import importlib + leaf_functions.update({ + func: LeafWrapInfo([Location(importlib, func.__name__)], False, None) + for func in [importlib.import_module] + }) + # get cube runtime functions cube_rt_funcs = [cube_rt_function.anchor, cube_rt_function.ifexpr] leaf_functions.update({ diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index 6626bedb..f1a44179 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -7,6 +7,7 @@ import collections import copy +from functools import partial import sys import inspect import logging @@ -640,8 +641,20 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): wrapped = wrap_utils.create_wrapped_nn_module_func(self, mod, forward_function_name) self.wrapped_leaf[mod.forward] = ((wrap_utils.Location(mod, forward_function_name),), wrapped) + # make sure never_wrap_function are called right + def wrap_never_wrap_function(func, *args, **kwargs): + if self.patcher.patch_mode: + with self.patcher.revert(): + return func(*args, **kwargs) + else: + return func(*args, **kwargs) + try: with self.patcher: + for func, leaf_info in wrap_utils.default_never_wrap_function.items(): + for loc in leaf_info.extra_locs: + self.patcher.patch_method(loc.ns, loc.name, partial(wrap_never_wrap_function, func), deduplicate=True) + # allow duplicate patches to support the case of nested calls self.patcher.patch_method(torch.nn.Module, "__getattribute__", wrap_utils.create_wrapped_module_getattribute(self), deduplicate=False) @@ -661,7 +674,7 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): for obj, (positions, wrapped) in self.wrapped_leaf.items(): for loc in positions: self.patcher.patch_method(loc.ns, loc.name, wrapped, deduplicate=False) - + wrap_utils.autowrap_check(self, fn_globals) with OperatorPatcherContext(self, use_operator_patch, operator_patch_backlist): diff --git a/nnscaler/graph/tracer/orig_func.py b/nnscaler/graph/tracer/orig_func.py index cae65f7e..cad0680b 100644 --- a/nnscaler/graph/tracer/orig_func.py +++ b/nnscaler/graph/tracer/orig_func.py @@ -48,3 +48,6 @@ torch_assert = torch._assert torch_Size = torch.Size torch_finfo = torch.finfo + +import importlib +import_module = importlib.import_module diff --git a/nnscaler/graph/tracer/wrap_utils.py b/nnscaler/graph/tracer/wrap_utils.py index 28f63905..f495fd04 100644 --- a/nnscaler/graph/tracer/wrap_utils.py +++ b/nnscaler/graph/tracer/wrap_utils.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field import functools import operator +import importlib from types import MethodType, ModuleType from typing import Any, Dict, Optional, Type, List, Callable, Union, TYPE_CHECKING, Tuple @@ -127,6 +128,23 @@ def _functions_in_module(module: ModuleType): yield op, name +# the functions that should never be wrapped +# TODO: +# currently we only have import_module, and should add more if needed +# Putting these functions as leaf functions doesn't work +# because +# 1. We only wrap function calls via `ProxyCallTransformer` +# 2. But some functions can be triggered by getattr (e.g. torch._dynamo) +# Two ways to fix this: +# 1. Handle popular functions in `default_never_wrap_function` in a special way. +# 2. Refine the `ProxyCallTransformer` to handle getattr as well. +# The second way is more general, but it's more complex and may introduce potential bugs. +# For now, we choose the first way as a quick fix. +default_never_wrap_function: Dict[Callable, LeafWrapInfo] = { + orig_func.import_module: LeafWrapInfo([Location(importlib, 'import_module')], False, None) +} + + # get all functions in the default_autowrap_modules and add them to default_autowrap_leaf_function default_autowrap_modules = (operator, math, torch, torch.functional, torch.nn.functional) for module in default_autowrap_modules: @@ -220,7 +238,7 @@ def create_wrapped_leaf_class(clz, *, replace_cls: Optional[Callable]=None, defa x_value = int(x) new_x = torch.tensor([x_value, x_value]) ... - + Args: clz : the original class. replace_cls : forward the call to another function. @@ -441,7 +459,7 @@ def detect_tracer(obj): raise Exception('more than 1 tracer detected. please report the issue') elif len(tracers) == 1: return next(iter(tracers)).create_proxy('call_function', orig_func.tuple, (results,), {}) - + return orig_func.tuple(results) def __eq__(self, __o: object) -> bool: diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 85f699fe..320baf31 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F import pytest +from torch.torch_version import TorchVersion from nnscaler.flags import CompileFlag import nnscaler.graph.function.dimops @@ -1330,6 +1331,67 @@ def test_codegen_scalar_tensor(tmp_path): r"self\.add_full_map\('num_batches_tracked_\d+', \d+, False, 'num_batches_tracked', \(\), \.\.\., 1\)") +class ImportlibModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = torch.nn.Linear(1024, 1024) + + def forward(self, data): + import importlib + x = importlib.import_module('datetime') + r = self.model(data + x.datetime.now().year) + return torch.sum(r) + + +@replace_all_device_with('cpu') +def test_codegen_importlib(tmp_path): + m = ImportlibModel() + m.train() + parallelize( + m, + {'data': torch.randn(1024, 1024)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # should success + assert True + + +class ImportlibModel2(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = torch.nn.Linear(1024, 1024) + + def forward(self, data): + torch._dynamo + r = self.model(data) + return torch.sum(r) + + +@replace_all_device_with('cpu') +@pytest.mark.skipif(torch.torch_version.__version__ < (2,1,0), reason='torch._dynamo is not a valid import') +def test_codegen_importlib2(tmp_path): + m = ImportlibModel2() + m.train() + parallelize( + m, + {'data': torch.randn(1024, 1024)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + import nnscaler.graph.tracer.orig_func as orig_func + import importlib + assert orig_func.import_module == importlib.import_module + # should success + assert True + + class ConvTranspose1DModule(torch.nn.Module): def __init__(self, weight, bias=None, stride=1, padding=0, output_padding=0, dilation=1, groups=1): super().__init__() From 592ff83ee5009a1092486c3a3263cfbb8d857c39 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 19 Nov 2024 03:33:33 +0000 Subject: [PATCH 1764/1892] Merged PR 2316: [BugFix] fix progress bar and file expire in trainer 1. Progress bar reports wrong speed (because we set progress to 1 at the beginning) 2. File Expire doesn't respect the symbol links in best/last due to a regression change. --- nnscaler/cli/trainer.py | 34 ++++++++++++++-------------------- tests/cli/test_resume_seed.py | 2 +- tests/cli/test_trainer.py | 3 +++ 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index a965cbcf..62f88063 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -487,22 +487,25 @@ def _expire_checkpoints(self): if len(checkpoints) <= self.train_args.checkpoint.keep_last_n_checkpoints: return - # (step, num) pairs + # (step, ckpt_name) pairs checkpoint_info = [(int(p.split('-')[1]), p) for p in checkpoints] checkpoint_info.sort() - expire_list = checkpoint_info[:-self.train_args.checkpoint.keep_last_n_checkpoints] + expire_list = [c[1] for c in checkpoint_info[:-self.train_args.checkpoint.keep_last_n_checkpoints]] best_ckpt = save_dir / CHECKPOINT_BEST_DIR_NAME - if best_ckpt.exists(): - for p in best_ckpt.glob('*.ckpt'): + last_ckpt = save_dir / CHECKPOINT_LAST_DIR_NAME + for ckpt_dir in [best_ckpt, last_ckpt]: + if not ckpt_dir.exists(): + continue + for p in ckpt_dir.glob('*.ckpt'): if p.is_symlink(): ckpt_name = p.resolve().parent.name if ckpt_name in expire_list: expire_list.remove(ckpt_name) - logger.info('Keep old checkpoint `%s` because it is the best.', ckpt_name) + logger.info('Keep old checkpoint `%s` because it is symbol linked in best or last.', ckpt_name) break # just check the first file is enough - for _, ckpt_name in expire_list: + for ckpt_name in expire_list: logger.info('Removing old checkpoint: %s', ckpt_name) shutil.rmtree(save_dir / ckpt_name) @@ -777,26 +780,17 @@ def _train_epoch(self, epoch: int) -> None: epoch_desc = f'Epoch {format(epoch, epoch_format)}' if self.rank == 0: - progress = tqdm( - None, + data_iter = tqdm( + data_iter, total=self.total_train_steps_per_epoch, initial=resume_from_idx, desc=epoch_desc, disable=not self.train_args.enable_progress_bar, ) - else: - progress = None step_stat: Optional[_StepStat] = None for i, batches in data_iter: idx = i + resume_from_idx - - if self.rank == 0: - # looks manually update progress bar is easier - # than using tqdm directly - # the difference is we update progress bar at the beginning of the loop - # instead of the end of the loop - progress.update(1) step_start_at = time.perf_counter() step_stat = _StepStat() step_metrics = {} @@ -873,7 +867,7 @@ def _train_epoch(self, epoch: int) -> None: step_metrics['train_wall'] = time.perf_counter() - step_start_at self.log_metrics(step_metrics, tag='train') if self.rank == 0: - progress.set_postfix(step_metrics) + data_iter.set_postfix(step_metrics) if self.train_args.enable_log_progress \ and self.train_status.finished_train_steps % self.train_args.log_progress_every_n_train_steps == 0: logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) @@ -895,8 +889,8 @@ def _train_epoch(self, epoch: int) -> None: has_validated = VAL_STATUS_SAVE if self.rank == 0: # disable refresh the progress bar to avoid redundant progress bar - progress.leave = False - progress.close() + data_iter.leave = False + data_iter.close() break if not has_validated and self.train_args.val_every_n_train_steps and \ diff --git a/tests/cli/test_resume_seed.py b/tests/cli/test_resume_seed.py index 30d4992a..53f31686 100644 --- a/tests/cli/test_resume_seed.py +++ b/tests/cli/test_resume_seed.py @@ -5,7 +5,7 @@ from nnscaler.cli.trainer_args import * -@pytest.mark.skipif(not torch.cuda.is_available(), reason='no gpu') +@pytest.mark.skipif(True, reason='no gpu') def test_resume_seed(): _set_envs({ # required by deterministic diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 30bafde5..7981b952 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -259,6 +259,9 @@ def trainer_last_checkpoint_worker(save_dir): gen_savedir = save_dir / 'gen' ckpt_savedir = save_dir / 'ckpt' + for i in range (100): # make a lot of fake checkpoints + (ckpt_savedir / f'0000-{i*15:04d}').mkdir(parents=True, exist_ok=True) + trainer = Trainer([ '-f', config_path, '--max_epochs', '1', From 5fc33262ea5b0fab5c2f6d246518036d6233ed25 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 19 Nov 2024 11:02:13 +0000 Subject: [PATCH 1765/1892] Merged PR 2311: [BugFix] Estimate memory cost correctly in Pipeline Solver In autodist, memory cost for a parameter is estimated according to - world size - scale unit size (ngpus) - zero_stage - zero_ngroups In pipeline solver, we need to call multiple SPMDSolver instances with different configs listed above. In this PR, we fix the bug to pass in config correctly to the `CostDatabase`. Related work items: #2082 --- nnscaler/autodist/autodist_config.py | 5 +- nnscaler/autodist/cost_database.py | 82 ++++++++++++++++------------ nnscaler/autodist/model_graph.py | 9 ++- nnscaler/autodist/pipeline_solver.py | 34 ++++++------ nnscaler/autodist/spmd_solver.py | 9 +-- 5 files changed, 82 insertions(+), 57 deletions(-) diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 511eb425..a846f550 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -166,7 +166,6 @@ def __init__(self, self.opt_transient_coef = opt_transient_coef self.is_train = is_train self.mesh_desc = MeshDesc(mesh_row, mesh_col) - self.ngpus = self.mesh_desc.row * self.mesh_desc.col self.recompute_modules = recompute_modules # from GB to Byte self.memory_constraint = int(memory_constraint * 1024 * 1024 * 1024) @@ -258,3 +257,7 @@ def _validate_config(self): def __repr__(self): return f'{self.__class__.__name__} {self.__dict__}' + + @property + def ngpus(self): + return self.mesh_desc.ngpus diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 75f0c754..5a4451b0 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -113,13 +113,12 @@ def _profile_graph(dilled_info: str, dev_id: int, partition_degree: int, re_prof class CostDatabase: - def __init__(self, graph: IRGraph, config: AutoDistConfig): + def __init__(self, graph: IRGraph, profile_dir: str, memory_granularity: int, ignore_small_tensor_threshold: int): self.comm_info = {} self.graph = graph - self.autodist_config = config - self.profile_dir = Path(config.profile_dir) + self.profile_dir = Path(profile_dir) self.db = ProfileDataBase() self.comp_profile_path = self.profile_dir / 'comp' if not self.comp_profile_path.exists(): @@ -134,17 +133,17 @@ def __init__(self, graph: IRGraph, config: AutoDistConfig): with open(comm_dir / fname, 'r') as f: self.comm_info[fname] = json.load(f) - self.memory_granularity = self.autodist_config.memory_granularity - self.ignore_small_tensor_threshold = self.autodist_config.ignore_small_tensor_threshold + self.memory_granularity = memory_granularity + self.ignore_small_tensor_threshold = ignore_small_tensor_threshold - def profile_comp(self, partition_degree: int): + def profile_comp(self, partition_degree: int, parallel_profile: bool, re_profile: bool): def insert_profile_info(info: List[Tuple[str, str, ProfiledMetrics]]): for sign, serialized, profiled_metrics in info: _logger.debug(f'profiled {sign} in {serialized} with {profiled_metrics}') if not self.db.exist_serialized(sign, serialized): self.db.insert(sign, serialized, profiled_metrics) - if self.autodist_config.parallel_profile: + if parallel_profile: _logger.info('Profiling in parallel') # use spawn to make sure the profiling process is independent from each other # and the main process, this is also required by torch @@ -154,7 +153,7 @@ def insert_profile_info(info: List[Tuple[str, str, ProfiledMetrics]]): processes = [] for i in range(torch.cuda.device_count()): p = mp_context.Process(target=_profile_graph, - args=(self.graph.dumps(), i, partition_degree, self.autodist_config.re_profile, self.comp_profile_path, results)) + args=(self.graph.dumps(), i, partition_degree, re_profile, self.comp_profile_path, results)) processes.append(p) p.start() @@ -169,7 +168,7 @@ def insert_profile_info(info: List[Tuple[str, str, ProfiledMetrics]]): else: _logger.info('Profiling in serial') node_to_profile = _filter_nodes(self.graph, self.db) - ret = _profile_nodes(node_to_profile, self.db, partition_degree, self.autodist_config.re_profile) + ret = _profile_nodes(node_to_profile, self.db, partition_degree, re_profile) insert_profile_info(ret) self.db.dump_ops(self.comp_profile_path, override=True) @@ -209,12 +208,31 @@ def get_mems(self, op_partition): memory_results[memory_type] = mem return memory_results - def get_mem_and_buffer(self, op_partition, is_train: bool, stage_num: int): + def get_mem_and_buffer( + self, + op_partition, + is_train: bool, + stage_num: int, + world_size: int, + plan_ngpus: int, + zero_stage: int, + zero_ngroups: int, + opt_resident_coef: float, + opt_transient_coef: float + ) -> Tuple[int, int, int, int, int]: """ Get the memory consumption and buffer memory consumption of a partition option. Args: op_partition: the partition option to be calculated + is_train: whether the partition is for training + stage_num: the number of stages + world_size: the total number of devices + plan_ngpus: the number of GPUs planned + zero_stage: the zero optimization stage + zero_ngroups: the number of zero optimization groups + opt_resident_coef: the coefficient for optimizer resident memory + opt_transient_coef: the coefficient for optimizer transient memory Returns: node_mem: the memory consumption of the partition option @@ -225,55 +243,51 @@ def get_mem_and_buffer(self, op_partition, is_train: bool, stage_num: int): """ memory_results = self.get_mems(op_partition) activation_mem = memory_results['train'] - if not self.autodist_config.zero_stage in [0, 1]: - raise RuntimeError( - f'invalid zero stage {self.autodist_config.zero_stage}') + if zero_stage not in [0, 1]: + raise RuntimeError(f'invalid zero stage {zero_stage}') + # estimate optimizer memory consumption for training. # no gradient no memory consumption, # weight_mem should be 0 when require_grad is false. opt_resident_mem, opt_transient_mem = 0, 0 if is_train and memory_results['param'] > 0: - if self.autodist_config.zero_stage == 0: + if zero_stage == 0: weight_mem = memory_results['param'] else: # if zero-1 is used, we assume the full weight is distributed equally # among all devices weight_mem = self.query_single_mem(op_partition, 'full_weight') - opt_resident_mem = self.autodist_config.opt_resident_coef * weight_mem - opt_transient_mem = self.autodist_config.opt_transient_coef * weight_mem - if self.autodist_config.zero_stage == 1: + opt_resident_mem = opt_resident_coef * weight_mem + opt_transient_mem = opt_transient_coef * weight_mem + if zero_stage == 1: if op_partition.is_replicated(): - assert self.autodist_config.world_size % self.autodist_config.ngpus == 0 - scale_factor = self.autodist_config.world_size // self.autodist_config.ngpus - divisor = scale_factor // self.autodist_config.zero_ngroups + assert world_size % plan_ngpus == 0, f'world_size {world_size} is not divisible by ngpus {plan_ngpus}' + scale_factor = world_size // plan_ngpus + divisor = scale_factor // zero_ngroups else: - assert self.autodist_config.world_size % self.autodist_config.zero_ngroups == 0 - divisor = self.autodist_config.world_size // self.autodist_config.zero_ngroups + assert world_size % zero_ngroups == 0 + divisor = world_size // zero_ngroups opt_resident_mem = opt_resident_mem // divisor opt_transient_mem = opt_transient_mem // divisor # optimizer state + saved activation tensors for backward + param # + gradients + buffer tensors (has deduplicated with the saved tensors) - node_mem = opt_resident_mem + memory_results[ - 'train'] + 2 * memory_results['param'] + memory_results['buffer'] - node_mem = node_mem + (stage_num - 1) * activation_mem \ - if is_train else node_mem - node_buffer = max(memory_results.values()) \ - if is_train else memory_results['infer'] + node_mem = opt_resident_mem + memory_results['train'] + 2 * memory_results['param'] + memory_results['buffer'] + node_mem = node_mem + (stage_num - 1) * activation_mem if is_train else node_mem + node_buffer = max(memory_results.values()) if is_train else memory_results['infer'] if node_mem != 0: - def to_mb(x): return x / 1024 / 1024 _logger.debug( f'{op_partition.operator.ir_cell.cid}, {op_partition.ir_cell}, ' - + f'node mem: {to_mb(node_mem)} MB, ' + - f'activation mem: {to_mb(activation_mem)} MB, ' + - f'optimizer transient mem: {to_mb(opt_transient_mem)} MB') + + f'node mem: {to_mb(node_mem)} MB, ' + + f'activation mem: {to_mb(activation_mem)} MB, ' + + f'optimizer transient mem: {to_mb(opt_transient_mem)} MB' + ) - return node_mem, node_buffer, activation_mem, opt_transient_mem, memory_results[ - 'input'] + return node_mem, node_buffer, activation_mem, opt_transient_mem, memory_results['input'] def query_single_mem(self, obj, memory_type, round=True) -> int: """ diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 8815e3a1..552815d2 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -523,8 +523,13 @@ class ModelGraph: def __init__(self, ir_graph: IRGraph, autodist_config: AutoDistConfig): self.ir_graph = ir_graph self.autodist_config = autodist_config - self.cost_database = CostDatabase(self.ir_graph, self.autodist_config) - self.cost_database.profile_comp(partition_degree=1) + self.cost_database = CostDatabase( + self.ir_graph, + profile_dir=autodist_config.profile_dir, + memory_granularity=autodist_config.memory_granularity, + ignore_small_tensor_threshold=autodist_config.ignore_small_tensor_threshold, + ) + self.cost_database.profile_comp(1, autodist_config.parallel_profile, autodist_config.re_profile) self.scope_tree_root = self.reconstruct_scope_tree() self.scope_leaf_nodes = self.scope_tree_root.select(lambda x: x.is_leaf) diff --git a/nnscaler/autodist/pipeline_solver.py b/nnscaler/autodist/pipeline_solver.py index 306c8b4d..90ae3213 100644 --- a/nnscaler/autodist/pipeline_solver.py +++ b/nnscaler/autodist/pipeline_solver.py @@ -153,10 +153,26 @@ def process_case(device_num, stage_num): return None, [], [] # postpone the initialization of SPMDSolver to save time cur_cfg = copy.deepcopy(cfg) + # In current parallel profiler's implementation, the profiling is divided into + # following steps: + # 1. searialize the input graph + # 2. lauch the multi-process profiling by python's spawn method + # 3. each process loads the serialized graph and do profiling + # 4. transport the profiling result back to the main process + # It helps to reduce the profiling time when the graph has not been met before. + # But the procedure itself has a large overhead. + # In PipelineSolver, the SPMDSolver is constructed and used to search the optimal + # plan for multiple times. For given `tp_degree`, cases that need to be profiled + # are the same. As a result, we set `cfg.parallel_profile` to True at the first time + # and set it to False for the rest of the time. + if stage_num == 1: + cur_cfg.parallel_profile = True + else: + cur_cfg.parallel_profile = False cur_cfg.world_size = cfg.world_size // cfg.mesh_desc.ngpus * device_num + cur_cfg.mesh_desc = _dev_num2mesh_desc(device_num, cfg.mesh_desc.col) solver = SPMDSolver(graph=model_graph, - mesh_desc=_dev_num2mesh_desc( - device_num, cfg.mesh_desc.col), + mesh_desc=cur_cfg.mesh_desc, autodist_config=cur_cfg, stage_num=stage_num) @@ -195,22 +211,8 @@ def shift_plan(solver, spmd_desc, offset: int, shifted_start: int, shifted_end: tp_info = {} for tp_degree in legal_tp_degrees: - # In current parallel profiler's implementation, the profiling is divided into - # following steps: - # 1. searialize the input graph - # 2. lauch the multi-process profiling by python's spawn method - # 3. each process loads the serialized graph and do profiling - # 4. transport the profiling result back to the main process - # It helps to reduce the profiling time when the graph has not been met before. - # But the procedure itself has a large overhead. - # In PipelineSolver, the SPMDSolver is constructed and used to search the optimal - # plan for multiple times. For given `tp_degree`, cases that need to be profiled - # are the same. As a result, we set `cfg.parallel_profile` to True at the first time - # and set it to False for the rest of the time. - cfg.parallel_profile = True for stage_num in range(1, _calc_upper_bound(tp_degree) + 1): solver, intervals, solver_ret = process_case(tp_degree, stage_num) - cfg.parallel_profile = False for interval, spmd_descs in zip(intervals, solver_ret): start, end = interval if spmd_descs: diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index fc7badcd..d6a3f61a 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -174,7 +174,7 @@ def __init__( _logger.info('no partition constraint is loaded') self.cost_database = graph.cost_database - self.cost_database.profile_comp(self.device_num) + self.cost_database.profile_comp(self.device_num, autodist_config.parallel_profile, autodist_config.re_profile) self.stage_num = stage_num self.initialize() @@ -656,11 +656,12 @@ def calc_partition_cost(self, op_idx: int, partition_idx: int): else: weight_comm_time = 0 - if not self.autodist_config.consider_mem: + cfg = self.autodist_config + if not cfg.consider_mem: node_mem, node_buffer, act_mem, opt_transient_mem, in_mem = 0, 0, 0, 0, 0 else: node_mem, node_buffer, act_mem, opt_transient_mem, in_mem = self.cost_database.get_mem_and_buffer( - tgt_p, self.is_train, self.stage_num) + tgt_p, self.is_train, self.stage_num, cfg.world_size, cfg.ngpus, cfg.zero_stage, cfg.zero_ngroups, cfg.opt_resident_coef, cfg.opt_transient_coef) # communication cost induced by partitioning activation tensors of the given op partition comm_vecs = [] @@ -751,7 +752,7 @@ def calc_partition_info(self): ratio, i = importance_ratios[idx] node = self.get_operator(i).ir_cell desc_str += f'operator {node} has {self.get_op_partition_count(i)} partitions, importance ratio {ratio:.3f}\nat {node.comment}\n\n' - _logger.info(desc_str) + _logger.debug(desc_str) _logger.info('finish spmd solver initializetion') def estimate_min_mem(self, start: int, end: int) -> int: From d04f20a5d0453b93fea2fe65c318a46a87761b16 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 21 Nov 2024 07:43:09 +0000 Subject: [PATCH 1766/1892] Merged PR 2319: [Tracer] add scoped constant folding option Add scoped constant folding option Please note constant folding only controls the output of ops. If you need to fold input, you should add extra op `fold_constant` parity check pass unit test pass --- nnscaler/__init__.py | 7 + nnscaler/graph/function/function.py | 9 + nnscaler/graph/parser/converter.py | 6 +- nnscaler/graph/parser/mapping.py | 1 + nnscaler/graph/parser/parser.py | 13 ++ nnscaler/graph/tracer/concrete_proxy.py | 12 +- nnscaler/graph/tracer/concrete_tracer.py | 3 +- nnscaler/graph/tracer/metadata.py | 22 ++ nnscaler/ir/cten.py | 10 + nnscaler/runtime/function/function.py | 44 ++++ nnscaler/utils.py | 32 ++- tests/graph/tracer/test_namedtuple.py | 3 +- tests/graph/tracer/test_op_context.py | 271 +++++++++++++++++++++++ tests/parallel_module/test_gencode.py | 129 +++++++++++ tests/test_utils.py | 15 +- 15 files changed, 566 insertions(+), 11 deletions(-) create mode 100644 tests/graph/tracer/test_op_context.py diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index 87ae960f..b6af82e7 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -15,8 +15,15 @@ load_deduped_state_dict, broadcast_weights, load_sharded_state_dict, + sync_grad_when, ) from nnscaler.graph.parser.register import register_op +from nnscaler.runtime.function.function import ( + anchor, + constant_folding, + no_constant_folding, + fold_constant, +) def init(): diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 43b6ca99..94cc616e 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -94,6 +94,15 @@ def Ifexpr(cond: Any, true_value: Any, false_value: Any, signature = None) -> IR ) +def FoldConstant(value: Any, signature = None): + if any_ir_object_satisfy(value, lambda x: isinstance(x, IRTensor)): + raise ValueError("FoldConstant doesn't support IRTensor") + + # always return a constant + # no node will be created + return IRObject.try_unwrap(value) + + def MultiRef(tensor: IRTensor, times: int, signature = None): """ nnscaler.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 9840fba8..85ad099b 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -91,7 +91,11 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: }) # get cube runtime functions - cube_rt_funcs = [cube_rt_function.anchor, cube_rt_function.ifexpr] + cube_rt_funcs = [ + cube_rt_function.anchor, + cube_rt_function.ifexpr, + cube_rt_function.fold_constant + ] leaf_functions.update({ func: LeafWrapInfo([Location(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs diff --git a/nnscaler/graph/parser/mapping.py b/nnscaler/graph/parser/mapping.py index c202231b..2b460d84 100644 --- a/nnscaler/graph/parser/mapping.py +++ b/nnscaler/graph/parser/mapping.py @@ -256,6 +256,7 @@ def exist(signature: str) -> bool: # runtime functions __rtemplate('anchor'): function.GraphAnchor, __rtemplate('ifexpr'): function.Ifexpr, + __rtemplate('fold_constant'): function.FoldConstant, __rtemplate('identity'): function.Identity, __rtemplate('multiref'): function.MultiRef, __rtemplate('accum'): function.Accum, diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 751aacf1..93cbe202 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -7,6 +7,8 @@ from typing import Any, List, Tuple, Callable, Union, Dict, Type, Optional import nnscaler +from nnscaler.utils import fields +from nnscaler.graph.tracer.metadata import OpContext from nnscaler.ir.operator import IRFwOperation from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.cten import IRObject, IRCell, IRTensor @@ -193,6 +195,16 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule input_vals = FxModuleParser.parse_complex(list(node.args), frame) kwargs = FxModuleParser.parse_complex(node.kwargs, frame) + # use context constant_folding if set + # Please note constant_folding only controls the output of the op + # For op inputs, we will not fold/unfold them + # when we enter the code block with different constant folding setting + # as a workaround, + # you can use `nnscaler.runtime.function.fold_constant` to fold inputs if needed + op_context: Optional[Dict[str, Any]] = node.meta.get('op_context') + if op_context is not None and op_context.get(fields(OpContext).constant_folding) is not None: + constant_folding = op_context[fields(OpContext).constant_folding] + if SignFx2Op.exist(fsig): ir_node = SignFx2Op.map(fsig)(*input_vals, **kwargs) else: @@ -330,6 +342,7 @@ def _set_node_meta(node: torch.fx.Node, ir_node: Union[IRCell, Any]): if not isinstance(ir_node, IRCell): return + ir_node.op_context = node.meta.get('op_context') module_stack = node.meta.get('nn_module_stack') ir_node.module_stack = module_stack comment = str(node.meta.get('frame_record', '')) diff --git a/nnscaler/graph/tracer/concrete_proxy.py b/nnscaler/graph/tracer/concrete_proxy.py index 1c1d33d5..560a14a2 100644 --- a/nnscaler/graph/tracer/concrete_proxy.py +++ b/nnscaler/graph/tracer/concrete_proxy.py @@ -252,7 +252,12 @@ def __init__(self, root: ConcreteProxy, attr: str): self.root = root self.attr = attr self.tracer = root.tracer - self._node: Optional[Node] = None + # In previous version, the node creation is done lazily. + # But when we support scoped context, + # Lazy creation of node will cause the node to be created in the wrong context. + # Please note unused nodes can still be removed by DCE later. + self._node: Node = self.tracer.create_proxy( + 'call_function', orig_func.getattr, (self.root, self.attr), {}).node if orig_func.isinstance(root.value, torch.Tensor) and attr == 'is_cuda': self.value = True elif orig_func.isinstance(root.value, torch.Tensor) and attr == 'device': @@ -272,11 +277,6 @@ def __repr__(self) -> str: @property def node(self): - # the node for attributes is added lazily, since most will just be method calls - # which do not rely on the getitem call - if self._node is None: - self._node = self.tracer.create_proxy( - 'call_function', orig_func.getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index f1a44179..eef7209f 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -37,7 +37,7 @@ from . import pytree_utils, orig_func, wrap_utils from .frame_utils import get_frame_record from .function_patcher import FunctionPatcher -from .metadata import EmptyResult, extract_results_metadata +from .metadata import EmptyResult, extract_results_metadata, get_op_context from .operator_patcher import OperatorPatcherContext from .torch_fx_patcher import TorchFXPatcher, ExtraSEFPatcher, side_effectful_inplace_ops from .trace_strategy import TRACE_STRATEGY @@ -148,6 +148,7 @@ def create_node(self, kind : str, target : Target, check_for_mutable_operation(target, args, kwargs) node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) + node.meta['op_context'] = get_op_context() # TODO node_name_to_scope will be depricated in favor of # node.meta['nn_module_stack'] self.node_name_to_scope[node.name] = ( diff --git a/nnscaler/graph/tracer/metadata.py b/nnscaler/graph/tracer/metadata.py index dfb9f187..62e2ce9c 100644 --- a/nnscaler/graph/tracer/metadata.py +++ b/nnscaler/graph/tracer/metadata.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. from typing import Any, Dict, NamedTuple, Optional, Tuple +from dataclasses import dataclass, asdict +import copy import torch from torch.fx.node import Node @@ -20,6 +22,26 @@ class EmptyResult: pass +@dataclass +class OpContext: + """ + OpContext is a dataclass that holds the context of an operation. + + Args: + constant_folding: Whether constant folding is enabled. + Please note we will not unfold/fold inputs + when we enter the code block with different constant folding setting. + """ + constant_folding: Optional[bool] = None + + +_GLOBAL_OP_CONTEXT = OpContext() + + +def get_op_context() -> OpContext: + return asdict(_GLOBAL_OP_CONTEXT) + + class TensorMetadata(NamedTuple): # TensorMetadata is a structure containing pertinent information # about a tensor within a PyTorch program. diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 03550aa5..68e1e571 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -77,6 +77,8 @@ def __init__(self, self._comment: Optional[str] = None # the module stack that preserves the hierarchy information self._module_stack: Optional[OrderedDict[str, Any]] = None + # the operation context information + self._op_context: Optional[Dict[str, Any]] = None @property def cid(self) -> int: @@ -384,6 +386,14 @@ def module_stack(self, stack: OrderedDict[str, Any]): """ self._module_stack = stack + @property + def op_context(self) -> Optional[Dict[str, Any]]: + return self._op_context + + @op_context.setter + def op_context(self, context: Optional[Dict[str, Any]]): + self._op_context = context + def __repr__(self) -> str: """ Cell string presentation diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index ee1f7166..22f250a8 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from contextlib import contextmanager from typing import Optional, List, Tuple, Union, Any import torch import torch.nn.functional as TorchF @@ -31,6 +32,49 @@ def anchor(name: str): return None +@contextmanager +def constant_folding(constant_folding: bool = True): + """ + Context manager to enable/disable constant folding. + You can put it inside your forward function to control the constant folding behavior. + Please note as we don't set it as leaf function in tracer, + it will not be present in the traced graph. + """ + from nnscaler.graph.tracer.metadata import _GLOBAL_OP_CONTEXT + + old_constant_folding = _GLOBAL_OP_CONTEXT.constant_folding + _GLOBAL_OP_CONTEXT.constant_folding = constant_folding + try: + yield + finally: + _GLOBAL_OP_CONTEXT.constant_folding = old_constant_folding + + +def no_constant_folding(): + """ + Context manager to disable constant folding. + """ + return constant_folding(constant_folding=False) + + +def fold_constant(a: Any) -> Any: + """ + Fold a constant(non-tensor) if constant folding is enabled. + + Please note this should be only used in `constant_folding` block + to make sure the input to a `constant_folding` block is not wrapped in an IRObject in the graph. + + Example: + ``` + a = some_func() # the value is wrapped in IRObject in graph + with constant_folding(): + a = fold_constant(a) # unwrap value + torch.add(t, a) # in graph a is a constant + ``` + """ + return a + + def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: """ identity forward. Create multiple same tensor. diff --git a/nnscaler/utils.py b/nnscaler/utils.py index b5209250..00408e24 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -3,7 +3,10 @@ from contextlib import contextmanager from functools import wraps -from typing import Generator, Optional, Tuple, Callable, List, Set, Any, Iterable, Type, Union +from typing import ( + Generator, Optional, Tuple, Callable, Dict, List, Set, Any, + Iterable, Type, Union, Protocol, ClassVar, cast, TypeVar +) import logging from pathlib import Path import sys @@ -292,6 +295,33 @@ def __get__(self, obj, objtype=None): # so here __set__ and __delete__ are not implemented, and the property is read-only +# ref: https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass +class IsDataclass(Protocol): + # as already noted in comments, checking for this attribute is currently + # the most reliable way to ascertain that something is a dataclass + __dataclass_fields__: ClassVar[Dict[str, Any]] + + +# ref: https://github.com/pydantic/pydantic/discussions/8600 +@dataclass(frozen=True) +class _GetFields: + _dataclass_type: Type[IsDataclass] + + def __getattr__(self, item: str) -> Any: + if item in self._dataclass_type.__dataclass_fields__: + return item + raise AttributeError(f'"{item}" is not a valid field in type: {self._dataclass_type}') + + +TDataClass = TypeVar("TDataClass", bound=Type[IsDataclass]) +def fields(model: TDataClass, /) -> TDataClass: + """ + This function is used to get the field names(in str) of a dataclass. + This is a workaround for the lack of `__name__` of dataclass field. + """ + return cast(TDataClass, _GetFields(model)) + + class accum_mode: """Make cube execution in gradient accumulation mode. diff --git a/tests/graph/tracer/test_namedtuple.py b/tests/graph/tracer/test_namedtuple.py index 0731c960..20ff7ccf 100644 --- a/tests/graph/tracer/test_namedtuple.py +++ b/tests/graph/tracer/test_namedtuple.py @@ -14,11 +14,12 @@ def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(10, 5) self.fc2 = torch.nn.Linear(10, 5) - + def forward(self, x): Result = namedtuple('Result', ['r1', 'r2']) return Result(self.fc1(x), self.fc2(x)) + @replace_all_device_with('cpu') def test_namedtuple(): model = SimpleModel() diff --git a/tests/graph/tracer/test_op_context.py b/tests/graph/tracer/test_op_context.py new file mode 100644 index 00000000..f85aab7f --- /dev/null +++ b/tests/graph/tracer/test_op_context.py @@ -0,0 +1,271 @@ +import operator + +import torch +import nnscaler +from nnscaler.graph.parser.converter import to_fx_graph +from tests.utils import replace_all_device_with + + +class ContextConstantFoldingModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ma = torch.nn.Parameter(torch.rand(3, 3)) + self.mb = torch.nn.Parameter(torch.rand(3, 3)) + self.mc = torch.nn.Parameter(torch.rand(3, 3)) + self.md = torch.nn.Parameter(torch.rand(3, 3)) + self.me = torch.nn.Parameter(torch.rand(3, 3)) + self.mf = torch.nn.Parameter(torch.rand(3, 3)) + self.mg = torch.nn.Parameter(torch.rand(3, 3)) + + def forward(self, a: torch.Tensor): + b = self.ma * a + with nnscaler.constant_folding(): + + c = self.mb * b + with nnscaler.no_constant_folding(): + d = self.mc * c + with nnscaler.constant_folding(): + e = self.md * d + + f = self.me * e + + g = self.mf * f + + h = self.mg * g + return h + + +@replace_all_device_with('cpu') +def test_context_folding(): + model = ContextConstantFoldingModule() + dummy_input = {'a': torch.rand(3, 3)} + traced_graph = to_fx_graph(model, dummy_input) + nodes = list(traced_graph.graph.nodes) + + assert nodes[0].name == 'a' + assert nodes[0].op == 'placeholder' + assert nodes[0].meta['op_context']['constant_folding'] is None + + assert nodes[-1].name == 'output' + assert nodes[-1].op == 'output' + assert nodes[-1].meta['op_context']['constant_folding'] is None + + # b = self.ma * a + # h = self.mg * g + assert nodes[1].name == 'ma' + assert nodes[1].op == 'get_attr' + assert nodes[1].meta['op_context']['constant_folding'] is None + assert nodes[2].name == 'mul' + assert nodes[2].op == 'call_function' + assert nodes[2].target == operator.mul + assert nodes[2].meta['op_context']['constant_folding'] is None + + assert nodes[13].name == 'mg' + assert nodes[13].op == 'get_attr' + assert nodes[13].meta['op_context']['constant_folding'] is None + assert nodes[14].name.startswith('mul_') + assert nodes[14].op == 'call_function' + assert nodes[14].target == operator.mul + assert nodes[14].meta['op_context']['constant_folding'] is None + + # with nnscaler.constant_folding(): + # c = self.mb * b + # g = self.mf * f + assert nodes[3].name == 'mb' + assert nodes[3].op == 'get_attr' + assert nodes[3].meta['op_context']['constant_folding'] is True + assert nodes[4].name.startswith('mul') + assert nodes[4].op == 'call_function' + assert nodes[4].target == operator.mul + assert nodes[4].meta['op_context']['constant_folding'] is True + + assert nodes[11].name == 'mf' + assert nodes[11].op == 'get_attr' + assert nodes[11].meta['op_context']['constant_folding'] is True + assert nodes[12].name.startswith('mul') + assert nodes[12].op == 'call_function' + assert nodes[12].target == operator.mul + assert nodes[12].meta['op_context']['constant_folding'] is True + + # d = self.mc * c + # f = self.me * e + assert nodes[5].name == 'mc' + assert nodes[5].op == 'get_attr' + assert nodes[5].meta['op_context']['constant_folding'] is False + assert nodes[6].name.startswith('mul') + assert nodes[6].op == 'call_function' + assert nodes[6].target == operator.mul + assert nodes[6].meta['op_context']['constant_folding'] is False + + assert nodes[9].name == 'me' + assert nodes[9].op == 'get_attr' + assert nodes[9].meta['op_context']['constant_folding'] is False + assert nodes[10].name.startswith('mul') + assert nodes[10].op == 'call_function' + assert nodes[10].target == operator.mul + assert nodes[10].meta['op_context']['constant_folding'] is False + + # e = self.md * d + assert nodes[7].name == 'md' + assert nodes[7].op == 'get_attr' + assert nodes[7].meta['op_context']['constant_folding'] is True + assert nodes[8].name.startswith('mul') + assert nodes[8].op == 'call_function' + assert nodes[8].target == operator.mul + assert nodes[8].meta['op_context']['constant_folding'] is True + + +class CFModule1(torch.nn.Module): + def __init__(self): + super().__init__() + self.mb = torch.nn.Parameter(torch.rand(3, 3)) + + @nnscaler.constant_folding() + def forward(self, a: torch.Tensor): + return self.mb * a + + +class CFModule2(torch.nn.Module): + def __init__(self): + super().__init__() + self.mc = torch.nn.Parameter(torch.rand(3, 3)) + + @nnscaler.constant_folding(False) + def forward(self, a: torch.Tensor): + return self.mc * a + + +class CFModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ma = torch.nn.Parameter(torch.rand(3, 3)) + self.mb = CFModule1() + self.mc = CFModule2() + self.md = torch.nn.Parameter(torch.rand(3, 3)) + self.me = torch.nn.Parameter(torch.rand(3, 3)) + + def forward(self, a: torch.Tensor): + x = self.ma * a + x = self.mb(x) + with nnscaler.constant_folding(): + x = self.mc(x) + x = self.md * x + x = self.me * x + return x + + +@replace_all_device_with('cpu') +def test_context_folding_module(): + model = CFModule() + dummy_input = {'a': torch.rand(3, 3)} + traced_graph = to_fx_graph(model, dummy_input) + nodes = list(traced_graph.graph.nodes) + + assert nodes[0].name == 'a' + assert nodes[0].op == 'placeholder' + assert nodes[0].meta['op_context']['constant_folding'] is None + + assert nodes[-1].name == 'output' + assert nodes[-1].op == 'output' + assert nodes[-1].meta['op_context']['constant_folding'] is None + + # x = self.ma * a + # x = self.me * x + assert nodes[1].name == 'ma' + assert nodes[1].op == 'get_attr' + assert nodes[1].meta['op_context']['constant_folding'] is None + assert nodes[2].name == 'mul' + assert nodes[2].op == 'call_function' + assert nodes[2].target == operator.mul + assert nodes[2].meta['op_context']['constant_folding'] is None + + assert nodes[9].name == 'me' + assert nodes[9].op == 'get_attr' + assert nodes[9].meta['op_context']['constant_folding'] is None + assert nodes[10].name.startswith('mul_') + assert nodes[10].op == 'call_function' + assert nodes[10].target == operator.mul + assert nodes[10].meta['op_context']['constant_folding'] is None + + # x = self.mb(x) + assert nodes[3].name == 'mb_mb' + assert nodes[3].op == 'get_attr' + assert nodes[3].meta['op_context']['constant_folding'] is True + assert nodes[4].name.startswith('mul') + assert nodes[4].op == 'call_function' + assert nodes[4].target == operator.mul + assert nodes[4].meta['op_context']['constant_folding'] is True + + # x = self.mc(x) + assert nodes[5].name == 'mc_mc' + assert nodes[5].op == 'get_attr' + assert nodes[5].meta['op_context']['constant_folding'] is False + assert nodes[6].name.startswith('mul') + assert nodes[6].op == 'call_function' + assert nodes[6].target == operator.mul + assert nodes[6].meta['op_context']['constant_folding'] is False + + # x = self.md * x + assert nodes[7].name == 'md' + assert nodes[7].op == 'get_attr' + assert nodes[7].meta['op_context']['constant_folding'] is True + assert nodes[8].name.startswith('mul') + assert nodes[8].op == 'call_function' + assert nodes[8].target == operator.mul + assert nodes[8].meta['op_context']['constant_folding'] is True + + +class OpReorderModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ma = torch.nn.Parameter(torch.rand(3, 3)) + + def forward(self, a: torch.Tensor): + shape = a.shape + with nnscaler.constant_folding(): + b = self.ma.abs() + shape[0] + return b + + +@replace_all_device_with('cpu') +def test_context_folding_reordered_and_dce(): + """ + Test getattr will not be reordered (old lazy-style implementation will reorder it) + """ + model = OpReorderModule() + dummy_input = {'a': torch.rand(3, 3)} + traced_graph = to_fx_graph(model, dummy_input) + nodes = list(traced_graph.graph.nodes) + assert nodes[0].name == 'a' + assert nodes[0].op == 'placeholder' + assert nodes[0].meta['op_context']['constant_folding'] is None + assert nodes[-1].name == 'output' + assert nodes[-1].op == 'output' + assert nodes[-1].meta['op_context']['constant_folding'] is None + + # a.shape + assert nodes[1].name.startswith('getattr') + assert nodes[1].op == 'call_function' + assert nodes[1].target == getattr + nodes[1].args[1] == 'shape' + + #b = self.ma.abs() + shape[0] + assert nodes[2].name == 'ma' + assert nodes[2].op == 'get_attr' + assert nodes[2].meta['op_context']['constant_folding'] is True + + # the ma.abs getattr will be eliminated by dce. + assert nodes[3].name.startswith('abs') + assert nodes[3].op == 'call_method' + assert nodes[3].target == 'abs' + assert nodes[3].meta['op_context']['constant_folding'] is True + + assert nodes[4].name == 'getitem' + assert nodes[4].op == 'call_function' + assert nodes[4].target == operator.getitem + assert nodes[4].meta['op_context']['constant_folding'] is True + + assert nodes[5].name == 'add' + assert nodes[5].op == 'call_function' + assert nodes[5].target == operator.add + assert nodes[5].meta['op_context']['constant_folding'] is True diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 320baf31..96bbfaa7 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -3,6 +3,7 @@ import inspect import tempfile +import re from contextlib import nullcontext import torch @@ -1547,3 +1548,131 @@ def test_codegen_function_to(tmp_path): assert _gencode_contains(tmp_path, FunctionToModule, 0, r'to_\d+ = torch\.Tensor\.to\(x_\d+\)') # to_1_22 = torch.Tensor.to(linear_26, copy=True, dtype=torch.float32) assert _gencode_contains(tmp_path, FunctionToModule, 0, r'torch\.Tensor\.to([^, ]*, copy=True, dtype=torch.float32)') + + +class CCFModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) + self.linear3 = torch.nn.Linear(3, 3) + self.linear4 = torch.nn.Linear(3, 3) + self.linear5 = torch.nn.Linear(3, 3) + self.linear6 = torch.nn.Linear(3, 3) + self.linear7 = torch.nn.Linear(3, 3) + self.linear8 = torch.nn.Linear(3, 3) + + def forward(self, a: torch.Tensor): + ashape = a.shape[0] # not folded + b = self.linear1(a) + ashape + bshape = b.shape[0] # not folded + with nnscaler.constant_folding(): + d = self.linear3(b) + ashape + dshape = d.shape[0] # folded + e = self.linear4(d) + dshape + bshape + ashape + with nnscaler.no_constant_folding(): + f = self.linear5(e) + dshape + bshape + ashape + fshape = f.shape[0] # not folded + g = self.linear6(f) + fshape + dshape + bshape + ashape + gshape = g.shape[0] # folded + h = self.linear7(g) + gshape + fshape + dshape + bshape + ashape + hshape = h.shape[0] + i = self.linear8(h) + hshape + gshape + fshape + dshape + bshape + ashape + return i + + +@replace_all_device_with('cpu') +def test_constant_folding_context(tmp_path): + parallelize( + CCFModule(), + {'a': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + 'dp', + ComputeConfig(1, 1, constant_folding=False), + gen_savedir=tmp_path, + load_module=False + ) + # Just check all torch.add code + add_codes = _gencode_contains(tmp_path, CCFModule, 0, r'.*torch\.add.*') + assert len(add_codes) == 23 + + not_folded_names = ['ashape', 'bshape', 'fshape', 'hshape'] + folded_names = ['dshape', 'gshape'] + + def check_op(*names): + for name in names: + code = add_codes.pop(0) + if name in not_folded_names: + assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, getitem_.*, alpha=1\)', code) + else: + assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, 2, alpha=1\)', code) + + # b = self.linear1(a) + ashape + check_op('ashape') + # d = self.linear3(b) + ashape + check_op('ashape') + # e = self.linear4(d) + dshape + bshape + ashape + check_op('dshape', 'bshape', 'ashape') + # f = self.linear5(e) + dshape + bshape + ashape + check_op('dshape', 'bshape', 'ashape') + # g = self.linear6(f) + fshape + dshape + bshape + ashape + check_op('fshape', 'dshape', 'bshape', 'ashape') + # h = self.linear7(g) + gshape + fshape + dshape + bshape + ashape + check_op('gshape', 'fshape', 'dshape', 'bshape', 'ashape') + # i = self.linear8(h) + hshape + gshape + fshape + dshape + bshape + ashape + check_op('hshape', 'gshape', 'fshape', 'dshape', 'bshape', 'ashape') + + assert not add_codes + + +class CCFModule2(torch.nn.Module): + def __init__(self, fold_input=False): + super().__init__() + self.linear1 = torch.nn.Linear(3, 3) + self.fold_input = fold_input + + def forward(self, a: torch.Tensor): + from nnscaler.runtime.function import fold_constant + ashape = a.shape[0] # not folded + ashape2 = a.shape[1] # not folded + ashape3 = ashape + ashape2 # not folded + with nnscaler.constant_folding(): + if self.fold_input: + ashape = fold_constant(ashape) + b = self.linear1(a) + ashape + if self.fold_input: + # check if the constant folding is correctly applied to tuple + # here we have 3 constants to fold + # In graph, it will be two nodes `fold_constant` and `getitem` + ashape, ashape2, ashape3 = fold_constant((ashape, ashape2, ashape3)) + b = b * ashape * ashape2 * ashape3 + return b + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('fold_input', [False, True]) +def test_fold_constant(tmp_path, fold_input): + parallelize( + CCFModule2(fold_input), + {'a': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + 'dp', + ComputeConfig(1, 1, constant_folding=False), + gen_savedir=tmp_path, + reuse='override', + load_module=False + ) + if fold_input: + # add_28 = torch.add(linear_31, 2, alpha=1) + assert _gencode_contains(tmp_path, CCFModule2, 0, + r'add_.* = torch\.add\(linear_.*, 2, alpha=1\)') + # b = b * ashape3 + # mul_2_59 = torch.mul(mul_1_65, 5) + assert _gencode_contains(tmp_path, CCFModule2, 0, + r'mul_.* = torch\.mul\(mul_.*, 5\)') + else: + # add_27 = torch.add(linear_30, getitem_20, alpha=1) + assert _gencode_contains(tmp_path, CCFModule2, 0, + r'add_.* = torch\.add\(linear_.*, getitem_.*, alpha=1\)') + # b = b * ashape3 + # mul_2_51 = torch.mul(mul_1_57, add_38) + assert _gencode_contains(tmp_path, CCFModule2, 0, + r'mul_.* = torch\.mul\(mul_.*, add_.*\)') diff --git a/tests/test_utils.py b/tests/test_utils.py index ffb2b2dd..7fa7d80a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from dataclasses import dataclass import pytest -from nnscaler.utils import select_many, classproperty +from nnscaler.utils import select_many, classproperty, fields def test_select_many(): @@ -40,3 +41,15 @@ def cfg(cls): x[1] = 2 assert A.cfg == {1: 2} assert id(A().cfg) == id(x) + + +def test_fields(): + @dataclass + class A: + x: int + y: int + + assert fields(A).x == 'x' + assert fields(A).y == 'y' + with pytest.raises(AttributeError): + fields(A).z From fb86ffc096d3680349af7475b0f5d57163982fd1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 21 Nov 2024 08:53:56 +0000 Subject: [PATCH 1767/1892] Merged PR 2321: [BugFix] Fix cli random seed unit test [BugFix] Fix cli random seed unit test by using launch_torchrun to run the train code. The previous version mocked env variable to fake a torchrun environment in current process doesn't work well. --- tests/cli/test_resume_seed.py | 73 +++++++++++------------------------ 1 file changed, 23 insertions(+), 50 deletions(-) diff --git a/tests/cli/test_resume_seed.py b/tests/cli/test_resume_seed.py index 53f31686..44356b45 100644 --- a/tests/cli/test_resume_seed.py +++ b/tests/cli/test_resume_seed.py @@ -3,62 +3,50 @@ import torch from nnscaler.cli.trainer import Trainer from nnscaler.cli.trainer_args import * +from tests.launch_torchrun import launch_torchrun -@pytest.mark.skipif(True, reason='no gpu') -def test_resume_seed(): - _set_envs({ - # required by deterministic - 'CUBLAS_WORKSPACE_CONFIG': ':4096:8', +@pytest.mark.skipif(not torch.cuda.is_available(), reason='no gpu') +def test_resume_seed(tmp_path): + launch_torchrun(1, resume_seed_worker, tmp_path) - # fake torchrun environment, check https://pytorch.org/docs/stable/elastic/run.html#environment-variables - 'LOCAL_RANK': 0, - 'RANK': 0, - 'GROUP_RANK': 0, - 'LOCAL_WORLD_SIZE': 1, - 'WORLD_SIZE': 1, - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': 29470, - 'TORCHELASTIC_RUN_ID': 'UT', - }) +def resume_seed_worker(tmp_path): + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' torch.use_deterministic_algorithms(True) # compile separately because run multiple trainers in one process will confuse `gen_reuse` - _compile() + _compile(tmp_path) + _test_resume_seed(tmp_path, steps_per_epoch=100, max_steps=20, resume_at=10) + _test_resume_seed(tmp_path, steps_per_epoch=5, max_steps=20, resume_at=10) - _test_resume_seed(steps_per_epoch=100, max_steps=20, resume_at=10) - _test_resume_seed(steps_per_epoch=5, max_steps=20, resume_at=10) - - _restore_envs() - - -def _test_resume_seed(steps_per_epoch, max_steps, resume_at): +def _test_resume_seed(tmp_path, steps_per_epoch, max_steps, resume_at): # no resume - model_1 = _train(steps_per_epoch, max_train_steps=max_steps, resume_from=None) + model_1 = _train(tmp_path, steps_per_epoch, max_train_steps=max_steps, resume_from=None) weight_1 = next(model_1.parameters()).data # resume - _train(steps_per_epoch, max_train_steps=resume_at, resume_from=None) - model_2 = _train(steps_per_epoch, max_train_steps=max_steps, resume_from='last') + _train(tmp_path, steps_per_epoch, max_train_steps=resume_at, resume_from=None) + model_2 = _train(tmp_path, steps_per_epoch, max_train_steps=max_steps, resume_from='last') weight_2 = next(model_2.parameters()).data assert torch.equal(weight_1, weight_2) ## resume without resuming seeds - _train(steps_per_epoch, max_train_steps=resume_at, resume_from=None) - _remove_rng_states() - model_3 = _train(steps_per_epoch, max_train_steps=max_steps, resume_from='last') + _train(tmp_path, steps_per_epoch, max_train_steps=resume_at, resume_from=None) + _remove_rng_states(tmp_path) + model_3 = _train(tmp_path, steps_per_epoch, max_train_steps=max_steps, resume_from='last') weight_3 = next(model_3.parameters()).data assert not torch.equal(weight_1, weight_3) -def _compile(): +def _compile(tmp_path): trainer_args = TrainerArgs( compute_config=ComputeConfig(plan_ngpus=1, runtime_ngpus=1, use_end2end=True), gen_reuse='override', + gen_savedir=tmp_path/'src', run_mode='compile', model=ModelConfig(type=Model), optimizer=OptimizerConfig(type=torch.optim.AdamW), @@ -71,13 +59,14 @@ def _compile(): trainer.run() -def _train(steps_per_epoch, max_train_steps, resume_from): +def _train(tmp_path, steps_per_epoch, max_train_steps, resume_from): trainer_args = TrainerArgs( + gen_savedir=tmp_path/'src', compute_config=ComputeConfig(plan_ngpus=1, runtime_ngpus=1, use_end2end=True), model=ModelConfig(type=Model), optimizer=OptimizerConfig(type=torch.optim.AdamW), dataset=DatasetConfig(type=RandomDataset, train_args={'length': steps_per_epoch}), - checkpoint=CheckpointConfig(resume_from=resume_from), + checkpoint=CheckpointConfig(resume_from=resume_from, save_dir=tmp_path/'checkpoints'), max_train_steps=max_train_steps, enable_progress_bar=False, seed=0, @@ -87,29 +76,13 @@ def _train(steps_per_epoch, max_train_steps, resume_from): return trainer.model -def _remove_rng_states(): - ckpt_path = 'checkpoints/last/0.ckpt' +def _remove_rng_states(tmp_path): + ckpt_path = tmp_path / 'checkpoints/last/0.ckpt' ckpt = torch.load(ckpt_path, weights_only=False) ckpt['rng_states'] = None torch.save(ckpt, ckpt_path) -_backup_envs = {} - -def _set_envs(envs): - _backup_envs.clear() - for key, value in envs.items(): - _backup_envs[key] = os.environ.get(key, None) - os.environ[key] = str(value) - -def _restore_envs(): - for key, value in _backup_envs.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value - - class Model(torch.nn.Module): def __init__(self): super().__init__() From 1bdedcc449e4e8257acb8682761aacceb77f871f Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Fri, 22 Nov 2024 04:30:31 +0000 Subject: [PATCH 1768/1892] Merged PR 2285: Add max train steps arg to llama 128k example The default (2 epochs) is too long for a test. Add a command line arg to reduce test time. --- examples/llama3_8B_128K/train.py | 11 +++++++++++ examples/llama3_demo/train.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/examples/llama3_8B_128K/train.py b/examples/llama3_8B_128K/train.py index 10d5b776..390ef0e1 100644 --- a/examples/llama3_8B_128K/train.py +++ b/examples/llama3_8B_128K/train.py @@ -227,6 +227,7 @@ def collate_fn(samples): checkpoint=checkpoint_config, precision='bf16', max_epochs=2, + max_train_steps=args.max_train_steps, grad_accumulation_steps=4, log=[log_config], seed=0, @@ -296,6 +297,16 @@ def collate_fn(samples): type=str, help='trace strategy control the function execution during tracing model graph, `cuda_run_cpu_offload` and `reuse_cache` are recommended, please read `docs/source/parallel_module.md` for more information', ) + parser.add_argument( + '--max_train_steps', + default=None, + type=int, + help='max training steps', + ) args = parser.parse_args() + if os.getenv('DETERMINISTIC'): # reduce randomness for integration test + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + main(args) diff --git a/examples/llama3_demo/train.py b/examples/llama3_demo/train.py index 7d1f2edc..b7a47084 100644 --- a/examples/llama3_demo/train.py +++ b/examples/llama3_demo/train.py @@ -271,4 +271,8 @@ def collate(samples): if __name__ == '__main__': + if os.getenv('DETERMINISTIC'): # reduce randomness for integration test + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + main() From 1be48b2ba1ae70c2e66e3dca52b30a0d0c97060b Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Fri, 22 Nov 2024 08:44:52 +0000 Subject: [PATCH 1769/1892] Merged PR 2322: trust remote code for load dataset trust remote code should be set to True for loading dataset --- examples/llama3_8B_128K/bookcorpus.py | 2 +- examples/llama3_demo/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/llama3_8B_128K/bookcorpus.py b/examples/llama3_8B_128K/bookcorpus.py index 1a77273a..31ba97e8 100644 --- a/examples/llama3_8B_128K/bookcorpus.py +++ b/examples/llama3_8B_128K/bookcorpus.py @@ -52,7 +52,7 @@ def create_dataset(tokenizer: PreTrainedTokenizer, raw_dataset: Dataset, text_ke save_path = args.save_path sequence_length = args.sequence_length - raw_dataset = load_dataset(data_path_or_name)["train"] + raw_dataset = load_dataset(data_path_or_name, trust_remote_code=True)["train"] tokenizer = get_tokenizer(tokenizer_path_or_name) dataset = create_dataset(tokenizer, raw_dataset, "text", sequence_length) dataset.save_to_disk(save_path) diff --git a/examples/llama3_demo/train.py b/examples/llama3_demo/train.py index b7a47084..bfed4389 100644 --- a/examples/llama3_demo/train.py +++ b/examples/llama3_demo/train.py @@ -42,7 +42,7 @@ def prepare_data(max_seq_len, dataset_path=None): if dataset_path is None: dataset_path = f'./bookcorpus-{max_seq_len}' - dataset = datasets.load_dataset(dataset_id)['train'] + dataset = datasets.load_dataset(dataset_id, trust_remote_code=True)['train'] tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) def _tokenize(sample): From 7abfca31d1c809abf80206dc4ae0829a33c7d60c Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 22 Nov 2024 09:00:13 +0000 Subject: [PATCH 1770/1892] Merged PR 2315: [Model Example] Support Llama3 70B 8k by pipeline parallelism on mi300 Related work items: #2050, #2071 --- examples/{llama3_8B_128K => llama}/.gitignore | 0 examples/llama/README.md | 189 ++++++++++++++++++ .../{llama3_8B_128K => llama}/bookcorpus.py | 2 +- .../chunk_linear_cross_entropy.py | 0 .../{llama3_8B_128K => llama}/ckpt_merger.py | 0 .../create_mini_model.py | 0 .../modeling_modifier.py | 0 .../requirements.txt | 0 examples/{llama3_8B_128K => llama}/train.py | 79 ++++++-- examples/llama3_8B_128K/README.md | 138 ------------- 10 files changed, 251 insertions(+), 157 deletions(-) rename examples/{llama3_8B_128K => llama}/.gitignore (100%) create mode 100644 examples/llama/README.md rename examples/{llama3_8B_128K => llama}/bookcorpus.py (97%) rename examples/{llama3_8B_128K => llama}/chunk_linear_cross_entropy.py (100%) rename examples/{llama3_8B_128K => llama}/ckpt_merger.py (100%) rename examples/{llama3_8B_128K => llama}/create_mini_model.py (100%) rename examples/{llama3_8B_128K => llama}/modeling_modifier.py (100%) rename examples/{llama3_8B_128K => llama}/requirements.txt (100%) rename examples/{llama3_8B_128K => llama}/train.py (77%) delete mode 100644 examples/llama3_8B_128K/README.md diff --git a/examples/llama3_8B_128K/.gitignore b/examples/llama/.gitignore similarity index 100% rename from examples/llama3_8B_128K/.gitignore rename to examples/llama/.gitignore diff --git a/examples/llama/README.md b/examples/llama/README.md new file mode 100644 index 00000000..68ab68fd --- /dev/null +++ b/examples/llama/README.md @@ -0,0 +1,189 @@ +# Introduction + +This example demonstrates how to train llama models in challenging distributed configurations by nnscaler. + +# Requirements + +Assume following packages have been installed in the environment. + +```text +nnscaler +transformers==4.40.0 +datasets==2.20.0 +apex +flash-attn +``` + +*nnScaler* is a framework for distributed training by automatically partitioning the model. Apart from the core nnScaler library, it also includes a mini-trainer for modern model training. You can find related documents and examples at [nnScaler](https://nnscaler.readthedocs.io/en/latest/). + +*transformers* and *datasets* are used to prepare the data and loading the Llama model. + +To speed up the training, [*apex*](https://github.com/NVIDIA/apex) and [*flash-attn*](https://github.com/Dao-AILab/flash-attention) are required. You can install them by following instructions in their official repositories. We also recommend to launch training in a docker directly, like nvidia/pytorch:24.02-py3 and rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0. + +# Supported Models + +The following table lists the supported model architectures and their corresponding distributed environments. A performance analysis for these will be provided later in the document. We plan to support more model combinations in the future and encourage you to experiment and contribute. + +| Model ID | Sequence Length | Device Type | Device Number | +| :---------------------------------: | :-------------: | :---------: | :-----------: | +| meta-llama/Meta-Llama-3-8B-Instruct | 131072 | H100 | 8 | +| meta-llama/Meta-Llama-3-70B | 8192 | MI300 | 16 | + +# Data Preparation + +We use the [bookcorpus](https://huggingface.co/datasets/bookcorpus) dataset for demonstrating in this doc. You can change related code to support your own dataset. Here we give an example that downloads and tokenizes `bookcorpus` for Llama. + +In the example command below, the dataset is tokenized by [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) tokenizer and grouped into 128K, tokenized data is saved in `bookcorpus_llama3_128K` directory. + +```bash +python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_128K --sequence_length 131072 +``` + +# Training + +nnScaler adopts a compiler approach to train deep learning models on multiple deivices. The processing pipeline is divided into two stages: + +1. Compile stage: trace the original PyTorch model and get the dataflow graph. Analyze the graph and generate an efficient plan for distributed training. Generate python code for the runtime stage. +2. Runtime stage: run the generated python code to train the model. + +For better user experience, we recommend to use separate commands for the compile and runtime stages at your first trial of nnScaler. You can use the `Run` command directly to combine the two stages when you are familiar with the system. + +**Note**: currently we only tested `"_attn_implementation": "flash_attention_2"` and `"use_cache": false` in the config file. Other configurations may trigger errors. + +## Trace Strategy + +During compiling, the time cost of trace model graph can vary significantly depending on the tracing strategy employed. Below are some reference time to trace `meta-llama/Meta-Llama-3-8B-Instruct` with different strategies and different context length, the time tested on one single A100 80GB: + +| Strategy | Context Length | Time/seconds | +| :------: | :------------: | :----------: | +| `reuse_cache` | 8k | 8.11 | +| `reuse_cache` | 32k | 11.06 | +| `reuse_cache` | 64k | 15.36 | +| `reuse_cache` | 128k | 26.29 | +| `cuda_run_cpu_offload` | 8k | 55.28 | +| `cuda_run_cpu_offload` | 32k | 194.27 | +| `cuda_run_cpu_offload` | 64k | 342.15 | +| `cuda_run_cpu_offload` | 128k | 789.15 | + +The trace strategy can be changed by setting `--trace_strategy` option. Please note that different strategies have different applicable scenarios. For more information and explanation to the different strategies, please read `docs/source/parallel_module.md`. + +## Register Customized Function + +Llama3's vocabulary size is about 128K, which is much larger then the 32K in Llama2. When the sequence length is very long like 128K, the output tensor size of the last projection layer is quite large: 128K x 128K x 2 bytes = 32GB in fp16 or bf16. +Although this tensor can be partitioned evenly to 8 GPUs, 4GB memory is still large due to limited GPU memory. What makes it worse is that we need to store additional 8GB for `log_softmax` and `cross_entropy_loss` computation. +In order to reduce the memory consumption: +- we split the input sequence on each device to chunks of 1K tokens +- for each chunk, we recompute a function which is composed of last projection layer, log_softmax and loss +- as a result, we only need to store the input tensor to the last projection layer, whose initial size is 128K x 4K x 2 bytes = 1GB, which is much smaller than 32GB + +You can find the detailed implementation in `chunk_linear_cross_entropy.py`. +The interface of the `chunk_linear_cross_entropy` function is `(hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor`, where +- `hidden_states` is the output of the last transformer layer, with shape `[batch_size, sequence_length, hidden_size]` +- `weight` is the weight matrix of the last projection layer, with shape `[vocab_size, hidden_size]` +- `labels` is the target labels, with shape `[batch_size, sequence_length]` +- `padding_idx` is the padding index +- `chunk_size` is the size of the chunk, default is 1024 + +We want to register this function to nnScaler and tell it to partition this function along batch size or sequence dimension. A possible annotation is `b l d^, n^ d^, b l -> b l`. Here `b` stands for batch size, `l` stands for sequence length, `d` stands for hidden size, and `n` stands for vocab size. The `^` means the dimension cannot be partitioned. More details about the annotation can be found in `docs/source/register_custom_op.md`. + +You can enable this customized function by passing `--enable-chunk-loss` to `train.py` when compiling. When the sequence length is small (like 8K), this option can be turned off. + +## Profile Communication + +To generate an efficient distributed plan in your environment, we recommend to profile the intra-node communication before compiling. The profiler records the time of different communication primitives (like allgather, allreduce, reducescatter and alltoall) for some message sizes. If the profiling is skipped, the system will use MI250's data by default. You can use the command below to profile. + +```bash +cd nnscaler && python utility/prim_profiler.py +``` + +## Checkpoint + +`train.py` will save the model checkpoint in the `./checkpoints` directory by default. You can change the checkpoint directory by updating the `CheckpointConfig` in the source code. + +nnScalar saves checkpoints in shards: each rank may save parameters and optimizer states in a file. These checkpoints can be directly loaded by nnScaler if the partitioning strategy is the same. If you want to evaluate the checkpoints on downstream tasks, you need to merge the shards into a single file. You can use the following command to merge the shards: + +```bash +python ckpt_merger.py --ckpt_dir ./checkpoints --output_fname ./merged.ckpt +``` + +The merged checkpoint can be loaded by nnScaler by setting the `--resume_path` option to the merged file. + +If the script is modified for different hardware configurations. +- All sharded checkpoint files should be collected and placed in a same directory before `ckpt_merger.py` is called. +- If the config is changed (plan_ngus/runtime_ngus/etc), the sharded checkpoint can not be used anymore. You need to merge them so the trainer can load from merged checkpoint. + +# Performance Analysis + +The flops of the forward computation for llama is + +$2 \cdot ( param\_num \cdot seqlen + 2 \cdot layer\_num \cdot hidden\_dim \cdot seqlen ^ 2)$ + +## Llama3 8B 128K on 8xH100 + +Commands below is used for this setting. + +**Compile** + +```bash +python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee compile.log +``` + +**Run** + +```bash +torchrun --nproc_per_node=8 train.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee run.log +``` + +For the 8B model, the forward flops is about 11104.35 TFLOPs. The detailed config is as following: +- $param\_num = 8 \times 10^9$ +- $seqlen = 128 \times 1024$ +- $layer\_num = 32$ +- $hidden\_dim = 4096$ + +Generally, the computational cost of backpropagation is twice that of the forward pass. In addition, the gradient accumulation number is set to 4. As a result, the flops for a step of the training script is 133252.22 TFLOPs. + +We execute the training script on a node with 8xH100 80GB HBM3. The time cost is about 41.12s for a step. The theoretical BF16 computational speed of the H100 is 989 TFLOPS. Combine them together, this script can achieve 40.96% MFU. You can optimize the performance furtherly by +- add more devices to avoid recomputation: in order to fit the model into the memory, we recompute by layer. +- do more kernel optimizations. For example, the swiglu activation can be fused into the matmul ahead of it. + +## Llama3 70B 8K on 16xMI300 + +Different from the 8B example, a merged command is used for the multi-node setting. Since 70b model is trained on 2 nodes, we use mpi to execute `torchrun` on them at the same time. If you want to run the command on your own, you can replace `MASTER_ADDR` with the IP address of the first node, `MASTER_PORT` with the available port on the first node and fill `OMPI_COMM_WORLD_RANK` with 0 and 1 on two nodes respectively. + +**Combined Command** + +```bash +torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$$OMPI_COMM_WORLD_RANK --master_addr="$$MASTER_ADDR" --master_port=$$MASTER_PORT train.py --name llama3-70b --model_id meta-llama/Meta-Llama-3-70B --dataset_path ./bookcorpus_llama3_8K --gpu_mem_constraint 153 --plan_ngpus=8 --runtime_ngpus=16 --explore_pipeline --grad_accumulation_steps 64 --pipeline_pivots LlamaDecoderLayer 2>&1 | tee run.log +``` + +Note that in the command above, we enable searching for pipeline parallelism by passing `--explore_pipeline` and set the possible pipeline stage boundaries by `--pipeline_pivots LlamaDecoderLayer`. + +For the 70B model, the flops for forward and backward is about 3968.41 TFLOPs. The detailed config is as following: +- $param\_num = 70 \times 10^9$ +- $seqlen = 8192$ +- $layer\_num = 80$ +- $hidden\_dim = 8192$ + +[MI300X](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf)'s peak theoritical performance for BF16 is 1307.4 TFLOPS. It takes about 100.3 s to finish 64 gradient accumulation steps in the experiment. Combine them together, the MFU of this distributed plan is 24.2 %. + +Based on AutoDist's analysis, the low utilization results from following aspects +- We observe MFU for important operators are low. For example, `linear`'s MFU is 40% ~ 50%, the real MFU of `flash-attn` is 14%. +- Like the 8B 128K example, we can fuse operators like RoPE and swiglu to reduce time. +- There are two pipeline stages each with 4 devices. In each stage, communication takes about 450ms and computation takes about 1000ms. According to our experiences, the communication time is higher than expected. Adding more devices may help to reduce it since the optimizer states still takes about 52GB in each device. +- Enlarge search space in the future. Currently we only consider plan_ngpus=8 and fix the pipeline schedule to be `1f1b`. We can refine this assumption in the future. + +# Debugging + +Since the large setting is challenging, it is recommended to use a smaller model for debugging. For example, you can use the following command to prepare data and train a smaller llama3 (same architecture, but with 4 decoder layers) model on two GPUs. + +```bash +# prepare data +python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 + +# build the mini model +python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini + +# compile and run using data parallelism + zero1 +torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K + +``` diff --git a/examples/llama3_8B_128K/bookcorpus.py b/examples/llama/bookcorpus.py similarity index 97% rename from examples/llama3_8B_128K/bookcorpus.py rename to examples/llama/bookcorpus.py index 31ba97e8..7d2441c2 100644 --- a/examples/llama3_8B_128K/bookcorpus.py +++ b/examples/llama/bookcorpus.py @@ -11,7 +11,7 @@ def get_tokenizer(model_path): - return AutoTokenizer.from_pretrained(model_path) + return AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) def tokenize(sample: Dict[str, str], tokenizer: PreTrainedTokenizer, text_key: str): diff --git a/examples/llama3_8B_128K/chunk_linear_cross_entropy.py b/examples/llama/chunk_linear_cross_entropy.py similarity index 100% rename from examples/llama3_8B_128K/chunk_linear_cross_entropy.py rename to examples/llama/chunk_linear_cross_entropy.py diff --git a/examples/llama3_8B_128K/ckpt_merger.py b/examples/llama/ckpt_merger.py similarity index 100% rename from examples/llama3_8B_128K/ckpt_merger.py rename to examples/llama/ckpt_merger.py diff --git a/examples/llama3_8B_128K/create_mini_model.py b/examples/llama/create_mini_model.py similarity index 100% rename from examples/llama3_8B_128K/create_mini_model.py rename to examples/llama/create_mini_model.py diff --git a/examples/llama3_8B_128K/modeling_modifier.py b/examples/llama/modeling_modifier.py similarity index 100% rename from examples/llama3_8B_128K/modeling_modifier.py rename to examples/llama/modeling_modifier.py diff --git a/examples/llama3_8B_128K/requirements.txt b/examples/llama/requirements.txt similarity index 100% rename from examples/llama3_8B_128K/requirements.txt rename to examples/llama/requirements.txt diff --git a/examples/llama3_8B_128K/train.py b/examples/llama/train.py similarity index 77% rename from examples/llama3_8B_128K/train.py rename to examples/llama/train.py index 390ef0e1..e80f73fb 100644 --- a/examples/llama3_8B_128K/train.py +++ b/examples/llama/train.py @@ -59,19 +59,31 @@ def get_tokenizer(tokenizer_name_or_path, class WrapperModel(torch.nn.Module): - def __init__(self, model_id): + def __init__(self, model_id, enable_chunk_loss): super().__init__() self.model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation='flash_attention_2') + self.enable_chunk_loss = enable_chunk_loss def forward(self, samples): - outputs = self.model.model( - input_ids=samples['net_input']['src_tokens'], - use_cache=False, - return_dict=False, - ) - hidden_states = outputs[0] - losses = chunk_linear_cross_entropy(hidden_states, self.model.lm_head.weight, samples['target'], IGNORE_IDX, 1024) - loss = torch.sum(losses) + if self.enable_chunk_loss: + outputs = self.model.model( + input_ids=samples['net_input']['src_tokens'], + use_cache=False, + return_dict=False, + ) + hidden_states = outputs[0] + losses = chunk_linear_cross_entropy(hidden_states, self.model.lm_head.weight, samples['target'], IGNORE_IDX, 1024) + loss = torch.sum(losses) + else: + outputs = self.model( + input_ids=samples['net_input']['src_tokens'], + use_cache=False, + return_dict=False, + ) + logits = outputs[0].view(-1, outputs[0].size(-1)) + labels = samples['target'].view(-1) + normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) + loss = torch.nn.functional.nll_loss(normalized_logits, labels, reduction='sum', ignore_index=IGNORE_IDX) return loss, loss.data, samples['ntokens'], samples['nsentences'] @@ -126,9 +138,10 @@ def collate_fn(samples): shift_labels = mini_batch['labels'][..., 1:] _mini_batch['labels'] = torch.nn.functional.pad(shift_labels, (0, 1), 'constant', IGNORE_IDX).contiguous() + # cast `nsentences` and `ntokens` to tensor since current pipeline parallelism can only transfer data in tensor format return { - "nsentences": len(samples), - "ntokens": len(samples) * seq_len, + "nsentences": torch.tensor(len(samples), dtype=torch.long), + "ntokens": torch.tensor(len(samples) * seq_len, dtype=torch.long), "net_input": _mini_batch, "target": _mini_batch.pop('labels'), } @@ -156,12 +169,11 @@ def collate_fn(samples): constant_folding=True, use_zero=True, use_end2end=True, - # autodist config: - # - memory constraint default value is 64GB - # - recompute by the transformer layer in Llama pas_config={ 'mem_constraint': args.gpu_mem_constraint, - 'recompute_modules': 'LlamaDecoderLayer', + 'explore_pipeline': args.explore_pipeline, + 'pipeline_pivots': args.pipeline_pivots, + 'recompute_modules': args.recompute_modules, }, trace_strategy=args.trace_strategy, ) @@ -170,6 +182,7 @@ def collate_fn(samples): type=WrapperModel, args={ 'model_id': args.model_id, + 'enable_chunk_loss': args.enable_chunk_loss, }, ) @@ -197,7 +210,7 @@ def collate_fn(samples): sampler_config = DatasetSamplerConfig( train_args={ - 'shuffle': False, + 'shuffle': True, }, ) @@ -227,8 +240,8 @@ def collate_fn(samples): checkpoint=checkpoint_config, precision='bf16', max_epochs=2, + grad_accumulation_steps=args.grad_accumulation_steps, max_train_steps=args.max_train_steps, - grad_accumulation_steps=4, log=[log_config], seed=0, broadcast_strategy=broadcast_strategy, @@ -245,7 +258,7 @@ def collate_fn(samples): parser = argparse.ArgumentParser() parser.add_argument( '--name', - default='llama3-8b', + default='llama', type=str, help='name of the experiment', ) @@ -297,6 +310,34 @@ def collate_fn(samples): type=str, help='trace strategy control the function execution during tracing model graph, `cuda_run_cpu_offload` and `reuse_cache` are recommended, please read `docs/source/parallel_module.md` for more information', ) + parser.add_argument( + '--enable-chunk-loss', + action='store_true', + help='enable chunk loss that exchanges the speed of training for the memory usage', + ) + parser.add_argument( + '--explore_pipeline', + action='store_true', + help='explore pipeline parallelism in autodist', + ) + parser.add_argument( + '--pipeline_pivots', + default='', + type=str, + help='specify the pipeline pivots for autodist', + ) + parser.add_argument( + '--recompute_modules', + default='', + type=str, + help='specify the modules to recompute in autodist', + ) + parser.add_argument( + '--grad_accumulation_steps', + default=4, + type=int, + help='number of gradient accumulation steps', + ) parser.add_argument( '--max_train_steps', default=None, @@ -304,6 +345,8 @@ def collate_fn(samples): help='max training steps', ) args = parser.parse_args() + if args.explore_pipeline and not args.pipeline_pivots: + raise ValueError('pipeline_pivots must be specified when explore_pipeline is enabled') if os.getenv('DETERMINISTIC'): # reduce randomness for integration test os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' diff --git a/examples/llama3_8B_128K/README.md b/examples/llama3_8B_128K/README.md deleted file mode 100644 index 449d2fcc..00000000 --- a/examples/llama3_8B_128K/README.md +++ /dev/null @@ -1,138 +0,0 @@ -# Introduction - -This example demonstrates how to train llama3-8B-128k model with 8xH100s or 8xA100s. - -# Requirements - -To run this example, you need to install the following packages: - -```text -nnscaler -transformers==4.40.0 -datasets==2.20.0 -apex -flash-attn -``` - -*nnScaler* is a framework for distributed training by automatically partitioning the model. Apart from the core nnScaler library, it also includes a mini-trainer for modern model training. You can find related documents and examples at [nnScaler](TODO). - -*transformers* and *datasets* are required to prepare the data and loading the Llama model. - -To speed up the training process, [*apex*](https://github.com/NVIDIA/apex) and [*flash-attn*](https://github.com/Dao-AILab/flash-attention) are required. You can install them by following instructions in their official repositories. We also recommend to launch the script under a Nvidia docker directly, like nvidia/pytorch:24.02-py3. - -# Data Preparation - -We use the [bookcorpus](https://huggingface.co/datasets/bookcorpus) dataset for training. The dataset is tokenized with the [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) tokenizer. Tokenized data is saved in the `bookcorpus_llama3_128K` directory. - -```bash -python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_128K --sequence_length 131072 -``` - -# Training - -nnScaler adopts a compiler approach to launch the distributed training. The processing pipeline is divided into two stages: - -1. Compile stage: trace the original PyTorch model and get the dataflow graph. Analyze the graph and generate an efficient plan for distributed training. Generate python code for the runtime stage. -2. Runtime stage: run the generated python code to train the model. - -For better user experience, we recommend to use separate commands for the compile and runtime stages. You can also use the `Run` command directly to combine the two stages. - -**Note**: currently we only tested `"_attn_implementation": "flash_attention_2"` and `"use_cache": false` in the config file. Other configurations may not work properly. - -## Register Customized Function - -Llama3's vocabulary size is about 128K, which is much larger then the 32K in Llama2. At the same time the sequence length in this example is 128K, the output tensor size of the last projection layer is quite large: 128K x 128K x 2 bytes = 32GB. -Although this tensor can be partitioned evenly to 8 GPUs, 4GB memory is still quite large for limited GPU memory. What makes it worse is that we need to store additional 8GB for `log_softmax` and `cross_entropy_loss` computation. -In order to reduce the memory consumption: -- we split the input sequence on each device to chunks of 1K tokens -- for each chunk, we recompute a function which is composed of last projection layer, log_softmax and loss -- as a result, we only need to store the input tensor to the last projection layer, whose initial size is 128K x 4K x 2 bytes = 1GB, which is much smaller than 32GB - -You can find the detailed implementation in `chunk_linear_cross_entropy.py`. -The interface of the `chunk_linear_cross_entropy` function is `(hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor`, where -- `hidden_states` is the output of the last transformer layer, with shape `[batch_size, sequence_length, hidden_size]` -- `weight` is the weight matrix of the last projection layer, with shape `[vocab_size, hidden_size]` -- `labels` is the target labels, with shape `[batch_size, sequence_length]` -- `padding_idx` is the padding index -- `chunk_size` is the size of the chunk, default is 1024 - -We want to register this function to nnScaler and tell it to partition this function along batch size or sequence dimension. A possible annotation is `b l d^, n^ d^, b l -> b l`. Here `b` stands for batch size, `l` stands for sequence length, `d` stands for hidden size, and `n` stands for vocab size. The `^` means the dimension cannot be partitioned. More details about the annotation can be found in related documents. - -## Compile - -```bash -python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 2>&1 | tee compile.log -``` - -## Run - -```bash -torchrun --nproc_per_node=8 train.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 2>&1 | tee run.log -``` - -## Checkpoint - -This script will save the model checkpoint in the `./checkpoints` directory. You can change the checkpoint directory by updating the `CheckpointConfig` in the `train.py` script. - -nnScalar saves checkpoints in shards: each rank may save parameters and optimizer states in a file. These checkpoints can be directly loaded by nnScaler if the partitioning strategy is the same. If you want to evaluate the checkpoints on downstream tasks, you need to merge the shards into a single file. You can use the following command to merge the shards: - -```bash -python ckpt_merger.py --ckpt_dir ./checkpoints --output_fname ./merged.ckpt -``` - -The merged checkpoint can be loaded by nnScaler by setting the `--resume_path` option to the merged file. - -If the script is modified for different hardware configurations. -- All sharded checkpoint files should be collected and placed in a same directory before `ckpt_merger.py` is called. -- If the config is changed (plan_ngus/runtime_ngus/etc), the sharded checkpoint can not be used anymore. You need to merge them so the trainer can load from merged checkpoint. - -# Performance - -The flops of the forward computation for Llama3 is - -$2 \cdot ( param\_num \cdot seqlen + 2 \cdot layer\_num \cdot hidden\_dim \cdot seqlen ^ 2)$ - -For the 8B model, the forward flops is about 11104.35 TFLOPs. The detailed config is as following: -- $param\_num = 8 \times 10^9$ -- $seqlen = 128 \times 1024$ -- $layer\_num = 32$ -- $hidden\_dim = 4096$ - -Generally, the computational cost of backpropagation is twice that of the forward pass. In addition, the gradient accumulation number is set to 4. As a result, the flops for a step of the training script is 133252.22 TFLOPs. - -We execute the training script on a node with 8xH100 80GB HBM3. The time cost is about 41.12s for a step. The theoretical BF16 computational speed of the H100 is 989 TFLOPS. Combine them together, this script can achieve 40.96% MFU. You can optimize the performance furtherly by -- add more devices to avoid recomputation: in order to fit the model into the memory, we recompute by layer. -- do more kernel optimizations. For example, the swiglu activation can be fused into the matmul ahead of it. - -## Trace Strategy - -During compiling, the time cost of trace model graph can vary significantly depending on the tracing strategy employed. Below are some reference time to trace `meta-llama/Meta-Llama-3-8B-Instruct` with different strategies and different context length, the time tested on one single A100 80GB: - -| Strategy | Context Length | Time/seconds | -| :------: | :------------: | :----------: | -| `reuse_cache` | 8k | 8.11 | -| `reuse_cache` | 32k | 11.06 | -| `reuse_cache` | 64k | 15.36 | -| `reuse_cache` | 128k | 26.29 | -| `cuda_run_cpu_offload` | 8k | 55.28 | -| `cuda_run_cpu_offload` | 32k | 194.27 | -| `cuda_run_cpu_offload` | 64k | 342.15 | -| `cuda_run_cpu_offload` | 128k | 789.15 | - -The trace strategy can be changed by setting `--trace_strategy` option. Please note that different strategies have different applicable scenarios. For more information and explanation to the different strategies, please read `docs/source/parallel_module.md`. - -# Debugging - -Since the 128K config is challenging, it is recommended to use a smaller model for debugging. For example, you can use the following command to prepare data and train a smaller llama3 (same architecture, but with 4 decoder layers) model on two GPUs. - -```bash -# prepare data -python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 - -# build the mini model -python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini - -# compile and run using data parallelism + zero1 -torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K - -``` From 926ee1cca6a80e7218da0bf6830fdb56ae3fdca9 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Fri, 22 Nov 2024 11:57:25 +0000 Subject: [PATCH 1771/1892] Merged PR 2287: add dagan example to nnscaler examples add dagan example to nnscaler examples Related work items: #2026 --- examples/vision/DaGAN/README.md | 24 ++ examples/vision/DaGAN/dataset.py | 56 ++++ examples/vision/DaGAN/model_full.py | 219 ++++++++++++++ examples/vision/DaGAN/requirements.txt | 34 +++ examples/vision/DaGAN/run.py | 394 +++++++++++++++++++++++++ examples/vision/DaGAN/vox-adv-256.yaml | 81 +++++ 6 files changed, 808 insertions(+) create mode 100644 examples/vision/DaGAN/README.md create mode 100644 examples/vision/DaGAN/dataset.py create mode 100644 examples/vision/DaGAN/model_full.py create mode 100644 examples/vision/DaGAN/requirements.txt create mode 100644 examples/vision/DaGAN/run.py create mode 100644 examples/vision/DaGAN/vox-adv-256.yaml diff --git a/examples/vision/DaGAN/README.md b/examples/vision/DaGAN/README.md new file mode 100644 index 00000000..833f9088 --- /dev/null +++ b/examples/vision/DaGAN/README.md @@ -0,0 +1,24 @@ +This example demonstrates a GAN-like vision model. The nnscaler trainer assumes there is only one end-to-end module that needs to be parallelized. However, GAN-like models always have both a generator and a discriminator. Here, you will learn how to run your code without the nnscaler trainer, and how to parallelize, synchronize, and update modules during training. + +In this example, both `GeneratorFullModel` and `DiscriminatorFullModel` contain the same keypoint detector, generator, and discriminator modules. A module cannot be parallelized multiple times, so keypoint detector, generator, and discriminator must be parallelized separately. Separate synchronization and updates are also needed during training. + + +# clone CVPR2022-DaGAN repository from github +``` +cd MagicCube/examples/vision/DaGAN +git clone https://github.com/harlanhong/CVPR2022-DaGAN.git +``` + +# Install dependent packages +``` +mv CVPR2022-DaGAN CVPR2022_DaGAN +cd CVPR2022_DaGAN +pip install --ignore-installed -r requirements.txt +cd .. +export PYTHONPATH=$PYTHONPATH:CVPR2022_DaGAN +``` + +# run +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc-per-node=4 --master_port=12348 run.py --config vox-adv-256.yaml --name DaGAN --batchsize 8 --kp_num 15 --generator DepthAwareGenerator +``` \ No newline at end of file diff --git a/examples/vision/DaGAN/dataset.py b/examples/vision/DaGAN/dataset.py new file mode 100644 index 00000000..b2bf65c3 --- /dev/null +++ b/examples/vision/DaGAN/dataset.py @@ -0,0 +1,56 @@ +import random +import string + +import torch +import torch.utils.data as data + + +class VDataset(data.Dataset): + """ + What has been changed: + Generates random data for training and evaluation. + """ + def __init__( + self, + size=256, + is_train=True, + evaluate_all=True, + data_type="two", + **kwargs, + ) -> None: + super().__init__() + self.size = size + self.is_train = is_train + self.evaluate_all = evaluate_all + self.data_type = data_type + self.len = 128 + torch.manual_seed(42) + self.sources = torch.normal(0.0, 1.0, size=(self.len, 3, 256, 256)) + self.driving = torch.normal(1.1, 3.0, size=(self.len, 3, 256, 256)) + self.video = torch.normal(2.1, 5.1, size=(self.len, 3, 256, 256)) + + + def __len__(self): + return self.len + + def __getitem__(self, index): + if self.is_train: + source = self.sources[index] + driving = self.driving[index] + data_sample = { + "source": source, + "driving": driving, + } + return data_sample + + else: + video = self.video[index] + if "norm" in self.data_type: + video = video * 2.0 - 1.0 + + out_name = ''.join(random.choices(string.ascii_letters + string.digits, k=8)) + ".mp4" + data_sample = { + "video": video, + "out_name": out_name, + } + return data_sample diff --git a/examples/vision/DaGAN/model_full.py b/examples/vision/DaGAN/model_full.py new file mode 100644 index 00000000..57f64770 --- /dev/null +++ b/examples/vision/DaGAN/model_full.py @@ -0,0 +1,219 @@ +import torch +import CVPR2022_DaGAN.depth as depth +from CVPR2022_DaGAN.modules.model import ImagePyramide +from CVPR2022_DaGAN.modules.model import Vgg19 +from CVPR2022_DaGAN.modules.model import GeneratorFullModel +from CVPR2022_DaGAN.modules.model import DiscriminatorFullModel +from CVPR2022_DaGAN.modules.model import Transform +from CVPR2022_DaGAN.modules.model import detach_kp +import torch.nn.functional as F + +class GeneratorFullModel_NNSCALER(GeneratorFullModel): + """ + Merge all generator related updates into single model for better multi-gpu usage + What has been changed: + 1. Replace train_params by config in __init__, which include all the content in yaml, original get scale + by self.discriminator.module.scales, but in nnscaler, there is no scales in discriminator.module + 2. Remove self.depth_encoder.load_state_dict() and self.depth_decoder.load_state_dict() in __init__, + which needs extra download, for this example, pretriained weights is not necessary. + 3. Remove passing in driving_depth in forward, generated = self.generator(...), + because driving_depth is not used in generator forward, and pass in an unused argument is not allowed in nnscaler. + """ + + def __init__(self, kp_extractor, generator, discriminator, config, opt): + super(GeneratorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.discriminator = discriminator + self.train_params = config['train_params'] + self.scales = self.train_params['scales'] + self.disc_scales = config['model_params']['discriminator_params']['scales'] + self.pyramid = ImagePyramide(self.scales, config['model_params']['common_params']['num_channels']) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + self.opt = opt + self.loss_weights = self.train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + self.depth_encoder = depth.ResnetEncoder(50, False).cuda() + self.depth_decoder = depth.DepthDecoder(num_ch_enc=self.depth_encoder.num_ch_enc, scales=range(4)).cuda() + self.set_requires_grad(self.depth_encoder, False) + self.set_requires_grad(self.depth_decoder, False) + self.depth_decoder.eval() + self.depth_encoder.eval() + + def forward(self, x): + depth_source = None + depth_driving = None + outputs = self.depth_decoder(self.depth_encoder(x['source'])) + depth_source = outputs[("disp", 0)] + outputs = self.depth_decoder(self.depth_encoder(x['driving'])) + depth_driving = outputs[("disp", 0)] + + if self.opt.use_depth: + kp_source = self.kp_extractor(depth_source) + kp_driving = self.kp_extractor(depth_driving) + elif self.opt.rgbd: + source = torch.cat((x['source'],depth_source),1) + driving = torch.cat((x['driving'],depth_driving),1) + kp_source = self.kp_extractor(source) + kp_driving = self.kp_extractor(driving) + else: + kp_source = self.kp_extractor(x['source']) + kp_driving = self.kp_extractor(x['driving']) + + generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving, source_depth = depth_source) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + loss_values = {} + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction']) + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_values['perceptual'] = value_total + + if self.loss_weights['generator_gan'] != 0: + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + value_total = 0 + for scale in self.disc_scales: + key = 'prediction_map_%s' % scale + value = ((1 - discriminator_maps_generated[key]) ** 2).mean() + value_total += self.loss_weights['generator_gan'] * value + loss_values['gen_gan'] = value_total + + if sum(self.loss_weights['feature_matching']) != 0: + value_total = 0 + for scale in self.disc_scales: + key = 'feature_maps_%s' % scale + for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): + if self.loss_weights['feature_matching'][i] == 0: + continue + value = torch.abs(a - b).mean() + value_total += self.loss_weights['feature_matching'][i] * value + loss_values['feature_matching'] = value_total + + if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: + transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) + transformed_frame = transform.transform_frame(x['driving']) + if self.opt.use_depth: + outputs = self.depth_decoder(self.depth_encoder(transformed_frame)) + depth_transform = outputs[("disp", 0)] + transformed_kp = self.kp_extractor(depth_transform) + elif self.opt.rgbd: + outputs = self.depth_decoder(self.depth_encoder(transformed_frame)) + depth_transform = outputs[("disp", 0)] + transform_img = torch.cat((transformed_frame,depth_transform),1) + transformed_kp = self.kp_extractor(transform_img) + else: + transformed_kp = self.kp_extractor(transformed_frame) + + generated['transformed_frame'] = transformed_frame + generated['transformed_kp'] = transformed_kp + + ## Value loss part + if self.loss_weights['equivariance_value'] != 0: + value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean() + loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value + + ## jacobian loss part + if self.loss_weights['equivariance_jacobian'] != 0: + jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']), + transformed_kp['jacobian']) + + normed_driving = torch.inverse(kp_driving['jacobian']) + normed_transformed = jacobian_transformed + value = torch.matmul(normed_driving, normed_transformed) + + eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) + + value = torch.abs(eye - value).mean() + loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value + + + if self.loss_weights['kp_distance']: + bz,num_kp,kp_dim = kp_source['value'].shape + sk = kp_source['value'].unsqueeze(2)-kp_source['value'].unsqueeze(1) + dk = kp_driving['value'].unsqueeze(2)-kp_driving['value'].unsqueeze(1) + source_dist_loss = (-torch.sign((torch.sqrt((sk*sk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()*0.2)-0.2)+1).mean() + driving_dist_loss = (-torch.sign((torch.sqrt((dk*dk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()*0.2)-0.2)+1).mean() + # driving_dist_loss = (torch.sign(1-(torch.sqrt((dk*dk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()))+1).mean() + value_total = self.loss_weights['kp_distance']*(source_dist_loss+driving_dist_loss) + loss_values['kp_distance'] = value_total + if self.loss_weights['kp_prior']: + bz,num_kp,kp_dim = kp_source['value'].shape + sk = kp_source['value'].unsqueeze(2)-kp_source['value'].unsqueeze(1) + dk = kp_driving['value'].unsqueeze(2)-kp_driving['value'].unsqueeze(1) + dis_loss = torch.relu(0.1-torch.sqrt((sk*sk).sum(-1)+1e-8))+torch.relu(0.1-torch.sqrt((dk*dk).sum(-1)+1e-8)) + bs,nk,_=kp_source['value'].shape + scoor_depth = F.grid_sample(depth_source,kp_source['value'].view(bs,1,nk,-1)) + dcoor_depth = F.grid_sample(depth_driving,kp_driving['value'].view(bs,1,nk,-1)) + sd_loss = torch.abs(scoor_depth.mean(-1,keepdim=True) - kp_source['value'].view(bs,1,nk,-1)).mean() + dd_loss = torch.abs(dcoor_depth.mean(-1,keepdim=True) - kp_driving['value'].view(bs,1,nk,-1)).mean() + value_total = self.loss_weights['kp_distance']*(dis_loss+sd_loss+dd_loss) + loss_values['kp_distance'] = value_total + + + if self.loss_weights['kp_scale']: + bz,num_kp,kp_dim = kp_source['value'].shape + if self.opt.rgbd: + outputs = self.depth_decoder(self.depth_encoder(generated['prediction'])) + depth_pred = outputs[("disp", 0)] + pred = torch.cat((generated['prediction'],depth_pred),1) + kp_pred = self.kp_extractor(pred) + elif self.opt.use_depth: + outputs = self.depth_decoder(self.depth_encoder(generated['prediction'])) + depth_pred = outputs[("disp", 0)] + kp_pred = self.kp_extractor(depth_pred) + else: + kp_pred = self.kp_extractor(generated['prediction']) + + pred_mean = kp_pred['value'].mean(1,keepdim=True) + driving_mean = kp_driving['value'].mean(1,keepdim=True) + pk = kp_source['value']-pred_mean + dk = kp_driving['value']- driving_mean + pred_dist_loss = torch.sqrt((pk*pk).sum(-1)+1e-8) + driving_dist_loss = torch.sqrt((dk*dk).sum(-1)+1e-8) + scale_vec = driving_dist_loss/pred_dist_loss + bz,n = scale_vec.shape + value = torch.abs(scale_vec[:,:n-1]-scale_vec[:,1:]).mean() + value_total = self.loss_weights['kp_scale']*value + loss_values['kp_scale'] = value_total + if self.loss_weights['depth_constraint']: + bz,num_kp,kp_dim = kp_source['value'].shape + outputs = self.depth_decoder(self.depth_encoder(generated['prediction'])) + depth_pred = outputs[("disp", 0)] + value_total = self.loss_weights['depth_constraint']*torch.abs(depth_driving-depth_pred).mean() + loss_values['depth_constraint'] = value_total + return loss_values, generated + + +class DiscriminatorFullModel_NNSCALER(DiscriminatorFullModel): + """ + Merge all discriminator related updates into single model for better multi-gpu usage + What has been changed: + 1. Replace train_params by config in __init__, the same reason as GeneratorFullModel_NNSCALER + """ + + def __init__(self, kp_extractor, generator, discriminator, config): + super(DiscriminatorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.discriminator = discriminator + self.train_params = config['train_params'] + self.scales = config['model_params']['discriminator_params']['scales'] + + self.pyramid = ImagePyramide(self.scales, config['model_params']['common_params']['num_channels']) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = self.train_params['loss_weights'] \ No newline at end of file diff --git a/examples/vision/DaGAN/requirements.txt b/examples/vision/DaGAN/requirements.txt new file mode 100644 index 00000000..fd56852e --- /dev/null +++ b/examples/vision/DaGAN/requirements.txt @@ -0,0 +1,34 @@ +absl-py +certifi +cycler +fonttools +grpcio +imageio +importlib-metadata +joblib +kiwisolver +Markdown +matplotlib +networkx +numpy==1.23.0 +packaging +pandas +Pillow +protobuf +pyparsing +python-dateutil +pytz +PyWavelets +PyYAML +scikit-image +scikit-learn +scipy +six +sklearn +tensorboard +threadpoolctl +tifffile +tqdm +typing_extensions +Werkzeug +zipp diff --git a/examples/vision/DaGAN/run.py b/examples/vision/DaGAN/run.py new file mode 100644 index 00000000..3bfb1862 --- /dev/null +++ b/examples/vision/DaGAN/run.py @@ -0,0 +1,394 @@ +import matplotlib + +matplotlib.use('Agg') + +import os, sys +import yaml +from argparse import ArgumentParser +from time import gmtime, strftime +from shutil import copy + +from torch.utils.tensorboard import SummaryWriter +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import torch +import json +import hashlib + +from tqdm import trange +import torch + +from torch.utils.data import DataLoader + +from tqdm import tqdm +from torch.optim.lr_scheduler import MultiStepLR + +from CVPR2022_DaGAN.logger import Logger +from CVPR2022_DaGAN.frames_dataset import DatasetRepeater +import CVPR2022_DaGAN.modules.keypoint_detector as KPD +from CVPR2022_DaGAN.animate import animate +import CVPR2022_DaGAN.modules.generator as gen_module +from CVPR2022_DaGAN.modules.discriminator import MultiScaleDiscriminator +from model_full import DiscriminatorFullModel_NNSCALER +import model_full as MODEL_FULL + +from nnscaler.parallel import ComputeConfig, ReuseType, build_optimizer, parallelize +from dataset import VDataset + +""" +Main modifications: + 1. build dummy inputs and ComputeConfig for parallization + 2. parallelize generator, discriminator, kp_detector + 3. build_optimizer for generator, discriminator, kp_detector + 4. after loss.backward(), need to sync_shard_grad() and scale_grads() according optimizers +Note: + batchsize in training and validation and test should be the same. +""" + + +def init_seeds(cuda_deterministic=True): + import random + import numpy as np + import torch.backends.cudnn as cudnn + seed = 0 + dist.get_rank() + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html + if cuda_deterministic: # slower, more reproducible + import os + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + cudnn.deterministic = True + cudnn.benchmark = False + else: # faster, less reproducible + cudnn.deterministic = False + cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + + +def main(rank, world_size): + if sys.version_info[0] < 3: + raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") + + parser = ArgumentParser() + parser.add_argument("--config", required=True, help="path to config") + parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"]) + parser.add_argument("--log_dir", default='log', help="path to log into") + parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") + # parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), + # help="Names of the devices comma separated.") + parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") + parser.add_argument("--use_depth",action='store_true',help='depth mode') + parser.add_argument("--rgbd",action='store_true',help='rgbd mode') + parser.add_argument("--kp_prior",action='store_true',help='use kp_prior in final objective function') + + # alter model + parser.add_argument("--generator",required=True,help='the type of genertor') + parser.add_argument("--kp_detector",default='KPDetector',type=str,help='the type of KPDetector') + parser.add_argument("--GFM",default='GeneratorFullModel_NNSCALER',help='the type of GeneratorFullModel') + + parser.add_argument("--batchsize",type=int, default=-1,help='user defined batchsize') + parser.add_argument("--kp_num",type=int, default=-1,help='user defined keypoint number') + parser.add_argument("--kp_distance",type=int, default=10,help='the weight of kp_distance loss') + parser.add_argument("--depth_constraint",type=int, default=0,help='the weight of depth_constraint loss') + + parser.add_argument("--name",type=str,help='user defined model saved name') + + parser.set_defaults(verbose=False) + opt = parser.parse_args() + with open(opt.config) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + if opt.checkpoint is not None: + log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) + else: + log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) + log_dir += opt.name + + + print("Training...") + + device=torch.device("cuda",rank) + torch.cuda.set_device(device) + config['train_params']['loss_weights']['depth_constraint'] = opt.depth_constraint + config['train_params']['loss_weights']['kp_distance'] = opt.kp_distance + if opt.kp_prior: + config['train_params']['loss_weights']['kp_distance'] = 0 + config['train_params']['loss_weights']['kp_prior'] = 10 + if opt.batchsize != -1: + config['train_params']['batch_size'] = opt.batchsize + if opt.kp_num != -1: + config['model_params']['common_params']['num_kp'] = opt.kp_num + + # create generator + generator = getattr(gen_module, opt.generator)(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + generator.to(device) + if opt.verbose: + print(generator) + + # create discriminator + discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'], + **config['model_params']['common_params']) + + discriminator.to(device) + if opt.verbose: + print(discriminator) + + kp_detector = getattr(KPD, opt.kp_detector)(**config['model_params']['kp_detector_params'], + **config['model_params']['common_params']) + kp_detector.to(device) + if opt.verbose: + print(kp_detector) + + if config['backend_params']['backend_name'] == "nnscaler": + # backend configs + plan_ngpus = config["backend_params"]["plan_ngpus"] + runtime_ngpus = config["backend_params"]["runtime_ngpus"] + batchsize = config["train_params"]["batch_size"] # per gpu batch size + batchsize *= plan_ngpus + frame_shape = config["dataset_params"]["frame_shape"] + # create dummy input tensors for nnscaler graph tracing + kp_detector_dummy_input = {"x": torch.randn(batchsize, 3, frame_shape[0], frame_shape[1])} + + generator_dummy_input = { + 'source_image': torch.randn(batchsize, 3, 256, 256), + 'kp_driving': { + 'value': torch.randn(batchsize, 15, 2), + 'jacobian': torch.randn(batchsize, 15, 2, 2) + }, + 'kp_source': { + 'value': torch.randn(batchsize, 15, 2), + 'jacobian': torch.randn(batchsize, 15, 2, 2) + }, + 'source_depth': torch.randn(batchsize, 1, 256, 256), + 'driving_depth': None # torch.randn(batchsize, 1, 256, 256) + } + + pyramide_generated = { + 'prediction_1': torch.randn(batchsize, 3, 256, 256), + 'prediction_0.5': torch.randn(batchsize, 3, 128, 128), + 'prediction_0.25': torch.randn(batchsize, 3, 64, 64), + 'prediction_0.125': torch.randn(batchsize, 3, 32, 32) + } + + detached_kp_driving = { + 'value': torch.randn(batchsize, 15, 2), + 'jacobian': torch.randn(batchsize, 15, 2, 2) + } + + compute_config = ComputeConfig( + plan_ngpus, runtime_ngpus, use_zero=False, user_config={"batch_size": batchsize} + ) + # parallelize models + kp_detector = parallelize( + kp_detector, + kp_detector_dummy_input, + 'data', # autodist + compute_config, + reuse=ReuseType.MOO, + ) + + generator = parallelize( + generator, + generator_dummy_input, + 'data', # autodist + compute_config, + reuse=ReuseType.MOO, + ) + + discriminator = parallelize( + discriminator, + {"x": pyramide_generated, "kp": detached_kp_driving}, + 'data', # autodist + compute_config, + reuse=ReuseType.MOO, + ) + else: + generator= torch.nn.SyncBatchNorm.convert_sync_batchnorm(generator) + generator = DDP(generator,device_ids=[rank],broadcast_buffers=False) + + discriminator= torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator) + discriminator = DDP(discriminator,device_ids=[rank],broadcast_buffers=False) + + kp_detector= torch.nn.SyncBatchNorm.convert_sync_batchnorm(kp_detector) + kp_detector = DDP(kp_detector,device_ids=[rank],broadcast_buffers=False) + + generator.to(device) + discriminator.to(device) + kp_detector.to(device) + + dataset = VDataset(is_train=True) + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): + copy(opt.config, log_dir) + + combined_str = json.dumps(config, sort_keys=True) + json.dumps(vars(opt), sort_keys=True) + hashstr = hashlib.sha256(combined_str.encode('utf-8')).hexdigest() + if rank == 0: + writer = SummaryWriter(os.path.join(log_dir, 'tensorboard-logs', hashstr)) + else: + writer = None + + if opt.mode == 'train': + train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, rank, device, opt, writer) + + +def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, rank, device, opt, writer=None): + """ + Steps: + 1. build optimizer for generator, discriminator, kp_detector + 2. build scheduler for generator, discriminator, kp_detector + 3. build train and valid DataLoader with DistributedSampler + 4. build generator_full and discriminator_full + 5. train and validate epoches + """ + train_params = config['train_params'] + + if config['backend_params']['backend_name'] == "nnscaler": + optimizer_generator = build_optimizer( + generator, + torch.optim.Adam, + lr=train_params['lr_generator'], + betas=(0.5, 0.999), + ) + optimizer_discriminator = build_optimizer( + discriminator, + torch.optim.Adam, + lr=train_params['lr_discriminator'], + betas=(0.5, 0.999), + ) + optimizer_kp_detector = build_optimizer( + kp_detector, + torch.optim.Adam, + lr=train_params['lr_kp_detector'], + betas=(0.5, 0.999), + ) + scale_factor = 1.0 / (config["backend_params"]["runtime_ngpus"] / config["backend_params"]["plan_ngpus"]) + else: + optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) + optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) + optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999)) + + if checkpoint is not None: + start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, + optimizer_generator, optimizer_discriminator, + None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector) + else: + start_epoch = 0 + + scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, + last_epoch=start_epoch - 1) + scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, + last_epoch=start_epoch - 1) + scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, + last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0)) + + if 'num_repeats' in train_params and train_params['num_repeats'] != 1: + dataset = DatasetRepeater(dataset, train_params['num_repeats']) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=torch.cuda.device_count(), shuffle=True, rank=rank) + dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], num_workers=16, sampler=sampler, drop_last=True) + + generator_full = getattr(MODEL_FULL, opt.GFM)(kp_detector, generator, discriminator, config, opt) + discriminator_full = DiscriminatorFullModel_NNSCALER(kp_detector, generator, discriminator, config) + test_dataset = VDataset(is_train=True) + test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas=torch.cuda.device_count(), rank=rank) + test_dataloader = DataLoader(test_dataset, batch_size=train_params['batch_size'], shuffle=False, num_workers=8, sampler=test_sampler) + + with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: + # logger.register_tensorboard_writer(writer) + for epoch in trange(start_epoch, train_params['num_epochs']): + #parallel + sampler.set_epoch(epoch) + total = len(dataloader) + epoch_train_loss = 0 + generator.train(), discriminator.train(), kp_detector.train() + with tqdm(total=total, position=rank, desc=f"Rank {rank}, Epoch {epoch}", leave=True) as par: + for i,x in enumerate(dataloader): + x['source'] = x['source'].to(device) + x['driving'] = x['driving'].to(device) + losses_generator, generated = generator_full(x) + + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + loss.backward() + if config['backend_params']['backend_name'] == "nnscaler": + optimizer_generator.sync_shard_grad() + optimizer_kp_detector.sync_shard_grad() + optimizer_generator.scale_grads(scale_factor) + optimizer_kp_detector.scale_grads(scale_factor) + optimizer_generator.step() + optimizer_generator.zero_grad() + optimizer_kp_detector.step() + optimizer_kp_detector.zero_grad() + epoch_train_loss+=loss.item() + + if train_params['loss_weights']['generator_gan'] != 0: + optimizer_discriminator.zero_grad() + losses_discriminator = discriminator_full(x, generated) + loss_values = [val.mean() for val in losses_discriminator.values()] + loss = sum(loss_values) + + loss.backward() + if config['backend_params']['backend_name'] == "nnscaler": + optimizer_discriminator.sync_shard_grad() + optimizer_discriminator.scale_grads(scale_factor) + optimizer_discriminator.step() + optimizer_discriminator.zero_grad() + else: + losses_discriminator = {} + + losses_generator.update(losses_discriminator) + losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} + # for k,v in losses.items(): + # writer.add_scalar(k, v, total*epoch+i) + if rank == 0: + logger.log_iter(losses=losses) + par.update(1) + torch.distributed.barrier() + epoch_train_loss = epoch_train_loss/total + if (epoch + 1) % train_params['checkpoint_freq'] == 0: + if rank == 0: + writer.add_scalar('epoch_train_loss', epoch_train_loss, epoch) + scheduler_generator.step() + scheduler_discriminator.step() + scheduler_kp_detector.step() + if rank == 0: + logger.log_epoch(epoch, {'generator': generator, + 'discriminator': discriminator, + 'kp_detector': kp_detector, + 'optimizer_generator': optimizer_generator, + 'optimizer_discriminator': optimizer_discriminator, + 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated) + generator.eval(), discriminator.eval(), kp_detector.eval() + if (epoch + 1) % train_params['checkpoint_freq'] == 0: + epoch_eval_loss = 0 + for i, data in tqdm(enumerate(test_dataloader), position=rank, desc=f"Rank {rank}", leave=True): + data['source'] = data['source'].cuda() + data['driving'] = data['driving'].cuda() + losses_generator, generated = generator_full(data) + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + epoch_eval_loss+=loss.item() + epoch_eval_loss = torch.tensor(epoch_eval_loss).cuda() + gather_epoch_eval_loss = [torch.zeros_like(epoch_eval_loss) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gather_epoch_eval_loss, epoch_eval_loss) + gathered_epoch_eval_loss = torch.mean(torch.tensor(gather_epoch_eval_loss)).item() + epoch_eval_loss = gathered_epoch_eval_loss / len(test_dataloader) + if rank == 0: + logger.log_iter({'epoch_eval_loss': epoch_eval_loss}) + + +if __name__ == "__main__": + world_size = int(os.environ['WORLD_SIZE']) + rank = int(os.environ['RANK']) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + init_seeds() + main(rank, world_size) + dist.destroy_process_group() + + \ No newline at end of file diff --git a/examples/vision/DaGAN/vox-adv-256.yaml b/examples/vision/DaGAN/vox-adv-256.yaml new file mode 100644 index 00000000..ce130efe --- /dev/null +++ b/examples/vision/DaGAN/vox-adv-256.yaml @@ -0,0 +1,81 @@ +dataset_params: + frame_shape: [256, 256, 3] + +model_params: + common_params: + num_kp: 10 + num_channels: 3 + estimate_jacobian: True + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 + num_blocks: 5 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + num_bottleneck_blocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + use_kp: True + + +train_params: + num_epochs: 1 # 150 + num_repeats: 1 # 75 + epoch_milestones: [] + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-4 + lr_kp_detector: 2.0e-4 + batch_size: 16 + scales: [1, 0.5, 0.25, 0.125] + checkpoint_freq: 1 + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + generator_gan: 1 + discriminator_gan: 1 + feature_matching: [10, 10, 10, 10] + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + equivariance_jacobian: 10 + kp_distance: 10 + kp_prior: 0 + kp_scale: 0 + depth_constraint: 0 + +reconstruction_params: + num_videos: 1000 + format: '.mp4' + +animate_params: + num_pairs: 50 + format: '.mp4' + normalization_params: + adapt_movement_scale: False + use_relative_movement: True + use_relative_jacobian: True + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' + + +backend_params: + backend_name: 'nnscaler' # 'nnscaler' or 'ddp' + plan_ngpus: 1 # the size of scale unit, using tp/dp/pp within a scale unit and dp across scale units + runtime_ngpus: 4 # total number of gpus \ No newline at end of file From a6f82b71e1794362d1dace600633d25b3a267c4c Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 25 Nov 2024 09:04:40 +0000 Subject: [PATCH 1772/1892] Merged PR 2320: [Tracer] add new jump op name python 3.11 & 3.12 added & deprecated some instructions, so added the new instructions to support new python version --- nnscaler/graph/tracer/concrete_proxy.py | 19 ++++++---- tests/graph/tracer/test_dis.py | 48 ++++++++++++++++++++----- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/nnscaler/graph/tracer/concrete_proxy.py b/nnscaler/graph/tracer/concrete_proxy.py index 560a14a2..d6dc943c 100644 --- a/nnscaler/graph/tracer/concrete_proxy.py +++ b/nnscaler/graph/tracer/concrete_proxy.py @@ -31,13 +31,20 @@ class ConcreteProxy(Proxy): We can use it to trace a more compatible model, and pass the branches. """ - # TODO: python bytecode changes a lot in version 3.11. these ops should be updated. + # some jump ops have not find practical examples, add them because they are in python doc, + # TODO: after finding specific cases, need to add them to the unit tests. jump_opnames = ( - 'JUMP_IF_FALSE_OR_POP', - 'JUMP_IF_TRUE_OR_POP', - 'POP_JUMP_IF_FALSE', - 'POP_JUMP_IF_TRUE', - 'JUMP_IF_NOT_EXC_MATCH', # occurred in new python vertion, not tested + 'JUMP_IF_NOT_EXC_MATCH', # <= python 3.10 + 'JUMP_IF_FALSE_OR_POP', # <= python 3.11 + 'JUMP_IF_TRUE_OR_POP', # <= python 3.11 + 'POP_JUMP_IF_FALSE', # != python 3.11 + 'POP_JUMP_IF_TRUE', # != python 3.11 + 'POP_JUMP_FORWARD_IF_FALSE', # == python 3.11 + 'POP_JUMP_FORWARD_IF_TRUE', # == python 3.11 + 'POP_JUMP_FORWARD_IF_NOT_NONE', # == python 3.11, not included in unit test + 'POP_JUMP_FORWARD_IF_NONE', # == python 3.11, not included in unit test + 'POP_JUMP_IF_NOT_NONE', # >= python 3.12, not included in unit test + 'POP_JUMP_IF_NONE', # >= python 3.12, not included in unit test ) jump_opcodes = orig_func.tuple(dis.opmap[name] for name in jump_opnames if name in dis.opmap) op_compare = dis.opmap['COMPARE_OP'] diff --git a/tests/graph/tracer/test_dis.py b/tests/graph/tracer/test_dis.py index 03b30cdd..0af05417 100644 --- a/tests/graph/tracer/test_dis.py +++ b/tests/graph/tracer/test_dis.py @@ -111,8 +111,14 @@ def test_normal_item_with_starargs2(): def test_extend(): a = A() [1,2].extend(a) - assert a.caller_inst.opname == 'CALL_METHOD' - assert a.len_caller_inst.opname == 'CALL_METHOD' + # in <= python 3.10, opname is CALL_METHOD + # in >= python 3.11, opname is CALL + if sys.version_info.minor <= 10: + assert a.caller_inst.opname == 'CALL_METHOD' + assert a.len_caller_inst.opname == 'CALL_METHOD' + else: + assert a.caller_inst.opname == 'CALL' + assert a.len_caller_inst.opname == 'CALL' [1, *a] assert a.caller_inst.opname == 'LIST_EXTEND' # BUILD_LIST_UNPACK in python 3.8 @@ -177,7 +183,12 @@ def test_bool(): x = {c: c} bool(x[c]) # CALL_FUNCTION - assert c.caller_inst.opname == 'CALL_FUNCTION' + # in <= python 3.10, opname is CALL_FUNCTION + # in >= python 3.11, opname is CALL + if sys.version_info.minor <= 10: + assert c.caller_inst.opname == 'CALL_FUNCTION' + else: + assert c.caller_inst.opname == 'CALL' c and 1 # JUMP_IF_FALSE_OR_POP assert c.caller_inst.opname == 'JUMP_IF_FALSE_OR_POP' @@ -188,24 +199,43 @@ def test_bool(): c.caller_inst = None bool(c) # CALL_FUNCTION - assert c.caller_inst.opname == 'CALL_FUNCTION' + # in <= python 3.10, opname is CALL_FUNCTION + # in >= python 3.11, opname is CALL + if sys.version_info.minor <= 10: + assert c.caller_inst.opname == 'CALL_FUNCTION' + else: + assert c.caller_inst.opname == 'CALL' c.caller_inst = None - if c: # POP_JUMP_IF_FALSE + if c: pass - assert c.caller_inst.opname == 'POP_JUMP_IF_FALSE' + if sys.version_info.minor != 11: + assert c.caller_inst.opname == 'POP_JUMP_IF_FALSE' + else: + assert c.caller_inst.opname == 'POP_JUMP_FORWARD_IF_FALSE' c.caller_inst = None if not c: # POP_JUMP_IF_TRUE pass - assert c.caller_inst.opname == 'POP_JUMP_IF_TRUE' + if sys.version_info.minor != 11: + assert c.caller_inst.opname == 'POP_JUMP_IF_TRUE' + else: + assert c.caller_inst.opname == 'POP_JUMP_FORWARD_IF_TRUE' c.caller_inst = None if bool(c): # CALL_FUNCTION pass - assert c.caller_inst.opname == 'CALL_FUNCTION' + # in <= python 3.10, opname is CALL_FUNCTION + # in >= python 3.11, opname is CALL + if sys.version_info.minor <= 10: + assert c.caller_inst.opname == 'CALL_FUNCTION' + else: + assert c.caller_inst.opname == 'CALL' c.caller_inst = None x = 1 if c else 0 # POP_JUMP_IF_FALSE - assert c.caller_inst.opname == 'POP_JUMP_IF_FALSE' + if sys.version_info.minor != 11: + assert c.caller_inst.opname == 'POP_JUMP_IF_FALSE' + else: + assert c.caller_inst.opname == 'POP_JUMP_FORWARD_IF_FALSE' c.caller_inst = None From c0338f2bc277b8232701c56228077feb1103ea62 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Mon, 25 Nov 2024 11:23:43 +0000 Subject: [PATCH 1773/1892] Merged PR 2313: [Doc] doc update 1) doc restructure 2) wording update 3) including new content: a. autodist docs b. dimops c. verify_op Preview at: http://srgws-15:8080/index.html (http://10.190.175.247:8080/index.html) Related work items: #2070 --- README.md | 31 +- docs/source/autodist/interface_design.md | 6 +- .../solver_interface/partition_constraint.md | 8 +- docs/source/conf.py | 7 +- .../{self_training.md => control_flow.md} | 22 +- tutorial.md => docs/source/dimops.md | 4 +- docs/source/faq.rst | 53 +- docs/source/index.rst | 119 ++- docs/source/install_from_source.rst | 37 + docs/source/llama3_8b_128k_example.rst | 152 ++++ docs/source/llama3_demo_example.rst | 50 +- docs/source/nanogpt_example.rst | 34 +- docs/source/parallel_module.md | 250 +----- docs/source/pytorch_lightning.md | 12 +- docs/source/quickstart.rst | 279 ++++--- docs/source/quickstart_internal.rst | 116 +++ docs/source/readme.rst | 76 -- docs/source/register_custom_op.md | 5 +- docs/source/trainer.md | 477 ----------- docs/source/trainer.rst | 769 ++++++++++++++++++ docs/source/troubleshooting.rst | 291 +++++++ .../verify_ops => docs/source}/verify_op.md | 26 +- nnscaler/profiler/README.md | 4 +- 23 files changed, 1790 insertions(+), 1038 deletions(-) rename docs/source/{self_training.md => control_flow.md} (80%) rename tutorial.md => docs/source/dimops.md (98%) create mode 100644 docs/source/install_from_source.rst create mode 100644 docs/source/llama3_8b_128k_example.rst create mode 100644 docs/source/quickstart_internal.rst delete mode 100644 docs/source/readme.rst delete mode 100644 docs/source/trainer.md create mode 100644 docs/source/trainer.rst create mode 100644 docs/source/troubleshooting.rst rename {utility/verify_ops => docs/source}/verify_op.md (96%) diff --git a/README.md b/README.md index 0319019a..10ea3495 100644 --- a/README.md +++ b/README.md @@ -33,13 +33,13 @@ For **_DNN system experts_**, they can leverage nnScaler to explore new DNN para ### Prerequisite -Install the following packages before the installation of cube: +Install the following packages before the installation of nnScaler: Python >= 3.8, < 3.11 (3.10 is recommanded) PyTorch >= 2.0, < 2.4 (2.2.0 is recommanded) -### (Option 1) Install nnScaler from source +### Install nnScaler from source Execute below commands in nnScaler directory: pip install -r requirements.txt @@ -52,11 +52,6 @@ Besides, to avoid *cppimport* error, it also needs to include nnScaler directory [//]: # (Reference output: Successfully installed MarkupSafe-2.1.5 contourpy-1.3.0 cppimport-22.8.2 cycler-0.12.1 dill-0.3.8 filelock-3.15.4 fonttools-4.53.1 fsspec-2024.6.1 importlib-resources-6.4.4 jinja2-3.1.4 kiwisolver-1.4.5 mako-1.3.5 matplotlib-3.9.2 more-itertools-10.4.0 mpmath-1.3.0 networkx-3.3 numpy-2.1.0 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.6.68 nvidia-nvtx-cu12-12.1.105 packaging-24.1 pillow-10.4.0 psutil-6.0.0 pulp-2.9.0 pybind11-2.13.5 pyparsing-3.1.4 python-dateutil-2.9.0.post0 pyyaml-6.0.2 six-1.16.0 sympy-1.13.2 torch-2.4.0 tqdm-4.66.5 triton-3.0.0 typing-extensions-4.12.2) -### (Option 2) Install nnScaler from whl package - -To get started, install the latest wheel by visiting [DevOps Artifacts](https://msrasrg.visualstudio.com/SuperScaler/_artifacts/feed/nightly/PyPI/nnscaler/overview/). You may follow DevOps guide to set up the repository, or alternatively download the **.whl** file from the “Files” section of the website, then install locally: - - python -m pip install nnscaler-*.whl ## Example Llama-3 @@ -152,23 +147,20 @@ torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name lla We also provide an example to demonstrate how to parallelize a model through a [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)-compatible interface in nnScaler. * Find the [nanoGPT](https://github.com/karpathy/nanoGPT) example in nnScaler repo: - - - cd MagicCube/examples/nanogpt - +```shell + cd examples/nanogpt +``` * Install nanoGPT's dependencies: - - +```shell pip install -r requirements.txt - +``` * Prepare dataset: - - +```shell python nanoGPT/data/shakespeare_char/prepare.py - +``` * Test with Single GPU -Now you can run ``train_nnscaler.py`` with `torchrun `_: +Now you can run ``train_nnscaler.py`` with `torchrun `: torchrun --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py @@ -190,6 +182,8 @@ Or if you have multiple nodes, for example 2 nodes with 4 GPUs each: NOTE: The local batch size is fixed by default, so using more workers will result in a larger global batch size. +💡 For advanced usages, please stay tuned for our future release. + # Success Stories @@ -201,6 +195,7 @@ nnScaler has been adopted by multiple projects, including both product and resea # Reference --------- +You may find the Artifact Evaluation for OSDI'24 with the guidance [here](https://github.com/microsoft/nnscaler/tree/osdi24ae). Please cite nnScaler in your publications if it helps your research: @inproceedings{lin2024nnscaler, diff --git a/docs/source/autodist/interface_design.md b/docs/source/autodist/interface_design.md index 9d4b3504..da4a164d 100644 --- a/docs/source/autodist/interface_design.md +++ b/docs/source/autodist/interface_design.md @@ -1,11 +1,11 @@ -# Interface Design +# AutoDist Interface Design -Similar to current user experiences in cube, the entrance to *AutoDist* is a function that accept a data flow graph and a resource descriptor as input. The function returns a rewritten graph. The core modules including +Similar to current user experiences in nnScaler, the entrance to *AutoDist* is a function that accepts a data flow graph and a resource descriptor as input. The function returns a rewritten graph. The core modules include: 1. *profile*: build cost models to provide the underlying solver with operator and communication information 2. *dp_solver*: encapsulate existing dynamic programming logic ```python -from cube.graph import IRGraph +from nnscaler.graph import IRGraph def annotate_graph(graph: IRGraph) -> AnnotatedIRGraph: # TODO diff --git a/docs/source/autodist/solver_interface/partition_constraint.md b/docs/source/autodist/solver_interface/partition_constraint.md index d71fc692..ed8cacd7 100644 --- a/docs/source/autodist/solver_interface/partition_constraint.md +++ b/docs/source/autodist/solver_interface/partition_constraint.md @@ -27,7 +27,7 @@ In autodist, we provide a set of partition constraints to control the distribute In this example, we have four partition constraints for the MoE model in retnet. Each partition constraint has 4 fields: `name`, `parent_module`, `allowed_partition_dims`, and `replica_allowed`. -- `name` is the name of the corresponding operator in the model. It equals to the `signature` field in the `IRFwOperation` in cube. Note: signature is the full name of the operator, for example, you should provide `torch.nn.functional.linear` instead of `linear`. +- `name` is the name of the corresponding operator in the model. It equals to the `signature` field in the `IRFwOperation` in nnScaler. Note: signature is the full name of the operator, for example, you should provide `torch.nn.functional.linear` instead of `linear`. - `parent_module` is the **closest** father module name of the operator. You can provide two partition constraints with a same `name` but different `module` to control the partition of the same operator in different modules. Similar to `recompute_modules`, Module name can be any suffix of the full module name, e.g., `module1` will match `x.module1`, `y.module1`, `x.module1` will match `x.module1` but not `y.module1`. - `allowed_partition_dims` is a list of allowed partition dimensions of input tensors. Each element in the list is a list of two integers, which are the index of the partitioned tensor among inputs and the partitioned dimension of that tensor. For example, the annotation of `torchscale.component.xmoe.routing.compute_logits` can be `(C 16) E^ C, E^ C M^ -> (C 16) M^`. `allowed_partition_dims = [[0, 0]]` means we only allow to partition the first input tensor along the first dimension, which is `(C, 16)` in this case. An empty list means no partition is allowed, note that in yaml, you should give an empty list explicitly, i.e., `allowed_partition_dims: []`. - `replica_allowed` is a boolean value. If it is `true`, it is allowed to replicate the operator across devices. @@ -38,6 +38,6 @@ After specifying the partition constraints in a yaml file, we can feed them to a Three examples are provided in `pc_examples` folder. -- `pc_examples/retnet_dp_pc.yaml` helps to generate a pure data parallel plan. -- `pc_examples/retnet_mp_pc.yaml` helps to generate a pure model parallel plan. -- `pc_examples/retnet_hybrid_pc.yaml` helps to generate a hybrid plan: data parallel for the attention module and model parallel for the feed forward module. +- `docs/source/autodist/solver_interface/pc_examples/retnet_dp_pc.yaml` helps to generate a pure data parallel plan. +- `docs/source/autodist/solver_interface/pc_examples/retnet_mp_pc.yaml` helps to generate a pure model parallel plan. +- `docs/source/autodist/solver_interface/pc_examples/retnet_hybrid_pc.yaml` helps to generate a hybrid plan: data parallel for the attention module and model parallel for the feed forward module. diff --git a/docs/source/conf.py b/docs/source/conf.py index 0a698442..013bb11b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,4 +30,9 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = 'furo' -# html_static_path = ['_static'] +html_logo = './images/nnScaler-c-1.png' + +html_static_path = ['_static'] +html_css_files = ['nnscaler.css'] + +mathjax3_config = {'chtml': {'displayAlign': 'left'}} diff --git a/docs/source/self_training.md b/docs/source/control_flow.md similarity index 80% rename from docs/source/self_training.md rename to docs/source/control_flow.md index 26bdf566..07f539df 100644 --- a/docs/source/self_training.md +++ b/docs/source/control_flow.md @@ -1,13 +1,4 @@ -# self.training support - -To parallelize the training process, we firstly need to trace the module and get a static computational graph. - -A common problem with static graph is that it is impossible to handle control flow. - -But on the other hand, `self.training` is very common used in module forward method. -So we add a very limited support for `self.training` in tracing. - -Please note that user code is flattened and transformed into a single `ParallelModule` at runtime, so `training` is a global module state, and we don't support the case that user want to set a sub-module's training to True but remaining modules to False. +# PyTorch control flow ## `if` statement @@ -16,7 +7,7 @@ We don't support any control flow, so For the following code, we only put the `i ```python if self.training: ... -else +else: ... ``` The consequence is that model training/validation will use exactly the same code path. @@ -52,9 +43,9 @@ This trick is not free. It will introduce two side effects: Both branches will be evaluated, so you must make sure that both branches are valid, and have no side effect. To reduce the side effect, we will check true expr/false expr, and requires both don't contain function calls. so the following code will not be converted: - ```python - x = f(a) if self.training else b - ``` +```python +x = f(a) if self.training else b +``` 2. We will convert `if` expression only if the condition is `self.training`. So if a non-module class has a `training` attribute, the `if` expression in its member functions will also be converted if its condition is `self.training`. @@ -64,7 +55,6 @@ For example, you can convert the above code to: import nnscaler import torch - @nnscaler.register_op('?, ? -> ?') def get_dropout(training, dropout): return dropout if training else 0 @@ -74,7 +64,7 @@ torch.nn.functional.scaled_dot_product_attention( dropout_p=get_dropout(self, self.dropout), is_causal=self.is_causal ) -`` +``` ## self.training as a parameter diff --git a/tutorial.md b/docs/source/dimops.md similarity index 98% rename from tutorial.md rename to docs/source/dimops.md index 13f59621..df7d4288 100644 --- a/tutorial.md +++ b/docs/source/dimops.md @@ -1,10 +1,10 @@ -# Dimop Tutorial +# Dimops ## Dimop: Dimension-annotated Operator ### Annotation for Shape Inference and Transformation -SuperScaler uses annotation to represent an operator (Dimop). +nnScaler uses annotation to represent an operator (Dimop). The goal of annotation is for 1). shape inference and 2) transformation plan. To annotate an operator, following example shows the annotation of matrix multiplication. An operator has inputs and outputs. The inputs can be tensors or non-tensors, while outputs are usually tensors. diff --git a/docs/source/faq.rst b/docs/source/faq.rst index 0b40ce88..6fe53a6d 100644 --- a/docs/source/faq.rst +++ b/docs/source/faq.rst @@ -1,50 +1,37 @@ -Frequent asked questions +Frequent Asked Questions ------------------------ -**What is nnScaler?** +* **What is nnScaler?** -The nnScaler is a system that takes a DNN model that is designed for a single device, e.g., GPU, and automatically converts it into a program that can execute concurrently on multiple devices. +nnScaler is a system that converts a Deep Neural Network (DNN) model designed for a single device (e.g., GPU) into a program capable of executing concurrently on multiple devices. -**What can nnScaler do?** +* **What can nnScaler do?** -Under the hood, nnScaler analyzes the given DNN models, plans for appropriate parallelization strategies, and generates corresponding execution code. With nnScaler, users can focus on single-device DNN model design, offload the complex parallelization work to nnScaler, and easily achieve high-performance parallel DNN execution. +Under the hood, nnScaler analyzes the given DNN models, plans appropriate parallelization strategies, and generates the corresponding execution code. This allows users to focus on single-device DNN model design while nnScaler handles the complex parallelization work, enabling high-performance parallel DNN execution with ease. -**What is/are nnScaler’s intended use(s)?** +* **What is/are nnScaler's intended use(s)?** -Due to high compatibility and extensibility, nnScaler can be used for the innovation of a wide range of new DNN models and DNN systems, including new model structures, training patterns, as well as new parallelization techniques that go beyond existing data-parallelism, tensor-parallelism, or pipeline parallelism. +Thanks to its high compatibility and extensibility, nnScaler can innovate a wide range of new DNN models and systems. This includes supporting new model structures, training patterns, and parallelization techniques that go beyond existing data-parallelism, tensor-parallelism, and pipeline parallelism. -**How was nnScaler evaluated? What metrics are used to measure performance?** +* **How was nnScaler evaluated? What metrics are used to measure performance?** -For execution performance, nnScaler can support new parallelisms that outperform existing parallel execution approaches: -1. Fitting larger DNN models given the same hardware. -2. Providing faster execution for the same model on the same hardware (included in our OSDI’24 paper). +For execution performance, nnScaler supports new parallelisms that outperform existing parallel execution approaches: -For compatibility, nnScaler can support paralleling new DNN models by providing user-defined functions (a few lines of code) for the new operators unrecognized by the nnScaler. + 1. Fitting larger DNN models on the same hardware. + 2. Providing faster execution for the same model on the same hardware (as detailed in our OSDI'24 paper). -**What are the limitations of nnScaler? How can users minimize the impact of nnScaler’s limitations when using the system?** +For compatibility, nnScaler supports paralleling new DNN models by allowing user-defined functions (a few lines of code) for operators not recognized by nnScaler. -- Certain DNN model architectures or execution patterns may violate the assumptions of nnScaler and, therefore, cannot be supported by nnScaler. -- The nnScaler does not guarantee the optimality of parallelization, so it is possible for nnScaler to miss the optimal parallelization strategy given DNN model and device settings, while only providing suboptimal solutions. -- Despite our best efforts to ensure the parallelization process is correct, it is possible for nnScaler to generate parallelized programs for concurrent execution that are inconsistent with the original DNN model for a single device. +* **What are the limitations of nnScaler? How can users minimize the impact of nnScaler's limitations when using the system?** -**What operational factors and settings allow for effective and responsible use of nnScaler?** +- Certain DNN model architectures or execution patterns may violate nnScaler's assumptions and cannot be supported. +- nnScaler does not guarantee optimal parallelization. It may miss the best strategy, providing suboptimal solutions. +- Despite efforts to ensure correctness, nnScaler might generate parallelized programs that are inconsistent with the original single-device DNN model. -- We provide documentation to guide users in the usage of the nnScaler. -- We provide parallelization examples that users can directly leverage for parallel execution if they intend to execute the same DNN models. -- We also provide certain cases of customization, including reconfiguring the device settings, adopting new DNN models in nnScaler, and supporting customized operators. +* **License** -**What are extensions(plugins) in nnScaler and how does nnScaler use them?** +- Please visit our `License Information `_ for details. -The nnScaler supports the extension with customized parallelization of DNN modules, allowing new DNN models to be parallelized. During this process, nnScaler will handle the new modules in the same way as those it already supports. +* **Security** -**What can nnScaler provide to extensions(plugins)?** - -The nnScaler provides an easy-to-use interface so users can conveniently realize customized parallelization of certain DNN modules by only implementing a few user-defined functions. - -**What kinds of issues may arise when using nnScaler enabled with extensions(plugins)?** - -- When paralleling new DNN models, users may try some structures or execution patterns that violate the assumptions and fail to support. -- When adapting new DNN models for parallelization, users may incorrectly implement the user-defined function, causing nnScaler to produce incorrect parallelized programs. -- Certain unforeseen mistakes in nnScaler implementation may cause it to produce incorrect parallelized programs without warning, leading to incorrect execution. -- To mitigate unsupported issues, users may disable parallelization for the entire DNN model or certain parts of the model as a workaround. -- To mitigate incorrect execution, users may compare the parallelized programs and original DNN model execution on small datasets to confirm their consistency before deploying to large scale for long-term execution. \ No newline at end of file +- Please visit our `Security Information `_ for details. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 7934db1e..d8b433a7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,21 +3,132 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. +######## +Overview +######## + Welcome to nnScaler's documentation! ==================================== +Project Website: https://github.com/microsoft/nnscaler + +What is nnScaler? +----------------- + +nnScaler is a parallelization engine that compiles a Deep neural network (DNN) model that designed for single-GPU execution into a program that capable of running in parallel across multiple GPUs. + +.. image:: ./images/nnScaler_flow.png + +System Highlights +----------------- + +* Ease of Use: Enable parallelization with just a few lines of code, producing a Pythonic parallel program easy for further development. +* Extensibility: Seamlessly integrates new operators to support emerging models through nnScaler's exposed API. +* Reliability: Verified through extensive end-to-end training sessions, nnScaler is a dependable system. +* Performance: By exploring a larger parallelization space, nnScaler can significantly enhance parallel training performance. + +``nnScaler`` allows **DNN scientist** to concentrate on model design with PyTorch on single GPU, while leaving parallelization complexities to the system. It introduces innovative parallelism techniques that surpass existing methods in performance. Additionally, nnScaler supports the extension of DNN modules with new structures or execution patterns, enabling users to parallelize custom DNN models. + +``nnScaler`` helps **DNN system experts** to explore new DNN parallelization mechanisms and policies for emerging models. By providing user-defined functions for new operators not recognized by nnScaler, it ensures seamless parallelization of novel DNN models, such as facilitate long sequence support in LLMs. + + +Success Stories +--------------- + +nnScaler has been adopted by multiple projects, including both product and research explorations: + * `(YOCO)You only cache once: Decoder-decoder architectures for language models `_ + * `LongRoPE: Extending LLM context window beyond 2 million tokens `_ + * Post training for the long context version of `Phi-3 series `_ SLMs + + +Get Started +=========== + +* :doc:`quickstart` +* :doc:`llama3_demo_example` +* :doc:`llama3_8b_128k_example` +* :doc:`nanogpt_example` + + +Reference +--------- + +Please cite nnScaler in your publications if it helps your research:: + + @inproceedings{lin2024nnscaler, + title = {nnScaler: Constraint-Guided Parallelization Plan Generation for Deep Learning Training}, + author={Lin, Zhiqi and Miao, Youshan and Zhang, Quanlu and Yang, Fan and Zhu, Yi and Li, Cheng and Maleki, Saeed and Cao, Xu and Shang, Ning and Yang, Yilei and Xu, Weijiang and Yang, Mao and Zhang, Lintao and Zhou, Lidong}, + booktitle={18th USENIX Symposium on Operating Systems Design and Implementation (OSDI 24)}, + pages={347--363}, + year={2024} + } + +You may find the Artifact Evaluation for OSDI'24 with the guidance `here `_. + +Contributing +------------ + +This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. + +When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the `Microsoft Open Source Code of Conduct `_. For more information, see the `Code of Conduct FAQ `_ or contact opencode@microsoft.com with any additional questions or comments. + +Trademarks +---------- + +This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow `Microsoft's Trademark & Brand Guidelines `_. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos is subject to those third-party's policies. + +Contact +------- + +You may find our public repo from https://github.com/microsoft/nnscaler or microsoft internal repo https://aka.ms/ms-nnscaler. +For any questions or inquiries, please contact us at nnscaler@service.microsoft.com. + .. toctree:: :maxdepth: 1 - :caption: Contents: + :hidden: + :caption: Get Started - readme + self + install_from_source quickstart + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Examples + + llama3_demo_example + llama3_8b_128k_example nanogpt_example + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Basic Usages + + trainer pytorch_lightning - parallel_module register_custom_op - parallel + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Advanced Usages + + parallel_module + dimops + verify_op + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Miscellaneous + + control_flow faq + troubleshooting Indices and tables ================== diff --git a/docs/source/install_from_source.rst b/docs/source/install_from_source.rst new file mode 100644 index 00000000..4af9b996 --- /dev/null +++ b/docs/source/install_from_source.rst @@ -0,0 +1,37 @@ +################### +Install from Source +################### + +************** +Clone the Repo +************** + +The nnScaler repository is hosted on GitHub. + +:: + + git clone https://github.com/microsoft/nnscaler + +**************** +Editable Install +**************** + +nnScaler uses ``pybind11`` and ``cppimport`` to speedup partitioning. +The c++ modules must be manually compiled for an editable install. + +:: + + cd nnscaler + pip install -e . + python -c "import cppimport.import_hook ; import nnscaler.autodist.dp_solver" + +************* +Build a Wheel +************* + +:: + + cd nnscaler + pip install build + python -m build + pip install dist/nnscaler-*.whl diff --git a/docs/source/llama3_8b_128k_example.rst b/docs/source/llama3_8b_128k_example.rst new file mode 100644 index 00000000..26701d6b --- /dev/null +++ b/docs/source/llama3_8b_128k_example.rst @@ -0,0 +1,152 @@ +####################### +Llama 3 8B 128K Example +####################### + +************ +Introduction +************ + +This example demonstrates how to train llama3-8B-128k model with 8xH100s or 8xA100s. + +************ +Requirements +************ + +To run this example, you need to install the following packages: :: + + nnscaler + transformers==4.40.0 + datasets==2.20.0 + apex + flash-attn + +*nnScaler* is a framework for distributed training by automatically partitioning the model. +Apart from the core nnScaler library, it also includes a mini-trainer for modern model training. + +*transformers* and *datasets* are required to prepare the data and loading the Llama model. + +To speed up the training process, +`apex `_ and `flash-attn `_ are required. +You can install them by following instructions in their official repositories. +You may also launch the script in an Nvidia docker container, e.g., ``nvidia/pytorch:24.02-py3``. + +**************** +Data Preparation +**************** + +We use the `bookcorpus `_ dataset for training, which is tokenized with the `Meta-Llama-3-8B-Instruct `_ tokenizer. +Tokenized data is saved in the ``bookcorpus_llama3_128K`` directory. + +.. code-block:: bash + + python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_128K --sequence_length 131072 + +******** +Training +******** + +nnScaler adopts a compiler approach to launch the distributed training, which consists of two stages: + +#. Compile stage: trace the original PyTorch model into a dataflow graph, analyzing the graph to get an efficient plan for distributed training, and + generate python code based on the plan. +#. Runtime stage: run the generated python code to train the model. + +**Note**: We recommend to use well-tested config ``"_attn_implementation": "flash_attention_2"`` and ``"use_cache": false`` in the config file. + +Register Customized Function +============================ + +Llama3's vocabulary size is about 128K, which is much larger then the 32K in Llama2. At the same time the sequence length in this example is 128K, the output tensor size of the last projection layer is quite large: 128K x 128K x 2 bytes = 32GB. +Although this tensor can be partitioned evenly to 8 GPUs, 4GB memory is still quite large for limited GPU memory. What makes it worse is that we need to store additional 8GB for `log_softmax` and `cross_entropy_loss` computation. +In order to reduce the memory consumption: + +* we split the input sequence on each device to chunks of 1K tokens +* for each chunk, we recompute a function which is composed of last projection layer, log_softmax and loss +* as a result, we only need to store the input tensor to the last projection layer, whose initial size is 128K x 4K x 2 bytes = 1GB, which is much smaller than 32GB + +You can find the detailed implementation in ``chunk_linear_cross_entropy.py``. +The interface of the ``chunk_linear_cross_entropy`` function is ``(hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor``, where + +* ``hidden_states`` is the output of the last transformer layer, with shape ``[batch_size, sequence_length, hidden_size]`` +* ``weight`` is the weight matrix of the last projection layer, with shape ``[vocab_size, hidden_size]`` +* ``labels`` is the target labels, with shape ``[batch_size, sequence_length]`` +* ``padding_idx`` is the padding index +* ``chunk_size`` is the size of the chunk, default is 1024 + +We want to register this function to nnScaler and tell it to partition this function along batch size or sequence dimension. A possible annotation is ``b l d^, n^ d^, b l -> b l``. Here ``b`` stands for batch size, ``l`` stands for sequence length, ``d`` stands for hidden size, and ``n`` stands for vocab size. The ``^`` means the dimension cannot be partitioned. More details about the annotation can be found in related documents. + +Compile +======= + +.. code-block:: bash + + python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee compile.log + + +Run +=== + +.. code-block:: bash + + torchrun --nproc_per_node=8 train.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee run.log + + +**Note**: You may directly the ``Run`` command which will compile implicitly, but for clearer log and debug information, we recommend to run ``Compile`` command explicitly before runtime stage. + +Checkpoint +========== + +This script will save the model checkpoint in the ``./checkpoints`` directory. You can change the checkpoint directory by updating the ``CheckpointConfig`` in the ``train.py`` script. + +nnScalar saves checkpoints in shards: each rank may save parameters and optimizer states in a file. These checkpoints can be directly loaded by nnScaler if the partitioning strategy is the same. If you want to evaluate the checkpoints on downstream tasks, you need to merge the shards into a single file. You can use the following command to merge the shards: + +.. code-block:: bash + + python ckpt_merger.py --ckpt_dir ./checkpoints --output_fname ./merged.ckpt + +The merged checkpoint can be loaded by nnScaler by setting the ``--resume_path`` option to the merged file. + +If the script is modified for different hardware configurations. + +* All sharded checkpoint files should be collected and placed in a same directory before ``ckpt_merger.py`` is called. +* If the config is changed (plan_ngpus/runtime_ngpus/etc), the sharded checkpoint can not be used anymore. You need to merge them so the trainer can load from merged checkpoint. + +*********** +Performance +*********** + +The flops of the forward computation for Llama3 is + +.. math:: 2 \cdot ( param\_num \cdot seqlen + 2 \cdot layer\_num \cdot hidden\_dim \cdot seqlen ^ 2) + +For the 8B model, the forward flops is about 11104.35 TFLOPs. The detailed config is as following: + + +* .. math:: param\_num = 8 \times 10^9 +* .. math:: seqlen = 128 \times 1024 +* .. math:: layer\_num = 32 +* .. math:: hidden\_dim = 4096 + +Generally, the computational cost of backpropagation is twice that of the forward pass. In addition, the gradient accumulation number is set to 4. As a result, the flops for a step of the training script is 133252.22 TFLOPs. + +We execute the training script on a node with 8xH100 80GB HBM3. The time cost is about 41.12s for a step. The theoretical BF16 computational speed of the H100 is 989 TFLOPS. Combine them together, this script can achieve 40.96% MFU. You can further optimize the performance by + +* add more devices to avoid recomputation: in order to fit the model into the memory, we recompute by layer. +* do more kernel optimizations. For example, the swiglu activation can be fused into the matmul ahead of it. + +********* +Debugging +********* + +Since the 128K config is challenging, it is recommended to use a smaller model for debugging. For example, you can use the following command to prepare data and train a smaller llama3 (same architecture, but with 4 decoder layers) model on two GPUs. + +.. code-block:: bash + + ## prepare data + python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 + + ## build the mini model + python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini + + ## compile and run using data parallelism + zero1 + torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K diff --git a/docs/source/llama3_demo_example.rst b/docs/source/llama3_demo_example.rst index 12eff6ec..b10bdc33 100644 --- a/docs/source/llama3_demo_example.rst +++ b/docs/source/llama3_demo_example.rst @@ -1,6 +1,6 @@ -############### -Llama 3 Example -############### +############ +Llama 3 Demo +############ This is an example demostrating how to train Llama 3 8B with nnScaler's :doc:`trainer `. @@ -17,21 +17,19 @@ Installation export HF_TOKEN=... -1. Install nnScaler :: +1. Clone nnScaler repo :: - pip install nnscaler + git clone --recursive https://github.com/microsoft/nnscaler -2. Clone nnScaler repo to get the example :: - - git clone --recursive https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube - cd MagicCube/examples/llama3_demo - -3. Install Llama 3 dependencies :: +2. Install dependencies (including Llama 3 dependencies) and :doc:`nnScaler from source ` :: + cd nnscaler pip install -r requirements.txt + pip install -e . - Note: The requirements file has pinned ``torch``, ``transformers``, and ``datasets`` versions - to ensure their compatibility with each others. +3. Find the Llama 3 example :: + + cd nnscaler/examples/llama3_demo 4. Prepare dataset :: @@ -44,30 +42,32 @@ Installation Train a Mini-model ================== -This examples requires 8 × 80GB GPU memory to train a full 8B model. -If your have adequate GPUs, you can skip to :ref:`the next section `. +This examples requires 8 x 80GB GPU memory to train a full 8B model. +If your have qualified GPUs, you can go to :ref:`the next section `. -Alternatively, you can start from a smaller model for verification: :: +Alternatively, you may start from a smaller model for verification: :: python train.py --prepare_data --mini torchrun --nproc_per_node=2 train.py --mini -This will resize Llama 3 to 4 hidden layers and reduce max sequence length to 4K. -We have tested it with 2 × 48GB memory. +This will resize Llama 3 into a model with 4 hidden layers and max-sequence-length reduced to 4K (4096). +We have tested it with 2 x 48GB GPUs. -If the model is still too large, you can shrink it further: :: +You may further shrink it if the model is still too large: :: python train.py --prepare_data --max_seq_len=1024 torchrun --nproc_per_node=2 train.py --max_seq_len=1024 --num_hidden_layers=2 --from_scratch -With the default mini config (4 layers, 4K sequence length), the loss curve will be like following: +Here is the training loss with the default mini config (4 layers, 4K sequence length): .. image:: ./images/llama3-curves-mini.png +.. _finetune: + Finetune Llama 3 8B =================== -Use the following commands to finetune `Meta-Llama-3-8B-Instruct `: :: +Use the following commands to finetune `Meta-Llama-3-8B-Instruct `_: :: python train.py --prepare_data torchrun --nproc_per_node=8 train.py @@ -78,13 +78,13 @@ Use the following commands to finetune `Meta-Llama-3-8B-Instruct `_ with nnScaler and `Lightning `_ trainer. @@ -14,23 +14,17 @@ Get Started Installation ============ -1. Install nnScaler :: +1. Clone nnScaler repo :: - pip install nnscaler + git clone --recursive https://github.com/microsoft/nnscaler -2. Clone nnScaler repo to get the example :: - - git clone --recursive https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube - cd MagicCube/examples/nanogpt - -.. - FIXME: update url to github? - -3. Install nanoGPT's dependencies :: +2. Install dependencies (including nanoGPT's dependencies) and :doc:`nnScaler from source ` :: + cd nnscaler pip install -r requirements.txt + pip install -e . -4. Prepare dataset :: +3. Prepare dataset :: python nanoGPT/data/shakespeare_char/prepare.py @@ -47,7 +41,7 @@ It will take several minutes and the best validation loss will be around 1.47. Get Distributed =============== -nnScaler is meant for distribution. For v0.1 release, we are focusing on data parallel. +nnScaler is meant for distribution. For the current release, we are focusing on data parallel. If you have 4 GPUs on one node: :: @@ -64,7 +58,7 @@ NOTE: The local batch size is fixed by default, so using more workers will resul Tensor Parallel (Experimental) ============================== -nnScaler will support tensor parallel and hybrid parallel in the next release. +nnScaler will support tensor parallel and hybrid parallel in following release. You can try this feature now, but its stability and parity has not been strictly verified yet. Using data parallel: (each model instance runs on 1 GPU, 4 instances using DP) :: @@ -82,7 +76,7 @@ Using hybrid parallel: (each model instance runs on 2 GPUs, 2 instances using DP Resuming ======== -You may resume a interrupted training: :: +You may resume an interrupted training: :: torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --init_from=resume @@ -141,7 +135,7 @@ To parallelize the lightning model with nnScaler, there are 2 noteworthy places: Other parameters are used for performance (efficiency) tuning. -For details, please check the :doc:`API reference `. +.. For details, please check the :doc:`API reference `. ********************** Parity and Limitations @@ -188,6 +182,6 @@ The Lightning Port The Lightning port is not exactly the same as the original nanoGPT training script for the following reaons: 1. The Lightning ``Trainer`` is different from nanoGPT's training loop. -2. nnScaler v0.1 lacks the support for multiple parameter groups, and therefore the weight decay is configured for all parameters. +2. nnScaler currently lacks the support for multiple parameter groups, and therefore the weight decay is configured for all parameters. .. image:: ./images/nanogpt-curves-orig.png diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index ae24fd96..3cd1df5f 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -1,12 +1,10 @@ -# Parallel Module +# Paralleling a Module -nnScaler can parallelize a `torch.nn.Module` to a parallel module. -A parallel module is a special `torch.nn.Module` but runs in multiple gpus/nodes. -All the complexity of distributed training/inferring is hidden from the user. +nnScaler can transform a `torch.nn.Module` into a parallel module, which is a specialized version of `torch.nn.Module` capable of running across multiple GPUs or nodes. This process hides the complexity of distributed training and inference from the user. -Currently we support three kinds of parallelism: data parallelism, tensor parallelism and pipeline parallelism (model parallelism). We can also combine them to get the best performance. +Currently, we support three kinds of parallelism: data parallelism, tensor parallelism and pipeline parallelism. We can also combine them for better performance. -Data parallelism and tensor parallelism are support for all kinds of module, but pipeline parallelism is only supported for end2end modules for scheduling reason. +Data parallelism and tensor parallelism can be supported for any module, but pipeline parallelism is only supported for end2end modules for scheduling reason. An end2end module is a module which satisfies: - the first argument of `module.forward` is the data sample, and every other argument should have default value, and use its default value in `module.forward` function. @@ -171,155 +169,6 @@ def train(model: ParallelizedPipelinedLLM, data): optimizer.zero_grad() ``` -## APIs - -### ComputeConfig -The configuration of the compute environment. It is a dataclass with the following fields: -```python - -@dataclass(frozen=True) -class ComputeConfig: - plan_ngpus: int - runtime_ngpus: int - - constant_folding: bool = False - trace_strategy: Literal['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'] = 'cuda_run_cpu_offload' - - use_zero: bool = False - zero_ngroups: int = 1 - zero_use_reduce_scatter: bool = False - - inference_only : bool = False - use_end2end: bool = False - - use_async_reducer: bool = False - reducer_bucket_cap_mb: Optional[float] = None - - pas_config: Dict[str, Any] = field(default_factory=dict) - user_config: Dict[str, Any] = field(default_factory=dict) -``` -We can categorize the fields into 4 categories: - -1. Trace configuration - - `constant_folding`: whether to enable constant folding when generating code. - When it is true, all non-tensor non-input values will be folded into the generated code. - - For example, if user's code contains following snippet, and `bsz=1`, `num_heads=32`, `len=1024`, `hidden_dim=128` at tracing. - ```python - bsz, num_heads, len, hidden_dim = x.size() - x = x.view(bsz * num_heads, len, hidden_dim) - ``` - The code (graph) is folded into the following format - - ```python - y = x.view(32, 1024, 128) - ``` - - Constant folding is helpful to simplify the input program, - and can make the compiling process faster and reduce the communication cost at runtime. - However, user should make sure that inputs at runtime share a same schema (including shape) with tracing and correspond to a same computation graph. - Errors may be raised at runtime when this assumption is broken. - - `trace_strategy`: how to execute the functions during trace. - Five strategies are supported: - 1. `cpu`: Execute all functions on cpu device, model weights and intermediate results are on cpu device. - 2. `cuda`: Execute all functions on cuda device, model weights and intermediate results are on cuda device. This strategy is recommended if the model can inference on single gpu. - 3. `meta`: Execute all functions on meta device, model weights are on cpu and intermediate results are on meta device. For more information about meta device type, please view https://pytorch.org/docs/stable/meta.html. - 4. `cuda_run_cpu_offload`: Try to execute all functions on cuda, and retry to execute the function on cpu as backup if OOM is catched, model weights and intermediate results are on cpu. This strategy is recommanded for most case if the model is too large to inference on single gpu. - 5. `reuse_cache`: Compared to `cuda_run_cpu_offload` strategy, maintains a map from function signatures to output values. The cached output is returned when the signature of the function that generates it has been executed. Same signature means the funtions are the same and have almost the same inputs (for tensor type input, just check if they have same tensor meta data[shape, dtyep, requires_grad, stride, memory_format, ...], and don't check the value). This strategy is an experimental strategy to speedup the large-model-large-input case, and have risk to trace an incorrect graph if the signature defined here can not distinguish the differnet functions used in the model, for example, torch.nonzero will always return the same result if the input have same meta data but different value. We have plan to continue improve this strategy to handle most these kind of data dependence cases, but please note that the risk is still inevitable. -2. Compute environment configuration - - `plan_ngpus`: the number of gpus to be used as a unit. The model is partitioned (TP or PP) within a unit, and then data parallelism is applied across multiple units. So every `plan_ngpus` devices holds the whole model. Furthermore, assume we have two workers, and their ranks are `rank1` and `rank2`: - 1. if `rank1 // plan_gpus == rank2 // plan_ngpus`, then they are in the same unit. - 2. If `rank1 % plan_ngpus == rank2 % plan_ngpus`, then the portion of model hold on both gpus are exactly the same. - - `runtime_ngpus`: the number of gpus to be used in runtime. It should be a multiple of `plan_ngpus`, which means we have `runtime_ngpus // plan_ngpus` units in runtime, and the data parallelism is `runtime_ngpus // plan_ngpus`. - Please note all modules must have the same `plan_ngpus` and `runtime_ngpus`. -3. Code generation feature configuration - - `use_zero`: whether to use zero. If it is true, the generated code will use zero1 to do distributed training. - - `zero_ngroups`: the number of groups to be used in zero. - - `zero_use_reduce_scatter`: whether to use reduce scatter in zero. If it is true, the gradients will be reduced by reduce scatter in zero. - - Please note - - Reduce scatter is only available when `zero_ngroups` is 1. when `zero_ngroups` > 1, you should set it to `False`, or an error will be raised. - - In some cases, it can introduce parity issue. So use it with caution. - - `inference_only`: whether to generate code for inference only. If it is true, the generated code can not be used to train the model. - - `use_end2end`: whether to use end2end training. For the requirement of end2end, see the description above. - - `use_async_reducer`: whether to use async reducer. - If it is true, the gradients will be reduced asynchronously. - Please note this only works when `use_end2end` is true. - - `reducer_bucket_cap_mb`: the bucket capacity of the reducer. - If it is `None` or `0`, the default value will be used, which is - - 25MB for async, the same default value with pytorch ddp implementation - - no limit for sync - - Please note this only works when `use_end2end` is true. - - `pas_config`: the configuration for the PAS policy (partition-assign-schedule policy, which describes how to place all computations across devices. For details, please refer to [PAS Policies](#pas-policies)). - It is a dictionary, and will be used by the PAS policy. - Please note different PAS will have different configurations, - You can also put any other settings that can affect code generation here. but please prefix the keys with `_` to avoid conflicts with PAS configurations. - - `user_config`: the user configuration, which is used to decide whether skipping compiling and reusing the previously traced graph. - -Note: -1. You can put any custom configurations in `user_config`. The assumption is different `user_config` should generate different graph/code. So if the user config is changed, we will regenerate the graph/code automatically. Here are some examples: - - - Example 1: save module configuration - ```python - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): - ... - if module_config.use_3d: - ... - ``` - here we can set `user_config` to `{'use_3d': module_config.use_3d}`, - and we can be sure different use_3d config will never use the same graph (and eventually the generated code). - - - Example 2: save file stats - If you want to track all related file stats (just like traditional compilers do), - you can save the md5 of the files to save some bytes: - ```python - import hashlib - h = hashlib.md5() - for f in Path('./src').glob('**/*.py'): - with open(f, 'rb') as f: - h.update(f.read()) - compute_config = { - ...., - user_config: { - 'files_md5': h.hexdigest() - } - } - ``` -2. If some settings doesn't affect tracing/graph generation, but do affect code generation, you can put them in `pas_config`. Please prefix the keys with `_` to avoid conflicts with predefined PAS configurations. One typical example is you can put the name of selected PAS policy in `pas_config`, so changing PAS policy will regenerate code but the graph will be reused. - - ```python - compute_config = ComputeConfig( - ... - pas_config={ - '_pas_name': ..., - # PAS policy specific configurations - ... - }, - ) - ``` - -### ReuseType - -The reuse policy for the existing generated code. It is an enum with the following values: - -```python -class ReuseType(Enum): - MATCH = 'match' - OVERRIDE = 'override' - MOO = 'moo' - GRAPH = 'graph' -``` -We call it a `match` when the `ComputeConfig` is the same with the previous run. - -1. `MATCH`: Reuse if match, error if not match, generate if no previous gerenated code exists. -2. `OVERRIDE`: Nothing will be reused. Everything will be regenerated. -3. `MOO`: `MOO` is short for 'match or override'. It will reuse if match, generate if not match or no previous generated code exists. -4. `GRAPH`: Reuse graph only if match, generate otherwise. - ### BroadcastGenFilesStrategy The broadcast strategy for new generated files. @@ -568,86 +417,6 @@ def infer_step(self, samples: List[Any]) -> List[Any]: The inference step function. It should be called in the inference loop. The input is a list of samples, and returns a list of outputs for the samples. If pipeline is used, it must have the same length as configured to pas policy. -### PAS Policies - -Writing a pas policy can be very hard and error-prone. So we provide 6 builtin PAS policies to help you. `dp`, `tp`, `pp`, `data`, `hybrid`, and `autodist`. Please note only `autodist` policy is the recommended policy for most cases, and all other PAS policies are mainly test purpose only. - -The configuration of the PAS policy should be passed in the `pas_config` of `ComputeConfig` as a dictionary. - -1. `dp`: data parallelism. It will replicate the module across all devices, and run data parallelism across all devices. It requires the `plan_ngpus` must be 1 and no configurations - -2. `tp`: tensor parallelism + data parallelism. It will do tensor parallelism inside a scale unit, and run data parallelism across scale units. It has only one configuration: - - seed: the random seed for choose the partition dimension. Default is `1` - -3. `pp`: pipeline parallelism + data parallelism. -It will do model parallelism inside a scale unit, -and run data parallelism across scale units. -It requires the `use_end2end` be true. -It has two configurations `pipeline_nmicros` and `pipeline_scheduler`. -See `hybrid` policy for more details. - -4. `data`: tensor parallelism on batch dimension. It has no configurations. - -5. `hybrid`: pipeline parallelism + tensor parallelism + data parallelism. -It will do model parallelism and tensor parallelism(on 0 dimension) inside a scale unit, -and run data parallelism across scale units. -It requires the `use_end2end` to be true. It has the following configurations. - - `pipeline_nstages`: the number of stages in the pipeline. Default is `plan_ngpus`. Optional. - - `pipeline_nmicros`: the number of microbatches in the pipeline. Required. - - `pipeline_scheduler`: the scheduler name for the pipeline. Current we support four schedulers in training `1f1b`/`1f1b_plus`/`gpipe`/`chimera_direct` (4 stages pipeline only), and one scheduler in inference `infer_pipe`. Default is `1f1b`. Optional. - -6. `autodist`: the recommended policy for most cases. Currently it only support Adam-like optimizers. It will automatically choose the best partition for you by balancing the memory usage and speed. It has the following configurations. - - `update_freq (int)`: the update frequency when training the module. Default is 1. Optional. - - `mem_constraint (float)`: The memory constraint in each device in GB. Optional. - - `task_name (str)`: The name of the current task to distinguish runs. Optional. - - `use_fp16 (bool)`: Whether you use `fp16`. Default is `False`. Optional. - - `use_memory_efficient_fp16` Whether you use memory efficient fp16 optimizer. Default is `False`. Optional. - - `use_bf16`: Whether you use `bf16`. Default is `False`. Optional. - - `use_memory_efficient_bf16`: Whether you use memory efficient bf16 optimizer. Default is `False`. Optional. - - `re_profile (bool)`: If set to `True`, the computation profiling results will be overridden. Please note reprofiling will take some time. Optional. - - `verbose (bool)`: Whether to print verbose information. Optional. - - `load_plan_path (str)`: The path to the plan file to load. If specified, the plan will be loaded from the file instead of searching. Optional. - - `save_plan_path (str)`: The path to the plan file to save. Optional. - - `partition_constraints_path (str)`: The path to the partition constraints file. Optional. - - `recompute_modules (str)`: The module names to recompute, separated by `,`. For example, `module1,module2`. Optional. - - `pipeline_pivots (str)`: The module names to pivot the pipeline, separated by `,`. For example, if `module1,module2` is specified, stages searched by pipeline solver only start from either `module1` or `module2`. Optional. - - `use_apex_fused_adam_v2`: If set to `True`, the apex fused adam v2 optimizer will be used. Default is `False`. Optional. - - `explore_pipeline`: If set to `True`, autodist will try pipeline parallelism to find the best partition plan - (but the selected partition plan is not necessarily pipeline parallelism). - - `pipeline_scheduler`: The scheduler name for the pipeline. Please note currently `1f1b` is the only supported scheduler in `autodist`. Default is `1f1b`. Optional. - - `parallel_profile`: If set to `True`, autodist will profile operators in parallel by using available gpus. Default is `True`. Optional. - - `max_partition_degree`: Max degree when partitioning an operator / node. When pipeline parallelism is enabled to explore (`explore_pipeline` is True), user can change the value to constrain the plan to be composed of stages that span on less or equal to `max_partition_degree` devices (recommend to set `max_partition_degree` to the number of devices in a node to avoid inter-node communication, but should be be no more than `plan_ngpus`). Default is `plan_ngpus`. Optional. - - `transient_mem_coef`: In autodist, a heuristic is used to estimate the transient memory size: `transient_mem_size = opt_transient_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)`. This formula is useful in many cases, but it may be too strict when some operators consume or generate a large tensor (>= 4GB). In this case, you can set `transient_mem_coef` to a smaller value to relax the constraint. Default is `2`. Optional. - - You can also put any other settings that can affect code generation here. but please prefix the keys with `_` to avoid conflicts with predefined keys. - -Here is an example: -```python -compute_config = ComputeConfig( - plan_ngpus=..., - runtime_ngpus=..., - use_zero=..., - pas_config={ - '__pas_name': ..., # addtional configurations that can affect code generation. - 'update_freq': ..., - 'mem_constraint': ..., - 'task_name': ..., - 'use_fp16': ..., - 'use_memory_efficient_fp16': ..., - 'use_bf16': ..., - 'use_memory_efficient_bf16': ..., - 're_profile': ..., - 'verbose': ..., - 'load_plan_path': ..., - 'save_plan_path': ..., - 'partition_constraints_path': ..., - 'recompute_modules': ..., - 'pipeline_pivots': ..., - 'use_apex_fused_adam_v2': ..., - }, -) -``` - ### Checkpoint support You can save/load the checkpoints for parallel modules. @@ -740,3 +509,14 @@ def create_distributed_sampler(dataset): ..., ) ``` + +### self.training support + +To parallelize the training process, we firstly need to trace the module and get a static computational graph. + +A common problem with static graph is that it is impossible to handle control flow. + +But on the other hand, `self.training` is very common used in module forward method. +So we add a very limited support for `self.training` in tracing. + +Please note that user code is flattened and transformed into a single `ParallelModule` at runtime, so `training` is a global module state, and we don't support the case that user want to set a sub-module's training to True but remaining modules to False. \ No newline at end of file diff --git a/docs/source/pytorch_lightning.md b/docs/source/pytorch_lightning.md index a66bc082..8ca793ac 100644 --- a/docs/source/pytorch_lightning.md +++ b/docs/source/pytorch_lightning.md @@ -1,4 +1,5 @@ -# Pytorch Lightning support +# PyTorch Lightning +[//]: # (# Pytorch Lightning support) We support Pytorch Lightning by `NnScalerStrategy` and `NnScalerPrecision`. You can use `nnscaler` strategy in pytorch lightning like this: @@ -75,8 +76,9 @@ def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str) -> None ``` where `checkpoint_files` is a list of checkpoint files to merge, and `output_file` is the output file path. -## Limitation +## Limitations -1. Only one optimizer is supported. -2. Only one lr scheduler is supported. -3. Only one parameter group is supported. +Currently, nnScaler only supports: +- single parameter group. +- single optimizer. +- single learning rate scheduler. \ No newline at end of file diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 0c9732e7..2349a869 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -1,117 +1,190 @@ -########### -Get Started -########### - -Repo address: https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube - -The nnScaler repo is currently internal. -If you do not have access, please contact cubedev@microsoft.com +########## +Quickstart +########## ************ Installation ************ -To get started, install the latest wheel from -`DevOps Artifacts `_. - -If you are familiar with Azure stuffs, you can follow DevOps' guide to set up the repository. - -Or if you prefer the simpler way, download the ``.whl`` file in the "Files" section of the website, -and install it locally: +nnScaler can be :doc:`installed from GitHub `: :: -:: - - python -m pip install nnscaler-*.whl - -********** -Quickstart -********** + git clone --recursive https://github.com/microsoft/nnscaler + cd nnscaler/ + pip install -r requirements.txt + pip install -e . -The next step depends on your choice of the training framework. +*************************** +Parallelize a Minimal Model +*************************** -- **No framework**: if you write your own training code and do not use a framework, - see :ref:`Parallelize API` section. -- **Fairseq**: if you use fairseq, see :ref:`Fairseq` section. -- **Lightning**: TODO +You can verify the installation by parallize a minimal model: -.. _Parallelize API: +.. code-block:: python -Parallelize API -=============== - -TODO: write a hello world example, assigned to Zhe Liu - -If you write your own training code, you can use the *parallelize* API to make your model parallel: - -:: + # model.py + import os import torch - from nnscaler import parallelize, ComputeConfig, build_optimizer - - class LLM(torch.nn.Module): - def __init__(self, ...): - ... - def forward(self, x): - ... - - llm_sample_input = ... # dummpy input will be used to do tracing - pas_policy = ... # the PAS policy, you can use autodist pas - compute_config = ComputeConfig( - plan_ngpus=..., - runtime_ngpus=..., - use_zero=..., - ..., - ) # compute environment config - ParallelizedLLM = parallelize( - LLM, - {'x': llm_sample_input}, - pas_policy, - compute_config, - ) - -Example -------- - -An example of the parallelize API is provided in the repo: -`train.py `_ + from nnscaler.cli.trainer import Trainer + from nnscaler.cli.trainer_args import * + from nnscaler.utils import set_default_logger_level + + set_default_logger_level('INFO') + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(100, 10) + + def forward(self, data): + x = self.linear(data['x']) + return torch.nn.functional.cross_entropy(x, data['y']) + + class RandomDataset: + def __init__(self, split): + pass + + def __getitem__(self, i): + return { + 'x': torch.rand(100), + 'y': torch.randint(10, tuple()), + } + + def __len__(self): + return 100 + + if __name__ == '__main__': + world_size = int(os.getenv('WORLD_SIZE', 1)) + trainer_args = TrainerArgs( + compute_config=ComputeConfig(plan_ngpus=1, runtime_ngpus=world_size, use_end2end=True), + model=ModelConfig(type=Model), + optimizer=OptimizerConfig(type=torch.optim.AdamW), + dataset=DatasetConfig(type=RandomDataset, train_args={'split': 'train'}), + max_train_steps=10, + enable_progress_bar=False, + ) + trainer = Trainer(train_args=trainer_args) + trainer.run() + +To run it in parallel, use `torchrun `_: :: + + torchrun --nproc_per_node=2 model.py + +Expected output: + +.. (FIXME: adjust log level) -You can download and try it: :: - - torchrun --nproc_per_node=4 --nnodes=1 train.py - -Documentation -------------- - -If the example works for you, you can now follow the documentation to parallelize your model: -:doc:`parallel_module` - -.. _Fairseq: - -Fairseq (TODO) -======= - -.. TODO: - - nnScaler provides `fairseq integration `_. - - TODO: refine the example (and its doc), assigned to Youshan Miao - - TODO (long term): write an example using unmodified fairseq - - Installation - ------------ - - To use fairseq, clone the fork and install it: :: - - python -m pip uninstall fairseq - - git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Fairseq - cd Fairseq - python -m pip install -e . - - Example - ------- - - Follow the example - `here `_. +:: + 2024-09-09 20:28:04 | INFO | nnscaler.graph.parser.converter | constant folding disabled to parse graph + 2024-09-09 20:28:04 | WARNING | nnscaler.graph.graph | nnScaler does not support to compute gradients for IRPyFunc. + Following nodes require gradients, this may trigger error in backward: + _operator.getitem, cid: 1 + + 2024-09-09 20:28:04 | WARNING | nnscaler.graph.segment | nnScaler does not support backward of IRPyFunc: PyOp1-()(sign=getitem, inputs=(Object(data35, val={'x': t32(p30,(1, 100),d(),v(0/1)), 'y': t34(p33,(1,),d(),v(0/1))}, is_constant=False), 'x'), outputs=(t25(p4,(1, 100),d(),v(0/1)),)), skip setting gradient, please register it as IRDimOps. + 2024-09-09 20:28:04 | INFO | nnscaler.autodist.apis | AutoDistConfig {'pc_path': '', 'profile_dir': PosixPath('/home/.cache/nnscaler/autodist/1.0/NVIDIA_RTX_A6000'), 'topk': 20, 'task_name': '__1gpus_1update_freq', 'load_plan_path': None, 'save_plan_path': None, 'consider_mem': True, 'zero_stage': 0, 'zero_ngroups': 1, 'opt_resident_coef': 2, 'opt_transient_coef': 0, 'is_train': True, 'mesh_desc': MeshDesc(row=1, col=1), 'ngpus': 1, 'recompute_modules': '', 'memory_constraint': 40802189312, 'memory_granularity': 524288, 'micro_batch_size': 1, 'update_freq': 1, 'world_size': 1, 'nproc': 1, 'ignore_small_tensor_threshold': 524288, 'verbose': False, 're_profile': False, 'pipeline': False, 'pipeline_pivots': '', 'pipeline_nstages': 1, 'pipeline_scheduler': '1f1b', 'max_pipeline_bubble_ratio': 0.2, 'max_pipeline_unbalance_ratio': 0.5, 'solver': 'dp', 'parallel_profile': True, 'transient_mem_coef': 2} + 2024-09-09 20:28:04 | WARNING | nnscaler.autodist.cost_database | Communication profile data not found, using default data at /home/nnscaler/nnscaler/resources/profile/mi200/comm + 2024-09-09 20:28:04 | INFO | nnscaler.autodist.cost_database | Profiling in parallel + 2024-09-09 20:28:06 | INFO | nnscaler.autodist.cost_database | device 0 finished profiling 1 nodes + 2024-09-09 20:28:06 | INFO | nnscaler.autodist.cost_database | device 2 finished profiling 0 nodes + 2024-09-09 20:28:06 | INFO | nnscaler.autodist.cost_database | device 1 finished profiling 1 nodes + 2024-09-09 20:28:06 | INFO | nnscaler.autodist.cost_database | device 3 finished profiling 0 nodes + 2024-09-09 20:28:07 | WARNING | nnscaler.autodist.model_graph | detect a non-IRDimops _operator.getitem at File "/home/nnscaler/test.py", line 16, in forward, x = self.linear(data['x']) that produces tensors + 2024-09-09 20:28:07 | WARNING | nnscaler.autodist.model_graph | detect a non-IRDimops _operator.getitem at File "/home/nnscaler/test.py", line 17, in forward, return torch.nn.functional.cross_entropy(x, data['y']) that produces tensors + 2024-09-09 20:28:07 | INFO | nnscaler.autodist.model_graph | + -------------------------nnScaler Graph Profiling Result------------------------- + + depth 1 + param_mem - [('linear, Linear', '0.00 MB'), ('_operator.getitem', '0.00 MB'), ('_operator.getitem', '0.00 MB')] + fw_span - [('torch.nn.functional.cross_entropy', '0.08 ms'), ('linear, Linear', '0.08 ms'), ('_operator.getitem', '0.00 ms')] + train_mem - [('linear, Linear', '0.00 MB'), ('torch.nn.functional.cross_entropy', '0.00 MB'), ('_operator.getitem', '0.00 MB')] + buffer_mem - [('_operator.getitem', '0.00 MB'), ('linear, Linear', '0.00 MB'), ('_operator.getitem', '0.00 MB')] + depth 2 + param_mem - [('torch.nn.functional.linear', '0.00 MB')] + fw_span - [('torch.nn.functional.linear', '0.08 ms')] + train_mem - [('torch.nn.functional.linear', '0.00 MB')] + buffer_mem - [('torch.nn.functional.linear', '0.00 MB')] + + 2024-09-09 20:28:07 | INFO | nnscaler.autodist.apis | param mem 0 MB, buff mem 0 MB, activation mem 0 MB + 2024-09-09 20:28:07 | INFO | nnscaler.autodist.apis | estimated minimum memory per device 0.0 MB + 2024-09-09 20:28:07 | INFO | nnscaler.autodist.spmd_solver | no partition constraint is loaded + 2024-09-09 20:28:07 | INFO | nnscaler.autodist.cost_database | Profiling in parallel + 2024-09-09 20:28:08 | INFO | nnscaler.autodist.cost_database | device 1 finished profiling 1 nodes + 2024-09-09 20:28:08 | INFO | nnscaler.autodist.cost_database | device 3 finished profiling 0 nodes + 2024-09-09 20:28:08 | INFO | nnscaler.autodist.cost_database | device 2 finished profiling 0 nodes + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.cost_database | device 0 finished profiling 1 nodes + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.spmd_solver | force_replica_threshold is 0 + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.spmd_solver | finish building op partitions + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.spmd_solver | finish building following relationships + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.spmd_solver | finish filtering useless partitions + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.spmd_solver | total state num is 4 + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.spmd_solver | output each operator's importance ratio (percentages of states that can be reduced by forcing the operator to be partitioned in a single partition) + + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.spmd_solver | finish spmd solver initializetion + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.apis | use plan with e2e time/s 0.94ms + 2024-09-09 20:28:09 | INFO | nnscaler.autodist.apis | + autodist plan analysis for stage 0 on devices [0] with mem 0.00 GB: + + Total computation time: 0.94 ms + Top 10 of operators that consume the most computation time: + torch.nn.functional.cross_entropy: 0.50 ms + torch.nn.functional.linear: 0.44 ms + _operator.getitem: 0.00 ms + Top 10 of operators computation time sum: 0.94 ms + + Top 2 operators split info: + torch.nn.functional.cross_entropy: + FwOp4-()(name=cross_entropy, inputs=(t28(p10,(1, 10),d(),v(0/1)), t29(p12,(1,),d(),v(0/1))), outputs=(t24(p13,(1,),d(),v(0/1)),)) + File "/home/nnscaler/test.py", line 17, in forward, return torch.nn.functional.cross_entropy(x, data['y']) + N^ C^, N^ -> 1^, OpPartition((-1,), (1,)), comp_time: 0.50 ms, comm_time: 0.00 ms + + + torch.nn.functional.linear: + FwOp2-()(name=linear, inputs=(t25(p4,(1, 100),d(),v(0/1)), w26(p6,(10, 100),d(),v(0/1)), w27(p8,(10,),d(),v(0/1))), outputs=(t28(p10,(1, 10),d(),v(0/1)),)) + File "/home/nnscaler/test.py", line 16, in forward, x = self.linear(data['x']) + a k^, n k^, n -> a n, OpPartition((-1,), (1,)), comp_time: 0.44 ms, comm_time: 0.00 ms + + + Total communication time: 0.00 ms + Top 10 operators that consume the most communication time: + Top 10 of operators communication time sum: 0.00 ms + + Module analysis: + Depth 1: + Top 3 modules that consume the most computation time: + Top 3 modules that consume the most communication time: + Top 3 modules that consume the most memory: + Depth 2: + Top 3 modules that consume the most computation time: + Top 3 modules that consume the most communication time: + Top 3 modules that consume the most memory: + + 2024-09-09 20:28:09 | INFO | nnscaler.graph.gener.gen | finish reordering producer and consumer + 2024-09-09 20:28:09 | INFO | nnscaler.graph.gener.gen | finish removing anchor nodes + 2024-09-09 20:28:09 | INFO | nnscaler.graph.gener.gen | finish replacing auto pyfunc + 2024-09-09 20:28:09 | INFO | nnscaler.graph.gener.gen | finish transforming multiref nodes + 2024-09-09 20:28:09 | INFO | nnscaler.graph.gener.gen | finish local fusion & multiref for 4 tensors + 2024-09-09 20:28:09 | INFO | nnscaler.graph.gener.gen | finish reordering producer and consumer + 2024-09-09 20:28:09 | INFO | nnscaler.graph.gener.gen | finish generating 4 activation adapters + 2024-09-09 20:28:09 | INFO | nnscaler.execplan.planpass.fusion | adapter fusion: successfully fuse 0 differentiable adapters + 2024-09-09 20:28:09 | INFO | nnscaler.runtime.module | loading partitioned model from /home/nnscaler/.nnscaler/_parallel_modules/__main__/Model/_/fullmodel.pt, number of model parameter chunks: 1 + 2024-09-09 20:28:09 | INFO | nnscaler.cli.trainer | Training... + 2024-09-09 20:28:10 | INFO | nnscaler.cli.trainer | Epoch 0: 010/100 train_loss=2.261, lr=0.001, gnorm=5.590, train_wall=0.004 + 2024-09-09 20:28:10 | INFO | nnscaler.cli.trainer | Saving checkpoint after 10 steps with loss=2.261. + 2024-09-09 20:28:10 | INFO | nnscaler.cli.trainer | Saving checkpoint to checkpoints/0000-0010 + 2024-09-09 20:28:10 | INFO | nnscaler.cli.trainer | Saving checkpoint as the last checkpoint. + 2024-09-09 20:28:10 | INFO | nnscaler.cli.trainer | Best loss updated: inf -> 2.261 + 2024-09-09 20:28:10 | INFO | nnscaler.cli.trainer | Saving checkpoint as the best checkpoint. + 2024-09-09 20:28:10 | INFO | nnscaler.cli.trainer | Reached max train steps(10): Training is done. + +********* +Next Step +********* + +The above example uses nnScaler's :doc:`Trainer APIs `. +To learn more about it, you may check our :doc:`Llama 3 example `. + +Or if you prefer to use a familiar trainer, we also provides integration with `PyTorch Lightning `_. +The usage is demostrated by :doc:`nanoGPT example `. + +If you want to try a more advanced model, please check :doc:`Llama 3 128K sequence length example `. diff --git a/docs/source/quickstart_internal.rst b/docs/source/quickstart_internal.rst new file mode 100644 index 00000000..28b6cc04 --- /dev/null +++ b/docs/source/quickstart_internal.rst @@ -0,0 +1,116 @@ +########### +Get Started +########### + +The nnScaler internal repo: https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube + +If you do not have access, please contact nnscaler@service.microsoft.com + +************ +Installation +************ + +To get started, install the latest wheel from +`DevOps Artifacts `_. + +If you are familiar with Azure stuffs, you can follow DevOps' guide to set up the repository. + +Or if you prefer the simpler way, download the ``.whl`` file in the "Files" section of the website, +and install it locally: + +:: + + python -m pip install nnscaler-*.whl + +********** +Quickstart +********** + +The next step depends on your choice of the training framework. + +- **No framework**: if you write your own training code and do not use a framework, + see :ref:`Parallelize API` section. +- **Fairseq**: if you use fairseq, see :ref:`Fairseq` section. +- **Lightning**: TODO + +.. _Parallelize API: + +Parallelize API +=============== + +TODO: write a hello world example, assigned to Zhe Liu + +If you write your own training code, you can use the *parallelize* API to make your model parallel: + +.. code-block:: python + + import torch + from nnscaler import parallelize, ComputeConfig, build_optimizer + + class LLM(torch.nn.Module): + def __init__(self, ...): + ... + def forward(self, x): + ... + + llm_sample_input = ... # dummpy input will be used to do tracing + pas_policy = ... # the PAS policy, you can use autodist pas + compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + ..., + ) # compute environment config + ParallelizedLLM = parallelize( + LLM, + {'x': llm_sample_input}, + pas_policy, + compute_config, + ) + +Example +------- + +An example of the parallelize API is provided in the repo: +`train.py `_ + +You can download and try it: :: + + torchrun --nproc_per_node=4 --nnodes=1 train.py + +Documentation +------------- + +If the example works for you, you can now follow the documentation to parallelize your model: +:doc:`parallel_module` + +.. _Fairseq: + +Fairseq (To be retired) +======= + +.. TODO: + + nnScaler provides `fairseq integration `_. + + TODO: refine the example (and its doc), assigned to Youshan Miao + + TODO (long term): write an example using unmodified fairseq + + Installation + ------------ + + To use fairseq, clone the fork and install it: :: + + python -m pip uninstall fairseq + + git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Fairseq + cd Fairseq + python -m pip install -e . + + Example + ------- + + Follow the example + `here `_. + diff --git a/docs/source/readme.rst b/docs/source/readme.rst deleted file mode 100644 index 11a93b10..00000000 --- a/docs/source/readme.rst +++ /dev/null @@ -1,76 +0,0 @@ -============================================================== -nnScaler: A Parallelization System for DNN Model Training -============================================================== - -Introduction ------------- -**nnScaler** is a parallelization system for deep neural network (DNN) model training. - - -nnScaler automatically parallelizes DNN models across multiple devices, enabling users to focus on model design. nnScaler supports new parallelisms that outperform existing parallel execution approaches. nnScaler supports extending DNN modules with new structures or execution patterns, enabling users to parallelize their own new DNN models. nnScaler can support paralleling new DNN models by providing user-defined functions for the new operators unrecognized by the nnScaler. - -Features --------- -- **Automatic Parallelization**: nnScaler automatically parallelizes DNN models across multiple devices, enabling users to focus on model design. -- **High Performance**: nnScaler supports new parallelisms that outperform existing parallel execution approaches. -- **Extensibility**: nnScaler supports extending DNN modules with new structures or execution patterns, enabling users to parallelize their own new DNN models. -- **Compatibility**: nnScaler can support paralleling new DNN models by providing user-defined functions for the new operators unrecognized by the nnScaler. - -Overview --------- - -Below is an overview of the nnScaler system. The nnScaler system consists of three main components: the parallelization compiler, the planner, and the interface. The parallelization compiler takes a DNN model as input, converts into intermediate representation (Graph IR) and generates execution for multiple devices. The parallelization planner will provide efficient strategies during parallelization. The nnScaler interface provides a set of parallelization APIs to support different trainers through certain adapters, as well as extending the nnScaler system. - -.. figure:: images/overview.png - :alt: overview - :figwidth: 80% - :align: center - - **nnScaler Overview** - -Outline --------- -- **Quick Start**: Learn how to install and use nnScaler. - - **Installation**: Install nnScaler on your machine. - - **Get Started**: Started from a simple example. -- **User Guide**: Learn how to use nnScaler to parallelize a model. - - **Example**: Parallelize NanoGPT through PyTorch Lightning interface. -- **Developer Guide**: Find detailed information about nnScaler. - - **Extending nnScaler**: Learn how to extend nnScaler. -- **Frequently Asked Questions**: Find answers to common questions about nnScaler. - - -Reference ---------- -Please cite nnScaler in your publications if it helps your research: :: - - @inproceedings {nnscaler-osdi24, - author = {Zhiqi Lin and Youshan Miao and Quanlu Zhang and Fan Yang and Yi Zhu and Cheng Li and Saeed Maleki and Xu Cao and Ning Shang and Yilei Yang and Weijiang Xu and Mao Yang and Lintao Zhang and Lidong Zhou}, - title = {nnScaler: Constraint-Guided Parallelization Plan Generation for Deep Learning Training}, - booktitle = {18th {USENIX} Symposium on Operating Systems Design and Implementation ({OSDI} 24)}, - year = {2024}, - publisher = {{USENIX} Association}, - } - -Contributing ------------- -This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. - -When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. - -This project has adopted the `Microsoft Open Source Code of Conduct `_. For more information, see the `Code of Conduct FAQ `_ or contact `opencode@microsoft.com `_ with any additional questions or comments. - -Trademarks ----------- -This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow `Microsoft's Trademark & Brand Guidelines `_. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos is subject to those third-party's policies. - -Contact -------- -You may find our public repo from `https://github.com/microsoft/nnscaler`_ or microsoft internal repo `https://aka.ms/ms-nnscaler`_. - -.. _`https://github.com/microsoft/nnscaler`: https://github.com/microsoft/nnscaler - -.. _`https://aka.ms/ms-nnscaler`: https://aka.ms/ms-nnscaler - -For any questions or inquiries, please contact us at nnscaler@service.microsoft.com. - diff --git a/docs/source/register_custom_op.md b/docs/source/register_custom_op.md index 286f28a6..75532cef 100644 --- a/docs/source/register_custom_op.md +++ b/docs/source/register_custom_op.md @@ -1,4 +1,5 @@ -# Register a new operator/function +# Customized Operator +[//]: # (# Register a new operator/function) ## Overview @@ -38,7 +39,7 @@ nnscaler.register_op('(h^ m^) kd+, kd+ n -> h^ m^ n', name='matmul_custom')(oper ``` -## Api Explains +## API Explained ```python def register_op( diff --git a/docs/source/trainer.md b/docs/source/trainer.md deleted file mode 100644 index a4635eba..00000000 --- a/docs/source/trainer.md +++ /dev/null @@ -1,477 +0,0 @@ -# Trainer - -We provide a `Trainer` class that can be used to train and evaluate models. It will firstly parallelize the model on multiple GPUs with `parallelize` API, and then train the model with the given dataset and optimizer in a distributed way. - - -## Arguments - -All the arguments are defined in `TrainerArgs` class. Here is the definition of `TrainerArgs`: - -```python -@dataclass -class TrainerArgs: - compute_config: ComputeConfig = None - - gen_savedir: str = './.nnscaler' - gen_reuse: str = 'auto' - pas_policy: str = 'autodist' - broadcast_strategy: str = 'all' - instance_name: str = None - run_mode: str = 'run' - tracing_from_weights: str = None - - model: ModelConfig = field(default_factory=ModelConfig) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) - dataset: DatasetConfig = field(default_factory=DatasetConfig) - dataloader: DataloaderConfig = field(default_factory=DataloaderConfig) - dataset_sampler: DatasetSamplerConfig = field(default_factory=DatasetSamplerConfig) - lr_scheduler: Optional[LRSchedulerConfig] = None - checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) - log: List[LogConfig] = field(default_factory=list) - hook: Union[HookConfig, HookMapConfig, None] = None - - precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None - - micro_batch_size: int = 1 - global_batch_size: Optional[int] = None - grad_accumulation_steps: Optional[int] = None - - max_epochs: Optional[int] = None - max_train_steps: Optional[int] = None - max_val_steps: Optional[int] = None - - val_every_n_train_steps: Optional[int] = None - val_every_n_epochs: Optional[int] = 1 - - enable_progress_bar: bool = True - - seed: Optional[int] = None - init_env_fn: str = None -``` - -The design philosophy of `Trainer` arguments is: -The classes(or factory functions) of components(model/optimizer/etc) -and their arguments are provided in the `TrainerArgs` class (functions/types are passed as fully qualified names), -and we are responsible for creating them. - -For example, you can tell me how to create a model by providing the model type and its arguments in `ModelConfig` class. - -Please note some of the arguments of components are set automatically, and you should not set them manually. -For example, arguments `dataset`, `num_replicas` and `rank` of the dataset sampler are set automatically by the `Trainer` class. -Those 3 arguments passed in the `DatasetSamplerConfig.train_args/val_args`(if any) will be ignored. - -```python -'dataset': { - 'type': 'SomeDataset', - 'train_args': { - ... - }, - 'val_args': { - ... - } -} -'dataset_sampler': { - 'type': 'SomeDatasetSampler', - 'train_args': { - 'num_replicas': ..., # this will be ignored - 'dataset': ..., # this will be ignored - 'rank': ..., # this will be ignored - ... - }, - 'val_args': { - 'num_replicas': ..., # this will be ignored - 'dataset': ..., # this will be ignored - 'rank': ..., # this will be ignored - ... - }, -} -``` - -If any argument type is a class, you can pass it as a dict, and add a special key `__type` to specify the class type. - -For example, if the module `__init__` takes `ModelConfig` object -```python -class SomeModule(torch.nn.Module): - def __init__(self, model_config: ModelConfig): - ... -``` -You can pass the `model_config` as -```python -{ - 'type': 'SomeModule', - 'args': { - 'model_config': { - '__type': 'ModelConfig', - # arguments to create ModelConfig - } - } -} -``` - -We also use `ast.literal_eval` to guess the type of the string arguments, You can skip it by passing a dict with `__value_type` and `value` keys. For example, you want a number to be a str, you can use -```python -{ - '__value_type': 'str', - 'value': '1' -} -``` -Internally we will get the final value with `__value_type(value)`. - -### Component Configs - -- `model` (`ModelConfig`): The model to be trained. You need to provide the model type and its arguments in `ModelConfig` class. Here is the definition of `ModelConfig`: - - ```python - @dataclass - class ModelConfig: - type: str = None - args: Dict[str, Any] = field(default_factory=dict) - ``` -- `optimizer` (`OptimizerConfig`): The optimizer to be used. - - ```python - @dataclass - class OptimizerConfig: - type: str = None - args: Dict[str, Any] = field(default_factory=dict) - clip_gnorm: float = 0.0 - - loss_reduction: str = 'mean' - grad_reduction: str = 'mean' - aggregate_outputs_fn: str = None - ``` - - `type` (`str`): The optimizer type or factory function. - Please note the first parameter of the optimizer constructor must be the model parameters. - - `args` (`Dict[str, Any]`): The arguments of the optimizer. - - `clip_gnorm` (`float`): The maximum norm value for gradient clipping. 0.0/None means no clipping. - - `loss_reduction` (`str`): The reduction method for loss. - It can be `mean` (average the loss over all micro-batches), - `sum` (sum the loss of all micro-batches). - Default is `mean`. - Please note in validation stage, this configuration is ignored the loss is always averaged over all batches - - `grad_reduction` (`str`): The reduction method for gradients. It can be `mean` (average the gradients over all micro-batches), `sum` (sum the gradients of all micro-batches), `per-token-mean` (average the gradients over all tokens). Default is `mean`. Please note if `per-token-mean` is used, you need to specify `aggregate_outputs_fn`, which will return the number of tokens - - `aggregate_outputs_fn` (`str`): The function to aggregate the outputs of the model. It is required when `grad_reduction` is `per-token-mean`. Its signature should be `def aggregate_outputs(self, loss_outputs, sync_group) -> AggregatedOutputs`, where `loss_outputs` is a list of outputs of the model, and `sync_group` is the `torch.distributed.ProcessGroup` to sync with. The function should return an `AggregatedOutputs` object, which defines as: - ```python - @dataclass - class AggregatedOutputs: - # the aggregated loss as a sum - loss_sum: float = None - # number of mini batches - num_batches: int = None - # number of tokens (necessary when grad_reduction is 'per-token-mean') - num_tokens: Optional[int] = None - # any other custom outputs - aggregated_outputs: Any = None - ``` -- `dataset` (`DatasetConfig`): The dataset to be used. - ```python - @dataclass - class DatasetConfig: - type: str = None - train_args: Dict[str, Any] = field(default_factory=dict) - val_args: Dict[str, Any] = field(default_factory=dict) - ``` - - `type` (`str`): The dataset type or factory function. - - `train_args` (`Dict[str, Any]`): The arguments of the training dataset. - - `val_args` (`Dict[str, Any]`): The arguments of the validation dataset. -- `dataloader` (`DataloaderConfig`): The dataloader to be used. Please note we recommend to pass `drop_last=True` in the dataloader arguments to avoid the last batch with different sizes. - - ```python - @dataclass - class DataloaderConfig: - type: str = 'torch.utils.data.DataLoader' - train_args: Dict[str, Any] = field(default_factory=dict) - # default to train_args - val_args: Dict[str, Any] = field(default_factory=dict) - # default to train_args - test_args: Dict[str, Any] = field(default_factory=dict) - ``` - - `type` (`str`): The dataloader type or factory function. - Please note the dataloader constructor must at least have 3 parameters `dataset`, `batch_size`, `sampler`. - - `train_args` (`Dict[str, Any]`): The arguments (except `dataset`,`batch_size`, `sampler`) of the training dataloader. Argument `batch_size` will be set to `micro_batch_size`. - - `val_args` (`Dict[str, Any]`): The arguments (except `dataset`,`batch_size`, `sampler`) of the validation dataloader. - -- `dataset_sampler` (`DatasetSamplerConfig`): The dataset sampler to be used. - - ```python - @dataclass - class DatasetSamplerConfig: - type: str = 'torch.utils.data.DistributedSampler' - train_args: Dict[str, Any] = field(default_factory=dict) - val_args: Dict[str, Any] = field(default_factory=dict) - test_args: Dict[str, Any] = field(default_factory=dict) - ``` - - `type` (`str`): The dataset sampler type or factory function. - Please note the dataset sampler constructor must at least have 3 parameters `dataset`, `num_replicas`, `rank`. - - `train_args` (`Dict[str, Any]`): The arguments (except `dataset`,`num_replicas`, `rank`) of the training dataset sampler. - - `val_args` (`Dict[str, Any]`): The arguments (except `dataset`,`num_replicas`, `rank`) of the validation dataset sampler. - -- `lr_scheduler` (`LRSchedulerConfig`): The learning rate scheduler to be used. This is optional. - - ```python - @dataclass - class LRSchedulerConfig: - type: str = None - args: Dict[str, Any] = field(default_factory=dict) - interval: str = 'epoch' - ``` - - `type` (`str`): The learning rate scheduler type or factory function. - Please note the first parameter of the learning rate scheduler constructor must be optimizer. - - `args` (`Dict[str, Any]`): The arguments of the learning rate scheduler. - - `interval` (`str`): The interval to update the learning rate. It can be `epoch` or `step`. Default is `epoch`. - -- `log` (`List[LogConfig]`): The loggers to be used. You can provide multiple loggers. Currently we have two builtin loggers: `TensorBoardLogger` and `WandbLogger`. - - ```python - @dataclass - class LogConfig: - type: str = None - args: Dict[str, Any] = field(default_factory=dict) - ``` - - `type` (`str`): The logger type or factory function. - - `args` (`Dict[str, Any]`): The arguments of the logger. - -- `debug` (`DebugConfig`): Trainer debug related setting. - - ```python - @dataclass - class DebugConfig: - check_gradient_sync_cross_devices: bool = True - ``` - - - `check_gradient_sync_cross_devices` (`bool`): Before gradient clip norm, check the gradient sync for the same parameter is consistent cross devices, if ZeRO is enabled, will check the gradient cross each ZeRO group, if ZeRO is not enabled, will check the gradient cross each nnscaler scale unit. This helps to find bugs related to gradient updates during training. Default is `True`. - -- `hook` (`Union[HookConfig, HookMapConfig, None]`): The hooks to be used. -You can provide a hook with a hook class or a map of hook functions. -Please note if your `model`/`optimizer`/`lr_scheduler` inherit from `TrainHook`, -their hook functions will be called automatically. -The order of the hook functions called is `model` -> `optimizer` -> `lr_scheduler`, -and hooks passed with this config is called in the last. - - Hook class: - - ```python - @dataclass - class HookConfig: - type: str = None - args: Dict[str, Any] = field(default_factory=dict) - ``` - - - `type` (`str`): The hook type or factory function. - - `args` (`Dict[str, Any]`): The arguments of the hook. - - Hook map: - - ```python - @dataclass - class HookMapConfig: - after_setup: str = None - - on_train_start: str = None - on_train_end: str = None - on_val_start: str = None - on_val_end: str = None - - on_epoch_start: str = None - on_epoch_end: str = None - - on_train_step_start: str = None - on_train_step_end: str = None - on_val_step_start: str = None - on_val_step_end: str = None - - after_aggregate_train_step_outputs: str = None - after_aggregate_val_step_outputs: str = None - - before_zero_grad: str = None - after_zero_grad: str = None - - before_gnorm_clip: str = None - after_gnorm_clip: str = None - - before_optimizer_step: str = None - after_optimizer_step: str = None - - on_load_checkpoint: str = None - on_save_checkpoint: str = None - ``` - - `after_setup` (`str`): The hook function to be called after setting up the trainer. - Only be called when `run_mode == 'run'`. - Signature: `def after_setup(trainer: 'Trainer') -> None:` - - `on_train_start` (`str`): The hook function to be called at the start of the training stage. Signature: `def on_train_start(trainer: 'Trainer') -> None:` - - `on_train_end` (`str`): The hook function to be called at the end of the training stage. Signature: `def on_train_end(trainer: 'Trainer') -> None:` - - `on_val_start` (`str`): The hook function to be called at the start of the validation stage. Signature: `def on_val_start(trainer: 'Trainer') -> None:` - - `on_val_end` (`str`): The hook function to be called at the end of the validation stage. Signature: `def on_val_end(trainer: 'Trainer', val_loss: float) -> None:` - - `on_epoch_start` (`str`): The hook function to be called at the start of each epoch. Signature: `def on_epoch_start(trainer: 'Trainer', epoch: int) -> None:` - - `on_epoch_end` (`str`): The hook function to be called at the end of each epoch. Signature: `def on_epoch_end(trainer: 'Trainer', epoch: int) -> None:` - - `on_train_step_start` (`str`): The hook function to be called at the start of each training step. Signature: `def on_train_step_start(trainer: 'Trainer', batches: List[Any], idx: int) -> None:` - - `on_train_step_end` (`str`): The hook function to be called at the end of each training step. Signature: `def on_train_step_end(trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None:` - - `on_val_step_start` (`str`): The hook function to be called at the start of each validation step. Signature: `def on_val_step_start(trainer: 'Trainer', batches: List[Any], idx: int) -> None:` - - `on_val_step_end` (`str`): The hook function to be called at the end of each validation step. Signature: `def on_val_step_end(trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None:` - - `after_aggregate_train_step_outputs` (`str`): The hook function to be called after aggregating the outputs of the model in the training step. Signature: `def after_aggregate_train_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None:` - - `after_aggregate_val_step_outputs` (`str`): The hook function to be called after aggregating the outputs of the model in the validation step. Signature: `def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None:` - - `before_zero_grad` (`str`): The hook function to be called before zeroing the gradients. Signature: `def before_zero_grad(trainer: 'Trainer') -> None:` - - `after_zero_grad` (`str`): The hook function to be called after zeroing the gradients. Signature: `def after_zero_grad(trainer: 'Trainer') -> None:` - - `before_sync_grad` (`str`): The hook function to be called before syncing the gradients between ranks. - Please note this hook can't be triggered correctly, - and you should not reply on this. - Will fix it later. - Signature: `def before_sync_grad(trainer: 'Trainer') -> None:` - - `after_sync_grad` (`str`): The hook function to be called after syncing the gradients between ranks. Signature: `def after_sync_grad(trainer: 'Trainer') -> None:` - - `before_gnorm_clip` (`str`): The hook function to be called before gradient clipping. Signature: `def before_gnorm_clip(trainer: 'Trainer') -> None:` - - `after_gnorm_clip` (`str`): The hook function to be called after gradient clipping. Signature: `def after_gnorm_clip(trainer: 'Trainer', gnorm: torch.Tensor) -> None:` - - `before_optimizer_step` (`str`): The hook function to be called before the optimizer step. Signature: `def before_optimizer_step(trainer: 'Trainer') -> None:` - - `after_optimizer_step` (`str`): The hook function to be called after the optimizer step. Signature: `def after_optimizer_step(trainer: 'Trainer') -> None:` - - `on_load_checkpoint` (`str`): The hook function to be called after loading the checkpoint. If you saved something with `on_save_checkpoint` this is - your chance to restore this. Signature: `def on_load_checkpoint(trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None:` - - `on_save_checkpoint` (`str`): The hook function to be called before saving the checkpoint. If you want to save something, you can add it to the checkpoint here. Signature: `def on_save_checkpoint(trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None:` - -### Compute Config - -All compute configs are put in `compute_config` (`ComputeConfig`). Please refer to [link](./parallel_module.md#ComputeConfig) for more information. - -Please note only end2end mode is supported in the trainer, so you must set `compute_config.use_end2end` to `True` to make it work. - -### Checkpoint Config - -```python -@dataclass -class CheckpointConfig: - save_dir: str = './checkpoints' - no_save: bool = False - - save_type: str = 'sharded' - - save_last: bool = True - save_best: bool = True - symlink_best_and_last: bool = True - - every_n_train_steps: Optional[int] = None - every_n_epochs: Optional[int] = None - keep_last_n_checkpoints: Optional[int] = None - - resume_from: str = None -``` - -- `save_dir` (`str`): The directory to save the checkpoints. -- `no_save` (`bool`): Whether to save the checkpoints. Default is `False`. -- `save_type` (`str`): The type of saving checkpoint. It can be `sharded` or `deduped`. Default is `sharded`. - - `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. - The checkpoint is a folder with as many files as the world size. - - `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. - The checkpoint is a folder with as many files as the world size. - - `"merged"`: everything has been merged into a single file. Used internally only when you merge the checkpoint files via `Trainer.merge_checkpoints` -- `save_last` (`bool`): Whether to save the last checkpoint. Default is `True`. -- `save_best` (`bool`): Whether to save the best (lowest `val_loss`) checkpoint. Default is `True`. -- `symlink_best_and_last` (`bool`): Whether to use symlink (instead of copy) to the best and last checkpoint. Default is `True`. -- `every_n_train_steps` (`Optional[int]`): Save the checkpoint every `every_n_train_steps` training steps. Default is `None`, which means no checkpoint is saved based on training steps. -- `every_n_epochs` (`Optional[int]`): Save the checkpoint every `every_n_epochs` epochs. Default is `None`, which means no checkpoint is saved based on epochs. -- `keep_last_n_checkpoints` (`Optional[int]`): Keep the last `keep_last_n_checkpoints` checkpoints. If we have more than `keep_last_n_checkpoints` checkpoints, we will remove the oldest ones. -Default is `None`, which means all checkpoints are kept. -- `resume_from` (`str`): The path to the checkpoint to resume from. It can be `last`/`best`/a specific folder/file. -We will not resume (nor report error) if resume_from is `last` or `best` but the corresponding checkpoint does not exist. -Default is `None`. - -Please note - -1. When the parallel plan is changed (i.e you re-trace the model with different configurations), -the checkpoints become incompatible, and can't be loaded any more. -You must firstly merge the checkpoints to a merged checkpoint with `Trainer.merge_checkpoint` and then load the merged checkpoint just like a regular checkpoint. - - ```python - def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): - ``` - where `checkpoint_files` is a list of checkpoint files to merge, and `output_file` is the output file path. - -2. When a checkpoint is saved, -we will run validation on the validation dataset and save the validation loss to the checkpoint file. -The validation run will ignore the `val_every_n_train_steps` and `val_every_n_epochs` configurations. -If no valid dataset is provided, validation is skipped and `valid_loss` is set to `train_loss` by default. - -3. The sharded checkpoints will contain PyTorch's RNG state, but not Python's or NumPy's. -The checkpoint's RNG state will be resumed right before training start, -which means the initialization stage will use `TrainerArgs.seed` instead. -Merged checkpoints will discard the RNG state. - -### Other configs -- `gen_savedir` (`str`): The directory to save the generated files. Default is `./.nnscaler`. -- `gen_reuse` (`str`): the reuse strategy of the generated code, it can be - - `auto`: automatically decide the reuse strategy (`moo` for `compile`, `match` for `run`) - - one of `match`/`override`/`moo`/`graph`. See `parallelize` API for more information. -- `pas_policy` (`str`): The policy of parameter partitioning. Default is `autodist`. -You can pass builtin pas policy name or your own pas policy function. -See `parallelize` API for more information. -- `broadcast_strategy` (`str`): The strategy of broadcasting the model. Default is `all`. See `parallelize` API for more information. -- `instance_name` (`str`): The instance name of the trainer. Default is `None`. See `parallelize` API for more information. -- `run_mode` (`str`): The run mode of the trainer. -It can be `run` (compile and train the model in a single python script OR train from previous compiling results) and `compile` (only compile the model for code generation). Default is `run`. -Please note you can only use `run` mode with `torchrun`. -On the other hand, if you disable broadcasting generated files (by setting `broadcast_strategy` to `none`), -you can run `compile` mode without `torchrun`. -- `tracing_from_weights` (`str`): The path to the weights to be loaded when tracing(compiling) the model. It is only used in tracing to serve as the initial state dict of the model. Default is `None`. -- `precison`(`Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None]`): The precision of the model. It can be a `str`, which means the same precision for all tensors, or a `Dict[_TENSOR_TYPE, _PRECISION_TYPE]`, which means the precision for each tensor type. Default is `None`. Currently we support 3 tensor types (`param`, `buffer`, `input`) and three precisions (`fp32`, `fp16`, `bf16`). You can set precision to `none` to avoid any precision conversion. -- `micro_batch_size` (`int`): The micro batch size. Default is `1`. -- `global_batch_size` (`Optional[int]`) and `grad_accumulation_steps` (`Optional[int]`): You can set one of `global_batch_size` and `grad_accumulation_steps` option. Please note if both are set, they must be consistent. Default is `micro_batch_size*scaling_factor` and `1` respectively. -- `max_epochs` (`Optional[int]`): The maximum number of epochs to train. Default is `None`, which means no limit. -- `max_train_steps` (`Optional[int]`): The maximum number of training steps to train. Default is `None`, which means no limit. -- `max_val_steps` (`Optional[int]`): The maximum number of validation steps to validate. Default is `None`, which means no limit. -- `val_every_n_train_steps` (`Optional[int]`): Validate every `val_every_n_train_steps` training steps. Default is `None`, which means no validation based on training steps. -- `val_every_n_epochs` (`Optional[int]`): Validate every `val_every_n_epochs` epochs. Default is `1`. -- `enable_progress_bar` (`bool`): Whether to enable the progress bar. Default is `True`. -- `seed` (`Optional[int]`): The random seed. Default is `None`. -- `init_env_fn` (`str`): The function to initialize the environment. Its only input is `Trainer`. Default is `None`. - -## CLI - -You can run the trainer with the following command: - -```bash -torchrun [torchrun arguments] ${NNSCALER_HOME}/cli/train.py -f ${CONFIG_FILE} [other arguments] -``` - -CONFIG_FILE is the path to the configuration yaml file. It looks like (taken from our test case) - -```yaml -compute_config: - plan_ngpus: 4 - runtime_ngpus: 100 - constant_folding: true - use_zero: true - use_end2end: true - -run_mode: run -pas_policy: autodist -micro_batch_size: 2 -global_batch_size: 8 -max_epochs: 4 -max_train_steps: 10 - -model: - type: tests.cli.common.MLP - args: - dim: 16 - nlayers: 16 - -optimizer: - type: torch.optim.Adam - args: - lr: 0.01 - -dataset: - type: tests.cli.common.SimpleDataset - train_args: - dim: 16 - size: 100 - val_args: - dim: 16 - size: 10 - -checkpoint: - keep_last_n_checkpoints: 30 - every_n_train_steps: 1 - save_type: deduped -``` - -All the arguments in the yaml file are the same as the arguments in the `TrainerArgs` class. -And they can be override with the command line arguments. -For example, you can override the `max_epochs` with `--max_epochs 2`, or override the `model` with `--model.args.dim 32 --model.args.nlayers 32`. diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst new file mode 100644 index 00000000..98b76dc5 --- /dev/null +++ b/docs/source/trainer.rst @@ -0,0 +1,769 @@ +####### +Native Trainer +####### + +``nnScaler`` provides a ``Trainer`` class for training and evaluating model parallelization. +Let's start from an example to demonstrate how to parallelize a model using the ``parallelize`` API. +Next, we'll illustrate how to train the model across multiple GPUs using the provided dataset and optimizer. + +********* +Arguments +********* + +All the arguments are defined in ``TrainerArgs`` class. Here is the definition of ``TrainerArgs``: + +.. code-block:: python + + @dataclass + class TrainerArgs: + compute_config: ComputeConfig = None + + gen_savedir: str = './.nnscaler' + gen_reuse: str = 'auto' + pas_policy: str = 'autodist' + broadcast_strategy: str = 'all' + instance_name: str = None + run_mode: str = 'run' + tracing_from_weights: str = None + + model: ModelConfig = field(default_factory=ModelConfig) + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + dataset: DatasetConfig = field(default_factory=DatasetConfig) + dataloader: DataloaderConfig = field(default_factory=DataloaderConfig) + dataset_sampler: DatasetSamplerConfig = field(default_factory=DatasetSamplerConfig) + lr_scheduler: Optional[LRSchedulerConfig] = None + checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) + log: List[LogConfig] = field(default_factory=list) + hook: Union[HookConfig, HookMapConfig, None] = None + + precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None + + micro_batch_size: int = 1 + global_batch_size: Optional[int] = None + grad_accumulation_steps: Optional[int] = None + + max_epochs: Optional[int] = None + max_train_steps: Optional[int] = None + max_val_steps: Optional[int] = None + + val_every_n_train_steps: Optional[int] = None + val_every_n_epochs: Optional[int] = 1 + + enable_progress_bar: bool = True + + seed: Optional[int] = None + init_env_fn: str = None + +The design philosophy of ``Trainer`` arguments is: +The classes(or factory functions) of components(model/optimizer/etc) +and their arguments are provided in the ``TrainerArgs`` class (functions/types are passed as fully qualified names), +and we are responsible for creating them. + +For example, you can tell me how to create a model by providing the model type and its arguments in ``ModelConfig`` class. + +Please note some of the arguments of components are set automatically, and you should not set them manually. +For example, arguments ``dataset``, ``num_replicas`` and ``rank`` of the dataset sampler are set automatically by the ``Trainer`` class. +Those 3 arguments passed in the ``DatasetSamplerConfig.train_args/val_args`` (if any) will be ignored. + +.. code-block:: python + + 'dataset': { + 'type': 'SomeDataset', + 'train_args': { + ... + }, + 'val_args': { + ... + } + } + 'dataset_sampler': { + 'type': 'SomeDatasetSampler', + 'train_args': { + 'num_replicas': ..., # this will be ignored + 'dataset': ..., # this will be ignored + 'rank': ..., # this will be ignored + ... + }, + 'val_args': { + 'num_replicas': ..., # this will be ignored + 'dataset': ..., # this will be ignored + 'rank': ..., # this will be ignored + ... + }, + } + +If any argument type is a class, you can pass it as a dict, and add a special key ``__type`` to specify the class type. + +For example, if the module ``__init__`` takes ``ModelConfig`` object + +.. code-block:: python + + class SomeModule(torch.nn.Module): + def __init__(self, model_config: ModelConfig): + ... + +You can pass the `model_config` as + +.. code-block:: python + + { + 'type': 'SomeModule', + 'args': { + 'model_config': { + '__type': 'ModelConfig', + # arguments to create ModelConfig + } + } + } + +We also use ``ast.literal_eval`` to guess the type of the string arguments, +You can skip it by passing a dict with ``__value_type`` and ``value`` keys. +For example, you want a number to be a str, you can use + +.. code-block:: python + + { + '__value_type': 'str', + 'value': '1' + } + +Internally we will get the final value with ``__value_type(value)``. + +Component Configs +================= + +* ``model`` (``ModelConfig``): The model to be trained. + You need to provide the model type and its arguments in ``ModelConfig`` class. + Here is the definition of ``ModelConfig``: + + .. code-block:: python + + @dataclass + class ModelConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + +* ``optimizer`` (``OptimizerConfig``): The optimizer to be used. + + .. code-block:: python + + @dataclass + class OptimizerConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + clip_gnorm: float = 0.0 + + loss_reduction: str = 'mean' + grad_reduction: str = 'mean' + aggregate_outputs_fn: str = None + + * ``type`` (``str``): The optimizer type or factory function. + Please note the first parameter of the optimizer constructor must be the model parameters. + * ``args`` (``Dict[str, Any]``): The arguments of the optimizer. + * ``clip_gnorm`` (``float``): The maximum norm value for gradient clipping. 0.0/None means no clipping. + * ``loss_reduction`` (``str``): The reduction method for loss. + It can be ``mean`` (average the loss over all micro-batches), + ``sum`` (sum the loss of all micro-batches). + Default is ``mean``. + Please note in validation stage, this configuration is ignored the loss is always averaged over all batches + * ``grad_reduction`` (``str``): The reduction method for gradients. It can be ``mean`` (average the gradients over all micro-batches), ``sum`` (sum the gradients of all micro-batches), ``per-token-mean`` (average the gradients over all tokens). Default is ``mean``. Please note if ``per-token-mean`` is used, you need to specify ``aggregate_outputs_fn``, which will return the number of tokens + * ``aggregate_outputs_fn`` (``str``): The function to aggregate the outputs of the model. It is required when ``grad_reduction`` is ``per-token-mean``. Its signature should be ``def aggregate_outputs(self, loss_outputs, sync_group) -> AggregatedOutputs``, where ``loss_outputs`` is a list of outputs of the model, and ``sync_group`` is the ``torch.distributed.ProcessGroup`` to sync with. The function should return an ``AggregatedOutputs`` object, which defines as: + + .. code-block:: python + + @dataclass + class AggregatedOutputs: + # the aggregated loss as a sum + loss_sum: float = None + # number of mini batches + num_batches: int = None + # number of tokens (necessary when grad_reduction is 'per-token-mean') + num_tokens: Optional[int] = None + # any other custom outputs + aggregated_outputs: Any = None + +* ``dataset`` (``DatasetConfig``): The dataset to be used. + + .. code-block:: python + + @dataclass + class DatasetConfig: + type: str = None + train_args: Dict[str, Any] = field(default_factory=dict) + val_args: Dict[str, Any] = field(default_factory=dict) + + * ``type`` (``str``): The dataset type or factory function. + * ``train_args`` (``Dict[str, Any]``): The arguments of the training dataset. + * ``val_args`` (``Dict[str, Any]``): The arguments of the validation dataset. +* ``dataloader`` (``DataloaderConfig``): The dataloader to be used. + Please note we recommend to pass ``drop_last=True`` in the dataloader arguments to avoid the last batch with different sizes. + + .. code-block:: python + + @dataclass + class DataloaderConfig: + type: str = 'torch.utils.data.DataLoader' + train_args: Dict[str, Any] = field(default_factory=dict) + # default to train_args + val_args: Dict[str, Any] = field(default_factory=dict) + # default to train_args + test_args: Dict[str, Any] = field(default_factory=dict) + + * ``type`` (``str``): The dataloader type or factory function. + Please note the dataloader constructor must at least have 3 parameters ``dataset``, ``batch_size``, ``sampler``. + * ``train_args`` (``Dict[str, Any]``): The arguments (except ``dataset``,``batch_size``, ``sampler``) of the training dataloader. + Argument ``batch_size`` will be set to ``micro_batch_size``. + * ``val_args`` (``Dict[str, Any]``): The arguments (except ``dataset``,``batch_size``, ``sampler``) of the validation dataloader. + +* ``dataset_sampler`` (``DatasetSamplerConfig``): The dataset sampler to be used. + + .. code-block:: python + + @dataclass + class DatasetSamplerConfig: + type: str = 'torch.utils.data.DistributedSampler' + train_args: Dict[str, Any] = field(default_factory=dict) + val_args: Dict[str, Any] = field(default_factory=dict) + test_args: Dict[str, Any] = field(default_factory=dict) + + * ``type`` (``str``): The dataset sampler type or factory function. + Please note the dataset sampler constructor must at least have 3 parameters ``dataset``, ``num_replicas``, ``rank``. + * ``train_args`` (``Dict[str, Any]``): The arguments (except ``dataset``,``num_replicas``, ``rank``) of the training dataset sampler. + * ``val_args`` (``Dict[str, Any]``): The arguments (except ``dataset``,``num_replicas``, ``rank``) of the validation dataset sampler. + +* ``lr_scheduler`` (``LRSchedulerConfig``): The learning rate scheduler to be used. This is optional. + + .. code-block:: python + + @dataclass + class LRSchedulerConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + interval: str = 'epoch' + + * ``type`` (``str``): The learning rate scheduler type or factory function. + Please note the first parameter of the learning rate scheduler constructor must be optimizer. + * ``args`` (``Dict[str, Any]``): The arguments of the learning rate scheduler. + * ``interval`` (``str``): The interval to update the learning rate. It can be ``epoch`` or ``step``. Default is ``epoch``. + +* ``log`` (``List[LogConfig]``): The loggers to be used. You can provide multiple loggers. + Currently we have two builtin loggers: ``TensorBoardLogger`` and ``WandbLogger``. + + .. code-block:: python + + @dataclass + class LogConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + + * ``type`` (``str``): The logger type or factory function. + * ``args`` (``Dict[str, Any]``): The arguments of the logger. + +* ``hook`` (``Union[HookConfig, HookMapConfig, None]``): The hooks to be used. + You can provide a hook with a hook class or a map of hook functions. + Please note if your ``model``/``optimizer``/``lr_scheduler`` inherit from ``TrainHook``, + their hook functions will be called automatically. + The order of the hook functions called is ``model`` -> ``optimizer`` -> ``lr_scheduler``, + and hooks passed with this config is called in the last. + + Hook class: + + .. code-block:: python + + @dataclass + class HookConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + + * ``type`` (``str``): The hook type or factory function. + * ``args`` (``Dict[str, Any]``): The arguments of the hook. + + Hook map: + + .. code-block:: python + + @dataclass + class HookMapConfig: + after_setup: str = None + + on_train_start: str = None + on_train_end: str = None + on_val_start: str = None + on_val_end: str = None + + on_epoch_start: str = None + on_epoch_end: str = None + + on_train_step_start: str = None + on_train_step_end: str = None + on_val_step_start: str = None + on_val_step_end: str = None + + after_aggregate_train_step_outputs: str = None + after_aggregate_val_step_outputs: str = None + + before_zero_grad: str = None + after_zero_grad: str = None + + before_gnorm_clip: str = None + after_gnorm_clip: str = None + + before_optimizer_step: str = None + after_optimizer_step: str = None + + on_load_checkpoint: str = None + on_save_checkpoint: str = None + + * ``after_setup`` (``str``): The hook function to be called after setting up the trainer. + Only be called when ``run_mode == 'run'``. + Signature: ``def after_setup(trainer: 'Trainer') -> None:`` + * ``on_train_start`` (``str``): The hook function to be called at the start of the training stage. Signature: ``def on_train_start(trainer: 'Trainer') -> None:`` + * ``on_train_end`` (``str``): The hook function to be called at the end of the training stage. Signature: ``def on_train_end(trainer: 'Trainer') -> None:`` + * ``on_val_start`` (``str``): The hook function to be called at the start of the validation stage. Signature: ``def on_val_start(trainer: 'Trainer') -> None:`` + * ``on_val_end`` (``str``): The hook function to be called at the end of the validation stage. Signature: ``def on_val_end(trainer: 'Trainer', val_loss: float) -> None:`` + * ``on_epoch_start`` (``str``): The hook function to be called at the start of each epoch. Signature: ``def on_epoch_start(trainer: 'Trainer', epoch: int) -> None:`` + * ``on_epoch_end`` (``str``): The hook function to be called at the end of each epoch. Signature: ``def on_epoch_end(trainer: 'Trainer', epoch: int) -> None:`` + * ``on_train_step_start`` (``str``): The hook function to be called at the start of each training step. + Signature: ``def on_train_step_start(trainer: 'Trainer', batches: List[Any], idx: int) -> None:`` + * ``on_train_step_end`` (``str``): The hook function to be called at the end of each training step. Signature: ``def on_train_step_end(trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None:`` + * ``on_val_step_start`` (``str``): The hook function to be called at the start of each validation step. Signature: ``def on_val_step_start(trainer: 'Trainer', batches: List[Any], idx: int) -> None:`` + * ``on_val_step_end`` (``str``): The hook function to be called at the end of each validation step. Signature: ``def on_val_step_end(trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None:`` + * ``after_aggregate_train_step_outputs`` (``str``): The hook function to be called after aggregating the outputs of the model in the training step. Signature: ``def after_aggregate_train_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None:`` + * ``after_aggregate_val_step_outputs`` (``str``): The hook function to be called after aggregating the outputs of the model in the validation step. Signature: ``def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None:`` + * ``before_zero_grad`` (``str``): The hook function to be called before zeroing the gradients. Signature: ``def before_zero_grad(trainer: 'Trainer') -> None:`` + * ``after_zero_grad`` (``str``): The hook function to be called after zeroing the gradients. Signature: ``def after_zero_grad(trainer: 'Trainer') -> None:`` + * ``before_sync_grad`` (``str``): The hook function to be called before syncing the gradients between ranks. + Please note this hook can't be triggered correctly, + and you should not reply on this. + Will fix it later. + Signature: ``def before_sync_grad(trainer: 'Trainer') -> None:`` + * ``after_sync_grad`` (``str``): The hook function to be called after syncing the gradients between ranks. + Signature: ``def after_sync_grad(trainer: 'Trainer') -> None:`` + * ``before_gnorm_clip`` (``str``): The hook function to be called before gradient clipping. + Signature: ``def before_gnorm_clip(trainer: 'Trainer') -> None:`` + * ``after_gnorm_clip`` (``str``): The hook function to be called after gradient clipping. + Signature: ``def after_gnorm_clip(trainer: 'Trainer', gnorm: torch.Tensor) -> None:`` + * ``before_optimizer_step`` (``str``): The hook function to be called before the optimizer step. + Signature: ``def before_optimizer_step(trainer: 'Trainer') -> None:`` + * ``after_optimizer_step`` (``str``): The hook function to be called after the optimizer step. + Signature: ``def after_optimizer_step(trainer: 'Trainer') -> None:`` + * ``on_load_checkpoint`` (``str``): The hook function to be called after loading the checkpoint. + If you saved something with ``on_save_checkpoint`` this is your chance to restore this. + Signature: ``def on_load_checkpoint(trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None:`` + * ``on_save_checkpoint`` (``str``): The hook function to be called before saving the checkpoint. + If you want to save something, you can add it to the checkpoint here. + Signature: ``def on_save_checkpoint(trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None:`` + +Compute Config +============== + +.. _end2end: + +All compute configs are put in ``compute_config`` (``ComputeConfig``). Please refer to :ref:`ComputeConfig ` for more information. + +Please note only end2end mode is supported in the trainer, so you must set ``compute_config.use_end2end`` to ``True`` to make it work. + +An end2end module is a module which satisfies: + +* the first argument of ``module.forward`` is the data sample, and every other argument should have default value, + and use its default value in ``module.forward`` function. +* the first return value of ``module.forward`` is the loss (scalar tensor) + +Checkpoint Config +================= + + .. code-block:: python + + @dataclass + class CheckpointConfig: + save_dir: str = './checkpoints' + no_save: bool = False + + save_type: str = 'sharded' + + save_last: bool = True + save_best: bool = True + symlink_best_and_last: bool = True + + every_n_train_steps: Optional[int] = None + every_n_epochs: Optional[int] = None + keep_last_n_checkpoints: Optional[int] = None + + resume_from: str = None + +* ``save_dir`` (``str``): The directory to save the checkpoints. +* ``no_save`` (``bool``): Whether to save the checkpoints. Default is ``False``. +* ``save_type`` (``str``): The type of saving checkpoint. It can be ``sharded`` or ``deduped``. Default is ``sharded``. + + * ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. + The checkpoint is a folder with as many files as the world size. + * ``"deduped"``: Each rank saves its deduped shard of weights and optimizer states to a file. + The checkpoint is a folder with as many files as the world size. + * ``"merged"``: everything has been merged into a single file. + Used internally only when you merge the checkpoint files via ``Trainer.merge_checkpoints`` + +* ``save_last`` (``bool``): Whether to save the last checkpoint. Default is ``True``. +* ``save_best`` (``bool``): Whether to save the best (lowest ``val_loss``) checkpoint. Default is ``True``. +* ``symlink_best_and_last`` (``bool``): Whether to use symlink (instead of copy) to the best and last checkpoint. Default is ``True``. +* ``every_n_train_steps`` (``Optional[int]``): Save the checkpoint every ``every_n_train_steps`` training steps. Default is ``None``, which means no checkpoint is saved based on training steps. +* ``every_n_epochs`` (``Optional[int]``): Save the checkpoint every ``every_n_epochs`` epochs. Default is ``None``, which means no checkpoint is saved based on epochs. +* ``keep_last_n_checkpoints`` (``Optional[int]``): Keep the last ``keep_last_n_checkpoints`` checkpoints. If we have more than ``keep_last_n_checkpoints`` checkpoints, we will remove the oldest ones. + Default is ``None``, which means all checkpoints are kept. +* ``resume_from`` (``str``): The path to the checkpoint to resume from. It can be ``last``/``best``/a specific folder/file. + We will not resume (nor report error) if resume_from is ``last`` or ``best`` but the corresponding checkpoint does not exist. + Default is ``None``. + +Please note + +#. When the parallel plan is changed (i.e you re-trace the model with different configurations), + the checkpoints become incompatible, and can't be loaded any more. + You must firstly merge the checkpoints to a merged checkpoint with ``Trainer.merge_checkpoint`` and then load the merged checkpoint just like a regular checkpoint. + + .. code-block:: python + + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): + + where ``checkpoint_files`` is a list of checkpoint files to merge, and ``output_file`` is the output file path. + +#. When a checkpoint is saved, + we will run validation on the validation dataset and save the validation loss to the checkpoint file. + The validation run will ignore the ``val_every_n_train_steps`` and ``val_every_n_epochs`` configurations. + If no valid dataset is provided, validation is skipped and ``valid_loss`` is set to ``train_loss`` by default. + +#. The sharded checkpoints will contain PyTorch's RNG state, but not Python's or NumPy's. + The checkpoint's RNG state will be resumed right before training start, which means the initialization stage will use `TrainerArgs.seed` instead. + Merged checkpoints will discard the RNG state. + +Other configs +============= + +* ``gen_savedir`` (``str``): The directory to save the generated files. Default is ``./.nnscaler``. +* ``gen_reuse`` (``str``): the reuse strategy of the generated code, it can be + + * ``auto``: automatically decide the reuse strategy (``moo`` for ``compile``, ``match`` for ``run``) + * one of ``match``/``override``/``moo``/``graph``. See ``parallelize`` API for more information. + +* ``pas_policy`` (``str``): The policy of parameter partitioning. Default is ``autodist``. + You can pass builtin pas policy name or your own pas policy function. + See ``parallelize`` API for more information. +* ``broadcast_strategy`` (``str``): The strategy of broadcasting the model. Default is ``all``. See ``parallelize`` API for more information. +* ``instance_name`` (``str``): The instance name of the trainer. Default is ``None``. See ``parallelize`` API for more information. +* ``run_mode`` (``str``): The run mode of the trainer. + It can be ``run`` (compile and train the model in a single python script OR train from previous compiling results) + and ``compile`` (only compile the model for code generation). + Default is ``run``. + Please note you can only use ``run`` mode with ``torchrun``. + On the other hand, if you disable broadcasting generated files (by setting ``broadcast_strategy`` to ``none``), + you can run ``compile`` mode without ``torchrun``. +* ``tracing_from_weights`` (``str``): The path to the weights to be loaded when tracing(compiling) the model. It is only used in tracing to serve as the initial state dict of the model. Default is ``None``. +* ``precison``(``Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None]``): The precision of the model. It can be a ``str``, which means the same precision for all tensors, or a ``Dict[_TENSOR_TYPE, _PRECISION_TYPE]``, which means the precision for each tensor type. Default is ``None``. Currently we support 3 tensor types (``param``, ``buffer``, ``input``) and three precisions (``fp32``, ``fp16``, ``bf16``). You can set precision to ``none`` to avoid any precision conversion. +* ``micro_batch_size`` (``int``): The micro batch size. Default is ``1``. +* ``global_batch_size`` (``Optional[int]``) and ``grad_accumulation_steps`` (``Optional[int]``): You can set one of ``global_batch_size`` and ``grad_accumulation_steps`` option. Please note if both are set, they must be consistent. Default is ``micro_batch_size*scaling_factor`` and ``1`` respectively. +* ``max_epochs`` (``Optional[int]``): The maximum number of epochs to train. Default is ``None``, which means no limit. +* ``max_train_steps`` (``Optional[int]``): The maximum number of training steps to train. Default is ``None``, which means no limit. +* ``max_val_steps`` (``Optional[int]``): The maximum number of validation steps to validate. Default is ``None``, which means no limit. +* ``val_every_n_train_steps`` (``Optional[int]``): Validate every ``val_every_n_train_steps`` training steps. Default is ``None``, which means no validation based on training steps. +* ``val_every_n_epochs`` (``Optional[int]``): Validate every ``val_every_n_epochs`` epochs. Default is ``1``. +* ``enable_progress_bar`` (``bool``): Whether to enable the progress bar. Default is ``True``. +* ``seed`` (``Optional[int]``): The random seed. Default is ``None``. +* ``init_env_fn`` (``str``): The function to initialize the environment. Default is ``None``. + +*** +CLI +*** + +You can run the trainer with the following command: + +.. code-block:: bash + + torchrun [torchrun arguments] ${NNSCALER_HOME}/cli/train.py -f ${CONFIG_FILE} [other arguments] + +CONFIG_FILE is the path to the configuration yaml file. It looks like (taken from our test case) + +.. code-block:: yaml + + compute_config: + plan_ngpus: 4 + runtime_ngpus: 100 + constant_folding: true + use_zero: true + use_end2end: true + + run_mode: run + pas_policy: autodist + micro_batch_size: 2 + global_batch_size: 8 + max_epochs: 4 + max_train_steps: 10 + + model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + + optimizer: + type: torch.optim.Adam + args: + lr: 0.01 + + dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + + checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped + +All the arguments in the yaml file are the same as the arguments in the ``TrainerArgs`` class. +And they can be override with the command line arguments. +For example, you can override the ``max_epochs`` with ``--max_epochs 2``, or override the ``model`` with ``--model.args.dim 32 --model.args.nlayers 32``. + +*********************** +Appendix: ComputeConfig +*********************** + +.. _computeconfig: + +ComputeConfig +============= + +The configuration of the compute environment. It is a dataclass with the following fields: + +.. code-block:: python + + @dataclass(frozen=True) + class ComputeConfig: + plan_ngpus: int + runtime_ngpus: int + + constant_folding: bool = False + trace_strategy: Literal['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'] = 'cuda_run_cpu_offload' + + use_zero: bool = False + zero_ngroups: int = 1 + + inference_only : bool = False + use_end2end: bool = False + + pas_config: Dict[str, Any] = field(default_factory=dict) + user_config: Dict[str, Any] = field(default_factory=dict) + +We can categorize the fields into 4 categories: + +#. Trace configuration + + * ``constant_folding``: whether to enable constant folding when generating code. + When it is true, all non-tensor non-input values will be folded into the generated code. + + For example, if user's code contains following snippet, and ``bsz=1``, ``num_heads=32``, ``len=1024``, ``hidden_dim=128`` at tracing. + + .. code-block:: python + + bsz, num_heads, len, hidden_dim = x.size() + x = x.view(bsz * num_heads, len, hidden_dim) + + The code (graph) is folded into the following format + + .. code-block:: python + + y = x.view(32, 1024, 128) + + Constant folding is helpful to simplify the input program, + and can make the compiling process faster and reduce the communication cost at runtime. + However, user should make sure that inputs at runtime share a same schema (including shape) with tracing and correspond to a same computation graph. + Errors may be raised at runtime when this assumption is broken. + * ``trace_strategy``: how to execute the functions during trace. + Five strategies are supported: + + #. ``cpu``: Execute all functions on cpu device, model weights and intermediate results are on cpu device. + #. ``cuda``: Execute all functions on cuda device, model weights and intermediate results are on cuda device. This strategy is recommended if the model can inference on single gpu. + #. ``meta``: Execute all functions on meta device, model weights are on cpu and intermediate results are on meta device. For more information about meta device type, please view https://pytorch.org/docs/stable/meta.html. + #. ``cuda_run_cpu_offload``: Try to execute all functions on cuda, and retry to execute the function on cpu as backup if OOM is catched, model weights and intermediate results are on cpu. This strategy is recommanded for most case if the model is too large to inference on single gpu. + #. ``reuse_cache``: Compared to ``cuda_run_cpu_offload`` strategy, maintains a map from function signatures to output values. The cached output is returned when the signature of the function that generates it has been executed. Same signature means the funtions are the same and have almost the same inputs (for tensor type input, just check if they have same tensor meta data[shape, dtyep, requires_grad, stride, memory_format, ...], and don't check the value). This strategy is an experimental strategy to speedup the large-model-large-input case, and have risk to trace an incorrect graph if the signature defined here can not distinguish the differnet functions used in the model, for example, torch.nonzero will always return the same result if the input have same meta data but different value. We have plan to continue improve this strategy to handle most these kind of data dependence cases, but please note that the risk is still inevitable. + +#. Compute environment configuration + + * ``plan_ngpus``: the number of gpus to be used as a unit. The model is partitioned (TP or PP) within a unit, and then data parallelism is applied across multiple units. So every ``plan_ngpus`` devices holds the whole model. Furthermore, assume we have two workers, and their ranks are ``rank1`` and ``rank2``: + + #. if ``rank1 // plan_gpus == rank2 // plan_ngpus``, then they are in the same unit. + #. If ``rank1 % plan_ngpus == rank2 % plan_ngpus``, then the portion of model hold on both gpus are exactly the same. + + * ``runtime_ngpus``: the number of gpus to be used in runtime. It should be a multiple of ``plan_ngpus``, which means we have ``runtime_ngpus // plan_ngpus`` units in runtime, and the data parallelism is ``runtime_ngpus // plan_ngpus``. + Please note all modules must have the same ``plan_ngpus`` and ``runtime_ngpus``. + +#. Code generation feature configuration + + * ``use_zero``: whether to use zero. If it is true, the generated code will use zero1 to do distributed training. + * ``zero_ngroups``: the number of groups to be used in zero. + * ``inference_only``: whether to generate code for inference only. If it is true, the generated code can not be used to train the model. + * ``use_end2end``: whether to use end2end training. For the requirement of end2end, see the description above. + * ``pas_config``: the configuration for the PAS policy (partition-assign-schedule policy, which describes how to place all computations across devices. For details, please refer to :ref:`PAS Policies `. + It is a dictionary, and will be used by the PAS policy. + Please note different PAS will have different configurations, + You can also put any other settings that can affect code generation here. but please prefix the keys with ``_`` to avoid conflicts with PAS configurations. + * ``user_config``: the user configuration, which is used to decide whether skipping compiling and reusing the previously traced graph. + +Note: + +#. You can put any custom configurations in ``user_config``. The assumption is different ``user_config`` should generate different graph/code. So if the user config is changed, we will regenerate the graph/code automatically. Here are some examples: + + * Example 1: save module configuration + + .. code-block:: python + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + ... + if module_config.use_3d: + ... + + here we can set ``user_config`` to ``{'use_3d': module_config.use_3d}``, + and we can be sure different use_3d config will never use the same graph (and eventually the generated code). + + * Example 2: save file stats + + If you want to track all related file stats (just like traditional compilers do), + you can save the md5 of the files to save some bytes: + + .. code-block:: python + + import hashlib + h = hashlib.md5() + for f in Path('./src').glob('**/*.py'): + with open(f, 'rb') as f: + h.update(f.read()) + compute_config = { + ...., + user_config: { + 'files_md5': h.hexdigest() + } + } + +#. If some settings doesn't affect tracing/graph generation, but do affect code generation, you can put them in ``pas_config``. Please prefix the keys with ``_`` to avoid conflicts with predefined PAS configurations. One typical example is you can put the name of selected PAS policy in ``pas_config``, so changing PAS policy will regenerate code but the graph will be reused. + + .. code-block:: python + + compute_config = ComputeConfig( + ... + pas_config={ + '_pas_name': ..., + # PAS policy specific configurations + ... + }, + ) + +ReuseType +========= + +The reuse policy for the existing generated code. It is an enum with the following values: + +.. code-block:: python + + class ReuseType(Enum): + MATCH = 'match' + OVERRIDE = 'override' + MOO = 'moo' + GRAPH = 'graph' + +We call it a ``match`` when the ``ComputeConfig`` is the same with the previous run. + +#. ``MATCH``: Reuse if match, error if not match, generate if no previous gerenated code exists. +#. ``OVERRIDE``: Nothing will be reused. Everything will be regenerated. +#. ``MOO``: ``MOO`` is short for 'match or override'. It will reuse if match, generate if not match or no previous generated code exists. +#. ``GRAPH``: Reuse graph only if match, generate otherwise. + +.. _pas-policies: + +PAS Policies +============ + +Writing a pas policy can be very hard and error-prone. So we provide 6 builtin PAS policies to help you. ``dp``, ``tp``, ``pp``, ``data``, ``hybrid``, and ``autodist``. Please note only ``autodist`` policy is the recommended policy for most cases, and all other PAS policies are mainly test purpose only. + +The configuration of the PAS policy should be passed in the ``pas_config`` of ``ComputeConfig`` as a dictionary. + +#. ``dp``: data parallelism. It will replicate the module across all devices, and run data parallelism across all devices. It requires the ``plan_ngpus`` must be 1 and no configurations + +#. ``tp``: tensor parallelism + data parallelism. It will do tensor parallelism inside a scale unit, and run data parallelism across scale units. It has only one configuration: + + * seed: the random seed for choose the partition dimension. Default is ``1`` + +#. ``pp``: pipeline parallelism + data parallelism. + It will do model parallelism inside a scale unit, + and run data parallelism across scale units. + It requires the ``use_end2end`` be true. + It has two configurations ``pipeline_nmicros`` and ``pipeline_scheduler``. + See ``hybrid`` policy for more details. + +#. ``data``: tensor parallelism on batch dimension. It has no configurations. + +#. ``hybrid``: pipeline parallelism + tensor parallelism + data parallelism. + It will do model parallelism and tensor parallelism(on 0 dimension) inside a scale unit, + and run data parallelism across scale units. + It requires the ``use_end2end`` to be true. It has the following configurations. + + * ``pipeline_nstages``: the number of stages in the pipeline. Default is ``plan_ngpus``. Optional. + * ``pipeline_nmicros``: the number of microbatches in the pipeline. Required. + * ``pipeline_scheduler``: the scheduler name for the pipeline. Current we support four schedulers in training ``1f1b``/``1f1b_plus``/``gpipe``/``chimera_direct`` (4 stages pipeline only), and one scheduler in inference ``infer_pipe``. Default is ``1f1b``. Optional. + +#. ``autodist``: the recommended policy for most cases. Currently it only support Adam-like optimizers. It will automatically choose the best partition for you by balancing the memory usage and speed. It has the following configurations. + + * ``update_freq (int)``: the update frequency when training the module. Default is 1. Optional. + * ``mem_constraint (float)``: The memory constraint in each device in GB. Optional. + * ``task_name (str)``: The name of the current task to distinguish runs. Optional. + * ``use_fp16 (bool)``: Whether you use ``fp16``. Default is ``False``. Optional. + * ``use_memory_efficient_fp16`` Whether you use memory efficient fp16 optimizer. Default is ``False``. Optional. + * ``use_bf16``: Whether you use ``bf16``. Default is ``False``. Optional. + * ``use_memory_efficient_bf16``: Whether you use memory efficient bf16 optimizer. Default is ``False``. Optional. + * ``re_profile (bool)``: If set to ``True``, the computation profiling results will be overridden. Please note reprofiling will take some time. Optional. + * ``verbose (bool)``: Whether to print verbose information. Optional. + * ``load_plan_path (str)``: The path to the plan file to load. If specified, the plan will be loaded from the file instead of searching. Optional. + * ``save_plan_path (str)``: The path to the plan file to save. Optional. + * ``partition_constraints_path (str)``: The path to the partition constraints file. Optional. + * ``recompute_modules (str)``: The module names to recompute, separated by ``,``. For example, ``module1,module2``. Optional. + * ``pipeline_pivots (str)``: The module names to pivot the pipeline, separated by ``,``. For example, if ``module1,module2`` is specified, stages searched by pipeline solver only start from either ``module1`` or ``module2``. Optional. + * ``use_apex_fused_adam_v2``: If set to ``True``, the apex fused adam v2 optimizer will be used. Default is ``False``. Optional. + * ``explore_pipeline``: If set to ``True``, autodist will try pipeline parallelism to find the best partition plan + (but the selected partition plan is not necessarily pipeline parallelism). + * ``pipeline_scheduler``: The scheduler name for the pipeline. Please note currently ``1f1b`` is the only supported scheduler in ``autodist``. Default is ``1f1b``. Optional. + * ``parallel_profile``: If set to ``True``, autodist will profile operators in parallel by using available gpus. Default is ``True``. Optional. + * ``max_partition_degree``: Max degree when partitioning an operator / node. When pipeline parallelism is enabled to explore (``explore_pipeline`` is True), user can change the value to constrain the plan to be composed of stages that span on less or equal to ``max_partition_degree`` devices (recommend to set ``max_partition_degree`` to the number of devices in a node to avoid inter-node communication, but should be be no more than ``plan_ngpus``). Default is ``plan_ngpus``. Optional. + * ``transient_mem_coef``: In autodist, a heuristic is used to estimate the transient memory size: ``transient_mem_size = opt_transient_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)``. This formula is useful in many cases, but it may be too strict when some operators consume or generate a large tensor (>= 4GB). In this case, you can set ``transient_mem_coef`` to a smaller value to relax the constraint. Default is ``2``. Optional. + +You can also put any other settings that can affect code generation here. but please prefix the keys with ``_`` to avoid conflicts with predefined keys. + +Here is an example: + +.. code-block:: python + + compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + pas_config={ + '__pas_name': ..., # addtional configurations that can affect code generation. + 'update_freq': ..., + 'mem_constraint': ..., + 'task_name': ..., + 'use_fp16': ..., + 'use_memory_efficient_fp16': ..., + 'use_bf16': ..., + 'use_memory_efficient_bf16': ..., + 're_profile': ..., + 'verbose': ..., + 'load_plan_path': ..., + 'save_plan_path': ..., + 'partition_constraints_path': ..., + 'recompute_modules': ..., + 'pipeline_pivots': ..., + 'use_apex_fused_adam_v2': ..., + }, + ) diff --git a/docs/source/troubleshooting.rst b/docs/source/troubleshooting.rst new file mode 100644 index 00000000..10a7a01f --- /dev/null +++ b/docs/source/troubleshooting.rst @@ -0,0 +1,291 @@ +############### +Troubleshooting +############### + +Reuse Cache +=========== + +I have modified the model but the result does not change +-------------------------------------------------------- + +Remove ``.nnscaler`` directory in the working path and try again. + +nnScaler's workflow is first compiling the model, and then running the compiled (generated) model. +After modifying the original model, you need to tell nnScaler to re-compile it. + +This can be achieved by two ways: + +1. Remove the compiled model (located in ``.nnscaler`` directory); +2. Set ``TrainerArgs.gen_reuse`` to ``"override"``. + +We recommend to set ``gen_reuse="override"`` to debug the model, +and change it to ``gen_reuse="auto"`` for deployment. + +.. code-block:: python + + trainer_args = TrainerArgs( + gen_reuse='override', + ... + ) + trainer = Trainer(trainer_args=trainer_args) + trainer.run() + +Note that setting ``gen_reuse="match"`` will NOT solve this problem, +since it only checks ``compute_config``, not the model. + +"RuntimeError: Output directory ... is not empty. And the existing files do not match..." after modifying models +---------------------------------------------------------------------------------------------------------------- + +As the error message said, please remove the ``.nnscaler`` directory. + +To prevent this kind of errors permanently, you can set ``gen_reuse`` to ``"override"``, at the expense of time. + +Example stacktrace: :: + + Traceback (most recent call last): + File "train.py", line 244, in + main() + File "train.py", line 240, in main + trainer.run() + File ".../nnscaler/cli/trainer.py", line 95, in run + self._setup() + File ".../nnscaler/cli/trainer.py", line 206, in _setup + pmodel_class = nnscaler.parallelize( + File ".../nnscaler/parallel.py", line 983, in parallelize + outdir, reusable = _prepare_and_check_reusable(gen_savedir, module_class, compute_config, instance_name, reuse) + File ".../nnscaler/parallel.py", line 547, in _prepare_and_check_reusable + raise RuntimeError(f'Output directory {outdir} is not empty. ' + RuntimeError: Output directory .../.nnscaler/_parallel_modules/__main__/WrapperModel/_ is not empty. And the existing files do not match with current config. You can remove the directory and try again, or set reuse to ReuseType.NONE/ReuseType.OVERRIDE to regenerate the code. + +Known Issues +============ + +"KeyError: '__mro__'" and errors mentioning "_dynamo" +----------------------------------------------------- + +Add ``import torch._dynamo`` to the beginning of your main script. + +Due to a limitation in nnScaler, the dynamic import of ``torch._dynamo`` cannot be correctly traced. +This can be workaround by importing it before tracing. + +Example stacktrace: :: + + Traceback (most recent call last): + File "train.py", line 286, in + trainer.run() + File ".../nnscaler/cli/trainer.py", line 95, in run + self._setup() + File ".../nnscaler/cli/trainer.py", line 206, in _setup + pmodel_class = nnscaler.parallelize( + File ".../nnscaler/parallel.py", line 993, in parallelize + regen_status = _gencode( + + ...... + + File ".../site-packages/transformers/models/llama/modeling_llama.py", line 1041, in _update_causal_mask + if AttentionMaskConverter._ignore_causal_mask_sdpa( + File ".../nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py", line 354, in patch_run + return new_func(*args, **kwargs) + File ".../site-packages/transformers/modeling_attn_mask_utils.py", line 259, in _ignore_causal_mask_sdpa + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + File ".../nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py", line 354, in patch_run + return new_func(*args, **kwargs) + File ".../site-packages/torch/__init__.py", line 2003, in __getattr__ + return importlib.import_module(f".{name}", __name__) + File ".../importlib/__init__.py", line 126, in import_module + return _bootstrap._gcd_import(name[level:], package, level) + + ...... + + File ".../site-packages/torch/_dynamo/utils.py", line 567, in unwrap_with_attr_name_if_wrapper + elif is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False): + File ".../inspect.py", line 1738, in getattr_static + if not _is_type(obj): + File ".../inspect.py", line 1707, in _is_type + _static_getmro(obj) + File ".../inspect.py", line 1685, in _static_getmro + return type.__dict__['__mro__'].__get__(klass) + KeyError: '__mro__' + +"ModuleNotFoundError: No module named 'nnscaler.autodist.dp_solver'" when using editable install +------------------------------------------------------------------------------------------------ + +Run the following command: :: + + python -c 'import os,sys,nnscaler,cppimport.import_hook ; sys.path.append(os.path.dirname(nnscaler.__path__[0])) ; import nnscaler.autodist.dp_solver' + +Example stacktrace: :: + + Traceback (most recent call last): + File "model.py", line 48, in + trainer.run() + File ".../nnscaler/cli/trainer.py", line 95, in run + self._setup() + File ".../nnscaler/cli/trainer.py", line 206, in _setup + pmodel_class = nnscaler.parallelize( + File ".../nnscaler/parallel.py", line 988, in parallelize + regen_status = _gencode( + File ".../nnscaler/parallel.py", line 753, in _gencode + graph = pas_policy(graph, compute_config) + File ".../nnscaler/policies.py", line 303, in pas_autodist + return parallelize_graph(graph, autodist_cfg) + File ".../nnscaler/autodist/apis.py", line 117, in parallelize_graph + search_out = calc_parallel_plan(graph, autodist_config) + File ".../nnscaler/autodist/apis.py", line 98, in calc_parallel_plan + pp_out = calc_optimal_spmd_plan(autodist_graph, autodist_config) + File ".../nnscaler/autodist/spmd_solver.py", line 1503, in calc_optimal_spmd_plan + spmd_outs = spmd_solver.solve([(0, model_graph.op_num - 1)], 1)[0] + File ".../nnscaler/autodist/spmd_solver.py", line 1374, in solve + return self.do_dp(intervals, topk) + File ".../nnscaler/autodist/spmd_solver.py", line 1183, in do_dp + import nnscaler.autodist.dp_solver as dp_solver + ModuleNotFoundError: No module named 'nnscaler.autodist.dp_solver' + +Incorrect Usages +================ + +"RuntineError: Loss can only be scalar tensor ..." when forward returns dict +---------------------------------------------------------------------------- + +When using nnScaler's Trainer, the return value of the top-level ``forward()`` must not be a dict. +It can either be: + +1. A loss tensor; +2. A tuple where the first element is a loss tensor. + +Detailed explaination: :ref:`end2end model `. + +How to fix: + +.. code-block:: diff + + def forward(self, data): + ... + -return {'loss': loss, 'ntokens': ntokens} + +return loss, ntokens + +Example stacktrace: :: + + Traceback (most recent call last): + File "example.py", line 27, in + trainer.run() + File ".../nnscaler/cli/trainer.py", line 95, in run + self._setup() + File ".../nnscaler/cli/trainer.py", line 206, in _setup + pmodel_class = nnscaler.parallelize( + File ".../nnscaler/parallel.py", line 988, in parallelize + regen_status = _gencode( + File ".../nnscaler/parallel.py", line 737, in _gencode + graph, forward_args = _gen_graph( + File ".../nnscaler/parallel.py", line 656, in _gen_graph + raise RuntimeError(f"Loss can only be scalar tensor but got {ir_loss.shape if isinstance(ir_loss, IRTensor) else ir_loss}") + RuntimeError: Loss can only be scalar tensor but got {'loss': t1596(p920,(1,),d(),v(0/1)), 'ntokens': t1597(p922,(1,),d(),v(0/1))} + +"TypeError: ... 'device_type' must be str, not ConcreteAttrProxy" when using torch>=2.4 +--------------------------------------------------------------------------------------- + +nnScaler does not support torch 2.4 yet. +Downgrade to torch 2.3.* will fix the issue: :: + + pip install "torch<2.4" + +Example stacktrace: :: + + Traceback (most recent call last): + File "model.py", line 43, in + trainer.run() + File ".../nnscaler/cli/trainer.py", line 95, in run + self._setup() + File ".../nnscaler/cli/trainer.py", line 206, in _setup + pmodel_class = nnscaler.parallelize( + File ".../nnscaler/parallel.py", line 988, in parallelize + regen_status = _gencode( + + ...... + + File ".../nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py", line 354, in patch_run + return new_func(*args, **kwargs) + File ".../torch/amp/autocast_mode.py", line 237, in __init__ + if not is_autocast_available(self.device): + File ".../torch/amp/autocast_mode.py", line 36, in is_autocast_available + return torch._C._is_autocast_available(device_type) + TypeError: _is_autocast_available(): argument 'device_type' (position 1) must be str, not ConcreteAttrProxy + +Flash Attention Problems +======================== + +"NameError: name 'flash_attn' is not defined" +--------------------------------------------- + +When using flash attention, it must be registered with ``register_op`` API. +Check :doc:`the llama 3 example ` for its usage. + +Example stacktrace: :: + + Traceback (most recent call last): + File "train.py", line 247, in + trainer.run() + File ".../nnscaler/cli/trainer.py", line 98, in run + self._train() + File ".../nnscaler/cli/trainer.py", line 558, in _train + self._train_epoch(epoch) + File ".../nnscaler/cli/trainer.py", line 698, in _train_epoch + losses = self.model.train_step(batches, is_dummy_batch) + File ".../nnscaler/runtime/module.py", line 967, in train_step + output = self._train_step(dataloader) + File ".nnscaler/_parallel_modules/__main__/WrapperModel/_/gencode0.py", line 1228, in _train_step + cross_entropy_1433, getitem_62_1431 = nnscaler.runtime.executor.fexecute('segment1977', model.segment1977, *(data_1780, ), requires_grad=True) + File ".../nnscaler/runtime/executor.py", line 105, in fexecute + outputs = subgraph(*input_dtensors) + File ".nnscaler/_parallel_modules/__main__/WrapperModel/_/gencode0.py", line 452, in segment1977 + add_7_2220, add_7_2221 = ckpt.checkpoint(recompute, unsqueeze_1439, embedding_2130, embedding_2131, use_reentrant=False) + File ".../site-packages/torch/_compile.py", line 24, in inner + return torch._dynamo.disable(fn, recursive)(*args, **kwargs) + File ".../site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn + return fn(*args, **kwargs) + File ".../site-packages/torch/_dynamo/external_utils.py", line 36, in inner + return fn(*args, **kwargs) + File ".../site-packages/torch/utils/checkpoint.py", line 494, in checkpoint + ret = function(*args, **kwargs) + File ".nnscaler/_parallel_modules/__main__/WrapperModel/_/gencode0.py", line 386, in recompute + apply_1495 = flash_attn.flash_attn_interface.FlashAttnFunc.apply(transpose_4_1492, transpose_5_1493, transpose_6_1494, ifexpr_930, None, True, (-1, -1), 0.0, None, False, False) + NameError: name 'flash_attn' is not defined + +"ImportError" when using flash attention +---------------------------------------- + +This is likely an error in flash attention itself. +Please try the related import command outside nnScaler. +If it still fails, please refer to `flash attention `_'s docs. + +If your ``flash-attn`` package is installed from pip, +you can try to use a wheel its `release page _` +which matches your environment more accurately. + +Example stacktrace: :: + + Traceback (most recent call last): + File "train.py", line 9, in + from modeling_modifier import nnscaler_llama_init + File "modeling_modifier.py", line 14, in + from transformers.models.llama.modeling_llama import LlamaAttention, LLAMA_ATTENTION_CLASSES, apply_rotary_pos_emb, LlamaRMSNorm + File ".../site-packages/transformers/models/llama/modeling_llama.py", line 53, in + from flash_attn import flash_attn_func, flash_attn_varlen_func + File ".../site-packages/flash_attn/__init__.py", line 3, in + from flash_attn.flash_attn_interface import ( + File ".../site-packages/flash_attn/flash_attn_interface.py", line 10, in + import flash_attn_2_cuda as flash_attn_cuda + ImportError: .../site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c105Error4whatEv + +Hugging Face Access +=================== + +"Access to model meta-llama/Meta-Llama-3-8B-Instruct is restricted. ... Please log in." +--------------------------------------------------------------------------------------- + +You need to request for `Llama 3 access `_ on Hugging Face first. +Once you get access, generates your `Hugging Face token `_ and export it: :: + + export HF_TOKEN=hf_... + +.. (FIXME: check it) Or alternatively, you can try replacing ``meta-llama/Meta-Llama-3-8B-Instruct`` with ``microsoft/Phi-3-mini-4k-instruct``. diff --git a/utility/verify_ops/verify_op.md b/docs/source/verify_op.md similarity index 96% rename from utility/verify_ops/verify_op.md rename to docs/source/verify_op.md index 46cfc5d2..ba70fdf0 100644 --- a/utility/verify_ops/verify_op.md +++ b/docs/source/verify_op.md @@ -1,16 +1,18 @@ -## verify-graph support -""" -Used to verify operations in IRGraph to ensure their functionality and consistency across single and multiple Gpus. -""" -## example: +# Verify-graph support +Used to verify operations in IRGraph to ensure their functionality and consistency across single and multiple GPUs. + Command-line interface for verifying operations in an IRGraph. Usage: +``` python verify_graph_operations.py --graph --outdir +``` Parameters: ---graph (str): Path to the graph checkpoint file (.ckp) to be loaded. This is the same graph used as the input for the pas policy. ---outdir (str): Directory where verification results will be saved. + + --graph (str): Path to the graph checkpoint file (.ckp) to be loaded. This is the same graph used as the input for the pas policy. + --outdir (str): Directory where verification results will be saved. + This script performs the following steps: 1. Load the IRGraph: Reads the graph checkpoint file specified by the `--graph` argument. @@ -22,13 +24,13 @@ This script performs the following steps: To test a module: you should first use parallelize to generate the required graph.ckp file, then test graph against the current script. -## verify-dimops support -""" +## Verify-dimops support Define a configuration for verifying partition options of a tensor operation. This configuration helps ensure that the operation's partitioning logic is valid by specifying the function signature, arguments, expected outputs, and partitioning options. -""" -## example 1: + +## Example of Conv2D +```python: This is used to verify that Conv2D's partition configuration is correct. This configuration defines a basic Conv2D operation with input Tensor, convolution kernel, and bias. ```python @dataclass @@ -68,7 +70,7 @@ conv2d_config = VerifyConfig( verify_partition_options(conv2d_config) ``` -## Examples for configuring different op +## Examples for more operators ``` dropout_config = VerifyConfig( diff --git a/nnscaler/profiler/README.md b/nnscaler/profiler/README.md index 86bf0577..9d91cb1a 100644 --- a/nnscaler/profiler/README.md +++ b/nnscaler/profiler/README.md @@ -21,7 +21,7 @@ for i in range(N): # our code ends prof.disabled() -prof.dump_stats('cube_RANK%d.prof' % torch.distributed.get_rank()) # or use TID/PID, if to profile multi-thread/-process program. +prof.dump_stats('nnScaler_RANK%d.prof' % torch.distributed.get_rank()) # or use TID/PID, if to profile multi-thread/-process program. ``` After the modification, run the Python file using the same command line with `torchrun` as usual. @@ -30,7 +30,7 @@ After dumping the profiling data, we can use `snakeviz` to visualize it: ```shell pip install snakeviz -snakeviz cube_RANK0.prof +snakeviz nnScaler_RANK0.prof ``` ### Use viztracer From 1861c7c3e1a7e2fbfab535ce26fb371fca496df9 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 26 Nov 2024 03:31:25 +0000 Subject: [PATCH 1774/1892] Merged PR 2325: [BugFix] detach correctly when multiple outputs --- nnscaler/codegen/schedule/schedule.py | 3 +- tests/parallel_module/test_e2e_detach_loss.py | 36 +++++++++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/nnscaler/codegen/schedule/schedule.py b/nnscaler/codegen/schedule/schedule.py index 6e260b9d..8484ef5b 100644 --- a/nnscaler/codegen/schedule/schedule.py +++ b/nnscaler/codegen/schedule/schedule.py @@ -196,10 +196,11 @@ def emit_node(self, node: IRCell, force_no_grad: bool = False) -> List[str]: operation here so that the backward graph's tensors can be deallocated right after the backward pass. """ + plan_outputs = IRCell.get_objects_from_complex(self.execplan.outputs()) for tensor in output_tensors: if not isinstance(tensor, IRTensor): continue - if tensor in self.execplan.outputs(): + if tensor in plan_outputs: codes.append(self.emit_detach(tensor)) elif isinstance(unwrap_node, IRDataOperation): diff --git a/tests/parallel_module/test_e2e_detach_loss.py b/tests/parallel_module/test_e2e_detach_loss.py index aaa9648a..d478c5fc 100644 --- a/tests/parallel_module/test_e2e_detach_loss.py +++ b/tests/parallel_module/test_e2e_detach_loss.py @@ -37,6 +37,19 @@ def forward(self, x): return x.sum() +class Model2(torch.nn.Module): + def __init__(self): + super(Model2, self).__init__() + self.fc1 = torch.nn.Linear(4096, 4096, bias=False) + self.fc2 = torch.nn.Linear(4096, 4096, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + l = x.sum() + return l, l.data + + def policy_pp(graph, cfg): data_loader, fc1, fc2, loss = graph.nodes()[:4] graph.staging([fc1, fc2]) @@ -69,9 +82,9 @@ def policy_pp(graph, cfg): return graph -def worker_pipeline_2x2(): +def worker_pipeline_2x2(model_cls): nnscaler.init() - m = Model() + m = model_cls() m.train() torch.manual_seed(0) if torch.cuda.is_available(): @@ -89,7 +102,7 @@ def worker_pipeline_2x2(): ) if pm.rank in [2, 3]: - assert len(_gencode_contains(tempdir, Model, pm.rank, 'detach\(\)')) == 4 + assert len(_gencode_contains(tempdir, model_cls, pm.rank, 'detach\(\)')) == 4 samples = [torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) for _ in range(4)] ret = pm.train_step(samples) @@ -105,9 +118,11 @@ def worker_pipeline_2x2(): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') -def test_detach_loss_pipeline_hard(): +@pytest.mark.parametrize('model_cls', [Model, Model2]) +def test_detach_loss_pipeline_hard(model_cls): + torchrun(4, worker_pipeline_2x2, model_cls) # should not raise any exception - torchrun(4, worker_pipeline_2x2) + assert True def policy_easy(graph, cfg): @@ -132,9 +147,9 @@ def policy_easy(graph, cfg): return graph -def worker_pipeline_2(): +def worker_pipeline_2(model_cls): nnscaler.init() - m = Model() + m = model_cls() m.train() torch.manual_seed(0) if torch.cuda.is_available(): @@ -153,7 +168,7 @@ def worker_pipeline_2(): pm.to(torch.cuda.current_device()) if pm.rank == 1: - assert len(_gencode_contains(tempdir, Model, pm.rank, 'detach\(\)')) == 4 + assert len(_gencode_contains(tempdir, model_cls, pm.rank, 'detach\(\)')) == 4 samples = [torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) for _ in range(4)] ret = pm.train_step(samples) mem0 = get_mem() @@ -167,7 +182,8 @@ def worker_pipeline_2(): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') -def test_detach_loss_pipeline_easy(): - torchrun(2, worker_pipeline_2) +@pytest.mark.parametrize('model_cls', [Model, Model2]) +def test_detach_loss_pipeline_easy(model_cls): + torchrun(2, worker_pipeline_2, model_cls) # should not raise any exception assert True From 0975639a4d08c7169f8619460ecd82a6b7e35fa9 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Thu, 28 Nov 2024 04:54:48 +0000 Subject: [PATCH 1775/1892] Merged PR 2328: Merge changes from github release - Fix docs format - Update example README's dir hierarchy - Add missing license headers - Rename DaGAN to dagan --- docs/Makefile | 4 + docs/source/_static/nnscaler.css | 4 + docs/source/examples/dagan.rst | 1 + docs/source/examples/deepseek.md | 1 + docs/source/examples/llama.rst | 1 + docs/source/examples/llama3_demo.rst | 1 + docs/source/examples/nanogpt.rst | 1 + docs/source/examples/vit.md | 1 + docs/source/index.rst | 17 +- docs/source/installation.rst | 44 +++ docs/source/llama3_8b_128k_example.rst | 152 --------- docs/source/llama3_demo_example.rst | 90 ------ docs/source/nanogpt_example.rst | 187 ----------- docs/source/quickstart.rst | 16 +- docs/source/trainer.rst | 4 +- docs/source/troubleshooting.rst | 4 +- docs/source/verify_op.md | 3 +- examples/deepseek_coder_v2_lite/README.md | 16 +- .../modeling/configuration_deepseek.py | 5 +- .../modeling/modeling_deepseek.py | 2 +- examples/llama/README.md | 189 ----------- examples/llama/README.rst | 296 ++++++++++++++++++ examples/llama3_demo/README.rst | 91 +++++- examples/nanogpt/README.rst | 188 ++++++++++- examples/vision/DaGAN/README.md | 24 -- examples/vision/dagan/README.rst | 28 ++ examples/vision/{DaGAN => dagan}/dataset.py | 3 + .../vision/{DaGAN => dagan}/model_full.py | 5 +- .../vision/{DaGAN => dagan}/requirements.txt | 0 examples/vision/{DaGAN => dagan}/run.py | 5 +- .../vision/{DaGAN => dagan}/vox-adv-256.yaml | 0 examples/vit/README.md | 8 +- tests/cli/test_resume_seed.py | 3 + tests/graph/test_multiref.py | 5 +- tests/graph/tracer/test_op_context.py | 3 + tests/parallel_module/test_async.py | 3 + tests/parallel_module/test_e2e_detach_loss.py | 3 + .../test_end2end_mix_precision.py | 3 + 38 files changed, 730 insertions(+), 681 deletions(-) create mode 100644 docs/source/_static/nnscaler.css create mode 120000 docs/source/examples/dagan.rst create mode 120000 docs/source/examples/deepseek.md create mode 120000 docs/source/examples/llama.rst create mode 120000 docs/source/examples/llama3_demo.rst create mode 120000 docs/source/examples/nanogpt.rst create mode 120000 docs/source/examples/vit.md create mode 100644 docs/source/installation.rst delete mode 100644 docs/source/llama3_8b_128k_example.rst delete mode 100644 docs/source/llama3_demo_example.rst delete mode 100644 docs/source/nanogpt_example.rst delete mode 100644 examples/llama/README.md create mode 100644 examples/llama/README.rst mode change 120000 => 100644 examples/llama3_demo/README.rst mode change 120000 => 100644 examples/nanogpt/README.rst delete mode 100644 examples/vision/DaGAN/README.md create mode 100644 examples/vision/dagan/README.rst rename examples/vision/{DaGAN => dagan}/dataset.py (95%) rename examples/vision/{DaGAN => dagan}/model_full.py (98%) rename examples/vision/{DaGAN => dagan}/requirements.txt (100%) rename examples/vision/{DaGAN => dagan}/run.py (99%) rename examples/vision/{DaGAN => dagan}/vox-adv-256.yaml (100%) diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf1..2177cdf8 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -8,6 +8,10 @@ SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build +default: + rm -rf build + sphinx-build -M html source build + # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/_static/nnscaler.css b/docs/source/_static/nnscaler.css new file mode 100644 index 00000000..013d9f4d --- /dev/null +++ b/docs/source/_static/nnscaler.css @@ -0,0 +1,4 @@ +/* remove "nnScaler documentation" text below icon in the left side bar */ +.sidebar-brand-text { + display: none; +} diff --git a/docs/source/examples/dagan.rst b/docs/source/examples/dagan.rst new file mode 120000 index 00000000..361869e5 --- /dev/null +++ b/docs/source/examples/dagan.rst @@ -0,0 +1 @@ +../../../examples/vision/dagan/README.rst \ No newline at end of file diff --git a/docs/source/examples/deepseek.md b/docs/source/examples/deepseek.md new file mode 120000 index 00000000..a0efbeff --- /dev/null +++ b/docs/source/examples/deepseek.md @@ -0,0 +1 @@ +../../../examples/deepseek_coder_v2_lite/README.md \ No newline at end of file diff --git a/docs/source/examples/llama.rst b/docs/source/examples/llama.rst new file mode 120000 index 00000000..8ad88198 --- /dev/null +++ b/docs/source/examples/llama.rst @@ -0,0 +1 @@ +../../../examples/llama/README.rst \ No newline at end of file diff --git a/docs/source/examples/llama3_demo.rst b/docs/source/examples/llama3_demo.rst new file mode 120000 index 00000000..e79a819f --- /dev/null +++ b/docs/source/examples/llama3_demo.rst @@ -0,0 +1 @@ +../../../examples/llama3_demo/README.rst \ No newline at end of file diff --git a/docs/source/examples/nanogpt.rst b/docs/source/examples/nanogpt.rst new file mode 120000 index 00000000..da52a600 --- /dev/null +++ b/docs/source/examples/nanogpt.rst @@ -0,0 +1 @@ +../../../examples/nanogpt/README.rst \ No newline at end of file diff --git a/docs/source/examples/vit.md b/docs/source/examples/vit.md new file mode 120000 index 00000000..d30ceedb --- /dev/null +++ b/docs/source/examples/vit.md @@ -0,0 +1 @@ +../../../examples/vit/README.md \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index d8b433a7..126507b2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -45,9 +45,9 @@ Get Started =========== * :doc:`quickstart` -* :doc:`llama3_demo_example` -* :doc:`llama3_8b_128k_example` -* :doc:`nanogpt_example` +* :doc:`examples/llama3_demo` +* :doc:`examples/llama` +* :doc:`examples/nanogpt` Reference @@ -91,7 +91,7 @@ For any questions or inquiries, please contact us at nnscaler@service.microsoft. :caption: Get Started self - install_from_source + installation quickstart .. toctree:: @@ -99,9 +99,12 @@ For any questions or inquiries, please contact us at nnscaler@service.microsoft. :hidden: :caption: Examples - llama3_demo_example - llama3_8b_128k_example - nanogpt_example + examples/llama3_demo + examples/llama + examples/dagan + examples/vit + examples/deepseek + examples/nanogpt .. toctree:: :maxdepth: 1 diff --git a/docs/source/installation.rst b/docs/source/installation.rst new file mode 100644 index 00000000..6cedb7fb --- /dev/null +++ b/docs/source/installation.rst @@ -0,0 +1,44 @@ +############ +Installation +############ + +nnScaler can either be installed from the wheel package or from the source code. + +****************** +Install from Wheel +****************** + +The wheel package is hosted on `GitHub release `_. + +.. code-block:: bash + + pip install https://github.com/microsoft/nnscaler/releases/download/0.5/nnscaler-0.5-py3-none-any.whl + +************************ +Install from Source Code +************************ + +Editable Install +================ + +nnScaler uses ``pybind11`` and ``cppimport`` to dynamically build C++ modules. +The C++ modules must be manually compiled for an editable install. + +.. code-block:: bash + + git clone --recursive https://github.com/microsoft/nnscaler + cd nnscaler + pip install -e . + python -c "import cppimport.import_hook ; import nnscaler.autodist.dp_solver" + +Build a Wheel +============= + +Alternatively you can build the wheel package by yourself. + +.. code-block:: bash + + cd nnscaler + pip install build + python -m build + pip install dist/nnscaler-*.whl diff --git a/docs/source/llama3_8b_128k_example.rst b/docs/source/llama3_8b_128k_example.rst deleted file mode 100644 index 26701d6b..00000000 --- a/docs/source/llama3_8b_128k_example.rst +++ /dev/null @@ -1,152 +0,0 @@ -####################### -Llama 3 8B 128K Example -####################### - -************ -Introduction -************ - -This example demonstrates how to train llama3-8B-128k model with 8xH100s or 8xA100s. - -************ -Requirements -************ - -To run this example, you need to install the following packages: :: - - nnscaler - transformers==4.40.0 - datasets==2.20.0 - apex - flash-attn - -*nnScaler* is a framework for distributed training by automatically partitioning the model. -Apart from the core nnScaler library, it also includes a mini-trainer for modern model training. - -*transformers* and *datasets* are required to prepare the data and loading the Llama model. - -To speed up the training process, -`apex `_ and `flash-attn `_ are required. -You can install them by following instructions in their official repositories. -You may also launch the script in an Nvidia docker container, e.g., ``nvidia/pytorch:24.02-py3``. - -**************** -Data Preparation -**************** - -We use the `bookcorpus `_ dataset for training, which is tokenized with the `Meta-Llama-3-8B-Instruct `_ tokenizer. -Tokenized data is saved in the ``bookcorpus_llama3_128K`` directory. - -.. code-block:: bash - - python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_128K --sequence_length 131072 - -******** -Training -******** - -nnScaler adopts a compiler approach to launch the distributed training, which consists of two stages: - -#. Compile stage: trace the original PyTorch model into a dataflow graph, analyzing the graph to get an efficient plan for distributed training, and - generate python code based on the plan. -#. Runtime stage: run the generated python code to train the model. - -**Note**: We recommend to use well-tested config ``"_attn_implementation": "flash_attention_2"`` and ``"use_cache": false`` in the config file. - -Register Customized Function -============================ - -Llama3's vocabulary size is about 128K, which is much larger then the 32K in Llama2. At the same time the sequence length in this example is 128K, the output tensor size of the last projection layer is quite large: 128K x 128K x 2 bytes = 32GB. -Although this tensor can be partitioned evenly to 8 GPUs, 4GB memory is still quite large for limited GPU memory. What makes it worse is that we need to store additional 8GB for `log_softmax` and `cross_entropy_loss` computation. -In order to reduce the memory consumption: - -* we split the input sequence on each device to chunks of 1K tokens -* for each chunk, we recompute a function which is composed of last projection layer, log_softmax and loss -* as a result, we only need to store the input tensor to the last projection layer, whose initial size is 128K x 4K x 2 bytes = 1GB, which is much smaller than 32GB - -You can find the detailed implementation in ``chunk_linear_cross_entropy.py``. -The interface of the ``chunk_linear_cross_entropy`` function is ``(hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor``, where - -* ``hidden_states`` is the output of the last transformer layer, with shape ``[batch_size, sequence_length, hidden_size]`` -* ``weight`` is the weight matrix of the last projection layer, with shape ``[vocab_size, hidden_size]`` -* ``labels`` is the target labels, with shape ``[batch_size, sequence_length]`` -* ``padding_idx`` is the padding index -* ``chunk_size`` is the size of the chunk, default is 1024 - -We want to register this function to nnScaler and tell it to partition this function along batch size or sequence dimension. A possible annotation is ``b l d^, n^ d^, b l -> b l``. Here ``b`` stands for batch size, ``l`` stands for sequence length, ``d`` stands for hidden size, and ``n`` stands for vocab size. The ``^`` means the dimension cannot be partitioned. More details about the annotation can be found in related documents. - -Compile -======= - -.. code-block:: bash - - python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee compile.log - - -Run -=== - -.. code-block:: bash - - torchrun --nproc_per_node=8 train.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee run.log - - -**Note**: You may directly the ``Run`` command which will compile implicitly, but for clearer log and debug information, we recommend to run ``Compile`` command explicitly before runtime stage. - -Checkpoint -========== - -This script will save the model checkpoint in the ``./checkpoints`` directory. You can change the checkpoint directory by updating the ``CheckpointConfig`` in the ``train.py`` script. - -nnScalar saves checkpoints in shards: each rank may save parameters and optimizer states in a file. These checkpoints can be directly loaded by nnScaler if the partitioning strategy is the same. If you want to evaluate the checkpoints on downstream tasks, you need to merge the shards into a single file. You can use the following command to merge the shards: - -.. code-block:: bash - - python ckpt_merger.py --ckpt_dir ./checkpoints --output_fname ./merged.ckpt - -The merged checkpoint can be loaded by nnScaler by setting the ``--resume_path`` option to the merged file. - -If the script is modified for different hardware configurations. - -* All sharded checkpoint files should be collected and placed in a same directory before ``ckpt_merger.py`` is called. -* If the config is changed (plan_ngpus/runtime_ngpus/etc), the sharded checkpoint can not be used anymore. You need to merge them so the trainer can load from merged checkpoint. - -*********** -Performance -*********** - -The flops of the forward computation for Llama3 is - -.. math:: 2 \cdot ( param\_num \cdot seqlen + 2 \cdot layer\_num \cdot hidden\_dim \cdot seqlen ^ 2) - -For the 8B model, the forward flops is about 11104.35 TFLOPs. The detailed config is as following: - - -* .. math:: param\_num = 8 \times 10^9 -* .. math:: seqlen = 128 \times 1024 -* .. math:: layer\_num = 32 -* .. math:: hidden\_dim = 4096 - -Generally, the computational cost of backpropagation is twice that of the forward pass. In addition, the gradient accumulation number is set to 4. As a result, the flops for a step of the training script is 133252.22 TFLOPs. - -We execute the training script on a node with 8xH100 80GB HBM3. The time cost is about 41.12s for a step. The theoretical BF16 computational speed of the H100 is 989 TFLOPS. Combine them together, this script can achieve 40.96% MFU. You can further optimize the performance by - -* add more devices to avoid recomputation: in order to fit the model into the memory, we recompute by layer. -* do more kernel optimizations. For example, the swiglu activation can be fused into the matmul ahead of it. - -********* -Debugging -********* - -Since the 128K config is challenging, it is recommended to use a smaller model for debugging. For example, you can use the following command to prepare data and train a smaller llama3 (same architecture, but with 4 decoder layers) model on two GPUs. - -.. code-block:: bash - - ## prepare data - python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 - - ## build the mini model - python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini - - ## compile and run using data parallelism + zero1 - torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K diff --git a/docs/source/llama3_demo_example.rst b/docs/source/llama3_demo_example.rst deleted file mode 100644 index b10bdc33..00000000 --- a/docs/source/llama3_demo_example.rst +++ /dev/null @@ -1,90 +0,0 @@ -############ -Llama 3 Demo -############ - -This is an example demostrating how to train Llama 3 8B with nnScaler's :doc:`trainer `. - -The example contains one single script, ``train.py``. - -*********** -Get Started -*********** - -Installation -============ - -0. Get your `Hugging Face token `_ to access Llama 3 model :: - - export HF_TOKEN=... - -1. Clone nnScaler repo :: - - git clone --recursive https://github.com/microsoft/nnscaler - -2. Install dependencies (including Llama 3 dependencies) and :doc:`nnScaler from source ` :: - - cd nnscaler - pip install -r requirements.txt - pip install -e . - -3. Find the Llama 3 example :: - - cd nnscaler/examples/llama3_demo - -4. Prepare dataset :: - - # To run Llama 3 8B: - python train.py --prepare_data - - # Or to run a shrinked Llama for debug: - python train.py --prepare_data --mini - -Train a Mini-model -================== - -This examples requires 8 x 80GB GPU memory to train a full 8B model. -If your have qualified GPUs, you can go to :ref:`the next section `. - -Alternatively, you may start from a smaller model for verification: :: - - python train.py --prepare_data --mini - torchrun --nproc_per_node=2 train.py --mini - -This will resize Llama 3 into a model with 4 hidden layers and max-sequence-length reduced to 4K (4096). -We have tested it with 2 x 48GB GPUs. - -You may further shrink it if the model is still too large: :: - - python train.py --prepare_data --max_seq_len=1024 - torchrun --nproc_per_node=2 train.py --max_seq_len=1024 --num_hidden_layers=2 --from_scratch - -Here is the training loss with the default mini config (4 layers, 4K sequence length): - -.. image:: ./images/llama3-curves-mini.png - -.. _finetune: - -Finetune Llama 3 8B -=================== - -Use the following commands to finetune `Meta-Llama-3-8B-Instruct `_: :: - - python train.py --prepare_data - torchrun --nproc_per_node=8 train.py - -.. image:: ./images/llama3-curves-8b.png - -******** -Resuming -******** - -The example will save checkpoint files after finishing 1000 steps then exit. -To continue training from the saved checkpoint: :: - - torchrun --nproc_per_node=8 train.py --resume_from=last --max_train_steps=2000 - -Please note that the checkpoint is sharded as multiple files. -If you want to resume a checkpoint in a different environment, you need to merge it into an single checkpoint file first: :: - - python train.py --merge_checkpoint=./checkpoints/last - torchrun --nproc_per_node=8 train.py --resume_from=./checkpoints/merged.ckpt --max_train_steps=3000 diff --git a/docs/source/nanogpt_example.rst b/docs/source/nanogpt_example.rst deleted file mode 100644 index f1360744..00000000 --- a/docs/source/nanogpt_example.rst +++ /dev/null @@ -1,187 +0,0 @@ -######################### -nanoGPT Lightning Example -######################### - -This is an example showing how to parallelize `nanoGPT `_ -with nnScaler and `Lightning `_ trainer. - -This example contains one single script, ``train_nnscaler.py``, besides the original nanoGPT repo. - -*********** -Get Started -*********** - -Installation -============ - -1. Clone nnScaler repo :: - - git clone --recursive https://github.com/microsoft/nnscaler - -2. Install dependencies (including nanoGPT's dependencies) and :doc:`nnScaler from source ` :: - - cd nnscaler - pip install -r requirements.txt - pip install -e . - -3. Prepare dataset :: - - python nanoGPT/data/shakespeare_char/prepare.py - -Test with Single GPU -==================== - -Now you can run ``train_nnscaler.py`` with `torchrun `_: :: - - torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py - -This will train a baby GPT model on a single GPU. -It will take several minutes and the best validation loss will be around 1.47. - -Get Distributed -=============== - -nnScaler is meant for distribution. For the current release, we are focusing on data parallel. - -If you have 4 GPUs on one node: :: - - torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py - -Or if you have multiple nodes, for example 2 nodes with 4 GPUs each: :: - - # on each node - torchrun --nnodes=2 --nproc_per_node=4 --rdzv-id=NNSCALER_NANOGPT --rdzv-backend=c10d --rdzv-endpoint= \ - train_nnscaler.py nanoGPT/config/train_shakespeare_char.py - -NOTE: The local batch size is fixed by default, so using more workers will result in larger total batch size. - -Tensor Parallel (Experimental) -============================== - -nnScaler will support tensor parallel and hybrid parallel in following release. -You can try this feature now, but its stability and parity has not been strictly verified yet. - -Using data parallel: (each model instance runs on 1 GPU, 4 instances using DP) :: - - torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=1 --runtime_ngpus=4 - -Using model parallel: (a model instance runs on all 4 GPUs, no DP) :: - - torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=4 --runtime_ngpus=4 - -Using hybrid parallel: (each model instance runs on 2 GPUs, 2 instances using DP) :: - - torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=2 --runtime_ngpus=4 - -Resuming -======== - -You may resume an interrupted training: :: - - torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --init_from=resume - -This will load the latest checkpoint saved by Lightning. - -For distributed environments, checkpoints must be *merged* when the environment changes. -Check :doc:`the reference ` for details. - -.. - FIXME: link to the section (dunno how to link into markdown) - -******** -The Code -******** - -The example code ``train_nnscaler.py`` is modified from nanoGPT's ``train.py``. - -The modification consists of two parts, (1) porting to Lightning trainer and (2) using nnScaler for distribution. - -The Lightning port is not the point of this example. Check the source code if you are interested. - -To parallelize the lightning model with nnScaler, there are 2 noteworthy places: - -1. Define the forward function and declare it's inputs: - - .. code-block:: python - - class LitModel(L.LightningModule): - def __init__(self): - super().__init__() - self.model = model - self.dummy_forward_args_fn = lambda batch: {'x': batch[0], 'y': batch[1]} - - def forward(self, x, y): - _logits, loss = self.model(x, y) - return loss - - A separate forward function is *required* because nnScaler will only parallelizes the codes in ``forward()``, - and will not touch those in ``training_step()``. - - And then, a special function ``dummy_forward_args_fn`` need to be defined to the ``LightningModule``. - It takes ``training_step()``'s ``batch`` argument, and returns a ``dict`` presenting ``forward()``'s parameters. - This function will be used to trace the module's forward graph. - -2. Register nnScaler's strategy and plugin to the Lightning trainer: - - .. code-block:: python - - compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, constant_folding=True) - strategy = NnScalerStrategy(compute_config=compute_config, pas_policy='autodist') - plugins = [NnScalerPrecision(precision)] - - trainer = L.Trainer(strateg=strategy, plugins=plugins, ...) - - For data parallel, always set ``plan_ngpus`` to 1 and set ``runtime_ngpus`` to the total GPU number. - - Other parameters are used for performance (efficiency) tuning. - -.. For details, please check the :doc:`API reference `. - -********************** -Parity and Limitations -********************** - -Single GPU -========== - -For comparison, you can run the script without using nnScaler: :: - - torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --use_nnscaler=False - -This will result in a similar loss curve: - -.. image:: ./images/nanogpt-curves.png - -There are several causes for the mismatch: - -1. nnScaler and Lightning have slightly different gradient clip implementation. -2. It cannot fully syncronize the random state for dropouts. -3. PyTorch is not deterministic by default. - -To get a perfectly matched curve, use the following command: -(The overfitting is significant due to the lack of dropout) -:: - - torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --deterministic=True - torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --deterministic=True --use_nnscaler=False - -.. image:: ./images/nanogpt-curves-deterministic.png - -Data Parallel -============= - -Here is a comparison between nnScaler's and Lightning's builtin data parallel: - -The curve is not fully reproducable due the nature of parallel. - -.. image:: ./images/nanogpt-curves-dp2.png - -The Lightning Port -================== - -The Lightning port is not exactly the same as the original nanoGPT training script for the following reaons: - -1. The Lightning ``Trainer`` is different from nanoGPT's training loop. -2. nnScaler currently lacks the support for multiple parameter groups, and therefore the weight decay is configured for all parameters. - -.. image:: ./images/nanogpt-curves-orig.png diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 2349a869..23da4eec 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -6,12 +6,14 @@ Quickstart Installation ************ -nnScaler can be :doc:`installed from GitHub `: :: +nnScaler can be installed from GitHub: +.. code-block:: bash + + pip install https://github.com/microsoft/nnscaler/releases/download/0.5/nnscaler-0.5-py3-none-any.whl + + # You may also want to clone the repo to try out the examples git clone --recursive https://github.com/microsoft/nnscaler - cd nnscaler/ - pip install -r requirements.txt - pip install -e . *************************** Parallelize a Minimal Model @@ -182,9 +184,9 @@ Next Step ********* The above example uses nnScaler's :doc:`Trainer APIs `. -To learn more about it, you may check our :doc:`Llama 3 example `. +To learn more about it, you may check our :doc:`Llama 3 example `. Or if you prefer to use a familiar trainer, we also provides integration with `PyTorch Lightning `_. -The usage is demostrated by :doc:`nanoGPT example `. +The usage is demostrated by :doc:`nanoGPT example `. -If you want to try a more advanced model, please check :doc:`Llama 3 128K sequence length example `. +If you want to try a more advanced model, please check :doc:`Llama 3 128K sequence length example `. diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 98b76dc5..a4e7331a 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -1,6 +1,6 @@ -####### +############## Native Trainer -####### +############## ``nnScaler`` provides a ``Trainer`` class for training and evaluating model parallelization. Let's start from an example to demonstrate how to parallelize a model using the ``parallelize`` API. diff --git a/docs/source/troubleshooting.rst b/docs/source/troubleshooting.rst index 10a7a01f..99104fac 100644 --- a/docs/source/troubleshooting.rst +++ b/docs/source/troubleshooting.rst @@ -218,7 +218,7 @@ Flash Attention Problems --------------------------------------------- When using flash attention, it must be registered with ``register_op`` API. -Check :doc:`the llama 3 example ` for its usage. +Check :doc:`the llama 3 example ` for its usage. Example stacktrace: :: @@ -259,7 +259,7 @@ Please try the related import command outside nnScaler. If it still fails, please refer to `flash attention `_'s docs. If your ``flash-attn`` package is installed from pip, -you can try to use a wheel its `release page _` +you can try to use a wheel its `release page `_ which matches your environment more accurately. Example stacktrace: :: diff --git a/docs/source/verify_op.md b/docs/source/verify_op.md index ba70fdf0..a829b1be 100644 --- a/docs/source/verify_op.md +++ b/docs/source/verify_op.md @@ -30,7 +30,6 @@ This configuration helps ensure that the operation's partitioning logic is valid by specifying the function signature, arguments, expected outputs, and partitioning options. ## Example of Conv2D -```python: This is used to verify that Conv2D's partition configuration is correct. This configuration defines a basic Conv2D operation with input Tensor, convolution kernel, and bias. ```python @dataclass @@ -346,4 +345,4 @@ verify_config = VerifyConfig( non_grad_indices=[3, 4] ) verify_partition_options(verify_config) -``` \ No newline at end of file +``` diff --git a/examples/deepseek_coder_v2_lite/README.md b/examples/deepseek_coder_v2_lite/README.md index f9aa6e5e..6a83be4e 100644 --- a/examples/deepseek_coder_v2_lite/README.md +++ b/examples/deepseek_coder_v2_lite/README.md @@ -1,8 +1,10 @@ -# Introduction +# DeepSeek Example + +## Introduction This example demonstrates how to train deepseek-coder-v2-lite-2k on 8xH100s or 8xA100s. -# Requirements +## Requirements To run this example, you need to install the following packages: @@ -17,7 +19,7 @@ grouped_gemm==1.1.4 We recommend to launch the script under a Nvidia docker directly, like `nvidia/pytorch:24.02-py3`. You can find grouped_gemm at https://github.com/fanshiqing/grouped_gemm. -# Data Preparation +## Data Preparation Like the *llama3_8B_128K* example, [bookcorpus](https://huggingface.co/datasets/bookcorpus) dataset is used for training. You can use the following command directly @@ -25,9 +27,9 @@ Like the *llama3_8B_128K* example, [bookcorpus](https://huggingface.co/datasets/ python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name deepseek-ai/DeepSeek-Coder-V2-Lite-Base --save_path ./bookcorpus_2k --sequence_length 2048 ``` -# Training +## Training -## Code Modification +### Code Modification Modeling is based on the open source version for [deepseek coder v2](https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Base/tree/main). To boost the training performance and be compatible with nnScaler, the source code is modified. You can check modifications in details under `modeling` folder: @@ -40,7 +42,7 @@ Similar to *llama3_8B_128K*, apex and flash-attn are introduced to reduce the ex - register the routing function with annotation to nnScaler, since it is composed of fine-grained irregular operators and generating the annoation automatically is non-trivial. - the for loop based MoE implementation is replaced with an efficient implementation built on [cutlass](https://github.com/NVIDIA/cutlass/blob/main/examples/24_gemm_grouped/gemm_grouped.cu). Along with kernel, separated expert weights are merged after loading the checkpoints. -## Distributed Config +### Distributed Config The input data is organized into batches of 64 sequences whose length = 2048. The micro batch size is 4 and gradient accumulation step is 8. 8 GPUs are divided into 2 data parallel groups (4 GPUs maintain a full copy of weights). @@ -58,7 +60,7 @@ python train.py --run_mode compile --model_id deepseek-ai/DeepSeek-Coder-V2-Lite torchrun --nproc_per_node=8 train.py --model_id deepseek-ai/DeepSeek-Coder-V2-Lite-Base --dataset_path ./bookcorpus_2k --plan_ngpus=4 --runtime_ngpus=8 2>&1 | tee run.log ``` -# Performance +## Performance We have tested the training script on 8xH100 and each step takes about 2s. A step is composed of 128K tokens and the number of activated params is about 2.65B. Combining them together, the MFU is about 13% (attention's FLOPs is omitted since the sequence is short in this ). The root cause is the low utilization rate of the MoE part. We collect statistics for the grouped gemm in the table below. Note that in deepseek coder v2 lite, there are 64 experts with hidden size = 2048, intermediate size = 1408, each token will be dispatched to 8 experts. diff --git a/examples/deepseek_coder_v2_lite/modeling/configuration_deepseek.py b/examples/deepseek_coder_v2_lite/modeling/configuration_deepseek.py index 82e0f5d9..a9e6c21e 100644 --- a/examples/deepseek_coder_v2_lite/modeling/configuration_deepseek.py +++ b/examples/deepseek_coder_v2_lite/modeling/configuration_deepseek.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -203,4 +206,4 @@ def __init__( eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, - ) \ No newline at end of file + ) diff --git a/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek.py b/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek.py index ea723720..847a458b 100644 --- a/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek.py +++ b/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek.py @@ -1919,4 +1919,4 @@ def forward( past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) \ No newline at end of file + ) diff --git a/examples/llama/README.md b/examples/llama/README.md deleted file mode 100644 index 68ab68fd..00000000 --- a/examples/llama/README.md +++ /dev/null @@ -1,189 +0,0 @@ -# Introduction - -This example demonstrates how to train llama models in challenging distributed configurations by nnscaler. - -# Requirements - -Assume following packages have been installed in the environment. - -```text -nnscaler -transformers==4.40.0 -datasets==2.20.0 -apex -flash-attn -``` - -*nnScaler* is a framework for distributed training by automatically partitioning the model. Apart from the core nnScaler library, it also includes a mini-trainer for modern model training. You can find related documents and examples at [nnScaler](https://nnscaler.readthedocs.io/en/latest/). - -*transformers* and *datasets* are used to prepare the data and loading the Llama model. - -To speed up the training, [*apex*](https://github.com/NVIDIA/apex) and [*flash-attn*](https://github.com/Dao-AILab/flash-attention) are required. You can install them by following instructions in their official repositories. We also recommend to launch training in a docker directly, like nvidia/pytorch:24.02-py3 and rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0. - -# Supported Models - -The following table lists the supported model architectures and their corresponding distributed environments. A performance analysis for these will be provided later in the document. We plan to support more model combinations in the future and encourage you to experiment and contribute. - -| Model ID | Sequence Length | Device Type | Device Number | -| :---------------------------------: | :-------------: | :---------: | :-----------: | -| meta-llama/Meta-Llama-3-8B-Instruct | 131072 | H100 | 8 | -| meta-llama/Meta-Llama-3-70B | 8192 | MI300 | 16 | - -# Data Preparation - -We use the [bookcorpus](https://huggingface.co/datasets/bookcorpus) dataset for demonstrating in this doc. You can change related code to support your own dataset. Here we give an example that downloads and tokenizes `bookcorpus` for Llama. - -In the example command below, the dataset is tokenized by [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) tokenizer and grouped into 128K, tokenized data is saved in `bookcorpus_llama3_128K` directory. - -```bash -python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_128K --sequence_length 131072 -``` - -# Training - -nnScaler adopts a compiler approach to train deep learning models on multiple deivices. The processing pipeline is divided into two stages: - -1. Compile stage: trace the original PyTorch model and get the dataflow graph. Analyze the graph and generate an efficient plan for distributed training. Generate python code for the runtime stage. -2. Runtime stage: run the generated python code to train the model. - -For better user experience, we recommend to use separate commands for the compile and runtime stages at your first trial of nnScaler. You can use the `Run` command directly to combine the two stages when you are familiar with the system. - -**Note**: currently we only tested `"_attn_implementation": "flash_attention_2"` and `"use_cache": false` in the config file. Other configurations may trigger errors. - -## Trace Strategy - -During compiling, the time cost of trace model graph can vary significantly depending on the tracing strategy employed. Below are some reference time to trace `meta-llama/Meta-Llama-3-8B-Instruct` with different strategies and different context length, the time tested on one single A100 80GB: - -| Strategy | Context Length | Time/seconds | -| :------: | :------------: | :----------: | -| `reuse_cache` | 8k | 8.11 | -| `reuse_cache` | 32k | 11.06 | -| `reuse_cache` | 64k | 15.36 | -| `reuse_cache` | 128k | 26.29 | -| `cuda_run_cpu_offload` | 8k | 55.28 | -| `cuda_run_cpu_offload` | 32k | 194.27 | -| `cuda_run_cpu_offload` | 64k | 342.15 | -| `cuda_run_cpu_offload` | 128k | 789.15 | - -The trace strategy can be changed by setting `--trace_strategy` option. Please note that different strategies have different applicable scenarios. For more information and explanation to the different strategies, please read `docs/source/parallel_module.md`. - -## Register Customized Function - -Llama3's vocabulary size is about 128K, which is much larger then the 32K in Llama2. When the sequence length is very long like 128K, the output tensor size of the last projection layer is quite large: 128K x 128K x 2 bytes = 32GB in fp16 or bf16. -Although this tensor can be partitioned evenly to 8 GPUs, 4GB memory is still large due to limited GPU memory. What makes it worse is that we need to store additional 8GB for `log_softmax` and `cross_entropy_loss` computation. -In order to reduce the memory consumption: -- we split the input sequence on each device to chunks of 1K tokens -- for each chunk, we recompute a function which is composed of last projection layer, log_softmax and loss -- as a result, we only need to store the input tensor to the last projection layer, whose initial size is 128K x 4K x 2 bytes = 1GB, which is much smaller than 32GB - -You can find the detailed implementation in `chunk_linear_cross_entropy.py`. -The interface of the `chunk_linear_cross_entropy` function is `(hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor`, where -- `hidden_states` is the output of the last transformer layer, with shape `[batch_size, sequence_length, hidden_size]` -- `weight` is the weight matrix of the last projection layer, with shape `[vocab_size, hidden_size]` -- `labels` is the target labels, with shape `[batch_size, sequence_length]` -- `padding_idx` is the padding index -- `chunk_size` is the size of the chunk, default is 1024 - -We want to register this function to nnScaler and tell it to partition this function along batch size or sequence dimension. A possible annotation is `b l d^, n^ d^, b l -> b l`. Here `b` stands for batch size, `l` stands for sequence length, `d` stands for hidden size, and `n` stands for vocab size. The `^` means the dimension cannot be partitioned. More details about the annotation can be found in `docs/source/register_custom_op.md`. - -You can enable this customized function by passing `--enable-chunk-loss` to `train.py` when compiling. When the sequence length is small (like 8K), this option can be turned off. - -## Profile Communication - -To generate an efficient distributed plan in your environment, we recommend to profile the intra-node communication before compiling. The profiler records the time of different communication primitives (like allgather, allreduce, reducescatter and alltoall) for some message sizes. If the profiling is skipped, the system will use MI250's data by default. You can use the command below to profile. - -```bash -cd nnscaler && python utility/prim_profiler.py -``` - -## Checkpoint - -`train.py` will save the model checkpoint in the `./checkpoints` directory by default. You can change the checkpoint directory by updating the `CheckpointConfig` in the source code. - -nnScalar saves checkpoints in shards: each rank may save parameters and optimizer states in a file. These checkpoints can be directly loaded by nnScaler if the partitioning strategy is the same. If you want to evaluate the checkpoints on downstream tasks, you need to merge the shards into a single file. You can use the following command to merge the shards: - -```bash -python ckpt_merger.py --ckpt_dir ./checkpoints --output_fname ./merged.ckpt -``` - -The merged checkpoint can be loaded by nnScaler by setting the `--resume_path` option to the merged file. - -If the script is modified for different hardware configurations. -- All sharded checkpoint files should be collected and placed in a same directory before `ckpt_merger.py` is called. -- If the config is changed (plan_ngus/runtime_ngus/etc), the sharded checkpoint can not be used anymore. You need to merge them so the trainer can load from merged checkpoint. - -# Performance Analysis - -The flops of the forward computation for llama is - -$2 \cdot ( param\_num \cdot seqlen + 2 \cdot layer\_num \cdot hidden\_dim \cdot seqlen ^ 2)$ - -## Llama3 8B 128K on 8xH100 - -Commands below is used for this setting. - -**Compile** - -```bash -python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee compile.log -``` - -**Run** - -```bash -torchrun --nproc_per_node=8 train.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee run.log -``` - -For the 8B model, the forward flops is about 11104.35 TFLOPs. The detailed config is as following: -- $param\_num = 8 \times 10^9$ -- $seqlen = 128 \times 1024$ -- $layer\_num = 32$ -- $hidden\_dim = 4096$ - -Generally, the computational cost of backpropagation is twice that of the forward pass. In addition, the gradient accumulation number is set to 4. As a result, the flops for a step of the training script is 133252.22 TFLOPs. - -We execute the training script on a node with 8xH100 80GB HBM3. The time cost is about 41.12s for a step. The theoretical BF16 computational speed of the H100 is 989 TFLOPS. Combine them together, this script can achieve 40.96% MFU. You can optimize the performance furtherly by -- add more devices to avoid recomputation: in order to fit the model into the memory, we recompute by layer. -- do more kernel optimizations. For example, the swiglu activation can be fused into the matmul ahead of it. - -## Llama3 70B 8K on 16xMI300 - -Different from the 8B example, a merged command is used for the multi-node setting. Since 70b model is trained on 2 nodes, we use mpi to execute `torchrun` on them at the same time. If you want to run the command on your own, you can replace `MASTER_ADDR` with the IP address of the first node, `MASTER_PORT` with the available port on the first node and fill `OMPI_COMM_WORLD_RANK` with 0 and 1 on two nodes respectively. - -**Combined Command** - -```bash -torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$$OMPI_COMM_WORLD_RANK --master_addr="$$MASTER_ADDR" --master_port=$$MASTER_PORT train.py --name llama3-70b --model_id meta-llama/Meta-Llama-3-70B --dataset_path ./bookcorpus_llama3_8K --gpu_mem_constraint 153 --plan_ngpus=8 --runtime_ngpus=16 --explore_pipeline --grad_accumulation_steps 64 --pipeline_pivots LlamaDecoderLayer 2>&1 | tee run.log -``` - -Note that in the command above, we enable searching for pipeline parallelism by passing `--explore_pipeline` and set the possible pipeline stage boundaries by `--pipeline_pivots LlamaDecoderLayer`. - -For the 70B model, the flops for forward and backward is about 3968.41 TFLOPs. The detailed config is as following: -- $param\_num = 70 \times 10^9$ -- $seqlen = 8192$ -- $layer\_num = 80$ -- $hidden\_dim = 8192$ - -[MI300X](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf)'s peak theoritical performance for BF16 is 1307.4 TFLOPS. It takes about 100.3 s to finish 64 gradient accumulation steps in the experiment. Combine them together, the MFU of this distributed plan is 24.2 %. - -Based on AutoDist's analysis, the low utilization results from following aspects -- We observe MFU for important operators are low. For example, `linear`'s MFU is 40% ~ 50%, the real MFU of `flash-attn` is 14%. -- Like the 8B 128K example, we can fuse operators like RoPE and swiglu to reduce time. -- There are two pipeline stages each with 4 devices. In each stage, communication takes about 450ms and computation takes about 1000ms. According to our experiences, the communication time is higher than expected. Adding more devices may help to reduce it since the optimizer states still takes about 52GB in each device. -- Enlarge search space in the future. Currently we only consider plan_ngpus=8 and fix the pipeline schedule to be `1f1b`. We can refine this assumption in the future. - -# Debugging - -Since the large setting is challenging, it is recommended to use a smaller model for debugging. For example, you can use the following command to prepare data and train a smaller llama3 (same architecture, but with 4 decoder layers) model on two GPUs. - -```bash -# prepare data -python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 - -# build the mini model -python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini - -# compile and run using data parallelism + zero1 -torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K - -``` diff --git a/examples/llama/README.rst b/examples/llama/README.rst new file mode 100644 index 00000000..4dd5be41 --- /dev/null +++ b/examples/llama/README.rst @@ -0,0 +1,296 @@ +###################### +Advanced Llama Example +###################### + +************ +Introduction +************ + +This example demonstrates how to train llama models in challenging distributed configurations by nnscaler. + +************ +Requirements +************ + +Assume following packages have been installed in the environment. :: + + nnscaler + transformers==4.40.0 + datasets==2.20.0 + apex + flash-attn + +*nnScaler* is a framework for distributed training by automatically partitioning the model. +Apart from the core nnScaler library, it also includes a mini-trainer for modern model training. +You can find related documents and examples at `nnScaler `_. + +*transformers* and *datasets* are used to prepare the data and loading the Llama model. + +To speed up the training, +`apex `_ and `flash-attn `_ are required. +You can install them by following instructions in their official repositories. +We also recommend to launch training in a docker directly, +like ``nvidia/pytorch:24.02-py3`` and ``rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0``. + +**************** +Supported Models +**************** + +The following table lists the supported model architectures and their corresponding distributed environments. +A performance analysis for these will be provided later in the document. +We plan to support more model combinations in the future and encourage you to experiment and contribute. + ++-------------------------------------+-----------------+-------------+---------------+ +| Model ID | Sequence Length | Device Type | Device Number | ++=====================================+=================+=============+===============+ +| meta-llama/Meta-Llama-3-8B-Instruct | 131072 | H100 | 8 | ++-------------------------------------+-----------------+-------------+---------------+ +| meta-llama/Meta-Llama-3-70B | 8192 | MI300 | 16 | ++-------------------------------------+-----------------+-------------+---------------+ + +**************** +Data Preparation +**************** + +We use the `bookcorpus `_ dataset for demonstrating in this doc. +You can change related code to support your own dataset. +Here we give an example that downloads and tokenizes ``bookcorpus`` for Llama. + +In the example command below, +the dataset is tokenized by `Meta-Llama-3-8B-Instruct `_ tokenizer and grouped into 128K, +tokenized data is saved in ``bookcorpus_llama3_128K`` directory. + +.. code-block:: bash + + python bookcorpus.py \ + --data_path_or_name bookcorpus/bookcorpus \ + --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct \ + --save_path ./bookcorpus_llama3_128K \ + --sequence_length 131072 + +******** +Training +******** + +nnScaler adopts a compiler approach to train deep learning models on multiple deivices. +The processing pipeline is divided into two stages: + +#. Compile stage: trace the original PyTorch model and get the dataflow graph. + Analyze the graph and generate an efficient plan for distributed training. + Generate python code for the runtime stage. +#. Runtime stage: run the generated python code to train the model. + +For better user experience, we recommend to use separate commands for the compile and runtime stages at your first trial of nnScaler. +You can use the ``Run`` command directly to combine the two stages when you are familiar with the system. + +**Note**: currently we only tested ``"_attn_implementation": "flash_attention_2"`` and ``"use_cache": false`` in the config file. +Other configurations may trigger errors. + +Trace Strategy +============== + +During compiling, the time cost of trace model graph can vary significantly depending on the tracing strategy employed. +Below are some reference time to trace ``meta-llama/Meta-Llama-3-8B-Instruct`` with different strategies and different context length, +the time tested on one single A100 80GB: + ++------------------------+----------------+--------------+ +| Strategy | Context Length | Time/seconds | ++========================+================+==============+ +| `reuse_cache` | 8k | 8.11 | ++------------------------+----------------+--------------+ +| `reuse_cache` | 32k | 11.06 | ++------------------------+----------------+--------------+ +| `reuse_cache` | 64k | 15.36 | ++------------------------+----------------+--------------+ +| `reuse_cache` | 128k | 26.29 | ++------------------------+----------------+--------------+ +| `cuda_run_cpu_offload` | 8k | 55.28 | ++------------------------+----------------+--------------+ +| `cuda_run_cpu_offload` | 32k | 194.27 | ++------------------------+----------------+--------------+ +| `cuda_run_cpu_offload` | 64k | 342.15 | ++------------------------+----------------+--------------+ +| `cuda_run_cpu_offload` | 128k | 789.15 | ++------------------------+----------------+--------------+ + +The trace strategy can be changed by setting ``--trace_strategy`` option. +Please note that different strategies have different applicable scenarios. +For more information and explanation to the different strategies, please read :doc:`../parallel_module`. + +Register Customized Function +============================ + +Llama3's vocabulary size is about 128K, which is much larger then the 32K in Llama2. +When the sequence length is very long like 128K, +the output tensor size of the last projection layer is quite large: +128K x 128K x 2 bytes = 32GB in fp16 or bf16. +Although this tensor can be partitioned evenly to 8 GPUs, 4GB memory is still large due to limited GPU memory. +What makes it worse is that we need to store additional 8GB for ``log_softmax`` and ``cross_entropy_loss`` computation. +In order to reduce the memory consumption: + +* we split the input sequence on each device to chunks of 1K tokens +* for each chunk, we recompute a function which is composed of last projection layer, log_softmax and loss +* as a result, we only need to store the input tensor to the last projection layer, + whose initial size is 128K x 4K x 2 bytes = 1GB, which is much smaller than 32GB + +You can find the detailed implementation in ``chunk_linear_cross_entropy.py``. +The interface of the ``chunk_linear_cross_entropy`` function is +``(hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor``, +where + +* ``hidden_states`` is the output of the last transformer layer, with shape ``[batch_size, sequence_length, hidden_size]`` +* ``weight`` is the weight matrix of the last projection layer, with shape ``[vocab_size, hidden_size]`` +* ``labels`` is the target labels, with shape ``[batch_size, sequence_length]`` +* ``padding_idx`` is the padding index +* ``chunk_size`` is the size of the chunk, default is 1024 + +We want to register this function to nnScaler and tell it to partition this function along batch size or sequence dimension. +A possible annotation is ``b l d^, n^ d^, b l -> b l``. +Here ``b`` stands for batch size, ``l`` stands for sequence length, ``d`` stands for hidden size, and ``n`` stands for vocab size. +The ``^`` means the dimension cannot be partitioned. +More details about the annotation can be found in :doc:`../register_custom_op`. + +You can enable this customized function by passing ``--enable-chunk-loss`` to ``train.py`` when compiling. +When the sequence length is small (like 8K), this option can be turned off. + +Profile Communication +===================== + +To generate an efficient distributed plan in your environment, we recommend to profile the intra-node communication before compiling. +The profiler records the time of different communication primitives (like allgather, allreduce, reducescatter and alltoall) for some message sizes. +If the profiling is skipped, the system will use MI250's data by default. You can use the command below to profile. + +.. code-block:: bash + + cd nnscaler && python utility/prim_profiler.py + +Checkpoint +========== + +``train.py`` will save the model checkpoint in the ``./checkpoints`` directory by default. +You can change the checkpoint directory by updating the ``CheckpointConfig`` in the source code. + +nnScalar saves checkpoints in shards: each rank may save parameters and optimizer states in a file. +These checkpoints can be directly loaded by nnScaler if the partitioning strategy is the same. +If you want to evaluate the checkpoints on downstream tasks, you need to merge the shards into a single file. +You can use the following command to merge the shards: + +.. code-block:: bash + + python ckpt_merger.py --ckpt_dir ./checkpoints --output_fname ./merged.ckpt + +The merged checkpoint can be loaded by nnScaler by setting the ``--resume_path`` option to the merged file. + +If the script is modified for different hardware configurations. + +* All sharded checkpoint files should be collected and placed in a same directory before ``ckpt_merger.py`` is called. +* If the config is changed (plan_ngus/runtime_ngus/etc), the sharded checkpoint can not be used anymore. + You need to merge them so the trainer can load from merged checkpoint. + +******************** +Performance Analysis +******************** + +The flops of the forward computation for llama is + +.. math:: 2 \cdot ( param\_num \cdot seqlen + 2 \cdot layer\_num \cdot hidden\_dim \cdot seqlen ^ 2) + +Llama3 8B 128K on 8xH100 +======================== + +Commands below is used for this setting. + +Compile +------- + +.. code-block:: bash + + python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee compile.log + +Run +--- + +.. code-block:: bash + + torchrun --nproc_per_node=8 train.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee run.log + +For the 8B model, the forward flops is about 11104.35 TFLOPs. The detailed config is as following: + +* .. math:: param\_num = 8 \times 10^9 +* .. math:: seqlen = 128 \times 1024 +* .. math:: layer\_num = 32 +* .. math:: hidden\_dim = 4096 + +Generally, the computational cost of backpropagation is twice that of the forward pass. +In addition, the gradient accumulation number is set to 4. +As a result, the flops for a step of the training script is 133252.22 TFLOPs. + +We execute the training script on a node with 8xH100 80GB HBM3. +The time cost is about 41.12s for a step. +The theoretical BF16 computational speed of the H100 is 989 TFLOPS. +Combine them together, this script can achieve 40.96% MFU. +You can optimize the performance furtherly by + +* add more devices to avoid recomputation: in order to fit the model into the memory, we recompute by layer. +* do more kernel optimizations. For example, the swiglu activation can be fused into the matmul ahead of it. + +Llama3 70B 8K on 16xMI300 +========================= + +Different from the 8B example, a merged command is used for the multi-node setting. +Since 70b model is trained on 2 nodes, we use mpi to execute ``torchrun`` on them at the same time. +If you want to run the command on your own, you can replace ``MASTER_ADDR`` with the IP address of the first node, +``MASTER_PORT`` with the available port on the first node and fill ``OMPI_COMM_WORLD_RANK`` with 0 and 1 on two nodes respectively. + +Combined Command +---------------- + +.. code-block:: bash + + torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$$OMPI_COMM_WORLD_RANK --master_addr="$$MASTER_ADDR" --master_port=$$MASTER_PORT train.py --name llama3-70b --model_id meta-llama/Meta-Llama-3-70B --dataset_path ./bookcorpus_llama3_8K --gpu_mem_constraint 153 --plan_ngpus=8 --runtime_ngpus=16 --explore_pipeline --grad_accumulation_steps 64 --pipeline_pivots LlamaDecoderLayer 2>&1 | tee run.log + +Note that in the command above, we enable searching for pipeline parallelism by passing ``--explore_pipeline`` +and set the possible pipeline stage boundaries by ``--pipeline_pivots LlamaDecoderLayer``. + +For the 70B model, the flops for forward and backward is about 3968.41 TFLOPs. The detailed config is as following: + +* .. math:: param\_num = 70 \times 10^9 +* .. math:: seqlen = 8192 +* .. math:: layer\_num = 80 +* .. math:: hidden\_dim = 8192 + +`MI300X `_'s +peak theoritical performance for BF16 is 1307.4 TFLOPS. +It takes about 100.3 s to finish 64 gradient accumulation steps in the experiment. +Combine them together, the MFU of this distributed plan is 24.2%. + +Based on AutoDist's analysis, the low utilization results from following aspects + +* We observe MFU for important operators are low. + For example, ``linear``'s MFU is 40% ~ 50%, the real MFU of ``flash-attn`` is 14%. +* Like the 8B 128K example, we can fuse operators like RoPE and swiglu to reduce time. +* There are two pipeline stages each with 4 devices. + In each stage, communication takes about 450ms and computation takes about 1000ms. + According to our experiences, the communication time is higher than expected. Adding more devices may help to reduce it since the optimizer states still takes about 52GB in each device. +* Enlarge search space in the future. + Currently we only consider plan_ngpus=8 and fix the pipeline schedule to be ``1f1b``. + We can refine this assumption in the future. + +********* +Debugging +********* + +Since the large setting is challenging, it is recommended to use a smaller model for debugging. +For example, you can use the following command to prepare data and train a smaller llama3 +(same architecture, but with 4 decoder layers) model on two GPUs. + +.. code-block:: bash + + # prepare data + python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 + + # build the mini model + python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini + + # compile and run using data parallelism + zero1 + torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K diff --git a/examples/llama3_demo/README.rst b/examples/llama3_demo/README.rst deleted file mode 120000 index 52a40886..00000000 --- a/examples/llama3_demo/README.rst +++ /dev/null @@ -1 +0,0 @@ -../../docs/source/llama3_demo_example.rst \ No newline at end of file diff --git a/examples/llama3_demo/README.rst b/examples/llama3_demo/README.rst new file mode 100644 index 00000000..93b9d5c5 --- /dev/null +++ b/examples/llama3_demo/README.rst @@ -0,0 +1,90 @@ +############ +Llama 3 Demo +############ + +This is an example demostrating how to train Llama 3 8B with nnScaler's :doc:`trainer <../trainer>`. + +The example contains one single script, ``train.py``. + +*********** +Get Started +*********** + +Installation +============ + +0. Get your `Hugging Face token `_ to access Llama 3 model :: + + export HF_TOKEN=... + +1. Clone nnScaler repo :: + + git clone --recursive https://github.com/microsoft/nnscaler + +2. Install dependencies (including Llama 3 dependencies) and :doc:`nnScaler from source <../install_from_source>` :: + + cd nnscaler + pip install -r requirements.txt + pip install -e . + +3. Find the Llama 3 example :: + + cd nnscaler/examples/llama3_demo + +4. Prepare dataset :: + + # To run Llama 3 8B: + python train.py --prepare_data + + # Or to run a shrinked Llama for debug: + python train.py --prepare_data --mini + +Train a Mini-model +================== + +This examples requires 8 x 80GB GPU memory to train a full 8B model. +If your have qualified GPUs, you can go to :ref:`the next section `. + +Alternatively, you may start from a smaller model for verification: :: + + python train.py --prepare_data --mini + torchrun --nproc_per_node=2 train.py --mini + +This will resize Llama 3 into a model with 4 hidden layers and max-sequence-length reduced to 4K (4096). +We have tested it with 2 x 48GB GPUs. + +You may further shrink it if the model is still too large: :: + + python train.py --prepare_data --max_seq_len=1024 + torchrun --nproc_per_node=2 train.py --max_seq_len=1024 --num_hidden_layers=2 --from_scratch + +Here is the training loss with the default mini config (4 layers, 4K sequence length): + +.. image:: ../images/llama3-curves-mini.png + +.. _finetune: + +Finetune Llama 3 8B +=================== + +Use the following commands to finetune `Meta-Llama-3-8B-Instruct `_: :: + + python train.py --prepare_data + torchrun --nproc_per_node=8 train.py + +.. image:: ../images/llama3-curves-8b.png + +******** +Resuming +******** + +The example will save checkpoint files after finishing 1000 steps then exit. +To continue training from the saved checkpoint: :: + + torchrun --nproc_per_node=8 train.py --resume_from=last --max_train_steps=2000 + +Please note that the checkpoint is sharded as multiple files. +If you want to resume a checkpoint in a different environment, you need to merge it into an single checkpoint file first: :: + + python train.py --merge_checkpoint=./checkpoints/last + torchrun --nproc_per_node=8 train.py --resume_from=./checkpoints/merged.ckpt --max_train_steps=3000 diff --git a/examples/nanogpt/README.rst b/examples/nanogpt/README.rst deleted file mode 120000 index a9f2be20..00000000 --- a/examples/nanogpt/README.rst +++ /dev/null @@ -1 +0,0 @@ -../../docs/source/nanogpt_example.rst \ No newline at end of file diff --git a/examples/nanogpt/README.rst b/examples/nanogpt/README.rst new file mode 100644 index 00000000..41054473 --- /dev/null +++ b/examples/nanogpt/README.rst @@ -0,0 +1,187 @@ +######################### +nanoGPT Lightning Example +######################### + +This is an example showing how to parallelize `nanoGPT `_ +with nnScaler and `Lightning `_ trainer. + +This example contains one single script, ``train_nnscaler.py``, besides the original nanoGPT repo. + +*********** +Get Started +*********** + +Installation +============ + +1. Clone nnScaler repo :: + + git clone --recursive https://github.com/microsoft/nnscaler + +2. Install dependencies (including nanoGPT's dependencies) and :doc:`nnScaler from source <../install_from_source>` :: + + cd nnscaler + pip install -r requirements.txt + pip install -e . + +3. Prepare dataset :: + + python nanoGPT/data/shakespeare_char/prepare.py + +Test with Single GPU +==================== + +Now you can run ``train_nnscaler.py`` with `torchrun `_: :: + + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +This will train a baby GPT model on a single GPU. +It will take several minutes and the best validation loss will be around 1.47. + +Get Distributed +=============== + +nnScaler is meant for distribution. For the current release, we are focusing on data parallel. + +If you have 4 GPUs on one node: :: + + torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +Or if you have multiple nodes, for example 2 nodes with 4 GPUs each: :: + + # on each node + torchrun --nnodes=2 --nproc_per_node=4 --rdzv-id=NNSCALER_NANOGPT --rdzv-backend=c10d --rdzv-endpoint= \ + train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +NOTE: The local batch size is fixed by default, so using more workers will result in larger total batch size. + +Tensor Parallel (Experimental) +============================== + +nnScaler will support tensor parallel and hybrid parallel in following release. +You can try this feature now, but its stability and parity has not been strictly verified yet. + +Using data parallel: (each model instance runs on 1 GPU, 4 instances using DP) :: + + torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=1 --runtime_ngpus=4 + +Using model parallel: (a model instance runs on all 4 GPUs, no DP) :: + + torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=4 --runtime_ngpus=4 + +Using hybrid parallel: (each model instance runs on 2 GPUs, 2 instances using DP) :: + + torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=2 --runtime_ngpus=4 + +Resuming +======== + +You may resume an interrupted training: :: + + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --init_from=resume + +This will load the latest checkpoint saved by Lightning. + +For distributed environments, checkpoints must be *merged* when the environment changes. +Check :doc:`the reference <../pytorch_lightning>` for details. + +.. + FIXME: link to the section (dunno how to link into markdown) + +******** +The Code +******** + +The example code ``train_nnscaler.py`` is modified from nanoGPT's ``train.py``. + +The modification consists of two parts, (1) porting to Lightning trainer and (2) using nnScaler for distribution. + +The Lightning port is not the point of this example. Check the source code if you are interested. + +To parallelize the lightning model with nnScaler, there are 2 noteworthy places: + +1. Define the forward function and declare it's inputs: + + .. code-block:: python + + class LitModel(L.LightningModule): + def __init__(self): + super().__init__() + self.model = model + self.dummy_forward_args_fn = lambda batch: {'x': batch[0], 'y': batch[1]} + + def forward(self, x, y): + _logits, loss = self.model(x, y) + return loss + + A separate forward function is *required* because nnScaler will only parallelizes the codes in ``forward()``, + and will not touch those in ``training_step()``. + + And then, a special function ``dummy_forward_args_fn`` need to be defined to the ``LightningModule``. + It takes ``training_step()``'s ``batch`` argument, and returns a ``dict`` presenting ``forward()``'s parameters. + This function will be used to trace the module's forward graph. + +2. Register nnScaler's strategy and plugin to the Lightning trainer: + + .. code-block:: python + + compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, constant_folding=True) + strategy = NnScalerStrategy(compute_config=compute_config, pas_policy='autodist') + plugins = [NnScalerPrecision(precision)] + + trainer = L.Trainer(strateg=strategy, plugins=plugins, ...) + + For data parallel, always set ``plan_ngpus`` to 1 and set ``runtime_ngpus`` to the total GPU number. + + Other parameters are used for performance (efficiency) tuning. + +.. For details, please check the :doc:`API reference `. + +********************** +Parity and Limitations +********************** + +Single GPU +========== + +For comparison, you can run the script without using nnScaler: :: + + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --use_nnscaler=False + +This will result in a similar loss curve: + +.. image:: ../images/nanogpt-curves.png + +There are several causes for the mismatch: + +1. nnScaler and Lightning have slightly different gradient clip implementation. +2. It cannot fully syncronize the random state for dropouts. +3. PyTorch is not deterministic by default. + +To get a perfectly matched curve, use the following command: +(The overfitting is significant due to the lack of dropout) +:: + + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --deterministic=True + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --deterministic=True --use_nnscaler=False + +.. image:: ../images/nanogpt-curves-deterministic.png + +Data Parallel +============= + +Here is a comparison between nnScaler's and Lightning's builtin data parallel: + +The curve is not fully reproducable due the nature of parallel. + +.. image:: ../images/nanogpt-curves-dp2.png + +The Lightning Port +================== + +The Lightning port is not exactly the same as the original nanoGPT training script for the following reaons: + +1. The Lightning ``Trainer`` is different from nanoGPT's training loop. +2. nnScaler currently lacks the support for multiple parameter groups, and therefore the weight decay is configured for all parameters. + +.. image:: ../images/nanogpt-curves-orig.png diff --git a/examples/vision/DaGAN/README.md b/examples/vision/DaGAN/README.md deleted file mode 100644 index 833f9088..00000000 --- a/examples/vision/DaGAN/README.md +++ /dev/null @@ -1,24 +0,0 @@ -This example demonstrates a GAN-like vision model. The nnscaler trainer assumes there is only one end-to-end module that needs to be parallelized. However, GAN-like models always have both a generator and a discriminator. Here, you will learn how to run your code without the nnscaler trainer, and how to parallelize, synchronize, and update modules during training. - -In this example, both `GeneratorFullModel` and `DiscriminatorFullModel` contain the same keypoint detector, generator, and discriminator modules. A module cannot be parallelized multiple times, so keypoint detector, generator, and discriminator must be parallelized separately. Separate synchronization and updates are also needed during training. - - -# clone CVPR2022-DaGAN repository from github -``` -cd MagicCube/examples/vision/DaGAN -git clone https://github.com/harlanhong/CVPR2022-DaGAN.git -``` - -# Install dependent packages -``` -mv CVPR2022-DaGAN CVPR2022_DaGAN -cd CVPR2022_DaGAN -pip install --ignore-installed -r requirements.txt -cd .. -export PYTHONPATH=$PYTHONPATH:CVPR2022_DaGAN -``` - -# run -``` -CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc-per-node=4 --master_port=12348 run.py --config vox-adv-256.yaml --name DaGAN --batchsize 8 --kp_num 15 --generator DepthAwareGenerator -``` \ No newline at end of file diff --git a/examples/vision/dagan/README.rst b/examples/vision/dagan/README.rst new file mode 100644 index 00000000..b16fba64 --- /dev/null +++ b/examples/vision/dagan/README.rst @@ -0,0 +1,28 @@ +############# +DaGAN Example +############# + +This example demonstrates a GAN-like vision model. +The nnscaler trainer assumes there is only one end-to-end module that needs to be parallelized. +However, GAN-like models always have both a generator and a discriminator. +Here, you will learn how to run your code without the nnscaler trainer, and how to parallelize, synchronize, and update modules during training. + +In this example, both ``GeneratorFullModel`` and ``DiscriminatorFullModel`` contain the same keypoint detector, generator, and discriminator modules. +A module cannot be parallelized multiple times, so keypoint detector, generator, and discriminator must be parallelized separately. +Separate synchronization and updates are also needed during training. + +.. code-block:: bash + + # clone nnScaler & CVPR2022-DaGAN repositories + git clone --recursive https://github.com/microsoft/nnscaler + cd nnscaler/examples/vision/dagan + git clone https://github.com/harlanhong/CVPR2022-DaGAN.git + + # Install dependent packages + mv CVPR2022-DaGAN CVPR2022_DaGAN + pip install --ignore-installed -r CVPR2022_DaGAN/requirements.txt + export PYTHONPATH=$PYTHONPATH:CVPR2022_DaGAN + + # Run + CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc-per-node=4 --master_port=12348 run.py \ + --config vox-adv-256.yaml --name DaGAN --batchsize 8 --kp_num 15 --generator DepthAwareGenerator diff --git a/examples/vision/DaGAN/dataset.py b/examples/vision/dagan/dataset.py similarity index 95% rename from examples/vision/DaGAN/dataset.py rename to examples/vision/dagan/dataset.py index b2bf65c3..19363f85 100644 --- a/examples/vision/DaGAN/dataset.py +++ b/examples/vision/dagan/dataset.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import random import string diff --git a/examples/vision/DaGAN/model_full.py b/examples/vision/dagan/model_full.py similarity index 98% rename from examples/vision/DaGAN/model_full.py rename to examples/vision/dagan/model_full.py index 57f64770..2b05209d 100644 --- a/examples/vision/DaGAN/model_full.py +++ b/examples/vision/dagan/model_full.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import CVPR2022_DaGAN.depth as depth from CVPR2022_DaGAN.modules.model import ImagePyramide @@ -216,4 +219,4 @@ def __init__(self, kp_extractor, generator, discriminator, config): if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() - self.loss_weights = self.train_params['loss_weights'] \ No newline at end of file + self.loss_weights = self.train_params['loss_weights'] diff --git a/examples/vision/DaGAN/requirements.txt b/examples/vision/dagan/requirements.txt similarity index 100% rename from examples/vision/DaGAN/requirements.txt rename to examples/vision/dagan/requirements.txt diff --git a/examples/vision/DaGAN/run.py b/examples/vision/dagan/run.py similarity index 99% rename from examples/vision/DaGAN/run.py rename to examples/vision/dagan/run.py index 3bfb1862..44c41e36 100644 --- a/examples/vision/DaGAN/run.py +++ b/examples/vision/dagan/run.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import matplotlib matplotlib.use('Agg') @@ -390,5 +393,3 @@ def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, da init_seeds() main(rank, world_size) dist.destroy_process_group() - - \ No newline at end of file diff --git a/examples/vision/DaGAN/vox-adv-256.yaml b/examples/vision/dagan/vox-adv-256.yaml similarity index 100% rename from examples/vision/DaGAN/vox-adv-256.yaml rename to examples/vision/dagan/vox-adv-256.yaml diff --git a/examples/vit/README.md b/examples/vit/README.md index 96dd2ed5..53ceb520 100644 --- a/examples/vit/README.md +++ b/examples/vit/README.md @@ -1,8 +1,10 @@ -# Introduction +# ViT Example + +## Introduction This example demonstrates how to use nnscaler to fine-tuning a transformer model. Here we use ViT as an example. -# Requirements +## Requirements To run this example, you need to install the packages listed in the `requirements.txt` file. You can install them by running the following command: @@ -17,7 +19,7 @@ pip install -r requirements.txt The implementation is inspired by [here](https://medium.com/@supersjgk/fine-tuning-vision-transformer-with-hugging-face-and-pytorch-df19839d5396). Many thanks to the author. -## Run +### Run First go to `examples/vit` directory, You can use the following command to run the example: diff --git a/tests/cli/test_resume_seed.py b/tests/cli/test_resume_seed.py index 44356b45..43cfb4ae 100644 --- a/tests/cli/test_resume_seed.py +++ b/tests/cli/test_resume_seed.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import os import pytest import torch diff --git a/tests/graph/test_multiref.py b/tests/graph/test_multiref.py index 3e1dcb3d..4b4f1f4a 100644 --- a/tests/graph/test_multiref.py +++ b/tests/graph/test_multiref.py @@ -1,8 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ pytest unit_tests/graph/test_multiref.py """ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. import torch import logging diff --git a/tests/graph/tracer/test_op_context.py b/tests/graph/tracer/test_op_context.py index f85aab7f..ee8e88df 100644 --- a/tests/graph/tracer/test_op_context.py +++ b/tests/graph/tracer/test_op_context.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import operator import torch diff --git a/tests/parallel_module/test_async.py b/tests/parallel_module/test_async.py index 27d4d101..1ed114b9 100644 --- a/tests/parallel_module/test_async.py +++ b/tests/parallel_module/test_async.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pathlib import Path import tempfile import pytest diff --git a/tests/parallel_module/test_e2e_detach_loss.py b/tests/parallel_module/test_e2e_detach_loss.py index d478c5fc..eb2b6a53 100644 --- a/tests/parallel_module/test_e2e_detach_loss.py +++ b/tests/parallel_module/test_e2e_detach_loss.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import torch.nn as nn import tempfile diff --git a/tests/parallel_module/test_end2end_mix_precision.py b/tests/parallel_module/test_end2end_mix_precision.py index aa84dfd9..c1922b5b 100644 --- a/tests/parallel_module/test_end2end_mix_precision.py +++ b/tests/parallel_module/test_end2end_mix_precision.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ PYTHONPATH=.:$PYTHONPATH torchrun \ --nproc_per_node=4 \ From 76acedbc28f110ac40d3a633976df8ca5e108c47 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Thu, 28 Nov 2024 06:07:13 +0000 Subject: [PATCH 1776/1892] Merged PR 2331: Minor fixes for llama3 demo and misc changes 0. Update version to 0.6 1. With the latest transformers version `use_cache=False` is required for tracing (requirements.txt specifies an old version that does not require it) 2. When nnodes/nproc_per_node is large, the timestamp can mismatch. Using timestamp as log name in this case will prevent merging checkpoints because of different trainer args 3. Added a known environment issue to troubleshooting --- docs/source/troubleshooting.rst | 21 +++++++++++++++++++++ examples/llama3_demo/train.py | 11 +++++------ nnscaler/version.py | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/docs/source/troubleshooting.rst b/docs/source/troubleshooting.rst index 99104fac..30e8e331 100644 --- a/docs/source/troubleshooting.rst +++ b/docs/source/troubleshooting.rst @@ -114,6 +114,8 @@ Run the following command: :: python -c 'import os,sys,nnscaler,cppimport.import_hook ; sys.path.append(os.path.dirname(nnscaler.__path__[0])) ; import nnscaler.autodist.dp_solver' +If it complains ``GLIBCXX_x.y.z`` not found, check the next issue. + Example stacktrace: :: Traceback (most recent call last): @@ -141,6 +143,25 @@ Example stacktrace: :: import nnscaler.autodist.dp_solver as dp_solver ModuleNotFoundError: No module named 'nnscaler.autodist.dp_solver' +"ImportError: ...... libstdc++.so.6: version \`GLIBCXX_x.y.z' not found" +------------------------------------------------------------------------- + +This is caused by gcc and glibc version mismatch. +Typically it means it's using the system gcc and conda's glibc. + +You can remove conda's glibc to force it use system glibc: :: + + rm /lib/libstdc++.so.6 + +The path is shown in the error message. + +Example stacktrace: :: + + $ python -c 'import nnscaler,cppimport.import_hook ; import nnscaler.autodist.dp_solver' + Traceback (most recent call last): + File "", line 1, in + ImportError: /home/user/miniconda3/envs/user/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by .../nnscaler/autodist/dp_solver.cpython-310-x86_64-linux-gnu.so) + Incorrect Usages ================ diff --git a/examples/llama3_demo/train.py b/examples/llama3_demo/train.py index bfed4389..11850605 100644 --- a/examples/llama3_demo/train.py +++ b/examples/llama3_demo/train.py @@ -89,9 +89,12 @@ def __init__(self, model_id, from_scratch=False, num_hidden_layers=None): config = AutoConfig.from_pretrained(model_id) if num_hidden_layers: config.num_hidden_layers = num_hidden_layers + # using kv cache may fail nnscaler's tracing in certain transformers versions, + # and training does not need kv cache + config.use_cache = False self.model = AutoModelForCausalLM.from_config(config) else: - self.model = AutoModelForCausalLM.from_pretrained(model_id) + self.model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False) def forward(self, data): result = self.model( @@ -242,13 +245,9 @@ def collate(samples): resume_from=args.resume_from, ) - timestamp = datetime.now().strftime('%y%m%d%H%M%S') log_config = LogConfig( type=TensorBoardLogger, - args={ - 'name': f'llama3-example-{timestamp}', - 'root_dir': 'runs', - }, + args={'name': 'llama3_demo', 'root_dir': 'runs'}, ) trainer_args = TrainerArgs( diff --git a/nnscaler/version.py b/nnscaler/version.py index 2fa3e43b..49925383 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -__version__ = '0.5' +__version__ = '0.6' From b1eb6667cb880e048a1dc58e1b7b14af5483aa6c Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Thu, 28 Nov 2024 06:50:22 +0000 Subject: [PATCH 1777/1892] Merged PR 2329: [CI/Build] attn_implementation as an option in example/llama for V100 not support flash_attention_2 V100 not support flash_attention_2 --- examples/llama/train.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/llama/train.py b/examples/llama/train.py index e80f73fb..777838ad 100644 --- a/examples/llama/train.py +++ b/examples/llama/train.py @@ -59,9 +59,9 @@ def get_tokenizer(tokenizer_name_or_path, class WrapperModel(torch.nn.Module): - def __init__(self, model_id, enable_chunk_loss): + def __init__(self, model_id, enable_chunk_loss, attn_implementation='flash_attention_2'): super().__init__() - self.model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation='flash_attention_2') + self.model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation) self.enable_chunk_loss = enable_chunk_loss def forward(self, samples): @@ -183,6 +183,7 @@ def collate_fn(samples): args={ 'model_id': args.model_id, 'enable_chunk_loss': args.enable_chunk_loss, + 'attn_implementation': args.attn_implementation, }, ) @@ -344,6 +345,12 @@ def collate_fn(samples): type=int, help='max training steps', ) + parser.add_argument( + '--attn_implementation', + default='flash_attention_2', + type=str, + help='attn implementation, can be flash_attention_2, spda or eager', + ) args = parser.parse_args() if args.explore_pipeline and not args.pipeline_pivots: raise ValueError('pipeline_pivots must be specified when explore_pipeline is enabled') From 78009275ecaab46bdad76c5b29131c421e4359af Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 29 Nov 2024 08:38:18 +0000 Subject: [PATCH 1778/1892] Merged PR 2327: [Reorg] Refine DistAlgo logic [Reorg]Refine DistAlgo logic 1. Split DistAlgorithmFactory.algorithms to two functions. 2. Add mro support (algorithm checks all registered distalgo along mro) unit test pass parity check pass --- docs/source/dimops.md | 2 +- .../ring_attention/test_ring_attn.py | 2 +- .../ring_attention/test_zigzag_attn.py | 2 +- examples/utils.py | 2 +- examples/vision/swin/policy/gallery.py | 2 +- nnscaler/algorithm/factory.py | 96 +++++++++++-------- nnscaler/algorithm/ops/dimops.py | 6 +- nnscaler/autodist/cost_database.py | 22 ++--- nnscaler/autodist/cube_operator.py | 2 +- nnscaler/autodist/op_partition.py | 4 +- nnscaler/autodist/spmd_solver.py | 2 +- nnscaler/autodist/util.py | 8 +- nnscaler/graph/function/dimops.py | 21 +--- nnscaler/ir/cten.py | 18 ++-- nnscaler/ir/operator.py | 38 ++++---- nnscaler/policies.py | 6 +- tests/algorithm/test_factory.py | 81 ++++++++++++++++ tests/compiler/test_compile.py | 2 +- tests/graph/function/test_dimops.py | 2 +- tests/graph/gener/test_producer_fusion.py | 4 +- tests/graph/gener/test_reducer_gen.py | 6 +- tests/graph/parser/test_register.py | 6 +- tests/graph/test_multiref.py | 4 +- tests/graph/test_segment.py | 20 ++-- tests/parallel_module/test_e2e_detach_loss.py | 6 +- tests/parallel_module/test_normlayer.py | 4 +- tests/runtime/test_gnorm.py | 4 +- tests/runtime/test_module_merge.py | 2 +- tests/runtime/test_reducer.py | 2 +- utility/verify_ops/verify_dimops.py | 14 +-- 30 files changed, 238 insertions(+), 152 deletions(-) create mode 100644 tests/algorithm/test_factory.py diff --git a/docs/source/dimops.md b/docs/source/dimops.md index df7d4288..3e3b7270 100644 --- a/docs/source/dimops.md +++ b/docs/source/dimops.md @@ -96,7 +96,7 @@ During policy decsion, user can see the operator and its name is 'matmul_custom' def PAS(graph: IRGraph, resource): for node in graph.nodes(): if node.name == 'matmul_custom': - algo = node.algorithms('dim') + algo = node.algorithm('dim') # partition kd+ config = dict(idx=0, dim=1, num=resource.ngpus) subnodes = graph.partition(node, algo, **config) diff --git a/examples/customized_ops/ring_attention/test_ring_attn.py b/examples/customized_ops/ring_attention/test_ring_attn.py index 62104533..757ee84e 100644 --- a/examples/customized_ops/ring_attention/test_ring_attn.py +++ b/examples/customized_ops/ring_attention/test_ring_attn.py @@ -62,7 +62,7 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: if not partitioned and node.signature == 'ring_attn.wrap_ring_attn_func': print('Partitioned node: ', node) sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=0, dim=1, num=ngpus) + node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) partitioned = True else: sub_nodes = graph.replicate(node, times=ngpus) diff --git a/examples/customized_ops/ring_attention/test_zigzag_attn.py b/examples/customized_ops/ring_attention/test_zigzag_attn.py index c00f082e..e97a5b88 100644 --- a/examples/customized_ops/ring_attention/test_zigzag_attn.py +++ b/examples/customized_ops/ring_attention/test_zigzag_attn.py @@ -62,7 +62,7 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: if not partitioned and node.signature == 'zigzag_attn.wrap_zigzag_attn_func': print('Partitioned node: ', node) sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=0, dim=1, num=ngpus) + node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) partitioned = True else: sub_nodes = graph.replicate(node, times=ngpus) diff --git a/examples/utils.py b/examples/utils.py index 25eea356..a574b0e0 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -102,7 +102,7 @@ def tensor_parallelism(graph: IRGraph, node: IRDimops, graph.assign(node, devs[0]) return [node] # transformation - algo = node.algorithms('dim') + algo = node.algorithm('dim') sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) assert sub_nodes is not None diff --git a/examples/vision/swin/policy/gallery.py b/examples/vision/swin/policy/gallery.py index 22554a23..ac3b2177 100644 --- a/examples/vision/swin/policy/gallery.py +++ b/examples/vision/swin/policy/gallery.py @@ -21,7 +21,7 @@ def coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, idx: int, dim: int): - algo = node.algorithms('dim') + algo = node.algorithm('dim') sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) assert sub_nodes is not None graph.recompute(sub_nodes) diff --git a/nnscaler/algorithm/factory.py b/nnscaler/algorithm/factory.py index e72bdc09..c9ae88d7 100644 --- a/nnscaler/algorithm/factory.py +++ b/nnscaler/algorithm/factory.py @@ -1,64 +1,76 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Dict, Any +from typing import Dict, Any, Union, List, Optional, Type, overload +from nnscaler.algorithm.generics import GenericDistAlgo -class DistAlgorithmFactory: - class __DistAlgorithmFactory: - - def __init__(self): - # [LogicOp][tag] = algorithm - self._algos: Dict[Any, Dict[str, Any]] = dict() - - instance = None - +class _DistAlgorithmFactory: def __init__(self): - if not DistAlgorithmFactory.instance: - DistAlgorithmFactory.instance = DistAlgorithmFactory.__DistAlgorithmFactory() - self._load_predefined_algos() + self._algos: dict[type, dict[str, type[GenericDistAlgo]]] = {} + self._load_predefined_algos() - def __getattr__(self, name): - return getattr(self.instance, name) - - def exist(self, op, tag=None): + def exist(self, op: Type, tag: Optional[str] = None): """ Check if the factory has op's algorithm recorded Returns: True if have, False if not """ - if tag is None: - return op in self.instance._algos - else: - return op in self.instance._algos and tag in self.instance._algos[op] - - def register(self, op, algorithm, tag: str): + for op_class in op.mro(): + if op_class not in self._algos: + continue + if tag is None or tag in self._algos[op_class]: + return True + return False + + def register(self, op, algorithm: type[GenericDistAlgo], tag: str): """ - Register a holistic op (class) as one of the anchors + Register a holistic op (class) as one of the anchors """ - if op not in self.instance._algos: - self.instance._algos[op] = dict() - self.instance._algos[op][tag] = algorithm + if op not in self._algos: + self._algos[op] = dict() + self._algos[op][tag] = algorithm - def algorithms(self, op, tag = None): + def algorithms(self, op: Type) -> List[GenericDistAlgo]: """ - Get op tranformed algorithms + Get all transform algorithms for the op Args: - op (IRFwOperation): index for the holist op factory - args, kwargs: (logical) tensor inputs + op (IRFwOperation): the op to be transformed Returns: - algorithm class + List[GenericDistAlgo]: the algorithms for the op """ - if op not in self.instance._algos: - raise KeyError("Op {op} is not registered in factory") - if tag: - return self.instance._algos[op][tag] - else: - return self.instance._algos[op].values() + algos = [self._algos[op_class] for op_class in op.mro() if op_class in self._algos] + # use dict to remove duplicates and keep order + algos_all: dict[type[GenericDistAlgo], None] = {} + for tag_algo_map in algos: + for algo in tag_algo_map.values(): + algos_all[algo] = None + return list(algos_all.keys()) + + def algorithm(self, op: Type, tag: str) -> GenericDistAlgo: + """ + Get best matched tranform algorithm for the op with tag + + Args: + op (IRFwOperation): the op to be transformed + tag (str): the tag of the algorithm + + Returns: + GenericDistAlgo: the algorithm for the op + + Raises: + ValueError: if the op + tag is not registered in the factory + """ + for op_class in op.mro(): + if op_class not in self._algos: + continue + if tag in self._algos[op_class]: + return self._algos[op_class][tag] + raise ValueError("Op {op} + Tag {tag} is not registered in factory") def _load_predefined_algos(self): @@ -70,3 +82,11 @@ def _load_predefined_algos(self): self.register(conv.IRConv2D, conv.DimSplitConv2D, tag='dim') self.register(conv.IRConv2D, conv.HaloSplitConv2D, tag='halo') self.register(conv.IRConv3D, conv.HaloSplitConv3D, tag='halo') + + +_instance: Optional[_DistAlgorithmFactory] = None +def DistAlgorithmFactory() -> _DistAlgorithmFactory: + global _instance + if _instance is None: + _instance = _DistAlgorithmFactory() + return _instance diff --git a/nnscaler/algorithm/ops/dimops.py b/nnscaler/algorithm/ops/dimops.py index f1aa1cae..668e4507 100644 --- a/nnscaler/algorithm/ops/dimops.py +++ b/nnscaler/algorithm/ops/dimops.py @@ -286,7 +286,7 @@ def gen_partitions(node: IRFwOperation, ngpus: int, base: int = 2, depth: int = list will contain 4 instances: 1. matmul with shape (1024, 4096), (4096, 2048) -> (1024, 2048), this means no partition, replicate on 2 gpus 2. matmul with shape ( 512, 4096), (4096, 2048) -> ( 512, 2048), partition first input first dimension - 3. matmul with shape (1024, 2048), (2048, 2048) -> (1024, 2048), partition first input second dimension + 3. matmul with shape (1024, 2048), (2048, 2048) -> (1024, 2048), partition first input second dimension 4. matmul with shape (1024, 4096), (4096, 1024) -> (1024, 1024), partition second input second dimension Args: @@ -311,7 +311,7 @@ def gen_hash(node: IRFwOperation) -> str: ret = ret + '-' + str(it.shape) return ret - dq = deque() + dq: deque[tuple[IRFwOperation, int, int]] = deque() visited = set() dq.append((node, ngpus, 0)) visited.add(gen_hash(node)) @@ -336,7 +336,7 @@ def gen_hash(node: IRFwOperation) -> str: if cur_ngpus % split_deg != 0: break - new_nodes = cur_node.algorithms('dim').instantiate(idx=idx_1st, dim=dim_1st, num=split_deg) + new_nodes = cur_node.algorithm('dim').instantiate(idx=idx_1st, dim=dim_1st, num=split_deg) # instantiate may return None if the partition is not possible if new_nodes is None: break diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 5a4451b0..0e24ddbb 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -209,15 +209,15 @@ def get_mems(self, op_partition): return memory_results def get_mem_and_buffer( - self, - op_partition, - is_train: bool, - stage_num: int, - world_size: int, - plan_ngpus: int, - zero_stage: int, - zero_ngroups: int, - opt_resident_coef: float, + self, + op_partition, + is_train: bool, + stage_num: int, + world_size: int, + plan_ngpus: int, + zero_stage: int, + zero_ngroups: int, + opt_resident_coef: float, opt_transient_coef: float ) -> Tuple[int, int, int, int, int]: """ @@ -523,10 +523,10 @@ def helper(primitive: str): dst_idx, dst_dim = dst_p.operator.dim_id2pos(dst_p_dim) rule_src, rule_dst = None, None if src_idx != -1: - rule_src = src_p.operator.ir_cell.algorithms('dim').infer( + rule_src = src_p.operator.ir_cell.algorithm('dim').infer( src_idx, src_dim, src_p_num) if dst_idx != -1: - rule_dst = dst_p.operator.ir_cell.algorithms('dim').infer( + rule_dst = dst_p.operator.ir_cell.algorithm('dim').infer( dst_idx, dst_dim, dst_p_num) cost = 0.0 for i, src_t in enumerate(src_p.operator.ir_cell.outputs()): diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py index 59be7bbe..f25f1ab5 100644 --- a/nnscaler/autodist/cube_operator.py +++ b/nnscaler/autodist/cube_operator.py @@ -136,7 +136,7 @@ def pos2dim_id(self, pos: Tuple[int, int]) -> str: if not isinstance(self.ir_cell, IRDimops): raise ValueError(f'{self.ir_cell} is not IRDimops') idx, dim = pos - adim, reduce_type = self.ir_cell.algorithms( + adim, reduce_type = self.ir_cell.algorithm( 'dim').get_identifier_reduce(idx, dim, 2) assert adim is not None, f'cannot find dim at {pos} in {self.ir_cell}' return adim diff --git a/nnscaler/autodist/op_partition.py b/nnscaler/autodist/op_partition.py index 00450454..46417ead 100644 --- a/nnscaler/autodist/op_partition.py +++ b/nnscaler/autodist/op_partition.py @@ -118,7 +118,7 @@ def __init__(self, partition_dims: Tuple[str, ...], if isinstance(self.operator.ir_cell, IRDimops): if partition_dims[0] != -1: idx, dim = operator.dim_id2pos(partition_dims[0]) - if not operator.ir_cell.algorithms('dim').satisfy( + if not operator.ir_cell.algorithm('dim').satisfy( idx, dim, partition_nums[0]): raise ValueError( f'invalid partition plan {partition_dims}, {partition_nums} for {operator.op_name}' @@ -129,7 +129,7 @@ def __init__(self, partition_dims: Tuple[str, ...], # 2. we can calculate th intra-communication cost without knowing the device assignment now, # since operator is constrained to be partitioned along one dimension. # It is used to query the computation cost in the cost database. - self.ir_cell = operator.ir_cell.algorithms('dim').instantiate( + self.ir_cell = operator.ir_cell.algorithm('dim').instantiate( idx, dim, partition_nums[0])[0] else: self.ir_cell = operator.ir_cell diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index d6a3f61a..4eb08889 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -366,7 +366,7 @@ def is_valid_partition(operator: CubeOperator, p_ids: List[Any], return False if p_ids[0] != -1: - if not operator.ir_cell.algorithms('dim').satisfy( + if not operator.ir_cell.algorithm('dim').satisfy( p_idx, p_dim, p_nums[0]): return False return True diff --git a/nnscaler/autodist/util.py b/nnscaler/autodist/util.py index 62d7b3dd..bb1968c1 100644 --- a/nnscaler/autodist/util.py +++ b/nnscaler/autodist/util.py @@ -43,7 +43,7 @@ def get_node_arch(): # tensor parallelism def tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, dim: int): - algo = node.algorithms('dim') + algo = node.algorithm('dim') sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) assert sub_nodes is not None for devid, sub_node in zip(devs, sub_nodes): @@ -58,13 +58,13 @@ def replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): return sub_nodes -def partition_node(node: IRFwOperation, graph: IRGraph, devs: [int], +def partition_node(node: IRFwOperation, graph: IRGraph, devs: List[int], desc: NodePartitionDesc) -> None: min_dev_index = min(devs) tp_size = len(devs) info = desc.desc - dq = deque() + dq: deque[tuple[IRFwOperation, tuple[int, int]]] = deque() dq.append((node, (0, tp_size))) for (idx, dim), num in info: @@ -80,7 +80,7 @@ def partition_node(node: IRFwOperation, graph: IRGraph, devs: [int], sub_nodes = graph.replicate(u, times=num) else: assert idx >= 0 and dim >= 0 - algo = u.algorithms('dim') + algo = u.algorithm('dim') sub_nodes = graph.partition(u, algo, idx=idx, dim=dim, num=num) for i in range(num): cur_nodes.append((sub_nodes[i], sub_intervals[i])) diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index cfdc12dc..e544c50e 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -65,7 +65,7 @@ """ -from typing import Callable, Dict, Iterable, List, Union, Set, Tuple, Optional +from typing import Callable, Dict, Iterable, List, Union, Set, Tuple, Optional, overload import enum import re import string @@ -74,7 +74,6 @@ from nnscaler.ir.cten import IRTensor, IRObject from nnscaler.ir.operator import IRFwOperation -from nnscaler.algorithm.factory import DistAlgorithmFactory _kSpecialIdentifiers = ('*', '?') @@ -921,24 +920,6 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict return False return True - def algorithms(self, tag: Optional[str] = None): - factory = DistAlgorithmFactory() - if tag is None: - algos = list() - if factory.exist(type(self)): - algos += [template(self) for template in factory.algorithms(type(self))] - if factory.exist(IRDimops): - algos += [template(self) for template in factory.algorithms(IRDimops)] - return algos - else: - if factory.exist(type(self), tag): - template = factory.algorithms(type(self), tag) - return template(self) - if factory.exist(IRDimops, tag): - template = factory.algorithms(IRDimops, tag) - return template(self) - return None - def transform_space(self) -> List[Tuple[int, int]]: """ Get transformation space of the operator, the transformation space diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 68e1e571..2318bc6f 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -408,14 +408,15 @@ def __repr__(self) -> str: def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[IRObject]: """Get all IRObjects from a complex data structure - Supported complex of types: List, Tuple, Dict, IRTensor, IRObject + Supported complex of types: List, Tuple, Dict, Slice, IRTensor, IRObject Args: val (Any): the complex data structure to be modified _objects (List[IRObject] | None): if provided, the objects will be appened into this - @return _objects List[IRObject]: all IRObject + Return: + List[IRObject]: all IRObject """ _objects = [] if _objects is None else _objects if isinstance(val, (tuple, list)): @@ -432,10 +433,10 @@ def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[ return _objects @staticmethod - def modify_objects_of_complex(val: Any, modifier: Callable) -> Any: + def modify_objects_of_complex(val: Any, modifier: Callable[['IRObject'], 'IRObject']) -> Any: """Return a complex data structure with modified IRObjects - Supported complex of types: List, Tuple, Dict, IRTensor, IRObject + Supported complex of types: List, Tuple, Dict, Slice, IRTensor, IRObject Args: val (Any): the complex data structure to be modified @@ -693,19 +694,22 @@ def tosub_complex(cls, obj: Any) -> Any: return IRCell.modify_objects_of_complex(obj, modifier) @classmethod - def try_unwrap(cls, x: Union[Any, 'IRObject']) -> Any: + def try_unwrap(cls, x: Union[Any, 'IRObject'], unwrap_ir_tensor=False) -> Any: """ Unwrap the IRObject to its original value if it is an IRObject otherwise, go recursively. Args: x (Any): the object to unwrap + unwrap_ir_tensor (bool): whether unwrap IRTensor Returns: Any: the original value """ - if isinstance(x, IRObject) and not isinstance(x, IRTensor): - return x.value + if isinstance(x, IRObject): + if not isinstance(x, IRTensor) or unwrap_ir_tensor: + return x.value + return x elif isinstance(x, (list, tuple)): return type(x)(cls.try_unwrap(v) for v in x) elif isinstance(x, dict): diff --git a/nnscaler/ir/operator.py b/nnscaler/ir/operator.py index 516dbd4f..11bf7bad 100644 --- a/nnscaler/ir/operator.py +++ b/nnscaler/ir/operator.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Optional, Tuple, Any, Union, List +from typing import Optional, Tuple, Any, Union, List, overload import copy from nnscaler.ir.cten import IRCell, IRTensor, IRObject @@ -76,29 +76,29 @@ def recompute(self, group_id: Optional[int]): assert self._recompute == group_id, "The operator is set to recompute in another recompute group." self._recompute = group_id - def algorithms(self, tag: Optional[str] = None) -> Union[Tuple[GenericDistAlgo], GenericDistAlgo]: + def algorithms(self) -> List[GenericDistAlgo]: """ - get algorithm from algorithm factory + get all algorithms from algorithm factory - @param tag Optional[str]: the queried tag (default None for all) + Returns: + List[GenericDistAlgo]: all possible algorithms + """ + factory = DistAlgorithmFactory() + return [template(self) for template in factory.algorithms(type(self))] + + def algorithm(self, tag: str) -> GenericDistAlgo: + """ + get a specific algorithm from algorithm factory + + Args: + tag (str): the tag of the algorithm - @return algorithm(s) Union[Tuple[GenericDistAlgo], GenericDistAlgo]: - If None (default), return all possible algorithms. - Otherwise, return the specified one. + Returns: + GenericDistAlgo: the algorithm """ factory = DistAlgorithmFactory() - if tag is None: - templates = list() - if factory.exist(type(self)): - templates = factory.algorithms(type(self)) - algos = list() - for template in templates: - algos.append(template(self)) - return algos - else: - assert factory.exist(type(self), tag), f"Node {self} doesn't have transformation algorithm tag: {tag}" - template = factory.algorithms(type(self), tag) - return template(self) + template = factory.algorithm(type(self), tag) + return template(self) def replicate(self): """! diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 9c0c5b0d..7275157c 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -44,7 +44,7 @@ def _tp(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int): if len(devs) > 1: sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) + node, node.algorithm('dim'), idx=idx, dim=dim, num=len(devs)) else: sub_nodes = [node] for devid, sub_node in zip(devs, sub_nodes): @@ -103,7 +103,7 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): random.shuffle(configs) for (idx, dim) in configs: if node.input(idx).shape[dim] % len(devs) != 0: continue - if node.algorithms('dim').satisfy(idx=idx, dim=dim, num=len(devs)): + if node.algorithm('dim').satisfy(idx=idx, dim=dim, num=len(devs)): _tp(graph, node, devs, idx, dim) break else: @@ -142,7 +142,7 @@ def pas_data(graph: IRGraph, env_resource: 'ComputeConfig'): for node in graph.nodes(): if isinstance(node, IRFwOperation): try: - algo = node.algorithms('dim') + algo = node.algorithm('dim') idx = 0 sub_nodes = graph.partition( node, algo, idx=idx, dim=batch_dim, num=ngpus) diff --git a/tests/algorithm/test_factory.py b/tests/algorithm/test_factory.py new file mode 100644 index 00000000..07735efa --- /dev/null +++ b/tests/algorithm/test_factory.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest +from nnscaler.algorithm.factory import GenericDistAlgo, _DistAlgorithmFactory + + +def test_mro(): + factory = _DistAlgorithmFactory() + factory._algos.clear() + + class A: + pass + + class B(A): + pass + + class C(B, A): + pass + + class D(C): + pass + + class AlgoA(GenericDistAlgo): + pass + + class AlgoB(GenericDistAlgo): + pass + + class AlgoC(GenericDistAlgo): + pass + + class AlgoA2(GenericDistAlgo): + pass + + class AlgoB2(GenericDistAlgo): + pass + + class AlgoC2(GenericDistAlgo): + pass + + factory.register(A, AlgoA, 'tag') + factory.register(B, AlgoB, 'tag') + factory.register(C, AlgoC, 'tag') + + # different tag with diffent algorithm + factory.register(A, AlgoA2, 'tag2') + factory.register(B, AlgoB2, 'tag2') + factory.register(C, AlgoC2, 'tag2') + + # different tag with the same algorithm + factory.register(A, AlgoA, 'tag3') + factory.register(B, AlgoB, 'tag3') + factory.register(C, AlgoC, 'tag3') + + assert factory.algorithms(D) == [AlgoC, AlgoC2, AlgoB, AlgoB2, AlgoA, AlgoA2] + assert factory.algorithms(C) == [AlgoC, AlgoC2, AlgoB, AlgoB2, AlgoA, AlgoA2] + assert factory.algorithms(B) == [AlgoB, AlgoB2, AlgoA, AlgoA2] + assert factory.algorithms(A) == [AlgoA, AlgoA2] + + assert factory.algorithm(D, 'tag3') == AlgoC + assert factory.algorithm(D, 'tag2') == AlgoC2 + assert factory.algorithm(D, 'tag') == AlgoC + with pytest.raises(ValueError): + factory.algorithm(D, 'tag4') + + assert factory.algorithm(C, 'tag3') == AlgoC + assert factory.algorithm(C, 'tag2') == AlgoC2 + assert factory.algorithm(C, 'tag') == AlgoC + with pytest.raises(ValueError): + factory.algorithm(C, 'tag4') + + assert factory.algorithm(B, 'tag3') == AlgoB + assert factory.algorithm(B, 'tag2') == AlgoB2 + assert factory.algorithm(B, 'tag') == AlgoB + + assert factory.algorithm(A, 'tag3') == AlgoA + assert factory.algorithm(A, 'tag2') == AlgoA2 + assert factory.algorithm(A, 'tag') == AlgoA + + diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index 5a6fbf1a..851d36d6 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -95,7 +95,7 @@ def tp_policy(graph: IRGraph, resource, ngpus_per_unit: int): def tensor_parallelism(node, idx, dim, num): sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=idx, dim=dim, num=num) + node, node.algorithm('dim'), idx=idx, dim=dim, num=num) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) return sub_nodes diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index 67ae3551..0065d9c1 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -25,7 +25,7 @@ def create_op(creator: Callable, def partitionable(node: IRDimops, **config): print(f'\n\n# {node.anno}') print(f'testing node: {node}') - sub_nodes = node.algorithms('dim').instantiate(**config) + sub_nodes = node.algorithm('dim').instantiate(**config) print(f'partitioned sub nodes:') for sub_node in sub_nodes: print(f'# {sub_node.anno}') diff --git a/tests/graph/gener/test_producer_fusion.py b/tests/graph/gener/test_producer_fusion.py index 22144064..9c458dc3 100644 --- a/tests/graph/gener/test_producer_fusion.py +++ b/tests/graph/gener/test_producer_fusion.py @@ -35,11 +35,11 @@ def test_gener_producer_fusion_replicate(): graph.assign(l1, 0) - s1, s2 = graph.partition(l2, l2.algorithms('dim'), idx=0, dim=0, num=2) + s1, s2 = graph.partition(l2, l2.algorithm('dim'), idx=0, dim=0, num=2) r1, r2 = graph.replicate(s1, 2) graph.assign(r1, 0) graph.assign(r2, 0) - s3, s4 = graph.partition(s2, s2.algorithms('dim'), idx=0, dim=1, num=2) + s3, s4 = graph.partition(s2, s2.algorithm('dim'), idx=0, dim=1, num=2) graph.assign(s3, 1) graph.assign(s4, 1) diff --git a/tests/graph/gener/test_reducer_gen.py b/tests/graph/gener/test_reducer_gen.py index 2a1e5a99..e88299cf 100644 --- a/tests/graph/gener/test_reducer_gen.py +++ b/tests/graph/gener/test_reducer_gen.py @@ -103,11 +103,11 @@ def test_reducer_partially_shared_part(): graph = build_graph() [matmul1, matmul2, add, sum] = graph.select(ntype=IRFwOperation) - m1, m2 = graph.partition(matmul1, matmul1.algorithms('dim'), idx=0, dim=1, num=2) + m1, m2 = graph.partition(matmul1, matmul1.algorithm('dim'), idx=0, dim=1, num=2) graph.assign(m1, 0) graph.assign(m2, 1) - add1, add2 = graph.partition(add, add.algorithms('dim'), idx=0, dim=1, num=2) + add1, add2 = graph.partition(add, add.algorithm('dim'), idx=0, dim=1, num=2) graph.assign(add1, 0) graph.assign(add2, 1) @@ -133,7 +133,7 @@ def pas_intra_reducer(graph: IRGraph, config: ComputeConfig): for i, node in enumerate(fw_nodes): if i == 1: - sn0, sn1 = graph.partition(node, node.algorithms('dim'), idx=1, dim=0, num=2) + sn0, sn1 = graph.partition(node, node.algorithm('dim'), idx=1, dim=0, num=2) else: sn0, sn1 = graph.replicate(node, 2) graph.assign(sn0, 0) diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index fff23ec1..d13e2db6 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -127,7 +127,7 @@ def test_autograd_register(): node = ir_graph.select(name='mock_view_with_obj')[0] assert node.kwargs['h'] == 4 - sub_nodes = ir_graph.partition(node, node.algorithms('dim'), idx=0, dim=0, num=2) + sub_nodes = ir_graph.partition(node, node.algorithm('dim'), idx=0, dim=0, num=2) for sub_node in sub_nodes: assert sub_node.kwargs['h'] == 2 @@ -211,11 +211,11 @@ def test_transform_rule(): ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) add_node0 = ir_graph.nodes()[2] add_node1 = ir_graph.nodes()[5] - sub0, sub1 = ir_graph.partition(add_node0, add_node0.algorithms('dim'), idx=0, dim=0, num=2) + sub0, sub1 = ir_graph.partition(add_node0, add_node0.algorithm('dim'), idx=0, dim=0, num=2) assert sub0.kwargs['z'] == 10 assert sub1.kwargs['z'] == 20 - sub2, sub3 = ir_graph.partition(add_node1, add_node1.algorithms('dim'), idx=0, dim=1, num=2) + sub2, sub3 = ir_graph.partition(add_node1, add_node1.algorithm('dim'), idx=0, dim=1, num=2) assert sub2.kwargs['z'] == 10 assert sub3.kwargs['z'] == 10 diff --git a/tests/graph/test_multiref.py b/tests/graph/test_multiref.py index 4b4f1f4a..ac0e28ad 100644 --- a/tests/graph/test_multiref.py +++ b/tests/graph/test_multiref.py @@ -78,14 +78,14 @@ def policy(graph: IRGraph, resource): first_add = graph.select('add')[0] sub_muls = graph.partition( - first_mul, first_mul.algorithms('dim'), + first_mul, first_mul.algorithm('dim'), idx=0, dim=0, num=resource.ngpus ) for idx, sub_node in enumerate(sub_muls): graph.assign(sub_node, idx) sub_adds = graph.partition( - first_add, first_add.algorithms('dim'), + first_add, first_add.algorithm('dim'), idx=0, dim=0, num=resource.ngpus ) for idx, sub_node in enumerate(sub_adds): diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py index 3024faad..a6af9d3f 100644 --- a/tests/graph/test_segment.py +++ b/tests/graph/test_segment.py @@ -88,25 +88,25 @@ def forward(self, q): q = q.transpose(0, 1) l = q.sum() return l, l.data - - + + def policy_transpose(graph: IRGraph, resource: ComputeConfig) -> IRGraph: ngpus = resource.plan_ngpus for _, node in enumerate(graph.select(ntype=IRFwOperation)): print(node.signature) if node.signature in ["torch.transpose"]: sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=0, dim=0, num=ngpus) + node, node.algorithm('dim'), idx=0, dim=0, num=ngpus) else: sub_nodes = graph.replicate(node, times=ngpus) - + for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) for node in graph.select(ntype=IRDataOperation): sub_nodes = graph.replicate(node, times=ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) - + return graph @@ -158,7 +158,7 @@ def forward(self, q): def policy_nograd(graph: IRGraph, cfg: ComputeConfig) -> IRGraph: ngpus = cfg.plan_ngpus # print(graph.nodes()) - if cfg.use_end2end: + if cfg.use_end2end: fc1_node = graph.nodes()[1] func_node = graph.nodes()[2] else: @@ -177,16 +177,16 @@ def policy_nograd(graph: IRGraph, cfg: ComputeConfig) -> IRGraph: # print(node.signature) if node.signature == 'torch.nn.functional.linear': sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=0, dim=0, num=ngpus) + node, node.algorithm('dim'), idx=0, dim=0, num=ngpus) elif node.signature == 'torch.sum': sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=0, dim=1, num=ngpus) + node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) elif 'func' in node.signature: sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=0, dim=0, num=ngpus) + node, node.algorithm('dim'), idx=0, dim=0, num=ngpus) else: sub_nodes = graph.replicate(node, times=ngpus) - + for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) diff --git a/tests/parallel_module/test_e2e_detach_loss.py b/tests/parallel_module/test_e2e_detach_loss.py index eb2b6a53..f2284af9 100644 --- a/tests/parallel_module/test_e2e_detach_loss.py +++ b/tests/parallel_module/test_e2e_detach_loss.py @@ -63,7 +63,7 @@ def policy_pp(graph, cfg): for i, sub_node in enumerate(sub_nodes): graph.assign(sub_node, i) - sub_nodes = graph.partition(fc1, fc1.algorithms('dim'), idx=0, dim=0, num=2) + sub_nodes = graph.partition(fc1, fc1.algorithm('dim'), idx=0, dim=0, num=2) graph.assign(sub_nodes[0], 0) graph.assign(sub_nodes[1], 1) @@ -72,11 +72,11 @@ def policy_pp(graph, cfg): graph.assign(sub_nodes[0], 2) graph.assign(sub_nodes[1], 3) - sub_nodes = graph.partition(fc2, fc2.algorithms('dim'), idx=0, dim=0, num=2) + sub_nodes = graph.partition(fc2, fc2.algorithm('dim'), idx=0, dim=0, num=2) graph.assign(sub_nodes[0], 2) graph.assign(sub_nodes[1], 3) - sub_nodes = graph.partition(loss, loss.algorithms('dim'), idx=0, dim=0, num=2) + sub_nodes = graph.partition(loss, loss.algorithm('dim'), idx=0, dim=0, num=2) graph.assign(sub_nodes[0], 2) graph.assign(sub_nodes[1], 3) diff --git a/tests/parallel_module/test_normlayer.py b/tests/parallel_module/test_normlayer.py index aa9918be..ff3c828a 100644 --- a/tests/parallel_module/test_normlayer.py +++ b/tests/parallel_module/test_normlayer.py @@ -35,7 +35,7 @@ def policy(graph: IRGraph, resource: ComputeConfig, dim: int) -> IRGraph: ): print("Partitioned node: ", node) sub_nodes = graph.partition( - node, node.algorithms("dim"), idx=0, dim=dim, num=ngpus + node, node.algorithm("dim"), idx=0, dim=dim, num=ngpus ) partitioned = True elif ( @@ -45,7 +45,7 @@ def policy(graph: IRGraph, resource: ComputeConfig, dim: int) -> IRGraph: ): print("Partitioned node: ", node) sub_nodes = graph.partition( - node, node.algorithms("dim"), idx=0, dim=0, num=ngpus + node, node.algorithm("dim"), idx=0, dim=0, num=ngpus ) partitioned = True else: diff --git a/tests/runtime/test_gnorm.py b/tests/runtime/test_gnorm.py index 5d8c4a87..80025906 100644 --- a/tests/runtime/test_gnorm.py +++ b/tests/runtime/test_gnorm.py @@ -38,9 +38,9 @@ def forward(self, x): return torch.sum(x) -def tensor_parallelism(graph, node, idx, dim, num): +def tensor_parallelism(graph, node: IRFwOperation, idx, dim, num): sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=idx, dim=dim, num=num) + node, node.algorithm('dim'), idx=idx, dim=dim, num=num) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) return sub_nodes diff --git a/tests/runtime/test_module_merge.py b/tests/runtime/test_module_merge.py index 3e35c72a..9b4e8984 100644 --- a/tests/runtime/test_module_merge.py +++ b/tests/runtime/test_module_merge.py @@ -49,7 +49,7 @@ def tp_policy(graph, resource): for idx, node in enumerate(graph.select(ntype=IRFwOperation)): if node.name == 'add': sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=1, dim=idx % 2, num=resource.ngpus) + node, node.algorithm('dim'), idx=1, dim=idx % 2, num=resource.ngpus) else: sub_nodes = graph.replicate(node, times=resource.ngpus) for devid, node in enumerate(sub_nodes): diff --git a/tests/runtime/test_reducer.py b/tests/runtime/test_reducer.py index 2721c3ee..8b897f77 100644 --- a/tests/runtime/test_reducer.py +++ b/tests/runtime/test_reducer.py @@ -75,7 +75,7 @@ def policy(graph: IRGraph, resource): def tensor_parallelism(node, idx, dim, num): sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=idx, dim=dim, num=num) + node, node.algorithm('dim'), idx=idx, dim=dim, num=num) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) return sub_nodes diff --git a/utility/verify_ops/verify_dimops.py b/utility/verify_ops/verify_dimops.py index f9c56eca..e34c1fa4 100644 --- a/utility/verify_ops/verify_dimops.py +++ b/utility/verify_ops/verify_dimops.py @@ -48,21 +48,21 @@ class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() - + def forward(self, {args}): # Add clone to resolve the issue: # a leaf Variable that requires grad is being used in an in-place operation. {clone_args} - + {func_sig_call} - + out = 0 for one_out in [{outputs}]: if not isinstance(one_out, torch.Tensor): continue out += torch.sum(one_out) return out - + model = TestModule() #.to(torch.float16) """ @@ -91,18 +91,18 @@ def forward(self, {args}): def policy(graph: IRGraph, resource) -> IRGraph: ngpus = 2 partitioned = False - + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): if not partitioned and node.signature == '{func_sig}': print('Partitioned node: ', node) sub_nodes = graph.partition( - node, node.algorithms('dim'), idx={idx}, dim={dim}, num=ngpus) + node, node.algorithm('dim'), idx={idx}, dim={dim}, num=ngpus) partitioned = True else: sub_nodes = graph.replicate(node, times=ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) - + assert partitioned, f'No node is partitioned for {func_sig}.' return graph From ba1ec5bca6c365d6b89230cbd31f88f8f999f3df Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 29 Nov 2024 14:25:18 +0000 Subject: [PATCH 1779/1892] Merged PR 2317: [BugFix] Fix wrong loss name in generated code Fix regression triggered by [PR](https://dev.azure.com/msrasrg/SuperScaler/_git/MagicCube/pullrequest/2298), same variable name (nll_loss_23127) is shared by tensors with different semantics (requires_grad): ![image.png](https://dev.azure.com/msrasrg/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2317/attachments/image.png) In this PR, we assume PAS will insert multiref node correctly to handle IRPyfunc. Note this PR also updates the logic that adding multiref in autodist, this may influence the numerical values of gradients since the order of accumulation is changed. Related work items: #2080 --- nnscaler/autodist/apis.py | 140 +++--- nnscaler/graph/gener/gen.py | 2 + nnscaler/graph/segment.py | 201 ++++++-- nnscaler/policies.py | 27 +- tests/autodist/pas/multiref_plan1.json | 100 ++++ tests/autodist/pas/multiref_plan2.json | 100 ++++ .../autodist/pas/test_multiref_activation.py | 163 ++++++ tests/autodist/pas/test_multiref_param.py | 82 +++ .../pas/test_shared_param_pipeline.py | 83 --- tests/graph/test_loss.py | 12 +- tests/graph/test_segment.py | 7 +- .../test_shared_param_pipeline.py | 473 ++++++++++++++++++ 12 files changed, 1169 insertions(+), 221 deletions(-) create mode 100644 tests/autodist/pas/multiref_plan1.json create mode 100644 tests/autodist/pas/multiref_plan2.json create mode 100644 tests/autodist/pas/test_multiref_activation.py create mode 100644 tests/autodist/pas/test_multiref_param.py delete mode 100644 tests/autodist/pas/test_shared_param_pipeline.py create mode 100644 tests/parallel_module/test_shared_param_pipeline.py diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 31b9b1ad..beae2a56 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -11,15 +11,16 @@ from nnscaler.graph import IRGraph from nnscaler.graph.segment import IRSegment from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.ir.tensor import IRSubTensor +from nnscaler.ir import IRCell from nnscaler.graph.function import IRDimops from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.schedule.predefined import PredefinedSched import json -import os import logging -import time +import copy from typing import Dict, List from pathlib import Path from collections import defaultdict @@ -37,6 +38,7 @@ def check_env(autodist_config: AutoDistConfig): _logger.info(f'create folder: {arch_dir}') arch_dir.mkdir(parents=True, exist_ok=True) + def pre_estimate_mem(graph: ModelGraph): ''' Estimate a rough lower bound of memory consumption per device. Exit if the model is too large @@ -116,9 +118,7 @@ def parallelize_graph(graph: IRGraph, search_out_json = json.load(f) search_out = PipelineSearchOutput.from_json(search_out_json) else: - compile_start_time = time.time() search_out = calc_parallel_plan(graph, autodist_config) - compile_cost_time = time.time() - compile_start_time if autodist_config.save_plan_path: _logger.info(f'save plan to {autodist_config.save_plan_path}') @@ -138,57 +138,64 @@ def parallelize_graph(graph: IRGraph, nodes = [cid2node[cid] for cid in group] graph.recompute(nodes) + def subtensor_desc(t): + return (t.indmap, t.grad is not None) + tensor_split_info = defaultdict(dict) + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): + continue + consumers = graph.consumers(ftensor) + if not consumers: + continue + for consumer in consumers: + find_desc = False + for stage_idx, stage_desc in enumerate(pp_desc.spmd_descs): + if consumer.cid not in stage_desc.partition_descs: + continue + find_desc = True + node_desc = stage_desc.partition_descs[consumer.cid].desc + if len(node_desc) != 1: + raise RuntimeError(f'node {consumer} is partitioned along multiple dims') + + (p_idx, p_dim), p_num = node_desc[0] + if p_idx == -1: + partitioned_node = consumer + else: + partitioned_nodes = consumer.algorithm('dim').instantiate(idx=p_idx, dim=p_dim, num=p_num) + if partitioned_nodes is None: + raise RuntimeError(f'node {consumer} cannot be partitioned by {p_idx}-{p_dim}-{p_num}') + partitioned_node = partitioned_nodes[0] + + if stage_idx not in tensor_split_info[ftensor]: + tensor_split_info[ftensor][stage_idx] = set() + for input in partitioned_node.inputs(): + if isinstance(input, IRSubTensor) and input.parent == ftensor: + if p_idx == -1 and stage_desc.mesh_desc.ngpus > 1: + tensor_split_info[ftensor][stage_idx].add(('REPLICATED', subtensor_desc(input))) + else: + # special case: if the stage has only one gpu, we treat it as partitioned + tensor_split_info[ftensor][stage_idx].add(('PARTITIONED', subtensor_desc(input))) + break + assert find_desc, f'node {consumer} not found in any stage' + # graph staging if len(pp_desc.spmd_descs) > 1: # add multiref for shared parameters across stages - shared_param2stage_info = defaultdict(dict) - for ftensor in graph.attributes(): + # note that we have constrained that shared parameters cannot + # be partitioned in SPMDSolver. + for ftensor, stage_info in tensor_split_info.items(): if not ftensor.is_param(): continue - for ctensor, consumer in zip(graph.ctensors(ftensor), - graph.consumers(ftensor)): - if ctensor.grad is None: - continue - for stage_idx, stage_desc in enumerate(pp_desc.spmd_descs): - if consumer.cid in stage_desc.partition_descs: - if len(stage_desc.partition_descs[ - consumer.cid].desc) != 1: - raise RuntimeError( - f'node {consumer} has more than one partition dim' - ) - (p_idx, p_dim), p_num = stage_desc.partition_descs[ - consumer.cid].desc[0] - if p_idx != -1 and consumer.inputs()[p_idx] == ftensor: - raise RuntimeError( - f'node {consumer} has partitioned input {ftensor}' - ) - is_replicated = p_idx == -1 - if stage_idx not in shared_param2stage_info[ftensor]: - shared_param2stage_info[ftensor][stage_idx] = [] - shared_param2stage_info[ftensor][stage_idx].append( - is_replicated) - - for ftensor, stage_info in shared_param2stage_info.items(): - if len(stage_info) == 1: - continue - # special case: all stages have only one gpu - stage_idxs = list(stage_info.keys()) - stage_sizes = [ - pp_desc.spmd_descs[i].mesh_desc.ngpus for i in stage_idxs - ] - if all([s == 1 for s in stage_sizes]): - continue - # check whether all partitioned - # In AutoDist, shared parameters are not allowed to be partitioned. - # As a result, the related operator is replicated or in data parallel. - has_replicated = False - for stage_idx, replicate_info in stage_info.items(): - if any(replicate_info): - has_replicated = True - break - if has_replicated: + splits = set() + find_replicated = False + for stage_splits in stage_info.values(): + splits.update(stage_splits) + if any(s[0] == 'REPLICATED' for s in stage_splits): + find_replicated = True + splits = list(splits) + if len(splits) > 1 or find_replicated: _logger.info(f'add multiref for shared param {ftensor}') - graph.multiref(ftensor) + graph.multiref(ftensor, comment='shared param') stages = [] for spmd_desc in pp_desc.spmd_descs: @@ -208,34 +215,21 @@ def parallelize_graph(graph: IRGraph, # if autodist_config.pipeline and len(stages) != autodist_config.pipeline_nstages: # raise RuntimeError("pipeline_nstages doesn't match the number of stages (based on your pipeline_pivots config) in the plan") - # add multiref to a tensor when - # 1. it is not a grad tensor - # 2. it has more than one consumers - # 3. consumers are different operators or in different partitions - for stage, spmd_desc in zip(stages, pp_desc.spmd_descs): + # add multiref to an activation tensor when the states of the tensor and its grad are different + # among consumers and current segment's outputs + for idx, (stage, spmd_desc) in enumerate(zip(stages, pp_desc.spmd_descs)): for ftensor in stage.full_tensors(): - if ftensor.is_grad(): + if ftensor.is_grad() or ftensor.is_param(): continue - if len(stage.consumers(ftensor)) <= 1: + if idx not in tensor_split_info[ftensor]: continue - consumers = stage.consumers(ftensor) - splits = set() - for consumer in consumers: - if consumer.cid in spmd_desc.partition_descs: - node_desc = spmd_desc.partition_descs[consumer.cid].desc - if len(node_desc) != 1: - raise RuntimeError( - f'node {consumer} has more than one partition desc') - (p_idx, p_dim), p_num = node_desc[0] - else: - _logger.warning( - f'node {consumer} is not in any partition desc') - p_idx, p_dim, p_num = -1, -1, spmd_desc.mesh_desc.ngpus - repr_str = f'{consumer.signature}-{p_idx}-{p_dim}-{p_num}' - splits.add(repr_str) + splits = copy.deepcopy(tensor_split_info[ftensor][idx]) + for output in IRCell.get_objects_from_complex(stage.outputs()): + if isinstance(output, IRSubTensor) and output.parent == ftensor: + splits.add(('REPLICATED', subtensor_desc(output))) if len(splits) > 1: - _logger.debug(f'add multiref {consumers}') - stage.multiref(ftensor) + _logger.debug(f'add multiref for {ftensor} in stage {stage}') + stage.multiref(ftensor, comment='activation') # partition and assign nodes to devices # TODO(yizhu1): network topo aware device map diff --git a/nnscaler/graph/gener/gen.py b/nnscaler/graph/gener/gen.py index 14c79c05..fbb0e398 100644 --- a/nnscaler/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -704,6 +704,7 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): ) multiref = MultiRef(devtensors[devid][0], len(grads)) + multiref.comment = 'created at IRAdapterGener:local_consumer_multiref' # set input gradient multiref.input(0).grad = accum_grad # set output and its gradient @@ -757,6 +758,7 @@ def autoref(graph: IRSegment) -> IRGraph: ptensors = sorted(ptensors, key=lambda t: t.device[0]) for tensor in ptensors: mr = MultiRef(tensor, len(multiref.outputs())) + mr.comment = f'create at IRAdapterGener:autoref, src tensor is {multiref.comment}' mr.input(0).grad = tensor.grad for idx, out in enumerate(multiref.outputs()): output = out.parent.select(tensor.indmap, tensor.valmap) diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index 0bfd429f..3c8ea635 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -720,20 +720,151 @@ def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwO # ===================== Advance Graph manipulations ================== - def multiref(self, ftensor: IRFullTensor, *deprecated_args) -> IRFwOperation: + def multiref(self, ftensor: IRFullTensor, comment: Optional[str] = None, *deprecated_args) -> IRFwOperation: """ - Add multiref to separate forward nodes that consume a same tensor into different tensor alias. - This should be called before any graph transformation. + Multiref accepts a full tensor that used in multiple places (consumed by a node, + or belongs to a graph's outputs). Its output tensors are full tensors with new + ids and dispatched to the corresponding consumers. + The input tensor can be parameter, buffer or activation tensors. - Operators in a group can only be partitioned by a same tensor split strategy. - The created multiref operator will be partitioned automatically when generating - tensor adapters. + Args: + tensor (IRSubTensor): full tensor to be multiref. - @param tensor IRSubTensor: tensor. - @return multiref IRFwOperation: the inserted multiref operator. + Returns: + multiref (IRFwOperation): the inserted multiref operator. + + This function should be called before any graph transformation, like replicate, + partition. The created multiref operator will be partitioned automatically when + generating adapters. + + multiref can be regarded as an approach to create different aliases for the input + full tensor, so we can overcome the limitation of communication generation logic + and correctly generate communications. + + At runtime, multiref just creates multiple tensors with the same storage, as the following + code snippet shows: + ```python + def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: + return tensor if times == 1 else tuple([tensor] * times) + ``` + + There are two kinds of communications in the system. + - Adapter: Which is used to exchange tensors data during forward and backward across + devices in the same scale unit. We use RVDLayout algorithm to generate adapters which + is composed of collective primitives at runtime. The limitation here is the communication + should be simple. If the communication is too complex, the generation will fail. + - IRWeightReducer: Which is used to sync weight (parameter) gradients across devices after + backward. IRWeightReducer will be mapped to nnscaler.runtime.adapter.Reducer in runtime. + The limitation here is the weight should be ALL partitioned or ALL replicated (check + gen_weight in IRAdapterGener). Put it in a simple word, reducer is added when the parameter + can be simply aggregated (summed) across certain devices. + + multiref is here to rescue. Some typical usage of this function are listed below with explanations. + + - Adapter generation case 1: If the full tensor has multiple consumers, and consumers consume + different portion of the full tensor (different tp partition). In this case, RVDLayout may + fail to generate backward communication. We should use multiref to create an alias for each + consumer. RVDLayout will generate communication between each alias and its consumer correctly. + The inserted multiref will aggregate the gradients automatically in the backward pass according + to the multiref's implementation and torch.autograd's mechanism. + + Example: If op1/op2 are consumers of fulltensor ft, and will be partitioned different: + op1(ft) + op2(ft) + multref should be inserted + ft1, ft2 = multiref(ft, 2) + op1(ft1) + op2(ft2) + Note that when op1 is replicated over multiple devices, op2 partitions its another input (not ft2), + although ft's indmap is same on op1 and op2, but we cannot add the gradients directly. As a result, + the multiref is needed too. + + - Adapter generation case 2: If the full tensor has multiple consumers, but these consumers have + different behavior in backward, ie, some of consumers generate gradient (normal torch ops), and + some of consumers don't generate gradient (mostly IRPyFunc). In this case, we need to use multiref, + so each alias can have different behaviors in backward. + + Example: If op1 (generate grad)/getitem (doesn't generate grad) are consumers of fulltensor ft, but with different backward behavior: + torch_op(ft) + getitem(ft) + multref should be inserted + ft1, ft2 = multiref(ft, 2) + torch_op(ft1) + getitem(ft2) + + - Adapter generation case 3: When the full tensor has consumers and also is graph's output (a specail + consumer). If consumers and graph outputs satisify case 1 or case 2, we also need to insert multiref. + It is a little difference with previous cases, because we don't update the tensor of graph outputs, + but use the old name. This is correct since the IRPyFunc and the segment's outputs are forced to be + replicated by the system. + + Example: If op is the only consumer of fulltensor ft: + op(ft) + return ft + multref should be inserted + ft1 = multiref(ft, 1) + op(ft1) + return ft # note old name is used. + + - IRWeightReducer generation case 1: when gradients over devices can not be accumulated directly to + synchronize. This typically happens when a parameter is shared, especially in pipeline parallelism. + Here we can use multiref to synchronize gradients, but the semantic is different. (TODO: add a new + function to handle this case to make it more clear). With multiref, the weight becomes an activation, + so no IRWeightReducer will be generated. Instead, Adapter will be used in runtime. + + Example: weight w is shared by two consumers, the distributed plan is two-stage pipeline, + stage 0 uses gpu 0 + stage 1 uses gpu (1, 2) + weights are all replicated in all devices. + If we don't insert multiref, the weight will be held in both stages, but the communication will fail to generate. + If we ignore the communication generation, the code will look like: + gencode0: + ``` + def __init__(....): + self.w = torch.nn.Parameter(...) + ... + def forward(...): + ... + op1(self.w) + ... + ``` + gencode 1: + ``` + def __init__(....): + self.w = torch.nn.Parameter(...) + ... + def forward(...): + ... + op2(self.w) + ... + ``` + We cannot sum up the gradients on gpu 0/1/2 directly. In logic, the real gradient should be a sum of gpu0 and gpu1's gradients or + gpu0 and gpu2's gradients. As the example shows, generating a reducer is hard in this case. Multiref is inserted to convert the + param to activation to bypass the difficulty, the code will look like: + gencode0: + ``` + def __init__(....): + self.w = torch.nn.Parameter(...) + ... + def forward(...): + ... + w1, w2 = multiref(self.w) + op1(w1) + ... + ``` + gencode1: + ``` + def __init__(....): + ... + def forward(w2, ...): + ... + op2(w2) + ... + ``` + You can see weight is gone in gencode1's constructor. Instead, it will be passed as forward argument. multiref here is just a way to + convert weight to activation, change the way to generate communication and aggregate gradients by adapters and multiref correctly. """ assert ftensor in self._fobjects, f"tensor: {ftensor} not in this graph." - if len(self.consumers(ftensor)) <= 1: return assert not ftensor.is_grad(), f"graph.multiref can only be applied on a non-gradient full tensor." # check no transformation assert len(self.ptensors(ftensor)) <= 1, f"no transformation should be called before multiref" @@ -746,6 +877,8 @@ def multiref(self, ftensor: IRFullTensor, *deprecated_args) -> IRFwOperation: otensors: List[IRSubTensor] = [ft.select(tensor.indmap, tensor.valmap) for ft in ftensors] # create multiref multiref = MultiRef(tensor, len(consumers)) + if comment: + multiref.comment = comment for idx, otensor in enumerate(otensors): multiref.set_output(idx, otensor) # setup gradient @@ -761,7 +894,10 @@ def multiref(self, ftensor: IRFullTensor, *deprecated_args) -> IRFwOperation: fidx = min(self.index(consumer) for consumer in self.consumers(ftensor)) else: fidx = max(self.index(prod) for prod in self.producers(ftensor)) + 1 - if req_grad: + # when the consumer is a IRPyFunc, the tensor at the consumer side will not have a grad + # in this case, we only insert the multiref node in the forward graph + req_backward = any(output.grad is not None for output in multiref.outputs()) + if req_backward: self.finsert(multiref, fidx) else: self.insert(multiref, fidx) @@ -857,6 +993,7 @@ def single_consume(self, one_for_all: bool = True): if len(cnodes) > 0: itensors = [ftensor.like() for _ in range(2)] multiref = MultiRef(reftensor, 2) + multiref.comment = 'create at IRSegment:single_consume' for idx, itensor in enumerate(itensors): multiref.set_output(idx, itensor) multiref.infer_shape() @@ -897,6 +1034,7 @@ def single_consume(self, one_for_all: bool = True): consumer.set_input(idx, itensor) # create and insert multiref operation multiref = MultiRef(ftensor, len(cnodes)) + multiref.comment = 'create at IRSegment:single_consume' for idx, itensor in enumerate(itensors): multiref.set_output(idx, itensor) multiref.infer_shape() @@ -1018,47 +1156,8 @@ def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> I inputs.add(itensor) for otensor in node.oobjs(): # if the tensor is required by segment outputs, set as output - # NOTE: there may be several identical tensors in `segment_outputs`, for example when - # user code contains `return t, t`. To handle this case, we will break the loop once - # finding a matched tensor. - seg_out_matched = False - for seg_out in segment_outputs: - if otensor != seg_out: - continue - # Since we set a consume tensor's grad to None if it is a input of a IRPyFunc, - # it is possible that two tensors share a same id but with different grad. - # For example, consider the following case: - # t1 = dimops(xx) - # t2 = pyfunc(t1) - # return t1, t2 - # Furtherly assume: - # - t1 requires grad - # - `dimops` is partitioned along `xx`'s batch dim - # - `pyfunc` is replicated. - # In nnscaler, there is one fulltensor representing `t1` and several subtensors. - # For simplicity, we name them separately: - # - when t1 is at the output of `dimops`, it is called `t1_a` - # - when t1 is at the input of `pyfunc`, it is called `t1_b` - # - when t1 is the output of the segment, it is called `t1_c` - # In this case, `t1_b`'s grad is set to None (check func `infer_grad`), but the - # the `t1_a` and `t1_c`'s grads are not None. The three subtensors share the same - # id inherited from the fulltensor, which means they are considered as the same - # tensor when calling `__equal__` method. - # Since partition plans for operators are different, there will be two adapters - # in communication generation: - # - between `t1_a` and the returned `t1_c`, this adapter's forward - # output subtensor's grad is not None and share same id with `t1_c`. - # - between `t1_a` and the consumed `t1_b`, this adapter doesn't have - # a mirror (backward op) and its output subtensor's grad is None. Its id is - # the identical to `t1_b` as well. - # `create_segment` is called after `gen_activation`, and we construct the segment's - # output here. However, we don't want to treat `t1_b` as output. To distinguish between - # output tensors of the two adapters, we double check the grad here. - if not isinstance(otensor, IRSubTensor) or otensor.grad == seg_out.grad: - outputs.add(otensor) - seg_out_matched = True - break - if seg_out_matched: + if otensor in segment_outputs: + outputs.add(otensor) continue consumers, ctensors = self.consumers(otensor.parent), self.ctensors(otensor.parent) cids = set(c.cid for c, t in zip(consumers, ctensors) if dmatch(t, otensor)) diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 7275157c..4b6ab30a 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -32,6 +32,7 @@ from nnscaler.graph.function.dimops import IRDimops from nnscaler.graph.segment import IRSegment from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.ir import IRCell, IRSubTensor, IRFullTensor if TYPE_CHECKING: @@ -59,6 +60,21 @@ def _replica(graph: IRGraph, node, devs: List[int]): return sub_nodes +def is_tensor_in_output(t: IRFullTensor, graph: IRSegment) -> bool: + for output in IRCell.get_objects_from_complex(graph.outputs()): + if isinstance(output, IRSubTensor) and output.parent == t: + return True + return False + + +def auto_multiref(graph: IRGraph): + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): continue + in_output = int(is_tensor_in_output(ftensor, graph)) + if len(graph.consumers(ftensor)) + in_output > 1: + graph.multiref(ftensor, comment='auto_multiref') + + def pas_dp(graph: IRGraph, cfg: 'ComputeConfig'): """ pure data parallelism policy @@ -85,10 +101,7 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): random.seed(seed) devs = list(range(ngpus)) - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): continue - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor) + auto_multiref(graph) for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): if node.name == 'multiref' or isinstance(node, IRGraphAnchor): @@ -130,10 +143,7 @@ def pas_data(graph: IRGraph, env_resource: 'ComputeConfig'): tensor partition on batch dimension inside a scale unit, and dp across scale units """ ngpus = env_resource.plan_ngpus - # auto multi-ref - for ftensor in graph.full_tensors(): - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor, [[n] for n in graph.consumers(ftensor)]) + auto_multiref(graph) batch_dim = 0 for dl in graph.select(ntype=IRDataOperation): @@ -172,6 +182,7 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): raise ValueError(f'invalid tp_size {tp_size} for ngpus {ngpus}') pp_size = ngpus // tp_size + auto_multiref(graph) fnodes = graph.select(ntype=IRFwOperation) stages = mitr.divide(pp_size, fnodes) stages = [list(s) for s in stages] diff --git a/tests/autodist/pas/multiref_plan1.json b/tests/autodist/pas/multiref_plan1.json new file mode 100644 index 00000000..266d3622 --- /dev/null +++ b/tests/autodist/pas/multiref_plan1.json @@ -0,0 +1,100 @@ +{ + "desc": { + "spmd_descs": [ + { + "partition_descs": [ + [ + 1, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 2, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 3, + [ + [ + [ + 0, + 0 + ], + 2 + ] + ] + ], + [ + 4, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 5, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } + } + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ] + }, + "e2e_time": 0.0, + "stage_mems": [ + 0.0, 0.0 + ], + "stage_all_times": [ + 0.0, 0.0 + ], + "stage_comp_times": [ + 0.0, 0.0 + ] +} diff --git a/tests/autodist/pas/multiref_plan2.json b/tests/autodist/pas/multiref_plan2.json new file mode 100644 index 00000000..4de944d5 --- /dev/null +++ b/tests/autodist/pas/multiref_plan2.json @@ -0,0 +1,100 @@ +{ + "desc": { + "spmd_descs": [ + { + "partition_descs": [ + [ + 1, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 2, + [ + [ + [ + 1, + 0 + ], + 2 + ] + ] + ], + [ + 3, + [ + [ + [ + 1, + 0 + ], + 2 + ] + ] + ], + [ + 4, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 5, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } + } + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ] + }, + "e2e_time": 0.0, + "stage_mems": [ + 0.0, 0.0 + ], + "stage_all_times": [ + 0.0, 0.0 + ], + "stage_comp_times": [ + 0.0, 0.0 + ] +} diff --git a/tests/autodist/pas/test_multiref_activation.py b/tests/autodist/pas/test_multiref_activation.py new file mode 100644 index 00000000..33e29491 --- /dev/null +++ b/tests/autodist/pas/test_multiref_activation.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import tempfile +import shutil +import contextlib +import pytest +from pathlib import Path + + +import nnscaler +import nnscaler.graph.function.function as F +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph import IRGraph +from nnscaler.ir.adapter import IRAdapter +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.ir.operator import IRFwOperation, IRDataOperation +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.schedule.predefined import PredefinedSched +from tests.utils import clear_dir_on_rank0, init_random +from tests.launch_torchrun import torchrun +from tests.parallel_module.test_gencode import _gencode_contains + + +class ModelA(torch.nn.Module): + + def __init__(self): + super(ModelA, self).__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc(x) + l = x.sum() + return l, l.data + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_loss_multiref(): + m = ModelA() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 10], dtype=torch.float32, device=torch.cuda.current_device()) + + with tempfile.TemporaryDirectory() as tempdir: + pas_cfg = { + 'parallel_profile': False + } + parallelize( + m, + {'x': trace_data}, + 'autodist', + ComputeConfig(1, 1, use_end2end=True, pas_config=pas_cfg), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) + + assert len(_gencode_contains(tempdir, ModelA, 0, '\.multiref')) == 1 + + +class ModelB(torch.nn.Module): + + def __init__(self): + super(ModelB, self).__init__() + self.fc = torch.nn.Linear(10, 10, bias=False) + self.fc1 = torch.nn.Linear(10, 10, bias=False) + self.fc2 = torch.nn.Linear(10, 10, bias=False) + + def forward(self, x): + x = self.fc(x) + x1 = self.fc1(x) + x2 = self.fc2(x) + y = x1 + x2 + l = y.sum() + return l + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_same_partition(): + m = ModelB() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 10], dtype=torch.float32, device=torch.cuda.current_device()) + + with tempfile.TemporaryDirectory() as tempdir: + pas_cfg = { + 'parallel_profile': False + } + parallelize( + m, + {'x': trace_data}, + 'autodist', + ComputeConfig(1, 1, use_end2end=True, pas_config=pas_cfg), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) + + # this multiref is generated by `local_consumer_multiref` in `IRAdapterGener` + assert len(_gencode_contains(tempdir, ModelB, 0, '\.multiref')) == 1 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_diff_partition_1(): + m = ModelB() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 10], dtype=torch.float32, device=torch.cuda.current_device()) + + with tempfile.TemporaryDirectory() as tempdir: + pas_cfg = { + 'load_plan_path': Path(__file__).parent / 'multiref_plan1.json', + } + parallelize( + m, + {'x': trace_data}, + 'autodist', + ComputeConfig(2, 2, use_end2end=True, pas_config=pas_cfg), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) + + # this multiref is generated by `local_consumer_multiref` in `IRAdapterGener` + assert len(_gencode_contains(tempdir, ModelB, 0, '\.multiref')) == 1 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_diff_partition_2(): + m = ModelB() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 10], dtype=torch.float32, device=torch.cuda.current_device()) + + with tempfile.TemporaryDirectory() as tempdir: + pas_cfg = { + 'load_plan_path': Path(__file__).parent / 'multiref_plan2.json', + } + parallelize( + m, + {'x': trace_data}, + 'autodist', + ComputeConfig(2, 2, use_end2end=True, pas_config=pas_cfg), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) + + # this multiref is generated by `local_consumer_multiref` in `IRAdapterGener` + assert len(_gencode_contains(tempdir, ModelB, 0, '\.multiref')) == 1 + # generate code like, should be only one identity_allreduce + # linear_34 = torch.nn.functional.linear(x_42, self.fc_weight_33, bias=None) + # del x_42 + # linear_34 = nnscaler.runtime.adapter.nn.identity_allreduce(linear_34, ranks=[0, 1]) + # linear_105, linear_109 = nnscaler.runtime.function.multiref(linear_34, times=2) + assert len(_gencode_contains(tempdir, ModelB, 0, 'nnscaler.runtime.adapter.nn.identity_allreduce')) == 1 diff --git a/tests/autodist/pas/test_multiref_param.py b/tests/autodist/pas/test_multiref_param.py new file mode 100644 index 00000000..ef8c958e --- /dev/null +++ b/tests/autodist/pas/test_multiref_param.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +import tempfile +import torch +import os +from pathlib import Path +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.autodist.apis import parallelize_graph + +import nnscaler +from nnscaler.ir.unique import IDGenerator +from nnscaler.graph.segment import IRSegment +from nnscaler.flags import CompileFlag +from nnscaler.runtime.utils import microbatches +from nnscaler.program import Program, SemanticDataLoader, SemanticModel +from nnscaler.graph.gener.gen import IRAdapterGener + + +class Model(torch.nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.w = torch.nn.Parameter(torch.randn(hidden_dim, hidden_dim)) + + def forward(self, x): + x = torch.matmul(x, self.w) + x = torch.nn.functional.relu(x) + x = torch.matmul(x, self.w) + return x.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +@pytest.mark.parametrize('cfg_fname', ['all_replicated_pp.json', 'replicated_and_partition.json']) +def test_shared_param_pipeline(cfg_fname): + bsz, hidden_dim = 4, 1024 + + CompileFlag.dev_mode = True + + with tempfile.TemporaryDirectory() as tempdir: + model = Model(hidden_dim) + model.train() + + program = Program() + program.clear() + IDGenerator().clear() + + dataloader = SemanticDataLoader( + microbatches([{ + 'x': torch.randn(bsz, hidden_dim) + }])) + + smodel = SemanticModel(model, attr_savedir=tempdir) + smodel.dummy_input = {'x': torch.randn(bsz, hidden_dim)} + smodel.constant_folding = True + program.set_input([dataloader.irobj]) + ir_dummy_input = next(dataloader) + outputs = smodel(ir_dummy_input) + outputs.backward() + program.set_output([outputs]) + program.finalize() + ir_graph = program.get_graph() + + print(ir_graph.nodes()) + plan_path = Path(os.path.dirname(__file__)) / cfg_fname + cfg = AutoDistConfig(load_plan_path=plan_path, mesh_col=4) + graph = parallelize_graph(ir_graph, cfg) + assert isinstance(graph.nodes()[4], IRSegment) + # check multiref is correctly inserted at the 1st IRSegment (pipeline stage) + has_multiref = False + for node in graph.nodes()[4].nodes(): + if node.signature == 'nnscaler.runtime.function.multiref': + has_multiref = True + break + assert has_multiref + + graph = IRAdapterGener.gen(graph, cost_fn=None) + if graph.sched is not None: + graph.sched.apply() diff --git a/tests/autodist/pas/test_shared_param_pipeline.py b/tests/autodist/pas/test_shared_param_pipeline.py deleted file mode 100644 index 74d91705..00000000 --- a/tests/autodist/pas/test_shared_param_pipeline.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import pytest - -import tempfile -import torch -import os -from pathlib import Path -from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph -from nnscaler.autodist.autodist_config import AutoDistConfig -from nnscaler.autodist.apis import parallelize_graph - -import nnscaler -from nnscaler.ir.unique import IDGenerator -from nnscaler.graph.segment import IRSegment -from nnscaler.flags import CompileFlag -from nnscaler.runtime.utils import microbatches -from nnscaler.program import Program, SemanticDataLoader, SemanticModel -from nnscaler.graph.gener.gen import IRAdapterGener - - -class Model(torch.nn.Module): - - def __init__(self, hidden_dim): - super().__init__() - self.w = torch.nn.Parameter(torch.randn(hidden_dim, hidden_dim)) - - def forward(self, x): - x = torch.matmul(x, self.w) - x = torch.nn.functional.relu(x) - x = torch.matmul(x, self.w) - return x.sum() - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') -def test_shared_param_pipeline(): - bsz, hidden_dim = 1024, 1024 - - CompileFlag.dev_mode = True - - for idx, cfg_fname in enumerate( - ['all_replicated_pp.json', 'replicated_and_partition.json']): - with tempfile.TemporaryDirectory() as tempdir: - model = Model(hidden_dim) - model.train() - - program = Program() - program.clear() - IDGenerator().clear() - - dataloader = SemanticDataLoader( - microbatches([{ - 'x': torch.randn(bsz, hidden_dim) - }])) - - smodel = SemanticModel(model, attr_savedir=tempdir) - smodel.dummy_input = {'x': torch.randn(bsz, hidden_dim)} - smodel.constant_folding = True - program.set_input([dataloader.irobj]) - ir_dummy_input = next(dataloader) - outputs = smodel(ir_dummy_input) - outputs.backward() - program.set_output([outputs]) - program.finalize() - ir_graph = program.get_graph() - - print(ir_graph.nodes()) - plan_path = Path(os.path.dirname(__file__)) / cfg_fname - cfg = AutoDistConfig(load_plan_path=plan_path, mesh_col=4) - graph = parallelize_graph(ir_graph, cfg) - assert isinstance(graph.nodes()[4], IRSegment) - # check multiref is correctly inserted at the 1st IRSegment (pipeline stage) - has_multiref = False - for node in graph.nodes()[4].nodes(): - if node.signature == 'nnscaler.runtime.function.multiref': - has_multiref = True - break - assert has_multiref - - graph = IRAdapterGener.gen(graph, cost_fn=None) - if graph.sched is not None: - graph.sched.apply() diff --git a/tests/graph/test_loss.py b/tests/graph/test_loss.py index 44015a6d..b30e4343 100644 --- a/tests/graph/test_loss.py +++ b/tests/graph/test_loss.py @@ -14,7 +14,7 @@ from nnscaler.execplan import ExecutionPlan from nnscaler.execplan.planpass.fusion import DiffFusion from nnscaler.execplan.planpass.grouping import Grouping -from nnscaler.ir.adapter.prim import AllReduceIdentityPrim, AllToAllAllToAllPrim, AllGatherSplitPrim +from nnscaler.ir.adapter.prim import AllReduceIdentityPrim, AllToAllAllToAllPrim, AllGatherSplitPrim, AllReducePrim from nnscaler.codegen.emit import FuncEmission from ..utils import replace_all_device_with @@ -61,6 +61,7 @@ def pas_partition_loss_hard(graph): linear = graph.nodes()[1] loss = graph.nodes()[2] get_attr = graph.nodes()[3] + graph.multiref(loss.outputs()[0].parent) _replica(graph, dataloader, [0, 1]) _tp(graph, linear, [0, 1], 0, 0) _tp(graph, loss, [0, 1], 0, 0) @@ -130,9 +131,12 @@ def test_loss_partition_hard(): def checker(init_graph, partitioned_graph, adapter_graph, execplan): fw_graph = execplan.seq(0)[1] bw_graph = execplan.seq(0)[2] - adapter = fw_graph.nodes()[-2] - assert len(adapter.prims) == 1 - assert isinstance(adapter.prims[0], AllReduceIdentityPrim) + adapter1 = fw_graph.nodes()[-4] + adapter2 = fw_graph.nodes()[-2] + assert len(adapter1.prims) == 1 + assert isinstance(adapter1.prims[0], AllReduceIdentityPrim) + assert len(adapter2.prims) == 1 + assert isinstance(adapter2.prims[0], AllReducePrim) assert fw_graph.outputs() == init_graph.outputs() emit = FuncEmission() input_tensors, output_tensors, output_grads, input_grads = \ diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py index a6af9d3f..108b9e3e 100644 --- a/tests/graph/test_segment.py +++ b/tests/graph/test_segment.py @@ -17,7 +17,7 @@ from nnscaler.ir.adapter import IRAdapter from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.ir.operator import IRFwOperation, IRDataOperation -from tests.parallel_module.test_gencode import _gencode_contains +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode from ..utils import replace_all_device_with, clear_dir_on_rank0, init_random from ..launch_torchrun import torchrun @@ -94,7 +94,9 @@ def policy_transpose(graph: IRGraph, resource: ComputeConfig) -> IRGraph: ngpus = resource.plan_ngpus for _, node in enumerate(graph.select(ntype=IRFwOperation)): print(node.signature) - if node.signature in ["torch.transpose"]: + if node.signature == 'torch.sum': + graph.multiref(node.outputs()[0].parent) + if node.signature in ["torch.transpose", "torch.sum"]: sub_nodes = graph.partition( node, node.algorithm('dim'), idx=0, dim=0, num=ngpus) else: @@ -126,6 +128,7 @@ def worker_a(): gen_savedir=tempdir, reuse='override', ) + # print_gencode(tempdir, ModelA, pm.rank) pm.to('cuda') ret = pm.train_step((data,)) diff --git a/tests/parallel_module/test_shared_param_pipeline.py b/tests/parallel_module/test_shared_param_pipeline.py new file mode 100644 index 00000000..3d67c662 --- /dev/null +++ b/tests/parallel_module/test_shared_param_pipeline.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +import tempfile +import shutil +import contextlib +import pytest +from pathlib import Path + + +import nnscaler +import nnscaler.graph.function.function as F +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph import IRGraph +from nnscaler.ir.adapter import IRAdapter +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.ir.operator import IRFwOperation, IRDataOperation +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.schedule.predefined import PredefinedSched +from tests.utils import clear_dir_on_rank0, init_random +from tests.launch_torchrun import torchrun +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode + + +# This test file demonstrates when to use multiref for shared parameters in pipeline parallelism. +# The criteria is simple, if we can insert reducers across stages to sync gradients, multiref is +# not needed. Otherwise, multiref is inserted into the graph so that gradients sync is achieved by +# combination of multiref and communications. +# The fundamental reason is that nnScaler's reducer requires a shared parameter should be ALL +# partitioned or ALL replicated, check `gen_weight` in IRAdapterGener for more details. + + +class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.weight = nn.Parameter(torch.randn(16, 16)) + + def forward(self, x): + x = torch.matmul(x, self.weight) + x = torch.matmul(x, self.weight) + return x.sum() + + +class Model2(torch.nn.Module): + def __init__(self): + super(Model2, self).__init__() + self.weight = nn.Parameter(torch.randn(16, 16)) + + def forward(self, x): + x = torch.matmul(x, self.weight) + x = torch.matmul(x, self.weight) + l = x.sum() + return l, l.data + + +class Model3(torch.nn.Module): + def __init__(self): + super(Model3, self).__init__() + self.weight = nn.Parameter(torch.randn(16, 16)) + + def forward(self, x): + x = torch.matmul(x, self.weight) + x = torch.matmul(x, self.weight) + x = torch.matmul(x, self.weight) + return x.sum() + + +class Model4(torch.nn.Module): + def __init__(self): + super(Model4, self).__init__() + self.weight = nn.Parameter(torch.randn(16, 16)) + + def forward(self, x): + x = torch.matmul(x, self.weight) + x = torch.matmul(x, self.weight) + x = torch.matmul(x, self.weight) + l = x.sum() + return l, l.data + + +def policy_easy_no_multiref(graph, cfg): + data_loader, fc1, fc2, loss = graph.nodes()[:4] + graph.staging([fc1, fc2]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + ngpus = cfg.plan_ngpus + sub_nodes = graph.replicate(data_loader, ngpus) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + if ngpus == 2: + graph.assign(fc1, 0) + + identity = stages[1].nodes()[0] + graph.assign(identity, 1) + graph.assign(fc2, 1) + graph.assign(loss, 1) + elif ngpus == 4: + sub_nodes = graph.partition(fc1, fc1.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + identity = stages[1].nodes()[0] + sub_nodes = graph.replicate(identity, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.partition(fc2, fc2.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.partition(loss, loss.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + else: + raise NotImplementedError + + PredefinedSched.sched_1f1b(graph, 4, len(stages)) + + return graph + + +def policy_hard_no_multiref(graph, cfg): + data_loader, fc1, fc2, fc3, loss = graph.nodes()[:5] + graph.staging([fc1, fc2, fc3]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + ngpus = cfg.plan_ngpus + sub_nodes = graph.replicate(data_loader, ngpus) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + if ngpus == 4: + graph.assign(fc1, 0) + + identity = stages[1].nodes()[0] + graph.assign(identity, 1) + graph.assign(fc2, 1) + + identity = stages[2].nodes()[0] + sub_nodes = graph.replicate(identity, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.partition(fc3, fc3.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.partition(loss, loss.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + else: + raise NotImplementedError + + PredefinedSched.sched_1f1b(graph, 4, len(stages)) + + return graph + + +def policy_hard_multiref(graph, cfg): + data_loader, fc1, fc2, fc3, loss = graph.nodes()[:5] + + # need multiref here + param = fc1.inputs()[1].parent + graph.multiref(param) + + graph.staging([fc1, fc2, fc3]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + ngpus = cfg.plan_ngpus + sub_nodes = graph.replicate(data_loader, ngpus) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + if ngpus == 4: + multiref = stages[0].nodes()[0] + graph.assign(multiref, 0) + graph.assign(fc1, 0) + + identity1, identity2, identity3 = stages[1].nodes()[:3] + graph.assign(identity1, 1) + graph.assign(identity2, 1) + graph.assign(identity3, 1) + graph.assign(fc2, 1) + + identity1, identity2 = stages[2].nodes()[:2] + sub_nodes = graph.replicate(identity1, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + sub_nodes = graph.replicate(identity2, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.replicate(fc3, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.replicate(loss, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + else: + raise NotImplementedError + + PredefinedSched.sched_1f1b(graph, 4, len(stages)) + + return graph + + +def policy_hard_multiref2(graph, cfg): + data_loader, fc1, fc2, loss = graph.nodes()[:4] + + # need multiref here + param = fc1.inputs()[1].parent + graph.multiref(param) + + graph.staging([fc1, fc2]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + ngpus = cfg.plan_ngpus + sub_nodes = graph.replicate(data_loader, ngpus) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + if ngpus == 4: + sub_nodes = graph.partition(fc1, fc1.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + identity1, identity2 = stages[1].nodes()[:2] + sub_nodes = graph.replicate(identity1, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + sub_nodes = graph.replicate(identity2, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.replicate(fc2, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.replicate(loss, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + else: + raise NotImplementedError + + PredefinedSched.sched_1f1b(graph, 4, len(stages)) + + return graph + + +def worker_pipeline(model_cls, pas, plan_ngpus, checker): + nnscaler.init() + m = model_cls() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 16], dtype=torch.float32, device=torch.cuda.current_device()) + + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2') as tempdir: + pm = parallelize( + m, + {'x': trace_data}, + pas, + ComputeConfig(plan_ngpus, plan_ngpus, use_end2end=True), + reuse='override', + gen_savedir=tempdir, + ) + pm.to(torch.cuda.current_device()) + + # print_gencode(tempdir, model_cls, pm.rank) + checker(model_cls, pm, tempdir) + samples = [torch.randn([2, 16], dtype=torch.float32, device=torch.cuda.current_device()) for _ in range(4)] + ret = pm.train_step(samples) + + +def checker_no_multiref(model_cls, pm, tempdir): + assert len(pm.reducers) == 1 + assert len(pm.reducers[0].params) == 1 + assert pm.reducers[0].params[0].shape == torch.Size([16, 16]) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('model_cls', [Model, Model2]) +@pytest.mark.parametrize('plan_ngpus', [2, 4]) +def test_shared_param_pipeline_no_multiref_easy(model_cls, plan_ngpus): + torchrun(plan_ngpus, worker_pipeline, model_cls, policy_easy_no_multiref, plan_ngpus, checker_no_multiref) + # should not raise any exception + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('model_cls', [Model3, Model4]) +@pytest.mark.parametrize('plan_ngpus', [4]) +def test_shared_param_pipeline_no_multiref_hard(model_cls, plan_ngpus): + torchrun(plan_ngpus, worker_pipeline, model_cls, policy_hard_no_multiref, plan_ngpus, checker_no_multiref) + # should not raise any exception + assert True + + +def checker_multiref(model_cls, pm, tempdir): + # no reducer should be created in any rank + # gradient accumulation and sync is achieved by multiref and communications + assert not pm.reducers + all_params = list(pm.parameters()) + if pm.rank == 0: + assert len(all_params) == 1 + assert all_params[0].shape == torch.Size([16, 16]) + assert len(_gencode_contains(tempdir, model_cls, pm.rank, 'multiref\(')) == 1 + else: + assert not all_params + assert not _gencode_contains(tempdir, model_cls, pm.rank, 'multiref\(') + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('model_cls', [Model3, Model4]) +@pytest.mark.parametrize('plan_ngpus', [4]) +def test_shared_param_pipeline_multiref_hard(model_cls, plan_ngpus): + torchrun(plan_ngpus, worker_pipeline, model_cls, policy_hard_multiref, plan_ngpus, checker_multiref) + # should not raise any exception + assert True + + +def checker_multiref2(model_cls, pm, tempdir): + # no reducer should be created in any rank + # gradient accumulation and sync is achieved by multiref and communications + # print_gencode(tempdir, model_cls, pm.rank) + assert not pm.reducers + all_params = list(pm.parameters()) + if pm.rank in [0, 1]: + assert len(all_params) == 1 + assert all_params[0].shape == torch.Size([16, 16]) + assert len(_gencode_contains(tempdir, model_cls, pm.rank, 'multiref\(')) == 1 + else: + assert not all_params + assert not _gencode_contains(tempdir, model_cls, pm.rank, 'multiref\(') + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('model_cls', [Model, Model2]) +@pytest.mark.parametrize('plan_ngpus', [4]) +def test_shared_param_pipeline_multiref_hard2(model_cls, plan_ngpus): + torchrun(plan_ngpus, worker_pipeline, model_cls, policy_hard_multiref2, plan_ngpus, checker_multiref2) + # should not raise any exception + assert True + + +def policy_hard_multiref_error(graph, cfg): + data_loader, fc1, fc2, fc3, loss = graph.nodes()[:5] + + graph.staging([fc1, fc2, fc3]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + ngpus = cfg.plan_ngpus + sub_nodes = graph.replicate(data_loader, ngpus) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + if ngpus == 4: + multiref = stages[0].nodes()[0] + graph.assign(multiref, 0) + graph.assign(fc1, 0) + + identity = stages[1].nodes()[0] + graph.assign(identity, 1) + graph.assign(fc2, 1) + + identity = stages[2].nodes()[0] + sub_nodes = graph.replicate(identity, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.replicate(fc3, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.replicate(loss, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + else: + raise NotImplementedError + + PredefinedSched.sched_1f1b(graph, 4, len(stages)) + + return graph + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +@pytest.mark.parametrize('model_cls', [Model3, Model4]) +def test_shared_param_error(model_cls): + m = model_cls() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 16], dtype=torch.float32, device=torch.cuda.current_device()) + + with tempfile.TemporaryDirectory() as tempdir: + with pytest.raises(RuntimeError, match='The weight consumers can either be ALL replicated or ALL partitioned'): + parallelize( + m, + {'x': trace_data}, + policy_hard_multiref_error, + ComputeConfig(4, 4, use_end2end=True), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) + + +def policy_hard_multiref2_error(graph, cfg): + data_loader, fc1, fc2, loss = graph.nodes()[:4] + + graph.staging([fc1, fc2]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + ngpus = cfg.plan_ngpus + sub_nodes = graph.replicate(data_loader, ngpus) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + if ngpus == 4: + sub_nodes = graph.partition(fc1, fc1.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + identity = stages[1].nodes()[0] + sub_nodes = graph.replicate(identity, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.replicate(fc2, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + + sub_nodes = graph.replicate(loss, 2) + graph.assign(sub_nodes[0], 2) + graph.assign(sub_nodes[1], 3) + else: + raise NotImplementedError + + PredefinedSched.sched_1f1b(graph, 4, len(stages)) + + return graph + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +@pytest.mark.parametrize('model_cls', [Model, Model2]) +def test_shared_param_error2(model_cls): + m = model_cls() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 16], dtype=torch.float32, device=torch.cuda.current_device()) + + with tempfile.TemporaryDirectory() as tempdir: + with pytest.raises(RuntimeError, match='The weight consumers can either be ALL replicated or ALL partitioned'): + parallelize( + m, + {'x': trace_data}, + policy_hard_multiref2_error, + ComputeConfig(4, 4, use_end2end=True), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) From df23b260ed4adf2b7d26164dff44de3aebf5abfd Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Mon, 2 Dec 2024 03:35:45 +0000 Subject: [PATCH 1780/1892] Merged PR 2337: [BUG] Fix split function While split_size_or_sections is an IRObject, sum(sections) will cause error. --- nnscaler/graph/function/function.py | 11 ++++++----- tests/graph/function/test_functions.py | 9 +++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 94cc616e..1aa380b1 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1200,12 +1200,13 @@ def Split(tensor, split_size_or_sections, dim = 0, signature = None): """ torch.functional.split(tensor, split_size_or_sections, dim=0) -> List[Tensor] """ - if isinstance(split_size_or_sections, int): - sections = [split_size_or_sections for _ in range(tensor.shape[dim] // split_size_or_sections)] - if tensor.shape[dim] % split_size_or_sections != 0: - sections.append(tensor.shape[dim] % split_size_or_sections) + unwrap_split_size = _unwrap_value(split_size_or_sections) + if isinstance(unwrap_split_size, int): + sections = [unwrap_split_size for _ in range(tensor.shape[dim] // unwrap_split_size)] + if tensor.shape[dim] % unwrap_split_size != 0: + sections.append(tensor.shape[dim] % unwrap_split_size) else: - sections = split_size_or_sections + sections = unwrap_split_size assert sum(sections) == tensor.shape[dim] edim_in = ShapeAnno.create_shape_str(tensor.shape) edim_ous = [copy.copy(edim_in) for _ in sections] diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 951cfb35..9bd72a14 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -951,3 +951,12 @@ def test_Pad(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '3 4 2 -> 9 7 3' op = F.Pad(IRTensor([3, 3, 4, 2]), pad=(o(1), o(1))) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c 2 -> a b c 4' + + +def test_Split(): + op = F.Split(IRTensor([3, 3, 4, 2]), split_size_or_sections=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '3 b c d -> 1 b c d, 1 b c d, 1 b c d' + op = F.Split(IRTensor([5, 3, 4, 2]), split_size_or_sections=IRObject(value=2, is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '5 b c d -> 2 b c d, 2 b c d, 1 b c d' + op = F.Split(IRTensor([7, 3, 4, 2]), split_size_or_sections=IRObject(value=[2, 2, 3], is_constant=False)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '7 b c d -> 2 b c d, 2 b c d, 3 b c d' From 9b0f871a9b93341128bc331ea45c8d1067dd049f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 3 Dec 2024 03:54:37 +0000 Subject: [PATCH 1781/1892] Merged PR 2326: [BugFix] reimplement reshape function [BugFix] reimplement reshape function unit test pass parity check pass --- nnscaler/graph/function/function.py | 326 ++++++++++++++----------- tests/graph/function/test_functions.py | 114 +++++++++ utility/verify_ops/verify_dimops.py | 14 +- 3 files changed, 315 insertions(+), 139 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 1aa380b1..23d46ed3 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1229,143 +1229,188 @@ def Contiguous(input, memory_format = None, signature = None): def _reshape_anno(in_shape: List[int], ou_shape: List[int], kwarg_name: str) -> Tuple[str, List[TransformRule]]: """ - reshape / view annotation and transformation rule generator - - Args: - in_shape List[int]: input shape - ou_shape List[int]: output shape - kwarg_name str: kwarg name of reshape / view op - - Returns: - str: annotation string - List[TransformRule]: transformation rules - """ - def nele(shape, nele=1): - for dimlen in shape: nele *= dimlen - return nele - - # infer -1 - cnt = nele(in_shape) + The general rule is that aligned dimensions in the input shape and output shape + are labeled with partitionable. + Here `align` means when we group input and output shapes (see `_group` function), + the input dims and output dims in a matched groups have exactly one non-1 dim. + + For example, + input shape: [10, 12, 16, 18] + output shape: [10, 2, 3, 32, 18] + The first and the last dimension is aligned, so we can partition both dimensions. + + There are two additional rules when we apply inner dimensions: + 1. When `input_dim % output_dim == 0 or output_dim % input_dim == 0`, + we can break the larger dimension to inner dimensions. + In above example, we can partition the second dimension, as `12 % 2 == 0`. + We first break the second dimension into `(2 6)`, then align the `2` with `2` in output shape. + So the final annotation is `a (b 6) 16 c -> a b 3 32 c` + 2. When a dimension size is `1`, we can skip it and align to the next dimension. + For example, input_shape: [2, 16, 32, 32] + output shape: [1, 32, 32, 32] + We can align the first dimension of input to the second dimension of output. + The final annotation is `a 16 b c -> 1 (a 16) b c`. Please note the first dimension of output is skipped. + + TODO: + We assume all dimensions are static. + We need to handle dynamic shapes (passed as IRObject) in the future. + """ + + if not in_shape or not ou_shape: + # scalar tensor, no need to partition + return '1 -> 1', [] + + from functools import reduce + from operator import mul + + nele = reduce(mul, in_shape) if -1 in ou_shape: idx = ou_shape.index(-1) - ou_shape[idx] = cnt // (-nele(ou_shape)) - assert nele(in_shape) == nele(ou_shape), f"shape mismatch: {in_shape}, {ou_shape}" - - # generate annotation - rest_inshape = [dimlen for dimlen in in_shape] - rest_oushape = [dimlen for dimlen in ou_shape] - chain = [] - can_bucket = True - while len(rest_inshape) != 0 or len(rest_oushape) != 0: - if len(rest_inshape) == 0: - chain = chain + rest_oushape - rest_oushape = [] - elif len(rest_oushape) == 0: - chain = chain + rest_inshape - rest_inshape = [] - else: - dimlen = min(rest_inshape[0], rest_oushape[0]) - if max(rest_inshape[0], rest_oushape[0]) % dimlen == 0: - chain.append(dimlen) - if dimlen == rest_inshape[0]: - rest_inshape.pop(0) - else: - rest_inshape[0] = rest_inshape[0] // dimlen - if dimlen == rest_oushape[0]: - rest_oushape.pop(0) + ou_shape[idx] = nele // (-reduce(mul, ou_shape)) + if nele != reduce(mul, ou_shape): + raise ValueError(f"shape mismatch: {in_shape}, {ou_shape}") + + def _group(input_shape, output_shape) -> List[Tuple[List[int], List[int]]]: + """ + Group input and output shape into groups that can be aligned together + For example + input shape: [10, 12, 16, 18] + output shape: [10, 2, 3, 32, 18] + We can group them into [ + ([10], [10]), + ([12, 16], [2, 3, 32]), + ([18], [18]) + ] + Please note when we group the dimensions, + the dimensions with size 1 will be grouped with next dimension if not matched. + The only exception is when the 1 dimension is the last dimensions. + + For example + [10, 1, 1, 5, 1, 7, 1, 1] and [10, 5, 7] + We will group them into + ([10], [10]) + ([1, 1, 5], [5]) + ([1, 7, 1, 1], [7]) + + And + [10, 1, 1, 5, 1, 7, 1, 1] and [10, 5, 7, 1] + We will group them into + ([10], [10]) + ([1, 1, 5], [5]) + ([1, 7], [7]) + ([1, 1], [1]) + """ + + groups = [] + input_idx = 0 + output_idx = 0 + + while input_idx < len(input_shape) and output_idx < len(output_shape): + # find one group in each iteration until all dimensions are aligned + input_dim = input_shape[input_idx] + output_dim = output_shape[output_idx] + group_input = [input_dim] + group_output = [output_dim] + # we find a group match when input_dim == output_dim + while input_dim != output_dim: + if input_dim < output_dim: + # add more dimensions from input shape + input_idx += 1 + input_dim *= input_shape[input_idx] + group_input.append(input_shape[input_idx]) else: - rest_oushape[0] = rest_oushape[0] // dimlen - else: - can_bucket = False + # add more dimensions from output shape + output_idx += 1 + output_dim *= output_shape[output_idx] + group_output.append(output_shape[output_idx]) + groups.append((group_input, group_output)) + input_idx += 1 + output_idx += 1 + + # at least one of input_shape and output_shape is exhausted + # put all remaining dimensions into the last group + for i in range(input_idx, len(input_shape)): + assert input_shape[i] == 1 + groups[-1][0].append(input_shape[i]) + for i in range(output_idx, len(output_shape)): + assert output_shape[i] == 1 + groups[-1][1].append(output_shape[i]) + + return groups + + shape_groups = _group(in_shape, ou_shape) + letters = iter(string.ascii_lowercase) + in_anno = [] + ou_anno = [] + # the first letter in each in/out annotation + # which will be used to partition the tensor + ifirst = [] + ofirst = [] + + def _append_partitionable(extra_in_anno=None, extra_out_anno=None): + """ + Allocate a letter for the partitionable dimension + If we need to break the dimension into inner dimensions, + pass the rest of annotation with `extra_in_anno` and `extra_out_anno` + """ + letter = next(letters) + in_anno.append(letter if extra_in_anno is None else f'({letter} {extra_in_anno})') + ou_anno.append(letter if extra_out_anno is None else f'({letter} {extra_out_anno})') + ifirst.append(letter) + ofirst.append(letter) + + # handle each aligned group + for in_group, ou_group in shape_groups: + # step 1: remove all leading 1's + # find the first non-1 dimension in input + for in_group_non_one_idx in range(len(in_group)): + if in_group[in_group_non_one_idx] != 1: + break + in_anno.append(f'1') + ifirst.append(None) + else: + in_group_non_one_idx = len(in_group) + in_group = in_group[in_group_non_one_idx:] + + # find the first non-1 dimension in output + for ou_group_non_one_idx in range(len(ou_group)): + if ou_group[ou_group_non_one_idx] != 1: break + ou_anno.append(f'1') + ofirst.append(None) + else: + ou_group_non_one_idx = len(ou_group) + ou_group = ou_group[ou_group_non_one_idx:] - letters = iter(string.ascii_lowercase) - if can_bucket: - inchain = ouchain = chain - inedims = ouedims = edims = [next(letters) for _ in chain] - else: - inchain, ouchain = in_shape, ou_shape - inedims = [str(dimlen) for dimlen in in_shape] - ouedims = [str(dimlen) for dimlen in ou_shape] - chain = inchain + ouchain - edims = inedims + ouedims - shape_map: Dict[str, int] = {edim: eshape for (edim, eshape) in zip(edims, chain)} - - # generate input and output shape annotations - # greedy fuse suffix number - def buckets(shape: List[int], chain: List[int], edims: List[int]) -> List[List[str]]: - anno = [] - dimidx = 0 - for idx, dimlen in enumerate(shape): - elements, bracket = 1, [] - maxele = len(chain) - dimidx - (len(shape) - 1 - idx) - while True: - if len(bracket) == maxele: - assert elements == dimlen, f"internal match error1: {bracket}" - break - if dimidx >= len(chain) or elements * chain[dimidx] > dimlen: - assert elements == dimlen, f"internal match error2: {bracket}" - break - else: - elements *= chain[dimidx] - bracket.append(edims[dimidx]) - dimidx += 1 - # fetch as many 1^ as possible from tail of the previous bracket - if len(bracket) == 0: - assert dimlen == 1, f"internal match error3: dimlen={dimlen}" - back = 0 - for edim in anno[-1][1:][::-1]: - if chain[edims.index(edim)] != 1: - break - back += 1 - assert back > 0, f"internal match error4: dimlen={dimlen}" - bracket = anno[-1][-back:] - anno[-1] = anno[-1][:-back] - assert len(bracket) > 0, f"got a dimension with no edim" - anno.append(bracket) - return anno - - in_anno = buckets(in_shape, inchain, inedims) - ou_anno = buckets(ou_shape, ouchain, ouedims) - - # postprocess on dimlen == 1 - shape_map['1'] = 1 - for bracket in in_anno + ou_anno: - for subdim, edim in enumerate(bracket): - if shape_map[edim] == 1: - bracket[subdim] = str(shape_map[edim]) - - # find out the axis that can be partitioned - ispatial, ifirst = set(), [] - for bracket in in_anno: - sdim = None - for hdim in range(len(bracket)): - if bracket[hdim] == '1' or shape_map[bracket[hdim]] == 1: continue - sdim = bracket[hdim] - break - if sdim is not None: - ispatial.add(sdim) - ifirst.append(sdim) - - ospatial, ofirst = set(), [] - for bracket in ou_anno: - sdim = None - for hdim in range(len(bracket)): - if bracket[hdim] == '1' or shape_map[bracket[hdim]] == 1: continue - sdim = bracket[hdim] - break - if sdim is not None: - ospatial.add(sdim) - ofirst.append(sdim) - - # intersection for spatial partitioned dimensions - spatial = ispatial.intersection(ospatial) - - # set dimension cannot be partitioned - for bracket in in_anno + ou_anno: - for hdim in range(len(bracket)): - if bracket[hdim] not in spatial: - bracket[hdim] = str(shape_map[bracket[hdim]]) + if not in_group or not ou_group: + # all dimensions are 1, we are done + assert len(in_group) == 0 and len(ou_group) == 0 + continue + + if len(in_group) == 1 and len(ou_group) == 1: # aligned + _append_partitionable() + else: + # use inner dimention to partition when possible + rest_start = 0 + + if in_group[0] == ou_group[0]: # special case: no need to use inner dimension + _append_partitionable() + rest_start = 1 + elif in_group[0] % ou_group[0] == 0: + _append_partitionable(extra_in_anno=in_group[0] // ou_group[0]) + rest_start = 1 + elif ou_group[0] % in_group[0] == 0: + _append_partitionable(extra_out_anno=ou_group[0] // in_group[0]) + rest_start = 1 + + for _ in in_group[rest_start:]: + in_anno.append(f'{in_shape[len(in_anno)]}') + ifirst.append(None) + for _ in ou_group[rest_start:]: + ou_anno.append(f'{ou_shape[len(ou_anno)]}') + ofirst.append(None) + + anno = OpAnno.create_op_str([in_anno], [ou_anno]) def modifier(kwargs: Dict, idx, dim, num: int, subnode_idx: int) -> Dict: kwargs = dict(**kwargs) @@ -1379,20 +1424,29 @@ def modifier(kwargs: Dict, idx, dim, num: int, subnode_idx: int) -> Dict: if isinstance(size[oidx], IRObject): _logger.warning(f'partition dim size in IRObject: {size[oidx]}') size[oidx] = size[oidx].value - size[oidx] = size[oidx] // num + if size[oidx] != -1: + size[oidx] = size[oidx] // num kwargs[kwarg_name] = tuple(size) return kwargs + non_none_ifirst = [i for i in ifirst if i is not None] + non_none_ofirst = [i for i in ofirst if i is not None] + # no duplicated identifier + assert len(set(non_none_ifirst)) == len(non_none_ifirst) + assert len(set(non_none_ofirst)) == len(non_none_ofirst) + # all identifier shown in input shape are also shown in output shape + assert set(non_none_ifirst) == set(non_none_ofirst) + # special rules: to change output size argument rules: TransformRule = [] - for identifier in spatial: - iidx = ifirst.index(identifier) + for iidx, identifier in enumerate(ifirst): + if identifier is None: + continue oidx = ofirst.index(identifier) rules.append( TransformRule([DimopSplit.D(iidx)], [DimopSplit.D(oidx)], modifier) ) - anno = OpAnno.create_op_str([in_anno], [ou_anno]) return anno, rules diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 9bd72a14..8d48153a 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -3,6 +3,9 @@ ### Only test the anno creation in these tests +from functools import reduce +from operator import add +from nnscaler.graph.function.dimops import IRDimops, OpAnno import nnscaler.graph.function.function as F from nnscaler.ir.cten import IRObject, IRTensor @@ -960,3 +963,114 @@ def test_Split(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '5 b c d -> 2 b c d, 2 b c d, 1 b c d' op = F.Split(IRTensor([7, 3, 4, 2]), split_size_or_sections=IRObject(value=[2, 2, 3], is_constant=False)) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '7 b c d -> 2 b c d, 2 b c d, 3 b c d' + + +def factors(n): + return set(reduce( + list.__add__, + ([i, n//i] for i in range(1, int(n**0.5) + 1) if n % i == 0))) + + +def verify_partition(op: IRDimops): + anno = op.anno + inputs = torch.randn(op.inputs()[0].shape) + outputs = inputs.reshape(**IRObject.try_unwrap(op.kwargs)).clone() \ + if 'reshape' in op.signature \ + else inputs.view(**IRObject.try_unwrap(op.kwargs)).clone() + + def _get_anno_ids_map(shape_annos: tuple): + anno_ids = {} + for idx, shape in enumerate(shape_annos): + for eidx, edim in enumerate(shape.dims): + for identifier in edim.identifiers: + if identifier[0].isalpha(): + anno_ids[(idx, eidx)] = identifier + break + return anno_ids + + # (input_idx, input_dim) -> identifier + input_anno_ids = _get_anno_ids_map(anno.inputs()) + output_anno_ids = _get_anno_ids_map(anno.outputs()) + # assume each identifier is unique, which is true for reshape/view + # identifier -> (input_idx, input_dim) + reverse_input_anno_ids = {v: k for k, v in input_anno_ids.items()} + reverse_output_anno_ids = {v: k for k, v in output_anno_ids.items()} + + transforms = anno.transform_space() + transform_rules = {} + for transform_rule in op.transform_rules: + transform_rules[ + (transform_rule.inputs()[0].dims[0], transform_rule.outputs()[0].dims[0]) + ] = transform_rule.modifier() + + for transform in transforms: + input_idx, input_dim = transform + identifier = input_anno_ids[transform] + output_idx, output_dim = reverse_output_anno_ids[identifier] + + # only one input/one output for reshape/view + assert input_idx == 0 + assert output_idx == 0 + + dim_size = anno.getlen(identifier) + for factor in factors(dim_size): + # simulate the partition process + # 1. chunk input tensor + input_chunks = torch.chunk(inputs, factor, dim=input_dim) + # 2. update kwargs + kwargs = transform_rules[(input_dim, output_dim)](op.kwargs, 0, input_dim, factor, 0) + kwargs = IRObject.try_unwrap(kwargs) + # 3. reshape/view + reshaped_input_chunks = [chunk.reshape(**kwargs) for chunk in input_chunks] \ + if 'reshape' in op.signature \ + else [chunk.view(**kwargs) for chunk in input_chunks] + # 4. compare with actual output + output_chunks = torch.chunk(outputs, factor, dim=output_dim) + for i in range(factor): + assert torch.equal(reshaped_input_chunks[i], output_chunks[i]) + + +def test_reshape_view(): + RVF = {F.Reshape: 'shape', F.View: 'size'} + for f, kwname in RVF.items(): + query = IRTensor([2048, 1, 2, 512]) + op = f(query, IRObject(value=2048), IRObject(value=1), -1, 32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 1 b 512 -> a 1 (b 16) 32' + assert IRObject.try_unwrap(op.kwargs[kwname]) == (2048, 1, -1, 32) + assert [type(x) for x in op.kwargs[kwname]] == [IRObject, IRObject, int, int] + verify_partition(op) + + query = IRTensor([10, 12, 16, 18]) + op = f(query, 10, 2, 3, 32, 18) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a (b 6) 16 c -> a b 3 32 c' + verify_partition(op) + + query = IRTensor([2, 16, 32, 32]) + op = f(query, 1, 32, 32, 32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 16 b c -> 1 (a 16) b c' + verify_partition(op) + + query = IRTensor([10, 12, 16, 18]) + op = f(query, 10, 24, 8, 18) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b 16 c -> a (b 2) 8 c' + verify_partition(op) + + query = IRTensor([10, 12, 16, 18]) + op = f(query, 10, 16, 12, 18) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 12 16 b -> a 16 12 b' + verify_partition(op) + + query = IRTensor([2, 1, 1, 16, 32, 1, 1, 32]) + op = f(query, 1, 32, 32, 1, 32) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 1 1 16 b 1 1 c -> 1 (a 16) b 1 c' + verify_partition(op) + + query = IRTensor([10, 1, 1, 5, 1, 7, 1, 1]) + op = f(query, 10, 5, 7, 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 1 1 b 1 c 1 1 -> a b c 1' + verify_partition(op) + + query = IRTensor([10, 1, 1, 5, 1, 7, 1, 1]) + op = f(query, 10, 5, 7) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 1 1 b 1 c 1 1 -> a b c' + verify_partition(op) diff --git a/utility/verify_ops/verify_dimops.py b/utility/verify_ops/verify_dimops.py index e34c1fa4..bb88bc72 100644 --- a/utility/verify_ops/verify_dimops.py +++ b/utility/verify_ops/verify_dimops.py @@ -305,11 +305,19 @@ def verify_partition_options(verify_config: VerifyConfig) -> bool: for k, v in verify_config.kwargs.items() ] ) - args_str = ", ".join([f"_in{i}" for i, tinfo in enumerate(verify_config.args)]) + func_sig_call = verify_config.fsig + args_str = ", ".join([f"_in{i}" for i in range(len(verify_config.args))]) + tensor_member_methods_prefix = 'torch.Tensor.' + if func_sig_call.startswith(tensor_member_methods_prefix): + # workaround because tracer does not support tensor member methods + func_sig_call = f'_in0.' + func_sig_call[len(tensor_member_methods_prefix):] + func_args_str = ", ".join([f"_in{i}" for i in range(1, len(verify_config.args))]) + else: + func_args_str = args_str - if args_str: - func_call = f"{outputs_str} = {func_sig_call}({args_str}, {kwargs_str})" + if func_args_str: + func_call = f"{outputs_str} = {func_sig_call}({func_args_str}, {kwargs_str})" else: func_call = f"{outputs_str} = {func_sig_call}({kwargs_str})" From 0a80aff0817c3ebdce82f43715743f21ef2a9455 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Tue, 3 Dec 2024 11:53:02 +0000 Subject: [PATCH 1782/1892] Merged PR 2338: Drop python 3.8 support --- README.md | 4 ++-- nnscaler/resources/__init__.py | 10 +++------- pyproject.toml | 2 +- requirements.txt | 1 - 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 10ea3495..83bcb86e 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ For **_DNN system experts_**, they can leverage nnScaler to explore new DNN para Install the following packages before the installation of nnScaler: - Python >= 3.8, < 3.11 (3.10 is recommanded) + Python >= 3.9, < 3.11 (3.10 is recommanded) PyTorch >= 2.0, < 2.4 (2.2.0 is recommanded) @@ -221,4 +221,4 @@ This project may contain trademarks or logos for projects, products, or services ## Contact You may find our public repo from or microsoft internal repo . -For any questions or inquiries, please contact us at [nnscaler@service.microsoft.com](mailto:nnscaler@service.microsoft.com). \ No newline at end of file +For any questions or inquiries, please contact us at [nnscaler@service.microsoft.com](mailto:nnscaler@service.microsoft.com). diff --git a/nnscaler/resources/__init__.py b/nnscaler/resources/__init__.py index b757e372..6bd0b71d 100644 --- a/nnscaler/resources/__init__.py +++ b/nnscaler/resources/__init__.py @@ -5,15 +5,11 @@ Pseudo module of resource files. """ -from __future__ import annotations - __all__ = 'files' -# TODO: when drop python 3.8 support, change it to `importlib.resources` -import importlib_resources -from importlib_resources.abc import Traversable +import importlib.resources -def files() -> Traversable: +def files(): """ Alias of ``importlib.resources.files('nnscaler.resources')``. @@ -25,4 +21,4 @@ def files() -> Traversable: import nnscaler.resources (nnscaler.resources.files() / 'path/to/my_file.txt').read_text() """ - return importlib_resources.files(__name__) + return importlib.resources.files(__name__) diff --git a/pyproject.toml b/pyproject.toml index 37f38af6..2e48bd29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version", "dependencies"] name = "nnscaler" description = "Parallelize DNN Training via A Systematic Way" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [ {name = "nnScaler Team", email = "nnscaler@service.microsoft.com"} ] diff --git a/requirements.txt b/requirements.txt index 9fd451a7..2fec92d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ cppimport dill -importlib-resources matplotlib more-itertools numpy>=1.23.0 From f02bdcc903efb474a65a20bd6fce8af7992fd514 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 6 Dec 2024 05:53:05 +0000 Subject: [PATCH 1783/1892] Merged PR 2340: [Doc] add test and doc for local_consumer_multiref --- nnscaler/graph/gener/gen.py | 69 +++++++++----- nnscaler/graph/gener/utils.py | 44 +++++++-- nnscaler/graph/segment.py | 5 + .../gener/test_local_consumer_multiref.py | 93 +++++++++++++++++++ .../test_shared_param_pipeline.py | 3 + 5 files changed, 187 insertions(+), 27 deletions(-) create mode 100644 tests/graph/gener/test_local_consumer_multiref.py diff --git a/nnscaler/graph/gener/gen.py b/nnscaler/graph/gener/gen.py index fbb0e398..fbcd65af 100644 --- a/nnscaler/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -107,11 +107,13 @@ def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: Generate tensor adapter for both activations and weights Note weight reducers are always append to the last. - @param graph IRGraph: the graph without adapter - @param cost_fn Optional[Callable]: takes an IRAdapterPrim and outputs a cost in float. - default to be None, which will use communication volume. + Args: + graph (IRGraph): the graph without adapter + cost_fn Optional[Callable]: takes an IRAdapterPrim and outputs a cost in float. + default to be None, which means communication volume is used as cost. - @return graph IRGraph: the graph with adapter inserted + Returns: + graph (IRGraph): the graph with adapter inserted """ # reorder producer and consumer ordering graph._reorder_producer_consumer() @@ -119,7 +121,7 @@ def gen(graph: IRGraph, cost_fn: Optional[Callable] = None) -> IRGraph: # remove anchor node graph = IRAdapterGener.remove_anchor(graph) _logger.info("finish removing anchor nodes") - # automatic replace pyfunc + # automatic replicate pyfunc graph = IRAdapterGener.auto_pyfunc(graph) _logger.info("finish replacing auto pyfunc") # automatic transform multiref @@ -153,7 +155,7 @@ def auto_pyfunc(graph: IRGraph): Warning: Each IRPyFunc will be replicated to all devices of its segment. - To restrict the replicaed devices in pipeline-like scenarios, use `graph.staging` + To restrict the replicated devices in pipeline-like scenarios, use `graph.staging` to group the operators into segments. Args: @@ -332,16 +334,36 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: input_producer, output_consumer = create_dummy(graph, inputs=True, outputs=True) bgraph: Optional[IRSegment] = graph.mirror - # local producer fusion and local consumer multiref + # Here are two optimization passes that are applied before generating communication adapters: + # - local producer fusion: If an operator is partitioned and there are multiple + # different sub-tensors on the same device, insert appropriate concat or accumulate + # operators to merge the results on the current device before generating communication. + # This way, part of the communication between multiple devices can be converted into + # local data processing. + # + # - local consumer multiref: When a full tensor has multiple consumers and, after partitioning, + # there exists a device that contains multiple partitioned consumers (note that this pass assumes + # these consumers share a same sub-tensor in the forward graph), a multiref node will be + # inserted before these consumers on that device. This way, during the backward pass through + # the multiref node, the gradients from the consumers are automatically accumulated together, + # avoiding the need for accumulation operations in the backward adapter. Note that to make + # this optimization work properly, `flatten_grad` should be called to adjust the valuemap of + # the gradient sub-tensors. + # + # Apart from the purpose of improving the efficiency of communication adapters, these two passes + # also reduce the number of sub-tensors that need to be considered when generating adapters, which + # can help to reduce the complexity of the adapter generation algorithm. More specifically, if the + # plan of the input graph is SPMD, the local consumer multiref pass will ensure the number of + # fptensors and bptensors is the same, which raise the possibility of generating high performance + # collectives, like allgather, allreduce, etc. ftensors = [] _cnt = 0 for ftensor in graph.full_tensors(): - # backward will gen in forward + # backward adapter will be generated along with the forward adapter if ftensor.is_param() or ftensor.is_grad(): continue # flatten gradient utils.flatten_grad(graph, ftensor) - # optimization: local fusion / multiref on producer / consumer ftensor = IRAdapterGener.local_producer_fusion(graph, ftensor) IRAdapterGener.local_consumer_multiref(graph, ftensor) ftensors.append(ftensor) @@ -485,7 +507,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: @staticmethod def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTensor: - """! + """ Fuse the producer tensors using concat and add. This will add a new full tensor by chaging from: producer --(ftensor)--> consumer @@ -496,10 +518,13 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens recompute group, then the additional generated cat/add are also apllied with same recompute region. Otherwise no recompute. - @param tensors List[IRSubTensor]: tensors to be fused in local device + Args: + graph (IRSegment): the graph that contains the full tensor + ftensor (IRFullTensor): the full tensor to be manipulated - @return new_ftensor IRFullTensor: the new full tensor. - If cannot fuse, the original ftensor. + Returns: + new_ftensor IRFullTensor: the new full tensor. If cannot fuse, + return the original ftensor. """ if not ftensor.requires_grad: return ftensor @@ -648,17 +673,17 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): then create a multiref forward node for it to make each sub-tensor to be consumed only once in each device. - This is to adapt with pytorch autograd function. - producer -> consumers[0,1] producer -> multiref -> consumer[0] |-----> consumer[1] - @param graph IRGraph - @param ftensor IRFullTensor: the forward full tensor + Args: + graph (IRSegment): the graph that contains the full tensor + ftensor (IRFullTensor): the full tensor to be manipulated - @return None + Returns: + None: the graph is modified inplace. """ if not ftensor.requires_grad: return @@ -735,9 +760,11 @@ def autoref(graph: IRSegment) -> IRGraph: Automatically transform inserted multiref. Multiref is transformed to align with the output tensors on each device. - @param graph IRGraph + Args: + graph (IRGraph): the graph to be transformed - @return None + Returns: + graph (IRGraph): the graph with transformed multiref """ for multiref in graph.select(name='multiref', flatten=False): # setup recompute @@ -758,7 +785,7 @@ def autoref(graph: IRSegment) -> IRGraph: ptensors = sorted(ptensors, key=lambda t: t.device[0]) for tensor in ptensors: mr = MultiRef(tensor, len(multiref.outputs())) - mr.comment = f'create at IRAdapterGener:autoref, src tensor is {multiref.comment}' + mr.comment = f'create at IRAdapterGener:autoref, comment before transformation: {multiref.comment}' mr.input(0).grad = tensor.grad for idx, out in enumerate(multiref.outputs()): output = out.parent.select(tensor.indmap, tensor.valmap) diff --git a/nnscaler/graph/gener/utils.py b/nnscaler/graph/gener/utils.py index 6b8a9f3c..00b92899 100644 --- a/nnscaler/graph/gener/utils.py +++ b/nnscaler/graph/gener/utils.py @@ -73,14 +73,46 @@ def convert_add_to_valmap(graph: IRGraph, add_node: IRFwOperation): def flatten_grad(graph: IRSegment, ftensor: IRFullTensor): """ - Reset gradient for consumers that are different (no replica) - Gradient valuemap will be flatten inter-devices, e.g.,(0,3), (1,3), (2,3) - Gradient valuemap will be exponent intra-devices, e.g., (0,2), (2,4), (3,4) + Normalize gradient's valuemap for consumers that are different (no replica). + Gradient valuemap will be flatten inside a device, e.g.,(0,3), (1,3), (2,3). + Gradient valuemap will be exponent cross devices, e.g., (0,2), (2,4), (3,4). - @param graph IRGraph: the graph - @param ftensor IRFullTensor: the fulltensor + Example: + Assume a tensor is consumed by 2 linear operators, the source code is: + ```python + x1 = self.fc0(x0) + x2 = self.fc1(x1) + x3 = self.fc2(x1) + ``` + There are two devices, `fc0` is partitioned along x's dim 0 (the batch dim), + fc1 and fc2 are partitioned along the weight's dim 0. In the forward pass, + a `AllGather` is inserted since each device only has a part of x1, while + fc1 and fc2 need the full x1. However, in the backward pass, it is hard to + generate the communication primitive directly since the `valmap` of consumer + sub-tensors are not well ordered. After graph transformation, the `valmap` of + x1's grad sub tensors are + - fc1 on device 0: (1, 4) + - fc1 on device 1: (0, 4) + - fc2 on device 0: (3, 4) + - fc2 on device 1: (2, 4) + The reason for the valmap values is that: 1) After constructing the graph, + the corresponding valmap for each consumer is calculated based on the number + of consumers. 2) After calling graph.partition, the valmap for the partitioned + sub-consumers is updated. + To align with `local_consumer_multiref` and the adapter generation, this function + will update the grad tensor of x1 to + - fc1 on device 0: (0, 4) + - fc2 on device 0: (1, 4) + - fc1 on device 1: (2, 4) + - fc2 on device 1: (3, 4) + e.g., the grad's valuemap is split along the device then along local consumers. - @return None: this is an inplacement update. + Args: + graph (IRGraph): the graph + ftensor (IRFullTensor): the fulltensor + + Returns: + None: input graph is modified inplace """ if not isinstance(ftensor.grad, IRFullTensor): return diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index 3c8ea635..3f6f289e 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -727,6 +727,11 @@ def multiref(self, ftensor: IRFullTensor, comment: Optional[str] = None, *deprec ids and dispatched to the corresponding consumers. The input tensor can be parameter, buffer or activation tensors. + Note that during the adapter generation (IRAdapterGener), the multiref inserted + here will be partitioned automatically by `autoref`. Further more, multiref may + be added to the graph at that time to reduce the communication time, check + `gen_activation` and `local_consumer_multiref` for more details. + Args: tensor (IRSubTensor): full tensor to be multiref. diff --git a/tests/graph/gener/test_local_consumer_multiref.py b/tests/graph/gener/test_local_consumer_multiref.py new file mode 100644 index 00000000..b04261e1 --- /dev/null +++ b/tests/graph/gener/test_local_consumer_multiref.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn +import tempfile +import pytest + +import nnscaler.graph.function.function as F +from nnscaler.parallel import ComputeConfig, parallelize +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode + + +class Model(torch.nn.Module): + + def __init__(self): + super(Model, self).__init__() + self.fc0 = torch.nn.Linear(2, 2, bias=False) + self.fc1 = torch.nn.Linear(2, 2, bias=False) + self.fc2 = torch.nn.Linear(2, 2, bias=False) + + def forward(self, x): + x = self.fc0(x) + x1 = self.fc1(x) + x2 = self.fc2(x) + x = x1 + x2 + loss = torch.sum(x) + return loss + + +def pas(graph, compute_config): + fc0, fc1, fc2, add, loss = graph.nodes()[:5] + sub_nodes = graph.partition(fc0, fc0.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + sub_nodes = graph.partition(fc1, fc1.algorithm('dim'), idx=1, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + sub_nodes = graph.partition(fc2, fc2.algorithm('dim'), idx=1, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + sub_nodes = graph.partition(add, add.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + sub_nodes = graph.partition(loss, loss.algorithm('dim'), idx=0, dim=0, num=2) + graph.assign(sub_nodes[0], 0) + graph.assign(sub_nodes[1], 1) + + return graph + + +def test_local_consumer_multiref(): + m = Model() + m.train() + torch.manual_seed(0) + trace_data = torch.randn([2, 2]) + + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + m, + {'x': trace_data}, + pas, + ComputeConfig(2, 2, use_end2end=False, trace_strategy='cpu',), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) + # print_gencode(tempdir, Model, 0) + # The generated code should be like: + # linear_56 = torch.nn.functional.linear(x_54, self.fc0_weight_33, bias=None) + # del x_54 + # linear_34 = nnscaler.runtime.adapter.nn.allgather_reducescatter(linear_56, dim=0, ranks=[0, 1]) + # del linear_56 + # linear_125, linear_129 = nnscaler.runtime.function.multiref(linear_34, times=2) + # del linear_34 + # linear_1_70 = torch.nn.functional.linear(linear_125, self.fc1_weight_68, bias=None) + # del linear_125 + # linear_2_84 = torch.nn.functional.linear(linear_129, self.fc2_weight_82, bias=None) + # del linear_129 + # linear_1_96 = nnscaler.runtime.adapter.nn.alltoall_alltoall(linear_1_70, idim=1, odim=0, ranks=[0, 1]) + # del linear_1_70 + # linear_2_98 = nnscaler.runtime.adapter.nn.alltoall_alltoall(linear_2_84, idim=1, odim=0, ranks=[0, 1]) + # del linear_2_84 + # add_100 = torch.add(linear_1_96, linear_2_98, alpha=1) + for i in range(2): + # output of fc0 is used two times in each device, so local multiref will be added in each device + assert len(_gencode_contains(tempdir, Model, i, 'nnscaler.runtime.function.multiref')) == 1 + assert len(_gencode_contains(tempdir, Model, i, 'nnscaler.runtime.adapter.nn.allgather_reducescatter')) == 1 + assert len(_gencode_contains(tempdir, Model, i, 'nnscaler.runtime.adapter.nn.alltoall_alltoall')) == 2 diff --git a/tests/parallel_module/test_shared_param_pipeline.py b/tests/parallel_module/test_shared_param_pipeline.py index 3d67c662..222c0962 100644 --- a/tests/parallel_module/test_shared_param_pipeline.py +++ b/tests/parallel_module/test_shared_param_pipeline.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import torch import torch.nn as nn import tempfile From 1ac5b1760aeeba3d947e672556cb4a1229243a6d Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 10 Dec 2024 06:02:20 +0000 Subject: [PATCH 1784/1892] Merged PR 2341: [Reorg + Parser] Polish parser logic Refine parser logic to make it clearer and maintainable. parity check pass unit test pass --- nnscaler/graph/function/dimops.py | 2 +- nnscaler/graph/function/function.py | 62 ++++--- nnscaler/graph/function/wrapnn.py | 21 ++- nnscaler/graph/parser/frame.py | 16 +- nnscaler/graph/parser/parser.py | 216 +++++++++++++++++------ nnscaler/ir/cten.py | 159 +++++++++++++++-- nnscaler/parallel.py | 4 +- nnscaler/program.py | 4 +- tests/graph/function/test_dict_values.py | 1 + tests/graph/function/test_functions.py | 10 +- tests/graph/parser/test_parser.py | 150 ++++++++++++++++ tests/ir/test_cten.py | 26 +-- tests/parallel_module/test_gencode.py | 150 +++++++++++++++- 13 files changed, 702 insertions(+), 119 deletions(-) diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index e544c50e..24241de2 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -737,7 +737,7 @@ def __init__(self, create_fn: Callable, name: str, # change tensor to IRObject for '?' annotation for idx, shape_anno in enumerate(self._oannos): if shape_anno.ignore: - self.set_output(idx, IRObject()) + self.set_output(idx, IRObject.missing) @property def anno(self) -> OpAnno: diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 23d46ed3..b8a99197 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1,6 +1,24 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +""" +Rules: +1. `dict`/`list`/`tuple`/`slice` are the only supported container types for IRObject/IRTensor. +2. IRObjects created in functions should be the outputs. Never create new IRObjects as function inputs/kwargs. +3. `iter` is not compatible with our system, and should be avoided. + That's because the values in the iterator have been exausted in tracer +4. Never nest IRObject, i.e. put another IRObject in IRObject.value. + Currently some functions still nest IRObject. Will be fixed in the future. +5. You can return concrete value in function only when there is no IRObjects in function args/kwargs. +6. When the tensor shape is unknown, you can annotate with `?`. In this case, the tensor will never be partitioned. +7. When a function returns multiple tensors, you should return them as tuple/list, + and annotate them correctly. +8. When a function is annotated with no-output (the right side of `->` of annotation is empty), + If you don't use its output in source code, this function will be removed in tracer + But if you use its output in source code, KeyError will be trigger in parser. + So it is a bad idea to annotate a function with no-output. +""" + from typing import Any, Callable, List, Optional, Tuple, Dict, Union, Iterable import string import copy @@ -11,7 +29,7 @@ import logging from collections.abc import Iterable -from nnscaler.ir.cten import IRTensor, IRObject +from nnscaler.ir.cten import IRTensor, IRObject, IR from nnscaler.ir.tensor import IRSubTensor, IRFullTensor from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule @@ -100,7 +118,7 @@ def FoldConstant(value: Any, signature = None): # always return a constant # no node will be created - return IRObject.try_unwrap(value) + return IR.try_unwrap(value) def MultiRef(tensor: IRTensor, times: int, signature = None): @@ -2302,7 +2320,7 @@ def _find_first_invalid_arg_name(arg_values, overload_idx): # overload matching is done by checking the type of the first positional argument # here we use unwrapped value, # because if it's a wrapped IRObject, it's impossible for us to select the correct overload - arg0 = IRObject.try_unwrap(args[0]) + arg0 = IR.try_unwrap(args[0]) for overload_idx in range(len(positional_arg_names)): # if arg[0] is None, use the first overload if arg0 is None or isinstance(arg0, arg_types[positional_arg_names[overload_idx][0]]): @@ -2538,7 +2556,7 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], return torch.strided if isinstance(obj, torch.finfo): return getattr(obj, name) - return IRPyFunc(signature, [instance, field], [IRObject()]) + return IRPyFunc(signature, [instance, field], [IRObject.missing]) def FInfo(dtype: torch.dtype, signature = None) -> torch.finfo: @@ -2866,10 +2884,10 @@ def Conv1D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, """ if len(input.shape) not in [2, 3]: raise ValueError(f"Expected input tensor to have 2 or 3 dimensions, but got {input.shape}") - stride_val = IRObject.try_unwrap(stride) - padding_val = IRObject.try_unwrap(padding) - dilation_val = IRObject.try_unwrap(dilation) - groups_val = IRObject.try_unwrap(groups) + stride_val = IR.try_unwrap(stride) + padding_val = IR.try_unwrap(padding) + dilation_val = IR.try_unwrap(dilation) + groups_val = IR.try_unwrap(groups) if isinstance(stride_val, int): stride_val = (stride_val,) if isinstance(dilation_val, int): @@ -2952,11 +2970,11 @@ def ConvTranspose1D(input, weight, bias=None, stride=1, padding=0, output_paddin """ if len(input.shape) not in [2, 3]: raise ValueError(f"Expected input tensor to have 2 or 3 dimensions, but got {input.shape}") - stride_val = IRObject.try_unwrap(stride) - padding_val = IRObject.try_unwrap(padding) - output_padding_val = IRObject.try_unwrap(output_padding) - dilation_val = IRObject.try_unwrap(dilation) - groups_val = IRObject.try_unwrap(groups) + stride_val = IR.try_unwrap(stride) + padding_val = IR.try_unwrap(padding) + output_padding_val = IR.try_unwrap(output_padding) + dilation_val = IR.try_unwrap(dilation) + groups_val = IR.try_unwrap(groups) if isinstance(stride_val, int): stride_val = (stride_val,) if isinstance(padding_val, int): @@ -3013,10 +3031,10 @@ def Conv2D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, """ if len(input.shape) not in [3, 4]: raise ValueError(f"Expected input tensor to have 3 or 4 dimensions, but got {input.shape}") - stride_val = IRObject.try_unwrap(stride) - padding_val = IRObject.try_unwrap(padding) - dilation_val = IRObject.try_unwrap(dilation) - groups_val = IRObject.try_unwrap(groups) + stride_val = IR.try_unwrap(stride) + padding_val = IR.try_unwrap(padding) + dilation_val = IR.try_unwrap(dilation) + groups_val = IR.try_unwrap(groups) if isinstance(stride_val, int): stride_val = (stride_val, stride_val) if isinstance(dilation_val, int): @@ -3104,11 +3122,11 @@ def ConvTranspose2D(input, weight, bias=None, stride=1, padding=0, output_paddin """ if len(input.shape) not in [3, 4]: raise ValueError(f"Expected input tensor to have 3 or 4 dimensions, but got {input.shape}") - stride_val = IRObject.try_unwrap(stride) - padding_val = IRObject.try_unwrap(padding) - output_padding_val = IRObject.try_unwrap(output_padding) - dilation_val = IRObject.try_unwrap(dilation) - groups_val = IRObject.try_unwrap(groups) + stride_val = IR.try_unwrap(stride) + padding_val = IR.try_unwrap(padding) + output_padding_val = IR.try_unwrap(output_padding) + dilation_val = IR.try_unwrap(dilation) + groups_val = IR.try_unwrap(groups) if isinstance(stride_val, int): stride_val = (stride_val, stride_val) if isinstance(padding_val, int): diff --git a/nnscaler/graph/function/wrapnn.py b/nnscaler/graph/function/wrapnn.py index a65a0615..a2d44ad2 100644 --- a/nnscaler/graph/function/wrapnn.py +++ b/nnscaler/graph/function/wrapnn.py @@ -26,10 +26,9 @@ from torch.nn.modules.instancenorm import _InstanceNorm from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm -from nnscaler.graph.function.function import _unwrap_value from nnscaler.graph.parser.register import register_op from nnscaler.ir.operator import IRFwOperation -from nnscaler.ir.cten import IRObject, IRTensor +from nnscaler.ir.cten import IRObject, IRTensor, IR from nnscaler.runtime.device import DeviceGroup @@ -144,11 +143,11 @@ def batchnorm2d_annotation_fn(*inputs, **kwargs): should also be absent. Reference: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html """ - weight = IRObject.try_unwrap(weight) - bias = IRObject.try_unwrap(bias) - running_mean = IRObject.try_unwrap(running_mean) - running_var = IRObject.try_unwrap(running_var) - num_batches_tracked = IRObject.try_unwrap(num_batches_tracked) + weight = IR.try_unwrap(weight) + bias = IR.try_unwrap(bias) + running_mean = IR.try_unwrap(running_mean) + running_var = IR.try_unwrap(running_var) + num_batches_tracked = IR.try_unwrap(num_batches_tracked) if weight is None: assert bias is None @@ -375,10 +374,10 @@ def instancenorm2d_annotation_fn(*inputs, **kwargs): ), "Expected 5 inputs: input, weight, bias, running_mean, running_var" input, weight, bias, running_mean, running_var = inputs - weight = IRObject.try_unwrap(weight) - bias = IRObject.try_unwrap(bias) - running_mean = IRObject.try_unwrap(running_mean) - running_var = IRObject.try_unwrap(running_var) + weight = IR.try_unwrap(weight) + bias = IR.try_unwrap(bias) + running_mean = IR.try_unwrap(running_mean) + running_var = IR.try_unwrap(running_var) if weight is None: assert bias is None diff --git a/nnscaler/graph/parser/frame.py b/nnscaler/graph/parser/frame.py index 6c1dcbba..228242a5 100644 --- a/nnscaler/graph/parser/frame.py +++ b/nnscaler/graph/parser/frame.py @@ -80,6 +80,16 @@ def add_var(self, var_name: str, val: Any, graph_arg: int = -1): else: raise ValueError("graph_arg (int) must be >= 0") + def del_val(self, var_name: str): + """ + Delete a variable from the current frame. + Do nothing if the variable doesn't exist. + + Args: + var_name (str): variable name + """ + self._vars[-1].pop(var_name, None) + def set_var(self, var_name: str, val: Any): """ Reset a variable with arbitrary value. @@ -104,7 +114,11 @@ def get_var(self, var_name: str) -> Any: # first check whether we have variable in this frame if var_name in self._vars[-1]: return self._vars[-1][var_name] - raise KeyError(f"Cannot find var name {var_name} in {self._vars}") + # See rule 8 in graph/function/functions.py + raise KeyError( + f"Cannot find var name {var_name} in {self._vars}. " + f"Please check whether the variable is from a function that is annotated as no-output." + ) def add_attr(self, tensor: IRTensor, concrete_value: torch.Tensor, name: str): """Add module attribute content diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 93cbe202..cb26d26c 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -11,7 +11,7 @@ from nnscaler.graph.tracer.metadata import OpContext from nnscaler.ir.operator import IRFwOperation from nnscaler.ir.tensor import IRFullTensor -from nnscaler.ir.cten import IRObject, IRCell, IRTensor +from nnscaler.ir.cten import IRObject, IRCell, IRTensor, IR from nnscaler.graph.parser.frame import Frame from nnscaler.graph.parser.mapping import SignFx2Op from nnscaler.graph.function.pyfunc import IRPyFunc @@ -62,11 +62,23 @@ def parse(module: torch.fx.GraphModule, all_ir_nodes (List[IRFwOperation]): the IRFwOperation nodes outputs (List[IRObject]): the output IRObjects """ + # frame will save the outputs of all ops (including `placeholder` and `output`) + # because some ir ops creators just return ops with empty ouputs + # (Those ops creators include user registered function, all functions returning tensors and more) + # We will connect the real op outputs (saved in frame) to all ir op outputs and inputs later. + frame = Frame() frame.push_var() # shape propagation assert isinstance(dummy_inputs, dict), "Expected dummy inputs to parse module" + output_nodes = [node for node in module.graph.nodes if node.op == 'output'] + # currently fx graph always has only one output + # even if a tuple/list is returned, it is still just one output + assert len(output_nodes) == 1, f"Expect only one output, but got {len(output_nodes)}" + output_node = output_nodes[0] + # currently output of fx graph satisfies the following: + assert len(output_node.args) == 1 and len(output_node.kwargs) == 0 # create IRObjects and IRTensors for node in module.graph.nodes: @@ -75,6 +87,23 @@ def parse(module: torch.fx.GraphModule, else: FxModuleParser.init_objects(node, module, frame, is_constant=True) + # note the output node will be reset later by `parse_prim_output_node` + # with the help of `parse_complex` + + # if fx graph output (output_node.args[0]) is a nested structure + # the nested structure will be kept in `parse_complex` + + # but if the node is the only output of the graph + # and it is a tuple, we need to keep it a tuple + # to make sure the IRGraph has the correct output number + # see `IRGrpah.from_logic_graph` + + val = frame.get_var(node.name) + if node == output_node.args[0] \ + and IR.is_object(val) and isinstance(val.value, tuple): + tuple_val = tuple(IRObject(name=node.name, value=v, is_constant=val.is_constant) for v in val.value) + frame.set_var(node.name, tuple_val) + # get graph inputs placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] inputs = [frame.get_var(n.name) for n in placeholders] @@ -131,7 +160,7 @@ def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, assert hasattr(node, 'meta') and 'tensor_meta' in node.meta, f"Node {node} should have tensor_meta" meta = node.meta['tensor_meta'] - val = IRObject.from_complex(node.name, meta, + val = IR.new(node.name, meta, collection_types=(list, tuple, dict, DICT_VALUES_TYPE, DICT_ITEMS_TYPE), tensor_types=(TensorMetadata,), is_constant=is_constant @@ -189,6 +218,20 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: @staticmethod def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: + """ + Convert `call_function`/`call_method` op to IRFwOperation. + + Args: + node (torch.fx.Node): the node to be parsed + module (torch.fx.GraphModule): the module containing the node + constant_folding (bool): global setting of whether to fold the constant + + Returns: + List[IRFwOperation]: the IRFwOperation nodes. + The returned list can be empty if the node is folded, + or contains exactly one node if the node is not folded. + + """ # get signature fsig = FxModuleParser._get_qualified_name(node.target, node) # get inputs @@ -229,56 +272,130 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule output = IRObject(name=node.name, value=output, is_constant=is_constant) ir_node = IRPyFunc(fsig, input_vals, [output], **kwargs) - FxModuleParser._set_node_meta(node, ir_node) - - ir_nodes = [] - if isinstance(ir_node, IRCell): - ir_nodes.append(ir_node) - if len(ir_node.outputs()) > 1: - vals = frame.get_var(node.name) - assert len(vals) == len(ir_node.outputs()), f'{vals}, {ir_node.outputs()}' - for i in range(len(vals)): - ir_node.set_output(i, vals[i]) - elif not isinstance(ir_node.output(0), IRTensor) and ir_node.output(0).value is not None: - # never fold our own functions defined in `nnscaler.runtime.function` module. - # currently only `ifexpr` will go here, and it will never be folded. - if not constant_folding or \ - ir_node.signature.startswith(nnscaler.runtime.function.__name__ + '.') or \ - any_ir_object_satisfy(ir_node.output(0), lambda a: not a.is_constant) or \ - any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, IRTensor)) or \ - any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, (DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE))): - # type of return values of dict.keys, dict.values and dict.items can not be repr, so we must take it as a node - frame.set_var(node.name, ir_node.output(0)) - ir_node.output(0).name = node.name - else: - # if use static shape graph, all IRObject will be converted to real traced value. - # the ir_node will be folded and not appeared in the final graph - frame.set_var(node.name, ir_node.output(0).value) - ir_nodes.pop(-1) - else: - output_val = frame.get_var(node.name) - if isinstance(ir_node, IRDimops): - # TODO: refine here - # infer_type actually just check whether the annoation is consistent - # with actual output - # internally it will set the shape of output, - # but the output is quickly rewritten by the actual output - # in following code `ir_node.set_output(0, output_val)` - # So the scalar-tensor flag is not removed with `infer_shape` - ir_node.infer_shape() - if isinstance(output_val, IRTensor) and isinstance(ir_node.output(0), IRTensor): - assert output_val.shape == ir_node.output(0).shape, ( - f'find shape inference not match: {output_val.shape} vs {ir_node.output(0).shape}' - f'\nnode: {node}' - ) - ir_node.set_output(0, output_val) - else: - # SignFx2Op may return object that is not IRCell but a concrete value, for example Add. + if not isinstance(ir_node, IRCell): + # SignFx2Op may return object that is not IRCell but a value (concrete or IRObject), + # for example Add or GetItem. # As node is deleted, we must set concrete value or IRTensor/IRObject into framework. + + # TODO: check the value saved in frame should equal to the value returned by the op frame.set_var(node.name, ir_node) + return [] - _logger.debug(f'parsing result: {ir_node}') - return ir_nodes + FxModuleParser._set_node_meta(node, ir_node) + + # step 1: align the node output with the value in frame + + # TODO: handle the case when the function has no output + # Currently this can only happened for user registered functions + # We may need to assume the function has side effect, and keep it in the graph + # But in current implementation, this kind of function will be removed by DCE in tracer + # As a result, the follow check will not be triggered in normal cases + if not ir_node.outputs(): + # To avoid the case that the function is annotated no output + # but its output is used in other nodes. + # By removing from frame, + # we can catch the case earlier + frame.del_val(node.name) + # if the function has no output, just return + return [ir_node] + + vals = frame.get_var(node.name) + if len(ir_node.outputs()) == 1: + vals = [vals] + elif IR.is_object(vals): + # fix the case that multiple outputs are returned as a single IRObject + # Because IR.new doesn't know the number of outputs + # it will wrap the output as a single IRObject if no tensor is in the outputs + # so we need to unwrap it here, and align the outputs with annotations. + is_constant = vals.is_constant + vals = vals.value + if not isinstance(vals, (list, tuple)): + raise RuntimeError(f'Expect list or tuple for multiple outputs, but got {type(vals)}') + vals = type(vals)(IRObject(name=node.name, value=v, is_constant=is_constant) for v in vals) + frame.set_var(node.name, vals) + + # this is only for annoation check + # to make sure the annoation is consistent with actual output + if isinstance(ir_node, IRDimops): + # TODO: refine here + # infer_shape actually just check whether the annoation is consistent + # with actual output + # internally it will set the shape of output, + # but the output is quickly rewritten by the actual output + # in following code `ir_node.set_output(0, output_val)` + # So the scalar-tensor flag is not removed with `infer_shape` + ir_node.infer_shape() + for oidx, otensor in enumerate(ir_node.outputs()): + shape_anno = ir_node.oanno(oidx) + if shape_anno.ignore: + continue + assert isinstance(vals[oidx], IRTensor) and isinstance(otensor, IRTensor), \ + f'find type inference not match: {vals[oidx]} vs {otensor}' + + assert vals[oidx].shape == otensor.shape, ( + f'find shape inference not match: {vals[oidx].shape} vs {otensor.shape}' + f'\nnode: {node}' + ) + + assert len(vals) == len(ir_node.outputs()), f'{vals}, {ir_node.outputs()}' + contains_undefined_output = any(v is IRObject.missing for v in ir_node.outputs()) + for i in range(len(vals)): + if isinstance(ir_node.output(i), IRTensor) or ir_node.output(i) is IRObject.missing: + # 1. output tensors are not set in function.py + # 2. IRObject output from some functions (registered functions/getattr) are not set + # For above two cases, we need to set them with values from frame. + ir_node.set_output(i, vals[i]) + + # update frame with ir output + # Please note when there is only one output, we will unwrap it from `ir_node.outputs()` here + frame.set_var( + node.name, + type(vals)(ir_node.outputs()) if len(ir_node.outputs()) > 1 else ir_node.output(0) + ) + + # update the name of output tensors + # Note assignment is not allowed in lambda + # so we use a helper function to update the name + def _update_name(x: IRObject): + x.name = node.name + IR.modify_objects_inplace(ir_node.outputs(), _update_name) + + # step 2: constant folding + + # rules: + # 1. never fold our own functions defined in `nnscaler.runtime.function` module. + # currently only `ifexpr` will go here, and it will never be folded. + # 2. never fold tensors. + # 3. never fold non-constant IRObject + # 4. never fold functions that contains undefined output + # TODO: This is for backward compatibility, Not sure why. + # In current implementation, + # only user registered functions and `getattr` will have undefined output. + # So I think the original intention is to avoid folding user registered functions. + # 5. Only fold primitive types (int, float, bool, None, str, Ellipsis) and its complex types + def _is_primitive_type(val): + # we don't fold a list/tuple/dict with length larger than this + # Just a quick filter, and may not work when val has multiple nested levels + FOLD_MAX_LEN = 10 + if isinstance(val, (list, tuple)): + return len(val) < FOLD_MAX_LEN and all(_is_primitive_type(v) for v in val) + elif isinstance(val, dict): + return len(val) < FOLD_MAX_LEN and all(_is_primitive_type(k) and _is_primitive_type(v) for k, v in val.items()) + # use a white list instead of a black list + return isinstance(val, (int, float, bool, type(None), str, type(Ellipsis))) + + # Note when it is not IRObject as a whole, we will not fold it + if constant_folding and len(ir_node.outputs()) == 1 \ + and isinstance(ir_node.output(0), IRObject) \ + and not isinstance(ir_node.output(0), IRTensor) \ + and not contains_undefined_output \ + and not ir_node.signature.startswith(nnscaler.runtime.function.__name__ + '.')\ + and ir_node.output(0).is_constant \ + and _is_primitive_type(ir_node.output(0).value): + frame.set_var(node.name, ir_node.output(0).value) + return [] + else: + return [ir_node] @staticmethod def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: @@ -349,7 +466,6 @@ def _set_node_meta(node: torch.fx.Node, ir_node: Union[IRCell, Any]): if comment: ir_node.comment = comment - @staticmethod def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: if isinstance(node_target, str): diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 2318bc6f..2b260f92 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -19,7 +19,7 @@ from __future__ import annotations from functools import lru_cache -from typing import List, Tuple, Union, Optional, Any, Dict, Callable +from typing import ClassVar, List, Tuple, Union, Optional, Any, Dict, Callable from collections import OrderedDict import copy import torch @@ -406,16 +406,16 @@ def __repr__(self) -> str: @staticmethod def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[IRObject]: - """Get all IRObjects from a complex data structure + """ + Get all IRObjects (including IRTensor) from a complex data structure - Supported complex of types: List, Tuple, Dict, Slice, IRTensor, IRObject + Supported complex of types: List, Tuple, Dict, Slice Args: - val (Any): the complex data structure to be modified + val (Any): the complex data structure to be traversed _objects (List[IRObject] | None): if provided, the objects will be appened into this - - Return: + Returns: List[IRObject]: all IRObject """ _objects = [] if _objects is None else _objects @@ -463,6 +463,8 @@ class IRObject: """ IRObject serves as general data of IRGraph edge """ + # will be set after class definition + missing: ClassVar['IRObject'] = None def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: Optional[None] = None, is_constant: bool = True): """ @@ -547,6 +549,8 @@ def __eq__(self, obj) -> bool: def __copy__(self): """Copy this object but remove the cell information""" + if self is IRObject.missing: # missing object is singleton + return IRObject.missing return IRObject(self.name, self._id, self._value, self._is_constant) def as_attr(self): @@ -573,8 +577,16 @@ def overlap(self, other: Any) -> bool: else: return False + def __repr__(self): + return f'Object({self.name}{self.tid}, val={self.value}, is_constant={self.is_constant})' + + +IRObject.missing = IRObject('missing', -1, None) + + +class IR: @classmethod - def from_complex(cls, + def new(cls, name: str, data: Any, *, @@ -608,6 +620,7 @@ def from_complex(cls, collection_types (Tuple): the complex data types to be converted tensor_types (Tuple): the tensor data types to be converted tosub(bool): whether convert full tensor to sub-tensor + is_constant (bool): whether the object is constant requires_grad (Optional[bool]): the requires_grad flag for the tensor-like object None: will respect the original requires_grad flag True: will set requires_grad to True @@ -682,7 +695,128 @@ def _inner(obj) -> Tuple[Any, bool]: return _inner(data)[0] @classmethod - def tosub_complex(cls, obj: Any) -> Any: + def get_objects(cls, val: Any) -> List[IRObject]: + """ + Get all IRObjects from a complex data structure + + Supported complex of types: List, Tuple, Dict, Slice + + Args: + val (Any): the complex data structure to be modified + _objects (List[IRObject] | None): + if provided, the objects will be appened into this + + Return: + List[IRObject]: all IRObject + """ + return IRCell.get_objects_from_complex(val) + + @classmethod + def get_object_paths(cls, val: Any) -> Dict[IRObject, List[str]]: + irobj_path = {} + def r(t, current_path): + if isinstance(t, IRObject): + irobj_path[t] = current_path + elif isinstance(t, (list, tuple)): + for i, v in enumerate(t): + r(v, current_path + [i]) + elif isinstance(t, dict): + for k, v in t.items(): + r(v, current_path + [k]) + elif isinstance(t, slice): + raise ValueError("slice is not supported in get_object_paths") + else: + # do nothing + pass + r(val, []) + return irobj_path + + @classmethod + def contains_object(cls, val: Any, condition: Optional[Callable[[IRObject], bool]]=None) -> bool: + """ + Check if there is any IRObject in the complex data structure that satisfies the condition + + Supported complex of types: List, Tuple, Dict, Slice + + Args: + val (Any): the complex data structure to be checked + condition (Optional[Callable[[IRObject], bool]]): the condition to check. If None, check if there is any IRObject + + Return: + bool: True if there is any IRObject that matches the condition + """ + if isinstance(val, dict): + return any(cls.contains_object(v, condition) for v in val.values()) + elif isinstance(val, (list, tuple)): + return any(cls.contains_object(v, condition) for v in val) + elif isinstance(val, slice): + return any(cls.contains_object(v, condition) for v in (val.start, val.stop, val.step)) + elif isinstance(val, IRObject): + return condition is None or condition(val) + else: + return False + + @classmethod + def contains_non_constant_object(cls, val: Any) -> bool: + """ + Check if there is any non-constant IRObject in the complex data structure + + Supported complex of types: List, Tuple, Dict, Slice + + Args: + val (Any): the complex data structure to be checked + + Return: + bool: True if there is any non-constant IRObject + """ + return cls.contains_object(val, lambda x: not x.is_constant) + + @classmethod + def modify_objects(cls, val: Any, modifier: Callable[['IRObject'], 'IRObject']) -> Any: + """ + Return a complex data structure with modified IRObjects + + Supported complex of types: List, Tuple, Dict, Slice + + Args: + val (Any): the complex data structure to be modified + modifier (Callable): a modifier that takes an IRObject and return a new one. + + Return: + new_val (Any): complex data structure with modified IRObjects + """ + return IRCell.modify_objects_of_complex(val, modifier) + + @classmethod + def modify_objects_inplace(cls, val: Any, modifier: Callable[['IRObject'], None]) -> None: + """Modify a complex data structure inplace + + Supported complex of types: List, Tuple, Dict, Slice, IRTensor, IRObject + + Args: + val (Any): the complex data structure to be modified + modifier (Callable): a modifier that takes an IRObject and return nothing. + + Return: + None + """ + rcall = cls.modify_objects_inplace + if isinstance(val, (list, tuple)): + for item in val: + rcall(item, modifier) + if isinstance(val, dict): + for k, v in val.items(): + rcall(k, modifier) + rcall(v, modifier) + if isinstance(val, slice): + for v in (val.start, val.stop, val.step): + rcall(v, modifier) + if isinstance(val, IRObject): + modifier(val) + return val + + @classmethod + def tosub(cls, obj: Any) -> Any: """ Convert complex data type of tensor-like object into sub-tensor @@ -719,8 +853,12 @@ def try_unwrap(cls, x: Union[Any, 'IRObject'], unwrap_ir_tensor=False) -> Any: else: return x - def __repr__(self): - return f'Object({self.name}{self.tid}, val={self.value}, is_constant={self.is_constant})' + @classmethod + def is_object(cls, val: Any, include_ir_tensor=False) -> bool: + """ + Check if the value is an IRObject + """ + return isinstance(val, IRObject) and (include_ir_tensor or not isinstance(val, IRTensor)) class IRTensor(IRObject): @@ -734,7 +872,6 @@ class IRTensor(IRObject): So all further operations could ignore the scalar tensor case. You can get the original shape with `origin_shape` property. """ - def __init__(self, shape=None, name='tensor', dtype=None, tid=None, *, is_attr=False, is_grad=False, requires_grad=False, persistent=False ): diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 8011949a..25ce15d1 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -34,7 +34,7 @@ from nnscaler.graph.schedule.predefined import PredefinedSched from nnscaler.graph.schedule.schedplan import SchedulePlan -from nnscaler.ir.cten import IRObject, IRTensor +from nnscaler.ir.cten import IRObject, IRTensor, IR from nnscaler.ir.operator import IRBpOperation, IRDataOperation from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.unique import IDGenerator @@ -665,7 +665,7 @@ def _gen_graph( # generate backward communications in adapter. However, as long as # the data doesn't require gradient in real runtime, the backward # communication will not be triggered. - ir_dummy_inputs[i] = IRObject.from_complex( + ir_dummy_inputs[i] = IR.new( fx_input_nodes[i].target, ir_dummy_inputs[i], requires_grad=True, tosub=True, diff --git a/nnscaler/program.py b/nnscaler/program.py index 7fdd126e..8867f750 100644 --- a/nnscaler/program.py +++ b/nnscaler/program.py @@ -4,7 +4,7 @@ from typing import List, Tuple, Optional, Any, Dict, Union import inspect -from nnscaler.ir.cten import IRCell, IRObject +from nnscaler.ir.cten import IRCell, IRObject, IR from nnscaler.ir.tensor import IRFullTensor, IRSubTensor from nnscaler.ir.operator import IRBpOperation, IRDataOperation @@ -119,7 +119,7 @@ def __next__(self): if not isinstance(sample, tuple): sample = (sample,) # turn sample into IRObjects - outputs = tuple(IRObject.from_complex('data', s, tosub=True, requires_grad=False, is_constant=False) for s in sample) + outputs = tuple(IR.new('data', s, tosub=True, requires_grad=False, is_constant=False) for s in sample) outputs = tuple(IRObject('data', value=out) if not isinstance(out, IRObject) else out for out in outputs) # create dataloader operation # the `self.irobj` is the IRObject standing for the non-tensor value of real dataloader. diff --git a/tests/graph/function/test_dict_values.py b/tests/graph/function/test_dict_values.py index 7bf5c58b..d97ec854 100644 --- a/tests/graph/function/test_dict_values.py +++ b/tests/graph/function/test_dict_values.py @@ -31,3 +31,4 @@ def test_script_func(): gen_savedir=tempdir, load_module=False ) + assert True diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 8d48153a..345ac500 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -7,7 +7,7 @@ from operator import add from nnscaler.graph.function.dimops import IRDimops, OpAnno import nnscaler.graph.function.function as F -from nnscaler.ir.cten import IRObject, IRTensor +from nnscaler.ir.cten import IR, IRObject, IRTensor import pytest import torch @@ -974,9 +974,9 @@ def factors(n): def verify_partition(op: IRDimops): anno = op.anno inputs = torch.randn(op.inputs()[0].shape) - outputs = inputs.reshape(**IRObject.try_unwrap(op.kwargs)).clone() \ + outputs = inputs.reshape(**IR.try_unwrap(op.kwargs)).clone() \ if 'reshape' in op.signature \ - else inputs.view(**IRObject.try_unwrap(op.kwargs)).clone() + else inputs.view(**IR.try_unwrap(op.kwargs)).clone() def _get_anno_ids_map(shape_annos: tuple): anno_ids = {} @@ -1019,7 +1019,7 @@ def _get_anno_ids_map(shape_annos: tuple): input_chunks = torch.chunk(inputs, factor, dim=input_dim) # 2. update kwargs kwargs = transform_rules[(input_dim, output_dim)](op.kwargs, 0, input_dim, factor, 0) - kwargs = IRObject.try_unwrap(kwargs) + kwargs = IR.try_unwrap(kwargs) # 3. reshape/view reshaped_input_chunks = [chunk.reshape(**kwargs) for chunk in input_chunks] \ if 'reshape' in op.signature \ @@ -1036,7 +1036,7 @@ def test_reshape_view(): query = IRTensor([2048, 1, 2, 512]) op = f(query, IRObject(value=2048), IRObject(value=1), -1, 32) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 1 b 512 -> a 1 (b 16) 32' - assert IRObject.try_unwrap(op.kwargs[kwname]) == (2048, 1, -1, 32) + assert IR.try_unwrap(op.kwargs[kwname]) == (2048, 1, -1, 32) assert [type(x) for x in op.kwargs[kwname]] == [IRObject, IRObject, int, int] verify_partition(op) diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 9c2da080..a0bc33b8 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -174,3 +174,153 @@ def forward(self, x): assert len(ir_graph.nodes()[-1].outputs()) == 1 else: assert len(ir_graph.nodes()[-1].outputs()) == 2 + + +@nnscaler.register_op('m n -> ?, ?') +def func_output_list3(x, factor=1): + return factor, factor * 2 + + +@nnscaler.register_op('m n -> ?, ?') +def func_output_list4(x, factor=1): + return [factor, factor * 2] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('output_list', [True, False]) +def test_non_tensor_multiple_outputs(tmp_path, output_list): + # test the case when multiple outputs are all non-tensors + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + if output_list: + return func_output_list4(x, factor=4) + else: + return func_output_list3(x, factor=4) + + dummy_input = {'x': torch.randn(4, 4)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) + print(ir_graph.extra_repr()) + + # the output number of node depends on the annoation. + # the right part of both func_output_list3 and func_output_list4 are the same ( `?, ?`) + # so two outputs are expected + assert len(ir_graph.nodes()[-1].outputs()) == 2 + + # the graph output number depends on the function return type + # func_output_list3 returns a tuple of two items, so the number of graph output is 2 + # func_output_list4 returns a list, so the number of graph output is 1 (the whole list is a single output) + if output_list: + assert len(ir_graph.outputs()) == 1 + else: + assert len(ir_graph.outputs()) == 2 + + +@nnscaler.register_op('m n -> ?') +def func_output_list5(x, factor=1): + return factor, factor * 2 + + +@nnscaler.register_op('m n -> ?') +def func_output_list6(x, factor=1): + return [factor, factor * 2] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('output_list', [True, False]) +def test_non_tensor_multiple_outputs2(tmp_path, output_list): + # test the case when multiple outputs are all non-tensors + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + if output_list: + return func_output_list6(x, factor=4) + else: + return func_output_list5(x, factor=4) + + dummy_input = {'x': torch.randn(4, 4)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) + print(ir_graph.extra_repr()) + + # the output number of node depends on the annoation. + # the right part of both func_output_list3 and func_output_list4 are the same ( `?`) + # so 1 output are expected + assert len(ir_graph.nodes()[-1].outputs()) == 1 + + # the graph output number depends on the function return type + # func_output_list3 returns a tuple of two items, so the number of graph output is 2 + # func_output_list4 returns a list, so the number of graph output is 1 (the whole list is a single output) + if output_list: + assert len(ir_graph.outputs()) == 1 + else: + assert len(ir_graph.outputs()) == 2 + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('output_list', [True, False]) +def test_non_tensor_multiple_outputs3(tmp_path, output_list): + # test the case when multiple outputs are all non-tensors + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + if output_list: + return func_output_list6(x, factor=4), 1 + else: + return func_output_list5(x, factor=4), 2 + + dummy_input = {'x': torch.randn(4, 4)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) + print(ir_graph.extra_repr()) + + # the output number of node depends on the annoation. + # the right part of both func_output_list3 and func_output_list4 are the same ( `?`) + # so 1 output are expected + assert len(ir_graph.nodes()[-1].outputs()) == 1 + # as both returns tuple, the graph output is 2. + assert len(ir_graph.outputs()) == 2 + + +@replace_all_device_with('cpu') +def test_non_tensor_multiple_outputs4(tmp_path): + # test the case when multiple outputs are all non-tensors + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + dummy_input = {'x': (1, 2)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) + print(ir_graph.extra_repr()) + + # FIXME: + # This is not by design. + # The output number should be 2 + # but we wrap the whole input in a IRObject + # when it goes to output, we can't unpack it. + # so the output number is 1 for now. + # Will be fixed later. + assert len(ir_graph.outputs()) == 1 diff --git a/tests/ir/test_cten.py b/tests/ir/test_cten.py index 9baba6db..d2671b53 100644 --- a/tests/ir/test_cten.py +++ b/tests/ir/test_cten.py @@ -5,7 +5,7 @@ import pytest -from nnscaler.ir.cten import IRObject +from nnscaler.ir.cten import IRObject, IR from nnscaler.ir.tensor import IRFullTensor, IRSubTensor from nnscaler.graph.parser.parser import TensorMetadata, DICT_VALUES_TYPE, DICT_ITEMS_TYPE @@ -20,28 +20,28 @@ def test_from_complex(tosub, requires_grad): rgt = requires_grad if rgt is None: rgt = True - obj = IRObject.from_complex('n', 1, tosub=tosub, requires_grad=requires_grad) + obj = IR.new('n', 1, tosub=tosub, requires_grad=requires_grad) assert type(obj) == IRObject and obj.value == 1 and not obj.is_constant and obj.name == 'n' - obj = IRObject.from_complex('n', [1, 2], tosub=tosub, requires_grad=requires_grad) + obj = IR.new('n', [1, 2], tosub=tosub, requires_grad=requires_grad) assert type(obj) == IRObject and obj.value == [1, 2] and not obj.is_constant and obj.name == 'n' - obj = IRObject.from_complex('n', {'a': 1, 'b': 2}, tosub=tosub, requires_grad=requires_grad) + obj = IR.new('n', {'a': 1, 'b': 2}, tosub=tosub, requires_grad=requires_grad) assert type(obj) == IRObject and obj.value == {'a': 1, 'b': 2} and not obj.is_constant and obj.name == 'n' - obj = IRObject.from_complex('n', {'a': {'c': [3, 4], 'd': [4, 5]}, 'b': [1,2]}, tosub=tosub, requires_grad=requires_grad) + obj = IR.new('n', {'a': {'c': [3, 4], 'd': [4, 5]}, 'b': [1,2]}, tosub=tosub, requires_grad=requires_grad) assert type(obj) == IRObject and obj.value == {'a': {'c': [3, 4], 'd': [4, 5]}, 'b': [1,2]} and not obj.is_constant and obj.name == 'n' t1 = torch.tensor(1.0) t2 = torch.tensor([2.0, 3.0], requires_grad=True) - obj = IRObject.from_complex('n', t1, tosub=tosub, requires_grad=requires_grad) + obj = IR.new('n', t1, tosub=tosub, requires_grad=requires_grad) assert type(obj) == tensor_type and id(obj.value) == id(t1) \ and obj.shape == (1,) and obj.origin_shape == () and obj.dtype == torch.float \ and obj.requires_grad == rg and not obj.is_constant \ and obj.name == 'n' - obj = IRObject.from_complex('n', [t1, t2, 1], tosub=tosub, requires_grad=requires_grad) + obj = IR.new('n', [t1, t2, 1], tosub=tosub, requires_grad=requires_grad) assert type(obj) == list and len(obj) == 3 assert type(obj[0]) == tensor_type and id(obj[0].value) == id(t1) \ and obj[0].shape == (1,) and obj[0].origin_shape == () and obj[0].dtype == torch.float \ @@ -53,7 +53,7 @@ def test_from_complex(tosub, requires_grad): and obj[1].name == 'n' assert type(obj[2]) == IRObject and obj[2].value == 1 and not obj[2].is_constant and obj[2].name == 'n' - obj = IRObject.from_complex('n', {'a': [1, 2, t1], 'b': 2}, tosub=tosub, requires_grad=requires_grad) + obj = IR.new('n', {'a': [1, 2, t1], 'b': 2}, tosub=tosub, requires_grad=requires_grad) assert type(obj) == dict and len(obj) == 2 x = obj['a'] assert type(x) == list and len(x) == 3 @@ -67,13 +67,13 @@ def test_from_complex(tosub, requires_grad): assert type(y) == IRObject and y.value == 2 and not y.is_constant and y.name == 'n' x = [t1, t2, 1] - obj = IRObject.from_complex('n', x, tosub=tosub, tensor_types=(), requires_grad=requires_grad) + obj = IR.new('n', x, tosub=tosub, tensor_types=(), requires_grad=requires_grad) assert type(obj) == IRObject and id(obj.value) == id(x) and not obj.is_constant and obj.name == 'n' - obj = IRObject.from_complex('n', x, tosub=tosub, collection_types=(tuple,), requires_grad=requires_grad) + obj = IR.new('n', x, tosub=tosub, collection_types=(tuple,), requires_grad=requires_grad) assert type(obj) == IRObject and id(obj.value) == id(x) and not obj.is_constant and obj.name == 'n' - obj = IRObject.from_complex('n', [t1, [1, 2, {'a': 3}], (4, 5, {'b': 6, 'c': t2})], tosub=tosub, requires_grad=requires_grad) + obj = IR.new('n', [t1, [1, 2, {'a': 3}], (4, 5, {'b': 6, 'c': t2})], tosub=tosub, requires_grad=requires_grad) assert type(obj) == list and len(obj) == 3 assert type(obj[0]) == tensor_type and id(obj[0].value) == id(t1) \ and obj[0].shape == (1,) and obj[0].origin_shape == () and obj[0].dtype == torch.float \ @@ -97,7 +97,7 @@ def test_from_complex(tosub, requires_grad): t2 = TensorMetadata(shape=(2,), dtype=torch.float, requires_grad=True, stride=None, memory_format=None, is_quantized=None, qparams=None) - obj = IRObject.from_complex('n', {'a': t1, 'b': t2}.values(), + obj = IR.new('n', {'a': t1, 'b': t2}.values(), collection_types=(DICT_VALUES_TYPE,), tensor_types=(TensorMetadata,), tosub=tosub, requires_grad=requires_grad @@ -114,7 +114,7 @@ def test_from_complex(tosub, requires_grad): and y.requires_grad == rgt and not y.is_constant \ and y.name == 'n' - obj = IRObject.from_complex('n', {'a': t1, 'b': t2}.items(), + obj = IR.new('n', {'a': t1, 'b': t2}.items(), collection_types=(DICT_ITEMS_TYPE,), tensor_types=(TensorMetadata,), tosub=tosub, requires_grad=requires_grad diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 96bbfaa7..c3f5b164 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -5,14 +5,18 @@ import tempfile import re from contextlib import nullcontext +from typing import Union import torch import torch.nn.functional as F import pytest -from torch.torch_version import TorchVersion +from unittest.mock import patch from nnscaler.flags import CompileFlag import nnscaler.graph.function.dimops +from nnscaler.graph.function.pyfunc import IRPyFunc +from nnscaler.graph.parser.mapping import SignFx2Op +from nnscaler.ir.cten import IR, IRObject from nnscaler.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph from .common import init_distributed @@ -1676,3 +1680,147 @@ def test_fold_constant(tmp_path, fold_input): # mul_2_51 = torch.mul(mul_1_57, add_38) assert _gencode_contains(tmp_path, CCFModule2, 0, r'mul_.* = torch\.mul\(mul_.*, add_.*\)') + + +@nnscaler.register_op('? ->') +def _op1(k): + pass + + +@nnscaler.register_op('? -> ?') +def _op2(k): + pass + + +@nnscaler.register_op(' -> ?') +def _op3(): + return 1 + + +@nnscaler.register_op('? -> ?') +def _op4(k): + return 1 if k else 0 + + +class IRNoneModule(torch.nn.Module): + def forward(self, x): + _op1(2) + r = _op2(3) + r = _op3() + _op4(r) + return x + r + + +@replace_all_device_with('cpu') +def test_no_return(tmp_path): + m = IRNoneModule() + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + reuse='override', + load_module=False, + ) + # it should looks like: + # def segment18(self, x_22): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1700, in forward, r = _op2(3) + # _op2_3 = tests.parallel_module.test_gencode._op2(3) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1701, in forward, r = _op3() + _op4(r) + # _op3_4 = tests.parallel_module.test_gencode._op3() + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1701, in forward, r = _op3() + _op4(r) + # _op4_5 = tests.parallel_module.test_gencode._op4(_op2_3) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1701, in forward, r = _op3() + _op4(r) + # add_14 = _operator.add(_op3_4, _op4_5) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1702, in forward, return x + r + # add_1_19 = torch.add(x_22, add_14, alpha=1) + # del x_22 + # return add_1_19 + + # _op1 will be removed by DCE in tracer + assert not _gencode_contains(tmp_path, IRNoneModule, 0, + r'tests\.parallel_module\.test_gencode\._op1') + + +class IRUseNoneModule(torch.nn.Module): + def forward(self, x): + r = _op3() + _op4(_op1(2)) + return x + r + + +@replace_all_device_with('cpu') +def test_use_none_return(tmp_path): + m = IRUseNoneModule() + m.train() + # it should raise an error, because _op1 has no return value, but it is used in _op4 + with pytest.raises(KeyError): + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + reuse='override', + load_module=False, + ) + + + +@nnscaler.register_op('? -> ?, ?') +def _op5(k): + return 1 + k, 2 + + +def _op6(k): + return 1 + k, 2 + + +# the ops registered with register_op can't cover all code path in parser +def Op6(o: Union[int, IRObject], signature=None): + o = IR.try_unwrap(o) + return IRPyFunc(signature, inputs=[o], outputs=[ + IRObject(name='_op6', value=o + 1, is_constant=True), + IRObject(name='_op6', value=2, is_constant=True), + ]) + + +class IRMultiOutputModule(torch.nn.Module): + def forward(self, x): + r0, _ = _op5(2) + r1, _ = _op6(3) + return x + r0 + r1 + + + +@replace_all_device_with('cpu') +def test_multi_output_op(tmp_path): + SignFx2Op.kOpMap['tests.parallel_module.test_gencode._op6'] = Op6 + + from nnscaler.graph.tracer import concrete_trace + from nnscaler.graph.tracer.wrap_utils import LeafWrapInfo + def patched_concrete_trace(*args, **kwargs): + kwargs['dce_ignored_function'].add(_op6) + kwargs['autowrap_leaf_function'][_op6] = LeafWrapInfo([], True, None) + + return concrete_trace(*args, **kwargs) + + with patch( + "nnscaler.graph.parser.converter.concrete_trace", + side_effect=patched_concrete_trace + ): + m = IRMultiOutputModule() + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + reuse='override', + load_module=False, + ) + + SignFx2Op.kOpMap.pop('tests.parallel_module.test_gencode._op6') + # should success + assert True From 93268ef2ac730548c038b87b3181ca3e04d48b43 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Tue, 10 Dec 2024 06:17:50 +0000 Subject: [PATCH 1785/1892] Merged PR 2339: [AutoDist] Multi-nodes communication profiling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Support profiling multi-nodes communication; move the script into nnscaler package. Usage: `torchrun ... -m nnscaler.profiler` The result format is "compatible" with current autodist. Multi-nodes setups like 4×8 GPUs will be saved as "intra_32". The old scripts are not removed (yet) because they behave differently. Let @ decide whether/when to remove them.. Related work items: #2089 --- examples/llama/README.rst | 2 +- nnscaler/profiler/__main__.py | 6 + nnscaler/profiler/comm_profile.py | 197 ++++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 nnscaler/profiler/__main__.py create mode 100644 nnscaler/profiler/comm_profile.py diff --git a/examples/llama/README.rst b/examples/llama/README.rst index 4dd5be41..401dd136 100644 --- a/examples/llama/README.rst +++ b/examples/llama/README.rst @@ -162,7 +162,7 @@ If the profiling is skipped, the system will use MI250's data by default. You ca .. code-block:: bash - cd nnscaler && python utility/prim_profiler.py + torchrun --nnodes= --nproc_per_node= -m nnscaler.profiler Checkpoint ========== diff --git a/nnscaler/profiler/__main__.py b/nnscaler/profiler/__main__.py new file mode 100644 index 00000000..0340eaf4 --- /dev/null +++ b/nnscaler/profiler/__main__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from . import comm_profile + +comm_profile.main() diff --git a/nnscaler/profiler/comm_profile.py b/nnscaler/profiler/comm_profile.py new file mode 100644 index 00000000..ec12b6b0 --- /dev/null +++ b/nnscaler/profiler/comm_profile.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from datetime import datetime +import json +import logging +import sys + +import torch + +import nnscaler +from nnscaler.autodist.util import get_node_arch, get_default_profile_path +from nnscaler.profiler import CudaTimer +from nnscaler.runtime.adapter.collectives import all_gather, all_reduce, all_to_all, reduce_scatter +from nnscaler.runtime.device import DeviceGroup +from nnscaler.utils import is_running_distributed + + +_logger = logging.getLogger('nnscaler.profiler') + + +# The profiling result of a primitive function, as two lists of the same length. +# The first list contains tensor sizes (in MB). +# The second list contains corresponding time consumption (in seconds). +PrimitiveProfile = tuple[list[float], list[float]] + +# The profiling result of a GPU group. +# The key is a primitive function's name: "all gather", "all reduce", "reduce scatter", "all to all" +# The value is its profiling result. +# NOTE: the function names use spaces, not underscores +Profile = dict[str, PrimitiveProfile] + + +class CommProfiler: + def __init__(self, warmup_times: int = 10, profile_times: int = 10): + self.warmup_times = warmup_times + self.profile_times = profile_times + + def profile_all(self) -> dict[str, Profile]: + ret = {} + + # run on all nodes to simplify barrier + ret.update(self.profile_single_node()) + + if DeviceGroup().world_size > DeviceGroup().local_world_size: + ret.update(self.profile_multi_nodes()) + + return ret + + def profile_single_node(self) -> dict[str, Profile]: + # The key is GPU numbers in string format: "2", "4", "8", ... + ret = {} + n_procs = DeviceGroup().local_world_size + + device_num = 2 + while device_num <= n_procs: + key = str(device_num) + if DeviceGroup().local_rank == 0: + _logger.info(f'Profiling {key} GPUs...') + ranks = tuple(range(device_num)) + + # dist.new_group() must be invoked on all ranks, + # but invoking primitives on all ranks will raise warning + DeviceGroup().get_group(ranks) + if DeviceGroup().rank in ranks: + ret[key] = self.profile_ranks(ranks) + DeviceGroup().long_barrier() + + device_num *= 2 + + return ret + + def profile_multi_nodes(self) -> dict[str, Profile]: + # The key is "{nnodes}x{ngpus}": "2x8", "4x8", "8x8", ... + # Because 2x2 is likely to slower than 1x4, we only test N x local_world_size + ret = {} + + # assuming all nodes have the same GPU numbers + world_size = DeviceGroup().world_size + local_world_size = DeviceGroup().local_world_size + assert world_size % local_world_size == 0, 'The nodes are heterogeneous' + + n_nodes = world_size // local_world_size + n_procs = local_world_size + + node_num = 2 + while node_num <= n_nodes: + key = f'{node_num}x{n_procs}' + if DeviceGroup().local_rank == 0: + _logger.info(f'Profiling {key} GPUs...') + ranks = list(range(n_procs * node_num)) + + # dist.new_group() must be invoked on all ranks, + # but invoking primitives on all ranks will raise warning + DeviceGroup().get_group(ranks) + if DeviceGroup().rank in ranks: + ret[key] = self.profile_ranks(ranks) + DeviceGroup().long_barrier() + + node_num *= 2 + + return ret + + def profile_ranks(self, ranks: list[int]) -> Profile: + profile_info = {} + for primitive in ['all gather', 'all reduce', 'reduce scatter', 'all to all']: + profile_info[primitive] = self.profile_primitive(primitive, ranks) + return profile_info + + def profile_primitive(self, primitive: str, ranks: list[int]) -> PrimitiveProfile: + b_size = 16 + sequence_len = 16 + quarter_mb_size_list = [ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 + ] + model_dim_list = [ + mem * 256 * 256 // b_size // sequence_len + for mem in quarter_mb_size_list + ] + sizes_in_mb = [0.25 * val for val in quarter_mb_size_list] + times_in_s = [] + for cur_sz, d_size in zip(sizes_in_mb, model_dim_list): + assert d_size % len(ranks) == 0 + if primitive in ['all gather', 'all to all']: + d_size = d_size // len(ranks) + # Here dtype has little impact. Here we just use `float32` + tensor = torch.rand([b_size, sequence_len, d_size], + dtype=torch.float32, + device=torch.cuda.current_device()) + # dim has no impact on transmission. In the following test, we use 0 for idim and 2 for odim. + if primitive == 'all gather': + func = lambda: all_gather(tensor=tensor, dim=2, ranks=ranks) + elif primitive == 'all reduce': + func = lambda: all_reduce(tensor=tensor, ranks=ranks) + elif primitive == 'reduce scatter': + func = lambda: reduce_scatter(tensor=tensor, dim=2, ranks=ranks) + elif primitive == 'all to all': + func = lambda: all_to_all(tensor=tensor, idim=0, odim=2, ranks=ranks) + else: + raise ValueError('Unknown primitive: {}'.format(primitive)) + for _ in range(self.warmup_times): + func() + CudaTimer().clear() + for _ in range(self.profile_times): + _otensor = func() + cur_t = CudaTimer().instance.field_data['comm'] / self.profile_times + times_in_s.append(cur_t) + return sizes_in_mb, times_in_s + + +def main() -> bool: + if not is_running_distributed(): + print('Usage: torchrun {TORCHRUN_ARGS} -m nnscaler.profiler') + sys.exit(1) + + nnscaler.init() + + if DeviceGroup().local_rank == 0: + nnscaler.utils.set_default_logger_level('INFO') + else: + nnscaler.utils.set_default_logger_level('DEBUG') + + CudaTimer(enable=True, predefined=True) + + comm_profiler = CommProfiler() + profile_info = comm_profiler.profile_all() + + if DeviceGroup().rank == 0: + comm_path = get_default_profile_path() / 'comm' + if comm_path.exists(): + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + backup_path = comm_path.with_name(f'comm-bak-{timestamp}') + _logger.info('Profiling data already exists') + _logger.info(f'Backup old data to {backup_path}') + comm_path.rename(backup_path) + + comm_path.mkdir(parents=True, exist_ok=True) + + for key, profile in profile_info.items(): + if 'x' in key: + # FIXME: saving inter-nodes results as intra + x, y = key.split('x') + key = str(int(x) * int(y)) + file_name = comm_path / f'intra_{key}.json' + with open(file_name, 'w') as f: + json.dump(profile, f, indent=2) + + _logger.info('Profiling done') + + elif DeviceGroup().local_rank == 0: + _logger.info('Multi-nodes profiling done') + + +if __name__ == '__main__': + main() From 7433277728bfbae1fb59e9ce48f9f9705766f934 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Tue, 10 Dec 2024 06:52:10 +0000 Subject: [PATCH 1786/1892] Merged PR 2343: Update README.md (merge github) Moving success stories and releases to the top by Scarlett --- README.md | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 83bcb86e..37630147 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -drawing +drawing nnScaler: Compiling DNN models for Parallel Training over Multiple Devices ============== @@ -11,6 +11,12 @@ nnScaler is a parallelization engine that compiles a Deep neural network (DNN) m drawing +# Latest News +nnScaler (also known as CUBE as code name) has been adopted by multiple product and research projects, this section includes some of the latest news from the team and partner projects. +* **2024-11-26** nnScaler 0.5 released: https://github.com/microsoft/nnscaler/releases/tag/0.5 +* **2024-05-09** YOCO utilizes nnScaler for long-sequence training: [(YOCO)You only cache once: Decoder-decoder architectures for language models](https://arxiv.org/abs/2405.05254) +* **2024-04-22** Post training for the long context version of [Phi-3 series](https://arxiv.org/abs/2404.14219) +* **2024-02-21** LongRoPE utilizes nnScaler to reduce both the training and inference costs: [LongRoPE: Extending LLM context window beyond 2 million tokens](https://arxiv.org/abs/2402.13753) ### System Highlights: @@ -40,7 +46,7 @@ Install the following packages before the installation of nnScaler: PyTorch >= 2.0, < 2.4 (2.2.0 is recommanded) ### Install nnScaler from source -Execute below commands in nnScaler directory: +Execute below commands in nnScaler directory: pip install -r requirements.txt pip install -e . @@ -57,13 +63,13 @@ Besides, to avoid *cppimport* error, it also needs to include nnScaler directory ### Prerequisite for Llama-3 -Install packages required to run Llama-3. Besides, a certain version of CUDA library is needed during flash-attn installation. For example, [CUDA V11.8](https://developer.nvidia.com/cuda-11-8-0-download-archive) is needed if using PyTorch 2.20. +Install packages required to run Llama-3. Besides, a certain version of CUDA library is needed during flash-attn installation. For example, [CUDA V11.8](https://developer.nvidia.com/cuda-11-8-0-download-archive) is needed if using PyTorch 2.20. python -m pip install transformers==4.40.0 flash-attn==2.5.5 tensorboard ### Model Access -Obtain access of Llama-3 model from [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), where you will receive an access token which should be set as an environment variable: +Obtain access of Llama-3 model from [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), where you will receive an access token which should be set as an environment variable: export HF_TOKEN= @@ -96,7 +102,7 @@ class WrapperModel(torch.nn.Module): def main(args): # data config dataloader_config = ... - + # model config model_config = ModelConfig( type=WrapperModel, @@ -104,15 +110,15 @@ def main(args): 'model_id': args.model_id, }, ) - # optimizer hyperparameters + # optimizer hyperparameters optimizer_config = OptimizerConfig( type=MixedPrecisionAdamW, args={'lr': 2e-5, 'betas': (0.9, 0.95), 'weight_decay': 0.0, 'fused': True}, #... ) #... - - # setup trainer with configs of dataloader/model/optimizer, etc. + + # setup trainer with configs of dataloader/model/optimizer, etc. trainer = Trainer(train_args=TrainerArgs( #... model=model_config, @@ -126,7 +132,7 @@ def main(args): ### Run the example Llama-3 training -Then we can start the example, and all the parallelization tasks will be finished by nnScaler automatically. +Then we can start the example, and all the parallelization tasks will be finished by nnScaler automatically. ```shell cd examples/llama3_8B_128K @@ -184,14 +190,6 @@ NOTE: The local batch size is fixed by default, so using more workers will resul 💡 For advanced usages, please stay tuned for our future release. - -# Success Stories - -nnScaler has been adopted by multiple projects, including both product and research explorations: -* [(YOCO)You only cache once: Decoder-decoder architectures for language models](https://arxiv.org/abs/2405.05254) -* [LongRoPE: Extending LLM context window beyond 2 million tokens](https://arxiv.org/abs/2402.13753) -* Post training for the long context version of [Phi-3 series](https://arxiv.org/abs/2404.14219) - # Reference --------- From 4d6cd755687021fb4b513bd01da862f2faab24a7 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 12 Dec 2024 03:15:55 +0000 Subject: [PATCH 1787/1892] Merged PR 2344: [Parser] never do dce on leaf functions never do dce on leaf functions --- nnscaler/graph/parser/converter.py | 8 +++++-- nnscaler/graph/parser/parser.py | 6 ++--- tests/parallel_module/test_gencode.py | 34 ++++++++++++++------------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 85ad099b..a30dfa23 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -76,7 +76,8 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: Returns: torch.fx.GraphModule representation of model """ - # get registered leaf function + # get user registered functions, and treat them as leaf functions + # torch function/operators/builtins/... are automatically handled as leaf functions by concrete trace autowrap_funcs = [CustomizedOps.kOpRuntime[sign] for sign in CustomizedOps.kOpMap] # filter out torch.autograd.Function.apply as concrete trace already treats them as leaf function autowrap_funcs = [fn for fn in autowrap_funcs if not is_autograd_apply(fn)] @@ -100,7 +101,10 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: func: LeafWrapInfo([Location(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs }) - dce_ignored_funcs = set(cube_rt_funcs) + + # keep all leaf functions in the graph (even if leaf functions have no return value) + # Please note the result graph is not always DAG, and can have outliers. + dce_ignored_funcs = set(leaf_functions) with no_save_tensor_hook(), warnings.catch_warnings(): # ignore the warning from fx about get_attr diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index cb26d26c..49b71836 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -285,11 +285,9 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # step 1: align the node output with the value in frame - # TODO: handle the case when the function has no output + # handle the case when the function has no output # Currently this can only happened for user registered functions - # We may need to assume the function has side effect, and keep it in the graph - # But in current implementation, this kind of function will be removed by DCE in tracer - # As a result, the follow check will not be triggered in normal cases + # We need to assume the function has side effect, and keep it in the graph if not ir_node.outputs(): # To avoid the case that the function is annotated no output # but its output is used in other nodes. diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index c3f5b164..c21f9421 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -1724,22 +1724,24 @@ def test_no_return(tmp_path): load_module=False, ) # it should looks like: - # def segment18(self, x_22): - # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1700, in forward, r = _op2(3) - # _op2_3 = tests.parallel_module.test_gencode._op2(3) - # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1701, in forward, r = _op3() + _op4(r) - # _op3_4 = tests.parallel_module.test_gencode._op3() - # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1701, in forward, r = _op3() + _op4(r) - # _op4_5 = tests.parallel_module.test_gencode._op4(_op2_3) - # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1701, in forward, r = _op3() + _op4(r) - # add_14 = _operator.add(_op3_4, _op4_5) - # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1702, in forward, return x + r - # add_1_19 = torch.add(x_22, add_14, alpha=1) - # del x_22 - # return add_1_19 - - # _op1 will be removed by DCE in tracer - assert not _gencode_contains(tmp_path, IRNoneModule, 0, + # def segment19(self, x_23): + # # File "/home/weijiangxu/nanogpt/MagicCube/tests/parallel_module/test_gencode.py", line 1707, in forward, _op1(2) + # tests.parallel_module.test_gencode._op1(2) + # # File "/home/weijiangxu/nanogpt/MagicCube/tests/parallel_module/test_gencode.py", line 1708, in forward, r = _op2(3) + # _op2_4 = tests.parallel_module.test_gencode._op2(3) + # # File "/home/weijiangxu/nanogpt/MagicCube/tests/parallel_module/test_gencode.py", line 1709, in forward, r = _op3() + _op4(r) + # _op3_5 = tests.parallel_module.test_gencode._op3() + # # File "/home/weijiangxu/nanogpt/MagicCube/tests/parallel_module/test_gencode.py", line 1709, in forward, r = _op3() + _op4(r) + # _op4_6 = tests.parallel_module.test_gencode._op4(_op2_4) + # # File "/home/weijiangxu/nanogpt/MagicCube/tests/parallel_module/test_gencode.py", line 1709, in forward, r = _op3() + _op4(r) + # add_15 = _operator.add(_op3_5, _op4_6) + # # File "/home/weijiangxu/nanogpt/MagicCube/tests/parallel_module/test_gencode.py", line 1710, in forward, return x + r + # add_1_20 = torch.add(x_23, add_15, alpha=1) + # del x_23 + # return add_1_20 + + # _op1 will not be removed by DCE in tracer + assert _gencode_contains(tmp_path, IRNoneModule, 0, r'tests\.parallel_module\.test_gencode\._op1') From a7a5089787e0a7d486e9a36724842d710a7f2f84 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 13 Dec 2024 05:59:18 +0000 Subject: [PATCH 1788/1892] Merged PR 2342: [Reorg + Parser] refine infer_shape refine infer_shape 1. Now `infer_shape` is a const method, which returns the inferred shape based on its inputs and kwargs (and annotation if op is IRDimops) 2. Add `verify_shape` to verify whether the shape of tensors in outputs matches the inferred shape. Other assumption change: 1. After IRFwOperation is created, all its outputs will be IRObject.missing, and should be assigned (with .set_output()) before you can use it. 2. Most infer_shape calls are replaced with verify_shape, which means all ops should assign outputs by themselves. parity check passes unit test passes --- nnscaler/algorithm/ops/dimops.py | 2 +- nnscaler/graph/function/anchor.py | 4 +-- nnscaler/graph/function/conv.py | 27 +++++++++----------- nnscaler/graph/function/dimops.py | 17 ++++--------- nnscaler/graph/function/pyfunc.py | 8 +++--- nnscaler/graph/parser/parser.py | 25 +++--------------- nnscaler/graph/segment.py | 6 ++--- nnscaler/ir/operator.py | 39 +++++++++++++++++++++++------ pyproject.toml | 7 ++++++ tests/graph/function/test_dimops.py | 13 +++++++++- 10 files changed, 80 insertions(+), 68 deletions(-) diff --git a/nnscaler/algorithm/ops/dimops.py b/nnscaler/algorithm/ops/dimops.py index 668e4507..1ba38c79 100644 --- a/nnscaler/algorithm/ops/dimops.py +++ b/nnscaler/algorithm/ops/dimops.py @@ -170,7 +170,7 @@ def transform(tensor: Any, split: DimopSplit) -> List[Any]: outputs = [t[nid] for t in ous] kwargs = rule.modifier()(node.kwargs, idx, dim, num, nid) sub_node: IRDimops = node.new(inputs, outputs, **kwargs) - sub_node.infer_shape() + sub_node.verify_shape() sub_nodes.append(sub_node) return sub_nodes diff --git a/nnscaler/graph/function/anchor.py b/nnscaler/graph/function/anchor.py index 298145c5..6a5986c6 100644 --- a/nnscaler/graph/function/anchor.py +++ b/nnscaler/graph/function/anchor.py @@ -48,8 +48,8 @@ def __init__(self, signature: str, name: str): self.kwargs['name'] = name self.set_output(0, IRObject('anchor', value=None)) - def infer_shape(self): - return True + def infer_shape(self) -> dict[int, tuple[int, ...]]: + return {} def __repr__(self) -> str: return f"AnchorOp-{self.cid}(name={self.name})" diff --git a/nnscaler/graph/function/conv.py b/nnscaler/graph/function/conv.py index b14aeedf..322dc65b 100644 --- a/nnscaler/graph/function/conv.py +++ b/nnscaler/graph/function/conv.py @@ -17,12 +17,12 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, assert len(kwargs) == 3, "Expected 2 kwargs: mode, value" super().__init__(name, signature, inputs, 1, **kwargs) - def infer_shape(self) -> bool: + def infer_shape(self) -> dict[int, tuple[int, ...]]: """ Output shape inference given the input shapes """ if len(self.input(0).shape) == 0: - return False + return {} pad = self.kwargs['pad'] assert len(pad) % 2 == 0, "IRPad::infer_shape len(pad) % 2 == 0" @@ -31,8 +31,7 @@ def infer_shape(self) -> bool: for pad_idx, pad_size in enumerate(pad): shape[-1 - (pad_idx // 2)] += pad_size - self.output(0).shape = shape - return True + return {0: shape} def new(self, inputs: List, outputs: List, pad = None): """ @@ -47,7 +46,7 @@ def new(self, inputs: List, outputs: List, pad = None): pad=pad, mode=mode, value=value) assert len(outputs) == 1 op.set_output(0, outputs[0]) - op.infer_shape() + op.verify_shape() return op @@ -60,12 +59,12 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" super().__init__(name, signature, inputs, 1, **kwargs) - def infer_shape(self) -> bool: + def infer_shape(self) -> dict[int, tuple[int, ...]]: """ Output shape inference given the input shapes """ if len(self.input(0).shape) == 0 or len(self.input(1).shape) == 0: - return False + return {} N = self.input(0).shape[0] iH, iW = self.input(0).shape[2:4] oC = self.input(1).shape[0] @@ -77,8 +76,7 @@ def infer_shape(self) -> bool: oH = (iH + padding[0] + padding[1] - dilation[0] * (dH - 1) - 1) // stride[0] + 1 oW = (iW + padding[2] + padding[3] - dilation[1] * (dW - 1) - 1) // stride[1] + 1 shape = [N, oC, oH, oW] - self.output(0).shape = shape - return True + return {0: shape} def new(self, inputs: List, outputs: List): """ @@ -93,7 +91,7 @@ def new(self, inputs: List, outputs: List): stride=stride, padding=padding, dilation=dilation, groups=groups) assert len(outputs) == 1 op.set_output(0, outputs[0]) - op.infer_shape() + op.verify_shape() return op @@ -106,12 +104,12 @@ def __init__(self, signature: str, inputs: List[IRTensor], name: str, assert len(kwargs) == 4, "Expected 4 kwargs: stride, padding, dialation, groups" super().__init__(name, signature, inputs, 1, **kwargs) - def infer_shape(self) -> bool: + def infer_shape(self) -> dict[int, tuple[int, ...]]: """ Output shape inference given the input shapes """ if len(self.input(0).shape) == 0 or len(self.input(1).shape) == 0: - return False + return {} N = self.input(0).shape[0] iC = self.input(0).shape[1] iT, iH, iW = self.input(0).shape[2:5] @@ -129,8 +127,7 @@ def infer_shape(self) -> bool: oW = (iW + 2 * padding[2] - dilation[2] * (dW - 1) - 1) // stride[2] + 1 shape = [N, oC, oT, oH, oW] - self.output(0).shape = shape - return True + return {0: shape} def new(self, inputs: List, outputs: List): """ @@ -145,5 +142,5 @@ def new(self, inputs: List, outputs: List): stride=stride, padding=padding, dilation=dilation, groups=groups) assert len(outputs) == 1 op.set_output(0, outputs[0]) - op.infer_shape() + op.verify_shape() return op diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index 24241de2..b5d875c8 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -734,11 +734,6 @@ def __init__(self, create_fn: Callable, name: str, n_outputs = len(self._oannos) super().__init__(name, signature, inputs, n_outputs, **kwargs) - # change tensor to IRObject for '?' annotation - for idx, shape_anno in enumerate(self._oannos): - if shape_anno.ignore: - self.set_output(idx, IRObject.missing) - @property def anno(self) -> OpAnno: return self._anno @@ -769,13 +764,14 @@ def oanno(self, index: int) -> ShapeAnno: assert index < len(self.outputs()), "index out of boudary" return self._oannos[index] - def infer_shape(self) -> bool: + def infer_shape(self) -> dict[int, tuple[int, ...]]: """ Shape inference using the matched annotation and tensor. @return sucess: True if successfully inferred shape """ - for oidx, otensor in enumerate(self.outputs()): + shapes = {} + for oidx in range(len(self.outputs())): shape_anno = self.oanno(oidx) if shape_anno.ignore: # otensor can be any type, including IRObject, collection types (list, dict, etc.) @@ -786,11 +782,8 @@ def infer_shape(self) -> bool: for identifier in shape_anno[odim].identifiers: accum *= self.anno.getlen(identifier) shape.append(accum) - otensor.shape = shape - # print(f'=> sign: {self.signature} anno: {self.anno}\n' - # f'=> inputs: {self.inputs()}\n' - # f'=> outputs: {self.outputs()}') - return True + shapes[oidx] = tuple(shape) + return shapes def new(self, inputs: List[IRTensor], outputs: List[IRTensor], **kwargs): """! diff --git a/nnscaler/graph/function/pyfunc.py b/nnscaler/graph/function/pyfunc.py index ce0a5643..733849ed 100644 --- a/nnscaler/graph/function/pyfunc.py +++ b/nnscaler/graph/function/pyfunc.py @@ -12,7 +12,7 @@ class IRPyFunc(IRFwOperation): Python runtime function """ - def __init__(self, signature: str, + def __init__(self, signature: str, inputs: Tuple[IRObject], outputs: Tuple[IRObject], **kwargs): name = signature.split('.')[-1] super().__init__(name, signature, inputs, len(outputs)) @@ -20,12 +20,12 @@ def __init__(self, signature: str, self.set_output(idx, t) self.kwargs.update(**kwargs) - def infer_shape(self) -> bool: + def infer_shape(self) -> dict[int, tuple[int, ...]]: """ Shape will not be inferred for python runtime """ - return True - + return {} + def __repr__(self) -> str: sign = self.signature.split('.')[-1] dscp = (f"PyOp{self._id}-{self.device}(sign={sign}, " diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 49b71836..046ac2ab 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -312,28 +312,9 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule vals = type(vals)(IRObject(name=node.name, value=v, is_constant=is_constant) for v in vals) frame.set_var(node.name, vals) - # this is only for annoation check - # to make sure the annoation is consistent with actual output - if isinstance(ir_node, IRDimops): - # TODO: refine here - # infer_shape actually just check whether the annoation is consistent - # with actual output - # internally it will set the shape of output, - # but the output is quickly rewritten by the actual output - # in following code `ir_node.set_output(0, output_val)` - # So the scalar-tensor flag is not removed with `infer_shape` - ir_node.infer_shape() - for oidx, otensor in enumerate(ir_node.outputs()): - shape_anno = ir_node.oanno(oidx) - if shape_anno.ignore: - continue - assert isinstance(vals[oidx], IRTensor) and isinstance(otensor, IRTensor), \ - f'find type inference not match: {vals[oidx]} vs {otensor}' - - assert vals[oidx].shape == otensor.shape, ( - f'find shape inference not match: {vals[oidx].shape} vs {otensor.shape}' - f'\nnode: {node}' - ) + # verify the inferred shape are consistent with actual output + if isinstance(ir_node, IRFwOperation): + ir_node.verify_shape(vals) assert len(vals) == len(ir_node.outputs()), f'{vals}, {ir_node.outputs()}' contains_undefined_output = any(v is IRObject.missing for v in ir_node.outputs()) diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index 3f6f289e..e71bfe1f 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -813,7 +813,7 @@ def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: - IRWeightReducer generation case 1: when gradients over devices can not be accumulated directly to synchronize. This typically happens when a parameter is shared, especially in pipeline parallelism. - Here we can use multiref to synchronize gradients, but the semantic is different. (TODO: add a new + Here we can use multiref to synchronize gradients, but the semantic is different. (TODO: add a new function to handle this case to make it more clear). With multiref, the weight becomes an activation, so no IRWeightReducer will be generated. Instead, Adapter will be used in runtime. @@ -1001,7 +1001,7 @@ def single_consume(self, one_for_all: bool = True): multiref.comment = 'create at IRSegment:single_consume' for idx, itensor in enumerate(itensors): multiref.set_output(idx, itensor) - multiref.infer_shape() + multiref.verify_shape() # insert multiref right before the consumor idx = self.index(consumer) # require backward @@ -1042,7 +1042,7 @@ def single_consume(self, one_for_all: bool = True): multiref.comment = 'create at IRSegment:single_consume' for idx, itensor in enumerate(itensors): multiref.set_output(idx, itensor) - multiref.infer_shape() + multiref.verify_shape() idx = self.index(producers[ftensor]) + 1 if ftensor in producers else 0 # idx = nodes.index(cnodes[0]) if any(itensor.requires_grad for itensor in node.inputs()): diff --git a/nnscaler/ir/operator.py b/nnscaler/ir/operator.py index 11bf7bad..952c83da 100644 --- a/nnscaler/ir/operator.py +++ b/nnscaler/ir/operator.py @@ -38,20 +38,43 @@ def __init__(self, name: str, signature: str, for name, value in kwargs.items(): self.set_kwarg(name, value) - # default infer rule - requires_grad = any( - t.requires_grad for t in inputs if isinstance(t, IRTensor)) - # setup output - outputs = [IRFullTensor(requires_grad=requires_grad) for _ in range(num_outputs)] + outputs = [IRObject.missing for _ in range(num_outputs)] for idx, output in enumerate(outputs): self.set_output(idx, output) - def infer_shape(self): + def infer_shape(self) -> dict[int, tuple[int, ...]]: """ - Infer output value shape + Infer output value shape for each output + Will not update graph or shape of `self._outputs` + """ + # by default, no shape inference + return {} + + def verify_shape(self, outputs=None) -> None: + """ + Verify the shape of the outputs with inferred shape of the operator. + Raise error if shape mismatch. + + Args: + outputs: the outputs to match. If None, use self.outputs() + + Raises: + ValueError: if shape mismatch """ - raise NotImplementedError + infered_shapes = self.infer_shape() + outputs = outputs if outputs is not None else self.outputs() + for oidx in range(len(outputs)): + if oidx not in infered_shapes: + continue + if not isinstance(outputs[oidx], IRTensor): + raise ValueError(f'find type inference not match: {outputs[oidx]} expected to be a tensor') + + if tuple(outputs[oidx].shape) != tuple(infered_shapes[oidx]): + raise ValueError( + f'find shape inference not match: {outputs[oidx].shape} vs {infered_shapes[oidx]}' + f'\nnode: {self}' + ) @property def recompute(self) -> Optional[int]: diff --git a/pyproject.toml b/pyproject.toml index 2e48bd29..e73228f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,3 +31,10 @@ dynamic.dependencies.file = "requirements.txt" # since we are using cppimport, sdist is not needed packages.find.include = ["nnscaler*"] package-data = { nnscaler = ["resources/**", "autodist/*.h", "autodist/*.cpp"] } + +[tool.coverage.run] +omit = [ + "nnscaler/algorithm/ops/conv.py", + "nnscaler/graph/function/conv.py", + "nnscaler/graph/tracer/_pytree.py" +] diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index 0065d9c1..ab4839b4 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -13,7 +13,7 @@ import nnscaler.graph.function as F from nnscaler.graph.function.dimops import IRDimops, OpAnno from nnscaler.ir.tensor import IRFullTensor -from nnscaler.ir.cten import IRObject +from nnscaler.ir.cten import IRObject, IRTensor def create_op(creator: Callable, @@ -22,7 +22,18 @@ def create_op(creator: Callable, return creator(*(inputs+args), **kwargs) +def set_outputs(op: IRDimops): + inputs = op.inputs() + require_grads = any( + t.requires_grad for t in inputs if isinstance(t, IRTensor)) + inferred_shape = op.infer_shape() + for idx, shape in inferred_shape.items(): + op.set_output(idx, IRFullTensor(shape=shape, requires_grad=require_grads).tosub()) + return op + + def partitionable(node: IRDimops, **config): + set_outputs(node) print(f'\n\n# {node.anno}') print(f'testing node: {node}') sub_nodes = node.algorithm('dim').instantiate(**config) From a9bdf58242381e9c2257581e398acbb6a14cab0e Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 19 Dec 2024 06:56:21 +0000 Subject: [PATCH 1789/1892] Merged PR 2346: [Profiler + Test] add tests for profiler add tests for profiler --- examples/llama/README.rst | 6 ++-- nnscaler/profiler/__main__.py | 6 ---- .../{comm_profile.py => benchmark_comm.py} | 8 +++-- nnscaler/profiler/database.py | 28 +++++++++--------- tests/profiler/test_benchmark_comm.py | 29 +++++++++++++++++++ 5 files changed, 52 insertions(+), 25 deletions(-) delete mode 100644 nnscaler/profiler/__main__.py rename nnscaler/profiler/{comm_profile.py => benchmark_comm.py} (97%) create mode 100644 tests/profiler/test_benchmark_comm.py diff --git a/examples/llama/README.rst b/examples/llama/README.rst index 401dd136..cae35b42 100644 --- a/examples/llama/README.rst +++ b/examples/llama/README.rst @@ -162,7 +162,7 @@ If the profiling is skipped, the system will use MI250's data by default. You ca .. code-block:: bash - torchrun --nnodes= --nproc_per_node= -m nnscaler.profiler + torchrun --nnodes= --nproc_per_node= -m nnscaler.profiler.benchmark_comm Checkpoint ========== @@ -288,9 +288,9 @@ For example, you can use the following command to prepare data and train a small # prepare data python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 - + # build the mini model python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini - + # compile and run using data parallelism + zero1 torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K diff --git a/nnscaler/profiler/__main__.py b/nnscaler/profiler/__main__.py deleted file mode 100644 index 0340eaf4..00000000 --- a/nnscaler/profiler/__main__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from . import comm_profile - -comm_profile.main() diff --git a/nnscaler/profiler/comm_profile.py b/nnscaler/profiler/benchmark_comm.py similarity index 97% rename from nnscaler/profiler/comm_profile.py rename to nnscaler/profiler/benchmark_comm.py index ec12b6b0..5871238c 100644 --- a/nnscaler/profiler/comm_profile.py +++ b/nnscaler/profiler/benchmark_comm.py @@ -150,13 +150,17 @@ def profile_primitive(self, primitive: str, ranks: list[int]) -> PrimitiveProfil return sizes_in_mb, times_in_s -def main() -> bool: +def main(): if not is_running_distributed(): - print('Usage: torchrun {TORCHRUN_ARGS} -m nnscaler.profiler') + print('Usage: torchrun {TORCHRUN_ARGS} -m nnscaler.profiler.benchmark_comm') sys.exit(1) nnscaler.init() + if DeviceGroup().world_size == 1: + _logger.warning('Single GPU profiling is not supported') + return + if DeviceGroup().local_rank == 0: nnscaler.utils.set_default_logger_level('INFO') else: diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index c6ae4671..30f76f10 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -305,22 +305,21 @@ def profile(self, node: IRFwOperation, override: bool = False) -> ProfiledMetric if not override and self.exist(node): return self.query(node) - fn, shapes, dtypes, requires_grads, values, kwargs = get_func(node) - - in_mem_info, param_mem_info, buffer_mem_info, in_mem_idx = [], [], [], [] - for idx, t in enumerate(node.inputs()): - if isinstance(t, IRTensor) and t.is_param(): - param_mem_info.append(t.byte_size()) - elif isinstance(t, IRTensor) and t.is_buffer(): - buffer_mem_info.append(t.byte_size()) - elif hasattr(t, 'byte_size'): - in_mem_info.append(t.byte_size()) - in_mem_idx.append(idx) - else: - _logger.debug(f'node {node}: skip input {t}') - # run profiling try: + in_mem_info, param_mem_info, buffer_mem_info, in_mem_idx = [], [], [], [] + fn, shapes, dtypes, requires_grads, values, kwargs = get_func(node) + + for idx, t in enumerate(node.inputs()): + if isinstance(t, IRTensor) and t.is_param(): + param_mem_info.append(t.byte_size()) + elif isinstance(t, IRTensor) and t.is_buffer(): + buffer_mem_info.append(t.byte_size()) + elif hasattr(t, 'byte_size'): + in_mem_info.append(t.byte_size()) + in_mem_idx.append(idx) + else: + _logger.debug(f'node {node}: skip input {t}') fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = \ profile(node, fn, shapes, dtypes, requires_grads, values, **kwargs) except Exception: @@ -333,6 +332,7 @@ def profile(self, node: IRFwOperation, override: bool = False) -> ProfiledMetric # by default, we assume that all the input tensors are saved for backward train_mem_info = copy.deepcopy(in_mem_info) train_mem2in_idx = in_mem_idx + profiled_metrics = ProfiledMetrics(in_mem_info, param_mem_info, buffer_mem_info, fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx) diff --git a/tests/profiler/test_benchmark_comm.py b/tests/profiler/test_benchmark_comm.py new file mode 100644 index 00000000..e53b2f50 --- /dev/null +++ b/tests/profiler/test_benchmark_comm.py @@ -0,0 +1,29 @@ +import torch + +from unittest.mock import patch +import pytest + +from nnscaler.profiler.benchmark_comm import main + +from ..launch_torchrun import launch_torchrun + + +def comm_profile_worker(tmp_path): + def patched_save_path(*args, **kwargs): + return tmp_path + + with patch( + "nnscaler.profiler.benchmark_comm.get_default_profile_path", + side_effect=patched_save_path + ): + main() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_comm_profile(tmp_path): + # just a smoke test + launch_torchrun(2, comm_profile_worker, tmp_path) + assert (tmp_path / 'comm' / 'intra_2.json').exists() + launch_torchrun(2, comm_profile_worker, tmp_path) + comm_bakup_dirs = list(tmp_path.glob('comm-bak-*')) + assert len(comm_bakup_dirs) == 1 From e8d763217b4bb25c9f6d61c275fe15740e23d991 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Fri, 20 Dec 2024 09:07:07 +0000 Subject: [PATCH 1790/1892] Merged PR 2348: [Model Example] Support diff-attention This PR supports diff attention based on llama, which replaces original eager and flash_attn_2 attention implementations. Related work items: #2121 --- README.md | 5 +- .../lm_models/diff_transformer_modifier.py | 248 ++++++++++++++++++ .../llama_modifier.py} | 27 -- examples/llama/lm_models/utils.py | 35 +++ examples/llama/train.py | 11 +- 5 files changed, 294 insertions(+), 32 deletions(-) create mode 100644 examples/llama/lm_models/diff_transformer_modifier.py rename examples/llama/{modeling_modifier.py => lm_models/llama_modifier.py} (92%) create mode 100644 examples/llama/lm_models/utils.py diff --git a/README.md b/README.md index 37630147..50a27325 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ nnScaler is a parallelization engine that compiles a Deep neural network (DNN) m # Latest News nnScaler (also known as CUBE as code name) has been adopted by multiple product and research projects, this section includes some of the latest news from the team and partner projects. * **2024-11-26** nnScaler 0.5 released: https://github.com/microsoft/nnscaler/releases/tag/0.5 +* **2024-10-07** Diff-Transformer utilizes nnScaler for differential attention mechanism: [DIFFERENTIAL TRANSFORMER](https://arxiv.org/abs/2410.05258) * **2024-05-09** YOCO utilizes nnScaler for long-sequence training: [(YOCO)You only cache once: Decoder-decoder architectures for language models](https://arxiv.org/abs/2405.05254) * **2024-04-22** Post training for the long context version of [Phi-3 series](https://arxiv.org/abs/2404.14219) * **2024-02-21** LongRoPE utilizes nnScaler to reduce both the training and inference costs: [LongRoPE: Extending LLM context window beyond 2 million tokens](https://arxiv.org/abs/2402.13753) @@ -75,7 +76,7 @@ Obtain access of Llama-3 model from [HuggingFace](https://huggingface.co/meta-ll ### Code Changes for Parallelization -You can find all the example code at `examples/llama3_8B_128K`. As shown below, a user needs to: +You can find all the example code at `examples/llama`. As shown below, a user needs to: * Wrap the Model: Include loss computation and other necessary components. * Configure Components: Set up the model, optimizer, and dataloader. * Initialize and Start: In the main function, create an nnScaler trainer with the above configurations and start the training process. @@ -135,7 +136,7 @@ def main(args): Then we can start the example, and all the parallelization tasks will be finished by nnScaler automatically. ```shell -cd examples/llama3_8B_128K +cd examples/llama # prepare training data: python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 diff --git a/examples/llama/lm_models/diff_transformer_modifier.py b/examples/llama/lm_models/diff_transformer_modifier.py new file mode 100644 index 00000000..909969a7 --- /dev/null +++ b/examples/llama/lm_models/diff_transformer_modifier.py @@ -0,0 +1,248 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb +try: + from apex.normalization import FusedRMSNorm as RMSNorm +except ModuleNotFoundError: + print("No fused RMSNorm") + from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm + +import logging +import math + +from transformers.utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, +) +from flash_attn import flash_attn_func + +from nnscaler.graph.parser.register import register_op +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func + + +logger = logging.getLogger(__name__) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + bs, n_kv_heads, slen, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, None, :, :] + .expand(bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(bs, n_kv_heads * n_rep, slen, head_dim) + ) + +def lambda_init_fn(depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + +class NNScalerMultiheadDiffAttn(LlamaAttention): + """ + Llama attention module using Diff-transformer attention. This module inherits from `LlamaAttention` as the weights + of the module stays untouched. The only changes are on attention part using implementation of multihead_diffattn.py, + original implementation can be refered to https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_diffattn.py + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.head_dim = self.hidden_size // self.num_heads // 2 + self.scaling = self.head_dim ** -0.5 + + if (self.head_dim * self.num_heads) != self.hidden_size // 2: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.num_key_value_groups, bias=self.config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.num_key_value_groups, bias=self.config.attention_bias) + self._init_rope() + + assert self.layer_idx is not None, "layer_idx must be provided for NNScalerMultiheadDiffAttn" + self.lambda_init = lambda_init_fn(self.layer_idx) + self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + + self.subln = RMSNorm(2 * self.head_dim, eps=1e-5) + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2 * self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # In case static cache is used, it is an instance attribute. + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + query_states *= self.scaling + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + else: + causal_mask = torch.triu(torch.zeros([q_len, q_len]).float().fill_(float("-inf")).type_as(query_states), 1) + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) + attn_weights = torch.nan_to_num(attn_weights) + attn_weights += causal_mask + attn_weights = F.softmax(attn_weights, dim=-1) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1)) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1)) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, q_len) + attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1] + + attn_output = torch.matmul(attn_weights, value_states) + attn_output = self.subln(attn_output) + attn_output = attn_output * (1 - self.lambda_init) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class NNScalerMultiheadDiffFlashAttn(NNScalerMultiheadDiffAttn): + """ + Llama attention module using Diff-transformer flash attention. This module inherits from `LlamaAttention` as the weights + of the module stays untouched. The only changes are on attention part using implementation of multihead_flashdiff_2.py, + original implementation can be refered to https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_flashdiff_2.py + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning("output_attentions is not supported for NNScalerMultiheadDiffFlashAttn.") + if attention_mask: + logger.warning("attention_mask is not supported for NNScalerMultiheadDiffFlashAttn.") + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2, self.head_dim) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2) + + # In case static cache is used, it is an instance attribute. + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if query_states.device.type == "cuda": + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + query_states = query_states.reshape(bsz, q_len, self.num_heads, 2, self.head_dim) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, 2, self.head_dim) + q1, q2 = query_states[:, :, :, 0], query_states[:, :, :, 1] + k1, k2 = key_states[:, :, :, 0], key_states[:, :, :, 1] + v1, v2 = value_states[:, :, :, 0], value_states[:, :, :, 1] + + attn11 = flash_attn_func(q1, k1, v1, causal=True) + attn12 = flash_attn_func(q1, k1, v2, causal=True) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = flash_attn_func(q2, k2, v1, causal=True) + attn22 = flash_attn_func(q2, k2, v2, causal=True) + attn2 = torch.cat([attn21, attn22], dim=-1) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1)) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1)) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + attn_output = attn1 - lambda_full * attn2 + attn_output = self.subln(attn_output) + attn_output = attn_output * (1 - self.lambda_init) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * 2 * self.head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' + + +register_op(flash_attention_anno)(flash_attn_func) diff --git a/examples/llama/modeling_modifier.py b/examples/llama/lm_models/llama_modifier.py similarity index 92% rename from examples/llama/modeling_modifier.py rename to examples/llama/lm_models/llama_modifier.py index 7792e3d8..cce87f3c 100644 --- a/examples/llama/modeling_modifier.py +++ b/examples/llama/lm_models/llama_modifier.py @@ -16,12 +16,8 @@ from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.models.llama.modeling_llama import LlamaAttention, LLAMA_ATTENTION_CLASSES, apply_rotary_pos_emb, LlamaRMSNorm from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, ) @@ -30,24 +26,6 @@ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -try: - from apex.normalization.fused_layer_norm import fused_rms_norm_affine - has_apex = True -except ImportError: - has_apex = False - - -def rmsnorm_fwd(self, hidden_states): - if has_apex: - return fused_rms_norm_affine(hidden_states, self.weight, self.weight.shape, self.variance_epsilon) - else: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - class NNScalerLlamaFlashAttention2(LlamaAttention): """ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays @@ -263,8 +241,3 @@ def llama_flash_attention_anno(query_states, key_states, value_states, attention register_op(llama_flash_attention_anno)(nnscaler_flash_attention_forward) - - -def nnscaler_llama_init(): - LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerLlamaFlashAttention2 - LlamaRMSNorm.forward = rmsnorm_fwd diff --git a/examples/llama/lm_models/utils.py b/examples/llama/lm_models/utils.py new file mode 100644 index 00000000..566ed1bf --- /dev/null +++ b/examples/llama/lm_models/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES, LlamaRMSNorm +from .llama_modifier import NNScalerLlamaFlashAttention2 +from .diff_transformer_modifier import NNScalerMultiheadDiffAttn, NNScalerMultiheadDiffFlashAttn + +try: + from apex.normalization.fused_layer_norm import fused_rms_norm_affine + has_apex = True +except ImportError: + has_apex = False + + +def rmsnorm_fwd(self, hidden_states): + if has_apex: + return fused_rms_norm_affine(hidden_states, self.weight, self.weight.shape, self.variance_epsilon) + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def nnscaler_lm_init(args): + if args.enable_diff_attn: + if args.attn_implementation == "sdpa": + raise ValueError("sdpa is currently not supported in Diff-Transformer") + LLAMA_ATTENTION_CLASSES["eager"] = NNScalerMultiheadDiffAttn + LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerMultiheadDiffFlashAttn + else: + LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerLlamaFlashAttention2 + LlamaRMSNorm.forward = rmsnorm_fwd \ No newline at end of file diff --git a/examples/llama/train.py b/examples/llama/train.py index 777838ad..c998698c 100644 --- a/examples/llama/train.py +++ b/examples/llama/train.py @@ -9,7 +9,7 @@ import huggingface_hub import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling -from modeling_modifier import nnscaler_llama_init +from lm_models.utils import nnscaler_lm_init from chunk_linear_cross_entropy import chunk_linear_cross_entropy from nnscaler.utils import set_default_logger_level @@ -116,7 +116,7 @@ def main(args): set_default_logger_level('INFO') - nnscaler_llama_init() + nnscaler_lm_init(args) ## Setup Dataset ## @@ -349,7 +349,12 @@ def collate_fn(samples): '--attn_implementation', default='flash_attention_2', type=str, - help='attn implementation, can be flash_attention_2, spda or eager', + help='attn implementation, can be flash_attention_2, spda, eager', + ) + parser.add_argument( + '--enable_diff_attn', + action='store_true', + help='enable diff attention implementation, eager is normal diff attention, flash_attention_2 is diff flash attention, and spda diff attention is not currently supported', ) args = parser.parse_args() if args.explore_pipeline and not args.pipeline_pivots: From 0c23d5314a1cbe15ccc3c78225b37674a45c48c6 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 26 Dec 2024 06:06:06 +0000 Subject: [PATCH 1791/1892] Merged PR 2347: [Parser] Clarify only dict/tuple/list/slice are supported collection types Make it clear that only dict/tuple/list/slice are supported collection types. DICT_VALUE_TYPES/DICT_ITEM_TYPES are supported by converting them to tuple. --- nnscaler/graph/function/function.py | 73 ++++++++++++--- nnscaler/graph/parser/mapping.py | 2 +- nnscaler/graph/parser/parser.py | 11 ++- nnscaler/ir/cten.py | 118 ++++++++++++++++--------- nnscaler/runtime/function/function.py | 12 +++ tests/graph/function/test_functions.py | 56 ++++++++++++ tests/ir/test_cten.py | 45 +++++++--- 7 files changed, 247 insertions(+), 70 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index b8a99197..ba75d68b 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -3,7 +3,12 @@ """ Rules: -1. `dict`/`list`/`tuple`/`slice` are the only supported container types for IRObject/IRTensor. +1. `dict`/`list`/`tuple`/`slice` are the only supported container types for IRTensor. + `DictValues`/`DictItems` will be converted to tuple to make it compatible with our system. + TODO: Set is incompatible with our system, but it is really rare to put tensors in a set. + The problem to support set is that hash of IRObject is different with the hash of original value. + So set functions (add/discard/etc) can have different behavior. + `DictKeys` has the same problem, so we don't support it either. 2. IRObjects created in functions should be the outputs. Never create new IRObjects as function inputs/kwargs. 3. `iter` is not compatible with our system, and should be avoided. That's because the values in the iterator have been exausted in tracer @@ -2627,14 +2632,31 @@ def L1Loss(input, target, size_average=None, reduce=None, reduction='mean', sign def MakeTuple(inputs: Iterable, signature=None): - return tuple(inputs) + """ + builtins.tuple + 1. inputs can be an IRObject or a tuple/list of any type(including IRObject) + 2. If inputs is IRObject, return IRPyFunc op + Otherwise, return concrete value. + """ + if not isinstance(inputs, IRObject): + return tuple(inputs) + + ir_value = IR.new('tuple', tuple(inputs.value), is_constant=inputs.is_constant) + return IRPyFunc(signature, inputs=[inputs], outputs=[ir_value]) def MakeList(inputs: Iterable, signature=None): - if isinstance(inputs, Iterable): + """ + builtins.list + 1. inputs can be an IRObject or a tuple/list of any type(including IRObject) + 2. If inputs is IRObject, return IRPyFunc op + Otherwise, return concrete value. + """ + if not isinstance(inputs, IRObject): return list(inputs) - else: - return IRPyFunc(signature, [inputs], [IRObject(value=list(inputs.value))]) + + ir_value = IR.new('list', list(inputs.value), is_constant=inputs.is_constant) + return IRPyFunc(signature, inputs=[inputs], outputs=[ir_value]) def MakeSlice(*inputs: Iterable, signature=None): @@ -3308,16 +3330,43 @@ def Sigmoid(input, *, out=None, signature=None): return IRDimops(Sigmoid, 'sigmoid', signature, annos, [input]) -def Dictkeys(o: Union[Dict, IRObject], signature=None): - assert isinstance(o, dict) or isinstance(o.value, dict), f'the input should be a dict or an IRObject with dict value, but get {o}' - return IRPyFunc(signature, inputs=[o], outputs=[IRObject(name='dictkeys', value=o.value.keys(), is_constant=o.is_constant)]) +def DictKeys(o: Union[Dict, IRObject], signature=None): + signature = 'nnscaler.runtime.function.dict_keys' + + if not isinstance(o, dict) and not (isinstance(o, IRObject) and isinstance(o.value, dict)): + raise ValueError(f'the input should be a dict or an IRObject with dict value, but get {o}') + + # put tuple of keys as value, because tuple supports all operations of dict.keys() + value = tuple(o.keys() if isinstance(o, dict) else o.value.keys()) + + # set is_constant to False to make sure it will never be folded. + ir_value = IR.new('dictkeys', value, is_constant=False) + return IRPyFunc(signature, inputs=[o], outputs=[ir_value]) def DictValues(o: Union[Dict, IRObject], signature=None): - assert isinstance(o, dict) or isinstance(o.value, dict), f'the input should be a dict or an IRObject with dict value, but get {o}' - return IRPyFunc(signature, inputs=[o], outputs=[IRObject(name='dictvalues', value=o.value.values(), is_constant=o.is_constant)]) + signature = 'nnscaler.runtime.function.dict_values' + + if not isinstance(o, dict) and not (isinstance(o, IRObject) and isinstance(o.value, dict)): + raise ValueError(f'the input should be a dict or an IRObject with dict value, but get {o}') + + # put tuple of values as value, because tuple supports all operations of dict.values() + value = o.values() if isinstance(o, dict) else o.value.values() + + # set is_constant to False to make sure it will never be folded. + ir_value = IR.new('dictvalues', value, is_constant=False) + return IRPyFunc(signature, inputs=[o], outputs=[ir_value]) def DictItems(o: Union[Dict, IRObject], signature=None): - assert isinstance(o, dict) or isinstance(o.value, dict), f'the input should be a dict or an IRObject with dict value, but get {o}' - return IRPyFunc(signature, inputs=[o], outputs=[IRObject(name='dictitems', value=o.value.items(), is_constant=o.is_constant)]) + signature = 'nnscaler.runtime.function.dict_values' + + if not isinstance(o, dict) and not (isinstance(o, IRObject) and isinstance(o.value, dict)): + raise ValueError(f'the input should be a dict or an IRObject with dict value, but get {o}') + + # put tuple of kv pairs as value, because tuple supports all operations of dict.items() + value = o.items() if isinstance(o, dict) else o.value.items() + + # set is_constant to False to make sure it will never be folded. + ir_value = IR.new('dictitems', value, is_constant=False) + return IRPyFunc(signature, inputs=[o], outputs=[ir_value]) diff --git a/nnscaler/graph/parser/mapping.py b/nnscaler/graph/parser/mapping.py index 2b460d84..03724a28 100644 --- a/nnscaler/graph/parser/mapping.py +++ b/nnscaler/graph/parser/mapping.py @@ -152,7 +152,7 @@ def exist(signature: str) -> bool: 'builtins.list': function.MakeList, 'builtins.slice': function.MakeSlice, 'builtins.len': function.Len, - 'builtins.dict.keys': function.Dictkeys, + 'builtins.dict.keys': function.DictKeys, 'builtins.dict.values': function.DictValues, 'builtins.dict.items': function.DictItems, diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 046ac2ab..fc9f96a5 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -161,7 +161,6 @@ def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, assert hasattr(node, 'meta') and 'tensor_meta' in node.meta, f"Node {node} should have tensor_meta" meta = node.meta['tensor_meta'] val = IR.new(node.name, meta, - collection_types=(list, tuple, dict, DICT_VALUES_TYPE, DICT_ITEMS_TYPE), tensor_types=(TensorMetadata,), is_constant=is_constant ) @@ -189,11 +188,17 @@ def parse_complex(val: Any, frame: Frame) -> Any: return list(FxModuleParser.parse_complex(t, frame) for t in val) if isinstance(val, dict): return {key: FxModuleParser.parse_complex(val, frame) for key, val in val.items()} + # TODO: Currently slice/DICT_VALUES_TYPE/DICT_ITEMS_TYPE cases are never found. + # We need to find some examples to test them. + if isinstance(val, slice): + return slice(FxModuleParser.parse_complex(val.start, frame), + FxModuleParser.parse_complex(val.stop, frame), + FxModuleParser.parse_complex(val.step, frame)) # because fx node cannot be a dict key, so skip DICT_KEYS_TYPE here if isinstance(val, DICT_VALUES_TYPE): - return {i: FxModuleParser.parse_complex(x, frame) for i, x in enumerate(val)}.values() + return tuple(FxModuleParser.parse_complex(x, frame) for x in val) if isinstance(val, DICT_ITEMS_TYPE): - return {i: FxModuleParser.parse_complex(x, frame) for i, x in val}.items() + return tuple((i, FxModuleParser.parse_complex(x, frame)) for i, x in val) if isinstance(val, torch.fx.Node): return frame.get_var(val.name) return val diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 2b260f92..11bca941 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -590,7 +590,6 @@ def new(cls, name: str, data: Any, *, - collection_types: Tuple = (tuple, list, dict), tensor_types: Tuple = (torch.Tensor,), is_constant: bool = False, requires_grad: Optional[bool] = None, @@ -598,10 +597,15 @@ def new(cls, ) -> Any: """ Convert complex data type of - collection_types (tuple, list, dict) + collection_types (tuple, list, dict, slice, DICT_VALUES_TYPE, DICT_ITEMS_TYPE) tensor_types (has shape/dtype/requires_grad) into intermediate representation object. + Note: + 1. dict_values will be converted into `Tuple` + dict_items will be converted into `Tuple[Tuple[Key, Value]]` + 2. This function cannot be used after logical graph is created (i.e., after parser). + Rule: 1. All tensor-like objects will be converted into IRFullTensor 2. For any complex types, @@ -617,25 +621,33 @@ def new(cls, Args: name (str): the object name data (Any): the complex data structure to be converted - collection_types (Tuple): the complex data types to be converted tensor_types (Tuple): the tensor data types to be converted - tosub(bool): whether convert full tensor to sub-tensor + tosub(bool): whether convert all full tensors to sub-tensor is_constant (bool): whether the object is constant requires_grad (Optional[bool]): the requires_grad flag for the tensor-like object None: will respect the original requires_grad flag True: will set requires_grad to True False: will set requires_grad to False """ - from nnscaler.ir.tensor import IRFullTensor + from nnscaler.ir.tensor import IRFullTensor, IRSubTensor - collection_types = tuple(collection_types) tensor_types = tuple(tensor_types) - supported_collection_types = (tuple, list, dict, _DICT_VALUES_TYPE, _DICT_ITEMS_TYPE) - if any(t not in supported_collection_types for t in collection_types): - raise ValueError(f"Only support converting complex data type of {supported_collection_types}") def _inner(obj) -> Tuple[Any, bool]: - # second return is to know if there is any tensor-like object + """second return is to know if there is any tensor-like object""" + + if isinstance(obj, IRObject) : + assert not isinstance(obj, IRSubTensor), "IRSubTensor is not supported" + # Never reuse existing ir object + # to make sure we have SSA semantics. + if isinstance(obj, IRFullTensor): + new_ir_tensor = obj.like() + if tosub: + new_ir_tensor = new_ir_tensor.tosub() + new_ir_tensor._value = obj.value + return new_ir_tensor, True + else: + return IRObject(name, value=obj.value, is_constant=is_constant), False if isinstance(obj, tensor_types): if requires_grad is None: @@ -656,40 +668,58 @@ def _inner(obj) -> Tuple[Any, bool]: tensor._value = obj # is required in SemanticModel.forward return tensor, True - if isinstance(obj, collection_types): - if isinstance(obj, tuple): - result = [_inner(item) for item in obj] - if not any(r[1] for r in result): - return IRObject(name, value=obj, is_constant=is_constant), False - else: - return tuple(r[0] for r in result), True - if isinstance(obj, list): - result = [_inner(item) for item in obj] - if not any(r[1] for r in result): - return IRObject(name, value=obj, is_constant=is_constant), False - else: - return [r[0] for r in result], True - if isinstance(obj, dict): - if not all(isinstance(key, str) for key in obj.keys()): - raise TypeError(f"only support dict type with str key, but got {obj.keys()}.") - result = {k: _inner(v) for k, v in obj.items()} - if not any(r[1] for r in result.values()): - return IRObject(name, value=obj, is_constant=is_constant), False - else: - return {k: r[0] for k, r in result.items()}, True - if isinstance(obj, _DICT_VALUES_TYPE): - result = [_inner(item) for item in obj] - if not any(r[1] for r in result): - return IRObject(name, value=obj, is_constant=is_constant), False - else: - return {k: r[0] for k, r in enumerate(result)}.values(), True - if isinstance(obj, _DICT_ITEMS_TYPE): - result = {k: _inner(v) for k, v in obj} - if not any(r[1] for r in result.values()): - return IRObject(name, value=obj, is_constant=is_constant), False - else: - return {k: r[0] for k, r in result.items()}.items(), True - # slice will go here, as its start/stop/step are never tensor-like objects + if isinstance(obj, slice): + result = [_inner(item) for item in [obj.start, obj.stop, obj.step]] + if not any(r[1] for r in result): + # try not to re-construct the slice if possible. + unwrapped_value = cls.try_unwrap(obj) if cls.contains_object(obj) else obj + return IRObject(name, value=unwrapped_value, is_constant=is_constant), False + else: + return slice(*[r[0] for r in result]), True + + if isinstance(obj, tuple): + result = [_inner(item) for item in obj] + if not any(r[1] for r in result): + # try not to re-construct the tuple if possible. + unwrapped_value = cls.try_unwrap(obj) if cls.contains_object(obj) else obj + return IRObject(name, value=unwrapped_value, is_constant=is_constant), False + else: + return tuple(r[0] for r in result), True + + if isinstance(obj, list): + result = [_inner(item) for item in obj] + if not any(r[1] for r in result): + # try not to re-construct the list if possible. + unwrapped_value = cls.try_unwrap(obj) if cls.contains_object(obj) else obj + return IRObject(name, value=unwrapped_value, is_constant=is_constant), False + else: + return [r[0] for r in result], True + + if isinstance(obj, dict): + if not all(isinstance(key, str) for key in obj.keys()): + raise TypeError(f"only support dict type with str key, but got {obj.keys()}.") + result = {k: _inner(v) for k, v in obj.items()} + if not any(r[1] for r in result.values()): + # try not to re-construct the dict if possible. + unwrapped_value = cls.try_unwrap(obj) if cls.contains_object(obj) else obj + return IRObject(name, value=unwrapped_value, is_constant=is_constant), False + else: + return {k: r[0] for k, r in result.items()}, True + + if isinstance(obj, _DICT_VALUES_TYPE): + result = [_inner(item) for item in obj] + if not any(r[1] for r in result): + return IRObject(name, value=cls.try_unwrap(tuple(obj)), is_constant=is_constant), False + else: + return tuple(r[0] for r in result), True + + if isinstance(obj, _DICT_ITEMS_TYPE): + result = {k: _inner(v) for k, v in obj} + if not any(r[1] for r in result.values()): + return IRObject(name, value=cls.try_unwrap(tuple(obj)), is_constant=is_constant), False + else: + return tuple((k,r[0]) for k, r in result.items()), True + return IRObject(name, value=obj, is_constant=is_constant), False return _inner(data)[0] diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 22f250a8..f6a8f06f 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -342,6 +342,18 @@ def setitem(__a, *__bc): return __a +def dict_keys(d: dict): + return tuple(d.keys()) + + +def dict_values(d: dict): + return tuple(d.values()) + + +def dict_items(d: dict): + return tuple(d.items()) + + def print_time(content: str): if not CompileFlag.line_timer: return diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 345ac500..1b6277e2 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -1074,3 +1074,59 @@ def test_reshape_view(): op = f(query, 10, 5, 7) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a 1 1 b 1 c 1 1 -> a b c' verify_partition(op) + + +def test_make_collection(): + # test non-irobject + l = [1, 2, 3] + t = (1, 2, 3) + s = slice(*l) + r = F.MakeList(l) + assert r == l + r = F.MakeTuple(t) + assert r == t + r = F.MakeSlice(*l) + assert r == s + + # test irobject items + l = [IRObject(value=1), IRFullTensor([2]), 3] + t = (IRObject(value=1), IRFullTensor([2]), 3) + s = slice(*l) + r = F.MakeList(l) + assert r == l + r = F.MakeTuple(t) + assert r == t + r = F.MakeSlice(*l) + assert r == s + + # test whole irobject + l = IRObject(value=[1, 2, 3]) + t = IRObject(value=(1, 2, 3)) + r = F.MakeList(l, signature='builtins.list') + assert r.output(0).value == l.value + r = F.MakeTuple(t, signature='builtins.tuple') + assert r.output(0).value == t.value + # MakeSlice is not valid. + # F.MakeSlice(s) + + +def test_dict_keys_values_items(): + # normal dict + d = {'a': 1, 'b': 2, 'c': 3} + r = F.DictKeys(d) + assert r.output(0).value == tuple(d.keys()) + r = F.DictValues(d) + assert r.output(0).value == tuple(d.values()) + r = F.DictItems(d) + assert r.output(0).value == tuple(d.items()) + + d = {'a': IRFullTensor([1]), 'b': IRFullTensor([2]), 'c': IRFullTensor([3])} + r = F.DictKeys(d) + assert r.output(0).value == tuple(d.keys()) + r = F.DictValues(d) + # IRFullTensor will be reconstructed, so their ids are different + assert all(x.shape == y.shape and x != y for x, y in zip(r.output(0), d.values())) + r = F.DictItems(d) + # key will never be wrapped with IRObject + # IRFullTensor will be reconstructed, so their ids are different + assert all(x[0] == y[0] and x[1].shape == y[1].shape and x[1] != y[1] for x, y in zip(r.output(0), d.items())) diff --git a/tests/ir/test_cten.py b/tests/ir/test_cten.py index d2671b53..c406f7bc 100644 --- a/tests/ir/test_cten.py +++ b/tests/ir/test_cten.py @@ -70,8 +70,9 @@ def test_from_complex(tosub, requires_grad): obj = IR.new('n', x, tosub=tosub, tensor_types=(), requires_grad=requires_grad) assert type(obj) == IRObject and id(obj.value) == id(x) and not obj.is_constant and obj.name == 'n' - obj = IR.new('n', x, tosub=tosub, collection_types=(tuple,), requires_grad=requires_grad) - assert type(obj) == IRObject and id(obj.value) == id(x) and not obj.is_constant and obj.name == 'n' + x_set = {1, t1, t2} # set is not supported + obj = IR.new('n', x_set, tosub=tosub, requires_grad=requires_grad) + assert type(obj) == IRObject and id(obj.value) == id(x_set) and not obj.is_constant and obj.name == 'n' obj = IR.new('n', [t1, [1, 2, {'a': 3}], (4, 5, {'b': 6, 'c': t2})], tosub=tosub, requires_grad=requires_grad) assert type(obj) == list and len(obj) == 3 @@ -92,42 +93,66 @@ def test_from_complex(tosub, requires_grad): and y['c'].requires_grad == rgt and not y['c'].is_constant \ and y['c'].name == 'n' + obj_item = IRObject('obj_item', value=1, is_constant=False) + obj_tensor_item = IRFullTensor(t1.shape, 'obj_tensor_item') + + obj = IR.new('n', slice(obj_item, 2, obj_item)) + assert type(obj) == IRObject and obj.value == slice(1, 2, 1) + + obj = IR.new('n', [obj_item, 2], tosub=tosub, requires_grad=requires_grad) + assert type(obj) == IRObject and obj.value == [1, 2] + + obj = IR.new('n', [obj_item, 2, obj_tensor_item], tosub=tosub, requires_grad=requires_grad) + assert type(obj) == list and len(obj) == 3 + assert obj[0].value == 1 and obj[0].tid != obj_item.tid + assert obj[1].value == 2 + assert type(obj[2]) == tensor_type and obj[2].parent.tid != obj_tensor_item.tid + + obj = IR.new('n', [t1, obj_item], tosub=tosub, requires_grad=requires_grad) + assert type(obj) == list and len(obj) == 2 + assert type(obj[0]) == tensor_type and id(obj[0].value) == id(t1) + assert type(obj[1]) == IRObject and obj[1].value == 1 and obj[1].tid != obj_item.tid + + obj = IR.new('n', [t1, obj_item, obj_tensor_item], tosub=tosub, requires_grad=requires_grad) + assert type(obj) == list and len(obj) == 3 + assert type(obj[0]) == tensor_type and id(obj[0].value) == id(t1) + assert type(obj[1]) == IRObject and obj[1].value == 1 and obj[1].tid != obj_item.tid + assert type(obj[2]) == tensor_type and obj[2].parent.tid != obj_tensor_item.tid + t1 = TensorMetadata(shape=(), dtype=torch.float, requires_grad=False, stride=None, memory_format=None, is_quantized=None, qparams=None) t2 = TensorMetadata(shape=(2,), dtype=torch.float, requires_grad=True, stride=None, memory_format=None, is_quantized=None, qparams=None) obj = IR.new('n', {'a': t1, 'b': t2}.values(), - collection_types=(DICT_VALUES_TYPE,), tensor_types=(TensorMetadata,), tosub=tosub, requires_grad=requires_grad ) - assert type(obj) == DICT_VALUES_TYPE and len(obj) == 2 - x = list(obj)[0] + assert type(obj) == tuple and len(obj) == 2 + x = obj[0] assert type(x) == tensor_type and id(x.value) == id(t1) \ and x.shape == (1,) and x.origin_shape == () and x.dtype == torch.float \ and x.requires_grad == rg and not x.is_constant \ and x.name == 'n' - y = list(obj)[1] + y = obj[1] assert type(y) == tensor_type and id(y.value) == id(t2) \ and y.shape == (2,) and y.origin_shape == (2,) and y.dtype == torch.float \ and y.requires_grad == rgt and not y.is_constant \ and y.name == 'n' obj = IR.new('n', {'a': t1, 'b': t2}.items(), - collection_types=(DICT_ITEMS_TYPE,), tensor_types=(TensorMetadata,), tosub=tosub, requires_grad=requires_grad ) - assert type(obj) == DICT_ITEMS_TYPE and len(obj) == 2 - x = list(obj)[0] + assert type(obj) == tuple and len(obj) == 2 + x = obj[0] assert x[0] == 'a' x = x[1] assert type(x) == tensor_type and id(x.value) == id(t1) \ and x.shape == (1,) and x.origin_shape == () and x.dtype == torch.float \ and x.requires_grad == rg and not x.is_constant \ and x.name == 'n' - y = list(obj)[1] + y = obj[1] assert y[0] == 'b' y = y[1] assert type(y) == tensor_type and id(y.value) == id(t2) \ From 7899103de629115043c9ca09140c3676c3ab6cf1 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Sun, 29 Dec 2024 12:30:53 +0000 Subject: [PATCH 1792/1892] Merged PR 2350: [BugFix] import flash_attn only when available remove import flash_attn outside check flash_attn available --- examples/llama/lm_models/diff_transformer_modifier.py | 5 ++--- examples/llama/lm_models/utils.py | 4 ++-- examples/llama/requirements.txt | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/llama/lm_models/diff_transformer_modifier.py b/examples/llama/lm_models/diff_transformer_modifier.py index 909969a7..134d1337 100644 --- a/examples/llama/lm_models/diff_transformer_modifier.py +++ b/examples/llama/lm_models/diff_transformer_modifier.py @@ -20,7 +20,6 @@ is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, ) -from flash_attn import flash_attn_func from nnscaler.graph.parser.register import register_op if is_flash_attn_2_available(): @@ -244,5 +243,5 @@ def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs q_anno = kv_anno = 'num_heads' return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' - -register_op(flash_attention_anno)(flash_attn_func) +if is_flash_attn_2_available(): + register_op(flash_attention_anno)(flash_attn_func) \ No newline at end of file diff --git a/examples/llama/lm_models/utils.py b/examples/llama/lm_models/utils.py index 566ed1bf..3b3745a2 100644 --- a/examples/llama/lm_models/utils.py +++ b/examples/llama/lm_models/utils.py @@ -3,8 +3,6 @@ import torch from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES, LlamaRMSNorm -from .llama_modifier import NNScalerLlamaFlashAttention2 -from .diff_transformer_modifier import NNScalerMultiheadDiffAttn, NNScalerMultiheadDiffFlashAttn try: from apex.normalization.fused_layer_norm import fused_rms_norm_affine @@ -26,10 +24,12 @@ def rmsnorm_fwd(self, hidden_states): def nnscaler_lm_init(args): if args.enable_diff_attn: + from .diff_transformer_modifier import NNScalerMultiheadDiffAttn, NNScalerMultiheadDiffFlashAttn if args.attn_implementation == "sdpa": raise ValueError("sdpa is currently not supported in Diff-Transformer") LLAMA_ATTENTION_CLASSES["eager"] = NNScalerMultiheadDiffAttn LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerMultiheadDiffFlashAttn else: + from .llama_modifier import NNScalerLlamaFlashAttention2 LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerLlamaFlashAttention2 LlamaRMSNorm.forward = rmsnorm_fwd \ No newline at end of file diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index 8001637d..f23e777d 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,2 +1,3 @@ transformers==4.40.0 datasets==2.20.0 +tensorboard From 4c75567fb2f3b515a7d13153bc44a57319ce6aeb Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 3 Jan 2025 02:59:59 +0000 Subject: [PATCH 1793/1892] Merged PR 2349: [BugFix] Fix profiler's test Related work items: #2123 --- nnscaler/autodist/cost_database.py | 47 ++++++++++++++----- nnscaler/autodist/model_graph.py | 1 + nnscaler/graph/parser/register.py | 2 +- .../spmd_solver/test_cube_operator.py | 2 +- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 0e24ddbb..82817af3 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -111,11 +111,41 @@ def _profile_graph(dilled_info: str, dev_id: int, partition_degree: int, re_prof result.put(ret) -class CostDatabase: +def _load_comm_data(profile_dir: Path, plan_ngpus: int) -> Dict[str, Dict[str, List[Tuple[float, float]]]]: + ''' + Load communication profile data from the profile directory. If the data is not found, use the default data + at _DEFAULT_COMM_DATA_PATH. Note that in autodist's current design, we only consider the communication + cost across 2^n devices, where n is an positive integer. For example, if plan_ngpus is 8, we will try to + load intra_2.json, intra_4.json, and intra_8.json from the profile directory. If any of the files is not + found, we will use the default data as well. + ''' + def loader(path: Path): + if not os.path.exists(path): + return False, None + info = {} + dev = 2 + while dev <= plan_ngpus: + fname = f'intra_{dev}.json' + if not (path / fname).exists(): + return False, None + with open(path / fname, 'r') as f: + info[fname] = json.load(f) + dev *= 2 + return True, info + + comm_path = profile_dir / 'comm' + success, comm_info = loader(comm_path) + if not success: + _logger.warning(f'Communication profile data not found, using default data at {_DEFAULT_COMM_DATA_PATH}') + success, comm_info = loader(Path(_DEFAULT_COMM_DATA_PATH)) + if not success: + raise RuntimeError(f'Communication profile data is not compatible with plan_ngpus {plan_ngpus}') + return comm_info + - def __init__(self, graph: IRGraph, profile_dir: str, memory_granularity: int, ignore_small_tensor_threshold: int): - self.comm_info = {} +class CostDatabase: + def __init__(self, graph: IRGraph, profile_dir: str, plan_ngpus: int, memory_granularity: int, ignore_small_tensor_threshold: int): self.graph = graph self.profile_dir = Path(profile_dir) @@ -125,13 +155,7 @@ def __init__(self, graph: IRGraph, profile_dir: str, memory_granularity: int, ig self.comp_profile_path.mkdir(parents=True) self.db.load_ops(self.comp_profile_path) - comm_dir = self.profile_dir / 'comm' - if not comm_dir.exists(): - _logger.warning(f'Communication profile data not found, using default data at {_DEFAULT_COMM_DATA_PATH}') - comm_dir = Path(_DEFAULT_COMM_DATA_PATH) - for fname in listdir(comm_dir): - with open(comm_dir / fname, 'r') as f: - self.comm_info[fname] = json.load(f) + self.comm_info = _load_comm_data(self.profile_dir, plan_ngpus) self.memory_granularity = memory_granularity self.ignore_small_tensor_threshold = ignore_small_tensor_threshold @@ -188,8 +212,7 @@ def round(self, mem): if mem % self.memory_granularity == 0: return mem else: - return (mem + self.memory_granularity - ) // self.memory_granularity * self.memory_granularity + return (mem + self.memory_granularity) // self.memory_granularity * self.memory_granularity def filter_then_sum(self, tensor_sizes: Tuple[int], mask=[]): # assert len(tensor_sizes) == len( diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 552815d2..31b287e5 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -526,6 +526,7 @@ def __init__(self, ir_graph: IRGraph, autodist_config: AutoDistConfig): self.cost_database = CostDatabase( self.ir_graph, profile_dir=autodist_config.profile_dir, + plan_ngpus=autodist_config.ngpus, memory_granularity=autodist_config.memory_granularity, ignore_small_tensor_threshold=autodist_config.ignore_small_tensor_threshold, ) diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index 7d9c1661..f042bb18 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -186,7 +186,7 @@ def get_import_path(fn: Callable) -> str: import_path = get_import_path(fn) if import_path == '__main__': raise NotImplementedError( - f"Cannot register function {fsig} in __main__ module. " + f"Cannot register function {fn} in __main__ module. " f"Try to define it in another module and import into main") if is_autograd_apply(fn): diff --git a/tests/autodist/spmd_solver/test_cube_operator.py b/tests/autodist/spmd_solver/test_cube_operator.py index dd404377..ba81036a 100644 --- a/tests/autodist/spmd_solver/test_cube_operator.py +++ b/tests/autodist/spmd_solver/test_cube_operator.py @@ -47,7 +47,7 @@ def test_cube_operator(): dummy_input, attr_savedir=tempdir, constant_folding=True) - cfg = AutoDistConfig(mesh_col=2) + cfg = AutoDistConfig(mesh_col=2, parallel_profile=False) model_graph = ModelGraph(ir_graph, cfg) mock_attention_op = model_graph.operator_list[0] assert mock_attention_op.pos2dim_id((0, 0)) == 'h' From 5795f7107c315cc02ef3f3f0d234fb49e3f9d277 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 6 Jan 2025 07:10:36 +0000 Subject: [PATCH 1794/1892] Merged PR 2352: [UT] fix dis test fix dis test --- tests/graph/tracer/test_dis.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/graph/tracer/test_dis.py b/tests/graph/tracer/test_dis.py index 0af05417..3a253852 100644 --- a/tests/graph/tracer/test_dis.py +++ b/tests/graph/tracer/test_dis.py @@ -113,7 +113,7 @@ def test_extend(): [1,2].extend(a) # in <= python 3.10, opname is CALL_METHOD # in >= python 3.11, opname is CALL - if sys.version_info.minor <= 10: + if sys.version_info < (3, 11): assert a.caller_inst.opname == 'CALL_METHOD' assert a.len_caller_inst.opname == 'CALL_METHOD' else: @@ -185,23 +185,33 @@ def test_bool(): bool(x[c]) # CALL_FUNCTION # in <= python 3.10, opname is CALL_FUNCTION # in >= python 3.11, opname is CALL - if sys.version_info.minor <= 10: + if sys.version_info < (3, 11): assert c.caller_inst.opname == 'CALL_FUNCTION' else: assert c.caller_inst.opname == 'CALL' c and 1 # JUMP_IF_FALSE_OR_POP - assert c.caller_inst.opname == 'JUMP_IF_FALSE_OR_POP' + # in <= python 3.11, opname is JUMP_IF_FALSE_OR_POP + # in >= python 3.12, opname is POP_JUMP_IF_FALSE + if sys.version_info < (3, 12): + assert c.caller_inst.opname == 'JUMP_IF_FALSE_OR_POP' + else: + assert c.caller_inst.opname == 'POP_JUMP_IF_FALSE' c.caller_inst = None c or 1 # JUMP_IF_TRUE_OR_POP - assert c.caller_inst.opname == 'JUMP_IF_TRUE_OR_POP' + # in <= python 3.11, opname is JUMP_IF_TRUE_OR_POP + # in >= python 3.12, opname is POP_JUMP_IF_TRUE + if sys.version_info < (3, 12): + assert c.caller_inst.opname == 'JUMP_IF_TRUE_OR_POP' + else: + assert c.caller_inst.opname == 'POP_JUMP_IF_TRUE' c.caller_inst = None bool(c) # CALL_FUNCTION # in <= python 3.10, opname is CALL_FUNCTION # in >= python 3.11, opname is CALL - if sys.version_info.minor <= 10: + if sys.version_info < (3, 11): assert c.caller_inst.opname == 'CALL_FUNCTION' else: assert c.caller_inst.opname == 'CALL' @@ -227,7 +237,7 @@ def test_bool(): pass # in <= python 3.10, opname is CALL_FUNCTION # in >= python 3.11, opname is CALL - if sys.version_info.minor <= 10: + if sys.version_info < (3, 11): assert c.caller_inst.opname == 'CALL_FUNCTION' else: assert c.caller_inst.opname == 'CALL' From 9caab232aa6b6d99dce7edc3483ba3a765e3ce20 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Thu, 9 Jan 2025 08:12:42 +0000 Subject: [PATCH 1795/1892] Merged PR 2353: Update version numbers and add missing headers --- README.md | 2 +- docs/source/installation.rst | 2 +- docs/source/quickstart.rst | 2 +- nnscaler/version.py | 2 +- tests/autodist/pas/test_multiref_activation.py | 3 +++ tests/profiler/test_benchmark_comm.py | 3 +++ 6 files changed, 10 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 50a27325..88d55bce 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ nnScaler is a parallelization engine that compiles a Deep neural network (DNN) m # Latest News nnScaler (also known as CUBE as code name) has been adopted by multiple product and research projects, this section includes some of the latest news from the team and partner projects. -* **2024-11-26** nnScaler 0.5 released: https://github.com/microsoft/nnscaler/releases/tag/0.5 +* **2025-01-08** nnScaler 0.6 released: https://github.com/microsoft/nnscaler/releases/tag/0.6 * **2024-10-07** Diff-Transformer utilizes nnScaler for differential attention mechanism: [DIFFERENTIAL TRANSFORMER](https://arxiv.org/abs/2410.05258) * **2024-05-09** YOCO utilizes nnScaler for long-sequence training: [(YOCO)You only cache once: Decoder-decoder architectures for language models](https://arxiv.org/abs/2405.05254) * **2024-04-22** Post training for the long context version of [Phi-3 series](https://arxiv.org/abs/2404.14219) diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 6cedb7fb..d71553c9 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -12,7 +12,7 @@ The wheel package is hosted on `GitHub release Date: Thu, 16 Jan 2025 04:22:15 +0000 Subject: [PATCH 1796/1892] Merged PR 2355: [Tracer] support subscriptable type hint --- nnscaler/graph/tracer/wrap_utils.py | 4 ++++ tests/graph/tracer/test_type_hint.py | 33 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 tests/graph/tracer/test_type_hint.py diff --git a/nnscaler/graph/tracer/wrap_utils.py b/nnscaler/graph/tracer/wrap_utils.py index f495fd04..710871f5 100644 --- a/nnscaler/graph/tracer/wrap_utils.py +++ b/nnscaler/graph/tracer/wrap_utils.py @@ -298,6 +298,10 @@ def __hash__(self): else: setattr(clz_wrapper_clz, name, attr) + # to support subscriptable type hint like func(x: dict[str, str]) + if hasattr(clz, '__class_getitem__'): + setattr(clz_wrapper_clz, '__class_getitem__', clz.__class_getitem__) + wrapped_cls_to_orig_cls[clz_wrapper_clz] = clz return clz_wrapper_clz diff --git a/tests/graph/tracer/test_type_hint.py b/tests/graph/tracer/test_type_hint.py new file mode 100644 index 00000000..b8560185 --- /dev/null +++ b/tests/graph/tracer/test_type_hint.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import namedtuple + +import torch +from nnscaler.graph.parser.converter import to_fx_graph + +from ...utils import replace_all_device_with + + +def func_with_type_hint(x: list[torch.Tensor]) -> torch.Tensor: + return x[0] + x[1] + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(10, 5) + + def forward(self, data: dict[str, torch.Tensor]): + return func_with_type_hint([self.fc1(data['x']), self.fc2(data['x'])]) + + +@replace_all_device_with('cpu') +def test_type_hint(): + model = SimpleModel() + dummy_input = {'data': {'x': torch.rand(10)}} + traced_graph = to_fx_graph(model, dummy_input) + + # just check if we can trace a model contains original type hint + assert True From ac454264fd995750ba5ef02236fadca91e83e99d Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 16 Jan 2025 04:22:40 +0000 Subject: [PATCH 1797/1892] Merged PR 2345: [Tracer + Codegen] support grad mode & autocast --- nnscaler/codegen/module/module.py | 129 ++++++- nnscaler/graph/graph.py | 29 +- nnscaler/graph/tracer/concrete_tracer.py | 6 +- nnscaler/graph/tracer/metadata.py | 48 ++- nnscaler/graph/tracer/orig_func.py | 1 + nnscaler/graph/tracer/wrap_utils.py | 20 ++ nnscaler/runtime/function/function.py | 6 + tests/graph/tracer/test_dict_iter.py | 52 +++ .../test_gencode_ctx_manager.py | 317 ++++++++++++++++++ 9 files changed, 582 insertions(+), 26 deletions(-) create mode 100644 tests/graph/tracer/test_dict_iter.py create mode 100644 tests/parallel_module/test_gencode_ctx_manager.py diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 8214c978..4ac9c853 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import List, Optional, Tuple, Dict, Any +from typing import List, Optional, Tuple, Dict, Any, Literal import more_itertools import logging import copy @@ -17,18 +17,20 @@ from nnscaler.graph.graph import IRSegment from nnscaler.graph.parser.register import CustomizedOps +from nnscaler.graph.tracer.metadata import AutocastInfo, GradMode from nnscaler.execplan import ExecutionPlan from nnscaler.execplan.execplan import ExeReuseCell from nnscaler.codegen.syntax.symtable import SymbolTable -from nnscaler.codegen.syntax.blocks import ClassBlock, FunctionBlock +from nnscaler.codegen.syntax.blocks import ClassBlock, FunctionBlock, Block from nnscaler.codegen.emit import FuncEmission from nnscaler.codegen.module.autograd import AutogradAdapterCodeGen from nnscaler.codegen.lifecycle import LifeCycle from nnscaler.flags import CompileFlag +from nnscaler.utils import fields from nnscaler import __version__ as runtime_version @@ -837,23 +839,136 @@ def _emit_nodes(self, nodes: List[IRCell], lifecycle: LifeCycle, runtime_devid: The fields storing intermediate codes that are populated by this method: - NONE """ - node_codes = [] - for node in nodes: + def has_op_context_info(node: IRCell): + if node.op_context is not None: + return True + else: + # if node is a IRFwOperation convert from fx graph (not create by nnscaler), it should have `op_context` field, + # otherwise the node is created by nnscaler. + return False + + def emit_context_manager(node: IRCell): + # Consider to emit torch.no_grad and torch.autocast context manager. + # + # There have two kinds of return values: + # 1. str, "", an empty str means there is no need to add context manager, current node is under default context. + # 2. str, "with xxx as yyy, aaa as bbb:", current node is under a specific context, need to add context manager. + assert node.op_context is not None + grad_mode = GradMode(**node.op_context["grad_mode"]) + autocast_info = AutocastInfo(**node.op_context["autocast_info"]) + if grad_mode.grad_mode and autocast_info.nesting == 0: + return "" + else: + ctx_managers = [] + if grad_mode.inference_mode: + ctx_managers.append("torch.inference_mode()") + elif grad_mode.no_grad_mode: + ctx_managers.append("torch.no_grad()") + + # NOTE: assume all tensor on cuda device, just care about cuda autocast now + if autocast_info.nesting > 0: + ctx_managers.append(f"torch.autocast(device_type='cuda', dtype={autocast_info.cuda_dtype!r}, enabled={autocast_info.cuda_enabled!r}, cache_enabled={autocast_info.cache_enabled!r})") + else: + assert autocast_info.nesting == 0, f'get autocast nesting state: {autocast_info.nesting}' + code = "with " + ", ".join(ctx_managers) + ":" + return code + + def emit_node(node): + node_code = [] # execute if isinstance(node, IRFwOperation): code = self.emit_fnode(node, runtime_devid=runtime_devid, plan_ndevs=len(self.devices), runtime_ndevs=self.runtime_ndevs, prefix_attr='self.') - node_codes += code + node_code += code elif isinstance(node, IRAdapter): # for adapters inside an IRSegment, we don't apply async communication to it # as it is mostly in critical path. code = self.emit_adapter(node, async_op=False) - node_codes += code + node_code += code else: raise RuntimeError(f"unexpected type {type(node)} in IRSegment") # release tensors_to_del = lifecycle.release_tensors_after_node(node) if len(tensors_to_del) > 0: - node_codes.append(self.emit_release(tensors_to_del)) + node_code.append(self.emit_release(tensors_to_del)) + return node_code + + def insert_codes_under_ctx(ctx_code, codes): + if ctx_code != "" and codes: + with Block(ctx_code) as cblock: + cblock.insert_body(codes) + # [''] to make a new line, make the generated code pretty + return [''] + cblock.code + [''] + else: + return codes + + node_codes = [] + current_context_manager_code = "" + current_codes = [] + for node in nodes: + if has_op_context_info(node): + new_context_manager_code = emit_context_manager(node) + if current_context_manager_code != new_context_manager_code: + node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) + current_codes = emit_node(node) + current_context_manager_code = new_context_manager_code + else: + current_codes.extend(emit_node(node)) + else: + # Node without op context infortmation means it is inserted by nnscaler, not convert from original fx graph, + # for example, multiref node and adapter node, currently for nodes inserted by nnscaler we have the following assumption: + # - the fx traced graph shows a context manager's impact to tensors, + # - the behavior of an inserted node is determined by tensor properties, like data type and requires grad, + # - combine the two points together, it is safe to put the inserted node in the default context. + # Base on this assumption, when we meet a node without op context infortmation, will force break the current code context scope + # and emit current node code without context managers. + # + # Here is an example about the inserted node code generation, the inserted node is a multiref node of y, + # and all the inserted nodes will under default context (without context managers): + # """ + # # original code + # with torch.no_grad(): + # y = func1(x) + # z = func2(x) + # + # # generated code + # with torch.no_grad(): + # y = func1(x) + # y_1, y_2 = nnscaler.runtime.function.multiref(y, 2) + # with torch.no_grad(): + # z = func2(x) + # """ + # + # This way have two risks: + # 1. the assumption is no longer valid in subsequent development. + # 2. nodes that originally belonged to the same context need to be executed in a continuous context and cannot be interrupted. + # Fortunately, these two risks have not yet occurred for the current inserted nodes and supported context managers. + # + # Please note that if one entire context scope is split to multiple sub-scope, it may introduce additional overhead. + # Here is an overhead example about torch.autocast: + # """ + # # original code + # with torch.autocast(device_type='cuda', dtype=torch.float32, enabled=True, cache_enabled=True): + # y = func1(x, a) + # z = func2(x, b) + # + # # generated code + # with torch.autocast(device_type='cuda', dtype=torch.float32, enabled=True, cache_enabled=True): + # y = func1(x, a) + # ... + # with torch.autocast(device_type='cuda', dtype=torch.float32, enabled=True, cache_enabled=True): + # z = func2(x, b) + # """ + # In the original code, x might cast to float32 and used by both func1 and func2, + # but in generated code, because the scope is interrupted, the two new scopes cannot share the x cast result, + # then there might have a additional cast x to float32 for func2. + # + # For torch.no_grad and torch.inference_mode, the overhead is not significant, so we can ignore it. + # + # TODO: all inserted nodes should have its op context field. + node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) + node_codes += emit_node(node) + current_codes = [] + node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) return node_codes diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 803859f1..27b4dc2a 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -340,12 +340,7 @@ def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> Lis rtensor.grad = copy.copy(itensor.grad) # insert forward for fnode in fnodes: - if isinstance(node, IRFwOperation): - fnode.recompute = node.recompute - if isinstance(node.comment, str): - fnode.comment = node.comment - fnode.module_stack = node.module_stack - fnode.device = node.device + self.copy_node_meta_info(node, fnode) fsegment.replace(node, fnodes) # insert backward bsegment: IRSegment = fsegment.mirror @@ -402,12 +397,7 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], # insert forward node fsegment: IRSegment = self.segment(node) for fnode in fnodes: - if isinstance(node, IRFwOperation): - fnode.recompute = node.recompute - if isinstance(node.comment, str): - fnode.comment = node.comment - fnode.module_stack = node.module_stack - fnode.device = node.device + self.copy_node_meta_info(node, fnode) fsegment.replace(node, fnodes) if node.mirror is None: return fnodes @@ -1219,3 +1209,18 @@ def checksum(self, strict: bool = True) -> str: states = str((max_tensor_id, max_cell_id, self.extra_repr())) checksum = hashlib.md5(states.encode()).hexdigest() return checksum + + @staticmethod + def copy_node_meta_info(src_node: Union[IRFwOperation, IRDataOperation], dest_node: Union[IRFwOperation, IRDataOperation]): + """ + Copy meta information from src_node to dest_node. + Current copy fields: ['recompute', 'comment', 'op_context', 'module_stack', 'device'] + """ + if isinstance(src_node, IRFwOperation): + dest_node.recompute = src_node.recompute + if isinstance(src_node.comment, str): + dest_node.comment = src_node.comment + if src_node.op_context is not None: + dest_node.op_context = src_node.op_context + dest_node.module_stack = src_node.module_stack + dest_node.device = src_node.device diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index eef7209f..9f4c0e0a 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -37,7 +37,7 @@ from . import pytree_utils, orig_func, wrap_utils from .frame_utils import get_frame_record from .function_patcher import FunctionPatcher -from .metadata import EmptyResult, extract_results_metadata, get_op_context +from .metadata import EmptyResult, extract_metadata from .operator_patcher import OperatorPatcherContext from .torch_fx_patcher import TorchFXPatcher, ExtraSEFPatcher, side_effectful_inplace_ops from .trace_strategy import TRACE_STRATEGY @@ -148,7 +148,6 @@ def create_node(self, kind : str, target : Target, check_for_mutable_operation(target, args, kwargs) node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) - node.meta['op_context'] = get_op_context() # TODO node_name_to_scope will be depricated in favor of # node.meta['nn_module_stack'] self.node_name_to_scope[node.name] = ( @@ -165,7 +164,7 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): # unwrap all proxy in the node result here, because no proxy should be record in the tensor metadata node_result = pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, node_result) - extract_results_metadata(node_result, node) + extract_metadata(node_result, node) return node @compatibility(is_backward_compatible=True) @@ -663,6 +662,7 @@ def wrap_never_wrap_function(func, *args, **kwargs): # for cuda versions of pytorch, autograd.Function.apply should be reverted by delattr self.patcher.patch_method(torch.autograd.Function, "apply", wrap_utils.create_wrapped_autograd_apply(self), deduplicate=False, revert_by_del=True) self.patcher.patch_method(torch, "_assert", wrap_utils.torch_assert_wrapper, deduplicate=False) + self.patcher.patch_method(torch, "autocast", wrap_utils.torch_autocast_wrapper_clz, deduplicate=False) self.patcher.patch_method(builtins, "map", wrap_utils.map_wrapper_clz, deduplicate=False) self.patcher.patch_method(builtins, "enumerate", wrap_utils.enumerate_wrapper_clz, deduplicate=False) diff --git a/nnscaler/graph/tracer/metadata.py b/nnscaler/graph/tracer/metadata.py index 62e2ce9c..f6f898a2 100644 --- a/nnscaler/graph/tracer/metadata.py +++ b/nnscaler/graph/tracer/metadata.py @@ -3,7 +3,6 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple from dataclasses import dataclass, asdict -import copy import torch from torch.fx.node import Node @@ -22,6 +21,40 @@ class EmptyResult: pass +@dataclass +class GradMode: + grad_mode: bool + no_grad_mode: bool + inference_mode: bool + + @classmethod + def from_context(cls): + return cls(torch.is_grad_enabled(), not torch.is_grad_enabled(), torch.is_inference_mode_enabled()) + + +@dataclass +class AutocastInfo: + # the nesting number of autocast context, if =0, means it is not under autocast context + # torch use this field to determine whether the cache needs to be cleaned + # nnscaler use this field to determine whether generating autocast context manager in code + nesting: int + + cache_enabled: bool + cpu_enabled: bool + cpu_dtype: torch.dtype + cuda_enabled: bool + cuda_dtype: torch.dtype + # NOTE: not care about "xpu" and "hpu" now + + @classmethod + def from_context(cls): + # use function pair [torch.autocast_increment_nesting, torch.autocast_decrement_nesting] to get the nesting number + torch.autocast_increment_nesting() + return cls(torch.autocast_decrement_nesting(), torch.is_autocast_cache_enabled(), + torch.is_autocast_cpu_enabled(), torch.get_autocast_cpu_dtype(), + torch.is_autocast_enabled(), torch.get_autocast_gpu_dtype()) + + @dataclass class OpContext: """ @@ -39,7 +72,12 @@ class OpContext: def get_op_context() -> OpContext: - return asdict(_GLOBAL_OP_CONTEXT) + """ + Get op context information. + Please note that current only tracked context managers that modify the tensor properties, for example, modify the requires_grad, dtype, + so that nnscaler can generate context manager code for them safety. + """ + return asdict(_GLOBAL_OP_CONTEXT) | {'grad_mode': asdict(GradMode.from_context()), 'autocast_info': asdict(AutocastInfo.from_context())} class TensorMetadata(NamedTuple): @@ -99,16 +137,18 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: return TensorMetadata(shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) -def extract_results_metadata(results: Any, node: Node): +def extract_metadata(results: Any, node: Node): if results is not EmptyResult: res = tuple(results) if isinstance(results, (DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE)) else results meta = pytree_utils.tree_map_only(torch.Tensor, _extract_tensor_metadata, res) # we should get the meta info of the inner element of these type obj if isinstance(results, DICT_KEYS_TYPE): - meta = {i: m for i, m in enumerate(meta)}.keys() + meta = {m: i for i, m in enumerate(meta)}.keys() if isinstance(results, DICT_VALUES_TYPE): meta = {i: m for i, m in enumerate(meta)}.values() if isinstance(results, DICT_ITEMS_TYPE): meta = {i: m for i, m in meta}.items() node.meta['tensor_meta'] = meta node.meta['type'] = type(results) + + node.meta['op_context'] = get_op_context() diff --git a/nnscaler/graph/tracer/orig_func.py b/nnscaler/graph/tracer/orig_func.py index cad0680b..04fa4a79 100644 --- a/nnscaler/graph/tracer/orig_func.py +++ b/nnscaler/graph/tracer/orig_func.py @@ -48,6 +48,7 @@ torch_assert = torch._assert torch_Size = torch.Size torch_finfo = torch.finfo +torch_autocast = torch.autocast import importlib import_module = importlib.import_module diff --git a/nnscaler/graph/tracer/wrap_utils.py b/nnscaler/graph/tracer/wrap_utils.py index 710871f5..28ff2e37 100644 --- a/nnscaler/graph/tracer/wrap_utils.py +++ b/nnscaler/graph/tracer/wrap_utils.py @@ -112,6 +112,9 @@ class LeafWrapInfo: torch.nn.ParameterDict.__len__: LeafWrapInfo([], False, builtins.len), torch.nn.ParameterDict.__iter__: LeafWrapInfo([], False, builtins.iter), torch.nn.ParameterDict.__contains__: LeafWrapInfo([], False, operator.contains), + + torch.autocast.__enter__: LeafWrapInfo([], False, None), + torch.autocast.__exit__: LeafWrapInfo([], False, None), } @@ -544,6 +547,23 @@ def __hash__(self): wrapped_cls_to_orig_cls[type_wrapper_clz] = orig_func.type +# wrap autocast to make it support proxy input and the related node will be DCE in DCE stage. +class torch_autocast_wrapper_clz: + # used to track the original class + _fx_wrapped_ori_clz = orig_func.torch_autocast + + def __new__(cls, *args, **kwargs): + return orig_func.torch_autocast(*args, **kwargs) + + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(orig_func.type)) + + def __hash__(self): + return id(self) + +wrapped_cls_to_orig_cls[torch_autocast_wrapper_clz] = orig_func.torch_autocast + + @functools.wraps(orig_func.torch_assert) def torch_assert_wrapper(condition, message): if orig_func.isinstance(condition, cct.ConcreteProxy): diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index f6a8f06f..51ef947c 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -1,6 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +""" +The functions in this file might be inserted as node to graph, to ensure that the inserted node can generate the correct code, +please following the assumption: + - should execute under default context (not under for example, torch.no_grad) no matter what the producer and consumer context are. +""" + from contextlib import contextmanager from typing import Optional, List, Tuple, Union, Any import torch diff --git a/tests/graph/tracer/test_dict_iter.py b/tests/graph/tracer/test_dict_iter.py new file mode 100644 index 00000000..9cb60613 --- /dev/null +++ b/tests/graph/tracer/test_dict_iter.py @@ -0,0 +1,52 @@ +import pytest + +from nnscaler.graph.tracer import concrete_trace, wrap_utils +from nnscaler.graph.tracer.metadata import DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE +import torch + + +def wrap_as_dict(x): + return {'x': x} + +def dict_keys_as_input(x_keys): + return x_keys + +def dict_values_as_input(x_values): + return x_values + +def dict_items_as_input(x_items): + return x_items + + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.randn(5, 10)) + + def forward(self, x): + x_dict = wrap_as_dict(x) + x_keys = [_ for _ in dict_keys_as_input(x_dict.keys())] + x_values = [_ for _ in dict_values_as_input(x_dict.values())] + x_items = [_ for _ in dict_items_as_input(x_dict.items())] + + x = self.param + x_dict[x_keys[0]] + x_values[0] + x_items[0][1] + return x + + +def test_dict_iter_metadata(): + graph = concrete_trace(TestModule(), + {'x': torch.randn(5, 10)}, + autowrap_leaf_function={ + wrap_as_dict: wrap_utils.LeafWrapInfo([], True, None), + dict_keys_as_input: wrap_utils.LeafWrapInfo([], True, None), + dict_values_as_input: wrap_utils.LeafWrapInfo([], True, None), + dict_items_as_input: wrap_utils.LeafWrapInfo([], True, None) + }, + strategy='cpu') + nodes = list(graph.graph.nodes) + dict_keys_as_input_node = nodes[2] + assert isinstance(dict_keys_as_input_node.meta['tensor_meta'], DICT_KEYS_TYPE) + dict_valus_as_input_node = nodes[6] + assert isinstance(dict_valus_as_input_node.meta['tensor_meta'], DICT_VALUES_TYPE) + dict_items_as_input_node = nodes[10] + assert isinstance(dict_items_as_input_node.meta['tensor_meta'], DICT_ITEMS_TYPE) diff --git a/tests/parallel_module/test_gencode_ctx_manager.py b/tests/parallel_module/test_gencode_ctx_manager.py new file mode 100644 index 00000000..eba20786 --- /dev/null +++ b/tests/parallel_module/test_gencode_ctx_manager.py @@ -0,0 +1,317 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +import ast +import tempfile +import torch + +from pathlib import Path +from nnscaler.parallel import parallelize, ComputeConfig, ParallelModule, build_optimizer +from .common import init_distributed, init_random +from .test_end2end import merge_cube_result +from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively +from ..utils import clear_dir_on_rank0 + + +class CtxManagerModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.param_1 = torch.nn.Parameter(torch.rand(16, 16)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + r_1 = torch.matmul(x, self.param_1) + r_2 = torch.matmul(y, self.param_1) + with torch.no_grad(): + r_3 = torch.matmul(r_1, self.param_1) + with torch.enable_grad(): + r_4 = torch.matmul(r_2, self.param_1) + with torch.autocast(r_4.device.type): + r_5 = r_3 * r_4 + r = r_1 * r_2 * r_3 * r_4 * r_5 + return torch.matmul(r, self.param_1).norm() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_ctx_manager_codegen(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, check_ctx_manager_codegen, tempdir) + + +def dummy_data(): + return {'x': torch.rand(4, 16), 'y': torch.rand(4, 16)} + + +def check_ctx_manager_codegen(tempdir): + init_distributed() + m = CtxManagerModel() + m_new = parallelize( + m, + dummy_data(), + 'data', + ComputeConfig(2, 4), + gen_savedir=tempdir, + load_module=False + ) + for i in range(4): + code = get_gencode(tempdir, CtxManagerModel, i) + ########## Generated Model Code ########### + # from typing import * + # from pathlib import Path + # import torch + # import torch.utils.checkpoint as ckpt + # import nnscaler + # import _operator + # from numpy import inf + # import builtins + + # runtime_version = '0.6' + + + # import nnscaler.graph.function.wrapnn + + # import apex.normalization.fused_layer_norm + + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = False + # nmicros_per_scheduler_step = 1 + # rank = 0 + + # def __init__(self, init_params=True, *, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 2]) + # self.init_group(ranks=[1, 3]) + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('param_1_62', torch.nn.Parameter(torch.empty((16, 16), dtype=torch.float32))) + # self.add_full_map('param_1_62', 5, True, 'param_1', (16, 16), (slice(0, 16, None), slice(0, 16, None)), 1) + + + # self.wreducer312 = nnscaler.runtime.adapter.Reducer(ranks=[0, 2], reduce_op='sum', async_op=async_op, zero=False, max_bucket_size_bytes=max_bucket_size_bytes, zero_use_reduce_scatter=zero_use_reduce_scatter, zero_ngroups=1) + # self.wreducer312.add_param(self.param_1_62) + # self.add_reducer(self.wreducer312) + + # self._post_init(init_params) + + # def segment308(self, x_75, y_78): + # # auto_multiref + # param_1_106, param_1_107, param_1_108, param_1_109, param_1_110 = nnscaler.runtime.function.multiref(self.param_1_62, times=5) + # x_166 = nnscaler.runtime.adapter.nn.split_allgather(x_75, dim=0, ranks=[0, 1]) + # del x_75 + # param_1_109 = nnscaler.runtime.adapter.nn.identity_allreduce(param_1_109, ranks=[0, 1]) + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 21, in forward, r_1 = torch.matmul(x, self.param_1) + # matmul_168 = torch.matmul(x_166, param_1_109) + # del param_1_109, x_166 + # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref + # matmul_226, matmul_194 = nnscaler.runtime.function.multiref(matmul_168, times=2) + # del matmul_168 + # y_180 = nnscaler.runtime.adapter.nn.split_allgather(y_78, dim=0, ranks=[0, 1]) + # del y_78 + # param_1_110 = nnscaler.runtime.adapter.nn.identity_allreduce(param_1_110, ranks=[0, 1]) + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 22, in forward, r_2 = torch.matmul(y, self.param_1) + # matmul_1_182 = torch.matmul(y_180, param_1_110) + # del param_1_110, y_180 + # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref + # matmul_1_202, matmul_1_228 = nnscaler.runtime.function.multiref(matmul_1_182, times=2) + # del matmul_1_182 + + # with torch.no_grad(): + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 24, in forward, r_3 = torch.matmul(r_1, self.param_1) + # matmul_2_196 = torch.matmul(matmul_194, param_1_106) + # del param_1_106, matmul_194 + + # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref + # matmul_2_216, matmul_2_242 = nnscaler.runtime.function.multiref(matmul_2_196, times=2) + # del matmul_2_196 + # param_1_107 = nnscaler.runtime.adapter.nn.identity_allreduce(param_1_107, ranks=[0, 1]) + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 26, in forward, r_4 = torch.matmul(r_2, self.param_1) + # matmul_3_204 = torch.matmul(matmul_1_202, param_1_107) + # del param_1_107, matmul_1_202 + # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref + # matmul_3_252, matmul_3_218 = nnscaler.runtime.function.multiref(matmul_3_204, times=2) + # del matmul_3_204 + + # with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True, cache_enabled=True): + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 28, in forward, r_5 = r_3 * r_4 + # mul_220 = torch.mul(matmul_2_216, matmul_3_218) + # del matmul_2_216, matmul_3_218 + + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 29, in forward, r = r_1 * r_2 * r_3 * r_4 * r_5 + # mul_1_230 = torch.mul(matmul_226, matmul_1_228) + # del matmul_226, matmul_1_228 + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 29, in forward, r = r_1 * r_2 * r_3 * r_4 * r_5 + # mul_2_244 = torch.mul(mul_1_230, matmul_2_242) + # del matmul_2_242, mul_1_230 + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 29, in forward, r = r_1 * r_2 * r_3 * r_4 * r_5 + # mul_3_254 = torch.mul(mul_2_244, matmul_3_252) + # del matmul_3_252, mul_2_244 + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 29, in forward, r = r_1 * r_2 * r_3 * r_4 * r_5 + # mul_4_264 = torch.mul(mul_3_254, mul_220) + # del mul_220, mul_3_254 + # param_1_108 = nnscaler.runtime.adapter.nn.identity_allreduce(param_1_108, ranks=[0, 1]) + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 30, in forward, return torch.matmul(r, self.param_1).norm() + # matmul_4_272 = torch.matmul(mul_4_264, param_1_108) + # del param_1_108, mul_4_264 + # matmul_4_72 = nnscaler.runtime.adapter.nn.allgather_split(matmul_4_272, dim=0, ranks=[0, 1]) + # del matmul_4_272 + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 30, in forward, return torch.matmul(r, self.param_1).norm() + # norm_61 = torch.norm(matmul_4_72, p='fro', dim=None, keepdim=False, out=None, dtype=None) + # del matmul_4_72 + # return norm_61 + + # def reducer312(self): + # self.wreducer312.sync_grads() + # return + + # def _forward_impl(self, x, y): + # norm_61 = self.segment308(x, y) + # return norm_61 + + # with torch.no_grad() as _nnscaler_no_grad: + def first_with_node_check(node: ast.With): + hit_no_grad = False + for item in node.items: + if isinstance(item.context_expr, ast.Call): + func = item.context_expr.func + if isinstance(func, ast.Attribute): + module_name = func.value.id if isinstance(func.value, ast.Name) else None + context_manager_name = func.attr + if module_name == 'torch' and context_manager_name == 'no_grad': + assert not hit_no_grad + hit_no_grad = True + else: + assert False, f"detect unexcepted context manager in first with code: {module_name}.{context_manager_name}" + assert hit_no_grad, f"context manager torch.no_grad not existed" + + # with torch.no_grad() as _nnscaler_no_grad, torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True, cache_enabled=True) as _nnscaler_autocast: + def second_with_node_check(node: ast.With): + hit_no_grad = False + hit_autocast = False + for item in node.items: + if isinstance(item.context_expr, ast.Call): + func = item.context_expr.func + if isinstance(func, ast.Attribute): + module_name = func.value.id if isinstance(func.value, ast.Name) else None + context_manager_name = func.attr + if module_name == 'torch' and context_manager_name == 'no_grad': + assert not hit_no_grad + hit_no_grad = True + elif module_name == 'torch' and context_manager_name == 'autocast': + assert not hit_autocast + hit_autocast = True + else: + assert False, f"detect unexcepted context manager in second with code: {module_name}.{context_manager_name}" + assert hit_no_grad, f"context manager torch.no_grad not existed" + assert hit_autocast, f"context manager torch.autocast not existed" + + with_node_count = 0 + for node in ast.walk(ast.parse(code)): + if isinstance(node, ast.With): + if with_node_count == 0: + first_with_node_check(node) + elif with_node_count == 1: + second_with_node_check(node) + else: + assert False, f"detect unexcepted third with code" + with_node_count += 1 + + +def get_gencode(cubesave_dir, module_class, index=0): + from nnscaler.parallel import _PARALLEL_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME + from pathlib import Path + + namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' + outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) + filecontent = (outdir /f'gencode{index}.py').read_text() + return filecontent + + +def _train_cube_one_sample(model: ParallelModule, mbs): + init_random() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) + data = [] + init_random() + data_size = mbs + for _ in range(data_size): + data.append(tuple(dummy_data().values())) + chunks = [data[i:i + mbs] for i in range(0, len(data), mbs)] + results = [] + for _, x in enumerate(chunks): + model.train() + losses = model.train_step(x) + print(f'loss {_}: {losses}') + optimizer.step() + gnorm = optimizer.clip_gnorm() + grads = {n: p.grad for n, p in model.named_parameters()} + model._add_extra_state(grads, '') + weights = {n: p.data for n, p in model.named_parameters()} + model._add_extra_state(weights, '') + results.append(clone_to_cpu_recursively([grads, weights, gnorm])) + optimizer.zero_grad() + return results + + +def gpu_worker_cube_one_sample(): + init_distributed() + init_random() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ctx_manager') as tempdir: + init_random() + model = CtxManagerModel() + model = parallelize( + model, + dummy_data(), + pas_policy='tp', + compute_config= ComputeConfig( + 2, 2, + use_end2end=True, + ), + gen_savedir=tempdir + ) + model.cuda() + train_result = _train_cube_one_sample(model, 1) + return train_result + + +def _train_ga(model, update_freq, data_size): + init_random() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + data = [] + init_random() + for _ in range(data_size): + data.append(dummy_data()) + results = [] + for i, x in enumerate(data): + model.train() + loss = model(**x) + print(f'loss {i}: {loss}') + loss.backward() + if i % update_freq == update_freq - 1: + optimizer.step() + grads = {n: p.grad for n, p in model.named_parameters()} + weights = {n: p.data for n, p in model.named_parameters()} + # gnorm calculation doesn't support float64, so let's skip it + results.append(clone_to_cpu_recursively([grads, weights, torch.tensor(0.0)])) + optimizer.zero_grad() + return results + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_loss_scaling(): + torch.cuda.set_device(0) + torch.set_default_device(f'cuda:0') + init_random() + model = CtxManagerModel() + ga4_result = _train_ga(model, 1, 1) + assert len(ga4_result) == 1 + ga4_grads = ga4_result[0][0] + + cube2_results = launch_torchrun(2, gpu_worker_cube_one_sample) + cube2_result = merge_cube_result({k: v for k, v in cube2_results.items()}) + assert len(cube2_result) == 1 + cube2_grads = cube2_result[0][0] + assert len(cube2_grads) == len(ga4_grads) + for k in cube2_grads: + assert torch.allclose(cube2_grads[k].cpu(), ga4_grads[k].cpu(), atol=1e-6, rtol=1e-6) From fc91977c23d93c32f6ee7cfe1c110ead99533feb Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 20 Jan 2025 02:16:29 +0000 Subject: [PATCH 1798/1892] Merged PR 2356: [Schedule] add interleaved 1f1b --- nnscaler/graph/graph.py | 4 +- nnscaler/graph/parser/parser.py | 2 +- nnscaler/graph/schedule/interleaved_1f1b.py | 324 ++++++++++++++++++ nnscaler/graph/schedule/predefined.py | 154 ++++++++- nnscaler/graph/schedule/schedplan.py | 40 ++- nnscaler/graph/tracer/concrete_tracer.py | 2 +- tests/graph/schedule/test_interleaved_1f1b.py | 172 ++++++++++ 7 files changed, 688 insertions(+), 10 deletions(-) create mode 100644 nnscaler/graph/schedule/interleaved_1f1b.py create mode 100644 tests/graph/schedule/test_interleaved_1f1b.py diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 27b4dc2a..39f52dc7 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -910,8 +910,8 @@ def staging(self, nodes: Tuple[IRFwOperation]): Returns: None """ - assert all(isinstance(node, IRFwOperation) for node in nodes), \ - f"Find node is not IRFwOperation or IRDataOperation: {node}" + for node in nodes: + assert isinstance(node, IRFwOperation), f"Expected node to be IRFwOperation, but got {node}" assert all(node in self._nodes for node in nodes), \ f"Exist node is not in graph nodes" starts = list(self._nodes.index(node) for node in nodes) diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index fc9f96a5..dc1566af 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -71,7 +71,7 @@ def parse(module: torch.fx.GraphModule, frame.push_var() # shape propagation - assert isinstance(dummy_inputs, dict), "Expected dummy inputs to parse module" + assert isinstance(dummy_inputs, dict), f"Expected dummy inputs to parse module, but got {dummy_inputs} of type {type(dummy_inputs)}" output_nodes = [node for node in module.graph.nodes if node.op == 'output'] # currently fx graph always has only one output # even if a tuple/list is returned, it is still just one output diff --git a/nnscaler/graph/schedule/interleaved_1f1b.py b/nnscaler/graph/schedule/interleaved_1f1b.py new file mode 100644 index 00000000..20c87374 --- /dev/null +++ b/nnscaler/graph/schedule/interleaved_1f1b.py @@ -0,0 +1,324 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# CREDITS: most of the code is from torch: https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/schedules.py + +from collections import defaultdict +from typing import Dict, List, Optional +from enum import Enum +from typing import NamedTuple +import re +import logging + +logger = logging.getLogger(__name__) + + +class _ComputationType(Enum): + # TODO(whc) rename to _ActType? + FORWARD = 1 + BACKWARD_INPUT = 2 + BACKWARD_WEIGHT = 3 + UNSHARD = 4 + RESHARD = 5 + SEND_F = 6 + RECV_F = 7 + SEND_B = 8 + RECV_B = 9 + FULL_BACKWARD = 10 + + def __str__(self): + str_map = { + _ComputationType.FORWARD: "F", + _ComputationType.BACKWARD_INPUT: "I", + _ComputationType.BACKWARD_WEIGHT: "W", + _ComputationType.UNSHARD: "UNSHARD", + _ComputationType.RESHARD: "RESHARD", + _ComputationType.SEND_F: "SEND_F", + _ComputationType.RECV_F: "RECV_F", + _ComputationType.SEND_B: "SEND_B", + _ComputationType.RECV_B: "RECV_B", + _ComputationType.FULL_BACKWARD: "B", + } + return str_map[self] + + @staticmethod + def from_str(action): + if action == "F": + return _ComputationType.FORWARD + elif action == "I": + return _ComputationType.BACKWARD_INPUT + elif action == "W": + return _ComputationType.BACKWARD_WEIGHT + elif action == "UNSHARD": + return _ComputationType.UNSHARD + elif action == "RESHARD": + return _ComputationType.RESHARD + elif action == "SEND_F": + return _ComputationType.SEND_F + elif action == "RECV_F": + return _ComputationType.RECV_F + elif action == "SEND_B": + return _ComputationType.SEND_B + elif action == "RECV_B": + return _ComputationType.RECV_B + elif action == "B": + return _ComputationType.FULL_BACKWARD + else: + raise RuntimeError(f"Invalid computation type {action}") + + +FORWARD = _ComputationType.FORWARD +BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT +BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT +UNSHARD = _ComputationType.UNSHARD +RESHARD = _ComputationType.RESHARD +SEND_F = _ComputationType.SEND_F +RECV_F = _ComputationType.RECV_F +SEND_B = _ComputationType.SEND_B +RECV_B = _ComputationType.RECV_B +FULL_BACKWARD = _ComputationType.FULL_BACKWARD + +# Convenience shorthand for compute actions only since they are used in 'simple schedule format' +F = FORWARD +I = BACKWARD_INPUT +W = BACKWARD_WEIGHT +B = FULL_BACKWARD + +# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) +_action_regex = re.compile( + r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" +) + + +class _Action(NamedTuple): + stage_index: int + computation_type: _ComputationType + microbatch_index: Optional[int] = None + + def __repr__(self): + repr = str(self.stage_index) + repr += str(self.computation_type) + if self.microbatch_index is not None: + repr += str(self.microbatch_index) + return repr + + @staticmethod + def from_str(action_string: str): + """ + Reverse of __repr__ + + String should be formatted as [stage][action type][(microbatch)] + e.g. `2F0`, `1UNSHARD`, `3SEND_F1` + """ + action_string = action_string.strip() + if match := _action_regex.match(action_string): + stage_index, computation_type, microbatch_index = match.groups() + return _Action( + int(stage_index), + _ComputationType.from_str(computation_type), + int(microbatch_index) if len(microbatch_index) else None, + ) + elif action_string == "": + return None + raise RuntimeError( + f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0" + ) + + +def _get_1f1b_rank_ops( + n_local_stages, + pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches=0, + enable_zero_bubble=False, +): + # All stages start with handling microbatch 0 + fwd_stage_mb_index: Dict[int, int] = defaultdict(int) + bwd_stage_mb_index: Dict[int, int] = defaultdict(int) + weight_stage_mb_index: Dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: List[Optional[_Action]] = [None for _ in range(rank)] + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + if enable_zero_bubble: + post_warmup_ops = pp_group_size - rank - 1 + + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + + backward_op_ids = [] + weight_op_count = 0 + + FULL_BACKWARD_OR_BACKWARD_INPUT = ( + BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD + ) + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index) + ) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + if not enable_zero_bubble: + rank_ops.append(None) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + + while enable_zero_bubble and weight_op_count < len(backward_op_ids): + weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count]) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index + ) + ) + weight_op_count += 1 + + return rank_ops + + +# use `self` here since it is a member function in torch, refer to +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/schedules.py#L1999 +def _calculate_single_rank_operations(self, rank): + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 2 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.info( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + ) diff --git a/nnscaler/graph/schedule/predefined.py b/nnscaler/graph/schedule/predefined.py index 9d2ce581..4d0e7cd2 100644 --- a/nnscaler/graph/schedule/predefined.py +++ b/nnscaler/graph/schedule/predefined.py @@ -31,7 +31,7 @@ def sched_1f1b(graph: IRGraph, num_microbatches: int, num_stages: int) -> Schedu raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) fsegs = [seg for seg in segments if seg.isfw()] - assert len(fsegs) == num_stages, f"Mismatch of forward segement number ({len(fsegs)}) with num_stages ({num_stages})" + assert len(fsegs) == num_stages, f"Mismatch of forward segment number ({len(fsegs)}) with num_stages ({num_stages})" # describe schedule sched = SchedulePlan(graph, num_microbatches) @@ -54,6 +54,152 @@ def sched_1f1b(graph: IRGraph, num_microbatches: int, num_stages: int) -> Schedu sched.finish() return sched + @staticmethod + def sched_1f1b_interleaved(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: + """ + 1F1B interleaved scheduling. The graph should be staged into segments. You can refer to the paper + [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/pdf/2104.04473) + for more details. Different from the 1f1b scheduling where each pipeline device group corresponds to exactly + one forward segment and its backward segment, in 1F1B interleaved scheduling, each pipeline device group + maintains multiple forward segments and their corresponding backward segments. + + Notations: + - `n`: number of pipeline device groups + - `m`: number of pipeline stages (model is split into m parts) + - `k`: number of local stages in each pipeline device group, thus `m = n * k` + - `p`: number of micro-batches in a training step, currently constrained to be a multiple of n + - `q`: in interleaved 1f1b, p is divided into q groups, each group contains n micro-batches, thus `p = n * q` + - (x)F(y) denotes the x-th forward segment with micro-batch y, B denotes the backward segment. + + Therefore: + - each pipeline rank runs k * p * 2 steps in total (will run `k` forward and `k` backward for each micro-batch). + - i-th segment is placed at *i mod n*-th device group + - if i mod n = j mod n, then i-th segment and j-th segment should be placed at the same device group. + + Furthermore, the paper also assumes that all forward segments have a similar execution time, and all backward + segments have a similar execution time. + + For 1f1b like schedule, computation process on each rank is divided into three parts: + 1. warmup: composed of forward stages + 2. steady: a list of pairs of a forward stage and a backward stage + 3. cooldown: composed of backward stages + + Here is an example with 4 devices, 8 stages, and 4 micro-batches. Note that in this schedule representation, + the steady part is different from the formulation that is currently used in nnScaler. + 0F0 0F1 0F2 0F3 4F0 4F1 4F2 | 4F3 4B0 | 4B1 4B2 4B3 0B0 0B1 0B2 0B3 + 1F0 1F1 1F2 1F3 5F0 5F1 5F2 | 5B0 5F3 | 5B1 5B2 5B3 1B0 1B1 1B2 1B3 + 2F0 2F1 2F2 2F3 6F0 6F1 6B0 | 6F2 6B1 | 6F3 6B2 6B3 2B0 2B1 2B2 2B3 + 3F0 3F1 3F2 3F3 7F0 7B0 7F1 | 7B1 7F2 | 7B2 7F3 7B3 3B0 3B1 3B2 3B3 + + In modern LLMs, the backward segments takes more time than the forward segments. As a result, + the schedule can be adjusted like that in Megatron-LM (this schedule is clearer and easier to + calculate the end to end span). In addition, `4F3` above is executed much earlier in the runtime. + As a result, we prefer to make the schedule closer to the real execution order. The schedule is like: + 0F0 0F1 0F2 0F3 4F0 4F1 4F2 4F3 4B0 4B1 4B2 4B3 0B0 0B1 0B2 0B3 + 1F0 1F1 1F2 1F3 5F0 5F1 5F2 5F3 5B0 5B1 5B2 5B3 1B0 1B1 1B2 1B3 + 2F0 2F1 2F2 2F3 6F0 6F1 6F2 6B0 6F3 6B1 6B2 6B3 2B0 2B1 2B2 2B3 + 3F0 3F1 3F2 3F3 7F0 7B0 7F1 7B1 7F2 7B2 7F3 7B3 3B0 3B1 3B2 3B3 + In this representation, #step for the 3 parts in each rank is: + | rank | warmup | steady | cooldown | + | 0 | 8 | 0 | 8 | + | 1 | 8 | 0 | 8 | + | 2 | 6 | 4 | 6 | + | 3 | 4 | 8 | 4 | + There is a subtle difference between the two schedules on memory usage on rank 1 and rank 2. However, the order + of the difference is a small constant (1 forward stage's memory footprint). Considering the memory is bounded + by the first device group, we can omit the difference for now. + + In torch, another schedule representation is used, which is equivalent to the Megatron-LM schedule. + Note the blank step between 3F3 and 7F0 will be 'squeezed' in runtime. + 0F0 0F1 0F2 0F3 4F0 4F1 4F2 4F3 4B0 4B1 4B2 4B3 0B0 0B1 0B2 0B3 + 1F0 1F1 1F2 1F3 5F0 5F1 5F2 5F3 5B0 5B1 5B2 5B3 1B0 1B1 1B2 1B3 + 2F0 2F1 2F2 2F3 6F0 6F1 6F2 6B0 6F3 6B1 6B2 6B3 2B0 2B1 2B2 2B3 + 3F0 3F1 3F2 3F3 7F0 7B0 7F1 7B1 7F2 7B2 7F3 7B3 3B0 3B1 3B2 3B3 + + Here is another example when num_microbatches is 8: + 0F0 0F1 0F2 0F3 4F0 4F1 4F2 4F3 0F4 0F5 0F6 4B0 0F7 4B1 4F4 4B2 4F5 4B3 4F6 0B0 4F7 0B1 0B2 0B3 4B4 4B5 4B6 4B7 0B4 0B5 0B6 0B7 + 1F0 1F1 1F2 1F3 5F0 5F1 5F2 5F3 1F4 5B0 1F5 5B1 1F6 5B2 1F7 5B3 5F4 1B0 5F5 1B1 5F6 1B2 5F7 1B3 5B4 5B5 5B6 5B7 1B4 1B5 1B6 1B7 + 2F0 2F1 2F2 2F3 6F0 6F1 6F2 6B0 6F3 6B1 2F4 6B2 2F5 6B3 2F6 2B0 2F7 2B1 6F4 2B2 6F5 2B3 6F6 6B4 6F7 6B5 6B6 6B7 2B4 2B5 2B6 2B7 + 3F0 3F1 3F2 3F3 7F0 7B0 7F1 7B1 7F2 7B2 7F3 7B3 3F4 3B0 3F5 3B1 3F6 3B2 3F7 3B3 7F4 7B4 7F5 7B5 7F6 7B6 7F7 7B7 3B4 3B5 3B6 3B7 + In this setting, #step for the 3 parts in each rank is: + | rank | warmup | steady | cooldown | + | 0 | 10 | 12 | 10 | + | 1 | 8 | 16 | 8 | + | 2 | 6 | 20 | 6 | + | 3 | 4 | 24 | 4 | + + Based on the example above, we can deduce the whole schedule from the last rank. + For the last pipeline rank, the steady part starts as long as it receives the last forward stage for the + 0-th micro-batch (we index from 0). It is easy to calculate that the last rank's warmup part takes n * (k - 1) + steps. + After the warmup part, the steady part begins: + - in the 0th round, it executes the (k-1)th stage's forward and backward stage for 0th micro batch groups + - in the 1st round, it executes the 0th stage's forward for 1st micro batch group and (k-2)th stage's backward for 0th micro batch group + - in the 2nd round, it executes the 1st stage's forward for 1st micro batch group and (k-3)th stage's backward for 0th micro batch group + - ... + - in the kth round, it executes the (k-1)th stage's forward and backward for 1st micro batch group + - ... + - in the ((q-1) * k)th round, it executes the (k-1)th stage's forward and backward for (q-1)th micro batch group + In all, the steady part takes ((q-1) * k + 1) * n * 2 steps. + The cooldown part for the last rank is symmetric to the warmup part. It takes n * (k - 1) steps to execute the backward + stage for the last micro-batch group on 0-th to (k-2)-th stages. + + Based on the analysis of the last rank, we can deduce the execution order for remaining ranks. For example, for the + (n-2)th rank. The steady part takes 2 less 1f1b pairs than the last rank. Since + - it depends on the backward stage in 0-th 1f1b pair finishes on the last rank + - the forward stage finishes one step earlier than the last rank + As a result, there will be + - 2 additional forward steps in the warmup part to provide the data that 0th and 1st 1f1b pair need for the last rank + - 2 additional backward steps in the cooldown part to consume the data that last and (last-1)th 1f1b pair produce for the last rank + + In general, for the i-th rank: + - the warmup part takes min(n * (k - 1) + 2 * (n - 1 - i), p * k) steps for forward computation + - the steady part is composed of (p * k - warmup_steps) 1f1b pairs + - the cooldown_steps equals to warmup_steps for backward computation + """ + if num_microbatches <= 0: + raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") + segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + fsegs = [seg for seg in segments if seg.isfw()] + assert len(fsegs) == num_stages, f"Mismatch of forward segment number ({len(fsegs)}) with num_stages ({num_stages})" + # collect segments by device assignment info + devs2segs = {} + for seg in fsegs: + cur_devs = tuple(seg.device) + for devs in devs2segs.keys(): + for dev in devs: + if dev in cur_devs: + assert devs == cur_devs, f"find illegal device assignment: {devs} vs {cur_devs} in 1f1b interleaved scheduling" + devs2segs.setdefault(cur_devs, []).append(seg) + assert num_microbatches % len(devs2segs) == 0, f"num_microbatches: {num_microbatches} should be a multiple of the number of pipeline groups: {len(devs2segs)}" + + sched = SchedulePlan(graph, num_microbatches) + # an adapter class to fit in torch's implementation + class ScheduleInfo: + def __init__(self, pp_group_size, num_stages, num_micro_batch): + self.pp_group_size = pp_group_size + self.n_local_stages = num_stages // pp_group_size + self.num_of_rounds = max(1, num_micro_batch // pp_group_size) + self.microbatches_per_round = num_micro_batch // self.num_of_rounds + self._n_microbatches = num_micro_batch + assert num_micro_batch % self.num_of_rounds == 0 + + from nnscaler.graph.schedule.interleaved_1f1b import _calculate_single_rank_operations + pp_group_size = len(devs2segs) + schedule_info = ScheduleInfo(pp_group_size, num_stages, num_microbatches) + for rank in range(pp_group_size): + rank_ops = _calculate_single_rank_operations(schedule_info, rank) + for step, op in enumerate(rank_ops): + # use None to represent the blank step + if op is None: continue + seg = fsegs[op.stage_index] + if str(op.computation_type) == 'B': + seg = seg.mirror + sched.add_segment(seg, op.microbatch_index, step) + + sched.finish() + return sched + @staticmethod def sched_1f1b_plus(graph: IRGraph, num_microbatches: int, num_stages: int) -> SchedulePlan: """1F1B Plus Scheduling. @@ -61,7 +207,7 @@ def sched_1f1b_plus(graph: IRGraph, num_microbatches: int, num_stages: int) -> S f0 f0 f1 f1 f2 f2 | f3 f3 b0 | b1 b2 b3 f0 f0 f1 f1 f2 f2 | f3 b0 f3 | b1 b2 b3 f0 f1 f0 f2 f1 b0 | f3 f2 b1 | f3 b2 b3 - f0 f1 f0 f2 b0 f1 | f3 b1 f2 | b2 f3 b + f0 f1 f0 f2 b0 f1 | f3 b1 f2 | b2 f3 b3 """ if num_microbatches <= 0: raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") @@ -141,7 +287,7 @@ def sched_gpipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> Sched raise ValueError(f"expected num_microbatches > 0, but got {num_microbatches} ") segments: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) fsegs = [seg for seg in segments if seg.isfw()] - assert len(fsegs) == num_stages, "Mismatch of forward segement number with num_stages" + assert len(fsegs) == num_stages, "Mismatch of forward segment number with num_stages" # describe schedule sched = SchedulePlan(graph, num_microbatches) @@ -226,7 +372,7 @@ def sched_infer_pipe(graph: IRGraph, num_microbatches: int, num_stages: int) -> """ fsegs: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) assert all(seg.isfw() for seg in fsegs), f"Detect backward. The predefined scheduling only applies for inference" - assert len(fsegs) == num_stages, "Mismatch of forward segement number with num_stages" + assert len(fsegs) == num_stages, "Mismatch of forward segment number with num_stages" # describe schedule sched = SchedulePlan(graph, num_microbatches) fwait_steps = [sid for sid in range(num_stages)] diff --git a/nnscaler/graph/schedule/schedplan.py b/nnscaler/graph/schedule/schedplan.py index 69914684..4c20ed63 100644 --- a/nnscaler/graph/schedule/schedplan.py +++ b/nnscaler/graph/schedule/schedplan.py @@ -1,6 +1,41 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +"""In pipeline parallelism, we want to execute micro-batches of samples on multiple devices +simultaneously. To achieve this, the computation dataflow graph is split into +several sub-graph (`IRSegment` in nnScaler), each of which has been assigned to related +devices (by default, we assumed a forward segment and its corresponding backward segment +share a same device placement). +In a pipeline schedule, each segment is replicated and annotated with different micro-batch +index (`Block` in nnScaler). Therefore, the schedule plan is a list of `Block` on each device. +To be valid, the blocks should satisfy the data dependency constraints, i.e., if block A +and block B share the same micro-batch index, and segment in B is dependent on segment in A, +then block A should be executed before block B. Note that there is no data dependency +between blocks with different micro-batch index. +Given the schedule (execution orders of blocks in each device), we can estimate the execution +time of the whole pipeline by: +- the execution time (span) of each block +- the data dependency between blocks belonging to different device groups +- the communication time between devices + +In nnScaler's current implementation, we use a global view to define and validate the schedule +plan. To be more specific +- the execution time is discreted into integer steps +- each block takes an integer span to finish execution +- the communication time between blocks is omitted +- a valid schedule plan should satisfy that if block A depends on block B, then the start step + of block A should not be earlier than end time of block B. + +This implementation may be improved in the future, since +- it is hard to estimate the span of each block before the actual execution +- the communication time between blocks cannot be ignored in a real system, especially when + the network bandwidth is limited +- the real start time of each block may be different from that defined in the schedule plan. + Dependencies are materialized by inserting send and recv adapters between blocks. As a result, + a block may start as soon as its input data is ready, rather than strictly following the start + time in the schedule plan. +""" + from typing import Dict, List, Optional, Tuple, Set from nnscaler.ir.cten import IRCell @@ -384,12 +419,13 @@ def validate(self) -> bool: """ Validate the plan to check if it satisfies data dependency - @return valid bool + Returns: + valid (bool): whether the plan is valid """ for block1 in self._blocks: for block2 in self._blocks: if self._dependency.depends(block1, block2): - if self.start(block1) >= self.start(block2): + if self.start(block1) + block1.span > self.start(block2): return False return True diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index 9f4c0e0a..d5bb0656 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -615,7 +615,7 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): if isinstance(fn, MethodType): fn = fn.__func__ - assert isinstance(fn, FunctionType) + assert isinstance(fn, FunctionType), f"Expected a function, but got {fn} with type {type(fn)}" fn_globals = fn.__globals__ # run before it gets patched diff --git a/tests/graph/schedule/test_interleaved_1f1b.py b/tests/graph/schedule/test_interleaved_1f1b.py new file mode 100644 index 00000000..593ff23f --- /dev/null +++ b/tests/graph/schedule/test_interleaved_1f1b.py @@ -0,0 +1,172 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.distributed +import torch.nn as nn +import tempfile +import shutil +import contextlib +import pytest +from pathlib import Path + + +import nnscaler +from nnscaler import parallelize, ComputeConfig, ParallelModule +from nnscaler.parallel import build_optimizer, sync_grad_when, merge_state_dicts +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph import IRGraph +from nnscaler.ir.adapter import IRAdapter +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.ir.operator import IRFwOperation, IRDataOperation +from nnscaler.graph.segment import IRSegment +from nnscaler.graph.schedule.predefined import PredefinedSched +from tests.utils import clear_dir_on_rank0, init_random +from tests.launch_torchrun import torchrun +from tests.parallel_module.common import assert_equal +from tests.parallel_module.test_gencode import _gencode_contains +from tests.launch_torchrun import launch_torchrun, clone_to_cpu_recursively + + +class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(32, 32, bias=False) + self.fc2 = torch.nn.Linear(32, 32, bias=False) + self.fc3 = torch.nn.Linear(32, 32, bias=False) + self.fc4 = torch.nn.Linear(32, 32, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.fc4(x) + return x.sum() + + +def policy_1f1b(graph, cfg): + data_loader, fc1, fc2, fc3, fc4, loss = graph.nodes()[:6] + graph.staging([fc1, fc3,]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + sub_nodes = graph.replicate(data_loader, 2) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + graph.assign(fc1, 0) + graph.assign(fc2, 0) + + identity = stages[1].nodes()[0] + graph.assign(identity, 1) + graph.assign(fc3, 1) + graph.assign(fc4, 1) + graph.assign(loss, 1) + + PredefinedSched.sched_1f1b(graph, cfg.pas_config['n_micro_batches'], len(stages)) + + return graph + + +def policy_1f1b_interleaved(graph, cfg): + data_loader, fc1, fc2, fc3, fc4, loss = graph.nodes()[:6] + graph.staging([fc1, fc2, fc3, fc4]) + stages = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + sub_nodes = graph.replicate(data_loader, 2) + for i, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, i) + + graph.assign(fc1, 0) + + identity = stages[1].nodes()[0] + graph.assign(identity, 1) + graph.assign(fc2, 1) + + identity = stages[2].nodes()[0] + graph.assign(identity, 0) + graph.assign(fc3, 0) + + identity = stages[3].nodes()[0] + graph.assign(identity, 1) + graph.assign(fc4, 1) + graph.assign(loss, 1) + + PredefinedSched.sched_1f1b_interleaved(graph, cfg.pas_config['n_micro_batches'], len(stages)) + + return graph + + +def _train_pp(model: ParallelModule, num_replicas, rank): + mbs = model.nmicros_per_scheduler_step + assert model.use_scheduler + init_random() + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + data = [] + DATA_SIZE = mbs * 4 + for _ in range(DATA_SIZE): + data.append( + torch.randn((2, 32), device='cuda', dtype=torch.float32) + ) + data = [data[i] for i in range(rank, DATA_SIZE, num_replicas)] + chunks = [data[i:i + mbs] for i in range(0, len(data), mbs)] + results = [] + for _, x in enumerate(chunks): + model.train() + _ = model.train_step(x) + optimizer.step() + optimizer.zero_grad() + results.append(clone_to_cpu_recursively(model.state_dict())) + return results + + +def worker_pipeline_2(n_micro_batches): + nnscaler.init() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + m = Model() + m.train() + trace_data = torch.randn([2, 32], dtype=torch.float32, device=torch.cuda.current_device()) + cfg = ComputeConfig(2, 2, use_end2end=True, pas_config=dict(n_micro_batches=n_micro_batches)) + + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_1f1b_interleaved') as tempdir: + pm_1f1b = parallelize( + m, + {'x': trace_data}, + policy_1f1b, + cfg, + reuse='override', + gen_savedir=tempdir, + instance_name='1f1b', + ).cuda() + pm_1f1b_interleaved = parallelize( + m, + {'x': trace_data}, + policy_1f1b_interleaved, + cfg, + reuse='override', + gen_savedir=tempdir, + instance_name='1f1b_interleaved', + ).cuda() + + results_1f1b = _train_pp(pm_1f1b, 1, 0) + results_1f1b_interleaved = _train_pp(pm_1f1b_interleaved, 1, 0) + return (results_1f1b, results_1f1b_interleaved) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('n_micro_batches', [2, 4, 6]) +def test_interleaved_1f1b(n_micro_batches): + results = launch_torchrun(2, worker_pipeline_2, n_micro_batches) + results_1f1b0, results_1f1b_interleaved0 = results[0] + results_1f1b1, results_1f1b_interleaved1 = results[1] + + assert len(results_1f1b0) == len(results_1f1b_interleaved0) + + for i in range(len(results_1f1b0)): + assert_equal( + merge_state_dicts([results_1f1b0[i], results_1f1b1[i]]), + merge_state_dicts([results_1f1b_interleaved0[i], results_1f1b_interleaved1[i]]) + ) From 2952d2bd2f5cfecd7296e22c64f0a728e097b340 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 20 Jan 2025 06:36:11 +0000 Subject: [PATCH 1799/1892] Merged PR 2359: [BugFix] Insert multiref for single stage correctly --- nnscaler/autodist/apis.py | 38 ++++---- ....json => replicated_and_partition_pp.json} | 0 .../pas/replicated_and_partition_spmd.json | 88 +++++++++++++++++++ tests/autodist/pas/test_multiref_param.py | 27 ++++-- 4 files changed, 127 insertions(+), 26 deletions(-) rename tests/autodist/pas/{replicated_and_partition.json => replicated_and_partition_pp.json} (100%) create mode 100644 tests/autodist/pas/replicated_and_partition_spmd.json diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index beae2a56..7fcd229a 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -178,25 +178,29 @@ def subtensor_desc(t): break assert find_desc, f'node {consumer} not found in any stage' + # add multiref for shared parameters across stages + # note that we have constrained that shared parameters cannot be partitioned in SPMDSolver, other input tensors + # belonging to the same operator can be partitioned. For example, in some LLMs, the embedding matrix is shared + # with the output layer. In this case, the batch dim / seq dim of the activation tensor can be partitioned. + for ftensor, stage_info in tensor_split_info.items(): + if not ftensor.is_param(): + continue + splits = set() + find_replicated = False + for stage_splits in stage_info.values(): + splits.update(stage_splits) + if any(s[0] == 'REPLICATED' for s in stage_splits): + find_replicated = True + splits = list(splits) + # For safety, we will add multiref when detecting shared param are all replicated for pipeline parallelism. + # The reason is that stages may have different number of devices, it is hard to synchronize gradients directly + # by inserting reducers although weights are all REPLICAED. + if len(splits) > 1 or (len(pp_desc.spmd_descs) > 1 and find_replicated): + _logger.info(f'add multiref for shared param {ftensor}') + graph.multiref(ftensor, comment='shared param') + # graph staging if len(pp_desc.spmd_descs) > 1: - # add multiref for shared parameters across stages - # note that we have constrained that shared parameters cannot - # be partitioned in SPMDSolver. - for ftensor, stage_info in tensor_split_info.items(): - if not ftensor.is_param(): - continue - splits = set() - find_replicated = False - for stage_splits in stage_info.values(): - splits.update(stage_splits) - if any(s[0] == 'REPLICATED' for s in stage_splits): - find_replicated = True - splits = list(splits) - if len(splits) > 1 or find_replicated: - _logger.info(f'add multiref for shared param {ftensor}') - graph.multiref(ftensor, comment='shared param') - stages = [] for spmd_desc in pp_desc.spmd_descs: stage = [] diff --git a/tests/autodist/pas/replicated_and_partition.json b/tests/autodist/pas/replicated_and_partition_pp.json similarity index 100% rename from tests/autodist/pas/replicated_and_partition.json rename to tests/autodist/pas/replicated_and_partition_pp.json diff --git a/tests/autodist/pas/replicated_and_partition_spmd.json b/tests/autodist/pas/replicated_and_partition_spmd.json new file mode 100644 index 00000000..5a05dd6f --- /dev/null +++ b/tests/autodist/pas/replicated_and_partition_spmd.json @@ -0,0 +1,88 @@ +{ + "desc": { + "spmd_descs": [ + { + "partition_descs": [ + [ + 2, + [ + [ + [ + -1, + -1 + ], + 4 + ] + ] + ], + [ + 3, + [ + [ + [ + 0, + 0 + ], + 4 + ] + ] + ], + [ + 4, + [ + [ + [ + 0, + 0 + ], + 4 + ] + ] + ], + [ + 5, + [ + [ + [ + 0, + 0 + ], + 4 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 4 + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } + } + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 4 + ] + }, + "e2e_time": 0.0, + "stage_mems": [ + 0.0, 0.0 + ], + "stage_all_times": [ + 0.0, 0.0 + ], + "stage_comp_times": [ + 0.0, 0.0 + ] +} diff --git a/tests/autodist/pas/test_multiref_param.py b/tests/autodist/pas/test_multiref_param.py index ef8c958e..5af156c3 100644 --- a/tests/autodist/pas/test_multiref_param.py +++ b/tests/autodist/pas/test_multiref_param.py @@ -34,7 +34,7 @@ def forward(self, x): @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') -@pytest.mark.parametrize('cfg_fname', ['all_replicated_pp.json', 'replicated_and_partition.json']) +@pytest.mark.parametrize('cfg_fname', ['all_replicated_pp.json', 'replicated_and_partition_pp.json', 'replicated_and_partition_spmd.json']) def test_shared_param_pipeline(cfg_fname): bsz, hidden_dim = 4, 1024 @@ -68,15 +68,24 @@ def test_shared_param_pipeline(cfg_fname): plan_path = Path(os.path.dirname(__file__)) / cfg_fname cfg = AutoDistConfig(load_plan_path=plan_path, mesh_col=4) graph = parallelize_graph(ir_graph, cfg) - assert isinstance(graph.nodes()[4], IRSegment) - # check multiref is correctly inserted at the 1st IRSegment (pipeline stage) - has_multiref = False - for node in graph.nodes()[4].nodes(): - if node.signature == 'nnscaler.runtime.function.multiref': - has_multiref = True - break - assert has_multiref + if 'pp' in cfg_fname: + assert isinstance(graph.nodes()[4], IRSegment) + # check multiref is correctly inserted at the 1st IRSegment (pipeline stage) + has_multiref = False + for node in graph.nodes()[4].nodes(): + if node.signature == 'nnscaler.runtime.function.multiref': + has_multiref = True + break + assert has_multiref + else: + has_multiref = False + for node in graph.nodes(): + if node.signature == 'nnscaler.runtime.function.multiref': + assert not has_multiref, 'multiple multiref nodes found' + has_multiref = True graph = IRAdapterGener.gen(graph, cost_fn=None) if graph.sched is not None: graph.sched.apply() + + assert True, 'should not raise exception' From 5a3f1781758851b0dc307a1e461d6cd0d12958a1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 22 Jan 2025 01:52:16 +0000 Subject: [PATCH 1800/1892] Merged PR 2351: [Codegen] add async support for pipeline parallelism add support for async pipeline background: https://github.com/microsoft/nnscaler/pull/23 unit tests pass parity check pass --- docs/source/trainer.rst | 45 ++-- nnscaler/codegen/module/module.py | 3 +- nnscaler/codegen/schedule/schedule.py | 35 ++- nnscaler/graph/schedule/schedplan.py | 42 ++-- nnscaler/parallel.py | 2 - nnscaler/policies.py | 25 +- tests/cli/test_arg_parser.py | 16 ++ tests/parallel_module/test_async.py | 347 ++++++++++++++++++++++++++ 8 files changed, 455 insertions(+), 60 deletions(-) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index a4e7331a..b79a59f1 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -17,7 +17,7 @@ All the arguments are defined in ``TrainerArgs`` class. Here is the definition o @dataclass class TrainerArgs: compute_config: ComputeConfig = None - + gen_savedir: str = './.nnscaler' gen_reuse: str = 'auto' pas_policy: str = 'autodist' @@ -25,7 +25,7 @@ All the arguments are defined in ``TrainerArgs`` class. Here is the definition o instance_name: str = None run_mode: str = 'run' tracing_from_weights: str = None - + model: ModelConfig = field(default_factory=ModelConfig) optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) dataset: DatasetConfig = field(default_factory=DatasetConfig) @@ -35,22 +35,22 @@ All the arguments are defined in ``TrainerArgs`` class. Here is the definition o checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) log: List[LogConfig] = field(default_factory=list) hook: Union[HookConfig, HookMapConfig, None] = None - + precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None - + micro_batch_size: int = 1 global_batch_size: Optional[int] = None grad_accumulation_steps: Optional[int] = None - + max_epochs: Optional[int] = None max_train_steps: Optional[int] = None max_val_steps: Optional[int] = None - + val_every_n_train_steps: Optional[int] = None val_every_n_epochs: Optional[int] = 1 - + enable_progress_bar: bool = True - + seed: Optional[int] = None init_env_fn: str = None @@ -378,17 +378,17 @@ Checkpoint Config class CheckpointConfig: save_dir: str = './checkpoints' no_save: bool = False - + save_type: str = 'sharded' - + save_last: bool = True save_best: bool = True symlink_best_and_last: bool = True - + every_n_train_steps: Optional[int] = None every_n_epochs: Optional[int] = None keep_last_n_checkpoints: Optional[int] = None - + resume_from: str = None * ``save_dir`` (``str``): The directory to save the checkpoints. @@ -488,25 +488,25 @@ CONFIG_FILE is the path to the configuration yaml file. It looks like (taken fro constant_folding: true use_zero: true use_end2end: true - + run_mode: run pas_policy: autodist micro_batch_size: 2 global_batch_size: 8 max_epochs: 4 max_train_steps: 10 - + model: type: tests.cli.common.MLP args: dim: 16 nlayers: 16 - + optimizer: type: torch.optim.Adam args: lr: 0.01 - + dataset: type: tests.cli.common.SimpleDataset train_args: @@ -515,7 +515,7 @@ CONFIG_FILE is the path to the configuration yaml file. It looks like (taken fro val_args: dim: 16 size: 10 - + checkpoint: keep_last_n_checkpoints: 30 every_n_train_steps: 1 @@ -542,16 +542,16 @@ The configuration of the compute environment. It is a dataclass with the followi class ComputeConfig: plan_ngpus: int runtime_ngpus: int - + constant_folding: bool = False trace_strategy: Literal['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'] = 'cuda_run_cpu_offload' - + use_zero: bool = False zero_ngroups: int = 1 - + inference_only : bool = False use_end2end: bool = False - + pas_config: Dict[str, Any] = field(default_factory=dict) user_config: Dict[str, Any] = field(default_factory=dict) @@ -712,7 +712,8 @@ The configuration of the PAS policy should be passed in the ``pas_config`` of `` * ``pipeline_nstages``: the number of stages in the pipeline. Default is ``plan_ngpus``. Optional. * ``pipeline_nmicros``: the number of microbatches in the pipeline. Required. - * ``pipeline_scheduler``: the scheduler name for the pipeline. Current we support four schedulers in training ``1f1b``/``1f1b_plus``/``gpipe``/``chimera_direct`` (4 stages pipeline only), and one scheduler in inference ``infer_pipe``. Default is ``1f1b``. Optional. + * ``pipeline_scheduler``: the scheduler name for the pipeline. Current we support four schedulers in training ``1f1b``/``1f1b_plus``/``1f1b_interleaved``/``gpipe``/``chimera_direct`` (4 stages pipeline only), and one scheduler in inference ``infer_pipe``. Default is ``1f1b``. Optional. + * ``pp_size``: the pipeline parallelism size. Default is ``pipeline_nstages``. Optional. #. ``autodist``: the recommended policy for most cases. Currently it only support Adam-like optimizers. It will automatically choose the best partition for you by balancing the memory usage and speed. It has the following configurations. diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 4ac9c853..32cd75fe 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -125,7 +125,8 @@ def __init__( 'from typing import *', 'from pathlib import Path', 'import torch', 'import torch.utils.checkpoint as ckpt', - 'import nnscaler', 'import _operator', 'from numpy import inf', 'import builtins', '', + 'import nnscaler', 'import nnscaler.flags', + 'import _operator', 'from numpy import inf', 'import builtins', '', f'runtime_version = {runtime_version!r}', '', '' ] diff --git a/nnscaler/codegen/schedule/schedule.py b/nnscaler/codegen/schedule/schedule.py index 8484ef5b..e2c44d27 100644 --- a/nnscaler/codegen/schedule/schedule.py +++ b/nnscaler/codegen/schedule/schedule.py @@ -71,6 +71,10 @@ def gen(self, device: int, outfile=None, attach=None) -> str: gencode = copy.copy(self.init_code) device_map = device % len(self.devices) device_nodes = self.execplan.seq(device_map) + # We will manually set the `skip_reducer` flag and `skip_zero_grad` + # when we use scheduler (i.e. pipeline parallelism) + # Otherwise, the caller of `_train_step` will set these flags to support gradient accumulation + use_scheduler = self.execplan.graph.sched is not None assert all(not isinstance(n, IRFwOperation) for n in device_nodes), \ "Expected all forward operators have been grouped into IRSegment" @@ -82,13 +86,42 @@ def gen(self, device: int, outfile=None, attach=None) -> str: with FunctionBlock(func_name='_train_step', args=args) as fb: fb.insert_body('_ = None') + + if use_scheduler: + fb.insert_body('nnscaler.flags.RuntimeFlag.skip_zero_grad = False') fb.insert_body('model.zero_grad()') + # body code if len(device_nodes) == 0: fb.insert_body('pass') else: + def _is_backward_segment(node: IRCell) -> bool: + node = node.cell if isinstance(node, ExeReuseCell) else node + return isinstance(node, IRSegment) and not node.isfw() + + # collect backward segments that needs to reduce gradients + # which are the last backward segments of every stage. + # (Every segment will be used multiple times via `ExeReuseCell`) + last_backward_node_oids = [] + if use_scheduler: + # Key: segment id + # Value: the last backward ExeReuseCell of the segment + last_backwards = {} + for node in device_nodes[::-1]: + if not _is_backward_segment(node): + continue + assert isinstance(node, ExeReuseCell), 'Expected ExeReuseCell for backward segment when using scheduler' + if node.cell.cid not in last_backwards: + last_backwards[node.cell.cid] = node + last_backward_node_oids = [id(node) for node in last_backwards.values()] + for line, node in enumerate(device_nodes): - # execute + # when use scheduler, skip reducer if it is not the last backward of same segments + if use_scheduler and _is_backward_segment(node): + fb.insert_body( + f'nnscaler.flags.RuntimeFlag.skip_reducer = ' + f'{id(node) not in last_backward_node_oids !r}' + ) codes = self.emit_node(node) fb.insert_body(codes) # release diff --git a/nnscaler/graph/schedule/schedplan.py b/nnscaler/graph/schedule/schedplan.py index 4c20ed63..72ea3aef 100644 --- a/nnscaler/graph/schedule/schedplan.py +++ b/nnscaler/graph/schedule/schedplan.py @@ -55,7 +55,7 @@ class Block: """ def __init__(self, cell: IRCell, micro_batch_id: int, span: int) -> None: - """Create an execution block with IRCell on microbatch index. The + """Create an execution block with IRCell on microbatch index. The block will take `span` steps to finish execution. """ assert isinstance(cell, IRCell), f"Expected IRCell, but got {type(cell)}: {cell}" @@ -67,7 +67,7 @@ def __eq__(self, other): if isinstance(other, Block): return other.content == self.content and other.mid == self.mid return False - + def __hash__(self) -> int: return hash((self._content, self._micro_batch_id)) @@ -78,15 +78,15 @@ def device(self) -> Tuple[int]: @property def mid(self) -> int: return self._micro_batch_id - + @property def content(self) -> IRCell: return self._content - + @property def span(self) -> int: return self._span - + def dispatch(self, devid: int): return Block(self._content.dispatch(devid), self._micro_batch_id) @@ -132,7 +132,7 @@ def build(self): self.senders[adapter] = segment # get all weight reducers self.reducers = self.graph.select(ntype=IRWeightReducer, flatten=False) - + def depends(self, prev: Block, next: Block) -> bool: return prev.mid == next.mid and self.graph.depends(prev.content, next.content) @@ -163,11 +163,11 @@ def __init__(self, graph: IRGraph, _dependency: Optional[ScheduleDependency] = N @property def nsteps(self) -> int: return len(self._step_blocks) - + @property def graph(self) -> IRGraph: return self._graph - + @property def device(self) -> Tuple[int]: device = set() @@ -177,7 +177,7 @@ def device(self) -> Tuple[int]: def nodes(self) -> Tuple[Block]: return tuple(self._seqs) - + def add_block(self, block: Block, step: int): """Add a block to start executing from step""" self._extend_step(step + block.span - 1) @@ -246,10 +246,10 @@ def insert_step(self, step: int, seg: IRSegment, micro_batch_id: int, span: Opti self._block_start_step[block] = step self._blocks.append(block) return block - + def remove_step(self, step: int): """Remove the step if there are no blocks in execution. - + All the blocks after the `step` will be shifted earlier. This can only apply when no adapters are placed. @@ -273,7 +273,7 @@ def remove_step(self, step: int): def shrink(self): """Remove steps that have no blocks in execution - + Note the implementation is costly. Users should avoid calling it many times. """ @@ -295,21 +295,21 @@ def start_blocks(self, step: int) -> Tuple[Block]: blocks = self._step_blocks[step] blocks = tuple(blk for blk in blocks if self.start(blk) == step) return blocks - + def start(self, block: Block) -> int: """Get the start step of the block""" return self._block_start_step[block] - + def all_blocks(self) -> Tuple[Block]: """ Get all segment blocks """ return tuple(self._blocks) - + def depends(self, prev: Block, succ: Block) -> bool: """Check whether prev block directly depends on succ block""" return self._dependency.depends(prev, succ) - + def _extend_step(self, step: int): """Extend the maximal accessible steps of plan to `step` index""" if len(self._step_blocks) <= step: @@ -377,8 +377,6 @@ class SchedulePlan(PlanBase): def __init__(self, graph: IRGraph, num_microbatches: int): super().__init__(graph) - if CompileFlag.async_reducer: - raise NotImplementedError("Async reducer is not supported for schedule plan yet.") # execution sequence self._num_microbatches = num_microbatches # bind to the graph @@ -390,7 +388,7 @@ def nmicros(self) -> int: Get number of micro-batches """ return self._num_microbatches - + @property def graph(self) -> IRGraph: return self._graph @@ -479,12 +477,12 @@ def str(self, show_max_steps: Optional[int] = None) -> str: for block in self._blocks: if block.content not in sids: sids[block.content] = len(sids) - + for idx, (cell, sid) in enumerate(sids.items()): dscp += f'{cell.name}{cell.cid:<3} = {sid}; ' if (idx + 1) % 3 == 0: dscp += '\n' - + dscp += '\nAnnotation: i(f/b)j = segment i on executing (forward/backward) microbatch j' for devid in sorted(self.device): timeline = '\n' @@ -517,6 +515,6 @@ def str(self, show_max_steps: Optional[int] = None) -> str: timeline += f" ... (remaining {self.nsteps-show_max_steps} steps)" dscp += timeline return dscp - + def __repr__(self): return self.str(show_max_steps=20) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 25ce15d1..06a2b1b3 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -180,8 +180,6 @@ def apply_pipeline_scheduler( raise ValueError(f"pipeline_nmicros {pipeline_nmicros} must be > 0.") if pipeline_nstages <= 0: raise ValueError(f"pipeline_nstages {pipeline_nstages} must be > 0.") - if self.plan_ngpus % pipeline_nstages != 0: - raise ValueError(f"pipeline_nstages {pipeline_nstages} must be a multiple of plan_ngpus {self.plan_ngpus}") if pipeline_scheduler not in _PREDEFINE_SCHEDS: raise ValueError(f"pipeline_scheduler {pipeline_scheduler} is not supported. " f"Supported schedulers are {_PREDEFINE_SCHEDS.keys()}") diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 4b6ab30a..c8d4ba91 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -170,21 +170,23 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): """ if not cfg.use_end2end: raise ValueError("Hybrid policy only supports end2end module") - if cfg.use_async_reducer: - raise ValueError("Hybrid policy does not support async reducer") ngpus: int = cfg.plan_ngpus nstages = cfg.pas_config.get('pipeline_nstages', cfg.plan_ngpus) nmicros = cfg.pas_config['pipeline_nmicros'] scheduler = cfg.pas_config.get('pipeline_scheduler', '1f1b') - tp_size: int = cfg.plan_ngpus // nstages - if ngpus % tp_size != 0: - raise ValueError(f'invalid tp_size {tp_size} for ngpus {ngpus}') - pp_size = ngpus // tp_size + pp_size = cfg.pas_config.get('pp_size', nstages) + + if nstages % pp_size != 0: + raise ValueError(f'invalid pp_size {pp_size} for nstages {nstages}') + if ngpus % pp_size != 0: + raise ValueError(f'invalid pp_size {pp_size} for ngpus {ngpus}') + tp_size = ngpus // pp_size + auto_multiref(graph) fnodes = graph.select(ntype=IRFwOperation) - stages = mitr.divide(pp_size, fnodes) + stages = mitr.divide(nstages, fnodes) stages = [list(s) for s in stages] for idx, stage in enumerate(stages): _logger.info(f'> stage {idx}: {stage[0]}') @@ -192,19 +194,18 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): stages: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) stages = [s for s in stages if s.isfw()] - assert len(stages) == pp_size, "Internal Error" + assert len(stages) == nstages, "Internal Error" # stage-wise tensor parallelism curr_devices = list(range(ngpus)) - for stage in stages: + for idx, stage in enumerate(stages): + idx = idx % pp_size + devs = curr_devices[idx * tp_size: (idx + 1)* tp_size] for node in stage.nodes(): - devs = curr_devices[:tp_size] try: _tp(graph, node, devs, idx=0, dim=0) except Exception as e: _replica(graph, node, devs) - curr_devices = curr_devices[tp_size:] - assert len(curr_devices) == 0, f"remaining devices: {curr_devices} not used" # replicate dataloader for dl in graph.select(ntype=IRDataOperation): diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 2bf853e8..e83ac94b 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -188,3 +188,19 @@ class A: x = parse_args(['--p.value=1']) y = deserialize_dataclass(x, A) assert y.p == {'value': 1} + + +def test_deserialize_union_type(): + # union annotation is actually ignored. + # Here the correct result depends on `_guess_deserialize_object` function. + @dataclass + class A: + p: Union[int, str] = None + + x = parse_args(['--p.value=1']) + y = deserialize_dataclass(x, A) + assert y.p == {'value': 1} + + x = parse_args(['--p.value=auto']) + y = deserialize_dataclass(x, A) + assert y.p == {'value': 'auto'} diff --git a/tests/parallel_module/test_async.py b/tests/parallel_module/test_async.py index 1ed114b9..97b770b5 100644 --- a/tests/parallel_module/test_async.py +++ b/tests/parallel_module/test_async.py @@ -160,3 +160,350 @@ def test_dp2(update_freq): for key in whole_async_weights.keys(): assert torch.equal(whole_async_weights[key], sub_async_weights[key]) + + +class OrigModuleEnd2End(torch.nn.Module): + def __init__(self): + super().__init__() + self.orig_module = OrigModule() + self.loss_fn = nn.BCELoss() + + def forward(self, data): + x = data['data'] + x = self.orig_module(x) + loss = self.loss_fn(x, data['target']) + return loss + + +def _train_pp(model: ParallelModule, num_replicas, rank): + mbs = model.nmicros_per_scheduler_step + assert model.use_scheduler + + init_random() + + optimizer = build_optimizer(model, torch.optim.Adam, lr=0.1) + data = [] + DATA_SIZE = 64 + for _ in range(DATA_SIZE): + data.append({ + 'data': torch.randn((2, 4), device='cuda', dtype=torch.float32), + 'target': torch.rand((2, 1), device='cuda', dtype=torch.float32), + }) + data = [data[i] for i in range(rank, DATA_SIZE, num_replicas)] + chunks = [data[i:i + mbs] for i in range(0, len(data), mbs)] + results = [] + for _, x in enumerate(chunks): + model.train() + _ = model.train_step(x) + optimizer.step() + optimizer.zero_grad() + results.append(clone_to_cpu_recursively(model.state_dict())) + return results + + +def _gpu_worker_pp(pas, pp_ngpus, runtime_ngpus, update_freq): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_pp_async') as tempdir: + init_random() + whole_module_async = parallelize( + OrigModuleEnd2End(), { + 'data': { + 'data': torch.randn(2, 4, device=torch.cuda.current_device()), + 'target': torch.rand(2, 1, device=torch.cuda.current_device()) + } + }, + pas, ComputeConfig( + pp_ngpus, runtime_ngpus, use_async_reducer=True, + reducer_bucket_cap_mb=1e-6, + use_end2end=True, + pas_config=dict( + pipeline_nmicros=update_freq, + pipeline_nstages=pp_ngpus, + pipeline_scheduler='1f1b', + ) + ), + gen_savedir=tempdir, + instance_name='async_pp_whole' + ).cuda() + + init_random() + whole_module_sync = parallelize( + OrigModuleEnd2End(), { + 'data': { + 'data': torch.randn(2, 4, device=torch.cuda.current_device()), + 'target': torch.rand(2, 1, device=torch.cuda.current_device()) + } + }, pas, + ComputeConfig( + pp_ngpus, runtime_ngpus, use_async_reducer=False, + reducer_bucket_cap_mb=1e-6, + use_end2end=True, + pas_config=dict( + pipeline_nmicros=update_freq, + pipeline_nstages=pp_ngpus, + pipeline_scheduler='1f1b', + ) + ), + gen_savedir=tempdir, + instance_name='sync_pp_whole' + ).cuda() + + whole_async_results = _train_pp(whole_module_async, runtime_ngpus // pp_ngpus, torch.distributed.get_rank() // pp_ngpus) + whole_sync_results = _train_pp(whole_module_sync, runtime_ngpus // pp_ngpus, torch.distributed.get_rank() // pp_ngpus) + + return ( + whole_async_results, + whole_sync_results, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_pp2(): + results = launch_torchrun(4, _gpu_worker_pp, 'pp', 2, 4, 4) + whole_async0, whole_sync0 = results[0] + whole_async1, whole_sync1 = results[1] + whole_async2, whole_sync2 = results[2] + whole_async3, whole_sync3 = results[3] + + assert len(whole_async0) == len(whole_sync0) + + for iter in range(len(whole_async0)): # for each iteration + assert_equal( + merge_state_dicts( + [whole_async0[iter], whole_async1[iter], whole_async2[iter], whole_async3[iter]] + ), + merge_state_dicts( + [whole_sync0[iter], whole_sync1[iter], whole_sync2[iter], whole_sync3[iter]] + ) + ) + + +def _gpu_worker_interleaved_pp(tempdir, tp_size=1): + init_distributed() + pp_size = 2 + stages = 4 + plan_ngpus = pp_size * tp_size + runtime_ngpus = 4 + update_freq = 8 + # the generated train_step: + # Please note + # 1. the assignment of runtime flags (starting with `nnscaler.flags.RuntimeFlag`) + # 2. Each gpus will hold 2 segments(stages) of pipeline (`segment53` and `segment69`) + # def _train_step(model, dataloader_126): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # data_125 = next(*(dataloader_126, )) + # add_1_92 = nnscaler.runtime.executor.fexecute('segment53', model.segment53, *(data_125, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter256, *(add_1_92, ), requires_grad=False) + # data_316 = next(*(dataloader_126, )) + # add_1_321 = nnscaler.runtime.executor.fexecute('segment53', model.segment53, *(data_316, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter256, *(add_1_321, ), requires_grad=False) + # add_3_102 = nnscaler.runtime.executor.aexecute(model.adapter212, *(), requires_grad=True) + # add_5_112 = nnscaler.runtime.executor.fexecute('segment69', model.segment69, *(add_3_102, ), requires_grad=True) + # add_3_340 = nnscaler.runtime.executor.aexecute(model.adapter212, *(), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter286, *(add_5_112, ), requires_grad=False) + # add_5_360 = nnscaler.runtime.executor.fexecute('segment69', model.segment69, *(add_3_340, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter286, *(add_5_360, ), requires_grad=False) + # gadd_5_157 = nnscaler.runtime.executor.aexecute(model.adapter297, *(), requires_grad=False) + # data_378 = next(*(dataloader_126, )) + # add_1_383 = nnscaler.runtime.executor.fexecute('segment53', model.segment53, *(data_378, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter256, *(add_1_383, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gadd_3_147 = nnscaler.runtime.executor.backward('segment69', (add_3_102, ), (add_5_112, ), (gadd_5_157, )) + # del add_5_112, gadd_5_157 + # gadd_5_361 = nnscaler.runtime.executor.aexecute(model.adapter297, *(), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter223, *(gadd_3_147, ), requires_grad=False) + # del add_3_102, gadd_3_147 + # data_407 = next(*(dataloader_126, )) + # add_1_412 = nnscaler.runtime.executor.fexecute('segment53', model.segment53, *(data_407, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter256, *(add_1_412, ), requires_grad=False) + # add_3_419 = nnscaler.runtime.executor.aexecute(model.adapter212, *(), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gadd_3_341 = nnscaler.runtime.executor.backward('segment69', (add_3_340, ), (add_5_360, ), (gadd_5_361, )) + # del add_5_360, gadd_5_361 + # _ = nnscaler.runtime.executor.aexecute(model.adapter223, *(gadd_3_341, ), requires_grad=False) + # del add_3_340, gadd_3_341 + # gadd_1_137 = nnscaler.runtime.executor.aexecute(model.adapter267, *(), requires_grad=False) + # add_5_451 = nnscaler.runtime.executor.fexecute('segment69', model.segment69, *(add_3_419, ), requires_grad=True) + # add_3_459 = nnscaler.runtime.executor.aexecute(model.adapter212, *(), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter286, *(add_5_451, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # _ = nnscaler.runtime.executor.backward('segment53', (), (add_1_92, ), (gadd_1_137, )) + # del add_1_92, gadd_1_137 + # gadd_1_322 = nnscaler.runtime.executor.aexecute(model.adapter267, *(), requires_grad=False) + # add_5_493 = nnscaler.runtime.executor.fexecute('segment69', model.segment69, *(add_3_459, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter286, *(add_5_493, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # _ = nnscaler.runtime.executor.backward('segment53', (), (add_1_321, ), (gadd_1_322, )) + # del add_1_321, gadd_1_322 + # gadd_5_452 = nnscaler.runtime.executor.aexecute(model.adapter297, *(), requires_grad=False) + # data_520 = next(*(dataloader_126, )) + # add_1_525 = nnscaler.runtime.executor.fexecute('segment53', model.segment53, *(data_520, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter256, *(add_1_525, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gadd_3_420 = nnscaler.runtime.executor.backward('segment69', (add_3_419, ), (add_5_451, ), (gadd_5_452, )) + # del add_5_451, gadd_5_452 + # gadd_5_494 = nnscaler.runtime.executor.aexecute(model.adapter297, *(), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter223, *(gadd_3_420, ), requires_grad=False) + # del add_3_419, gadd_3_420 + # data_549 = next(*(dataloader_126, )) + # add_1_554 = nnscaler.runtime.executor.fexecute('segment53', model.segment53, *(data_549, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter256, *(add_1_554, ), requires_grad=False) + # add_3_561 = nnscaler.runtime.executor.aexecute(model.adapter212, *(), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gadd_3_460 = nnscaler.runtime.executor.backward('segment69', (add_3_459, ), (add_5_493, ), (gadd_5_494, )) + # del add_5_493, gadd_5_494 + # _ = nnscaler.runtime.executor.aexecute(model.adapter223, *(gadd_3_460, ), requires_grad=False) + # del add_3_459, gadd_3_460 + # gadd_1_384 = nnscaler.runtime.executor.aexecute(model.adapter267, *(), requires_grad=False) + # add_5_593 = nnscaler.runtime.executor.fexecute('segment69', model.segment69, *(add_3_561, ), requires_grad=True) + # add_3_601 = nnscaler.runtime.executor.aexecute(model.adapter212, *(), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter286, *(add_5_593, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # _ = nnscaler.runtime.executor.backward('segment53', (), (add_1_383, ), (gadd_1_384, )) + # del add_1_383, gadd_1_384 + # gadd_1_413 = nnscaler.runtime.executor.aexecute(model.adapter267, *(), requires_grad=False) + # add_5_635 = nnscaler.runtime.executor.fexecute('segment69', model.segment69, *(add_3_601, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter286, *(add_5_635, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # _ = nnscaler.runtime.executor.backward('segment53', (), (add_1_412, ), (gadd_1_413, )) + # del add_1_412, gadd_1_413 + # gadd_5_594 = nnscaler.runtime.executor.aexecute(model.adapter297, *(), requires_grad=False) + # data_662 = next(*(dataloader_126, )) + # add_1_667 = nnscaler.runtime.executor.fexecute('segment53', model.segment53, *(data_662, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter256, *(add_1_667, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gadd_3_562 = nnscaler.runtime.executor.backward('segment69', (add_3_561, ), (add_5_593, ), (gadd_5_594, )) + # del add_5_593, gadd_5_594 + # gadd_5_636 = nnscaler.runtime.executor.aexecute(model.adapter297, *(), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter223, *(gadd_3_562, ), requires_grad=False) + # del add_3_561, gadd_3_562 + # data_691 = next(*(dataloader_126, )) + # add_1_696 = nnscaler.runtime.executor.fexecute('segment53', model.segment53, *(data_691, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter256, *(add_1_696, ), requires_grad=False) + # add_3_703 = nnscaler.runtime.executor.aexecute(model.adapter212, *(), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gadd_3_602 = nnscaler.runtime.executor.backward('segment69', (add_3_601, ), (add_5_635, ), (gadd_5_636, )) + # del add_5_635, gadd_5_636 + # _ = nnscaler.runtime.executor.aexecute(model.adapter223, *(gadd_3_602, ), requires_grad=False) + # del add_3_601, gadd_3_602 + # gadd_1_526 = nnscaler.runtime.executor.aexecute(model.adapter267, *(), requires_grad=False) + # add_5_735 = nnscaler.runtime.executor.fexecute('segment69', model.segment69, *(add_3_703, ), requires_grad=True) + # add_3_743 = nnscaler.runtime.executor.aexecute(model.adapter212, *(), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter286, *(add_5_735, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # _ = nnscaler.runtime.executor.backward('segment53', (), (add_1_525, ), (gadd_1_526, )) + # del add_1_525, gadd_1_526 + # gadd_1_555 = nnscaler.runtime.executor.aexecute(model.adapter267, *(), requires_grad=False) + # add_5_777 = nnscaler.runtime.executor.fexecute('segment69', model.segment69, *(add_3_743, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter286, *(add_5_777, ), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # _ = nnscaler.runtime.executor.backward('segment53', (), (add_1_554, ), (gadd_1_555, )) + # del add_1_554, gadd_1_555 + # gadd_5_736 = nnscaler.runtime.executor.aexecute(model.adapter297, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gadd_3_704 = nnscaler.runtime.executor.backward('segment69', (add_3_703, ), (add_5_735, ), (gadd_5_736, )) + # del add_5_735, gadd_5_736 + # gadd_5_778 = nnscaler.runtime.executor.aexecute(model.adapter297, *(), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter223, *(gadd_3_704, ), requires_grad=False) + # del add_3_703, gadd_3_704 + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gadd_3_744 = nnscaler.runtime.executor.backward('segment69', (add_3_743, ), (add_5_777, ), (gadd_5_778, )) + # del add_5_777, gadd_5_778 + # _ = nnscaler.runtime.executor.aexecute(model.adapter223, *(gadd_3_744, ), requires_grad=False) + # del add_3_743, gadd_3_744 + # gadd_1_668 = nnscaler.runtime.executor.aexecute(model.adapter267, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # _ = nnscaler.runtime.executor.backward('segment53', (), (add_1_667, ), (gadd_1_668, )) + # del add_1_667, gadd_1_668 + # gadd_1_697 = nnscaler.runtime.executor.aexecute(model.adapter267, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # _ = nnscaler.runtime.executor.backward('segment53', (), (add_1_696, ), (gadd_1_697, )) + # del add_1_696, gadd_1_697 + # binary_cross_entropy_82 = nnscaler.runtime.executor.aexecute(model.adapter236, *(), requires_grad=True) + # binary_cross_entropy_391 = nnscaler.runtime.executor.aexecute(model.adapter236, *(), requires_grad=True) + # binary_cross_entropy_502 = nnscaler.runtime.executor.aexecute(model.adapter236, *(), requires_grad=True) + # binary_cross_entropy_533 = nnscaler.runtime.executor.aexecute(model.adapter236, *(), requires_grad=True) + # binary_cross_entropy_644 = nnscaler.runtime.executor.aexecute(model.adapter236, *(), requires_grad=True) + # binary_cross_entropy_675 = nnscaler.runtime.executor.aexecute(model.adapter236, *(), requires_grad=True) + # binary_cross_entropy_786 = nnscaler.runtime.executor.aexecute(model.adapter236, *(), requires_grad=True) + # binary_cross_entropy_809 = nnscaler.runtime.executor.aexecute(model.adapter236, *(), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.reducer548, *(), requires_grad=False) + # return binary_cross_entropy_82, binary_cross_entropy_391, binary_cross_entropy_502, binary_cross_entropy_533, binary_cross_entropy_644, binary_cross_entropy_675, binary_cross_entropy_786, binary_cross_entropy_809 + + init_random() + whole_module_async = parallelize( + OrigModuleEnd2End(), { + 'data': { + 'data': torch.randn(2, 4, device=torch.cuda.current_device()), + 'target': torch.rand(2, 1, device=torch.cuda.current_device()) + } + }, + 'hybrid', ComputeConfig( + plan_ngpus, runtime_ngpus, use_async_reducer=True, + reducer_bucket_cap_mb=1e-6, + use_end2end=True, + pas_config=dict( + pipeline_nmicros=update_freq, + pipeline_nstages=stages, + pipeline_scheduler='1f1b_interleaved', + pp_size=pp_size, + ) + ), + gen_savedir=tempdir, + instance_name='async_interleaved_pp_whole' + ).cuda() + + init_random() + whole_module_sync = parallelize( + OrigModuleEnd2End(), { + 'data': { + 'data': torch.randn(2, 4, device=torch.cuda.current_device()), + 'target': torch.rand(2, 1, device=torch.cuda.current_device()) + } + }, 'hybrid', + ComputeConfig( + plan_ngpus, runtime_ngpus, use_async_reducer=False, + reducer_bucket_cap_mb=1e-6, + use_end2end=True, + pas_config=dict( + pipeline_nmicros=update_freq, + pipeline_nstages=stages, + pipeline_scheduler='1f1b_interleaved', + pp_size=pp_size, + ) + ), + gen_savedir=tempdir, + instance_name='sync_interleaved_pp_whole' + ).cuda() + + whole_async_results = _train_pp(whole_module_async, runtime_ngpus // pp_size, torch.distributed.get_rank() // pp_size) + whole_sync_results = _train_pp(whole_module_sync, runtime_ngpus // pp_size, torch.distributed.get_rank() // pp_size) + + return ( + whole_async_results, + whole_sync_results, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('tp_size', [1, 2]) +def test_interleaved_pp(tmp_path, tp_size): + results = launch_torchrun(4, _gpu_worker_interleaved_pp, tmp_path, tp_size) + whole_async0, whole_sync0 = results[0] + whole_async1, whole_sync1 = results[1] + whole_async2, whole_sync2 = results[2] + whole_async3, whole_sync3 = results[3] + + assert len(whole_async0) == len(whole_sync0) + + for iter in range(len(whole_async0)): # for each iteration + assert_equal( + merge_state_dicts( + [whole_async0[iter], whole_async1[iter], whole_async2[iter], whole_async3[iter]] + ), + merge_state_dicts( + [whole_sync0[iter], whole_sync1[iter], whole_sync2[iter], whole_sync3[iter]] + ) + ) From 30b42862daf173e72ae7b7ee177211b14b8e7298 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 23 Jan 2025 02:43:32 +0000 Subject: [PATCH 1801/1892] Merged PR 2357: [Trainer] Add partial parallelized module support. Add partial parallelized module support. --- docs/source/parallel_module.md | 8 +- docs/source/pytorch_lightning.md | 2 +- docs/source/trainer.rst | 93 +++++ nnscaler/cli/mixed_module.py | 333 ++++++++++++++++++ nnscaler/cli/trainer.py | 88 +---- nnscaler/cli/trainer_args.py | 263 +++++++++++--- nnscaler/parallel.py | 130 +++++-- tests/cli/common.py | 72 ++++ tests/cli/test_train_args.py | 13 +- tests/cli/test_trainer.py | 288 ++++++++++++++- tests/cli/trainer_args.yaml | 1 + tests/parallel_module/test_gencode.py | 106 +++++- tests/parallel_module/test_init.py | 1 - tests/parallel_module/test_nested.py | 3 +- tests/parallel_module/test_override.py | 14 +- .../test_shared_param_pipeline.py | 6 +- tests/utils.py | 32 +- 17 files changed, 1254 insertions(+), 199 deletions(-) create mode 100644 nnscaler/cli/mixed_module.py diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 3cd1df5f..4bdee9a2 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -320,7 +320,7 @@ We have `build_optimizer` to build an optimizer for distributed training. def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], - *args, + compute_config: Optional[ComputeConfig] = None, **kwargs, ) -> OptimizerT: ``` @@ -329,8 +329,10 @@ It has the following parameters: - `optimizer_fn` (`Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]`): It can be the optimizer class or optimizer factory function. The first parameter of the `optimizer_fn` should be the module parameters. -- *args: other args for `optimizer_fn` besides module parameters. -- **kwargs: the kwargs will pass to `optimizer_fn` +- compute_config (Optional[ComputeConfig]): + The config will be used to generate communication reducer. + If it is None, Default configuration will be used when creating reducer for non-parallel modules. +- **kwargs: the kwargs will pass to `optimizer_fn`. To support distributed training, in the function we need to hook 4 places (which we have done for you in `build_optimizer`. That's why you should use `build_optimizer` to create optimizer): diff --git a/docs/source/pytorch_lightning.md b/docs/source/pytorch_lightning.md index 8ca793ac..0bea4486 100644 --- a/docs/source/pytorch_lightning.md +++ b/docs/source/pytorch_lightning.md @@ -78,7 +78,7 @@ where `checkpoint_files` is a list of checkpoint files to merge, and `output_fil ## Limitations -Currently, nnScaler only supports: +Currently, nnScaler only supports: - single parameter group. - single optimizer. - single learning rate scheduler. \ No newline at end of file diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index b79a59f1..58225b73 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -142,6 +142,98 @@ Component Configs class ModelConfig: type: str = None args: Dict[str, Any] = field(default_factory=dict) + parallel_modules: list[ModuleParallelizeConfig] = field(default_factory=list) + + * ``type`` (``str``): The model type. Note: It can't be a factory function. + * ``args`` (``Dict[str, Any]``): The arguments of the model's ``__init__`` function. + * ``parallel_modules`` (``List[ModuleParallelizeConfig]``): The sub modules to be parallelized. + If this is not empty, these modules will be parallelized instead of the whole model. + i.e. sub modules (in the list of ``parallel_modules``) in the model + will be replaced with parallelized version + + Note: When parallel_modules is not empty, + pipeline parallelism is not supported as the model is not end-to-end parallelized any more. + + .. code-block:: python + + @dataclass(frozen=True) + class OptionalComputeConfig: + constant_folding: Optional[bool] = None + trace_strategy: Optional[str] = None + use_zero: Optional[bool] = None + zero_ngroups: Optional[int] = None + zero_use_reduce_scatter: Optional[bool] = None + use_async_reducer: Optional[bool] = None + reducer_bucket_cap_mb: Optional[float] = None + + pas_config: Optional[Dict[str, Any]] = None + user_config: Optional[Dict[str, Any]] = None + +This is an optional version of the ``ComputeConfig``. +Please refer to :ref:`ComputeConfig ` for more information. + + .. code-block:: python + + @dataclass + class ModuleParallelizeConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + forward_args_gen_fn: str = None + tracing_from_weights: str = None + tracing_from_weights_prefix: str = None + + # For the following config, If None, the config of the trainer_args will be used + compute_config: Optional[OptionalComputeConfig] = None + gen_savedir: Optional[str] = None + gen_reuse: Optional[str] = None + pas_policy: Optional[str] = None + broadcast_strategy: Optional[str] = None + instance_name: Optional[str] = None + precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None + + * ``type`` (``str``): The sub model type to be parallelized. Note: It can't be a factory function. + * ``args`` (``Dict[str, Any]``): The arguments of the model's ``__init__`` function. + * ``forward_args_gen_fn`` (``str``): The full qualified name of the function to generate dummy forward args. + Its type should be ``Callable[[TrainerArgs],Dict[str, Any]]``. + The function should return a dict of dummy forward args for the model. + * ``tracing_from_weights`` (``str``): The path to the weights to be loaded when tracing(compiling) the model. + It is only used in tracing to serve as the initial state dict of the model. Default is ``None``. + * ``tracing_from_weights_prefix`` (``str``): the prefix in the state dict (loaded from ``trainer_args.tracing_from_weights``) to be used for tracing. + Please note ``trainer_args.tracing_from_weights`` must be set if you want to use this, + and ``tracing_from_weights`` and ``tracing_from_weights_prefix`` shouldn't be set at the same time. + * ``compute_config`` (``Optional[OptionalComputeConfig]``): The compute config for the parallelized module. + The merged config with the compute config of the ``trainer_args.compute_config`` will be used. + * ``gen_savedir`` (``Optional[str]``): The directory to save the generated files. + If None, the config of the trainer_args will be used. You can find more information below. + * ``gen_reuse`` (``Optional[str]``): The reuse strategy of the generated code. + If None, the config of the trainer_args will be used. You can find more information below. + * ``pas_policy`` (``Optional[str]``): The policy of parameter partitioning. + If None, the config of the trainer_args will be used. You can find more information below. + * ``broadcast_strategy`` (``Optional[str]``): The strategy of broadcasting the model. + If None, the config of the trainer_args will be used. You can find more information below. + * ``instance_name`` (``Optional[str]``): The instance name of the trainer. + If None, the config of the trainer_args will be used. You can find more information below. + * ``precision`` (``Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None]``): The precision of the model. + If None, the config of the trainer_args will be used. You can find more information below. + +Please Note: +1. The parallelization is per-module-type, which means one module type can only be parallelized once. + Moreover, the initial weights of the parallelized modules with the same type are all the same. + + So if you want to parallelize a module multiple times (with different arguments or different inital weights), + you need to create an alias for it. + + For example, if you want to parallelize a module named ``SomeModule`` twice, you can create an alias for it: + .. code-block:: python + + class SomeModuleAlias(SomeModule): + pass + +2. The initial weights of the whole model will be different when sub module parallelization is enabled, + since parallelization process will change the ``rng_state`` of torch. + + To make the initial weights of the whole model the same as the original model, + We recommend to save the initial weights of the original model and load them before training. * ``optimizer`` (``OptimizerConfig``): The optimizer to be used. @@ -467,6 +559,7 @@ Other configs * ``enable_progress_bar`` (``bool``): Whether to enable the progress bar. Default is ``True``. * ``seed`` (``Optional[int]``): The random seed. Default is ``None``. * ``init_env_fn`` (``str``): The function to initialize the environment. Default is ``None``. + Note: one of ``seed`` and ``init_env_fn`` must be set. *** CLI diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py new file mode 100644 index 00000000..89741ff0 --- /dev/null +++ b/nnscaler/cli/mixed_module.py @@ -0,0 +1,333 @@ +import types +import torch +from typing import Any, Optional +from dataclasses import asdict, replace +import inspect +import copy + +import nnscaler +from nnscaler.runtime.adapter.reducer import Reducer +from nnscaler.runtime.gnorm import ParamsInfo +from nnscaler.utils import fields + +from .trainer_args import ( + TrainerArgs, PrecisionMixin, PolicyMixin, ModuleParallelizeConfig, ComputeConfig, + load_type +) + + +def fork_rng(): + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + return torch.random.fork_rng([rank]) + else: + return torch.random.fork_rng() + + +class ModuleParallelizeConfigAdapter(PrecisionMixin, PolicyMixin): + """ + Adapter for ModuleParallelizeConfig and TrainerArgs + """ + def __init__( + self, trainer_args: TrainerArgs, + parallel_module: Optional[ModuleParallelizeConfig] = None, + tracing_weights: Optional[dict[str, Any]] = None + ): + """ + Args: + trainer_args: the trainer args + parallelized_module: the parallelized module config. + If None, the whole model will be parallelized + """ + self.trainer_args = trainer_args + self.parallel_module = parallel_module + self.tracing_weights = tracing_weights + + # we don't want to load the tracing weights every time + # It should be loaded only once outside, and passed to the adapter + if self.parallel_module \ + and self.parallel_module.tracing_from_weights_prefix \ + and not self.tracing_weights: + raise ValueError('tracing_weights should be provided when tracing_from_weights_prefix is set') + + @property + def model_type(self): + return ( + self.parallel_module.model_type + if self.parallel_module + else self.trainer_args.model_type + ) + + @property + def compute_config(self): + if self.parallel_module: + if self.parallel_module.compute_config is not None: + return self.parallel_module.compute_config.resolve(self.trainer_args.compute_config) + else: + return replace(self.trainer_args.compute_config, use_end2end=False) + else: + return self.trainer_args.compute_config + + @property + def gen_savedir(self): + return ( + self.parallel_module.gen_savedir + if self.parallel_module and self.parallel_module.gen_savedir is not None + else self.trainer_args.gen_savedir + ) + + @property + def gen_reuse(self): + return ( + self.parallel_module.gen_reuse + if self.parallel_module and self.parallel_module.gen_reuse is not None + else self.trainer_args.gen_reuse + ) + + @property + def pas_policy(self): + return ( + self.parallel_module.pas_policy + if self.parallel_module and self.parallel_module.pas_policy is not None + else self.trainer_args.pas_policy + ) + + @property + def broadcast_strategy(self): + return ( + self.parallel_module.broadcast_strategy + if self.parallel_module and self.parallel_module.broadcast_strategy is not None + else self.trainer_args.broadcast_strategy + ) + + @property + def instance_name(self): + return ( + self.parallel_module.instance_name + if self.parallel_module and self.parallel_module.instance_name is not None + else self.trainer_args.instance_name + ) + + @property + def tracing_from_weights(self): + return ( + self.parallel_module.tracing_from_weights + if self.parallel_module + else self.trainer_args.tracing_from_weights + ) + + def load_tracing_weights(self) -> Optional[dict[str, Any]]: + tracing_weights = None + if not self.parallel_module: + # try to reuse the weights from the tracing weights + tracing_weights = self.tracing_weights + if self.tracing_from_weights and tracing_weights is None: + tracing_weights = torch.load(self.tracing_from_weights) + else: + if self.tracing_from_weights: + tracing_weights = torch.load(self.tracing_from_weights) + elif self.parallel_module.tracing_from_weights_prefix: + leading_key = self.parallel_module.tracing_from_weights_prefix + '.' + tracing_weights = {} + for key in self.tracing_weights: + if key.startswith(leading_key): + tracing_weights[key[len(leading_key):]] = self.tracing_weights[key] + return tracing_weights + + @property + def precision(self): + return ( + self.parallel_module.precision + if self.parallel_module and self.parallel_module.precision is not None + else self.trainer_args.precision + ) + + def create_model(self) -> torch.nn.Module: + model = ( + self.parallel_module.create_model(self.trainer_args) + if self.parallel_module + else self.trainer_args.create_model() + ) + model = self.to_precision(model) + tracing_weights = self.load_tracing_weights() + if tracing_weights: + model.load_state_dict(tracing_weights) + return model + + def create_dummy_forward_args(self, dummy_input) -> dict[str, Any]: + if self.parallel_module: + return self.fix_input( + self.parallel_module.create_dummy_forward_args(self.trainer_args) + ) + + # forward args of whole model + arg_names = list( + inspect.signature( + inspect.unwrap(getattr(self.model_type, 'forward')) + ).parameters.keys() + ) + return {arg_names[1]: self.fix_input(dummy_input)} # arg_names[0] is self + + def resolve_compute_config(self): + compute_config = copy.deepcopy(self.compute_config) + compute_config.pas_config['__pas_name'] = self.pas_policy + # autodist configs + compute_config.pas_config['update_freq'] = self.trainer_args.update_freq + compute_config.pas_config['use_bf16'] = self.param_dtype == torch.bfloat16 + compute_config.pas_config['use_fp16'] = self.param_dtype == torch.float16 + + compute_config.user_config['__from_trainer_args'] = { + 'mbs': self.trainer_args.micro_batch_size, + 'gbs': self.trainer_args.global_batch_size, + 'precision': self.trainer_args.precision, + 'model_args': self.trainer_args.model.args, + } + return compute_config + + def parallelize(self, dummy_input: Optional[dict[str, Any]] = None, *, load_module: bool = True): + pmodel_class = nnscaler.parallelize( + self.model_type, + self.create_dummy_forward_args(dummy_input), + self.resolved_pas_policy, + self.resolve_compute_config(), + module_fn=self.create_model, + gen_savedir=self.gen_savedir, + reuse=self.gen_reuse, + instance_name=self.instance_name, + broadcast_strategy=self.broadcast_strategy, + load_module=load_module, + ) + if load_module: + return pmodel_class() + return pmodel_class + + +def mixin_module(model: torch.nn.Module, optimizer: torch.optim.Optimizer): + if isinstance(model, nnscaler.ParallelModule): + return model + + def train_step(self, + samples: list[Any], + is_dummy_batch: Optional[list[bool]] = None + ) -> list[Any]: + if is_dummy_batch is not None: + if len(samples) != len(is_dummy_batch): + raise ValueError('The length of samples and is_dummy_batch should be the same') + samples = [ + sample + for sample, is_dummy in zip(samples, is_dummy_batch) + if not is_dummy + ] + if not samples: + raise ValueError('No real samples in the batch') + + if not all(is_dummy_batch[len(samples):]): + raise ValueError('Dummy samples should be at the end of the batch') + + forward_outputs = [] + for idx, sample in enumerate(samples): + with nnscaler.sync_grad_when(idx == len(samples) - 1): + output = model(sample) + loss = output[0] if isinstance(output, tuple) else output + loss.backward() + forward_outputs.append(output) + return forward_outputs + + def infer_step(self, samples: list[Any]) -> list[Any]: + forward_outputs = [] + for sample in samples: + output = model(sample) + forward_outputs.append(output) + return forward_outputs + + def parameters_for_calc_gnorm(self): + parallel_modules = [m for m in model.modules() if isinstance(m, nnscaler.ParallelModule)] + + params_info = [] + for module in parallel_modules: + params_info.extend(module.parameters_for_calc_gnorm()) + + non_parallel_module_reducer: Reducer = optimizer._non_parallel_module_reducer + if non_parallel_module_reducer: + param_info = ParamsInfo( + non_parallel_module_reducer.ranks, + non_parallel_module_reducer.parameters_for_optimizer(), + [], + non_parallel_module_reducer.zero_ngroups + ) + params_info.append(param_info) + + return params_info + + model.train_step = types.MethodType(train_step, model) + model.infer_step = types.MethodType(infer_step, model) + model.parameters_for_calc_gnorm = types.MethodType(parameters_for_calc_gnorm, model) + return model + + +def parallelize_model(trainer_args: TrainerArgs, dummy_input: dict[str, Any], load_module: bool): + tracing_weights = None + if trainer_args.tracing_from_weights: + tracing_weights = torch.load(trainer_args.tracing_from_weights) + + def _new_adapter(parallel_module=None): + return ModuleParallelizeConfigAdapter( + trainer_args, parallel_module, + tracing_weights=tracing_weights + ) + + if not trainer_args.model.parallel_modules: + # parallelize the whole model + return _new_adapter().parallelize(dummy_input, load_module=load_module) + + if not load_module: + for m in trainer_args.model.parallel_modules: + _new_adapter(m).parallelize(dummy_input, load_module=False) + return + + parallel_sub_modules = { + load_type(m.type): m + for m in trainer_args.model.parallel_modules + } + + def _default_new(cls, *args, **kwargs): + return object.__new__(cls) + + # mock the __new__ method of sub modules to replace them with parallelized version + # Please note mocking __new__ is very dangerous and error-prone + # And once you set it, you can never restore it + # Here we use _default_new to restore it, + # Setting it to object.__new__ will be wrong + # Deleting the __new__ method will also be wrong + # See more https://github.com/python/cpython/issues/105888 + def _patch_new(): + for m in parallel_sub_modules: + m.__new__ = __parallel__new__ + + def _restore_new(): + for m in parallel_sub_modules: + m.__new__ = _default_new + + # parallelize modules hook + def __parallel__new__(cls, *args, **kwargs): + try: + _restore_new() + # it can go here when a subclass module of a parallelized module is instantiated + if cls not in parallel_sub_modules: + return cls.__new__(cls) + else: + parallel_module_config = parallel_sub_modules[cls] + adapter = _new_adapter(parallel_module_config) + # fork the random state to + # make sure the modules after parallelized module + # are the same in all devices. + with fork_rng(): + return adapter.parallelize(dummy_input, load_module=True) + finally: + _patch_new() + + _patch_new() + try: + return trainer_args.to_precision(trainer_args.create_model()) + finally: + _restore_new() diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 62f88063..0834d906 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -23,10 +23,10 @@ import nnscaler from nnscaler.utils import enforce_zero_num_worker, is_running_distributed -import nnscaler.utils -from .trainer_args import AggregatedOutputs, TrainerArgs +from .trainer_args import AggregatedOutputs, TrainerArgs, fix_input from .train_hook import AggregatedTrainHook, TrainHook +from .mixed_module import parallelize_model, mixin_module logger = logging.getLogger(__name__) @@ -105,29 +105,7 @@ def run(self): self._train() def _fix_input(self, input): - if isinstance(input, dict): - return {k: self._fix_input(v) for k, v in input.items()} - elif isinstance(input, list): - return [self._fix_input(v) for v in input] - elif isinstance(input, tuple): - return tuple(self._fix_input(v) for v in input) - elif isinstance(input, torch.Tensor): - if input.is_floating_point() and self.train_args.input_dtype is not None: - return input.to(self.train_args.input_dtype).cuda() - else: - return input.cuda() - return input - - def _create_dummy_forward_args(self): - assert self.dummy_input is not None, "dummy_input is not set" - assert self.train_args.model_type is not None, "model_type is not set" - - arg_names = list( - inspect.signature( - inspect.unwrap(getattr(self.train_args.model_type, 'forward')) - ).parameters.keys() - ) - return {arg_names[1]: self.dummy_input} # arg_names[0] is self + return fix_input(input, self.train_args.input_dtype) def _load_dummy_input(self): with enforce_zero_num_worker(DataLoader): @@ -147,33 +125,6 @@ def _setup(self): else: logging.getLogger().setLevel(logging.WARNING) - def _create_model(): - model = self.train_args.create_model() - if self.train_args.param_dtype == self.train_args.buffer_dtype: - if self.train_args.param_dtype is not None: - model = model.to(self.train_args.param_dtype) - else: - # separate param and buffer dtype - # TODO: a little hacky. A better way? - # 3 kinds of tensors are converted in Module._apply: - # model parameters, its grad, and buffer - # param_dtype controls the first two, (but grad is `None` here) - # and buffer_dtype controls the last one - buf_ids = { id(buf) for buf in model.buffers(recurse=True) } - if self.train_args.param_dtype is not None: - model._apply( - lambda t: t.to(self.train_args.param_dtype) - if t.is_floating_point() and id(t) not in buf_ids - else t) - if self.train_args.buffer_dtype is not None: - model._apply( - lambda t: t.to(self.train_args.buffer_dtype) - if t.is_floating_point() and id(t) in buf_ids - else t) - if self.train_args.tracing_from_weights: - model.load_state_dict(torch.load(self.train_args.tracing_from_weights)) - return model - # create dataset and dataloader for stage in ['train', 'val', 'test']: self.dataset[stage] = self.train_args.create_dataset(stage) @@ -194,34 +145,7 @@ def _create_model(): f"You can specify `drop_last=True` in DataLoader to fix this problem." ) - # setup compute config - compute_config = copy.deepcopy(self.train_args.compute_config) - compute_config.pas_config['__pas_name'] = self.train_args.pas_policy - # autodist configs - compute_config.pas_config['update_freq'] = self.train_args.update_freq - compute_config.pas_config['use_bf16'] = self.train_args.param_dtype == torch.bfloat16 - compute_config.pas_config['use_fp16'] = self.train_args.param_dtype == torch.float16 - - compute_config.user_config['__from_trainer_args'] = { - 'mbs': self.train_args.micro_batch_size, - 'gbs': self.train_args.global_batch_size, - 'precision': self.train_args.precision, - 'model_args': self.train_args.model.args, - } - - # parallalize model - pmodel_class = nnscaler.parallelize( - self.train_args.model_type, - self._create_dummy_forward_args(), - self.train_args.resolved_pas_policy, - compute_config, - module_fn=_create_model, - gen_savedir=self.train_args.gen_savedir, - reuse=self.train_args.gen_reuse, - instance_name=self.train_args.instance_name, - broadcast_strategy=self.train_args.broadcast_strategy, - load_module=not compile_only, - ) + pmodel = parallelize_model(self.train_args, self.dummy_input, load_module=not compile_only) if compile_only: return @@ -244,9 +168,11 @@ def _create_model(): self.max_train_steps = self.total_train_steps_per_epoch * self.train_args.max_epochs _, self.sync_group = self.train_args.compute_config.get_sync_group() - self.model = pmodel_class() + self.model = pmodel self.model.cuda() self.optimizer = self.train_args.create_parallel_optimizer(self.model) + # unify the interface of ParallelModule and partial-parallelized model + self.model = mixin_module(self.model, self.optimizer) # Here we carefully scale down the gradient locally with 1/scale_factor before reduce, # (the reduce op is `sum` by default, follow torch's c10d, grad is divided by scaling_factor before allreduce) # and scale up the gradient after reduce diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index e3d01cdf..58a1d00f 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, replace import importlib from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union from typing_extensions import get_args from pathlib import Path import logging +import inspect import copy import os import builtins @@ -18,7 +19,8 @@ import yaml import torch -from nnscaler.utils import transform_recursively +import nnscaler +from nnscaler.utils import fields, transform_recursively from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule @@ -33,6 +35,110 @@ logger = logging.getLogger(__name__) +_TENSOR_TYPE = Literal['param', 'buffer', 'input'] +_PRECISION_TYPE = Literal['fp32', 'fp16', 'bf16', 'none'] +_PRECISION_MAP = { + 'fp32': torch.float32, + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + 'none': None # as it is. no conversion will happen. +} + + +def _get_tensor_dtype(precision: Dict[_TENSOR_TYPE, _PRECISION_TYPE], tensor_type: _TENSOR_TYPE) -> torch.dtype: + return _PRECISION_MAP[precision[tensor_type]] + + +def _to_precision(module: torch.nn.Module, precision: Dict[_TENSOR_TYPE, _PRECISION_TYPE]): + param_dtype = _get_tensor_dtype(precision, 'param') + buffer_dtype = _get_tensor_dtype(precision, 'buffer') + + if param_dtype == buffer_dtype: + if param_dtype is not None: + module = module.to(param_dtype) + else: + # separate param and buffer dtype + # TODO: a little hacky. A better way? + # 3 kinds of tensors are converted in Module._apply: + # model parameters, its grad, and buffer + # param_dtype controls the first two, (but grad is `None` here) + # and buffer_dtype controls the last one + buf_ids = { id(buf) for buf in module.buffers(recurse=True) } + if param_dtype is not None: + module._apply( + lambda t: t.to(param_dtype) + if t.is_floating_point() and id(t) not in buf_ids + else t) + if buffer_dtype is not None: + module._apply( + lambda t: t.to(buffer_dtype) + if t.is_floating_point() and id(t) in buf_ids + else t) + + return module + + +def _resolve_precision(precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE]]): + supported_precision_type = get_args(_PRECISION_TYPE) + supported_tensor_type = get_args(_TENSOR_TYPE) + if not precision: + precision = 'none' + if isinstance(precision, str): + precision = {k: precision for k in supported_tensor_type} + for tensor_type in supported_tensor_type: + if tensor_type not in precision: + precision[tensor_type] = 'none' + if precision[tensor_type] not in supported_precision_type: + raise ValueError(f"Invalid precision {precision[tensor_type]} for {tensor_type}") + if any(k not in supported_tensor_type for k in precision): + raise ValueError(f"Invalid tensor type found in {precision.keys()}") + + return precision + + +def fix_input(input, input_dtype=None): + if isinstance(input, dict): + return {k: fix_input(v, input_dtype) for k, v in input.items()} + elif isinstance(input, list): + return [fix_input(v, input_dtype) for v in input] + elif isinstance(input, tuple): + return tuple(fix_input(v, input_dtype) for v in input) + elif isinstance(input, torch.Tensor): + if input.is_floating_point() and input_dtype is not None: + return input.to(input_dtype).cuda() + else: + return input.cuda() + return input + + +class PrecisionMixin: + @property + def param_dtype(self): + return _get_tensor_dtype(self.precision, 'param') + + @property + def buffer_dtype(self): + return _get_tensor_dtype(self.precision, 'buffer') + + @property + def input_dtype(self): + return _get_tensor_dtype(self.precision, 'input') + + def fix_input(self, input): + return fix_input(input, input_dtype=self.input_dtype) + + def to_precision(self, module): + return _to_precision(module, self.precision) + + +class PolicyMixin: + @property + def resolved_pas_policy(self): + if self.pas_policy in _PREDEFINED_POLICIES: + return self.pas_policy + return load_type(self.pas_policy) + + def load_type(type_name: str): """ Load function/class from its full qualified name @@ -79,12 +185,95 @@ class AggregatedOutputs: aggregated_outputs: Any = None +@dataclass(frozen=True) +class OptionalComputeConfig: + constant_folding: Optional[bool] = None + trace_strategy: Optional[str] = None + use_zero: Optional[bool] = None + zero_ngroups: Optional[int] = None + zero_use_reduce_scatter: Optional[bool] = None + use_async_reducer: Optional[bool] = None + reducer_bucket_cap_mb: Optional[float] = None + + pas_config: Optional[Dict[str, Any]] = None + user_config: Optional[Dict[str, Any]] = None + + def resolve(self, compute_config: ComputeConfig) -> ComputeConfig: + replace_values = { + k: v for k, v in asdict(self).items() + if v is not None + } + resolved_values = asdict(compute_config) + resolved_values.update(replace_values) + resolved_values[fields(ComputeConfig).use_end2end] = False + return ComputeConfig(**resolved_values) + + @dataclass -class ModelConfig: +class ModuleParallelizeConfig: + # The type to be parallelized + # Please note if you specify this + # pipeline parallelism will be disabled, and you must ensure ComputeConfig.use_end2end is False type: str = None args: Dict[str, Any] = field(default_factory=dict) + # the full qualified name of the function to generate dummy forward args + # Its type should be `Callable[[TrainerArgs],Dict[str, Any]]` + forward_args_gen_fn: str = None + # the model state dict file for tracing. + # It is only used in tracing to serve as the initial state dict of the model. + tracing_from_weights: str = None + # the prefix in the state dict (loaded from trainer_args.tracing_from_weights) to be used for tracing + tracing_from_weights_prefix: str = None + + # For the following config, If None, the config of the trainer_args will be used + compute_config: Optional[OptionalComputeConfig] = None + gen_savedir: Optional[str] = None + gen_reuse: Optional[str] = None + pas_policy: Optional[str] = None + broadcast_strategy: Optional[str] = None + instance_name: Optional[str] = None + precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None + + def __post_init__(self): + if not self.type: + raise ValueError("type is required") + if not self.forward_args_gen_fn: + raise ValueError("forward_args_gen_fn is required") + + if self.tracing_from_weights and self.tracing_from_weights_prefix: + raise ValueError("tracing_from_weights and tracing_from_weights_prefix must not be used together") + + if self.precision is not None: + self.precision = _resolve_precision(self.precision) + + @property + def model_type(self): + return load_type(self.type) + + def create_model(self, trainer_args: 'TrainerArgs') -> torch.nn.Module: + kwargs = trainer_args.create_kwarg(self.args) + return self.model_type(**kwargs) + + def create_dummy_forward_args(self, trainer_args: 'TrainerArgs') -> dict[str, Any]: + forward_args_gen_fn = load_type(self.forward_args_gen_fn) + return forward_args_gen_fn(trainer_args) +@dataclass +class ModelConfig: + type: str = None + args: dict[str, Any] = field(default_factory=dict) + # if parallel_modules is not empty, + # these modules will be parallelized instead of the whole model + # and sub modules (in the list of `parallel_modules`) in the model + # will be replaced with parallelized version + parallel_modules: list[ModuleParallelizeConfig] = field(default_factory=list) + + def __post_init__(self): + parallel_sub_modules = [load_type(m.type) for m in self.parallel_modules] + if set(parallel_sub_modules) != set(parallel_sub_modules): + raise ValueError(f"parallelized sub modules must be unique") + @dataclass class OptimizerConfig: type: str = None @@ -287,18 +476,8 @@ def __init__(self, hook_config: HookMapConfig): setattr(self, k, load_type(v)) -_TENSOR_TYPE = Literal['param', 'buffer', 'input'] -_PRECISION_TYPE = Literal['fp32', 'fp16', 'bf16', 'none'] -_PRECISION_MAP = { - 'fp32': torch.float32, - 'fp16': torch.float16, - 'bf16': torch.bfloat16, - 'none': None # as it is. no conversion will happen. -} - - @dataclass -class TrainerArgs: +class TrainerArgs(PrecisionMixin, PolicyMixin): compute_config: ComputeConfig = None gen_savedir: str = './.nnscaler' @@ -390,24 +569,25 @@ def __post_init__(self): if self.broadcast_strategy not in [e.value for e in BroadcastGenFilesStrategy]: raise ValueError(f"Invalid broadcast_strategy {self.broadcast_strategy}") - supported_precision_type = get_args(_PRECISION_TYPE) - supported_tensor_type = get_args(_TENSOR_TYPE) - if not self.precision: - self.precision = 'none' - if isinstance(self.precision, str): - self.precision = {k: self.precision for k in supported_tensor_type} - for tensor_type in supported_tensor_type: - if tensor_type not in self.precision: - self.precision[tensor_type] = 'none' - if self.precision[tensor_type] not in supported_precision_type: - raise ValueError(f"Invalid precision {self.precision[tensor_type]} for {tensor_type}") - if any(k not in supported_tensor_type for k in self.precision): - raise ValueError(f"Invalid tensor type found in {self.precision.keys()}") + self.precision = _resolve_precision(self.precision) if not self.max_epochs and not self.max_train_steps: raise ValueError("max_epochs or max_train_steps is required") + if not self.model.type: raise ValueError("model type is required") + + for m in self.model.parallel_modules: + if m.compute_config: + # will raise ValueError if m.compute_config is invalid when combining with the global compute_config + m.compute_config.resolve(self.compute_config) + + if load_type(m.type) == self.model_type: + raise ValueError(f"parallelized sub module {m.type} cannot be the same as the model type in trainer args") + + if m.tracing_from_weights_prefix and not self.tracing_from_weights: + raise ValueError("`tracing_from_weights` is required when `tracing_from_weights_prefix` is specified") + if not self.optimizer.type: raise ValueError("optimizer type is required") if not self.dataset.type: @@ -419,6 +599,13 @@ def __post_init__(self): if self.lr_scheduler and not self.lr_scheduler.type: raise ValueError("lr_scheduler type is required") + if self.seed is None and self.init_env_fn is None: + logger.warning( + "Neither `seed` nor `init_env_fn` is not provided. " + "The training may not be reproducible " + "and the model weights on different devices can be different." + ) + @classmethod def from_cli(cls, argv: List[str]) -> 'TrainerArgs': d = {} @@ -485,12 +672,6 @@ def resolved_aggregate_outputs_fn(self): return None return load_type(self.optimizer.aggregate_outputs_fn) - @property - def resolved_pas_policy(self): - if self.pas_policy in _PREDEFINED_POLICIES: - return self.pas_policy - return load_type(self.pas_policy) - @property def scaling_factor(self): return self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus @@ -507,18 +688,6 @@ def enable_log_progress(self): def compile_mode(self) -> bool: return self.run_mode == 'compile' - @property - def param_dtype(self) -> torch.dtype: - return _PRECISION_MAP[self.precision['param']] - - @property - def buffer_dtype(self) -> torch.dtype: - return _PRECISION_MAP[self.precision['buffer']] - - @property - def input_dtype(self) -> torch.dtype: - return _PRECISION_MAP[self.precision['input']] - def init_env(self, trainer: 'Trainer'): if self.seed is not None: import random @@ -536,10 +705,10 @@ def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model.args) return self.model_type(**kwargs) - def create_parallel_optimizer(self, parallel_model: ParallelModule): + def create_parallel_optimizer(self, parallel_model: torch.nn.Module): kwargs = self.create_kwarg(self.optimizer.args) optimizer_class = load_type(self.optimizer.type) - return build_optimizer(parallel_model, optimizer_class, **kwargs) + return build_optimizer(parallel_model, optimizer_class, self.compute_config, **kwargs) def create_dataset(self, stage='train'): dataset_args = getattr(self.dataset, f'{stage}_args') diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 06a2b1b3..0eda7f57 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -230,6 +230,12 @@ def optimizer_dedup_group_size(self) -> int: else: return self.plan_ngpus + @property + def max_bucket_size_bytes(self) -> Optional[int]: + return int(self.reducer_bucket_cap_mb * 1024 * 1024) \ + if self.reducer_bucket_cap_mb \ + else None + def get_sync_group(self) -> Tuple[List[int], torch.distributed.ProcessGroup]: """ Get sync group for the current rank. @@ -317,8 +323,7 @@ def _compile_flags(compute_config: ComputeConfig): return _flags( CompileFlag, async_reducer=compute_config.use_async_reducer, reducer_op='sum', - max_reducer_bucket=int(compute_config.reducer_bucket_cap_mb * 1024 * 1024) - if compute_config.reducer_bucket_cap_mb else None, + max_reducer_bucket=compute_config.max_bucket_size_bytes, async_comm=False, use_zero=compute_config.use_zero, zero_ngroups=compute_config.zero_ngroups, @@ -468,6 +473,7 @@ class RegenStatus(Enum): NONE = 'none' # nothing is regenerated. ALL = 'all' # everything is regenerated, including graph and code CODE = 'code' # only code is regenerated. + ERROR = 'error' # error occurs during generation. def _prepare_namespace( @@ -1014,26 +1020,38 @@ def __init__(self, init_params=True): if torch.distributed.is_initialized(): _ = DeviceGroup() - # generate code only in node0 - # if it is not in a torchrun environment, just generate. - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - outdir, reusable = _prepare_and_check_reusable(gen_savedir, module_class, compute_config, instance_name, reuse) - if not reusable: - config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE - ComputeConfig.safe_dump_to_file(compute_config, config_file) # always refresh compute config - with _compile_flags(compute_config): - regen_status = _gencode( - module_or_module_class, - dummy_forward_args, - pas_policy, - compute_config, - outdir, - module_dtype=module_dtype, - module_fn=module_fn, - ) - else: - regen_status = RegenStatus.NONE - logger.info(f"Reuse generated code in {outdir}") + # try...finally to ensure the barrier is called + # even if an exception is raised in the middle of the code generation. + try: + # generate code only in node0 + # if it is not in a torchrun environment, just generate. + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + outdir, reusable = _prepare_and_check_reusable(gen_savedir, module_class, compute_config, instance_name, reuse) + if not reusable: + config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE + ComputeConfig.safe_dump_to_file(compute_config, config_file) # always refresh compute config + with _compile_flags(compute_config): + regen_status = _gencode( + module_or_module_class, + dummy_forward_args, + pas_policy, + compute_config, + outdir, + module_dtype=module_dtype, + module_fn=module_fn, + ) + else: + regen_status = RegenStatus.NONE + logger.info(f"Reuse generated code in {outdir}") + except Exception as e: + regen_status = RegenStatus.ERROR + regen_exception = e + else: + # if the code generation is successful, set `regen_exception` to `None` in all nodes + # If the code generation is failed, `regen_exception` will be set to `None` in non-zero rank nodes. + # Please note exception is not broadcasted to other nodes + # because it may contain unpicklable objects. + regen_exception = None if torch.distributed.is_initialized(): # code generation can take very long time (for example, over 1 hour) @@ -1041,11 +1059,6 @@ def __init__(self, init_params=True): # because the default timeout for nccl is 30 minutes # (we can't control the timeout setting if torch.distributed is not initialized by us) DeviceGroup().long_barrier() - - if broadcast_strategy != BroadcastGenFilesStrategy.NONE: - if not torch.distributed.is_initialized(): # we only support loading in torchrun environment - raise RuntimeError("Broadcast generated files failed: torch.distributed is not initialized.") - torch.distributed.barrier() # sync regen_status curr_rank = torch.distributed.get_rank() if curr_rank == 0: @@ -1059,7 +1072,16 @@ def __init__(self, init_params=True): if curr_rank != 0: regen_status = sent_obj[0] - # narrow down broadcast_strategy according to regen_status + # all nodes will raise an exception if the code generation is failed. + if regen_status == RegenStatus.ERROR: + raise RuntimeError("Reuse generated code failed.") from regen_exception + + if broadcast_strategy != BroadcastGenFilesStrategy.NONE: + if not torch.distributed.is_initialized(): # we only support loading in torchrun environment + raise RuntimeError("Broadcast generated files failed: torch.distributed is not initialized.") + torch.distributed.barrier() + + # narrow down broadcast_strategy according to regen_status if regen_status == RegenStatus.NONE: # we don't need to broadcast anything broadcast_strategy = BroadcastGenFilesStrategy.NONE @@ -1226,7 +1248,7 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], - *args, + compute_config: Optional[ComputeConfig] = None, **kwargs, ) -> Union[OptimizerT, ParallelOptimizer]: """ @@ -1251,7 +1273,9 @@ def build_optimizer( optimizer_fn (Union[Type[torch.optim.Optimizer], Callable[..., torch.optim.Optimizer]]): It can be the optimizer class or optimizer factory function. The first parameter of the optimizer_fn should be the parameters of the module. - *args: other args for `optimizer_fn` besides parameters. + compute_config (Optional[ComputeConfig]): + The config will be used to generate communication reducer. + If it is None, Default configuration will be used when creating reducer for non-parallel modules. **kwargs: the kwargs for optimizer constructor Returns: @@ -1282,13 +1306,24 @@ def build_optimizer( for i in range(1, len(compute_configs)): if compute_configs[i].gpu_config != compute_configs[0].gpu_config: raise RuntimeError("All ParallelModules should have the same gpu_config.") + if compute_config and compute_config.gpu_config != compute_configs[0].gpu_config: + raise RuntimeError("All ParallelModules should have the same gpu_config.") plan_ngpus, runtime_ngpus = compute_configs[0].plan_ngpus, compute_configs[0].runtime_ngpus # we need to add all parameters of non-parallel modules to a reducer to reduce grads # if there are non-parallel parameters if plan_ngpus != runtime_ngpus and non_parallel_modules and any(p.numel() for m in non_parallel_modules for p in m.parameters(False)): group, _ = compute_configs[0].get_sync_group() - non_parallel_module_reducer = Reducer(group) + reducer_config = {} + if compute_config: + reducer_config = { + 'async_op': compute_config.use_async_reducer, + 'zero': compute_config.use_zero, + 'max_bucket_size_bytes': compute_config.max_bucket_size_bytes, + 'zero_use_reduce_scatter': compute_config.zero_use_reduce_scatter, + 'zero_ngroups': compute_config.zero_ngroups, + } + non_parallel_module_reducer = Reducer(group, **reducer_config) for m in non_parallel_modules: for param in m.parameters(recurse=False): # only add leaf parameters to avoid duplicate non_parallel_module_reducer.add_param(param) @@ -1320,7 +1355,7 @@ def _local_parameters(module: torch.nn.Module): opt_module_locs[name].count += 1 yield param - optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), *args, **kwargs) + optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), **kwargs) optimizer._non_parallel_module_reducer = non_parallel_module_reducer optimizer._extra_state = OptimizerExtraState( rank=torch.distributed.get_rank(), @@ -1609,8 +1644,30 @@ def merge_state_dicts( if not module_state_dicts: raise ValueError("model_state_dicts should not be empty.") + def _get_state_dict_rank(state_dict: Dict[str, Any]) -> int: + for k in state_dict: + if k.split('.')[-1] == ParallelModule.EXTRA_STATE_KEY: + return state_dict[k]['rank'] + raise ValueError("Invalid state dict: no rank found.") + + def _sort_state_dicts(state_dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + sorted_state_dicts =[None] * len(state_dicts) + for state_dict in state_dicts: + rank = _get_state_dict_rank(state_dict) + if sorted_state_dicts[rank] is not None: + raise ValueError(f"Duplicate rank {rank} in state_dicts.") + if rank >= len(state_dicts): + raise ValueError(f"Invalid rank {rank} in state_dicts.") + sorted_state_dicts[rank] = state_dict + return sorted_state_dicts + + # sort state dicts by rank + module_state_dicts = _sort_state_dicts(module_state_dicts) + pm_extra_states, pm_state_dicts, ret_state_dict = _get_parallel_module_state_dict_info(module_state_dicts) if optimizer_state_dicts is not None: + # sort state dicts by rank + optimizer_state_dicts = _sort_state_dicts(optimizer_state_dicts) opt_extra_states, opt_state_dicts, ret_opt_state_dict = _get_optimizer_state_dict_info(optimizer_state_dicts) # the new optimizer state dict for ParallelModules # key: the parallel module location in the optimizer state @@ -2260,7 +2317,7 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g else: step_stack = torch.zeros( len(state_indexes), - dtype=optimizer_state_dict['state'][0]['step'].dtype, + dtype=optimizer_state_dict['state'][state_indexes[0]]['step'].dtype, device=torch.cuda.current_device() ) torch.distributed.broadcast(step_stack, src=src_rank, group=curr_parallel_group) @@ -2318,7 +2375,7 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): module._warn_uninitialized_non_persistent_buffers(raise_error=True) # we have a special optimization for ParallelModule - params = module.parameters_for_broadcast() if isinstance(module, ParallelModule) else module._parameters.values() + params = module.parameters_for_broadcast() if isinstance(module, ParallelModule) else list(module.parameters(False)) logging.info(f'Inplace broadcasting {len(params)} parameters...') for i, param in enumerate(params): torch.distributed.broadcast(param.data, src=src_rank, group=curr_parallel_group) @@ -2326,8 +2383,9 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): # NOTE: may batch buffers for efficient broadcast, # current implementation is the most memory efficient way. - logging.info(f'Inplace broadcasting {len(module._buffers)} buffers...') - for _, buffer in module._buffers.items(): + buffers = list(module.buffers(False)) + logging.info(f'Inplace broadcasting {len(buffers)} buffers...') + for buffer in buffers: torch.distributed.broadcast(buffer.data, src=src_rank, group=curr_parallel_group) if isinstance(module, ParallelModule): diff --git a/tests/cli/common.py b/tests/cli/common.py index 39e64d20..86b4b00b 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -2,9 +2,81 @@ # Licensed under the MIT License. import torch +from torch import nn from torch.utils.data import DataLoader, Dataset +from typing import Dict +from nnscaler.cli.trainer_args import TrainerArgs from tests.parallel_module.test_end2end import MLP +from tests.utils import init_random as init_random_fn + + +class MixModuleMLP(nn.Module): + def __init__(self, dim: int, nlayers: int, init_random: bool = True): + super().__init__() + if init_random: + init_random_fn() + self.layers = torch.nn.ModuleList([]) + for _ in range(nlayers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, input): + x = input + for layer in self.layers: + x = layer(x) + return x + + +class MixModuleMLP2(MixModuleMLP): + pass + + +class MixModuleMLP3(MixModuleMLP): + pass + + +class MixModuleMLP4(MixModuleMLP): + pass + + +class MixModuleMLPWithLoss(nn.Module): + def __init__(self, dim: int, nlayers: int, init_random: bool = True): + super().__init__() + self.mlp = MixModuleMLP(dim, nlayers, init_random=init_random) + self.loss_fn = nn.BCELoss() + + def forward(self, input, target): + x = self.mlp(input) + x = torch.sigmoid(x) + loss = self.loss_fn(x, target) + return loss + + +class MixedModule(torch.nn.Module): + def __init__(self, dim: int, nlayers: int, init_random: bool = True): + super().__init__() + self.mlp0 = MixModuleMLP(dim, nlayers, init_random=init_random) + self.mlp1 = MixModuleMLP2(dim, nlayers, init_random=init_random) + self.mlp2 = MixModuleMLP3(dim, nlayers, init_random=init_random) + self.mlploss = MixModuleMLPWithLoss(dim, nlayers, init_random=init_random) + + def forward(self, data: Dict[str, torch.Tensor]): + x = data['data'] + target = data['target'] + x = self.mlp0(x) + x = self.mlp1(x) + x = self.mlp2(x) + return self.mlploss(x, target) + + +def forward_args_gen_fn(trainer_args: TrainerArgs): + return { + 'input': + torch.randn(trainer_args.dataset.train_args['size'], trainer_args.dataset.train_args['dim']), + 'target': + torch.rand(trainer_args.dataset.train_args['size'], trainer_args.dataset.train_args['dim']), + } + class SimpleDataset(Dataset): def __init__(self, dim: int, size: int = 100): diff --git a/tests/cli/test_train_args.py b/tests/cli/test_train_args.py index 183ccfbe..f4b1b62d 100644 --- a/tests/cli/test_train_args.py +++ b/tests/cli/test_train_args.py @@ -4,7 +4,7 @@ import pytest import nnscaler -from nnscaler.cli.trainer_args import load_type +from nnscaler.cli.trainer_args import load_type, ComputeConfig, OptionalComputeConfig def test_load_type(): @@ -26,3 +26,14 @@ def test_load_type(): with pytest.raises(RuntimeError): load_type('nnscaler.cli.trainer_args.TrainerArgs.not_exist_name') + + +def test_compute_config_merge(): + cc = ComputeConfig(1, 2, constant_folding=True, use_end2end=True, use_zero=True) + occ = OptionalComputeConfig(constant_folding=False, use_zero=False) + rcc = occ.resolve(cc) + assert rcc == ComputeConfig(1, 2, constant_folding=False, use_end2end=False) + + occ2 = OptionalComputeConfig(zero_ngroups=-1) + with pytest.raises(ValueError): + occ2.resolve(cc) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 7981b952..32f66764 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -8,11 +8,13 @@ import pytest import torch.distributed +from nnscaler import merge_state_dicts from nnscaler.cli.trainer import Trainer from nnscaler.cli.trainer_args import AggregatedOutputs, TrainerArgs -from tests.parallel_module.common import assert_equal -from tests.utils import replace_all_device_with +from tests.parallel_module.common import assert_equal, assert_close +from tests.utils import init_random, replace_all_device_with, clear_parallel_cache from ..launch_torchrun import launch_torchrun +from .common import MixedModule, MixModuleMLP, MixModuleMLP3 def trainer_logging_worker(save_dir): @@ -85,17 +87,75 @@ def test_trainer_compile_worker(tmp_path): shutil.rmtree(gen_savedir) -def trainer_resume_worker(save_dir, save_type, bf16): +def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) gen_savedir = save_dir / 'gen' ckpt_savedir = save_dir / 'ckpt' - # train 4 epcho in one time + optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' \ if bf16 == 'Mixed' \ else 'torch.optim.Adam' use_zero = save_type == 'sharded' + if parallel_type == 0: + additional_args = [] + elif parallel_type == 1: + # 1. parallelize MixModuleMLP2 (self.mlp1) + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP2', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + elif parallel_type == 2: + # 2. parallelize MixModuleMLP (self.mlp0, self.mlploss.mlp) + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + elif parallel_type == 3: + # 3. parallelize MixModuleMLP and MixModuleMLP3 (self.mlp0, self.mlploss.mlp, self.mlp2) + # We will use different compute_config for the two parallelized modules + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.compute_config.use_zero', 'False' if use_zero else 'True', + '--model.parallel_modules.0.compute_config.constant_folding', 'False', + '--model.parallel_modules.0.pas_policy', 'tp', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + '--model.parallel_modules.1.type', 'tests.cli.common.MixModuleMLP3', + '--model.parallel_modules.1.args.dim', '16', + '--model.parallel_modules.1.args.nlayers', '16', + '--model.parallel_modules.1.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + elif parallel_type == 4: + # 4. parallelize MixModuleMLP and MixModuleMLPWithLoss (self.mlp0, self.mlploss) + # Note MixModuleMLP is also a member of MixModuleMLPWithLoss + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.compute_config.use_zero', 'False' if use_zero else 'True', + '--model.parallel_modules.0.compute_config.constant_folding', 'False', + '--model.parallel_modules.0.pas_policy', 'tp', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + '--model.parallel_modules.1.type', 'tests.cli.common.MixModuleMLPWithLoss', + '--model.parallel_modules.1.args.dim', '16', + '--model.parallel_modules.1.args.nlayers', '16', + '--model.parallel_modules.1.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + else: + raise ValueError(f'parallel_type {parallel_type} is not supported') + + # train 4 epcho in one time trainer = Trainer([ '-f', config_path, '--precision', 'bf16' if bf16 else 'none', @@ -110,6 +170,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.save_dir', str(ckpt_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + *additional_args, ]) trainer.run() ckpt_files = set(ckpt_savedir.glob('**/*.ckpt')) @@ -133,6 +194,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + *additional_args, ]) trainer.run() ckpt0_files0 = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} @@ -153,6 +215,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + *additional_args, ]) trainer.run() ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} @@ -181,6 +244,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + *additional_args, ]) trainer.run() left_files = { @@ -210,6 +274,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.save_dir', str(ckpt1_savedir), '--checkpoint.resume_from', str(ckpt1_savedir / 'merged.pt'), '--checkpoint.keep_last_n_checkpoints', '30', + *additional_args, ]) trainer.run() left_files = { @@ -253,6 +318,19 @@ def test_trainer_resume(tmp_path, save_type, bf16): launch_torchrun(4, trainer_resume_worker, tmp_path, save_type, bf16) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('parallel_type', [1, 2, 3, 4]) +def test_trainer_resume_mixed(tmp_path, parallel_type): + # we will parallelize the sub models in MixedModule + # We have different ways to parallelize the sub models + # 1. parallelize MixModuleMLP2 (self.mlp1) + # 2. parallelize MixModuleMLP (self.mlp0, self.mlploss.mlp) + # 3. parallelize MixModuleMLP and MixModuleMLP3 (self.mlp0, self.mlploss.mlp, self.mlp2) + # 4. parallelize MixModuleMLP and MixModuleMLPWithLoss (self.mlp0, self.mlploss) + # Note MixModuleMLP is also a member of MixModuleMLPWithLoss + launch_torchrun(4, trainer_resume_worker, tmp_path, 'deduped', True, parallel_type) + + def trainer_last_checkpoint_worker(save_dir): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) @@ -513,3 +591,205 @@ def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): ]) trainer.run() torch.distributed.barrier() + + +def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + if parallel_type == 0: + # parallelize the whole MixedModule + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + ] + elif parallel_type == 1: + # 1. parallelize MixModuleMLP2 (self.mlp1) + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP2', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + elif parallel_type == 2: + # 2. parallelize MixModuleMLP (self.mlp0, self.mlploss.mlp) + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + elif parallel_type == 3: + # 3. parallelize MixModuleMLP and MixModuleMLP3 (self.mlp0, self.mlploss.mlp, self.mlp2) + # We will use different compute_config for the two parallelized modules + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.compute_config.use_zero', 'False', + '--model.parallel_modules.0.compute_config.constant_folding', 'False', + '--model.parallel_modules.0.pas_policy', 'tp', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + '--model.parallel_modules.1.type', 'tests.cli.common.MixModuleMLP3', + '--model.parallel_modules.1.args.dim', '16', + '--model.parallel_modules.1.args.nlayers', '16', + '--model.parallel_modules.1.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + elif parallel_type == 4: + # 4. parallelize MixModuleMLP and MixModuleMLPWithLoss (self.mlp0, self.mlploss) + # Note MixModuleMLP is also a member of MixModuleMLPWithLoss + additional_args = [ + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.compute_config.use_zero', 'False', + '--model.parallel_modules.0.compute_config.constant_folding', 'False', + '--model.parallel_modules.0.pas_policy', 'tp', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + '--model.parallel_modules.1.type', 'tests.cli.common.MixModuleMLPWithLoss', + '--model.parallel_modules.1.args.dim', '16', + '--model.parallel_modules.1.args.nlayers', '16', + '--model.parallel_modules.1.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + else: + raise ValueError(f'parallel_type {parallel_type} is not supported') + + # train 4 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--precision', 'fp32', + '--max_epochs', '2', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.use_zero', 'False', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '2', + '--compute_config.use_async_reducer', str(async_reducer), + '--compute_config.reducer_bucket_cap_mb', '1e-6', + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '5', + # '--model.args.dim', '16', + # '--model.args.nlayers', '2', + *additional_args, + ]) + trainer.run() + + torch.distributed.barrier() + + # create merged checkpoint + if trainer.rank == 0: + Trainer.merge_checkpoint(list((ckpt_savedir / 'last').glob('*.ckpt')), save_dir / 'merged.pt') + shutil.rmtree(gen_savedir) + + clear_parallel_cache() + + torch.distributed.barrier() + + +def trainer_correctness_worker_aggregate(tmp_path): + for parallel_type in range(5): + for async_reducer in [False, True]: + print(f'parallel_type={parallel_type}, async_reducer={async_reducer}') + save_dir = tmp_path/f'{parallel_type}-{async_reducer}' + trainer_correctness_worker(save_dir, parallel_type, async_reducer) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_trainer_correctness(tmp_path): + launch_torchrun(2, trainer_correctness_worker_aggregate, tmp_path) + merged_ckpts = {} + for parallel_type in range(5): + for async_reducer in [False, True]: + save_dir = tmp_path/f'{parallel_type}-{async_reducer}' + merged_ckpts[(parallel_type, async_reducer)] = torch.load(save_dir/'merged.pt') + + for parallel_type in range(5): + for async_reducer in [False, True]: + assert_equal( + merged_ckpts[(parallel_type, async_reducer)]['model'], + merged_ckpts[(0, False)]['model'] + ) + assert_equal( + merged_ckpts[(parallel_type, async_reducer)]['optimizer'], + merged_ckpts[(0, False)]['optimizer'] + ) + + +def tracing_from_weights_worker(tmp_path): + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + + init_random(1141) + mixed_module = MixedModule(16, 16, False) + mixed_module_2 = MixedModule(16, 16) + assert not torch.equal(mixed_module.mlp0.layers[0].weight, mixed_module_2.mlp0.layers[0].weight) + assert not torch.equal(mixed_module.mlp2.layers[0].weight, mixed_module_2.mlp2.layers[0].weight) + assert not torch.equal(mixed_module.mlploss.mlp.layers[0].weight, mixed_module_2.mlploss.mlp.layers[0].weight) + + tracing_weights = mixed_module.state_dict() + tracing_from_weights = tmp_path / 'tracing_weights.pt' + torch.save(tracing_weights, tracing_from_weights) + + def _compile(index, *additional_args): + gen_dir = tmp_path / f'gen{index}' + trainer = Trainer([ + '-f', config_path, + '--gen_savedir', str(gen_dir), + '--global_batch_size', '0', + '--max_epochs', '-1', # HACK: will exit without training. + '--max_train_steps', '-1', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + '--broadcast_strategy', 'none', + '--model.type', 'tests.cli.common.MixedModule', + *additional_args, + ]) + trainer.run() + import shutil + shutil.rmtree(gen_dir) + clear_parallel_cache() + return merge_state_dicts([trainer.model.state_dict()])[0] + + model1 = _compile(1) + model3 = _compile(3, '--tracing_from_weights', str(tracing_from_weights)) + model2 = _compile(2) + + + assert_equal(model1, model2) + assert_equal(model1, dict(**mixed_module_2.state_dict())) + assert_equal(model3, dict(**tracing_weights)) + + # parallelize MixModuleMLP2 and MixModuleMLP3 (self.mlp1, self.mlp2) + # We will use different compute_config for the two parallelized modules + additional_args = [ + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP2', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + '--model.parallel_modules.1.type', 'tests.cli.common.MixModuleMLP3', + '--model.parallel_modules.1.args.dim', '16', + '--model.parallel_modules.1.args.nlayers', '16', + '--model.parallel_modules.1.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ] + model4 = _compile(4, *additional_args) + assert_equal(model4, dict(**mixed_module_2.state_dict())) + + model5 = _compile(5, + '--tracing_from_weights', str(tracing_from_weights), + '--model.parallel_modules.0.tracing_from_weights_prefix', 'mlp1', + '--model.parallel_modules.1.tracing_from_weights_prefix', 'mlp2', + *additional_args + ) + for key in tracing_weights: + if key.startswith('mlp1') or key.startswith('mlp2'): + assert torch.equal(model5[key], tracing_weights[key]) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_tracing_from_weights(tmp_path): + launch_torchrun(1, tracing_from_weights_worker, tmp_path) diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml index 272ce791..72e1c202 100644 --- a/tests/cli/trainer_args.yaml +++ b/tests/cli/trainer_args.yaml @@ -11,6 +11,7 @@ micro_batch_size: 2 global_batch_size: 8 max_epochs: 4 max_train_steps: 100 +seed: 0 model: type: tests.cli.common.MLP diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index c21f9421..594a73d5 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -21,7 +21,7 @@ from .common import init_distributed from ..launch_torchrun import launch_torchrun -from ..utils import replace_all_device_with +from ..utils import replace_all_device_with, raises_with_cause def _to_cube_model(module, compute_config, cube_savedir, load_module): return parallelize( @@ -151,7 +151,7 @@ def forward(self, x, y): @replace_all_device_with('cpu') @pytest.mark.parametrize('return_type', [0, 1]) def test_codegen_tuple_return2(return_type): - test_context = nullcontext() if return_type != 0 else pytest.raises(RuntimeError, match='Single tuple outputs.*') + test_context = nullcontext() if return_type != 0 else raises_with_cause(RuntimeError, match='Single tuple outputs.*') with tempfile.TemporaryDirectory() as tempdir, test_context: parallelize( TupleReturnModule2(return_type), @@ -603,7 +603,7 @@ def test_codegen_tensor_slice(): with tempfile.TemporaryDirectory() as tempdir: m = TensorSliceModule() m.train() - with pytest.raises(RuntimeError, match='Tensor is not supported in slice.'): + with raises_with_cause(RuntimeError, match='Tensor is not supported in slice.'): parallelize( m, {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, @@ -903,7 +903,7 @@ def p(cube_dir, use_pipeline, constant_folding, return_type, inference_only=Fals ) p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=0) # should success if use_pipeline: - with pytest.raises(RuntimeError, match='.*Communication generation.*'): + with raises_with_cause(RuntimeError, match='.*Communication generation.*'): # fail for non-tensor IRObject return in pipeline mode p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=1) else: @@ -911,13 +911,13 @@ def p(cube_dir, use_pipeline, constant_folding, return_type, inference_only=Fals p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=1) # should success p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=2) # should success p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=2) # should success - with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): + with raises_with_cause(RuntimeError, match='.*Loss can only be scalar tensor.*'): p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=3) - with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): + with raises_with_cause(RuntimeError, match='.*Loss can only be scalar tensor.*'): p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=3) - with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): + with raises_with_cause(RuntimeError, match='.*Loss can only be scalar tensor.*'): p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=4) - with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): + with raises_with_cause(RuntimeError, match='.*Loss can only be scalar tensor.*'): p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=4) p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=0, inference_only=True) # should success @@ -1222,7 +1222,7 @@ def test_invalid_partition(tmp_path): m = CVModel() m.train() - with pytest.raises(ValueError): + with raises_with_cause(ValueError): parallelize( m, dummy_input, @@ -1756,7 +1756,7 @@ def test_use_none_return(tmp_path): m = IRUseNoneModule() m.train() # it should raise an error, because _op1 has no return value, but it is used in _op4 - with pytest.raises(KeyError): + with raises_with_cause(KeyError): parallelize( m, {'x': torch.randn(128, 64)}, @@ -1826,3 +1826,89 @@ def patched_concrete_trace(*args, **kwargs): SignFx2Op.kOpMap.pop('tests.parallel_module.test_gencode._op6') # should success assert True + + +class InitErrorModule(torch.nn.Module): + def __init__(self): + super().__init__() + raise ValueError('world error') + + def forward(self, input): + pass + + +def _gencode_init_error_worker(tmp_path, without_init_distributed=False): + if not without_init_distributed: + init_distributed() + try: + m_new = parallelize( + InitErrorModule, + { + 'input': torch.randn(2, 3, 32, 32), + }, + 'dp', + ComputeConfig(1, 2), + gen_savedir=tmp_path, + load_module=True + ) + except Exception as e: + assert isinstance(e, RuntimeError) + if without_init_distributed or torch.distributed.get_rank() == 0: + root_cause = e.__cause__ + while root_cause.__cause__ is not None: + root_cause = root_cause.__cause__ + assert isinstance(root_cause, ValueError) + assert root_cause.args[0] == 'world error' + else: + assert e.__cause__ is None + + +@replace_all_device_with('cpu') +def test_codegen_init_error_compile(tmp_path): + _gencode_init_error_worker(tmp_path, without_init_distributed=True) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of GPU devices') +def test_codegen_init__error(tmp_path): + launch_torchrun(2, _gencode_init_error_worker, tmp_path) + + +class ForwardErrorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + raise ValueError('hello error') + + +def _gencode_forward_error_worker(tmp_path, without_init_distributed=False): + if not without_init_distributed: + init_distributed() + try: + m_new = parallelize( + ForwardErrorModule, + { + 'input': torch.randn(2, 3, 32, 32), + }, + 'dp', + ComputeConfig(1, 2), + gen_savedir=tmp_path, + load_module=True + ) + except Exception as e: + assert isinstance(e, RuntimeError) + if without_init_distributed or torch.distributed.get_rank() == 0: + assert isinstance(e.__cause__, ValueError) + assert e.__cause__.args[0] == 'hello error' + else: + assert e.__cause__ is None + + +@replace_all_device_with('cpu') +def test_codegen_forward_error_compile(tmp_path): + _gencode_forward_error_worker(tmp_path, without_init_distributed=True) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of GPU devices') +def test_codegen_forward_error(tmp_path): + launch_torchrun(2, _gencode_forward_error_worker, tmp_path) diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index d2917526..3d046393 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -45,7 +45,6 @@ def _init_params_worker(): for p1, p3 in zip(module1.parameters(), module3.parameters()): assert not torch.equal(p1, p3) - assert torch.all(p3 == 0) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') diff --git a/tests/parallel_module/test_nested.py b/tests/parallel_module/test_nested.py index 9a24c1f0..a75b832a 100644 --- a/tests/parallel_module/test_nested.py +++ b/tests/parallel_module/test_nested.py @@ -10,6 +10,7 @@ from .common import init_distributed from ..launch_torchrun import launch_torchrun +from ..utils import raises_with_cause def _to_cube_model(module, pas, compute_config, cube_savedir): return parallelize( @@ -47,7 +48,7 @@ def __init__(self) -> None: def forward(self, x): return self.module1(x) - with pytest.raises(RuntimeError, match='Parallel modules can not be nested.'): + with raises_with_cause(RuntimeError, match='Parallel modules can not be nested.'): _to_cube_model(Module2(), 'data', ComputeConfig(1, 1), cube_savedir=tempdir) diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index 3b24267c..401acd80 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -13,7 +13,7 @@ from nnscaler.parallel import ReuseType, parallelize, ComputeConfig, _load_parallel_module_class from nnscaler.runtime.module import ParallelModule -from ..utils import new_empty, replace_all_device_with +from ..utils import new_empty, replace_all_device_with, raises_with_cause def _to_cube_model(model_class, compute_config, cube_savedir, reuse, instance_name, load_module=True): @@ -73,7 +73,7 @@ def test_override(): # MATCH | unmatch | raise error _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MATCH, 'm0') - with pytest.raises(RuntimeError, match='.*not empty.*'): + with raises_with_cause(RuntimeError, match='.*not empty.*'): _to_cube_model(MyModule, ComputeConfig(2, 2),tempdir, 'match', 'm0') # MOO | empty | generate @@ -90,19 +90,19 @@ def test_override(): # MOO | imported | raise error _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o2', load_module=True) - with pytest.raises(RuntimeError): + with raises_with_cause(RuntimeError): _to_cube_model(MyModule, ComputeConfig(2, 2),tempdir, ReuseType.MOO, 'o2') # OVERRIDE | imported | raise error - with pytest.raises(RuntimeError): + with raises_with_cause(RuntimeError): _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.OVERRIDE, 'mm0') # OVERRIDE | imported | raise error - with pytest.raises(RuntimeError): + with raises_with_cause(RuntimeError): _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.OVERRIDE, 'test') # OVERRIDE | imported | raise error - with pytest.raises(RuntimeError): + with raises_with_cause(RuntimeError): _to_cube_model(MyModule, ComputeConfig(2, 2),tempdir, ReuseType.OVERRIDE, 'test') # OVERRIDE | empty | generate @@ -168,7 +168,7 @@ def test_override(): g6_module = _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'g6') # Graph | imported | raise error - with pytest.raises(RuntimeError): + with raises_with_cause(RuntimeError): _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'g6') # Graph | unmatch | generate diff --git a/tests/parallel_module/test_shared_param_pipeline.py b/tests/parallel_module/test_shared_param_pipeline.py index 222c0962..86dc4380 100644 --- a/tests/parallel_module/test_shared_param_pipeline.py +++ b/tests/parallel_module/test_shared_param_pipeline.py @@ -19,7 +19,7 @@ from nnscaler.ir.operator import IRFwOperation, IRDataOperation from nnscaler.graph.segment import IRSegment from nnscaler.graph.schedule.predefined import PredefinedSched -from tests.utils import clear_dir_on_rank0, init_random +from tests.utils import clear_dir_on_rank0, init_random, raises_with_cause from tests.launch_torchrun import torchrun from tests.parallel_module.test_gencode import _gencode_contains, print_gencode @@ -404,7 +404,7 @@ def test_shared_param_error(model_cls): trace_data = torch.randn([2, 16], dtype=torch.float32, device=torch.cuda.current_device()) with tempfile.TemporaryDirectory() as tempdir: - with pytest.raises(RuntimeError, match='The weight consumers can either be ALL replicated or ALL partitioned'): + with raises_with_cause(RuntimeError, match='The weight consumers can either be ALL replicated or ALL partitioned'): parallelize( m, {'x': trace_data}, @@ -464,7 +464,7 @@ def test_shared_param_error2(model_cls): trace_data = torch.randn([2, 16], dtype=torch.float32, device=torch.cuda.current_device()) with tempfile.TemporaryDirectory() as tempdir: - with pytest.raises(RuntimeError, match='The weight consumers can either be ALL replicated or ALL partitioned'): + with raises_with_cause(RuntimeError, match='The weight consumers can either be ALL replicated or ALL partitioned'): parallelize( m, {'x': trace_data}, diff --git a/tests/utils.py b/tests/utils.py index 272b02c2..22036f42 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,11 +51,11 @@ def trunc_normal_(tensor: torch.Tensor, mean: float = 0., std: float = 1., a: fl torch.nn.init.constant_(param, 0) -def init_random(): - np.random.seed(1) - torch.manual_seed(1) +def init_random(seed: int = 1): + np.random.seed(seed) + torch.manual_seed(seed) if torch.cuda.is_available(): - torch.cuda.manual_seed(1) + torch.cuda.manual_seed(seed) def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-4) -> bool: @@ -377,3 +377,27 @@ def clear_dir_on_rank0(tempdir): torch.distributed.barrier() if torch.distributed.get_rank() == 0 and tempdir.exists(): shutil.rmtree(tempdir) + + +def clear_parallel_cache(): + """ + Clear all parallel modules in sys.modules + """ + import sys + parallel_modules = [name for name in sys.modules if name.startswith('_parallel_modules')] + for name in parallel_modules: + del sys.modules[name] + + +@contextmanager +def raises_with_cause(exception_type, match=None): + try: + yield + except Exception as e: + cause = e + while cause.__cause__: + cause = cause.__cause__ + assert isinstance(cause, exception_type), f"unexpected cause: {cause}" + assert not match or re.search(match, str(cause)), f"unexpected cause message: {cause}" + else: + raise AssertionError(f"expected exception {exception_type} not raised") From 1b91d7e25fac3b1a1b96246ef5501731a30b9efd Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Wed, 12 Feb 2025 04:49:45 +0000 Subject: [PATCH 1802/1892] Merged PR 2363: Merge github PR and update version --- README.md | 2 +- docs/source/installation.rst | 2 +- docs/source/quickstart.rst | 2 +- nnscaler/autodist/spmd_solver.py | 2 +- .../comp/torch.Tensor.contiguous.json | 60 +++++ .../comp/torch.Tensor.reshape.json | 48 ++++ .../comp/torch.Tensor.view.json | 48 ++++ .../comp/torch.div.json | 60 +++++ .../comp/torch.matmul.json | 228 ++++++++++++++++++ .../comp/torch.nn.functional.dropout.json | 60 +++++ .../comp/torch.nn.functional.linear.json | 180 ++++++++++++++ .../comp/torch.nn.functional.softmax.json | 64 +++++ .../comp/torch.sum.json | 48 ++++ .../comp/torch.transpose.json | 180 ++++++++++++++ .../spmd_solver/test_weight_comm_time.py | 60 +++++ .../spmd_solver/test_weight_comm_time.yaml | 17 ++ 16 files changed, 1057 insertions(+), 4 deletions(-) create mode 100644 tests/autodist/spmd_solver/test_weight_comm_time.py create mode 100644 tests/autodist/spmd_solver/test_weight_comm_time.yaml diff --git a/README.md b/README.md index 88d55bce..4062c2fa 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ nnScaler is a parallelization engine that compiles a Deep neural network (DNN) m # Latest News nnScaler (also known as CUBE as code name) has been adopted by multiple product and research projects, this section includes some of the latest news from the team and partner projects. -* **2025-01-08** nnScaler 0.6 released: https://github.com/microsoft/nnscaler/releases/tag/0.6 +* **2025-02-12** nnScaler 0.7 released: https://github.com/microsoft/nnscaler/releases/tag/0.7 * **2024-10-07** Diff-Transformer utilizes nnScaler for differential attention mechanism: [DIFFERENTIAL TRANSFORMER](https://arxiv.org/abs/2410.05258) * **2024-05-09** YOCO utilizes nnScaler for long-sequence training: [(YOCO)You only cache once: Decoder-decoder architectures for language models](https://arxiv.org/abs/2405.05254) * **2024-04-22** Post training for the long context version of [Phi-3 series](https://arxiv.org/abs/2404.14219) diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d71553c9..2e26a9af 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -12,7 +12,7 @@ The wheel package is hosted on `GitHub release 0 for comm in weight_comm_times) + else: + assert all(comm == 0 for comm in weight_comm_times) + diff --git a/tests/autodist/spmd_solver/test_weight_comm_time.yaml b/tests/autodist/spmd_solver/test_weight_comm_time.yaml new file mode 100644 index 00000000..f316c930 --- /dev/null +++ b/tests/autodist/spmd_solver/test_weight_comm_time.yaml @@ -0,0 +1,17 @@ +- allowed_partition_dims: + - 0,0 + name: torch.nn.functional.linear + parent_module: Attention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + - 0,1 + name: torch.matmul + parent_module: Attention + replica_allowed: false +- allowed_partition_dims: + - 0,0 + - 0,1 + name: torch.nn.functional.softmax + parent_module: Attention + replica_allowed: false From 8c57978323358a9ab7ce882cea5b4876f19924a8 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Wed, 12 Feb 2025 05:15:55 +0000 Subject: [PATCH 1803/1892] Merged PR 2358: [AutoDist] Support pipeline_nstages option Added the option to Llama example. Tested with nightly tests. Together with: https://dev.azure.com/msrasrg/SuperScaler/_git/Fairseq/pullrequest/2354 Config update: `explore_pipeline` is removed from pas config and `pipeline` is removed from autodist config (they are the same despite different names). `pipeline_nstages` can be int or "auto"; (not documented) added autodist config `max_pipeline_bubble_ratio` to pas config, so we can use tiny model for UT without being rejected by autodist. when `pipeline_nstages` is 1, pipeline is disabled; when `pipeline_pivots` is empty and `pipeline_nstages` is "auto", pipeline is disabled; when `pipeline_pivots` is empty and `pipeline_nstages` is a number >= 2, raise; when `pipeline_nstages` is "auto" and `not end2end or use_async_reducer`, pipeline is disabled; (I don't really understand it, but there was an assertion that they cannot be used with `explore_pipeline`) for any other cases, it follows when `explore_pipeline` was True. (sorry) included some irrelevant doc changes --- docs/source/parallel_module.md | 6 +- docs/source/quickstart_internal.rst | 2 +- docs/source/trainer.rst | 17 ++++- docs/source/troubleshooting.rst | 32 ++++++++ examples/llama/README.rst | 6 +- examples/llama/train.py | 21 +++--- nnscaler/autodist/apis.py | 7 +- nnscaler/autodist/autodist_config.py | 22 +++--- nnscaler/autodist/model_graph.py | 2 +- nnscaler/autodist/pipeline_solver.py | 8 +- nnscaler/parallel.py | 4 +- nnscaler/policies.py | 41 ++++++++--- tests/autodist/test_pipeline_nstages.py | 98 +++++++++++++++++++++++++ 13 files changed, 220 insertions(+), 46 deletions(-) create mode 100644 tests/autodist/test_pipeline_nstages.py diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 4bdee9a2..f68ec024 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -44,7 +44,7 @@ ParallelizedLLM = parallelize( - Example 2: Parallelize submodules. -In this case, for non-paralle modules, they are replicated inside unit, and run data parallelism across units. See more details about unit in [Compute Config](###ComputeConfig) section. +In this case, for non-paralle modules, they are replicated inside unit, and run data parallelism across units. See more details about unit in [Compute Config](./trainer) section. ```python import torch @@ -270,7 +270,7 @@ Please note the module can't be parallelize if `Module.forward` has positional-o - `pas_policy` (`Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]`): the pas (partition-assign-schedule) policy, which describes how to place all computations across devices. You need either pass a builtin PAS policy name or a a custom policy function which should take an `IRGraph` and a `ComputeConfig` as input, and return a new `IRGraph` with the PAS policy applied. We have 6 builtin PAS policies: `dp`, `tp`, `pp`, `data`, `hybrid`, and `autodist`. Please note all builtin PAS policies except `autodist` are only for test purpose. The `autodist` policy is the recommended policy for most cases. - For details, please refer to [PAS Policies](#pas-policies) section. + For details, please refer to [PAS Policies](./trainer) section. - `compute_config` (`ComputeConfig`): the environment resource @@ -521,4 +521,4 @@ A common problem with static graph is that it is impossible to handle control fl But on the other hand, `self.training` is very common used in module forward method. So we add a very limited support for `self.training` in tracing. -Please note that user code is flattened and transformed into a single `ParallelModule` at runtime, so `training` is a global module state, and we don't support the case that user want to set a sub-module's training to True but remaining modules to False. \ No newline at end of file +Please note that user code is flattened and transformed into a single `ParallelModule` at runtime, so `training` is a global module state, and we don't support the case that user want to set a sub-module's training to True but remaining modules to False. diff --git a/docs/source/quickstart_internal.rst b/docs/source/quickstart_internal.rst index 28b6cc04..98bd6e5c 100644 --- a/docs/source/quickstart_internal.rst +++ b/docs/source/quickstart_internal.rst @@ -87,7 +87,7 @@ If the example works for you, you can now follow the documentation to paralleliz .. _Fairseq: Fairseq (To be retired) -======= +======================= .. TODO: diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 58225b73..310a75d2 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -803,7 +803,15 @@ The configuration of the PAS policy should be passed in the ``pas_config`` of `` and run data parallelism across scale units. It requires the ``use_end2end`` to be true. It has the following configurations. - * ``pipeline_nstages``: the number of stages in the pipeline. Default is ``plan_ngpus``. Optional. + * ``pipeline_nstages``: the number of stages in the pipeline, or ``"auto"`` (let autodist to decide). + Default is ``"auto"``. Optional. + + * If ``pipeline_nstages`` is ``"auto"`` and ``pipeline_pivots`` is specified, it will use pipeline. + (The number of stages will be determined automatically by autodist) + * If ``pipeline_nstages`` is ``"auto"`` and ``pipeline_pivots`` is not specified, it will not use pipeline. + * If ``pipeline_nstages`` is a 1, pipeline will not be used. (``pipeline_pivots`` must not be set) + * If ``pipeline_nstages`` is a number > 1, pipeline will be used. (``pipeline_pivots`` must be set) + * ``pipeline_nmicros``: the number of microbatches in the pipeline. Required. * ``pipeline_scheduler``: the scheduler name for the pipeline. Current we support four schedulers in training ``1f1b``/``1f1b_plus``/``1f1b_interleaved``/``gpipe``/``chimera_direct`` (4 stages pipeline only), and one scheduler in inference ``infer_pipe``. Default is ``1f1b``. Optional. * ``pp_size``: the pipeline parallelism size. Default is ``pipeline_nstages``. Optional. @@ -823,10 +831,11 @@ The configuration of the PAS policy should be passed in the ``pas_config`` of `` * ``save_plan_path (str)``: The path to the plan file to save. Optional. * ``partition_constraints_path (str)``: The path to the partition constraints file. Optional. * ``recompute_modules (str)``: The module names to recompute, separated by ``,``. For example, ``module1,module2``. Optional. - * ``pipeline_pivots (str)``: The module names to pivot the pipeline, separated by ``,``. For example, if ``module1,module2`` is specified, stages searched by pipeline solver only start from either ``module1`` or ``module2``. Optional. + * ``pipeline_pivots (str)``: If set, autodist will try pipeline parallelism to find the best partition plan. + It specifies the module names to pivot the pipeline, separated by ``,``. + For example, if ``module1,module2`` is specified, stages searched by pipeline solver only start from either ``module1`` or ``module2``. + Optional. * ``use_apex_fused_adam_v2``: If set to ``True``, the apex fused adam v2 optimizer will be used. Default is ``False``. Optional. - * ``explore_pipeline``: If set to ``True``, autodist will try pipeline parallelism to find the best partition plan - (but the selected partition plan is not necessarily pipeline parallelism). * ``pipeline_scheduler``: The scheduler name for the pipeline. Please note currently ``1f1b`` is the only supported scheduler in ``autodist``. Default is ``1f1b``. Optional. * ``parallel_profile``: If set to ``True``, autodist will profile operators in parallel by using available gpus. Default is ``True``. Optional. * ``max_partition_degree``: Max degree when partitioning an operator / node. When pipeline parallelism is enabled to explore (``explore_pipeline`` is True), user can change the value to constrain the plan to be composed of stages that span on less or equal to ``max_partition_degree`` devices (recommend to set ``max_partition_degree`` to the number of devices in a node to avoid inter-node communication, but should be be no more than ``plan_ngpus``). Default is ``plan_ngpus``. Optional. diff --git a/docs/source/troubleshooting.rst b/docs/source/troubleshooting.rst index 30e8e331..512c728b 100644 --- a/docs/source/troubleshooting.rst +++ b/docs/source/troubleshooting.rst @@ -232,6 +232,38 @@ Example stacktrace: :: return torch._C._is_autocast_available(device_type) TypeError: _is_autocast_available(): argument 'device_type' (position 1) must be str, not ConcreteAttrProxy +"RuntimeError: Broadcast generated files failed" when use ``run_mode='compile'`` +-------------------------------------------------------------------------------- + +When using ``Trainer``'s ``run_mode='compile'`` option, ``broadcast_strategy`` must be set to ``'none'``. + +How to fix: + +.. code-block:: diff + + trainer_args = TrainerArgs( + run_mode='compile', + ... + +broadcast_strategy=('none' if run_mode=='compile' else 'all'), + ) + +Example stacktrace: :: + + Traceback (most recent call last): + File "model.py", line 63, in + trainer.run() + File ".../nnscaler/cli/trainer.py", line 102, in run + self._setup() + File ".../nnscaler/cli/trainer.py", line 148, in _setup + pmodel = parallelize_model(self.train_args, self.dummy_input, load_module=not compile_only) + File ".../nnscaler/cli/mixed_module.py", line 281, in parallelize_model + return _new_adapter().parallelize(dummy_input, load_module=load_module) + File ".../nnscaler/cli/mixed_module.py", line 188, in parallelize + pmodel_class = nnscaler.parallelize( + File ".../nnscaler/parallel.py", line 1081, in parallelize + raise RuntimeError("Broadcast generated files failed: torch.distributed is not initialized.") + RuntimeError: Broadcast generated files failed: torch.distributed is not initialized. + Flash Attention Problems ======================== diff --git a/examples/llama/README.rst b/examples/llama/README.rst index cae35b42..51a64a7b 100644 --- a/examples/llama/README.rst +++ b/examples/llama/README.rst @@ -247,10 +247,10 @@ Combined Command .. code-block:: bash - torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$$OMPI_COMM_WORLD_RANK --master_addr="$$MASTER_ADDR" --master_port=$$MASTER_PORT train.py --name llama3-70b --model_id meta-llama/Meta-Llama-3-70B --dataset_path ./bookcorpus_llama3_8K --gpu_mem_constraint 153 --plan_ngpus=8 --runtime_ngpus=16 --explore_pipeline --grad_accumulation_steps 64 --pipeline_pivots LlamaDecoderLayer 2>&1 | tee run.log + torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$$OMPI_COMM_WORLD_RANK --master_addr="$$MASTER_ADDR" --master_port=$$MASTER_PORT train.py --name llama3-70b --model_id meta-llama/Meta-Llama-3-70B --dataset_path ./bookcorpus_llama3_8K --gpu_mem_constraint 153 --plan_ngpus=8 --runtime_ngpus=16 --grad_accumulation_steps 64 --pipeline_pivots LlamaDecoderLayer --pipeline_nstages auto 2>&1 | tee run.log -Note that in the command above, we enable searching for pipeline parallelism by passing ``--explore_pipeline`` -and set the possible pipeline stage boundaries by ``--pipeline_pivots LlamaDecoderLayer``. +Note that in the command above, we enable searching for pipeline parallelism and set the possible pipeline stage boundaries +by passing ``--pipeline_pivots LlamaDecoderLayer --pipeline_nstages auto``. For the 70B model, the flops for forward and backward is about 3968.41 TFLOPs. The detailed config is as following: diff --git a/examples/llama/train.py b/examples/llama/train.py index c998698c..2963ddba 100644 --- a/examples/llama/train.py +++ b/examples/llama/train.py @@ -171,8 +171,8 @@ def collate_fn(samples): use_end2end=True, pas_config={ 'mem_constraint': args.gpu_mem_constraint, - 'explore_pipeline': args.explore_pipeline, 'pipeline_pivots': args.pipeline_pivots, + 'pipeline_nstages': args.pipeline_nstages, 'recompute_modules': args.recompute_modules, }, trace_strategy=args.trace_strategy, @@ -316,16 +316,17 @@ def collate_fn(samples): action='store_true', help='enable chunk loss that exchanges the speed of training for the memory usage', ) - parser.add_argument( - '--explore_pipeline', - action='store_true', - help='explore pipeline parallelism in autodist', - ) parser.add_argument( '--pipeline_pivots', default='', type=str, - help='specify the pipeline pivots for autodist', + help='explore pipeline parallelism by specifying the pipeline pivots for autodist', + ) + parser.add_argument( + '--pipeline_nstages', + default=1, + type=str, + help='specify the number of stages in the pipeline (use "1" to disable pipeline; use "auto" for autodist)', ) parser.add_argument( '--recompute_modules', @@ -357,8 +358,10 @@ def collate_fn(samples): help='enable diff attention implementation, eager is normal diff attention, flash_attention_2 is diff flash attention, and spda diff attention is not currently supported', ) args = parser.parse_args() - if args.explore_pipeline and not args.pipeline_pivots: - raise ValueError('pipeline_pivots must be specified when explore_pipeline is enabled') + if args.pipeline_nstages != 'auto': + args.pipeline_nstages = int(args.pipeline_nstages) + if args.pipeline_nstages > 1 and not args.pipeline_pivots: + raise ValueError('pipeline_pivots must be specified when pipeline is enabled') if os.getenv('DETERMINISTIC'): # reduce randomness for integration test os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 7fcd229a..4bf4ea14 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -97,7 +97,7 @@ def calc_parallel_plan(graph: IRGraph, [node.cid for node in group] for group in recompute_groups ] - if autodist_config.pipeline: + if autodist_config.pipeline_enabled: pp_out = calc_optimal_pp_plan(autodist_graph, autodist_config) else: pp_out = calc_optimal_spmd_plan(autodist_graph, autodist_config) @@ -215,9 +215,8 @@ def subtensor_desc(t): else: stages = [graph] - # TODO: check pipeline_nstages when ready. - # if autodist_config.pipeline and len(stages) != autodist_config.pipeline_nstages: - # raise RuntimeError("pipeline_nstages doesn't match the number of stages (based on your pipeline_pivots config) in the plan") + if autodist_config.pipeline_nstages != 'auto' and len(stages) != autodist_config.pipeline_nstages: + raise RuntimeError("pipeline_nstages doesn't match the number of stages (based on your pipeline_pivots config) in the plan") # add multiref to an activation tensor when the states of the tensor and its grad are different # among consumers and current segment's outputs diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index a846f550..c5e4ab77 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -88,13 +88,11 @@ class AutoDistConfig: Whether to print verbose information. - re_profile (`bool`, *optional*, defaults to `False`): If set to `True`, the computation profiling results will be overridden. - - pipeline (`bool`, *optional*, defaults to `False`): - Whether to use pipeline parallelism or tensor parallelism. - pipeline_pivots (`str`, *optional*, defaults to `''`): The module names to pivot the pipeline, separated by `,`. For example, if `module1,module2` is specified, stages searched by pipeline solver only start from either `module1` or `module2`. - - pipeline_nstages(`int`, *optional*, defaults to `1`): - The number of stages in pipeline parallelism. This option is only used when pipeline is True. + - pipeline_nstages(`int | Literal['auto']`, *optional*, defaults to `'auto'`): + When `pipeline_pivots` is not empty, this specify the number of stages in pipeline parallelism. `1` means not to use pipelines. - pipeline_scheduler (`str`, *optional*, defaults to `'1f1b'`): The pipeline scheduler to use. Currently only support `'1f1b'`. - max_pipeline_bubble_ratio (`float`, *optional*, defaults to `0.2`): @@ -142,9 +140,8 @@ def __init__(self, ignore_small_tensor_threshold=1, verbose=False, re_profile=False, - pipeline=False, pipeline_pivots='', - pipeline_nstages=1, + pipeline_nstages='auto', pipeline_scheduler='1f1b', max_pipeline_bubble_ratio=0.2, max_pipeline_unbalance_ratio=0.5, @@ -178,7 +175,6 @@ def __init__(self, self.ignore_small_tensor_threshold = ignore_small_tensor_threshold self.verbose = verbose self.re_profile = re_profile - self.pipeline = pipeline self.pipeline_pivots = pipeline_pivots self.pipeline_nstages = pipeline_nstages self.pipeline_scheduler = pipeline_scheduler @@ -187,7 +183,7 @@ def __init__(self, self.max_pipeline_bubble_ratio = max_pipeline_bubble_ratio self.max_pipeline_unbalance_ratio = max_pipeline_unbalance_ratio self.solver = solver - if pipeline and solver != 'dp': + if self.pipeline_enabled and solver != 'dp': _logger.warning( f'pipeline is enabled, but solver is not dp, set solver to dp' ) @@ -210,7 +206,7 @@ def _validate_config(self): _logger.info(f'create folder: {self.profile_dir}') Path(self.profile_dir).mkdir(parents=True, exist_ok=True) - if self.pipeline: + if self.pipeline_enabled: if self.max_pipeline_bubble_ratio <= 0 or self.max_pipeline_bubble_ratio >= 1: raise ValueError( f'max pipeline bubble ratio {self.max_pipeline_bubble_ratio} must be in (0, 1)' @@ -261,3 +257,11 @@ def __repr__(self): @property def ngpus(self): return self.mesh_desc.ngpus + + @property + def pipeline_enabled(self) -> bool: + # whether to explore pipeline + # "auto" is considered as enabled although the exploration result might be not to use pipeline + if not self.pipeline_pivots: + return False + return self.pipeline_nstages == 'auto' or self.pipeline_nstages > 1 diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 31b287e5..7ab6f6c2 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -589,7 +589,7 @@ def get_pipeline_pivots(self) -> List[int]: the indices of pivot operators in the operator list ''' # TODO(yizhu1): check recompute_modules are between pivots - if not self.autodist_config.pipeline: + if not self.autodist_config.pipeline_enabled: raise RuntimeError('pipeline is not enabled') pp_pivot_modules = self.autodist_config.pipeline_pivots.split(',') pp_pivot_modules = [module for module in pp_pivot_modules if module] diff --git a/nnscaler/autodist/pipeline_solver.py b/nnscaler/autodist/pipeline_solver.py index 90ae3213..75f085f6 100644 --- a/nnscaler/autodist/pipeline_solver.py +++ b/nnscaler/autodist/pipeline_solver.py @@ -211,7 +211,10 @@ def shift_plan(solver, spmd_desc, offset: int, shifted_start: int, shifted_end: tp_info = {} for tp_degree in legal_tp_degrees: - for stage_num in range(1, _calc_upper_bound(tp_degree) + 1): + stage_num_bound = _calc_upper_bound(tp_degree) + if cfg.pipeline_nstages != 'auto': + stage_num_bound = min(stage_num_bound, cfg.pipeline_nstages) + for stage_num in range(1, stage_num_bound + 1): solver, intervals, solver_ret = process_case(tp_degree, stage_num) for interval, spmd_descs in zip(intervals, solver_ret): start, end = interval @@ -292,11 +295,12 @@ def calc_optimal_pp_plan( val = max(lhs, rhs) if T[cur_idx][0] > val: T[cur_idx] = [val, prev_idx] - best_time = float('inf') best_state = (-1, -1, -1, -1) micro_batch_num = autodist_config.update_freq for stage_num in range(1, ngpus + 1): + if autodist_config.pipeline_nstages != 'auto' and autodist_config.pipeline_nstages != stage_num: + continue for pp_dev_num in range(stage_num, ngpus + 1): for tp_degree in range(1, pp_dev_num - stage_num + 1 + 1): if tp_degree not in legal_tp_degrees: diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 0eda7f57..d47274e2 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2435,8 +2435,10 @@ def sync_grad_when(cond: bool): only when `cond` is True. This is needed when + 1. The mode is not end2end model. - For end2end model, gradients are synchronized across workers automatically. + For end2end model, gradients are synchronized across workers automatically. + 2. async is enabled (`compute_config.use_async_reducer` is `True`). If both conditions are not satisfied, this function has no effect. diff --git a/nnscaler/policies.py b/nnscaler/policies.py index c8d4ba91..5d534ab5 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -133,7 +133,8 @@ def pas_pp(graph: IRGraph, cfg: 'ComputeConfig'): """ pipeline parallelism inside a scale unit, and dp across scale units """ - if cfg.pas_config.get('pipeline_nstages', cfg.plan_ngpus) != cfg.plan_ngpus: + nstages = cfg.pas_config.get('pipeline_nstages', 'auto') + if nstages != 'auto' and nstages != cfg.plan_ngpus: raise ValueError("pas_pp requires pipeline_nstages == plan_ngpus") return pas_hybrid(graph, cfg) @@ -172,7 +173,9 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): raise ValueError("Hybrid policy only supports end2end module") ngpus: int = cfg.plan_ngpus - nstages = cfg.pas_config.get('pipeline_nstages', cfg.plan_ngpus) + nstages = cfg.pas_config.get('pipeline_nstages', 'auto') + if nstages == 'auto': + nstages = cfg.plan_ngpus nmicros = cfg.pas_config['pipeline_nmicros'] scheduler = cfg.pas_config.get('pipeline_scheduler', '1f1b') pp_size = cfg.pas_config.get('pp_size', nstages) @@ -223,11 +226,29 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: update_freq = update_freq[0] # optional parameters - explore_pipeline = pas_cfg.get('explore_pipeline', False) - if explore_pipeline and not cfg.use_end2end: - raise ValueError("explore_pipeline cannot be enabled if use_end2end is False") - if explore_pipeline and cfg.use_async_reducer: - raise ValueError("explore_pipeline cannot be enabled if use_async_reducer is True") + + # Note we don't directly pass pipeline_nstages to autodist. + # when `pipeline_nstages == 'auto'`, we will check if there are options incompatible with pipeline. + # if we find incompabible options (here use_async_reducer and pipeline_pivots), + # we will disable pipeline effectively by setting it to 1. + pipeline_nstages = pas_cfg.get('pipeline_nstages', 'auto') + + if pipeline_nstages == 'auto': + if not pas_cfg.get('pipeline_pivots'): + pipeline_nstages = 1 + if not cfg.use_end2end or cfg.use_async_reducer: + pipeline_nstages = 1 + elif pipeline_nstages > 1: + # the user manually enabled pipeline, should not disable, so raise + if not pas_cfg.get('pipeline_pivots'): + raise ValueError("pipeline_pivots must be set to enable pipeline") + if not cfg.use_end2end: + raise ValueError("explore_pipeline cannot be enabled if use_end2end is False") + if cfg.use_async_reducer: + raise ValueError("explore_pipeline cannot be enabled if use_async_reducer is True") + else: + if pas_cfg.get('pipeline_pivots'): + raise ValueError("pipeline_pivots must not be set because pipeline is disabled by pipeline_nstages<=1") pipeline_scheduler = pas_cfg.get('pipeline_scheduler', '1f1b') if pipeline_scheduler != '1f1b': @@ -237,7 +258,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: if cfg.plan_ngpus % mesh_col != 0: raise ValueError(f"plan_ngpus {cfg.plan_ngpus} should be divisible by max_partition_degree {mesh_col}") mesh_row = cfg.plan_ngpus // mesh_col - if not explore_pipeline and mesh_row != 1: + if pipeline_nstages == 1 and mesh_row != 1: raise ValueError("mesh_row should be 1 if pipeline is not enabled") memory_constraint = pas_cfg.get('mem_constraint', -1) task_name = pas_cfg.get('task_name', '_') @@ -252,6 +273,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: partition_constraints_path = pas_cfg.get('partition_constraints_path', '') recompute_modules = pas_cfg.get('recompute_modules', '') pipeline_pivots = pas_cfg.get('pipeline_pivots', '') + max_pipeline_bubble_ratio = pas_cfg.get('max_pipeline_bubble_ratio', 0.2) use_apex_fused_adam_v2 = pas_cfg.get('use_apex_fused_adam_v2', False) parallel_profile = pas_cfg.get('parallel_profile', True) transient_mem_coef = pas_cfg.get('transient_mem_coef', 2) @@ -315,8 +337,9 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: zero_ngroups=zero_ngroups, load_plan_path=load_plan_path, save_plan_path=save_plan_path, - pipeline=explore_pipeline, pipeline_pivots=pipeline_pivots, + pipeline_nstages=pipeline_nstages, + max_pipeline_bubble_ratio=max_pipeline_bubble_ratio, parallel_profile=parallel_profile, transient_mem_coef=transient_mem_coef, ) diff --git a/tests/autodist/test_pipeline_nstages.py b/tests/autodist/test_pipeline_nstages.py new file mode 100644 index 00000000..07952abb --- /dev/null +++ b/tests/autodist/test_pipeline_nstages.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import pytest +import torch +from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer_args import * +from tests.launch_torchrun import launch_torchrun + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='need 2 gpus') +def test_1_stage(tmp_path): + launch_torchrun(1, _compile_worker, tmp_path, 1) + # for TP, the scripts should be identical (except tensor names) + lines0 = _count_gencode_lines(tmp_path, 0) + lines1 = _count_gencode_lines(tmp_path, 1) + assert lines0 == lines1 + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='need 2 gpus') +def test_2_stages(tmp_path): + launch_torchrun(1, _compile_worker, tmp_path, 2) + # for PP, since we have 3 linears, the scripts should be different + lines0 = _count_gencode_lines(tmp_path, 0) + lines1 = _count_gencode_lines(tmp_path, 1) + assert lines0 != lines1 + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='need 2 gpus') +def test_auto_stages(tmp_path): + launch_torchrun(1, _compile_worker, tmp_path, 'auto') + # just check it does not throw + # because both results are possible theoretically + + +def _compile_worker(tmp_path, nstages): + _compile(tmp_path, nstages) + + +def _compile(tmp_path, nstages): + trainer_args = TrainerArgs( + compute_config=ComputeConfig( + plan_ngpus=2, + runtime_ngpus=2, + use_end2end=True, + pas_config={ + 'pipeline_pivots': 'Linear', + 'pipeline_nstages': nstages, + 'max_pipeline_bubble_ratio': 0.99, # force autodist to accept unbalanced stages + }, + ), + gen_reuse='override', + gen_savedir=tmp_path/'src', + run_mode='compile', + model=ModelConfig(type=Model), + optimizer=OptimizerConfig(type=torch.optim.AdamW), + dataset=DatasetConfig(type=RandomDataset, train_args={'length': 100}), + max_train_steps=1, + ) + trainer = Trainer(train_args=trainer_args) + trainer.run() + + +def _count_gencode_lines(tmp_path, index): + script = 'tests/autodist/test_pipeline_nstages' + path = f'_parallel_modules/{script}/Model/_/gencode{index}.py' + text = Path(tmp_path, 'src', path).read_text() + return text.count('\n') + + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 10) + self.linear2 = torch.nn.Linear(10, 10) + self.linear3 = torch.nn.Linear(10, 10) + + def forward(self, data): + x = data['x'] + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return torch.nn.functional.cross_entropy(x, data['y']) + + +class RandomDataset: + def __init__(self, length): + self.length = length + + def __getitem__(self, i): + return { + 'x': torch.rand(10), + 'y': torch.randint(10, tuple()), + } + + def __len__(self): + return self.length From 078d187f7c6e1c3aac4632eae09fb9e56672ea7a Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Wed, 12 Feb 2025 07:06:14 +0000 Subject: [PATCH 1804/1892] Merged PR 2361: [Example] improve diff attention This version only implements the one found at [multihead_flashdiff_2.py](https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_flashdiff_2.py). The original flash_attn library does not support different qk/v dimensions, resulting in the execution of the flash_attn_func four times, which doubles the computational load compared to the original LlamaFlashAttention2. For scenarios involving long sequence lengths and a small number of heads, this method takes nearly four times as long as the original. For the more efficient version available at [multihead_flashdiff_1.py](https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_flashdiff_1.py), it is necessary to utilize a separate flash_attention library. If required, you can implement this version independently by following the guidelines provided in the [README.md](https://github.com/microsoft/unilm/blob/master/Diff-Transformer/README.md). On an 8*H100 setup, the model used is Llama3 with attention in 28 layers, a hidden size of 3072, an FFN size of 8192, and 12 heads. The input has a batch size of 1 and a sequence length of 64K. The parallel plan includes TP 4, ZeRO, and recompute for LlamaDecoderLayer. Performance is 60 seconds per iteration, resulting in an MFU of approximately 7500 TFLOPS / 60s / 4 / 312 Tflops/s = 10% Related work items: #2122 --- .../core/ring_attn_implementation.py | 19 ++- .../ring_attention/ring_attn.py | 11 +- examples/llama/README.rst | 11 +- .../lm_models/diff_transformer_modifier.py | 62 ++++---- examples/llama/lm_models/llama_modifier.py | 134 +---------------- examples/llama/lm_models/utils.py | 139 +++++++++++++++++- examples/llama/train.py | 7 +- 7 files changed, 209 insertions(+), 174 deletions(-) diff --git a/examples/customized_ops/ring_attention/core/ring_attn_implementation.py b/examples/customized_ops/ring_attention/core/ring_attn_implementation.py index 42219ad3..f7c23f16 100644 --- a/examples/customized_ops/ring_attention/core/ring_attn_implementation.py +++ b/examples/customized_ops/ring_attention/core/ring_attn_implementation.py @@ -10,6 +10,13 @@ _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() +import flash_attn + +version = flash_attn.__version__ +if not version.startswith("2.6"): + raise ImportError("The current version of Ring Attention is not compatible with Flash Attention versions other than 2.6.x.") + + def ring_flash_attn_forward( process_group, q: torch.Tensor, @@ -19,6 +26,7 @@ def ring_flash_attn_forward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -45,6 +53,7 @@ def ring_flash_attn_forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -63,6 +72,7 @@ def ring_flash_attn_forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) @@ -84,6 +94,7 @@ def ring_flash_attn_backward( dropout_p=0, causal=True, window_size=(-1, -1), + softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -124,6 +135,7 @@ def ring_flash_attn_backward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, @@ -156,6 +168,7 @@ def ring_flash_attn_backward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, @@ -192,6 +205,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -219,6 +233,7 @@ def forward( dropout_p=dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=False, ) @@ -228,6 +243,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic ctx.group = group @@ -259,8 +275,9 @@ def backward(ctx, dout, *args): dropout_p=ctx.dropout_p, causal=ctx.causal, window_size=ctx.window_size, + softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, ) dq = recover_output(dq, ctx.group) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None diff --git a/examples/customized_ops/ring_attention/ring_attn.py b/examples/customized_ops/ring_attention/ring_attn.py index 801378ce..1a167a7f 100644 --- a/examples/customized_ops/ring_attention/ring_attn.py +++ b/examples/customized_ops/ring_attention/ring_attn.py @@ -2,21 +2,19 @@ # Licensed under the MIT License. from typing import Tuple, List, Dict -import torch from torch import Tensor -import torch.distributed from nnscaler.graph.parser.register import register_op from nnscaler.ir.operator import IRFwOperation from core.ring_attn_implementation import RingFlashAttnFunc from flash_attn import flash_attn_func -import torch.distributed as dist from nnscaler.runtime.device import DeviceGroup + def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), - alibi_slopes: Tensor=None, deterministic: bool=False, + softcap: float=0.0, alibi_slopes: Tensor=None, deterministic: bool=False, return_attn_probs: bool=False, process_group: Tuple[int]=None) -> Tensor: ''' @@ -49,6 +47,10 @@ def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=N local_process_group = DeviceGroup().get_group(process_group) + # In the RingFlashAttnFunc.apply function, the torch.distributed._all_gather_base function + # requires that the k and v tensors be contiguous. + k = k.contiguous() + v = v.contiguous() output = RingFlashAttnFunc.apply( q, k, @@ -57,6 +59,7 @@ def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=N softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, diff --git a/examples/llama/README.rst b/examples/llama/README.rst index 51a64a7b..a5a8c148 100644 --- a/examples/llama/README.rst +++ b/examples/llama/README.rst @@ -150,7 +150,7 @@ Here ``b`` stands for batch size, ``l`` stands for sequence length, ``d`` stands The ``^`` means the dimension cannot be partitioned. More details about the annotation can be found in :doc:`../register_custom_op`. -You can enable this customized function by passing ``--enable-chunk-loss`` to ``train.py`` when compiling. +You can enable this customized function by passing ``--enable_chunk_loss`` to ``train.py`` when compiling. When the sequence length is small (like 8K), this option can be turned off. Profile Communication @@ -205,7 +205,7 @@ Compile .. code-block:: bash - python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable-chunk-loss 2>&1 | tee compile.log + python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --enable_chunk_loss 2>&1 | tee compile.log Run --- @@ -276,6 +276,13 @@ Based on AutoDist's analysis, the low utilization results from following aspects Currently we only consider plan_ngpus=8 and fix the pipeline schedule to be ``1f1b``. We can refine this assumption in the future. +**************** +DIFFERENTIAL TRANSFORMER +**************** + +Users can utilize ``DIFFERENTIAL TRANSFORMER`` by using the flag ``--enable_diff_attn``. ``DIFFERENTIAL TRANSFORMER`` is an alternative to the traditional attention mechanism ``. It implements differential attention as the Attention module, replacing the conventional ``eager attention`` or ``flash attention`` modules. +When the sequence length is extremely long, ``ring diff attn`` can be employed by using the flag ``--enable_ring_attn``. This approach can break through the limitations of key - value heads, achieving a more refined partitioning. + ********* Debugging ********* diff --git a/examples/llama/lm_models/diff_transformer_modifier.py b/examples/llama/lm_models/diff_transformer_modifier.py index 134d1337..6e3c5820 100644 --- a/examples/llama/lm_models/diff_transformer_modifier.py +++ b/examples/llama/lm_models/diff_transformer_modifier.py @@ -16,19 +16,26 @@ import logging import math -from transformers.utils import ( - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, -) - -from nnscaler.graph.parser.register import register_op -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func +from transformers.utils import is_flash_attn_greater_or_equal_2_10 +from .utils import nnscaler_flash_attention_forward logger = logging.getLogger(__name__) +try: + import os + import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname( + os.path.abspath(__file__)), '../../customized_ops/ring_attention'))) + from ring_attn import wrap_ring_attn_func + + def nnscaler_ring_attn_func(query_states, key_states, value_states, *args, **kwargs): + return wrap_ring_attn_func(query_states, key_states, value_states) +except ModuleNotFoundError: + logger.warning("Ring Attention is not import correctly.") + + def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" bs, n_kv_heads, slen, head_dim = x.shape @@ -63,9 +70,6 @@ def __init__(self, *args, **kwargs): f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.num_key_value_groups, bias=self.config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.num_key_value_groups, bias=self.config.attention_bias) self._init_rope() assert self.layer_idx is not None, "layer_idx must be provided for NNScalerMultiheadDiffAttn" @@ -118,6 +122,7 @@ def forward( causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] else: causal_mask = torch.triu(torch.zeros([q_len, q_len]).float().fill_(float("-inf")).type_as(query_states), 1) + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() @@ -163,6 +168,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.attn_func = nnscaler_flash_attention_forward def forward( self, @@ -199,11 +205,6 @@ def forward( # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if query_states.device.type == "cuda": - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() query_states = query_states.reshape(bsz, q_len, self.num_heads, 2, self.head_dim) key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, 2, self.head_dim) @@ -211,12 +212,12 @@ def forward( k1, k2 = key_states[:, :, :, 0], key_states[:, :, :, 1] v1, v2 = value_states[:, :, :, 0], value_states[:, :, :, 1] - attn11 = flash_attn_func(q1, k1, v1, causal=True) - attn12 = flash_attn_func(q1, k1, v2, causal=True) + attn11 = self.attn_func(q1, k1, v1, attention_mask, q_len, causal=True) + attn12 = self.attn_func(q1, k1, v2, attention_mask, q_len, causal=True) attn1 = torch.cat([attn11, attn12], dim=-1) - attn21 = flash_attn_func(q2, k2, v1, causal=True) - attn22 = flash_attn_func(q2, k2, v2, causal=True) + attn21 = self.attn_func(q2, k2, v1, attention_mask, q_len, causal=True) + attn22 = self.attn_func(q2, k2, v2, attention_mask, q_len, causal=True) attn2 = torch.cat([attn21, attn22], dim=-1) lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1)) @@ -225,23 +226,14 @@ def forward( attn_output = attn1 - lambda_full * attn2 attn_output = self.subln(attn_output) attn_output = attn_output * (1 - self.lambda_init) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * 2 * self.head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * 2 * self.head_dim).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - -def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: - if query_states.shape[2] != key_states.shape[2]: - assert query_states.shape[2] % key_states.shape[2] == 0 - group_size = query_states.shape[2] // key_states.shape[2] - assert query_states.shape[2] == value_states.shape[2] * group_size - q_anno = f'(group_num {group_size})' - kv_anno = 'group_num' - else: - q_anno = kv_anno = 'num_heads' - return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' - -if is_flash_attn_2_available(): - register_op(flash_attention_anno)(flash_attn_func) \ No newline at end of file + +class NNScalerMultiheadDiffRingAttn(NNScalerMultiheadDiffFlashAttn): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_func = nnscaler_ring_attn_func diff --git a/examples/llama/lm_models/llama_modifier.py b/examples/llama/lm_models/llama_modifier.py index cce87f3c..084d0f2b 100644 --- a/examples/llama/lm_models/llama_modifier.py +++ b/examples/llama/lm_models/llama_modifier.py @@ -5,25 +5,19 @@ # 1. register the flash attention function to nnscaler and update related code # 2. replace the un-fused RMSNorm with apex's fused version -import types -from typing import List, Optional, Tuple, Union - -from nnscaler.graph.parser.register import register_op -from nnscaler.ir import IRTensor +from typing import Optional, Tuple import torch +import logging + +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb +from transformers.utils import is_flash_attn_greater_or_equal_2_10 -from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.models.llama.modeling_llama import LlamaAttention, LLAMA_ATTENTION_CLASSES, apply_rotary_pos_emb, LlamaRMSNorm -from transformers.utils import ( - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, -) +from .utils import nnscaler_flash_attention_forward -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +logger = logging.getLogger(__name__) class NNScalerLlamaFlashAttention2(LlamaAttention): @@ -129,115 +123,3 @@ def forward( return attn_output, attn_weights, past_key_value - -def nnscaler_flash_attention_forward( - query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, causal=True -): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = nnscaler_upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) - - return attn_output - - -def nnscaler_upad_input(query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - _, _, num_heads, _ = query_layer.shape - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -def llama_flash_attention_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str: - if query_states.shape[2] != key_states.shape[2]: - assert query_states.shape[2] % key_states.shape[2] == 0 - group_size = query_states.shape[2] // key_states.shape[2] - assert query_states.shape[2] == value_states.shape[2] * group_size - q_anno = f'(group_num {group_size})' - kv_anno = 'group_num' - else: - q_anno = kv_anno = 'num_heads' - if isinstance(attention_mask, IRTensor): - return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, b l^ -> b l^ {q_anno} vd^' - else: - return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' - - -register_op(llama_flash_attention_anno)(nnscaler_flash_attention_forward) diff --git a/examples/llama/lm_models/utils.py b/examples/llama/lm_models/utils.py index 3b3745a2..a3673bed 100644 --- a/examples/llama/lm_models/utils.py +++ b/examples/llama/lm_models/utils.py @@ -2,7 +2,9 @@ # Licensed under the MIT License. import torch -from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES, LlamaRMSNorm + +from nnscaler.ir import IRTensor +from nnscaler.graph.parser.register import register_op try: from apex.normalization.fused_layer_norm import fused_rms_norm_affine @@ -10,6 +12,14 @@ except ImportError: has_apex = False +from transformers.utils import is_flash_attn_2_available +from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES, LlamaRMSNorm + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + def rmsnorm_fwd(self, hidden_states): if has_apex: @@ -24,12 +34,131 @@ def rmsnorm_fwd(self, hidden_states): def nnscaler_lm_init(args): if args.enable_diff_attn: - from .diff_transformer_modifier import NNScalerMultiheadDiffAttn, NNScalerMultiheadDiffFlashAttn if args.attn_implementation == "sdpa": raise ValueError("sdpa is currently not supported in Diff-Transformer") - LLAMA_ATTENTION_CLASSES["eager"] = NNScalerMultiheadDiffAttn - LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerMultiheadDiffFlashAttn + if args.enable_ring_attn: + if args.attn_implementation == "eager": + raise ValueError("Ring Attention only support flash attention") + from .diff_transformer_modifier import NNScalerMultiheadDiffRingAttn + LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerMultiheadDiffRingAttn + else: + from .diff_transformer_modifier import NNScalerMultiheadDiffAttn, NNScalerMultiheadDiffFlashAttn + LLAMA_ATTENTION_CLASSES["eager"] = NNScalerMultiheadDiffAttn + LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerMultiheadDiffFlashAttn else: from .llama_modifier import NNScalerLlamaFlashAttention2 LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerLlamaFlashAttention2 - LlamaRMSNorm.forward = rmsnorm_fwd \ No newline at end of file + LlamaRMSNorm.forward = rmsnorm_fwd + + +def nnscaler_flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, causal=True +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = nnscaler_upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + +def nnscaler_upad_input(query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + _, _, num_heads, _ = query_layer.shape + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def flash_attention_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + if isinstance(attention_mask, IRTensor): + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, b l^ -> b l^ {q_anno} vd^' + else: + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' + + +register_op(flash_attention_anno)(nnscaler_flash_attention_forward) diff --git a/examples/llama/train.py b/examples/llama/train.py index 2963ddba..66deeb40 100644 --- a/examples/llama/train.py +++ b/examples/llama/train.py @@ -312,7 +312,7 @@ def collate_fn(samples): help='trace strategy control the function execution during tracing model graph, `cuda_run_cpu_offload` and `reuse_cache` are recommended, please read `docs/source/parallel_module.md` for more information', ) parser.add_argument( - '--enable-chunk-loss', + '--enable_chunk_loss', action='store_true', help='enable chunk loss that exchanges the speed of training for the memory usage', ) @@ -357,6 +357,11 @@ def collate_fn(samples): action='store_true', help='enable diff attention implementation, eager is normal diff attention, flash_attention_2 is diff flash attention, and spda diff attention is not currently supported', ) + parser.add_argument( + '--enable_ring_attn', + action='store_true', + help='enable ring attention, currently only diff flash attention is supported', + ) args = parser.parse_args() if args.pipeline_nstages != 'auto': args.pipeline_nstages = int(args.pipeline_nstages) From d5e8b305252e28224bf5f883ead4182c9c39b416 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 13 Feb 2025 02:56:20 +0000 Subject: [PATCH 1805/1892] Merged PR 2366: [BugFix] Fix autodist's data parallel test --- tests/autodist/spmd_solver/test_follow.py | 2 +- .../comp/torch.Tensor.contiguous.json | 50 ++--- .../comp/torch.Tensor.reshape.json | 40 ++-- .../comp/torch.Tensor.view.json | 40 ++-- .../comp/torch.div.json | 50 ++--- .../comp/torch.matmul.json | 192 +++++++++--------- .../comp/torch.nn.functional.dropout.json | 50 ++--- .../comp/torch.nn.functional.linear.json | 120 +++++------ .../comp/torch.nn.functional.softmax.json | 48 ++--- .../comp/torch.sum.json | 40 ++-- .../comp/torch.transpose.json | 150 +++++++------- 11 files changed, 391 insertions(+), 391 deletions(-) diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index 1b585308..6af9ab35 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -309,7 +309,7 @@ def helper(search_out): def test_solver_data_parallel(): from nnscaler.ir.unique import IDGenerator IDGenerator().clear() - bsz, seq_len, hidden_dim, num_heads = 2, 2048, 512, 8 + bsz, seq_len, hidden_dim, num_heads = 2, 8192, 512, 8 dummy_input = { 'x': torch.rand(bsz, seq_len, hidden_dim), } diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.contiguous.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.contiguous.json index 66990601..a06e37e0 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.contiguous.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.contiguous.json @@ -1,61 +1,61 @@ { - "(2, 2048, 8, 64)-(2, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 8, 64)-(2, 8192, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.00392962247133255, - "bw_span": 0.0361565500497818, - "infer_memory": 8388608, + "fw_span": 0.004092603921890259, + "bw_span": 0.14557528775185347, + "infer_memory": 33554432, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 2048, 8, 64)-(1, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "(1, 8192, 8, 64)-(1, 8192, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.005548819899559021, - "bw_span": 0.022963620722293854, - "infer_memory": 4194304, + "fw_span": 0.004374142736196518, + "bw_span": 0.0721710966899991, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 1024, 8, 64)-(2, 1024, 8, 64) : torch.float32-torch.float32 : True-True": { + "(2, 4096, 8, 64)-(2, 4096, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.006691738963127136, - "bw_span": 0.0254802405834198, - "infer_memory": 4194304, + "fw_span": 0.004373118281364441, + "bw_span": 0.0719255767762661, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 2048, 4, 64)-(2, 2048, 4, 64) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 4, 64)-(2, 8192, 4, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.006204098463058472, - "bw_span": 0.023501552641391754, - "infer_memory": 4194304, + "fw_span": 0.0038990983739495277, + "bw_span": 0.07239815313369036, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 2048, 8, 32)-(2, 2048, 8, 32) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 8, 32)-(2, 8192, 8, 32) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.004957988858222961, - "bw_span": 0.01678112894296646, - "infer_memory": 4194304, + "fw_span": 0.004175095818936825, + "bw_span": 0.07230215705931187, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] } diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.reshape.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.reshape.json index f0933e95..916e57b2 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.reshape.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.reshape.json @@ -1,49 +1,49 @@ { - "(2, 2048, 8, 64)-(2, 2048, 512) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 8, 64)-(2, 8192, 512) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.007220730185508728, - "bw_span": 0.026119686663150787, - "infer_memory": 8388608, + "fw_span": 0.005718669854104519, + "bw_span": 0.09603481739759445, + "infer_memory": 33554432, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 2048, 8, 64)-(1, 2048, 512) : torch.float32-torch.float32 : True-True": { + "(1, 8192, 8, 64)-(1, 8192, 512) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009209662675857544, - "bw_span": 0.0355185940861702, - "infer_memory": 4194304, + "fw_span": 0.007527763955295086, + "bw_span": 0.056579639203846455, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 1024, 8, 64)-(2, 1024, 512) : torch.float32-torch.float32 : True-True": { + "(2, 4096, 8, 64)-(2, 4096, 512) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.00962037593126297, - "bw_span": 0.03517419099807739, - "infer_memory": 4194304, + "fw_span": 0.009833252988755703, + "bw_span": 0.06164633668959141, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 2048, 4, 64)-(2, 2048, 256) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 4, 64)-(2, 8192, 256) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009422190487384796, - "bw_span": 0.03640800714492798, - "infer_memory": 4194304, + "fw_span": 0.0054121483117341995, + "bw_span": 0.04739384166896343, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] } diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.view.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.view.json index 68f18636..16b55f4e 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.view.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.view.json @@ -1,49 +1,49 @@ { - "(2, 2048, 512)-(2, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 512)-(2, 8192, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009936653077602386, - "bw_span": 0.033933669328689575, - "infer_memory": 8388608, + "fw_span": 0.005937204696238041, + "bw_span": 0.09612578433007002, + "infer_memory": 33554432, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 2048, 512)-(1, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "(1, 8192, 512)-(1, 8192, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.010452233254909515, - "bw_span": 0.036592036485672, - "infer_memory": 4194304, + "fw_span": 0.006045191548764706, + "bw_span": 0.04667786415666342, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 1024, 512)-(2, 1024, 8, 64) : torch.float32-torch.float32 : True-True": { + "(2, 4096, 512)-(2, 4096, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.010039843618869781, - "bw_span": 0.04428252577781677, - "infer_memory": 4194304, + "fw_span": 0.0060086604207754135, + "bw_span": 0.046695396304130554, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 2048, 256)-(2, 2048, 4, 64) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 256)-(2, 8192, 4, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009671784937381744, - "bw_span": 0.03549344837665558, - "infer_memory": 4194304, + "fw_span": 0.005979649722576141, + "bw_span": 0.046865385957062244, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] } diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.div.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.div.json index e2375b2c..8e8a308f 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.div.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.div.json @@ -1,61 +1,61 @@ { - "(2, 8, 2048, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 8192)-(2, 8, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 268435456 + 4294967296 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.7881719619035721, - "bw_span": 1.9588613882660866, - "infer_memory": 536870912, + "fw_span": 12.559171300381422, + "bw_span": 31.287124007940292, + "infer_memory": 8589934592, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 8, 2048, 2048)-(1, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(1, 8, 8192, 8192)-(1, 8, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.3966972231864929, - "bw_span": 0.9810343384742737, - "infer_memory": 268435456, + "fw_span": 6.275683059357107, + "bw_span": 15.639891382306814, + "infer_memory": 4294967296, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 4, 2048, 2048)-(2, 4, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 4, 8192, 8192)-(2, 4, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.3962540999054909, - "bw_span": 0.9834336116909981, - "infer_memory": 268435456, + "fw_span": 6.277594109997153, + "bw_span": 15.663174027577043, + "infer_memory": 4294967296, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 1024, 2048)-(2, 8, 1024, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 8, 4096, 8192)-(2, 8, 4096, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.3961294889450073, - "bw_span": 0.9815100580453873, - "infer_memory": 268435456, + "fw_span": 6.273878482170403, + "bw_span": 15.646824124269187, + "infer_memory": 4294967296, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 2048, 1024)-(2, 8, 2048, 1024) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 4096)-(2, 8, 8192, 4096) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.39616115391254425, - "bw_span": 0.9814225137233734, - "infer_memory": 268435456, + "fw_span": 6.275715050287545, + "bw_span": 15.648416592739522, + "infer_memory": 4294967296, "train_mem_info": [], "train_mem2in_idx": [] } diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.matmul.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.matmul.json index aa662f0b..a9ec13e1 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.matmul.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.matmul.json @@ -1,226 +1,226 @@ { - "(2, 8, 2048, 64)-(2, 8, 64, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8, 8192, 64)-(2, 8, 64, 8192)-(2, 8, 8192, 8192) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 8388608, - 8388608 + 33554432, + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.5448272451758385, - "bw_span": 2.108476497232914, - "infer_memory": 285212672, + "fw_span": 10.149017279036343, + "bw_span": 25.981219834648073, + "infer_memory": 4362076160, "train_mem_info": [ - 8388608, - 8388608 + 33554432, + 33554432 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 8, 2048, 2048)-(2, 8, 2048, 64)-(2, 8, 2048, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8, 8192, 8192)-(2, 8, 8192, 64)-(2, 8, 8192, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 268435456, - 8388608 + 4294967296, + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.9633848443627357, - "bw_span": 2.7184441685676575, - "infer_memory": 285212672, + "fw_span": 13.920847116969526, + "bw_span": 39.25901562906802, + "infer_memory": 4362076160, "train_mem_info": [ - 8388608, - 268435456 + 33554432, + 4294967296 ], "train_mem2in_idx": [ 1, 0 ] }, - "(1, 8, 2048, 64)-(1, 8, 64, 2048)-(1, 8, 2048, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(1, 8, 8192, 64)-(1, 8, 64, 8192)-(1, 8, 8192, 8192) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 4194304, - 4194304 + 16777216, + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.2818003296852112, - "bw_span": 1.0446002706885338, - "infer_memory": 142606336, + "fw_span": 5.109956441447139, + "bw_span": 13.891059276647866, + "infer_memory": 2181038080, "train_mem_info": [ - 4194304, - 4194304 + 16777216, + 16777216 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 4, 2048, 64)-(2, 4, 64, 2048)-(2, 4, 2048, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 4, 8192, 64)-(2, 4, 64, 8192)-(2, 4, 8192, 8192) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 4194304, - 4194304 + 16777216, + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.28081703931093216, - "bw_span": 1.0572420433163643, - "infer_memory": 142606336, + "fw_span": 5.099610146135092, + "bw_span": 13.918159017339349, + "infer_memory": 2181038080, "train_mem_info": [ - 4194304, - 4194304 + 16777216, + 16777216 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 8, 1024, 64)-(2, 8, 64, 2048)-(2, 8, 1024, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8, 4096, 64)-(2, 8, 64, 8192)-(2, 8, 4096, 8192) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 4194304, - 8388608 + 16777216, + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.28518345206975937, - "bw_span": 1.0718883946537971, - "infer_memory": 146800640, + "fw_span": 5.08829178288579, + "bw_span": 13.72850202023983, + "infer_memory": 2197815296, "train_mem_info": [ - 8388608, - 4194304 + 33554432, + 16777216 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 8, 2048, 32)-(2, 8, 32, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8, 8192, 32)-(2, 8, 32, 8192)-(2, 8, 8192, 8192) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 4194304, - 4194304 + 16777216, + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.4970652982592583, - "bw_span": 1.9734794273972511, - "infer_memory": 276824064, + "fw_span": 7.5393385021016, + "bw_span": 25.859177764505148, + "infer_memory": 4328521728, "train_mem_info": [ - 4194304, - 4194304 + 16777216, + 16777216 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 8, 2048, 64)-(2, 8, 64, 1024)-(2, 8, 2048, 1024) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8, 8192, 64)-(2, 8, 64, 4096)-(2, 8, 8192, 4096) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 8388608, - 4194304 + 33554432, + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.2919815480709076, - "bw_span": 1.0712673887610435, - "infer_memory": 146800640, + "fw_span": 4.288905439898372, + "bw_span": 14.000384951941669, + "infer_memory": 2197815296, "train_mem_info": [ - 4194304, - 8388608 + 16777216, + 33554432 ], "train_mem2in_idx": [ 1, 0 ] }, - "(1, 8, 2048, 2048)-(1, 8, 2048, 64)-(1, 8, 2048, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(1, 8, 8192, 8192)-(1, 8, 8192, 64)-(1, 8, 8192, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 134217728, - 4194304 + 2147483648, + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.4901982843875885, - "bw_span": 1.4053288847208023, - "infer_memory": 142606336, + "fw_span": 7.462437194772065, + "bw_span": 20.21532175131142, + "infer_memory": 2181038080, "train_mem_info": [ - 4194304, - 134217728 + 16777216, + 2147483648 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 4, 2048, 2048)-(2, 4, 2048, 64)-(2, 4, 2048, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 4, 8192, 8192)-(2, 4, 8192, 64)-(2, 4, 8192, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 134217728, - 4194304 + 2147483648, + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.4908433184027672, - "bw_span": 1.399235613644123, - "infer_memory": 142606336, + "fw_span": 7.47751968447119, + "bw_span": 20.190081978216767, + "infer_memory": 2181038080, "train_mem_info": [ - 4194304, - 134217728 + 16777216, + 2147483648 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 8, 1024, 2048)-(2, 8, 2048, 64)-(2, 8, 1024, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8, 4096, 8192)-(2, 8, 8192, 64)-(2, 8, 4096, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 134217728, - 8388608 + 2147483648, + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.4898613318800926, - "bw_span": 1.4226442202925682, - "infer_memory": 146800640, + "fw_span": 7.488103001378477, + "bw_span": 19.93988968897611, + "infer_memory": 2197815296, "train_mem_info": [ - 8388608, - 134217728 + 33554432, + 2147483648 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 8, 2048, 1024)-(2, 8, 1024, 64)-(2, 8, 2048, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8, 8192, 4096)-(2, 8, 4096, 64)-(2, 8, 8192, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 134217728, - 4194304 + 2147483648, + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.5006546154618263, - "bw_span": 1.410149410367012, - "infer_memory": 146800640, + "fw_span": 7.145568006671965, + "bw_span": 19.709997391328216, + "infer_memory": 2197815296, "train_mem_info": [ - 4194304, - 134217728 + 16777216, + 2147483648 ], "train_mem2in_idx": [ 1, 0 ] }, - "(2, 8, 2048, 2048)-(2, 8, 2048, 32)-(2, 8, 2048, 32) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8, 8192, 8192)-(2, 8, 8192, 32)-(2, 8, 8192, 32) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 268435456, - 4194304 + 4294967296, + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.9725717827677727, - "bw_span": 2.6656288653612137, - "infer_memory": 276824064, + "fw_span": 14.046217314898968, + "bw_span": 38.024049531668425, + "infer_memory": 4328521728, "train_mem_info": [ - 4194304, - 268435456 + 16777216, + 4294967296 ], "train_mem2in_idx": [ 1, diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.dropout.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.dropout.json index c7ca2767..f5d483ad 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.dropout.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.dropout.json @@ -1,61 +1,61 @@ { - "(2, 8, 2048, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 8192)-(2, 8, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 268435456 + 4294967296 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.0060519203543663025, - "bw_span": 1.1683166027069092, - "infer_memory": 268435456, + "fw_span": 0.006165704689919949, + "bw_span": 18.729831255041063, + "infer_memory": 4294967296, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 8, 2048, 2048)-(1, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(1, 8, 8192, 8192)-(1, 8, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.005933083593845367, - "bw_span": 0.5830274894833565, - "infer_memory": 134217728, + "fw_span": 0.006175646558403969, + "bw_span": 9.37040860299021, + "infer_memory": 2147483648, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 4, 2048, 2048)-(2, 4, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 4, 8192, 8192)-(2, 4, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.005894899368286133, - "bw_span": 0.5836460739374161, - "infer_memory": 134217728, + "fw_span": 0.005974690429866314, + "bw_span": 9.368689474649727, + "infer_memory": 2147483648, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 1024, 2048)-(2, 8, 1024, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 8, 4096, 8192)-(2, 8, 4096, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.006268918514251709, - "bw_span": 0.5835466086864471, - "infer_memory": 134217728, + "fw_span": 0.005687680095434189, + "bw_span": 9.368521464057267, + "infer_memory": 2147483648, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 2048, 1024)-(2, 8, 2048, 1024) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 4096)-(2, 8, 8192, 4096) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.006314180791378021, - "bw_span": 0.5827156826853752, - "infer_memory": 134217728, + "fw_span": 0.00601569190621376, + "bw_span": 9.366218908689916, + "infer_memory": 2147483648, "train_mem_info": [], "train_mem2in_idx": [] } diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.linear.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.linear.json index 2882d7e6..391c8493 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.linear.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.linear.json @@ -1,179 +1,179 @@ { - "(2, 2048, 512)-(512, 512)-(2, 2048, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "(2, 8192, 512)-(512, 512)-(2, 8192, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [ 1048576 ], "buffer_mem_info": [], - "fw_span": 0.13340581208467484, - "bw_span": 0.11256430298089984, - "infer_memory": 17825792, + "fw_span": 0.456011900678277, + "bw_span": 0.4641034873202443, + "infer_memory": 68157440, "train_mem_info": [ - 8388608 + 33554432 ], "train_mem2in_idx": [ 0 ] }, - "(2, 2048, 512)-(512, 512)-(2, 2048, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8192, 512)-(512, 512)-(2, 8192, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [ 1048576 ], "buffer_mem_info": [], - "fw_span": 0.12639649212360382, - "bw_span": 0.2559272572398186, - "infer_memory": 17825792, + "fw_span": 0.44349594973027706, + "bw_span": 1.0563238989561796, + "infer_memory": 68157440, "train_mem_info": [ - 8388608 + 33554432 ], "train_mem2in_idx": [ 0 ] }, - "(1, 2048, 512)-(512, 512)-(1, 2048, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "(1, 8192, 512)-(512, 512)-(1, 8192, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [ 1048576 ], "buffer_mem_info": [], - "fw_span": 0.06875265389680862, - "bw_span": 0.08587203919887543, - "infer_memory": 9437184, + "fw_span": 0.23856249172240496, + "bw_span": 0.23950396571308374, + "infer_memory": 34603008, "train_mem_info": [ - 4194304 + 16777216 ], "train_mem2in_idx": [ 0 ] }, - "(2, 1024, 512)-(512, 512)-(2, 1024, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "(2, 4096, 512)-(512, 512)-(2, 4096, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [ 1048576 ], "buffer_mem_info": [], - "fw_span": 0.06851088255643845, - "bw_span": 0.08921511471271515, - "infer_memory": 9437184, + "fw_span": 0.24063654709607363, + "bw_span": 0.2445041434839368, + "infer_memory": 34603008, "train_mem_info": [ - 4194304 + 16777216 ], "train_mem2in_idx": [ 0 ] }, - "(2, 2048, 256)-(512, 256)-(2, 2048, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "(2, 8192, 256)-(512, 256)-(2, 8192, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [ 524288 ], "buffer_mem_info": [], - "fw_span": 0.07312837988138199, - "bw_span": 0.09327642619609833, - "infer_memory": 13107200, + "fw_span": 0.24551521055400372, + "bw_span": 0.25930958800017834, + "infer_memory": 50855936, "train_mem_info": [ - 4194304 + 16777216 ], "train_mem2in_idx": [ 0 ] }, - "(2, 2048, 512)-(256, 512)-(2, 2048, 256) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "(2, 8192, 512)-(256, 512)-(2, 8192, 256) : torch.float32-torch.float32-torch.float32 : False-True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [ 524288 ], "buffer_mem_info": [], - "fw_span": 0.0704590231180191, - "bw_span": 0.08722972124814987, - "infer_memory": 13107200, + "fw_span": 0.24952581152319908, + "bw_span": 0.2285446971654892, + "infer_memory": 50855936, "train_mem_info": [ - 8388608 + 33554432 ], "train_mem2in_idx": [ 0 ] }, - "(1, 2048, 512)-(512, 512)-(1, 2048, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(1, 8192, 512)-(512, 512)-(1, 8192, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [ 1048576 ], "buffer_mem_info": [], - "fw_span": 0.0701218843460083, - "bw_span": 0.14707427471876144, - "infer_memory": 9437184, + "fw_span": 0.24234461598098278, + "bw_span": 0.5651834886521101, + "infer_memory": 34603008, "train_mem_info": [ - 4194304 + 16777216 ], "train_mem2in_idx": [ 0 ] }, - "(2, 1024, 512)-(512, 512)-(2, 1024, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 4096, 512)-(512, 512)-(2, 4096, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [ 1048576 ], "buffer_mem_info": [], - "fw_span": 0.07040128111839294, - "bw_span": 0.14658160507678986, - "infer_memory": 9437184, + "fw_span": 0.2415340393781662, + "bw_span": 0.5497391102835536, + "infer_memory": 34603008, "train_mem_info": [ - 4194304 + 16777216 ], "train_mem2in_idx": [ 0 ] }, - "(2, 2048, 256)-(512, 256)-(2, 2048, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8192, 256)-(512, 256)-(2, 8192, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [ 524288 ], "buffer_mem_info": [], - "fw_span": 0.07458627223968506, - "bw_span": 0.14493074268102646, - "infer_memory": 13107200, + "fw_span": 0.24739371147006747, + "bw_span": 0.5779084283858538, + "infer_memory": 50855936, "train_mem_info": [ - 4194304 + 16777216 ], "train_mem2in_idx": [ 0 ] }, - "(2, 2048, 512)-(256, 512)-(2, 2048, 256) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "(2, 8192, 512)-(256, 512)-(2, 8192, 256) : torch.float32-torch.float32-torch.float32 : True-True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [ 524288 ], "buffer_mem_info": [], - "fw_span": 0.070917047560215, - "bw_span": 0.16025099903345108, - "infer_memory": 13107200, + "fw_span": 0.24541020393371582, + "bw_span": 0.6191035965457559, + "infer_memory": 50855936, "train_mem_info": [ - 8388608 + 33554432 ], "train_mem2in_idx": [ 0 diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.softmax.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.softmax.json index 8bd4b216..92421db0 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.softmax.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.softmax.json @@ -1,63 +1,63 @@ { - "(2, 8, 2048, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 8192)-(2, 8, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 268435456 + 4294967296 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 1.0575706139206886, - "bw_span": 3.7416458129882812, - "infer_memory": 536870912, + "fw_span": 23.9705030573532, + "bw_span": 61.606465093791485, + "infer_memory": 8589934592, "train_mem_info": [ - 268435456 + 4294967296 ], "train_mem2in_idx": [ -1 ] }, - "(1, 8, 2048, 2048)-(1, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(1, 8, 8192, 8192)-(1, 8, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.517265684902668, - "bw_span": 1.8769219517707825, - "infer_memory": 268435456, + "fw_span": 11.99300775770098, + "bw_span": 30.78675337601453, + "infer_memory": 4294967296, "train_mem_info": [ - 134217728 + 2147483648 ], "train_mem2in_idx": [ -1 ] }, - "(2, 4, 2048, 2048)-(2, 4, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 4, 8192, 8192)-(2, 4, 8192, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.5280718207359314, - "bw_span": 1.860242709517479, - "infer_memory": 268435456, + "fw_span": 11.974411201663315, + "bw_span": 30.794763541780412, + "infer_memory": 4294967296, "train_mem_info": [ - 134217728 + 2147483648 ], "train_mem2in_idx": [ -1 ] }, - "(2, 8, 1024, 2048)-(2, 8, 1024, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 8, 4096, 8192)-(2, 8, 4096, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 134217728 + 2147483648 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.529155321419239, - "bw_span": 1.8756849691271782, - "infer_memory": 268435456, + "fw_span": 11.989785148762167, + "bw_span": 30.799460248090327, + "infer_memory": 4294967296, "train_mem_info": [ - 134217728 + 2147483648 ], "train_mem2in_idx": [ -1 diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.sum.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.sum.json index f58f2764..d2f1ebdc 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.sum.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.sum.json @@ -1,49 +1,49 @@ { - "(2, 2048, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 512)-(1,) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.02030692994594574, - "bw_span": 0.03039538860321045, - "infer_memory": 8390656, + "fw_span": 0.05562114529311657, + "bw_span": 0.09991039987653494, + "infer_memory": 33556480, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 2048, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "(1, 8192, 512)-(1,) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.015971437096595764, - "bw_span": 0.033936649560928345, - "infer_memory": 4195840, + "fw_span": 0.03152289427816868, + "bw_span": 0.051691546104848385, + "infer_memory": 16779264, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 1024, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "(2, 4096, 512)-(1,) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.0163249671459198, - "bw_span": 0.03274437040090561, - "infer_memory": 4195840, + "fw_span": 0.031939405016601086, + "bw_span": 0.050949002616107464, + "infer_memory": 16779264, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 2048, 256)-(1,) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 256)-(1,) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.01620892435312271, - "bw_span": 0.03477875143289566, - "infer_memory": 4195840, + "fw_span": 0.03197593614459038, + "bw_span": 0.05111047066748142, + "infer_memory": 16779264, "train_mem_info": [], "train_mem2in_idx": [] } diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.transpose.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.transpose.json index 9c71e6e1..29dd5c5b 100644 --- a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.transpose.json +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.transpose.json @@ -1,181 +1,181 @@ { - "(2, 8, 2048, 64)-(2, 8, 64, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 64)-(2, 8, 64, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.01199282705783844, - "bw_span": 0.03650356084108353, - "infer_memory": 8388608, + "fw_span": 0.005893642082810402, + "bw_span": 0.09586436208337545, + "infer_memory": 33554432, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 2048, 8, 64)-(2, 8, 2048, 64) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 8, 64)-(2, 8, 8192, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.005709193646907806, - "bw_span": 0.023253075778484344, - "infer_memory": 8388608, + "fw_span": 0.005636643618345261, + "bw_span": 0.0960023608058691, + "infer_memory": 33554432, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 2048, 64)-(2, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 64)-(2, 8192, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 8388608 + 33554432 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.010336004197597504, - "bw_span": 0.03510527312755585, - "infer_memory": 8388608, + "fw_span": 0.005456642247736454, + "bw_span": 0.09644185192883015, + "infer_memory": 33554432, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 8, 2048, 64)-(1, 8, 64, 2048) : torch.float32-torch.float32 : True-True": { + "(1, 8, 8192, 64)-(1, 8, 64, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.00971127301454544, - "bw_span": 0.03457833081483841, - "infer_memory": 4194304, + "fw_span": 0.0061027007177472115, + "bw_span": 0.04664934240281582, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 4, 2048, 64)-(2, 4, 64, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 4, 8192, 64)-(2, 4, 64, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.00948682427406311, - "bw_span": 0.03553144633769989, - "infer_memory": 4194304, + "fw_span": 0.005742162466049194, + "bw_span": 0.046914396807551384, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 1024, 64)-(2, 8, 64, 1024) : torch.float32-torch.float32 : True-True": { + "(2, 8, 4096, 64)-(2, 8, 64, 4096) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009621679782867432, - "bw_span": 0.03477856516838074, - "infer_memory": 4194304, + "fw_span": 0.0056121498346328735, + "bw_span": 0.04703190643340349, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 2048, 32)-(2, 8, 32, 2048) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 32)-(2, 8, 32, 8192) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009464845061302185, - "bw_span": 0.0379662960767746, - "infer_memory": 4194304, + "fw_span": 0.005402648821473122, + "bw_span": 0.04719835706055164, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 2048, 8, 64)-(1, 8, 2048, 64) : torch.float32-torch.float32 : True-True": { + "(1, 8192, 8, 64)-(1, 8, 8192, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009648129343986511, - "bw_span": 0.03602709621191025, - "infer_memory": 4194304, + "fw_span": 0.0054776668548583984, + "bw_span": 0.04702990408986807, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 1024, 8, 64)-(2, 8, 1024, 64) : torch.float32-torch.float32 : True-True": { + "(2, 4096, 8, 64)-(2, 8, 4096, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009131059050559998, - "bw_span": 0.03459863364696503, - "infer_memory": 4194304, + "fw_span": 0.0057601602748036385, + "bw_span": 0.04674734082072973, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 2048, 4, 64)-(2, 4, 2048, 64) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 4, 64)-(2, 4, 8192, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.00681169331073761, - "bw_span": 0.024368613958358765, - "infer_memory": 4194304, + "fw_span": 0.005707703530788422, + "bw_span": 0.04695134703069925, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 2048, 8, 32)-(2, 8, 2048, 32) : torch.float32-torch.float32 : True-True": { + "(2, 8192, 8, 32)-(2, 8, 8192, 32) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.006612204015254974, - "bw_span": 0.024354644119739532, - "infer_memory": 4194304, + "fw_span": 0.00533214770257473, + "bw_span": 0.04725989419966936, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(1, 8, 2048, 64)-(1, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "(1, 8, 8192, 64)-(1, 8192, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.009928643703460693, - "bw_span": 0.03518201410770416, - "infer_memory": 4194304, + "fw_span": 0.005862163379788399, + "bw_span": 0.0465993769466877, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 4, 2048, 64)-(2, 2048, 4, 64) : torch.float32-torch.float32 : True-True": { + "(2, 4, 8192, 64)-(2, 8192, 4, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.005669891834259033, - "bw_span": 0.024593807756900787, - "infer_memory": 4194304, + "fw_span": 0.009811297059059143, + "bw_span": 0.0724355923011899, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 1024, 64)-(2, 1024, 8, 64) : torch.float32-torch.float32 : True-True": { + "(2, 8, 4096, 64)-(2, 4096, 8, 64) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.010106153786182404, - "bw_span": 0.03500021994113922, - "infer_memory": 4194304, + "fw_span": 0.009760749526321888, + "bw_span": 0.04399379249662161, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] }, - "(2, 8, 2048, 32)-(2, 2048, 8, 32) : torch.float32-torch.float32 : True-True": { + "(2, 8, 8192, 32)-(2, 8192, 8, 32) : torch.float32-torch.float32 : True-True": { "in_mem_info": [ - 4194304 + 16777216 ], "param_mem_info": [], "buffer_mem_info": [], - "fw_span": 0.010841339826583862, - "bw_span": 0.034562498331069946, - "infer_memory": 4194304, + "fw_span": 0.009862799197435379, + "bw_span": 0.05074900109320879, + "infer_memory": 16777216, "train_mem_info": [], "train_mem2in_idx": [] } From 2a50991a1666fd8457ab5db7b48fd453a101191b Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Thu, 13 Feb 2025 05:34:00 +0000 Subject: [PATCH 1806/1892] Merged PR 2364: [UT] Fix pipeline_nstages UT Forgot to rerun UT after solving comments. There is another UT fail caused by the github PR https://dev.azure.com/msrasrg/SuperScaler/_git/MagicCube/pullrequest/2363 / https://github.com/microsoft/nnscaler/pull/21 Asking @ to fix it. --- tests/autodist/test_pipeline_nstages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/autodist/test_pipeline_nstages.py b/tests/autodist/test_pipeline_nstages.py index 07952abb..2b3413da 100644 --- a/tests/autodist/test_pipeline_nstages.py +++ b/tests/autodist/test_pipeline_nstages.py @@ -45,7 +45,7 @@ def _compile(tmp_path, nstages): runtime_ngpus=2, use_end2end=True, pas_config={ - 'pipeline_pivots': 'Linear', + 'pipeline_pivots': 'Linear' if nstages != 1 else '', 'pipeline_nstages': nstages, 'max_pipeline_bubble_ratio': 0.99, # force autodist to accept unbalanced stages }, From 075979f6b2095597996e1e38cd2eae51157832b7 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Thu, 20 Feb 2025 08:21:34 +0000 Subject: [PATCH 1807/1892] Merged PR 2368: [AutoDist] Fix when model too small autodist does not use all GPUs When the model is too small, autodist may choose not to use all plan_ngpus. This kind of plans are not acceptable to downstream codegen. Also removed a fixed fixme. --- examples/llama3_demo/train.py | 2 - nnscaler/autodist/autodist_config.py | 2 +- nnscaler/autodist/pipeline_solver.py | 25 ++++++++-- nnscaler/policies.py | 2 + tests/autodist/test_pipeline_nstages.py | 66 ++++++++++++++++++++----- 5 files changed, 76 insertions(+), 21 deletions(-) diff --git a/examples/llama3_demo/train.py b/examples/llama3_demo/train.py index 11850605..8c271a35 100644 --- a/examples/llama3_demo/train.py +++ b/examples/llama3_demo/train.py @@ -30,8 +30,6 @@ from nnscaler.runtime.f16_optimizer import MixedPrecisionAdamW import nnscaler.utils -import torch._dynamo # FIXME: a workaround to avoid tracing the dynamic import - model_id = 'meta-llama/Meta-Llama-3-8B-Instruct' tokenizer_id = model_id diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index c5e4ab77..2b33f202 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -144,7 +144,7 @@ def __init__(self, pipeline_nstages='auto', pipeline_scheduler='1f1b', max_pipeline_bubble_ratio=0.2, - max_pipeline_unbalance_ratio=0.5, + max_pipeline_unbalance_ratio=0.5, # FIXME: this is in fact "min_pipeline_balance_ratio" solver='dp', parallel_profile=True, transient_mem_coef=2, diff --git a/nnscaler/autodist/pipeline_solver.py b/nnscaler/autodist/pipeline_solver.py index 75f085f6..ea7acfcd 100644 --- a/nnscaler/autodist/pipeline_solver.py +++ b/nnscaler/autodist/pipeline_solver.py @@ -295,8 +295,18 @@ def calc_optimal_pp_plan( val = max(lhs, rhs) if T[cur_idx][0] > val: T[cur_idx] = [val, prev_idx] - best_time = float('inf') + + # why there are two bests here: + # if the model is too small, the best solution may not fully utilize all gpus (because of comm overhead, etc) + # this violates the user's plan_ngpus config so we must pick another + # here it records both bests to identify this case and warn the user + # best_time/state: the overall best solution + # valid_best_time/state: the best solution when we respect plan_ngpus config + best_time = math.inf best_state = (-1, -1, -1, -1) + valid_best_time = math.inf + valid_best_state = (-1, -1, -1, -1) + micro_batch_num = autodist_config.update_freq for stage_num in range(1, ngpus + 1): if autodist_config.pipeline_nstages != 'auto' and autodist_config.pipeline_nstages != stage_num: @@ -312,10 +322,15 @@ def calc_optimal_pp_plan( cur_time = T[cur_idx][0] * (micro_batch_num - 1 + stage_num) if best_time > cur_time: best_time, best_state = cur_time, cur_idx + if pp_dev_num == ngpus and valid_best_time > cur_time: + valid_best_time, valid_best_state = cur_time, cur_idx _logger.info( - f'best time/s: {best_time}, state (s, pp, tp, i): {best_state}') - if best_state == (-1, -1, -1, -1): + f'best time/s: {valid_best_time}, state (s, pp, tp, i): {valid_best_state}') + if best_state != valid_best_state: + _s, pp, _tp, _i = best_state + _logger.warning(f'the model is too small for {ngpus} GPUs; please use {pp} GPUs for better performance') + if valid_best_state == (-1, -1, -1, -1): raise RuntimeError('fail to find a valid pipeline plan') spmd_outs = [] @@ -331,12 +346,12 @@ def build_answer(s, pp, tp, i): if prev_idx[0] != -1: build_answer(*prev_idx) - build_answer(*best_state) + build_answer(*valid_best_state) spmd_descs = [spmd_out.desc for spmd_out in spmd_outs] pp_desc = PipelineParallelDesc(spmd_descs, [], autodist_config.mesh_desc) stage_mems = [spmd_out.memory for spmd_out in spmd_outs] stage_all_times = [spmd_out.all_time for spmd_out in spmd_outs] stage_comp_times = [spmd_out.comp_time for spmd_out in spmd_outs] - return PipelineSearchOutput(pp_desc, best_time, stage_mems, stage_all_times, + return PipelineSearchOutput(pp_desc, valid_best_time, stage_mems, stage_all_times, stage_comp_times) diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 5d534ab5..f1db5858 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -274,6 +274,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: recompute_modules = pas_cfg.get('recompute_modules', '') pipeline_pivots = pas_cfg.get('pipeline_pivots', '') max_pipeline_bubble_ratio = pas_cfg.get('max_pipeline_bubble_ratio', 0.2) + max_pipeline_unbalance_ratio = pas_cfg.get('max_pipeline_unbalance_ratio', 0.5) use_apex_fused_adam_v2 = pas_cfg.get('use_apex_fused_adam_v2', False) parallel_profile = pas_cfg.get('parallel_profile', True) transient_mem_coef = pas_cfg.get('transient_mem_coef', 2) @@ -340,6 +341,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: pipeline_pivots=pipeline_pivots, pipeline_nstages=pipeline_nstages, max_pipeline_bubble_ratio=max_pipeline_bubble_ratio, + max_pipeline_unbalance_ratio=max_pipeline_unbalance_ratio, parallel_profile=parallel_profile, transient_mem_coef=transient_mem_coef, ) diff --git a/tests/autodist/test_pipeline_nstages.py b/tests/autodist/test_pipeline_nstages.py index 2b3413da..28f9d2bc 100644 --- a/tests/autodist/test_pipeline_nstages.py +++ b/tests/autodist/test_pipeline_nstages.py @@ -1,57 +1,97 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import logging import os +import unittest.mock import pytest import torch +import nnscaler.autodist.pipeline_solver from nnscaler.cli.trainer import Trainer from nnscaler.cli.trainer_args import * -from tests.launch_torchrun import launch_torchrun +from .. import utils -@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='need 2 gpus') + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='no gpu') def test_1_stage(tmp_path): - launch_torchrun(1, _compile_worker, tmp_path, 1) + _compile(tmp_path, 1) # for TP, the scripts should be identical (except tensor names) lines0 = _count_gencode_lines(tmp_path, 0) lines1 = _count_gencode_lines(tmp_path, 1) assert lines0 == lines1 -@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='need 2 gpus') +@pytest.mark.skipif(not torch.cuda.is_available(), reason='no gpu') def test_2_stages(tmp_path): - launch_torchrun(1, _compile_worker, tmp_path, 2) + _compile(tmp_path, 2) # for PP, since we have 3 linears, the scripts should be different lines0 = _count_gencode_lines(tmp_path, 0) lines1 = _count_gencode_lines(tmp_path, 1) assert lines0 != lines1 -@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='need 2 gpus') +@pytest.mark.skipif(not torch.cuda.is_available(), reason='no gpu') def test_auto_stages(tmp_path): - launch_torchrun(1, _compile_worker, tmp_path, 'auto') + _compile(tmp_path, 'auto') # just check it does not throw # because both results are possible theoretically -def _compile_worker(tmp_path, nstages): - _compile(tmp_path, nstages) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='no gpu') +def test_small(tmp_path): + # check it spreads the model to all gpus even when less are required + + # the graph is as follow: + # [0] data['x'] + # [1] linear1 + # [2] linear2 + # [3] linear3 + # [4] data['y'] + # [5] cross_entroy + # we assume linear costs 0.2, cross_entroy costs 0.3, getitem costs 0, unavoidable overhead costs 0.1, + # and ngpus does not affect the time + # since tp has no gain in our "profiling", the algorithm will tend to use 1 gpu per stage + + # the "best" result can be T[2,2,1,0] <- min(T[1,1,1,3]=tp_info[1,1,3,5], tp_info[1,2,0,2]) + # what we expect is T[2,4,2,0] <- min(T[1,2,2,3]=tp_info[2,1,3,5], tp_info[2,2,2,2]) + + costs = [0.0, 0.2, 0.2, 0.2, 0.0, 0.3] + + orig_compute_tp_info = nnscaler.autodist.pipeline_solver._compute_tp_info + + def patched_compute_tp_info(model_graph, cfg, legal_tp_degrees): + tp_info = orig_compute_tp_info(model_graph, cfg, legal_tp_degrees) + for k, v in tp_info.items(): + _ngpus, _nstages, start, end = k + v.all_time = 0.1 + sum(costs[i] for i in range(start, end + 1)) + return tp_info + patch = unittest.mock.patch('nnscaler.autodist.pipeline_solver._compute_tp_info', patched_compute_tp_info) -def _compile(tmp_path, nstages): + with utils.catch_log(nnscaler.autodist.pipeline_solver._logger, 'WARNING') as log: + with patch: + _compile(tmp_path, nstages=2, ngpus=4) + + assert 'model is too small' in log.getvalue() + + +def _compile(tmp_path, nstages, ngpus=2): trainer_args = TrainerArgs( compute_config=ComputeConfig( - plan_ngpus=2, - runtime_ngpus=2, + plan_ngpus=ngpus, + runtime_ngpus=ngpus, use_end2end=True, pas_config={ 'pipeline_pivots': 'Linear' if nstages != 1 else '', 'pipeline_nstages': nstages, 'max_pipeline_bubble_ratio': 0.99, # force autodist to accept unbalanced stages + 'max_pipeline_unbalance_ratio': 0.01, }, ), - gen_reuse='override', gen_savedir=tmp_path/'src', + gen_reuse='override', + broadcast_strategy='none', run_mode='compile', model=ModelConfig(type=Model), optimizer=OptimizerConfig(type=torch.optim.AdamW), From e81bb5d8e00f8ce0f3716c25f2e5101932efa498 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 21 Mar 2025 08:56:31 +0000 Subject: [PATCH 1808/1892] Merged PR 2362: [Example] longrope --- examples/longrope2/.gitignore | 7 + examples/longrope2/README.rst | 68 +++++ examples/longrope2/__init__.py | 0 examples/longrope2/data/__init__.py | 0 examples/longrope2/data/dataset.py | 37 +++ examples/longrope2/data/download.py | 126 ++++++++ examples/longrope2/data/process.py | 140 +++++++++ .../longrope2/llama3_8b_longrope2_config.json | 31 ++ examples/longrope2/requirement.txt | 4 + examples/longrope2/rope_modifier.py | 56 ++++ examples/longrope2/run.sh | 13 + examples/longrope2/train.py | 288 ++++++++++++++++++ examples/transformers_utils/__init__.py | 14 + .../transformers_utils/causal_lm_wrapper.py | 63 ++++ .../chunk_linear_cross_entropy.py | 68 +++++ .../transformers_utils/flash_attn_anno.py | 139 +++++++++ examples/transformers_utils/tokenizer.py | 24 ++ examples/warmup_schedular.py | 82 +++++ nnscaler/graph/tracer/wrap_utils.py | 2 + 19 files changed, 1162 insertions(+) create mode 100644 examples/longrope2/.gitignore create mode 100644 examples/longrope2/README.rst create mode 100644 examples/longrope2/__init__.py create mode 100644 examples/longrope2/data/__init__.py create mode 100644 examples/longrope2/data/dataset.py create mode 100644 examples/longrope2/data/download.py create mode 100644 examples/longrope2/data/process.py create mode 100644 examples/longrope2/llama3_8b_longrope2_config.json create mode 100644 examples/longrope2/requirement.txt create mode 100644 examples/longrope2/rope_modifier.py create mode 100644 examples/longrope2/run.sh create mode 100644 examples/longrope2/train.py create mode 100644 examples/transformers_utils/__init__.py create mode 100644 examples/transformers_utils/causal_lm_wrapper.py create mode 100644 examples/transformers_utils/chunk_linear_cross_entropy.py create mode 100644 examples/transformers_utils/flash_attn_anno.py create mode 100644 examples/transformers_utils/tokenizer.py create mode 100644 examples/warmup_schedular.py diff --git a/examples/longrope2/.gitignore b/examples/longrope2/.gitignore new file mode 100644 index 00000000..497bbed0 --- /dev/null +++ b/examples/longrope2/.gitignore @@ -0,0 +1,7 @@ +data/fineweb-edu* +data/RedPajama-Data-1T* +data/mix-context-win-* +*.log +runs +checkpoints*/ +gpucore.* diff --git a/examples/longrope2/README.rst b/examples/longrope2/README.rst new file mode 100644 index 00000000..35648d40 --- /dev/null +++ b/examples/longrope2/README.rst @@ -0,0 +1,68 @@ +########################################## +LongRope2 context length extension Example +########################################## + +************ +Introduction +************ + +`LongRoPE2 `_ is an advanced version of `LongRoPE `_ that significantly improves long-context extension for RoPE-based LLMs. It has been adopted in Phi4-mini and Phi4-multimodal. + +This example includes the training part for LongRope2. Before training, please using `LongRoPE repo ` for searching the rope extension scaling factor for your model. +This example provides the extension scaling factor of llama3-8b-base as a reference. If you want to have a try with llama3-8b-base, you can run this example directly. + + +*********** +Preparation +*********** + +If this is the first time you use nnScalar, it would be better start with ``examples/llama`` for more using detail. +But it is OK to directly follow this example to run pass. + +Assume following packages have been installed in the environment. :: + + nnscaler + zstandard + transformers>=4.48 + datasets + tensorboard + apex + flash-attn + +A new model config includes the longrope ``rope_scaling`` field and ``original_max_position_embeddings`` are needed, please reference ``examples/longrope2/llama3_8b_longrope2_config.json`` + + +**************** +Data Preparation +**************** + +We use ``HuggingFaceFW/fineweb-edu`` for short context window training and ``togethercomputer/RedPajama-Data-1T`` for long context window training. + +.. code-block:: bash + export PYTHONPATH=$PYTHONPATH:/home/USER_NAME/MagicCube:/home/USER_NAME/MagicCube/examples + # download data to at MagicCube/examples/longrope2/data, will take around 100GB disk memory. + python data/download.py + # process the data to mix context window length format for long context training, will take around 900GB disk memory. + python data/process.py --tokenizer_name_or_path "meta-llama/Meta-Llama-3-8B" + +If you don't have large disk memory, i.e., 1 TB free memory, you could take a sub-dataset by modify the code. + + +******** +Training +******** + +The main different compared with the common long context training example ``examples/llama`` is we need to pass ``--model_config`` to passin the rope extension scaling factor to the model. + +.. code-block:: bash + # compile the distributed code for llama3 model with dp2, tp4 on 8 gpus + python train.py --run_mode compile --model_id "meta-llama/Meta-Llama-3-8B" --model_config llama3_8b_longrope2_config.json --dataset_path data/mix-context-win-short-8192-long-131072 --plan_ngpus=4 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --gpu_mem_constraint 64 --enable-chunk-loss --grad_accumulation_steps 16 --max_train_steps 2250 2>&1 | tee compile.log + # run the training job + torchrun --nproc_per_node=8 train.py --model_id "meta-llama/Meta-Llama-3-8B" --model_config llama3_8b_longrope2_config.json --dataset_path data/mix-context-win-short-8192-long-131072 --plan_ngpus=4 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --gpu_mem_constraint 64 --enable-chunk-loss --grad_accumulation_steps 16 --max_train_steps 2250 2>&1 | tee run.log + + +********** +Additional +********** + +More details about how to change distributed plan or merge checkpoints, please reference ``examples/llama/README.rst``. diff --git a/examples/longrope2/__init__.py b/examples/longrope2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/longrope2/data/__init__.py b/examples/longrope2/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/longrope2/data/dataset.py b/examples/longrope2/data/dataset.py new file mode 100644 index 00000000..29eacf2d --- /dev/null +++ b/examples/longrope2/data/dataset.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch + +from datasets import load_from_disk +from transformers import DataCollatorForLanguageModeling + +from ...transformers_utils import get_tokenizer + + +IGNORE_IDX=-100 + + +def get_dataset(dataset_path, tokenizer_name_or_path): + dataset = load_from_disk(dataset_path) + tokenizer = get_tokenizer(tokenizer_name_or_path) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + def collate_fn(samples): + if len(samples) == 0: + return {} + + mini_batch = dict(data_collator(samples)) + seq_len = mini_batch['input_ids'].size(-1) + shift_labels = mini_batch['input_ids'][..., 1:] + mini_batch['labels'] = torch.nn.functional.pad(shift_labels, (0, 1), 'constant', IGNORE_IDX).contiguous() + + # cast `nsentences` and `ntokens` to tensor since current pipeline parallelism can only transfer data in tensor format + return { + "nsentences": torch.tensor(len(samples), dtype=torch.long), + "ntokens": torch.tensor(len(samples) * seq_len, dtype=torch.long), + "net_input": mini_batch, + "target": mini_batch.pop('labels'), + } + + return dataset, collate_fn diff --git a/examples/longrope2/data/download.py b/examples/longrope2/data/download.py new file mode 100644 index 00000000..535aa948 --- /dev/null +++ b/examples/longrope2/data/download.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import io +import json +import os +import subprocess +import time +import zstandard as zstd + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from tqdm import tqdm +from huggingface_hub import snapshot_download + + +ROOT_SAVE_DIR = Path(__file__).parent +MAX_WORKERS = 16 + + +def read_jsonl_zst(file_path): + if str(file_path).endswith('.zst'): + with open(file_path, 'rb') as f: + dctx = zstd.ZstdDecompressor() + with dctx.stream_reader(f) as reader: + text_stream = io.TextIOWrapper(reader, encoding='utf-8') + for line in tqdm(text_stream): + data = json.loads(line) + yield data + else: + with open(file_path, 'r') as f: + for line in tqdm(f): + data = json.loads(line) + yield data + + +def filter_jsonl_zst(file_path, min_text_length=None, max_text_length=None): + def filter_func(data): + if min_text_length and len(data["text"]) < min_text_length: + return False + if max_text_length and len(data["text"]) > max_text_length: + return False + return True + + filtered_data = [] + for data in read_jsonl_zst(file_path): + if filter_func(data): + filtered_data.append(json.dumps(data)+'\n') + + os.remove(file_path) + if not str(file_path).endswith('.zst'): + file_path = str(file_path) + '.zst' + + with open(file_path, 'wb') as f: + cctx = zstd.ZstdCompressor() + with cctx.stream_writer(f) as writer: + writer.write(''.join(filtered_data).encode('utf-8')) + print(f"{Path(file_path).name} sample number: {len(filtered_data)}") + + +def download_file(url, download_folder, retries=3, delay=5, min_text_length=None, max_text_length=None): + attempt = 0 + while attempt <= retries: + try: + wget_command = ['wget', '-P', download_folder, url] + subprocess.run(wget_command, check=True) + print(f"Downloaded: {url}") + if min_text_length or max_text_length: + file_name = url.split("/")[-1] + filter_jsonl_zst(os.path.join(download_folder, file_name), min_text_length, max_text_length) + return True + except subprocess.CalledProcessError as e: + attempt += 1 + if attempt > retries: + print(f"Failed to download {url} after {retries} retries: {e}") + return False + else: + print(f"Retrying {url} ({attempt}/{retries})...") + time.sleep(delay) + + +def download_files_with_wget(urls, download_folder, retries=3, delay=5, min_text_length=None, max_text_length=None, max_workers=8): + if not os.path.exists(download_folder): + os.makedirs(download_folder) + + with ThreadPoolExecutor(max_workers) as executor: + futures = {executor.submit(download_file, url, download_folder, retries, delay, min_text_length, max_text_length): url for url in urls if url} + for future in as_completed(futures): + url = futures[future] + try: + future.result() + except Exception as e: + print(f"Exception occurred while downloading {url}: {e}") + + +if __name__ == "__main__": + root_save_dir = ROOT_SAVE_DIR + max_workers = MAX_WORKERS + + # For short context, using fineweb-edu dataset as example + snapshot_download( + "HuggingFaceFW/fineweb-edu", + repo_type="dataset", + local_dir=root_save_dir / "fineweb-edu", + allow_patterns="sample/10BT/*", + max_workers=max_workers, + ) + + # For long context, using RedPajama-Data-1T dataset as example + snapshot_download( + "togethercomputer/RedPajama-Data-1T", + repo_type="dataset", + local_dir=root_save_dir / "RedPajama-Data-1T", + allow_patterns="urls/*", + max_workers=max_workers, + ) + + for split in ["arxiv", "wikipedia"]: + with (root_save_dir / "RedPajama-Data-1T" / "urls" / f"{split}.txt").open("r") as f: + urls = [url.strip() for url in f.readlines() if url.strip()] + download_files_with_wget(urls, root_save_dir / "RedPajama-Data-1T" / split, min_text_length=32 * 1024, max_text_length=800 * 1024, max_workers=max_workers) + + with (root_save_dir / "RedPajama-Data-1T" / "urls" / "common_crawl.txt").open("r") as f: + # using 2023-06/en_head only for demonstration + urls = [url.strip() for url in f.readlines() if (url.strip() and "2023-06/en_head" in url)] + download_files_with_wget(urls, root_save_dir / "RedPajama-Data-1T" / "common_crawl", min_text_length=32 * 1024, max_text_length=800 * 1024, max_workers=max_workers) diff --git a/examples/longrope2/data/process.py b/examples/longrope2/data/process.py new file mode 100644 index 00000000..3a8e0856 --- /dev/null +++ b/examples/longrope2/data/process.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from datasets import load_from_disk, concatenate_datasets, load_dataset +from pathlib import Path +from transformers import PreTrainedTokenizer +from typing import List, Dict + +from examples.transformers_utils import get_tokenizer + + +ROOT_SAVE_DIR = Path(__file__).parent +MAX_WORKERS = 32 + + +def tokenize(samples: Dict[str, List], tokenizer: PreTrainedTokenizer, min_length=None, max_length=None): + def condition(text): + return (min_length is None or len(text) >= min_length) and (max_length is None or len(text) <= max_length) + + input_ids_list = [] + for text in samples["text"]: + if condition(text): + input_ids_list.append(tokenizer.encode(tokenizer.bos_token + text + tokenizer.eos_token, add_special_tokens=False)) + return {"input_ids": input_ids_list, "length": [len(input_ids) for input_ids in input_ids_list]} + + +def cat(samples: Dict[str, List], max_seq_len=128 * 1024, context_len=128 * 1024): + input_ids_list = [] + position_ids_list = [] + + input_ids_buffer = [] + position_ids_buffer = [] + + for input_ids in samples["input_ids"]: + input_ids = input_ids[:context_len] + input_ids_buffer.extend(input_ids) + position_ids_buffer.extend(range(len(input_ids))) + + if len(input_ids_buffer) >= max_seq_len: + input_ids_list.append(input_ids_buffer[:max_seq_len]) + position_ids_list.append(position_ids_buffer[:max_seq_len]) + + input_ids_buffer = [] + position_ids_buffer = [] + + return {"input_ids": input_ids_list, "position_ids": position_ids_list} + + +if __name__ == "__main__": + root_save_dir = ROOT_SAVE_DIR + max_workers = MAX_WORKERS + + parser = argparse.ArgumentParser(description="Set the tokenizer name or path.") + parser.add_argument( + '--tokenizer_name_or_path', + type=str, + default='meta-llama/Meta-Llama-3-8B', + help='Path to the tokenizer model or name of the tokenizer.' + ) + args = parser.parse_args() + tokenizer_name_or_path = args.tokenizer_name_or_path + + max_seq_len = 128 * 1024 + short_context_len = 8 * 1024 + long_context_len = 128 * 1024 + short_len_split = [64, 2 * 1024, 4 * 1024, 9 * 1024] + long_len_split = [8 * 1024, 32 * 1024, 64 * 1024, 128 * 1024, 200 * 1024] + + tokenizer = get_tokenizer(tokenizer_name_or_path) + + fweb_dataset = load_dataset(str(root_save_dir / "fineweb-edu/sample/10BT"), num_proc=max_workers)["train"] + fweb_dataset = fweb_dataset.map(tokenize, + fn_kwargs={"tokenizer": tokenizer, "max_length": 5 * short_len_split[-1]}, + batched=True, + num_proc=max_workers, + batch_size=10000, + remove_columns=fweb_dataset.column_names) + + sub_fweb_datasets = [] + fweb_sample_size = [8000, 8000, 16000] + for left, right, size in zip(short_len_split[:-1], short_len_split[1:], fweb_sample_size): + assert left < right + fweb_dataset_idx = [idx for idx, length in enumerate(fweb_dataset["length"]) if left < length <= right] + sub_fweb_dataset = fweb_dataset.select(fweb_dataset_idx) + sub_fweb_dataset = fweb_dataset.map(cat, + fn_kwargs={"max_seq_len": max_seq_len, "context_len": short_context_len}, + batched=True, + num_proc=max_workers, + batch_size=10000, + remove_columns=sub_fweb_dataset.column_names) + sub_fweb_dataset = sub_fweb_dataset.select(range(size)) + print(f"Short context [cat]: {left} - {right}, sample size: {len(sub_fweb_dataset)}") + sub_fweb_datasets.append(sub_fweb_dataset) + concatenate_datasets(sub_fweb_datasets).save_to_disk(root_save_dir / "fineweb-edu-sample-10BT-short-context") + del sub_fweb_datasets + + for split in ["arxiv", "common_crawl", "wikipedia"]: + rp_dataset = load_dataset(str(root_save_dir / "RedPajama-Data-1T" / f"{split}"), num_proc=max_workers)["train"] + rp_dataset = rp_dataset.map(tokenize, + batched=True, + fn_kwargs={"tokenizer": tokenizer, "min_length": 4 * long_len_split[0], "max_length": 5 * long_len_split[-1]}, + num_proc=max_workers, + batch_size=10000, + remove_columns=rp_dataset.column_names) + rp_dataset_idx = [idx for idx, length in enumerate(rp_dataset["length"]) if long_len_split[0] < length <= long_len_split[-1]] + rp_dataset = rp_dataset.select(rp_dataset_idx) + print(f"Long context [{split} filter]: {long_len_split[0]} - {long_len_split[-1]}, sample size: {len(rp_dataset)}") + rp_dataset.save_to_disk(root_save_dir / f"RedPajama-Data-1T-{split}-long-context-filtered") + del rp_dataset + + sub_rp_datasets = [] + rp_sample_size = { + "arxiv": [3000, 4000, 8000, 3000], + "common_crawl": [2000, 3000, 8000, 5000], + "wikipedia": [3000, 1000, 0, 0] + } + for split in ["arxiv", "common_crawl", "wikipedia"]: + rp_dataset = load_from_disk(root_save_dir / f"RedPajama-Data-1T-{split}-long-context-filtered") + for left, right, size in zip(long_len_split[:-1], long_len_split[1:], rp_sample_size[split]): + assert left < right + rp_dataset_idx = [idx for idx, length in enumerate(rp_dataset["length"]) if left < length <= right] + sub_rp_dataset = rp_dataset.select(rp_dataset_idx) + sub_rp_dataset = sub_rp_dataset.map(cat, + fn_kwargs={"max_seq_len": max_seq_len, "context_len": long_context_len}, + batched=True, + num_proc=max_workers, + batch_size=10000, + remove_columns=sub_rp_dataset.column_names) + sub_rp_dataset = sub_rp_dataset.select(range(size)) + print(f"Long context [{split} cat]: {left} - {right}, sample size: {len(sub_rp_dataset)}") + sub_rp_datasets.append(sub_rp_dataset) + concatenate_datasets(sub_rp_datasets).save_to_disk(root_save_dir / "RedPajama-Data-1T-long-context") + del sub_rp_datasets + + # create final mix context window dataset + concatenate_datasets([ + load_from_disk(root_save_dir / "fineweb-edu-sample-10BT-short-context"), + load_from_disk(root_save_dir / "RedPajama-Data-1T-long-context"), + ]).save_to_disk(root_save_dir / f"mix-context-win-short-{short_context_len}-long-{long_context_len}") diff --git a/examples/longrope2/llama3_8b_longrope2_config.json b/examples/longrope2/llama3_8b_longrope2_config.json new file mode 100644 index 00000000..8e70ce4d --- /dev/null +++ b/examples/longrope2/llama3_8b_longrope2_config.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "original_max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "type": "longrope", + "long_factor": [1, 1.097906416, 1.205398499, 1.323414747, 1.452985542, 1.595242149, 1.751426591, 1.922902493, 2.111166985, 2.317863779, 2.544797515, 2.79394952, 3.067495105, 3.367822558, 3.697553996, 4.059568257, 4.457026037, 4.893397485, 5.372492496, 5.898493984, 6.475994392, 7.110035795, 7.806153921, 8.570426477, 9.40952622, 10.33077921, 11.34222878, 12.45270576, 13.67190555, 15.01047283, 16.48009443, 18.09360142, 19.8650811, 21.81, 21.87, 21.88, 22.23, 22.28, 22.53, 22.65, 22.73, 22.89, 22.92, 23.18, 23.23, 23.36, 23.42, 23.47, 23.51, 23.52, 23.75, 23.82, 23.89, 24.06, 24.13, 24.14, 24.3, 24.53, 24.67, 24.75, 24.97, 25.49, 26.01, 26.05], + "short_factor": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/examples/longrope2/requirement.txt b/examples/longrope2/requirement.txt new file mode 100644 index 00000000..d17e4d81 --- /dev/null +++ b/examples/longrope2/requirement.txt @@ -0,0 +1,4 @@ +zstandard +transformers>=4.48 +datasets +tensorboard diff --git a/examples/longrope2/rope_modifier.py b/examples/longrope2/rope_modifier.py new file mode 100644 index 00000000..b10852d7 --- /dev/null +++ b/examples/longrope2/rope_modifier.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch + +from nnscaler.graph.parser.register import register_op + +def get_longrope_inv_freq(position_ids, base, head_dim, original_max_position_embeddings, long_factor, short_factor): + seq_len = torch.max(position_ids) + 1 + if seq_len and seq_len > original_max_position_embeddings: + ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=position_ids.device) + else: + ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=position_ids.device) + inv_freq_shape = torch.arange(0, head_dim, 2, dtype=torch.int64, device=position_ids.device).float() / head_dim + inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) + return inv_freq + +register_op("b^ l^ -> ?")(get_longrope_inv_freq) + + +@torch.no_grad() +def longrope_forward(self, x, position_ids): + assert self.rope_type == "longrope" + base = self.config.rope_theta + head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) + long_factor = self.config.rope_scaling["long_factor"] + short_factor = self.config.rope_scaling["short_factor"] + + if hasattr(self.config, "original_max_position_embeddings"): + original_max_position_embeddings = self.config.original_max_position_embeddings + else: + original_max_position_embeddings = self.config.max_position_embeddings + inv_freq = get_longrope_inv_freq(position_ids, base, head_dim, original_max_position_embeddings, long_factor, short_factor) + + # Core RoPE block + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def modify_rope_cls(cls): + cls.forward = longrope_forward + return cls diff --git a/examples/longrope2/run.sh b/examples/longrope2/run.sh new file mode 100644 index 00000000..d1395102 --- /dev/null +++ b/examples/longrope2/run.sh @@ -0,0 +1,13 @@ +SCRIPT_DIR=$(dirname "$(realpath "$0")") +EXAMPLE_DIR=$(dirname "$SCRIPT_DIR") +MAGICCUBE_DIR=$(dirname "$EXAMPLE_DIR") +export PYTHONPATH=$PYTHONPATH:$MAGICCUBE_DIR:$EXAMPLE_DIR + +# download data to at MagicCube/examples/longrope2/data, will take around 100GB disk memory. +python data/download.py +# process the data to mix context window length format for long context training, will take around 900GB disk memory. +python data/process.py --tokenizer_name_or_path "meta-llama/Meta-Llama-3-8B" +# compile the distributed code for llama3 model with dp2, tp4 on 8 gpus +python train.py --run_mode compile --model_id "meta-llama/Meta-Llama-3-8B" --model_config llama3_8b_longrope2_config.json --dataset_path data/mix-context-win-short-8192-long-131072 --plan_ngpus=4 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --gpu_mem_constraint 64 --enable-chunk-loss --grad_accumulation_steps 16 --max_train_steps 2250 2>&1 | tee compile.log +# run the training job +torchrun --nproc_per_node=8 train.py --model_id "meta-llama/Meta-Llama-3-8B" --model_config llama3_8b_longrope2_config.json --dataset_path data/mix-context-win-short-8192-long-131072 --plan_ngpus=4 --runtime_ngpus=8 --recompute_modules LlamaDecoderLayer --gpu_mem_constraint 64 --enable-chunk-loss --grad_accumulation_steps 16 --max_train_steps 2250 2>&1 | tee run.log diff --git a/examples/longrope2/train.py b/examples/longrope2/train.py new file mode 100644 index 00000000..4e002cae --- /dev/null +++ b/examples/longrope2/train.py @@ -0,0 +1,288 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os + +import torch + +from nnscaler.utils import set_default_logger_level +from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer_args import ( + CheckpointConfig, + DatasetConfig, + ModelConfig, + OptimizerConfig, + TrainerArgs, + DataloaderConfig, + LogConfig, + DatasetSamplerConfig, + LRSchedulerConfig, +) +from nnscaler.parallel import ComputeConfig +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdamW +from nnscaler.cli.loggers.tensorboard import TensorBoardLogger + +from examples.transformers_utils import WrapperModel, aggregate_outputs_fn +from examples.longrope2.data.dataset import get_dataset +from examples.longrope2.rope_modifier import modify_rope_cls +from examples.warmup_schedular import WarmupCosineAnnealingLR + + +def main(args): + + if args.run_mode == 'run': + broadcast_strategy = 'all' + else: + broadcast_strategy = 'none' + + set_default_logger_level('INFO') + + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + modify_rope_cls(LlamaRotaryEmbedding) + + ## Setup Dataset ## + + dataset, collate_fn = get_dataset(args.dataset_path, args.model_id) + + ## Config Trainer ## + + if args.run_mode == 'compile': + if args.runtime_ngpus is None: + raise ValueError('runtime_ngpus must be specified in compile mode') + runtime_ngpus = args.runtime_ngpus + elif args.run_mode == 'run': + world_size = int(os.getenv('WORLD_SIZE')) + if args.runtime_ngpus is None: + runtime_ngpus = world_size + else: + if args.runtime_ngpus != world_size: + raise ValueError('runtime_ngpus must match the number of GPUs in run mode') + runtime_ngpus = args.runtime_ngpus + if runtime_ngpus % args.plan_ngpus != 0: + raise ValueError('runtime_ngpus must be a multiple of plan_ngpus') + + compute_config = ComputeConfig( + plan_ngpus=args.plan_ngpus, + runtime_ngpus=runtime_ngpus, + constant_folding=True, + use_zero=True, + use_end2end=True, + pas_config={ + 'mem_constraint': args.gpu_mem_constraint, + 'pipeline_pivots': args.pipeline_pivots, + 'pipeline_nstages': args.pipeline_nstages, + 'recompute_modules': args.recompute_modules, + }, + trace_strategy=args.trace_strategy, + ) + + model_config = ModelConfig( + type=WrapperModel, + args={ + 'model_id': args.model_id, + 'config': args.model_config, + 'enable_chunk_loss': args.enable_chunk_loss, + 'attn_implementation': args.attn_implementation, + }, + ) + + # optimizer hyperparameters are from YaRN + optimizer_config = OptimizerConfig( + type=MixedPrecisionAdamW, + args={'lr': 2e-5, 'betas': (0.9, 0.95), 'weight_decay': 0.0, 'fused': True}, + clip_gnorm=1.0, + loss_reduction='sum', + grad_reduction='per-token-mean', + aggregate_outputs_fn=aggregate_outputs_fn, + ) + + lrscheduler_config = LRSchedulerConfig( + type=WarmupCosineAnnealingLR, + args={ + 'warmup_steps': args.warmup_steps, + 'T_max': args.max_train_steps, + }, + interval='step', + ) + + dataset_config = DatasetConfig( + type=(lambda split: dataset), + train_args={'split': 'train'}, + ) + + dataloader_config = DataloaderConfig( + train_args={ + 'collate_fn': collate_fn, + 'drop_last': True, + }, + ) + + sampler_config = DatasetSamplerConfig( + train_args={ + 'shuffle': True, + }, + ) + + checkpoint_config = CheckpointConfig( + every_n_train_steps=200, + save_type='deduped', + resume_from=(args.resume_path or 'last'), + ) + + log_config = LogConfig( + type=TensorBoardLogger, + args={ + 'name': args.name, + 'root_dir': './runs', + }, + ) + + trainer_args = TrainerArgs( + instance_name=args.name, + run_mode=args.run_mode, + compute_config=compute_config, + pas_policy='autodist', + model=model_config, + optimizer=optimizer_config, + lr_scheduler=lrscheduler_config, + dataset=dataset_config, + dataloader=dataloader_config, + checkpoint=checkpoint_config, + precision='bf16', + max_epochs=None, + grad_accumulation_steps=args.grad_accumulation_steps, + max_train_steps=args.max_train_steps, + log=[log_config], + seed=0, + broadcast_strategy=broadcast_strategy, + dataset_sampler=sampler_config, + ) + + trainer = Trainer(train_args=trainer_args) + trainer.run() + + +if __name__ == '__main__': + ## Parse Args ## + + parser = argparse.ArgumentParser() + parser.add_argument( + '--name', + default='llama', + type=str, + help='name of the experiment', + ) + parser.add_argument( + '--run_mode', + default='run', + choices=['run', 'compile'], + help='run or compile', + ) + parser.add_argument( + '--plan_ngpus', + type=int, + required=True, + help='specify the scale unit size', + ) + parser.add_argument( + '--runtime_ngpus', + type=int, + required=True, + help='specify the number of GPUs to use', + ) + parser.add_argument( + '--resume_path', + default=None, + type=str, + help='path to dir of ckpts or the ckpt file to resume from', + ) + parser.add_argument( + '--dataset_path', + default=None, + type=str, + help='path to the dataset', + ) + parser.add_argument( + '--model_id', + default=None, + type=str, + help='transformers model id', + ) + parser.add_argument( + '--model_config', + default=None, + type=str, + help='transformers model config json path', + ) + parser.add_argument( + '--gpu_mem_constraint', + default=64, + type=int, + help='the max memory usage constraint (GB) per GPU during nnscaler generating distribution plan, recommended to be 80 percent of GPU memory', + ) + parser.add_argument( + '--trace_strategy', + default='reuse_cache', + type=str, + help='trace strategy control the function execution during tracing model graph, `cuda_run_cpu_offload` and `reuse_cache` are recommended, please read `docs/source/parallel_module.md` for more information', + ) + parser.add_argument( + '--enable-chunk-loss', + action='store_true', + help='enable chunk loss that exchanges the speed of training for the memory usage', + ) + parser.add_argument( + '--pipeline_pivots', + default='', + type=str, + help='specify the pipeline pivots for autodist', + ) + parser.add_argument( + '--pipeline_nstages', + default=1, + type=str, + help='specify the number of stages in the pipeline (use "1" to disable pipeline; use "auto" for autodist)', + ) + parser.add_argument( + '--recompute_modules', + default='', + type=str, + help='specify the modules to recompute in autodist', + ) + parser.add_argument( + '--grad_accumulation_steps', + default=4, + type=int, + help='number of gradient accumulation steps', + ) + parser.add_argument( + '--max_train_steps', + default=1000, + type=int, + help='max training steps', + ) + parser.add_argument( + '--warmup_steps', + default=40, + type=int, + help='warmup steps', + ) + parser.add_argument( + '--attn_implementation', + default='flash_attention_2', + type=str, + help='attn implementation, can be flash_attention_2, spda, eager', + ) + + args = parser.parse_args() + if args.pipeline_nstages != 'auto': + args.pipeline_nstages = int(args.pipeline_nstages) + if args.pipeline_nstages > 1 and not args.pipeline_pivots: + raise ValueError('pipeline_pivots must be specified when pipeline is enabled') + + if os.getenv('DETERMINISTIC'): # reduce randomness for integration test + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + + main(args) diff --git a/examples/transformers_utils/__init__.py b/examples/transformers_utils/__init__.py new file mode 100644 index 00000000..97e866da --- /dev/null +++ b/examples/transformers_utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from packaging import version +import transformers + +from .causal_lm_wrapper import WrapperModel, aggregate_outputs_fn +from .tokenizer import get_tokenizer + +if version.parse(transformers.__version__) >= version.parse('4.43.0'): + from .flash_attn_anno import * +else: + # need specified support for each model if transformers version < 4.43.0 + pass diff --git a/examples/transformers_utils/causal_lm_wrapper.py b/examples/transformers_utils/causal_lm_wrapper.py new file mode 100644 index 00000000..48ace681 --- /dev/null +++ b/examples/transformers_utils/causal_lm_wrapper.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from transformers import AutoModelForCausalLM, AutoConfig + +from nnscaler.cli.trainer_args import AggregatedOutputs + +from .chunk_linear_cross_entropy import chunk_linear_cross_entropy + + +IGNORE_IDX = -100 + + +class WrapperModel(torch.nn.Module): + def __init__(self, model_id, config=None, enable_chunk_loss=False, attn_implementation='flash_attention_2'): + super().__init__() + if isinstance(config, str): + config = AutoConfig.from_pretrained(config) + self.model = AutoModelForCausalLM.from_pretrained(model_id, config=config, attn_implementation=attn_implementation) + self.enable_chunk_loss = enable_chunk_loss + + def forward(self, samples): + if self.enable_chunk_loss: + outputs = self.model.model( + **samples['net_input'], + use_cache=False, + return_dict=False, + ) + hidden_states = outputs[0] + losses = chunk_linear_cross_entropy(hidden_states, self.model.lm_head.weight, samples['target'], IGNORE_IDX, 1024) + loss = torch.sum(losses) + else: + outputs = self.model( + **samples['net_input'], + use_cache=False, + return_dict=False, + ) + logits = outputs[0].view(-1, outputs[0].size(-1)) + labels = samples['target'].view(-1) + normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) + loss = torch.nn.functional.nll_loss(normalized_logits, labels, reduction='sum', ignore_index=IGNORE_IDX) + return loss, loss.data, samples['ntokens'], samples['nsentences'] + + +def aggregate_outputs_fn(loss_outputs, sync_group) -> AggregatedOutputs: + losses, ntokens_info = [], [] + for _, loss, ntokens, _ in loss_outputs: + losses.append(loss) + ntokens_info.append(ntokens) + + loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) + torch.distributed.all_reduce(loss_sum, group=sync_group) + ntokens_sum = torch.sum(torch.tensor(ntokens_info, dtype=torch.float64, device=torch.cuda.current_device())) + torch.distributed.all_reduce(ntokens_sum, group=sync_group) + num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) + torch.distributed.all_reduce(num_batches, group=sync_group) + + return AggregatedOutputs( + loss_sum=loss_sum.item() / ntokens_sum.item(), + num_batches=num_batches.item(), + num_tokens=ntokens_sum.item(), + ) diff --git a/examples/transformers_utils/chunk_linear_cross_entropy.py b/examples/transformers_utils/chunk_linear_cross_entropy.py new file mode 100644 index 00000000..5fd3a54f --- /dev/null +++ b/examples/transformers_utils/chunk_linear_cross_entropy.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.utils.checkpoint as ckpt + +from nnscaler.graph.parser.register import register_op + + +def linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int = 0) -> torch.Tensor: + """ + Compute the cross entropy loss of a linear layer. + + Args: + + x: [token_num, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [token_num], the target token index + padding_idx: int, the index of padding token + + Returns: + + losses: [token_num], the cross entropy loss of each token + """ + logits = torch.nn.functional.linear(x, w) + normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) + losses = torch.nn.functional.nll_loss(normalized_logits, y, reduction='none', ignore_index=padding_idx) + return losses + + +def chunk_linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor: + """ + In order to reduce the memory usage when the sequence length and dictionary size are large, we can split the input + tensor into chunks and compute the cross entropy loss of each chunk separately. + You can register this function with annotation 'b l d^, n^ d^, b l -> b l'. + + Args: + + x: [bsz, seq_len, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [bsz, seq_len], the target token index + padding_idx: int, the index of padding token + chunk_size: int, the size of each chunk + + Returns: + + losses: [bsz, seq_len], the cross entropy loss of each token + """ + bsz, seq_len, hidden_size = x.size() + token_num = bsz * seq_len + x = x.view(token_num, hidden_size) + y = y.view(token_num) + + if token_num % chunk_size != 0: + raise ValueError(f"token_num {token_num} is not divisible by chunk_size {chunk_size}") + + chunk_num = token_num // chunk_size + xs = x.view(chunk_num, chunk_size, hidden_size) + ys = y.view(chunk_num, chunk_size) + losses = [] + for i in range(chunk_num): + loss = ckpt.checkpoint(linear_cross_entropy, xs[i], w, ys[i], padding_idx, use_reentrant=False) + losses.append(loss) + losses = torch.stack(losses).view(bsz, seq_len) + return losses + + +register_op('b l d^, n^ d^, b l -> b l')(chunk_linear_cross_entropy) diff --git a/examples/transformers_utils/flash_attn_anno.py b/examples/transformers_utils/flash_attn_anno.py new file mode 100644 index 00000000..0162b90d --- /dev/null +++ b/examples/transformers_utils/flash_attn_anno.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Optional + +from nnscaler.ir import IRTensor +from nnscaler.graph.parser.register import register_op +from transformers.modeling_flash_attention_utils import _flash_attention_forward + +import torch + + +def flash_attention_anno( + query_states: IRTensor, + key_states: IRTensor, + value_states: IRTensor, + attention_mask: Optional[IRTensor], + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[IRTensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: bool = None, + *args, + **kwargs + # added >= 4.47 + # cu_seq_lens_q: Optional[IRTensor] = None, + # cu_seq_lens_k: Optional[IRTensor] = None, + # max_length_q: Optional[int] = None, + # max_length_k: Optional[int] = None, + # target_dtype: Optional[torch.dtype] = None + ) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + input_anno = f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^' + + if isinstance(attention_mask, IRTensor): + # add attention mask annotation + input_anno += ', b l^' + if isinstance(position_ids, IRTensor): + # add position_ids annotation + input_anno += ', ?, ?, ?, b l^' + elif isinstance(position_ids, IRTensor): + # add position_ids annotation + input_anno += ', ?, ?, ?, ?, b l^' + + if 'cu_seq_lens_q' in kwargs: + cu_seq_lens_q = kwargs['cu_seq_lens_q'] + cu_seq_lens_k = kwargs['cu_seq_lens_k'] + assert not isinstance(cu_seq_lens_k, IRTensor) and not isinstance(cu_seq_lens_q, IRTensor), f'cu_seq_lens_k: {cu_seq_lens_k}, cu_seq_lens_q: {cu_seq_lens_q}, not supported' + + return f'{input_anno} -> b l^ {q_anno} vd^' + + +register_op(flash_attention_anno)(_flash_attention_forward) + + +# Copy from transformers/integrations/flash_attention.py +# To solve the issue of transformers/integrations/flash_attention.py using relative import _flash_attention_forward, +# and the anno issue mentioned in the following code snippet. +from typing import Optional, Tuple +import torch +from transformers.utils import is_flash_attn_greater_or_equal_2_10 +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers import modeling_flash_attention_utils + +_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + +def flash_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + # This is before the transpose + seq_len = query.shape[2] + + # FA2 uses non-transposed inputs + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (usually our RMSNorm modules handle it correctly) + target_dtype = None + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(module.config, "_pre_quantization_dtype"): + target_dtype = module.config._pre_quantization_dtype + else: + target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype + + # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice + kwargs.pop("is_causal", None) + + # NNScaler: can not use for example, ```query_length=seq_len```, will case anno error, + # all inputs that have annotation should not use xxx_name=xxx format + attn_output = modeling_flash_attention_utils._flash_attention_forward( + query, + key, + value, + attention_mask, + seq_len, + module.is_causal, + dropout, + kwargs.pop("position_ids", None), + softmax_scale=scaling, + sliding_window=sliding_window, + softcap=softcap, + use_top_left_mask=_use_top_left_mask, + target_dtype=target_dtype, + **kwargs, + ) + + return attn_output, None + + +ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward diff --git a/examples/transformers_utils/tokenizer.py b/examples/transformers_utils/tokenizer.py new file mode 100644 index 00000000..1a0636ec --- /dev/null +++ b/examples/transformers_utils/tokenizer.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from transformers import AutoTokenizer + + +IGNORE_IDX = -100 + + +def get_tokenizer(tokenizer_name_or_path, + model_max_length=None,): + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True) + special_tokens_dict = dict() + assert tokenizer.bos_token is not None and tokenizer.eos_token is not None + if tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = tokenizer.unk_token if tokenizer.unk_token else tokenizer.eos_token + if tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = tokenizer.pad_token if tokenizer.pad_token else tokenizer.eos_token + + tokenizer.add_special_tokens(special_tokens_dict) + if model_max_length: + tokenizer.model_max_length = model_max_length + return tokenizer diff --git a/examples/warmup_schedular.py b/examples/warmup_schedular.py new file mode 100644 index 00000000..b066c385 --- /dev/null +++ b/examples/warmup_schedular.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import math +from torch.optim.lr_scheduler import LRScheduler, Optimizer, _warn_get_lr_called_within_step + + +class WarmupCosineAnnealingLR(LRScheduler): + r""" + torch.optim.lr_scheduler.CosineAnnealingLR with warmup. + + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_steps (int): Number of warmup steps. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + + def __init__( + self, + optimizer: Optimizer, + warmup_steps: int, + T_max: int, + eta_min=0.0, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.warmup_steps = warmup_steps + self.T_max = T_max - warmup_steps + 1 + self.eta_min = eta_min + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Retrieve the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + last_epoch_wo_warmup = self.last_epoch - self.warmup_steps + 1 + if last_epoch_wo_warmup < 0: + return [base_lr * (self.last_epoch + 1) / self.warmup_steps for base_lr in self.base_lrs] + elif last_epoch_wo_warmup == 0: + return [base_lr for base_lr in self.base_lrs] + elif self._step_count == 1 and last_epoch_wo_warmup > 0: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos((last_epoch_wo_warmup) * math.pi / self.T_max)) + / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif (last_epoch_wo_warmup - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * last_epoch_wo_warmup / self.T_max)) + / (1 + math.cos(math.pi * (last_epoch_wo_warmup - 1) / self.T_max)) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + last_epoch_wo_warmup = self.last_epoch - self.warmup_steps + 1 + if last_epoch_wo_warmup < 0: + return [base_lr * (self.last_epoch + 1) / self.warmup_steps for base_lr in self.base_lrs] + else: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * last_epoch_wo_warmup / self.T_max)) + / 2 + for base_lr in self.base_lrs + ] diff --git a/nnscaler/graph/tracer/wrap_utils.py b/nnscaler/graph/tracer/wrap_utils.py index 28ff2e37..318dfb6b 100644 --- a/nnscaler/graph/tracer/wrap_utils.py +++ b/nnscaler/graph/tracer/wrap_utils.py @@ -553,6 +553,8 @@ class torch_autocast_wrapper_clz: _fx_wrapped_ori_clz = orig_func.torch_autocast def __new__(cls, *args, **kwargs): + args = (arg.value if orig_func.isinstance(arg, cct.ConcreteProxy) else arg for arg in args) + kwargs = {k: v.value if orig_func.isinstance(v, cct.ConcreteProxy) else v for k, v in kwargs.items()} return orig_func.torch_autocast(*args, **kwargs) def __eq__(self, __o: object) -> bool: From ffa8831030ec6d9658fd59576b308ab11c87ec04 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 24 Mar 2025 08:41:30 +0000 Subject: [PATCH 1809/1892] Merged PR 2370: [BugFix] Fix autocast in AutoDist --- nnscaler/algorithm/ops/dimops.py | 3 + nnscaler/graph/graph.py | 2 - nnscaler/profiler/database.py | 14 ++++- tests/autodist/spmd_solver/test_autocast.py | 61 +++++++++++++++++++++ 4 files changed, 76 insertions(+), 4 deletions(-) create mode 100644 tests/autodist/spmd_solver/test_autocast.py diff --git a/nnscaler/algorithm/ops/dimops.py b/nnscaler/algorithm/ops/dimops.py index 1ba38c79..246b749b 100644 --- a/nnscaler/algorithm/ops/dimops.py +++ b/nnscaler/algorithm/ops/dimops.py @@ -10,6 +10,7 @@ from nnscaler.ir.tensor import IRSubTensor from nnscaler.ir.cten import IRTensor from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph import IRGraph from collections import deque _logger = logging.getLogger(__name__) @@ -173,6 +174,8 @@ def transform(tensor: Any, split: DimopSplit) -> List[Any]: sub_node.verify_shape() sub_nodes.append(sub_node) + for sub_node in sub_nodes: + IRGraph.copy_node_meta_info(node, sub_node) return sub_nodes def infer(self, idx: int, dim: Union[int, str], num: int) -> Optional[TransformRule]: diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 39f52dc7..550b21f1 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -396,8 +396,6 @@ def partition(self, node: Union[IRFwOperation, IRDataOperation], # insert forward node fsegment: IRSegment = self.segment(node) - for fnode in fnodes: - self.copy_node_meta_info(node, fnode) fsegment.replace(node, fnodes) if node.mirror is None: return fnodes diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 30f76f10..2d162e50 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -15,6 +15,7 @@ import logging from dataclasses import dataclass, asdict from pathlib import Path +import contextlib import _operator # required by eval() import nnscaler # required by eval() @@ -22,6 +23,7 @@ from nnscaler.ir.cten import IRTensor, IRObject from nnscaler.ir.operator import IRFwOperation from nnscaler.graph.parser.register import CustomizedOps +from nnscaler.graph.tracer.metadata import AutocastInfo _logger = logging.getLogger(__name__) @@ -176,8 +178,15 @@ def gen_torch_tensors(shape, dtype, requires_grad): train_kwargs[name] = train_val eval_kwargs[name] = eval_val + assert node.op_context is not None, f"node {node}: op_context is None" + autocast_info = AutocastInfo(**node.op_context["autocast_info"]) + if autocast_info.nesting > 0: + ctx = torch.autocast(device_type='cuda', dtype=autocast_info.cuda_dtype, enabled=autocast_info.cuda_enabled, cache_enabled=autocast_info.cache_enabled) + else: + ctx = contextlib.nullcontext() # run one sample - outputs = func(*tensors, **train_kwargs) + with ctx: + outputs = func(*tensors, **train_kwargs) # check whether func is a in-place operation for t1, t2 in zip(in_tensors, tensors): @@ -195,7 +204,8 @@ def gen_torch_tensors(shape, dtype, requires_grad): grads = tuple(torch.zeros_like(otensor) for otensor in outputs) def run_step(func, tensors, kwargs, backward: bool): - outputs = func(*tensors, **kwargs) + with ctx: + outputs = func(*tensors, **kwargs) outputs = (outputs,) if torch.is_tensor(outputs) else outputs outputs = tuple(filter(lambda x: torch.is_tensor(x) and x.requires_grad, outputs)) if backward: diff --git a/tests/autodist/spmd_solver/test_autocast.py b/tests/autodist/spmd_solver/test_autocast.py new file mode 100644 index 00000000..c191e0ea --- /dev/null +++ b/tests/autodist/spmd_solver/test_autocast.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +import tempfile +import torch +import torch.nn.functional as F +import os +from pathlib import Path +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.model_graph import ModelGraph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.autodist.spmd_solver import SPMDSolver + + +class Model(torch.nn.Module): + def __init__(self, hidden_dim): + super().__init__() + self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + x = x.sum() + return x + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_autocast(): + bsz, seq_len, hidden_dim = 2, 16, 16 + + dummy_input = {'x': torch.randn(bsz, seq_len, hidden_dim)} + model = Model(hidden_dim) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + constant_folding=True) + + cfg = AutoDistConfig(mesh_col=2, re_profile=True, parallel_profile=False) + model_graph = ModelGraph(ir_graph, cfg) + + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=cfg.mesh_desc, + autodist_config=cfg, + stage_num=1, + micro_batch_num=cfg.update_freq, + ) + + partition_counts = [ + spmd_solver.get_op_partition_count(i) + for i in range(model_graph.op_num) + ] + assert partition_counts == [4, 4, 4, 4] From 4390c16b53f06ab8e2b9ce4f8ba21fd7d344d3ea Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 26 Mar 2025 02:46:40 +0000 Subject: [PATCH 1810/1892] Merged PR 2371: [BugFix] Fix AutoDist's implementation - refine re_profile's behavior - fix recompute estimation when handling border ops - refine info and add test when SPMD's follow fails --- nnscaler/autodist/cost_database.py | 5 +- nnscaler/autodist/model_graph.py | 3 + nnscaler/autodist/spmd_solver.py | 6 +- .../spmd_solver/test_trigger_follow_error.py | 71 ++++++++++++++ ...t_trigger_follow_error.mock_attention.json | 44 +++++++++ .../comp/torch.Tensor.reshape.json | 50 ++++++++++ .../comp/torch.Tensor.view.json | 38 ++++++++ .../comp/torch.nn.functional.linear.json | 92 +++++++++++++++++++ .../comp/torch.sum.json | 62 +++++++++++++ 9 files changed, 368 insertions(+), 3 deletions(-) create mode 100644 tests/autodist/spmd_solver/test_trigger_follow_error.py create mode 100644 tests/autodist/spmd_solver/test_trigger_follow_error/comp/tests.autodist.spmd_solver.test_trigger_follow_error.mock_attention.json create mode 100644 tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.Tensor.reshape.json create mode 100644 tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.Tensor.view.json create mode 100644 tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.nn.functional.linear.json create mode 100644 tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.sum.json diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 82817af3..2e003359 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -164,8 +164,9 @@ def profile_comp(self, partition_degree: int, parallel_profile: bool, re_profile def insert_profile_info(info: List[Tuple[str, str, ProfiledMetrics]]): for sign, serialized, profiled_metrics in info: _logger.debug(f'profiled {sign} in {serialized} with {profiled_metrics}') - if not self.db.exist_serialized(sign, serialized): - self.db.insert(sign, serialized, profiled_metrics) + # Align with `re_profile`'s semantic, as long as we get the profile info, + # we will override the old one + self.db.insert(sign, serialized, profiled_metrics) if parallel_profile: _logger.info('Profiling in parallel') diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 7ab6f6c2..6f6da5d6 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -803,6 +803,9 @@ def label_ops(self, operator_list: List[CubeOperator]): if isinstance(t, IRTensor): output_tensors.add(t) for node in group: + # Since we only profile IRDimops, skip if the node is not + if not isinstance(node, IRDimops): + continue is_border = False for t in node.inputs(): if isinstance(t, IRTensor) and not t.is_attr(): diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index da2e755a..cca7f0be 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -579,7 +579,11 @@ def calc_father4op_partition(): if not cur_p_fathers[j]: cur_p_fathers[j] = -1 else: - assert len(cur_p_fathers[j]) == 1, f'unexpected partition {self.get_operator(i).ir_cell}, {cur_p_fathers[j]}' + error_msg = f'find multiple p_fathers {cur_p_fathers[j]} for {self.get_operator(i)}' \ + f', this may due to its producer has multiple partitions but hard to distinguish' \ + f' at current operator\'s perspective, you can try to\n 1. add partition constraint' \ + f' to the producer or the consumer \n 2. double check annotations of related ops' + assert len(cur_p_fathers[j]) == 1, error_msg cur_p_fathers[j] = cur_p_fathers[j][0] p_fathers.append(cur_p_fathers) # -1 will be filtered out in the intersection operation below diff --git a/tests/autodist/spmd_solver/test_trigger_follow_error.py b/tests/autodist/spmd_solver/test_trigger_follow_error.py new file mode 100644 index 00000000..ce7774fe --- /dev/null +++ b/tests/autodist/spmd_solver/test_trigger_follow_error.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest +import os +from pathlib import Path + +import tempfile +import torch +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph +from nnscaler.autodist.model_graph import ModelGraph +from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.autodist.spmd_solver import SPMDSolver +import nnscaler +from tests.utils import raises_with_cause + + +# this is a wrong annotation +@nnscaler.register_op('l^ hq dim^, l^ hkv dim^, l^ hkv dim^ -> l^ hq dim^') +def mock_attention(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): + return x + y + z + + +class Model(torch.nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + + def forward(self, x): + bsz, seq_len, hidden_dim = x.shape + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + q = q.view(bsz * seq_len, 2, hidden_dim // 2) + k = k.view(bsz * seq_len, 2, hidden_dim // 2) + v = v.view(bsz * seq_len, 2, hidden_dim // 2) + x = mock_attention(q, k, v) + x = x.reshape(bsz, seq_len, 2, hidden_dim // 2) + return x.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_trigger_follow_error(): + bsz, seq_len, hidden_dim = 2, 16, 16 + + dummy_input = {'x': torch.randn(bsz, seq_len, hidden_dim)} + model = Model(hidden_dim) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + profile_dir = Path(os.path.dirname(__file__)) / './test_trigger_follow_error' + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + constant_folding=True) + + cfg = AutoDistConfig(mesh_col=2, parallel_profile=False, profile_dir=profile_dir) + model_graph = ModelGraph(ir_graph, cfg) + + with raises_with_cause(AssertionError, match='find multiple p_fathers'): + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=cfg.mesh_desc, + autodist_config=cfg, + stage_num=1, + micro_batch_num=cfg.update_freq, + ) diff --git a/tests/autodist/spmd_solver/test_trigger_follow_error/comp/tests.autodist.spmd_solver.test_trigger_follow_error.mock_attention.json b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/tests.autodist.spmd_solver.test_trigger_follow_error.mock_attention.json new file mode 100644 index 00000000..608d096f --- /dev/null +++ b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/tests.autodist.spmd_solver.test_trigger_follow_error.mock_attention.json @@ -0,0 +1,44 @@ +{ + "(32, 2, 8)-(32, 2, 8)-(32, 2, 8)-(32, 2, 8) : torch.float32-torch.float32-torch.float32-torch.float32 : True-True-True-True": { + "in_mem_info": [ + 2048, + 2048, + 2048 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.029117846861481667, + "bw_span": 0.14072032645344734, + "infer_memory": 10240, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(32, 1, 8)-(32, 2, 8)-(32, 2, 8)-(32, 1, 8) : torch.float32-torch.float32-torch.float32-torch.float32 : True-True-True-True": { + "in_mem_info": [ + 1024, + 2048, + 2048 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03680214285850525, + "bw_span": 0.17663338221609592, + "infer_memory": 9216, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(32, 2, 8)-(32, 1, 8)-(32, 1, 8)-(32, 2, 8) : torch.float32-torch.float32-torch.float32-torch.float32 : True-True-True-True": { + "in_mem_info": [ + 2048, + 1024, + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03265049308538437, + "bw_span": 0.18002847209572792, + "infer_memory": 8192, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.Tensor.reshape.json b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.Tensor.reshape.json new file mode 100644 index 00000000..51de7fdf --- /dev/null +++ b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.Tensor.reshape.json @@ -0,0 +1,50 @@ +{ + "(32, 2, 8)-(2, 16, 2, 8) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 2048 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010654376819729805, + "bw_span": 0.07317820563912392, + "infer_memory": 2048, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(16, 2, 8)-(1, 16, 2, 8) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010553840547800064, + "bw_span": 0.06544496864080429, + "infer_memory": 1024, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(32, 1, 8)-(2, 16, 1, 8) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010478775948286057, + "bw_span": 0.06197541952133179, + "infer_memory": 1024, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(32, 2, 4)-(2, 16, 2, 4) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010212836787104607, + "bw_span": 0.06344295106828213, + "infer_memory": 1024, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.Tensor.view.json b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.Tensor.view.json new file mode 100644 index 00000000..5a7db450 --- /dev/null +++ b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.Tensor.view.json @@ -0,0 +1,38 @@ +{ + "(2, 16, 16)-(32, 2, 8) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 2048 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010434770956635475, + "bw_span": 0.07250779308378696, + "infer_memory": 2048, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 16, 16)-(16, 2, 8) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010340288281440735, + "bw_span": 0.07582530379295349, + "infer_memory": 1024, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 16, 8)-(32, 1, 8) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010559335350990295, + "bw_span": 0.07563922554254532, + "infer_memory": 1024, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.nn.functional.linear.json b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.nn.functional.linear.json new file mode 100644 index 00000000..b0771db3 --- /dev/null +++ b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.nn.functional.linear.json @@ -0,0 +1,92 @@ +{ + "(2, 16, 16)-(16, 16)-(2, 16, 16) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 2048 + ], + "param_mem_info": [ + 1024 + ], + "buffer_mem_info": [], + "fw_span": 0.043286802247166634, + "bw_span": 0.1421193592250347, + "infer_memory": 5120, + "train_mem_info": [ + 2048 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 16, 16)-(16, 16)-(1, 16, 16) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [ + 1024 + ], + "buffer_mem_info": [], + "fw_span": 0.04025171510875225, + "bw_span": 0.18368610180914402, + "infer_memory": 3072, + "train_mem_info": [ + 1024 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 8, 16)-(16, 16)-(2, 8, 16) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [ + 1024 + ], + "buffer_mem_info": [], + "fw_span": 0.03982819616794586, + "bw_span": 0.18744319677352905, + "infer_memory": 3072, + "train_mem_info": [ + 1024 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 16, 8)-(16, 8)-(2, 16, 16) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [ + 512 + ], + "buffer_mem_info": [], + "fw_span": 0.04101274535059929, + "bw_span": 0.18919017165899277, + "infer_memory": 3584, + "train_mem_info": [ + 1024 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 16, 16)-(8, 16)-(2, 16, 8) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 2048 + ], + "param_mem_info": [ + 512 + ], + "buffer_mem_info": [], + "fw_span": 0.043114274740219116, + "bw_span": 0.18649823032319546, + "infer_memory": 3584, + "train_mem_info": [ + 2048 + ], + "train_mem2in_idx": [ + 0 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.sum.json b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.sum.json new file mode 100644 index 00000000..df01f322 --- /dev/null +++ b/tests/autodist/spmd_solver/test_trigger_follow_error/comp/torch.sum.json @@ -0,0 +1,62 @@ +{ + "(2, 16, 2, 8)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 2048 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.014059478417038918, + "bw_span": 0.06327787414193153, + "infer_memory": 2560, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 16, 2, 8)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013872934505343437, + "bw_span": 0.052794069051742554, + "infer_memory": 1536, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 2, 8)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.023470260202884674, + "bw_span": 0.05659819580614567, + "infer_memory": 1536, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 16, 1, 8)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.023735733702778816, + "bw_span": 0.09062276221811771, + "infer_memory": 1536, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 16, 2, 4)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013737892732024193, + "bw_span": 0.05688727833330631, + "infer_memory": 1536, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file From eded7910790b316727232d71b14391eda17a5f43 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 26 Mar 2025 03:36:25 +0000 Subject: [PATCH 1811/1892] Merged PR 2372: [Parser] Support new functions & Refine apex register --- nnscaler/graph/function/function.py | 29 ++++++++++++++++++++++++++ nnscaler/graph/parser/external/apex.py | 9 ++++++++ nnscaler/graph/parser/mapping.py | 2 ++ tests/graph/function/test_functions.py | 14 +++++++++++++ 4 files changed, 54 insertions(+) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index ba75d68b..505dc61f 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -869,6 +869,35 @@ def Clamp(input, min=None, max=None, *, out=None, signature = None): return IRDimops(Clamp, 'clamp', signature, annos, [input], min=min, max=max) +def ViewAsComplex(input, signature = None): + """ + torch.view_as_complex(input) + """ + assert input.shape[-1] == 2 + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_in[-1] = '2' + if len(edim_in) == 1: + edim_ou = ['1'] + else: + edim_ou = copy.copy(edim_in[:-1]) + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(ViewAsComplex, 'view_as_complex', signature, [anno], [input]) + + +def ViewAsReal(input, signature = None): + """ + torch.view_as_real(input) + """ + if input.is_scalar_tensor(): + edim_in, edim_ou = ['1'], ['2'] + else: + edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = copy.copy(edim_in) + edim_ou.append('2') + anno = OpAnno.create_op_str([edim_in], [edim_ou]) + return IRDimops(ViewAsReal, 'view_as_real', signature, [anno], [input]) + + def ClampMin(input, min, *, out=None, signature = None): return Clamp(input, min=min, out=out, signature='torch.clamp') diff --git a/nnscaler/graph/parser/external/apex.py b/nnscaler/graph/parser/external/apex.py index c8274fe8..94b22209 100644 --- a/nnscaler/graph/parser/external/apex.py +++ b/nnscaler/graph/parser/external/apex.py @@ -75,5 +75,14 @@ def apex_fused_rms_norm_affine_anno(input, weight, normalized_shape, eps, *args, parser.register(apex_fused_rms_norm_anno)(FusedRMSNormFunction.apply) parser.register(apex_fused_rms_norm_affine_anno)(FusedRMSNormAffineFunction.apply) + + # wrap at a higher level since `Function.apply` may not be called in newer versions of apex + from apex.normalization.fused_layer_norm import fused_layer_norm, fused_layer_norm_affine, fused_rms_norm, fused_rms_norm_affine + + parser.register(apex_fused_layer_norm_anno)(fused_layer_norm) + parser.register(apex_fused_layer_norm_affine_anno)(fused_layer_norm_affine) + parser.register(apex_fused_rms_norm_anno)(fused_rms_norm) + parser.register(apex_fused_rms_norm_affine_anno)(fused_rms_norm_affine) + except: _logger.warning('skip apex ops as it is not installed.') diff --git a/nnscaler/graph/parser/mapping.py b/nnscaler/graph/parser/mapping.py index 03724a28..b6865401 100644 --- a/nnscaler/graph/parser/mapping.py +++ b/nnscaler/graph/parser/mapping.py @@ -140,6 +140,8 @@ def exist(signature: str) -> bool: __ftemplate('layer_norm'): function.LayerNorm, __ftemplate('scaled_dot_product_attention'): function.ScaledDotProductAttention, __fcntemplate('scaled_dot_product_attention'): function.ScaledDotProductAttention, + __ttemplate('view_as_complex'): function.ViewAsComplex, + __ttemplate('view_as_real'): function.ViewAsReal, # ============== runtime function ================= __tttemplate('size'): function.Size, diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 1b6277e2..5106768d 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -965,6 +965,20 @@ def test_Split(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '7 b c d -> 2 b c d, 2 b c d, 3 b c d' +def test_ViewAsComplex(): + op = F.ViewAsComplex(IRTensor([2, 3, 2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b 2 -> a b' + op = F.ViewAsComplex(IRTensor([2])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '2 -> 1' + + +def test_ViewAsReal(): + op = F.ViewAsReal(IRTensor([2, 3])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b -> a b 2' + op = F.ViewAsReal(IRTensor(shape=None)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '1 -> 2' + + def factors(n): return set(reduce( list.__add__, From a1368ee799d0a87e587fcee0b4872daa30967e4d Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 26 Mar 2025 06:38:26 +0000 Subject: [PATCH 1812/1892] Merged PR 2373: [BugFix] torch load fix for pytorch 2.6 pytorch 2.6 made a break change for torch.load. Default value of param `weights_only` is changed from `False` to `True`. We need it to be False, as we are saving partition-related information in checkpoint. unit test pass (for both torch==2.6 and torch==2.0.1) --- nnscaler/cli/trainer.py | 4 ++-- nnscaler/integration/lightning/pytorch/strategy.py | 6 +++--- nnscaler/parallel.py | 4 ++-- nnscaler/runtime/module.py | 8 ++++---- requirements.txt | 2 +- tests/cli/test_trainer.py | 6 +++--- tests/parallel_module/test_checkpoint.py | 6 +++--- tests/parallel_module/test_checkpoint_dedup.py | 2 +- tests/parallel_module/test_checkpoint_shared.py | 2 +- tests/runtime/test_f16_optimizer.py | 4 ++-- tests/runtime/test_module_merge.py | 6 +++--- 11 files changed, 25 insertions(+), 25 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 0834d906..c9804eea 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -201,7 +201,7 @@ def reducer_pre_hook(reducer, grad): @classmethod def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): - state_dicts = [torch.load(f, map_location='cpu') for f in checkpoint_files] + state_dicts = [torch.load(f, map_location='cpu', weights_only=False) for f in checkpoint_files] for i in range(1, len(state_dicts)): if state_dicts[i]['train_args'] != state_dicts[0]['train_args']: raise ValueError(f"train_args in {checkpoint_files[i]} is different from {checkpoint_files[0]}") @@ -244,7 +244,7 @@ def _load_checkpoint(self): resume_from = resume_from # when we load from merged checkpoint else: resume_from = resume_from / f'{self.rank}.ckpt' - state_dict = torch.load(resume_from, map_location='cpu') + state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) self.hook.on_load_checkpoint(self, state_dict) ckpt_save_type = state_dict['train_args']['checkpoint']['save_type'] diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py index 54461d39..d6619ced 100644 --- a/nnscaler/integration/lightning/pytorch/strategy.py +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -514,9 +514,9 @@ def load_checkpoint( raise FileNotFoundError(f"Checkpoint file {path} not found.") if path.is_dir(): - state_dict: dict = torch.load(path / f'{self.global_rank}.pt') + state_dict: dict = torch.load(path / f'{self.global_rank}.pt', weights_only=False) else: - state_dict: dict = torch.load(path) + state_dict: dict = torch.load(path, weights_only=False) nnscaler_extra_state = state_dict.pop(self._nnscaler_extra_state_key) # load the extra states of the pl module self._lightning_module.load_state_dict(nnscaler_extra_state[self._pl_module_name_key], strict=False) @@ -551,7 +551,7 @@ def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str) -> None checkpoint_files: The list of checkpoint files to merge. output_file: The output file path. """ - state_dicts = [torch.load(f, map_location='cpu') for f in checkpoint_files] + state_dicts = [torch.load(f, map_location='cpu', weights_only=False) for f in checkpoint_files] module_state_dicts = [s[cls._module_name_key] for s in state_dicts] opt_state_dicts = None diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index d47274e2..37fb55d2 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -277,7 +277,7 @@ def safe_load_from_file(cls, file: Union[str, Path], return_none_on_error=True) """ if Path(file).exists(): try: - cfg = torch.load(file) + cfg = torch.load(file, weights_only=False) if isinstance(cfg, dict): # in old version, we save the object directly (not save as dict) # this can raise if cfg has extra keys. # which means some fields of ComputeConfig has been removed(we should avoid this). @@ -790,7 +790,7 @@ def _gencode( ret = RegenStatus.CODE logger.info(f"Reuse graph dump in {outdir}") graph = IRGraph.load(graph_ckp) - forward_args = torch.load(forward_args_ckp) + forward_args = torch.load(forward_args_ckp, weights_only=False) graph = pas_policy(graph, compute_config) if not isinstance(graph, IRGraph): diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 6dc7cf97..9213dcdd 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -334,7 +334,7 @@ def get_checkpoint(self, optimizer: torch.optim.Optimizer = None): if not dist_param_map: module_file = Path(sys.modules[self.__module__].__file__) # load from the same directory as the module file - dist_param_map = torch.load(module_file.with_name(FxModuleParser.ATTR_MAP_FILE)) + dist_param_map = torch.load(module_file.with_name(FxModuleParser.ATTR_MAP_FILE), weights_only=False) param_area_map = self._fullmap optimizer_state_dict = optimizer.state_dict() if optimizer is not None else None return state_dict, dist_param_map, param_area_map, optimizer_state_dict @@ -700,7 +700,7 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): ckpts = {} for rank in range(DeviceGroup().world_size): filename = f"{filename_prefix}-{rank}.ckpt" - ckpts[rank] = torch.load(filename) + ckpts[rank] = torch.load(filename, weights_only=False) _logger.info(f'checkpoints = {ckpts}') state_dicts = [] @@ -835,12 +835,12 @@ def _post_init(self, init_params=True): self._warn_uninitialized_non_persistent_buffers() - self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}")) + self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}"), weights_only=False) self._compute_config: ComputeConfig = ComputeConfig.safe_load_from_file( module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}"), return_none_on_error=False ) - self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}")) + self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}"), weights_only=False) for reducer in self.reducers: reducer.build_buckets() diff --git a/requirements.txt b/requirements.txt index 2fec92d6..8ab5bea9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,5 @@ psutil pulp pybind11 pyyaml -torch>=2.0,<2.4 +torch>=2.0,<=2.6 tqdm diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 32f66764..59b77617 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -293,9 +293,9 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): if torch.distributed.get_rank() == 0: assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} for i in range(4): - x = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt') - y = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt') - z = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt') + x = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + z = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) assert_equal(x['model'], y['model']) assert_equal(x['optimizer'], y['optimizer']) assert_equal(x['lr_scheduler'], y['lr_scheduler']) diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 627f76bb..6faee640 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -308,7 +308,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf loss_fn = nn.BCELoss() optimizer = build_optimizer(model, torch.optim.Adam, lr=0.01) if ckpt_start_file.exists(): - ckpt_dict = torch.load(ckpt_start_file) + ckpt_dict = torch.load(ckpt_start_file, weights_only=False) model_state_dict = ckpt_dict['model'] for name, m in model.named_modules(): prefix = f'{name}.' if name else '' @@ -334,7 +334,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf torch.save(inference_module.state_dict(), temp_inferenece_ckpt_file) torch.distributed.barrier() inference_ckpt_files = [ckpt_dir / temp_inferenece_ckpt_file_template.format(rank=i) for i in range(torch.distributed.get_world_size())] - inference_state_dicts = [torch.load(f) for f in inference_ckpt_files] + inference_state_dicts = [torch.load(f, weights_only=False) for f in inference_ckpt_files] merged_inference_state_dict, _ = merge_state_dicts(inference_state_dicts) assert_model_state_dict_equal(merged_model_state_dict, merged_inference_state_dict) @@ -399,7 +399,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf torch.distributed.barrier() if torch.distributed.get_rank() == 0: ckpt_files = [ckpt_dir / ckpt_file_template.format(rank=i, start=end) for i in range(torch.distributed.get_world_size())] - ckpt_state_dicts = [torch.load(f) for f in ckpt_files] + ckpt_state_dicts = [torch.load(f, weights_only=False) for f in ckpt_files] model_state_dicts = [ckpt['model'] for ckpt in ckpt_state_dicts] optimizer_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_state_dicts] if check_merge_log: diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index d0f5ebae..339ea888 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -108,7 +108,7 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): ckpt_dir / CKPT_FILE_NAME_TEMPLATE.format(i) for i in range(torch.distributed.get_world_size()) ] - ckpt_state_dicts = [torch.load(f) for f in ckpt_files] + ckpt_state_dicts = [torch.load(f, weights_only=False) for f in ckpt_files] model_state_dicts = [ckpt['model'] for ckpt in ckpt_state_dicts] optimizer_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_state_dicts] dedupped_model_state_dicts = [ckpt['model-dedup'] for ckpt in ckpt_state_dicts] diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index 8d71a2f8..8b5e91b1 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -160,7 +160,7 @@ def _load_merged(parallel_model: torch.nn.Module, ckpt_dir): torch.distributed.barrier() if torch.distributed.get_rank() == 0: ckpt_files = [ckpt_dir / ckpt_file_template.format(rank=i) for i in range(torch.distributed.get_world_size())] - ckpt_state_dicts = [torch.load(f) for f in ckpt_files] + ckpt_state_dicts = [torch.load(f, weights_only=False) for f in ckpt_files] model_state_dicts = [ckpt['model'] for ckpt in ckpt_state_dicts] optimizer_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_state_dicts] merged_model_state_dicts, merged_optimizer_state_dict = merge_state_dicts(model_state_dicts, optimizer_state_dicts) diff --git a/tests/runtime/test_f16_optimizer.py b/tests/runtime/test_f16_optimizer.py index 3f843496..2a0d0630 100644 --- a/tests/runtime/test_f16_optimizer.py +++ b/tests/runtime/test_f16_optimizer.py @@ -42,8 +42,8 @@ def trainer_worker(save_dir): if torch.distributed.get_rank() == 0: for i in range(2): - x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt') - y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt') + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) # actually they are not close # assert_close(x['model'], y['model']) # assert_close(x['optimizer'], y['optimizer']) diff --git a/tests/runtime/test_module_merge.py b/tests/runtime/test_module_merge.py index 9b4e8984..9bab4e64 100644 --- a/tests/runtime/test_module_merge.py +++ b/tests/runtime/test_module_merge.py @@ -94,7 +94,7 @@ def train_iter(model, sample): model_states = [] fullmaps = [] for i in range(DeviceGroup().world_size): - checkpoint = torch.load(f'checkpoint-shard{i}.pt') + checkpoint = torch.load(f'checkpoint-shard{i}.pt', weights_only=False) model_states.append(checkpoint['state_dict']) fullmaps.append(checkpoint['fullmap']) merged_state_dict = cube_model.merge_model_state_dicts(model_states, fullmaps) @@ -138,7 +138,7 @@ def train_iter(model, sample): if DeviceGroup().rank == 0: states = [] for i in range(DeviceGroup().world_size): - checkpoint = torch.load(f'checkpoint-shard{i}.pt') + checkpoint = torch.load(f'checkpoint-shard{i}.pt', weights_only=False) states.append((checkpoint['model'], checkpoint['optimizer'], checkpoint['fullmap'])) merged_model_states, merged_optim_states = cube_model.merge_partial_states(states) assert_same_state(full_model_state, merged_model_states) @@ -177,7 +177,7 @@ def train_iter(model, sample): if DeviceGroup().rank == 0: states = [] for i in range(DeviceGroup().world_size): - checkpoint = torch.load(f'checkpoint-shard{i}.pt') + checkpoint = torch.load(f'checkpoint-shard{i}.pt', weights_only=False) states.append((checkpoint['model'], checkpoint['optimizer'], checkpoint['fullmap'])) merged_model_states, merged_optim_states = cube_model.merge_partial_states(states) assert_same_state(full_model_state, merged_model_states) From e96c0f4e8eff5aace68f357c34901cb99061a4f0 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 9 Jun 2025 06:01:25 +0000 Subject: [PATCH 1813/1892] Merged PR 2378: add iter dataset support and stateful dataloader add iter dataset support --- nnscaler/cli/arg_parser.py | 155 ++++++++++++++++- nnscaler/cli/trainer.py | 54 ++++-- nnscaler/cli/trainer_args.py | 32 ++-- requirements-dev.txt | 1 + tests/cli/common.py | 27 +++ .../simple_dataset_train/index.json | 1 + .../simple_dataset_train/shard.00000.mds | Bin 0 -> 14810 bytes .../simple_dataset_train/shard.00000.mds.zstd | Bin 0 -> 12627 bytes .../simple_dataset_val/index.json | 1 + .../simple_dataset_val/shard.00000.mds | Bin 0 -> 1670 bytes .../simple_dataset_val/shard.00000.mds.zstd | Bin 0 -> 1495 bytes tests/cli/test_arg_parser.py | 159 +++++++++++++++++- tests/cli/test_trainer.py | 144 ++++++++++++++++ tests/cli/trainer_args_streaming.yaml | 50 ++++++ 14 files changed, 591 insertions(+), 33 deletions(-) create mode 100644 tests/cli/streaming_data/simple_dataset_train/index.json create mode 100644 tests/cli/streaming_data/simple_dataset_train/shard.00000.mds create mode 100644 tests/cli/streaming_data/simple_dataset_train/shard.00000.mds.zstd create mode 100644 tests/cli/streaming_data/simple_dataset_val/index.json create mode 100644 tests/cli/streaming_data/simple_dataset_val/shard.00000.mds create mode 100644 tests/cli/streaming_data/simple_dataset_val/shard.00000.mds.zstd create mode 100644 tests/cli/trainer_args_streaming.yaml diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index b9f94910..e883de4e 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import List, Tuple, Dict, Any, Union +import os +import copy + +from typing import List, Optional, Tuple, Dict, Any, Union from dataclasses import dataclass, is_dataclass, asdict import enum import ast @@ -22,7 +25,7 @@ def parse_args(argv: List[str]) -> dict: raw_args = {} last_key = None for v in argv: - if v.startswith('--'): + if isinstance(v, str) and v.startswith('--'): if '=' in v: k, v = v[2:].split('=', 1) raw_args[k] = v @@ -51,12 +54,146 @@ def parse_args(argv: List[str]) -> dict: return args -def merge_args(args: dict, new_args: dict): +def merge_args(args: dict, argv: List[str]): + """ + Please note that this function will modify the args in place. + """ + _merge_args(args, parse_args(argv)) + + +def _merge_args(args: dict, new_args: dict): + MISSING = object() + + def _is_removed_key(k): + return isinstance(k, str) and k.endswith('!') + + def _clear_keys(data): + # values in new_args can only be dict or str + if isinstance(data, dict): + new_data = {} + for k, v in data.items(): + if _is_removed_key(k): + continue + v = _clear_keys(v) + if v is not MISSING: + new_data[k] = v + return new_data if new_data else MISSING + else: + return data + for k, v in new_args.items(): - if k in args and isinstance(args[k], dict) and isinstance(v, dict): - merge_args(args[k], v) + if _is_removed_key(k): + # if the key ends with '!', we will remove the key from args + k = k[:-1] + args.pop(k, None) + continue + if k not in args or not isinstance(args[k], (dict, list)): + # if the existing value is not a dict/list, we will overwrite it with the new value + # for example, if args is {'a': 1} and new_args is {'a': {'b': 2}}, + # we will overwrite args['a'] with new_args['a'] + new_v = _clear_keys(v) + if new_v is not MISSING: + args[k] = new_v + elif isinstance(args[k], dict): + if isinstance(v, dict): + _merge_args(args[k], v) + else: + args[k] = v + else: + assert isinstance(args[k], list) + # we only update per-element value if the new value is a dict + if isinstance(v, dict) \ + and all( + isinstance(item, int) or + (isinstance(item, str) and item.isdigit()) + for item in v.keys() + ): + # note: you can't delete an item in a list by index (with '!' ending), + current_value = {idx: item for idx, item in enumerate(args[k])} + new_value = {int(idx): item for idx, item in v.items()} + _merge_args(current_value, new_value) + args[k] = [None] * (max(current_value.keys()) + 1) + for nk, nv in current_value.items(): + args[k][nk] = nv + else: + args[k] = v + + +def resolve_args(args: dict): + """ + Substitute the args with the value from the args. + For example, if args is {'a': '$(b)', 'b': 'c'}, then + it will be updated to {'a': 'c', 'b': 'c'}. + """ + def _is_variable(var_path): + return isinstance(var_path, str) and ( + (var_path.startswith('$(') and var_path.endswith(')')) or + (var_path.startswith('${') and var_path.endswith('}')) + ) + + def _get_variable(var_path: str) -> Optional[str]: + if not _is_variable(var_path): + return None + return var_path[2:-1] + + def _get_value(data, var_path: list[Any]): + for key in var_path: + if isinstance(data, list): + data = data[int(key)] + elif key in data: + data = data[key] + else: + raise ValueError(f"{var_path} not found in args") + return data + + def _set_value(data, var_path: list[Any], value): + value = copy.deepcopy(value) + for key in var_path[:-1]: + if isinstance(data, list): + data = data[int(key)] + elif key in data: + data = data[key] + else: + raise ValueError(f"{var_path} not found in args") + + if isinstance(data, list): + data[int(var_path[-1])] = value else: - args[k] = v + data[var_path[-1]] = value + + pending_values = set() + def _resolve(var_path: list[Any], value: Any): + if isinstance(value, dict): + for k, v in value.items(): + _resolve(var_path + [k], v) + return value + elif isinstance(value, list): + for i, v in enumerate(value): + _resolve(var_path + [i], v) + return value + else: + ref_key = _get_variable(value) + if ref_key: + if ref_key in pending_values: + raise ValueError(f"Circular reference detected for {ref_key}") + pending_values.add(ref_key) + ref_var_path = ref_key.split('.') + try: + value = _get_value(args, ref_var_path) + resolved_value = _resolve(ref_var_path, value) + except ValueError as e: + if ref_key in os.environ: + resolved_value = os.environ[ref_key] + else: + raise + + _set_value(args, var_path, resolved_value) + pending_values.remove(ref_key) + return resolved_value + else: + return value + + _resolve([], args) def _fix_any(type_): @@ -154,6 +291,12 @@ def _guess_deserialize_object(value): if isinstance(value, tuple): return tuple(_guess_deserialize_object(v) for v in value) if isinstance(value, str): + # special handling for 'false'/'true'. + # 'False'/'True' are handled in ast.literal_eval + if value == 'false': + return False + if value == 'true': + return True try: # try to parse as literal # if failed, return as it is diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index c9804eea..2e201a2a 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -151,6 +151,7 @@ def _setup(self): torch.distributed.barrier() self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq if len(self.dataloader['train']) % self.train_args.update_freq != 0: @@ -205,6 +206,10 @@ def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): for i in range(1, len(state_dicts)): if state_dicts[i]['train_args'] != state_dicts[0]['train_args']: raise ValueError(f"train_args in {checkpoint_files[i]} is different from {checkpoint_files[0]}") + if state_dicts[i].get('dataloader', None) != state_dicts[0].get('dataloader', None): + raise ValueError(f"dataloader state in {checkpoint_files[i]} is different from {checkpoint_files[0]}") + if state_dicts[i].get('lr_scheduler', None) != state_dicts[0].get('lr_scheduler', None): + raise ValueError(f"lr_scheduler state in {checkpoint_files[i]} is different from {checkpoint_files[0]}") module_state_dict, opt_state_dict = nnscaler.merge_state_dicts( [s['model'] for s in state_dicts], @@ -219,6 +224,8 @@ def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): 'train_status': state_dicts[0]['train_status'], 'train_args': train_args, 'rng_states': None, + # assume the dataloader state is the same for all checkpoints + 'dataloader': state_dicts[0].get('dataloader', None) } torch.save(merged_state_dict, output_file) @@ -271,6 +278,10 @@ def _load_checkpoint(self): raise ValueError("lr_scheduler is not set in the current trainer") if self.lr_scheduler: self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + if 'dataloader' in state_dict and state_dict['dataloader'] is not None: + if not self._is_resumable_dataloader(): + raise ValueError("dataloader is not resumable, but checkpoint contains dataloader state") + self.dataloader['train'].load_state_dict(state_dict['dataloader']) self.train_status = TrainStatus(**state_dict['train_status']) self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() @@ -307,6 +318,12 @@ def _format_metrics(self, epoch_desc, idx, metrics: Dict[str, Union[float,int]]) step_str = f'' return f"{epoch_desc}: {step_str}{metris_str}" + def _is_resumable_dataloader(self): + return ( + callable(getattr(self.dataloader['train'], 'state_dict', None)) and + callable(getattr(self.dataloader['train'], 'load_state_dict', None)) + ) + def _save_checkpoint(self, loss): checkpoint_config = self.train_args.checkpoint @@ -343,6 +360,8 @@ def _save_checkpoint(self, loss): 'train_args': self.train_args.to_dict(), 'rng_states': self._get_rng_states(), } + if self._is_resumable_dataloader(): + state_dict['dataloader'] = self.dataloader['train'].state_dict() self.hook.on_save_checkpoint(self, state_dict) ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( epoch=current_epoch, @@ -436,22 +455,25 @@ def _expire_checkpoints(self): shutil.rmtree(save_dir / ckpt_name) def _global_batch_iterator(self, num_skip_first=0, stage='train'): - if num_skip_first == 0: - # if the checkpoint stops at the end of an epoch, - # the rng states must be resumed before creating iterator - # because `DataLoader.__iter__()` uses the rng (dunno why), - # and the previous run had not call it yet - self._try_resume_rng_states() - - it = iter(self.dataloader[stage]) - for _ in range(num_skip_first * self.train_args.update_freq): - _sample = next(it) - - if num_skip_first != 0: - # if the checkpoint stops in the middle of an epoch, - # the rng states must be resumed before loading the first batch, which depends on the rng; - # and must be resumed after skipping unused batches, which will affect the rng - self._try_resume_rng_states() + if stage == 'train': + if self._is_resumable_dataloader() or num_skip_first == 0: + # if the checkpoint stops at the end of an epoch, + # the rng states must be resumed before creating iterator + # because `DataLoader.__iter__()` uses the rng (dunno why), + # and the previous run had not call it yet + self._try_resume_rng_states() + it = iter(self.dataloader[stage]) + else: # dry run until reach the desired batch. + it = iter(self.dataloader[stage]) + for _ in range(num_skip_first * self.train_args.update_freq): + _sample = next(it) + # if the checkpoint stops in the middle of an epoch, + # the rng states must be resumed before loading the first batch, which depends on the rng; + # and must be resumed after skipping unused batches, which will affect the rng + self._try_resume_rng_states() + else: + # for validation and test, we don't need to resume rng states + it = iter(self.dataloader[stage]) samples = [] for sample in it: diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 58a1d00f..12117634 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -24,7 +24,12 @@ from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule -from .arg_parser import deserialize_dataclass, merge_args, parse_args, _TYPE_KEY, _VALUE_TYPE_KEY, _VALUE_KEY +from .arg_parser import ( + deserialize_dataclass, + merge_args, parse_args, + _TYPE_KEY, _VALUE_TYPE_KEY, _VALUE_KEY, + resolve_args +) from .loggers.logger_base import LoggerBase from .train_hook import TrainHook @@ -499,7 +504,7 @@ class TrainerArgs(PrecisionMixin, PolicyMixin): optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) dataset: DatasetConfig = field(default_factory=DatasetConfig) dataloader: DataloaderConfig = field(default_factory=DataloaderConfig) - dataset_sampler: DatasetSamplerConfig = field(default_factory=DatasetSamplerConfig) + dataset_sampler: Optional[DatasetSamplerConfig] = None lr_scheduler: Optional[LRSchedulerConfig] = None checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) log: List[LogConfig] = field(default_factory=list) @@ -594,7 +599,7 @@ def __post_init__(self): raise ValueError("dataset type is required") if not self.dataloader.type: raise ValueError("dataloader type is required") - if not self.dataset_sampler.type: + if self.dataset_sampler and not self.dataset_sampler.type: raise ValueError("dataset_sampler type is required") if self.lr_scheduler and not self.lr_scheduler.type: raise ValueError("lr_scheduler type is required") @@ -612,9 +617,10 @@ def from_cli(cls, argv: List[str]) -> 'TrainerArgs': if argv[0] == '-f': with open(argv[1], 'r') as f: d = yaml.safe_load(f) + resolve_args(d) argv = argv[2:] - merge_args(d, parse_args(argv)) + merge_args(d, argv) return cls.from_dict(d) @classmethod @@ -721,19 +727,18 @@ def create_dataset(self, stage='train'): kwargs = self.create_kwarg(dataset_args) dataset_class = load_type(self.dataset.type) dataset = dataset_class(**kwargs) - if isinstance(dataset_class, torch.utils.data.IterableDataset): - raise ValueError("IterableDataset is not supported") return dataset def create_sampler(self, dataset, stage='train'): - sampler_args = getattr(self.dataset_sampler, f'{stage}_args') - sampler_args = sampler_args or self.dataset_sampler.train_args + dataset_sampler = self.dataset_sampler or DatasetSamplerConfig() + sampler_args = getattr(dataset_sampler, f'{stage}_args') + sampler_args = sampler_args or dataset_sampler.train_args kwargs = self.create_kwarg(sampler_args) kwargs['dataset'] = dataset kwargs['num_replicas'] = self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus # if not distributed, we use the rank 0 sampler kwargs['rank'] = int(os.environ.get('RANK', 0)) // self.compute_config.plan_ngpus - sampler_class = load_type(self.dataset_sampler.type) + sampler_class = load_type(dataset_sampler.type) return sampler_class(**kwargs) def create_dataloader(self, stage='train', dataset=None): @@ -751,8 +756,15 @@ def create_dataloader(self, stage='train', dataset=None): # here we don't use self.collate_fn to avoid its implementation hacking kwargs['collate_fn'] = load_type(kwargs['collate_fn']) kwargs['batch_size'] = self.micro_batch_size - kwargs['sampler'] = self.create_sampler(kwargs['dataset'], stage) + dataloader_class = load_type(self.dataloader.type) + if isinstance(dataset, torch.utils.data.IterableDataset): + if self.dataset_sampler: + raise ValueError("IterableDataset does not support sampler. " + "Please remove dataset_sampler from TrainerArgs.") + else: + kwargs['sampler'] = self.create_sampler(kwargs['dataset'], stage) + return dataloader_class(**kwargs) def create_lr_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.LRScheduler: diff --git a/requirements-dev.txt b/requirements-dev.txt index 9107136d..3f64d831 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,3 +17,4 @@ tox-conda yapf wandb tensorboard +mosaicml-streaming diff --git a/tests/cli/common.py b/tests/cli/common.py index 86b4b00b..f5be1c9d 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -1,11 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from pathlib import Path import torch from torch import nn from torch.utils.data import DataLoader, Dataset from typing import Dict +from streaming import MDSWriter, StreamingDataset, StreamingDataLoader + from nnscaler.cli.trainer_args import TrainerArgs from tests.parallel_module.test_end2end import MLP from tests.utils import init_random as init_random_fn @@ -92,3 +95,27 @@ def __getitem__(self, idx: int): def __len__(self): return len(self.data) + + +class SimpleIterDataset(StreamingDataset): + def __init__(self, split, *args, **kwargs): + name = Path(__file__).parent / f'streaming_data/simple_dataset_{split}' + super().__init__(local=name, *args, **kwargs) + # the data files are created using: + # dataset = SimpleDataset(dim, size) + # with MDSWriter( + # columns={'data' : 'ndarray', 'target': 'ndarray'}, + # out=name, compression='zstd' + # ) as out: + # for item in dataset: + # out.write({ + # 'data': item['data'].numpy(), + # 'target': item['target'].numpy() + # }) + + def __iter__(self): + for item in super().__iter__(): + yield { + 'data': torch.tensor(item['data']), + 'target': torch.tensor(item['target']) + } diff --git a/tests/cli/streaming_data/simple_dataset_train/index.json b/tests/cli/streaming_data/simple_dataset_train/index.json new file mode 100644 index 00000000..755ac103 --- /dev/null +++ b/tests/cli/streaming_data/simple_dataset_train/index.json @@ -0,0 +1 @@ +{"shards": [{"column_encodings": ["ndarray", "ndarray"], "column_names": ["data", "target"], "column_sizes": [null, null], "compression": "zstd", "format": "mds", "hashes": [], "raw_data": {"basename": "shard.00000.mds", "bytes": 14810, "hashes": {}}, "samples": 100, "size_limit": 67108864, "version": 2, "zip_data": {"basename": "shard.00000.mds.zstd", "bytes": 12627, "hashes": {}}}], "version": 2} \ No newline at end of file diff --git a/tests/cli/streaming_data/simple_dataset_train/shard.00000.mds b/tests/cli/streaming_data/simple_dataset_train/shard.00000.mds new file mode 100644 index 0000000000000000000000000000000000000000..3fe098bd2fe8c0c9fa7cf208360f08796a037321 GIT binary patch literal 14810 zcmY*=d00(f^!`z4F3po>sZ^RYse9JCRD?7kMX9J%nkbbd2_cn)L?|>NO6GLWzCuFC zl*~y8kujn4JNNtjKELOF{&;)tbKCdaw|BjJueI0Sg;Nx@ScIa!u*ejpC|NOz+RMUp zG)0|d;U-Q|_gI8VP}FA@8IlwwEk#i!ER3Zo>NJa~G8EO$Vv#IG^|45oqbNyviYjJd zG=`!YSa206>NblYMT+{sB29^+#FZ&(CksOrimGP;sub16Vu2b(^|DA&rzkNEiYj8E zuSroSSa7r`>IRE|u@v=|MY1+UiRe&NAq!nyimGMdtVdC;Ed2E;s)xl21NM(Jq^JTG z+T$qdC<`YeifUmocRWRPvq&_isDCW-Oejjrl%i@_I830ZD=hrXD5{G^yg5Y;vB`6TdPx)XxR3*LzfPZGQl6_ynBzn>R;O7O6d*yy-u!4L;W8*3*g z2Rrutl!z7o_axgjuB`Y17>X%u99u|ovv!b) zO89a5AQ`bf1s5;xht!0h^p0f*u;5z(W;m>2*hlbRCc*aQ(O7fV72mFV0IIp-*wohp zd(Y|MwmdQ9o6iKrbLuD^D2m&sNuv1j1Mu#qK6-q+OzXHBqY#_samUFtmB&O(K@o$; z&BCp-HF2TbHFBfn1mk7(375%o(fqY4`B-4XRJ?Kle#c3qJ&frcSfL6e$n<*)kVr+FpUr z$uB{OO{=mF312#jHsua6iJ_U$uJ;N$9N*G|(niR6`3f#j1n+tcn&4slo^R^%_mWSXE)lGO*+W{1wFJmNkwlcx0GN|r% zgV~wV!x*IjuB>au;-TC)t6im^M;&-ok94QIX*KWjNWY086wg z@wOe0Y@KbyZRUJqyl)+YOZn}DYLX=#wLh7u9V1M8(PlFLYyn!xDljkHy^v*7bLlw5 z_0X93ek$C`NXJUMg;@7n9qDgQxU|p$TVMFX<}1VS*}xD>99!T5la8OK?FXNK3+OJr z9N|a>O(eXWZY%64^(W)6M>~h)%*W|>+Zc7Zjo2A-4r<@6X5xdCxi(L~kirwQc+W0v zB=qxSaymeY=Cs-i*l<0R>FD%m^iatM?p1ZDn9~D0MNA-VYci~yG6fT@M(A3DXo=UvuJ0S^ z?=t{B5)MF1)M2pRdY?9$YlmscGSE|d9d70(Ih&pok%W!a#HTvo-MU5vdw9dOobO_4@@l_YW;0lEwhVQ^BUrkIIvjSt) zyo4k5!Dumgu)2fMmhvJuZz|wZt~LhA^n+pDZ``lEnfwYmz>8V@nQk4UL8K>_k*}(E zc|Nl?5L=~k#=34buxxtrOprh7I<2ayg4HXc@q?E?&fFh|Iy2p|`|bzWzsvyBC3b;x zRV-?#bFp=b6et{z$E-<1(0bs9FdJU~2Bu-bH#)g}6&zCBg;V%D!05|WIQ4oqb9((( z-h(@Cpj_)tHo$k{-98%jdD@Y&;|{`Ja|3Yox+-9ErT+p9BsYOvlsJlruZQ}6d0Nsf z4K~Q%f*m!X_+!E%>~)z9CFTRb(K!qIG~(gt-+Qpvq8d(aoh}@yXT_-o*HEfn>Jix{n|*%%yD z!W+7^mDHbWh3B>!cr$kb>PGD)=Ih2{nv$%5P4<%r)G>F(?UN_rv_c-F{dGmRYA zEzok`Amr&PlUD?l0GI}bH}k(5xBuY6Hje4$9m6~^vrn|VO*&U>Sm8Y`GQJe zHfOKbG4^`b@uY(r@HRDrwTlYV=R6LljV@#YCm$u#mnkyWc#_0?Xf)ZnH3&y5EFoVm z_@iRc4m`T-l|UP+FW;p;mIwWm4Q_Lc$MTYC^bTzr%cJjdx`rA!T|o#ptX{x7G7bwl zKj~}x7qhEC5^8iwV7YaaaHNhK^T@*sU1V^mls-Ntmly`0NBgziJe}jp_}gV7vr2D0 z&*0+*EFK?)`6=HSx%Z36YDX2gRrno$zPcuuQ=M9Q;MKSndIw%W)Q}twKGcFVnKnAP z&=z~%8Q}-VHK<>{8%|G}h^J4MK(7^I`Q5Kz`Fas{Y5!+VX}KHlJSNU2L+6H>5RD7W z-|MOJ3{*Yb>es%F)qtyQEm** zy*35a?}*_xuXwDVItS)7reUhe8g!hNj#pwfL+sf^n5Zd<-A_%>*Pe@y`GE7_xDrV92f+5**AMfp3=NMgkpDPpV{Oy;NZo2LKXg^l`cpqep@dGT`x z$u>O1gq};ps5?c1NQH{l!lQR)xain8RPmn&tHNZ_Y107R{N^-t=afSCq(tN|_oUlT z`2ttLisR;?k2M`1>DK58_+_ECFq^J6J9GcLfhaHX!RHtW;1Q89$A(U%p9BahxxX0 z25t&?Mq3(hA>wa7^E558$-?=2$jzTWKt60gSrA-Dt`%KH1u;pUt$zoz<9W7#&Fq;9 zSidSA!ZqK6gG?7tvfUi>7JoF>eG3XVbLolh5m>k`7JPgsV&a!f{LnWVSKT+lV^3}I zGA~rv#ee6DG5iA|y!%Rbq1mDk2kL&|{l8hLS(1fb1J9Y|a%E&(Yy~baTLpBB1bQmP zH;Jf?;dfPCL`~^^f=DSOp9Ig>ArSQIBXG6CF?o~>?(AplJ&k|R?QVk0lEiT1Dp4#q zYvJ%)vuKeDX&eqj9R52GujP*xX5-ptgjO3@bItDk<4HK~2JXkpklQFr^jBq)BhzcK zM&>tGnpKj!!D@U?loN)E#P_cSU#kH-ZH zFCfc16TaAdq64hwz^fld7`aGCSet@_#o(voiFfWE#JG$E%sMQCrq{2KpTEi=Kbgjk z@(Y-N>$&2F+c3GHl}LTDB}3B6_%!VRCjJ>Gh}4$PZYVwQ7(A~|!S*E~^wB6o{53TL zsn!~oKD!i9{OXIFe?NjRDhaSh%?k^5N@ASoXiSF&IFtAv7iXS!B_9r?K=v{-kn51= zAHRDNUzAY9f_j8K3hT+DwnsScc_}&kz!lA1_Mpis2a@WU!{g}6aaU*cAj`&=nvOc3 z)v#r=GHP2-#T$$_N)H!9>-X2tW8;K1-`2yqx>RhPZU7AmE*LOb0nTVlKy8*Tm;Fybxw%+3pl=@*PhR#g(|5ZekRedg%Y)`xP-OR$4d zCl&*2mbJYk&?aU6TR0Hmi+w=~;8bJ`CAgdBcep{CpD6zMuoav(XyYYcF^t+2gYqMV zko)TbY^#}zpWO(&-mO1$2+fHE@M`!FIZOiU8IiG%u=MHTmz1IS{ z3wJ6KucYtHpx#Tg_VHyDM;&4!_EchlO)c8)nFbXW*YWAk%Zx@*gCJ6&amiq)bQToK z?7&IN5z9?wu{*j34vD;l!oEK+gCB=IQ&-?1J6#vOx5OE7HW>TW8bdXo!)pOeH~7O4nSBV$f!Vyj6PNakMUjc6Q!A7Wfa{^(h9@T5BK$DKqrU5Q?58_Y@{L+p=S z!?KPUe3nhQy*|qC%E0pJM>!{+9DuAG6G*(b9`+e%;xPp;lsx>KZs5*A%a1YG*}e`u z=BwdO<5iFrGZw{swg_tzQ!$3T^M1x;?mNl^e7Q<2^_DlKw(rBBp_};bK|8tkD~dPI z#(+(8xlp{%ihLkj$=s0+JhCQ?NUl*4Xk$2SA8@8l01dUH^t|zr*mdkB^t>5_#*!?E za16w*uM5CpWeF5kbm*Zyl_q1KjGx+l6j4+$Pj7B)P_B8$6qnhk^U`J+} znleTWX{6!XdZK=MJQiBJGKndUs3l7AI~=w6OX@?xHTf~IFRe!J$_D}$Z|JtfIF2E< ziuBW3KUN@>HthVpMiyV}`v6t7SHR18J{Bp(vJ0yM@^@Ck!U-X0w9E%1ht<&Z?*(Bt z@i%42f;k2xKFyHqU%v`}cKO53r$0DOtxwSR=47&2d>Y9-sDX`g<>c0(BFuKD@n%#x zZZ3JvDCYPIv{}nm<9(_taQ!PeG|ImPQR4gQw(>ZzS-}>4Lj%zK{WBDG*r2tM6}&k8 z8P1NJgLi{daj|(6P7m=B_Dy!QF9}?I9*U3MCIz!}`RW=fw2fmOV;OguS^hhU=lg3u z-fOqSYjhxY!`_GF;589E99zWnZaoIlBEq*%F!kiTXwlal(UiDeU6odrQs z-@$Z94YgmM2K@sMIVwrrP^uG#UiphLQkjeSyOdDuwjS0Of1t-aoP?B)w#&Zh#W43m zlQ0|ctE1Rm(he4_=b8Ak)p$}amGBe7Q0>KC-e#YP#9-<@=Dx)p;wCPF z9>-L;&NG*R*KP@cZ}$8P#^No5z%1DcPlMS0jL!-BnNt&(CyirEfPRSAc>}}FUZ`L+ zK!>NVfi-VeV_KRfQu!SaGJC3Uq$IY7kg&sBT@pkNlD0d!#Pu>4P0v(<&h0#ArfxN+ zJq&zIFkO}l3yREKu*laV=4^CfS!$rAX4Y4+4#)qA>;+`f(nxk@K3kI4BwY9 zk~tF-zc$eW5!n#9GYoItP{6zvO$1FdoSA(Ys>=32*?A*jHey}*FmcjqvSRBJ?0C?M zfrqMjMLiA7TMuveSG9wg`&1n=+qAf4zm8zN#wKQZX(O23HRQVJTky}!(H2zcpN*Tq z35&tyy#fyY)WxdqUMO!G5A8D_L(jut2wdcfO~e9p#zdm_Ut26U*g=~OFG9-wj7yEq zSK&y7{hh$+)?7dg~kn^^MzHE>J4l=;{-)=-IU%x;F-)r1(?X8e*`+X*N#OPQ)LQXW)w2 zGzjQ4hj)oFICLxv&Zl=lc5bIIn}zA3Bsx=pSZFOHOJ5_gk8S2zz0bpmylTAHzl-_Q zK7;%%p2+lD%kT{sJ0WFU1TS)8V8YZ2yt$=b;G21))li}F9IfxL6kCl)!?ZVRu$}1a(g>aZEKy=B^RJyYbHkSGTnZw2~1->cE$e zuz5*4-6^(?p1Vd5L}Up_ZJCL0LyO^%`~ave83M}J20knP2G`xI@%QXVD(Rzv_dZB`L1Sw=RZf@SHh4?i&-+SOtiy$cd?* zc=xgrsrcsywa*s_BK6?6I}UXo1!Har^skS@t6mc5Gs_s(JXwYesN$6jH}t5?q1*ZD zSUA%QS13LOl|Mlc8Wo9ZTU>^2QKwy`>xnAy zJ_5|T@B`%cmffg1-h-jkj*-l*Ik;=ZdQxh;gISW>hXd0z`B!=EP&6=6P~VgsUkaa7 zK7ez^I2`>$5oc`h!+>k-Y_8YCcGdUrsz(W3`YrLc-+U||?%?PK?E)^>kGAOF1EM0`)%JCg-2H5e=t(BB_DKO7uH++5;*Ys>oX$+tuk?k{`N3$?ZGzqg`siK9?(Dv)(T^{(D`fTr z41Z?|kB#)uxcWR~z#ed()+0Qp&ZXGlL8m{Q;JO7wX5wKD6+5}raCEyzKcxN>D&Ha zU?iFCT@}nj4_z@F;f}|30|uD4CKy_yr-H(^Ksso67#u52gths)_BzjG+aLsuv&Vno z*0kuqF6OD~L~?h>3)FO8PUcSAOx`#xV;bj`9sg^)csdrJ)l%x_@Qx9vsU&hcdK4RKUmnnX-}rTF{5T?X!ryUbWs znkaYOM>WUM_{KI~piN}bWULR@rB_ZD$Ha3yI5zBuB?{)~J9`qU4tQcyT?-iP=?AQj z#MYxNbh3snj1QK=#k==H&0Gm#Z6c2U1G0Jx1WEJoRP`B18n}bGAJlQilV>D`qeoh< zwOy<>k3cK$p^IIkC)0i2D)?vH4U)e~jnpT+7j)XbDlW$de%&zmwwS(Mz7=HhS7KuG zHTXcb0ogwp3p&Rj<+p;qR1$-$R6KCylmwh2>VzKkY_S{npT1x8_DN{ZktbekJ>xss zo)^S!yQ+3h#nEIvac3(m=gwYUinR_oyJ!Kb+|z|;1^FnUo`*5Riui#{6J$+INx+hj z+u)(0iCMCl9F634P_osE-7tTLoKG9Tr!AY7Sm4XJpf2iV;~H{Bo}ivSi|D{r3UTuvZQZYsB!$Wqp)6q$cbe@376JTK6_|tiHo! z%vU8!!5f%wqsNlsW-ilOvl7)+EwSdK8yRuk#VaJ|$?@$6$><$w_&Xw$l&s$%(57?e zH`tozicu3XuwqXPW_l;%;yp*8q>SBD)xCz7560t-JCiWf@;#W&u!n%~zjUzkYIL)> z2s$vX`!Mu|Ys zm`<3~whDhd9*aln6!C{%7c3m|Msb~W!fblxP3NvE@FuCL`Q(7vWU}S@47^)&0$scb zILT*`#OOBXkaOZpF5aiB=kb_v<_@GPWPsP5bpR`8%mS8;!gYT{S{J!4@+eoPh$S31 z5K|06QD;{y(SAuQjF+SJH5tg1nui4gtiunbpvBHi2$h;fbG)2{BPF-Yleqim!+iDY zbjcH4?(w=kgm)qvCm6IaJ)+UP&r}|>TRENg+d&dLM(cIwE)feZn)w^H(Y&YDy)soWglJxHekQBGS}>nI6fC03&-z@VdPge zGE*%Jix1Y2%j?d$Tn|;mNu5~3dcW@kpyy20mRzJa<3Vgf-Ctfndw2#N>`o0D@ zwNw*3CzR8*(^WA_DGt79--4AlCSt(zYjn>HM<@?-MrzwBmjbl_bc*>PJi}ECvT^?D z3OsB#Hq%!IOVjFxaEBEPPXve+JyD`Sf@Y<2ht^N(;A&Y_r-9gpzsaYb%p zf3iRue%D{%P1C|r*|yl}I2Gz9D`JL@BgS2i!}=w|wEu-Hcrs%=>bCv@g*C?DV| zK5aOw%$5PGQn4;=`gZ0$0 zcuX5KUN&1F_l(=-;` z*@ACAd(+GM?~{oheaw-8vFri&5uzg{kGsVt6JEOnx31cgc3txhYBu@coy+nJo0^$0 zvkw9`QP~Cb`2kzpBISigMGpftb*YO~<5GOqdl?2xHPJ)rJZK;E#UBqdkjwGGotsWT z#;HKePX_#$AuG)0&Z$Z=l*MMwq#q>cbq*TK=a8xmk&q*0iO-iv;JctR@N1t678|S6 z2Qu|A?2kW^>tSS)Tq?FNTO@F?g7r0erg9Nvy*}zQ(I>;2xy~P;334&RhD)7XY4dm$fyQFQW6e{e<#e-j-2wdzG+6@|$Pt!G0gx0E4 zKyHvLXQ`?omM=D-&$AKCeRTt}TrYCghAzcpT~X+{JsF@~3umsLf;VdaGs7J(=aTpL zMH$tFmZbJ*HHng#fa$D<^6!7ZkfQhCb2ATQ6+Gcp%1fHLJ)cDOY2mk;8mv>>f)m32 z=bS<5%yS4W?d7~y3dZ`=t7(35J5+s*;@J5&PgonzVSTJUw+;_hdlKt9idii+5#P22VdC%`Tr|ulABT6+=I538^6S`6 zNwy`v+(by$`qiYS*AY{eyb@#vKi$jl_FFeBn==Yq?mdJfL*wyJR5b3^Rl(sPOYD{Q zM!Vf(QGJvPrW?7k(MiR^Zd>|On>nUM{^xYQw{r}s7@UBw_KEYA8$vPf)e5FAx)sw> zi%3V38Q*D8g4?s_BwHGJqwPauC>=2;+`3LOY4;YQdv<}qH;csYf$r@6VD!TiFR8O@ zvR@PZ#>5@XTm$gmA~mG4N8$aoi!tY>C^nAgLypEL&{Ld%Ix_mGZ~aYJn?=#;IGl9#dtn&9sS36EBMc3qq2v+u`o4< z^Yo-Fc3~Fh^UhzwY@YJAkvqS1aq+@XvQg#bE z@WR(1bi3h?3mEp)NADTk^<$W$F~<@^cg>_Jw{YinSOe|xWwe4$DX7%a!rDAit7mrX ziGlXqm!Li`7aobvX0|WPMw_&U%npO0rqN?Ehy=$Q#BV)-xypOVinWqVz~o|PzxZd6 z+Ndn>&4xLv!2cv$NnN@FewHyP{YV5eR9=F!>>02SQGzdS*Xij8129?pBxL>aowX$Lr9oR}XX8`&_?0+;P0LHctKepF3?<4tJpXhzYUJnn#ndC!lJ5AX6pR ziaTARNOM{%dG~l4?d6||PxBPH0{(|=*;;G;TJ_f z;%x?|?cNTRX*1F1cPRXMO5x(W=16Tfp^u++My}Nty84+SPUYN!u#g!zW$AxzO0D#k zB@wDN#IVnV9DZwqK8F2_^^iF;*L*5wDV1PfgaJwTzL6Igt&4{G6gO_dZU9M1{B>tE zvAcLvz~*|VHlC;-gw1B_K;wcH%C>HRu$Ok&BfT9)?)kGN!eZ>q@Iu;S6#iJ~g7r7; z>0h^&;oTrRK`+1Gs49q>1O+K6F;@w+Cn3ZJ1B(2V-%kmBpnbCU0`bHA- z`Q%wzZtYtu#0CeX(n)V5Au6oGX<+>(LK#J zjC`OQ9(T~ikNLLTUrEowDNvF-f6_Q^k>ef+N>C-|E305%k|aEB9ua8s>{|iIR*eDs zU&=Vs;TcFciDTMENA%&@W3fveY`r0geM@yQ`+^8M?AJtT&k|7l7==H0Z{QF7$2WUc z6H*trixdB?9zGo&;rtnCgT2F#q2z-J!Ku~RU_4v{%FJ!^%uP7dh>V%3NnGjF8VyDV`&=s4Ai|;H$aj#sUMrA`|K^p}4 zeSrE8YoT*XGyFR+D6CC*vJ1ZXqREtr+mqW1`{C}10@(QLD0;lwfoDr1dAn9UgZr1_ z$*dE6(!2XUt#-K?(+n?y&bzbta!aFNP9?=mM&7a%{MX9%fb-aW zr7z%{s0Eh1WN3YpJ@my56q-dfKxXJV=!%o02YK^^ee-YQBMc2%hZ%NynCqm(U0r+~ z5B)vByi6jP5TAvSBj?DY9BbO`;5`z(>LLksKLtN~I(hQjX<}!!Q{bB$ZS1#xhy~B{ zyWvii0g9cFg>6o3C+X1;bnG?9Pc61=#c2b5ZjN}SIuwg(NxWr01$$J=pe|NZm`xG? zJgQHa1)EdyiAJahzmb;YS_b`5l1==Y5JqN7;TG%(m0pFaELZx~`6#20hbA7AfIcGlp zo)V8fUa@refG&H3$^)GhJE62<8g5@8hKVPV%U1_I5eqCq7(g4SewS6Z6x>8 z2c{;glBi1MV*A|B^iAFtIQi-eF*_wgrWP!LQ@3|8m;KE#Y8Ju$cGo!Bul_NMyE!Nt zd|1HdpS>6QtkI?Sw>#jOS#!}^%$#W*!!Qg*m=d@I(C>NksOBDe~&A z$ilh1n00~u=$nyGatS=$0dmol>Q={-IEudEhblPu>N+%6DlF z9gPzkIH;pJ8~Wu1hZ74F|5}N z_i2`Zal8zwEogDXdQ@FE_3^K}fX~5Cc7qD4S5{Py|Phm#4GB)h?!I!4~xX9ZL4ZZF|V(cQCdUu%f z-Nh8A4Ge+NdJa5$>xfoX$!IM(lRXi7D$Hiksgr15I*R^T6xVTbB6DfO7LsT(o6K%q zg2gKfnH#oWndYs7s5QeDKT0W+V6zl5?`Q*v zwkQBIl(s@)>Rz_9Rz*u|G($7z9h)t3;hD}tJi%TM3(ZcX7l*$ENvZ$blHX|A0dMv? zV2jB`NccUE)Cc`0K3X5K(6NL0c!*+r2hS4?)nVs17pgI0w>T-iC&O!SZNj=Aj$zu{uZGxi3CJ?-oOV*Dm0yp{*{M}lDgJ%!(K5Fg7BeLJo zU|tJS<^HgJkuyz~8$)tYf*@<|2R&5LO#u1X?ClNDBKSO22M;_t1CvC#sIM)8yV)+l zx8scvn?DD{m2J@IX$bIL&(TwoZO}r|Q<%-FIn#-D;(C<7u7^o>yC88lTa`~nqPIDT zzDIAv73YzYvp0@2BsZBHeLkCsdlO7XJuJiAV1Ht#6)DJ?CtAv2P?--cYEtOqUkz!` z4}(VyfjoBSe|U~DX7D$Fin=bqFA;1UH33U1{(#cwnV5BC8IIa8UU-Hljt(G>f5TC$ zIvUn#S;9V95A=+dB)#8aN#!qZW^>&D?CJjm2?rODpH8hrb+iHVd0QfWIH$p^kW?19 z_<@5N-kLfSPkYut;%PNpma!DgqBnxqfD~$dC}(JgWrjQ(wfT3iz( ze{8UH z9SLFQ)#JZGc;4m}`+pDaxG8J`+WFN9N6Js*5gC=U1Jl;M!&BQ^NrKoZ%zmB+`E5s; zJL@JA3AswFi5vmLR0*3$e%vV^3>AK}R@_rV9|6S>p5l0659apZ~RYZ zpK{Z3so$+cn>0}PPyZ;~u?R#tz7cAS-$aM%&ZJkrn~!Rx_u#fjA#z}W;~N=wEst%X#)Zawef1d0`u{)B?(_}5>uS3c`?L{ z7?0|M+Pk-yM~(`3av)xyO~AC9^r@NdxOuTBB(rBOK3R&GDCrMdF1X{)=jtGtsE6C@ zGGqU3UT9cE>?y`?ZVNo0^&~>TW~bR|$~N)=uhUssz4?$fEFKaV!qpxDGE~QuwX9 u5K?;i&{$!F+(X+rvRYxdP9_M)+B=~1y%5att_Plp4l1k{XLou!@c#gajYE_G literal 0 HcmV?d00001 diff --git a/tests/cli/streaming_data/simple_dataset_train/shard.00000.mds.zstd b/tests/cli/streaming_data/simple_dataset_train/shard.00000.mds.zstd new file mode 100644 index 0000000000000000000000000000000000000000..adca84d761fd64081ccbb3335b68604455e9b79f GIT binary patch literal 12627 zcmV-ZF|5ugwJ-f(+Bi*$0ctw~zcUbUPtpMpVc0RF7&pu@KHRns!yh%&AjaTy+^{mE zdo|ZhHfHL!B%O|>hEQq+B?Ll60^$qf3+oGRlt|&nJ6#yl$n$J1w46=PVJ3vQ|LHQB z)4uRHVV5smB0_DK%FjyFMEodZP#!!;tKjQwk@&Z>|n8S z0Wz=r0-)KSZ}v$*gkVb4uGY%f=Z8qFI1wX=8=!M-K~~-G)tPZ60F4*K+P0iNef`JA zu8}l4b9YM5t5X63^FE_|zF8)v zy+DEC{_W_-L>2%YkUQmSW{MmP42cJbjq^!6i${=$*M}%JKS+lFzSQE>Znna%XN64b zy7YFdQd`5h66{D4q<4ZKp%RUZcvVl0dqm{)Dx6pHs9%PxBe1FU?9q~KV9<~&%dWU| z+7x9MVP^1|Jr;j30#E!WL6O)`5k%&C(V%DRP}|c_?O>Z%{?OiJUk@vpFIZfW0f~KvW*-!NST-%cOgU!4+wGyq@x$WblO=O?EF9>DlBI;^j&px z+B2tZ=Os~v?Lu_Fn83AhSh}{7Hd0_dEZ)8pNTM(3Nu5BE4Ig_mrbUMUIlB;ku?b-8 z!l`-&IY9#D{|iG zeCEavXK{S3Y0S55c{gCj^+#6-dXdRtw=ksg$OdjVA@z7j(+XXI0qGO(V4SE+)fZf` zaKp0i)Whz$`TONalPoUvRc1^jalFE^7)^^wkHxLN02~C*ZAzZ3AZ>0d-Hm=%|VW;5(+c!+_8WW54%9=gTv$>_E+)CsPaG2sej+rMc(0LVdytpa2NW=)@ z{5;Bt7T5Iz+nxA*t#DbL$f^nHy$=5bg6utyX93YR^H(6?l(!^91snPIH!} z-`A=w4b-ajQuui|O9tG)^Ye*rFrN0-Y3?CuTxRUM1_Y7I;C#B?6Fa$sQ-Jrqc##7H zBFjhIEAbNu2bqWhzIZc0gA;&@jiutyS_%tp&$S^r7B04U6bq9enoSo_xpgHR1RBZ( z2(TJE(vz)7@~bu#A3Pgyx+{wBf)aZM6o%ib75D-TS-NU5^n*uXZCPW_iRg7RcwlG^ZhuO;}Y#z>xx*Qy)${OQDi3 z@^SV{z|MGy3@-V~ngwx_$UH6!#}**;Ea)p!PVQzZrANgyHYnU%+trZm5`vT`>tguG zC$$|MbzMn8=z>$aO?u_pe^6+TLc-W6lGaL_$Y{qJNBJ^pK>UeCt(|Bh;O;@w))^T4 z^k)`7->U0#2CQT8APg^l;zk*+8Pq?0>X3GdBB2)%GWBUApMFw@N|igg_9DB{JqQfQ zYj!z0fa-*tr&VM+(Wsv55{7#xg_2kTn**k99Z?5_>)pk9p$H+Dx|)Kq0O;NT z=BpFYSC<$?i$9^%vWf_JTrX31G7kkw*-8K^OoxLZcSm2J+#*WH{fvHf9i_f!ZOzSbD`t26V1ui~c7aR7!iDYt)~pzC%FeUyh#n9T^lxWmG#1gThN^{; z_LSsCG;m%K;^q zn-TQ2C08q)K@eW+Xh7$Y@sqaV=J_u#Iz3CJMdzhKec9%Rzu|#jj0`yMsh;5E*VS0~bwcBxPfW&E+TML%i~<3oH&uWItt5XibG!h~_HCO7&u_Vh)UQym6{ zo69i?ePM_g&8Vj9MD!J#Xp_7q9?kCnAp_o{_}Dt8Ki>n*&xj5q-VyD`!GY#GvC>SJ_(I+Z$S%TU#Mw;o;w4 zWFaFVV;~`;AYx-6q6kF~f{l!kUtMrtiHe1cjDsQ;5xdF`4@V?M7BV&teh?AZ_|;YR z71$UkIEWaS*j48DMIumyh(N)_udcE$5`iK_1PUg8Y3PFcV*HR;D5%&Waquq@am1lw zhp!UvtDwyLDzJOsxA=YE(B9Y0$NRcTy>FZ4`?ATtFPq`}uAO<`wY&FC(0$(oYVWHC zxCgpw2*4`HZ==&nDdf&3k_azJTzGR8-JJ$NX zV-()^0G0PW(DuG!q3U&=&F7NBKeP02k?<)Z8eVaPpx9RA8nIhkJDfWGrHoUKr>ia4&c;6%?@0%3% zzDS<$i9>}?;AkueTiH?y7e0{BI8mFdk=LX;Wqx9o))1szu{q)Rb}u*+O+QD%W*%}vHqN5_9v?VS5$Rt zSs^(tS(Mpl-Z1GVuzH`m@YiW?0(F-$v`coDo|7u&MD%6)nYZFP_^h77QBZpgPI;5( zhy3YfObp_!O>WV^LJk3U^)xcPWCT-Y;#NUJFq)vCvk~ZPSr`VmSWeuzfgo-@JKG2I z!EkB@g{bXk0`p#cP`to)(22I{enIogj~WNH9X&4ahc%!5N^0eaWG1dS=S8!)gQ^gw zWqSl@GDh{((Qoh&iB-1@`IV!I7xjSB9Xc`BUfdL4+vL{$`$#;CxKilLJ!vu8RWTq4dF)VLfM)z;F|Ih@M?lK59P!%=XkTq zb7h=lE^)XHB(uyd8&iA=lgRIil(daHWiW;9wrTWjX;)`IkAVW0o7IUs)a#mtGo^Yi zQp`JD**gUl5B{VlzzImgbsR@J{lSPQ0zRCr%B|A9YQX2QmmOjmYB+KV5j1Z3Vlo$@ zdgHnRSFhp1Pu?MjBJLp*@>H4DOyh0nCxoWFgsbV#+6aM_z!qUT3R>K&tD)rxnd>x0 zRLy3S^=PMm{|f6nTC0bqphK)tktqDZhrSDe>2J}sI9OwfKXRZ4!6cbS9HPi7$N=cl zMKvFiGWH)Wa(sov-Sh0sWLC4ab49rkvsEb26{5qJP3Q6+N928D1qqLHL3Iy=fZg{; zs?9u!_#K4JS9Mik%OG{1lce&gjt`=Yi4Kc^ccZ&1M@&(tJ-1;^gf-CV{DjGXMzkpi z0+OMoc)1;eBbmhCpr(c~9s){YG+x}e6_0`Kdqn4WRj}kG?&W1HaFf^g)NQ#?Q)C_` zI9!Mhtle}EA1KMa2MK;&qJ~`id7{&4d2a{mK-GMjd~_$En?rDc)Oks$xlCAqQ4q1h zsvTOi96LLd&OG!G*ufTx&zjhV1_G}!2un8Q5^_NxRNdtaP;A7h!moxnc$Ef&S6cBH zC!yh>?*xI#B7RgpNwMh5mCj9U`M#au;ghHop5Z7G?nLU_sxi!7V9jfv=@O?in>1(~ zVN!3RO@j^)hWZa}U2qdqr&9&EK|FwtO}pG6<})>@-=JhaxqwURHK%q?4vB66l7DZ<>sCuN#a}=_N;c zzL&ihs@adh!Lzo=YYVss+7OB^(27ojT zKg<^McIFaZCqijlE#M1CjQQY5c0wfWn!$<8h;*L)qJ#9F7|09Rf)HbJ(dr;e#J-&( zmoK}d<{m0ni*maabPk@oE zPv6eKNAO5b)n}$(bckElMd$ql*)Xbc*B~IkZKaCTU8ra-cPRrCe_`O25#Y7pdPzg| zK*er;WI}QwOcG-!sAa9{xcCTsIO3M%JPn{q(diQl-RU4lT$a}R1Y!ihaXLxQ#f9il zG%~-9@^B8BcMsIU(XNbAS}{AiT8WyhQw0+1JV<>_Da3;l(bwspN5|NLiS7f<=pIRE zP=mT@6N@VH*}hVUxQtzB#CLRhz+F!h-KoU9eCFDgJFkAkf42BtP(bB_fFWuF6S}aA zuC7O_NNuyN?&F?#9>juDct!xRfABQ(HZms~O;E|do;blb7<4Z{Y(WRn`h@vPRsWZ` zp;IWSi@bo1IS+qa=t_?)=7x$bAAIFzs$Y(wPac=hPl$L5qFFrZmIg~%Rdu%_Xl5=& zI@?JJzMNNBdAvUdCnV!=2C{Fu$si_8M@gDK14Db-DreKlBc=HS>i3CAYsAk~SVs|I zZ^;BZ)0z|)F%|m&gD3JVpJFziE&fTUB=>M@xc6t&0-wB)@E>YG-N znHjCuy5}TP=*VA?iP@_{hODF|v7Z1EJJqa*HcTGhU zwBzJR^QF){aI=u_Md-nT>jTmzDJ=wi98cHbeayE!h# za%h_`prk6Nz*CwbOTl6>F$j4rjEpR#4h^j2M2|L=1o=sZC+bVk=A=CZ~-PlWz`25Xgf!PFDVm8InczAuMV=1 zU*tIXEJtTwu7=(#?E;su&zGzk=FPPP>U$2=$I}p8X1R);y_iUj$@t8KH(3A>!*+p3 z^Z|yDO(JZ@sW8+v3RTt{$}xL|Z7J;M$Iqdq$k9OzqX^0PQ24@2-M^yGA7`z^_`=I~CZnhgCA01lg z`X#Z?yCJ>x8V{#t0`=96h=8S>R`HZQW{#x{NH?&dcF{D=2eC02`vCA1=i#AU4VP@3 zhw&~y&Csyry8@e8_-JxBO5jevfJn_YBnx{69zCz%V#TQj!@TQm+ysmg=acI8pVgzz za)w<8O6e`^C1}OH#-}Oycw7@MJ!S7m&#GP z;npA~I+MUt+$O%{AwkJbLneCBKwu3EM;Xg%Qn-!>2)0gsuaPg{be^Mg3!yOjNuSH! z9Mp6}7KA6ti}EZLZ~Y?9pAS)h?77u0T;rY{#`)GxeX-9Q+|h0@I7U%Byd$5oT=lNr zGX&FPQ{PS&>)9`eRx=|gWd?f^`9GNvXId$WH36mFwTqJ6q7m?sS{_dq%vN4NBH8&q zB>jNZl6F$C_c*80_Huyw$I&L80#E4Q(q5ZQ9@OrfP5Oj5%Nd|$J*tJ3Q_<17gE#M$ zvc#5YOUt#id(3Sl4BSQ>+2z9g$T-kk_?XlUIfc{f$Gm0v^-OalUW{~`R+)MXL}fef zh?C`%Oad-0>C%pTfFZ;m&CvWEm5o-EveRSyw0Pc~r*>S6jSEG6FUg?KzX*xIe;ri6 zpo(3S;gj?^IxYzOS^N%2GmlCNl-2AYWmL_^!}u!Nrnuste6AjD%xoTg_GBtp#IE#% z@`{p(@8PA(?X)O#XAe!S4IU1rXH8c?+k@w%+3Cp%l%2*Y!xPaw9fkv@dw7QFK|ms5 zK0;KV;4wj`en^7bG<^2nBOsgxNp|r{OAI$8=1gTgW3@Cf*+-bpxDO^@7NF)a=CP;c zVskpUOmsN&5$2oTg96Nj+eI)9$wF`(Yl;95zcap5Gf-B0$T;@($Ah) zHHxtoElYajS*%#JKv~NHy7c+cQFo04h}*a8>sy6<+!c8gv)5qkJ)tPifXqTn76>h_ z^M&^`H$Sq0*sq*6HJFzsjQFtN5B>ssiZEg}CxCR~JlOV6Yx1hoq!#Ki{2>@Xr0svn@w>+jMegZvH8HN!L*nn4T zPSW((TPOx(%+0gIpfH-UjpUm=VLTNJBnvYz;vq}fKy?~=K)HvZT-b~hY9AQPy3n7P zkANK0+V!DqBw3B$>)`;#GH0iugvxr!5?rg1lDyEEq#K%1=q^$8oCFfRCQ(5})C#1g zQ=}vARInExoe+_4L|N!CTZ(#u)+kE~2yhz6OGYdPGVPTbV#E@P#`NWds#bHU+J-h{ z#H3VXU{R^Yq2&bgNK|bTtqLbN_refp0N$*Em3CH7D}07)+m~^M>t&gH-<>D_Q<`cu zamv~bRV93;@MXgQ**2ueZul6Ooh)HtC$B`L^lt|`E>@6267S4 zH;uZu)pKCo@&u*7m_m#PpBdtGu*<3*^oGwbW5(-_jCq?54;_c+@Q9RTJ;+4r8N?WI zF95kC`(XYh#Jl8zrl#_h^|K{s-atv@0A*I(MHz}-w1@x)(<>! z^&Bh9?iMipkQU2xI05DWD&JFW5cEB{vC|o71D{f`^AS~1Zl?#Vbx7FiFGU+NFxkp3 zY9z?4quWe22uFvBQTv3hlOwV5x;2u0>y9z5X3j^*qZW?FBGzCYur#V$lw-{g$O6d% zSYbUto-B0^tmeA;$1Zi31B-kdfT|CuvGxf`us7?Qz0_8X7ZUz?fj~oA%oi&MPys}{ z;&IbcwkWLW;CMIK&|D$~WzHgEWLO2qFd zTDB9?<0zVlzH8R5F#)l@NC12ebk=kcCu3Kvp`cyqFu4)aBzzf&w0-!w=^~_kBRin2 z+S!FmC?UNc6dNrp4!}mU1S&`1fxDm@eP}42W5s&N7`Vi+8pkzl=d{$ATJRbU=ilRi z$;?XP$R98w%tDBWW-|>_@emj|`ErVzXqGQ10pVDx-HkEqPS{%H9{^8ip~z8omNl0> z;M581#o_Z`Qy;W)ecL;YV!;_H6jq|Zs$Bu!%Gl}gOBqf_nvt5l6d;mS*g0|_EFyGY z8?Shah0VCfhQiiFv+WKdruvCyYR(hFp@Y=%lMiFkcc38K1G?0m08*SEUE#JHP(e7a zKvxrCa?_HC+m~>2{nl3TVs3^G0fv`PTBiAh92nw^u%T?uCjEZc($&}`Whc1t#WSq~%I6Ho@qL6US?l9cmq z%sgKb$v(nj#__blqTz5dtCei6X_cS7ImKmm@fb;wuK+v+kJVt5}|0*aqV6J23Xgzlx!ng zdVYavt_Pfk_8E^BY<35>ZwDWlQ%4okr3-ktLDcrTLy?Ucv&f!!Nts^$!eZK~0yh|6?2Dgrro0BMa)#N3Ov|H7+!~& zaOys&TQBgSOjAD5gJ(OC9uW)T-LxeCq+qlexUpWX62P$-$8^byK(MuivTaUib3;sLHx1z;QCI}St%hF5w?!Ep@V zzDK1QmjuRpxWLbLd+-WhL9=%RSw{4hEIW+|MffQPqVVSd)?a-{(xtJaun{?s|B8a( z9mw!eKn2a*Hh)b()!>GoKH{$PbdN@4vKVx^-p00apRuRw^W;QKg!qUSZtWUd8e`sw%oByG0|a0vLh4# zY}Z*ScPm}^f)F(X8Hce!LY#GxL>_g*&7e7m;eIr%sY;YEzrDdQ|Ou2Rg zH#8@rFBk_H0y%^hJg2BoN7=!WHyuU~V`gy~;6*%)aUJaY2WDWS8Jy8j%I01I&JPA8xYukC#2Stn;1hJo@&aMhG@DT$9m>PA zhuGlo#tD-*v;pLj2M8~R`J(a6U{#UAGXoC#IhBk6xkn5m*-75y(vTaJ=j4RmY^v%b zLQ&(bl+HXL0G?QdA5@;gAQ^WlvD#d=gya^16Hetr>0MLGoGA3`OgBn<4^9`1s6^SR zU+i2CkWY#wHk00IvtgsB$z-$O97YIxL>=0Xs4VF$qH+Ep6mC1HzP~NJyaXwMvZKjN zMo{pBieShctS2~ zL0&us@h&%M*`6Rqp(vuB9zt68x1j}BTrlEwQY3EZ83h(W;k^J>%`4E68%B=PSEMX{ zM5a=|_~Zh2x{+;H3Sl1N=?)x4(o97nGiQPE=TZ}5VFdwNp#++6_CYcsP*Sg0$y{2) ze5TBUSH-z}7e0)=L`afzllt04iQIyW4Bp5>Xy2&8;YfZ3ddsR2wn7B;9!TGHY=E_u z(bL=UnR)^Vc_$04rRVTPDV(llfZL=KEdTc<>1}KS-lv4(fmlo07(TWK17g};SjXNJ zg6FTAG>s(-pC5KY`AaznYDk(@Qx#ir&DLRTK{-8TQ3>UWDr+OwnEMkHa`}Orvr`Hy z-egJ{-FJrsvlSa9noOBB*-}VLe8+5!OSDA4u|d8)*tn748Ld^^dow}qfFP0af zH-)(L7aUZ_@TLR|VG9suQNZe8U8eJub!j49g8r5ofT3!T#$nC9S>Ut^BBV}j68VfQ z4Wu$!W=II6L(0WWA0#JK-W>yC%@jE#z2_di9cizZ86&$zpG;)r`n0V zGp$Uf`#Qj-`YX~nttH;)1;BuFDx*o*lG{C0Q-Pnz+|fPe$bJpbl;(o<{02KUc$Ujs zccutyIzOJ5OrmrY6@yXppx{cK?#O!s_*%+jiTI2f56|+W@J*dkS_J_Xai(Zecj}Yv zsxj1L26meN`#Q^UDMI8qoLN7NK)o@`lLiyTtlzY8(tK9Voa~Q#q$?-w*-57>IB0A< zS-e}wLvjLC{B9u^fVp(;^=KBK7m8t7dVRL$Qx7g{RF4ZF+-og-T}LbOwxwp%B`5|u zjh$mA8!>9@4t1!^d~m9KRTWM)O0oAxKxEo(5Cog-rOqB3%9J5o_qKQrp+$Y8+wOH| z+A4bUNfP|PrDLtVhxE_!j-)wY;b0LNp9&i zX-?h;h|I;}{@#tNzdC$s&TZqbzTR2`lkVjJA!r;xu;3GaHXj%^=Z`1koHh;VE~TQz zZtOsvcKc03S;WCj;*831HmHcfFpBBEwz}9(nyP!nDYcIgd@Z$ABX@utM}rz@n@Fhi z1|x7h?N7deE!NNc1ZWT*=I+5k;pI~0;8-5p#t{eWNPO;kMd8>#NP;+(lEn_SQ*sCl zwtTEDOw^Hx*s4jKFT$V`dj*)vsBe36h8~JNr;d&0Guk4a&?iVf!~>g;3hiFEl%vhC zP;{XXtP28i%X-k-thZq3auMJ`>0S%F?fGLm`@rI@LPDC6A`~_xg z!IiJz&YNm1HjOpLb75(Eip8w#paKcIr?RItE`XcJb^Ff@*wdnlxt2IGBtVYvM01>>} zrc|2_kopNSJPra@g8!DferJK(I#j#*PN?MXXfS9gSSnv)GV(^Mi!BuF=uhQJxk_Jf zzdq!|t8SZ@gN2v<^pWNZeiR#u0V{92(WY0;Rv;a<8$k~6 zVkO`oj5u-)T4(z5P3$*XDJ{h+2?l*2);AtE`vsKb_cmOC-7vCVYbe(W3d8&chs`Pn*lllQrW?=Maw2-gtI)H#cWmn*yu@@o8|exIEe)Z(exA;bMci`}*934c8M>bG#a* zk8*v`X8P97=b54FO52d5QL*e03EKsaPQD}03Fh)a%ml8;iR*M3@V&{|oVobE0;~DX zH72=$47PlECJ)U_x(qp9dW96#4t)Z8rb$JZ5{mV8aqus!=92-@@dz>+pCkgw0i%en z2l~SisA!2sytf)bphGAqdj+r(Z5Cs2H?%1_;G44!M6f1TA%x#`s+d#LRL^2jK!(Ra zW-nO}vZ+>*L($^rZ&X0EQyZY|(waRV6|&F>u=b91wbeBqVE)Y#yIxr78Cknv)jz0r zV1-Y|88D+!Gf-VAF|0kWa(oOn@P$$lng!ZscejDe_#n%jI%LabVZd=9nuv)qw>Nvz z7VC9|7>A6wyQEkc+qkTfMbUc1lbmF|ni<+<%6t4G7^zcWvF2lXwB9sjz}FJrG#KaM znYMo};Y#b;QUouAwXl~c*ruUCugUCya6lhjFb*WvCt1!N>Cp6may#l4ITR>_%38h-Z-@ucNq_Cni>W0b~;w?4qZu4or= zbUF>MtkFD?(lDHMTm)2xKJq2Q8CaOmTTJQnP=~TJg1*6IL}<4VVaeN6b--ID@P5Z8 z=!f6XwXPguL2D|tSJreqM<6_z-jrAdmn7XmlR;KWGM4SELDDX`T3|XBRDX#ym4i@* z#&PCcW*B{@uo(F5M^&e7qzUSY{v3ON(9*FKpEMN{Tc64-{e&(GE+?yWE-Q;abTjm@ zr@aTV)%DvNl?z3Afdfxu{B^pGn<+$-HVNw{kW{utJsk7hjXOC=8W6dS=ImE55Xm*z zD4f#LvN_L?VoVrjvDQ|b-#JzWhA=%*h8S-E1V^(z!MzTJ;kWn*!iRTwPh8;V7_?|w zNZ#d7=n?uCDb$gUw0YZt*kkRXb`wBOnXf!E)sm~G%_h$NzF3jY%Rkv znz*d5dpZdnV1`1XFn{SW2cubz$xQ}%IAx;ZGozY#2OuBE(ph2)^>%r#HOZss1N@j3 zf_{N33o{|MHVVMWE+riMs{#sa!$21ttMPo<5S~u+VXpUJGI2C$V-_*1%s&po4ZFM`=%1)cdH43)u4nZx&1#6#(L)vjD>~*T9X;%Vx`|_^nAi8ik0-~*kBgJ~T zH?b#W0BKf2A0GmA*Gp-n;JnvXC!#N|rkI+nL`3yEq-=8_Km~GXptCRi8OdSjEMK(6 zYzW`K|kg+DV{_v4ZhKwOW5c?-G-~@gmRwR(=O>(I@&Jur1D!I zVtJ^<8*ZFi z0tql|kNDmi4P&3JTBe+DwHf&wUsIeW^|{(2lz~!2~lwLMLX-f|j`gC}kg*JmAU_jBy$xxNAA$^<%JqJ(+KY+)mAp{9RfDjVGh=2$W>pbuR5fPcG=@4il zB2nJI1MmMw1dQ+9C#gq?=vQsvNhW>y|M~tt9DWwi-+=mL_4oPV&shQ5N0{@Jm;UPi z`H=^3{8R@3e=-6%*h@=@KUSXr{3G@Jf9MB)hYC=?qXCD1EHl{j9oB#N2YwEF-A9;t zOG$mB{>Og+KCcIW=y%Y&{$uZd2jU+3;_om6)6Y@&^7{jHzQg(HBl+y-{2|CkFM&S+ zj=!3Lefrq1$_#wq?*Q(9hyeT@^3{GE>w&**G=P1pT3`eH4ER4{f!QC&2;kpA0elGX z->Q%O4tnPw$4wyrPk{eBZUhO?f8+q%{}2KE$td9bk9c6ee`^5j-@$KvvH{%xKl}Fe z^6%Kd`||wWN9yUndIQl%C@}f|83BJi0Q~-|@%LM03Lv15{>Vr0kKf^e#z*1-```ll zCnEqpJYfG10Dosdry)|Oiy1WDWezr<4xGM4khS38;GYmqAmISKp0lrYmR0=nsRCOE B0fhhn literal 0 HcmV?d00001 diff --git a/tests/cli/streaming_data/simple_dataset_val/index.json b/tests/cli/streaming_data/simple_dataset_val/index.json new file mode 100644 index 00000000..c566c0de --- /dev/null +++ b/tests/cli/streaming_data/simple_dataset_val/index.json @@ -0,0 +1 @@ +{"shards": [{"column_encodings": ["ndarray", "ndarray"], "column_names": ["data", "target"], "column_sizes": [null, null], "compression": "zstd", "format": "mds", "hashes": [], "raw_data": {"basename": "shard.00000.mds", "bytes": 1670, "hashes": {}}, "samples": 10, "size_limit": 67108864, "version": 2, "zip_data": {"basename": "shard.00000.mds.zstd", "bytes": 1495, "hashes": {}}}], "version": 2} \ No newline at end of file diff --git a/tests/cli/streaming_data/simple_dataset_val/shard.00000.mds b/tests/cli/streaming_data/simple_dataset_val/shard.00000.mds new file mode 100644 index 0000000000000000000000000000000000000000..b82a2b928d15e64e5140a70ae56ccb4a787e0492 GIT binary patch literal 1670 zcmY+Cdo)#S@n(5LlTg)-Sul1W|-o z%#V+gAH>-zmV_(nBjZKFP>CWV3>PaB(V^n#FZGHlMYCW5U&vQj9G2K{9UUF)4Frgdc1^j@OInubB9?wIE>69aUv zlFlpjjOU^ujMe6$>=Q@_~M7a2)KataHPkah)O?yByAQBt4 zU#2aabx;=B30lzssCv;97iZAuZMP7439TTQ-wBjiBRwy`71@;$SYPO`Ow+-hO~Rv8 zXbavrlMu2EWOJTC_mb!IdrdQR?t20kc63Qg9XIm1)YotjMG9g7J<14y-YD0M^4;W@=#v6Qr+&hCZE4VNyS1wgoWpWGmLW zZsPTRY9praOG)eC0FHBY6f_oVtkKP$h4W@i026)_+;mxJm)8ziiznc{FAFQw`)KVd zF$@Mo;lPI$+PG{8Ldv(H8f}ej<&%_Y;FoCxGh@+dQ9d4+RfV_gq$JP7jMwTs%6Q#8 z0hjj32-U1jx@+GuE4n{3vVFP4|6D#=X|tF|uAV5Ti7w8i^P&+)yx36ddmauN)xfC1 zJrG8!pnCxyo5pmJbyS20)_J1D(;14Qov>+|0obM&fUIh+GELw@BF%Za2)hrhCv&gy zmj>tfQ;KXDt?X4R7?F%6(&ppJDs^A`5JUb3W#&$z7r0ao2hh2f-9dZun2Y^)vt-8X;2 zSZ)xxG06?}RtOc7N}v4Wvej3eFhM69LVOm0v&Kf4e_##jU2{PT2*YziYb?vC0fVF_ zI4YZq@s4~9y2-)Oz8p|JzC)R2L_dTv?lEfd^DUD4*=tFEib_K3E2m+&KbB-zMM^FC zbI62Xha+~m(<%mi`eQ_Fx-2Sx_=zfXy07FpxG| z`QJ!~vzXJHMrnyG84gb?!sRUmU^dbYXP$X5jhU~dcYa+7>{?fn1%DAQnL3oLvm-{P zm2lAVYjE*wSJ0@M&O|}aB>2rsAA7SLu(-w+W3KeUft^is;LaW>lvrU*@nM*~pNlTw Uhr5qSz}!UzcMjR%nyH%j4-2S?5&!@I literal 0 HcmV?d00001 diff --git a/tests/cli/streaming_data/simple_dataset_val/shard.00000.mds.zstd b/tests/cli/streaming_data/simple_dataset_val/shard.00000.mds.zstd new file mode 100644 index 0000000000000000000000000000000000000000..7688459def38afbf50af6510814fbe25d176e256 GIT binary patch literal 1495 zcmV;|1t|I`wJ-f(h6Qad0NPsU6+IwYs}5!W;ROKC6jV?FF@$ILL9uy|8fT~?wV)ql z92yxiQDx#N-!kY4qcb)+(h+UpLEgvoY zfC~+LO3+MR(_+d+HKO}wTRR}=H8(k2p7Yx1rn64=0QCAK5y&;v<`0V<#&efXxW5$YlIeParXy|E zM-GGOGLYTa$YVP(T?Q)$+XnDw+fX>R4VAO)Cl=d&a&5aOZQDImw#`Gswt1+wy#r|5 zHm7adKLKk8cpJub9u#TkY7<7fjwWOHCYR)vNh~l&Yn^-zBo;3yrS)?lJ|7~|dtyv} z)cyPpF}xiOqSBp?Gv?vCXS%sY(2WUAe5o0n-dW|>DFd}H6$0x`z^RFQSOe zvuV(hW5;65y{YRnIFUZhaIRPQ0obb^@a{O8M_&el=b>$Uol`E`xnJ=9ii`EL8;SkR zNwXUfZFwjVbEjV7SaO7zM?QsYr)XynwIX$+5gFlxV-#LO78qzKB8GG>m|V~F1LqEd zweDOFsCx%*9X@9Zb??A|a|6>#G(yLBQ=Gn6*Xe#_vb$vhahI&y{sJx8-yW)dsR++^ zDPXVHC=+D3=GqxMA)h#xf$@@*aJyOa=p1QrQ;nh#f3Q0DC+qIIv8 z}1#o^)WZndZSQ*`|vKa?^%J+PgLl_!;n~c6{l(XY1{Zv zYP9o^I8Wa8a}!S>ZiD02N- zdcOe}^xImy-nin-djsUXg04ubO0L_-!ef5Ua_1q8U_pmDqR2I`Ol_xXWSoN-#x8{I z@SDql1vrT1Kl zzrYd?CAX_1xVq?yOG&wQ0*nq-kL7E+U+%>??GK(zL8}4rh>Lmo(J3{V`e3@j7Y@nj zdFu1MdBF$nmx(sOH>ThDL_W0>M)%v}>^aOxcrBfo&ST~BE%=e>4>FK-Bw}^vgrq(c zUiGgMt{v>E*4Z4A=#+MP(xN02ct*`D7w&09bqwKmp+71BMSQ_<-R9 x3krZsL&TWeL(zm(M%iVxU Date: Tue, 10 Jun 2025 07:01:18 +0000 Subject: [PATCH 1814/1892] Merged PR 2379: [Trainer]Load auto-merged checkpoint when world size changes --- docs/source/trainer.rst | 2 + nnscaler/cli/arg_parser.py | 81 +++++++++++++++++++++---------- nnscaler/cli/trainer.py | 94 +++++++++++++++++++++++++++++++----- nnscaler/cli/trainer_args.py | 70 +++++++++++++++++++-------- tests/cli/test_arg_parser.py | 17 ++++++- tests/cli/test_trainer.py | 55 ++++++++++++++++++++- 6 files changed, 258 insertions(+), 61 deletions(-) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 310a75d2..8afa5041 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -441,6 +441,8 @@ Please Note: Signature: ``def after_optimizer_step(trainer: 'Trainer') -> None:`` * ``on_load_checkpoint`` (``str``): The hook function to be called after loading the checkpoint. If you saved something with ``on_save_checkpoint`` this is your chance to restore this. + Please note when checkpoints are merged, the custom data saved in the checkpoint + will be collected and saved as array in merged checkpoint. You must handle this case. Signature: ``def on_load_checkpoint(trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None:`` * ``on_save_checkpoint`` (``str``): The hook function to be called before saving the checkpoint. If you want to save something, you can add it to the checkpoint here. diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index e883de4e..27edb215 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -5,7 +5,7 @@ import copy from typing import List, Optional, Tuple, Dict, Any, Union -from dataclasses import dataclass, is_dataclass, asdict +from dataclasses import dataclass, field, is_dataclass, asdict import enum import ast @@ -20,6 +20,20 @@ _VALUE_TYPE_KEY = '__value_type' _VALUE_KEY = 'value' +# Keys for metadata in dataclass fields +# These keys are used to control the deserialization and normalization behavior + +# specify a custom deserialization function, +# the return value of this function will be assigned to the dataclass field directly. +# So it should return a value of the type specified in the dataclass field +DESERIALIZE_KEY = 'deserialize' +# specify a custom normalization function. +# The value returned by this function will be further deserialized with default deserialization logic. +NORMALIZE_KEY = 'normalize' +# if set to True, the field will be skipped during deserialization +# You can use `__post_init__` to handle the deserialization of the field. +SKIP_DESERIALIZATION_KEY = 'skip_deserialization' + def parse_args(argv: List[str]) -> dict: raw_args = {} @@ -62,28 +76,16 @@ def merge_args(args: dict, argv: List[str]): def _merge_args(args: dict, new_args: dict): - MISSING = object() - + """ + Note: values in new_args can only be dict or str or None. + """ def _is_removed_key(k): return isinstance(k, str) and k.endswith('!') - def _clear_keys(data): - # values in new_args can only be dict or str - if isinstance(data, dict): - new_data = {} - for k, v in data.items(): - if _is_removed_key(k): - continue - v = _clear_keys(v) - if v is not MISSING: - new_data[k] = v - return new_data if new_data else MISSING - else: - return data - for k, v in new_args.items(): if _is_removed_key(k): # if the key ends with '!', we will remove the key from args + args.pop(k, None) # a little trick to support merge self k = k[:-1] args.pop(k, None) continue @@ -91,9 +93,15 @@ def _clear_keys(data): # if the existing value is not a dict/list, we will overwrite it with the new value # for example, if args is {'a': 1} and new_args is {'a': {'b': 2}}, # we will overwrite args['a'] with new_args['a'] - new_v = _clear_keys(v) - if new_v is not MISSING: - args[k] = new_v + if isinstance(v, dict): + new_v = copy.deepcopy(v) + # merge self trick is here. + # directly assign v to args[k] will not work + # because v can have removed keys. + _merge_args(new_v, v) + args[k] = new_v # do we need to keep the empty dict? + else: + args[k] = v elif isinstance(args[k], dict): if isinstance(v, dict): _merge_args(args[k], v) @@ -105,13 +113,14 @@ def _clear_keys(data): if isinstance(v, dict) \ and all( isinstance(item, int) or - (isinstance(item, str) and item.isdigit()) + (isinstance(item, str) and item.isdigit()) or + (_is_removed_key(item) and item[:-1].isdigit()) for item in v.keys() ): - # note: you can't delete an item in a list by index (with '!' ending), - current_value = {idx: item for idx, item in enumerate(args[k])} - new_value = {int(idx): item for idx, item in v.items()} + current_value = {str(idx): item for idx, item in enumerate(args[k])} + new_value = {str(idx): item for idx, item in v.items()} _merge_args(current_value, new_value) + current_value = {int(k): v for k, v in current_value.items()} args[k] = [None] * (max(current_value.keys()) + 1) for nk, nv in current_value.items(): args[k][nk] = nv @@ -230,6 +239,7 @@ class _TypeInfo: key_type: Any = None value_type: Any = None item_type: Any = None + metadata: dict = field(default_factory=dict) def _get_type_info_from_annotation(type_info): @@ -267,9 +277,16 @@ def _get_type_info_from_annotation(type_info): def _get_type_info(dataclass_type) -> Dict[str, _TypeInfo]: if not is_dataclass(dataclass_type): raise ValueError(f"{dataclass_type} is not a dataclass") - type_dict = {} + type_dict: dict[str, _TypeInfo] = {} for k, v in dataclass_type.__dataclass_fields__.items(): - type_dict[k] = _get_type_info_from_annotation(v.type) + if v.metadata.get(SKIP_DESERIALIZATION_KEY, False) or DESERIALIZE_KEY in v.metadata: + # if the field is marked as skip_deserialization, + # or if it has a custom deserialize function, + # we don't need to extract the type information + type_dict[k] = _TypeInfo(type=None) + else: + type_dict[k] = _get_type_info_from_annotation(v.type) + type_dict[k].metadata = v.metadata return type_dict @@ -360,8 +377,20 @@ def deserialize_dataclass(value, value_type): for key, ti in type_info.items(): if not key in value: continue + used_keys.add(key) + v = value[key] + + if deserialize_func := ti.metadata.get(DESERIALIZE_KEY, None): + v = deserialize_func(v) + member_values[key] = v + continue + + if normalize_func := ti.metadata.get(NORMALIZE_KEY, None): + v = normalize_func(v) + # will continue to process the value + if ti.type is bool and v is None: v = True # set bool to True if it shows up in cmd line if v is None: diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 2e201a2a..6c6689c7 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -39,6 +39,7 @@ CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/{rank}.ckpt' CHECKPOINT_LAST_DIR_NAME: str = 'last' CHECKPOINT_BEST_DIR_NAME: str = 'best' +CHECKPOINT_MERGED_FILE_NAME: str = 'merged.ckpt' CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}.ckpt' CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}.ckpt' @@ -83,11 +84,16 @@ def __init__(self, self.train_args = TrainerArgs.from_cli(cli_args) self.rank = None + self.world_size = None + self.local_world_size = None + self.local_rank = None + self.node_rank = None self.sync_group = None self.model = None self.optimizer = None self.dataset = {'train': None, 'val': None, 'test': None} self.dataloader: Dict[str, Optional[DataLoader]] = {'train': None, 'val': None, 'test': None} + self.dataloader_resumed = False # whether the dataloader is resumed from checkpoint self.lr_scheduler = None self.train_status = TrainStatus() self.dummy_input = None @@ -152,6 +158,9 @@ def _setup(self): torch.distributed.barrier() self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() + self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) + self.local_rank = int(os.environ.get('LOCAL_RANK')) + self.node_rank = int(os.environ.get('GROUP_RANK')) self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq if len(self.dataloader['train']) % self.train_args.update_freq != 0: @@ -201,13 +210,11 @@ def reducer_pre_hook(reducer, grad): self.hook.after_setup(self) @classmethod - def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): + def _merge_checkpoint(cls, checkpoint_files: List[str]): state_dicts = [torch.load(f, map_location='cpu', weights_only=False) for f in checkpoint_files] for i in range(1, len(state_dicts)): if state_dicts[i]['train_args'] != state_dicts[0]['train_args']: raise ValueError(f"train_args in {checkpoint_files[i]} is different from {checkpoint_files[0]}") - if state_dicts[i].get('dataloader', None) != state_dicts[0].get('dataloader', None): - raise ValueError(f"dataloader state in {checkpoint_files[i]} is different from {checkpoint_files[0]}") if state_dicts[i].get('lr_scheduler', None) != state_dicts[0].get('lr_scheduler', None): raise ValueError(f"lr_scheduler state in {checkpoint_files[i]} is different from {checkpoint_files[0]}") @@ -217,16 +224,35 @@ def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): ) train_args = copy.deepcopy(state_dicts[0]['train_args']) train_args['checkpoint']['save_type'] = 'merged' + + global_keys = { + 'model', 'optimizer', 'train_args', + 'train_status', 'lr_scheduler', 'rank' + } + # for extra keys (including `dataloader` and `rng_states`), we will not merge them. + # Intead we will collect them from all state_dicts + extra_keys: Dict[str, list] = {} + for s in state_dicts: + extra_keys.update({k: [] for k in s.keys() if k not in global_keys}) + if extra_keys: + sorted_state_dicts = sorted(state_dicts, key=lambda x: x['rank']) + for s in sorted_state_dicts: + for k in extra_keys: + extra_keys[k].append(s.get(k, None)) + merged_state_dict = { 'model': module_state_dict, 'optimizer': opt_state_dict, 'lr_scheduler': state_dicts[0].get('lr_scheduler', None), 'train_status': state_dicts[0]['train_status'], 'train_args': train_args, - 'rng_states': None, - # assume the dataloader state is the same for all checkpoints - 'dataloader': state_dicts[0].get('dataloader', None) + **extra_keys, } + return merged_state_dict + + @classmethod + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): + merged_state_dict = cls._merge_checkpoint(checkpoint_files) torch.save(merged_state_dict, output_file) def _log_finalize(self): @@ -249,9 +275,32 @@ def _load_checkpoint(self): logger.info(f"Resuming from {resume_from}") if resume_from.is_file(): resume_from = resume_from # when we load from merged checkpoint + state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) else: - resume_from = resume_from / f'{self.rank}.ckpt' - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + ckpt_files = list(resume_from.glob('*.ckpt')) + rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} + if set(rank_ckpt_files.keys()) != set(range(len(rank_ckpt_files))): + raise ValueError(f"Checkpoint files in {resume_from} are not complete: {rank_ckpt_files.keys()}") + if len(rank_ckpt_files) != self.world_size \ + and self.train_args.checkpoint.resume_from.with_merged is False: + raise ValueError(f"World size is different with original one: {len(rank_ckpt_files)} != {self.world_size}") + + if len(rank_ckpt_files) != self.world_size or self.train_args.checkpoint.resume_from.with_merged: + # merge the checkpoint files from all ranks and broadcast to all ranks + torch.distributed.barrier() + if self.rank == 0: + logger.info(f"Merging checkpoint files from {resume_from}") + state_dict = self._merge_checkpoint(list(rank_ckpt_files.values())) + else: + state_dict = None + state_dict_list = [state_dict] + logger.info(f"Broadcasting merged checkpoint to all ranks.") + torch.distributed.broadcast_object_list(state_dict_list, src=0) + state_dict = state_dict_list[0] + else: + resume_from = resume_from / f'{self.rank}.ckpt' + state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + self.hook.on_load_checkpoint(self, state_dict) ckpt_save_type = state_dict['train_args']['checkpoint']['save_type'] @@ -281,9 +330,28 @@ def _load_checkpoint(self): if 'dataloader' in state_dict and state_dict['dataloader'] is not None: if not self._is_resumable_dataloader(): raise ValueError("dataloader is not resumable, but checkpoint contains dataloader state") - self.dataloader['train'].load_state_dict(state_dict['dataloader']) + if ckpt_save_type == 'merged': + dataloader_states = state_dict['dataloader'] + # only load dataloader state when all ranks have the same state + # TODO: is this reasonable? + if all(dataloader_states[i] == dataloader_states[0] for i in range(1, len(dataloader_states))): + self.dataloader['train'].load_state_dict(dataloader_states[0]) + self.dataloader_resumed = True + else: + logger.warning("Dataloader states are not the same across ranks, will use dry run to resume dataloader state.") + self.dataloader_resumed = False + else: + self.dataloader['train'].load_state_dict(state_dict['dataloader']) + self.dataloader_resumed = True + else: + self.dataloader_resumed = False self.train_status = TrainStatus(**state_dict['train_status']) - self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() + + # we don't resume rng states when loading merged checkpoint, + if ckpt_save_type != 'merged': + self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() + else: + logger.warning("RNG states are not resumed when loading merged checkpoint.") def _log_mem_stats(self, tag=None): # log minimum free memory over the iteration @@ -359,6 +427,7 @@ def _save_checkpoint(self, loss): 'train_status': asdict(self.train_status), 'train_args': self.train_args.to_dict(), 'rng_states': self._get_rng_states(), + 'rank': self.rank, } if self._is_resumable_dataloader(): state_dict['dataloader'] = self.dataloader['train'].state_dict() @@ -410,9 +479,8 @@ def _save_checkpoint(self, loss): torch.distributed.barrier() # remove old checkpoints - local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) # only the first rank in the group will do the job - if self.rank % local_world_size == 0: + if self.rank % self.local_world_size == 0: try: self._expire_checkpoints() except Exception as e: @@ -456,7 +524,7 @@ def _expire_checkpoints(self): def _global_batch_iterator(self, num_skip_first=0, stage='train'): if stage == 'train': - if self._is_resumable_dataloader() or num_skip_first == 0: + if self.dataloader_resumed or num_skip_first == 0: # if the checkpoint stops at the end of an epoch, # the rng states must be resumed before creating iterator # because `DataLoader.__iter__()` uses the rng (dunno why), diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 12117634..a65901b7 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -348,6 +348,17 @@ def __post_init__(self): raise ValueError(f"Invalid interval {self.interval}") +@dataclass +class ResumeOptions: + checkpoint: str = 'last' + # whether to merge the checkpoint files + # Only used when `checkpoint` is a directory. + # `True` means will load the merged checkpoint (without saving) + # `False` means will load the sharded checkpoint files + # `None` means will load the sharded checkpoint files if the world size is not changed. + # and will load merged checkpoint if the world size is changed. + with_merged: Optional[bool] = None + @dataclass class CheckpointConfig: save_dir: str = './checkpoints' @@ -374,27 +385,33 @@ class CheckpointConfig: # resume training from a checkpoint folder/file # can be 'last'/'best'/a specific folder/file # we will not resume if resume_from is last or best but the corresponding checkpoint does not exist - resume_from: str = None + resume_from: Optional[ResumeOptions] = field(default=None, metadata={ + 'normalize': lambda x: {'checkpoint': x} if isinstance(x, str) else x + }) def get_resume_checkpoint_dir(self) -> Optional[Path]: if not self.resume_from: return None - if self.resume_from in ['last', 'best']: - d = Path(self.save_dir) / self.resume_from + if self.resume_from.checkpoint in ['last', 'best']: + d = Path(self.save_dir) / self.resume_from.checkpoint if not d.exists(): return None return d - return Path(self.resume_from) + return Path(self.resume_from.checkpoint) def __post_init__(self): + if isinstance(self.resume_from, str): + self.resume_from = ResumeOptions(checkpoint=self.resume_from) + elif isinstance(self.resume_from, dict): + self.resume_from = deserialize_dataclass(self.resume_from, ResumeOptions) if self.resume_from: - if self.resume_from in ['last', 'best']: + if self.resume_from.checkpoint in ['last', 'best']: if not self.save_dir: raise ValueError("save_dir is required when resume_from is 'last'/'best'") - if not (Path(self.save_dir) / self.resume_from).exists(): - logger.warning(f"`{self.resume_from}` checkpoint does not exist. Will train from scratch.") - elif not Path(self.resume_from).exists(): - raise ValueError(f"resume_from {self.resume_from} does not exist") + if not (Path(self.save_dir) / self.resume_from.checkpoint).exists(): + logger.warning(f"`{self.resume_from.checkpoint}` checkpoint does not exist. Will train from scratch.") + elif not Path(self.resume_from.checkpoint).exists(): + raise ValueError(f"resume_from {self.resume_from.checkpoint} does not exist") if self.no_save: return @@ -481,6 +498,18 @@ def __init__(self, hook_config: HookMapConfig): setattr(self, k, load_type(v)) +def _deserialize_hook_config(hook) -> Union[HookConfig, HookMapConfig]: + if isinstance(hook, dict): + if 'type' in hook: + return deserialize_dataclass(hook, HookConfig) + else: + # treat hook map as a dict. this is for backward compatibility + # don't use `deserialize_dataclass` here + # because hooks can be functions (not str) + return HookMapConfig(**hook) + raise ValueError(f"Invalid hook config {hook}.") + + @dataclass class TrainerArgs(PrecisionMixin, PolicyMixin): compute_config: ComputeConfig = None @@ -509,12 +538,16 @@ class TrainerArgs(PrecisionMixin, PolicyMixin): checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) log: List[LogConfig] = field(default_factory=list) # It can be `HookConfig` or `HookMapConfig` - hook: Union[HookConfig, HookMapConfig, None] = None + hook: Union[HookConfig, HookMapConfig, None] = field(default=None, metadata={ + 'deserialize': _deserialize_hook_config + }) debug: DebugConfig = field(default_factory=DebugConfig) - # TODO: mixed precision support - precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None + # None value will be resolved in __post_init__ + precision: Dict[_TENSOR_TYPE, _PRECISION_TYPE] = field(default=None, metadata={ + 'skip_deserialization': True, + }) micro_batch_size: int = 1 # You can set one of `global_batch_size` and `grad_accumulation_steps` option. @@ -604,6 +637,11 @@ def __post_init__(self): if self.lr_scheduler and not self.lr_scheduler.type: raise ValueError("lr_scheduler type is required") + if isinstance(self.hook, dict): + # if it is a dict, we will deserialize it to HookMapConfig + # This is for backward compatibility + self.hook = _deserialize_hook_config(self.hook) + if self.seed is None and self.init_env_fn is None: logger.warning( "Neither `seed` nor `init_env_fn` is not provided. " @@ -786,13 +824,7 @@ def create_hook(self) -> TrainHook: if not self.hook: return TrainHook() # empty hook - if isinstance(self.hook, dict): - if 'type' in self.hook: - hook_config = HookConfig(**self.hook) - else: - hook_config = HookMapConfig(**self.hook) - else: - hook_config = self.hook + hook_config = self.hook if isinstance(hook_config, HookConfig): kwargs = self.create_kwarg(hook_config.args) diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 5839f8ec..fff9bcbd 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -257,14 +257,17 @@ class A: z0 = copy.deepcopy(z) z1 = copy.deepcopy(z) z2 = copy.deepcopy(z) + z3 = copy.deepcopy(z) + z4 = copy.deepcopy(z) merge_args(z, ['--a.2=3','--b.0.h=3', '--b.1.h!', '--b.2.h!']) assert z == { 'a': [1, 2, '3'], - 'b': [{'h': '3'}, {}], + 'b': [{'h': '3'}, {}, {}], } merge_args(z0, ['--a.2=3','--b!', '--b.2.h!']) assert z0 == { 'a': [1, 2, '3'], + 'b': {'2': {}} } merge_args(z1, ['--a.2=3','--b!', '--b.2.h=3']) assert z1 == { @@ -280,6 +283,18 @@ class A: 'b': [None, {'h': '3'}], } + merge_args(z3, ['--a.2=3','--b.0.h=3', '--b.1.h!', '--b.2.h=3', '--b.2.h!', '--b.2.g=3']) + assert z3 == { + 'a': [1, 2, '3'], + 'b': [{'h': '3'}, {}, {'g': '3'}], + } + + merge_args(z4, ['--a.2=3','--b.1!', '--b.2!', '--b.3.h=3', '--b.4!']) + assert z4 == { + 'a': [1, 2, '3'], + 'b': [None, None, None, {'h': '3'}], + } + def test_resolve_args(): import os diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index a4cd4439..10443430 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -9,10 +9,10 @@ import torch.distributed from nnscaler import merge_state_dicts -from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer import Trainer, logger from nnscaler.cli.trainer_args import AggregatedOutputs, TrainerArgs from tests.parallel_module.common import assert_equal, assert_close -from tests.utils import init_random, replace_all_device_with, clear_parallel_cache +from tests.utils import catch_log, init_random, replace_all_device_with, clear_parallel_cache from ..launch_torchrun import launch_torchrun from .common import MixedModule, MixModuleMLP, MixModuleMLP3 @@ -825,6 +825,7 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.keep_last_n_checkpoints', 30, ]) trainer.run() + assert not trainer.dataloader_resumed # train 4 epcho in one time ckpt0_savedir = save_dir / 'ckpt0' @@ -842,6 +843,7 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.run() + assert not trainer.dataloader_resumed torch.distributed.barrier() # train 4 epcho two times (resume from last) @@ -860,6 +862,7 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.run() + assert not trainer.dataloader_resumed ckpt1_files0 = {f: f.stat().st_mtime_ns for f in ckpt1_savedir.glob('**/*.ckpt')} # resume from last without update max_epochs @@ -876,6 +879,7 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.run() + assert trainer.dataloader_resumed ckpt1_files0_x = {f: f.stat().st_mtime_ns for f in ckpt1_savedir.glob('**/*.ckpt')} # nothing should be updated in this case. assert ckpt1_files0 == ckpt1_files0_x @@ -885,6 +889,12 @@ def trainer_resumable_dataloader(save_dir): ckpt2_savedir.mkdir(parents=True, exist_ok=True) if trainer.rank == 0: Trainer.merge_checkpoint(list((ckpt1_savedir / 'last').glob('*.ckpt')), ckpt2_savedir / 'merged.pt') + merged_state_dict = torch.load(ckpt2_savedir / 'merged.pt') + assert 'dataloader' in merged_state_dict + assert isinstance(merged_state_dict['dataloader'], list) + assert len(merged_state_dict['dataloader']) == merged_state_dict['train_args']['compute_config']['runtime_ngpus'] + merged_state_dict.pop('dataloader') # remove the merged_state_dict from the cache + torch.save(merged_state_dict, ckpt2_savedir / 'merged2.pt') torch.distributed.barrier() # resume for sharded/deduped checkpoint @@ -900,6 +910,7 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.run() + assert trainer.dataloader_resumed torch.distributed.barrier() @@ -916,20 +927,60 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.run() + assert trainer.dataloader_resumed torch.distributed.barrier() + # resume for merged without dataloader states + ckpt3_savedir = save_dir / 'ckpt3' + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', str(ckpt2_savedir / 'merged2.pt'), + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert not trainer.dataloader_resumed + + # resume for auto-merged checkpoint + ckpt4_savedir = save_dir / 'ckpt4' + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt4_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt1_savedir / '0002-0035'), + '--checkpoint.resume_from.with_merged', True, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting merged checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states + if torch.distributed.get_rank() == 0: for i in range(4): g = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) z = torch.load(ckpt2_savedir / 'last' / f'{i}.ckpt', weights_only=False) + w = torch.load(ckpt3_savedir / 'last' / f'{i}.ckpt', weights_only=False) + v = torch.load(ckpt4_savedir / 'last' / f'{i}.ckpt', weights_only=False) assert 'dataloader' not in g assert 'dataloader' in x for key in ['model', 'optimizer', 'lr_scheduler', 'dataloader']: assert_equal(x[key], y[key]) assert_equal(x[key], z[key]) + assert_equal(x[key], w[key]) + assert_equal(x[key], v[key]) if key != 'dataloader': assert_equal(g[key], x[key]) From ac9bb543d86d4226b83676705afa48546b19c99a Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 18 Jun 2025 05:54:57 +0000 Subject: [PATCH 1815/1892] Merged PR 2377: [AutoDist] fix bug: calculate split info correctly --- nnscaler/autodist/apis.py | 113 +++++++++-------- tests/autodist/pas/activation_pp.json | 119 ++++++++++++++++++ .../autodist/pas/test_multiref_activation.py | 53 ++++++++ 3 files changed, 233 insertions(+), 52 deletions(-) create mode 100644 tests/autodist/pas/activation_pp.json diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 4bf4ea14..3924699c 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -138,46 +138,7 @@ def parallelize_graph(graph: IRGraph, nodes = [cid2node[cid] for cid in group] graph.recompute(nodes) - def subtensor_desc(t): - return (t.indmap, t.grad is not None) - tensor_split_info = defaultdict(dict) - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): - continue - consumers = graph.consumers(ftensor) - if not consumers: - continue - for consumer in consumers: - find_desc = False - for stage_idx, stage_desc in enumerate(pp_desc.spmd_descs): - if consumer.cid not in stage_desc.partition_descs: - continue - find_desc = True - node_desc = stage_desc.partition_descs[consumer.cid].desc - if len(node_desc) != 1: - raise RuntimeError(f'node {consumer} is partitioned along multiple dims') - - (p_idx, p_dim), p_num = node_desc[0] - if p_idx == -1: - partitioned_node = consumer - else: - partitioned_nodes = consumer.algorithm('dim').instantiate(idx=p_idx, dim=p_dim, num=p_num) - if partitioned_nodes is None: - raise RuntimeError(f'node {consumer} cannot be partitioned by {p_idx}-{p_dim}-{p_num}') - partitioned_node = partitioned_nodes[0] - - if stage_idx not in tensor_split_info[ftensor]: - tensor_split_info[ftensor][stage_idx] = set() - for input in partitioned_node.inputs(): - if isinstance(input, IRSubTensor) and input.parent == ftensor: - if p_idx == -1 and stage_desc.mesh_desc.ngpus > 1: - tensor_split_info[ftensor][stage_idx].add(('REPLICATED', subtensor_desc(input))) - else: - # special case: if the stage has only one gpu, we treat it as partitioned - tensor_split_info[ftensor][stage_idx].add(('PARTITIONED', subtensor_desc(input))) - break - assert find_desc, f'node {consumer} not found in any stage' - + tensor_split_info = collect_tensor_split_info(graph, pp_desc) # add multiref for shared parameters across stages # note that we have constrained that shared parameters cannot be partitioned in SPMDSolver, other input tensors # belonging to the same operator can be partitioned. For example, in some LLMs, the embedding matrix is shared @@ -212,6 +173,9 @@ def subtensor_desc(t): graph.staging([s[0] for s in stages]) stages = graph.select(ntype=IRSegment, flatten=False) stages = [s for s in stages if s.isfw()] + # update tensor_split_info since the graph has been transformed + for stage in stages: + tensor_split_info.update(collect_tensor_split_info(stage, pp_desc)) else: stages = [graph] @@ -245,26 +209,18 @@ def subtensor_desc(t): offset += cur_ngpus for node in stage.nodes(): if isinstance(node, IRFwOperation): - if isinstance( - node, - (IRGraphAnchor, IRPyFunc)) or node.name == 'multiref': + if isinstance(node, (IRGraphAnchor, IRPyFunc)) or node.name == 'multiref': continue if node.cid in spmd_desc.partition_descs: p_desc = spmd_desc.partition_descs[node.cid] partition_node(node, graph, dev, p_desc) if isinstance(node, IRDimops): - _logger.debug( - f'apply {node} with {node.anno} at {node.comment}, plan: {p_desc}' - ) + _logger.debug(f'apply {node} with {node.anno} at {node.comment}, plan: {p_desc}') else: - _logger.debug( - f'replicate non-IRDimops {node.signature} with {node.comment}' - ) + _logger.debug(f'replicate non-IRDimops {node.signature} with {node.comment}') else: replica(graph, node, dev) - _logger.debug( - f'NOT included in plan, replicate {node.signature} with {node.comment}' - ) + _logger.debug(f'NOT included in plan, replicate {node.signature} with {node.comment}') for dl in graph.select(ntype=IRDataOperation): replica(graph, dl, devs=list(range(autodist_config.mesh_desc.ngpus))) @@ -278,3 +234,56 @@ def subtensor_desc(t): ) return graph + + +def subtensor_desc(t): + return (t.indmap, t.grad is not None) + + +def collect_tensor_split_info(graph: IRGraph, pp_desc: PipelineSearchOutput): + """ + Collect information about how tensors are split across stages in the pipeline parallelism. + This function populates the `tensor_split_info` dictionary with details about each tensor's partitioning + across different stages, including whether they are replicated or partitioned. + """ + + tensor_split_info = defaultdict(dict) + for ftensor in graph.full_tensors(): + if ftensor.is_grad(): + continue + consumers = graph.consumers(ftensor) + if not consumers: + continue + for consumer in consumers: + find_desc = False + for stage_idx, stage_desc in enumerate(pp_desc.spmd_descs): + if consumer.cid not in stage_desc.partition_descs: + continue + find_desc = True + node_desc = stage_desc.partition_descs[consumer.cid].desc + if len(node_desc) != 1: + raise RuntimeError(f'node {consumer} is partitioned along multiple dims') + + (p_idx, p_dim), p_num = node_desc[0] + if p_idx == -1: + partitioned_node = consumer + else: + partitioned_nodes = consumer.algorithm('dim').instantiate(idx=p_idx, dim=p_dim, num=p_num) + if partitioned_nodes is None: + raise RuntimeError(f'node {consumer} cannot be partitioned by {p_idx}-{p_dim}-{p_num}') + partitioned_node = partitioned_nodes[0] + + if stage_idx not in tensor_split_info[ftensor]: + tensor_split_info[ftensor][stage_idx] = set() + for input in partitioned_node.inputs(): + if isinstance(input, IRSubTensor) and input.parent == ftensor: + if p_idx == -1 and stage_desc.mesh_desc.ngpus > 1: + tensor_split_info[ftensor][stage_idx].add(('REPLICATED', subtensor_desc(input))) + else: + # special case: if the stage has only one gpu, we treat it as partitioned + tensor_split_info[ftensor][stage_idx].add(('PARTITIONED', subtensor_desc(input))) + break + # operator inserted by nnscaler + if consumer.name not in ('multiref', 'identity'): + assert find_desc, f'node {consumer} not found in any stage' + return tensor_split_info diff --git a/tests/autodist/pas/activation_pp.json b/tests/autodist/pas/activation_pp.json new file mode 100644 index 00000000..3d2ee73c --- /dev/null +++ b/tests/autodist/pas/activation_pp.json @@ -0,0 +1,119 @@ +{ + "desc": { + "spmd_descs": [ + { + "partition_descs": [ + [ + 1, + [ + [ + [ + 0, + 0 + ], + 2 + ] + ] + ], + [ + 2, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } + }, + { + "partition_descs": [ + [ + 3, + [ + [ + [ + 0, + 0 + ], + 2 + ] + ] + ], + [ + 4, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ], + [ + 5, + [ + [ + [ + -1, + -1 + ], + 2 + ] + ] + ] + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 2 + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } + } + ], + "recompute_groups": [], + "mesh_desc": [ + 1, + 4 + ] + }, + "e2e_time": 0.0, + "stage_mems": [ + 0.0, 0.0 + ], + "stage_all_times": [ + 0.0, 0.0 + ], + "stage_comp_times": [ + 0.0, 0.0 + ] +} diff --git a/tests/autodist/pas/test_multiref_activation.py b/tests/autodist/pas/test_multiref_activation.py index 79ba2b87..cdb8f139 100644 --- a/tests/autodist/pas/test_multiref_activation.py +++ b/tests/autodist/pas/test_multiref_activation.py @@ -164,3 +164,56 @@ def test_diff_partition_2(): # linear_34 = nnscaler.runtime.adapter.nn.identity_allreduce(linear_34, ranks=[0, 1]) # linear_105, linear_109 = nnscaler.runtime.function.multiref(linear_34, times=2) assert len(_gencode_contains(tempdir, ModelB, 0, 'nnscaler.runtime.adapter.nn.identity_allreduce')) == 1 + + +class Layer(torch.nn.Module): + + def __init__(self): + super(Layer, self).__init__() + self.fc = torch.nn.Linear(10, 10, bias=False) + + def forward(self, x): + y = self.fc(x) + x = x + y + return x + + +class Decoder(torch.nn.Module): + + def __init__(self): + super(Decoder, self).__init__() + self.layer1 = Layer() + self.layer2 = Layer() + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = x.sum() + return x + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_activation_pp(): + m = Decoder() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 10], dtype=torch.float32, device=torch.cuda.current_device()) + + with tempfile.TemporaryDirectory() as tempdir: + pas_cfg = { + 'load_plan_path': Path(__file__).parent / 'activation_pp.json', + 'pipeline_nstages': 2, + 'pipeline_pivots': 'Layer', + } + parallelize( + m, + {'x': trace_data}, + 'autodist', + ComputeConfig(4, 4, use_end2end=True, pas_config=pas_cfg), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) + assert True, "should not raise any exception" From 344083cc789d9397859889bcda1a2a3eb28eaa8f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 21 Jul 2025 02:36:30 +0000 Subject: [PATCH 1816/1892] Merged PR 2380: add Fairseq example --- azure-pipelines.yml | 4 +- nnscaler/__init__.py | 17 ++- nnscaler/cli/arg_parser.py | 67 ++++++--- nnscaler/cli/mixed_module.py | 40 +++++- nnscaler/cli/train.py | 5 +- nnscaler/cli/train_hook.py | 22 +++ nnscaler/cli/trainer.py | 125 +++++++++++------ nnscaler/cli/trainer_args.py | 190 ++++++++++++++++++-------- nnscaler/graph/function/dimops.py | 4 +- nnscaler/graph/function/function.py | 8 ++ nnscaler/graph/function/pyfunc.py | 5 +- nnscaler/graph/parser/mapping.py | 1 + nnscaler/graph/parser/parser.py | 8 +- nnscaler/graph/parser/register.py | 9 +- nnscaler/ir/cten.py | 4 + nnscaler/ir/operator.py | 12 +- nnscaler/parallel.py | 100 +++++++++++--- nnscaler/profiler/database.py | 3 +- nnscaler/runtime/device.py | 18 +++ nnscaler/runtime/f16_optimizer.py | 12 ++ nnscaler/runtime/module.py | 27 +++- nnscaler/utils.py | 38 ++++++ requirements.txt | 2 +- tests/cli/test_arg_parser.py | 10 +- tests/cli/test_resume_seed.py | 2 +- tests/cli/test_trainer.py | 21 +++ tests/cli/trainer_args.yaml | 13 +- tests/graph/parser/test_register.py | 23 ++++ tests/parallel_module/test_gencode.py | 31 +++++ 29 files changed, 657 insertions(+), 164 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b1769e99..f8a2f6de 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -8,9 +8,11 @@ trigger: pool: vmImage: ubuntu-latest - + steps: - script: | + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main; + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r; pip install tox pip install tox-conda displayName: 'Install tox' diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index b6af82e7..2bf5867a 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -45,17 +45,30 @@ def init(): None """ from nnscaler import runtime - _ = runtime.device.DeviceGroup() + runtime.device.init_device() _ = runtime.resource.EnvResource() +def uninit(): + """ + Uninitialize the nnscaler library. + + It will destroy the torch distributed nccl process_group + + Returns: + None + """ + from nnscaler.runtime.device import uninit_device + uninit_device() + + def _check_torch_version(): import torch import logging torch_version = str(torch.__version__).split('+')[0] torch_version = tuple(int(v) for v in torch_version.split('.')[:2]) if torch_version < (2, 0): - logging.warn(f"expected PyTorch version >= 2.0 but got {torch_version}") + logging.warning(f"expected PyTorch version >= 2.0 but got {torch_version}") _check_torch_version() diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index 27edb215..f5caab0d 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field, is_dataclass, asdict import enum import ast +import regex try: @@ -35,6 +36,10 @@ SKIP_DESERIALIZATION_KEY = 'skip_deserialization' +class _KeyNotFoundError(KeyError): + pass + + def parse_args(argv: List[str]) -> dict: raw_args = {} last_key = None @@ -134,17 +139,47 @@ def resolve_args(args: dict): For example, if args is {'a': '$(b)', 'b': 'c'}, then it will be updated to {'a': 'c', 'b': 'c'}. """ + pattern = r'(\$\{[^}]+\}|\$\([^)]+\))' + def _is_variable(var_path): return isinstance(var_path, str) and ( (var_path.startswith('$(') and var_path.endswith(')')) or (var_path.startswith('${') and var_path.endswith('}')) ) - def _get_variable(var_path: str) -> Optional[str]: + def _get_variable(var_path: Any) -> Optional[str]: if not _is_variable(var_path): return None return var_path[2:-1] + def _get_variables(var_path: str) -> List[str]: + """ + Get all variables in the var_path. + For example, if var_path is 'a$(a.b.c)b$(c.d)c', it will return ['a.b.c', 'c.d']. + """ + # use regex to find all variables in the var_path + matches = regex.findall(pattern, var_path) + return [_get_variable(m) for m in matches] + + def _resolve_variables(var_path: Any, resolved_vars: dict[str, str]) -> str | Any: + """ + Resolve all variables in the var_path by replacing them with their values. + For example, if var_path is 'a$(b.c)d$(e.f)g', and resolved_vars is {'b.c': 'x', 'e.f': 'y'}, + it will return 'axdyg'. + """ + # special case, this will keep the type of the variable + if _is_variable(var_path): + return resolved_vars[_get_variable(var_path)] + + # always return a string + var_path = regex.sub( + pattern, + lambda m: str(resolved_vars[_get_variable(m.group(0))]), + var_path + ) + var_path = var_path.replace(r'$\(', '$(').replace(r'$\{', '${') # escape the variable syntax + return var_path + def _get_value(data, var_path: list[Any]): for key in var_path: if isinstance(data, list): @@ -152,7 +187,7 @@ def _get_value(data, var_path: list[Any]): elif key in data: data = data[key] else: - raise ValueError(f"{var_path} not found in args") + raise _KeyNotFoundError(f"{var_path} not found in args") return data def _set_value(data, var_path: list[Any], value): @@ -163,7 +198,7 @@ def _set_value(data, var_path: list[Any], value): elif key in data: data = data[key] else: - raise ValueError(f"{var_path} not found in args") + raise _KeyNotFoundError(f"{var_path} not found in args") if isinstance(data, list): data[int(var_path[-1])] = value @@ -180,28 +215,28 @@ def _resolve(var_path: list[Any], value: Any): for i, v in enumerate(value): _resolve(var_path + [i], v) return value - else: - ref_key = _get_variable(value) - if ref_key: + elif isinstance(value, str): + ref_keys = _get_variables(value) + ref_values = {} + for ref_key in ref_keys: if ref_key in pending_values: raise ValueError(f"Circular reference detected for {ref_key}") pending_values.add(ref_key) ref_var_path = ref_key.split('.') try: - value = _get_value(args, ref_var_path) - resolved_value = _resolve(ref_var_path, value) - except ValueError as e: + raw_ref_value = _get_value(args, ref_var_path) + ref_values[ref_key] = _resolve(ref_var_path, raw_ref_value) + except _KeyNotFoundError as e: if ref_key in os.environ: - resolved_value = os.environ[ref_key] + ref_values[ref_key] = os.environ[ref_key] else: raise - - _set_value(args, var_path, resolved_value) pending_values.remove(ref_key) - return resolved_value - else: - return value - + resolved_value = _resolve_variables(value, ref_values) + _set_value(args, var_path, resolved_value) + return resolved_value + else: + return value _resolve([], args) diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index 89741ff0..fd2b4601 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -4,6 +4,8 @@ from dataclasses import asdict, replace import inspect import copy +import logging +from functools import partial import nnscaler from nnscaler.runtime.adapter.reducer import Reducer @@ -16,6 +18,9 @@ ) +logger = logging.getLogger(__name__) + + def fork_rng(): if torch.distributed.is_initialized(): rank = torch.distributed.get_rank() @@ -142,9 +147,9 @@ def precision(self): else self.trainer_args.precision ) - def create_model(self) -> torch.nn.Module: + def create_model(self, module_args: Optional[tuple[tuple, dict]]=None) -> torch.nn.Module: model = ( - self.parallel_module.create_model(self.trainer_args) + self.parallel_module.create_model(self.trainer_args, module_args) if self.parallel_module else self.trainer_args.create_model() ) @@ -184,13 +189,17 @@ def resolve_compute_config(self): } return compute_config - def parallelize(self, dummy_input: Optional[dict[str, Any]] = None, *, load_module: bool = True): + def parallelize(self, + dummy_input: Optional[dict[str, Any]] = None, *, + load_module: bool = True, + module_args: Optional[tuple[tuple, dict]] = None + ): pmodel_class = nnscaler.parallelize( self.model_type, self.create_dummy_forward_args(dummy_input), self.resolved_pas_policy, self.resolve_compute_config(), - module_fn=self.create_model, + module_fn=partial(self.create_model, module_args=module_args), gen_savedir=self.gen_savedir, reuse=self.gen_reuse, instance_name=self.instance_name, @@ -280,7 +289,7 @@ def _new_adapter(parallel_module=None): # parallelize the whole model return _new_adapter().parallelize(dummy_input, load_module=load_module) - if not load_module: + if not load_module and all(pm.args is not None for pm in trainer_args.model.parallel_modules): for m in trainer_args.model.parallel_modules: _new_adapter(m).parallelize(dummy_input, load_module=False) return @@ -289,6 +298,7 @@ def _new_adapter(parallel_module=None): load_type(m.type): m for m in trainer_args.model.parallel_modules } + paralleled_sub_modules = set() def _default_new(cls, *args, **kwargs): return object.__new__(cls) @@ -314,20 +324,36 @@ def __parallel__new__(cls, *args, **kwargs): _restore_new() # it can go here when a subclass module of a parallelized module is instantiated if cls not in parallel_sub_modules: + # TODO: pass *args and **kwargs? return cls.__new__(cls) else: + if cls in paralleled_sub_modules: + logger.warning( + f'Parallelized module {cls.__name__} is already created. Previously Parallelized version will be reused.' + ) + paralleled_sub_modules.add(cls) parallel_module_config = parallel_sub_modules[cls] adapter = _new_adapter(parallel_module_config) # fork the random state to # make sure the modules after parallelized module # are the same in all devices. + # TODO: This will cause the random state to be different to non-parallel version. + # This is a trade-off to make sure the parallelized module is consistent. + # Maybe we can use torch.distributed.broadcast to sync the random state in all devices. with fork_rng(): - return adapter.parallelize(dummy_input, load_module=True) + return adapter.parallelize(dummy_input, load_module=load_module, module_args=(args, kwargs)) finally: _patch_new() _patch_new() try: - return trainer_args.to_precision(trainer_args.create_model()) + model = trainer_args.to_precision(trainer_args.create_model()) + missing_modules = set(parallel_sub_modules.keys()) - paralleled_sub_modules + if missing_modules: + logger.warning( + f'The following modules are not parallelized because they are not used: {", ".join(m.__name__ for m in missing_modules)}' + ) + if load_module: + return model finally: _restore_new() diff --git a/nnscaler/cli/train.py b/nnscaler/cli/train.py index c2a81519..6de3bc68 100644 --- a/nnscaler/cli/train.py +++ b/nnscaler/cli/train.py @@ -10,7 +10,10 @@ def main(): nnscaler.utils.set_default_logger_level(level=logging.INFO) trainer = Trainer() - trainer.run() + try: + trainer.run() + finally: + nnscaler.uninit() if __name__ == '__main__': diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index 09707f9c..aebb44d3 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -21,6 +21,11 @@ def after_setup(self, trainer: 'Trainer') -> None: When run_mode == 'compile', this hook will not be called. """ + def on_finalize(self, trainer: 'Trainer') -> None: + """ + Called after training is done. + """ + def on_train_start(self, trainer: 'Trainer') -> None: """Called at the beginning of training""" @@ -152,6 +157,15 @@ def on_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> checkpoint: the checkpoint loaded """ + def after_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: + """ + Called after setting model/optimizer/etc from checkpoint. + You can use this to restore some states for model/optimizer/etc that are not saved in the checkpoint. + + Args: + checkpoint: the checkpoint loaded + """ + def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: """ Called before saving checkpoint. @@ -170,6 +184,10 @@ def after_setup(self, trainer: 'Trainer') -> None: for hook in self.hooks: hook.after_setup(trainer) + def on_finalize(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.on_finalize(trainer) + def on_train_start(self, trainer: 'Trainer') -> None: for hook in self.hooks: hook.on_train_start(trainer) @@ -254,6 +272,10 @@ def on_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> for hook in self.hooks: hook.on_load_checkpoint(trainer, checkpoint) + def after_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: + for hook in self.hooks: + hook.after_load_checkpoint(trainer, checkpoint) + def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: for hook in self.hooks: hook.on_save_checkpoint(trainer, checkpoint) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 6c6689c7..4338ad83 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -105,25 +105,40 @@ def __init__(self, self.rng_states_from_resume: dict[str, torch.Tensor] | None = None def run(self): - self._setup() - if self.train_args.compile_mode: - return - self._train() + try: + self._setup() + if not self.train_args.compile_mode: + self._train() + finally: + for stage in ['train', 'val', 'test']: + if self.dataloader[stage] is not None and (close_fn := getattr(self.dataloader[stage], 'close', None)): + close_fn() + self.dataset[stage] = None + self.dataloader[stage] = None + if self.hook: + self.hook.on_finalize(self) + # It is very common to use `torch.distributed` after training + # So let's not uninitialize nnscaler here. + # TODO: make it configurable? + # nnscaler.uninit() def _fix_input(self, input): return fix_input(input, self.train_args.input_dtype) def _load_dummy_input(self): + if dummy_sample_gen_fn := self.train_args.resolved_dummy_sample_gen_fn: + return dummy_sample_gen_fn(self.train_args) + with enforce_zero_num_worker(DataLoader): - assert self.dataset['train'] is not None, "train dataset is not set" - dataloader = self.train_args.create_dataloader('train', self.dataset['train']) + dataset = self.train_args.create_dataset('train') + dataloader = self.train_args.create_dataloader('train', dataset) assert dataloader.num_workers == 0, "The dataloader must have `num_workers=0`." - return next(iter(dataloader)) + value = next(iter(dataloader)) + if close_fn := getattr(dataloader, 'close', None): + close_fn() + return value def _setup(self): - self.train_args.init_env(self) - compile_only = self.train_args.compile_mode - if is_running_distributed(): nnscaler.init() if torch.distributed.get_rank() == 0: @@ -131,14 +146,28 @@ def _setup(self): else: logging.getLogger().setLevel(logging.WARNING) - # create dataset and dataloader - for stage in ['train', 'val', 'test']: - self.dataset[stage] = self.train_args.create_dataset(stage) + self.train_args.init_env(self) + + # make sure all ranks are synchronized after init_env + if is_running_distributed(): + torch.distributed.barrier() + + compile_only = self.train_args.compile_mode # load a dummy input from training dataset self.dummy_input = self._load_dummy_input() self.dummy_input = self._fix_input(self.dummy_input) + pmodel = parallelize_model(self.train_args, self.dummy_input, load_module=not compile_only) + if compile_only: + return + + torch.distributed.barrier() + + # create dataset and dataloader + for stage in ['train', 'val', 'test']: + self.dataset[stage] = self.train_args.create_dataset(stage) + for stage in ['train', 'val', 'test']: self.dataloader[stage] = self.train_args.create_dataloader(stage, self.dataset[stage]) if self.dataloader[stage] is not None \ @@ -151,11 +180,6 @@ def _setup(self): f"You can specify `drop_last=True` in DataLoader to fix this problem." ) - pmodel = parallelize_model(self.train_args, self.dummy_input, load_module=not compile_only) - if compile_only: - return - - torch.distributed.barrier() self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) @@ -269,13 +293,15 @@ def _log_config(self, config: Dict): logger.setup(config) def _load_checkpoint(self): - resume_from = self.train_args.checkpoint.get_resume_checkpoint_dir() + resume_from = self.train_args.checkpoint.get_resume_checkpoint() if not resume_from: return logger.info(f"Resuming from {resume_from}") if resume_from.is_file(): resume_from = resume_from # when we load from merged checkpoint state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + if convert_fn := self.train_args.checkpoint.resolved_convert_fn: + state_dict = convert_fn(state_dict) else: ckpt_files = list(resume_from.glob('*.ckpt')) rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} @@ -295,6 +321,8 @@ def _load_checkpoint(self): state_dict = None state_dict_list = [state_dict] logger.info(f"Broadcasting merged checkpoint to all ranks.") + # TODO: it will be easily out of memory when the model is large + # We should broadcast the state_dict one parameter by one torch.distributed.broadcast_object_list(state_dict_list, src=0) state_dict = state_dict_list[0] else: @@ -302,7 +330,11 @@ def _load_checkpoint(self): state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) self.hook.on_load_checkpoint(self, state_dict) - ckpt_save_type = state_dict['train_args']['checkpoint']['save_type'] + # if it is not a well-formed state_dict (from third party) + # we will treat it as a merged state_dict + ckpt_save_type = state_dict.get('train_args', {}) \ + .get('checkpoint', {}) \ + .get('save_type', 'merged') if ckpt_save_type == 'merged': # it is a merged state dict nnscaler.load_merged_state_dict( @@ -345,7 +377,9 @@ def _load_checkpoint(self): self.dataloader_resumed = True else: self.dataloader_resumed = False - self.train_status = TrainStatus(**state_dict['train_status']) + + if 'train_status' in state_dict: + self.train_status = TrainStatus(**state_dict['train_status']) # we don't resume rng states when loading merged checkpoint, if ckpt_save_type != 'merged': @@ -353,6 +387,8 @@ def _load_checkpoint(self): else: logger.warning("RNG states are not resumed when loading merged checkpoint.") + self.hook.after_load_checkpoint(self, state_dict) + def _log_mem_stats(self, tag=None): # log minimum free memory over the iteration cuda_free, _ = torch.cuda.mem_get_info() @@ -374,9 +410,17 @@ def _format_metrics(self, epoch_desc, idx, metrics: Dict[str, Union[float,int]]) idx_format = f"0{ndigits}d" int_format = '' float_format = '.3f' - metris_str = ', '.join( + float_scientific_format = '.3e' + def _select_format(v): + if isinstance(v, float): + if v != 0.0 and (v < 1e-3 or v > 1e3): + return float_scientific_format + else: + return float_format + return int_format + metrics_str = ', '.join( [ - f"{k}={format(v, float_format if isinstance(v, float) else int_format)}" + f"{k}={format(v, _select_format(v))}" for k, v in metrics.items() ] ) @@ -384,7 +428,7 @@ def _format_metrics(self, epoch_desc, idx, metrics: Dict[str, Union[float,int]]) step_str = f'{format(idx, idx_format)}/{self.total_train_steps_per_epoch} ' else: step_str = f'' - return f"{epoch_desc}: {step_str}{metris_str}" + return f"{epoch_desc}: {step_str}{metrics_str}" def _is_resumable_dataloader(self): return ( @@ -428,10 +472,14 @@ def _save_checkpoint(self, loss): 'train_args': self.train_args.to_dict(), 'rng_states': self._get_rng_states(), 'rank': self.rank, + 'nnscaler': nnscaler.__version__, } + if self._is_resumable_dataloader(): - state_dict['dataloader'] = self.dataloader['train'].state_dict() + state_dict['dataloader'] = self.dataloader['train'].state_dict() # problematic + self.hook.on_save_checkpoint(self, state_dict) + ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( epoch=current_epoch, step=self.train_status.finished_train_steps, @@ -555,19 +603,9 @@ def _global_batch_iterator(self, num_skip_first=0, stage='train'): def aggregate_outputs(self, loss_outputs, sync_group) -> AggregatedOutputs: # loss is the first element of the output (or the only element) - losses = [ - loss if isinstance(loss, torch.Tensor) - else loss[0] - for loss in loss_outputs - ] - loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) - torch.distributed.all_reduce(loss_sum, group=sync_group) - num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) - torch.distributed.all_reduce(num_batches, group=sync_group) - - return AggregatedOutputs( - loss_sum=loss_sum.item(), - num_batches=num_batches.item(), + return AggregatedOutputs.aggregate( + loss_outputs, sync_group=sync_group, + loss_fn=lambda loss: loss if isinstance(loss, torch.Tensor) else loss[0] ) def _fix_batches(self, batches): @@ -682,6 +720,7 @@ def _train(self): self.hook.on_train_start(self) for epoch in range(start_epoch, self.train_args.max_epochs or sys.maxsize): + # TODO: make sure set_epoch doesn't have negative effect when called multiple times (i.e. when resuming) if hasattr(self.dataloader['train'], 'set_epoch'): self.dataloader['train'].set_epoch(epoch) elif hasattr(self.dataloader['train'].sampler, 'set_epoch'): @@ -828,15 +867,18 @@ def _train_epoch(self, epoch: int) -> None: aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) if self.train_args.optimizer.loss_reduction == 'mean': loss = aggregated_outputs.loss_sum / aggregated_outputs.num_batches + elif self.train_args.optimizer.loss_reduction == 'per-token-mean': + if not aggregated_outputs.num_tokens: + raise RuntimeError("`aggregate_outputs` doesn't set `num_tokens` field") + loss = aggregated_outputs.loss_sum / aggregated_outputs.num_tokens else: loss = aggregated_outputs.loss_sum step_stat.train_loss = loss self.hook.after_aggregate_train_step_outputs(self, aggregated_outputs, loss, idx) self.hook.before_sync_grad(self) - # actually `sync_shard_grad` is no-op here - # because trainer only supports end2end model - # and syncing grad in end2end model is done in `_train_step`. + # `sync_shard_grad` is no-op if the whole model is parallelized + # because syncing grad in end2end model is done in `_train_step`. self.optimizer.sync_shard_grad() self.hook.after_sync_grad(self) @@ -881,6 +923,7 @@ def _train_epoch(self, epoch: int) -> None: self._log_mem_stats(tag='train') step_metrics = {k:v for k, v in asdict(step_stat).items() if v is not None} step_metrics['train_wall'] = time.perf_counter() - step_start_at + step_metrics['loss'] = step_metrics['train_loss'] self.log_metrics(step_metrics, tag='train') if self.rank == 0: data_iter.set_postfix(step_metrics) diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index a65901b7..221e1717 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass, field, replace import importlib -from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union, TypeVar from typing_extensions import get_args from pathlib import Path import logging @@ -20,7 +20,7 @@ import torch import nnscaler -from nnscaler.utils import fields, transform_recursively +from nnscaler.utils import fields, transform_recursively, load_type from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule @@ -48,6 +48,8 @@ 'bf16': torch.bfloat16, 'none': None # as it is. no conversion will happen. } +_SELF_ARG_VALUE = 'self' +_LOSS_TYPE = TypeVar('_LOSS_TYPE') def _get_tensor_dtype(precision: Dict[_TENSOR_TYPE, _PRECISION_TYPE], tensor_type: _TENSOR_TYPE) -> torch.dtype: @@ -144,37 +146,6 @@ def resolved_pas_policy(self): return load_type(self.pas_policy) -def load_type(type_name: str): - """ - Load function/class from its full qualified name - """ - if callable(type_name): # a function or class - return type_name - - parts = type_name.split('.') - - # s: the number of parts to be the namespace - # s == 0: use builtins - # so the range() part includes 0 (with stop=-1) - for s in range(len(parts) - 1, -1, -1): - if s == 0: - nm = builtins - else: - namespace = '.'.join(parts[:s]) - try: - nm = importlib.import_module(namespace) - break - except (ImportError, ModuleNotFoundError): - pass - - try: - for i in range(s, len(parts)): - nm = getattr(nm, parts[i]) - return nm - except AttributeError as e: - raise RuntimeError(f"Failed to load type {type_name}") from e - - @dataclass class AggregatedOutputs: """ @@ -189,6 +160,37 @@ class AggregatedOutputs: # any other custom outputs aggregated_outputs: Any = None + @classmethod + def aggregate(cls, + loss_outputs: list[_LOSS_TYPE], + sync_group: torch.distributed.ProcessGroup, + loss_fn: Callable[[_LOSS_TYPE], torch.Tensor], + ntokens_fn: Callable[[_LOSS_TYPE], torch.Tensor] | None = None, + ) -> 'AggregatedOutputs': + losses, ntokens = [], [] + for output in loss_outputs: + losses.append(loss_fn(output)) + if ntokens_fn is not None: + ntokens.append(ntokens_fn(output)) + + loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) + torch.distributed.all_reduce(loss_sum, group=sync_group) + + if ntokens_fn is not None: + ntokens_sum = torch.sum(torch.tensor(ntokens, dtype=torch.int64, device=torch.cuda.current_device())) + torch.distributed.all_reduce(ntokens_sum, group=sync_group) + else: + ntokens_sum = None + + num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) + torch.distributed.all_reduce(num_batches, group=sync_group) + + return AggregatedOutputs( + loss_sum=loss_sum.item(), + num_batches=num_batches.item(), + num_tokens=ntokens_sum.item() if ntokens_sum is not None else None, + ) + @dataclass(frozen=True) class OptionalComputeConfig: @@ -220,7 +222,11 @@ class ModuleParallelizeConfig: # Please note if you specify this # pipeline parallelism will be disabled, and you must ensure ComputeConfig.use_end2end is False type: str = None - args: Dict[str, Any] = field(default_factory=dict) + # the module args to be used for creating the module + # If run_mode is 'compile' and `args` is not None + # we can parallelize submodules instead of creating whole model. + # This is useful sometimes. + args: Optional[Dict[str, Any]] = None # the full qualified name of the function to generate dummy forward args # Its type should be `Callable[[TrainerArgs],Dict[str, Any]]` forward_args_gen_fn: str = None @@ -255,9 +261,14 @@ def __post_init__(self): def model_type(self): return load_type(self.type) - def create_model(self, trainer_args: 'TrainerArgs') -> torch.nn.Module: - kwargs = trainer_args.create_kwarg(self.args) - return self.model_type(**kwargs) + def create_model(self, trainer_args: 'TrainerArgs', module_args: Optional[tuple[tuple, dict]]=None) -> torch.nn.Module: + if self.args: + args, kwargs = (), trainer_args.create_kwarg(self.args) + elif module_args: + args, kwargs = module_args + else: + raise ValueError("`module_args` or `args` must be provided") + return self.model_type(*args, **kwargs) def create_dummy_forward_args(self, trainer_args: 'TrainerArgs') -> dict[str, Any]: forward_args_gen_fn = load_type(self.forward_args_gen_fn) @@ -275,9 +286,9 @@ class ModelConfig: parallel_modules: list[ModuleParallelizeConfig] = field(default_factory=list) def __post_init__(self): - parallel_sub_modules = [load_type(m.type) for m in self.parallel_modules] - if set(parallel_sub_modules) != set(parallel_sub_modules): - raise ValueError(f"parallelized sub modules must be unique") + if len(set(m.type for m in self.parallel_modules)) != len(self.parallel_modules): + raise ValueError(f"parallelized sub modules must be unique by type") + @dataclass class OptimizerConfig: @@ -288,6 +299,8 @@ class OptimizerConfig: # loss reduction method # mean: average the loss over all micro-batches # sum: sum the loss of all micro-batches + # per-token-mean: average the gradients over all tokens + # you must specify `aggregate_outputs_fn` and return the number of tokens # Please note in validation stage, this configuration is ignored # the loss is always averaged over all batches loss_reduction: str = 'mean' @@ -308,9 +321,12 @@ def __post_init__(self): raise ValueError(f"Invalid gradient_accumulation {self.grad_reduction}") if self.grad_reduction == 'per-token-mean' and not self.aggregate_outputs_fn: raise ValueError("aggregate_outputs_fn is required when grad_reduction is 'per-token-mean'") - if self.loss_reduction not in ('mean', 'sum'): + if self.loss_reduction == 'per-token-mean' and not self.aggregate_outputs_fn: + raise ValueError("aggregate_outputs_fn is required when loss_reduction is 'per-token-mean'") + if self.loss_reduction not in ('mean', 'sum', 'per-token-mean'): raise ValueError(f"Invalid loss_reduction {self.loss_reduction}") + @dataclass class DatasetConfig: type: str = None @@ -351,6 +367,11 @@ def __post_init__(self): @dataclass class ResumeOptions: checkpoint: str = 'last' + # the full qualified name of the function to + # convert the checkpoint to nnscaler format + # It should be `Callable[[Dict[str, Any]], Dict[str, Any]]` + # Only applied when `checkpoint` is a file. + convert_fn: Optional[str] = None # whether to merge the checkpoint files # Only used when `checkpoint` is a directory. # `True` means will load the merged checkpoint (without saving) @@ -359,6 +380,7 @@ class ResumeOptions: # and will load merged checkpoint if the world size is changed. with_merged: Optional[bool] = None + @dataclass class CheckpointConfig: save_dir: str = './checkpoints' @@ -389,8 +411,8 @@ class CheckpointConfig: 'normalize': lambda x: {'checkpoint': x} if isinstance(x, str) else x }) - def get_resume_checkpoint_dir(self) -> Optional[Path]: - if not self.resume_from: + def get_resume_checkpoint(self) -> Optional[Path]: + if not self.resume_from or not self.resume_from.checkpoint: return None if self.resume_from.checkpoint in ['last', 'best']: d = Path(self.save_dir) / self.resume_from.checkpoint @@ -399,12 +421,14 @@ def get_resume_checkpoint_dir(self) -> Optional[Path]: return d return Path(self.resume_from.checkpoint) + @property + def resolved_convert_fn(self) -> Optional[Callable[[Dict[str, Any]], Dict[str, Any]]]: + if not self.resume_from or not self.resume_from.convert_fn: + return None + return load_type(self.resume_from.convert_fn) + def __post_init__(self): - if isinstance(self.resume_from, str): - self.resume_from = ResumeOptions(checkpoint=self.resume_from) - elif isinstance(self.resume_from, dict): - self.resume_from = deserialize_dataclass(self.resume_from, ResumeOptions) - if self.resume_from: + if self.resume_from and self.resume_from.checkpoint: if self.resume_from.checkpoint in ['last', 'best']: if not self.save_dir: raise ValueError("save_dir is required when resume_from is 'last'/'best'") @@ -438,6 +462,13 @@ class LogConfig: type: str = None args: Dict[str, Any] = field(default_factory=dict) + def __post_init__(self): + if not self.type: + raise ValueError("type is required") + if isinstance(self.type, str) and '.' not in self.type: + # assume it is a built-in logger + self.type = f'nnscaler.cli.loggers.{self.type}' + @dataclass class DebugConfig: @@ -457,6 +488,7 @@ class HookConfig: @dataclass class HookMapConfig: after_setup: str = None + on_finalize: str = None on_train_start: str = None on_train_end: str = None @@ -487,6 +519,7 @@ class HookMapConfig: after_optimizer_step: str = None on_load_checkpoint: str = None + after_load_checkpoint: str = None on_save_checkpoint: str = None @@ -512,6 +545,8 @@ def _deserialize_hook_config(hook) -> Union[HookConfig, HookMapConfig]: @dataclass class TrainerArgs(PrecisionMixin, PolicyMixin): + init_module: Optional[str] = None + vars: Dict[str, Any] = field(default_factory=dict) compute_config: ComputeConfig = None gen_savedir: str = './.nnscaler' @@ -525,6 +560,9 @@ class TrainerArgs(PrecisionMixin, PolicyMixin): # compile: compile the model but not training # run: compile and run the model run_mode: str = 'run' + # the full qualified name of the function to generate dummy sample for forward + # Its type should be `Callable[[TrainerArgs], Any]` + dummy_sample_gen_fn: str = None # the model state dict file for tracing. # It is only used in tracing to serve as the initial state dict of the model. tracing_from_weights: str = None @@ -649,6 +687,8 @@ def __post_init__(self): "and the model weights on different devices can be different." ) + self._vars = self.create_kwarg(self.vars) + @classmethod def from_cli(cls, argv: List[str]) -> 'TrainerArgs': d = {} @@ -663,6 +703,8 @@ def from_cli(cls, argv: List[str]) -> 'TrainerArgs': @classmethod def from_dict(cls, d: Dict[str, Any]) -> 'TrainerArgs': + if init_module := d.get('init_module', None): + importlib.import_module(init_module) ta = deserialize_dataclass(d, TrainerArgs) return ta @@ -677,13 +719,11 @@ def to_dict(self): @classmethod def from_yaml(cls, path: str) -> 'TrainerArgs': - with open(path, 'r') as f: - return cls.from_dict(yaml.safe_load(f)) + return cls.from_cli(['-f', path]) - @classmethod - def create_kwarg(cls, value: Any): + def create_kwarg(self, value: Any) -> Any: if isinstance(value, dict): - value = {k: cls.create_kwarg(v) for k, v in value.items()} + value = {k: self.create_kwarg(v) for k, v in value.items()} if _TYPE_KEY in value: value_type = load_type(value.pop(_TYPE_KEY)) return value_type(**value) @@ -700,15 +740,34 @@ def create_kwarg(cls, value: Any): else: return value elif isinstance(value, list): - return [cls.create_kwarg(i) for i in value] + return [self.create_kwarg(i) for i in value] elif isinstance(value, tuple): - return tuple(cls.create_kwarg(i) for i in value) + return tuple(self.create_kwarg(i) for i in value) + elif isinstance(value, str): + # resolved reference + # Note: resolved reference can only be used in various args + # (train/optimizer/dataloader/etc args). + if (value.startswith('$!(') and value.endswith(')')) \ + or (value.startswith('$!{') and value.endswith('}')): + value = value[3:-1] + if value == 'self': + return self + else: + parts = value.split('.') + if parts[0] != 'vars': + raise ValueError(f"Invalid resolved reference {value}. It must be `self` or start with `vars`.") + # resolve self.vars.x.y.z + return self.get_resolved_var('.'.join(parts[1:])) + return value else: return value @property def model_type(self): - return load_type(self.model.type) + m = load_type(self.model.type) + if not inspect.isclass(m) or not issubclass(m, torch.nn.Module): + raise ValueError(f"Invalid model type {self.model.type}. It must be a subclass of torch.nn.Module") + return m @property def resolved_aggregate_outputs_fn(self): @@ -716,6 +775,12 @@ def resolved_aggregate_outputs_fn(self): return None return load_type(self.optimizer.aggregate_outputs_fn) + @property + def resolved_dummy_sample_gen_fn(self): + if not self.dummy_sample_gen_fn: + return None + return load_type(self.dummy_sample_gen_fn) + @property def scaling_factor(self): return self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus @@ -745,6 +810,19 @@ def init_env(self, trainer: 'Trainer'): init_env_fn = load_type(self.init_env_fn) init_env_fn(trainer) + def get_resolved_var(self, fqn: str) -> Any: + """ + Get a resolved variable from the vars dictionary. + The fqn is a full qualified name of the variable, e.g. 'x.y.z'. + """ + parts = fqn.split('.') + var = self._vars + for part in parts: + if part not in var: + raise ValueError(f"Variable {fqn} not found in vars") + var = var[part] + return var + def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model.args) return self.model_type(**kwargs) diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index b5d875c8..333b01ad 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -695,6 +695,8 @@ def __init__(self, create_fn: Callable, name: str, signature: str, annos: Tuple[str], inputs: List[Union[IRTensor, IRObject]], transform_rules: Optional[Tuple[TransformRule]] = None, + *, + constant_foldable=False, **kwargs): """! Create a IRDimops @@ -732,7 +734,7 @@ def __init__(self, create_fn: Callable, name: str, ) n_outputs = len(self._oannos) - super().__init__(name, signature, inputs, n_outputs, **kwargs) + super().__init__(name, signature, inputs, n_outputs, constant_foldable=constant_foldable, **kwargs) @property def anno(self) -> OpAnno: diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 505dc61f..c4bd3366 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -3359,6 +3359,14 @@ def Sigmoid(input, *, out=None, signature=None): return IRDimops(Sigmoid, 'sigmoid', signature, annos, [input]) +def Item(input, signature = None): + """ + torch.Tensor.item() + """ + anno = '? -> ?' + return IRDimops(Item, 'item', signature, [anno], [input], constant_foldable=False) + + def DictKeys(o: Union[Dict, IRObject], signature=None): signature = 'nnscaler.runtime.function.dict_keys' diff --git a/nnscaler/graph/function/pyfunc.py b/nnscaler/graph/function/pyfunc.py index 733849ed..a759eaad 100644 --- a/nnscaler/graph/function/pyfunc.py +++ b/nnscaler/graph/function/pyfunc.py @@ -13,9 +13,10 @@ class IRPyFunc(IRFwOperation): """ def __init__(self, signature: str, - inputs: Tuple[IRObject], outputs: Tuple[IRObject], **kwargs): + inputs: Tuple[IRObject], outputs: Tuple[IRObject], + *, constant_foldable=True, **kwargs): name = signature.split('.')[-1] - super().__init__(name, signature, inputs, len(outputs)) + super().__init__(name, signature, inputs, len(outputs), constant_foldable=constant_foldable) for idx, t in enumerate(outputs): self.set_output(idx, t) self.kwargs.update(**kwargs) diff --git a/nnscaler/graph/parser/mapping.py b/nnscaler/graph/parser/mapping.py index b6865401..55c2792f 100644 --- a/nnscaler/graph/parser/mapping.py +++ b/nnscaler/graph/parser/mapping.py @@ -119,6 +119,7 @@ def exist(signature: str) -> bool: __ttemplate('transpose'): function.Transpose, __tttemplate('expand'): function.Expand, __tttemplate('expand_as'): function.ExpandAs, + __tttemplate('item'): function.Item, __ttemplate('arange'): function.Arange, __ttemplate('linspace'): function.Linspace, __ttemplate('detach'): function.Detach, diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index dc1566af..48001642 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -275,6 +275,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule if not isinstance(output, IRObject): # avoid nested IRObject output = IRObject(name=node.name, value=output, is_constant=is_constant) + elif not isinstance(output, IRTensor): + # make sure is_constant is set correctly + # IRTensor is always non-constant + output.is_constant = is_constant ir_node = IRPyFunc(fsig, input_vals, [output], **kwargs) if not isinstance(ir_node, IRCell): @@ -357,6 +361,7 @@ def _update_name(x: IRObject): # only user registered functions and `getattr` will have undefined output. # So I think the original intention is to avoid folding user registered functions. # 5. Only fold primitive types (int, float, bool, None, str, Ellipsis) and its complex types + # 6. Only fold constant_foldable node def _is_primitive_type(val): # we don't fold a list/tuple/dict with length larger than this # Just a quick filter, and may not work when val has multiple nested levels @@ -369,7 +374,8 @@ def _is_primitive_type(val): return isinstance(val, (int, float, bool, type(None), str, type(Ellipsis))) # Note when it is not IRObject as a whole, we will not fold it - if constant_folding and len(ir_node.outputs()) == 1 \ + if constant_folding and ir_node.constant_foldable \ + and len(ir_node.outputs()) == 1 \ and isinstance(ir_node.output(0), IRObject) \ and not isinstance(ir_node.output(0), IRTensor) \ and not contains_undefined_output \ diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index f042bb18..ae128a3f 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -241,7 +241,14 @@ def udfop(*args, signature=None, **kwargs): anno = OpAnno(anno) ninputs = len(anno.inputs()) if len(args) < ninputs: - raise ValueError(f"calling function {signature} should include at least {ninputs} *args") + # try to fill args with kwargs + args = list(args) + kwargs = dict(kwargs) + for idx in range(len(args), ninputs): + if arg_names[idx] in kwargs: + args.append(kwargs.pop(arg_names[idx])) + else: + raise ValueError(f"calling function {signature} should include at least {ninputs} *args") tensors = args[:ninputs] for idx, t in enumerate(tensors): # argument check diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 11bca941..f359fec3 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -542,6 +542,10 @@ def value(self) -> Any: def is_constant(self) -> bool: return self._is_constant + @is_constant.setter + def is_constant(self, val: bool): + self._is_constant = val + def __eq__(self, obj) -> bool: if not isinstance(obj, IRObject): return False diff --git a/nnscaler/ir/operator.py b/nnscaler/ir/operator.py index 952c83da..77f7722f 100644 --- a/nnscaler/ir/operator.py +++ b/nnscaler/ir/operator.py @@ -17,7 +17,8 @@ class IRFwOperation(IRCell): """ def __init__(self, name: str, signature: str, - inputs: List[IRObject], num_outputs: int, **kwargs): + inputs: List[IRObject], num_outputs: int, + *, constant_foldable=False, **kwargs): """! Create a forward operation. @@ -28,6 +29,7 @@ def __init__(self, name: str, signature: str, """ # recompute schedule self._recompute = None + self._constant_foldable = constant_foldable super().__init__(name, signature, len(inputs), num_outputs) # setup input @@ -76,6 +78,14 @@ def verify_shape(self, outputs=None) -> None: f'\nnode: {self}' ) + @property + def constant_foldable(self) -> bool: + """ + Get whether the operator is constant foldable. + Constant foldable operators can be folded during graph optimization. + """ + return self._constant_foldable + @property def recompute(self) -> Optional[int]: """! diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 37fb55d2..b728269b 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -16,6 +16,7 @@ import os import torch +import torch.distributed from nnscaler.codegen import ModuleCodeGen from nnscaler.codegen.schedule.schedule import ScheduleCodeGen @@ -1074,7 +1075,7 @@ def __init__(self, init_params=True): # all nodes will raise an exception if the code generation is failed. if regen_status == RegenStatus.ERROR: - raise RuntimeError("Reuse generated code failed.") from regen_exception + raise RuntimeError("Code generation failed.") from regen_exception if broadcast_strategy != BroadcastGenFilesStrategy.NONE: if not torch.distributed.is_initialized(): # we only support loading in torchrun environment @@ -1298,9 +1299,17 @@ def build_optimizer( non_parallel_module_reducer = None non_parallel_modules = [m for m in module.modules() if not isinstance(m, ParallelModule)] parallel_modules = [m for m in module.modules() if isinstance(m, ParallelModule)] + if not parallel_modules: raise RuntimeError("No ParallelModule found in the module. Please make sure you have called parallelize() before build_optimizer().") + non_parallel_parameters_dict = {} # use dict for dedup and order + for m in non_parallel_modules: + for param in m.parameters(recurse=False): # only leaf parameters to avoid duplicate + if param is not None and param.requires_grad: + non_parallel_parameters_dict[param] = None + non_parallel_parameters = list(non_parallel_parameters_dict.keys()) + # check if all ParallelModules have the same gpu_config compute_configs = [m.compute_config for m in parallel_modules] for i in range(1, len(compute_configs)): @@ -1313,6 +1322,10 @@ def build_optimizer( # we need to add all parameters of non-parallel modules to a reducer to reduce grads # if there are non-parallel parameters if plan_ngpus != runtime_ngpus and non_parallel_modules and any(p.numel() for m in non_parallel_modules for p in m.parameters(False)): + # For non-parallel modules, we use a Reducer to reduce the gradients. + # Please note here we still follow the original compute_config, + # although we can use a different compute_config for non-parallel modules. + # for example, we can always use plan_ngpus=1, and that may lead better gpu memory usage whe zero is ON. group, _ = compute_configs[0].get_sync_group() reducer_config = {} if compute_config: @@ -1324,9 +1337,8 @@ def build_optimizer( 'zero_ngroups': compute_config.zero_ngroups, } non_parallel_module_reducer = Reducer(group, **reducer_config) - for m in non_parallel_modules: - for param in m.parameters(recurse=False): # only add leaf parameters to avoid duplicate - non_parallel_module_reducer.add_param(param) + for param in non_parallel_parameters: + non_parallel_module_reducer.add_param(param) non_parallel_module_reducer.build_buckets() opt_module_locs: Dict[str, ModuleParameterLocation] = {} @@ -1339,9 +1351,14 @@ def _local_parameters(module: torch.nn.Module): m.parameters_for_optimizer() if m.compute_config.use_zero else m.parameters() # `ParallelModule.merge_partial_states` supports parameters_for_optimizer() only in zero mode ) + if p is not None and p.requires_grad ] if isinstance(m, ParallelModule) - else m._parameters.items() + else [ + (name, p) + for name, p in m._parameters.items() + if p is not None and p.requires_grad + ] ) for idx, (name, param) in enumerate(gen): if name.endswith(pm_suffix): # is a parameter of ParallelModule @@ -1374,6 +1391,8 @@ def _step_pre_hook(opt, *args, **kwargs): def _step_post_hook(opt, *args, **kwargs): for m in parallel_modules: m.gather_params() + if non_parallel_module_reducer: + non_parallel_module_reducer.gather_params() # Please note: # register_step_pre_hook doesn't work expectly @@ -1389,6 +1408,19 @@ def _patched_zero_grad(self, set_to_none: bool = True): m.zero_grad() if non_parallel_module_reducer: non_parallel_module_reducer.zero_grad() + elif non_parallel_parameters: + # for the case when non-parallel modules are not managed by a reducer + for p in non_parallel_parameters: + # copied from Module.zero_grad() + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + p.grad.zero_() + optimizer.zero_grad = types.MethodType(_patched_zero_grad, optimizer) orig_state_dict = optimizer.state_dict @@ -1430,8 +1462,25 @@ def _clip_gnorm(self, max_norm: Optional[float] = None): grads.extend(mgrads) if non_parallel_module_reducer: + # all non parallel module parameters are the same across all ranks + # but we still need to handle the case when zero is on to get correct gnorm params = non_parallel_module_reducer.parameters_for_optimizer() mnorm, mgrads = calcuate_gnorm(params) + mnorm_squared = torch.square(mnorm) + if non_parallel_module_reducer.zero: + torch.distributed.all_reduce(mnorm_squared) + # parameters are duplicated `zero_ngroups * plan_ngpus` times. + # so we need to divide the norm by `zero_ngroups * plan_ngpus` to get the correct gnorm + # Reason (also see how non_parallel_module_reducer is constructed above): + # 1. Ranks in the same scale unit (plan_ngpus) have grads from exactly the same parameters. + # because they are in the same position of a zero group. + # 2. Ranks in the same position of different zero groups have grads from exactly the same parameters + mnorm_squared.div_(non_parallel_module_reducer.zero_ngroups * plan_ngpus) + total_norm_squared += mnorm_squared + grads.extend(mgrads) + elif non_parallel_parameters: + # for the case when non-parallel modules are not managed by a reducer + mnorm, mgrads = calcuate_gnorm(non_parallel_parameters) total_norm_squared += torch.square(mnorm) grads.extend(mgrads) @@ -2282,11 +2331,14 @@ def load_deduped_state_dict( def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_group_size: int): + if not state_indexes: + return + rank = torch.distributed.get_rank() broadcast_group = setup_stride_broadcast_group(dedup_group_size) src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks - logging.info(f'Rank-{rank} is broadcasting states to ranks {curr_parallel_group_ranks}, broadcast root: {src_rank}...') + logging.info(f'Rank-{rank} is broadcasting optimizer states to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') # broadcast param groups and state keys/shapes/dtypes via broadcast_object_list if rank == src_rank: @@ -2310,20 +2362,23 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g # broadcast step # step is too small, so we can just broadcast all of them all together - if rank == src_rank: - step_stack = torch.stack( - [optimizer_state_dict['state'][k]['step'] for k in state_indexes] - ) - else: - step_stack = torch.zeros( - len(state_indexes), - dtype=optimizer_state_dict['state'][state_indexes[0]]['step'].dtype, - device=torch.cuda.current_device() - ) - torch.distributed.broadcast(step_stack, src=src_rank, group=curr_parallel_group) - if rank != src_rank: - for k, v in zip(state_indexes, step_stack): - optimizer_state_dict['state'][k]['step'].copy_(v) + # some adam/adamw optimizers may not have step in their state dict + # so we need to check if 'step' is in the state dict + if 'step' in optimizer_state_dict['state'][state_indexes[0]]: + if rank == src_rank: + step_stack = torch.stack( + [optimizer_state_dict['state'][k]['step'] for k in state_indexes] + ) + else: + step_stack = torch.zeros( + len(state_indexes), + dtype=optimizer_state_dict['state'][state_indexes[0]]['step'].dtype, + device=torch.cuda.current_device() + ) + torch.distributed.broadcast(step_stack, src=src_rank, group=curr_parallel_group) + if rank != src_rank: + for k, v in zip(state_indexes, step_stack): + optimizer_state_dict['state'][k]['step'].copy_(v) # broadcast other states # TODO: can be slow? @@ -2331,7 +2386,8 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g keys = sorted(optimizer_state_dict['state'][k].keys()) # for mixed precision f16 optimizer, we will add custom keys # assert set(keys) == {'step', 'exp_avg', 'exp_avg_sq'} - keys.remove('step') # we have done step in previous. + if 'step' in keys: + keys.remove('step') # we have done step in previous. for key in keys: value = optimizer_state_dict['state'][k][key] torch.distributed.broadcast(value.data, src=src_rank, group=curr_parallel_group) @@ -2368,7 +2424,7 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): broadcast_group = setup_stride_broadcast_group(stride_size) rank = torch.distributed.get_rank() src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks - logging.info(f'Rank-{rank} is broadcasting weight to ranks {curr_parallel_group_ranks}, broadcast root: {src_rank}...') + logging.info(f'Rank-{rank} is broadcasting weights of {module.__class__.__name__} to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') if isinstance(module, ParallelModule): if not _broadcast_single_value(src_rank, curr_parallel_group, module.non_presistent_buffers_inited): diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 2d162e50..27c7f54d 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -24,6 +24,7 @@ from nnscaler.ir.operator import IRFwOperation from nnscaler.graph.parser.register import CustomizedOps from nnscaler.graph.tracer.metadata import AutocastInfo +from nnscaler.utils import load_type _logger = logging.getLogger(__name__) @@ -86,7 +87,7 @@ def get_func(node: IRFwOperation) -> Tuple[Callable, Shapes, DTypes, Dict]: if node.signature in CustomizedOps.kOpRuntime: fn = CustomizedOps.kOpRuntime[node.signature] else: - fn = eval(node.signature) + fn = load_type(node.signature) shapes, dtypes, requires_grads, values = [], [], [], [] # TODO: this function should rewrite with pytree diff --git a/nnscaler/runtime/device.py b/nnscaler/runtime/device.py index 10e02922..d8c81484 100644 --- a/nnscaler/runtime/device.py +++ b/nnscaler/runtime/device.py @@ -20,6 +20,7 @@ class _DeviceGroup: def __init__(self): + self._is_pg_initer = False if CompileFlag.dev_mode or not is_running_distributed(): self.rank = 0 self.world_size = 1 @@ -31,6 +32,7 @@ def __init__(self): torch.distributed.init_process_group( backend='nccl', timeout=_LARGE_TIMEOUT ) + self._is_pg_initer = True # disable it for now due to connection refused error when nnodes > 1 # TODO: investigate the root cause @@ -53,6 +55,11 @@ def __init__(self): self.streams: Dict[str, torch.cuda.Stream] = { 'default': torch.cuda.default_stream()} + def close(self): + if self._is_pg_initer: + torch.distributed.destroy_process_group() + self._is_pg_initer = False + def group_exists(self, ranks): """ Check if group exists @@ -143,3 +150,14 @@ def DeviceGroup() -> _DeviceGroup: if _instance is None: _instance = _DeviceGroup() return _instance + + +def init_device(): + _ = DeviceGroup() + + +def uninit_device(): + global _instance + if _instance is not None: + _instance.close() + _instance = None diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py index fc6b770e..414361a5 100644 --- a/nnscaler/runtime/f16_optimizer.py +++ b/nnscaler/runtime/f16_optimizer.py @@ -141,6 +141,18 @@ def _sync_fp32_params_to_f16(self): continue p.data.copy_(p32.data) + def _sync_fp16_params_to_fp32(self): + # copy FP16 params to FP32 + for p, p32 in zip(self.f16_params, self.fp32_params): + if not p.requires_grad: + continue + p32.data.copy_(p.data) + + def after_load_checkpoint(self, trainer, checkpoint) -> None: + if 'nnscaler' not in checkpoint: + # this checkpoint is not created by nnscaler. + self._sync_fp16_params_to_fp32() + def overrided_scale_grads(self, scale: float): """ Scale the gradients by a factor. diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 9213dcdd..0e26d483 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -1170,6 +1170,26 @@ def _pre_load_state_dict_hook(self, state_dict, prefix, local_metadata, strict, # Both load_state_dict and load_deduped_state_dict will trigger this hook self._warn_uninitialized_non_persistent_buffers() + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """ + Override the default load_from_state_dict to support loading from a merged state dict. + Please note + 1. we have to trigger the pre_load_state_dict_hook explicitly when loading merged state dict, + 2. post_load_state_dict_hook will be triggered in `super().load_state_dict`. + """ + if f'{prefix}{self.EXTRA_STATE_KEY}' in state_dict: + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs + ) + else: + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + new_missing_keys = self.load_merged_state_dict(state_dict, prefix, strict=False) + if strict: + missing_keys.extend(new_missing_keys) + @property def module_dedup_group_size(self) -> int: """ @@ -1212,7 +1232,8 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s we only make sure no missing keys. Unexpected keys are not checked. Default: `True` Returns: - None + list[str]: a list of missing keys in the state_dict. + Please note currently we don't check unexpected keys. Raises: RuntimeError: if strict=True and there are missing keys. """ @@ -1247,11 +1268,13 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s tensor.copy_(content) attr_names.remove(attr) + missing_keys = [prefix + self._fullmap[attr].orig_name for attr in attr_names] if len(attr_names) != 0: - erro_msg = f'Missing key(s) in state_dict: {[prefix + self._fullmap[attr].orig_name for attr in attr_names]}.' + erro_msg = f'Missing key(s) in state_dict: {missing_keys}.' if strict: raise RuntimeError(erro_msg) else: _logger.warning(erro_msg) self._warn_uninitialized_non_persistent_buffers() + return missing_keys diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 00408e24..310c6649 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import builtins +import importlib from contextlib import contextmanager from functools import wraps from typing import ( @@ -180,6 +182,7 @@ def enforce_zero_num_worker(cls) -> Generator[None, None, None]: def _new__init__(self, *args, **kwargs) -> None: kwargs['num_workers'] = 0 kwargs['prefetch_factor'] = None + kwargs['persistent_workers'] = False _old__init__(self, *args, **kwargs) cls.__init__ = _new__init__ yield @@ -322,6 +325,41 @@ def fields(model: TDataClass, /) -> TDataClass: return cast(TDataClass, _GetFields(model)) +def load_type(type_name: str): + """ + Load function/class from its full qualified name + """ + if callable(type_name): # a function or class + return type_name + + parts = type_name.split('.') + + last_ex = None + # s: the number of parts to be the namespace + # s == 0: use builtins + # so the range() part includes 0 (with stop=-1) + for s in range(len(parts) - 1, -1, -1): + if s == 0: + nm = builtins + else: + namespace = '.'.join(parts[:s]) + try: + nm = importlib.import_module(namespace) + break + except (ImportError, ModuleNotFoundError) as e: + last_ex = e + + try: + for i in range(s, len(parts)): + nm = getattr(nm, parts[i]) + return nm + except AttributeError as e: + # give a hint of the import error + # TODO: a better way? + e.__context__ = last_ex + raise RuntimeError(f"Failed to load type {type_name}") from e + + class accum_mode: """Make cube execution in gradient accumulation mode. diff --git a/requirements.txt b/requirements.txt index 8ab5bea9..2f7f75b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ more-itertools numpy>=1.23.0 psutil pulp -pybind11 +pybind11<3.0.0 pyyaml torch>=2.0,<=2.6 tqdm diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index fff9bcbd..feff6391 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -8,7 +8,7 @@ import pytest from nnscaler.cli.arg_parser import ( - parse_args, deserialize_dataclass, _fix_type, + _KeyNotFoundError, parse_args, deserialize_dataclass, _fix_type, resolve_args, merge_args ) @@ -332,6 +332,9 @@ def test_resolve_args2(): 'v': {'a': 10, 'b': 20}, 'x': {'y': '$(c.d)', 'z': '$(k.2.y)'}, 'k': ['$(f.g.h)', {'x': 20}, {'y': '$(c.e)'}], + 'm': '${k.0}1', + 'n': '${m}2$(x.y)', + 'o': r'$\{k.0}1$(x.y)2$\(x.y)', } resolve_args(data) assert data == { @@ -342,6 +345,9 @@ def test_resolve_args2(): 'v': {'a': 10, 'b': 20}, 'x': {'y': 3, 'z': 4}, 'k': [5, {'x': 20}, {'y': 4}], + 'm': '51', + 'n': '5123', + 'o': '${k.0}132$(x.y)', } def test_circular_resolve_args(): @@ -374,5 +380,5 @@ def test_missing_resolve_args(): 'b': True, 'c': {'d': 3, 'e': '$(k.2.y)'}, } - with pytest.raises(ValueError): + with pytest.raises(_KeyNotFoundError): resolve_args(data) diff --git a/tests/cli/test_resume_seed.py b/tests/cli/test_resume_seed.py index 43cfb4ae..06ea9d57 100644 --- a/tests/cli/test_resume_seed.py +++ b/tests/cli/test_resume_seed.py @@ -69,7 +69,7 @@ def _train(tmp_path, steps_per_epoch, max_train_steps, resume_from): model=ModelConfig(type=Model), optimizer=OptimizerConfig(type=torch.optim.AdamW), dataset=DatasetConfig(type=RandomDataset, train_args={'length': steps_per_epoch}), - checkpoint=CheckpointConfig(resume_from=resume_from, save_dir=tmp_path/'checkpoints'), + checkpoint=CheckpointConfig(resume_from=ResumeOptions(checkpoint=resume_from), save_dir=tmp_path/'checkpoints'), max_train_steps=max_train_steps, enable_progress_bar=False, seed=0, diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 10443430..d69a392f 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -86,6 +86,27 @@ def test_trainer_compile_worker(tmp_path): assert set([f.name for f in gen_savedir.glob('**/*.py')]) == set(['gencode0.py', 'gencode1.py', 'gencode2.py', 'gencode3.py']) shutil.rmtree(gen_savedir) + # mixed compile only + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + '--model.type', 'tests.cli.common.MixedModule', + '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP2', + '--model.parallel_modules.0.args.dim', '16', + '--model.parallel_modules.0.args.nlayers', '16', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', + ]) + trainer.run() + + assert set([f.name for f in gen_savedir.glob('**/*.py')]) == set(['gencode0.py', 'gencode1.py', 'gencode2.py', 'gencode3.py']) + shutil.rmtree(gen_savedir) + def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): save_dir = Path(save_dir) diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml index 72e1c202..db05b3e6 100644 --- a/tests/cli/trainer_args.yaml +++ b/tests/cli/trainer_args.yaml @@ -1,3 +1,6 @@ +vars: + dim: 16 + drop_last: true compute_config: plan_ngpus: 4 runtime_ngpus: 100 @@ -16,7 +19,7 @@ seed: 0 model: type: tests.cli.common.MLP args: - dim: 16 + dim: $(vars.dim) nlayers: 16 optimizer: @@ -27,17 +30,17 @@ optimizer: dataset: type: tests.cli.common.SimpleDataset train_args: - dim: 16 + dim: $(vars.dim) size: 100 val_args: - dim: 16 + dim: $(vars.dim) size: 10 dataloader: train_args: - drop_last: true + drop_last: $(vars.drop_last) val_args: - drop_last: true + drop_last: $(vars.drop_last) checkpoint: keep_last_n_checkpoints: 30 diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index d13e2db6..4669ee0b 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -255,3 +255,26 @@ def test_input_gen_fn(): fn = CustomizedOps.kOpInputGen[select_node.signature] ret = mock_select(*fn(select_node)) assert True + + +class MockModelKwargs(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x, y): + x, y = self.fc(x), self.fc(y) + return mock_add(x=x, y=y) + + +# passed test +@replace_all_device_with('cpu') +def test_kw_args(): + model = MockModelKwargs() + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) + + # test profiler.database + for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'mock_add']): + profile_name = get_func(node)[0].__qualname__ + assert profile_name == p_name, f'{profile_name} should be {p_name}' diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 594a73d5..ef1fc2d1 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -675,6 +675,37 @@ def test_codegen_clone(): assert isinstance(g.nodes()[0], nnscaler.graph.function.dimops.IRDimops) +class ItemModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + return a.item() + + +def _gencode_item_function_worker(tempdir): + init_distributed() + m_new = parallelize( + ItemModule(), + { + 'a': torch.tensor([5.0]), + }, + 'dp', + ComputeConfig(1, 1, constant_folding=True), + gen_savedir=tempdir, + load_module=True + ) + assert m_new is not None + # never fold torch.Tensor.item() to constant + assert _gencode_contains(tempdir, ItemModule, 0, '.*torch.Tensor.item.*') + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_codegen_item(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, _gencode_item_function_worker, tempdir) + + class MinModule(torch.nn.Module): def __init__(self): super().__init__() From 7ad94d174a21da8e77216883f36f5af71a2221ce Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 22 Jul 2025 09:12:11 +0000 Subject: [PATCH 1817/1892] Merged PR 2381: [AutoDist] Remove cppimport Since cppimport is not actively maintained, we remove it and compile by pybind11 directly --- nnscaler/autodist/dp_solver.cpp | 11 ----- nnscaler/autodist/spmd_solver.py | 6 +-- pyproject.toml | 4 +- requirements.txt | 1 - setup.py | 82 ++++++++++++++++++++++++++++++++ tests/autodist/test_dp_solver.py | 1 - tox.ini | 2 +- 7 files changed, 87 insertions(+), 20 deletions(-) create mode 100644 setup.py diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index ecfda7f7..334617e5 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -1,5 +1,3 @@ -// cppimport - /* * Copyright (c) Microsoft Corporation. * Licensed under the MIT License. @@ -625,12 +623,3 @@ PYBIND11_MODULE(dp_solver, m) { .def("solve", &DPSolver::solve) .def("get_results", &DPSolver::get_results); } -/* -<% -setup_pybind11(cfg) -cfg['extra_compile_args'] = ['-std=c++11'] -cfg['extra_compile_args'] = ['-O3'] -cfg['extra_compile_args'] = ['-pthread'] -cfg['dependencies'] = ['dp_solver.h'] -%> -*/ diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index cca7f0be..1a1e0ec8 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -1187,14 +1187,12 @@ def do_ilp(self, intervals: List[Tuple[int, int]], def do_dp(self, intervals: List[Tuple[int, int]], topk: int) -> List[List[SPMDSearchOutput]]: - import cppimport.import_hook try: import nnscaler.autodist.dp_solver as dp_solver except ImportError: raise RuntimeError( - 'Failed to import solver. ' - 'If you installed nnscaler from source (`pip install -e .`), ' - 'please also make sure to put parent directory of `nnscaler` in `PYTHONPATH`.' + 'Failed to import dp_solver. ' + 'Please install nnscaler with: pip install -e .' ) if self.autodist_config.memory_granularity < 1024: diff --git a/pyproject.toml b/pyproject.toml index e73228f2..654da046 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools"] +requires = ["setuptools>=45", "wheel", "pybind11>=2.6", "pybind11[global]"] build-backend = "setuptools.build_meta" [project] @@ -28,7 +28,7 @@ dynamic.dependencies.file = "requirements.txt" # NOTE: # the following part only affects wheel, not sdist -# since we are using cppimport, sdist is not needed +# C++ extensions are now built via setup.py instead of cppimport packages.find.include = ["nnscaler*"] package-data = { nnscaler = ["resources/**", "autodist/*.h", "autodist/*.cpp"] } diff --git a/requirements.txt b/requirements.txt index 2f7f75b3..6ae52a83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -cppimport dill matplotlib more-itertools diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..9d4dc0c0 --- /dev/null +++ b/setup.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +""" +Setup script for nnscaler with C++ extensions +""" + +import os +from pybind11.setup_helpers import Pybind11Extension, build_ext +import pybind11 +from setuptools import setup + +# Define C++ extensions +def get_ext_modules(): + """Get extension modules with appropriate compiler flags""" + + # Base compile args + compile_args = ['-O3', '-fPIC'] + + # Try to use older ABI for better compatibility (following PyTorch's approach) + compile_args.append('-D_GLIBCXX_USE_CXX11_ABI=0') + + # Link arguments + link_args = ['-lpthread'] + + # conda environment handling, since: + # - nnscaler may be installed in conda, for example, user's development environment and our ci. + # - libstdc++ in conda may be different from system libstdc++. + # - we prefer the conda version for compatibility. + conda_prefix = os.environ.get('CONDA_PREFIX') + if not conda_prefix: + # Fallback to ANACONDA_PYTHON_VERSION like PyTorch does + anaconda_python_version = os.environ.get('ANACONDA_PYTHON_VERSION') + if anaconda_python_version: + conda_prefix = f"/opt/conda/envs/py_{anaconda_python_version}" + + if conda_prefix: + # Add conda library path with RPATH for runtime discovery + conda_lib_path = os.path.join(conda_prefix, 'lib') + if os.path.exists(conda_lib_path): + link_args.extend([f'-L{conda_lib_path}', f'-Wl,-rpath,{conda_lib_path}']) + + ext_modules = [ + Pybind11Extension( + "nnscaler.autodist.dp_solver", + [ + "nnscaler/autodist/dp_solver.cpp", + ], + include_dirs=[ + pybind11.get_include(), + "nnscaler/autodist", + ], + language='c++', + cxx_std=11, + extra_compile_args=compile_args, + extra_link_args=link_args, + ), + ] + + return ext_modules + +# Custom build_ext class to provide feedback +class CustomBuildExt(build_ext): + """Custom build extension to handle C++ compilation""" + + def build_extensions(self): + print("Building C++ extensions...") + for ext in self.extensions: + print(f" - {ext.name}") + + # Print environment info + conda_prefix = os.environ.get('CONDA_PREFIX') + if conda_prefix: + print(f" Using conda environment: {conda_prefix}") + + super().build_extensions() + print("C++ extensions built successfully!") + +setup( + ext_modules=get_ext_modules(), + cmdclass={"build_ext": CustomBuildExt}, + zip_safe=False, +) diff --git a/tests/autodist/test_dp_solver.py b/tests/autodist/test_dp_solver.py index 29f33322..846e1c05 100644 --- a/tests/autodist/test_dp_solver.py +++ b/tests/autodist/test_dp_solver.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import cppimport.import_hook import nnscaler.autodist.dp_solver as dp_solver # use a naive ffn to test the dynamic programming solver diff --git a/tox.ini b/tox.ini index b1236354..04e0c3ae 100644 --- a/tox.ini +++ b/tox.ini @@ -14,8 +14,8 @@ install_command = pip install {opts} {packages} deps = -rrequirements.txt -rrequirements-dev.txt + -e . commands = coverage erase - rm -f {envdir}/lib/libstdc++.so.6 # force using system libstdc++ pytest --cov={toxinidir}/nnscaler -x tests coverage html rm -rf {envdir} From ec519358b4c668dd8b5c42a1f92bcd03295b97e0 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 23 Jul 2025 06:25:56 +0000 Subject: [PATCH 1818/1892] Merged PR 2382: [BugFix] TrainerArgs Resolution for command line args When an argument is passed from command line. the new value is not used in args resolution. --- nnscaler/cli/trainer_args.py | 3 ++- tests/cli/test_arg_parser.py | 1 + tests/cli/test_train_args.py | 16 +++++++++++++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 221e1717..62d1b9b8 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -695,10 +695,11 @@ def from_cli(cls, argv: List[str]) -> 'TrainerArgs': if argv[0] == '-f': with open(argv[1], 'r') as f: d = yaml.safe_load(f) - resolve_args(d) argv = argv[2:] merge_args(d, argv) + resolve_args(d) + return cls.from_dict(d) @classmethod diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index feff6391..427c8f26 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -350,6 +350,7 @@ def test_resolve_args2(): 'o': '${k.0}132$(x.y)', } + def test_circular_resolve_args(): data = { 'a': 1, diff --git a/tests/cli/test_train_args.py b/tests/cli/test_train_args.py index f4b1b62d..aabce253 100644 --- a/tests/cli/test_train_args.py +++ b/tests/cli/test_train_args.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from pathlib import Path import pytest import nnscaler -from nnscaler.cli.trainer_args import load_type, ComputeConfig, OptionalComputeConfig +from nnscaler.cli.trainer_args import load_type, ComputeConfig, OptionalComputeConfig, TrainerArgs def test_load_type(): @@ -37,3 +38,16 @@ def test_compute_config_merge(): occ2 = OptionalComputeConfig(zero_ngroups=-1) with pytest.raises(ValueError): occ2.resolve(cc) + + +def test_arg_merge_resolve(): + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + args = TrainerArgs.from_cli(['-f', config_path, + '--vars.dim', '22', + '--vars.hello', '$(compute_config.plan_ngpus)', + '--global_batch_size!' + ]) + assert args.vars['dim'] == 22 + assert args.dataset.train_args['dim'] == 22 + assert args.dataset.val_args['dim'] == 22 + assert args.vars['hello'] == args.compute_config.plan_ngpus From 3537d431f14e793b005312e88b5df3b345567e6f Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 24 Jul 2025 03:10:37 +0000 Subject: [PATCH 1819/1892] Merged PR 2374: [Example] Refine ring-attention --- .../chunk_linear_cross_entropy.py | 63 --- .../customized_ops}/README.md | 1 + .../chunk_linear_cross_entropy.py | 0 .../customized_ops/ring_attention/__init__.py | 5 + .../core/ring_attn_implementation.py | 183 ++++--- .../core/ring_attn_varlen_implementation.py | 508 ++++++++++++++++++ .../ring_attention/core/utils.py | 86 +++ .../core/zigzag_attn_implementation.py | 87 ++- .../ring_attention/ring_attn.py | 20 +- .../ring_attention/ring_attn_varlen.py | 145 +++++ .../ring_attention/zigzag_attn.py | 16 +- .../test_chunk_linear_cross_entropy.py | 0 .../customized_ops}/test_ring_attn.py | 104 ++-- .../customized_ops/test_ring_attn_varlen.py | 129 +++++ .../customized_ops}/test_zigzag_attn.py | 104 ++-- .../lm_models/diff_transformer_modifier.py | 20 +- examples/llama/train.py | 2 +- 17 files changed, 1157 insertions(+), 316 deletions(-) delete mode 100644 examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py rename examples/{customized_ops/ring_attention => llama/customized_ops}/README.md (96%) rename examples/llama/{ => customized_ops}/chunk_linear_cross_entropy.py (100%) create mode 100644 examples/llama/customized_ops/ring_attention/__init__.py rename examples/{ => llama}/customized_ops/ring_attention/core/ring_attn_implementation.py (63%) create mode 100644 examples/llama/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py rename examples/{ => llama}/customized_ops/ring_attention/core/utils.py (75%) rename examples/{ => llama}/customized_ops/ring_attention/core/zigzag_attn_implementation.py (86%) rename examples/{ => llama}/customized_ops/ring_attention/ring_attn.py (83%) create mode 100644 examples/llama/customized_ops/ring_attention/ring_attn_varlen.py rename examples/{ => llama}/customized_ops/ring_attention/zigzag_attn.py (85%) rename examples/{customized_ops/chunk_linear_cross_entropy => llama/customized_ops}/test_chunk_linear_cross_entropy.py (100%) rename examples/{customized_ops/ring_attention => llama/customized_ops}/test_ring_attn.py (50%) create mode 100644 examples/llama/customized_ops/test_ring_attn_varlen.py rename examples/{customized_ops/ring_attention => llama/customized_ops}/test_zigzag_attn.py (50%) diff --git a/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py b/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py deleted file mode 100644 index 2f81f00a..00000000 --- a/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import torch.utils.checkpoint as ckpt - - -def linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int = 0) -> torch.Tensor: - """ - Compute the cross entropy loss of a linear layer. - - Args: - - x: [token_num, hidden_size], the last hidden state of the model - w: [dict_size, hidden_size], the weight matrix of the last linear layer - y: [token_num], the target token index - padding_idx: int, the index of padding token - - Returns: - - losses: [token_num], the cross entropy loss of each token - """ - logits = torch.nn.functional.linear(x, w) - normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) - losses = torch.nn.functional.nll_loss(normalized_logits, y, reduction='none', ignore_index=padding_idx) - return losses - - -def chunk_linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor: - """ - In order to reduce the memory usage when the sequence length and dictionary size are large, we can split the input - tensor into chunks and compute the cross entropy loss of each chunk separately. - You can register this function with annotation 'b l d^, n^ d^, b l -> b l'. - - Args: - - x: [bsz, seq_len, hidden_size], the last hidden state of the model - w: [dict_size, hidden_size], the weight matrix of the last linear layer - y: [bsz, seq_len], the target token index - padding_idx: int, the index of padding token - chunk_size: int, the size of each chunk - - Returns: - - losses: [bsz, seq_len], the cross entropy loss of each token - """ - bsz, seq_len, hidden_size = x.size() - token_num = bsz * seq_len - x = x.view(token_num, hidden_size) - y = y.view(token_num) - - if token_num % chunk_size != 0: - raise ValueError(f"token_num {token_num} is not divisible by chunk_size {chunk_size}") - - chunk_num = token_num // chunk_size - xs = x.view(chunk_num, chunk_size, hidden_size) - ys = y.view(chunk_num, chunk_size) - losses = [] - for i in range(chunk_num): - loss = ckpt.checkpoint(linear_cross_entropy, xs[i], w, ys[i], padding_idx, use_reentrant=False) - losses.append(loss) - losses = torch.stack(losses).view(bsz, seq_len) - return losses diff --git a/examples/customized_ops/ring_attention/README.md b/examples/llama/customized_ops/README.md similarity index 96% rename from examples/customized_ops/ring_attention/README.md rename to examples/llama/customized_ops/README.md index 35f9fede..a20398c8 100644 --- a/examples/customized_ops/ring_attention/README.md +++ b/examples/llama/customized_ops/README.md @@ -18,4 +18,5 @@ Test can be run with the following command: ```bash torchrun --nproc_per_node 4 test_ring_attn.py torchrun --nproc_per_node 4 test_zigzag_attn.py +torchrun --nproc_per_node 4 test_ring_attn_varlen.py ``` \ No newline at end of file diff --git a/examples/llama/chunk_linear_cross_entropy.py b/examples/llama/customized_ops/chunk_linear_cross_entropy.py similarity index 100% rename from examples/llama/chunk_linear_cross_entropy.py rename to examples/llama/customized_ops/chunk_linear_cross_entropy.py diff --git a/examples/llama/customized_ops/ring_attention/__init__.py b/examples/llama/customized_ops/ring_attention/__init__.py new file mode 100644 index 00000000..405f1207 --- /dev/null +++ b/examples/llama/customized_ops/ring_attention/__init__.py @@ -0,0 +1,5 @@ +from .ring_attn_varlen import wrap_ring_attn_varlen_func + +from .zigzag_attn import wrap_zigzag_attn_func + +from .ring_attn import wrap_ring_attn_func \ No newline at end of file diff --git a/examples/customized_ops/ring_attention/core/ring_attn_implementation.py b/examples/llama/customized_ops/ring_attention/core/ring_attn_implementation.py similarity index 63% rename from examples/customized_ops/ring_attention/core/ring_attn_implementation.py rename to examples/llama/customized_ops/ring_attention/core/ring_attn_implementation.py index f7c23f16..85b374d8 100644 --- a/examples/customized_ops/ring_attention/core/ring_attn_implementation.py +++ b/examples/llama/customized_ops/ring_attention/core/ring_attn_implementation.py @@ -4,19 +4,12 @@ import torch import torch.distributed as dist from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward -from .utils import shuffle_input, recover_output, GlobalMemoryBuffer +from .utils import shuffle_input, recover_output, GlobalMemoryBuffer, get_default_args _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() -import flash_attn - -version = flash_attn.__version__ -if not version.startswith("2.6"): - raise ImportError("The current version of Ring Attention is not compatible with Flash Attention versions other than 2.6.x.") - - def ring_flash_attn_forward( process_group, q: torch.Tensor, @@ -26,10 +19,40 @@ def ring_flash_attn_forward( dropout_p=0, causal=True, window_size=(-1, -1), - softcap=0.0, alibi_slopes=None, deterministic=False, ): + def forward(q, k, v, causal): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + block_len = q.size(1) // 2 curr_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) @@ -45,18 +68,8 @@ def ring_flash_attn_forward( up_v = v[:, :(up_rank + 1) * block_len] else: up_k, up_v = k, v - up_out, _, _, _, _, up_lse, _, _ = _flash_attn_forward( - up_q, - up_k, - up_v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) + + up_out, up_lse = forward(up_q, up_k, up_v, causal) down_q = q[:, block_len:] if causal: @@ -64,18 +77,7 @@ def ring_flash_attn_forward( down_v = v[:, :(down_rank + 1) * block_len] else: down_k, down_v = k, v - down_out, _, _, _, _, down_lse, _, _ = _flash_attn_forward( - down_q, - down_k, - down_v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) + down_out, down_lse = forward(down_q, down_k, down_v, causal) out = torch.cat([up_out, down_out], dim=1) return out, up_lse, down_lse @@ -94,7 +96,6 @@ def ring_flash_attn_backward( dropout_p=0, causal=True, window_size=(-1, -1), - softcap=0.0, alibi_slopes=None, deterministic=False, ): @@ -121,25 +122,36 @@ def ring_flash_attn_backward( up_v = v[:, :(up_rank + 1) * block_len] else: up_k, up_v = k, v - _flash_attn_backward( - up_dout, - up_q, - up_k, - up_v, - up_out, - up_lse, - dq[:, :block_len], - dk_buffer[:, :(up_rank + 1) * block_len], - dv_buffer[:, :(up_rank + 1) * block_len], - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - rng_state=None, + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": up_dout, + "q": up_q, + "k": up_k, + "v": up_v, + "out": up_out, + "softmax_lse": up_lse, + "dq": dq[:, :block_len], + "dk": dk_buffer[:, :(up_rank + 1) * block_len], + "dv": dv_buffer[:, :(up_rank + 1) * block_len], + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) down_q = q[:, block_len:] down_out = out[:, block_len:] @@ -154,25 +166,36 @@ def ring_flash_attn_backward( down_v = v[:, :(down_rank + 1) * block_len] else: down_k, down_v = k, v - _flash_attn_backward( - down_dout, - down_q, - down_k, - down_v, - down_out, - down_lse, - dq[:, block_len:], - down_dk_buffer[:, :(down_rank + 1) * block_len], - down_dv_buffer[:, :(down_rank + 1) * block_len], - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - rng_state=None, + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": down_dout, + "q": down_q, + "k": down_k, + "v": down_v, + "out": down_out, + "softmax_lse": down_lse, + "dq": dq[:, block_len:], + "dk": down_dk_buffer[:, :(down_rank + 1) * block_len], + "dv": down_dv_buffer[:, :(down_rank + 1) * block_len], + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) dk_buffer.add_(down_dk_buffer) dv_buffer.add_(down_dv_buffer) @@ -180,8 +203,8 @@ def ring_flash_attn_backward( dim_size[1] = dim_size[1] // world_size dk = torch.empty(dim_size, dtype=k.dtype, device=k.device) dv = torch.empty(dim_size, dtype=v.dtype, device=v.device) - dist._reduce_scatter_base(dk, dk_buffer, group=process_group) - dist._reduce_scatter_base(dv, dv_buffer, group=process_group) + dist.reduce_scatter_tensor(dk, dk_buffer, group=process_group) + dist.reduce_scatter_tensor(dv, dv_buffer, group=process_group) return dq, dk, dv @@ -205,7 +228,6 @@ def forward( softmax_scale, causal, window_size, - softcap, alibi_slopes, deterministic, return_softmax, @@ -221,8 +243,8 @@ def forward( dim_size[1] = dim_size[1] * world_size k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") - torch.distributed._all_gather_base(k_buffer, k, group=group) - torch.distributed._all_gather_base(v_buffer, v, group=group) + torch.distributed.all_gather_into_tensor(k_buffer, k, group=group) + torch.distributed.all_gather_into_tensor(v_buffer, v, group=group) out, up_lse, down_lse = ring_flash_attn_forward( group, @@ -233,7 +255,6 @@ def forward( dropout_p=dropout_p, causal=causal, window_size=window_size, - softcap=softcap, alibi_slopes=alibi_slopes, deterministic=False, ) @@ -243,12 +264,11 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic ctx.group = group out = recover_output(out, process_group=group) - return out if not return_softmax else (out, softmax_lse, None) + return out @staticmethod def backward(ctx, dout, *args): @@ -259,8 +279,8 @@ def backward(ctx, dout, *args): dim_size[1] = dim_size[1] * world_size k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") - torch.distributed._all_gather_base(k_buffer, k, group=ctx.group) - torch.distributed._all_gather_base(v_buffer, v, group=ctx.group) + torch.distributed.all_gather_into_tensor(k_buffer, k, group=ctx.group) + torch.distributed.all_gather_into_tensor(v_buffer, v, group=ctx.group) dq, dk, dv = ring_flash_attn_backward( ctx.group, @@ -275,9 +295,8 @@ def backward(ctx, dout, *args): dropout_p=ctx.dropout_p, causal=ctx.causal, window_size=ctx.window_size, - softcap=ctx.softcap, alibi_slopes=ctx.alibi_slopes, deterministic=ctx.deterministic, ) dq = recover_output(dq, ctx.group) - return dq, dk, dv, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None diff --git a/examples/llama/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py b/examples/llama/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py new file mode 100644 index 00000000..924f4eda --- /dev/null +++ b/examples/llama/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py @@ -0,0 +1,508 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: Most of code is copied from project https://github.com/zhuzilin/ring-flash-attention + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, +) +from .utils import get_default_args, AllGatherComm as Comm + + +def llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens: torch.Tensor, causal: bool, rank: int, world_size: int +): + """ + Args: + cu_seqlens: torch.Tensor, the cu_seqlens of all the sequences across the ring process group. + + Returns: + cu_seqlens_q: torch.Tensor, the cu_seqlens of the q slice for this rank. + cu_seqlens_k: torch.Tensor, the cu_seqlens of the k slice that the local q need. Note + that this may be longer than `total_seq_len // world_size`. + local_k_slice: slice, the slice of the k that the local q need. Note + that this may be longer than `total_seq_len // world_size`. + """ + total_length = cu_seqlens[-1].item() + assert total_length % world_size == 0 + length_per_rank = total_length // world_size + left = torch.searchsorted(cu_seqlens, rank * length_per_rank) + right = torch.searchsorted(cu_seqlens, (rank + 1) * length_per_rank) + + # after this, cu_seqlens[left:right + 1] contains all the sequence for this rank + if cu_seqlens[left] != rank * length_per_rank: + left -= 1 + left = left.item() + right = right.item() + + # q is always the same. just calculate the cu_seqlens for the local slice + cu_seqlens_q = cu_seqlens[left : right + 1].clone() + cu_seqlens_q -= rank * length_per_rank + cu_seqlens_q[0] = 0 + cu_seqlens_q[-1] = length_per_rank + + cu_seqlens_k = cu_seqlens[left : right + 1].clone() + if causal: + # when causal, we hope + # - the last k seq is of the same length as the last q seq + slice_right = (rank + 1) * length_per_rank + cu_seqlens_k[-1] = slice_right + else: + # when not causal, we hope + # - the last k is full seq + slice_right = cu_seqlens[right].item() + + slice_left = cu_seqlens[left].item() + cu_seqlens_k -= slice_left + + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + local_k_slice = slice(slice_left, slice_right) + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, local_k_slice + + +def llama3_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + out_list = [] + lse_list = [] + + nheads = q.shape[1] + total_k, nheads_k, head_dim = k.shape + assert nheads_k % heads_k_stride == 0 + + world_size = dist.get_world_size(process_group) + kv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + kv_buffer_copy = torch.empty_like(kv_buffer) + + k_0 = k[:, :heads_k_stride].contiguous() + v_0 = v[:, :heads_k_stride].contiguous() + comm = Comm(process_group) + + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + for i in range(0, nheads_k, heads_k_stride): + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + if i < nheads_k - heads_k_stride: + # all_gather the next kv slice + kv_slice_left = i + heads_k_stride + kv_slice_right = kv_slice_left + heads_k_stride + send_k = k[:, kv_slice_left:kv_slice_right].contiguous() + send_v = v[:, kv_slice_left:kv_slice_right].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + q_i = q[:, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + k_i = kv_buffer[0][local_k_slice] + v_i = kv_buffer[1][local_k_slice] + + params = get_default_args(_flash_attn_varlen_forward).copy() + params.update( + { + "q": q_i, + "k": k_i, + "v": v_i, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_varlen_forward(**params) + if len(outputs) == 8: + out, _, _, _, _, lse, _, _ = outputs + else: + assert len(outputs) == 4 + out, lse, _, _ = outputs + out_list.append(out) + lse_list.append(lse) + + out = torch.cat(out_list, dim=1) + lse = torch.cat(lse_list, dim=-2) + return out, lse + + +def llama3_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + nheads = q.shape[1] + total_k, nheads_k, head_dim = k.shape + assert nheads_k % heads_k_stride == 0 + + world_size = dist.get_world_size(process_group) + kv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + kv_buffer_copy = torch.empty_like(kv_buffer) + + dkv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + if heads_k_stride != nheads_k: + kv_contiguous_buffer = torch.empty( + (2, total_k, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + comm = Comm(process_group) + + k_0 = k[:, :heads_k_stride].contiguous() + v_0 = v[:, :heads_k_stride].contiguous() + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + for i in range(0, nheads_k, heads_k_stride): + dkv_buffer.zero_() + + q_slice = slice( + i * nheads // nheads_k, (i + heads_k_stride) * nheads // nheads_k + ) + q_i = q[:, q_slice] + dout_i = dout[:, q_slice] + out_i = out[:, q_slice] + dq_i = dq[:, q_slice] + if softmax_lse.dim() == 3: + lse_i = softmax_lse[:, q_slice].contiguous() + else: + lse_i = softmax_lse[q_slice] + + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + if i < nheads_k - heads_k_stride: + # all_gather the next kv slice + kv_slice_left = i + heads_k_stride + kv_slice_right = kv_slice_left + heads_k_stride + send_k = k[:, kv_slice_left:kv_slice_right].contiguous() + send_v = v[:, kv_slice_left:kv_slice_right].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + k_i = kv_buffer[0][local_k_slice] + v_i = kv_buffer[1][local_k_slice] + dk_i = dkv_buffer[0][local_k_slice] + dv_i = dkv_buffer[1][local_k_slice] + + params = get_default_args(_flash_attn_varlen_backward).copy() + params.update( + { + "dout": dout_i, + "q": q_i, + "k": k_i, + "v": v_i, + "out": out_i, + "softmax_lse": lse_i, + "dq": dq_i, + "dk": dk_i, + "dv": dv_i, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_varlen_backward(**params) + + if heads_k_stride != nheads_k: + # reduce_scatter needs contiguous buffer + dk_i = kv_contiguous_buffer[0] + dv_i = kv_contiguous_buffer[1] + else: + dk_i = dk + dv_i = dv + + dist.reduce_scatter_tensor(dk_i, dkv_buffer[0], group=process_group) + dist.reduce_scatter_tensor(dv_i, dkv_buffer[1], group=process_group) + + if heads_k_stride != nheads_k: + dk[:, i : i + heads_k_stride] = dk_i + dv[:, i : i + heads_k_stride] = dv_i + + return dq, dk, dv + + +class Llama3FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = llama3_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.heads_k_stride = heads_k_stride + ctx.local_k_slice = local_k_slice + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = llama3_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.heads_k_stride, + ctx.local_k_slice, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return (dq, dk, dv) + (None,) * 15 + + +def llama3_flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def llama3_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def llama3_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/examples/customized_ops/ring_attention/core/utils.py b/examples/llama/customized_ops/ring_attention/core/utils.py similarity index 75% rename from examples/customized_ops/ring_attention/core/utils.py rename to examples/llama/customized_ops/ring_attention/core/utils.py index 643fa59d..57383516 100644 --- a/examples/customized_ops/ring_attention/core/utils.py +++ b/examples/llama/customized_ops/ring_attention/core/utils.py @@ -6,11 +6,97 @@ from typing import Optional, Tuple from functools import reduce import operator +import inspect +from functools import cache +import random import torch import torch.distributed as dist +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def gen_head_anno(query_states, key_states, value_states): + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + return q_anno, kv_anno + + +# copied from project https://github.com/zhuzilin/ring-flash-attention +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +class AllGatherComm: + def __init__(self, group=None) -> None: + self.group = group + self.handles = [] + + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): + handle = dist.all_gather_into_tensor( + output_tensor, input_tensor, group=self.group, async_op=True + ) + self.handles.append(handle) + + def wait(self): + for handle in self.handles: + handle.wait() + self.handles = [] + + # copy from megatron/core/utils.py class GlobalMemoryBuffer: """Global buffer to avoid dynamic memory allocations. diff --git a/examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py b/examples/llama/customized_ops/ring_attention/core/zigzag_attn_implementation.py similarity index 86% rename from examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py rename to examples/llama/customized_ops/ring_attention/core/zigzag_attn_implementation.py index f18deac4..fddf74d7 100644 --- a/examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py +++ b/examples/llama/customized_ops/ring_attention/core/zigzag_attn_implementation.py @@ -1,12 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention +# Credits: This implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention import torch import torch.distributed as dist from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward -from .utils import RingComm, update_out_and_lse, shuffle_input, recover_output +from .utils import RingComm, update_out_and_lse, shuffle_input, recover_output, get_default_args ''' Assume we have 4 GPUs A, B, C, D. @@ -157,17 +157,34 @@ def zigzag_ring_flash_attn_forward( next_k, next_v = None, None def forward(q, k, v, causal): - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs return block_out, block_lse for step in range(comm.world_size): @@ -260,24 +277,36 @@ def zigzag_ring_flash_attn_backward( def backward(dout, q, k, v, out, softmax_lse, causal): seqlen_q = q.shape[1] seqlen_kv = k.shape[1] - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq_buffer[:, :seqlen_q], - dk_buffer[:, :seqlen_kv], - dv_buffer[:, :seqlen_kv], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout, + "q": q, + "k": k, + "v": v, + "out": out, + "softmax_lse": softmax_lse, + "dq": dq_buffer[:, :seqlen_q], + "dk": dk_buffer[:, :seqlen_kv], + "dv": dv_buffer[:, :seqlen_kv], + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) for step in range(kv_comm.world_size): if step + 1 != kv_comm.world_size: diff --git a/examples/customized_ops/ring_attention/ring_attn.py b/examples/llama/customized_ops/ring_attention/ring_attn.py similarity index 83% rename from examples/customized_ops/ring_attention/ring_attn.py rename to examples/llama/customized_ops/ring_attention/ring_attn.py index 1a167a7f..6a24bb7f 100644 --- a/examples/customized_ops/ring_attention/ring_attn.py +++ b/examples/llama/customized_ops/ring_attention/ring_attn.py @@ -6,7 +6,8 @@ from nnscaler.graph.parser.register import register_op from nnscaler.ir.operator import IRFwOperation -from core.ring_attn_implementation import RingFlashAttnFunc +from .core.ring_attn_implementation import RingFlashAttnFunc +from .core.utils import gen_head_anno from flash_attn import flash_attn_func from nnscaler.runtime.device import DeviceGroup @@ -14,7 +15,7 @@ def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), - softcap: float=0.0, alibi_slopes: Tensor=None, deterministic: bool=False, + alibi_slopes: Tensor=None, deterministic: bool=False, return_attn_probs: bool=False, process_group: Tuple[int]=None) -> Tensor: ''' @@ -25,12 +26,15 @@ def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=N required communications. ''' + assert alibi_slopes is None, "alibi_slopes is not supported in ring_attn_func" + assert return_attn_probs is False, "return_attn_probs is not supported in ring_attn_func" + if process_group is None or len(process_group) == 1: # there is an additional checker for the `softmax_scale`, which is equivalent # to the behavior of the original flash_attn_func. if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + output = flash_attn_func(q, k, v, softmax_scale=softmax_scale, causal=causal, window_size=window_size,) return output assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" @@ -59,7 +63,6 @@ def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=N softmax_scale, causal, window_size, - softcap, alibi_slopes, deterministic, return_attn_probs, @@ -68,6 +71,7 @@ def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=N return output + def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: """Special rule to generate ring_attn node""" @@ -104,4 +108,10 @@ def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runt args = ", ".join(list(args) + kw_pairs) return f"{signature}({args})" -register_op('bs l h dim^, bs l h dim^, bs l h dim^ -> bs l h dim^', emit_fn=emit_ring)(wrap_ring_attn_func) + +def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states) + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + + +register_op(flash_attention_anno, emit_fn=emit_ring)(wrap_ring_attn_func) diff --git a/examples/llama/customized_ops/ring_attention/ring_attn_varlen.py b/examples/llama/customized_ops/ring_attention/ring_attn_varlen.py new file mode 100644 index 00000000..869dc0bd --- /dev/null +++ b/examples/llama/customized_ops/ring_attention/ring_attn_varlen.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Tuple, List, Dict +from torch import Tensor +import torch.distributed as dist + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir.operator import IRFwOperation +from flash_attn import flash_attn_varlen_func +from .core.ring_attn_varlen_implementation import llama3_flash_attn_prepare_cu_seqlens, llama3_flash_attn_varlen_func +from .core.utils import gen_head_anno + +from nnscaler.runtime.device import DeviceGroup + + +def wrap_ring_attn_varlen_func( + q: Tensor, + k: Tensor, + v: Tensor, + cu_seqlens_q: Tensor, + cu_seqlens_k: Tensor, + dropout_p: float = 0.0, + softmax_scale: Tensor = None, + causal: bool = False, + window_size: Tuple[int] = (-1, -1), + alibi_slopes: Tensor = None, + deterministic: bool = False, + return_attn_probs: bool = False, + process_group: Tuple[int] = None, +): + ''' + wrap the ring_attn_varlen_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_varlen_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + assert not return_attn_probs, "return_attn_probs is not supported in ring-attention" + assert alibi_slopes is None, "alibi_slopes is not supported in ring-attention" + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + + if process_group is None or len(process_group) == 1: + output = flash_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=False, + ) + return output + + assert len(q.shape) == 3, "q must have shape [total_q, qh, dim]" + assert len(k.shape) == 3, "k must have shape [total_k, kh, dim]" + assert len(v.shape) == 3, "v must have shape [total_k, vh, dim]" + total_q, qheads, qdim = q.shape + total_k, kheads, kdim = k.shape + total_v, vheads, vdim = v.shape + assert total_q == total_k == total_v, "total_q, total_k and total_v must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + local_rank = dist.get_rank(local_process_group) + local_world_size = dist.get_world_size(local_process_group) + + ( + local_cu_seqlens_q, + local_cu_seqlens_k, + local_max_seqlen_q, + local_max_seqlen_k, + local_k_slice, + ) = llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens_q, + causal=causal, + rank=local_rank, + world_size=local_world_size, + ) + + output = llama3_flash_attn_varlen_func( + q, + k, + v, + local_cu_seqlens_q, + local_cu_seqlens_k, + local_max_seqlen_q, + local_max_seqlen_k, + heads_k_stride=1, + local_k_slice=local_k_slice, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + group=local_process_group, + ) + + return output + + +def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate ring_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + if partition_dims[0] == 0: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 1: + # partition the head dim, use local flash_attn_func + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states) + return f'l {q_anno} hd^, l {kv_anno} hd^, l {kv_anno} vd^, e^, e^ -> l {q_anno} vd^' + + +register_op(flash_attention_anno, emit_fn=emit_ring)(wrap_ring_attn_varlen_func) diff --git a/examples/customized_ops/ring_attention/zigzag_attn.py b/examples/llama/customized_ops/ring_attention/zigzag_attn.py similarity index 85% rename from examples/customized_ops/ring_attention/zigzag_attn.py rename to examples/llama/customized_ops/ring_attention/zigzag_attn.py index 3fccb59b..2373f9d0 100644 --- a/examples/customized_ops/ring_attention/zigzag_attn.py +++ b/examples/llama/customized_ops/ring_attention/zigzag_attn.py @@ -8,12 +8,14 @@ from nnscaler.graph.parser.register import register_op from nnscaler.ir.operator import IRFwOperation -from core.zigzag_attn_implementation import ZigZagRingFlashAttnFunc +from .core.zigzag_attn_implementation import ZigZagRingFlashAttnFunc +from .core.utils import gen_head_anno from flash_attn import flash_attn_func import torch.distributed as dist from nnscaler.runtime.device import DeviceGroup + def wrap_zigzag_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), alibi_slopes: Tensor=None, deterministic: bool=False, @@ -27,6 +29,10 @@ def wrap_zigzag_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor required communications. ''' + assert window_size == (-1, -1), "window_size is not supported in zigzag-attention" + assert not return_attn_probs, "return_attn_probs is not supported in zigzag-attention" + assert alibi_slopes is None, "alibi_slopes is not supported in zigzag-attention" + if process_group is None or len(process_group) == 1: # there is an additional checker for the `softmax_scale`, which is equivalent # to the behavior of the original flash_attn_func. @@ -102,4 +108,10 @@ def emit_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], ru args = ", ".join(list(args) + kw_pairs) return f"{signature}({args})" -register_op('bs l h dim^, bs l h dim^, bs l h dim^ -> bs l h dim^', emit_fn=emit_zigzag)(wrap_zigzag_attn_func) + +def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states) + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + + +register_op(flash_attention_anno, emit_fn=emit_zigzag)(wrap_zigzag_attn_func) diff --git a/examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py b/examples/llama/customized_ops/test_chunk_linear_cross_entropy.py similarity index 100% rename from examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py rename to examples/llama/customized_ops/test_chunk_linear_cross_entropy.py diff --git a/examples/customized_ops/ring_attention/test_ring_attn.py b/examples/llama/customized_ops/test_ring_attn.py similarity index 50% rename from examples/customized_ops/ring_attention/test_ring_attn.py rename to examples/llama/customized_ops/test_ring_attn.py index 757ee84e..fc72a91a 100644 --- a/examples/customized_ops/ring_attention/test_ring_attn.py +++ b/examples/llama/customized_ops/test_ring_attn.py @@ -2,67 +2,35 @@ # Licensed under the MIT License. import torch +import argparse import nnscaler from nnscaler.graph import IRGraph from nnscaler.ir.operator import IRFwOperation from nnscaler.parallel import parallelize, ComputeConfig, ReuseType import torch.distributed as dist -from flash_attn import flash_attn_func import nnscaler.graph import nnscaler.graph.function -from ring_attn import wrap_ring_attn_func +from ring_attention import wrap_ring_attn_func +from ring_attention.core.utils import set_seed, log -import random - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " - f"max {a.abs().max().item()}, " - f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " - f"max {a.abs().max().item()}, " - f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() - def forward(self, _in0, _in1, _in2): - out = wrap_ring_attn_func(_in0, _in1, _in2) + def forward(self, q, k, v): + out = wrap_ring_attn_func(q, k, v) return out + def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: ngpus = resource.plan_ngpus partitioned = False for idx, node in enumerate(graph.select(ntype=IRFwOperation)): - if not partitioned and node.signature == 'ring_attn.wrap_ring_attn_func': - print('Partitioned node: ', node) - sub_nodes = graph.partition( - node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) + if not partitioned and node.signature == 'ring_attention.ring_attn.wrap_ring_attn_func': + print('\nPartitioned node: ', node, '\n') + sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) partitioned = True else: sub_nodes = graph.replicate(node, times=ngpus) @@ -71,7 +39,18 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: assert partitioned, f'expect ring_attn_func in graph, but not found.' return graph + if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + args = parser.parse_args() + nnscaler.init() rank_id = torch.distributed.get_rank() world_size = dist.get_world_size() @@ -83,18 +62,21 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: d = 128 device = torch.device(f"cuda:{rank_id}") - # dtype = torch.float16 - dtype = torch.bfloat16 + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 - q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) - k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) - v = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) + k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) + v = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) dist.broadcast(q, src=0) dist.broadcast(k, src=0) dist.broadcast(v, src=0) dist.barrier() + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + single_out = wrap_ring_attn_func(q, k, v) single_out.retain_grad() single_loss = single_out.sum() @@ -102,22 +84,22 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: model = TestModule() - _in0 = q.detach().clone().requires_grad_() - _in1 = k.detach().clone().requires_grad_() - _in2 = v.detach().clone().requires_grad_() + qq = q.detach().clone().requires_grad_() + kk = k.detach().clone().requires_grad_() + vv = v.detach().clone().requires_grad_() - parallel_model = parallelize(model, dummy_forward_args={"_in0": _in0, "_in1": _in1, "_in2": _in2}, pas_policy=policy, + parallel_model = parallelize(model, dummy_forward_args={"q": qq, "k": kk, "v": vv}, pas_policy=policy, compute_config=ComputeConfig(world_size, world_size), reuse=ReuseType.OVERRIDE) parallel_model = parallel_model.cuda() parallel_model.train() - _in0 = q.detach().clone().requires_grad_() - _in1 = k.detach().clone().requires_grad_() - _in2 = v.detach().clone().requires_grad_() + qq = q.detach().clone().requires_grad_() + kk = k.detach().clone().requires_grad_() + vv = v.detach().clone().requires_grad_() - para_out = parallel_model(_in0, _in1, _in2) + para_out = parallel_model(qq, kk, vv) para_loss = para_out.sum() para_loss.backward() parallel_model.sync_grad() @@ -127,13 +109,15 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: log("out diff", single_out - para_out, rank0_only=True) log("single dq", q.grad, rank0_only=True) - log("multi dq", _in0.grad, rank0_only=True) - log("dq diff", q.grad - _in0.grad, rank0_only=True) + log("multi dq", qq.grad, rank0_only=True) + log("dq diff", q.grad - qq.grad, rank0_only=True) log("single dk", k.grad, rank0_only=True) - log("multi dk", _in1.grad, rank0_only=True) - log("dk diff", k.grad - _in1.grad, rank0_only=True) + log("multi dk", kk.grad, rank0_only=True) + log("dk diff", k.grad - kk.grad, rank0_only=True) log("single dv", v.grad, rank0_only=True) - log("multi dv", _in2.grad, rank0_only=True) - log("dv diff", v.grad - _in2.grad, rank0_only=True) + log("multi dv", vv.grad, rank0_only=True) + log("dv diff", v.grad - vv.grad, rank0_only=True) + + dist.destroy_process_group() diff --git a/examples/llama/customized_ops/test_ring_attn_varlen.py b/examples/llama/customized_ops/test_ring_attn_varlen.py new file mode 100644 index 00000000..0df2e863 --- /dev/null +++ b/examples/llama/customized_ops/test_ring_attn_varlen.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import nnscaler +import argparse +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType +import torch.distributed as dist + +import nnscaler.graph +import nnscaler.graph.function +from ring_attention import wrap_ring_attn_varlen_func +from ring_attention.core.utils import set_seed, log + + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v, cu_seqlens_q, cu_seqlens_k): + out = wrap_ring_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k) + return out + + +def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: + ngpus = resource.plan_ngpus + partitioned = False + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == 'ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func': + print('\nPartitioned node: ', node, '\n') + sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=0, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + assert partitioned, f'expect ring_attn_varlen_func in graph, but not found.' + return graph + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + args = parser.parse_args() + + nnscaler.init() + rank_id = torch.distributed.get_rank() + world_size = dist.get_world_size() + + set_seed(rank_id) + seqlen = 8192 + nheads = 24 + d = 128 + + device = torch.device(f"cuda:{rank_id}") + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 + cu_seqlens = [0, 120, 1248, 4232, 8192] + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + + q = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) + k = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) + v = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) + + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.barrier() + + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + + single_out = wrap_ring_attn_varlen_func(q, k, v, cu_seqlens_tensor, cu_seqlens_tensor) + single_out.retain_grad() + single_loss = single_out.sum() + single_loss.backward() + + model = TestModule() + + qq = q.detach().clone().requires_grad_() + kk = k.detach().clone().requires_grad_() + vv = v.detach().clone().requires_grad_() + + parallel_model = parallelize( + model, + dummy_forward_args={"q": qq, "k": kk, "v": vv, 'cu_seqlens_q': cu_seqlens_tensor, 'cu_seqlens_k': cu_seqlens_tensor}, + pas_policy=policy, + compute_config=ComputeConfig(world_size, world_size), + reuse=ReuseType.OVERRIDE + ) + parallel_model = parallel_model.cuda() + + + parallel_model.train() + + qq = q.detach().clone().requires_grad_() + kk = k.detach().clone().requires_grad_() + vv = v.detach().clone().requires_grad_() + + para_out = parallel_model(qq, kk, vv, cu_seqlens_tensor, cu_seqlens_tensor) + para_loss = para_out.sum() + para_loss.backward() + parallel_model.sync_grad() + + log("single out", single_out, rank0_only=True) + log("multi out", para_out, rank0_only=True) + log("out diff", single_out - para_out, rank0_only=True) + + log("single dq", q.grad, rank0_only=True) + log("multi dq", qq.grad, rank0_only=True) + log("dq diff", q.grad - qq.grad, rank0_only=True) + + log("single dk", k.grad, rank0_only=True) + log("multi dk", kk.grad, rank0_only=True) + log("dk diff", k.grad - kk.grad, rank0_only=True) + + log("single dv", v.grad, rank0_only=True) + log("multi dv", vv.grad, rank0_only=True) + log("dv diff", v.grad - vv.grad, rank0_only=True) + + dist.destroy_process_group() diff --git a/examples/customized_ops/ring_attention/test_zigzag_attn.py b/examples/llama/customized_ops/test_zigzag_attn.py similarity index 50% rename from examples/customized_ops/ring_attention/test_zigzag_attn.py rename to examples/llama/customized_ops/test_zigzag_attn.py index e97a5b88..0e0d8069 100644 --- a/examples/customized_ops/ring_attention/test_zigzag_attn.py +++ b/examples/llama/customized_ops/test_zigzag_attn.py @@ -2,67 +2,35 @@ # Licensed under the MIT License. import torch +import argparse import nnscaler from nnscaler.graph import IRGraph from nnscaler.ir.operator import IRFwOperation from nnscaler.parallel import parallelize, ComputeConfig, ReuseType import torch.distributed as dist -from flash_attn import flash_attn_func import nnscaler.graph import nnscaler.graph.function -from zigzag_attn import wrap_zigzag_attn_func +from ring_attention import wrap_zigzag_attn_func +from ring_attention.core.utils import set_seed, log -import random - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " - f"max {a.abs().max().item()}, " - f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " - f"max {a.abs().max().item()}, " - f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() - def forward(self, _in0, _in1, _in2): - out = wrap_zigzag_attn_func(_in0, _in1, _in2) + def forward(self, q, k, v): + out = wrap_zigzag_attn_func(q, k, v) return out + def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: ngpus = resource.plan_ngpus partitioned = False for idx, node in enumerate(graph.select(ntype=IRFwOperation)): - if not partitioned and node.signature == 'zigzag_attn.wrap_zigzag_attn_func': - print('Partitioned node: ', node) - sub_nodes = graph.partition( - node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) + if not partitioned and node.signature == 'ring_attention.zigzag_attn.wrap_zigzag_attn_func': + print('\nPartitioned node: ', node, '\n') + sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) partitioned = True else: sub_nodes = graph.replicate(node, times=ngpus) @@ -71,7 +39,18 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: assert partitioned, f'expect zigzag_attn_func in graph, but not found.' return graph + if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + args = parser.parse_args() + nnscaler.init() rank_id = torch.distributed.get_rank() world_size = dist.get_world_size() @@ -83,18 +62,21 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: d = 128 device = torch.device(f"cuda:{rank_id}") - # dtype = torch.float16 - dtype = torch.bfloat16 + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 - q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) - k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) - v = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) + k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) + v = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) dist.broadcast(q, src=0) dist.broadcast(k, src=0) dist.broadcast(v, src=0) dist.barrier() + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + single_out = wrap_zigzag_attn_func(q, k, v) single_out.retain_grad() single_loss = single_out.sum() @@ -102,22 +84,22 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: model = TestModule() - _in0 = q.detach().clone().requires_grad_() - _in1 = k.detach().clone().requires_grad_() - _in2 = v.detach().clone().requires_grad_() + qq = q.detach().clone().requires_grad_() + kk = k.detach().clone().requires_grad_() + vv = v.detach().clone().requires_grad_() - parallel_model = parallelize(model, dummy_forward_args={"_in0": _in0, "_in1": _in1, "_in2": _in2}, pas_policy=policy, + parallel_model = parallelize(model, dummy_forward_args={"q": qq, "k": kk, "v": vv}, pas_policy=policy, compute_config=ComputeConfig(world_size, world_size), reuse=ReuseType.OVERRIDE) parallel_model = parallel_model.cuda() parallel_model.train() - _in0 = q.detach().clone().requires_grad_() - _in1 = k.detach().clone().requires_grad_() - _in2 = v.detach().clone().requires_grad_() + qq = q.detach().clone().requires_grad_() + kk = k.detach().clone().requires_grad_() + vv = v.detach().clone().requires_grad_() - para_out = parallel_model(_in0, _in1, _in2) + para_out = parallel_model(qq, kk, vv) para_loss = para_out.sum() para_loss.backward() parallel_model.sync_grad() @@ -127,13 +109,15 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: log("out diff", single_out - para_out, rank0_only=True) log("single dq", q.grad, rank0_only=True) - log("multi dq", _in0.grad, rank0_only=True) - log("dq diff", q.grad - _in0.grad, rank0_only=True) + log("multi dq", qq.grad, rank0_only=True) + log("dq diff", q.grad - qq.grad, rank0_only=True) log("single dk", k.grad, rank0_only=True) - log("multi dk", _in1.grad, rank0_only=True) - log("dk diff", k.grad - _in1.grad, rank0_only=True) + log("multi dk", kk.grad, rank0_only=True) + log("dk diff", k.grad - kk.grad, rank0_only=True) log("single dv", v.grad, rank0_only=True) - log("multi dv", _in2.grad, rank0_only=True) - log("dv diff", v.grad - _in2.grad, rank0_only=True) + log("multi dv", vv.grad, rank0_only=True) + log("dv diff", v.grad - vv.grad, rank0_only=True) + + dist.destroy_process_group() diff --git a/examples/llama/lm_models/diff_transformer_modifier.py b/examples/llama/lm_models/diff_transformer_modifier.py index 6e3c5820..17542b37 100644 --- a/examples/llama/lm_models/diff_transformer_modifier.py +++ b/examples/llama/lm_models/diff_transformer_modifier.py @@ -19,21 +19,13 @@ from transformers.utils import is_flash_attn_greater_or_equal_2_10 from .utils import nnscaler_flash_attention_forward +from customized_ops.ring_attention import wrap_ring_attn_func -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) -try: - import os - import sys - sys.path.append(os.path.abspath(os.path.join(os.path.dirname( - os.path.abspath(__file__)), '../../customized_ops/ring_attention'))) - from ring_attn import wrap_ring_attn_func - - def nnscaler_ring_attn_func(query_states, key_states, value_states, *args, **kwargs): - return wrap_ring_attn_func(query_states, key_states, value_states) -except ModuleNotFoundError: - logger.warning("Ring Attention is not import correctly.") +def nnscaler_ring_attn_func(query_states, key_states, value_states, *args, **kwargs): + return wrap_ring_attn_func(query_states, key_states, value_states) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -181,9 +173,9 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - logger.warning("output_attentions is not supported for NNScalerMultiheadDiffFlashAttn.") + _logger.warning("output_attentions is not supported for NNScalerMultiheadDiffFlashAttn.") if attention_mask: - logger.warning("attention_mask is not supported for NNScalerMultiheadDiffFlashAttn.") + _logger.warning("attention_mask is not supported for NNScalerMultiheadDiffFlashAttn.") bsz, q_len, _ = hidden_states.size() diff --git a/examples/llama/train.py b/examples/llama/train.py index 66deeb40..efbec4e7 100644 --- a/examples/llama/train.py +++ b/examples/llama/train.py @@ -10,7 +10,7 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling from lm_models.utils import nnscaler_lm_init -from chunk_linear_cross_entropy import chunk_linear_cross_entropy +from customized_ops.chunk_linear_cross_entropy import chunk_linear_cross_entropy from nnscaler.utils import set_default_logger_level from nnscaler.cli.trainer import Trainer From c929f270cae3ebf5cd9922641ea547bab3aab813 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 24 Jul 2025 03:55:01 +0000 Subject: [PATCH 1820/1892] Merged PR 2383: [BugFix] Refine broadcasting merge state dict to avoid OOM Send one parameter at a time, instead of sending all at one time to avoid OOM. --- nnscaler/autodist/dp_solver.cpp | 13 ++++++ nnscaler/cli/mixed_module.py | 6 ++- nnscaler/cli/train_hook.py | 33 ++++++++++++++- nnscaler/cli/trainer.py | 74 ++++++++++++++++++++++++++++++--- nnscaler/cli/trainer_args.py | 40 ++++++++++++++++-- requirements-dev.txt | 1 + tests/cli/test_train_args.py | 15 +++++++ tests/conftest.py | 10 +++++ 8 files changed, 180 insertions(+), 12 deletions(-) diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index 334617e5..f0b8e6d7 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -623,3 +623,16 @@ PYBIND11_MODULE(dp_solver, m) { .def("solve", &DPSolver::solve) .def("get_results", &DPSolver::get_results); } + +// the following is used to build the cpp file in cppimport +// which is just for local development convenience +// For production, `setup.py` will be used to build the cpp file +/* +<% +setup_pybind11(cfg) +cfg['extra_compile_args'] = ['-std=c++11'] +cfg['extra_compile_args'] = ['-O3'] +cfg['extra_compile_args'] = ['-pthread'] +cfg['dependencies'] = ['dp_solver.h'] +%> +*/ diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index fd2b4601..7849a67f 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -23,8 +23,10 @@ def fork_rng(): if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - return torch.random.fork_rng([rank]) + # only capture the random state of the current device + # which is good enough for us + device = torch.cuda.current_device() + return torch.random.fork_rng([device]) else: return torch.random.fork_rng() diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index aebb44d3..620cedf7 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, Dict, List, TYPE_CHECKING, TypedDict, Optional import torch @@ -10,6 +10,14 @@ from nnscaler.cli.trainer_args import AggregatedOutputs +class StepMetrics(TypedDict): + train_loss: float + loss: float # alias for train_loss + lr: float + gnorm: float + train_wall: float # wall time for training step + + class TrainHook: """ Note: All hooks are called in all ranks, and the inputs of hooks are only the local data. @@ -52,6 +60,21 @@ def on_epoch_end(self, trainer: 'Trainer', epoch: int) -> None: epoch: the current epoch index """ + def on_step_start(self, trainer: 'Trainer', epoch: int, idx: int) -> None: + """ + Called at the beginning of each step + Args: + idx: the index of current step + """ + + def on_step_end(self, trainer: 'Trainer', epoch: int, idx: int, step_metrics: StepMetrics) -> None: + """ + Called at the end of each step (validation and checkpoint saving are not included) + Args: + idx: the index of current step + step_metrics: the metrics of the current step + """ + def on_train_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: """ Called at the beginning of each training step @@ -212,6 +235,14 @@ def on_epoch_end(self, trainer: 'Trainer', epoch: int) -> None: for hook in self.hooks: hook.on_epoch_end(trainer, epoch) + def on_step_start(self, trainer: 'Trainer', epoch: int, idx: int) -> None: + for hook in self.hooks: + hook.on_step_start(trainer, epoch, idx) + + def on_step_end(self, trainer: 'Trainer', epoch: int, idx: int, step_metrics: StepMetrics) -> None: + for hook in self.hooks: + hook.on_step_end(trainer, epoch, idx, step_metrics) + def on_train_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: for hook in self.hooks: hook.on_train_step_start(trainer, batches, idx) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 4338ad83..7a176e3d 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -274,6 +274,69 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): } return merged_state_dict + def _broadcast_merged_state_dict(self, state_dict: Dict[str, Any]): + """ + Broadcast the merged state dict to all ranks. + We can't broadcast the whole state_dict at once, because it may be too large, and leads to OOM. + Here we will break the model and optimizer state_dict into smaller pieces and broadcast them one by one. + Please note we use `torch.distributed.broadcast_object_list` to broadcast the state_dict (including tensors inside). + """ + + def _broadcast_keys(sdict: Dict[str, Any], set_keys=True): + if self.rank == 0: + state_keys = list(sdict.keys()) + else: + state_keys = None + state_key_list = [state_keys] + torch.distributed.broadcast_object_list(state_key_list, src=0) + state_keys = state_key_list[0] + if set_keys and self.rank != 0: + for key in state_keys: + sdict[key] = {} # assume the values are empty dicts + return state_keys + + def _broadcast_value(sdict, key): + if self.rank == 0: + value_list = [sdict[key]] + else: + value_list = [None] + torch.distributed.broadcast_object_list(value_list, src=0) + if self.rank != 0: + sdict[key] = value_list[0] + + def _broadcast_values(sdict, keys): + for key in keys: + _broadcast_value(sdict, key) + + if self.rank == 0: + if state_dict is None: + raise ValueError("state_dict should not be None in rank 0 when broadcasting") + else: + if state_dict is not None: + raise ValueError("state_dict should be None in other ranks when broadcasting") + state_dict = {} + + state_keys = _broadcast_keys(state_dict) + + for skey in state_keys: + logger.info(f"Broadcasting {skey}.") + if skey == 'optimizer': + opt_keys = _broadcast_keys(state_dict['optimizer']) + opt_keys_without_state = [ + k for k in opt_keys if k != 'state' + ] + _broadcast_values(state_dict['optimizer'], opt_keys_without_state) + idxs = _broadcast_keys(state_dict['optimizer']['state']) + for idx in idxs: + idx_keys = _broadcast_keys(state_dict['optimizer']['state'][idx]) + _broadcast_values(state_dict['optimizer']['state'][idx], idx_keys) + elif skey == 'model': + model_keys = _broadcast_keys(state_dict['model']) + _broadcast_values(state_dict['model'], model_keys) + else: + _broadcast_value(state_dict, skey) + return state_dict + @classmethod def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): merged_state_dict = cls._merge_checkpoint(checkpoint_files) @@ -319,12 +382,9 @@ def _load_checkpoint(self): state_dict = self._merge_checkpoint(list(rank_ckpt_files.values())) else: state_dict = None - state_dict_list = [state_dict] logger.info(f"Broadcasting merged checkpoint to all ranks.") - # TODO: it will be easily out of memory when the model is large - # We should broadcast the state_dict one parameter by one - torch.distributed.broadcast_object_list(state_dict_list, src=0) - state_dict = state_dict_list[0] + state_dict = self._broadcast_merged_state_dict(state_dict) + logger.info(f"Broadcasted merged checkpoint to all ranks.") else: resume_from = resume_from / f'{self.rank}.ckpt' state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) @@ -846,6 +906,8 @@ def _train_epoch(self, epoch: int) -> None: step_stat: Optional[_StepStat] = None for i, batches in data_iter: idx = i + resume_from_idx + self.hook.on_step_start(self, epoch, idx) + step_start_at = time.perf_counter() step_stat = _StepStat() step_metrics = {} @@ -932,6 +994,8 @@ def _train_epoch(self, epoch: int) -> None: logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) step_metrics = {} + self.hook.on_step_end(self, epoch, idx, step_metrics) + # validate and save checkpoint if self.train_args.checkpoint.every_n_train_steps and \ self.train_status.finished_train_steps % self.train_args.checkpoint.every_n_train_steps == 0: diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 62d1b9b8..8552ad53 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -103,6 +103,17 @@ def _resolve_precision(precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE] return precision +def _factory_normalize(value): + if isinstance(value, dict) and _TYPE_KEY in value: + kwargs = { + k: v for k, v in value.items() + if k != _TYPE_KEY + } + return load_type(value[_TYPE_KEY])(**kwargs) + else: + return value + + def fix_input(input, input_dtype=None): if isinstance(input, dict): return {k: fix_input(v, input_dtype) for k, v in input.items()} @@ -242,8 +253,15 @@ class ModuleParallelizeConfig: gen_reuse: Optional[str] = None pas_policy: Optional[str] = None broadcast_strategy: Optional[str] = None - instance_name: Optional[str] = None - precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None + # sometimes you want to dynamically set the instance name + # for example, you can set it to the hash of related files + # In that case, we can pass a dict with callable __type field. + instance_name: Optional[str] = field(default=None, metadata={ + 'normalize': _factory_normalize + }) + precision: Optional[Dict[_TENSOR_TYPE, _PRECISION_TYPE]] = field(default=None, metadata={ + 'skip_deserialization': True, + }) def __post_init__(self): if not self.type: @@ -366,11 +384,17 @@ def __post_init__(self): @dataclass class ResumeOptions: - checkpoint: str = 'last' + # sometimes you want to dynamically set checkpoint path + # for example, you can set it to finetune model if no `last` checkpoint exists + checkpoint: str = field(default='last', metadata={ + 'normalize': _factory_normalize + }) # the full qualified name of the function to # convert the checkpoint to nnscaler format # It should be `Callable[[Dict[str, Any]], Dict[str, Any]]` # Only applied when `checkpoint` is a file. + # Please note you should handle the case + # when checkpoint file comes from a factory method convert_fn: Optional[str] = None # whether to merge the checkpoint files # Only used when `checkpoint` is a directory. @@ -498,6 +522,9 @@ class HookMapConfig: on_epoch_start: str = None on_epoch_end: str = None + on_step_start: str = None + on_step_end: str = None + on_train_step_start: str = None on_train_step_end: str = None on_val_step_start: str = None @@ -556,7 +583,12 @@ class TrainerArgs(PrecisionMixin, PolicyMixin): gen_reuse: str = 'auto' pas_policy: str = 'autodist' broadcast_strategy: str = 'all' - instance_name: str = None + # sometimes you want to dynamically set the instance name + # for example, you can set it to the hash of related files + # In that case, we can pass a dict with callable __type field. + instance_name: Optional[str] = field(default=None, metadata={ + 'normalize': _factory_normalize + }) # compile: compile the model but not training # run: compile and run the model run_mode: str = 'run' diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f64d831..7d181749 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,3 +18,4 @@ yapf wandb tensorboard mosaicml-streaming +cppimport diff --git a/tests/cli/test_train_args.py b/tests/cli/test_train_args.py index aabce253..389e5d75 100644 --- a/tests/cli/test_train_args.py +++ b/tests/cli/test_train_args.py @@ -51,3 +51,18 @@ def test_arg_merge_resolve(): assert args.dataset.train_args['dim'] == 22 assert args.dataset.val_args['dim'] == 22 assert args.vars['hello'] == args.compute_config.plan_ngpus + + +def gen_instance_name(stem): + return f'instance_{stem}' + + +def test_dyn_str_config(): + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + args = TrainerArgs.from_cli(['-f', config_path, + '--instance_name.__type', 'tests.cli.test_train_args.gen_instance_name', + '--instance_name.stem', 'p$(compute_config.plan_ngpus)', + '--compute_config.plan_ngpus', '1', + '--global_batch_size!', + ]) + assert args.instance_name == 'instance_p1' diff --git a/tests/conftest.py b/tests/conftest.py index e42de300..b581d41c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,16 @@ from nnscaler.graph.parser import FxModuleParser +try: + import nnscaler.autodist.dp_solver +except ImportError: + from pathlib import Path + from cppimport import build_filepath + import nnscaler.autodist + # lazy build the cpp file if it is not built yet + build_filepath(Path(nnscaler.autodist.__file__).with_name("dp_solver.cpp"), fullname="nnscaler.autodist.dp_solver") + + @pytest.fixture(autouse=True) def clean_generated_files(): print('hello') From df303c179f7f18e9b0e27e76b4dc08f5c5bf41ad Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 25 Jul 2025 06:32:57 +0000 Subject: [PATCH 1821/1892] Merged PR 2384: [BugFix] dij implementation and deepseek coder v2 lite example resolve git issues - https://github.com/microsoft/nnscaler/issues/36 - https://github.com/microsoft/nnscaler/issues/37 --- .../modeling/modeling_deepseek_modifier.py | 2 +- nnscaler/graph/gener/rvd/inter.py | 2 +- nnscaler/graph/gener/rvd/intra.py | 2 +- tests/graph/gener/test_intra_rvd.py | 204 ++++++++++++++++-- 4 files changed, 192 insertions(+), 18 deletions(-) diff --git a/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py b/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py index 9ae7da5a..54774044 100644 --- a/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py +++ b/examples/deepseek_coder_v2_lite/modeling/modeling_deepseek_modifier.py @@ -383,7 +383,7 @@ def moe_route(hidden_states: torch.Tensor, weight: torch.Tensor, # it difficult to handle it correctly without modifying the code # 3. dispatch by allgather is used currently, which is compatible with the replicated # moe_route plan -register_op(f'n^ l^ h^, e^ h^ -> (n^ l^) k^, (n^ l^) k^, 1')(moe_route) +register_op(f'n^ l^ h^, e^ h^ -> (n^ l^) 64, (n^ l^) 64, 1')(moe_route) def nnscaler_llama_flash_attention_forward_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str: diff --git a/nnscaler/graph/gener/rvd/inter.py b/nnscaler/graph/gener/rvd/inter.py index a57a8a21..e078be3b 100644 --- a/nnscaler/graph/gener/rvd/inter.py +++ b/nnscaler/graph/gener/rvd/inter.py @@ -364,7 +364,7 @@ def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, cost_fn: Optional[Ca min_cost, visit = np.inf, None for idx in unvisited: if cost[idx] < min_cost: - min_cost = idx + min_cost = cost[idx] visit = idx if visit is None: break for neighbor in np.where(edges[visit] != np.inf)[0]: diff --git a/nnscaler/graph/gener/rvd/intra.py b/nnscaler/graph/gener/rvd/intra.py index 2783019e..1cba9a34 100644 --- a/nnscaler/graph/gener/rvd/intra.py +++ b/nnscaler/graph/gener/rvd/intra.py @@ -420,7 +420,7 @@ def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, min_cost, visit = np.inf, None for idx in unvisited: if cost[idx] < min_cost: - min_cost = idx + min_cost = cost[idx] visit = idx if visit is None: break # for remaining states that cannot reach for neighbor in np.where(edges[visit] != np.inf)[0]: diff --git a/tests/graph/gener/test_intra_rvd.py b/tests/graph/gener/test_intra_rvd.py index 1a784a1a..e71e90bc 100644 --- a/tests/graph/gener/test_intra_rvd.py +++ b/tests/graph/gener/test_intra_rvd.py @@ -77,6 +77,26 @@ def test_transition_space(tmp_path): def test_one_f_case(): + """ + Test complex RVD transformation: (1,4,1,1,2) → (2,1,2,1,2) + + Note: This test case has multiple optimal paths with equal costs, so the specific + path selection depends on algorithm implementation details (e.g., Dijkstra's + tie-breaking behavior). Once a path is chosen, the device assignments are + deterministic based on the RVD layout structure. + + Current path (one of multiple possible optimal paths): + - Step 1: (1,4,1,1,2) → (1,2,2,1,2) - v2d transformation via reduce_scatter + - Step 2: (1,2,2,1,2) → (2,1,2,1,2) - v2r transformation via all_reduce + + Alternative path with same cost exists: + - (1,4,1,1,2) → (1,1,4,1,2) → (2,1,2,1,2) + + The test verifies the current algorithm's choice and the resulting device mappings. + If the path selection algorithm changes, this test may need to be updated to + reflect the new chosen path, but the device assignments for any given path + should remain deterministic. + """ fshape = [128, 256, 512] src_r, src_v, src_d = 1,4,(1,1,2) @@ -94,26 +114,35 @@ def test_one_f_case(): fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=cdevs) rvds = IntraPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) - # reduce-scatter(v2d) and then all-gather(d2r) - assert rvds == ((1, 4, 1, 1, 2), (1, 1, 4, 1, 2), (2, 1, 2, 1, 2)) + # reduce-scatter(v2d) and then all-reduce(v2r) + # Note: This is one of multiple possible optimal paths + assert rvds == ((1, 4, 1, 1, 2), (1, 2, 2, 1, 2), (2, 1, 2, 1, 2)) fprims = IntraPathFinder.path(fp_rvd, fc_rvd) - assert len(fprims) == 6 - # (1, 4, 1, 1, 2) => (1, 1, 4, 1, 2) + assert len(fprims) == 8 + + # Step 1: (1, 4, 1, 1, 2) => (1, 2, 2, 1, 2) via reduce_scatter + # Once the path is determined, device mappings are deterministic based on RVDLayout.grid # here the device align is found with `inner_transpose` alternative. assert fprims[0].signature == 'nnscaler.runtime.adapter.reduce_scatter' - assert fprims[0].device == [0, 2, 4, 6] + assert fprims[0].device == [0, 2] # Deterministic given the chosen path and layout assert fprims[1].signature == 'nnscaler.runtime.adapter.reduce_scatter' - assert fprims[1].device == [1, 3, 5, 7] - # (1, 1, 4, 1, 2), (2, 1, 2, 1, 2) - assert fprims[2].signature == 'nnscaler.runtime.adapter.all_gather' - assert fprims[2].device == [0, 4] - assert fprims[3].signature == 'nnscaler.runtime.adapter.all_gather' - assert fprims[3].device == [1, 5] - assert fprims[4].signature == 'nnscaler.runtime.adapter.all_gather' - assert fprims[4].device == [2, 6] - assert fprims[5].signature == 'nnscaler.runtime.adapter.all_gather' - assert fprims[5].device == [3, 7] + assert fprims[1].device == [1, 3] # Deterministic given the chosen path and layout + assert fprims[2].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[2].device == [4, 6] # Deterministic given the chosen path and layout + assert fprims[3].signature == 'nnscaler.runtime.adapter.reduce_scatter' + assert fprims[3].device == [5, 7] # Deterministic given the chosen path and layout + + # Step 2: (1, 2, 2, 1, 2) => (2, 1, 2, 1, 2) via all_reduce + # These device assignments follow from the intermediate layout structure + assert fprims[4].signature == 'nnscaler.runtime.adapter.all_reduce' + assert fprims[4].device == [0, 4] # Deterministic based on intermediate RVD layout + assert fprims[5].signature == 'nnscaler.runtime.adapter.all_reduce' + assert fprims[5].device == [1, 5] # Deterministic based on intermediate RVD layout + assert fprims[6].signature == 'nnscaler.runtime.adapter.all_reduce' + assert fprims[6].device == [2, 6] # Deterministic based on intermediate RVD layout + assert fprims[7].signature == 'nnscaler.runtime.adapter.all_reduce' + assert fprims[7].device == [3, 7] # Deterministic based on intermediate RVD layout def test_f_reducescatter_alltoall(): @@ -219,6 +248,151 @@ def test_f_reducescatter_alltoall(): assert fprims[3]._outputs[1] == fc_subtensors[fprims[3]._outputs[1].device] +def test_simple_v2r_transformation(): + """ + Test simple value-to-replica transformation: (1,8,1,1,1) → (8,1,1,1,1) + This is a stable test case with deterministic path selection. + Only one optimal path exists, making this test robust to algorithm changes. + """ + fshape = [32, 64, 128] + src_rvd = (1, 8, 1, 1, 1) + dst_rvd = (8, 1, 1, 1, 1) + ndevs = 8 + + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + pdevs = list(range(ndevs)) + + src_r, src_v = src_rvd[0], src_rvd[1] + src_d = src_rvd[2:] + dst_r, dst_v = dst_rvd[0], dst_rvd[1] + dst_d = dst_rvd[2:] + + fp_rvd = RVDLayout.grid(ftensor, r=src_r, v=src_v, dims=src_d, devices=pdevs) + fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=pdevs) + + # Get path and communication operations + rvds = IntraPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) + fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + + # Verify path properties + assert rvds[0] == src_rvd + assert rvds[-1] == dst_rvd + assert len(rvds) == 2 # Direct transformation + assert len(fprims) == 1 # Single operation + + # Verify operation types (more stable than specific devices) + expected_ops = {'all_reduce': 1} + actual_ops = {} + for p in fprims: + op_type = p.signature.split('.')[-1] + actual_ops[op_type] = actual_ops.get(op_type, 0) + 1 + assert actual_ops == expected_ops + + # Verify device coverage + all_devices = set() + for p in fprims: + all_devices.update(p.device) + assert all_devices == set(range(8)) + + +def test_simple_r2d_transformation(): + """ + Test simple replica-to-dimension transformation: (8,1,1,1,1) → (1,1,8,1,1) + This is a stable test case with deterministic path selection. + Only one optimal path exists, making this test robust to algorithm changes. + """ + fshape = [32, 64, 128] + src_rvd = (8, 1, 1, 1, 1) + dst_rvd = (1, 1, 8, 1, 1) + ndevs = 8 + + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + pdevs = list(range(ndevs)) + + src_r, src_v = src_rvd[0], src_rvd[1] + src_d = src_rvd[2:] + dst_r, dst_v = dst_rvd[0], dst_rvd[1] + dst_d = dst_rvd[2:] + + fp_rvd = RVDLayout.grid(ftensor, r=src_r, v=src_v, dims=src_d, devices=pdevs) + fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=pdevs) + + # Get path and communication operations + rvds = IntraPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) + fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + + # Verify path properties + assert rvds[0] == src_rvd + assert rvds[-1] == dst_rvd + assert len(rvds) == 2 # Direct transformation + assert len(fprims) == 1 # Single operation + + # Verify operation types (more stable than specific devices) + expected_ops = {'chunk': 1} + actual_ops = {} + for p in fprims: + op_type = p.signature.split('.')[-1] + actual_ops[op_type] = actual_ops.get(op_type, 0) + 1 + assert actual_ops == expected_ops + + # Verify device coverage + all_devices = set() + for p in fprims: + all_devices.update(p.device) + assert all_devices == set(range(8)) + + +def test_partial_v2r_transformation(): + """ + Test partial value-to-replica transformation: (1,4,1,1,2) → (4,1,1,1,2) + This is a stable test case with deterministic path selection. + The transformation is simple and direct, making the path selection robust. + """ + fshape = [32, 64, 128] + src_rvd = (1, 4, 1, 1, 2) + dst_rvd = (4, 1, 1, 1, 2) + ndevs = 8 + + ftensor = IRFullTensor(shape=fshape, name='tensor', requires_grad=False) + pdevs = list(range(ndevs)) + + src_r, src_v = src_rvd[0], src_rvd[1] + src_d = src_rvd[2:] + dst_r, dst_v = dst_rvd[0], dst_rvd[1] + dst_d = dst_rvd[2:] + + fp_rvd = RVDLayout.grid(ftensor, r=src_r, v=src_v, dims=src_d, devices=pdevs) + fc_rvd = RVDLayout.grid(ftensor, r=dst_r, v=dst_v, dims=dst_d, devices=pdevs) + + # Get path and communication operations + rvds = IntraPathFinder.get_optimal_path(ftensor, src_rvd, dst_rvd) + fprims = IntraPathFinder.path(fp_rvd, fc_rvd) + + # Verify path properties + assert rvds[0] == src_rvd + assert rvds[-1] == dst_rvd + assert len(rvds) == 2 # Direct transformation + assert len(fprims) == 2 # Two parallel operations + + # Verify operation types (more stable than specific devices) + expected_ops = {'all_reduce': 2} + actual_ops = {} + for p in fprims: + op_type = p.signature.split('.')[-1] + actual_ops[op_type] = actual_ops.get(op_type, 0) + 1 + assert actual_ops == expected_ops + + # Verify device coverage + all_devices = set() + for p in fprims: + all_devices.update(p.device) + assert all_devices == set(range(8)) + + # Verify each operation involves exactly 4 devices (value=4) + for p in fprims: + assert len(p.device) == 4 + + def print_prims(fp_rvd, fc_rvd, fprims): print('fp_rvd:') for f in fp_rvd.mat.flatten(): From f01591e2b7a49a9bb19de45924b7ade7308cb4ad Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 25 Jul 2025 06:44:58 +0000 Subject: [PATCH 1822/1892] Merged PR 2385: [BugFix] fix arange op --- nnscaler/graph/function/function.py | 50 +++++++++++++++++----- tests/graph/function/test_functions.py | 3 ++ tests/parallel_module/test_gencode.py | 58 ++++++++++++++++++-------- 3 files changed, 83 insertions(+), 28 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index c4bd3366..c083ec9c 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -262,21 +262,49 @@ def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Uni return IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) -def Arange(*args, out=None, dtype=None, layout=None, +def Arange(*args, start=None, end=None, step=None, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): """ torch.arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor """ assert layout is None - if len(args) == 1: - start, end, step = 0, args[0], 1 + if len(args) == 0: + if end is None: + # torch.arange() is invalid + raise ValueError("torch.arange() requires end to be set") + if start is None and step is not None: + # torch.arange(end=end, step=step) is invalid + raise ValueError("start should be set when step is provided") + resolved_start, resolved_end, resolved_step = \ + (0 if start is None else start), end, (1 if step is None else step) + elif len(args) == 1: + if start is not None: + # torch.arange(number, start=start) is invalid + # torch.arange(number, start=start, end=end) is invalid + # torch.arange(number, start=start, step=step) is invalid + # torch.arange(number, start=start, end=end, step=step) is invalid + raise ValueError("start should not be set when only one argument is provided") + if end is None and step is not None: + # torch.arange(number, step=step) is invalid + raise ValueError("end should be set when step is provided") + if end is None: + # case 1: torch.arange(number) # number is end + resolved_start, resolved_end, resolved_step = 0, args[0], 1 + else: + # case 2: torch.arange(number, end=end) # number is start + # or torch.arange(number, end=end, step=step) # number is start + resolved_start, resolved_end, resolved_step = args[0], end, (1 if step is None else step) elif len(args) == 2: - start, end, step = args[0], args[1], 1 + if start is not None or end is not None: + raise ValueError("start, end should not be set when two arguments are provided") + resolved_start, resolved_end, resolved_step = args[0], args[1], (1 if step is None else step) elif len(args) == 3: - start, end, step = args + resolved_start, resolved_end, resolved_step = args + if start is not None or end is not None or step is not None: + raise ValueError("start, end, step should not be set when three arguments are provided") else: - raise RuntimeError(f'Invalid number {len(args)} of args in Arange.') - return CubeArange(start, end, step, dtype, requires_grad=requires_grad) + raise ValueError(f'Invalid number {len(args)} of args in Arange.') + return CubeArange(resolved_start, resolved_end, resolved_step, dtype, requires_grad=requires_grad) def CubeLinspace(start: Union[int, IRObject], end: Union[int, IRObject], steps: Union[int, IRObject], @@ -2466,7 +2494,7 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: # tensor slice if isinstance(obj, IRTensor): # TODO: support general tensor slicing: https://pytorch.org/cppdocs/notes/tensor_indexing.html - index = (index,) if isinstance(index, (int, slice, IRTensor, IRObject)) else tuple(index) + index = (index,) if isinstance(index, (int, slice, IRTensor, IRObject, type(None))) else tuple(index) return FullSlice(obj, *index) # object slice if isinstance(obj, IRObject): @@ -3363,8 +3391,10 @@ def Item(input, signature = None): """ torch.Tensor.item() """ - anno = '? -> ?' - return IRDimops(Item, 'item', signature, [anno], [input], constant_foldable=False) + # set output to IRObject.missing, + # because the output is unknown here. + # It will be filled with real value in parser. + return IRPyFunc(signature, inputs=[input], outputs=[IRObject.missing], constant_foldable=False) def DictKeys(o: Union[Dict, IRObject], signature=None): diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 5106768d..319e3b6f 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -221,6 +221,9 @@ def test_Where(): def test_FullSlice(): + op = F.FullSlice(IRTensor([2, 3, 4]), None) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c, ? -> 1 a b c' + op = F.FullSlice(IRTensor([2, 3, 4]), 1, [1.2, -1], 2) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, ?, ?, ? -> 2' diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index ef1fc2d1..79ae00cf 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -683,27 +683,48 @@ def forward(self, a): return a.item() -def _gencode_item_function_worker(tempdir): - init_distributed() - m_new = parallelize( - ItemModule(), - { - 'a': torch.tensor([5.0]), - }, - 'dp', - ComputeConfig(1, 1, constant_folding=True), - gen_savedir=tempdir, - load_module=True - ) - assert m_new is not None - # never fold torch.Tensor.item() to constant - assert _gencode_contains(tempdir, ItemModule, 0, '.*torch.Tensor.item.*') +@replace_all_device_with('cpu') +def test_codegen_item(): + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + ItemModule(), + { + 'a': torch.tensor([5.0]), + }, + 'dp', + ComputeConfig(1, 1, constant_folding=True), + gen_savedir=tempdir, + load_module=False + ) + assert m_new is None + # never fold torch.Tensor.item() to constant + assert _gencode_contains(tempdir, ItemModule, 0, '.*torch.Tensor.item.*') -@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') -def test_codegen_item(): +class ArangeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, end): + return torch.arange(start=0, end=end, dtype=torch.float32) + + +@replace_all_device_with('cpu') +def test_codegen_arange(): with tempfile.TemporaryDirectory() as tempdir: - launch_torchrun(1, _gencode_item_function_worker, tempdir) + m_new = parallelize( + ArangeModule(), + { + 'end': 5, + }, + 'dp', + ComputeConfig(1, 1, constant_folding=True), + gen_savedir=tempdir, + load_module=False + ) + assert m_new is None + # never fold torch.Tensor.item() to constant + assert _gencode_contains(tempdir, ArangeModule, 0, '.*nnscaler.runtime.function.arange\(start=0, end=.*, step=1.*') class MinModule(torch.nn.Module): @@ -713,6 +734,7 @@ def __init__(self): def forward(self, a, b): return torch.min(a, b) + def _gencode_min_function_worker(tempdir): init_distributed() m_new = parallelize( From 7728b2ad642b83684373fa9b0429fdda16e8de2b Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 6 Aug 2025 02:43:09 +0000 Subject: [PATCH 1823/1892] Merged PR 2386: [BugFix] torch.compile support 1. clarify torch.compile support 2. revert the patcher in tracer to support the ccase when autograd functions use torch.compile --- docs/source/register_custom_op.md | 44 ++++ nnscaler/graph/parser/parser.py | 15 +- nnscaler/graph/parser/register.py | 6 +- nnscaler/graph/tracer/concrete_tracer.py | 6 +- nnscaler/graph/tracer/wrap_utils.py | 8 + tests/parallel_module/test_gencode.py | 1 - .../test_gencode_torch_compile.py | 192 ++++++++++++++++++ tests/parallel_module/test_pyfunc.py | 8 +- 8 files changed, 271 insertions(+), 9 deletions(-) create mode 100644 tests/parallel_module/test_gencode_torch_compile.py diff --git a/docs/source/register_custom_op.md b/docs/source/register_custom_op.md index 75532cef..e81103ab 100644 --- a/docs/source/register_custom_op.md +++ b/docs/source/register_custom_op.md @@ -120,6 +120,50 @@ This function has the following parameters: good distributed plan. Default: None. +## `torch.autograd.Function` + +If you are using `torch.autograd.Function`, you should register it(internally its `apply` function is registered). +Otherwise it will be replicated by default, which may lead to poor performance. + +``` +import torch +import nnscaler + +annotation = ... + +@nnscaler.register_op(annotation) +class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *args, **kwargs): + ... # your forward implementation + + @staticmethod + def backward(ctx, *grad_outputs): + ... # your backward implementation +``` +If you can't use class decorator, you can also register like this: +``` +nnscaler.register_op(annotation)(MyFunction) +``` +or +``` +nnscaler.register_op(annotation)(MyFunction.apply) +``` + +## `torch.compile` functions + +If you are using `torch.compile` for better performance, you must register the function to avoid tracing into the compiling logic, which will cause the tracing to fail. +```python +import torch +import nnscaler + +@torch.compile +def my_function(x: torch.Tensor) -> torch.Tensor: + return x * 2 + +nnscaler.register_op('* -> *')(my_function) + +``` ## Dimension Annotion Operations diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 48001642..02c52611 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -262,7 +262,12 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule _logger.warning(f'Find unknown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fsig ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) - # case2: python runtime function + # case2: custom autograd function + elif FxModuleParser._is_custom_autograd_op(node): + # custom autograd function + _logger.warning(f'Find unknown custom autograd operation: {fsig}. You should register it with nnscaler.register_op') + ir_node = IRFwOperation(fsig, fsig, input_vals, 1, **kwargs) + # case3: python runtime function else: _logger.warning(f'Set python runtime function: {fsig}') is_constant = True @@ -549,3 +554,11 @@ def _is_torch_autograd_op(node: torch.fx.Node, frame: Frame, signature: str) -> # an IRTensor, thus cannot be considered as a pytorch autograd operator. return signature.startswith('torch.') and \ isinstance(frame.get_var(node.name), IRFullTensor) + + @staticmethod + def _is_custom_autograd_op(node: torch.fx.Node) -> bool: + node_target = node.target + return callable(node_target) \ + and getattr(node_target, '__name__', None) == 'apply' \ + and isinstance(getattr(node_target, '__self__', None), Type) \ + and issubclass(node_target.__self__, torch.autograd.Function) diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index ae128a3f..d7efea3d 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -14,7 +14,7 @@ from torch import ScriptFunction from nnscaler.graph.function.dimops import IRDimops, OpAnno, TransformRule -from nnscaler.graph.tracer.wrap_utils import is_autograd_apply +from nnscaler.graph.tracer.wrap_utils import is_autograd_apply, is_autograd_op from nnscaler.ir.operator import IRTensor, IRFwOperation _logger = logging.getLogger(__name__) @@ -172,6 +172,10 @@ def decorator(fn: Callable): if not callable(fn): raise TypeError("Expected a runtime function") + if inspect.isclass(fn) and is_autograd_op(fn): + _ = decorator(fn.apply) # register `apply` method of the autograd function + return fn # return the class itself + # step 1. get function signature and inputs def get_import_path(fn: Callable) -> str: if is_autograd_apply(fn): diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index d5bb0656..aab86ba0 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -189,8 +189,10 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): args_unwrapped = pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, args) kwargs_unwrapped = pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, kwargs) - - if self.need_revert(target): + # A lot of autograd functions are using torch.compile + # We must revert the patcher to the original function so torch.compile can work. + # (For non-torch.compile functions, this is not necessary, but it is safe to do so.) + if self.need_revert(target) or wrap_utils.is_autograd_apply(target): with self.patcher.revert(): value_unwrapped, args_run, kwargs_run = self.strategy.run_target(kind, target, args_unwrapped, kwargs_unwrapped) else: diff --git a/nnscaler/graph/tracer/wrap_utils.py b/nnscaler/graph/tracer/wrap_utils.py index 318dfb6b..3be03593 100644 --- a/nnscaler/graph/tracer/wrap_utils.py +++ b/nnscaler/graph/tracer/wrap_utils.py @@ -408,6 +408,14 @@ def is_autograd_apply(func) -> bool: and orig_func.isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) +def is_autograd_op(func) -> bool: + try: + return issubclass(func, torch.autograd.Function) + except TypeError: + # if func is not a class, then it is not an autograd function + return False + + def create_wrapped_autograd_apply(default_tracer: 'ConcreteTracer'): @classmethod @functools.wraps(orig_func.torch_agfunc_apply) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 79ae00cf..4a286dd7 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -352,7 +352,6 @@ def test_codegen_recompute_kwargs(): ) - class DefaultArgsModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/tests/parallel_module/test_gencode_torch_compile.py b/tests/parallel_module/test_gencode_torch_compile.py new file mode 100644 index 00000000..700bac9d --- /dev/null +++ b/tests/parallel_module/test_gencode_torch_compile.py @@ -0,0 +1,192 @@ +import tempfile +import pytest +import torch +import math +import torch.nn.functional as F + +from nnscaler import parallelize, ComputeConfig, register_op + +from tests.utils import replace_all_device_with + +from .test_gencode import _gencode_contains, print_gencode + + +class ActQuant(torch.autograd.Function): + + @staticmethod + @torch.compile + def forward(ctx, x): + dtype = x.dtype + x = x.float() + s = 127 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) + x = (x * s).round().clamp(-128, 127) / s + return x.to(dtype) + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + return grad_input + + +class ActQuantInt4(torch.autograd.Function): + + @staticmethod + @torch.compile + def forward(ctx, x): + dtype = x.dtype + x = x.float() + s = math.sqrt(7) / x.abs().mean(dim=-1, keepdim=True).clamp_(min=1e-5) + x = (x * s).round().clamp(-8, 7) / s + return x.to(dtype) + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + return grad_input + + +class ActQuantInt2(torch.autograd.Function): + + @staticmethod + @torch.compile + def forward(ctx, x): + dtype = x.dtype + x = x.float() + s = math.sqrt(3) / x.abs().mean(dim=-1, keepdim=True).clamp_(min=1e-5) + x = (x * s).round().clamp(-4, 3) / s + return x.to(dtype) + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + return grad_input + +# Two ways to register autograd functions: +# 1. Use `@register_op` decorator +# 2. Use `register_op` function directly, and pass `Function` or `Function.apply`. + +@register_op('*^ -> *^') +class WeightQuant(torch.autograd.Function): + + @staticmethod + @torch.compile + def forward(ctx, x): + dtype = x.dtype + x = x.float() + s = 1.0 / x.abs().mean().clamp_(min=1e-5) + x = (x * s).round().clamp(-1, 1) / s + return x.to(dtype) + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + return grad_input + + +register_op('*^ -> *^')(ActQuant) +register_op('*^ -> *^')(ActQuantInt2.apply) +register_op('*^ -> *^')(ActQuantInt4.apply) + + +class BitLinear(torch.nn.Linear): + def __init__(self, in_features: int, out_features: int, split_size: list[int], bias: bool = True, act_bits: int = 8): + super(BitLinear, self).__init__(in_features, out_features, bias) + self.split_size = split_size + self.act_bits = act_bits + assert sum(split_size) == out_features + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(self.split_size) == 1: + weight = WeightQuant.apply(self.weight) + else: + weight = torch.split(self.weight, self.split_size, dim=0) + weight = [WeightQuant.apply(w) for w in weight] + weight = torch.cat(weight, dim=0) + if self.act_bits == 8: + input = ActQuant.apply(x) + elif self.act_bits == 4: + input = ActQuantInt4.apply(x) + elif self.act_bits == 2: + input = ActQuantInt2.apply(x) + else: + raise ValueError(f"Unsupported act_bits: {self.act_bits}") + return F.linear(input, weight, self.bias) + + +# looks torch.compile needs gpu, not sure why. +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_codegen_compile_apply(tmp_path): + m = BitLinear(64, 128, [64, 64], bias=False) + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert True + + +@register_op('* -> *') +@torch.compile +def f(x): + return x * 2 + + +class Module1(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return self.linear(x) + f(x) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_codegen_compile_f(): + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + Module1(), + {'x': torch.randn(3, 3)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False + ) + # parallelize will succeed. + assert True + + +@torch.compile +def g(x): + # g is not registered in nnscaler + # RuntimeError will be raised + # when parallelize is called. + return x * 2 + + +class Module2(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return self.linear(x) + g(x) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_codegen_compile_failed_g(): + with pytest.raises(RuntimeError), tempfile.TemporaryDirectory() as tempdir: + parallelize( + Module2(), + {'x': torch.randn(3, 3)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False + ) + # parallelize will succeed. + assert True diff --git a/tests/parallel_module/test_pyfunc.py b/tests/parallel_module/test_pyfunc.py index e30c1afe..c064dfa4 100644 --- a/tests/parallel_module/test_pyfunc.py +++ b/tests/parallel_module/test_pyfunc.py @@ -35,14 +35,14 @@ def __init__(self): self.weight = torch.nn.Parameter(torch.rand(10, 10)) def forward(self, x): - x = MyMatmul.apply(x, self.weight) + x = MyMatmul.apply(x[0], self.weight) return x def _worker(): init_distributed() - dummy_input = {'x': torch.rand(2, 10)} + dummy_input = {'x': (torch.rand(2, 10), torch.rand(10, 10))} from nnscaler.graph.parser.parser import _logger as _logger_parser from nnscaler.graph.graph import _logger as _logger_graph from nnscaler.graph.segment import _logger as _logger_seg @@ -63,8 +63,8 @@ def _worker(): seg_logs = log_stream_seg.getvalue() graph_logs = log_stream_graph.getvalue() # parser.py: parse_prim_function_method - assert 'non register python runtime function' in parser_logs - # segment.py: infer_grad + assert 'Find unknown custom autograd operation' in parser_logs + # segment.py: infer_grad assert 'nnScaler does not support backward of IRPyFunc' in seg_logs # graph.py: from_logic_graph assert 'nnScaler does not support to compute gradients for IRPyFunc.' in graph_logs From aa5ac3bd00bbc7fb9429972fb62ee187a16be419 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 7 Aug 2025 01:53:57 +0000 Subject: [PATCH 1824/1892] Merged PR 2389: [Feature][Breaking Change] Refine hooks 1. Add more hooks for logging metrics 2. Simplify existing hook parameters. This will break some old code, because it removes some parameters. 3. Document. --- docs/source/trainer.rst | 27 ++++-- examples/nanogpt/train_cli.py | 8 +- examples/vit/vit_cli.py | 8 +- nnscaler/cli/train_hook.py | 85 ++++++++++++------- nnscaler/cli/trainer.py | 24 +++--- nnscaler/cli/trainer_args.py | 3 + tests/cli/test_trainer.py | 4 +- .../lightning/pytorch/test_strategy.py | 6 +- 8 files changed, 105 insertions(+), 60 deletions(-) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 8afa5041..7b311391 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -377,6 +377,7 @@ Please Note: @dataclass class HookMapConfig: after_setup: str = None + on_finalize: str = None on_train_start: str = None on_train_end: str = None @@ -391,6 +392,9 @@ Please Note: on_val_step_start: str = None on_val_step_end: str = None + on_step_start: str = None + on_step_end: str = None + after_aggregate_train_step_outputs: str = None after_aggregate_val_step_outputs: str = None @@ -403,12 +407,17 @@ Please Note: before_optimizer_step: str = None after_optimizer_step: str = None + before_log_train_metrics: str = None + before_log_val_metrics: str = None + on_load_checkpoint: str = None on_save_checkpoint: str = None * ``after_setup`` (``str``): The hook function to be called after setting up the trainer. Only be called when ``run_mode == 'run'``. Signature: ``def after_setup(trainer: 'Trainer') -> None:`` + * ``on_finalize`` (``str``): The hook function to be called when the training is done. + Signature: ``def on_finalize(trainer: 'Trainer') -> None:`` * ``on_train_start`` (``str``): The hook function to be called at the start of the training stage. Signature: ``def on_train_start(trainer: 'Trainer') -> None:`` * ``on_train_end`` (``str``): The hook function to be called at the end of the training stage. Signature: ``def on_train_end(trainer: 'Trainer') -> None:`` * ``on_val_start`` (``str``): The hook function to be called at the start of the validation stage. Signature: ``def on_val_start(trainer: 'Trainer') -> None:`` @@ -416,12 +425,14 @@ Please Note: * ``on_epoch_start`` (``str``): The hook function to be called at the start of each epoch. Signature: ``def on_epoch_start(trainer: 'Trainer', epoch: int) -> None:`` * ``on_epoch_end`` (``str``): The hook function to be called at the end of each epoch. Signature: ``def on_epoch_end(trainer: 'Trainer', epoch: int) -> None:`` * ``on_train_step_start`` (``str``): The hook function to be called at the start of each training step. - Signature: ``def on_train_step_start(trainer: 'Trainer', batches: List[Any], idx: int) -> None:`` - * ``on_train_step_end`` (``str``): The hook function to be called at the end of each training step. Signature: ``def on_train_step_end(trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None:`` - * ``on_val_step_start`` (``str``): The hook function to be called at the start of each validation step. Signature: ``def on_val_step_start(trainer: 'Trainer', batches: List[Any], idx: int) -> None:`` - * ``on_val_step_end`` (``str``): The hook function to be called at the end of each validation step. Signature: ``def on_val_step_end(trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None:`` - * ``after_aggregate_train_step_outputs`` (``str``): The hook function to be called after aggregating the outputs of the model in the training step. Signature: ``def after_aggregate_train_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None:`` - * ``after_aggregate_val_step_outputs`` (``str``): The hook function to be called after aggregating the outputs of the model in the validation step. Signature: ``def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None:`` + Signature: ``def on_train_step_start(trainer: 'Trainer', batches: List[Any]) -> None:`` + * ``on_train_step_end`` (``str``): The hook function to be called at the end of each training step. Signature: ``def on_train_step_end(trainer: 'Trainer', outputs: List[Any]) -> None:`` + * ``on_val_step_start`` (``str``): The hook function to be called at the start of each validation step. Signature: ``def on_val_step_start(trainer: 'Trainer', batches: List[Any]) -> None:`` + * ``on_val_step_end`` (``str``): The hook function to be called at the end of each validation step. Signature: ``def on_val_step_end(trainer: 'Trainer', outputs: List[Any]) -> None:`` + * ``on_step_start`` (``str``): The hook function to be called at the start of each step. Signature: ``def on_step_start(self, trainer: 'Trainer', epoch: int, idx: int) -> None:`` + * ``on_step_end`` (``str``): The hook function to be called at the end of each step. Signature: ``def on_step_end(self, trainer: 'Trainer', epoch: int, idx: int, step_metrics: TrainStepMetrics, aggregated_outputs: 'AggregatedOutputs') -> None:`` + * ``after_aggregate_train_step_outputs`` (``str``): The hook function to be called after aggregating the outputs of the model in the training step. Signature: ``def after_aggregate_train_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float) -> None:`` + * ``after_aggregate_val_step_outputs`` (``str``): The hook function to be called after aggregating the outputs of the model in the validation step. Signature: ``def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float) -> None:`` * ``before_zero_grad`` (``str``): The hook function to be called before zeroing the gradients. Signature: ``def before_zero_grad(trainer: 'Trainer') -> None:`` * ``after_zero_grad`` (``str``): The hook function to be called after zeroing the gradients. Signature: ``def after_zero_grad(trainer: 'Trainer') -> None:`` * ``before_sync_grad`` (``str``): The hook function to be called before syncing the gradients between ranks. @@ -439,6 +450,10 @@ Please Note: Signature: ``def before_optimizer_step(trainer: 'Trainer') -> None:`` * ``after_optimizer_step`` (``str``): The hook function to be called after the optimizer step. Signature: ``def after_optimizer_step(trainer: 'Trainer') -> None:`` + * ``before_log_train_metrics`` (``str``): The hook function to be called before logging the training metrics. You can use this to modify the training metrics before logging. + Signature: ``def before_log_train_metrics(self, trainer: 'Trainer', step_metrics: TrainStepMetrics, aggregated_outputs: 'AggregatedOutputs') -> None:`` + * ``before_log_val_metrics`` (``str``): The hook function to be called before logging the validation metrics. You can use this to modify the validation metrics before logging. + Signature: ``def before_log_val_metrics(self, trainer: 'Trainer', metrics: ValMetrics) -> None:`` * ``on_load_checkpoint`` (``str``): The hook function to be called after loading the checkpoint. If you saved something with ``on_save_checkpoint`` this is your chance to restore this. Please note when checkpoints are merged, the custom data saved in the checkpoint diff --git a/examples/nanogpt/train_cli.py b/examples/nanogpt/train_cli.py index 04586397..a1baaaee 100644 --- a/examples/nanogpt/train_cli.py +++ b/examples/nanogpt/train_cli.py @@ -49,14 +49,14 @@ def init_env(trainer: 'Trainer'): torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn -def on_train_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: +def on_train_step_end(trainer: 'Trainer', outputs) -> None: if torch.distributed.get_rank() == 0: - print(f'# train_loss {idx:03d}', outputs[0].item()) + print(f'# train_loss {trainer.train_status.finished_train_steps:03d}', outputs[0].item()) -def on_val_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: +def on_val_step_end(trainer: 'Trainer', outputs) -> None: if torch.distributed.get_rank() == 0: - print(f'# val_loss {idx:03d}', outputs[0].item()) + print(f'# val_loss {trainer.train_status.finished_train_steps:03d}', outputs[0].item()) # poor man's data loader diff --git a/examples/vit/vit_cli.py b/examples/vit/vit_cli.py index 92bb5e08..4543e4db 100644 --- a/examples/vit/vit_cli.py +++ b/examples/vit/vit_cli.py @@ -81,14 +81,14 @@ def init_env(trainer: 'Trainer'): print('Resume on this line') -def on_train_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: +def on_train_step_end(trainer: 'Trainer', outputs) -> None: if torch.distributed.get_rank() == 0: - print(f'# train_loss {idx:03d}', outputs[0].item()) + print(f'# train_loss {trainer.train_status.finished_train_steps:03d}', outputs[0].item()) -def on_val_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: +def on_val_step_end(trainer: 'Trainer', outputs) -> None: if torch.distributed.get_rank() == 0: - print(f'# val_loss {idx:03d}', outputs[0].item()) + print(f'# val_loss {trainer.train_status.finished_train_steps:03d}', outputs[0].item()) datasets = None diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index 620cedf7..76abeb8b 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Any, Dict, List, TYPE_CHECKING, TypedDict, Optional +from typing import Any, Dict, List, TYPE_CHECKING, Literal, TypedDict, Optional import torch @@ -10,7 +10,7 @@ from nnscaler.cli.trainer_args import AggregatedOutputs -class StepMetrics(TypedDict): +class TrainStepMetrics(TypedDict): train_loss: float loss: float # alias for train_loss lr: float @@ -18,6 +18,11 @@ class StepMetrics(TypedDict): train_wall: float # wall time for training step +class ValMetrics(TypedDict): + val_loss: float + val_wall: float # wall time for validation + + class TrainHook: """ Note: All hooks are called in all ranks, and the inputs of hooks are only the local data. @@ -67,66 +72,59 @@ def on_step_start(self, trainer: 'Trainer', epoch: int, idx: int) -> None: idx: the index of current step """ - def on_step_end(self, trainer: 'Trainer', epoch: int, idx: int, step_metrics: StepMetrics) -> None: + def on_step_end(self, trainer: 'Trainer', epoch: int, idx: int, step_metrics: TrainStepMetrics, aggregated_outputs: 'AggregatedOutputs') -> None: """ Called at the end of each step (validation and checkpoint saving are not included) Args: idx: the index of current step step_metrics: the metrics of the current step + aggregated_outputs: the aggregated outputs of the current step """ - def on_train_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: + def on_train_step_start(self, trainer: 'Trainer', batches: List[Any]) -> None: """ Called at the beginning of each training step Please note one train step may contain multiple batches Args: batches: the current batches - idx: the index of current step """ - def on_train_step_end(self, trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None: + def on_train_step_end(self, trainer: 'Trainer', outputs: List[Any]) -> None: """ Called at the end of each training step Args: outputs: the outputs of the train_step - batches: the current batches - idx: the index of current step """ - def on_val_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: + def on_val_step_start(self, trainer: 'Trainer', batches: List[Any]) -> None: """ Called at the beginning of each validating step Please note one val step may contain multiple batches Args: batches: the current batches - idx: the index of current step """ - def on_val_step_end(self, trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None: + def on_val_step_end(self, trainer: 'Trainer', outputs: List[Any]) -> None: """ Called at the end of each validating step Args: outputs: the outputs of the val_step - batches: the current batches - idx: the index of current step """ - def after_aggregate_train_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None: + def after_aggregate_train_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float) -> None: """ Called after aggregating outputs in train step Args: aggregated_outputs: the aggregated outputs train_loss: the loss of the current step - idx: the index of current step """ - def after_aggregate_val_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None: + def after_aggregate_val_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float) -> None: """ Called after aggregating outputs in val step Args: aggregated_outputs: the aggregated outputs val_loss: the loss of the current step - idx: the index of current step """ def before_zero_grad(self, trainer: 'Trainer') -> None: @@ -170,6 +168,23 @@ def after_optimizer_step(self, trainer: 'Trainer') -> None: Called after optimizer.step() """ + def before_log_train_metrics(self, trainer: 'Trainer', step_metrics: TrainStepMetrics, aggregated_outputs: 'AggregatedOutputs') -> None: + """ + Called before logging metrics. + This is useful for modifying the metrics (inplace) before logging. + Args: + step_metrics: the metrics of the current step + aggregated_outputs: the aggregated outputs of the current step + """ + + def before_log_val_metrics(self, trainer: 'Trainer', metrics: ValMetrics) -> None: + """ + Called before logging validation metrics. + This is useful for modifying the metrics (inplace) before logging. + Args: + metrics: the metrics of the current validation + """ + def on_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: """ Called after loading checkpoint. @@ -239,33 +254,33 @@ def on_step_start(self, trainer: 'Trainer', epoch: int, idx: int) -> None: for hook in self.hooks: hook.on_step_start(trainer, epoch, idx) - def on_step_end(self, trainer: 'Trainer', epoch: int, idx: int, step_metrics: StepMetrics) -> None: + def on_step_end(self, trainer: 'Trainer', epoch: int, idx: int, step_metrics: TrainStepMetrics, aggregated_outputs: 'AggregatedOutputs') -> None: for hook in self.hooks: - hook.on_step_end(trainer, epoch, idx, step_metrics) + hook.on_step_end(trainer, epoch, idx, step_metrics, aggregated_outputs) - def on_train_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: + def on_train_step_start(self, trainer: 'Trainer', batches: List[Any]) -> None: for hook in self.hooks: - hook.on_train_step_start(trainer, batches, idx) + hook.on_train_step_start(trainer, batches) - def on_train_step_end(self, trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None: + def on_train_step_end(self, trainer: 'Trainer', outputs: List[Any]) -> None: for hook in self.hooks: - hook.on_train_step_end(trainer, outputs, batches, idx) + hook.on_train_step_end(trainer, outputs) - def on_val_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: + def on_val_step_start(self, trainer: 'Trainer', batches: List[Any]) -> None: for hook in self.hooks: - hook.on_val_step_start(trainer, batches, idx) + hook.on_val_step_start(trainer, batches) - def on_val_step_end(self, trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None: + def on_val_step_end(self, trainer: 'Trainer', outputs: List[Any]) -> None: for hook in self.hooks: - hook.on_val_step_end(trainer, outputs, batches, idx) + hook.on_val_step_end(trainer, outputs) - def after_aggregate_train_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None: + def after_aggregate_train_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float) -> None: for hook in self.hooks: - hook.after_aggregate_train_step_outputs(trainer, aggregated_outputs, train_loss, idx) + hook.after_aggregate_train_step_outputs(trainer, aggregated_outputs, train_loss) - def after_aggregate_val_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None: + def after_aggregate_val_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float) -> None: for hook in self.hooks: - hook.after_aggregate_val_step_outputs(trainer, aggregated_outputs, val_loss, idx) + hook.after_aggregate_val_step_outputs(trainer, aggregated_outputs, val_loss) def before_zero_grad(self, trainer: 'Trainer') -> None: for hook in self.hooks: @@ -299,6 +314,14 @@ def after_optimizer_step(self, trainer: 'Trainer') -> None: for hook in self.hooks: hook.after_optimizer_step(trainer) + def before_log_train_metrics(self, trainer: 'Trainer', step_metrics: TrainStepMetrics, aggregated_outputs: 'AggregatedOutputs') -> None: + for hook in self.hooks: + hook.before_log_train_metrics(trainer, step_metrics, aggregated_outputs) + + def before_log_val_metrics(self, trainer: 'Trainer', metrics: ValMetrics) -> None: + for hook in self.hooks: + hook.before_log_val_metrics(trainer, metrics) + def on_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: for hook in self.hooks: hook.on_load_checkpoint(trainer, checkpoint) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 7a176e3d..d167b95f 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -47,7 +47,8 @@ @dataclass class TrainStatus: best_loss = float('inf') - # the train steps done so far + # the train steps done (forward/backward/optimizer step) so far + # This will be updated after optimizer.step is done, but before validation/logging metrics/saving checkpoint. finished_train_steps: int = 0 @@ -844,6 +845,7 @@ def _validate(self, step_stat: _StepStat): batches_count = 0 self.hook.on_val_start(self) + val_start_at = time.perf_counter() for idx, batches in data_iter: if self.train_args.max_val_steps and idx >= self.train_args.max_val_steps: break @@ -853,29 +855,30 @@ def _validate(self, step_stat: _StepStat): self.model.eval() with torch.inference_mode(): - self.hook.on_val_step_start(self, batches[:num_batches], idx) + self.hook.on_val_step_start(self, batches[:num_batches]) losses = self.model.infer_step(batches) - self.hook.on_val_step_end(self, losses[:num_batches], batches[:num_batches], idx) + self.hook.on_val_step_end(self, losses[:num_batches]) aggregate_outputs = self.train_args.resolved_aggregate_outputs_fn or self.aggregate_outputs aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) self.hook.after_aggregate_val_step_outputs( self, aggregated_outputs, aggregated_outputs.loss_sum / aggregated_outputs.num_batches, - idx ) loss_sum += aggregated_outputs.loss_sum batches_count += aggregated_outputs.num_batches + val_wall = time.perf_counter() - val_start_at # update train status loss = loss_sum / batches_count self.hook.on_val_end(self, loss) step_stat.val_loss = loss - val_metrics = asdict(step_stat) + val_metrics = {'val_loss': loss, 'val_wall': val_wall} + self.hook.before_log_val_metrics(self, val_metrics) self.log_metrics(val_metrics, tag='val') if self.rank == 0 and self.train_args.enable_log_progress: - logger.info(self._format_metrics(f'Validation', None, val_metrics)) + logger.info(self._format_metrics(f'Validation', None, asdict(step_stat))) return step_stat.val_loss def _train_epoch(self, epoch: int) -> None: @@ -921,9 +924,9 @@ def _train_epoch(self, epoch: int) -> None: self.optimizer.zero_grad() self.hook.after_zero_grad(self) - self.hook.on_train_step_start(self, batches[:num_batches], idx) + self.hook.on_train_step_start(self, batches[:num_batches]) losses = self.model.train_step(batches, is_dummy_batch) - self.hook.on_train_step_end(self, losses[:num_batches], batches[:num_batches], idx) + self.hook.on_train_step_end(self, losses[:num_batches]) aggregate_outputs = self.train_args.resolved_aggregate_outputs_fn or self.aggregate_outputs aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) @@ -936,7 +939,7 @@ def _train_epoch(self, epoch: int) -> None: else: loss = aggregated_outputs.loss_sum step_stat.train_loss = loss - self.hook.after_aggregate_train_step_outputs(self, aggregated_outputs, loss, idx) + self.hook.after_aggregate_train_step_outputs(self, aggregated_outputs, loss) self.hook.before_sync_grad(self) # `sync_shard_grad` is no-op if the whole model is parallelized @@ -986,6 +989,7 @@ def _train_epoch(self, epoch: int) -> None: step_metrics = {k:v for k, v in asdict(step_stat).items() if v is not None} step_metrics['train_wall'] = time.perf_counter() - step_start_at step_metrics['loss'] = step_metrics['train_loss'] + self.hook.before_log_train_metrics(self, step_metrics, aggregated_outputs) self.log_metrics(step_metrics, tag='train') if self.rank == 0: data_iter.set_postfix(step_metrics) @@ -994,7 +998,7 @@ def _train_epoch(self, epoch: int) -> None: logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) step_metrics = {} - self.hook.on_step_end(self, epoch, idx, step_metrics) + self.hook.on_step_end(self, epoch, idx, step_metrics, aggregated_outputs) # validate and save checkpoint if self.train_args.checkpoint.every_n_train_steps and \ diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 8552ad53..2fb99daa 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -545,6 +545,9 @@ class HookMapConfig: before_optimizer_step: str = None after_optimizer_step: str = None + before_log_train_metrics: str = None + before_log_val_metrics: str = None + on_load_checkpoint: str = None after_load_checkpoint: str = None on_save_checkpoint: str = None diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index d69a392f..2dc1d252 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -388,11 +388,11 @@ def test_trainer_last_checkpoint(tmp_path): _val_losses = [] -def after_aggregate_train_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None: +def after_aggregate_train_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float) -> None: _train_losses.append(train_loss) -def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None: +def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float) -> None: _val_losses.append(val_loss) diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index 41023354..8adf8a7f 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -195,11 +195,11 @@ def on_before_grad_clip(trainer: Trainer): _correctnes_worker_update_history.append((grads, weights)) -def after_aggregate_train_step_outputs(trainer: Trainer, aggregated_outputs, train_loss, idx): +def after_aggregate_train_step_outputs(trainer: Trainer, aggregated_outputs, train_loss): _correctnes_worker_train_loss_history.append(train_loss) -def on_train_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: +def on_train_step_end(trainer: 'Trainer', outputs) -> None: _correctnes_worker_single_loss_history.append(outputs[0].item()) @@ -243,7 +243,7 @@ def correctnes_worker_cli( with_tp=False ): - def on_val_step_end(trainer: Trainer, outputs, batches, idx) -> None: + def on_val_step_end(trainer: Trainer, outputs) -> None: _correctnes_worker_val_loss_history.append(outputs[0].item()) assert precision == '32-true' From 65dfa75512d29e29c42b2cf15acbacd83cef44d2 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 7 Aug 2025 03:15:41 +0000 Subject: [PATCH 1825/1892] Merged PR 2388: [AutoDist] Refine recompute modules implementation --- docs/source/autodist/configuration.rst | 305 +++++++++++++++++++++++ docs/source/index.rst | 7 + nnscaler/autodist/autodist_config.py | 2 + nnscaler/autodist/model_graph.py | 8 +- nnscaler/graph/tracer/concrete_tracer.py | 2 +- tests/autodist/graph/test_recompute.py | 120 +++++++++ 6 files changed, 442 insertions(+), 2 deletions(-) create mode 100644 docs/source/autodist/configuration.rst diff --git a/docs/source/autodist/configuration.rst b/docs/source/autodist/configuration.rst new file mode 100644 index 00000000..d02ae5a4 --- /dev/null +++ b/docs/source/autodist/configuration.rst @@ -0,0 +1,305 @@ +AutoDist Configuration Reference +==================================== + +This document provides a comprehensive guide to all configuration options available in AutoDist's ``AutoDistConfig`` class. + +Overview +-------- + +``AutoDistConfig`` is the central configuration class for AutoDist, allowing you to control various aspects of automatic parallelization including memory optimization, pipeline parallelism, tensor parallelism, and recomputation strategies. + +Basic Usage +----------- + +.. code-block:: python + + from nnscaler.autodist.autodist_config import AutoDistConfig + + # Basic configuration + config = AutoDistConfig( + task_name='my_experiment', + memory_constraint=32, # 32GB memory limit + recompute_modules='transformer.layer' # Recompute transformer layers + ) + +Configuration Parameters +------------------------ + +Task Configuration +~~~~~~~~~~~~~~~~~~ + +**task_name** (*str*, optional, default: ``'default'``) + The name of the current task to distinguish different runs. Used for naming saved plans and logs. + + .. code-block:: python + + config = AutoDistConfig(task_name='bert_large_training') + +Memory Management +~~~~~~~~~~~~~~~~~ + +**consider_mem** (*bool*, optional, default: ``True``) + Whether to consider memory constraints when searching for parallelization plans. + +**memory_constraint** (*float*, optional, default: ``32``) + The memory constraint for each device in GB. AutoDist will ensure that the parallelization plan fits within this memory limit. + + .. code-block:: python + + config = AutoDistConfig(memory_constraint=80) # 80GB A100 + +**memory_granularity** (*int*, optional, default: ``1``) + The memory granularity in bytes. Used for memory profiling and estimation. + +**transient_mem_coef** (*float*, optional, default: ``2``) + Coefficient for estimating transient memory size. Formula: ``transient_mem_size = transient_mem_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)``. + + Reduce this value if operators consume/generate very large tensors (≥4GB). + +Optimizer Configuration +~~~~~~~~~~~~~~~~~~~~~~~ + +**opt_resident_coef** (*int*, optional, default: ``2``) + Coefficient for optimizer resident state compared to model weight size. + + Common cases: + + - FP32 training with Adam: ``2`` (FP32 momentum1 + FP32 momentum2) + - FP16/BF16 training with Adam: ``6`` (FP32 momentum1 + FP32 momentum2 + FP32 weight) + - FP16/BF16 training with memory-efficient Adam: ``4`` (FP32 momentum1 + FP32 momentum2) + +**opt_transient_coef** (*int*, optional, default: ``0``) + Coefficient for optimizer transient state compared to model weight size. + + Common cases: + + - FP32 training with Adam: ``0`` + - FP16/BF16 training with Adam without internal cast: ``2`` (FP32 gradient) + - FP16/BF16 training with memory-efficient Adam without internal cast: ``4`` (FP32 weight + FP32 gradient) + +Recomputation +~~~~~~~~~~~~~ + +**recompute_modules** (*str*, optional, default: ``''``) + Module names to recompute, separated by commas. Recomputation trades computation for memory by not storing intermediate activations during forward pass and recomputing them during backward pass. Note that recomputation still requires storing some tensors for gradient computation, so the memory savings depend on the specific model structure and recomputation granularity. + + Examples: + + .. code-block:: python + + # Recompute specific modules + config = AutoDistConfig(recompute_modules='transformer.layer,attention') + + # Recompute entire model + config = AutoDistConfig(recompute_modules='ROOT') + + # Recompute multiple specific modules + config = AutoDistConfig(recompute_modules='encoder.layer,decoder.layer') + + **Note**: Module names can be any suffix of the full module name. For example, ``layer`` will match ``transformer.layer``, ``encoder.layer``, etc. ``ROOT`` recomputes the entire model but may not always provide maximum memory savings due to the need to store intermediate tensors for backward pass. + +ZeRO Optimization +~~~~~~~~~~~~~~~~~ + +**zero_stage** (*int*, optional, default: ``0``) + ZeRO optimization stage (see `ZeRO paper `_). + + - ``0``: No ZeRO optimization + - ``1``: Optimizer state partitioning + +**zero_ngroups** (*int*, optional, default: ``1``) + Number of ZeRO groups to balance memory usage and communication cost. Larger values use more memory but reduce communication overhead. + +Pipeline Parallelism +~~~~~~~~~~~~~~~~~~~~ + +**pipeline_pivots** (*str*, optional, default: ``''``) + Module names that serve as pipeline stage boundaries, separated by commas. + + .. code-block:: python + + config = AutoDistConfig(pipeline_pivots='encoder,decoder') + +**pipeline_nstages** (*int* or *'auto'*, optional, default: ``'auto'``) + Number of pipeline stages. Set to ``1`` to disable pipeline parallelism. + + - ``'auto'``: Automatically determine optimal number of stages + - ``int``: Fixed number of stages + +**pipeline_scheduler** (*str*, optional, default: ``'1f1b'``) + Pipeline scheduling strategy. Currently only supports ``'1f1b'`` (1-forward-1-backward). + +**max_pipeline_bubble_ratio** (*float*, optional, default: ``0.2``) + Maximum allowed bubble ratio in pipeline parallelism. Higher values allow more pipeline bubbles but explore larger search space. + +**max_pipeline_unbalance_ratio** (*float*, optional, default: ``0.5``) + Maximum unbalance ratio between pipeline stages (min_stage_time / max_stage_time). Higher values require better balance but reduce search space. + +Mesh and Parallelism +~~~~~~~~~~~~~~~~~~~~ + +**mesh_row** (*int*, optional, default: ``1``) + Number of available nodes in the device mesh. + +**mesh_col** (*int*, optional, default: ``1``) + Number of available devices per node in the device mesh. + +**world_size** (*int*, optional, default: ``1``) + Total number of devices (mesh_row × mesh_col × scale_factor). + +**micro_batch_size** (*int*, optional, default: ``1``) + Micro batch size for gradient accumulation. + +**update_freq** (*int*, optional, default: ``1``) + Update frequency. The effective batch size is micro_batch_size × update_freq. + +Profiling and Search +~~~~~~~~~~~~~~~~~~~~ + +**profile_dir** (*str*, optional, default: ``~/.cache/nnscaler/autodist/1.0/get_node_arch()``) + Directory to store profiling results for computation cost estimation. + +**parallel_profile** (*bool*, optional, default: ``True``) + Whether to profile on multiple devices in parallel. Set to ``False`` for sequential profiling on a single device. + +**re_profile** (*bool*, optional, default: ``False``) + Whether to override existing profiling results and re-profile operations. + +**topk** (*int*, optional, default: ``20``) + Number of parallelization plans to generate for robustness. Higher values provide more options but increase search time. + +**solver** (*str*, optional, default: ``'dp'``) + Solver algorithm for SPMD parallelism: + + - ``'dp'``: Dynamic programming + - ``'ilp'``: Integer linear programming + +**nproc** (*int*, optional, default: ``1``) + Number of processes for pipeline parallelism search. + +Plan Management +~~~~~~~~~~~~~~~ + +**load_plan_path** (*str*, optional, default: ``''``) + Path to load an existing parallelization plan. When specified, skips plan searching and uses the loaded plan. + +**save_plan_path** (*str*, optional, default: ``''``) + Path to save the generated parallelization plan for reuse. + +**partition_constraints_path** (*str*, optional, default: ``''``) + Path to partition constraints file. See :doc:`solver_interface/partition_constraints` for details. + +Training Configuration +~~~~~~~~~~~~~~~~~~~~~~ + +**is_train** (*bool*, optional, default: ``True``) + Whether the model is for training or inference. Affects memory estimation and operator selection. + +Debug and Optimization +~~~~~~~~~~~~~~~~~~~~~~ + +**verbose** (*bool*, optional, default: ``False``) + Whether to print verbose information during plan generation. + +**ignore_small_tensor_threshold** (*int*, optional, default: ``1``) + Tensor size threshold (in elements) to ignore during analysis. Small tensors below this threshold are not considered for partitioning. + +Example Configurations +---------------------- + +High Memory Training +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + # Configuration for large model training with high memory + config = AutoDistConfig( + task_name='large_model_training', + memory_constraint=80, # 80GB A100 + recompute_modules='transformer.layer', # Selective recomputation + zero_stage=1, # Enable ZeRO stage 1 + zero_ngroups=4, # Use 4 ZeRO groups + opt_resident_coef=6, # FP16 training with Adam + opt_transient_coef=2, + topk=50 # More plan options + ) + +Pipeline Parallelism +~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + # Configuration for pipeline parallelism + config = AutoDistConfig( + task_name='pipeline_training', + pipeline_pivots='encoder,decoder', + pipeline_nstages=4, + pipeline_scheduler='1f1b', + max_pipeline_bubble_ratio=0.1, # Strict bubble control + mesh_row=2, # 2 nodes + mesh_col=4, # 4 GPUs per node + micro_batch_size=2, + update_freq=4 # Effective batch size = 2 * 4 = 8 + ) + +Memory-Efficient Training +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + # Configuration for memory-efficient training + config = AutoDistConfig( + task_name='efficient_training', + is_train=True, + consider_mem=True, + memory_constraint=24, # 24GB RTX 4090 + recompute_modules='attention,mlp', # Selective recomputation + solver='ilp', # More precise optimization + topk=10 + ) + +Best Practices +-------------- + +1. **Start Simple**: Begin with default settings and gradually tune parameters based on your needs. + +2. **Memory Tuning**: + - Consider ``recompute_modules`` for memory savings, but note that more aggressive recomputation (like ``'ROOT'``) doesn't always provide maximum memory savings + - Adjust ``memory_constraint`` based on your hardware + - Fine-tune optimizer coefficients based on your training setup + - Experiment with different recomputation granularities to find the optimal memory-computation trade-off + +3. **Pipeline Parallelism**: + - Choose ``pipeline_pivots`` at natural module boundaries + - Start with ``pipeline_nstages='auto'`` to find optimal stages + - Monitor bubble ratio and adjust ``max_pipeline_bubble_ratio`` + +4. **Profiling**: + - Enable ``parallel_profile`` for faster profiling + - Set ``re_profile=True`` when changing hardware or model architecture + - Use appropriate ``profile_dir`` for different experiments + +5. **Plan Management**: + - Save successful plans with ``save_plan_path`` for reuse + - Use descriptive ``task_name`` for better organization + +Troubleshooting +--------------- + +**Out of Memory Errors** + - Reduce ``memory_constraint`` + - Experiment with different ``recompute_modules`` strategies (selective recomputation may be more effective than ``'ROOT'``) + - Increase ``zero_ngroups`` or enable higher ZeRO stages + - Reduce ``transient_mem_coef`` + +**Slow Plan Generation** + - Reduce ``topk`` for faster search + - Use ``'dp'`` solver instead of ``'ilp'`` + - Set ``parallel_profile=True`` + - Increase ``ignore_small_tensor_threshold`` + +**Poor Performance** + - Check ``max_pipeline_bubble_ratio`` if using pipeline parallelism + - Verify ``mesh_row`` and ``mesh_col`` match your hardware + - Tune ``micro_batch_size`` and ``update_freq`` + - Consider different ``recompute_modules`` strategies diff --git a/docs/source/index.rst b/docs/source/index.rst index 126507b2..8787ce04 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -124,6 +124,13 @@ For any questions or inquiries, please contact us at nnscaler@service.microsoft. dimops verify_op +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: AutoDist + + autodist/configuration + .. toctree:: :maxdepth: 1 :hidden: diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 2b33f202..796d35e2 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -70,6 +70,8 @@ class AutoDistConfig: The module names to recompute, separated by `,`. For example, `module1,module2`. Module name can be any suffix of the full module name, e.g., `module1` will match `x.module1`, `y.module1`, `x.module1` will match `x.module1` but not `y.module1`. + Due to constraint of the tracer, you can pass `ROOT` to recompute_modules if you want the whole module to + be recomputed. - memory_constraint (`float`, *optional*, defaults to `32`): The memory constraint in each device in GB. - memory_granularity (`int`, *optional*, defaults to `1`): diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 6f6da5d6..44be8ca2 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -712,7 +712,13 @@ def fetch_module(scope_node: ScopeNode, prefix: List[str]): ret += fetch_module(child, next_prefix) return ret - modules = fetch_module(self.scope_tree_root, []) + # the root module's name is not tracked in the tracer, to enable recomputing the + # whole module, we add a special 'ROOT' module name + if 'ROOT' in recompute_modules: + modules = [self.scope_tree_root] + else: + modules = fetch_module(self.scope_tree_root, []) + train_mem = 0 for module in modules: train_mem = max(train_mem, module.train_mem) diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index aab86ba0..9ec92c27 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -160,7 +160,7 @@ def create_node(self, kind : str, target : Target, node.meta['nn_module_stack'] = collections.OrderedDict() def unwrap_nested_proxy(proxy: ep.ConcreteProxy): - return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) + return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) # unwrap all proxy in the node result here, because no proxy should be record in the tensor metadata node_result = pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, node_result) diff --git a/tests/autodist/graph/test_recompute.py b/tests/autodist/graph/test_recompute.py index 6fe0b93b..1f87434a 100644 --- a/tests/autodist/graph/test_recompute.py +++ b/tests/autodist/graph/test_recompute.py @@ -9,6 +9,9 @@ from nnscaler.ir.operator import IRFwOperation from nnscaler.autodist.model_graph import ModelGraph from nnscaler.autodist.autodist_config import AutoDistConfig +from nnscaler.parallel import parallelize, ComputeConfig +from pathlib import Path +from tests.parallel_module.test_gencode import print_gencode, _gencode_contains class MLP(torch.nn.Module): @@ -150,3 +153,120 @@ def test_recompute(): # will label operator like GELU and add with `has_batch_dim=True` for op in model_graph.operator_list: assert op.has_batch_dim, f'{op} does not have batch dim' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_recompute_root_module(): + """ + Test that when recompute_modules='ROOT' is set, the entire module is marked for recompute + """ + batch_size = 2 + hidden_dim, ffn_dim, num_layers = 64, 64, 1 + + dummy_input = {'x': torch.randn(batch_size, hidden_dim)} + model = Model(hidden_dim, ffn_dim, num_layers) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + constant_folding=False) + + # Test with ROOT recompute + config = AutoDistConfig(recompute_modules='ROOT') + model_graph = ModelGraph(ir_graph, config) + + print("=== ROOT recompute module configuration ===") + print(f"min_recompute_mem: {model_graph.min_recompute_mem}") + print(f"Number of recompute groups: {len(model_graph.recompute_groups)}") + + # With ROOT recompute, the entire model should be one recompute group + # The recompute memory should be the total training memory of the entire model + model_node = model_graph.scope_tree_root + + # Since we're recomputing the ROOT module, the min_recompute_mem should be + # the training memory of the entire model + expected_recompute_mem = model_node.train_mem + assert model_graph.min_recompute_mem == expected_recompute_mem, \ + f"Expected recompute mem {expected_recompute_mem}, got {model_graph.min_recompute_mem}" + + # All forward operations should be in one big recompute group + fnodes = ir_graph.select(ntype=IRFwOperation) + print(f"Total forward nodes: {len(fnodes)}") + print(f"Recompute groups: {len(model_graph.recompute_groups)}") + + # With ROOT recompute, there should be one recompute group containing all operations + assert len(model_graph.recompute_groups) == 1, \ + f"Expected 1 recompute group for ROOT, got {len(model_graph.recompute_groups)}" + + # The single recompute group should contain all forward operations + recompute_group = model_graph.recompute_groups[0] + assert len(recompute_group) == len(fnodes), \ + f"Expected {len(fnodes)} nodes in recompute group, got {len(recompute_group)}" + + # Verify that all forward nodes are in the recompute group + recompute_node_set = set(recompute_group) + fnodes_set = set(fnodes) + assert recompute_node_set == fnodes_set, \ + "Recompute group should contain exactly all forward nodes" + + print("ROOT recompute test passed: entire model is marked for recompute") + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 10, bias=False) + self.linear2 = torch.nn.Linear(10, 10, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = torch.relu(x) + x = self.linear2(x) + return x.sum() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_parallelize_with_root_recompute(): + """ + Test parallelize with recompute_modules='ROOT' and examine generated code + """ + m = SimpleModel() + m.train() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + trace_data = torch.randn([2, 10], dtype=torch.float32, device=torch.cuda.current_device()) + + with tempfile.TemporaryDirectory() as tempdir: + # Test with ROOT recompute + pas_cfg = { + 'recompute_modules': 'ROOT', + 'parallel_profile': False + } + + print("=== Testing parallelize with recompute_modules='ROOT' ===") + parallelize( + m, + {'x': trace_data}, + 'autodist', + ComputeConfig(1, 1, use_end2end=True, pas_config=pas_cfg), + reuse='override', + gen_savedir=tempdir, + load_module=False, + ) + + print("\n=== Generated code with ROOT recompute ===") + print_gencode(tempdir, SimpleModel, 0) + + # Check that recompute is applied + recompute_matches = _gencode_contains(tempdir, SimpleModel, 0, r'def recompute\(') + checkpoint_matches = _gencode_contains(tempdir, SimpleModel, 0, r'ckpt\.checkpoint\(recompute') + + print(f"\nFound {len(recompute_matches)} recompute function definitions") + print(f"Found {len(checkpoint_matches)} checkpoint calls") + + assert len(recompute_matches) >= 1, "Should generate at least one recompute function" + assert len(checkpoint_matches) >= 1, "Should use checkpoint for recompute function" From fde55a206b874ebdac65cd5a001c6b93a78cf903 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Tue, 12 Aug 2025 08:38:38 +0000 Subject: [PATCH 1826/1892] Merged PR 2387: Add Copyright and update version of package datasets version of package datasets from 2.20.0 to 3.6.0, for fixing an unicode UnicodeDecodeError Add Copyright for release --- docs/source/examples/longrope2.rst | 1 + docs/source/index.rst | 1 + examples/deepseek_coder_v2_lite/README.md | 2 +- examples/llama/README.rst | 2 +- examples/llama/customized_ops/ring_attention/__init__.py | 3 +++ examples/llama/requirements.txt | 2 +- examples/llama3_demo/requirements.txt | 2 +- nnscaler/cli/mixed_module.py | 3 +++ tests/graph/tracer/test_dict_iter.py | 3 ++- 9 files changed, 14 insertions(+), 5 deletions(-) create mode 120000 docs/source/examples/longrope2.rst diff --git a/docs/source/examples/longrope2.rst b/docs/source/examples/longrope2.rst new file mode 120000 index 00000000..d0628769 --- /dev/null +++ b/docs/source/examples/longrope2.rst @@ -0,0 +1 @@ +../../../examples/longrope2/README.rst \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 8787ce04..4d680603 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -105,6 +105,7 @@ For any questions or inquiries, please contact us at nnscaler@service.microsoft. examples/vit examples/deepseek examples/nanogpt + examples/longrope2 .. toctree:: :maxdepth: 1 diff --git a/examples/deepseek_coder_v2_lite/README.md b/examples/deepseek_coder_v2_lite/README.md index 6a83be4e..c84eb7e6 100644 --- a/examples/deepseek_coder_v2_lite/README.md +++ b/examples/deepseek_coder_v2_lite/README.md @@ -11,7 +11,7 @@ To run this example, you need to install the following packages: ```text nnscaler transformers==4.40.0 -datasets==2.20.0 +datasets==3.6.0 apex flash-attn grouped_gemm==1.1.4 diff --git a/examples/llama/README.rst b/examples/llama/README.rst index a5a8c148..31620f1d 100644 --- a/examples/llama/README.rst +++ b/examples/llama/README.rst @@ -16,7 +16,7 @@ Assume following packages have been installed in the environment. :: nnscaler transformers==4.40.0 - datasets==2.20.0 + datasets==3.6.0 apex flash-attn diff --git a/examples/llama/customized_ops/ring_attention/__init__.py b/examples/llama/customized_ops/ring_attention/__init__.py index 405f1207..e54f5bc1 100644 --- a/examples/llama/customized_ops/ring_attention/__init__.py +++ b/examples/llama/customized_ops/ring_attention/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from .ring_attn_varlen import wrap_ring_attn_varlen_func from .zigzag_attn import wrap_zigzag_attn_func diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index f23e777d..705dba10 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,3 +1,3 @@ transformers==4.40.0 -datasets==2.20.0 +datasets==3.6.0 tensorboard diff --git a/examples/llama3_demo/requirements.txt b/examples/llama3_demo/requirements.txt index 5fd38d67..71f990ca 100644 --- a/examples/llama3_demo/requirements.txt +++ b/examples/llama3_demo/requirements.txt @@ -1,4 +1,4 @@ -datasets==2.21.0 +datasets==3.6.0 tensorboard torch==2.3.1 transformers==4.42.4 diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index 7849a67f..d7354617 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import types import torch from typing import Any, Optional diff --git a/tests/graph/tracer/test_dict_iter.py b/tests/graph/tracer/test_dict_iter.py index 9cb60613..1029ae51 100644 --- a/tests/graph/tracer/test_dict_iter.py +++ b/tests/graph/tracer/test_dict_iter.py @@ -1,4 +1,5 @@ -import pytest +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from nnscaler.graph.tracer import concrete_trace, wrap_utils from nnscaler.graph.tracer.metadata import DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE From 385e7aa3a85b9071df380513c5869dc41dac1b71 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 13 Aug 2025 03:13:14 +0000 Subject: [PATCH 1827/1892] Merged PR 2390: [BugFix] Add functools.cache support We will track all functions decorated with `functools.cache/lru_cache`, and clear the cache after tracing to avoid memory leak and potential tracing error (when run tracer multiple times) --- nnscaler/graph/parser/external/__init__.py | 3 +- nnscaler/graph/parser/external/einops.py | 18 +++++ nnscaler/graph/tracer/concrete_tracer.py | 31 +++++++++ nnscaler/graph/tracer/operator_patcher.py | 1 + requirements-dev.txt | 1 + tests/parallel_module/test_gencode_einops.py | 67 +++++++++++++++++++ .../test_gencode_torch_compile.py | 3 + 7 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 nnscaler/graph/parser/external/einops.py create mode 100644 tests/parallel_module/test_gencode_einops.py diff --git a/nnscaler/graph/parser/external/__init__.py b/nnscaler/graph/parser/external/__init__.py index 5c71d8f9..5a628d8f 100644 --- a/nnscaler/graph/parser/external/__init__.py +++ b/nnscaler/graph/parser/external/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .apex import * \ No newline at end of file +from .apex import * +from .einops import * diff --git a/nnscaler/graph/parser/external/einops.py b/nnscaler/graph/parser/external/einops.py new file mode 100644 index 00000000..91845fe8 --- /dev/null +++ b/nnscaler/graph/parser/external/einops.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging + +import torch + + +_logger = logging.getLogger(__name__) + +try: + import einops + + # trigger einops initialization + einops.rearrange(torch.arange(1), '(a b c) -> a b c', a=1, b=1, c=1) +except ImportError as e: + _logger.debug("Einops is not installed") + pass diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index 9ec92c27..dda70381 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -16,6 +16,7 @@ from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType from typing import Any, Dict, Optional, Set, Tuple, Type, List, Callable, Union, Literal from contextlib import contextmanager +import weakref import torch from torch._C import ScriptObject @@ -100,6 +101,16 @@ def __init__(self, strategy, record_frames = False): self.need_revert_functions = set() self.need_revert_wrapped_functions = set() + # Save functions decorated with functools.cache/lru_cache + # We need to clear up caches after tracing to avoid memory leak or tracing error. + # TODO: currently only functions/methods are tracked. + # Cached Properties (via @property @cache or @cached_property) are not tracked + # The reason is: + # 1. Cached properties is rare to cause problem as they have no arguments (no ConcrateProxy object will pass to it) + # 2. We need to patch all getattr (`a.b``) to support this scenario, which is too expensive + # Currently only function calls (`f(a,b)`) are patched and tracked. (See `operator_patcher`) + self.cached_function = weakref.WeakSet() + self.temp_call_origin = False def add_need_revert_function(self, func, wrapped_func): @@ -109,6 +120,22 @@ def add_need_revert_function(self, func, wrapped_func): def need_revert(self, func): return func in self.need_revert_functions or func in self.need_revert_wrapped_functions + @classmethod + def _is_cache_wrapped_function(cls, func): + return callable(func) \ + and hasattr(func, 'cache_clear') \ + and hasattr(func, 'cache_info') \ + and hasattr(func, 'cache_parameters') \ + and hasattr(func, '__wrapped__') \ + and callable(func.__wrapped__) + + def on_function_call(self, func): + while func is not None: + if self._is_cache_wrapped_function(func): + self.cached_function.add(func) + break + func = getattr(func, '__wrapped__', None) + @contextmanager def do_temp_call_origin(self): temp_call_origin = self.temp_call_origin @@ -692,6 +719,10 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): {}, type_expr=fn.__annotations__.get('return', None), node_result=node_result) finally: _retain_weight_consistency(self.root) + # clean up caches + for func in self.cached_function: + if func is not None: + func.cache_clear() return self.graph diff --git a/nnscaler/graph/tracer/operator_patcher.py b/nnscaler/graph/tracer/operator_patcher.py index 235da0b5..5149d9e5 100644 --- a/nnscaler/graph/tracer/operator_patcher.py +++ b/nnscaler/graph/tracer/operator_patcher.py @@ -350,5 +350,6 @@ def patch_run(func, *args, **kwargs): assert OperatorPatcherContext.ctx_tracer is not None assert OperatorPatcherContext.ctx_patcher is not None with wrap_utils.do_temp_call_origin(): + OperatorPatcherContext.ctx_tracer.on_function_call(func) new_func = OperatorPatcherContext.ctx_patcher.patch_func_or_module(func) return new_func(*args, **kwargs) diff --git a/requirements-dev.txt b/requirements-dev.txt index 7d181749..bd2f7fdf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,3 +19,4 @@ wandb tensorboard mosaicml-streaming cppimport +einops diff --git a/tests/parallel_module/test_gencode_einops.py b/tests/parallel_module/test_gencode_einops.py new file mode 100644 index 00000000..bea1c75a --- /dev/null +++ b/tests/parallel_module/test_gencode_einops.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile +import functools +from einops import rearrange +import torch + +from nnscaler import parallelize, ComputeConfig +from nnscaler.graph import parser +from nnscaler.graph.tracer import ConcreteTracer + +from tests.utils import replace_all_device_with +from .test_gencode import _gencode_contains, print_gencode + + +class RearrangeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x, y): + return self.linear(x) + rearrange(y, '(h w) -> h w', h=3, w=3) + f(3) + + +def log_f(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + print(f"Function '{func.__name__}' called") + return func(*args, **kwargs) + return wrapper + + +@log_f +@functools.cache +def f(x: int) -> int: + return x * 2 + + +@replace_all_device_with('cpu') +def test_trace_rearrange(): + import gc + def _convert(): + model = RearrangeModule() + parser.to_fx_graph(model, {'x': torch.randn(3, 3), 'y': torch.randn(9)}) + gc.collect() + + _convert() + for obj in gc.get_objects(): + # einops is using functools.cache + # will leak memory if not properly handle it. + assert not isinstance(obj, ConcreteTracer) + + +@replace_all_device_with('cpu') +def test_codegen_rearrange(): + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + RearrangeModule(), + {'x': torch.randn(3, 3), 'y': torch.randn(9)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False + ) + # parallelize will succeed. + assert True diff --git a/tests/parallel_module/test_gencode_torch_compile.py b/tests/parallel_module/test_gencode_torch_compile.py index 700bac9d..62ba534c 100644 --- a/tests/parallel_module/test_gencode_torch_compile.py +++ b/tests/parallel_module/test_gencode_torch_compile.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import tempfile import pytest import torch From cbf14fc084c7e03fbe0c3a0f941819d2410e167d Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 14 Aug 2025 08:42:13 +0000 Subject: [PATCH 1828/1892] Merged PR 2393: [Tracer] add torch compile check Tracing into @torch.compile will lead to confusing error. This PR will detect this case, and raise a clear error. --- nnscaler/graph/tracer/concrete_tracer.py | 25 ++++++++++++++++++- .../test_gencode_torch_compile.py | 4 +-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index dda70381..f4731300 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -129,13 +129,36 @@ def _is_cache_wrapped_function(cls, func): and hasattr(func, '__wrapped__') \ and callable(func.__wrapped__) - def on_function_call(self, func): + def _track_cache_wrapped_function(self, func): while func is not None: if self._is_cache_wrapped_function(func): self.cached_function.add(func) break func = getattr(func, '__wrapped__', None) + @classmethod + def _is_torch_compile_function(cls, func): + return callable(func) \ + and hasattr(func, '__wrapped__') \ + and hasattr(func, '_torchdynamo_orig_callable') + + def _check_torch_compile_function(self, func): + outmost_func = func + while func is not None: + if self._is_torch_compile_function(func): + # If func is registered, run this func will be in a reverted context. + if not self.need_revert(outmost_func): + raise RuntimeError( + f"@torch.compile decorated function `{outmost_func.__module__}.{outmost_func.__qualname__}` is not registered. " + f"You must register it to avoid tracing failure." + ) + break + func = getattr(func, '__wrapped__', None) + + def on_function_call(self, func): + self._track_cache_wrapped_function(func) + self._check_torch_compile_function(func) + @contextmanager def do_temp_call_origin(self): temp_call_origin = self.temp_call_origin diff --git a/tests/parallel_module/test_gencode_torch_compile.py b/tests/parallel_module/test_gencode_torch_compile.py index 62ba534c..4865783c 100644 --- a/tests/parallel_module/test_gencode_torch_compile.py +++ b/tests/parallel_module/test_gencode_torch_compile.py @@ -9,7 +9,7 @@ from nnscaler import parallelize, ComputeConfig, register_op -from tests.utils import replace_all_device_with +from tests.utils import raises_with_cause, replace_all_device_with from .test_gencode import _gencode_contains, print_gencode @@ -182,7 +182,7 @@ def forward(self, x): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_codegen_compile_failed_g(): - with pytest.raises(RuntimeError), tempfile.TemporaryDirectory() as tempdir: + with raises_with_cause(RuntimeError, match=".*You must register it to avoid tracing failure..*"), tempfile.TemporaryDirectory() as tempdir: parallelize( Module2(), {'x': torch.randn(3, 3)}, From f0576f20e225205c89dc039f59260b4450ae4545 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 14 Aug 2025 09:31:11 +0000 Subject: [PATCH 1829/1892] Merged PR 2391: [Runtime] Refine dedup ckpt save and load This PR reduces the size of saved weights for `deduped`. This is achieved by computing the first occurrence of each sub-tensor and only store it at the corresponding rank. --- docs/source/parallel_module.md | 4 +- nnscaler/parallel.py | 83 +++++++++++++++++-- nnscaler/runtime/module.py | 61 +++++++++----- tests/parallel_module/test_attr_dedup.py | 11 ++- .../parallel_module/test_checkpoint_dedup.py | 2 - 5 files changed, 125 insertions(+), 36 deletions(-) diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index f68ec024..4b910a20 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -473,7 +473,7 @@ Please note the `device` parameter. If it is None, we will use `torch.cuda.curre #### `deduped_state_dicts` -In parallel training, a lot of weights/state in the module and the optimizer will be the same in the ranks. So we can save a lot of space by deduping the state dicts before saving them to the disk. +In parallel training, a lot of weights/state in the module and the optimizer will be the same in the ranks. So we can save a lot of space by deduping the state dicts before saving them to the disk. Note each part of a logical tensor is saved at the first rank it appears. ```python def deduped_state_dict( @@ -484,6 +484,8 @@ def deduped_state_dict( #### `load_deduped_state_dict` +This is a reverse process of `deduped_state_dicts`. It assumes the distributed plan is unchanged. For weights, the loading process is divided into 3 steps: 1. each rank read its own state_dict 2. replicated weight is broadcasted inside the first scale unit so that it contains the full parameters 3. the 1st scale unit broadcast weights to other units. + ```python def load_deduped_state_dict( module: torch.nn.Module, diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index b728269b..6bd2f5a3 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -14,6 +14,7 @@ import logging import copy import os +from collections import OrderedDict import torch import torch.distributed @@ -43,7 +44,7 @@ from nnscaler.runtime.adapter.reducer import Reducer from nnscaler.runtime.device import DeviceGroup from nnscaler.runtime.gnorm import calcuate_gnorm, clip_grads -from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState +from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState, dedup_attrs from nnscaler.flags import CompileFlag, RuntimeFlag import nnscaler.policies as policies @@ -2202,6 +2203,34 @@ def _broadcast_gen_files( torch.distributed.barrier() +def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Dict[str, Any]: + """ + A helper function that computes the deduplicated attribute information from all ranks. + Note that this function may be removed in the future and dedup information are computed + directly at the compilation stage. + """ + dedup_group_size = None + for prefix, parallel_module in parallel_modules.items(): + if dedup_group_size is None: + dedup_group_size = parallel_module.module_dedup_group_size + else: + assert dedup_group_size == parallel_module.module_dedup_group_size, \ + f'dedup_group_size mismatch {dedup_group_size} vs {parallel_module.module_dedup_group_size}' + dedup_group_size = dedup_group_size or 1 + + world_size = torch.distributed.get_world_size() + local_fullmaps = {prefix: m.fullmap for prefix, m in parallel_modules.items()} + global_fullmaps = [None for _ in range(world_size)] + torch.distributed.all_gather_object(global_fullmaps, local_fullmaps) + # `dedup_attrs` is a deterministic algorithm, so it produces same results across different ranks + rank2deduped_fullmap = dedup_attrs(OrderedDict(list(enumerate(global_fullmaps)))) + + for rank in range(dedup_group_size, world_size): + assert len(rank2deduped_fullmap[rank]) == 0, f'Rank {rank} has non-empty deduped_fullmap: {rank2deduped_fullmap[rank]}' + + return rank2deduped_fullmap, dedup_group_size + + @torch.no_grad() def deduped_state_dict( module: torch.nn.Module, @@ -2224,6 +2253,9 @@ def deduped_state_dict( module_state_dict, opt_state_dict = None, None parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} + rank2deduped_fullmap, _ = _collect_dedup_info(parallel_modules) + cur_deduped_fullmap = rank2deduped_fullmap[cur_rank] + # The reason we use `Module.state_dict` on the whole to get the complete state dict # instead of call `Module.state_dict` on each submodule # is to make sure the hooks to state_dict are called. @@ -2231,11 +2263,13 @@ def deduped_state_dict( for key in list(module_state_dict.keys()): if key.endswith(ParallelModule.EXTRA_STATE_KEY): # never remove extra state continue - prefix = '.'.join(key.split('.')[:-1]) # remove the last part of the key - dedup_group_size = parallel_modules[prefix].module_dedup_group_size \ - if prefix in parallel_modules else 1 - # only keep the first `dedup_group_size` ranks' state - if cur_rank >= dedup_group_size: + split_names = key.split('.') + prefix = '.'.join(split_names[:-1]) # remove the last part of the key + if prefix in parallel_modules: + if prefix not in cur_deduped_fullmap or split_names[-1] not in cur_deduped_fullmap[prefix]: + module_state_dict.pop(key, None) + # since replicated non-parallel modules, we only keep weights on rank 0 + elif cur_rank >= 1: module_state_dict.pop(key, None) if optimizer is not None: @@ -2285,13 +2319,44 @@ def load_deduped_state_dict( None """ device = device or torch.cuda.current_device() + cur_rank = torch.distributed.get_rank() - # only load partial state for all ranks except rank 0 + # step 1: load deduped state dict at each rank module.load_state_dict(module_state_dict, strict=False) - module.to(device) + torch.distributed.barrier() + logger.debug(f'at rank {cur_rank}, state_dict keys: {module_state_dict.keys()}') + + # step 2: broadcast deduped weights inside 1st scale unit + parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} + rank2deduped_fullmap, dedup_group_size = _collect_dedup_info(parallel_modules) + broadcast_group = DeviceGroup().get_group(list(range(dedup_group_size))) + logger.debug(f'at rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}') + if cur_rank < dedup_group_size: + for rank, deduped_fullmap in rank2deduped_fullmap.items(): + logger.debug(f'at rank {cur_rank}, process rank: {rank}') + for prefix, fullmap in deduped_fullmap.items(): + for local_name, _ in fullmap.items(): + key = f'{prefix}.{local_name}' if prefix else local_name + if rank == cur_rank: + assert key in module_state_dict, f'expect {key} in {module_state_dict.keys()}' + object_list = [module_state_dict[key]] + logger.debug(f'at rank {cur_rank}, broadcast: {key} from {cur_rank}') + else: + object_list = [None] + torch.distributed.broadcast_object_list(object_list, src=rank, group=broadcast_group) + if rank != cur_rank: + tensor = object_list[0] + logger.debug(f'at rank {cur_rank}, try to load: {key} to rank {cur_rank}') + assert prefix in parallel_modules, f'prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}' + pm = parallel_modules[prefix] + # in pipeline parallelism, the local_name may not be found in the module + if hasattr(pm, local_name): + attr = getattr(pm, local_name) + attr.data.copy_(tensor) torch.distributed.barrier() - # broadcast weights + module.to(device) + # step 3: broadcast weights from 1st scale unit to other units broadcast_weights(module) if optimizer is not None: diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 0e26d483..35ad0244 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -48,22 +48,33 @@ class AttrMeta: val_chunks: int -def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, AttrMeta]]) -> Dict[int, Dict[str, AttrMeta]]: +def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, Dict[str, AttrMeta]]]) -> Dict[int, Dict[str, Dict[str, AttrMeta]]]: ''' Deduplicate the attributes according to `rank2attr_area_map`. - For each `slicers` of a full tensor with the name `orig_name`, we only store its first appearance - in the `rank2attr_area_map`. + For each `slicers` of a full tensor identified by its full qualified name, we only store its first appearance + in the `rank2attr_area_map`. In nnscaler, this dedup process leads to: + - If an attribute is not within the first scale unit, it will be deduplicated. + - If an attribute is shared by different operators, it will be deduplicated. + - If an attribute is replicated across several devices, we only save it at the devices with the smallest rank. + - If an attribute is partitioned across several devices, all these sub tensors will be saved. + - Note that nnscaler supports partition an operator across multiple dimensions, attributes in the operator may + be saved at a subset of related devices. + - Pipeline parallelism is supported since it is composed of different segments in nnscaler, which are different + parallel modules with their own attribute maps at runtime. In addition, we will check - the shape of the full tensor is consistent across different ranks - the slicers of the full tensor are not intersected with each other - the slicers of the full tensor can cover the full tensor - The input and output attribute area map's key is the local attribute name. Args: - rank2attr_area_map (Dict[int, Dict[str, AttrMeta]]): the mapping from rank to the attribute area map + rank2attr_area_map ( + Dict[int, # rank id + Dict[str, # submodule prefix + Dict[str, # attribute name in parallel module (not original name) + AttrMeta]]]): fullmap information for all parallel modules in all ranks. Returns: - Dict[int, Dict[str, AttrMeta]]: the deduplicated attribute area map + Dict[int, Dict[str, Dict[str, AttrMeta]]]: the deduplicated fullmap info, the structure is the same as the input. ''' # assume ranks in rank2attr_area_map are in increasing order ranks = list(rank2attr_area_map.keys()) @@ -87,26 +98,32 @@ def need_save(slicers: Tuple[slice, ...], saved_slicers_list: List[Tuple[slice, return True ret = dict() - for rank, attr_area_map in rank2attr_area_map.items(): - dedup_attr_area_map = dict() - for attr, attr_meta in attr_area_map.items(): - assert attr_meta.val_chunks == 1, 'not support partitioning on value dimension' - if attr_meta.orig_name not in orig_name2shape: - orig_name2shape[attr_meta.orig_name] = attr_meta.shape - else: - assert orig_name2shape[attr_meta.orig_name] == attr_meta.shape, \ - f'unmatched shape {orig_name2shape[attr_meta.orig_name]} vs {attr_meta.shape}' - if need_save(attr_meta.slicers, orig_name2slice_info[attr_meta.orig_name]): - orig_name2slice_info[attr_meta.orig_name].append(attr_meta.slicers) - dedup_attr_area_map[attr] = attr_meta - ret[rank] = dedup_attr_area_map + for rank, module_fullmaps in rank2attr_area_map.items(): + dedup_module_fullmaps = dict() + for module_name, attr_area_map in module_fullmaps.items(): + dedup_attr_area_map = dict() + for attr, attr_meta in attr_area_map.items(): + assert attr_meta.val_chunks == 1, 'not support partitioning on value dimension' + # use module_name.orig_name as the unique identifier for full tensor + full_tensor_name = f"{module_name}.{attr_meta.orig_name}" + if full_tensor_name not in orig_name2shape: + orig_name2shape[full_tensor_name] = attr_meta.shape + else: + assert orig_name2shape[full_tensor_name] == attr_meta.shape, \ + f'unmatched shape {orig_name2shape[full_tensor_name]} vs {attr_meta.shape}' + if need_save(attr_meta.slicers, orig_name2slice_info[full_tensor_name]): + orig_name2slice_info[full_tensor_name].append(attr_meta.slicers) + dedup_attr_area_map[attr] = attr_meta + if dedup_attr_area_map: # only add non-empty maps + dedup_module_fullmaps[module_name] = dedup_attr_area_map + ret[rank] = dedup_module_fullmaps # since we # - skip saving when there are identical weights # - assert the slicers are disjoint # we can use the sum of the sub-slicers to verify the full tensor is covered - for orig_name, slicerss in orig_name2slice_info.items(): - shape = orig_name2shape[orig_name] + for full_tensor_name, slicerss in orig_name2slice_info.items(): + shape = orig_name2shape[full_tensor_name] full_size = 1 for s in shape: full_size *= s @@ -116,7 +133,7 @@ def need_save(slicers: Tuple[slice, ...], saved_slicers_list: List[Tuple[slice, for s in slicers: size *= s.stop - s.start covered_size += size - assert full_size == covered_size, f'uncovered size for {orig_name} with shape {shape}, slicerss {slicerss}' + assert full_size == covered_size, f'uncovered size for {full_tensor_name} with shape {shape}, slicerss {slicerss}' return ret diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py index 64bf490a..b1a5a95b 100644 --- a/tests/parallel_module/test_attr_dedup.py +++ b/tests/parallel_module/test_attr_dedup.py @@ -22,6 +22,7 @@ from ..launch_torchrun import launch_torchrun from ..utils import clear_dir_on_rank0 + class Net(torch.nn.Module): def __init__(self): super().__init__() @@ -38,6 +39,7 @@ def forward(self, x): x = self.buffer + x return x + def pas(graph: IRGraph, config: ComputeConfig): fw_nodes = graph.select(ntype=IRFwOperation) assert len(fw_nodes) == 4 @@ -50,6 +52,7 @@ def pas(graph: IRGraph, config: ComputeConfig): _replica(graph, fw_nodes[3], devs=devs) return graph + def _gpu_worker_spmd(cc: ComputeConfig): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_dedup_attr') as tempdir: @@ -65,13 +68,17 @@ def _gpu_worker_spmd(cc: ComputeConfig): world_size = torch.distributed.get_world_size() attr_area_maps = [None for _ in range(world_size)] curr_rank = torch.distributed.get_rank() - torch.distributed.all_gather_object(attr_area_maps, module.fullmap) + # Construct the three-level nested structure: rank -> module_name -> fullmap + # In this test case, we have only one module instance 'attr_dedup' + module_fullmap = {'attr_dedup': module.fullmap} + torch.distributed.all_gather_object(attr_area_maps, module_fullmap) rank2attr_area_map = {} for i, attr_area_map in enumerate(attr_area_maps): rank2attr_area_map[i] = attr_area_map torch.distributed.barrier() dedup_meta_info = dedup_attrs(rank2attr_area_map) - dedup_area_map = list(dedup_meta_info[curr_rank].items()) + # Access the deduped fullmap for the specific module + dedup_area_map = list(dedup_meta_info[curr_rank]['attr_dedup'].items()) if curr_rank == 0: assert len(dedup_area_map) == 4 assert dedup_area_map[0][1].orig_name == 'fc1.weight' diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index 339ea888..a5fec814 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -139,8 +139,6 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): if not isinstance(model, ParallelModule): # in this case, non parallel module is removed, so it should have less keys assert len(parallel_modules) < len(dedupped_model_state_dict) < len(model_state_dict) - else: - assert len(dedupped_model_state_dict) == len(model_state_dict) for k, v in dedupped_model_state_dict.items(): assert_equal(v, model_state_dict[k]) From ac1c72f60759c712fc9d5902eb36107e364820c1 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 20 Aug 2025 05:28:48 +0000 Subject: [PATCH 1830/1892] Merged PR 2394: [BugFix] fix bug in dedup load: handle persistent buffer correctly & use broadcast instead of broadcast_object_list to save time --- nnscaler/parallel.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 6bd2f5a3..ce5a0d0f 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2220,15 +2220,26 @@ def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Dict[str world_size = torch.distributed.get_world_size() local_fullmaps = {prefix: m.fullmap for prefix, m in parallel_modules.items()} + local_tensor_meta = dict() + for prefix, m in parallel_modules.items(): + module_meta = {} + for local_name in m.fullmap.keys(): + assert hasattr(m, local_name), f'Module {prefix} does not have attribute {local_name}' + tensor = getattr(m, local_name) + module_meta[local_name] = (tuple(tensor.shape), tensor.dtype) + local_tensor_meta[prefix] = module_meta global_fullmaps = [None for _ in range(world_size)] torch.distributed.all_gather_object(global_fullmaps, local_fullmaps) # `dedup_attrs` is a deterministic algorithm, so it produces same results across different ranks rank2deduped_fullmap = dedup_attrs(OrderedDict(list(enumerate(global_fullmaps)))) + global_tensor_meta = [None for _ in range(world_size)] + torch.distributed.all_gather_object(global_tensor_meta, local_tensor_meta) + global_tensor_meta = OrderedDict(list(enumerate(global_tensor_meta))) for rank in range(dedup_group_size, world_size): assert len(rank2deduped_fullmap[rank]) == 0, f'Rank {rank} has non-empty deduped_fullmap: {rank2deduped_fullmap[rank]}' - return rank2deduped_fullmap, dedup_group_size + return rank2deduped_fullmap, dedup_group_size, global_tensor_meta @torch.no_grad() @@ -2253,7 +2264,7 @@ def deduped_state_dict( module_state_dict, opt_state_dict = None, None parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} - rank2deduped_fullmap, _ = _collect_dedup_info(parallel_modules) + rank2deduped_fullmap, _, _ = _collect_dedup_info(parallel_modules) cur_deduped_fullmap = rank2deduped_fullmap[cur_rank] # The reason we use `Module.state_dict` on the whole to get the complete state dict @@ -2323,39 +2334,39 @@ def load_deduped_state_dict( # step 1: load deduped state dict at each rank module.load_state_dict(module_state_dict, strict=False) + module.to(device) torch.distributed.barrier() logger.debug(f'at rank {cur_rank}, state_dict keys: {module_state_dict.keys()}') # step 2: broadcast deduped weights inside 1st scale unit parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} - rank2deduped_fullmap, dedup_group_size = _collect_dedup_info(parallel_modules) + rank2deduped_fullmap, dedup_group_size, global_tensor_meta = _collect_dedup_info(parallel_modules) broadcast_group = DeviceGroup().get_group(list(range(dedup_group_size))) logger.debug(f'at rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}') if cur_rank < dedup_group_size: for rank, deduped_fullmap in rank2deduped_fullmap.items(): logger.debug(f'at rank {cur_rank}, process rank: {rank}') for prefix, fullmap in deduped_fullmap.items(): - for local_name, _ in fullmap.items(): + for local_name, attr_meta in fullmap.items(): key = f'{prefix}.{local_name}' if prefix else local_name + assert prefix in parallel_modules, f'prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}' + pm = parallel_modules[prefix] + shape, dtype = global_tensor_meta[rank][prefix][local_name] if rank == cur_rank: - assert key in module_state_dict, f'expect {key} in {module_state_dict.keys()}' - object_list = [module_state_dict[key]] - logger.debug(f'at rank {cur_rank}, broadcast: {key} from {cur_rank}') + assert hasattr(pm, local_name), f'local_name {local_name} not found in {pm}' + broadcast_tensor = getattr(pm, local_name) + logger.info(f'at rank {cur_rank}, broadcast: {key} from {cur_rank}') else: - object_list = [None] - torch.distributed.broadcast_object_list(object_list, src=rank, group=broadcast_group) + broadcast_tensor = torch.empty(shape, device=device, requires_grad=False, dtype=dtype) + torch.distributed.broadcast(broadcast_tensor, src=rank, group=broadcast_group) if rank != cur_rank: - tensor = object_list[0] - logger.debug(f'at rank {cur_rank}, try to load: {key} to rank {cur_rank}') - assert prefix in parallel_modules, f'prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}' - pm = parallel_modules[prefix] + logger.info(f'at rank {cur_rank}, try to load: {key} to rank {cur_rank}') # in pipeline parallelism, the local_name may not be found in the module if hasattr(pm, local_name): attr = getattr(pm, local_name) - attr.data.copy_(tensor) + attr.data.copy_(broadcast_tensor) torch.distributed.barrier() - module.to(device) # step 3: broadcast weights from 1st scale unit to other units broadcast_weights(module) From 1474acd6f9415d19c72aba067b92ebdb6331014b Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 21 Aug 2025 03:19:44 +0000 Subject: [PATCH 1831/1892] Merged PR 2396: [AutoDist] Refine error handling and logging - if the plan_ngpus is greater than existing communication profile data, we will throw a warning and use previous data - add more information when `gen_masks` fails - when allowed dims in provided partition constraints are not correct, will generate warning, exception maybe throwed if there are no valid partitions --- nnscaler/autodist/autodist_config.py | 2 +- nnscaler/autodist/cost_database.py | 34 ++++++++++++++++++++-------- nnscaler/autodist/spmd_solver.py | 13 +++++++---- nnscaler/graph/function/dimops.py | 9 ++++++-- 4 files changed, 41 insertions(+), 17 deletions(-) diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 796d35e2..b790e459 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -244,7 +244,7 @@ def _validate_config(self): scale_factor = self.world_size // self.mesh_desc.ngpus if scale_factor % self.zero_ngroups != 0: raise ValueError( - f'world size {self.world_size} must be divisible by zero num groups {self.zero_ngroups}' + f'scale_factor {scale_factor} must be divisible by zero num groups {self.zero_ngroups}' ) if not self.solver in [ diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 2e003359..97c446c3 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -119,27 +119,37 @@ def _load_comm_data(profile_dir: Path, plan_ngpus: int) -> Dict[str, Dict[str, L load intra_2.json, intra_4.json, and intra_8.json from the profile directory. If any of the files is not found, we will use the default data as well. ''' - def loader(path: Path): + def loader(path: Path, strict: bool): if not os.path.exists(path): return False, None info = {} dev = 2 + prev_info = None while dev <= plan_ngpus: fname = f'intra_{dev}.json' if not (path / fname).exists(): - return False, None - with open(path / fname, 'r') as f: - info[fname] = json.load(f) + if strict or prev_info is None: + return False, None + else: + content = prev_info + _logger.warning(f'{dev} devices communication profile data not found, using previous data') + else: + with open(path / fname, 'r') as f: + content = json.load(f) + prev_info = content + info[fname] = content dev *= 2 return True, info comm_path = profile_dir / 'comm' - success, comm_info = loader(comm_path) + success, comm_info = loader(comm_path, strict=True) if not success: + # When communication profile data is not found, use the default data. If the input `plan_ngpus` is greater + # than the devices in the profile data, the data with largest device count (16 for mi200) will be used. This + # is helpful when user wants to generate a distributed plan spanning over multiple nodes. _logger.warning(f'Communication profile data not found, using default data at {_DEFAULT_COMM_DATA_PATH}') - success, comm_info = loader(Path(_DEFAULT_COMM_DATA_PATH)) - if not success: - raise RuntimeError(f'Communication profile data is not compatible with plan_ngpus {plan_ngpus}') + success, comm_info = loader(Path(_DEFAULT_COMM_DATA_PATH), strict=False) + assert success, f'Failed to load default communication profile data from {_DEFAULT_COMM_DATA_PATH}, please check nnscaler\'s installation' return comm_info @@ -337,10 +347,14 @@ def query_single_mem(self, obj, memory_type, round=True) -> int: from .op_partition import OpPartition from .cube_operator import CubeOperator if isinstance(obj, OpPartition): - masks = self.gen_masks(obj.operator) + query_obj = obj.operator else: assert isinstance(obj, CubeOperator) - masks = self.gen_masks(obj) + query_obj = obj + try: + masks = self.gen_masks(query_obj) + except Exception as e: + raise RuntimeError(f"Failed to generate masks for {query_obj} with {self.query_profiled_metrics(query_obj)}: {e}") if memory_type == 'full_weight' and isinstance(obj, OpPartition): profiled_metrics = self.query_profiled_metrics(obj.operator) else: diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 1a1e0ec8..9decc8a8 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -358,10 +358,15 @@ def is_valid_partition(operator: CubeOperator, p_ids: List[Any], if not selected_pc.replica_allowed: return False else: - allowed_pids = [ - operator.pos2dim_id(pos) - for pos in selected_pc.allowed_partition_dims - ] + allowed_pids = list() + for pos in selected_pc.allowed_partition_dims: + # When allowed dims in provided partition constraints are not correct generate warning + # If there is no valid partitions for the operator, the solver will throw exception later. + try: + cur_allowed_pid = operator.pos2dim_id(pos) + allowed_pids.append(cur_allowed_pid) + except Exception as e: + _logger.warning(f"Failed to get allowed partition id for {selected_pc}'s {pos}: {e}") if u not in allowed_pids: return False diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index 333b01ad..aaa2f5e3 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -853,7 +853,7 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict op_anno.reset_identifiers() identifier_values: Dict[str, int] = dict() - for ashape, itensor in zip(op_anno.inputs(), inputs): + for idx, (ashape, itensor) in enumerate(zip(op_anno.inputs(), inputs)): if not isinstance(itensor, IRTensor) or ashape.ignore: continue if ashape.ndims != len(itensor.shape): @@ -861,7 +861,12 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict for adim, dimlen in zip(ashape.dims, itensor.shape): if len(adim.identifiers) == 1: if adim.identifiers[0] in identifier_values and identifier_values[adim.identifiers[0]] != dimlen: - raise RuntimeError(f'the exist identifier value {identifier_values[adim.identifiers[0]]} is not equal to the new value {dimlen}') + error_msg = ( + f"at {signature} with {op_anno} the exist identifier {adim.identifiers[0]} value " + f"{identifier_values[adim.identifiers[0]]} is not equal to the new value {dimlen}, " + f"error idx {idx}, input tensors {inputs}" + ) + raise RuntimeError(error_msg) identifier_values[adim.identifiers[0]] = dimlen # check dimension consistency From 401bea6bb49403f93d6f5205f63334474494432b Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 27 Aug 2025 04:46:03 +0000 Subject: [PATCH 1832/1892] Merged PR 2397: [Runtime] Support offload params for parallel module --- ...armup_schedular.py => warmup_scheduler.py} | 3 +- nnscaler/runtime/adapter/reducer.py | 149 ++++++++---- nnscaler/runtime/module.py | 98 ++++++++ tests/parallel_module/test_offload_params.py | 213 ++++++++++++++++++ 4 files changed, 418 insertions(+), 45 deletions(-) rename examples/{warmup_schedular.py => warmup_scheduler.py} (97%) create mode 100644 tests/parallel_module/test_offload_params.py diff --git a/examples/warmup_schedular.py b/examples/warmup_scheduler.py similarity index 97% rename from examples/warmup_schedular.py rename to examples/warmup_scheduler.py index b066c385..54e8aa7f 100644 --- a/examples/warmup_schedular.py +++ b/examples/warmup_scheduler.py @@ -30,12 +30,11 @@ def __init__( T_max: int, eta_min=0.0, last_epoch=-1, - verbose="deprecated", ): # noqa: D107 self.warmup_steps = warmup_steps self.T_max = T_max - warmup_steps + 1 self.eta_min = eta_min - super().__init__(optimizer, last_epoch, verbose) + super().__init__(optimizer, last_epoch) def get_lr(self): """Retrieve the learning rate of each parameter group.""" diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index be83fdc3..679cdde9 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -168,17 +168,7 @@ def _group_reduce_scatter(self): partial_tensor, self._contiguous_grads, op=self._reduce_op, group=self._zero_subgroup) - def build(self): - """ - Build offset for each parameter - This should only be called once during the construction of bucket. - """ - ofst = 0 - for param in self._params: - self._pofset[param] = ofst - ofst += _aligned_nelement(param.nelement(), param.element_size(), self._align_size) - # build parameter for optimizer (shared storage). - # Its gradient will be updated everytime calling `self.sync_grads()` + def _get_opt_param_data(self): if not self._zero: opt = self._contiguous_params else: @@ -190,7 +180,20 @@ def build(self): # the calculation of gnorm is not affected as long as the paddings are all 0. # So for now, it looks harmless. opt = self._contiguous_params.chunk(self._zgroup_sz)[rank] - self._param_for_optimizer = torch.nn.Parameter(opt) + return opt + + def build(self): + """ + Build offset for each parameter + This should only be called once during the construction of bucket. + """ + ofst = 0 + for param in self._params: + self._pofset[param] = ofst + ofst += _aligned_nelement(param.nelement(), param.element_size(), self._align_size) + # build parameter for optimizer (shared storage). + # Its gradient will be updated everytime calling `self.sync_grads()` + self._param_for_optimizer = torch.nn.Parameter(self._get_opt_param_data()) def register_hooks(self): """ @@ -363,6 +366,31 @@ def reset(self): self._async_param_cnt = 0 self._async_handle = None + def sleep(self): + """ + release reference to contiguous buffer in reducer + """ + cpu = torch.device('cpu') + self._param_for_optimizer.data = self._param_for_optimizer.data.to(cpu) + # set none to release memory + self._contiguous_params = None + self._contiguous_grads = None + + def wake_up(self, param_buffer, grad_buffer): + """ + re-attach to the contiguous buffer and re-build hooks + """ + self._contiguous_params = param_buffer + self._contiguous_grads = grad_buffer + self._param_for_optimizer.data = self._get_opt_param_data() + + # TODO(yizhu1): seems moving attributes to cpu will make hooks invalid. + # The reason is that torch's autograd will reset the AccumulateGrad object if the data is set: + # https://github.com/pytorch/pytorch/blob/38a492d40d7ebb2856cb120df337c6cdac244528/torch/csrc/autograd/variable.cpp#L473 + # To make the resuming process safe, re-register them here. + self._hooks = [] + self.register_hooks() + class Reducer: # the default bucket cap for async reducer in megabytes @@ -419,6 +447,13 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None self._contiguous_params: torch.Tensor = None self._contiguous_grads: torch.Tensor = None + # record following variables for params offload + # items in the bucket is params list + self.seq_buckets: List[List[torch.nn.Parameter]] = [] + # bucket start and stop pos in buffer + self.starts, self.stops = [], [] + self.buffer_length: int = 0 + # build the subgroup of zero the current rank belongs to. # When zero_ngroups is larger than 1, the number of ranks # will be divided by zero_ngroups into sub rank groups, @@ -506,6 +541,27 @@ def add_param(self, param: torch.nn.Parameter): self._param_ids.add(param.data.data_ptr()) self._numel += param.numel() + def _allocate_buffers(self): + # gradient buffer + self._contiguous_grads: torch.Tensor = torch.zeros( + (self.buffer_length,), dtype=self._params[0].dtype, + device=torch.cuda.current_device(), requires_grad=False) + # parameter buffer + self._contiguous_params: torch.Tensor = torch.zeros( + (self.buffer_length,), dtype=self._params[0].dtype, + device=torch.cuda.current_device(), requires_grad=False) + + def _bind_params(self): + for params, start, stop in zip(self.seq_buckets, self.starts, self.stops): + # replace underlying parameter content using shared storage from parameter + ofst = start + for param in params: + with torch.no_grad(): + self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) + param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) + aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) + ofst += aligned_nelements + def build_buckets(self): """ Build buckets the reducer. @@ -526,8 +582,6 @@ def build_buckets(self): # (used in pytorch, with a couple percentage improvement) bucket_size = self._numel * 8 + 1 if not self._bucket_size else self._bucket_size - # items in the bucket is params list - seq_buckets: List[List[torch.nn.Parameter]] = [] last_bucket_size = None assert len(set(p.dtype for p in self._params)) == 1, ( @@ -540,51 +594,37 @@ def build_buckets(self): # It will go the `else` branch # and finish the current bucket and start a new bucket. # This new bucket will be sealed in the next iteration - if len(seq_buckets) == 0: - seq_buckets.append([param]) + if len(self.seq_buckets) == 0: + self.seq_buckets.append([param]) last_bucket_size = cur_byte_size elif last_bucket_size + cur_byte_size <= bucket_size: - seq_buckets[-1].append(param) + self.seq_buckets[-1].append(param) last_bucket_size += cur_byte_size else: - seq_buckets.append([param]) + self.seq_buckets.append([param]) last_bucket_size = cur_byte_size # step 2: build meta data for the offset of each bucket # the start of each bucket will be padded to the next multiple of `len(self.ranks)` - buffer_length: int = 0 - starts, stops = [], [] - for params in seq_buckets: - starts.append(buffer_length) + for params in self.seq_buckets: + self.starts.append(self.buffer_length) numel = sum(_aligned_nelement(p.nelement(), p.element_size(), self._align_size) for p in params) # this pad is for zero, which needs numels in each Bucket can be divided by the number of ranks in this group * _align_size # so that each chunck during zero can be divided by _align_size align_nelements = self._align_size // params[0].element_size() * len(self._ranks) padding = (align_nelements - numel % align_nelements) % len(self._ranks) - buffer_length += numel + padding - stops.append(buffer_length) + self.buffer_length += numel + padding + self.stops.append(self.buffer_length) - # step3: allocate memory - # gradient buffer - self._contiguous_grads: torch.Tensor = torch.zeros( - (buffer_length,), dtype=self._params[0].dtype, - device=torch.cuda.current_device(), requires_grad=False) - # parameter buffer - self._contiguous_params: torch.Tensor = torch.zeros( - (buffer_length,), dtype=self._params[0].dtype, - device=torch.cuda.current_device(), requires_grad=False) + # step 3: allocate memory + self._allocate_buffers() - # step 4: build buckets + # step 4: bind parameters + self._bind_params() + + # step 5: build buckets buckets: List[Bucket] = [] - for params, start, stop in zip(seq_buckets, starts, stops): - # replace underlying parameter content using shared storage from parameter - ofst = start - for param in params: - with torch.no_grad(): - self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) - param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) - aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) - ofst += aligned_nelements + for params, start, stop in zip(self.seq_buckets, self.starts, self.stops): # initialize buckets bucket = Bucket( params, @@ -723,3 +763,26 @@ def clear_post_hooks(self): """Clear all post hooks.""" for bucket in self._buckets: bucket.clear_post_hooks() + + def sleep(self): + """ + release contiguous buffers on the device to save memory + """ + for bucket in self._buckets: + bucket.sleep() + + self._contiguous_params = None + self._contiguous_grads = None + + def wake_up(self): + """ + reallocate contiguous buffers and related objects + """ + self._allocate_buffers() + self._bind_params() + + for start, stop, bucket in zip(self.starts, self.stops, self._buckets): + bucket.wake_up( + self._contiguous_params[start:stop], + self._contiguous_grads[start:stop], + ) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 35ad0244..d892336f 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -2,15 +2,19 @@ # Licensed under the MIT License. from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union +from typing_extensions import Self import logging import os import sys +import gc +import warnings from pathlib import Path from dataclasses import dataclass, asdict from collections import defaultdict import torch import torch.distributed as dist +from torch import device from nnscaler.graph.parser import FxModuleParser @@ -735,6 +739,100 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): 'optim_state_dict': merged_optimizer_state_dict }, filename_prefix + '.full.ckpt') + def sleep(self): + """ + Move attributes (buffer and param) to cpu and release contiguous buffer in reducers. Different from + nn.Module's cpu() method, references to attributes are unchanged. + """ + for name, param in self.named_parameters(): + assert param.grad is None, f'expect {name} with shape {param.shape} has no grad' + + for reducer in self._reducers: + reducer.zero_grad() + + # we want attribute references are unchanged, so super().cpu() is not used here + cpu = torch.device('cpu') + for buffer in self.buffers(): + buffer.data = buffer.data.to(cpu) + + for param in self.parameters(): + param.data = param.data.to(cpu) + + for reducer in self._reducers: + reducer.sleep() + + gc.collect() + torch.cuda.empty_cache() + return self + + def wake_up(self, device: Optional[Union[int, device]] = None) -> Self: + """ + Move attributes (buffer and param) back to gpu and reallocate memories in reducers. It is a reverse + operation of `self.sleep()`. + """ + gpu = torch.cuda.current_device() + if device is not None: + if isinstance(device, int): + index = device + elif isinstance(device, torch.device): + index = device.index + else: + raise RuntimeError(f'unexpected device type {type(device)}') + assert gpu == index, f'nnscaler module does not support cross gpu transport, expect {gpu} but got {index}' + + for name, param in self.named_parameters(): + assert param.grad is None, f'expect {name} with shape {param.shape} has no grad' + + # we want attribute references are unchanged, so super().gpu() is not used here + for buffer in self.buffers(): + buffer.data = buffer.data.to(gpu) + + for param in self.parameters(): + param.data = param.data.to(gpu) + + for reducer in self._reducers: + reducer.wake_up() + + gc.collect() + torch.cuda.empty_cache() + return self + + def to(self, *args, **kwargs): + """ + Override nn.Module's to function, currently we only allow transfer data from host and device + + Args: + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module + dtype (:class:`torch.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module + tensor (torch.Tensor): Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + memory_format (:class:`torch.memory_format`): the desired memory + format for 4D parameters and buffers in this module (keyword + only argument) + + Returns: + Module: self + """ + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + if dtype is not None: + raise ValueError(f'nnscaler does not support passing dtype {dtype} to to()') + if convert_to_format is not None: + raise ValueError(f'nnscaler does not support passing convert_to_format {convert_to_format} to to()') + if non_blocking is not None: + warnings.warn(f'nnscaler moves tensors in a blocking approach currently') + + # after _parse_to `device` must in type of torch.device + if device.type == 'cpu': + return self.cpu() + elif device.type == 'cuda': + return self.cuda(device) + else: + raise ValueError(f'unsupported device type {device}') + @dataclass class OriginModuleMetadata: diff --git a/tests/parallel_module/test_offload_params.py b/tests/parallel_module/test_offload_params.py new file mode 100644 index 00000000..81e5d305 --- /dev/null +++ b/tests/parallel_module/test_offload_params.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile +from pathlib import Path +import pytest +from typing import Dict, Tuple, List, Any + +import torch +from torch import nn + +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.graph import IRGraph + +from .common import PASMegatron, CubeLinear, init_random, init_distributed, assert_equal +from ..launch_torchrun import launch_torchrun +from ..utils import clear_dir_on_rank0 + + +class SimpleMLP(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super(SimpleMLP, self).__init__() + init_random() + self.register_buffer('buffer', torch.zeros(hidden_dim,)) + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = x + self.buffer + x = torch.relu(x) + x = self.fc2(x) + return x + + +def get_tensor_bytesize(t: torch.Tensor) -> int: + return t.numel() * t.element_size() + + +def pas_test_offload(graph: IRGraph, cfg: ComputeConfig): + ngpus = cfg.plan_ngpus + auto_multiref(graph) + + batch_dim = 0 + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, list(range(ngpus))) + + found_linear = False + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if 'linear' in node.signature and not found_linear: + found_linear = True + algo = node.algorithm('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=1, num=ngpus) + else: + sub_nodes = graph.replicate(node, ngpus) + + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def _mem_worker(): + init_distributed() + bsz, dim = 32, 1024 + compute_config = ComputeConfig( + plan_ngpus=1, + runtime_ngpus=2, + ) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_offload_mem') as tempdir: + module = SimpleMLP(dim, dim, dim) + p_module = parallelize( + module, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + ) + + before_mem = torch.cuda.memory_allocated() + size_to_free = 0 + for reducer in p_module.reducers: + assert get_tensor_bytesize(reducer._contiguous_params) == get_tensor_bytesize(reducer._contiguous_grads) + size_to_free += get_tensor_bytesize(reducer._contiguous_params) + + for buffer in p_module.buffers(): + size_to_free += get_tensor_bytesize(buffer) + + for param in p_module.parameters(): + size_to_free += get_tensor_bytesize(param) + + p_module.sleep() + torch.distributed.barrier() + after_mem = torch.cuda.memory_allocated() + print(f"Memory before offload: {before_mem}, after offload: {after_mem}, freed: {before_mem - after_mem}") + print(f"Total size to free: {size_to_free}") + + assert size_to_free == before_mem - after_mem, f"Expected {size_to_free}, but got {before_mem - after_mem}" + + +def _correctness_worker(): + init_distributed() + bsz, dim, num_steps = 32, 1024, 5 + compute_config = ComputeConfig( + plan_ngpus=1, + runtime_ngpus=2, + ) + + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_offload_correctness') as tempdir: + # Create test data + torch.manual_seed(42 + torch.distributed.get_rank()) + test_data = [torch.randn(bsz, dim).cuda() for _ in range(num_steps)] + + # Test 1: Normal execution without offload/load + init_random() + module1 = SimpleMLP(dim, dim, dim) + p_module1 = parallelize( + module1, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + instance_name='normal' + ) + optimizer1 = build_optimizer(p_module1, torch.optim.Adam, lr=0.01) + + results_normal = [] + for step, x in enumerate(test_data): + p_module1.train() + output = p_module1(x) + loss = output.sum() + loss.backward() + optimizer1.step() + optimizer1.zero_grad() + + # Save intermediate results for comparison + results_normal.append({ + 'loss': loss.detach().cpu(), + 'output': output.detach().cpu(), + 'params': {name: param.detach().cpu().clone() for name, param in p_module1.named_parameters()} + }) + + torch.distributed.barrier() + + # Test 2: Execution with offload/load + init_random() + module2 = SimpleMLP(dim, dim, dim) + p_module2 = parallelize( + module2, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + instance_name='offload' + ) + optimizer2 = build_optimizer(p_module2, torch.optim.Adam, lr=0.01) + + # First offload to initialize the buffer_shape + p_module2.sleep() + + results_offload = [] + for step, x in enumerate(test_data): + # Load params at the beginning of each step + p_module2.wake_up() + + p_module2.train() + output = p_module2(x) + loss = output.sum() + loss.backward() + optimizer2.step() + optimizer2.zero_grad() + + # Save intermediate results for comparison + results_offload.append({ + 'loss': loss.detach().cpu(), + 'output': output.detach().cpu(), + 'params': {name: param.detach().cpu().clone() for name, param in p_module2.named_parameters()} + }) + + # Offload params at the end of each step + p_module2.sleep() + + torch.distributed.barrier() + + # Compare results + for step in range(num_steps): + normal_result = results_normal[step] + offload_result = results_offload[step] + + # Compare loss + assert torch.equal(normal_result['loss'], offload_result['loss']), \ + f"Loss mismatch at step {step}: {normal_result['loss']} vs {offload_result['loss']}" + + # Compare output + assert torch.equal(normal_result['output'], offload_result['output']), \ + f"Output mismatch at step {step}" + + # Compare parameters + for param_name in normal_result['params']: + assert torch.equal(normal_result['params'][param_name], + offload_result['params'][param_name]), \ + f"Parameter {param_name} mismatch at step {step}" + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_offload_params_mem(): + launch_torchrun(2, _mem_worker) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_offload_params_correctness(): + launch_torchrun(2, _correctness_worker) From 47ea46869c9df5207c1703e6b0ce2a9f5448db31 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 27 Aug 2025 07:56:21 +0000 Subject: [PATCH 1833/1892] Merged PR 2398: [Feature] Add multiple-optimizer/param groups support 1. Classify parameters belonging to different optimizers/param groups into different buckets. 2. Put all states from different optimizers/param groups into one HybridOptimizer to reuse existing merge logic. --- nnscaler/__init__.py | 2 + nnscaler/cli/arg_parser.py | 4 + nnscaler/cli/mixed_module.py | 11 +- nnscaler/cli/train_hook.py | 28 ++ nnscaler/cli/trainer.py | 31 +- nnscaler/cli/trainer_args.py | 3 + nnscaler/codegen/module/module.py | 6 +- nnscaler/parallel.py | 108 +++++-- nnscaler/runtime/adapter/reducer.py | 70 +++- nnscaler/runtime/hybrid_optimizer.py | 299 ++++++++++++++++++ nnscaler/runtime/module.py | 53 +++- nnscaler/utils.py | 29 +- tests/cli/test_arg_parser.py | 22 ++ tests/cli/test_hooks.py | 30 ++ tests/cli/test_trainer.py | 79 ++++- tests/parallel_module/common.py | 4 +- tests/runtime/test_hybrid_optimizer.py | 132 ++++++++ .../test_hybrid_optimizer_trainer_args.yaml | 76 +++++ 18 files changed, 921 insertions(+), 66 deletions(-) create mode 100644 nnscaler/runtime/hybrid_optimizer.py create mode 100644 tests/cli/test_hooks.py create mode 100644 tests/runtime/test_hybrid_optimizer.py create mode 100644 tests/runtime/test_hybrid_optimizer_trainer_args.yaml diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index 2bf5867a..4cf4896a 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -24,6 +24,8 @@ no_constant_folding, fold_constant, ) +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdam, MixedPrecisionAdamW +from nnscaler.runtime.hybrid_optimizer import HybridLRScheduler, HybridOptimizer def init(): diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index f5caab0d..1adf6b9c 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -3,6 +3,7 @@ import os import copy +import logging from typing import List, Optional, Tuple, Dict, Any, Union from dataclasses import dataclass, field, is_dataclass, asdict @@ -16,6 +17,7 @@ except ImportError: UnionType = None # for python < 3.10 +logger = logging.getLogger(__name__) _TYPE_KEY = '__type' _VALUE_TYPE_KEY = '__value_type' @@ -390,6 +392,8 @@ def _deserialize_object(value, value_type): else: raise ValueError(f"Failed to deserialize {value} to {value_type}") if _is_primitive_type(value_type): + if callable(value): + logger.warning(f'{value} is callable, converting to {value_type} may not work as expected.') return value_type(value) except Exception as ex: raise ValueError(f"Failed to deserialize {value} to {value_type}") from ex diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index d7354617..69487fb2 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -197,6 +197,7 @@ def resolve_compute_config(self): def parallelize(self, dummy_input: Optional[dict[str, Any]] = None, *, load_module: bool = True, + build_buckets: bool = True, module_args: Optional[tuple[tuple, dict]] = None ): pmodel_class = nnscaler.parallelize( @@ -212,7 +213,7 @@ def parallelize(self, load_module=load_module, ) if load_module: - return pmodel_class() + return pmodel_class(build_buckets=build_buckets) return pmodel_class @@ -279,7 +280,7 @@ def parameters_for_calc_gnorm(self): return model -def parallelize_model(trainer_args: TrainerArgs, dummy_input: dict[str, Any], load_module: bool): +def parallelize_model(trainer_args: TrainerArgs, dummy_input: dict[str, Any], load_module: bool, build_buckets: bool): tracing_weights = None if trainer_args.tracing_from_weights: tracing_weights = torch.load(trainer_args.tracing_from_weights) @@ -292,11 +293,11 @@ def _new_adapter(parallel_module=None): if not trainer_args.model.parallel_modules: # parallelize the whole model - return _new_adapter().parallelize(dummy_input, load_module=load_module) + return _new_adapter().parallelize(dummy_input, load_module=load_module, build_buckets=build_buckets) if not load_module and all(pm.args is not None for pm in trainer_args.model.parallel_modules): for m in trainer_args.model.parallel_modules: - _new_adapter(m).parallelize(dummy_input, load_module=False) + _new_adapter(m).parallelize(dummy_input, load_module=False, build_buckets=build_buckets) return parallel_sub_modules = { @@ -346,7 +347,7 @@ def __parallel__new__(cls, *args, **kwargs): # This is a trade-off to make sure the parallelized module is consistent. # Maybe we can use torch.distributed.broadcast to sync the random state in all devices. with fork_rng(): - return adapter.parallelize(dummy_input, load_module=load_module, module_args=(args, kwargs)) + return adapter.parallelize(dummy_input, load_module=load_module, build_buckets=build_buckets, module_args=(args, kwargs)) finally: _patch_new() diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index 76abeb8b..7848ae31 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -333,3 +333,31 @@ def after_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: for hook in self.hooks: hook.on_save_checkpoint(trainer, checkpoint) + + +class TrainHookHost: + def _get_hook_objects(self) -> List[Any]: + """ + Return a list of objects that can be hooks (but not necessarily hooks) + """ + ... + + def get_hooks(self) -> List[TrainHook]: + """ + Return a list of TrainHook objects + """ + hooks = {} + visited = set() + def _get_hooks(obj): + if id(obj) in visited: + return + visited.add(id(obj)) + + if isinstance(obj, TrainHook): + hooks[id(obj)] = obj + if isinstance(obj, TrainHookHost): + for o in obj._get_hook_objects(): + _get_hooks(o) + + _get_hooks(self) + return list(hooks.values()) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index d167b95f..c35ad677 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -25,7 +25,7 @@ from nnscaler.utils import enforce_zero_num_worker, is_running_distributed from .trainer_args import AggregatedOutputs, TrainerArgs, fix_input -from .train_hook import AggregatedTrainHook, TrainHook +from .train_hook import AggregatedTrainHook, TrainHook, TrainHookHost from .mixed_module import parallelize_model, mixin_module @@ -159,7 +159,11 @@ def _setup(self): self.dummy_input = self._load_dummy_input() self.dummy_input = self._fix_input(self.dummy_input) - pmodel = parallelize_model(self.train_args, self.dummy_input, load_module=not compile_only) + pmodel = parallelize_model( + self.train_args, self.dummy_input, + load_module=not compile_only, + build_buckets=not self.train_args.is_hybrid_optimizer() + ) if compile_only: return @@ -216,6 +220,7 @@ def _setup(self): def reducer_pre_hook(reducer, grad): grad.div_(self.train_args.scaling_factor) self.optimizer.register_reducer_pre_hook(reducer_pre_hook) + # Currently we never pass `last_epoch` to its constructor self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) self.loggers = self.train_args.create_loggers() @@ -224,8 +229,18 @@ def reducer_pre_hook(reducer, grad): self.optimizer, self.lr_scheduler, ] + component_hooks = [] + for component in supported_hook_components: + if isinstance(component, TrainHook): + component_hooks.append(component) + if isinstance(component, TrainHookHost): + component_hooks.extend(component.get_hooks()) + + # dedup hooks + component_hooks = list({id(hook): hook for hook in component_hooks}.values()) + self.hook = AggregatedTrainHook( - [x for x in supported_hook_components if isinstance(x, TrainHook)] + component_hooks + [self.train_args.create_hook()] ) @@ -252,7 +267,7 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): global_keys = { 'model', 'optimizer', 'train_args', - 'train_status', 'lr_scheduler', 'rank' + 'train_status', 'lr_scheduler', 'rank', 'nnscaler' } # for extra keys (including `dataloader` and `rng_states`), we will not merge them. # Intead we will collect them from all state_dicts @@ -271,6 +286,7 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): 'lr_scheduler': state_dicts[0].get('lr_scheduler', None), 'train_status': state_dicts[0]['train_status'], 'train_args': train_args, + 'nnscaler': state_dicts[0]['nnscaler'], **extra_keys, } return merged_state_dict @@ -389,6 +405,11 @@ def _load_checkpoint(self): else: resume_from = resume_from / f'{self.rank}.ckpt' state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + if state_dict['train_args']['compute_config'] != asdict(self.train_args.compute_config): + logger.warning( + f"compute_config is changed, and loading checkpoint may fail. " + f"If it fails, please try with merged checkpoint." + ) self.hook.on_load_checkpoint(self, state_dict) # if it is not a well-formed state_dict (from third party) @@ -977,7 +998,7 @@ def _train_epoch(self, epoch: int) -> None: step_stat.gnorm = step_stat.gnorm.item() # update parameters - step_stat.lr = self.optimizer.param_groups[0]['lr'] + step_stat.lr = self.optimizer.param_groups[0]['lr'] # only log the first group's lr self.hook.before_optimizer_step(self) self.optimizer.step() self.hook.after_optimizer_step(self) diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 2fb99daa..3381c70d 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -863,6 +863,9 @@ def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model.args) return self.model_type(**kwargs) + def is_hybrid_optimizer(self) -> bool: + return getattr(load_type(self.optimizer.type), 'is_hybrid', False) + def create_parallel_optimizer(self, parallel_model: torch.nn.Module): kwargs = self.create_kwarg(self.optimizer.args) optimizer_class = load_type(self.optimizer.type) diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 32cd75fe..fe1b5130 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -506,15 +506,17 @@ def forward(self, x, y=None, z=None): args=[ 'self', 'init_params=True', - '*', + 'build_buckets=True', + '*args', f'async_op={CompileFlag.async_reducer}', f'max_bucket_size_bytes={CompileFlag.max_reducer_bucket}', f'zero_use_reduce_scatter={CompileFlag.zero_use_reduce_scatter}', + f'**kwargs', ] ) as ib: ib.insert_body(self.model_init_statements) ib.insert_body('') - ib.insert_body('self._post_init(init_params)') + ib.insert_body('self._post_init(init_params, build_buckets)') else: with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.model_init_statements) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index ce5a0d0f..c7b99132 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -49,7 +49,7 @@ from nnscaler.flags import CompileFlag, RuntimeFlag import nnscaler.policies as policies from nnscaler.program import disable_global_graph -from nnscaler.utils import get_member_by_name, setup_stride_broadcast_group, get_shared_params +from nnscaler.utils import get_member_by_name, load_type, setup_stride_broadcast_group, get_shared_params logger = logging.getLogger(__name__) @@ -913,6 +913,7 @@ def parallelize( module_dtype: Optional[torch.dtype] = None, module_fn: Optional[Callable[[], torch.nn.Module]] = None, init_module_params: bool = True, + build_module_buckets: bool = True, broadcast_strategy: Union[str, BroadcastGenFilesStrategy] = 'none', ) -> Union[None, ParallelModule, Type[ParallelModule]]: """ @@ -983,6 +984,12 @@ def __init__(self, init_params=True): Otherwise, they will be empty tensor. This parameter will be passed to the module constructor, so it is only used when module_or_module_class is a module object, and load_module is true. + build_module_buckets (bool): For parallel module, parameters that needs to synchronize will be grouped into buckets for more efficient communication. + If true, grouping process will be done in `__init__` + If false, you should do this by yourself. + This parameter will be passed to the module constructor, + so it is only used when module_or_module_class is a module object, and load_module is true. + Please leave it to true until you have a good reason to change it. module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for generated files. @@ -1116,7 +1123,7 @@ def __init__(self, init_params=True): if is_module_class: return parallel_module_class else: - parallel_module = parallel_module_class(init_module_params) + parallel_module = parallel_module_class(init_module_params, build_module_buckets) parallel_module.train(module_or_module_class.training) # set training state to the same as original module return parallel_module @@ -1245,6 +1252,27 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] OptimizerT = TypeVar('OptimizerT', bound=torch.optim.Optimizer) +HybridOptimizerT = TypeVar('HybridOptimizer', bound=torch.optim.Optimizer) + + +def hybrid( + params: list[torch.nn.Parameter], + param_clss_fn: Callable[[str], tuple[int, int]], + **kwargs, +) -> HybridOptimizerT: + """ + Stub for hybrid optimizer creation. + Signature of Hybrid optimizer constructor: + ``` + def __init__(self, params, param_clss, **kwargs): + ... + ``` + But when you pass arguments to `build_optimizer` + You must replace `param_clss` with `param_clss_fn`, + And `build_optimizer` will automatically replace `param_clss_fn` with the actual `param_clss`. + """ + ... +hybrid.is_hybrid = True # mark this function as hybrid optimizer factory def build_optimizer( @@ -1287,6 +1315,8 @@ def build_optimizer( Please note the type annotation of the returned optimizer (`Union[OptimizerT, ParallelOptimizer]`) is just for intellisense. """ + PARAM_CLSS_FN_NAME = 'param_clss_fn' + if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): raise RuntimeError("Old style CubeModule is not supported") @@ -1294,12 +1324,21 @@ def build_optimizer( if any(m != module and isinstance(m, ParallelModule) and m.compute_config.use_end2end for m in module.modules()): raise RuntimeError("End2End module cannot be nested in another module") + is_hybrid = False + if getattr(optimizer_fn, 'is_hybrid', False): + if PARAM_CLSS_FN_NAME not in kwargs: + raise ValueError("param_clss_fn must be provided when using hybrid optimizer") + # syntax sugar + kwargs[PARAM_CLSS_FN_NAME] = load_type(kwargs[PARAM_CLSS_FN_NAME]) + is_hybrid = True + RuntimeFlag.skip_reducer = True RuntimeFlag.skip_zero_grad = False non_parallel_module_reducer = None non_parallel_modules = [m for m in module.modules() if not isinstance(m, ParallelModule)] parallel_modules = [m for m in module.modules() if isinstance(m, ParallelModule)] + parallel_modules_prefix = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} if not parallel_modules: raise RuntimeError("No ParallelModule found in the module. Please make sure you have called parallelize() before build_optimizer().") @@ -1311,6 +1350,22 @@ def build_optimizer( non_parallel_parameters_dict[param] = None non_parallel_parameters = list(non_parallel_parameters_dict.keys()) + param_original_names = {} + for n, p in module.named_parameters(): + nparts = n.split('.') + module_prefix = '.'.join(nparts[:-1]) + if module_prefix in parallel_modules_prefix: + name_mapping = parallel_modules_prefix[module_prefix].get_full_map() + original_name = name_mapping[nparts[-1]].orig_name + param_original_names[p] = \ + f'{module_prefix}.{original_name}' if module_prefix else original_name + else: + param_original_names[p] = n + if is_hybrid: + param_clss = {p: kwargs[PARAM_CLSS_FN_NAME](n) for p, n in param_original_names.items()} + else: + param_clss = {} + # check if all ParallelModules have the same gpu_config compute_configs = [m.compute_config for m in parallel_modules] for i in range(1, len(compute_configs)): @@ -1340,7 +1395,13 @@ def build_optimizer( non_parallel_module_reducer = Reducer(group, **reducer_config) for param in non_parallel_parameters: non_parallel_module_reducer.add_param(param) - non_parallel_module_reducer.build_buckets() + non_parallel_module_reducer.build_buckets(param_clss=param_clss) + + if is_hybrid: + for pm in parallel_modules: + pm.build_buckets(param_clss=param_clss) + for reducer in pm.reducers: + param_clss.update(reducer.get_opt_params()) opt_module_locs: Dict[str, ModuleParameterLocation] = {} def _local_parameters(module: torch.nn.Module): @@ -1373,7 +1434,13 @@ def _local_parameters(module: torch.nn.Module): opt_module_locs[name].count += 1 yield param - optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), **kwargs) + if is_hybrid: + optimizer = optimizer_fn(_local_parameters(module), + param_clss=param_clss, + **{k: v for k, v in kwargs.items() if k != PARAM_CLSS_FN_NAME} + ) + else: + optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), **kwargs) optimizer._non_parallel_module_reducer = non_parallel_module_reducer optimizer._extra_state = OptimizerExtraState( rank=torch.distributed.get_rank(), @@ -1386,21 +1453,21 @@ def _local_parameters(module: torch.nn.Module): } ) - def _step_pre_hook(opt, *args, **kwargs): - opt.sync_shard_grad() - - def _step_post_hook(opt, *args, **kwargs): + orig_step = optimizer.step + def _patched_step(self, closure=None): + # Please note: + # when closure is used in optimizer.step() + # the backward is done in closure, + # and it is useless to sync grad because grad is still unavailable there + # so you must call sync_shard_grad() manually in this case. + if closure is None: + self.sync_shard_grad() + orig_step(closure=closure) for m in parallel_modules: m.gather_params() if non_parallel_module_reducer: non_parallel_module_reducer.gather_params() - - # Please note: - # register_step_pre_hook doesn't work expectly - # when closure is used in optimizer.step() - # in that case, you must call sync_shard_grad() manually - optimizer.register_step_pre_hook(_step_pre_hook) - optimizer.register_step_post_hook(_step_post_hook) + optimizer.step = types.MethodType(_patched_step, optimizer) orig_zero_grad = optimizer.zero_grad def _patched_zero_grad(self, set_to_none: bool = True): @@ -1575,6 +1642,11 @@ def _get_parallel_module_state_dict_info( return pm_extra_states, pm_state_dicts, non_pm_state_dict +def _is_supported_optimizer(name: str): + from nnscaler.runtime.hybrid_optimizer import HybridOptimizer + return ('adam' in name.lower()) or name == HybridOptimizer.__name__ + + def _get_optimizer_state_dict_info( optimizer_state_dicts: List[Dict[str, Any]] ) -> Tuple[ @@ -1631,7 +1703,7 @@ def _get_optimizer_state_dict_info( ] = {} for opt_state_dict in optimizer_state_dicts: opt_extra_state = OptimizerExtraState(**opt_state_dict[ParallelModule.EXTRA_STATE_KEY]) - if 'adam' not in opt_extra_state.name.lower(): + if not _is_supported_optimizer(opt_extra_state.name): raise ValueError("Only Adam-like optimizers are supported.") opt_extra_states[opt_extra_state.rank] = opt_extra_state @@ -1870,7 +1942,7 @@ def load_merged_state_dict( module.to(device) if optimizer is not None and optimizer_state_dict is not None: - if 'adam' not in optimizer._extra_state.name.lower(): + if not _is_supported_optimizer(optimizer._extra_state.name): raise ValueError("Only Adam-like optimizers are supported.") # handle non-paralleled module parameters @@ -2371,7 +2443,7 @@ def load_deduped_state_dict( broadcast_weights(module) if optimizer is not None: - if 'adam' not in optimizer._extra_state.name.lower(): + if not _is_supported_optimizer(optimizer._extra_state.name): raise ValueError("Only Adam-like optimizers are supported.") if optimizer_state_dict is None: raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index 679cdde9..d28cdf52 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -16,7 +16,7 @@ # According to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#device-memory-accesses -# Any address of a variable residing in global memory or returned by one of the memory allocation +# Any address of a variable residing in global memory or returned by one of the memory allocation # routines from the driver or runtime API is always aligned to at least 256 bytes. # But in our practice, we found that 16 bytes alignment is enough, it can be modified if unaligned access is detected. ALIGNED_BYTES = 16 @@ -68,6 +68,7 @@ def __init__(self, params: List[torch.nn.Parameter], zero_crossgroup: torch.distributed.ProcessGroup = None, zero_use_reduce_scatter: bool = False, align_size: int = ALIGNED_BYTES, + param_cls: Any = None, ): """ Create a communication unit for parameter allreduce. @@ -87,9 +88,11 @@ def __init__(self, params: List[torch.nn.Parameter], zero_crossgroup (torch.distributed.ProcessGroup): the communication group for cross zero group allreduce when reduce scatter is enabled zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization align_size (int): the alignment size in bytes for each parameter + param_cls (Any): the class of the parameters """ self._params: List[torch.nn.Parameter] = params + self._param_cls: Any = param_cls self._pofset: Dict[torch.nn.Parameter, int] = {} self._reduce_op = reduce_op self._group = group @@ -137,6 +140,11 @@ def params(self) -> List[torch.nn.Parameter]: """Parameter list""" return self._params + @property + def param_cls(self) -> Any: + """Class of the parameters in the bucket""" + return self._param_cls + @property def zero(self) -> bool: """Whether enable zero for this bucket""" @@ -257,6 +265,14 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): # grad_acc must keep, otherwise the hook won't take effect self._hooks.append((grad_acc, hook)) + def unregister_hooks(self): + """ + Unregister all post-backward hook to parameters. + """ + for _, hook in self._hooks: + hook.remove() + self._hooks.clear() + def sync_grads(self): """ Wait until allreduce finished (async), or perform allreduce (sync). @@ -422,7 +438,9 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization align_size (int): the alignment size in bytes for each parameter """ + # the parameters with same class will be consecutive in the list. self._params: List[torch.nn.Parameter] = list() + self._param_clss: Dict[torch.nn.Parameter, Any] = dict() # the class of each parameter, used for sorting self._param_ids: Set[int] = set() self._numel: int = 0 self._ranks = ranks @@ -562,15 +580,31 @@ def _bind_params(self): aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) ofst += aligned_nelements - def build_buckets(self): + def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None): """ Build buckets the reducer. - The parameters in each bucket have consistent data types, + The parameters in each bucket have consistent data types and classes, and each bucket contains at least one parameter. If the bucket contains more than 2 parameters, than the total size is samller than the max_bucket_size_bytes. + + You can call this method multiple times to rebuild the buckets. + Typically this will be called when building optimizer when multiple optimizers/param groups are used. + And we will put parameters with different optimizer or different param groups into different buckets. """ + self._param_clss = param_clss or {} + # sort parameters by their class + # which can help bucket building + if self._param_clss: + self._params.sort(key=lambda p: self._param_clss[p]) + for bucket in self._buckets: + # rebuild bucket should be done before any hooks registered. + if bucket._pre_hooks or bucket._post_hooks: + raise RuntimeError("Cannot rebuild buckets while pre/post hooks are registered.") + bucket.unregister_hooks() + self._buckets.clear() + # step 1: build bucket for overlapping gradient synchronization # self._numel * 8 + 1 here is to make sure # the bucket size is larger than the total size of all parameters @@ -582,7 +616,9 @@ def build_buckets(self): # (used in pytorch, with a couple percentage improvement) bucket_size = self._numel * 8 + 1 if not self._bucket_size else self._bucket_size + seq_buckets_cls: List[Any] = [] last_bucket_size = None + last_bucket_cls = None assert len(set(p.dtype for p in self._params)) == 1, ( "All parameters in the reducer should have the same data type" @@ -597,12 +633,17 @@ def build_buckets(self): if len(self.seq_buckets) == 0: self.seq_buckets.append([param]) last_bucket_size = cur_byte_size - elif last_bucket_size + cur_byte_size <= bucket_size: + last_bucket_cls = self._param_clss.get(param, None) + seq_buckets_cls.append(last_bucket_cls) + elif last_bucket_size + cur_byte_size <= bucket_size \ + and last_bucket_cls == self._param_clss.get(param, None): self.seq_buckets[-1].append(param) last_bucket_size += cur_byte_size else: self.seq_buckets.append([param]) last_bucket_size = cur_byte_size + last_bucket_cls = self._param_clss.get(param, None) + seq_buckets_cls.append(last_bucket_cls) # step 2: build meta data for the offset of each bucket # the start of each bucket will be padded to the next multiple of `len(self.ranks)` @@ -624,7 +665,7 @@ def build_buckets(self): # step 5: build buckets buckets: List[Bucket] = [] - for params, start, stop in zip(self.seq_buckets, self.starts, self.stops): + for params, param_cls, start, stop in zip(self.seq_buckets, seq_buckets_cls, self.starts, self.stops): # initialize buckets bucket = Bucket( params, @@ -638,6 +679,7 @@ def build_buckets(self): self._zero_crossgroup, self._zero_use_reduce_scatter, self._align_size, + param_cls=param_cls, ) buckets.append(bucket) torch.cuda.empty_cache() @@ -692,9 +734,23 @@ def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: Returns: List[torch.nn.Parameter]: parameters for optimizer """ - params = [] + return list(self.get_opt_params().keys()) + + def get_opt_params(self) -> dict[torch.nn.Parameter, Any]: + """ + Get parameters and their classes for optimizers + Please note for ZeRO optimization, + the returned parameters are not the same as the original parameters, + and can have paddings (with value 0.0) both at the end and in the middle of paramters data. + + the calculation of gnorm is not affected as paddings are all 0. + + Returns: + List[torch.nn.Parameter]: parameters for optimizer + """ + params = {} for bucket in self._buckets: - params.append(bucket._param_for_optimizer) + params[bucket._param_for_optimizer] = bucket.param_cls return params def broadcast_params(self): diff --git a/nnscaler/runtime/hybrid_optimizer.py b/nnscaler/runtime/hybrid_optimizer.py new file mode 100644 index 00000000..e860f56d --- /dev/null +++ b/nnscaler/runtime/hybrid_optimizer.py @@ -0,0 +1,299 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable, Type, Union + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.hooks import RemovableHandle + +from nnscaler.cli.arg_parser import deserialize_dataclass +from nnscaler.cli.train_hook import TrainHookHost, TrainHook +from nnscaler.utils import fn_field, OptStateDict + + +@dataclass +class HybridSubOptParamGroupConfig: + options: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HybridSubOptConfig: + type: Union[Type[Optimizer], Callable[..., Optimizer]] = fn_field(default=None) + options: dict[str, Any] = field(default_factory=dict) + param_groups: list[HybridSubOptParamGroupConfig] = field(default_factory=list) + + def __post_init__(self): + if not self.type: + raise ValueError("Optimizer type must be specified in HybridSubOptConfig") + + +@dataclass +class HybridOptConfig: + optimizers: list[HybridSubOptConfig] = field(default_factory=list) + + def __post_init__(self): + if not self.optimizers: + raise ValueError("At least one optimizer must be specified in HybridOptConfig") + + +class HybridRemovableHandle: + def __init__(self, removable_handles: list[RemovableHandle]): + self.removable_handles = removable_handles + + def remove(self): + for removable_handle in self.removable_handles: + removable_handle.remove() + + def __enter__(self) -> "HybridRemovableHandle": + return self + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + self.remove() + + +class HybridOptimizer(torch.optim.Optimizer, TrainHookHost): + """ + A hybrid optimizer that combines multiple optimizers/multiple param groups + into a single optimizer. + + Please note HybridOptimizer doesn't call super().__init__(), + So it is actually a duck type for optimizer. + """ + + # Identifier for hybrid optimizer + is_hybrid = True + + def __init__( + self, + params: Iterable[torch.nn.Parameter], + param_clss: dict[torch.nn.Parameter, tuple[int, int]], + config: Union[HybridOptConfig, dict[str, Any]] + ): + """ + Initialize the hybrid optimizer. + + Args: + params (Iterable[torch.nn.Parameter]): The parameters to optimize. + param_clss (dict[torch.nn.Parameter, tuple[int, int]]): The parameter classes for each parameter. + Please replace this argument with `param_clss_fn` (Callable[[str], tuple[int, int]]) + when you use creating it with `nnscaler.build_optimizer` (including cli trainer). + config (Union[HybridOptConfig, dict[str, Any]]): The configuration for the hybrid optimizer. + """ + params = list(params) + if isinstance(config, dict): + config = deserialize_dataclass(config, HybridOptConfig) + self.config = config + + self.optimizers = [] + classified_params = defaultdict(list) + # map from (optimizer_idx, pg_idx, param_pg_idx) to param global param index + param_loc = {} + + for idx, param in enumerate(params): + param_cls = param_clss[param] + assert param_cls[0] < len(self.config.optimizers) + classified_params[param_cls].append(param) + + loc = *param_cls, len(classified_params[param_cls]) - 1 + param_loc[loc] = idx + + # sort with key i.e. (optimizer idx, param group idx) + classified_params = dict(sorted(classified_params.items())) + + quick_param_groups = {param_cls: {"params": params} for param_cls, params in classified_params.items()} + opt_param_groups = defaultdict(dict) + for param_cls, group in quick_param_groups.items(): + opt_param_groups[param_cls[0]][param_cls[1]] = group + + for idx, opt_config in enumerate(config.optimizers): + param_groups = opt_param_groups[idx] + if len(param_groups) > 1: + if len(param_groups) != len(opt_config.param_groups): + raise ValueError(f"Expected {len(opt_config.param_groups)} param groups, got {len(param_groups)}") + # param group indices must be consecutive. + if max(param_groups.keys()) != len(opt_config.param_groups) - 1: + raise ValueError(f"Param group indices must be consecutive. We have {len(opt_config.param_groups)} groups, got max group id {max(param_groups.keys())}") + for param_group_idx, param_group in param_groups.items(): + param_group.update(opt_config.param_groups[param_group_idx].options) + else: + if len(opt_config.param_groups) > 1: + raise ValueError(f"Expected at most 1 param group, got {len(opt_config.param_groups)}") + if opt_config.param_groups: + param_groups[0].update(opt_config.param_groups[0].options) + optimizer = opt_config.type(param_groups.values(), **opt_config.options) + self.optimizers.append(optimizer) + + # map from param global index to (optimizer_idx, param_idx) + self._param_map: dict[int, tuple[int, int]] = {} + # map from (optimizer_idx, param_idx) to param global idx + self._reverse_param_map: dict[tuple[int, int], int] = {} + for opt_idx, optimizer in enumerate(self.optimizers): + state_dict: OptStateDict = optimizer.state_dict() + for pg_idx, pg in enumerate(state_dict['param_groups']): + for param_idx_in_pg, param_idx in enumerate(pg['params']): + # param_idx_in_pg is the index in this param group + # param_idx is the index in this optimizer + global_idx = param_loc[(opt_idx, pg_idx, param_idx_in_pg)] + self._param_map[global_idx] = (opt_idx, param_idx) + self._reverse_param_map[(opt_idx, param_idx)] = global_idx + + # Don't call base init + # So HybridOptimizer is a duck optimizer + # super().__init__(params, {}) + + # simulated param groups + self.param_groups = [] + for optimizer in self.optimizers: + self.param_groups.extend(optimizer.param_groups) + + def _get_hook_objects(self): + return self.optimizers + + def step(self, closure=None): + """ + Perform a single optimization step. + """ + assert closure is None, "Closure is not supported in HybridOptimizer" + for optimizer in self.optimizers: + optimizer.step(closure) + + def zero_grad(self, set_to_none: bool = False): + """ + Zero the gradients of all optimizers. + """ + for optimizer in self.optimizers: + optimizer.zero_grad(set_to_none=set_to_none) + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + " [\n" + format_string += ",\n".join(f"{repr(opt)}" for opt in self.optimizers) + format_string += "\n]" + return format_string + + def register_step_pre_hook(self, hook) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_step_pre_hook(hook) for opt in self.optimizers]) + + def register_step_post_hook(self, hook) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_step_post_hook(hook) for opt in self.optimizers]) + + def register_state_dict_pre_hook( + self, hook, prepend: bool = False + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_state_dict_pre_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def register_state_dict_post_hook( + self, + hook, + prepend: bool = False, + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_state_dict_post_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def state_dict(self): + state_dicts: list[OptStateDict] = [opt.state_dict() for opt in self.optimizers] + merged_state_dict: OptStateDict = {'state': {}, 'param_groups': [{'children': {}}]} + + for opt_idx, sd in enumerate(state_dicts): + for param_idx, s in sd['state'].items(): + merged_state_dict['state'][self._reverse_param_map[(opt_idx, param_idx)]] = s + merged_state_dict['param_groups'][0]['children'][opt_idx] = sd['param_groups'] + + merged_state_dict['param_groups'][0]['params'] = list(range(len(self._param_map))) + merged_state_dict['param_groups'][0]['param_map'] = self._param_map + merged_state_dict['param_groups'][0]['reverse_param_map'] = self._reverse_param_map + merged_state_dict['state'] = dict(sorted(merged_state_dict['state'].items())) + + return merged_state_dict + + def register_load_state_dict_pre_hook( + self, + hook, + prepend: bool = False, + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_load_state_dict_pre_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def register_load_state_dict_post_hook( + self, hook, prepend: bool = False + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_load_state_dict_post_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def load_state_dict(self, state_dict) -> None: + child_state_dicts = [{'state': {}, 'param_groups': []} for _ in self.optimizers] + + for idx, sd in enumerate(child_state_dicts): + # copy param groups from state dict + sd['param_groups'] = state_dict['param_groups'][0]['children'][idx] + if len(sd['param_groups']) != len(self.optimizers[idx].param_groups): + raise ValueError(f"Number of param groups mismatch. Expected {len(self.optimizers[idx].param_groups)} got {len(sd['param_groups'])}") + # param groups can be changed (for example, the compute config is changed) + # state_dict for HybridOptimizer is already well organized, + # here we will carefully dispatch parameters to each optimizer. + current_state_dict = self.optimizers[idx].state_dict() + for pg, current_pg in zip(sd['param_groups'], current_state_dict['param_groups']): + pg['params'] = current_pg['params'][:] # make a copy + + for param_idx, param_state in state_dict['state'].items(): + opt_idx, param_state_idx = self._param_map[param_idx] + child_state_dicts[opt_idx]['state'][param_state_idx] = param_state + + for child_state_dict, opt in zip(child_state_dicts, self.optimizers): + opt.load_state_dict(child_state_dict) + + def add_param_group(self, param_group: dict[str, Any]) -> None: + # no-op to avoid creating new parameter groups + # all parameter groups are managed by the individual optimizers + pass + + +@dataclass +class HybridSubLRSchedulerConfig: + type: Union[Type[LRScheduler], Callable[..., LRScheduler]] = fn_field(default=None) + options: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HybridLRSchedulerConfig: + schedulers: list[HybridSubLRSchedulerConfig] = field(default_factory=list) + + +class HybridLRScheduler(LRScheduler, TrainHookHost): + """ + A hybrid learning rate scheduler that combines multiple schedulers. + + Please note HybridLRScheduler doesn't call super().__init__(), + So it is actually a duck type for scheduler. + """ + + def __init__( + self, + optimizer: HybridOptimizer, + config: Union[HybridLRSchedulerConfig, dict[str, Any]], + last_epoch: int = -1, + ): + assert isinstance(optimizer, HybridOptimizer), "Optimizer must be an instance of HybridOptimizer" + if isinstance(config, dict): + config = deserialize_dataclass(config, HybridLRSchedulerConfig) + + if len(config.schedulers) == 1: + self.schedulers = [config.schedulers[0].type(optimizer, **config.schedulers[0].options)] + elif len(config.schedulers) == len(optimizer.optimizers): + self.schedulers = [sub_config.type(opt, **sub_config.options) for sub_config, opt in zip(config.schedulers, optimizer.optimizers)] + else: + raise ValueError(f"Expected {len(optimizer.optimizers)} or 1 schedulers, got {len(config.schedulers)}") + + def _get_hook_objects(self): + return self.schedulers + + def step(self, epoch=None): + for scheduler in self.schedulers: + scheduler.step(epoch) + + def state_dict(self): + return {idx: scheduler.state_dict() for idx, scheduler in enumerate(self.schedulers)} + + def load_state_dict(self, state_dict): + for idx, sd in state_dict.items(): + self.schedulers[idx].load_state_dict(sd) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index d892336f..6fa3145d 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -213,14 +213,29 @@ def zero_grad(self): def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: """Get parameter list for optimizer""" - params = [] + return list(self.get_opt_params().keys()) + + def get_opt_params(self, prefix='', classify_param_cls_fn: Callable[[str], Any]=None) -> dict[torch.nn.Parameter, Any]: + """ + Get all parameters and their classifications + + Args: + prefix (str): The prefix of this module, + which will be used to generate full names of parameters and further classify them. + classify_param_cls_fn (Callable[[str], Any], optional): A function to classify parameters by name. + + Returns: + dict[torch.nn.Parameter, Any]: A dictionary mapping parameters to their classifications. + + """ + params = {} reducer_pids = set() for reducer in self._reducers: - params += reducer.parameters_for_optimizer() + params.update(reducer.get_opt_params()) reducer_pids.update(id(p) for p in reducer.params) - for param in self.parameters(): + for name, param in self.named_parameters(prefix): if id(param) not in reducer_pids: - params.append(param) + params[param] = classify_param_cls_fn(name) if classify_param_cls_fn else None # print(f'> get out parameters: {sum(p.numel() for p in params)}') return params @@ -927,12 +942,14 @@ def _warn_uninitialized_non_persistent_buffers(self, raise_error = False): else: _logger.warning(_non_persistent_buffers_load_warning) - def _post_init(self, init_params=True): + def _post_init(self, init_params=True, build_buckets=True): """ This is post init function to further initialize the model. Should be called by subclass's __init__(). Args: init_params (bool): whether to load model init parameters. Default True. + build_buckets (bool): whether to build buckets for the model. Default True. + If it is False, you must manually call `build_buckets()` later before use this module. """ # Here we check the rank to load the module file name # Current we don't check rank when we are not in distributed mode @@ -957,11 +974,6 @@ def _post_init(self, init_params=True): ) self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}"), weights_only=False) - for reducer in self.reducers: - reducer.build_buckets() - - self._zero_metadata = self._get_zero_metadata() - # add state_dict hook to save extra state # Please note extra_state is only used for merging, not for loading # so we can safely remove it in load_state_dict pre hook @@ -969,6 +981,27 @@ def _post_init(self, init_params=True): # add load_state_dict pre hook to pop extra state to prevent warning self._register_load_state_dict_pre_hook(ParallelModule._pre_load_state_dict_hook, with_module=True) + if build_buckets: + self.build_buckets() + + def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None): + """ + Build buckets for the model reducers. + + You can call this method multiple times to rebuild the buckets. + Typically this will be called when building optimizer when multiple optimizers/param groups are used. + And we will put parameters with different optimizer or different param groups into different buckets. + + Currently we have done an optimization to make sure this is only called once even for hybrid optimizers + by + 1. setting `build_buckets=False` when calling constructor in `nnscaler.parallelize`. + 2. manually calling `build_buckets()` later in `nnscaler.build_optimizer` + """ + for reducer in self.reducers: + reducer.build_buckets(param_clss) + + self._zero_metadata = self._get_zero_metadata() + def forward(self, *args, **kwargs): self._warn_uninitialized_non_persistent_buffers(raise_error=True) if self.training: diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 310c6649..21b8d3f1 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -4,16 +4,16 @@ import builtins import importlib from contextlib import contextmanager -from functools import wraps +from functools import wraps, cache from typing import ( Generator, Optional, Tuple, Callable, Dict, List, Set, Any, - Iterable, Type, Union, Protocol, ClassVar, cast, TypeVar + Iterable, Type, TypedDict, Union, Protocol, ClassVar, cast, TypeVar ) import logging from pathlib import Path import sys from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field import inspect import os @@ -325,6 +325,7 @@ def fields(model: TDataClass, /) -> TDataClass: return cast(TDataClass, _GetFields(model)) +@cache def load_type(type_name: str): """ Load function/class from its full qualified name @@ -457,3 +458,25 @@ def steps(nsteps: int): RuntimeFlag.skip_reducer = (not (step == nsteps - 1)) yield step RuntimeFlag.skip_zero_grad, RuntimeFlag.skip_reducer = old + + +class AdamOptState(TypedDict): + step: int + exp_avg: torch.Tensor + exp_avg_sq: torch.Tensor + + +class OptStateParamGroup(TypedDict): + params: list[int] + lr: int + + +class OptStateDict(TypedDict): + state: dict[int, AdamOptState | dict[str, Any]] + param_groups: list[OptStateParamGroup | dict[str, Any]] + + +def fn_field(**kwargs): + metadata = kwargs.pop('metadata', {}) + metadata['deserialize'] = lambda t: None if t is None else load_type(t) + return field(**kwargs, metadata=metadata) diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 427c8f26..dffb0813 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -209,6 +209,28 @@ class A: assert y.p == {'value': 'auto'} +def test_merge_dict(): + a = { + 'compute_config': { + 'plan_ngpus': 1 + }, + 'optimizer': { + 'type': 'torch.nn.Adam', + 'args': { + 'lr': 0.001 + } + } + } + merge_args(a, ['--optimizer', { + 'type': 'torch.nn.AdamW', + 'args': { + 'hello': 'haha' + } + }]) + assert a['optimizer']['args']['lr'] == 0.001 + assert a['optimizer']['args']['hello'] == 'haha' + + def test_merge_list(): @dataclass class A: diff --git a/tests/cli/test_hooks.py b/tests/cli/test_hooks.py new file mode 100644 index 00000000..102a48bc --- /dev/null +++ b/tests/cli/test_hooks.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, List + +from nnscaler.cli.train_hook import TrainHook, TrainHookHost + + +class A(TrainHook): + pass + +class B(TrainHook): + pass + +class C(TrainHook, TrainHookHost): + def _get_hook_objects(self) -> List[Any]: + return [A(), B(), self] + + +class D(TrainHookHost): + def _get_hook_objects(self) -> List[Any]: + return [self, A(), C()] + +def test_hook(): + hooks = D().get_hooks() + assert len(hooks) == 4 + assert isinstance(hooks[0], A) + assert isinstance(hooks[1], C) + assert isinstance(hooks[2], A) + assert isinstance(hooks[3], B) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 2dc1d252..0ae31da9 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -614,7 +614,7 @@ def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): torch.distributed.barrier() -def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): +def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False, hybrid_opt=False): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) gen_savedir = save_dir / 'gen' @@ -680,6 +680,47 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): else: raise ValueError(f'parallel_type {parallel_type} is not supported') + + def param_clss_fn(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'mlp0.' in param_name: + return 0, 0 + elif 'mlp1.' in param_name: + return 0, 1 + else: + return 1, 0 + + optimizer_config = { + 'type': 'nnscaler.HybridOptimizer', + 'args': { + 'param_clss_fn': param_clss_fn, + 'config': { + 'optimizers':[ + { + 'type': torch.optim.Adam, + 'options': { + 'lr': 0.01, + }, + 'param_groups': [ + {}, + {} + ], + },{ + 'type': torch.optim.Adam, + 'options': { + 'lr': 0.01 + } + } + ] + } + } + } + + if hybrid_opt: + additional_args.extend(['--optimizer!', '--optimizer', optimizer_config]) + # train 4 epcho in one time trainer = Trainer([ '-f', config_path, @@ -716,9 +757,10 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): def trainer_correctness_worker_aggregate(tmp_path): for parallel_type in range(5): for async_reducer in [False, True]: - print(f'parallel_type={parallel_type}, async_reducer={async_reducer}') - save_dir = tmp_path/f'{parallel_type}-{async_reducer}' - trainer_correctness_worker(save_dir, parallel_type, async_reducer) + for hybrid_opt in [True, False]: + print(f'parallel_type={parallel_type}, async_reducer={async_reducer}, hybrid_opt={hybrid_opt}') + save_dir = tmp_path/f'{parallel_type}-{async_reducer}-{hybrid_opt}' + trainer_correctness_worker(save_dir, parallel_type, async_reducer, hybrid_opt) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') @@ -727,19 +769,28 @@ def test_trainer_correctness(tmp_path): merged_ckpts = {} for parallel_type in range(5): for async_reducer in [False, True]: - save_dir = tmp_path/f'{parallel_type}-{async_reducer}' - merged_ckpts[(parallel_type, async_reducer)] = torch.load(save_dir/'merged.pt') + for hybrid_opt in [True, False]: + save_dir = tmp_path/f'{parallel_type}-{async_reducer}-{hybrid_opt}' + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)] = torch.load(save_dir/'merged.pt') for parallel_type in range(5): for async_reducer in [False, True]: - assert_equal( - merged_ckpts[(parallel_type, async_reducer)]['model'], - merged_ckpts[(0, False)]['model'] - ) - assert_equal( - merged_ckpts[(parallel_type, async_reducer)]['optimizer'], - merged_ckpts[(0, False)]['optimizer'] - ) + for hybrid_opt in [True, False]: + assert_equal( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['model'], + merged_ckpts[(0, False, False)]['model'] + ) + if not hybrid_opt: + assert_equal( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['optimizer'], + merged_ckpts[(0, False, False)]['optimizer'] + ) + else: + # param_groups are different when using hybrid optimizer. + assert_equal( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['optimizer']['state'], + merged_ckpts[(0, False, False)]['optimizer']['state'] + ) def tracing_from_weights_worker(tmp_path): diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 0cfac768..aa177f8f 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -117,7 +117,7 @@ def init_distributed(): def assert_equal(a: Any, b: Any): assert type(a) == type(b) if isinstance(a, torch.Tensor): - assert torch.equal(a.cpu(), b.cpu()) + assert torch.equal(a.cpu(), b.cpu()), torch.max(torch.abs(a.cpu() - b.cpu())) elif isinstance(a, dict): assert len(a) == len(b) for k in a.keys(): @@ -127,7 +127,7 @@ def assert_equal(a: Any, b: Any): for i in range(len(a)): assert_equal(a[i], b[i]) else: - assert a == b + assert a == b, f"Values are not equal: {a} != {b}" def assert_close(a: Any, b: Any, atol=1e-6, rtol=1e-6): diff --git a/tests/runtime/test_hybrid_optimizer.py b/tests/runtime/test_hybrid_optimizer.py new file mode 100644 index 00000000..e22d8f72 --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path +import shutil + +import torch +import pytest +import torch.distributed + +from nnscaler.cli.trainer import Trainer +from tests.parallel_module.common import assert_close, assert_equal +from ..launch_torchrun import launch_torchrun + + +def param_clss_fn(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'layers.1.' in param_name or 'layers.10.' in param_name: + return 0, 0 + elif 'layers.2.' in param_name or 'layers.12.' in param_name: + return 0, 1 + else: + return 1, 0 + +_lr_history = [] +def on_train_step_start(trainer: 'Trainer', batches) -> None: + _lr_history.append(( + trainer.optimizer.optimizers[0].param_groups[0]['lr'], + trainer.optimizer.optimizers[0].param_groups[1]['lr'], + trainer.optimizer.optimizers[1].param_groups[0]['lr'], + )) + + +def trainer_worker(save_dir, use_zero): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('test_hybrid_optimizer_trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + _lr_history.clear() + + # train with a resume + ckpt0_savedir = save_dir / 'ckpt0' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + assert len(_lr_history) == 10 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--checkpoint.resume_from', 'last', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + + assert len(_lr_history) == 20 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + _lr_history.clear() + # train in one time + ckpt1_savedir = save_dir / 'ckpt1' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + assert len(_lr_history) == 20 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + + trainer = Trainer([ + '-f', config_path, + '--compute_config.plan_ngpus', '2', + '--pas_policy', 'tp', + '--max_train_steps', '30', + '--checkpoint.resume_from.checkpoint', 'last', + '--checkpoint.resume_from.with_merged', True, + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(not use_zero), + ]) + trainer.run() + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + r = trainer._merge_checkpoint([ckpt0_savedir / 'last' / f'{i}.ckpt' for i in range(2)]) + # should success + assert r + + torch.distributed.barrier() + + # trainer = Trainer([ + # '-f', config_path, + # '--compute_config.plan_ngpus', '1', + # '--max_train_steps', '40', + # '--checkpoint.resume_from.checkpoint', 'last', + # '--checkpoint.resume_from.with_merged', True, + # '--gen_savedir', str(gen_savedir), + # '--checkpoint.save_dir', str(ckpt0_savedir), + # ]) + # trainer.run() + # torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('use_zero', [True, False]) +def test_hybrid_optimizer(tmp_path, use_zero): + launch_torchrun(2, trainer_worker, tmp_path, use_zero) diff --git a/tests/runtime/test_hybrid_optimizer_trainer_args.yaml b/tests/runtime/test_hybrid_optimizer_trainer_args.yaml new file mode 100644 index 00000000..1484fe55 --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer_trainer_args.yaml @@ -0,0 +1,76 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_train_steps: 10 +enable_progress_bar: false +precision: bf16 +instance_name: p$(compute_config.plan_ngpus) + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: nnscaler.HybridOptimizer + args: + param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn + config: + optimizers: + - type: torch.optim.Adam + options: + lr: 0.02 + param_groups: + - options: + lr: 0.04 + - options: + lr: 0.06 + - type: torch.optim.AdamW + options: + lr: 0.04 + +lr_scheduler: + type: nnscaler.HybridLRScheduler + args: + config: + schedulers: + - type: torch.optim.lr_scheduler.ConstantLR + options: + factor: 0.5 + total_iters: 5 + - type: torch.optim.lr_scheduler.ConstantLR + options: + factor: 0.2 + total_iters: 5 + interval: step + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped + +hook: + on_train_step_start: tests.runtime.test_hybrid_optimizer.on_train_step_start From bf83275c856b71101cc934b4e114ce0fae6fa6c5 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 1 Sep 2025 03:48:57 +0000 Subject: [PATCH 1834/1892] Merged PR 2399: [Runtime] ParallelModule: move more instance member variable to class variable. --- nnscaler/codegen/module/module.py | 47 ++++++++--- nnscaler/parallel.py | 55 +++++++----- nnscaler/runtime/module.py | 108 ++++++++++++++++-------- tests/compiler/test_compile.py | 3 +- tests/parallel_module/test_broadcast.py | 4 +- 5 files changed, 148 insertions(+), 69 deletions(-) diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index fe1b5130..907323a9 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -8,6 +8,7 @@ import torch import numpy as np import inspect +import pickle from nnscaler.ir.cten import IRCell from nnscaler.ir.tensor import IRFullTensor, IRSubTensor @@ -317,7 +318,8 @@ def gen( *, as_parallel_module: bool = False, end2end_mode: bool = False, - forward_args: Optional[Dict[str, Any]] = None + forward_args: Optional[Dict[str, Any]] = None, + outfile_attr_meta_map: Optional[str] = None, ) -> str: """ Generate model implementation code based on the given graph. @@ -406,6 +408,7 @@ def forward(self, x, y=None, z=None): This is used only in parallel module. forward_args (Dict[str, Any]): argument names and their default values of forward function, if None, use node inputs. This is used only in parallel module. + outfile_attr_meta_map (str): output file path for parameter mapping. None if don't save Returns: generated code @@ -451,6 +454,7 @@ def forward(self, x, y=None, z=None): if k not in param_first_used_pos: param_first_used_pos[k] = (i, v) + attr_meta_map = {} # emit code for node in sequence: if isinstance(node, IRSegment): @@ -472,7 +476,7 @@ def forward(self, x, y=None, z=None): # emit node tensor declaration into `__init__` # typically it's about the `nn.Parameter` - self.init_attributes(node) + attr_meta_map.update(self.init_attributes(node)) # emit node code # codes : List[str] @@ -488,6 +492,10 @@ def forward(self, x, y=None, z=None): args.append(self.tensor_name(t)) node_args.append(args) + if outfile_attr_meta_map: + with open(outfile_attr_meta_map, 'wb') as f: + pickle.dump(attr_meta_map, f) + # generate full code with ClassBlock( class_name='GenModel', @@ -499,6 +507,7 @@ def forward(self, x, y=None, z=None): if as_parallel_module: cb.insert_body(f'rank = {device}') # save rank in class level + cb.insert_body(f'world_size = {self.runtime_ndevs}') # save world size in class level # async_op, max_bucket_size_bytes and zero_use_reduce_scatter # parameters are for testing purpose # and will not expose to user @@ -653,7 +662,7 @@ def emit_comm_groups(self): self.model_init_statements.append(code) self.model_init_statements.append(' ') - def init_attributes(self, node: IRCell): + def init_attributes(self, node: IRCell) -> dict[str, dict[str, Any]]: """ Emit tensor declaration code @@ -662,10 +671,18 @@ def init_attributes(self, node: IRCell): This method also populates `self.symbols : SymbolTable` to record the names of the variables for the tensors ever encountered. + + Returns: + dict[str, dict[str, Any]]: A mapping of tensor names to their attributes. """ + attr_meta_map = {} + self._init_attributes(node, attr_meta_map) + return attr_meta_map + + def _init_attributes(self, node: IRCell, attr_meta_map: Dict[str, Any]): psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}), persistent={persistent})" - map_sign = "self.add_full_map('{attr}', {tid}, {is_param}, '{orig_name}', {full_shape}, {slicers}, {val_chunks})" + map_sign = "self.add_full_map('{attr}', {tid}, {is_param}, '{orig_name}', {shape}, {slicers}, {val_chunks})" if not isinstance(node, IRSegment): for itensor in node.inputs(): name = self.tensor_name(itensor, prefix_attr='self.') @@ -693,14 +710,24 @@ def init_attributes(self, node: IRCell): assert len(slicers) == 1 and slicers[0] == slice(0, 1), f"Unexpected slicers {slicers} for scalar tensor." slicers = '...' # Ellipsis slicer for scalar tensor, x[...] is equivalent to x val_chunks = itensor.valmap[1] - code = map_sign.format( - attr=self.tensor_name(itensor), + attr_name = self.tensor_name(itensor) + attr_props = dict( tid=itensor.parent.tid, is_param=itensor.is_param(), orig_name=itensor.parent.name, - full_shape=tuple(itensor.parent.origin_shape), - slicers=str(slicers), - val_chunks=val_chunks + shape=tuple(itensor.parent.origin_shape), # full tensor shape + slicers=slicers, + val_chunks=val_chunks, + ) + attr_meta_map[attr_name] = dict( + **attr_props, + dtype=itensor.dtype, + sub_shape=tuple(itensor.shape) + ) + + code = map_sign.format( + attr=attr_name, + **attr_props ) self.model_init_statements.append(code) self.model_init_statements.append('') @@ -712,7 +739,7 @@ def init_attributes(self, node: IRCell): self.symbols.create(self.tensor_name(output, prefix_attr='self.')) else: for sub_node in node.nodes(): - self.init_attributes(sub_node) + self._init_attributes(sub_node, attr_meta_map) return def init_reducer(self, diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index c7b99132..1bc96c12 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -557,6 +557,10 @@ def _prepare_and_check_reusable( if reuse == ReuseType.MATCH or reuse == ReuseType.MOO: # check if the module is already generated expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] + expected_output_files.extend([ + outdir / ParallelModule.ATTR_META_FILE_TEMPLATE.format(rank) + for rank in range(compute_config.runtime_ngpus) + ]) expected_output_files.extend(trace_meta_files) expected_output_files.append(config_file) expected_output_files.append(outdir / _GRAPH_DUMP_FILE) @@ -840,12 +844,14 @@ def _gencode( sgener = ScheduleCodeGen(execplan, compute_config.runtime_ngpus) for rank in range(compute_config.runtime_ngpus): fname = outdir / _GENCODE_FILE_TEMPLATE.format(rank) + attr_meta_map_fname = outdir / ParallelModule.ATTR_META_FILE_TEMPLATE.format(rank) mgener.gen(rank, forward_args=forward_args, outfile=fname, attach=False, as_parallel_module=True, - end2end_mode=compute_config.use_end2end + end2end_mode=compute_config.use_end2end, + outfile_attr_meta_map=attr_meta_map_fname ) # generate temporal schedule code only for end2end module # because the code generated is wrong for non-end2end module. @@ -1979,7 +1985,7 @@ def load_merged_state_dict( else: # NNPN<[P]PP>N: the current parallel module # parallel module - pm_param_count = len(pm_modules[pm_cur]._orign_module_metadata.origin_param_names) + pm_param_count = len(pm_modules[pm_cur].origin_module_metadata.origin_param_names) # will map `pm_param_count` parameters in merge state dict # to `pm_locs[pm_cur].count` in optimizer state. cur_states = {} @@ -2011,7 +2017,7 @@ def _opt_load_merged_state_dict(module: ParallelModule, states: Dict[int, Dict[s # orig_name -> state orig_param_dict: Dict[str, Dict[str, Any]] = {} cnt = 0 - origin_param_names = module._orign_module_metadata.origin_param_names + origin_param_names = module.origin_module_metadata.origin_param_names for name in origin_param_names: if cnt in states: # some parameters may not in the sates when it is not used or requires_grad is False in training orig_param_dict[name] = states[cnt] @@ -2275,11 +2281,21 @@ def _broadcast_gen_files( torch.distributed.barrier() -def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Dict[str, Any]: +def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Tuple[ + Dict[int, Dict[str, Dict[str, AttrMeta]]], + int, + Dict[int, Dict[str, Dict[str, AttrMeta]]] +]: """ A helper function that computes the deduplicated attribute information from all ranks. Note that this function may be removed in the future and dedup information are computed directly at the compilation stage. + + Returns: + A tuple containing: + - rank2deduped_fullmap: a mapping from rank id to deduplicated attribute information + - dedup_group_size: the size of the deduplication group + - global_fullmaps: a mapping from rank id to full attribute information """ dedup_group_size = None for prefix, parallel_module in parallel_modules.items(): @@ -2291,27 +2307,23 @@ def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Dict[str dedup_group_size = dedup_group_size or 1 world_size = torch.distributed.get_world_size() - local_fullmaps = {prefix: m.fullmap for prefix, m in parallel_modules.items()} - local_tensor_meta = dict() - for prefix, m in parallel_modules.items(): - module_meta = {} - for local_name in m.fullmap.keys(): - assert hasattr(m, local_name), f'Module {prefix} does not have attribute {local_name}' - tensor = getattr(m, local_name) - module_meta[local_name] = (tuple(tensor.shape), tensor.dtype) - local_tensor_meta[prefix] = module_meta - global_fullmaps = [None for _ in range(world_size)] - torch.distributed.all_gather_object(global_fullmaps, local_fullmaps) + global_fullmaps: Dict[ + int, # rank id + Dict[str, # submodule prefix + Dict[str, # attribute name in parallel module + AttrMeta]] + ] = {} + for rank in range(world_size): + global_fullmaps[rank] = {} + for prefix, m in parallel_modules.items(): + global_fullmaps[rank][prefix] = m.get_attr_meta_map(rank) # `dedup_attrs` is a deterministic algorithm, so it produces same results across different ranks - rank2deduped_fullmap = dedup_attrs(OrderedDict(list(enumerate(global_fullmaps)))) - global_tensor_meta = [None for _ in range(world_size)] - torch.distributed.all_gather_object(global_tensor_meta, local_tensor_meta) - global_tensor_meta = OrderedDict(list(enumerate(global_tensor_meta))) + rank2deduped_fullmap = dedup_attrs(global_fullmaps) for rank in range(dedup_group_size, world_size): assert len(rank2deduped_fullmap[rank]) == 0, f'Rank {rank} has non-empty deduped_fullmap: {rank2deduped_fullmap[rank]}' - return rank2deduped_fullmap, dedup_group_size, global_tensor_meta + return rank2deduped_fullmap, dedup_group_size, global_fullmaps @torch.no_grad() @@ -2423,7 +2435,8 @@ def load_deduped_state_dict( key = f'{prefix}.{local_name}' if prefix else local_name assert prefix in parallel_modules, f'prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}' pm = parallel_modules[prefix] - shape, dtype = global_tensor_meta[rank][prefix][local_name] + attr_meta = global_tensor_meta[rank][prefix][local_name] + shape, dtype = attr_meta.sub_shape, attr_meta.dtype if rank == cur_rank: assert hasattr(pm, local_name), f'local_name {local_name} not found in {pm}' broadcast_tensor = getattr(pm, local_name) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 6fa3145d..4745da70 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union +import pickle +from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union, ClassVar from typing_extensions import Self import logging import os @@ -26,7 +27,7 @@ from nnscaler import __version__ as runtime_version from nnscaler.flags import CompileFlag -from nnscaler.utils import accum_mode +from nnscaler.utils import accum_mode, classproperty if TYPE_CHECKING: from nnscaler.parallel import ComputeConfig @@ -50,6 +51,11 @@ class AttrMeta: # the number of the partitioned values, usually 1 # (i.e., no partition on value -> no need to sum up) val_chunks: int + # data type of the full tensor and sub tensor + dtype: torch.dtype + # shape of the sub tensor + # it should be the shape of full_tensor[slicers] + sub_shape: Tuple[int, ...] def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, Dict[str, AttrMeta]]]) -> Dict[int, Dict[str, Dict[str, AttrMeta]]]: @@ -313,7 +319,8 @@ def add_full_map(self, attr: str, tid: int, is_param: bool, orig_name: str, shap val_chunks int: the number of value chunks. """ assert hasattr(self, attr), f"{attr} is not in the module" - meta = AttrMeta(tid, is_param, orig_name, shape, slicers, val_chunks) + attr_tensor: torch.Tensor = getattr(self, attr) + meta = AttrMeta(tid, is_param, orig_name, shape, slicers, val_chunks, attr_tensor.dtype, tuple(attr_tensor.shape)) self._fullmap[attr] = meta # TODO: remove this function, use the property instead @@ -366,7 +373,7 @@ def get_checkpoint(self, optimizer: torch.optim.Optimizer = None): # backward compatibility # in old version, dist_param_map is not loaded in constructor # so we will try to load it from file on the fly. - dist_param_map = getattr(self, '_dist_param_map', None) + dist_param_map = getattr(self, 'dist_param_map', None) if not dist_param_map: module_file = Path(sys.modules[self.__module__].__file__) # load from the same directory as the module file @@ -895,10 +902,27 @@ class ParallelModule(CubeModule): COMPUTE_CONFIG_FILE = 'compute_config.pt' ORIGIN_MODULE_METADATA_FILE = 'origin_module_metadata.pt' EXTRA_STATE_KEY = 'CUBE_EXTRA_STATE' + ATTR_META_FILE_PREFIX = 'attr_meta' + ATTR_META_FILE_TEMPLATE = ATTR_META_FILE_PREFIX + '{}.pkl' # 'attr_meta{}.pkl' + # the rank of the module, will be assigned in the generated subclasses rank: int + # the world size to run this module, will be assigned in the generated subclasses + world_size: int # the runtime version of the module when it is generated, will be assigned in the generated subclasses runtime_version: str + # mapping from the name of local attribute tensor + # to its corresponding fulltensor meta for all ranks. + # it is a list of dictionaries mapping from attribute names to their metadata + # and it is a replacement of `CubeModule.fullmap` + attr_meta_maps: list[dict[str, AttrMeta]] + # the directory of the module located + module_dir: Path + # The map is a dict mapping from the new parameter name (without tid suffix) in parallel module + # to the parameter name in original module. + dist_param_map: dict[str, str] + compute_config: 'ComputeConfig' + origin_module_metadata: OriginModuleMetadata def __init__(self): if self.__class__ == ParallelModule: # not init via super().__init__() @@ -921,6 +945,27 @@ def __init__(self): # track whether all the parames (especially the non-persistent buffers) have been initialized self._non_presistent_buffers_inited = False + def __init_subclass__(cls, **kwargs): + from nnscaler.parallel import ComputeConfig + + super().__init_subclass__(**kwargs) + cls.attr_meta_maps = [] + cls.module_dir = Path(sys.modules[cls.__module__].__file__).parent + + for rank in range(cls.world_size): + attr_map_file = cls.module_dir / cls.ATTR_META_FILE_TEMPLATE.format(rank) + with open(attr_map_file, 'rb') as f: + attr_meta_map = pickle.load(f) + attr_meta_map = {attr: AttrMeta(**meta) for attr, meta in attr_meta_map.items()} + cls.attr_meta_maps.append(attr_meta_map) + + cls.dist_param_map = torch.load(cls.module_dir / FxModuleParser.ATTR_MAP_FILE, weights_only=False) + cls.compute_config = ComputeConfig.safe_load_from_file( + cls.module_dir / cls.COMPUTE_CONFIG_FILE, + return_none_on_error=False + ) + cls.origin_module_metadata = torch.load(cls.module_dir / cls.ORIGIN_MODULE_METADATA_FILE, weights_only=False) + @property def non_presistent_buffers_inited(self): return self._non_presistent_buffers_inited @@ -957,23 +1002,14 @@ def _post_init(self, init_params=True, build_buckets=True): # TODO: re-enable this check # if dist.is_initialized() and self.rank != dist.get_rank(): # raise RuntimeError(f"The rank to load this module file name is expected to be {self._rank}, but got {dist.get_rank()}") - from nnscaler.parallel import ComputeConfig self._non_presistent_buffers_inited = init_params or not self._non_persistent_buffers_set module_file = Path(sys.modules[self.__module__].__file__) - self.module_dir = module_file.parent if init_params: self.load_attr_content(str(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE_STEM}"))) self._warn_uninitialized_non_persistent_buffers() - self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}"), weights_only=False) - self._compute_config: ComputeConfig = ComputeConfig.safe_load_from_file( - module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}"), - return_none_on_error=False - ) - self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}"), weights_only=False) - # add state_dict hook to save extra state # Please note extra_state is only used for merging, not for loading # so we can safely remove it in load_state_dict pre hook @@ -1002,6 +1038,21 @@ def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None self._zero_metadata = self._get_zero_metadata() + @classmethod + def get_attr_meta_map(cls, rank=None): + """ + Get the attribute meta map for the given rank. + If rank is None, return the attribute map for the current rank. + + This function is preferred over accessing `CubeModule.fullmap` in most cases, + since it doesn't need to instantiate the module. + """ + if rank is None: + rank = cls.rank + if rank < 0 or rank >= cls.world_size: + raise ValueError(f"Rank {rank} is out of range [0, {cls.world_size})") + return cls.attr_meta_maps[rank] + def forward(self, *args, **kwargs): self._warn_uninitialized_non_persistent_buffers(raise_error=True) if self.training: @@ -1163,19 +1214,6 @@ def infer_step(self, samples: List[Any]) -> List[Any]: outputs.append(output) return outputs - @property - def dist_param_map(self) -> Dict[str, str]: - """ - Get the parameter map of the model. - The map is a dict mapping from the new parameter name (without tid suffix) in parallel module - to the parameter name in original module. - """ - return self._dist_param_map - - @property - def compute_config(self) -> 'ComputeConfig': - return self._compute_config - def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ Calculate the gradient norm and clip gradients. @@ -1298,11 +1336,11 @@ def _add_extra_state(self, state_dict, prefix) -> None: state_dict[f'{prefix}{self.EXTRA_STATE_KEY}'] = asdict( ExtraState( rank=self.rank, - compute_config=self._compute_config, - dist_param_map=self._dist_param_map, + compute_config=self.compute_config, + dist_param_map=self.dist_param_map, param_area_map=self._fullmap, cube_param_names=[name for name, _ in self.named_parameters()], - **asdict(self._orign_module_metadata), + **asdict(self.origin_module_metadata), **asdict(self._zero_metadata), ) ) @@ -1338,19 +1376,19 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if strict: missing_keys.extend(new_missing_keys) - @property - def module_dedup_group_size(self) -> int: + @classproperty + def module_dedup_group_size(cls) -> int: """ Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. """ - return self.compute_config.module_dedup_group_size + return cls.compute_config.module_dedup_group_size - @property - def optimizer_dedup_group_size(self) -> int: + @classproperty + def optimizer_dedup_group_size(cls) -> int: """ Get the size of the deduplication group of the optimizer state dict. """ - return self.compute_config.optimizer_dedup_group_size + return cls.compute_config.optimizer_dedup_group_size def _list_fullmodel_files(self) -> List[Path]: legacy_fullmodel_path = self.module_dir / FxModuleParser.ATTR_CONTENT_FILE_STEM diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index 851d36d6..c11d7a23 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -186,7 +186,8 @@ def train_iter(model, dataloader): # tensor parallelism + scale test test_tp2scale2 = partial(torchrun, 4, assert_parity, baseline, - partial(cube_run, 2, tp_policy) + partial(cube_run, 2, tp_policy), + 0.001, ) # pipeline parallelism test diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index c7bc814a..5a305499 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -64,12 +64,12 @@ def _gpu_worker(): p(tempdir, 'none', '_1') # case 2: broadcast only code, so only rank 0 can load the module - # rank 1 will raise RuntimeError because it will fail to load fullmodel.pt + # rank 1 will raise FileNotFoundError because it will fail to load attr_map files and more with tempfile.TemporaryDirectory() as tempdir: if torch.distributed.get_rank() == 0: p(tempdir, 'code', '_2') else: - with pytest.raises(RuntimeError, match='Cannot find file.*'): + with pytest.raises(FileNotFoundError): p(tempdir, 'code', '_2') # case 3: broadcast except weights, so only rank 0 can load the module From 5039b80979f11b7671ed615c461c3238e4d7483e Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 8 Sep 2025 07:57:40 +0000 Subject: [PATCH 1835/1892] Merged PR 2404: [Parser] Fix Stack and Add Dot --- nnscaler/graph/function/function.py | 31 +++++++++++++++++++------- nnscaler/graph/parser/mapping.py | 1 + tests/graph/function/test_functions.py | 20 +++++++++++++++++ 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index c083ec9c..9adff8fb 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -150,6 +150,16 @@ def Accum(*inputs, signature = None): return IRDimops(Cat, 'accum', signature, [anno], inputs) +def Dot(input, tensor, *, out=None, signature = None): + """ + torch.dot(input, tensor, *, out=None) -> Tensor + """ + assert out is None + signature = 'torch.dot' + annos = ['k+, k+ -> 1',] + return IRDimops(Dot, 'dot', signature, annos, [input, tensor]) + + def Linear(input, weight, bias=None, signature = None): signature = 'torch.nn.functional.linear' assert isinstance(input, IRTensor) and isinstance(weight, IRTensor) @@ -195,6 +205,7 @@ def CubeEinSum(*operands, equation=None, signature = None): anno = f'{lhs} -> {rhs}' return IRDimops(CubeEinSum, 'einsum', signature, [anno], operands, equation=equation) + def EinSum(equation: str, *operands, signature = None): return CubeEinSum(*operands, equation=equation, signature=signature) @@ -1809,14 +1820,18 @@ def CubeStack(*tensors, dim=0, signature=None): assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'but got {tensors}' assert isinstance(dim, int), f"but not {dim}" signature = 'nnscaler.runtime.function.stack' - iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] - oanno = [None for i in range(len(tensors[0].shape) + 1)] - oanno[dim] = f'{len(tensors)}^' - offset = 0 - for i in range(len(oanno)): - if oanno[i] is None: - oanno[i] = copy.copy(iannos[-1][offset]) - offset += 1 + if tensors[0].is_scalar_tensor(): + iannos = ['1' for _ in tensors] + oanno = [f'{len(tensors)}'] + else: + iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] + oanno = [None for i in range(len(tensors[0].shape) + 1)] + oanno[dim] = f'{len(tensors)}' + offset = 0 + for i in range(len(oanno)): + if oanno[i] is None: + oanno[i] = copy.copy(iannos[-1][offset]) + offset += 1 anno = OpAnno.create_op_str(iannos, [oanno]) return IRDimops(CubeStack, 'stack', signature, [anno], tensors, dim=dim) diff --git a/nnscaler/graph/parser/mapping.py b/nnscaler/graph/parser/mapping.py index 55c2792f..2c8f88ce 100644 --- a/nnscaler/graph/parser/mapping.py +++ b/nnscaler/graph/parser/mapping.py @@ -55,6 +55,7 @@ def exist(signature: str) -> bool: kOpMap = { # __tnmtemplate('Dropout'): function.nnDropout, + __ttemplate('dot'): function.Dot, __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 319e3b6f..0a8b2fc0 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -1147,3 +1147,23 @@ def test_dict_keys_values_items(): # key will never be wrapped with IRObject # IRFullTensor will be reconstructed, so their ids are different assert all(x[0] == y[0] and x[1].shape == y[1].shape and x[1] != y[1] for x, y in zip(r.output(0), d.items())) + +def test_Stack(): + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=0) + expected_annotation = 'a b, a b, a b -> 3 a b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=1) + expected_annotation = 'a b, a b, a b -> a 3 b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=2) + expected_annotation = 'a b, a b, a b -> a b 3' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + + op = F.Stack([IRTensor([]), IRTensor([]), IRTensor([])], dim=0) + expected_annotation = '1, 1, 1 -> 3' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + +def test_Dot(): + op = F.Dot(IRTensor([4]), IRTensor([4])) + expected_annotation = 'k+, k+ -> 1' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Dot." From 0ce3b73ba54212a87231bd7008dc4918fec03ebf Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 8 Sep 2025 08:06:46 +0000 Subject: [PATCH 1836/1892] Merged PR 2403: [Trainer] Refine logging related to dedup and dataloader --- nnscaler/cli/trainer.py | 2 ++ nnscaler/parallel.py | 54 ++++++++++++++++++++++++++++------------- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index c35ad677..e7c86960 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -655,6 +655,7 @@ def _expire_checkpoints(self): def _global_batch_iterator(self, num_skip_first=0, stage='train'): if stage == 'train': if self.dataloader_resumed or num_skip_first == 0: + logger.info(f'Trainer resumes dataloader directly.') # if the checkpoint stops at the end of an epoch, # the rng states must be resumed before creating iterator # because `DataLoader.__iter__()` uses the rng (dunno why), @@ -662,6 +663,7 @@ def _global_batch_iterator(self, num_skip_first=0, stage='train'): self._try_resume_rng_states() it = iter(self.dataloader[stage]) else: # dry run until reach the desired batch. + logger.info(f'Trainer try to resume dataloader for {stage} stage with {num_skip_first}.') it = iter(self.dataloader[stage]) for _ in range(num_skip_first * self.train_args.update_freq): _sample = next(it) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 1bc96c12..985ac732 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2263,7 +2263,7 @@ def _broadcast_gen_files( if curr_rank != 0: files = sent_obj[0] - logging.info(f'File list broadcasted ({len(files)} in total).') + logger.info(f'File list broadcasted ({len(files)} in total).') # send file content one by one for fname in files: if curr_rank == 0: @@ -2275,7 +2275,7 @@ def _broadcast_gen_files( if curr_rank != 0: with open(outdir / fname, 'wb') as f: f.write(data[0]) - logging.info(f'File {fname} broadcasted.') + logger.info(f'File {fname} broadcasted.') # wait for all nodes to finish torch.distributed.barrier() @@ -2417,42 +2417,62 @@ def load_deduped_state_dict( cur_rank = torch.distributed.get_rank() # step 1: load deduped state dict at each rank - module.load_state_dict(module_state_dict, strict=False) + missing_keys, unexpected_keys = module.load_state_dict(module_state_dict, strict=False) module.to(device) torch.distributed.barrier() - logger.debug(f'at rank {cur_rank}, state_dict keys: {module_state_dict.keys()}') + logger.debug(f'At rank {cur_rank}, state_dict keys: {module_state_dict.keys()}.') + logger.debug(f'At rank {cur_rank}, missing_keys: {missing_keys}, unexpected_keys: {unexpected_keys}.') # step 2: broadcast deduped weights inside 1st scale unit parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} rank2deduped_fullmap, dedup_group_size, global_tensor_meta = _collect_dedup_info(parallel_modules) broadcast_group = DeviceGroup().get_group(list(range(dedup_group_size))) - logger.debug(f'at rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}') + logger.debug(f'At rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}.') if cur_rank < dedup_group_size: + # broadcast weights in parallel modules for rank, deduped_fullmap in rank2deduped_fullmap.items(): - logger.debug(f'at rank {cur_rank}, process rank: {rank}') + logger.debug(f'At rank {cur_rank}, process rank: {rank}.') for prefix, fullmap in deduped_fullmap.items(): for local_name, attr_meta in fullmap.items(): key = f'{prefix}.{local_name}' if prefix else local_name - assert prefix in parallel_modules, f'prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}' + assert prefix in parallel_modules, f'Prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}.' pm = parallel_modules[prefix] attr_meta = global_tensor_meta[rank][prefix][local_name] shape, dtype = attr_meta.sub_shape, attr_meta.dtype if rank == cur_rank: - assert hasattr(pm, local_name), f'local_name {local_name} not found in {pm}' + assert hasattr(pm, local_name), f'Local name {local_name} not found in {pm}.' broadcast_tensor = getattr(pm, local_name) - logger.info(f'at rank {cur_rank}, broadcast: {key} from {cur_rank}') + logger.info(f'Broadcast: {key} from {cur_rank}.') else: broadcast_tensor = torch.empty(shape, device=device, requires_grad=False, dtype=dtype) torch.distributed.broadcast(broadcast_tensor, src=rank, group=broadcast_group) if rank != cur_rank: - logger.info(f'at rank {cur_rank}, try to load: {key} to rank {cur_rank}') # in pipeline parallelism, the local_name may not be found in the module if hasattr(pm, local_name): + logger.info(f'At rank {cur_rank}, try to load: {key} from rank {rank}.') attr = getattr(pm, local_name) - attr.data.copy_(broadcast_tensor) + if key in missing_keys: + attr.data.copy_(broadcast_tensor) + missing_keys.remove(key) + else: + assert torch.equal(attr, broadcast_tensor), \ + f'At rank {cur_rank}, the attribute {key} is already loaded, but not equal to the broadcasted tensor from rank {rank}.' + else: + logger.info(f'At rank {cur_rank}, skip to load: {key} from rank {rank}, not found in the module.') + + for key in missing_keys: + split_names = key.split('.') + prefix = '.'.join(split_names[:-1]) # remove the last part of the key + assert prefix not in parallel_modules, f'At rank {cur_rank}, the missing key {key} should be in non-parallel modules.' + + # At this point + # - All parallel modules in first scale unit should be complete. + # - Non-parallel modules in rank0 should be complete. The rest ranks will get the weights via broadcast_weights. torch.distributed.barrier() - # step 3: broadcast weights from 1st scale unit to other units + # step 3: + # - broadcast non-parallel module weights from 0th rank to other ranks + # - broadcast parallel modules weights from 1st scale unit to other units broadcast_weights(module) if optimizer is not None: @@ -2499,7 +2519,7 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g broadcast_group = setup_stride_broadcast_group(dedup_group_size) src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks - logging.info(f'Rank-{rank} is broadcasting optimizer states to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') + logger.info(f'Rank-{rank} is broadcasting optimizer states to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') # broadcast param groups and state keys/shapes/dtypes via broadcast_object_list if rank == src_rank: @@ -2585,7 +2605,7 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): broadcast_group = setup_stride_broadcast_group(stride_size) rank = torch.distributed.get_rank() src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks - logging.info(f'Rank-{rank} is broadcasting weights of {module.__class__.__name__} to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') + logger.info(f'Rank-{rank} is broadcasting weights of {module.__class__.__name__} to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') if isinstance(module, ParallelModule): if not _broadcast_single_value(src_rank, curr_parallel_group, module.non_presistent_buffers_inited): @@ -2593,15 +2613,15 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): # we have a special optimization for ParallelModule params = module.parameters_for_broadcast() if isinstance(module, ParallelModule) else list(module.parameters(False)) - logging.info(f'Inplace broadcasting {len(params)} parameters...') + logger.info(f'Inplace broadcasting {len(params)} parameters...') for i, param in enumerate(params): torch.distributed.broadcast(param.data, src=src_rank, group=curr_parallel_group) - logging.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') + logger.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') # NOTE: may batch buffers for efficient broadcast, # current implementation is the most memory efficient way. buffers = list(module.buffers(False)) - logging.info(f'Inplace broadcasting {len(buffers)} buffers...') + logger.info(f'Inplace broadcasting {len(buffers)} buffers...') for buffer in buffers: torch.distributed.broadcast(buffer.data, src=src_rank, group=curr_parallel_group) From 122fdf66a8a80deb55870d37865f603a3ae92b35 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 8 Sep 2025 08:24:37 +0000 Subject: [PATCH 1837/1892] Merged PR 2400: [Runtime] Refine f16 optimizer loading logic - It is possible that the model is trained in other precision in nnscaler before but want to train it in f16 optimizer this case, previous hook fails at this case. - It is better to provide message during loading. --- nnscaler/runtime/f16_optimizer.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py index 414361a5..908a4f3e 100644 --- a/nnscaler/runtime/f16_optimizer.py +++ b/nnscaler/runtime/f16_optimizer.py @@ -32,6 +32,10 @@ def __init__(self, *args, **kwargs): # forward __init__ call to the next class in mro(method resolution order) super().__init__(*args, **kwargs) self._multiply_factor = 1.0 + # This flag is used to indicate whether fp32_params are loaded from checkpoint. + # If not, we will sync from fp16 params to fp32 params in after_load_checkpoint. + # If the model is trained from scratch, this flag will be None. + self._fp32_params_loaded = None def after_setup(self, trainer: 'Trainer') -> None: """ @@ -111,12 +115,15 @@ def load_state_dict(self, state_dict): param.data = state_dict['state'][i]['fp32_params'].data.to(device) # pop to avoid store a redundant copy in the wrapped optimizer state_dict['state'][i].pop('fp32_params') + else: + logger.warning('fp32_params not found in state_dict, will sync from fp16 params to fp32 params') + self._sync_fp16_params_to_fp32() - if len(self.param_groups) != 1: - raise RuntimeError('only support one param group') - self.param_groups[0]['params'] = self.fp32_params + if len(self.param_groups) != 1: + raise RuntimeError('only support one param group') super().load_state_dict(state_dict) + self._fp32_params_loaded = True def _sync_f16_grads_to_fp32(self): # copy FP16 grads to FP32 @@ -148,10 +155,15 @@ def _sync_fp16_params_to_fp32(self): continue p32.data.copy_(p.data) + def on_load_checkpoint(self, trainer, checkpoint) -> None: + self._fp32_params_loaded = False + logger.info('Set _fp32_params_loaded to False in on_load_checkpoint hook') + def after_load_checkpoint(self, trainer, checkpoint) -> None: - if 'nnscaler' not in checkpoint: - # this checkpoint is not created by nnscaler. + if not self._fp32_params_loaded: + logger.info('fp32_params not loaded, will sync from fp16 params to fp32 params') self._sync_fp16_params_to_fp32() + self._fp32_params_loaded = True def overrided_scale_grads(self, scale: float): """ From 12c3f0e5f039443c202f9c0dea4ff1ec9bd1e882 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 9 Sep 2025 08:48:39 +0000 Subject: [PATCH 1838/1892] Merged PR 2401: [Runtime] Add option to load merged state dict with less memory usage. In previous solution, all ranks will load merged state dict at the same time, which can cause OOM when module is big. In this PR, only one rank loads the whole merged state dict, and then trim the state dict and send the trimmed version to each rank. --- nnscaler/__init__.py | 2 + nnscaler/cli/checkpoint.py | 153 +++++++ nnscaler/cli/trainer.py | 123 +++-- nnscaler/cli/trainer_args.py | 5 + nnscaler/parallel.py | 553 ++++++++++++++++++++--- nnscaler/runtime/adapter/reducer.py | 113 ++++- nnscaler/runtime/module.py | 129 +++++- nnscaler/utils.py | 37 +- tests/cli/test_trainer.py | 78 +++- tests/parallel_module/test_checkpoint.py | 28 +- tests/runtime/test_hybrid_optimizer.py | 49 +- tests/test_utils.py | 41 +- 12 files changed, 1156 insertions(+), 155 deletions(-) create mode 100644 nnscaler/cli/checkpoint.py diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index 4cf4896a..c9265af8 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -16,6 +16,8 @@ broadcast_weights, load_sharded_state_dict, sync_grad_when, + trimmed_broadcast_merged_state_dict, + load_merged_state_dict_from_rank, ) from nnscaler.graph.parser.register import register_op from nnscaler.runtime.function.function import ( diff --git a/nnscaler/cli/checkpoint.py b/nnscaler/cli/checkpoint.py new file mode 100644 index 00000000..3f596bbc --- /dev/null +++ b/nnscaler/cli/checkpoint.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Only for command line +""" + +import logging +import os +import sys +from pathlib import Path + +import torch.distributed + +import nnscaler +from nnscaler.cli.trainer import Trainer, TrainerArgs +from nnscaler.parallel import _trim_module_merged_state_dict, _trim_optimizer_merged_state_dict + + +logger = logging.getLogger(__name__) + + +def _patch_distributed(): + groups = {} + + def is_initialized(): + return bool(groups) + + torch.distributed.is_initialized = is_initialized + + def init_process_group(*args, **kwargs): + world_size = int(os.environ['WORLD_SIZE']) + groups[None] = list(range(world_size)) + + def get_rank(group=None): + if group not in groups: + raise ValueError(f"Unknown group: {group}") + try: + return groups[group].index(int(os.environ['RANK'])) + except ValueError: + return -1 + + def get_world_size(group=None): + if group not in groups: + raise ValueError(f"Unknown group: {group}") + return len(groups[group]) + + def new_group(ranks=None, *args, **kwargs): + world_size = int(os.environ['WORLD_SIZE']) + if ranks is None or len(ranks) == world_size: + return + group_id = tuple(sorted(ranks)) + if group_id in groups: + return group_id + groups[group_id] = ranks + return group_id + + torch.distributed.get_rank = get_rank + torch.distributed.get_world_size = get_world_size + torch.distributed.init_process_group = init_process_group + torch.distributed.destroy_process_group = lambda: None + torch.distributed.new_group = new_group + torch.distributed.barrier = lambda *args, **kwargs: None + torch.distributed.all_gather = lambda *args, **kwargs: None + torch.distributed.broadcast_object_list = lambda *args, **kwargs: None + + +def _trim_merged_checkpoint(train_args: TrainerArgs, merged_state_dict, rank: int): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = '0' + os.environ['WORLD_SIZE'] = str(train_args.compute_config.runtime_ngpus) + os.environ['GROUP_RANK'] = str(rank) + os.environ['LOCAL_WORLD_SIZE'] = '1' + os.environ['TORCHELASTIC_RUN_ID'] = '0' # fake torchrun env + + sharded_state_dict = {k: v for k, v in merged_state_dict.items()} + + trainer = Trainer(train_args=train_args) + # enforce run mode to load module and optimizer + trainer.train_args.run_mode = 'run' + trainer._setup() + + sharded_state_dict['model'] = _trim_module_merged_state_dict( + trainer.model, merged_state_dict['model'], + device='cpu' + ) + sharded_state_dict['optimizer'] = _trim_optimizer_merged_state_dict( + trainer.model, trainer.optimizer._extra_state, merged_state_dict['optimizer'], + device='cpu' + ) + sharded_state_dict['train_args'] = train_args.to_dict() + sharded_state_dict['train_args'].setdefault('checkpoint', {})['save_type'] = 'sharded' + # discard rng_states for merged state dict + sharded_state_dict.pop('rng_states', None) + if 'dataloader' in sharded_state_dict and sharded_state_dict['dataloader'] is not None: + # keep dataloader state only when all ranks have the same state + dataloader_states = sharded_state_dict['dataloader'] + if all(dataloader_states[i] == dataloader_states[0] for i in range(1, len(dataloader_states))): + sharded_state_dict['dataloader'] = dataloader_states[0] + else: + sharded_state_dict.pop('dataloader') + + # make it sharded checkpoint + for module_path, m in trainer.model.named_modules(): + prefix = module_path + '.' if module_path else '' + if isinstance(m, nnscaler.ParallelModule): + m._add_extra_state(sharded_state_dict['model'], prefix) + return sharded_state_dict + + +def _distribute_checkpoint(train_args: TrainerArgs, from_: str, to_: str): + nnscaler.utils.set_default_logger_level(level=logging.INFO) + _patch_distributed() + resume_from = Path(from_) + save_to = Path(to_) + save_to.mkdir(parents=True, exist_ok=True) + + if resume_from.is_file(): + state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + if convert_fn := train_args.checkpoint.resolved_convert_fn: + state_dict = convert_fn(state_dict) + else: + ckpt_files = list(resume_from.glob('*.ckpt')) + rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} + if set(rank_ckpt_files.keys()) != set(range(len(rank_ckpt_files))): + raise ValueError(f"Checkpoint files in {resume_from} are not complete: {rank_ckpt_files.keys()}") + state_dict = Trainer._merge_checkpoint(list(rank_ckpt_files.values())) + + for i in range(train_args.compute_config.runtime_ngpus): + sharded_state_dict = _trim_merged_checkpoint(train_args, state_dict, i) + torch.save(sharded_state_dict, save_to / f"{i}.ckpt") + + +if __name__ == '__main__': + argv = sys.argv[1:] + if len(argv) == 0: + raise ValueError("No command specified. Expected `distribute -f `") + if argv[0] == 'distribute': + if len(argv) < 5: + raise ValueError("Not enough arguments. Expected at least `distribute -f `") + from_ = argv[1] + to_ = argv[2] + train_args = TrainerArgs.from_cli(argv[3:]) + # never broadcast generated files. + train_args.broadcast_strategy = 'none' + train_args.checkpoint.resume_from = None + _distribute_checkpoint(train_args, from_, to_) + else: + raise ValueError(f"Unknown command: {argv[0]}") +else: + # we have patched too many things. + # please run this script with `python -m nnscaler.cli.checkpoint` + raise ImportError("checkpoint.py should be run as a script.") diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index e7c86960..994932b2 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -22,6 +22,7 @@ from tqdm import tqdm import nnscaler +from nnscaler.runtime.device import DeviceGroup from nnscaler.utils import enforce_zero_num_worker, is_running_distributed from .trainer_args import AggregatedOutputs, TrainerArgs, fix_input @@ -142,7 +143,7 @@ def _load_dummy_input(self): def _setup(self): if is_running_distributed(): nnscaler.init() - if torch.distributed.get_rank() == 0: + if DeviceGroup().local_rank == 0: logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) @@ -190,6 +191,17 @@ def _setup(self): self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) self.local_rank = int(os.environ.get('LOCAL_RANK')) self.node_rank = int(os.environ.get('GROUP_RANK')) + assert self.rank // self.local_world_size == self.node_rank + self.local_ranks = list( + range( + self.node_rank * self.local_world_size, + (self.node_rank + 1) * self.local_world_size + ) + ) + self.local_rank0 = self.local_ranks[0] + # create local process groups + for local_rank0 in range(0, self.world_size, self.local_world_size): + DeviceGroup().get_group(list(range(local_rank0, local_rank0 + self.local_world_size))) self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq if len(self.dataloader['train']) % self.train_args.update_freq != 0: @@ -291,48 +303,57 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): } return merged_state_dict - def _broadcast_merged_state_dict(self, state_dict: Dict[str, Any]): + def _broadcast_merged_state_dict( + self, + state_dict: Dict[str, Any], + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + ): """ Broadcast the merged state dict to all ranks. We can't broadcast the whole state_dict at once, because it may be too large, and leads to OOM. Here we will break the model and optimizer state_dict into smaller pieces and broadcast them one by one. Please note we use `torch.distributed.broadcast_object_list` to broadcast the state_dict (including tensors inside). """ + dst_ranks = dst_ranks or list(range(torch.distributed.get_world_size())) + if src_rank not in dst_ranks or self.rank not in dst_ranks: + raise ValueError(f"src_rank and current rank must be in dst_ranks: {dst_ranks}") + pg = DeviceGroup().get_group(dst_ranks) + + if self.rank == src_rank: + if state_dict is None: + raise ValueError("state_dict should not be None in rank 0 when broadcasting") + else: + if state_dict is not None: + raise ValueError("state_dict should be None in other ranks when broadcasting") + state_dict = {} def _broadcast_keys(sdict: Dict[str, Any], set_keys=True): - if self.rank == 0: + if self.rank == src_rank: state_keys = list(sdict.keys()) else: state_keys = None state_key_list = [state_keys] - torch.distributed.broadcast_object_list(state_key_list, src=0) + torch.distributed.broadcast_object_list(state_key_list, src=src_rank, group=pg) state_keys = state_key_list[0] - if set_keys and self.rank != 0: + if set_keys and self.rank != src_rank: for key in state_keys: sdict[key] = {} # assume the values are empty dicts return state_keys def _broadcast_value(sdict, key): - if self.rank == 0: + if self.rank == src_rank: value_list = [sdict[key]] else: value_list = [None] - torch.distributed.broadcast_object_list(value_list, src=0) - if self.rank != 0: + torch.distributed.broadcast_object_list(value_list, src=src_rank, group=pg) + if self.rank != src_rank: sdict[key] = value_list[0] def _broadcast_values(sdict, keys): for key in keys: _broadcast_value(sdict, key) - if self.rank == 0: - if state_dict is None: - raise ValueError("state_dict should not be None in rank 0 when broadcasting") - else: - if state_dict is not None: - raise ValueError("state_dict should be None in other ranks when broadcasting") - state_dict = {} - state_keys = _broadcast_keys(state_dict) for skey in state_keys: @@ -377,11 +398,19 @@ def _load_checkpoint(self): if not resume_from: return logger.info(f"Resuming from {resume_from}") + trimmed_broadcast_required = False + load_from_merged = False + if resume_from.is_file(): - resume_from = resume_from # when we load from merged checkpoint - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) - if convert_fn := self.train_args.checkpoint.resolved_convert_fn: - state_dict = convert_fn(state_dict) + # when we load from merged checkpoint + load_from_merged = True + trimmed_broadcast_required = self.train_args.checkpoint.resume_from.save_memory + if not self.train_args.checkpoint.resume_from.save_memory or self.local_rank == 0: + state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + if convert_fn := self.train_args.checkpoint.resolved_convert_fn: + state_dict = convert_fn(state_dict) + else: + state_dict = None else: ckpt_files = list(resume_from.glob('*.ckpt')) rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} @@ -394,14 +423,20 @@ def _load_checkpoint(self): if len(rank_ckpt_files) != self.world_size or self.train_args.checkpoint.resume_from.with_merged: # merge the checkpoint files from all ranks and broadcast to all ranks torch.distributed.barrier() - if self.rank == 0: + if self.local_rank == 0: logger.info(f"Merging checkpoint files from {resume_from}") state_dict = self._merge_checkpoint(list(rank_ckpt_files.values())) else: state_dict = None - logger.info(f"Broadcasting merged checkpoint to all ranks.") - state_dict = self._broadcast_merged_state_dict(state_dict) - logger.info(f"Broadcasted merged checkpoint to all ranks.") + + load_from_merged = True + trimmed_broadcast_required = self.train_args.checkpoint.resume_from.save_memory + if not self.train_args.checkpoint.resume_from.save_memory: + logger.info(f"Broadcasting merged checkpoint to all ranks.") + state_dict = self._broadcast_merged_state_dict( + state_dict, src_rank=self.local_rank0, dst_ranks=self.local_ranks + ) + logger.info(f"Broadcasted merged checkpoint to all ranks.") else: resume_from = resume_from / f'{self.rank}.ckpt' state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) @@ -411,12 +446,37 @@ def _load_checkpoint(self): f"If it fails, please try with merged checkpoint." ) + if trimmed_broadcast_required: + logger.info("Broadcasting trimmed checkpoint to all ranks.") + state_dict = state_dict or {} + state_dict['model'], state_dict['optimizer'] = nnscaler.trimmed_broadcast_merged_state_dict( + self.model, + state_dict['model'] if self.local_rank == 0 else None, + self.optimizer, + state_dict['optimizer'] if self.local_rank == 0 else None, + src_rank=self.local_rank0, + dst_ranks=self.local_ranks, + ) + remaining_state_dict = self._broadcast_merged_state_dict( + {k: v for k, v in state_dict.items() if k not in ('model', 'optimizer')} + if self.local_rank == 0 else None, + src_rank=self.local_rank0, + dst_ranks=self.local_ranks, + ) + if self.local_rank != 0: + state_dict.update(remaining_state_dict) + logger.info("Broadcasted trimmed checkpoint to all ranks.") + + # trimmed checkpoint is sharded + ckpt_save_type = 'sharded' + else: + # if it is not a well-formed state_dict (from third party) + # we will treat it as a merged state_dict + ckpt_save_type = state_dict.get('train_args', {}) \ + .get('checkpoint', {}) \ + .get('save_type', 'merged') + self.hook.on_load_checkpoint(self, state_dict) - # if it is not a well-formed state_dict (from third party) - # we will treat it as a merged state_dict - ckpt_save_type = state_dict.get('train_args', {}) \ - .get('checkpoint', {}) \ - .get('save_type', 'merged') if ckpt_save_type == 'merged': # it is a merged state dict nnscaler.load_merged_state_dict( @@ -441,10 +501,11 @@ def _load_checkpoint(self): raise ValueError("lr_scheduler is not set in the current trainer") if self.lr_scheduler: self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + if 'dataloader' in state_dict and state_dict['dataloader'] is not None: if not self._is_resumable_dataloader(): raise ValueError("dataloader is not resumable, but checkpoint contains dataloader state") - if ckpt_save_type == 'merged': + if load_from_merged: dataloader_states = state_dict['dataloader'] # only load dataloader state when all ranks have the same state # TODO: is this reasonable? @@ -464,7 +525,7 @@ def _load_checkpoint(self): self.train_status = TrainStatus(**state_dict['train_status']) # we don't resume rng states when loading merged checkpoint, - if ckpt_save_type != 'merged': + if not load_from_merged: self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() else: logger.warning("RNG states are not resumed when loading merged checkpoint.") diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 3381c70d..0c368aab 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -403,6 +403,11 @@ class ResumeOptions: # `None` means will load the sharded checkpoint files if the world size is not changed. # and will load merged checkpoint if the world size is changed. with_merged: Optional[bool] = None + # If the memory is limited, we can save memory by only loading merged state dict in GPU 0 of each node + # and broadcast trimmed state dict to other ranks in the same node + # although this will be slower + # Only used when resuming from a merged checkpoint. + save_memory: bool = True @dataclass diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 985ac732..d6a170be 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -14,7 +14,7 @@ import logging import copy import os -from collections import OrderedDict +from collections import OrderedDict, defaultdict import torch import torch.distributed @@ -49,7 +49,14 @@ from nnscaler.flags import CompileFlag, RuntimeFlag import nnscaler.policies as policies from nnscaler.program import disable_global_graph -from nnscaler.utils import get_member_by_name, load_type, setup_stride_broadcast_group, get_shared_params +from nnscaler.utils import ( + get_member_by_name, + load_type, + set_member_by_name, + setup_stride_broadcast_group, + get_shared_params, + OptStateDict, +) logger = logging.getLogger(__name__) @@ -1948,68 +1955,94 @@ def load_merged_state_dict( module.to(device) if optimizer is not None and optimizer_state_dict is not None: - if not _is_supported_optimizer(optimizer._extra_state.name): - raise ValueError("Only Adam-like optimizers are supported.") + new_optimizer_state_dict = _trim_optimizer_merged_state_dict(module, optimizer._extra_state, optimizer_state_dict, device=device) + optimizer.load_state_dict(new_optimizer_state_dict) - # handle non-paralleled module parameters - # make sure the order of the parameters - pm_name_locs: Dict[str, ModuleParameterLocation] = dict(sorted(optimizer._extra_state.parallel_module_locs.items(), key=lambda x: x[1].offset)) - pm_modules: List[torch.nn.Module] = [] - pm_locs = list(pm_name_locs.values()) - for name in pm_name_locs: - m = get_member_by_name(module, name) - if not isinstance(m, ParallelModule): - raise ValueError(f"Module {name} is not a ParallelModule") - pm_modules.append(m) - - merged_cur = 0 # the current index of the merged state dict - pm_cur = 0 # the current index of the parallel module in pm_locs - new_states: Dict[int, Dict[str, Any]] = {} - new_cur = 0 # the current index of the new state dict - assert len(optimizer_state_dict['param_groups']) == 1 - effective_state_len = len(optimizer_state_dict['param_groups'][0]['params']) - while merged_cur < effective_state_len: - # N: non-paralleled module parameters, P: paralleled module (will have multiple parameters) - # The parameter list would look like: NNPNPPPN - # []: the current processing parameter - # <>: the current processing parallel module - if ( - pm_cur >= len(pm_modules) # NNPNPPP[N]: the ending parameters, no current parallel module - or new_cur < pm_locs[pm_cur].offset # [N]N

NPPPN: other parameters - ): - # non-parallel module - if merged_cur in optimizer_state_dict['state']: - new_states[new_cur] = optimizer_state_dict['state'][merged_cur] - merged_cur += 1 - new_cur += 1 - else: - # NNPN<[P]PP>N: the current parallel module - # parallel module - pm_param_count = len(pm_modules[pm_cur].origin_module_metadata.origin_param_names) - # will map `pm_param_count` parameters in merge state dict - # to `pm_locs[pm_cur].count` in optimizer state. - cur_states = {} - for i in range(pm_param_count): - if merged_cur + i in optimizer_state_dict['state']: - cur_states[i] =optimizer_state_dict['state'][merged_cur + i] - pm_new_states = _opt_load_merged_state_dict(pm_modules[pm_cur], cur_states) - for idx, value in pm_new_states.items(): - new_states[new_cur + idx] = value - new_cur += pm_locs[pm_cur].count - merged_cur += pm_param_count - pm_cur += 1 - - # move the new states to the device if needed - for idx, state in new_states.items(): - for key, value in state.items(): - if isinstance(value, torch.Tensor): - new_states[idx][key] = value.to(device) - new_optimizer_state_dict = {} - new_optimizer_state_dict['state'] = new_states - new_optimizer_state_dict['param_groups'] = copy.deepcopy(optimizer_state_dict['param_groups']) - new_optimizer_state_dict['param_groups'][0]['params'] = list(range(new_cur)) - optimizer.load_state_dict(new_optimizer_state_dict) +def _trim_optimizer_merged_state_dict( + module: torch.nn.Module, + opt_extra_state: OptimizerExtraState, + optimizer_state_dict: Dict[str, Any], + *, + device: Union[str, torch.device] = None +) -> Dict[str, Any]: + """ + Trim the merged state dict to only keep the states needed for the optimizer. + + Args: + module (torch.nn.Module): the module to be loaded + opt_extra_state (OptimizerExtraState): the extra state of the optimizer + optimizer_state_dict (Dict[str, Any]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the optimizer state dict. + + Returns: + Dict[str, Any]: the trimmed optimizer state dict + """ + if not _is_supported_optimizer(opt_extra_state.name): + raise ValueError("Only Adam-like optimizers are supported.") + + device = device or torch.cuda.current_device() + + # handle non-paralleled module parameters + # make sure the order of the parameters + pm_name_locs: Dict[str, ModuleParameterLocation] = dict(sorted(opt_extra_state.parallel_module_locs.items(), key=lambda x: x[1].offset)) + pm_modules: List[ParallelModule] = [] + pm_locs = list(pm_name_locs.values()) + for name in pm_name_locs: + m = get_member_by_name(module, name) + if not isinstance(m, ParallelModule): + raise ValueError(f"Module {name} is not a ParallelModule") + pm_modules.append(m) + + merged_cur = 0 # the current index of the merged state dict + pm_cur = 0 # the current index of the parallel module in pm_locs + new_states: Dict[int, Dict[str, Any]] = {} + new_cur = 0 # the current index of the new state dict + assert len(optimizer_state_dict['param_groups']) == 1 + effective_state_len = len(optimizer_state_dict['param_groups'][0]['params']) + while merged_cur < effective_state_len: + # N: non-paralleled module parameters, P: paralleled module (will have multiple parameters) + # The parameter list would look like: NNPNPPPN + # []: the current processing parameter + # <>: the current processing parallel module + if ( + pm_cur >= len(pm_modules) # NNPNPPP[N]: the ending parameters, no current parallel module + or new_cur < pm_locs[pm_cur].offset # [N]N

NPPPN: other parameters + ): + # non-parallel module + if merged_cur in optimizer_state_dict['state']: + new_states[new_cur] = optimizer_state_dict['state'][merged_cur] + merged_cur += 1 + new_cur += 1 + else: + # NNPN<[P]PP>N: the current parallel module + # parallel module + pm_param_count = len(pm_modules[pm_cur].origin_module_metadata.origin_param_names) + # will map `pm_param_count` parameters in merge state dict + # to `pm_locs[pm_cur].count` in optimizer state. + cur_states = {} + for i in range(pm_param_count): + if merged_cur + i in optimizer_state_dict['state']: + cur_states[i] =optimizer_state_dict['state'][merged_cur + i] + pm_new_states = _opt_load_merged_state_dict(pm_modules[pm_cur], cur_states) + for idx, value in pm_new_states.items(): + new_states[new_cur + idx] = value + new_cur += pm_locs[pm_cur].count + merged_cur += pm_param_count + pm_cur += 1 + + # move the new states to the device if needed + for idx, state in new_states.items(): + for key, value in state.items(): + if isinstance(value, torch.Tensor): + new_states[idx][key] = value.to(device) + + new_optimizer_state_dict = {} + new_optimizer_state_dict['state'] = new_states + new_optimizer_state_dict['param_groups'] = copy.deepcopy(optimizer_state_dict['param_groups']) + new_optimizer_state_dict['param_groups'][0]['params'] = list(range(new_cur)) + + return new_optimizer_state_dict def _opt_load_merged_state_dict(module: ParallelModule, states: Dict[int, Dict[str, Any]]): @@ -2475,11 +2508,9 @@ def load_deduped_state_dict( # - broadcast parallel modules weights from 1st scale unit to other units broadcast_weights(module) - if optimizer is not None: + if optimizer is not None and optimizer_state_dict is not None: if not _is_supported_optimizer(optimizer._extra_state.name): raise ValueError("Only Adam-like optimizers are supported.") - if optimizer_state_dict is None: - raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") for idx, state in optimizer_state_dict['state'].items(): for key, value in state.items(): @@ -2658,9 +2689,11 @@ def load_sharded_state_dict( device = device or torch.cuda.current_device() module.load_state_dict(module_state_dict) module.to(device) - if optimizer: - if optimizer_state_dict is None: - raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") + if optimizer and optimizer_state_dict: + for idx, state in optimizer_state_dict.get('state', {}).items(): + for key, value in state.items(): + if isinstance(value, torch.Tensor): + optimizer_state_dict['state'][idx][key] = value.to(device) optimizer.load_state_dict(optimizer_state_dict) @@ -2694,3 +2727,387 @@ def sync_grad_when(cond: bool): cond (bool): whether to synchronize gradients. """ return _runtime_flags(skip_reducer=not cond) + + +def _construct_parallel_module_stub(metadata): + pmodules = {prefix: ParallelModule._unpack(minfo) for prefix, minfo in metadata.items()} + + # whole parallel module + if len(pmodules) == 1 and list(pmodules.keys())[0] == '': + module = pmodules[''] + else: + module = torch.nn.Module() + for prefix, pmodule in pmodules.items(): + set_member_by_name(module, prefix, pmodule) + + # mock `named_modules` to list parallel modules in stub module + def named_modules( + memo=None, + prefix: str = "", + remove_duplicate: bool = True, + ): + assert memo is None and prefix == '' and remove_duplicate is True, \ + "Only support default arguments" + return pmodules.items() + + module.named_modules = named_modules + + return module + + +def _trim_module_merged_state_dict( + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + *, + device: Union[str, torch.device] = None, +): + device = device or torch.cuda.current_device() + + parallel_modules = {module_path: m for module_path, m in module.named_modules() if isinstance(m, ParallelModule)} + + trimmed_state_dict = {} + # collect non-parallel module parameters + for key, tensor in module_state_dict.items(): + parts = key.split('.') + if not any('.'.join(parts[:i]) in parallel_modules for i in range(0, len(parts))): + trimmed_state_dict[key] = tensor.to(device) + + for module_path, pmodule in parallel_modules.items(): + prefix = module_path + '.' if module_path else '' + trimmed_state_dict.update( + pmodule.trim_merged_state_dict( + pmodule.rank, module_state_dict, prefix=prefix, + device=device + ) + ) + return trimmed_state_dict + + +def _send_trimmed_module_state_dict( + trimmed_state_dict: Dict[str, torch.Tensor], + group: torch.distributed.ProcessGroup, + dst_rank: int, +): + """ + Send the trimmed state dict to the specified destination rank. + + Args: + trimmed_state_dict (Dict[str, torch.Tensor]): the trimmed state dict to send. + dst_rank (int): the destination rank to send the state dict to. + """ + # send trimmed state dict to rank + # one tensor each time + keys = list(trimmed_state_dict.keys()) + shape_dtypes = [(tensor.shape, tensor.dtype) for tensor in trimmed_state_dict.values()] + torch.distributed.send_object_list([keys, shape_dtypes], group=group, dst=dst_rank) + for key in keys: + tensor = trimmed_state_dict[key] + # NOTE: send is broken if the tensor is not contiguous + torch.distributed.send(tensor.contiguous(), group=group, dst=dst_rank) + + +def _receive_trimmed_module_state_dict( + src_rank: int, + group: torch.distributed.ProcessGroup, + device: Union[str, torch.device] = None, +): + """ + Receive the trimmed state dict from the specified source rank. + + Args: + src_rank (int): the source rank to receive the state dict from. + """ + device = device or torch.cuda.current_device() + + # receive trimmed state dict from rank + # one at a time + keys_shape_dtypes=[None, None] + torch.distributed.recv_object_list(keys_shape_dtypes, group=group, src=src_rank) + keys: list[str] = keys_shape_dtypes[0] + shape_dtypes: list[tuple[torch.Size, torch.dtype]] = keys_shape_dtypes[1] + + trimmed_state_dict = {} + for key, shape_dtype in zip(keys, shape_dtypes): + tensor = torch.zeros(shape_dtype[0], dtype=shape_dtype[1], device=device) + torch.distributed.recv(tensor, group=group, src=src_rank) + trimmed_state_dict[key] = tensor + return trimmed_state_dict + + +def _send_trimmed_opt_state_dict( + trimmed_opt_state_dict: OptStateDict, + group: torch.distributed.ProcessGroup, + dst_rank: int, +): + """ + Send the trimmed optimizer state dict to the specified destination rank. + + Args: + trimmed_opt_state_dict (OptStateDict): the trimmed optimizer state dict to send. + dst_rank (int): the destination rank to send the state dict to. + """ + # send trimmed optimizer state dict to rank + # one tensor each time + + # broadcast param groups and state keys/shapes/dtypes via broadcast_object_list + state_info = {} + state_keys = list(trimmed_opt_state_dict['state'].keys()) + param_group = trimmed_opt_state_dict['param_groups'] + for idx in state_keys: + state_info[idx] = {key: (value.shape, value.dtype) for key, value in trimmed_opt_state_dict['state'][idx].items()} + sent = [state_keys, state_info, param_group] + torch.distributed.send_object_list(sent, group=group, dst=dst_rank) + + # broadcast step in stack + if 'step' in trimmed_opt_state_dict['state'][state_keys[0]]: + step_stack = torch.stack( + [trimmed_opt_state_dict['state'][k]['step'] for k in state_keys] + ) + torch.distributed.send(step_stack, group=group, dst=dst_rank) + + # broadcast other states + # TODO: can be slow? + for k in state_keys: + keys = sorted(trimmed_opt_state_dict['state'][k].keys()) + if 'step' in keys: + keys.remove('step') # we have done step in previous. + for key in keys: + value = trimmed_opt_state_dict['state'][k][key] + torch.distributed.send(value.data, group=group, dst=dst_rank) + + +def _receive_trimmed_opt_state_dict( + src_rank: int, + group: torch.distributed.ProcessGroup, + device: Union[str, torch.device] = None, + ) -> OptStateDict: + """ + Receive the trimmed optimizer state dict from the specified source rank. + + Args: + src_rank (int): the source rank to receive the state dict from. + """ + device = device or torch.cuda.current_device() + + # receive trimmed optimizer state dict from rank + # one at a time + state_dict_info = [None, None, None] + torch.distributed.recv_object_list(state_dict_info, group=group, src=src_rank) + state_keys: list[str] = state_dict_info[0] + state_info: list[tuple[torch.Size, torch.dtype]] = state_dict_info[1] + param_group = state_dict_info[2] + + trimmed_opt_state_dict = { + 'state': {}, + 'param_groups': param_group + } + for key in state_keys: + trimmed_opt_state_dict['state'][key] = { + k: torch.zeros(v[0], dtype=v[1], device=device) + for k, v in state_info[key].items() + } + + # receive steps + if 'step' in trimmed_opt_state_dict['state'][state_keys[0]]: + step_stack = torch.zeros( + len(state_keys), + dtype=trimmed_opt_state_dict['state'][state_keys[0]]['step'].dtype, + device=device + ) + torch.distributed.recv(step_stack, group=group, src=src_rank) + for k, v in zip(state_keys, step_stack): + trimmed_opt_state_dict['state'][k]['step'].copy_(v) + + # receive other states + for k in state_keys: + keys = sorted(trimmed_opt_state_dict['state'][k].keys()) + if 'step' in keys: + keys.remove('step') # we have done step in previous. + for key in keys: + value = trimmed_opt_state_dict['state'][k][key] + torch.distributed.recv(value.data, group=group, src=src_rank) + + return trimmed_opt_state_dict + + +def trimmed_broadcast_merged_state_dict( + module: torch.nn.Module, + module_state_dict: Optional[Dict[str, Any]] = None, + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + device: Union[str, torch.device] = None, +) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + """ + trim merged state dict and broadcast to each rank. + + Args: + module (torch.nn.Module): the module to be loaded + module_state_dict (Dict[str, Any]): the merged model state dict + optimizer (Optional[torch.optim.Optimizer]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. + src_rank (int): the source rank to load the merged state dict from. + dst_ranks (Optional[list[int]]): the destination ranks to load the merged state dict to. + + Returns: + Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + the trimmed state dicts for the module and optimizer + """ + device = device or torch.cuda.current_device() + world_size = torch.distributed.get_world_size() + dst_ranks = dst_ranks or list(range(world_size)) + cur_rank = torch.distributed.get_rank() + + if cur_rank not in dst_ranks or src_rank not in dst_ranks: + raise ValueError( + f"Invalid rank configuration. Both current rank ({cur_rank}) and source rank ({src_rank}) " + f"must be in the destination ranks {dst_ranks}." + ) + + pg = DeviceGroup().get_group(dst_ranks) + + if cur_rank == src_rank: + if optimizer_state_dict and not optimizer: + raise ValueError("Optimizer must be provided when loading optimizer state dict.") + else: + if optimizer_state_dict or module_state_dict: + raise ValueError("Only the source rank can provide the merged state dicts.") + + rank_metadata = ( + {module_path: m._pack() for module_path, m in module.named_modules() if isinstance(m, ParallelModule)}, + optimizer._extra_state if optimizer else None, + ) + + rank_metadatas = [None] * len(dst_ranks) if cur_rank == src_rank else None + torch.distributed.gather_object(rank_metadata, rank_metadatas, group=pg, dst=src_rank) + + if cur_rank == src_rank: + will_load_opt_state = [optimizer_state_dict is not None] + else: + will_load_opt_state = [None] + torch.distributed.broadcast_object_list(will_load_opt_state, group=pg, src=src_rank) + will_load_opt_state = will_load_opt_state[0] + if will_load_opt_state and not optimizer: + raise ValueError("Optimizer must be provided when loading optimizer state dict.") + + ret = None + + if cur_rank == src_rank: + pmodule_stubs = [_construct_parallel_module_stub(r[0]) for r in rank_metadatas] + opt_extra_states = [r[1] for r in rank_metadatas] + for rank in dst_ranks: + if rank != cur_rank: + logger.info(f'At rank {src_rank}: Trimming module state dict for rank {rank}') + trimmed_module_state_dict = _trim_module_merged_state_dict( + pmodule_stubs[rank], + module_state_dict, + device=device, + ) + logger.info(f'At rank {src_rank}: Sending trimmed module state dict for rank {rank}') + _send_trimmed_module_state_dict(trimmed_module_state_dict, dst_rank=rank, group=pg) + del trimmed_module_state_dict + + if will_load_opt_state: + logger.info(f'At rank {src_rank}: Trimming optimizer state dict for rank {rank}') + trimmed_opt_state_dict = _trim_optimizer_merged_state_dict( + pmodule_stubs[rank], + opt_extra_states[rank], + optimizer_state_dict, + device=device, + ) + logger.info(f'At rank {src_rank}: Sending trimmed optimizer state dict for rank {rank}') + _send_trimmed_opt_state_dict(trimmed_opt_state_dict, dst_rank=rank, group=pg) + del trimmed_opt_state_dict + + torch.distributed.barrier(group=pg) + + # load for self after state dict for all other ranks are sent + # this can lower gpu memory peak + logger.info(f'At rank {src_rank}: Trimming module state dict for self rank {cur_rank}') + trimmed_module_state_dict = _trim_module_merged_state_dict( + pmodule_stubs[cur_rank], + module_state_dict, + device=device, + ) + if will_load_opt_state: + logger.info(f'At rank {src_rank}: Trimming optimizer state dict for self rank {cur_rank}') + trimmed_opt_state_dict = _trim_optimizer_merged_state_dict( + pmodule_stubs[cur_rank], + opt_extra_states[cur_rank], + optimizer_state_dict, + device=device, + ) + else: + trimmed_opt_state_dict = None + ret = (trimmed_module_state_dict, trimmed_opt_state_dict) + else: + for rank in dst_ranks: + if rank == cur_rank: + # receive state dict from src_rank + logger.info(f'At rank {cur_rank}: Receiving trimmed module state dict from rank {src_rank}') + trimmed_module_state_dict = _receive_trimmed_module_state_dict(src_rank, group=pg) + + if will_load_opt_state: + logger.info(f'At rank {cur_rank}: Receiving trimmed optimizer state dict from rank {src_rank}') + trimmed_opt_state_dict = _receive_trimmed_opt_state_dict(src_rank, group=pg) + else: + trimmed_opt_state_dict = None + + ret = (trimmed_module_state_dict, trimmed_opt_state_dict) + + torch.distributed.barrier(group=pg) + + assert ret is not None + # make it a sharded state dict. + for module_path, m in module.named_modules(): + prefix = module_path + '.' if module_path else '' + if isinstance(m, ParallelModule): + m._add_extra_state(ret[0], prefix) + return ret + + +def load_merged_state_dict_from_rank( + module: torch.nn.Module, + module_state_dict: Optional[Dict[str, Any]] = None, + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + device: Union[str, torch.device] = None, +): + """ + load the merged state dict from rank. + + Only src_rank will load merged state dict to memory (for saving memory), + and dst_ranks will receive the sharded state dict from src_rank via communication. + + Args: + module (torch.nn.Module): the module to be loaded + module_state_dict (Dict[str, Any]): the merged model state dict + optimizer (Optional[torch.optim.Optimizer]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. + src_rank (int): the source rank to load the merged state dict from. + dst_ranks (Optional[list[int]]): the destination ranks to load the merged state dict to. + + Returns: + None + """ + trimmed_module_state_dict, trimmed_opt_state_dict = trimmed_broadcast_merged_state_dict( + module, + module_state_dict, + optimizer, + optimizer_state_dict, + device=device, + src_rank=src_rank, + dst_ranks=dst_ranks, + ) + module.load_state_dict(trimmed_module_state_dict) + if trimmed_opt_state_dict: + optimizer.load_state_dict(trimmed_opt_state_dict) diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index d28cdf52..d842b702 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -11,6 +11,7 @@ from nnscaler.runtime.device import DeviceGroup from nnscaler.profiler.timer import CudaTimer from nnscaler.flags import RuntimeFlag +from nnscaler.utils import unchecked_fields _logger = logging.getLogger(__name__) @@ -407,6 +408,52 @@ def wake_up(self, param_buffer, grad_buffer): self._hooks = [] self.register_hooks() + def _pack( + self, + param_map: dict[torch.nn.Parameter, torch.nn.Parameter], + ): + """ + Get the information of the bucket. + """ + state = self.__dict__.copy() + + fields = unchecked_fields(self) + state[fields._params] = [param_map[p] for p in self._params] + state[fields._pofset] = {param_map[p]: ofst for p, ofst in self._pofset.items()} + state[fields._param_for_optimizer] = torch.nn.Parameter(torch.empty_like(self._param_for_optimizer, device='meta')) + state[fields._contiguous_params] = torch.empty_like(self._contiguous_params, device='meta') + state[fields._contiguous_grads] = torch.empty_like(self._contiguous_grads, device='meta') + + # remove torch handles + state.pop(fields._group, None) + state.pop(fields._async_handle, None) + state.pop(fields._async_param_cnt, None) + state.pop(fields._zero_subgroup, None) + state.pop(fields._zero_crossgroup, None) + + # remove hooks + state.pop(fields._hooks, None) + state.pop(fields._pre_hooks, None) + state.pop(fields._post_hooks, None) + + return state + + @classmethod + def _unpack(cls, state: dict): + """ + Return a fake bucket that carries the same information. + """ + bucket = object.__new__(cls) + bucket.__dict__.update(state) + + for param in bucket._params: + assert param.device.type == 'meta' + assert bucket._contiguous_grads.device.type == 'meta' + assert bucket._contiguous_grads.device.type == 'meta' + assert bucket._param_for_optimizer.device.type == 'meta' + + return bucket + class Reducer: # the default bucket cap for async reducer in megabytes @@ -593,11 +640,14 @@ def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None Typically this will be called when building optimizer when multiple optimizers/param groups are used. And we will put parameters with different optimizer or different param groups into different buckets. """ - self._param_clss = param_clss or {} - # sort parameters by their class - # which can help bucket building - if self._param_clss: + self._param_clss = {} + if param_clss: + # only keep parameters that are in self._params + self._param_clss = {p: param_clss[p] for p in self._params} + # sort parameters by their class + # which can help bucket building self._params.sort(key=lambda p: self._param_clss[p]) + for bucket in self._buckets: # rebuild bucket should be done before any hooks registered. if bucket._pre_hooks or bucket._post_hooks: @@ -842,3 +892,58 @@ def wake_up(self): self._contiguous_params[start:stop], self._contiguous_grads[start:stop], ) + + def _pack( + self, + param_map: dict[torch.nn.Parameter, torch.nn.Parameter], + ): + """ + Get the information of the bucket. + """ + state = self.__dict__.copy() + fields = unchecked_fields(self) + + state[fields._params] = [param_map[p] for p in self._params] + state[fields._param_clss] = {param_map[p]: param_cls for p, param_cls in self._param_clss.items()} + state[fields._contiguous_params] = torch.empty_like(self._contiguous_params, device='meta') + state[fields._contiguous_grads] = torch.empty_like(self._contiguous_grads, device='meta') + + state[fields._buckets] = [ + bucket._pack(param_map) + for bucket in self._buckets + ] + + # remove torch handles + state.pop(fields._group, None) + state.pop(fields._zero_subgroup, None) + state.pop(fields._zero_crossgroup, None) + + # remove unuseful information + state.pop(fields._param_ids, None) + state.pop(fields.seq_buckets, None) + + return state + + @classmethod + def _unpack(cls, state: dict): + """ + Return a fake bucket that carries the same information. + """ + reducer = object.__new__(cls) + fields = unchecked_fields(reducer) + + buckets = state.pop(fields._buckets) + reducer._buckets = [ + Bucket._unpack(bucket) for bucket in buckets + ] + reducer.__dict__.update(state) + for param in reducer._params: + assert param.device.type == 'meta' + + for param in reducer._param_clss.keys(): + assert param.device.type == 'meta' + + assert reducer._contiguous_grads.device.type == 'meta' + assert reducer._contiguous_params.device.type == 'meta' + + return reducer diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 4745da70..58512adf 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -27,7 +27,7 @@ from nnscaler import __version__ as runtime_version from nnscaler.flags import CompileFlag -from nnscaler.utils import accum_mode, classproperty +from nnscaler.utils import accum_mode, classproperty, unchecked_fields if TYPE_CHECKING: from nnscaler.parallel import ComputeConfig @@ -223,7 +223,7 @@ def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: def get_opt_params(self, prefix='', classify_param_cls_fn: Callable[[str], Any]=None) -> dict[torch.nn.Parameter, Any]: """ - Get all parameters and their classifications + Get all parameters and their classifications. Parameters in reducers come first. Args: prefix (str): The prefix of this module, @@ -232,7 +232,6 @@ def get_opt_params(self, prefix='', classify_param_cls_fn: Callable[[str], Any]= Returns: dict[torch.nn.Parameter, Any]: A dictionary mapping parameters to their classifications. - """ params = {} reducer_pids = set() @@ -945,7 +944,12 @@ def __init__(self): # track whether all the parames (especially the non-persistent buffers) have been initialized self._non_presistent_buffers_inited = False - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, skip_init=False, **kwargs): + # special case when we just fake a ParallelModule class + # In this case, you should also use object.__new__ instead of __init__ + if skip_init: + return + from nnscaler.parallel import ComputeConfig super().__init_subclass__(**kwargs) @@ -1423,18 +1427,60 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s Raises: RuntimeError: if strict=True and there are missing keys. """ - - dist2param = self.dist_param_map - orig_param_names = list(dist2param.values()) # param names in original module (without prefix) non_persistent_buffers = self.get_non_persistent_buffers() with torch.no_grad(): # avoid checking the non-persistent buffers attr_names = set([attr for attr in self._fullmap.keys() if attr not in non_persistent_buffers]) - origname_tid_map = {meta.orig_name: meta.tid for meta in self._fullmap.values()} + for prefix_attr, content in self.trim_merged_state_dict(self.rank, state_dict, prefix).items(): + attr = prefix_attr[len(prefix):] + tensor: torch.Tensor = getattr(self, attr) + tensor.copy_(content) + attr_names.remove(attr) + + missing_keys = [prefix + self._fullmap[attr].orig_name for attr in attr_names] + if len(attr_names) != 0: + erro_msg = f'Missing key(s) in state_dict: {missing_keys}.' + if strict: + raise RuntimeError(erro_msg) + else: + _logger.warning(erro_msg) + + self._warn_uninitialized_non_persistent_buffers() + return missing_keys + + @classmethod + def trim_merged_state_dict( + cls, + rank, + state_dict: Dict[str, Any], + prefix: str = '', + *, + device=None, + ) -> Dict[str, Any]: + """ + Trim the merged state dict to only keep the parameters needed for the module. + Please note we don't check missing/unexpected keys. + + Args: + state_dict (Dict[str, Any]): the merged state dict + prefix (str): the prefix of the model state dict in the merged state dict + + Returns: + Dict[str, Any]: the trimmed state dict + """ + device = device or torch.cuda.current_device() + trimmed_state_dict = {} + + dist2param = cls.dist_param_map + orig_param_names = list(dist2param.values()) # param names in original module (without prefix) + attr_meta_map = cls.get_attr_meta_map(rank) + with torch.no_grad(): + # avoid checking the non-persistent buffers + origname_tid_map = {meta.orig_name: meta.tid for meta in attr_meta_map.values()} tid_info = defaultdict(list) - for attr, meta in self._fullmap.items(): + for attr, meta in attr_meta_map.items(): tid_info[meta.tid].append((attr, meta.slicers, meta.val_chunks)) # multiple params may share the same tid for orig_param_name in orig_param_names: @@ -1447,20 +1493,61 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s param_value = state_dict[orig_param_name_with_prefix] tid = origname_tid_map[orig_param_name] for attr, slicer, nchunks in tid_info[tid]: - tensor: torch.Tensor = getattr(self, attr) content = param_value[slicer] if nchunks != 1: content = content / nchunks - tensor.copy_(content) - attr_names.remove(attr) + trimmed_state_dict[prefix + attr] = content.to(device) - missing_keys = [prefix + self._fullmap[attr].orig_name for attr in attr_names] - if len(attr_names) != 0: - erro_msg = f'Missing key(s) in state_dict: {missing_keys}.' - if strict: - raise RuntimeError(erro_msg) - else: - _logger.warning(erro_msg) + return trimmed_state_dict - self._warn_uninitialized_non_persistent_buffers() - return missing_keys + def _pack( + self, + ): + """ + Get a packed information of the ParallelModule, so it can be sent to other ranks. + """ + param_map: dict[torch.nn.Parameter, torch.nn.Parameter] = {} + for p in self.parameters(): + param_map[p] = torch.nn.Parameter( + torch.empty_like(p, device='meta')) if p is not None else None + for b in self.buffers(): + param_map[b] = torch.empty_like( + b, device='meta') if b is not None else None + state = {} + fields = unchecked_fields(self) + state[fields._parameters] = {n: param_map[p] for n, p in self._parameters.items()} + state[fields._buffers] = {n: param_map[b] for n, b in self._buffers.items()} + state[fields._reducers] = [reducer._pack(param_map) for reducer in self._reducers] + state[fields._zero_metadata] = self._zero_metadata + state[fields._fullmap] = self._fullmap + + for cv in ParallelModule.__annotations__: + state[cv] = getattr(self, cv) + return state + + @classmethod + def _unpack(cls, state: dict): + """ + Unpack the information and return a fake ParallelModule that carries the same information. + """ + class GenModelX(ParallelModule, skip_init=True): + pass + pm = object.__new__(GenModelX) + fields = unchecked_fields(pm) + object.__setattr__(pm, fields._parameters, state[fields._parameters]) + object.__setattr__(pm, fields._buffers, state[fields._buffers]) + object.__setattr__(pm, fields._reducers, [Reducer._unpack(reducer) for reducer in state[fields._reducers]]) + object.__setattr__(pm, fields._zero_metadata, state[fields._zero_metadata]) + object.__setattr__(pm, fields._fullmap, state[fields._fullmap]) + + def named_parameters( + prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ): + assert prefix == "" and recurse is True, "Only support default arguments" + return pm._parameters.items() + + pm.named_parameters = named_parameters + + for cv in ParallelModule.__annotations__: + setattr(GenModelX, cv, state[cv]) + return pm diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 21b8d3f1..4f254c7c 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -112,6 +112,27 @@ def get_member_by_name(model: torch.nn.Module, name: str) -> Any: return model_attr +def set_member_by_name(model: Any, name: str, value: Any) -> None: + """ + Set the member of the model by its full name. + """ + if not name: + raise ValueError("Name cannot be empty") + class _ValueHolder: + """ + A value holder. + In python you can't call `setattr` on object, but you can call it on its subclasses. + """ + pass + sliced_names = name.split(".") + model_attr = model + for sliced_name in sliced_names[:-1]: + if not hasattr(model_attr, sliced_name): + setattr(model_attr, sliced_name, _ValueHolder()) + model_attr = getattr(model_attr, sliced_name) + setattr(model_attr, sliced_names[-1], value) + + def get_shared_params(model: torch.nn.Module) -> List[List[str]]: paramid2name = defaultdict(set) for name in model.state_dict().keys(): @@ -325,6 +346,20 @@ def fields(model: TDataClass, /) -> TDataClass: return cast(TDataClass, _GetFields(model)) +class _UncheckedFields: + def __getattr__(self, item: str) -> Any: + return item + + +TUncheckedClass = TypeVar("TAnyClass") +def unchecked_fields(_: TUncheckedClass, /) -> TUncheckedClass: + """ + This function is used to get the field names(in str) of any object without checking + This is a workaround for the lack of `__name__` of member. + """ + return cast(TUncheckedClass, _UncheckedFields()) + + @cache def load_type(type_name: str): """ @@ -461,7 +496,7 @@ def steps(nsteps: int): class AdamOptState(TypedDict): - step: int + step: torch.Tensor exp_avg: torch.Tensor exp_avg_sq: torch.Tensor diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 0ae31da9..298652ad 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -987,19 +987,45 @@ def trainer_resumable_dataloader(save_dir): torch.distributed.barrier() # resume for merged - trainer = Trainer([ - '-f', config_path_streaming, - '--precision', 'bf16', - '--optimizer.type', optimizer_type, - '--enable_progress_bar', 'false', - '--gen_savedir', str(gen_savedir), - '--checkpoint.save_type', save_type, - '--checkpoint.save_dir', str(ckpt2_savedir), - '--checkpoint.resume_from', str(ckpt2_savedir / 'merged.pt'), - '--checkpoint.keep_last_n_checkpoints', '30', - ]) - trainer.run() - assert trainer.dataloader_resumed + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt2_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt2_savedir / 'merged.pt'), + '--checkpoint.resume_from.save_memory', False, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' not in log.getvalue() # no warning about dataloader states + + torch.distributed.barrier() + + + ckpt2_1_savedir = save_dir / 'ckpt2_1' + ckpt2_1_savedir.mkdir(parents=True, exist_ok=True) + # resume for merged + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt2_1_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt2_savedir / 'merged.pt'), + '--checkpoint.resume_from.save_memory', True, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states torch.distributed.barrier() @@ -1032,20 +1058,44 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.save_dir', str(ckpt4_savedir), '--checkpoint.resume_from.checkpoint', str(ckpt1_savedir / '0002-0035'), '--checkpoint.resume_from.with_merged', True, + '--checkpoint.resume_from.save_memory', False, '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.run() assert trainer.dataloader_resumed assert 'Broadcasting merged checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states + # resume from auto-merged with save_memory + ckpt5_savedir = save_dir / 'ckpt5' + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt5_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt1_savedir / '0002-0035'), + '--checkpoint.resume_from.with_merged', True, + '--checkpoint.resume_from.save_memory', True, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states + + if torch.distributed.get_rank() == 0: for i in range(4): g = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) z = torch.load(ckpt2_savedir / 'last' / f'{i}.ckpt', weights_only=False) + z_1 = torch.load(ckpt2_1_savedir / 'last' / f'{i}.ckpt', weights_only=False) w = torch.load(ckpt3_savedir / 'last' / f'{i}.ckpt', weights_only=False) v = torch.load(ckpt4_savedir / 'last' / f'{i}.ckpt', weights_only=False) + u = torch.load(ckpt5_savedir / 'last' / f'{i}.ckpt', weights_only=False) assert 'dataloader' not in g assert 'dataloader' in x for key in ['model', 'optimizer', 'lr_scheduler', 'dataloader']: @@ -1053,6 +1103,8 @@ def trainer_resumable_dataloader(save_dir): assert_equal(x[key], z[key]) assert_equal(x[key], w[key]) assert_equal(x[key], v[key]) + assert_equal(x[key], u[key]) + assert_equal(x[key], z_1[key]) if key != 'dataloader': assert_equal(g[key], x[key]) diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 6faee640..59510180 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -17,11 +17,18 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dict +from nnscaler.parallel import ( + ComputeConfig, parallelize, + build_optimizer, + merge_state_dicts, + load_merged_state_dict, + load_merged_state_dict_from_rank, + trimmed_broadcast_merged_state_dict, +) from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import CubeLinear, init_random, init_distributed, PASMegatron +from .common import CubeLinear, init_random, init_distributed, PASMegatron, assert_equal from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively from ..utils import replace_all_device_with, clear_dir_on_rank0 @@ -345,6 +352,23 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf optimizer_from_merged, merged_opt_state_dict, ) + model_from_merged_rank = type(model)() + optimizer_from_merged_rank = build_optimizer(model_from_merged_rank, torch.optim.Adam, lr=0.01) + load_merged_state_dict_from_rank( + model_from_merged_rank, merged_model_state_dict if torch.distributed.get_rank() == 0 else None, + optimizer_from_merged_rank, merged_opt_state_dict if torch.distributed.get_rank() == 0 else None, + ) + assert_equal(model_from_merged.state_dict(), model_from_merged_rank.state_dict()) + assert_equal(optimizer_from_merged.state_dict(), optimizer_from_merged_rank.state_dict()) + + trimmed_model_state_dict, trimmed_opt_state_dict = trimmed_broadcast_merged_state_dict( + model_from_merged_rank, merged_model_state_dict if torch.distributed.get_rank() == 0 else None, + optimizer_from_merged_rank, merged_opt_state_dict if torch.distributed.get_rank() == 0 else None, + ) + assert_equal(dict(model_from_merged.state_dict()), trimmed_model_state_dict) + assert_equal(optimizer_from_merged.state_dict()['state'], trimmed_opt_state_dict['state']) + assert_equal(optimizer_from_merged.state_dict()['param_groups'], trimmed_opt_state_dict['param_groups']) + # check merged model result_orig_model_state_dict = model.state_dict() result_merged_model_state_dict = model_from_merged.state_dict() diff --git a/tests/runtime/test_hybrid_optimizer.py b/tests/runtime/test_hybrid_optimizer.py index e22d8f72..65b6b067 100644 --- a/tests/runtime/test_hybrid_optimizer.py +++ b/tests/runtime/test_hybrid_optimizer.py @@ -93,17 +93,19 @@ def trainer_worker(save_dir, use_zero): assert_equal(x['model'], y['model']) assert_equal(x['optimizer'], y['optimizer']) - trainer = Trainer([ + # train with different config + trainer_config = [ '-f', config_path, '--compute_config.plan_ngpus', '2', '--pas_policy', 'tp', '--max_train_steps', '30', '--checkpoint.resume_from.checkpoint', 'last', - '--checkpoint.resume_from.with_merged', True, + '--checkpoint.resume_from.with_merged', str(True), '--gen_savedir', str(gen_savedir), '--checkpoint.save_dir', str(ckpt0_savedir), '--compute_config.use_zero', str(not use_zero), - ]) + ] + trainer = Trainer(trainer_config) trainer.run() torch.distributed.barrier() if torch.distributed.get_rank() == 0: @@ -113,17 +115,36 @@ def trainer_worker(save_dir, use_zero): torch.distributed.barrier() - # trainer = Trainer([ - # '-f', config_path, - # '--compute_config.plan_ngpus', '1', - # '--max_train_steps', '40', - # '--checkpoint.resume_from.checkpoint', 'last', - # '--checkpoint.resume_from.with_merged', True, - # '--gen_savedir', str(gen_savedir), - # '--checkpoint.save_dir', str(ckpt0_savedir), - # ]) - # trainer.run() - # torch.distributed.barrier() + from subprocess import check_call as _call + from functools import partial + call = partial(_call, shell=True) + + if torch.distributed.get_rank() == 0: + call(f"python -m nnscaler.cli.checkpoint distribute {ckpt1_savedir}/last {ckpt1_savedir}/sharded {' '.join(trainer_config)} --compute_config.runtime_ngpus {torch.distributed.get_world_size()}") + + torch.distributed.barrier() + + trainer = Trainer([ + '-f', config_path, + '--compute_config.plan_ngpus', '2', + '--pas_policy', 'tp', + '--max_train_steps', '30', + '--checkpoint.resume_from.checkpoint', f'{ckpt1_savedir}/sharded', + '--checkpoint.resume_from.with_merged', str(False), + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--compute_config.use_zero', str(not use_zero), + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + + torch.distributed.barrier() @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') diff --git a/tests/test_utils.py b/tests/test_utils.py index 7fa7d80a..a92c36c1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,8 +3,9 @@ from dataclasses import dataclass import pytest +import torch -from nnscaler.utils import select_many, classproperty, fields +from nnscaler.utils import select_many, classproperty, fields, set_member_by_name, unchecked_fields def test_select_many(): @@ -53,3 +54,41 @@ class A: assert fields(A).y == 'y' with pytest.raises(AttributeError): fields(A).z + + assert unchecked_fields(A).x == 'x' + assert unchecked_fields(A).y == 'y' + assert unchecked_fields(A).z == 'z' + + a = A(x=0, y=0) + assert unchecked_fields(a).x == 'x' + assert unchecked_fields(a).y == 'y' + assert unchecked_fields(a).z == 'z' + + class B: + def __init__(self): + self.a = A(x=1, y=2) + + assert unchecked_fields(B).x == 'x' + b = B() + assert unchecked_fields(b).x == 'x' + assert unchecked_fields(b.a).x == 'x' + + +def test_set_member_by_name(): + model = torch.nn.Module() + set_member_by_name(model, "x", 42) + assert model.x == 42 + with pytest.raises(AttributeError): + set_member_by_name(model, 'x.y.z', 43) + + set_member_by_name(model, 'a.b.c', 44) + assert model.a.b.c == 44 + + model = torch.nn.Module() + child_module = torch.nn.Module() + set_member_by_name(model, "x.y", child_module) + assert model.x.y == child_module + + set_member_by_name(model, 'x.y.z', 45) + assert model.x.y == child_module + assert model.x.y.z == 45 From e15b7f1b995f25ef956ee67981b936ef54fc7228 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 12 Sep 2025 03:37:44 +0000 Subject: [PATCH 1839/1892] Merged PR 2406: [Runtime] Allow custom parameter bucketing Promote `param_clss_fn` to optimizer config 1. Allow users to customize how to group parameters into buckets. This is useful in some cases, for example, in MoE, the user can put experts into different buckets 2. Unify the interface of hybrid optimizer --- nnscaler/cli/trainer.py | 2 +- nnscaler/cli/trainer_args.py | 13 +++++-- nnscaler/parallel.py | 38 +++++++++---------- nnscaler/runtime/hybrid_optimizer.py | 2 - tests/cli/test_trainer.py | 2 +- .../test_hybrid_optimizer_trainer_args.yaml | 2 +- 6 files changed, 31 insertions(+), 28 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 994932b2..41b1132f 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -163,7 +163,7 @@ def _setup(self): pmodel = parallelize_model( self.train_args, self.dummy_input, load_module=not compile_only, - build_buckets=not self.train_args.is_hybrid_optimizer() + build_buckets=not self.train_args.should_delay_bucket_building() ) if compile_only: return diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 0c368aab..6580d0f9 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -20,7 +20,7 @@ import torch import nnscaler -from nnscaler.utils import fields, transform_recursively, load_type +from nnscaler.utils import fields, fn_field, transform_recursively, load_type from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule @@ -314,6 +314,7 @@ class OptimizerConfig: args: Dict[str, Any] = field(default_factory=dict) clip_gnorm: float = 0.0 + param_clss_fn: Optional[Callable[[str], Any]] = fn_field(default=None) # loss reduction method # mean: average the loss over all micro-batches # sum: sum the loss of all micro-batches @@ -868,13 +869,17 @@ def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model.args) return self.model_type(**kwargs) - def is_hybrid_optimizer(self) -> bool: - return getattr(load_type(self.optimizer.type), 'is_hybrid', False) + def should_delay_bucket_building(self) -> bool: + return self.optimizer.param_clss_fn is not None def create_parallel_optimizer(self, parallel_model: torch.nn.Module): kwargs = self.create_kwarg(self.optimizer.args) optimizer_class = load_type(self.optimizer.type) - return build_optimizer(parallel_model, optimizer_class, self.compute_config, **kwargs) + return build_optimizer( + parallel_model, optimizer_class, self.compute_config, + self.optimizer.param_clss_fn, + **kwargs + ) def create_dataset(self, stage='train'): dataset_args = getattr(self.dataset, f'{stage}_args') diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index d6a170be..e6e75259 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1270,7 +1270,7 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] def hybrid( params: list[torch.nn.Parameter], - param_clss_fn: Callable[[str], tuple[int, int]], + param_clss: dict[torch.nn.Parameter, tuple[int, int]], **kwargs, ) -> HybridOptimizerT: """ @@ -1280,9 +1280,9 @@ def hybrid( def __init__(self, params, param_clss, **kwargs): ... ``` - But when you pass arguments to `build_optimizer` - You must replace `param_clss` with `param_clss_fn`, - And `build_optimizer` will automatically replace `param_clss_fn` with the actual `param_clss`. + When you pass arguments to `build_optimizer` + You must pass `param_clss_fn`, + and `build_optimizer` will automatically pass `param_clss` to its constructor. """ ... hybrid.is_hybrid = True # mark this function as hybrid optimizer factory @@ -1292,6 +1292,7 @@ def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], compute_config: Optional[ComputeConfig] = None, + param_clss_fn: Optional[Callable[[str], Any]] = None, **kwargs, ) -> Union[OptimizerT, ParallelOptimizer]: """ @@ -1319,6 +1320,11 @@ def build_optimizer( compute_config (Optional[ComputeConfig]): The config will be used to generate communication reducer. If it is None, Default configuration will be used when creating reducer for non-parallel modules. + param_clss_fn (Optional[Callable[[str], Any]]): + A function that maps original full qualified parameter names to their class IDs. + If you are using a hybrid optimizer, + you must specify this function + and the return value of this function must be a tuple[int, int] of (optimizer_index, param_group_index). **kwargs: the kwargs for optimizer constructor Returns: @@ -1327,9 +1333,6 @@ def build_optimizer( and will be patched with the methods in ParallelModule class to support parallelized module. Please note the type annotation of the returned optimizer (`Union[OptimizerT, ParallelOptimizer]`) is just for intellisense. """ - - PARAM_CLSS_FN_NAME = 'param_clss_fn' - if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): raise RuntimeError("Old style CubeModule is not supported") @@ -1337,13 +1340,9 @@ def build_optimizer( if any(m != module and isinstance(m, ParallelModule) and m.compute_config.use_end2end for m in module.modules()): raise RuntimeError("End2End module cannot be nested in another module") - is_hybrid = False - if getattr(optimizer_fn, 'is_hybrid', False): - if PARAM_CLSS_FN_NAME not in kwargs: - raise ValueError("param_clss_fn must be provided when using hybrid optimizer") - # syntax sugar - kwargs[PARAM_CLSS_FN_NAME] = load_type(kwargs[PARAM_CLSS_FN_NAME]) - is_hybrid = True + is_hybrid = getattr(optimizer_fn, 'is_hybrid', False) + if is_hybrid and param_clss_fn is None: + raise ValueError("param_clss_fn must be provided when using hybrid optimizer") RuntimeFlag.skip_reducer = True RuntimeFlag.skip_zero_grad = False @@ -1374,8 +1373,9 @@ def build_optimizer( f'{module_prefix}.{original_name}' if module_prefix else original_name else: param_original_names[p] = n - if is_hybrid: - param_clss = {p: kwargs[PARAM_CLSS_FN_NAME](n) for p, n in param_original_names.items()} + + if param_clss_fn: + param_clss = {p: param_clss_fn(n) for p, n in param_original_names.items()} else: param_clss = {} @@ -1410,7 +1410,7 @@ def build_optimizer( non_parallel_module_reducer.add_param(param) non_parallel_module_reducer.build_buckets(param_clss=param_clss) - if is_hybrid: + if param_clss_fn: for pm in parallel_modules: pm.build_buckets(param_clss=param_clss) for reducer in pm.reducers: @@ -1449,8 +1449,8 @@ def _local_parameters(module: torch.nn.Module): if is_hybrid: optimizer = optimizer_fn(_local_parameters(module), - param_clss=param_clss, - **{k: v for k, v in kwargs.items() if k != PARAM_CLSS_FN_NAME} + param_clss, + **kwargs ) else: optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), **kwargs) diff --git a/nnscaler/runtime/hybrid_optimizer.py b/nnscaler/runtime/hybrid_optimizer.py index e860f56d..cd0379a4 100644 --- a/nnscaler/runtime/hybrid_optimizer.py +++ b/nnscaler/runtime/hybrid_optimizer.py @@ -79,8 +79,6 @@ def __init__( Args: params (Iterable[torch.nn.Parameter]): The parameters to optimize. param_clss (dict[torch.nn.Parameter, tuple[int, int]]): The parameter classes for each parameter. - Please replace this argument with `param_clss_fn` (Callable[[str], tuple[int, int]]) - when you use creating it with `nnscaler.build_optimizer` (including cli trainer). config (Union[HybridOptConfig, dict[str, Any]]): The configuration for the hybrid optimizer. """ params = list(params) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 298652ad..a837db29 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -694,8 +694,8 @@ def param_clss_fn(param_name: str) -> tuple[int, int]: optimizer_config = { 'type': 'nnscaler.HybridOptimizer', + 'param_clss_fn': param_clss_fn, 'args': { - 'param_clss_fn': param_clss_fn, 'config': { 'optimizers':[ { diff --git a/tests/runtime/test_hybrid_optimizer_trainer_args.yaml b/tests/runtime/test_hybrid_optimizer_trainer_args.yaml index 1484fe55..b84c4870 100644 --- a/tests/runtime/test_hybrid_optimizer_trainer_args.yaml +++ b/tests/runtime/test_hybrid_optimizer_trainer_args.yaml @@ -21,8 +21,8 @@ model: optimizer: type: nnscaler.HybridOptimizer + param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn args: - param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn config: optimizers: - type: torch.optim.Adam From e71c869e7b9eff6440b527c39a58f18bad70e446 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 15 Sep 2025 02:01:56 +0000 Subject: [PATCH 1840/1892] Merged PR 2407: [BwCompat]: add backward compatibiity for cli checkpoint config Allow checkpoint.resume_from to be str --- nnscaler/cli/trainer_args.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 6580d0f9..2bf23bf2 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -458,6 +458,10 @@ def resolved_convert_fn(self) -> Optional[Callable[[Dict[str, Any]], Dict[str, A return load_type(self.resume_from.convert_fn) def __post_init__(self): + # backward compatibility + if isinstance(self.resume_from, str): + self.resume_from = ResumeOptions(checkpoint=self.resume_from) + if self.resume_from and self.resume_from.checkpoint: if self.resume_from.checkpoint in ['last', 'best']: if not self.save_dir: From fa5ec8b471370c3f9ac564fa58b071b69cb13525 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Wed, 17 Sep 2025 12:16:00 +0000 Subject: [PATCH 1841/1892] Merged PR 2408: [Runtime] Bugfix: calculate gnorm correctly when existing different precision params --- nnscaler/runtime/gnorm.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nnscaler/runtime/gnorm.py b/nnscaler/runtime/gnorm.py index eb6a3e5b..5f69b059 100644 --- a/nnscaler/runtime/gnorm.py +++ b/nnscaler/runtime/gnorm.py @@ -40,7 +40,7 @@ class TidReplicaInfo: def _calc_grad_shape(slicers_list): - # caculate the shape of each full parameters/grads + # calculate the shape of each full parameters/grads tid2shape = {} for rank_slicers in slicers_list: for tid, slicers in rank_slicers.items(): @@ -50,7 +50,7 @@ def _calc_grad_shape(slicers_list): # slicer: (start, end, step) if slicer.stop > tid2shape[tid][i]: tid2shape[tid][i] = slicer.stop - # caculate the number of replicas of each model parameter + # calculate the number of replicas of each model parameter tid2nreplicas = {} for rank_slicers in slicers_list: for tid, slicers in rank_slicers.items(): @@ -117,7 +117,7 @@ def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int Returns: tid2nreplicas: dict, tid -> TidReplicaInfo """ - # caculate the number of replicas of each model parameter + # calculate the number of replicas of each model parameter tid2nreplicas = {} tid2ranksset = defaultdict(set) for tid2ranks in tid2ranks_list: @@ -241,7 +241,8 @@ def grad_exists(p): elif len(grads) == 1: total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) else: - if multi_tensor_l2norm_available: + dtypes = set([g.dtype for g in grads]) + if multi_tensor_l2norm_available and len(dtypes) == 1: total_norm = _multi_tensor_total_norm(grads).to(device) else: # torch.nn.utils.clip_grad_norm_ way to calculate the norm From eeae4850028bd373ed45ddb0e99712fa50938efa Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 19 Sep 2025 00:33:49 +0000 Subject: [PATCH 1842/1892] Merged PR 2409: [Refine] Parallelize: skip graph.forward call to simplify code There are two things that graph.forward does for parallelize api: 1. input naming (from node.name to node.target). This part moves to parser. 2. Force input gradient. This is already done in parser. So we don't need to call it again. --- nnscaler/graph/parser/parser.py | 10 +++++--- nnscaler/parallel.py | 35 +++------------------------- tests/compiler/test_compile.py | 3 +-- tests/graph/parser/test_converter.py | 4 ++-- tests/utils.py | 6 ++--- 5 files changed, 16 insertions(+), 42 deletions(-) diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 02c52611..3cb92257 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -111,7 +111,7 @@ def parse(module: torch.fx.GraphModule, # it should be wrapped into an IRObject for idx, placeholder in enumerate(placeholders): if not isinstance(inputs[idx], IRObject): - obj = IRObject(name=placeholder.name, value=inputs[idx], is_constant=False) + obj = IRObject(name=placeholder.target, value=inputs[idx], is_constant=False) inputs[idx] = obj frame.set_var(placeholder.name, obj) @@ -160,9 +160,13 @@ def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, assert hasattr(node, 'meta') and 'tensor_meta' in node.meta, f"Node {node} should have tensor_meta" meta = node.meta['tensor_meta'] - val = IR.new(node.name, meta, + val = IR.new( + # node.target is necesssary for input + # its name will be used to align with model forward args when generating code. + node.target if node.op == 'placeholder' else node.name, + meta, tensor_types=(TensorMetadata,), - is_constant=is_constant + is_constant=is_constant, ) frame.add_var(node.name, val) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index e6e75259..06a6af84 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -665,51 +665,22 @@ def _gen_graph( node.target: forward_args_default.get(node.target, inspect.Parameter.empty) for node in fx_input_nodes } - ir_dummy_inputs = [] - for node in fx_input_nodes: - if node.target.startswith('*'): # *args or **kwargs - if node.target.strip('*') in dummy_forward_args: - raise ValueError(f"Input {node.target}: *args or **kwargs is not suppported") - ir_dummy_inputs.append(None) # always set None to *args/**kwargs - elif node.target in dummy_forward_args: - ir_dummy_inputs.append(dummy_forward_args[node.target]) - elif forward_args[node.target] is not inspect.Parameter.empty: - ir_dummy_inputs.append(forward_args[node.target]) - else: - raise ValueError(f"Input {node.target} not in dummy forward args, nor has default value.") - for i in range(len(ir_dummy_inputs)): - # note: we will always set tensor to require gradient, which may - # generate backward communications in adapter. However, as long as - # the data doesn't require gradient in real runtime, the backward - # communication will not be triggered. - ir_dummy_inputs[i] = IR.new( - fx_input_nodes[i].target, ir_dummy_inputs[i], - requires_grad=True, - tosub=True, - is_constant=False, - ) - # if the input is a complex type, we should wrap it with IRObject - if not isinstance(ir_dummy_inputs[i], IRObject): - ir_dummy_inputs[i] = IRObject(fx_input_nodes[i].target, value=ir_dummy_inputs[i], is_constant=False) - # generate complete ir graph - ir_dummy_outputs = graph(*ir_dummy_inputs) if end2end_mode: # in end2end mode, we must use dataloader as the first argument of forward # we assume the first argument of forward is the data sample (which is a requirement in our doc) graph.use_dataloader_input() # we require the first output is the loss - if isinstance(ir_dummy_outputs, (list, tuple)): - ir_loss = ir_dummy_outputs[0] - else: - ir_loss = ir_dummy_outputs + ir_loss = graph.output(0) if not isinstance(ir_loss, IRTensor) or ir_loss.shape != (1,): # internally scalar tensor will be reshaped to (1,) in IRGraph raise RuntimeError(f"Loss can only be scalar tensor but got {ir_loss.shape if isinstance(ir_loss, IRTensor) else ir_loss}") else: ir_loss = None + # we generate backward nodes and setup gradient tensors here + # forward nodes are done when we trace the model if not inference_only: graph.backward(ir_loss) else: diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index c11d7a23..851d36d6 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -186,8 +186,7 @@ def train_iter(model, dataloader): # tensor parallelism + scale test test_tp2scale2 = partial(torchrun, 4, assert_parity, baseline, - partial(cube_run, 2, tp_policy), - 0.001, + partial(cube_run, 2, tp_policy) ) # pipeline parallelism test diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 20f2bcff..f15fcc29 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -52,9 +52,9 @@ def forward(self, x, **kwargs): assert ir_graph.name == 'MyModule' inputs = ir_graph.inputs() assert len(inputs) == 2 - assert inputs[0].name == nodes[0].name + assert inputs[0].name == nodes[0].target assert isinstance(inputs[0], IRTensor) - assert inputs[1].name == nodes[1].name + assert inputs[1].name == nodes[1].target assert isinstance(inputs[1], IRObject) outputs = ir_graph.outputs() diff --git a/tests/utils.py b/tests/utils.py index 22036f42..07e88497 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -58,7 +58,7 @@ def init_random(seed: int = 1): torch.cuda.manual_seed(seed) -def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-4) -> bool: +def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-3) -> bool: """Compare the output of baseline_fn and compile_fn Error will raise if the output of two functions are not the same. @@ -92,10 +92,10 @@ def assert_same_complex(gt, out): assert_same_complex(gt[key], out[key]) elif isinstance(gt, torch.Tensor): assert isinstance(out, torch.Tensor) - assert torch.allclose(gt, out, atol=atol), f'mismatched: {gt} != {out}' + assert torch.allclose(gt, out, atol=atol), f'mismatched (with atol {atol}): {gt} != {out}' elif isinstance(gt, float): assert isinstance(out, float) - assert math.isclose(gt, out, abs_tol=atol), f'mismatched: {gt} != {out}' + assert math.isclose(gt, out, abs_tol=atol), f'mismatched (with atol {atol}): {gt} != {out}' else: assert gt == out, f'mismatched: {gt} != {out}' assert_same_complex(baseline_outputs, compile_outputs) From 9c9b58f0fad71c641cb2b6a02d2f11b8ee64ddc6 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 13 Oct 2025 06:06:31 +0000 Subject: [PATCH 1843/1892] Merged PR 2411: [Feature] add value tracker add value tracker for IRObject and IRTensor dims. Here is an example of nagogpt (with constant_folding=False): ![image.png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2411/attachments/image.png) --- nnscaler/graph/function/dimops.py | 4 +- nnscaler/graph/function/function.py | 31 +++- nnscaler/graph/graph.py | 5 +- nnscaler/graph/parser/parser.py | 52 +++++-- nnscaler/graph/parser/value_tracker.py | 171 +++++++++++++++++++++++ nnscaler/ir/cten.py | 129 +++++++++++++++-- nnscaler/ir/tensor.py | 19 ++- nnscaler/ir/unique.py | 14 +- nnscaler/utils.py | 5 + tests/graph/parser/test_converter.py | 33 ++++- tests/graph/parser/test_parser.py | 2 +- tests/graph/parser/test_value_tracker.py | 105 ++++++++++++++ tests/parallel_module/test_gencode.py | 13 +- utility/visualize_value_tracks.py | 158 +++++++++++++++++++++ 14 files changed, 696 insertions(+), 45 deletions(-) create mode 100644 nnscaler/graph/parser/value_tracker.py create mode 100644 tests/graph/parser/test_value_tracker.py create mode 100644 utility/visualize_value_tracks.py diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index aaa2f5e3..827c3ed2 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -72,7 +72,7 @@ import logging from itertools import dropwhile -from nnscaler.ir.cten import IRTensor, IRObject +from nnscaler.ir.cten import IRTensor, IRObject, ValueTrack from nnscaler.ir.operator import IRFwOperation @@ -753,7 +753,7 @@ def ianno(self, index: int) -> ShapeAnno: @return dim_annos ShapeAnno: a tuple that each element is a dimension annotation """ assert index < len(self.inputs()), "index out of boudary" - return tuple(self._iannos[index]) + return self._iannos[index] def oanno(self, index: int) -> ShapeAnno: """! diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 9adff8fb..66cfff2a 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -34,7 +34,7 @@ import logging from collections.abc import Iterable -from nnscaler.ir.cten import IRTensor, IRObject, IR +from nnscaler.ir.cten import IRTensor, IRObject, IR, ValueTrack from nnscaler.ir.tensor import IRSubTensor, IRFullTensor from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule @@ -270,7 +270,21 @@ def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Uni size = (math.ceil((end_val-start_val)/step_val),) anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), False) - return IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) + + # Output will be replaced in Parser, + # Here we just pass the value tracks out + output = IRFullTensor(size) + if not isinstance(start, IRObject) and start == 0 \ + and not isinstance(step, IRObject) and step == 1 \ + and isinstance(end, IRObject): + # a special case for arange(0, end), which is very common in practice + # we can directly use end's value track + output.dim_tracks = [end.value_track] + else: + output.dim_tracks = [ValueTrack.new([start, end, step])] + ret = IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) + ret.set_output(0, output) + return ret def Arange(*args, start=None, end=None, step=None, out=None, dtype=None, layout=None, @@ -2355,12 +2369,15 @@ def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: torch.Tensor.size(tensor, dim=None) """ assert isinstance(tensor, IRTensor) - val = tensor.shape[dim] if isinstance(dim, int) else tensor.shape - assert val is not None + if isinstance(dim, int): + val = IRObject(name='size', value=tensor.shape[dim], value_track=tensor.dim_tracks[dim]) + else: + val = tuple(IRObject('size', value=s, value_track=t) for s, t in zip(tensor.shape, tensor.dim_tracks)) + if dim is None: - return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)]) + return IRPyFunc(signature, [tensor], [val]) else: - return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)], dim=dim) + return IRPyFunc(signature, [tensor], [val], dim=dim) def Dim(tensor, signature=None) -> Union[List[int], IRPyFunc]: @@ -2617,7 +2634,7 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], if isinstance(obj, IRTensor): if name == 'shape': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - shape = IRObject('shape', value=obj.shape) + shape = tuple(IRObject('shape', value=s, value_track=t) for s, t in zip(obj.shape, obj.dim_tracks)) return IRPyFunc(signature, [instance, field], [shape]) if name == 'dtype': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 550b21f1..dce640e3 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -65,11 +65,10 @@ def __call__(self, *args): """ return self.forward(*args) - def forward(self, *args: Tuple[IRObject]) -> Union[IRTensor, Tuple[IRTensor]]: + def forward(self, *args: IRObject) -> Union[IRTensor, Tuple[IRTensor]]: """Forward the IRGraph to add model nodes into program. - Args: - args (Tuple[IRObject]): input IRObjects + args (Tuple[IRObject, ...]): input IRObjects Returns: Any: output that can be nested structure of IRObjects diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 3cb92257..a766c7d9 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -11,8 +11,9 @@ from nnscaler.graph.tracer.metadata import OpContext from nnscaler.ir.operator import IRFwOperation from nnscaler.ir.tensor import IRFullTensor -from nnscaler.ir.cten import IRObject, IRCell, IRTensor, IR +from nnscaler.ir.cten import IRObject, IRCell, IRTensor, IR, ValueTrack from nnscaler.graph.parser.frame import Frame +from nnscaler.graph.parser.value_tracker import ValueTracker from nnscaler.graph.parser.mapping import SignFx2Op from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import IRDimops @@ -104,15 +105,20 @@ def parse(module: torch.fx.GraphModule, tuple_val = tuple(IRObject(name=node.name, value=v, is_constant=val.is_constant) for v in val.value) frame.set_var(node.name, tuple_val) + value_tracker = ValueTracker() + # get graph inputs placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] inputs = [frame.get_var(n.name) for n in placeholders] + value_tracker.track_values(inputs) # - if the graph inputs contain nested strcuture, # it should be wrapped into an IRObject for idx, placeholder in enumerate(placeholders): if not isinstance(inputs[idx], IRObject): obj = IRObject(name=placeholder.target, value=inputs[idx], is_constant=False) + obj.value_track.with_no_dep() inputs[idx] = obj + value_tracker.track_values([obj]) frame.set_var(placeholder.name, obj) # parse graph nodes @@ -121,6 +127,8 @@ def parse(module: torch.fx.GraphModule, ir_nodes = FxModuleParser.parse_node(node, module, constant_folding, frame) all_ir_nodes += ir_nodes + value_tracker.track_nodes(all_ir_nodes) + # get graph outputs outputs = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] # currently fx graph always has only one output @@ -168,6 +176,18 @@ def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, tensor_types=(TensorMetadata,), is_constant=is_constant, ) + + if node.op == 'placeholder': + def set_no_dep(x: IRObject): + if isinstance(x, IRTensor): + # let's the value_track of tensor stay None(unknown) + # because we don't care about it. + for dt in x.dim_tracks: + dt.with_no_dep() + else: + x.value_track.with_no_dep() + IR.modify_objects(val, set_no_dep) + frame.add_var(node.name, val) @staticmethod @@ -315,7 +335,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # if the function has no output, just return return [ir_node] - vals = frame.get_var(node.name) + vals: Union[Any, IRObject, List[IRObject], IRTensor, List[IRTensor]] = frame.get_var(node.name) if len(ir_node.outputs()) == 1: vals = [vals] elif IR.is_object(vals): @@ -341,6 +361,13 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # 1. output tensors are not set in function.py # 2. IRObject output from some functions (registered functions/getattr) are not set # For above two cases, we need to set them with values from frame. + if isinstance(ir_node.output(i), IRTensor): + assert isinstance(vals[i], IRTensor), f'Expect tensor for output {i}, but got {type(vals[i])}' + assert ir_node.output(i).shape == vals[i].shape, f'Expect shape {ir_node.output(i).shape} for output {i}, but got {vals[i].shape}' + # We need to copy dim tracks + # As we will use frame version as node output, instead of the placeholder created in function.py + for dim in range(len(vals[i].shape)): + vals[i].dim_tracks[dim].merge_deps(ir_node.output(i).dim_tracks[dim]) ir_node.set_output(i, vals[i]) # update frame with ir output @@ -382,16 +409,13 @@ def _is_primitive_type(val): # use a white list instead of a black list return isinstance(val, (int, float, bool, type(None), str, type(Ellipsis))) - # Note when it is not IRObject as a whole, we will not fold it if constant_folding and ir_node.constant_foldable \ and len(ir_node.outputs()) == 1 \ - and isinstance(ir_node.output(0), IRObject) \ - and not isinstance(ir_node.output(0), IRTensor) \ and not contains_undefined_output \ and not ir_node.signature.startswith(nnscaler.runtime.function.__name__ + '.')\ - and ir_node.output(0).is_constant \ - and _is_primitive_type(ir_node.output(0).value): - frame.set_var(node.name, ir_node.output(0).value) + and not IR.contains_object(ir_node.output(0), lambda x: isinstance(x, IRTensor) or not x.is_constant) \ + and _is_primitive_type(cval := IR.try_unwrap(ir_node.output(0))): + frame.set_var(node.name, cval) return [] else: return [ir_node] @@ -414,7 +438,7 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, exist_tensor = frame.get_attr_var(concrete_value) # the case that the parameter is the first time used by getattr if not exist_tensor: - tensor = frame.get_var(node.name) + tensor: IRFullTensor = frame.get_var(node.name) # set tensor name same with the name in original model tensor.name = node.target if tensor.requires_grad: @@ -426,6 +450,11 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, direct_module = getattr(direct_module, name) persistent = full_qualified_name[-1] not in direct_module._non_persistent_buffers_set tensor.as_buffer(persistent=persistent) + + # Parameters and buffers have no dependency on other values + for dt in tensor.dim_tracks: + dt.with_no_dep() + frame.add_attr(tensor, concrete_value, node.target) # the case that the parameter is consumed multiple times and registered previously else: @@ -435,7 +464,10 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, # in sub modules, the target is full qualified name (for example `embeddings.dropout.training`) if node.target.split('.')[-1] == 'training': # Let's just support `self.training` and ignore all other cases for now - output = IRObject(name=node.name, value=frame.get_var(node.name), is_constant=False) + if isinstance(output := frame.get_var(node.name), IRObject): + output.is_constant = False + else: + output = IRObject(name=node.name, value=output, is_constant=False) ir_node = IRPyFunc(SELF_GETATTR_SIG, ['training'], [output]) FxModuleParser._set_node_meta(node, ir_node) frame.set_var(node.name, output) diff --git a/nnscaler/graph/parser/value_tracker.py b/nnscaler/graph/parser/value_tracker.py new file mode 100644 index 00000000..6419ca15 --- /dev/null +++ b/nnscaler/graph/parser/value_tracker.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import defaultdict +from typing import Any +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir.cten import IR, IRObject, IRTensor, ValueTrack +from nnscaler.ir.operator import IRFwOperation + + +class ValueTracker: + def __init__(self): + # value_id -> ValueTrack + # Please note some ValueTracks may be merged together (from annotation) + # So the key can be different from the id of the ValueTrack + self._vtm: dict[int, ValueTrack] = {} + self._equiv_value_ids: dict[int, set] = {} + + def track_values(self, objs: list[Any]): + for obj in objs: + self.track_value(obj) + + def track_value(self, obj: Any): + for item in IR.get_objects(obj): + if isinstance(item, IRTensor): + for dt in item.dim_tracks: + self._vtm[dt.value_id] = dt + elif isinstance(item, IRObject): + self._vtm[item.value_track.value_id] = item.value_track + + def _update_track_value(self, obj: Any): + if isinstance(obj, IRTensor): + new_dim_tracks = [] + for dt in obj.dim_tracks: + new_dim_tracks.append(self._vtm[dt.value_id]) + obj.dim_tracks = new_dim_tracks + elif isinstance(obj, IRObject): + obj.value_track = self._vtm[obj.value_track.value_id] + + def track_nodes(self, nodes: list[IRFwOperation]): + """ + Track the value tracks of the input and output objects in the given nodes. + Here we assume the nodes are topologically sorted. + """ + # collect all value tracks from nodes + for node in nodes: + for obj in node.iobjs(): + self.track_value(obj) + for obj in node.oobjs(): + self.track_value(obj) + + # init equivalence classes + for vt in self._vtm.values(): + self._equiv_value_ids[vt.value_id] = {vt.value_id} + + # collect extra value tracks from dimops + for node in nodes: + if isinstance(node, IRDimops): + self._track_dims(node) + + # merge equivalent value tracks together + for value_id, equiv_ids in self._equiv_value_ids.items(): + min_value_id = min(equiv_ids) + if value_id != min_value_id: + continue + + # use the smallest id as the representative + rep_one = self._vtm[min_value_id] + for vid in equiv_ids: + if vid == min_value_id: + continue + # TODO: how we merge dependencies? + # current we take union (Union may be too strict) + if rep_one.deps is None: + rep_one.deps = self._vtm[vid].deps + elif self._vtm[vid].deps is not None: + rep_one.deps = list(set(rep_one.deps).union(set(self._vtm[vid].deps))) + self._vtm[vid] = rep_one + + # dedup dependencies + # Here we will replace dependencies with their representative value tracks + # which can introduce some duplicates + for vt in self._vtm.values(): + if vt.deps is not None: + vt.deps = list(set(self._vtm[d].value_id for d in vt.deps)) + + # propagate the merged value tracks back to nodes + for node in nodes: + for obj in node.iobjs(): + self._update_track_value(obj) + for obj in node.oobjs(): + self._update_track_value(obj) + + def _track_dims(self, node: IRDimops): + """ + Track the dimension values of output tensors according to input tensors. + This function should be called after shape inference. + """ + # align the dim_ids of output with inputs + # not-hidden-dimension means the identifier is all for this dimension + # for example, in `l (2 h) m`, + # l and m are not-hidden-dimension identifiers, h is hidden-dimension identifier + # + # If the annotation is `l (2 h) m -> l h (m 2 h)` + # We will get the following relations (nhd->not-hidden-dimension, hd->hidden-dimension): + # 1. for `l`: `input.dim_tracks[0] is output.dim_tracks[0]` # both nhd, equality + # 2. for `m`: `input.dim_tracks[2].value_id in output.dim_tracks[2].deps` # one is hd, depencency + # 3. for `h`: `input.dim_tracks[1].value_id in output.dim_tracks[2].deps` # one is hd, depencency + # `input.dim_tracks[1] in output.dim_tracks[1].deps` # one is hd, depencency + + # TODO: We can handle more complex cases in the future if needed. + # In current version, we don't handle the case like + # 1. `(2 h) -> (2 h)`: input.dim_tracks[0] should be equal to output.dim_tracks[0]? (2 can be a runtime number, so we cannot be sure) + # 2. `(l m) -> (l m)`: input.dim_tracks[0] should be equal to output.dim_tracks[0]. + + # ivt => identifier_value_track_map + hidden_ivt: dict[str, list[ValueTrack]] = defaultdict(list) + non_hidden_ivt: dict[str, list[ValueTrack]] = defaultdict(list) + + for i, input_tensor in enumerate(node.inputs()): + if not isinstance(input_tensor, IRTensor) or node.ianno(i).ignore: + continue + + ianno = node.ianno(i) + for dim, dim_track in zip(ianno.dims, input_tensor.dim_tracks): + identifiers = [i for i in dim.identifiers if not str.isdecimal(i)] + if len(identifiers) == 1 and len(dim.identifiers) == 1: + # not hidden dimension + non_hidden_ivt[identifiers[0]].append(dim_track) + else: + for iden in identifiers: + hidden_ivt[iden].append(dim_track) + + for iden, iden_infos in non_hidden_ivt.items(): + # merge all not-hidden-dimension infos together + first = iden_infos[0] + for info in iden_infos[1:]: + self._add_equiv_value(first.value_id, info.value_id) + + for i, output_tensor in enumerate(node.outputs()): + if not isinstance(output_tensor, IRTensor) or node.oanno(i).ignore: + continue + + oanno = node.oanno(i) + for dim, dim_track in zip(oanno.dims, output_tensor.dim_tracks): + # find the first identifier that is not a number + identifiers = [i for i in dim.identifiers if not str.isdecimal(i)] + if len(identifiers) == 1 and len(dim.identifiers) == 1: + ident = identifiers[0] + if ident in non_hidden_ivt: + first = non_hidden_ivt[ident][0] + self._add_equiv_value(first.value_id, dim_track.value_id) + else: + # this identifier is used together with other identifiers + # so it is just a dependency. + dim_track.deps = dim_track.deps or [] + dim_track.deps.extend(v.value_id for v in hidden_ivt[ident]) + dim_track.deps = list(set(dim_track.deps)) # deduplicate + else: + dim_track.deps = dim_track.deps or [] + for ident in identifiers: + if ident in hidden_ivt: + dim_track.deps.extend(v.value_id for v in hidden_ivt[ident]) + if ident in non_hidden_ivt: + first = non_hidden_ivt[ident][0] + dim_track.deps.append(first.value_id) + + def _add_equiv_value(self, value_id, other_value_id): + self._equiv_value_ids[value_id].update(self._equiv_value_ids[other_value_id]) + for vid in self._equiv_value_ids[other_value_id]: + self._equiv_value_ids[vid] = self._equiv_value_ids[value_id] diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index f359fec3..bc8978ea 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -18,8 +18,9 @@ from __future__ import annotations +from dataclasses import dataclass, field from functools import lru_cache -from typing import ClassVar, List, Tuple, Union, Optional, Any, Dict, Callable +from typing import ClassVar, Iterable, List, Set, Tuple, Union, Optional, Any, Dict, Callable from collections import OrderedDict import copy import torch @@ -29,7 +30,7 @@ from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE -NestedVarOrStatic = Any +NestedVarOrStatic = Union[Any, 'IRObject', List['IRObject'], 'IRTensor'] class IRCell: @@ -459,6 +460,73 @@ def modify_objects_of_complex(val: Any, modifier: Callable[['IRObject'], 'IRObje return val +@dataclass +class ValueTrack: + """ + Track the value of an IRObject or a dimension of IRTensor. + Currently only implemented for dimension via IRDimops annotation. + + Example: + `l (2 h) m -> l h (2 m)`: + Input Tensor Tracks (2/5 is external dependencies for illustration): + dim 0: ValueTrack(value_id=10, dependencies=[]) # l + dim 1: ValueTrack(value_id=20, dependencies=[]) # (2 h) + dim 2: ValueTrack(value_id=30, dependencies=[2, 5]) # m + Then we can infer the output Tensor Tracks: + Output Tensor Tracks: + dim 0: ValueTrack(value_id=10, dependencies=[]) # reuse input dim 0, since they are the same + dim 1: ValueTrack(value_id=40, dependencies=[20]) # it depends on input dim 1: (2 h) + dim 2: ValueTrack(value_id=50, dependencies=[30]) # it depends on input dim 2: m + """ + value_id: int = field(default_factory=IDGenerator().gen_value_id) + # None: unknown dependencies + # []: no dependencies + deps: Optional[list[int]] = None + + def with_no_dep(self) -> 'ValueTrack': + """ + Initialize this ValueTrack with no dependencies. + """ + self.with_dep(None) + return self + + def with_dep(self, dep: Union[None, 'ValueTrack', 'IRObject'] = None) -> 'ValueTrack': + """ + Initialize or add a dependency to the ValueTrack. + If dep is None, just initialize an empty dependency list, which means no dependencies. + If dep is not IRObject or ValueTrack, do nothing. + """ + if self.deps is None: + self.deps = [] + + if not isinstance(dep, (ValueTrack, IRObject)): + return self + + if isinstance(dep, IRTensor): + raise TypeError("Cannot directly add IRTensor as dependency.") + + dep = dep.value_track if isinstance(dep, IRObject) else dep + dep_value_id = dep.value_id + if dep_value_id not in self.deps: + self.deps.append(dep_value_id) + + return self + + def merge_deps(self, other: ValueTrack) -> 'ValueTrack': + if self.deps is None: + self.deps = other.deps + else: + self.deps.extend(other.deps or []) + self.deps = list(set(self.deps)) + + @classmethod + def new(cls, deps: Iterable[Union[Any, 'ValueTrack', 'IRObject']]) -> 'ValueTrack': + vt = cls() + for dep in deps: + vt.with_dep(dep) + return vt + + class IRObject: """ IRObject serves as general data of IRGraph edge @@ -466,7 +534,15 @@ class IRObject: # will be set after class definition missing: ClassVar['IRObject'] = None - def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: Optional[None] = None, is_constant: bool = True): + def __init__( + self, + name: Optional[str] = None, + tid: Optional[int] = None, + value: Optional[None] = None, + is_constant: bool = True, + *, + value_track: Optional[ValueTrack] = None, + ) -> None: """ Args: name (str): object name @@ -486,6 +562,7 @@ def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: self._is_attr: bool = False self._value: Optional[Any] = value self._is_constant: bool = is_constant + self._value_track: ValueTrack = value_track or ValueTrack() def __hash__(self) -> int: return self._id @@ -538,6 +615,15 @@ def value(self) -> Any: """Get example value""" return self._value + @property + def value_track(self) -> ValueTrack: + """Get value track info""" + return self._value_track + + @value_track.setter + def value_track(self, val: ValueTrack): + self._value_track = val + @property def is_constant(self) -> bool: return self._is_constant @@ -555,7 +641,7 @@ def __copy__(self): """Copy this object but remove the cell information""" if self is IRObject.missing: # missing object is singleton return IRObject.missing - return IRObject(self.name, self._id, self._value, self._is_constant) + return IRObject(self.name, self._id, self._value, self._is_constant, value_track=self._value_track) def as_attr(self): """ @@ -651,7 +737,10 @@ def _inner(obj) -> Tuple[Any, bool]: new_ir_tensor._value = obj.value return new_ir_tensor, True else: - return IRObject(name, value=obj.value, is_constant=is_constant), False + return IRObject( + name, value=obj.value, + is_constant=is_constant, value_track=obj.value_track + ), False if isinstance(obj, tensor_types): if requires_grad is None: @@ -907,11 +996,12 @@ class IRTensor(IRObject): You can get the original shape with `origin_shape` property. """ def __init__(self, shape=None, name='tensor', dtype=None, tid=None, *, - is_attr=False, is_grad=False, requires_grad=False, persistent=False + is_attr=False, is_grad=False, requires_grad=False, persistent=False, ): super().__init__(name, tid, is_constant=False) self._is_scalar_tensor: bool = True - self._shape: Tuple[int] = () + self._shape: Tuple[int, ...] = () + self._dim_tracks: Tuple[ValueTrack, ...] = () self._dtype: Optional[torch.dtype] = None # tensor gradient self._is_grad: bool = False @@ -946,7 +1036,9 @@ def _update( if shape is not None: self._is_scalar_tensor = not shape # will always convert scalar tensor to 1-d tensor - self._shape: Tuple[int] = (1,) if not shape else tuple(shape) + self._shape: Tuple[int, ...] = (1,) if not shape else tuple(shape) + # reset dim tracks + self._dim_tracks = tuple(ValueTrack() for _ in self._shape) if name is not None or self.name is None: self.name = name if dtype is not None: @@ -1039,12 +1131,31 @@ def origin_shape(self) -> Tuple[int]: return self.shape if not self.is_scalar_tensor() else () @property - def shape(self) -> Tuple[int]: + def shape(self) -> Tuple[int, ...]: # NOTE: here return a tuple but not a real torch.Size obj may have risk, here is an example: # (torch.Size + tuple -> torch.Size) will change to (tuple + tuple -> tuple), is ok. # (torch.Size + list -> torch.Size) will change to (tuple + list -> error), is wrong. return self._shape + @property + def dim_tracks(self) -> Tuple[ValueTrack, ...]: + """ + Get the track of each dimension + """ + return self._dim_tracks + + @dim_tracks.setter + def dim_tracks(self, val: Tuple[Optional[ValueTrack], ...]): + """ + Set the unique id of each dimension + """ + if not isinstance(val, (list, tuple)): + raise ValueError("dim_tracks must be a list or tuple") + if len(val) != len(self._shape): + raise ValueError("dim_tracks length must be equal to shape length") + # None means starting a new dim track + self._dim_tracks = tuple(v if v is not None else ValueTrack() for v in val) + def nelement(self) -> int: """ Get total number of element in the tensor. diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index 6546720f..b24f45ca 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -27,10 +27,10 @@ 3) gradient of parameters """ -from typing import List, Optional, Union, Tuple, NewType, Dict, Any +from typing import List, Optional, Set, Union, Tuple, NewType, Dict, Any import torch -from nnscaler.ir.cten import IRTensor +from nnscaler.ir.cten import IRTensor, ValueTrack StartEnd = NewType('[start:end)', Tuple[int, int]) IdxChunk = NewType('(index, chunks)', Tuple[int, int]) @@ -260,14 +260,17 @@ class IRFullTensor(IRTensor): """ def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=None, *, - is_attr=False, is_grad=False, persistent=False, is_loss=False + is_attr=False, is_grad=False, persistent=False, is_loss=False, ): self._is_loss: bool = False # record all created sub_tensors self._subtensors : Dict[(ValueMap, IndexMap), int] = dict() self._grad: Optional[IRFullTensor] = None - super().__init__(shape, name, dtype, requires_grad=requires_grad, is_attr=is_attr, is_grad=is_grad, persistent=persistent) + super().__init__( + shape, name, dtype, requires_grad=requires_grad, + is_attr=is_attr, is_grad=is_grad, persistent=persistent, + ) self._update( is_loss=is_loss, ) @@ -334,6 +337,7 @@ def like(self): self.origin_shape, self.name, self._requires_grad, self._dtype, is_loss=self._is_loss ) + tensor.dim_tracks = self.dim_tracks return tensor def like_grad(self): @@ -346,6 +350,7 @@ def like_grad(self): self.origin_shape, 'g' + self.name, requires_grad=False, dtype=self.dtype ).as_grad(self._is_attr) + grad.dim_tracks = self.dim_tracks return grad @property @@ -363,6 +368,7 @@ def grad(self, val: Optional[IRTensor]): assert self._requires_grad, f"Cannot assign {val} to no grad-required tensor" assert val.origin_shape == self.origin_shape assert val.is_attr() == self.is_attr() + val.dim_tracks = self.dim_tracks # TODO: we should check the grad-required here # it is very common in current code that we assign None to grad # so currently it is impossible to check the grad-required here @@ -507,6 +513,7 @@ def __init__(self, ftensor: IRFullTensor, del self._is_grad del self._requires_grad del self._persistent + del self._dim_tracks self.cell = None # the index from full_tensor @@ -677,6 +684,10 @@ def dtype(self) -> Optional[torch.dtype]: """Tensor data type""" return self.parent.dtype + @property + def dim_tracks(self) -> Tuple[ValueTrack, ...]: + return self.parent.dim_tracks + @IRTensor.shape.setter def shape(self, val: Tuple[int]): # TODO: remove this function diff --git a/nnscaler/ir/unique.py b/nnscaler/ir/unique.py index dde3ceb2..72338ee5 100644 --- a/nnscaler/ir/unique.py +++ b/nnscaler/ir/unique.py @@ -5,14 +5,14 @@ class IDGenerator: """ Tensor / Operator manager. To guarantee that each IRTensor / IROperator id is unique and progressively increases. - + This class is designed in singleton pattern. """ class __IDGenerator: def __init__(self): - self._tensor_id = 0 self._cell_id = 0 + self._value_id = 0 instance = None @@ -31,13 +31,19 @@ def gen_cell_id(self): self.instance._cell_id += 1 return self.instance._cell_id + def gen_value_id(self): + self.instance._value_id += 1 + return self.instance._value_id + def get_states(self): - return (self._tensor_id, self._cell_id) - + return (self._tensor_id, self._cell_id, self._value_id) + def load_states(self, states: tuple): IDGenerator.instance._tensor_id = states[0] IDGenerator.instance._cell_id = states[1] + IDGenerator.instance._value_id = states[2] def clear(self): self.instance._tensor_id = 0 self.instance._cell_id = 0 + self.instance._value_id = 0 diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 4f254c7c..f7bd8954 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -232,6 +232,7 @@ def wrapped_fn(*args, **kwargs): _DICT_ITEMS_TYPE = type({}.items()) _DICT_KEYS_TYPE = type({}.keys()) _DICT_VALUES_TYPE = type({}.values()) +TRANSFORM_SUPPORTED_COLLECTION_TYPES = (tuple, list, dict, set, slice, _DICT_ITEMS_TYPE, _DICT_KEYS_TYPE, _DICT_VALUES_TYPE) def transform_recursively(data: Any, fn: Callable[[Any], Any], @@ -240,14 +241,18 @@ def transform_recursively(data: Any, fn: Callable[[Any], Any], ) -> Any: """ Transform the data with the given function, will recursively apply the function to the nested data. + Currently supported collection types is SUPPORTED_COLLECTION_TYPES. Args: data: the data to be transformed. fn: the function to apply. target_types: the target types to apply the function. collection_types: the collection types to apply the function to the nested data. + Will handle all supported types if None. skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. """ + if collection_types is None: + collection_types = TRANSFORM_SUPPORTED_COLLECTION_TYPES if isinstance(data, collection_types): if isinstance(data, tuple): return tuple(transform_recursively(t, fn, target_types, collection_types) for t in data) diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index f15fcc29..aed04fe6 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -44,7 +44,6 @@ def forward(self, x, **kwargs): assert any(node.op == 'call_function' and node.target == torch.nn.functional.linear for node in nodes) with tempfile.TemporaryDirectory() as tempdir: - to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) assert ir_graph is not None assert (Path(tempdir) / FxModuleParser.ATTR_MAP_FILE).exists() @@ -54,11 +53,43 @@ def forward(self, x, **kwargs): assert len(inputs) == 2 assert inputs[0].name == nodes[0].target assert isinstance(inputs[0], IRTensor) + assert inputs[0].value_track.deps == None + # inputs has no dependency + assert all(dt.deps == [] for dt in inputs[0].dim_tracks) assert inputs[1].name == nodes[1].target assert isinstance(inputs[1], IRObject) + assert inputs[1].value_track.deps == [] + + assert len(ir_graph.nodes()) == 1 + linear_node = ir_graph.nodes()[0] + assert len(linear_node.inputs()) == 3 # x, weight, bias + + assert all(isinstance(i, IRTensor) for i in linear_node.inputs()) + # from its annotation, a k^, n k^, n -> a n + # we can check the value_track and dim_track dependencies + + # the same with graph inputs + assert all(linear_node.input(0).dim_tracks[i] is inputs[0].dim_tracks[i] for i in range(len(inputs[0].dim_tracks))) + # weights has no dependency + assert linear_node.input(1).dim_tracks[0].deps == [] + # the `k` dimension + assert linear_node.input(1).dim_tracks[1] is inputs[0].dim_tracks[1] + # the `n` dimension + assert linear_node.input(2).dim_tracks[0] is linear_node.input(1).dim_tracks[0] + + assert len(linear_node.outputs()) == 1 + assert isinstance(linear_node.outputs()[0], IRTensor) + # `a` + assert linear_node.output(0).dim_tracks[0] is inputs[0].dim_tracks[0] + # `n` + assert linear_node.output(0).dim_tracks[1] is linear_node.input(1).dim_tracks[0] outputs = ir_graph.outputs() assert len(outputs) == 1 + # `a` + assert outputs[0].dim_tracks[0] is inputs[0].dim_tracks[0] + # `n` + assert outputs[0].dim_tracks[1] is linear_node.input(1).dim_tracks[0] nodes = list(ir_graph.nodes()) assert any(node.signature == 'torch.nn.functional.linear' for node in nodes) diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index a0bc33b8..176cba07 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -166,7 +166,7 @@ def forward(self, x): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) print(ir_graph.extra_repr()) - assert len(ir_graph.nodes()) == 5 + assert len(ir_graph.nodes()) == 4 assert len(ir_graph.nodes()[0].outputs()) == 3 assert len(ir_graph.outputs()) == 1 assert isinstance(ir_graph.output(0), list) diff --git a/tests/graph/parser/test_value_tracker.py b/tests/graph/parser/test_value_tracker.py new file mode 100644 index 00000000..86f8661d --- /dev/null +++ b/tests/graph/parser/test_value_tracker.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile + +import torch + +from nnscaler.graph.parser.converter import convert_model + +from ...utils import replace_all_device_with + + +@replace_all_device_with('cpu') +def test_hidden_dim(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return x.repeat(4, 1) + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 1 + node = ir_graph.node(0) + assert str(node.anno) == 'a^ b -> (4^ a^) b' + dim0_vi = node.input(0).dim_tracks[0].value_id + dim1_vi = node.input(0).dim_tracks[1].value_id + + assert node.output(0).dim_tracks[0].value_id != dim0_vi + assert node.output(0).dim_tracks[0].deps == [dim0_vi] + assert node.output(0).dim_tracks[1].value_id == dim1_vi + assert node.output(0).dim_tracks[1].deps == [] + + +@replace_all_device_with('cpu') +def test_equiv_class(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + x = x + 1 + y = y * 2 + return x@y + + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 'y': torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 3 + x_node = ir_graph.node(0) + y_node = ir_graph.node(1) + assert x_node.input(0).dim_tracks[0] is x_node.output(0).dim_tracks[0] + assert x_node.input(0).dim_tracks[1] is x_node.output(0).dim_tracks[1] + + assert y_node.input(0).dim_tracks[0] is y_node.output(0).dim_tracks[0] + assert y_node.input(0).dim_tracks[1] is y_node.output(0).dim_tracks[1] + + node = ir_graph.node(-1) + assert str(node.anno) == 'm k+, k+ n -> m n' + # the `k` dimension of input 1 should be the same as input 0 + # they are in the same equivalence class + assert node.input(0).dim_tracks[0] is x_node.input(0).dim_tracks[0] + assert node.input(0).dim_tracks[1] is x_node.input(0).dim_tracks[1] + assert node.input(1).dim_tracks[0] is node.input(0).dim_tracks[1] + assert node.input(1).dim_tracks[1] is y_node.input(0).dim_tracks[1] + + assert node.output(0).dim_tracks[0] is node.input(0).dim_tracks[0] + assert node.output(0).dim_tracks[1] is node.input(1).dim_tracks[1] + + +@replace_all_device_with('cpu') +def test_size(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + x = x + 1 + s = x.size() + y = torch.randn(s) + return x + y + + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 4 + size_node = ir_graph.node(1) + randn_node = ir_graph.node(2) + + assert size_node.output(0)[0].value_track is ir_graph.inputs()[0].dim_tracks[0] + assert size_node.output(0)[1].value_track is ir_graph.inputs()[0].dim_tracks[1] + + # dim tracks of randn node is from equivalence class originally from torch.add + assert randn_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[0] + assert randn_node.output(0).dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[1] diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 4a286dd7..c4306c2c 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -453,8 +453,13 @@ def test_codegen_getitem(): gen_savedir=tempdir, load_module=False, ) - assert _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') - assert _gencode_contains(tempdir, GetItemModule, 1, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') + + assert _gencode_contains(tempdir, GetItemModule, 0, r"_operator.getitem\(batched_data.*, 'x'\)") + assert _gencode_contains(tempdir, GetItemModule, 1, r"_operator.getitem\(batched_data.*, 'x'\)") + # data_x.size() will be expanded to a list of ir objects, + # so no slice operation will be generated. + assert not _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') + assert not _gencode_contains(tempdir, GetItemModule, 1, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') assert m_new is None @@ -1658,7 +1663,7 @@ def check_op(*names): for name in names: code = add_codes.pop(0) if name in not_folded_names: - assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, getitem_.*, alpha=1\)', code) + assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, get.*, alpha=1\)', code) else: assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, 2, alpha=1\)', code) @@ -1727,7 +1732,7 @@ def test_fold_constant(tmp_path, fold_input): else: # add_27 = torch.add(linear_30, getitem_20, alpha=1) assert _gencode_contains(tmp_path, CCFModule2, 0, - r'add_.* = torch\.add\(linear_.*, getitem_.*, alpha=1\)') + r'add_.* = torch\.add\(linear_.*, get.*, alpha=1\)') # b = b * ashape3 # mul_2_51 = torch.mul(mul_1_57, add_38) assert _gencode_contains(tmp_path, CCFModule2, 0, diff --git a/utility/visualize_value_tracks.py b/utility/visualize_value_tracks.py new file mode 100644 index 00000000..164b838e --- /dev/null +++ b/utility/visualize_value_tracks.py @@ -0,0 +1,158 @@ +import argparse +import matplotlib.pyplot as plt +from nnscaler.graph import IRGraph +from matplotlib.patches import FancyArrowPatch +from nnscaler.ir.cten import IR, IRTensor, IRObject + + +class Visualizer: + NUM_ROWS_PER_OP = 3 + TEXT_HEIGHT_IN_INCH = 0.4 + PER_OP_GAP_IN_INCH = 0.2 + PER_ROW_HEIGHT_IN_INCH = TEXT_HEIGHT_IN_INCH * 1.1 + PER_OP_HEIGHT_IN_INCH = PER_ROW_HEIGHT_IN_INCH * NUM_ROWS_PER_OP + PER_INOUT_GAP = 0.01 + + INIT_Y = 0.001 + INIT_X = 0.001 + + def __init__(self, graph): + self.graph = graph + self.value_loc = {} + self.ops = [node for node in self.graph.nodes() if node.isfw()] + + self.fig_heigth_in_inch = ( + self.PER_OP_HEIGHT_IN_INCH + self.PER_OP_GAP_IN_INCH + ) * (len(self.ops) + 1) + self.coord_per_inch = 1.0 / self.fig_heigth_in_inch + self.per_op_height = self.PER_OP_HEIGHT_IN_INCH * self.coord_per_inch + self.per_row_height = self.per_op_height / self.NUM_ROWS_PER_OP + self.per_op_gap = self.PER_OP_GAP_IN_INCH * self.coord_per_inch + + self.fig, self.ax = plt.subplots(figsize=(30, self.fig_heigth_in_inch)) + self.ax.axis('off') + self.ax.invert_yaxis() + + def draw_value(self, value, value_track, cur_x, cur_y, previous_value_loc): + t = self.ax.text(cur_x, cur_y, str(value), + fontsize=14, ha="left", va="top") + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + if value_track is not None: + if value_track.value_id in previous_value_loc: + prev_x, prev_y = previous_value_loc[value_track.value_id] + arrow = FancyArrowPatch( + (prev_x, prev_y), + (cur_x + bbox.width/2, cur_y), + arrowstyle="Simple,tail_width=0.25,head_width=1,head_length=1", + mutation_scale=6, + color="#2c7bb6", + linewidth=0.02, + connectionstyle="arc3,rad=0", + alpha=0.5, + zorder=4 + ) + self.ax.add_patch(arrow) + self.value_loc[value_track.value_id] = (cur_x + bbox.width/2, cur_y) + + cur_x += bbox.width + self.PER_INOUT_GAP/2 + return cur_x + + def draw_obj(self, obj, cur_x, cur_y, previous_value_loc): + if isinstance(obj, IRTensor): + cur_x = self.draw_value('T(', None, cur_x, cur_y, previous_value_loc) + for i, d in enumerate(obj.shape): + if i > 0: + cur_x = self.draw_value(',', None, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(str(d), obj.dim_tracks[i], cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(')', None, cur_x, cur_y, previous_value_loc) + else: + assert isinstance(obj, IRObject) + cur_x = self.draw_value('O(', None, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(str(obj.value), obj.value_track, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(')', None, cur_x, cur_y, previous_value_loc) + cur_x += self.PER_INOUT_GAP + return cur_x + + def draw_objs(self, objs, cur_x, cur_y): + previous_value_loc = dict(self.value_loc) + for inp in objs: + cur_x = self.draw_obj(inp, cur_x, cur_y, previous_value_loc) + + def draw_graph_inputs(self, g, cur_x, cur_y): + label = "GRAPH IN: " + t = self.ax.text(cur_x, cur_y, label, + fontsize=14, fontweight="bold", ha="left", va="top") + + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + cur_x = cur_x + bbox.width + self.PER_INOUT_GAP + + ir_objs = [] + for inp in g.inputs(): + if isinstance(inp, (IRObject, IRTensor)): + ir_objs.append(inp) + elif isinstance(inp, IRObject): + sub_objs = IR.get_objects(inp.value) + if sub_objs: + ir_objs.extend(sub_objs) + else: + ir_objs.append(inp) + + self.draw_objs(ir_objs, cur_x, cur_y) + + def draw_inout(self, node, cur_y, is_in): + if is_in: + ir_objs = node.iobjs() + label = "IN: " + cur_y += self.per_row_height + else: + ir_objs = node.oobjs() + label = "OU: " + cur_y += self.per_row_height * 2 + + t = self.ax.text(self.INIT_X, cur_y, label, + fontsize=14, fontweight="bold", ha="left", va="top") + + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + cur_x = self.INIT_X + bbox.width + self.PER_INOUT_GAP + + self.draw_objs(ir_objs, cur_x, cur_y) + + def visualize(self): + self.draw_graph_inputs(self.graph, self.INIT_X, self.INIT_Y) + cur_y = self.INIT_Y + (self.per_op_height + self.per_op_gap)/2 + + for node in self.ops: + op_name = node.name + self.ax.text(self.INIT_X, cur_y, op_name + ":", + fontsize=16, fontweight="bold", ha="left", va="top") + + self.draw_inout(node, cur_y, is_in=True) + self.draw_inout(node, cur_y, is_in=False) + + cur_y += self.per_op_height + self.per_op_gap + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + 'graphfile', + type=str, + help="Graph dump file" + ) + parser.add_argument( + 'imagefile', + type=str, + nargs='?', + default=None, + help="Save generated image to file" + ) + args = parser.parse_args() + g = IRGraph.load(args.graphfile) + visualizer = Visualizer(g) + visualizer.visualize() + if args.imagefile: + plt.savefig(args.imagefile, bbox_inches='tight', dpi=100) + plt.show() From 5d0db5a68b4606d341e1d98fb2299cc6a1437e96 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 13 Oct 2025 06:20:46 +0000 Subject: [PATCH 1844/1892] Merged PR 2413: [BwCompat] add send/recv_object_list for pytorch < 2.4 torch.distributed.send/recv_object_list are only availbe in pytorch >= 2.4. Here we copy the implementation to nnscaler so we can use it in any pytorch > 2.0 --- nnscaler/graph/parser/converter.py | 5 +- nnscaler/runtime/__init__.py | 1 + nnscaler/runtime/_patch_torch.py | 104 +++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 nnscaler/runtime/_patch_torch.py diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index a30dfa23..2b97c314 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -30,8 +30,11 @@ class no_save_tensor_hook(saved_tensors_hooks): """skip saving tensors for backward since tracer only traces forward""" def __init__(self): def pack(x): - return None + return (x.shape, x.dtype, x.device) def unpack(x): + # in pytorch 2.4.0-, torch.compile will call backward when tracing graph + if torch.__version__ < (2, 4, 0): + return torch.empty(x[0], dtype=x[1], device=x[2]) raise RuntimeError("not expecting backward to be called on this tensor") super().__init__(pack, unpack) diff --git a/nnscaler/runtime/__init__.py b/nnscaler/runtime/__init__.py index d0171757..46be9e99 100644 --- a/nnscaler/runtime/__init__.py +++ b/nnscaler/runtime/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from nnscaler.runtime import _patch_torch from nnscaler.runtime import executor from nnscaler.runtime import device from nnscaler.runtime import adapter diff --git a/nnscaler/runtime/_patch_torch.py b/nnscaler/runtime/_patch_torch.py new file mode 100644 index 00000000..53ab7438 --- /dev/null +++ b/nnscaler/runtime/_patch_torch.py @@ -0,0 +1,104 @@ +# The following code is copied from torch.distributed.distributed_c10d in PyTorch 2.4.0 +# For copyright, see pytorch/LICENSE +# https://github.com/pytorch/pytorch/blob/main/LICENSE + + +import torch +import torch.distributed + + +if torch.__version__ < (2, 4, 0): + # send_object_list and recv_object_list only available in PyTorch 2.4.0+ + + import torch.distributed.distributed_c10d as dist_c10d + + + if torch.__version__ < (2, 3, 0): + def _object_to_tensor(obj, device, group): + return dist_c10d._object_to_tensor(obj, device) + else: + def _object_to_tensor(obj, device, group): + return dist_c10d._object_to_tensor(obj, device, group) + + + if torch.__version__ < (2, 3, 0): + def _tensor_to_object(tensor, size, group): + return dist_c10d._tensor_to_object(tensor, size) + else: + def _tensor_to_object(tensor, size, group): + return dist_c10d._tensor_to_object(tensor, size, group) + + + def send_object_list(object_list, dst, group=None, device=None): + if torch.distributed.get_rank() == dst: + raise ValueError( + "Invalid destination rank: destination rank should not be the same as " + "the rank of the current process." + ) + + if dist_c10d._rank_not_in_group(group): + dist_c10d._warn_not_in_group("send_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # sent to this device. + current_device = device or torch.device("cuda", torch.cuda.current_device()) + # Serialize object_list elements to tensors on src rank. + tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + + # Send object sizes + torch.distributed.send(object_sizes_tensor, dst=dst, group=group) + + # Concatenate and send serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + + torch.distributed.send(object_tensor, dst=dst, group=group) + + + def recv_object_list(object_list, src=None, group=None, device=None): + if dist_c10d._rank_not_in_group(group): + dist_c10d._warn_not_in_group("recv_object_list") + return -1 + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # received to this device. + current_device = device or torch.device("cuda", torch.cuda.current_device()) + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + + # Receive object sizes + rank_sizes = torch.distributed.recv(object_sizes_tensor, src=src, group=group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device + ) + + rank_objects = torch.distributed.recv(object_tensor, src=src, group=group) + assert rank_sizes == rank_objects, "Mismatch in return ranks for object sizes and objects." + # Deserialize objects using their stored sizes. + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + return rank_objects + + torch.distributed.send_object_list = send_object_list + torch.distributed.recv_object_list = recv_object_list From e6f2587fdc91d57afcd766326f148b8b95e12f58 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 17 Oct 2025 07:20:56 +0000 Subject: [PATCH 1845/1892] Merged PR 2414: [Refine] Make parser a real class Old parser has only staticmethod, and make it behave like a namespace. In this PR, we change common arguments of staticmethods as member variable to simplify the code. --- nnscaler/graph/parser/__init__.py | 2 +- nnscaler/graph/parser/converter.py | 4 +- nnscaler/graph/parser/parser.py | 247 +++++++++++++++++------------ 3 files changed, 145 insertions(+), 108 deletions(-) diff --git a/nnscaler/graph/parser/__init__.py b/nnscaler/graph/parser/__init__.py index 1dea36e7..e7fa0900 100644 --- a/nnscaler/graph/parser/__init__.py +++ b/nnscaler/graph/parser/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from nnscaler.graph.parser.parser import FxModuleParser +from nnscaler.graph.parser.parser import FxModuleParser, parse_fx_module from nnscaler.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph from nnscaler.graph.parser.register import register from nnscaler.graph.parser.external import * diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 2b97c314..ae338b25 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -12,7 +12,7 @@ from nnscaler.graph import IRGraph from nnscaler.flags import CompileFlag -from nnscaler.graph.parser import FxModuleParser +from nnscaler.graph.parser import parse_fx_module from nnscaler.graph.tracer import concrete_trace from nnscaler.graph.tracer.wrap_utils import Location, is_autograd_apply, LeafWrapInfo from nnscaler.graph.tracer.torch_fx_patcher import side_effectful_inplace_ops @@ -149,7 +149,7 @@ def to_ir_graph( _logger.info(f"constant folding {'enabled' if constant_folding else 'disabled'} to parse graph") with no_save_tensor_hook(): - inputs, nodes, outputs = FxModuleParser.parse( + inputs, nodes, outputs = parse_fx_module( traced_model, dummy_input, attr_savedir=attr_savedir, constant_folding=constant_folding, diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index a766c7d9..7c6ca0b0 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -39,14 +39,14 @@ class FxModuleParser: ATTR_CONTENT_FILE_FORMAT = '{stem}.{idx}' ATTR_MAP_FILE = 'dist_param_map.pt' - @staticmethod - def parse(module: torch.fx.GraphModule, + def __init__(self, + module: torch.fx.GraphModule, dummy_inputs: Dict[str, Any], attr_savedir='./', *, save_content: bool = True, constant_folding: bool = False - ) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + ): """Parse torch.fx module into cube IR The overall entry to parse a torch.fx graph module @@ -57,6 +57,24 @@ def parse(module: torch.fx.GraphModule, attr_savedir (str): the directory to save the attribute content save_content (bool): whether to save the content of the module constant_folding (bool): whether to parse the module with constant folding + """ + + self.module = module + + self.dummy_inputs = dummy_inputs + assert isinstance(dummy_inputs, dict), f"Expected dummy inputs to parse module, but got {dummy_inputs} of type {type(dummy_inputs)}" + + self.attr_savedir = attr_savedir + self.save_content = save_content + self.constant_folding = constant_folding + + self.frame = Frame() + self.value_tracker = ValueTracker() + + def parse(self) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + """Parse torch.fx module into cube IR + + The overall entry to parse a torch.fx graph module Returns: inputs (List[IRObject]): the input IRObjects @@ -68,12 +86,10 @@ def parse(module: torch.fx.GraphModule, # (Those ops creators include user registered function, all functions returning tensors and more) # We will connect the real op outputs (saved in frame) to all ir op outputs and inputs later. - frame = Frame() - frame.push_var() + self.frame.push_var() # shape propagation - assert isinstance(dummy_inputs, dict), f"Expected dummy inputs to parse module, but got {dummy_inputs} of type {type(dummy_inputs)}" - output_nodes = [node for node in module.graph.nodes if node.op == 'output'] + output_nodes = [node for node in self.module.graph.nodes if node.op == 'output'] # currently fx graph always has only one output # even if a tuple/list is returned, it is still just one output assert len(output_nodes) == 1, f"Expect only one output, but got {len(output_nodes)}" @@ -82,11 +98,11 @@ def parse(module: torch.fx.GraphModule, assert len(output_node.args) == 1 and len(output_node.kwargs) == 0 # create IRObjects and IRTensors - for node in module.graph.nodes: + for node in self.module.graph.nodes: if node.op == 'placeholder': - FxModuleParser.init_objects(node, module, frame, is_constant=False) + self._init_objects(node, is_constant=False) else: - FxModuleParser.init_objects(node, module, frame, is_constant=True) + self._init_objects(node, is_constant=True) # note the output node will be reset later by `parse_prim_output_node` # with the help of `parse_complex` @@ -99,18 +115,16 @@ def parse(module: torch.fx.GraphModule, # to make sure the IRGraph has the correct output number # see `IRGrpah.from_logic_graph` - val = frame.get_var(node.name) + val = self.frame.get_var(node.name) if node == output_node.args[0] \ and IR.is_object(val) and isinstance(val.value, tuple): tuple_val = tuple(IRObject(name=node.name, value=v, is_constant=val.is_constant) for v in val.value) - frame.set_var(node.name, tuple_val) - - value_tracker = ValueTracker() + self.frame.set_var(node.name, tuple_val) # get graph inputs - placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] - inputs = [frame.get_var(n.name) for n in placeholders] - value_tracker.track_values(inputs) + placeholders = [n for n in self.module.graph.nodes if n.op == 'placeholder'] + inputs = [self.frame.get_var(n.name) for n in placeholders] + self.value_tracker.track_values(inputs) # - if the graph inputs contain nested strcuture, # it should be wrapped into an IRObject for idx, placeholder in enumerate(placeholders): @@ -118,52 +132,49 @@ def parse(module: torch.fx.GraphModule, obj = IRObject(name=placeholder.target, value=inputs[idx], is_constant=False) obj.value_track.with_no_dep() inputs[idx] = obj - value_tracker.track_values([obj]) - frame.set_var(placeholder.name, obj) + self.value_tracker.track_values([obj]) + self.frame.set_var(placeholder.name, obj) # parse graph nodes all_ir_nodes = [] - for node in module.graph.nodes: - ir_nodes = FxModuleParser.parse_node(node, module, constant_folding, frame) + for node in self.module.graph.nodes: + ir_nodes = self._parse_node(node) all_ir_nodes += ir_nodes - value_tracker.track_nodes(all_ir_nodes) + self.value_tracker.track_nodes(all_ir_nodes) # get graph outputs - outputs = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + outputs = [self.frame.get_var(node.name) for node in self.module.graph.nodes if node.op == 'output'] # currently fx graph always has only one output # even if a tuple/list is returned, it is still just one output assert len(outputs) == 1, f"Expect only one output, but got {len(outputs)}" - if save_content: - attr_savedir = Path(attr_savedir) - frame.save_attr_content(attr_savedir / FxModuleParser.ATTR_CONTENT_FILE_STEM) - frame.save_attr_map(attr_savedir / FxModuleParser.ATTR_MAP_FILE) + if self.save_content: + attr_savedir = Path(self.attr_savedir) + self.frame.save_attr_content(attr_savedir / self.ATTR_CONTENT_FILE_STEM) + self.frame.save_attr_map(attr_savedir / self.ATTR_MAP_FILE) - frame.pop_var() + self.frame.pop_var() return inputs, all_ir_nodes, outputs - @staticmethod - def parse_node(node: torch.fx.Node, module, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: + def _parse_node(self, node: torch.fx.Node) -> List[IRFwOperation]: """ Parse the node and return the IRFwOperation nodes """ if node.op == 'placeholder': return [] if node.op == 'output': - return FxModuleParser.parse_prim_output_node(node, module, frame) + return self._parse_prim_output_node(node) if node.op in ('call_function', 'call_method'): - return FxModuleParser.parse_prim_function_method(node, module, constant_folding, frame) + return self._parse_prim_function_method(node) if node.op == 'get_attr': - return FxModuleParser.parse_prim_get_attr_node(node, module, frame) + return self._parse_prim_get_attr_node(node) if node.op == 'call_module': - return FxModuleParser.parse_prim_module(node, module, frame) + return self._parse_prim_module(node) else: raise TypeError(f"Unknown node kind {node.op}") - @staticmethod - def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, - frame: Frame, is_constant: bool = True): + def _init_objects(self, node: torch.fx.Node, is_constant: bool = True): assert isinstance(node, torch.fx.Node) assert hasattr(node, 'meta') and 'tensor_meta' in node.meta, f"Node {node} should have tensor_meta" @@ -188,10 +199,9 @@ def set_no_dep(x: IRObject): x.value_track.with_no_dep() IR.modify_objects(val, set_no_dep) - frame.add_var(node.name, val) + self.frame.add_var(node.name, val) - @staticmethod - def parse_complex(val: Any, frame: Frame) -> Any: + def _parse_complex(self, val: Any) -> Any: """parse complex fx.Node into IRObject The val is usually from a node's input or output, can be fx.Node nested @@ -207,28 +217,28 @@ def parse_complex(val: Any, frame: Frame) -> Any: # to support more nested types, we can refer to the implementation of # https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py if isinstance(val, tuple): - return tuple(FxModuleParser.parse_complex(t, frame) for t in val) + return tuple(self._parse_complex(t) for t in val) if isinstance(val, list): - return list(FxModuleParser.parse_complex(t, frame) for t in val) + return list(self._parse_complex(t) for t in val) if isinstance(val, dict): - return {key: FxModuleParser.parse_complex(val, frame) for key, val in val.items()} + return {key: self._parse_complex(val) for key, val in val.items()} # TODO: Currently slice/DICT_VALUES_TYPE/DICT_ITEMS_TYPE cases are never found. # We need to find some examples to test them. if isinstance(val, slice): - return slice(FxModuleParser.parse_complex(val.start, frame), - FxModuleParser.parse_complex(val.stop, frame), - FxModuleParser.parse_complex(val.step, frame)) + return slice(self._parse_complex(val.start), + self._parse_complex(val.stop), + self._parse_complex(val.step)) # because fx node cannot be a dict key, so skip DICT_KEYS_TYPE here if isinstance(val, DICT_VALUES_TYPE): - return tuple(FxModuleParser.parse_complex(x, frame) for x in val) + return tuple(self._parse_complex(x) for x in val) if isinstance(val, DICT_ITEMS_TYPE): - return tuple((i, FxModuleParser.parse_complex(x, frame)) for i, x in val) + return tuple((i, self._parse_complex(x)) for i, x in val) if isinstance(val, torch.fx.Node): - return frame.get_var(val.name) + return self.frame.get_var(val.name) return val - @staticmethod - def fetch_attr(mod: torch.fx.GraphModule, target: str): + @classmethod + def _fetch_attr(cls, mod: torch.fx.GraphModule, target: str): target_atoms = target.split('.') attr_itr = mod for i, atom in enumerate(target_atoms): @@ -237,23 +247,19 @@ def fetch_attr(mod: torch.fx.GraphModule, target: str): attr_itr = getattr(attr_itr, atom) return attr_itr - @staticmethod - def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: - prim_module = FxModuleParser.fetch_attr(module, node.target) + def _parse_prim_module(self, node: torch.fx.Node) -> List[IRFwOperation]: + prim_module = self._fetch_attr(self.module, node.target) if prim_module.__class__.__module__.startswith('torch.nn.modules'): raise RuntimeError(f'{prim_module.__class__.__module__} can not be parsed as leaf nodes') else: raise RuntimeError(f'unknown module: {prim_module.__class__.__module__}') - @staticmethod - def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: + def _parse_prim_function_method(self, node: torch.fx.Node) -> List[IRFwOperation]: """ Convert `call_function`/`call_method` op to IRFwOperation. Args: node (torch.fx.Node): the node to be parsed - module (torch.fx.GraphModule): the module containing the node - constant_folding (bool): global setting of whether to fold the constant Returns: List[IRFwOperation]: the IRFwOperation nodes. @@ -262,10 +268,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule """ # get signature - fsig = FxModuleParser._get_qualified_name(node.target, node) + fsig = self._get_qualified_name(node.target, node) # get inputs - input_vals = FxModuleParser.parse_complex(list(node.args), frame) - kwargs = FxModuleParser.parse_complex(node.kwargs, frame) + input_vals = self._parse_complex(list(node.args)) + kwargs = self._parse_complex(node.kwargs) # use context constant_folding if set # Please note constant_folding only controls the output of the op @@ -273,6 +279,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # when we enter the code block with different constant folding setting # as a workaround, # you can use `nnscaler.runtime.function.fold_constant` to fold inputs if needed + constant_folding = self.constant_folding op_context: Optional[Dict[str, Any]] = node.meta.get('op_context') if op_context is not None and op_context.get(fields(OpContext).constant_folding) is not None: constant_folding = op_context[fields(OpContext).constant_folding] @@ -282,12 +289,12 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule else: # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator - if FxModuleParser._is_torch_autograd_op(node, frame, fsig): + if self._is_torch_autograd_op(node, fsig): _logger.warning(f'Find unknown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fsig ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) # case2: custom autograd function - elif FxModuleParser._is_custom_autograd_op(node): + elif self._is_custom_autograd_op(node): # custom autograd function _logger.warning(f'Find unknown custom autograd operation: {fsig}. You should register it with nnscaler.register_op') ir_node = IRFwOperation(fsig, fsig, input_vals, 1, **kwargs) @@ -300,7 +307,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule 'You can register it as a customized function using nnscaler.register_op to remove this warning' _logger.warning(warning_msg) is_constant = False - output = frame.get_var(node.name) + output = self.frame.get_var(node.name) if not isinstance(output, IRObject): # avoid nested IRObject output = IRObject(name=node.name, value=output, is_constant=is_constant) @@ -316,10 +323,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # As node is deleted, we must set concrete value or IRTensor/IRObject into framework. # TODO: check the value saved in frame should equal to the value returned by the op - frame.set_var(node.name, ir_node) + self.frame.set_var(node.name, ir_node) return [] - FxModuleParser._set_node_meta(node, ir_node) + self._set_node_meta(node, ir_node) # step 1: align the node output with the value in frame @@ -331,11 +338,11 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # but its output is used in other nodes. # By removing from frame, # we can catch the case earlier - frame.del_val(node.name) + self.frame.del_val(node.name) # if the function has no output, just return return [ir_node] - vals: Union[Any, IRObject, List[IRObject], IRTensor, List[IRTensor]] = frame.get_var(node.name) + vals: Union[Any, IRObject, List[IRObject], IRTensor, List[IRTensor]] = self.frame.get_var(node.name) if len(ir_node.outputs()) == 1: vals = [vals] elif IR.is_object(vals): @@ -348,7 +355,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule if not isinstance(vals, (list, tuple)): raise RuntimeError(f'Expect list or tuple for multiple outputs, but got {type(vals)}') vals = type(vals)(IRObject(name=node.name, value=v, is_constant=is_constant) for v in vals) - frame.set_var(node.name, vals) + self.frame.set_var(node.name, vals) # verify the inferred shape are consistent with actual output if isinstance(ir_node, IRFwOperation): @@ -372,7 +379,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # update frame with ir output # Please note when there is only one output, we will unwrap it from `ir_node.outputs()` here - frame.set_var( + self.frame.set_var( node.name, type(vals)(ir_node.outputs()) if len(ir_node.outputs()) > 1 else ir_node.output(0) ) @@ -380,6 +387,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # update the name of output tensors # Note assignment is not allowed in lambda # so we use a helper function to update the name + def _update_name(x: IRObject): x.name = node.name IR.modify_objects_inplace(ir_node.outputs(), _update_name) @@ -415,13 +423,12 @@ def _is_primitive_type(val): and not ir_node.signature.startswith(nnscaler.runtime.function.__name__ + '.')\ and not IR.contains_object(ir_node.output(0), lambda x: isinstance(x, IRTensor) or not x.is_constant) \ and _is_primitive_type(cval := IR.try_unwrap(ir_node.output(0))): - frame.set_var(node.name, cval) + self.frame.set_var(node.name, cval) return [] else: return [ir_node] - @staticmethod - def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + def _parse_prim_get_attr_node(self, node: torch.fx.Node) -> List[IRFwOperation]: """ There are two types of get_attr, one is `FxNodeKind.PrimGetAttr` which is dealt with in this function. The other is `FxNodeKind.PrimCallFunction ` (i.e., ) @@ -431,20 +438,20 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, node.target is the attribute name of the object. """ ir_nodes = [] - concrete_value = FxModuleParser.fetch_attr(module, node.target) + concrete_value = self._fetch_attr(self.module, node.target) if isinstance(concrete_value, torch.Tensor): assert isinstance(concrete_value, torch.Tensor), \ f"GetAttrPrim: expect tensor but got {type(concrete_value)}" - exist_tensor = frame.get_attr_var(concrete_value) + exist_tensor = self.frame.get_attr_var(concrete_value) # the case that the parameter is the first time used by getattr if not exist_tensor: - tensor: IRFullTensor = frame.get_var(node.name) + tensor: IRFullTensor = self.frame.get_var(node.name) # set tensor name same with the name in original model tensor.name = node.target if tensor.requires_grad: tensor.as_param() else: - direct_module = module + direct_module = self.module full_qualified_name = node.target.split('.') for name in full_qualified_name[:-1]: # last one is the attribute name direct_module = getattr(direct_module, name) @@ -455,38 +462,37 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, for dt in tensor.dim_tracks: dt.with_no_dep() - frame.add_attr(tensor, concrete_value, node.target) + self.frame.add_attr(tensor, concrete_value, node.target) # the case that the parameter is consumed multiple times and registered previously else: - frame.set_var(node.name, exist_tensor) + self.frame.set_var(node.name, exist_tensor) else: assert isinstance(node.target, str), f"GetAttrPrim: expect `node.target` to be str but got {type(node.target)}" # in sub modules, the target is full qualified name (for example `embeddings.dropout.training`) if node.target.split('.')[-1] == 'training': # Let's just support `self.training` and ignore all other cases for now - if isinstance(output := frame.get_var(node.name), IRObject): + if isinstance(output := self.frame.get_var(node.name), IRObject): output.is_constant = False else: output = IRObject(name=node.name, value=output, is_constant=False) ir_node = IRPyFunc(SELF_GETATTR_SIG, ['training'], [output]) - FxModuleParser._set_node_meta(node, ir_node) - frame.set_var(node.name, output) + self._set_node_meta(node, ir_node) + self.frame.set_var(node.name, output) # never fold the IRPyFunc node ir_nodes.append(ir_node) else: - frame.set_var(node.name, concrete_value) + self.frame.set_var(node.name, concrete_value) return ir_nodes - @staticmethod - def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: + def _parse_prim_output_node(self, node: torch.fx.Node) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 - output = FxModuleParser.parse_complex(node.args[0], frame) - frame.set_var(node.name, output) + output = self._parse_complex(node.args[0]) + self.frame.set_var(node.name, output) return [] - @staticmethod - def _set_node_meta(node: torch.fx.Node, ir_node: Union[IRCell, Any]): + @classmethod + def _set_node_meta(cls, node: torch.fx.Node, ir_node: Union[IRCell, Any]): if not isinstance(ir_node, IRCell): return @@ -497,16 +503,16 @@ def _set_node_meta(node: torch.fx.Node, ir_node: Union[IRCell, Any]): if comment: ir_node.comment = comment - @staticmethod - def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: + @classmethod + def _get_qualified_name(cls, node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: if isinstance(node_target, str): assert node is not None - return FxModuleParser._get_qualified_name_of_call_method(node_target, node) + return cls._get_qualified_name_of_call_method(node_target, node) else: - return FxModuleParser._get_qualified_name_of_call_function(node_target) + return cls._get_qualified_name_of_call_function(node_target) - @staticmethod - def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str: + @classmethod + def _get_qualified_name_of_call_function(cls, node_target: Callable[..., Any]) -> str: """ The target field of call_function node must be an callable object. """ @@ -516,12 +522,12 @@ def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str # TODO(yizhu1): find a general solution assert callable(node_target) name = node_target.__name__ - module = FxModuleParser._find_module_of_method(node_target) + module = cls._find_module_of_method(node_target) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module return f'{module}.{name}' - @staticmethod - def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> str: + @classmethod + def _get_qualified_name_of_call_method(cls, node_target: str, node: torch.fx.Node) -> str: """ The target field of call_method node must be a string. """ @@ -549,8 +555,8 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> else: return f'{in_type.__module__}.{in_type.__name__}.{node_target}' - @staticmethod - def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + @classmethod + def _find_module_of_method(cls, orig_method: Callable[..., Any]) -> str: if getattr(orig_method, '__name__', None) == 'apply' and isinstance(getattr(orig_method, '__self__', None), Type) \ and issubclass(orig_method.__self__, torch.autograd.Function): # for torch.autograd.Function @@ -583,18 +589,49 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str: return guess.__name__ raise RuntimeError(f'cannot find module for {orig_method}') - @staticmethod - def _is_torch_autograd_op(node: torch.fx.Node, frame: Frame, signature: str) -> bool: + def _is_torch_autograd_op(self, node: torch.fx.Node, signature: str) -> bool: """Check whether the node is of a pytorch autograd operation.""" # note: some python operations like torch.Tensor.size() doesn't return # an IRTensor, thus cannot be considered as a pytorch autograd operator. return signature.startswith('torch.') and \ - isinstance(frame.get_var(node.name), IRFullTensor) + isinstance(self.frame.get_var(node.name), IRFullTensor) - @staticmethod - def _is_custom_autograd_op(node: torch.fx.Node) -> bool: + @classmethod + def _is_custom_autograd_op(cls, node: torch.fx.Node) -> bool: node_target = node.target return callable(node_target) \ and getattr(node_target, '__name__', None) == 'apply' \ and isinstance(getattr(node_target, '__self__', None), Type) \ and issubclass(node_target.__self__, torch.autograd.Function) + + +def parse_fx_module( + module: torch.fx.GraphModule, + dummy_inputs: Dict[str, Any], + attr_savedir='./', + *, + save_content: bool = True, + constant_folding: bool = False +) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + """Parse torch.fx module into cube IR + + The overall entry to parse a torch.fx graph module + + Args: + module (torch.fx.GraphModule): the torch.fx module + dummy_inputs (Dict[str, Any]): the dummy inputs to run the module + attr_savedir (str): the directory to save the attribute content + constant_folding (bool): whether to parse the module with constant folding + + Returns: + inputs (List[IRObject]): the input IRObjects + all_ir_nodes (List[IRFwOperation]): the IRFwOperation nodes + outputs (List[IRObject]): the output IRObjects + """ + return FxModuleParser( + module, + dummy_inputs, + attr_savedir, + save_content=save_content, + constant_folding=constant_folding + ).parse() From 9d4b8bab58cb89c7b9f6a5e52bd9e45e4c324711 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 27 Oct 2025 08:38:33 +0000 Subject: [PATCH 1846/1892] Merged PR 2417: [Runtime] Refine ring attention related implementation - add test - add benchmark - integrate transformer engine for varlen --- benchmark/README.md | 357 +++++++++++++++ benchmark/benchmark_base.py | 426 ++++++++++++++++++ benchmark/benchmark_ring_attn.py | 206 +++++++++ benchmark/benchmark_ring_attn_varlen.py | 222 +++++++++ benchmark/benchmark_zigzag_attn.py | 243 ++++++++++ examples/llama/customized_ops/README.md | 22 - .../ring_attention/ring_attn_varlen.py | 145 ------ .../llama/customized_ops/test_ring_attn.py | 123 ----- .../customized_ops/test_ring_attn_varlen.py | 129 ------ .../llama/customized_ops/test_zigzag_attn.py | 123 ----- nnscaler/customized_ops/__init__.py | 0 .../customized_ops/ring_attention/README.md | 219 +++++++++ .../customized_ops/ring_attention/__init__.py | 0 .../core/ring_attn_implementation.py | 62 ++- .../core/ring_attn_varlen_implementation.py | 16 +- .../ring_attention/core/utils.py | 30 +- .../core/zigzag_attn_implementation.py | 0 .../ring_attention/ring_attn.py | 6 +- .../ring_attention/ring_attn_varlen.py | 308 +++++++++++++ .../ring_attention/varlen_utils.py | 182 ++++++++ .../ring_attention/zigzag_attn.py | 0 tests/customized_ops/__init__.py | 4 + tests/customized_ops/ring_attn/configs.py | 268 +++++++++++ .../ring_attn/ring_attn_runner.py | 107 +++++ .../ring_attn/ring_attn_varlen_runner.py | 96 ++++ tests/customized_ops/ring_attn/runner_base.py | 290 ++++++++++++ tests/customized_ops/ring_attn/test_base.py | 211 +++++++++ .../ring_attn/test_ring_attn.py | 56 +++ .../ring_attn/test_ring_attn_varlen.py | 53 +++ .../ring_attn/test_zigzag_attn.py | 73 +++ .../ring_attn/zigzag_attn_runner.py | 96 ++++ 31 files changed, 3498 insertions(+), 575 deletions(-) create mode 100644 benchmark/README.md create mode 100644 benchmark/benchmark_base.py create mode 100644 benchmark/benchmark_ring_attn.py create mode 100644 benchmark/benchmark_ring_attn_varlen.py create mode 100644 benchmark/benchmark_zigzag_attn.py delete mode 100644 examples/llama/customized_ops/README.md delete mode 100644 examples/llama/customized_ops/ring_attention/ring_attn_varlen.py delete mode 100644 examples/llama/customized_ops/test_ring_attn.py delete mode 100644 examples/llama/customized_ops/test_ring_attn_varlen.py delete mode 100644 examples/llama/customized_ops/test_zigzag_attn.py create mode 100644 nnscaler/customized_ops/__init__.py create mode 100644 nnscaler/customized_ops/ring_attention/README.md rename {examples/llama => nnscaler}/customized_ops/ring_attention/__init__.py (100%) rename {examples/llama => nnscaler}/customized_ops/ring_attention/core/ring_attn_implementation.py (78%) rename {examples/llama => nnscaler}/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py (96%) rename {examples/llama => nnscaler}/customized_ops/ring_attention/core/utils.py (87%) rename {examples/llama => nnscaler}/customized_ops/ring_attention/core/zigzag_attn_implementation.py (100%) rename {examples/llama => nnscaler}/customized_ops/ring_attention/ring_attn.py (95%) create mode 100644 nnscaler/customized_ops/ring_attention/ring_attn_varlen.py create mode 100644 nnscaler/customized_ops/ring_attention/varlen_utils.py rename {examples/llama => nnscaler}/customized_ops/ring_attention/zigzag_attn.py (100%) create mode 100644 tests/customized_ops/__init__.py create mode 100644 tests/customized_ops/ring_attn/configs.py create mode 100644 tests/customized_ops/ring_attn/ring_attn_runner.py create mode 100644 tests/customized_ops/ring_attn/ring_attn_varlen_runner.py create mode 100644 tests/customized_ops/ring_attn/runner_base.py create mode 100644 tests/customized_ops/ring_attn/test_base.py create mode 100644 tests/customized_ops/ring_attn/test_ring_attn.py create mode 100644 tests/customized_ops/ring_attn/test_ring_attn_varlen.py create mode 100644 tests/customized_ops/ring_attn/test_zigzag_attn.py create mode 100644 tests/customized_ops/ring_attn/zigzag_attn_runner.py diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..1cd09a68 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,357 @@ +# Ring Attention Performance Benchmarks + +This directory contains a unified performance benchmarking framework for all Ring Attention variants, built using a shared architecture that eliminates code duplication and provides consistent interfaces. + +## 🏗️ Architecture + +The benchmark framework consists of: + +### Core Framework +- **`benchmark_base.py`**: Shared benchmark framework extending the test framework +- **Configuration System**: Unified configuration management via `../tests/customized_ops/ring_attn/configs.py` + +### Attention Implementations +- **`benchmark_ring_attn.py`**: Standard Ring Attention benchmarks +- **`benchmark_ring_attn_varlen.py`**: Variable Length Ring Attention benchmarks +- **`benchmark_zigzag_attn.py`**: Zigzag Ring Attention benchmarks (causal-only) + +## 🚀 Quick Start + +### 1. List Available Configurations + +```bash +cd benchmark + +# List configurations for any benchmark variant +python benchmark_ring_attn_varlen.py --list-configs +python benchmark_ring_attn.py --list-configs +python benchmark_zigzag_attn.py --list-configs +``` + +### 2. Run Basic Benchmarks + +```bash +# Ring Attention Variable Length +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium + +# Standard Ring Attention +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config small + +# Zigzag Ring Attention (causal-only) +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config tiny +``` + +### 3. Advanced Usage + +```bash +# Custom timing parameters +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium --timing-method warmup --warmup-runs 5 --timing-runs 10 + +# Detailed profiling +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config large --timing-method profiler + +# Custom configurations (legacy support) +torchrun --nproc_per_node=2 benchmark_ring_attn.py --seqlen 8192 --nheads 16 --head-dim 128 --batch-size 4 +``` + +## 📋 Available Configurations + +The benchmark framework uses a comprehensive configuration system with predefined configurations for different testing scenarios. + +### Configuration Categories + +#### Small Configs (Quick Testing) +- **`tiny`**: 2×8×64, seq=1024, tokens=1K, bf16 [Causal] +- **`small`**: 4×12×128, seq=4096, tokens=4K, bf16 [Causal] +- **`small_fp16`**: 4×12×128, seq=4096, tokens=4K, fp16 [Non-causal] +- **`small_window`**: 4×12×128, seq=4096, tokens=4K, bf16 [Causal] [Window=512,0] + +#### Medium Configs (Standard Testing) +- **`medium`**: 4×24×128, seq=8192, tokens=8K, bf16 [Causal] +- **`medium_large_head`**: 4×12×256, seq=8192, tokens=8K, bf16 [Non-causal] +- **`medium_many_heads`**: 4×32×128, seq=8192, tokens=8K, bf16 [Causal] +- **`medium_fp16`**: 4×24×128, seq=8192, tokens=8K, fp16 [Causal] +- **`medium_window`**: 4×24×128, seq=8192, tokens=8K, bf16 [Causal] [Window=512,0] + +#### Large Configs (Performance Testing) +- **`large`**: 4×32×128, seq=16384, tokens=16K, bf16 [Causal] +- **`large_seq`**: 4×24×128, seq=32768, tokens=32K, bf16 [Causal] +- **`large_head`**: 4×24×256, seq=16384, tokens=16K, bf16 [Non-causal] +- **`xlarge`**: 8×32×128, seq=32768, tokens=32K, bf16 [Causal] +- **`large_window`**: 4×32×128, seq=16384, tokens=16K, bf16 [Causal] [Window=512,0] + +#### GQA Configs (Grouped Query Attention) +- **`qwen3_235b_a22b`**: 2×64×64, seq=16384, tokens=16K, bf16 (GQA 64→4) [Causal] +- **`qwen3_30b_a3b`**: 4×32×64, seq=16384, tokens=16K, bf16 (GQA 32→4) [Causal] +- **`qwen3_4b`**: 4×32×80, seq=16384, tokens=16K, bf16 (GQA 32→4) [Causal] +- **`qwen3_32b`**: 2×64×128, seq=16384, tokens=16K, bf16 (GQA 64→8) [Causal] +- **`qwen3_14b`**: 4×40×128, seq=16384, tokens=16K, bf16 (GQA 40→8) [Causal] + +#### Zigzag Configs (Causal-Only) +- **`zigzag_tiny`**: 2×8×64, seq=1024, tokens=1K, bf16 [Causal] +- **`zigzag_small`**: 4×12×128, seq=4096, tokens=4K, bf16 [Causal] +- **`zigzag_medium`**: 4×24×128, seq=8192, tokens=8K, bf16 [Causal] +- **`zigzag_large`**: 4×32×128, seq=16384, tokens=16K, bf16 [Causal] +- **`zigzag_fp16`**: 4×12×128, seq=4096, tokens=4K, fp16 [Causal] +- **`zigzag_gqa`**: 4×32×128, seq=8192, tokens=8K, bf16 (GQA 32→8) [Causal] + +### Default Configuration Sets +- **Correctness Testing**: `["tiny", "small", "medium"]` +- **Performance Testing**: `["medium", "large"]` +- **Multi-GPU Testing**: `["small", "medium"]` +- **GQA Testing**: `["qwen3_4b", "qwen3_14b", "qwen3_32b"]` +- **Zigzag Testing**: `["zigzag_tiny", "zigzag_small", "zigzag_medium"]` + +## 🔧 Features + +### Unified Framework +- **Shared Base Class**: All benchmarks extend `RingAttnBenchmarkBase` for consistency +- **Code Reuse**: Leverages test framework components (`test_base.py`, `runner_base.py`) +- **Consistent Interface**: Same command-line options across all attention variants + +### Multiple Timing Methods +- **`simple`**: Basic CUDA timing measurements (fastest) +- **`warmup`**: Multiple runs with warm-up (recommended for accurate results) +- **`profiler`**: torch.profiler with detailed kernel analysis + +### Comprehensive Metrics +- **Performance**: Forward/backward timing, throughput (tokens/sec) +- **Scalability**: Speedup analysis, parallel efficiency +- **Memory**: GPU memory usage tracking +- **Comparative**: Single vs. parallel mode analysis + +### Configuration Support +- **Predefined Configs**: 20+ predefined configurations covering different scales +- **Legacy Parameters**: Backward compatibility with custom parameters +- **Attention Variants**: Support for standard, variable-length, and zigzag attention +- **GQA Support**: Grouped Query Attention configurations based on Qwen models + +## 🧪 Usage Examples + +### Basic Performance Testing +```bash +# Quick benchmarks with different attention types +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config tiny --timing-method simple +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config small --timing-method warmup +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config medium --dtype fp16 +``` + +### Comparative Analysis +```bash +# Compare different attention mechanisms on same config +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --timing-method warmup +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium --timing-method warmup +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config medium --timing-method warmup +``` + +### Advanced Profiling +```bash +# Detailed profiler analysis +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config large --timing-method profiler + +# Custom timing parameters for high precision +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --timing-method warmup --warmup-runs 10 --timing-runs 20 +``` + +### GQA Performance Testing +```bash +# Test Grouped Query Attention configurations +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config qwen3_4b --timing-method warmup +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config qwen3_14b --timing-method warmup +``` + +### Legacy Support (Custom Parameters) +```bash +# Override specific parameters while using predefined base +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --seqlen 16384 --nheads 32 + +# Full custom configuration +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --seqlen 8192 --nheads 16 --head-dim 128 --batch-size 4 --dtype bf16 +``` + +## 📈 Output Interpretation + +The benchmark framework provides comprehensive performance analysis: + +### Performance Metrics +``` +================================================================================ +RING ATTENTION VARIABLE LENGTH PERFORMANCE BENCHMARK (WARMUP METHOD) +Configuration: medium - medium + Sequence length: 8192 + Batch size: 4 + Heads: 24 + Head dim: 128 + Data type: bf16 + World size: 2 GPUs + Total tokens: 8,192 + (Warmup runs: 3, Timing runs: 5) +================================================================================ +Single Mode: + Forward time: 0.001234 seconds + Backward time: 0.002345 seconds + Total time: 0.003579 seconds + Throughput: 2288764 tokens/sec + +Parallel Mode: + Forward time: 0.000987 seconds + Backward time: 0.001654 seconds + Total time: 0.002641 seconds + Throughput: 3102234 tokens/sec + +Speedup: + Forward speedup: 1.25x + Backward speedup: 1.42x + Total speedup: 1.35x + Throughput improvement: 1.35x + +Efficiency: + Theoretical speedup: 2x + Actual speedup: 1.35x + Parallel efficiency: 67.7% +================================================================================ +``` + +### Key Metrics Explained +- **Forward/Backward Time**: Separate timing for forward and backward passes +- **Throughput**: Tokens processed per second (higher = better) +- **Speedup**: Performance ratio vs single GPU (higher = better) +- **Parallel Efficiency**: Actual speedup / theoretical speedup (closer to 100% = better) + +### Profiler Output (when using `--timing-method profiler`) +When using the profiler method, you get additional detailed analysis: +- Kernel-level timing breakdown +- Memory bandwidth utilization +- CUDA kernel execution patterns +- Optimization recommendations + +## 🎯 Attention Variant Characteristics + +### Ring Attention (`benchmark_ring_attn.py`) +- **Format**: Standard batch format `[batch_size, seq_len, num_heads, head_dim]` +- **Use Case**: General purpose attention for standard transformer models +- **Constraints**: Supports both causal and non-causal attention, sliding windows + +### Ring Attention Variable Length (`benchmark_ring_attn_varlen.py`) +- **Format**: Packed format `[total_tokens, num_heads, head_dim]` with `cu_seqlens` +- **Use Case**: Optimized for variable-length sequences, eliminates padding waste +- **Constraints**: Supports causal/non-causal attention, sliding windows + +### Zigzag Attention (`benchmark_zigzag_attn.py`) +- **Format**: Standard batch format `[batch_size, seq_len, num_heads, head_dim]` +- **Use Case**: Specialized for causal attention with optimized communication pattern +- **Constraints**: **Only supports causal=True and window_size=(-1,-1)** + +## 🔗 Integration with Test Framework + +The benchmark framework is tightly integrated with the correctness test framework: + +### Shared Components +- **Configuration System**: Same `configs.py` used for both correctness and performance testing +- **Base Classes**: Reuses `RingAttnRunnerBase` from `runner_base.py` +- **Distributed Setup**: Shared GPU detection and distributed initialization +- **Error Handling**: Consistent tolerance and validation logic + +### Workflow Integration +```bash +# 1. Run correctness tests first +cd /path/to/MagicCube +pytest tests/customized_ops/ring_attn/test_ring_attn_varlen.py --config tiny + +# 2. Then run performance benchmarks +cd benchmark +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config tiny +``` + +## ⚠️ Requirements & Setup + +### System Requirements +- **Multi-GPU Setup**: Most benchmarks require 2+ GPUs (use `torchrun --nproc_per_node=N`) +- **GPU Memory**: Large configs may require high-memory GPUs (A100, H100 recommended) +- **CUDA**: Compatible CUDA installation (11.8+ recommended) +- **Python Environment**: PyTorch with NCCL support for distributed training + +### Optional Components +- **TransformerEngine**: Install TE 2.2.0+ for optimal performance (auto-detected) +- **Flash Attention**: Required for base attention implementations +- **InfiniBand**: Recommended for multi-node setups (reduces communication latency) + +### Environment Setup +```bash +# From MagicCube root directory +cd benchmark + +# Verify imports work correctly +python -c " +from benchmark_base import RingAttnBenchmarkBase +print('✓ Benchmark framework ready') +" + +# Test configuration system +python benchmark_ring_attn_varlen.py --list-configs +``` + +## 🚨 Troubleshooting + +### Common Issues + +#### GPU/Memory Issues +```bash +# OOM errors: Use smaller configs or reduce batch size +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config tiny # Instead of large + +# Insufficient GPUs: Check available GPUs +python -c "import torch; print(f'Available GPUs: {torch.cuda.device_count()}')" +``` + +#### Import/Path Issues +```bash +# Import errors: Ensure running from correct directory +cd /path/to/MagicCube/benchmark +python benchmark_ring_attn.py --help + +# Configuration import errors +python -c " +import sys, os +sys.path.insert(0, '../tests/customized_ops/ring_attn') +from configs import get_config +print('✓ Config system working') +" +``` + +#### Distributed Training Issues +```bash +# NCCL errors: Check GPU compatibility and CUDA setup +export NCCL_DEBUG=INFO # For detailed NCCL debugging + +# Port conflicts: Use different port +torchrun --master_port=29501 --nproc_per_node=2 benchmark_ring_attn.py --config tiny +``` + +### Performance Debugging +```bash +# Test basic functionality without distributed training +CUDA_VISIBLE_DEVICES=0 python -c " +from benchmark_ring_attn import RingAttnBenchmark +print('✓ Benchmark classes load correctly') +" + +# Verify attention implementations work +cd ../tests/customized_ops/ring_attn +pytest test_ring_attn.py::TestRingAttn::test_ring_attn_tiny -v +``` + +**Note**: Actual efficiency depends on hardware, network, and system configuration. + +## 📚 Related Documentation + +### Core Documentation +- **Ring Attention Implementation**: `../nnscaler/customized_ops/ring_attention/README.md` +- **Test Framework**: `../tests/customized_ops/ring_attn/README.md` +- **Development Guide**: `../dev_docs/README_refactoring.md` +- **Testing Results**: `../dev_docs/benchmark_testing_results.md` + +--- + +**For implementation details**: See `../nnscaler/customized_ops/ring_attention/` +**For correctness testing**: See `../tests/customized_ops/ring_attn/` \ No newline at end of file diff --git a/benchmark/benchmark_base.py b/benchmark/benchmark_base.py new file mode 100644 index 00000000..4226d2ed --- /dev/null +++ b/benchmark/benchmark_base.py @@ -0,0 +1,426 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base benchmark framework for ring attention performance tests. +This module extends the test framework to support performance benchmarking. +""" + +import os +import sys +import time +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Tuple, Callable + +import torch +import torch.distributed as dist +from torch.profiler import profile, ProfilerActivity + +# Add tests directory to path to import test framework +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) + +from runner_base import RingAttnRunnerBase +from configs import get_config, get_configs_by_category, DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnBenchmarkBase(RingAttnRunnerBase): + """Base class for ring attention performance benchmarks""" + + def __init__(self): + super().__init__() + self.timing_method = "warmup" + self.warmup_runs = 3 + self.timing_runs = 5 + + @abstractmethod + def get_benchmark_name(self) -> str: + """Return the benchmark name for display""" + pass + + def run_timing_with_warmup(self, forward_fn: Callable, backward_fn: Callable, + warmup_runs: int = None, timing_runs: int = None) -> Tuple[float, float, Any]: + """Run timing with warm-up runs to get accurate measurements.""" + warmup_runs = warmup_runs or self.warmup_runs + timing_runs = timing_runs or self.timing_runs + + # Warm-up runs + for _ in range(warmup_runs): + torch.cuda.synchronize() + output = forward_fn() + torch.cuda.synchronize() + backward_fn(output) + torch.cuda.synchronize() + + # Timing runs + forward_times = [] + backward_times = [] + + for _ in range(timing_runs): + # Forward timing + torch.cuda.synchronize() + start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_time = time.perf_counter() - start + forward_times.append(forward_time) + + # Backward timing + torch.cuda.synchronize() + start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_time = time.perf_counter() - start + backward_times.append(backward_time) + + # Return average times + avg_forward = sum(forward_times) / len(forward_times) + avg_backward = sum(backward_times) / len(backward_times) + return avg_forward, avg_backward, output + + def run_timing_with_profiler(self, forward_fn: Callable, backward_fn: Callable, + rank_id: int = 0) -> Tuple[float, float, Any]: + """Run timing using torch.profiler for detailed analysis.""" + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + # Run profiler with timing + torch.cuda.synchronize() + + with profile(activities=activities, record_shapes=True, with_stack=True) as prof: + torch.cuda.synchronize() + forward_start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_end = time.perf_counter() + + torch.cuda.synchronize() + backward_start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_end = time.perf_counter() + + torch.cuda.synchronize() + + # Calculate timing from our measurements + forward_time = forward_end - forward_start + backward_time = backward_end - backward_start + + if rank_id == 0: + self._print_profiler_results(prof) + + return forward_time, backward_time, output + + def run_timing_simple(self, forward_fn: Callable, backward_fn: Callable) -> Tuple[float, float, Any]: + """Run simple timing without warmup or profiling.""" + torch.cuda.synchronize() + forward_start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_time = time.perf_counter() - forward_start + + torch.cuda.synchronize() + backward_start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_time = time.perf_counter() - backward_start + + return forward_time, backward_time, output + + def _print_profiler_results(self, prof): + """Print profiler results with fallback for different PyTorch versions.""" + print("\n" + "="*60) + print("TORCH PROFILER RESULTS") + print("="*60) + + try: + # Try the most common sorting options + events = prof.key_averages() + table_str = events.table(sort_by="self_cuda_time_total", row_limit=20) + print(table_str) + except Exception as e1: + try: + table_str = events.table(sort_by="cuda_time_total", row_limit=20) + print(table_str) + except Exception as e2: + try: + table_str = events.table(sort_by="self_cpu_time_total", row_limit=20) + print(table_str) + except Exception as e3: + print(f"Warning: Could not generate profiler table due to API differences") + print(f"Errors: {e1}, {e2}, {e3}") + + # Fallback: print basic event info + print("Available profiler events:") + for i, event in enumerate(events): + if i >= 10: # Limit output + break + try: + print(f" {event.key}: CPU time = {getattr(event, 'cpu_time_total', 'N/A')} us") + except: + print(f" {event.key}: [timing info unavailable]") + + print("="*60 + "\n") + + def create_timing_functions(self, inputs, config, dout_tensor): + """Create timing functions for single and parallel execution.""" + # Single mode functions + def single_forward(): + single_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + single_inputs[k] = v.detach().clone().requires_grad_() + else: + single_inputs[k] = v.detach().clone() + else: + single_inputs[k] = v + + # Run single GPU reference + output, grad_tensors = self.run_single_gpu_reference(single_inputs, config) + return output, (single_inputs, grad_tensors) + + def single_backward(outputs): + output, (single_inputs, grad_tensors) = outputs + output.backward(dout_tensor) + return dout_tensor + + # Parallel mode functions + model = self.create_test_module(config) + dummy_args = self.get_dummy_forward_args(inputs) + + from nnscaler.parallel import parallelize, ComputeConfig, ReuseType + world_size = dist.get_world_size() + + parallel_model = parallelize( + model, + dummy_forward_args=dummy_args, + pas_policy=self.create_policy(), + compute_config=ComputeConfig(world_size, world_size), + reuse=ReuseType.OVERRIDE + ) + parallel_model = parallel_model.cuda() + parallel_model.train() + + def parallel_forward(): + para_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + para_inputs[k] = v.detach().clone().requires_grad_() + else: + para_inputs[k] = v.detach().clone() + else: + para_inputs[k] = v + + output = parallel_model(**para_inputs) + return output, para_inputs + + def parallel_backward(outputs): + output, para_inputs = outputs + output.backward(dout_tensor) + parallel_model.sync_grad() + return dout_tensor + + return single_forward, single_backward, parallel_forward, parallel_backward + + def calculate_throughput_metrics(self, config, forward_time: float, backward_time: float) -> Dict[str, float]: + """Calculate throughput and efficiency metrics.""" + total_time = forward_time + backward_time + + # Calculate total tokens processed + if hasattr(config, 'total_tokens'): + total_tokens = config.total_tokens + else: + total_tokens = config.batch_size * config.max_seqlen + + throughput = total_tokens / total_time if total_time > 0 else 0 + + return { + 'total_tokens': total_tokens, + 'throughput_tokens_per_sec': throughput, + 'total_time': total_time, + 'forward_time': forward_time, + 'backward_time': backward_time + } + + def print_benchmark_results(self, config_name: str, config, dtype: str, + single_metrics: Dict[str, float], + parallel_metrics: Dict[str, float], + world_size: int, rank_id: int): + """Print comprehensive benchmark results.""" + if rank_id != 0: + return + + print("\n" + "="*80) + print(f"{self.get_benchmark_name().upper()} PERFORMANCE BENCHMARK ({self.timing_method.upper()} METHOD)") + print(f"Configuration: {config_name} - {config.name}") + print(f" Sequence length: {config.max_seqlen}") + print(f" Batch size: {config.batch_size}") + print(f" Heads: {config.num_heads}") + print(f" Head dim: {config.head_dim}") + print(f" Data type: {dtype}") + print(f" World size: {world_size} GPUs") + print(f" Total tokens: {single_metrics['total_tokens']:,}") + + if self.timing_method == "warmup": + print(f" (Warmup runs: {self.warmup_runs}, Timing runs: {self.timing_runs})") + print("="*80) + + # Timing results + print(f"Single Mode:") + print(f" Forward time: {single_metrics['forward_time']:.6f} seconds") + print(f" Backward time: {single_metrics['backward_time']:.6f} seconds") + print(f" Total time: {single_metrics['total_time']:.6f} seconds") + print(f" Throughput: {single_metrics['throughput_tokens_per_sec']:.0f} tokens/sec") + + print(f"\nParallel Mode:") + print(f" Forward time: {parallel_metrics['forward_time']:.6f} seconds") + print(f" Backward time: {parallel_metrics['backward_time']:.6f} seconds") + print(f" Total time: {parallel_metrics['total_time']:.6f} seconds") + print(f" Throughput: {parallel_metrics['throughput_tokens_per_sec']:.0f} tokens/sec") + + # Speedup calculations + forward_speedup = single_metrics['forward_time'] / parallel_metrics['forward_time'] if parallel_metrics['forward_time'] > 0 else 0 + backward_speedup = single_metrics['backward_time'] / parallel_metrics['backward_time'] if parallel_metrics['backward_time'] > 0 else 0 + total_speedup = single_metrics['total_time'] / parallel_metrics['total_time'] if parallel_metrics['total_time'] > 0 else 0 + throughput_improvement = parallel_metrics['throughput_tokens_per_sec'] / single_metrics['throughput_tokens_per_sec'] if single_metrics['throughput_tokens_per_sec'] > 0 else 0 + + print(f"\nSpeedup:") + print(f" Forward speedup: {forward_speedup:.2f}x") + print(f" Backward speedup: {backward_speedup:.2f}x") + print(f" Total speedup: {total_speedup:.2f}x") + print(f" Throughput improvement: {throughput_improvement:.2f}x") + + # Efficiency metrics + theoretical_speedup = world_size + efficiency = total_speedup / theoretical_speedup * 100 if theoretical_speedup > 0 else 0 + print(f"\nEfficiency:") + print(f" Theoretical speedup: {theoretical_speedup:.0f}x") + print(f" Actual speedup: {total_speedup:.2f}x") + print(f" Parallel efficiency: {efficiency:.1f}%") + print("="*80 + "\n") + + def run_performance_benchmark(self, config_name: str = None, dtype: str = "bf16", + timing_method: str = "warmup", warmup_runs: int = 3, + timing_runs: int = 5, **legacy_kwargs): + """Run performance benchmark for the specific attention implementation.""" + # Setup timing parameters + self.timing_method = timing_method + self.warmup_runs = warmup_runs + self.timing_runs = timing_runs + + # Initialize distributed environment + world_size, rank = self.initialize_distributed() + rank_id = dist.get_rank() + + # Get configuration + config = get_config(config_name) if config_name else self._create_legacy_config(**legacy_kwargs) + torch_dtype = torch.bfloat16 if dtype == "bf16" else torch.float16 + + if rank_id == 0: + print(f"Running {self.get_benchmark_name()} performance benchmark...") + print(f"Configuration: {config.name if hasattr(config, 'name') else 'custom'}") + + # Prepare inputs + device = torch.device(f"cuda:{rank_id}") + inputs = self.prepare_inputs(config, device, torch_dtype) + + # Broadcast inputs to ensure consistency + for tensor in inputs.values(): + if isinstance(tensor, torch.Tensor): + dist.broadcast(tensor, src=0) + dist.barrier() + + # Pre-generate dout tensor for timing consistency + with torch.no_grad(): + dummy_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + dummy_inputs[k] = v.detach() + else: + dummy_inputs[k] = v + dummy_out, _ = self.run_single_gpu_reference(dummy_inputs, config) + dout_tensor = torch.randn_like(dummy_out, device=device, dtype=torch_dtype) + dist.broadcast(dout_tensor, src=0) + + # Create timing functions + single_forward, single_backward, parallel_forward, parallel_backward = self.create_timing_functions( + inputs, config, dout_tensor + ) + + if rank_id == 0: + print(f"Running performance benchmark using {timing_method} method...", end="") + + # Run timing based on method + if timing_method == "profiler": + single_forward_time, single_backward_time, _ = self.run_timing_with_profiler( + single_forward, single_backward, rank_id + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_with_profiler( + parallel_forward, parallel_backward, rank_id + ) + elif timing_method == "warmup": + single_forward_time, single_backward_time, _ = self.run_timing_with_warmup( + single_forward, single_backward, warmup_runs, timing_runs + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_with_warmup( + parallel_forward, parallel_backward, warmup_runs, timing_runs + ) + else: # simple + single_forward_time, single_backward_time, _ = self.run_timing_simple( + single_forward, single_backward + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_simple( + parallel_forward, parallel_backward + ) + + if rank_id == 0: + print(" Done!") + + # Calculate metrics and print results + single_metrics = self.calculate_throughput_metrics(config, single_forward_time, single_backward_time) + parallel_metrics = self.calculate_throughput_metrics(config, parallel_forward_time, parallel_backward_time) + + self.print_benchmark_results( + config_name or "custom", config, dtype, + single_metrics, parallel_metrics, world_size, rank_id + ) + + # Cleanup + dist.destroy_process_group() + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters.""" + class LegacyConfig: + def __init__(self, **kwargs): + self.name = "legacy_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + # Add other default attributes as needed + + return LegacyConfig(**kwargs) + + def list_configurations(self): + """List all available configurations for benchmarking.""" + print("Available Ring Attention Configurations:") + print("=" * 50) + + for category in ["small", "medium", "large", "gqa"]: + print(f"\n{category.upper()} CONFIGS:") + configs = get_configs_by_category(category) + if configs: + for name, config in configs.items(): + tokens_k = config.total_tokens // 1000 + gqa_info = f" (GQA {config.num_heads}->{config.num_kv_heads})" if config.is_gqa else "" + causal_info = " [Causal]" if config.causal else " [Non-causal]" + window_info = f" [Window={config.window_size[0]},{config.window_size[1]}]" if config.window_size != (-1, -1) else "" + print(f" {name:20s} - {config.batch_size}x{config.num_heads}x{config.head_dim}, seq={config.max_seqlen}, tokens={tokens_k}K, {config.dtype}{gqa_info}{causal_info}{window_info}") + else: + print(" No configurations in this category") + + print(f"\nDEFAULT PERFORMANCE CONFIGS: {DEFAULT_PERFORMANCE_CONFIGS}") + print(f"\nUsage: Use --config to specify a configuration") \ No newline at end of file diff --git a/benchmark/benchmark_ring_attn.py b/benchmark/benchmark_ring_attn.py new file mode 100644 index 00000000..50929729 --- /dev/null +++ b/benchmark/benchmark_ring_attn.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import ring attention implementation +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnBenchmark(RingAttnBenchmarkBase): + """Benchmark for standard Ring Attention""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn.wrap_ring_attn_func' + + @property + def function_name(self) -> str: + return "ring_attn" + + def get_benchmark_name(self) -> str: + return "Ring Attention" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for standard ring attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v): + return wrap_ring_attn_func( + q, k, v, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for standard ring attention.""" + set_seed(42) + + # Create input tensors with standard batch format + q = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + k = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + v = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + + output = wrap_ring_attn_func( + q, k, v, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for standard ring attention.""" + class LegacyRingAttnConfig: + def __init__(self, **kwargs): + self.name = "legacy_ring_attn_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + self.causal = True + self.window_size = (-1, -1) + + return LegacyRingAttnConfig(**kwargs) + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Ring Attention Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = RingAttnBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/benchmark_ring_attn_varlen.py b/benchmark/benchmark_ring_attn_varlen.py new file mode 100644 index 00000000..97c4c6fc --- /dev/null +++ b/benchmark/benchmark_ring_attn_varlen.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import ring attention implementation +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnVarlenBenchmark(RingAttnBenchmarkBase): + """Benchmark for Ring Attention Variable Length""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func' + + @property + def function_name(self) -> str: + return "ring_attn_varlen" + + def get_benchmark_name(self) -> str: + return "Ring Attention Variable Length" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for variable length ring attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v, cu_seqlens_q, cu_seqlens_k): + return wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for variable length sequence attention.""" + set_seed(42) + + # Get cu_seqlens from config or create default + if hasattr(config, 'cu_seqlens'): + cu_seqlens = config.cu_seqlens + else: + # Create default variable length sequences + seqlen = config.max_seqlen + cu_seqlens = [0, seqlen // 8, seqlen // 4, seqlen // 2, seqlen] + + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + total_tokens = cu_seqlens[-1] + + # Create input tensors + q = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + k = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + v = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v, + 'cu_seqlens_q': cu_seqlens_tensor, + 'cu_seqlens_k': cu_seqlens_tensor + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + cu_seqlens_q = inputs['cu_seqlens_q'] + cu_seqlens_k = inputs['cu_seqlens_k'] + + output = wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for varlen.""" + class LegacyVarlenConfig: + def __init__(self, **kwargs): + self.name = "legacy_varlen_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.dtype = "bf16" + self.causal = True + self.window_size = (-1, -1) + + # Create variable length sequences + seqlen = self.max_seqlen + self.cu_seqlens = kwargs.get('cu_seqlens', [0, seqlen // 8, seqlen // 4, seqlen // 2, seqlen]) + self.total_tokens = self.cu_seqlens[-1] + + return LegacyVarlenConfig(**kwargs) + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Ring Attention Variable Length Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Total sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (number of sequences) (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = RingAttnVarlenBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/benchmark_zigzag_attn.py b/benchmark/benchmark_zigzag_attn.py new file mode 100644 index 00000000..94e99521 --- /dev/null +++ b/benchmark/benchmark_zigzag_attn.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag Attention Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import zigzag attention implementation +from nnscaler.customized_ops.ring_attention import wrap_zigzag_attn_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS, ZIGZAG_CONFIGS + + +class ZigzagAttnBenchmark(RingAttnBenchmarkBase): + """Benchmark for Zigzag Attention""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.zigzag_attn.wrap_zigzag_attn_func' + + @property + def function_name(self) -> str: + return "zigzag_attn" + + def get_benchmark_name(self) -> str: + return "Zigzag Attention" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for zigzag attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v): + # Zigzag attention only supports causal=True and window_size=(-1,-1) + return wrap_zigzag_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for zigzag attention.""" + set_seed(42) + + # Create input tensors with standard batch format + q = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + k = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + v = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + + # Zigzag attention constraints + output = wrap_zigzag_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for zigzag attention.""" + class LegacyZigzagAttnConfig: + def __init__(self, **kwargs): + self.name = "legacy_zigzag_attn_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + # Zigzag attention constraints + self.causal = True + self.window_size = (-1, -1) + + return LegacyZigzagAttnConfig(**kwargs) + + def run_performance_benchmark(self, config_name: str = None, dtype: str = "bf16", + timing_method: str = "warmup", warmup_runs: int = 3, + timing_runs: int = 5, **legacy_kwargs): + """Override to validate zigzag attention constraints.""" + # Validate configuration for zigzag constraints + if config_name: + from configs import get_config + config = get_config(config_name) + if not config.causal: + print(f"WARNING: Config '{config_name}' has causal=False, but zigzag attention requires causal=True") + print("Proceeding with causal=True for zigzag attention...") + if config.window_size != (-1, -1): + print(f"WARNING: Config '{config_name}' has window_size={config.window_size}, but zigzag attention requires (-1, -1)") + print("Proceeding with window_size=(-1, -1) for zigzag attention...") + + # Call parent implementation + super().run_performance_benchmark( + config_name=config_name, dtype=dtype, timing_method=timing_method, + warmup_runs=warmup_runs, timing_runs=timing_runs, **legacy_kwargs + ) + + def list_configurations(self): + """List configurations suitable for zigzag attention.""" + print("Available Zigzag Attention Configurations:") + print("=" * 50) + print("NOTE: Zigzag attention only supports causal=True and window_size=(-1,-1)") + print("Configurations listed below will be automatically adjusted for these constraints.\n") + + # Call parent method but with zigzag-specific note + super().list_configurations() + + print(f"\nZIGZAG-SPECIFIC CONFIGS: {list(ZIGZAG_CONFIGS.keys())}") + print("These configs are specifically designed for zigzag attention.") + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Zigzag Attention Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = ZigzagAttnBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/llama/customized_ops/README.md b/examples/llama/customized_ops/README.md deleted file mode 100644 index a20398c8..00000000 --- a/examples/llama/customized_ops/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# ring attention - -Tensor parallel (partition head) is a widely used distributed plan to train large language models. Computation and memory are -distributed evenly across devices. However, when the sequence length is extremely long (e.g., 1M), the partition degree of -tensor parallel is constrained by the number of kv heads, which means that the maximum number of devices in a data parallel -unit is no more than the number of kv heads. As a result, tensor parallel fails to scale a model with long sequence length. - -[ring attention](https://arxiv.org/abs/2310.01889) is proposed to address this issue. It partitions q, k and v along the -sequence dimension and passes the partitioned q, k and v through a ring of devices. [ring flash attention](https://github.com/zhuzilin/ring-flash-attention) -implements a high-performance version in PyTorch. This example attempts to integrate the causal version of ring attention -(zigzag ring attention) into nnScaler. - -The interface is wrapped in `zigzag_attn.py`. [flash attention](https://github.com/Dao-AILab/flash-attention) is required for this example. - -In addition to the zigzag version, we also include a implementation based on [llama 3.1](https://ai.meta.com/research/publications/the-llama-3-herd-of-models/)'s technical report. This version uses `all_gather` and `reduce_scatter` to collect and distribute the kv values and gradients. You can check the code in `ring_attn.py`. - -Test can be run with the following command: -```bash -torchrun --nproc_per_node 4 test_ring_attn.py -torchrun --nproc_per_node 4 test_zigzag_attn.py -torchrun --nproc_per_node 4 test_ring_attn_varlen.py -``` \ No newline at end of file diff --git a/examples/llama/customized_ops/ring_attention/ring_attn_varlen.py b/examples/llama/customized_ops/ring_attention/ring_attn_varlen.py deleted file mode 100644 index 869dc0bd..00000000 --- a/examples/llama/customized_ops/ring_attention/ring_attn_varlen.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import Tuple, List, Dict -from torch import Tensor -import torch.distributed as dist - -from nnscaler.graph.parser.register import register_op -from nnscaler.ir.operator import IRFwOperation -from flash_attn import flash_attn_varlen_func -from .core.ring_attn_varlen_implementation import llama3_flash_attn_prepare_cu_seqlens, llama3_flash_attn_varlen_func -from .core.utils import gen_head_anno - -from nnscaler.runtime.device import DeviceGroup - - -def wrap_ring_attn_varlen_func( - q: Tensor, - k: Tensor, - v: Tensor, - cu_seqlens_q: Tensor, - cu_seqlens_k: Tensor, - dropout_p: float = 0.0, - softmax_scale: Tensor = None, - causal: bool = False, - window_size: Tuple[int] = (-1, -1), - alibi_slopes: Tensor = None, - deterministic: bool = False, - return_attn_probs: bool = False, - process_group: Tuple[int] = None, -): - ''' - wrap the ring_attn_varlen_func to support the distributed training in nnScaler. - most of the arguments are the same as the original flash_attn_varlen_func. - `process_group` should be none in the user code since nnScaler accepts the - program defined for the single device and will automatically generate the - required communications. - ''' - assert not return_attn_probs, "return_attn_probs is not supported in ring-attention" - assert alibi_slopes is None, "alibi_slopes is not supported in ring-attention" - max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() - max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() - - if process_group is None or len(process_group) == 1: - output = flash_attn_varlen_func( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=False, - ) - return output - - assert len(q.shape) == 3, "q must have shape [total_q, qh, dim]" - assert len(k.shape) == 3, "k must have shape [total_k, kh, dim]" - assert len(v.shape) == 3, "v must have shape [total_k, vh, dim]" - total_q, qheads, qdim = q.shape - total_k, kheads, kdim = k.shape - total_v, vheads, vdim = v.shape - assert total_q == total_k == total_v, "total_q, total_k and total_v must be the same" - assert kheads == vheads, "number of k and v heads must be the same" - assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" - assert qdim == kdim == vdim, "dimension must be the same" - - local_process_group = DeviceGroup().get_group(process_group) - local_rank = dist.get_rank(local_process_group) - local_world_size = dist.get_world_size(local_process_group) - - ( - local_cu_seqlens_q, - local_cu_seqlens_k, - local_max_seqlen_q, - local_max_seqlen_k, - local_k_slice, - ) = llama3_flash_attn_prepare_cu_seqlens( - cu_seqlens_q, - causal=causal, - rank=local_rank, - world_size=local_world_size, - ) - - output = llama3_flash_attn_varlen_func( - q, - k, - v, - local_cu_seqlens_q, - local_cu_seqlens_k, - local_max_seqlen_q, - local_max_seqlen_k, - heads_k_stride=1, - local_k_slice=local_k_slice, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - group=local_process_group, - ) - - return output - - -def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: - """Special rule to generate ring_attn node""" - - signature = node.signature - - offset = (runtime_devid // plan_ndevs) * plan_ndevs - scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] - - kw_pairs = list() - for key, val in kwargs.items(): - code = f'{key}={val}' - kw_pairs.append(code) - - sub_input = node.inputs()[0] - full_input = sub_input.parent - partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] - assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" - if not partition_dims: - kw_pairs.append("process_group=None") - else: - if partition_dims[0] == 0: # partition on sequence dim - # the synchronization should occur across scaleunits - kw_pairs.append(f"process_group={scale_unit_dev_ids}") - elif partition_dims[0] == 1: - # partition the head dim, use local flash_attn_func - kw_pairs.append("process_group=None") - else: - raise ValueError(f'unsupported partition dim: {partition_dims[0]}') - - args = ", ".join(list(args) + kw_pairs) - return f"{signature}({args})" - - -def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: - q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states) - return f'l {q_anno} hd^, l {kv_anno} hd^, l {kv_anno} vd^, e^, e^ -> l {q_anno} vd^' - - -register_op(flash_attention_anno, emit_fn=emit_ring)(wrap_ring_attn_varlen_func) diff --git a/examples/llama/customized_ops/test_ring_attn.py b/examples/llama/customized_ops/test_ring_attn.py deleted file mode 100644 index fc72a91a..00000000 --- a/examples/llama/customized_ops/test_ring_attn.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import argparse -import nnscaler -from nnscaler.graph import IRGraph -from nnscaler.ir.operator import IRFwOperation -from nnscaler.parallel import parallelize, ComputeConfig, ReuseType -import torch.distributed as dist - -import nnscaler.graph -import nnscaler.graph.function -from ring_attention import wrap_ring_attn_func -from ring_attention.core.utils import set_seed, log - - -class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - - def forward(self, q, k, v): - out = wrap_ring_attn_func(q, k, v) - return out - - -def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: - ngpus = resource.plan_ngpus - partitioned = False - for idx, node in enumerate(graph.select(ntype=IRFwOperation)): - if not partitioned and node.signature == 'ring_attention.ring_attn.wrap_ring_attn_func': - print('\nPartitioned node: ', node, '\n') - sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) - partitioned = True - else: - sub_nodes = graph.replicate(node, times=ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - assert partitioned, f'expect ring_attn_func in graph, but not found.' - return graph - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dtype", - type=str, - default="bf16", - choices=["fp16", "bf16"], - help="Data type for inputs", - ) - args = parser.parse_args() - - nnscaler.init() - rank_id = torch.distributed.get_rank() - world_size = dist.get_world_size() - - set_seed(rank_id) - bsz = 1 - seqlen = 8192 - nheads = 24 - d = 128 - - device = torch.device(f"cuda:{rank_id}") - dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 - - q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) - k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) - v = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) - - dist.broadcast(q, src=0) - dist.broadcast(k, src=0) - dist.broadcast(v, src=0) - dist.barrier() - - q.requires_grad = True - k.requires_grad = True - v.requires_grad = True - - single_out = wrap_ring_attn_func(q, k, v) - single_out.retain_grad() - single_loss = single_out.sum() - single_loss.backward() - - model = TestModule() - - qq = q.detach().clone().requires_grad_() - kk = k.detach().clone().requires_grad_() - vv = v.detach().clone().requires_grad_() - - parallel_model = parallelize(model, dummy_forward_args={"q": qq, "k": kk, "v": vv}, pas_policy=policy, - compute_config=ComputeConfig(world_size, world_size), reuse=ReuseType.OVERRIDE) - parallel_model = parallel_model.cuda() - - - parallel_model.train() - - qq = q.detach().clone().requires_grad_() - kk = k.detach().clone().requires_grad_() - vv = v.detach().clone().requires_grad_() - - para_out = parallel_model(qq, kk, vv) - para_loss = para_out.sum() - para_loss.backward() - parallel_model.sync_grad() - - log("single out", single_out, rank0_only=True) - log("multi out", para_out, rank0_only=True) - log("out diff", single_out - para_out, rank0_only=True) - - log("single dq", q.grad, rank0_only=True) - log("multi dq", qq.grad, rank0_only=True) - log("dq diff", q.grad - qq.grad, rank0_only=True) - - log("single dk", k.grad, rank0_only=True) - log("multi dk", kk.grad, rank0_only=True) - log("dk diff", k.grad - kk.grad, rank0_only=True) - - log("single dv", v.grad, rank0_only=True) - log("multi dv", vv.grad, rank0_only=True) - log("dv diff", v.grad - vv.grad, rank0_only=True) - - dist.destroy_process_group() diff --git a/examples/llama/customized_ops/test_ring_attn_varlen.py b/examples/llama/customized_ops/test_ring_attn_varlen.py deleted file mode 100644 index 0df2e863..00000000 --- a/examples/llama/customized_ops/test_ring_attn_varlen.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import nnscaler -import argparse -from nnscaler.graph import IRGraph -from nnscaler.ir.operator import IRFwOperation -from nnscaler.parallel import parallelize, ComputeConfig, ReuseType -import torch.distributed as dist - -import nnscaler.graph -import nnscaler.graph.function -from ring_attention import wrap_ring_attn_varlen_func -from ring_attention.core.utils import set_seed, log - - -class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - - def forward(self, q, k, v, cu_seqlens_q, cu_seqlens_k): - out = wrap_ring_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k) - return out - - -def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: - ngpus = resource.plan_ngpus - partitioned = False - for idx, node in enumerate(graph.select(ntype=IRFwOperation)): - if not partitioned and node.signature == 'ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func': - print('\nPartitioned node: ', node, '\n') - sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=0, num=ngpus) - partitioned = True - else: - sub_nodes = graph.replicate(node, times=ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - assert partitioned, f'expect ring_attn_varlen_func in graph, but not found.' - return graph - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dtype", - type=str, - default="bf16", - choices=["fp16", "bf16"], - help="Data type for inputs", - ) - args = parser.parse_args() - - nnscaler.init() - rank_id = torch.distributed.get_rank() - world_size = dist.get_world_size() - - set_seed(rank_id) - seqlen = 8192 - nheads = 24 - d = 128 - - device = torch.device(f"cuda:{rank_id}") - dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 - cu_seqlens = [0, 120, 1248, 4232, 8192] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - - q = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) - k = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) - v = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) - - dist.broadcast(q, src=0) - dist.broadcast(k, src=0) - dist.broadcast(v, src=0) - dist.barrier() - - q.requires_grad = True - k.requires_grad = True - v.requires_grad = True - - single_out = wrap_ring_attn_varlen_func(q, k, v, cu_seqlens_tensor, cu_seqlens_tensor) - single_out.retain_grad() - single_loss = single_out.sum() - single_loss.backward() - - model = TestModule() - - qq = q.detach().clone().requires_grad_() - kk = k.detach().clone().requires_grad_() - vv = v.detach().clone().requires_grad_() - - parallel_model = parallelize( - model, - dummy_forward_args={"q": qq, "k": kk, "v": vv, 'cu_seqlens_q': cu_seqlens_tensor, 'cu_seqlens_k': cu_seqlens_tensor}, - pas_policy=policy, - compute_config=ComputeConfig(world_size, world_size), - reuse=ReuseType.OVERRIDE - ) - parallel_model = parallel_model.cuda() - - - parallel_model.train() - - qq = q.detach().clone().requires_grad_() - kk = k.detach().clone().requires_grad_() - vv = v.detach().clone().requires_grad_() - - para_out = parallel_model(qq, kk, vv, cu_seqlens_tensor, cu_seqlens_tensor) - para_loss = para_out.sum() - para_loss.backward() - parallel_model.sync_grad() - - log("single out", single_out, rank0_only=True) - log("multi out", para_out, rank0_only=True) - log("out diff", single_out - para_out, rank0_only=True) - - log("single dq", q.grad, rank0_only=True) - log("multi dq", qq.grad, rank0_only=True) - log("dq diff", q.grad - qq.grad, rank0_only=True) - - log("single dk", k.grad, rank0_only=True) - log("multi dk", kk.grad, rank0_only=True) - log("dk diff", k.grad - kk.grad, rank0_only=True) - - log("single dv", v.grad, rank0_only=True) - log("multi dv", vv.grad, rank0_only=True) - log("dv diff", v.grad - vv.grad, rank0_only=True) - - dist.destroy_process_group() diff --git a/examples/llama/customized_ops/test_zigzag_attn.py b/examples/llama/customized_ops/test_zigzag_attn.py deleted file mode 100644 index 0e0d8069..00000000 --- a/examples/llama/customized_ops/test_zigzag_attn.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import argparse -import nnscaler -from nnscaler.graph import IRGraph -from nnscaler.ir.operator import IRFwOperation -from nnscaler.parallel import parallelize, ComputeConfig, ReuseType -import torch.distributed as dist - -import nnscaler.graph -import nnscaler.graph.function -from ring_attention import wrap_zigzag_attn_func -from ring_attention.core.utils import set_seed, log - - -class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - - def forward(self, q, k, v): - out = wrap_zigzag_attn_func(q, k, v) - return out - - -def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: - ngpus = resource.plan_ngpus - partitioned = False - for idx, node in enumerate(graph.select(ntype=IRFwOperation)): - if not partitioned and node.signature == 'ring_attention.zigzag_attn.wrap_zigzag_attn_func': - print('\nPartitioned node: ', node, '\n') - sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=1, num=ngpus) - partitioned = True - else: - sub_nodes = graph.replicate(node, times=ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - assert partitioned, f'expect zigzag_attn_func in graph, but not found.' - return graph - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dtype", - type=str, - default="bf16", - choices=["fp16", "bf16"], - help="Data type for inputs", - ) - args = parser.parse_args() - - nnscaler.init() - rank_id = torch.distributed.get_rank() - world_size = dist.get_world_size() - - set_seed(rank_id) - bsz = 1 - seqlen = 8192 - nheads = 24 - d = 128 - - device = torch.device(f"cuda:{rank_id}") - dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 - - q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) - k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) - v = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype) - - dist.broadcast(q, src=0) - dist.broadcast(k, src=0) - dist.broadcast(v, src=0) - dist.barrier() - - q.requires_grad = True - k.requires_grad = True - v.requires_grad = True - - single_out = wrap_zigzag_attn_func(q, k, v) - single_out.retain_grad() - single_loss = single_out.sum() - single_loss.backward() - - model = TestModule() - - qq = q.detach().clone().requires_grad_() - kk = k.detach().clone().requires_grad_() - vv = v.detach().clone().requires_grad_() - - parallel_model = parallelize(model, dummy_forward_args={"q": qq, "k": kk, "v": vv}, pas_policy=policy, - compute_config=ComputeConfig(world_size, world_size), reuse=ReuseType.OVERRIDE) - parallel_model = parallel_model.cuda() - - - parallel_model.train() - - qq = q.detach().clone().requires_grad_() - kk = k.detach().clone().requires_grad_() - vv = v.detach().clone().requires_grad_() - - para_out = parallel_model(qq, kk, vv) - para_loss = para_out.sum() - para_loss.backward() - parallel_model.sync_grad() - - log("single out", single_out, rank0_only=True) - log("multi out", para_out, rank0_only=True) - log("out diff", single_out - para_out, rank0_only=True) - - log("single dq", q.grad, rank0_only=True) - log("multi dq", qq.grad, rank0_only=True) - log("dq diff", q.grad - qq.grad, rank0_only=True) - - log("single dk", k.grad, rank0_only=True) - log("multi dk", kk.grad, rank0_only=True) - log("dk diff", k.grad - kk.grad, rank0_only=True) - - log("single dv", v.grad, rank0_only=True) - log("multi dv", vv.grad, rank0_only=True) - log("dv diff", v.grad - vv.grad, rank0_only=True) - - dist.destroy_process_group() diff --git a/nnscaler/customized_ops/__init__.py b/nnscaler/customized_ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nnscaler/customized_ops/ring_attention/README.md b/nnscaler/customized_ops/ring_attention/README.md new file mode 100644 index 00000000..38fbec5f --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/README.md @@ -0,0 +1,219 @@ +# Ring Attention Implementation + +High-performance ring attention mechanisms for nnscaler, supporting multiple attention variants and distributed training. + +## 📖 Overview + +This module implements multiple efficient attention mechanisms designed to distribute computation evenly in long sequence processing: + +- **Ring Attention**: Standard ring attention supporting arbitrary sequence lengths +- **Ring Attention Variable Length**: Variable-length sequence optimized ring attention +- **Zigzag Attention**: Zigzag pattern ring attention optimized for causal attention + +All implementations are deeply integrated with nnscaler's parallel computing framework, supporting automatic distributed training. + +## 🏗️ Architecture Design + +``` +nnscaler/customized_ops/ring_attention/ +├── __init__.py # Package import interface +├── ring_attn.py # Standard ring attention +├── ring_attn_varlen.py # Variable length ring attention +├── zigzag_attn.py # Zigzag ring attention +├── varlen_utils.py # Variable length utility functions +└── core/ # Core implementations + ├── ring_attn_implementation.py # Standard ring attention core + ├── ring_attn_varlen_implementation.py # Variable length core implementation + ├── zigzag_attn_implementation.py # Zigzag attention core implementation + └── utils.py # Common utility functions +``` + +## 🚀 Quick Start + +### Standard Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + +# Basic usage +output = wrap_ring_attn_func( + q, # [batch_size, seq_len, num_heads, head_dim] + k, # [batch_size, seq_len, num_heads, head_dim] + v, # [batch_size, seq_len, num_heads, head_dim] + causal=True, # Causal attention mask + window_size=(-1, -1), # Sliding window size, (-1,-1) means global attention + softmax_scale=None, # Softmax scale factor, defaults to 1/sqrt(head_dim) + dropout_p=0.0 # Dropout probability +) +``` + +### Variable Length Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func + +# Variable length sequence attention +output = wrap_ring_attn_varlen_func( + q, # [total_tokens, num_heads, head_dim] + k, # [total_tokens, num_heads, head_dim] + v, # [total_tokens, num_heads, head_dim] + cu_seqlens_q, # Cumulative sequence lengths [batch_size + 1] + cu_seqlens_k, # Cumulative sequence lengths [batch_size + 1] + bias=None, # Optional attention bias + causal=True, # Causal attention mask + window_size=(-1, -1), # Sliding window size + softmax_scale=None, # Softmax scale factor + dropout_p=0.0 # Dropout probability +) +``` + +### Zigzag Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_zigzag_attn_func + +# Zigzag attention (causal attention only) +output = wrap_zigzag_attn_func( + q, # [batch_size, seq_len, num_heads, head_dim] + k, # [batch_size, seq_len, num_heads, head_dim] + v, # [batch_size, seq_len, num_heads, head_dim] + causal=True, # Must be True + window_size=(-1, -1), # Must be (-1, -1), sliding window not supported + softmax_scale=None, + dropout_p=0.0 +) +``` + +## 🔧 Core Features + +### Performance Optimization +- **Flash Attention integration**: Efficient implementation based on flash_attn +- **TransformerEngine support**: Automatic detection and usage of TE 2.2.0+ +- **CUDA kernel optimization**: GPU-optimized low-level implementations +- **Distributed friendly**: Seamless integration with torch.distributed + +### Flexible Configuration +- **Attention patterns**: Support for causal and non-causal attention +- **Sliding window**: Configurable local attention windows +- **GQA support**: Grouped Query Attention optimization +- **Custom scaling**: Flexible softmax scaling strategies + +## 🧮 Algorithm Principles + +### Ring Attention Mechanism + +Ring Attention decomposes attention computation into multiple blocks: + +1. **Sequence chunking**: Divide long sequences into blocks distributed across devices +2. **Ring communication**: Devices pass key/value blocks by all-gather and reduce-scatter +3. **Incremental computation**: Each device computes attention with received key/value blocks + +### Variable Length Optimization + +Special optimizations for variable length sequences: + +```python +# Cumulative sequence length example +cu_seqlens = [0, 128, 256, 512] # 3 sequences with lengths 128, 128, 256 +# Corresponding token tensor shape: [512, num_heads, head_dim] +``` + +### Zigzag Pattern + +Zigzag Attention uses a special communication pattern for higher efficiency in causal attention scenarios: + +- **Causal constraint**: Only supports causal=True cases +- **Optimized communication**: Ring communication optimized for causal masks +- **Memory friendly**: Further reduces unnecessary computation and communication + +## 🔗 nnscaler Integration + +### Automatic Parallelization + +```python +from nnscaler.parallel import parallelize, ComputeConfig +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + +class AttentionModel(torch.nn.Module): + def forward(self, q, k, v): + return wrap_ring_attn_func(q, k, v, causal=True) + +# nnscaler automatically handles distribution +config = ComputeConfig( + plan_ngpus=4, + runtime_ngpus=4 +) +parallel_model = parallelize(model, config=config) +``` + +### Computation Graph Optimization + +nnscaler automatically provides: +- **Communication optimization**: Minimize inter-device communication overhead +- **Memory planning**: Optimize memory usage patterns +- **Operator fusion**: Fuse with other operators for optimization +- **Gradient synchronization**: Automatic gradient communication in backward pass + +## 🧪 Testing Framework + +Comprehensive test coverage ensures implementation correctness and performance: + +```bash +# Run all attention tests +pytest tests/customized_ops/ring_attn/ -v + +# Specific attention variant tests +pytest tests/customized_ops/ring_attn/test_ring_attn.py -v +pytest tests/customized_ops/ring_attn/test_ring_attn_varlen.py -v +pytest tests/customized_ops/ring_attn/test_zigzag_attn.py -v +``` + +### Test Types + +- **Correctness tests**: Compare outputs with standard attention +- **Multi-GPU scalability**: Behavior validation across different device counts +- **GQA compatibility**: Grouped Query Attention correctness +- **Sliding window**: Local attention pattern validation +- **Edge cases**: Stability testing under extreme conditions + +## 🛠️ Development Guide + +### Adding New Attention Variants + +1. **Core implementation**: Add implementation file in `core/` directory +2. **Wrapper function**: Create corresponding wrap function +3. **Test coverage**: Add comprehensive test cases +4. **Documentation**: Update README and API documentation + +### Performance Optimization Tips + +- **TransformerEngine**: Install TE 2.2.0+ for optimal performance +- **CUDA version**: Use CUDA 11.8+ for latest optimizations +- **Memory configuration**: Adjust batch size and sequence length based on GPU memory +- **Communication optimization**: Use InfiniBand networks to reduce communication latency + +## 🚨 Known Limitations + +### Ring Attention +- **alibi_slopes**: ALiBi positional encoding not currently supported +- **return_attn_probs**: Returning attention weights not supported + +### Zigzag Attention +- **causal**: Only supports causal attention (causal=True) +- **window_size**: Sliding window not supported (must be (-1,-1)) + +### General Limitations +- **Dynamic shapes**: Sequence length cannot change dynamically during training +- **Mixed precision**: May require special handling in certain configurations + +## 📚 References + +- **Ring Attention Paper**: [Ring Attention with Blockwise Transformers](https://arxiv.org/abs/2310.01889) +- **Flash Attention**: [FlashAttention: Fast and Memory-Efficient Exact Attention](https://arxiv.org/abs/2205.14135) +- **Llama3 Paper**: [The Llama3 Herd of Models](https://arxiv.org/pdf/2407.21783) +- **nnscaler Documentation**: [nnscaler Parallel Computing Framework](https://github.com/microsoft/nnscaler) +- **TransformerEngine**: [NVIDIA TransformerEngine](https://github.com/NVIDIA/TransformerEngine) + +--- + +**Note**: This implementation is optimized for large-scale distributed training. For single-GPU scenarios, standard Flash Attention is recommended for optimal performance. \ No newline at end of file diff --git a/examples/llama/customized_ops/ring_attention/__init__.py b/nnscaler/customized_ops/ring_attention/__init__.py similarity index 100% rename from examples/llama/customized_ops/ring_attention/__init__.py rename to nnscaler/customized_ops/ring_attention/__init__.py diff --git a/examples/llama/customized_ops/ring_attention/core/ring_attn_implementation.py b/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py similarity index 78% rename from examples/llama/customized_ops/ring_attention/core/ring_attn_implementation.py rename to nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py index 85b374d8..b8bbc351 100644 --- a/examples/llama/customized_ops/ring_attention/core/ring_attn_implementation.py +++ b/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward -from .utils import shuffle_input, recover_output, GlobalMemoryBuffer, get_default_args +from .utils import shuffle_input, recover_output, GlobalMemoryBuffer, get_default_args, all_gather, reduce_scatter _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() @@ -120,8 +120,11 @@ def ring_flash_attn_backward( if causal: up_k = k[:, :(up_rank + 1) * block_len] up_v = v[:, :(up_rank + 1) * block_len] + up_dk = dk_buffer[:, :(up_rank + 1) * block_len] + up_dv = dv_buffer[:, :(up_rank + 1) * block_len] else: up_k, up_v = k, v + up_dk, up_dv = dk_buffer, dv_buffer params = get_default_args(_flash_attn_backward).copy() params.update( @@ -133,8 +136,8 @@ def ring_flash_attn_backward( "out": up_out, "softmax_lse": up_lse, "dq": dq[:, :block_len], - "dk": dk_buffer[:, :(up_rank + 1) * block_len], - "dv": dv_buffer[:, :(up_rank + 1) * block_len], + "dk": up_dk, + "dv": up_dv, "dropout_p": dropout_p, "softmax_scale": softmax_scale, "causal": causal, @@ -164,8 +167,11 @@ def ring_flash_attn_backward( if causal: down_k = k[:, :(down_rank + 1) * block_len] down_v = v[:, :(down_rank + 1) * block_len] + down_dk = down_dk_buffer[:, :(down_rank + 1) * block_len] + down_dv = down_dv_buffer[:, :(down_rank + 1) * block_len] else: down_k, down_v = k, v + down_dk, down_dv = down_dk_buffer, down_dv_buffer params = get_default_args(_flash_attn_backward).copy() params.update( @@ -177,8 +183,8 @@ def ring_flash_attn_backward( "out": down_out, "softmax_lse": down_lse, "dq": dq[:, block_len:], - "dk": down_dk_buffer[:, :(down_rank + 1) * block_len], - "dv": down_dv_buffer[:, :(down_rank + 1) * block_len], + "dk": down_dk, + "dv": down_dv, "dropout_p": dropout_p, "softmax_scale": softmax_scale, "causal": causal, @@ -199,12 +205,17 @@ def ring_flash_attn_backward( dk_buffer.add_(down_dk_buffer) dv_buffer.add_(down_dv_buffer) - dim_size = list(k.size()) - dim_size[1] = dim_size[1] // world_size - dk = torch.empty(dim_size, dtype=k.dtype, device=k.device) - dv = torch.empty(dim_size, dtype=v.dtype, device=v.device) - dist.reduce_scatter_tensor(dk, dk_buffer, group=process_group) - dist.reduce_scatter_tensor(dv, dv_buffer, group=process_group) + bsz = q.size(0) + if bsz == 1: + dim_size = list(k.size()) + dim_size[1] = dim_size[1] // world_size + dk = torch.empty(dim_size, dtype=k.dtype, device=k.device) + dv = torch.empty(dim_size, dtype=v.dtype, device=v.device) + dist.reduce_scatter_tensor(dk, dk_buffer, group=process_group) + dist.reduce_scatter_tensor(dv, dv_buffer, group=process_group) + else: + dk = reduce_scatter(dk_buffer, dim=1, process_group=process_group) + dv = reduce_scatter(dv_buffer, dim=1, process_group=process_group) return dq, dk, dv @@ -237,14 +248,22 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) assert alibi_slopes is None + bsz = q.size(0) q = shuffle_input(to_send=q, process_group=group) + k = k.contiguous() + v = v.contiguous() world_size = dist.get_world_size(group) dim_size = list(k.size()) dim_size[1] = dim_size[1] * world_size - k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") - v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") - torch.distributed.all_gather_into_tensor(k_buffer, k, group=group) - torch.distributed.all_gather_into_tensor(v_buffer, v, group=group) + if bsz == 1: + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + # torch.distributed._all_gather_base function requires that the k and v tensors are contiguous. + torch.distributed.all_gather_into_tensor(k_buffer, k, group=group) + torch.distributed.all_gather_into_tensor(v_buffer, v, group=group) + else: + k_buffer = all_gather(k, dim=1, process_group=group) + v_buffer = all_gather(v, dim=1, process_group=group) out, up_lse, down_lse = ring_flash_attn_forward( group, @@ -274,13 +293,18 @@ def forward( def backward(ctx, dout, *args): dout = shuffle_input(to_send=dout, process_group=ctx.group) q, k, v, out, up_lse, down_lse = ctx.saved_tensors + bsz = q.size(0) world_size = dist.get_world_size(ctx.group) dim_size = list(k.size()) dim_size[1] = dim_size[1] * world_size - k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") - v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") - torch.distributed.all_gather_into_tensor(k_buffer, k, group=ctx.group) - torch.distributed.all_gather_into_tensor(v_buffer, v, group=ctx.group) + if bsz == 1: + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed.all_gather_into_tensor(k_buffer, k, group=ctx.group) + torch.distributed.all_gather_into_tensor(v_buffer, v, group=ctx.group) + else: + k_buffer = all_gather(k, dim=1, process_group=ctx.group) + v_buffer = all_gather(v, dim=1, process_group=ctx.group) dq, dk, dv = ring_flash_attn_backward( ctx.group, diff --git a/examples/llama/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py b/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py similarity index 96% rename from examples/llama/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py rename to nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py index 924f4eda..d388d3b4 100644 --- a/examples/llama/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py +++ b/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py @@ -27,7 +27,7 @@ def llama3_flash_attn_prepare_cu_seqlens( that this may be longer than `total_seq_len // world_size`. """ total_length = cu_seqlens[-1].item() - assert total_length % world_size == 0 + assert total_length % world_size == 0, cu_seqlens length_per_rank = total_length // world_size left = torch.searchsorted(cu_seqlens, rank * length_per_rank) right = torch.searchsorted(cu_seqlens, (rank + 1) * length_per_rank) @@ -121,6 +121,10 @@ def llama3_flash_attn_varlen_forward( q_i = q[:, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] k_i = kv_buffer[0][local_k_slice] v_i = kv_buffer[1][local_k_slice] + if alibi_slopes is None: + cur_alibi_slopes = None + else: + cur_alibi_slopes = alibi_slopes[i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] params = get_default_args(_flash_attn_varlen_forward).copy() params.update( @@ -135,7 +139,7 @@ def llama3_flash_attn_varlen_forward( "dropout_p": dropout_p, "softmax_scale": softmax_scale, "causal": causal, - "alibi_slopes": alibi_slopes, + "alibi_slopes": cur_alibi_slopes, "return_softmax": True and dropout_p > 0, } ) @@ -251,6 +255,11 @@ def llama3_flash_attn_varlen_backward( dk_i = dkv_buffer[0][local_k_slice] dv_i = dkv_buffer[1][local_k_slice] + if alibi_slopes is None: + cur_alibi_slopes = None + else: + cur_alibi_slopes = alibi_slopes[i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + params = get_default_args(_flash_attn_varlen_backward).copy() params.update( { @@ -270,7 +279,7 @@ def llama3_flash_attn_varlen_backward( "dropout_p": dropout_p, "softmax_scale": softmax_scale, "causal": causal, - "alibi_slopes": alibi_slopes, + "alibi_slopes": cur_alibi_slopes, "deterministic": deterministic, } ) @@ -328,7 +337,6 @@ def forward( if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - assert alibi_slopes is None k = k.contiguous() v = v.contiguous() out, softmax_lse = llama3_flash_attn_varlen_forward( diff --git a/examples/llama/customized_ops/ring_attention/core/utils.py b/nnscaler/customized_ops/ring_attention/core/utils.py similarity index 87% rename from examples/llama/customized_ops/ring_attention/core/utils.py rename to nnscaler/customized_ops/ring_attention/core/utils.py index 57383516..6345b35b 100644 --- a/examples/llama/customized_ops/ring_attention/core/utils.py +++ b/nnscaler/customized_ops/ring_attention/core/utils.py @@ -48,11 +48,11 @@ def log(msg, a, rank0_only=False): dist.barrier() -def gen_head_anno(query_states, key_states, value_states): - if query_states.shape[2] != key_states.shape[2]: - assert query_states.shape[2] % key_states.shape[2] == 0 - group_size = query_states.shape[2] // key_states.shape[2] - assert query_states.shape[2] == value_states.shape[2] * group_size +def gen_head_anno(query_states, key_states, value_states, head_pos=2): + if query_states.shape[head_pos] != key_states.shape[head_pos]: + assert query_states.shape[head_pos] % key_states.shape[head_pos] == 0 + group_size = query_states.shape[head_pos] // key_states.shape[head_pos] + assert query_states.shape[head_pos] == value_states.shape[head_pos] * group_size q_anno = f'(group_num {group_size})' kv_anno = 'group_num' else: @@ -321,3 +321,23 @@ def recover_output(to_send: torch.Tensor, to_send_f[:, block_seq_len:] = res return to_send_f.contiguous() + + +def all_gather(tensor: torch.Tensor, dim: int, process_group: dist.ProcessGroup): + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + tensor_list[torch.distributed.get_rank(process_group)] = tensor.data + torch.distributed.all_gather(tensor_list, tensor, group=process_group) + otensor = torch.concat(tuple(tensor_list), dim=dim) + return otensor + + +def reduce_scatter(tensor: torch.Tensor, dim: int, process_group: dist.ProcessGroup): + world_size = dist.get_world_size(process_group) + itensors = list(tensor.chunk(world_size, dim)) + for idx, t in enumerate(itensors): + itensors[idx] = t.contiguous() if not t.is_contiguous() else t + otensor = torch.empty_like(itensors[0], requires_grad=False) + torch.distributed.reduce_scatter(otensor, itensors, group=process_group) + return otensor diff --git a/examples/llama/customized_ops/ring_attention/core/zigzag_attn_implementation.py b/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py similarity index 100% rename from examples/llama/customized_ops/ring_attention/core/zigzag_attn_implementation.py rename to nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py diff --git a/examples/llama/customized_ops/ring_attention/ring_attn.py b/nnscaler/customized_ops/ring_attention/ring_attn.py similarity index 95% rename from examples/llama/customized_ops/ring_attention/ring_attn.py rename to nnscaler/customized_ops/ring_attention/ring_attn.py index 6a24bb7f..e7a8a4b8 100644 --- a/examples/llama/customized_ops/ring_attention/ring_attn.py +++ b/nnscaler/customized_ops/ring_attention/ring_attn.py @@ -51,10 +51,6 @@ def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=N local_process_group = DeviceGroup().get_group(process_group) - # In the RingFlashAttnFunc.apply function, the torch.distributed._all_gather_base function - # requires that the k and v tensors be contiguous. - k = k.contiguous() - v = v.contiguous() output = RingFlashAttnFunc.apply( q, k, @@ -67,7 +63,7 @@ def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=N deterministic, return_attn_probs, local_process_group, - ).contiguous() + ) return output diff --git a/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py b/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py new file mode 100644 index 00000000..bb9ff54b --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Tuple, List, Dict, Optional +import torch +from torch import Tensor +import torch.distributed as dist +import warnings + +from nnscaler.graph.parser.register import register_op +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir import IRTensor +from nnscaler.runtime.device import DeviceGroup +from flash_attn import flash_attn_varlen_func +from .core.ring_attn_varlen_implementation import llama3_flash_attn_prepare_cu_seqlens, llama3_flash_attn_varlen_func +from .core.utils import gen_head_anno +from .varlen_utils import shuffle_varlen, unshuffle_varlen + +# Try to import TransformerEngine with version check +_HAS_TRANSFORMER_ENGINE = False +_TE_VERSION_OK = False +attn_forward_func_with_cp = None + +try: + import transformer_engine + _HAS_TRANSFORMER_ENGINE = True + + # Check version - require 2.2.0+ + try: + from packaging import version + te_version = version.parse(transformer_engine.__version__) + required_version = version.parse("2.2.0") + _TE_VERSION_OK = te_version >= required_version + + if _TE_VERSION_OK: + # Try different import paths for different versions + try: + # For v2.5.0+ + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import attn_forward_func_with_cp + except ImportError: + try: + # For v2.2.0-v2.4.x + from transformer_engine.pytorch.attention import attn_forward_func_with_cp + except ImportError: + warnings.warn( + "TransformerEngine attention module not available or incompatible. " + "Falling back to basic ring attention implementation." + ) + else: + warnings.warn( + f"TransformerEngine version {transformer_engine.__version__} is too old. " + f"Require 2.2.0+. Falling back to basic ring attention implementation." + ) + except ImportError: + # packaging not available, try to import anyway + try: + # Try different import paths for different versions + try: + # For v2.5.0+ + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import attn_forward_func_with_cp + except ImportError: + # For v2.2.0-v2.4.x + from transformer_engine.pytorch.attention import attn_forward_func_with_cp + _TE_VERSION_OK = True + except (ImportError, AttributeError): + warnings.warn( + "TransformerEngine attention module not available or incompatible. " + "Falling back to basic ring attention implementation." + ) + +except ImportError: + warnings.warn( + "TransformerEngine not found. Falling back to basic ring attention implementation. " + "For better performance with context parallelism, install TransformerEngine 2.2.0+." + ) + + +def get_transformer_engine_info() -> Dict[str, any]: + """Get information about TransformerEngine availability and version.""" + return { + "has_transformer_engine": _HAS_TRANSFORMER_ENGINE, + "version_ok": _TE_VERSION_OK, + "has_cp_function": attn_forward_func_with_cp is not None, + "version": getattr(transformer_engine, "__version__", None) if _HAS_TRANSFORMER_ENGINE else None, + "required_version": "2.2.0+", + } + + +def print_transformer_engine_status(): + """Print TransformerEngine status for debugging.""" + info = get_transformer_engine_info() + print("TransformerEngine Status:") + print(f" - Available: {info['has_transformer_engine']}") + if info['has_transformer_engine']: + print(f" - Version: {info['version']}") + print(f" - Version OK (>= 2.2.0): {info['version_ok']}") + print(f" - CP Function Available: {info['has_cp_function']}") + else: + print(f" - Required Version: {info['required_version']}") + print(f" - Will use TE CP: {info['has_transformer_engine'] and info['version_ok'] and info['has_cp_function']}") + + +def wrap_ring_attn_varlen_func( + q: Tensor, + k: Tensor, + v: Tensor, + cu_seqlens_q: Tensor, + cu_seqlens_k: Tensor, + alibi_slopes: Tensor, + dropout_p: float = 0.0, + softmax_scale: Tensor = None, + causal: bool = False, + window_size: Tuple[int] = (-1, -1), + deterministic: bool = False, + return_attn_probs: bool = False, + process_group: Tuple[int] = None, +): + ''' + wrap the ring_attn_varlen_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_varlen_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + assert not return_attn_probs, "return_attn_probs is not supported in ring-attention" + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + + if process_group is None or len(process_group) == 1: + output = flash_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=False, + ) + return output + + assert len(q.shape) == 3, "q must have shape [total_q, qh, dim]" + assert len(k.shape) == 3, "k must have shape [total_k, kh, dim]" + assert len(v.shape) == 3, "v must have shape [total_k, vh, dim]" + total_q, qheads, qdim = q.shape + total_k, kheads, kdim = k.shape + total_v, vheads, vdim = v.shape + assert total_q == total_k == total_v, "total_q, total_k and total_v must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + local_rank = dist.get_rank(local_process_group) + local_world_size = dist.get_world_size(local_process_group) + assert local_world_size == len(process_group), "local_world_size should be the same with process_group size" + + if local_process_group is None: + local_process_group = dist.group.WORLD + + if window_size == (-1, -1): + # Use TransformerEngine with context parallelism if available and version is OK + if _HAS_TRANSFORMER_ENGINE and _TE_VERSION_OK and attn_forward_func_with_cp is not None: + shuffled_q = shuffle_varlen(q, cu_seqlens_q, process_group, local_process_group) + shuffled_k = shuffle_varlen(k, cu_seqlens_k, process_group, local_process_group) + shuffled_v = shuffle_varlen(v, cu_seqlens_k, process_group, local_process_group) + + te_cu_seqlens_q = cu_seqlens_q.clone() + te_cu_seqlens_k = cu_seqlens_k.clone() + te_cu_seqlens_q = torch.cat( + [ + te_cu_seqlens_q, + torch.tensor([cu_seqlens_q[-1].item()], dtype=te_cu_seqlens_q.dtype, device=te_cu_seqlens_q.device) + ] + ) + te_cu_seqlens_k = torch.cat( + [ + te_cu_seqlens_k, + torch.tensor([cu_seqlens_k[-1].item()], dtype=te_cu_seqlens_k.dtype, device=te_cu_seqlens_k.device) + ] + ) + shuffled_output = attn_forward_func_with_cp( + True, + shuffled_q, + shuffled_k, + shuffled_v, + te_cu_seqlens_q, + te_cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + te_cu_seqlens_q, + te_cu_seqlens_k, + dropout_p, + local_process_group, + process_group, + # TODO: optimize the stream usage + torch.cuda.current_stream(), + "p2p", # "all_gather" version cannot work with thd format + qkv_format="thd", + attn_mask_type="padding_causal" if causal else "padding", + ) + output = unshuffle_varlen(shuffled_output, cu_seqlens_q, process_group, local_process_group) + return output + else: + # Fallback to basic ring attention implementation + warnings.warn( + "TransformerEngine not available or version incompatible. " + "Using basic ring attention implementation which may be slower." + ) + + ( + local_cu_seqlens_q, + local_cu_seqlens_k, + local_max_seqlen_q, + local_max_seqlen_k, + local_k_slice, + ) = llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens_q, + causal=causal, + rank=local_rank, + world_size=local_world_size, + ) + + output = llama3_flash_attn_varlen_func( + q, + k, + v, + local_cu_seqlens_q, + local_cu_seqlens_k, + local_max_seqlen_q, + local_max_seqlen_k, + heads_k_stride=1, + local_k_slice=local_k_slice, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=False, + group=local_process_group, + ) + + return output + + +def emit_ring(node: IRDimops, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate ring_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + remainder = runtime_devid % plan_ndevs + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [(i, f // s) for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + if partition_dims[0][0] == 0: # partition on sequence dim + # the synchronization should occur across scaleunits + num = partition_dims[0][1] + scale_unit_dev_ids = [local_rank + offset for local_rank in range(remainder // num * num, (remainder // num + 1) * num)] + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0][0] == 1: + # partition the head dim, use local flash_attn_func + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +def flash_attention_anno(query_states, key_states, value_states, cu_seqlens_q, cu_seqlens_k, alibi_slopes, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states, head_pos=1) + if isinstance(alibi_slopes, IRTensor): + return f'l {q_anno} hd^, l {kv_anno} hd^, l {kv_anno} vd^, e^, e^, {q_anno} -> l {q_anno} vd^' + else: + return f'l {q_anno} hd^, l {kv_anno} hd^, l {kv_anno} vd^, e^, e^, ? -> l {q_anno} vd^' + + +def input_gen_fn(node: IRDimops): + inputs = [] + device = torch.cuda.current_device() + seqlen = node.inputs()[0].shape[0] + for i, t in enumerate(node.inputs()): + if i < 3: # query, key, value + inputs.append(torch.randn(t.shape, dtype=t.dtype, device=device, requires_grad=t.requires_grad)) + elif i in [3, 4]: # cu_seqlens + inputs.append(torch.Tensor([0, seqlen]).to(torch.int32).to(device)) + elif i == 5: # optional alibi_slopes + if isinstance(t, IRTensor): + inputs.append(torch.randn(t.shape, dtype=t.dtype, device=device, requires_grad=t.requires_grad)) + else: + inputs.append(None) + else: # other kwargs, use defaults + break + return tuple(inputs) + + +register_op(flash_attention_anno, emit_fn=emit_ring, input_gen_fn=input_gen_fn)(wrap_ring_attn_varlen_func) diff --git a/nnscaler/customized_ops/ring_attention/varlen_utils.py b/nnscaler/customized_ops/ring_attention/varlen_utils.py new file mode 100644 index 00000000..bdd1f127 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/varlen_utils.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Utilities for variable-length sequence processing in ring attention. +Contains shuffle and unshuffle functions for context parallel processing. +""" + +from typing import List +import torch +from torch import Tensor +import torch.distributed as dist +from nnscaler.runtime.adapter.nn import allgather_reducescatter + + +def shuffle_varlen(t: Tensor, cu_seqlens_padded: Tensor, cp_ranks: List[int], cp_group: dist.ProcessGroup) -> Tensor: + """ + Shuffle tensor data for variable-length sequences in context parallel processing. + + Args: + t: Input tensor to shuffle (local portion from each rank) + cu_seqlens_padded: Cumulative sequence lengths (global) + cp_ranks: List of ranks in the context parallel group + cp_group: Process group for context parallel communication + + Returns: + Shuffled tensor + """ + # Get context parallel size and rank + cp_size = torch.distributed.get_world_size(group=cp_group) + assert cp_size > 1, "cp_size should be greater than 1" + cp_rank = torch.distributed.get_rank(group=cp_group) + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + full_tensor = allgather_reducescatter(t, 0, cp_ranks) + return process_tensor(full_tensor) + + +def unshuffle_varlen(t: Tensor, cu_seqlens_padded: Tensor, cp_ranks: List[int], cp_group: dist.ProcessGroup) -> Tensor: + """ + Unshuffle tensor data to restore original variable-length sequence order. + This is the reverse operation of shuffle_varlen. + + Args: + t: Shuffled tensor to unshuffle (local portion from each rank) + cu_seqlens_padded: Cumulative sequence lengths (global) + cp_ranks: List of ranks in the context parallel group + cp_group: Process group for context parallel communication + + Returns: + Unshuffled tensor (local portion for each rank) + """ + # reverse operation of shuffle_varlen + cp_size = torch.distributed.get_world_size(group=cp_group) + assert cp_size > 1, "cp_size should be greater than 1" + cp_rank = torch.distributed.get_rank(group=cp_group) + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + sum_len = cu_seqlens_padded[-1].item() + + def process_tensor(val): + if val is None: + return val + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + cp_rank_slices = [] + for rank in range(cp_size): + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (rank * slice_size), + seq_start + ((rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - rank) * slice_size), + device=val.device, + ) + ) + perm = torch.cat(cp_rank_slices) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(sum_len, device=val.device) + + # Create a tensor to hold the unshuffled result + unshuffled = val.index_select(current_seq_dim, inv_perm) + local_tensor = torch.chunk(unshuffled, cp_size, dim=current_seq_dim)[cp_rank] + return local_tensor + + full_tensor = allgather_reducescatter(t, 0, cp_ranks) + return process_tensor(full_tensor) diff --git a/examples/llama/customized_ops/ring_attention/zigzag_attn.py b/nnscaler/customized_ops/ring_attention/zigzag_attn.py similarity index 100% rename from examples/llama/customized_ops/ring_attention/zigzag_attn.py rename to nnscaler/customized_ops/ring_attention/zigzag_attn.py diff --git a/tests/customized_ops/__init__.py b/tests/customized_ops/__init__.py new file mode 100644 index 00000000..78e3db5e --- /dev/null +++ b/tests/customized_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Ring Attention test module""" \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/configs.py b/tests/customized_ops/ring_attn/configs.py new file mode 100644 index 00000000..ebc7182c --- /dev/null +++ b/tests/customized_ops/ring_attn/configs.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Configuration file for ring attention tests. +This file contains predefined test configurations for both correctness and performance testing. +""" + +from dataclasses import dataclass +from typing import List, Tuple, Optional + + +@dataclass +class RingAttnConfig: + """Configuration for ring attention test cases""" + batch_size: int + num_heads: int + head_dim: int + max_seqlen: int + dtype: str = "bf16" + name: str = "" + num_kv_heads: Optional[int] = None # For GQA/MQA support + causal: bool = True # Most attention patterns are causal + window_size: Tuple[int, int] = (-1, -1) # Sliding window attention (-1, -1) means no window + + def __post_init__(self): + # Set num_kv_heads to num_heads if not specified (standard MHA) + if self.num_kv_heads is None: + self.num_kv_heads = self.num_heads + + if not self.name: + gqa_suffix = f"_gqa{self.num_kv_heads}" if self.num_kv_heads != self.num_heads else "" + causal_suffix = "" if self.causal else "_noncausal" + window_suffix = f"_w{self.window_size[0]}-{self.window_size[1]}" if self.window_size != (-1, -1) else "" + self.name = f"b{self.batch_size}_h{self.num_heads}_d{self.head_dim}_s{self.max_seqlen}_{self.dtype}{gqa_suffix}{causal_suffix}{window_suffix}" + + # Generate cu_seqlens for variable length sequences + # Create sequences with different lengths for more realistic testing + seq_lens = [ + self.max_seqlen // 8, # Short sequence + self.max_seqlen // 4, # Medium sequence + self.max_seqlen // 2, # Long sequence + self.max_seqlen - self.max_seqlen // 8 - self.max_seqlen // 4 - self.max_seqlen // 2 # Remaining + ] + self.cu_seqlens = [0] + for seq_len in seq_lens: + self.cu_seqlens.append(self.cu_seqlens[-1] + seq_len) + + @property + def total_tokens(self) -> int: + """Total number of tokens across all sequences""" + return self.cu_seqlens[-1] + + @property + def is_gqa(self) -> bool: + """Check if this is a GQA (Grouped Query Attention) configuration""" + return self.num_kv_heads < self.num_heads + + @property + def is_mqa(self) -> bool: + """Check if this is an MQA (Multi-Query Attention) configuration""" + return self.num_kv_heads == 1 + + @property + def num_groups(self) -> int: + """Number of query heads per KV head (group size)""" + return self.num_heads // self.num_kv_heads + + +# Small test cases for quick correctness validation +SMALL_CONFIGS = { + "tiny": RingAttnConfig(2, 8, 64, 1024, "bf16", "tiny", causal=True), + "small": RingAttnConfig(4, 12, 128, 4096, "bf16", "small", causal=True), + "small_fp16": RingAttnConfig(4, 12, 128, 4096, "fp16", "small_fp16", causal=False), # One non-causal config + "small_window": RingAttnConfig(4, 12, 128, 4096, "bf16", "small_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Medium test cases for standard testing +MEDIUM_CONFIGS = { + "medium": RingAttnConfig(4, 24, 128, 8192, "bf16", "medium", causal=True), + "medium_large_head": RingAttnConfig(4, 12, 256, 8192, "bf16", "medium_large_head", causal=False), # One non-causal config + "medium_many_heads": RingAttnConfig(4, 32, 128, 8192, "bf16", "medium_many_heads", causal=True), + "medium_fp16": RingAttnConfig(4, 24, 128, 8192, "fp16", "medium_fp16", causal=True), + "medium_window": RingAttnConfig(4, 24, 128, 8192, "bf16", "medium_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Large test cases for performance benchmarking +LARGE_CONFIGS = { + "large": RingAttnConfig(4, 32, 128, 16384, "bf16", "large", causal=True), + "large_seq": RingAttnConfig(4, 24, 128, 32768, "bf16", "large_seq", causal=True), + "large_head": RingAttnConfig(4, 24, 256, 16384, "bf16", "large_head", causal=False), # One non-causal config + "xlarge": RingAttnConfig(8, 32, 128, 32768, "bf16", "xlarge", causal=True), + "large_window": RingAttnConfig(4, 32, 128, 16384, "bf16", "large_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Realistic model configurations (kept minimal, most covered by medium/large configs) +MODEL_CONFIGS = { +} + +# GQA (Grouped Query Attention) configurations based on Qwen models +GQA_CONFIGS = { + # Qwen3-235B-A22B: 64 heads, 4 kv_heads, 128 head_dim + "qwen3_235b_a22b": RingAttnConfig( + batch_size=2, + num_heads=64, + head_dim=64, + max_seqlen=16384, + dtype="bf16", + name="qwen3_235b_a22b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-30B-A3B: 40 heads, 8 kv_heads, 128 head_dim + "qwen3_30b_a3b": RingAttnConfig( + batch_size=4, + num_heads=32, + head_dim=64, + max_seqlen=16384, + dtype="bf16", + name="qwen3_30b_a3b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-4B: 32 heads, 4 kv_heads, 80 head_dim + "qwen3_4b": RingAttnConfig( + batch_size=4, + num_heads=32, + head_dim=80, + max_seqlen=16384, + dtype="bf16", + name="qwen3_4b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-32B: 64 heads, 8 kv_heads, 128 head_dim + "qwen3_32b": RingAttnConfig( + batch_size=2, + num_heads=64, + head_dim=128, + max_seqlen=16384, + dtype="bf16", + name="qwen3_32b", + num_kv_heads=8, + causal=True + ), + + # Qwen3-14B: 40 heads, 8 kv_heads, 128 head_dim + "qwen3_14b": RingAttnConfig( + batch_size=4, + num_heads=40, + head_dim=128, + max_seqlen=16384, + dtype="bf16", + name="qwen3_14b", + num_kv_heads=8, + causal=True + ), +} + +# MQA is already covered by medium/large configs, so removed duplicate MQA_CONFIGS + +# Zigzag attention configurations (only supports causal=True and window_size=(-1, -1)) +ZIGZAG_CONFIGS = { + "zigzag_tiny": RingAttnConfig(2, 8, 64, 1024, "bf16", "zigzag_tiny", causal=True, window_size=(-1, -1)), + "zigzag_small": RingAttnConfig(4, 12, 128, 4096, "bf16", "zigzag_small", causal=True, window_size=(-1, -1)), + "zigzag_medium": RingAttnConfig(4, 24, 128, 8192, "bf16", "zigzag_medium", causal=True, window_size=(-1, -1)), + "zigzag_large": RingAttnConfig(4, 32, 128, 16384, "bf16", "zigzag_large", causal=True, window_size=(-1, -1)), + "zigzag_fp16": RingAttnConfig(4, 12, 128, 4096, "fp16", "zigzag_fp16", causal=True, window_size=(-1, -1)), + "zigzag_gqa": RingAttnConfig(4, 32, 128, 8192, "bf16", "zigzag_gqa", num_kv_heads=8, causal=True, window_size=(-1, -1)), +} + +# All configurations combined +ALL_CONFIGS = { + **SMALL_CONFIGS, + **MEDIUM_CONFIGS, + **LARGE_CONFIGS, + **MODEL_CONFIGS, + **GQA_CONFIGS, + **ZIGZAG_CONFIGS, +} + +# Default configurations for different test types +DEFAULT_CORRECTNESS_CONFIGS = ["tiny", "small", "medium"] +DEFAULT_PERFORMANCE_CONFIGS = ["medium", "large"] +DEFAULT_MULTI_GPU_CONFIGS = ["small", "medium"] +DEFAULT_GQA_CONFIGS = ["qwen3_4b", "qwen3_14b", "qwen3_32b"] +DEFAULT_ZIGZAG_CONFIGS = ["zigzag_tiny", "zigzag_small", "zigzag_medium"] + + +def get_config(name: str) -> RingAttnConfig: + """Get a configuration by name""" + if name in ALL_CONFIGS: + return ALL_CONFIGS[name] + else: + raise ValueError(f"Unknown configuration: {name}. Available: {list(ALL_CONFIGS.keys())}") + + +def list_configs(category: str = "all") -> List[str]: + """List available configurations by category""" + if category == "all": + return list(ALL_CONFIGS.keys()) + elif category == "small": + return list(SMALL_CONFIGS.keys()) + elif category == "medium": + return list(MEDIUM_CONFIGS.keys()) + elif category == "large": + return list(LARGE_CONFIGS.keys()) + elif category == "model": + return list(MODEL_CONFIGS.keys()) + elif category == "gqa": + return list(GQA_CONFIGS.keys()) + elif category == "zigzag": + return list(ZIGZAG_CONFIGS.keys()) + elif category == "correctness": + return DEFAULT_CORRECTNESS_CONFIGS + elif category == "performance": + return DEFAULT_PERFORMANCE_CONFIGS + elif category == "multi_gpu": + return DEFAULT_MULTI_GPU_CONFIGS + elif category == "gqa_default": + return DEFAULT_GQA_CONFIGS + elif category == "zigzag_default": + return DEFAULT_ZIGZAG_CONFIGS + else: + raise ValueError(f"Unknown category: {category}") + + +def get_configs_by_category(category: str) -> dict: + """Get all configurations in a category""" + config_names = list_configs(category) + return {name: get_config(name) for name in config_names} + + +def get_gqa_configs() -> dict: + """Get all GQA (Grouped Query Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if config.is_gqa and not config.is_mqa} + + +def get_mqa_configs() -> dict: + """Get all MQA (Multi-Query Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if config.is_mqa} + + +def get_mha_configs() -> dict: + """Get all MHA (Multi-Head Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if not config.is_gqa} + + +def get_zigzag_configs() -> dict: + """Get all Zigzag attention configurations""" + return ZIGZAG_CONFIGS + + +def filter_configs_by_attention_type(attention_type: str) -> dict: + """Filter configurations by attention type: 'mha', 'gqa', 'mqa', or 'zigzag'""" + if attention_type.lower() == "mha": + return get_mha_configs() + elif attention_type.lower() == "gqa": + return get_gqa_configs() + elif attention_type.lower() == "mqa": + return get_mqa_configs() # Will return empty dict since no dedicated MQA configs + elif attention_type.lower() == "zigzag": + return get_zigzag_configs() + else: + raise ValueError(f"Unknown attention type: {attention_type}. Supported: 'mha', 'gqa', 'mqa', 'zigzag'") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/ring_attn_runner.py b/tests/customized_ops/ring_attn/ring_attn_runner.py new file mode 100644 index 00000000..fafc442f --- /dev/null +++ b/tests/customized_ops/ring_attn/ring_attn_runner.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Correctness Test Runner Script + +This script runs ring attention correctness tests in a distributed environment. +It compares the outputs of single-GPU and multi-GPU ring attention to ensure correctness. +""" + +import sys +import torch + +from runner_base import RingAttnRunnerBase +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + + +class TestModule(torch.nn.Module): + """Test module for ring attention""" + def __init__(self, causal=True, window_size=(-1, -1)): + super(TestModule, self).__init__() + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v): + result = wrap_ring_attn_func( + q, k, v, + causal=self.causal, + window_size=self.window_size + ) + return result + + +class RingAttnRunner(RingAttnRunnerBase): + """Runner for ring attention tests""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn.wrap_ring_attn_func' + + @property + def function_name(self) -> str: + return 'wrap_ring_attn_func' + + def create_test_module(self, config) -> torch.nn.Module: + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare regular inputs with shape [batch_size, seq_len, num_heads, head_dim]""" + q = torch.randn( + config.batch_size, + config.max_seqlen, + config.num_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ) + + k = torch.randn( + config.batch_size, + config.max_seqlen, + config.num_kv_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ) + + v = torch.randn( + config.batch_size, + config.max_seqlen, + config.num_kv_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ) + + return {'q': q, 'k': k, 'v': v} + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + # Run single GPU version (this should call flash_attn internally when no process_group) + single_out = wrap_ring_attn_func( + inputs['q'], inputs['k'], inputs['v'], + causal=config.causal, + window_size=config.window_size + ) + return single_out, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization""" + return { + "q": inputs["q"], + "k": inputs["k"], + "v": inputs["v"], + } + + +def run_correctness_test(**kwargs): + """Legacy function for backward compatibility""" + runner = RingAttnRunner() + runner.run_correctness_test(**kwargs) + + +if __name__ == "__main__": + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + runner = RingAttnRunner() + runner.main(**kwargs) \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py b/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py new file mode 100644 index 00000000..806cd852 --- /dev/null +++ b/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Correctness Test Runner + +This script runs ring attention variable length correctness tests in a distributed environment. +It compares the outputs of single-GPU and multi-GPU ring attention to ensure correctness. +""" + +import sys +import torch + +from runner_base import RingAttnRunnerBase +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func + + +class TestModule(torch.nn.Module): + def __init__(self, causal=True, window_size=(-1, -1)): + super(TestModule, self).__init__() + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v, cu_seqlens_q, cu_seqlens_k): + out = wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=self.causal, + window_size=self.window_size + ) + return out + + +class RingAttnVarlenRunner(RingAttnRunnerBase): + """Runner for ring attention variable length tests""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func' + + @property + def function_name(self) -> str: + return 'ring_attn_varlen_func' + + def create_test_module(self, config) -> torch.nn.Module: + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare variable length inputs with cu_seqlens""" + cu_seqlens_tensor = torch.tensor(config.cu_seqlens, dtype=torch.int32, device=device) + total_seqlen = config.cu_seqlens[-1] + + # Create inputs with total sequence length (don't set requires_grad here, base class handles it) + q = torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + k = torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + v = torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v, + 'cu_seqlens_q': cu_seqlens_tensor, + 'cu_seqlens_k': cu_seqlens_tensor + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + single_out = wrap_ring_attn_varlen_func( + inputs['q'], inputs['k'], inputs['v'], + inputs['cu_seqlens_q'], inputs['cu_seqlens_k'], None, + causal=config.causal, + window_size=config.window_size + ) + single_out.retain_grad() + return single_out, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization""" + return { + "q": inputs["q"], + "k": inputs["k"], + "v": inputs["v"], + 'cu_seqlens_q': inputs['cu_seqlens_q'], + 'cu_seqlens_k': inputs['cu_seqlens_k'] + } + + +def run_ring_attn_correctness_test(**kwargs): + """Legacy function for backward compatibility""" + runner = RingAttnVarlenRunner() + runner.run_correctness_test(**kwargs) + + +if __name__ == "__main__": + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + runner = RingAttnVarlenRunner() + runner.main(**kwargs) diff --git a/tests/customized_ops/ring_attn/runner_base.py b/tests/customized_ops/ring_attn/runner_base.py new file mode 100644 index 00000000..2fcd7790 --- /dev/null +++ b/tests/customized_ops/ring_attn/runner_base.py @@ -0,0 +1,290 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base runner framework for ring attention correctness tests. +This module provides common functionality for both ring_attn and ring_attn_varlen test runners. +""" + +import os +import sys +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Tuple, Union + +import torch +import torch.distributed as dist +import nnscaler +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType + +from nnscaler.customized_ops.ring_attention.core.utils import set_seed, log +from configs import get_config + + +class RingAttnRunnerBase(ABC): + """Base class for ring attention test runners""" + + @property + @abstractmethod + def function_signature(self) -> str: + """Return the function signature to look for in the graph""" + pass + + @property + @abstractmethod + def function_name(self) -> str: + """Return the function name for partitioning""" + pass + + @abstractmethod + def create_test_module(self, config) -> torch.nn.Module: + """Create the test module with the appropriate configuration""" + pass + + @abstractmethod + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors based on the configuration and attention type""" + pass + + @abstractmethod + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + pass + + @abstractmethod + def get_dummy_forward_args(self, inputs) -> Dict[str, Any]: + """Get dummy forward arguments for model parallelization""" + pass + + def create_policy(self) -> callable: + """Create partitioning policy for the specific attention type""" + def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: + ngpus = resource.plan_ngpus + partitioned = False + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == self.function_signature: + print(f'\nPartitioned node: {node}\n') + sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=0, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + if not partitioned: + print(f"WARNING: No {self.function_name} found in graph for partitioning") + return graph + return policy + + def initialize_distributed(self): + """Initialize distributed environment""" + # Check CUDA availability first + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available") + sys.exit(1) + + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + # Check if we have enough GPUs + available_gpus = torch.cuda.device_count() + if available_gpus < world_size: + print(f"ERROR: Test requires {world_size} GPUs, but only {available_gpus} available") + sys.exit(1) + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + device_count = torch.cuda.device_count() + device = rank % device_count + try: + torch.cuda.set_device(device) + except Exception as e: + print(f"ERROR: Failed to set CUDA device {device}: {e}") + sys.exit(1) + + print(f"[INFO] world_size:{world_size}, rank:{rank}, available_gpus:{available_gpus}") + + try: + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + except Exception as e: + print(f"ERROR: Failed to initialize process group: {e}") + sys.exit(1) + + # Initialize nnscaler + nnscaler.init() + return world_size, rank + + def get_tolerances(self, dtype: str) -> Dict[str, float]: + """Get tolerance values based on data type""" + if dtype == "bf16": + return dict(atol=2.5e-2, rtol=2.5e-2) + elif dtype == "fp16": + return dict(atol=5e-3, rtol=5e-3) + else: + return dict(atol=2.5e-2, rtol=2.5e-2) + + def print_debug_info(self, single_out, para_out, single_grads, para_grads, rank_id): + """Print debug information when correctness test fails""" + if rank_id == 0: + print("✗ Correctness test FAILED!") + # Print detailed error information + log("single out", single_out, rank0_only=True) + log("multi out", para_out, rank0_only=True) + log("out diff", single_out - para_out, rank0_only=True) + + for i, (single_grad, para_grad, name) in enumerate(zip(single_grads, para_grads, ['q', 'k', 'v'])): + log(f"single d{name}", single_grad, rank0_only=True) + log(f"multi d{name}", para_grad, rank0_only=True) + log(f"d{name} diff", single_grad - para_grad, rank0_only=True) + + def print_success_info(self, rank_id, config_name=None): + """Print success information""" + if rank_id == 0: + config_suffix = f" for config '{config_name}'" if config_name else "" + print(f"✓ Correctness test PASSED{config_suffix}!") + + def run_correctness_test(self, config_name: str, dtype: str = "bf16", **kwargs): + """Run correctness test with the specific attention implementation""" + # Initialize distributed + world_size, rank = self.initialize_distributed() + rank_id = torch.distributed.get_rank() + + # Get configuration + config = get_config(config_name) + torch_dtype = torch.bfloat16 if dtype == "bf16" else torch.float16 + + if rank_id == 0: + print(f"Testing {self.function_name} correctness") + print(f"Configuration: {config.name}") + print(f" Batch size: {config.batch_size}") + print(f" Sequence length: {config.max_seqlen}") + print(f" Num heads: {config.num_heads}") + print(f" KV heads: {config.num_kv_heads}") + print(f" Head dim: {config.head_dim}") + print(f" Data type: {dtype}") + print(f" World size: {world_size}") + print("=" * 60) + + # Set seed for reproducibility + set_seed(42 + rank_id) + device = torch.device(f"cuda:{rank_id}") + + # Prepare inputs (implementation-specific) + inputs = self.prepare_inputs(config, device, torch_dtype) + + # Broadcast inputs to ensure consistency across ranks + for tensor in inputs.values(): + if isinstance(tensor, torch.Tensor): + dist.broadcast(tensor, src=0) + dist.barrier() + + # Setup models + model = self.create_test_module(config) + + # Create parallel model + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + + parallel_model = parallelize( + model, + dummy_forward_args=self.get_dummy_forward_args(dummy_args), + pas_policy=self.create_policy(), + compute_config=ComputeConfig(world_size, world_size), + reuse=ReuseType.OVERRIDE + ) + parallel_model = parallel_model.cuda() + parallel_model.train() + + # Run correctness test + print("Running correctness test..." if rank_id == 0 else "", end="") + + # Single mode for reference + single_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + single_inputs[k] = v.detach().clone().requires_grad_() + else: + single_inputs[k] = v.detach().clone() + else: + single_inputs[k] = v + + single_out, single_grad_tensors = self.run_single_gpu_reference(single_inputs, config) + + # Create gradient for backward pass + dout = torch.randn_like(single_out, device=device, dtype=torch_dtype) + # Ensure dout is consistent across all ranks + dist.broadcast(dout, src=0) + single_out.backward(dout) + + # Extract single gradients + single_grads = [tensor.grad for tensor in single_grad_tensors] + + # Parallel mode for correctness + para_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + para_inputs[k] = v.detach().clone().requires_grad_() + else: + para_inputs[k] = v.detach().clone() + else: + para_inputs[k] = v + + para_out = parallel_model(**para_inputs) + para_out.backward(dout) + parallel_model.sync_grad() + + # Extract gradients for q, k, v tensors + para_grads = [para_inputs[k].grad for k in ['q', 'k', 'v']] + + print(" Done!" if rank_id == 0 else "") + + # Check correctness with tolerances + tols = self.get_tolerances(dtype) + + # Verify outputs and gradients + try: + torch.testing.assert_close(single_out, para_out, **tols) + for single_grad, para_grad in zip(single_grads, para_grads): + torch.testing.assert_close(single_grad, para_grad, **tols) + + self.print_success_info(rank_id, config_name) + + except AssertionError as e: + self.print_debug_info(single_out, para_out, single_grads, para_grads, rank_id) + raise e + + dist.destroy_process_group() + + def main(self, **kwargs): + """Main entry point for the test runner""" + # Filter out torch.distributed.launch arguments + filtered_kwargs = {} + for k, v in kwargs.items(): + if k.startswith('--'): + # Remove leading '--' from argument names + k = k[2:].replace('-', '_') + if k not in ['local_rank', 'local-rank']: # Filter out torch.distributed.launch args + filtered_kwargs[k] = v + + # Convert string arguments back to appropriate types + for numeric_arg in ['batch_size', 'num_heads', 'head_dim', 'max_seqlen']: + if numeric_arg in filtered_kwargs and filtered_kwargs[numeric_arg] is not None: + filtered_kwargs[numeric_arg] = int(filtered_kwargs[numeric_arg]) + + for float_arg in ['rtol', 'atol']: + if float_arg in filtered_kwargs and filtered_kwargs[float_arg] is not None: + filtered_kwargs[float_arg] = float(filtered_kwargs[float_arg]) + + self.run_correctness_test(**filtered_kwargs) \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_base.py b/tests/customized_ops/ring_attn/test_base.py new file mode 100644 index 00000000..211ee567 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_base.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base test framework for ring attention tests. +This module provides common functionality for both ring_attn and ring_attn_varlen tests. +""" + +import os +import sys +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, Any, List + +import pytest +import torch + +from configs import ( + DEFAULT_CORRECTNESS_CONFIGS, + DEFAULT_MULTI_GPU_CONFIGS, + DEFAULT_GQA_CONFIGS, + get_config, + list_configs +) + + +class RingAttnTestBase(ABC): + """Base class for ring attention tests""" + + @property + @abstractmethod + def runner_script_name(self) -> str: + """Return the name of the runner script (e.g., 'run_correctness.py')""" + pass + + @property + @abstractmethod + def test_name_prefix(self) -> str: + """Return the prefix for test names (e.g., 'ring_attn' or 'ring_attn_varlen')""" + pass + + def _check_gpu_availability(self, required_gpus: int): + """Check if enough GPUs are available and skip test if not""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + available_gpus = torch.cuda.device_count() + if available_gpus < required_gpus: + pytest.skip(f"Test requires {required_gpus} GPUs, but only {available_gpus} available") + + def _get_project_root(self): + """Get the absolute path to nnscaler root directory""" + current_dir = os.path.dirname(__file__) # tests/customized_ops/ring_attn/ + return os.path.abspath(os.path.join(current_dir, "../../../")) + + def get_bash_arguments(self, num_gpus_per_node: int, **kwargs) -> List[str]: + """Generate command line arguments for running the test script""" + args = [ + "python3", + "-m", + "torch.distributed.launch", + "--nproc-per-node=" + str(num_gpus_per_node), + ] + + project_root = self._get_project_root() + script_path = os.path.join( + project_root, "tests", "customized_ops", "ring_attn", + self.runner_script_name + ) + args.append(script_path) + + for k, v in kwargs.items(): + args.append(f"{k}={v}") + return args + + def run_test_subprocess(self, num_gpus: int, **kwargs): + """Run test using subprocess with the configured runner script""" + # Check GPU availability before running subprocess + self._check_gpu_availability(num_gpus) + + subprocess.run( + self.get_bash_arguments( + num_gpus_per_node=num_gpus, + **kwargs + ), + check=True, + cwd=self._get_project_root() + ) + + # Common test methods that can be used by both ring_attn and ring_attn_varlen + + def run_correctness_basic(self, dtype: str, config_name: str): + """Test correctness with different configurations""" + num_gpus = 2 # Default to 2 GPUs for correctness tests + config = get_config(config_name) + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_multi_gpu_scaling(self, num_gpus: int, config_name: str): + """Test with different numbers of GPUs""" + self.run_test_subprocess( + num_gpus=num_gpus, + dtype="bf16", + config_name=config_name, + ) + + def run_comprehensive_configs(self, dtype: str): + """Test all available configurations (comprehensive test)""" + num_gpus = 2 + + # Test a selection of configurations + test_configs = ["tiny", "small", "medium"] + + for config_name in test_configs: + config = get_config(config_name) + # Skip very large configs for comprehensive test + if config.max_seqlen > 16384: + continue + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_gqa_correctness(self, dtype: str, config_name: str): + """Test GQA correctness with Qwen model configurations""" + num_gpus = 2 + config = get_config(config_name) + + # Ensure it's actually a GQA config + assert config.is_gqa, f"Configuration {config_name} should be GQA" + assert config.num_kv_heads < config.num_heads, f"Configuration {config_name} should have fewer KV heads" + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_sliding_window(self, dtype: str, config_name: str): + """Test with sliding window configurations""" + num_gpus = 2 + config = get_config(config_name) + + # Ensure it's actually a sliding window config + assert config.window_size != (-1, -1), f"Configuration {config_name} should have sliding window" + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + +def create_parametrized_tests(test_class: RingAttnTestBase): + """ + Factory function to create parametrized test methods for a test class. + This reduces code duplication between ring_attn and ring_attn_varlen tests. + """ + + # Correctness tests with different dtypes and configs + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + @pytest.mark.parametrize("config_name", DEFAULT_CORRECTNESS_CONFIGS) + def test_correctness(dtype, config_name): + """Test correctness with different configurations""" + instance = test_class() + instance.run_correctness_basic(dtype, config_name) + + # Multi-GPU tests + @pytest.mark.parametrize("num_gpus", [2, 4]) + @pytest.mark.parametrize("config_name", DEFAULT_MULTI_GPU_CONFIGS) + def test_multi_gpu(num_gpus, config_name): + """Test with different numbers of GPUs""" + instance = test_class() + instance.run_multi_gpu_scaling(num_gpus, config_name) + + # Comprehensive tests + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_all_configs(dtype): + """Test all available configurations (comprehensive test)""" + instance = test_class() + instance.run_comprehensive_configs(dtype) + + # GQA tests + @pytest.mark.parametrize("dtype", ["bf16"]) + @pytest.mark.parametrize("config_name", DEFAULT_GQA_CONFIGS) + def test_gqa_correctness(dtype, config_name): + """Test GQA correctness with Qwen model configurations""" + instance = test_class() + instance.run_gqa_correctness(dtype, config_name) + + # Sliding window tests + @pytest.mark.parametrize("dtype", ["bf16"]) + @pytest.mark.parametrize("config_name", ["small_window", "medium_window"]) + def test_sliding_window(dtype, config_name): + """Test with sliding window configurations""" + instance = test_class() + instance.run_sliding_window(dtype, config_name) + + return { + f'test_{test_class().test_name_prefix}_correctness': test_correctness, + f'test_{test_class().test_name_prefix}_multi_gpu': test_multi_gpu, + f'test_{test_class().test_name_prefix}_all_configs': test_all_configs, + f'test_{test_class().test_name_prefix}_gqa_correctness': test_gqa_correctness, + f'test_{test_class().test_name_prefix}_sliding_window': test_sliding_window, + } \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_ring_attn.py b/tests/customized_ops/ring_attn/test_ring_attn.py new file mode 100644 index 00000000..bcb47f16 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_ring_attn.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Correctness Tests + +This module tests the correctness of regular ring attention (non-variable length). +It uses the shared test base framework to avoid code duplication. +""" + +import pytest +import torch + +# Skip all tests if flash_attn_func is not available +try: + from flash_attn import flash_attn_func +except ImportError: + pytest.skip("flash_attn_func not available", allow_module_level=True) + +from test_base import RingAttnTestBase, create_parametrized_tests +from configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS + + +class RingAttnTest(RingAttnTestBase): + """Test class for regular ring attention""" + + @property + def runner_script_name(self) -> str: + return "ring_attn_runner.py" + + @property + def test_name_prefix(self) -> str: + return "ring_attn" + + +# Create parametrized test functions using the factory +test_functions = create_parametrized_tests(RingAttnTest) + +# Assign test functions to module globals for pytest discovery +test_ring_attn_correctness = test_functions['test_ring_attn_correctness'] +test_ring_attn_multi_gpu = test_functions['test_ring_attn_multi_gpu'] +test_ring_attn_all_configs = test_functions['test_ring_attn_all_configs'] +test_ring_attn_gqa_correctness = test_functions['test_ring_attn_gqa_correctness'] +test_ring_attn_sliding_window = test_functions['test_ring_attn_sliding_window'] + + +if __name__ == "__main__": + # Run specific test if called directly + test_instance = RingAttnTest() + test_instance.run_correctness_basic("bf16", "small") + + # Example of running GQA test + # test_instance.run_gqa_correctness("bf16", "qwen3_4b") + + # Example of running sliding window test + # test_instance.run_sliding_window("bf16", "small_window") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_ring_attn_varlen.py b/tests/customized_ops/ring_attn/test_ring_attn_varlen.py new file mode 100644 index 00000000..04c26035 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_ring_attn_varlen.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Correctness Tests + +This module tests the correctness of ring attention with variable length sequences. +It uses the shared test base framework to avoid code duplication. +""" + +import pytest +import torch + +# Skip all tests if flash_attn_varlen_func is not available +try: + from flash_attn import flash_attn_varlen_func +except ImportError: + pytest.skip("flash_attn_varlen_func not available", allow_module_level=True) + +from test_base import RingAttnTestBase, create_parametrized_tests +from configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS + + +class RingAttnVarlenTest(RingAttnTestBase): + """Test class for ring attention variable length""" + + @property + def runner_script_name(self) -> str: + return "ring_attn_varlen_runner.py" + + @property + def test_name_prefix(self) -> str: + return "ring_attn_varlen" + + +# Create parametrized test functions using the factory +test_functions = create_parametrized_tests(RingAttnVarlenTest) + +# Assign test functions to module globals for pytest discovery +test_ring_attn_varlen_correctness = test_functions['test_ring_attn_varlen_correctness'] +test_ring_attn_varlen_multi_gpu = test_functions['test_ring_attn_varlen_multi_gpu'] +test_ring_attn_varlen_all_configs = test_functions['test_ring_attn_varlen_all_configs'] +test_ring_attn_varlen_gqa_correctness = test_functions['test_ring_attn_varlen_gqa_correctness'] +test_ring_attn_varlen_sliding_window = test_functions['test_ring_attn_varlen_sliding_window'] + + +if __name__ == "__main__": + # Run specific test if called directly + test_instance = RingAttnVarlenTest() + test_instance.run_correctness_basic("bf16", "small") + + # Example of running GQA test + # test_instance.run_gqa_correctness("bf16", "qwen3_4b") diff --git a/tests/customized_ops/ring_attn/test_zigzag_attn.py b/tests/customized_ops/ring_attn/test_zigzag_attn.py new file mode 100644 index 00000000..6bb885bf --- /dev/null +++ b/tests/customized_ops/ring_attn/test_zigzag_attn.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag attention correctness tests. + +This module contains correctness tests for the zigzag attention implementation. +Note: Zigzag attention only supports causal=True and window_size=(-1, -1). + +Usage: + python -m pytest test_zigzag_attn.py -v + python -m pytest test_zigzag_attn.py::TestZigzagAttn::test_zigzag_attn_tiny_bf16 -v +""" + +import pytest + +# Skip all tests if flash_attn_func is not available +try: + from flash_attn import flash_attn_func +except ImportError: + pytest.skip("flash_attn_func not available", allow_module_level=True) + +from test_base import RingAttnTestBase + + +class TestZigzagAttn(RingAttnTestBase): + """Test class for zigzag attention correctness testing""" + + @property + def runner_script_name(self) -> str: + return "zigzag_attn_runner.py" + + @property + def test_name_prefix(self) -> str: + return "zigzag_attn" + + # Basic correctness tests + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_zigzag_attn_tiny(self, dtype): + """Test zigzag attention with tiny configuration""" + self.run_correctness_basic(dtype, "zigzag_tiny") + + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_zigzag_attn_small(self, dtype): + """Test zigzag attention with small configuration""" + self.run_correctness_basic(dtype, "zigzag_small") + + @pytest.mark.parametrize("dtype", ["bf16"]) + def test_zigzag_attn_medium(self, dtype): + """Test zigzag attention with medium configuration""" + self.run_correctness_basic(dtype, "zigzag_medium") + + # Multi-GPU tests + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_zigzag_attn_multi_gpu_small(self, num_gpus): + """Test zigzag attention with small config on multiple GPUs""" + self.run_multi_gpu_scaling(num_gpus, "zigzag_small") + + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_zigzag_attn_multi_gpu_medium(self, num_gpus): + """Test zigzag attention with medium config on multiple GPUs""" + self.run_multi_gpu_scaling(num_gpus, "zigzag_medium") + + # GQA test + def test_zigzag_attn_gqa(self): + """Test zigzag attention with GQA configuration""" + self.run_gqa_correctness("bf16", "zigzag_gqa") + + +if __name__ == "__main__": + # For direct execution, run a simple test + test_instance = TestZigzagAttn() + test_instance.run_correctness_basic("bf16", "zigzag_tiny") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/zigzag_attn_runner.py b/tests/customized_ops/ring_attn/zigzag_attn_runner.py new file mode 100644 index 00000000..5b1ba465 --- /dev/null +++ b/tests/customized_ops/ring_attn/zigzag_attn_runner.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag attention test runner implementation. +This module provides the specific runner for testing zigzag attention. +Note: Zigzag attention only supports causal=True and window_size=(-1, -1). +""" + +import os +import sys +from typing import Dict, Any + +import torch +import torch.nn as nn + +from nnscaler.customized_ops.ring_attention.zigzag_attn import wrap_zigzag_attn_func +from runner_base import RingAttnRunnerBase + + +class ZigzagAttnRunner(RingAttnRunnerBase): + """Zigzag attention test runner""" + + @property + def function_signature(self) -> str: + return "wrap_zigzag_attn_func" + + @property + def function_name(self) -> str: + return "wrap_zigzag_attn_func" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for zigzag attention""" + class TestModule(nn.Module): + def __init__(self, causal=True, window_size=(-1, -1)): + super().__init__() + # Zigzag attention only supports causal=True and window_size=(-1, -1) + assert causal is True, "Zigzag attention only supports causal=True" + assert window_size == (-1, -1), "Zigzag attention only supports window_size=(-1, -1)" + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v): + # Note: zigzag_attn always uses causal=True and window_size=(-1, -1) + return wrap_zigzag_attn_func(q, k, v, causal=self.causal, window_size=self.window_size) + + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare inputs for zigzag attention""" + batch_size = config.batch_size + max_seqlen = config.max_seqlen + num_heads = config.num_heads + num_kv_heads = config.num_kv_heads + head_dim = config.head_dim + + # Create input tensors + q = torch.randn(batch_size, max_seqlen, num_heads, head_dim, device=device, dtype=torch_dtype) + k = torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype) + v = torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + # Note: zigzag_attn always uses causal=True and window_size=(-1, -1) + output = wrap_zigzag_attn_func( + inputs['q'], inputs['k'], inputs['v'], + causal=config.causal, window_size=config.window_size) + output.retain_grad() + + return output, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs) -> Dict[str, Any]: + """Get dummy forward arguments for model parallelization""" + return { + 'q': inputs['q'], + 'k': inputs['k'], + 'v': inputs['v'] + } + + +def main(): + """Main entry point for command line execution""" + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + + runner = ZigzagAttnRunner() + runner.main(**kwargs) + + +if __name__ == "__main__": + main() \ No newline at end of file From b9c9dbe9296ab90f6666d19c874cd2064794e35c Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 29 Oct 2025 07:28:03 +0000 Subject: [PATCH 1847/1892] Merged PR 2418: [Feat] Add new-style policy API Add new-style policy API --- nnscaler/graph/parser/parser.py | 1 + nnscaler/graph/tracer/concrete_tracer.py | 16 +- nnscaler/graph/tracer/operator_patcher.py | 23 +- nnscaler/ir/cten.py | 67 +- nnscaler/ir/tensor.py | 2 +- nnscaler/parallel.py | 4 + nnscaler/policies.py | 433 ++++++++++- tests/graph/parser/test_ast_transformer.py | 12 +- tests/parallel_module/common.py | 13 + tests/test_policies.py | 817 ++++++++++++++++++++- 10 files changed, 1370 insertions(+), 18 deletions(-) diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 7c6ca0b0..8e0692af 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -499,6 +499,7 @@ def _set_node_meta(cls, node: torch.fx.Node, ir_node: Union[IRCell, Any]): ir_node.op_context = node.meta.get('op_context') module_stack = node.meta.get('nn_module_stack') ir_node.module_stack = module_stack + ir_node.call_expr = node.meta.get('call_expr') comment = str(node.meta.get('frame_record', '')) if comment: ir_node.comment = comment diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index f4731300..5a86969c 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -90,6 +90,7 @@ def __init__(self, strategy, record_frames = False): self.scope = Scope("", None) self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} + self.call_expr_stack = [] self.strategy = TRACE_STRATEGY[strategy](self) self.record_frames = record_frames self.patcher = FunctionPatcher() @@ -155,10 +156,14 @@ def _check_torch_compile_function(self, func): break func = getattr(func, '__wrapped__', None) - def on_function_call(self, func): + def on_function_call(self, func, expr): + self.call_expr_stack.append(expr) self._track_cache_wrapped_function(func) self._check_torch_compile_function(func) + def on_function_call_end(self): + self.call_expr_stack.pop() + @contextmanager def do_temp_call_origin(self): temp_call_origin = self.temp_call_origin @@ -209,6 +214,15 @@ def create_node(self, kind : str, target : Target, else: node.meta['nn_module_stack'] = collections.OrderedDict() + if self.call_expr_stack: + last_call_expr = None + for item in reversed(self.call_expr_stack): + # if not found, leave last_call_expr as None + if item: + last_call_expr = item + break + node.meta['call_expr'] = last_call_expr + def unwrap_nested_proxy(proxy: ep.ConcreteProxy): return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) diff --git a/nnscaler/graph/tracer/operator_patcher.py b/nnscaler/graph/tracer/operator_patcher.py index 5149d9e5..783d9627 100644 --- a/nnscaler/graph/tracer/operator_patcher.py +++ b/nnscaler/graph/tracer/operator_patcher.py @@ -171,7 +171,11 @@ def visit_Call(self, node: ast.Call): self.modified = True return self.generic_visit(ast.Call( func=ast.Name(id=self.proxy_call_name, ctx=ast.Load()), - args=[node.func, *node.args], + args=[ + node.func, + ast.fix_missing_locations(ast.Constant(value=ast.unparse(node))), + *node.args + ], keywords=node.keywords, )) else: @@ -311,7 +315,7 @@ def patch_func_helper(self, func): # use func.__code__.co_filename to make the new function easily debuggable. compile(new_tree, func_inner.__code__.co_filename, 'exec'), { - self.proxy_call_name: OperatorPatcherContext.patch_run, + self.proxy_call_name: OperatorPatcherContext._patch_run, **func_inner.__globals__, **closure_dict, }, @@ -346,10 +350,19 @@ def __exit__(self, exc_type, exc_value, tb): return exc_type is None @staticmethod - def patch_run(func, *args, **kwargs): + def _patch_run(func, expr, *args, **kwargs): assert OperatorPatcherContext.ctx_tracer is not None assert OperatorPatcherContext.ctx_patcher is not None with wrap_utils.do_temp_call_origin(): - OperatorPatcherContext.ctx_tracer.on_function_call(func) + OperatorPatcherContext.ctx_tracer.on_function_call(func, expr) new_func = OperatorPatcherContext.ctx_patcher.patch_func_or_module(func) - return new_func(*args, **kwargs) + + ret = new_func(*args, **kwargs) + + with wrap_utils.do_temp_call_origin(): + OperatorPatcherContext.ctx_tracer.on_function_call_end() + return ret + + @staticmethod + def patch_run(func, *args, **kwargs): + return OperatorPatcherContext._patch_run(func, '', *args, **kwargs) diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index bc8978ea..3c720743 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -27,7 +27,7 @@ from nnscaler.ir.unique import IDGenerator from nnscaler.ir.dtype import DTypeInfo -from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE +from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE, load_type NestedVarOrStatic = Union[Any, 'IRObject', List['IRObject'], 'IRTensor'] @@ -78,6 +78,13 @@ def __init__(self, self._comment: Optional[str] = None # the module stack that preserves the hierarchy information self._module_stack: Optional[OrderedDict[str, Any]] = None + # the original call expression + # Note: + # 1. some cells may not have call expression if the cell is not from function call (e.g., __getitem__) + # 2. call_expr can be inaccurate when function call happens + # inside pytorch official module (like in torch.nn namespace) forward, + # (e.g., F.linear inside nn.Linear), in this case, call_expr will be module call expression. + self._call_expr: Optional[str] = None # the operation context information self._op_context: Optional[Dict[str, Any]] = None @@ -378,6 +385,22 @@ def comment(self, info: str): @property def module_stack(self) -> Optional[OrderedDict[str, Any]]: + """ + Get the module stack, which preserves the hierarchy information + of modules this cell belongs to. + For example, if this cell is from model.submodule.layers.0.block0.conv2d, + then the module stack will be: + OrderedDict([ + ('model.submodule', ), + ('model.submodule.layers.0.block0', ), + ('model.submodule.layers.0.block0.conv2d', ), + ]) + + Please note + 1. Root module (e.g., model) is not included in the stack. + 2. Only modules that have `.forward` function are included in the stack, + so in above example, `torch.nn.ModuleList` is not included. + """ return self._module_stack @module_stack.setter @@ -387,6 +410,48 @@ def module_stack(self, stack: OrderedDict[str, Any]): """ self._module_stack = stack + @property + def module_class_chain(self) -> list[type[torch.nn.Module]]: + """ + Get the module chains the IRCell belongs to. + If module stack is None or empty, return []. + """ + if not self._module_stack: + return [] + return list(self._module_stack.values()) + + @property + def fqn(self) -> str: + """ + Get the fully qualified module name the IRCell belongs to. + If module stack is None or empty, return ''. + """ + if not self._module_stack: + return '' + return list(self._module_stack.keys())[-1] + + @property + def call_expr(self) -> Optional[str]: + return self._call_expr + + @call_expr.setter + def call_expr(self, expr: Optional[str]): + self._call_expr = expr + + @property + def fn(self) -> Optional[Callable]: + """ + Get the function of this cell based on its signature. + Return None if the function cannot be loaded. (e.g. virtual ops like `self_getattr`) + + Returns: + Callable: the function object + """ + try: + return load_type(self.signature) + except Exception as e: + return None + @property def op_context(self) -> Optional[Dict[str, Any]]: return self._op_context diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index b24f45ca..f3f2c9eb 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -563,7 +563,7 @@ def ndims(self) -> int: def as_attr(self): raise RuntimeError("as_attr is not allowed for SubTensor") - def splitdims(self) -> Tuple[int]: + def splitdims(self) -> Tuple[int, ...]: """! Get partitioned dimensions diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 06a6af84..7dc60e4c 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1001,6 +1001,10 @@ def __init__(self, init_params=True): if not pas_policy in _PREDEFINED_POLICIES: raise ValueError(f"Invalid pas_policy: {pas_policy}") pas_policy = _PREDEFINED_POLICIES[pas_policy] + else: + if not callable(pas_policy): + raise ValueError("pas_policy should be a callable or a predefined policy name") + pas_policy = partial(policies.fn, policy=pas_policy) is_module_class = inspect.isclass(module_or_module_class) module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ diff --git a/nnscaler/policies.py b/nnscaler/policies.py index f1db5858..cd3ed94a 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -18,8 +18,10 @@ IRDataOperation is recommended to be replicated to all devices. """ +import ast +from dataclasses import dataclass, field import logging -from typing import List, Optional, TYPE_CHECKING +from typing import List, Literal, Optional, TYPE_CHECKING, Callable, Iterable, Union import random import torch @@ -30,9 +32,12 @@ from nnscaler.graph import IRGraph from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.segment import IRSegment -from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from nnscaler.ir import IRCell, IRSubTensor, IRFullTensor +from nnscaler.ir.cten import IR +from nnscaler.runtime.function import identity, multiref if TYPE_CHECKING: @@ -347,3 +352,427 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: ) return parallelize_graph(graph, autodist_cfg) + + +@dataclass(unsafe_hash=True, frozen=True) +class OpPartition: + """ + OpPartition represents a partition plan for an operator dimension. + """ + input: int + dim: int + + +@dataclass +class OpPlan: + """ + OpPlan represents the distributed plan for an operator. + """ + op: IRFwOperation + recompute_id: int = -1 # -1 means no recompute + stage_id: int = -1 # pipeline stage id, -1 means following the previous op's stage + + # OpPartition: user specified partition plan + # You only need to specify one partition plan here. + # For example, torch.matmul has annotation of `m k+, k+ n -> m n`, + # If you want to partition the matmul on the k dimension, + # you can set OpPartition(input=0, dim=1) or OpPartition(input=1, dim=0). + # They are equivalent. + # None: replicated + # 'auto': auto partition based on the input tensor partition info + # 1. if any of the input tensors is value partitioned, we replicate the op + # TODO: is it too strict? + # 2. if any of the input tensors is partitioned on a dim, + # we will try to partition the op on the same dim first, + # if the partition is invalid, we replicate the op + # 3. if all the input tensor is replicated, we replicate the op + partition: OpPartition | None | Literal['auto'] = None # partition plan + # for future extension + # don't use it now. + partitions: List[OpPartition | None] = field(default_factory=list) # multiple partition plans + + def __post_init__(self): + if self.partition is not None and len(self.partitions) > 0: + raise ValueError("Only one of partition and partitions can be set") + + if len(self.partitions) > 1: + raise NotImplementedError("Multiple partitions are not supported yet") + + if len(self.partitions) == 1: + self.partition = self.partitions[0] + self.partitions = [] + + +def get_layer_index(fqn: str) -> int: + """ + Extract the layer index from full qualified name. + If there are multiple integers in the name, raise ValueError. + """ + nums = [int(s) for s in fqn.split(".") if s.isdigit()] + if len(nums) != 1: + raise ValueError(f"Name {fqn} should only contain one integer") + return nums[0] + + +def get_called_self_module_name(node_call_expr: str) -> str: + """ + Get the called module name from the node's call expr by ast. + For example: + self.up_proj(x) -> up_proj + self.act_fn(self.gate_proj(x)) -> act_fn + self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -> down_proj + torch.tanh(x) -> '' # because it's not called from self + self.up_proj(x).transpose() -> '' # because it's an attribute call + + Other cases return empty string. + + NOTE: regex is not easy to make it work + + """ + + if not node_call_expr: + return '' + call_expr: ast.Call = ast.parse(node_call_expr, mode='eval').body # type: ignore + if isinstance(call_expr, ast.Call): # self.up_proj(x) + if isinstance(call_expr.func, ast.Attribute): # self.up_proj + if isinstance(call_expr.func.value, ast.Name) and call_expr.func.value.id == 'self': + return call_expr.func.attr # up_proj + return '' + + +def get_pas_ops(graph: IRGraph) -> List[IRFwOperation]: + """ + Get all operators in the graph that can set operator plan. + When we write a policy, only ops returned from this function need to be considered. + + Args: + graph: the input IRGraph + + Returns: + List[IRFwOperation]: list of IRFwOperation nodes + """ + return graph.select(ntype=IRFwOperation) + + +def fn( + graph: IRGraph, cfg: 'ComputeConfig', + policy: Union[ + Callable[[IRGraph, 'ComputeConfig'], IRGraph], + Callable[[IRGraph, 'ComputeConfig'], Iterable[OpPlan]], + ] +) -> IRGraph: + """ + General policy function based on user-defined policy. + The user-defined policy can either return the final IRGraph, or + return a list of OpPlan to describe the distributed plan for each operator. + + To write a new-style policy, the most important part is to locate the operator node in the graph. + Here are some tips: + 1. use `node.name` to get the operator name. + 2. use `node.fn` to get the operator function. + 3. use `node.module_stack` to get the module stack info. + 4. use `node.module_class_chain` to get the module class chain. + 5. use `node.call_expr` to get the call expression string. And you can user `ast.parse` to parse it. + 6. use `get_layer_index` to get the layer index in a torch.nn.ModuleList. + 7. use `get_called_self_module_name` to get the called self module name from the call expression. + 8. use `node.inputs()` the get the input tensors of the operator. + We can further check whether the input tensor is a parameter by `tensor.is_param`, + or get the full name of the parameter by `tensor.name`, etc. + 9. insert anchors in code with `nnscaler.anchor` to help locate the operator (intrusive way). + + A good way to locate the operator will be like: + 1. Locate the module first by module_class_chain (`target_module in node.module_class_chain`) + 2. If the module are used multiple times (e.g., in ModuleList), + locate further by layer index (`get_layer_index`) or `node.fqn`. + 3. Once the module is located, + we can further locate the operator by + `node.name`,`node.call_expr`, `node.fn`, `node.inputs()` (especially the `is_param`/`name` of input) + or other properties. + + Args: + graph: the input IRGraph + cfg: the compute config + policy: the user-defined policy function. It can either return the final IRGraph, + or return an iterable of OpPlan for each operator. + + Returns: + the distributed IRGraph + """ + result = policy(graph, cfg) + if isinstance(result, IRGraph): # traditional policy + return result + + op_plans = {r.op: r for r in result} + ngpus: int = cfg.plan_ngpus + + recompute_groups: dict[int, list[IRFwOperation]] = {} + recompute_last_id: int = -1 + recompute_group_stages: dict[int, int] = {} + + pp_stages: list[list[IRFwOperation]] = [[]] + pp_cur_stage_id = 0 + + # key: IRFullTensor + # value: + # key: stage_id + # value: set of OpPartition in this stage + tensor_splits: dict[IRFullTensor, dict[int, set[OpPartition]]] = {} + # store the last split info for each tensor to help handle auto partition + # None: replicated + # 'value': value partitioned + # int: the partitioned dim + output_tensor_last_split: dict[IRFullTensor, int | None | Literal['value']] = {} + + fw_nodes = dict.fromkeys(graph.select(ntype=IRFwOperation)) + + for node in fw_nodes: + if node not in op_plans: + op_plans[node] = OpPlan(op=node) # default: no partition, stage 0, no recompute + + op_plan = op_plans[node] + + # set pipeline stage id if not set + if op_plan.stage_id == -1: + op_plan.stage_id = pp_cur_stage_id + + # currently we only support partition for IRDimops + if not isinstance(op_plan.op, IRDimops): + if op_plan.partition == 'auto': + op_plan.partition = None + if op_plan.partition is not None: + raise ValueError("Only IRDimops can be partitioned.") + + # list of partitions for the op + # [] means no partition(replicated) + op_partitions = [op_plan.partition] if op_plan.partition is not None else [] + + if op_partitions == ['auto']: + # auto partition based on input tensor partition info + op_partitions = [] # reset to collect partitions + for idx, input in enumerate(op_plan.op.inputs()): + if not isinstance(input, IRSubTensor): + continue + ftensor = input.parent + last_partition_dim = output_tensor_last_split.get(ftensor, None) + if last_partition_dim == 'value': + # value partitioned input, replicate the op + op_partitions = [] + break + elif last_partition_dim is not None: + op_partitions.append(OpPartition(input=idx, dim=last_partition_dim)) + + # final partition plan for the op + # key: input idx, value: partitioned dim + op_partition_map: dict[int, int] = {} + if op_partitions: + # we partition the op based on the first partition plan + # and then check the rest partitions are satisfied or not + op_first_partition = op_partitions[0] + partitioned_nodes = op_plan.op.algorithm('dim')\ + .instantiate(idx=op_first_partition.input, dim=op_first_partition.dim, num=ngpus) + subnode = partitioned_nodes[0] # first subnode carries all necessary partition info + + # collect input partition info + # key: input idx, value: partitioned dim + result_partitions: dict[int, int] = {} + for idx, input in enumerate(subnode.inputs()): + if not isinstance(input, IRSubTensor): + continue + split_dims = input.splitdims() + assert len(split_dims) <= 1, "Internal Error: multiple splitdims in one input" + if split_dims: + result_partitions[idx] = split_dims[0] + + # check the rest partitions + # Note if we only have one partition plan, the check is skipped, we can always partition it + # In fact, if `auto` is not specified, we always have at most one partition plan + for op_partition in op_partitions[1:]: + if op_partition.input not in result_partitions or \ + result_partitions[op_partition.input] != op_partition.dim: + _logger.warning( + f"Operator {op_plan.op} cannot be partitioned as specified: {op_partition}" + f", replicate it instead." + ) + op_partitions = [] + op_partition_map = {} + break + else: + # all partitions are satisfied + # then we can update input/output partition info + + # make sure the first item in op_partition_map is the first partition plan + op_partition_map[op_first_partition.input] = op_first_partition.dim + op_partition_map.update(result_partitions) + + for output in subnode.outputs(): + if not isinstance(output, IRSubTensor): + continue + ftensor = output.parent + if output.valmap != (0, 1): + output_tensor_last_split[ftensor] = 'value' + else: + split_dims = output.splitdims() + assert len(split_dims) <= 1, "Internal Error: multiple splitdims in one output" + if split_dims: + output_tensor_last_split[ftensor] = split_dims[0] + + if op_plan.partition == 'auto': + if not op_partition_map: + op_plan.partition = None + else: + # use the first partition plan, + # which is consistent with the logic above + first_input_idx = list(op_partition_map.keys())[0] + op_plan.partition = OpPartition( + input=first_input_idx, + dim=op_partition_map[first_input_idx] + ) + + # update tensor_splits for input tensors + for idx, input in enumerate(op_plan.op.inputs()): + if not isinstance(input, IRSubTensor): + continue + ftensor = input.parent + if ftensor not in tensor_splits: + tensor_splits[ftensor] = {} + if idx not in op_partition_map: + tensor_splits[ftensor].setdefault(op_plan.stage_id, set()).add(None) + else: + tensor_splits[ftensor].setdefault(op_plan.stage_id, set()).add( + OpPartition(input=idx, dim=op_partition_map[idx])) + + if op_plan.recompute_id != -1: + if op_plan.recompute_id in recompute_group_stages: + if recompute_group_stages[op_plan.recompute_id] != op_plan.stage_id: + raise ValueError("All ops in a recompute group must be in the same stage") + else: + recompute_group_stages[op_plan.recompute_id] = op_plan.stage_id + + if op_plan.recompute_id != recompute_last_id and op_plan.recompute_id in recompute_groups: + raise ValueError("Nodes in a recompute group must be continuous.") + + recompute_groups.setdefault(op_plan.recompute_id, []).append(op_plan.op) + + recompute_last_id = op_plan.recompute_id + + # update pipeline stages + if op_plan.stage_id == pp_cur_stage_id: + pp_stages[pp_cur_stage_id].append(op_plan.op) + elif op_plan.stage_id == pp_cur_stage_id + 1: + pp_cur_stage_id += 1 + pp_stages.append([op_plan.op]) + else: + raise ValueError("Pipeline stage ids must be continuous integers starting from 0") + + if len(op_plans) != len(fw_nodes): + assert len(op_plans) > len(fw_nodes) + for op_plan in op_plans.values(): + if op_plan.op not in fw_nodes: + raise ValueError(f"OpPlan contains operator {op_plan.op} not in the graph or not a forward operator") + + pp_segs = [graph] + nstages = len(pp_stages) + pp_enabled = nstages > 1 + # not all schedulers support pp_size < nstages + pp_size = cfg.pas_config.get('pipeline_size', nstages) + nmicros = cfg.pas_config.get('pipeline_nmicros', None) + scheduler = cfg.pas_config.get('pipeline_scheduler', '1f1b') + tp_size = ngpus // pp_size + + if pp_enabled: + if not cfg.use_end2end: + raise ValueError("Pipeline parallelism requires use_end2end to be True") + if pp_size <= 1: + raise ValueError("pipeline_size must be greater than 1 when pipeline is enabled") + if not nmicros: + raise ValueError("nmicros must be set when pipeline is enabled") + if nstages % pp_size != 0: + raise ValueError(f'invalid pipeline_size {pp_size} for nstages {nstages}') + if ngpus % pp_size != 0: + raise ValueError(f'invalid pipeline_size {pp_size} for ngpus {ngpus}') + else: + if pp_size != 1: + raise ValueError("pipeline_size must be 1 when pipeline is disabled") + + # set recompute groups + for group in recompute_groups.values(): + if len(group) <= 1: + continue + graph.recompute(group) + + # add multiref for shared parameters across stages + # note that we have constrained that shared parameters cannot be partitioned in SPMDSolver, other input tensors + # belonging to the same operator can be partitioned. For example, in some LLMs, the embedding matrix is shared + # with the output layer. In this case, the batch dim / seq dim of the activation tensor can be partitioned. + for ftensor, stage_info in tensor_splits.items(): + if not ftensor.is_param(): + continue + splits = set(k.dim if k is not None else None for v in stage_info.values() for k in v) + find_replicated = None in splits + splits = list(splits) + # For safety, we will add multiref when detecting shared param are all replicated for pipeline parallelism. + # The reason is that stages may have different number of devices, it is hard to synchronize gradients directly + # by inserting reducers although weights are all REPLICAED. + if len(splits) > 1 or (pp_enabled and find_replicated): + _logger.info(f'add multiref for shared param {ftensor}') + graph.multiref(ftensor, comment='shared param') + + # set pipeline stages + if pp_enabled: + graph.staging([s[0] for s in pp_stages]) + pp_segs: list[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + + for stage_id, stage in enumerate(pp_segs): + for node in stage.select(ntype=IRFwOperation): + if node in fw_nodes: + continue + if node.fn == multiref: # skip multiref nodes + continue + assert node.fn == identity, "Internal Error: non-identity node added in staging" + # force identity nodes to be replicated + # these nodes are usually added for data transfer between stages in graph.staging + # TODO: is it possible to have TP here? + op_plans[node] = OpPlan(op=node, stage_id=stage_id, partition=None) + + # add multiref to an activation tensor when the states of the tensor and its grad are different + # among consumers and current segment's outputs + for ftensor, stage_info in tensor_splits.items(): + # Parameter are already handled above + if ftensor.is_grad() or ftensor.is_param(): + continue + + # check if this tensor is in the output of each stage + is_seg_output: dict[int, bool] = {} + for idx, stage in enumerate(pp_segs): + is_seg_output[idx] = IR.contains_object( + stage.outputs(), + lambda x: isinstance(x, IRSubTensor) and x.parent == ftensor + ) + + for idx, splits in stage_info.items(): + stage = pp_segs[idx] + split_list = list(splits) + if len(split_list) > 1 or ( + is_seg_output[idx] and split_list[0] is not None # treat segment output as a consumer + ): + _logger.debug(f'add multiref for {ftensor} in stage {stage}') + stage.multiref(ftensor, comment='activation') + + # stage-wise tensor parallelism + curr_devices = list(range(ngpus)) + for op_plan in op_plans.values(): + idx = op_plan.stage_id % pp_size + devs = curr_devices[idx * tp_size: (idx + 1)* tp_size] + if op_plan.partition is not None: + _tp(graph, op_plan.op, devs, idx=op_plan.partition.input, dim=op_plan.partition.dim) + else: + _replica(graph, op_plan.op, devs) + + # replicate dataloader + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs=list(range(ngpus))) + + if pp_enabled: + cfg.apply_pipeline_scheduler(graph, nstages, nmicros, scheduler) + + return graph diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 51330fd9..548e9024 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -164,12 +164,12 @@ def f(self) -> None: assert modified assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' def f(func_name, type: int, /, *args, **kwargs): - return patched_run(func_name, type, *args, **kwargs) + return patched_run(func_name, 'func_name(type, *args, **kwargs)', type, *args, **kwargs) def g(): - return patched_run(x + y, a, b) + return patched_run(x + y, '(x + y)(a, b)', a, b) class A: def f(self) -> None: - patched_run(patched_run(super).f) + patched_run(patched_run(super, 'super()').f, 'super().f()') ''').strip() @@ -188,10 +188,10 @@ def __init__(self) -> None: modified, new_ast = transform(tree, transfomers) assert modified assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' - x = patched_run(not_, True) + x = patched_run(not_, 'not_(True)', True) def f(func_name, type: int, /, *args, **kwargs): - return patched_run(func_name, type, *args, **kwargs) + return patched_run(func_name, 'func_name(type, *args, **kwargs)', type, *args, **kwargs) class A: def __init__(self) -> None: - patched_run(super(self.__class__, self).__init__) + patched_run(super(self.__class__, self).__init__, 'super(self.__class__, self).__init__()') ''').strip() diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index aa177f8f..43e83178 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -107,6 +107,19 @@ def forward(self, x): return x +class FFN(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = torch.nn.Tanh() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + def init_distributed(): torch.distributed.init_process_group(backend='nccl') rank = torch.distributed.get_rank() diff --git a/tests/test_policies.py b/tests/test_policies.py index 4f1fa7b0..c03e6673 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -8,9 +8,12 @@ import torch import torch.nn as nn -from nnscaler.parallel import ComputeConfig, parallelize +from nnscaler.parallel import ComputeConfig, _load_parallel_module_class, parallelize +from nnscaler.policies import get_called_self_module_name, get_pas_ops +from tests.parallel_module.common import FFN +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode -from .utils import init_random +from .utils import init_random, replace_all_device_with MBS = 2 DIM = 16 @@ -58,3 +61,813 @@ def test_autodist(): load_module=False ) assert m_new is None + + +def test_call_name(): + assert get_called_self_module_name('self.up_proj(x)') == 'up_proj' + assert get_called_self_module_name('self.act_fn(self.gate_proj(x))') == 'act_fn' + assert get_called_self_module_name('self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))') == 'down_proj' + assert get_called_self_module_name('torch.tanh(x)') == '' + assert get_called_self_module_name('x * y') == '' + assert get_called_self_module_name('self.up_proj(x).transpose()') == '' + + +class FnPolicyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = FFN(4, 8) + + def forward(self, x): + x = x * 2 + x = self.ffn(x) + x = x + 3 + return x + + +def megatron_ffn_policy(graph, cfg): + from nnscaler.ir import IRSubTensor + from nnscaler.policies import OpPlan, OpPartition + + for node in get_pas_ops(graph): + if FFN not in node.module_class_chain: # work on FFN module + continue + + if node.fn in [torch.tanh, torch.mul]: + yield OpPlan(node, partition=OpPartition(input=0, dim=1)) + continue + + assert node.fn == torch.nn.functional.linear + + input1: IRSubTensor = node.input(1) + if not input1.is_param(): # linear weight param + continue + + # we will partition gate_proj/up_proj with column parallelism (tp=ngpus) + # and partition down_proj with row parallelism (tp=ngpus) + + if input1.name.endswith('gate_proj.weight') or input1.name.endswith('up_proj.weight'): + # gate_proj/up_proj + # column parallelism + yield OpPlan(node, partition=OpPartition(input=1, dim=0)) + elif input1.name.endswith('down_proj.weight'): + # down_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + + +def megatron_ffn_policy_auto(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition + + linear_rank = 0 + for node in get_pas_ops(graph): + if FFN not in node.module_class_chain: # work on FFN module + continue + + if node.fn == torch.nn.functional.linear: + if linear_rank in [0, 1]: + # gate_proj/up_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=0)) + else: + assert linear_rank == 2 + # down_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + linear_rank += 1 + else: + # other ops + yield OpPlan(node, partition='auto') + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('policy', [megatron_ffn_policy, megatron_ffn_policy_auto]) +def test_codegen_fn(tmp_path, policy): + parallelize( + FnPolicyModule(), + {'x': torch.randn(2, 4)}, + policy, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicyModule, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + assert fullmap['ffn.gate_proj.weight'].shape == (8, 4) and fullmap['ffn.gate_proj.weight'].sub_shape == (4, 4) + assert fullmap['ffn.up_proj.weight'].shape == (8, 4) and fullmap['ffn.up_proj.weight'].sub_shape == (4, 4) + assert fullmap['ffn.down_proj.weight'].shape == (4, 8) and fullmap['ffn.down_proj.weight'].sub_shape == (4, 4) + + # will generate two communication ops + # one for ffn input + assert _gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + # one for ffn output + assert _gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.allreduce_identity') + + assert len(_gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.')) == 2 + + # Generated code of rank 0 should looks like: + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + + # self.register_parameter('ffn_gate_proj_weight_49', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_gate_proj_weight_49', 5, True, 'ffn.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_up_proj_weight_63', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_up_proj_weight_63', 11, True, 'ffn.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_down_proj_weight_77', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_down_proj_weight_77', 17, True, 'ffn.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def segment118(self, x_25): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1653, in forward, x = x * 2 + # mul_27 = torch.mul(x_25, 2) + # del x_25 + # mul_27 = nnscaler.runtime.adapter.nn.identity_allreduce(mul_27, ranks=[0, 1]) + # # created at IRAdapterGener:local_consumer_multiref + # mul_85, mul_89 = nnscaler.runtime.function.multiref(mul_27, times=2) + # del mul_27 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_51 = torch.nn.functional.linear(mul_85, self.ffn_gate_proj_weight_49, bias=None) + # del mul_85 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_59 = torch.tanh(linear_51) + # del linear_51 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_65 = torch.nn.functional.linear(mul_89, self.ffn_up_proj_weight_63, bias=None) + # del mul_89 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_73 = torch.mul(tanh_59, linear_1_65) + # del tanh_59, linear_1_65 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_79 = torch.nn.functional.linear(mul_1_73, self.ffn_down_proj_weight_77, bias=None) + # del mul_1_73 + # linear_2_35 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_2_79, ranks=[0, 1]) + # del linear_2_79 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1655, in forward, x = x + 3 + # add_26 = torch.add(linear_2_35, 3, alpha=1) + # del linear_2_35 + # return add_26 + + +class FFNDropout(torch.nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = torch.nn.Tanh() + self.dropout = torch.nn.Dropout(p=0.1) + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return self.dropout(down_proj) + + +class FnPolicyModuleList(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + + def forward(self, x): + x = x * 2 + for ffn in self.ffn: + x = ffn(x) + x = x + 3 + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + +def megatron_ffn_policy_list(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition, get_layer_index, get_called_self_module_name + + for node in get_pas_ops(graph): + if FFNDropout not in node.module_class_chain: # work on FFN module + continue + + ffn_idx = get_layer_index(node.fqn) + module_called = get_called_self_module_name(node.call_expr) + + if node.fn == torch.nn.functional.linear: + if module_called in ['gate_proj', 'up_proj']: + # gate_proj/up_proj + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition=OpPartition(input=1, dim=0)) + else: + # down_proj + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition=OpPartition(input=1, dim=1)) + else: + # other ops + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition='auto') + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline(tmp_path): + parallelize( + FnPolicyModuleList(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + 'pipeline_size': 2, + } + ), + gen_savedir=tmp_path, + load_module=False + ) + + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicyModuleList, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + tp_idx = rank // 2 + assert fullmap[f'ffn.{tp_idx}.gate_proj.weight'].shape == (8, 4) and fullmap[f'ffn.{tp_idx}.gate_proj.weight'].sub_shape == (4, 4) + assert fullmap[f'ffn.{tp_idx}.up_proj.weight'].shape == (8, 4) and fullmap[f'ffn.{tp_idx}.up_proj.weight'].sub_shape == (4, 4) + assert fullmap[f'ffn.{tp_idx}.down_proj.weight'].shape == (4, 8) and fullmap[f'ffn.{tp_idx}.down_proj.weight'].sub_shape == (4, 4) + + # will generate two communication ops + # one for ffn input + assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + # one for ffn output + assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.allreduce_identity') + + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.')) == 2 + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, r'ckpt.checkpoint\(recompute')) == 1 + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, r'def recompute\(')) == 1 + + + # Generated code of rank 0 looks like: + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 0 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('ffn_0_gate_proj_weight_168', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_gate_proj_weight_168', 5, True, 'ffn.0.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_up_proj_weight_182', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_up_proj_weight_182', 11, True, 'ffn.0.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_down_proj_weight_196', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_down_proj_weight_196', 17, True, 'ffn.0.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def segment79(self, x_49): + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 243, in forward, x = x * 2 + # mul_51 = torch.mul(x_49, 2) + # del x_49 + # mul_51 = nnscaler.runtime.adapter.nn.identity_allreduce(mul_51, ranks=[0, 1]) + + # def recompute(mul_51): + # # created at IRAdapterGener:local_consumer_multiref + # mul_246, mul_250 = nnscaler.runtime.function.multiref(mul_51, times=2) + # del mul_51 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_170 = torch.nn.functional.linear(mul_246, self.ffn_0_gate_proj_weight_168, bias=None) + # del mul_246 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_178 = torch.tanh(linear_170) + # del linear_170 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_184 = torch.nn.functional.linear(mul_250, self.ffn_0_up_proj_weight_182, bias=None) + # del mul_250 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_192 = torch.mul(tanh_178, linear_1_184) + # del tanh_178, linear_1_184 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_198 = torch.nn.functional.linear(mul_1_192, self.ffn_0_down_proj_weight_196, bias=None) + # del mul_1_192 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_21 = self.training + # linear_2_59 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_2_198, ranks=[0, 1]) + # del linear_2_198 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_60 = torch. + # nn.functional.dropout(linear_2_59, p=0.1, training=ffn_0_dropout_training_21, inplace=False) + # del linear_2_59 + # return dropout_60 + + # dropout_60 = ckpt.checkpoint(recompute, mul_51, use_reentrant=False) + # return dropout_60 + + # def adapter196(self, dropout_60): + # dropout_236 = nnscaler.runtime.adapter.chunk(dropout_60, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(dropout_236, shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # return + + # def adapter207(self): + # gdropout_242 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # gdropout_85 = nnscaler.runtime.adapter.all_gather(gdropout_242, dim=1, ranks=[0, 1]) + # return gdropout_85 + + # def adapter160(self): + # sum_1_50 = nnscaler.runtime.adapter.move((), shape=(), dtype=torch.float32, src=2, dst=0) + # return sum_1_50 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_71): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # x_49 = next(*(dataloader_71, )) + # dropout_60 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_49, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_60, ), requires_grad=False) + # x_278 = next(*(dataloader_71, )) + # dropout_286 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_278, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_286, ), requires_grad=False) + # gdropout_85 = nnscaler.runtime.executor.aexecute(model.adapter207, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gx_73 = nnscaler.runtime.executor.backward('segment79', (x_49, ), (dropout_60, ), (gdropout_85, )) + # del x_49, dropout_60, gdropout_85, gx_73 + # gdropout_287 = nnscaler.runtime.executor.aexecute(model.adapter207, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gx_279 = nnscaler.runtime.executor.backward('segment79', (x_278, ), (dropout_286, ), (gdropout_287, )) + # del x_278, dropout_286, gdropout_287, gx_279 + # sum_1_50 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=True) + # sum_1_306 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=True) + + # def _infer_step(model, dataloader_71): + # _ = None + # x_49 = next(*(dataloader_71, )) + # dropout_60 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_49, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_60, ), requires_grad=False) + # x_278 = next(*(dataloader_71, )) + # dropout_286 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_278, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_286, ), requires_grad=False) + # sum_1_50 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=False) + # sum_1_306 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=False) + # return sum_1_50, sum_1_306 + assert True + + +class FnPolicyModuleSharedWeight(torch.nn.Module): + def __init__(self): + super().__init__() + self.input_projection = torch.nn.Linear(4, 4, bias=False) + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + self.output_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection.weight = self.input_projection.weight # share weight + + def forward(self, x): + x = self.input_projection(x) + for ffn in self.ffn: + x = ffn(x) + x = self.output_projection(x) + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline_shared_weight(tmp_path): + parallelize( + FnPolicyModuleSharedWeight(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + 'pipeline_size': 2, + } + ), + gen_savedir=tmp_path, + load_module=False + ) + for rank in range(2): + # the input projection is multiref'ed + assert _gencode_contains(tmp_path, FnPolicyModuleSharedWeight, rank, r'nnscaler.runtime.function.multiref\(self.input_projection') + + for rank in range(2, 4): + # receive shared weight projection via identity + assert _gencode_contains(tmp_path, FnPolicyModuleSharedWeight, rank, r'nnscaler.runtime.function.identity\(input_projection') + + # Generated code of rank 0 looks like: + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 1 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('input_projection_weight_55', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('input_projection_weight_55', 3, True, 'input_projection.weight', (4, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_gate_proj_weight_189', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_gate_proj_weight_189', 7, True, 'ffn.0.gate_proj.weight', (8, 4), (slice(4, 8, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_up_proj_weight_203', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_up_proj_weight_203', 13, True, 'ffn.0.up_proj.weight', (8, 4), (slice(4, 8, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_down_proj_weight_217', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_down_proj_weight_217', 19, True, 'ffn.0.down_proj.weight', (4, 8), (slice(0, 4, None), slice(4, 8, None)), 1) + # self._post_init(init_params, build_buckets) + + # def segment83(self, x_53): + # # shared param + # input_projection_weight_173, input_projection_weight_174 = nnscaler.runtime.function.multiref(self.input_projection_weight_55, times=2) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 441, in forward, x = self.input_projection(x) + # linear_56 = torch.nn.functional.linear(x_53, input_projection_weight_173, bias=None) + # del x_53, input_projection_weight_173 + # linear_56 = nnscaler.runtime.adapter.nn.identity_allreduce(linear_56, ranks=[0, 1]) + + # def recompute(linear_56): + # # created at IRAdapterGener:local_consumer_multiref + # linear_278, linear_282 = nnscaler.runtime.function.multiref(linear_56, times=2) + # del linear_56 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_191 = torch.nn.functional.linear(linear_278, self.ffn_0_gate_proj_weight_189, bias=None) + # del linear_278 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_199 = torch.tanh(linear_1_191) + # del linear_1_191 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_205 = torch.nn.functional.linear(linear_282, self.ffn_0_up_proj_weight_203, bias=None) + # del linear_282 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_213 = torch.mul(tanh_199, linear_2_205) + # del tanh_199, linear_2_205 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_3_219 = torch.nn.functional.linear(mul_213, self.ffn_0_down_proj_weight_217, bias=None) + # del mul_213 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_23 = self.training + # linear_3_64 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_3_219, ranks=[0, 1]) + # del linear_3_219 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_65 = torch.nn.functional.dropout(linear_3_64, p=0.1, training=ffn_0_dropout_training_23, inplace=False) + # del linear_3_64 + # return dropout_65 + + # dropout_65 = ckpt.checkpoint(recompute, linear_56, use_reentrant=False) + # return dropout_65, input_projection_weight_174 + + # def adapter190(self, input_projection_weight_174): + # input_projection_weight_257 = nnscaler.runtime.adapter.chunk(input_projection_weight_174, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(input_projection_weight_257, shape=(4, 2), dtype=torch.float32, src=1, dst=3) + # return + + # def adapter234(self, dropout_65): + # dropout_265 = nnscaler.runtime.adapter.chunk(dropout_65, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(dropout_265, shape=(4, 2), dtype=torch.float32, src=1, dst=3) + # return + + # def adapter245(self): + # gdropout_267 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=3, dst=1) + # gdropout_92 = nnscaler.runtime.adapter.all_gather(gdropout_267, dim=1, ranks=[0, 1]) + # return gdropout_92 + + # def adapter201(self): + # ginput_projection_weight_263 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=3, dst=1) + # ginput_projection_weight_177 = nnscaler.runtime.adapter.all_gather(ginput_projection_weight_263, dim=1, ranks=[0, 1]) + # return ginput_projection_weight_177 + + # def adapter214(self): + # sum_1_54 = nnscaler.runtime.adapter.move((), shape=(), dtype=torch.float32, src=3, dst=1) + # return sum_1_54 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_76): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # x_53 = next(*(dataloader_76, )) + # dropout_65, input_projection_weight_174 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_53, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_174, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_65, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # dropout_310, input_projection_weight_314 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_302, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_314, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_310, ), requires_grad=False) + # gdropout_92 = nnscaler.runtime.executor.aexecute(model.adapter245, *(), requires_grad=False) + # ginput_projection_weight_177 = nnscaler.runtime.executor.aexecute(model.adapter201, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gx_78 = nnscaler.runtime.executor.backward('segment83', (x_53, ), (dropout_65, input_projection_weight_174, ), (gdropout_92, ginput_projection_weight_177, )) + # del x_53, dropout_65, input_projection_weight_174, gdropout_92, ginput_projection_weight_177, gx_78 + # gdropout_311 = nnscaler.runtime.executor.aexecute(model.adapter245, *(), requires_grad=False) + # ginput_projection_weight_315 = nnscaler.runtime.executor.aexecute(model.adapter201, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gx_303 = nnscaler.runtime.executor.backward('segment83', (x_302, ), (dropout_310, input_projection_weight_314, ), (gdropout_311, ginput_projection_weight_315, )) + # del x_302, dropout_310, input_projection_weight_314, gdropout_311, ginput_projection_weight_315, gx_303 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=True) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=True) + # return sum_1_54, sum_1_349 + + # def _infer_step(model, dataloader_76): + # _ = None + # x_53 = next(*(dataloader_76, )) + # dropout_65, input_projection_weight_174 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_53, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_174, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_65, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # dropout_310, input_projection_weight_314 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_302, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_314, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_310, ), requires_grad=False) + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=False) + # return sum_1_54, sum_1_349 + + # Generated code of rank 2 looks like: + + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 2 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('ffn_1_gate_proj_weight_222', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_gate_proj_weight_222', 26, True, 'ffn.1.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_1_up_proj_weight_236', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_up_proj_weight_236', 32, True, 'ffn.1.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_1_down_proj_weight_250', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_down_proj_weight_250', 38, True, 'ffn.1.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def adapter190(self): + # input_projection_weight_256 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # input_projection_weight_174 = nnscaler.runtime.adapter.all_gather(input_projection_weight_256, dim=1, ranks=[2, 3]) + # return input_projection_weight_174 + + # def adapter234(self): + # dropout_264 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # dropout_65 = nnscaler.runtime.adapter.all_gather(dropout_264, dim=1, ranks=[2, 3]) + # return dropout_65 + + # def segment93(self, dropout_65, input_projection_weight_174): + # input_projection_weight_184 = nnscaler.runtime.function.identity(input_projection_weight_174) + # del input_projection_weight_174 + # dropout_180 = nnscaler.runtime.function.identity(dropout_65) + # del dropout_65 + # dropout_180 = nnscaler.runtime.adapter.nn.identity_allreduce(dropout_180, ranks=[2, 3]) + + # def recompute(dropout_180): + # # created at IRAdapterGener:local_consumer_multiref + # dropout_286, dropout_290 = nnscaler.runtime.function.multiref(dropout_180, times=2) + # del dropout_180 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_4_224 = torch.nn.functional.linear(dropout_286, self.ffn_1_gate_proj_weight_222, bias=None) + # del dropout_286 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_1_232 = torch.tanh(linear_4_224) + # del linear_4_224 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_5_238 = torch.nn.functional.linear(dropout_290, self.ffn_1_up_proj_weight_236, bias=None) + # del dropout_290 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_246 = torch.mul(tanh_1_232, linear_5_238) + # del tanh_1_232, linear_5_238 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_6_252 = torch.nn.functional.linear(mul_1_246, self.ffn_1_down_proj_weight_250, bias=None) + # del mul_1_246 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_1_dropout_training_42 = self.training + # linear_6_73 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_6_252, ranks=[2, 3]) + # del linear_6_252 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_1_74 = torch.nn.functional.dropout(linear_6_73, p=0.1, training=ffn_1_dropout_training_42, inplace=False) + # del linear_6_73 + # return dropout_1_74 + + # dropout_1_74 = ckpt.checkpoint(recompute, dropout_180, use_reentrant=False) + # del dropout_180 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 444, in forward, x = self.output_projection(x) + # linear_7_75 = torch.nn.functional.linear(dropout_1_74, input_projection_weight_184, bias=None) + # del input_projection_weight_184, dropout_1_74 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 445, in forward, return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + # sum_1_54 = torch.sum(linear_7_75) + # del linear_7_75 + # return sum_1_54 + + # def adapter245(self, gdropout_92): + # gdropout_266 = nnscaler.runtime.adapter.chunk(gdropout_92, dim=1, ranks=[2, 3]) + # _ = nnscaler.runtime.adapter.move(gdropout_266, shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # return + + # def adapter201(self, ginput_projection_weight_177): + # ginput_projection_weight_262 = nnscaler.runtime.adapter.chunk(ginput_projection_weight_177, dim=1, ranks=[2, 3]) + # _ = nnscaler.runtime.adapter.move(ginput_projection_weight_262, shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # return + + # def adapter214(self, sum_1_54): + # _ = nnscaler.runtime.adapter.move(sum_1_54, shape=(), dtype=torch.float32, src=2, dst=0) + # return sum_1_54 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_76): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # input_projection_weight_174 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=True) + # dropout_65 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=True) + # sum_1_54 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_65, input_projection_weight_174, ), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gdropout_92, ginput_projection_weight_177 = nnscaler.runtime.executor.backward('segment93', (dropout_65, input_projection_weight_174, ), (sum_1_54, ), (None, )) + # sum_1_54 = sum_1_54.detach() + # input_projection_weight_314 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=True) + # dropout_310 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter245, *(gdropout_92, ), requires_grad=False) + # del dropout_65, gdropout_92 + # _ = nnscaler.runtime.executor.aexecute(model.adapter201, *(ginput_projection_weight_177, ), requires_grad=False) + # del input_projection_weight_174, ginput_projection_weight_177 + # sum_1_349 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_310, input_projection_weight_314, ), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gdropout_311, ginput_projection_weight_315 = nnscaler.runtime.executor.backward('segment93', (dropout_310, input_projection_weight_314, ), (sum_1_349, ), (None, )) + # sum_1_349 = sum_1_349.detach() + # _ = nnscaler.runtime.executor.aexecute(model.adapter245, *(gdropout_311, ), requires_grad=False) + # del dropout_310, gdropout_311 + # _ = nnscaler.runtime.executor.aexecute(model.adapter201, *(ginput_projection_weight_315, ), requires_grad=False) + # del input_projection_weight_314, ginput_projection_weight_315 + # x_302 = next(*(dataloader_76, )) + # del x_302 + # x_53 = next(*(dataloader_76, )) + # del x_53 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_54, ), requires_grad=True) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_349, ), requires_grad=True) + # return sum_1_54, sum_1_349 + + # def _infer_step(model, dataloader_76): + # _ = None + # input_projection_weight_174 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=False) + # dropout_65 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=False) + # sum_1_54 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_65, input_projection_weight_174, ), requires_grad=False) + # input_projection_weight_314 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=False) + # dropout_310 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_310, input_projection_weight_314, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # del x_302 + # x_53 = next(*(dataloader_76, )) + # del x_53 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_54, ), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_349, ), requires_grad=False) + # return sum_1_54, sum_1_349 + + +class FnPolicySharedWeightModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.input_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection.weight = self.input_projection.weight # share weight + + def forward(self, x): + x = self.input_projection(x) + x = self.output_projection(x) + return x + + +def shared_weight_different_partition_policy(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition, get_called_self_module_name + + for node in get_pas_ops(graph): + module_called = get_called_self_module_name(node.call_expr) + + if node.fn == torch.nn.functional.linear and module_called == 'output_projection': + # input_projection.weight is used two times with different partition + # x = self.input_projection(x) --> no partition + # x = self.output_projection(x) --> partition dim=1 + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + + +@replace_all_device_with('cpu') +def test_codegen_fn_shared_weight(tmp_path): + parallelize( + FnPolicySharedWeightModule(), + {'x': torch.randn(4, 4)}, + # 'pp', + shared_weight_different_partition_policy, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicySharedWeightModule, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + # the input projection is multiref'ed + assert _gencode_contains(tmp_path, FnPolicySharedWeightModule, rank, r'nnscaler.runtime.function.multiref\(self.input_projection') + # input_projection.weight will not be splitted + # because it is multiref'ed + assert fullmap['input_projection.weight'].shape == (4, 4) and fullmap['input_projection.weight'].sub_shape == (4, 4) + + # Generated code of rank 0 looks like: + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 2]) + # self.init_group(ranks=[1, 3]) + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('input_projection_weight_15', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('input_projection_weight_15', 3, True, 'input_projection.weight', (4, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.wreducer80 = nnscaler.runtime.adapter.Reducer(ranks=[0, 2], reduce_op='sum', async_op=async_op, zero=False, max_bucket_size_bytes=max_bucket_size_bytes, zero_use_reduce_scatter=zero_use_reduce_scatter, zero_ngroups=1) + # self.wreducer80.add_param(self.input_projection_weight_15) + # self.add_reducer(self.wreducer80) + + # self._post_init(init_params, build_buckets) + + # def segment76(self, x_13): + # # shared param + # input_projection_weight_32, input_projection_weight_33 = nnscaler.runtime.function.multiref(self.input_projection_weight_15, times=2) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 763, in forward, x = self.input_projection(x) + # linear_16 = torch.nn.functional.linear(x_13, input_projection_weight_32, bias=None) + # del x_13, input_projection_weight_32 + # linear_22 = nnscaler.runtime.adapter.nn.split_allgather(linear_16, dim=1, ranks=[0, 1]) + # del linear_16 + # input_projection_weight_37 = nnscaler.runtime.adapter.nn.split_allgather(input_projection_weight_33, dim=1, ranks=[0, 1]) + # del input_projection_weight_33 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 764, in forward, x = self.output_projection(x) + # linear_1_26 = torch.nn.functional.linear(linear_22, input_projection_weight_37, bias=None) + # del linear_22, input_projection_weight_37 + # linear_1_14 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_1_26, ranks=[0, 1]) + # del linear_1_26 + # return linear_1_14 + + +class FnPolicyModuleList2(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + + def forward(self, x): + x = x * 2 + for ffn in self.ffn: + x = ffn(x) + x = x + 3 + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline2(tmp_path): + parallelize( + FnPolicyModuleList2(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + # 4 stages, with pp=2 + 'pipeline_size': 2, + 'pipeline_scheduler': '1f1b_interleaved', + } + ), + gen_savedir=tmp_path, + load_module=False + ) + # should successfully generate code without error + assert True From f6e5d60ef9c67a58dea471675096d89cdc276aa0 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 29 Oct 2025 07:55:49 +0000 Subject: [PATCH 1848/1892] Merged PR 2415: [Refine] Refine dynamic shape support 1. Mark some dims of some inputs as dynamic (via `nnscaler.mark_dynamic, cli supported is also added). 2. Use dim tracking to track the flow of dims. 3. Propagate the dim dynamicness in the compute graph --- nnscaler/__init__.py | 4 + nnscaler/cli/mixed_module.py | 23 +-- nnscaler/cli/trainer.py | 4 +- nnscaler/cli/trainer_args.py | 44 +++--- nnscaler/graph/function/function.py | 82 ++++++++-- nnscaler/graph/graph.py | 1 + nnscaler/graph/parser/parser.py | 49 +++++- nnscaler/graph/parser/value_tracker.py | 182 +++++++++++++++++++---- nnscaler/graph/tracer/metadata.py | 6 +- nnscaler/ir/cten.py | 109 ++++++++++++-- nnscaler/parallel.py | 3 +- nnscaler/policies.py | 2 + nnscaler/utils.py | 34 +++++ tests/cli/common.py | 81 ++++++++++ tests/cli/test_trainer.py | 100 +++++++++++++ tests/cli/trainer_args_csa.yaml | 53 +++++++ tests/graph/function/test_functions.py | 36 +++++ tests/graph/parser/test_value_tracker.py | 93 +++++++++++- tests/ir/test_cten.py | 4 +- tests/parallel_module/test_gencode.py | 148 +++++++++++++++++- 20 files changed, 963 insertions(+), 95 deletions(-) create mode 100644 tests/cli/trainer_args_csa.yaml diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index c9265af8..b3a18165 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -28,6 +28,10 @@ ) from nnscaler.runtime.f16_optimizer import MixedPrecisionAdam, MixedPrecisionAdamW from nnscaler.runtime.hybrid_optimizer import HybridLRScheduler, HybridOptimizer +from nnscaler.utils import ( + mark_dynamic, + get_dynamic, +) def init(): diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index 69487fb2..eb5bb7ab 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -166,17 +166,22 @@ def create_model(self, module_args: Optional[tuple[tuple, dict]]=None) -> torch. def create_dummy_forward_args(self, dummy_input) -> dict[str, Any]: if self.parallel_module: - return self.fix_input( + forward_args = self.fix_input( self.parallel_module.create_dummy_forward_args(self.trainer_args) ) - - # forward args of whole model - arg_names = list( - inspect.signature( - inspect.unwrap(getattr(self.model_type, 'forward')) - ).parameters.keys() - ) - return {arg_names[1]: self.fix_input(dummy_input)} # arg_names[0] is self + if self.parallel_module.forward_args_post_process_fn: + forward_args = self.parallel_module.forward_args_post_process_fn(self.trainer_args, forward_args) + return forward_args + else: + # forward args of whole model + arg_names = list( + inspect.signature( + inspect.unwrap(getattr(self.model_type, 'forward')) + ).parameters.keys() + ) + # dummy input is already fixed and post processed by trainer + forward_args = {arg_names[1]: dummy_input} # arg_names[0] is self + return forward_args def resolve_compute_config(self): compute_config = copy.deepcopy(self.compute_config) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 41b1132f..ae65da81 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -128,7 +128,7 @@ def _fix_input(self, input): return fix_input(input, self.train_args.input_dtype) def _load_dummy_input(self): - if dummy_sample_gen_fn := self.train_args.resolved_dummy_sample_gen_fn: + if dummy_sample_gen_fn := self.train_args.dummy_sample_gen_fn: return dummy_sample_gen_fn(self.train_args) with enforce_zero_num_worker(DataLoader): @@ -159,6 +159,8 @@ def _setup(self): # load a dummy input from training dataset self.dummy_input = self._load_dummy_input() self.dummy_input = self._fix_input(self.dummy_input) + if self.train_args.dummy_sample_post_process_fn: + self.dummy_input = self.train_args.dummy_sample_post_process_fn(self.train_args, self.dummy_input) pmodel = parallelize_model( self.train_args, self.dummy_input, diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 2bf23bf2..d4bd6974 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -20,7 +20,7 @@ import torch import nnscaler -from nnscaler.utils import fields, fn_field, transform_recursively, load_type +from nnscaler.utils import fields, fn_field, transform_recursively, load_type, copy_dynamic from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule @@ -123,9 +123,9 @@ def fix_input(input, input_dtype=None): return tuple(fix_input(v, input_dtype) for v in input) elif isinstance(input, torch.Tensor): if input.is_floating_point() and input_dtype is not None: - return input.to(input_dtype).cuda() + return copy_dynamic(input, input.to(input_dtype).cuda()) else: - return input.cuda() + return copy_dynamic(input, input.cuda()) return input @@ -238,9 +238,17 @@ class ModuleParallelizeConfig: # we can parallelize submodules instead of creating whole model. # This is useful sometimes. args: Optional[Dict[str, Any]] = None - # the full qualified name of the function to generate dummy forward args - # Its type should be `Callable[[TrainerArgs],Dict[str, Any]]` - forward_args_gen_fn: str = None + # the full qualified name of the function to generate dummy inputs for forward + # Its type should be `Callable[[TrainerArgs], dict[str, Any]]` + # where the output dict is the kwargs for forward function of the module + # The tensors in the sample will be moved to GPU and converted to input_dtype by trainer. + forward_args_gen_fn: Optional[Callable[['TrainerArgs'], dict[str, Any]]] = fn_field(default=None) + # the full qualified name of the function to post process the dummy inputs for forward + # Note the tensors in the inputs have been moved to GPU and converted to input_dtype + # But you can still further process the sample, + # for example, mark some dims of tensors as dynamic + # (you can do it in `forward_args_gen_fn` as well) + forward_args_post_process_fn: Optional[Callable[['TrainerArgs', dict[str, Any]], dict[str, Any]]] = fn_field(default=None) # the model state dict file for tracing. # It is only used in tracing to serve as the initial state dict of the model. tracing_from_weights: str = None @@ -289,8 +297,7 @@ def create_model(self, trainer_args: 'TrainerArgs', module_args: Optional[tuple[ return self.model_type(*args, **kwargs) def create_dummy_forward_args(self, trainer_args: 'TrainerArgs') -> dict[str, Any]: - forward_args_gen_fn = load_type(self.forward_args_gen_fn) - return forward_args_gen_fn(trainer_args) + return self.forward_args_gen_fn(trainer_args) @dataclass @@ -605,9 +612,16 @@ class TrainerArgs(PrecisionMixin, PolicyMixin): # compile: compile the model but not training # run: compile and run the model run_mode: str = 'run' - # the full qualified name of the function to generate dummy sample for forward + # the full qualified name of the function to generate dummy sample # Its type should be `Callable[[TrainerArgs], Any]` - dummy_sample_gen_fn: str = None + # The tensors in the sample will be moved to GPU and converted to input_dtype by trainer. + dummy_sample_gen_fn: Optional[Callable[['TrainerArgs'], Any]] = fn_field(default=None) + # the full qualified name of the function to post process the dummy sample + # Note the tensors in the sample have been moved to GPU and converted to input_dtype + # But you can still further process the sample, + # for example, you can use this function to mark some dims of tensors as dynamic + # when you don't use `dummy_sample_gen_fn` or don't handle dynamic dims in it, + dummy_sample_post_process_fn: Optional[Callable[['TrainerArgs', Any], Any]] = fn_field(default=None) # the model state dict file for tracing. # It is only used in tracing to serve as the initial state dict of the model. tracing_from_weights: str = None @@ -821,12 +835,6 @@ def resolved_aggregate_outputs_fn(self): return None return load_type(self.optimizer.aggregate_outputs_fn) - @property - def resolved_dummy_sample_gen_fn(self): - if not self.dummy_sample_gen_fn: - return None - return load_type(self.dummy_sample_gen_fn) - @property def scaling_factor(self): return self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus @@ -856,7 +864,7 @@ def init_env(self, trainer: 'Trainer'): init_env_fn = load_type(self.init_env_fn) init_env_fn(trainer) - def get_resolved_var(self, fqn: str) -> Any: + def get_resolved_var(self, fqn: str, *, default: Any = None) -> Any: """ Get a resolved variable from the vars dictionary. The fqn is a full qualified name of the variable, e.g. 'x.y.z'. @@ -865,7 +873,7 @@ def get_resolved_var(self, fqn: str) -> Any: var = self._vars for part in parts: if part not in var: - raise ValueError(f"Variable {fqn} not found in vars") + return default var = var[part] return var diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 66cfff2a..54bd3fef 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -377,13 +377,38 @@ def creation_function_size_check(op_name, size, *arg_size) -> Tuple[Union[int, I raise ValueError(f"get illegal input size={size}, arg_size={arg_size} in {op_name}") # convert scalar to shape (1,) tensor, nnscaler don't support empty shape [] now. if len(size_val) == 0: - _logger.warn(f"detect tensor creation function {op_name} create a scalar, force it to create a shape [1] tensor instead") + _logger.warning(f"detect tensor creation function {op_name} create a scalar, force it to create a shape [1] tensor instead") size = (1,) else: raise ValueError(f"get unknown input type size={size} in {op_name}") return size +def creation_function_dim_track(resolved_size: Union[IRObject, tuple[Union[int, IRObject]]]) -> list[ValueTrack]: + if isinstance(resolved_size, IRObject): + assert isinstance(resolved_size.value, (tuple, list)) + # all dims dependent on resolved_size + return [ValueTrack.new([resolved_size]) for _ in resolved_size.value] + + dim_tracks = [] + for dim in resolved_size: + if isinstance(dim, IRObject): + dim_tracks.append(ValueTrack.new([dim])) + else: + # no dim dependency when dim is not IRObject + dim_tracks.append(ValueTrack.new([])) + return dim_tracks + + +def creation_function_set_dim_tracks(op: IRDimops, resolved_size: Union[IRObject, tuple[Union[int, IRObject]]]) -> IRDimops: + # Output will be replaced in Parser, + # Here we just pass the value tracks out + output = IRFullTensor(_unwrap_value(resolved_size)) + output.dim_tracks = creation_function_dim_track(resolved_size) + op.set_output(0, output) + return op + + def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): """ @@ -399,7 +424,10 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs), + size + ) def Zeros(size, *arg_size, out=None, dtype=None, layout=None, @@ -415,7 +443,10 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, size = creation_function_size_check('torch.zeros', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs), + size + ) def Ones(size, *arg_size, out=None, dtype=None, layout=None, @@ -431,7 +462,10 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, size = creation_function_size_check('torch.ones', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs), + size + ) def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, @@ -449,7 +483,10 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs), + size + ) def Randn(size, *arg_size, generator=None, out=None, dtype=None, layout=None, device=None, requires_grad=False, @@ -467,7 +504,10 @@ def Randn(size, *arg_size, generator=None, out=None, dtype=None, layout=None, de kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Randn, 'randn', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Randn, 'randn', signature, [anno], [], rules, **kwargs), + size + ) def Full(size, fill_value, *, out=None, dtype=None, layout=None, @@ -482,8 +522,11 @@ def Full(size, fill_value, *, out=None, dtype=None, layout=None, signature = 'nnscaler.runtime.function.full' size = creation_function_size_check('torch.full', size) anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Full, 'full', signature, [anno], [], rules, - size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad) + return creation_function_set_dim_tracks( + IRDimops(Full, 'full', signature, [anno], [], rules, + size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad), + size + ) def NewTensor(data, *, dtype=None, device=None, @@ -1863,7 +1906,7 @@ def Stack(tensors, dim=0, out=None, signature = None): return CubeStack(*tensors, dim=dim, signature=signature) -def Chunk(input, chunks, dim=0, signature = None): +def Chunk(input: IRTensor, chunks, dim=0, signature = None): """ torch.chunk(input, chunks, dim=0) """ @@ -1874,7 +1917,18 @@ def Chunk(input, chunks, dim=0, signature = None): for oanno in oannos: oanno[dim] = str(input.shape[dim] // chunks) anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(Chunk, 'chunk', signature, [anno], [input], chunks=chunks, dim=dim) + ret = IRDimops(Chunk, 'chunk', signature, [anno], [input], chunks=chunks, dim=dim) + + # set proper value tracks for outputs + output_shape = list(input.shape) + output_shape[dim] = input.shape[dim] // chunks + dim_vt = ValueTrack.new([chunks, input.dim_tracks[dim]]) + for d in range(chunks): + output = IRFullTensor(output_shape) + output.set_dim_track(dim, dim_vt) + ret.set_output(d, output) + + return ret def Select(input, dim, index, signature = None): @@ -3423,10 +3477,14 @@ def Item(input, signature = None): """ torch.Tensor.item() """ - # set output to IRObject.missing, + # set output value to IRObject.missing_value, # because the output is unknown here. # It will be filled with real value in parser. - return IRPyFunc(signature, inputs=[input], outputs=[IRObject.missing], constant_foldable=False) + return IRPyFunc( + signature, inputs=[input], + outputs=[IRObject('item', value=IRObject.missing_value, is_constant=False)], + constant_foldable=False + ) def DictKeys(o: Union[Dict, IRObject], signature=None): diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index dce640e3..5f412f25 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -287,6 +287,7 @@ def use_dataloader_input(self): # IRDataOperation. Since we already know the output of the dataloader, # we don't need to set the value for it. ir_root_obj = IRObject(name='dataloader', value=None, is_constant=False) + ir_root_obj.value_track.with_no_dep() data_op = IRDataOperation(ir_root_obj, self.inputs()) # add the data operation to the graph, which will use `next` to get data. self.insert(data_op, 0) diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 8e0692af..4fa263b9 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -130,7 +130,7 @@ def parse(self) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: for idx, placeholder in enumerate(placeholders): if not isinstance(inputs[idx], IRObject): obj = IRObject(name=placeholder.target, value=inputs[idx], is_constant=False) - obj.value_track.with_no_dep() + obj.value_track.mark_as_input() inputs[idx] = obj self.value_tracker.track_values([obj]) self.frame.set_var(placeholder.name, obj) @@ -141,7 +141,7 @@ def parse(self) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: ir_nodes = self._parse_node(node) all_ir_nodes += ir_nodes - self.value_tracker.track_nodes(all_ir_nodes) + self.value_tracker.complete_tracking(all_ir_nodes) # get graph outputs outputs = [self.frame.get_var(node.name) for node in self.module.graph.nodes if node.op == 'output'] @@ -189,15 +189,15 @@ def _init_objects(self, node: torch.fx.Node, is_constant: bool = True): ) if node.op == 'placeholder': - def set_no_dep(x: IRObject): + def mark_as_input(x: IRObject): if isinstance(x, IRTensor): # let's the value_track of tensor stay None(unknown) # because we don't care about it. for dt in x.dim_tracks: dt.with_no_dep() else: - x.value_track.with_no_dep() - IR.modify_objects(val, set_no_dep) + x.value_track.mark_as_input() + IR.modify_objects(val, mark_as_input) self.frame.add_var(node.name, val) @@ -374,8 +374,33 @@ def _parse_prim_function_method(self, node: torch.fx.Node) -> List[IRFwOperation # We need to copy dim tracks # As we will use frame version as node output, instead of the placeholder created in function.py for dim in range(len(vals[i].shape)): - vals[i].dim_tracks[dim].merge_deps(ir_node.output(i).dim_tracks[dim]) + vals[i].dim_tracks[dim].merge(ir_node.output(i).dim_tracks[dim]) ir_node.set_output(i, vals[i]) + elif isinstance(ir_node.output(i), IRObject) and ir_node.output(i).is_value_missing(): + # output is IRObject with missing value + # we need to set it with the value from frame + assert not IR.contains_object(vals[i], lambda x: isinstance(x, IRTensor)), \ + f'Output {i} of node {node} is expected to be IRObject, but got tensor: {vals[i]}' + ir_node.output(i).value = IR.try_unwrap(vals[i]) + else: + # Currently we don't support missing-value IRObject in tuple/list/dict/... + # TODO: add support when needed + assert not IR.contains_object(ir_node.output(i), lambda x: not isinstance(x, IRTensor) and x.is_value_missing()), \ + f'Output {i} of node {node} contains missing value: {ir_node.output(i)}' + + # per-op value tracking via its annotation + # TODO: + # This may be not accurate because many ops in function.py are not properly annotated their value deps + # Two ways to improve it: + # 1. add value deps annotation for those ops in function.py + # 2. use global data flow analysis to track value deps + # a. add all nodes without folding + # b. use value_tracker.track_nodes to analyze value deps for all nodes + # c. remove nodes that can be folded. + # It is not easy because some op logic in function.py works differently + # when its inputs are constant or not. + # For now, we just use per-op value tracking for simplicity. + self.value_tracker.track_nodes([ir_node]) # update frame with ir output # Please note when there is only one output, we will unwrap it from `ir_node.outputs()` here @@ -423,7 +448,17 @@ def _is_primitive_type(val): and not ir_node.signature.startswith(nnscaler.runtime.function.__name__ + '.')\ and not IR.contains_object(ir_node.output(0), lambda x: isinstance(x, IRTensor) or not x.is_constant) \ and _is_primitive_type(cval := IR.try_unwrap(ir_node.output(0))): + # TODO: + # This will break the value tracking graph + # for example, if not folded: + # value1 -> op1 -> value2 -> op2 -> value3 -> op3 + # if op2 is folded, then op3 will not know the value1 dependency + # So the value tracking becomes: + # value1 -> op1 value3 -> op3 + # In many cases, op1 and op3 can be connected by other ops, + # But when this becomes a problem, we need to fix it by using global data flow analysis. self.frame.set_var(node.name, cval) + self.value_tracker.untrack_node(ir_node) return [] else: return [ir_node] @@ -460,6 +495,7 @@ def _parse_prim_get_attr_node(self, node: torch.fx.Node) -> List[IRFwOperation]: # Parameters and buffers have no dependency on other values for dt in tensor.dim_tracks: + dt.is_constant = True dt.with_no_dep() self.frame.add_attr(tensor, concrete_value, node.target) @@ -483,6 +519,7 @@ def _parse_prim_get_attr_node(self, node: torch.fx.Node) -> List[IRFwOperation]: else: self.frame.set_var(node.name, concrete_value) + self.value_tracker.track_nodes(ir_nodes) return ir_nodes def _parse_prim_output_node(self, node: torch.fx.Node) -> List[IRCell]: diff --git a/nnscaler/graph/parser/value_tracker.py b/nnscaler/graph/parser/value_tracker.py index 6419ca15..45a3cf0f 100644 --- a/nnscaler/graph/parser/value_tracker.py +++ b/nnscaler/graph/parser/value_tracker.py @@ -9,49 +9,101 @@ class ValueTracker: + """ + Example: + >>> vt = ValueTracker() + >>> vt.track_value(input1) + >>> vt.track_value(input2) + >>> ... + >>> vt.track_nodes([node1]) + >>> vt.track_nodes([node2]) + >>> vt.untrack_node(node2) # when node2 is folded + >>> vt.track_nodes([node3]) + >>> ... + >>> vt.complete_tracking([node1, node3, ...]) # pass all tracked nodes here + """ def __init__(self): # value_id -> ValueTrack # Please note some ValueTracks may be merged together (from annotation) # So the key can be different from the id of the ValueTrack self._vtm: dict[int, ValueTrack] = {} - self._equiv_value_ids: dict[int, set] = {} + self._equiv_value_ids: dict[int, set[int]] = {} + # store removed value ids + # used to delay the removal of value tracks in deps + self._removed_value_ids: set[int] = set() - def track_values(self, objs: list[Any]): + def _add_track_value(self, value: ValueTrack): + if value.value_id not in self._vtm: + # always use the updated value track in self._vtm + self._vtm[value.value_id] = value + + if value.value_id not in self._equiv_value_ids: + self._equiv_value_ids[value.value_id] = {value.value_id} + + def track_values(self, objs: list[Any]) -> set[int]: + """ + Track the value tracks of the given objects. + Args: + objs (list[Any]): the objects to be tracked + Returns: + set[int]: the set of value ids tracked + """ + value_ids = set() for obj in objs: - self.track_value(obj) + value_ids.update(self._track_value(obj)) + return value_ids - def track_value(self, obj: Any): - for item in IR.get_objects(obj): - if isinstance(item, IRTensor): - for dt in item.dim_tracks: - self._vtm[dt.value_id] = dt - elif isinstance(item, IRObject): - self._vtm[item.value_track.value_id] = item.value_track + def _track_value(self, value: Any): + for obj in IR.get_objects(value): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + self._add_track_value(dt) + yield dt.value_id + else: + assert isinstance(obj, IRObject) + self._add_track_value(obj.value_track) + yield obj.value_track.value_id - def _update_track_value(self, obj: Any): + def _update_track_value(self, obj: IRObject): if isinstance(obj, IRTensor): new_dim_tracks = [] for dt in obj.dim_tracks: new_dim_tracks.append(self._vtm[dt.value_id]) obj.dim_tracks = new_dim_tracks - elif isinstance(obj, IRObject): + else: + assert isinstance(obj, IRObject) obj.value_track = self._vtm[obj.value_track.value_id] + def _update_constness(self, obj: IRObject): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + dt.is_constant = dt.is_constant and all(self._vtm[dep].is_constant for dep in dt.deps or []) + else: + assert isinstance(obj, IRObject) + obj.value_track.is_constant = obj.value_track.is_constant and all(self._vtm[dep].is_constant for dep in obj.value_track.deps or []) + def track_nodes(self, nodes: list[IRFwOperation]): """ Track the value tracks of the input and output objects in the given nodes. Here we assume the nodes are topologically sorted. + + Please note we only update the tracks of nodes in arguments. + For nodes not in arguments, their tracks are not updated. + + Args: + nodes (list[IRFwOperation]): the nodes to be tracked """ # collect all value tracks from nodes + if not nodes: + return + + # collect all involved value ids from nodes + node_value_ids = set() for node in nodes: for obj in node.iobjs(): - self.track_value(obj) + node_value_ids.update(self._track_value(obj)) for obj in node.oobjs(): - self.track_value(obj) - - # init equivalence classes - for vt in self._vtm.values(): - self._equiv_value_ids[vt.value_id] = {vt.value_id} + node_value_ids.update(self._track_value(obj)) # collect extra value tracks from dimops for node in nodes: @@ -59,31 +111,104 @@ def track_nodes(self, nodes: list[IRFwOperation]): self._track_dims(node) # merge equivalent value tracks together - for value_id, equiv_ids in self._equiv_value_ids.items(): + done_value_ids = set() + for value_id in node_value_ids: + equiv_ids = self._equiv_value_ids[value_id] + min_value_id = min(equiv_ids) - if value_id != min_value_id: + if min_value_id in done_value_ids: continue + done_value_ids.add(min_value_id) # use the smallest id as the representative rep_one = self._vtm[min_value_id] for vid in equiv_ids: - if vid == min_value_id: + if vid == min_value_id or self._vtm[vid] is rep_one: continue # TODO: how we merge dependencies? # current we take union (Union may be too strict) if rep_one.deps is None: rep_one.deps = self._vtm[vid].deps elif self._vtm[vid].deps is not None: - rep_one.deps = list(set(rep_one.deps).union(set(self._vtm[vid].deps))) + # deps can still have duplicates here + # because merging of the rest value tracks haven't been done yet + # NOTE: + # 1. this duplication is temporary, + # Duplicated value ids will be removed when we touch the same value track again + # in future track_nodes call. + # 2. duplication is not harmful for correctness + rep_one.deps = list( + set(rep_one.deps) + .union(self._vtm[vid].deps) + .difference(self._removed_value_ids) + ) self._vtm[vid] = rep_one - # dedup dependencies - # Here we will replace dependencies with their representative value tracks + self._propagate_tracks(nodes) + + def untrack_node(self, node: IRFwOperation): + """ + Untrack the value tracks of output objects in the given node. + This function is used when we fold a node from the graph. + + Args: + node (IRFwOperation): the node to be untracked + """ + input_value_ids = set() + for obj in node.iobjs(): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + input_value_ids.add(dt.value_id) + else: + assert isinstance(obj, IRObject) + input_value_ids.add(obj.value_track.value_id) + + for obj in node.oobjs(): + # we can only remove value tracks that are not used by inputs + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + if dt.value_id not in input_value_ids: + self._removed_value_ids.add(dt.value_id) + else: + assert isinstance(obj, IRObject) + if obj.value_track.value_id not in input_value_ids: + self._removed_value_ids.add(obj.value_track.value_id) + + def complete_tracking(self, nodes: list[IRFwOperation]): + """ + Complete the tracking process. + Should be called after all nodes are tracked. + """ + # remove all removed value ids for vtm + # note we don't remove them from equivalence classes + for removed_id in self._removed_value_ids: + if self._vtm[removed_id].value_id == removed_id \ + and (new_equiv_cls := self._equiv_value_ids[removed_id].difference(self._removed_value_ids)): + # change the representative value id of this equivalence class + # NOTE: + # In current usage, code should not reach here. + # As we remove value tracks only for constant irobjects, + # and all equivalent value tracks should be removed together. + self._vtm[removed_id].value_id = min(new_equiv_cls) + self._vtm.pop(removed_id, None) + + # replace dependencies with their representative value tracks # which can introduce some duplicates + # So we use `set` to further dedup dependencies for vt in self._vtm.values(): if vt.deps is not None: - vt.deps = list(set(self._vtm[d].value_id for d in vt.deps)) + vt.deps = list(set( + self._vtm[d].value_id for d in vt.deps + if d not in self._removed_value_ids + )) + self._propagate_tracks(nodes) + + def _propagate_tracks(self, nodes: list[IRFwOperation]): + """ + Update value tracks and constantness information of the input and output objects + in the given nodes. + """ # propagate the merged value tracks back to nodes for node in nodes: for obj in node.iobjs(): @@ -91,6 +216,13 @@ def track_nodes(self, nodes: list[IRFwOperation]): for obj in node.oobjs(): self._update_track_value(obj) + # propagate the constantness information back to nodes + for node in nodes: + for obj in node.iobjs(): + self._update_constness(obj) + for obj in node.oobjs(): + self._update_constness(obj) + def _track_dims(self, node: IRDimops): """ Track the dimension values of output tensors according to input tensors. diff --git a/nnscaler/graph/tracer/metadata.py b/nnscaler/graph/tracer/metadata.py index f6f898a2..75f4de9a 100644 --- a/nnscaler/graph/tracer/metadata.py +++ b/nnscaler/graph/tracer/metadata.py @@ -8,6 +8,7 @@ from torch.fx.node import Node from . import pytree_utils +from nnscaler.utils import get_dynamic DICT_KEYS_TYPE = type({}.keys()) DICT_VALUES_TYPE= type({}.values()) @@ -95,6 +96,9 @@ class TensorMetadata(NamedTuple): is_quantized : bool qparams: Dict[str, Any] + # all dynamic dimensions in shape + dynamic_dims: set[int] + def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ @@ -134,7 +138,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] - return TensorMetadata(shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + return TensorMetadata(shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams, get_dynamic(result)) def extract_metadata(results: Any, node: Node): diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 3c720743..8d843011 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -27,7 +27,7 @@ from nnscaler.ir.unique import IDGenerator from nnscaler.ir.dtype import DTypeInfo -from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE, load_type +from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE, load_type, get_dynamic NestedVarOrStatic = Union[Any, 'IRObject', List['IRObject'], 'IRTensor'] @@ -544,6 +544,10 @@ class ValueTrack: dim 2: ValueTrack(value_id=50, dependencies=[30]) # it depends on input dim 2: m """ value_id: int = field(default_factory=IDGenerator().gen_value_id) + # By default, we consider the value is constant + # unless it is set to not constant + # via mark_dynamic or it is from input or explicitly set in function.py + is_constant: bool = True # None: unknown dependencies # []: no dependencies deps: Optional[list[int]] = None @@ -552,13 +556,12 @@ def with_no_dep(self) -> 'ValueTrack': """ Initialize this ValueTrack with no dependencies. """ - self.with_dep(None) + self.deps = [] return self - def with_dep(self, dep: Union[None, 'ValueTrack', 'IRObject'] = None) -> 'ValueTrack': + def add_dep(self, dep: Union[Any, 'ValueTrack', 'IRObject']) -> 'ValueTrack': """ Initialize or add a dependency to the ValueTrack. - If dep is None, just initialize an empty dependency list, which means no dependencies. If dep is not IRObject or ValueTrack, do nothing. """ if self.deps is None: @@ -570,41 +573,76 @@ def with_dep(self, dep: Union[None, 'ValueTrack', 'IRObject'] = None) -> 'ValueT if isinstance(dep, IRTensor): raise TypeError("Cannot directly add IRTensor as dependency.") - dep = dep.value_track if isinstance(dep, IRObject) else dep + dep: ValueTrack = dep.value_track if isinstance(dep, IRObject) else dep dep_value_id = dep.value_id if dep_value_id not in self.deps: self.deps.append(dep_value_id) + self.is_constant = self.is_constant and dep.is_constant return self - def merge_deps(self, other: ValueTrack) -> 'ValueTrack': + def merge(self, other: ValueTrack) -> 'ValueTrack': + """ + Merge another ValueTrack into this one. + The merged ValueTrack will have dependencies from both ValueTracks. + """ if self.deps is None: self.deps = other.deps else: self.deps.extend(other.deps or []) + + if self.deps is not None: self.deps = list(set(self.deps)) + self.is_constant = self.is_constant and other.is_constant + return self + @classmethod - def new(cls, deps: Iterable[Union[Any, 'ValueTrack', 'IRObject']]) -> 'ValueTrack': + def new(cls, deps: Iterable[Union[Any, 'ValueTrack', 'IRObject']], is_constant: Optional[bool] = None) -> 'ValueTrack': vt = cls() + if is_constant is not None: + vt.is_constant = is_constant + vt.deps = [] for dep in deps: - vt.with_dep(dep) + vt.add_dep(dep) return vt + def mark_as_input(self) -> 'ValueTrack': + """ + Mark this ValueTrack as graph input, which should be not constant and have no dependencies. + """ + self.is_constant = False + self.deps = [] + return self + + +_missing_value = object() class IRObject: """ IRObject serves as general data of IRGraph edge + + There are two special IRObject for lazy evaluation: + 1. IRObject.missing: a singleton object to represent missing object + It is used to tell parser that we don't know the real object yet. + The parser is supposed to create a new IRObject to replace it. + For example, all custom ops will have missing outputs.It relies on parser to set them. + 2. IRObject(..., value=missing_value, ...): an object with unknown value + It is used to tell parser that we don't know the real value yet. + The parser is supposed to set the value. + We have this because we want ops to pass out `value_track` even when the value is unknown. + For example, `Item()` op in `function.py` will create such object. """ # will be set after class definition missing: ClassVar['IRObject'] = None + missing_value: ClassVar[object] = _missing_value def __init__( self, name: Optional[str] = None, tid: Optional[int] = None, - value: Optional[None] = None, - is_constant: bool = True, + value: Any = _missing_value, + is_constant: Optional[bool] = None, *, value_track: Optional[ValueTrack] = None, ) -> None: @@ -620,14 +658,19 @@ def __init__( 2. val is model input, or is the result of a non-torch operation on another not constant IRObject Please note is_constant flag is only used in parser, so after parser, you can totally ignore this flag. + We keep this flag in IRObject for backward compatibility. + If both is_constant and value_track are provided, + `value_track.is_constant` will be overrided by this flag. + value_track (ValueTrack): the value track info of this object """ self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() self.name: str = name if name else 'obj' self._cell: Optional[IRCell] = None self._is_attr: bool = False - self._value: Optional[Any] = value - self._is_constant: bool = is_constant + self._value: Any = value self._value_track: ValueTrack = value_track or ValueTrack() + if is_constant is not None: + self._value_track.is_constant = is_constant def __hash__(self) -> int: return self._id @@ -680,6 +723,14 @@ def value(self) -> Any: """Get example value""" return self._value + @value.setter + def value(self, val: Any): + self._value = val + + def is_value_missing(self) -> bool: + """Check if the value is missing""" + return self._value is IRObject.missing_value + @property def value_track(self) -> ValueTrack: """Get value track info""" @@ -691,11 +742,11 @@ def value_track(self, val: ValueTrack): @property def is_constant(self) -> bool: - return self._is_constant + return self._value_track.is_constant @is_constant.setter def is_constant(self, val: bool): - self._is_constant = val + self._value_track.is_constant = val def __eq__(self, obj) -> bool: if not isinstance(obj, IRObject): @@ -706,7 +757,7 @@ def __copy__(self): """Copy this object but remove the cell information""" if self is IRObject.missing: # missing object is singleton return IRObject.missing - return IRObject(self.name, self._id, self._value, self._is_constant, value_track=self._value_track) + return IRObject(self.name, self._id, self._value, self.is_constant, value_track=self._value_track) def as_attr(self): """ @@ -821,6 +872,10 @@ def _inner(obj) -> Tuple[Any, bool]: dtype=obj.dtype, requires_grad=rg, ) + + for dyn_idx in get_dynamic(obj): + tensor.dim_tracks[dyn_idx].is_constant = False + if tosub: tensor = tensor.tosub() tensor._value = obj # is required in SemanticModel.forward @@ -1221,6 +1276,30 @@ def dim_tracks(self, val: Tuple[Optional[ValueTrack], ...]): # None means starting a new dim track self._dim_tracks = tuple(v if v is not None else ValueTrack() for v in val) + def set_dim_track(self, dim: int, track: ValueTrack): + """ + Set the track of a specific dimension + """ + if dim < 0 or dim >= len(self._shape): + raise IndexError("dim out of range") + dim_tracks = list(self._dim_tracks) + dim_tracks[dim] = track + self._dim_tracks = tuple(dim_tracks) + + def dim_constant(self, dim: int) -> bool: + """ + Check if a dim is constant + """ + if dim < 0 or dim >= len(self._shape): + raise IndexError("dim out of range") + return self._dim_tracks[dim].is_constant + + def dims_constant(self) -> bool: + """ + Check if all dims are constant + """ + return all(track.is_constant for track in self._dim_tracks) + def nelement(self) -> int: """ Get total number of element in the tensor. diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 7dc60e4c..6001f9e1 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -56,6 +56,7 @@ setup_stride_broadcast_group, get_shared_params, OptStateDict, + copy_dynamic ) logger = logging.getLogger(__name__) @@ -357,7 +358,7 @@ def _to_cpu(val: Any): return {_to_cpu(t) for t in val} if isinstance(val, torch.Tensor): requires_grad = val.is_floating_point() or val.is_complex() - return val.detach().clone().cpu().requires_grad_(requires_grad) + return copy_dynamic(val, val.detach().clone().cpu().requires_grad_(requires_grad)) return val diff --git a/nnscaler/policies.py b/nnscaler/policies.py index cd3ed94a..ba05293c 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -121,6 +121,8 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): random.shuffle(configs) for (idx, dim) in configs: if node.input(idx).shape[dim] % len(devs) != 0: continue + # only partition when all input tensors are constant on this dim + if not node.input(idx).dim_tracks[dim].is_constant: continue if node.algorithm('dim').satisfy(idx=idx, dim=dim, num=len(devs)): _tp(graph, node, devs, idx, dim) break diff --git a/nnscaler/utils.py b/nnscaler/utils.py index f7bd8954..b20ab317 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -520,3 +520,37 @@ def fn_field(**kwargs): metadata = kwargs.pop('metadata', {}) metadata['deserialize'] = lambda t: None if t is None else load_type(t) return field(**kwargs, metadata=metadata) + + +TENSOR_DYNAMIC_DIMS_FIELD_NAME = '_nnscaler_dynamic_dims' +# for nnscaler custom class (TensorMetadata) +NNSCALER_DYNAMIC_DIMS_NAME = 'dynamic_dims' + + +def mark_dynamic(tensor: torch.Tensor, dims: int | list[int] | tuple[int]) -> torch.Tensor: + """ + Mark the dim of a tensor as dynamic, which means it can be changed in the future. + This is the same with `torch._dynamo.mark_dynamic` + """ + setattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, set(dims) if dims else set()) + return tensor + + +def copy_dynamic(src: torch.Tensor, tensor: torch.Tensor) -> torch.Tensor: + """ + Copy the dynamic dims from src to tensor, and return the tensor. + """ + if hasattr(src, TENSOR_DYNAMIC_DIMS_FIELD_NAME): + setattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, getattr(src, TENSOR_DYNAMIC_DIMS_FIELD_NAME)) + return tensor + + +def get_dynamic(tensor: Any) -> set[int]: + """ + Get the dynamic dims of a tensor. + It also works when tensor is not an instance of torch.Tensor + """ + if isinstance(tensor, torch.Tensor): + return getattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, set()) + else: + return getattr(tensor, NNSCALER_DYNAMIC_DIMS_NAME, set()) diff --git a/tests/cli/common.py b/tests/cli/common.py index f5be1c9d..02f8b197 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -1,6 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# CausalSelfAttention is copied from https://github.com/karpathy/nanoGPT/blob/master/model.py +# with minor modifications. +# See the original license in the file https://github.com/karpathy/nanoGPT/blob/master/LICENSE + from pathlib import Path import torch from torch import nn @@ -9,11 +13,88 @@ from streaming import MDSWriter, StreamingDataset, StreamingDataLoader +import nnscaler from nnscaler.cli.trainer_args import TrainerArgs from tests.parallel_module.test_end2end import MLP from tests.utils import init_random as init_random_fn + +class CausalSelfAttention(nn.Module): + def __init__(self, n_embd: int, n_head: int, dropout: float): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd, bias=True) + # regularization + self.attn_dropout = nn.Dropout(dropout) + self.resid_dropout = nn.Dropout(dropout) + self.n_head = n_head + self.n_embd = n_embd + self.dropout = dropout + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class SimpleTransformerModel(nn.Module): + def __init__(self, n_embd: int, n_head: int, dropout: float, nlayers: int, vocab_size: int): + super().__init__() + + self.layers = nn.ModuleList([]) + self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) + for _ in range(nlayers): + self.layers.append(CausalSelfAttention(n_embd, n_head, dropout)) + + def forward(self, data): + x = data['input'] + target = data['target'] + for layer in self.layers: + x = layer(x) + logits = self.lm_head(x) + loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1), ignore_index=-1) + return loss + + +def csa_forward_args_gen_fn(trainer_args: TrainerArgs): + seq_len = 128 # dynamicness is controlled by trainer_args.vars['dynamic_dims'] + + return { + 'x': torch.randn(1, seq_len, trainer_args.model.args['n_embd']), + } + + +def post_csa_forward_args_gen_fn(trainer_args: TrainerArgs, args): + dynamic_dims = trainer_args.get_resolved_var('dynamic_dims', default=[]) + nnscaler.mark_dynamic(args['x'], dynamic_dims) + return args + + +def transformer_dummy_sample_gen_fn(trainer_args: TrainerArgs): + seq_len = 128 # dynamicness is controlled by trainer_args.vars['dynamic_dims'] + dynamic_dims = trainer_args.get_resolved_var('dynamic_dims', default=[]) + return { + 'input': nnscaler.mark_dynamic(torch.randn(1, seq_len, trainer_args.model.args['n_embd']), dynamic_dims), + 'target': nnscaler.mark_dynamic(torch.randint(0, trainer_args.model.args['vocab_size'], (1, seq_len)), dynamic_dims), + } + + class MixModuleMLP(nn.Module): def __init__(self, dim: int, nlayers: int, init_random: bool = True): super().__init__() diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index a837db29..48107c2a 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from pathlib import Path +import re import shutil import torch @@ -1112,3 +1113,102 @@ def trainer_resumable_dataloader(save_dir): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_trainer_resumable_dataloader(tmp_path): launch_torchrun(4, trainer_resumable_dataloader, tmp_path) + + +@replace_all_device_with('cpu') +def test_trainer_dynamic_worker(tmp_path): + + def check_match(code_dir: Path, should_exist: bool): + gencode_files = list(code_dir.glob('**/*.py')) + assert set(f.name for f in gencode_files) == set(['gencode0.py', 'gencode1.py', 'gencode2.py', 'gencode3.py']) + for gencode_file in gencode_files: + filecontent = gencode_file.read_text() + matches = re.findall(r'B, T, C = x\.size\(\)', filecontent) + if should_exist: + assert matches + else: + assert not matches + + shutil.rmtree(code_dir) + + save_dir = Path(tmp_path) + config_path = str(Path(__file__).with_name('trainer_args_csa.yaml').resolve()) + gen_savedir = save_dir / 'gen' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[1]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + ]) + trainer.run() + check_match(gen_savedir, should_exist=True) + + gen_savedir = save_dir / 'gen0' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + ]) + trainer.run() + check_match(gen_savedir, should_exist=False) + + # mixed compile + gen_savedir = save_dir / 'gen1' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[1]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + + '--model.parallel_modules.0.type', 'tests.cli.common.CausalSelfAttention', + '--model.parallel_modules.0.args.n_embd', '$(model.args.n_embd)', + '--model.parallel_modules.0.args.n_head', '$(model.args.n_head)', + '--model.parallel_modules.0.args.dropout', '$(model.args.dropout)', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.csa_forward_args_gen_fn', + '--model.parallel_modules.0.forward_args_post_process_fn', 'tests.cli.common.post_csa_forward_args_gen_fn', + ]) + trainer.run() + check_match(gen_savedir, should_exist=True) + + # mixed compile + gen_savedir = save_dir / 'gen2' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + + '--model.parallel_modules.0.type', 'tests.cli.common.CausalSelfAttention', + '--model.parallel_modules.0.args.n_embd', '$(model.args.n_embd)', + '--model.parallel_modules.0.args.n_head', '$(model.args.n_head)', + '--model.parallel_modules.0.args.dropout', '$(model.args.dropout)', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.csa_forward_args_gen_fn', + '--model.parallel_modules.0.forward_args_post_process_fn', 'tests.cli.common.post_csa_forward_args_gen_fn', + ]) + trainer.run() + check_match(gen_savedir, should_exist=False) diff --git a/tests/cli/trainer_args_csa.yaml b/tests/cli/trainer_args_csa.yaml new file mode 100644 index 00000000..1a18c6e3 --- /dev/null +++ b/tests/cli/trainer_args_csa.yaml @@ -0,0 +1,53 @@ +vars: + dynamic_dims: [1] + dim: 16 + drop_last: true +compute_config: + plan_ngpus: 4 + runtime_ngpus: 100 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: tp +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 +max_train_steps: 100 +seed: 0 +dummy_sample_gen_fn: tests.cli.common.transformer_dummy_sample_gen_fn + +model: + type: tests.cli.common.SimpleTransformerModel + args: + n_embd: 1024 + n_head: 8 + dropout: 0.001 + nlayers: 2 + vocab_size: 10000 + +optimizer: + type: torch.optim.Adam + args: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: $(vars.dim) + size: 100 + val_args: + dim: $(vars.dim) + size: 10 + +dataloader: + train_args: + drop_last: $(vars.drop_last) + val_args: + drop_last: $(vars.drop_last) + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 0a8b2fc0..e6a12c55 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -7,6 +7,7 @@ from operator import add from nnscaler.graph.function.dimops import IRDimops, OpAnno import nnscaler.graph.function.function as F +from nnscaler.graph.parser.value_tracker import ValueTracker from nnscaler.ir.cten import IR, IRObject, IRTensor import pytest @@ -47,6 +48,21 @@ def test_Full(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 1' +def test_Randn(): + op = F.Randn(IRObject(value=[2, 3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 2 3 4' + + for dim_track in op.output(0).dim_tracks: + assert dim_track.deps == [op.kwargs['size'].value_track.value_id] + + op = F.Randn(2, IRObject(value=3), IRObject(value=4)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 2 3 4' + + assert op.output(0).dim_tracks[0].deps == [] + assert op.output(0).dim_tracks[1].deps == [op.kwargs['size'][1].value_track.value_id] + assert op.output(0).dim_tracks[2].deps == [op.kwargs['size'][2].value_track.value_id] + + def test_Expand(): inp = IRTensor([10, 1]) out = IRTensor([10, 2]) @@ -1163,7 +1179,27 @@ def test_Stack(): expected_annotation = '1, 1, 1 -> 3' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + def test_Dot(): op = F.Dot(IRTensor([4]), IRTensor([4])) expected_annotation = 'k+, k+ -> 1' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Dot." + + +def test_chunk(): + op = F.Chunk(IRTensor([8, 10]), chunks=4, dim=0) + expected_annotation = '8 b -> 2 b, 2 b, 2 b, 2 b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation + value_tracker = ValueTracker() + value_tracker.track_nodes([op]) + value_tracker.complete_tracking([op]) + input_dim_tracks = op.input(0).dim_tracks + output_dim_tracks = [out.dim_tracks for out in op.outputs()] + # all dim 1 tracks should be the same + assert output_dim_tracks[0][1] is input_dim_tracks[1] + # output dim 0 tracks should depend on input dim 0 track + assert output_dim_tracks[0][0].deps == [input_dim_tracks[0].value_id] + for output_dim_track in output_dim_tracks[1:]: + assert output_dim_track[0] is output_dim_tracks[0][0] + assert output_dim_track[1] is output_dim_tracks[0][1] + assert True diff --git a/tests/graph/parser/test_value_tracker.py b/tests/graph/parser/test_value_tracker.py index 86f8661d..ff45eacf 100644 --- a/tests/graph/parser/test_value_tracker.py +++ b/tests/graph/parser/test_value_tracker.py @@ -3,9 +3,11 @@ import tempfile +import pytest import torch from nnscaler.graph.parser.converter import convert_model +from nnscaler import register_op, mark_dynamic from ...utils import replace_all_device_with @@ -76,7 +78,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): @replace_all_device_with('cpu') -def test_size(): +@pytest.mark.parametrize('dynamic_dim', [True, False]) +def test_size(dynamic_dim): class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -87,7 +90,7 @@ def forward(self, x: torch.Tensor): y = torch.randn(s) return x + y - dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + dummy_input = {'x': mark_dynamic(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), [0] if dynamic_dim else [])} module = MyModule() with tempfile.TemporaryDirectory() as tempdir: @@ -98,8 +101,94 @@ def forward(self, x: torch.Tensor): randn_node = ir_graph.node(2) assert size_node.output(0)[0].value_track is ir_graph.inputs()[0].dim_tracks[0] + assert size_node.output(0)[0].value_track.is_constant != (dynamic_dim is True) assert size_node.output(0)[1].value_track is ir_graph.inputs()[0].dim_tracks[1] + assert size_node.output(0)[1].value_track.is_constant is True # dim tracks of randn node is from equivalence class originally from torch.add assert randn_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[0] + assert randn_node.output(0).dim_tracks[0].is_constant != (dynamic_dim is True) assert randn_node.output(0).dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[1] + assert randn_node.output(0).dim_tracks[1].is_constant is True + + +# Note: the custom op here is just for testing purpose +@register_op('l (2 m) n -> n (2 l) m') +def my_op(x: torch.Tensor) -> torch.Tensor: + return torch.randn(x.size(2), x.size(0) * 2, x.size(1) // 2) + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dim', [True, False]) +def test_custom_op(dynamic_dim): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + x = my_op(x) + s = x.size() + y = torch.randn(s) + return x + y + + dummy_input = {'x': mark_dynamic(torch.randn(2, 2, 2), [0, 2] if dynamic_dim else [])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 4 + my_op_node = ir_graph.node(0) + size_node = ir_graph.node(1) + randn_node = ir_graph.node(2) + + assert [t.is_constant for t in ir_graph.inputs()[0].dim_tracks] == [not dynamic_dim, True, not dynamic_dim] + + assert my_op_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[2] + assert ir_graph.inputs()[0].dim_tracks[0].value_id in my_op_node.output(0).dim_tracks[1].deps + assert ir_graph.inputs()[0].dim_tracks[1].value_id in my_op_node.output(0).dim_tracks[2].deps + + assert [t.is_constant for t in my_op_node.outputs()[0].dim_tracks] == [not dynamic_dim, not dynamic_dim, True] + + assert size_node.output(0)[0].value_track is my_op_node.output(0).dim_tracks[0] + assert size_node.output(0)[1].value_track is my_op_node.output(0).dim_tracks[1] + assert size_node.output(0)[2].value_track is my_op_node.output(0).dim_tracks[2] + + assert [t.value_track.is_constant for t in size_node.output(0)] == [not dynamic_dim, not dynamic_dim, True] + + # dim tracks of randn node is from equivalence class originally from torch.add + assert randn_node.output(0).dim_tracks[0] is my_op_node.output(0).dim_tracks[0] + # assert randn_node.output(0).dim_tracks[0].is_constant != (dynamic_dim is True) + assert randn_node.output(0).dim_tracks[1] is my_op_node.output(0).dim_tracks[1] + assert randn_node.output(0).dim_tracks[2] is my_op_node.output(0).dim_tracks[2] + + assert [t.is_constant for t in randn_node.outputs()[0].dim_tracks] == [not dynamic_dim, not dynamic_dim, True] + + +# Note: the custom op here is just for testing purpose +@register_op('l l -> l l') +def my_identity(x: torch.Tensor) -> torch.Tensor: + return x + + +@replace_all_device_with('cpu') +def test_custom_op2(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return my_identity(x) + + dummy_input = {'x': torch.randn(2, 2)} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 1 + my_op_node = ir_graph.node(0) + + assert ir_graph.inputs()[0].dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[0] + assert my_op_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[0] + assert my_op_node.output(0).dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[1] diff --git a/tests/ir/test_cten.py b/tests/ir/test_cten.py index c406f7bc..42f7d115 100644 --- a/tests/ir/test_cten.py +++ b/tests/ir/test_cten.py @@ -120,9 +120,9 @@ def test_from_complex(tosub, requires_grad): assert type(obj[2]) == tensor_type and obj[2].parent.tid != obj_tensor_item.tid t1 = TensorMetadata(shape=(), dtype=torch.float, requires_grad=False, - stride=None, memory_format=None, is_quantized=None, qparams=None) + stride=None, memory_format=None, is_quantized=None, qparams=None, dynamic_dims=set()) t2 = TensorMetadata(shape=(2,), dtype=torch.float, requires_grad=True, - stride=None, memory_format=None, is_quantized=None, qparams=None) + stride=None, memory_format=None, is_quantized=None, qparams=None, dynamic_dims=set()) obj = IR.new('n', {'a': t1, 'b': t2}.values(), tensor_types=(TensorMetadata,), diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index c4306c2c..7e2e3ba0 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -6,6 +6,7 @@ import re from contextlib import nullcontext from typing import Union +from functools import partial import torch import torch.nn.functional as F @@ -17,7 +18,8 @@ from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.parser.mapping import SignFx2Op from nnscaler.ir.cten import IR, IRObject -from nnscaler.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph +from nnscaler.parallel import _load_parallel_module_class, parallelize, ComputeConfig, CubeModule, _gen_graph +from nnscaler.utils import mark_dynamic from .common import init_distributed from ..launch_torchrun import launch_torchrun @@ -394,11 +396,11 @@ def print_gencode(cubesave_dir, module_class, index=0): print(filecontent) -def _gencode_contains(cubesave_dir, module_class, index, search_re): +def _gencode_contains(cubesave_dir, module_class, index, search_re, *, instance_name=None): from nnscaler.parallel import _PARALLEL_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME from pathlib import Path import re - namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' + namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{instance_name or _DEFAULT_INSTANCE_NAME}' outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) filecontent = (outdir /f'gencode{index}.py').read_text() matches = re.findall(search_re, filecontent) @@ -659,6 +661,37 @@ def test_codegen_dictget(): assert m_new is None +class NonConstModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + shape = x.shape + z = torch.randn(shape) + shape = z.shape + return z + shape[0] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dims', [[], [0]]) +def test_codegen_nonconst(dynamic_dims): + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + NonConstModule(), + {'x': mark_dynamic(torch.tensor([[[1.0], [2.0], [3.0], [6.0]]]), dynamic_dims)}, # shape 1/4/1 + 'dp', + ComputeConfig(1, 1, constant_folding=True), + gen_savedir=tempdir, + load_module=False + ) + if not dynamic_dims: + # shape[0] is constant 1, so can be folded to constant 1 + assert _gencode_contains(tempdir, NonConstModule, 0, r'torch.add\(.*, 1, alpha=1\)') + else: + # shape[0] is dynamic, so cannot be folded to constant 1 + assert not _gencode_contains(tempdir, NonConstModule, 0, r'torch.add\(.*, 1, alpha=1\)') + + class CloneModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1969,3 +2002,112 @@ def test_codegen_forward_error_compile(tmp_path): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of GPU devices') def test_codegen_forward_error(tmp_path): launch_torchrun(2, _gencode_forward_error_worker, tmp_path) + + +class WeightModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, input): + input = input + self.weights + out = input @ self.weights + return out + + +class WeightModel2(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = WeightModel() + + def forward(self, input): + return self.weights(input) + + +def pas_weight(graph, cfg, with_auto_multiref=True): + from nnscaler.ir import IRFwOperation, IRDataOperation + from nnscaler.policies import _tp, _replica, auto_multiref + ngpus = cfg.plan_ngpus + if with_auto_multiref: + auto_multiref(graph) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'add': + _tp(graph, node, list(range(ngpus)), 1, 0) + elif node.name == 'matmul': + _tp(graph, node, list(range(ngpus)), 1, 0) + else: + _replica(graph, node, list(range(ngpus))) + return graph + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('with_auto_multiref', [True, False]) +def test_weight_partition(tmp_path, with_auto_multiref): + """ + If auto_multiref is not applied, the weight will correctly partitioned + If auto_multiref is applied, the weight will be replicated as a whole + """ + input = torch.randn((4, 4)) + instance_name = f'with_auto_multiref_{with_auto_multiref}' + + dummy_input = {'input': input} + + m = WeightModel2() + m.train() + + parallelize( + m, + dummy_input, + partial(pas_weight, with_auto_multiref=with_auto_multiref), + ComputeConfig(2, 2), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + instance_name=instance_name, + ) + + module_class = _load_parallel_module_class(WeightModel2, gen_savedir=tmp_path, instance_name=instance_name, rank=0) + + if with_auto_multiref: + for rank in range(2): + fullmap = module_class.attr_meta_maps[rank] + assert fullmap[list(fullmap.keys())[0]].sub_shape == (4, 4) + else: + for rank in range(2): + fullmap = module_class.attr_meta_maps[rank] + assert fullmap[list(fullmap.keys())[0]].sub_shape == (2, 4) + +class DynamicInputModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = torch.nn.Parameter(torch.randn(1, 1)) + + def forward(self, input): + return input + self.weights + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dims', [[], [0, 1]]) +def test_dynamic_dim_partition(tmp_path, dynamic_dims): + input = mark_dynamic(torch.randn((4, 4)), dynamic_dims) + dummy_input = {'input': input} + instance_name=f'{"no" if not dynamic_dims else ""}_dynamic_dims' + + m = DynamicInputModel() + m.train() + + parallelize( + m, + dummy_input, + 'tp', + ComputeConfig(2, 2), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + instance_name=instance_name, + ) + if dynamic_dims: + # no partition for dynamic input + assert not _gencode_contains(tmp_path, DynamicInputModel, 0, r'nnscaler.runtime.adapter.nn.split_allgather', instance_name=instance_name) + else: + assert _gencode_contains(tmp_path, DynamicInputModel, 0, r'nnscaler.runtime.adapter.nn.split_allgather', instance_name=instance_name) From 9d303b2f51cb614a7670b4b181f35539f2dee125 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 3 Nov 2025 06:04:26 +0000 Subject: [PATCH 1849/1892] Merged PR 2420: [Runtime] Add runtime pre-hook and post-hook for node Add runtime pre-hook and post-hook for node --- nnscaler/codegen/emit.py | 43 ++++++++++- nnscaler/codegen/module/module.py | 12 ++++ nnscaler/graph/graph.py | 6 +- nnscaler/ir/cten.py | 38 ++++++++++ nnscaler/policies.py | 27 ++++++- tests/test_policies.py | 114 +++++++++++++++++++++++++++++- 6 files changed, 233 insertions(+), 7 deletions(-) diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index fc197b71..00c73060 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import inspect from typing import Generator, Iterable, List, Any, Optional, Tuple, Dict import logging @@ -225,8 +226,35 @@ def emit_fnode(self, node: IRFwOperation, runtime_devid: int, plan_ndevs: int, r emit_rule = self._emit_rules.map(signature) body = emit_rule(node, inputs, kwargs, runtime_devid, plan_ndevs, runtime_ndevs) + def _to_tuple_str(names: List[str]) -> str: + if len(names) == 1: + return f'({names[0]}, )' + return '(' + ', '.join(names) + ')' + + def _insert_hook(outputs=None, is_pre: bool=False, output_len: int = 0): + hook = node.pre_hook if is_pre else node.post_hook + if not hook: + return + module_path = inspect.getmodule(hook).__name__ + fsig = f'{module_path}.{hook.__name__}' + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + codes.append( + f'{fsig}(self, ' + + repr(node.hook_meta) + ', ' + + f"{_to_tuple_str(inputs)}, " + + f"dict({', '.join(kw_pairs)})" + + ('' if is_pre else ', ' + outputs) + + ')' + ) + + _insert_hook(is_pre=True) + if len(node.outputs()) == 0: codes.append(body) + _insert_hook(is_pre=False, outputs='None') else: irobj_path = {} def r(t, current_path): @@ -245,8 +273,12 @@ def r(t, current_path): if all(len(x) == 1 for x in irobj_path.values()): # if all IRObjects are leafs, we can directly assign the output outputs = [self.tensor_name(t) for t in node.outputs()] - outputs = ', '.join(outputs) - codes.append(f'{outputs} = {body}') + outputs_str = ', '.join(outputs) + codes.append(f'{outputs_str} = {body}') + _insert_hook( + outputs=outputs_str if len(node.outputs()) == 1 else _to_tuple_str(outputs), + is_pre=False + ) else: outputs = [] im_outputs = [] @@ -258,7 +290,12 @@ def r(t, current_path): im_ouptut = self.tensor_name(IRObject('im_output')) im_outputs.append(im_ouptut) outputs.append(im_ouptut) - codes.append(f'{", ".join(outputs)} = {body}') + outputs_str = ', '.join(outputs) + codes.append(f'{outputs_str} = {body}') + _insert_hook( + outputs=outputs_str if len(node.outputs()) == 1 else _to_tuple_str(outputs), + is_pre=False + ) for t, path in irobj_path.items(): if len(path) == 1: # immediate output, skip diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 907323a9..36ad0a83 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -139,6 +139,18 @@ def __init__( # self.init_code.append('@torch.jit.script') self.init_code.append(op_impl) self.init_code += [''] + + # hooks + hook_imports = set() + for node in execplan.graph.select(ntype=IRFwOperation): + if node.pre_hook is not None: + hook_imports.add(inspect.getmodule(node.pre_hook).__name__) + if node.post_hook is not None: + hook_imports.add(inspect.getmodule(node.post_hook).__name__) + for modname in hook_imports: + self.init_code.append(f'import {modname}') + self.init_code += [''] + # module init code self.model_init_statements: List[str] = list() # module method bodies for forward computations, e.g. Segments, Adapters. diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 5f412f25..6cf7372f 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -1212,7 +1212,8 @@ def checksum(self, strict: bool = True) -> str: def copy_node_meta_info(src_node: Union[IRFwOperation, IRDataOperation], dest_node: Union[IRFwOperation, IRDataOperation]): """ Copy meta information from src_node to dest_node. - Current copy fields: ['recompute', 'comment', 'op_context', 'module_stack', 'device'] + Current copy fields: ['recompute', 'comment', 'op_context', 'module_stack', 'device', + 'hook_meta', 'pre_hook', 'post_hook'] """ if isinstance(src_node, IRFwOperation): dest_node.recompute = src_node.recompute @@ -1222,3 +1223,6 @@ def copy_node_meta_info(src_node: Union[IRFwOperation, IRDataOperation], dest_no dest_node.op_context = src_node.op_context dest_node.module_stack = src_node.module_stack dest_node.device = src_node.device + dest_node.hook_meta = src_node.hook_meta + dest_node.pre_hook = src_node.pre_hook + dest_node.post_hook = src_node.post_hook diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 8d843011..e2c2c063 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -88,6 +88,20 @@ def __init__(self, # the operation context information self._op_context: Optional[Dict[str, Any]] = None + # function to be called before the op is executed + # which will be inserted in the runtime code before the op call. + # op's inputs will be passed to the hook. + # The signature will be like + # def pre_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: + self._pre_hook: Optional[Callable[..., None]] = None + # function to be called after the op is executed + # which will be inserted in the runtime code after the op call. + # op's inputs and outputs will be passed to the hook. + # the signature will be like + # def post_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any], output: Any) -> None: + self._post_hook: Optional[Callable[..., None]] = None + self._hook_meta: Any = None + @property def cid(self) -> int: """ @@ -452,6 +466,30 @@ def fn(self) -> Optional[Callable]: except Exception as e: return None + @property + def pre_hook(self) -> Optional[Callable[..., None]]: + return self._pre_hook + + @pre_hook.setter + def pre_hook(self, hook: Optional[Callable[..., None]]): + self._pre_hook = hook + + @property + def post_hook(self) -> Optional[Callable[..., None]]: + return self._post_hook + + @post_hook.setter + def post_hook(self, hook: Optional[Callable[..., None]]): + self._post_hook = hook + + @property + def hook_meta(self) -> Any: + return self._hook_meta + + @hook_meta.setter + def hook_meta(self, meta: Any): + self._hook_meta = meta + @property def op_context(self) -> Optional[Dict[str, Any]]: return self._op_context diff --git a/nnscaler/policies.py b/nnscaler/policies.py index ba05293c..9558da34 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -21,7 +21,7 @@ import ast from dataclasses import dataclass, field import logging -from typing import List, Literal, Optional, TYPE_CHECKING, Callable, Iterable, Union +from typing import Any, List, Literal, Optional, TYPE_CHECKING, Callable, Iterable, Union import random import torch @@ -41,7 +41,7 @@ if TYPE_CHECKING: - from nnscaler.parallel import ComputeConfig + from nnscaler.parallel import ComputeConfig, ParallelModule _logger = logging.getLogger(__name__) @@ -374,6 +374,25 @@ class OpPlan: recompute_id: int = -1 # -1 means no recompute stage_id: int = -1 # pipeline stage id, -1 means following the previous op's stage + # user defined meta data for hooks + # which will be passed to the pre_hook and post_hook functions + # Note: Only types that can be safely `repr`-ed can be used here. (e.g., str, int, float, tuple, list, dict) + hook_meta: Any = None + + # function to be called before the op is executed + # which will be inserted in the runtime code before the op call. + # op's inputs will be passed to the hook. + # The signature will be like + # def pre_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: + pre_hook: Optional[Callable[['ParallelModule', Any, tuple[Any, ...], dict[str, Any]], None]] = None + + # function to be called after the op is executed + # which will be inserted in the runtime code after the op call. + # op's inputs and outputs will be passed to the hook. + # the signature will be like + # def post_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any], output: Any) -> None: + post_hook: Optional[Callable[['ParallelModule', Any, tuple[Any, ...], dict[str, Any], Any], None]] = None + # OpPartition: user specified partition plan # You only need to specify one partition plan here. # For example, torch.matmul has annotation of `m k+, k+ n -> m n`, @@ -531,6 +550,10 @@ def fn( if node not in op_plans: op_plans[node] = OpPlan(op=node) # default: no partition, stage 0, no recompute + node.hook_meta = op_plans[node].hook_meta + node.pre_hook = op_plans[node].pre_hook + node.post_hook = op_plans[node].post_hook + op_plan = op_plans[node] # set pipeline stage id if not set diff --git a/tests/test_policies.py b/tests/test_policies.py index c03e6673..06e31deb 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -10,7 +10,8 @@ from nnscaler.parallel import ComputeConfig, _load_parallel_module_class, parallelize from nnscaler.policies import get_called_self_module_name, get_pas_ops -from tests.parallel_module.common import FFN +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.common import FFN, init_distributed from tests.parallel_module.test_gencode import _gencode_contains, print_gencode from .utils import init_random, replace_all_device_with @@ -871,3 +872,114 @@ def test_codegen_fn_pipeline2(tmp_path): ) # should successfully generate code without error assert True + + +class HookModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + a, b = x.size()[:2] + r = torch.randn(a * 2, b) + r = r.chunk(2, dim=0)[0] + return self.linear(x) + r + + + +def hello(module, meta, *args, **kwargs): + print(f'hello: {meta}') + + +def policy_hook(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition + # add hook to all ops + def _hook(op_plan: OpPlan): + op_plan.pre_hook = hello + op_plan.post_hook = hello + op_plan.hook_meta = op_plan.op.name + return op_plan + + for node in get_pas_ops(graph): + if node.fn == torch.nn.functional.linear: + yield _hook(OpPlan(node, partition=OpPartition(input=1, dim=0))) + else: + yield _hook(OpPlan(node)) + + +@replace_all_device_with('cpu') +def test_codegen_fn_with_hook(tmp_path): + parallelize( + HookModule(), + {'x': torch.randn(4, 4)}, + policy_hook, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + # should successfully generate code without error + # and hooks are inserted + for rank in range(4): + assert _gencode_contains(tmp_path, HookModule, rank, r'tests.test_policies.hello\(self,') + + # Generated code of rank 0 looks like: + # def segment64(self, x_32): + # x_32 = nnscaler.runtime.adapter.nn.identity_allreduce(x_32, ranks=[0, 1]) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 883, in forward, a, b = x.size()[:2] + # tests.test_policies.hello(self, 'size', (x_32, ), dict()) + # im_output_63 = torch.Tensor.size(x_32) + # tests.test_policies.hello(self, 'size', (x_32, ), dict(), im_output_63) + # size_26 = im_output_63[0] + # size_27 = im_output_63[1] + # del im_output_63 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 884, in forward, r = torch.randn(a * 2, b) + # tests.test_policies.hello(self, 'mul', (size_26, 2), dict()) + # mul_28 = _operator.mul(size_26, 2) + # tests.test_policies.hello(self, 'mul', (size_26, 2), dict(), mul_28) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 884, in forward, r = torch.randn(a * 2, b) + # tests.test_policies.hello(self, 'randn', (), dict(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False)) + # randn_34 = nnscaler.runtime.function.randn(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False) + # tests.test_policies.hello(self, 'randn', (), dict(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False), randn_34) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 885, in forward, r = r.chunk(2, dim=0)[0] + # tests.test_policies.hello(self, 'chunk', (randn_34, ), dict(chunks=2, dim=0)) + # chunk_35, chunk_36 = torch.chunk(randn_34, chunks=2, dim=0) + # tests.test_policies.hello(self, 'chunk', (randn_34, ), dict(chunks=2, dim=0), (chunk_35, chunk_36)) + # del randn_34, chunk_36 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 886, in forward, return self.linear(x) + r + # tests.test_policies.hello(self, 'linear', (x_32, self.linear_weight_45, self.linear_bias_47), dict()) + # linear_49 = torch.nn.functional.linear(x_32, self.linear_weight_45, self.linear_bias_47) + # tests.test_policies.hello(self, 'linear', (x_32, self.linear_weight_45, self.linear_bias_47), dict(), linear_49) + # del x_32 + # linear_39 = nnscaler.runtime.adapter.nn.allgather_split(linear_49, dim=1, ranks=[0, 1]) + # del linear_49 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 886, in forward, return self.linear(x) + r + # tests.test_policies.hello(self, 'add', (linear_39, chunk_35), dict(alpha=1)) + # add_33 = torch.add(linear_39, chunk_35, alpha=1) + # tests.test_policies.hello(self, 'add', (linear_39, chunk_35), dict(alpha=1), add_33) + # del chunk_35, linear_39 + # return add_33 + + +def _gencode_unused_args_worker(tempdir): + init_distributed() + m_new = parallelize( + HookModule(), + {'x': torch.randn(4, 4)}, + policy_hook, + ComputeConfig(2, 2), + gen_savedir=tempdir, + load_module=True + ) + assert m_new is not None + m_new(torch.randn(4, 4)) + # should successfully run without error + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_run_codegen_fn_with_hook(): + """ + Verify the generated code can run correctly. + """ + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(2, _gencode_unused_args_worker, tempdir) From 1826f1218f5752306264bb4ddc636a22622e4858 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 20 Nov 2025 04:57:29 +0000 Subject: [PATCH 1850/1892] Merged PR 2419: [Test] Improve attn test coverage --- .../core/ring_attn_implementation.py | 4 +- .../core/ring_attn_varlen_implementation.py | 4 +- .../core/zigzag_attn_implementation.py | 4 +- tests/customized_ops/ring_attn/__init__.py | 0 .../ring_attn/ring_attn_runner.py | 23 +- .../ring_attn/ring_attn_varlen_runner.py | 17 +- tests/customized_ops/ring_attn/runner_base.py | 29 ++- tests/customized_ops/ring_attn/test_base.py | 60 +++-- .../ring_attn/test_ring_attn.py | 8 +- .../ring_attn/test_ring_attn_varlen.py | 8 +- .../ring_attn/test_shuffle_varlen.py | 213 ++++++++++++++++++ .../ring_attn/test_zigzag_attn.py | 6 +- .../ring_attn/zigzag_attn_runner.py | 22 +- 13 files changed, 348 insertions(+), 50 deletions(-) create mode 100644 tests/customized_ops/ring_attn/__init__.py create mode 100644 tests/customized_ops/ring_attn/test_shuffle_varlen.py diff --git a/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py b/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py index b8bbc351..39a3885d 100644 --- a/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py +++ b/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py @@ -98,7 +98,7 @@ def ring_flash_attn_backward( window_size=(-1, -1), alibi_slopes=None, deterministic=False, -): +): # pragma: no cover block_len = q.size(1) // 2 curr_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) @@ -290,7 +290,7 @@ def forward( return out @staticmethod - def backward(ctx, dout, *args): + def backward(ctx, dout, *args): # pragma: no cover dout = shuffle_input(to_send=dout, process_group=ctx.group) q, k, v, out, up_lse, down_lse = ctx.saved_tensors bsz = q.size(0) diff --git a/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py b/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py index d388d3b4..709ffe09 100644 --- a/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py +++ b/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py @@ -186,7 +186,7 @@ def llama3_flash_attn_varlen_backward( window_size=(-1, -1), alibi_slopes=None, deterministic=False, -): +): # pragma: no cover nheads = q.shape[1] total_k, nheads_k, head_dim = k.shape assert nheads_k % heads_k_stride == 0 @@ -373,7 +373,7 @@ def forward( return out if not return_softmax else (out, softmax_lse, None) @staticmethod - def backward(ctx, dout, *args): + def backward(ctx, dout, *args): # pragma: no cover q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors dq, dk, dv = llama3_flash_attn_varlen_backward( ctx.group, diff --git a/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py b/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py index fddf74d7..7a643d59 100644 --- a/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py +++ b/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py @@ -254,7 +254,7 @@ def zigzag_ring_flash_attn_backward( window_size=(-1, -1), alibi_slopes=None, deterministic=False, -): +): # pragma: no cover assert causal == True, "zigzag ring is meaningless for causal=False" kv_comm = RingComm(process_group) d_kv_comm = RingComm(process_group) @@ -411,7 +411,7 @@ def forward( return out if not return_softmax else (out, softmax_lse, None) @staticmethod - def backward(ctx, dout, *args): + def backward(ctx, dout, *args): # pragma: no cover dout = shuffle_input(to_send=dout, process_group=ctx.group) q, k, v, out, softmax_lse = ctx.saved_tensors dq, dk, dv = zigzag_ring_flash_attn_backward( diff --git a/tests/customized_ops/ring_attn/__init__.py b/tests/customized_ops/ring_attn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/customized_ops/ring_attn/ring_attn_runner.py b/tests/customized_ops/ring_attn/ring_attn_runner.py index fafc442f..405c61b1 100644 --- a/tests/customized_ops/ring_attn/ring_attn_runner.py +++ b/tests/customized_ops/ring_attn/ring_attn_runner.py @@ -9,6 +9,7 @@ """ import sys +from typing import Tuple import torch from runner_base import RingAttnRunnerBase @@ -38,6 +39,10 @@ class RingAttnRunner(RingAttnRunnerBase): def function_signature(self) -> str: return 'nnscaler.customized_ops.ring_attention.ring_attn.wrap_ring_attn_func' + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 1 + @property def function_name(self) -> str: return 'wrap_ring_attn_func' @@ -47,32 +52,32 @@ def create_test_module(self, config) -> torch.nn.Module: def prepare_inputs(self, config, device, torch_dtype): """Prepare regular inputs with shape [batch_size, seq_len, num_heads, head_dim]""" - q = torch.randn( + q = torch.clamp(torch.randn( config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype - ) + ), min=-1, max=1) - k = torch.randn( + k = torch.clamp(torch.randn( config.batch_size, config.max_seqlen, config.num_kv_heads, config.head_dim, device=device, dtype=torch_dtype - ) + ), min=-1, max=1) - v = torch.randn( + v = torch.clamp(torch.randn( config.batch_size, config.max_seqlen, config.num_kv_heads, config.head_dim, device=device, dtype=torch_dtype - ) + ), min=-1, max=1) return {'q': q, 'k': k, 'v': v} @@ -95,6 +100,12 @@ def get_dummy_forward_args(self, inputs): } +def ring_attn_test(dtype="bf16", config_name="tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = RingAttnRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + def run_correctness_test(**kwargs): """Legacy function for backward compatibility""" runner = RingAttnRunner() diff --git a/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py b/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py index 806cd852..2ca40313 100644 --- a/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py +++ b/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py @@ -9,6 +9,7 @@ """ import sys +from typing import Tuple import torch from runner_base import RingAttnRunnerBase @@ -37,6 +38,10 @@ class RingAttnVarlenRunner(RingAttnRunnerBase): def function_signature(self) -> str: return 'nnscaler.customized_ops.ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func' + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 0 + @property def function_name(self) -> str: return 'ring_attn_varlen_func' @@ -50,9 +55,9 @@ def prepare_inputs(self, config, device, torch_dtype): total_seqlen = config.cu_seqlens[-1] # Create inputs with total sequence length (don't set requires_grad here, base class handles it) - q = torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) - k = torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) - v = torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + q = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + k = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + v = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) return { 'q': q, @@ -84,6 +89,12 @@ def get_dummy_forward_args(self, inputs): } +def ring_attn_varlen_test(dtype="bf16", config_name="tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = RingAttnVarlenRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + def run_ring_attn_correctness_test(**kwargs): """Legacy function for backward compatibility""" runner = RingAttnVarlenRunner() diff --git a/tests/customized_ops/ring_attn/runner_base.py b/tests/customized_ops/ring_attn/runner_base.py index 2fcd7790..e7d9a32f 100644 --- a/tests/customized_ops/ring_attn/runner_base.py +++ b/tests/customized_ops/ring_attn/runner_base.py @@ -32,6 +32,12 @@ def function_signature(self) -> str: """Return the function signature to look for in the graph""" pass + @property + @abstractmethod + def partition_position(self) -> Tuple[int, int]: + """Return the partition position (idx, dim)""" + pass + @property @abstractmethod def function_name(self) -> str: @@ -66,14 +72,16 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: for idx, node in enumerate(graph.select(ntype=IRFwOperation)): if not partitioned and node.signature == self.function_signature: print(f'\nPartitioned node: {node}\n') - sub_nodes = graph.partition(node, node.algorithm('dim'), idx=0, dim=0, num=ngpus) + idx, dim = self.partition_position + sub_nodes = graph.partition(node, node.algorithm('dim'), idx=idx, dim=dim, num=ngpus) partitioned = True else: sub_nodes = graph.replicate(node, times=ngpus) for idx, sub_node in enumerate(sub_nodes): graph.assign(sub_node, idx) if not partitioned: - print(f"WARNING: No {self.function_name} found in graph for partitioning") + signatures = [node.signature for node in graph.select(ntype=IRFwOperation)] + raise RuntimeError(f"Failed to find the target function '{self.function_signature}' in {signatures}") return graph return policy @@ -83,10 +91,10 @@ def initialize_distributed(self): if not torch.cuda.is_available(): print("ERROR: CUDA is not available") sys.exit(1) - + rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) - + # Check if we have enough GPUs available_gpus = torch.cuda.device_count() if available_gpus < world_size: @@ -106,7 +114,7 @@ def initialize_distributed(self): sys.exit(1) print(f"[INFO] world_size:{world_size}, rank:{rank}, available_gpus:{available_gpus}") - + try: dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) except Exception as e: @@ -117,10 +125,13 @@ def initialize_distributed(self): nnscaler.init() return world_size, rank - def get_tolerances(self, dtype: str) -> Dict[str, float]: + def get_tolerances(self, dtype: str, num_heads: int, num_kv_heads: int) -> Dict[str, float]: """Get tolerance values based on data type""" if dtype == "bf16": - return dict(atol=2.5e-2, rtol=2.5e-2) + if num_heads == num_kv_heads: + return dict(atol=2.5e-2, rtol=2.5e-2) + else: + return dict(atol=3.5e-2, rtol=3.5e-2) elif dtype == "fp16": return dict(atol=5e-3, rtol=5e-3) else: @@ -222,7 +233,7 @@ def run_correctness_test(self, config_name: str, dtype: str = "bf16", **kwargs): single_out, single_grad_tensors = self.run_single_gpu_reference(single_inputs, config) # Create gradient for backward pass - dout = torch.randn_like(single_out, device=device, dtype=torch_dtype) + dout = torch.clamp(torch.randn_like(single_out, device=device, dtype=torch_dtype), min=-1, max=1) # Ensure dout is consistent across all ranks dist.broadcast(dout, src=0) single_out.backward(dout) @@ -251,7 +262,7 @@ def run_correctness_test(self, config_name: str, dtype: str = "bf16", **kwargs): print(" Done!" if rank_id == 0 else "") # Check correctness with tolerances - tols = self.get_tolerances(dtype) + tols = self.get_tolerances(dtype, config.num_heads, config.num_kv_heads) # Verify outputs and gradients try: diff --git a/tests/customized_ops/ring_attn/test_base.py b/tests/customized_ops/ring_attn/test_base.py index 211ee567..44870792 100644 --- a/tests/customized_ops/ring_attn/test_base.py +++ b/tests/customized_ops/ring_attn/test_base.py @@ -8,14 +8,14 @@ import os import sys -import subprocess from abc import ABC, abstractmethod -from typing import Dict, Any, List +from typing import Dict, Any, List, Tuple +from functools import partial import pytest import torch -from configs import ( +from .configs import ( DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS, @@ -23,6 +23,8 @@ list_configs ) +from ...launch_torchrun import torchrun + class RingAttnTestBase(ABC): """Base class for ring attention tests""" @@ -39,11 +41,17 @@ def test_name_prefix(self) -> str: """Return the prefix for test names (e.g., 'ring_attn' or 'ring_attn_varlen')""" pass + @property + @abstractmethod + def test_function_name(self) -> str: + """Return the name of the test function to import (e.g., 'zigzag_attn_test')""" + pass + def _check_gpu_availability(self, required_gpus: int): """Check if enough GPUs are available and skip test if not""" if not torch.cuda.is_available(): pytest.skip("CUDA is not available") - + available_gpus = torch.cuda.device_count() if available_gpus < required_gpus: pytest.skip(f"Test requires {required_gpus} GPUs, but only {available_gpus} available") @@ -54,7 +62,11 @@ def _get_project_root(self): return os.path.abspath(os.path.join(current_dir, "../../../")) def get_bash_arguments(self, num_gpus_per_node: int, **kwargs) -> List[str]: - """Generate command line arguments for running the test script""" + """Generate command line arguments for running the test script + + Deprecated: This method is kept for backward compatibility. + The new implementation uses launch_torchrun directly. + """ args = [ "python3", "-m", @@ -73,19 +85,37 @@ def get_bash_arguments(self, num_gpus_per_node: int, **kwargs) -> List[str]: args.append(f"{k}={v}") return args + def _get_test_function(self): + """Get the test function for this test""" + # Add the script directory to sys.path to allow imports + project_root = self._get_project_root() + script_dir = os.path.join(project_root, "tests", "customized_ops", "ring_attn") + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + # Import the module and get the test function + module_name = self.runner_script_name.replace('.py', '') + module = __import__(module_name) + + if hasattr(module, self.test_function_name): + return getattr(module, self.test_function_name) + else: + raise ImportError(f"Could not find function '{self.test_function_name}' in {module_name}") + def run_test_subprocess(self, num_gpus: int, **kwargs): - """Run test using subprocess with the configured runner script""" - # Check GPU availability before running subprocess + """Run test using torchrun with the configured test function""" + # Check GPU availability before running test self._check_gpu_availability(num_gpus) - subprocess.run( - self.get_bash_arguments( - num_gpus_per_node=num_gpus, - **kwargs - ), - check=True, - cwd=self._get_project_root() - ) + # Get the test function and use torchrun to execute it + test_function = self._get_test_function() + + # Extract common parameters + dtype = kwargs.get('dtype', 'bf16') + config_name = kwargs.get('config_name', 'tiny') + + # Use partial with positional arguments like test_gnorm.py + return partial(torchrun, num_gpus, test_function, dtype, config_name)() # Common test methods that can be used by both ring_attn and ring_attn_varlen diff --git a/tests/customized_ops/ring_attn/test_ring_attn.py b/tests/customized_ops/ring_attn/test_ring_attn.py index bcb47f16..e1378101 100644 --- a/tests/customized_ops/ring_attn/test_ring_attn.py +++ b/tests/customized_ops/ring_attn/test_ring_attn.py @@ -17,8 +17,8 @@ except ImportError: pytest.skip("flash_attn_func not available", allow_module_level=True) -from test_base import RingAttnTestBase, create_parametrized_tests -from configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS +from .test_base import RingAttnTestBase, create_parametrized_tests +from .configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS class RingAttnTest(RingAttnTestBase): @@ -28,6 +28,10 @@ class RingAttnTest(RingAttnTestBase): def runner_script_name(self) -> str: return "ring_attn_runner.py" + @property + def test_function_name(self) -> str: + return "ring_attn_test" + @property def test_name_prefix(self) -> str: return "ring_attn" diff --git a/tests/customized_ops/ring_attn/test_ring_attn_varlen.py b/tests/customized_ops/ring_attn/test_ring_attn_varlen.py index 04c26035..f86fc276 100644 --- a/tests/customized_ops/ring_attn/test_ring_attn_varlen.py +++ b/tests/customized_ops/ring_attn/test_ring_attn_varlen.py @@ -17,8 +17,8 @@ except ImportError: pytest.skip("flash_attn_varlen_func not available", allow_module_level=True) -from test_base import RingAttnTestBase, create_parametrized_tests -from configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS +from .test_base import RingAttnTestBase, create_parametrized_tests +from .configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS class RingAttnVarlenTest(RingAttnTestBase): @@ -28,6 +28,10 @@ class RingAttnVarlenTest(RingAttnTestBase): def runner_script_name(self) -> str: return "ring_attn_varlen_runner.py" + @property + def test_function_name(self) -> str: + return "ring_attn_varlen_test" + @property def test_name_prefix(self) -> str: return "ring_attn_varlen" diff --git a/tests/customized_ops/ring_attn/test_shuffle_varlen.py b/tests/customized_ops/ring_attn/test_shuffle_varlen.py new file mode 100644 index 00000000..8f7271a4 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_shuffle_varlen.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Simple test for shuffle_varlen and unshuffle_varlen functions. +""" + +import pytest +import torch +import torch.distributed as dist +from dataclasses import dataclass +from typing import List +from functools import partial + +from tests.launch_torchrun import torchrun + + +@dataclass +class ShuffleVarlenConfig: + """Simple test configuration""" + name: str + batch_size: int + seq_lens: List[int] + hidden_dim: int + + +# Test configurations +CONFIGS = { + "tiny": ShuffleVarlenConfig("tiny", 2, [512, 768], 64), + "small": ShuffleVarlenConfig("small", 2, [1024, 1536], 128), + "medium": ShuffleVarlenConfig("medium", 2, [1024, 1536], 256), + "uneven": ShuffleVarlenConfig("uneven", 3, [256, 768, 1024], 128), +} + + +def shuffle_varlen_test(config_name="tiny", dtype="float32", world_size=2): + """Test shuffle_varlen and unshuffle_varlen functions""" + + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + + rank = dist.get_rank() + world_size_actual = dist.get_world_size() + device = torch.device(f'cuda:{rank}') + torch.cuda.set_device(device) + + if rank == 0: + print(f"Testing shuffle_varlen and unshuffle_varlen functions") + print(f"Configuration: {config_name}") + print(f"World size: {world_size_actual}") + print(f"Data type: {dtype}") + print("=" * 60) + + # Get configuration + config = CONFIGS[config_name] + + # Set up process group for context parallel + cp_ranks = list(range(world_size_actual)) + cp_group = dist.new_group(cp_ranks) + + # Create cumulative sequence lengths (padded to be divisible by 2*world_size) + cu_seqlens = torch.zeros(config.batch_size + 1, dtype=torch.int32, device=device) + total_slices_per_seq = 2 * world_size_actual + + for i, seq_len in enumerate(config.seq_lens): + # Pad sequence length to be divisible by total_slices_per_seq + padded_seq_len = ((seq_len + total_slices_per_seq - 1) // total_slices_per_seq) * total_slices_per_seq + cu_seqlens[i + 1] = cu_seqlens[i] + padded_seq_len + + total_seq_len = cu_seqlens[len(config.seq_lens)].item() # Use len(config.seq_lens) instead of -1 + + # Convert dtype string to torch dtype + torch_dtype = getattr(torch, dtype) + + # Import functions from varlen_utils + from nnscaler.customized_ops.ring_attention.varlen_utils import shuffle_varlen, unshuffle_varlen + + if rank == 0: + print("Running shuffle/unshuffle correctness tests...") + + tolerance = 1e-5 if torch_dtype == torch.float32 else 1e-2 + + # Test 1: 1D tensor (like position_ids) + if rank == 0: + print(" Test: 1D tensor (total_seq_len,)...") + + try: + # Create full tensor first (on rank 0) + if rank == 0: + full_tensor_1d = torch.arange(total_seq_len, dtype=torch_dtype, device=device) + else: + full_tensor_1d = torch.empty(total_seq_len, dtype=torch_dtype, device=device) + + # Broadcast full tensor to all ranks for reference + dist.broadcast(full_tensor_1d, src=0, group=cp_group) + + # Split tensor for local input (each rank gets a chunk) + chunk_size = total_seq_len // world_size_actual + start_idx = rank * chunk_size + end_idx = start_idx + chunk_size if rank < world_size_actual - 1 else total_seq_len + local_tensor_1d = full_tensor_1d[start_idx:end_idx].clone() + + # Test shuffle -> unshuffle + shuffled = shuffle_varlen(local_tensor_1d, cu_seqlens, cp_ranks, cp_group) + unshuffled = unshuffle_varlen(shuffled, cu_seqlens, cp_ranks, cp_group) + + # Compare with original local chunk + if torch.allclose(local_tensor_1d, unshuffled, atol=tolerance): + if rank == 0: + print(" ✓ 1D tensor test passed") + else: + if rank == 0: + print(" ✗ 1D tensor test FAILED") + raise AssertionError("1D tensor test failed") + + except Exception as e: + if rank == 0: + print(f" ✗ 1D tensor test FAILED with error: {e}") + raise e + + # Test 2: 2D tensor (total_seq_len, hidden_dim) + if rank == 0: + print(" Test: 2D tensor (total_seq_len, hidden_dim)...") + + try: + # Create full tensor first (on rank 0) + if rank == 0: + full_tensor_2d = torch.randn(total_seq_len, config.hidden_dim, dtype=torch_dtype, device=device) + else: + full_tensor_2d = torch.empty(total_seq_len, config.hidden_dim, dtype=torch_dtype, device=device) + + # Broadcast full tensor to all ranks for reference + dist.broadcast(full_tensor_2d, src=0, group=cp_group) + + # Split tensor for local input (each rank gets a chunk) + chunk_size = total_seq_len // world_size_actual + start_idx = rank * chunk_size + end_idx = start_idx + chunk_size if rank < world_size_actual - 1 else total_seq_len + local_tensor_2d = full_tensor_2d[start_idx:end_idx].clone() + + # Test shuffle -> unshuffle + shuffled = shuffle_varlen(local_tensor_2d, cu_seqlens, cp_ranks, cp_group) + unshuffled = unshuffle_varlen(shuffled, cu_seqlens, cp_ranks, cp_group) + + # Compare with original local chunk + if torch.allclose(local_tensor_2d, unshuffled, atol=tolerance): + if rank == 0: + print(" ✓ 2D tensor test passed") + else: + if rank == 0: + print(" ✗ 2D tensor test FAILED") + raise AssertionError("2D tensor test failed") + + except Exception as e: + if rank == 0: + print(f" ✗ 2D tensor test FAILED with error: {e}") + raise e + + dist.barrier() + + if rank == 0: + print("✓ All shuffle/unshuffle tests PASSED!") + + dist.destroy_process_group() + + +class TestShuffleVarlen: + """Simple test class for shuffle/unshuffle varlen""" + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_tiny(self, dtype): + """Test shuffle/unshuffle varlen with tiny configuration""" + partial(torchrun, 2, shuffle_varlen_test, "tiny", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_small(self, dtype): + """Test shuffle/unshuffle varlen with small configuration""" + partial(torchrun, 2, shuffle_varlen_test, "small", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_medium(self, dtype): + """Test shuffle/unshuffle varlen with medium configuration""" + partial(torchrun, 2, shuffle_varlen_test, "medium", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_uneven(self, dtype): + """Test shuffle/unshuffle varlen with uneven sequence lengths""" + partial(torchrun, 2, shuffle_varlen_test, "uneven", dtype)() + + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_shuffle_varlen_multi_gpu(self, num_gpus): + """Test shuffle/unshuffle varlen on multiple GPUs""" + partial(torchrun, num_gpus, shuffle_varlen_test, "tiny", "float32")() + + +# Standalone test functions for pytest discovery +@pytest.mark.parametrize("config,dtype", [ + ("tiny", "float32"), ("tiny", "float16"), + ("small", "float32"), ("small", "float16"), + ("uneven", "float32"), ("uneven", "float16"), +]) +def test_shuffle_varlen_correctness(config, dtype): + """Test shuffle/unshuffle varlen correctness""" + partial(torchrun, 2, shuffle_varlen_test, config, dtype)() + + +@pytest.mark.parametrize("config,num_gpus", [ + ("tiny", 2), ("tiny", 4), + ("small", 2), ("small", 4), +]) +def test_shuffle_varlen_multi_gpu(config, num_gpus): + """Test shuffle/unshuffle varlen on multiple GPUs""" + partial(torchrun, num_gpus, shuffle_varlen_test, config, "float32")() \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_zigzag_attn.py b/tests/customized_ops/ring_attn/test_zigzag_attn.py index 6bb885bf..3bca5792 100644 --- a/tests/customized_ops/ring_attn/test_zigzag_attn.py +++ b/tests/customized_ops/ring_attn/test_zigzag_attn.py @@ -20,7 +20,7 @@ except ImportError: pytest.skip("flash_attn_func not available", allow_module_level=True) -from test_base import RingAttnTestBase +from .test_base import RingAttnTestBase class TestZigzagAttn(RingAttnTestBase): @@ -30,6 +30,10 @@ class TestZigzagAttn(RingAttnTestBase): def runner_script_name(self) -> str: return "zigzag_attn_runner.py" + @property + def test_function_name(self) -> str: + return "zigzag_attn_test" + @property def test_name_prefix(self) -> str: return "zigzag_attn" diff --git a/tests/customized_ops/ring_attn/zigzag_attn_runner.py b/tests/customized_ops/ring_attn/zigzag_attn_runner.py index 5b1ba465..6e557e2a 100644 --- a/tests/customized_ops/ring_attn/zigzag_attn_runner.py +++ b/tests/customized_ops/ring_attn/zigzag_attn_runner.py @@ -9,7 +9,7 @@ import os import sys -from typing import Dict, Any +from typing import Dict, Any, Tuple import torch import torch.nn as nn @@ -23,7 +23,11 @@ class ZigzagAttnRunner(RingAttnRunnerBase): @property def function_signature(self) -> str: - return "wrap_zigzag_attn_func" + return "nnscaler.customized_ops.ring_attention.zigzag_attn.wrap_zigzag_attn_func" + + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 1 @property def function_name(self) -> str: @@ -55,9 +59,9 @@ def prepare_inputs(self, config, device, torch_dtype): head_dim = config.head_dim # Create input tensors - q = torch.randn(batch_size, max_seqlen, num_heads, head_dim, device=device, dtype=torch_dtype) - k = torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype) - v = torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype) + q = torch.clamp(torch.randn(batch_size, max_seqlen, num_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + k = torch.clamp(torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + v = torch.clamp(torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) return { 'q': q, @@ -84,10 +88,16 @@ def get_dummy_forward_args(self, inputs) -> Dict[str, Any]: } +def zigzag_attn_test(dtype="bf16", config_name="zigzag_tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = ZigzagAttnRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + def main(): """Main entry point for command line execution""" kwargs = dict(arg.split("=") for arg in sys.argv[1:]) - + runner = ZigzagAttnRunner() runner.main(**kwargs) From b4ca05c3bef4f52572557c251dc6b055421d24db Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 20 Nov 2025 06:38:46 +0000 Subject: [PATCH 1851/1892] Merged PR 2423: [BugFix] Handle trim correctly when multi-node --- nnscaler/parallel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 6001f9e1..02459418 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1765,10 +1765,10 @@ def _sort_state_dicts(state_dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]] sorted_state_dicts =[None] * len(state_dicts) for state_dict in state_dicts: rank = _get_state_dict_rank(state_dict) + if rank >= len(state_dicts): + raise ValueError(f"Invalid rank {rank} in state_dicts with length {len(state_dicts)}.") if sorted_state_dicts[rank] is not None: raise ValueError(f"Duplicate rank {rank} in state_dicts.") - if rank >= len(state_dicts): - raise ValueError(f"Invalid rank {rank} in state_dicts.") sorted_state_dicts[rank] = state_dict return sorted_state_dicts @@ -2973,8 +2973,8 @@ def trimmed_broadcast_merged_state_dict( ret = None if cur_rank == src_rank: - pmodule_stubs = [_construct_parallel_module_stub(r[0]) for r in rank_metadatas] - opt_extra_states = [r[1] for r in rank_metadatas] + pmodule_stubs = {rank : _construct_parallel_module_stub(r[0]) for rank, r in zip(dst_ranks, rank_metadatas)} + opt_extra_states = {rank : r[1] for rank, r in zip(dst_ranks, rank_metadatas)} for rank in dst_ranks: if rank != cur_rank: logger.info(f'At rank {src_rank}: Trimming module state dict for rank {rank}') From 85bd572eb665bff8772ce1566f0ed46fd9cf7d78 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 9 Dec 2025 03:19:00 +0000 Subject: [PATCH 1852/1892] Merged PR 2425: [Feat] cli: Add safetensors support 1. Add custom state dict save/load support 2. Add safetensors support. Currently no lazy loading support in cli to simplify the logic. --- nnscaler/cli/__init__.py | 2 + nnscaler/cli/mixed_module.py | 12 +- nnscaler/cli/serialization.py | 195 ++++++++++++++++++ nnscaler/cli/trainer.py | 64 +++--- nnscaler/cli/trainer_args.py | 12 ++ nnscaler/runtime/module.py | 3 +- nnscaler/runtime/serialization.py | 249 +++++++++++++++++++++++ nnscaler/utils.py | 203 ++++++++++++++---- requirements.txt | 1 + tests/cli/test_trainer.py | 155 ++++++++++++-- tests/parallel_module/common.py | 3 +- tests/parallel_module/test_checkpoint.py | 19 ++ tests/runtime/test_serialization.py | 68 +++++++ tests/test_utils.py | 88 +++++++- 14 files changed, 982 insertions(+), 92 deletions(-) create mode 100644 nnscaler/cli/serialization.py create mode 100644 nnscaler/runtime/serialization.py create mode 100644 tests/runtime/test_serialization.py diff --git a/nnscaler/cli/__init__.py b/nnscaler/cli/__init__.py index 958e874f..d218f6f9 100644 --- a/nnscaler/cli/__init__.py +++ b/nnscaler/cli/__init__.py @@ -17,4 +17,6 @@ AggregatedOutputs, ) +from nnscaler.cli.serialization import register_format + from nnscaler.parallel import ComputeConfig diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index eb5bb7ab..81deb84e 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -19,6 +19,7 @@ TrainerArgs, PrecisionMixin, PolicyMixin, ModuleParallelizeConfig, ComputeConfig, load_type ) +from .serialization import Checkpointer logger = logging.getLogger(__name__) @@ -132,10 +133,10 @@ def load_tracing_weights(self) -> Optional[dict[str, Any]]: # try to reuse the weights from the tracing weights tracing_weights = self.tracing_weights if self.tracing_from_weights and tracing_weights is None: - tracing_weights = torch.load(self.tracing_from_weights) + tracing_weights = Checkpointer.load(self.tracing_from_weights) else: if self.tracing_from_weights: - tracing_weights = torch.load(self.tracing_from_weights) + tracing_weights = Checkpointer.load(self.tracing_from_weights) elif self.parallel_module.tracing_from_weights_prefix: leading_key = self.parallel_module.tracing_from_weights_prefix + '.' tracing_weights = {} @@ -185,7 +186,10 @@ def create_dummy_forward_args(self, dummy_input) -> dict[str, Any]: def resolve_compute_config(self): compute_config = copy.deepcopy(self.compute_config) - compute_config.pas_config['__pas_name'] = self.pas_policy + compute_config.pas_config['__pas_name'] = \ + self.pas_policy \ + if not callable(self.pas_policy) \ + else f'{self.pas_policy.__module__}.{self.pas_policy.__qualname__}' # autodist configs compute_config.pas_config['update_freq'] = self.trainer_args.update_freq compute_config.pas_config['use_bf16'] = self.param_dtype == torch.bfloat16 @@ -288,7 +292,7 @@ def parameters_for_calc_gnorm(self): def parallelize_model(trainer_args: TrainerArgs, dummy_input: dict[str, Any], load_module: bool, build_buckets: bool): tracing_weights = None if trainer_args.tracing_from_weights: - tracing_weights = torch.load(trainer_args.tracing_from_weights) + tracing_weights = Checkpointer.load(trainer_args.tracing_from_weights) def _new_adapter(parallel_module=None): return ModuleParallelizeConfigAdapter( diff --git a/nnscaler/cli/serialization.py b/nnscaler/cli/serialization.py new file mode 100644 index 00000000..bc90e59a --- /dev/null +++ b/nnscaler/cli/serialization.py @@ -0,0 +1,195 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, Callable, Protocol +from pathlib import Path + +import torch + +from nnscaler.runtime.serialization import load, save + + +class _LoadProc(Protocol): + def __call__(self, f: str | Path, *, device='cpu') -> Any: ... + + +class _SaveProc(Protocol): + def __call__(self, obj: Any, f: str | Path) -> None: ... + + +class Checkpointer: + # the format of the checkpoint file + # keys: epoch, step, rank + # currently it is not configurable + # TODO: make it configurable + CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/{rank}{suffix}' + CHECKPOINT_LAST_DIR_NAME: str = 'last' + CHECKPOINT_BEST_DIR_NAME: str = 'best' + CHECKPOINT_MERGED_FILE_NAME: str = 'merged{suffix}' + CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}{suffix}' + CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}{suffix}' + SUFFIX_MAP: dict[str, str] = { + 'pt': '.ckpt', + 'safetensors': '.safetensors' + } + # will use torch.load and torch.save for other suffixes + SUFFIX_HANDLERS: dict[str, tuple[_LoadProc, _SaveProc]] = { + '.safetensors': (load, save), + } + + def __init__(self, format: str = 'pt'): + if format not in self.SUFFIX_MAP: + raise ValueError(f"Unsupported checkpoint format: {format}") + self.format = format + self.suffix = self.SUFFIX_MAP[format] + + def get_checkpoint_file_path(self, epoch: int, step: int, rank: int) -> str: + return self.CHECKPOINT_FILE_FORMAT.format(epoch=epoch, step=step, rank=rank, suffix=self.suffix) + + def get_last_checkpoint_file_path(self, rank: int) -> str: + return self.CHECKPOINT_LAST_FILE_FORMAT.format(rank=rank, suffix=self.suffix) + + def get_best_checkpoint_file_path(self, rank: int) -> str: + return self.CHECKPOINT_BEST_FILE_FORMAT.format(rank=rank, suffix=self.suffix) + + def get_merged_checkpoint_file_name(self) -> str: + return self.CHECKPOINT_MERGED_FILE_NAME.format(suffix=self.suffix) + + def get_last_dir_name(self) -> str: + return self.CHECKPOINT_LAST_DIR_NAME + + def get_best_dir_name(self) -> str: + return self.CHECKPOINT_BEST_DIR_NAME + + @classmethod + def load(cls, f: str | Path, *, device='cpu') -> Any: + """ + Loads a checkpoint file + + Args: + f: filename of the checkpoint file. + if the suffix is .safetensors, it will be loaded as safetensors file. + otherwise, it will be loaded as a PyTorch checkpoint file. + device (`str`, *optional*, defaults to `"cpu"`): + The device on which you want the tensors. + """ + suffix = Path(f).suffix + if suffix in cls.SUFFIX_HANDLERS: + load_func, _ = cls.SUFFIX_HANDLERS[suffix] + return load_func(f, device=device) + else: + return torch.load(f, map_location=device, weights_only=False) + + @classmethod + def save(cls, obj: Any, f: str | Path) -> None: + """ + Saves a checkpoint file + + Args: + obj (`Any`): + The object to save. + f: filename of the checkpoint file. + if the suffix is .safetensors, it will be saved as safetensors file. + otherwise, it will be saved as a PyTorch checkpoint file. + """ + suffix = Path(f).suffix + if suffix in cls.SUFFIX_HANDLERS: + _, save_func = cls.SUFFIX_HANDLERS[suffix] + save_func(obj, f) + else: + torch.save(obj, f) + + @classmethod + def load_for_rank(cls, dir: str | Path, rank: int, device='cpu') -> Any: + """ + Loads a checkpoint file for a specific rank + + Args: + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to load. + device (`str`, `int`, *optional*): + The device on which you want the tensors. + """ + for suffix in cls.SUFFIX_MAP.values(): + f = Path(dir) / f"{rank}{suffix}" + if f.exists(): + return cls.load(f, device=device) + raise FileNotFoundError(f"No checkpoint file found for rank {rank} in directory {dir}") + + def save_for_rank(self, obj: Any, dir: str | Path, rank: int) -> None: + """ + Saves a checkpoint file for a specific rank + + Args: + obj (`Any`): + The object to save. + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to save. + """ + f = Path(dir) / f"{rank}{self.suffix}" + self.save(obj, f) + + @classmethod + def remove_for_rank(cls, dir: str | Path, rank: int) -> None: + """ + Removes a checkpoint file for a specific rank + + Args: + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to remove. + """ + for suffix in cls.SUFFIX_MAP.values(): + f = Path(dir) / f"{rank}{suffix}" + if f.exists(): + f.unlink() + + @classmethod + def list_checkpoints(cls, dir: str | Path) -> list[Path]: + """ + List the checkpoint files in a directory + Args: + dir (`str`): + The directory where the checkpoint files are stored. + Returns: + (`list[Path]`): + The list of checkpoint files in the directory. + """ + p = Path(dir) + files = [] + for suffix in cls.SUFFIX_MAP.values(): + fs = list(p.glob(f"*{suffix}")) + if fs: + if files: + raise ValueError(f"Mixed checkpoint file formats in directory {dir}") + else: + files.extend(fs) + return files + + +def register_format( + name: str, + suffix: str, + save_func: _SaveProc, + load_func: _LoadProc, + ) -> None: + """ + Registers a new serialization format. + Args: + name (`str`): + The name of the format. + suffix (`str`): + The file suffix of the format. + load_func: + The function to load the format. + save_func: + The function to save the format. + """ + suffix = '.' + suffix.lstrip('.') + Checkpointer.SUFFIX_MAP[name] = suffix + Checkpointer.SUFFIX_HANDLERS[suffix] = (load_func, save_func) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index ae65da81..b1768924 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -28,23 +28,12 @@ from .trainer_args import AggregatedOutputs, TrainerArgs, fix_input from .train_hook import AggregatedTrainHook, TrainHook, TrainHookHost from .mixed_module import parallelize_model, mixin_module +from .serialization import Checkpointer logger = logging.getLogger(__name__) -# the format of the checkpoint file -# keys: epoch, step, rank -# currently it is not configurable -# TODO: make it configurable -CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/{rank}.ckpt' -CHECKPOINT_LAST_DIR_NAME: str = 'last' -CHECKPOINT_BEST_DIR_NAME: str = 'best' -CHECKPOINT_MERGED_FILE_NAME: str = 'merged.ckpt' -CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}.ckpt' -CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}.ckpt' - - @dataclass class TrainStatus: best_loss = float('inf') @@ -103,6 +92,7 @@ def __init__(self, self.max_train_steps = None self.loggers = [] self.hook = None + self.checkpointer = None # RNG states pending resume; reset to None after resuming self.rng_states_from_resume: dict[str, torch.Tensor] | None = None @@ -237,6 +227,7 @@ def reducer_pre_hook(reducer, grad): # Currently we never pass `last_epoch` to its constructor self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) self.loggers = self.train_args.create_loggers() + self.checkpointer = self.train_args.create_checkpointer() supported_hook_components = [ self.model, @@ -265,7 +256,7 @@ def reducer_pre_hook(reducer, grad): @classmethod def _merge_checkpoint(cls, checkpoint_files: List[str]): - state_dicts = [torch.load(f, map_location='cpu', weights_only=False) for f in checkpoint_files] + state_dicts = [Checkpointer.load(f) for f in checkpoint_files] for i in range(1, len(state_dicts)): if state_dicts[i]['train_args'] != state_dicts[0]['train_args']: raise ValueError(f"train_args in {checkpoint_files[i]} is different from {checkpoint_files[0]}") @@ -380,7 +371,7 @@ def _broadcast_values(sdict, keys): @classmethod def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): merged_state_dict = cls._merge_checkpoint(checkpoint_files) - torch.save(merged_state_dict, output_file) + Checkpointer.save(merged_state_dict, output_file) def _log_finalize(self): for logger in self.loggers: @@ -408,13 +399,13 @@ def _load_checkpoint(self): load_from_merged = True trimmed_broadcast_required = self.train_args.checkpoint.resume_from.save_memory if not self.train_args.checkpoint.resume_from.save_memory or self.local_rank == 0: - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + state_dict = self.checkpointer.load(resume_from) if convert_fn := self.train_args.checkpoint.resolved_convert_fn: state_dict = convert_fn(state_dict) else: state_dict = None else: - ckpt_files = list(resume_from.glob('*.ckpt')) + ckpt_files = self.checkpointer.list_checkpoints(resume_from) rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} if set(rank_ckpt_files.keys()) != set(range(len(rank_ckpt_files))): raise ValueError(f"Checkpoint files in {resume_from} are not complete: {rank_ckpt_files.keys()}") @@ -440,8 +431,7 @@ def _load_checkpoint(self): ) logger.info(f"Broadcasted merged checkpoint to all ranks.") else: - resume_from = resume_from / f'{self.rank}.ckpt' - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + state_dict = self.checkpointer.load_for_rank(resume_from, self.rank) if state_dict['train_args']['compute_config'] != asdict(self.train_args.compute_config): logger.warning( f"compute_config is changed, and loading checkpoint may fail. " @@ -625,26 +615,29 @@ def _save_checkpoint(self, loss): self.hook.on_save_checkpoint(self, state_dict) - ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( + ckpt_file = save_dir / self.checkpointer.get_checkpoint_file_path( epoch=current_epoch, step=self.train_status.finished_train_steps, rank=self.rank, ) logger.info(f"Saving checkpoint to {str(ckpt_file.parent)}") ckpt_file.parent.mkdir(parents=True, exist_ok=True) - torch.save(state_dict, ckpt_file) + self.checkpointer.save(state_dict, ckpt_file) # save last if checkpoint_config.save_last: logger.info(f"Saving checkpoint as the last checkpoint.") - last_file = save_dir / CHECKPOINT_LAST_FILE_FORMAT.format( + + # remove the old symlink or file + self.checkpointer.remove_for_rank( + save_dir / self.checkpointer.get_last_dir_name(), + self.rank + ) + last_file = save_dir / self.checkpointer.get_last_checkpoint_file_path( rank=self.rank ) last_file.parent.mkdir(parents=True, exist_ok=True) if checkpoint_config.symlink_best_and_last: - # remove the old symlink or file - if last_file.is_symlink() or last_file.exists(): - last_file.unlink() # symblink as relative path last_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) # last_file.symlink_to(ckpt_file) @@ -655,16 +648,18 @@ def _save_checkpoint(self, loss): if checkpoint_config.save_best and loss <= self.train_status.best_loss: logger.info(f"Best loss updated: {self.train_status.best_loss:.3f} -> {loss:.3f}") logger.info(f"Saving checkpoint as the best checkpoint.") - best_file = save_dir / CHECKPOINT_BEST_FILE_FORMAT.format( - epoch=current_epoch, - step=self.train_status.finished_train_steps, + + # remove the old symlink or file + self.checkpointer.remove_for_rank( + save_dir / self.checkpointer.get_best_dir_name(), + self.rank + ) + best_file = save_dir / self.checkpointer.get_best_checkpoint_file_path( rank=self.rank, ) best_file.parent.mkdir(parents=True, exist_ok=True) if checkpoint_config.symlink_best_and_last: # symblink as relative path - if best_file.is_symlink() or best_file.exists(): - best_file.unlink() best_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) # best_file.symlink_to(ckpt_file) else: @@ -688,7 +683,10 @@ def _expire_checkpoints(self): save_dir = Path(self.train_args.checkpoint.save_dir) checkpoints = [ p.name for p in save_dir.glob('*') - if p.is_dir() and p.name not in [CHECKPOINT_BEST_DIR_NAME, CHECKPOINT_LAST_DIR_NAME] + if p.is_dir() and p.name not in [ + self.checkpointer.get_best_dir_name(), + self.checkpointer.get_last_dir_name() + ] ] if len(checkpoints) <= self.train_args.checkpoint.keep_last_n_checkpoints: return @@ -698,12 +696,12 @@ def _expire_checkpoints(self): checkpoint_info.sort() expire_list = [c[1] for c in checkpoint_info[:-self.train_args.checkpoint.keep_last_n_checkpoints]] - best_ckpt = save_dir / CHECKPOINT_BEST_DIR_NAME - last_ckpt = save_dir / CHECKPOINT_LAST_DIR_NAME + best_ckpt = save_dir / self.checkpointer.get_best_dir_name() + last_ckpt = save_dir / self.checkpointer.get_last_dir_name() for ckpt_dir in [best_ckpt, last_ckpt]: if not ckpt_dir.exists(): continue - for p in ckpt_dir.glob('*.ckpt'): + for p in self.checkpointer.list_checkpoints(ckpt_dir): if p.is_symlink(): ckpt_name = p.resolve().parent.name if ckpt_name in expire_list: diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index d4bd6974..29e42044 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -32,6 +32,7 @@ ) from .loggers.logger_base import LoggerBase from .train_hook import TrainHook +from .serialization import Checkpointer if TYPE_CHECKING: from .trainer import Trainer @@ -423,6 +424,11 @@ class CheckpointConfig: save_dir: str = './checkpoints' no_save: bool = False + # `"pt"`: PyTorch native format + # `"safetensors"`: Safetensors format + # You can also register new formats via `nnscaler.cli.serialization.register_format` + format: str = 'pt' + # `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is # a folder with as many files as the world size. # `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is @@ -485,6 +491,9 @@ def __post_init__(self): if not self.save_dir: raise ValueError("save_dir is required") + if self.format not in Checkpointer.SUFFIX_MAP: + raise ValueError(f"Invalid format {self.format}") + if self.every_n_epochs is not None and self.every_n_train_steps is not None: raise ValueError("Cannot specify both every_n_epochs and every_n_train_steps") if self.every_n_epochs is None and self.every_n_train_steps is None: @@ -972,3 +981,6 @@ def create_hook(self) -> TrainHook: return ArgsTrainHook(hook_config) else: raise ValueError(f"Invalid hook_config {hook_config}") + + def create_checkpointer(self) -> Checkpointer: + return Checkpointer(self.checkpoint.format) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 58512adf..68a811e7 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -618,7 +618,8 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] if 'step' in bucket_states[0]: - opt_states['step'] = bucket_states[0]['step'] + # make sure all steps are different tensors (with same value) + opt_states['step'] = bucket_states[0]['step'].clone() return opt_states def _merge_opt_zero(worker_idx, param_idx): diff --git a/nnscaler/runtime/serialization.py b/nnscaler/runtime/serialization.py new file mode 100644 index 00000000..0adc6b78 --- /dev/null +++ b/nnscaler/runtime/serialization.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, TypedDict +import pickle +import base64 +import copy + +import torch +from safetensors.torch import save_file +from safetensors import safe_open + +from nnscaler.utils import transform_recursively, check_recursively +from nnscaler.version import __version__ + + +class MetadataDict(TypedDict): + obj: str + nnscaler: str + + +class _Index: + def __init__(self, index: int): + self.index = index + + def __repr__(self): + return f"_Index({self.index})" + + +def save(obj: Any, f, *, format="safetensors") -> None: + """ + Saves an object containing tensors into a safetensors file. + Args: + obj (`Any`): + The object you want to save. It can be a nested structure containing + tensors, lists, tuples, and dictionaries. + f: + The file-like object or filename where to save the safetensors file. + format (`str`, *optional*, defaults to `"safetensors"`): + The format to save the object. Currently `"safetensors"` and `"pt"` is supported. + """ + if format == 'pt': + torch.save(obj, f) + return + + if format != 'safetensors': + raise ValueError(f"Unsupported format: {format}") + + index = 0 + + # all tensors to be saved + tensors = {} + # detect shared tensors + # because safetensors does not support shared tensors, we need to + # save shared tensors only once and replace other occurrences + # TODO: Currently we only detect shared tensors that are exactly the same + # (i.e., share the same data_ptr and shape and stride). + # We may improve it in the future if needed. + # key: (tensor.data_ptr(), tensor.shape, tensor.stride()), value: _Index + tensor_ids: dict[tuple[int, tuple[int, ...], tuple[int, ...]], _Index] = {} + def transform_fn(o: Any) -> Any: + nonlocal index + if isinstance(o, torch.Tensor): + key = (o.data_ptr(), o.shape, o.stride()) + if key in tensor_ids: + idx = tensor_ids[key] + else: + idx = _Index(index) + tensor_ids[key] = idx + tensors[f'{index}'] = o + index += 1 + return idx + return o + metadata = transform_recursively(obj, transform_fn, target_types=(torch.Tensor,)) + save_file(tensors, f, metadata={ + 'obj': base64.b64encode(pickle.dumps(metadata)).decode('utf-8'), + 'nnscaler': __version__ + }) + + +class _LazyContainer: + """ + Mock class for dictionary, list, and tuple that loads tensors lazily from safetensors file. + """ + def __init__(self, data: dict | tuple | list, tensors: safe_open): + self.data = data + self.tensors = tensors + + def __getitem__(self, key): + return self._v(self.data[key]) + + def __setitem__(self, key, value): + raise NotImplementedError("Lazy containers are read-only.") + + def __delitem__(self, key): + raise NotImplementedError("Lazy containers are read-only.") + + def pop(self, key, default=None): + raise NotImplementedError("Lazy containers are read-only.") + + def __len__(self): + return len(self.data) + + def __contains__(self, item): + return self.data.__contains__(item) + + def get(self, key, default=None): + return self._v(self.data.get(key, default)) + + def keys(self): + return self.data.keys() + + def values(self): + return map(self._v, self.data.values()) + + def items(self): + return ((k, self._v(v)) for k, v in self.data.items()) + + def _v(self, v): + return _wrap_value(v, self.tensors) + + def load_all(self): + def _load(v): + if isinstance(v, _Index): + return self.tensors.get_tensor(f'{v.index}') + return v + return transform_recursively(self.data, _load, target_types=(_Index,)) + + def __copy__(self): + return copy.copy(self.load_all()) + + def __deepcopy__(self, memo): + return copy.deepcopy(self.load_all(), memo) + + def __iter__(self): + return iter(self.data) + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.data)})" + + +class _LazyList(_LazyContainer, list): + pass + + +class _LazyDict(_LazyContainer, dict): + pass + + +class _LazyTuple(_LazyContainer, tuple): + # tuple is immutable, so we need to override __new__ + def __new__(cls, *args, **kwargs): + return tuple.__new__(cls, ()) + + +def _wrap_value(v: Any, tensors: safe_open) -> Any: + if isinstance(v, _Index): + return tensors.get_tensor(f'{v.index}') + if not check_recursively(v, lambda k: isinstance(k, _Index)): + return v + if isinstance(v, dict): + return _LazyDict(v, tensors) + if isinstance(v, list): + return _LazyList(v, tensors) + if isinstance(v, tuple): + return _LazyTuple(v, tensors) + # should not reach here + return v + + +class LazyLoader: + def __init__(self, filename, device="cpu"): + self.filename = filename + self.device = device + self.tensor_loader = safe_open(self.filename, framework="pt", device=self.device) + self.tensors = None + self.data = None + + def __enter__(self): + self.tensors = self.tensor_loader.__enter__() + metadata: MetadataDict = self.tensors.metadata() + metadata_obj_b64 = metadata['obj'] + self.data = pickle.loads(base64.b64decode(metadata_obj_b64.encode('utf-8'))) + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + self.tensor_loader.__exit__(_exc_type, _exc_value, _traceback) + + def get_lazy_data(self) -> _LazyContainer | Any: + if self.tensors is None: + raise RuntimeError("LazyLoader context is not entered.") + return _wrap_value(self.data, self.tensors) + + +def load(f, *, device="cpu", format="safetensors", lazy=False) -> LazyLoader | Any: + """ + Loads an object containing tensors from a safetensors file lazily. + Args: + f: The file-like object or filename from which to load the safetensors file. + device (`str`, *optional*, defaults to `"cpu"`): + The device where the tensors will be loaded. + lazy (`bool`, *optional*, defaults to `False`): + If set to `False`, loads all tensors into memory eagerly. + Returns: + (`LazyLoader` | `Any`): + The lazy loader object that can be used to access the data. + If `lazy` is set to `False`, returns the loaded object with all tensors + loaded into memory. + """ + if format == 'pt': + return torch.load(f, map_location=device, weights_only=False) + if format != 'safetensors': + raise ValueError(f"Unsupported format: {format}") + + if not lazy: + with LazyLoader(f, device=device) as loader: + data = loader.get_lazy_data() + assert isinstance(data, _LazyContainer) + return data.load_all() + return LazyLoader(f, device=device) + + +def convert(src: str, dst: str, *, src_format="safetensors", dst_format="pt", device="cpu") -> None: + """ + Converts a serialized file from one format to another. + Args: + src (`str`): + The source filename. + dst (`str`): + The destination filename. + src_format (`str`, *optional*, defaults to `"safetensors"`): + The format of the source file. Currently `"safetensors"` and `"pt"` is supported. + dst_format (`str`, *optional*, defaults to `"pt"`): + The format of the destination file. Currently `"safetensors"` and `"pt"` is supported. + device (`str`, *optional*, defaults to `"cpu"`): + The device where the tensors will be loaded. + + Returns: + (`None`): + This function does not return anything. + """ + if src_format == dst_format: + raise ValueError("Source and destination formats are the same.") + + save( + load(src, device=device, format=src_format, lazy=False), + dst, + format=dst_format + ) diff --git a/nnscaler/utils.py b/nnscaler/utils.py index b20ab317..053a0b60 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -235,6 +235,123 @@ def wrapped_fn(*args, **kwargs): TRANSFORM_SUPPORTED_COLLECTION_TYPES = (tuple, list, dict, set, slice, _DICT_ITEMS_TYPE, _DICT_KEYS_TYPE, _DICT_VALUES_TYPE) +def _transform_recursively(data: Any, fn: Callable[[Any], Any], + target_types: Union[Callable[[Any], bool], Type, Tuple[Type, ...]], + collection_types = (tuple, list, dict), skip_dict_keys = True +) -> tuple[bool, Any]: + if collection_types is None: + collection_types = TRANSFORM_SUPPORTED_COLLECTION_TYPES + if isinstance(data, collection_types): + if isinstance(data, tuple): + result = tuple(_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data) + changed = any(c for c, _ in result) + if changed: + return True, tuple(v for _, v in result) + else: + return False, data + if isinstance(data, list): + result = [_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data] + changed = any(c for c, _ in result) + if changed: + return True, [v for _, v in result] + else: + return False, data + if isinstance(data, set): + result = [_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data] + changed = any(c for c, _ in result) + if changed: + return True, {v for _, v in result} + else: + return False, data + if isinstance(data, dict): + if skip_dict_keys: + keys = {k: (False, k) for k in data.keys()} + else: + keys = { + k: _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k in data.keys() + } + changed = any(c for c, _ in keys.values()) + result = { + k: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) + for k, v in data.items() + } + changed = changed or any(c for c, _ in result.values()) + if changed: + return True, { + keys[k][1]: v for k, (_, v) in result.items() + } + else: + return False, data + if isinstance(data, _DICT_ITEMS_TYPE): + if skip_dict_keys: + keys = {k: (False, k) for k, _ in data} + else: + keys = { + k: _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k, _ in data + } + + changed = any(c for c, _ in keys.values()) + result = { + k: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) + for k, v in data + } + changed = changed or any(c for c, _ in result.values()) + if changed: + return True, { + keys[k][1]: v for k, (_, v) in result.items() + }.items() + else: + return False, data + if isinstance(data, _DICT_KEYS_TYPE): + result = [ + _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k in data + ] + changed = any(c for c, _ in result) + if changed: + return True, { + v: i for i, (_, v) in enumerate(result) + }.keys() + else: + return False, data + if isinstance(data, _DICT_VALUES_TYPE): + result = { + i: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) + for i, v in enumerate(data) + } + changed = any(c for c, _ in result.values()) + if changed: + return True, { + i: v for i, (_, v) in result.items() + }.values() + else: + return False, data + if isinstance(data, slice): + result = ( + _transform_recursively(data.start, fn, target_types, collection_types, skip_dict_keys), + _transform_recursively(data.stop, fn, target_types, collection_types, skip_dict_keys), + _transform_recursively(data.step, fn, target_types, collection_types, skip_dict_keys), + ) + if any(c for c, _ in result): + return True, slice( + result[0][1], + result[1][1], + result[2][1] + ) + else: + return False, data + raise ValueError(f"Unsupported collection type: {type(data)}") + elif isinstance(target_types, (tuple, list)) or inspect.isclass(target_types): + if isinstance(data, target_types): + return True, fn(data) + elif callable(target_types): # not a class, but callable. treat as a check function. + if target_types(data): + return True, fn(data) + return False, data + + def transform_recursively(data: Any, fn: Callable[[Any], Any], target_types: Union[Callable[[Any], bool], Type, Tuple[Type, ...]], collection_types = (tuple, list, dict), skip_dict_keys = True @@ -251,51 +368,61 @@ def transform_recursively(data: Any, fn: Callable[[Any], Any], skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. """ + _, result = _transform_recursively(data, fn, target_types, collection_types, skip_dict_keys) + return result + + +def check_recursively(data, fn: Callable[[Any], bool], + collection_types = (tuple, list, dict), + skip_dict_keys = True +) -> bool: + """ + Check the data with the given function, will recursively apply the function to the nested data. + Args: + data: the data to be checked. + fn: the function to check. + collection_types: the collection types to apply the function to the nested data. + skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). + _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. + + """ if collection_types is None: collection_types = TRANSFORM_SUPPORTED_COLLECTION_TYPES + if isinstance(data, collection_types): - if isinstance(data, tuple): - return tuple(transform_recursively(t, fn, target_types, collection_types) for t in data) - if isinstance(data, list): - return list(transform_recursively(t, fn, target_types, collection_types) for t in data) - if isinstance(data, set): - return set(transform_recursively(t, fn, target_types, collection_types) for t in data) + if isinstance(data, (list, tuple, set, _DICT_KEYS_TYPE, _DICT_VALUES_TYPE)): + return any(check_recursively(t, fn, collection_types) for t in data) if isinstance(data, dict): - return { - k if skip_dict_keys else transform_recursively(k, fn, target_types, collection_types): - transform_recursively(v, fn, target_types, collection_types) - for k, v in data.items() - } + if skip_dict_keys: + return any( + check_recursively(v, fn, collection_types) + for v in data.values() + ) + else: + return any( + check_recursively(k, fn, collection_types) or check_recursively(v, fn, collection_types) + for k, v in data.items() + ) if isinstance(data, _DICT_ITEMS_TYPE): - return { - k if skip_dict_keys else transform_recursively(k, fn, target_types, collection_types): - transform_recursively(v, fn, target_types, collection_types) - for k, v in data - }.items() - if isinstance(data, _DICT_KEYS_TYPE): - return { - transform_recursively(k, fn, target_types, collection_types): i - for i, k in enumerate(data) - }.keys() - if isinstance(data, _DICT_VALUES_TYPE): - return { - i: transform_recursively(v, fn, target_types, collection_types) - for i, v in enumerate(data) - }.values() + if skip_dict_keys: + return any( + check_recursively(v, fn, collection_types) + for _, v in data + ) + else: + return any( + check_recursively(k, fn, collection_types) or check_recursively(v, fn, collection_types) + for k, v in data + ) if isinstance(data, slice): - return slice( - transform_recursively(data.start, fn, target_types, collection_types), - transform_recursively(data.stop, fn, target_types, collection_types), - transform_recursively(data.step, fn, target_types, collection_types) - ) + return any(( + check_recursively(data.start, fn, collection_types), + check_recursively(data.stop, fn, collection_types), + check_recursively(data.step, fn, collection_types) + )) raise ValueError(f"Unsupported collection type: {type(data)}") - elif isinstance(target_types, (tuple, list)) or inspect.isclass(target_types): - if isinstance(data, target_types): - return fn(data) - elif callable(target_types): # not a class, but callable. treat as a check function. - if target_types(data): - return fn(data) - return data + + return fn(data) def is_running_distributed() -> bool: diff --git a/requirements.txt b/requirements.txt index 6ae52a83..41efcd26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pybind11<3.0.0 pyyaml torch>=2.0,<=2.6 tqdm +safetensors diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 48107c2a..b57faded 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -4,12 +4,14 @@ from pathlib import Path import re import shutil +from typing import Any import torch import pytest import torch.distributed from nnscaler import merge_state_dicts +from nnscaler.cli.serialization import Checkpointer from nnscaler.cli.trainer import Trainer, logger from nnscaler.cli.trainer_args import AggregatedOutputs, TrainerArgs from tests.parallel_module.common import assert_equal, assert_close @@ -119,6 +121,11 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): if bf16 == 'Mixed' \ else 'torch.optim.Adam' use_zero = save_type == 'sharded' + format = 'safetensors' if parallel_type % 2 else 'pt' + rev_format = 'pt' if format == 'safetensors' else 'safetensors' + + def list_ckpt_files(dir): + return set(dir.glob('**/*.ckpt')) | set(dir.glob('**/*.safetensors')) if parallel_type == 0: additional_args = [] @@ -192,10 +199,11 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', format, *additional_args, ]) trainer.run() - ckpt_files = set(ckpt_savedir.glob('**/*.ckpt')) + ckpt_files = list_ckpt_files(ckpt_savedir) assert len(ckpt_files)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -216,10 +224,11 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', format, *additional_args, ]) trainer.run() - ckpt0_files0 = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} + ckpt0_files0 = {f: f.stat().st_mtime_ns for f in list_ckpt_files(ckpt0_savedir)} assert len(ckpt0_files0)/4 == min(30, trainer.total_train_steps_per_epoch * 2) + 2 # 2 for best/last # resume from last without update max_epochs @@ -237,18 +246,20 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', rev_format, *additional_args, ]) trainer.run() - ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} + ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in list_ckpt_files(ckpt0_savedir)} # nothing should be updated in this case. assert ckpt0_files0 == ckpt0_files0_x # create merged checkpoint ckpt1_savedir = save_dir / 'ckpt1' ckpt1_savedir.mkdir(parents=True, exist_ok=True) + merged_file_name = f'merged{Checkpointer.SUFFIX_MAP[format]}' if trainer.rank == 0: - Trainer.merge_checkpoint(list((ckpt0_savedir / 'last').glob('*.ckpt')), ckpt1_savedir / 'merged.pt') + Trainer.merge_checkpoint(Checkpointer.list_checkpoints(ckpt0_savedir / 'last'), ckpt1_savedir / merged_file_name) torch.distributed.barrier() # continue with the last two epochs (resume for sharded/deduped checkpoint) @@ -265,6 +276,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', + '--checkpoint.format', rev_format, '--checkpoint.keep_last_n_checkpoints', '30', *additional_args, ]) @@ -277,7 +289,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): for f, s in left_files.items(): # make sure the old checkpoints are not overwritten assert ckpt0_files0[f] == s - ckpt0_files1 = set(ckpt0_savedir.glob('**/*.ckpt')) + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) assert len(ckpt0_files1)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -294,7 +306,8 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--compute_config.use_zero', str(use_zero), '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt1_savedir), - '--checkpoint.resume_from', str(ckpt1_savedir / 'merged.pt'), + '--checkpoint.format', rev_format, + '--checkpoint.resume_from', str(ckpt1_savedir / merged_file_name), '--checkpoint.keep_last_n_checkpoints', '30', *additional_args, ]) @@ -307,7 +320,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): for f, s in left_files.items(): # make sure the old checkpoints are not overwritten assert ckpt0_files0[f] == s - ckpt0_files1 = set(ckpt0_savedir.glob('**/*.ckpt')) + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) assert len(ckpt0_files1)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -315,9 +328,9 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): if torch.distributed.get_rank() == 0: assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} for i in range(4): - x = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) - y = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) - z = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + x = Checkpointer.load_for_rank(ckpt_savedir / 'last', i) + y = Checkpointer.load_for_rank(ckpt0_savedir / 'last', i) + z = Checkpointer.load_for_rank(ckpt1_savedir / 'last', i) assert_equal(x['model'], y['model']) assert_equal(x['optimizer'], y['optimizer']) assert_equal(x['lr_scheduler'], y['lr_scheduler']) @@ -325,12 +338,13 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): assert_equal(x['optimizer'], z['optimizer']) assert_equal(x['lr_scheduler'], z['lr_scheduler']) + suffix = Checkpointer.SUFFIX_MAP[format] if save_type == 'deduped': - assert (ckpt_savedir / 'last/0.ckpt').stat().st_size > (ckpt_savedir / 'last/2.ckpt').stat().st_size - assert (ckpt_savedir / 'last/1.ckpt').stat().st_size > (ckpt_savedir / 'last/3.ckpt').stat().st_size + assert (ckpt_savedir / f'last/0{suffix}').stat().st_size > (ckpt_savedir / f'last/2{suffix}').stat().st_size + assert (ckpt_savedir / f'last/1{suffix}').stat().st_size > (ckpt_savedir / f'last/3{suffix}').stat().st_size else: - assert (ckpt_savedir / 'last/0.ckpt').stat().st_size == (ckpt_savedir / 'last/2.ckpt').stat().st_size - assert (ckpt_savedir / 'last/1.ckpt').stat().st_size == (ckpt_savedir / 'last/3.ckpt').stat().st_size + assert (ckpt_savedir / f'last/0{suffix}').stat().st_size == (ckpt_savedir / f'last/2{suffix}').stat().st_size + assert (ckpt_savedir / f'last/1{suffix}').stat().st_size == (ckpt_savedir / f'last/3{suffix}').stat().st_size @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @@ -1212,3 +1226,116 @@ def check_match(code_dir: Path, should_exist: bool): ]) trainer.run() check_match(gen_savedir, should_exist=False) + + +def trainer_checkpointer_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + from nnscaler.cli import register_format + + load_triggered = False + + def save(obj: Any, f: Path) -> None: + obj['test'] = True + return torch.save(obj, f) + + def load(f: str | Path, *, device='cpu') -> Any: + x = torch.load(f, map_location=device, weights_only=False) + assert x['test'] is True + nonlocal load_triggered + load_triggered = True + return x + + register_format('test_format', '.testpt', save, load) + + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '1', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.format', 'test_format', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + + files0 = list(ckpt_savedir.glob('**/*.testpt')) + assert files0, 'No checkpoint files saved with custom format.' + + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.format', 'test_format', + '--checkpoint.resume_from', 'last', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + assert load_triggered, 'Custom load function not triggered when resuming.' + + files1 = list(ckpt_savedir.glob('**/*.testpt')) + assert len(files1) > len(files0), 'Checkpoint files not updated after resuming.' + assert all(f in files1 for f in files0), 'Some checkpoint files missing after resuming.' + assert files1, 'No checkpoint files saved with custom format.' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_custom_checkpointer(tmp_path): + launch_torchrun(1, trainer_checkpointer_worker, tmp_path) + + +def trainer_pas_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + from nnscaler.policies import pas_dp + from nnscaler.cli import TrainerArgs + called = False + + def custom_pas(graph, cfg): + nonlocal called + called = True + return pas_dp(graph, cfg) + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--max_epochs', '1', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + args.pas_policy = custom_pas + # train 1 epcho in one time + trainer = Trainer(train_args=args) + trainer.run() + + assert called, 'Custom PAS policy not called.' + + gen_savedir = save_dir / 'gen2' + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--pas-policy', 'nnscaler.policies.pas_dp', # use full qualified name of pas policy + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_custom_pas(tmp_path): + launch_torchrun(1, trainer_pas_worker, tmp_path) diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 43e83178..99416a3b 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -128,7 +128,8 @@ def init_distributed(): def assert_equal(a: Any, b: Any): - assert type(a) == type(b) + # treat dict and OrderedDict as same for comparison + assert type(a) == type(b) or (isinstance(a, dict) and isinstance(b, dict)), f'{type(a)} != {type(b)}' if isinstance(a, torch.Tensor): assert torch.equal(a.cpu(), b.cpu()), torch.max(torch.abs(a.cpu() - b.cpu())) elif isinstance(a, dict): diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 59510180..4d9b5323 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -449,6 +449,25 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf 'model': merged_model_state_dicts, 'optimizer': merged_optimizer_state_dict }, ckpt_merged_file) + from nnscaler.runtime.serialization import convert, load + from contextlib import ExitStack + ckpt_st_file_template = 'ckpt_{rank}_{start}.safetensors' + ckpt_st_files = [ckpt_dir / ckpt_st_file_template.format(rank=i, start=end) for i in range(torch.distributed.get_world_size())] + for pt, st in zip(ckpt_files, ckpt_st_files): + convert(pt, st, src_format='pt', dst_format='safetensors') + ckpt_st_state_dict_loaders = [load(f, lazy=True) for f in ckpt_st_files] + with ExitStack() as stack: + ckpt_st_state_dicts = [] + for f in ckpt_st_state_dict_loaders: + ckpt_st_state_dicts.append(stack.enter_context(f).get_lazy_data()) + model_st_state_dicts = [ckpt['model'] for ckpt in ckpt_st_state_dicts] + optimizer_st_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_st_state_dicts] + merged_model_st_state_dicts, merged_optimizer_st_state_dict = merge_state_dicts( + model_st_state_dicts, optimizer_st_state_dicts + ) + assert_equal(merged_model_state_dicts, merged_model_st_state_dicts) + assert_equal(merged_optimizer_state_dict, merged_optimizer_st_state_dict) + torch.distributed.barrier() return results diff --git a/tests/runtime/test_serialization.py b/tests/runtime/test_serialization.py new file mode 100644 index 00000000..8b8cab92 --- /dev/null +++ b/tests/runtime/test_serialization.py @@ -0,0 +1,68 @@ +import torch +import pytest + +from nnscaler.runtime.serialization import load, save, convert + +from tests.parallel_module.common import assert_equal + + +def test_normal(tmp_path): + a = torch.randn((2, 2), device='cpu') + b = torch.randn((2, 3), device='cpu') + c = torch.randn((4, 4), device='cpu') + tensors = { + "embedding": a, + "attention": b, + "fc": a, # shared tensor + "bias": {'inner': b, 'outer': {'deep': c}} + } + save(tensors, tmp_path / "model.safetensors") + loaded = load(tmp_path / "model.safetensors", lazy=False) + assert_equal(tensors, loaded) + convert(tmp_path / "model.safetensors", tmp_path / "model.pt") + loaded_pt = torch.load(tmp_path / "model.pt") + assert_equal(tensors, loaded_pt) + + +def test_shared_params(tmp_path): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 4) + self.fc2 = torch.nn.Linear(4, 4) + # share the same weight + self.fc2.weight = self.fc1.weight + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = Model() + save(model.state_dict(), tmp_path / "model.safetensors") + loaded = load(tmp_path / "model.safetensors", lazy=False) + assert_equal(model.state_dict(), loaded) + convert(tmp_path / "model.safetensors", tmp_path / "model.pt") + loaded_pt = torch.load(tmp_path / "model.pt") + assert_equal(model.state_dict(), loaded_pt) + + +def test_bad_shared_params(tmp_path): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 4) + self.fc2 = torch.nn.Linear(4, 4) + # share the same weight + # This case is not common, + # so we don't support it currently. + self.fc2.weight.data = self.fc1.weight.reshape(-1) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = Model() + with pytest.raises(RuntimeError): + save(model.state_dict(), tmp_path / "model.safetensors") diff --git a/tests/test_utils.py b/tests/test_utils.py index a92c36c1..864cab67 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from collections import OrderedDict from dataclasses import dataclass import pytest import torch -from nnscaler.utils import select_many, classproperty, fields, set_member_by_name, unchecked_fields +from nnscaler.utils import ( + select_many, classproperty, fields, set_member_by_name, unchecked_fields, + transform_recursively, +) def test_select_many(): @@ -92,3 +96,85 @@ def test_set_member_by_name(): set_member_by_name(model, 'x.y.z', 45) assert model.x.y == child_module assert model.x.y.z == 45 + + +def test_transform_recursively(): + data = { + 'a': torch.tensor([1]), + 'b': [torch.tensor(4), {'c': torch.tensor([5])}], + 'd': (7, torch.tensor(8)), + 'e': {1: 9, 2: torch.tensor(10)}.keys(), + 'f': {1: 9, 2: torch.tensor(11)}.items(), + 'g': {1: 9, 2: torch.tensor(12)}.values(), + 'h': {1: 9, 2: torch.tensor(13)}, + 'i': slice(0, 10, None), + 'j': torch.Size([11, 12]), + 'k': OrderedDict({1: 9, 2: 10}), + 'l': {1: 9, 2: 10}.values(), + 'm': [1, 2, 3], + 'n': slice(0, 10, torch.tensor(2)), + 'o': {torch.tensor(1): 9, torch.tensor(2): 10}, + 'p': {torch.tensor(1): 9, torch.tensor(2): 10}.items(), + 'q': {torch.tensor(1): 9, torch.tensor(2): 10}.keys() + } + + def fn(x): + if isinstance(x, torch.Tensor): + return x.item() + return x + + result1 = transform_recursively( + data, fn, + target_types=torch.Tensor, + collection_types=None, + skip_dict_keys=True, + ) + + result2 = transform_recursively( + data, fn, + target_types=torch.Tensor, + collection_types=None, + skip_dict_keys=False, + ) + target = { + 'a': 1, + 'b': [4, {'c': 5}], + 'd': (7, 8), + 'e': {1: 1, 2: 2}.keys(), + 'f': dict([(1, 9), (2, 11)]).items(), + 'g': {1: 9, 2: 12}.values(), + 'h': {1: 9, 2: 13}, + 'i': slice(0, 10, None), + 'j': torch.Size([11, 12]), + 'k': OrderedDict({1: 9, 2: 10}), + 'l': data['l'], + 'm': [1, 2, 3], + 'n': slice(0, 10, 2), + } + # dict values are not comparable. + assert list(target['g']) == list(result1.pop('g')) + assert list(target['g']) == list(result2.pop('g')) + target.pop('g') + + + skip_key_target = { + **target, + 'o': {torch.tensor(1): 9, torch.tensor(2): 10}, + 'p': {torch.tensor(1): 9, torch.tensor(2): 10}.items(), + 'q': {1: 9, 2: 10}.keys() + } + noskip_key_target = { + **target, + 'o': {1: 9, 2: 10}, + 'p': dict([(1, 9), (2, 10)]).items(), + 'q': {1: 9, 2: 10}.keys() + } + + from tests.parallel_module.common import assert_equal + + assert_equal(list(skip_key_target.pop('o')), list(result1.pop('o'))) + assert_equal(list(skip_key_target.pop('p')), list(result1.pop('p'))) + assert_equal(list(skip_key_target.pop('q')), list(result1.pop('q'))) + + assert_equal(result1, skip_key_target) + assert_equal(result2, noskip_key_target) From 0b6735fea5af7d719591a9bd83c91148b8954173 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 15 Dec 2025 06:14:50 +0000 Subject: [PATCH 1853/1892] Merged PR 2421: [Feat] Add zero3 support This PR add a naive implementation of ZeRO3. To improve its performance, we will do the following in future. 1. Async prefetch support for both forward and backward 2. Async gradient reduce-scatter inside zero-subgroup 3. Per-bucket gradient reduce-scatter (current implementation is per-parameter) --- nnscaler/codegen/module/module.py | 89 ++++++- nnscaler/flags.py | 2 +- nnscaler/graph/function/function.py | 2 +- nnscaler/ir/cten.py | 29 ++- nnscaler/parallel.py | 200 +++++++++++--- nnscaler/policies.py | 48 ++++ nnscaler/runtime/adapter/reducer.py | 247 +++++++++++++++--- nnscaler/runtime/function/function.py | 24 +- nnscaler/runtime/gnorm.py | 4 +- nnscaler/runtime/module.py | 318 ++++++++++++++++++++--- nnscaler/runtime/utils.py | 14 +- tests/cli/test_trainer.py | 208 +++++++++++++-- tests/cli/test_trainer2.py | 73 ++++++ tests/cli/trainer_args.yaml | 1 + tests/cli/trainer_args_mixed_bf16.yaml | 36 +++ tests/parallel_module/common.py | 6 +- tests/parallel_module/test_checkpoint.py | 6 +- tests/parallel_module/test_ddp.py | 32 ++- tests/parallel_module/test_gencode.py | 200 +++++++++++++- tests/parallel_module/test_init.py | 8 +- tests/runtime/test_gnorm.py | 2 +- tests/runtime/test_hybrid_optimizer.py | 6 +- tests/test_policies.py | 118 +++++++++ tests/utils.py | 9 + 24 files changed, 1506 insertions(+), 176 deletions(-) create mode 100644 tests/cli/test_trainer2.py create mode 100644 tests/cli/trainer_args_mixed_bf16.yaml diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 36ad0a83..94bc2f32 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -10,7 +10,7 @@ import inspect import pickle -from nnscaler.ir.cten import IRCell +from nnscaler.ir.cten import IRCell, IRTensor from nnscaler.ir.tensor import IRFullTensor, IRSubTensor from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from nnscaler.ir.adapter import IRWeightReducer, IRAdapter @@ -24,7 +24,7 @@ from nnscaler.execplan.execplan import ExeReuseCell from nnscaler.codegen.syntax.symtable import SymbolTable -from nnscaler.codegen.syntax.blocks import ClassBlock, FunctionBlock, Block +from nnscaler.codegen.syntax.blocks import ClassBlock, ForBlock, FunctionBlock, Block from nnscaler.codegen.emit import FuncEmission from nnscaler.codegen.module.autograd import AutogradAdapterCodeGen @@ -127,6 +127,7 @@ def __init__( 'from pathlib import Path', 'import torch', 'import torch.utils.checkpoint as ckpt', 'import nnscaler', 'import nnscaler.flags', + 'import nnscaler.runtime.function', 'import _operator', 'from numpy import inf', 'import builtins', '', f'runtime_version = {runtime_version!r}', '', '' ] @@ -551,7 +552,10 @@ def forward(self, x, y=None, z=None): if isinstance(node, IRSegment): segment_idxs.append(idx) - with FunctionBlock(func_name=name, args=input_args) as fb: + saved_tensors_hooks_needed = isinstance(node, IRSegment) and CompileFlag.use_zero > 1 + func_name = name + '_impl' if saved_tensors_hooks_needed else name + + with FunctionBlock(func_name=func_name, args=input_args) as fb: fb.insert_body(forward_code) # generate output outputs = [self.tensor_name(t) for t in node.outputs()] @@ -564,6 +568,16 @@ def forward(self, x, y=None, z=None): cb.insert_body('@torch.jit.script_method') cb.insert_body(fb.code) + if saved_tensors_hooks_needed: + with FunctionBlock(func_name=name, args=input_args) as fb: + # call segment under save_params_hooks context + save_context_code = f'with self.save_params_hooks():' + with Block(save_context_code) as cblock: + cblock.insert_body(f'return self.{func_name}({", ".join(node_args[idx])})') + fb.insert_body(cblock.code) + cb.insert_body('') + cb.insert_body(fb.code) + if as_parallel_module: if not segment_idxs: raise RuntimeError("The graph has no segment, forward code cannot be generated.") @@ -650,8 +664,11 @@ def _get_resolved_arg(arg_name, default_value): outputs = self.return_name(node.outputs(), skip_attr=True) call_code = f'{outputs} = self.{self.node_name(node)}({", ".join(inputs)})' # be sure the user doesn't specify unused args. - for unused_arg in unused_args: - fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') + # but sometimes this can cause issues + # (for example, the value is used in an `if` condition in the original forward function), + # so we disable it for now. + # for unused_arg in unused_args: + # fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') fb.insert_body(call_code) return_code = f'return {self.return_name_complex(self.execplan.graph.outputs())}' fb.insert_body(return_code) @@ -667,6 +684,11 @@ def emit_comm_groups(self): - `model_init_statements` """ sign = 'self.init_group(ranks={ranks})' + # create single rank communication group + self.model_init_statements.append('# single rank communication groups') + with ForBlock(var='rank', iters=f'range({self.runtime_ndevs})') as fb: + fb.insert_body(sign.format(ranks='[rank]')) + self.model_init_statements.extend(fb.code) # create communication group self.model_init_statements.append('# communication groups') for ranks in self.comm_groups: @@ -915,12 +937,57 @@ def emit_context_manager(node: IRCell): code = "with " + ", ".join(ctx_managers) + ":" return code - def emit_node(node): + def emit_node(node, node_idx): node_code = [] # execute if isinstance(node, IRFwOperation): + param_inputs = [ + self.tensor_name(t, prefix_attr='self.') for t in node.iobjs() + if isinstance(t, IRTensor) and t.is_param() + ] code = self.emit_fnode(node, runtime_devid=runtime_devid, plan_ndevs=len(self.devices), runtime_ndevs=self.runtime_ndevs, prefix_attr='self.') - node_code += code + + if not param_inputs or CompileFlag.use_zero <= 1: + node_code += code + else: + activation_inputs = [ + self.tensor_name(t) for t in node.iobjs() + if isinstance(t, IRTensor) and not t.is_attr() and t.requires_grad + ] + activation_outputs = [ + self.tensor_name(t) for t in node.oobjs() + if isinstance(t, IRTensor) and t.requires_grad + ] + + # insert param prefetch before each fnode for zero3 + for t in param_inputs: + node_code.append(f'self.prefetch_param({t})') + # The backward hook here is not reliable, + # 1. there can be no activation input requiring grad, + # 2. some inputs may not be used. + # so, to maximize the chance of triggering backward hook + # let's hook to every input requiring grad + # We also add evict logic in AccumulateGrad hook in bucket implementation, + # which can make sure params are evicted after backward use. + for q in activation_inputs: + node_code.append(f'{q} = self.backward_postevict_param({q}, {t}, {node_idx})') + + node_code += code + + # insert zero param release after each fnode + for t in param_inputs: + node_code.append(f'self.postevict_param({t})') + + # insert backward hook for activation outputs to fetch params in backward + for t in activation_outputs: + # we don't know which activation output will be used in backward + # (DCE may not work 100% correctly), + # so we add hook to all activation outputs for all input params + for p in param_inputs: + node_code.append( + f'{t} = self.backward_prefetch_param({t}, {p}, {node_idx})' + ) + elif isinstance(node, IRAdapter): # for adapters inside an IRSegment, we don't apply async communication to it # as it is mostly in critical path. @@ -946,15 +1013,15 @@ def insert_codes_under_ctx(ctx_code, codes): node_codes = [] current_context_manager_code = "" current_codes = [] - for node in nodes: + for node_idx, node in enumerate(nodes): if has_op_context_info(node): new_context_manager_code = emit_context_manager(node) if current_context_manager_code != new_context_manager_code: node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) - current_codes = emit_node(node) + current_codes = emit_node(node, node_idx) current_context_manager_code = new_context_manager_code else: - current_codes.extend(emit_node(node)) + current_codes.extend(emit_node(node, node_idx)) else: # Node without op context infortmation means it is inserted by nnscaler, not convert from original fx graph, # for example, multiref node and adapter node, currently for nodes inserted by nnscaler we have the following assumption: @@ -1008,7 +1075,7 @@ def insert_codes_under_ctx(ctx_code, codes): # # TODO: all inserted nodes should have its op context field. node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) - node_codes += emit_node(node) + node_codes += emit_node(node, node_idx) current_codes = [] node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) diff --git a/nnscaler/flags.py b/nnscaler/flags.py index 77333987..af903b91 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -47,7 +47,7 @@ class CompileFlag: # use zero optimization on optimizer status. # to cooperate with zero, user needs to call `model.parameters_for_optimizer()` # to get parameters for optimizer, and `model.gather_params()` after `optimizer.step()` - use_zero = _to_bool('USE_ZERO') + use_zero = _to_int('USE_ZERO') # use async communication to overlap gradient synchronization and backward computation async_reducer = _to_bool('ASYNC_REDUCER') # use async reducer # maximal reducer weight bytes for one allreduce (only effective for async): diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 54bd3fef..7c686d76 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -3516,7 +3516,7 @@ def DictValues(o: Union[Dict, IRObject], signature=None): def DictItems(o: Union[Dict, IRObject], signature=None): - signature = 'nnscaler.runtime.function.dict_values' + signature = 'nnscaler.runtime.function.dict_items' if not isinstance(o, dict) and not (isinstance(o, IRObject) and isinstance(o.value, dict)): raise ValueError(f'the input should be a dict or an IRObject with dict value, but get {o}') diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index e2c2c063..42a05eaf 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -20,7 +20,7 @@ from dataclasses import dataclass, field from functools import lru_cache -from typing import ClassVar, Iterable, List, Set, Tuple, Union, Optional, Any, Dict, Callable +from typing import ClassVar, Iterable, List, Set, Tuple, Type, Union, Optional, Any, Dict, Callable from collections import OrderedDict import copy import torch @@ -444,6 +444,31 @@ def fqn(self) -> str: return '' return list(self._module_stack.keys())[-1] + def get_module_fqn( + self, module_class: Type[torch.nn.Module], + *, + include_subclass: bool = False + ) -> str: + """ + Get the first fully qualified module name for the given module class + in the module stack. If not found, return ''. + + Args: + module_class (Type[torch.nn.Module]): the module class to find + include_subclass (bool): whether to include subclass of the module_class + + Returns: + str: the fully qualified module name + """ + if not self._module_stack: + return '' + for fqn, mod_cls in self._module_stack.items(): + if mod_cls == module_class or ( + include_subclass and issubclass(mod_cls, module_class) + ): + return fqn + return '' + @property def call_expr(self) -> Optional[str]: return self._call_expr @@ -1223,7 +1248,7 @@ def dtype(self) -> Optional[torch.dtype]: def is_param(self) -> bool: """! - Check if the tensor is parameter + Check if the tensor is parameter (with requires_grad = True). @return is_param boolean: True if is parameter. """ diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 02459418..2eaf8d04 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -41,10 +41,10 @@ from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.unique import IDGenerator -from nnscaler.runtime.adapter.reducer import Reducer +from nnscaler.runtime.adapter.reducer import Bucket, Reducer from nnscaler.runtime.device import DeviceGroup from nnscaler.runtime.gnorm import calcuate_gnorm, clip_grads -from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState, dedup_attrs +from nnscaler.runtime.module import AttrMeta, Zero3AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState, dedup_attrs from nnscaler.flags import CompileFlag, RuntimeFlag import nnscaler.policies as policies @@ -87,11 +87,17 @@ class ComputeConfig: # how to execute the functions during trace trace_strategy: str = 'cuda_run_cpu_offload' - use_zero: bool = False + # Only support 0/1/3 for now + # If you set use_zero to 2, ZeRO stage 3 will be used internally. + # 0: no zero + # 1: ZeRO stage 1 + # 2: ZeRO stage 3 + # 3: ZeRO stage 3 + use_zero: int = 0 zero_ngroups: int = 1 # whether to use reduce scatter for zero # Please note - # 1. this only works when `use_zero` is True and `zero_ngroups` is 1. + # 1. this only works when `use_zero` is not 0 and `zero_ngroups` is 1. # 2. In some cases, it can introduce parity issue. So use it with caution. zero_use_reduce_scatter: bool = False @@ -158,16 +164,34 @@ def __post_init__(self): raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be > 0") if self.runtime_ngpus % self.plan_ngpus != 0: raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be a multiple of plan_ngpus {self.plan_ngpus}") + + if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: + raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") + + # for backward compatibility, convert bool to int + super().__setattr__('use_zero', int(self.use_zero)) + if self.use_zero not in (0, 1, 2, 3): + raise ValueError(f"use_zero {self.use_zero} must be 0, 1, 2 or 3.") + if self.use_zero == 2: + logger.warning("use_zero=2 is not supported. ZeRO stage 3 will be used instead.") + super().__setattr__('use_zero', 3) + + num_scale_units = self.runtime_ngpus // self.plan_ngpus + if self.use_zero: + if num_scale_units % self.zero_ngroups != 0: + raise ValueError(f"zero_ngroups {self.zero_ngroups} must be a divisor of runtime_ngpus/plan_ngpus {num_scale_units}.") + if num_scale_units == self.zero_ngroups: + logger.warning(f"zero_ngroups {self.zero_ngroups} equals to runtime_ngpus/plan_ngpus {num_scale_units}. Zero optimization is disabled.") + super().__setattr__('use_zero', 0) + if self.use_zero and self.zero_ngroups <= 0: raise ValueError(f"zero_ngroups {self.zero_ngroups} must be > 0") + if not self.use_zero and self.zero_ngroups != 1: logger.warning(f"use_zero is False, but zero_ngroups is {self.zero_ngroups}. Will set zero_ngroups to 1.") # have to use __setattr__ for frozen dataclass super().__setattr__('zero_ngroups', 1) - if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: - raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") - # TODO: Please note in current implementation of Bucket, # zero_use_reduce_scatter still works when zero_ngroups > 1 in sync mode # Let's hide this feature for now for consistency. @@ -224,7 +248,11 @@ def module_dedup_group_size(self) -> int: """ Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. """ - return self.plan_ngpus + if self.use_zero > 1: + # for zero3 + return self.runtime_ngpus // self.zero_ngroups + else: + return self.plan_ngpus @property def optimizer_dedup_group_size(self) -> int: @@ -388,7 +416,7 @@ def _add_gen_savedir_to_syspath(gen_savedir: str) -> Path: gen_savedir = Path(gen_savedir).resolve() gen_savedir.mkdir(parents=True, exist_ok=True) if str(gen_savedir) not in sys.path: - sys.path.append(str(gen_savedir)) + sys.path.insert(0, str(gen_savedir)) return gen_savedir @@ -1001,7 +1029,7 @@ def __init__(self, init_params=True): if isinstance(pas_policy, str): if not pas_policy in _PREDEFINED_POLICIES: raise ValueError(f"Invalid pas_policy: {pas_policy}") - pas_policy = _PREDEFINED_POLICIES[pas_policy] + pas_policy = partial(policies.fn, policy=_PREDEFINED_POLICIES[pas_policy]) else: if not callable(pas_policy): raise ValueError("pas_policy should be a callable or a predefined policy name") @@ -1376,7 +1404,9 @@ def build_optimizer( if compute_config: reducer_config = { 'async_op': compute_config.use_async_reducer, - 'zero': compute_config.use_zero, + # zero3 can't be used in non-parallel module reducer + # because we are unable to insert hooks to prefetch/postevict params + 'zero': 1 if compute_config.use_zero else 0, 'max_bucket_size_bytes': compute_config.max_bucket_size_bytes, 'zero_use_reduce_scatter': compute_config.zero_use_reduce_scatter, 'zero_ngroups': compute_config.zero_ngroups, @@ -1800,7 +1830,7 @@ def _sort_state_dicts(state_dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]] module_prefix = '.'.join(k) opt_state_dicts_for_merge = None if opt_state_dicts is None else opt_state_dicts[module_prefix] - merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks) for e in extra_states] + merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks, e.zero) for e in extra_states] if not extra_states[0].compute_config.use_zero: # all ranks should have the same use_zero merge_partial_states_zero_idx_maps = None merged_state_dict, merged_opt_state_dict = ParallelModule.merge_state_dicts( @@ -2022,8 +2052,16 @@ def _trim_optimizer_merged_state_dict( def _opt_load_merged_state_dict(module: ParallelModule, states: Dict[int, Dict[str, Any]]): + """ + Args: + module (ParallelModule): the parallel module + states (Dict[int, Dict[str, Any]]): the merged optimizer state dict for a parallel module + key: optimizer parameter index in the merged state dict + value: the state dict for each attribute, e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys + """ with torch.no_grad(): # orig_name -> state + # state: Dict[str, Any], e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys orig_param_dict: Dict[str, Dict[str, Any]] = {} cnt = 0 origin_param_names = module.origin_module_metadata.origin_param_names @@ -2032,16 +2070,80 @@ def _opt_load_merged_state_dict(module: ParallelModule, states: Dict[int, Dict[s orig_param_dict[name] = states[cnt] cnt = cnt + 1 - if module.compute_config.use_zero: + if module.compute_config.use_zero == 1: return _construct_optim_state_zero(module, orig_param_dict) + elif module.compute_config.use_zero > 1: + return _construct_optim_state_zero3(module, orig_param_dict) else: return _construct_optim_state_nonzero(module, orig_param_dict) +def _construct_optim_state_zero3( + module: ParallelModule, + orig_param_dict: Dict[str, Dict[str, Any]] +): + # state for each parameter in the parallel module + new_states = _construct_optim_state_nonzero(module, orig_param_dict) + param_state_map = {p: new_states[idx] for idx, p in enumerate(module.parameters())} + + state_dict, opt_param_idx = {}, 0 + opt_param = module.parameters_for_optimizer() + # first load the params' optimizer state for the reducers's flattened params + for reducer in module.reducers: + for bucket in reducer.buckets: + bucket: Bucket + # one bucket corresponds to one flattened param + assert len(opt_param[opt_param_idx].shape) == 1 + chunk_size = bucket._contiguous_params.shape[0] + opt_states = {} + offset = 0 + for param in bucket.params: + sliced_new_val = param_state_map[param] + param_numel = bucket.get_aligned_numel(param) + # init the optimizer state + if not opt_states: + for key in sliced_new_val.keys(): + if key == 'step': + opt_states[key] = sliced_new_val[key] + else: + opt_states[key] = torch.zeros( + [chunk_size], dtype=sliced_new_val[key].dtype, + device=sliced_new_val[key].device, requires_grad=False + ) + # copy the param's slices to the optimizer's chunk + for key in opt_states.keys(): + if key == 'step': + continue + opt_states[key][offset:offset+sliced_new_val[key].numel()] = sliced_new_val[key] + + offset += param_numel + state_dict[opt_param_idx] = opt_states + opt_param_idx += 1 + + # load the params' optimizer state that are not in reducers + reducer_pids = set() + for reducer in module.reducers: + reducer_pids.update(id(p) for p in reducer.params) + for param in module.parameters(): + if id(param) not in reducer_pids: + state_dict[opt_param_idx] = param_state_map[param] + opt_param_idx += 1 + + return state_dict + + def _construct_optim_state_zero( module: ParallelModule, orig_param_dict: Dict[str, Dict[str, Any]], ): + """ + Construct the optimizer state for a ParallelModule with ZeRO optimization. + Args: + module (ParallelModule): the parallel module + orig_param_dict (Dict[str, Dict[str, Any]]): the original parameter optimizer state + key: original parameter name + value: the state dict for each attribute, e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys + """ dist_param_map = module.dist_param_map # name in parallel module (without tid suffix) -> name in origin module param_area_map = module.fullmap # str -> AttrMeta def _get_optimizer_state_of_param(param, param_ids, local_names): @@ -2158,9 +2260,12 @@ def _construct_optim_state_nonzero( dist_param_map = module.dist_param_map # name in parallel module (without tid suffix) -> name in origin module param_area_map = module.fullmap # str -> AttrMeta - new_states = {} + new_states: dict[int, dict[str, torch.Tensor]] = {} for index, (local_name, _) in enumerate(module.named_parameters()): - new_states[index] = _extract_new_state(local_name, orig_param_dict, dist_param_map, param_area_map) + new_states[index] = _extract_new_state( + local_name, orig_param_dict, dist_param_map, param_area_map, + module.get_zero3_attr_meta(local_name) + ) return new_states @@ -2170,7 +2275,8 @@ def _extract_new_state( orig_param_dict: Dict[str, Dict[str, Any]], dist_param_map: Dict[str, str], param_area_map: Dict[str, AttrMeta], -): + zero3_info: Optional[Zero3AttrMeta] = None +) -> Dict[str, torch.Tensor]: name = '_'.join(local_name.split('_')[:-1]) # remove the integer suffix assert name in dist_param_map attr_meta = param_area_map[local_name] @@ -2181,6 +2287,18 @@ def _extract_new_state( sliced_new_val[key] = new_val[key] else: sliced_new_val[key] = new_val[key][attr_meta.slicers] / attr_meta.val_chunks + if zero3_info is not None: + sliced_new_val[key] = sliced_new_val[key].view(-1)[zero3_info.start:zero3_info.end] + if sliced_new_val[key].numel() < zero3_info.chunk_size: + # padding if needed + sliced_new_val[key] = torch.cat( + [sliced_new_val[key], + torch.zeros( + zero3_info.chunk_size - sliced_new_val[key].numel(), + dtype=sliced_new_val[key].dtype, + device=sliced_new_val[key].device + )], dim=0 + ) return sliced_new_val @@ -2292,7 +2410,7 @@ def _broadcast_gen_files( def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Tuple[ Dict[int, Dict[str, Dict[str, AttrMeta]]], - int, + Dict[str, int], Dict[int, Dict[str, Dict[str, AttrMeta]]] ]: """ @@ -2303,17 +2421,12 @@ def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Tuple[ Returns: A tuple containing: - rank2deduped_fullmap: a mapping from rank id to deduplicated attribute information - - dedup_group_size: the size of the deduplication group + - dedup_group_size: the size of the deduplication group for each parallel module - global_fullmaps: a mapping from rank id to full attribute information """ - dedup_group_size = None + dedup_group_size = {} for prefix, parallel_module in parallel_modules.items(): - if dedup_group_size is None: - dedup_group_size = parallel_module.module_dedup_group_size - else: - assert dedup_group_size == parallel_module.module_dedup_group_size, \ - f'dedup_group_size mismatch {dedup_group_size} vs {parallel_module.module_dedup_group_size}' - dedup_group_size = dedup_group_size or 1 + dedup_group_size[prefix] = parallel_module.module_dedup_group_size world_size = torch.distributed.get_world_size() global_fullmaps: Dict[ @@ -2329,8 +2442,9 @@ def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Tuple[ # `dedup_attrs` is a deterministic algorithm, so it produces same results across different ranks rank2deduped_fullmap = dedup_attrs(global_fullmaps) - for rank in range(dedup_group_size, world_size): - assert len(rank2deduped_fullmap[rank]) == 0, f'Rank {rank} has non-empty deduped_fullmap: {rank2deduped_fullmap[rank]}' + for prefix, group_size in dedup_group_size.items(): + for rank in range(group_size, world_size): + assert len(rank2deduped_fullmap[rank].get(prefix, {})) == 0, f'Rank {rank} has non-empty deduped_fullmap: {rank2deduped_fullmap[rank]}' return rank2deduped_fullmap, dedup_group_size, global_fullmaps @@ -2370,7 +2484,12 @@ def deduped_state_dict( split_names = key.split('.') prefix = '.'.join(split_names[:-1]) # remove the last part of the key if prefix in parallel_modules: - if prefix not in cur_deduped_fullmap or split_names[-1] not in cur_deduped_fullmap[prefix]: + if parallel_modules[prefix].compute_config.use_zero > 1: + # for zero3, we don't use advanced deduplication. + # TODO: handle zero3 case + if cur_rank >= parallel_modules[prefix].module_dedup_group_size: + module_state_dict.pop(key, None) + elif prefix not in cur_deduped_fullmap or split_names[-1] not in cur_deduped_fullmap[prefix]: module_state_dict.pop(key, None) # since replicated non-parallel modules, we only keep weights on rank 0 elif cur_rank >= 1: @@ -2432,16 +2551,25 @@ def load_deduped_state_dict( logger.debug(f'At rank {cur_rank}, state_dict keys: {module_state_dict.keys()}.') logger.debug(f'At rank {cur_rank}, missing_keys: {missing_keys}, unexpected_keys: {unexpected_keys}.') - # step 2: broadcast deduped weights inside 1st scale unit - parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} - rank2deduped_fullmap, dedup_group_size, global_tensor_meta = _collect_dedup_info(parallel_modules) - broadcast_group = DeviceGroup().get_group(list(range(dedup_group_size))) - logger.debug(f'At rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}.') - if cur_rank < dedup_group_size: + # step 2: broadcast deduped weights inside 1st scale unit for non-zero3 parallel modules + # for zero3 modules, the weights are already complete after step 1 + # TODO: refine zero3 modules support + parallel_modules = { + prefix: m + for prefix, m in module.named_modules() + if isinstance(m, ParallelModule) and m.compute_config.use_zero <= 1 + } + if parallel_modules: + rank2deduped_fullmap, dedup_group_size, global_tensor_meta = _collect_dedup_info(parallel_modules) + logger.debug(f'At rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}.') + # broadcast weights in parallel modules for rank, deduped_fullmap in rank2deduped_fullmap.items(): logger.debug(f'At rank {cur_rank}, process rank: {rank}.') for prefix, fullmap in deduped_fullmap.items(): + if cur_rank >= dedup_group_size[prefix]: + break + broadcast_group = DeviceGroup().get_group(list(range(dedup_group_size[prefix]))) for local_name, attr_meta in fullmap.items(): key = f'{prefix}.{local_name}' if prefix else local_name assert prefix in parallel_modules, f'Prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}.' @@ -2472,7 +2600,7 @@ def load_deduped_state_dict( for key in missing_keys: split_names = key.split('.') prefix = '.'.join(split_names[:-1]) # remove the last part of the key - assert prefix not in parallel_modules, f'At rank {cur_rank}, the missing key {key} should be in non-parallel modules.' + assert prefix not in parallel_modules or cur_rank >= dedup_group_size[prefix], f'At rank {cur_rank}, the missing key {key} should be in non-parallel modules.' # At this point # - All parallel modules in first scale unit should be complete. @@ -2752,7 +2880,7 @@ def _trim_module_merged_state_dict( prefix = module_path + '.' if module_path else '' trimmed_state_dict.update( pmodule.trim_merged_state_dict( - pmodule.rank, module_state_dict, prefix=prefix, + module_state_dict, prefix=prefix, device=device ) ) diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 9558da34..b061adca 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -38,6 +38,7 @@ from nnscaler.ir import IRCell, IRSubTensor, IRFullTensor from nnscaler.ir.cten import IR from nnscaler.runtime.function import identity, multiref +from nnscaler.utils import load_type if TYPE_CHECKING: @@ -801,3 +802,50 @@ def fn( cfg.apply_pipeline_scheduler(graph, nstages, nmicros, scheduler) return graph + + +def pas_fsdp(graph, cfg: 'ComputeConfig'): + """ + A simple FSDP policy: + 1. all operators are replicated + 2. user specified modules with `cfg.pas_config.recompute_modules` are recomputed + 3. shard policy is configured in cfg.use_zero and cfg.zero_ngroups + 4. CPU offload is not supported + """ + if cfg.plan_ngpus != 1: + raise ValueError("FSDP policy only supports 1 plan GPU") + if not cfg.use_zero: + raise ValueError("FSDP policy requires use_zero to be 1/3") + + recompute_modules = cfg.pas_config.get('recompute_modules', '') + # parse recompute_modules + # user can also provide a list of Module classes. + if isinstance(recompute_modules, str): + recompute_modules = recompute_modules.strip() + if not recompute_modules: + recompute_modules = [] + else: + recompute_modules = [m.strip() for m in recompute_modules.split(',')] + + if recompute_modules: + recompute_modules = [load_type(rm) for rm in recompute_modules] + else: + recompute_modules = [] + + cur_recompute_id = -1 + cur_recompute_module_fqn = None + for node in get_pas_ops(graph): + recompute_module: torch.nn.Module + for rm in recompute_modules: + if rm in node.module_class_chain: + recompute_module = rm + break + else: + cur_recompute_module_fqn = None + continue + + mod_fqn = node.get_module_fqn(recompute_module) + if cur_recompute_module_fqn is None or cur_recompute_module_fqn != mod_fqn: + cur_recompute_id += 1 + cur_recompute_module_fqn = mod_fqn + yield OpPlan(node, recompute_id=cur_recompute_id) diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index d842b702..bf795d43 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -3,6 +3,7 @@ from typing import List, Dict, Tuple, Any, Callable, Optional, Set, Sequence from functools import partial +from dataclasses import dataclass import math import logging import torch @@ -60,11 +61,26 @@ def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: raise KeyError(f"Unsupported reduce op {reduce_op}. Supported reduce op: {supported}") +@dataclass +class _Z3ParamInfo: + shape: torch.Size # original shape of the parameter + start: int + end: int + param_buffer_start: int = -1 + param_buffer_end: int = -1 + + def numel(self) -> int: + return self.end - self.start + + def numel_with_padding(self) -> int: + return self.param_buffer_end - self.param_buffer_start + + class Bucket: - def __init__(self, params: List[torch.nn.Parameter], + def __init__(self, reducer: 'Reducer', params: List[torch.nn.Parameter], param_buffer: torch.Tensor, grad_buffer: torch.Tensor, reduce_op: torch.distributed.ReduceOp, - group: torch.distributed.ProcessGroup, async_op: bool, zero: bool, + group: torch.distributed.ProcessGroup, async_op: bool, zero: int, zero_subgroup: torch.distributed.ProcessGroup = None, zero_crossgroup: torch.distributed.ProcessGroup = None, zero_use_reduce_scatter: bool = False, @@ -84,7 +100,8 @@ def __init__(self, params: List[torch.nn.Parameter], reduce_op (torch.distributed.ReduceOp): the reduce op used by collectives group (torch.distributed.ProcessGroup): communication group async_op (bool): whether to use asynchronous operation - zero (bool): whether to use zero optimization on gradients + zero (int): whether to use zero optimization on gradients, currently only 0/1/3 are supported + zero=2 will be treated as zero=3 zero_subgroup (torch.distributed.ProcessGroup): the subgroup for zero optimization the current rank belongs to zero_crossgroup (torch.distributed.ProcessGroup): the communication group for cross zero group allreduce when reduce scatter is enabled zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization @@ -103,7 +120,7 @@ def __init__(self, params: List[torch.nn.Parameter], self._hooks: List[Tuple[Any, RemovableHandle]] = [] self._async: bool = async_op - self._zero: bool = zero + self._zero: int = zero self._zero_use_reduce_scatter = zero_use_reduce_scatter self._contiguous_params = param_buffer self._contiguous_grads = grad_buffer @@ -127,6 +144,9 @@ def __init__(self, params: List[torch.nn.Parameter], self._pre_hooks: List[Callable] = [] self._post_hooks: List[Callable] = [] + self._z3 = self._zero > 1 + self._reducer = reducer + # only async will enable contiguous gradient self.build() self.register_hooks() @@ -151,6 +171,11 @@ def zero(self) -> bool: """Whether enable zero for this bucket""" return self._zero + @property + def zero3(self) -> bool: + """Whether enable zero3 for this bucket""" + return self._z3 + def get_aligned_numel(self, param) -> int: """ Get the aligned number of elements for a parameter @@ -178,9 +203,11 @@ def _group_reduce_scatter(self): op=self._reduce_op, group=self._zero_subgroup) def _get_opt_param_data(self): - if not self._zero: + if not self._zero or self._zero > 1: + # when zero3 is used, the parameters are already sharded in reducer opt = self._contiguous_params else: + assert self._zero == 1 rank = torch.distributed.get_rank(group=self._zero_subgroup) assert len(self._contiguous_params) % self._zgroup_sz == 0 # Note: @@ -217,13 +244,41 @@ def register_hooks(self): """ @torch.no_grad() - def post_grad_hook(param: torch.nn.Parameter, *unused): + def post_grad_hook(param: torch.nn.Parameter, *unused): # pragma: no cover # stream = DeviceGroup().get_stream('reducer') ofst = self._pofset[param] + rank = torch.distributed.get_rank() # TODO: need to handle sparse gradients in torch.nn.Embedding - self._contiguous_grads[ofst:ofst+param.numel()].add_(param.grad.data.view(-1)) + if self._z3: + z3_info = self._reducer.get_z3_info(param) + grad = param.grad.data.view(-1) + padded_numel = z3_info.numel_with_padding() * self._zgroup_sz + if grad.numel() < padded_numel: + # add padding + grad = torch.cat( + [grad, + torch.zeros(padded_numel - grad.numel(), device=grad.device, dtype=grad.dtype)] + ) + output = torch.zeros(z3_info.numel_with_padding(), device=grad.device, dtype=grad.dtype) + torch.distributed.reduce_scatter_tensor( + output, + grad, + op=self._reduce_op, + group=self._zero_subgroup + ) + # accumulate the param grad in zero3 way + self._contiguous_grads[ofst:ofst+z3_info.numel()]\ + .add_(output[0:z3_info.end-z3_info.start]) + else: + self._contiguous_grads[ofst:ofst+param.numel()].add_(param.grad.data.view(-1)) + param.grad = None + if self._z3: + # in most cases, it is not necessary to post-evict here, + # let's add it for safety + self._reducer.postevict_param(param) + if RuntimeFlag.skip_reducer: return self._async_param_cnt += 1 @@ -237,7 +292,9 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): # apply pre hooks self._apply_pre_hooks() # communication - if self._zero and self._zero_use_reduce_scatter: + if self._zero == 1 and self._zero_use_reduce_scatter: + # when zero3 is used, the parameters and gradients are already sharded in reducer + # so only allreduce is needed if self._zgroup_sz == self._wsz: rank = torch.distributed.get_rank(group=self._group) shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) @@ -248,9 +305,13 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): group=self._group, async_op=True) else: assert False, "group zero + reducescatter is not supported in async mode, " \ - "because the two steps (allreduce, reducescatter) use " \ - "two communication groups, which may induce deadlock." + "because the two steps (allreduce, reducescatter) use " \ + "two communication groups, which may induce deadlock." self._group_reduce_scatter() + elif self._zero > 1: + self._async_handle = torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, + group=self._zero_crossgroup, async_op=True) else: self._async_handle = torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, @@ -259,20 +320,23 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): for param in self._params: # same trick with FSDP and Megatron # reference: https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/fully_sharded_data_parallel.py#L3177-L3188 - param_tmp = param.expand_as(param) + if self._z3: + old_param_data = param.data + # here we need the full parameter to build the computation graph + # let's create a temporary parameter with full shape to fake it. + param.data = torch.empty(self._reducer.get_z3_info(param).shape, dtype=param.dtype, device=param.device) + param_tmp = param.expand_as(param) + param.data = old_param_data + else: + param_tmp = param.expand_as(param) + # gets its AccumulateGrad object grad_acc = param_tmp.grad_fn.next_functions[0][0] hook = grad_acc.register_hook(partial(post_grad_hook, param)) # grad_acc must keep, otherwise the hook won't take effect self._hooks.append((grad_acc, hook)) - def unregister_hooks(self): - """ - Unregister all post-backward hook to parameters. - """ - for _, hook in self._hooks: - hook.remove() - self._hooks.clear() + torch.cuda.empty_cache() def sync_grads(self): """ @@ -294,8 +358,14 @@ def sync_grads(self): # apply pre-hooks self._apply_pre_hooks() # synchrnoize gradients - if self._zero and self._zero_use_reduce_scatter: + if self._zero == 1 and self._zero_use_reduce_scatter: self._group_reduce_scatter() + elif self._zero > 1: + torch.distributed.all_reduce( + self._contiguous_grads, + op=self._reduce_op, + group=self._zero_crossgroup + ) else: torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, group=self._group) @@ -304,10 +374,16 @@ def sync_grads(self): for param in self._params: assert param.grad is None pofst = self._pofset[param] + if self._z3: + z3_info = self._reducer.get_z3_info(param) + # the param should have been evicted + assert z3_info.numel_with_padding() == param.numel() and len(param.shape) == 1, \ + f"internal error: zero3 param size mismatch, " \ + f"expect {[z3_info.numel_with_padding()]} got {param.shape}" param.grad = self._contiguous_grads[pofst:pofst+param.numel()].view(param.size()) # setup gradient for optimizer parameters - if self._zero: + if self._zero == 1: rank = torch.distributed.get_rank(group=self._zero_subgroup) grad = self._contiguous_grads.chunk(self._zgroup_sz, dim=0)[rank] self._param_for_optimizer.grad = grad @@ -321,7 +397,7 @@ def gather_params(self): """ All-gather parameters """ - assert self._zero, "gathering paramters is only for zero optimization." + assert self._zero == 1, "gathering paramters is only for zero1 optimization." rank = torch.distributed.get_rank(group=self._zero_subgroup) CudaTimer().start(field_name='comm', predefined=True) src_tensor = self._contiguous_params.chunk(self._zgroup_sz, dim=0)[rank] @@ -436,15 +512,19 @@ def _pack( state.pop(fields._pre_hooks, None) state.pop(fields._post_hooks, None) + # remove reducer reference + state.pop(fields._reducer, None) + return state @classmethod - def _unpack(cls, state: dict): + def _unpack(cls, state: dict, reducer: 'Reducer'): """ Return a fake bucket that carries the same information. """ bucket = object.__new__(cls) bucket.__dict__.update(state) + bucket._reducer = reducer for param in bucket._params: assert param.device.type == 'meta' @@ -461,11 +541,13 @@ class Reducer: # https://github.com/pytorch/pytorch/blob/4fd16dd8aa259cd75c9a6d2ddcd8171cd1ee8e28/torch/nn/parallel/distributed.py#L548 _DEFAULT_BUCKET_CAP_MB = 25 # 25MB, the same as pytorch - def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None, - reduce_op: str = 'sum', async_op: bool = False, - zero: bool = False, zero_ngroups: int = 1, - zero_use_reduce_scatter: bool = False, - align_size: int = ALIGNED_BYTES + def __init__(self, ranks: List[int], + *, + max_bucket_size_bytes: Optional[int] = None, + reduce_op: str = 'sum', async_op: bool = False, + zero: int = 0, zero_ngroups: int = 1, + zero_use_reduce_scatter: bool = False, + align_size: int = ALIGNED_BYTES, ): """ Create a reducer applied on a set of weights for weight reduction @@ -480,7 +562,8 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None Default is `None` reduce_op (str): reduce operation, can be 'sum', 'avg', 'max' or 'min' (default 'sum') async_op (bool): whether to overlap with backward computation (default False) - zero (bool): whether to apply ZeRO optimization on gradients + zero (int): whether to use zero optimization on gradients, currently only 0/1/3 are supported + zero=2 will be treated as zero=3 zero_ngroups (int): number of ZeRO subgroups in the original ZeRO group zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization align_size (int): the alignment size in bytes for each parameter @@ -502,7 +585,7 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None # buckets stands for a transission unit self._buckets: List[Bucket] = list() self._async: bool = async_op - self._zero: bool = zero + self._zero: int = int(zero) self._zero_use_reduce_scatter = zero_use_reduce_scatter self._align_size: int = align_size if self._align_size % ALIGNED_BYTES != 0: @@ -554,9 +637,18 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None else: assert zero_ngroups == 1, f"ZeRO number of groups must be 1, but got {zero_ngroups}" self._zero_subgroup = self._group - self._zero_crossgroup = None + # trivial crossgroup for single rank + self._zero_crossgroup = DeviceGroup().get_group([torch.distributed.get_rank()]) + self._zero_ngroups = zero_ngroups + self._z3_size = torch.distributed.get_world_size(group=self._zero_subgroup) + if self._z3_size == 1: + self._zero = 0 # disable zero when only one rank in subgroup + self._z3 = self._zero > 1 + self._z3_rank = torch.distributed.get_rank(group=self._zero_subgroup) + self._z3_params_info: dict[torch.nn.Parameter, _Z3ParamInfo] = dict() + @property def zero_ngroups(self) -> int: return self._zero_ngroups @@ -579,6 +671,11 @@ def zero(self) -> bool: """Whether to apply zero optimization on gradients""" return self._zero + @property + def zero3(self) -> bool: + """Whether to apply ZeRO3""" + return self._zero > 1 + @property def buckets(self) -> Tuple[Bucket, ...]: return tuple(self._buckets) @@ -623,7 +720,12 @@ def _bind_params(self): for param in params: with torch.no_grad(): self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) - param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) + if self._z3: + param.data = self._contiguous_params[ofst:ofst+param.numel()] + self._z3_params_info[param].param_buffer_start = ofst + self._z3_params_info[param].param_buffer_end = ofst + param.numel() + else: + param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) ofst += aligned_nelements @@ -635,10 +737,6 @@ def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None and each bucket contains at least one parameter. If the bucket contains more than 2 parameters, than the total size is samller than the max_bucket_size_bytes. - - You can call this method multiple times to rebuild the buckets. - Typically this will be called when building optimizer when multiple optimizers/param groups are used. - And we will put parameters with different optimizer or different param groups into different buckets. """ self._param_clss = {} if param_clss: @@ -648,12 +746,29 @@ def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None # which can help bucket building self._params.sort(key=lambda p: self._param_clss[p]) - for bucket in self._buckets: - # rebuild bucket should be done before any hooks registered. - if bucket._pre_hooks or bucket._post_hooks: - raise RuntimeError("Cannot rebuild buckets while pre/post hooks are registered.") - bucket.unregister_hooks() - self._buckets.clear() + # step 0: param split for zero3 + if self._z3: + for param in self._params: + if not param.requires_grad: + continue + + chunk_size = (param.numel() + self._z3_size - 1) // self._z3_size + start = self._z3_rank * chunk_size + end = min(start + chunk_size, param.numel()) + self._z3_params_info[param] = _Z3ParamInfo(shape=param.shape, start=start, end=end) + # clone the data so original param can be released + # this padding is required + # to make sure all ranks in the zero subgroup have the same bucket layout. + if end - start < chunk_size: + padding = chunk_size - (end - start) + param.data = torch.cat([ + param.data.view(-1)[start:end].clone(), + torch.zeros(padding, dtype=param.dtype, device=param.device) + ], dim=0) + else: + param.data = param.data.view(-1)[start:end].clone() + + torch.cuda.empty_cache() # step 1: build bucket for overlapping gradient synchronization # self._numel * 8 + 1 here is to make sure @@ -718,6 +833,7 @@ def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None for params, param_cls, start, stop in zip(self.seq_buckets, seq_buckets_cls, self.starts, self.stops): # initialize buckets bucket = Bucket( + self, params, self._contiguous_params[start:stop], self._contiguous_grads[start:stop], @@ -749,12 +865,58 @@ def sync_grads(self): for bucket in self._buckets: bucket.sync_grads() + def get_z3_info(self, param: torch.nn.Parameter) -> _Z3ParamInfo: + """ + Get zero3 param info + if the param is not in zero3, return None + """ + return self._z3_params_info.get(param, None) + + @torch.no_grad() + def prefetch_param(self, param: torch.nn.Parameter): + """Prefetch parameter before forward and backward. + + This is required when zero3 is used. + """ + if not self._z3: + raise RuntimeError("postevict_param is only for zero3 optimization.") + if param not in self._z3_params_info: + raise ValueError(f"parameter {param} not found in zero3 params info.") + + info = self._z3_params_info[param] + if param.shape == info.shape: + # no need to gather + return + + full_data = torch.zeros(info.numel_with_padding() * self._z3_size, dtype=param.dtype, + device=torch.cuda.current_device()) + torch.distributed.all_gather_into_tensor( + full_data, + param.data, + group=self._zero_subgroup + ) + param.data = full_data[0:math.prod(info.shape)].view(info.shape).contiguous() + + @torch.no_grad() + def postevict_param(self, param: torch.nn.Parameter): + """Release parameter after forward and backward. + + This is required when zero3 is used. + """ + if not self._z3: + raise RuntimeError("postevict_param is only for zero3 optimization.") + if param not in self._z3_params_info: + raise ValueError(f"parameter {param} not found in zero3 params info.") + info = self._z3_params_info[param] + param.data = self._contiguous_params[info.param_buffer_start:info.param_buffer_end] + def gather_params(self): """Gather parameters with Zero optimizations after `optimizer.step()`. This is required when zero optimization is turned on. """ if not self._zero: return + if self._z3: return # in zero3 mode, no need to gather params for bucket in self._buckets: bucket.gather_params() @@ -904,6 +1066,7 @@ def _pack( fields = unchecked_fields(self) state[fields._params] = [param_map[p] for p in self._params] + state[fields._z3_params_info] = {param_map[p]: info for p, info in self._z3_params_info.items()} state[fields._param_clss] = {param_map[p]: param_cls for p, param_cls in self._param_clss.items()} state[fields._contiguous_params] = torch.empty_like(self._contiguous_params, device='meta') state[fields._contiguous_grads] = torch.empty_like(self._contiguous_grads, device='meta') @@ -934,7 +1097,7 @@ def _unpack(cls, state: dict): buckets = state.pop(fields._buckets) reducer._buckets = [ - Bucket._unpack(bucket) for bucket in buckets + Bucket._unpack(bucket, reducer) for bucket in buckets ] reducer.__dict__.update(state) for param in reducer._params: diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 51ef947c..dbcb295f 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -8,7 +8,7 @@ """ from contextlib import contextmanager -from typing import Optional, List, Tuple, Union, Any +from typing import Callable, Optional, List, Tuple, Union, Any import torch import torch.nn.functional as TorchF import operator @@ -366,4 +366,24 @@ def print_time(content: str): rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1 if torch.cuda.is_available(): torch.cuda.synchronize() - print(f"line timer: {rank} - {datetime.datetime.now()} - {content}") \ No newline at end of file + print(f"line timer: {rank} - {datetime.datetime.now()} - {content}") + + +class _BackwardHook(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, backward_hook: Callable[[], None]): + ctx.save_for_backward() + ctx.backward_hook = backward_hook + return x + + @staticmethod + def backward(ctx, grad_output): + ctx.backward_hook() + return grad_output, None + + +def insert_backward_hook(x: torch.Tensor, backward_hook: Optional[Callable[[], None]]) -> torch.Tensor: + if backward_hook is None: + # no need to add hook + return x + return _BackwardHook.apply(x, backward_hook) diff --git a/nnscaler/runtime/gnorm.py b/nnscaler/runtime/gnorm.py index 5f69b059..36d46c38 100644 --- a/nnscaler/runtime/gnorm.py +++ b/nnscaler/runtime/gnorm.py @@ -135,7 +135,7 @@ def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int return tid2nreplicas -def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, List[torch.nn.Parameter]]: +def prepare_for_grad_clip(cube_model: 'CubeModule', use_zero: int) -> Dict[int, List[torch.nn.Parameter]]: params_info_for_gnorm = cube_model.parameters_for_calc_gnorm() tid2ranks = {} tid2info_list_seq = {} @@ -174,7 +174,7 @@ def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, # multiplied by the number of ZeRO groups. Multiplying the number of pure replicated is easy # to understand. Multiplying the number of ZeRO groups is because the gradients of each ZeRO group # are full model gradients, so the number of ZeRO groups is the number of gradient replicas of the full model. - if not is_zero: + if not use_zero: nreplicas = replicated_info.nranks else: nreplicas = replicated_info.nreplicated * params_info.zero_ngroups diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 68a811e7..550632d9 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import functools import pickle from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union, ClassVar from typing_extensions import Self @@ -16,6 +17,7 @@ import torch import torch.distributed as dist from torch import device +from torch.autograd.graph import saved_tensors_hooks from nnscaler.graph.parser import FxModuleParser @@ -24,6 +26,7 @@ from nnscaler.runtime.executor import Executor from nnscaler.runtime.gnorm import ParamsInfo from nnscaler.runtime.utils import microbatches +from nnscaler.runtime.function import insert_backward_hook from nnscaler import __version__ as runtime_version from nnscaler.flags import CompileFlag @@ -58,6 +61,23 @@ class AttrMeta: sub_shape: Tuple[int, ...] +@dataclass +class Zero3AttrMeta: + """ + Used for loading merged state dict + """ + # original name in the module + orig_name: str + # name in the module + attr_name: str + # start index of the sub tensor + start: int + # end index of the sub tensor + end: int + # chunk size of the sub tensor, can be bigger than end - start due to padding + chunk_size: int + + def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, Dict[str, AttrMeta]]]) -> Dict[int, Dict[str, Dict[str, AttrMeta]]]: ''' Deduplicate the attributes according to `rank2attr_area_map`. @@ -417,6 +437,7 @@ def merge_model_state_dicts( fullmaps: List[Dict[str, AttrMeta]] ): """Merge model states from multiple shard into a single-model state. + Here we assume the order of state_dicts and fullmaps are aligned, and is the same as the rank order. Note: Users only need to provide as fewer local model states as necessary to @@ -438,6 +459,11 @@ def merge_model_state_dicts( # Here we expand slice to (start, step, stop) tuple, # because before python 3.12, slice object is not hashable state_dict_merge_track: Dict[str, Set[Tuple[Tuple[Any, Any, Any], ...]]] = {} + # the fill progress of zero3 parameters + # key: param name + # value: Dict[ tuple(start, step, stop) , filled size] + # used to track how many elements have been filled for each zero3 parameter + zero3_current_filled: Dict[str, Dict[Tuple[Tuple[int, int, int], ...], int]] = {} # gather param/buffer full tensor for rank, (model_state_dict, local_fullmap) in enumerate(zip(state_dicts, fullmaps)): for local_name, meta in local_fullmap.items(): @@ -457,13 +483,40 @@ def merge_model_state_dicts( raise NotImplementedError("Not support of partitioning parameter / buffer at value dimension") state_dict_merge_track_id = tuple((i.start, i.step, i.stop) for i in meta.slicers) - if state_dict_merge_track_id in state_dict_merge_track[meta.orig_name]: - if not CubeModule._safe_tensor_equal(full_model_state_dict[meta.orig_name][meta.slicers], partial_tensor): - raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") + dest_tensor = full_model_state_dict[meta.orig_name][meta.slicers] + if dest_tensor.shape == partial_tensor.shape and state_dict_merge_track_id in state_dict_merge_track[meta.orig_name]: + if not CubeModule._safe_tensor_equal(dest_tensor, partial_tensor): + raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") _logger.debug(f'rank {rank}: skip merging duplicated model state for param {meta.orig_name} with slicers {meta.slicers}') else: state_dict_merge_track[meta.orig_name].add(state_dict_merge_track_id) - full_model_state_dict[meta.orig_name][meta.slicers] = partial_tensor + if dest_tensor.shape == partial_tensor.shape: + dest_tensor.copy_(partial_tensor) + else: + # we assume zero3 is on when dest_tensor.shape != partial_tensor.shape + if len(partial_tensor.shape) != 1: + raise ValueError("Invalid tensor as a ZeRO3 parameter, expected a 1D tensor.") + fill_start = zero3_current_filled.setdefault(meta.orig_name, {}).setdefault(state_dict_merge_track_id, 0) + fill_len = partial_tensor.numel() + if fill_start >= dest_tensor.numel(): + # already filled, let's check consistency + fill_start = fill_start % dest_tensor.numel() + if fill_start + fill_len > dest_tensor.numel(): + # remove padding part + fill_len = dest_tensor.numel() - fill_start + if not CubeModule._safe_tensor_equal(dest_tensor.view(-1)[fill_start: fill_start + fill_len], partial_tensor[0:fill_len]): + raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") + else: + if fill_start + fill_len > dest_tensor.numel(): + # remove padding part + fill_len = dest_tensor.numel() - fill_start + old_shape = dest_tensor.shape + dest_tensor = dest_tensor.reshape(-1) + dest_tensor[fill_start: fill_start + fill_len] = partial_tensor[0: fill_len] + full_model_state_dict[meta.orig_name][meta.slicers] = dest_tensor.view(old_shape) + + zero3_current_filled[meta.orig_name][state_dict_merge_track_id] += fill_len + return full_model_state_dict @staticmethod @@ -584,7 +637,7 @@ def _check_state_size(opt_state_keys, bucket_state): return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape for key in opt_state_keys) - def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): + def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size, zero_version): assert bucket_size % len(bucket_states) == 0 opt_state_keys = list(bucket_states[0].keys()) if 'step' in bucket_states[0]: @@ -593,37 +646,65 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): # NOTE: only support adam for now assert 'exp_avg' in opt_state_keys assert 'exp_avg_sq' in opt_state_keys - chunk_size = bucket_size // len(bucket_states) - start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size - end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + opt_states, opt_states_1d = {}, {} for key in opt_state_keys: opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, device=bucket_states[0][key].device, requires_grad=False) opt_states_1d[key] = opt_states[key].view(-1) - if start_rank_id == end_rank_id: - for key in opt_state_keys: - opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] - else: - offset = chunk_size-start_offset - for key in opt_state_keys: - opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] - for i in range(start_rank_id+1, end_rank_id): + if zero_version == 1: + chunk_size = bucket_size // len(bucket_states) + start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size + end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + if start_rank_id == end_rank_id: for key in opt_state_keys: - opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] - offset += chunk_size - if end_offset: # skip if end_offset == 0, because it is a no-op + opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] + else: + offset = chunk_size-start_offset for key in opt_state_keys: - opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] + for i in range(start_rank_id+1, end_rank_id): + for key in opt_state_keys: + opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] + offset += chunk_size + if end_offset: # skip if end_offset == 0, because it is a no-op + for key in opt_state_keys: + opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + else: # zero_version == 3 + assert zero_version > 1, f'unsupported zero version {zero_version}' + for key in opt_state_keys: + fill_start = 0 + fill_len = pend - pstart + param_numel = opt_states_1d[key].numel() + for bstate in bucket_states: + if fill_start >= param_numel: + # from current implementation, code never goes here + # because we have used model_idx2opt_idx to filter out unnecessary ranks + # but let's keep the logic here for safety + fill_start = fill_start % param_numel + if fill_start + fill_len > param_numel: + fill_len = param_numel - fill_start + # check consistency for the already filled part + if not CubeModule._safe_tensor_equal( + opt_states_1d[key][fill_start: fill_start + fill_len], + bstate[key][pstart: pstart+fill_len] + ): + raise ValueError(f"Conflict in merging optimizer state for param with shape {pshape}") + else: + if fill_start + fill_len > param_numel: + fill_len = param_numel - fill_start + # remove padding part + opt_states_1d[key][fill_start: fill_start + fill_len] = bstate[key][pstart: pstart+fill_len] + fill_start += fill_len if 'step' in bucket_states[0]: # make sure all steps are different tensors (with same value) opt_states['step'] = bucket_states[0]['step'].clone() return opt_states - def _merge_opt_zero(worker_idx, param_idx): - model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[worker_idx] + def _merge_opt_zero(param_shape, worker_idx, param_idx): + model_idx2opt_idx, opt_idx2ranks, zero_version = zero_idx_maps[worker_idx] opt_idx = model_idx2opt_idx[param_idx] if isinstance(opt_idx, int): # the param without reducer @@ -632,14 +713,19 @@ def _merge_opt_zero(worker_idx, param_idx): else: # the param in reducer bucket opt_idx, pstart, pend, pshape = opt_idx + if zero_version == 1: + assert param_shape == pshape, f'param shape {param_shape} vs pshape {pshape}' ranks, bucket_size = opt_idx2ranks[opt_idx] + # parameters in reducer come first, so we can directly use opt_idx to index. bucket_states = [optim_state_dicts[rank]['state'][opt_idx] for rank in ranks] return _retrieve_param_opt_state( bucket_states, pstart, pend, - pshape, - bucket_size) + param_shape, + bucket_size, + zero_version + ) # full_index: param IDs in the full optimizer state for full_index, param_name in enumerate(origin_parameter_names): @@ -682,7 +768,7 @@ def _merge_opt_zero(worker_idx, param_idx): # As ZeRO is applied, the optimizer state of this parameter (a shard) # may not be stored locally in its optimizer state. # _merge_opt_zero is for recovering the optimizer state corresponding to this parameter shard. - states: Dict[str, torch.Tensor] = _merge_opt_zero(work_idx, local_index) + states: Dict[str, torch.Tensor] = _merge_opt_zero(meta.sub_shape, work_idx, local_index) zero_done_track.add(track_id) else: _logger.debug(f'rank {work_idx}: skip merging duplicated optimizer state for param {full_index} with slicers {meta.slicers}') @@ -870,6 +956,11 @@ class ZeroMetadata: model_idx2opt_idx: Optional[Dict] = None # a mapping from optimizer_index to the related bucket information (sub_ranks, bucket_size) opt_idx2ranks: Optional[Dict] = None + # the level of zero optimization + # 0: no zero optimization + # 1: zero1 + # > 1: zero3 + zero: int = 0 @dataclass @@ -944,6 +1035,16 @@ def __init__(self): self._nreplicas2localparams: Optional[Dict[int, List[torch.nn.Parameter]]] = None # track whether all the parames (especially the non-persistent buffers) have been initialized self._non_presistent_buffers_inited = False + # track the params that have been prefetched in backward + # this is only used for zero3 + # The reason is the eviction of prefetched params in backward + # relies on the input.requires_grad flag to be True + # If all the inputs do not require grad, + # the eviction logic will not be triggered + # In that case, we will delay the eviction until next backward hook. + self._backward_prefetched_params: dict[torch.nn.Parameter, int] = {} + # the params that have been prefetched in forward + self._forward_prefetched_params: set[torch.nn.Parameter] = set() def __init_subclass__(cls, skip_init=False, **kwargs): # special case when we just fake a ParallelModule class @@ -1029,7 +1130,7 @@ def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None """ Build buckets for the model reducers. - You can call this method multiple times to rebuild the buckets. + You should call this method exactly once before using this module. Typically this will be called when building optimizer when multiple optimizers/param groups are used. And we will put parameters with different optimizer or different param groups into different buckets. @@ -1038,11 +1139,123 @@ def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None 1. setting `build_buckets=False` when calling constructor in `nnscaler.parallelize`. 2. manually calling `build_buckets()` later in `nnscaler.build_optimizer` """ - for reducer in self.reducers: + # needs all parameters to be in cuda memory before building buckets + self.cuda() + self._param_reducer_map: dict[torch.nn.Parameter, int] = {} + model_params = {p: n for n, p in self.named_parameters()} + # key: attr name of the parameter + # value: Zero3AttrMeta + self._zero3_param_metadata: dict[str, Zero3AttrMeta] = {} + for idx, reducer in enumerate(self.reducers): reducer.build_buckets(param_clss) + for param in reducer.params: + self._param_reducer_map[param] = idx + attr_name = model_params[param] + param_attr = self._fullmap[attr_name] + zero3_info = reducer.get_z3_info(param) + self._zero3_param_metadata[attr_name] = Zero3AttrMeta( + attr_name=attr_name, + orig_name=param_attr.orig_name, + start = zero3_info.start, + end = zero3_info.end, + chunk_size=zero3_info.numel_with_padding(), + ) if zero3_info is not None else None self._zero_metadata = self._get_zero_metadata() + def get_zero3_attr_meta(self, attr_name: str) -> Optional[Zero3AttrMeta]: + """ + Get the Zero3AttrMeta for the given attribute name. + + Args: + attr_name (str): the attribute name of the parameter + Returns: + Optional[Zero3AttrMeta]: the Zero3AttrMeta for the given attribute name + """ + return self._zero3_param_metadata.get(attr_name, None) + + @torch.no_grad() + def prefetch_param(self, param: torch.nn.Parameter): + """ + Gather the full parameter tensor for FSDP. + + Args: + param (torch.nn.Parameter): the local parameter to gather + """ + reducer = self._reducers[self._param_reducer_map[param]] + reducer.prefetch_param(param) + self._forward_prefetched_params.add(param) + + @torch.no_grad() + def postevict_param(self, param: torch.nn.Parameter): + """ + Release the full parameter tensor for zero3. + + Args: + param (torch.nn.Parameter): the local parameter + """ + reducer = self._reducers[self._param_reducer_map[param]] + reducer.postevict_param(param) + self._forward_prefetched_params.discard(param) + + def _backward_evict_leftover_params(self, order: int): + for p in [p for p, o in self._backward_prefetched_params.items() if o > order]: + self.postevict_param(p) + self._backward_prefetched_params.pop(p, None) + + def backward_postevict_param(self, input: torch.Tensor, param: torch.nn.Parameter, order: int): + """ + Here we need an input tensor to register the backward hook. + """ + if not input.requires_grad: + # if input does not require grad, we cannot register backward hook on it + return input + + @torch.no_grad() + def _postevict_param(param): # pragma: no cover + self.postevict_param(param) + self._backward_prefetched_params.pop(param, None) + self._backward_evict_leftover_params(order) + + return insert_backward_hook(input, functools.partial(_postevict_param, param)) + + def backward_prefetch_param(self, activation: torch.Tensor, param: torch.nn.Parameter, order: int): + """ + Here we need an activation tensor to register the backward hook. + """ + if not activation.requires_grad: + # if activation does not require grad, we cannot register backward hook on it + return activation + + @torch.no_grad() + def _prefetch_param(param): # pragma: no cover + self.prefetch_param(param) + self._backward_prefetched_params[param] = order + self._backward_evict_leftover_params(order) + + return insert_backward_hook(activation, functools.partial(_prefetch_param, param)) + + def save_params_hooks(self) -> saved_tensors_hooks: + """ + A hook to save tensors during forward pass. + This is used to avoid parameters being saved for activation checkpointing. + + Returns: + saved_tensors_hooks: the saved tensors hooks + """ + def pack(x: torch.Tensor): + for param in self._forward_prefetched_params: + if x.untyped_storage() == param.untyped_storage(): + return (param, x.shape, x.stride(), x.storage_offset()) + return x + + def unpack(x): + if isinstance(x, tuple) and len(x) == 4: + return torch.as_strided(x[0], x[1], x[2], x[3]) + return x + + return saved_tensors_hooks(pack, unpack) + @classmethod def get_attr_meta_map(cls, rank=None): """ @@ -1062,7 +1275,22 @@ def forward(self, *args, **kwargs): self._warn_uninitialized_non_persistent_buffers(raise_error=True) if self.training: self._sync_grad_required = True # mark sync_grad() can be called again - return self._forward_impl(*args, **kwargs) + # all prefetched params should have been evicted + # please note the param can be evicted in Reducer, + # which is not tracked in self._backward_prefetched_params + # so we just check the shape to make sure the param is evicted + for param in self._backward_prefetched_params.keys(): + old_shape = param.shape + self.postevict_param(param) + assert param.shape == old_shape, \ + f'Param {param} is not properly evicted in backward' + self._backward_prefetched_params.clear() + + ret = self._forward_impl(*args, **kwargs) + + assert not self._forward_prefetched_params, \ + f'All forward prefetched params should have been evicted in forward' + return ret def _forward_impl(self, *args, **kwargs): """ @@ -1307,6 +1535,7 @@ def _get_zero_metadata(self) -> ZeroMetadata: return ZeroMetadata( model_idx2opt_idx=model_idx2opt_idx, opt_idx2ranks=opt_idx2ranks, + zero=self.compute_config.use_zero ) def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: @@ -1434,7 +1663,7 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s # avoid checking the non-persistent buffers attr_names = set([attr for attr in self._fullmap.keys() if attr not in non_persistent_buffers]) - for prefix_attr, content in self.trim_merged_state_dict(self.rank, state_dict, prefix).items(): + for prefix_attr, content in self.trim_merged_state_dict(state_dict, prefix).items(): attr = prefix_attr[len(prefix):] tensor: torch.Tensor = getattr(self, attr) tensor.copy_(content) @@ -1451,10 +1680,8 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s self._warn_uninitialized_non_persistent_buffers() return missing_keys - @classmethod def trim_merged_state_dict( - cls, - rank, + self, state_dict: Dict[str, Any], prefix: str = '', *, @@ -1474,9 +1701,9 @@ def trim_merged_state_dict( device = device or torch.cuda.current_device() trimmed_state_dict = {} - dist2param = cls.dist_param_map + dist2param = self.dist_param_map orig_param_names = list(dist2param.values()) # param names in original module (without prefix) - attr_meta_map = cls.get_attr_meta_map(rank) + attr_meta_map = self.get_attr_meta_map(self.rank) with torch.no_grad(): # avoid checking the non-persistent buffers origname_tid_map = {meta.orig_name: meta.tid for meta in attr_meta_map.values()} @@ -1494,10 +1721,23 @@ def trim_merged_state_dict( param_value = state_dict[orig_param_name_with_prefix] tid = origname_tid_map[orig_param_name] for attr, slicer, nchunks in tid_info[tid]: - content = param_value[slicer] + content: torch.Tensor = param_value[slicer] if nchunks != 1: content = content / nchunks - trimmed_state_dict[prefix + attr] = content.to(device) + if self.compute_config.use_zero <= 1 or self._zero3_param_metadata.get(attr, None) is None: + trimmed_state_dict[prefix + attr] = content.to(device) + else: + z3_info = self._zero3_param_metadata[attr] + start, end, chunk_size = z3_info.start, z3_info.end, z3_info.chunk_size + if end - start < chunk_size: + # need padding + padding = chunk_size - (end - start) + trimmed_state_dict[prefix + attr] = torch.cat([ + content.view(-1)[start:end], + torch.zeros(padding, dtype=content.dtype, device=content.device) + ], dim=0).to(device) + else: + trimmed_state_dict[prefix + attr] = content.reshape(-1)[start:end].to(device) return trimmed_state_dict @@ -1521,6 +1761,10 @@ def _pack( state[fields._reducers] = [reducer._pack(param_map) for reducer in self._reducers] state[fields._zero_metadata] = self._zero_metadata state[fields._fullmap] = self._fullmap + state[fields._param_reducer_map] = { + param_map[p]: rid for p, rid in self._param_reducer_map.items() + } + state[fields._zero3_param_metadata] = self._zero3_param_metadata for cv in ParallelModule.__annotations__: state[cv] = getattr(self, cv) @@ -1540,6 +1784,8 @@ class GenModelX(ParallelModule, skip_init=True): object.__setattr__(pm, fields._reducers, [Reducer._unpack(reducer) for reducer in state[fields._reducers]]) object.__setattr__(pm, fields._zero_metadata, state[fields._zero_metadata]) object.__setattr__(pm, fields._fullmap, state[fields._fullmap]) + object.__setattr__(pm, fields._param_reducer_map, state[fields._param_reducer_map]) + object.__setattr__(pm, fields._zero3_param_metadata, state[fields._zero3_param_metadata]) def named_parameters( prefix: str = "", recurse: bool = True, remove_duplicate: bool = True diff --git a/nnscaler/runtime/utils.py b/nnscaler/runtime/utils.py index b15748ea..43ec0656 100644 --- a/nnscaler/runtime/utils.py +++ b/nnscaler/runtime/utils.py @@ -13,7 +13,7 @@ class MicroBatchDataLoader: """ MicroBatchDataLoader is used for scenarios of gradient accumulation, where a training iteration will have multiple data samples and perform - multiple forward and backward on each sample (i.e., each refers to + multiple forward and backward on each sample (i.e., each refers to as a micro-batch). To support more flexible training patterns, e.g., pipeline parallelism, @@ -25,7 +25,7 @@ class MicroBatchDataLoader: ```python # compilation phase dataloader = MicroBatchDataLoader([(input1,),]) # only need one micro-batch - + @nnscaler.compile(model, dataloader, ...) def train_iter(model, dataloader): input1 = next(dataloader) @@ -36,9 +36,9 @@ def train_iter(model, dataloader): ... # runtime phase - + for mini_batch_samples in iter(dataloader): - # mini_batch_samples are sample list for + # mini_batch_samples are sample list for # all micro-batches in one iteration. dl = MicroBatchDataLoader(mini_batch_samples) loss =train_iter(model, dl) @@ -68,7 +68,7 @@ def __init__(self, samples: List[Any], cycle: bool = False): def __iter__(self): self._idx = 0 return self - + def __next__(self): if self._idx == self.nmicros: raise StopIteration @@ -77,10 +77,10 @@ def __next__(self): if self.cycle: self._idx = self._idx % self.nmicros return batch - + def __len__(self): return self.nmicros - + def get_micro_batch(self, idx: int): idx = idx % self.nmicros if self.cycle else idx return self.samples[idx] diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index b57faded..e70528d5 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -6,12 +6,15 @@ import shutil from typing import Any +from mock import PropertyMock import torch import pytest import torch.distributed +from unittest.mock import patch from nnscaler import merge_state_dicts from nnscaler.cli.serialization import Checkpointer +import nnscaler from nnscaler.cli.trainer import Trainer, logger from nnscaler.cli.trainer_args import AggregatedOutputs, TrainerArgs from tests.parallel_module.common import assert_equal, assert_close @@ -120,7 +123,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' \ if bf16 == 'Mixed' \ else 'torch.optim.Adam' - use_zero = save_type == 'sharded' + use_zero = 1 if save_type == 'sharded' else 0 format = 'safetensors' if parallel_type % 2 else 'pt' rev_format = 'pt' if format == 'safetensors' else 'safetensors' @@ -155,7 +158,7 @@ def list_ckpt_files(dir): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False' if use_zero else 'True', + '--model.parallel_modules.0.compute_config.use_zero', str(use_zero), '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -172,7 +175,7 @@ def list_ckpt_files(dir): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False' if use_zero else 'True', + '--model.parallel_modules.0.compute_config.use_zero', str(use_zero), '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -608,7 +611,7 @@ def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): gen_savedir = save_dir / 'gen' ckpt_savedir = save_dir / 'ckpt' optimizer_type = 'torch.optim.Adam' - use_zero = False if zero_ngroups is None else True + use_zero = 0 if zero_ngroups is None else 1 zero_ngroups = '1' if zero_ngroups is None else zero_ngroups trainer = Trainer([ @@ -629,7 +632,7 @@ def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): torch.distributed.barrier() -def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False, hybrid_opt=False): +def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False, hybrid_opt=False, use_zero=0): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) gen_savedir = save_dir / 'gen' @@ -666,7 +669,7 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False, h '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False', + '--model.parallel_modules.0.compute_config.use_zero', '0', '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -683,7 +686,7 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False, h '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False', + '--model.parallel_modules.0.compute_config.use_zero', '0', '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -743,7 +746,7 @@ def param_clss_fn(param_name: str) -> tuple[int, int]: '--max_epochs', '2', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), - '--compute_config.use_zero', 'False', + '--compute_config.use_zero', str(use_zero), '--compute_config.plan_ngpus', '1', '--compute_config.runtime_ngpus', '2', '--compute_config.use_async_reducer', str(async_reducer), @@ -769,18 +772,19 @@ def param_clss_fn(param_name: str) -> tuple[int, int]: torch.distributed.barrier() -def trainer_correctness_worker_aggregate(tmp_path): +def trainer_correctness_worker_aggregate(tmp_path, use_zero): for parallel_type in range(5): for async_reducer in [False, True]: for hybrid_opt in [True, False]: print(f'parallel_type={parallel_type}, async_reducer={async_reducer}, hybrid_opt={hybrid_opt}') save_dir = tmp_path/f'{parallel_type}-{async_reducer}-{hybrid_opt}' - trainer_correctness_worker(save_dir, parallel_type, async_reducer, hybrid_opt) + trainer_correctness_worker(save_dir, parallel_type, async_reducer, hybrid_opt, use_zero) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') -def test_trainer_correctness(tmp_path): - launch_torchrun(2, trainer_correctness_worker_aggregate, tmp_path) +@pytest.mark.parametrize('use_zero', [0, 1, 3]) +def test_trainer_correctness(tmp_path, use_zero): + launch_torchrun(2, trainer_correctness_worker_aggregate, tmp_path, use_zero) merged_ckpts = {} for parallel_type in range(5): for async_reducer in [False, True]: @@ -788,21 +792,26 @@ def test_trainer_correctness(tmp_path): save_dir = tmp_path/f'{parallel_type}-{async_reducer}-{hybrid_opt}' merged_ckpts[(parallel_type, async_reducer, hybrid_opt)] = torch.load(save_dir/'merged.pt') + if use_zero == 3: + assert_fn = assert_close + else: + assert_fn = assert_equal + for parallel_type in range(5): for async_reducer in [False, True]: for hybrid_opt in [True, False]: - assert_equal( + assert_fn( merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['model'], merged_ckpts[(0, False, False)]['model'] ) if not hybrid_opt: - assert_equal( + assert_fn( merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['optimizer'], merged_ckpts[(0, False, False)]['optimizer'] ) else: # param_groups are different when using hybrid optimizer. - assert_equal( + assert_fn( merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['optimizer']['state'], merged_ckpts[(0, False, False)]['optimizer']['state'] ) @@ -1228,6 +1237,175 @@ def check_match(code_dir: Path, should_exist: bool): check_match(gen_savedir, should_exist=False) +def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + + zero_ngroups = runtime_ngpus // plan_ngpus // 2 + if zero_ngroups < 1: + zero_ngroups = 1 + policy = 'dp' if plan_ngpus == 1 else 'tp' + + gen3_savedir = save_dir / 'gen3' + ckpt3_savedir = save_dir / 'ckpt3' + # train 1 epcho in one time with zero3 + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '5', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + torch.distributed.barrier() + + # load from sharded + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', 'last', + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + # load from deduped + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '15', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', 'last', + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + + torch.distributed.barrier() + + # load from merged (from deduped) + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', str(ckpt3_savedir / 'merged.pt'), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + + with ( + patch('nnscaler.ComputeConfig.module_dedup_group_size', new_callable=PropertyMock) as mock_dgs, + patch('nnscaler.ComputeConfig.optimizer_dedup_group_size', new_callable=PropertyMock) as mock_dgs2 + ): + # to mock the case where we have duplicated data in merging + mock_dgs.return_value = runtime_ngpus + mock_dgs2.return_value = runtime_ngpus + + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged2.pt') + zero3_merged_state_dict2 = torch.load(ckpt3_savedir / 'merged2.pt') + zero3_merged_state_dict = torch.load(ckpt3_savedir / 'merged.pt') + assert_equal(zero3_merged_state_dict, zero3_merged_state_dict2) + + torch.distributed.barrier() + + # load from merged (from sharded) + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '25', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', str(ckpt3_savedir / 'merged.pt'), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + torch.distributed.barrier() + + gen1_savedir = save_dir / 'gen1' + ckpt1_savedir = save_dir / 'ckpt1' + + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '25', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen1_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '1', + '--checkpoint.save_dir', str(ckpt1_savedir), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt1_savedir / 'last').glob('*.ckpt')), ckpt1_savedir / 'merged.pt') + zero1_merged_state_dict = torch.load(ckpt1_savedir / 'merged.pt') + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + zero3_merged_state_dict = torch.load(ckpt3_savedir / 'merged.pt') + assert_equal(zero1_merged_state_dict['model'], zero3_merged_state_dict['model']) + assert_equal(zero1_merged_state_dict['optimizer'], zero3_merged_state_dict['optimizer']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_trainer_zero3(tmp_path): + launch_torchrun(2, trainer_zero3, 16, tmp_path, 1, 2) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_trainer_zero3_tp(tmp_path): + launch_torchrun(4, trainer_zero3, 16, tmp_path, 2, 4) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_trainer_zero3_ngroup(tmp_path): + # dim that needs padding + launch_torchrun(4, trainer_zero3, 13, tmp_path, 1, 4) + + def trainer_checkpointer_worker(save_dir): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) diff --git a/tests/cli/test_trainer2.py b/tests/cli/test_trainer2.py new file mode 100644 index 00000000..9b565322 --- /dev/null +++ b/tests/cli/test_trainer2.py @@ -0,0 +1,73 @@ +from pathlib import Path +import pytest +import torch +from torch.utils.data import Dataset + +from nnscaler.cli import TrainerArgs, Trainer +from tests.launch_torchrun import launch_torchrun + + +class NanoGptDataset(Dataset): + def __init__(self, *args, **kwargs): + pass + + def __getitems__(self, indices): + return [torch.randint(0, 151936, (1, 4096), dtype=torch.int64) for _ in indices] + + def __len__(self): + return 10000 + + +def gen_args(trainer_args: 'TrainerArgs'): + src_token = torch.randint(0, 151936, (1, 4096), dtype=torch.int64) + ret = dict( + input_ids=src_token, # torch.Size([1, 4096]) torch.int64 + ) + return ret + + +class WrappedSubModel(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + self.embedding = torch.nn.Embedding(151936, 1536) + + def forward(self, input_ids): + x = self.embedding(input_ids) + return x + + +class WrapperModel(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + self.model = WrappedSubModel() + + def forward(self, src_tokens): + # the logic is from task.train_step + logits = self.model( + src_tokens + ) + return torch.sum(logits) + + +def trainer_mixed_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args_mixed_bf16.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer = Trainer(train_args=args) + trainer.run() + # should reach here without error + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_mixed_bf16_model(tmp_path): + launch_torchrun(2, trainer_mixed_worker, tmp_path) diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml index db05b3e6..5006b87a 100644 --- a/tests/cli/trainer_args.yaml +++ b/tests/cli/trainer_args.yaml @@ -1,6 +1,7 @@ vars: dim: 16 drop_last: true + compute_config: plan_ngpus: 4 runtime_ngpus: 100 diff --git a/tests/cli/trainer_args_mixed_bf16.yaml b/tests/cli/trainer_args_mixed_bf16.yaml new file mode 100644 index 00000000..3ba80cbb --- /dev/null +++ b/tests/cli/trainer_args_mixed_bf16.yaml @@ -0,0 +1,36 @@ +compute_config: + plan_ngpus: 1 + runtime_ngpus: 2 + constant_folding: false + use_zero: 3 + use_end2end: true + +run_mode: run +pas_policy: dp +micro_batch_size: 1 +grad_accumulation_steps: 4 +max_train_steps: 10 +enable_progress_bar: false +log_progress_every_n_train_steps: 10 +precision: bf16 +seed: 1 + +model: + type: tests.cli.test_trainer2.WrapperModel + + parallel_modules: + - type: tests.cli.test_trainer2.WrappedSubModel + forward_args_gen_fn: tests.cli.test_trainer2.gen_args + +optimizer: + type: torch.optim.AdamW + args: + betas: (0.9, 0.95) + eps: 1e-08 + weight_decay: 0.1 + lr: 0.0001 + fused: true + clip_gnorm: 2.0 + +dataset: + type: tests.cli.test_trainer2.NanoGptDataset diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 99416a3b..cc3b22e3 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -151,10 +151,10 @@ def assert_close(a: Any, b: Any, atol=1e-6, rtol=1e-6): elif isinstance(a, dict): assert len(a) == len(b) for k in a.keys(): - assert_close(a[k], b[k]) + assert_close(a[k], b[k], atol=atol, rtol=rtol) elif isinstance(a, (list, tuple)): assert len(a) == len(b) for i in range(len(a)): - assert_close(a[i], b[i]) + assert_close(a[i], b[i], atol=atol, rtol=rtol) else: - raise ValueError(f'unsupported type {type(a)}') \ No newline at end of file + assert a == b, f"Values are not equal: {a} != {b}" diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 4d9b5323..d84d7503 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -593,7 +593,7 @@ def _gpu_merge_worker(): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_merge') as tempdir: compiled_module = _create_cube_module('data', - ComputeConfig(2, 2, use_zero=True), + ComputeConfig(2, 4, use_zero=True), tempdir, 'whole', ) @@ -608,6 +608,6 @@ def _gpu_merge_worker(): ) -@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 24, reason='lack of gpu devices') def test_checkpoint_merge(): - launch_torchrun(2, _gpu_merge_worker) + launch_torchrun(4, _gpu_merge_worker) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index b43e0d20..400e3b18 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -17,7 +17,7 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts from nnscaler.runtime.module import ParallelModule from nnscaler.runtime.gnorm import calcuate_gnorm @@ -261,6 +261,36 @@ def _compare_weights(orig0, compiled0, compiled1, fc1_fullmap, fc2_fullmap, fc1_ # print(f'key: {k}, max diff: {torch.max(torch.abs(orig0[k] - v))}') assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('update_freq', [1, 2, 4]) +def test_zero3(update_freq): + zero3_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 3) + zero1_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 1) + no_zero_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 0) + + zero3_results0: List[StepResult] + zero3_results1: List[StepResult] + zero1_results0: List[StepResult] + zero1_results1: List[StepResult] + no_zero_results0: List[StepResult] + no_zero_results1: List[StepResult] + + zero3_results0, zero3_results1 = zero3_results[0][0], zero3_results[1][0] + zero1_results0, zero1_results1 = zero1_results[0][0], zero1_results[1][0] + no_zero_results0, no_zero_results1 = no_zero_results[0][0], no_zero_results[1][0] + + for r0, r1 in [ + (zero3_results0, zero1_results0), (zero1_results0, no_zero_results0), + (zero3_results1, zero1_results1), (zero1_results1, no_zero_results1), + ]: + # have the same input + assert len(r0) == len(r1) # iteration count + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.pred, b.pred) # pred + assert torch.equal(a.loss, b.loss) # loss + assert torch.equal(a.gnorm, b.gnorm) # gnorm + @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @pytest.mark.parametrize('update_freq', [1, 2, 4]) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 7e2e3ba0..dbca3de0 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -69,6 +69,7 @@ def __init__(self): def forward(self, x): return x[:2] + @replace_all_device_with('cpu') def test_codegen_slice(): with tempfile.TemporaryDirectory() as tempdir: @@ -210,9 +211,8 @@ def _gencode_unused_args_worker(tempdir): m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), m=1) ) - with pytest.raises(ValueError): - # y must be None - m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) + # if y is not None, we will not raise error now. + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @@ -258,9 +258,6 @@ def _gencode_unused_args_worker2(tempdir): with pytest.raises(TypeError, match='.*must be Tensor, not NoneType.*'): # raise by torch.add, as m is None m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - with pytest.raises(ValueError): - # y must be None - m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @@ -2111,3 +2108,194 @@ def test_dynamic_dim_partition(tmp_path, dynamic_dims): assert not _gencode_contains(tmp_path, DynamicInputModel, 0, r'nnscaler.runtime.adapter.nn.split_allgather', instance_name=instance_name) else: assert _gencode_contains(tmp_path, DynamicInputModel, 0, r'nnscaler.runtime.adapter.nn.split_allgather', instance_name=instance_name) + + +@replace_all_device_with('cpu') +def test_zero3_normal(tmp_path): + from tests.parallel_module.test_end2end import MLP + m = MLP(2, 2) + dummy_input = { + 'data': torch.randn( + 2, 2), + 'target': torch.rand( + 2, 2) + } + m.train() + parallelize( + m, + {'data': dummy_input}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # code looks like: + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.prefetch_param\(self\.layers_0_weight_\d+\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.postevict_param\(self\.layers_0_weight_\d+\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.backward_postevict_param\(.*, self\.layers_0_weight_\d+, 1\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.backward_prefetch_param\(.*, self\.layers_0_weight_\d+, 1\)') + + # def segment35_impl(self, data_23): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 46, in forward, x = data['data'] + # getitem_25 = _operator.getitem(data_23, 'data') + # self.prefetch_param(self.layers_0_weight_26) + # getitem_25 = self.backward_postevict_param(getitem_25, self.layers_0_weight_26, 1) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 48, in forward, x = layer(x) + # linear_27 = torch.nn.functional.linear(getitem_25, self.layers_0_weight_26, bias=None) + # self.postevict_param(self.layers_0_weight_26) + # linear_27 = self.backward_prefetch_param(linear_27, self.layers_0_weight_26, 1) + # del getitem_25 + # self.prefetch_param(self.layers_1_weight_28) + # linear_27 = self.backward_postevict_param(linear_27, self.layers_1_weight_28, 2) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 48, in forward, x = layer(x) + # linear_1_29 = torch.nn.functional.linear(linear_27, self.layers_1_weight_28, bias=None) + # self.postevict_param(self.layers_1_weight_28) + # linear_1_29 = self.backward_prefetch_param(linear_1_29, self.layers_1_weight_28, 2) + # del linear_27 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 49, in forward, x = torch.sigmoid(x) + # sigmoid_30 = torch.sigmoid(linear_1_29) + # del linear_1_29 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 50, in forward, loss = self.loss_fn(x, data['target']) + # getitem_1_31 = _operator.getitem(data_23, 'target') + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 50, in forward, loss = self.loss_fn(x, data['target']) + # binary_cross_entropy_24 = torch.nn.functional.binary_cross_entropy(sigmoid_30, getitem_1_31, weight=None, reduction='mean') + # del sigmoid_30, getitem_1_31 + # return binary_cross_entropy_24 + + # def segment35(self, data_23): + # with self.save_params_hooks(): + # return self.segment35_impl(data_23) + + +class SoloOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + self.p = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + x = self.scale.sum() + x + self.p + return torch.sum(x) + + +def launch_zero3_run_solo_param(tmp_path): + init_distributed() + m = SoloOpModule() + dummy_input = torch.randn(4, 4) + m.train() + m_new = parallelize( + m, + {'x': dummy_input}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=True, + reuse='override', + ) + loss = m_new(dummy_input) + loss.backward() + # scale can't be evicited with backward hook + assert len(m_new._backward_prefetched_params) == 1 + # but it should have been evicted in reducer. + assert list(m_new._backward_prefetched_params.keys())[0].shape == (8,) + assert not _gencode_contains(tmp_path, SoloOpModule, 0, + r'self\.backward_postevict_param\(.*, self\.scale_\d+, \d+\)') + assert _gencode_contains(tmp_path, SoloOpModule, 0, + r'self\.backward_postevict_param\(.*, self\.p_\d+, \d+\)') + # code looks like: + # def segment32_impl(self, x_17): + # self.prefetch_param(self.scale_19) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # sum_1_20 = torch.sum(self.scale_19) + # self.postevict_param(self.scale_19) + # sum_1_20 = self.backward_prefetch_param(sum_1_20, self.scale_19, 0) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # add_21 = torch.add(sum_1_20, x_17, alpha=1) + # del x_17, sum_1_20 + # self.prefetch_param(self.p_22) + # add_21 = self.backward_postevict_param(add_21, self.p_22, 2) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # add_1_23 = torch.add(add_21, self.p_22, alpha=1) + # self.postevict_param(self.p_22) + # add_1_23 = self.backward_prefetch_param(add_1_23, self.p_22, 2) + # del add_21 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2187, in forward, return torch.sum(x) + # sum_2_18 = torch.sum(add_1_23) + # del add_1_23 + # return sum_2_18 + + # def segment32(self, x_17): + # with self.save_params_hooks(): + # return self.segment32_impl(x_17) + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of GPU devices') +def test_zero3_run_solo_param(tmp_path): + launch_torchrun(2, launch_zero3_run_solo_param, tmp_path) + + +@nnscaler.register_op('*, *, *, * -> *, *, *, *') +def _zero3_multi_inout(x, y, z, w): + return x + 1, y + 1, z + 1, w + 1 + + +class Zero3MultiInoutModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = torch.nn.Parameter(torch.randn(4, 4)) + self.q = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, y): + return _zero3_multi_inout(x, y, self.p, self.q) + + +@replace_all_device_with('cpu') +def test_zero3_multi_inout(tmp_path): + m = Zero3MultiInoutModule() + m.train() + m_new = parallelize( + m, + {'x': torch.randn(4, 4), 'y': torch.randn(4, 4)}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert len(_gencode_contains(tmp_path, Zero3MultiInoutModule, 0, + 'self.backward_prefetch_param')) == 8 + assert len(_gencode_contains(tmp_path, Zero3MultiInoutModule, 0, + 'self.backward_postevict_param')) == 4 + # code looks like: + # def segment34_impl(self, x_25, y_26): + # self.prefetch_param(self.p_31) + # x_25 = self.backward_postevict_param(x_25, self.p_31, 0) + # y_26 = self.backward_postevict_param(y_26, self.p_31, 0) + # self.prefetch_param(self.q_32) + # x_25 = self.backward_postevict_param(x_25, self.q_32, 0) + # y_26 = self.backward_postevict_param(y_26, self.q_32, 0) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2259, in forward, return _zero3_multi_inout(x, y, self.p, self.q) + # _zero3_multi_inout_27, _zero3_multi_inout_28, _zero3_multi_inout_29, _zero3_multi_inout_30 = tests.parallel_module.test_gencode._zero3_multi_inout(x_25, y_26, self.p_31, self.q_32) + # self.postevict_param(self.p_31) + # self.postevict_param(self.q_32) + # _zero3_multi_inout_27 = self.backward_prefetch_param(_zero3_multi_inout_27, self.p_31, 0) + # _zero3_multi_inout_27 = self.backward_prefetch_param(_zero3_multi_inout_27, self.q_32, 0) + # _zero3_multi_inout_28 = self.backward_prefetch_param(_zero3_multi_inout_28, self.p_31, 0) + # _zero3_multi_inout_28 = self.backward_prefetch_param(_zero3_multi_inout_28, self.q_32, 0) + # _zero3_multi_inout_29 = self.backward_prefetch_param(_zero3_multi_inout_29, self.p_31, 0) + # _zero3_multi_inout_29 = self.backward_prefetch_param(_zero3_multi_inout_29, self.q_32, 0) + # _zero3_multi_inout_30 = self.backward_prefetch_param(_zero3_multi_inout_30, self.p_31, 0) + # _zero3_multi_inout_30 = self.backward_prefetch_param(_zero3_multi_inout_30, self.q_32, 0) + # del x_25, y_26 + # return _zero3_multi_inout_27, _zero3_multi_inout_28, _zero3_multi_inout_29, _zero3_multi_inout_30 + + # def segment34(self, x_25, y_26): + # with self.save_params_hooks(): + # return self.segment34_impl(x_25, y_26) + assert True diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 3d046393..3278cc0f 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -72,13 +72,13 @@ def test_empty_weights(model_class, tp): model_class, {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, 'tp', - ComputeConfig(2, 4, use_zero=True, zero_ngroups=2), + ComputeConfig(2, 8, use_zero=True, zero_ngroups=2), gen_savedir=tempdir, reuse='match', load_module=False, instance_name=instance_name, ) - for i in range(4): + for i in range(8): module_class = _load_parallel_module_class(model_class, gen_savedir=tempdir, instance_name=instance_name, rank=i) m = new_empty(module_class) assert m.rank == i @@ -86,9 +86,9 @@ def test_empty_weights(model_class, tp): assert p.device == torch.device('meta') for r in m.reducers: if tp: - assert r.ranks == ((0, 2) if i in (0, 2) else (1, 3)) + assert r.ranks == ((0, 2, 4, 6) if i in (0, 2, 4, 6) else (1, 3, 5, 7)) else: - assert r.ranks == (0, 1, 2, 3) + assert r.ranks == (0, 1, 2, 3, 4, 5, 6, 7) assert len(r.buckets) == 1 assert r.zero assert r.zero_ngroups == 2 diff --git a/tests/runtime/test_gnorm.py b/tests/runtime/test_gnorm.py index 80025906..1de3ee35 100644 --- a/tests/runtime/test_gnorm.py +++ b/tests/runtime/test_gnorm.py @@ -57,7 +57,7 @@ def cal_wnorm_cube(model: CubeModule): for p in model.parameters_for_optimizer(): p.grad = p.data # p.grad.copy_(p.data) - nreplicas2localparams = prepare_for_grad_clip(model, is_zero=CompileFlag.use_zero) + nreplicas2localparams = prepare_for_grad_clip(model, use_zero=CompileFlag.use_zero) wnorm, _ = clip_gnorm(nreplicas2localparams, None) # maps = {tid: [t.size() for t in ts] for tid, ts in nreplicas2localparams.items()} # print(f'cube nrepicas len: {maps}') diff --git a/tests/runtime/test_hybrid_optimizer.py b/tests/runtime/test_hybrid_optimizer.py index 65b6b067..b4238246 100644 --- a/tests/runtime/test_hybrid_optimizer.py +++ b/tests/runtime/test_hybrid_optimizer.py @@ -103,7 +103,7 @@ def trainer_worker(save_dir, use_zero): '--checkpoint.resume_from.with_merged', str(True), '--gen_savedir', str(gen_savedir), '--checkpoint.save_dir', str(ckpt0_savedir), - '--compute_config.use_zero', str(not use_zero), + '--compute_config.use_zero', str(1 - use_zero), ] trainer = Trainer(trainer_config) trainer.run() @@ -133,7 +133,7 @@ def trainer_worker(save_dir, use_zero): '--checkpoint.resume_from.with_merged', str(False), '--gen_savedir', str(gen_savedir), '--checkpoint.save_dir', str(ckpt1_savedir), - '--compute_config.use_zero', str(not use_zero), + '--compute_config.use_zero', str(1 - use_zero), ]) trainer.run() @@ -148,6 +148,6 @@ def trainer_worker(save_dir, use_zero): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') -@pytest.mark.parametrize('use_zero', [True, False]) +@pytest.mark.parametrize('use_zero', [0, 1]) def test_hybrid_optimizer(tmp_path, use_zero): launch_torchrun(2, trainer_worker, tmp_path, use_zero) diff --git a/tests/test_policies.py b/tests/test_policies.py index 06e31deb..393ee4cc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -983,3 +983,121 @@ def test_run_codegen_fn_with_hook(): """ with tempfile.TemporaryDirectory() as tempdir: launch_torchrun(2, _gencode_unused_args_worker, tempdir) + + +@replace_all_device_with('cpu') +def test_codegen_fsdp(tmp_path): + parallelize( + FnPolicyModuleList(), + {'x': torch.randn(4, 4)}, + 'fsdp', + ComputeConfig( + 1, 2, + use_end2end=True, + use_zero=3, + pas_config={ + 'recompute_modules': [FFNDropout], + } + ), + gen_savedir=tmp_path, + load_module=False + ) + # code should look like: + # def segment105_impl(self, x_49): + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 239, in forward, x = x * 2 + # mul_51 = torch.mul(x_49, 2) + # del x_49 + + # def recompute(mul_51): + # # created at IRAdapterGener:local_consumer_multiref + # mul_100, mul_104 = nnscaler.runtime.function.multiref(mul_51, times=2) + # del mul_51 + # self.prefetch_param(self.ffn_0_gate_proj_weight_52) + # mul_100 = self.backward_postevict_param(mul_100, self.ffn_0_gate_proj_weight_52, 1) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_53 = torch.nn.functional.linear(mul_100, self.ffn_0_gate_proj_weight_52, bias=None) + # self.postevict_param(self.ffn_0_gate_proj_weight_52) + # linear_53 = self.backward_prefetch_param(linear_53, self.ffn_0_gate_proj_weight_52, 1) + # del mul_100 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_54 = torch.tanh(linear_53) + # del linear_53 + # self.prefetch_param(self.ffn_0_up_proj_weight_55) + # mul_104 = self.backward_postevict_param(mul_104, self.ffn_0_up_proj_weight_55, 3) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_56 = torch.nn.functional.linear(mul_104, self.ffn_0_up_proj_weight_55, bias=None) + # self.postevict_param(self.ffn_0_up_proj_weight_55) + # linear_1_56 = self.backward_prefetch_param(linear_1_56, self.ffn_0_up_proj_weight_55, 3) + # del mul_104 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_57 = torch.mul(tanh_54, linear_1_56) + # del tanh_54, linear_1_56 + # self.prefetch_param(self.ffn_0_down_proj_weight_58) + # mul_1_57 = self.backward_postevict_param(mul_1_57, self.ffn_0_down_proj_weight_58, 5) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_59 = torch.nn.functional.linear(mul_1_57, self.ffn_0_down_proj_weight_58, bias=None) + # self.postevict_param(self.ffn_0_down_proj_weight_58) + # linear_2_59 = self.backward_prefetch_param(linear_2_59, self.ffn_0_down_proj_weight_58, 5) + # del mul_1_57 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_21 = self.training + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # dropout_60 = torch.nn.functional.dropout(linear_2_59, p=0.1, training=ffn_0_dropout_training_21, inplace=False) + # del linear_2_59 + # return dropout_60 + + # dropout_60 = ckpt.checkpoint(recompute, mul_51, use_reentrant=False) + # del mul_51 + + # def recompute(dropout_60): + # # created at IRAdapterGener:local_consumer_multiref + # dropout_108, dropout_112 = nnscaler.runtime.function.multiref(dropout_60, times=2) + # del dropout_60 + # self.prefetch_param(self.ffn_1_gate_proj_weight_61) + # dropout_108 = self.backward_postevict_param(dropout_108, self.ffn_1_gate_proj_weight_61, 1) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_3_62 = torch.nn.functional.linear(dropout_108, self.ffn_1_gate_proj_weight_61, bias=None) + # self.postevict_param(self.ffn_1_gate_proj_weight_61) + # linear_3_62 = self.backward_prefetch_param(linear_3_62, self.ffn_1_gate_proj_weight_61, 1) + # del dropout_108 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_1_63 = torch.tanh(linear_3_62) + # del linear_3_62 + # self.prefetch_param(self.ffn_1_up_proj_weight_64) + # dropout_112 = self.backward_postevict_param(dropout_112, self.ffn_1_up_proj_weight_64, 3) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_4_65 = torch.nn.functional.linear(dropout_112, self.ffn_1_up_proj_weight_64, bias=None) + # self.postevict_param(self.ffn_1_up_proj_weight_64) + # linear_4_65 = self.backward_prefetch_param(linear_4_65, self.ffn_1_up_proj_weight_64, 3) + # del dropout_112 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_2_66 = torch.mul(tanh_1_63, linear_4_65) + # del tanh_1_63, linear_4_65 + # self.prefetch_param(self.ffn_1_down_proj_weight_67) + # mul_2_66 = self.backward_postevict_param(mul_2_66, self.ffn_1_down_proj_weight_67, 5) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_5_68 = torch.nn.functional.linear(mul_2_66, self.ffn_1_down_proj_weight_67, bias=None) + # self.postevict_param(self.ffn_1_down_proj_weight_67) + # linear_5_68 = self.backward_prefetch_param(linear_5_68, self.ffn_1_down_proj_weight_67, 5) + # del mul_2_66 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # ffn_1_dropout_training_40 = self.training + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # dropout_1_69 = torch.nn.functional.dropout(linear_5_68, p=0.1, training=ffn_1_dropout_training_40, inplace=False) + # del linear_5_68 + # return dropout_1_69 + + # dropout_1_69 = ckpt.checkpoint(recompute, dropout_60, use_reentrant=False) + # del dropout_60 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 242, in forward, x = x + 3 + # add_70 = torch.add(dropout_1_69, 3, alpha=1) + # del dropout_1_69 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 243, in forward, return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + # sum_1_50 = torch.sum(add_70) + # del add_70 + # return sum_1_50 + + # def segment105(self, x_49): + # with self.save_params_hooks(): + # return self.segment105_impl(x_49) + assert True diff --git a/tests/utils.py b/tests/utils.py index 07e88497..d29e4168 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -114,6 +114,7 @@ def replace_all_device_with(device='cpu', force=False): orig_to = torch.Tensor.to orig_cuda = torch.Tensor.cuda orig_cpu = torch.Tensor.cpu + orig_is_cuda = torch.Tensor.is_cuda def patch_tensor_constructor(fn): orig_func = getattr(fn, '__cube_orig_func__', fn) # to support nested patching @@ -158,6 +159,8 @@ def wrapper(*args, **kwargs): } def patched_to(self, *args, **kwargs): + if device == 'meta': + return self if len(args) > 0 and isinstance(args[0], (torch.device, str)): return orig_to(self, device, *args[1:], **kwargs) if 'device' in kwargs: @@ -166,15 +169,20 @@ def patched_to(self, *args, **kwargs): return orig_to(self, *args, **kwargs) def patched_cuda(self, *args, **kwargs): + if device == 'meta': + return self return orig_to(self, device) def patched_cpu(self, *args, **kwargs): + if device == 'meta': + return self return orig_to(self, device) try: torch.Tensor.to = patched_to torch.Tensor.cuda = patched_cuda torch.Tensor.cpu = patched_cpu + torch.Tensor.is_cuda = property(lambda self: True) # patch tensor constructors for tf_name, fn in old_tensor_constructors.items(): setattr(torch, tf_name, patched_tensor_constructors[tf_name]) @@ -205,6 +213,7 @@ def patched_cpu(self, *args, **kwargs): torch.Tensor.to = orig_to torch.Tensor.cuda = orig_cuda torch.Tensor.cpu = orig_cpu + torch.Tensor.is_cuda = orig_is_cuda # mock process group is from pytorch testing code From bb6a8191e7dbe7f7e69abaf32c0a1e5896ca37d2 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 15 Dec 2025 07:23:30 +0000 Subject: [PATCH 1854/1892] Merged PR 2424: [Misc] Refine trainer interfaces - refine autodist's check: only double check the follow relation when source and target op have backward data flow - update interfaces for hooks: on_train_step_start and on_train_step_end to enable user to change contents in inputs and outpus, which is useful in moe training - add `keep_opt` option in merge_checkpoints: when evaluating a large model (like 30B), user does not want the large optimizer part. This option helps to save the cpu memory and execution time. --- nnscaler/autodist/spmd_solver.py | 9 ++++----- nnscaler/cli/trainer.py | 20 ++++++++++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 9decc8a8..eedef569 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -686,11 +686,10 @@ def calc_partition_cost(self, op_idx: int, partition_idx: int): bw_comm_time = 0 intra_time = micro_batch_num * (fw_comm_time + bw_comm_time) # double check the follow chain - if self.get_father_id(op_idx) == self.get_father_id( - producer) and intra_time == 0: - if src_p.operator.ir_cell.mirror is not None: - if self.p_fathers[op_idx][ - partition_idx] != self.p_fathers[producer][k]: + # if `intra_time` (forward + backward) is 0, we assume both partitions are in the same follow chain + if self.get_father_id(op_idx) == self.get_father_id(producer) and intra_time == 0: + if src_p.operator.ir_cell.mirror is not None and tgt_p.operator.ir_cell.mirror is not None: + if self.p_fathers[op_idx][partition_idx] != self.p_fathers[producer][k]: _logger.warning( f'Unexpected comm cost, set to inf: {src_p.ir_cell} to {tgt_p.ir_cell}' ) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index b1768924..e92a26de 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -255,8 +255,14 @@ def reducer_pre_hook(reducer, grad): self.hook.after_setup(self) @classmethod - def _merge_checkpoint(cls, checkpoint_files: List[str]): - state_dicts = [Checkpointer.load(f) for f in checkpoint_files] + def _merge_checkpoint(cls, checkpoint_files: List[str], *, model_only: bool = False): + state_dicts = [] + for f in checkpoint_files: + state_dict = Checkpointer.load(f) + if model_only: + # we pop optimizer state to save cpu memory + state_dict.pop('optimizer', None) + state_dicts.append(state_dict) for i in range(1, len(state_dicts)): if state_dicts[i]['train_args'] != state_dicts[0]['train_args']: raise ValueError(f"train_args in {checkpoint_files[i]} is different from {checkpoint_files[0]}") @@ -265,8 +271,10 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): module_state_dict, opt_state_dict = nnscaler.merge_state_dicts( [s['model'] for s in state_dicts], - [s['optimizer'] for s in state_dicts] + [s['optimizer'] for s in state_dicts] if not model_only else None, ) + if model_only: + return {'model': module_state_dict} train_args = copy.deepcopy(state_dicts[0]['train_args']) train_args['checkpoint']['save_type'] = 'merged' @@ -369,8 +377,8 @@ def _broadcast_values(sdict, keys): return state_dict @classmethod - def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): - merged_state_dict = cls._merge_checkpoint(checkpoint_files) + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str, *, model_only: bool = False): + merged_state_dict = cls._merge_checkpoint(checkpoint_files, model_only=model_only) Checkpointer.save(merged_state_dict, output_file) def _log_finalize(self): @@ -857,7 +865,7 @@ def _train(self): torch.cuda.reset_peak_memory_stats() if self.train_status.finished_train_steps >= self.max_train_steps: - logger.info(f"Training is skipped: already done.") + logger.info(f"Training is skipped: already done, finished_train_steps={self.train_status.finished_train_steps} >= max_train_steps={self.max_train_steps}.") return start_epoch = self.train_status.finished_train_steps // self.total_train_steps_per_epoch From bb67d75ee32252fc7ac1bfa1c2d5bed8d7a33c7b Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 15 Dec 2025 07:36:06 +0000 Subject: [PATCH 1855/1892] Merged PR 2428: [BugFix] don't set requires_grad on end2end model Setting requires_grad to True can confuse scheduler, which may generate wrong code. --- nnscaler/parallel.py | 32 ++++++++++++++++++++++++-------- tests/graph/test_segment.py | 6 +++++- tests/test_policies.py | 10 ++++++++-- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 2eaf8d04..6dd57aaf 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -374,18 +374,28 @@ def _runtime_flags(**kwargs): return _flags(RuntimeFlag, **kwargs) -def _to_cpu(val: Any): - """Complex to CPU""" +def _to_cpu(val: Any, requires_grad: Optional[bool] = None) -> Any: + """ + Complex to CPU + Recursively move the input to CPU. + Args: + val (Any): the input value + requires_grad (Optional[bool]): whether the returned tensor requires grad. + If it is None, will keep the same as the input tensor. + """ if isinstance(val, tuple): - return tuple(_to_cpu(t) for t in val) + return tuple(_to_cpu(t, requires_grad) for t in val) if isinstance(val, list): - return list(_to_cpu(t) for t in val) + return list(_to_cpu(t, requires_grad) for t in val) if isinstance(val, dict): - return {_to_cpu(key):_to_cpu(val) for key, val in val.items()} + return {_to_cpu(key, requires_grad):_to_cpu(val, requires_grad) for key, val in val.items()} if isinstance(val, set): - return {_to_cpu(t) for t in val} + return {_to_cpu(t, requires_grad) for t in val} if isinstance(val, torch.Tensor): - requires_grad = val.is_floating_point() or val.is_complex() + if requires_grad is None: + requires_grad = val.requires_grad + else: + requires_grad = requires_grad and (val.is_floating_point() or val.is_complex()) return copy_dynamic(val, val.detach().clone().cpu().requires_grad_(requires_grad)) return val @@ -677,7 +687,13 @@ def _gen_graph( raise ValueError(f"Default value type {type(v)} of forward args is not supported.") # generate fx graph - dummy_forward_args = _to_cpu(dummy_forward_args) + dummy_forward_args = _to_cpu( + dummy_forward_args, + # in end2end mode, we don't need gradients for inputs + # in normal mode, we assume all inputs require gradients + # so it can connect to other parts of the graph correctly + requires_grad=not end2end_mode + ) fx_graph = parser.to_fx_graph(module, dummy_forward_args) # generate ir logic graph diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py index 108b9e3e..57c10809 100644 --- a/tests/graph/test_segment.py +++ b/tests/graph/test_segment.py @@ -167,7 +167,11 @@ def policy_nograd(graph: IRGraph, cfg: ComputeConfig) -> IRGraph: else: fc1_node = graph.nodes()[0] func_node = graph.nodes()[1] - assert fc1_node.inputs()[0].requires_grad and fc1_node.inputs()[0].grad + if cfg.use_end2end: + assert not fc1_node.inputs()[0].requires_grad and not fc1_node.inputs()[0].grad + else: + assert fc1_node.inputs()[0].requires_grad and fc1_node.inputs()[0].grad + assert fc1_node.inputs()[1].requires_grad and fc1_node.inputs()[1].grad assert fc1_node.outputs()[0].requires_grad and fc1_node.outputs()[0].grad assert func_node.inputs()[0].requires_grad and not func_node.inputs()[0].grad diff --git a/tests/test_policies.py b/tests/test_policies.py index 393ee4cc..ecf136dd 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -293,11 +293,17 @@ def test_codegen_fn_pipeline(tmp_path): # will generate two communication ops # one for ffn input - assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + if tp_idx == 0: + assert not _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + else: + assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') # one for ffn output assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.allreduce_identity') - assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.')) == 2 + if tp_idx == 0: + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.')) == 1 + else: + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.')) == 2 assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, r'ckpt.checkpoint\(recompute')) == 1 assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, r'def recompute\(')) == 1 From 244214c8d400d7f24cb0ac0ccbfe74fa6e20c71c Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 16 Dec 2025 06:33:13 +0000 Subject: [PATCH 1856/1892] Merged PR 2431: [BugFix] Handle graph output correctly in dynamic programming solver --- nnscaler/autodist/dp_solver.cpp | 79 +++++++++++++++++++++++++++----- tests/autodist/test_dp_solver.py | 26 +++++++++++ 2 files changed, 94 insertions(+), 11 deletions(-) diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index f0b8e6d7..3b8bc015 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -67,7 +67,7 @@ void ThreadPool::waitFinished() { cv_finished.wait(lock, [this]() { return tasks.empty() && (busy == 0); }); } -const int MAX_CONCURRENCY = std::thread::hardware_concurrency(); +int MAX_CONCURRENCY = std::thread::hardware_concurrency(); ThreadPool pool(MAX_CONCURRENCY); std::vector> split_work(int num, int base) { @@ -118,6 +118,11 @@ class DPSolver { queries.clear(); id2node.clear(); search_results.clear(); + if (verbose) { + MAX_CONCURRENCY = 1; + std::cout << "set MAX_CONCURRENCY to 1 for verbose mode" + << std::endl; + } } void add_interval(int start, int end) { @@ -230,6 +235,31 @@ class DPSolver { } } + int encode_ir(const std::vector> &cur_ir) { + int val = 0; + for (std::size_t j = 0; j < cur_ir.size(); ++j) { + val += cur_ir[j].second; + if (j + 1 < cur_ir.size()) { + val *= id2node[cur_ir[j + 1].first]->p_num; + } + } + return val; + } + + void print_ir(const std::vector> &cur_ir) { + for (std::size_t j = 0; j < cur_ir.size(); ++j) { + std::cout << "(" << cur_ir[j].first << ", " << cur_ir[j].second << ") "; + } + std::cout << std::endl; + } + + void print_states(DPNode *dp_node) { + for (std::size_t i = 0; i < dp_node->state.size(); ++i) { + UnitDPState state = dp_node->state[i]; + std::cout << "state " << i << ": " << state.to_string() << std::endl; + } + } + // lazy build edge void buildInEdges(DPNode *dp_node) { if (!dp_node->in_edges.empty()) { @@ -361,15 +391,14 @@ class DPSolver { break; } } + bool need_add_pre_node = false; if (!find_pre_id) { Node *pre_node = id2node[node->id - 1]; if (pre_node->father_id != node->father_id) { - // do nothing, means the pre_node's output is not used - // we select the 1st partition of the pre_node - // need to be careful when the graph has multiple outputs if (!has_found_follow && !follow_candidates.empty()) { cur_ir.push_back(*follow_candidates.rbegin()); } + need_add_pre_node = true; } else if (pre_node->father_id == pre_node->id) { assert(follow_candidates.rbegin()->first == pre_node->id); cur_ir.push_back(*follow_candidates.rbegin()); @@ -391,15 +420,36 @@ class DPSolver { } } std::sort(cur_ir.begin(), cur_ir.end()); - val = 0; - for (std::size_t j = 0; j < cur_ir.size(); ++j) { - val += cur_ir[j].second; - if (j + 1 < cur_ir.size()) { - val *= id2node[cur_ir[j + 1].first]->p_num; + if (verbose) { + std::cout << "need_add_pre_node: " << need_add_pre_node << std::endl; + } + if (need_add_pre_node) { + // means the pre_node's output is not used by later nodes, + // so we need to enumerate all the partition states of pre_node + if (verbose) { + std::cout << "p_num " << id2node[node->id - 1]->p_num << std::endl; + } + for (int pred_p = 0; pred_p < id2node[node->id - 1]->p_num; + ++pred_p) { + cur_ir.push_back(std::make_pair(node->id - 1, pred_p)); + int val = encode_ir(cur_ir); + dp_node->in_edges.push_back( + std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); + if (verbose) { + print_ir(cur_ir); + print_states(id2node[node->id - 1]->dp_nodes[val]); + } + cur_ir.pop_back(); + } + } else { + int val = encode_ir(cur_ir); + dp_node->in_edges.push_back( + std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); + if (verbose) { + print_ir(cur_ir); + print_states(id2node[node->id - 1]->dp_nodes[val]); } } - dp_node->in_edges.push_back( - std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); } } @@ -414,6 +464,9 @@ class DPSolver { return; } + if (verbose) { + std::cout << "before update, cur_p " << cur_p << std::endl; + } // storing edges takes space, so we build edges when needed buildInEdges(dp_node); if (dp_node->in_edges.empty()) { @@ -468,6 +521,10 @@ class DPSolver { } } } + if (verbose) { + std::cout << "after update" << std::endl; + print_states(dp_node); + } } void do_dp(int start_level, int end_level) { diff --git a/tests/autodist/test_dp_solver.py b/tests/autodist/test_dp_solver.py index 846e1c05..c6b172a4 100644 --- a/tests/autodist/test_dp_solver.py +++ b/tests/autodist/test_dp_solver.py @@ -37,6 +37,7 @@ def test_dp_solver(): # the optimal plan is each operator's first partition assert best.path == [(0, 0), (1, 0), (2, 0)] + def test_dp_solver_mem(): solver = dp_solver.DPSolver(True, 100, 1) solver.add_interval(0, 4) @@ -73,6 +74,7 @@ def test_dp_solver_mem(): assert best.path == [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)] assert best.memory == 71 + def test_dp_solver_build_in_edges(): # mock following code # dropout_rate = self.attention_dropout if self.training else 0.0 @@ -102,6 +104,7 @@ def test_dp_solver_build_in_edges(): best = ans[0] assert best.path == [(0, 0), (1, 0), (2, 0)] + def test_dp_solver_mem_bound(): solver = dp_solver.DPSolver(True, 10, 1) solver.add_interval(0, 2) @@ -119,3 +122,26 @@ def test_dp_solver_mem_bound(): ans = solver.get_results(0, 2) assert len(ans) == 0 + + +def test_dp_solver_output(): + solver = dp_solver.DPSolver(True, 1024, 1) + solver.add_interval(0, 2) + + solver.add_node(0, 0, [0], [], 2, False, False, False) + solver.add_partition(0, 0, 10, 16, 0, 0, 0, 0, 0, [[]]) + solver.add_partition(0, 1, 5, 8, 0, 0, 0, 0, 1, [[]]) + + solver.add_node(1, 1, [0, 1], [], 2, False, False, False) + solver.add_partition(1, 0, 4, 6, 0, 0, 0, 0, 0, [[]]) + solver.add_partition(1, 1, 2, 3, 0, 0, 0, 0, 1, [[]]) + + solver.add_node(2, 2, [2], [0], 1, False, False, False) + solver.add_partition(2, 0, 0, 0, 0, 0, 0, 0, 0, [[0, 0]]) + + solver.solve() + ans = solver.get_results(0, 2) + best = ans[0] + assert best.all_time == 7 + assert best.path == [(0, 1), (1, 1), (2, 0)] + assert best.memory == 11 From 137e28b25950dbf16a0cfb445e493f89c8f2f2c9 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 16 Dec 2025 07:49:07 +0000 Subject: [PATCH 1857/1892] Merged PR 2432: [Bugfix] return cloned tensor when a parameter is multiref'ed in zero3 The original multiref will return the parameter itself, which doesn't work when zero3 is enabled. Because after multiref returns, the parameter will be evicted from memory, so will the return values. To fix this problem, we add a new argument(clone_level) to multiref, and return a clone for parameters when zero3 is enable. --- nnscaler/codegen/module/module.py | 7 +++ nnscaler/policies.py | 5 +- nnscaler/runtime/function/function.py | 19 +++++-- nnscaler/utils.py | 1 + tests/cli/test_trainer2.py | 61 ++++++++++++++++++++++ tests/cli/trainer_args_shared_weights.yaml | 32 ++++++++++++ tests/test_policies.py | 2 +- 7 files changed, 121 insertions(+), 6 deletions(-) create mode 100644 tests/cli/trainer_args_shared_weights.yaml diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 94bc2f32..8856149d 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -945,6 +945,13 @@ def emit_node(node, node_idx): self.tensor_name(t, prefix_attr='self.') for t in node.iobjs() if isinstance(t, IRTensor) and t.is_param() ] + + # for multiref node under zero3, we need to clone the params to avoid in-place modification issue + if param_inputs and CompileFlag.use_zero > 1 and node.name == 'multiref': + _logger.warning(f'Node {node} is a multiref node with param inputs under ZeRO-3, ' + f'we set clone_level=1 to avoid in-place modification issue.') + node.kwargs['clone_level'] = 1 + code = self.emit_fnode(node, runtime_devid=runtime_devid, plan_ndevs=len(self.devices), runtime_ndevs=self.runtime_ndevs, prefix_attr='self.') if not param_inputs or CompileFlag.use_zero <= 1: diff --git a/nnscaler/policies.py b/nnscaler/policies.py index b061adca..29b2e509 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -816,8 +816,9 @@ def pas_fsdp(graph, cfg: 'ComputeConfig'): raise ValueError("FSDP policy only supports 1 plan GPU") if not cfg.use_zero: raise ValueError("FSDP policy requires use_zero to be 1/3") - - recompute_modules = cfg.pas_config.get('recompute_modules', '') + # use 'recomputes' instead of 'recompute_modules' + # to avoid confliction with autodist config + recompute_modules = cfg.pas_config.get('recomputes', '') # parse recompute_modules # user can also provide a list of Module classes. if isinstance(recompute_modules, str): diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index dbcb295f..cba44779 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -81,11 +81,24 @@ def fold_constant(a: Any) -> Any: return a -def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: +def multiref(tensor: torch.Tensor, times: int, *, clone_level: int = 0) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """ identity forward. Create multiple same tensor. - """ - return tensor if times == 1 else tuple([tensor] * times) + Args: + tensor (torch.Tensor): input tensor + times (int): number of same tensor to create + clone_level (int): 0: no clone, 1: clone once for all, 2: clone each time + Returns: + Union[torch.Tensor, Tuple[torch.Tensor]]: + if times==1, return tensor; else return tuple of tensors + """ + if clone_level == 0: + return tensor if times == 1 else tuple([tensor] * times) + elif clone_level == 1: + cloned_tensor = tensor.clone() + return cloned_tensor if times == 1 else tuple([cloned_tensor] * times) + else: # clone_level == 2 + return tensor.clone() if times == 1 else tuple([tensor.clone() for _ in range(times)]) def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) -> torch.Tensor: diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 053a0b60..a88f4997 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -659,6 +659,7 @@ def mark_dynamic(tensor: torch.Tensor, dims: int | list[int] | tuple[int]) -> to Mark the dim of a tensor as dynamic, which means it can be changed in the future. This is the same with `torch._dynamo.mark_dynamic` """ + dims = [dims] if isinstance(dims, int) else dims setattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, set(dims) if dims else set()) return tensor diff --git a/tests/cli/test_trainer2.py b/tests/cli/test_trainer2.py index 9b565322..34007486 100644 --- a/tests/cli/test_trainer2.py +++ b/tests/cli/test_trainer2.py @@ -5,6 +5,7 @@ from nnscaler.cli import TrainerArgs, Trainer from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode class NanoGptDataset(Dataset): @@ -71,3 +72,63 @@ def trainer_mixed_worker(save_dir): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_mixed_bf16_model(tmp_path): launch_torchrun(2, trainer_mixed_worker, tmp_path) + + +class SharedWeightsDataset(Dataset): + def __init__(self, *args, **kwargs): + pass + + def __getitems__(self, indices): + return [torch.randn(4, 4) for _ in indices] + + def __len__(self): + return 10000 + + +class SharedWeightsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + self.linear2 = torch.nn.Linear(4, 4, bias=False) + self.linear2.weight = self.linear.weight # share weight + + def forward(self, x): + y = x * 2 + z = x + 2 + r = self.linear2(y) + r = r + self.linear(z) + return torch.sum(r) + + +def trainer_zero3_shared_weights_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args_shared_weights.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer = Trainer(train_args=args) + trainer.run() + # weight sharing multiref should have clone_level=1 in gencode + assert _gencode_contains( + gen_savedir, + SharedWeightsModule, + torch.distributed.get_rank(), + r'linear_weight_\d+, linear_weight_\d+ = nnscaler.runtime.function.multiref\(self.linear_weight_\d+, times=2, clone_level=1\)' + ) + # non-weight tensor multiref should not have clone_level + assert _gencode_contains( + gen_savedir, + SharedWeightsModule, + torch.distributed.get_rank(), + r'x_\d+, x_\d+ = nnscaler.runtime.function.multiref\(x_\d+, times=2\)' + ) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_zero3_shared_weights(tmp_path): + launch_torchrun(4, trainer_zero3_shared_weights_worker, tmp_path) diff --git a/tests/cli/trainer_args_shared_weights.yaml b/tests/cli/trainer_args_shared_weights.yaml new file mode 100644 index 00000000..2dbd6f7a --- /dev/null +++ b/tests/cli/trainer_args_shared_weights.yaml @@ -0,0 +1,32 @@ +compute_config: + plan_ngpus: 2 + runtime_ngpus: 4 + constant_folding: false + use_zero: 3 + use_end2end: true + +run_mode: run +pas_policy: tp +micro_batch_size: 1 +grad_accumulation_steps: 4 +max_train_steps: 10 +enable_progress_bar: false +log_progress_every_n_train_steps: 10 +precision: bf16 +seed: 1 + +model: + type: tests.cli.test_trainer2.SharedWeightsModule + +optimizer: + type: torch.optim.AdamW + args: + betas: (0.9, 0.95) + eps: 1e-08 + weight_decay: 0.1 + lr: 0.0001 + fused: true + clip_gnorm: 2.0 + +dataset: + type: tests.cli.test_trainer2.SharedWeightsDataset diff --git a/tests/test_policies.py b/tests/test_policies.py index ecf136dd..003adc0c 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1002,7 +1002,7 @@ def test_codegen_fsdp(tmp_path): use_end2end=True, use_zero=3, pas_config={ - 'recompute_modules': [FFNDropout], + 'recomputes': [FFNDropout], } ), gen_savedir=tmp_path, From 26744c06c7bcbe8c6763ced02cd0d65104d09396 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 16 Dec 2025 08:38:10 +0000 Subject: [PATCH 1858/1892] Merged PR 2427: [BugFix] auto pack kwargs in dummy input --- nnscaler/graph/tracer/concrete_tracer.py | 8 +++++- tests/graph/tracer/test_pack_kwargs.py | 33 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 tests/graph/tracer/test_pack_kwargs.py diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index 5a86969c..85f7c8f6 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -461,7 +461,6 @@ def proxy_placeholder(name: str): return self.create_proxy('placeholder', name, default_arg, {}) args.extend(proxy_placeholder(names) for names in arg_names) - if hasattr(co, 'co_kwonlyargcount') and ( co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF): # TODO: type annotations for *args and **kwargs @@ -471,6 +470,13 @@ def proxy_placeholder(name: str): more_args = proxy_placeholder(name) if co.co_flags & inspect.CO_VARKEYWORDS: name = '**' + next(names_iter) + if name not in concrete_args: + # auto pack the additional kwargs + kwargs_val = {} + for cc_name in concrete_args: + if cc_name not in arg_names and not cc_name.startswith('*'): + kwargs_val[cc_name] = concrete_args[cc_name] + concrete_args[name] = kwargs_val default_args[name] = {} kwargs = proxy_placeholder(name) diff --git a/tests/graph/tracer/test_pack_kwargs.py b/tests/graph/tracer/test_pack_kwargs.py new file mode 100644 index 00000000..4db75d5d --- /dev/null +++ b/tests/graph/tracer/test_pack_kwargs.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from torch.nn import Module + +from nnscaler.graph.tracer import concrete_trace +from ...utils import replace_all_device_with + + +class Model(Module): + def __init__(self): + super(Model, self).__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, **kwargs): + return self.linear(kwargs['input']) + + +@replace_all_device_with('cpu') +def test_pack_kwargs(): + model = Model() + example_inputs = {'input': torch.randn(1, 10)} + traced_model = concrete_trace(model, example_inputs) + assert list(traced_model.graph.nodes)[0].target == '**kwargs' + + +@replace_all_device_with('cpu') +def test_direct_kwargs(): + model = Model() + example_inputs = {'**kwargs': {'input': torch.randn(1, 10)}} + traced_model = concrete_trace(model, example_inputs) + assert list(traced_model.graph.nodes)[0].target == '**kwargs' From 1c1c8813b802b3a4ff8f83c7a2a6c0cf7e73dbb3 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 18 Dec 2025 08:30:42 +0000 Subject: [PATCH 1859/1892] Merged PR 2435: [Bugfix] Fix the generated code when **kwargs are used in module.forward Two cases when generating code for `**kwargs`: 1. As function argument(function declaration): keep `**` 2. As function parameter(calling function): remove `**` --- docs/source/parallel_module.md | 2 +- nnscaler/codegen/emit.py | 26 ++- nnscaler/codegen/module/module.py | 4 +- tests/parallel_module/test_gencode_kwargs.py | 189 +++++++++++++++++++ 4 files changed, 211 insertions(+), 10 deletions(-) create mode 100644 tests/parallel_module/test_gencode_kwargs.py diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 4b910a20..7ec8a0a5 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -7,7 +7,7 @@ Currently, we support three kinds of parallelism: data parallelism, tensor paral Data parallelism and tensor parallelism can be supported for any module, but pipeline parallelism is only supported for end2end modules for scheduling reason. An end2end module is a module which satisfies: -- the first argument of `module.forward` is the data sample, and every other argument should have default value, and use its default value in `module.forward` function. +- the first argument of `module.forward` is the data sample, and other arguments should have default value, and should never be used in `module.forward` function. - the first return value of `module.forward` is the loss (scalar tensor) The above restrictions are necessary for the pipeline parallelism to work. Of course, you can still use the parallel module without pipeline parallelism for end2end modules. diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index 00c73060..3f6e8652 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -36,7 +36,10 @@ def __repr__(self): return self.name -def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: +def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None, + *, + strip_star: bool = True, +) -> Any: """ Return repr-able value of a tensor or value. For tensor, return IRValue({prefix}{tensor.name}_{tensor.tid}) @@ -45,6 +48,7 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: Args: val (Any): tensor or non-tensor value prefix_attr (str): prefix to the tensor name if the tensor is an attribute + strip_star (bool): whether to strip leading * for *args and **kwargs Returns: the val that can be repr safely """ @@ -52,20 +56,22 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: return val if isinstance(val, IRObject): tensor_name = val.name + if strip_star: + tensor_name = tensor_name.lstrip('*') tensor_name = tensor_name.replace('.', '_') name = '_'.join([tensor_name, str(val.tid)]) if prefix_attr is not None and val.is_attr(): name = prefix_attr + name return IRValue(name) elif isinstance(val, slice): - return slice(_safe_repr_value(val.start, prefix_attr), _safe_repr_value(val.stop, prefix_attr), _safe_repr_value(val.step, prefix_attr)) + return slice(_safe_repr_value(val.start, prefix_attr, strip_star=strip_star), _safe_repr_value(val.stop, prefix_attr, strip_star=strip_star), _safe_repr_value(val.step, prefix_attr, strip_star=strip_star)) elif isinstance(val, dict): - return {_safe_repr_value(k, prefix_attr): _safe_repr_value(v, prefix_attr) for k, v in val.items()} + return {_safe_repr_value(k, prefix_attr, strip_star=strip_star): _safe_repr_value(v, prefix_attr, strip_star=strip_star) for k, v in val.items()} elif isinstance(val, list): - return [_safe_repr_value(v, prefix_attr) for v in val] + return [_safe_repr_value(v, prefix_attr, strip_star=strip_star) for v in val] elif isinstance(val, tuple): # TODO: support subclasses of tuple, like torch.Size? - return tuple(_safe_repr_value(v, prefix_attr) for v in val) + return tuple(_safe_repr_value(v, prefix_attr, strip_star=strip_star) for v in val) elif isinstance(val, (int, str, bool, float, type(None), bytes, type(Ellipsis), torch.dtype)): return val elif isinstance(val, torch.device): @@ -90,7 +96,10 @@ class CodeEmission: def node_name(self, node: IRCell) -> str: return f"{node.name}{node.cid}" - def tensor_name(self, val: Any, prefix_attr: Optional[str] = None) -> str: + def tensor_name(self, val: Any, prefix_attr: Optional[str] = None, + *, + strip_star: bool = True, + ) -> str: """ Return representation of a value or a tensor. For tensor, return the {prefix}{tensor.name}_{tensor.tid} @@ -99,10 +108,13 @@ def tensor_name(self, val: Any, prefix_attr: Optional[str] = None) -> str: Args: val (Any): tensor or non-tensor value prefix_attr (Optional[str]): prefix to the tensor name if the tensor is an attribute + strip_star (bool): whether to strip leading * for *args and **kwargs + You should set it to False when you want to generate code for + function arguments. Returns: representation of the val in str """ - return repr(_safe_repr_value(val, prefix_attr)) + return repr(_safe_repr_value(val, prefix_attr, strip_star=strip_star)) def complex_name(self, val: Any, prefix_attr: Optional[str]=None) -> str: """ diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 8856149d..c3bc8c34 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -500,9 +500,9 @@ def forward(self, x, y=None, z=None): for t in node.inputs(): if isinstance(t, IRSubTensor): if not t.is_attr(): - args.append(self.tensor_name(t)) + args.append(self.tensor_name(t, strip_star=False)) else: - args.append(self.tensor_name(t)) + args.append(self.tensor_name(t, strip_star=False)) node_args.append(args) if outfile_attr_meta_map: diff --git a/tests/parallel_module/test_gencode_kwargs.py b/tests/parallel_module/test_gencode_kwargs.py new file mode 100644 index 00000000..85a1b3c5 --- /dev/null +++ b/tests/parallel_module/test_gencode_kwargs.py @@ -0,0 +1,189 @@ +import nnscaler +from nnscaler import parallelize, ComputeConfig + +import torch + +from .test_gencode import _gencode_contains, replace_all_device_with, print_gencode + + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='kw_operator') +def kw_operator(x: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='kw_operator2') +def kw_operator2(x: torch.Tensor, y: torch.Tensor, kwargs) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + + +class KwargsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, **kwargs): + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + c = kwargs['c'] + return kw_operator(x, self.scale, **kwargs) \ + + kw_operator2(x, self.scale, kwargs) + a + b + c + + +@replace_all_device_with('cpu') +def test_kwargs(tmp_path): + m = KwargsModule() + m.train() + parallelize( + m, + {'x': torch.randn(4, 4), 'a': 3, 'c': 4, 'd': 5}, + 'dp', + ComputeConfig(1, 1, constant_folding=False), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'a\'\)') + assert not _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'b\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'c\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'d\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'tests.parallel_module.test_gencode_kwargs.kw_operator\(x_\d+, self.scale_\d+, a=getitem_[\d_]+, c=getitem_[\d_]+, d=getitem_[\d_]+\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r"tests.parallel_module.test_gencode_kwargs.kw_operator2\(x_\d+, self.scale_\d+, kwargs=\{'a': getitem_[\d_]+, 'c': getitem_[\d_]+, 'd': getitem_[\d_]+\}\)") + # code looks like: + # def segment49(self, x_31, **kwargs_6): + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_28 = _operator.getitem(kwargs_6, 'a') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_1_29 = _operator.getitem(kwargs_6, 'c') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_2_30 = _operator.getitem(kwargs_6, 'd') + # # created at IRAdapterGener:local_consumer_multiref + # x_52, x_56 = nnscaler.runtime.function.multiref(x_31, times=2) + # del x_31 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # kw_operator_34 = tests.parallel_module.test_gencode_kwargs.kw_operator(x_52, self.scale_33, a=getitem_28, c=getitem_1_29, d=getitem_2_30) + # del x_52 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # kw_operator2_35 = tests.parallel_module.test_gencode_kwargs.kw_operator2(x_56, self.scale_33, kwargs={'a': getitem_28, 'c': getitem_1_29, 'd': getitem_2_30}) + # del x_56 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_36 = torch.add(kw_operator_34, kw_operator2_35, alpha=1) + # del kw_operator_34, kw_operator2_35 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_1_37 = torch.add(add_36, getitem_28, alpha=1) + # del add_36 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_2_38 = torch.add(add_1_37, 2, alpha=1) + # del add_1_37 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_3_32 = torch.add(add_2_38, getitem_1_29, alpha=1) + # del add_2_38 + # return add_3_32 + + # def _forward_impl(self, x, **kwargs): + # add_3_32 = self.segment49(x, **kwargs) + # return add_3_32 + assert True + + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='dict_operator') +def dict_operator(x: torch.Tensor, y: torch.Tensor, kwargs: dict) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + + +class DictargsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, kwargs: dict): + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + c = kwargs['c'] + return dict_operator(x, self.scale, kwargs) \ + + kw_operator(x, self.scale, **kwargs) + a + b + c + + +@replace_all_device_with('cpu') +def test_dictargs(tmp_path): + m = DictargsModule() + m.train() + parallelize( + m, + {'x': torch.randn(4, 4), 'kwargs': {'a': 3, 'c': 4, 'd': 5}}, + 'dp', + ComputeConfig(1, 1, constant_folding=False), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"builtins.dict.get\(kwargs_.*, 'a', 1\)") + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"builtins.dict.get\(kwargs_.*, 'b', 2\)") + assert len(_gencode_contains(tmp_path, DictargsModule, 0, + r"_operator\.getitem\(kwargs_.*, 'a'\)")) == 1 + assert len(_gencode_contains(tmp_path, DictargsModule, 0, + r"_operator\.getitem\(kwargs_.*, 'c'\)")) == 2 + assert _gencode_contains(tmp_path, DictargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'d\'\)') + assert _gencode_contains(tmp_path, DictargsModule, 0, + r'tests.parallel_module.test_gencode_kwargs.kw_operator\(x_\d+, self.scale_\d+, a=getitem_[\d_]+, c=getitem_[\d_]+, d=getitem_[\d_]+\)') + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"tests.parallel_module.test_gencode_kwargs.dict_operator\(x_\d+, self.scale_\d+, kwargs=kwargs_\d+\)") + # code looks like: + # def segment52(self, x_35, kwargs_6): + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2411, in forward, a = kwargs.get('a', 1) + # get_7 = builtins.dict.get(kwargs_6, 'a', 1) + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2412, in forward, b = kwargs.get('b', 2) + # get_1_8 = builtins.dict.get(kwargs_6, 'b', 2) + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2413, in forward, c = kwargs['c'] + # getitem_31 = _operator.getitem(kwargs_6, 'c') + # # created at IRAdapterGener:local_consumer_multiref + # x_56, x_60 = nnscaler.runtime.function.multiref(x_35, times=2) + # del x_35 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # dict_operator_38 = tests.parallel_module.test_gencode_kwargs.dict_operator(x_56, self.scale_37, kwargs=kwargs_6) + # del x_56 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_1_32 = _operator.getitem(kwargs_6, 'a') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_2_33 = _operator.getitem(kwargs_6, 'c') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_3_34 = _operator.getitem(kwargs_6, 'd') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # kw_operator_39 = tests.parallel_module.test_gencode_kwargs.kw_operator(x_60, self.scale_37, a=getitem_1_32, c=getitem_2_33, d=getitem_3_34) + # del x_60 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_40 = torch.add(dict_operator_38, kw_operator_39, alpha=1) + # del dict_operator_38, kw_operator_39 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_1_41 = torch.add(add_40, get_7, alpha=1) + # del add_40 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_2_42 = torch.add(add_1_41, get_1_8, alpha=1) + # del add_1_41 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_3_36 = torch.add(add_2_42, getitem_31, alpha=1) + # del add_2_42 + # return add_3_36 + + # def _forward_impl(self, x, kwargs): + # add_3_36 = self.segment52(x, kwargs) + # return add_3_36 From c6dd773b00cecae1bbea4483236042b41af33de1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 23 Dec 2025 07:03:08 +0000 Subject: [PATCH 1860/1892] Merged PR 2438: [BwCompat] Fix backward compatiblity when metadata lacks zero version Fix backward compatiblity (found in nightly test for fairseq) when metadata lacks zero version, in which case zero1 is assumably used. --- nnscaler/runtime/module.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 550632d9..9d286489 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -704,7 +704,12 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size, return opt_states def _merge_opt_zero(param_shape, worker_idx, param_idx): - model_idx2opt_idx, opt_idx2ranks, zero_version = zero_idx_maps[worker_idx] + if len(zero_idx_maps[worker_idx]) == 3: + model_idx2opt_idx, opt_idx2ranks, zero_version = zero_idx_maps[worker_idx] + else: # backward compatibility + assert len(zero_idx_maps[worker_idx]) == 2 + model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[worker_idx] + zero_version = 1 # default to ZeRO-1 opt_idx = model_idx2opt_idx[param_idx] if isinstance(opt_idx, int): # the param without reducer From 1d88d04556671b70fcef7ce528bc22860aa831ba Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 25 Dec 2025 00:20:45 +0000 Subject: [PATCH 1861/1892] Merged PR 2442: [BugFix] Rebuild param groups after resuming in HybridOptimizer --- nnscaler/runtime/hybrid_optimizer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/nnscaler/runtime/hybrid_optimizer.py b/nnscaler/runtime/hybrid_optimizer.py index cd0379a4..e120c3df 100644 --- a/nnscaler/runtime/hybrid_optimizer.py +++ b/nnscaler/runtime/hybrid_optimizer.py @@ -240,6 +240,13 @@ def load_state_dict(self, state_dict) -> None: for child_state_dict, opt in zip(child_state_dicts, self.optimizers): opt.load_state_dict(child_state_dict) + # after loading from state dict, the param_groups of optimizers are reassigned + # (instead of updated inplace), so we need to gather them again (as we have done + # in the constructor). + self.param_groups = [] + for optimizer in self.optimizers: + self.param_groups.extend(optimizer.param_groups) + def add_param_group(self, param_group: dict[str, Any]) -> None: # no-op to avoid creating new parameter groups # all parameter groups are managed by the individual optimizers From dbcc36dc98ed488fd0f44e8365114ac4a1e6c44b Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 29 Dec 2025 01:24:34 +0000 Subject: [PATCH 1862/1892] Merged PR 2443: [Feat] Add option disable_shared_param_constraint and recompute_ratio to autodist --- nnscaler/autodist/autodist_config.py | 14 ++++++++++++++ nnscaler/autodist/model_graph.py | 3 +++ nnscaler/autodist/spmd_solver.py | 3 +++ nnscaler/graph/parser/external/apex.py | 3 ++- nnscaler/policies.py | 12 ++++++++++++ 5 files changed, 34 insertions(+), 1 deletion(-) diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index b790e459..de3eb351 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -72,6 +72,12 @@ class AutoDistConfig: `x.module1` will match `x.module1` but not `y.module1`. Due to constraint of the tracer, you can pass `ROOT` to recompute_modules if you want the whole module to be recomputed. + - recompute_ratio ('float`, *optional*, defaults to `1.0`): + When `recompute_modules` only contains one name (excluding `ROOT`), this specify the ratio of modules + to be recomputed. For example, if `module1` is specified in `recompute_modules` and `recompute_ratio` is `0.8`, + only 80% of `module1` instances will be recomputed. + If there are multiple module names in `recompute_modules`, this field will be ignored and all specified modules + will be recomputed. - memory_constraint (`float`, *optional*, defaults to `32`): The memory constraint in each device in GB. - memory_granularity (`int`, *optional*, defaults to `1`): @@ -115,6 +121,10 @@ class AutoDistConfig: `transient_mem_size = opt_transient_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)`. This formula is useful in many cases, but it may be too strict when some operators consume or generate a large tensor (>= 4GB). In this case, you can set `transient_mem_coef` to a smaller value to relax the constraint. + - disable_shared_param_constraint (`bool`, *optional*, defaults to `False`): + Whether to disable the shared parameter constraint in spmd solver. When a parameter is shared by multiple modules, + the spmd solver will force the parameter to be replicated to complicated adapter generation. However, user can disable + it and provide customized partition constraints for those shared parameters. """ def __init__(self, @@ -133,6 +143,7 @@ def __init__(self, mesh_row=1, mesh_col=1, recompute_modules='', + recompute_ratio=1.0, memory_constraint=32, memory_granularity=1, micro_batch_size=1, @@ -150,6 +161,7 @@ def __init__(self, solver='dp', parallel_profile=True, transient_mem_coef=2, + disable_shared_param_constraint=False, **kwargs): self.pc_path = partition_constraints_path self.profile_dir = profile_dir @@ -166,6 +178,7 @@ def __init__(self, self.is_train = is_train self.mesh_desc = MeshDesc(mesh_row, mesh_col) self.recompute_modules = recompute_modules + self.recompute_ratio = recompute_ratio # from GB to Byte self.memory_constraint = int(memory_constraint * 1024 * 1024 * 1024) self.memory_granularity = memory_granularity @@ -192,6 +205,7 @@ def __init__(self, self.solver = 'dp' self.parallel_profile = parallel_profile self.transient_mem_coef = transient_mem_coef + self.disable_shared_param_constraint = disable_shared_param_constraint ignored_keys = list(kwargs.keys()) if ignored_keys: diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 44be8ca2..7c9eec82 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -718,6 +718,9 @@ def fetch_module(scope_node: ScopeNode, prefix: List[str]): modules = [self.scope_tree_root] else: modules = fetch_module(self.scope_tree_root, []) + if len(recompute_modules) == 1 and self.autodist_config.recompute_ratio < 1.0: + boundary = max(1, int(len(modules) * self.autodist_config.recompute_ratio)) + modules = modules[:boundary] train_mem = 0 for module in modules: diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index eedef569..608e1fd2 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -261,6 +261,9 @@ def should_force_replica(operator: CubeOperator) -> bool: if len(consumers) == 1: continue _logger.info(f'find shared parameter {param} in {consumers}') + if self.autodist_config.disable_shared_param_constraint: + _logger.info(f'disable shared parameter constraint for {param}') + continue for consumer in consumers: if not isinstance(consumer, IRDimops): # always replicate non-dimops diff --git a/nnscaler/graph/parser/external/apex.py b/nnscaler/graph/parser/external/apex.py index 94b22209..e7d321d1 100644 --- a/nnscaler/graph/parser/external/apex.py +++ b/nnscaler/graph/parser/external/apex.py @@ -83,6 +83,7 @@ def apex_fused_rms_norm_affine_anno(input, weight, normalized_shape, eps, *args, parser.register(apex_fused_layer_norm_affine_anno)(fused_layer_norm_affine) parser.register(apex_fused_rms_norm_anno)(fused_rms_norm) parser.register(apex_fused_rms_norm_affine_anno)(fused_rms_norm_affine) + _logger.info("apex ops registered successfully.") except: - _logger.warning('skip apex ops as it is not installed.') + _logger.debug('skip apex ops as it is not installed.') diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 29b2e509..b4cf13af 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -227,6 +227,8 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: + from nnscaler.autodist.util import get_default_profile_path + pas_cfg = cfg.pas_config update_freq = pas_cfg.get('update_freq', 1) @@ -274,18 +276,24 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: use_memory_efficient_bf16 = pas_cfg.get('use_memory_efficient_bf16', False) use_fp16 = pas_cfg.get('use_fp16', use_memory_efficient_fp16) use_bf16 = pas_cfg.get('use_bf16', use_memory_efficient_bf16) + profile_dir = pas_cfg.get('profile_dir', None) + if profile_dir is None: + profile_dir = get_default_profile_path() re_profile = pas_cfg.get('re_profile', False) verbose = pas_cfg.get('verbose', False) load_plan_path = pas_cfg.get('load_plan_path', None) save_plan_path = pas_cfg.get('save_plan_path', None) partition_constraints_path = pas_cfg.get('partition_constraints_path', '') recompute_modules = pas_cfg.get('recompute_modules', '') + recompute_ratio = pas_cfg.get('recompute_ratio', 1.0) pipeline_pivots = pas_cfg.get('pipeline_pivots', '') max_pipeline_bubble_ratio = pas_cfg.get('max_pipeline_bubble_ratio', 0.2) max_pipeline_unbalance_ratio = pas_cfg.get('max_pipeline_unbalance_ratio', 0.5) use_apex_fused_adam_v2 = pas_cfg.get('use_apex_fused_adam_v2', False) parallel_profile = pas_cfg.get('parallel_profile', True) transient_mem_coef = pas_cfg.get('transient_mem_coef', 2) + disable_shared_param_constraint = pas_cfg.get('disable_shared_param_constraint', False) + solver = pas_cfg.get('solver', 'dp') task_name = f'{task_name}_{cfg.plan_ngpus}gpus_{update_freq}update_freq' if memory_constraint == -1: @@ -340,8 +348,10 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: opt_transient_coef=opt_transient_coef, verbose=verbose, re_profile=re_profile, + profile_dir=profile_dir, world_size=cfg.runtime_ngpus, recompute_modules=recompute_modules, + recompute_ratio=recompute_ratio, zero_stage=zero_stage, zero_ngroups=zero_ngroups, load_plan_path=load_plan_path, @@ -352,6 +362,8 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: max_pipeline_unbalance_ratio=max_pipeline_unbalance_ratio, parallel_profile=parallel_profile, transient_mem_coef=transient_mem_coef, + disable_shared_param_constraint=disable_shared_param_constraint, + solver=solver, ) return parallelize_graph(graph, autodist_cfg) From d1d89016bdd17e0b107a596128cbb443729207c9 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 30 Dec 2025 00:17:31 +0000 Subject: [PATCH 1863/1892] Merged PR 2437: [Feat] cli: add customized serialization runner Add serialization runner to support different serialization strategy, like chunking, async, etc. --- nnscaler/cli/checkpoint.py | 20 +- nnscaler/cli/mixed_module.py | 22 +- nnscaler/cli/serialization.py | 367 ++++++++++++++++++++++++---- nnscaler/cli/trainer.py | 69 ++++-- nnscaler/cli/trainer_args.py | 41 +++- nnscaler/runtime/serialization.py | 7 +- tests/cli/test_serialization.py | 245 +++++++++++++++++++ tests/cli/test_trainer.py | 38 +-- tests/runtime/test_serialization.py | 7 + 9 files changed, 710 insertions(+), 106 deletions(-) create mode 100644 tests/cli/test_serialization.py diff --git a/nnscaler/cli/checkpoint.py b/nnscaler/cli/checkpoint.py index 3f596bbc..4d080620 100644 --- a/nnscaler/cli/checkpoint.py +++ b/nnscaler/cli/checkpoint.py @@ -2,7 +2,14 @@ # Licensed under the MIT License. """ -Only for command line +This script provides functionality to convert a merged checkpoint or directory containing per-rank checkpoints into sharded checkpoints +suitable for distributed training with multiple GPUs. +Run this script with: + python -m nnscaler.cli.checkpoint distribute -f +where is the path to the merged checkpoint file or directory containing per-rank checkpoints, +and is the directory to save the sharded checkpoints. + +This script only for command line. """ import logging @@ -114,21 +121,24 @@ def _distribute_checkpoint(train_args: TrainerArgs, from_: str, to_: str): resume_from = Path(from_) save_to = Path(to_) save_to.mkdir(parents=True, exist_ok=True) + checkpointer = train_args.create_checkpointer() if resume_from.is_file(): - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + state_dict = checkpointer.load(resume_from, device='cpu') if convert_fn := train_args.checkpoint.resolved_convert_fn: state_dict = convert_fn(state_dict) else: - ckpt_files = list(resume_from.glob('*.ckpt')) + ckpt_files = checkpointer.list_checkpoints(resume_from) rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} if set(rank_ckpt_files.keys()) != set(range(len(rank_ckpt_files))): raise ValueError(f"Checkpoint files in {resume_from} are not complete: {rank_ckpt_files.keys()}") - state_dict = Trainer._merge_checkpoint(list(rank_ckpt_files.values())) + state_dict = Trainer._merge_checkpoint(list(rank_ckpt_files.values()), checkpointer=checkpointer) for i in range(train_args.compute_config.runtime_ngpus): sharded_state_dict = _trim_merged_checkpoint(train_args, state_dict, i) - torch.save(sharded_state_dict, save_to / f"{i}.ckpt") + checkpointer.save_for_rank(sharded_state_dict, save_to, i) + + checkpointer.flush() if __name__ == '__main__': diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index 81deb84e..ef4268ca 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -42,7 +42,8 @@ class ModuleParallelizeConfigAdapter(PrecisionMixin, PolicyMixin): def __init__( self, trainer_args: TrainerArgs, parallel_module: Optional[ModuleParallelizeConfig] = None, - tracing_weights: Optional[dict[str, Any]] = None + tracing_weights: Optional[dict[str, Any]] = None, + checkpointer: Optional[Checkpointer] = None, ): """ Args: @@ -53,6 +54,7 @@ def __init__( self.trainer_args = trainer_args self.parallel_module = parallel_module self.tracing_weights = tracing_weights + self.checkpointer = checkpointer or Checkpointer() # we don't want to load the tracing weights every time # It should be loaded only once outside, and passed to the adapter @@ -133,10 +135,10 @@ def load_tracing_weights(self) -> Optional[dict[str, Any]]: # try to reuse the weights from the tracing weights tracing_weights = self.tracing_weights if self.tracing_from_weights and tracing_weights is None: - tracing_weights = Checkpointer.load(self.tracing_from_weights) + tracing_weights = self.checkpointer.load(self.tracing_from_weights) else: if self.tracing_from_weights: - tracing_weights = Checkpointer.load(self.tracing_from_weights) + tracing_weights = self.checkpointer.load(self.tracing_from_weights) elif self.parallel_module.tracing_from_weights_prefix: leading_key = self.parallel_module.tracing_from_weights_prefix + '.' tracing_weights = {} @@ -289,15 +291,23 @@ def parameters_for_calc_gnorm(self): return model -def parallelize_model(trainer_args: TrainerArgs, dummy_input: dict[str, Any], load_module: bool, build_buckets: bool): +def parallelize_model( + trainer_args: TrainerArgs, + dummy_input: dict[str, Any], + load_module: bool, + build_buckets: bool, + checkpointer: Checkpointer +): tracing_weights = None + checkpointer = checkpointer or Checkpointer() if trainer_args.tracing_from_weights: - tracing_weights = Checkpointer.load(trainer_args.tracing_from_weights) + tracing_weights = checkpointer.load(trainer_args.tracing_from_weights) def _new_adapter(parallel_module=None): return ModuleParallelizeConfigAdapter( trainer_args, parallel_module, - tracing_weights=tracing_weights + tracing_weights=tracing_weights, + checkpointer=checkpointer, ) if not trainer_args.model.parallel_modules: diff --git a/nnscaler/cli/serialization.py b/nnscaler/cli/serialization.py index bc90e59a..c499740f 100644 --- a/nnscaler/cli/serialization.py +++ b/nnscaler/cli/serialization.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Any, Callable, Protocol +from typing import Any, Callable, Protocol, Type from pathlib import Path +import shutil import torch @@ -17,31 +18,189 @@ class _SaveProc(Protocol): def __call__(self, obj: Any, f: str | Path) -> None: ... +class CheckpointFormat(Protocol): + """ + A placeholder class for new serialization formats. + """ + name: str + suffix: str + + @classmethod + def load(cls, f: str | Path, *, device='cpu') -> Any: + ... + + @classmethod + def save(cls, obj: Any, f: str | Path) -> None: + ... + + +class SerializationRunner(Protocol): + name: str + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + ... + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + ... + + def flush(self) -> None: + """ + Flushes any pending operations for saving. + Loading operations are assumed to be synchronous. + """ + ... + + +class _DefaultSerializationRunner: + name: str = '' + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + return load_func(f, device=device) + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + save_func(obj, f) + + def flush(self) -> None: + pass + + +def make_hybrid_serialization_runner( + load_serializer: Type[SerializationRunner], + save_serializer: Type[SerializationRunner] +) -> Type[SerializationRunner]: + """ + Creates a hybrid serialization runner that uses different runners for loading and saving. + """ + class HybridSerializationRunner(SerializationRunner): + name = f"{load_serializer.name}:{save_serializer.name}" + + def __init__(self, load_args=None, save_args=None): + self._load_runner = load_serializer(**(load_args or {})) + self._save_runner = save_serializer(**(save_args or {})) + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + return self._load_runner.run_load(load_func, f, device=device) + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + self._save_runner.run_save(save_func, obj, f) + + def flush(self) -> None: + self._save_runner.flush() + + return HybridSerializationRunner + + +def _torch_load(f: str | Path, *, device='cpu') -> Any: + return torch.load(f, map_location=device, weights_only=False) + + +def _torch_save(obj: Any, f: str | Path) -> None: + torch.save(obj, f) + + class Checkpointer: # the format of the checkpoint file # keys: epoch, step, rank # currently it is not configurable # TODO: make it configurable - CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/{rank}{suffix}' + CHECKPOINT_FILE_NAME_FORMAT: str = '{rank}{suffix}' + CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/' + CHECKPOINT_FILE_NAME_FORMAT CHECKPOINT_LAST_DIR_NAME: str = 'last' CHECKPOINT_BEST_DIR_NAME: str = 'best' CHECKPOINT_MERGED_FILE_NAME: str = 'merged{suffix}' CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}{suffix}' CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}{suffix}' - SUFFIX_MAP: dict[str, str] = { + NAME_MAP: dict[str, str] = { 'pt': '.ckpt', 'safetensors': '.safetensors' } + SUFFIX_MAP: dict[str, str] = {v: k for k, v in NAME_MAP.items()} # will use torch.load and torch.save for other suffixes SUFFIX_HANDLERS: dict[str, tuple[_LoadProc, _SaveProc]] = { '.safetensors': (load, save), } + REGISTERED_RUNNERS: dict[str, Type[SerializationRunner]] = { + '': _DefaultSerializationRunner, + } - def __init__(self, format: str = 'pt'): - if format not in self.SUFFIX_MAP: + def __init__(self, format: str = 'pt', serializer: str = None, serializer_args: dict[str, Any] = None): + """ + Args: + format (`str`, *optional*, defaults to `"pt"`): + The checkpoint format to use. Builtin formats are: + - `"pt"`: PyTorch checkpoint format. + - `"safetensors"`: Safetensors format. + serializer (`str`, *optional*): + The serialization runner to use. Builtin runners are: + - `""` (empty string): Default runner that directly uses the load and save functions. + You can also specify a hybrid runner by using the format `load_serializer:save_serializer`, + e.g., `"split:async"`. + serializer_args (`dict`, *optional*): + args for the serialization runner. + """ + if format not in self.NAME_MAP: raise ValueError(f"Unsupported checkpoint format: {format}") self.format = format - self.suffix = self.SUFFIX_MAP[format] + self.suffix = self.NAME_MAP[format] + + self.runner: SerializationRunner + serializer = serializer or '' + + if ':' in serializer: + parts = serializer.split(':') + if len(parts) != 2: + raise ValueError(f"Invalid hybrid serialization runner: {serializer}") + load_serializer_name = parts[0] + save_serializer_name = parts[1] + if load_serializer_name not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {load_serializer_name}") + if save_serializer_name not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {save_serializer_name}") + load_serializer_type = self.REGISTERED_RUNNERS[load_serializer_name] + save_serializer_type = self.REGISTERED_RUNNERS[save_serializer_name] + runner_cls = make_hybrid_serialization_runner( + load_serializer_type, + save_serializer_type + ) + else: + if serializer not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {serializer}") + runner_cls = self.REGISTERED_RUNNERS[serializer] + + self.runner = runner_cls(**(serializer_args or {})) def get_checkpoint_file_path(self, epoch: int, step: int, rank: int) -> str: return self.CHECKPOINT_FILE_FORMAT.format(epoch=epoch, step=step, rank=rank, suffix=self.suffix) @@ -61,8 +220,7 @@ def get_last_dir_name(self) -> str: def get_best_dir_name(self) -> str: return self.CHECKPOINT_BEST_DIR_NAME - @classmethod - def load(cls, f: str | Path, *, device='cpu') -> Any: + def load(self, f: str | Path, *, device='cpu') -> Any: """ Loads a checkpoint file @@ -74,14 +232,14 @@ def load(cls, f: str | Path, *, device='cpu') -> Any: The device on which you want the tensors. """ suffix = Path(f).suffix - if suffix in cls.SUFFIX_HANDLERS: - load_func, _ = cls.SUFFIX_HANDLERS[suffix] - return load_func(f, device=device) + if suffix in self.SUFFIX_HANDLERS: + load_func, _ = self.SUFFIX_HANDLERS[suffix] else: - return torch.load(f, map_location=device, weights_only=False) + load_func = _torch_load - @classmethod - def save(cls, obj: Any, f: str | Path) -> None: + return self.runner.run_load(load_func, f, device=device) + + def save(self, obj: Any, f: str | Path) -> None: """ Saves a checkpoint file @@ -93,14 +251,14 @@ def save(cls, obj: Any, f: str | Path) -> None: otherwise, it will be saved as a PyTorch checkpoint file. """ suffix = Path(f).suffix - if suffix in cls.SUFFIX_HANDLERS: - _, save_func = cls.SUFFIX_HANDLERS[suffix] - save_func(obj, f) + if suffix in self.SUFFIX_HANDLERS: + _, save_func = self.SUFFIX_HANDLERS[suffix] else: - torch.save(obj, f) + save_func = _torch_save - @classmethod - def load_for_rank(cls, dir: str | Path, rank: int, device='cpu') -> Any: + self.runner.run_save(save_func, obj, f) + + def load_for_rank(self, dir: str | Path, rank: int, device='cpu') -> Any: """ Loads a checkpoint file for a specific rank @@ -112,10 +270,10 @@ def load_for_rank(cls, dir: str | Path, rank: int, device='cpu') -> Any: device (`str`, `int`, *optional*): The device on which you want the tensors. """ - for suffix in cls.SUFFIX_MAP.values(): + for suffix in self.NAME_MAP.values(): f = Path(dir) / f"{rank}{suffix}" if f.exists(): - return cls.load(f, device=device) + return self.load(f, device=device) raise FileNotFoundError(f"No checkpoint file found for rank {rank} in directory {dir}") def save_for_rank(self, obj: Any, dir: str | Path, rank: int) -> None: @@ -130,29 +288,72 @@ def save_for_rank(self, obj: Any, dir: str | Path, rank: int) -> None: rank (`int`): The rank of the checkpoint file to save. """ - f = Path(dir) / f"{rank}{self.suffix}" + f = Path(dir) / self.CHECKPOINT_FILE_NAME_FORMAT.format(rank=rank, suffix=self.suffix) self.save(obj, f) - @classmethod - def remove_for_rank(cls, dir: str | Path, rank: int) -> None: + def remove_for_rank(self, dir: str | Path, rank: int) -> None: """ - Removes a checkpoint file for a specific rank - + Removes a checkpoint file for a specific rank. Args: dir (`str`): The directory where the checkpoint files are stored. rank (`int`): The rank of the checkpoint file to remove. """ - for suffix in cls.SUFFIX_MAP.values(): + self.flush() + + for suffix in self.NAME_MAP.values(): f = Path(dir) / f"{rank}{suffix}" if f.exists(): f.unlink() + for extra_file in Path(dir).glob(f"{rank}{suffix}.*"): + extra_file.unlink() - @classmethod - def list_checkpoints(cls, dir: str | Path) -> list[Path]: + def copy_for_rank(self, src: str | Path, dst: str | Path, rank: int, symlink: bool = False) -> None: """ - List the checkpoint files in a directory + Copies a checkpoint file for a specific rank from one directory to another. + Args: + src (`str`): + The source directory where the checkpoint files are stored. + dst (`str`): + The destination directory where the checkpoint files will be copied. + rank (`int`): + The rank of the checkpoint file to copy. + symlink (`bool`, *optional*, defaults to `False`): + Whether to create a symbolic link instead of copying the file. + """ + + self.flush() + src = Path(src).resolve() + dst = Path(dst).resolve() + dst.mkdir(parents=True, exist_ok=True) + + src_f = Path(src) / f"{rank}{self.suffix}" + dst_f = Path(dst) / f"{rank}{self.suffix}" + + if not src_f.exists(): + raise FileNotFoundError(f"No checkpoint file found for rank {rank} in directory {src}") + + if symlink: + # this restricts symlink creation within the same directory + # so we can create relative symlink safely + if src.parent != dst.parent: + raise ValueError("Cannot create symlink when source and destination are not in the same directory.") + + if symlink: + dst_f.symlink_to(Path('..') / src.name / src_f.name) + for extra_file in src.glob(f"{rank}{self.suffix}.*"): + dst_extra_file = Path(dst) / extra_file.name + dst_extra_file.symlink_to(Path('..') / src.name / extra_file.name) + else: + shutil.copy2(src_f, dst_f) + for extra_file in src.glob(f"{rank}{self.suffix}.*"): + dst_extra_file = Path(dst) / extra_file.name + shutil.copy2(extra_file, dst_extra_file) + + def list_checkpoints(self, dir: str | Path) -> list[Path]: + """ + List the main checkpoint files in a directory Args: dir (`str`): The directory where the checkpoint files are stored. @@ -160,9 +361,11 @@ def list_checkpoints(cls, dir: str | Path) -> list[Path]: (`list[Path]`): The list of checkpoint files in the directory. """ + self.flush() + p = Path(dir) files = [] - for suffix in cls.SUFFIX_MAP.values(): + for suffix in self.NAME_MAP.values(): fs = list(p.glob(f"*{suffix}")) if fs: if files: @@ -171,25 +374,91 @@ def list_checkpoints(cls, dir: str | Path) -> list[Path]: files.extend(fs) return files + def flush(self) -> None: + """ + Flushes any pending operations. + """ + self.runner.flush() -def register_format( - name: str, - suffix: str, - save_func: _SaveProc, - load_func: _LoadProc, - ) -> None: + @classmethod + def get_format(cls, suffix: str) -> str: + """ + Gets the format name from the suffix. + """ + suffix = '.' + suffix.lstrip('.') + if suffix not in Checkpointer.SUFFIX_MAP: + raise ValueError(f"Unsupported checkpoint suffix: {suffix}") + return Checkpointer.SUFFIX_MAP[suffix] + + +def register_format(format: Type[CheckpointFormat]) -> None: """ Registers a new serialization format. + """ + suffix = '.' + format.suffix.lstrip('.') + Checkpointer.NAME_MAP[format.name] = suffix + Checkpointer.SUFFIX_MAP[suffix] = format.name + Checkpointer.SUFFIX_HANDLERS[suffix] = (format.load, format.save) + + +def register_serialization_runner(runner: Type[SerializationRunner]) -> None: + """ + Register a new serialization runner, which can intercept the load and save process. + For example, file redirection, chunking, asynchronous IO or other logic. + + Please note if you create extra files during saving, + you must make sure + 1. the suffix of the main checkpoint file must match registered formats. + 2. the name of extra files should start with the main checkpoint file name + '.', + but the suffix should not conflict with registered formats. + + For example, if the input checkpoint file is `model.ckpt`, + you must create a file called 'model.ckpt', + and you can use extra file names like 'model.ckpt.1', 'model.ckpt.meta', 'model.ckpt.opt' etc. + """ + if ':' in runner.name: + raise ValueError("Serialization runner name cannot contain ':'") + Checkpointer.REGISTERED_RUNNERS[runner.name] = runner + + +def convert_format( + src: str | Path, + dst: str | Path, + *, + src_serializer: str = None, + src_serializer_args: dict = None, + dst_serializer: str = None, + dst_serializer_args: dict = None, + device: str = 'cpu' +) -> None: + """ + Converts a checkpoint file from one format to another. + Args: - name (`str`): - The name of the format. - suffix (`str`): - The file suffix of the format. - load_func: - The function to load the format. - save_func: - The function to save the format. + src (`str` or `Path`): + The input checkpoint file. + dst (`str` or `Path`): + The output checkpoint file. + src_serializer (`str`, *optional*): + The serialization runner of the input checkpoint file. + src_serializer_args (`dict`, *optional*): + The arguments for the serialization runner of the input checkpoint file. + dst_serializer (`str`, *optional*): + The serialization runner of the output checkpoint file. + dst_serializer_args (`dict`, *optional*): + The arguments for the serialization runner of the output checkpoint file. + device (`str`, *optional*, defaults to `"cpu"`): + The device on which you want the tensors. """ - suffix = '.' + suffix.lstrip('.') - Checkpointer.SUFFIX_MAP[name] = suffix - Checkpointer.SUFFIX_HANDLERS[suffix] = (load_func, save_func) + src_format = Checkpointer.get_format(Path(src).suffix) + dst_format = Checkpointer.get_format(Path(dst).suffix) + + if src_format == dst_format and src_serializer == dst_serializer: + raise ValueError("Input and output formats and serializers are the same, no conversion needed.") + + src_checkpointer = Checkpointer(format=src_format, serializer=src_serializer, serializer_args=src_serializer_args) + dst_checkpointer = Checkpointer(format=dst_format, serializer=dst_serializer, serializer_args=dst_serializer_args) + + obj = src_checkpointer.load(src, device=device) + dst_checkpointer.save(obj, dst) + dst_checkpointer.flush() diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index e92a26de..c5d8289e 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -102,6 +102,8 @@ def run(self): if not self.train_args.compile_mode: self._train() finally: + if self.checkpointer: + self.checkpointer.flush() for stage in ['train', 'val', 'test']: if self.dataloader[stage] is not None and (close_fn := getattr(self.dataloader[stage], 'close', None)): close_fn() @@ -139,6 +141,7 @@ def _setup(self): logging.getLogger().setLevel(logging.WARNING) self.train_args.init_env(self) + self.checkpointer = self.train_args.create_checkpointer() # make sure all ranks are synchronized after init_env if is_running_distributed(): @@ -155,7 +158,8 @@ def _setup(self): pmodel = parallelize_model( self.train_args, self.dummy_input, load_module=not compile_only, - build_buckets=not self.train_args.should_delay_bucket_building() + build_buckets=not self.train_args.should_delay_bucket_building(), + checkpointer=self.checkpointer, ) if compile_only: return @@ -227,7 +231,6 @@ def reducer_pre_hook(reducer, grad): # Currently we never pass `last_epoch` to its constructor self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) self.loggers = self.train_args.create_loggers() - self.checkpointer = self.train_args.create_checkpointer() supported_hook_components = [ self.model, @@ -255,10 +258,15 @@ def reducer_pre_hook(reducer, grad): self.hook.after_setup(self) @classmethod - def _merge_checkpoint(cls, checkpoint_files: List[str], *, model_only: bool = False): + def _merge_checkpoint(cls, checkpoint_files: List[str], + *, + model_only: bool = False, + checkpointer: Optional[Checkpointer] = None, + ): + checkpointer = checkpointer or Checkpointer() state_dicts = [] for f in checkpoint_files: - state_dict = Checkpointer.load(f) + state_dict = checkpointer.load(f) if model_only: # we pop optimizer state to save cpu memory state_dict.pop('optimizer', None) @@ -377,9 +385,26 @@ def _broadcast_values(sdict, keys): return state_dict @classmethod - def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str, *, model_only: bool = False): - merged_state_dict = cls._merge_checkpoint(checkpoint_files, model_only=model_only) - Checkpointer.save(merged_state_dict, output_file) + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str, + *, + model_only: bool = False, + checkpointer: Optional[Checkpointer] = None, + serializer: Optional[str] = None, + serializer_args: Optional[dict[str, Any]] = None, + ): + if checkpointer is not None: + if serializer is not None or serializer_args is not None: + raise ValueError("serializer and serializer_args should not be specified when checkpointer is given") + else: + checkpointer = Checkpointer(serializer=serializer, serializer_args=serializer_args) + + merged_state_dict = cls._merge_checkpoint( + checkpoint_files, + model_only=model_only, + checkpointer=checkpointer, + ) + checkpointer.save(merged_state_dict, output_file) + checkpointer.flush() def _log_finalize(self): for logger in self.loggers: @@ -426,7 +451,7 @@ def _load_checkpoint(self): torch.distributed.barrier() if self.local_rank == 0: logger.info(f"Merging checkpoint files from {resume_from}") - state_dict = self._merge_checkpoint(list(rank_ckpt_files.values())) + state_dict = self._merge_checkpoint(list(rank_ckpt_files.values()), checkpointer=self.checkpointer) else: state_dict = None @@ -641,16 +666,12 @@ def _save_checkpoint(self, loss): save_dir / self.checkpointer.get_last_dir_name(), self.rank ) - last_file = save_dir / self.checkpointer.get_last_checkpoint_file_path( - rank=self.rank + self.checkpointer.copy_for_rank( + ckpt_file.parent, + save_dir / self.checkpointer.get_last_dir_name(), + self.rank, + checkpoint_config.symlink_best_and_last ) - last_file.parent.mkdir(parents=True, exist_ok=True) - if checkpoint_config.symlink_best_and_last: - # symblink as relative path - last_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) - # last_file.symlink_to(ckpt_file) - else: - shutil.copy(ckpt_file, last_file) # save best if checkpoint_config.save_best and loss <= self.train_status.best_loss: @@ -662,16 +683,12 @@ def _save_checkpoint(self, loss): save_dir / self.checkpointer.get_best_dir_name(), self.rank ) - best_file = save_dir / self.checkpointer.get_best_checkpoint_file_path( - rank=self.rank, + self.checkpointer.copy_for_rank( + ckpt_file.parent, + save_dir / self.checkpointer.get_best_dir_name(), + self.rank, + checkpoint_config.symlink_best_and_last ) - best_file.parent.mkdir(parents=True, exist_ok=True) - if checkpoint_config.symlink_best_and_last: - # symblink as relative path - best_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) - # best_file.symlink_to(ckpt_file) - else: - shutil.copy(ckpt_file, best_file) torch.distributed.barrier() # remove old checkpoints diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 29e42044..5cd811fa 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass, field, replace import importlib -from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union, TypeVar +from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Type, Union, TypeVar from typing_extensions import get_args from pathlib import Path import logging @@ -419,6 +419,23 @@ class ResumeOptions: save_memory: bool = True +@dataclass +class SerializerOptions: + # the serialization runner to be used + # It should be a name of registered SerializationRunners + name: str = '' + + # the full qualified name of the function to create the serialization runner + # Currently we do not support this way + # to make sure all serialization runners are registered and can be used in other places + # (like nnscaler.cli.Trainer.merge_checkpoint) + # type: str = None + + # arguments for the serialization runner + # Note You should be able to load for any arguments + args: Dict[str, Any] = field(default_factory=dict) + + @dataclass class CheckpointConfig: save_dir: str = './checkpoints' @@ -427,8 +444,16 @@ class CheckpointConfig: # `"pt"`: PyTorch native format # `"safetensors"`: Safetensors format # You can also register new formats via `nnscaler.cli.serialization.register_format` + # or specify a custom format here by providing a CheckpointFormat subclass format: str = 'pt' + # the serialization runner to be used + # It should be a name of registered SerializationRunners + # If None, the default serializer will be used + serializer: Optional[SerializerOptions] = field(default=None, metadata={ + 'normalize': lambda x: {'name': x} if isinstance(x, str) else x + }) + # `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is # a folder with as many files as the world size. # `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is @@ -475,6 +500,9 @@ def __post_init__(self): if isinstance(self.resume_from, str): self.resume_from = ResumeOptions(checkpoint=self.resume_from) + if isinstance(self.serializer, str): + self.serializer = SerializerOptions(name=self.serializer) + if self.resume_from and self.resume_from.checkpoint: if self.resume_from.checkpoint in ['last', 'best']: if not self.save_dir: @@ -491,9 +519,12 @@ def __post_init__(self): if not self.save_dir: raise ValueError("save_dir is required") - if self.format not in Checkpointer.SUFFIX_MAP: + if self.format not in Checkpointer.NAME_MAP: raise ValueError(f"Invalid format {self.format}") + if self.serializer and self.serializer.name not in Checkpointer.REGISTERED_RUNNERS: + raise ValueError(f"Invalid Serialization runner {self.serializer.name}") + if self.every_n_epochs is not None and self.every_n_train_steps is not None: raise ValueError("Cannot specify both every_n_epochs and every_n_train_steps") if self.every_n_epochs is None and self.every_n_train_steps is None: @@ -983,4 +1014,10 @@ def create_hook(self) -> TrainHook: raise ValueError(f"Invalid hook_config {hook_config}") def create_checkpointer(self) -> Checkpointer: + if self.checkpoint.serializer: + return Checkpointer( + self.checkpoint.format, + self.checkpoint.serializer.name, + self.checkpoint.serializer.args + ) return Checkpointer(self.checkpoint.format) diff --git a/nnscaler/runtime/serialization.py b/nnscaler/runtime/serialization.py index 0adc6b78..ff492c26 100644 --- a/nnscaler/runtime/serialization.py +++ b/nnscaler/runtime/serialization.py @@ -215,8 +215,11 @@ def load(f, *, device="cpu", format="safetensors", lazy=False) -> LazyLoader | A if not lazy: with LazyLoader(f, device=device) as loader: data = loader.get_lazy_data() - assert isinstance(data, _LazyContainer) - return data.load_all() + if isinstance(data, _LazyContainer): + return data.load_all() + else: + # pure data without any tensors + return data return LazyLoader(f, device=device) diff --git a/tests/cli/test_serialization.py b/tests/cli/test_serialization.py new file mode 100644 index 00000000..cd5ae9b4 --- /dev/null +++ b/tests/cli/test_serialization.py @@ -0,0 +1,245 @@ +import pytest +import torch +from pathlib import Path + +from nnscaler.cli.serialization import ( + convert_format, SerializationRunner, register_serialization_runner, + Checkpointer +) +from nnscaler.cli.trainer import Trainer +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.common import assert_equal + + +def test_runner(tmp_path): + + class SplitSerializationRunner(SerializationRunner): + name: str = 'split' + + def run_load(self, load_func, f, *, device='cpu'): + model_state_dict = load_func(f, device=device) + opt_state_dict = load_func(str(f) + '.opt', device=device) + return { + 'model': model_state_dict, + 'optimizer': opt_state_dict + } + + def run_save(self, save_func, state_dict, f): + save_func(state_dict['model'], f) + save_func(state_dict['optimizer'], str(f) + '.opt') + + register_serialization_runner(SplitSerializationRunner) + + a = torch.randn((2, 2), device='cpu') + b = torch.randn((2, 3), device='cpu') + c = torch.randn((4, 4), device='cpu') + d = torch.randn((3, 3), device='cpu') + tensors = { + "model": { + "embedding": a, + "attention": b, + }, + "optimizer": { + "state": { + 0: { + "exp_avg": c, + "exp_avg_sq": d, + } + } + } + } + checkpointer = Checkpointer() + checkpointer.save(tensors, tmp_path / "model.ckpt") + checkpointer.flush() + + convert_format( + src=str(tmp_path / "model.ckpt"), + dst=str(tmp_path / "model_split.ckpt"), + dst_serializer='split', + ) + + assert Path(tmp_path / "model_split.ckpt").exists() + assert Path(tmp_path / "model_split.ckpt.opt").exists() + tensor3 = Checkpointer(serializer='split').load(tmp_path / "model_split.ckpt") + assert_equal(tensors, tensor3) + + checkpointer2 = Checkpointer(serializer=':split') + tensor2 = checkpointer2.load(tmp_path / "model.ckpt") + assert_equal(tensors, tensor2) + + checkpointer2.save(tensor2, tmp_path / "model_split2.ckpt") + checkpointer2.flush() + assert Path(tmp_path / "model_split2.ckpt").exists() + assert Path(tmp_path / "model_split2.ckpt.opt").exists() + + tensor4 = Checkpointer(serializer='split').load(tmp_path / "model_split2.ckpt") + assert_equal(tensors, tensor4) + + +def trainer_split_serializer_worker(tmp_path, symblink): + save_dir = Path(tmp_path) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' + use_zero = 1 + format = 'safetensors' + rev_format = 'pt' if format == 'safetensors' else 'safetensors' + + def list_ckpt_files(dir): + return set(dir.glob('**/*.ckpt')) | set(dir.glob('**/*.safetensors')) + + + class SplitSerializationRunner(SerializationRunner): + name: str = 'split' + + def __init__(self, mark=''): + self.mark = mark + + def run_load(self, load_func, f, *, device='cpu'): + other_state_dict = load_func(f, device=device) + opt_state_dict = load_func(str(f) + '.opt', device=device) + model_state_dict = load_func(str(f) + '.model', device=device) + return { + 'model': model_state_dict, + 'optimizer': opt_state_dict, + **other_state_dict + } + + def run_save(self, save_func, state_dict, f): + save_func(state_dict['model'], str(f) + '.model') + save_func(state_dict['optimizer'], str(f) + '.opt') + other_state_dict = {k: v for k, v in state_dict.items() if k not in ['model', 'optimizer']} + other_state_dict['mark'] = self.mark + save_func(other_state_dict, f) + + register_serialization_runner(SplitSerializationRunner) + + # train 4 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.format', format, + '--checkpoint.serializer.name', 'split', + '--checkpoint.serializer.args.mark', 'hello', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + ckpt_files = list_ckpt_files(ckpt_savedir) + assert len(ckpt_files)/4 == min(10, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last + + for f in ckpt_files: + assert trainer.checkpointer.load(f)['mark'] == 'hello' + assert Path(str(f) + '.opt').exists() + assert Path(str(f) + '.model').exists() + + torch.distributed.barrier() + # train 4 epcho two times (resume from last) + ckpt0_savedir = save_dir / 'ckpt0' + # first two epochs + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '2', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.format', format, + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + + # create merged checkpoint + ckpt1_savedir = save_dir / 'ckpt1' + ckpt1_savedir.mkdir(parents=True, exist_ok=True) + merged_file_name = f'merged{Checkpointer.NAME_MAP[format]}' + if trainer.rank == 0: + Trainer.merge_checkpoint(trainer.checkpointer.list_checkpoints(ckpt0_savedir / 'last'), ckpt1_savedir / merged_file_name, serializer='split') + assert Path(str(ckpt1_savedir / merged_file_name) + '.opt').exists() + assert Path(str(ckpt1_savedir / merged_file_name) + '.model').exists() + + torch.distributed.barrier() + # continue with the last two epochs (resume for sharded/deduped checkpoint) + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.format', rev_format, + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + + torch.distributed.barrier() + + # continue with the last two epochs (resume for merged) + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt1_savedir), + '--checkpoint.format', rev_format, + '--checkpoint.resume_from', str(ckpt1_savedir / merged_file_name), + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) + + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} + for i in range(4): + x = trainer.checkpointer.load_for_rank(ckpt_savedir / 'last', i) + y = trainer.checkpointer.load_for_rank(ckpt0_savedir / 'last', i) + z = trainer.checkpointer.load_for_rank(ckpt1_savedir / 'last', i) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + assert_equal(x['lr_scheduler'], y['lr_scheduler']) + assert_equal(x['model'], z['model']) + assert_equal(x['optimizer'], z['optimizer']) + assert_equal(x['lr_scheduler'], z['lr_scheduler']) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('symblink', [True, False]) +def test_trainer_split_serializer(tmp_path, symblink): + launch_torchrun(4, trainer_split_serializer_worker, tmp_path, symblink) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index e70528d5..307adbef 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -260,9 +260,9 @@ def list_ckpt_files(dir): # create merged checkpoint ckpt1_savedir = save_dir / 'ckpt1' ckpt1_savedir.mkdir(parents=True, exist_ok=True) - merged_file_name = f'merged{Checkpointer.SUFFIX_MAP[format]}' + merged_file_name = f'merged{Checkpointer.NAME_MAP[format]}' if trainer.rank == 0: - Trainer.merge_checkpoint(Checkpointer.list_checkpoints(ckpt0_savedir / 'last'), ckpt1_savedir / merged_file_name) + Trainer.merge_checkpoint(trainer.checkpointer.list_checkpoints(ckpt0_savedir / 'last'), ckpt1_savedir / merged_file_name) torch.distributed.barrier() # continue with the last two epochs (resume for sharded/deduped checkpoint) @@ -331,9 +331,9 @@ def list_ckpt_files(dir): if torch.distributed.get_rank() == 0: assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} for i in range(4): - x = Checkpointer.load_for_rank(ckpt_savedir / 'last', i) - y = Checkpointer.load_for_rank(ckpt0_savedir / 'last', i) - z = Checkpointer.load_for_rank(ckpt1_savedir / 'last', i) + x = trainer.checkpointer.load_for_rank(ckpt_savedir / 'last', i) + y = trainer.checkpointer.load_for_rank(ckpt0_savedir / 'last', i) + z = trainer.checkpointer.load_for_rank(ckpt1_savedir / 'last', i) assert_equal(x['model'], y['model']) assert_equal(x['optimizer'], y['optimizer']) assert_equal(x['lr_scheduler'], y['lr_scheduler']) @@ -341,7 +341,7 @@ def list_ckpt_files(dir): assert_equal(x['optimizer'], z['optimizer']) assert_equal(x['lr_scheduler'], z['lr_scheduler']) - suffix = Checkpointer.SUFFIX_MAP[format] + suffix = Checkpointer.NAME_MAP[format] if save_type == 'deduped': assert (ckpt_savedir / f'last/0{suffix}').stat().st_size > (ckpt_savedir / f'last/2{suffix}').stat().st_size assert (ckpt_savedir / f'last/1{suffix}').stat().st_size > (ckpt_savedir / f'last/3{suffix}').stat().st_size @@ -1415,18 +1415,24 @@ def trainer_checkpointer_worker(save_dir): load_triggered = False - def save(obj: Any, f: Path) -> None: - obj['test'] = True - return torch.save(obj, f) + class TestFormat: + name: str = 'test_format' + suffix: str = '.testpt' - def load(f: str | Path, *, device='cpu') -> Any: - x = torch.load(f, map_location=device, weights_only=False) - assert x['test'] is True - nonlocal load_triggered - load_triggered = True - return x + @classmethod + def save(cls, obj: Any, f: Path) -> None: + obj['test'] = True + return torch.save(obj, f) - register_format('test_format', '.testpt', save, load) + @classmethod + def load(cls, f: str | Path, *, device='cpu') -> Any: + x = torch.load(f, map_location=device, weights_only=False) + assert x['test'] is True + nonlocal load_triggered + load_triggered = True + return x + + register_format(TestFormat) # train 1 epcho in one time trainer = Trainer([ diff --git a/tests/runtime/test_serialization.py b/tests/runtime/test_serialization.py index 8b8cab92..aaee1eae 100644 --- a/tests/runtime/test_serialization.py +++ b/tests/runtime/test_serialization.py @@ -2,6 +2,7 @@ import pytest from nnscaler.runtime.serialization import load, save, convert +from nnscaler.cli.serialization import convert_format from tests.parallel_module.common import assert_equal @@ -20,8 +21,14 @@ def test_normal(tmp_path): loaded = load(tmp_path / "model.safetensors", lazy=False) assert_equal(tensors, loaded) convert(tmp_path / "model.safetensors", tmp_path / "model.pt") + convert_format( + src=str(tmp_path / "model.safetensors"), + dst=str(tmp_path / "model2.ckpt"), + ) loaded_pt = torch.load(tmp_path / "model.pt") assert_equal(tensors, loaded_pt) + loaded_pt2 = torch.load(tmp_path / "model2.ckpt") + assert_equal(tensors, loaded_pt2) def test_shared_params(tmp_path): From 91ccb4a06313b3edc9d6a349b6207ff23c3cdf5f Mon Sep 17 00:00:00 2001 From: nnScaler Date: Wed, 31 Dec 2025 21:12:02 +0800 Subject: [PATCH 1864/1892] Fix & Update --- README.md | 1 - azure-pipelines.yml | 21 + benchmark/README.md | 357 ++++++ benchmark/benchmark_base.py | 426 +++++++ benchmark/benchmark_ring_attn.py | 206 ++++ benchmark/benchmark_ring_attn_varlen.py | 222 ++++ benchmark/benchmark_zigzag_attn.py | 243 ++++ dev.md | 2 +- docs/source/installation.rst | 2 +- docs/source/parallel_module.md | 6 +- docs/source/quickstart.rst | 2 +- docs/source/quickstart_internal.rst | 116 ++ examples/warmup_scheduler.py | 81 ++ nnscaler/__init__.py | 8 + nnscaler/autodist/autodist_config.py | 16 +- nnscaler/autodist/cost_database.py | 34 +- nnscaler/autodist/dp_solver.cpp | 79 +- nnscaler/autodist/model_graph.py | 3 + nnscaler/autodist/spmd_solver.py | 25 +- nnscaler/cli/__init__.py | 2 + nnscaler/cli/arg_parser.py | 4 + nnscaler/cli/checkpoint.py | 163 +++ nnscaler/cli/mixed_module.py | 60 +- nnscaler/cli/serialization.py | 464 ++++++++ nnscaler/cli/train_hook.py | 28 + nnscaler/cli/trainer.py | 293 +++-- nnscaler/cli/trainer_args.py | 114 +- nnscaler/codegen/emit.py | 69 +- nnscaler/codegen/module/module.py | 165 ++- nnscaler/customized_ops/__init__.py | 0 .../customized_ops/ring_attention/README.md | 219 ++++ .../customized_ops/ring_attention/__init__.py | 8 + .../core/ring_attn_implementation.py | 326 +++++ .../core/ring_attn_varlen_implementation.py | 516 ++++++++ .../ring_attention/core/utils.py | 343 ++++++ .../core/zigzag_attn_implementation.py | 516 ++++++++ .../ring_attention/ring_attn.py | 113 ++ .../ring_attention/ring_attn_varlen.py | 308 +++++ .../ring_attention/varlen_utils.py | 182 +++ .../ring_attention/zigzag_attn.py | 117 ++ nnscaler/flags.py | 2 +- nnscaler/graph/function/dimops.py | 13 +- nnscaler/graph/function/function.py | 146 ++- nnscaler/graph/graph.py | 12 +- nnscaler/graph/parser/__init__.py | 2 +- nnscaler/graph/parser/converter.py | 9 +- nnscaler/graph/parser/external/__init__.py | 3 +- nnscaler/graph/parser/external/apex.py | 3 +- nnscaler/graph/parser/external/einops.py | 18 + nnscaler/graph/parser/mapping.py | 1 + nnscaler/graph/parser/parser.py | 329 +++-- nnscaler/graph/parser/value_tracker.py | 303 +++++ nnscaler/graph/tracer/concrete_tracer.py | 76 +- nnscaler/graph/tracer/metadata.py | 6 +- nnscaler/graph/tracer/operator_patcher.py | 22 +- nnscaler/ir/cten.py | 350 +++++- nnscaler/ir/tensor.py | 21 +- nnscaler/ir/unique.py | 14 +- nnscaler/parallel.py | 1038 +++++++++++++--- nnscaler/policies.py | 521 +++++++- nnscaler/runtime/__init__.py | 1 + nnscaler/runtime/_patch_torch.py | 104 ++ nnscaler/runtime/adapter/reducer.py | 527 +++++++-- nnscaler/runtime/f16_optimizer.py | 22 +- nnscaler/runtime/function/function.py | 41 +- nnscaler/runtime/gnorm.py | 13 +- nnscaler/runtime/hybrid_optimizer.py | 304 +++++ nnscaler/runtime/module.py | 749 ++++++++++-- nnscaler/runtime/serialization.py | 252 ++++ nnscaler/runtime/utils.py | 14 +- nnscaler/utils.py | 305 ++++- nnscaler/version.py | 2 +- pipelines/nightly-build.yaml | 25 + pipelines/release.yaml | 39 + pipelines/scripts/update_version.py | 71 ++ requirements-dev.txt | 1 + requirements.txt | 3 +- tests/autodist/test_dp_solver.py | 26 + tests/cli/common.py | 81 ++ tests/cli/test_arg_parser.py | 22 + tests/cli/test_hooks.py | 30 + tests/cli/test_serialization.py | 245 ++++ tests/cli/test_trainer.py | 616 +++++++++- tests/cli/test_trainer2.py | 134 +++ tests/cli/trainer_args.yaml | 1 + tests/cli/trainer_args_csa.yaml | 53 + tests/cli/trainer_args_mixed_bf16.yaml | 36 + tests/cli/trainer_args_shared_weights.yaml | 32 + tests/customized_ops/__init__.py | 4 + tests/customized_ops/ring_attn/__init__.py | 0 tests/customized_ops/ring_attn/configs.py | 268 +++++ .../ring_attn/ring_attn_runner.py | 118 ++ .../ring_attn/ring_attn_varlen_runner.py | 107 ++ tests/customized_ops/ring_attn/runner_base.py | 301 +++++ tests/customized_ops/ring_attn/test_base.py | 241 ++++ .../ring_attn/test_ring_attn.py | 60 + .../ring_attn/test_ring_attn_varlen.py | 57 + .../ring_attn/test_shuffle_varlen.py | 213 ++++ .../ring_attn/test_zigzag_attn.py | 77 ++ .../ring_attn/zigzag_attn_runner.py | 106 ++ tests/graph/function/test_functions.py | 56 + tests/graph/parser/test_ast_transformer.py | 12 +- tests/graph/parser/test_converter.py | 37 +- tests/graph/parser/test_parser.py | 2 +- tests/graph/parser/test_value_tracker.py | 194 +++ tests/graph/test_segment.py | 6 +- tests/graph/tracer/test_pack_kwargs.py | 33 + tests/ir/test_cten.py | 4 +- tests/parallel_module/common.py | 26 +- tests/parallel_module/test_attr_dedup.py | 11 +- tests/parallel_module/test_broadcast.py | 4 +- tests/parallel_module/test_checkpoint.py | 53 +- .../parallel_module/test_checkpoint_dedup.py | 2 - tests/parallel_module/test_ddp.py | 32 +- tests/parallel_module/test_gencode.py | 362 +++++- tests/parallel_module/test_gencode_einops.py | 67 ++ tests/parallel_module/test_gencode_kwargs.py | 189 +++ .../test_gencode_torch_compile.py | 4 +- tests/parallel_module/test_init.py | 8 +- tests/parallel_module/test_offload_params.py | 213 ++++ tests/runtime/test_gnorm.py | 2 +- tests/runtime/test_hybrid_optimizer.py | 153 +++ .../test_hybrid_optimizer_trainer_args.yaml | 76 ++ tests/runtime/test_serialization.py | 75 ++ tests/test_policies.py | 1053 ++++++++++++++++- tests/test_utils.py | 127 +- tests/utils.py | 15 +- utility/aggregate.sh | 15 + utility/broadcast.sh | 15 + utility/comm_profile.py | 108 ++ utility/dgx1_reorder_gpu.py | 119 ++ utility/keep.py | 72 ++ utility/prim_profiler.py | 52 + utility/test_rvd_prim.py | 137 +++ utility/verify_ops/verify_dimops.py | 470 ++++++++ utility/verify_ops/verify_graph_operations.py | 161 +++ utility/visualize_value_tracks.py | 158 +++ 137 files changed, 17269 insertions(+), 904 deletions(-) create mode 100644 azure-pipelines.yml create mode 100644 benchmark/README.md create mode 100644 benchmark/benchmark_base.py create mode 100644 benchmark/benchmark_ring_attn.py create mode 100644 benchmark/benchmark_ring_attn_varlen.py create mode 100644 benchmark/benchmark_zigzag_attn.py create mode 100644 docs/source/quickstart_internal.rst create mode 100644 examples/warmup_scheduler.py create mode 100644 nnscaler/cli/checkpoint.py create mode 100644 nnscaler/cli/serialization.py create mode 100644 nnscaler/customized_ops/__init__.py create mode 100644 nnscaler/customized_ops/ring_attention/README.md create mode 100644 nnscaler/customized_ops/ring_attention/__init__.py create mode 100644 nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py create mode 100644 nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py create mode 100644 nnscaler/customized_ops/ring_attention/core/utils.py create mode 100644 nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py create mode 100644 nnscaler/customized_ops/ring_attention/ring_attn.py create mode 100644 nnscaler/customized_ops/ring_attention/ring_attn_varlen.py create mode 100644 nnscaler/customized_ops/ring_attention/varlen_utils.py create mode 100644 nnscaler/customized_ops/ring_attention/zigzag_attn.py create mode 100644 nnscaler/graph/parser/external/einops.py create mode 100644 nnscaler/graph/parser/value_tracker.py create mode 100644 nnscaler/runtime/_patch_torch.py create mode 100644 nnscaler/runtime/hybrid_optimizer.py create mode 100644 nnscaler/runtime/serialization.py create mode 100644 pipelines/nightly-build.yaml create mode 100644 pipelines/release.yaml create mode 100644 pipelines/scripts/update_version.py create mode 100644 tests/cli/test_hooks.py create mode 100644 tests/cli/test_serialization.py create mode 100644 tests/cli/test_trainer2.py create mode 100644 tests/cli/trainer_args_csa.yaml create mode 100644 tests/cli/trainer_args_mixed_bf16.yaml create mode 100644 tests/cli/trainer_args_shared_weights.yaml create mode 100644 tests/customized_ops/__init__.py create mode 100644 tests/customized_ops/ring_attn/__init__.py create mode 100644 tests/customized_ops/ring_attn/configs.py create mode 100644 tests/customized_ops/ring_attn/ring_attn_runner.py create mode 100644 tests/customized_ops/ring_attn/ring_attn_varlen_runner.py create mode 100644 tests/customized_ops/ring_attn/runner_base.py create mode 100644 tests/customized_ops/ring_attn/test_base.py create mode 100644 tests/customized_ops/ring_attn/test_ring_attn.py create mode 100644 tests/customized_ops/ring_attn/test_ring_attn_varlen.py create mode 100644 tests/customized_ops/ring_attn/test_shuffle_varlen.py create mode 100644 tests/customized_ops/ring_attn/test_zigzag_attn.py create mode 100644 tests/customized_ops/ring_attn/zigzag_attn_runner.py create mode 100644 tests/graph/parser/test_value_tracker.py create mode 100644 tests/graph/tracer/test_pack_kwargs.py create mode 100644 tests/parallel_module/test_gencode_einops.py create mode 100644 tests/parallel_module/test_gencode_kwargs.py create mode 100644 tests/parallel_module/test_offload_params.py create mode 100644 tests/runtime/test_hybrid_optimizer.py create mode 100644 tests/runtime/test_hybrid_optimizer_trainer_args.yaml create mode 100644 tests/runtime/test_serialization.py create mode 100644 utility/aggregate.sh create mode 100644 utility/broadcast.sh create mode 100644 utility/comm_profile.py create mode 100644 utility/dgx1_reorder_gpu.py create mode 100644 utility/keep.py create mode 100644 utility/prim_profiler.py create mode 100644 utility/test_rvd_prim.py create mode 100644 utility/verify_ops/verify_dimops.py create mode 100644 utility/verify_ops/verify_graph_operations.py create mode 100644 utility/visualize_value_tracks.py diff --git a/README.md b/README.md index 1b99ef0d..4062c2fa 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,6 @@ nnScaler is a parallelization engine that compiles a Deep neural network (DNN) m # Latest News nnScaler (also known as CUBE as code name) has been adopted by multiple product and research projects, this section includes some of the latest news from the team and partner projects. -* **2025-08-12** nnScaler 0.8 released: https://github.com/microsoft/nnscaler/releases/tag/0.8 * **2025-02-12** nnScaler 0.7 released: https://github.com/microsoft/nnscaler/releases/tag/0.7 * **2024-10-07** Diff-Transformer utilizes nnScaler for differential attention mechanism: [DIFFERENTIAL TRANSFORMER](https://arxiv.org/abs/2410.05258) * **2024-05-09** YOCO utilizes nnScaler for long-sequence training: [(YOCO)You only cache once: Decoder-decoder architectures for language models](https://arxiv.org/abs/2405.05254) diff --git a/azure-pipelines.yml b/azure-pipelines.yml new file mode 100644 index 00000000..f8a2f6de --- /dev/null +++ b/azure-pipelines.yml @@ -0,0 +1,21 @@ +# Starter pipeline +# Start with a minimal pipeline that you can customize to build and deploy your code. +# Add steps that build, run tests, deploy, and more: +# https://aka.ms/yaml + +trigger: +- main + +pool: + vmImage: ubuntu-latest + +steps: +- script: | + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main; + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r; + pip install tox + pip install tox-conda + displayName: 'Install tox' +- script: | + tox + displayName: 'Run unit tests' diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..1cd09a68 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,357 @@ +# Ring Attention Performance Benchmarks + +This directory contains a unified performance benchmarking framework for all Ring Attention variants, built using a shared architecture that eliminates code duplication and provides consistent interfaces. + +## 🏗️ Architecture + +The benchmark framework consists of: + +### Core Framework +- **`benchmark_base.py`**: Shared benchmark framework extending the test framework +- **Configuration System**: Unified configuration management via `../tests/customized_ops/ring_attn/configs.py` + +### Attention Implementations +- **`benchmark_ring_attn.py`**: Standard Ring Attention benchmarks +- **`benchmark_ring_attn_varlen.py`**: Variable Length Ring Attention benchmarks +- **`benchmark_zigzag_attn.py`**: Zigzag Ring Attention benchmarks (causal-only) + +## 🚀 Quick Start + +### 1. List Available Configurations + +```bash +cd benchmark + +# List configurations for any benchmark variant +python benchmark_ring_attn_varlen.py --list-configs +python benchmark_ring_attn.py --list-configs +python benchmark_zigzag_attn.py --list-configs +``` + +### 2. Run Basic Benchmarks + +```bash +# Ring Attention Variable Length +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium + +# Standard Ring Attention +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config small + +# Zigzag Ring Attention (causal-only) +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config tiny +``` + +### 3. Advanced Usage + +```bash +# Custom timing parameters +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium --timing-method warmup --warmup-runs 5 --timing-runs 10 + +# Detailed profiling +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config large --timing-method profiler + +# Custom configurations (legacy support) +torchrun --nproc_per_node=2 benchmark_ring_attn.py --seqlen 8192 --nheads 16 --head-dim 128 --batch-size 4 +``` + +## 📋 Available Configurations + +The benchmark framework uses a comprehensive configuration system with predefined configurations for different testing scenarios. + +### Configuration Categories + +#### Small Configs (Quick Testing) +- **`tiny`**: 2×8×64, seq=1024, tokens=1K, bf16 [Causal] +- **`small`**: 4×12×128, seq=4096, tokens=4K, bf16 [Causal] +- **`small_fp16`**: 4×12×128, seq=4096, tokens=4K, fp16 [Non-causal] +- **`small_window`**: 4×12×128, seq=4096, tokens=4K, bf16 [Causal] [Window=512,0] + +#### Medium Configs (Standard Testing) +- **`medium`**: 4×24×128, seq=8192, tokens=8K, bf16 [Causal] +- **`medium_large_head`**: 4×12×256, seq=8192, tokens=8K, bf16 [Non-causal] +- **`medium_many_heads`**: 4×32×128, seq=8192, tokens=8K, bf16 [Causal] +- **`medium_fp16`**: 4×24×128, seq=8192, tokens=8K, fp16 [Causal] +- **`medium_window`**: 4×24×128, seq=8192, tokens=8K, bf16 [Causal] [Window=512,0] + +#### Large Configs (Performance Testing) +- **`large`**: 4×32×128, seq=16384, tokens=16K, bf16 [Causal] +- **`large_seq`**: 4×24×128, seq=32768, tokens=32K, bf16 [Causal] +- **`large_head`**: 4×24×256, seq=16384, tokens=16K, bf16 [Non-causal] +- **`xlarge`**: 8×32×128, seq=32768, tokens=32K, bf16 [Causal] +- **`large_window`**: 4×32×128, seq=16384, tokens=16K, bf16 [Causal] [Window=512,0] + +#### GQA Configs (Grouped Query Attention) +- **`qwen3_235b_a22b`**: 2×64×64, seq=16384, tokens=16K, bf16 (GQA 64→4) [Causal] +- **`qwen3_30b_a3b`**: 4×32×64, seq=16384, tokens=16K, bf16 (GQA 32→4) [Causal] +- **`qwen3_4b`**: 4×32×80, seq=16384, tokens=16K, bf16 (GQA 32→4) [Causal] +- **`qwen3_32b`**: 2×64×128, seq=16384, tokens=16K, bf16 (GQA 64→8) [Causal] +- **`qwen3_14b`**: 4×40×128, seq=16384, tokens=16K, bf16 (GQA 40→8) [Causal] + +#### Zigzag Configs (Causal-Only) +- **`zigzag_tiny`**: 2×8×64, seq=1024, tokens=1K, bf16 [Causal] +- **`zigzag_small`**: 4×12×128, seq=4096, tokens=4K, bf16 [Causal] +- **`zigzag_medium`**: 4×24×128, seq=8192, tokens=8K, bf16 [Causal] +- **`zigzag_large`**: 4×32×128, seq=16384, tokens=16K, bf16 [Causal] +- **`zigzag_fp16`**: 4×12×128, seq=4096, tokens=4K, fp16 [Causal] +- **`zigzag_gqa`**: 4×32×128, seq=8192, tokens=8K, bf16 (GQA 32→8) [Causal] + +### Default Configuration Sets +- **Correctness Testing**: `["tiny", "small", "medium"]` +- **Performance Testing**: `["medium", "large"]` +- **Multi-GPU Testing**: `["small", "medium"]` +- **GQA Testing**: `["qwen3_4b", "qwen3_14b", "qwen3_32b"]` +- **Zigzag Testing**: `["zigzag_tiny", "zigzag_small", "zigzag_medium"]` + +## 🔧 Features + +### Unified Framework +- **Shared Base Class**: All benchmarks extend `RingAttnBenchmarkBase` for consistency +- **Code Reuse**: Leverages test framework components (`test_base.py`, `runner_base.py`) +- **Consistent Interface**: Same command-line options across all attention variants + +### Multiple Timing Methods +- **`simple`**: Basic CUDA timing measurements (fastest) +- **`warmup`**: Multiple runs with warm-up (recommended for accurate results) +- **`profiler`**: torch.profiler with detailed kernel analysis + +### Comprehensive Metrics +- **Performance**: Forward/backward timing, throughput (tokens/sec) +- **Scalability**: Speedup analysis, parallel efficiency +- **Memory**: GPU memory usage tracking +- **Comparative**: Single vs. parallel mode analysis + +### Configuration Support +- **Predefined Configs**: 20+ predefined configurations covering different scales +- **Legacy Parameters**: Backward compatibility with custom parameters +- **Attention Variants**: Support for standard, variable-length, and zigzag attention +- **GQA Support**: Grouped Query Attention configurations based on Qwen models + +## 🧪 Usage Examples + +### Basic Performance Testing +```bash +# Quick benchmarks with different attention types +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config tiny --timing-method simple +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config small --timing-method warmup +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config medium --dtype fp16 +``` + +### Comparative Analysis +```bash +# Compare different attention mechanisms on same config +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --timing-method warmup +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium --timing-method warmup +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config medium --timing-method warmup +``` + +### Advanced Profiling +```bash +# Detailed profiler analysis +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config large --timing-method profiler + +# Custom timing parameters for high precision +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --timing-method warmup --warmup-runs 10 --timing-runs 20 +``` + +### GQA Performance Testing +```bash +# Test Grouped Query Attention configurations +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config qwen3_4b --timing-method warmup +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config qwen3_14b --timing-method warmup +``` + +### Legacy Support (Custom Parameters) +```bash +# Override specific parameters while using predefined base +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --seqlen 16384 --nheads 32 + +# Full custom configuration +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --seqlen 8192 --nheads 16 --head-dim 128 --batch-size 4 --dtype bf16 +``` + +## 📈 Output Interpretation + +The benchmark framework provides comprehensive performance analysis: + +### Performance Metrics +``` +================================================================================ +RING ATTENTION VARIABLE LENGTH PERFORMANCE BENCHMARK (WARMUP METHOD) +Configuration: medium - medium + Sequence length: 8192 + Batch size: 4 + Heads: 24 + Head dim: 128 + Data type: bf16 + World size: 2 GPUs + Total tokens: 8,192 + (Warmup runs: 3, Timing runs: 5) +================================================================================ +Single Mode: + Forward time: 0.001234 seconds + Backward time: 0.002345 seconds + Total time: 0.003579 seconds + Throughput: 2288764 tokens/sec + +Parallel Mode: + Forward time: 0.000987 seconds + Backward time: 0.001654 seconds + Total time: 0.002641 seconds + Throughput: 3102234 tokens/sec + +Speedup: + Forward speedup: 1.25x + Backward speedup: 1.42x + Total speedup: 1.35x + Throughput improvement: 1.35x + +Efficiency: + Theoretical speedup: 2x + Actual speedup: 1.35x + Parallel efficiency: 67.7% +================================================================================ +``` + +### Key Metrics Explained +- **Forward/Backward Time**: Separate timing for forward and backward passes +- **Throughput**: Tokens processed per second (higher = better) +- **Speedup**: Performance ratio vs single GPU (higher = better) +- **Parallel Efficiency**: Actual speedup / theoretical speedup (closer to 100% = better) + +### Profiler Output (when using `--timing-method profiler`) +When using the profiler method, you get additional detailed analysis: +- Kernel-level timing breakdown +- Memory bandwidth utilization +- CUDA kernel execution patterns +- Optimization recommendations + +## 🎯 Attention Variant Characteristics + +### Ring Attention (`benchmark_ring_attn.py`) +- **Format**: Standard batch format `[batch_size, seq_len, num_heads, head_dim]` +- **Use Case**: General purpose attention for standard transformer models +- **Constraints**: Supports both causal and non-causal attention, sliding windows + +### Ring Attention Variable Length (`benchmark_ring_attn_varlen.py`) +- **Format**: Packed format `[total_tokens, num_heads, head_dim]` with `cu_seqlens` +- **Use Case**: Optimized for variable-length sequences, eliminates padding waste +- **Constraints**: Supports causal/non-causal attention, sliding windows + +### Zigzag Attention (`benchmark_zigzag_attn.py`) +- **Format**: Standard batch format `[batch_size, seq_len, num_heads, head_dim]` +- **Use Case**: Specialized for causal attention with optimized communication pattern +- **Constraints**: **Only supports causal=True and window_size=(-1,-1)** + +## 🔗 Integration with Test Framework + +The benchmark framework is tightly integrated with the correctness test framework: + +### Shared Components +- **Configuration System**: Same `configs.py` used for both correctness and performance testing +- **Base Classes**: Reuses `RingAttnRunnerBase` from `runner_base.py` +- **Distributed Setup**: Shared GPU detection and distributed initialization +- **Error Handling**: Consistent tolerance and validation logic + +### Workflow Integration +```bash +# 1. Run correctness tests first +cd /path/to/MagicCube +pytest tests/customized_ops/ring_attn/test_ring_attn_varlen.py --config tiny + +# 2. Then run performance benchmarks +cd benchmark +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config tiny +``` + +## ⚠️ Requirements & Setup + +### System Requirements +- **Multi-GPU Setup**: Most benchmarks require 2+ GPUs (use `torchrun --nproc_per_node=N`) +- **GPU Memory**: Large configs may require high-memory GPUs (A100, H100 recommended) +- **CUDA**: Compatible CUDA installation (11.8+ recommended) +- **Python Environment**: PyTorch with NCCL support for distributed training + +### Optional Components +- **TransformerEngine**: Install TE 2.2.0+ for optimal performance (auto-detected) +- **Flash Attention**: Required for base attention implementations +- **InfiniBand**: Recommended for multi-node setups (reduces communication latency) + +### Environment Setup +```bash +# From MagicCube root directory +cd benchmark + +# Verify imports work correctly +python -c " +from benchmark_base import RingAttnBenchmarkBase +print('✓ Benchmark framework ready') +" + +# Test configuration system +python benchmark_ring_attn_varlen.py --list-configs +``` + +## 🚨 Troubleshooting + +### Common Issues + +#### GPU/Memory Issues +```bash +# OOM errors: Use smaller configs or reduce batch size +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config tiny # Instead of large + +# Insufficient GPUs: Check available GPUs +python -c "import torch; print(f'Available GPUs: {torch.cuda.device_count()}')" +``` + +#### Import/Path Issues +```bash +# Import errors: Ensure running from correct directory +cd /path/to/MagicCube/benchmark +python benchmark_ring_attn.py --help + +# Configuration import errors +python -c " +import sys, os +sys.path.insert(0, '../tests/customized_ops/ring_attn') +from configs import get_config +print('✓ Config system working') +" +``` + +#### Distributed Training Issues +```bash +# NCCL errors: Check GPU compatibility and CUDA setup +export NCCL_DEBUG=INFO # For detailed NCCL debugging + +# Port conflicts: Use different port +torchrun --master_port=29501 --nproc_per_node=2 benchmark_ring_attn.py --config tiny +``` + +### Performance Debugging +```bash +# Test basic functionality without distributed training +CUDA_VISIBLE_DEVICES=0 python -c " +from benchmark_ring_attn import RingAttnBenchmark +print('✓ Benchmark classes load correctly') +" + +# Verify attention implementations work +cd ../tests/customized_ops/ring_attn +pytest test_ring_attn.py::TestRingAttn::test_ring_attn_tiny -v +``` + +**Note**: Actual efficiency depends on hardware, network, and system configuration. + +## 📚 Related Documentation + +### Core Documentation +- **Ring Attention Implementation**: `../nnscaler/customized_ops/ring_attention/README.md` +- **Test Framework**: `../tests/customized_ops/ring_attn/README.md` +- **Development Guide**: `../dev_docs/README_refactoring.md` +- **Testing Results**: `../dev_docs/benchmark_testing_results.md` + +--- + +**For implementation details**: See `../nnscaler/customized_ops/ring_attention/` +**For correctness testing**: See `../tests/customized_ops/ring_attn/` \ No newline at end of file diff --git a/benchmark/benchmark_base.py b/benchmark/benchmark_base.py new file mode 100644 index 00000000..4226d2ed --- /dev/null +++ b/benchmark/benchmark_base.py @@ -0,0 +1,426 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base benchmark framework for ring attention performance tests. +This module extends the test framework to support performance benchmarking. +""" + +import os +import sys +import time +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Tuple, Callable + +import torch +import torch.distributed as dist +from torch.profiler import profile, ProfilerActivity + +# Add tests directory to path to import test framework +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) + +from runner_base import RingAttnRunnerBase +from configs import get_config, get_configs_by_category, DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnBenchmarkBase(RingAttnRunnerBase): + """Base class for ring attention performance benchmarks""" + + def __init__(self): + super().__init__() + self.timing_method = "warmup" + self.warmup_runs = 3 + self.timing_runs = 5 + + @abstractmethod + def get_benchmark_name(self) -> str: + """Return the benchmark name for display""" + pass + + def run_timing_with_warmup(self, forward_fn: Callable, backward_fn: Callable, + warmup_runs: int = None, timing_runs: int = None) -> Tuple[float, float, Any]: + """Run timing with warm-up runs to get accurate measurements.""" + warmup_runs = warmup_runs or self.warmup_runs + timing_runs = timing_runs or self.timing_runs + + # Warm-up runs + for _ in range(warmup_runs): + torch.cuda.synchronize() + output = forward_fn() + torch.cuda.synchronize() + backward_fn(output) + torch.cuda.synchronize() + + # Timing runs + forward_times = [] + backward_times = [] + + for _ in range(timing_runs): + # Forward timing + torch.cuda.synchronize() + start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_time = time.perf_counter() - start + forward_times.append(forward_time) + + # Backward timing + torch.cuda.synchronize() + start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_time = time.perf_counter() - start + backward_times.append(backward_time) + + # Return average times + avg_forward = sum(forward_times) / len(forward_times) + avg_backward = sum(backward_times) / len(backward_times) + return avg_forward, avg_backward, output + + def run_timing_with_profiler(self, forward_fn: Callable, backward_fn: Callable, + rank_id: int = 0) -> Tuple[float, float, Any]: + """Run timing using torch.profiler for detailed analysis.""" + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + # Run profiler with timing + torch.cuda.synchronize() + + with profile(activities=activities, record_shapes=True, with_stack=True) as prof: + torch.cuda.synchronize() + forward_start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_end = time.perf_counter() + + torch.cuda.synchronize() + backward_start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_end = time.perf_counter() + + torch.cuda.synchronize() + + # Calculate timing from our measurements + forward_time = forward_end - forward_start + backward_time = backward_end - backward_start + + if rank_id == 0: + self._print_profiler_results(prof) + + return forward_time, backward_time, output + + def run_timing_simple(self, forward_fn: Callable, backward_fn: Callable) -> Tuple[float, float, Any]: + """Run simple timing without warmup or profiling.""" + torch.cuda.synchronize() + forward_start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_time = time.perf_counter() - forward_start + + torch.cuda.synchronize() + backward_start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_time = time.perf_counter() - backward_start + + return forward_time, backward_time, output + + def _print_profiler_results(self, prof): + """Print profiler results with fallback for different PyTorch versions.""" + print("\n" + "="*60) + print("TORCH PROFILER RESULTS") + print("="*60) + + try: + # Try the most common sorting options + events = prof.key_averages() + table_str = events.table(sort_by="self_cuda_time_total", row_limit=20) + print(table_str) + except Exception as e1: + try: + table_str = events.table(sort_by="cuda_time_total", row_limit=20) + print(table_str) + except Exception as e2: + try: + table_str = events.table(sort_by="self_cpu_time_total", row_limit=20) + print(table_str) + except Exception as e3: + print(f"Warning: Could not generate profiler table due to API differences") + print(f"Errors: {e1}, {e2}, {e3}") + + # Fallback: print basic event info + print("Available profiler events:") + for i, event in enumerate(events): + if i >= 10: # Limit output + break + try: + print(f" {event.key}: CPU time = {getattr(event, 'cpu_time_total', 'N/A')} us") + except: + print(f" {event.key}: [timing info unavailable]") + + print("="*60 + "\n") + + def create_timing_functions(self, inputs, config, dout_tensor): + """Create timing functions for single and parallel execution.""" + # Single mode functions + def single_forward(): + single_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + single_inputs[k] = v.detach().clone().requires_grad_() + else: + single_inputs[k] = v.detach().clone() + else: + single_inputs[k] = v + + # Run single GPU reference + output, grad_tensors = self.run_single_gpu_reference(single_inputs, config) + return output, (single_inputs, grad_tensors) + + def single_backward(outputs): + output, (single_inputs, grad_tensors) = outputs + output.backward(dout_tensor) + return dout_tensor + + # Parallel mode functions + model = self.create_test_module(config) + dummy_args = self.get_dummy_forward_args(inputs) + + from nnscaler.parallel import parallelize, ComputeConfig, ReuseType + world_size = dist.get_world_size() + + parallel_model = parallelize( + model, + dummy_forward_args=dummy_args, + pas_policy=self.create_policy(), + compute_config=ComputeConfig(world_size, world_size), + reuse=ReuseType.OVERRIDE + ) + parallel_model = parallel_model.cuda() + parallel_model.train() + + def parallel_forward(): + para_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + para_inputs[k] = v.detach().clone().requires_grad_() + else: + para_inputs[k] = v.detach().clone() + else: + para_inputs[k] = v + + output = parallel_model(**para_inputs) + return output, para_inputs + + def parallel_backward(outputs): + output, para_inputs = outputs + output.backward(dout_tensor) + parallel_model.sync_grad() + return dout_tensor + + return single_forward, single_backward, parallel_forward, parallel_backward + + def calculate_throughput_metrics(self, config, forward_time: float, backward_time: float) -> Dict[str, float]: + """Calculate throughput and efficiency metrics.""" + total_time = forward_time + backward_time + + # Calculate total tokens processed + if hasattr(config, 'total_tokens'): + total_tokens = config.total_tokens + else: + total_tokens = config.batch_size * config.max_seqlen + + throughput = total_tokens / total_time if total_time > 0 else 0 + + return { + 'total_tokens': total_tokens, + 'throughput_tokens_per_sec': throughput, + 'total_time': total_time, + 'forward_time': forward_time, + 'backward_time': backward_time + } + + def print_benchmark_results(self, config_name: str, config, dtype: str, + single_metrics: Dict[str, float], + parallel_metrics: Dict[str, float], + world_size: int, rank_id: int): + """Print comprehensive benchmark results.""" + if rank_id != 0: + return + + print("\n" + "="*80) + print(f"{self.get_benchmark_name().upper()} PERFORMANCE BENCHMARK ({self.timing_method.upper()} METHOD)") + print(f"Configuration: {config_name} - {config.name}") + print(f" Sequence length: {config.max_seqlen}") + print(f" Batch size: {config.batch_size}") + print(f" Heads: {config.num_heads}") + print(f" Head dim: {config.head_dim}") + print(f" Data type: {dtype}") + print(f" World size: {world_size} GPUs") + print(f" Total tokens: {single_metrics['total_tokens']:,}") + + if self.timing_method == "warmup": + print(f" (Warmup runs: {self.warmup_runs}, Timing runs: {self.timing_runs})") + print("="*80) + + # Timing results + print(f"Single Mode:") + print(f" Forward time: {single_metrics['forward_time']:.6f} seconds") + print(f" Backward time: {single_metrics['backward_time']:.6f} seconds") + print(f" Total time: {single_metrics['total_time']:.6f} seconds") + print(f" Throughput: {single_metrics['throughput_tokens_per_sec']:.0f} tokens/sec") + + print(f"\nParallel Mode:") + print(f" Forward time: {parallel_metrics['forward_time']:.6f} seconds") + print(f" Backward time: {parallel_metrics['backward_time']:.6f} seconds") + print(f" Total time: {parallel_metrics['total_time']:.6f} seconds") + print(f" Throughput: {parallel_metrics['throughput_tokens_per_sec']:.0f} tokens/sec") + + # Speedup calculations + forward_speedup = single_metrics['forward_time'] / parallel_metrics['forward_time'] if parallel_metrics['forward_time'] > 0 else 0 + backward_speedup = single_metrics['backward_time'] / parallel_metrics['backward_time'] if parallel_metrics['backward_time'] > 0 else 0 + total_speedup = single_metrics['total_time'] / parallel_metrics['total_time'] if parallel_metrics['total_time'] > 0 else 0 + throughput_improvement = parallel_metrics['throughput_tokens_per_sec'] / single_metrics['throughput_tokens_per_sec'] if single_metrics['throughput_tokens_per_sec'] > 0 else 0 + + print(f"\nSpeedup:") + print(f" Forward speedup: {forward_speedup:.2f}x") + print(f" Backward speedup: {backward_speedup:.2f}x") + print(f" Total speedup: {total_speedup:.2f}x") + print(f" Throughput improvement: {throughput_improvement:.2f}x") + + # Efficiency metrics + theoretical_speedup = world_size + efficiency = total_speedup / theoretical_speedup * 100 if theoretical_speedup > 0 else 0 + print(f"\nEfficiency:") + print(f" Theoretical speedup: {theoretical_speedup:.0f}x") + print(f" Actual speedup: {total_speedup:.2f}x") + print(f" Parallel efficiency: {efficiency:.1f}%") + print("="*80 + "\n") + + def run_performance_benchmark(self, config_name: str = None, dtype: str = "bf16", + timing_method: str = "warmup", warmup_runs: int = 3, + timing_runs: int = 5, **legacy_kwargs): + """Run performance benchmark for the specific attention implementation.""" + # Setup timing parameters + self.timing_method = timing_method + self.warmup_runs = warmup_runs + self.timing_runs = timing_runs + + # Initialize distributed environment + world_size, rank = self.initialize_distributed() + rank_id = dist.get_rank() + + # Get configuration + config = get_config(config_name) if config_name else self._create_legacy_config(**legacy_kwargs) + torch_dtype = torch.bfloat16 if dtype == "bf16" else torch.float16 + + if rank_id == 0: + print(f"Running {self.get_benchmark_name()} performance benchmark...") + print(f"Configuration: {config.name if hasattr(config, 'name') else 'custom'}") + + # Prepare inputs + device = torch.device(f"cuda:{rank_id}") + inputs = self.prepare_inputs(config, device, torch_dtype) + + # Broadcast inputs to ensure consistency + for tensor in inputs.values(): + if isinstance(tensor, torch.Tensor): + dist.broadcast(tensor, src=0) + dist.barrier() + + # Pre-generate dout tensor for timing consistency + with torch.no_grad(): + dummy_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + dummy_inputs[k] = v.detach() + else: + dummy_inputs[k] = v + dummy_out, _ = self.run_single_gpu_reference(dummy_inputs, config) + dout_tensor = torch.randn_like(dummy_out, device=device, dtype=torch_dtype) + dist.broadcast(dout_tensor, src=0) + + # Create timing functions + single_forward, single_backward, parallel_forward, parallel_backward = self.create_timing_functions( + inputs, config, dout_tensor + ) + + if rank_id == 0: + print(f"Running performance benchmark using {timing_method} method...", end="") + + # Run timing based on method + if timing_method == "profiler": + single_forward_time, single_backward_time, _ = self.run_timing_with_profiler( + single_forward, single_backward, rank_id + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_with_profiler( + parallel_forward, parallel_backward, rank_id + ) + elif timing_method == "warmup": + single_forward_time, single_backward_time, _ = self.run_timing_with_warmup( + single_forward, single_backward, warmup_runs, timing_runs + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_with_warmup( + parallel_forward, parallel_backward, warmup_runs, timing_runs + ) + else: # simple + single_forward_time, single_backward_time, _ = self.run_timing_simple( + single_forward, single_backward + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_simple( + parallel_forward, parallel_backward + ) + + if rank_id == 0: + print(" Done!") + + # Calculate metrics and print results + single_metrics = self.calculate_throughput_metrics(config, single_forward_time, single_backward_time) + parallel_metrics = self.calculate_throughput_metrics(config, parallel_forward_time, parallel_backward_time) + + self.print_benchmark_results( + config_name or "custom", config, dtype, + single_metrics, parallel_metrics, world_size, rank_id + ) + + # Cleanup + dist.destroy_process_group() + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters.""" + class LegacyConfig: + def __init__(self, **kwargs): + self.name = "legacy_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + # Add other default attributes as needed + + return LegacyConfig(**kwargs) + + def list_configurations(self): + """List all available configurations for benchmarking.""" + print("Available Ring Attention Configurations:") + print("=" * 50) + + for category in ["small", "medium", "large", "gqa"]: + print(f"\n{category.upper()} CONFIGS:") + configs = get_configs_by_category(category) + if configs: + for name, config in configs.items(): + tokens_k = config.total_tokens // 1000 + gqa_info = f" (GQA {config.num_heads}->{config.num_kv_heads})" if config.is_gqa else "" + causal_info = " [Causal]" if config.causal else " [Non-causal]" + window_info = f" [Window={config.window_size[0]},{config.window_size[1]}]" if config.window_size != (-1, -1) else "" + print(f" {name:20s} - {config.batch_size}x{config.num_heads}x{config.head_dim}, seq={config.max_seqlen}, tokens={tokens_k}K, {config.dtype}{gqa_info}{causal_info}{window_info}") + else: + print(" No configurations in this category") + + print(f"\nDEFAULT PERFORMANCE CONFIGS: {DEFAULT_PERFORMANCE_CONFIGS}") + print(f"\nUsage: Use --config to specify a configuration") \ No newline at end of file diff --git a/benchmark/benchmark_ring_attn.py b/benchmark/benchmark_ring_attn.py new file mode 100644 index 00000000..50929729 --- /dev/null +++ b/benchmark/benchmark_ring_attn.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import ring attention implementation +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnBenchmark(RingAttnBenchmarkBase): + """Benchmark for standard Ring Attention""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn.wrap_ring_attn_func' + + @property + def function_name(self) -> str: + return "ring_attn" + + def get_benchmark_name(self) -> str: + return "Ring Attention" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for standard ring attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v): + return wrap_ring_attn_func( + q, k, v, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for standard ring attention.""" + set_seed(42) + + # Create input tensors with standard batch format + q = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + k = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + v = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + + output = wrap_ring_attn_func( + q, k, v, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for standard ring attention.""" + class LegacyRingAttnConfig: + def __init__(self, **kwargs): + self.name = "legacy_ring_attn_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + self.causal = True + self.window_size = (-1, -1) + + return LegacyRingAttnConfig(**kwargs) + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Ring Attention Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = RingAttnBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/benchmark_ring_attn_varlen.py b/benchmark/benchmark_ring_attn_varlen.py new file mode 100644 index 00000000..97c4c6fc --- /dev/null +++ b/benchmark/benchmark_ring_attn_varlen.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import ring attention implementation +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnVarlenBenchmark(RingAttnBenchmarkBase): + """Benchmark for Ring Attention Variable Length""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func' + + @property + def function_name(self) -> str: + return "ring_attn_varlen" + + def get_benchmark_name(self) -> str: + return "Ring Attention Variable Length" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for variable length ring attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v, cu_seqlens_q, cu_seqlens_k): + return wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for variable length sequence attention.""" + set_seed(42) + + # Get cu_seqlens from config or create default + if hasattr(config, 'cu_seqlens'): + cu_seqlens = config.cu_seqlens + else: + # Create default variable length sequences + seqlen = config.max_seqlen + cu_seqlens = [0, seqlen // 8, seqlen // 4, seqlen // 2, seqlen] + + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + total_tokens = cu_seqlens[-1] + + # Create input tensors + q = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + k = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + v = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v, + 'cu_seqlens_q': cu_seqlens_tensor, + 'cu_seqlens_k': cu_seqlens_tensor + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + cu_seqlens_q = inputs['cu_seqlens_q'] + cu_seqlens_k = inputs['cu_seqlens_k'] + + output = wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for varlen.""" + class LegacyVarlenConfig: + def __init__(self, **kwargs): + self.name = "legacy_varlen_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.dtype = "bf16" + self.causal = True + self.window_size = (-1, -1) + + # Create variable length sequences + seqlen = self.max_seqlen + self.cu_seqlens = kwargs.get('cu_seqlens', [0, seqlen // 8, seqlen // 4, seqlen // 2, seqlen]) + self.total_tokens = self.cu_seqlens[-1] + + return LegacyVarlenConfig(**kwargs) + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Ring Attention Variable Length Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Total sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (number of sequences) (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = RingAttnVarlenBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/benchmark_zigzag_attn.py b/benchmark/benchmark_zigzag_attn.py new file mode 100644 index 00000000..94e99521 --- /dev/null +++ b/benchmark/benchmark_zigzag_attn.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag Attention Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import zigzag attention implementation +from nnscaler.customized_ops.ring_attention import wrap_zigzag_attn_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS, ZIGZAG_CONFIGS + + +class ZigzagAttnBenchmark(RingAttnBenchmarkBase): + """Benchmark for Zigzag Attention""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.zigzag_attn.wrap_zigzag_attn_func' + + @property + def function_name(self) -> str: + return "zigzag_attn" + + def get_benchmark_name(self) -> str: + return "Zigzag Attention" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for zigzag attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v): + # Zigzag attention only supports causal=True and window_size=(-1,-1) + return wrap_zigzag_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for zigzag attention.""" + set_seed(42) + + # Create input tensors with standard batch format + q = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + k = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + v = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + + # Zigzag attention constraints + output = wrap_zigzag_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for zigzag attention.""" + class LegacyZigzagAttnConfig: + def __init__(self, **kwargs): + self.name = "legacy_zigzag_attn_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + # Zigzag attention constraints + self.causal = True + self.window_size = (-1, -1) + + return LegacyZigzagAttnConfig(**kwargs) + + def run_performance_benchmark(self, config_name: str = None, dtype: str = "bf16", + timing_method: str = "warmup", warmup_runs: int = 3, + timing_runs: int = 5, **legacy_kwargs): + """Override to validate zigzag attention constraints.""" + # Validate configuration for zigzag constraints + if config_name: + from configs import get_config + config = get_config(config_name) + if not config.causal: + print(f"WARNING: Config '{config_name}' has causal=False, but zigzag attention requires causal=True") + print("Proceeding with causal=True for zigzag attention...") + if config.window_size != (-1, -1): + print(f"WARNING: Config '{config_name}' has window_size={config.window_size}, but zigzag attention requires (-1, -1)") + print("Proceeding with window_size=(-1, -1) for zigzag attention...") + + # Call parent implementation + super().run_performance_benchmark( + config_name=config_name, dtype=dtype, timing_method=timing_method, + warmup_runs=warmup_runs, timing_runs=timing_runs, **legacy_kwargs + ) + + def list_configurations(self): + """List configurations suitable for zigzag attention.""" + print("Available Zigzag Attention Configurations:") + print("=" * 50) + print("NOTE: Zigzag attention only supports causal=True and window_size=(-1,-1)") + print("Configurations listed below will be automatically adjusted for these constraints.\n") + + # Call parent method but with zigzag-specific note + super().list_configurations() + + print(f"\nZIGZAG-SPECIFIC CONFIGS: {list(ZIGZAG_CONFIGS.keys())}") + print("These configs are specifically designed for zigzag attention.") + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Zigzag Attention Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = ZigzagAttnBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dev.md b/dev.md index f42fcb04..0960bde9 100644 --- a/dev.md +++ b/dev.md @@ -92,4 +92,4 @@ Another trick is, if you want to step into pakcage source code, you can add the ### Write Unit Tests 1. If you need to use torchrun, please refer to `unit_test/launch_torchrun.py`, and you can find examples in `unit_tests/runtime/test_runtime_collectives.py`. Please note that `torchrun` is very slow, you should reduce its usage as possible. 2. If you want to mock up any functions/methods, please use pytest-mock. -3. **NOTE**: The name of test files and test functions must start with `test_` +3. **NOTE**: The name of test files and test functions must start with `test_` \ No newline at end of file diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 5e5deb7d..2e26a9af 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -12,7 +12,7 @@ The wheel package is hosted on `GitHub release `_. + +If you are familiar with Azure stuffs, you can follow DevOps' guide to set up the repository. + +Or if you prefer the simpler way, download the ``.whl`` file in the "Files" section of the website, +and install it locally: + +:: + + python -m pip install nnscaler-*.whl + +********** +Quickstart +********** + +The next step depends on your choice of the training framework. + +- **No framework**: if you write your own training code and do not use a framework, + see :ref:`Parallelize API` section. +- **Fairseq**: if you use fairseq, see :ref:`Fairseq` section. +- **Lightning**: TODO + +.. _Parallelize API: + +Parallelize API +=============== + +TODO: write a hello world example, assigned to Zhe Liu + +If you write your own training code, you can use the *parallelize* API to make your model parallel: + +.. code-block:: python + + import torch + from nnscaler import parallelize, ComputeConfig, build_optimizer + + class LLM(torch.nn.Module): + def __init__(self, ...): + ... + def forward(self, x): + ... + + llm_sample_input = ... # dummpy input will be used to do tracing + pas_policy = ... # the PAS policy, you can use autodist pas + compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + ..., + ) # compute environment config + ParallelizedLLM = parallelize( + LLM, + {'x': llm_sample_input}, + pas_policy, + compute_config, + ) + +Example +------- + +An example of the parallelize API is provided in the repo: +`train.py `_ + +You can download and try it: :: + + torchrun --nproc_per_node=4 --nnodes=1 train.py + +Documentation +------------- + +If the example works for you, you can now follow the documentation to parallelize your model: +:doc:`parallel_module` + +.. _Fairseq: + +Fairseq (To be retired) +======================= + +.. TODO: + + nnScaler provides `fairseq integration `_. + + TODO: refine the example (and its doc), assigned to Youshan Miao + + TODO (long term): write an example using unmodified fairseq + + Installation + ------------ + + To use fairseq, clone the fork and install it: :: + + python -m pip uninstall fairseq + + git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Fairseq + cd Fairseq + python -m pip install -e . + + Example + ------- + + Follow the example + `here `_. + diff --git a/examples/warmup_scheduler.py b/examples/warmup_scheduler.py new file mode 100644 index 00000000..54e8aa7f --- /dev/null +++ b/examples/warmup_scheduler.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import math +from torch.optim.lr_scheduler import LRScheduler, Optimizer, _warn_get_lr_called_within_step + + +class WarmupCosineAnnealingLR(LRScheduler): + r""" + torch.optim.lr_scheduler.CosineAnnealingLR with warmup. + + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_steps (int): Number of warmup steps. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + + def __init__( + self, + optimizer: Optimizer, + warmup_steps: int, + T_max: int, + eta_min=0.0, + last_epoch=-1, + ): # noqa: D107 + self.warmup_steps = warmup_steps + self.T_max = T_max - warmup_steps + 1 + self.eta_min = eta_min + super().__init__(optimizer, last_epoch) + + def get_lr(self): + """Retrieve the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + last_epoch_wo_warmup = self.last_epoch - self.warmup_steps + 1 + if last_epoch_wo_warmup < 0: + return [base_lr * (self.last_epoch + 1) / self.warmup_steps for base_lr in self.base_lrs] + elif last_epoch_wo_warmup == 0: + return [base_lr for base_lr in self.base_lrs] + elif self._step_count == 1 and last_epoch_wo_warmup > 0: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos((last_epoch_wo_warmup) * math.pi / self.T_max)) + / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif (last_epoch_wo_warmup - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * last_epoch_wo_warmup / self.T_max)) + / (1 + math.cos(math.pi * (last_epoch_wo_warmup - 1) / self.T_max)) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + last_epoch_wo_warmup = self.last_epoch - self.warmup_steps + 1 + if last_epoch_wo_warmup < 0: + return [base_lr * (self.last_epoch + 1) / self.warmup_steps for base_lr in self.base_lrs] + else: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * last_epoch_wo_warmup / self.T_max)) + / 2 + for base_lr in self.base_lrs + ] diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index 2bf5867a..b3a18165 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -16,6 +16,8 @@ broadcast_weights, load_sharded_state_dict, sync_grad_when, + trimmed_broadcast_merged_state_dict, + load_merged_state_dict_from_rank, ) from nnscaler.graph.parser.register import register_op from nnscaler.runtime.function.function import ( @@ -24,6 +26,12 @@ no_constant_folding, fold_constant, ) +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdam, MixedPrecisionAdamW +from nnscaler.runtime.hybrid_optimizer import HybridLRScheduler, HybridOptimizer +from nnscaler.utils import ( + mark_dynamic, + get_dynamic, +) def init(): diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 796d35e2..de3eb351 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -72,6 +72,12 @@ class AutoDistConfig: `x.module1` will match `x.module1` but not `y.module1`. Due to constraint of the tracer, you can pass `ROOT` to recompute_modules if you want the whole module to be recomputed. + - recompute_ratio ('float`, *optional*, defaults to `1.0`): + When `recompute_modules` only contains one name (excluding `ROOT`), this specify the ratio of modules + to be recomputed. For example, if `module1` is specified in `recompute_modules` and `recompute_ratio` is `0.8`, + only 80% of `module1` instances will be recomputed. + If there are multiple module names in `recompute_modules`, this field will be ignored and all specified modules + will be recomputed. - memory_constraint (`float`, *optional*, defaults to `32`): The memory constraint in each device in GB. - memory_granularity (`int`, *optional*, defaults to `1`): @@ -115,6 +121,10 @@ class AutoDistConfig: `transient_mem_size = opt_transient_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)`. This formula is useful in many cases, but it may be too strict when some operators consume or generate a large tensor (>= 4GB). In this case, you can set `transient_mem_coef` to a smaller value to relax the constraint. + - disable_shared_param_constraint (`bool`, *optional*, defaults to `False`): + Whether to disable the shared parameter constraint in spmd solver. When a parameter is shared by multiple modules, + the spmd solver will force the parameter to be replicated to complicated adapter generation. However, user can disable + it and provide customized partition constraints for those shared parameters. """ def __init__(self, @@ -133,6 +143,7 @@ def __init__(self, mesh_row=1, mesh_col=1, recompute_modules='', + recompute_ratio=1.0, memory_constraint=32, memory_granularity=1, micro_batch_size=1, @@ -150,6 +161,7 @@ def __init__(self, solver='dp', parallel_profile=True, transient_mem_coef=2, + disable_shared_param_constraint=False, **kwargs): self.pc_path = partition_constraints_path self.profile_dir = profile_dir @@ -166,6 +178,7 @@ def __init__(self, self.is_train = is_train self.mesh_desc = MeshDesc(mesh_row, mesh_col) self.recompute_modules = recompute_modules + self.recompute_ratio = recompute_ratio # from GB to Byte self.memory_constraint = int(memory_constraint * 1024 * 1024 * 1024) self.memory_granularity = memory_granularity @@ -192,6 +205,7 @@ def __init__(self, self.solver = 'dp' self.parallel_profile = parallel_profile self.transient_mem_coef = transient_mem_coef + self.disable_shared_param_constraint = disable_shared_param_constraint ignored_keys = list(kwargs.keys()) if ignored_keys: @@ -244,7 +258,7 @@ def _validate_config(self): scale_factor = self.world_size // self.mesh_desc.ngpus if scale_factor % self.zero_ngroups != 0: raise ValueError( - f'world size {self.world_size} must be divisible by zero num groups {self.zero_ngroups}' + f'scale_factor {scale_factor} must be divisible by zero num groups {self.zero_ngroups}' ) if not self.solver in [ diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 2e003359..97c446c3 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -119,27 +119,37 @@ def _load_comm_data(profile_dir: Path, plan_ngpus: int) -> Dict[str, Dict[str, L load intra_2.json, intra_4.json, and intra_8.json from the profile directory. If any of the files is not found, we will use the default data as well. ''' - def loader(path: Path): + def loader(path: Path, strict: bool): if not os.path.exists(path): return False, None info = {} dev = 2 + prev_info = None while dev <= plan_ngpus: fname = f'intra_{dev}.json' if not (path / fname).exists(): - return False, None - with open(path / fname, 'r') as f: - info[fname] = json.load(f) + if strict or prev_info is None: + return False, None + else: + content = prev_info + _logger.warning(f'{dev} devices communication profile data not found, using previous data') + else: + with open(path / fname, 'r') as f: + content = json.load(f) + prev_info = content + info[fname] = content dev *= 2 return True, info comm_path = profile_dir / 'comm' - success, comm_info = loader(comm_path) + success, comm_info = loader(comm_path, strict=True) if not success: + # When communication profile data is not found, use the default data. If the input `plan_ngpus` is greater + # than the devices in the profile data, the data with largest device count (16 for mi200) will be used. This + # is helpful when user wants to generate a distributed plan spanning over multiple nodes. _logger.warning(f'Communication profile data not found, using default data at {_DEFAULT_COMM_DATA_PATH}') - success, comm_info = loader(Path(_DEFAULT_COMM_DATA_PATH)) - if not success: - raise RuntimeError(f'Communication profile data is not compatible with plan_ngpus {plan_ngpus}') + success, comm_info = loader(Path(_DEFAULT_COMM_DATA_PATH), strict=False) + assert success, f'Failed to load default communication profile data from {_DEFAULT_COMM_DATA_PATH}, please check nnscaler\'s installation' return comm_info @@ -337,10 +347,14 @@ def query_single_mem(self, obj, memory_type, round=True) -> int: from .op_partition import OpPartition from .cube_operator import CubeOperator if isinstance(obj, OpPartition): - masks = self.gen_masks(obj.operator) + query_obj = obj.operator else: assert isinstance(obj, CubeOperator) - masks = self.gen_masks(obj) + query_obj = obj + try: + masks = self.gen_masks(query_obj) + except Exception as e: + raise RuntimeError(f"Failed to generate masks for {query_obj} with {self.query_profiled_metrics(query_obj)}: {e}") if memory_type == 'full_weight' and isinstance(obj, OpPartition): profiled_metrics = self.query_profiled_metrics(obj.operator) else: diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index f0b8e6d7..3b8bc015 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -67,7 +67,7 @@ void ThreadPool::waitFinished() { cv_finished.wait(lock, [this]() { return tasks.empty() && (busy == 0); }); } -const int MAX_CONCURRENCY = std::thread::hardware_concurrency(); +int MAX_CONCURRENCY = std::thread::hardware_concurrency(); ThreadPool pool(MAX_CONCURRENCY); std::vector> split_work(int num, int base) { @@ -118,6 +118,11 @@ class DPSolver { queries.clear(); id2node.clear(); search_results.clear(); + if (verbose) { + MAX_CONCURRENCY = 1; + std::cout << "set MAX_CONCURRENCY to 1 for verbose mode" + << std::endl; + } } void add_interval(int start, int end) { @@ -230,6 +235,31 @@ class DPSolver { } } + int encode_ir(const std::vector> &cur_ir) { + int val = 0; + for (std::size_t j = 0; j < cur_ir.size(); ++j) { + val += cur_ir[j].second; + if (j + 1 < cur_ir.size()) { + val *= id2node[cur_ir[j + 1].first]->p_num; + } + } + return val; + } + + void print_ir(const std::vector> &cur_ir) { + for (std::size_t j = 0; j < cur_ir.size(); ++j) { + std::cout << "(" << cur_ir[j].first << ", " << cur_ir[j].second << ") "; + } + std::cout << std::endl; + } + + void print_states(DPNode *dp_node) { + for (std::size_t i = 0; i < dp_node->state.size(); ++i) { + UnitDPState state = dp_node->state[i]; + std::cout << "state " << i << ": " << state.to_string() << std::endl; + } + } + // lazy build edge void buildInEdges(DPNode *dp_node) { if (!dp_node->in_edges.empty()) { @@ -361,15 +391,14 @@ class DPSolver { break; } } + bool need_add_pre_node = false; if (!find_pre_id) { Node *pre_node = id2node[node->id - 1]; if (pre_node->father_id != node->father_id) { - // do nothing, means the pre_node's output is not used - // we select the 1st partition of the pre_node - // need to be careful when the graph has multiple outputs if (!has_found_follow && !follow_candidates.empty()) { cur_ir.push_back(*follow_candidates.rbegin()); } + need_add_pre_node = true; } else if (pre_node->father_id == pre_node->id) { assert(follow_candidates.rbegin()->first == pre_node->id); cur_ir.push_back(*follow_candidates.rbegin()); @@ -391,15 +420,36 @@ class DPSolver { } } std::sort(cur_ir.begin(), cur_ir.end()); - val = 0; - for (std::size_t j = 0; j < cur_ir.size(); ++j) { - val += cur_ir[j].second; - if (j + 1 < cur_ir.size()) { - val *= id2node[cur_ir[j + 1].first]->p_num; + if (verbose) { + std::cout << "need_add_pre_node: " << need_add_pre_node << std::endl; + } + if (need_add_pre_node) { + // means the pre_node's output is not used by later nodes, + // so we need to enumerate all the partition states of pre_node + if (verbose) { + std::cout << "p_num " << id2node[node->id - 1]->p_num << std::endl; + } + for (int pred_p = 0; pred_p < id2node[node->id - 1]->p_num; + ++pred_p) { + cur_ir.push_back(std::make_pair(node->id - 1, pred_p)); + int val = encode_ir(cur_ir); + dp_node->in_edges.push_back( + std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); + if (verbose) { + print_ir(cur_ir); + print_states(id2node[node->id - 1]->dp_nodes[val]); + } + cur_ir.pop_back(); + } + } else { + int val = encode_ir(cur_ir); + dp_node->in_edges.push_back( + std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); + if (verbose) { + print_ir(cur_ir); + print_states(id2node[node->id - 1]->dp_nodes[val]); } } - dp_node->in_edges.push_back( - std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); } } @@ -414,6 +464,9 @@ class DPSolver { return; } + if (verbose) { + std::cout << "before update, cur_p " << cur_p << std::endl; + } // storing edges takes space, so we build edges when needed buildInEdges(dp_node); if (dp_node->in_edges.empty()) { @@ -468,6 +521,10 @@ class DPSolver { } } } + if (verbose) { + std::cout << "after update" << std::endl; + print_states(dp_node); + } } void do_dp(int start_level, int end_level) { diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 44be8ca2..7c9eec82 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -718,6 +718,9 @@ def fetch_module(scope_node: ScopeNode, prefix: List[str]): modules = [self.scope_tree_root] else: modules = fetch_module(self.scope_tree_root, []) + if len(recompute_modules) == 1 and self.autodist_config.recompute_ratio < 1.0: + boundary = max(1, int(len(modules) * self.autodist_config.recompute_ratio)) + modules = modules[:boundary] train_mem = 0 for module in modules: diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 1a1e0ec8..608e1fd2 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -261,6 +261,9 @@ def should_force_replica(operator: CubeOperator) -> bool: if len(consumers) == 1: continue _logger.info(f'find shared parameter {param} in {consumers}') + if self.autodist_config.disable_shared_param_constraint: + _logger.info(f'disable shared parameter constraint for {param}') + continue for consumer in consumers: if not isinstance(consumer, IRDimops): # always replicate non-dimops @@ -358,10 +361,15 @@ def is_valid_partition(operator: CubeOperator, p_ids: List[Any], if not selected_pc.replica_allowed: return False else: - allowed_pids = [ - operator.pos2dim_id(pos) - for pos in selected_pc.allowed_partition_dims - ] + allowed_pids = list() + for pos in selected_pc.allowed_partition_dims: + # When allowed dims in provided partition constraints are not correct generate warning + # If there is no valid partitions for the operator, the solver will throw exception later. + try: + cur_allowed_pid = operator.pos2dim_id(pos) + allowed_pids.append(cur_allowed_pid) + except Exception as e: + _logger.warning(f"Failed to get allowed partition id for {selected_pc}'s {pos}: {e}") if u not in allowed_pids: return False @@ -681,11 +689,10 @@ def calc_partition_cost(self, op_idx: int, partition_idx: int): bw_comm_time = 0 intra_time = micro_batch_num * (fw_comm_time + bw_comm_time) # double check the follow chain - if self.get_father_id(op_idx) == self.get_father_id( - producer) and intra_time == 0: - if src_p.operator.ir_cell.mirror is not None: - if self.p_fathers[op_idx][ - partition_idx] != self.p_fathers[producer][k]: + # if `intra_time` (forward + backward) is 0, we assume both partitions are in the same follow chain + if self.get_father_id(op_idx) == self.get_father_id(producer) and intra_time == 0: + if src_p.operator.ir_cell.mirror is not None and tgt_p.operator.ir_cell.mirror is not None: + if self.p_fathers[op_idx][partition_idx] != self.p_fathers[producer][k]: _logger.warning( f'Unexpected comm cost, set to inf: {src_p.ir_cell} to {tgt_p.ir_cell}' ) diff --git a/nnscaler/cli/__init__.py b/nnscaler/cli/__init__.py index 958e874f..d218f6f9 100644 --- a/nnscaler/cli/__init__.py +++ b/nnscaler/cli/__init__.py @@ -17,4 +17,6 @@ AggregatedOutputs, ) +from nnscaler.cli.serialization import register_format + from nnscaler.parallel import ComputeConfig diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index f5caab0d..1adf6b9c 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -3,6 +3,7 @@ import os import copy +import logging from typing import List, Optional, Tuple, Dict, Any, Union from dataclasses import dataclass, field, is_dataclass, asdict @@ -16,6 +17,7 @@ except ImportError: UnionType = None # for python < 3.10 +logger = logging.getLogger(__name__) _TYPE_KEY = '__type' _VALUE_TYPE_KEY = '__value_type' @@ -390,6 +392,8 @@ def _deserialize_object(value, value_type): else: raise ValueError(f"Failed to deserialize {value} to {value_type}") if _is_primitive_type(value_type): + if callable(value): + logger.warning(f'{value} is callable, converting to {value_type} may not work as expected.') return value_type(value) except Exception as ex: raise ValueError(f"Failed to deserialize {value} to {value_type}") from ex diff --git a/nnscaler/cli/checkpoint.py b/nnscaler/cli/checkpoint.py new file mode 100644 index 00000000..4d080620 --- /dev/null +++ b/nnscaler/cli/checkpoint.py @@ -0,0 +1,163 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This script provides functionality to convert a merged checkpoint or directory containing per-rank checkpoints into sharded checkpoints +suitable for distributed training with multiple GPUs. +Run this script with: + python -m nnscaler.cli.checkpoint distribute -f +where is the path to the merged checkpoint file or directory containing per-rank checkpoints, +and is the directory to save the sharded checkpoints. + +This script only for command line. +""" + +import logging +import os +import sys +from pathlib import Path + +import torch.distributed + +import nnscaler +from nnscaler.cli.trainer import Trainer, TrainerArgs +from nnscaler.parallel import _trim_module_merged_state_dict, _trim_optimizer_merged_state_dict + + +logger = logging.getLogger(__name__) + + +def _patch_distributed(): + groups = {} + + def is_initialized(): + return bool(groups) + + torch.distributed.is_initialized = is_initialized + + def init_process_group(*args, **kwargs): + world_size = int(os.environ['WORLD_SIZE']) + groups[None] = list(range(world_size)) + + def get_rank(group=None): + if group not in groups: + raise ValueError(f"Unknown group: {group}") + try: + return groups[group].index(int(os.environ['RANK'])) + except ValueError: + return -1 + + def get_world_size(group=None): + if group not in groups: + raise ValueError(f"Unknown group: {group}") + return len(groups[group]) + + def new_group(ranks=None, *args, **kwargs): + world_size = int(os.environ['WORLD_SIZE']) + if ranks is None or len(ranks) == world_size: + return + group_id = tuple(sorted(ranks)) + if group_id in groups: + return group_id + groups[group_id] = ranks + return group_id + + torch.distributed.get_rank = get_rank + torch.distributed.get_world_size = get_world_size + torch.distributed.init_process_group = init_process_group + torch.distributed.destroy_process_group = lambda: None + torch.distributed.new_group = new_group + torch.distributed.barrier = lambda *args, **kwargs: None + torch.distributed.all_gather = lambda *args, **kwargs: None + torch.distributed.broadcast_object_list = lambda *args, **kwargs: None + + +def _trim_merged_checkpoint(train_args: TrainerArgs, merged_state_dict, rank: int): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = '0' + os.environ['WORLD_SIZE'] = str(train_args.compute_config.runtime_ngpus) + os.environ['GROUP_RANK'] = str(rank) + os.environ['LOCAL_WORLD_SIZE'] = '1' + os.environ['TORCHELASTIC_RUN_ID'] = '0' # fake torchrun env + + sharded_state_dict = {k: v for k, v in merged_state_dict.items()} + + trainer = Trainer(train_args=train_args) + # enforce run mode to load module and optimizer + trainer.train_args.run_mode = 'run' + trainer._setup() + + sharded_state_dict['model'] = _trim_module_merged_state_dict( + trainer.model, merged_state_dict['model'], + device='cpu' + ) + sharded_state_dict['optimizer'] = _trim_optimizer_merged_state_dict( + trainer.model, trainer.optimizer._extra_state, merged_state_dict['optimizer'], + device='cpu' + ) + sharded_state_dict['train_args'] = train_args.to_dict() + sharded_state_dict['train_args'].setdefault('checkpoint', {})['save_type'] = 'sharded' + # discard rng_states for merged state dict + sharded_state_dict.pop('rng_states', None) + if 'dataloader' in sharded_state_dict and sharded_state_dict['dataloader'] is not None: + # keep dataloader state only when all ranks have the same state + dataloader_states = sharded_state_dict['dataloader'] + if all(dataloader_states[i] == dataloader_states[0] for i in range(1, len(dataloader_states))): + sharded_state_dict['dataloader'] = dataloader_states[0] + else: + sharded_state_dict.pop('dataloader') + + # make it sharded checkpoint + for module_path, m in trainer.model.named_modules(): + prefix = module_path + '.' if module_path else '' + if isinstance(m, nnscaler.ParallelModule): + m._add_extra_state(sharded_state_dict['model'], prefix) + return sharded_state_dict + + +def _distribute_checkpoint(train_args: TrainerArgs, from_: str, to_: str): + nnscaler.utils.set_default_logger_level(level=logging.INFO) + _patch_distributed() + resume_from = Path(from_) + save_to = Path(to_) + save_to.mkdir(parents=True, exist_ok=True) + checkpointer = train_args.create_checkpointer() + + if resume_from.is_file(): + state_dict = checkpointer.load(resume_from, device='cpu') + if convert_fn := train_args.checkpoint.resolved_convert_fn: + state_dict = convert_fn(state_dict) + else: + ckpt_files = checkpointer.list_checkpoints(resume_from) + rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} + if set(rank_ckpt_files.keys()) != set(range(len(rank_ckpt_files))): + raise ValueError(f"Checkpoint files in {resume_from} are not complete: {rank_ckpt_files.keys()}") + state_dict = Trainer._merge_checkpoint(list(rank_ckpt_files.values()), checkpointer=checkpointer) + + for i in range(train_args.compute_config.runtime_ngpus): + sharded_state_dict = _trim_merged_checkpoint(train_args, state_dict, i) + checkpointer.save_for_rank(sharded_state_dict, save_to, i) + + checkpointer.flush() + + +if __name__ == '__main__': + argv = sys.argv[1:] + if len(argv) == 0: + raise ValueError("No command specified. Expected `distribute -f `") + if argv[0] == 'distribute': + if len(argv) < 5: + raise ValueError("Not enough arguments. Expected at least `distribute -f `") + from_ = argv[1] + to_ = argv[2] + train_args = TrainerArgs.from_cli(argv[3:]) + # never broadcast generated files. + train_args.broadcast_strategy = 'none' + train_args.checkpoint.resume_from = None + _distribute_checkpoint(train_args, from_, to_) + else: + raise ValueError(f"Unknown command: {argv[0]}") +else: + # we have patched too many things. + # please run this script with `python -m nnscaler.cli.checkpoint` + raise ImportError("checkpoint.py should be run as a script.") diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index d7354617..ef4268ca 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -19,6 +19,7 @@ TrainerArgs, PrecisionMixin, PolicyMixin, ModuleParallelizeConfig, ComputeConfig, load_type ) +from .serialization import Checkpointer logger = logging.getLogger(__name__) @@ -41,7 +42,8 @@ class ModuleParallelizeConfigAdapter(PrecisionMixin, PolicyMixin): def __init__( self, trainer_args: TrainerArgs, parallel_module: Optional[ModuleParallelizeConfig] = None, - tracing_weights: Optional[dict[str, Any]] = None + tracing_weights: Optional[dict[str, Any]] = None, + checkpointer: Optional[Checkpointer] = None, ): """ Args: @@ -52,6 +54,7 @@ def __init__( self.trainer_args = trainer_args self.parallel_module = parallel_module self.tracing_weights = tracing_weights + self.checkpointer = checkpointer or Checkpointer() # we don't want to load the tracing weights every time # It should be loaded only once outside, and passed to the adapter @@ -132,10 +135,10 @@ def load_tracing_weights(self) -> Optional[dict[str, Any]]: # try to reuse the weights from the tracing weights tracing_weights = self.tracing_weights if self.tracing_from_weights and tracing_weights is None: - tracing_weights = torch.load(self.tracing_from_weights) + tracing_weights = self.checkpointer.load(self.tracing_from_weights) else: if self.tracing_from_weights: - tracing_weights = torch.load(self.tracing_from_weights) + tracing_weights = self.checkpointer.load(self.tracing_from_weights) elif self.parallel_module.tracing_from_weights_prefix: leading_key = self.parallel_module.tracing_from_weights_prefix + '.' tracing_weights = {} @@ -166,21 +169,29 @@ def create_model(self, module_args: Optional[tuple[tuple, dict]]=None) -> torch. def create_dummy_forward_args(self, dummy_input) -> dict[str, Any]: if self.parallel_module: - return self.fix_input( + forward_args = self.fix_input( self.parallel_module.create_dummy_forward_args(self.trainer_args) ) - - # forward args of whole model - arg_names = list( - inspect.signature( - inspect.unwrap(getattr(self.model_type, 'forward')) - ).parameters.keys() - ) - return {arg_names[1]: self.fix_input(dummy_input)} # arg_names[0] is self + if self.parallel_module.forward_args_post_process_fn: + forward_args = self.parallel_module.forward_args_post_process_fn(self.trainer_args, forward_args) + return forward_args + else: + # forward args of whole model + arg_names = list( + inspect.signature( + inspect.unwrap(getattr(self.model_type, 'forward')) + ).parameters.keys() + ) + # dummy input is already fixed and post processed by trainer + forward_args = {arg_names[1]: dummy_input} # arg_names[0] is self + return forward_args def resolve_compute_config(self): compute_config = copy.deepcopy(self.compute_config) - compute_config.pas_config['__pas_name'] = self.pas_policy + compute_config.pas_config['__pas_name'] = \ + self.pas_policy \ + if not callable(self.pas_policy) \ + else f'{self.pas_policy.__module__}.{self.pas_policy.__qualname__}' # autodist configs compute_config.pas_config['update_freq'] = self.trainer_args.update_freq compute_config.pas_config['use_bf16'] = self.param_dtype == torch.bfloat16 @@ -197,6 +208,7 @@ def resolve_compute_config(self): def parallelize(self, dummy_input: Optional[dict[str, Any]] = None, *, load_module: bool = True, + build_buckets: bool = True, module_args: Optional[tuple[tuple, dict]] = None ): pmodel_class = nnscaler.parallelize( @@ -212,7 +224,7 @@ def parallelize(self, load_module=load_module, ) if load_module: - return pmodel_class() + return pmodel_class(build_buckets=build_buckets) return pmodel_class @@ -279,24 +291,32 @@ def parameters_for_calc_gnorm(self): return model -def parallelize_model(trainer_args: TrainerArgs, dummy_input: dict[str, Any], load_module: bool): +def parallelize_model( + trainer_args: TrainerArgs, + dummy_input: dict[str, Any], + load_module: bool, + build_buckets: bool, + checkpointer: Checkpointer +): tracing_weights = None + checkpointer = checkpointer or Checkpointer() if trainer_args.tracing_from_weights: - tracing_weights = torch.load(trainer_args.tracing_from_weights) + tracing_weights = checkpointer.load(trainer_args.tracing_from_weights) def _new_adapter(parallel_module=None): return ModuleParallelizeConfigAdapter( trainer_args, parallel_module, - tracing_weights=tracing_weights + tracing_weights=tracing_weights, + checkpointer=checkpointer, ) if not trainer_args.model.parallel_modules: # parallelize the whole model - return _new_adapter().parallelize(dummy_input, load_module=load_module) + return _new_adapter().parallelize(dummy_input, load_module=load_module, build_buckets=build_buckets) if not load_module and all(pm.args is not None for pm in trainer_args.model.parallel_modules): for m in trainer_args.model.parallel_modules: - _new_adapter(m).parallelize(dummy_input, load_module=False) + _new_adapter(m).parallelize(dummy_input, load_module=False, build_buckets=build_buckets) return parallel_sub_modules = { @@ -346,7 +366,7 @@ def __parallel__new__(cls, *args, **kwargs): # This is a trade-off to make sure the parallelized module is consistent. # Maybe we can use torch.distributed.broadcast to sync the random state in all devices. with fork_rng(): - return adapter.parallelize(dummy_input, load_module=load_module, module_args=(args, kwargs)) + return adapter.parallelize(dummy_input, load_module=load_module, build_buckets=build_buckets, module_args=(args, kwargs)) finally: _patch_new() diff --git a/nnscaler/cli/serialization.py b/nnscaler/cli/serialization.py new file mode 100644 index 00000000..c499740f --- /dev/null +++ b/nnscaler/cli/serialization.py @@ -0,0 +1,464 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, Callable, Protocol, Type +from pathlib import Path +import shutil + +import torch + +from nnscaler.runtime.serialization import load, save + + +class _LoadProc(Protocol): + def __call__(self, f: str | Path, *, device='cpu') -> Any: ... + + +class _SaveProc(Protocol): + def __call__(self, obj: Any, f: str | Path) -> None: ... + + +class CheckpointFormat(Protocol): + """ + A placeholder class for new serialization formats. + """ + name: str + suffix: str + + @classmethod + def load(cls, f: str | Path, *, device='cpu') -> Any: + ... + + @classmethod + def save(cls, obj: Any, f: str | Path) -> None: + ... + + +class SerializationRunner(Protocol): + name: str + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + ... + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + ... + + def flush(self) -> None: + """ + Flushes any pending operations for saving. + Loading operations are assumed to be synchronous. + """ + ... + + +class _DefaultSerializationRunner: + name: str = '' + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + return load_func(f, device=device) + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + save_func(obj, f) + + def flush(self) -> None: + pass + + +def make_hybrid_serialization_runner( + load_serializer: Type[SerializationRunner], + save_serializer: Type[SerializationRunner] +) -> Type[SerializationRunner]: + """ + Creates a hybrid serialization runner that uses different runners for loading and saving. + """ + class HybridSerializationRunner(SerializationRunner): + name = f"{load_serializer.name}:{save_serializer.name}" + + def __init__(self, load_args=None, save_args=None): + self._load_runner = load_serializer(**(load_args or {})) + self._save_runner = save_serializer(**(save_args or {})) + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + return self._load_runner.run_load(load_func, f, device=device) + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + self._save_runner.run_save(save_func, obj, f) + + def flush(self) -> None: + self._save_runner.flush() + + return HybridSerializationRunner + + +def _torch_load(f: str | Path, *, device='cpu') -> Any: + return torch.load(f, map_location=device, weights_only=False) + + +def _torch_save(obj: Any, f: str | Path) -> None: + torch.save(obj, f) + + +class Checkpointer: + # the format of the checkpoint file + # keys: epoch, step, rank + # currently it is not configurable + # TODO: make it configurable + CHECKPOINT_FILE_NAME_FORMAT: str = '{rank}{suffix}' + CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/' + CHECKPOINT_FILE_NAME_FORMAT + CHECKPOINT_LAST_DIR_NAME: str = 'last' + CHECKPOINT_BEST_DIR_NAME: str = 'best' + CHECKPOINT_MERGED_FILE_NAME: str = 'merged{suffix}' + CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}{suffix}' + CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}{suffix}' + NAME_MAP: dict[str, str] = { + 'pt': '.ckpt', + 'safetensors': '.safetensors' + } + SUFFIX_MAP: dict[str, str] = {v: k for k, v in NAME_MAP.items()} + # will use torch.load and torch.save for other suffixes + SUFFIX_HANDLERS: dict[str, tuple[_LoadProc, _SaveProc]] = { + '.safetensors': (load, save), + } + REGISTERED_RUNNERS: dict[str, Type[SerializationRunner]] = { + '': _DefaultSerializationRunner, + } + + def __init__(self, format: str = 'pt', serializer: str = None, serializer_args: dict[str, Any] = None): + """ + Args: + format (`str`, *optional*, defaults to `"pt"`): + The checkpoint format to use. Builtin formats are: + - `"pt"`: PyTorch checkpoint format. + - `"safetensors"`: Safetensors format. + serializer (`str`, *optional*): + The serialization runner to use. Builtin runners are: + - `""` (empty string): Default runner that directly uses the load and save functions. + You can also specify a hybrid runner by using the format `load_serializer:save_serializer`, + e.g., `"split:async"`. + serializer_args (`dict`, *optional*): + args for the serialization runner. + """ + if format not in self.NAME_MAP: + raise ValueError(f"Unsupported checkpoint format: {format}") + self.format = format + self.suffix = self.NAME_MAP[format] + + self.runner: SerializationRunner + serializer = serializer or '' + + if ':' in serializer: + parts = serializer.split(':') + if len(parts) != 2: + raise ValueError(f"Invalid hybrid serialization runner: {serializer}") + load_serializer_name = parts[0] + save_serializer_name = parts[1] + if load_serializer_name not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {load_serializer_name}") + if save_serializer_name not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {save_serializer_name}") + load_serializer_type = self.REGISTERED_RUNNERS[load_serializer_name] + save_serializer_type = self.REGISTERED_RUNNERS[save_serializer_name] + runner_cls = make_hybrid_serialization_runner( + load_serializer_type, + save_serializer_type + ) + else: + if serializer not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {serializer}") + runner_cls = self.REGISTERED_RUNNERS[serializer] + + self.runner = runner_cls(**(serializer_args or {})) + + def get_checkpoint_file_path(self, epoch: int, step: int, rank: int) -> str: + return self.CHECKPOINT_FILE_FORMAT.format(epoch=epoch, step=step, rank=rank, suffix=self.suffix) + + def get_last_checkpoint_file_path(self, rank: int) -> str: + return self.CHECKPOINT_LAST_FILE_FORMAT.format(rank=rank, suffix=self.suffix) + + def get_best_checkpoint_file_path(self, rank: int) -> str: + return self.CHECKPOINT_BEST_FILE_FORMAT.format(rank=rank, suffix=self.suffix) + + def get_merged_checkpoint_file_name(self) -> str: + return self.CHECKPOINT_MERGED_FILE_NAME.format(suffix=self.suffix) + + def get_last_dir_name(self) -> str: + return self.CHECKPOINT_LAST_DIR_NAME + + def get_best_dir_name(self) -> str: + return self.CHECKPOINT_BEST_DIR_NAME + + def load(self, f: str | Path, *, device='cpu') -> Any: + """ + Loads a checkpoint file + + Args: + f: filename of the checkpoint file. + if the suffix is .safetensors, it will be loaded as safetensors file. + otherwise, it will be loaded as a PyTorch checkpoint file. + device (`str`, *optional*, defaults to `"cpu"`): + The device on which you want the tensors. + """ + suffix = Path(f).suffix + if suffix in self.SUFFIX_HANDLERS: + load_func, _ = self.SUFFIX_HANDLERS[suffix] + else: + load_func = _torch_load + + return self.runner.run_load(load_func, f, device=device) + + def save(self, obj: Any, f: str | Path) -> None: + """ + Saves a checkpoint file + + Args: + obj (`Any`): + The object to save. + f: filename of the checkpoint file. + if the suffix is .safetensors, it will be saved as safetensors file. + otherwise, it will be saved as a PyTorch checkpoint file. + """ + suffix = Path(f).suffix + if suffix in self.SUFFIX_HANDLERS: + _, save_func = self.SUFFIX_HANDLERS[suffix] + else: + save_func = _torch_save + + self.runner.run_save(save_func, obj, f) + + def load_for_rank(self, dir: str | Path, rank: int, device='cpu') -> Any: + """ + Loads a checkpoint file for a specific rank + + Args: + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to load. + device (`str`, `int`, *optional*): + The device on which you want the tensors. + """ + for suffix in self.NAME_MAP.values(): + f = Path(dir) / f"{rank}{suffix}" + if f.exists(): + return self.load(f, device=device) + raise FileNotFoundError(f"No checkpoint file found for rank {rank} in directory {dir}") + + def save_for_rank(self, obj: Any, dir: str | Path, rank: int) -> None: + """ + Saves a checkpoint file for a specific rank + + Args: + obj (`Any`): + The object to save. + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to save. + """ + f = Path(dir) / self.CHECKPOINT_FILE_NAME_FORMAT.format(rank=rank, suffix=self.suffix) + self.save(obj, f) + + def remove_for_rank(self, dir: str | Path, rank: int) -> None: + """ + Removes a checkpoint file for a specific rank. + Args: + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to remove. + """ + self.flush() + + for suffix in self.NAME_MAP.values(): + f = Path(dir) / f"{rank}{suffix}" + if f.exists(): + f.unlink() + for extra_file in Path(dir).glob(f"{rank}{suffix}.*"): + extra_file.unlink() + + def copy_for_rank(self, src: str | Path, dst: str | Path, rank: int, symlink: bool = False) -> None: + """ + Copies a checkpoint file for a specific rank from one directory to another. + Args: + src (`str`): + The source directory where the checkpoint files are stored. + dst (`str`): + The destination directory where the checkpoint files will be copied. + rank (`int`): + The rank of the checkpoint file to copy. + symlink (`bool`, *optional*, defaults to `False`): + Whether to create a symbolic link instead of copying the file. + """ + + self.flush() + src = Path(src).resolve() + dst = Path(dst).resolve() + dst.mkdir(parents=True, exist_ok=True) + + src_f = Path(src) / f"{rank}{self.suffix}" + dst_f = Path(dst) / f"{rank}{self.suffix}" + + if not src_f.exists(): + raise FileNotFoundError(f"No checkpoint file found for rank {rank} in directory {src}") + + if symlink: + # this restricts symlink creation within the same directory + # so we can create relative symlink safely + if src.parent != dst.parent: + raise ValueError("Cannot create symlink when source and destination are not in the same directory.") + + if symlink: + dst_f.symlink_to(Path('..') / src.name / src_f.name) + for extra_file in src.glob(f"{rank}{self.suffix}.*"): + dst_extra_file = Path(dst) / extra_file.name + dst_extra_file.symlink_to(Path('..') / src.name / extra_file.name) + else: + shutil.copy2(src_f, dst_f) + for extra_file in src.glob(f"{rank}{self.suffix}.*"): + dst_extra_file = Path(dst) / extra_file.name + shutil.copy2(extra_file, dst_extra_file) + + def list_checkpoints(self, dir: str | Path) -> list[Path]: + """ + List the main checkpoint files in a directory + Args: + dir (`str`): + The directory where the checkpoint files are stored. + Returns: + (`list[Path]`): + The list of checkpoint files in the directory. + """ + self.flush() + + p = Path(dir) + files = [] + for suffix in self.NAME_MAP.values(): + fs = list(p.glob(f"*{suffix}")) + if fs: + if files: + raise ValueError(f"Mixed checkpoint file formats in directory {dir}") + else: + files.extend(fs) + return files + + def flush(self) -> None: + """ + Flushes any pending operations. + """ + self.runner.flush() + + @classmethod + def get_format(cls, suffix: str) -> str: + """ + Gets the format name from the suffix. + """ + suffix = '.' + suffix.lstrip('.') + if suffix not in Checkpointer.SUFFIX_MAP: + raise ValueError(f"Unsupported checkpoint suffix: {suffix}") + return Checkpointer.SUFFIX_MAP[suffix] + + +def register_format(format: Type[CheckpointFormat]) -> None: + """ + Registers a new serialization format. + """ + suffix = '.' + format.suffix.lstrip('.') + Checkpointer.NAME_MAP[format.name] = suffix + Checkpointer.SUFFIX_MAP[suffix] = format.name + Checkpointer.SUFFIX_HANDLERS[suffix] = (format.load, format.save) + + +def register_serialization_runner(runner: Type[SerializationRunner]) -> None: + """ + Register a new serialization runner, which can intercept the load and save process. + For example, file redirection, chunking, asynchronous IO or other logic. + + Please note if you create extra files during saving, + you must make sure + 1. the suffix of the main checkpoint file must match registered formats. + 2. the name of extra files should start with the main checkpoint file name + '.', + but the suffix should not conflict with registered formats. + + For example, if the input checkpoint file is `model.ckpt`, + you must create a file called 'model.ckpt', + and you can use extra file names like 'model.ckpt.1', 'model.ckpt.meta', 'model.ckpt.opt' etc. + """ + if ':' in runner.name: + raise ValueError("Serialization runner name cannot contain ':'") + Checkpointer.REGISTERED_RUNNERS[runner.name] = runner + + +def convert_format( + src: str | Path, + dst: str | Path, + *, + src_serializer: str = None, + src_serializer_args: dict = None, + dst_serializer: str = None, + dst_serializer_args: dict = None, + device: str = 'cpu' +) -> None: + """ + Converts a checkpoint file from one format to another. + + Args: + src (`str` or `Path`): + The input checkpoint file. + dst (`str` or `Path`): + The output checkpoint file. + src_serializer (`str`, *optional*): + The serialization runner of the input checkpoint file. + src_serializer_args (`dict`, *optional*): + The arguments for the serialization runner of the input checkpoint file. + dst_serializer (`str`, *optional*): + The serialization runner of the output checkpoint file. + dst_serializer_args (`dict`, *optional*): + The arguments for the serialization runner of the output checkpoint file. + device (`str`, *optional*, defaults to `"cpu"`): + The device on which you want the tensors. + """ + src_format = Checkpointer.get_format(Path(src).suffix) + dst_format = Checkpointer.get_format(Path(dst).suffix) + + if src_format == dst_format and src_serializer == dst_serializer: + raise ValueError("Input and output formats and serializers are the same, no conversion needed.") + + src_checkpointer = Checkpointer(format=src_format, serializer=src_serializer, serializer_args=src_serializer_args) + dst_checkpointer = Checkpointer(format=dst_format, serializer=dst_serializer, serializer_args=dst_serializer_args) + + obj = src_checkpointer.load(src, device=device) + dst_checkpointer.save(obj, dst) + dst_checkpointer.flush() diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index 76abeb8b..7848ae31 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -333,3 +333,31 @@ def after_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: for hook in self.hooks: hook.on_save_checkpoint(trainer, checkpoint) + + +class TrainHookHost: + def _get_hook_objects(self) -> List[Any]: + """ + Return a list of objects that can be hooks (but not necessarily hooks) + """ + ... + + def get_hooks(self) -> List[TrainHook]: + """ + Return a list of TrainHook objects + """ + hooks = {} + visited = set() + def _get_hooks(obj): + if id(obj) in visited: + return + visited.add(id(obj)) + + if isinstance(obj, TrainHook): + hooks[id(obj)] = obj + if isinstance(obj, TrainHookHost): + for o in obj._get_hook_objects(): + _get_hooks(o) + + _get_hooks(self) + return list(hooks.values()) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index d167b95f..c5d8289e 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -22,28 +22,18 @@ from tqdm import tqdm import nnscaler +from nnscaler.runtime.device import DeviceGroup from nnscaler.utils import enforce_zero_num_worker, is_running_distributed from .trainer_args import AggregatedOutputs, TrainerArgs, fix_input -from .train_hook import AggregatedTrainHook, TrainHook +from .train_hook import AggregatedTrainHook, TrainHook, TrainHookHost from .mixed_module import parallelize_model, mixin_module +from .serialization import Checkpointer logger = logging.getLogger(__name__) -# the format of the checkpoint file -# keys: epoch, step, rank -# currently it is not configurable -# TODO: make it configurable -CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/{rank}.ckpt' -CHECKPOINT_LAST_DIR_NAME: str = 'last' -CHECKPOINT_BEST_DIR_NAME: str = 'best' -CHECKPOINT_MERGED_FILE_NAME: str = 'merged.ckpt' -CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}.ckpt' -CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}.ckpt' - - @dataclass class TrainStatus: best_loss = float('inf') @@ -102,6 +92,7 @@ def __init__(self, self.max_train_steps = None self.loggers = [] self.hook = None + self.checkpointer = None # RNG states pending resume; reset to None after resuming self.rng_states_from_resume: dict[str, torch.Tensor] | None = None @@ -111,6 +102,8 @@ def run(self): if not self.train_args.compile_mode: self._train() finally: + if self.checkpointer: + self.checkpointer.flush() for stage in ['train', 'val', 'test']: if self.dataloader[stage] is not None and (close_fn := getattr(self.dataloader[stage], 'close', None)): close_fn() @@ -127,7 +120,7 @@ def _fix_input(self, input): return fix_input(input, self.train_args.input_dtype) def _load_dummy_input(self): - if dummy_sample_gen_fn := self.train_args.resolved_dummy_sample_gen_fn: + if dummy_sample_gen_fn := self.train_args.dummy_sample_gen_fn: return dummy_sample_gen_fn(self.train_args) with enforce_zero_num_worker(DataLoader): @@ -142,12 +135,13 @@ def _load_dummy_input(self): def _setup(self): if is_running_distributed(): nnscaler.init() - if torch.distributed.get_rank() == 0: + if DeviceGroup().local_rank == 0: logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) self.train_args.init_env(self) + self.checkpointer = self.train_args.create_checkpointer() # make sure all ranks are synchronized after init_env if is_running_distributed(): @@ -158,8 +152,15 @@ def _setup(self): # load a dummy input from training dataset self.dummy_input = self._load_dummy_input() self.dummy_input = self._fix_input(self.dummy_input) - - pmodel = parallelize_model(self.train_args, self.dummy_input, load_module=not compile_only) + if self.train_args.dummy_sample_post_process_fn: + self.dummy_input = self.train_args.dummy_sample_post_process_fn(self.train_args, self.dummy_input) + + pmodel = parallelize_model( + self.train_args, self.dummy_input, + load_module=not compile_only, + build_buckets=not self.train_args.should_delay_bucket_building(), + checkpointer=self.checkpointer, + ) if compile_only: return @@ -186,6 +187,17 @@ def _setup(self): self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) self.local_rank = int(os.environ.get('LOCAL_RANK')) self.node_rank = int(os.environ.get('GROUP_RANK')) + assert self.rank // self.local_world_size == self.node_rank + self.local_ranks = list( + range( + self.node_rank * self.local_world_size, + (self.node_rank + 1) * self.local_world_size + ) + ) + self.local_rank0 = self.local_ranks[0] + # create local process groups + for local_rank0 in range(0, self.world_size, self.local_world_size): + DeviceGroup().get_group(list(range(local_rank0, local_rank0 + self.local_world_size))) self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq if len(self.dataloader['train']) % self.train_args.update_freq != 0: @@ -216,6 +228,7 @@ def _setup(self): def reducer_pre_hook(reducer, grad): grad.div_(self.train_args.scaling_factor) self.optimizer.register_reducer_pre_hook(reducer_pre_hook) + # Currently we never pass `last_epoch` to its constructor self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) self.loggers = self.train_args.create_loggers() @@ -224,8 +237,18 @@ def reducer_pre_hook(reducer, grad): self.optimizer, self.lr_scheduler, ] + component_hooks = [] + for component in supported_hook_components: + if isinstance(component, TrainHook): + component_hooks.append(component) + if isinstance(component, TrainHookHost): + component_hooks.extend(component.get_hooks()) + + # dedup hooks + component_hooks = list({id(hook): hook for hook in component_hooks}.values()) + self.hook = AggregatedTrainHook( - [x for x in supported_hook_components if isinstance(x, TrainHook)] + component_hooks + [self.train_args.create_hook()] ) @@ -235,8 +258,19 @@ def reducer_pre_hook(reducer, grad): self.hook.after_setup(self) @classmethod - def _merge_checkpoint(cls, checkpoint_files: List[str]): - state_dicts = [torch.load(f, map_location='cpu', weights_only=False) for f in checkpoint_files] + def _merge_checkpoint(cls, checkpoint_files: List[str], + *, + model_only: bool = False, + checkpointer: Optional[Checkpointer] = None, + ): + checkpointer = checkpointer or Checkpointer() + state_dicts = [] + for f in checkpoint_files: + state_dict = checkpointer.load(f) + if model_only: + # we pop optimizer state to save cpu memory + state_dict.pop('optimizer', None) + state_dicts.append(state_dict) for i in range(1, len(state_dicts)): if state_dicts[i]['train_args'] != state_dicts[0]['train_args']: raise ValueError(f"train_args in {checkpoint_files[i]} is different from {checkpoint_files[0]}") @@ -245,14 +279,16 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): module_state_dict, opt_state_dict = nnscaler.merge_state_dicts( [s['model'] for s in state_dicts], - [s['optimizer'] for s in state_dicts] + [s['optimizer'] for s in state_dicts] if not model_only else None, ) + if model_only: + return {'model': module_state_dict} train_args = copy.deepcopy(state_dicts[0]['train_args']) train_args['checkpoint']['save_type'] = 'merged' global_keys = { 'model', 'optimizer', 'train_args', - 'train_status', 'lr_scheduler', 'rank' + 'train_status', 'lr_scheduler', 'rank', 'nnscaler' } # for extra keys (including `dataloader` and `rng_states`), we will not merge them. # Intead we will collect them from all state_dicts @@ -271,52 +307,62 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): 'lr_scheduler': state_dicts[0].get('lr_scheduler', None), 'train_status': state_dicts[0]['train_status'], 'train_args': train_args, + 'nnscaler': state_dicts[0]['nnscaler'], **extra_keys, } return merged_state_dict - def _broadcast_merged_state_dict(self, state_dict: Dict[str, Any]): + def _broadcast_merged_state_dict( + self, + state_dict: Dict[str, Any], + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + ): """ Broadcast the merged state dict to all ranks. We can't broadcast the whole state_dict at once, because it may be too large, and leads to OOM. Here we will break the model and optimizer state_dict into smaller pieces and broadcast them one by one. Please note we use `torch.distributed.broadcast_object_list` to broadcast the state_dict (including tensors inside). """ + dst_ranks = dst_ranks or list(range(torch.distributed.get_world_size())) + if src_rank not in dst_ranks or self.rank not in dst_ranks: + raise ValueError(f"src_rank and current rank must be in dst_ranks: {dst_ranks}") + pg = DeviceGroup().get_group(dst_ranks) + + if self.rank == src_rank: + if state_dict is None: + raise ValueError("state_dict should not be None in rank 0 when broadcasting") + else: + if state_dict is not None: + raise ValueError("state_dict should be None in other ranks when broadcasting") + state_dict = {} def _broadcast_keys(sdict: Dict[str, Any], set_keys=True): - if self.rank == 0: + if self.rank == src_rank: state_keys = list(sdict.keys()) else: state_keys = None state_key_list = [state_keys] - torch.distributed.broadcast_object_list(state_key_list, src=0) + torch.distributed.broadcast_object_list(state_key_list, src=src_rank, group=pg) state_keys = state_key_list[0] - if set_keys and self.rank != 0: + if set_keys and self.rank != src_rank: for key in state_keys: sdict[key] = {} # assume the values are empty dicts return state_keys def _broadcast_value(sdict, key): - if self.rank == 0: + if self.rank == src_rank: value_list = [sdict[key]] else: value_list = [None] - torch.distributed.broadcast_object_list(value_list, src=0) - if self.rank != 0: + torch.distributed.broadcast_object_list(value_list, src=src_rank, group=pg) + if self.rank != src_rank: sdict[key] = value_list[0] def _broadcast_values(sdict, keys): for key in keys: _broadcast_value(sdict, key) - if self.rank == 0: - if state_dict is None: - raise ValueError("state_dict should not be None in rank 0 when broadcasting") - else: - if state_dict is not None: - raise ValueError("state_dict should be None in other ranks when broadcasting") - state_dict = {} - state_keys = _broadcast_keys(state_dict) for skey in state_keys: @@ -339,9 +385,26 @@ def _broadcast_values(sdict, keys): return state_dict @classmethod - def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): - merged_state_dict = cls._merge_checkpoint(checkpoint_files) - torch.save(merged_state_dict, output_file) + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str, + *, + model_only: bool = False, + checkpointer: Optional[Checkpointer] = None, + serializer: Optional[str] = None, + serializer_args: Optional[dict[str, Any]] = None, + ): + if checkpointer is not None: + if serializer is not None or serializer_args is not None: + raise ValueError("serializer and serializer_args should not be specified when checkpointer is given") + else: + checkpointer = Checkpointer(serializer=serializer, serializer_args=serializer_args) + + merged_state_dict = cls._merge_checkpoint( + checkpoint_files, + model_only=model_only, + checkpointer=checkpointer, + ) + checkpointer.save(merged_state_dict, output_file) + checkpointer.flush() def _log_finalize(self): for logger in self.loggers: @@ -361,13 +424,21 @@ def _load_checkpoint(self): if not resume_from: return logger.info(f"Resuming from {resume_from}") + trimmed_broadcast_required = False + load_from_merged = False + if resume_from.is_file(): - resume_from = resume_from # when we load from merged checkpoint - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) - if convert_fn := self.train_args.checkpoint.resolved_convert_fn: - state_dict = convert_fn(state_dict) + # when we load from merged checkpoint + load_from_merged = True + trimmed_broadcast_required = self.train_args.checkpoint.resume_from.save_memory + if not self.train_args.checkpoint.resume_from.save_memory or self.local_rank == 0: + state_dict = self.checkpointer.load(resume_from) + if convert_fn := self.train_args.checkpoint.resolved_convert_fn: + state_dict = convert_fn(state_dict) + else: + state_dict = None else: - ckpt_files = list(resume_from.glob('*.ckpt')) + ckpt_files = self.checkpointer.list_checkpoints(resume_from) rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} if set(rank_ckpt_files.keys()) != set(range(len(rank_ckpt_files))): raise ValueError(f"Checkpoint files in {resume_from} are not complete: {rank_ckpt_files.keys()}") @@ -378,24 +449,59 @@ def _load_checkpoint(self): if len(rank_ckpt_files) != self.world_size or self.train_args.checkpoint.resume_from.with_merged: # merge the checkpoint files from all ranks and broadcast to all ranks torch.distributed.barrier() - if self.rank == 0: + if self.local_rank == 0: logger.info(f"Merging checkpoint files from {resume_from}") - state_dict = self._merge_checkpoint(list(rank_ckpt_files.values())) + state_dict = self._merge_checkpoint(list(rank_ckpt_files.values()), checkpointer=self.checkpointer) else: state_dict = None - logger.info(f"Broadcasting merged checkpoint to all ranks.") - state_dict = self._broadcast_merged_state_dict(state_dict) - logger.info(f"Broadcasted merged checkpoint to all ranks.") + + load_from_merged = True + trimmed_broadcast_required = self.train_args.checkpoint.resume_from.save_memory + if not self.train_args.checkpoint.resume_from.save_memory: + logger.info(f"Broadcasting merged checkpoint to all ranks.") + state_dict = self._broadcast_merged_state_dict( + state_dict, src_rank=self.local_rank0, dst_ranks=self.local_ranks + ) + logger.info(f"Broadcasted merged checkpoint to all ranks.") else: - resume_from = resume_from / f'{self.rank}.ckpt' - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + state_dict = self.checkpointer.load_for_rank(resume_from, self.rank) + if state_dict['train_args']['compute_config'] != asdict(self.train_args.compute_config): + logger.warning( + f"compute_config is changed, and loading checkpoint may fail. " + f"If it fails, please try with merged checkpoint." + ) + + if trimmed_broadcast_required: + logger.info("Broadcasting trimmed checkpoint to all ranks.") + state_dict = state_dict or {} + state_dict['model'], state_dict['optimizer'] = nnscaler.trimmed_broadcast_merged_state_dict( + self.model, + state_dict['model'] if self.local_rank == 0 else None, + self.optimizer, + state_dict['optimizer'] if self.local_rank == 0 else None, + src_rank=self.local_rank0, + dst_ranks=self.local_ranks, + ) + remaining_state_dict = self._broadcast_merged_state_dict( + {k: v for k, v in state_dict.items() if k not in ('model', 'optimizer')} + if self.local_rank == 0 else None, + src_rank=self.local_rank0, + dst_ranks=self.local_ranks, + ) + if self.local_rank != 0: + state_dict.update(remaining_state_dict) + logger.info("Broadcasted trimmed checkpoint to all ranks.") + + # trimmed checkpoint is sharded + ckpt_save_type = 'sharded' + else: + # if it is not a well-formed state_dict (from third party) + # we will treat it as a merged state_dict + ckpt_save_type = state_dict.get('train_args', {}) \ + .get('checkpoint', {}) \ + .get('save_type', 'merged') self.hook.on_load_checkpoint(self, state_dict) - # if it is not a well-formed state_dict (from third party) - # we will treat it as a merged state_dict - ckpt_save_type = state_dict.get('train_args', {}) \ - .get('checkpoint', {}) \ - .get('save_type', 'merged') if ckpt_save_type == 'merged': # it is a merged state dict nnscaler.load_merged_state_dict( @@ -420,10 +526,11 @@ def _load_checkpoint(self): raise ValueError("lr_scheduler is not set in the current trainer") if self.lr_scheduler: self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + if 'dataloader' in state_dict and state_dict['dataloader'] is not None: if not self._is_resumable_dataloader(): raise ValueError("dataloader is not resumable, but checkpoint contains dataloader state") - if ckpt_save_type == 'merged': + if load_from_merged: dataloader_states = state_dict['dataloader'] # only load dataloader state when all ranks have the same state # TODO: is this reasonable? @@ -443,7 +550,7 @@ def _load_checkpoint(self): self.train_status = TrainStatus(**state_dict['train_status']) # we don't resume rng states when loading merged checkpoint, - if ckpt_save_type != 'merged': + if not load_from_merged: self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() else: logger.warning("RNG states are not resumed when loading merged checkpoint.") @@ -541,50 +648,47 @@ def _save_checkpoint(self, loss): self.hook.on_save_checkpoint(self, state_dict) - ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( + ckpt_file = save_dir / self.checkpointer.get_checkpoint_file_path( epoch=current_epoch, step=self.train_status.finished_train_steps, rank=self.rank, ) logger.info(f"Saving checkpoint to {str(ckpt_file.parent)}") ckpt_file.parent.mkdir(parents=True, exist_ok=True) - torch.save(state_dict, ckpt_file) + self.checkpointer.save(state_dict, ckpt_file) # save last if checkpoint_config.save_last: logger.info(f"Saving checkpoint as the last checkpoint.") - last_file = save_dir / CHECKPOINT_LAST_FILE_FORMAT.format( - rank=self.rank + + # remove the old symlink or file + self.checkpointer.remove_for_rank( + save_dir / self.checkpointer.get_last_dir_name(), + self.rank + ) + self.checkpointer.copy_for_rank( + ckpt_file.parent, + save_dir / self.checkpointer.get_last_dir_name(), + self.rank, + checkpoint_config.symlink_best_and_last ) - last_file.parent.mkdir(parents=True, exist_ok=True) - if checkpoint_config.symlink_best_and_last: - # remove the old symlink or file - if last_file.is_symlink() or last_file.exists(): - last_file.unlink() - # symblink as relative path - last_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) - # last_file.symlink_to(ckpt_file) - else: - shutil.copy(ckpt_file, last_file) # save best if checkpoint_config.save_best and loss <= self.train_status.best_loss: logger.info(f"Best loss updated: {self.train_status.best_loss:.3f} -> {loss:.3f}") logger.info(f"Saving checkpoint as the best checkpoint.") - best_file = save_dir / CHECKPOINT_BEST_FILE_FORMAT.format( - epoch=current_epoch, - step=self.train_status.finished_train_steps, - rank=self.rank, + + # remove the old symlink or file + self.checkpointer.remove_for_rank( + save_dir / self.checkpointer.get_best_dir_name(), + self.rank + ) + self.checkpointer.copy_for_rank( + ckpt_file.parent, + save_dir / self.checkpointer.get_best_dir_name(), + self.rank, + checkpoint_config.symlink_best_and_last ) - best_file.parent.mkdir(parents=True, exist_ok=True) - if checkpoint_config.symlink_best_and_last: - # symblink as relative path - if best_file.is_symlink() or best_file.exists(): - best_file.unlink() - best_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) - # best_file.symlink_to(ckpt_file) - else: - shutil.copy(ckpt_file, best_file) torch.distributed.barrier() # remove old checkpoints @@ -604,7 +708,10 @@ def _expire_checkpoints(self): save_dir = Path(self.train_args.checkpoint.save_dir) checkpoints = [ p.name for p in save_dir.glob('*') - if p.is_dir() and p.name not in [CHECKPOINT_BEST_DIR_NAME, CHECKPOINT_LAST_DIR_NAME] + if p.is_dir() and p.name not in [ + self.checkpointer.get_best_dir_name(), + self.checkpointer.get_last_dir_name() + ] ] if len(checkpoints) <= self.train_args.checkpoint.keep_last_n_checkpoints: return @@ -614,12 +721,12 @@ def _expire_checkpoints(self): checkpoint_info.sort() expire_list = [c[1] for c in checkpoint_info[:-self.train_args.checkpoint.keep_last_n_checkpoints]] - best_ckpt = save_dir / CHECKPOINT_BEST_DIR_NAME - last_ckpt = save_dir / CHECKPOINT_LAST_DIR_NAME + best_ckpt = save_dir / self.checkpointer.get_best_dir_name() + last_ckpt = save_dir / self.checkpointer.get_last_dir_name() for ckpt_dir in [best_ckpt, last_ckpt]: if not ckpt_dir.exists(): continue - for p in ckpt_dir.glob('*.ckpt'): + for p in self.checkpointer.list_checkpoints(ckpt_dir): if p.is_symlink(): ckpt_name = p.resolve().parent.name if ckpt_name in expire_list: @@ -634,6 +741,7 @@ def _expire_checkpoints(self): def _global_batch_iterator(self, num_skip_first=0, stage='train'): if stage == 'train': if self.dataloader_resumed or num_skip_first == 0: + logger.info(f'Trainer resumes dataloader directly.') # if the checkpoint stops at the end of an epoch, # the rng states must be resumed before creating iterator # because `DataLoader.__iter__()` uses the rng (dunno why), @@ -641,6 +749,7 @@ def _global_batch_iterator(self, num_skip_first=0, stage='train'): self._try_resume_rng_states() it = iter(self.dataloader[stage]) else: # dry run until reach the desired batch. + logger.info(f'Trainer try to resume dataloader for {stage} stage with {num_skip_first}.') it = iter(self.dataloader[stage]) for _ in range(num_skip_first * self.train_args.update_freq): _sample = next(it) @@ -773,7 +882,7 @@ def _train(self): torch.cuda.reset_peak_memory_stats() if self.train_status.finished_train_steps >= self.max_train_steps: - logger.info(f"Training is skipped: already done.") + logger.info(f"Training is skipped: already done, finished_train_steps={self.train_status.finished_train_steps} >= max_train_steps={self.max_train_steps}.") return start_epoch = self.train_status.finished_train_steps // self.total_train_steps_per_epoch @@ -977,7 +1086,7 @@ def _train_epoch(self, epoch: int) -> None: step_stat.gnorm = step_stat.gnorm.item() # update parameters - step_stat.lr = self.optimizer.param_groups[0]['lr'] + step_stat.lr = self.optimizer.param_groups[0]['lr'] # only log the first group's lr self.hook.before_optimizer_step(self) self.optimizer.step() self.hook.after_optimizer_step(self) diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 2fb99daa..5cd811fa 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass, field, replace import importlib -from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union, TypeVar +from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Type, Union, TypeVar from typing_extensions import get_args from pathlib import Path import logging @@ -20,7 +20,7 @@ import torch import nnscaler -from nnscaler.utils import fields, transform_recursively, load_type +from nnscaler.utils import fields, fn_field, transform_recursively, load_type, copy_dynamic from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule @@ -32,6 +32,7 @@ ) from .loggers.logger_base import LoggerBase from .train_hook import TrainHook +from .serialization import Checkpointer if TYPE_CHECKING: from .trainer import Trainer @@ -123,9 +124,9 @@ def fix_input(input, input_dtype=None): return tuple(fix_input(v, input_dtype) for v in input) elif isinstance(input, torch.Tensor): if input.is_floating_point() and input_dtype is not None: - return input.to(input_dtype).cuda() + return copy_dynamic(input, input.to(input_dtype).cuda()) else: - return input.cuda() + return copy_dynamic(input, input.cuda()) return input @@ -238,9 +239,17 @@ class ModuleParallelizeConfig: # we can parallelize submodules instead of creating whole model. # This is useful sometimes. args: Optional[Dict[str, Any]] = None - # the full qualified name of the function to generate dummy forward args - # Its type should be `Callable[[TrainerArgs],Dict[str, Any]]` - forward_args_gen_fn: str = None + # the full qualified name of the function to generate dummy inputs for forward + # Its type should be `Callable[[TrainerArgs], dict[str, Any]]` + # where the output dict is the kwargs for forward function of the module + # The tensors in the sample will be moved to GPU and converted to input_dtype by trainer. + forward_args_gen_fn: Optional[Callable[['TrainerArgs'], dict[str, Any]]] = fn_field(default=None) + # the full qualified name of the function to post process the dummy inputs for forward + # Note the tensors in the inputs have been moved to GPU and converted to input_dtype + # But you can still further process the sample, + # for example, mark some dims of tensors as dynamic + # (you can do it in `forward_args_gen_fn` as well) + forward_args_post_process_fn: Optional[Callable[['TrainerArgs', dict[str, Any]], dict[str, Any]]] = fn_field(default=None) # the model state dict file for tracing. # It is only used in tracing to serve as the initial state dict of the model. tracing_from_weights: str = None @@ -289,8 +298,7 @@ def create_model(self, trainer_args: 'TrainerArgs', module_args: Optional[tuple[ return self.model_type(*args, **kwargs) def create_dummy_forward_args(self, trainer_args: 'TrainerArgs') -> dict[str, Any]: - forward_args_gen_fn = load_type(self.forward_args_gen_fn) - return forward_args_gen_fn(trainer_args) + return self.forward_args_gen_fn(trainer_args) @dataclass @@ -314,6 +322,7 @@ class OptimizerConfig: args: Dict[str, Any] = field(default_factory=dict) clip_gnorm: float = 0.0 + param_clss_fn: Optional[Callable[[str], Any]] = fn_field(default=None) # loss reduction method # mean: average the loss over all micro-batches # sum: sum the loss of all micro-batches @@ -403,6 +412,28 @@ class ResumeOptions: # `None` means will load the sharded checkpoint files if the world size is not changed. # and will load merged checkpoint if the world size is changed. with_merged: Optional[bool] = None + # If the memory is limited, we can save memory by only loading merged state dict in GPU 0 of each node + # and broadcast trimmed state dict to other ranks in the same node + # although this will be slower + # Only used when resuming from a merged checkpoint. + save_memory: bool = True + + +@dataclass +class SerializerOptions: + # the serialization runner to be used + # It should be a name of registered SerializationRunners + name: str = '' + + # the full qualified name of the function to create the serialization runner + # Currently we do not support this way + # to make sure all serialization runners are registered and can be used in other places + # (like nnscaler.cli.Trainer.merge_checkpoint) + # type: str = None + + # arguments for the serialization runner + # Note You should be able to load for any arguments + args: Dict[str, Any] = field(default_factory=dict) @dataclass @@ -410,6 +441,19 @@ class CheckpointConfig: save_dir: str = './checkpoints' no_save: bool = False + # `"pt"`: PyTorch native format + # `"safetensors"`: Safetensors format + # You can also register new formats via `nnscaler.cli.serialization.register_format` + # or specify a custom format here by providing a CheckpointFormat subclass + format: str = 'pt' + + # the serialization runner to be used + # It should be a name of registered SerializationRunners + # If None, the default serializer will be used + serializer: Optional[SerializerOptions] = field(default=None, metadata={ + 'normalize': lambda x: {'name': x} if isinstance(x, str) else x + }) + # `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is # a folder with as many files as the world size. # `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is @@ -452,6 +496,13 @@ def resolved_convert_fn(self) -> Optional[Callable[[Dict[str, Any]], Dict[str, A return load_type(self.resume_from.convert_fn) def __post_init__(self): + # backward compatibility + if isinstance(self.resume_from, str): + self.resume_from = ResumeOptions(checkpoint=self.resume_from) + + if isinstance(self.serializer, str): + self.serializer = SerializerOptions(name=self.serializer) + if self.resume_from and self.resume_from.checkpoint: if self.resume_from.checkpoint in ['last', 'best']: if not self.save_dir: @@ -468,6 +519,12 @@ def __post_init__(self): if not self.save_dir: raise ValueError("save_dir is required") + if self.format not in Checkpointer.NAME_MAP: + raise ValueError(f"Invalid format {self.format}") + + if self.serializer and self.serializer.name not in Checkpointer.REGISTERED_RUNNERS: + raise ValueError(f"Invalid Serialization runner {self.serializer.name}") + if self.every_n_epochs is not None and self.every_n_train_steps is not None: raise ValueError("Cannot specify both every_n_epochs and every_n_train_steps") if self.every_n_epochs is None and self.every_n_train_steps is None: @@ -595,9 +652,16 @@ class TrainerArgs(PrecisionMixin, PolicyMixin): # compile: compile the model but not training # run: compile and run the model run_mode: str = 'run' - # the full qualified name of the function to generate dummy sample for forward + # the full qualified name of the function to generate dummy sample # Its type should be `Callable[[TrainerArgs], Any]` - dummy_sample_gen_fn: str = None + # The tensors in the sample will be moved to GPU and converted to input_dtype by trainer. + dummy_sample_gen_fn: Optional[Callable[['TrainerArgs'], Any]] = fn_field(default=None) + # the full qualified name of the function to post process the dummy sample + # Note the tensors in the sample have been moved to GPU and converted to input_dtype + # But you can still further process the sample, + # for example, you can use this function to mark some dims of tensors as dynamic + # when you don't use `dummy_sample_gen_fn` or don't handle dynamic dims in it, + dummy_sample_post_process_fn: Optional[Callable[['TrainerArgs', Any], Any]] = fn_field(default=None) # the model state dict file for tracing. # It is only used in tracing to serve as the initial state dict of the model. tracing_from_weights: str = None @@ -811,12 +875,6 @@ def resolved_aggregate_outputs_fn(self): return None return load_type(self.optimizer.aggregate_outputs_fn) - @property - def resolved_dummy_sample_gen_fn(self): - if not self.dummy_sample_gen_fn: - return None - return load_type(self.dummy_sample_gen_fn) - @property def scaling_factor(self): return self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus @@ -846,7 +904,7 @@ def init_env(self, trainer: 'Trainer'): init_env_fn = load_type(self.init_env_fn) init_env_fn(trainer) - def get_resolved_var(self, fqn: str) -> Any: + def get_resolved_var(self, fqn: str, *, default: Any = None) -> Any: """ Get a resolved variable from the vars dictionary. The fqn is a full qualified name of the variable, e.g. 'x.y.z'. @@ -855,7 +913,7 @@ def get_resolved_var(self, fqn: str) -> Any: var = self._vars for part in parts: if part not in var: - raise ValueError(f"Variable {fqn} not found in vars") + return default var = var[part] return var @@ -863,10 +921,17 @@ def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model.args) return self.model_type(**kwargs) + def should_delay_bucket_building(self) -> bool: + return self.optimizer.param_clss_fn is not None + def create_parallel_optimizer(self, parallel_model: torch.nn.Module): kwargs = self.create_kwarg(self.optimizer.args) optimizer_class = load_type(self.optimizer.type) - return build_optimizer(parallel_model, optimizer_class, self.compute_config, **kwargs) + return build_optimizer( + parallel_model, optimizer_class, self.compute_config, + self.optimizer.param_clss_fn, + **kwargs + ) def create_dataset(self, stage='train'): dataset_args = getattr(self.dataset, f'{stage}_args') @@ -947,3 +1012,12 @@ def create_hook(self) -> TrainHook: return ArgsTrainHook(hook_config) else: raise ValueError(f"Invalid hook_config {hook_config}") + + def create_checkpointer(self) -> Checkpointer: + if self.checkpoint.serializer: + return Checkpointer( + self.checkpoint.format, + self.checkpoint.serializer.name, + self.checkpoint.serializer.args + ) + return Checkpointer(self.checkpoint.format) diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index fc197b71..3f6e8652 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import inspect from typing import Generator, Iterable, List, Any, Optional, Tuple, Dict import logging @@ -35,7 +36,10 @@ def __repr__(self): return self.name -def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: +def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None, + *, + strip_star: bool = True, +) -> Any: """ Return repr-able value of a tensor or value. For tensor, return IRValue({prefix}{tensor.name}_{tensor.tid}) @@ -44,6 +48,7 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: Args: val (Any): tensor or non-tensor value prefix_attr (str): prefix to the tensor name if the tensor is an attribute + strip_star (bool): whether to strip leading * for *args and **kwargs Returns: the val that can be repr safely """ @@ -51,20 +56,22 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: return val if isinstance(val, IRObject): tensor_name = val.name + if strip_star: + tensor_name = tensor_name.lstrip('*') tensor_name = tensor_name.replace('.', '_') name = '_'.join([tensor_name, str(val.tid)]) if prefix_attr is not None and val.is_attr(): name = prefix_attr + name return IRValue(name) elif isinstance(val, slice): - return slice(_safe_repr_value(val.start, prefix_attr), _safe_repr_value(val.stop, prefix_attr), _safe_repr_value(val.step, prefix_attr)) + return slice(_safe_repr_value(val.start, prefix_attr, strip_star=strip_star), _safe_repr_value(val.stop, prefix_attr, strip_star=strip_star), _safe_repr_value(val.step, prefix_attr, strip_star=strip_star)) elif isinstance(val, dict): - return {_safe_repr_value(k, prefix_attr): _safe_repr_value(v, prefix_attr) for k, v in val.items()} + return {_safe_repr_value(k, prefix_attr, strip_star=strip_star): _safe_repr_value(v, prefix_attr, strip_star=strip_star) for k, v in val.items()} elif isinstance(val, list): - return [_safe_repr_value(v, prefix_attr) for v in val] + return [_safe_repr_value(v, prefix_attr, strip_star=strip_star) for v in val] elif isinstance(val, tuple): # TODO: support subclasses of tuple, like torch.Size? - return tuple(_safe_repr_value(v, prefix_attr) for v in val) + return tuple(_safe_repr_value(v, prefix_attr, strip_star=strip_star) for v in val) elif isinstance(val, (int, str, bool, float, type(None), bytes, type(Ellipsis), torch.dtype)): return val elif isinstance(val, torch.device): @@ -89,7 +96,10 @@ class CodeEmission: def node_name(self, node: IRCell) -> str: return f"{node.name}{node.cid}" - def tensor_name(self, val: Any, prefix_attr: Optional[str] = None) -> str: + def tensor_name(self, val: Any, prefix_attr: Optional[str] = None, + *, + strip_star: bool = True, + ) -> str: """ Return representation of a value or a tensor. For tensor, return the {prefix}{tensor.name}_{tensor.tid} @@ -98,10 +108,13 @@ def tensor_name(self, val: Any, prefix_attr: Optional[str] = None) -> str: Args: val (Any): tensor or non-tensor value prefix_attr (Optional[str]): prefix to the tensor name if the tensor is an attribute + strip_star (bool): whether to strip leading * for *args and **kwargs + You should set it to False when you want to generate code for + function arguments. Returns: representation of the val in str """ - return repr(_safe_repr_value(val, prefix_attr)) + return repr(_safe_repr_value(val, prefix_attr, strip_star=strip_star)) def complex_name(self, val: Any, prefix_attr: Optional[str]=None) -> str: """ @@ -225,8 +238,35 @@ def emit_fnode(self, node: IRFwOperation, runtime_devid: int, plan_ndevs: int, r emit_rule = self._emit_rules.map(signature) body = emit_rule(node, inputs, kwargs, runtime_devid, plan_ndevs, runtime_ndevs) + def _to_tuple_str(names: List[str]) -> str: + if len(names) == 1: + return f'({names[0]}, )' + return '(' + ', '.join(names) + ')' + + def _insert_hook(outputs=None, is_pre: bool=False, output_len: int = 0): + hook = node.pre_hook if is_pre else node.post_hook + if not hook: + return + module_path = inspect.getmodule(hook).__name__ + fsig = f'{module_path}.{hook.__name__}' + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + codes.append( + f'{fsig}(self, ' + + repr(node.hook_meta) + ', ' + + f"{_to_tuple_str(inputs)}, " + + f"dict({', '.join(kw_pairs)})" + + ('' if is_pre else ', ' + outputs) + + ')' + ) + + _insert_hook(is_pre=True) + if len(node.outputs()) == 0: codes.append(body) + _insert_hook(is_pre=False, outputs='None') else: irobj_path = {} def r(t, current_path): @@ -245,8 +285,12 @@ def r(t, current_path): if all(len(x) == 1 for x in irobj_path.values()): # if all IRObjects are leafs, we can directly assign the output outputs = [self.tensor_name(t) for t in node.outputs()] - outputs = ', '.join(outputs) - codes.append(f'{outputs} = {body}') + outputs_str = ', '.join(outputs) + codes.append(f'{outputs_str} = {body}') + _insert_hook( + outputs=outputs_str if len(node.outputs()) == 1 else _to_tuple_str(outputs), + is_pre=False + ) else: outputs = [] im_outputs = [] @@ -258,7 +302,12 @@ def r(t, current_path): im_ouptut = self.tensor_name(IRObject('im_output')) im_outputs.append(im_ouptut) outputs.append(im_ouptut) - codes.append(f'{", ".join(outputs)} = {body}') + outputs_str = ', '.join(outputs) + codes.append(f'{outputs_str} = {body}') + _insert_hook( + outputs=outputs_str if len(node.outputs()) == 1 else _to_tuple_str(outputs), + is_pre=False + ) for t, path in irobj_path.items(): if len(path) == 1: # immediate output, skip diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 32cd75fe..c3bc8c34 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -8,8 +8,9 @@ import torch import numpy as np import inspect +import pickle -from nnscaler.ir.cten import IRCell +from nnscaler.ir.cten import IRCell, IRTensor from nnscaler.ir.tensor import IRFullTensor, IRSubTensor from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from nnscaler.ir.adapter import IRWeightReducer, IRAdapter @@ -23,7 +24,7 @@ from nnscaler.execplan.execplan import ExeReuseCell from nnscaler.codegen.syntax.symtable import SymbolTable -from nnscaler.codegen.syntax.blocks import ClassBlock, FunctionBlock, Block +from nnscaler.codegen.syntax.blocks import ClassBlock, ForBlock, FunctionBlock, Block from nnscaler.codegen.emit import FuncEmission from nnscaler.codegen.module.autograd import AutogradAdapterCodeGen @@ -126,6 +127,7 @@ def __init__( 'from pathlib import Path', 'import torch', 'import torch.utils.checkpoint as ckpt', 'import nnscaler', 'import nnscaler.flags', + 'import nnscaler.runtime.function', 'import _operator', 'from numpy import inf', 'import builtins', '', f'runtime_version = {runtime_version!r}', '', '' ] @@ -138,6 +140,18 @@ def __init__( # self.init_code.append('@torch.jit.script') self.init_code.append(op_impl) self.init_code += [''] + + # hooks + hook_imports = set() + for node in execplan.graph.select(ntype=IRFwOperation): + if node.pre_hook is not None: + hook_imports.add(inspect.getmodule(node.pre_hook).__name__) + if node.post_hook is not None: + hook_imports.add(inspect.getmodule(node.post_hook).__name__) + for modname in hook_imports: + self.init_code.append(f'import {modname}') + self.init_code += [''] + # module init code self.model_init_statements: List[str] = list() # module method bodies for forward computations, e.g. Segments, Adapters. @@ -317,7 +331,8 @@ def gen( *, as_parallel_module: bool = False, end2end_mode: bool = False, - forward_args: Optional[Dict[str, Any]] = None + forward_args: Optional[Dict[str, Any]] = None, + outfile_attr_meta_map: Optional[str] = None, ) -> str: """ Generate model implementation code based on the given graph. @@ -406,6 +421,7 @@ def forward(self, x, y=None, z=None): This is used only in parallel module. forward_args (Dict[str, Any]): argument names and their default values of forward function, if None, use node inputs. This is used only in parallel module. + outfile_attr_meta_map (str): output file path for parameter mapping. None if don't save Returns: generated code @@ -451,6 +467,7 @@ def forward(self, x, y=None, z=None): if k not in param_first_used_pos: param_first_used_pos[k] = (i, v) + attr_meta_map = {} # emit code for node in sequence: if isinstance(node, IRSegment): @@ -472,7 +489,7 @@ def forward(self, x, y=None, z=None): # emit node tensor declaration into `__init__` # typically it's about the `nn.Parameter` - self.init_attributes(node) + attr_meta_map.update(self.init_attributes(node)) # emit node code # codes : List[str] @@ -483,11 +500,15 @@ def forward(self, x, y=None, z=None): for t in node.inputs(): if isinstance(t, IRSubTensor): if not t.is_attr(): - args.append(self.tensor_name(t)) + args.append(self.tensor_name(t, strip_star=False)) else: - args.append(self.tensor_name(t)) + args.append(self.tensor_name(t, strip_star=False)) node_args.append(args) + if outfile_attr_meta_map: + with open(outfile_attr_meta_map, 'wb') as f: + pickle.dump(attr_meta_map, f) + # generate full code with ClassBlock( class_name='GenModel', @@ -499,6 +520,7 @@ def forward(self, x, y=None, z=None): if as_parallel_module: cb.insert_body(f'rank = {device}') # save rank in class level + cb.insert_body(f'world_size = {self.runtime_ndevs}') # save world size in class level # async_op, max_bucket_size_bytes and zero_use_reduce_scatter # parameters are for testing purpose # and will not expose to user @@ -506,15 +528,17 @@ def forward(self, x, y=None, z=None): args=[ 'self', 'init_params=True', - '*', + 'build_buckets=True', + '*args', f'async_op={CompileFlag.async_reducer}', f'max_bucket_size_bytes={CompileFlag.max_reducer_bucket}', f'zero_use_reduce_scatter={CompileFlag.zero_use_reduce_scatter}', + f'**kwargs', ] ) as ib: ib.insert_body(self.model_init_statements) ib.insert_body('') - ib.insert_body('self._post_init(init_params)') + ib.insert_body('self._post_init(init_params, build_buckets)') else: with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.model_init_statements) @@ -528,7 +552,10 @@ def forward(self, x, y=None, z=None): if isinstance(node, IRSegment): segment_idxs.append(idx) - with FunctionBlock(func_name=name, args=input_args) as fb: + saved_tensors_hooks_needed = isinstance(node, IRSegment) and CompileFlag.use_zero > 1 + func_name = name + '_impl' if saved_tensors_hooks_needed else name + + with FunctionBlock(func_name=func_name, args=input_args) as fb: fb.insert_body(forward_code) # generate output outputs = [self.tensor_name(t) for t in node.outputs()] @@ -541,6 +568,16 @@ def forward(self, x, y=None, z=None): cb.insert_body('@torch.jit.script_method') cb.insert_body(fb.code) + if saved_tensors_hooks_needed: + with FunctionBlock(func_name=name, args=input_args) as fb: + # call segment under save_params_hooks context + save_context_code = f'with self.save_params_hooks():' + with Block(save_context_code) as cblock: + cblock.insert_body(f'return self.{func_name}({", ".join(node_args[idx])})') + fb.insert_body(cblock.code) + cb.insert_body('') + cb.insert_body(fb.code) + if as_parallel_module: if not segment_idxs: raise RuntimeError("The graph has no segment, forward code cannot be generated.") @@ -627,8 +664,11 @@ def _get_resolved_arg(arg_name, default_value): outputs = self.return_name(node.outputs(), skip_attr=True) call_code = f'{outputs} = self.{self.node_name(node)}({", ".join(inputs)})' # be sure the user doesn't specify unused args. - for unused_arg in unused_args: - fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') + # but sometimes this can cause issues + # (for example, the value is used in an `if` condition in the original forward function), + # so we disable it for now. + # for unused_arg in unused_args: + # fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') fb.insert_body(call_code) return_code = f'return {self.return_name_complex(self.execplan.graph.outputs())}' fb.insert_body(return_code) @@ -644,6 +684,11 @@ def emit_comm_groups(self): - `model_init_statements` """ sign = 'self.init_group(ranks={ranks})' + # create single rank communication group + self.model_init_statements.append('# single rank communication groups') + with ForBlock(var='rank', iters=f'range({self.runtime_ndevs})') as fb: + fb.insert_body(sign.format(ranks='[rank]')) + self.model_init_statements.extend(fb.code) # create communication group self.model_init_statements.append('# communication groups') for ranks in self.comm_groups: @@ -651,7 +696,7 @@ def emit_comm_groups(self): self.model_init_statements.append(code) self.model_init_statements.append(' ') - def init_attributes(self, node: IRCell): + def init_attributes(self, node: IRCell) -> dict[str, dict[str, Any]]: """ Emit tensor declaration code @@ -660,10 +705,18 @@ def init_attributes(self, node: IRCell): This method also populates `self.symbols : SymbolTable` to record the names of the variables for the tensors ever encountered. + + Returns: + dict[str, dict[str, Any]]: A mapping of tensor names to their attributes. """ + attr_meta_map = {} + self._init_attributes(node, attr_meta_map) + return attr_meta_map + + def _init_attributes(self, node: IRCell, attr_meta_map: Dict[str, Any]): psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}), persistent={persistent})" - map_sign = "self.add_full_map('{attr}', {tid}, {is_param}, '{orig_name}', {full_shape}, {slicers}, {val_chunks})" + map_sign = "self.add_full_map('{attr}', {tid}, {is_param}, '{orig_name}', {shape}, {slicers}, {val_chunks})" if not isinstance(node, IRSegment): for itensor in node.inputs(): name = self.tensor_name(itensor, prefix_attr='self.') @@ -691,14 +744,24 @@ def init_attributes(self, node: IRCell): assert len(slicers) == 1 and slicers[0] == slice(0, 1), f"Unexpected slicers {slicers} for scalar tensor." slicers = '...' # Ellipsis slicer for scalar tensor, x[...] is equivalent to x val_chunks = itensor.valmap[1] - code = map_sign.format( - attr=self.tensor_name(itensor), + attr_name = self.tensor_name(itensor) + attr_props = dict( tid=itensor.parent.tid, is_param=itensor.is_param(), orig_name=itensor.parent.name, - full_shape=tuple(itensor.parent.origin_shape), - slicers=str(slicers), - val_chunks=val_chunks + shape=tuple(itensor.parent.origin_shape), # full tensor shape + slicers=slicers, + val_chunks=val_chunks, + ) + attr_meta_map[attr_name] = dict( + **attr_props, + dtype=itensor.dtype, + sub_shape=tuple(itensor.shape) + ) + + code = map_sign.format( + attr=attr_name, + **attr_props ) self.model_init_statements.append(code) self.model_init_statements.append('') @@ -710,7 +773,7 @@ def init_attributes(self, node: IRCell): self.symbols.create(self.tensor_name(output, prefix_attr='self.')) else: for sub_node in node.nodes(): - self.init_attributes(sub_node) + self._init_attributes(sub_node, attr_meta_map) return def init_reducer(self, @@ -874,12 +937,64 @@ def emit_context_manager(node: IRCell): code = "with " + ", ".join(ctx_managers) + ":" return code - def emit_node(node): + def emit_node(node, node_idx): node_code = [] # execute if isinstance(node, IRFwOperation): + param_inputs = [ + self.tensor_name(t, prefix_attr='self.') for t in node.iobjs() + if isinstance(t, IRTensor) and t.is_param() + ] + + # for multiref node under zero3, we need to clone the params to avoid in-place modification issue + if param_inputs and CompileFlag.use_zero > 1 and node.name == 'multiref': + _logger.warning(f'Node {node} is a multiref node with param inputs under ZeRO-3, ' + f'we set clone_level=1 to avoid in-place modification issue.') + node.kwargs['clone_level'] = 1 + code = self.emit_fnode(node, runtime_devid=runtime_devid, plan_ndevs=len(self.devices), runtime_ndevs=self.runtime_ndevs, prefix_attr='self.') - node_code += code + + if not param_inputs or CompileFlag.use_zero <= 1: + node_code += code + else: + activation_inputs = [ + self.tensor_name(t) for t in node.iobjs() + if isinstance(t, IRTensor) and not t.is_attr() and t.requires_grad + ] + activation_outputs = [ + self.tensor_name(t) for t in node.oobjs() + if isinstance(t, IRTensor) and t.requires_grad + ] + + # insert param prefetch before each fnode for zero3 + for t in param_inputs: + node_code.append(f'self.prefetch_param({t})') + # The backward hook here is not reliable, + # 1. there can be no activation input requiring grad, + # 2. some inputs may not be used. + # so, to maximize the chance of triggering backward hook + # let's hook to every input requiring grad + # We also add evict logic in AccumulateGrad hook in bucket implementation, + # which can make sure params are evicted after backward use. + for q in activation_inputs: + node_code.append(f'{q} = self.backward_postevict_param({q}, {t}, {node_idx})') + + node_code += code + + # insert zero param release after each fnode + for t in param_inputs: + node_code.append(f'self.postevict_param({t})') + + # insert backward hook for activation outputs to fetch params in backward + for t in activation_outputs: + # we don't know which activation output will be used in backward + # (DCE may not work 100% correctly), + # so we add hook to all activation outputs for all input params + for p in param_inputs: + node_code.append( + f'{t} = self.backward_prefetch_param({t}, {p}, {node_idx})' + ) + elif isinstance(node, IRAdapter): # for adapters inside an IRSegment, we don't apply async communication to it # as it is mostly in critical path. @@ -905,15 +1020,15 @@ def insert_codes_under_ctx(ctx_code, codes): node_codes = [] current_context_manager_code = "" current_codes = [] - for node in nodes: + for node_idx, node in enumerate(nodes): if has_op_context_info(node): new_context_manager_code = emit_context_manager(node) if current_context_manager_code != new_context_manager_code: node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) - current_codes = emit_node(node) + current_codes = emit_node(node, node_idx) current_context_manager_code = new_context_manager_code else: - current_codes.extend(emit_node(node)) + current_codes.extend(emit_node(node, node_idx)) else: # Node without op context infortmation means it is inserted by nnscaler, not convert from original fx graph, # for example, multiref node and adapter node, currently for nodes inserted by nnscaler we have the following assumption: @@ -967,7 +1082,7 @@ def insert_codes_under_ctx(ctx_code, codes): # # TODO: all inserted nodes should have its op context field. node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) - node_codes += emit_node(node) + node_codes += emit_node(node, node_idx) current_codes = [] node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) diff --git a/nnscaler/customized_ops/__init__.py b/nnscaler/customized_ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nnscaler/customized_ops/ring_attention/README.md b/nnscaler/customized_ops/ring_attention/README.md new file mode 100644 index 00000000..38fbec5f --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/README.md @@ -0,0 +1,219 @@ +# Ring Attention Implementation + +High-performance ring attention mechanisms for nnscaler, supporting multiple attention variants and distributed training. + +## 📖 Overview + +This module implements multiple efficient attention mechanisms designed to distribute computation evenly in long sequence processing: + +- **Ring Attention**: Standard ring attention supporting arbitrary sequence lengths +- **Ring Attention Variable Length**: Variable-length sequence optimized ring attention +- **Zigzag Attention**: Zigzag pattern ring attention optimized for causal attention + +All implementations are deeply integrated with nnscaler's parallel computing framework, supporting automatic distributed training. + +## 🏗️ Architecture Design + +``` +nnscaler/customized_ops/ring_attention/ +├── __init__.py # Package import interface +├── ring_attn.py # Standard ring attention +├── ring_attn_varlen.py # Variable length ring attention +├── zigzag_attn.py # Zigzag ring attention +├── varlen_utils.py # Variable length utility functions +└── core/ # Core implementations + ├── ring_attn_implementation.py # Standard ring attention core + ├── ring_attn_varlen_implementation.py # Variable length core implementation + ├── zigzag_attn_implementation.py # Zigzag attention core implementation + └── utils.py # Common utility functions +``` + +## 🚀 Quick Start + +### Standard Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + +# Basic usage +output = wrap_ring_attn_func( + q, # [batch_size, seq_len, num_heads, head_dim] + k, # [batch_size, seq_len, num_heads, head_dim] + v, # [batch_size, seq_len, num_heads, head_dim] + causal=True, # Causal attention mask + window_size=(-1, -1), # Sliding window size, (-1,-1) means global attention + softmax_scale=None, # Softmax scale factor, defaults to 1/sqrt(head_dim) + dropout_p=0.0 # Dropout probability +) +``` + +### Variable Length Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func + +# Variable length sequence attention +output = wrap_ring_attn_varlen_func( + q, # [total_tokens, num_heads, head_dim] + k, # [total_tokens, num_heads, head_dim] + v, # [total_tokens, num_heads, head_dim] + cu_seqlens_q, # Cumulative sequence lengths [batch_size + 1] + cu_seqlens_k, # Cumulative sequence lengths [batch_size + 1] + bias=None, # Optional attention bias + causal=True, # Causal attention mask + window_size=(-1, -1), # Sliding window size + softmax_scale=None, # Softmax scale factor + dropout_p=0.0 # Dropout probability +) +``` + +### Zigzag Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_zigzag_attn_func + +# Zigzag attention (causal attention only) +output = wrap_zigzag_attn_func( + q, # [batch_size, seq_len, num_heads, head_dim] + k, # [batch_size, seq_len, num_heads, head_dim] + v, # [batch_size, seq_len, num_heads, head_dim] + causal=True, # Must be True + window_size=(-1, -1), # Must be (-1, -1), sliding window not supported + softmax_scale=None, + dropout_p=0.0 +) +``` + +## 🔧 Core Features + +### Performance Optimization +- **Flash Attention integration**: Efficient implementation based on flash_attn +- **TransformerEngine support**: Automatic detection and usage of TE 2.2.0+ +- **CUDA kernel optimization**: GPU-optimized low-level implementations +- **Distributed friendly**: Seamless integration with torch.distributed + +### Flexible Configuration +- **Attention patterns**: Support for causal and non-causal attention +- **Sliding window**: Configurable local attention windows +- **GQA support**: Grouped Query Attention optimization +- **Custom scaling**: Flexible softmax scaling strategies + +## 🧮 Algorithm Principles + +### Ring Attention Mechanism + +Ring Attention decomposes attention computation into multiple blocks: + +1. **Sequence chunking**: Divide long sequences into blocks distributed across devices +2. **Ring communication**: Devices pass key/value blocks by all-gather and reduce-scatter +3. **Incremental computation**: Each device computes attention with received key/value blocks + +### Variable Length Optimization + +Special optimizations for variable length sequences: + +```python +# Cumulative sequence length example +cu_seqlens = [0, 128, 256, 512] # 3 sequences with lengths 128, 128, 256 +# Corresponding token tensor shape: [512, num_heads, head_dim] +``` + +### Zigzag Pattern + +Zigzag Attention uses a special communication pattern for higher efficiency in causal attention scenarios: + +- **Causal constraint**: Only supports causal=True cases +- **Optimized communication**: Ring communication optimized for causal masks +- **Memory friendly**: Further reduces unnecessary computation and communication + +## 🔗 nnscaler Integration + +### Automatic Parallelization + +```python +from nnscaler.parallel import parallelize, ComputeConfig +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + +class AttentionModel(torch.nn.Module): + def forward(self, q, k, v): + return wrap_ring_attn_func(q, k, v, causal=True) + +# nnscaler automatically handles distribution +config = ComputeConfig( + plan_ngpus=4, + runtime_ngpus=4 +) +parallel_model = parallelize(model, config=config) +``` + +### Computation Graph Optimization + +nnscaler automatically provides: +- **Communication optimization**: Minimize inter-device communication overhead +- **Memory planning**: Optimize memory usage patterns +- **Operator fusion**: Fuse with other operators for optimization +- **Gradient synchronization**: Automatic gradient communication in backward pass + +## 🧪 Testing Framework + +Comprehensive test coverage ensures implementation correctness and performance: + +```bash +# Run all attention tests +pytest tests/customized_ops/ring_attn/ -v + +# Specific attention variant tests +pytest tests/customized_ops/ring_attn/test_ring_attn.py -v +pytest tests/customized_ops/ring_attn/test_ring_attn_varlen.py -v +pytest tests/customized_ops/ring_attn/test_zigzag_attn.py -v +``` + +### Test Types + +- **Correctness tests**: Compare outputs with standard attention +- **Multi-GPU scalability**: Behavior validation across different device counts +- **GQA compatibility**: Grouped Query Attention correctness +- **Sliding window**: Local attention pattern validation +- **Edge cases**: Stability testing under extreme conditions + +## 🛠️ Development Guide + +### Adding New Attention Variants + +1. **Core implementation**: Add implementation file in `core/` directory +2. **Wrapper function**: Create corresponding wrap function +3. **Test coverage**: Add comprehensive test cases +4. **Documentation**: Update README and API documentation + +### Performance Optimization Tips + +- **TransformerEngine**: Install TE 2.2.0+ for optimal performance +- **CUDA version**: Use CUDA 11.8+ for latest optimizations +- **Memory configuration**: Adjust batch size and sequence length based on GPU memory +- **Communication optimization**: Use InfiniBand networks to reduce communication latency + +## 🚨 Known Limitations + +### Ring Attention +- **alibi_slopes**: ALiBi positional encoding not currently supported +- **return_attn_probs**: Returning attention weights not supported + +### Zigzag Attention +- **causal**: Only supports causal attention (causal=True) +- **window_size**: Sliding window not supported (must be (-1,-1)) + +### General Limitations +- **Dynamic shapes**: Sequence length cannot change dynamically during training +- **Mixed precision**: May require special handling in certain configurations + +## 📚 References + +- **Ring Attention Paper**: [Ring Attention with Blockwise Transformers](https://arxiv.org/abs/2310.01889) +- **Flash Attention**: [FlashAttention: Fast and Memory-Efficient Exact Attention](https://arxiv.org/abs/2205.14135) +- **Llama3 Paper**: [The Llama3 Herd of Models](https://arxiv.org/pdf/2407.21783) +- **nnscaler Documentation**: [nnscaler Parallel Computing Framework](https://github.com/microsoft/nnscaler) +- **TransformerEngine**: [NVIDIA TransformerEngine](https://github.com/NVIDIA/TransformerEngine) + +--- + +**Note**: This implementation is optimized for large-scale distributed training. For single-GPU scenarios, standard Flash Attention is recommended for optimal performance. \ No newline at end of file diff --git a/nnscaler/customized_ops/ring_attention/__init__.py b/nnscaler/customized_ops/ring_attention/__init__.py new file mode 100644 index 00000000..e54f5bc1 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .ring_attn_varlen import wrap_ring_attn_varlen_func + +from .zigzag_attn import wrap_zigzag_attn_func + +from .ring_attn import wrap_ring_attn_func \ No newline at end of file diff --git a/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py b/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py new file mode 100644 index 00000000..39a3885d --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py @@ -0,0 +1,326 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from .utils import shuffle_input, recover_output, GlobalMemoryBuffer, get_default_args, all_gather, reduce_scatter + + +_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + def forward(q, k, v, causal): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + up_q = q[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + + up_out, up_lse = forward(up_q, up_k, up_v, causal) + + down_q = q[:, block_len:] + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + down_out, down_lse = forward(down_q, down_k, down_v, causal) + + out = torch.cat([up_out, down_out], dim=1) + return out, up_lse, down_lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + up_lse, + down_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): # pragma: no cover + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + dq = torch.zeros_like(q) + dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_dk") + dk_buffer.zero_() + dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_dv") + dv_buffer.zero_() + + up_q = q[:, :block_len] + up_out = out[:, :block_len] + up_dout = dout[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + up_dk = dk_buffer[:, :(up_rank + 1) * block_len] + up_dv = dv_buffer[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + up_dk, up_dv = dk_buffer, dv_buffer + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": up_dout, + "q": up_q, + "k": up_k, + "v": up_v, + "out": up_out, + "softmax_lse": up_lse, + "dq": dq[:, :block_len], + "dk": up_dk, + "dv": up_dv, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + down_q = q[:, block_len:] + down_out = out[:, block_len:] + down_dout = dout[:, block_len:] + # TODO: optimize the buffer allocation + down_dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_down_dk") + down_dk_buffer.zero_() + down_dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_down_dv") + down_dv_buffer.zero_() + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + down_dk = down_dk_buffer[:, :(down_rank + 1) * block_len] + down_dv = down_dv_buffer[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + down_dk, down_dv = down_dk_buffer, down_dv_buffer + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": down_dout, + "q": down_q, + "k": down_k, + "v": down_v, + "out": down_out, + "softmax_lse": down_lse, + "dq": dq[:, block_len:], + "dk": down_dk, + "dv": down_dv, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + dk_buffer.add_(down_dk_buffer) + dv_buffer.add_(down_dv_buffer) + + bsz = q.size(0) + if bsz == 1: + dim_size = list(k.size()) + dim_size[1] = dim_size[1] // world_size + dk = torch.empty(dim_size, dtype=k.dtype, device=k.device) + dv = torch.empty(dim_size, dtype=v.dtype, device=v.device) + dist.reduce_scatter_tensor(dk, dk_buffer, group=process_group) + dist.reduce_scatter_tensor(dv, dv_buffer, group=process_group) + else: + dk = reduce_scatter(dk_buffer, dim=1, process_group=process_group) + dv = reduce_scatter(dv_buffer, dim=1, process_group=process_group) + + return dq, dk, dv + + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, all gather k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, reduce scatter dk, dv +''' +class RingFlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + bsz = q.size(0) + q = shuffle_input(to_send=q, process_group=group) + k = k.contiguous() + v = v.contiguous() + world_size = dist.get_world_size(group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + if bsz == 1: + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + # torch.distributed._all_gather_base function requires that the k and v tensors are contiguous. + torch.distributed.all_gather_into_tensor(k_buffer, k, group=group) + torch.distributed.all_gather_into_tensor(v_buffer, v, group=group) + else: + k_buffer = all_gather(k, dim=1, process_group=group) + v_buffer = all_gather(v, dim=1, process_group=group) + + out, up_lse, down_lse = ring_flash_attn_forward( + group, + q, + k_buffer, + v_buffer, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, up_lse, down_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_output(out, process_group=group) + return out + + @staticmethod + def backward(ctx, dout, *args): # pragma: no cover + dout = shuffle_input(to_send=dout, process_group=ctx.group) + q, k, v, out, up_lse, down_lse = ctx.saved_tensors + bsz = q.size(0) + world_size = dist.get_world_size(ctx.group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + if bsz == 1: + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed.all_gather_into_tensor(k_buffer, k, group=ctx.group) + torch.distributed.all_gather_into_tensor(v_buffer, v, group=ctx.group) + else: + k_buffer = all_gather(k, dim=1, process_group=ctx.group) + v_buffer = all_gather(v, dim=1, process_group=ctx.group) + + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k_buffer, + v_buffer, + out, + up_lse, + down_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + dq = recover_output(dq, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None diff --git a/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py b/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py new file mode 100644 index 00000000..709ffe09 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py @@ -0,0 +1,516 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: Most of code is copied from project https://github.com/zhuzilin/ring-flash-attention + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, +) +from .utils import get_default_args, AllGatherComm as Comm + + +def llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens: torch.Tensor, causal: bool, rank: int, world_size: int +): + """ + Args: + cu_seqlens: torch.Tensor, the cu_seqlens of all the sequences across the ring process group. + + Returns: + cu_seqlens_q: torch.Tensor, the cu_seqlens of the q slice for this rank. + cu_seqlens_k: torch.Tensor, the cu_seqlens of the k slice that the local q need. Note + that this may be longer than `total_seq_len // world_size`. + local_k_slice: slice, the slice of the k that the local q need. Note + that this may be longer than `total_seq_len // world_size`. + """ + total_length = cu_seqlens[-1].item() + assert total_length % world_size == 0, cu_seqlens + length_per_rank = total_length // world_size + left = torch.searchsorted(cu_seqlens, rank * length_per_rank) + right = torch.searchsorted(cu_seqlens, (rank + 1) * length_per_rank) + + # after this, cu_seqlens[left:right + 1] contains all the sequence for this rank + if cu_seqlens[left] != rank * length_per_rank: + left -= 1 + left = left.item() + right = right.item() + + # q is always the same. just calculate the cu_seqlens for the local slice + cu_seqlens_q = cu_seqlens[left : right + 1].clone() + cu_seqlens_q -= rank * length_per_rank + cu_seqlens_q[0] = 0 + cu_seqlens_q[-1] = length_per_rank + + cu_seqlens_k = cu_seqlens[left : right + 1].clone() + if causal: + # when causal, we hope + # - the last k seq is of the same length as the last q seq + slice_right = (rank + 1) * length_per_rank + cu_seqlens_k[-1] = slice_right + else: + # when not causal, we hope + # - the last k is full seq + slice_right = cu_seqlens[right].item() + + slice_left = cu_seqlens[left].item() + cu_seqlens_k -= slice_left + + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + local_k_slice = slice(slice_left, slice_right) + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, local_k_slice + + +def llama3_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + out_list = [] + lse_list = [] + + nheads = q.shape[1] + total_k, nheads_k, head_dim = k.shape + assert nheads_k % heads_k_stride == 0 + + world_size = dist.get_world_size(process_group) + kv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + kv_buffer_copy = torch.empty_like(kv_buffer) + + k_0 = k[:, :heads_k_stride].contiguous() + v_0 = v[:, :heads_k_stride].contiguous() + comm = Comm(process_group) + + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + for i in range(0, nheads_k, heads_k_stride): + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + if i < nheads_k - heads_k_stride: + # all_gather the next kv slice + kv_slice_left = i + heads_k_stride + kv_slice_right = kv_slice_left + heads_k_stride + send_k = k[:, kv_slice_left:kv_slice_right].contiguous() + send_v = v[:, kv_slice_left:kv_slice_right].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + q_i = q[:, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + k_i = kv_buffer[0][local_k_slice] + v_i = kv_buffer[1][local_k_slice] + if alibi_slopes is None: + cur_alibi_slopes = None + else: + cur_alibi_slopes = alibi_slopes[i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + + params = get_default_args(_flash_attn_varlen_forward).copy() + params.update( + { + "q": q_i, + "k": k_i, + "v": v_i, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": cur_alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_varlen_forward(**params) + if len(outputs) == 8: + out, _, _, _, _, lse, _, _ = outputs + else: + assert len(outputs) == 4 + out, lse, _, _ = outputs + out_list.append(out) + lse_list.append(lse) + + out = torch.cat(out_list, dim=1) + lse = torch.cat(lse_list, dim=-2) + return out, lse + + +def llama3_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): # pragma: no cover + nheads = q.shape[1] + total_k, nheads_k, head_dim = k.shape + assert nheads_k % heads_k_stride == 0 + + world_size = dist.get_world_size(process_group) + kv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + kv_buffer_copy = torch.empty_like(kv_buffer) + + dkv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + if heads_k_stride != nheads_k: + kv_contiguous_buffer = torch.empty( + (2, total_k, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + comm = Comm(process_group) + + k_0 = k[:, :heads_k_stride].contiguous() + v_0 = v[:, :heads_k_stride].contiguous() + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + for i in range(0, nheads_k, heads_k_stride): + dkv_buffer.zero_() + + q_slice = slice( + i * nheads // nheads_k, (i + heads_k_stride) * nheads // nheads_k + ) + q_i = q[:, q_slice] + dout_i = dout[:, q_slice] + out_i = out[:, q_slice] + dq_i = dq[:, q_slice] + if softmax_lse.dim() == 3: + lse_i = softmax_lse[:, q_slice].contiguous() + else: + lse_i = softmax_lse[q_slice] + + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + if i < nheads_k - heads_k_stride: + # all_gather the next kv slice + kv_slice_left = i + heads_k_stride + kv_slice_right = kv_slice_left + heads_k_stride + send_k = k[:, kv_slice_left:kv_slice_right].contiguous() + send_v = v[:, kv_slice_left:kv_slice_right].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + k_i = kv_buffer[0][local_k_slice] + v_i = kv_buffer[1][local_k_slice] + dk_i = dkv_buffer[0][local_k_slice] + dv_i = dkv_buffer[1][local_k_slice] + + if alibi_slopes is None: + cur_alibi_slopes = None + else: + cur_alibi_slopes = alibi_slopes[i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + + params = get_default_args(_flash_attn_varlen_backward).copy() + params.update( + { + "dout": dout_i, + "q": q_i, + "k": k_i, + "v": v_i, + "out": out_i, + "softmax_lse": lse_i, + "dq": dq_i, + "dk": dk_i, + "dv": dv_i, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": cur_alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_varlen_backward(**params) + + if heads_k_stride != nheads_k: + # reduce_scatter needs contiguous buffer + dk_i = kv_contiguous_buffer[0] + dv_i = kv_contiguous_buffer[1] + else: + dk_i = dk + dv_i = dv + + dist.reduce_scatter_tensor(dk_i, dkv_buffer[0], group=process_group) + dist.reduce_scatter_tensor(dv_i, dkv_buffer[1], group=process_group) + + if heads_k_stride != nheads_k: + dk[:, i : i + heads_k_stride] = dk_i + dv[:, i : i + heads_k_stride] = dv_i + + return dq, dk, dv + + +class Llama3FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = llama3_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.heads_k_stride = heads_k_stride + ctx.local_k_slice = local_k_slice + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): # pragma: no cover + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = llama3_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.heads_k_stride, + ctx.local_k_slice, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return (dq, dk, dv) + (None,) * 15 + + +def llama3_flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def llama3_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def llama3_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/nnscaler/customized_ops/ring_attention/core/utils.py b/nnscaler/customized_ops/ring_attention/core/utils.py new file mode 100644 index 00000000..6345b35b --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/core/utils.py @@ -0,0 +1,343 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention + +from typing import Optional, Tuple +from functools import reduce +import operator +import inspect +from functools import cache +import random + +import torch +import torch.distributed as dist + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def gen_head_anno(query_states, key_states, value_states, head_pos=2): + if query_states.shape[head_pos] != key_states.shape[head_pos]: + assert query_states.shape[head_pos] % key_states.shape[head_pos] == 0 + group_size = query_states.shape[head_pos] // key_states.shape[head_pos] + assert query_states.shape[head_pos] == value_states.shape[head_pos] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + return q_anno, kv_anno + + +# copied from project https://github.com/zhuzilin/ring-flash-attention +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +class AllGatherComm: + def __init__(self, group=None) -> None: + self.group = group + self.handles = [] + + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): + handle = dist.all_gather_into_tensor( + output_tensor, input_tensor, group=self.group, async_op=True + ) + self.handles.append(handle) + + def wait(self): + for handle in self.handles: + handle.wait() + self.handles = [] + + +# copy from megatron/core/utils.py +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if ( + self.buffer.get((name, dtype), None) is None + or self.buffer[(name, dtype)].numel() < required_len + ): + self.buffer[(name, dtype)] = torch.empty( + required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + @torch.jit.script + def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + + out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + + lse = new_lse + return out, lse + + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + parts = self.world_size // 2 + self.ring_list = [] + for i in range(parts): + self.ring_list.extend([i, self.world_size - i - 1]) + + self.revert_rank = self.ring_list.index(self.rank) + + offset = ((dist.get_rank() // self.world_size) * self.world_size) + self.send_rank = self.ring_list[(self.revert_rank + 1) % self.world_size] + offset + self.recv_rank = self.ring_list[(self.revert_rank - 1) % self.world_size] + offset + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + +def shuffle_input(to_send: torch.Tensor, + process_group: dist.ProcessGroup = None): + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + block_seq_len = to_send.shape[1] // 2 + to_send_slice = to_send[:, block_seq_len:].contiguous() + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[:, block_seq_len:] = to_send[:, :block_seq_len] + to_send_f[:, :block_seq_len, ...] = res + else: # A: 0 1, -> 0 7 + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len] + to_send_f[:, block_seq_len:, ...] = res + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + + return to_send_f + + +def recover_output(to_send: torch.Tensor, + process_group: dist.ProcessGroup = None): + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + to_send_f = torch.zeros_like(to_send) + + block_seq_len = to_send.shape[1] // 2 + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if rank >= world_size // 2: + to_send_slice = to_send[:, :block_seq_len, ...].contiguous() + else: + to_send_slice = to_send[:, block_seq_len:, ...].contiguous() + res = torch.zeros_like(to_send_slice) + + assert to_send_slice.is_contiguous() + assert res.is_contiguous() + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: + to_send_f[:, :block_seq_len] = to_send[:, block_seq_len:, ...] + to_send_f[:, block_seq_len:] = res + else: + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len, ...] + to_send_f[:, block_seq_len:] = res + + return to_send_f.contiguous() + + +def all_gather(tensor: torch.Tensor, dim: int, process_group: dist.ProcessGroup): + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + tensor_list[torch.distributed.get_rank(process_group)] = tensor.data + torch.distributed.all_gather(tensor_list, tensor, group=process_group) + otensor = torch.concat(tuple(tensor_list), dim=dim) + return otensor + + +def reduce_scatter(tensor: torch.Tensor, dim: int, process_group: dist.ProcessGroup): + world_size = dist.get_world_size(process_group) + itensors = list(tensor.chunk(world_size, dim)) + for idx, t in enumerate(itensors): + itensors[idx] = t.contiguous() if not t.is_contiguous() else t + otensor = torch.empty_like(itensors[0], requires_grad=False) + torch.distributed.reduce_scatter(otensor, itensors, group=process_group) + return otensor diff --git a/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py b/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py new file mode 100644 index 00000000..7a643d59 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py @@ -0,0 +1,516 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from .utils import RingComm, update_out_and_lse, shuffle_input, recover_output, get_default_args + +''' +Assume we have 4 GPUs A, B, C, D. +The sequence is represented as [0 1 2 3 4 5 6 7]. + +The P2P communication ring is A -> D -> B -> C -> A +The initial status of the attention computation is +X +X X +X X X +X X X X +X X X X X +X X X X X X +X X X X X X X +X X X X X X X X +Note: +- the computation in the diagonal is `causal=True` +- the computation in the off-diagonal is `causal=False` +We consider a `X` with `causal=True` as a unit computation block. +In this example, there are 4 steps. Each device is responsible for 2 unit computation blocks in each step. + +q status is same across all steps (q is not transmitted): +GPU A: [0 7] +GPU B: [2 5] +GPU C: [3 4] +GPU D: [1 6] + +Step 0, kv status: +GPU A: [0 7] +GPU B: [2 5] +GPU C: [3 4] +GPU D: [1 6] +Computation status: +A +X D +X X B +X X X C +X X X C C +X X B X X B +X D X X X X D +A X X X X X X A + +Step 1, kv status: +GPU A: [3 4] +GPU B: [1 6] +GPU C: [2 5] +GPU D: [0 7] +Computation status: +X +D X +X B X +X X C X +X X C X X +X B X X X X +D X X X X X X +X X X A A X X X + +Step 2, kv status: +GPU A: [2 5] +GPU B: [0 7] +GPU C: [1 6] +GPU D: [3 4] +Computation status: +X +X X +B X X +X C X X +X C X X X +B X X X X X +X X X D D X X +X X A X X A X X + +Step 3, kv status: +GPU A: [1 6] +GPU B: [3 4] +GPU C: [0 7] +GPU D: [2 5] +Computation status: +X +X X +X X X +C X X X +C X X X X +X X X B B X +X X D X X D X +X A X X X X A X + +From this example, we can conclude the key insight of zigzag ring flash attention is: +- split the sequence into fine-grained blocks to achieve balance across steps and gpus +- schedule the computation in a zigzag pattern to minimize the communication overhead + +To be more specific, if the sequence length is L=4n, the total computation cost of flash attention +with causal=True is 1/2 L^2 = 8n^2. Each device needs to compute 4n. Each step needs to compute 2. + +Computation task assigned for each GPU: + +GPU 0: (0, 4n-1) +GPU 1: (2, 4n-3) +... +GPU n-1: (2n-2, 2n+1) +GPU n: (2n-1, 2n) +GPU n+1: (2n-3, 2n+2) +... +GPU 2n-1: (1, 4n-2) + +Dependence of kv (required kv range) for each device: +GPU 0: [0, 4n-1] +GPU 1: [0, 4n-3] +... +GPU n-1: [0, 2n+1] +GPU n: [0, 2n] +GPU n+1: [0, 2n+2] +... +GPU 2n-1: [0, 4n-2] + +In general, if there are 2n GPUs, the ring is 0 -> 2n-1 -> 1 -> 2n-2 -> ... -> n -> n+1 -> 0 + +For each device, the 2n steps is divided into 3 parts: +1. compute the local attention with `causal=True` +2. if current step is less or equal to its relative rank in the ring, select the first half + of the received kv to compute the attention with `causal=False`. In the example above, each + device computes to `left` of its corresponding rows in the status matrix. +3. if current step is greater than its relative rank in the ring, select the second half of + local q and full received kv to compute the attention with `causal=False`. In the example + above, each device fills the remaining part of its lower row in the status matrix. +''' + +def zigzag_ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group) + + block_seq_len = q.shape[1] // 2 + q1 = q[:, block_seq_len:] + + out = None + lse = None + next_k, next_v = None, None + + def forward(q, k, v, causal): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step == 0: + block_out, block_lse = forward(q, k, v, causal=True) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + block_out, block_lse = forward(q, k0, v0, causal=False) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + out, lse = update_out_and_lse( + out, + lse, + block_out, + block_lse, + slice_=(slice(None), slice(block_seq_len, None)), + ) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +''' +In the backward pass, we assume q, k, v and out are saved in the shuffled order. +In addition, the backward pass requires a shuffled dout as input and generates +a shuffled dq, dk, dv as output. Note that out is a sum of all step outputs, so +we can directly pass dout to each step's backward block to compute the local gradient +according to the differiential chain rule. + +Similar to the forward pass, in the backward pass, the 2n steps are divided into 3 parts. + +Different from the forward pass, we need to communicate the gradient of kv in a ring as well. +To be more specific, each device calculates the local gradients of dq, dk, dv. In the following +steps, dq will be accumulated in the initial device, while dk and dv will be transmitted to the +next consumer device, then accumulated in the consumer device. In the end, the dk and dv will be +transmitted back to the initial device. + +In addition, to be compatible with the flash-attn's interface and reduce the precision loss, +we will accumulate and transmit the gradients in float32. They will be converted back to the +original dtype at the end of the backward pass. +''' +def zigzag_ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): # pragma: no cover + assert causal == True, "zigzag ring is meaningless for causal=False" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=1)[1] + q1 = q.chunk(2, dim=1)[1] + out1 = out.chunk(2, dim=1)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() + block_seq_len = q.shape[1] // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout, q, k, v, out, softmax_lse, causal): + seqlen_q = q.shape[1] + seqlen_kv = k.shape[1] + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout, + "q": q, + "k": k, + "v": v, + "out": out, + "softmax_lse": softmax_lse, + "dq": dq_buffer[:, :seqlen_q], + "dk": dk_buffer[:, :seqlen_kv], + "dv": dv_buffer[:, :seqlen_kv], + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + backward(dout, q, k0, v0, out, softmax_lse, causal=False) + dq += dq_buffer + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) + # always use the first half in dq_buffer. + dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.revert_rank: + dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] + dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] + else: + dk += dk_buffer + dv += dv_buffer + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, zigzag ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, dk, dv +''' +class ZigZagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + q = shuffle_input(to_send=q, process_group=group) + k = shuffle_input(to_send=k, process_group=group) + v = shuffle_input(to_send=v, process_group=group) + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = zigzag_ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_output(out, process_group=group) + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): # pragma: no cover + dout = shuffle_input(to_send=dout, process_group=ctx.group) + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = zigzag_ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + dq = recover_output(dq, ctx.group) + dk = recover_output(dk, ctx.group) + dv = recover_output(dv, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/nnscaler/customized_ops/ring_attention/ring_attn.py b/nnscaler/customized_ops/ring_attention/ring_attn.py new file mode 100644 index 00000000..e7a8a4b8 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/ring_attn.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Tuple, List, Dict +from torch import Tensor + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir.operator import IRFwOperation +from .core.ring_attn_implementation import RingFlashAttnFunc +from .core.utils import gen_head_anno +from flash_attn import flash_attn_func + +from nnscaler.runtime.device import DeviceGroup + + +def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, + dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None) -> Tensor: + ''' + wrap the ring_attn_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + + assert alibi_slopes is None, "alibi_slopes is not supported in ring_attn_func" + assert return_attn_probs is False, "return_attn_probs is not supported in ring_attn_func" + + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, softmax_scale=softmax_scale, causal=causal, window_size=window_size,) + return output + + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + + output = RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ) + + return output + + +def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate ring_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states) + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + + +register_op(flash_attention_anno, emit_fn=emit_ring)(wrap_ring_attn_func) diff --git a/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py b/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py new file mode 100644 index 00000000..bb9ff54b --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Tuple, List, Dict, Optional +import torch +from torch import Tensor +import torch.distributed as dist +import warnings + +from nnscaler.graph.parser.register import register_op +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir import IRTensor +from nnscaler.runtime.device import DeviceGroup +from flash_attn import flash_attn_varlen_func +from .core.ring_attn_varlen_implementation import llama3_flash_attn_prepare_cu_seqlens, llama3_flash_attn_varlen_func +from .core.utils import gen_head_anno +from .varlen_utils import shuffle_varlen, unshuffle_varlen + +# Try to import TransformerEngine with version check +_HAS_TRANSFORMER_ENGINE = False +_TE_VERSION_OK = False +attn_forward_func_with_cp = None + +try: + import transformer_engine + _HAS_TRANSFORMER_ENGINE = True + + # Check version - require 2.2.0+ + try: + from packaging import version + te_version = version.parse(transformer_engine.__version__) + required_version = version.parse("2.2.0") + _TE_VERSION_OK = te_version >= required_version + + if _TE_VERSION_OK: + # Try different import paths for different versions + try: + # For v2.5.0+ + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import attn_forward_func_with_cp + except ImportError: + try: + # For v2.2.0-v2.4.x + from transformer_engine.pytorch.attention import attn_forward_func_with_cp + except ImportError: + warnings.warn( + "TransformerEngine attention module not available or incompatible. " + "Falling back to basic ring attention implementation." + ) + else: + warnings.warn( + f"TransformerEngine version {transformer_engine.__version__} is too old. " + f"Require 2.2.0+. Falling back to basic ring attention implementation." + ) + except ImportError: + # packaging not available, try to import anyway + try: + # Try different import paths for different versions + try: + # For v2.5.0+ + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import attn_forward_func_with_cp + except ImportError: + # For v2.2.0-v2.4.x + from transformer_engine.pytorch.attention import attn_forward_func_with_cp + _TE_VERSION_OK = True + except (ImportError, AttributeError): + warnings.warn( + "TransformerEngine attention module not available or incompatible. " + "Falling back to basic ring attention implementation." + ) + +except ImportError: + warnings.warn( + "TransformerEngine not found. Falling back to basic ring attention implementation. " + "For better performance with context parallelism, install TransformerEngine 2.2.0+." + ) + + +def get_transformer_engine_info() -> Dict[str, any]: + """Get information about TransformerEngine availability and version.""" + return { + "has_transformer_engine": _HAS_TRANSFORMER_ENGINE, + "version_ok": _TE_VERSION_OK, + "has_cp_function": attn_forward_func_with_cp is not None, + "version": getattr(transformer_engine, "__version__", None) if _HAS_TRANSFORMER_ENGINE else None, + "required_version": "2.2.0+", + } + + +def print_transformer_engine_status(): + """Print TransformerEngine status for debugging.""" + info = get_transformer_engine_info() + print("TransformerEngine Status:") + print(f" - Available: {info['has_transformer_engine']}") + if info['has_transformer_engine']: + print(f" - Version: {info['version']}") + print(f" - Version OK (>= 2.2.0): {info['version_ok']}") + print(f" - CP Function Available: {info['has_cp_function']}") + else: + print(f" - Required Version: {info['required_version']}") + print(f" - Will use TE CP: {info['has_transformer_engine'] and info['version_ok'] and info['has_cp_function']}") + + +def wrap_ring_attn_varlen_func( + q: Tensor, + k: Tensor, + v: Tensor, + cu_seqlens_q: Tensor, + cu_seqlens_k: Tensor, + alibi_slopes: Tensor, + dropout_p: float = 0.0, + softmax_scale: Tensor = None, + causal: bool = False, + window_size: Tuple[int] = (-1, -1), + deterministic: bool = False, + return_attn_probs: bool = False, + process_group: Tuple[int] = None, +): + ''' + wrap the ring_attn_varlen_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_varlen_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + assert not return_attn_probs, "return_attn_probs is not supported in ring-attention" + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + + if process_group is None or len(process_group) == 1: + output = flash_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=False, + ) + return output + + assert len(q.shape) == 3, "q must have shape [total_q, qh, dim]" + assert len(k.shape) == 3, "k must have shape [total_k, kh, dim]" + assert len(v.shape) == 3, "v must have shape [total_k, vh, dim]" + total_q, qheads, qdim = q.shape + total_k, kheads, kdim = k.shape + total_v, vheads, vdim = v.shape + assert total_q == total_k == total_v, "total_q, total_k and total_v must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + local_rank = dist.get_rank(local_process_group) + local_world_size = dist.get_world_size(local_process_group) + assert local_world_size == len(process_group), "local_world_size should be the same with process_group size" + + if local_process_group is None: + local_process_group = dist.group.WORLD + + if window_size == (-1, -1): + # Use TransformerEngine with context parallelism if available and version is OK + if _HAS_TRANSFORMER_ENGINE and _TE_VERSION_OK and attn_forward_func_with_cp is not None: + shuffled_q = shuffle_varlen(q, cu_seqlens_q, process_group, local_process_group) + shuffled_k = shuffle_varlen(k, cu_seqlens_k, process_group, local_process_group) + shuffled_v = shuffle_varlen(v, cu_seqlens_k, process_group, local_process_group) + + te_cu_seqlens_q = cu_seqlens_q.clone() + te_cu_seqlens_k = cu_seqlens_k.clone() + te_cu_seqlens_q = torch.cat( + [ + te_cu_seqlens_q, + torch.tensor([cu_seqlens_q[-1].item()], dtype=te_cu_seqlens_q.dtype, device=te_cu_seqlens_q.device) + ] + ) + te_cu_seqlens_k = torch.cat( + [ + te_cu_seqlens_k, + torch.tensor([cu_seqlens_k[-1].item()], dtype=te_cu_seqlens_k.dtype, device=te_cu_seqlens_k.device) + ] + ) + shuffled_output = attn_forward_func_with_cp( + True, + shuffled_q, + shuffled_k, + shuffled_v, + te_cu_seqlens_q, + te_cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + te_cu_seqlens_q, + te_cu_seqlens_k, + dropout_p, + local_process_group, + process_group, + # TODO: optimize the stream usage + torch.cuda.current_stream(), + "p2p", # "all_gather" version cannot work with thd format + qkv_format="thd", + attn_mask_type="padding_causal" if causal else "padding", + ) + output = unshuffle_varlen(shuffled_output, cu_seqlens_q, process_group, local_process_group) + return output + else: + # Fallback to basic ring attention implementation + warnings.warn( + "TransformerEngine not available or version incompatible. " + "Using basic ring attention implementation which may be slower." + ) + + ( + local_cu_seqlens_q, + local_cu_seqlens_k, + local_max_seqlen_q, + local_max_seqlen_k, + local_k_slice, + ) = llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens_q, + causal=causal, + rank=local_rank, + world_size=local_world_size, + ) + + output = llama3_flash_attn_varlen_func( + q, + k, + v, + local_cu_seqlens_q, + local_cu_seqlens_k, + local_max_seqlen_q, + local_max_seqlen_k, + heads_k_stride=1, + local_k_slice=local_k_slice, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=False, + group=local_process_group, + ) + + return output + + +def emit_ring(node: IRDimops, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate ring_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + remainder = runtime_devid % plan_ndevs + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [(i, f // s) for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + if partition_dims[0][0] == 0: # partition on sequence dim + # the synchronization should occur across scaleunits + num = partition_dims[0][1] + scale_unit_dev_ids = [local_rank + offset for local_rank in range(remainder // num * num, (remainder // num + 1) * num)] + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0][0] == 1: + # partition the head dim, use local flash_attn_func + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +def flash_attention_anno(query_states, key_states, value_states, cu_seqlens_q, cu_seqlens_k, alibi_slopes, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states, head_pos=1) + if isinstance(alibi_slopes, IRTensor): + return f'l {q_anno} hd^, l {kv_anno} hd^, l {kv_anno} vd^, e^, e^, {q_anno} -> l {q_anno} vd^' + else: + return f'l {q_anno} hd^, l {kv_anno} hd^, l {kv_anno} vd^, e^, e^, ? -> l {q_anno} vd^' + + +def input_gen_fn(node: IRDimops): + inputs = [] + device = torch.cuda.current_device() + seqlen = node.inputs()[0].shape[0] + for i, t in enumerate(node.inputs()): + if i < 3: # query, key, value + inputs.append(torch.randn(t.shape, dtype=t.dtype, device=device, requires_grad=t.requires_grad)) + elif i in [3, 4]: # cu_seqlens + inputs.append(torch.Tensor([0, seqlen]).to(torch.int32).to(device)) + elif i == 5: # optional alibi_slopes + if isinstance(t, IRTensor): + inputs.append(torch.randn(t.shape, dtype=t.dtype, device=device, requires_grad=t.requires_grad)) + else: + inputs.append(None) + else: # other kwargs, use defaults + break + return tuple(inputs) + + +register_op(flash_attention_anno, emit_fn=emit_ring, input_gen_fn=input_gen_fn)(wrap_ring_attn_varlen_func) diff --git a/nnscaler/customized_ops/ring_attention/varlen_utils.py b/nnscaler/customized_ops/ring_attention/varlen_utils.py new file mode 100644 index 00000000..bdd1f127 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/varlen_utils.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Utilities for variable-length sequence processing in ring attention. +Contains shuffle and unshuffle functions for context parallel processing. +""" + +from typing import List +import torch +from torch import Tensor +import torch.distributed as dist +from nnscaler.runtime.adapter.nn import allgather_reducescatter + + +def shuffle_varlen(t: Tensor, cu_seqlens_padded: Tensor, cp_ranks: List[int], cp_group: dist.ProcessGroup) -> Tensor: + """ + Shuffle tensor data for variable-length sequences in context parallel processing. + + Args: + t: Input tensor to shuffle (local portion from each rank) + cu_seqlens_padded: Cumulative sequence lengths (global) + cp_ranks: List of ranks in the context parallel group + cp_group: Process group for context parallel communication + + Returns: + Shuffled tensor + """ + # Get context parallel size and rank + cp_size = torch.distributed.get_world_size(group=cp_group) + assert cp_size > 1, "cp_size should be greater than 1" + cp_rank = torch.distributed.get_rank(group=cp_group) + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + full_tensor = allgather_reducescatter(t, 0, cp_ranks) + return process_tensor(full_tensor) + + +def unshuffle_varlen(t: Tensor, cu_seqlens_padded: Tensor, cp_ranks: List[int], cp_group: dist.ProcessGroup) -> Tensor: + """ + Unshuffle tensor data to restore original variable-length sequence order. + This is the reverse operation of shuffle_varlen. + + Args: + t: Shuffled tensor to unshuffle (local portion from each rank) + cu_seqlens_padded: Cumulative sequence lengths (global) + cp_ranks: List of ranks in the context parallel group + cp_group: Process group for context parallel communication + + Returns: + Unshuffled tensor (local portion for each rank) + """ + # reverse operation of shuffle_varlen + cp_size = torch.distributed.get_world_size(group=cp_group) + assert cp_size > 1, "cp_size should be greater than 1" + cp_rank = torch.distributed.get_rank(group=cp_group) + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + sum_len = cu_seqlens_padded[-1].item() + + def process_tensor(val): + if val is None: + return val + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + cp_rank_slices = [] + for rank in range(cp_size): + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (rank * slice_size), + seq_start + ((rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - rank) * slice_size), + device=val.device, + ) + ) + perm = torch.cat(cp_rank_slices) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(sum_len, device=val.device) + + # Create a tensor to hold the unshuffled result + unshuffled = val.index_select(current_seq_dim, inv_perm) + local_tensor = torch.chunk(unshuffled, cp_size, dim=current_seq_dim)[cp_rank] + return local_tensor + + full_tensor = allgather_reducescatter(t, 0, cp_ranks) + return process_tensor(full_tensor) diff --git a/nnscaler/customized_ops/ring_attention/zigzag_attn.py b/nnscaler/customized_ops/ring_attention/zigzag_attn.py new file mode 100644 index 00000000..2373f9d0 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/zigzag_attn.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Tuple, List, Dict +import torch +from torch import Tensor +import torch.distributed + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir.operator import IRFwOperation +from .core.zigzag_attn_implementation import ZigZagRingFlashAttnFunc +from .core.utils import gen_head_anno +from flash_attn import flash_attn_func + +import torch.distributed as dist +from nnscaler.runtime.device import DeviceGroup + + +def wrap_zigzag_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, + dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None) -> Tensor: + ''' + wrap the zigzag_attn_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + + assert window_size == (-1, -1), "window_size is not supported in zigzag-attention" + assert not return_attn_probs, "return_attn_probs is not supported in zigzag-attention" + assert alibi_slopes is None, "alibi_slopes is not supported in zigzag-attention" + + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + return output + + assert causal == True, "zigzag_ring is meaningless for causal=False" + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + + output = ZigZagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ).contiguous() + + return output + +def emit_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states) + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + + +register_op(flash_attention_anno, emit_fn=emit_zigzag)(wrap_zigzag_attn_func) diff --git a/nnscaler/flags.py b/nnscaler/flags.py index 77333987..af903b91 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -47,7 +47,7 @@ class CompileFlag: # use zero optimization on optimizer status. # to cooperate with zero, user needs to call `model.parameters_for_optimizer()` # to get parameters for optimizer, and `model.gather_params()` after `optimizer.step()` - use_zero = _to_bool('USE_ZERO') + use_zero = _to_int('USE_ZERO') # use async communication to overlap gradient synchronization and backward computation async_reducer = _to_bool('ASYNC_REDUCER') # use async reducer # maximal reducer weight bytes for one allreduce (only effective for async): diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index 333b01ad..827c3ed2 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -72,7 +72,7 @@ import logging from itertools import dropwhile -from nnscaler.ir.cten import IRTensor, IRObject +from nnscaler.ir.cten import IRTensor, IRObject, ValueTrack from nnscaler.ir.operator import IRFwOperation @@ -753,7 +753,7 @@ def ianno(self, index: int) -> ShapeAnno: @return dim_annos ShapeAnno: a tuple that each element is a dimension annotation """ assert index < len(self.inputs()), "index out of boudary" - return tuple(self._iannos[index]) + return self._iannos[index] def oanno(self, index: int) -> ShapeAnno: """! @@ -853,7 +853,7 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict op_anno.reset_identifiers() identifier_values: Dict[str, int] = dict() - for ashape, itensor in zip(op_anno.inputs(), inputs): + for idx, (ashape, itensor) in enumerate(zip(op_anno.inputs(), inputs)): if not isinstance(itensor, IRTensor) or ashape.ignore: continue if ashape.ndims != len(itensor.shape): @@ -861,7 +861,12 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict for adim, dimlen in zip(ashape.dims, itensor.shape): if len(adim.identifiers) == 1: if adim.identifiers[0] in identifier_values and identifier_values[adim.identifiers[0]] != dimlen: - raise RuntimeError(f'the exist identifier value {identifier_values[adim.identifiers[0]]} is not equal to the new value {dimlen}') + error_msg = ( + f"at {signature} with {op_anno} the exist identifier {adim.identifiers[0]} value " + f"{identifier_values[adim.identifiers[0]]} is not equal to the new value {dimlen}, " + f"error idx {idx}, input tensors {inputs}" + ) + raise RuntimeError(error_msg) identifier_values[adim.identifiers[0]] = dimlen # check dimension consistency diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index c083ec9c..7c686d76 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -34,7 +34,7 @@ import logging from collections.abc import Iterable -from nnscaler.ir.cten import IRTensor, IRObject, IR +from nnscaler.ir.cten import IRTensor, IRObject, IR, ValueTrack from nnscaler.ir.tensor import IRSubTensor, IRFullTensor from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule @@ -150,6 +150,16 @@ def Accum(*inputs, signature = None): return IRDimops(Cat, 'accum', signature, [anno], inputs) +def Dot(input, tensor, *, out=None, signature = None): + """ + torch.dot(input, tensor, *, out=None) -> Tensor + """ + assert out is None + signature = 'torch.dot' + annos = ['k+, k+ -> 1',] + return IRDimops(Dot, 'dot', signature, annos, [input, tensor]) + + def Linear(input, weight, bias=None, signature = None): signature = 'torch.nn.functional.linear' assert isinstance(input, IRTensor) and isinstance(weight, IRTensor) @@ -195,6 +205,7 @@ def CubeEinSum(*operands, equation=None, signature = None): anno = f'{lhs} -> {rhs}' return IRDimops(CubeEinSum, 'einsum', signature, [anno], operands, equation=equation) + def EinSum(equation: str, *operands, signature = None): return CubeEinSum(*operands, equation=equation, signature=signature) @@ -259,7 +270,21 @@ def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Uni size = (math.ceil((end_val-start_val)/step_val),) anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), False) - return IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) + + # Output will be replaced in Parser, + # Here we just pass the value tracks out + output = IRFullTensor(size) + if not isinstance(start, IRObject) and start == 0 \ + and not isinstance(step, IRObject) and step == 1 \ + and isinstance(end, IRObject): + # a special case for arange(0, end), which is very common in practice + # we can directly use end's value track + output.dim_tracks = [end.value_track] + else: + output.dim_tracks = [ValueTrack.new([start, end, step])] + ret = IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) + ret.set_output(0, output) + return ret def Arange(*args, start=None, end=None, step=None, out=None, dtype=None, layout=None, @@ -352,13 +377,38 @@ def creation_function_size_check(op_name, size, *arg_size) -> Tuple[Union[int, I raise ValueError(f"get illegal input size={size}, arg_size={arg_size} in {op_name}") # convert scalar to shape (1,) tensor, nnscaler don't support empty shape [] now. if len(size_val) == 0: - _logger.warn(f"detect tensor creation function {op_name} create a scalar, force it to create a shape [1] tensor instead") + _logger.warning(f"detect tensor creation function {op_name} create a scalar, force it to create a shape [1] tensor instead") size = (1,) else: raise ValueError(f"get unknown input type size={size} in {op_name}") return size +def creation_function_dim_track(resolved_size: Union[IRObject, tuple[Union[int, IRObject]]]) -> list[ValueTrack]: + if isinstance(resolved_size, IRObject): + assert isinstance(resolved_size.value, (tuple, list)) + # all dims dependent on resolved_size + return [ValueTrack.new([resolved_size]) for _ in resolved_size.value] + + dim_tracks = [] + for dim in resolved_size: + if isinstance(dim, IRObject): + dim_tracks.append(ValueTrack.new([dim])) + else: + # no dim dependency when dim is not IRObject + dim_tracks.append(ValueTrack.new([])) + return dim_tracks + + +def creation_function_set_dim_tracks(op: IRDimops, resolved_size: Union[IRObject, tuple[Union[int, IRObject]]]) -> IRDimops: + # Output will be replaced in Parser, + # Here we just pass the value tracks out + output = IRFullTensor(_unwrap_value(resolved_size)) + output.dim_tracks = creation_function_dim_track(resolved_size) + op.set_output(0, output) + return op + + def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): """ @@ -374,7 +424,10 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs), + size + ) def Zeros(size, *arg_size, out=None, dtype=None, layout=None, @@ -390,7 +443,10 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, size = creation_function_size_check('torch.zeros', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs), + size + ) def Ones(size, *arg_size, out=None, dtype=None, layout=None, @@ -406,7 +462,10 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, size = creation_function_size_check('torch.ones', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs), + size + ) def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, @@ -424,7 +483,10 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs), + size + ) def Randn(size, *arg_size, generator=None, out=None, dtype=None, layout=None, device=None, requires_grad=False, @@ -442,7 +504,10 @@ def Randn(size, *arg_size, generator=None, out=None, dtype=None, layout=None, de kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Randn, 'randn', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Randn, 'randn', signature, [anno], [], rules, **kwargs), + size + ) def Full(size, fill_value, *, out=None, dtype=None, layout=None, @@ -457,8 +522,11 @@ def Full(size, fill_value, *, out=None, dtype=None, layout=None, signature = 'nnscaler.runtime.function.full' size = creation_function_size_check('torch.full', size) anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Full, 'full', signature, [anno], [], rules, - size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad) + return creation_function_set_dim_tracks( + IRDimops(Full, 'full', signature, [anno], [], rules, + size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad), + size + ) def NewTensor(data, *, dtype=None, device=None, @@ -1809,14 +1877,18 @@ def CubeStack(*tensors, dim=0, signature=None): assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'but got {tensors}' assert isinstance(dim, int), f"but not {dim}" signature = 'nnscaler.runtime.function.stack' - iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] - oanno = [None for i in range(len(tensors[0].shape) + 1)] - oanno[dim] = f'{len(tensors)}^' - offset = 0 - for i in range(len(oanno)): - if oanno[i] is None: - oanno[i] = copy.copy(iannos[-1][offset]) - offset += 1 + if tensors[0].is_scalar_tensor(): + iannos = ['1' for _ in tensors] + oanno = [f'{len(tensors)}'] + else: + iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] + oanno = [None for i in range(len(tensors[0].shape) + 1)] + oanno[dim] = f'{len(tensors)}' + offset = 0 + for i in range(len(oanno)): + if oanno[i] is None: + oanno[i] = copy.copy(iannos[-1][offset]) + offset += 1 anno = OpAnno.create_op_str(iannos, [oanno]) return IRDimops(CubeStack, 'stack', signature, [anno], tensors, dim=dim) @@ -1834,7 +1906,7 @@ def Stack(tensors, dim=0, out=None, signature = None): return CubeStack(*tensors, dim=dim, signature=signature) -def Chunk(input, chunks, dim=0, signature = None): +def Chunk(input: IRTensor, chunks, dim=0, signature = None): """ torch.chunk(input, chunks, dim=0) """ @@ -1845,7 +1917,18 @@ def Chunk(input, chunks, dim=0, signature = None): for oanno in oannos: oanno[dim] = str(input.shape[dim] // chunks) anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(Chunk, 'chunk', signature, [anno], [input], chunks=chunks, dim=dim) + ret = IRDimops(Chunk, 'chunk', signature, [anno], [input], chunks=chunks, dim=dim) + + # set proper value tracks for outputs + output_shape = list(input.shape) + output_shape[dim] = input.shape[dim] // chunks + dim_vt = ValueTrack.new([chunks, input.dim_tracks[dim]]) + for d in range(chunks): + output = IRFullTensor(output_shape) + output.set_dim_track(dim, dim_vt) + ret.set_output(d, output) + + return ret def Select(input, dim, index, signature = None): @@ -2340,12 +2423,15 @@ def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: torch.Tensor.size(tensor, dim=None) """ assert isinstance(tensor, IRTensor) - val = tensor.shape[dim] if isinstance(dim, int) else tensor.shape - assert val is not None + if isinstance(dim, int): + val = IRObject(name='size', value=tensor.shape[dim], value_track=tensor.dim_tracks[dim]) + else: + val = tuple(IRObject('size', value=s, value_track=t) for s, t in zip(tensor.shape, tensor.dim_tracks)) + if dim is None: - return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)]) + return IRPyFunc(signature, [tensor], [val]) else: - return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)], dim=dim) + return IRPyFunc(signature, [tensor], [val], dim=dim) def Dim(tensor, signature=None) -> Union[List[int], IRPyFunc]: @@ -2602,7 +2688,7 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], if isinstance(obj, IRTensor): if name == 'shape': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - shape = IRObject('shape', value=obj.shape) + shape = tuple(IRObject('shape', value=s, value_track=t) for s, t in zip(obj.shape, obj.dim_tracks)) return IRPyFunc(signature, [instance, field], [shape]) if name == 'dtype': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" @@ -3391,10 +3477,14 @@ def Item(input, signature = None): """ torch.Tensor.item() """ - # set output to IRObject.missing, + # set output value to IRObject.missing_value, # because the output is unknown here. # It will be filled with real value in parser. - return IRPyFunc(signature, inputs=[input], outputs=[IRObject.missing], constant_foldable=False) + return IRPyFunc( + signature, inputs=[input], + outputs=[IRObject('item', value=IRObject.missing_value, is_constant=False)], + constant_foldable=False + ) def DictKeys(o: Union[Dict, IRObject], signature=None): @@ -3426,7 +3516,7 @@ def DictValues(o: Union[Dict, IRObject], signature=None): def DictItems(o: Union[Dict, IRObject], signature=None): - signature = 'nnscaler.runtime.function.dict_values' + signature = 'nnscaler.runtime.function.dict_items' if not isinstance(o, dict) and not (isinstance(o, IRObject) and isinstance(o.value, dict)): raise ValueError(f'the input should be a dict or an IRObject with dict value, but get {o}') diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 550b21f1..6cf7372f 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -65,11 +65,10 @@ def __call__(self, *args): """ return self.forward(*args) - def forward(self, *args: Tuple[IRObject]) -> Union[IRTensor, Tuple[IRTensor]]: + def forward(self, *args: IRObject) -> Union[IRTensor, Tuple[IRTensor]]: """Forward the IRGraph to add model nodes into program. - Args: - args (Tuple[IRObject]): input IRObjects + args (Tuple[IRObject, ...]): input IRObjects Returns: Any: output that can be nested structure of IRObjects @@ -288,6 +287,7 @@ def use_dataloader_input(self): # IRDataOperation. Since we already know the output of the dataloader, # we don't need to set the value for it. ir_root_obj = IRObject(name='dataloader', value=None, is_constant=False) + ir_root_obj.value_track.with_no_dep() data_op = IRDataOperation(ir_root_obj, self.inputs()) # add the data operation to the graph, which will use `next` to get data. self.insert(data_op, 0) @@ -1212,7 +1212,8 @@ def checksum(self, strict: bool = True) -> str: def copy_node_meta_info(src_node: Union[IRFwOperation, IRDataOperation], dest_node: Union[IRFwOperation, IRDataOperation]): """ Copy meta information from src_node to dest_node. - Current copy fields: ['recompute', 'comment', 'op_context', 'module_stack', 'device'] + Current copy fields: ['recompute', 'comment', 'op_context', 'module_stack', 'device', + 'hook_meta', 'pre_hook', 'post_hook'] """ if isinstance(src_node, IRFwOperation): dest_node.recompute = src_node.recompute @@ -1222,3 +1223,6 @@ def copy_node_meta_info(src_node: Union[IRFwOperation, IRDataOperation], dest_no dest_node.op_context = src_node.op_context dest_node.module_stack = src_node.module_stack dest_node.device = src_node.device + dest_node.hook_meta = src_node.hook_meta + dest_node.pre_hook = src_node.pre_hook + dest_node.post_hook = src_node.post_hook diff --git a/nnscaler/graph/parser/__init__.py b/nnscaler/graph/parser/__init__.py index 1dea36e7..e7fa0900 100644 --- a/nnscaler/graph/parser/__init__.py +++ b/nnscaler/graph/parser/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from nnscaler.graph.parser.parser import FxModuleParser +from nnscaler.graph.parser.parser import FxModuleParser, parse_fx_module from nnscaler.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph from nnscaler.graph.parser.register import register from nnscaler.graph.parser.external import * diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index a30dfa23..ae338b25 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -12,7 +12,7 @@ from nnscaler.graph import IRGraph from nnscaler.flags import CompileFlag -from nnscaler.graph.parser import FxModuleParser +from nnscaler.graph.parser import parse_fx_module from nnscaler.graph.tracer import concrete_trace from nnscaler.graph.tracer.wrap_utils import Location, is_autograd_apply, LeafWrapInfo from nnscaler.graph.tracer.torch_fx_patcher import side_effectful_inplace_ops @@ -30,8 +30,11 @@ class no_save_tensor_hook(saved_tensors_hooks): """skip saving tensors for backward since tracer only traces forward""" def __init__(self): def pack(x): - return None + return (x.shape, x.dtype, x.device) def unpack(x): + # in pytorch 2.4.0-, torch.compile will call backward when tracing graph + if torch.__version__ < (2, 4, 0): + return torch.empty(x[0], dtype=x[1], device=x[2]) raise RuntimeError("not expecting backward to be called on this tensor") super().__init__(pack, unpack) @@ -146,7 +149,7 @@ def to_ir_graph( _logger.info(f"constant folding {'enabled' if constant_folding else 'disabled'} to parse graph") with no_save_tensor_hook(): - inputs, nodes, outputs = FxModuleParser.parse( + inputs, nodes, outputs = parse_fx_module( traced_model, dummy_input, attr_savedir=attr_savedir, constant_folding=constant_folding, diff --git a/nnscaler/graph/parser/external/__init__.py b/nnscaler/graph/parser/external/__init__.py index 5c71d8f9..5a628d8f 100644 --- a/nnscaler/graph/parser/external/__init__.py +++ b/nnscaler/graph/parser/external/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .apex import * \ No newline at end of file +from .apex import * +from .einops import * diff --git a/nnscaler/graph/parser/external/apex.py b/nnscaler/graph/parser/external/apex.py index 94b22209..e7d321d1 100644 --- a/nnscaler/graph/parser/external/apex.py +++ b/nnscaler/graph/parser/external/apex.py @@ -83,6 +83,7 @@ def apex_fused_rms_norm_affine_anno(input, weight, normalized_shape, eps, *args, parser.register(apex_fused_layer_norm_affine_anno)(fused_layer_norm_affine) parser.register(apex_fused_rms_norm_anno)(fused_rms_norm) parser.register(apex_fused_rms_norm_affine_anno)(fused_rms_norm_affine) + _logger.info("apex ops registered successfully.") except: - _logger.warning('skip apex ops as it is not installed.') + _logger.debug('skip apex ops as it is not installed.') diff --git a/nnscaler/graph/parser/external/einops.py b/nnscaler/graph/parser/external/einops.py new file mode 100644 index 00000000..91845fe8 --- /dev/null +++ b/nnscaler/graph/parser/external/einops.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging + +import torch + + +_logger = logging.getLogger(__name__) + +try: + import einops + + # trigger einops initialization + einops.rearrange(torch.arange(1), '(a b c) -> a b c', a=1, b=1, c=1) +except ImportError as e: + _logger.debug("Einops is not installed") + pass diff --git a/nnscaler/graph/parser/mapping.py b/nnscaler/graph/parser/mapping.py index 55c2792f..2c8f88ce 100644 --- a/nnscaler/graph/parser/mapping.py +++ b/nnscaler/graph/parser/mapping.py @@ -55,6 +55,7 @@ def exist(signature: str) -> bool: kOpMap = { # __tnmtemplate('Dropout'): function.nnDropout, + __ttemplate('dot'): function.Dot, __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 02c52611..4fa263b9 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -11,8 +11,9 @@ from nnscaler.graph.tracer.metadata import OpContext from nnscaler.ir.operator import IRFwOperation from nnscaler.ir.tensor import IRFullTensor -from nnscaler.ir.cten import IRObject, IRCell, IRTensor, IR +from nnscaler.ir.cten import IRObject, IRCell, IRTensor, IR, ValueTrack from nnscaler.graph.parser.frame import Frame +from nnscaler.graph.parser.value_tracker import ValueTracker from nnscaler.graph.parser.mapping import SignFx2Op from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import IRDimops @@ -38,14 +39,14 @@ class FxModuleParser: ATTR_CONTENT_FILE_FORMAT = '{stem}.{idx}' ATTR_MAP_FILE = 'dist_param_map.pt' - @staticmethod - def parse(module: torch.fx.GraphModule, + def __init__(self, + module: torch.fx.GraphModule, dummy_inputs: Dict[str, Any], attr_savedir='./', *, save_content: bool = True, constant_folding: bool = False - ) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + ): """Parse torch.fx module into cube IR The overall entry to parse a torch.fx graph module @@ -56,6 +57,24 @@ def parse(module: torch.fx.GraphModule, attr_savedir (str): the directory to save the attribute content save_content (bool): whether to save the content of the module constant_folding (bool): whether to parse the module with constant folding + """ + + self.module = module + + self.dummy_inputs = dummy_inputs + assert isinstance(dummy_inputs, dict), f"Expected dummy inputs to parse module, but got {dummy_inputs} of type {type(dummy_inputs)}" + + self.attr_savedir = attr_savedir + self.save_content = save_content + self.constant_folding = constant_folding + + self.frame = Frame() + self.value_tracker = ValueTracker() + + def parse(self) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + """Parse torch.fx module into cube IR + + The overall entry to parse a torch.fx graph module Returns: inputs (List[IRObject]): the input IRObjects @@ -67,12 +86,10 @@ def parse(module: torch.fx.GraphModule, # (Those ops creators include user registered function, all functions returning tensors and more) # We will connect the real op outputs (saved in frame) to all ir op outputs and inputs later. - frame = Frame() - frame.push_var() + self.frame.push_var() # shape propagation - assert isinstance(dummy_inputs, dict), f"Expected dummy inputs to parse module, but got {dummy_inputs} of type {type(dummy_inputs)}" - output_nodes = [node for node in module.graph.nodes if node.op == 'output'] + output_nodes = [node for node in self.module.graph.nodes if node.op == 'output'] # currently fx graph always has only one output # even if a tuple/list is returned, it is still just one output assert len(output_nodes) == 1, f"Expect only one output, but got {len(output_nodes)}" @@ -81,11 +98,11 @@ def parse(module: torch.fx.GraphModule, assert len(output_node.args) == 1 and len(output_node.kwargs) == 0 # create IRObjects and IRTensors - for node in module.graph.nodes: + for node in self.module.graph.nodes: if node.op == 'placeholder': - FxModuleParser.init_objects(node, module, frame, is_constant=False) + self._init_objects(node, is_constant=False) else: - FxModuleParser.init_objects(node, module, frame, is_constant=True) + self._init_objects(node, is_constant=True) # note the output node will be reset later by `parse_prim_output_node` # with the help of `parse_complex` @@ -98,76 +115,93 @@ def parse(module: torch.fx.GraphModule, # to make sure the IRGraph has the correct output number # see `IRGrpah.from_logic_graph` - val = frame.get_var(node.name) + val = self.frame.get_var(node.name) if node == output_node.args[0] \ and IR.is_object(val) and isinstance(val.value, tuple): tuple_val = tuple(IRObject(name=node.name, value=v, is_constant=val.is_constant) for v in val.value) - frame.set_var(node.name, tuple_val) + self.frame.set_var(node.name, tuple_val) # get graph inputs - placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] - inputs = [frame.get_var(n.name) for n in placeholders] + placeholders = [n for n in self.module.graph.nodes if n.op == 'placeholder'] + inputs = [self.frame.get_var(n.name) for n in placeholders] + self.value_tracker.track_values(inputs) # - if the graph inputs contain nested strcuture, # it should be wrapped into an IRObject for idx, placeholder in enumerate(placeholders): if not isinstance(inputs[idx], IRObject): - obj = IRObject(name=placeholder.name, value=inputs[idx], is_constant=False) + obj = IRObject(name=placeholder.target, value=inputs[idx], is_constant=False) + obj.value_track.mark_as_input() inputs[idx] = obj - frame.set_var(placeholder.name, obj) + self.value_tracker.track_values([obj]) + self.frame.set_var(placeholder.name, obj) # parse graph nodes all_ir_nodes = [] - for node in module.graph.nodes: - ir_nodes = FxModuleParser.parse_node(node, module, constant_folding, frame) + for node in self.module.graph.nodes: + ir_nodes = self._parse_node(node) all_ir_nodes += ir_nodes + self.value_tracker.complete_tracking(all_ir_nodes) + # get graph outputs - outputs = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + outputs = [self.frame.get_var(node.name) for node in self.module.graph.nodes if node.op == 'output'] # currently fx graph always has only one output # even if a tuple/list is returned, it is still just one output assert len(outputs) == 1, f"Expect only one output, but got {len(outputs)}" - if save_content: - attr_savedir = Path(attr_savedir) - frame.save_attr_content(attr_savedir / FxModuleParser.ATTR_CONTENT_FILE_STEM) - frame.save_attr_map(attr_savedir / FxModuleParser.ATTR_MAP_FILE) + if self.save_content: + attr_savedir = Path(self.attr_savedir) + self.frame.save_attr_content(attr_savedir / self.ATTR_CONTENT_FILE_STEM) + self.frame.save_attr_map(attr_savedir / self.ATTR_MAP_FILE) - frame.pop_var() + self.frame.pop_var() return inputs, all_ir_nodes, outputs - @staticmethod - def parse_node(node: torch.fx.Node, module, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: + def _parse_node(self, node: torch.fx.Node) -> List[IRFwOperation]: """ Parse the node and return the IRFwOperation nodes """ if node.op == 'placeholder': return [] if node.op == 'output': - return FxModuleParser.parse_prim_output_node(node, module, frame) + return self._parse_prim_output_node(node) if node.op in ('call_function', 'call_method'): - return FxModuleParser.parse_prim_function_method(node, module, constant_folding, frame) + return self._parse_prim_function_method(node) if node.op == 'get_attr': - return FxModuleParser.parse_prim_get_attr_node(node, module, frame) + return self._parse_prim_get_attr_node(node) if node.op == 'call_module': - return FxModuleParser.parse_prim_module(node, module, frame) + return self._parse_prim_module(node) else: raise TypeError(f"Unknown node kind {node.op}") - @staticmethod - def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, - frame: Frame, is_constant: bool = True): + def _init_objects(self, node: torch.fx.Node, is_constant: bool = True): assert isinstance(node, torch.fx.Node) assert hasattr(node, 'meta') and 'tensor_meta' in node.meta, f"Node {node} should have tensor_meta" meta = node.meta['tensor_meta'] - val = IR.new(node.name, meta, + val = IR.new( + # node.target is necesssary for input + # its name will be used to align with model forward args when generating code. + node.target if node.op == 'placeholder' else node.name, + meta, tensor_types=(TensorMetadata,), - is_constant=is_constant + is_constant=is_constant, ) - frame.add_var(node.name, val) - @staticmethod - def parse_complex(val: Any, frame: Frame) -> Any: + if node.op == 'placeholder': + def mark_as_input(x: IRObject): + if isinstance(x, IRTensor): + # let's the value_track of tensor stay None(unknown) + # because we don't care about it. + for dt in x.dim_tracks: + dt.with_no_dep() + else: + x.value_track.mark_as_input() + IR.modify_objects(val, mark_as_input) + + self.frame.add_var(node.name, val) + + def _parse_complex(self, val: Any) -> Any: """parse complex fx.Node into IRObject The val is usually from a node's input or output, can be fx.Node nested @@ -183,28 +217,28 @@ def parse_complex(val: Any, frame: Frame) -> Any: # to support more nested types, we can refer to the implementation of # https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py if isinstance(val, tuple): - return tuple(FxModuleParser.parse_complex(t, frame) for t in val) + return tuple(self._parse_complex(t) for t in val) if isinstance(val, list): - return list(FxModuleParser.parse_complex(t, frame) for t in val) + return list(self._parse_complex(t) for t in val) if isinstance(val, dict): - return {key: FxModuleParser.parse_complex(val, frame) for key, val in val.items()} + return {key: self._parse_complex(val) for key, val in val.items()} # TODO: Currently slice/DICT_VALUES_TYPE/DICT_ITEMS_TYPE cases are never found. # We need to find some examples to test them. if isinstance(val, slice): - return slice(FxModuleParser.parse_complex(val.start, frame), - FxModuleParser.parse_complex(val.stop, frame), - FxModuleParser.parse_complex(val.step, frame)) + return slice(self._parse_complex(val.start), + self._parse_complex(val.stop), + self._parse_complex(val.step)) # because fx node cannot be a dict key, so skip DICT_KEYS_TYPE here if isinstance(val, DICT_VALUES_TYPE): - return tuple(FxModuleParser.parse_complex(x, frame) for x in val) + return tuple(self._parse_complex(x) for x in val) if isinstance(val, DICT_ITEMS_TYPE): - return tuple((i, FxModuleParser.parse_complex(x, frame)) for i, x in val) + return tuple((i, self._parse_complex(x)) for i, x in val) if isinstance(val, torch.fx.Node): - return frame.get_var(val.name) + return self.frame.get_var(val.name) return val - @staticmethod - def fetch_attr(mod: torch.fx.GraphModule, target: str): + @classmethod + def _fetch_attr(cls, mod: torch.fx.GraphModule, target: str): target_atoms = target.split('.') attr_itr = mod for i, atom in enumerate(target_atoms): @@ -213,23 +247,19 @@ def fetch_attr(mod: torch.fx.GraphModule, target: str): attr_itr = getattr(attr_itr, atom) return attr_itr - @staticmethod - def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: - prim_module = FxModuleParser.fetch_attr(module, node.target) + def _parse_prim_module(self, node: torch.fx.Node) -> List[IRFwOperation]: + prim_module = self._fetch_attr(self.module, node.target) if prim_module.__class__.__module__.startswith('torch.nn.modules'): raise RuntimeError(f'{prim_module.__class__.__module__} can not be parsed as leaf nodes') else: raise RuntimeError(f'unknown module: {prim_module.__class__.__module__}') - @staticmethod - def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: + def _parse_prim_function_method(self, node: torch.fx.Node) -> List[IRFwOperation]: """ Convert `call_function`/`call_method` op to IRFwOperation. Args: node (torch.fx.Node): the node to be parsed - module (torch.fx.GraphModule): the module containing the node - constant_folding (bool): global setting of whether to fold the constant Returns: List[IRFwOperation]: the IRFwOperation nodes. @@ -238,10 +268,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule """ # get signature - fsig = FxModuleParser._get_qualified_name(node.target, node) + fsig = self._get_qualified_name(node.target, node) # get inputs - input_vals = FxModuleParser.parse_complex(list(node.args), frame) - kwargs = FxModuleParser.parse_complex(node.kwargs, frame) + input_vals = self._parse_complex(list(node.args)) + kwargs = self._parse_complex(node.kwargs) # use context constant_folding if set # Please note constant_folding only controls the output of the op @@ -249,6 +279,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # when we enter the code block with different constant folding setting # as a workaround, # you can use `nnscaler.runtime.function.fold_constant` to fold inputs if needed + constant_folding = self.constant_folding op_context: Optional[Dict[str, Any]] = node.meta.get('op_context') if op_context is not None and op_context.get(fields(OpContext).constant_folding) is not None: constant_folding = op_context[fields(OpContext).constant_folding] @@ -258,12 +289,12 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule else: # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator - if FxModuleParser._is_torch_autograd_op(node, frame, fsig): + if self._is_torch_autograd_op(node, fsig): _logger.warning(f'Find unknown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fsig ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) # case2: custom autograd function - elif FxModuleParser._is_custom_autograd_op(node): + elif self._is_custom_autograd_op(node): # custom autograd function _logger.warning(f'Find unknown custom autograd operation: {fsig}. You should register it with nnscaler.register_op') ir_node = IRFwOperation(fsig, fsig, input_vals, 1, **kwargs) @@ -276,7 +307,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule 'You can register it as a customized function using nnscaler.register_op to remove this warning' _logger.warning(warning_msg) is_constant = False - output = frame.get_var(node.name) + output = self.frame.get_var(node.name) if not isinstance(output, IRObject): # avoid nested IRObject output = IRObject(name=node.name, value=output, is_constant=is_constant) @@ -292,10 +323,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # As node is deleted, we must set concrete value or IRTensor/IRObject into framework. # TODO: check the value saved in frame should equal to the value returned by the op - frame.set_var(node.name, ir_node) + self.frame.set_var(node.name, ir_node) return [] - FxModuleParser._set_node_meta(node, ir_node) + self._set_node_meta(node, ir_node) # step 1: align the node output with the value in frame @@ -307,11 +338,11 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # but its output is used in other nodes. # By removing from frame, # we can catch the case earlier - frame.del_val(node.name) + self.frame.del_val(node.name) # if the function has no output, just return return [ir_node] - vals = frame.get_var(node.name) + vals: Union[Any, IRObject, List[IRObject], IRTensor, List[IRTensor]] = self.frame.get_var(node.name) if len(ir_node.outputs()) == 1: vals = [vals] elif IR.is_object(vals): @@ -324,7 +355,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule if not isinstance(vals, (list, tuple)): raise RuntimeError(f'Expect list or tuple for multiple outputs, but got {type(vals)}') vals = type(vals)(IRObject(name=node.name, value=v, is_constant=is_constant) for v in vals) - frame.set_var(node.name, vals) + self.frame.set_var(node.name, vals) # verify the inferred shape are consistent with actual output if isinstance(ir_node, IRFwOperation): @@ -337,11 +368,43 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # 1. output tensors are not set in function.py # 2. IRObject output from some functions (registered functions/getattr) are not set # For above two cases, we need to set them with values from frame. + if isinstance(ir_node.output(i), IRTensor): + assert isinstance(vals[i], IRTensor), f'Expect tensor for output {i}, but got {type(vals[i])}' + assert ir_node.output(i).shape == vals[i].shape, f'Expect shape {ir_node.output(i).shape} for output {i}, but got {vals[i].shape}' + # We need to copy dim tracks + # As we will use frame version as node output, instead of the placeholder created in function.py + for dim in range(len(vals[i].shape)): + vals[i].dim_tracks[dim].merge(ir_node.output(i).dim_tracks[dim]) ir_node.set_output(i, vals[i]) + elif isinstance(ir_node.output(i), IRObject) and ir_node.output(i).is_value_missing(): + # output is IRObject with missing value + # we need to set it with the value from frame + assert not IR.contains_object(vals[i], lambda x: isinstance(x, IRTensor)), \ + f'Output {i} of node {node} is expected to be IRObject, but got tensor: {vals[i]}' + ir_node.output(i).value = IR.try_unwrap(vals[i]) + else: + # Currently we don't support missing-value IRObject in tuple/list/dict/... + # TODO: add support when needed + assert not IR.contains_object(ir_node.output(i), lambda x: not isinstance(x, IRTensor) and x.is_value_missing()), \ + f'Output {i} of node {node} contains missing value: {ir_node.output(i)}' + + # per-op value tracking via its annotation + # TODO: + # This may be not accurate because many ops in function.py are not properly annotated their value deps + # Two ways to improve it: + # 1. add value deps annotation for those ops in function.py + # 2. use global data flow analysis to track value deps + # a. add all nodes without folding + # b. use value_tracker.track_nodes to analyze value deps for all nodes + # c. remove nodes that can be folded. + # It is not easy because some op logic in function.py works differently + # when its inputs are constant or not. + # For now, we just use per-op value tracking for simplicity. + self.value_tracker.track_nodes([ir_node]) # update frame with ir output # Please note when there is only one output, we will unwrap it from `ir_node.outputs()` here - frame.set_var( + self.frame.set_var( node.name, type(vals)(ir_node.outputs()) if len(ir_node.outputs()) > 1 else ir_node.output(0) ) @@ -349,6 +412,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # update the name of output tensors # Note assignment is not allowed in lambda # so we use a helper function to update the name + def _update_name(x: IRObject): x.name = node.name IR.modify_objects_inplace(ir_node.outputs(), _update_name) @@ -378,22 +442,28 @@ def _is_primitive_type(val): # use a white list instead of a black list return isinstance(val, (int, float, bool, type(None), str, type(Ellipsis))) - # Note when it is not IRObject as a whole, we will not fold it if constant_folding and ir_node.constant_foldable \ and len(ir_node.outputs()) == 1 \ - and isinstance(ir_node.output(0), IRObject) \ - and not isinstance(ir_node.output(0), IRTensor) \ and not contains_undefined_output \ and not ir_node.signature.startswith(nnscaler.runtime.function.__name__ + '.')\ - and ir_node.output(0).is_constant \ - and _is_primitive_type(ir_node.output(0).value): - frame.set_var(node.name, ir_node.output(0).value) + and not IR.contains_object(ir_node.output(0), lambda x: isinstance(x, IRTensor) or not x.is_constant) \ + and _is_primitive_type(cval := IR.try_unwrap(ir_node.output(0))): + # TODO: + # This will break the value tracking graph + # for example, if not folded: + # value1 -> op1 -> value2 -> op2 -> value3 -> op3 + # if op2 is folded, then op3 will not know the value1 dependency + # So the value tracking becomes: + # value1 -> op1 value3 -> op3 + # In many cases, op1 and op3 can be connected by other ops, + # But when this becomes a problem, we need to fix it by using global data flow analysis. + self.frame.set_var(node.name, cval) + self.value_tracker.untrack_node(ir_node) return [] else: return [ir_node] - @staticmethod - def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + def _parse_prim_get_attr_node(self, node: torch.fx.Node) -> List[IRFwOperation]: """ There are two types of get_attr, one is `FxNodeKind.PrimGetAttr` which is dealt with in this function. The other is `FxNodeKind.PrimCallFunction ` (i.e., ) @@ -403,74 +473,84 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, node.target is the attribute name of the object. """ ir_nodes = [] - concrete_value = FxModuleParser.fetch_attr(module, node.target) + concrete_value = self._fetch_attr(self.module, node.target) if isinstance(concrete_value, torch.Tensor): assert isinstance(concrete_value, torch.Tensor), \ f"GetAttrPrim: expect tensor but got {type(concrete_value)}" - exist_tensor = frame.get_attr_var(concrete_value) + exist_tensor = self.frame.get_attr_var(concrete_value) # the case that the parameter is the first time used by getattr if not exist_tensor: - tensor = frame.get_var(node.name) + tensor: IRFullTensor = self.frame.get_var(node.name) # set tensor name same with the name in original model tensor.name = node.target if tensor.requires_grad: tensor.as_param() else: - direct_module = module + direct_module = self.module full_qualified_name = node.target.split('.') for name in full_qualified_name[:-1]: # last one is the attribute name direct_module = getattr(direct_module, name) persistent = full_qualified_name[-1] not in direct_module._non_persistent_buffers_set tensor.as_buffer(persistent=persistent) - frame.add_attr(tensor, concrete_value, node.target) + + # Parameters and buffers have no dependency on other values + for dt in tensor.dim_tracks: + dt.is_constant = True + dt.with_no_dep() + + self.frame.add_attr(tensor, concrete_value, node.target) # the case that the parameter is consumed multiple times and registered previously else: - frame.set_var(node.name, exist_tensor) + self.frame.set_var(node.name, exist_tensor) else: assert isinstance(node.target, str), f"GetAttrPrim: expect `node.target` to be str but got {type(node.target)}" # in sub modules, the target is full qualified name (for example `embeddings.dropout.training`) if node.target.split('.')[-1] == 'training': # Let's just support `self.training` and ignore all other cases for now - output = IRObject(name=node.name, value=frame.get_var(node.name), is_constant=False) + if isinstance(output := self.frame.get_var(node.name), IRObject): + output.is_constant = False + else: + output = IRObject(name=node.name, value=output, is_constant=False) ir_node = IRPyFunc(SELF_GETATTR_SIG, ['training'], [output]) - FxModuleParser._set_node_meta(node, ir_node) - frame.set_var(node.name, output) + self._set_node_meta(node, ir_node) + self.frame.set_var(node.name, output) # never fold the IRPyFunc node ir_nodes.append(ir_node) else: - frame.set_var(node.name, concrete_value) + self.frame.set_var(node.name, concrete_value) + self.value_tracker.track_nodes(ir_nodes) return ir_nodes - @staticmethod - def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: + def _parse_prim_output_node(self, node: torch.fx.Node) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 - output = FxModuleParser.parse_complex(node.args[0], frame) - frame.set_var(node.name, output) + output = self._parse_complex(node.args[0]) + self.frame.set_var(node.name, output) return [] - @staticmethod - def _set_node_meta(node: torch.fx.Node, ir_node: Union[IRCell, Any]): + @classmethod + def _set_node_meta(cls, node: torch.fx.Node, ir_node: Union[IRCell, Any]): if not isinstance(ir_node, IRCell): return ir_node.op_context = node.meta.get('op_context') module_stack = node.meta.get('nn_module_stack') ir_node.module_stack = module_stack + ir_node.call_expr = node.meta.get('call_expr') comment = str(node.meta.get('frame_record', '')) if comment: ir_node.comment = comment - @staticmethod - def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: + @classmethod + def _get_qualified_name(cls, node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: if isinstance(node_target, str): assert node is not None - return FxModuleParser._get_qualified_name_of_call_method(node_target, node) + return cls._get_qualified_name_of_call_method(node_target, node) else: - return FxModuleParser._get_qualified_name_of_call_function(node_target) + return cls._get_qualified_name_of_call_function(node_target) - @staticmethod - def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str: + @classmethod + def _get_qualified_name_of_call_function(cls, node_target: Callable[..., Any]) -> str: """ The target field of call_function node must be an callable object. """ @@ -480,12 +560,12 @@ def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str # TODO(yizhu1): find a general solution assert callable(node_target) name = node_target.__name__ - module = FxModuleParser._find_module_of_method(node_target) + module = cls._find_module_of_method(node_target) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module return f'{module}.{name}' - @staticmethod - def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> str: + @classmethod + def _get_qualified_name_of_call_method(cls, node_target: str, node: torch.fx.Node) -> str: """ The target field of call_method node must be a string. """ @@ -513,8 +593,8 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> else: return f'{in_type.__module__}.{in_type.__name__}.{node_target}' - @staticmethod - def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + @classmethod + def _find_module_of_method(cls, orig_method: Callable[..., Any]) -> str: if getattr(orig_method, '__name__', None) == 'apply' and isinstance(getattr(orig_method, '__self__', None), Type) \ and issubclass(orig_method.__self__, torch.autograd.Function): # for torch.autograd.Function @@ -547,18 +627,49 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str: return guess.__name__ raise RuntimeError(f'cannot find module for {orig_method}') - @staticmethod - def _is_torch_autograd_op(node: torch.fx.Node, frame: Frame, signature: str) -> bool: + def _is_torch_autograd_op(self, node: torch.fx.Node, signature: str) -> bool: """Check whether the node is of a pytorch autograd operation.""" # note: some python operations like torch.Tensor.size() doesn't return # an IRTensor, thus cannot be considered as a pytorch autograd operator. return signature.startswith('torch.') and \ - isinstance(frame.get_var(node.name), IRFullTensor) + isinstance(self.frame.get_var(node.name), IRFullTensor) - @staticmethod - def _is_custom_autograd_op(node: torch.fx.Node) -> bool: + @classmethod + def _is_custom_autograd_op(cls, node: torch.fx.Node) -> bool: node_target = node.target return callable(node_target) \ and getattr(node_target, '__name__', None) == 'apply' \ and isinstance(getattr(node_target, '__self__', None), Type) \ and issubclass(node_target.__self__, torch.autograd.Function) + + +def parse_fx_module( + module: torch.fx.GraphModule, + dummy_inputs: Dict[str, Any], + attr_savedir='./', + *, + save_content: bool = True, + constant_folding: bool = False +) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + """Parse torch.fx module into cube IR + + The overall entry to parse a torch.fx graph module + + Args: + module (torch.fx.GraphModule): the torch.fx module + dummy_inputs (Dict[str, Any]): the dummy inputs to run the module + attr_savedir (str): the directory to save the attribute content + constant_folding (bool): whether to parse the module with constant folding + + Returns: + inputs (List[IRObject]): the input IRObjects + all_ir_nodes (List[IRFwOperation]): the IRFwOperation nodes + outputs (List[IRObject]): the output IRObjects + """ + return FxModuleParser( + module, + dummy_inputs, + attr_savedir, + save_content=save_content, + constant_folding=constant_folding + ).parse() diff --git a/nnscaler/graph/parser/value_tracker.py b/nnscaler/graph/parser/value_tracker.py new file mode 100644 index 00000000..45a3cf0f --- /dev/null +++ b/nnscaler/graph/parser/value_tracker.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import defaultdict +from typing import Any +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir.cten import IR, IRObject, IRTensor, ValueTrack +from nnscaler.ir.operator import IRFwOperation + + +class ValueTracker: + """ + Example: + >>> vt = ValueTracker() + >>> vt.track_value(input1) + >>> vt.track_value(input2) + >>> ... + >>> vt.track_nodes([node1]) + >>> vt.track_nodes([node2]) + >>> vt.untrack_node(node2) # when node2 is folded + >>> vt.track_nodes([node3]) + >>> ... + >>> vt.complete_tracking([node1, node3, ...]) # pass all tracked nodes here + """ + def __init__(self): + # value_id -> ValueTrack + # Please note some ValueTracks may be merged together (from annotation) + # So the key can be different from the id of the ValueTrack + self._vtm: dict[int, ValueTrack] = {} + self._equiv_value_ids: dict[int, set[int]] = {} + # store removed value ids + # used to delay the removal of value tracks in deps + self._removed_value_ids: set[int] = set() + + def _add_track_value(self, value: ValueTrack): + if value.value_id not in self._vtm: + # always use the updated value track in self._vtm + self._vtm[value.value_id] = value + + if value.value_id not in self._equiv_value_ids: + self._equiv_value_ids[value.value_id] = {value.value_id} + + def track_values(self, objs: list[Any]) -> set[int]: + """ + Track the value tracks of the given objects. + Args: + objs (list[Any]): the objects to be tracked + Returns: + set[int]: the set of value ids tracked + """ + value_ids = set() + for obj in objs: + value_ids.update(self._track_value(obj)) + return value_ids + + def _track_value(self, value: Any): + for obj in IR.get_objects(value): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + self._add_track_value(dt) + yield dt.value_id + else: + assert isinstance(obj, IRObject) + self._add_track_value(obj.value_track) + yield obj.value_track.value_id + + def _update_track_value(self, obj: IRObject): + if isinstance(obj, IRTensor): + new_dim_tracks = [] + for dt in obj.dim_tracks: + new_dim_tracks.append(self._vtm[dt.value_id]) + obj.dim_tracks = new_dim_tracks + else: + assert isinstance(obj, IRObject) + obj.value_track = self._vtm[obj.value_track.value_id] + + def _update_constness(self, obj: IRObject): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + dt.is_constant = dt.is_constant and all(self._vtm[dep].is_constant for dep in dt.deps or []) + else: + assert isinstance(obj, IRObject) + obj.value_track.is_constant = obj.value_track.is_constant and all(self._vtm[dep].is_constant for dep in obj.value_track.deps or []) + + def track_nodes(self, nodes: list[IRFwOperation]): + """ + Track the value tracks of the input and output objects in the given nodes. + Here we assume the nodes are topologically sorted. + + Please note we only update the tracks of nodes in arguments. + For nodes not in arguments, their tracks are not updated. + + Args: + nodes (list[IRFwOperation]): the nodes to be tracked + """ + # collect all value tracks from nodes + if not nodes: + return + + # collect all involved value ids from nodes + node_value_ids = set() + for node in nodes: + for obj in node.iobjs(): + node_value_ids.update(self._track_value(obj)) + for obj in node.oobjs(): + node_value_ids.update(self._track_value(obj)) + + # collect extra value tracks from dimops + for node in nodes: + if isinstance(node, IRDimops): + self._track_dims(node) + + # merge equivalent value tracks together + done_value_ids = set() + for value_id in node_value_ids: + equiv_ids = self._equiv_value_ids[value_id] + + min_value_id = min(equiv_ids) + if min_value_id in done_value_ids: + continue + done_value_ids.add(min_value_id) + + # use the smallest id as the representative + rep_one = self._vtm[min_value_id] + for vid in equiv_ids: + if vid == min_value_id or self._vtm[vid] is rep_one: + continue + # TODO: how we merge dependencies? + # current we take union (Union may be too strict) + if rep_one.deps is None: + rep_one.deps = self._vtm[vid].deps + elif self._vtm[vid].deps is not None: + # deps can still have duplicates here + # because merging of the rest value tracks haven't been done yet + # NOTE: + # 1. this duplication is temporary, + # Duplicated value ids will be removed when we touch the same value track again + # in future track_nodes call. + # 2. duplication is not harmful for correctness + rep_one.deps = list( + set(rep_one.deps) + .union(self._vtm[vid].deps) + .difference(self._removed_value_ids) + ) + self._vtm[vid] = rep_one + + self._propagate_tracks(nodes) + + def untrack_node(self, node: IRFwOperation): + """ + Untrack the value tracks of output objects in the given node. + This function is used when we fold a node from the graph. + + Args: + node (IRFwOperation): the node to be untracked + """ + input_value_ids = set() + for obj in node.iobjs(): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + input_value_ids.add(dt.value_id) + else: + assert isinstance(obj, IRObject) + input_value_ids.add(obj.value_track.value_id) + + for obj in node.oobjs(): + # we can only remove value tracks that are not used by inputs + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + if dt.value_id not in input_value_ids: + self._removed_value_ids.add(dt.value_id) + else: + assert isinstance(obj, IRObject) + if obj.value_track.value_id not in input_value_ids: + self._removed_value_ids.add(obj.value_track.value_id) + + def complete_tracking(self, nodes: list[IRFwOperation]): + """ + Complete the tracking process. + Should be called after all nodes are tracked. + """ + # remove all removed value ids for vtm + # note we don't remove them from equivalence classes + for removed_id in self._removed_value_ids: + if self._vtm[removed_id].value_id == removed_id \ + and (new_equiv_cls := self._equiv_value_ids[removed_id].difference(self._removed_value_ids)): + # change the representative value id of this equivalence class + # NOTE: + # In current usage, code should not reach here. + # As we remove value tracks only for constant irobjects, + # and all equivalent value tracks should be removed together. + self._vtm[removed_id].value_id = min(new_equiv_cls) + self._vtm.pop(removed_id, None) + + # replace dependencies with their representative value tracks + # which can introduce some duplicates + # So we use `set` to further dedup dependencies + for vt in self._vtm.values(): + if vt.deps is not None: + vt.deps = list(set( + self._vtm[d].value_id for d in vt.deps + if d not in self._removed_value_ids + )) + + self._propagate_tracks(nodes) + + def _propagate_tracks(self, nodes: list[IRFwOperation]): + """ + Update value tracks and constantness information of the input and output objects + in the given nodes. + """ + # propagate the merged value tracks back to nodes + for node in nodes: + for obj in node.iobjs(): + self._update_track_value(obj) + for obj in node.oobjs(): + self._update_track_value(obj) + + # propagate the constantness information back to nodes + for node in nodes: + for obj in node.iobjs(): + self._update_constness(obj) + for obj in node.oobjs(): + self._update_constness(obj) + + def _track_dims(self, node: IRDimops): + """ + Track the dimension values of output tensors according to input tensors. + This function should be called after shape inference. + """ + # align the dim_ids of output with inputs + # not-hidden-dimension means the identifier is all for this dimension + # for example, in `l (2 h) m`, + # l and m are not-hidden-dimension identifiers, h is hidden-dimension identifier + # + # If the annotation is `l (2 h) m -> l h (m 2 h)` + # We will get the following relations (nhd->not-hidden-dimension, hd->hidden-dimension): + # 1. for `l`: `input.dim_tracks[0] is output.dim_tracks[0]` # both nhd, equality + # 2. for `m`: `input.dim_tracks[2].value_id in output.dim_tracks[2].deps` # one is hd, depencency + # 3. for `h`: `input.dim_tracks[1].value_id in output.dim_tracks[2].deps` # one is hd, depencency + # `input.dim_tracks[1] in output.dim_tracks[1].deps` # one is hd, depencency + + # TODO: We can handle more complex cases in the future if needed. + # In current version, we don't handle the case like + # 1. `(2 h) -> (2 h)`: input.dim_tracks[0] should be equal to output.dim_tracks[0]? (2 can be a runtime number, so we cannot be sure) + # 2. `(l m) -> (l m)`: input.dim_tracks[0] should be equal to output.dim_tracks[0]. + + # ivt => identifier_value_track_map + hidden_ivt: dict[str, list[ValueTrack]] = defaultdict(list) + non_hidden_ivt: dict[str, list[ValueTrack]] = defaultdict(list) + + for i, input_tensor in enumerate(node.inputs()): + if not isinstance(input_tensor, IRTensor) or node.ianno(i).ignore: + continue + + ianno = node.ianno(i) + for dim, dim_track in zip(ianno.dims, input_tensor.dim_tracks): + identifiers = [i for i in dim.identifiers if not str.isdecimal(i)] + if len(identifiers) == 1 and len(dim.identifiers) == 1: + # not hidden dimension + non_hidden_ivt[identifiers[0]].append(dim_track) + else: + for iden in identifiers: + hidden_ivt[iden].append(dim_track) + + for iden, iden_infos in non_hidden_ivt.items(): + # merge all not-hidden-dimension infos together + first = iden_infos[0] + for info in iden_infos[1:]: + self._add_equiv_value(first.value_id, info.value_id) + + for i, output_tensor in enumerate(node.outputs()): + if not isinstance(output_tensor, IRTensor) or node.oanno(i).ignore: + continue + + oanno = node.oanno(i) + for dim, dim_track in zip(oanno.dims, output_tensor.dim_tracks): + # find the first identifier that is not a number + identifiers = [i for i in dim.identifiers if not str.isdecimal(i)] + if len(identifiers) == 1 and len(dim.identifiers) == 1: + ident = identifiers[0] + if ident in non_hidden_ivt: + first = non_hidden_ivt[ident][0] + self._add_equiv_value(first.value_id, dim_track.value_id) + else: + # this identifier is used together with other identifiers + # so it is just a dependency. + dim_track.deps = dim_track.deps or [] + dim_track.deps.extend(v.value_id for v in hidden_ivt[ident]) + dim_track.deps = list(set(dim_track.deps)) # deduplicate + else: + dim_track.deps = dim_track.deps or [] + for ident in identifiers: + if ident in hidden_ivt: + dim_track.deps.extend(v.value_id for v in hidden_ivt[ident]) + if ident in non_hidden_ivt: + first = non_hidden_ivt[ident][0] + dim_track.deps.append(first.value_id) + + def _add_equiv_value(self, value_id, other_value_id): + self._equiv_value_ids[value_id].update(self._equiv_value_ids[other_value_id]) + for vid in self._equiv_value_ids[other_value_id]: + self._equiv_value_ids[vid] = self._equiv_value_ids[value_id] diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index 9ec92c27..85f7c8f6 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -16,6 +16,7 @@ from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType from typing import Any, Dict, Optional, Set, Tuple, Type, List, Callable, Union, Literal from contextlib import contextmanager +import weakref import torch from torch._C import ScriptObject @@ -89,6 +90,7 @@ def __init__(self, strategy, record_frames = False): self.scope = Scope("", None) self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} + self.call_expr_stack = [] self.strategy = TRACE_STRATEGY[strategy](self) self.record_frames = record_frames self.patcher = FunctionPatcher() @@ -100,6 +102,16 @@ def __init__(self, strategy, record_frames = False): self.need_revert_functions = set() self.need_revert_wrapped_functions = set() + # Save functions decorated with functools.cache/lru_cache + # We need to clear up caches after tracing to avoid memory leak or tracing error. + # TODO: currently only functions/methods are tracked. + # Cached Properties (via @property @cache or @cached_property) are not tracked + # The reason is: + # 1. Cached properties is rare to cause problem as they have no arguments (no ConcrateProxy object will pass to it) + # 2. We need to patch all getattr (`a.b``) to support this scenario, which is too expensive + # Currently only function calls (`f(a,b)`) are patched and tracked. (See `operator_patcher`) + self.cached_function = weakref.WeakSet() + self.temp_call_origin = False def add_need_revert_function(self, func, wrapped_func): @@ -109,6 +121,49 @@ def add_need_revert_function(self, func, wrapped_func): def need_revert(self, func): return func in self.need_revert_functions or func in self.need_revert_wrapped_functions + @classmethod + def _is_cache_wrapped_function(cls, func): + return callable(func) \ + and hasattr(func, 'cache_clear') \ + and hasattr(func, 'cache_info') \ + and hasattr(func, 'cache_parameters') \ + and hasattr(func, '__wrapped__') \ + and callable(func.__wrapped__) + + def _track_cache_wrapped_function(self, func): + while func is not None: + if self._is_cache_wrapped_function(func): + self.cached_function.add(func) + break + func = getattr(func, '__wrapped__', None) + + @classmethod + def _is_torch_compile_function(cls, func): + return callable(func) \ + and hasattr(func, '__wrapped__') \ + and hasattr(func, '_torchdynamo_orig_callable') + + def _check_torch_compile_function(self, func): + outmost_func = func + while func is not None: + if self._is_torch_compile_function(func): + # If func is registered, run this func will be in a reverted context. + if not self.need_revert(outmost_func): + raise RuntimeError( + f"@torch.compile decorated function `{outmost_func.__module__}.{outmost_func.__qualname__}` is not registered. " + f"You must register it to avoid tracing failure." + ) + break + func = getattr(func, '__wrapped__', None) + + def on_function_call(self, func, expr): + self.call_expr_stack.append(expr) + self._track_cache_wrapped_function(func) + self._check_torch_compile_function(func) + + def on_function_call_end(self): + self.call_expr_stack.pop() + @contextmanager def do_temp_call_origin(self): temp_call_origin = self.temp_call_origin @@ -159,6 +214,15 @@ def create_node(self, kind : str, target : Target, else: node.meta['nn_module_stack'] = collections.OrderedDict() + if self.call_expr_stack: + last_call_expr = None + for item in reversed(self.call_expr_stack): + # if not found, leave last_call_expr as None + if item: + last_call_expr = item + break + node.meta['call_expr'] = last_call_expr + def unwrap_nested_proxy(proxy: ep.ConcreteProxy): return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) @@ -397,7 +461,6 @@ def proxy_placeholder(name: str): return self.create_proxy('placeholder', name, default_arg, {}) args.extend(proxy_placeholder(names) for names in arg_names) - if hasattr(co, 'co_kwonlyargcount') and ( co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF): # TODO: type annotations for *args and **kwargs @@ -407,6 +470,13 @@ def proxy_placeholder(name: str): more_args = proxy_placeholder(name) if co.co_flags & inspect.CO_VARKEYWORDS: name = '**' + next(names_iter) + if name not in concrete_args: + # auto pack the additional kwargs + kwargs_val = {} + for cc_name in concrete_args: + if cc_name not in arg_names and not cc_name.startswith('*'): + kwargs_val[cc_name] = concrete_args[cc_name] + concrete_args[name] = kwargs_val default_args[name] = {} kwargs = proxy_placeholder(name) @@ -692,6 +762,10 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): {}, type_expr=fn.__annotations__.get('return', None), node_result=node_result) finally: _retain_weight_consistency(self.root) + # clean up caches + for func in self.cached_function: + if func is not None: + func.cache_clear() return self.graph diff --git a/nnscaler/graph/tracer/metadata.py b/nnscaler/graph/tracer/metadata.py index f6f898a2..75f4de9a 100644 --- a/nnscaler/graph/tracer/metadata.py +++ b/nnscaler/graph/tracer/metadata.py @@ -8,6 +8,7 @@ from torch.fx.node import Node from . import pytree_utils +from nnscaler.utils import get_dynamic DICT_KEYS_TYPE = type({}.keys()) DICT_VALUES_TYPE= type({}.values()) @@ -95,6 +96,9 @@ class TensorMetadata(NamedTuple): is_quantized : bool qparams: Dict[str, Any] + # all dynamic dimensions in shape + dynamic_dims: set[int] + def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ @@ -134,7 +138,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] - return TensorMetadata(shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + return TensorMetadata(shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams, get_dynamic(result)) def extract_metadata(results: Any, node: Node): diff --git a/nnscaler/graph/tracer/operator_patcher.py b/nnscaler/graph/tracer/operator_patcher.py index 235da0b5..783d9627 100644 --- a/nnscaler/graph/tracer/operator_patcher.py +++ b/nnscaler/graph/tracer/operator_patcher.py @@ -171,7 +171,11 @@ def visit_Call(self, node: ast.Call): self.modified = True return self.generic_visit(ast.Call( func=ast.Name(id=self.proxy_call_name, ctx=ast.Load()), - args=[node.func, *node.args], + args=[ + node.func, + ast.fix_missing_locations(ast.Constant(value=ast.unparse(node))), + *node.args + ], keywords=node.keywords, )) else: @@ -311,7 +315,7 @@ def patch_func_helper(self, func): # use func.__code__.co_filename to make the new function easily debuggable. compile(new_tree, func_inner.__code__.co_filename, 'exec'), { - self.proxy_call_name: OperatorPatcherContext.patch_run, + self.proxy_call_name: OperatorPatcherContext._patch_run, **func_inner.__globals__, **closure_dict, }, @@ -346,9 +350,19 @@ def __exit__(self, exc_type, exc_value, tb): return exc_type is None @staticmethod - def patch_run(func, *args, **kwargs): + def _patch_run(func, expr, *args, **kwargs): assert OperatorPatcherContext.ctx_tracer is not None assert OperatorPatcherContext.ctx_patcher is not None with wrap_utils.do_temp_call_origin(): + OperatorPatcherContext.ctx_tracer.on_function_call(func, expr) new_func = OperatorPatcherContext.ctx_patcher.patch_func_or_module(func) - return new_func(*args, **kwargs) + + ret = new_func(*args, **kwargs) + + with wrap_utils.do_temp_call_origin(): + OperatorPatcherContext.ctx_tracer.on_function_call_end() + return ret + + @staticmethod + def patch_run(func, *args, **kwargs): + return OperatorPatcherContext._patch_run(func, '', *args, **kwargs) diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index f359fec3..42a05eaf 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -18,18 +18,19 @@ from __future__ import annotations +from dataclasses import dataclass, field from functools import lru_cache -from typing import ClassVar, List, Tuple, Union, Optional, Any, Dict, Callable +from typing import ClassVar, Iterable, List, Set, Tuple, Type, Union, Optional, Any, Dict, Callable from collections import OrderedDict import copy import torch from nnscaler.ir.unique import IDGenerator from nnscaler.ir.dtype import DTypeInfo -from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE +from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE, load_type, get_dynamic -NestedVarOrStatic = Any +NestedVarOrStatic = Union[Any, 'IRObject', List['IRObject'], 'IRTensor'] class IRCell: @@ -77,9 +78,30 @@ def __init__(self, self._comment: Optional[str] = None # the module stack that preserves the hierarchy information self._module_stack: Optional[OrderedDict[str, Any]] = None + # the original call expression + # Note: + # 1. some cells may not have call expression if the cell is not from function call (e.g., __getitem__) + # 2. call_expr can be inaccurate when function call happens + # inside pytorch official module (like in torch.nn namespace) forward, + # (e.g., F.linear inside nn.Linear), in this case, call_expr will be module call expression. + self._call_expr: Optional[str] = None # the operation context information self._op_context: Optional[Dict[str, Any]] = None + # function to be called before the op is executed + # which will be inserted in the runtime code before the op call. + # op's inputs will be passed to the hook. + # The signature will be like + # def pre_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: + self._pre_hook: Optional[Callable[..., None]] = None + # function to be called after the op is executed + # which will be inserted in the runtime code after the op call. + # op's inputs and outputs will be passed to the hook. + # the signature will be like + # def post_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any], output: Any) -> None: + self._post_hook: Optional[Callable[..., None]] = None + self._hook_meta: Any = None + @property def cid(self) -> int: """ @@ -377,6 +399,22 @@ def comment(self, info: str): @property def module_stack(self) -> Optional[OrderedDict[str, Any]]: + """ + Get the module stack, which preserves the hierarchy information + of modules this cell belongs to. + For example, if this cell is from model.submodule.layers.0.block0.conv2d, + then the module stack will be: + OrderedDict([ + ('model.submodule', ), + ('model.submodule.layers.0.block0', ), + ('model.submodule.layers.0.block0.conv2d', ), + ]) + + Please note + 1. Root module (e.g., model) is not included in the stack. + 2. Only modules that have `.forward` function are included in the stack, + so in above example, `torch.nn.ModuleList` is not included. + """ return self._module_stack @module_stack.setter @@ -386,6 +424,97 @@ def module_stack(self, stack: OrderedDict[str, Any]): """ self._module_stack = stack + @property + def module_class_chain(self) -> list[type[torch.nn.Module]]: + """ + Get the module chains the IRCell belongs to. + If module stack is None or empty, return []. + """ + if not self._module_stack: + return [] + return list(self._module_stack.values()) + + @property + def fqn(self) -> str: + """ + Get the fully qualified module name the IRCell belongs to. + If module stack is None or empty, return ''. + """ + if not self._module_stack: + return '' + return list(self._module_stack.keys())[-1] + + def get_module_fqn( + self, module_class: Type[torch.nn.Module], + *, + include_subclass: bool = False + ) -> str: + """ + Get the first fully qualified module name for the given module class + in the module stack. If not found, return ''. + + Args: + module_class (Type[torch.nn.Module]): the module class to find + include_subclass (bool): whether to include subclass of the module_class + + Returns: + str: the fully qualified module name + """ + if not self._module_stack: + return '' + for fqn, mod_cls in self._module_stack.items(): + if mod_cls == module_class or ( + include_subclass and issubclass(mod_cls, module_class) + ): + return fqn + return '' + + @property + def call_expr(self) -> Optional[str]: + return self._call_expr + + @call_expr.setter + def call_expr(self, expr: Optional[str]): + self._call_expr = expr + + @property + def fn(self) -> Optional[Callable]: + """ + Get the function of this cell based on its signature. + Return None if the function cannot be loaded. (e.g. virtual ops like `self_getattr`) + + Returns: + Callable: the function object + """ + try: + return load_type(self.signature) + except Exception as e: + return None + + @property + def pre_hook(self) -> Optional[Callable[..., None]]: + return self._pre_hook + + @pre_hook.setter + def pre_hook(self, hook: Optional[Callable[..., None]]): + self._pre_hook = hook + + @property + def post_hook(self) -> Optional[Callable[..., None]]: + return self._post_hook + + @post_hook.setter + def post_hook(self, hook: Optional[Callable[..., None]]): + self._post_hook = hook + + @property + def hook_meta(self) -> Any: + return self._hook_meta + + @hook_meta.setter + def hook_meta(self, meta: Any): + self._hook_meta = meta + @property def op_context(self) -> Optional[Dict[str, Any]]: return self._op_context @@ -459,14 +588,127 @@ def modify_objects_of_complex(val: Any, modifier: Callable[['IRObject'], 'IRObje return val +@dataclass +class ValueTrack: + """ + Track the value of an IRObject or a dimension of IRTensor. + Currently only implemented for dimension via IRDimops annotation. + + Example: + `l (2 h) m -> l h (2 m)`: + Input Tensor Tracks (2/5 is external dependencies for illustration): + dim 0: ValueTrack(value_id=10, dependencies=[]) # l + dim 1: ValueTrack(value_id=20, dependencies=[]) # (2 h) + dim 2: ValueTrack(value_id=30, dependencies=[2, 5]) # m + Then we can infer the output Tensor Tracks: + Output Tensor Tracks: + dim 0: ValueTrack(value_id=10, dependencies=[]) # reuse input dim 0, since they are the same + dim 1: ValueTrack(value_id=40, dependencies=[20]) # it depends on input dim 1: (2 h) + dim 2: ValueTrack(value_id=50, dependencies=[30]) # it depends on input dim 2: m + """ + value_id: int = field(default_factory=IDGenerator().gen_value_id) + # By default, we consider the value is constant + # unless it is set to not constant + # via mark_dynamic or it is from input or explicitly set in function.py + is_constant: bool = True + # None: unknown dependencies + # []: no dependencies + deps: Optional[list[int]] = None + + def with_no_dep(self) -> 'ValueTrack': + """ + Initialize this ValueTrack with no dependencies. + """ + self.deps = [] + return self + + def add_dep(self, dep: Union[Any, 'ValueTrack', 'IRObject']) -> 'ValueTrack': + """ + Initialize or add a dependency to the ValueTrack. + If dep is not IRObject or ValueTrack, do nothing. + """ + if self.deps is None: + self.deps = [] + + if not isinstance(dep, (ValueTrack, IRObject)): + return self + + if isinstance(dep, IRTensor): + raise TypeError("Cannot directly add IRTensor as dependency.") + + dep: ValueTrack = dep.value_track if isinstance(dep, IRObject) else dep + dep_value_id = dep.value_id + if dep_value_id not in self.deps: + self.deps.append(dep_value_id) + self.is_constant = self.is_constant and dep.is_constant + + return self + + def merge(self, other: ValueTrack) -> 'ValueTrack': + """ + Merge another ValueTrack into this one. + The merged ValueTrack will have dependencies from both ValueTracks. + """ + if self.deps is None: + self.deps = other.deps + else: + self.deps.extend(other.deps or []) + + if self.deps is not None: + self.deps = list(set(self.deps)) + + self.is_constant = self.is_constant and other.is_constant + return self + + @classmethod + def new(cls, deps: Iterable[Union[Any, 'ValueTrack', 'IRObject']], is_constant: Optional[bool] = None) -> 'ValueTrack': + vt = cls() + if is_constant is not None: + vt.is_constant = is_constant + vt.deps = [] + for dep in deps: + vt.add_dep(dep) + return vt + + def mark_as_input(self) -> 'ValueTrack': + """ + Mark this ValueTrack as graph input, which should be not constant and have no dependencies. + """ + self.is_constant = False + self.deps = [] + return self + + +_missing_value = object() + class IRObject: """ IRObject serves as general data of IRGraph edge + + There are two special IRObject for lazy evaluation: + 1. IRObject.missing: a singleton object to represent missing object + It is used to tell parser that we don't know the real object yet. + The parser is supposed to create a new IRObject to replace it. + For example, all custom ops will have missing outputs.It relies on parser to set them. + 2. IRObject(..., value=missing_value, ...): an object with unknown value + It is used to tell parser that we don't know the real value yet. + The parser is supposed to set the value. + We have this because we want ops to pass out `value_track` even when the value is unknown. + For example, `Item()` op in `function.py` will create such object. """ # will be set after class definition missing: ClassVar['IRObject'] = None - - def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: Optional[None] = None, is_constant: bool = True): + missing_value: ClassVar[object] = _missing_value + + def __init__( + self, + name: Optional[str] = None, + tid: Optional[int] = None, + value: Any = _missing_value, + is_constant: Optional[bool] = None, + *, + value_track: Optional[ValueTrack] = None, + ) -> None: """ Args: name (str): object name @@ -479,13 +721,19 @@ def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: 2. val is model input, or is the result of a non-torch operation on another not constant IRObject Please note is_constant flag is only used in parser, so after parser, you can totally ignore this flag. + We keep this flag in IRObject for backward compatibility. + If both is_constant and value_track are provided, + `value_track.is_constant` will be overrided by this flag. + value_track (ValueTrack): the value track info of this object """ self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() self.name: str = name if name else 'obj' self._cell: Optional[IRCell] = None self._is_attr: bool = False - self._value: Optional[Any] = value - self._is_constant: bool = is_constant + self._value: Any = value + self._value_track: ValueTrack = value_track or ValueTrack() + if is_constant is not None: + self._value_track.is_constant = is_constant def __hash__(self) -> int: return self._id @@ -538,13 +786,30 @@ def value(self) -> Any: """Get example value""" return self._value + @value.setter + def value(self, val: Any): + self._value = val + + def is_value_missing(self) -> bool: + """Check if the value is missing""" + return self._value is IRObject.missing_value + + @property + def value_track(self) -> ValueTrack: + """Get value track info""" + return self._value_track + + @value_track.setter + def value_track(self, val: ValueTrack): + self._value_track = val + @property def is_constant(self) -> bool: - return self._is_constant + return self._value_track.is_constant @is_constant.setter def is_constant(self, val: bool): - self._is_constant = val + self._value_track.is_constant = val def __eq__(self, obj) -> bool: if not isinstance(obj, IRObject): @@ -555,7 +820,7 @@ def __copy__(self): """Copy this object but remove the cell information""" if self is IRObject.missing: # missing object is singleton return IRObject.missing - return IRObject(self.name, self._id, self._value, self._is_constant) + return IRObject(self.name, self._id, self._value, self.is_constant, value_track=self._value_track) def as_attr(self): """ @@ -651,7 +916,10 @@ def _inner(obj) -> Tuple[Any, bool]: new_ir_tensor._value = obj.value return new_ir_tensor, True else: - return IRObject(name, value=obj.value, is_constant=is_constant), False + return IRObject( + name, value=obj.value, + is_constant=is_constant, value_track=obj.value_track + ), False if isinstance(obj, tensor_types): if requires_grad is None: @@ -667,6 +935,10 @@ def _inner(obj) -> Tuple[Any, bool]: dtype=obj.dtype, requires_grad=rg, ) + + for dyn_idx in get_dynamic(obj): + tensor.dim_tracks[dyn_idx].is_constant = False + if tosub: tensor = tensor.tosub() tensor._value = obj # is required in SemanticModel.forward @@ -907,11 +1179,12 @@ class IRTensor(IRObject): You can get the original shape with `origin_shape` property. """ def __init__(self, shape=None, name='tensor', dtype=None, tid=None, *, - is_attr=False, is_grad=False, requires_grad=False, persistent=False + is_attr=False, is_grad=False, requires_grad=False, persistent=False, ): super().__init__(name, tid, is_constant=False) self._is_scalar_tensor: bool = True - self._shape: Tuple[int] = () + self._shape: Tuple[int, ...] = () + self._dim_tracks: Tuple[ValueTrack, ...] = () self._dtype: Optional[torch.dtype] = None # tensor gradient self._is_grad: bool = False @@ -946,7 +1219,9 @@ def _update( if shape is not None: self._is_scalar_tensor = not shape # will always convert scalar tensor to 1-d tensor - self._shape: Tuple[int] = (1,) if not shape else tuple(shape) + self._shape: Tuple[int, ...] = (1,) if not shape else tuple(shape) + # reset dim tracks + self._dim_tracks = tuple(ValueTrack() for _ in self._shape) if name is not None or self.name is None: self.name = name if dtype is not None: @@ -973,7 +1248,7 @@ def dtype(self) -> Optional[torch.dtype]: def is_param(self) -> bool: """! - Check if the tensor is parameter + Check if the tensor is parameter (with requires_grad = True). @return is_param boolean: True if is parameter. """ @@ -1039,12 +1314,55 @@ def origin_shape(self) -> Tuple[int]: return self.shape if not self.is_scalar_tensor() else () @property - def shape(self) -> Tuple[int]: + def shape(self) -> Tuple[int, ...]: # NOTE: here return a tuple but not a real torch.Size obj may have risk, here is an example: # (torch.Size + tuple -> torch.Size) will change to (tuple + tuple -> tuple), is ok. # (torch.Size + list -> torch.Size) will change to (tuple + list -> error), is wrong. return self._shape + @property + def dim_tracks(self) -> Tuple[ValueTrack, ...]: + """ + Get the track of each dimension + """ + return self._dim_tracks + + @dim_tracks.setter + def dim_tracks(self, val: Tuple[Optional[ValueTrack], ...]): + """ + Set the unique id of each dimension + """ + if not isinstance(val, (list, tuple)): + raise ValueError("dim_tracks must be a list or tuple") + if len(val) != len(self._shape): + raise ValueError("dim_tracks length must be equal to shape length") + # None means starting a new dim track + self._dim_tracks = tuple(v if v is not None else ValueTrack() for v in val) + + def set_dim_track(self, dim: int, track: ValueTrack): + """ + Set the track of a specific dimension + """ + if dim < 0 or dim >= len(self._shape): + raise IndexError("dim out of range") + dim_tracks = list(self._dim_tracks) + dim_tracks[dim] = track + self._dim_tracks = tuple(dim_tracks) + + def dim_constant(self, dim: int) -> bool: + """ + Check if a dim is constant + """ + if dim < 0 or dim >= len(self._shape): + raise IndexError("dim out of range") + return self._dim_tracks[dim].is_constant + + def dims_constant(self) -> bool: + """ + Check if all dims are constant + """ + return all(track.is_constant for track in self._dim_tracks) + def nelement(self) -> int: """ Get total number of element in the tensor. diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index 6546720f..f3f2c9eb 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -27,10 +27,10 @@ 3) gradient of parameters """ -from typing import List, Optional, Union, Tuple, NewType, Dict, Any +from typing import List, Optional, Set, Union, Tuple, NewType, Dict, Any import torch -from nnscaler.ir.cten import IRTensor +from nnscaler.ir.cten import IRTensor, ValueTrack StartEnd = NewType('[start:end)', Tuple[int, int]) IdxChunk = NewType('(index, chunks)', Tuple[int, int]) @@ -260,14 +260,17 @@ class IRFullTensor(IRTensor): """ def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=None, *, - is_attr=False, is_grad=False, persistent=False, is_loss=False + is_attr=False, is_grad=False, persistent=False, is_loss=False, ): self._is_loss: bool = False # record all created sub_tensors self._subtensors : Dict[(ValueMap, IndexMap), int] = dict() self._grad: Optional[IRFullTensor] = None - super().__init__(shape, name, dtype, requires_grad=requires_grad, is_attr=is_attr, is_grad=is_grad, persistent=persistent) + super().__init__( + shape, name, dtype, requires_grad=requires_grad, + is_attr=is_attr, is_grad=is_grad, persistent=persistent, + ) self._update( is_loss=is_loss, ) @@ -334,6 +337,7 @@ def like(self): self.origin_shape, self.name, self._requires_grad, self._dtype, is_loss=self._is_loss ) + tensor.dim_tracks = self.dim_tracks return tensor def like_grad(self): @@ -346,6 +350,7 @@ def like_grad(self): self.origin_shape, 'g' + self.name, requires_grad=False, dtype=self.dtype ).as_grad(self._is_attr) + grad.dim_tracks = self.dim_tracks return grad @property @@ -363,6 +368,7 @@ def grad(self, val: Optional[IRTensor]): assert self._requires_grad, f"Cannot assign {val} to no grad-required tensor" assert val.origin_shape == self.origin_shape assert val.is_attr() == self.is_attr() + val.dim_tracks = self.dim_tracks # TODO: we should check the grad-required here # it is very common in current code that we assign None to grad # so currently it is impossible to check the grad-required here @@ -507,6 +513,7 @@ def __init__(self, ftensor: IRFullTensor, del self._is_grad del self._requires_grad del self._persistent + del self._dim_tracks self.cell = None # the index from full_tensor @@ -556,7 +563,7 @@ def ndims(self) -> int: def as_attr(self): raise RuntimeError("as_attr is not allowed for SubTensor") - def splitdims(self) -> Tuple[int]: + def splitdims(self) -> Tuple[int, ...]: """! Get partitioned dimensions @@ -677,6 +684,10 @@ def dtype(self) -> Optional[torch.dtype]: """Tensor data type""" return self.parent.dtype + @property + def dim_tracks(self) -> Tuple[ValueTrack, ...]: + return self.parent.dim_tracks + @IRTensor.shape.setter def shape(self, val: Tuple[int]): # TODO: remove this function diff --git a/nnscaler/ir/unique.py b/nnscaler/ir/unique.py index dde3ceb2..72338ee5 100644 --- a/nnscaler/ir/unique.py +++ b/nnscaler/ir/unique.py @@ -5,14 +5,14 @@ class IDGenerator: """ Tensor / Operator manager. To guarantee that each IRTensor / IROperator id is unique and progressively increases. - + This class is designed in singleton pattern. """ class __IDGenerator: def __init__(self): - self._tensor_id = 0 self._cell_id = 0 + self._value_id = 0 instance = None @@ -31,13 +31,19 @@ def gen_cell_id(self): self.instance._cell_id += 1 return self.instance._cell_id + def gen_value_id(self): + self.instance._value_id += 1 + return self.instance._value_id + def get_states(self): - return (self._tensor_id, self._cell_id) - + return (self._tensor_id, self._cell_id, self._value_id) + def load_states(self, states: tuple): IDGenerator.instance._tensor_id = states[0] IDGenerator.instance._cell_id = states[1] + IDGenerator.instance._value_id = states[2] def clear(self): self.instance._tensor_id = 0 self.instance._cell_id = 0 + self.instance._value_id = 0 diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index b728269b..6dd57aaf 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -14,6 +14,7 @@ import logging import copy import os +from collections import OrderedDict, defaultdict import torch import torch.distributed @@ -40,15 +41,23 @@ from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.unique import IDGenerator -from nnscaler.runtime.adapter.reducer import Reducer +from nnscaler.runtime.adapter.reducer import Bucket, Reducer from nnscaler.runtime.device import DeviceGroup from nnscaler.runtime.gnorm import calcuate_gnorm, clip_grads -from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState +from nnscaler.runtime.module import AttrMeta, Zero3AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState, dedup_attrs from nnscaler.flags import CompileFlag, RuntimeFlag import nnscaler.policies as policies from nnscaler.program import disable_global_graph -from nnscaler.utils import get_member_by_name, setup_stride_broadcast_group, get_shared_params +from nnscaler.utils import ( + get_member_by_name, + load_type, + set_member_by_name, + setup_stride_broadcast_group, + get_shared_params, + OptStateDict, + copy_dynamic +) logger = logging.getLogger(__name__) @@ -78,11 +87,17 @@ class ComputeConfig: # how to execute the functions during trace trace_strategy: str = 'cuda_run_cpu_offload' - use_zero: bool = False + # Only support 0/1/3 for now + # If you set use_zero to 2, ZeRO stage 3 will be used internally. + # 0: no zero + # 1: ZeRO stage 1 + # 2: ZeRO stage 3 + # 3: ZeRO stage 3 + use_zero: int = 0 zero_ngroups: int = 1 # whether to use reduce scatter for zero # Please note - # 1. this only works when `use_zero` is True and `zero_ngroups` is 1. + # 1. this only works when `use_zero` is not 0 and `zero_ngroups` is 1. # 2. In some cases, it can introduce parity issue. So use it with caution. zero_use_reduce_scatter: bool = False @@ -149,16 +164,34 @@ def __post_init__(self): raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be > 0") if self.runtime_ngpus % self.plan_ngpus != 0: raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be a multiple of plan_ngpus {self.plan_ngpus}") + + if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: + raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") + + # for backward compatibility, convert bool to int + super().__setattr__('use_zero', int(self.use_zero)) + if self.use_zero not in (0, 1, 2, 3): + raise ValueError(f"use_zero {self.use_zero} must be 0, 1, 2 or 3.") + if self.use_zero == 2: + logger.warning("use_zero=2 is not supported. ZeRO stage 3 will be used instead.") + super().__setattr__('use_zero', 3) + + num_scale_units = self.runtime_ngpus // self.plan_ngpus + if self.use_zero: + if num_scale_units % self.zero_ngroups != 0: + raise ValueError(f"zero_ngroups {self.zero_ngroups} must be a divisor of runtime_ngpus/plan_ngpus {num_scale_units}.") + if num_scale_units == self.zero_ngroups: + logger.warning(f"zero_ngroups {self.zero_ngroups} equals to runtime_ngpus/plan_ngpus {num_scale_units}. Zero optimization is disabled.") + super().__setattr__('use_zero', 0) + if self.use_zero and self.zero_ngroups <= 0: raise ValueError(f"zero_ngroups {self.zero_ngroups} must be > 0") + if not self.use_zero and self.zero_ngroups != 1: logger.warning(f"use_zero is False, but zero_ngroups is {self.zero_ngroups}. Will set zero_ngroups to 1.") # have to use __setattr__ for frozen dataclass super().__setattr__('zero_ngroups', 1) - if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: - raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") - # TODO: Please note in current implementation of Bucket, # zero_use_reduce_scatter still works when zero_ngroups > 1 in sync mode # Let's hide this feature for now for consistency. @@ -215,7 +248,11 @@ def module_dedup_group_size(self) -> int: """ Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. """ - return self.plan_ngpus + if self.use_zero > 1: + # for zero3 + return self.runtime_ngpus // self.zero_ngroups + else: + return self.plan_ngpus @property def optimizer_dedup_group_size(self) -> int: @@ -337,19 +374,29 @@ def _runtime_flags(**kwargs): return _flags(RuntimeFlag, **kwargs) -def _to_cpu(val: Any): - """Complex to CPU""" +def _to_cpu(val: Any, requires_grad: Optional[bool] = None) -> Any: + """ + Complex to CPU + Recursively move the input to CPU. + Args: + val (Any): the input value + requires_grad (Optional[bool]): whether the returned tensor requires grad. + If it is None, will keep the same as the input tensor. + """ if isinstance(val, tuple): - return tuple(_to_cpu(t) for t in val) + return tuple(_to_cpu(t, requires_grad) for t in val) if isinstance(val, list): - return list(_to_cpu(t) for t in val) + return list(_to_cpu(t, requires_grad) for t in val) if isinstance(val, dict): - return {_to_cpu(key):_to_cpu(val) for key, val in val.items()} + return {_to_cpu(key, requires_grad):_to_cpu(val, requires_grad) for key, val in val.items()} if isinstance(val, set): - return {_to_cpu(t) for t in val} + return {_to_cpu(t, requires_grad) for t in val} if isinstance(val, torch.Tensor): - requires_grad = val.is_floating_point() or val.is_complex() - return val.detach().clone().cpu().requires_grad_(requires_grad) + if requires_grad is None: + requires_grad = val.requires_grad + else: + requires_grad = requires_grad and (val.is_floating_point() or val.is_complex()) + return copy_dynamic(val, val.detach().clone().cpu().requires_grad_(requires_grad)) return val @@ -379,7 +426,7 @@ def _add_gen_savedir_to_syspath(gen_savedir: str) -> Path: gen_savedir = Path(gen_savedir).resolve() gen_savedir.mkdir(parents=True, exist_ok=True) if str(gen_savedir) not in sys.path: - sys.path.append(str(gen_savedir)) + sys.path.insert(0, str(gen_savedir)) return gen_savedir @@ -556,6 +603,10 @@ def _prepare_and_check_reusable( if reuse == ReuseType.MATCH or reuse == ReuseType.MOO: # check if the module is already generated expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] + expected_output_files.extend([ + outdir / ParallelModule.ATTR_META_FILE_TEMPLATE.format(rank) + for rank in range(compute_config.runtime_ngpus) + ]) expected_output_files.extend(trace_meta_files) expected_output_files.append(config_file) expected_output_files.append(outdir / _GRAPH_DUMP_FILE) @@ -636,7 +687,13 @@ def _gen_graph( raise ValueError(f"Default value type {type(v)} of forward args is not supported.") # generate fx graph - dummy_forward_args = _to_cpu(dummy_forward_args) + dummy_forward_args = _to_cpu( + dummy_forward_args, + # in end2end mode, we don't need gradients for inputs + # in normal mode, we assume all inputs require gradients + # so it can connect to other parts of the graph correctly + requires_grad=not end2end_mode + ) fx_graph = parser.to_fx_graph(module, dummy_forward_args) # generate ir logic graph @@ -653,51 +710,22 @@ def _gen_graph( node.target: forward_args_default.get(node.target, inspect.Parameter.empty) for node in fx_input_nodes } - ir_dummy_inputs = [] - for node in fx_input_nodes: - if node.target.startswith('*'): # *args or **kwargs - if node.target.strip('*') in dummy_forward_args: - raise ValueError(f"Input {node.target}: *args or **kwargs is not suppported") - ir_dummy_inputs.append(None) # always set None to *args/**kwargs - elif node.target in dummy_forward_args: - ir_dummy_inputs.append(dummy_forward_args[node.target]) - elif forward_args[node.target] is not inspect.Parameter.empty: - ir_dummy_inputs.append(forward_args[node.target]) - else: - raise ValueError(f"Input {node.target} not in dummy forward args, nor has default value.") - for i in range(len(ir_dummy_inputs)): - # note: we will always set tensor to require gradient, which may - # generate backward communications in adapter. However, as long as - # the data doesn't require gradient in real runtime, the backward - # communication will not be triggered. - ir_dummy_inputs[i] = IR.new( - fx_input_nodes[i].target, ir_dummy_inputs[i], - requires_grad=True, - tosub=True, - is_constant=False, - ) - # if the input is a complex type, we should wrap it with IRObject - if not isinstance(ir_dummy_inputs[i], IRObject): - ir_dummy_inputs[i] = IRObject(fx_input_nodes[i].target, value=ir_dummy_inputs[i], is_constant=False) - # generate complete ir graph - ir_dummy_outputs = graph(*ir_dummy_inputs) if end2end_mode: # in end2end mode, we must use dataloader as the first argument of forward # we assume the first argument of forward is the data sample (which is a requirement in our doc) graph.use_dataloader_input() # we require the first output is the loss - if isinstance(ir_dummy_outputs, (list, tuple)): - ir_loss = ir_dummy_outputs[0] - else: - ir_loss = ir_dummy_outputs + ir_loss = graph.output(0) if not isinstance(ir_loss, IRTensor) or ir_loss.shape != (1,): # internally scalar tensor will be reshaped to (1,) in IRGraph raise RuntimeError(f"Loss can only be scalar tensor but got {ir_loss.shape if isinstance(ir_loss, IRTensor) else ir_loss}") else: ir_loss = None + # we generate backward nodes and setup gradient tensors here + # forward nodes are done when we trace the model if not inference_only: graph.backward(ir_loss) else: @@ -839,12 +867,14 @@ def _gencode( sgener = ScheduleCodeGen(execplan, compute_config.runtime_ngpus) for rank in range(compute_config.runtime_ngpus): fname = outdir / _GENCODE_FILE_TEMPLATE.format(rank) + attr_meta_map_fname = outdir / ParallelModule.ATTR_META_FILE_TEMPLATE.format(rank) mgener.gen(rank, forward_args=forward_args, outfile=fname, attach=False, as_parallel_module=True, - end2end_mode=compute_config.use_end2end + end2end_mode=compute_config.use_end2end, + outfile_attr_meta_map=attr_meta_map_fname ) # generate temporal schedule code only for end2end module # because the code generated is wrong for non-end2end module. @@ -912,6 +942,7 @@ def parallelize( module_dtype: Optional[torch.dtype] = None, module_fn: Optional[Callable[[], torch.nn.Module]] = None, init_module_params: bool = True, + build_module_buckets: bool = True, broadcast_strategy: Union[str, BroadcastGenFilesStrategy] = 'none', ) -> Union[None, ParallelModule, Type[ParallelModule]]: """ @@ -982,6 +1013,12 @@ def __init__(self, init_params=True): Otherwise, they will be empty tensor. This parameter will be passed to the module constructor, so it is only used when module_or_module_class is a module object, and load_module is true. + build_module_buckets (bool): For parallel module, parameters that needs to synchronize will be grouped into buckets for more efficient communication. + If true, grouping process will be done in `__init__` + If false, you should do this by yourself. + This parameter will be passed to the module constructor, + so it is only used when module_or_module_class is a module object, and load_module is true. + Please leave it to true until you have a good reason to change it. module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for generated files. @@ -1008,7 +1045,11 @@ def __init__(self, init_params=True): if isinstance(pas_policy, str): if not pas_policy in _PREDEFINED_POLICIES: raise ValueError(f"Invalid pas_policy: {pas_policy}") - pas_policy = _PREDEFINED_POLICIES[pas_policy] + pas_policy = partial(policies.fn, policy=_PREDEFINED_POLICIES[pas_policy]) + else: + if not callable(pas_policy): + raise ValueError("pas_policy should be a callable or a predefined policy name") + pas_policy = partial(policies.fn, policy=pas_policy) is_module_class = inspect.isclass(module_or_module_class) module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ @@ -1115,7 +1156,7 @@ def __init__(self, init_params=True): if is_module_class: return parallel_module_class else: - parallel_module = parallel_module_class(init_module_params) + parallel_module = parallel_module_class(init_module_params, build_module_buckets) parallel_module.train(module_or_module_class.training) # set training state to the same as original module return parallel_module @@ -1244,12 +1285,34 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] OptimizerT = TypeVar('OptimizerT', bound=torch.optim.Optimizer) +HybridOptimizerT = TypeVar('HybridOptimizer', bound=torch.optim.Optimizer) + + +def hybrid( + params: list[torch.nn.Parameter], + param_clss: dict[torch.nn.Parameter, tuple[int, int]], + **kwargs, +) -> HybridOptimizerT: + """ + Stub for hybrid optimizer creation. + Signature of Hybrid optimizer constructor: + ``` + def __init__(self, params, param_clss, **kwargs): + ... + ``` + When you pass arguments to `build_optimizer` + You must pass `param_clss_fn`, + and `build_optimizer` will automatically pass `param_clss` to its constructor. + """ + ... +hybrid.is_hybrid = True # mark this function as hybrid optimizer factory def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], compute_config: Optional[ComputeConfig] = None, + param_clss_fn: Optional[Callable[[str], Any]] = None, **kwargs, ) -> Union[OptimizerT, ParallelOptimizer]: """ @@ -1277,6 +1340,11 @@ def build_optimizer( compute_config (Optional[ComputeConfig]): The config will be used to generate communication reducer. If it is None, Default configuration will be used when creating reducer for non-parallel modules. + param_clss_fn (Optional[Callable[[str], Any]]): + A function that maps original full qualified parameter names to their class IDs. + If you are using a hybrid optimizer, + you must specify this function + and the return value of this function must be a tuple[int, int] of (optimizer_index, param_group_index). **kwargs: the kwargs for optimizer constructor Returns: @@ -1285,7 +1353,6 @@ def build_optimizer( and will be patched with the methods in ParallelModule class to support parallelized module. Please note the type annotation of the returned optimizer (`Union[OptimizerT, ParallelOptimizer]`) is just for intellisense. """ - if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): raise RuntimeError("Old style CubeModule is not supported") @@ -1293,12 +1360,17 @@ def build_optimizer( if any(m != module and isinstance(m, ParallelModule) and m.compute_config.use_end2end for m in module.modules()): raise RuntimeError("End2End module cannot be nested in another module") + is_hybrid = getattr(optimizer_fn, 'is_hybrid', False) + if is_hybrid and param_clss_fn is None: + raise ValueError("param_clss_fn must be provided when using hybrid optimizer") + RuntimeFlag.skip_reducer = True RuntimeFlag.skip_zero_grad = False non_parallel_module_reducer = None non_parallel_modules = [m for m in module.modules() if not isinstance(m, ParallelModule)] parallel_modules = [m for m in module.modules() if isinstance(m, ParallelModule)] + parallel_modules_prefix = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} if not parallel_modules: raise RuntimeError("No ParallelModule found in the module. Please make sure you have called parallelize() before build_optimizer().") @@ -1310,6 +1382,23 @@ def build_optimizer( non_parallel_parameters_dict[param] = None non_parallel_parameters = list(non_parallel_parameters_dict.keys()) + param_original_names = {} + for n, p in module.named_parameters(): + nparts = n.split('.') + module_prefix = '.'.join(nparts[:-1]) + if module_prefix in parallel_modules_prefix: + name_mapping = parallel_modules_prefix[module_prefix].get_full_map() + original_name = name_mapping[nparts[-1]].orig_name + param_original_names[p] = \ + f'{module_prefix}.{original_name}' if module_prefix else original_name + else: + param_original_names[p] = n + + if param_clss_fn: + param_clss = {p: param_clss_fn(n) for p, n in param_original_names.items()} + else: + param_clss = {} + # check if all ParallelModules have the same gpu_config compute_configs = [m.compute_config for m in parallel_modules] for i in range(1, len(compute_configs)): @@ -1331,7 +1420,9 @@ def build_optimizer( if compute_config: reducer_config = { 'async_op': compute_config.use_async_reducer, - 'zero': compute_config.use_zero, + # zero3 can't be used in non-parallel module reducer + # because we are unable to insert hooks to prefetch/postevict params + 'zero': 1 if compute_config.use_zero else 0, 'max_bucket_size_bytes': compute_config.max_bucket_size_bytes, 'zero_use_reduce_scatter': compute_config.zero_use_reduce_scatter, 'zero_ngroups': compute_config.zero_ngroups, @@ -1339,7 +1430,13 @@ def build_optimizer( non_parallel_module_reducer = Reducer(group, **reducer_config) for param in non_parallel_parameters: non_parallel_module_reducer.add_param(param) - non_parallel_module_reducer.build_buckets() + non_parallel_module_reducer.build_buckets(param_clss=param_clss) + + if param_clss_fn: + for pm in parallel_modules: + pm.build_buckets(param_clss=param_clss) + for reducer in pm.reducers: + param_clss.update(reducer.get_opt_params()) opt_module_locs: Dict[str, ModuleParameterLocation] = {} def _local_parameters(module: torch.nn.Module): @@ -1372,7 +1469,13 @@ def _local_parameters(module: torch.nn.Module): opt_module_locs[name].count += 1 yield param - optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), **kwargs) + if is_hybrid: + optimizer = optimizer_fn(_local_parameters(module), + param_clss, + **kwargs + ) + else: + optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), **kwargs) optimizer._non_parallel_module_reducer = non_parallel_module_reducer optimizer._extra_state = OptimizerExtraState( rank=torch.distributed.get_rank(), @@ -1385,21 +1488,21 @@ def _local_parameters(module: torch.nn.Module): } ) - def _step_pre_hook(opt, *args, **kwargs): - opt.sync_shard_grad() - - def _step_post_hook(opt, *args, **kwargs): + orig_step = optimizer.step + def _patched_step(self, closure=None): + # Please note: + # when closure is used in optimizer.step() + # the backward is done in closure, + # and it is useless to sync grad because grad is still unavailable there + # so you must call sync_shard_grad() manually in this case. + if closure is None: + self.sync_shard_grad() + orig_step(closure=closure) for m in parallel_modules: m.gather_params() if non_parallel_module_reducer: non_parallel_module_reducer.gather_params() - - # Please note: - # register_step_pre_hook doesn't work expectly - # when closure is used in optimizer.step() - # in that case, you must call sync_shard_grad() manually - optimizer.register_step_pre_hook(_step_pre_hook) - optimizer.register_step_post_hook(_step_post_hook) + optimizer.step = types.MethodType(_patched_step, optimizer) orig_zero_grad = optimizer.zero_grad def _patched_zero_grad(self, set_to_none: bool = True): @@ -1574,6 +1677,11 @@ def _get_parallel_module_state_dict_info( return pm_extra_states, pm_state_dicts, non_pm_state_dict +def _is_supported_optimizer(name: str): + from nnscaler.runtime.hybrid_optimizer import HybridOptimizer + return ('adam' in name.lower()) or name == HybridOptimizer.__name__ + + def _get_optimizer_state_dict_info( optimizer_state_dicts: List[Dict[str, Any]] ) -> Tuple[ @@ -1630,7 +1738,7 @@ def _get_optimizer_state_dict_info( ] = {} for opt_state_dict in optimizer_state_dicts: opt_extra_state = OptimizerExtraState(**opt_state_dict[ParallelModule.EXTRA_STATE_KEY]) - if 'adam' not in opt_extra_state.name.lower(): + if not _is_supported_optimizer(opt_extra_state.name): raise ValueError("Only Adam-like optimizers are supported.") opt_extra_states[opt_extra_state.rank] = opt_extra_state @@ -1703,10 +1811,10 @@ def _sort_state_dicts(state_dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]] sorted_state_dicts =[None] * len(state_dicts) for state_dict in state_dicts: rank = _get_state_dict_rank(state_dict) + if rank >= len(state_dicts): + raise ValueError(f"Invalid rank {rank} in state_dicts with length {len(state_dicts)}.") if sorted_state_dicts[rank] is not None: raise ValueError(f"Duplicate rank {rank} in state_dicts.") - if rank >= len(state_dicts): - raise ValueError(f"Invalid rank {rank} in state_dicts.") sorted_state_dicts[rank] = state_dict return sorted_state_dicts @@ -1738,7 +1846,7 @@ def _sort_state_dicts(state_dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]] module_prefix = '.'.join(k) opt_state_dicts_for_merge = None if opt_state_dicts is None else opt_state_dicts[module_prefix] - merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks) for e in extra_states] + merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks, e.zero) for e in extra_states] if not extra_states[0].compute_config.use_zero: # all ranks should have the same use_zero merge_partial_states_zero_idx_maps = None merged_state_dict, merged_opt_state_dict = ParallelModule.merge_state_dicts( @@ -1869,91 +1977,189 @@ def load_merged_state_dict( module.to(device) if optimizer is not None and optimizer_state_dict is not None: - if 'adam' not in optimizer._extra_state.name.lower(): - raise ValueError("Only Adam-like optimizers are supported.") + new_optimizer_state_dict = _trim_optimizer_merged_state_dict(module, optimizer._extra_state, optimizer_state_dict, device=device) + optimizer.load_state_dict(new_optimizer_state_dict) - # handle non-paralleled module parameters - # make sure the order of the parameters - pm_name_locs: Dict[str, ModuleParameterLocation] = dict(sorted(optimizer._extra_state.parallel_module_locs.items(), key=lambda x: x[1].offset)) - pm_modules: List[torch.nn.Module] = [] - pm_locs = list(pm_name_locs.values()) - for name in pm_name_locs: - m = get_member_by_name(module, name) - if not isinstance(m, ParallelModule): - raise ValueError(f"Module {name} is not a ParallelModule") - pm_modules.append(m) - - merged_cur = 0 # the current index of the merged state dict - pm_cur = 0 # the current index of the parallel module in pm_locs - new_states: Dict[int, Dict[str, Any]] = {} - new_cur = 0 # the current index of the new state dict - assert len(optimizer_state_dict['param_groups']) == 1 - effective_state_len = len(optimizer_state_dict['param_groups'][0]['params']) - while merged_cur < effective_state_len: - # N: non-paralleled module parameters, P: paralleled module (will have multiple parameters) - # The parameter list would look like: NNPNPPPN - # []: the current processing parameter - # <>: the current processing parallel module - if ( - pm_cur >= len(pm_modules) # NNPNPPP[N]: the ending parameters, no current parallel module - or new_cur < pm_locs[pm_cur].offset # [N]N

NPPPN: other parameters - ): - # non-parallel module - if merged_cur in optimizer_state_dict['state']: - new_states[new_cur] = optimizer_state_dict['state'][merged_cur] - merged_cur += 1 - new_cur += 1 - else: - # NNPN<[P]PP>N: the current parallel module - # parallel module - pm_param_count = len(pm_modules[pm_cur]._orign_module_metadata.origin_param_names) - # will map `pm_param_count` parameters in merge state dict - # to `pm_locs[pm_cur].count` in optimizer state. - cur_states = {} - for i in range(pm_param_count): - if merged_cur + i in optimizer_state_dict['state']: - cur_states[i] =optimizer_state_dict['state'][merged_cur + i] - pm_new_states = _opt_load_merged_state_dict(pm_modules[pm_cur], cur_states) - for idx, value in pm_new_states.items(): - new_states[new_cur + idx] = value - new_cur += pm_locs[pm_cur].count - merged_cur += pm_param_count - pm_cur += 1 - - # move the new states to the device if needed - for idx, state in new_states.items(): - for key, value in state.items(): - if isinstance(value, torch.Tensor): - new_states[idx][key] = value.to(device) - new_optimizer_state_dict = {} - new_optimizer_state_dict['state'] = new_states - new_optimizer_state_dict['param_groups'] = copy.deepcopy(optimizer_state_dict['param_groups']) - new_optimizer_state_dict['param_groups'][0]['params'] = list(range(new_cur)) - optimizer.load_state_dict(new_optimizer_state_dict) +def _trim_optimizer_merged_state_dict( + module: torch.nn.Module, + opt_extra_state: OptimizerExtraState, + optimizer_state_dict: Dict[str, Any], + *, + device: Union[str, torch.device] = None +) -> Dict[str, Any]: + """ + Trim the merged state dict to only keep the states needed for the optimizer. + + Args: + module (torch.nn.Module): the module to be loaded + opt_extra_state (OptimizerExtraState): the extra state of the optimizer + optimizer_state_dict (Dict[str, Any]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the optimizer state dict. + + Returns: + Dict[str, Any]: the trimmed optimizer state dict + """ + if not _is_supported_optimizer(opt_extra_state.name): + raise ValueError("Only Adam-like optimizers are supported.") + + device = device or torch.cuda.current_device() + + # handle non-paralleled module parameters + # make sure the order of the parameters + pm_name_locs: Dict[str, ModuleParameterLocation] = dict(sorted(opt_extra_state.parallel_module_locs.items(), key=lambda x: x[1].offset)) + pm_modules: List[ParallelModule] = [] + pm_locs = list(pm_name_locs.values()) + for name in pm_name_locs: + m = get_member_by_name(module, name) + if not isinstance(m, ParallelModule): + raise ValueError(f"Module {name} is not a ParallelModule") + pm_modules.append(m) + + merged_cur = 0 # the current index of the merged state dict + pm_cur = 0 # the current index of the parallel module in pm_locs + new_states: Dict[int, Dict[str, Any]] = {} + new_cur = 0 # the current index of the new state dict + assert len(optimizer_state_dict['param_groups']) == 1 + effective_state_len = len(optimizer_state_dict['param_groups'][0]['params']) + while merged_cur < effective_state_len: + # N: non-paralleled module parameters, P: paralleled module (will have multiple parameters) + # The parameter list would look like: NNPNPPPN + # []: the current processing parameter + # <>: the current processing parallel module + if ( + pm_cur >= len(pm_modules) # NNPNPPP[N]: the ending parameters, no current parallel module + or new_cur < pm_locs[pm_cur].offset # [N]N

NPPPN: other parameters + ): + # non-parallel module + if merged_cur in optimizer_state_dict['state']: + new_states[new_cur] = optimizer_state_dict['state'][merged_cur] + merged_cur += 1 + new_cur += 1 + else: + # NNPN<[P]PP>N: the current parallel module + # parallel module + pm_param_count = len(pm_modules[pm_cur].origin_module_metadata.origin_param_names) + # will map `pm_param_count` parameters in merge state dict + # to `pm_locs[pm_cur].count` in optimizer state. + cur_states = {} + for i in range(pm_param_count): + if merged_cur + i in optimizer_state_dict['state']: + cur_states[i] =optimizer_state_dict['state'][merged_cur + i] + pm_new_states = _opt_load_merged_state_dict(pm_modules[pm_cur], cur_states) + for idx, value in pm_new_states.items(): + new_states[new_cur + idx] = value + new_cur += pm_locs[pm_cur].count + merged_cur += pm_param_count + pm_cur += 1 + + # move the new states to the device if needed + for idx, state in new_states.items(): + for key, value in state.items(): + if isinstance(value, torch.Tensor): + new_states[idx][key] = value.to(device) + + new_optimizer_state_dict = {} + new_optimizer_state_dict['state'] = new_states + new_optimizer_state_dict['param_groups'] = copy.deepcopy(optimizer_state_dict['param_groups']) + new_optimizer_state_dict['param_groups'][0]['params'] = list(range(new_cur)) + + return new_optimizer_state_dict def _opt_load_merged_state_dict(module: ParallelModule, states: Dict[int, Dict[str, Any]]): + """ + Args: + module (ParallelModule): the parallel module + states (Dict[int, Dict[str, Any]]): the merged optimizer state dict for a parallel module + key: optimizer parameter index in the merged state dict + value: the state dict for each attribute, e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys + """ with torch.no_grad(): # orig_name -> state + # state: Dict[str, Any], e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys orig_param_dict: Dict[str, Dict[str, Any]] = {} cnt = 0 - origin_param_names = module._orign_module_metadata.origin_param_names + origin_param_names = module.origin_module_metadata.origin_param_names for name in origin_param_names: if cnt in states: # some parameters may not in the sates when it is not used or requires_grad is False in training orig_param_dict[name] = states[cnt] cnt = cnt + 1 - if module.compute_config.use_zero: + if module.compute_config.use_zero == 1: return _construct_optim_state_zero(module, orig_param_dict) + elif module.compute_config.use_zero > 1: + return _construct_optim_state_zero3(module, orig_param_dict) else: return _construct_optim_state_nonzero(module, orig_param_dict) +def _construct_optim_state_zero3( + module: ParallelModule, + orig_param_dict: Dict[str, Dict[str, Any]] +): + # state for each parameter in the parallel module + new_states = _construct_optim_state_nonzero(module, orig_param_dict) + param_state_map = {p: new_states[idx] for idx, p in enumerate(module.parameters())} + + state_dict, opt_param_idx = {}, 0 + opt_param = module.parameters_for_optimizer() + # first load the params' optimizer state for the reducers's flattened params + for reducer in module.reducers: + for bucket in reducer.buckets: + bucket: Bucket + # one bucket corresponds to one flattened param + assert len(opt_param[opt_param_idx].shape) == 1 + chunk_size = bucket._contiguous_params.shape[0] + opt_states = {} + offset = 0 + for param in bucket.params: + sliced_new_val = param_state_map[param] + param_numel = bucket.get_aligned_numel(param) + # init the optimizer state + if not opt_states: + for key in sliced_new_val.keys(): + if key == 'step': + opt_states[key] = sliced_new_val[key] + else: + opt_states[key] = torch.zeros( + [chunk_size], dtype=sliced_new_val[key].dtype, + device=sliced_new_val[key].device, requires_grad=False + ) + # copy the param's slices to the optimizer's chunk + for key in opt_states.keys(): + if key == 'step': + continue + opt_states[key][offset:offset+sliced_new_val[key].numel()] = sliced_new_val[key] + + offset += param_numel + state_dict[opt_param_idx] = opt_states + opt_param_idx += 1 + + # load the params' optimizer state that are not in reducers + reducer_pids = set() + for reducer in module.reducers: + reducer_pids.update(id(p) for p in reducer.params) + for param in module.parameters(): + if id(param) not in reducer_pids: + state_dict[opt_param_idx] = param_state_map[param] + opt_param_idx += 1 + + return state_dict + + def _construct_optim_state_zero( module: ParallelModule, orig_param_dict: Dict[str, Dict[str, Any]], ): + """ + Construct the optimizer state for a ParallelModule with ZeRO optimization. + Args: + module (ParallelModule): the parallel module + orig_param_dict (Dict[str, Dict[str, Any]]): the original parameter optimizer state + key: original parameter name + value: the state dict for each attribute, e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys + """ dist_param_map = module.dist_param_map # name in parallel module (without tid suffix) -> name in origin module param_area_map = module.fullmap # str -> AttrMeta def _get_optimizer_state_of_param(param, param_ids, local_names): @@ -2070,9 +2276,12 @@ def _construct_optim_state_nonzero( dist_param_map = module.dist_param_map # name in parallel module (without tid suffix) -> name in origin module param_area_map = module.fullmap # str -> AttrMeta - new_states = {} + new_states: dict[int, dict[str, torch.Tensor]] = {} for index, (local_name, _) in enumerate(module.named_parameters()): - new_states[index] = _extract_new_state(local_name, orig_param_dict, dist_param_map, param_area_map) + new_states[index] = _extract_new_state( + local_name, orig_param_dict, dist_param_map, param_area_map, + module.get_zero3_attr_meta(local_name) + ) return new_states @@ -2082,7 +2291,8 @@ def _extract_new_state( orig_param_dict: Dict[str, Dict[str, Any]], dist_param_map: Dict[str, str], param_area_map: Dict[str, AttrMeta], -): + zero3_info: Optional[Zero3AttrMeta] = None +) -> Dict[str, torch.Tensor]: name = '_'.join(local_name.split('_')[:-1]) # remove the integer suffix assert name in dist_param_map attr_meta = param_area_map[local_name] @@ -2093,6 +2303,18 @@ def _extract_new_state( sliced_new_val[key] = new_val[key] else: sliced_new_val[key] = new_val[key][attr_meta.slicers] / attr_meta.val_chunks + if zero3_info is not None: + sliced_new_val[key] = sliced_new_val[key].view(-1)[zero3_info.start:zero3_info.end] + if sliced_new_val[key].numel() < zero3_info.chunk_size: + # padding if needed + sliced_new_val[key] = torch.cat( + [sliced_new_val[key], + torch.zeros( + zero3_info.chunk_size - sliced_new_val[key].numel(), + dtype=sliced_new_val[key].dtype, + device=sliced_new_val[key].device + )], dim=0 + ) return sliced_new_val @@ -2184,7 +2406,7 @@ def _broadcast_gen_files( if curr_rank != 0: files = sent_obj[0] - logging.info(f'File list broadcasted ({len(files)} in total).') + logger.info(f'File list broadcasted ({len(files)} in total).') # send file content one by one for fname in files: if curr_rank == 0: @@ -2196,12 +2418,53 @@ def _broadcast_gen_files( if curr_rank != 0: with open(outdir / fname, 'wb') as f: f.write(data[0]) - logging.info(f'File {fname} broadcasted.') + logger.info(f'File {fname} broadcasted.') # wait for all nodes to finish torch.distributed.barrier() +def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Tuple[ + Dict[int, Dict[str, Dict[str, AttrMeta]]], + Dict[str, int], + Dict[int, Dict[str, Dict[str, AttrMeta]]] +]: + """ + A helper function that computes the deduplicated attribute information from all ranks. + Note that this function may be removed in the future and dedup information are computed + directly at the compilation stage. + + Returns: + A tuple containing: + - rank2deduped_fullmap: a mapping from rank id to deduplicated attribute information + - dedup_group_size: the size of the deduplication group for each parallel module + - global_fullmaps: a mapping from rank id to full attribute information + """ + dedup_group_size = {} + for prefix, parallel_module in parallel_modules.items(): + dedup_group_size[prefix] = parallel_module.module_dedup_group_size + + world_size = torch.distributed.get_world_size() + global_fullmaps: Dict[ + int, # rank id + Dict[str, # submodule prefix + Dict[str, # attribute name in parallel module + AttrMeta]] + ] = {} + for rank in range(world_size): + global_fullmaps[rank] = {} + for prefix, m in parallel_modules.items(): + global_fullmaps[rank][prefix] = m.get_attr_meta_map(rank) + # `dedup_attrs` is a deterministic algorithm, so it produces same results across different ranks + rank2deduped_fullmap = dedup_attrs(global_fullmaps) + + for prefix, group_size in dedup_group_size.items(): + for rank in range(group_size, world_size): + assert len(rank2deduped_fullmap[rank].get(prefix, {})) == 0, f'Rank {rank} has non-empty deduped_fullmap: {rank2deduped_fullmap[rank]}' + + return rank2deduped_fullmap, dedup_group_size, global_fullmaps + + @torch.no_grad() def deduped_state_dict( module: torch.nn.Module, @@ -2224,6 +2487,9 @@ def deduped_state_dict( module_state_dict, opt_state_dict = None, None parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} + rank2deduped_fullmap, _, _ = _collect_dedup_info(parallel_modules) + cur_deduped_fullmap = rank2deduped_fullmap[cur_rank] + # The reason we use `Module.state_dict` on the whole to get the complete state dict # instead of call `Module.state_dict` on each submodule # is to make sure the hooks to state_dict are called. @@ -2231,11 +2497,18 @@ def deduped_state_dict( for key in list(module_state_dict.keys()): if key.endswith(ParallelModule.EXTRA_STATE_KEY): # never remove extra state continue - prefix = '.'.join(key.split('.')[:-1]) # remove the last part of the key - dedup_group_size = parallel_modules[prefix].module_dedup_group_size \ - if prefix in parallel_modules else 1 - # only keep the first `dedup_group_size` ranks' state - if cur_rank >= dedup_group_size: + split_names = key.split('.') + prefix = '.'.join(split_names[:-1]) # remove the last part of the key + if prefix in parallel_modules: + if parallel_modules[prefix].compute_config.use_zero > 1: + # for zero3, we don't use advanced deduplication. + # TODO: handle zero3 case + if cur_rank >= parallel_modules[prefix].module_dedup_group_size: + module_state_dict.pop(key, None) + elif prefix not in cur_deduped_fullmap or split_names[-1] not in cur_deduped_fullmap[prefix]: + module_state_dict.pop(key, None) + # since replicated non-parallel modules, we only keep weights on rank 0 + elif cur_rank >= 1: module_state_dict.pop(key, None) if optimizer is not None: @@ -2285,20 +2558,79 @@ def load_deduped_state_dict( None """ device = device or torch.cuda.current_device() + cur_rank = torch.distributed.get_rank() - # only load partial state for all ranks except rank 0 - module.load_state_dict(module_state_dict, strict=False) + # step 1: load deduped state dict at each rank + missing_keys, unexpected_keys = module.load_state_dict(module_state_dict, strict=False) module.to(device) torch.distributed.barrier() + logger.debug(f'At rank {cur_rank}, state_dict keys: {module_state_dict.keys()}.') + logger.debug(f'At rank {cur_rank}, missing_keys: {missing_keys}, unexpected_keys: {unexpected_keys}.') + + # step 2: broadcast deduped weights inside 1st scale unit for non-zero3 parallel modules + # for zero3 modules, the weights are already complete after step 1 + # TODO: refine zero3 modules support + parallel_modules = { + prefix: m + for prefix, m in module.named_modules() + if isinstance(m, ParallelModule) and m.compute_config.use_zero <= 1 + } + if parallel_modules: + rank2deduped_fullmap, dedup_group_size, global_tensor_meta = _collect_dedup_info(parallel_modules) + logger.debug(f'At rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}.') + + # broadcast weights in parallel modules + for rank, deduped_fullmap in rank2deduped_fullmap.items(): + logger.debug(f'At rank {cur_rank}, process rank: {rank}.') + for prefix, fullmap in deduped_fullmap.items(): + if cur_rank >= dedup_group_size[prefix]: + break + broadcast_group = DeviceGroup().get_group(list(range(dedup_group_size[prefix]))) + for local_name, attr_meta in fullmap.items(): + key = f'{prefix}.{local_name}' if prefix else local_name + assert prefix in parallel_modules, f'Prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}.' + pm = parallel_modules[prefix] + attr_meta = global_tensor_meta[rank][prefix][local_name] + shape, dtype = attr_meta.sub_shape, attr_meta.dtype + if rank == cur_rank: + assert hasattr(pm, local_name), f'Local name {local_name} not found in {pm}.' + broadcast_tensor = getattr(pm, local_name) + logger.info(f'Broadcast: {key} from {cur_rank}.') + else: + broadcast_tensor = torch.empty(shape, device=device, requires_grad=False, dtype=dtype) + torch.distributed.broadcast(broadcast_tensor, src=rank, group=broadcast_group) + if rank != cur_rank: + # in pipeline parallelism, the local_name may not be found in the module + if hasattr(pm, local_name): + logger.info(f'At rank {cur_rank}, try to load: {key} from rank {rank}.') + attr = getattr(pm, local_name) + if key in missing_keys: + attr.data.copy_(broadcast_tensor) + missing_keys.remove(key) + else: + assert torch.equal(attr, broadcast_tensor), \ + f'At rank {cur_rank}, the attribute {key} is already loaded, but not equal to the broadcasted tensor from rank {rank}.' + else: + logger.info(f'At rank {cur_rank}, skip to load: {key} from rank {rank}, not found in the module.') + + for key in missing_keys: + split_names = key.split('.') + prefix = '.'.join(split_names[:-1]) # remove the last part of the key + assert prefix not in parallel_modules or cur_rank >= dedup_group_size[prefix], f'At rank {cur_rank}, the missing key {key} should be in non-parallel modules.' + + # At this point + # - All parallel modules in first scale unit should be complete. + # - Non-parallel modules in rank0 should be complete. The rest ranks will get the weights via broadcast_weights. + torch.distributed.barrier() - # broadcast weights + # step 3: + # - broadcast non-parallel module weights from 0th rank to other ranks + # - broadcast parallel modules weights from 1st scale unit to other units broadcast_weights(module) - if optimizer is not None: - if 'adam' not in optimizer._extra_state.name.lower(): + if optimizer is not None and optimizer_state_dict is not None: + if not _is_supported_optimizer(optimizer._extra_state.name): raise ValueError("Only Adam-like optimizers are supported.") - if optimizer_state_dict is None: - raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") for idx, state in optimizer_state_dict['state'].items(): for key, value in state.items(): @@ -2338,7 +2670,7 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g broadcast_group = setup_stride_broadcast_group(dedup_group_size) src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks - logging.info(f'Rank-{rank} is broadcasting optimizer states to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') + logger.info(f'Rank-{rank} is broadcasting optimizer states to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') # broadcast param groups and state keys/shapes/dtypes via broadcast_object_list if rank == src_rank: @@ -2424,7 +2756,7 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): broadcast_group = setup_stride_broadcast_group(stride_size) rank = torch.distributed.get_rank() src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks - logging.info(f'Rank-{rank} is broadcasting weights of {module.__class__.__name__} to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') + logger.info(f'Rank-{rank} is broadcasting weights of {module.__class__.__name__} to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') if isinstance(module, ParallelModule): if not _broadcast_single_value(src_rank, curr_parallel_group, module.non_presistent_buffers_inited): @@ -2432,15 +2764,15 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): # we have a special optimization for ParallelModule params = module.parameters_for_broadcast() if isinstance(module, ParallelModule) else list(module.parameters(False)) - logging.info(f'Inplace broadcasting {len(params)} parameters...') + logger.info(f'Inplace broadcasting {len(params)} parameters...') for i, param in enumerate(params): torch.distributed.broadcast(param.data, src=src_rank, group=curr_parallel_group) - logging.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') + logger.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') # NOTE: may batch buffers for efficient broadcast, # current implementation is the most memory efficient way. buffers = list(module.buffers(False)) - logging.info(f'Inplace broadcasting {len(buffers)} buffers...') + logger.info(f'Inplace broadcasting {len(buffers)} buffers...') for buffer in buffers: torch.distributed.broadcast(buffer.data, src=src_rank, group=curr_parallel_group) @@ -2477,9 +2809,11 @@ def load_sharded_state_dict( device = device or torch.cuda.current_device() module.load_state_dict(module_state_dict) module.to(device) - if optimizer: - if optimizer_state_dict is None: - raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") + if optimizer and optimizer_state_dict: + for idx, state in optimizer_state_dict.get('state', {}).items(): + for key, value in state.items(): + if isinstance(value, torch.Tensor): + optimizer_state_dict['state'][idx][key] = value.to(device) optimizer.load_state_dict(optimizer_state_dict) @@ -2513,3 +2847,387 @@ def sync_grad_when(cond: bool): cond (bool): whether to synchronize gradients. """ return _runtime_flags(skip_reducer=not cond) + + +def _construct_parallel_module_stub(metadata): + pmodules = {prefix: ParallelModule._unpack(minfo) for prefix, minfo in metadata.items()} + + # whole parallel module + if len(pmodules) == 1 and list(pmodules.keys())[0] == '': + module = pmodules[''] + else: + module = torch.nn.Module() + for prefix, pmodule in pmodules.items(): + set_member_by_name(module, prefix, pmodule) + + # mock `named_modules` to list parallel modules in stub module + def named_modules( + memo=None, + prefix: str = "", + remove_duplicate: bool = True, + ): + assert memo is None and prefix == '' and remove_duplicate is True, \ + "Only support default arguments" + return pmodules.items() + + module.named_modules = named_modules + + return module + + +def _trim_module_merged_state_dict( + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + *, + device: Union[str, torch.device] = None, +): + device = device or torch.cuda.current_device() + + parallel_modules = {module_path: m for module_path, m in module.named_modules() if isinstance(m, ParallelModule)} + + trimmed_state_dict = {} + # collect non-parallel module parameters + for key, tensor in module_state_dict.items(): + parts = key.split('.') + if not any('.'.join(parts[:i]) in parallel_modules for i in range(0, len(parts))): + trimmed_state_dict[key] = tensor.to(device) + + for module_path, pmodule in parallel_modules.items(): + prefix = module_path + '.' if module_path else '' + trimmed_state_dict.update( + pmodule.trim_merged_state_dict( + module_state_dict, prefix=prefix, + device=device + ) + ) + return trimmed_state_dict + + +def _send_trimmed_module_state_dict( + trimmed_state_dict: Dict[str, torch.Tensor], + group: torch.distributed.ProcessGroup, + dst_rank: int, +): + """ + Send the trimmed state dict to the specified destination rank. + + Args: + trimmed_state_dict (Dict[str, torch.Tensor]): the trimmed state dict to send. + dst_rank (int): the destination rank to send the state dict to. + """ + # send trimmed state dict to rank + # one tensor each time + keys = list(trimmed_state_dict.keys()) + shape_dtypes = [(tensor.shape, tensor.dtype) for tensor in trimmed_state_dict.values()] + torch.distributed.send_object_list([keys, shape_dtypes], group=group, dst=dst_rank) + for key in keys: + tensor = trimmed_state_dict[key] + # NOTE: send is broken if the tensor is not contiguous + torch.distributed.send(tensor.contiguous(), group=group, dst=dst_rank) + + +def _receive_trimmed_module_state_dict( + src_rank: int, + group: torch.distributed.ProcessGroup, + device: Union[str, torch.device] = None, +): + """ + Receive the trimmed state dict from the specified source rank. + + Args: + src_rank (int): the source rank to receive the state dict from. + """ + device = device or torch.cuda.current_device() + + # receive trimmed state dict from rank + # one at a time + keys_shape_dtypes=[None, None] + torch.distributed.recv_object_list(keys_shape_dtypes, group=group, src=src_rank) + keys: list[str] = keys_shape_dtypes[0] + shape_dtypes: list[tuple[torch.Size, torch.dtype]] = keys_shape_dtypes[1] + + trimmed_state_dict = {} + for key, shape_dtype in zip(keys, shape_dtypes): + tensor = torch.zeros(shape_dtype[0], dtype=shape_dtype[1], device=device) + torch.distributed.recv(tensor, group=group, src=src_rank) + trimmed_state_dict[key] = tensor + return trimmed_state_dict + + +def _send_trimmed_opt_state_dict( + trimmed_opt_state_dict: OptStateDict, + group: torch.distributed.ProcessGroup, + dst_rank: int, +): + """ + Send the trimmed optimizer state dict to the specified destination rank. + + Args: + trimmed_opt_state_dict (OptStateDict): the trimmed optimizer state dict to send. + dst_rank (int): the destination rank to send the state dict to. + """ + # send trimmed optimizer state dict to rank + # one tensor each time + + # broadcast param groups and state keys/shapes/dtypes via broadcast_object_list + state_info = {} + state_keys = list(trimmed_opt_state_dict['state'].keys()) + param_group = trimmed_opt_state_dict['param_groups'] + for idx in state_keys: + state_info[idx] = {key: (value.shape, value.dtype) for key, value in trimmed_opt_state_dict['state'][idx].items()} + sent = [state_keys, state_info, param_group] + torch.distributed.send_object_list(sent, group=group, dst=dst_rank) + + # broadcast step in stack + if 'step' in trimmed_opt_state_dict['state'][state_keys[0]]: + step_stack = torch.stack( + [trimmed_opt_state_dict['state'][k]['step'] for k in state_keys] + ) + torch.distributed.send(step_stack, group=group, dst=dst_rank) + + # broadcast other states + # TODO: can be slow? + for k in state_keys: + keys = sorted(trimmed_opt_state_dict['state'][k].keys()) + if 'step' in keys: + keys.remove('step') # we have done step in previous. + for key in keys: + value = trimmed_opt_state_dict['state'][k][key] + torch.distributed.send(value.data, group=group, dst=dst_rank) + + +def _receive_trimmed_opt_state_dict( + src_rank: int, + group: torch.distributed.ProcessGroup, + device: Union[str, torch.device] = None, + ) -> OptStateDict: + """ + Receive the trimmed optimizer state dict from the specified source rank. + + Args: + src_rank (int): the source rank to receive the state dict from. + """ + device = device or torch.cuda.current_device() + + # receive trimmed optimizer state dict from rank + # one at a time + state_dict_info = [None, None, None] + torch.distributed.recv_object_list(state_dict_info, group=group, src=src_rank) + state_keys: list[str] = state_dict_info[0] + state_info: list[tuple[torch.Size, torch.dtype]] = state_dict_info[1] + param_group = state_dict_info[2] + + trimmed_opt_state_dict = { + 'state': {}, + 'param_groups': param_group + } + for key in state_keys: + trimmed_opt_state_dict['state'][key] = { + k: torch.zeros(v[0], dtype=v[1], device=device) + for k, v in state_info[key].items() + } + + # receive steps + if 'step' in trimmed_opt_state_dict['state'][state_keys[0]]: + step_stack = torch.zeros( + len(state_keys), + dtype=trimmed_opt_state_dict['state'][state_keys[0]]['step'].dtype, + device=device + ) + torch.distributed.recv(step_stack, group=group, src=src_rank) + for k, v in zip(state_keys, step_stack): + trimmed_opt_state_dict['state'][k]['step'].copy_(v) + + # receive other states + for k in state_keys: + keys = sorted(trimmed_opt_state_dict['state'][k].keys()) + if 'step' in keys: + keys.remove('step') # we have done step in previous. + for key in keys: + value = trimmed_opt_state_dict['state'][k][key] + torch.distributed.recv(value.data, group=group, src=src_rank) + + return trimmed_opt_state_dict + + +def trimmed_broadcast_merged_state_dict( + module: torch.nn.Module, + module_state_dict: Optional[Dict[str, Any]] = None, + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + device: Union[str, torch.device] = None, +) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + """ + trim merged state dict and broadcast to each rank. + + Args: + module (torch.nn.Module): the module to be loaded + module_state_dict (Dict[str, Any]): the merged model state dict + optimizer (Optional[torch.optim.Optimizer]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. + src_rank (int): the source rank to load the merged state dict from. + dst_ranks (Optional[list[int]]): the destination ranks to load the merged state dict to. + + Returns: + Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + the trimmed state dicts for the module and optimizer + """ + device = device or torch.cuda.current_device() + world_size = torch.distributed.get_world_size() + dst_ranks = dst_ranks or list(range(world_size)) + cur_rank = torch.distributed.get_rank() + + if cur_rank not in dst_ranks or src_rank not in dst_ranks: + raise ValueError( + f"Invalid rank configuration. Both current rank ({cur_rank}) and source rank ({src_rank}) " + f"must be in the destination ranks {dst_ranks}." + ) + + pg = DeviceGroup().get_group(dst_ranks) + + if cur_rank == src_rank: + if optimizer_state_dict and not optimizer: + raise ValueError("Optimizer must be provided when loading optimizer state dict.") + else: + if optimizer_state_dict or module_state_dict: + raise ValueError("Only the source rank can provide the merged state dicts.") + + rank_metadata = ( + {module_path: m._pack() for module_path, m in module.named_modules() if isinstance(m, ParallelModule)}, + optimizer._extra_state if optimizer else None, + ) + + rank_metadatas = [None] * len(dst_ranks) if cur_rank == src_rank else None + torch.distributed.gather_object(rank_metadata, rank_metadatas, group=pg, dst=src_rank) + + if cur_rank == src_rank: + will_load_opt_state = [optimizer_state_dict is not None] + else: + will_load_opt_state = [None] + torch.distributed.broadcast_object_list(will_load_opt_state, group=pg, src=src_rank) + will_load_opt_state = will_load_opt_state[0] + if will_load_opt_state and not optimizer: + raise ValueError("Optimizer must be provided when loading optimizer state dict.") + + ret = None + + if cur_rank == src_rank: + pmodule_stubs = {rank : _construct_parallel_module_stub(r[0]) for rank, r in zip(dst_ranks, rank_metadatas)} + opt_extra_states = {rank : r[1] for rank, r in zip(dst_ranks, rank_metadatas)} + for rank in dst_ranks: + if rank != cur_rank: + logger.info(f'At rank {src_rank}: Trimming module state dict for rank {rank}') + trimmed_module_state_dict = _trim_module_merged_state_dict( + pmodule_stubs[rank], + module_state_dict, + device=device, + ) + logger.info(f'At rank {src_rank}: Sending trimmed module state dict for rank {rank}') + _send_trimmed_module_state_dict(trimmed_module_state_dict, dst_rank=rank, group=pg) + del trimmed_module_state_dict + + if will_load_opt_state: + logger.info(f'At rank {src_rank}: Trimming optimizer state dict for rank {rank}') + trimmed_opt_state_dict = _trim_optimizer_merged_state_dict( + pmodule_stubs[rank], + opt_extra_states[rank], + optimizer_state_dict, + device=device, + ) + logger.info(f'At rank {src_rank}: Sending trimmed optimizer state dict for rank {rank}') + _send_trimmed_opt_state_dict(trimmed_opt_state_dict, dst_rank=rank, group=pg) + del trimmed_opt_state_dict + + torch.distributed.barrier(group=pg) + + # load for self after state dict for all other ranks are sent + # this can lower gpu memory peak + logger.info(f'At rank {src_rank}: Trimming module state dict for self rank {cur_rank}') + trimmed_module_state_dict = _trim_module_merged_state_dict( + pmodule_stubs[cur_rank], + module_state_dict, + device=device, + ) + if will_load_opt_state: + logger.info(f'At rank {src_rank}: Trimming optimizer state dict for self rank {cur_rank}') + trimmed_opt_state_dict = _trim_optimizer_merged_state_dict( + pmodule_stubs[cur_rank], + opt_extra_states[cur_rank], + optimizer_state_dict, + device=device, + ) + else: + trimmed_opt_state_dict = None + ret = (trimmed_module_state_dict, trimmed_opt_state_dict) + else: + for rank in dst_ranks: + if rank == cur_rank: + # receive state dict from src_rank + logger.info(f'At rank {cur_rank}: Receiving trimmed module state dict from rank {src_rank}') + trimmed_module_state_dict = _receive_trimmed_module_state_dict(src_rank, group=pg) + + if will_load_opt_state: + logger.info(f'At rank {cur_rank}: Receiving trimmed optimizer state dict from rank {src_rank}') + trimmed_opt_state_dict = _receive_trimmed_opt_state_dict(src_rank, group=pg) + else: + trimmed_opt_state_dict = None + + ret = (trimmed_module_state_dict, trimmed_opt_state_dict) + + torch.distributed.barrier(group=pg) + + assert ret is not None + # make it a sharded state dict. + for module_path, m in module.named_modules(): + prefix = module_path + '.' if module_path else '' + if isinstance(m, ParallelModule): + m._add_extra_state(ret[0], prefix) + return ret + + +def load_merged_state_dict_from_rank( + module: torch.nn.Module, + module_state_dict: Optional[Dict[str, Any]] = None, + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + device: Union[str, torch.device] = None, +): + """ + load the merged state dict from rank. + + Only src_rank will load merged state dict to memory (for saving memory), + and dst_ranks will receive the sharded state dict from src_rank via communication. + + Args: + module (torch.nn.Module): the module to be loaded + module_state_dict (Dict[str, Any]): the merged model state dict + optimizer (Optional[torch.optim.Optimizer]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. + src_rank (int): the source rank to load the merged state dict from. + dst_ranks (Optional[list[int]]): the destination ranks to load the merged state dict to. + + Returns: + None + """ + trimmed_module_state_dict, trimmed_opt_state_dict = trimmed_broadcast_merged_state_dict( + module, + module_state_dict, + optimizer, + optimizer_state_dict, + device=device, + src_rank=src_rank, + dst_ranks=dst_ranks, + ) + module.load_state_dict(trimmed_module_state_dict) + if trimmed_opt_state_dict: + optimizer.load_state_dict(trimmed_opt_state_dict) diff --git a/nnscaler/policies.py b/nnscaler/policies.py index f1db5858..b4cf13af 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -18,8 +18,10 @@ IRDataOperation is recommended to be replicated to all devices. """ +import ast +from dataclasses import dataclass, field import logging -from typing import List, Optional, TYPE_CHECKING +from typing import Any, List, Literal, Optional, TYPE_CHECKING, Callable, Iterable, Union import random import torch @@ -30,13 +32,17 @@ from nnscaler.graph import IRGraph from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.segment import IRSegment -from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from nnscaler.ir import IRCell, IRSubTensor, IRFullTensor +from nnscaler.ir.cten import IR +from nnscaler.runtime.function import identity, multiref +from nnscaler.utils import load_type if TYPE_CHECKING: - from nnscaler.parallel import ComputeConfig + from nnscaler.parallel import ComputeConfig, ParallelModule _logger = logging.getLogger(__name__) @@ -116,6 +122,8 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): random.shuffle(configs) for (idx, dim) in configs: if node.input(idx).shape[dim] % len(devs) != 0: continue + # only partition when all input tensors are constant on this dim + if not node.input(idx).dim_tracks[dim].is_constant: continue if node.algorithm('dim').satisfy(idx=idx, dim=dim, num=len(devs)): _tp(graph, node, devs, idx, dim) break @@ -219,6 +227,8 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: + from nnscaler.autodist.util import get_default_profile_path + pas_cfg = cfg.pas_config update_freq = pas_cfg.get('update_freq', 1) @@ -266,18 +276,24 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: use_memory_efficient_bf16 = pas_cfg.get('use_memory_efficient_bf16', False) use_fp16 = pas_cfg.get('use_fp16', use_memory_efficient_fp16) use_bf16 = pas_cfg.get('use_bf16', use_memory_efficient_bf16) + profile_dir = pas_cfg.get('profile_dir', None) + if profile_dir is None: + profile_dir = get_default_profile_path() re_profile = pas_cfg.get('re_profile', False) verbose = pas_cfg.get('verbose', False) load_plan_path = pas_cfg.get('load_plan_path', None) save_plan_path = pas_cfg.get('save_plan_path', None) partition_constraints_path = pas_cfg.get('partition_constraints_path', '') recompute_modules = pas_cfg.get('recompute_modules', '') + recompute_ratio = pas_cfg.get('recompute_ratio', 1.0) pipeline_pivots = pas_cfg.get('pipeline_pivots', '') max_pipeline_bubble_ratio = pas_cfg.get('max_pipeline_bubble_ratio', 0.2) max_pipeline_unbalance_ratio = pas_cfg.get('max_pipeline_unbalance_ratio', 0.5) use_apex_fused_adam_v2 = pas_cfg.get('use_apex_fused_adam_v2', False) parallel_profile = pas_cfg.get('parallel_profile', True) transient_mem_coef = pas_cfg.get('transient_mem_coef', 2) + disable_shared_param_constraint = pas_cfg.get('disable_shared_param_constraint', False) + solver = pas_cfg.get('solver', 'dp') task_name = f'{task_name}_{cfg.plan_ngpus}gpus_{update_freq}update_freq' if memory_constraint == -1: @@ -332,8 +348,10 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: opt_transient_coef=opt_transient_coef, verbose=verbose, re_profile=re_profile, + profile_dir=profile_dir, world_size=cfg.runtime_ngpus, recompute_modules=recompute_modules, + recompute_ratio=recompute_ratio, zero_stage=zero_stage, zero_ngroups=zero_ngroups, load_plan_path=load_plan_path, @@ -344,6 +362,503 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: max_pipeline_unbalance_ratio=max_pipeline_unbalance_ratio, parallel_profile=parallel_profile, transient_mem_coef=transient_mem_coef, + disable_shared_param_constraint=disable_shared_param_constraint, + solver=solver, ) return parallelize_graph(graph, autodist_cfg) + + +@dataclass(unsafe_hash=True, frozen=True) +class OpPartition: + """ + OpPartition represents a partition plan for an operator dimension. + """ + input: int + dim: int + + +@dataclass +class OpPlan: + """ + OpPlan represents the distributed plan for an operator. + """ + op: IRFwOperation + recompute_id: int = -1 # -1 means no recompute + stage_id: int = -1 # pipeline stage id, -1 means following the previous op's stage + + # user defined meta data for hooks + # which will be passed to the pre_hook and post_hook functions + # Note: Only types that can be safely `repr`-ed can be used here. (e.g., str, int, float, tuple, list, dict) + hook_meta: Any = None + + # function to be called before the op is executed + # which will be inserted in the runtime code before the op call. + # op's inputs will be passed to the hook. + # The signature will be like + # def pre_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: + pre_hook: Optional[Callable[['ParallelModule', Any, tuple[Any, ...], dict[str, Any]], None]] = None + + # function to be called after the op is executed + # which will be inserted in the runtime code after the op call. + # op's inputs and outputs will be passed to the hook. + # the signature will be like + # def post_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any], output: Any) -> None: + post_hook: Optional[Callable[['ParallelModule', Any, tuple[Any, ...], dict[str, Any], Any], None]] = None + + # OpPartition: user specified partition plan + # You only need to specify one partition plan here. + # For example, torch.matmul has annotation of `m k+, k+ n -> m n`, + # If you want to partition the matmul on the k dimension, + # you can set OpPartition(input=0, dim=1) or OpPartition(input=1, dim=0). + # They are equivalent. + # None: replicated + # 'auto': auto partition based on the input tensor partition info + # 1. if any of the input tensors is value partitioned, we replicate the op + # TODO: is it too strict? + # 2. if any of the input tensors is partitioned on a dim, + # we will try to partition the op on the same dim first, + # if the partition is invalid, we replicate the op + # 3. if all the input tensor is replicated, we replicate the op + partition: OpPartition | None | Literal['auto'] = None # partition plan + # for future extension + # don't use it now. + partitions: List[OpPartition | None] = field(default_factory=list) # multiple partition plans + + def __post_init__(self): + if self.partition is not None and len(self.partitions) > 0: + raise ValueError("Only one of partition and partitions can be set") + + if len(self.partitions) > 1: + raise NotImplementedError("Multiple partitions are not supported yet") + + if len(self.partitions) == 1: + self.partition = self.partitions[0] + self.partitions = [] + + +def get_layer_index(fqn: str) -> int: + """ + Extract the layer index from full qualified name. + If there are multiple integers in the name, raise ValueError. + """ + nums = [int(s) for s in fqn.split(".") if s.isdigit()] + if len(nums) != 1: + raise ValueError(f"Name {fqn} should only contain one integer") + return nums[0] + + +def get_called_self_module_name(node_call_expr: str) -> str: + """ + Get the called module name from the node's call expr by ast. + For example: + self.up_proj(x) -> up_proj + self.act_fn(self.gate_proj(x)) -> act_fn + self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -> down_proj + torch.tanh(x) -> '' # because it's not called from self + self.up_proj(x).transpose() -> '' # because it's an attribute call + + Other cases return empty string. + + NOTE: regex is not easy to make it work + + """ + + if not node_call_expr: + return '' + call_expr: ast.Call = ast.parse(node_call_expr, mode='eval').body # type: ignore + if isinstance(call_expr, ast.Call): # self.up_proj(x) + if isinstance(call_expr.func, ast.Attribute): # self.up_proj + if isinstance(call_expr.func.value, ast.Name) and call_expr.func.value.id == 'self': + return call_expr.func.attr # up_proj + return '' + + +def get_pas_ops(graph: IRGraph) -> List[IRFwOperation]: + """ + Get all operators in the graph that can set operator plan. + When we write a policy, only ops returned from this function need to be considered. + + Args: + graph: the input IRGraph + + Returns: + List[IRFwOperation]: list of IRFwOperation nodes + """ + return graph.select(ntype=IRFwOperation) + + +def fn( + graph: IRGraph, cfg: 'ComputeConfig', + policy: Union[ + Callable[[IRGraph, 'ComputeConfig'], IRGraph], + Callable[[IRGraph, 'ComputeConfig'], Iterable[OpPlan]], + ] +) -> IRGraph: + """ + General policy function based on user-defined policy. + The user-defined policy can either return the final IRGraph, or + return a list of OpPlan to describe the distributed plan for each operator. + + To write a new-style policy, the most important part is to locate the operator node in the graph. + Here are some tips: + 1. use `node.name` to get the operator name. + 2. use `node.fn` to get the operator function. + 3. use `node.module_stack` to get the module stack info. + 4. use `node.module_class_chain` to get the module class chain. + 5. use `node.call_expr` to get the call expression string. And you can user `ast.parse` to parse it. + 6. use `get_layer_index` to get the layer index in a torch.nn.ModuleList. + 7. use `get_called_self_module_name` to get the called self module name from the call expression. + 8. use `node.inputs()` the get the input tensors of the operator. + We can further check whether the input tensor is a parameter by `tensor.is_param`, + or get the full name of the parameter by `tensor.name`, etc. + 9. insert anchors in code with `nnscaler.anchor` to help locate the operator (intrusive way). + + A good way to locate the operator will be like: + 1. Locate the module first by module_class_chain (`target_module in node.module_class_chain`) + 2. If the module are used multiple times (e.g., in ModuleList), + locate further by layer index (`get_layer_index`) or `node.fqn`. + 3. Once the module is located, + we can further locate the operator by + `node.name`,`node.call_expr`, `node.fn`, `node.inputs()` (especially the `is_param`/`name` of input) + or other properties. + + Args: + graph: the input IRGraph + cfg: the compute config + policy: the user-defined policy function. It can either return the final IRGraph, + or return an iterable of OpPlan for each operator. + + Returns: + the distributed IRGraph + """ + result = policy(graph, cfg) + if isinstance(result, IRGraph): # traditional policy + return result + + op_plans = {r.op: r for r in result} + ngpus: int = cfg.plan_ngpus + + recompute_groups: dict[int, list[IRFwOperation]] = {} + recompute_last_id: int = -1 + recompute_group_stages: dict[int, int] = {} + + pp_stages: list[list[IRFwOperation]] = [[]] + pp_cur_stage_id = 0 + + # key: IRFullTensor + # value: + # key: stage_id + # value: set of OpPartition in this stage + tensor_splits: dict[IRFullTensor, dict[int, set[OpPartition]]] = {} + # store the last split info for each tensor to help handle auto partition + # None: replicated + # 'value': value partitioned + # int: the partitioned dim + output_tensor_last_split: dict[IRFullTensor, int | None | Literal['value']] = {} + + fw_nodes = dict.fromkeys(graph.select(ntype=IRFwOperation)) + + for node in fw_nodes: + if node not in op_plans: + op_plans[node] = OpPlan(op=node) # default: no partition, stage 0, no recompute + + node.hook_meta = op_plans[node].hook_meta + node.pre_hook = op_plans[node].pre_hook + node.post_hook = op_plans[node].post_hook + + op_plan = op_plans[node] + + # set pipeline stage id if not set + if op_plan.stage_id == -1: + op_plan.stage_id = pp_cur_stage_id + + # currently we only support partition for IRDimops + if not isinstance(op_plan.op, IRDimops): + if op_plan.partition == 'auto': + op_plan.partition = None + if op_plan.partition is not None: + raise ValueError("Only IRDimops can be partitioned.") + + # list of partitions for the op + # [] means no partition(replicated) + op_partitions = [op_plan.partition] if op_plan.partition is not None else [] + + if op_partitions == ['auto']: + # auto partition based on input tensor partition info + op_partitions = [] # reset to collect partitions + for idx, input in enumerate(op_plan.op.inputs()): + if not isinstance(input, IRSubTensor): + continue + ftensor = input.parent + last_partition_dim = output_tensor_last_split.get(ftensor, None) + if last_partition_dim == 'value': + # value partitioned input, replicate the op + op_partitions = [] + break + elif last_partition_dim is not None: + op_partitions.append(OpPartition(input=idx, dim=last_partition_dim)) + + # final partition plan for the op + # key: input idx, value: partitioned dim + op_partition_map: dict[int, int] = {} + if op_partitions: + # we partition the op based on the first partition plan + # and then check the rest partitions are satisfied or not + op_first_partition = op_partitions[0] + partitioned_nodes = op_plan.op.algorithm('dim')\ + .instantiate(idx=op_first_partition.input, dim=op_first_partition.dim, num=ngpus) + subnode = partitioned_nodes[0] # first subnode carries all necessary partition info + + # collect input partition info + # key: input idx, value: partitioned dim + result_partitions: dict[int, int] = {} + for idx, input in enumerate(subnode.inputs()): + if not isinstance(input, IRSubTensor): + continue + split_dims = input.splitdims() + assert len(split_dims) <= 1, "Internal Error: multiple splitdims in one input" + if split_dims: + result_partitions[idx] = split_dims[0] + + # check the rest partitions + # Note if we only have one partition plan, the check is skipped, we can always partition it + # In fact, if `auto` is not specified, we always have at most one partition plan + for op_partition in op_partitions[1:]: + if op_partition.input not in result_partitions or \ + result_partitions[op_partition.input] != op_partition.dim: + _logger.warning( + f"Operator {op_plan.op} cannot be partitioned as specified: {op_partition}" + f", replicate it instead." + ) + op_partitions = [] + op_partition_map = {} + break + else: + # all partitions are satisfied + # then we can update input/output partition info + + # make sure the first item in op_partition_map is the first partition plan + op_partition_map[op_first_partition.input] = op_first_partition.dim + op_partition_map.update(result_partitions) + + for output in subnode.outputs(): + if not isinstance(output, IRSubTensor): + continue + ftensor = output.parent + if output.valmap != (0, 1): + output_tensor_last_split[ftensor] = 'value' + else: + split_dims = output.splitdims() + assert len(split_dims) <= 1, "Internal Error: multiple splitdims in one output" + if split_dims: + output_tensor_last_split[ftensor] = split_dims[0] + + if op_plan.partition == 'auto': + if not op_partition_map: + op_plan.partition = None + else: + # use the first partition plan, + # which is consistent with the logic above + first_input_idx = list(op_partition_map.keys())[0] + op_plan.partition = OpPartition( + input=first_input_idx, + dim=op_partition_map[first_input_idx] + ) + + # update tensor_splits for input tensors + for idx, input in enumerate(op_plan.op.inputs()): + if not isinstance(input, IRSubTensor): + continue + ftensor = input.parent + if ftensor not in tensor_splits: + tensor_splits[ftensor] = {} + if idx not in op_partition_map: + tensor_splits[ftensor].setdefault(op_plan.stage_id, set()).add(None) + else: + tensor_splits[ftensor].setdefault(op_plan.stage_id, set()).add( + OpPartition(input=idx, dim=op_partition_map[idx])) + + if op_plan.recompute_id != -1: + if op_plan.recompute_id in recompute_group_stages: + if recompute_group_stages[op_plan.recompute_id] != op_plan.stage_id: + raise ValueError("All ops in a recompute group must be in the same stage") + else: + recompute_group_stages[op_plan.recompute_id] = op_plan.stage_id + + if op_plan.recompute_id != recompute_last_id and op_plan.recompute_id in recompute_groups: + raise ValueError("Nodes in a recompute group must be continuous.") + + recompute_groups.setdefault(op_plan.recompute_id, []).append(op_plan.op) + + recompute_last_id = op_plan.recompute_id + + # update pipeline stages + if op_plan.stage_id == pp_cur_stage_id: + pp_stages[pp_cur_stage_id].append(op_plan.op) + elif op_plan.stage_id == pp_cur_stage_id + 1: + pp_cur_stage_id += 1 + pp_stages.append([op_plan.op]) + else: + raise ValueError("Pipeline stage ids must be continuous integers starting from 0") + + if len(op_plans) != len(fw_nodes): + assert len(op_plans) > len(fw_nodes) + for op_plan in op_plans.values(): + if op_plan.op not in fw_nodes: + raise ValueError(f"OpPlan contains operator {op_plan.op} not in the graph or not a forward operator") + + pp_segs = [graph] + nstages = len(pp_stages) + pp_enabled = nstages > 1 + # not all schedulers support pp_size < nstages + pp_size = cfg.pas_config.get('pipeline_size', nstages) + nmicros = cfg.pas_config.get('pipeline_nmicros', None) + scheduler = cfg.pas_config.get('pipeline_scheduler', '1f1b') + tp_size = ngpus // pp_size + + if pp_enabled: + if not cfg.use_end2end: + raise ValueError("Pipeline parallelism requires use_end2end to be True") + if pp_size <= 1: + raise ValueError("pipeline_size must be greater than 1 when pipeline is enabled") + if not nmicros: + raise ValueError("nmicros must be set when pipeline is enabled") + if nstages % pp_size != 0: + raise ValueError(f'invalid pipeline_size {pp_size} for nstages {nstages}') + if ngpus % pp_size != 0: + raise ValueError(f'invalid pipeline_size {pp_size} for ngpus {ngpus}') + else: + if pp_size != 1: + raise ValueError("pipeline_size must be 1 when pipeline is disabled") + + # set recompute groups + for group in recompute_groups.values(): + if len(group) <= 1: + continue + graph.recompute(group) + + # add multiref for shared parameters across stages + # note that we have constrained that shared parameters cannot be partitioned in SPMDSolver, other input tensors + # belonging to the same operator can be partitioned. For example, in some LLMs, the embedding matrix is shared + # with the output layer. In this case, the batch dim / seq dim of the activation tensor can be partitioned. + for ftensor, stage_info in tensor_splits.items(): + if not ftensor.is_param(): + continue + splits = set(k.dim if k is not None else None for v in stage_info.values() for k in v) + find_replicated = None in splits + splits = list(splits) + # For safety, we will add multiref when detecting shared param are all replicated for pipeline parallelism. + # The reason is that stages may have different number of devices, it is hard to synchronize gradients directly + # by inserting reducers although weights are all REPLICAED. + if len(splits) > 1 or (pp_enabled and find_replicated): + _logger.info(f'add multiref for shared param {ftensor}') + graph.multiref(ftensor, comment='shared param') + + # set pipeline stages + if pp_enabled: + graph.staging([s[0] for s in pp_stages]) + pp_segs: list[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + + for stage_id, stage in enumerate(pp_segs): + for node in stage.select(ntype=IRFwOperation): + if node in fw_nodes: + continue + if node.fn == multiref: # skip multiref nodes + continue + assert node.fn == identity, "Internal Error: non-identity node added in staging" + # force identity nodes to be replicated + # these nodes are usually added for data transfer between stages in graph.staging + # TODO: is it possible to have TP here? + op_plans[node] = OpPlan(op=node, stage_id=stage_id, partition=None) + + # add multiref to an activation tensor when the states of the tensor and its grad are different + # among consumers and current segment's outputs + for ftensor, stage_info in tensor_splits.items(): + # Parameter are already handled above + if ftensor.is_grad() or ftensor.is_param(): + continue + + # check if this tensor is in the output of each stage + is_seg_output: dict[int, bool] = {} + for idx, stage in enumerate(pp_segs): + is_seg_output[idx] = IR.contains_object( + stage.outputs(), + lambda x: isinstance(x, IRSubTensor) and x.parent == ftensor + ) + + for idx, splits in stage_info.items(): + stage = pp_segs[idx] + split_list = list(splits) + if len(split_list) > 1 or ( + is_seg_output[idx] and split_list[0] is not None # treat segment output as a consumer + ): + _logger.debug(f'add multiref for {ftensor} in stage {stage}') + stage.multiref(ftensor, comment='activation') + + # stage-wise tensor parallelism + curr_devices = list(range(ngpus)) + for op_plan in op_plans.values(): + idx = op_plan.stage_id % pp_size + devs = curr_devices[idx * tp_size: (idx + 1)* tp_size] + if op_plan.partition is not None: + _tp(graph, op_plan.op, devs, idx=op_plan.partition.input, dim=op_plan.partition.dim) + else: + _replica(graph, op_plan.op, devs) + + # replicate dataloader + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs=list(range(ngpus))) + + if pp_enabled: + cfg.apply_pipeline_scheduler(graph, nstages, nmicros, scheduler) + + return graph + + +def pas_fsdp(graph, cfg: 'ComputeConfig'): + """ + A simple FSDP policy: + 1. all operators are replicated + 2. user specified modules with `cfg.pas_config.recompute_modules` are recomputed + 3. shard policy is configured in cfg.use_zero and cfg.zero_ngroups + 4. CPU offload is not supported + """ + if cfg.plan_ngpus != 1: + raise ValueError("FSDP policy only supports 1 plan GPU") + if not cfg.use_zero: + raise ValueError("FSDP policy requires use_zero to be 1/3") + # use 'recomputes' instead of 'recompute_modules' + # to avoid confliction with autodist config + recompute_modules = cfg.pas_config.get('recomputes', '') + # parse recompute_modules + # user can also provide a list of Module classes. + if isinstance(recompute_modules, str): + recompute_modules = recompute_modules.strip() + if not recompute_modules: + recompute_modules = [] + else: + recompute_modules = [m.strip() for m in recompute_modules.split(',')] + + if recompute_modules: + recompute_modules = [load_type(rm) for rm in recompute_modules] + else: + recompute_modules = [] + + cur_recompute_id = -1 + cur_recompute_module_fqn = None + for node in get_pas_ops(graph): + recompute_module: torch.nn.Module + for rm in recompute_modules: + if rm in node.module_class_chain: + recompute_module = rm + break + else: + cur_recompute_module_fqn = None + continue + + mod_fqn = node.get_module_fqn(recompute_module) + if cur_recompute_module_fqn is None or cur_recompute_module_fqn != mod_fqn: + cur_recompute_id += 1 + cur_recompute_module_fqn = mod_fqn + yield OpPlan(node, recompute_id=cur_recompute_id) diff --git a/nnscaler/runtime/__init__.py b/nnscaler/runtime/__init__.py index d0171757..46be9e99 100644 --- a/nnscaler/runtime/__init__.py +++ b/nnscaler/runtime/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from nnscaler.runtime import _patch_torch from nnscaler.runtime import executor from nnscaler.runtime import device from nnscaler.runtime import adapter diff --git a/nnscaler/runtime/_patch_torch.py b/nnscaler/runtime/_patch_torch.py new file mode 100644 index 00000000..53ab7438 --- /dev/null +++ b/nnscaler/runtime/_patch_torch.py @@ -0,0 +1,104 @@ +# The following code is copied from torch.distributed.distributed_c10d in PyTorch 2.4.0 +# For copyright, see pytorch/LICENSE +# https://github.com/pytorch/pytorch/blob/main/LICENSE + + +import torch +import torch.distributed + + +if torch.__version__ < (2, 4, 0): + # send_object_list and recv_object_list only available in PyTorch 2.4.0+ + + import torch.distributed.distributed_c10d as dist_c10d + + + if torch.__version__ < (2, 3, 0): + def _object_to_tensor(obj, device, group): + return dist_c10d._object_to_tensor(obj, device) + else: + def _object_to_tensor(obj, device, group): + return dist_c10d._object_to_tensor(obj, device, group) + + + if torch.__version__ < (2, 3, 0): + def _tensor_to_object(tensor, size, group): + return dist_c10d._tensor_to_object(tensor, size) + else: + def _tensor_to_object(tensor, size, group): + return dist_c10d._tensor_to_object(tensor, size, group) + + + def send_object_list(object_list, dst, group=None, device=None): + if torch.distributed.get_rank() == dst: + raise ValueError( + "Invalid destination rank: destination rank should not be the same as " + "the rank of the current process." + ) + + if dist_c10d._rank_not_in_group(group): + dist_c10d._warn_not_in_group("send_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # sent to this device. + current_device = device or torch.device("cuda", torch.cuda.current_device()) + # Serialize object_list elements to tensors on src rank. + tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + + # Send object sizes + torch.distributed.send(object_sizes_tensor, dst=dst, group=group) + + # Concatenate and send serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + + torch.distributed.send(object_tensor, dst=dst, group=group) + + + def recv_object_list(object_list, src=None, group=None, device=None): + if dist_c10d._rank_not_in_group(group): + dist_c10d._warn_not_in_group("recv_object_list") + return -1 + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # received to this device. + current_device = device or torch.device("cuda", torch.cuda.current_device()) + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + + # Receive object sizes + rank_sizes = torch.distributed.recv(object_sizes_tensor, src=src, group=group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device + ) + + rank_objects = torch.distributed.recv(object_tensor, src=src, group=group) + assert rank_sizes == rank_objects, "Mismatch in return ranks for object sizes and objects." + # Deserialize objects using their stored sizes. + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + return rank_objects + + torch.distributed.send_object_list = send_object_list + torch.distributed.recv_object_list = recv_object_list diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index be83fdc3..bf795d43 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -3,6 +3,7 @@ from typing import List, Dict, Tuple, Any, Callable, Optional, Set, Sequence from functools import partial +from dataclasses import dataclass import math import logging import torch @@ -11,12 +12,13 @@ from nnscaler.runtime.device import DeviceGroup from nnscaler.profiler.timer import CudaTimer from nnscaler.flags import RuntimeFlag +from nnscaler.utils import unchecked_fields _logger = logging.getLogger(__name__) # According to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#device-memory-accesses -# Any address of a variable residing in global memory or returned by one of the memory allocation +# Any address of a variable residing in global memory or returned by one of the memory allocation # routines from the driver or runtime API is always aligned to at least 256 bytes. # But in our practice, we found that 16 bytes alignment is enough, it can be modified if unaligned access is detected. ALIGNED_BYTES = 16 @@ -59,15 +61,31 @@ def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: raise KeyError(f"Unsupported reduce op {reduce_op}. Supported reduce op: {supported}") +@dataclass +class _Z3ParamInfo: + shape: torch.Size # original shape of the parameter + start: int + end: int + param_buffer_start: int = -1 + param_buffer_end: int = -1 + + def numel(self) -> int: + return self.end - self.start + + def numel_with_padding(self) -> int: + return self.param_buffer_end - self.param_buffer_start + + class Bucket: - def __init__(self, params: List[torch.nn.Parameter], + def __init__(self, reducer: 'Reducer', params: List[torch.nn.Parameter], param_buffer: torch.Tensor, grad_buffer: torch.Tensor, reduce_op: torch.distributed.ReduceOp, - group: torch.distributed.ProcessGroup, async_op: bool, zero: bool, + group: torch.distributed.ProcessGroup, async_op: bool, zero: int, zero_subgroup: torch.distributed.ProcessGroup = None, zero_crossgroup: torch.distributed.ProcessGroup = None, zero_use_reduce_scatter: bool = False, align_size: int = ALIGNED_BYTES, + param_cls: Any = None, ): """ Create a communication unit for parameter allreduce. @@ -82,14 +100,17 @@ def __init__(self, params: List[torch.nn.Parameter], reduce_op (torch.distributed.ReduceOp): the reduce op used by collectives group (torch.distributed.ProcessGroup): communication group async_op (bool): whether to use asynchronous operation - zero (bool): whether to use zero optimization on gradients + zero (int): whether to use zero optimization on gradients, currently only 0/1/3 are supported + zero=2 will be treated as zero=3 zero_subgroup (torch.distributed.ProcessGroup): the subgroup for zero optimization the current rank belongs to zero_crossgroup (torch.distributed.ProcessGroup): the communication group for cross zero group allreduce when reduce scatter is enabled zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization align_size (int): the alignment size in bytes for each parameter + param_cls (Any): the class of the parameters """ self._params: List[torch.nn.Parameter] = params + self._param_cls: Any = param_cls self._pofset: Dict[torch.nn.Parameter, int] = {} self._reduce_op = reduce_op self._group = group @@ -99,7 +120,7 @@ def __init__(self, params: List[torch.nn.Parameter], self._hooks: List[Tuple[Any, RemovableHandle]] = [] self._async: bool = async_op - self._zero: bool = zero + self._zero: int = zero self._zero_use_reduce_scatter = zero_use_reduce_scatter self._contiguous_params = param_buffer self._contiguous_grads = grad_buffer @@ -123,6 +144,9 @@ def __init__(self, params: List[torch.nn.Parameter], self._pre_hooks: List[Callable] = [] self._post_hooks: List[Callable] = [] + self._z3 = self._zero > 1 + self._reducer = reducer + # only async will enable contiguous gradient self.build() self.register_hooks() @@ -137,11 +161,21 @@ def params(self) -> List[torch.nn.Parameter]: """Parameter list""" return self._params + @property + def param_cls(self) -> Any: + """Class of the parameters in the bucket""" + return self._param_cls + @property def zero(self) -> bool: """Whether enable zero for this bucket""" return self._zero + @property + def zero3(self) -> bool: + """Whether enable zero3 for this bucket""" + return self._z3 + def get_aligned_numel(self, param) -> int: """ Get the aligned number of elements for a parameter @@ -168,6 +202,22 @@ def _group_reduce_scatter(self): partial_tensor, self._contiguous_grads, op=self._reduce_op, group=self._zero_subgroup) + def _get_opt_param_data(self): + if not self._zero or self._zero > 1: + # when zero3 is used, the parameters are already sharded in reducer + opt = self._contiguous_params + else: + assert self._zero == 1 + rank = torch.distributed.get_rank(group=self._zero_subgroup) + assert len(self._contiguous_params) % self._zgroup_sz == 0 + # Note: + # There may be paddings both in the middle and at the end of the contiguous buffer + # When there are paddings in the middle or end of the contiguous buffer, + # the calculation of gnorm is not affected as long as the paddings are all 0. + # So for now, it looks harmless. + opt = self._contiguous_params.chunk(self._zgroup_sz)[rank] + return opt + def build(self): """ Build offset for each parameter @@ -179,18 +229,7 @@ def build(self): ofst += _aligned_nelement(param.nelement(), param.element_size(), self._align_size) # build parameter for optimizer (shared storage). # Its gradient will be updated everytime calling `self.sync_grads()` - if not self._zero: - opt = self._contiguous_params - else: - rank = torch.distributed.get_rank(group=self._zero_subgroup) - assert len(self._contiguous_params) % self._zgroup_sz == 0 - # Note: - # There may be paddings both in the middle and at the end of the contiguous buffer - # When there are paddings in the middle or end of the contiguous buffer, - # the calculation of gnorm is not affected as long as the paddings are all 0. - # So for now, it looks harmless. - opt = self._contiguous_params.chunk(self._zgroup_sz)[rank] - self._param_for_optimizer = torch.nn.Parameter(opt) + self._param_for_optimizer = torch.nn.Parameter(self._get_opt_param_data()) def register_hooks(self): """ @@ -205,13 +244,41 @@ def register_hooks(self): """ @torch.no_grad() - def post_grad_hook(param: torch.nn.Parameter, *unused): + def post_grad_hook(param: torch.nn.Parameter, *unused): # pragma: no cover # stream = DeviceGroup().get_stream('reducer') ofst = self._pofset[param] + rank = torch.distributed.get_rank() # TODO: need to handle sparse gradients in torch.nn.Embedding - self._contiguous_grads[ofst:ofst+param.numel()].add_(param.grad.data.view(-1)) + if self._z3: + z3_info = self._reducer.get_z3_info(param) + grad = param.grad.data.view(-1) + padded_numel = z3_info.numel_with_padding() * self._zgroup_sz + if grad.numel() < padded_numel: + # add padding + grad = torch.cat( + [grad, + torch.zeros(padded_numel - grad.numel(), device=grad.device, dtype=grad.dtype)] + ) + output = torch.zeros(z3_info.numel_with_padding(), device=grad.device, dtype=grad.dtype) + torch.distributed.reduce_scatter_tensor( + output, + grad, + op=self._reduce_op, + group=self._zero_subgroup + ) + # accumulate the param grad in zero3 way + self._contiguous_grads[ofst:ofst+z3_info.numel()]\ + .add_(output[0:z3_info.end-z3_info.start]) + else: + self._contiguous_grads[ofst:ofst+param.numel()].add_(param.grad.data.view(-1)) + param.grad = None + if self._z3: + # in most cases, it is not necessary to post-evict here, + # let's add it for safety + self._reducer.postevict_param(param) + if RuntimeFlag.skip_reducer: return self._async_param_cnt += 1 @@ -225,7 +292,9 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): # apply pre hooks self._apply_pre_hooks() # communication - if self._zero and self._zero_use_reduce_scatter: + if self._zero == 1 and self._zero_use_reduce_scatter: + # when zero3 is used, the parameters and gradients are already sharded in reducer + # so only allreduce is needed if self._zgroup_sz == self._wsz: rank = torch.distributed.get_rank(group=self._group) shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) @@ -236,9 +305,13 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): group=self._group, async_op=True) else: assert False, "group zero + reducescatter is not supported in async mode, " \ - "because the two steps (allreduce, reducescatter) use " \ - "two communication groups, which may induce deadlock." + "because the two steps (allreduce, reducescatter) use " \ + "two communication groups, which may induce deadlock." self._group_reduce_scatter() + elif self._zero > 1: + self._async_handle = torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, + group=self._zero_crossgroup, async_op=True) else: self._async_handle = torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, @@ -247,13 +320,24 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): for param in self._params: # same trick with FSDP and Megatron # reference: https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/fully_sharded_data_parallel.py#L3177-L3188 - param_tmp = param.expand_as(param) + if self._z3: + old_param_data = param.data + # here we need the full parameter to build the computation graph + # let's create a temporary parameter with full shape to fake it. + param.data = torch.empty(self._reducer.get_z3_info(param).shape, dtype=param.dtype, device=param.device) + param_tmp = param.expand_as(param) + param.data = old_param_data + else: + param_tmp = param.expand_as(param) + # gets its AccumulateGrad object grad_acc = param_tmp.grad_fn.next_functions[0][0] hook = grad_acc.register_hook(partial(post_grad_hook, param)) # grad_acc must keep, otherwise the hook won't take effect self._hooks.append((grad_acc, hook)) + torch.cuda.empty_cache() + def sync_grads(self): """ Wait until allreduce finished (async), or perform allreduce (sync). @@ -274,8 +358,14 @@ def sync_grads(self): # apply pre-hooks self._apply_pre_hooks() # synchrnoize gradients - if self._zero and self._zero_use_reduce_scatter: + if self._zero == 1 and self._zero_use_reduce_scatter: self._group_reduce_scatter() + elif self._zero > 1: + torch.distributed.all_reduce( + self._contiguous_grads, + op=self._reduce_op, + group=self._zero_crossgroup + ) else: torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, group=self._group) @@ -284,10 +374,16 @@ def sync_grads(self): for param in self._params: assert param.grad is None pofst = self._pofset[param] + if self._z3: + z3_info = self._reducer.get_z3_info(param) + # the param should have been evicted + assert z3_info.numel_with_padding() == param.numel() and len(param.shape) == 1, \ + f"internal error: zero3 param size mismatch, " \ + f"expect {[z3_info.numel_with_padding()]} got {param.shape}" param.grad = self._contiguous_grads[pofst:pofst+param.numel()].view(param.size()) # setup gradient for optimizer parameters - if self._zero: + if self._zero == 1: rank = torch.distributed.get_rank(group=self._zero_subgroup) grad = self._contiguous_grads.chunk(self._zgroup_sz, dim=0)[rank] self._param_for_optimizer.grad = grad @@ -301,7 +397,7 @@ def gather_params(self): """ All-gather parameters """ - assert self._zero, "gathering paramters is only for zero optimization." + assert self._zero == 1, "gathering paramters is only for zero1 optimization." rank = torch.distributed.get_rank(group=self._zero_subgroup) CudaTimer().start(field_name='comm', predefined=True) src_tensor = self._contiguous_params.chunk(self._zgroup_sz, dim=0)[rank] @@ -363,6 +459,81 @@ def reset(self): self._async_param_cnt = 0 self._async_handle = None + def sleep(self): + """ + release reference to contiguous buffer in reducer + """ + cpu = torch.device('cpu') + self._param_for_optimizer.data = self._param_for_optimizer.data.to(cpu) + # set none to release memory + self._contiguous_params = None + self._contiguous_grads = None + + def wake_up(self, param_buffer, grad_buffer): + """ + re-attach to the contiguous buffer and re-build hooks + """ + self._contiguous_params = param_buffer + self._contiguous_grads = grad_buffer + self._param_for_optimizer.data = self._get_opt_param_data() + + # TODO(yizhu1): seems moving attributes to cpu will make hooks invalid. + # The reason is that torch's autograd will reset the AccumulateGrad object if the data is set: + # https://github.com/pytorch/pytorch/blob/38a492d40d7ebb2856cb120df337c6cdac244528/torch/csrc/autograd/variable.cpp#L473 + # To make the resuming process safe, re-register them here. + self._hooks = [] + self.register_hooks() + + def _pack( + self, + param_map: dict[torch.nn.Parameter, torch.nn.Parameter], + ): + """ + Get the information of the bucket. + """ + state = self.__dict__.copy() + + fields = unchecked_fields(self) + state[fields._params] = [param_map[p] for p in self._params] + state[fields._pofset] = {param_map[p]: ofst for p, ofst in self._pofset.items()} + state[fields._param_for_optimizer] = torch.nn.Parameter(torch.empty_like(self._param_for_optimizer, device='meta')) + state[fields._contiguous_params] = torch.empty_like(self._contiguous_params, device='meta') + state[fields._contiguous_grads] = torch.empty_like(self._contiguous_grads, device='meta') + + # remove torch handles + state.pop(fields._group, None) + state.pop(fields._async_handle, None) + state.pop(fields._async_param_cnt, None) + state.pop(fields._zero_subgroup, None) + state.pop(fields._zero_crossgroup, None) + + # remove hooks + state.pop(fields._hooks, None) + state.pop(fields._pre_hooks, None) + state.pop(fields._post_hooks, None) + + # remove reducer reference + state.pop(fields._reducer, None) + + return state + + @classmethod + def _unpack(cls, state: dict, reducer: 'Reducer'): + """ + Return a fake bucket that carries the same information. + """ + bucket = object.__new__(cls) + bucket.__dict__.update(state) + bucket._reducer = reducer + + for param in bucket._params: + assert param.device.type == 'meta' + assert bucket._contiguous_grads.device.type == 'meta' + assert bucket._contiguous_grads.device.type == 'meta' + assert bucket._param_for_optimizer.device.type == 'meta' + + return bucket + class Reducer: # the default bucket cap for async reducer in megabytes @@ -370,11 +541,13 @@ class Reducer: # https://github.com/pytorch/pytorch/blob/4fd16dd8aa259cd75c9a6d2ddcd8171cd1ee8e28/torch/nn/parallel/distributed.py#L548 _DEFAULT_BUCKET_CAP_MB = 25 # 25MB, the same as pytorch - def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None, - reduce_op: str = 'sum', async_op: bool = False, - zero: bool = False, zero_ngroups: int = 1, - zero_use_reduce_scatter: bool = False, - align_size: int = ALIGNED_BYTES + def __init__(self, ranks: List[int], + *, + max_bucket_size_bytes: Optional[int] = None, + reduce_op: str = 'sum', async_op: bool = False, + zero: int = 0, zero_ngroups: int = 1, + zero_use_reduce_scatter: bool = False, + align_size: int = ALIGNED_BYTES, ): """ Create a reducer applied on a set of weights for weight reduction @@ -389,12 +562,15 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None Default is `None` reduce_op (str): reduce operation, can be 'sum', 'avg', 'max' or 'min' (default 'sum') async_op (bool): whether to overlap with backward computation (default False) - zero (bool): whether to apply ZeRO optimization on gradients + zero (int): whether to use zero optimization on gradients, currently only 0/1/3 are supported + zero=2 will be treated as zero=3 zero_ngroups (int): number of ZeRO subgroups in the original ZeRO group zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization align_size (int): the alignment size in bytes for each parameter """ + # the parameters with same class will be consecutive in the list. self._params: List[torch.nn.Parameter] = list() + self._param_clss: Dict[torch.nn.Parameter, Any] = dict() # the class of each parameter, used for sorting self._param_ids: Set[int] = set() self._numel: int = 0 self._ranks = ranks @@ -409,7 +585,7 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None # buckets stands for a transission unit self._buckets: List[Bucket] = list() self._async: bool = async_op - self._zero: bool = zero + self._zero: int = int(zero) self._zero_use_reduce_scatter = zero_use_reduce_scatter self._align_size: int = align_size if self._align_size % ALIGNED_BYTES != 0: @@ -419,6 +595,13 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None self._contiguous_params: torch.Tensor = None self._contiguous_grads: torch.Tensor = None + # record following variables for params offload + # items in the bucket is params list + self.seq_buckets: List[List[torch.nn.Parameter]] = [] + # bucket start and stop pos in buffer + self.starts, self.stops = [], [] + self.buffer_length: int = 0 + # build the subgroup of zero the current rank belongs to. # When zero_ngroups is larger than 1, the number of ranks # will be divided by zero_ngroups into sub rank groups, @@ -454,9 +637,18 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None else: assert zero_ngroups == 1, f"ZeRO number of groups must be 1, but got {zero_ngroups}" self._zero_subgroup = self._group - self._zero_crossgroup = None + # trivial crossgroup for single rank + self._zero_crossgroup = DeviceGroup().get_group([torch.distributed.get_rank()]) + self._zero_ngroups = zero_ngroups + self._z3_size = torch.distributed.get_world_size(group=self._zero_subgroup) + if self._z3_size == 1: + self._zero = 0 # disable zero when only one rank in subgroup + self._z3 = self._zero > 1 + self._z3_rank = torch.distributed.get_rank(group=self._zero_subgroup) + self._z3_params_info: dict[torch.nn.Parameter, _Z3ParamInfo] = dict() + @property def zero_ngroups(self) -> int: return self._zero_ngroups @@ -479,6 +671,11 @@ def zero(self) -> bool: """Whether to apply zero optimization on gradients""" return self._zero + @property + def zero3(self) -> bool: + """Whether to apply ZeRO3""" + return self._zero > 1 + @property def buckets(self) -> Tuple[Bucket, ...]: return tuple(self._buckets) @@ -506,15 +703,73 @@ def add_param(self, param: torch.nn.Parameter): self._param_ids.add(param.data.data_ptr()) self._numel += param.numel() - def build_buckets(self): + def _allocate_buffers(self): + # gradient buffer + self._contiguous_grads: torch.Tensor = torch.zeros( + (self.buffer_length,), dtype=self._params[0].dtype, + device=torch.cuda.current_device(), requires_grad=False) + # parameter buffer + self._contiguous_params: torch.Tensor = torch.zeros( + (self.buffer_length,), dtype=self._params[0].dtype, + device=torch.cuda.current_device(), requires_grad=False) + + def _bind_params(self): + for params, start, stop in zip(self.seq_buckets, self.starts, self.stops): + # replace underlying parameter content using shared storage from parameter + ofst = start + for param in params: + with torch.no_grad(): + self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) + if self._z3: + param.data = self._contiguous_params[ofst:ofst+param.numel()] + self._z3_params_info[param].param_buffer_start = ofst + self._z3_params_info[param].param_buffer_end = ofst + param.numel() + else: + param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) + aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) + ofst += aligned_nelements + + def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None): """ Build buckets the reducer. - The parameters in each bucket have consistent data types, + The parameters in each bucket have consistent data types and classes, and each bucket contains at least one parameter. If the bucket contains more than 2 parameters, than the total size is samller than the max_bucket_size_bytes. """ + self._param_clss = {} + if param_clss: + # only keep parameters that are in self._params + self._param_clss = {p: param_clss[p] for p in self._params} + # sort parameters by their class + # which can help bucket building + self._params.sort(key=lambda p: self._param_clss[p]) + + # step 0: param split for zero3 + if self._z3: + for param in self._params: + if not param.requires_grad: + continue + + chunk_size = (param.numel() + self._z3_size - 1) // self._z3_size + start = self._z3_rank * chunk_size + end = min(start + chunk_size, param.numel()) + self._z3_params_info[param] = _Z3ParamInfo(shape=param.shape, start=start, end=end) + # clone the data so original param can be released + # this padding is required + # to make sure all ranks in the zero subgroup have the same bucket layout. + if end - start < chunk_size: + padding = chunk_size - (end - start) + param.data = torch.cat([ + param.data.view(-1)[start:end].clone(), + torch.zeros(padding, dtype=param.dtype, device=param.device) + ], dim=0) + else: + param.data = param.data.view(-1)[start:end].clone() + + torch.cuda.empty_cache() + # step 1: build bucket for overlapping gradient synchronization # self._numel * 8 + 1 here is to make sure # the bucket size is larger than the total size of all parameters @@ -526,9 +781,9 @@ def build_buckets(self): # (used in pytorch, with a couple percentage improvement) bucket_size = self._numel * 8 + 1 if not self._bucket_size else self._bucket_size - # items in the bucket is params list - seq_buckets: List[List[torch.nn.Parameter]] = [] + seq_buckets_cls: List[Any] = [] last_bucket_size = None + last_bucket_cls = None assert len(set(p.dtype for p in self._params)) == 1, ( "All parameters in the reducer should have the same data type" @@ -540,53 +795,45 @@ def build_buckets(self): # It will go the `else` branch # and finish the current bucket and start a new bucket. # This new bucket will be sealed in the next iteration - if len(seq_buckets) == 0: - seq_buckets.append([param]) + if len(self.seq_buckets) == 0: + self.seq_buckets.append([param]) last_bucket_size = cur_byte_size - elif last_bucket_size + cur_byte_size <= bucket_size: - seq_buckets[-1].append(param) + last_bucket_cls = self._param_clss.get(param, None) + seq_buckets_cls.append(last_bucket_cls) + elif last_bucket_size + cur_byte_size <= bucket_size \ + and last_bucket_cls == self._param_clss.get(param, None): + self.seq_buckets[-1].append(param) last_bucket_size += cur_byte_size else: - seq_buckets.append([param]) + self.seq_buckets.append([param]) last_bucket_size = cur_byte_size + last_bucket_cls = self._param_clss.get(param, None) + seq_buckets_cls.append(last_bucket_cls) # step 2: build meta data for the offset of each bucket # the start of each bucket will be padded to the next multiple of `len(self.ranks)` - buffer_length: int = 0 - starts, stops = [], [] - for params in seq_buckets: - starts.append(buffer_length) + for params in self.seq_buckets: + self.starts.append(self.buffer_length) numel = sum(_aligned_nelement(p.nelement(), p.element_size(), self._align_size) for p in params) # this pad is for zero, which needs numels in each Bucket can be divided by the number of ranks in this group * _align_size # so that each chunck during zero can be divided by _align_size align_nelements = self._align_size // params[0].element_size() * len(self._ranks) padding = (align_nelements - numel % align_nelements) % len(self._ranks) - buffer_length += numel + padding - stops.append(buffer_length) + self.buffer_length += numel + padding + self.stops.append(self.buffer_length) - # step3: allocate memory - # gradient buffer - self._contiguous_grads: torch.Tensor = torch.zeros( - (buffer_length,), dtype=self._params[0].dtype, - device=torch.cuda.current_device(), requires_grad=False) - # parameter buffer - self._contiguous_params: torch.Tensor = torch.zeros( - (buffer_length,), dtype=self._params[0].dtype, - device=torch.cuda.current_device(), requires_grad=False) + # step 3: allocate memory + self._allocate_buffers() + + # step 4: bind parameters + self._bind_params() - # step 4: build buckets + # step 5: build buckets buckets: List[Bucket] = [] - for params, start, stop in zip(seq_buckets, starts, stops): - # replace underlying parameter content using shared storage from parameter - ofst = start - for param in params: - with torch.no_grad(): - self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) - param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) - aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) - ofst += aligned_nelements + for params, param_cls, start, stop in zip(self.seq_buckets, seq_buckets_cls, self.starts, self.stops): # initialize buckets bucket = Bucket( + self, params, self._contiguous_params[start:stop], self._contiguous_grads[start:stop], @@ -598,6 +845,7 @@ def build_buckets(self): self._zero_crossgroup, self._zero_use_reduce_scatter, self._align_size, + param_cls=param_cls, ) buckets.append(bucket) torch.cuda.empty_cache() @@ -617,12 +865,58 @@ def sync_grads(self): for bucket in self._buckets: bucket.sync_grads() + def get_z3_info(self, param: torch.nn.Parameter) -> _Z3ParamInfo: + """ + Get zero3 param info + if the param is not in zero3, return None + """ + return self._z3_params_info.get(param, None) + + @torch.no_grad() + def prefetch_param(self, param: torch.nn.Parameter): + """Prefetch parameter before forward and backward. + + This is required when zero3 is used. + """ + if not self._z3: + raise RuntimeError("postevict_param is only for zero3 optimization.") + if param not in self._z3_params_info: + raise ValueError(f"parameter {param} not found in zero3 params info.") + + info = self._z3_params_info[param] + if param.shape == info.shape: + # no need to gather + return + + full_data = torch.zeros(info.numel_with_padding() * self._z3_size, dtype=param.dtype, + device=torch.cuda.current_device()) + torch.distributed.all_gather_into_tensor( + full_data, + param.data, + group=self._zero_subgroup + ) + param.data = full_data[0:math.prod(info.shape)].view(info.shape).contiguous() + + @torch.no_grad() + def postevict_param(self, param: torch.nn.Parameter): + """Release parameter after forward and backward. + + This is required when zero3 is used. + """ + if not self._z3: + raise RuntimeError("postevict_param is only for zero3 optimization.") + if param not in self._z3_params_info: + raise ValueError(f"parameter {param} not found in zero3 params info.") + info = self._z3_params_info[param] + param.data = self._contiguous_params[info.param_buffer_start:info.param_buffer_end] + def gather_params(self): """Gather parameters with Zero optimizations after `optimizer.step()`. This is required when zero optimization is turned on. """ if not self._zero: return + if self._z3: return # in zero3 mode, no need to gather params for bucket in self._buckets: bucket.gather_params() @@ -652,9 +946,23 @@ def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: Returns: List[torch.nn.Parameter]: parameters for optimizer """ - params = [] + return list(self.get_opt_params().keys()) + + def get_opt_params(self) -> dict[torch.nn.Parameter, Any]: + """ + Get parameters and their classes for optimizers + Please note for ZeRO optimization, + the returned parameters are not the same as the original parameters, + and can have paddings (with value 0.0) both at the end and in the middle of paramters data. + + the calculation of gnorm is not affected as paddings are all 0. + + Returns: + List[torch.nn.Parameter]: parameters for optimizer + """ + params = {} for bucket in self._buckets: - params.append(bucket._param_for_optimizer) + params[bucket._param_for_optimizer] = bucket.param_cls return params def broadcast_params(self): @@ -723,3 +1031,82 @@ def clear_post_hooks(self): """Clear all post hooks.""" for bucket in self._buckets: bucket.clear_post_hooks() + + def sleep(self): + """ + release contiguous buffers on the device to save memory + """ + for bucket in self._buckets: + bucket.sleep() + + self._contiguous_params = None + self._contiguous_grads = None + + def wake_up(self): + """ + reallocate contiguous buffers and related objects + """ + self._allocate_buffers() + self._bind_params() + + for start, stop, bucket in zip(self.starts, self.stops, self._buckets): + bucket.wake_up( + self._contiguous_params[start:stop], + self._contiguous_grads[start:stop], + ) + + def _pack( + self, + param_map: dict[torch.nn.Parameter, torch.nn.Parameter], + ): + """ + Get the information of the bucket. + """ + state = self.__dict__.copy() + fields = unchecked_fields(self) + + state[fields._params] = [param_map[p] for p in self._params] + state[fields._z3_params_info] = {param_map[p]: info for p, info in self._z3_params_info.items()} + state[fields._param_clss] = {param_map[p]: param_cls for p, param_cls in self._param_clss.items()} + state[fields._contiguous_params] = torch.empty_like(self._contiguous_params, device='meta') + state[fields._contiguous_grads] = torch.empty_like(self._contiguous_grads, device='meta') + + state[fields._buckets] = [ + bucket._pack(param_map) + for bucket in self._buckets + ] + + # remove torch handles + state.pop(fields._group, None) + state.pop(fields._zero_subgroup, None) + state.pop(fields._zero_crossgroup, None) + + # remove unuseful information + state.pop(fields._param_ids, None) + state.pop(fields.seq_buckets, None) + + return state + + @classmethod + def _unpack(cls, state: dict): + """ + Return a fake bucket that carries the same information. + """ + reducer = object.__new__(cls) + fields = unchecked_fields(reducer) + + buckets = state.pop(fields._buckets) + reducer._buckets = [ + Bucket._unpack(bucket, reducer) for bucket in buckets + ] + reducer.__dict__.update(state) + for param in reducer._params: + assert param.device.type == 'meta' + + for param in reducer._param_clss.keys(): + assert param.device.type == 'meta' + + assert reducer._contiguous_grads.device.type == 'meta' + assert reducer._contiguous_params.device.type == 'meta' + + return reducer diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py index 414361a5..908a4f3e 100644 --- a/nnscaler/runtime/f16_optimizer.py +++ b/nnscaler/runtime/f16_optimizer.py @@ -32,6 +32,10 @@ def __init__(self, *args, **kwargs): # forward __init__ call to the next class in mro(method resolution order) super().__init__(*args, **kwargs) self._multiply_factor = 1.0 + # This flag is used to indicate whether fp32_params are loaded from checkpoint. + # If not, we will sync from fp16 params to fp32 params in after_load_checkpoint. + # If the model is trained from scratch, this flag will be None. + self._fp32_params_loaded = None def after_setup(self, trainer: 'Trainer') -> None: """ @@ -111,12 +115,15 @@ def load_state_dict(self, state_dict): param.data = state_dict['state'][i]['fp32_params'].data.to(device) # pop to avoid store a redundant copy in the wrapped optimizer state_dict['state'][i].pop('fp32_params') + else: + logger.warning('fp32_params not found in state_dict, will sync from fp16 params to fp32 params') + self._sync_fp16_params_to_fp32() - if len(self.param_groups) != 1: - raise RuntimeError('only support one param group') - self.param_groups[0]['params'] = self.fp32_params + if len(self.param_groups) != 1: + raise RuntimeError('only support one param group') super().load_state_dict(state_dict) + self._fp32_params_loaded = True def _sync_f16_grads_to_fp32(self): # copy FP16 grads to FP32 @@ -148,10 +155,15 @@ def _sync_fp16_params_to_fp32(self): continue p32.data.copy_(p.data) + def on_load_checkpoint(self, trainer, checkpoint) -> None: + self._fp32_params_loaded = False + logger.info('Set _fp32_params_loaded to False in on_load_checkpoint hook') + def after_load_checkpoint(self, trainer, checkpoint) -> None: - if 'nnscaler' not in checkpoint: - # this checkpoint is not created by nnscaler. + if not self._fp32_params_loaded: + logger.info('fp32_params not loaded, will sync from fp16 params to fp32 params') self._sync_fp16_params_to_fp32() + self._fp32_params_loaded = True def overrided_scale_grads(self, scale: float): """ diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 51ef947c..cba44779 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -8,7 +8,7 @@ """ from contextlib import contextmanager -from typing import Optional, List, Tuple, Union, Any +from typing import Callable, Optional, List, Tuple, Union, Any import torch import torch.nn.functional as TorchF import operator @@ -81,11 +81,24 @@ def fold_constant(a: Any) -> Any: return a -def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: +def multiref(tensor: torch.Tensor, times: int, *, clone_level: int = 0) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """ identity forward. Create multiple same tensor. + Args: + tensor (torch.Tensor): input tensor + times (int): number of same tensor to create + clone_level (int): 0: no clone, 1: clone once for all, 2: clone each time + Returns: + Union[torch.Tensor, Tuple[torch.Tensor]]: + if times==1, return tensor; else return tuple of tensors """ - return tensor if times == 1 else tuple([tensor] * times) + if clone_level == 0: + return tensor if times == 1 else tuple([tensor] * times) + elif clone_level == 1: + cloned_tensor = tensor.clone() + return cloned_tensor if times == 1 else tuple([cloned_tensor] * times) + else: # clone_level == 2 + return tensor.clone() if times == 1 else tuple([tensor.clone() for _ in range(times)]) def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) -> torch.Tensor: @@ -366,4 +379,24 @@ def print_time(content: str): rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1 if torch.cuda.is_available(): torch.cuda.synchronize() - print(f"line timer: {rank} - {datetime.datetime.now()} - {content}") \ No newline at end of file + print(f"line timer: {rank} - {datetime.datetime.now()} - {content}") + + +class _BackwardHook(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, backward_hook: Callable[[], None]): + ctx.save_for_backward() + ctx.backward_hook = backward_hook + return x + + @staticmethod + def backward(ctx, grad_output): + ctx.backward_hook() + return grad_output, None + + +def insert_backward_hook(x: torch.Tensor, backward_hook: Optional[Callable[[], None]]) -> torch.Tensor: + if backward_hook is None: + # no need to add hook + return x + return _BackwardHook.apply(x, backward_hook) diff --git a/nnscaler/runtime/gnorm.py b/nnscaler/runtime/gnorm.py index eb6a3e5b..36d46c38 100644 --- a/nnscaler/runtime/gnorm.py +++ b/nnscaler/runtime/gnorm.py @@ -40,7 +40,7 @@ class TidReplicaInfo: def _calc_grad_shape(slicers_list): - # caculate the shape of each full parameters/grads + # calculate the shape of each full parameters/grads tid2shape = {} for rank_slicers in slicers_list: for tid, slicers in rank_slicers.items(): @@ -50,7 +50,7 @@ def _calc_grad_shape(slicers_list): # slicer: (start, end, step) if slicer.stop > tid2shape[tid][i]: tid2shape[tid][i] = slicer.stop - # caculate the number of replicas of each model parameter + # calculate the number of replicas of each model parameter tid2nreplicas = {} for rank_slicers in slicers_list: for tid, slicers in rank_slicers.items(): @@ -117,7 +117,7 @@ def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int Returns: tid2nreplicas: dict, tid -> TidReplicaInfo """ - # caculate the number of replicas of each model parameter + # calculate the number of replicas of each model parameter tid2nreplicas = {} tid2ranksset = defaultdict(set) for tid2ranks in tid2ranks_list: @@ -135,7 +135,7 @@ def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int return tid2nreplicas -def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, List[torch.nn.Parameter]]: +def prepare_for_grad_clip(cube_model: 'CubeModule', use_zero: int) -> Dict[int, List[torch.nn.Parameter]]: params_info_for_gnorm = cube_model.parameters_for_calc_gnorm() tid2ranks = {} tid2info_list_seq = {} @@ -174,7 +174,7 @@ def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, # multiplied by the number of ZeRO groups. Multiplying the number of pure replicated is easy # to understand. Multiplying the number of ZeRO groups is because the gradients of each ZeRO group # are full model gradients, so the number of ZeRO groups is the number of gradient replicas of the full model. - if not is_zero: + if not use_zero: nreplicas = replicated_info.nranks else: nreplicas = replicated_info.nreplicated * params_info.zero_ngroups @@ -241,7 +241,8 @@ def grad_exists(p): elif len(grads) == 1: total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) else: - if multi_tensor_l2norm_available: + dtypes = set([g.dtype for g in grads]) + if multi_tensor_l2norm_available and len(dtypes) == 1: total_norm = _multi_tensor_total_norm(grads).to(device) else: # torch.nn.utils.clip_grad_norm_ way to calculate the norm diff --git a/nnscaler/runtime/hybrid_optimizer.py b/nnscaler/runtime/hybrid_optimizer.py new file mode 100644 index 00000000..e120c3df --- /dev/null +++ b/nnscaler/runtime/hybrid_optimizer.py @@ -0,0 +1,304 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable, Type, Union + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.hooks import RemovableHandle + +from nnscaler.cli.arg_parser import deserialize_dataclass +from nnscaler.cli.train_hook import TrainHookHost, TrainHook +from nnscaler.utils import fn_field, OptStateDict + + +@dataclass +class HybridSubOptParamGroupConfig: + options: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HybridSubOptConfig: + type: Union[Type[Optimizer], Callable[..., Optimizer]] = fn_field(default=None) + options: dict[str, Any] = field(default_factory=dict) + param_groups: list[HybridSubOptParamGroupConfig] = field(default_factory=list) + + def __post_init__(self): + if not self.type: + raise ValueError("Optimizer type must be specified in HybridSubOptConfig") + + +@dataclass +class HybridOptConfig: + optimizers: list[HybridSubOptConfig] = field(default_factory=list) + + def __post_init__(self): + if not self.optimizers: + raise ValueError("At least one optimizer must be specified in HybridOptConfig") + + +class HybridRemovableHandle: + def __init__(self, removable_handles: list[RemovableHandle]): + self.removable_handles = removable_handles + + def remove(self): + for removable_handle in self.removable_handles: + removable_handle.remove() + + def __enter__(self) -> "HybridRemovableHandle": + return self + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + self.remove() + + +class HybridOptimizer(torch.optim.Optimizer, TrainHookHost): + """ + A hybrid optimizer that combines multiple optimizers/multiple param groups + into a single optimizer. + + Please note HybridOptimizer doesn't call super().__init__(), + So it is actually a duck type for optimizer. + """ + + # Identifier for hybrid optimizer + is_hybrid = True + + def __init__( + self, + params: Iterable[torch.nn.Parameter], + param_clss: dict[torch.nn.Parameter, tuple[int, int]], + config: Union[HybridOptConfig, dict[str, Any]] + ): + """ + Initialize the hybrid optimizer. + + Args: + params (Iterable[torch.nn.Parameter]): The parameters to optimize. + param_clss (dict[torch.nn.Parameter, tuple[int, int]]): The parameter classes for each parameter. + config (Union[HybridOptConfig, dict[str, Any]]): The configuration for the hybrid optimizer. + """ + params = list(params) + if isinstance(config, dict): + config = deserialize_dataclass(config, HybridOptConfig) + self.config = config + + self.optimizers = [] + classified_params = defaultdict(list) + # map from (optimizer_idx, pg_idx, param_pg_idx) to param global param index + param_loc = {} + + for idx, param in enumerate(params): + param_cls = param_clss[param] + assert param_cls[0] < len(self.config.optimizers) + classified_params[param_cls].append(param) + + loc = *param_cls, len(classified_params[param_cls]) - 1 + param_loc[loc] = idx + + # sort with key i.e. (optimizer idx, param group idx) + classified_params = dict(sorted(classified_params.items())) + + quick_param_groups = {param_cls: {"params": params} for param_cls, params in classified_params.items()} + opt_param_groups = defaultdict(dict) + for param_cls, group in quick_param_groups.items(): + opt_param_groups[param_cls[0]][param_cls[1]] = group + + for idx, opt_config in enumerate(config.optimizers): + param_groups = opt_param_groups[idx] + if len(param_groups) > 1: + if len(param_groups) != len(opt_config.param_groups): + raise ValueError(f"Expected {len(opt_config.param_groups)} param groups, got {len(param_groups)}") + # param group indices must be consecutive. + if max(param_groups.keys()) != len(opt_config.param_groups) - 1: + raise ValueError(f"Param group indices must be consecutive. We have {len(opt_config.param_groups)} groups, got max group id {max(param_groups.keys())}") + for param_group_idx, param_group in param_groups.items(): + param_group.update(opt_config.param_groups[param_group_idx].options) + else: + if len(opt_config.param_groups) > 1: + raise ValueError(f"Expected at most 1 param group, got {len(opt_config.param_groups)}") + if opt_config.param_groups: + param_groups[0].update(opt_config.param_groups[0].options) + optimizer = opt_config.type(param_groups.values(), **opt_config.options) + self.optimizers.append(optimizer) + + # map from param global index to (optimizer_idx, param_idx) + self._param_map: dict[int, tuple[int, int]] = {} + # map from (optimizer_idx, param_idx) to param global idx + self._reverse_param_map: dict[tuple[int, int], int] = {} + for opt_idx, optimizer in enumerate(self.optimizers): + state_dict: OptStateDict = optimizer.state_dict() + for pg_idx, pg in enumerate(state_dict['param_groups']): + for param_idx_in_pg, param_idx in enumerate(pg['params']): + # param_idx_in_pg is the index in this param group + # param_idx is the index in this optimizer + global_idx = param_loc[(opt_idx, pg_idx, param_idx_in_pg)] + self._param_map[global_idx] = (opt_idx, param_idx) + self._reverse_param_map[(opt_idx, param_idx)] = global_idx + + # Don't call base init + # So HybridOptimizer is a duck optimizer + # super().__init__(params, {}) + + # simulated param groups + self.param_groups = [] + for optimizer in self.optimizers: + self.param_groups.extend(optimizer.param_groups) + + def _get_hook_objects(self): + return self.optimizers + + def step(self, closure=None): + """ + Perform a single optimization step. + """ + assert closure is None, "Closure is not supported in HybridOptimizer" + for optimizer in self.optimizers: + optimizer.step(closure) + + def zero_grad(self, set_to_none: bool = False): + """ + Zero the gradients of all optimizers. + """ + for optimizer in self.optimizers: + optimizer.zero_grad(set_to_none=set_to_none) + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + " [\n" + format_string += ",\n".join(f"{repr(opt)}" for opt in self.optimizers) + format_string += "\n]" + return format_string + + def register_step_pre_hook(self, hook) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_step_pre_hook(hook) for opt in self.optimizers]) + + def register_step_post_hook(self, hook) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_step_post_hook(hook) for opt in self.optimizers]) + + def register_state_dict_pre_hook( + self, hook, prepend: bool = False + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_state_dict_pre_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def register_state_dict_post_hook( + self, + hook, + prepend: bool = False, + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_state_dict_post_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def state_dict(self): + state_dicts: list[OptStateDict] = [opt.state_dict() for opt in self.optimizers] + merged_state_dict: OptStateDict = {'state': {}, 'param_groups': [{'children': {}}]} + + for opt_idx, sd in enumerate(state_dicts): + for param_idx, s in sd['state'].items(): + merged_state_dict['state'][self._reverse_param_map[(opt_idx, param_idx)]] = s + merged_state_dict['param_groups'][0]['children'][opt_idx] = sd['param_groups'] + + merged_state_dict['param_groups'][0]['params'] = list(range(len(self._param_map))) + merged_state_dict['param_groups'][0]['param_map'] = self._param_map + merged_state_dict['param_groups'][0]['reverse_param_map'] = self._reverse_param_map + merged_state_dict['state'] = dict(sorted(merged_state_dict['state'].items())) + + return merged_state_dict + + def register_load_state_dict_pre_hook( + self, + hook, + prepend: bool = False, + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_load_state_dict_pre_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def register_load_state_dict_post_hook( + self, hook, prepend: bool = False + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_load_state_dict_post_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def load_state_dict(self, state_dict) -> None: + child_state_dicts = [{'state': {}, 'param_groups': []} for _ in self.optimizers] + + for idx, sd in enumerate(child_state_dicts): + # copy param groups from state dict + sd['param_groups'] = state_dict['param_groups'][0]['children'][idx] + if len(sd['param_groups']) != len(self.optimizers[idx].param_groups): + raise ValueError(f"Number of param groups mismatch. Expected {len(self.optimizers[idx].param_groups)} got {len(sd['param_groups'])}") + # param groups can be changed (for example, the compute config is changed) + # state_dict for HybridOptimizer is already well organized, + # here we will carefully dispatch parameters to each optimizer. + current_state_dict = self.optimizers[idx].state_dict() + for pg, current_pg in zip(sd['param_groups'], current_state_dict['param_groups']): + pg['params'] = current_pg['params'][:] # make a copy + + for param_idx, param_state in state_dict['state'].items(): + opt_idx, param_state_idx = self._param_map[param_idx] + child_state_dicts[opt_idx]['state'][param_state_idx] = param_state + + for child_state_dict, opt in zip(child_state_dicts, self.optimizers): + opt.load_state_dict(child_state_dict) + + # after loading from state dict, the param_groups of optimizers are reassigned + # (instead of updated inplace), so we need to gather them again (as we have done + # in the constructor). + self.param_groups = [] + for optimizer in self.optimizers: + self.param_groups.extend(optimizer.param_groups) + + def add_param_group(self, param_group: dict[str, Any]) -> None: + # no-op to avoid creating new parameter groups + # all parameter groups are managed by the individual optimizers + pass + + +@dataclass +class HybridSubLRSchedulerConfig: + type: Union[Type[LRScheduler], Callable[..., LRScheduler]] = fn_field(default=None) + options: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HybridLRSchedulerConfig: + schedulers: list[HybridSubLRSchedulerConfig] = field(default_factory=list) + + +class HybridLRScheduler(LRScheduler, TrainHookHost): + """ + A hybrid learning rate scheduler that combines multiple schedulers. + + Please note HybridLRScheduler doesn't call super().__init__(), + So it is actually a duck type for scheduler. + """ + + def __init__( + self, + optimizer: HybridOptimizer, + config: Union[HybridLRSchedulerConfig, dict[str, Any]], + last_epoch: int = -1, + ): + assert isinstance(optimizer, HybridOptimizer), "Optimizer must be an instance of HybridOptimizer" + if isinstance(config, dict): + config = deserialize_dataclass(config, HybridLRSchedulerConfig) + + if len(config.schedulers) == 1: + self.schedulers = [config.schedulers[0].type(optimizer, **config.schedulers[0].options)] + elif len(config.schedulers) == len(optimizer.optimizers): + self.schedulers = [sub_config.type(opt, **sub_config.options) for sub_config, opt in zip(config.schedulers, optimizer.optimizers)] + else: + raise ValueError(f"Expected {len(optimizer.optimizers)} or 1 schedulers, got {len(config.schedulers)}") + + def _get_hook_objects(self): + return self.schedulers + + def step(self, epoch=None): + for scheduler in self.schedulers: + scheduler.step(epoch) + + def state_dict(self): + return {idx: scheduler.state_dict() for idx, scheduler in enumerate(self.schedulers)} + + def load_state_dict(self, state_dict): + for idx, sd in state_dict.items(): + self.schedulers[idx].load_state_dict(sd) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 0e26d483..9d286489 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -1,16 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union +import functools +import pickle +from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union, ClassVar +from typing_extensions import Self import logging import os import sys +import gc +import warnings from pathlib import Path from dataclasses import dataclass, asdict from collections import defaultdict import torch import torch.distributed as dist +from torch import device +from torch.autograd.graph import saved_tensors_hooks from nnscaler.graph.parser import FxModuleParser @@ -19,10 +26,11 @@ from nnscaler.runtime.executor import Executor from nnscaler.runtime.gnorm import ParamsInfo from nnscaler.runtime.utils import microbatches +from nnscaler.runtime.function import insert_backward_hook from nnscaler import __version__ as runtime_version from nnscaler.flags import CompileFlag -from nnscaler.utils import accum_mode +from nnscaler.utils import accum_mode, classproperty, unchecked_fields if TYPE_CHECKING: from nnscaler.parallel import ComputeConfig @@ -46,24 +54,57 @@ class AttrMeta: # the number of the partitioned values, usually 1 # (i.e., no partition on value -> no need to sum up) val_chunks: int + # data type of the full tensor and sub tensor + dtype: torch.dtype + # shape of the sub tensor + # it should be the shape of full_tensor[slicers] + sub_shape: Tuple[int, ...] -def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, AttrMeta]]) -> Dict[int, Dict[str, AttrMeta]]: +@dataclass +class Zero3AttrMeta: + """ + Used for loading merged state dict + """ + # original name in the module + orig_name: str + # name in the module + attr_name: str + # start index of the sub tensor + start: int + # end index of the sub tensor + end: int + # chunk size of the sub tensor, can be bigger than end - start due to padding + chunk_size: int + + +def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, Dict[str, AttrMeta]]]) -> Dict[int, Dict[str, Dict[str, AttrMeta]]]: ''' Deduplicate the attributes according to `rank2attr_area_map`. - For each `slicers` of a full tensor with the name `orig_name`, we only store its first appearance - in the `rank2attr_area_map`. + For each `slicers` of a full tensor identified by its full qualified name, we only store its first appearance + in the `rank2attr_area_map`. In nnscaler, this dedup process leads to: + - If an attribute is not within the first scale unit, it will be deduplicated. + - If an attribute is shared by different operators, it will be deduplicated. + - If an attribute is replicated across several devices, we only save it at the devices with the smallest rank. + - If an attribute is partitioned across several devices, all these sub tensors will be saved. + - Note that nnscaler supports partition an operator across multiple dimensions, attributes in the operator may + be saved at a subset of related devices. + - Pipeline parallelism is supported since it is composed of different segments in nnscaler, which are different + parallel modules with their own attribute maps at runtime. In addition, we will check - the shape of the full tensor is consistent across different ranks - the slicers of the full tensor are not intersected with each other - the slicers of the full tensor can cover the full tensor - The input and output attribute area map's key is the local attribute name. Args: - rank2attr_area_map (Dict[int, Dict[str, AttrMeta]]): the mapping from rank to the attribute area map + rank2attr_area_map ( + Dict[int, # rank id + Dict[str, # submodule prefix + Dict[str, # attribute name in parallel module (not original name) + AttrMeta]]]): fullmap information for all parallel modules in all ranks. Returns: - Dict[int, Dict[str, AttrMeta]]: the deduplicated attribute area map + Dict[int, Dict[str, Dict[str, AttrMeta]]]: the deduplicated fullmap info, the structure is the same as the input. ''' # assume ranks in rank2attr_area_map are in increasing order ranks = list(rank2attr_area_map.keys()) @@ -87,26 +128,32 @@ def need_save(slicers: Tuple[slice, ...], saved_slicers_list: List[Tuple[slice, return True ret = dict() - for rank, attr_area_map in rank2attr_area_map.items(): - dedup_attr_area_map = dict() - for attr, attr_meta in attr_area_map.items(): - assert attr_meta.val_chunks == 1, 'not support partitioning on value dimension' - if attr_meta.orig_name not in orig_name2shape: - orig_name2shape[attr_meta.orig_name] = attr_meta.shape - else: - assert orig_name2shape[attr_meta.orig_name] == attr_meta.shape, \ - f'unmatched shape {orig_name2shape[attr_meta.orig_name]} vs {attr_meta.shape}' - if need_save(attr_meta.slicers, orig_name2slice_info[attr_meta.orig_name]): - orig_name2slice_info[attr_meta.orig_name].append(attr_meta.slicers) - dedup_attr_area_map[attr] = attr_meta - ret[rank] = dedup_attr_area_map + for rank, module_fullmaps in rank2attr_area_map.items(): + dedup_module_fullmaps = dict() + for module_name, attr_area_map in module_fullmaps.items(): + dedup_attr_area_map = dict() + for attr, attr_meta in attr_area_map.items(): + assert attr_meta.val_chunks == 1, 'not support partitioning on value dimension' + # use module_name.orig_name as the unique identifier for full tensor + full_tensor_name = f"{module_name}.{attr_meta.orig_name}" + if full_tensor_name not in orig_name2shape: + orig_name2shape[full_tensor_name] = attr_meta.shape + else: + assert orig_name2shape[full_tensor_name] == attr_meta.shape, \ + f'unmatched shape {orig_name2shape[full_tensor_name]} vs {attr_meta.shape}' + if need_save(attr_meta.slicers, orig_name2slice_info[full_tensor_name]): + orig_name2slice_info[full_tensor_name].append(attr_meta.slicers) + dedup_attr_area_map[attr] = attr_meta + if dedup_attr_area_map: # only add non-empty maps + dedup_module_fullmaps[module_name] = dedup_attr_area_map + ret[rank] = dedup_module_fullmaps # since we # - skip saving when there are identical weights # - assert the slicers are disjoint # we can use the sum of the sub-slicers to verify the full tensor is covered - for orig_name, slicerss in orig_name2slice_info.items(): - shape = orig_name2shape[orig_name] + for full_tensor_name, slicerss in orig_name2slice_info.items(): + shape = orig_name2shape[full_tensor_name] full_size = 1 for s in shape: full_size *= s @@ -116,7 +163,7 @@ def need_save(slicers: Tuple[slice, ...], saved_slicers_list: List[Tuple[slice, for s in slicers: size *= s.stop - s.start covered_size += size - assert full_size == covered_size, f'uncovered size for {orig_name} with shape {shape}, slicerss {slicerss}' + assert full_size == covered_size, f'uncovered size for {full_tensor_name} with shape {shape}, slicerss {slicerss}' return ret @@ -192,14 +239,28 @@ def zero_grad(self): def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: """Get parameter list for optimizer""" - params = [] + return list(self.get_opt_params().keys()) + + def get_opt_params(self, prefix='', classify_param_cls_fn: Callable[[str], Any]=None) -> dict[torch.nn.Parameter, Any]: + """ + Get all parameters and their classifications. Parameters in reducers come first. + + Args: + prefix (str): The prefix of this module, + which will be used to generate full names of parameters and further classify them. + classify_param_cls_fn (Callable[[str], Any], optional): A function to classify parameters by name. + + Returns: + dict[torch.nn.Parameter, Any]: A dictionary mapping parameters to their classifications. + """ + params = {} reducer_pids = set() for reducer in self._reducers: - params += reducer.parameters_for_optimizer() + params.update(reducer.get_opt_params()) reducer_pids.update(id(p) for p in reducer.params) - for param in self.parameters(): + for name, param in self.named_parameters(prefix): if id(param) not in reducer_pids: - params.append(param) + params[param] = classify_param_cls_fn(name) if classify_param_cls_fn else None # print(f'> get out parameters: {sum(p.numel() for p in params)}') return params @@ -277,7 +338,8 @@ def add_full_map(self, attr: str, tid: int, is_param: bool, orig_name: str, shap val_chunks int: the number of value chunks. """ assert hasattr(self, attr), f"{attr} is not in the module" - meta = AttrMeta(tid, is_param, orig_name, shape, slicers, val_chunks) + attr_tensor: torch.Tensor = getattr(self, attr) + meta = AttrMeta(tid, is_param, orig_name, shape, slicers, val_chunks, attr_tensor.dtype, tuple(attr_tensor.shape)) self._fullmap[attr] = meta # TODO: remove this function, use the property instead @@ -330,7 +392,7 @@ def get_checkpoint(self, optimizer: torch.optim.Optimizer = None): # backward compatibility # in old version, dist_param_map is not loaded in constructor # so we will try to load it from file on the fly. - dist_param_map = getattr(self, '_dist_param_map', None) + dist_param_map = getattr(self, 'dist_param_map', None) if not dist_param_map: module_file = Path(sys.modules[self.__module__].__file__) # load from the same directory as the module file @@ -375,6 +437,7 @@ def merge_model_state_dicts( fullmaps: List[Dict[str, AttrMeta]] ): """Merge model states from multiple shard into a single-model state. + Here we assume the order of state_dicts and fullmaps are aligned, and is the same as the rank order. Note: Users only need to provide as fewer local model states as necessary to @@ -396,6 +459,11 @@ def merge_model_state_dicts( # Here we expand slice to (start, step, stop) tuple, # because before python 3.12, slice object is not hashable state_dict_merge_track: Dict[str, Set[Tuple[Tuple[Any, Any, Any], ...]]] = {} + # the fill progress of zero3 parameters + # key: param name + # value: Dict[ tuple(start, step, stop) , filled size] + # used to track how many elements have been filled for each zero3 parameter + zero3_current_filled: Dict[str, Dict[Tuple[Tuple[int, int, int], ...], int]] = {} # gather param/buffer full tensor for rank, (model_state_dict, local_fullmap) in enumerate(zip(state_dicts, fullmaps)): for local_name, meta in local_fullmap.items(): @@ -415,13 +483,40 @@ def merge_model_state_dicts( raise NotImplementedError("Not support of partitioning parameter / buffer at value dimension") state_dict_merge_track_id = tuple((i.start, i.step, i.stop) for i in meta.slicers) - if state_dict_merge_track_id in state_dict_merge_track[meta.orig_name]: - if not CubeModule._safe_tensor_equal(full_model_state_dict[meta.orig_name][meta.slicers], partial_tensor): - raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") + dest_tensor = full_model_state_dict[meta.orig_name][meta.slicers] + if dest_tensor.shape == partial_tensor.shape and state_dict_merge_track_id in state_dict_merge_track[meta.orig_name]: + if not CubeModule._safe_tensor_equal(dest_tensor, partial_tensor): + raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") _logger.debug(f'rank {rank}: skip merging duplicated model state for param {meta.orig_name} with slicers {meta.slicers}') else: state_dict_merge_track[meta.orig_name].add(state_dict_merge_track_id) - full_model_state_dict[meta.orig_name][meta.slicers] = partial_tensor + if dest_tensor.shape == partial_tensor.shape: + dest_tensor.copy_(partial_tensor) + else: + # we assume zero3 is on when dest_tensor.shape != partial_tensor.shape + if len(partial_tensor.shape) != 1: + raise ValueError("Invalid tensor as a ZeRO3 parameter, expected a 1D tensor.") + fill_start = zero3_current_filled.setdefault(meta.orig_name, {}).setdefault(state_dict_merge_track_id, 0) + fill_len = partial_tensor.numel() + if fill_start >= dest_tensor.numel(): + # already filled, let's check consistency + fill_start = fill_start % dest_tensor.numel() + if fill_start + fill_len > dest_tensor.numel(): + # remove padding part + fill_len = dest_tensor.numel() - fill_start + if not CubeModule._safe_tensor_equal(dest_tensor.view(-1)[fill_start: fill_start + fill_len], partial_tensor[0:fill_len]): + raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") + else: + if fill_start + fill_len > dest_tensor.numel(): + # remove padding part + fill_len = dest_tensor.numel() - fill_start + old_shape = dest_tensor.shape + dest_tensor = dest_tensor.reshape(-1) + dest_tensor[fill_start: fill_start + fill_len] = partial_tensor[0: fill_len] + full_model_state_dict[meta.orig_name][meta.slicers] = dest_tensor.view(old_shape) + + zero3_current_filled[meta.orig_name][state_dict_merge_track_id] += fill_len + return full_model_state_dict @staticmethod @@ -542,7 +637,7 @@ def _check_state_size(opt_state_keys, bucket_state): return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape for key in opt_state_keys) - def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): + def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size, zero_version): assert bucket_size % len(bucket_states) == 0 opt_state_keys = list(bucket_states[0].keys()) if 'step' in bucket_states[0]: @@ -551,36 +646,70 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): # NOTE: only support adam for now assert 'exp_avg' in opt_state_keys assert 'exp_avg_sq' in opt_state_keys - chunk_size = bucket_size // len(bucket_states) - start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size - end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + opt_states, opt_states_1d = {}, {} for key in opt_state_keys: opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, device=bucket_states[0][key].device, requires_grad=False) opt_states_1d[key] = opt_states[key].view(-1) - if start_rank_id == end_rank_id: - for key in opt_state_keys: - opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] - else: - offset = chunk_size-start_offset - for key in opt_state_keys: - opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] - for i in range(start_rank_id+1, end_rank_id): + if zero_version == 1: + chunk_size = bucket_size // len(bucket_states) + start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size + end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + if start_rank_id == end_rank_id: for key in opt_state_keys: - opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] - offset += chunk_size - if end_offset: # skip if end_offset == 0, because it is a no-op + opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] + else: + offset = chunk_size-start_offset for key in opt_state_keys: - opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] + for i in range(start_rank_id+1, end_rank_id): + for key in opt_state_keys: + opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] + offset += chunk_size + if end_offset: # skip if end_offset == 0, because it is a no-op + for key in opt_state_keys: + opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + else: # zero_version == 3 + assert zero_version > 1, f'unsupported zero version {zero_version}' + for key in opt_state_keys: + fill_start = 0 + fill_len = pend - pstart + param_numel = opt_states_1d[key].numel() + for bstate in bucket_states: + if fill_start >= param_numel: + # from current implementation, code never goes here + # because we have used model_idx2opt_idx to filter out unnecessary ranks + # but let's keep the logic here for safety + fill_start = fill_start % param_numel + if fill_start + fill_len > param_numel: + fill_len = param_numel - fill_start + # check consistency for the already filled part + if not CubeModule._safe_tensor_equal( + opt_states_1d[key][fill_start: fill_start + fill_len], + bstate[key][pstart: pstart+fill_len] + ): + raise ValueError(f"Conflict in merging optimizer state for param with shape {pshape}") + else: + if fill_start + fill_len > param_numel: + fill_len = param_numel - fill_start + # remove padding part + opt_states_1d[key][fill_start: fill_start + fill_len] = bstate[key][pstart: pstart+fill_len] + fill_start += fill_len if 'step' in bucket_states[0]: - opt_states['step'] = bucket_states[0]['step'] + # make sure all steps are different tensors (with same value) + opt_states['step'] = bucket_states[0]['step'].clone() return opt_states - def _merge_opt_zero(worker_idx, param_idx): - model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[worker_idx] + def _merge_opt_zero(param_shape, worker_idx, param_idx): + if len(zero_idx_maps[worker_idx]) == 3: + model_idx2opt_idx, opt_idx2ranks, zero_version = zero_idx_maps[worker_idx] + else: # backward compatibility + assert len(zero_idx_maps[worker_idx]) == 2 + model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[worker_idx] + zero_version = 1 # default to ZeRO-1 opt_idx = model_idx2opt_idx[param_idx] if isinstance(opt_idx, int): # the param without reducer @@ -589,14 +718,19 @@ def _merge_opt_zero(worker_idx, param_idx): else: # the param in reducer bucket opt_idx, pstart, pend, pshape = opt_idx + if zero_version == 1: + assert param_shape == pshape, f'param shape {param_shape} vs pshape {pshape}' ranks, bucket_size = opt_idx2ranks[opt_idx] + # parameters in reducer come first, so we can directly use opt_idx to index. bucket_states = [optim_state_dicts[rank]['state'][opt_idx] for rank in ranks] return _retrieve_param_opt_state( bucket_states, pstart, pend, - pshape, - bucket_size) + param_shape, + bucket_size, + zero_version + ) # full_index: param IDs in the full optimizer state for full_index, param_name in enumerate(origin_parameter_names): @@ -639,7 +773,7 @@ def _merge_opt_zero(worker_idx, param_idx): # As ZeRO is applied, the optimizer state of this parameter (a shard) # may not be stored locally in its optimizer state. # _merge_opt_zero is for recovering the optimizer state corresponding to this parameter shard. - states: Dict[str, torch.Tensor] = _merge_opt_zero(work_idx, local_index) + states: Dict[str, torch.Tensor] = _merge_opt_zero(meta.sub_shape, work_idx, local_index) zero_done_track.add(track_id) else: _logger.debug(f'rank {work_idx}: skip merging duplicated optimizer state for param {full_index} with slicers {meta.slicers}') @@ -718,6 +852,100 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): 'optim_state_dict': merged_optimizer_state_dict }, filename_prefix + '.full.ckpt') + def sleep(self): + """ + Move attributes (buffer and param) to cpu and release contiguous buffer in reducers. Different from + nn.Module's cpu() method, references to attributes are unchanged. + """ + for name, param in self.named_parameters(): + assert param.grad is None, f'expect {name} with shape {param.shape} has no grad' + + for reducer in self._reducers: + reducer.zero_grad() + + # we want attribute references are unchanged, so super().cpu() is not used here + cpu = torch.device('cpu') + for buffer in self.buffers(): + buffer.data = buffer.data.to(cpu) + + for param in self.parameters(): + param.data = param.data.to(cpu) + + for reducer in self._reducers: + reducer.sleep() + + gc.collect() + torch.cuda.empty_cache() + return self + + def wake_up(self, device: Optional[Union[int, device]] = None) -> Self: + """ + Move attributes (buffer and param) back to gpu and reallocate memories in reducers. It is a reverse + operation of `self.sleep()`. + """ + gpu = torch.cuda.current_device() + if device is not None: + if isinstance(device, int): + index = device + elif isinstance(device, torch.device): + index = device.index + else: + raise RuntimeError(f'unexpected device type {type(device)}') + assert gpu == index, f'nnscaler module does not support cross gpu transport, expect {gpu} but got {index}' + + for name, param in self.named_parameters(): + assert param.grad is None, f'expect {name} with shape {param.shape} has no grad' + + # we want attribute references are unchanged, so super().gpu() is not used here + for buffer in self.buffers(): + buffer.data = buffer.data.to(gpu) + + for param in self.parameters(): + param.data = param.data.to(gpu) + + for reducer in self._reducers: + reducer.wake_up() + + gc.collect() + torch.cuda.empty_cache() + return self + + def to(self, *args, **kwargs): + """ + Override nn.Module's to function, currently we only allow transfer data from host and device + + Args: + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module + dtype (:class:`torch.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module + tensor (torch.Tensor): Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + memory_format (:class:`torch.memory_format`): the desired memory + format for 4D parameters and buffers in this module (keyword + only argument) + + Returns: + Module: self + """ + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + if dtype is not None: + raise ValueError(f'nnscaler does not support passing dtype {dtype} to to()') + if convert_to_format is not None: + raise ValueError(f'nnscaler does not support passing convert_to_format {convert_to_format} to to()') + if non_blocking is not None: + warnings.warn(f'nnscaler moves tensors in a blocking approach currently') + + # after _parse_to `device` must in type of torch.device + if device.type == 'cpu': + return self.cpu() + elif device.type == 'cuda': + return self.cuda(device) + else: + raise ValueError(f'unsupported device type {device}') + @dataclass class OriginModuleMetadata: @@ -733,6 +961,11 @@ class ZeroMetadata: model_idx2opt_idx: Optional[Dict] = None # a mapping from optimizer_index to the related bucket information (sub_ranks, bucket_size) opt_idx2ranks: Optional[Dict] = None + # the level of zero optimization + # 0: no zero optimization + # 1: zero1 + # > 1: zero3 + zero: int = 0 @dataclass @@ -765,10 +998,27 @@ class ParallelModule(CubeModule): COMPUTE_CONFIG_FILE = 'compute_config.pt' ORIGIN_MODULE_METADATA_FILE = 'origin_module_metadata.pt' EXTRA_STATE_KEY = 'CUBE_EXTRA_STATE' + ATTR_META_FILE_PREFIX = 'attr_meta' + ATTR_META_FILE_TEMPLATE = ATTR_META_FILE_PREFIX + '{}.pkl' # 'attr_meta{}.pkl' + # the rank of the module, will be assigned in the generated subclasses rank: int + # the world size to run this module, will be assigned in the generated subclasses + world_size: int # the runtime version of the module when it is generated, will be assigned in the generated subclasses runtime_version: str + # mapping from the name of local attribute tensor + # to its corresponding fulltensor meta for all ranks. + # it is a list of dictionaries mapping from attribute names to their metadata + # and it is a replacement of `CubeModule.fullmap` + attr_meta_maps: list[dict[str, AttrMeta]] + # the directory of the module located + module_dir: Path + # The map is a dict mapping from the new parameter name (without tid suffix) in parallel module + # to the parameter name in original module. + dist_param_map: dict[str, str] + compute_config: 'ComputeConfig' + origin_module_metadata: OriginModuleMetadata def __init__(self): if self.__class__ == ParallelModule: # not init via super().__init__() @@ -790,6 +1040,42 @@ def __init__(self): self._nreplicas2localparams: Optional[Dict[int, List[torch.nn.Parameter]]] = None # track whether all the parames (especially the non-persistent buffers) have been initialized self._non_presistent_buffers_inited = False + # track the params that have been prefetched in backward + # this is only used for zero3 + # The reason is the eviction of prefetched params in backward + # relies on the input.requires_grad flag to be True + # If all the inputs do not require grad, + # the eviction logic will not be triggered + # In that case, we will delay the eviction until next backward hook. + self._backward_prefetched_params: dict[torch.nn.Parameter, int] = {} + # the params that have been prefetched in forward + self._forward_prefetched_params: set[torch.nn.Parameter] = set() + + def __init_subclass__(cls, skip_init=False, **kwargs): + # special case when we just fake a ParallelModule class + # In this case, you should also use object.__new__ instead of __init__ + if skip_init: + return + + from nnscaler.parallel import ComputeConfig + + super().__init_subclass__(**kwargs) + cls.attr_meta_maps = [] + cls.module_dir = Path(sys.modules[cls.__module__].__file__).parent + + for rank in range(cls.world_size): + attr_map_file = cls.module_dir / cls.ATTR_META_FILE_TEMPLATE.format(rank) + with open(attr_map_file, 'rb') as f: + attr_meta_map = pickle.load(f) + attr_meta_map = {attr: AttrMeta(**meta) for attr, meta in attr_meta_map.items()} + cls.attr_meta_maps.append(attr_meta_map) + + cls.dist_param_map = torch.load(cls.module_dir / FxModuleParser.ATTR_MAP_FILE, weights_only=False) + cls.compute_config = ComputeConfig.safe_load_from_file( + cls.module_dir / cls.COMPUTE_CONFIG_FILE, + return_none_on_error=False + ) + cls.origin_module_metadata = torch.load(cls.module_dir / cls.ORIGIN_MODULE_METADATA_FILE, weights_only=False) @property def non_presistent_buffers_inited(self): @@ -812,12 +1098,14 @@ def _warn_uninitialized_non_persistent_buffers(self, raise_error = False): else: _logger.warning(_non_persistent_buffers_load_warning) - def _post_init(self, init_params=True): + def _post_init(self, init_params=True, build_buckets=True): """ This is post init function to further initialize the model. Should be called by subclass's __init__(). Args: init_params (bool): whether to load model init parameters. Default True. + build_buckets (bool): whether to build buckets for the model. Default True. + If it is False, you must manually call `build_buckets()` later before use this module. """ # Here we check the rank to load the module file name # Current we don't check rank when we are not in distributed mode @@ -825,28 +1113,14 @@ def _post_init(self, init_params=True): # TODO: re-enable this check # if dist.is_initialized() and self.rank != dist.get_rank(): # raise RuntimeError(f"The rank to load this module file name is expected to be {self._rank}, but got {dist.get_rank()}") - from nnscaler.parallel import ComputeConfig self._non_presistent_buffers_inited = init_params or not self._non_persistent_buffers_set module_file = Path(sys.modules[self.__module__].__file__) - self.module_dir = module_file.parent if init_params: self.load_attr_content(str(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE_STEM}"))) self._warn_uninitialized_non_persistent_buffers() - self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}"), weights_only=False) - self._compute_config: ComputeConfig = ComputeConfig.safe_load_from_file( - module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}"), - return_none_on_error=False - ) - self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}"), weights_only=False) - - for reducer in self.reducers: - reducer.build_buckets() - - self._zero_metadata = self._get_zero_metadata() - # add state_dict hook to save extra state # Please note extra_state is only used for merging, not for loading # so we can safely remove it in load_state_dict pre hook @@ -854,11 +1128,174 @@ def _post_init(self, init_params=True): # add load_state_dict pre hook to pop extra state to prevent warning self._register_load_state_dict_pre_hook(ParallelModule._pre_load_state_dict_hook, with_module=True) + if build_buckets: + self.build_buckets() + + def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None): + """ + Build buckets for the model reducers. + + You should call this method exactly once before using this module. + Typically this will be called when building optimizer when multiple optimizers/param groups are used. + And we will put parameters with different optimizer or different param groups into different buckets. + + Currently we have done an optimization to make sure this is only called once even for hybrid optimizers + by + 1. setting `build_buckets=False` when calling constructor in `nnscaler.parallelize`. + 2. manually calling `build_buckets()` later in `nnscaler.build_optimizer` + """ + # needs all parameters to be in cuda memory before building buckets + self.cuda() + self._param_reducer_map: dict[torch.nn.Parameter, int] = {} + model_params = {p: n for n, p in self.named_parameters()} + # key: attr name of the parameter + # value: Zero3AttrMeta + self._zero3_param_metadata: dict[str, Zero3AttrMeta] = {} + for idx, reducer in enumerate(self.reducers): + reducer.build_buckets(param_clss) + for param in reducer.params: + self._param_reducer_map[param] = idx + attr_name = model_params[param] + param_attr = self._fullmap[attr_name] + zero3_info = reducer.get_z3_info(param) + self._zero3_param_metadata[attr_name] = Zero3AttrMeta( + attr_name=attr_name, + orig_name=param_attr.orig_name, + start = zero3_info.start, + end = zero3_info.end, + chunk_size=zero3_info.numel_with_padding(), + ) if zero3_info is not None else None + + self._zero_metadata = self._get_zero_metadata() + + def get_zero3_attr_meta(self, attr_name: str) -> Optional[Zero3AttrMeta]: + """ + Get the Zero3AttrMeta for the given attribute name. + + Args: + attr_name (str): the attribute name of the parameter + Returns: + Optional[Zero3AttrMeta]: the Zero3AttrMeta for the given attribute name + """ + return self._zero3_param_metadata.get(attr_name, None) + + @torch.no_grad() + def prefetch_param(self, param: torch.nn.Parameter): + """ + Gather the full parameter tensor for FSDP. + + Args: + param (torch.nn.Parameter): the local parameter to gather + """ + reducer = self._reducers[self._param_reducer_map[param]] + reducer.prefetch_param(param) + self._forward_prefetched_params.add(param) + + @torch.no_grad() + def postevict_param(self, param: torch.nn.Parameter): + """ + Release the full parameter tensor for zero3. + + Args: + param (torch.nn.Parameter): the local parameter + """ + reducer = self._reducers[self._param_reducer_map[param]] + reducer.postevict_param(param) + self._forward_prefetched_params.discard(param) + + def _backward_evict_leftover_params(self, order: int): + for p in [p for p, o in self._backward_prefetched_params.items() if o > order]: + self.postevict_param(p) + self._backward_prefetched_params.pop(p, None) + + def backward_postevict_param(self, input: torch.Tensor, param: torch.nn.Parameter, order: int): + """ + Here we need an input tensor to register the backward hook. + """ + if not input.requires_grad: + # if input does not require grad, we cannot register backward hook on it + return input + + @torch.no_grad() + def _postevict_param(param): # pragma: no cover + self.postevict_param(param) + self._backward_prefetched_params.pop(param, None) + self._backward_evict_leftover_params(order) + + return insert_backward_hook(input, functools.partial(_postevict_param, param)) + + def backward_prefetch_param(self, activation: torch.Tensor, param: torch.nn.Parameter, order: int): + """ + Here we need an activation tensor to register the backward hook. + """ + if not activation.requires_grad: + # if activation does not require grad, we cannot register backward hook on it + return activation + + @torch.no_grad() + def _prefetch_param(param): # pragma: no cover + self.prefetch_param(param) + self._backward_prefetched_params[param] = order + self._backward_evict_leftover_params(order) + + return insert_backward_hook(activation, functools.partial(_prefetch_param, param)) + + def save_params_hooks(self) -> saved_tensors_hooks: + """ + A hook to save tensors during forward pass. + This is used to avoid parameters being saved for activation checkpointing. + + Returns: + saved_tensors_hooks: the saved tensors hooks + """ + def pack(x: torch.Tensor): + for param in self._forward_prefetched_params: + if x.untyped_storage() == param.untyped_storage(): + return (param, x.shape, x.stride(), x.storage_offset()) + return x + + def unpack(x): + if isinstance(x, tuple) and len(x) == 4: + return torch.as_strided(x[0], x[1], x[2], x[3]) + return x + + return saved_tensors_hooks(pack, unpack) + + @classmethod + def get_attr_meta_map(cls, rank=None): + """ + Get the attribute meta map for the given rank. + If rank is None, return the attribute map for the current rank. + + This function is preferred over accessing `CubeModule.fullmap` in most cases, + since it doesn't need to instantiate the module. + """ + if rank is None: + rank = cls.rank + if rank < 0 or rank >= cls.world_size: + raise ValueError(f"Rank {rank} is out of range [0, {cls.world_size})") + return cls.attr_meta_maps[rank] + def forward(self, *args, **kwargs): self._warn_uninitialized_non_persistent_buffers(raise_error=True) if self.training: self._sync_grad_required = True # mark sync_grad() can be called again - return self._forward_impl(*args, **kwargs) + # all prefetched params should have been evicted + # please note the param can be evicted in Reducer, + # which is not tracked in self._backward_prefetched_params + # so we just check the shape to make sure the param is evicted + for param in self._backward_prefetched_params.keys(): + old_shape = param.shape + self.postevict_param(param) + assert param.shape == old_shape, \ + f'Param {param} is not properly evicted in backward' + self._backward_prefetched_params.clear() + + ret = self._forward_impl(*args, **kwargs) + + assert not self._forward_prefetched_params, \ + f'All forward prefetched params should have been evicted in forward' + return ret def _forward_impl(self, *args, **kwargs): """ @@ -1015,19 +1452,6 @@ def infer_step(self, samples: List[Any]) -> List[Any]: outputs.append(output) return outputs - @property - def dist_param_map(self) -> Dict[str, str]: - """ - Get the parameter map of the model. - The map is a dict mapping from the new parameter name (without tid suffix) in parallel module - to the parameter name in original module. - """ - return self._dist_param_map - - @property - def compute_config(self) -> 'ComputeConfig': - return self._compute_config - def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ Calculate the gradient norm and clip gradients. @@ -1116,6 +1540,7 @@ def _get_zero_metadata(self) -> ZeroMetadata: return ZeroMetadata( model_idx2opt_idx=model_idx2opt_idx, opt_idx2ranks=opt_idx2ranks, + zero=self.compute_config.use_zero ) def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: @@ -1150,11 +1575,11 @@ def _add_extra_state(self, state_dict, prefix) -> None: state_dict[f'{prefix}{self.EXTRA_STATE_KEY}'] = asdict( ExtraState( rank=self.rank, - compute_config=self._compute_config, - dist_param_map=self._dist_param_map, + compute_config=self.compute_config, + dist_param_map=self.dist_param_map, param_area_map=self._fullmap, cube_param_names=[name for name, _ in self.named_parameters()], - **asdict(self._orign_module_metadata), + **asdict(self.origin_module_metadata), **asdict(self._zero_metadata), ) ) @@ -1190,19 +1615,19 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if strict: missing_keys.extend(new_missing_keys) - @property - def module_dedup_group_size(self) -> int: + @classproperty + def module_dedup_group_size(cls) -> int: """ Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. """ - return self.compute_config.module_dedup_group_size + return cls.compute_config.module_dedup_group_size - @property - def optimizer_dedup_group_size(self) -> int: + @classproperty + def optimizer_dedup_group_size(cls) -> int: """ Get the size of the deduplication group of the optimizer state dict. """ - return self.compute_config.optimizer_dedup_group_size + return cls.compute_config.optimizer_dedup_group_size def _list_fullmodel_files(self) -> List[Path]: legacy_fullmodel_path = self.module_dir / FxModuleParser.ATTR_CONTENT_FILE_STEM @@ -1237,18 +1662,58 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s Raises: RuntimeError: if strict=True and there are missing keys. """ - - dist2param = self.dist_param_map - orig_param_names = list(dist2param.values()) # param names in original module (without prefix) non_persistent_buffers = self.get_non_persistent_buffers() with torch.no_grad(): # avoid checking the non-persistent buffers attr_names = set([attr for attr in self._fullmap.keys() if attr not in non_persistent_buffers]) - origname_tid_map = {meta.orig_name: meta.tid for meta in self._fullmap.values()} + for prefix_attr, content in self.trim_merged_state_dict(state_dict, prefix).items(): + attr = prefix_attr[len(prefix):] + tensor: torch.Tensor = getattr(self, attr) + tensor.copy_(content) + attr_names.remove(attr) + + missing_keys = [prefix + self._fullmap[attr].orig_name for attr in attr_names] + if len(attr_names) != 0: + erro_msg = f'Missing key(s) in state_dict: {missing_keys}.' + if strict: + raise RuntimeError(erro_msg) + else: + _logger.warning(erro_msg) + + self._warn_uninitialized_non_persistent_buffers() + return missing_keys + + def trim_merged_state_dict( + self, + state_dict: Dict[str, Any], + prefix: str = '', + *, + device=None, + ) -> Dict[str, Any]: + """ + Trim the merged state dict to only keep the parameters needed for the module. + Please note we don't check missing/unexpected keys. + + Args: + state_dict (Dict[str, Any]): the merged state dict + prefix (str): the prefix of the model state dict in the merged state dict + + Returns: + Dict[str, Any]: the trimmed state dict + """ + device = device or torch.cuda.current_device() + trimmed_state_dict = {} + + dist2param = self.dist_param_map + orig_param_names = list(dist2param.values()) # param names in original module (without prefix) + attr_meta_map = self.get_attr_meta_map(self.rank) + with torch.no_grad(): + # avoid checking the non-persistent buffers + origname_tid_map = {meta.orig_name: meta.tid for meta in attr_meta_map.values()} tid_info = defaultdict(list) - for attr, meta in self._fullmap.items(): + for attr, meta in attr_meta_map.items(): tid_info[meta.tid].append((attr, meta.slicers, meta.val_chunks)) # multiple params may share the same tid for orig_param_name in orig_param_names: @@ -1261,20 +1726,80 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s param_value = state_dict[orig_param_name_with_prefix] tid = origname_tid_map[orig_param_name] for attr, slicer, nchunks in tid_info[tid]: - tensor: torch.Tensor = getattr(self, attr) - content = param_value[slicer] + content: torch.Tensor = param_value[slicer] if nchunks != 1: content = content / nchunks - tensor.copy_(content) - attr_names.remove(attr) + if self.compute_config.use_zero <= 1 or self._zero3_param_metadata.get(attr, None) is None: + trimmed_state_dict[prefix + attr] = content.to(device) + else: + z3_info = self._zero3_param_metadata[attr] + start, end, chunk_size = z3_info.start, z3_info.end, z3_info.chunk_size + if end - start < chunk_size: + # need padding + padding = chunk_size - (end - start) + trimmed_state_dict[prefix + attr] = torch.cat([ + content.view(-1)[start:end], + torch.zeros(padding, dtype=content.dtype, device=content.device) + ], dim=0).to(device) + else: + trimmed_state_dict[prefix + attr] = content.reshape(-1)[start:end].to(device) - missing_keys = [prefix + self._fullmap[attr].orig_name for attr in attr_names] - if len(attr_names) != 0: - erro_msg = f'Missing key(s) in state_dict: {missing_keys}.' - if strict: - raise RuntimeError(erro_msg) - else: - _logger.warning(erro_msg) + return trimmed_state_dict - self._warn_uninitialized_non_persistent_buffers() - return missing_keys + def _pack( + self, + ): + """ + Get a packed information of the ParallelModule, so it can be sent to other ranks. + """ + param_map: dict[torch.nn.Parameter, torch.nn.Parameter] = {} + for p in self.parameters(): + param_map[p] = torch.nn.Parameter( + torch.empty_like(p, device='meta')) if p is not None else None + for b in self.buffers(): + param_map[b] = torch.empty_like( + b, device='meta') if b is not None else None + state = {} + fields = unchecked_fields(self) + state[fields._parameters] = {n: param_map[p] for n, p in self._parameters.items()} + state[fields._buffers] = {n: param_map[b] for n, b in self._buffers.items()} + state[fields._reducers] = [reducer._pack(param_map) for reducer in self._reducers] + state[fields._zero_metadata] = self._zero_metadata + state[fields._fullmap] = self._fullmap + state[fields._param_reducer_map] = { + param_map[p]: rid for p, rid in self._param_reducer_map.items() + } + state[fields._zero3_param_metadata] = self._zero3_param_metadata + + for cv in ParallelModule.__annotations__: + state[cv] = getattr(self, cv) + return state + + @classmethod + def _unpack(cls, state: dict): + """ + Unpack the information and return a fake ParallelModule that carries the same information. + """ + class GenModelX(ParallelModule, skip_init=True): + pass + pm = object.__new__(GenModelX) + fields = unchecked_fields(pm) + object.__setattr__(pm, fields._parameters, state[fields._parameters]) + object.__setattr__(pm, fields._buffers, state[fields._buffers]) + object.__setattr__(pm, fields._reducers, [Reducer._unpack(reducer) for reducer in state[fields._reducers]]) + object.__setattr__(pm, fields._zero_metadata, state[fields._zero_metadata]) + object.__setattr__(pm, fields._fullmap, state[fields._fullmap]) + object.__setattr__(pm, fields._param_reducer_map, state[fields._param_reducer_map]) + object.__setattr__(pm, fields._zero3_param_metadata, state[fields._zero3_param_metadata]) + + def named_parameters( + prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ): + assert prefix == "" and recurse is True, "Only support default arguments" + return pm._parameters.items() + + pm.named_parameters = named_parameters + + for cv in ParallelModule.__annotations__: + setattr(GenModelX, cv, state[cv]) + return pm diff --git a/nnscaler/runtime/serialization.py b/nnscaler/runtime/serialization.py new file mode 100644 index 00000000..ff492c26 --- /dev/null +++ b/nnscaler/runtime/serialization.py @@ -0,0 +1,252 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, TypedDict +import pickle +import base64 +import copy + +import torch +from safetensors.torch import save_file +from safetensors import safe_open + +from nnscaler.utils import transform_recursively, check_recursively +from nnscaler.version import __version__ + + +class MetadataDict(TypedDict): + obj: str + nnscaler: str + + +class _Index: + def __init__(self, index: int): + self.index = index + + def __repr__(self): + return f"_Index({self.index})" + + +def save(obj: Any, f, *, format="safetensors") -> None: + """ + Saves an object containing tensors into a safetensors file. + Args: + obj (`Any`): + The object you want to save. It can be a nested structure containing + tensors, lists, tuples, and dictionaries. + f: + The file-like object or filename where to save the safetensors file. + format (`str`, *optional*, defaults to `"safetensors"`): + The format to save the object. Currently `"safetensors"` and `"pt"` is supported. + """ + if format == 'pt': + torch.save(obj, f) + return + + if format != 'safetensors': + raise ValueError(f"Unsupported format: {format}") + + index = 0 + + # all tensors to be saved + tensors = {} + # detect shared tensors + # because safetensors does not support shared tensors, we need to + # save shared tensors only once and replace other occurrences + # TODO: Currently we only detect shared tensors that are exactly the same + # (i.e., share the same data_ptr and shape and stride). + # We may improve it in the future if needed. + # key: (tensor.data_ptr(), tensor.shape, tensor.stride()), value: _Index + tensor_ids: dict[tuple[int, tuple[int, ...], tuple[int, ...]], _Index] = {} + def transform_fn(o: Any) -> Any: + nonlocal index + if isinstance(o, torch.Tensor): + key = (o.data_ptr(), o.shape, o.stride()) + if key in tensor_ids: + idx = tensor_ids[key] + else: + idx = _Index(index) + tensor_ids[key] = idx + tensors[f'{index}'] = o + index += 1 + return idx + return o + metadata = transform_recursively(obj, transform_fn, target_types=(torch.Tensor,)) + save_file(tensors, f, metadata={ + 'obj': base64.b64encode(pickle.dumps(metadata)).decode('utf-8'), + 'nnscaler': __version__ + }) + + +class _LazyContainer: + """ + Mock class for dictionary, list, and tuple that loads tensors lazily from safetensors file. + """ + def __init__(self, data: dict | tuple | list, tensors: safe_open): + self.data = data + self.tensors = tensors + + def __getitem__(self, key): + return self._v(self.data[key]) + + def __setitem__(self, key, value): + raise NotImplementedError("Lazy containers are read-only.") + + def __delitem__(self, key): + raise NotImplementedError("Lazy containers are read-only.") + + def pop(self, key, default=None): + raise NotImplementedError("Lazy containers are read-only.") + + def __len__(self): + return len(self.data) + + def __contains__(self, item): + return self.data.__contains__(item) + + def get(self, key, default=None): + return self._v(self.data.get(key, default)) + + def keys(self): + return self.data.keys() + + def values(self): + return map(self._v, self.data.values()) + + def items(self): + return ((k, self._v(v)) for k, v in self.data.items()) + + def _v(self, v): + return _wrap_value(v, self.tensors) + + def load_all(self): + def _load(v): + if isinstance(v, _Index): + return self.tensors.get_tensor(f'{v.index}') + return v + return transform_recursively(self.data, _load, target_types=(_Index,)) + + def __copy__(self): + return copy.copy(self.load_all()) + + def __deepcopy__(self, memo): + return copy.deepcopy(self.load_all(), memo) + + def __iter__(self): + return iter(self.data) + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.data)})" + + +class _LazyList(_LazyContainer, list): + pass + + +class _LazyDict(_LazyContainer, dict): + pass + + +class _LazyTuple(_LazyContainer, tuple): + # tuple is immutable, so we need to override __new__ + def __new__(cls, *args, **kwargs): + return tuple.__new__(cls, ()) + + +def _wrap_value(v: Any, tensors: safe_open) -> Any: + if isinstance(v, _Index): + return tensors.get_tensor(f'{v.index}') + if not check_recursively(v, lambda k: isinstance(k, _Index)): + return v + if isinstance(v, dict): + return _LazyDict(v, tensors) + if isinstance(v, list): + return _LazyList(v, tensors) + if isinstance(v, tuple): + return _LazyTuple(v, tensors) + # should not reach here + return v + + +class LazyLoader: + def __init__(self, filename, device="cpu"): + self.filename = filename + self.device = device + self.tensor_loader = safe_open(self.filename, framework="pt", device=self.device) + self.tensors = None + self.data = None + + def __enter__(self): + self.tensors = self.tensor_loader.__enter__() + metadata: MetadataDict = self.tensors.metadata() + metadata_obj_b64 = metadata['obj'] + self.data = pickle.loads(base64.b64decode(metadata_obj_b64.encode('utf-8'))) + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + self.tensor_loader.__exit__(_exc_type, _exc_value, _traceback) + + def get_lazy_data(self) -> _LazyContainer | Any: + if self.tensors is None: + raise RuntimeError("LazyLoader context is not entered.") + return _wrap_value(self.data, self.tensors) + + +def load(f, *, device="cpu", format="safetensors", lazy=False) -> LazyLoader | Any: + """ + Loads an object containing tensors from a safetensors file lazily. + Args: + f: The file-like object or filename from which to load the safetensors file. + device (`str`, *optional*, defaults to `"cpu"`): + The device where the tensors will be loaded. + lazy (`bool`, *optional*, defaults to `False`): + If set to `False`, loads all tensors into memory eagerly. + Returns: + (`LazyLoader` | `Any`): + The lazy loader object that can be used to access the data. + If `lazy` is set to `False`, returns the loaded object with all tensors + loaded into memory. + """ + if format == 'pt': + return torch.load(f, map_location=device, weights_only=False) + if format != 'safetensors': + raise ValueError(f"Unsupported format: {format}") + + if not lazy: + with LazyLoader(f, device=device) as loader: + data = loader.get_lazy_data() + if isinstance(data, _LazyContainer): + return data.load_all() + else: + # pure data without any tensors + return data + return LazyLoader(f, device=device) + + +def convert(src: str, dst: str, *, src_format="safetensors", dst_format="pt", device="cpu") -> None: + """ + Converts a serialized file from one format to another. + Args: + src (`str`): + The source filename. + dst (`str`): + The destination filename. + src_format (`str`, *optional*, defaults to `"safetensors"`): + The format of the source file. Currently `"safetensors"` and `"pt"` is supported. + dst_format (`str`, *optional*, defaults to `"pt"`): + The format of the destination file. Currently `"safetensors"` and `"pt"` is supported. + device (`str`, *optional*, defaults to `"cpu"`): + The device where the tensors will be loaded. + + Returns: + (`None`): + This function does not return anything. + """ + if src_format == dst_format: + raise ValueError("Source and destination formats are the same.") + + save( + load(src, device=device, format=src_format, lazy=False), + dst, + format=dst_format + ) diff --git a/nnscaler/runtime/utils.py b/nnscaler/runtime/utils.py index b15748ea..43ec0656 100644 --- a/nnscaler/runtime/utils.py +++ b/nnscaler/runtime/utils.py @@ -13,7 +13,7 @@ class MicroBatchDataLoader: """ MicroBatchDataLoader is used for scenarios of gradient accumulation, where a training iteration will have multiple data samples and perform - multiple forward and backward on each sample (i.e., each refers to + multiple forward and backward on each sample (i.e., each refers to as a micro-batch). To support more flexible training patterns, e.g., pipeline parallelism, @@ -25,7 +25,7 @@ class MicroBatchDataLoader: ```python # compilation phase dataloader = MicroBatchDataLoader([(input1,),]) # only need one micro-batch - + @nnscaler.compile(model, dataloader, ...) def train_iter(model, dataloader): input1 = next(dataloader) @@ -36,9 +36,9 @@ def train_iter(model, dataloader): ... # runtime phase - + for mini_batch_samples in iter(dataloader): - # mini_batch_samples are sample list for + # mini_batch_samples are sample list for # all micro-batches in one iteration. dl = MicroBatchDataLoader(mini_batch_samples) loss =train_iter(model, dl) @@ -68,7 +68,7 @@ def __init__(self, samples: List[Any], cycle: bool = False): def __iter__(self): self._idx = 0 return self - + def __next__(self): if self._idx == self.nmicros: raise StopIteration @@ -77,10 +77,10 @@ def __next__(self): if self.cycle: self._idx = self._idx % self.nmicros return batch - + def __len__(self): return self.nmicros - + def get_micro_batch(self, idx: int): idx = idx % self.nmicros if self.cycle else idx return self.samples[idx] diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 310c6649..a88f4997 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -4,16 +4,16 @@ import builtins import importlib from contextlib import contextmanager -from functools import wraps +from functools import wraps, cache from typing import ( Generator, Optional, Tuple, Callable, Dict, List, Set, Any, - Iterable, Type, Union, Protocol, ClassVar, cast, TypeVar + Iterable, Type, TypedDict, Union, Protocol, ClassVar, cast, TypeVar ) import logging from pathlib import Path import sys from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field import inspect import os @@ -112,6 +112,27 @@ def get_member_by_name(model: torch.nn.Module, name: str) -> Any: return model_attr +def set_member_by_name(model: Any, name: str, value: Any) -> None: + """ + Set the member of the model by its full name. + """ + if not name: + raise ValueError("Name cannot be empty") + class _ValueHolder: + """ + A value holder. + In python you can't call `setattr` on object, but you can call it on its subclasses. + """ + pass + sliced_names = name.split(".") + model_attr = model + for sliced_name in sliced_names[:-1]: + if not hasattr(model_attr, sliced_name): + setattr(model_attr, sliced_name, _ValueHolder()) + model_attr = getattr(model_attr, sliced_name) + setattr(model_attr, sliced_names[-1], value) + + def get_shared_params(model: torch.nn.Module) -> List[List[str]]: paramid2name = defaultdict(set) for name in model.state_dict().keys(): @@ -211,65 +232,197 @@ def wrapped_fn(*args, **kwargs): _DICT_ITEMS_TYPE = type({}.items()) _DICT_KEYS_TYPE = type({}.keys()) _DICT_VALUES_TYPE = type({}.values()) +TRANSFORM_SUPPORTED_COLLECTION_TYPES = (tuple, list, dict, set, slice, _DICT_ITEMS_TYPE, _DICT_KEYS_TYPE, _DICT_VALUES_TYPE) -def transform_recursively(data: Any, fn: Callable[[Any], Any], +def _transform_recursively(data: Any, fn: Callable[[Any], Any], target_types: Union[Callable[[Any], bool], Type, Tuple[Type, ...]], collection_types = (tuple, list, dict), skip_dict_keys = True -) -> Any: - """ - Transform the data with the given function, will recursively apply the function to the nested data. - Args: - data: the data to be transformed. - fn: the function to apply. - target_types: the target types to apply the function. - collection_types: the collection types to apply the function to the nested data. - skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). - _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. - """ +) -> tuple[bool, Any]: + if collection_types is None: + collection_types = TRANSFORM_SUPPORTED_COLLECTION_TYPES if isinstance(data, collection_types): if isinstance(data, tuple): - return tuple(transform_recursively(t, fn, target_types, collection_types) for t in data) + result = tuple(_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data) + changed = any(c for c, _ in result) + if changed: + return True, tuple(v for _, v in result) + else: + return False, data if isinstance(data, list): - return list(transform_recursively(t, fn, target_types, collection_types) for t in data) + result = [_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data] + changed = any(c for c, _ in result) + if changed: + return True, [v for _, v in result] + else: + return False, data if isinstance(data, set): - return set(transform_recursively(t, fn, target_types, collection_types) for t in data) + result = [_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data] + changed = any(c for c, _ in result) + if changed: + return True, {v for _, v in result} + else: + return False, data if isinstance(data, dict): - return { - k if skip_dict_keys else transform_recursively(k, fn, target_types, collection_types): - transform_recursively(v, fn, target_types, collection_types) + if skip_dict_keys: + keys = {k: (False, k) for k in data.keys()} + else: + keys = { + k: _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k in data.keys() + } + changed = any(c for c, _ in keys.values()) + result = { + k: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) for k, v in data.items() - } + } + changed = changed or any(c for c, _ in result.values()) + if changed: + return True, { + keys[k][1]: v for k, (_, v) in result.items() + } + else: + return False, data if isinstance(data, _DICT_ITEMS_TYPE): - return { - k if skip_dict_keys else transform_recursively(k, fn, target_types, collection_types): - transform_recursively(v, fn, target_types, collection_types) + if skip_dict_keys: + keys = {k: (False, k) for k, _ in data} + else: + keys = { + k: _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k, _ in data + } + + changed = any(c for c, _ in keys.values()) + result = { + k: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) for k, v in data - }.items() + } + changed = changed or any(c for c, _ in result.values()) + if changed: + return True, { + keys[k][1]: v for k, (_, v) in result.items() + }.items() + else: + return False, data if isinstance(data, _DICT_KEYS_TYPE): - return { - transform_recursively(k, fn, target_types, collection_types): i - for i, k in enumerate(data) - }.keys() + result = [ + _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k in data + ] + changed = any(c for c, _ in result) + if changed: + return True, { + v: i for i, (_, v) in enumerate(result) + }.keys() + else: + return False, data if isinstance(data, _DICT_VALUES_TYPE): - return { - i: transform_recursively(v, fn, target_types, collection_types) + result = { + i: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) for i, v in enumerate(data) - }.values() + } + changed = any(c for c, _ in result.values()) + if changed: + return True, { + i: v for i, (_, v) in result.items() + }.values() + else: + return False, data if isinstance(data, slice): - return slice( - transform_recursively(data.start, fn, target_types, collection_types), - transform_recursively(data.stop, fn, target_types, collection_types), - transform_recursively(data.step, fn, target_types, collection_types) + result = ( + _transform_recursively(data.start, fn, target_types, collection_types, skip_dict_keys), + _transform_recursively(data.stop, fn, target_types, collection_types, skip_dict_keys), + _transform_recursively(data.step, fn, target_types, collection_types, skip_dict_keys), ) + if any(c for c, _ in result): + return True, slice( + result[0][1], + result[1][1], + result[2][1] + ) + else: + return False, data raise ValueError(f"Unsupported collection type: {type(data)}") elif isinstance(target_types, (tuple, list)) or inspect.isclass(target_types): if isinstance(data, target_types): - return fn(data) + return True, fn(data) elif callable(target_types): # not a class, but callable. treat as a check function. if target_types(data): - return fn(data) - return data + return True, fn(data) + return False, data + + +def transform_recursively(data: Any, fn: Callable[[Any], Any], + target_types: Union[Callable[[Any], bool], Type, Tuple[Type, ...]], + collection_types = (tuple, list, dict), skip_dict_keys = True +) -> Any: + """ + Transform the data with the given function, will recursively apply the function to the nested data. + Currently supported collection types is SUPPORTED_COLLECTION_TYPES. + Args: + data: the data to be transformed. + fn: the function to apply. + target_types: the target types to apply the function. + collection_types: the collection types to apply the function to the nested data. + Will handle all supported types if None. + skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). + _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. + """ + _, result = _transform_recursively(data, fn, target_types, collection_types, skip_dict_keys) + return result + + +def check_recursively(data, fn: Callable[[Any], bool], + collection_types = (tuple, list, dict), + skip_dict_keys = True +) -> bool: + """ + Check the data with the given function, will recursively apply the function to the nested data. + Args: + data: the data to be checked. + fn: the function to check. + collection_types: the collection types to apply the function to the nested data. + skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). + _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. + + """ + if collection_types is None: + collection_types = TRANSFORM_SUPPORTED_COLLECTION_TYPES + + if isinstance(data, collection_types): + if isinstance(data, (list, tuple, set, _DICT_KEYS_TYPE, _DICT_VALUES_TYPE)): + return any(check_recursively(t, fn, collection_types) for t in data) + if isinstance(data, dict): + if skip_dict_keys: + return any( + check_recursively(v, fn, collection_types) + for v in data.values() + ) + else: + return any( + check_recursively(k, fn, collection_types) or check_recursively(v, fn, collection_types) + for k, v in data.items() + ) + if isinstance(data, _DICT_ITEMS_TYPE): + if skip_dict_keys: + return any( + check_recursively(v, fn, collection_types) + for _, v in data + ) + else: + return any( + check_recursively(k, fn, collection_types) or check_recursively(v, fn, collection_types) + for k, v in data + ) + if isinstance(data, slice): + return any(( + check_recursively(data.start, fn, collection_types), + check_recursively(data.stop, fn, collection_types), + check_recursively(data.step, fn, collection_types) + )) + raise ValueError(f"Unsupported collection type: {type(data)}") + + return fn(data) def is_running_distributed() -> bool: @@ -325,6 +478,21 @@ def fields(model: TDataClass, /) -> TDataClass: return cast(TDataClass, _GetFields(model)) +class _UncheckedFields: + def __getattr__(self, item: str) -> Any: + return item + + +TUncheckedClass = TypeVar("TAnyClass") +def unchecked_fields(_: TUncheckedClass, /) -> TUncheckedClass: + """ + This function is used to get the field names(in str) of any object without checking + This is a workaround for the lack of `__name__` of member. + """ + return cast(TUncheckedClass, _UncheckedFields()) + + +@cache def load_type(type_name: str): """ Load function/class from its full qualified name @@ -457,3 +625,60 @@ def steps(nsteps: int): RuntimeFlag.skip_reducer = (not (step == nsteps - 1)) yield step RuntimeFlag.skip_zero_grad, RuntimeFlag.skip_reducer = old + + +class AdamOptState(TypedDict): + step: torch.Tensor + exp_avg: torch.Tensor + exp_avg_sq: torch.Tensor + + +class OptStateParamGroup(TypedDict): + params: list[int] + lr: int + + +class OptStateDict(TypedDict): + state: dict[int, AdamOptState | dict[str, Any]] + param_groups: list[OptStateParamGroup | dict[str, Any]] + + +def fn_field(**kwargs): + metadata = kwargs.pop('metadata', {}) + metadata['deserialize'] = lambda t: None if t is None else load_type(t) + return field(**kwargs, metadata=metadata) + + +TENSOR_DYNAMIC_DIMS_FIELD_NAME = '_nnscaler_dynamic_dims' +# for nnscaler custom class (TensorMetadata) +NNSCALER_DYNAMIC_DIMS_NAME = 'dynamic_dims' + + +def mark_dynamic(tensor: torch.Tensor, dims: int | list[int] | tuple[int]) -> torch.Tensor: + """ + Mark the dim of a tensor as dynamic, which means it can be changed in the future. + This is the same with `torch._dynamo.mark_dynamic` + """ + dims = [dims] if isinstance(dims, int) else dims + setattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, set(dims) if dims else set()) + return tensor + + +def copy_dynamic(src: torch.Tensor, tensor: torch.Tensor) -> torch.Tensor: + """ + Copy the dynamic dims from src to tensor, and return the tensor. + """ + if hasattr(src, TENSOR_DYNAMIC_DIMS_FIELD_NAME): + setattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, getattr(src, TENSOR_DYNAMIC_DIMS_FIELD_NAME)) + return tensor + + +def get_dynamic(tensor: Any) -> set[int]: + """ + Get the dynamic dims of a tensor. + It also works when tensor is not an instance of torch.Tensor + """ + if isinstance(tensor, torch.Tensor): + return getattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, set()) + else: + return getattr(tensor, NNSCALER_DYNAMIC_DIMS_NAME, set()) diff --git a/nnscaler/version.py b/nnscaler/version.py index ae87ac5e..84bc6647 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -__version__ = '0.8' +__version__ = '0.7' diff --git a/pipelines/nightly-build.yaml b/pipelines/nightly-build.yaml new file mode 100644 index 00000000..da792cf6 --- /dev/null +++ b/pipelines/nightly-build.yaml @@ -0,0 +1,25 @@ +trigger: +- main + +pool: + vmImage: ubuntu-latest + +steps: +- task: TwineAuthenticate@1 + inputs: + artifactFeed: SuperScaler/nightly + +- script: | + python -m pip install --upgrade build twine + displayName: prepare environment + +- script: | + python pipelines/scripts/update_version.py --nightly + python -m build + displayName: build wheel + +- script: | + number_of_wheels=`ls dist/*.whl | wc -l` + test $number_of_wheels -eq 1 + python -m twine upload -r nightly --config-file $(PYPIRC_PATH) dist/*.whl + displayName: upload nightly wheel diff --git a/pipelines/release.yaml b/pipelines/release.yaml new file mode 100644 index 00000000..b7a06433 --- /dev/null +++ b/pipelines/release.yaml @@ -0,0 +1,39 @@ +# depends on two variables: +# +# - version +# must be set on devops website for each run +# the value should be something like "0.1" or "v0.1a1" (w/ or w/o leading v) +# +# - test_pypi_token +# secret, should never expire +# to view it or to update it, check onenote accounts/pypi page (test.pypi token) + +trigger: none +pr: none + +pool: + vmImage: ubuntu-latest + +steps: +- task: TwineAuthenticate@1 + inputs: + artifactFeed: SuperScaler/release + +- script: | + python -m pip install --upgrade build twine + displayName: prepare environment + +- script: | + python pipelines/scripts/update_version.py $(version) + python -m build + number_of_wheels=`ls dist/*.whl | wc -l` + test $number_of_wheels -eq 1 + displayName: build wheel + +- script: | + python -m twine upload -r release --config-file $(PYPIRC_PATH) dist/*.whl + displayName: upload to artifact + +- script: | + python -m twine upload -r testpypi -p $(test_pypi_token) dist/*.whl + displayName: upload to testpypi diff --git a/pipelines/scripts/update_version.py b/pipelines/scripts/update_version.py new file mode 100644 index 00000000..98b4f289 --- /dev/null +++ b/pipelines/scripts/update_version.py @@ -0,0 +1,71 @@ +""" +Update "nnscaler/version.py" before building the wheel. + +Usage 1: + + python update_version.py --nightly + +Update version.py to "X.Y.dev{TIMESTAMP}+{GIT_COMMIT}". + +Usage 2: + + python update_version.py 1.2 + python update_version.py v1.2b3 + +Update version.py to the specified version (normalized, leading "v" removed). +It will verify that the release part matches the old version. +""" + +import argparse +from datetime import datetime +from pathlib import Path +import subprocess + +from packaging.version import Version + +project_dir = Path(__file__).parents[2] + +def main(): + parser = argparse.ArgumentParser(add_help=False) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--nightly', action='store_true') + group.add_argument('version', nargs='?') + args = parser.parse_args() + + version_file = Path(project_dir, 'nnscaler/version.py') + file_content = version_file.read_text() + version_str = file_content.split('=')[-1].strip()[1:-1] # "version = 'x'" -> "x" + repo_version = Version(version_str) + + if args.nightly: + timestamp = datetime.now().strftime('%y%m%d%H%M') + + r = subprocess.run( + 'git rev-parse --short HEAD'.split(), + stdout=subprocess.PIPE, + cwd=project_dir, + text=True, + ) + if r.returncode != 0: + print('[error] failed to get git commit hash') + exit(1) + commit = r.stdout.strip() + + new_version_str = f'{repo_version.base_version}.dev{timestamp}+{commit}' + + else: + arg_version = Version(args.version) + + if repo_version.release != arg_version.release: + print('[error] version not match') + print(f' repo: {version_str} -> {repo_version}') + print(f' arg: {args.version} -> {arg_version}') + exit(1) + + new_version_str = str(arg_version) # normalize + + file_content = file_content.replace(version_str, new_version_str) + version_file.write_text(file_content) + +if __name__ == '__main__': + main() diff --git a/requirements-dev.txt b/requirements-dev.txt index 7d181749..bd2f7fdf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,3 +19,4 @@ wandb tensorboard mosaicml-streaming cppimport +einops diff --git a/requirements.txt b/requirements.txt index 6ae52a83..03833f53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,6 @@ psutil pulp pybind11<3.0.0 pyyaml -torch>=2.0,<=2.6 +torch>=2.0 tqdm +safetensors diff --git a/tests/autodist/test_dp_solver.py b/tests/autodist/test_dp_solver.py index 846e1c05..c6b172a4 100644 --- a/tests/autodist/test_dp_solver.py +++ b/tests/autodist/test_dp_solver.py @@ -37,6 +37,7 @@ def test_dp_solver(): # the optimal plan is each operator's first partition assert best.path == [(0, 0), (1, 0), (2, 0)] + def test_dp_solver_mem(): solver = dp_solver.DPSolver(True, 100, 1) solver.add_interval(0, 4) @@ -73,6 +74,7 @@ def test_dp_solver_mem(): assert best.path == [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)] assert best.memory == 71 + def test_dp_solver_build_in_edges(): # mock following code # dropout_rate = self.attention_dropout if self.training else 0.0 @@ -102,6 +104,7 @@ def test_dp_solver_build_in_edges(): best = ans[0] assert best.path == [(0, 0), (1, 0), (2, 0)] + def test_dp_solver_mem_bound(): solver = dp_solver.DPSolver(True, 10, 1) solver.add_interval(0, 2) @@ -119,3 +122,26 @@ def test_dp_solver_mem_bound(): ans = solver.get_results(0, 2) assert len(ans) == 0 + + +def test_dp_solver_output(): + solver = dp_solver.DPSolver(True, 1024, 1) + solver.add_interval(0, 2) + + solver.add_node(0, 0, [0], [], 2, False, False, False) + solver.add_partition(0, 0, 10, 16, 0, 0, 0, 0, 0, [[]]) + solver.add_partition(0, 1, 5, 8, 0, 0, 0, 0, 1, [[]]) + + solver.add_node(1, 1, [0, 1], [], 2, False, False, False) + solver.add_partition(1, 0, 4, 6, 0, 0, 0, 0, 0, [[]]) + solver.add_partition(1, 1, 2, 3, 0, 0, 0, 0, 1, [[]]) + + solver.add_node(2, 2, [2], [0], 1, False, False, False) + solver.add_partition(2, 0, 0, 0, 0, 0, 0, 0, 0, [[0, 0]]) + + solver.solve() + ans = solver.get_results(0, 2) + best = ans[0] + assert best.all_time == 7 + assert best.path == [(0, 1), (1, 1), (2, 0)] + assert best.memory == 11 diff --git a/tests/cli/common.py b/tests/cli/common.py index f5be1c9d..02f8b197 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -1,6 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# CausalSelfAttention is copied from https://github.com/karpathy/nanoGPT/blob/master/model.py +# with minor modifications. +# See the original license in the file https://github.com/karpathy/nanoGPT/blob/master/LICENSE + from pathlib import Path import torch from torch import nn @@ -9,11 +13,88 @@ from streaming import MDSWriter, StreamingDataset, StreamingDataLoader +import nnscaler from nnscaler.cli.trainer_args import TrainerArgs from tests.parallel_module.test_end2end import MLP from tests.utils import init_random as init_random_fn + +class CausalSelfAttention(nn.Module): + def __init__(self, n_embd: int, n_head: int, dropout: float): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd, bias=True) + # regularization + self.attn_dropout = nn.Dropout(dropout) + self.resid_dropout = nn.Dropout(dropout) + self.n_head = n_head + self.n_embd = n_embd + self.dropout = dropout + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class SimpleTransformerModel(nn.Module): + def __init__(self, n_embd: int, n_head: int, dropout: float, nlayers: int, vocab_size: int): + super().__init__() + + self.layers = nn.ModuleList([]) + self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) + for _ in range(nlayers): + self.layers.append(CausalSelfAttention(n_embd, n_head, dropout)) + + def forward(self, data): + x = data['input'] + target = data['target'] + for layer in self.layers: + x = layer(x) + logits = self.lm_head(x) + loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1), ignore_index=-1) + return loss + + +def csa_forward_args_gen_fn(trainer_args: TrainerArgs): + seq_len = 128 # dynamicness is controlled by trainer_args.vars['dynamic_dims'] + + return { + 'x': torch.randn(1, seq_len, trainer_args.model.args['n_embd']), + } + + +def post_csa_forward_args_gen_fn(trainer_args: TrainerArgs, args): + dynamic_dims = trainer_args.get_resolved_var('dynamic_dims', default=[]) + nnscaler.mark_dynamic(args['x'], dynamic_dims) + return args + + +def transformer_dummy_sample_gen_fn(trainer_args: TrainerArgs): + seq_len = 128 # dynamicness is controlled by trainer_args.vars['dynamic_dims'] + dynamic_dims = trainer_args.get_resolved_var('dynamic_dims', default=[]) + return { + 'input': nnscaler.mark_dynamic(torch.randn(1, seq_len, trainer_args.model.args['n_embd']), dynamic_dims), + 'target': nnscaler.mark_dynamic(torch.randint(0, trainer_args.model.args['vocab_size'], (1, seq_len)), dynamic_dims), + } + + class MixModuleMLP(nn.Module): def __init__(self, dim: int, nlayers: int, init_random: bool = True): super().__init__() diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 427c8f26..dffb0813 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -209,6 +209,28 @@ class A: assert y.p == {'value': 'auto'} +def test_merge_dict(): + a = { + 'compute_config': { + 'plan_ngpus': 1 + }, + 'optimizer': { + 'type': 'torch.nn.Adam', + 'args': { + 'lr': 0.001 + } + } + } + merge_args(a, ['--optimizer', { + 'type': 'torch.nn.AdamW', + 'args': { + 'hello': 'haha' + } + }]) + assert a['optimizer']['args']['lr'] == 0.001 + assert a['optimizer']['args']['hello'] == 'haha' + + def test_merge_list(): @dataclass class A: diff --git a/tests/cli/test_hooks.py b/tests/cli/test_hooks.py new file mode 100644 index 00000000..102a48bc --- /dev/null +++ b/tests/cli/test_hooks.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, List + +from nnscaler.cli.train_hook import TrainHook, TrainHookHost + + +class A(TrainHook): + pass + +class B(TrainHook): + pass + +class C(TrainHook, TrainHookHost): + def _get_hook_objects(self) -> List[Any]: + return [A(), B(), self] + + +class D(TrainHookHost): + def _get_hook_objects(self) -> List[Any]: + return [self, A(), C()] + +def test_hook(): + hooks = D().get_hooks() + assert len(hooks) == 4 + assert isinstance(hooks[0], A) + assert isinstance(hooks[1], C) + assert isinstance(hooks[2], A) + assert isinstance(hooks[3], B) diff --git a/tests/cli/test_serialization.py b/tests/cli/test_serialization.py new file mode 100644 index 00000000..cd5ae9b4 --- /dev/null +++ b/tests/cli/test_serialization.py @@ -0,0 +1,245 @@ +import pytest +import torch +from pathlib import Path + +from nnscaler.cli.serialization import ( + convert_format, SerializationRunner, register_serialization_runner, + Checkpointer +) +from nnscaler.cli.trainer import Trainer +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.common import assert_equal + + +def test_runner(tmp_path): + + class SplitSerializationRunner(SerializationRunner): + name: str = 'split' + + def run_load(self, load_func, f, *, device='cpu'): + model_state_dict = load_func(f, device=device) + opt_state_dict = load_func(str(f) + '.opt', device=device) + return { + 'model': model_state_dict, + 'optimizer': opt_state_dict + } + + def run_save(self, save_func, state_dict, f): + save_func(state_dict['model'], f) + save_func(state_dict['optimizer'], str(f) + '.opt') + + register_serialization_runner(SplitSerializationRunner) + + a = torch.randn((2, 2), device='cpu') + b = torch.randn((2, 3), device='cpu') + c = torch.randn((4, 4), device='cpu') + d = torch.randn((3, 3), device='cpu') + tensors = { + "model": { + "embedding": a, + "attention": b, + }, + "optimizer": { + "state": { + 0: { + "exp_avg": c, + "exp_avg_sq": d, + } + } + } + } + checkpointer = Checkpointer() + checkpointer.save(tensors, tmp_path / "model.ckpt") + checkpointer.flush() + + convert_format( + src=str(tmp_path / "model.ckpt"), + dst=str(tmp_path / "model_split.ckpt"), + dst_serializer='split', + ) + + assert Path(tmp_path / "model_split.ckpt").exists() + assert Path(tmp_path / "model_split.ckpt.opt").exists() + tensor3 = Checkpointer(serializer='split').load(tmp_path / "model_split.ckpt") + assert_equal(tensors, tensor3) + + checkpointer2 = Checkpointer(serializer=':split') + tensor2 = checkpointer2.load(tmp_path / "model.ckpt") + assert_equal(tensors, tensor2) + + checkpointer2.save(tensor2, tmp_path / "model_split2.ckpt") + checkpointer2.flush() + assert Path(tmp_path / "model_split2.ckpt").exists() + assert Path(tmp_path / "model_split2.ckpt.opt").exists() + + tensor4 = Checkpointer(serializer='split').load(tmp_path / "model_split2.ckpt") + assert_equal(tensors, tensor4) + + +def trainer_split_serializer_worker(tmp_path, symblink): + save_dir = Path(tmp_path) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' + use_zero = 1 + format = 'safetensors' + rev_format = 'pt' if format == 'safetensors' else 'safetensors' + + def list_ckpt_files(dir): + return set(dir.glob('**/*.ckpt')) | set(dir.glob('**/*.safetensors')) + + + class SplitSerializationRunner(SerializationRunner): + name: str = 'split' + + def __init__(self, mark=''): + self.mark = mark + + def run_load(self, load_func, f, *, device='cpu'): + other_state_dict = load_func(f, device=device) + opt_state_dict = load_func(str(f) + '.opt', device=device) + model_state_dict = load_func(str(f) + '.model', device=device) + return { + 'model': model_state_dict, + 'optimizer': opt_state_dict, + **other_state_dict + } + + def run_save(self, save_func, state_dict, f): + save_func(state_dict['model'], str(f) + '.model') + save_func(state_dict['optimizer'], str(f) + '.opt') + other_state_dict = {k: v for k, v in state_dict.items() if k not in ['model', 'optimizer']} + other_state_dict['mark'] = self.mark + save_func(other_state_dict, f) + + register_serialization_runner(SplitSerializationRunner) + + # train 4 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.format', format, + '--checkpoint.serializer.name', 'split', + '--checkpoint.serializer.args.mark', 'hello', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + ckpt_files = list_ckpt_files(ckpt_savedir) + assert len(ckpt_files)/4 == min(10, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last + + for f in ckpt_files: + assert trainer.checkpointer.load(f)['mark'] == 'hello' + assert Path(str(f) + '.opt').exists() + assert Path(str(f) + '.model').exists() + + torch.distributed.barrier() + # train 4 epcho two times (resume from last) + ckpt0_savedir = save_dir / 'ckpt0' + # first two epochs + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '2', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.format', format, + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + + # create merged checkpoint + ckpt1_savedir = save_dir / 'ckpt1' + ckpt1_savedir.mkdir(parents=True, exist_ok=True) + merged_file_name = f'merged{Checkpointer.NAME_MAP[format]}' + if trainer.rank == 0: + Trainer.merge_checkpoint(trainer.checkpointer.list_checkpoints(ckpt0_savedir / 'last'), ckpt1_savedir / merged_file_name, serializer='split') + assert Path(str(ckpt1_savedir / merged_file_name) + '.opt').exists() + assert Path(str(ckpt1_savedir / merged_file_name) + '.model').exists() + + torch.distributed.barrier() + # continue with the last two epochs (resume for sharded/deduped checkpoint) + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.format', rev_format, + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + + torch.distributed.barrier() + + # continue with the last two epochs (resume for merged) + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt1_savedir), + '--checkpoint.format', rev_format, + '--checkpoint.resume_from', str(ckpt1_savedir / merged_file_name), + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) + + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} + for i in range(4): + x = trainer.checkpointer.load_for_rank(ckpt_savedir / 'last', i) + y = trainer.checkpointer.load_for_rank(ckpt0_savedir / 'last', i) + z = trainer.checkpointer.load_for_rank(ckpt1_savedir / 'last', i) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + assert_equal(x['lr_scheduler'], y['lr_scheduler']) + assert_equal(x['model'], z['model']) + assert_equal(x['optimizer'], z['optimizer']) + assert_equal(x['lr_scheduler'], z['lr_scheduler']) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('symblink', [True, False]) +def test_trainer_split_serializer(tmp_path, symblink): + launch_torchrun(4, trainer_split_serializer_worker, tmp_path, symblink) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 2dc1d252..307adbef 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -2,13 +2,19 @@ # Licensed under the MIT License. from pathlib import Path +import re import shutil +from typing import Any +from mock import PropertyMock import torch import pytest import torch.distributed +from unittest.mock import patch from nnscaler import merge_state_dicts +from nnscaler.cli.serialization import Checkpointer +import nnscaler from nnscaler.cli.trainer import Trainer, logger from nnscaler.cli.trainer_args import AggregatedOutputs, TrainerArgs from tests.parallel_module.common import assert_equal, assert_close @@ -117,7 +123,12 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' \ if bf16 == 'Mixed' \ else 'torch.optim.Adam' - use_zero = save_type == 'sharded' + use_zero = 1 if save_type == 'sharded' else 0 + format = 'safetensors' if parallel_type % 2 else 'pt' + rev_format = 'pt' if format == 'safetensors' else 'safetensors' + + def list_ckpt_files(dir): + return set(dir.glob('**/*.ckpt')) | set(dir.glob('**/*.safetensors')) if parallel_type == 0: additional_args = [] @@ -147,7 +158,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False' if use_zero else 'True', + '--model.parallel_modules.0.compute_config.use_zero', str(use_zero), '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -164,7 +175,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False' if use_zero else 'True', + '--model.parallel_modules.0.compute_config.use_zero', str(use_zero), '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -191,10 +202,11 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', format, *additional_args, ]) trainer.run() - ckpt_files = set(ckpt_savedir.glob('**/*.ckpt')) + ckpt_files = list_ckpt_files(ckpt_savedir) assert len(ckpt_files)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -215,10 +227,11 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', format, *additional_args, ]) trainer.run() - ckpt0_files0 = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} + ckpt0_files0 = {f: f.stat().st_mtime_ns for f in list_ckpt_files(ckpt0_savedir)} assert len(ckpt0_files0)/4 == min(30, trainer.total_train_steps_per_epoch * 2) + 2 # 2 for best/last # resume from last without update max_epochs @@ -236,18 +249,20 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', rev_format, *additional_args, ]) trainer.run() - ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} + ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in list_ckpt_files(ckpt0_savedir)} # nothing should be updated in this case. assert ckpt0_files0 == ckpt0_files0_x # create merged checkpoint ckpt1_savedir = save_dir / 'ckpt1' ckpt1_savedir.mkdir(parents=True, exist_ok=True) + merged_file_name = f'merged{Checkpointer.NAME_MAP[format]}' if trainer.rank == 0: - Trainer.merge_checkpoint(list((ckpt0_savedir / 'last').glob('*.ckpt')), ckpt1_savedir / 'merged.pt') + Trainer.merge_checkpoint(trainer.checkpointer.list_checkpoints(ckpt0_savedir / 'last'), ckpt1_savedir / merged_file_name) torch.distributed.barrier() # continue with the last two epochs (resume for sharded/deduped checkpoint) @@ -264,6 +279,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', + '--checkpoint.format', rev_format, '--checkpoint.keep_last_n_checkpoints', '30', *additional_args, ]) @@ -276,7 +292,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): for f, s in left_files.items(): # make sure the old checkpoints are not overwritten assert ckpt0_files0[f] == s - ckpt0_files1 = set(ckpt0_savedir.glob('**/*.ckpt')) + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) assert len(ckpt0_files1)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -293,7 +309,8 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--compute_config.use_zero', str(use_zero), '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt1_savedir), - '--checkpoint.resume_from', str(ckpt1_savedir / 'merged.pt'), + '--checkpoint.format', rev_format, + '--checkpoint.resume_from', str(ckpt1_savedir / merged_file_name), '--checkpoint.keep_last_n_checkpoints', '30', *additional_args, ]) @@ -306,7 +323,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): for f, s in left_files.items(): # make sure the old checkpoints are not overwritten assert ckpt0_files0[f] == s - ckpt0_files1 = set(ckpt0_savedir.glob('**/*.ckpt')) + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) assert len(ckpt0_files1)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -314,9 +331,9 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): if torch.distributed.get_rank() == 0: assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} for i in range(4): - x = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) - y = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) - z = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + x = trainer.checkpointer.load_for_rank(ckpt_savedir / 'last', i) + y = trainer.checkpointer.load_for_rank(ckpt0_savedir / 'last', i) + z = trainer.checkpointer.load_for_rank(ckpt1_savedir / 'last', i) assert_equal(x['model'], y['model']) assert_equal(x['optimizer'], y['optimizer']) assert_equal(x['lr_scheduler'], y['lr_scheduler']) @@ -324,12 +341,13 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): assert_equal(x['optimizer'], z['optimizer']) assert_equal(x['lr_scheduler'], z['lr_scheduler']) + suffix = Checkpointer.NAME_MAP[format] if save_type == 'deduped': - assert (ckpt_savedir / 'last/0.ckpt').stat().st_size > (ckpt_savedir / 'last/2.ckpt').stat().st_size - assert (ckpt_savedir / 'last/1.ckpt').stat().st_size > (ckpt_savedir / 'last/3.ckpt').stat().st_size + assert (ckpt_savedir / f'last/0{suffix}').stat().st_size > (ckpt_savedir / f'last/2{suffix}').stat().st_size + assert (ckpt_savedir / f'last/1{suffix}').stat().st_size > (ckpt_savedir / f'last/3{suffix}').stat().st_size else: - assert (ckpt_savedir / 'last/0.ckpt').stat().st_size == (ckpt_savedir / 'last/2.ckpt').stat().st_size - assert (ckpt_savedir / 'last/1.ckpt').stat().st_size == (ckpt_savedir / 'last/3.ckpt').stat().st_size + assert (ckpt_savedir / f'last/0{suffix}').stat().st_size == (ckpt_savedir / f'last/2{suffix}').stat().st_size + assert (ckpt_savedir / f'last/1{suffix}').stat().st_size == (ckpt_savedir / f'last/3{suffix}').stat().st_size @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @@ -593,7 +611,7 @@ def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): gen_savedir = save_dir / 'gen' ckpt_savedir = save_dir / 'ckpt' optimizer_type = 'torch.optim.Adam' - use_zero = False if zero_ngroups is None else True + use_zero = 0 if zero_ngroups is None else 1 zero_ngroups = '1' if zero_ngroups is None else zero_ngroups trainer = Trainer([ @@ -614,7 +632,7 @@ def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): torch.distributed.barrier() -def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): +def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False, hybrid_opt=False, use_zero=0): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) gen_savedir = save_dir / 'gen' @@ -651,7 +669,7 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False', + '--model.parallel_modules.0.compute_config.use_zero', '0', '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -668,7 +686,7 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False', + '--model.parallel_modules.0.compute_config.use_zero', '0', '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -680,6 +698,47 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): else: raise ValueError(f'parallel_type {parallel_type} is not supported') + + def param_clss_fn(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'mlp0.' in param_name: + return 0, 0 + elif 'mlp1.' in param_name: + return 0, 1 + else: + return 1, 0 + + optimizer_config = { + 'type': 'nnscaler.HybridOptimizer', + 'param_clss_fn': param_clss_fn, + 'args': { + 'config': { + 'optimizers':[ + { + 'type': torch.optim.Adam, + 'options': { + 'lr': 0.01, + }, + 'param_groups': [ + {}, + {} + ], + },{ + 'type': torch.optim.Adam, + 'options': { + 'lr': 0.01 + } + } + ] + } + } + } + + if hybrid_opt: + additional_args.extend(['--optimizer!', '--optimizer', optimizer_config]) + # train 4 epcho in one time trainer = Trainer([ '-f', config_path, @@ -687,7 +746,7 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): '--max_epochs', '2', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), - '--compute_config.use_zero', 'False', + '--compute_config.use_zero', str(use_zero), '--compute_config.plan_ngpus', '1', '--compute_config.runtime_ngpus', '2', '--compute_config.use_async_reducer', str(async_reducer), @@ -713,33 +772,49 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): torch.distributed.barrier() -def trainer_correctness_worker_aggregate(tmp_path): +def trainer_correctness_worker_aggregate(tmp_path, use_zero): for parallel_type in range(5): for async_reducer in [False, True]: - print(f'parallel_type={parallel_type}, async_reducer={async_reducer}') - save_dir = tmp_path/f'{parallel_type}-{async_reducer}' - trainer_correctness_worker(save_dir, parallel_type, async_reducer) + for hybrid_opt in [True, False]: + print(f'parallel_type={parallel_type}, async_reducer={async_reducer}, hybrid_opt={hybrid_opt}') + save_dir = tmp_path/f'{parallel_type}-{async_reducer}-{hybrid_opt}' + trainer_correctness_worker(save_dir, parallel_type, async_reducer, hybrid_opt, use_zero) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') -def test_trainer_correctness(tmp_path): - launch_torchrun(2, trainer_correctness_worker_aggregate, tmp_path) +@pytest.mark.parametrize('use_zero', [0, 1, 3]) +def test_trainer_correctness(tmp_path, use_zero): + launch_torchrun(2, trainer_correctness_worker_aggregate, tmp_path, use_zero) merged_ckpts = {} for parallel_type in range(5): for async_reducer in [False, True]: - save_dir = tmp_path/f'{parallel_type}-{async_reducer}' - merged_ckpts[(parallel_type, async_reducer)] = torch.load(save_dir/'merged.pt') + for hybrid_opt in [True, False]: + save_dir = tmp_path/f'{parallel_type}-{async_reducer}-{hybrid_opt}' + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)] = torch.load(save_dir/'merged.pt') + + if use_zero == 3: + assert_fn = assert_close + else: + assert_fn = assert_equal for parallel_type in range(5): for async_reducer in [False, True]: - assert_equal( - merged_ckpts[(parallel_type, async_reducer)]['model'], - merged_ckpts[(0, False)]['model'] - ) - assert_equal( - merged_ckpts[(parallel_type, async_reducer)]['optimizer'], - merged_ckpts[(0, False)]['optimizer'] - ) + for hybrid_opt in [True, False]: + assert_fn( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['model'], + merged_ckpts[(0, False, False)]['model'] + ) + if not hybrid_opt: + assert_fn( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['optimizer'], + merged_ckpts[(0, False, False)]['optimizer'] + ) + else: + # param_groups are different when using hybrid optimizer. + assert_fn( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['optimizer']['state'], + merged_ckpts[(0, False, False)]['optimizer']['state'] + ) def tracing_from_weights_worker(tmp_path): @@ -936,19 +1011,45 @@ def trainer_resumable_dataloader(save_dir): torch.distributed.barrier() # resume for merged - trainer = Trainer([ - '-f', config_path_streaming, - '--precision', 'bf16', - '--optimizer.type', optimizer_type, - '--enable_progress_bar', 'false', - '--gen_savedir', str(gen_savedir), - '--checkpoint.save_type', save_type, - '--checkpoint.save_dir', str(ckpt2_savedir), - '--checkpoint.resume_from', str(ckpt2_savedir / 'merged.pt'), - '--checkpoint.keep_last_n_checkpoints', '30', - ]) - trainer.run() - assert trainer.dataloader_resumed + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt2_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt2_savedir / 'merged.pt'), + '--checkpoint.resume_from.save_memory', False, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' not in log.getvalue() # no warning about dataloader states + + torch.distributed.barrier() + + + ckpt2_1_savedir = save_dir / 'ckpt2_1' + ckpt2_1_savedir.mkdir(parents=True, exist_ok=True) + # resume for merged + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt2_1_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt2_savedir / 'merged.pt'), + '--checkpoint.resume_from.save_memory', True, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states torch.distributed.barrier() @@ -981,20 +1082,44 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.save_dir', str(ckpt4_savedir), '--checkpoint.resume_from.checkpoint', str(ckpt1_savedir / '0002-0035'), '--checkpoint.resume_from.with_merged', True, + '--checkpoint.resume_from.save_memory', False, '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.run() assert trainer.dataloader_resumed assert 'Broadcasting merged checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states + # resume from auto-merged with save_memory + ckpt5_savedir = save_dir / 'ckpt5' + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt5_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt1_savedir / '0002-0035'), + '--checkpoint.resume_from.with_merged', True, + '--checkpoint.resume_from.save_memory', True, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states + + if torch.distributed.get_rank() == 0: for i in range(4): g = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) z = torch.load(ckpt2_savedir / 'last' / f'{i}.ckpt', weights_only=False) + z_1 = torch.load(ckpt2_1_savedir / 'last' / f'{i}.ckpt', weights_only=False) w = torch.load(ckpt3_savedir / 'last' / f'{i}.ckpt', weights_only=False) v = torch.load(ckpt4_savedir / 'last' / f'{i}.ckpt', weights_only=False) + u = torch.load(ckpt5_savedir / 'last' / f'{i}.ckpt', weights_only=False) assert 'dataloader' not in g assert 'dataloader' in x for key in ['model', 'optimizer', 'lr_scheduler', 'dataloader']: @@ -1002,6 +1127,8 @@ def trainer_resumable_dataloader(save_dir): assert_equal(x[key], z[key]) assert_equal(x[key], w[key]) assert_equal(x[key], v[key]) + assert_equal(x[key], u[key]) + assert_equal(x[key], z_1[key]) if key != 'dataloader': assert_equal(g[key], x[key]) @@ -1009,3 +1136,390 @@ def trainer_resumable_dataloader(save_dir): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_trainer_resumable_dataloader(tmp_path): launch_torchrun(4, trainer_resumable_dataloader, tmp_path) + + +@replace_all_device_with('cpu') +def test_trainer_dynamic_worker(tmp_path): + + def check_match(code_dir: Path, should_exist: bool): + gencode_files = list(code_dir.glob('**/*.py')) + assert set(f.name for f in gencode_files) == set(['gencode0.py', 'gencode1.py', 'gencode2.py', 'gencode3.py']) + for gencode_file in gencode_files: + filecontent = gencode_file.read_text() + matches = re.findall(r'B, T, C = x\.size\(\)', filecontent) + if should_exist: + assert matches + else: + assert not matches + + shutil.rmtree(code_dir) + + save_dir = Path(tmp_path) + config_path = str(Path(__file__).with_name('trainer_args_csa.yaml').resolve()) + gen_savedir = save_dir / 'gen' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[1]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + ]) + trainer.run() + check_match(gen_savedir, should_exist=True) + + gen_savedir = save_dir / 'gen0' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + ]) + trainer.run() + check_match(gen_savedir, should_exist=False) + + # mixed compile + gen_savedir = save_dir / 'gen1' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[1]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + + '--model.parallel_modules.0.type', 'tests.cli.common.CausalSelfAttention', + '--model.parallel_modules.0.args.n_embd', '$(model.args.n_embd)', + '--model.parallel_modules.0.args.n_head', '$(model.args.n_head)', + '--model.parallel_modules.0.args.dropout', '$(model.args.dropout)', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.csa_forward_args_gen_fn', + '--model.parallel_modules.0.forward_args_post_process_fn', 'tests.cli.common.post_csa_forward_args_gen_fn', + ]) + trainer.run() + check_match(gen_savedir, should_exist=True) + + # mixed compile + gen_savedir = save_dir / 'gen2' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + + '--model.parallel_modules.0.type', 'tests.cli.common.CausalSelfAttention', + '--model.parallel_modules.0.args.n_embd', '$(model.args.n_embd)', + '--model.parallel_modules.0.args.n_head', '$(model.args.n_head)', + '--model.parallel_modules.0.args.dropout', '$(model.args.dropout)', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.csa_forward_args_gen_fn', + '--model.parallel_modules.0.forward_args_post_process_fn', 'tests.cli.common.post_csa_forward_args_gen_fn', + ]) + trainer.run() + check_match(gen_savedir, should_exist=False) + + +def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + + zero_ngroups = runtime_ngpus // plan_ngpus // 2 + if zero_ngroups < 1: + zero_ngroups = 1 + policy = 'dp' if plan_ngpus == 1 else 'tp' + + gen3_savedir = save_dir / 'gen3' + ckpt3_savedir = save_dir / 'ckpt3' + # train 1 epcho in one time with zero3 + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '5', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + torch.distributed.barrier() + + # load from sharded + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', 'last', + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + # load from deduped + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '15', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', 'last', + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + + torch.distributed.barrier() + + # load from merged (from deduped) + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', str(ckpt3_savedir / 'merged.pt'), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + + with ( + patch('nnscaler.ComputeConfig.module_dedup_group_size', new_callable=PropertyMock) as mock_dgs, + patch('nnscaler.ComputeConfig.optimizer_dedup_group_size', new_callable=PropertyMock) as mock_dgs2 + ): + # to mock the case where we have duplicated data in merging + mock_dgs.return_value = runtime_ngpus + mock_dgs2.return_value = runtime_ngpus + + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged2.pt') + zero3_merged_state_dict2 = torch.load(ckpt3_savedir / 'merged2.pt') + zero3_merged_state_dict = torch.load(ckpt3_savedir / 'merged.pt') + assert_equal(zero3_merged_state_dict, zero3_merged_state_dict2) + + torch.distributed.barrier() + + # load from merged (from sharded) + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '25', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', str(ckpt3_savedir / 'merged.pt'), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + torch.distributed.barrier() + + gen1_savedir = save_dir / 'gen1' + ckpt1_savedir = save_dir / 'ckpt1' + + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '25', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen1_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '1', + '--checkpoint.save_dir', str(ckpt1_savedir), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt1_savedir / 'last').glob('*.ckpt')), ckpt1_savedir / 'merged.pt') + zero1_merged_state_dict = torch.load(ckpt1_savedir / 'merged.pt') + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + zero3_merged_state_dict = torch.load(ckpt3_savedir / 'merged.pt') + assert_equal(zero1_merged_state_dict['model'], zero3_merged_state_dict['model']) + assert_equal(zero1_merged_state_dict['optimizer'], zero3_merged_state_dict['optimizer']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_trainer_zero3(tmp_path): + launch_torchrun(2, trainer_zero3, 16, tmp_path, 1, 2) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_trainer_zero3_tp(tmp_path): + launch_torchrun(4, trainer_zero3, 16, tmp_path, 2, 4) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_trainer_zero3_ngroup(tmp_path): + # dim that needs padding + launch_torchrun(4, trainer_zero3, 13, tmp_path, 1, 4) + + +def trainer_checkpointer_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + from nnscaler.cli import register_format + + load_triggered = False + + class TestFormat: + name: str = 'test_format' + suffix: str = '.testpt' + + @classmethod + def save(cls, obj: Any, f: Path) -> None: + obj['test'] = True + return torch.save(obj, f) + + @classmethod + def load(cls, f: str | Path, *, device='cpu') -> Any: + x = torch.load(f, map_location=device, weights_only=False) + assert x['test'] is True + nonlocal load_triggered + load_triggered = True + return x + + register_format(TestFormat) + + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '1', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.format', 'test_format', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + + files0 = list(ckpt_savedir.glob('**/*.testpt')) + assert files0, 'No checkpoint files saved with custom format.' + + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.format', 'test_format', + '--checkpoint.resume_from', 'last', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + assert load_triggered, 'Custom load function not triggered when resuming.' + + files1 = list(ckpt_savedir.glob('**/*.testpt')) + assert len(files1) > len(files0), 'Checkpoint files not updated after resuming.' + assert all(f in files1 for f in files0), 'Some checkpoint files missing after resuming.' + assert files1, 'No checkpoint files saved with custom format.' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_custom_checkpointer(tmp_path): + launch_torchrun(1, trainer_checkpointer_worker, tmp_path) + + +def trainer_pas_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + from nnscaler.policies import pas_dp + from nnscaler.cli import TrainerArgs + called = False + + def custom_pas(graph, cfg): + nonlocal called + called = True + return pas_dp(graph, cfg) + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--max_epochs', '1', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + args.pas_policy = custom_pas + # train 1 epcho in one time + trainer = Trainer(train_args=args) + trainer.run() + + assert called, 'Custom PAS policy not called.' + + gen_savedir = save_dir / 'gen2' + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--pas-policy', 'nnscaler.policies.pas_dp', # use full qualified name of pas policy + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_custom_pas(tmp_path): + launch_torchrun(1, trainer_pas_worker, tmp_path) diff --git a/tests/cli/test_trainer2.py b/tests/cli/test_trainer2.py new file mode 100644 index 00000000..34007486 --- /dev/null +++ b/tests/cli/test_trainer2.py @@ -0,0 +1,134 @@ +from pathlib import Path +import pytest +import torch +from torch.utils.data import Dataset + +from nnscaler.cli import TrainerArgs, Trainer +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode + + +class NanoGptDataset(Dataset): + def __init__(self, *args, **kwargs): + pass + + def __getitems__(self, indices): + return [torch.randint(0, 151936, (1, 4096), dtype=torch.int64) for _ in indices] + + def __len__(self): + return 10000 + + +def gen_args(trainer_args: 'TrainerArgs'): + src_token = torch.randint(0, 151936, (1, 4096), dtype=torch.int64) + ret = dict( + input_ids=src_token, # torch.Size([1, 4096]) torch.int64 + ) + return ret + + +class WrappedSubModel(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + self.embedding = torch.nn.Embedding(151936, 1536) + + def forward(self, input_ids): + x = self.embedding(input_ids) + return x + + +class WrapperModel(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + self.model = WrappedSubModel() + + def forward(self, src_tokens): + # the logic is from task.train_step + logits = self.model( + src_tokens + ) + return torch.sum(logits) + + +def trainer_mixed_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args_mixed_bf16.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer = Trainer(train_args=args) + trainer.run() + # should reach here without error + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_mixed_bf16_model(tmp_path): + launch_torchrun(2, trainer_mixed_worker, tmp_path) + + +class SharedWeightsDataset(Dataset): + def __init__(self, *args, **kwargs): + pass + + def __getitems__(self, indices): + return [torch.randn(4, 4) for _ in indices] + + def __len__(self): + return 10000 + + +class SharedWeightsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + self.linear2 = torch.nn.Linear(4, 4, bias=False) + self.linear2.weight = self.linear.weight # share weight + + def forward(self, x): + y = x * 2 + z = x + 2 + r = self.linear2(y) + r = r + self.linear(z) + return torch.sum(r) + + +def trainer_zero3_shared_weights_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args_shared_weights.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer = Trainer(train_args=args) + trainer.run() + # weight sharing multiref should have clone_level=1 in gencode + assert _gencode_contains( + gen_savedir, + SharedWeightsModule, + torch.distributed.get_rank(), + r'linear_weight_\d+, linear_weight_\d+ = nnscaler.runtime.function.multiref\(self.linear_weight_\d+, times=2, clone_level=1\)' + ) + # non-weight tensor multiref should not have clone_level + assert _gencode_contains( + gen_savedir, + SharedWeightsModule, + torch.distributed.get_rank(), + r'x_\d+, x_\d+ = nnscaler.runtime.function.multiref\(x_\d+, times=2\)' + ) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_zero3_shared_weights(tmp_path): + launch_torchrun(4, trainer_zero3_shared_weights_worker, tmp_path) diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml index db05b3e6..5006b87a 100644 --- a/tests/cli/trainer_args.yaml +++ b/tests/cli/trainer_args.yaml @@ -1,6 +1,7 @@ vars: dim: 16 drop_last: true + compute_config: plan_ngpus: 4 runtime_ngpus: 100 diff --git a/tests/cli/trainer_args_csa.yaml b/tests/cli/trainer_args_csa.yaml new file mode 100644 index 00000000..1a18c6e3 --- /dev/null +++ b/tests/cli/trainer_args_csa.yaml @@ -0,0 +1,53 @@ +vars: + dynamic_dims: [1] + dim: 16 + drop_last: true +compute_config: + plan_ngpus: 4 + runtime_ngpus: 100 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: tp +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 +max_train_steps: 100 +seed: 0 +dummy_sample_gen_fn: tests.cli.common.transformer_dummy_sample_gen_fn + +model: + type: tests.cli.common.SimpleTransformerModel + args: + n_embd: 1024 + n_head: 8 + dropout: 0.001 + nlayers: 2 + vocab_size: 10000 + +optimizer: + type: torch.optim.Adam + args: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: $(vars.dim) + size: 100 + val_args: + dim: $(vars.dim) + size: 10 + +dataloader: + train_args: + drop_last: $(vars.drop_last) + val_args: + drop_last: $(vars.drop_last) + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/cli/trainer_args_mixed_bf16.yaml b/tests/cli/trainer_args_mixed_bf16.yaml new file mode 100644 index 00000000..3ba80cbb --- /dev/null +++ b/tests/cli/trainer_args_mixed_bf16.yaml @@ -0,0 +1,36 @@ +compute_config: + plan_ngpus: 1 + runtime_ngpus: 2 + constant_folding: false + use_zero: 3 + use_end2end: true + +run_mode: run +pas_policy: dp +micro_batch_size: 1 +grad_accumulation_steps: 4 +max_train_steps: 10 +enable_progress_bar: false +log_progress_every_n_train_steps: 10 +precision: bf16 +seed: 1 + +model: + type: tests.cli.test_trainer2.WrapperModel + + parallel_modules: + - type: tests.cli.test_trainer2.WrappedSubModel + forward_args_gen_fn: tests.cli.test_trainer2.gen_args + +optimizer: + type: torch.optim.AdamW + args: + betas: (0.9, 0.95) + eps: 1e-08 + weight_decay: 0.1 + lr: 0.0001 + fused: true + clip_gnorm: 2.0 + +dataset: + type: tests.cli.test_trainer2.NanoGptDataset diff --git a/tests/cli/trainer_args_shared_weights.yaml b/tests/cli/trainer_args_shared_weights.yaml new file mode 100644 index 00000000..2dbd6f7a --- /dev/null +++ b/tests/cli/trainer_args_shared_weights.yaml @@ -0,0 +1,32 @@ +compute_config: + plan_ngpus: 2 + runtime_ngpus: 4 + constant_folding: false + use_zero: 3 + use_end2end: true + +run_mode: run +pas_policy: tp +micro_batch_size: 1 +grad_accumulation_steps: 4 +max_train_steps: 10 +enable_progress_bar: false +log_progress_every_n_train_steps: 10 +precision: bf16 +seed: 1 + +model: + type: tests.cli.test_trainer2.SharedWeightsModule + +optimizer: + type: torch.optim.AdamW + args: + betas: (0.9, 0.95) + eps: 1e-08 + weight_decay: 0.1 + lr: 0.0001 + fused: true + clip_gnorm: 2.0 + +dataset: + type: tests.cli.test_trainer2.SharedWeightsDataset diff --git a/tests/customized_ops/__init__.py b/tests/customized_ops/__init__.py new file mode 100644 index 00000000..78e3db5e --- /dev/null +++ b/tests/customized_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Ring Attention test module""" \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/__init__.py b/tests/customized_ops/ring_attn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/customized_ops/ring_attn/configs.py b/tests/customized_ops/ring_attn/configs.py new file mode 100644 index 00000000..ebc7182c --- /dev/null +++ b/tests/customized_ops/ring_attn/configs.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Configuration file for ring attention tests. +This file contains predefined test configurations for both correctness and performance testing. +""" + +from dataclasses import dataclass +from typing import List, Tuple, Optional + + +@dataclass +class RingAttnConfig: + """Configuration for ring attention test cases""" + batch_size: int + num_heads: int + head_dim: int + max_seqlen: int + dtype: str = "bf16" + name: str = "" + num_kv_heads: Optional[int] = None # For GQA/MQA support + causal: bool = True # Most attention patterns are causal + window_size: Tuple[int, int] = (-1, -1) # Sliding window attention (-1, -1) means no window + + def __post_init__(self): + # Set num_kv_heads to num_heads if not specified (standard MHA) + if self.num_kv_heads is None: + self.num_kv_heads = self.num_heads + + if not self.name: + gqa_suffix = f"_gqa{self.num_kv_heads}" if self.num_kv_heads != self.num_heads else "" + causal_suffix = "" if self.causal else "_noncausal" + window_suffix = f"_w{self.window_size[0]}-{self.window_size[1]}" if self.window_size != (-1, -1) else "" + self.name = f"b{self.batch_size}_h{self.num_heads}_d{self.head_dim}_s{self.max_seqlen}_{self.dtype}{gqa_suffix}{causal_suffix}{window_suffix}" + + # Generate cu_seqlens for variable length sequences + # Create sequences with different lengths for more realistic testing + seq_lens = [ + self.max_seqlen // 8, # Short sequence + self.max_seqlen // 4, # Medium sequence + self.max_seqlen // 2, # Long sequence + self.max_seqlen - self.max_seqlen // 8 - self.max_seqlen // 4 - self.max_seqlen // 2 # Remaining + ] + self.cu_seqlens = [0] + for seq_len in seq_lens: + self.cu_seqlens.append(self.cu_seqlens[-1] + seq_len) + + @property + def total_tokens(self) -> int: + """Total number of tokens across all sequences""" + return self.cu_seqlens[-1] + + @property + def is_gqa(self) -> bool: + """Check if this is a GQA (Grouped Query Attention) configuration""" + return self.num_kv_heads < self.num_heads + + @property + def is_mqa(self) -> bool: + """Check if this is an MQA (Multi-Query Attention) configuration""" + return self.num_kv_heads == 1 + + @property + def num_groups(self) -> int: + """Number of query heads per KV head (group size)""" + return self.num_heads // self.num_kv_heads + + +# Small test cases for quick correctness validation +SMALL_CONFIGS = { + "tiny": RingAttnConfig(2, 8, 64, 1024, "bf16", "tiny", causal=True), + "small": RingAttnConfig(4, 12, 128, 4096, "bf16", "small", causal=True), + "small_fp16": RingAttnConfig(4, 12, 128, 4096, "fp16", "small_fp16", causal=False), # One non-causal config + "small_window": RingAttnConfig(4, 12, 128, 4096, "bf16", "small_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Medium test cases for standard testing +MEDIUM_CONFIGS = { + "medium": RingAttnConfig(4, 24, 128, 8192, "bf16", "medium", causal=True), + "medium_large_head": RingAttnConfig(4, 12, 256, 8192, "bf16", "medium_large_head", causal=False), # One non-causal config + "medium_many_heads": RingAttnConfig(4, 32, 128, 8192, "bf16", "medium_many_heads", causal=True), + "medium_fp16": RingAttnConfig(4, 24, 128, 8192, "fp16", "medium_fp16", causal=True), + "medium_window": RingAttnConfig(4, 24, 128, 8192, "bf16", "medium_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Large test cases for performance benchmarking +LARGE_CONFIGS = { + "large": RingAttnConfig(4, 32, 128, 16384, "bf16", "large", causal=True), + "large_seq": RingAttnConfig(4, 24, 128, 32768, "bf16", "large_seq", causal=True), + "large_head": RingAttnConfig(4, 24, 256, 16384, "bf16", "large_head", causal=False), # One non-causal config + "xlarge": RingAttnConfig(8, 32, 128, 32768, "bf16", "xlarge", causal=True), + "large_window": RingAttnConfig(4, 32, 128, 16384, "bf16", "large_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Realistic model configurations (kept minimal, most covered by medium/large configs) +MODEL_CONFIGS = { +} + +# GQA (Grouped Query Attention) configurations based on Qwen models +GQA_CONFIGS = { + # Qwen3-235B-A22B: 64 heads, 4 kv_heads, 128 head_dim + "qwen3_235b_a22b": RingAttnConfig( + batch_size=2, + num_heads=64, + head_dim=64, + max_seqlen=16384, + dtype="bf16", + name="qwen3_235b_a22b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-30B-A3B: 40 heads, 8 kv_heads, 128 head_dim + "qwen3_30b_a3b": RingAttnConfig( + batch_size=4, + num_heads=32, + head_dim=64, + max_seqlen=16384, + dtype="bf16", + name="qwen3_30b_a3b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-4B: 32 heads, 4 kv_heads, 80 head_dim + "qwen3_4b": RingAttnConfig( + batch_size=4, + num_heads=32, + head_dim=80, + max_seqlen=16384, + dtype="bf16", + name="qwen3_4b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-32B: 64 heads, 8 kv_heads, 128 head_dim + "qwen3_32b": RingAttnConfig( + batch_size=2, + num_heads=64, + head_dim=128, + max_seqlen=16384, + dtype="bf16", + name="qwen3_32b", + num_kv_heads=8, + causal=True + ), + + # Qwen3-14B: 40 heads, 8 kv_heads, 128 head_dim + "qwen3_14b": RingAttnConfig( + batch_size=4, + num_heads=40, + head_dim=128, + max_seqlen=16384, + dtype="bf16", + name="qwen3_14b", + num_kv_heads=8, + causal=True + ), +} + +# MQA is already covered by medium/large configs, so removed duplicate MQA_CONFIGS + +# Zigzag attention configurations (only supports causal=True and window_size=(-1, -1)) +ZIGZAG_CONFIGS = { + "zigzag_tiny": RingAttnConfig(2, 8, 64, 1024, "bf16", "zigzag_tiny", causal=True, window_size=(-1, -1)), + "zigzag_small": RingAttnConfig(4, 12, 128, 4096, "bf16", "zigzag_small", causal=True, window_size=(-1, -1)), + "zigzag_medium": RingAttnConfig(4, 24, 128, 8192, "bf16", "zigzag_medium", causal=True, window_size=(-1, -1)), + "zigzag_large": RingAttnConfig(4, 32, 128, 16384, "bf16", "zigzag_large", causal=True, window_size=(-1, -1)), + "zigzag_fp16": RingAttnConfig(4, 12, 128, 4096, "fp16", "zigzag_fp16", causal=True, window_size=(-1, -1)), + "zigzag_gqa": RingAttnConfig(4, 32, 128, 8192, "bf16", "zigzag_gqa", num_kv_heads=8, causal=True, window_size=(-1, -1)), +} + +# All configurations combined +ALL_CONFIGS = { + **SMALL_CONFIGS, + **MEDIUM_CONFIGS, + **LARGE_CONFIGS, + **MODEL_CONFIGS, + **GQA_CONFIGS, + **ZIGZAG_CONFIGS, +} + +# Default configurations for different test types +DEFAULT_CORRECTNESS_CONFIGS = ["tiny", "small", "medium"] +DEFAULT_PERFORMANCE_CONFIGS = ["medium", "large"] +DEFAULT_MULTI_GPU_CONFIGS = ["small", "medium"] +DEFAULT_GQA_CONFIGS = ["qwen3_4b", "qwen3_14b", "qwen3_32b"] +DEFAULT_ZIGZAG_CONFIGS = ["zigzag_tiny", "zigzag_small", "zigzag_medium"] + + +def get_config(name: str) -> RingAttnConfig: + """Get a configuration by name""" + if name in ALL_CONFIGS: + return ALL_CONFIGS[name] + else: + raise ValueError(f"Unknown configuration: {name}. Available: {list(ALL_CONFIGS.keys())}") + + +def list_configs(category: str = "all") -> List[str]: + """List available configurations by category""" + if category == "all": + return list(ALL_CONFIGS.keys()) + elif category == "small": + return list(SMALL_CONFIGS.keys()) + elif category == "medium": + return list(MEDIUM_CONFIGS.keys()) + elif category == "large": + return list(LARGE_CONFIGS.keys()) + elif category == "model": + return list(MODEL_CONFIGS.keys()) + elif category == "gqa": + return list(GQA_CONFIGS.keys()) + elif category == "zigzag": + return list(ZIGZAG_CONFIGS.keys()) + elif category == "correctness": + return DEFAULT_CORRECTNESS_CONFIGS + elif category == "performance": + return DEFAULT_PERFORMANCE_CONFIGS + elif category == "multi_gpu": + return DEFAULT_MULTI_GPU_CONFIGS + elif category == "gqa_default": + return DEFAULT_GQA_CONFIGS + elif category == "zigzag_default": + return DEFAULT_ZIGZAG_CONFIGS + else: + raise ValueError(f"Unknown category: {category}") + + +def get_configs_by_category(category: str) -> dict: + """Get all configurations in a category""" + config_names = list_configs(category) + return {name: get_config(name) for name in config_names} + + +def get_gqa_configs() -> dict: + """Get all GQA (Grouped Query Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if config.is_gqa and not config.is_mqa} + + +def get_mqa_configs() -> dict: + """Get all MQA (Multi-Query Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if config.is_mqa} + + +def get_mha_configs() -> dict: + """Get all MHA (Multi-Head Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if not config.is_gqa} + + +def get_zigzag_configs() -> dict: + """Get all Zigzag attention configurations""" + return ZIGZAG_CONFIGS + + +def filter_configs_by_attention_type(attention_type: str) -> dict: + """Filter configurations by attention type: 'mha', 'gqa', 'mqa', or 'zigzag'""" + if attention_type.lower() == "mha": + return get_mha_configs() + elif attention_type.lower() == "gqa": + return get_gqa_configs() + elif attention_type.lower() == "mqa": + return get_mqa_configs() # Will return empty dict since no dedicated MQA configs + elif attention_type.lower() == "zigzag": + return get_zigzag_configs() + else: + raise ValueError(f"Unknown attention type: {attention_type}. Supported: 'mha', 'gqa', 'mqa', 'zigzag'") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/ring_attn_runner.py b/tests/customized_ops/ring_attn/ring_attn_runner.py new file mode 100644 index 00000000..405c61b1 --- /dev/null +++ b/tests/customized_ops/ring_attn/ring_attn_runner.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Correctness Test Runner Script + +This script runs ring attention correctness tests in a distributed environment. +It compares the outputs of single-GPU and multi-GPU ring attention to ensure correctness. +""" + +import sys +from typing import Tuple +import torch + +from runner_base import RingAttnRunnerBase +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + + +class TestModule(torch.nn.Module): + """Test module for ring attention""" + def __init__(self, causal=True, window_size=(-1, -1)): + super(TestModule, self).__init__() + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v): + result = wrap_ring_attn_func( + q, k, v, + causal=self.causal, + window_size=self.window_size + ) + return result + + +class RingAttnRunner(RingAttnRunnerBase): + """Runner for ring attention tests""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn.wrap_ring_attn_func' + + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 1 + + @property + def function_name(self) -> str: + return 'wrap_ring_attn_func' + + def create_test_module(self, config) -> torch.nn.Module: + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare regular inputs with shape [batch_size, seq_len, num_heads, head_dim]""" + q = torch.clamp(torch.randn( + config.batch_size, + config.max_seqlen, + config.num_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ), min=-1, max=1) + + k = torch.clamp(torch.randn( + config.batch_size, + config.max_seqlen, + config.num_kv_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ), min=-1, max=1) + + v = torch.clamp(torch.randn( + config.batch_size, + config.max_seqlen, + config.num_kv_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ), min=-1, max=1) + + return {'q': q, 'k': k, 'v': v} + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + # Run single GPU version (this should call flash_attn internally when no process_group) + single_out = wrap_ring_attn_func( + inputs['q'], inputs['k'], inputs['v'], + causal=config.causal, + window_size=config.window_size + ) + return single_out, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization""" + return { + "q": inputs["q"], + "k": inputs["k"], + "v": inputs["v"], + } + + +def ring_attn_test(dtype="bf16", config_name="tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = RingAttnRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + +def run_correctness_test(**kwargs): + """Legacy function for backward compatibility""" + runner = RingAttnRunner() + runner.run_correctness_test(**kwargs) + + +if __name__ == "__main__": + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + runner = RingAttnRunner() + runner.main(**kwargs) \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py b/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py new file mode 100644 index 00000000..2ca40313 --- /dev/null +++ b/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Correctness Test Runner + +This script runs ring attention variable length correctness tests in a distributed environment. +It compares the outputs of single-GPU and multi-GPU ring attention to ensure correctness. +""" + +import sys +from typing import Tuple +import torch + +from runner_base import RingAttnRunnerBase +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func + + +class TestModule(torch.nn.Module): + def __init__(self, causal=True, window_size=(-1, -1)): + super(TestModule, self).__init__() + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v, cu_seqlens_q, cu_seqlens_k): + out = wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=self.causal, + window_size=self.window_size + ) + return out + + +class RingAttnVarlenRunner(RingAttnRunnerBase): + """Runner for ring attention variable length tests""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func' + + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 0 + + @property + def function_name(self) -> str: + return 'ring_attn_varlen_func' + + def create_test_module(self, config) -> torch.nn.Module: + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare variable length inputs with cu_seqlens""" + cu_seqlens_tensor = torch.tensor(config.cu_seqlens, dtype=torch.int32, device=device) + total_seqlen = config.cu_seqlens[-1] + + # Create inputs with total sequence length (don't set requires_grad here, base class handles it) + q = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + k = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + v = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + + return { + 'q': q, + 'k': k, + 'v': v, + 'cu_seqlens_q': cu_seqlens_tensor, + 'cu_seqlens_k': cu_seqlens_tensor + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + single_out = wrap_ring_attn_varlen_func( + inputs['q'], inputs['k'], inputs['v'], + inputs['cu_seqlens_q'], inputs['cu_seqlens_k'], None, + causal=config.causal, + window_size=config.window_size + ) + single_out.retain_grad() + return single_out, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization""" + return { + "q": inputs["q"], + "k": inputs["k"], + "v": inputs["v"], + 'cu_seqlens_q': inputs['cu_seqlens_q'], + 'cu_seqlens_k': inputs['cu_seqlens_k'] + } + + +def ring_attn_varlen_test(dtype="bf16", config_name="tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = RingAttnVarlenRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + +def run_ring_attn_correctness_test(**kwargs): + """Legacy function for backward compatibility""" + runner = RingAttnVarlenRunner() + runner.run_correctness_test(**kwargs) + + +if __name__ == "__main__": + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + runner = RingAttnVarlenRunner() + runner.main(**kwargs) diff --git a/tests/customized_ops/ring_attn/runner_base.py b/tests/customized_ops/ring_attn/runner_base.py new file mode 100644 index 00000000..e7d9a32f --- /dev/null +++ b/tests/customized_ops/ring_attn/runner_base.py @@ -0,0 +1,301 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base runner framework for ring attention correctness tests. +This module provides common functionality for both ring_attn and ring_attn_varlen test runners. +""" + +import os +import sys +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Tuple, Union + +import torch +import torch.distributed as dist +import nnscaler +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType + +from nnscaler.customized_ops.ring_attention.core.utils import set_seed, log +from configs import get_config + + +class RingAttnRunnerBase(ABC): + """Base class for ring attention test runners""" + + @property + @abstractmethod + def function_signature(self) -> str: + """Return the function signature to look for in the graph""" + pass + + @property + @abstractmethod + def partition_position(self) -> Tuple[int, int]: + """Return the partition position (idx, dim)""" + pass + + @property + @abstractmethod + def function_name(self) -> str: + """Return the function name for partitioning""" + pass + + @abstractmethod + def create_test_module(self, config) -> torch.nn.Module: + """Create the test module with the appropriate configuration""" + pass + + @abstractmethod + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors based on the configuration and attention type""" + pass + + @abstractmethod + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + pass + + @abstractmethod + def get_dummy_forward_args(self, inputs) -> Dict[str, Any]: + """Get dummy forward arguments for model parallelization""" + pass + + def create_policy(self) -> callable: + """Create partitioning policy for the specific attention type""" + def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: + ngpus = resource.plan_ngpus + partitioned = False + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == self.function_signature: + print(f'\nPartitioned node: {node}\n') + idx, dim = self.partition_position + sub_nodes = graph.partition(node, node.algorithm('dim'), idx=idx, dim=dim, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + if not partitioned: + signatures = [node.signature for node in graph.select(ntype=IRFwOperation)] + raise RuntimeError(f"Failed to find the target function '{self.function_signature}' in {signatures}") + return graph + return policy + + def initialize_distributed(self): + """Initialize distributed environment""" + # Check CUDA availability first + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available") + sys.exit(1) + + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + # Check if we have enough GPUs + available_gpus = torch.cuda.device_count() + if available_gpus < world_size: + print(f"ERROR: Test requires {world_size} GPUs, but only {available_gpus} available") + sys.exit(1) + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + device_count = torch.cuda.device_count() + device = rank % device_count + try: + torch.cuda.set_device(device) + except Exception as e: + print(f"ERROR: Failed to set CUDA device {device}: {e}") + sys.exit(1) + + print(f"[INFO] world_size:{world_size}, rank:{rank}, available_gpus:{available_gpus}") + + try: + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + except Exception as e: + print(f"ERROR: Failed to initialize process group: {e}") + sys.exit(1) + + # Initialize nnscaler + nnscaler.init() + return world_size, rank + + def get_tolerances(self, dtype: str, num_heads: int, num_kv_heads: int) -> Dict[str, float]: + """Get tolerance values based on data type""" + if dtype == "bf16": + if num_heads == num_kv_heads: + return dict(atol=2.5e-2, rtol=2.5e-2) + else: + return dict(atol=3.5e-2, rtol=3.5e-2) + elif dtype == "fp16": + return dict(atol=5e-3, rtol=5e-3) + else: + return dict(atol=2.5e-2, rtol=2.5e-2) + + def print_debug_info(self, single_out, para_out, single_grads, para_grads, rank_id): + """Print debug information when correctness test fails""" + if rank_id == 0: + print("✗ Correctness test FAILED!") + # Print detailed error information + log("single out", single_out, rank0_only=True) + log("multi out", para_out, rank0_only=True) + log("out diff", single_out - para_out, rank0_only=True) + + for i, (single_grad, para_grad, name) in enumerate(zip(single_grads, para_grads, ['q', 'k', 'v'])): + log(f"single d{name}", single_grad, rank0_only=True) + log(f"multi d{name}", para_grad, rank0_only=True) + log(f"d{name} diff", single_grad - para_grad, rank0_only=True) + + def print_success_info(self, rank_id, config_name=None): + """Print success information""" + if rank_id == 0: + config_suffix = f" for config '{config_name}'" if config_name else "" + print(f"✓ Correctness test PASSED{config_suffix}!") + + def run_correctness_test(self, config_name: str, dtype: str = "bf16", **kwargs): + """Run correctness test with the specific attention implementation""" + # Initialize distributed + world_size, rank = self.initialize_distributed() + rank_id = torch.distributed.get_rank() + + # Get configuration + config = get_config(config_name) + torch_dtype = torch.bfloat16 if dtype == "bf16" else torch.float16 + + if rank_id == 0: + print(f"Testing {self.function_name} correctness") + print(f"Configuration: {config.name}") + print(f" Batch size: {config.batch_size}") + print(f" Sequence length: {config.max_seqlen}") + print(f" Num heads: {config.num_heads}") + print(f" KV heads: {config.num_kv_heads}") + print(f" Head dim: {config.head_dim}") + print(f" Data type: {dtype}") + print(f" World size: {world_size}") + print("=" * 60) + + # Set seed for reproducibility + set_seed(42 + rank_id) + device = torch.device(f"cuda:{rank_id}") + + # Prepare inputs (implementation-specific) + inputs = self.prepare_inputs(config, device, torch_dtype) + + # Broadcast inputs to ensure consistency across ranks + for tensor in inputs.values(): + if isinstance(tensor, torch.Tensor): + dist.broadcast(tensor, src=0) + dist.barrier() + + # Setup models + model = self.create_test_module(config) + + # Create parallel model + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + + parallel_model = parallelize( + model, + dummy_forward_args=self.get_dummy_forward_args(dummy_args), + pas_policy=self.create_policy(), + compute_config=ComputeConfig(world_size, world_size), + reuse=ReuseType.OVERRIDE + ) + parallel_model = parallel_model.cuda() + parallel_model.train() + + # Run correctness test + print("Running correctness test..." if rank_id == 0 else "", end="") + + # Single mode for reference + single_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + single_inputs[k] = v.detach().clone().requires_grad_() + else: + single_inputs[k] = v.detach().clone() + else: + single_inputs[k] = v + + single_out, single_grad_tensors = self.run_single_gpu_reference(single_inputs, config) + + # Create gradient for backward pass + dout = torch.clamp(torch.randn_like(single_out, device=device, dtype=torch_dtype), min=-1, max=1) + # Ensure dout is consistent across all ranks + dist.broadcast(dout, src=0) + single_out.backward(dout) + + # Extract single gradients + single_grads = [tensor.grad for tensor in single_grad_tensors] + + # Parallel mode for correctness + para_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + para_inputs[k] = v.detach().clone().requires_grad_() + else: + para_inputs[k] = v.detach().clone() + else: + para_inputs[k] = v + + para_out = parallel_model(**para_inputs) + para_out.backward(dout) + parallel_model.sync_grad() + + # Extract gradients for q, k, v tensors + para_grads = [para_inputs[k].grad for k in ['q', 'k', 'v']] + + print(" Done!" if rank_id == 0 else "") + + # Check correctness with tolerances + tols = self.get_tolerances(dtype, config.num_heads, config.num_kv_heads) + + # Verify outputs and gradients + try: + torch.testing.assert_close(single_out, para_out, **tols) + for single_grad, para_grad in zip(single_grads, para_grads): + torch.testing.assert_close(single_grad, para_grad, **tols) + + self.print_success_info(rank_id, config_name) + + except AssertionError as e: + self.print_debug_info(single_out, para_out, single_grads, para_grads, rank_id) + raise e + + dist.destroy_process_group() + + def main(self, **kwargs): + """Main entry point for the test runner""" + # Filter out torch.distributed.launch arguments + filtered_kwargs = {} + for k, v in kwargs.items(): + if k.startswith('--'): + # Remove leading '--' from argument names + k = k[2:].replace('-', '_') + if k not in ['local_rank', 'local-rank']: # Filter out torch.distributed.launch args + filtered_kwargs[k] = v + + # Convert string arguments back to appropriate types + for numeric_arg in ['batch_size', 'num_heads', 'head_dim', 'max_seqlen']: + if numeric_arg in filtered_kwargs and filtered_kwargs[numeric_arg] is not None: + filtered_kwargs[numeric_arg] = int(filtered_kwargs[numeric_arg]) + + for float_arg in ['rtol', 'atol']: + if float_arg in filtered_kwargs and filtered_kwargs[float_arg] is not None: + filtered_kwargs[float_arg] = float(filtered_kwargs[float_arg]) + + self.run_correctness_test(**filtered_kwargs) \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_base.py b/tests/customized_ops/ring_attn/test_base.py new file mode 100644 index 00000000..44870792 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_base.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base test framework for ring attention tests. +This module provides common functionality for both ring_attn and ring_attn_varlen tests. +""" + +import os +import sys +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Tuple +from functools import partial + +import pytest +import torch + +from .configs import ( + DEFAULT_CORRECTNESS_CONFIGS, + DEFAULT_MULTI_GPU_CONFIGS, + DEFAULT_GQA_CONFIGS, + get_config, + list_configs +) + +from ...launch_torchrun import torchrun + + +class RingAttnTestBase(ABC): + """Base class for ring attention tests""" + + @property + @abstractmethod + def runner_script_name(self) -> str: + """Return the name of the runner script (e.g., 'run_correctness.py')""" + pass + + @property + @abstractmethod + def test_name_prefix(self) -> str: + """Return the prefix for test names (e.g., 'ring_attn' or 'ring_attn_varlen')""" + pass + + @property + @abstractmethod + def test_function_name(self) -> str: + """Return the name of the test function to import (e.g., 'zigzag_attn_test')""" + pass + + def _check_gpu_availability(self, required_gpus: int): + """Check if enough GPUs are available and skip test if not""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + available_gpus = torch.cuda.device_count() + if available_gpus < required_gpus: + pytest.skip(f"Test requires {required_gpus} GPUs, but only {available_gpus} available") + + def _get_project_root(self): + """Get the absolute path to nnscaler root directory""" + current_dir = os.path.dirname(__file__) # tests/customized_ops/ring_attn/ + return os.path.abspath(os.path.join(current_dir, "../../../")) + + def get_bash_arguments(self, num_gpus_per_node: int, **kwargs) -> List[str]: + """Generate command line arguments for running the test script + + Deprecated: This method is kept for backward compatibility. + The new implementation uses launch_torchrun directly. + """ + args = [ + "python3", + "-m", + "torch.distributed.launch", + "--nproc-per-node=" + str(num_gpus_per_node), + ] + + project_root = self._get_project_root() + script_path = os.path.join( + project_root, "tests", "customized_ops", "ring_attn", + self.runner_script_name + ) + args.append(script_path) + + for k, v in kwargs.items(): + args.append(f"{k}={v}") + return args + + def _get_test_function(self): + """Get the test function for this test""" + # Add the script directory to sys.path to allow imports + project_root = self._get_project_root() + script_dir = os.path.join(project_root, "tests", "customized_ops", "ring_attn") + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + # Import the module and get the test function + module_name = self.runner_script_name.replace('.py', '') + module = __import__(module_name) + + if hasattr(module, self.test_function_name): + return getattr(module, self.test_function_name) + else: + raise ImportError(f"Could not find function '{self.test_function_name}' in {module_name}") + + def run_test_subprocess(self, num_gpus: int, **kwargs): + """Run test using torchrun with the configured test function""" + # Check GPU availability before running test + self._check_gpu_availability(num_gpus) + + # Get the test function and use torchrun to execute it + test_function = self._get_test_function() + + # Extract common parameters + dtype = kwargs.get('dtype', 'bf16') + config_name = kwargs.get('config_name', 'tiny') + + # Use partial with positional arguments like test_gnorm.py + return partial(torchrun, num_gpus, test_function, dtype, config_name)() + + # Common test methods that can be used by both ring_attn and ring_attn_varlen + + def run_correctness_basic(self, dtype: str, config_name: str): + """Test correctness with different configurations""" + num_gpus = 2 # Default to 2 GPUs for correctness tests + config = get_config(config_name) + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_multi_gpu_scaling(self, num_gpus: int, config_name: str): + """Test with different numbers of GPUs""" + self.run_test_subprocess( + num_gpus=num_gpus, + dtype="bf16", + config_name=config_name, + ) + + def run_comprehensive_configs(self, dtype: str): + """Test all available configurations (comprehensive test)""" + num_gpus = 2 + + # Test a selection of configurations + test_configs = ["tiny", "small", "medium"] + + for config_name in test_configs: + config = get_config(config_name) + # Skip very large configs for comprehensive test + if config.max_seqlen > 16384: + continue + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_gqa_correctness(self, dtype: str, config_name: str): + """Test GQA correctness with Qwen model configurations""" + num_gpus = 2 + config = get_config(config_name) + + # Ensure it's actually a GQA config + assert config.is_gqa, f"Configuration {config_name} should be GQA" + assert config.num_kv_heads < config.num_heads, f"Configuration {config_name} should have fewer KV heads" + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_sliding_window(self, dtype: str, config_name: str): + """Test with sliding window configurations""" + num_gpus = 2 + config = get_config(config_name) + + # Ensure it's actually a sliding window config + assert config.window_size != (-1, -1), f"Configuration {config_name} should have sliding window" + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + +def create_parametrized_tests(test_class: RingAttnTestBase): + """ + Factory function to create parametrized test methods for a test class. + This reduces code duplication between ring_attn and ring_attn_varlen tests. + """ + + # Correctness tests with different dtypes and configs + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + @pytest.mark.parametrize("config_name", DEFAULT_CORRECTNESS_CONFIGS) + def test_correctness(dtype, config_name): + """Test correctness with different configurations""" + instance = test_class() + instance.run_correctness_basic(dtype, config_name) + + # Multi-GPU tests + @pytest.mark.parametrize("num_gpus", [2, 4]) + @pytest.mark.parametrize("config_name", DEFAULT_MULTI_GPU_CONFIGS) + def test_multi_gpu(num_gpus, config_name): + """Test with different numbers of GPUs""" + instance = test_class() + instance.run_multi_gpu_scaling(num_gpus, config_name) + + # Comprehensive tests + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_all_configs(dtype): + """Test all available configurations (comprehensive test)""" + instance = test_class() + instance.run_comprehensive_configs(dtype) + + # GQA tests + @pytest.mark.parametrize("dtype", ["bf16"]) + @pytest.mark.parametrize("config_name", DEFAULT_GQA_CONFIGS) + def test_gqa_correctness(dtype, config_name): + """Test GQA correctness with Qwen model configurations""" + instance = test_class() + instance.run_gqa_correctness(dtype, config_name) + + # Sliding window tests + @pytest.mark.parametrize("dtype", ["bf16"]) + @pytest.mark.parametrize("config_name", ["small_window", "medium_window"]) + def test_sliding_window(dtype, config_name): + """Test with sliding window configurations""" + instance = test_class() + instance.run_sliding_window(dtype, config_name) + + return { + f'test_{test_class().test_name_prefix}_correctness': test_correctness, + f'test_{test_class().test_name_prefix}_multi_gpu': test_multi_gpu, + f'test_{test_class().test_name_prefix}_all_configs': test_all_configs, + f'test_{test_class().test_name_prefix}_gqa_correctness': test_gqa_correctness, + f'test_{test_class().test_name_prefix}_sliding_window': test_sliding_window, + } \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_ring_attn.py b/tests/customized_ops/ring_attn/test_ring_attn.py new file mode 100644 index 00000000..e1378101 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_ring_attn.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Correctness Tests + +This module tests the correctness of regular ring attention (non-variable length). +It uses the shared test base framework to avoid code duplication. +""" + +import pytest +import torch + +# Skip all tests if flash_attn_func is not available +try: + from flash_attn import flash_attn_func +except ImportError: + pytest.skip("flash_attn_func not available", allow_module_level=True) + +from .test_base import RingAttnTestBase, create_parametrized_tests +from .configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS + + +class RingAttnTest(RingAttnTestBase): + """Test class for regular ring attention""" + + @property + def runner_script_name(self) -> str: + return "ring_attn_runner.py" + + @property + def test_function_name(self) -> str: + return "ring_attn_test" + + @property + def test_name_prefix(self) -> str: + return "ring_attn" + + +# Create parametrized test functions using the factory +test_functions = create_parametrized_tests(RingAttnTest) + +# Assign test functions to module globals for pytest discovery +test_ring_attn_correctness = test_functions['test_ring_attn_correctness'] +test_ring_attn_multi_gpu = test_functions['test_ring_attn_multi_gpu'] +test_ring_attn_all_configs = test_functions['test_ring_attn_all_configs'] +test_ring_attn_gqa_correctness = test_functions['test_ring_attn_gqa_correctness'] +test_ring_attn_sliding_window = test_functions['test_ring_attn_sliding_window'] + + +if __name__ == "__main__": + # Run specific test if called directly + test_instance = RingAttnTest() + test_instance.run_correctness_basic("bf16", "small") + + # Example of running GQA test + # test_instance.run_gqa_correctness("bf16", "qwen3_4b") + + # Example of running sliding window test + # test_instance.run_sliding_window("bf16", "small_window") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_ring_attn_varlen.py b/tests/customized_ops/ring_attn/test_ring_attn_varlen.py new file mode 100644 index 00000000..f86fc276 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_ring_attn_varlen.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Correctness Tests + +This module tests the correctness of ring attention with variable length sequences. +It uses the shared test base framework to avoid code duplication. +""" + +import pytest +import torch + +# Skip all tests if flash_attn_varlen_func is not available +try: + from flash_attn import flash_attn_varlen_func +except ImportError: + pytest.skip("flash_attn_varlen_func not available", allow_module_level=True) + +from .test_base import RingAttnTestBase, create_parametrized_tests +from .configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS + + +class RingAttnVarlenTest(RingAttnTestBase): + """Test class for ring attention variable length""" + + @property + def runner_script_name(self) -> str: + return "ring_attn_varlen_runner.py" + + @property + def test_function_name(self) -> str: + return "ring_attn_varlen_test" + + @property + def test_name_prefix(self) -> str: + return "ring_attn_varlen" + + +# Create parametrized test functions using the factory +test_functions = create_parametrized_tests(RingAttnVarlenTest) + +# Assign test functions to module globals for pytest discovery +test_ring_attn_varlen_correctness = test_functions['test_ring_attn_varlen_correctness'] +test_ring_attn_varlen_multi_gpu = test_functions['test_ring_attn_varlen_multi_gpu'] +test_ring_attn_varlen_all_configs = test_functions['test_ring_attn_varlen_all_configs'] +test_ring_attn_varlen_gqa_correctness = test_functions['test_ring_attn_varlen_gqa_correctness'] +test_ring_attn_varlen_sliding_window = test_functions['test_ring_attn_varlen_sliding_window'] + + +if __name__ == "__main__": + # Run specific test if called directly + test_instance = RingAttnVarlenTest() + test_instance.run_correctness_basic("bf16", "small") + + # Example of running GQA test + # test_instance.run_gqa_correctness("bf16", "qwen3_4b") diff --git a/tests/customized_ops/ring_attn/test_shuffle_varlen.py b/tests/customized_ops/ring_attn/test_shuffle_varlen.py new file mode 100644 index 00000000..8f7271a4 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_shuffle_varlen.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Simple test for shuffle_varlen and unshuffle_varlen functions. +""" + +import pytest +import torch +import torch.distributed as dist +from dataclasses import dataclass +from typing import List +from functools import partial + +from tests.launch_torchrun import torchrun + + +@dataclass +class ShuffleVarlenConfig: + """Simple test configuration""" + name: str + batch_size: int + seq_lens: List[int] + hidden_dim: int + + +# Test configurations +CONFIGS = { + "tiny": ShuffleVarlenConfig("tiny", 2, [512, 768], 64), + "small": ShuffleVarlenConfig("small", 2, [1024, 1536], 128), + "medium": ShuffleVarlenConfig("medium", 2, [1024, 1536], 256), + "uneven": ShuffleVarlenConfig("uneven", 3, [256, 768, 1024], 128), +} + + +def shuffle_varlen_test(config_name="tiny", dtype="float32", world_size=2): + """Test shuffle_varlen and unshuffle_varlen functions""" + + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + + rank = dist.get_rank() + world_size_actual = dist.get_world_size() + device = torch.device(f'cuda:{rank}') + torch.cuda.set_device(device) + + if rank == 0: + print(f"Testing shuffle_varlen and unshuffle_varlen functions") + print(f"Configuration: {config_name}") + print(f"World size: {world_size_actual}") + print(f"Data type: {dtype}") + print("=" * 60) + + # Get configuration + config = CONFIGS[config_name] + + # Set up process group for context parallel + cp_ranks = list(range(world_size_actual)) + cp_group = dist.new_group(cp_ranks) + + # Create cumulative sequence lengths (padded to be divisible by 2*world_size) + cu_seqlens = torch.zeros(config.batch_size + 1, dtype=torch.int32, device=device) + total_slices_per_seq = 2 * world_size_actual + + for i, seq_len in enumerate(config.seq_lens): + # Pad sequence length to be divisible by total_slices_per_seq + padded_seq_len = ((seq_len + total_slices_per_seq - 1) // total_slices_per_seq) * total_slices_per_seq + cu_seqlens[i + 1] = cu_seqlens[i] + padded_seq_len + + total_seq_len = cu_seqlens[len(config.seq_lens)].item() # Use len(config.seq_lens) instead of -1 + + # Convert dtype string to torch dtype + torch_dtype = getattr(torch, dtype) + + # Import functions from varlen_utils + from nnscaler.customized_ops.ring_attention.varlen_utils import shuffle_varlen, unshuffle_varlen + + if rank == 0: + print("Running shuffle/unshuffle correctness tests...") + + tolerance = 1e-5 if torch_dtype == torch.float32 else 1e-2 + + # Test 1: 1D tensor (like position_ids) + if rank == 0: + print(" Test: 1D tensor (total_seq_len,)...") + + try: + # Create full tensor first (on rank 0) + if rank == 0: + full_tensor_1d = torch.arange(total_seq_len, dtype=torch_dtype, device=device) + else: + full_tensor_1d = torch.empty(total_seq_len, dtype=torch_dtype, device=device) + + # Broadcast full tensor to all ranks for reference + dist.broadcast(full_tensor_1d, src=0, group=cp_group) + + # Split tensor for local input (each rank gets a chunk) + chunk_size = total_seq_len // world_size_actual + start_idx = rank * chunk_size + end_idx = start_idx + chunk_size if rank < world_size_actual - 1 else total_seq_len + local_tensor_1d = full_tensor_1d[start_idx:end_idx].clone() + + # Test shuffle -> unshuffle + shuffled = shuffle_varlen(local_tensor_1d, cu_seqlens, cp_ranks, cp_group) + unshuffled = unshuffle_varlen(shuffled, cu_seqlens, cp_ranks, cp_group) + + # Compare with original local chunk + if torch.allclose(local_tensor_1d, unshuffled, atol=tolerance): + if rank == 0: + print(" ✓ 1D tensor test passed") + else: + if rank == 0: + print(" ✗ 1D tensor test FAILED") + raise AssertionError("1D tensor test failed") + + except Exception as e: + if rank == 0: + print(f" ✗ 1D tensor test FAILED with error: {e}") + raise e + + # Test 2: 2D tensor (total_seq_len, hidden_dim) + if rank == 0: + print(" Test: 2D tensor (total_seq_len, hidden_dim)...") + + try: + # Create full tensor first (on rank 0) + if rank == 0: + full_tensor_2d = torch.randn(total_seq_len, config.hidden_dim, dtype=torch_dtype, device=device) + else: + full_tensor_2d = torch.empty(total_seq_len, config.hidden_dim, dtype=torch_dtype, device=device) + + # Broadcast full tensor to all ranks for reference + dist.broadcast(full_tensor_2d, src=0, group=cp_group) + + # Split tensor for local input (each rank gets a chunk) + chunk_size = total_seq_len // world_size_actual + start_idx = rank * chunk_size + end_idx = start_idx + chunk_size if rank < world_size_actual - 1 else total_seq_len + local_tensor_2d = full_tensor_2d[start_idx:end_idx].clone() + + # Test shuffle -> unshuffle + shuffled = shuffle_varlen(local_tensor_2d, cu_seqlens, cp_ranks, cp_group) + unshuffled = unshuffle_varlen(shuffled, cu_seqlens, cp_ranks, cp_group) + + # Compare with original local chunk + if torch.allclose(local_tensor_2d, unshuffled, atol=tolerance): + if rank == 0: + print(" ✓ 2D tensor test passed") + else: + if rank == 0: + print(" ✗ 2D tensor test FAILED") + raise AssertionError("2D tensor test failed") + + except Exception as e: + if rank == 0: + print(f" ✗ 2D tensor test FAILED with error: {e}") + raise e + + dist.barrier() + + if rank == 0: + print("✓ All shuffle/unshuffle tests PASSED!") + + dist.destroy_process_group() + + +class TestShuffleVarlen: + """Simple test class for shuffle/unshuffle varlen""" + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_tiny(self, dtype): + """Test shuffle/unshuffle varlen with tiny configuration""" + partial(torchrun, 2, shuffle_varlen_test, "tiny", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_small(self, dtype): + """Test shuffle/unshuffle varlen with small configuration""" + partial(torchrun, 2, shuffle_varlen_test, "small", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_medium(self, dtype): + """Test shuffle/unshuffle varlen with medium configuration""" + partial(torchrun, 2, shuffle_varlen_test, "medium", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_uneven(self, dtype): + """Test shuffle/unshuffle varlen with uneven sequence lengths""" + partial(torchrun, 2, shuffle_varlen_test, "uneven", dtype)() + + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_shuffle_varlen_multi_gpu(self, num_gpus): + """Test shuffle/unshuffle varlen on multiple GPUs""" + partial(torchrun, num_gpus, shuffle_varlen_test, "tiny", "float32")() + + +# Standalone test functions for pytest discovery +@pytest.mark.parametrize("config,dtype", [ + ("tiny", "float32"), ("tiny", "float16"), + ("small", "float32"), ("small", "float16"), + ("uneven", "float32"), ("uneven", "float16"), +]) +def test_shuffle_varlen_correctness(config, dtype): + """Test shuffle/unshuffle varlen correctness""" + partial(torchrun, 2, shuffle_varlen_test, config, dtype)() + + +@pytest.mark.parametrize("config,num_gpus", [ + ("tiny", 2), ("tiny", 4), + ("small", 2), ("small", 4), +]) +def test_shuffle_varlen_multi_gpu(config, num_gpus): + """Test shuffle/unshuffle varlen on multiple GPUs""" + partial(torchrun, num_gpus, shuffle_varlen_test, config, "float32")() \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_zigzag_attn.py b/tests/customized_ops/ring_attn/test_zigzag_attn.py new file mode 100644 index 00000000..3bca5792 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_zigzag_attn.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag attention correctness tests. + +This module contains correctness tests for the zigzag attention implementation. +Note: Zigzag attention only supports causal=True and window_size=(-1, -1). + +Usage: + python -m pytest test_zigzag_attn.py -v + python -m pytest test_zigzag_attn.py::TestZigzagAttn::test_zigzag_attn_tiny_bf16 -v +""" + +import pytest + +# Skip all tests if flash_attn_func is not available +try: + from flash_attn import flash_attn_func +except ImportError: + pytest.skip("flash_attn_func not available", allow_module_level=True) + +from .test_base import RingAttnTestBase + + +class TestZigzagAttn(RingAttnTestBase): + """Test class for zigzag attention correctness testing""" + + @property + def runner_script_name(self) -> str: + return "zigzag_attn_runner.py" + + @property + def test_function_name(self) -> str: + return "zigzag_attn_test" + + @property + def test_name_prefix(self) -> str: + return "zigzag_attn" + + # Basic correctness tests + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_zigzag_attn_tiny(self, dtype): + """Test zigzag attention with tiny configuration""" + self.run_correctness_basic(dtype, "zigzag_tiny") + + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_zigzag_attn_small(self, dtype): + """Test zigzag attention with small configuration""" + self.run_correctness_basic(dtype, "zigzag_small") + + @pytest.mark.parametrize("dtype", ["bf16"]) + def test_zigzag_attn_medium(self, dtype): + """Test zigzag attention with medium configuration""" + self.run_correctness_basic(dtype, "zigzag_medium") + + # Multi-GPU tests + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_zigzag_attn_multi_gpu_small(self, num_gpus): + """Test zigzag attention with small config on multiple GPUs""" + self.run_multi_gpu_scaling(num_gpus, "zigzag_small") + + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_zigzag_attn_multi_gpu_medium(self, num_gpus): + """Test zigzag attention with medium config on multiple GPUs""" + self.run_multi_gpu_scaling(num_gpus, "zigzag_medium") + + # GQA test + def test_zigzag_attn_gqa(self): + """Test zigzag attention with GQA configuration""" + self.run_gqa_correctness("bf16", "zigzag_gqa") + + +if __name__ == "__main__": + # For direct execution, run a simple test + test_instance = TestZigzagAttn() + test_instance.run_correctness_basic("bf16", "zigzag_tiny") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/zigzag_attn_runner.py b/tests/customized_ops/ring_attn/zigzag_attn_runner.py new file mode 100644 index 00000000..6e557e2a --- /dev/null +++ b/tests/customized_ops/ring_attn/zigzag_attn_runner.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag attention test runner implementation. +This module provides the specific runner for testing zigzag attention. +Note: Zigzag attention only supports causal=True and window_size=(-1, -1). +""" + +import os +import sys +from typing import Dict, Any, Tuple + +import torch +import torch.nn as nn + +from nnscaler.customized_ops.ring_attention.zigzag_attn import wrap_zigzag_attn_func +from runner_base import RingAttnRunnerBase + + +class ZigzagAttnRunner(RingAttnRunnerBase): + """Zigzag attention test runner""" + + @property + def function_signature(self) -> str: + return "nnscaler.customized_ops.ring_attention.zigzag_attn.wrap_zigzag_attn_func" + + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 1 + + @property + def function_name(self) -> str: + return "wrap_zigzag_attn_func" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for zigzag attention""" + class TestModule(nn.Module): + def __init__(self, causal=True, window_size=(-1, -1)): + super().__init__() + # Zigzag attention only supports causal=True and window_size=(-1, -1) + assert causal is True, "Zigzag attention only supports causal=True" + assert window_size == (-1, -1), "Zigzag attention only supports window_size=(-1, -1)" + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v): + # Note: zigzag_attn always uses causal=True and window_size=(-1, -1) + return wrap_zigzag_attn_func(q, k, v, causal=self.causal, window_size=self.window_size) + + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare inputs for zigzag attention""" + batch_size = config.batch_size + max_seqlen = config.max_seqlen + num_heads = config.num_heads + num_kv_heads = config.num_kv_heads + head_dim = config.head_dim + + # Create input tensors + q = torch.clamp(torch.randn(batch_size, max_seqlen, num_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + k = torch.clamp(torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + v = torch.clamp(torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + # Note: zigzag_attn always uses causal=True and window_size=(-1, -1) + output = wrap_zigzag_attn_func( + inputs['q'], inputs['k'], inputs['v'], + causal=config.causal, window_size=config.window_size) + output.retain_grad() + + return output, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs) -> Dict[str, Any]: + """Get dummy forward arguments for model parallelization""" + return { + 'q': inputs['q'], + 'k': inputs['k'], + 'v': inputs['v'] + } + + +def zigzag_attn_test(dtype="bf16", config_name="zigzag_tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = ZigzagAttnRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + +def main(): + """Main entry point for command line execution""" + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + + runner = ZigzagAttnRunner() + runner.main(**kwargs) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 319e3b6f..e6a12c55 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -7,6 +7,7 @@ from operator import add from nnscaler.graph.function.dimops import IRDimops, OpAnno import nnscaler.graph.function.function as F +from nnscaler.graph.parser.value_tracker import ValueTracker from nnscaler.ir.cten import IR, IRObject, IRTensor import pytest @@ -47,6 +48,21 @@ def test_Full(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 1' +def test_Randn(): + op = F.Randn(IRObject(value=[2, 3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 2 3 4' + + for dim_track in op.output(0).dim_tracks: + assert dim_track.deps == [op.kwargs['size'].value_track.value_id] + + op = F.Randn(2, IRObject(value=3), IRObject(value=4)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 2 3 4' + + assert op.output(0).dim_tracks[0].deps == [] + assert op.output(0).dim_tracks[1].deps == [op.kwargs['size'][1].value_track.value_id] + assert op.output(0).dim_tracks[2].deps == [op.kwargs['size'][2].value_track.value_id] + + def test_Expand(): inp = IRTensor([10, 1]) out = IRTensor([10, 2]) @@ -1147,3 +1163,43 @@ def test_dict_keys_values_items(): # key will never be wrapped with IRObject # IRFullTensor will be reconstructed, so their ids are different assert all(x[0] == y[0] and x[1].shape == y[1].shape and x[1] != y[1] for x, y in zip(r.output(0), d.items())) + +def test_Stack(): + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=0) + expected_annotation = 'a b, a b, a b -> 3 a b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=1) + expected_annotation = 'a b, a b, a b -> a 3 b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=2) + expected_annotation = 'a b, a b, a b -> a b 3' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + + op = F.Stack([IRTensor([]), IRTensor([]), IRTensor([])], dim=0) + expected_annotation = '1, 1, 1 -> 3' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + + +def test_Dot(): + op = F.Dot(IRTensor([4]), IRTensor([4])) + expected_annotation = 'k+, k+ -> 1' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Dot." + + +def test_chunk(): + op = F.Chunk(IRTensor([8, 10]), chunks=4, dim=0) + expected_annotation = '8 b -> 2 b, 2 b, 2 b, 2 b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation + value_tracker = ValueTracker() + value_tracker.track_nodes([op]) + value_tracker.complete_tracking([op]) + input_dim_tracks = op.input(0).dim_tracks + output_dim_tracks = [out.dim_tracks for out in op.outputs()] + # all dim 1 tracks should be the same + assert output_dim_tracks[0][1] is input_dim_tracks[1] + # output dim 0 tracks should depend on input dim 0 track + assert output_dim_tracks[0][0].deps == [input_dim_tracks[0].value_id] + for output_dim_track in output_dim_tracks[1:]: + assert output_dim_track[0] is output_dim_tracks[0][0] + assert output_dim_track[1] is output_dim_tracks[0][1] + assert True diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 51330fd9..548e9024 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -164,12 +164,12 @@ def f(self) -> None: assert modified assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' def f(func_name, type: int, /, *args, **kwargs): - return patched_run(func_name, type, *args, **kwargs) + return patched_run(func_name, 'func_name(type, *args, **kwargs)', type, *args, **kwargs) def g(): - return patched_run(x + y, a, b) + return patched_run(x + y, '(x + y)(a, b)', a, b) class A: def f(self) -> None: - patched_run(patched_run(super).f) + patched_run(patched_run(super, 'super()').f, 'super().f()') ''').strip() @@ -188,10 +188,10 @@ def __init__(self) -> None: modified, new_ast = transform(tree, transfomers) assert modified assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' - x = patched_run(not_, True) + x = patched_run(not_, 'not_(True)', True) def f(func_name, type: int, /, *args, **kwargs): - return patched_run(func_name, type, *args, **kwargs) + return patched_run(func_name, 'func_name(type, *args, **kwargs)', type, *args, **kwargs) class A: def __init__(self) -> None: - patched_run(super(self.__class__, self).__init__) + patched_run(super(self.__class__, self).__init__, 'super(self.__class__, self).__init__()') ''').strip() diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 20f2bcff..aed04fe6 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -44,7 +44,6 @@ def forward(self, x, **kwargs): assert any(node.op == 'call_function' and node.target == torch.nn.functional.linear for node in nodes) with tempfile.TemporaryDirectory() as tempdir: - to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) assert ir_graph is not None assert (Path(tempdir) / FxModuleParser.ATTR_MAP_FILE).exists() @@ -52,13 +51,45 @@ def forward(self, x, **kwargs): assert ir_graph.name == 'MyModule' inputs = ir_graph.inputs() assert len(inputs) == 2 - assert inputs[0].name == nodes[0].name + assert inputs[0].name == nodes[0].target assert isinstance(inputs[0], IRTensor) - assert inputs[1].name == nodes[1].name + assert inputs[0].value_track.deps == None + # inputs has no dependency + assert all(dt.deps == [] for dt in inputs[0].dim_tracks) + assert inputs[1].name == nodes[1].target assert isinstance(inputs[1], IRObject) + assert inputs[1].value_track.deps == [] + + assert len(ir_graph.nodes()) == 1 + linear_node = ir_graph.nodes()[0] + assert len(linear_node.inputs()) == 3 # x, weight, bias + + assert all(isinstance(i, IRTensor) for i in linear_node.inputs()) + # from its annotation, a k^, n k^, n -> a n + # we can check the value_track and dim_track dependencies + + # the same with graph inputs + assert all(linear_node.input(0).dim_tracks[i] is inputs[0].dim_tracks[i] for i in range(len(inputs[0].dim_tracks))) + # weights has no dependency + assert linear_node.input(1).dim_tracks[0].deps == [] + # the `k` dimension + assert linear_node.input(1).dim_tracks[1] is inputs[0].dim_tracks[1] + # the `n` dimension + assert linear_node.input(2).dim_tracks[0] is linear_node.input(1).dim_tracks[0] + + assert len(linear_node.outputs()) == 1 + assert isinstance(linear_node.outputs()[0], IRTensor) + # `a` + assert linear_node.output(0).dim_tracks[0] is inputs[0].dim_tracks[0] + # `n` + assert linear_node.output(0).dim_tracks[1] is linear_node.input(1).dim_tracks[0] outputs = ir_graph.outputs() assert len(outputs) == 1 + # `a` + assert outputs[0].dim_tracks[0] is inputs[0].dim_tracks[0] + # `n` + assert outputs[0].dim_tracks[1] is linear_node.input(1).dim_tracks[0] nodes = list(ir_graph.nodes()) assert any(node.signature == 'torch.nn.functional.linear' for node in nodes) diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index a0bc33b8..176cba07 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -166,7 +166,7 @@ def forward(self, x): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) print(ir_graph.extra_repr()) - assert len(ir_graph.nodes()) == 5 + assert len(ir_graph.nodes()) == 4 assert len(ir_graph.nodes()[0].outputs()) == 3 assert len(ir_graph.outputs()) == 1 assert isinstance(ir_graph.output(0), list) diff --git a/tests/graph/parser/test_value_tracker.py b/tests/graph/parser/test_value_tracker.py new file mode 100644 index 00000000..ff45eacf --- /dev/null +++ b/tests/graph/parser/test_value_tracker.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile + +import pytest +import torch + +from nnscaler.graph.parser.converter import convert_model +from nnscaler import register_op, mark_dynamic + +from ...utils import replace_all_device_with + + +@replace_all_device_with('cpu') +def test_hidden_dim(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return x.repeat(4, 1) + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 1 + node = ir_graph.node(0) + assert str(node.anno) == 'a^ b -> (4^ a^) b' + dim0_vi = node.input(0).dim_tracks[0].value_id + dim1_vi = node.input(0).dim_tracks[1].value_id + + assert node.output(0).dim_tracks[0].value_id != dim0_vi + assert node.output(0).dim_tracks[0].deps == [dim0_vi] + assert node.output(0).dim_tracks[1].value_id == dim1_vi + assert node.output(0).dim_tracks[1].deps == [] + + +@replace_all_device_with('cpu') +def test_equiv_class(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + x = x + 1 + y = y * 2 + return x@y + + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 'y': torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 3 + x_node = ir_graph.node(0) + y_node = ir_graph.node(1) + assert x_node.input(0).dim_tracks[0] is x_node.output(0).dim_tracks[0] + assert x_node.input(0).dim_tracks[1] is x_node.output(0).dim_tracks[1] + + assert y_node.input(0).dim_tracks[0] is y_node.output(0).dim_tracks[0] + assert y_node.input(0).dim_tracks[1] is y_node.output(0).dim_tracks[1] + + node = ir_graph.node(-1) + assert str(node.anno) == 'm k+, k+ n -> m n' + # the `k` dimension of input 1 should be the same as input 0 + # they are in the same equivalence class + assert node.input(0).dim_tracks[0] is x_node.input(0).dim_tracks[0] + assert node.input(0).dim_tracks[1] is x_node.input(0).dim_tracks[1] + assert node.input(1).dim_tracks[0] is node.input(0).dim_tracks[1] + assert node.input(1).dim_tracks[1] is y_node.input(0).dim_tracks[1] + + assert node.output(0).dim_tracks[0] is node.input(0).dim_tracks[0] + assert node.output(0).dim_tracks[1] is node.input(1).dim_tracks[1] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dim', [True, False]) +def test_size(dynamic_dim): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + x = x + 1 + s = x.size() + y = torch.randn(s) + return x + y + + dummy_input = {'x': mark_dynamic(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), [0] if dynamic_dim else [])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 4 + size_node = ir_graph.node(1) + randn_node = ir_graph.node(2) + + assert size_node.output(0)[0].value_track is ir_graph.inputs()[0].dim_tracks[0] + assert size_node.output(0)[0].value_track.is_constant != (dynamic_dim is True) + assert size_node.output(0)[1].value_track is ir_graph.inputs()[0].dim_tracks[1] + assert size_node.output(0)[1].value_track.is_constant is True + + # dim tracks of randn node is from equivalence class originally from torch.add + assert randn_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[0] + assert randn_node.output(0).dim_tracks[0].is_constant != (dynamic_dim is True) + assert randn_node.output(0).dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[1] + assert randn_node.output(0).dim_tracks[1].is_constant is True + + +# Note: the custom op here is just for testing purpose +@register_op('l (2 m) n -> n (2 l) m') +def my_op(x: torch.Tensor) -> torch.Tensor: + return torch.randn(x.size(2), x.size(0) * 2, x.size(1) // 2) + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dim', [True, False]) +def test_custom_op(dynamic_dim): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + x = my_op(x) + s = x.size() + y = torch.randn(s) + return x + y + + dummy_input = {'x': mark_dynamic(torch.randn(2, 2, 2), [0, 2] if dynamic_dim else [])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 4 + my_op_node = ir_graph.node(0) + size_node = ir_graph.node(1) + randn_node = ir_graph.node(2) + + assert [t.is_constant for t in ir_graph.inputs()[0].dim_tracks] == [not dynamic_dim, True, not dynamic_dim] + + assert my_op_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[2] + assert ir_graph.inputs()[0].dim_tracks[0].value_id in my_op_node.output(0).dim_tracks[1].deps + assert ir_graph.inputs()[0].dim_tracks[1].value_id in my_op_node.output(0).dim_tracks[2].deps + + assert [t.is_constant for t in my_op_node.outputs()[0].dim_tracks] == [not dynamic_dim, not dynamic_dim, True] + + assert size_node.output(0)[0].value_track is my_op_node.output(0).dim_tracks[0] + assert size_node.output(0)[1].value_track is my_op_node.output(0).dim_tracks[1] + assert size_node.output(0)[2].value_track is my_op_node.output(0).dim_tracks[2] + + assert [t.value_track.is_constant for t in size_node.output(0)] == [not dynamic_dim, not dynamic_dim, True] + + # dim tracks of randn node is from equivalence class originally from torch.add + assert randn_node.output(0).dim_tracks[0] is my_op_node.output(0).dim_tracks[0] + # assert randn_node.output(0).dim_tracks[0].is_constant != (dynamic_dim is True) + assert randn_node.output(0).dim_tracks[1] is my_op_node.output(0).dim_tracks[1] + assert randn_node.output(0).dim_tracks[2] is my_op_node.output(0).dim_tracks[2] + + assert [t.is_constant for t in randn_node.outputs()[0].dim_tracks] == [not dynamic_dim, not dynamic_dim, True] + + +# Note: the custom op here is just for testing purpose +@register_op('l l -> l l') +def my_identity(x: torch.Tensor) -> torch.Tensor: + return x + + +@replace_all_device_with('cpu') +def test_custom_op2(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return my_identity(x) + + dummy_input = {'x': torch.randn(2, 2)} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 1 + my_op_node = ir_graph.node(0) + + assert ir_graph.inputs()[0].dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[0] + assert my_op_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[0] + assert my_op_node.output(0).dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[1] diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py index 108b9e3e..57c10809 100644 --- a/tests/graph/test_segment.py +++ b/tests/graph/test_segment.py @@ -167,7 +167,11 @@ def policy_nograd(graph: IRGraph, cfg: ComputeConfig) -> IRGraph: else: fc1_node = graph.nodes()[0] func_node = graph.nodes()[1] - assert fc1_node.inputs()[0].requires_grad and fc1_node.inputs()[0].grad + if cfg.use_end2end: + assert not fc1_node.inputs()[0].requires_grad and not fc1_node.inputs()[0].grad + else: + assert fc1_node.inputs()[0].requires_grad and fc1_node.inputs()[0].grad + assert fc1_node.inputs()[1].requires_grad and fc1_node.inputs()[1].grad assert fc1_node.outputs()[0].requires_grad and fc1_node.outputs()[0].grad assert func_node.inputs()[0].requires_grad and not func_node.inputs()[0].grad diff --git a/tests/graph/tracer/test_pack_kwargs.py b/tests/graph/tracer/test_pack_kwargs.py new file mode 100644 index 00000000..4db75d5d --- /dev/null +++ b/tests/graph/tracer/test_pack_kwargs.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from torch.nn import Module + +from nnscaler.graph.tracer import concrete_trace +from ...utils import replace_all_device_with + + +class Model(Module): + def __init__(self): + super(Model, self).__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, **kwargs): + return self.linear(kwargs['input']) + + +@replace_all_device_with('cpu') +def test_pack_kwargs(): + model = Model() + example_inputs = {'input': torch.randn(1, 10)} + traced_model = concrete_trace(model, example_inputs) + assert list(traced_model.graph.nodes)[0].target == '**kwargs' + + +@replace_all_device_with('cpu') +def test_direct_kwargs(): + model = Model() + example_inputs = {'**kwargs': {'input': torch.randn(1, 10)}} + traced_model = concrete_trace(model, example_inputs) + assert list(traced_model.graph.nodes)[0].target == '**kwargs' diff --git a/tests/ir/test_cten.py b/tests/ir/test_cten.py index c406f7bc..42f7d115 100644 --- a/tests/ir/test_cten.py +++ b/tests/ir/test_cten.py @@ -120,9 +120,9 @@ def test_from_complex(tosub, requires_grad): assert type(obj[2]) == tensor_type and obj[2].parent.tid != obj_tensor_item.tid t1 = TensorMetadata(shape=(), dtype=torch.float, requires_grad=False, - stride=None, memory_format=None, is_quantized=None, qparams=None) + stride=None, memory_format=None, is_quantized=None, qparams=None, dynamic_dims=set()) t2 = TensorMetadata(shape=(2,), dtype=torch.float, requires_grad=True, - stride=None, memory_format=None, is_quantized=None, qparams=None) + stride=None, memory_format=None, is_quantized=None, qparams=None, dynamic_dims=set()) obj = IR.new('n', {'a': t1, 'b': t2}.values(), tensor_types=(TensorMetadata,), diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 0cfac768..cc3b22e3 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -107,6 +107,19 @@ def forward(self, x): return x +class FFN(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = torch.nn.Tanh() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + def init_distributed(): torch.distributed.init_process_group(backend='nccl') rank = torch.distributed.get_rank() @@ -115,9 +128,10 @@ def init_distributed(): def assert_equal(a: Any, b: Any): - assert type(a) == type(b) + # treat dict and OrderedDict as same for comparison + assert type(a) == type(b) or (isinstance(a, dict) and isinstance(b, dict)), f'{type(a)} != {type(b)}' if isinstance(a, torch.Tensor): - assert torch.equal(a.cpu(), b.cpu()) + assert torch.equal(a.cpu(), b.cpu()), torch.max(torch.abs(a.cpu() - b.cpu())) elif isinstance(a, dict): assert len(a) == len(b) for k in a.keys(): @@ -127,7 +141,7 @@ def assert_equal(a: Any, b: Any): for i in range(len(a)): assert_equal(a[i], b[i]) else: - assert a == b + assert a == b, f"Values are not equal: {a} != {b}" def assert_close(a: Any, b: Any, atol=1e-6, rtol=1e-6): @@ -137,10 +151,10 @@ def assert_close(a: Any, b: Any, atol=1e-6, rtol=1e-6): elif isinstance(a, dict): assert len(a) == len(b) for k in a.keys(): - assert_close(a[k], b[k]) + assert_close(a[k], b[k], atol=atol, rtol=rtol) elif isinstance(a, (list, tuple)): assert len(a) == len(b) for i in range(len(a)): - assert_close(a[i], b[i]) + assert_close(a[i], b[i], atol=atol, rtol=rtol) else: - raise ValueError(f'unsupported type {type(a)}') \ No newline at end of file + assert a == b, f"Values are not equal: {a} != {b}" diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py index 64bf490a..b1a5a95b 100644 --- a/tests/parallel_module/test_attr_dedup.py +++ b/tests/parallel_module/test_attr_dedup.py @@ -22,6 +22,7 @@ from ..launch_torchrun import launch_torchrun from ..utils import clear_dir_on_rank0 + class Net(torch.nn.Module): def __init__(self): super().__init__() @@ -38,6 +39,7 @@ def forward(self, x): x = self.buffer + x return x + def pas(graph: IRGraph, config: ComputeConfig): fw_nodes = graph.select(ntype=IRFwOperation) assert len(fw_nodes) == 4 @@ -50,6 +52,7 @@ def pas(graph: IRGraph, config: ComputeConfig): _replica(graph, fw_nodes[3], devs=devs) return graph + def _gpu_worker_spmd(cc: ComputeConfig): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_dedup_attr') as tempdir: @@ -65,13 +68,17 @@ def _gpu_worker_spmd(cc: ComputeConfig): world_size = torch.distributed.get_world_size() attr_area_maps = [None for _ in range(world_size)] curr_rank = torch.distributed.get_rank() - torch.distributed.all_gather_object(attr_area_maps, module.fullmap) + # Construct the three-level nested structure: rank -> module_name -> fullmap + # In this test case, we have only one module instance 'attr_dedup' + module_fullmap = {'attr_dedup': module.fullmap} + torch.distributed.all_gather_object(attr_area_maps, module_fullmap) rank2attr_area_map = {} for i, attr_area_map in enumerate(attr_area_maps): rank2attr_area_map[i] = attr_area_map torch.distributed.barrier() dedup_meta_info = dedup_attrs(rank2attr_area_map) - dedup_area_map = list(dedup_meta_info[curr_rank].items()) + # Access the deduped fullmap for the specific module + dedup_area_map = list(dedup_meta_info[curr_rank]['attr_dedup'].items()) if curr_rank == 0: assert len(dedup_area_map) == 4 assert dedup_area_map[0][1].orig_name == 'fc1.weight' diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index c7bc814a..5a305499 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -64,12 +64,12 @@ def _gpu_worker(): p(tempdir, 'none', '_1') # case 2: broadcast only code, so only rank 0 can load the module - # rank 1 will raise RuntimeError because it will fail to load fullmodel.pt + # rank 1 will raise FileNotFoundError because it will fail to load attr_map files and more with tempfile.TemporaryDirectory() as tempdir: if torch.distributed.get_rank() == 0: p(tempdir, 'code', '_2') else: - with pytest.raises(RuntimeError, match='Cannot find file.*'): + with pytest.raises(FileNotFoundError): p(tempdir, 'code', '_2') # case 3: broadcast except weights, so only rank 0 can load the module diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 6faee640..d84d7503 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -17,11 +17,18 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dict +from nnscaler.parallel import ( + ComputeConfig, parallelize, + build_optimizer, + merge_state_dicts, + load_merged_state_dict, + load_merged_state_dict_from_rank, + trimmed_broadcast_merged_state_dict, +) from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import CubeLinear, init_random, init_distributed, PASMegatron +from .common import CubeLinear, init_random, init_distributed, PASMegatron, assert_equal from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively from ..utils import replace_all_device_with, clear_dir_on_rank0 @@ -345,6 +352,23 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf optimizer_from_merged, merged_opt_state_dict, ) + model_from_merged_rank = type(model)() + optimizer_from_merged_rank = build_optimizer(model_from_merged_rank, torch.optim.Adam, lr=0.01) + load_merged_state_dict_from_rank( + model_from_merged_rank, merged_model_state_dict if torch.distributed.get_rank() == 0 else None, + optimizer_from_merged_rank, merged_opt_state_dict if torch.distributed.get_rank() == 0 else None, + ) + assert_equal(model_from_merged.state_dict(), model_from_merged_rank.state_dict()) + assert_equal(optimizer_from_merged.state_dict(), optimizer_from_merged_rank.state_dict()) + + trimmed_model_state_dict, trimmed_opt_state_dict = trimmed_broadcast_merged_state_dict( + model_from_merged_rank, merged_model_state_dict if torch.distributed.get_rank() == 0 else None, + optimizer_from_merged_rank, merged_opt_state_dict if torch.distributed.get_rank() == 0 else None, + ) + assert_equal(dict(model_from_merged.state_dict()), trimmed_model_state_dict) + assert_equal(optimizer_from_merged.state_dict()['state'], trimmed_opt_state_dict['state']) + assert_equal(optimizer_from_merged.state_dict()['param_groups'], trimmed_opt_state_dict['param_groups']) + # check merged model result_orig_model_state_dict = model.state_dict() result_merged_model_state_dict = model_from_merged.state_dict() @@ -425,6 +449,25 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf 'model': merged_model_state_dicts, 'optimizer': merged_optimizer_state_dict }, ckpt_merged_file) + from nnscaler.runtime.serialization import convert, load + from contextlib import ExitStack + ckpt_st_file_template = 'ckpt_{rank}_{start}.safetensors' + ckpt_st_files = [ckpt_dir / ckpt_st_file_template.format(rank=i, start=end) for i in range(torch.distributed.get_world_size())] + for pt, st in zip(ckpt_files, ckpt_st_files): + convert(pt, st, src_format='pt', dst_format='safetensors') + ckpt_st_state_dict_loaders = [load(f, lazy=True) for f in ckpt_st_files] + with ExitStack() as stack: + ckpt_st_state_dicts = [] + for f in ckpt_st_state_dict_loaders: + ckpt_st_state_dicts.append(stack.enter_context(f).get_lazy_data()) + model_st_state_dicts = [ckpt['model'] for ckpt in ckpt_st_state_dicts] + optimizer_st_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_st_state_dicts] + merged_model_st_state_dicts, merged_optimizer_st_state_dict = merge_state_dicts( + model_st_state_dicts, optimizer_st_state_dicts + ) + assert_equal(merged_model_state_dicts, merged_model_st_state_dicts) + assert_equal(merged_optimizer_state_dict, merged_optimizer_st_state_dict) + torch.distributed.barrier() return results @@ -550,7 +593,7 @@ def _gpu_merge_worker(): init_distributed() with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_merge') as tempdir: compiled_module = _create_cube_module('data', - ComputeConfig(2, 2, use_zero=True), + ComputeConfig(2, 4, use_zero=True), tempdir, 'whole', ) @@ -565,6 +608,6 @@ def _gpu_merge_worker(): ) -@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 24, reason='lack of gpu devices') def test_checkpoint_merge(): - launch_torchrun(2, _gpu_merge_worker) + launch_torchrun(4, _gpu_merge_worker) diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index 339ea888..a5fec814 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -139,8 +139,6 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): if not isinstance(model, ParallelModule): # in this case, non parallel module is removed, so it should have less keys assert len(parallel_modules) < len(dedupped_model_state_dict) < len(model_state_dict) - else: - assert len(dedupped_model_state_dict) == len(model_state_dict) for k, v in dedupped_model_state_dict.items(): assert_equal(v, model_state_dict[k]) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index b43e0d20..400e3b18 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -17,7 +17,7 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts from nnscaler.runtime.module import ParallelModule from nnscaler.runtime.gnorm import calcuate_gnorm @@ -261,6 +261,36 @@ def _compare_weights(orig0, compiled0, compiled1, fc1_fullmap, fc2_fullmap, fc1_ # print(f'key: {k}, max diff: {torch.max(torch.abs(orig0[k] - v))}') assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('update_freq', [1, 2, 4]) +def test_zero3(update_freq): + zero3_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 3) + zero1_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 1) + no_zero_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 0) + + zero3_results0: List[StepResult] + zero3_results1: List[StepResult] + zero1_results0: List[StepResult] + zero1_results1: List[StepResult] + no_zero_results0: List[StepResult] + no_zero_results1: List[StepResult] + + zero3_results0, zero3_results1 = zero3_results[0][0], zero3_results[1][0] + zero1_results0, zero1_results1 = zero1_results[0][0], zero1_results[1][0] + no_zero_results0, no_zero_results1 = no_zero_results[0][0], no_zero_results[1][0] + + for r0, r1 in [ + (zero3_results0, zero1_results0), (zero1_results0, no_zero_results0), + (zero3_results1, zero1_results1), (zero1_results1, no_zero_results1), + ]: + # have the same input + assert len(r0) == len(r1) # iteration count + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.pred, b.pred) # pred + assert torch.equal(a.loss, b.loss) # loss + assert torch.equal(a.gnorm, b.gnorm) # gnorm + @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @pytest.mark.parametrize('update_freq', [1, 2, 4]) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index b7117563..dbca3de0 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -6,6 +6,7 @@ import re from contextlib import nullcontext from typing import Union +from functools import partial import torch import torch.nn.functional as F @@ -17,7 +18,8 @@ from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.parser.mapping import SignFx2Op from nnscaler.ir.cten import IR, IRObject -from nnscaler.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph +from nnscaler.parallel import _load_parallel_module_class, parallelize, ComputeConfig, CubeModule, _gen_graph +from nnscaler.utils import mark_dynamic from .common import init_distributed from ..launch_torchrun import launch_torchrun @@ -67,6 +69,7 @@ def __init__(self): def forward(self, x): return x[:2] + @replace_all_device_with('cpu') def test_codegen_slice(): with tempfile.TemporaryDirectory() as tempdir: @@ -208,9 +211,8 @@ def _gencode_unused_args_worker(tempdir): m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), m=1) ) - with pytest.raises(ValueError): - # y must be None - m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) + # if y is not None, we will not raise error now. + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @@ -256,9 +258,6 @@ def _gencode_unused_args_worker2(tempdir): with pytest.raises(TypeError, match='.*must be Tensor, not NoneType.*'): # raise by torch.add, as m is None m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - with pytest.raises(ValueError): - # y must be None - m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @@ -394,11 +393,11 @@ def print_gencode(cubesave_dir, module_class, index=0): print(filecontent) -def _gencode_contains(cubesave_dir, module_class, index, search_re): +def _gencode_contains(cubesave_dir, module_class, index, search_re, *, instance_name=None): from nnscaler.parallel import _PARALLEL_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME from pathlib import Path import re - namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' + namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{instance_name or _DEFAULT_INSTANCE_NAME}' outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) filecontent = (outdir /f'gencode{index}.py').read_text() matches = re.findall(search_re, filecontent) @@ -453,8 +452,13 @@ def test_codegen_getitem(): gen_savedir=tempdir, load_module=False, ) - assert _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') - assert _gencode_contains(tempdir, GetItemModule, 1, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') + + assert _gencode_contains(tempdir, GetItemModule, 0, r"_operator.getitem\(batched_data.*, 'x'\)") + assert _gencode_contains(tempdir, GetItemModule, 1, r"_operator.getitem\(batched_data.*, 'x'\)") + # data_x.size() will be expanded to a list of ir objects, + # so no slice operation will be generated. + assert not _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') + assert not _gencode_contains(tempdir, GetItemModule, 1, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') assert m_new is None @@ -654,6 +658,37 @@ def test_codegen_dictget(): assert m_new is None +class NonConstModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + shape = x.shape + z = torch.randn(shape) + shape = z.shape + return z + shape[0] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dims', [[], [0]]) +def test_codegen_nonconst(dynamic_dims): + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + NonConstModule(), + {'x': mark_dynamic(torch.tensor([[[1.0], [2.0], [3.0], [6.0]]]), dynamic_dims)}, # shape 1/4/1 + 'dp', + ComputeConfig(1, 1, constant_folding=True), + gen_savedir=tempdir, + load_module=False + ) + if not dynamic_dims: + # shape[0] is constant 1, so can be folded to constant 1 + assert _gencode_contains(tempdir, NonConstModule, 0, r'torch.add\(.*, 1, alpha=1\)') + else: + # shape[0] is dynamic, so cannot be folded to constant 1 + assert not _gencode_contains(tempdir, NonConstModule, 0, r'torch.add\(.*, 1, alpha=1\)') + + class CloneModule(torch.nn.Module): def __init__(self): super().__init__() @@ -733,6 +768,7 @@ def __init__(self): def forward(self, a, b): return torch.min(a, b) + def _gencode_min_function_worker(tempdir): init_distributed() m_new = parallelize( @@ -1657,7 +1693,7 @@ def check_op(*names): for name in names: code = add_codes.pop(0) if name in not_folded_names: - assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, getitem_.*, alpha=1\)', code) + assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, get.*, alpha=1\)', code) else: assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, 2, alpha=1\)', code) @@ -1726,7 +1762,7 @@ def test_fold_constant(tmp_path, fold_input): else: # add_27 = torch.add(linear_30, getitem_20, alpha=1) assert _gencode_contains(tmp_path, CCFModule2, 0, - r'add_.* = torch\.add\(linear_.*, getitem_.*, alpha=1\)') + r'add_.* = torch\.add\(linear_.*, get.*, alpha=1\)') # b = b * ashape3 # mul_2_51 = torch.mul(mul_1_57, add_38) assert _gencode_contains(tmp_path, CCFModule2, 0, @@ -1963,3 +1999,303 @@ def test_codegen_forward_error_compile(tmp_path): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of GPU devices') def test_codegen_forward_error(tmp_path): launch_torchrun(2, _gencode_forward_error_worker, tmp_path) + + +class WeightModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, input): + input = input + self.weights + out = input @ self.weights + return out + + +class WeightModel2(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = WeightModel() + + def forward(self, input): + return self.weights(input) + + +def pas_weight(graph, cfg, with_auto_multiref=True): + from nnscaler.ir import IRFwOperation, IRDataOperation + from nnscaler.policies import _tp, _replica, auto_multiref + ngpus = cfg.plan_ngpus + if with_auto_multiref: + auto_multiref(graph) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'add': + _tp(graph, node, list(range(ngpus)), 1, 0) + elif node.name == 'matmul': + _tp(graph, node, list(range(ngpus)), 1, 0) + else: + _replica(graph, node, list(range(ngpus))) + return graph + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('with_auto_multiref', [True, False]) +def test_weight_partition(tmp_path, with_auto_multiref): + """ + If auto_multiref is not applied, the weight will correctly partitioned + If auto_multiref is applied, the weight will be replicated as a whole + """ + input = torch.randn((4, 4)) + instance_name = f'with_auto_multiref_{with_auto_multiref}' + + dummy_input = {'input': input} + + m = WeightModel2() + m.train() + + parallelize( + m, + dummy_input, + partial(pas_weight, with_auto_multiref=with_auto_multiref), + ComputeConfig(2, 2), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + instance_name=instance_name, + ) + + module_class = _load_parallel_module_class(WeightModel2, gen_savedir=tmp_path, instance_name=instance_name, rank=0) + + if with_auto_multiref: + for rank in range(2): + fullmap = module_class.attr_meta_maps[rank] + assert fullmap[list(fullmap.keys())[0]].sub_shape == (4, 4) + else: + for rank in range(2): + fullmap = module_class.attr_meta_maps[rank] + assert fullmap[list(fullmap.keys())[0]].sub_shape == (2, 4) + +class DynamicInputModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = torch.nn.Parameter(torch.randn(1, 1)) + + def forward(self, input): + return input + self.weights + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dims', [[], [0, 1]]) +def test_dynamic_dim_partition(tmp_path, dynamic_dims): + input = mark_dynamic(torch.randn((4, 4)), dynamic_dims) + dummy_input = {'input': input} + instance_name=f'{"no" if not dynamic_dims else ""}_dynamic_dims' + + m = DynamicInputModel() + m.train() + + parallelize( + m, + dummy_input, + 'tp', + ComputeConfig(2, 2), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + instance_name=instance_name, + ) + if dynamic_dims: + # no partition for dynamic input + assert not _gencode_contains(tmp_path, DynamicInputModel, 0, r'nnscaler.runtime.adapter.nn.split_allgather', instance_name=instance_name) + else: + assert _gencode_contains(tmp_path, DynamicInputModel, 0, r'nnscaler.runtime.adapter.nn.split_allgather', instance_name=instance_name) + + +@replace_all_device_with('cpu') +def test_zero3_normal(tmp_path): + from tests.parallel_module.test_end2end import MLP + m = MLP(2, 2) + dummy_input = { + 'data': torch.randn( + 2, 2), + 'target': torch.rand( + 2, 2) + } + m.train() + parallelize( + m, + {'data': dummy_input}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # code looks like: + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.prefetch_param\(self\.layers_0_weight_\d+\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.postevict_param\(self\.layers_0_weight_\d+\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.backward_postevict_param\(.*, self\.layers_0_weight_\d+, 1\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.backward_prefetch_param\(.*, self\.layers_0_weight_\d+, 1\)') + + # def segment35_impl(self, data_23): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 46, in forward, x = data['data'] + # getitem_25 = _operator.getitem(data_23, 'data') + # self.prefetch_param(self.layers_0_weight_26) + # getitem_25 = self.backward_postevict_param(getitem_25, self.layers_0_weight_26, 1) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 48, in forward, x = layer(x) + # linear_27 = torch.nn.functional.linear(getitem_25, self.layers_0_weight_26, bias=None) + # self.postevict_param(self.layers_0_weight_26) + # linear_27 = self.backward_prefetch_param(linear_27, self.layers_0_weight_26, 1) + # del getitem_25 + # self.prefetch_param(self.layers_1_weight_28) + # linear_27 = self.backward_postevict_param(linear_27, self.layers_1_weight_28, 2) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 48, in forward, x = layer(x) + # linear_1_29 = torch.nn.functional.linear(linear_27, self.layers_1_weight_28, bias=None) + # self.postevict_param(self.layers_1_weight_28) + # linear_1_29 = self.backward_prefetch_param(linear_1_29, self.layers_1_weight_28, 2) + # del linear_27 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 49, in forward, x = torch.sigmoid(x) + # sigmoid_30 = torch.sigmoid(linear_1_29) + # del linear_1_29 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 50, in forward, loss = self.loss_fn(x, data['target']) + # getitem_1_31 = _operator.getitem(data_23, 'target') + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 50, in forward, loss = self.loss_fn(x, data['target']) + # binary_cross_entropy_24 = torch.nn.functional.binary_cross_entropy(sigmoid_30, getitem_1_31, weight=None, reduction='mean') + # del sigmoid_30, getitem_1_31 + # return binary_cross_entropy_24 + + # def segment35(self, data_23): + # with self.save_params_hooks(): + # return self.segment35_impl(data_23) + + +class SoloOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + self.p = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + x = self.scale.sum() + x + self.p + return torch.sum(x) + + +def launch_zero3_run_solo_param(tmp_path): + init_distributed() + m = SoloOpModule() + dummy_input = torch.randn(4, 4) + m.train() + m_new = parallelize( + m, + {'x': dummy_input}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=True, + reuse='override', + ) + loss = m_new(dummy_input) + loss.backward() + # scale can't be evicited with backward hook + assert len(m_new._backward_prefetched_params) == 1 + # but it should have been evicted in reducer. + assert list(m_new._backward_prefetched_params.keys())[0].shape == (8,) + assert not _gencode_contains(tmp_path, SoloOpModule, 0, + r'self\.backward_postevict_param\(.*, self\.scale_\d+, \d+\)') + assert _gencode_contains(tmp_path, SoloOpModule, 0, + r'self\.backward_postevict_param\(.*, self\.p_\d+, \d+\)') + # code looks like: + # def segment32_impl(self, x_17): + # self.prefetch_param(self.scale_19) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # sum_1_20 = torch.sum(self.scale_19) + # self.postevict_param(self.scale_19) + # sum_1_20 = self.backward_prefetch_param(sum_1_20, self.scale_19, 0) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # add_21 = torch.add(sum_1_20, x_17, alpha=1) + # del x_17, sum_1_20 + # self.prefetch_param(self.p_22) + # add_21 = self.backward_postevict_param(add_21, self.p_22, 2) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # add_1_23 = torch.add(add_21, self.p_22, alpha=1) + # self.postevict_param(self.p_22) + # add_1_23 = self.backward_prefetch_param(add_1_23, self.p_22, 2) + # del add_21 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2187, in forward, return torch.sum(x) + # sum_2_18 = torch.sum(add_1_23) + # del add_1_23 + # return sum_2_18 + + # def segment32(self, x_17): + # with self.save_params_hooks(): + # return self.segment32_impl(x_17) + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of GPU devices') +def test_zero3_run_solo_param(tmp_path): + launch_torchrun(2, launch_zero3_run_solo_param, tmp_path) + + +@nnscaler.register_op('*, *, *, * -> *, *, *, *') +def _zero3_multi_inout(x, y, z, w): + return x + 1, y + 1, z + 1, w + 1 + + +class Zero3MultiInoutModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = torch.nn.Parameter(torch.randn(4, 4)) + self.q = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, y): + return _zero3_multi_inout(x, y, self.p, self.q) + + +@replace_all_device_with('cpu') +def test_zero3_multi_inout(tmp_path): + m = Zero3MultiInoutModule() + m.train() + m_new = parallelize( + m, + {'x': torch.randn(4, 4), 'y': torch.randn(4, 4)}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert len(_gencode_contains(tmp_path, Zero3MultiInoutModule, 0, + 'self.backward_prefetch_param')) == 8 + assert len(_gencode_contains(tmp_path, Zero3MultiInoutModule, 0, + 'self.backward_postevict_param')) == 4 + # code looks like: + # def segment34_impl(self, x_25, y_26): + # self.prefetch_param(self.p_31) + # x_25 = self.backward_postevict_param(x_25, self.p_31, 0) + # y_26 = self.backward_postevict_param(y_26, self.p_31, 0) + # self.prefetch_param(self.q_32) + # x_25 = self.backward_postevict_param(x_25, self.q_32, 0) + # y_26 = self.backward_postevict_param(y_26, self.q_32, 0) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2259, in forward, return _zero3_multi_inout(x, y, self.p, self.q) + # _zero3_multi_inout_27, _zero3_multi_inout_28, _zero3_multi_inout_29, _zero3_multi_inout_30 = tests.parallel_module.test_gencode._zero3_multi_inout(x_25, y_26, self.p_31, self.q_32) + # self.postevict_param(self.p_31) + # self.postevict_param(self.q_32) + # _zero3_multi_inout_27 = self.backward_prefetch_param(_zero3_multi_inout_27, self.p_31, 0) + # _zero3_multi_inout_27 = self.backward_prefetch_param(_zero3_multi_inout_27, self.q_32, 0) + # _zero3_multi_inout_28 = self.backward_prefetch_param(_zero3_multi_inout_28, self.p_31, 0) + # _zero3_multi_inout_28 = self.backward_prefetch_param(_zero3_multi_inout_28, self.q_32, 0) + # _zero3_multi_inout_29 = self.backward_prefetch_param(_zero3_multi_inout_29, self.p_31, 0) + # _zero3_multi_inout_29 = self.backward_prefetch_param(_zero3_multi_inout_29, self.q_32, 0) + # _zero3_multi_inout_30 = self.backward_prefetch_param(_zero3_multi_inout_30, self.p_31, 0) + # _zero3_multi_inout_30 = self.backward_prefetch_param(_zero3_multi_inout_30, self.q_32, 0) + # del x_25, y_26 + # return _zero3_multi_inout_27, _zero3_multi_inout_28, _zero3_multi_inout_29, _zero3_multi_inout_30 + + # def segment34(self, x_25, y_26): + # with self.save_params_hooks(): + # return self.segment34_impl(x_25, y_26) + assert True diff --git a/tests/parallel_module/test_gencode_einops.py b/tests/parallel_module/test_gencode_einops.py new file mode 100644 index 00000000..bea1c75a --- /dev/null +++ b/tests/parallel_module/test_gencode_einops.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile +import functools +from einops import rearrange +import torch + +from nnscaler import parallelize, ComputeConfig +from nnscaler.graph import parser +from nnscaler.graph.tracer import ConcreteTracer + +from tests.utils import replace_all_device_with +from .test_gencode import _gencode_contains, print_gencode + + +class RearrangeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x, y): + return self.linear(x) + rearrange(y, '(h w) -> h w', h=3, w=3) + f(3) + + +def log_f(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + print(f"Function '{func.__name__}' called") + return func(*args, **kwargs) + return wrapper + + +@log_f +@functools.cache +def f(x: int) -> int: + return x * 2 + + +@replace_all_device_with('cpu') +def test_trace_rearrange(): + import gc + def _convert(): + model = RearrangeModule() + parser.to_fx_graph(model, {'x': torch.randn(3, 3), 'y': torch.randn(9)}) + gc.collect() + + _convert() + for obj in gc.get_objects(): + # einops is using functools.cache + # will leak memory if not properly handle it. + assert not isinstance(obj, ConcreteTracer) + + +@replace_all_device_with('cpu') +def test_codegen_rearrange(): + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + RearrangeModule(), + {'x': torch.randn(3, 3), 'y': torch.randn(9)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False + ) + # parallelize will succeed. + assert True diff --git a/tests/parallel_module/test_gencode_kwargs.py b/tests/parallel_module/test_gencode_kwargs.py new file mode 100644 index 00000000..85a1b3c5 --- /dev/null +++ b/tests/parallel_module/test_gencode_kwargs.py @@ -0,0 +1,189 @@ +import nnscaler +from nnscaler import parallelize, ComputeConfig + +import torch + +from .test_gencode import _gencode_contains, replace_all_device_with, print_gencode + + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='kw_operator') +def kw_operator(x: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='kw_operator2') +def kw_operator2(x: torch.Tensor, y: torch.Tensor, kwargs) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + + +class KwargsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, **kwargs): + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + c = kwargs['c'] + return kw_operator(x, self.scale, **kwargs) \ + + kw_operator2(x, self.scale, kwargs) + a + b + c + + +@replace_all_device_with('cpu') +def test_kwargs(tmp_path): + m = KwargsModule() + m.train() + parallelize( + m, + {'x': torch.randn(4, 4), 'a': 3, 'c': 4, 'd': 5}, + 'dp', + ComputeConfig(1, 1, constant_folding=False), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'a\'\)') + assert not _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'b\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'c\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'d\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'tests.parallel_module.test_gencode_kwargs.kw_operator\(x_\d+, self.scale_\d+, a=getitem_[\d_]+, c=getitem_[\d_]+, d=getitem_[\d_]+\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r"tests.parallel_module.test_gencode_kwargs.kw_operator2\(x_\d+, self.scale_\d+, kwargs=\{'a': getitem_[\d_]+, 'c': getitem_[\d_]+, 'd': getitem_[\d_]+\}\)") + # code looks like: + # def segment49(self, x_31, **kwargs_6): + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_28 = _operator.getitem(kwargs_6, 'a') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_1_29 = _operator.getitem(kwargs_6, 'c') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_2_30 = _operator.getitem(kwargs_6, 'd') + # # created at IRAdapterGener:local_consumer_multiref + # x_52, x_56 = nnscaler.runtime.function.multiref(x_31, times=2) + # del x_31 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # kw_operator_34 = tests.parallel_module.test_gencode_kwargs.kw_operator(x_52, self.scale_33, a=getitem_28, c=getitem_1_29, d=getitem_2_30) + # del x_52 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # kw_operator2_35 = tests.parallel_module.test_gencode_kwargs.kw_operator2(x_56, self.scale_33, kwargs={'a': getitem_28, 'c': getitem_1_29, 'd': getitem_2_30}) + # del x_56 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_36 = torch.add(kw_operator_34, kw_operator2_35, alpha=1) + # del kw_operator_34, kw_operator2_35 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_1_37 = torch.add(add_36, getitem_28, alpha=1) + # del add_36 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_2_38 = torch.add(add_1_37, 2, alpha=1) + # del add_1_37 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_3_32 = torch.add(add_2_38, getitem_1_29, alpha=1) + # del add_2_38 + # return add_3_32 + + # def _forward_impl(self, x, **kwargs): + # add_3_32 = self.segment49(x, **kwargs) + # return add_3_32 + assert True + + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='dict_operator') +def dict_operator(x: torch.Tensor, y: torch.Tensor, kwargs: dict) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + + +class DictargsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, kwargs: dict): + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + c = kwargs['c'] + return dict_operator(x, self.scale, kwargs) \ + + kw_operator(x, self.scale, **kwargs) + a + b + c + + +@replace_all_device_with('cpu') +def test_dictargs(tmp_path): + m = DictargsModule() + m.train() + parallelize( + m, + {'x': torch.randn(4, 4), 'kwargs': {'a': 3, 'c': 4, 'd': 5}}, + 'dp', + ComputeConfig(1, 1, constant_folding=False), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"builtins.dict.get\(kwargs_.*, 'a', 1\)") + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"builtins.dict.get\(kwargs_.*, 'b', 2\)") + assert len(_gencode_contains(tmp_path, DictargsModule, 0, + r"_operator\.getitem\(kwargs_.*, 'a'\)")) == 1 + assert len(_gencode_contains(tmp_path, DictargsModule, 0, + r"_operator\.getitem\(kwargs_.*, 'c'\)")) == 2 + assert _gencode_contains(tmp_path, DictargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'d\'\)') + assert _gencode_contains(tmp_path, DictargsModule, 0, + r'tests.parallel_module.test_gencode_kwargs.kw_operator\(x_\d+, self.scale_\d+, a=getitem_[\d_]+, c=getitem_[\d_]+, d=getitem_[\d_]+\)') + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"tests.parallel_module.test_gencode_kwargs.dict_operator\(x_\d+, self.scale_\d+, kwargs=kwargs_\d+\)") + # code looks like: + # def segment52(self, x_35, kwargs_6): + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2411, in forward, a = kwargs.get('a', 1) + # get_7 = builtins.dict.get(kwargs_6, 'a', 1) + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2412, in forward, b = kwargs.get('b', 2) + # get_1_8 = builtins.dict.get(kwargs_6, 'b', 2) + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2413, in forward, c = kwargs['c'] + # getitem_31 = _operator.getitem(kwargs_6, 'c') + # # created at IRAdapterGener:local_consumer_multiref + # x_56, x_60 = nnscaler.runtime.function.multiref(x_35, times=2) + # del x_35 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # dict_operator_38 = tests.parallel_module.test_gencode_kwargs.dict_operator(x_56, self.scale_37, kwargs=kwargs_6) + # del x_56 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_1_32 = _operator.getitem(kwargs_6, 'a') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_2_33 = _operator.getitem(kwargs_6, 'c') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_3_34 = _operator.getitem(kwargs_6, 'd') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # kw_operator_39 = tests.parallel_module.test_gencode_kwargs.kw_operator(x_60, self.scale_37, a=getitem_1_32, c=getitem_2_33, d=getitem_3_34) + # del x_60 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_40 = torch.add(dict_operator_38, kw_operator_39, alpha=1) + # del dict_operator_38, kw_operator_39 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_1_41 = torch.add(add_40, get_7, alpha=1) + # del add_40 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_2_42 = torch.add(add_1_41, get_1_8, alpha=1) + # del add_1_41 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_3_36 = torch.add(add_2_42, getitem_31, alpha=1) + # del add_2_42 + # return add_3_36 + + # def _forward_impl(self, x, kwargs): + # add_3_36 = self.segment52(x, kwargs) + # return add_3_36 diff --git a/tests/parallel_module/test_gencode_torch_compile.py b/tests/parallel_module/test_gencode_torch_compile.py index 62ba534c..4865783c 100644 --- a/tests/parallel_module/test_gencode_torch_compile.py +++ b/tests/parallel_module/test_gencode_torch_compile.py @@ -9,7 +9,7 @@ from nnscaler import parallelize, ComputeConfig, register_op -from tests.utils import replace_all_device_with +from tests.utils import raises_with_cause, replace_all_device_with from .test_gencode import _gencode_contains, print_gencode @@ -182,7 +182,7 @@ def forward(self, x): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_codegen_compile_failed_g(): - with pytest.raises(RuntimeError), tempfile.TemporaryDirectory() as tempdir: + with raises_with_cause(RuntimeError, match=".*You must register it to avoid tracing failure..*"), tempfile.TemporaryDirectory() as tempdir: parallelize( Module2(), {'x': torch.randn(3, 3)}, diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 3d046393..3278cc0f 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -72,13 +72,13 @@ def test_empty_weights(model_class, tp): model_class, {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, 'tp', - ComputeConfig(2, 4, use_zero=True, zero_ngroups=2), + ComputeConfig(2, 8, use_zero=True, zero_ngroups=2), gen_savedir=tempdir, reuse='match', load_module=False, instance_name=instance_name, ) - for i in range(4): + for i in range(8): module_class = _load_parallel_module_class(model_class, gen_savedir=tempdir, instance_name=instance_name, rank=i) m = new_empty(module_class) assert m.rank == i @@ -86,9 +86,9 @@ def test_empty_weights(model_class, tp): assert p.device == torch.device('meta') for r in m.reducers: if tp: - assert r.ranks == ((0, 2) if i in (0, 2) else (1, 3)) + assert r.ranks == ((0, 2, 4, 6) if i in (0, 2, 4, 6) else (1, 3, 5, 7)) else: - assert r.ranks == (0, 1, 2, 3) + assert r.ranks == (0, 1, 2, 3, 4, 5, 6, 7) assert len(r.buckets) == 1 assert r.zero assert r.zero_ngroups == 2 diff --git a/tests/parallel_module/test_offload_params.py b/tests/parallel_module/test_offload_params.py new file mode 100644 index 00000000..81e5d305 --- /dev/null +++ b/tests/parallel_module/test_offload_params.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile +from pathlib import Path +import pytest +from typing import Dict, Tuple, List, Any + +import torch +from torch import nn + +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.graph import IRGraph + +from .common import PASMegatron, CubeLinear, init_random, init_distributed, assert_equal +from ..launch_torchrun import launch_torchrun +from ..utils import clear_dir_on_rank0 + + +class SimpleMLP(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super(SimpleMLP, self).__init__() + init_random() + self.register_buffer('buffer', torch.zeros(hidden_dim,)) + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = x + self.buffer + x = torch.relu(x) + x = self.fc2(x) + return x + + +def get_tensor_bytesize(t: torch.Tensor) -> int: + return t.numel() * t.element_size() + + +def pas_test_offload(graph: IRGraph, cfg: ComputeConfig): + ngpus = cfg.plan_ngpus + auto_multiref(graph) + + batch_dim = 0 + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, list(range(ngpus))) + + found_linear = False + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if 'linear' in node.signature and not found_linear: + found_linear = True + algo = node.algorithm('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=1, num=ngpus) + else: + sub_nodes = graph.replicate(node, ngpus) + + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def _mem_worker(): + init_distributed() + bsz, dim = 32, 1024 + compute_config = ComputeConfig( + plan_ngpus=1, + runtime_ngpus=2, + ) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_offload_mem') as tempdir: + module = SimpleMLP(dim, dim, dim) + p_module = parallelize( + module, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + ) + + before_mem = torch.cuda.memory_allocated() + size_to_free = 0 + for reducer in p_module.reducers: + assert get_tensor_bytesize(reducer._contiguous_params) == get_tensor_bytesize(reducer._contiguous_grads) + size_to_free += get_tensor_bytesize(reducer._contiguous_params) + + for buffer in p_module.buffers(): + size_to_free += get_tensor_bytesize(buffer) + + for param in p_module.parameters(): + size_to_free += get_tensor_bytesize(param) + + p_module.sleep() + torch.distributed.barrier() + after_mem = torch.cuda.memory_allocated() + print(f"Memory before offload: {before_mem}, after offload: {after_mem}, freed: {before_mem - after_mem}") + print(f"Total size to free: {size_to_free}") + + assert size_to_free == before_mem - after_mem, f"Expected {size_to_free}, but got {before_mem - after_mem}" + + +def _correctness_worker(): + init_distributed() + bsz, dim, num_steps = 32, 1024, 5 + compute_config = ComputeConfig( + plan_ngpus=1, + runtime_ngpus=2, + ) + + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_offload_correctness') as tempdir: + # Create test data + torch.manual_seed(42 + torch.distributed.get_rank()) + test_data = [torch.randn(bsz, dim).cuda() for _ in range(num_steps)] + + # Test 1: Normal execution without offload/load + init_random() + module1 = SimpleMLP(dim, dim, dim) + p_module1 = parallelize( + module1, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + instance_name='normal' + ) + optimizer1 = build_optimizer(p_module1, torch.optim.Adam, lr=0.01) + + results_normal = [] + for step, x in enumerate(test_data): + p_module1.train() + output = p_module1(x) + loss = output.sum() + loss.backward() + optimizer1.step() + optimizer1.zero_grad() + + # Save intermediate results for comparison + results_normal.append({ + 'loss': loss.detach().cpu(), + 'output': output.detach().cpu(), + 'params': {name: param.detach().cpu().clone() for name, param in p_module1.named_parameters()} + }) + + torch.distributed.barrier() + + # Test 2: Execution with offload/load + init_random() + module2 = SimpleMLP(dim, dim, dim) + p_module2 = parallelize( + module2, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + instance_name='offload' + ) + optimizer2 = build_optimizer(p_module2, torch.optim.Adam, lr=0.01) + + # First offload to initialize the buffer_shape + p_module2.sleep() + + results_offload = [] + for step, x in enumerate(test_data): + # Load params at the beginning of each step + p_module2.wake_up() + + p_module2.train() + output = p_module2(x) + loss = output.sum() + loss.backward() + optimizer2.step() + optimizer2.zero_grad() + + # Save intermediate results for comparison + results_offload.append({ + 'loss': loss.detach().cpu(), + 'output': output.detach().cpu(), + 'params': {name: param.detach().cpu().clone() for name, param in p_module2.named_parameters()} + }) + + # Offload params at the end of each step + p_module2.sleep() + + torch.distributed.barrier() + + # Compare results + for step in range(num_steps): + normal_result = results_normal[step] + offload_result = results_offload[step] + + # Compare loss + assert torch.equal(normal_result['loss'], offload_result['loss']), \ + f"Loss mismatch at step {step}: {normal_result['loss']} vs {offload_result['loss']}" + + # Compare output + assert torch.equal(normal_result['output'], offload_result['output']), \ + f"Output mismatch at step {step}" + + # Compare parameters + for param_name in normal_result['params']: + assert torch.equal(normal_result['params'][param_name], + offload_result['params'][param_name]), \ + f"Parameter {param_name} mismatch at step {step}" + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_offload_params_mem(): + launch_torchrun(2, _mem_worker) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_offload_params_correctness(): + launch_torchrun(2, _correctness_worker) diff --git a/tests/runtime/test_gnorm.py b/tests/runtime/test_gnorm.py index 80025906..1de3ee35 100644 --- a/tests/runtime/test_gnorm.py +++ b/tests/runtime/test_gnorm.py @@ -57,7 +57,7 @@ def cal_wnorm_cube(model: CubeModule): for p in model.parameters_for_optimizer(): p.grad = p.data # p.grad.copy_(p.data) - nreplicas2localparams = prepare_for_grad_clip(model, is_zero=CompileFlag.use_zero) + nreplicas2localparams = prepare_for_grad_clip(model, use_zero=CompileFlag.use_zero) wnorm, _ = clip_gnorm(nreplicas2localparams, None) # maps = {tid: [t.size() for t in ts] for tid, ts in nreplicas2localparams.items()} # print(f'cube nrepicas len: {maps}') diff --git a/tests/runtime/test_hybrid_optimizer.py b/tests/runtime/test_hybrid_optimizer.py new file mode 100644 index 00000000..b4238246 --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path +import shutil + +import torch +import pytest +import torch.distributed + +from nnscaler.cli.trainer import Trainer +from tests.parallel_module.common import assert_close, assert_equal +from ..launch_torchrun import launch_torchrun + + +def param_clss_fn(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'layers.1.' in param_name or 'layers.10.' in param_name: + return 0, 0 + elif 'layers.2.' in param_name or 'layers.12.' in param_name: + return 0, 1 + else: + return 1, 0 + +_lr_history = [] +def on_train_step_start(trainer: 'Trainer', batches) -> None: + _lr_history.append(( + trainer.optimizer.optimizers[0].param_groups[0]['lr'], + trainer.optimizer.optimizers[0].param_groups[1]['lr'], + trainer.optimizer.optimizers[1].param_groups[0]['lr'], + )) + + +def trainer_worker(save_dir, use_zero): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('test_hybrid_optimizer_trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + _lr_history.clear() + + # train with a resume + ckpt0_savedir = save_dir / 'ckpt0' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + assert len(_lr_history) == 10 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--checkpoint.resume_from', 'last', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + + assert len(_lr_history) == 20 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + _lr_history.clear() + # train in one time + ckpt1_savedir = save_dir / 'ckpt1' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + assert len(_lr_history) == 20 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + + # train with different config + trainer_config = [ + '-f', config_path, + '--compute_config.plan_ngpus', '2', + '--pas_policy', 'tp', + '--max_train_steps', '30', + '--checkpoint.resume_from.checkpoint', 'last', + '--checkpoint.resume_from.with_merged', str(True), + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(1 - use_zero), + ] + trainer = Trainer(trainer_config) + trainer.run() + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + r = trainer._merge_checkpoint([ckpt0_savedir / 'last' / f'{i}.ckpt' for i in range(2)]) + # should success + assert r + + torch.distributed.barrier() + + from subprocess import check_call as _call + from functools import partial + call = partial(_call, shell=True) + + if torch.distributed.get_rank() == 0: + call(f"python -m nnscaler.cli.checkpoint distribute {ckpt1_savedir}/last {ckpt1_savedir}/sharded {' '.join(trainer_config)} --compute_config.runtime_ngpus {torch.distributed.get_world_size()}") + + torch.distributed.barrier() + + trainer = Trainer([ + '-f', config_path, + '--compute_config.plan_ngpus', '2', + '--pas_policy', 'tp', + '--max_train_steps', '30', + '--checkpoint.resume_from.checkpoint', f'{ckpt1_savedir}/sharded', + '--checkpoint.resume_from.with_merged', str(False), + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--compute_config.use_zero', str(1 - use_zero), + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('use_zero', [0, 1]) +def test_hybrid_optimizer(tmp_path, use_zero): + launch_torchrun(2, trainer_worker, tmp_path, use_zero) diff --git a/tests/runtime/test_hybrid_optimizer_trainer_args.yaml b/tests/runtime/test_hybrid_optimizer_trainer_args.yaml new file mode 100644 index 00000000..b84c4870 --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer_trainer_args.yaml @@ -0,0 +1,76 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_train_steps: 10 +enable_progress_bar: false +precision: bf16 +instance_name: p$(compute_config.plan_ngpus) + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: nnscaler.HybridOptimizer + param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn + args: + config: + optimizers: + - type: torch.optim.Adam + options: + lr: 0.02 + param_groups: + - options: + lr: 0.04 + - options: + lr: 0.06 + - type: torch.optim.AdamW + options: + lr: 0.04 + +lr_scheduler: + type: nnscaler.HybridLRScheduler + args: + config: + schedulers: + - type: torch.optim.lr_scheduler.ConstantLR + options: + factor: 0.5 + total_iters: 5 + - type: torch.optim.lr_scheduler.ConstantLR + options: + factor: 0.2 + total_iters: 5 + interval: step + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped + +hook: + on_train_step_start: tests.runtime.test_hybrid_optimizer.on_train_step_start diff --git a/tests/runtime/test_serialization.py b/tests/runtime/test_serialization.py new file mode 100644 index 00000000..aaee1eae --- /dev/null +++ b/tests/runtime/test_serialization.py @@ -0,0 +1,75 @@ +import torch +import pytest + +from nnscaler.runtime.serialization import load, save, convert +from nnscaler.cli.serialization import convert_format + +from tests.parallel_module.common import assert_equal + + +def test_normal(tmp_path): + a = torch.randn((2, 2), device='cpu') + b = torch.randn((2, 3), device='cpu') + c = torch.randn((4, 4), device='cpu') + tensors = { + "embedding": a, + "attention": b, + "fc": a, # shared tensor + "bias": {'inner': b, 'outer': {'deep': c}} + } + save(tensors, tmp_path / "model.safetensors") + loaded = load(tmp_path / "model.safetensors", lazy=False) + assert_equal(tensors, loaded) + convert(tmp_path / "model.safetensors", tmp_path / "model.pt") + convert_format( + src=str(tmp_path / "model.safetensors"), + dst=str(tmp_path / "model2.ckpt"), + ) + loaded_pt = torch.load(tmp_path / "model.pt") + assert_equal(tensors, loaded_pt) + loaded_pt2 = torch.load(tmp_path / "model2.ckpt") + assert_equal(tensors, loaded_pt2) + + +def test_shared_params(tmp_path): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 4) + self.fc2 = torch.nn.Linear(4, 4) + # share the same weight + self.fc2.weight = self.fc1.weight + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = Model() + save(model.state_dict(), tmp_path / "model.safetensors") + loaded = load(tmp_path / "model.safetensors", lazy=False) + assert_equal(model.state_dict(), loaded) + convert(tmp_path / "model.safetensors", tmp_path / "model.pt") + loaded_pt = torch.load(tmp_path / "model.pt") + assert_equal(model.state_dict(), loaded_pt) + + +def test_bad_shared_params(tmp_path): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 4) + self.fc2 = torch.nn.Linear(4, 4) + # share the same weight + # This case is not common, + # so we don't support it currently. + self.fc2.weight.data = self.fc1.weight.reshape(-1) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = Model() + with pytest.raises(RuntimeError): + save(model.state_dict(), tmp_path / "model.safetensors") diff --git a/tests/test_policies.py b/tests/test_policies.py index 4f1fa7b0..003adc0c 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -8,9 +8,13 @@ import torch import torch.nn as nn -from nnscaler.parallel import ComputeConfig, parallelize +from nnscaler.parallel import ComputeConfig, _load_parallel_module_class, parallelize +from nnscaler.policies import get_called_self_module_name, get_pas_ops +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.common import FFN, init_distributed +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode -from .utils import init_random +from .utils import init_random, replace_all_device_with MBS = 2 DIM = 16 @@ -58,3 +62,1048 @@ def test_autodist(): load_module=False ) assert m_new is None + + +def test_call_name(): + assert get_called_self_module_name('self.up_proj(x)') == 'up_proj' + assert get_called_self_module_name('self.act_fn(self.gate_proj(x))') == 'act_fn' + assert get_called_self_module_name('self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))') == 'down_proj' + assert get_called_self_module_name('torch.tanh(x)') == '' + assert get_called_self_module_name('x * y') == '' + assert get_called_self_module_name('self.up_proj(x).transpose()') == '' + + +class FnPolicyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = FFN(4, 8) + + def forward(self, x): + x = x * 2 + x = self.ffn(x) + x = x + 3 + return x + + +def megatron_ffn_policy(graph, cfg): + from nnscaler.ir import IRSubTensor + from nnscaler.policies import OpPlan, OpPartition + + for node in get_pas_ops(graph): + if FFN not in node.module_class_chain: # work on FFN module + continue + + if node.fn in [torch.tanh, torch.mul]: + yield OpPlan(node, partition=OpPartition(input=0, dim=1)) + continue + + assert node.fn == torch.nn.functional.linear + + input1: IRSubTensor = node.input(1) + if not input1.is_param(): # linear weight param + continue + + # we will partition gate_proj/up_proj with column parallelism (tp=ngpus) + # and partition down_proj with row parallelism (tp=ngpus) + + if input1.name.endswith('gate_proj.weight') or input1.name.endswith('up_proj.weight'): + # gate_proj/up_proj + # column parallelism + yield OpPlan(node, partition=OpPartition(input=1, dim=0)) + elif input1.name.endswith('down_proj.weight'): + # down_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + + +def megatron_ffn_policy_auto(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition + + linear_rank = 0 + for node in get_pas_ops(graph): + if FFN not in node.module_class_chain: # work on FFN module + continue + + if node.fn == torch.nn.functional.linear: + if linear_rank in [0, 1]: + # gate_proj/up_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=0)) + else: + assert linear_rank == 2 + # down_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + linear_rank += 1 + else: + # other ops + yield OpPlan(node, partition='auto') + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('policy', [megatron_ffn_policy, megatron_ffn_policy_auto]) +def test_codegen_fn(tmp_path, policy): + parallelize( + FnPolicyModule(), + {'x': torch.randn(2, 4)}, + policy, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicyModule, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + assert fullmap['ffn.gate_proj.weight'].shape == (8, 4) and fullmap['ffn.gate_proj.weight'].sub_shape == (4, 4) + assert fullmap['ffn.up_proj.weight'].shape == (8, 4) and fullmap['ffn.up_proj.weight'].sub_shape == (4, 4) + assert fullmap['ffn.down_proj.weight'].shape == (4, 8) and fullmap['ffn.down_proj.weight'].sub_shape == (4, 4) + + # will generate two communication ops + # one for ffn input + assert _gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + # one for ffn output + assert _gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.allreduce_identity') + + assert len(_gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.')) == 2 + + # Generated code of rank 0 should looks like: + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + + # self.register_parameter('ffn_gate_proj_weight_49', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_gate_proj_weight_49', 5, True, 'ffn.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_up_proj_weight_63', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_up_proj_weight_63', 11, True, 'ffn.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_down_proj_weight_77', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_down_proj_weight_77', 17, True, 'ffn.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def segment118(self, x_25): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1653, in forward, x = x * 2 + # mul_27 = torch.mul(x_25, 2) + # del x_25 + # mul_27 = nnscaler.runtime.adapter.nn.identity_allreduce(mul_27, ranks=[0, 1]) + # # created at IRAdapterGener:local_consumer_multiref + # mul_85, mul_89 = nnscaler.runtime.function.multiref(mul_27, times=2) + # del mul_27 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_51 = torch.nn.functional.linear(mul_85, self.ffn_gate_proj_weight_49, bias=None) + # del mul_85 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_59 = torch.tanh(linear_51) + # del linear_51 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_65 = torch.nn.functional.linear(mul_89, self.ffn_up_proj_weight_63, bias=None) + # del mul_89 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_73 = torch.mul(tanh_59, linear_1_65) + # del tanh_59, linear_1_65 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_79 = torch.nn.functional.linear(mul_1_73, self.ffn_down_proj_weight_77, bias=None) + # del mul_1_73 + # linear_2_35 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_2_79, ranks=[0, 1]) + # del linear_2_79 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1655, in forward, x = x + 3 + # add_26 = torch.add(linear_2_35, 3, alpha=1) + # del linear_2_35 + # return add_26 + + +class FFNDropout(torch.nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = torch.nn.Tanh() + self.dropout = torch.nn.Dropout(p=0.1) + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return self.dropout(down_proj) + + +class FnPolicyModuleList(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + + def forward(self, x): + x = x * 2 + for ffn in self.ffn: + x = ffn(x) + x = x + 3 + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + +def megatron_ffn_policy_list(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition, get_layer_index, get_called_self_module_name + + for node in get_pas_ops(graph): + if FFNDropout not in node.module_class_chain: # work on FFN module + continue + + ffn_idx = get_layer_index(node.fqn) + module_called = get_called_self_module_name(node.call_expr) + + if node.fn == torch.nn.functional.linear: + if module_called in ['gate_proj', 'up_proj']: + # gate_proj/up_proj + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition=OpPartition(input=1, dim=0)) + else: + # down_proj + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition=OpPartition(input=1, dim=1)) + else: + # other ops + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition='auto') + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline(tmp_path): + parallelize( + FnPolicyModuleList(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + 'pipeline_size': 2, + } + ), + gen_savedir=tmp_path, + load_module=False + ) + + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicyModuleList, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + tp_idx = rank // 2 + assert fullmap[f'ffn.{tp_idx}.gate_proj.weight'].shape == (8, 4) and fullmap[f'ffn.{tp_idx}.gate_proj.weight'].sub_shape == (4, 4) + assert fullmap[f'ffn.{tp_idx}.up_proj.weight'].shape == (8, 4) and fullmap[f'ffn.{tp_idx}.up_proj.weight'].sub_shape == (4, 4) + assert fullmap[f'ffn.{tp_idx}.down_proj.weight'].shape == (4, 8) and fullmap[f'ffn.{tp_idx}.down_proj.weight'].sub_shape == (4, 4) + + # will generate two communication ops + # one for ffn input + if tp_idx == 0: + assert not _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + else: + assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + # one for ffn output + assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.allreduce_identity') + + if tp_idx == 0: + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.')) == 1 + else: + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.')) == 2 + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, r'ckpt.checkpoint\(recompute')) == 1 + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, r'def recompute\(')) == 1 + + + # Generated code of rank 0 looks like: + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 0 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('ffn_0_gate_proj_weight_168', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_gate_proj_weight_168', 5, True, 'ffn.0.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_up_proj_weight_182', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_up_proj_weight_182', 11, True, 'ffn.0.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_down_proj_weight_196', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_down_proj_weight_196', 17, True, 'ffn.0.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def segment79(self, x_49): + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 243, in forward, x = x * 2 + # mul_51 = torch.mul(x_49, 2) + # del x_49 + # mul_51 = nnscaler.runtime.adapter.nn.identity_allreduce(mul_51, ranks=[0, 1]) + + # def recompute(mul_51): + # # created at IRAdapterGener:local_consumer_multiref + # mul_246, mul_250 = nnscaler.runtime.function.multiref(mul_51, times=2) + # del mul_51 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_170 = torch.nn.functional.linear(mul_246, self.ffn_0_gate_proj_weight_168, bias=None) + # del mul_246 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_178 = torch.tanh(linear_170) + # del linear_170 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_184 = torch.nn.functional.linear(mul_250, self.ffn_0_up_proj_weight_182, bias=None) + # del mul_250 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_192 = torch.mul(tanh_178, linear_1_184) + # del tanh_178, linear_1_184 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_198 = torch.nn.functional.linear(mul_1_192, self.ffn_0_down_proj_weight_196, bias=None) + # del mul_1_192 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_21 = self.training + # linear_2_59 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_2_198, ranks=[0, 1]) + # del linear_2_198 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_60 = torch. + # nn.functional.dropout(linear_2_59, p=0.1, training=ffn_0_dropout_training_21, inplace=False) + # del linear_2_59 + # return dropout_60 + + # dropout_60 = ckpt.checkpoint(recompute, mul_51, use_reentrant=False) + # return dropout_60 + + # def adapter196(self, dropout_60): + # dropout_236 = nnscaler.runtime.adapter.chunk(dropout_60, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(dropout_236, shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # return + + # def adapter207(self): + # gdropout_242 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # gdropout_85 = nnscaler.runtime.adapter.all_gather(gdropout_242, dim=1, ranks=[0, 1]) + # return gdropout_85 + + # def adapter160(self): + # sum_1_50 = nnscaler.runtime.adapter.move((), shape=(), dtype=torch.float32, src=2, dst=0) + # return sum_1_50 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_71): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # x_49 = next(*(dataloader_71, )) + # dropout_60 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_49, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_60, ), requires_grad=False) + # x_278 = next(*(dataloader_71, )) + # dropout_286 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_278, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_286, ), requires_grad=False) + # gdropout_85 = nnscaler.runtime.executor.aexecute(model.adapter207, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gx_73 = nnscaler.runtime.executor.backward('segment79', (x_49, ), (dropout_60, ), (gdropout_85, )) + # del x_49, dropout_60, gdropout_85, gx_73 + # gdropout_287 = nnscaler.runtime.executor.aexecute(model.adapter207, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gx_279 = nnscaler.runtime.executor.backward('segment79', (x_278, ), (dropout_286, ), (gdropout_287, )) + # del x_278, dropout_286, gdropout_287, gx_279 + # sum_1_50 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=True) + # sum_1_306 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=True) + + # def _infer_step(model, dataloader_71): + # _ = None + # x_49 = next(*(dataloader_71, )) + # dropout_60 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_49, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_60, ), requires_grad=False) + # x_278 = next(*(dataloader_71, )) + # dropout_286 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_278, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_286, ), requires_grad=False) + # sum_1_50 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=False) + # sum_1_306 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=False) + # return sum_1_50, sum_1_306 + assert True + + +class FnPolicyModuleSharedWeight(torch.nn.Module): + def __init__(self): + super().__init__() + self.input_projection = torch.nn.Linear(4, 4, bias=False) + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + self.output_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection.weight = self.input_projection.weight # share weight + + def forward(self, x): + x = self.input_projection(x) + for ffn in self.ffn: + x = ffn(x) + x = self.output_projection(x) + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline_shared_weight(tmp_path): + parallelize( + FnPolicyModuleSharedWeight(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + 'pipeline_size': 2, + } + ), + gen_savedir=tmp_path, + load_module=False + ) + for rank in range(2): + # the input projection is multiref'ed + assert _gencode_contains(tmp_path, FnPolicyModuleSharedWeight, rank, r'nnscaler.runtime.function.multiref\(self.input_projection') + + for rank in range(2, 4): + # receive shared weight projection via identity + assert _gencode_contains(tmp_path, FnPolicyModuleSharedWeight, rank, r'nnscaler.runtime.function.identity\(input_projection') + + # Generated code of rank 0 looks like: + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 1 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('input_projection_weight_55', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('input_projection_weight_55', 3, True, 'input_projection.weight', (4, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_gate_proj_weight_189', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_gate_proj_weight_189', 7, True, 'ffn.0.gate_proj.weight', (8, 4), (slice(4, 8, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_up_proj_weight_203', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_up_proj_weight_203', 13, True, 'ffn.0.up_proj.weight', (8, 4), (slice(4, 8, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_down_proj_weight_217', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_down_proj_weight_217', 19, True, 'ffn.0.down_proj.weight', (4, 8), (slice(0, 4, None), slice(4, 8, None)), 1) + # self._post_init(init_params, build_buckets) + + # def segment83(self, x_53): + # # shared param + # input_projection_weight_173, input_projection_weight_174 = nnscaler.runtime.function.multiref(self.input_projection_weight_55, times=2) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 441, in forward, x = self.input_projection(x) + # linear_56 = torch.nn.functional.linear(x_53, input_projection_weight_173, bias=None) + # del x_53, input_projection_weight_173 + # linear_56 = nnscaler.runtime.adapter.nn.identity_allreduce(linear_56, ranks=[0, 1]) + + # def recompute(linear_56): + # # created at IRAdapterGener:local_consumer_multiref + # linear_278, linear_282 = nnscaler.runtime.function.multiref(linear_56, times=2) + # del linear_56 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_191 = torch.nn.functional.linear(linear_278, self.ffn_0_gate_proj_weight_189, bias=None) + # del linear_278 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_199 = torch.tanh(linear_1_191) + # del linear_1_191 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_205 = torch.nn.functional.linear(linear_282, self.ffn_0_up_proj_weight_203, bias=None) + # del linear_282 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_213 = torch.mul(tanh_199, linear_2_205) + # del tanh_199, linear_2_205 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_3_219 = torch.nn.functional.linear(mul_213, self.ffn_0_down_proj_weight_217, bias=None) + # del mul_213 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_23 = self.training + # linear_3_64 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_3_219, ranks=[0, 1]) + # del linear_3_219 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_65 = torch.nn.functional.dropout(linear_3_64, p=0.1, training=ffn_0_dropout_training_23, inplace=False) + # del linear_3_64 + # return dropout_65 + + # dropout_65 = ckpt.checkpoint(recompute, linear_56, use_reentrant=False) + # return dropout_65, input_projection_weight_174 + + # def adapter190(self, input_projection_weight_174): + # input_projection_weight_257 = nnscaler.runtime.adapter.chunk(input_projection_weight_174, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(input_projection_weight_257, shape=(4, 2), dtype=torch.float32, src=1, dst=3) + # return + + # def adapter234(self, dropout_65): + # dropout_265 = nnscaler.runtime.adapter.chunk(dropout_65, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(dropout_265, shape=(4, 2), dtype=torch.float32, src=1, dst=3) + # return + + # def adapter245(self): + # gdropout_267 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=3, dst=1) + # gdropout_92 = nnscaler.runtime.adapter.all_gather(gdropout_267, dim=1, ranks=[0, 1]) + # return gdropout_92 + + # def adapter201(self): + # ginput_projection_weight_263 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=3, dst=1) + # ginput_projection_weight_177 = nnscaler.runtime.adapter.all_gather(ginput_projection_weight_263, dim=1, ranks=[0, 1]) + # return ginput_projection_weight_177 + + # def adapter214(self): + # sum_1_54 = nnscaler.runtime.adapter.move((), shape=(), dtype=torch.float32, src=3, dst=1) + # return sum_1_54 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_76): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # x_53 = next(*(dataloader_76, )) + # dropout_65, input_projection_weight_174 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_53, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_174, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_65, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # dropout_310, input_projection_weight_314 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_302, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_314, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_310, ), requires_grad=False) + # gdropout_92 = nnscaler.runtime.executor.aexecute(model.adapter245, *(), requires_grad=False) + # ginput_projection_weight_177 = nnscaler.runtime.executor.aexecute(model.adapter201, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gx_78 = nnscaler.runtime.executor.backward('segment83', (x_53, ), (dropout_65, input_projection_weight_174, ), (gdropout_92, ginput_projection_weight_177, )) + # del x_53, dropout_65, input_projection_weight_174, gdropout_92, ginput_projection_weight_177, gx_78 + # gdropout_311 = nnscaler.runtime.executor.aexecute(model.adapter245, *(), requires_grad=False) + # ginput_projection_weight_315 = nnscaler.runtime.executor.aexecute(model.adapter201, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gx_303 = nnscaler.runtime.executor.backward('segment83', (x_302, ), (dropout_310, input_projection_weight_314, ), (gdropout_311, ginput_projection_weight_315, )) + # del x_302, dropout_310, input_projection_weight_314, gdropout_311, ginput_projection_weight_315, gx_303 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=True) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=True) + # return sum_1_54, sum_1_349 + + # def _infer_step(model, dataloader_76): + # _ = None + # x_53 = next(*(dataloader_76, )) + # dropout_65, input_projection_weight_174 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_53, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_174, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_65, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # dropout_310, input_projection_weight_314 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_302, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_314, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_310, ), requires_grad=False) + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=False) + # return sum_1_54, sum_1_349 + + # Generated code of rank 2 looks like: + + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 2 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('ffn_1_gate_proj_weight_222', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_gate_proj_weight_222', 26, True, 'ffn.1.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_1_up_proj_weight_236', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_up_proj_weight_236', 32, True, 'ffn.1.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_1_down_proj_weight_250', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_down_proj_weight_250', 38, True, 'ffn.1.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def adapter190(self): + # input_projection_weight_256 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # input_projection_weight_174 = nnscaler.runtime.adapter.all_gather(input_projection_weight_256, dim=1, ranks=[2, 3]) + # return input_projection_weight_174 + + # def adapter234(self): + # dropout_264 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # dropout_65 = nnscaler.runtime.adapter.all_gather(dropout_264, dim=1, ranks=[2, 3]) + # return dropout_65 + + # def segment93(self, dropout_65, input_projection_weight_174): + # input_projection_weight_184 = nnscaler.runtime.function.identity(input_projection_weight_174) + # del input_projection_weight_174 + # dropout_180 = nnscaler.runtime.function.identity(dropout_65) + # del dropout_65 + # dropout_180 = nnscaler.runtime.adapter.nn.identity_allreduce(dropout_180, ranks=[2, 3]) + + # def recompute(dropout_180): + # # created at IRAdapterGener:local_consumer_multiref + # dropout_286, dropout_290 = nnscaler.runtime.function.multiref(dropout_180, times=2) + # del dropout_180 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_4_224 = torch.nn.functional.linear(dropout_286, self.ffn_1_gate_proj_weight_222, bias=None) + # del dropout_286 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_1_232 = torch.tanh(linear_4_224) + # del linear_4_224 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_5_238 = torch.nn.functional.linear(dropout_290, self.ffn_1_up_proj_weight_236, bias=None) + # del dropout_290 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_246 = torch.mul(tanh_1_232, linear_5_238) + # del tanh_1_232, linear_5_238 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_6_252 = torch.nn.functional.linear(mul_1_246, self.ffn_1_down_proj_weight_250, bias=None) + # del mul_1_246 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_1_dropout_training_42 = self.training + # linear_6_73 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_6_252, ranks=[2, 3]) + # del linear_6_252 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_1_74 = torch.nn.functional.dropout(linear_6_73, p=0.1, training=ffn_1_dropout_training_42, inplace=False) + # del linear_6_73 + # return dropout_1_74 + + # dropout_1_74 = ckpt.checkpoint(recompute, dropout_180, use_reentrant=False) + # del dropout_180 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 444, in forward, x = self.output_projection(x) + # linear_7_75 = torch.nn.functional.linear(dropout_1_74, input_projection_weight_184, bias=None) + # del input_projection_weight_184, dropout_1_74 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 445, in forward, return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + # sum_1_54 = torch.sum(linear_7_75) + # del linear_7_75 + # return sum_1_54 + + # def adapter245(self, gdropout_92): + # gdropout_266 = nnscaler.runtime.adapter.chunk(gdropout_92, dim=1, ranks=[2, 3]) + # _ = nnscaler.runtime.adapter.move(gdropout_266, shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # return + + # def adapter201(self, ginput_projection_weight_177): + # ginput_projection_weight_262 = nnscaler.runtime.adapter.chunk(ginput_projection_weight_177, dim=1, ranks=[2, 3]) + # _ = nnscaler.runtime.adapter.move(ginput_projection_weight_262, shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # return + + # def adapter214(self, sum_1_54): + # _ = nnscaler.runtime.adapter.move(sum_1_54, shape=(), dtype=torch.float32, src=2, dst=0) + # return sum_1_54 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_76): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # input_projection_weight_174 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=True) + # dropout_65 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=True) + # sum_1_54 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_65, input_projection_weight_174, ), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gdropout_92, ginput_projection_weight_177 = nnscaler.runtime.executor.backward('segment93', (dropout_65, input_projection_weight_174, ), (sum_1_54, ), (None, )) + # sum_1_54 = sum_1_54.detach() + # input_projection_weight_314 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=True) + # dropout_310 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter245, *(gdropout_92, ), requires_grad=False) + # del dropout_65, gdropout_92 + # _ = nnscaler.runtime.executor.aexecute(model.adapter201, *(ginput_projection_weight_177, ), requires_grad=False) + # del input_projection_weight_174, ginput_projection_weight_177 + # sum_1_349 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_310, input_projection_weight_314, ), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gdropout_311, ginput_projection_weight_315 = nnscaler.runtime.executor.backward('segment93', (dropout_310, input_projection_weight_314, ), (sum_1_349, ), (None, )) + # sum_1_349 = sum_1_349.detach() + # _ = nnscaler.runtime.executor.aexecute(model.adapter245, *(gdropout_311, ), requires_grad=False) + # del dropout_310, gdropout_311 + # _ = nnscaler.runtime.executor.aexecute(model.adapter201, *(ginput_projection_weight_315, ), requires_grad=False) + # del input_projection_weight_314, ginput_projection_weight_315 + # x_302 = next(*(dataloader_76, )) + # del x_302 + # x_53 = next(*(dataloader_76, )) + # del x_53 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_54, ), requires_grad=True) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_349, ), requires_grad=True) + # return sum_1_54, sum_1_349 + + # def _infer_step(model, dataloader_76): + # _ = None + # input_projection_weight_174 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=False) + # dropout_65 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=False) + # sum_1_54 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_65, input_projection_weight_174, ), requires_grad=False) + # input_projection_weight_314 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=False) + # dropout_310 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_310, input_projection_weight_314, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # del x_302 + # x_53 = next(*(dataloader_76, )) + # del x_53 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_54, ), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_349, ), requires_grad=False) + # return sum_1_54, sum_1_349 + + +class FnPolicySharedWeightModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.input_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection.weight = self.input_projection.weight # share weight + + def forward(self, x): + x = self.input_projection(x) + x = self.output_projection(x) + return x + + +def shared_weight_different_partition_policy(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition, get_called_self_module_name + + for node in get_pas_ops(graph): + module_called = get_called_self_module_name(node.call_expr) + + if node.fn == torch.nn.functional.linear and module_called == 'output_projection': + # input_projection.weight is used two times with different partition + # x = self.input_projection(x) --> no partition + # x = self.output_projection(x) --> partition dim=1 + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + + +@replace_all_device_with('cpu') +def test_codegen_fn_shared_weight(tmp_path): + parallelize( + FnPolicySharedWeightModule(), + {'x': torch.randn(4, 4)}, + # 'pp', + shared_weight_different_partition_policy, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicySharedWeightModule, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + # the input projection is multiref'ed + assert _gencode_contains(tmp_path, FnPolicySharedWeightModule, rank, r'nnscaler.runtime.function.multiref\(self.input_projection') + # input_projection.weight will not be splitted + # because it is multiref'ed + assert fullmap['input_projection.weight'].shape == (4, 4) and fullmap['input_projection.weight'].sub_shape == (4, 4) + + # Generated code of rank 0 looks like: + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 2]) + # self.init_group(ranks=[1, 3]) + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('input_projection_weight_15', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('input_projection_weight_15', 3, True, 'input_projection.weight', (4, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.wreducer80 = nnscaler.runtime.adapter.Reducer(ranks=[0, 2], reduce_op='sum', async_op=async_op, zero=False, max_bucket_size_bytes=max_bucket_size_bytes, zero_use_reduce_scatter=zero_use_reduce_scatter, zero_ngroups=1) + # self.wreducer80.add_param(self.input_projection_weight_15) + # self.add_reducer(self.wreducer80) + + # self._post_init(init_params, build_buckets) + + # def segment76(self, x_13): + # # shared param + # input_projection_weight_32, input_projection_weight_33 = nnscaler.runtime.function.multiref(self.input_projection_weight_15, times=2) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 763, in forward, x = self.input_projection(x) + # linear_16 = torch.nn.functional.linear(x_13, input_projection_weight_32, bias=None) + # del x_13, input_projection_weight_32 + # linear_22 = nnscaler.runtime.adapter.nn.split_allgather(linear_16, dim=1, ranks=[0, 1]) + # del linear_16 + # input_projection_weight_37 = nnscaler.runtime.adapter.nn.split_allgather(input_projection_weight_33, dim=1, ranks=[0, 1]) + # del input_projection_weight_33 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 764, in forward, x = self.output_projection(x) + # linear_1_26 = torch.nn.functional.linear(linear_22, input_projection_weight_37, bias=None) + # del linear_22, input_projection_weight_37 + # linear_1_14 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_1_26, ranks=[0, 1]) + # del linear_1_26 + # return linear_1_14 + + +class FnPolicyModuleList2(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + + def forward(self, x): + x = x * 2 + for ffn in self.ffn: + x = ffn(x) + x = x + 3 + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline2(tmp_path): + parallelize( + FnPolicyModuleList2(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + # 4 stages, with pp=2 + 'pipeline_size': 2, + 'pipeline_scheduler': '1f1b_interleaved', + } + ), + gen_savedir=tmp_path, + load_module=False + ) + # should successfully generate code without error + assert True + + +class HookModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + a, b = x.size()[:2] + r = torch.randn(a * 2, b) + r = r.chunk(2, dim=0)[0] + return self.linear(x) + r + + + +def hello(module, meta, *args, **kwargs): + print(f'hello: {meta}') + + +def policy_hook(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition + # add hook to all ops + def _hook(op_plan: OpPlan): + op_plan.pre_hook = hello + op_plan.post_hook = hello + op_plan.hook_meta = op_plan.op.name + return op_plan + + for node in get_pas_ops(graph): + if node.fn == torch.nn.functional.linear: + yield _hook(OpPlan(node, partition=OpPartition(input=1, dim=0))) + else: + yield _hook(OpPlan(node)) + + +@replace_all_device_with('cpu') +def test_codegen_fn_with_hook(tmp_path): + parallelize( + HookModule(), + {'x': torch.randn(4, 4)}, + policy_hook, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + # should successfully generate code without error + # and hooks are inserted + for rank in range(4): + assert _gencode_contains(tmp_path, HookModule, rank, r'tests.test_policies.hello\(self,') + + # Generated code of rank 0 looks like: + # def segment64(self, x_32): + # x_32 = nnscaler.runtime.adapter.nn.identity_allreduce(x_32, ranks=[0, 1]) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 883, in forward, a, b = x.size()[:2] + # tests.test_policies.hello(self, 'size', (x_32, ), dict()) + # im_output_63 = torch.Tensor.size(x_32) + # tests.test_policies.hello(self, 'size', (x_32, ), dict(), im_output_63) + # size_26 = im_output_63[0] + # size_27 = im_output_63[1] + # del im_output_63 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 884, in forward, r = torch.randn(a * 2, b) + # tests.test_policies.hello(self, 'mul', (size_26, 2), dict()) + # mul_28 = _operator.mul(size_26, 2) + # tests.test_policies.hello(self, 'mul', (size_26, 2), dict(), mul_28) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 884, in forward, r = torch.randn(a * 2, b) + # tests.test_policies.hello(self, 'randn', (), dict(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False)) + # randn_34 = nnscaler.runtime.function.randn(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False) + # tests.test_policies.hello(self, 'randn', (), dict(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False), randn_34) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 885, in forward, r = r.chunk(2, dim=0)[0] + # tests.test_policies.hello(self, 'chunk', (randn_34, ), dict(chunks=2, dim=0)) + # chunk_35, chunk_36 = torch.chunk(randn_34, chunks=2, dim=0) + # tests.test_policies.hello(self, 'chunk', (randn_34, ), dict(chunks=2, dim=0), (chunk_35, chunk_36)) + # del randn_34, chunk_36 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 886, in forward, return self.linear(x) + r + # tests.test_policies.hello(self, 'linear', (x_32, self.linear_weight_45, self.linear_bias_47), dict()) + # linear_49 = torch.nn.functional.linear(x_32, self.linear_weight_45, self.linear_bias_47) + # tests.test_policies.hello(self, 'linear', (x_32, self.linear_weight_45, self.linear_bias_47), dict(), linear_49) + # del x_32 + # linear_39 = nnscaler.runtime.adapter.nn.allgather_split(linear_49, dim=1, ranks=[0, 1]) + # del linear_49 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 886, in forward, return self.linear(x) + r + # tests.test_policies.hello(self, 'add', (linear_39, chunk_35), dict(alpha=1)) + # add_33 = torch.add(linear_39, chunk_35, alpha=1) + # tests.test_policies.hello(self, 'add', (linear_39, chunk_35), dict(alpha=1), add_33) + # del chunk_35, linear_39 + # return add_33 + + +def _gencode_unused_args_worker(tempdir): + init_distributed() + m_new = parallelize( + HookModule(), + {'x': torch.randn(4, 4)}, + policy_hook, + ComputeConfig(2, 2), + gen_savedir=tempdir, + load_module=True + ) + assert m_new is not None + m_new(torch.randn(4, 4)) + # should successfully run without error + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_run_codegen_fn_with_hook(): + """ + Verify the generated code can run correctly. + """ + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(2, _gencode_unused_args_worker, tempdir) + + +@replace_all_device_with('cpu') +def test_codegen_fsdp(tmp_path): + parallelize( + FnPolicyModuleList(), + {'x': torch.randn(4, 4)}, + 'fsdp', + ComputeConfig( + 1, 2, + use_end2end=True, + use_zero=3, + pas_config={ + 'recomputes': [FFNDropout], + } + ), + gen_savedir=tmp_path, + load_module=False + ) + # code should look like: + # def segment105_impl(self, x_49): + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 239, in forward, x = x * 2 + # mul_51 = torch.mul(x_49, 2) + # del x_49 + + # def recompute(mul_51): + # # created at IRAdapterGener:local_consumer_multiref + # mul_100, mul_104 = nnscaler.runtime.function.multiref(mul_51, times=2) + # del mul_51 + # self.prefetch_param(self.ffn_0_gate_proj_weight_52) + # mul_100 = self.backward_postevict_param(mul_100, self.ffn_0_gate_proj_weight_52, 1) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_53 = torch.nn.functional.linear(mul_100, self.ffn_0_gate_proj_weight_52, bias=None) + # self.postevict_param(self.ffn_0_gate_proj_weight_52) + # linear_53 = self.backward_prefetch_param(linear_53, self.ffn_0_gate_proj_weight_52, 1) + # del mul_100 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_54 = torch.tanh(linear_53) + # del linear_53 + # self.prefetch_param(self.ffn_0_up_proj_weight_55) + # mul_104 = self.backward_postevict_param(mul_104, self.ffn_0_up_proj_weight_55, 3) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_56 = torch.nn.functional.linear(mul_104, self.ffn_0_up_proj_weight_55, bias=None) + # self.postevict_param(self.ffn_0_up_proj_weight_55) + # linear_1_56 = self.backward_prefetch_param(linear_1_56, self.ffn_0_up_proj_weight_55, 3) + # del mul_104 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_57 = torch.mul(tanh_54, linear_1_56) + # del tanh_54, linear_1_56 + # self.prefetch_param(self.ffn_0_down_proj_weight_58) + # mul_1_57 = self.backward_postevict_param(mul_1_57, self.ffn_0_down_proj_weight_58, 5) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_59 = torch.nn.functional.linear(mul_1_57, self.ffn_0_down_proj_weight_58, bias=None) + # self.postevict_param(self.ffn_0_down_proj_weight_58) + # linear_2_59 = self.backward_prefetch_param(linear_2_59, self.ffn_0_down_proj_weight_58, 5) + # del mul_1_57 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_21 = self.training + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # dropout_60 = torch.nn.functional.dropout(linear_2_59, p=0.1, training=ffn_0_dropout_training_21, inplace=False) + # del linear_2_59 + # return dropout_60 + + # dropout_60 = ckpt.checkpoint(recompute, mul_51, use_reentrant=False) + # del mul_51 + + # def recompute(dropout_60): + # # created at IRAdapterGener:local_consumer_multiref + # dropout_108, dropout_112 = nnscaler.runtime.function.multiref(dropout_60, times=2) + # del dropout_60 + # self.prefetch_param(self.ffn_1_gate_proj_weight_61) + # dropout_108 = self.backward_postevict_param(dropout_108, self.ffn_1_gate_proj_weight_61, 1) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_3_62 = torch.nn.functional.linear(dropout_108, self.ffn_1_gate_proj_weight_61, bias=None) + # self.postevict_param(self.ffn_1_gate_proj_weight_61) + # linear_3_62 = self.backward_prefetch_param(linear_3_62, self.ffn_1_gate_proj_weight_61, 1) + # del dropout_108 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_1_63 = torch.tanh(linear_3_62) + # del linear_3_62 + # self.prefetch_param(self.ffn_1_up_proj_weight_64) + # dropout_112 = self.backward_postevict_param(dropout_112, self.ffn_1_up_proj_weight_64, 3) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_4_65 = torch.nn.functional.linear(dropout_112, self.ffn_1_up_proj_weight_64, bias=None) + # self.postevict_param(self.ffn_1_up_proj_weight_64) + # linear_4_65 = self.backward_prefetch_param(linear_4_65, self.ffn_1_up_proj_weight_64, 3) + # del dropout_112 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_2_66 = torch.mul(tanh_1_63, linear_4_65) + # del tanh_1_63, linear_4_65 + # self.prefetch_param(self.ffn_1_down_proj_weight_67) + # mul_2_66 = self.backward_postevict_param(mul_2_66, self.ffn_1_down_proj_weight_67, 5) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_5_68 = torch.nn.functional.linear(mul_2_66, self.ffn_1_down_proj_weight_67, bias=None) + # self.postevict_param(self.ffn_1_down_proj_weight_67) + # linear_5_68 = self.backward_prefetch_param(linear_5_68, self.ffn_1_down_proj_weight_67, 5) + # del mul_2_66 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # ffn_1_dropout_training_40 = self.training + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # dropout_1_69 = torch.nn.functional.dropout(linear_5_68, p=0.1, training=ffn_1_dropout_training_40, inplace=False) + # del linear_5_68 + # return dropout_1_69 + + # dropout_1_69 = ckpt.checkpoint(recompute, dropout_60, use_reentrant=False) + # del dropout_60 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 242, in forward, x = x + 3 + # add_70 = torch.add(dropout_1_69, 3, alpha=1) + # del dropout_1_69 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 243, in forward, return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + # sum_1_50 = torch.sum(add_70) + # del add_70 + # return sum_1_50 + + # def segment105(self, x_49): + # with self.save_params_hooks(): + # return self.segment105_impl(x_49) + assert True diff --git a/tests/test_utils.py b/tests/test_utils.py index 7fa7d80a..864cab67 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from collections import OrderedDict from dataclasses import dataclass import pytest +import torch -from nnscaler.utils import select_many, classproperty, fields +from nnscaler.utils import ( + select_many, classproperty, fields, set_member_by_name, unchecked_fields, + transform_recursively, +) def test_select_many(): @@ -53,3 +58,123 @@ class A: assert fields(A).y == 'y' with pytest.raises(AttributeError): fields(A).z + + assert unchecked_fields(A).x == 'x' + assert unchecked_fields(A).y == 'y' + assert unchecked_fields(A).z == 'z' + + a = A(x=0, y=0) + assert unchecked_fields(a).x == 'x' + assert unchecked_fields(a).y == 'y' + assert unchecked_fields(a).z == 'z' + + class B: + def __init__(self): + self.a = A(x=1, y=2) + + assert unchecked_fields(B).x == 'x' + b = B() + assert unchecked_fields(b).x == 'x' + assert unchecked_fields(b.a).x == 'x' + + +def test_set_member_by_name(): + model = torch.nn.Module() + set_member_by_name(model, "x", 42) + assert model.x == 42 + with pytest.raises(AttributeError): + set_member_by_name(model, 'x.y.z', 43) + + set_member_by_name(model, 'a.b.c', 44) + assert model.a.b.c == 44 + + model = torch.nn.Module() + child_module = torch.nn.Module() + set_member_by_name(model, "x.y", child_module) + assert model.x.y == child_module + + set_member_by_name(model, 'x.y.z', 45) + assert model.x.y == child_module + assert model.x.y.z == 45 + + +def test_transform_recursively(): + data = { + 'a': torch.tensor([1]), + 'b': [torch.tensor(4), {'c': torch.tensor([5])}], + 'd': (7, torch.tensor(8)), + 'e': {1: 9, 2: torch.tensor(10)}.keys(), + 'f': {1: 9, 2: torch.tensor(11)}.items(), + 'g': {1: 9, 2: torch.tensor(12)}.values(), + 'h': {1: 9, 2: torch.tensor(13)}, + 'i': slice(0, 10, None), + 'j': torch.Size([11, 12]), + 'k': OrderedDict({1: 9, 2: 10}), + 'l': {1: 9, 2: 10}.values(), + 'm': [1, 2, 3], + 'n': slice(0, 10, torch.tensor(2)), + 'o': {torch.tensor(1): 9, torch.tensor(2): 10}, + 'p': {torch.tensor(1): 9, torch.tensor(2): 10}.items(), + 'q': {torch.tensor(1): 9, torch.tensor(2): 10}.keys() + } + + def fn(x): + if isinstance(x, torch.Tensor): + return x.item() + return x + + result1 = transform_recursively( + data, fn, + target_types=torch.Tensor, + collection_types=None, + skip_dict_keys=True, + ) + + result2 = transform_recursively( + data, fn, + target_types=torch.Tensor, + collection_types=None, + skip_dict_keys=False, + ) + target = { + 'a': 1, + 'b': [4, {'c': 5}], + 'd': (7, 8), + 'e': {1: 1, 2: 2}.keys(), + 'f': dict([(1, 9), (2, 11)]).items(), + 'g': {1: 9, 2: 12}.values(), + 'h': {1: 9, 2: 13}, + 'i': slice(0, 10, None), + 'j': torch.Size([11, 12]), + 'k': OrderedDict({1: 9, 2: 10}), + 'l': data['l'], + 'm': [1, 2, 3], + 'n': slice(0, 10, 2), + } + # dict values are not comparable. + assert list(target['g']) == list(result1.pop('g')) + assert list(target['g']) == list(result2.pop('g')) + target.pop('g') + + + skip_key_target = { + **target, + 'o': {torch.tensor(1): 9, torch.tensor(2): 10}, + 'p': {torch.tensor(1): 9, torch.tensor(2): 10}.items(), + 'q': {1: 9, 2: 10}.keys() + } + noskip_key_target = { + **target, + 'o': {1: 9, 2: 10}, + 'p': dict([(1, 9), (2, 10)]).items(), + 'q': {1: 9, 2: 10}.keys() + } + + from tests.parallel_module.common import assert_equal + + assert_equal(list(skip_key_target.pop('o')), list(result1.pop('o'))) + assert_equal(list(skip_key_target.pop('p')), list(result1.pop('p'))) + assert_equal(list(skip_key_target.pop('q')), list(result1.pop('q'))) + + assert_equal(result1, skip_key_target) + assert_equal(result2, noskip_key_target) diff --git a/tests/utils.py b/tests/utils.py index 22036f42..d29e4168 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -58,7 +58,7 @@ def init_random(seed: int = 1): torch.cuda.manual_seed(seed) -def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-4) -> bool: +def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-3) -> bool: """Compare the output of baseline_fn and compile_fn Error will raise if the output of two functions are not the same. @@ -92,10 +92,10 @@ def assert_same_complex(gt, out): assert_same_complex(gt[key], out[key]) elif isinstance(gt, torch.Tensor): assert isinstance(out, torch.Tensor) - assert torch.allclose(gt, out, atol=atol), f'mismatched: {gt} != {out}' + assert torch.allclose(gt, out, atol=atol), f'mismatched (with atol {atol}): {gt} != {out}' elif isinstance(gt, float): assert isinstance(out, float) - assert math.isclose(gt, out, abs_tol=atol), f'mismatched: {gt} != {out}' + assert math.isclose(gt, out, abs_tol=atol), f'mismatched (with atol {atol}): {gt} != {out}' else: assert gt == out, f'mismatched: {gt} != {out}' assert_same_complex(baseline_outputs, compile_outputs) @@ -114,6 +114,7 @@ def replace_all_device_with(device='cpu', force=False): orig_to = torch.Tensor.to orig_cuda = torch.Tensor.cuda orig_cpu = torch.Tensor.cpu + orig_is_cuda = torch.Tensor.is_cuda def patch_tensor_constructor(fn): orig_func = getattr(fn, '__cube_orig_func__', fn) # to support nested patching @@ -158,6 +159,8 @@ def wrapper(*args, **kwargs): } def patched_to(self, *args, **kwargs): + if device == 'meta': + return self if len(args) > 0 and isinstance(args[0], (torch.device, str)): return orig_to(self, device, *args[1:], **kwargs) if 'device' in kwargs: @@ -166,15 +169,20 @@ def patched_to(self, *args, **kwargs): return orig_to(self, *args, **kwargs) def patched_cuda(self, *args, **kwargs): + if device == 'meta': + return self return orig_to(self, device) def patched_cpu(self, *args, **kwargs): + if device == 'meta': + return self return orig_to(self, device) try: torch.Tensor.to = patched_to torch.Tensor.cuda = patched_cuda torch.Tensor.cpu = patched_cpu + torch.Tensor.is_cuda = property(lambda self: True) # patch tensor constructors for tf_name, fn in old_tensor_constructors.items(): setattr(torch, tf_name, patched_tensor_constructors[tf_name]) @@ -205,6 +213,7 @@ def patched_cpu(self, *args, **kwargs): torch.Tensor.to = orig_to torch.Tensor.cuda = orig_cuda torch.Tensor.cpu = orig_cpu + torch.Tensor.is_cuda = orig_is_cuda # mock process group is from pytorch testing code diff --git a/utility/aggregate.sh b/utility/aggregate.sh new file mode 100644 index 00000000..19185564 --- /dev/null +++ b/utility/aggregate.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# gather the folder to all workers to node-0 under the same workspace + +set -ex + +WORKSPACE=/workspace +FOLDER=MagicCube + +WORKER_PREFIX=node- +WORKER_NUM=2 + +for ((i=1; i<${WORKER_NUM}; i++)); do + WORKER=${WORKER_PREFIX}${i} + scp -r ${WORKER}:${WORKSPACE}/${FOLDER} ${WORKSPACE}/${FOLDER}-${WORKER} +done diff --git a/utility/broadcast.sh b/utility/broadcast.sh new file mode 100644 index 00000000..dbb77c7a --- /dev/null +++ b/utility/broadcast.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# broadcast the folder to all workers under the same workspace + +set -ex + +WORKSPACE=/workspace +FOLDER=MagicCube + +WORKER_PREFIX=node- +WORKER_NUM=2 + +for ((i=1; i<=${WORKER_NUM}; i++)); do + WORKER=${WORKER_PREFIX}${i} + scp -r ${WORKSPACE}/${SYNC_FOLDER} ${WORKER}:${WORKSPACE} +done diff --git a/utility/comm_profile.py b/utility/comm_profile.py new file mode 100644 index 00000000..5f767c24 --- /dev/null +++ b/utility/comm_profile.py @@ -0,0 +1,108 @@ +import argparse +import json +import torch +from pathlib import Path +import os +from typing import Tuple, List, Dict + +import nnscaler +from nnscaler.runtime.adapter.collectives import all_gather, all_reduce, all_to_all, reduce_scatter +from nnscaler.profiler import CudaTimer +from nnscaler.runtime.device import DeviceGroup +from nnscaler.autodist.util import get_node_arch, get_default_profile_path + + +class CommProfiler: + + def __init__(self, + nranks: int, + warmup_times: int = 10, + profile_times: int = 10) -> None: + self.nranks = nranks + self.warmup_times = warmup_times + self.profile_times = profile_times + self.ranks = tuple(range(self.nranks)) + + def collect_profile_info(self, + primitive: str) -> Tuple[List[float], List[float]]: + + b_size = 16 + sequence_len = 16 + quarter_mb_size_list = [ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 + ] + model_dim_list = [ + mem * 256 * 256 // b_size // sequence_len + for mem in quarter_mb_size_list + ] + sizes_in_mb = [0.25 * val for val in quarter_mb_size_list] + times_in_s = [] + for cur_sz, d_size in zip(sizes_in_mb, model_dim_list): + assert d_size % self.nranks == 0 + if primitive in ['all gather', 'all to all']: + d_size = d_size // self.nranks + tensor = torch.rand([b_size, sequence_len, d_size], + dtype=torch.float32, + device=torch.cuda.current_device()) + if primitive == 'all gather': + func = all_gather + kwargs = {'tensor': tensor, 'dim': 2, 'ranks': self.ranks} + elif primitive == 'all reduce': + func = all_reduce + kwargs = {'tensor': tensor, 'ranks': self.ranks} + elif primitive == 'reduce scatter': + func = reduce_scatter + kwargs = {'tensor': tensor, 'dim': 2, 'ranks': self.ranks} + elif primitive == 'all to all': + func = all_to_all + kwargs = { + 'tensor': tensor, + 'idim': 0, + 'odim': 2, + 'ranks': self.ranks + } + else: + raise ValueError('Unknown primitive: {}'.format(primitive)) + for _ in range(self.warmup_times): + func(**kwargs) + CudaTimer().clear() + for _ in range(self.profile_times): + otensor = func(**kwargs) + cur_t = CudaTimer().instance.field_data['comm'] / self.profile_times + times_in_s.append(cur_t) + return sizes_in_mb, times_in_s + + def profile(self) -> Dict[str, Tuple[List[float], List[float]]]: + profile_info = {} + for primitive in [ + 'all gather', 'all reduce', 'reduce scatter', 'all to all' + ]: + profile_info[primitive] = self.collect_profile_info( + primitive=primitive) + return profile_info + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description='Profile runtime communication cost') + parser.add_argument('--comm_profile_dir', + type=str, + default=get_default_profile_path() / get_node_arch() / 'comm', + help='autodist comm profile folder') + args = parser.parse_args() + + nnscaler.init() + + CudaTimer(enable=True, predefined=True) + world_size = DeviceGroup().world_size + comm_profiler = CommProfiler(nranks=world_size) + + profile_info = comm_profiler.profile() + + if torch.distributed.get_rank() == 0: + dir_path = Path(args.comm_profile_dir) + if not dir_path.exists(): + dir_path.mkdir(parents=True, exist_ok=True) + file_name = dir_path / f'intra_{world_size}.json' + with open(file_name, 'w') as f: + json.dump(profile_info, f, indent=2) diff --git a/utility/dgx1_reorder_gpu.py b/utility/dgx1_reorder_gpu.py new file mode 100644 index 00000000..aa312587 --- /dev/null +++ b/utility/dgx1_reorder_gpu.py @@ -0,0 +1,119 @@ +""" +Reorder GPU index by finding DGX-1 topology Find dgx topology + +┌───────────┐ +1 = 0 = 4 = 5 +‖ x | | x ‖ +2 = 3 = 7 = 6 +└───────────┘ + +""" +from typing import List +import subprocess +import numpy as np + +_kConnType = { + "NV1": 1, + "NV2": 2, + "NODE": 3, + "X": -1, +} + +_kConnTypeStr = {val: key for key, val in _kConnType.items()} + + + +def get_topology(): + cmds = [ + 'nvidia-smi', + 'topo', + '-m', + ] + + proc = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + outputs = stdout.decode('utf-8').split('\n') + + outputs = [out for out in outputs if out.startswith('GPU')] + ngpus = len(outputs) + print(f'Detected GPU number: {ngpus}') + + topology = np.empty((ngpus, ngpus), dtype=int) + for src, output in enumerate(outputs): + connections = output.split('\t')[1:1+ngpus] + for dst, link in enumerate(connections): + link = link.replace(" ", "") + assert link in _kConnType, f"Find link not in DGX-1 topology: {link}" + topology[src, dst] = _kConnType[link] + return topology + + +def topology_repr(topology: np.ndarray, reorder: List[int]): + reorder = list(reorder) + ngpus = topology.shape[0] + reorder_topo = np.empty((ngpus, ngpus), dtype=object) + for src in range(ngpus): + for dst in range(ngpus): + link = _kConnTypeStr[topology[src, dst]] + reorder_topo[reorder.index(src), reorder.index(dst)] = link + maxlen = max(len(key) for key in _kConnType) + dscp = '' + for gidx, line in enumerate(reorder_topo): + dscp += f'GPU{gidx}: '+ ' '.join(link.ljust(maxlen) for link in line) + '\n' + return dscp + + +def reorder(topology: np.ndarray) -> np.ndarray: + """ + Reorder GPU according to DGX-1 topology + + ┌───────────┐ + 1 = 0 = 4 = 5 + ‖ x | | x ‖ + 2 = 3 = 7 = 6 + └───────────┘ + """ + ngpus = topology.shape[0] + # find NV2 ring + ring = [0] + while len(ring) < ngpus: + nv2s = np.where(topology[ring[-1]] == _kConnType['NV2'])[0] + find_next = False + for gid in nv2s: + if gid not in ring: + ring.append(gid) + find_next = True + break + assert find_next + ring = np.array(ring, dtype=int) + print(f'Get ring: {ring}') + # find fc + for idx, src in enumerate(ring): + is_fc = True + pairs = [ + (src, ring[(idx + 3) % len(ring)]), + (src, ring[(idx + 2) % len(ring)]), + (ring[(idx+1) % len(ring)], ring[(idx+3) % len(ring)]) + ] + for src, dst in pairs: + if topology[src, dst] != _kConnType['NV1']: + is_fc = False + break + if is_fc: + break + assert is_fc, f"Cannot find FC group." + ring = np.roll(ring, 0-idx) + return ring + + +if __name__ == '__main__': + topology = get_topology() + print('original topology:') + print(topology_repr(topology, list(range(topology.shape[0])))) + reorder = reorder(topology) + print('reorder topology:') + print(topology_repr(topology, reorder)) + print( + f"Command need to be added into environment:\n" + f"export CUDA_VISIBLE_DEVICES={','.join(str(gid) for gid in reorder)}" + ) diff --git a/utility/keep.py b/utility/keep.py new file mode 100644 index 00000000..d45a02fc --- /dev/null +++ b/utility/keep.py @@ -0,0 +1,72 @@ +import torch +import time +import argparse + +import subprocess +import re + +def get_gpu_util(rank): + from shutil import which + smi = None + if which('nvidia-smi') is not None: + smi = 'nvidia-smi' + elif which('rocm-smi') is not None: + smi = 'rocm-smi' + else: + raise Exception('Cannot find either nvidia-smi or rocm-smi!') + + cmds = [ + smi, + '-i', + str(rank), + ] + proc = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + outputs = stdout.decode('utf-8').split('\n') + + util = 0 + for output in outputs[::-1]: + # switch to performance line + if 'Default' in output: + # match all the numbers and return the last one + util = re.findall(r'\d+', output)[-1] + util = int(util) + break + else: + print("rank {}: couldn't match any, check GPU status!".format(rank)) + return util + + +def keep(rank, args): + + torch.cuda.set_device(rank) + a = torch.rand((8192, 8192)).cuda() + b = torch.rand((8192, 8192)).cuda() + + print(f'benchmarking {args.gpus} gpus...') + while True: + tic = time.time() + for _ in range(5000): + c = a * b + torch.cuda.synchronize() + toc = time.time() + # if rank == 0: + # print('benchmark 8K matmul: time span: {}ms'.format((toc - tic) * 1000 / 5000)) + time.sleep(args.interval) + while True: + util = get_gpu_util(rank) + if util <= 10: + break + # print('rank {}: find gpu busy, keep sleeping...'.format(rank)) + time.sleep(args.interval) + # print('rank {} gets up'.format(rank)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--interval', type=int, default=2) + parser.add_argument('--gpus', type=int, default=1) + args = parser.parse_args() + + torch.multiprocessing.spawn(keep, args=(args,), nprocs=args.gpus, join=True) diff --git a/utility/prim_profiler.py b/utility/prim_profiler.py new file mode 100644 index 00000000..6daef5ff --- /dev/null +++ b/utility/prim_profiler.py @@ -0,0 +1,52 @@ +import torch +import os +import sys +import shutil +from datetime import datetime +import subprocess +import torch +import logging +from pathlib import Path +from nnscaler.autodist.util import get_node_arch, get_default_profile_path + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("nnscaler.comm_profiler") + + +def main(): + default_path = get_default_profile_path() + + if not default_path.is_dir(): + default_path.mkdir(parents=True) + logger.info(f'create folder: {default_path}') + else: + logger.info(f'folder already exists: {default_path}') + + comm_path = default_path / 'comm' + + if comm_path.is_dir(): + logger.info(f'back up legacy comm info: {comm_path}') + shutil.move( + comm_path, + default_path / f'comm_back_{str(datetime.now().timestamp())}') + comm_path.mkdir(parents=True, exist_ok=True) + + logger.info(f'CUDA device num: {torch.cuda.device_count()}') + profiler_fname = Path(__file__).parent / 'comm_profile.py' + device_num = 2 + while device_num <= torch.cuda.device_count(): + command = f'torchrun --master_port 21212 --nproc_per_node={device_num} {profiler_fname} --comm_profile_dir={comm_path}' + output = subprocess.check_output(command, shell=True, text=True) + device_num = device_num * 2 + + logger.info(f'comm profile done') + + +if __name__ == '__main__': + main() diff --git a/utility/test_rvd_prim.py b/utility/test_rvd_prim.py new file mode 100644 index 00000000..7e8251c2 --- /dev/null +++ b/utility/test_rvd_prim.py @@ -0,0 +1,137 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + utility/test_rvd_prim.py --prims allreduce + +OMP_NUM_THREADS=4 torchrun \ + --nnode=2 --node_rank=$NODE_RANK --master_addr=node-0 \ + --nproc_per_node=8 \ + utility/test_rvd_prim.py --prims all +""" + +from typing import Callable +import nnscaler +import torch +import time +import argparse +from nnscaler.profiler.timer import CudaTimer, print_each_rank + +from nnscaler.runtime.adapter.collectives import all_reduce, all_gather, reduce_scatter, all_to_all +from nnscaler.runtime.device import DeviceGroup + + +def prim_allreduce(itensor, ranks, dim0=None, dim1=None): + return all_reduce(itensor, ranks) + + +def bw_allreduce(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * 2 * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_allgather(itensor, ranks, dim0=0, dim1=None): + return all_gather(itensor, dim0, ranks) + + +def bw_allgather(itensor: torch.Tensor, ranks, sec_per_call: float): + ndevs = len(ranks) + msg_size = itensor.nelement() * 4 / 1e9 * ndevs + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_reducescatter(itensor, ranks, dim0=0, dim1=None): + return reduce_scatter(itensor, dim0, ranks) + + +def bw_reducescatter(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_alltoall(itensor, ranks, dim0=0, dim1=1): + return all_to_all(itensor, dim0, dim1, ranks) + + +def bw_alltoall(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_bw(prim: Callable, bandwidth: Callable, ranks, size, warmup=100, profile=100): + if 'allgather' in prim.__name__: + size = size // len(ranks) + tensor: torch.Tensor = torch.zeros(size, device=torch.cuda.current_device()) + tensor = tensor.view(256, -1).contiguous() + torch.distributed.barrier() + # warm up + for _ in range(warmup): + _ = prim(tensor, ranks) + # profile + torch.cuda.synchronize() + torch.distributed.barrier() + tic = time.perf_counter() + for _ in range(profile): + _ = prim(tensor, ranks) + torch.cuda.synchronize() + toc = time.perf_counter() + + span = (toc - tic) / profile # seconds + msg_size = tensor.nelement() * 4 // 1024 // 1024 # MB + if 'allgather' in prim.__name__: + msg_size = len(ranks) * tensor.nelement() * 4 // 1024 // 1024 # MB + algo_bw, bus_bw = bandwidth(tensor, ranks, span) + print_each_rank( + '{} msg {} MB | wall-time(ms) algo-bw(GB/s) bus-bw(GB/s) {:.2f} {:.2f} {:.2f}'.format( + prim.__name__, msg_size, span*1000, algo_bw, bus_bw + ), rank_only=0 + ) + + +if __name__ == '__main__': + + nnscaler.init() + + parser = argparse.ArgumentParser(description='comm primitive') + parser.add_argument('--prims', type=str, nargs='+', action='append', + help='prims: all, allreduce, reducescatter, allgather, alltoall') + parser.add_argument('--begin', type=int, default=1, + help='start message size in MB') + parser.add_argument('--end', type=int, default=256, + help='end message size in MB') + args = parser.parse_args() + args.prims = args.prims[0] + + prims, bws = [], [] + if 'allreduce' in args.prims or 'all' in args.prims: + prims.append(prim_allreduce) + bws.append(bw_allreduce) + if 'allgather' in args.prims or 'all' in args.prims: + prims.append(prim_allgather) + bws.append(bw_allgather) + if 'reducescatter' in args.prims or 'all' in args.prims: + prims.append(prim_reducescatter) + bws.append(bw_reducescatter) + if 'alltoall' in args.prims or 'all' in args.prims: + prims.append(prim_alltoall) + bws.append(bw_alltoall) + + ranks = tuple(range(DeviceGroup().world_size)) + CudaTimer(enable=False) + for prim, bw in zip(prims, bws): + print_each_rank(f'====> test start {prim.__name__}', rank_only=0) + size = args.begin + while size <= args.end: + prim_bw(prim, bw, ranks, size * 1024 * 1024 // 4) + size *= 2 + print_each_rank(f'====> test finish {prim.__name__}', rank_only=0) diff --git a/utility/verify_ops/verify_dimops.py b/utility/verify_ops/verify_dimops.py new file mode 100644 index 00000000..bb88bc72 --- /dev/null +++ b/utility/verify_ops/verify_dimops.py @@ -0,0 +1,470 @@ +""" +This test verifies the correctness of an operator's annotation by running its distributed versions. +The processing pipeline is: +1. generate the input and calculate the output for the operator on a single device +2. construct the partition search space based on its annotation +3. for each partition choice, nnscaler will generate runnable code with communication adapters automatically +4. compare each distributed result with single device version, the difference should be less than a threshold +NOTE: only consider partitioning along one dimension currently +""" + +import os +from typing import Dict, List, Tuple, Any, Union +from dataclasses import dataclass, field +import logging +import subprocess +import torch + +from nnscaler.graph.function.dimops import IRDimops, OpAnno, DimAnno +from nnscaler.ir.cten import IRTensor, IRObject + + +logger = logging.getLogger(__name__) + + +_SINGLE_GPU_TEST_FILE = "single_gpu_test.py" +_TWO_GPUS_TEST_FILE = "two_gpus_test.py" + +module_template_common = """ +import os +import numpy +import sys +import torch +import nnscaler + +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType +{import_cumsomized_func} + +import nnscaler.graph +import nnscaler.graph.function +import nnscaler.graph.function.wrapnn + +import torch +import numpy as np +import random + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, {args}): + # Add clone to resolve the issue: + # a leaf Variable that requires grad is being used in an in-place operation. + {clone_args} + + {func_sig_call} + + out = 0 + for one_out in [{outputs}]: + if not isinstance(one_out, torch.Tensor): + continue + out += torch.sum(one_out) + return out + +model = TestModule() #.to(torch.float16) +""" + +module_template_single_main = """ +# Load inputs from file, ensuring inputs.pt is always a tuple, even when there's only one input +{args}, = torch.load('{func_sig}_inputs.pt', map_location=torch.device('cuda:0')) + +model = model.cuda() + +single_loss = model({args}) +single_loss.backward() + +grad_tensors = {grad_tensors} +torch.save([grad_tensors, single_loss], '{func_sig}_loss_single.pt') +print('single gpu loss: ', single_loss) +""" + +module_template_single = module_template_common + module_template_single_main + +module_template_parallel_main = """ +nnscaler.init() +rank_id = torch.distributed.get_rank() + +{args}, = torch.load('{func_sig}_inputs.pt', map_location=torch.device(f'cuda:{{rank_id}}')) + +def policy(graph: IRGraph, resource) -> IRGraph: + ngpus = 2 + partitioned = False + + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == '{func_sig}': + print('Partitioned node: ', node) + sub_nodes = graph.partition( + node, node.algorithm('dim'), idx={idx}, dim={dim}, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + + assert partitioned, f'No node is partitioned for {func_sig}.' + return graph + +parallel_model = parallelize( + model, + dummy_forward_args={dummy_input_str}, + pas_policy=policy, + compute_config=ComputeConfig(2, 2), + reuse=ReuseType.OVERRIDE +) + +parallel_model.train() + +parallel_loss = parallel_model({args}) +parallel_loss.backward() + +grad_tensors = {grad_tensors} +torch.save([grad_tensors, parallel_loss], '{func_sig}_loss_para_'+str(rank_id)+'.pt') +print('two gpus loss: ', parallel_loss) +""" + +module_template_parallel = module_template_common + module_template_parallel_main + + +@dataclass +class TensorInfo: + value_form: str # 'shape' or 'value' + value: Union[Tuple[int], Any] + dtype: torch.dtype = torch.float32 + requires_grad: bool = True + + # make TensorInfo hashable + def __hash__(self): + value = self.value + if isinstance(value, slice): + value = (value.start, value.stop, value.step) + return hash((self.value_form, value)) + + +@dataclass +class VerifyConfig: + fsig: str + args: List[TensorInfo] + kwargs: Dict[str, Any] + noutputs: int + parti_options: List[Dict[str, int]] + import_customized_func: str = "" + non_grad_indices: List[int] = field(default_factory=list) + + +def _complex(val: Any): + """ + Convert IRObject to concrete value + NOTE: only used for handling kwargs + """ + if isinstance(val, tuple): + return tuple(_complex(t) for t in val) + if isinstance(val, list): + return list(_complex(t) for t in val) + if isinstance(val, dict): + return {_complex(key): _complex(val) for key, val in val.items()} + if isinstance(val, slice): + return slice(_complex(val.start), _complex(val.stop), _complex(val.step)) + if isinstance(val, IRObject): + assert not isinstance(val, IRTensor), "IRTensor should not be in kwargs" + return _complex(val.value) + return val + + +def get_candidate_options( + anno: OpAnno, ins_outs_shape: List[TensorInfo], npartitions: int = 2 +) -> List[Dict[str, int]]: + """ + Get all the feasible partitions specified by the annotation of an operator. + Checks whether the dimension can be divided, and also checks whether the size of the dimension can be evenly divided by the number of partitions + Args: + anno (OpAnno): operator annotation + ins_outs_shape (List[TensorInfo]): input and output shapes + npartitions (int, optional): number of partitions. Defaults to 2. + Returns: + List[Dict[str, int]]: a list of feasible partitions + + """ + all_configs = anno.transform_space() + + candidate_partitions = [] + for idx, dim in all_configs: + if ( + ins_outs_shape[idx].value_form == "shape" + and ins_outs_shape[idx].value[dim] % npartitions == 0 + ): + candidate_partitions.append({"idx": idx, "dim": dim}) + + return candidate_partitions + + +def handle_buffer_parameters(inputs, non_grad_indices): + """ + Detach specified buffer parameters from the computational graph and disable their gradient computation. + This is necessary for parameters that should not participate in the backward pass, + such as statistical parameters in certain layers (e.g., running_mean in normalization layers). + + Args: + inputs (List[torch.Tensor]): The list of input tensors. + non_grad_indices (List[int]): The indices of buffer parameters in the input list. + """ + for idx in non_grad_indices: + if inputs[idx] is not None: + inputs[idx] = inputs[idx].detach() + inputs[idx].requires_grad = False + + +def _create_op_inputs(verify_config: VerifyConfig) -> List[Any]: + """ + Create input tensors/non-tensors for the operator. + The input tensors/non-tensors are only for args, not for kwargs. + Args: + verify_config (VerifyConfig): configuration for verifying the partitions + Returns: + List[Any]: input tensors + """ + torch.manual_seed(0) + inputs = [] + + def process_slice(slice_obj): + start = ( + slice_obj.start.value + if isinstance(slice_obj.start, IRObject) + else slice_obj.start + ) + stop = ( + slice_obj.stop.value + if isinstance(slice_obj.stop, IRObject) + else slice_obj.stop + ) + step = slice_obj.step + return slice(start, stop, step) + + for i, tensor_info in enumerate(verify_config.args): + if tensor_info.value_form == "shape": + # Special handling: For torch. rsqrt, generate random integers between 1 and 10 to avoid invalid values + if verify_config.fsig == "torch.rsqrt": + inputs.append( + torch.randint( + 1, + 10, + tensor_info.value, + dtype=tensor_info.dtype, + requires_grad=tensor_info.requires_grad, + ) + ) + # Special handling: for the first parameter of torch.where which is a boolean mask + elif verify_config.fsig == "torch.where" and i == 0: + inputs.append( + torch.rand( + *tensor_info.value, dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad + ) + > 0.5 + ) + elif verify_config.fsig == "torch.add" and tensor_info.value == (1,): + # Special handling:add in the model generates values that cannot be partitioned + inputs.append(torch.randn(4, dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad)) + else: + if tensor_info.value == (): + inputs.append( + torch.randn( + (), dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad + ).squeeze() + ) + else: + inputs.append( + torch.randn( + *tensor_info.value, + dtype=tensor_info.dtype, + requires_grad=tensor_info.requires_grad, + ) + ) + elif tensor_info.value_form == "value" and isinstance(tensor_info.value, slice): + inputs.append(process_slice(tensor_info.value)) + else: + inputs.append(tensor_info.value) + if verify_config.non_grad_indices: + handle_buffer_parameters(inputs, verify_config.non_grad_indices) + return inputs + + +def verify_partition_options(verify_config: VerifyConfig) -> bool: + errors = [] + try: + logger.info(f"Verifying partitions of {verify_config.fsig}...") + inputs = _create_op_inputs(verify_config) + torch.save(inputs, f"{verify_config.fsig}_inputs.pt") + logger.info(f"Input tensors saved to {verify_config.fsig}_inputs.pt") + + outputs_str = ", ".join([f"_out{i}" for i in range(verify_config.noutputs)]) + + kwargs_str = ", ".join( + [ + f'{k}="{v}"' if isinstance(v, str) else f"{k}={_complex(v)}" + for k, v in verify_config.kwargs.items() + ] + ) + + func_sig_call = verify_config.fsig + args_str = ", ".join([f"_in{i}" for i in range(len(verify_config.args))]) + tensor_member_methods_prefix = 'torch.Tensor.' + if func_sig_call.startswith(tensor_member_methods_prefix): + # workaround because tracer does not support tensor member methods + func_sig_call = f'_in0.' + func_sig_call[len(tensor_member_methods_prefix):] + func_args_str = ", ".join([f"_in{i}" for i in range(1, len(verify_config.args))]) + else: + func_args_str = args_str + + if func_args_str: + func_call = f"{outputs_str} = {func_sig_call}({func_args_str}, {kwargs_str})" + else: + func_call = f"{outputs_str} = {func_sig_call}({kwargs_str})" + + clone_args_right = ", ".join( + [ + f"_in{i}.clone()" + for i, tinfo in enumerate(verify_config.args) + if tinfo.value_form == "shape" + ] + ) + if clone_args_right: + clone_args_left = ", ".join( + [ + f"_in{i}" + for i, tinfo in enumerate(verify_config.args) + if tinfo.value_form == "shape" + ] + ) + clone_args = f"{clone_args_left} = {clone_args_right}" + else: + clone_args = "" + + dummy_input_str = ( + "{" + + ", ".join([f'"_in{i}": _in{i}' for i in range(len(verify_config.args))]) + + "}" + ) + + grad_tensors = ( + "[" + + ", ".join( + [ + f"_in{i}.grad" + for i in range(len(verify_config.args)) + if i not in verify_config.non_grad_indices + and verify_config.args[i].value_form == "shape" + ] + ) + + "]" + ) + module_single_str = module_template_single.format( + import_cumsomized_func=verify_config.import_customized_func, + clone_args=clone_args, + args=args_str, + kwargs=kwargs_str, + func_sig=verify_config.fsig, + func_sig_call=func_call, + outputs=outputs_str, + grad_tensors=grad_tensors, + ) + with open(_SINGLE_GPU_TEST_FILE, "w") as f: + f.write(module_single_str) + logger.info("Generated test code for single gpu and running...") + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_single.pt"]) + subprocess.run(["python", _SINGLE_GPU_TEST_FILE]) + logger.info( + f"Single GPU test completed. Output saved to {verify_config.fsig}_loss_single.pt" + ) + logger.info(f"verify_config: {verify_config}") + logger.info(f"verify_config.parti_options: {verify_config.parti_options}") + + for poption in verify_config.parti_options: + try: + logger.info(f"Verifying the partition {poption}...") + module_para_str = module_template_parallel.format( + import_cumsomized_func=verify_config.import_customized_func, + clone_args=clone_args, + args=args_str, + kwargs=kwargs_str, + func_sig=verify_config.fsig, + func_sig_call=func_call, + outputs=outputs_str, + dummy_input_str=dummy_input_str, + grad_tensors=grad_tensors, + idx=poption["idx"], + dim=poption["dim"], + ) + with open(_TWO_GPUS_TEST_FILE, "w") as f: + f.write(module_para_str) + logger.info("Generated test code for two gpus.") + + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_para_0.pt"]) + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_para_1.pt"]) + subprocess.run( + [ + "torchrun", + "--nproc_per_node=2", + "--nnodes=1", + "--rdzv-endpoint=localhost:23457", + _TWO_GPUS_TEST_FILE, + ] + ) + logger.info( + f"Two GPU test completed. Outputs saved to {verify_config.fsig}_loss_para_0.pt and {verify_config.fsig}_loss_para_1.pt" + ) + single = torch.load(f"{verify_config.fsig}_loss_single.pt") + logger.info( + f"Loading single loss from: {verify_config.fsig}_loss_single.pt" + ) + para0 = torch.load(f"{verify_config.fsig}_loss_para_0.pt") + para1 = torch.load(f"{verify_config.fsig}_loss_para_1.pt") + + logger.info(f"Single loss: {single[1]}") + logger.info(f"Multi-GPU loss (para0): {para0[1]}") + logger.info(f"Multi-GPU loss (para1): {para1[1]}") + + assert torch.allclose( + single[1], para0[1], rtol=1e-3, atol=1e-5 + ), f"Loss mismatch between single and multi-GPU (para0)" + assert torch.equal( + para0[1], para1[1].to(para0[1]) + ), f"Loss mismatch between multi-GPU (para0 and para1)" + + for i in range(len(single[0])): + if single[0][i] is None or para0[0][i] is None: + logger.debug( + f"Skipping comparison for index {i} because it is None" + ) + continue + logger.debug(f"Absolute error: {single[0][i] - para0[0][i]}") + logger.debug( + f"Relative error: {(single[0][i] - para0[0][i]) / single[0][i]}" + ) + assert torch.allclose( + single[0][i], para0[0][i], rtol=1e-3, atol=1e-5 + ), f"Gradient mismatch between single and multi-GPU (para0)" + assert torch.equal( + para0[0][i], para1[0][i].to(para0[0][i]) + ), f"Gradient mismatch between multi-GPU (para0 and para1)" + + logger.info( + f"{verify_config.fsig} of partition {poption} passed the allclose comparison." + ) + except Exception as e: + error_message = f"Partition {poption} failed with error: {str(e)}" + logger.error(error_message) + errors.append(error_message) + if errors: + logger.error("Some partitions failed:") + for error in errors: + logger.error(error) + return False + else: + logger.info( + f"Verified all the partitions of {verify_config.fsig} successfully." + ) + return True + except Exception as e: + logger.exception("Exception occurred during verification process") + raise e diff --git a/utility/verify_ops/verify_graph_operations.py b/utility/verify_ops/verify_graph_operations.py new file mode 100644 index 00000000..680d3fc1 --- /dev/null +++ b/utility/verify_ops/verify_graph_operations.py @@ -0,0 +1,161 @@ +import argparse +import os +import sys +import torch +from nnscaler.graph.function.dimops import DimAnno, IRDimops, OpAnno +from nnscaler.graph.graph import IRGraph +from nnscaler.ir.cten import IRObject, IRTensor +from pathlib import Path +import logging + +from verify_dimops import TensorInfo, get_candidate_options + +_VERIFIED_OPS_FILE_NAME = "verified_ops.pt" +_DEFAULT_CACHE_DIR = Path(os.path.expanduser("~/.cache/nnscaler")) + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +def load_verified_ops(outdir: Path): + verified_ops_file = outdir / _VERIFIED_OPS_FILE_NAME + if verified_ops_file.exists(): + logger.info(f"{verified_ops_file} exists, load it.") + return torch.load(verified_ops_file) + else: + logger.info(f"{verified_ops_file} does not exist, start from scratch.") + return set() + + +def save_verified_ops(outdir: Path, verified_ops: set): + verified_ops_file = outdir / _VERIFIED_OPS_FILE_NAME + torch.save(verified_ops, verified_ops_file) + logger.info(f"Verification results saved to {verified_ops_file}") + + +def verify_op_partitions(graph: IRGraph, outdir: Path): + """ + Test if the partitioned ops in the graph are computationally correct. + + Args: + graph (IRGraph): the graph to be verified + outdir (Path): the directory to save the verified ops + + Returns: + None + """ + from verify_dimops import ( + VerifyConfig, + TensorInfo, + verify_partition_options, + ) + + verified_ops = load_verified_ops(outdir) + skipped_nodes = [] + + gnodes = graph.nodes(flatten=True) + for idx, node in enumerate(gnodes): + logger.info(f"node: {node}") + logger.info(f"Verification progress: {idx} / {len(gnodes)}") + if node.isfw() and isinstance(node, IRDimops): + ins_info = [ + ( + TensorInfo("shape", _input.shape) + if isinstance(_input, IRTensor) + else TensorInfo( + "value", + _input.value if isinstance(_input, IRObject) else _input, + ) + ) + for _input in node.inputs() + ] + if not ins_info: + skipped_nodes.append(f"{node.signature} (type: {type(node)})") + logger.info(f"ins_info is empty for node: {node.signature}, skipping.") + continue + + outs_info = [ + ( + TensorInfo("shape", output.shape) + if isinstance(output, IRTensor) + else TensorInfo( + "value", + output.value if isinstance(output, IRObject) else output, + ) + ) + for output in node.outputs() + ] + if (node.signature, tuple(ins_info + outs_info)) in verified_ops: + logger.info(f"{node.signature} has been verified before, skip.") + continue + + logger.info(f"Node annos: {node.signature}, {node.anno}") + + parti_options = get_candidate_options(node.anno, ins_info + outs_info) + + logger.info(f"Candidate partition options: {parti_options}") + + verify_config = VerifyConfig( + fsig=node.signature, + args=ins_info, + kwargs=node.kwargs, + noutputs=len(node.outputs()), + parti_options=parti_options, + ) + try: + iscorrect = verify_partition_options(verify_config) + except Exception as e: + logger.warning( + f"Verification failed for {node.signature}, {e}, please manually verify." + ) + iscorrect = True # fake true to skip this node + if not iscorrect: + logger.warning(f"Verification failed for {node.signature}, continuing execution.") + continue + + verified_ops.add((node.signature, tuple(ins_info + outs_info))) + save_verified_ops(outdir, verified_ops) + + if skipped_nodes: + logger.info("Skipped the following nodes due to empty ins_info:") + for node_info in skipped_nodes: + logger.info(f" - {node_info}") + +def main(): + parser = argparse.ArgumentParser( + description="Verify partitions of operations in an IRGraph." + ) + parser.add_argument( + "--graph", type=str, required=True, help="Path to the graph file." + ) + parser.add_argument( + "--outdir", + type=str, + help="Optional directory to save the verified operations. If not provided, results will be saved to the default cache directory.", + ) + + args = parser.parse_args() + + graph_path = Path(args.graph) + if not graph_path.exists(): + raise FileNotFoundError(f"Graph file {graph_path} does not exist.") + + graph = IRGraph.load(graph_path) + + if args.outdir: + outdir = Path(args.outdir) + else: + outdir = _DEFAULT_CACHE_DIR + + outdir.mkdir(parents=True, exist_ok=True) + verify_op_partitions(graph, outdir) + + +if __name__ == "__main__": + main() diff --git a/utility/visualize_value_tracks.py b/utility/visualize_value_tracks.py new file mode 100644 index 00000000..164b838e --- /dev/null +++ b/utility/visualize_value_tracks.py @@ -0,0 +1,158 @@ +import argparse +import matplotlib.pyplot as plt +from nnscaler.graph import IRGraph +from matplotlib.patches import FancyArrowPatch +from nnscaler.ir.cten import IR, IRTensor, IRObject + + +class Visualizer: + NUM_ROWS_PER_OP = 3 + TEXT_HEIGHT_IN_INCH = 0.4 + PER_OP_GAP_IN_INCH = 0.2 + PER_ROW_HEIGHT_IN_INCH = TEXT_HEIGHT_IN_INCH * 1.1 + PER_OP_HEIGHT_IN_INCH = PER_ROW_HEIGHT_IN_INCH * NUM_ROWS_PER_OP + PER_INOUT_GAP = 0.01 + + INIT_Y = 0.001 + INIT_X = 0.001 + + def __init__(self, graph): + self.graph = graph + self.value_loc = {} + self.ops = [node for node in self.graph.nodes() if node.isfw()] + + self.fig_heigth_in_inch = ( + self.PER_OP_HEIGHT_IN_INCH + self.PER_OP_GAP_IN_INCH + ) * (len(self.ops) + 1) + self.coord_per_inch = 1.0 / self.fig_heigth_in_inch + self.per_op_height = self.PER_OP_HEIGHT_IN_INCH * self.coord_per_inch + self.per_row_height = self.per_op_height / self.NUM_ROWS_PER_OP + self.per_op_gap = self.PER_OP_GAP_IN_INCH * self.coord_per_inch + + self.fig, self.ax = plt.subplots(figsize=(30, self.fig_heigth_in_inch)) + self.ax.axis('off') + self.ax.invert_yaxis() + + def draw_value(self, value, value_track, cur_x, cur_y, previous_value_loc): + t = self.ax.text(cur_x, cur_y, str(value), + fontsize=14, ha="left", va="top") + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + if value_track is not None: + if value_track.value_id in previous_value_loc: + prev_x, prev_y = previous_value_loc[value_track.value_id] + arrow = FancyArrowPatch( + (prev_x, prev_y), + (cur_x + bbox.width/2, cur_y), + arrowstyle="Simple,tail_width=0.25,head_width=1,head_length=1", + mutation_scale=6, + color="#2c7bb6", + linewidth=0.02, + connectionstyle="arc3,rad=0", + alpha=0.5, + zorder=4 + ) + self.ax.add_patch(arrow) + self.value_loc[value_track.value_id] = (cur_x + bbox.width/2, cur_y) + + cur_x += bbox.width + self.PER_INOUT_GAP/2 + return cur_x + + def draw_obj(self, obj, cur_x, cur_y, previous_value_loc): + if isinstance(obj, IRTensor): + cur_x = self.draw_value('T(', None, cur_x, cur_y, previous_value_loc) + for i, d in enumerate(obj.shape): + if i > 0: + cur_x = self.draw_value(',', None, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(str(d), obj.dim_tracks[i], cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(')', None, cur_x, cur_y, previous_value_loc) + else: + assert isinstance(obj, IRObject) + cur_x = self.draw_value('O(', None, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(str(obj.value), obj.value_track, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(')', None, cur_x, cur_y, previous_value_loc) + cur_x += self.PER_INOUT_GAP + return cur_x + + def draw_objs(self, objs, cur_x, cur_y): + previous_value_loc = dict(self.value_loc) + for inp in objs: + cur_x = self.draw_obj(inp, cur_x, cur_y, previous_value_loc) + + def draw_graph_inputs(self, g, cur_x, cur_y): + label = "GRAPH IN: " + t = self.ax.text(cur_x, cur_y, label, + fontsize=14, fontweight="bold", ha="left", va="top") + + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + cur_x = cur_x + bbox.width + self.PER_INOUT_GAP + + ir_objs = [] + for inp in g.inputs(): + if isinstance(inp, (IRObject, IRTensor)): + ir_objs.append(inp) + elif isinstance(inp, IRObject): + sub_objs = IR.get_objects(inp.value) + if sub_objs: + ir_objs.extend(sub_objs) + else: + ir_objs.append(inp) + + self.draw_objs(ir_objs, cur_x, cur_y) + + def draw_inout(self, node, cur_y, is_in): + if is_in: + ir_objs = node.iobjs() + label = "IN: " + cur_y += self.per_row_height + else: + ir_objs = node.oobjs() + label = "OU: " + cur_y += self.per_row_height * 2 + + t = self.ax.text(self.INIT_X, cur_y, label, + fontsize=14, fontweight="bold", ha="left", va="top") + + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + cur_x = self.INIT_X + bbox.width + self.PER_INOUT_GAP + + self.draw_objs(ir_objs, cur_x, cur_y) + + def visualize(self): + self.draw_graph_inputs(self.graph, self.INIT_X, self.INIT_Y) + cur_y = self.INIT_Y + (self.per_op_height + self.per_op_gap)/2 + + for node in self.ops: + op_name = node.name + self.ax.text(self.INIT_X, cur_y, op_name + ":", + fontsize=16, fontweight="bold", ha="left", va="top") + + self.draw_inout(node, cur_y, is_in=True) + self.draw_inout(node, cur_y, is_in=False) + + cur_y += self.per_op_height + self.per_op_gap + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + 'graphfile', + type=str, + help="Graph dump file" + ) + parser.add_argument( + 'imagefile', + type=str, + nargs='?', + default=None, + help="Save generated image to file" + ) + args = parser.parse_args() + g = IRGraph.load(args.graphfile) + visualizer = Visualizer(g) + visualizer.visualize() + if args.imagefile: + plt.savefig(args.imagefile, bbox_inches='tight', dpi=100) + plt.show() From 3eb1bc249cc38723af0d7c864760142010026cbf Mon Sep 17 00:00:00 2001 From: youshan Date: Wed, 31 Dec 2025 21:23:57 +0800 Subject: [PATCH 1865/1892] Bump version to 0.8 --- nnscaler/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnscaler/version.py b/nnscaler/version.py index 84bc6647..ae87ac5e 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -__version__ = '0.7' +__version__ = '0.8' From 91ef006e4a12991d54f330a5e0576d08f6522df9 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 4 Jan 2026 01:46:28 +0000 Subject: [PATCH 1866/1892] Merged PR 2450: [AutoDist] Add dynamic constraint --- nnscaler/autodist/cube_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py index f25f1ab5..b7d3c2dc 100644 --- a/nnscaler/autodist/cube_operator.py +++ b/nnscaler/autodist/cube_operator.py @@ -108,7 +108,7 @@ def collect_anno_info(self): for idx_dim, dim_anno in enumerate(shape_anno.dims): for idx_id, identifier in enumerate(dim_anno.identifiers): reduce_type = dim_anno.reduces[idx_id] - if reduce_type != DimAnno.ReduceType.Freeze: + if reduce_type != DimAnno.ReduceType.Freeze and self.ir_cell.input(idx_shape).dim_tracks[idx_dim].is_constant: self.parallelable_dims.add(identifier) if reduce_type == DimAnno.ReduceType.Sum: self._has_sum_dim = True From 9834d5e9c7bf0b7065dfdb7212048beb1a3be434 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 4 Jan 2026 09:02:51 +0000 Subject: [PATCH 1867/1892] Merged PR 2451: [Parser] Support `torch.eye` and `.T` --- nnscaler/graph/function/function.py | 20 ++++++++++++++++++++ nnscaler/graph/parser/mapping.py | 1 + nnscaler/runtime/function/function.py | 4 ++++ tests/graph/function/test_functions.py | 14 ++++++++++++++ tests/graph/parser/test_parser.py | 20 ++++++++++++++++++++ 5 files changed, 59 insertions(+) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 7c686d76..f6af51ae 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -560,6 +560,22 @@ def NewTensor(data, *, dtype=None, device=None, return IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) +def Eye(n: int, m: Optional[int] = None, *, dtype=None, device=None, + requires_grad=False, signature=None): + """ + torch.eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor + """ + dtype = dtype if dtype is not None else torch.get_default_dtype() + creation_function_args_check('torch.eye', dtype=dtype, device=device) + + signature = 'nnscaler.runtime.function.eye' + if m is None: + m = n + kwargs = {'n': n, 'm': m, 'requires_grad': requires_grad, 'dtype': dtype} + anno, rules = _get_creator_anno_rules((_unwrap_value(n), _unwrap_value(m)), False) + return IRDimops(Eye, 'eye', signature, [anno], [], rules, **kwargs) + + def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: """Create shape annotations for element wise operator following broadcastable rules: https://pytorch.org/docs/stable/notes/broadcasting.html @@ -2702,6 +2718,10 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" _logger.warning("getattr of 'layout' will always return torch.strided") return torch.strided + if name == 'T': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + assert len(obj.shape) == 2, "only 2-dim tensor support .T operation" + return Transpose(obj, 0, 1, signature='torch.transpose') if isinstance(obj, torch.finfo): return getattr(obj, name) return IRPyFunc(signature, [instance, field], [IRObject.missing]) diff --git a/nnscaler/graph/parser/mapping.py b/nnscaler/graph/parser/mapping.py index 2c8f88ce..56c2908f 100644 --- a/nnscaler/graph/parser/mapping.py +++ b/nnscaler/graph/parser/mapping.py @@ -181,6 +181,7 @@ def exist(signature: str) -> bool: __ttemplate('rand_like'): function.RandLike, __ttemplate('randn'): function.Randn, __ttemplate('randn_like'): function.RandnLike, + __ttemplate('eye'): function.Eye, __ttemplate('clone'): function.Clone, '_operator.is_': function.Is, diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index cba44779..8bda380a 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -325,6 +325,10 @@ def linspace(start: Union[int, torch.Tensor], end: Union[int, torch.Tensor], device=torch.cuda.current_device()) +def eye(n: int, m: Optional[int]=None, requires_grad=False, dtype: torch.dtype=torch.float32) -> torch.Tensor: + return torch.eye(n, m=m, dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) + + def index_select(input: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor: return torch.index_select(input, dim, index) diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index e6a12c55..8a32e20d 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -63,6 +63,20 @@ def test_Randn(): assert op.output(0).dim_tracks[2].deps == [op.kwargs['size'][2].value_track.value_id] +def test_Eye(): + op = F.Eye(IRObject(value=3)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 3^ 3^' + + op = F.Eye(IRObject(value=3), IRObject(value=4)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 3^ 4^' + + op = F.Eye(3, IRObject(value=4)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 3^ 4^' + + op = F.Eye(IRObject(value=3), 4) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 3^ 4^' + + def test_Expand(): inp = IRTensor([10, 1]) out = IRTensor([10, 2]) diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 176cba07..c468417a 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -324,3 +324,23 @@ def forward(self, x): # so the output number is 1 for now. # Will be fixed later. assert len(ir_graph.outputs()) == 1 + + +@replace_all_device_with('cpu') +def test_T(tmp_path): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.matmul(x, x.T) + + dummy_input = {'x': torch.randn(4, 8)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) + print(ir_graph.extra_repr()) + + assert ir_graph.nodes()[0].signature == 'torch.transpose' From ccb7fa1bb7bd9ce64e5b974d64bbe005ebba077b Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sun, 4 Jan 2026 14:27:55 +0000 Subject: [PATCH 1868/1892] Merged PR 2441: [Feat] Add option: `reducer_pre_divisor` --- nnscaler/cli/trainer.py | 4 ++-- nnscaler/cli/trainer_args.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index c5d8289e..07a43a3f 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -226,7 +226,7 @@ def _setup(self): # (see `train_args.optimizer.grad_reduction`` handling in `train_epoch`). # This is useful to avoid overflow when the gradients are large. def reducer_pre_hook(reducer, grad): - grad.div_(self.train_args.scaling_factor) + grad.div_(self.train_args.optimizer.grad_reduce_divisor or self.train_args.scaling_factor) self.optimizer.register_reducer_pre_hook(reducer_pre_hook) # Currently we never pass `last_epoch` to its constructor self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) @@ -1057,7 +1057,7 @@ def _train_epoch(self, epoch: int) -> None: self.hook.after_sync_grad(self) # scale gradients - multiplier = self.train_args.scaling_factor + multiplier = self.train_args.optimizer.grad_reduce_divisor or self.train_args.scaling_factor if self.train_args.optimizer.grad_reduction == 'sum': # do nothing. `multiplier` is already correct pass diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 5cd811fa..20fdda77 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -337,6 +337,13 @@ class OptimizerConfig: # per-token-mean: average the gradients over all tokens # you must specify `aggregate_outputs_fn` and return the number of tokens grad_reduction: str = 'mean' + # the divisor applied to gradients before all-reduce. If not set, the default + # divisor is `runtime_ngpus / plan_ngpus`. We divide the gradients to avoid overflow. + # However, if the gradients are in high precision or the user has known the range of + # the gradients, he/she can set a smaller divisor to improve the accuracy. Note that + # the gradients will be recovered by multiplying the divisor after all-reduce and before + # optimizer step. + grad_reduce_divisor: Optional[float] = None # the function to aggregate the outputs from all micro-batches # inputs: (list of local outputs, torch group) # output: AggregateOutputs From 4a78f4af9f4eb82eaa36f506dd5a2a5f67dbe09f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 5 Jan 2026 03:08:11 +0000 Subject: [PATCH 1869/1892] Merged PR 2447: [Refine] move dummy_input from trainer to trainer args and polish checkpointer 1. Move dummy_input from trainer to trainer args so `ModuleParallelizeConfig.forward_args_gen_fn` have a way to create args from dummy input 2. Polish serializer to better support async checkpointer and blobfuse storage (path.unlinke is not reliable in blobfuse). --- nnscaler/cli/serialization.py | 41 +++++++++++++++++------ nnscaler/cli/train_hook.py | 17 ++++++++++ nnscaler/cli/trainer.py | 59 ++++++++++++++------------------- nnscaler/cli/trainer_args.py | 31 +++++++++++++++-- tests/cli/test_serialization.py | 6 ++++ 5 files changed, 106 insertions(+), 48 deletions(-) diff --git a/nnscaler/cli/serialization.py b/nnscaler/cli/serialization.py index c499740f..3fb281f9 100644 --- a/nnscaler/cli/serialization.py +++ b/nnscaler/cli/serialization.py @@ -4,12 +4,17 @@ from typing import Any, Callable, Protocol, Type from pathlib import Path import shutil +import time +import logging import torch from nnscaler.runtime.serialization import load, save +logger = logging.getLogger(__name__) + + class _LoadProc(Protocol): def __call__(self, f: str | Path, *, device='cpu') -> Any: ... @@ -300,14 +305,12 @@ def remove_for_rank(self, dir: str | Path, rank: int) -> None: rank (`int`): The rank of the checkpoint file to remove. """ - self.flush() - - for suffix in self.NAME_MAP.values(): + suffixes = set(list(self.NAME_MAP.values()) + [self.suffix]) + for suffix in suffixes: f = Path(dir) / f"{rank}{suffix}" - if f.exists(): - f.unlink() + f.unlink(missing_ok=True) for extra_file in Path(dir).glob(f"{rank}{suffix}.*"): - extra_file.unlink() + extra_file.unlink(missing_ok=True) def copy_for_rank(self, src: str | Path, dst: str | Path, rank: int, symlink: bool = False) -> None: """ @@ -322,8 +325,8 @@ def copy_for_rank(self, src: str | Path, dst: str | Path, rank: int, symlink: bo symlink (`bool`, *optional*, defaults to `False`): Whether to create a symbolic link instead of copying the file. """ + self.remove_for_rank(dst, rank) - self.flush() src = Path(src).resolve() dst = Path(dst).resolve() dst.mkdir(parents=True, exist_ok=True) @@ -341,16 +344,34 @@ def copy_for_rank(self, src: str | Path, dst: str | Path, rank: int, symlink: bo raise ValueError("Cannot create symlink when source and destination are not in the same directory.") if symlink: - dst_f.symlink_to(Path('..') / src.name / src_f.name) + self._create_symlink_with_retry(Path('..') / src.name / src_f.name, dst_f) for extra_file in src.glob(f"{rank}{self.suffix}.*"): dst_extra_file = Path(dst) / extra_file.name - dst_extra_file.symlink_to(Path('..') / src.name / extra_file.name) + self._create_symlink_with_retry(Path('..') / src.name / extra_file.name, dst_extra_file) else: shutil.copy2(src_f, dst_f) for extra_file in src.glob(f"{rank}{self.suffix}.*"): dst_extra_file = Path(dst) / extra_file.name shutil.copy2(extra_file, dst_extra_file) + @classmethod + def _create_symlink_with_retry(cls, src: str | Path, dst: str | Path) -> None: + dst = Path(dst) + dst.unlink(missing_ok=True) + + # deletion in blobfuse is not immediate sometimes + # so we retry until success + while True: + try: + dst.symlink_to(Path(src)) + break + except FileExistsError: + logger.warning(f"Creating symlink {dst} failed. Retrying...") + dst.unlink(missing_ok=True) + time.sleep(0.1) + + logger.info(f"Symlink {dst} created.") + def list_checkpoints(self, dir: str | Path) -> list[Path]: """ List the main checkpoint files in a directory @@ -361,8 +382,6 @@ def list_checkpoints(self, dir: str | Path) -> list[Path]: (`list[Path]`): The list of checkpoint files in the directory. """ - self.flush() - p = Path(dir) files = [] for suffix in self.NAME_MAP.values(): diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index 7848ae31..afccf2da 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from typing import Any, Dict, List, TYPE_CHECKING, Literal, TypedDict, Optional +from pathlib import Path import torch @@ -213,6 +214,18 @@ def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> checkpoint: the checkpoint to be saved """ + def on_expire_checkpoint(self, trainer: 'Trainer', step: int, checkpoint_dir: Path) -> None: + """ + Called before expiring (deleting) checkpoint. + If you want to do something before a checkpoint is deleted, you can do it here. + + Note: only local-rank 0 will call this hook. + + Args: + step: the overall training step of the checkpoint to be expired + checkpoint_dir: the directory that holds the checkpoint to be expired + """ + class AggregatedTrainHook(TrainHook): def __init__(self, hooks: List[TrainHook]): @@ -334,6 +347,10 @@ def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> for hook in self.hooks: hook.on_save_checkpoint(trainer, checkpoint) + def on_expire_checkpoint(self, trainer: 'Trainer', step: int, checkpoint_dir: Path) -> None: + for hook in self.hooks: + hook.on_expire_checkpoint(trainer, step, checkpoint_dir) + class TrainHookHost: def _get_hook_objects(self) -> List[Any]: diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 07a43a3f..a781a2c7 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -119,19 +119,6 @@ def run(self): def _fix_input(self, input): return fix_input(input, self.train_args.input_dtype) - def _load_dummy_input(self): - if dummy_sample_gen_fn := self.train_args.dummy_sample_gen_fn: - return dummy_sample_gen_fn(self.train_args) - - with enforce_zero_num_worker(DataLoader): - dataset = self.train_args.create_dataset('train') - dataloader = self.train_args.create_dataloader('train', dataset) - assert dataloader.num_workers == 0, "The dataloader must have `num_workers=0`." - value = next(iter(dataloader)) - if close_fn := getattr(dataloader, 'close', None): - close_fn() - return value - def _setup(self): if is_running_distributed(): nnscaler.init() @@ -150,10 +137,7 @@ def _setup(self): compile_only = self.train_args.compile_mode # load a dummy input from training dataset - self.dummy_input = self._load_dummy_input() - self.dummy_input = self._fix_input(self.dummy_input) - if self.train_args.dummy_sample_post_process_fn: - self.dummy_input = self.train_args.dummy_sample_post_process_fn(self.train_args, self.dummy_input) + self.dummy_input = self.train_args.dummy_input pmodel = parallelize_model( self.train_args, self.dummy_input, @@ -236,6 +220,7 @@ def reducer_pre_hook(reducer, grad): self.model, self.optimizer, self.lr_scheduler, + self.checkpointer, ] component_hooks = [] for component in supported_hook_components: @@ -661,11 +646,6 @@ def _save_checkpoint(self, loss): if checkpoint_config.save_last: logger.info(f"Saving checkpoint as the last checkpoint.") - # remove the old symlink or file - self.checkpointer.remove_for_rank( - save_dir / self.checkpointer.get_last_dir_name(), - self.rank - ) self.checkpointer.copy_for_rank( ckpt_file.parent, save_dir / self.checkpointer.get_last_dir_name(), @@ -678,11 +658,6 @@ def _save_checkpoint(self, loss): logger.info(f"Best loss updated: {self.train_status.best_loss:.3f} -> {loss:.3f}") logger.info(f"Saving checkpoint as the best checkpoint.") - # remove the old symlink or file - self.checkpointer.remove_for_rank( - save_dir / self.checkpointer.get_best_dir_name(), - self.rank - ) self.checkpointer.copy_for_rank( ckpt_file.parent, save_dir / self.checkpointer.get_best_dir_name(), @@ -701,6 +676,14 @@ def _save_checkpoint(self, loss): torch.distributed.barrier() + @classmethod + def _get_dependent_dirs(cls, ckpt_dir): + target_dirs = set() + for p in Path(ckpt_dir).glob('*'): + if p.is_symlink(): + target_dirs.add(p.resolve().parent.name) + return target_dirs + def _expire_checkpoints(self): if not self.train_args.checkpoint.keep_last_n_checkpoints: # keep all return @@ -718,6 +701,8 @@ def _expire_checkpoints(self): # (step, ckpt_name) pairs checkpoint_info = [(int(p.split('-')[1]), p) for p in checkpoints] + # map from ckpt_name to step + checkpoint_info_map = {p[1]: p[0] for p in checkpoint_info} checkpoint_info.sort() expire_list = [c[1] for c in checkpoint_info[:-self.train_args.checkpoint.keep_last_n_checkpoints]] @@ -726,17 +711,21 @@ def _expire_checkpoints(self): for ckpt_dir in [best_ckpt, last_ckpt]: if not ckpt_dir.exists(): continue - for p in self.checkpointer.list_checkpoints(ckpt_dir): - if p.is_symlink(): - ckpt_name = p.resolve().parent.name - if ckpt_name in expire_list: - expire_list.remove(ckpt_name) - logger.info('Keep old checkpoint `%s` because it is symbol linked in best or last.', ckpt_name) - break # just check the first file is enough + for ckpt_name in self._get_dependent_dirs(ckpt_dir): + if ckpt_name in expire_list: + expire_list.remove(ckpt_name) + logger.info('Keep old checkpoint `%s` because it is symbol linked in best or last.', ckpt_name) for ckpt_name in expire_list: logger.info('Removing old checkpoint: %s', ckpt_name) - shutil.rmtree(save_dir / ckpt_name) + self.hook.on_expire_checkpoint(self, checkpoint_info_map[ckpt_name], save_dir / ckpt_name) + try: + shutil.rmtree(save_dir / ckpt_name) + except FileNotFoundError: + # may have been removed by other processes (when the storage is shared) + pass + except Exception as e: + logger.warning('Error when expiring checkpoint `%s`: %s. Will try later.', ckpt_name, e) def _global_batch_iterator(self, num_skip_first=0, stage='train'): if stage == 'train': diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 20fdda77..bfa86cf7 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -15,12 +15,12 @@ import torch import torch.utils import torch.utils.data -import torch.utils.data.dataloader +from torch.utils.data.dataloader import DataLoader import yaml import torch import nnscaler -from nnscaler.utils import fields, fn_field, transform_recursively, load_type, copy_dynamic +from nnscaler.utils import enforce_zero_num_worker, fields, fn_field, transform_recursively, load_type, copy_dynamic from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule @@ -615,6 +615,7 @@ class HookMapConfig: on_load_checkpoint: str = None after_load_checkpoint: str = None on_save_checkpoint: str = None + on_expire_checkpoint: str = None class ArgsTrainHook(TrainHook): @@ -794,6 +795,10 @@ def __post_init__(self): ) self._vars = self.create_kwarg(self.vars) + # will be initialized lazily + # because it is heavy, and may not be used in some cases + # and it looks weird to initialize it eagerly in __post_init__ + self._dummy_input = None @classmethod def from_cli(cls, argv: List[str]) -> 'TrainerArgs': @@ -924,6 +929,28 @@ def get_resolved_var(self, fqn: str, *, default: Any = None) -> Any: var = var[part] return var + @property + def dummy_input(self): + if self._dummy_input is None: + self._dummy_input = self._load_dummy_input() + self._dummy_input = fix_input(self._dummy_input, self.input_dtype) + if self.dummy_sample_post_process_fn: + self._dummy_input = self.dummy_sample_post_process_fn(self, self._dummy_input) + return self._dummy_input + + def _load_dummy_input(self): + if self.dummy_sample_gen_fn: + return self.dummy_sample_gen_fn(self) + + with enforce_zero_num_worker(DataLoader): + dataset = self.create_dataset('train') + dataloader = self.create_dataloader('train', dataset) + assert dataloader.num_workers == 0, "The dataloader must have `num_workers=0`." + value = next(iter(dataloader)) + if close_fn := getattr(dataloader, 'close', None): + close_fn() + return value + def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model.args) return self.model_type(**kwargs) diff --git a/tests/cli/test_serialization.py b/tests/cli/test_serialization.py index cd5ae9b4..e4120cfd 100644 --- a/tests/cli/test_serialization.py +++ b/tests/cli/test_serialization.py @@ -137,6 +137,8 @@ def run_save(self, save_func, state_dict, f): '--checkpoint.symlink_best_and_last', str(symblink), ]) trainer.run() + torch.distributed.barrier() + ckpt_files = list_ckpt_files(ckpt_savedir) assert len(ckpt_files)/4 == min(10, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last @@ -169,6 +171,7 @@ def run_save(self, save_func, state_dict, f): ]) trainer.run() + torch.distributed.barrier() # create merged checkpoint ckpt1_savedir = save_dir / 'ckpt1' ckpt1_savedir.mkdir(parents=True, exist_ok=True) @@ -221,6 +224,9 @@ def run_save(self, save_func, state_dict, f): '--checkpoint.symlink_best_and_last', str(symblink), ]) trainer.run() + + torch.distributed.barrier() + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) torch.distributed.barrier() From 442b2438c52f583cb38a0b8a6033bdbdaf53329d Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 5 Jan 2026 07:54:13 +0000 Subject: [PATCH 1870/1892] Merged PR 2444: [Tracer] Fix for torch 2.8 --- nnscaler/graph/tracer/torch_fx_patcher.py | 35 ++++++++++++++++++++++- requirements.txt | 2 +- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/nnscaler/graph/tracer/torch_fx_patcher.py b/nnscaler/graph/tracer/torch_fx_patcher.py index 8affab89..af358275 100644 --- a/nnscaler/graph/tracer/torch_fx_patcher.py +++ b/nnscaler/graph/tracer/torch_fx_patcher.py @@ -191,7 +191,7 @@ def format_import_statement_new(name: str, obj: Any, importer) -> str: return TorchFXPatcher.format_import_statement_ori(name, obj, importer) @staticmethod - def is_impure_new(node: fx_node.Node): + def is_impure_new(node: fx_node.Node, impure_random: bool = True) -> bool: """ Returns whether this op is impure, i.e. if its op is a placeholder or output, or if a call_function or call_module which is impure. @@ -208,6 +208,39 @@ def is_impure_new(node: fx_node.Node): # Check if an impure function. if node.op == "call_function": + schema = getattr(node.target, "_schema", None) + if schema is not None and schema.is_mutable: + # impure since it mutates inputs + return True + + if impure_random: + if getattr(node.target, "_nondeterministic_seeded", False): + # impure since it mutates RNG state + return True + + # Handle Python random functions that don't have _nondeterministic_seeded + # but still affect global RNG state (issue #151524) + # These should be impure regardless of impure_random setting to maintain + # consistency between eager and compiled execution + _random_functions = { + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.rand_like, + torch.randn_like, + torch.randint_like, + torch.normal, + torch.poisson, + torch.bernoulli, + torch.multinomial, + } + + if node.target in _random_functions: + # All random operations are impure to ensure consistent behavior + # between eager and compiled execution, regardless of generator usage + return True + return node.target in _side_effectful_functions # NOTE by nnscaler: we assume all method end with "_" is inplace operation, diff --git a/requirements.txt b/requirements.txt index 41efcd26..0d337d5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,6 @@ psutil pulp pybind11<3.0.0 pyyaml -torch>=2.0,<=2.6 +torch>=2.0,<=2.8 tqdm safetensors From 5a902bca374030808014fcecbbaef15926aac693 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 5 Jan 2026 07:58:54 +0000 Subject: [PATCH 1871/1892] Merged PR 2452: [HotFix] Merge z3 model states correctly --- nnscaler/cli/trainer.py | 2 +- nnscaler/parallel.py | 2 +- nnscaler/policies.py | 4 ++- nnscaler/profiler/database.py | 5 ++- nnscaler/runtime/module.py | 57 +++++++++++++++++++++-------------- tests/cli/test_trainer.py | 6 ++++ 6 files changed, 50 insertions(+), 26 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index a781a2c7..48e8730d 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -606,7 +606,7 @@ def _save_checkpoint(self, loss): current_epoch -= 1 if checkpoint_config.save_type == 'sharded': - model_state_dict= self.model.state_dict() + model_state_dict = self.model.state_dict() optimizer_state_dict = self.optimizer.state_dict() elif checkpoint_config.save_type == 'deduped': model_state_dict, optimizer_state_dict = nnscaler.deduped_state_dict( diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 6dd57aaf..bda51909 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1846,7 +1846,7 @@ def _sort_state_dicts(state_dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]] module_prefix = '.'.join(k) opt_state_dicts_for_merge = None if opt_state_dicts is None else opt_state_dicts[module_prefix] - merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks, e.zero) for e in extra_states] + merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks, e.zero, e.zero3_param_metadata) for e in extra_states] if not extra_states[0].compute_config.use_zero: # all ranks should have the same use_zero merge_partial_states_zero_idx_maps = None merged_state_dict, merged_opt_state_dict = ParallelModule.merge_state_dicts( diff --git a/nnscaler/policies.py b/nnscaler/policies.py index b4cf13af..2768e0a6 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -100,6 +100,8 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): random tensor parallelism inside a scale unit, and dp across scale units """ ngpus = cfg.plan_ngpus + pas_cfg = cfg.pas_config + enable_random_replicated = pas_cfg.get('enable_random_replicated', False) # get the current random state state = random.getstate() @@ -114,7 +116,7 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): continue if isinstance(node, IRDimops): configs = node.transform_space() - if len(configs) == 0: + if len(configs) == 0 or (enable_random_replicated and random.random() < 0.5): _replica(graph, node, devs) else: configs = sorted(configs, reverse=True, diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 27c7f54d..61597b2a 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -511,7 +511,10 @@ def load_ops(self, folder: str): if filename.endswith('.json'): with open(os.path.join(folder, filename)) as f: signature = filename[:-len('.json')] - loaded_json = json.load(f) + try: + loaded_json = json.load(f) + except json.JSONDecodeError: + raise RuntimeError(f'fail to load profiling data from {filename}, please check the file content') self._data[signature] = {key: ProfiledMetrics(**value) for key, value in loaded_json.items()} def __repr__(self) -> str: diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 9d286489..05b9ea03 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -434,7 +434,8 @@ def _safe_tensor_equal(cls, tensor1: Any, tensor2: Any): @staticmethod def merge_model_state_dicts( state_dicts: List[Dict], - fullmaps: List[Dict[str, AttrMeta]] + fullmaps: List[Dict[str, AttrMeta]], + zero_idx_maps: Optional[List[Dict]] = None ): """Merge model states from multiple shard into a single-model state. Here we assume the order of state_dicts and fullmaps are aligned, and is the same as the rank order. @@ -446,6 +447,7 @@ def merge_model_state_dicts( Args: state_dicts (List[Dict[str, torch.Tensor]]): per-rank local model state dict from model.state_dict() fullmaps (List[Dict[str, AttrMeta]]): per-rank fullmap + zero_idx_maps (Optional[List[Dict]]): zero information for the model, `None` if zero is not enabled Returns: full_state_dicts (List[Dict[str, torch.Tensor]]): Full model state dict @@ -461,11 +463,12 @@ def merge_model_state_dicts( state_dict_merge_track: Dict[str, Set[Tuple[Tuple[Any, Any, Any], ...]]] = {} # the fill progress of zero3 parameters # key: param name - # value: Dict[ tuple(start, step, stop) , filled size] + # value: Dict[ tuple(start, step, stop) , filled chunk] # used to track how many elements have been filled for each zero3 parameter - zero3_current_filled: Dict[str, Dict[Tuple[Tuple[int, int, int], ...], int]] = {} + zero3_current_filled: Dict[str, Dict[Tuple[Tuple[int, int, int], ...], List[Tuple[int, int]]]] = {} + zero3_param_metadatas = [info[-1] for info in zero_idx_maps] if zero_idx_maps is not None else [None] * len(state_dicts) # gather param/buffer full tensor - for rank, (model_state_dict, local_fullmap) in enumerate(zip(state_dicts, fullmaps)): + for rank, (model_state_dict, local_fullmap, zero3_param_metadata) in enumerate(zip(state_dicts, fullmaps, zero3_param_metadatas)): for local_name, meta in local_fullmap.items(): if local_name not in model_state_dict: # the parameter may not in model_state_dict (deduped with optimization) @@ -496,27 +499,33 @@ def merge_model_state_dicts( # we assume zero3 is on when dest_tensor.shape != partial_tensor.shape if len(partial_tensor.shape) != 1: raise ValueError("Invalid tensor as a ZeRO3 parameter, expected a 1D tensor.") - fill_start = zero3_current_filled.setdefault(meta.orig_name, {}).setdefault(state_dict_merge_track_id, 0) - fill_len = partial_tensor.numel() - if fill_start >= dest_tensor.numel(): + curr_filled = zero3_current_filled.setdefault(meta.orig_name, {}).setdefault(state_dict_merge_track_id, []) + curr_z3_info = zero3_param_metadata[local_name] + curr_start, curr_end = curr_z3_info['start'], curr_z3_info['end'] + fill_len = curr_end - curr_start + if (curr_start, curr_end) in curr_filled: # already filled, let's check consistency - fill_start = fill_start % dest_tensor.numel() - if fill_start + fill_len > dest_tensor.numel(): - # remove padding part - fill_len = dest_tensor.numel() - fill_start - if not CubeModule._safe_tensor_equal(dest_tensor.view(-1)[fill_start: fill_start + fill_len], partial_tensor[0:fill_len]): + if not CubeModule._safe_tensor_equal(dest_tensor.view(-1)[curr_start: curr_end], partial_tensor[0:fill_len]): raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") else: - if fill_start + fill_len > dest_tensor.numel(): - # remove padding part - fill_len = dest_tensor.numel() - fill_start old_shape = dest_tensor.shape dest_tensor = dest_tensor.reshape(-1) - dest_tensor[fill_start: fill_start + fill_len] = partial_tensor[0: fill_len] + dest_tensor[curr_start: curr_end] = partial_tensor[0: fill_len] full_model_state_dict[meta.orig_name][meta.slicers] = dest_tensor.view(old_shape) - - zero3_current_filled[meta.orig_name][state_dict_merge_track_id] += fill_len - + zero3_current_filled[meta.orig_name][state_dict_merge_track_id].append((curr_start, curr_end)) + + if zero3_current_filled: + # verify all zero3 parameters are fully filled + for param_name, slicers2filled in zero3_current_filled.items(): + for slicers, filled_chunks in slicers2filled.items(): + full_size = 1 + for s in slicers: + full_size *= s[-1] - s[0] + covered_size = 0 + for start, end in filled_chunks: + covered_size += end - start + if full_size != covered_size: + raise ValueError(f'Uncovered ZeRO3 parameter {param_name} with slicers {slicers}, full size {full_size}, covered size {covered_size}') return full_model_state_dict @staticmethod @@ -599,7 +608,7 @@ def merge_state_dicts( # help understand the whole logic. In other words, the real plan_ngpus is <= len(model_state_dicts). plan_ngpus = len(model_state_dicts) # gather model states - full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps[0: plan_ngpus]) + full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps[0: plan_ngpus], zero_idx_maps) _logger.info('finish merge model states') if optim_state_dicts is None: return full_model_state_dict, None @@ -704,7 +713,9 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size, return opt_states def _merge_opt_zero(param_shape, worker_idx, param_idx): - if len(zero_idx_maps[worker_idx]) == 3: + if len(zero_idx_maps[worker_idx]) == 4: + model_idx2opt_idx, opt_idx2ranks, zero_version, _ = zero_idx_maps[worker_idx] + elif len(zero_idx_maps[worker_idx]) == 3: # backward compatibility model_idx2opt_idx, opt_idx2ranks, zero_version = zero_idx_maps[worker_idx] else: # backward compatibility assert len(zero_idx_maps[worker_idx]) == 2 @@ -966,6 +977,7 @@ class ZeroMetadata: # 1: zero1 # > 1: zero3 zero: int = 0 + zero3_param_metadata: Optional[Dict[str, Dict]] = None @dataclass @@ -1540,7 +1552,8 @@ def _get_zero_metadata(self) -> ZeroMetadata: return ZeroMetadata( model_idx2opt_idx=model_idx2opt_idx, opt_idx2ranks=opt_idx2ranks, - zero=self.compute_config.use_zero + zero=self.compute_config.use_zero, + zero3_param_metadata=self._zero3_param_metadata, ) def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 307adbef..4e3ffea5 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -1258,6 +1258,7 @@ def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): '--compute_config.runtime_ngpus', f'{runtime_ngpus}', '--compute_config.zero_ngroups', f'{zero_ngroups}', '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', '--checkpoint.save_dir', str(ckpt3_savedir), '--pas_policy', f'{policy}', '--checkpoint.save_type', 'sharded', @@ -1276,6 +1277,7 @@ def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): '--compute_config.runtime_ngpus', f'{runtime_ngpus}', '--compute_config.zero_ngroups', f'{zero_ngroups}', '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', '--checkpoint.save_dir', str(ckpt3_savedir), '--checkpoint.resume_from', 'last', '--pas_policy', f'{policy}', @@ -1293,6 +1295,7 @@ def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): '--compute_config.runtime_ngpus', f'{runtime_ngpus}', '--compute_config.zero_ngroups', f'{zero_ngroups}', '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', '--checkpoint.save_dir', str(ckpt3_savedir), '--checkpoint.resume_from', 'last', '--pas_policy', f'{policy}', @@ -1315,6 +1318,7 @@ def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): '--compute_config.runtime_ngpus', f'{runtime_ngpus}', '--compute_config.zero_ngroups', f'{zero_ngroups}', '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', '--checkpoint.save_dir', str(ckpt3_savedir), '--checkpoint.resume_from', str(ckpt3_savedir / 'merged.pt'), '--pas_policy', f'{policy}', @@ -1350,6 +1354,7 @@ def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): '--compute_config.runtime_ngpus', f'{runtime_ngpus}', '--compute_config.zero_ngroups', f'{zero_ngroups}', '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', '--checkpoint.save_dir', str(ckpt3_savedir), '--checkpoint.resume_from', str(ckpt3_savedir / 'merged.pt'), '--pas_policy', f'{policy}', @@ -1371,6 +1376,7 @@ def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): '--compute_config.runtime_ngpus', f'{runtime_ngpus}', '--compute_config.zero_ngroups', f'{zero_ngroups}', '--compute_config.use_zero', '1', + '--compute_config.pas_config.enable_random_replicated', 'True', '--checkpoint.save_dir', str(ckpt1_savedir), '--pas_policy', f'{policy}', '--checkpoint.save_type', 'sharded', From 527c5b7bc096a443d1096b54a5663e5006cb2a77 Mon Sep 17 00:00:00 2001 From: nnScaler Date: Mon, 5 Jan 2026 19:21:20 +0800 Subject: [PATCH 1872/1892] nit fix --- README.md | 1 + docs/source/quickstart.rst | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4062c2fa..1b99ef0d 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ nnScaler is a parallelization engine that compiles a Deep neural network (DNN) m # Latest News nnScaler (also known as CUBE as code name) has been adopted by multiple product and research projects, this section includes some of the latest news from the team and partner projects. +* **2025-08-12** nnScaler 0.8 released: https://github.com/microsoft/nnscaler/releases/tag/0.8 * **2025-02-12** nnScaler 0.7 released: https://github.com/microsoft/nnscaler/releases/tag/0.7 * **2024-10-07** Diff-Transformer utilizes nnScaler for differential attention mechanism: [DIFFERENTIAL TRANSFORMER](https://arxiv.org/abs/2410.05258) * **2024-05-09** YOCO utilizes nnScaler for long-sequence training: [(YOCO)You only cache once: Decoder-decoder architectures for language models](https://arxiv.org/abs/2405.05254) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 85eb5059..5179d9b5 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -10,7 +10,7 @@ nnScaler can be installed from GitHub: .. code-block:: bash - pip install https://github.com/microsoft/nnscaler/releases/download/0.7/nnscaler-0.7-py3-none-any.whl + pip install https://github.com/microsoft/nnscaler/releases/download/0.8/nnscaler-0.8-py3-none-any.whl # You may also want to clone the repo to try out the examples git clone --recursive https://github.com/microsoft/nnscaler From 8ba2509c1be6194a5d92717ab14ba97f9bdb0676 Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Wed, 14 Jan 2026 14:27:43 +0800 Subject: [PATCH 1873/1892] save work --- nnscaler/parallel.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index bda51909..b7bdcc79 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -180,9 +180,13 @@ def __post_init__(self): if self.use_zero: if num_scale_units % self.zero_ngroups != 0: raise ValueError(f"zero_ngroups {self.zero_ngroups} must be a divisor of runtime_ngpus/plan_ngpus {num_scale_units}.") - if num_scale_units == self.zero_ngroups: - logger.warning(f"zero_ngroups {self.zero_ngroups} equals to runtime_ngpus/plan_ngpus {num_scale_units}. Zero optimization is disabled.") - super().__setattr__('use_zero', 0) + # NOTE: + # we can't disable zero optimization when num_scale_units == zero_ngroups here + # because some ops are replicated inside a scale unit, + # and those ops can still utilize zero optimization. + # if num_scale_units == self.zero_ngroups: + # logger.warning(f"zero_ngroups {self.zero_ngroups} equals to runtime_ngpus/plan_ngpus {num_scale_units}. Zero optimization is disabled.") + # super().__setattr__('use_zero', 0) if self.use_zero and self.zero_ngroups <= 0: raise ValueError(f"zero_ngroups {self.zero_ngroups} must be > 0") From 8f5366fdaf6a12bc60dd230e9569859a4bb99d4a Mon Sep 17 00:00:00 2001 From: XU Weijiang <90586345+0xWJ@users.noreply.github.com> Date: Thu, 15 Jan 2026 08:42:23 +0800 Subject: [PATCH 1874/1892] [Refine] Improve generated file broadcast with multithread writing and small file batching (#2) 1. batch small files for better performance, especially when the world size is big (a lot of small files will be generated) 2. Multi-threading IO has better performance on SSD, especially on NVMe SSD with PCIe. 3. [TODO] Chunking (+ pin memory) can have an even better performance for huge files. Currently the max size of weight files in default setting is 1b*dtype_size (2GB for bf16). This hasn't done in this PR. --------- Co-authored-by: Hangbo Bao <10023639+addf400@users.noreply.github.com> Co-authored-by: addf400 --- nnscaler/parallel.py | 94 +++++++++--------- nnscaler/utils.py | 108 +++++++++++++++++++++ tests/conftest.py | 9 ++ tests/parallel_module/test_broadcast.py | 123 +++++++++++++----------- 4 files changed, 230 insertions(+), 104 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index b7bdcc79..ccab6a8c 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -56,7 +56,8 @@ setup_stride_broadcast_group, get_shared_params, OptStateDict, - copy_dynamic + copy_dynamic, + broadcast_files, ) logger = logging.getLogger(__name__) @@ -2373,56 +2374,49 @@ def _broadcast_gen_files( return curr_rank = torch.distributed.get_rank() - ranks = list(range(0, world_size, local_world_size)) - group = DeviceGroup().get_group(ranks) - - # use the first rank of each node to broadcast - if curr_rank % local_world_size == 0: - _, outdir = _prepare_namespace(gen_savedir, module_class, instance_name) - files: List[str] = [] - # send file list - if curr_rank == 0: - for file in outdir.glob('*'): - if file.is_file() and ( - broadcast_strategy == BroadcastGenFilesStrategy.ALL or - ( - broadcast_strategy == BroadcastGenFilesStrategy.NO_WEIGHTS - and not file.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) - ) or - ( - # broadcast code files and compute config file - # please note the compute config file can be updated - # even when the graph is reused. - broadcast_strategy == BroadcastGenFilesStrategy.CODE - and (file.suffix == '.py' or file.name == ParallelModule.COMPUTE_CONFIG_FILE) - ) - ): - files.append(file.name) - sent_obj = [files] + + # use all ranks of each node to broadcast + _, outdir = _prepare_namespace(gen_savedir, module_class, instance_name) + files: List[str] = [] + # send file list + if curr_rank == 0: + for file in outdir.glob('*'): + if file.is_file() and ( + broadcast_strategy == BroadcastGenFilesStrategy.ALL or + ( + broadcast_strategy == BroadcastGenFilesStrategy.NO_WEIGHTS + and not file.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) + ) or + ( + # broadcast code files and compute config file + # please note the compute config file can be updated + # even when the graph is reused. + broadcast_strategy == BroadcastGenFilesStrategy.CODE + and (file.suffix == '.py' or file.name == ParallelModule.COMPUTE_CONFIG_FILE) + ) + ): + files.append(file.name) + sent_obj = [files] + else: + sent_obj = [None] + torch.distributed.broadcast_object_list( + sent_obj, + src=0, + ) + # get file list + if curr_rank != 0: + files = sent_obj[0] + + logger.info(f'File list broadcasted ({len(files)} in total).') + + grouped_files = [[]] # 0th groups for small files (attribute content files excluded) + for fname in files: + if not fname.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM): + grouped_files[0].append(outdir / fname) else: - sent_obj = [None] - torch.distributed.broadcast_object_list( - sent_obj, - src=0, - group=group, - ) - # get file list - if curr_rank != 0: - files = sent_obj[0] - - logger.info(f'File list broadcasted ({len(files)} in total).') - # send file content one by one - for fname in files: - if curr_rank == 0: - with open(outdir / fname, 'rb') as f: - data = [f.read()] - else: - data = [None] - torch.distributed.broadcast_object_list(data, src=0, group=group) - if curr_rank != 0: - with open(outdir / fname, 'wb') as f: - f.write(data[0]) - logger.info(f'File {fname} broadcasted.') + grouped_files.append([outdir / fname]) + + broadcast_files(grouped_files) # wait for all nodes to finish torch.distributed.barrier() diff --git a/nnscaler/utils.py b/nnscaler/utils.py index a88f4997..7f87b6bd 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -16,12 +16,17 @@ from dataclasses import dataclass, field import inspect import os +import warnings +from concurrent.futures import ThreadPoolExecutor +import itertools +import numpy as np import nnscaler from nnscaler.flags import RuntimeFlag, CompileFlag import torch + _logger = logging.getLogger(__name__) @@ -682,3 +687,106 @@ def get_dynamic(tensor: Any) -> set[int]: return getattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, set()) else: return getattr(tensor, NNSCALER_DYNAMIC_DIMS_NAME, set()) + + +@contextmanager +def suppress_warnings(message): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message=message) + yield + + +def broadcast_files( + file_groups: List[List[Union[str, Path]]], + *, + max_workers: int = 8, +): + """Broadcast files from src to all other nodes. Files are grouped into file_groups, + and each group of files are broadcasted together to get better performance. + + Args: + files (List[List[str | Path]]): List of file groups to be broadcasted. + Note that the file names should be the same across all ranks. + """ + from nnscaler.runtime.device import DeviceGroup + + # filter out empty file groups + file_groups = [ + fg for fg in file_groups if fg + ] + + curr_rank = torch.distributed.get_rank() + local_world_size = DeviceGroup().local_world_size + world_size = torch.distributed.get_world_size() + local_rank = curr_rank % local_world_size + + # create groups, make sure all groups are created correctly + for i in range(local_world_size): + group_ranks = list(range(i, world_size, local_world_size)) + DeviceGroup().get_group(group_ranks) + + # collect file sizes and broadcast + if curr_rank == 0: + file_group_sizes: List[List[int]] = [ + [os.path.getsize(file) for file in files] for files in file_groups + ] + exchange_objects = [file_group_sizes] + else: + exchange_objects = [None] + + torch.distributed.broadcast_object_list(exchange_objects, src=0) + file_group_sizes = exchange_objects[0] + + # sort file_groups by size descending to improve overlapping + file_groups_sizes_pairs = list(zip(file_groups, file_group_sizes)) + file_groups_sizes_pairs.sort(key=lambda x: sum(x[1]), reverse=True) + file_groups = [pair[0] for pair in file_groups_sizes_pairs] + file_group_sizes = [pair[1] for pair in file_groups_sizes_pairs] + + def _write_file(file: Union[str, Path], buffer, start, size): + _logger.info(f'Rank {curr_rank}: Writing file {file} of size {size} bytes.') + # have better performance than open + write + buffer[start: start + size].numpy().tofile(file) + + def _read_file(file, buffer, start, size): + _logger.info(f'Rank {curr_rank}: Reading file {file} of size {size} bytes.') + # slightly faster than open + read + buffer[start: start + size] = torch.from_numpy(np.fromfile(file, dtype=np.uint8)) + + def _write_files(buffer, files, file_sizes): + buffer = buffer.cpu() + file_starts = itertools.accumulate([0] + file_sizes[:-1]) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + executor.map( + lambda args: _write_file(args[0], buffer, args[1], args[2]), + zip(files, file_starts, file_sizes) + ) + + def _send_file_group(src, files, file_sizes): + total_size = sum(file_sizes) + + ranks = list(range(src, world_size, local_world_size)) + group = DeviceGroup().get_group(ranks) + file_buffer = torch.empty(total_size, dtype=torch.uint8, device='cpu').pin_memory() + + if curr_rank < local_world_size: + file_starts = itertools.accumulate([0] + file_sizes[:-1]) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + executor.map( + lambda args: _read_file(args[0], file_buffer, args[1], args[2]), + zip(files, file_starts, file_sizes) + ) + broadcast_tensor = file_buffer.cuda() + else: + broadcast_tensor = torch.empty(total_size, dtype=torch.uint8, device='cuda') + + torch.distributed.broadcast(broadcast_tensor, src=src, group=group) + + if curr_rank >= local_world_size: + file_buffer.copy_(broadcast_tensor) + _write_files(file_buffer, files, file_sizes) + + # we split the file groups among local ranks + # each local rank sends its assigned file groups (in round robin fashion) + for i in range(local_rank, len(file_groups), local_world_size): + _send_file_group(local_rank, file_groups[i], file_group_sizes[i]) diff --git a/tests/conftest.py b/tests/conftest.py index b581d41c..126b05f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,3 +29,12 @@ def clean_generated_files(): f.unlink() for f in basedir.glob('gencode*.py'): f.unlink() + + +def pytest_collection_modifyitems(session, config, items): + def policy_first(item): + # it is very easy to break policy related tests, so run them first + if item.fspath.basename == 'test_policies.py': + return 0 + return 1 + items.sort(key=policy_first) diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index 5a305499..e7552ab8 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -8,7 +8,7 @@ import pytest import torch -from nnscaler.parallel import ComputeConfig, parallelize, broadcast_weights +from nnscaler.parallel import ComputeConfig, _prepare_namespace, parallelize, broadcast_weights from .common import init_distributed from ..launch_torchrun import launch_torchrun @@ -41,13 +41,24 @@ def _to_cube_model(module, compute_config, cube_savedir, ) -def _gpu_worker(): +def _gpu_worker(tmp_path): init_distributed() + world_size = torch.distributed.get_world_size() + local_world_size = world_size // 2 # fake two machines, as we use different cube_savedir for each worker - os.environ['LOCAL_WORLD_SIZE'] = '1' + os.environ['LOCAL_WORLD_SIZE'] = str(local_world_size) + tempdir = tmp_path / f'worker_{torch.distributed.get_rank() // local_world_size}' + node_rank = torch.distributed.get_rank() // local_world_size + + # from nnscaler.runtime.device import DeviceGroup + # # create groups + # for i in range(local_world_size): + # group_ranks = list(range(i, world_size, local_world_size)) + # DeviceGroup().get_group(group_ranks) + p = lambda t, b, i, load_module=True, **kwargs: _to_cube_model( Module(), - ComputeConfig(1, 2), + ComputeConfig(1, world_size), t, load_module=load_module, broadcast_strategy=b, @@ -56,74 +67,78 @@ def _gpu_worker(): ) # case 1: no broadcast, so only rank 0 can load the module # rank 1 will raise ModuleNotFoundError - with tempfile.TemporaryDirectory() as tempdir: - if torch.distributed.get_rank() == 0: - p(tempdir, 'none', '_1') - else: - with pytest.raises(ModuleNotFoundError): - p(tempdir, 'none', '_1') + # this will hang forever due to the distributed group creation in generated code. + # if node_rank == 0: + # p(tempdir, 'none', '_1') + # else: + # with pytest.raises(ModuleNotFoundError): + # p(tempdir, 'none', '_1') # case 2: broadcast only code, so only rank 0 can load the module # rank 1 will raise FileNotFoundError because it will fail to load attr_map files and more - with tempfile.TemporaryDirectory() as tempdir: - if torch.distributed.get_rank() == 0: + if node_rank == 0: + p(tempdir, 'code', '_2') + else: + with pytest.raises(FileNotFoundError): p(tempdir, 'code', '_2') - else: - with pytest.raises(FileNotFoundError): - p(tempdir, 'code', '_2') # case 3: broadcast except weights, so only rank 0 can load the module # rank 1 will raise RuntimeError because it will fail to load fullmodel.pt - with tempfile.TemporaryDirectory() as tempdir: - if torch.distributed.get_rank() == 0: + if node_rank == 0: + p(tempdir, 'no_weights', '_3') + else: + with pytest.raises(RuntimeError, match='Cannot find file.*'): p(tempdir, 'no_weights', '_3') - else: - with pytest.raises(RuntimeError, match='Cannot find file.*'): - p(tempdir, 'no_weights', '_3') # case 4: broadcast except weights, every rank can succeed if don't lood init params - with tempfile.TemporaryDirectory() as tempdir: - m = p(tempdir, 'no_weights', '_4', - init_module_params=torch.distributed.get_rank() == 0 - ) - if torch.distributed.get_rank() == 0: - for n, pa in m.named_parameters(): - if n.startswith('linear_weight'): - pa.data.fill_(1.0) - else: - for n, pa in m.named_parameters(): - if n.startswith('linear_weight'): - assert not torch.equal(pa.data, torch.ones_like(pa.data)) - broadcast_weights(m) - # check if broadcast_weights works + m = p(tempdir, 'no_weights', '_4', + init_module_params=torch.distributed.get_rank() == 0 + ) + if node_rank == 0: + for n, pa in m.named_parameters(): + if n.startswith('linear_weight'): + pa.data.fill_(1.0) + else: for n, pa in m.named_parameters(): if n.startswith('linear_weight'): - assert torch.equal(pa.data, torch.ones_like(pa.data)) + assert not torch.equal(pa.data, torch.ones_like(pa.data)) + broadcast_weights(m) + # check if broadcast_weights works + for n, pa in m.named_parameters(): + if n.startswith('linear_weight'): + assert torch.equal(pa.data, torch.ones_like(pa.data)) # case 5: broadcast all, all ranks will succeed - with tempfile.TemporaryDirectory() as tempdir: - p(tempdir, 'all', '_5') + p(tempdir, 'all', '_5') # case 6: test incremental broadcast - with tempfile.TemporaryDirectory() as tempdir: - # generate without broadcasting - m = p(tempdir, 'none', '_6', load_module=False) - if torch.distributed.get_rank() != 0: - assert list(Path(tempdir).glob('*')) == [] - - # case 6.1: broadcast code even we set broadcast_strategy to `all` - # because only code is new generated. - m = p(tempdir, 'all', '_6', load_module=False, reuse='graph') - if torch.distributed.get_rank() != 0: - # only python files are broadcasted - assert set(f.name for f in Path(tempdir).glob('**/*') if f.is_file()) == set(['gencode0.py', 'gencode1.py', 'compute_config.pt']) + # generate without broadcasting + _, outdir6 = _prepare_namespace(tempdir, Module, '_6') + m = p(tempdir, 'none', '_6', load_module=False) + if node_rank != 0: + assert list(Path(outdir6).glob('*')) == [] + + # case 6.1: broadcast code even we set broadcast_strategy to `all` + # because only code is new generated. + m = p(tempdir, 'all', '_6', load_module=False, reuse='graph') + if node_rank != 0: + # only python files are broadcasted + assert set(f.name for f in Path(outdir6).glob('**/*') if f.is_file()) == set( + [f'gencode{i}.py' for i in range(world_size)] + ['compute_config.pt'] + ) - # case 6.2: everything should be broadcasted, including weights - # so the load_module will succeed. - m = p(tempdir, 'all', '_6', load_module=True, reuse='override') + torch.distributed.barrier() + # case 6.2: everything should be broadcasted, including weights + # so the load_module will succeed. + m = p(tempdir, 'all', '_6', load_module=True, reuse='override') @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') -def test_broadcast(): - launch_torchrun(2, _gpu_worker) +def test_broadcast(tmp_path): + launch_torchrun(2, _gpu_worker, tmp_path) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_broadcast4(tmp_path): + launch_torchrun(4, _gpu_worker, tmp_path) From ee5eb8f5626d14249405c8687d940382aec16226 Mon Sep 17 00:00:00 2001 From: XU Weijiang <90586345+0xWJ@users.noreply.github.com> Date: Mon, 19 Jan 2026 08:28:19 +0800 Subject: [PATCH 1875/1892] [Feat] Add support for gathering full model state from all ranks (#4) Add support for gathering full model state from all ranks. A potential usage is when we use nnscaler in RL, and need to sync weights to rollout engine(like vllm) --- nnscaler/cli/trainer.py | 1 + nnscaler/parallel.py | 47 ++++++ nnscaler/utils.py | 194 +++++++++++++++++++++++ tests/parallel_module/test_checkpoint.py | 36 ++++- 4 files changed, 277 insertions(+), 1 deletion(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 48e8730d..4014aced 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -308,6 +308,7 @@ def _broadcast_merged_state_dict( We can't broadcast the whole state_dict at once, because it may be too large, and leads to OOM. Here we will break the model and optimizer state_dict into smaller pieces and broadcast them one by one. Please note we use `torch.distributed.broadcast_object_list` to broadcast the state_dict (including tensors inside). + TODO: optimize the broadcast by sending tensors separately. """ dst_ranks = dst_ranks or list(range(torch.distributed.get_world_size())) if src_rank not in dst_ranks or self.rank not in dst_ranks: diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index ccab6a8c..49376685 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -58,6 +58,8 @@ OptStateDict, copy_dynamic, broadcast_files, + broadcast_mixed_data, + gather_mixed_data, ) logger = logging.getLogger(__name__) @@ -3229,3 +3231,48 @@ def load_merged_state_dict_from_rank( module.load_state_dict(trimmed_module_state_dict) if trimmed_opt_state_dict: optimizer.load_state_dict(trimmed_opt_state_dict) + + +@torch.no_grad() +def gather_full_model_state_dict( + module: torch.nn.Module, +) -> Dict[str, Any]: + """ + Gather model state dicts from all ranks, + And merge them into a single merged model state dict in all ranks. + + Args: + module (torch.nn.Module): the module to gather state dicts from + + Returns: + Dict[str, Any]: the merged model state dict + """ + + rank = torch.distributed.get_rank() + parallel_modules = [m for m in module.modules() if isinstance(m, ParallelModule)] + if not parallel_modules: + raise ValueError("No ParallelModule found in the module.") + parallel_module = parallel_modules[0] + compute_config = parallel_module.compute_config + num_involved_ranks = compute_config.module_dedup_group_size + involved_group = DeviceGroup().get_group(list(range(num_involved_ranks))) + + logger.info(f'Gathering full model state dict from ranks {list(range(num_involved_ranks))}') + + if rank < num_involved_ranks: + local_state_dict, _ = deduped_state_dict(module, optimizer=None) + logger.info(f'Rank {rank}: gathering state dict') + state_dicts = gather_mixed_data(local_state_dict, src_rank=0, group=involved_group, device='cpu') + if rank == 0: + logger.info(f'Rank {rank}: merging gathered state dicts') + merge_state_dict = merge_state_dicts(state_dicts) + else: + merge_state_dict = None + else: + merge_state_dict = None + + logger.info(f'Rank {rank}: Broadcasting merged state dict to all ranks') + merge_state_dict = broadcast_mixed_data(merge_state_dict, src_rank=0) + logger.info(f'Rank {rank}: Finished gathering full model state dict') + torch.distributed.barrier() + return merge_state_dict diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 7f87b6bd..a4375fa8 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -790,3 +790,197 @@ def _send_file_group(src, files, file_sizes): # each local rank sends its assigned file groups (in round robin fashion) for i in range(local_rank, len(file_groups), local_world_size): _send_file_group(local_rank, file_groups[i], file_group_sizes[i]) + + +class _TensorIndex: + def __init__(self, index: int): + self.index = index + + def __repr__(self): + return f"_TensorIndex({self.index})" + + +def extract_tensors(data: Dict[str, Any]) -> Tuple[Dict[str, Any], List[torch.Tensor]]: + """ + Extract tensors from a collection, and return the skeleton (by replacing tensors with _TensorIndex) and the list of tensors. + Args: + data (Dict[str, Any]): The collection to be extracted. + Returns: + Tuple[Dict[str, Any], List[torch.Tensor]]: The skeleton and the list of tensors. + """ + tensors = [] + + # used to deduplicate tensors + # TODO: Consider more robust way to identify tensors + # key: (tensor.data_ptr(), tensor.shape, tensor.stride()), value: _Index + tensor_ids: dict[tuple[int, tuple[int, ...], tuple[int, ...]], _TensorIndex] = {} + def transform_fn(o: torch.Tensor) -> Any: + key = (o.data_ptr(), o.shape, o.stride()) + if key in tensor_ids: + idx = tensor_ids[key] + else: + idx = _TensorIndex(len(tensors)) + tensor_ids[key] = idx + tensors.append(o) + return idx + skeleton = transform_recursively(data, transform_fn, target_types=(torch.Tensor,)) + + return skeleton, tensors + + +def refill_tensors(skeleton: Dict[str, Any], tensors: List[torch.Tensor]) -> Dict[str, Any]: + """ + Refill tensors into the skeleton, and return the data. + This is the inverse operation of `extract_tensors`. + + Args: + skeleton (Dict[str, Any]): The skeleton to be refilled. + tensors (List[torch.Tensor]): The list of tensors to be refilled. + Returns: + Dict[str, Any]: The data. + """ + def transform_fn(o: _TensorIndex) -> Any: + return tensors[o.index] + state_dict = transform_recursively(skeleton, transform_fn, target_types=_TensorIndex) + return state_dict + + +def broadcast_mixed_data( + data: Optional[dict] = None, + *, + src_rank: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +): + """ + Broadcast the data (containing tensors) from src_rank to all other ranks. + + Args: + data (Optional[dict]): The data to be broadcasted. + for non-src ranks, this must be None. + src_rank (int): The source rank to broadcast from. Default: 0. + group (torch.distributed.ProcessGroup, optional): The process group to use for broadcasting. + If None, the default process group will be used. Default: None. + + Returns: + dict: The broadcasted data. + For src_rank, it is the same as the input data. + For non-src ranks, it is the broadcasted data. the device of tensors will be cuda. + """ + rank = torch.distributed.get_rank(group=group) + + # share the structure and tensor shapes + if rank == src_rank: + if data is None: + raise ValueError("data must not be None in src_rank") + skeleton, tensors = extract_tensors(data) + meta_tensors = [t.to('meta') for t in tensors] + sent = [(skeleton, meta_tensors)] + else: + if data is not None: + raise ValueError("data must be None in non-src ranks") + skeleton, tensors, meta_tensors = None, None, None + sent = [None] + + torch.distributed.broadcast_object_list(sent, src=src_rank, group=group) + skeleton, meta_tensors = sent[0] + if rank != src_rank: + tensors = [torch.empty_like(mt, device='cuda') for mt in meta_tensors] + else: + # make sure tensors are in cuda + tensors = [t.cuda() for t in tensors] + + # broadcast tensor data + for i in range(len(tensors)): + torch.distributed.broadcast(tensors[i], src=src_rank, group=group) + + # refill tensors + if rank != src_rank: + return refill_tensors(skeleton, tensors) + else: + return data + + +def gather_mixed_data( + data: dict, + *, + src_rank: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, + device: Optional[Union[str, torch.device]] = None, +): + """ + Gather the data (containing tensors) from all ranks to src_rank. + + Args: + data (dict): The data to be gathered. + src_rank (int): The source rank to gather to. Default: 0. + group (torch.distributed.ProcessGroup, optional): The process group to use for gathering. + If None, the default process group will be used. Default: None. + device (str or torch.device, optional): The device to use for receiving tensors on src_rank. + If None, the current cuda device will be used. Default: None. + If you want to save memory, you can set it to 'cpu' to move tensors to cpu after receiving. + Returns: + dict: The gathered data. + For src_rank, it is the gathered data from all ranks. + For non-src ranks, it is None. + """ + device = torch.cuda.current_device() if device is None else device + + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + result = [None] * world_size + result[rank] = data + + skeleton, tensors = extract_tensors(data) + sent = (skeleton, [t.to('meta') for t in tensors]) + + # Gather metadata from all ranks + gathered_sent = [None for _ in range(world_size)] + torch.distributed.all_gather_object(gathered_sent, sent, group=group) + + def _send_recv_tensors( + sender: int, + skel: Dict[str, Any], + tensors: list[torch.Tensor] + ) -> Dict[str, Any]: + if rank == src_rank: + assert all(tensor.device.type == 'meta' for tensor in tensors), \ + "Tensors should be on meta device on rank 0." + if rank != src_rank: + assert all(tensor.device.type != 'meta' for tensor in tensors), \ + f"Tensors should not be on meta device on rank {rank}." + + if rank == src_rank: + cuda_tensors = [torch.empty_like(tensor, device='cuda') for tensor in tensors] + else: + cuda_tensors = [tensor.cuda() for tensor in tensors] + + for i in range(len(tensors)): + if rank == src_rank: + torch.distributed.recv(cuda_tensors[i], group_src=sender, group=group) + else: + torch.distributed.send(cuda_tensors[i], group_dst=src_rank, group=group) + + if rank == src_rank: + tensors = [tensor.to(device, non_blocking=True) for tensor in cuda_tensors] + return transform_recursively( + skel, + lambda idx: tensors[idx.index], + target_types=_TensorIndex, + ) + else: + return None # only rank 0 needs the recovered state dict + + # TODO: It may have performance issue if the number of ranks is large + for i in range(0, world_size): + if i == src_rank: + continue + if rank == src_rank: + result[i] = _send_recv_tensors(i, gathered_sent[i][0], gathered_sent[i][1]) + elif rank == i: + _send_recv_tensors(rank, skeleton, tensors) + torch.distributed.barrier(group=group) + + if rank == src_rank: + return result + else: + return None diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index d84d7503..51a6328d 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -608,6 +608,40 @@ def _gpu_merge_worker(): ) -@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 24, reason='lack of gpu devices') +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_checkpoint_merge(): launch_torchrun(4, _gpu_merge_worker) + + +def _gather_full_model_state_dict_worker(tmp_path, use_zero): + from .test_end2end import MLP, dummy_data + from nnscaler.parallel import gather_full_model_state_dict, merge_state_dicts + init_distributed() + + model = MLP() + model = parallelize( + model, + {'data': dummy_data()}, + pas_policy='tp', + compute_config= ComputeConfig( + 2, 4, + use_end2end=True, + use_zero=use_zero, + ), + gen_savedir=tmp_path + ) + model.cuda() + rank = torch.distributed.get_rank() + torch.save(model.state_dict(), tmp_path / f'{rank}.pt') + torch.distributed.barrier() + merged_state_dict = merge_state_dicts( + [torch.load(tmp_path / f'{i}.pt', weights_only=False) for i in range(torch.distributed.get_world_size())] + ) + full_state_dict = gather_full_model_state_dict(model) + assert_equal(merged_state_dict, full_state_dict) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('use_zero', [0, 1, 3]) +def test_gather_full_model_state_dict(tmp_path, use_zero): + launch_torchrun(4, _gather_full_model_state_dict_worker, tmp_path, use_zero) From fd8f704dd05fa34295fcf2fa4d4ab365cf68d548 Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Mon, 19 Jan 2026 11:20:04 +0800 Subject: [PATCH 1876/1892] [Refine] Reduce memory fragment when resuming --- nnscaler/parallel.py | 41 ++++++++++++++++++++++++++++------------- nnscaler/utils.py | 6 +++--- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 49376685..cf3d4166 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2540,7 +2540,7 @@ def load_deduped_state_dict( module: torch.nn.Module, module_state_dict: Dict[str, Any], optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, - optimizer_state_dict: Optional[Dict[str, Any]] = None, + optimizer_state_dict: Optional[OptStateDict] = None, *, device: Union[str, torch.device] = None ) -> None: @@ -2597,21 +2597,30 @@ def load_deduped_state_dict( broadcast_tensor = getattr(pm, local_name) logger.info(f'Broadcast: {key} from {cur_rank}.') else: - broadcast_tensor = torch.empty(shape, device=device, requires_grad=False, dtype=dtype) - torch.distributed.broadcast(broadcast_tensor, src=rank, group=broadcast_group) - if rank != cur_rank: # in pipeline parallelism, the local_name may not be found in the module + existing_tensor = None if hasattr(pm, local_name): logger.info(f'At rank {cur_rank}, try to load: {key} from rank {rank}.') attr = getattr(pm, local_name) + + broadcast_tensor = attr.data if key in missing_keys: - attr.data.copy_(broadcast_tensor) missing_keys.remove(key) else: - assert torch.equal(attr, broadcast_tensor), \ - f'At rank {cur_rank}, the attribute {key} is already loaded, but not equal to the broadcasted tensor from rank {rank}.' + # the tensor is already loaded, we need to check if they are equal after broadcast + existing_tensor = broadcast_tensor.cpu() else: logger.info(f'At rank {cur_rank}, skip to load: {key} from rank {rank}, not found in the module.') + # we still need to create a tensor to receive the broadcasted data + # TODO: this rank should be removed from the broadcast group + broadcast_tensor = torch.empty(shape, device=device, requires_grad=False, dtype=dtype) + + torch.distributed.broadcast(broadcast_tensor, src=rank, group=broadcast_group) + + if rank != cur_rank: + if existing_tensor is not None: + assert torch.equal(existing_tensor, broadcast_tensor.cpu()), \ + f'At rank {cur_rank}, the attribute {key} is already loaded, but not equal to the broadcasted tensor from rank {rank}.' for key in missing_keys: split_names = key.split('.') @@ -2632,11 +2641,6 @@ def load_deduped_state_dict( if not _is_supported_optimizer(optimizer._extra_state.name): raise ValueError("Only Adam-like optimizers are supported.") - for idx, state in optimizer_state_dict['state'].items(): - for key, value in state.items(): - if isinstance(value, torch.Tensor): - optimizer_state_dict['state'][idx][key] = value.to(device) - # get the locations of non-parallel module parameters # by removing the parallel module locations non_parallel_module_locs: Set[int] = set(optimizer_state_dict['param_groups'][0]['params']) @@ -2657,12 +2661,19 @@ def load_deduped_state_dict( for bg in opt_broadcast_groups: _broadcast_opt_state(optimizer_state_dict, *bg) + + # make sure all tensors are in the target device + for idx, state in optimizer_state_dict['state'].items(): + for key, value in state.items(): + if isinstance(value, torch.Tensor): + optimizer_state_dict['state'][idx][key] = value.to(device) + optimizer.load_state_dict(optimizer_state_dict) torch.distributed.barrier() -def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_group_size: int): +def _broadcast_opt_state(optimizer_state_dict: OptStateDict, state_indexes: List[int], dedup_group_size: int): if not state_indexes: return @@ -2691,6 +2702,10 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g key: torch.zeros(value[0], dtype=value[1], device=torch.cuda.current_device()) for key, value in v.items() } + else: + for idx in state_indexes: + for key, value in optimizer_state_dict['state'][idx].items(): + optimizer_state_dict['state'][idx][key] = optimizer_state_dict['state'][idx][key].cuda() # broadcast step # step is too small, so we can just broadcast all of them all together diff --git a/nnscaler/utils.py b/nnscaler/utils.py index a4375fa8..fceb454a 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -640,12 +640,12 @@ class AdamOptState(TypedDict): class OptStateParamGroup(TypedDict): params: list[int] - lr: int + lr: float class OptStateDict(TypedDict): - state: dict[int, AdamOptState | dict[str, Any]] - param_groups: list[OptStateParamGroup | dict[str, Any]] + state: dict[int, AdamOptState | dict[str, Union[Any, torch.Tensor]]] + param_groups: list[OptStateParamGroup | dict[str, Union[Any, torch.Tensor]]] def fn_field(**kwargs): From e584751c4cc7af99319fb1ddee67e48639dffdfc Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Mon, 19 Jan 2026 11:26:47 +0800 Subject: [PATCH 1877/1892] refine comment --- nnscaler/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index cf3d4166..e6d95f90 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2607,7 +2607,7 @@ def load_deduped_state_dict( if key in missing_keys: missing_keys.remove(key) else: - # the tensor is already loaded, we need to check if they are equal after broadcast + # the tensor is already loaded, we need to check if they are equal existing_tensor = broadcast_tensor.cpu() else: logger.info(f'At rank {cur_rank}, skip to load: {key} from rank {rank}, not found in the module.') From 8024fad03b65b0710a5c25469bbcf1f7125df34b Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Tue, 20 Jan 2026 10:43:19 +0800 Subject: [PATCH 1878/1892] code refine --- nnscaler/parallel.py | 100 +++++++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 42 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index e6d95f90..94c83c29 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2559,6 +2559,7 @@ def load_deduped_state_dict( """ device = device or torch.cuda.current_device() cur_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() # step 1: load deduped state dict at each rank missing_keys, unexpected_keys = module.load_state_dict(module_state_dict, strict=False) @@ -2570,62 +2571,77 @@ def load_deduped_state_dict( # step 2: broadcast deduped weights inside 1st scale unit for non-zero3 parallel modules # for zero3 modules, the weights are already complete after step 1 # TODO: refine zero3 modules support - parallel_modules = { + no_zero3_pms = { prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule) and m.compute_config.use_zero <= 1 } - if parallel_modules: - rank2deduped_fullmap, dedup_group_size, global_tensor_meta = _collect_dedup_info(parallel_modules) + if no_zero3_pms: + rank2deduped_fullmap, dedup_group_size, _ = _collect_dedup_info(no_zero3_pms) logger.debug(f'At rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}.') + # collect dedup info from dedup information + # Key: (prefix, local_name) + # Value: list[rank]: a list of ranks that have the local_name + local_name2rank_attr_map: Dict[tuple[str, str], list[int]] = {} + for rank in range(world_size): + for prefix, m in no_zero3_pms.items(): + if rank >= dedup_group_size[prefix]: + continue + for local_name, _ in m.get_attr_meta_map(rank).items(): + key = (prefix, local_name) + if key not in local_name2rank_attr_map: + local_name2rank_attr_map[key] = [] + local_name2rank_attr_map[key].append(rank) + + # create process groups for broadcasting + for key, ranks in local_name2rank_attr_map.items(): + if len(ranks) <= 1: + continue + ranks.sort() + DeviceGroup().get_group(ranks) + # broadcast weights in parallel modules - for rank, deduped_fullmap in rank2deduped_fullmap.items(): - logger.debug(f'At rank {cur_rank}, process rank: {rank}.') - for prefix, fullmap in deduped_fullmap.items(): - if cur_rank >= dedup_group_size[prefix]: - break - broadcast_group = DeviceGroup().get_group(list(range(dedup_group_size[prefix]))) - for local_name, attr_meta in fullmap.items(): - key = f'{prefix}.{local_name}' if prefix else local_name - assert prefix in parallel_modules, f'Prefix {prefix} not found in parallel_modules: {list(parallel_modules.keys())}.' - pm = parallel_modules[prefix] - attr_meta = global_tensor_meta[rank][prefix][local_name] - shape, dtype = attr_meta.sub_shape, attr_meta.dtype - if rank == cur_rank: - assert hasattr(pm, local_name), f'Local name {local_name} not found in {pm}.' - broadcast_tensor = getattr(pm, local_name) - logger.info(f'Broadcast: {key} from {cur_rank}.') + for key_name, ranks in local_name2rank_attr_map.items(): + if len(ranks) <= 1: + continue + prefix, local_name = key_name + if cur_rank in ranks: + key = f'{prefix}.{local_name}' if prefix else local_name + broadcast_group = DeviceGroup().get_group(ranks) + assert prefix in no_zero3_pms, f'Prefix {prefix} not found in parallel_modules: {list(no_zero3_pms.keys())}.' + pm = no_zero3_pms[prefix] + assert hasattr(pm, local_name), f'Local name {local_name} not found in {pm}.' + # the shared tensor will always store in the smallest rank in the dedup group + if cur_rank == ranks[0]: + broadcast_tensor = getattr(pm, local_name) + logger.info(f'Broadcast: {key} from {cur_rank}.') + else: + existing_tensor = None + logger.info(f'At rank {cur_rank}, try to load: {key} from rank {ranks[0]}.') + attr = getattr(pm, local_name) + + broadcast_tensor = attr.data + if key in missing_keys: + missing_keys.remove(key) else: - # in pipeline parallelism, the local_name may not be found in the module - existing_tensor = None - if hasattr(pm, local_name): - logger.info(f'At rank {cur_rank}, try to load: {key} from rank {rank}.') - attr = getattr(pm, local_name) - - broadcast_tensor = attr.data - if key in missing_keys: - missing_keys.remove(key) - else: - # the tensor is already loaded, we need to check if they are equal - existing_tensor = broadcast_tensor.cpu() - else: - logger.info(f'At rank {cur_rank}, skip to load: {key} from rank {rank}, not found in the module.') - # we still need to create a tensor to receive the broadcasted data - # TODO: this rank should be removed from the broadcast group - broadcast_tensor = torch.empty(shape, device=device, requires_grad=False, dtype=dtype) + # the tensor is already loaded, we need to check if they are equal + existing_tensor = broadcast_tensor.cpu() + + torch.distributed.broadcast(broadcast_tensor, src=ranks[0], group=broadcast_group) - torch.distributed.broadcast(broadcast_tensor, src=rank, group=broadcast_group) + if cur_rank != ranks[0]: + if existing_tensor is not None: + assert torch.equal(existing_tensor, broadcast_tensor.cpu()), \ + f'At rank {cur_rank}, the attribute {key} is already loaded, ' \ + f'but not equal to the broadcasted tensor from rank {ranks[0]}.' - if rank != cur_rank: - if existing_tensor is not None: - assert torch.equal(existing_tensor, broadcast_tensor.cpu()), \ - f'At rank {cur_rank}, the attribute {key} is already loaded, but not equal to the broadcasted tensor from rank {rank}.' + torch.distributed.barrier() for key in missing_keys: split_names = key.split('.') prefix = '.'.join(split_names[:-1]) # remove the last part of the key - assert prefix not in parallel_modules or cur_rank >= dedup_group_size[prefix], f'At rank {cur_rank}, the missing key {key} should be in non-parallel modules.' + assert prefix not in no_zero3_pms or cur_rank >= dedup_group_size[prefix], f'At rank {cur_rank}, the missing key {key} should be in non-parallel modules.' # At this point # - All parallel modules in first scale unit should be complete. From 2ff7e545a20af1d532bdc5a55c581547b758841c Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Tue, 20 Jan 2026 10:57:37 +0800 Subject: [PATCH 1879/1892] refine code --- nnscaler/parallel.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 94c83c29..a7d5fe3a 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2583,26 +2583,27 @@ def load_deduped_state_dict( # collect dedup info from dedup information # Key: (prefix, local_name) # Value: list[rank]: a list of ranks that have the local_name - local_name2rank_attr_map: Dict[tuple[str, str], list[int]] = {} - for rank in range(world_size): - for prefix, m in no_zero3_pms.items(): - if rank >= dedup_group_size[prefix]: - continue + local_name2ranks: Dict[tuple[str, str], list[int]] = {} + + for prefix, m in no_zero3_pms.items(): + for rank in range(dedup_group_size[prefix]): for local_name, _ in m.get_attr_meta_map(rank).items(): key = (prefix, local_name) - if key not in local_name2rank_attr_map: - local_name2rank_attr_map[key] = [] - local_name2rank_attr_map[key].append(rank) + if key not in local_name2ranks: + local_name2ranks[key] = [] + local_name2ranks[key].append(rank) # create process groups for broadcasting - for key, ranks in local_name2rank_attr_map.items(): + for key, ranks in local_name2ranks.items(): if len(ranks) <= 1: continue + # should have sorted. ranks.sort() DeviceGroup().get_group(ranks) - # broadcast weights in parallel modules - for key_name, ranks in local_name2rank_attr_map.items(): + # broadcast weights in parallel modules inside dedup group (most time it is the 1st scale unit) + # Implementation of `deduped_state_dict` can guarantee that the first rank in each rank group always has the weights + for key_name, ranks in local_name2ranks.items(): if len(ranks) <= 1: continue prefix, local_name = key_name @@ -2631,6 +2632,8 @@ def load_deduped_state_dict( torch.distributed.broadcast(broadcast_tensor, src=ranks[0], group=broadcast_group) if cur_rank != ranks[0]: + # it should not come here if _collect_dedup_info is strict + # anyway, we add an assertion here to make sure if existing_tensor is not None: assert torch.equal(existing_tensor, broadcast_tensor.cpu()), \ f'At rank {cur_rank}, the attribute {key} is already loaded, ' \ From 0014141b0c7cf301a1c4a8ee2c67b691f03af137 Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Tue, 20 Jan 2026 12:20:59 +0800 Subject: [PATCH 1880/1892] refine comments --- nnscaler/parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index a7d5fe3a..b704ac88 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2580,7 +2580,7 @@ def load_deduped_state_dict( rank2deduped_fullmap, dedup_group_size, _ = _collect_dedup_info(no_zero3_pms) logger.debug(f'At rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}.') - # collect dedup info from dedup information + # collect dedup info from attr meta maps # Key: (prefix, local_name) # Value: list[rank]: a list of ranks that have the local_name local_name2ranks: Dict[tuple[str, str], list[int]] = {} @@ -2627,6 +2627,7 @@ def load_deduped_state_dict( missing_keys.remove(key) else: # the tensor is already loaded, we need to check if they are equal + # it should not come here if _collect_dedup_info is strict existing_tensor = broadcast_tensor.cpu() torch.distributed.broadcast(broadcast_tensor, src=ranks[0], group=broadcast_group) From 9d5b02e731cbada1912ccb24956df47288057fd8 Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Wed, 21 Jan 2026 09:20:37 +0800 Subject: [PATCH 1881/1892] add more debug info --- nnscaler/parallel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index b704ac88..0b8c699a 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2599,6 +2599,7 @@ def load_deduped_state_dict( continue # should have sorted. ranks.sort() + logger.debug(f'At rank {cur_rank}, create groups for ranks: {ranks}.') DeviceGroup().get_group(ranks) # broadcast weights in parallel modules inside dedup group (most time it is the 1st scale unit) @@ -2630,6 +2631,7 @@ def load_deduped_state_dict( # it should not come here if _collect_dedup_info is strict existing_tensor = broadcast_tensor.cpu() + logger.debug(f'At rank {cur_rank}, broadcast from {ranks[0]} to {ranks} for `{key}`.') torch.distributed.broadcast(broadcast_tensor, src=ranks[0], group=broadcast_group) if cur_rank != ranks[0]: From c799251814e45083be0615b7ce182d247a88cd4b Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Wed, 21 Jan 2026 14:27:02 +0800 Subject: [PATCH 1882/1892] refine comment --- nnscaler/parallel.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 0b8c699a..bd10f5a5 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2684,12 +2684,6 @@ def load_deduped_state_dict( for bg in opt_broadcast_groups: _broadcast_opt_state(optimizer_state_dict, *bg) - # make sure all tensors are in the target device - for idx, state in optimizer_state_dict['state'].items(): - for key, value in state.items(): - if isinstance(value, torch.Tensor): - optimizer_state_dict['state'][idx][key] = value.to(device) - optimizer.load_state_dict(optimizer_state_dict) torch.distributed.barrier() @@ -2847,10 +2841,6 @@ def load_sharded_state_dict( module.load_state_dict(module_state_dict) module.to(device) if optimizer and optimizer_state_dict: - for idx, state in optimizer_state_dict.get('state', {}).items(): - for key, value in state.items(): - if isinstance(value, torch.Tensor): - optimizer_state_dict['state'][idx][key] = value.to(device) optimizer.load_state_dict(optimizer_state_dict) From 943b154b09d438cfecb82701473d962bb09a0361 Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Fri, 23 Jan 2026 09:38:21 +0800 Subject: [PATCH 1883/1892] refine code --- nnscaler/parallel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index bd10f5a5..407ec480 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2559,7 +2559,6 @@ def load_deduped_state_dict( """ device = device or torch.cuda.current_device() cur_rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() # step 1: load deduped state dict at each rank missing_keys, unexpected_keys = module.load_state_dict(module_state_dict, strict=False) @@ -2686,7 +2685,7 @@ def load_deduped_state_dict( optimizer.load_state_dict(optimizer_state_dict) - torch.distributed.barrier() + torch.distributed.barrier() def _broadcast_opt_state(optimizer_state_dict: OptStateDict, state_indexes: List[int], dedup_group_size: int): From 52d93222a0476952df4cca84afa4756d2ea24172 Mon Sep 17 00:00:00 2001 From: XU Weijiang Date: Sat, 24 Jan 2026 15:39:51 +0800 Subject: [PATCH 1884/1892] add barrier --- nnscaler/parallel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 407ec480..52f9f74a 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2601,6 +2601,8 @@ def load_deduped_state_dict( logger.debug(f'At rank {cur_rank}, create groups for ranks: {ranks}.') DeviceGroup().get_group(ranks) + torch.distributed.barrier() + # broadcast weights in parallel modules inside dedup group (most time it is the 1st scale unit) # Implementation of `deduped_state_dict` can guarantee that the first rank in each rank group always has the weights for key_name, ranks in local_name2ranks.items(): From 68c9cbee5b8356939aacb9bbac9f0b72a073cf72 Mon Sep 17 00:00:00 2001 From: XU Weijiang <90586345+0xWJ@users.noreply.github.com> Date: Wed, 28 Jan 2026 08:14:23 +0800 Subject: [PATCH 1885/1892] [Tracer] Provide better einops tracing by skipping tracing some internal einops functions. Tracing einops Functions are challenging due to their dynamic nature and heavy reliance on string-based patterns and runtime shape manipulations. To make things easier, we skip tracing the internal logic of einops functions and directly use the resolved transformation recipes. --- docs/source/einops.md | 38 +++++++++++ nnscaler/graph/parser/external/einops.py | 17 ++++- nnscaler/graph/tracer/concrete_tracer.py | 8 +++ requirements-dev.txt | 2 +- .../ring_attn/test_shuffle_varlen.py | 7 ++ tests/parallel_module/test_gencode_einops.py | 64 +++++++++++++++++++ 6 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 docs/source/einops.md diff --git a/docs/source/einops.md b/docs/source/einops.md new file mode 100644 index 00000000..e8103add --- /dev/null +++ b/docs/source/einops.md @@ -0,0 +1,38 @@ +# einops Support in NnScaler +================================= + +Tracing einops Functions are challenging due to their dynamic nature and heavy reliance on string-based patterns and runtime shape manipulations. It is challenging to statically analyze and trace these operations accurately, because tracing doesn't work well with complex python logic (e.g. string parsing, dynamic shape computations, loops, etc) involved in einops functions. + +To make things easier, we skip tracing the internal logic of einops functions and directly use the resolved transformation recipes. + +This is done by skipping tracing internal einops function: `_prepare_transformation_recipe`. In future, if einops changes their internal implementation, we may need to update our patching logic accordingly. + +For nnscaler, we may skip more functions in the future if needed. For exmaple, `_reconstruct_from_shape_uncached` and `_reconstruct_from_shape` are also candidates for skipping tracing, but currently we haven't found issues without skipping them. Once we find issues related to them, we will skip tracing them as well. + +As a result, when you use einops functions in your model, we can't guarantee that the traced recipe will be valid when their parameters are changed (e.g. input shapes or pattern strings. `compute_config.constant_folding=False` doesn't help here). + +Currently we haven't encountered problems in our tests, but it's still possible in some corner cases. If you encounter any problems, please report an issue to us. + +Here is an example of using einops in a model with NnScaler: + +```python +import torch +import torch.nn as nn +import einops +from nnscaler import nnscaler, ComputeConfig + +class EinopsModel(nn.Module): + def __init__(self): + ... + + def forward(self, x): + # this is good, because the pattern and the input shape is static (h/w/c are fixed) + x = einops.rearrange(x, 'b (h w c) -> b c h w', h=4, w=4, c=1) + ... + y = ... + # this depends on y + # although dependence maintains properly if you set `compute_config.constant_folding=False`, + # This can be changed in future. So be cautious when using such patterns. + x = einops.rearrange(x, 'b c h w -> b (h w c)', b=y) + ... +``` diff --git a/nnscaler/graph/parser/external/einops.py b/nnscaler/graph/parser/external/einops.py index 91845fe8..fb24f38e 100644 --- a/nnscaler/graph/parser/external/einops.py +++ b/nnscaler/graph/parser/external/einops.py @@ -10,9 +10,24 @@ try: import einops - # trigger einops initialization einops.rearrange(torch.arange(1), '(a b c) -> a b c', a=1, b=1, c=1) + + from nnscaler.graph.tracer.wrap_utils import default_never_wrap_function, LeafWrapInfo, Location + + default_never_wrap_function[einops.einops._prepare_transformation_recipe] = \ + LeafWrapInfo([Location(einops.einops, '_prepare_transformation_recipe')], False, None) + + # we comment out these two functions + # because it looks not necessary for now. + # and they also introduce some problems, + # i.e. dynamic shape will be lost even with `compute_config.constant_folding=False` + + # default_never_wrap_function[einops.einops._reconstruct_from_shape_uncached] = \ + # LeafWrapInfo([Location(einops.einops, '_reconstruct_from_shape_uncached')], False, None) + # default_never_wrap_function[einops.einops._reconstruct_from_shape] = \ + # LeafWrapInfo([Location(einops.einops, '_reconstruct_from_shape')], False, None) + except ImportError as e: _logger.debug("Einops is not installed") pass diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index 85f7c8f6..72c2acf7 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -30,6 +30,8 @@ from torch.fx.proxy import TracerBase, Scope from torch.fx.operator_schemas import check_for_mutable_operation +from nnscaler.utils import transform_recursively + dict_keys_type = type(dict().keys()) dict_values_type = type(dict().values()) dict_items_type = type(dict().items()) @@ -717,8 +719,14 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): def wrap_never_wrap_function(func, *args, **kwargs): if self.patcher.patch_mode: with self.patcher.revert(): + # unwrap all proxy in args/kwargs + args = transform_recursively(args, lambda x: x.value, target_types=ep.ConcreteProxy) + kwargs = transform_recursively(kwargs, lambda x: x.value, target_types=ep.ConcreteProxy) return func(*args, **kwargs) else: + # unwrap all proxy in args/kwargs + args = transform_recursively(args, lambda x: x.value, target_types=ep.ConcreteProxy) + kwargs = transform_recursively(kwargs, lambda x: x.value, target_types=ep.ConcreteProxy) return func(*args, **kwargs) try: diff --git a/requirements-dev.txt b/requirements-dev.txt index bd2f7fdf..5be4b7d4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,7 +8,7 @@ pytest pytest-cov pytest-mock scikit-learn -lightning +lightning==2.5.1.post0 sphinx sphinxcontrib-napoleon tabulate diff --git a/tests/customized_ops/ring_attn/test_shuffle_varlen.py b/tests/customized_ops/ring_attn/test_shuffle_varlen.py index 8f7271a4..a8d6b3c9 100644 --- a/tests/customized_ops/ring_attn/test_shuffle_varlen.py +++ b/tests/customized_ops/ring_attn/test_shuffle_varlen.py @@ -15,6 +15,13 @@ from tests.launch_torchrun import torchrun +# Skip all tests if flash_attn_func is not available +try: + from flash_attn import flash_attn_func +except ImportError: + pytest.skip("flash_attn_func not available", allow_module_level=True) + + @dataclass class ShuffleVarlenConfig: """Simple test configuration""" diff --git a/tests/parallel_module/test_gencode_einops.py b/tests/parallel_module/test_gencode_einops.py index bea1c75a..2a9d2aa3 100644 --- a/tests/parallel_module/test_gencode_einops.py +++ b/tests/parallel_module/test_gencode_einops.py @@ -5,6 +5,7 @@ import functools from einops import rearrange import torch +import pytest from nnscaler import parallelize, ComputeConfig from nnscaler.graph import parser @@ -65,3 +66,66 @@ def test_codegen_rearrange(): ) # parallelize will succeed. assert True + + +class RearrangeModule2(torch.nn.Module): + def __init__(self, shape: tuple[int, ...]): + super().__init__() + self.shape = shape + self.weight = torch.nn.Parameter(torch.ones(self.shape)) + + def forward(self, x): + bsz = x.size(0) + x = rearrange(x, 'n l d -> (n l) d', n=bsz) + return x + self.weight + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize("constant_folding", [True, False]) +def test_rearrange2(tmp_path, constant_folding): + parallelize( + RearrangeModule2(4), + {'x': torch.randn(4, 4, 4)}, + 'dp', + ComputeConfig(1, 1, constant_folding=constant_folding), + gen_savedir=tmp_path, + load_module=False + ) + # parallelize will succeed. + assert True + + # code will look like this when constant_folding=True + # def segment22(self, x_25): + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/_backends.py", line 93, in reshape, return x.reshape(shape) + # reshape_27 = torch.Tensor.reshape(x_25, shape=(16, 4)) + # del x_25 + # # File "/data/weijiangxu/nnscaler/tests/parallel_module/test_gencode_einops.py", line 80, in forward, return x + self.weight + # add_26 = torch.add(reshape_27, self.weight_28, alpha=1) + # del reshape_27 + # return add_26 + + + # code will look like this when constant_folding=False + # def segment25(self, x_28): + # # File "/data/weijiangxu/nnscaler/tests/parallel_module/test_gencode_einops.py", line 78, in forward, bsz = x.size(0) + # size_21 = torch.Tensor.size(x_28, dim=0) + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/_backends.py", line 90, in shape, return x.shape + # im_output_36 = builtins.getattr(x_28, 'shape') + # getattr_3_22 = im_output_36[0] + # getattr_3_23 = im_output_36[1] + # getattr_3_24 = im_output_36[2] + # del im_output_36 + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/einops.py", line 33, in _product, result *= element + # mul_1_25 = _operator.mul(1, size_21) + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/einops.py", line 33, in _product, result *= element + # imul_26 = _operator.imul(mul_1_25, getattr_3_23) + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/einops.py", line 33, in _product, result *= element + # mul_2_27 = _operator.mul(1, getattr_3_24) + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/_backends.py", line 93, in reshape, return x.reshape(shape) + # reshape_30 = torch.Tensor.reshape(x_28, shape=(imul_26, mul_2_27)) + # del x_28 + # # File "/data/weijiangxu/nnscaler/tests/parallel_module/test_gencode_einops.py", line 80, in forward, return x + self.weight + # add_29 = torch.add(reshape_30, self.weight_31, alpha=1) + # del reshape_30 + # return add_29 + From e39d68b45c8e0cf0f3e194289245519f13ed032e Mon Sep 17 00:00:00 2001 From: XU Weijiang <90586345+0xWJ@users.noreply.github.com> Date: Fri, 30 Jan 2026 14:24:15 +0800 Subject: [PATCH 1886/1892] [Refine] Normalize device handling in state dicts and more (#9) 1. Normalize state_dict device handling 2. replace torch.cat with F.pad --- nnscaler/cli/trainer.py | 54 +--------------- nnscaler/parallel.py | 63 ++++++++++--------- nnscaler/runtime/adapter/reducer.py | 18 +++--- nnscaler/runtime/module.py | 22 ++++--- nnscaler/utils.py | 34 +++++++--- tests/parallel_module/test_checkpoint.py | 2 +- tests/parallel_module/test_end2end.py | 52 +++++++-------- .../test_end2end_mix_precision.py | 10 +-- .../test_gencode_ctx_manager.py | 36 +++++------ 9 files changed, 135 insertions(+), 156 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 4014aced..828498d7 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -23,7 +23,7 @@ import nnscaler from nnscaler.runtime.device import DeviceGroup -from nnscaler.utils import enforce_zero_num_worker, is_running_distributed +from nnscaler.utils import broadcast_mixed_data, is_running_distributed from .trainer_args import AggregatedOutputs, TrainerArgs, fix_input from .train_hook import AggregatedTrainHook, TrainHook, TrainHookHost @@ -305,10 +305,6 @@ def _broadcast_merged_state_dict( ): """ Broadcast the merged state dict to all ranks. - We can't broadcast the whole state_dict at once, because it may be too large, and leads to OOM. - Here we will break the model and optimizer state_dict into smaller pieces and broadcast them one by one. - Please note we use `torch.distributed.broadcast_object_list` to broadcast the state_dict (including tensors inside). - TODO: optimize the broadcast by sending tensors separately. """ dst_ranks = dst_ranks or list(range(torch.distributed.get_world_size())) if src_rank not in dst_ranks or self.rank not in dst_ranks: @@ -321,54 +317,8 @@ def _broadcast_merged_state_dict( else: if state_dict is not None: raise ValueError("state_dict should be None in other ranks when broadcasting") - state_dict = {} - def _broadcast_keys(sdict: Dict[str, Any], set_keys=True): - if self.rank == src_rank: - state_keys = list(sdict.keys()) - else: - state_keys = None - state_key_list = [state_keys] - torch.distributed.broadcast_object_list(state_key_list, src=src_rank, group=pg) - state_keys = state_key_list[0] - if set_keys and self.rank != src_rank: - for key in state_keys: - sdict[key] = {} # assume the values are empty dicts - return state_keys - - def _broadcast_value(sdict, key): - if self.rank == src_rank: - value_list = [sdict[key]] - else: - value_list = [None] - torch.distributed.broadcast_object_list(value_list, src=src_rank, group=pg) - if self.rank != src_rank: - sdict[key] = value_list[0] - - def _broadcast_values(sdict, keys): - for key in keys: - _broadcast_value(sdict, key) - - state_keys = _broadcast_keys(state_dict) - - for skey in state_keys: - logger.info(f"Broadcasting {skey}.") - if skey == 'optimizer': - opt_keys = _broadcast_keys(state_dict['optimizer']) - opt_keys_without_state = [ - k for k in opt_keys if k != 'state' - ] - _broadcast_values(state_dict['optimizer'], opt_keys_without_state) - idxs = _broadcast_keys(state_dict['optimizer']['state']) - for idx in idxs: - idx_keys = _broadcast_keys(state_dict['optimizer']['state'][idx]) - _broadcast_values(state_dict['optimizer']['state'][idx], idx_keys) - elif skey == 'model': - model_keys = _broadcast_keys(state_dict['model']) - _broadcast_values(state_dict['model'], model_keys) - else: - _broadcast_value(state_dict, skey) - return state_dict + return broadcast_mixed_data(state_dict, src_rank=src_rank, group=pg, device='cpu') @classmethod def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str, diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 52f9f74a..0b5dbde6 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1790,11 +1790,15 @@ def merge_state_dicts( Please Note: We don't garantee the devices of tensors are the same in the merged state dict. - You can assume the device of the tensors in the merged state dict can be one of the following: + You can assume the device of the tensors in the merged state dict + can be 'cpu' or the device of the tensor in the original state dict. - 1. the current device when running this function - 2. the current cuda device when running this function - 3. the device of the tensor in the original state dict + Quick Explanation: + In current implementation, + For non-parallel modules, we directly take the tensor from input state dicts + For parallel modules, we will create new tensors from cpu, and copy/merge the tensors from input state dicts to it. + (this may be optimized later as we can avoid copying for replicated tensors.) + So in summary, the devices of the tensors in output state dicts can be either 'cpu' or the device in original state dict. When you load the state dict from file, you can just use `torch.load(..., map_location='...')` to unify the device of the tensors. @@ -1971,6 +1975,8 @@ def load_merged_state_dict( """ device = device or torch.cuda.current_device() + module.to(device) + # non ParallelModule parameters will be loaded here # there will be mismatched keys if the module is a ParallelModule or contains ParallelModule # so we need to ignore the mismatched keys @@ -1981,10 +1987,8 @@ def load_merged_state_dict( prefix = name + '.' if name else '' child_module.load_merged_state_dict(module_state_dict, prefix=prefix) - module.to(device) - if optimizer is not None and optimizer_state_dict is not None: - new_optimizer_state_dict = _trim_optimizer_merged_state_dict(module, optimizer._extra_state, optimizer_state_dict, device=device) + new_optimizer_state_dict = _trim_optimizer_merged_state_dict(module, optimizer._extra_state, optimizer_state_dict, device='cpu') optimizer.load_state_dict(new_optimizer_state_dict) @@ -2131,7 +2135,7 @@ def _construct_optim_state_zero3( else: opt_states[key] = torch.zeros( [chunk_size], dtype=sliced_new_val[key].dtype, - device=sliced_new_val[key].device, requires_grad=False + device='cpu', requires_grad=False ) # copy the param's slices to the optimizer's chunk for key in opt_states.keys(): @@ -2212,7 +2216,7 @@ def _get_optimizer_state_of_param(param, param_ids, local_names): opt_state_keys.remove('step') for key in opt_state_keys: opt_states[key] = torch.zeros([chunk_size], dtype=sliced_new_val[key].dtype, - device=sliced_new_val[key].device, requires_grad=False) + device='cpu', requires_grad=False) # copy the param's slices to the optimizer's chunk for key in opt_state_keys: sliced_new_val[key] = sliced_new_val[key].view(-1) @@ -2314,13 +2318,11 @@ def _extract_new_state( sliced_new_val[key] = sliced_new_val[key].view(-1)[zero3_info.start:zero3_info.end] if sliced_new_val[key].numel() < zero3_info.chunk_size: # padding if needed - sliced_new_val[key] = torch.cat( - [sliced_new_val[key], - torch.zeros( - zero3_info.chunk_size - sliced_new_val[key].numel(), - dtype=sliced_new_val[key].dtype, - device=sliced_new_val[key].device - )], dim=0 + sliced_new_val[key] = torch.nn.functional.pad( + sliced_new_val[key].cpu(), + (0, zero3_info.chunk_size - sliced_new_val[key].numel()), + mode='constant', + value=0.0 ) return sliced_new_val @@ -2560,9 +2562,10 @@ def load_deduped_state_dict( device = device or torch.cuda.current_device() cur_rank = torch.distributed.get_rank() + module.to(device) + # step 1: load deduped state dict at each rank missing_keys, unexpected_keys = module.load_state_dict(module_state_dict, strict=False) - module.to(device) torch.distributed.barrier() logger.debug(f'At rank {cur_rank}, state_dict keys: {module_state_dict.keys()}.') logger.debug(f'At rank {cur_rank}, missing_keys: {missing_keys}, unexpected_keys: {unexpected_keys}.') @@ -2839,8 +2842,9 @@ def load_sharded_state_dict( """ device = device or torch.cuda.current_device() - module.load_state_dict(module_state_dict) module.to(device) + + module.load_state_dict(module_state_dict) if optimizer and optimizer_state_dict: optimizer.load_state_dict(optimizer_state_dict) @@ -2951,7 +2955,7 @@ def _send_trimmed_module_state_dict( for key in keys: tensor = trimmed_state_dict[key] # NOTE: send is broken if the tensor is not contiguous - torch.distributed.send(tensor.contiguous(), group=group, dst=dst_rank) + torch.distributed.send(tensor.cuda().contiguous(), group=group, dst=dst_rank) def _receive_trimmed_module_state_dict( @@ -2976,9 +2980,9 @@ def _receive_trimmed_module_state_dict( trimmed_state_dict = {} for key, shape_dtype in zip(keys, shape_dtypes): - tensor = torch.zeros(shape_dtype[0], dtype=shape_dtype[1], device=device) + tensor = torch.zeros(shape_dtype[0], dtype=shape_dtype[1], device='cuda') torch.distributed.recv(tensor, group=group, src=src_rank) - trimmed_state_dict[key] = tensor + trimmed_state_dict[key] = tensor.to(device) return trimmed_state_dict @@ -3011,7 +3015,7 @@ def _send_trimmed_opt_state_dict( step_stack = torch.stack( [trimmed_opt_state_dict['state'][k]['step'] for k in state_keys] ) - torch.distributed.send(step_stack, group=group, dst=dst_rank) + torch.distributed.send(step_stack.cuda(), group=group, dst=dst_rank) # broadcast other states # TODO: can be slow? @@ -3021,7 +3025,7 @@ def _send_trimmed_opt_state_dict( keys.remove('step') # we have done step in previous. for key in keys: value = trimmed_opt_state_dict['state'][k][key] - torch.distributed.send(value.data, group=group, dst=dst_rank) + torch.distributed.send(value.data.cuda(), group=group, dst=dst_rank) def _receive_trimmed_opt_state_dict( @@ -3060,7 +3064,7 @@ def _receive_trimmed_opt_state_dict( step_stack = torch.zeros( len(state_keys), dtype=trimmed_opt_state_dict['state'][state_keys[0]]['step'].dtype, - device=device + device='cuda' ) torch.distributed.recv(step_stack, group=group, src=src_rank) for k, v in zip(state_keys, step_stack): @@ -3072,8 +3076,9 @@ def _receive_trimmed_opt_state_dict( if 'step' in keys: keys.remove('step') # we have done step in previous. for key in keys: - value = trimmed_opt_state_dict['state'][k][key] - torch.distributed.recv(value.data, group=group, src=src_rank) + value = trimmed_opt_state_dict['state'][k][key].cuda() + torch.distributed.recv(value, group=group, src=src_rank) + trimmed_opt_state_dict['state'][k][key] = value.to(device) return trimmed_opt_state_dict @@ -3247,12 +3252,14 @@ def load_merged_state_dict_from_rank( Returns: None """ + device = device or torch.cuda.current_device() + module.to(device) trimmed_module_state_dict, trimmed_opt_state_dict = trimmed_broadcast_merged_state_dict( module, module_state_dict, optimizer, optimizer_state_dict, - device=device, + device='cpu', src_rank=src_rank, dst_ranks=dst_ranks, ) @@ -3300,7 +3307,7 @@ def gather_full_model_state_dict( merge_state_dict = None logger.info(f'Rank {rank}: Broadcasting merged state dict to all ranks') - merge_state_dict = broadcast_mixed_data(merge_state_dict, src_rank=0) + merge_state_dict = broadcast_mixed_data(merge_state_dict, src_rank=0, device='cpu') logger.info(f'Rank {rank}: Finished gathering full model state dict') torch.distributed.barrier() return merge_state_dict diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index bf795d43..19c3c8b0 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -255,9 +255,11 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): # pragma: no cover padded_numel = z3_info.numel_with_padding() * self._zgroup_sz if grad.numel() < padded_numel: # add padding - grad = torch.cat( - [grad, - torch.zeros(padded_numel - grad.numel(), device=grad.device, dtype=grad.dtype)] + grad = torch.nn.functional.pad( + grad, + (0, padded_numel - grad.numel()), + mode='constant', + value=0.0, ) output = torch.zeros(z3_info.numel_with_padding(), device=grad.device, dtype=grad.dtype) torch.distributed.reduce_scatter_tensor( @@ -761,10 +763,12 @@ def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None # to make sure all ranks in the zero subgroup have the same bucket layout. if end - start < chunk_size: padding = chunk_size - (end - start) - param.data = torch.cat([ - param.data.view(-1)[start:end].clone(), - torch.zeros(padding, dtype=param.dtype, device=param.device) - ], dim=0) + param.data = torch.nn.functional.pad( + param.data.view(-1)[start:end], + (0, padding), + mode='constant', + value=0.0, + ) else: param.data = param.data.view(-1)[start:end].clone() diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 05b9ea03..9f57f835 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -479,7 +479,7 @@ def merge_model_state_dicts( partial_tensor = model_state_dict[local_name] if meta.orig_name not in full_model_state_dict: full_model_state_dict[meta.orig_name] = torch.empty( - meta.shape, dtype=partial_tensor.dtype) + meta.shape, dtype=partial_tensor.dtype, device='cpu') state_dict_merge_track[meta.orig_name] = set() # assign partial tensor if meta.val_chunks > 1: @@ -659,7 +659,7 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size, opt_states, opt_states_1d = {}, {} for key in opt_state_keys: opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, - device=bucket_states[0][key].device, requires_grad=False) + device='cpu', requires_grad=False) opt_states_1d[key] = opt_states[key].view(-1) if zero_version == 1: @@ -709,7 +709,7 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size, if 'step' in bucket_states[0]: # make sure all steps are different tensors (with same value) - opt_states['step'] = bucket_states[0]['step'].clone() + opt_states['step'] = bucket_states[0]['step'].cpu().clone() return opt_states def _merge_opt_zero(param_shape, worker_idx, param_idx): @@ -798,7 +798,7 @@ def _merge_opt_zero(param_shape, worker_idx, param_idx): if not CubeModule._safe_tensor_equal(full_states[full_index][state_name], value): raise ValueError(f"Conflict in merging {param_name}.{state_name} from rank {work_idx}") else: - full_states[full_index][state_name] = value + full_states[full_index][state_name] = value.cpu() continue # for non-tensor states @@ -813,7 +813,7 @@ def _merge_opt_zero(param_shape, worker_idx, param_idx): else: # create optimizer state tensor if state_name not in full_states[full_index]: - full_states[full_index][state_name] = torch.empty(meta.shape, dtype=value.dtype) + full_states[full_index][state_name] = torch.empty(meta.shape, dtype=value.dtype, device='cpu') if track_id in state_merge_track: if not CubeModule._safe_tensor_equal(full_states[full_index][state_name][meta.slicers], value): @@ -1681,7 +1681,7 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s # avoid checking the non-persistent buffers attr_names = set([attr for attr in self._fullmap.keys() if attr not in non_persistent_buffers]) - for prefix_attr, content in self.trim_merged_state_dict(state_dict, prefix).items(): + for prefix_attr, content in self.trim_merged_state_dict(state_dict, prefix, device='cpu').items(): attr = prefix_attr[len(prefix):] tensor: torch.Tensor = getattr(self, attr) tensor.copy_(content) @@ -1750,10 +1750,12 @@ def trim_merged_state_dict( if end - start < chunk_size: # need padding padding = chunk_size - (end - start) - trimmed_state_dict[prefix + attr] = torch.cat([ - content.view(-1)[start:end], - torch.zeros(padding, dtype=content.dtype, device=content.device) - ], dim=0).to(device) + trimmed_state_dict[prefix + attr] = torch.nn.functional.pad( + content.view(-1)[start:end].to(device), + (0, padding), + mode='constant', + value=0.0, + ) else: trimmed_state_dict[prefix + attr] = content.reshape(-1)[start:end].to(device) diff --git a/nnscaler/utils.py b/nnscaler/utils.py index fceb454a..f4d68cb1 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -850,6 +850,7 @@ def broadcast_mixed_data( *, src_rank: int = 0, group: Optional[torch.distributed.ProcessGroup] = None, + device: Optional[Union[str, torch.device]] = None, ): """ Broadcast the data (containing tensors) from src_rank to all other ranks. @@ -860,12 +861,18 @@ def broadcast_mixed_data( src_rank (int): The source rank to broadcast from. Default: 0. group (torch.distributed.ProcessGroup, optional): The process group to use for broadcasting. If None, the default process group will be used. Default: None. + device (str or torch.device, optional): The device to use for receiving tensors on non-src ranks. + If None, the current cuda device will be used. Default: None. Returns: dict: The broadcasted data. For src_rank, it is the same as the input data. For non-src ranks, it is the broadcasted data. the device of tensors will be cuda. """ + device = device or torch.cuda.current_device() + if isinstance(device, str): + # need to compare device later, so convert to torch.device + device = torch.device(device) rank = torch.distributed.get_rank(group=group) # share the structure and tensor shapes @@ -884,20 +891,29 @@ def broadcast_mixed_data( torch.distributed.broadcast_object_list(sent, src=src_rank, group=group) skeleton, meta_tensors = sent[0] if rank != src_rank: - tensors = [torch.empty_like(mt, device='cuda') for mt in meta_tensors] - else: - # make sure tensors are in cuda - tensors = [t.cuda() for t in tensors] + tensors = [None] * len(meta_tensors) # broadcast tensor data for i in range(len(tensors)): - torch.distributed.broadcast(tensors[i], src=src_rank, group=group) + if rank != src_rank: + tensor = torch.empty_like(meta_tensors[i], device='cuda') + else: + # make sure tensors are in cuda + tensor = tensors[i].cuda() + + torch.distributed.broadcast(tensor, src=src_rank, group=group) + + if rank != src_rank: + tensors[i] = tensor.to(device, non_blocking=True) + else: + # try to reuse the existing tensors if device matches + if tensor.device == device: + tensors[i] = tensor + else: + tensors[i] = tensors[i].to(device, non_blocking=True) # refill tensors - if rank != src_rank: - return refill_tensors(skeleton, tensors) - else: - return data + return refill_tensors(skeleton, tensors) def gather_mixed_data( diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 51a6328d..fc8388fb 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -382,7 +382,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf assert set(result_orig_opt_state_dict['state']) == set(result_merged_opt_state_dict['state']) for index in result_orig_opt_state_dict['state']: for key in ('step', 'exp_avg', 'exp_avg_sq'): - assert torch.equal(result_orig_opt_state_dict['state'][index][key], result_merged_opt_state_dict['state'][index][key]) + assert_equal(result_orig_opt_state_dict['state'][index][key], result_merged_opt_state_dict['state'][index][key]) torch.distributed.barrier() data = gendata(model, DATA_SIZE, start, end, rank, num_replicas) results = [] diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index c38d4403..924c62c1 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -204,14 +204,14 @@ def allclose(a, b, atol=1e-6, rtol=1e-6): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_end2end(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') - model = MLP() - ga4_result = _train_ga(model, 4) # micro_batch_size = 4 - assert len(ga4_result) == 16 - # will be used for comparision when zero_use_reduce_scatter is True - ga4_result_without_grads = [] - for i in range(len(ga4_result)): - ga4_result_without_grads.append([ga4_result[i][1], ga4_result[i][2]]) + with torch.device('cuda:0'): + model = MLP() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 + # will be used for comparision when zero_use_reduce_scatter is True + ga4_result_without_grads = [] + for i in range(len(ga4_result)): + ga4_result_without_grads.append([ga4_result[i][1], ga4_result[i][2]]) cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid') # micro_batch_size = 4 for _, v in cube2_results.items(): @@ -311,14 +311,14 @@ def __init__(self): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_pipeline_shared(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') - model = MLPShared() - ga4_result = _train_ga(model, 4) # micro_batch_size = 4 - assert len(ga4_result) == 16 - for step in range(len(ga4_result)): - # fake shared weights for later compare - ga4_result[step][0]['layers.5.weight'] = ga4_result[step][0]['layers.0.weight'] - ga4_result[step][1]['layers.5.weight'] = ga4_result[step][1]['layers.0.weight'] + with torch.device('cuda:0'): + model = MLPShared() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 + for step in range(len(ga4_result)): + # fake shared weights for later compare + ga4_result[step][0]['layers.5.weight'] = ga4_result[step][0]['layers.0.weight'] + ga4_result[step][1]['layers.5.weight'] = ga4_result[step][1]['layers.0.weight'] with pytest.raises(ValueError, match='is not supported in training mode'): ComputeConfig( @@ -356,10 +356,10 @@ def test_pipeline_shared(): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 8, reason='lack of gpu devices') def test_pipeline(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') - model = MLP() - ga4_result = _train_ga(model, 4) # micro_batch_size = 4 - assert len(ga4_result) == 16 + with torch.device('cuda:0'): + model = MLP() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 # pp_size = 2 # tp_size = 2 @@ -441,12 +441,12 @@ def gpu_worker_cube_one_sample(): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_loss_scaling(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') - model = MLP() - ga4_result = _train_ga(model, 1, 1) - assert len(ga4_result) == 1 - ga4_grads = ga4_result[0][0] - scaled_ga4_grads = {n: g * 2.0 for n, g in ga4_grads.items()} + with torch.device('cuda:0'): + model = MLP() + ga4_result = _train_ga(model, 1, 1) + assert len(ga4_result) == 1 + ga4_grads = ga4_result[0][0] + scaled_ga4_grads = {n: g * 2.0 for n, g in ga4_grads.items()} cube2_results = launch_torchrun(2, gpu_worker_cube_one_sample) cube2_result = merge_cube_result({k: v for k, v in cube2_results.items()}) diff --git a/tests/parallel_module/test_end2end_mix_precision.py b/tests/parallel_module/test_end2end_mix_precision.py index c1922b5b..dd399182 100644 --- a/tests/parallel_module/test_end2end_mix_precision.py +++ b/tests/parallel_module/test_end2end_mix_precision.py @@ -171,12 +171,12 @@ def gpu_worker_cube(use_zero=False, async_reducer=False, use_bucket=False): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_mixed_precision(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') init_random() - model = MPModule() - torch.save(model.state_dict(), 'model.pth') - ga4_result = _train_ga(model, 4) # micro_batch_size = 4 - assert len(ga4_result) == 4 + with torch.device('cuda:0'): + model = MPModule() + torch.save(model.state_dict(), 'model.pth') + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 4 cube2_results_non_pipeline = {} for use_async_reducer in [False, True]: diff --git a/tests/parallel_module/test_gencode_ctx_manager.py b/tests/parallel_module/test_gencode_ctx_manager.py index eba20786..42bd7181 100644 --- a/tests/parallel_module/test_gencode_ctx_manager.py +++ b/tests/parallel_module/test_gencode_ctx_manager.py @@ -77,7 +77,7 @@ def check_ctx_manager_codegen(tempdir): # use_scheduler = False # nmicros_per_scheduler_step = 1 # rank = 0 - + # def __init__(self, init_params=True, *, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False): # super().__init__() # # communication groups @@ -85,17 +85,17 @@ def check_ctx_manager_codegen(tempdir): # self.init_group(ranks=[1, 3]) # self.init_group(ranks=[0, 1]) # self.init_group(ranks=[2, 3]) - + # self.register_parameter('param_1_62', torch.nn.Parameter(torch.empty((16, 16), dtype=torch.float32))) # self.add_full_map('param_1_62', 5, True, 'param_1', (16, 16), (slice(0, 16, None), slice(0, 16, None)), 1) - - + + # self.wreducer312 = nnscaler.runtime.adapter.Reducer(ranks=[0, 2], reduce_op='sum', async_op=async_op, zero=False, max_bucket_size_bytes=max_bucket_size_bytes, zero_use_reduce_scatter=zero_use_reduce_scatter, zero_ngroups=1) # self.wreducer312.add_param(self.param_1_62) # self.add_reducer(self.wreducer312) - + # self._post_init(init_params) - + # def segment308(self, x_75, y_78): # # auto_multiref # param_1_106, param_1_107, param_1_108, param_1_109, param_1_110 = nnscaler.runtime.function.multiref(self.param_1_62, times=5) @@ -117,12 +117,12 @@ def check_ctx_manager_codegen(tempdir): # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref # matmul_1_202, matmul_1_228 = nnscaler.runtime.function.multiref(matmul_1_182, times=2) # del matmul_1_182 - + # with torch.no_grad(): # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 24, in forward, r_3 = torch.matmul(r_1, self.param_1) # matmul_2_196 = torch.matmul(matmul_194, param_1_106) # del param_1_106, matmul_194 - + # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref # matmul_2_216, matmul_2_242 = nnscaler.runtime.function.multiref(matmul_2_196, times=2) # del matmul_2_196 @@ -133,12 +133,12 @@ def check_ctx_manager_codegen(tempdir): # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref # matmul_3_252, matmul_3_218 = nnscaler.runtime.function.multiref(matmul_3_204, times=2) # del matmul_3_204 - + # with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True, cache_enabled=True): # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 28, in forward, r_5 = r_3 * r_4 # mul_220 = torch.mul(matmul_2_216, matmul_3_218) # del matmul_2_216, matmul_3_218 - + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 29, in forward, r = r_1 * r_2 * r_3 * r_4 * r_5 # mul_1_230 = torch.mul(matmul_226, matmul_1_228) # del matmul_226, matmul_1_228 @@ -161,11 +161,11 @@ def check_ctx_manager_codegen(tempdir): # norm_61 = torch.norm(matmul_4_72, p='fro', dim=None, keepdim=False, out=None, dtype=None) # del matmul_4_72 # return norm_61 - + # def reducer312(self): # self.wreducer312.sync_grads() - # return - + # return + # def _forward_impl(self, x, y): # norm_61 = self.segment308(x, y) # return norm_61 @@ -301,12 +301,12 @@ def _train_ga(model, update_freq, data_size): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_loss_scaling(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') init_random() - model = CtxManagerModel() - ga4_result = _train_ga(model, 1, 1) - assert len(ga4_result) == 1 - ga4_grads = ga4_result[0][0] + with torch.device('cuda:0'): + model = CtxManagerModel() + ga4_result = _train_ga(model, 1, 1) + assert len(ga4_result) == 1 + ga4_grads = ga4_result[0][0] cube2_results = launch_torchrun(2, gpu_worker_cube_one_sample) cube2_result = merge_cube_result({k: v for k, v in cube2_results.items()}) From cc97940741b3ab49a2f289c562edad77d21571f5 Mon Sep 17 00:00:00 2001 From: yyl9510 <584540273@qq.com> Date: Fri, 30 Jan 2026 14:29:36 +0800 Subject: [PATCH 1887/1892] Add Doc Autodist Constraints Guide (#5) * Add Doc for Autodist Constraints Guide * Revise the description of autodist in the documentation. * fix comment * polish doc --------- Co-authored-by: yileiyang --- docs/source/partition_constraints_guide.md | 128 +++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 docs/source/partition_constraints_guide.md diff --git a/docs/source/partition_constraints_guide.md b/docs/source/partition_constraints_guide.md new file mode 100644 index 00000000..c4aabb45 --- /dev/null +++ b/docs/source/partition_constraints_guide.md @@ -0,0 +1,128 @@ +# Partition Constraints Guide + +Nnscaler allows users to guide the parallelization strategy by specifying constraints. This is useful when you have specific knowledge about how certain operators should be partitioned or if you want to enforce specific behaviors like recomputation or pipeline stages. + +There are two scenarios for providing constraints, they cannot be used at the same time: +1. **Partition constraints in Autodist**: When using autodist, you can use **Autodist YAML Configuration** to define valid partition dimensions for specific operators. +2. **Partition constraints in Function Policy**: You can provide a **Function Policy** that yields `OpPlan` objects to explicitly define the plan for operators. There are no auto search for partitioning except for your definition. + +## Method 1: Partition constraints in Autodist + +You can use a **Autodist YAML Configuration** file to specify which dimensions are allowed for partitioning for specific operators. This is often used to prevent Autodist from partitioning certain operators in ways that are known to be inefficient or problematic (e.g., forcing replication). + +### Configuration Format + +The configuration is a list of dictionaries, each describing a constraint rule. + +```yaml +- allowed_partition_dims: + - 0,0 # List of allowed (input_index, dim_index) pairs + name: torch.sum + parent_module: 'MoE' # Optional: Filter by parent module class name + replica_allowed: false +``` + +### Fields + +* **`name`** (required): The fully qualified name or signature of the operator (e.g., `torch.sum`, `arch.ffn.ffn_func`). +* **`allowed_partition_dims`** (required): A list of strings representing allowed partition strategies. + * Format: `"input_idx,dim_idx"`. + * Example: `"0,0"` means the operator can be partitioned along dimension 0 of input 0. + * If the list is empty, the operator might be forced to replicate (depending on `replica_allowed`). +* **`parent_module`** (optional): If specified, the constraint only applies to operators that are children of a module with this class name. This is useful for targeting specific parts of the model (e.g., only `torch.sum` inside `MoE` layer). +* **`replica_allowed`** (optional, default: `true`): Whether replication is a valid strategy. If `false`, Autodist *must* find a partition strategy from `allowed_partition_dims`. + +### Example + +Below is an example of a custom operator: + +```yaml + +# Constraint for a custom op +- allowed_partition_dims: + - 0,0 + name: arch.all2all_moe.nnscaler_all2all_moe_gmm + parent_module: 'MoE' + replica_allowed: false + +``` + +To use this file, pass its path to `AutoDistConfig`: + +```python +cfg = AutoDistConfig( + ..., + partition_constraints_path='/path/to/constraints.yaml' +) +``` + + +## Method 2: Partition constraints in Function Policy + +For fine-grained control, you can provide a **Function Policy** (the `pas_policy` argument in `parallelize` or `pas_policy` argument in `TrainerArgs`). This function yields `OpPlan` objects which explicitly specify the partitioning strategy for specific nodes. + +### Usage + +Define a function `policy(graph, cfg)` that iterates over the graph nodes and yields `OpPlan` objects. + +**Important Considerations:** +If you choose to manually partition operators (especially for complex communication patterns), you often need to define `OpPlan` for **all connected operators** that share the partition logic, this means: +1. you must define OpPlans for all ops you want to partition. +2. the default OpPlans is replicated if you don't define `OpPlan` for ops. +3. The only exception is when you define `OpPlan.partition` to `auto`, which will try to partition the op based on its input partitions. + +### `OpPlan` Parameters + +The `OpPlan` class defines the strategy for a single operator. + +```python +class OpPlan: + def __init__(self, op, partition='auto', recompute_id=-1, stage_id=-1, ...): + ... +``` + +* **`op`**: The graph node (`IRFwOperation`) this plan applies to. +* **`partition`**: define the partitioning strategy. + * `OpPartition(input=i, dim=d)`: Partition the operator based on the `d`-th dimension of its `i`-th input tensor. + * `'auto'` (default): Tries to follow the partition of its inputs. If no input is partitioned, it just replicate the operator. + * `None`: Force the operator to be replicated (no partitioning). +* **`recompute_id`** (default: -1): + * Used to group operators for Recompute (Gradient Checkpointing). + * Operators with the same non-negative `recompute_id` will be grouped into a single recomputation block. + * These operators with the same `recompute_id` should be consecutive in the graph. +* **`stage_id`** (default: -1): + * Used for Pipeline Parallelism assignment. + * These operators with the same `stage_id` should be consecutive in the graph. +* **`pre_hook` / `post_hook`**: + * You can attach custom Python functions to be executed before or after the operator. See source code for signature details. + +### Example: Custom Partitioning and Recomputation + +This example demonstrates how to: +1. Use helper functions: `get_pas_ops` (filter for relevant ops), `get_layer_index` (extract layer ID from name), and `get_called_self_module_name` (identify sub-module names like `gate_proj`). +2. Filter operations by the module class chain (e.g., targeting `FFNDropout`). +3. Assign `recompute_id` and `stage_id` dynamically based on the model's layer index. +4. Apply different partition strategies based on the specific module being called. + +```python +from nnscaler.policies import OpPlan, OpPartition, get_layer_index, get_called_self_module_name, get_pas_ops +import torch + +def custom_policy(graph): + for node in get_pas_ops(graph): + if FFNDropout not in node.module_class_chain: # work only on FFN module + continue + + ffn_idx = get_layer_index(node.fqn) + module_called = get_called_self_module_name(node.call_expr) + + if node.fn == torch.nn.functional.linear: + if module_called in ['gate_proj', 'up_proj']: + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition=OpPartition(input=1, dim=0)) + else: + # down_proj + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition=OpPartition(input=1, dim=1)) + else: + # other ops + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition='auto') +``` From cbbe521007deb73479c7f4bdfdf3aaf29f8ff79e Mon Sep 17 00:00:00 2001 From: yyl9510 <584540273@qq.com> Date: Fri, 30 Jan 2026 15:20:47 +0800 Subject: [PATCH 1888/1892] CI/CD (#11) * add ci.yml for continuous integration and continuous development * branch self test switch back to conda from uv change back to python3.10 from python 3.12 * add conda-forge channel for tox-conda package * add py lib for tox-conda * change conda to uv because tox-conda is too old * using uv tool for fixing setup issue * pip is needed for tox 3.0, while azuer use 3.0 don't need this permission * add all possible commands to allowlist * change back to main branch * back to conda with fixed tox and tox-conda version * change back to uv --------- Co-authored-by: yileiyang --- .github/workflows/ci.yml | 32 ++++++++++++++++++++++++++++++++ requirements-dev.txt | 2 +- tox.ini | 6 ++++-- 3 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..abe93c90 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI Pipeline + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + name: Run Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install system build tools + run: | + sudo apt-get update + sudo apt-get install -y build-essential python3-dev + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - name: Install tox + run: uv tool install tox --with tox-uv + + - name: Run unit tests + run: tox \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 5be4b7d4..df05299d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,7 +13,7 @@ sphinx sphinxcontrib-napoleon tabulate tox -tox-conda +tox-uv yapf wandb tensorboard diff --git a/tox.ini b/tox.ini index 04e0c3ae..0cfe015c 100644 --- a/tox.ini +++ b/tox.ini @@ -8,9 +8,11 @@ envlist = py310 skipsdist = True [testenv] -allowlist_externals = rm +allowlist_externals = + rm + uv passenv = * -install_command = pip install {opts} {packages} +install_command = uv pip install {opts} {packages} deps = -rrequirements.txt -rrequirements-dev.txt From 8ce8b09fb86bd7ff26cd5342e5b5211db9f55520 Mon Sep 17 00:00:00 2001 From: yyl9510 <584540273@qq.com> Date: Tue, 3 Feb 2026 17:39:42 +0800 Subject: [PATCH 1889/1892] Add nightly test to the repo (#12) * Add nightly test to the repo * make parity alignment * fix incoordination shutil.rmtree of rank0 --------- Co-authored-by: yileiyang --- examples/llama/create_mini_model.py | 2 + tests/utils.py | 1 + utility/nightly_test/nightly_test.py | 228 ++++++++++++++++++ .../parity_alert_examples/parity_alert.sh | 164 +++++++++++++ .../parity_alert_examples/parity_check.py | 48 ++++ .../test_cases/llama/config.yaml | 16 ++ .../test_cases/llama3_demo/config.yaml | 12 + .../test_cases/nanogpt/config.yaml | 13 + .../parity_alert_examples/train.py | 60 +++++ utility/nightly_test/test_utils.py | 183 ++++++++++++++ 10 files changed, 727 insertions(+) create mode 100644 utility/nightly_test/nightly_test.py create mode 100644 utility/nightly_test/parity_alert_examples/parity_alert.sh create mode 100644 utility/nightly_test/parity_alert_examples/parity_check.py create mode 100644 utility/nightly_test/parity_alert_examples/test_cases/llama/config.yaml create mode 100644 utility/nightly_test/parity_alert_examples/test_cases/llama3_demo/config.yaml create mode 100644 utility/nightly_test/parity_alert_examples/test_cases/nanogpt/config.yaml create mode 100644 utility/nightly_test/parity_alert_examples/train.py create mode 100644 utility/nightly_test/test_utils.py diff --git a/examples/llama/create_mini_model.py b/examples/llama/create_mini_model.py index 1151771e..514ac3ac 100644 --- a/examples/llama/create_mini_model.py +++ b/examples/llama/create_mini_model.py @@ -3,9 +3,11 @@ import argparse from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +import torch def main(args): + torch.manual_seed(0) # Ensure deterministic initialization config = AutoConfig.from_pretrained(args.model_id) config.num_hidden_layers = 4 config.use_cache = False diff --git a/tests/utils.py b/tests/utils.py index d29e4168..4e3bd85a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -382,6 +382,7 @@ def catch_stdout(): def clear_dir_on_rank0(tempdir): if torch.distributed.get_rank() == 0 and tempdir.exists(): shutil.rmtree(tempdir) + torch.distributed.barrier() yield tempdir torch.distributed.barrier() if torch.distributed.get_rank() == 0 and tempdir.exists(): diff --git a/utility/nightly_test/nightly_test.py b/utility/nightly_test/nightly_test.py new file mode 100644 index 00000000..99eb2b2d --- /dev/null +++ b/utility/nightly_test/nightly_test.py @@ -0,0 +1,228 @@ +from test_utils import TestUtils +from azure.communication.email import EmailClient +from subprocess import CalledProcessError +import subprocess +from datetime import datetime, timedelta +from pathlib import Path +import argparse +import base64 +import zipfile +import json +import os +import sys + +sender_address = "DoNotReply@ca1e34f6-1a6d-4181-8b16-692dbe193525.azurecomm.net" + +def zip_folder(folder_path, output_path): + """ Zip the folder to the output path + Args: + folder_path: the folder path + output_path: the output path + """ + with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in os.walk(folder_path): + for file in files: + relative_path = os.path.relpath(os.path.join(root, file), os.path.dirname(folder_path)) + zipf.write(os.path.join(root, file), arcname=relative_path) + +def get_branch_commit(repo_path, branch_name = None, days_ago = 0): + """ Get the branch name or commit ID of the branch_name that is days_ago + Args: + repo_path: the path of the git repo + branch_name: the branch name, if not provided return the branch name + days_ago: the days ago, 0 means get current commit ID of the branch + Returns: + The branch name of the commit ID + """ + if branch_name is None: + git_command = 'git rev-parse --abbrev-ref HEAD' + return TestUtils.execute_command(git_command, repo_path) + elif days_ago == 0: + git_command = 'git rev-parse HEAD' + return TestUtils.execute_command(git_command, repo_path) + else: + before_date = (datetime.now() - timedelta(days=int(days_ago)) + timedelta(hours=15)).strftime('%Y-%m-%d %H:%M:%S') # add 15 hours to align with Beijing time + git_command = 'git fetch && git rev-list -n 1 --before="{}" {}'.format(before_date, branch_name) + return TestUtils.execute_command(git_command, repo_path) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description='Running Nightly Test') + parser.add_argument('-w', '--workspace', required=True, help='workspace for nightly test') + parser.add_argument('-d', '--data-path', required=True, help='dataset path') + + parser.add_argument('-n', '--nnscaler-commit-id', help='nnscaler commit id, decide the version of nnscaler for unit test and example parity-check') + + parser.add_argument('-u', '--unit-test', default=False, action=argparse.BooleanOptionalAction, help='unit test for nnscaler') + + parser.add_argument('-ep', '--example-parity-check', default=False, action=argparse.BooleanOptionalAction, help='example parity check for nnscaler. It will compare nnscaler or main with or main') + + # Keeping old argument name for compatibility if needed, but help text updated + parser.add_argument('-p2', '--parity-check2', dest='example_parity_check', action='store_true', help='Alias for --example-parity-check') + + parser.add_argument('-pb', '--parity-check-conda-base', help='base conda environment for parity check, needed if example-parity-check is True') + parser.add_argument('-ngt', '--cube-branch-gt', default='main', help='cube branch for ground truth, default is main') + + parser.add_argument('-e', '--email-connect-string', help='email connect string for sending email address') + parser.add_argument('-et', '--email-to', action='append', default=[], help='multiple -et will be combined') + parser.add_argument('-ec', '--email-cc', action='append', default=[], help='multiple -ec will be combined') + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_arguments() + workspace = Path(args.workspace).expanduser() + data_path = Path(args.data_path).expanduser() + if not workspace.exists(): + raise ValueError(f"Invalid workspace path: {workspace}") + if not data_path.exists(): + raise ValueError(f"Invalid data_path path: {data_path}") + log_folder = TestUtils.gen_log_folder(workspace) + + # Assuming nnscaler is cloned as "nnscaler" in the workspace + nnscaler_repo_path = workspace / "nnscaler" + pytest_dir = nnscaler_repo_path / "tests" + + script_dir = Path(__file__).parent.absolute() + parity_alert_script = script_dir / "parity_alert_examples/parity_alert.sh" + parity_check_cases_dir = script_dir / "parity_alert_examples/test_cases" + + if not pytest_dir.exists(): + raise ValueError(f"Invalid pytest_dir path: {pytest_dir}") + if not parity_alert_script.exists(): + raise ValueError(f"Invalid parity_alert_script path: {parity_alert_script}") + + if args.nnscaler_commit_id: + cmd = f"parallel-ssh -x -q -h ~/.pssh_hosts_files git -C {nnscaler_repo_path} checkout {args.nnscaler_commit_id}" + TestUtils.call([cmd]) + + with open(TestUtils.gen_log_folder(workspace) / "nightly_test.log", 'a') as nightly_test_file: + nnscaler_branch = get_branch_commit(nnscaler_repo_path) + nnscaler_commit_id = get_branch_commit(nnscaler_repo_path, nnscaler_branch) + nightly_test_file.write(f"nnscaler on branch {nnscaler_branch}, commit ID {nnscaler_commit_id}\n\n") + nightly_test_file.flush() + + # Run Unit Test + pytest_output = "" + if args.unit_test: + pytest_cmd = f"{sys.executable} -m pytest {pytest_dir}" + try: + pytest_log_file = log_folder / "pytest.log" + with open(pytest_log_file, 'w') as f: + # Run pytest from inside nnscaler repo + result = subprocess.run([sys.executable, '-m', 'pytest', '-v', str(pytest_dir)], stdout=f, stderr=f, cwd=nnscaler_repo_path) + if result.returncode != 0: + pytest_output = f"NNScaler Unit test didn't pass, see {pytest_log_file.name} for more details." + else: + pytest_output = "NNScaler Unit test passed" + except CalledProcessError as e: + pytest_output = f"Command {pytest_cmd} failed with error code {e.returncode}" + finally: + nightly_test_file.write(pytest_output + "\n") + + # Run Example Parity Check + parity_alert_output = "" + if args.example_parity_check: + tmp_parity_check = workspace / 'tmp_example_parity_check' + if os.path.isdir(tmp_parity_check): + import shutil + shutil.rmtree(tmp_parity_check) + + if not args.nnscaler_commit_id: + # If not specified, get the current one for consistency in logging/checking + args.nnscaler_commit_id = get_branch_commit(nnscaler_repo_path, "origin/main", 0) + + nightly_test_file.write(f"Example Parity check:\nnnscaler commit ID: {args.nnscaler_commit_id}" + "\n") + + parity_check_cmd = f"bash {parity_alert_script} {tmp_parity_check} {data_path} {parity_check_cases_dir} --cube-branch {args.nnscaler_commit_id} --cube-branch-gt {args.cube_branch_gt} --conda-base {args.parity_check_conda_base}" + + env = os.environ.copy() + # Assuming we might need to set PYTHONPATH if needed for some scripts, but usually the parity script handles env setup + # But let's keep consistency if we copied parity_alert_examples which relies on some imports + try: + parity_log_file = log_folder / "example_parity_check.log" + with open(parity_log_file, 'w') as f: + # CWD to the directory of parity_alert_examples for any relative path assumptions inside train.py potentially + cwd_path = script_dir / "parity_alert_examples" + result = subprocess.run(parity_check_cmd, stdout=f, stderr=f, shell=True, env=env, cwd=cwd_path) + if result.returncode != 0: + parity_alert_output = f"Example Parity Check didn't pass, see {parity_log_file.name} for more details." + else: + parity_alert_output = "Example Parity Check passed" + except CalledProcessError as e: + parity_alert_output = f"Command {parity_check_cmd} failed with error code {e.returncode}" + finally: + nightly_test_file.write(parity_alert_output + "\n") + + nightly_test_file.flush() + + # Send email + if args.email_connect_string: + if not args.email_to: + raise ValueError(f"Invalid email_to: {args.email_to}") + zip_output = log_folder.parent / 'nightly_test_logs.zip' + zip_folder(log_folder, zip_output) + with open(zip_output, "rb") as file: + zip_b64encoded = base64.b64encode(file.read()) + + html_output = """ + + + Test Results + + + + """ + + if args.unit_test: + pytest_html_message = f"""

NNScaler Unit Test

{pytest_output}

""" + html_output += pytest_html_message + + if args.example_parity_check: + parity_html_message = f"""

Example Parity Check

{parity_alert_output}

""" + html_output += parity_html_message + + html_output +="""""" + + message = { + "senderAddress": sender_address, + "recipients": { + "to": [{ "address": t } for t in args.email_to], + "cc": [{ "address": t } for t in args.email_cc] + }, + "content": { + "subject": "Nightly Test Notification", + "html": html_output + }, + "attachments": [ + { + "name": "attachment.zip", + "contentType": "application/zip", + "contentInBase64": zip_b64encoded.decode() + } + ] + } + + try: + POLLER_WAIT_TIME = 10 + client = EmailClient.from_connection_string(args.email_connect_string) + poller = client.begin_send(message) + time_elapsed = 0 + while not poller.done(): + poller.wait(POLLER_WAIT_TIME) + time_elapsed += POLLER_WAIT_TIME + if time_elapsed > 18 * POLLER_WAIT_TIME: + raise RuntimeError("Polling timed out.") + if poller.result()["status"] == "Succeeded": + nightly_test_file.write(f"Successfully sent the email (operation id: {poller.result()['id']})") + else: + raise RuntimeError(str(poller.result()["error"])) + except Exception as ex: + nightly_test_file.write(str(ex)) + else: + nightly_test_file.write("No email connection string provided, skip sending email") diff --git a/utility/nightly_test/parity_alert_examples/parity_alert.sh b/utility/nightly_test/parity_alert_examples/parity_alert.sh new file mode 100644 index 00000000..7df42bf2 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/parity_alert.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +# For parity check. +# Example: +# bash parity_alert.sh [] +# : the workspace where all codes are stored. +# : the folder when the train data for torchscale is stored. +# : the definition of parity check. +# Default value is ${the dir of the current script}/test_cases/ +# Options: +# --cube-branch-gt : default is main +# --cube-branch : default is main +# --conda-base : default is base +# --test-cases : default is all +# The test cases are listed under (`test_cases/`) folder, e.g., pasdata, dp2, tp2, hybrid2. +# +# Currently the workspace is not cleared after execution, so it can help fix the parity problem if any. +# To clean the workspace +# 1. run `rm -rf ` to clean the cloned source code. +# 2. run `conda env remove -n parity` to remove conda env. + +set -e + +export NCCL_DEBUG=WARN + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +POSITIONAL_ARGS=() + +CUBE_BRANCH_GT=main + +CUBE_BRANCH_NEW=main + +CONDA_ENV_BASE=base +TEST_CASES= + +while [[ $# -gt 0 ]]; do + case $1 in + --cube-branch-gt) + CUBE_BRANCH_GT="$2" + shift # past argument + shift # past value + ;; + --cube-branch) + CUBE_BRANCH_NEW="$2" + shift # past argument + shift # past value + ;; + --conda-base) + CONDA_ENV_BASE="$2" + shift # past argument + shift # past value + ;; + --test-cases) + TEST_CASES="$2" + shift # past argument + shift # past value + ;; + -*|--*) + echo "Unknown option $1" + exit 1 + ;; + *) + POSITIONAL_ARGS+=("$1") # save positional arg + shift # past argument + ;; + esac +done + +set -- "${POSITIONAL_ARGS[@]}" # restore positional parameters + +OPERATION=$1 + +if [[ $# -ne 2 ]] && [[ $# -ne 3 ]]; then + echo "usage: $0 WORKSPACE TRAIN_DATA_DIR [PARITY_CHECK_DATA_DIR]" + echo " [--cube-branch-gt ]" + echo " [--cube-branch ]" + echo " [--conda-base ]" + echo " [--test-cases ]" + exit 1 +fi + + +WORKSPACE=$1 +TRAIN_DATA_DIR=$2 +PARITY_CHECK_DATA_DIR=${3:-${SCRIPT_DIR}/test_cases} + +if [[ -d $WORKSPACE ]]; then + echo "Error: $WORKSPACE has existed, please remove the folder before running the test(s)." + exit 2 +fi + + +ENV_NAME=parity_$(echo $RANDOM | md5sum | head -c 10) +TMP_SETUP_ENV_SH=tmp_setup_env.sh +TMP_SWITCH_BRANCH_SH=tmp_switch_branch.sh +TMP_MODEL_DIR=result_models # will not be removed after execution +# get an unused port +UNUSED_PORT=`python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()'` + +conda create -y -n ${ENV_NAME} --clone ${CONDA_ENV_BASE} + +LIBSTDC_PATH=$(conda env list | grep ${ENV_NAME} | awk '{print $NF}')/lib/libstdc++.so.6 +rm -f ${LIBSTDC_PATH} + +trap "rm -rf tmp_* && conda env remove -n ${ENV_NAME} -y" EXIT + +cat > ${TMP_SETUP_ENV_SH} << EOF +#!/bin/bash + +set -e + +# init python env +pip install build + +mkdir -p ${WORKSPACE} +cd ${WORKSPACE} + +git clone --recursive "https://github.com/msrasys/nnscaler.git" -b $CUBE_BRANCH_GT +cd nnscaler +# Rename directory to match expected 'MagicCube' or just adapt strict usage. +# The original script used 'MagicCube' directory name. Let's stick effectively to cloning nnscaler. +# However, train.py and others might expect import structure. + +pip install -e . + +python -c 'import os,sys,nnscaler,cppimport.import_hook ; sys.path.append(os.path.dirname(nnscaler.__path__[0])) ; import nnscaler.autodist.dp_solver' +cd .. + +# verify installation +python -c 'import torch; import nnscaler; print(torch.__path__, nnscaler.__path__)' + +EOF + +cat > ${TMP_SWITCH_BRANCH_SH} << EOF +#!/bin/bash + +set -e + +cd ${WORKSPACE} + +cd nnscaler +git checkout $CUBE_BRANCH_NEW + +pip install -e . + +python -c 'import os,sys,nnscaler,cppimport.import_hook ; sys.path.append(os.path.dirname(nnscaler.__path__[0])) ; import nnscaler.autodist.dp_solver' + +cd .. +EOF + +export TEST_CASES="$TEST_CASES" +export TRAIN_DATA_DIR="$TRAIN_DATA_DIR" +export UNUSED_PORT="$UNUSED_PORT" +export DETERMINISTIC=1 + +conda run --no-capture-output -n ${ENV_NAME} bash ${TMP_SETUP_ENV_SH} + +conda run --no-capture-output -n ${ENV_NAME} python ${SCRIPT_DIR}/train.py ${WORKSPACE} ${PARITY_CHECK_DATA_DIR} ${TMP_MODEL_DIR}/gt + +conda run --no-capture-output -n ${ENV_NAME} bash ${TMP_SWITCH_BRANCH_SH} + +conda run --no-capture-output -n ${ENV_NAME} python ${SCRIPT_DIR}/train.py ${WORKSPACE} ${PARITY_CHECK_DATA_DIR} ${TMP_MODEL_DIR}/new + +conda run --no-capture-output -n ${ENV_NAME} python ${SCRIPT_DIR}/parity_check.py ${TMP_MODEL_DIR}/gt ${TMP_MODEL_DIR}/new diff --git a/utility/nightly_test/parity_alert_examples/parity_check.py b/utility/nightly_test/parity_alert_examples/parity_check.py new file mode 100644 index 00000000..94d852f8 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/parity_check.py @@ -0,0 +1,48 @@ +import os +from pathlib import Path +import sys + +import torch + + +def parity_check(task_name, ground_truth_model_file, new_model_file): + gt_ckpt = torch.load(ground_truth_model_file, map_location='cpu', weights_only=False) + new_ckpt = torch.load(new_model_file, map_location='cpu', weights_only=False) + if 'model' in gt_ckpt: + gt_model = gt_ckpt['model'] + new_model = new_ckpt['model'] + elif 'state_dict' in gt_ckpt: + gt_model = gt_ckpt['state_dict'] + new_model = new_ckpt['state_dict'] + for name in gt_model: + if not torch.allclose(gt_model[name], new_model[name], rtol=1e-06, atol=1e-06): + raise Exception(f'{task_name} failed: {name} mismatch (rtol=1e-06, atol=1e-06)') + print('All weights match (rtol=1e-06, atol=1e-06)') + + +def main(gt_dir: str, new_dir: str): + new_dir = Path(new_dir).absolute() + + test_cases = os.getenv('TEST_CASES') + if test_cases: + test_cases = test_cases.split(',') + print(f'Check test cases: {test_cases}') + else: + test_cases = None + print('Check all test cases') + passed = [] + for d in Path(gt_dir).glob('*'): + if not d.is_dir(): + continue + if not test_cases or d.name in test_cases: + print(f'Checking for {d.name}...') + parity_check(d.name, d / 'model.pt', new_dir / d.name / 'model.pt') + passed.append(d.name) + print(f'All passed: {passed}') + + +if __name__ == '__main__': + if len(sys.argv) !=3: + print('Usage: python check.py ') + exit(1) + main(sys.argv[1], sys.argv[2]) diff --git a/utility/nightly_test/parity_alert_examples/test_cases/llama/config.yaml b/utility/nightly_test/parity_alert_examples/test_cases/llama/config.yaml new file mode 100644 index 00000000..ffa2133e --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/test_cases/llama/config.yaml @@ -0,0 +1,16 @@ +# NOTE: +# Must set HF_TOKEN +# Must install apex and flash-attn manually + +name: Llama 3 8B 128K +train: + path: nnscaler/examples/llama + output: ./merged.ckpt + commands: + - rm -rf .nnscaler ./checkpoints ./merged.ckpt + - pip install -r requirements.txt + - python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 + - python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini + - python train.py --run_mode compile --plan_ngpus 4 --runtime_ngpus 4 --name llama3_debug --model_id ./llama3_mini --attn_implementation sdpa --dataset_path ./bookcorpus_llama3_4K --max_train_steps 50 --pipeline_pivots LlamaDecoderLayer --pipeline_nstages 2 + - torchrun --nproc_per_node=4 train.py --plan_ngpus 4 --runtime_ngpus 4 --name llama3_debug --model_id ./llama3_mini --attn_implementation sdpa --dataset_path ./bookcorpus_llama3_4K --max_train_steps 50 --pipeline_pivots LlamaDecoderLayer --pipeline_nstages 2 + - python ckpt_merger.py --ckpt_dir ./checkpoints/last --output_fname ./merged.ckpt diff --git a/utility/nightly_test/parity_alert_examples/test_cases/llama3_demo/config.yaml b/utility/nightly_test/parity_alert_examples/test_cases/llama3_demo/config.yaml new file mode 100644 index 00000000..8dbcb120 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/test_cases/llama3_demo/config.yaml @@ -0,0 +1,12 @@ +# NOTE: Must set HF_TOKEN + +name: Llama 3 demo +train: + path: nnscaler/examples/llama3_demo + output: checkpoints/merged.ckpt + commands: + - rm -rf .nnscaler ./checkpoints + - pip install -r requirements.txt + - python train.py --prepare_data --mini + - torchrun --nproc_per_node=4 train.py --mini --max_train_steps=50 + - python train.py --merge_checkpoint=./checkpoints/last diff --git a/utility/nightly_test/parity_alert_examples/test_cases/nanogpt/config.yaml b/utility/nightly_test/parity_alert_examples/test_cases/nanogpt/config.yaml new file mode 100644 index 00000000..6f1df5f1 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/test_cases/nanogpt/config.yaml @@ -0,0 +1,13 @@ +name: nanoGPT lightning +train: + path: nnscaler/examples/nanogpt + output: _merge/merged.pt + commands: + - rm -rf .nnscaler + - rm -rf lightning_logs _merge + - pip install -r requirements.txt + - python nanoGPT/data/shakespeare_char/prepare.py + - torchrun --nproc_per_node=2 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --max_iters=100 + - mkdir _merge + - cp lightning_logs/version_0/checkpoints/*/*.pt _merge + - python -c "from nnscaler.integration.lightning.pytorch import NnScalerStrategy ; NnScalerStrategy.merge_checkpoint(['_merge/0.pt', '_merge/1.pt'], '_merge/merged.pt')" diff --git a/utility/nightly_test/parity_alert_examples/train.py b/utility/nightly_test/parity_alert_examples/train.py new file mode 100644 index 00000000..c96e4b10 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/train.py @@ -0,0 +1,60 @@ +from pathlib import Path +from functools import partial +from subprocess import check_call as _call, check_output +import os +import sys + +import shutil +import yaml + +call = partial(_call, shell=True) + + +def train_model(config_dir: Path, save_dir: Path): + save_dir = save_dir / config_dir.name + save_dir.mkdir(parents=True, exist_ok=True) + config_file = config_dir / 'config.yaml' + + with open(config_file) as f: + config = yaml.safe_load(f) + + path = Path(config['train']['path']).absolute() + new_model = path / config['train']['output'] + env = {} + env.update(os.environ) + env.update({ + 'TRAIN_DATA_DIR': str(Path(os.getenv('TRAIN_DATA_DIR'))), + 'CONFIG_DIR': str(config_dir), + 'SAVE_DIR': str(save_dir), + 'RDZV_ENDPOINT': 'localhost:' + os.getenv('UNUSED_PORT'), + }) + env.update(config['train'].get('envs', {})) + for command in config['train']['commands']: + call(command, env=env, cwd=path) + shutil.copy2(new_model, save_dir / 'model.pt') + + +def main(workspace: str, parity_check_dir: str, parity_save_dir: str): + parity_check_root = Path(parity_check_dir).absolute() + parity_save_root = Path(parity_save_dir).absolute() + os.chdir(workspace) + test_cases = os.getenv('TEST_CASES') + if test_cases: + test_cases = test_cases.split(',') + print(f'Run test cases: {test_cases}') + else: + test_cases = None + print('Run all test cases') + for d in parity_check_root.glob('*'): + if not d.is_dir(): + continue + if not test_cases or d.name in test_cases: + print(f'Training for {d.name}...') + train_model(d, parity_save_root) + + +if __name__ == '__main__': + if len(sys.argv) !=4: + print('Usage: python train.py ') + exit(1) + main(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/utility/nightly_test/test_utils.py b/utility/nightly_test/test_utils.py new file mode 100644 index 00000000..bec73fcf --- /dev/null +++ b/utility/nightly_test/test_utils.py @@ -0,0 +1,183 @@ +from subprocess import CalledProcessError +import subprocess +import asyncio +import logging +import logging.handlers +from pathlib import Path +import os +import copy +import yaml + +logging.basicConfig( + filemode='a', + format="%(asctime)s | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + +logger = logging.getLogger("CubeSystemTest") +result_logger = logging.getLogger("cube_system_test_results") +warning_logger = logging.getLogger("CompareInterface") + +# smaller buffer to quick output +buffer_handler = logging.handlers.BufferingHandler(10) +logger.addHandler(buffer_handler) +result_logger.addHandler(buffer_handler) + +global_time = None + +class TestUtils: + @staticmethod + def execute_command(cmd: str, cwd: str): + """Execute a command and log the output""" + try: + result = subprocess.check_output(cmd, shell=True, cwd=cwd).decode('utf-8').strip() + return result + except subprocess.CalledProcessError as e: + print("An error occurred while trying to execute:", cmd) + return None + + @staticmethod + def call(cmds): + """Call commands async and log the output""" + + if isinstance(cmds, str): + cmds = [cmds] + try: + results = asyncio.run(TestUtils.run_commands_async(cmds)) + for result in results: + stdout, stderr = result + if stdout: + logger.info(f'{stdout.decode()}') + if stderr: + err_msg = stderr.decode().strip() + if ("Traceback (most recent call last):" in err_msg): + result_logger.error(f'{err_msg}') + else: + logger.error(f'{err_msg}') + except CalledProcessError as e: + result_logger.error(f"Commands {cmds} failed with error code {e.returncode}") + raise + + @staticmethod + async def run_command_async(cmd: str): + """run a command async and return the output of stdout and stderr""" + logger.info(f"Running command: {cmd}") + proc = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + stdout, stderr = await proc.communicate() + return stdout, stderr + + @staticmethod + async def run_commands_async(cmds: list): + """run commands async and return the output of stdout and stderr""" + tasks = [asyncio.ensure_future(TestUtils.run_command_async(cmd)) for cmd in cmds] + results = await asyncio.gather(*tasks) + return results + + @staticmethod + def get_ipv4_address(): + import re + interface_name = subprocess.check_output("route -n | grep '^0.0.0.0' | awk '{print $8}'", shell=True).decode().strip() + ifconfig_output = subprocess.check_output(f"ifconfig {interface_name}", shell=True).decode() + ip_address_match = re.search(r'inet (\S+)', ifconfig_output) + ip_pattern = re.compile( + r'^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.' + r'(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.' + r'(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.' + r'(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$' + ) + if ip_address_match: + ip_addr = ip_address_match.group(1) + if ip_addr and ip_pattern.match(ip_addr): + return ip_addr + + if os.getenv("CUBE_MASTER_ADDR"): + ip_addr = os.getenv("CUBE_MASTER_ADDR").strip() + if ip_addr and ip_pattern.match(ip_addr): + return ip_addr + + raise RuntimeError(f"cannot get ip address for interface {interface_name}, you can set master_addr manually by setting the environment variable CUBE_MASTER_ADDR") + + @staticmethod + def gen_log_folder(workspace): + global global_time + if global_time is None: + from datetime import datetime + now = datetime.now() + global_time = now.strftime("%Y%m%d_%H%M%S") + log_folder = Path(workspace) / 'cube_test_logs' / global_time + if not log_folder.exists(): + log_folder.mkdir(parents=True, exist_ok=True) + return log_folder + + @staticmethod + def parse_hosts_file(file_path): + file_path = Path(file_path).expanduser() + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + with open(file_path, 'r') as file: + lines = file.readlines() + ssh_host_list = [line.strip() for line in lines] + return ssh_host_list + + @staticmethod + def logger_redirect(logger1, log_folder, filename) -> tuple[str, logging.FileHandler]: + import logging.handlers + file_path = f"{log_folder}/{filename}.log" + result_handler = logging.FileHandler(file_path, 'a') + formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s") + result_handler.setFormatter(formatter) + logger1.addHandler(result_handler) + return file_path, result_handler + + @staticmethod + def merge_dict(dict_a, dict_b): + a = copy.deepcopy(dict_a) + b = copy.deepcopy(dict_b) + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + a[key] = TestUtils.merge_dict(a[key], b[key]) + elif b[key] is None or (b[key] == {}): + continue + else: + a[key] = b[key] + else: + a[key] = b[key] + return a + + @staticmethod + def merge_dicts(*dicts): + result = {} + for current_dict in dicts: + result = TestUtils.merge_dict(result, current_dict) + return result + + @staticmethod + def load_yaml_file(file_path) -> dict: + with open(file_path, 'r') as f: + element = yaml.safe_load(f) + if isinstance(element, dict): + TestUtils.recursive_replace_keys(element, '_', '-', 'fairseq') + TestUtils.recursive_replace_keys(element, '-', '_', 'torchrun') + TestUtils.recursive_replace_keys(element, '-', '_', 'envs') + return element + else: + raise ValueError(f"Invalid config_file {file_path}") + + @staticmethod + def recursive_replace_keys(d, old_char, new_char, target_key): + if target_key in d: + target_dict = d[target_key] + d[target_key] = {k.replace(old_char, new_char): v for k, v in target_dict.items()} + else: + for _, value in d.items(): + if isinstance(value, dict): + TestUtils.recursive_replace_keys(value, old_char, new_char, target_key) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + TestUtils.recursive_replace_keys(item, old_char, new_char, target_key) \ No newline at end of file From b9e79934a70139bd2050b755e3bcbc4912be83f8 Mon Sep 17 00:00:00 2001 From: XU Weijiang <90586345+0xWJ@users.noreply.github.com> Date: Wed, 4 Feb 2026 13:49:04 +0800 Subject: [PATCH 1890/1892] [BugFix] Refine HybridOptimizer to support mixed precision optimizer (#10) 1. Add a new Mixin (ScaleDelayedOptimizerMixin) to support MixedPrecisionAdam like optimizers 2. Refine HybridOptimizer to support ScaleDelayedOptimizerMixin. --------- Co-authored-by: zyeric --- nnscaler/runtime/f16_optimizer.py | 76 +++++----- nnscaler/runtime/hybrid_optimizer.py | 121 +++++++++++++++- tests/runtime/test_hybrid_optimizer.py | 130 ++++++++++++++++++ ...ptimizer_trainer_args_mixed_precision.yaml | 54 ++++++++ ...timizer_trainer_args_mixed_precision2.yaml | 54 ++++++++ 5 files changed, 402 insertions(+), 33 deletions(-) create mode 100644 tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision.yaml create mode 100644 tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision2.yaml diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py index 908a4f3e..bee85b89 100644 --- a/nnscaler/runtime/f16_optimizer.py +++ b/nnscaler/runtime/f16_optimizer.py @@ -4,11 +4,12 @@ # CREDITS: This implementation is inspired by Fairseq https://github.com/facebookresearch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py import logging -from typing import Optional, TYPE_CHECKING +import types +from typing import TYPE_CHECKING import torch -from nnscaler.cli.train_hook import TrainHook +from nnscaler.runtime.hybrid_optimizer import ScaleDelayedOptimizerMixin if TYPE_CHECKING: from nnscaler.cli.trainer import Trainer @@ -16,7 +17,7 @@ logger = logging.getLogger(__name__) -class MixedPrecisionF16OptimizerMixin(TrainHook): +class MixedPrecisionF16OptimizerMixin(ScaleDelayedOptimizerMixin): """ A mixin class for mixed precision optimizer. Support both FP16 and BF16 parameters. @@ -31,7 +32,6 @@ class MixedPrecisionF16OptimizerMixin(TrainHook): def __init__(self, *args, **kwargs): # forward __init__ call to the next class in mro(method resolution order) super().__init__(*args, **kwargs) - self._multiply_factor = 1.0 # This flag is used to indicate whether fp32_params are loaded from checkpoint. # If not, we will sync from fp16 params to fp32 params in after_load_checkpoint. # If the model is trained from scratch, this flag will be None. @@ -48,17 +48,25 @@ def after_setup(self, trainer: 'Trainer') -> None: Assumption: `clip_gnorm` is called immediately after `scale_grads` in training loop. """ - trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm - trainer.optimizer.clip_gnorm = self.overrided_clip_gnorm - trainer.optimizer._scale_grads = trainer.optimizer.scale_grads - trainer.optimizer.scale_grads = self.overrided_scale_grads + if trainer.optimizer is self: + # don't override when using HybridOptimizer + trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm + trainer.optimizer.clip_gnorm = self.overrided_clip_gnorm + trainer.optimizer._scale_grads = trainer.optimizer.scale_grads + trainer.optimizer.scale_grads = self.overrided_scale_grads + + # step method is overrided below to apply the scaling factor @classmethod - def build_fp32_params(cls, params): + def build_fp32_params(cls, params: list[torch.nn.Parameter]) -> list[torch.nn.Parameter]: # create FP32 copy of parameters and grads fp32_params = [] for p in params: - p32 = torch.nn.Parameter(p.data.float()) + if p.data.dtype != torch.float32: + p32 = torch.nn.Parameter(p.data.float()) + else: + # make sure the storage is not shared with original parameter + p32 = torch.nn.Parameter(p.data.clone()) p32.grad = torch.zeros_like(p32.data) fp32_params.append(p32) return fp32_params @@ -74,18 +82,22 @@ def step(self, closure=None): def zero_grad(self, set_to_none: bool = True): """ Clears the gradients of all optimized parameters. - Will ignore `set_to_none` and always set fp16 grads to None, and fp32 grads to zero. + Will ignore `set_to_none` and always set fp16 grads and fp32 grads to None. """ for p in self.f16_params: p.grad = None for p32 in self.fp32_params: - if p32.grad is not None: - p32.grad.zero_() + p32.grad = None def state_dict(self): """Return the optimizer's state dict.""" state_dict = super().state_dict() + # called from hybrid optimizer before call `.step` (to get the param_groups of the wrapped optimizer) + # In this case, state_dict['state'] is empty. + if not state_dict['state']: + return state_dict + # move fp32_params to the same level with 'exp_avg' and 'exp_avg_sq' # we do this to handle the merge of sharded checkpoint in nnscaler assert 'state' in state_dict, f'state not found in state_dict: {state_dict.keys()}' @@ -165,34 +177,36 @@ def after_load_checkpoint(self, trainer, checkpoint) -> None: self._sync_fp16_params_to_fp32() self._fp32_params_loaded = True - def overrided_scale_grads(self, scale: float): - """ - Scale the gradients by a factor. - Will override the original scale_grads method in ParallelOptimizer. - """ - self._multiply_factor *= scale + def _unfold_params(self, params) -> tuple[list[torch.nn.Parameter], dict]: + params = list(params) + if not params: + raise ValueError("optimizer got an empty parameter list") - def overrided_clip_gnorm(self, max_norm: Optional[float] = None) -> float: - """ - Will override the original clip_gnorm method in ParallelOptimizer. - """ - # self._clip_gnorm() is ParallelOptimizer.clip_gnorm - grad_norm = self._multiply_factor * self._clip_gnorm() - if max_norm is not None and max_norm > 0.0: - clip_coef = (max_norm / (grad_norm + 1e-6)).clamp(max=1.0) - self._multiply_factor *= clip_coef - return grad_norm + if isinstance(params[0], dict): + if len(params) > 1: + raise ValueError("MixedPrecisionF16OptimizerMixin only supports one param group") + unfolded_params = list(params[0]['params']) + unfolded_kwargs = {k: v for k, v in params[0].items() if k != 'params'} + else: + if not all(isinstance(p, torch.nn.Parameter) for p in params): + raise ValueError("optimizer params should be either a list of Parameters or a dict with 'params' key") + unfolded_params = params + unfolded_kwargs = {} + + return unfolded_params, unfolded_kwargs class MixedPrecisionAdam(MixedPrecisionF16OptimizerMixin, torch.optim.Adam): def __init__(self, params, **kwargs): - self.f16_params = list(params) + self.f16_params, unfolded_kwargs = self._unfold_params(params) self.fp32_params = self.build_fp32_params(self.f16_params) + kwargs = {**unfolded_kwargs, **kwargs} super().__init__(self.fp32_params, **kwargs) class MixedPrecisionAdamW(MixedPrecisionF16OptimizerMixin, torch.optim.AdamW): def __init__(self, params, **kwargs): - self.f16_params = list(params) + self.f16_params, unfolded_kwargs = self._unfold_params(params) self.fp32_params = self.build_fp32_params(self.f16_params) + kwargs = {**unfolded_kwargs, **kwargs} super().__init__(self.fp32_params, **kwargs) diff --git a/nnscaler/runtime/hybrid_optimizer.py b/nnscaler/runtime/hybrid_optimizer.py index e120c3df..be8a5c85 100644 --- a/nnscaler/runtime/hybrid_optimizer.py +++ b/nnscaler/runtime/hybrid_optimizer.py @@ -3,7 +3,8 @@ from collections import defaultdict from dataclasses import dataclass, field -from typing import Any, Callable, Iterable, Type, Union +import types +from typing import Any, Callable, Iterable, Type, Union, TYPE_CHECKING, Optional import torch from torch.optim import Optimizer @@ -14,6 +15,9 @@ from nnscaler.cli.train_hook import TrainHookHost, TrainHook from nnscaler.utils import fn_field, OptStateDict +if TYPE_CHECKING: + from nnscaler.cli.trainer import Trainer + @dataclass class HybridSubOptParamGroupConfig: @@ -55,7 +59,77 @@ def __exit__(self, type: Any, value: Any, tb: Any) -> None: self.remove() -class HybridOptimizer(torch.optim.Optimizer, TrainHookHost): +class ScaleDelayedOptimizerMixin(TrainHook): + """ + A mixin class to add scale-delayed optimization support to an optimizer. + This mixin overrides the `scale_grads`, `clip_gnorm`, and `step` methods + of the optimizer to delay the scaling of gradients until the `step` method is called. + """ + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in mro(method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + def after_setup(self, trainer: 'Trainer') -> None: + if trainer.optimizer is self: + # do nothing if we are in the hybrid optimizer, + # who is responsible for overriding these methods. + trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm + trainer.optimizer.clip_gnorm = self.overrided_clip_gnorm + trainer.optimizer._scale_grads = trainer.optimizer.scale_grads + trainer.optimizer.scale_grads = self.overrided_scale_grads + + # we need to override the step method to apply the scaling factor + # hybrid optimizer will also call `step` of child optimizers, + self._step = self.step + self.step = self.override_step + + def overrided_scale_grads(self, scale: float): + """ + Scale the gradients by a factor. + Will override the original scale_grads method in ParallelOptimizer. + """ + self._multiply_factor *= scale + + def overrided_clip_gnorm(self, max_norm: Optional[float] = None) -> float: + """ + Will override the original clip_gnorm method in ParallelOptimizer. + """ + # self._clip_gnorm() is ParallelOptimizer.clip_gnorm + grad_norm = self._multiply_factor * self._clip_gnorm() + if max_norm is not None and max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp(max=1.0) + self._multiply_factor *= clip_coef + return grad_norm + + def override_step(self, closure=None): + """ + Performs a single optimization step. + """ + # apply the accumulated multiply factor to grads + if self._multiply_factor != 1.0: + for pg_idx in range(len(self.param_groups)): + for p in self.param_groups[pg_idx]['params']: + if p.grad is not None: + p.grad.mul_(self._multiply_factor) + self._multiply_factor = 1.0 + # can't use super() here because we need to support applying this mixin to existing optimizers + self._step(closure) + + @classmethod + def apply_mixin(cls, obj: Any) -> Any: + """Apply this mixin to an existing object.""" + obj._multiply_factor = 1.0 + # bind the new methods + obj.after_setup = types.MethodType(cls.after_setup, obj) + obj.overrided_scale_grads = types.MethodType(cls.overrided_scale_grads, obj) + obj.overrided_clip_gnorm = types.MethodType(cls.overrided_clip_gnorm, obj) + obj.override_step = types.MethodType(cls.override_step, obj) + + return obj + + +class HybridOptimizer(torch.optim.Optimizer, TrainHookHost, TrainHook): """ A hybrid optimizer that combines multiple optimizers/multiple param groups into a single optimizer. @@ -148,6 +222,49 @@ def __init__( for optimizer in self.optimizers: self.param_groups.extend(optimizer.param_groups) + # to support scale-delayed optimizers like mixed-precision f16 optimizer + self._has_scale_delayed = any(isinstance(opt, ScaleDelayedOptimizerMixin) for opt in self.optimizers) + + def after_setup(self, trainer: 'Trainer') -> None: + if not self._has_scale_delayed: + return + + assert trainer.optimizer is self, "HybridOptimizer should not be nested inside another optimizer" + trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm + trainer.optimizer._scale_grads = trainer.optimizer.scale_grads + + # if any one of the optimizers is scale-delayed, + # we must apply the mixin to make sure all optimizers are scale-delayed + # this is the only way to calculate gnorm correctly. + for opt in self.optimizers: + if not isinstance(opt, ScaleDelayedOptimizerMixin): + ScaleDelayedOptimizerMixin.apply_mixin(opt) + # after_setup of non-scale-delayed optimizers can't be called automatically by Trainer + # we need to call it here manually + # For consistency, let's call all optimizers' after_setup manually here (including scale-delayed ones) + opt.after_setup(trainer) + # disable after_setup for sub optimizers + # as we have already handled it here + opt.after_setup = lambda *args, **kwargs: None + + def overrided_scale_grads(self, scale: float) -> None: + for optimizer in self.optimizers: + optimizer.overrided_scale_grads(scale) + + self.scale_grads = types.MethodType(overrided_scale_grads, self) + + def override_clip_gnorm(self, max_norm: Optional[float] = None) -> float: + # self._clip_gnorm() is ParallelOptimizer.clip_gnorm + # all optimizers have the same `multiply_factor` + grad_norm = self.optimizers[0]._multiply_factor * self._clip_gnorm() + if max_norm is not None and max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp(max=1.0) + # will update all optimizers' multiply_factor + self.scale_grads(clip_coef) + return grad_norm + + self.clip_gnorm = types.MethodType(override_clip_gnorm, self) + def _get_hook_objects(self): return self.optimizers diff --git a/tests/runtime/test_hybrid_optimizer.py b/tests/runtime/test_hybrid_optimizer.py index b4238246..72494837 100644 --- a/tests/runtime/test_hybrid_optimizer.py +++ b/tests/runtime/test_hybrid_optimizer.py @@ -9,6 +9,7 @@ import torch.distributed from nnscaler.cli.trainer import Trainer +from nnscaler.runtime.hybrid_optimizer import ScaleDelayedOptimizerMixin from tests.parallel_module.common import assert_close, assert_equal from ..launch_torchrun import launch_torchrun @@ -151,3 +152,132 @@ def trainer_worker(save_dir, use_zero): @pytest.mark.parametrize('use_zero', [0, 1]) def test_hybrid_optimizer(tmp_path, use_zero): launch_torchrun(2, trainer_worker, tmp_path, use_zero) + + +def param_clss_fn_mp(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'layers.1.' in param_name or 'layers.10.' in param_name: + return 0, 0 + else: + return 1, 0 + + +def trainer_worker_mp(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('test_hybrid_optimizer_trainer_args_mixed_precision.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + # train with a hybrid optimizer + ckpt0_savedir = save_dir / 'ckpt0' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # resume training with hybrid optimizer + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--checkpoint.resume_from', 'last', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # train with normal optimizer + ckpt1_savedir = save_dir / 'ckpt1' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--optimizer.args.config!', + '--optimizer.type', 'nnscaler.MixedPrecisionAdamW', + '--optimizer.args.lr', '0.02', + ]) + trainer.run() + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer']['state'], y['optimizer']['state']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_hybrid_optimizer_mp(tmp_path): + launch_torchrun(2, trainer_worker_mp, tmp_path) + + + +class ScaleDelayedAdamW(ScaleDelayedOptimizerMixin, torch.optim.AdamW): + pass + + +def trainer_worker_mp2(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('test_hybrid_optimizer_trainer_args_mixed_precision2.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + # train with a hybrid optimizer + ckpt0_savedir = save_dir / 'ckpt0' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # resume training with hybrid optimizer + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--checkpoint.resume_from', 'last', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # train with normal optimizer + ckpt1_savedir = save_dir / 'ckpt1' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--optimizer.args.config.optimizers.1.type', 'tests.runtime.test_hybrid_optimizer.ScaleDelayedAdamW', + ]) + trainer.run() + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer']['state'], y['optimizer']['state']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_hybrid_optimizer_mp2(tmp_path): + """ + Demonstrate that ScaleDelayedOptimizerMixin that is applied to existing optimizers + are equivalent to defining new optimizers that inherit from the mixin. + """ + launch_torchrun(2, trainer_worker_mp2, tmp_path) diff --git a/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision.yaml b/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision.yaml new file mode 100644 index 00000000..971ada07 --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision.yaml @@ -0,0 +1,54 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_train_steps: 10 +enable_progress_bar: false +precision: bf16 +instance_name: p$(compute_config.plan_ngpus) + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: nnscaler.HybridOptimizer + param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn_mp + args: + config: + optimizers: + - type: nnscaler.MixedPrecisionAdamW + options: + lr: 0.02 + - type: nnscaler.MixedPrecisionAdamW + options: + lr: 0.02 + clip_gnorm: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision2.yaml b/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision2.yaml new file mode 100644 index 00000000..8601c18e --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision2.yaml @@ -0,0 +1,54 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_train_steps: 10 +enable_progress_bar: false +precision: bf16 +instance_name: p$(compute_config.plan_ngpus) + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: nnscaler.HybridOptimizer + param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn_mp + args: + config: + optimizers: + - type: nnscaler.MixedPrecisionAdamW + options: + lr: 0.02 + - type: torch.optim.AdamW + options: + lr: 0.02 + clip_gnorm: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped From 166635aeecad50e2e9125fab78fa283b525e2517 Mon Sep 17 00:00:00 2001 From: Xu Weijiang Date: Wed, 4 Feb 2026 13:59:50 +0800 Subject: [PATCH 1891/1892] [Feat] Add Muon Support (dp without zero) --- nnscaler/graph/tracer/concrete_tracer.py | 12 +- nnscaler/parallel.py | 56 ++++++--- nnscaler/runtime/utils.py | 112 +++++++++++++++++ tests/cli/test_trainer_muon.py | 114 ++++++++++++++++++ tests/cli/trainer_args_muon.yaml | 51 ++++++++ tests/cli/trainer_args_muon_hybrid.yaml | 60 +++++++++ tests/graph/schedule/test_interleaved_1f1b.py | 4 +- tests/graph/test_segment.py | 6 +- tests/launch_torchrun.py | 5 +- tests/parallel_module/test_async.py | 6 +- tests/parallel_module/test_attr_dedup.py | 4 +- tests/parallel_module/test_checkpoint.py | 6 +- .../parallel_module/test_checkpoint_buffer.py | 6 +- .../parallel_module/test_checkpoint_dedup.py | 6 +- .../parallel_module/test_checkpoint_shared.py | 4 +- .../parallel_module/test_checkpoint_unused.py | 4 +- tests/parallel_module/test_ddp.py | 4 +- tests/parallel_module/test_e2e_detach_loss.py | 6 +- tests/parallel_module/test_end2end.py | 6 +- .../test_end2end_mix_precision.py | 4 +- .../test_gencode_ctx_manager.py | 4 +- tests/parallel_module/test_inference.py | 4 +- tests/parallel_module/test_line_timer.py | 4 +- tests/parallel_module/test_offload_params.py | 40 +++--- tests/parallel_module/test_reducer_hook.py | 4 +- tests/parallel_module/test_scale_grads.py | 4 +- .../test_shared_param_pipeline.py | 4 +- tests/parallel_module/test_submodule.py | 4 +- tests/parallel_module/test_wholemodule.py | 4 +- tests/runtime/test_utils.py | 62 ++++++++++ tests/utils.py | 4 + 31 files changed, 525 insertions(+), 89 deletions(-) create mode 100644 tests/cli/test_trainer_muon.py create mode 100644 tests/cli/trainer_args_muon.yaml create mode 100644 tests/cli/trainer_args_muon_hybrid.yaml create mode 100644 tests/runtime/test_utils.py diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index 72c2acf7..d2471d36 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -542,7 +542,10 @@ def get_wrapped_leaves(self, leaf_functions: Dict[Callable, wrap_utils.LeafWrapI method_name=func.__name__, ) elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn' \ - and not func.__qualname__.startswith('PyCapsule'): + and not func.__qualname__.startswith('PyCapsule') \ + and not func.__qualname__.startswith('pybind11_detail_function_'): + # this branch is for method/functions originally not defined in module level. + # in torch >= 2.9, we found pybind11_builtins are included in torch namespace. # method # in torch >= 2.2, we found two functions under torch._C has no __module__: # @@ -552,8 +555,11 @@ def get_wrapped_leaves(self, leaf_functions: Dict[Callable, wrap_utils.LeafWrapI path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) else: path = sys.modules[func.__module__] - path = getattr(path, func.__qualname__.split('.')[0]) - locations = (*locations, wrap_utils.Location(path, func.__name__)) + try: + path = getattr(path, func.__qualname__.split('.')[0]) + locations = (*locations, wrap_utils.Location(path, func.__name__)) + except AttributeError: + _logger.warning(f'Can not get the class path of method {func} {func.__qualname__}!') if len(locations) == 0: _logger.warning(f'Can not find location of {func}, skip wrap it.') continue diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 0b5dbde6..e8c2cb67 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1686,7 +1686,9 @@ def _get_parallel_module_state_dict_info( def _is_supported_optimizer(name: str): from nnscaler.runtime.hybrid_optimizer import HybridOptimizer - return ('adam' in name.lower()) or name == HybridOptimizer.__name__ + return ('adam' in name.lower()) \ + or ('muon' in name.lower()) \ + or name == HybridOptimizer.__name__ def _get_optimizer_state_dict_info( @@ -1746,7 +1748,7 @@ def _get_optimizer_state_dict_info( for opt_state_dict in optimizer_state_dicts: opt_extra_state = OptimizerExtraState(**opt_state_dict[ParallelModule.EXTRA_STATE_KEY]) if not _is_supported_optimizer(opt_extra_state.name): - raise ValueError("Only Adam-like optimizers are supported.") + raise ValueError("Only Adam-like or Muon-like optimizers are supported.") opt_extra_states[opt_extra_state.rank] = opt_extra_state for module_prefix, loc in opt_extra_state.parallel_module_locs.items(): @@ -2012,7 +2014,7 @@ def _trim_optimizer_merged_state_dict( Dict[str, Any]: the trimmed optimizer state dict """ if not _is_supported_optimizer(opt_extra_state.name): - raise ValueError("Only Adam-like optimizers are supported.") + raise ValueError("Only Adam-like or Muon-like optimizers are supported.") device = device or torch.cuda.current_device() @@ -2665,7 +2667,7 @@ def load_deduped_state_dict( if optimizer is not None and optimizer_state_dict is not None: if not _is_supported_optimizer(optimizer._extra_state.name): - raise ValueError("Only Adam-like optimizers are supported.") + raise ValueError("Only Adam-like or Muon-like optimizers are supported.") # get the locations of non-parallel module parameters # by removing the parallel module locations @@ -2731,20 +2733,28 @@ def _broadcast_opt_state(optimizer_state_dict: OptStateDict, state_indexes: List # step is too small, so we can just broadcast all of them all together # some adam/adamw optimizers may not have step in their state dict # so we need to check if 'step' is in the state dict - if 'step' in optimizer_state_dict['state'][state_indexes[0]]: + step_state_indexes = [k for k in state_indexes if 'step' in optimizer_state_dict['state'][k]] + if step_state_indexes: + assert all( + optimizer_state_dict['state'][k]['step'].dtype == + optimizer_state_dict['state'][step_state_indexes[0]]['step'].dtype and + optimizer_state_dict['state'][k]['step'].shape == + optimizer_state_dict['state'][step_state_indexes[0]]['step'].shape + for k in step_state_indexes + ) if rank == src_rank: step_stack = torch.stack( - [optimizer_state_dict['state'][k]['step'] for k in state_indexes] + [optimizer_state_dict['state'][k]['step'] for k in step_state_indexes] ) else: step_stack = torch.zeros( - len(state_indexes), - dtype=optimizer_state_dict['state'][state_indexes[0]]['step'].dtype, + len(step_state_indexes), + dtype=optimizer_state_dict['state'][step_state_indexes[0]]['step'].dtype, device=torch.cuda.current_device() ) torch.distributed.broadcast(step_stack, src=src_rank, group=curr_parallel_group) if rank != src_rank: - for k, v in zip(state_indexes, step_stack): + for k, v in zip(step_state_indexes, step_stack): optimizer_state_dict['state'][k]['step'].copy_(v) # broadcast other states @@ -3011,9 +3021,17 @@ def _send_trimmed_opt_state_dict( torch.distributed.send_object_list(sent, group=group, dst=dst_rank) # broadcast step in stack - if 'step' in trimmed_opt_state_dict['state'][state_keys[0]]: + step_state_keys = [k for k in state_keys if 'step' in trimmed_opt_state_dict['state'][k]] + if step_state_keys: + assert all( + trimmed_opt_state_dict['state'][k]['step'].dtype == + trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].dtype and + trimmed_opt_state_dict['state'][k]['step'].shape == + trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].shape + for k in step_state_keys + ) step_stack = torch.stack( - [trimmed_opt_state_dict['state'][k]['step'] for k in state_keys] + [trimmed_opt_state_dict['state'][k]['step'] for k in step_state_keys] ) torch.distributed.send(step_stack.cuda(), group=group, dst=dst_rank) @@ -3060,14 +3078,22 @@ def _receive_trimmed_opt_state_dict( } # receive steps - if 'step' in trimmed_opt_state_dict['state'][state_keys[0]]: + step_state_keys = [k for k in state_keys if 'step' in trimmed_opt_state_dict['state'][k]] + if step_state_keys: + assert all( + trimmed_opt_state_dict['state'][k]['step'].dtype == + trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].dtype and + trimmed_opt_state_dict['state'][k]['step'].shape == + trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].shape + for k in step_state_keys + ) step_stack = torch.zeros( - len(state_keys), - dtype=trimmed_opt_state_dict['state'][state_keys[0]]['step'].dtype, + len(step_state_keys), + dtype=trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].dtype, device='cuda' ) torch.distributed.recv(step_stack, group=group, src=src_rank) - for k, v in zip(state_keys, step_stack): + for k, v in zip(step_state_keys, step_stack): trimmed_opt_state_dict['state'][k]['step'].copy_(v) # receive other states diff --git a/nnscaler/runtime/utils.py b/nnscaler/runtime/utils.py index 43ec0656..e6cb1bd9 100644 --- a/nnscaler/runtime/utils.py +++ b/nnscaler/runtime/utils.py @@ -5,6 +5,7 @@ from typing import Any, List import logging +import heapq _logger = logging.getLogger(__name__) @@ -104,3 +105,114 @@ def microbatches(samples: List[Any], cycle: bool = False) -> MicroBatchDataLoade MicroBatchDataLoader: a micro-batch data loader. """ return MicroBatchDataLoader(samples, cycle=cycle) + + +def split_array_min_max(nums: list[int], g: int, *, keep_order: bool = True) -> tuple[list[list[int]], list[list[int]]]: + """ + Split the array nums into g continuous subarrays such that the maximum sum + of the subarrays is minimized. + + Args: + nums (list[int]): The input array of integers. + g (int): The number of groups to split the array into. + keep_order (bool): Whether to keep the order of elements in the subarrays. + If True, the order of elements in the original array is preserved + in the subarrays. If False, the order can be changed. + Returns: + tuple[list[list[int]], list[list[int]]]: + A tuple containing a list of g subarrays and their corresponding indices. + """ + if g <= 0 or g > len(nums): + raise ValueError("g must be in the range [1, len(nums)]") + + if not keep_order: + return _split_array_min_max_out_of_order(nums, g) + + def _check(limit): + count = 1 + count_sum = nums[0] + for x in nums[1:]: + if count_sum + x > limit: + count += 1 + count_sum = x + else: + count_sum += x + return count <= g + + # 1. Binary search to find the "minimum maximum sum" (Target Limit) + left = max(nums) + right = sum(nums) + target_limit = right + + while left <= right: + mid = (left + right) // 2 + if _check(mid): + target_limit = mid + right = mid - 1 + else: + left = mid + 1 + + # 2. Reconstruct the result based on the calculated target_limit + # Note: A special greedy strategy is needed here to ensure exactly g groups + # A simple greedy approach may result in fewer than g groups (although the maximum sum meets the condition, the number of groups is insufficient) + + result = [[nums[0]]] + result_idx = [[0]] + current_sum = nums[0] + + # We process in forward order, or forcefully reserve enough elements for the remaining groups during forward processing + # Here we use forward iteration with a "remaining quota" check + for i, x in enumerate(nums[1:], start=1): + # Remaining groups needed + groups_needed = g - len(result) + # Remaining elements not yet processed + elements_left = len(nums) - i + if elements_left == groups_needed: + # Each element must form a separate group + result.append([x]) + result_idx.append([i]) + current_sum = x + continue + + if current_sum + x > target_limit: + result.append([x]) + result_idx.append([i]) + current_sum = x + else: + result[-1].append(x) + result_idx[-1].append(i) + current_sum += x + + return result, result_idx + + +def _split_array_min_max_out_of_order(nums: list[int], g: int) -> tuple[list[list[int]], list[list[int]]]: + """ + Split the array nums into g subarrays (order of elements can be changed) + This problem (multi-way number partitioning) is NP-hard. We use a greedy approximation algorithm here. + """ + # 1. Sort numbers in descending order + nums_with_indices = list((nun, i) for i, nun in enumerate(nums)) + sorted_nums = sorted(nums_with_indices, reverse=True) + + # 2. Initialize heap + heap = [(0, i) for i in range(g)] + + # groups to save results + groups = [[] for _ in range(g)] + group_idx = [[] for _ in range(g)] + + # 3. greedy assignment + for num, idx in sorted_nums: + # Pop the bucket with the smallest current sum + current_sum, gidx = heapq.heappop(heap) + + # Add the number to this bucket + groups[gidx].append(num) + group_idx[gidx].append(idx) + + # Update the sum of this bucket and push it back to the heap + new_sum = current_sum + num + heapq.heappush(heap, (new_sum, gidx)) + + return groups, group_idx diff --git a/tests/cli/test_trainer_muon.py b/tests/cli/test_trainer_muon.py new file mode 100644 index 00000000..dbb4af3b --- /dev/null +++ b/tests/cli/test_trainer_muon.py @@ -0,0 +1,114 @@ +from pathlib import Path +import shutil + +import torch + +import pytest + +from nnscaler.cli.trainer import Trainer +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.common import assert_equal + + +try: + from torch.optim import Muon +except ImportError: + pytest.skip("Muon not available", allow_module_level=True) + + + +def trainer_muon_worker(save_dir, config_file): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name(config_file).resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + # train first epoch + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '1', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # create merged checkpoint + if trainer.rank == 0: + Trainer.merge_checkpoint(list((ckpt_savedir / 'last').glob('*.ckpt')), ckpt_savedir / 'merged.pt') + + torch.distributed.barrier() + + # train 2nd epoch, resume from merged checkpoint + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', str(ckpt_savedir / 'merged.pt'), + ]) + trainer.run() + + torch.distributed.barrier() + + # train 3rd epoch, resume from deduped checkpoint + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '3', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + torch.distributed.barrier() + + # train 4th epoch, resume from sharded checkpoint + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '4', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + ]) + trainer.run() + + torch.distributed.barrier() + + ckpt1_savedir = save_dir / 'ckpt1' + # train 4 epoch without resuming + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '4', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + ]) + trainer.run() + + torch.distributed.barrier() + + if trainer.rank == 0: + for i in range(2): + x = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + for key in ['model', 'optimizer']: + assert_equal(x[key], y[key]) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('config_file', ['trainer_args_muon.yaml', 'trainer_args_muon_hybrid.yaml']) +def test_trainer_muon_resume_correctness(tmp_path, config_file): + launch_torchrun(2, trainer_muon_worker, tmp_path, config_file) + + +def param_clss_fn(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'layers.1.' in param_name or 'layers.10.' in param_name: + return 0, 0 + else: + return 1, 0 diff --git a/tests/cli/trainer_args_muon.yaml b/tests/cli/trainer_args_muon.yaml new file mode 100644 index 00000000..ae1ef334 --- /dev/null +++ b/tests/cli/trainer_args_muon.yaml @@ -0,0 +1,51 @@ +vars: + dim: 16 + drop_last: true + +compute_config: + plan_ngpus: 1 + runtime_ngpus: 2 + constant_folding: true + use_zero: 0 + use_end2end: true + +run_mode: run +pas_policy: dp +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 +max_train_steps: 100 +seed: 0 +precision: fp32 +enable_progress_bar: false + +model: + type: tests.cli.common.MLP + args: + dim: $(vars.dim) + nlayers: 16 + +optimizer: + type: torch.optim.Muon + args: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: $(vars.dim) + size: 100 + val_args: + dim: $(vars.dim) + size: 10 + +dataloader: + train_args: + drop_last: $(vars.drop_last) + val_args: + drop_last: $(vars.drop_last) + +checkpoint: + keep_last_n_checkpoints: 5 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/cli/trainer_args_muon_hybrid.yaml b/tests/cli/trainer_args_muon_hybrid.yaml new file mode 100644 index 00000000..8e2f8987 --- /dev/null +++ b/tests/cli/trainer_args_muon_hybrid.yaml @@ -0,0 +1,60 @@ +vars: + dim: 16 + drop_last: true + +compute_config: + plan_ngpus: 1 + runtime_ngpus: 2 + constant_folding: true + use_zero: 0 + use_end2end: true + +run_mode: run +pas_policy: dp +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 +max_train_steps: 100 +seed: 0 +precision: bf16 +enable_progress_bar: false + +model: + type: tests.cli.common.MLP + args: + dim: $(vars.dim) + nlayers: 16 + + +optimizer: + type: nnscaler.HybridOptimizer + param_clss_fn: tests.cli.test_trainer_muon.param_clss_fn + args: + config: + optimizers: + - type: nnscaler.runtime.f16_optimizer.MixedPrecisionAdamW + options: + lr: 0.01 + - type: torch.optim.Muon + options: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: $(vars.dim) + size: 100 + val_args: + dim: $(vars.dim) + size: 10 + +dataloader: + train_args: + drop_last: $(vars.drop_last) + val_args: + drop_last: $(vars.drop_last) + +checkpoint: + keep_last_n_checkpoints: 5 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/graph/schedule/test_interleaved_1f1b.py b/tests/graph/schedule/test_interleaved_1f1b.py index 593ff23f..ed779ba3 100644 --- a/tests/graph/schedule/test_interleaved_1f1b.py +++ b/tests/graph/schedule/test_interleaved_1f1b.py @@ -21,7 +21,7 @@ from nnscaler.ir.operator import IRFwOperation, IRDataOperation from nnscaler.graph.segment import IRSegment from nnscaler.graph.schedule.predefined import PredefinedSched -from tests.utils import clear_dir_on_rank0, init_random +from tests.utils import clear_dir_on_rank0, init_random, PYTEST_RUN_ID from tests.launch_torchrun import torchrun from tests.parallel_module.common import assert_equal from tests.parallel_module.test_gencode import _gencode_contains @@ -131,7 +131,7 @@ def worker_pipeline_2(n_micro_batches): trace_data = torch.randn([2, 32], dtype=torch.float32, device=torch.cuda.current_device()) cfg = ComputeConfig(2, 2, use_end2end=True, pas_config=dict(n_micro_batches=n_micro_batches)) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_1f1b_interleaved') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_1f1b_interleaved_{PYTEST_RUN_ID}') as tempdir: pm_1f1b = parallelize( m, {'x': trace_data}, diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py index 57c10809..32e36bf6 100644 --- a/tests/graph/test_segment.py +++ b/tests/graph/test_segment.py @@ -18,7 +18,7 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.ir.operator import IRFwOperation, IRDataOperation from tests.parallel_module.test_gencode import _gencode_contains, print_gencode -from ..utils import replace_all_device_with, clear_dir_on_rank0, init_random +from ..utils import replace_all_device_with, clear_dir_on_rank0, init_random, PYTEST_RUN_ID from ..launch_torchrun import torchrun @@ -119,7 +119,7 @@ def worker_a(): m.train() trace_data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_infer_grad_pyfunc') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_infer_grad_pyfunc_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'q': trace_data,}, @@ -211,7 +211,7 @@ def worker_b(use_end2end): init_random() trace_data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_infer_grad_no_grad') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_infer_grad_no_grad_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'q': trace_data,}, diff --git a/tests/launch_torchrun.py b/tests/launch_torchrun.py index fce62ed2..27933ecc 100644 --- a/tests/launch_torchrun.py +++ b/tests/launch_torchrun.py @@ -4,11 +4,12 @@ from typing import Callable import uuid import torch +import os from torch.distributed.run import elastic_launch, LaunchConfig from torch.distributed.elastic.multiprocessing.errors import ChildFailedError -from .utils import retry +from .utils import retry, MASTER_PORT @retry(ChildFailedError, delay=10, match='The server socket has failed to listen on any local network address.') @@ -18,7 +19,7 @@ def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): max_nodes=1, nproc_per_node=nproc_per_node, rdzv_backend = "c10d", - rdzv_endpoint = "localhost:29401", + rdzv_endpoint = f"localhost:{MASTER_PORT}", run_id = str(uuid.uuid4()), monitor_interval=0.1, max_restarts=0, diff --git a/tests/parallel_module/test_async.py b/tests/parallel_module/test_async.py index 97b770b5..af9d75f0 100644 --- a/tests/parallel_module/test_async.py +++ b/tests/parallel_module/test_async.py @@ -13,7 +13,7 @@ from tests.launch_torchrun import launch_torchrun from tests.launch_torchrun import clone_to_cpu_recursively from tests.parallel_module.common import assert_equal, init_distributed -from tests.utils import clear_dir_on_rank0, init_random +from tests.utils import clear_dir_on_rank0, init_random, PYTEST_RUN_ID from .test_wholemodule import FcRelu_4_4 @@ -88,7 +88,7 @@ def _train(model: ParallelModule, update_freq): def _gpu_worker(pas, ngpus, update_freq): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_async') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_async_{PYTEST_RUN_ID}') as tempdir: whole_module_async, sub_module_async = _create_modules( pas, ComputeConfig( 1, ngpus, use_async_reducer=True, @@ -203,7 +203,7 @@ def _train_pp(model: ParallelModule, num_replicas, rank): def _gpu_worker_pp(pas, pp_ngpus, runtime_ngpus, update_freq): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_pp_async') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_pp_async_{PYTEST_RUN_ID}') as tempdir: init_random() whole_module_async = parallelize( OrigModuleEnd2End(), { diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py index b1a5a95b..99d18443 100644 --- a/tests/parallel_module/test_attr_dedup.py +++ b/tests/parallel_module/test_attr_dedup.py @@ -20,7 +20,7 @@ from .common import init_distributed, assert_equal from ..launch_torchrun import launch_torchrun -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class Net(torch.nn.Module): @@ -55,7 +55,7 @@ def pas(graph: IRGraph, config: ComputeConfig): def _gpu_worker_spmd(cc: ComputeConfig): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_dedup_attr') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'nnscaler_test_dedup_attr_{PYTEST_RUN_ID}') as tempdir: module = parallelize( Net(), {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index fc8388fb..565ad8b1 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -30,7 +30,7 @@ from .common import CubeLinear, init_random, init_distributed, PASMegatron, assert_equal from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import replace_all_device_with, clear_dir_on_rank0 +from ..utils import replace_all_device_with, clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -475,7 +475,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_count, check_module=None): init_distributed() compiled_results = [] - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_{PYTEST_RUN_ID}') as tempdir: for i in range(resume_count): start = i * per_resume_update_count end = (i + 1) * per_resume_update_count @@ -591,7 +591,7 @@ def test_checkpoint_intra_reducer(module_type, use_zero): def _gpu_merge_worker(): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_merge') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_merge_{PYTEST_RUN_ID}') as tempdir: compiled_module = _create_cube_module('data', ComputeConfig(2, 4, use_zero=True), tempdir, diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index cde6f864..c81c11a2 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -12,7 +12,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun -from ..utils import catch_log, clear_dir_on_rank0 +from ..utils import catch_log, clear_dir_on_rank0, PYTEST_RUN_ID class Net1(torch.nn.Module): @@ -63,7 +63,7 @@ def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_sh def _gpu_worker(): init_distributed() compute_config = ComputeConfig(1, 1, use_zero=False) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_{PYTEST_RUN_ID}') as tempdir: net1 = _to_cube_model(Net1(), compute_config, tempdir, 'net1', (128, 64)) cube_state_dict = net1.state_dict() assert not any(key.startswith('buffer') for key in cube_state_dict) @@ -129,7 +129,7 @@ def _gpu_worker_broadcast(): init_distributed() compute_config = ComputeConfig(1, 2, use_zero=False) rank = torch.distributed.get_rank() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_broadcast_fail') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_broadcast_fail_{PYTEST_RUN_ID}') as tempdir: net1 = _to_cube_model(Net1(), compute_config, tempdir, 'net1', (128, 64), init_module_params=False) with pytest.raises(RuntimeError, match="Non-persistent buffers haven't been initialized."): broadcast_weights(net1) diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index a5fec814..6ffdafdd 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -17,7 +17,7 @@ from .common import PASMegatron, CubeLinear, init_random, init_distributed, assert_equal from ..launch_torchrun import launch_torchrun from .test_checkpoint import gendata, train_step, End2EndMLP, End2EndMLPWithUnusedAndShared -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -179,7 +179,7 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): def _gpu_worker(pas, cc1, cc2): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_compact') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_compact_{PYTEST_RUN_ID}') as tempdir: _train(_create_cube_module(pas, cc1, cc2, tempdir), tempdir) torch.distributed.barrier() _check_deduped( @@ -202,7 +202,7 @@ def test_checkpoint_compact(use_zero): def _gpu_worker_pipeline(cc): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_compact_pipeline') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_compact_pipeline_{PYTEST_RUN_ID}') as tempdir: for model_cls in [End2EndMLP, End2EndMLPWithUnusedAndShared]: pipeline_moule_cls = model_cls.to_pipeline_module(cc, tempdir) _train(pipeline_moule_cls().cuda(), tempdir) diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index 8b5e91b1..efceb594 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -14,7 +14,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun from .test_checkpoint import End2EndMLP, train_step, gendata -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcReluWithShared(nn.Module): @@ -194,7 +194,7 @@ def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus): # d. compare the full state dict in step a and the merged state dict in step c. They should be the same. init_distributed() compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_{PYTEST_RUN_ID}') as tempdir: if torch.distributed.get_rank() == 0: tempdir.mkdir(parents=True, exist_ok=True) _train_raw(_create_cube_module(pas, compute_config, tempdir, f'{module_type}/raw'), tempdir) diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py index ca8ea820..c0aec698 100644 --- a/tests/parallel_module/test_checkpoint_unused.py +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -24,7 +24,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively from .test_checkpoint_shared import _train_raw, _load_merged -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcReluWithUnused(nn.Module): @@ -113,7 +113,7 @@ def _gpu_worker(use_zero, pas, plan_ngpus, runtime_ngpus): # d. compare the full state dict in step a and the merged state dict in step c. They should be the same. init_distributed() compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_{PYTEST_RUN_ID}') as tempdir: if torch.distributed.get_rank() == 0: tempdir.mkdir(parents=True, exist_ok=True) _train_raw(_create_cube_module(pas, compute_config, tempdir, 'raw'), tempdir) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index 400e3b18..9a4a179a 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -23,7 +23,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -206,7 +206,7 @@ def _gpu_worker_ddp(update_freq): def _gpu_worker_cube(pas, plan_ngpus, runtime_ngpus, update_freq, use_zero): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_{PYTEST_RUN_ID}') as tempdir: compiled_module = _create_cube_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero), tempdir diff --git a/tests/parallel_module/test_e2e_detach_loss.py b/tests/parallel_module/test_e2e_detach_loss.py index f2284af9..4b4cc914 100644 --- a/tests/parallel_module/test_e2e_detach_loss.py +++ b/tests/parallel_module/test_e2e_detach_loss.py @@ -19,7 +19,7 @@ from nnscaler.ir.operator import IRFwOperation, IRDataOperation from nnscaler.graph.segment import IRSegment from nnscaler.graph.schedule.predefined import PredefinedSched -from tests.utils import clear_dir_on_rank0, init_random +from tests.utils import clear_dir_on_rank0, init_random, PYTEST_RUN_ID from tests.launch_torchrun import torchrun from tests.parallel_module.test_gencode import _gencode_contains @@ -94,7 +94,7 @@ def worker_pipeline_2x2(model_cls): torch.cuda.manual_seed(0) trace_data = torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2x2') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_detach_loss_pp_2x2_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'x': trace_data}, @@ -159,7 +159,7 @@ def worker_pipeline_2(model_cls): torch.cuda.manual_seed(0) trace_data = torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_detach_loss_pp_2_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'x': trace_data}, diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index 924c62c1..a9b1f8d9 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -23,7 +23,7 @@ from nnscaler.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts from .common import assert_equal, init_distributed, PASMegatron, init_random from ..launch_torchrun import clone_to_cpu_recursively, launch_torchrun -from ..utils import replace_all_device_with, clear_dir_on_rank0 +from ..utils import replace_all_device_with, clear_dir_on_rank0, PYTEST_RUN_ID from .test_checkpoint import End2EndMLP @@ -111,7 +111,7 @@ def gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages=None, nmi init_random() nstages = nstages or plan_ngpus nmicros = nmicros or plan_ngpus - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_end2end_{PYTEST_RUN_ID}') as tempdir: init_random() model = model_cls() model = parallelize( @@ -416,7 +416,7 @@ def _train_cube_one_sample(model: ParallelModule, mbs): def gpu_worker_cube_one_sample(): init_distributed() init_random() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_end2end_{PYTEST_RUN_ID}') as tempdir: init_random() model = MLP() model = parallelize( diff --git a/tests/parallel_module/test_end2end_mix_precision.py b/tests/parallel_module/test_end2end_mix_precision.py index dd399182..f1785acd 100644 --- a/tests/parallel_module/test_end2end_mix_precision.py +++ b/tests/parallel_module/test_end2end_mix_precision.py @@ -26,7 +26,7 @@ from .test_checkpoint import End2EndMLP from .test_end2end import allclose, merge_cube_result -from ..utils import init_parameter, clear_dir_on_rank0 +from ..utils import init_parameter, clear_dir_on_rank0, PYTEST_RUN_ID DATA_SIZE = 16 @@ -136,7 +136,7 @@ def gpu_worker_cube(use_zero=False, async_reducer=False, use_bucket=False): plan_ngpus = 2 runtime_ngpus = 4 nmicros = plan_ngpus - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end_mp') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_end2end_mp_{PYTEST_RUN_ID}') as tempdir: init_random() model = MPModule() model = parallelize( diff --git a/tests/parallel_module/test_gencode_ctx_manager.py b/tests/parallel_module/test_gencode_ctx_manager.py index 42bd7181..fc4f07c1 100644 --- a/tests/parallel_module/test_gencode_ctx_manager.py +++ b/tests/parallel_module/test_gencode_ctx_manager.py @@ -12,7 +12,7 @@ from .common import init_distributed, init_random from .test_end2end import merge_cube_result from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class CtxManagerModel(torch.nn.Module): @@ -257,7 +257,7 @@ def _train_cube_one_sample(model: ParallelModule, mbs): def gpu_worker_cube_one_sample(): init_distributed() init_random() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ctx_manager') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ctx_manager_{PYTEST_RUN_ID}') as tempdir: init_random() model = CtxManagerModel() model = parallelize( diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 6b15d8a1..ce9adbdd 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -13,7 +13,7 @@ from .common import CubeLinear, init_distributed, init_random from ..launch_torchrun import torchrun -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -62,7 +62,7 @@ def _inference_worker(ngpus, inference_only): init_distributed() init_random() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_inference_test') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_inference_test_{PYTEST_RUN_ID}') as tempdir: model = Module() model.eval() diff --git a/tests/parallel_module/test_line_timer.py b/tests/parallel_module/test_line_timer.py index 90483848..28948a38 100644 --- a/tests/parallel_module/test_line_timer.py +++ b/tests/parallel_module/test_line_timer.py @@ -13,7 +13,7 @@ from .common import init_distributed from ..launch_torchrun import launch_torchrun -from ..utils import catch_stdout, clear_dir_on_rank0 +from ..utils import catch_stdout, clear_dir_on_rank0, PYTEST_RUN_ID class Net(torch.nn.Module): @@ -43,7 +43,7 @@ def _gpu_worker(): compute_config = ComputeConfig(1, 1, use_zero=False) try: CompileFlag.line_timer = True - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_line_timer') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_line_timer_{PYTEST_RUN_ID}') as tempdir: net = _to_cube_model(Net(), compute_config, tempdir, 'net', (128, 64)) x = torch.randn(128, 64).cuda() diff --git a/tests/parallel_module/test_offload_params.py b/tests/parallel_module/test_offload_params.py index 81e5d305..2e96a123 100644 --- a/tests/parallel_module/test_offload_params.py +++ b/tests/parallel_module/test_offload_params.py @@ -14,7 +14,7 @@ from .common import PASMegatron, CubeLinear, init_random, init_distributed, assert_equal from ..launch_torchrun import launch_torchrun -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class SimpleMLP(nn.Module): @@ -68,7 +68,7 @@ def _mem_worker(): plan_ngpus=1, runtime_ngpus=2, ) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_offload_mem') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'nnscaler_test_offload_mem_{PYTEST_RUN_ID}') as tempdir: module = SimpleMLP(dim, dim, dim) p_module = parallelize( module, @@ -106,12 +106,12 @@ def _correctness_worker(): plan_ngpus=1, runtime_ngpus=2, ) - - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_offload_correctness') as tempdir: + + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'nnscaler_test_offload_correctness_{PYTEST_RUN_ID}') as tempdir: # Create test data torch.manual_seed(42 + torch.distributed.get_rank()) test_data = [torch.randn(bsz, dim).cuda() for _ in range(num_steps)] - + # Test 1: Normal execution without offload/load init_random() module1 = SimpleMLP(dim, dim, dim) @@ -124,7 +124,7 @@ def _correctness_worker(): instance_name='normal' ) optimizer1 = build_optimizer(p_module1, torch.optim.Adam, lr=0.01) - + results_normal = [] for step, x in enumerate(test_data): p_module1.train() @@ -133,16 +133,16 @@ def _correctness_worker(): loss.backward() optimizer1.step() optimizer1.zero_grad() - + # Save intermediate results for comparison results_normal.append({ 'loss': loss.detach().cpu(), 'output': output.detach().cpu(), 'params': {name: param.detach().cpu().clone() for name, param in p_module1.named_parameters()} }) - + torch.distributed.barrier() - + # Test 2: Execution with offload/load init_random() module2 = SimpleMLP(dim, dim, dim) @@ -155,50 +155,50 @@ def _correctness_worker(): instance_name='offload' ) optimizer2 = build_optimizer(p_module2, torch.optim.Adam, lr=0.01) - + # First offload to initialize the buffer_shape p_module2.sleep() - + results_offload = [] for step, x in enumerate(test_data): # Load params at the beginning of each step p_module2.wake_up() - + p_module2.train() output = p_module2(x) loss = output.sum() loss.backward() optimizer2.step() optimizer2.zero_grad() - + # Save intermediate results for comparison results_offload.append({ 'loss': loss.detach().cpu(), 'output': output.detach().cpu(), 'params': {name: param.detach().cpu().clone() for name, param in p_module2.named_parameters()} }) - + # Offload params at the end of each step p_module2.sleep() - + torch.distributed.barrier() - + # Compare results for step in range(num_steps): normal_result = results_normal[step] offload_result = results_offload[step] - + # Compare loss assert torch.equal(normal_result['loss'], offload_result['loss']), \ f"Loss mismatch at step {step}: {normal_result['loss']} vs {offload_result['loss']}" - + # Compare output assert torch.equal(normal_result['output'], offload_result['output']), \ f"Output mismatch at step {step}" - + # Compare parameters for param_name in normal_result['params']: - assert torch.equal(normal_result['params'][param_name], + assert torch.equal(normal_result['params'][param_name], offload_result['params'][param_name]), \ f"Parameter {param_name} mismatch at step {step}" diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index 851dfad0..eb6807ae 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -14,7 +14,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -118,7 +118,7 @@ def post_hook(reducer, grad): def _gpu_worker(pas, plan_ngpus, runtime_ngpus=None): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_hook') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_hook_{PYTEST_RUN_ID}') as tempdir: compiled_module = _create_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus or plan_ngpus), tempdir) _train(compiled_module) diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py index 255869cf..ce7b9b45 100644 --- a/tests/parallel_module/test_scale_grads.py +++ b/tests/parallel_module/test_scale_grads.py @@ -23,7 +23,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -130,7 +130,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, scale_grads: bool): def _gpu_worker(pas, plan_ngpus, runtime_ngpus, scale_grads: bool): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_scale_grads') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_scale_grads_{PYTEST_RUN_ID}') as tempdir: compiled_module = _create_cube_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=True), tempdir diff --git a/tests/parallel_module/test_shared_param_pipeline.py b/tests/parallel_module/test_shared_param_pipeline.py index 86dc4380..3d6159e7 100644 --- a/tests/parallel_module/test_shared_param_pipeline.py +++ b/tests/parallel_module/test_shared_param_pipeline.py @@ -19,7 +19,7 @@ from nnscaler.ir.operator import IRFwOperation, IRDataOperation from nnscaler.graph.segment import IRSegment from nnscaler.graph.schedule.predefined import PredefinedSched -from tests.utils import clear_dir_on_rank0, init_random, raises_with_cause +from tests.utils import clear_dir_on_rank0, init_random, raises_with_cause, PYTEST_RUN_ID from tests.launch_torchrun import torchrun from tests.parallel_module.test_gencode import _gencode_contains, print_gencode @@ -264,7 +264,7 @@ def worker_pipeline(model_cls, pas, plan_ngpus, checker): torch.cuda.manual_seed(0) trace_data = torch.randn([2, 16], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_detach_loss_pp_2_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'x': trace_data}, diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 047f9f5b..628071ea 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -17,7 +17,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -118,7 +118,7 @@ def _train(model, update_freq, is_cube): def _gpu_worker(pas, ngpus, update_freq): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_{PYTEST_RUN_ID}') as tempdir: orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) orig_results = _train(orig_module, update_freq, False) compiled_results = _train(compiled_module, update_freq, True) diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 64ec5ff0..3d385d19 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -17,7 +17,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -106,7 +106,7 @@ def _train(model, is_cube): def _gpu_worker(pas, ngpus): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_{PYTEST_RUN_ID}') as tempdir: orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) orig_results = _train(orig_module, False) compiled_results = _train(compiled_module, True) diff --git a/tests/runtime/test_utils.py b/tests/runtime/test_utils.py new file mode 100644 index 00000000..4e736649 --- /dev/null +++ b/tests/runtime/test_utils.py @@ -0,0 +1,62 @@ +from nnscaler.runtime.utils import split_array_min_max + + +def test_split_array_min_max(): + nums = [1, 2, 3, 4, 5, 6, 7, 8, 9] + g = 3 + groups, group_idx = split_array_min_max(nums, g, keep_order=True) + assert groups == [[1, 2, 3, 4, 5], [6, 7], [8, 9]] + assert group_idx == [[0, 1, 2, 3, 4], [5, 6], [7, 8]] + + groups, group_idx = split_array_min_max(nums, g, keep_order=False) + assert groups == [[9, 4, 3], [8, 5, 2], [7, 6, 1]] + assert group_idx == [[8, 3, 2], [7, 4, 1], [6, 5, 0]] + + nums = [10, 10, 10, 10, 10, 10] + g = 3 + groups, group_idx = split_array_min_max(nums, g, keep_order=True) + assert groups == [[10, 10], [10, 10], [10, 10]] + assert group_idx == [[0, 1], [2, 3], [4, 5]] + + groups, group_idx = split_array_min_max(nums, g, keep_order=False) + assert groups == [[10, 10], [10, 10], [10, 10]] + assert group_idx == [[5, 2], [4, 1], [3, 0]] + + nums = [ + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 1310720, 1310720 + ] + g = 8 + best_sum = sum(nums) // g + + groups, group_idx = split_array_min_max(nums, g, keep_order=True) + max_sum = max(sum(group) for group in groups) + assert len(groups) == 8 + assert list(j for k in group_idx for j in k) == list(range(len(nums))) + + groups, group_idx = split_array_min_max(nums, g, keep_order=False) + assert len(groups) == 8 + max_sum2 = max(sum(group) for group in groups) + assert list(j for k in group_idx for j in k) != list(range(len(nums))) + + assert best_sum< max_sum2 < max_sum + print(f'best_sum: {best_sum}, keep_order: {max_sum}, not keep_order: {max_sum2}') diff --git a/tests/utils.py b/tests/utils.py index 4e3bd85a..0aae6c4e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,6 +26,10 @@ from nnscaler.runtime.device import DeviceGroup, CompileFlag +MASTER_PORT = os.environ.get("MASTER_PORT", "29401") +PYTEST_RUN_ID = MASTER_PORT + + def init_parameter(model: torch.nn.Module, seed: int = 0): """ Initialize a model's parameters with truncated normal distribution. From 1a9198c8b6a0a7c82fbb218b408fb1ef75009d92 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Fri, 6 Feb 2026 15:36:08 +0800 Subject: [PATCH 1892/1892] Warmup triggers 150K times, put too much pressure on memory and lead to segmentation fault --- nnscaler/profiler/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 61597b2a..23ba4cd0 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -260,7 +260,7 @@ def unpack_hook(x): # warmup warmup_cnt = 0 tic = time.perf_counter() - while time.perf_counter() - tic < warmup_sec: + while time.perf_counter() - tic < warmup_sec and warmup_cnt < prof_times: run_step(func, tensors, train_kwargs, backward=require_backward) torch.cuda.synchronize() warmup_cnt += 1

7Odl_H0-HAh72t)Wa34U*!TF_@#95Vc573Z8v&|? zZ0ZFwMmgbkBznH={p%I(TX~8ouTam4JY!^)=tE_VPI4>%9<`W zF~es_Q?Lea11^#I(fS_kLb8AL>A_YyqQCtegJ!+JOD-#DY&P!x zgh8`$jbDJIyx8QWMr|KNh9S7VUB`q^E8B=$OrvCvObNqZ@iZgonX$iOCN?L7Po5~n zDPQ!mvfL{C=W+G276cD?k?WR^Y0T7r^V2V+pCI15LD8OItfJ5-zMVreRtLkp1HwF= z3jSVw8KZMToVs4wG7s)7sQV|n!+$!6zpC8!vYORJqhaq7ulYV$*zhpk`T{1;<9dkO zuh{^o&)q_c2a!~(c*L`|E@Hgn(-{v6oxlg5H5zZUhV$6J>L6l$9+%g6QlbkU;Rf-# zM5cSBHL||Qqkc6UZD<+dd|mV{!BgnL?mcZe$6C6t>JQ4e(X zh(Kyy-ZQUK_i#hgi}9wf`uY5O4i%E=e()gy?uVJ+{IOv8k(Gtp51{m*ZE@}RRO%QWZ zi9R8pev(cs-8%*jNtw9pN0^n-@xIpGT!Ho3v@bMDLb32EpEf%kiTMyTh9Zp-lJ2fH zPOE>4cz)`x5qyWWR>arb)ZsB-M(mfj96KBZ&PuA*4RoPpCnA%BHIQ0QNAc+FQG(Q* zL+DmuakfZ2&fa;cEvs#2VZuiu#@dbH<9ndft5?^p=J}JhQCEdxCWNYmS(IZ8r1V#B z{ua#_;3Iy8R_g?T&uVuc)HCXV?VfMkN-z>kJ$_)QW2rj|4SjuSisY(XLsxM{v^w`z zU;Eu7rLW1cewcj@3F9@{kTKzf!4t|vzW0WC#dw{`*9=x$G07eclC&?nqbh|T(O$HV zqG7o-j)RhFpKGjG5P~XDBqp7=dCoeszkFz>D5bFfppz*HIRO8(AGhd(f2;e0OG~V| zBif-VJ2f%z=k*b_+)ScGraZ~PQ1L~t^BnWE&BN6_B2zrJHMB6;Y({2f?2iy67u0s^ z)i)(3^o{Ix?F$LXu9kwx24D3JP>;^S0u?o_=d9zHd(lT9UhQhXLuY%6an&(fQf-4pB2*MWvB4bOmsS_dTN|dmo)deX z>fteIo9t72MtHCX{a&1{ZfbkTmbk`PQtn#$RAV&+Nf_oiQ-#k&lR%@*a9uZHgGy5u zW_Y=bk+%r87BkAwd+O4as)(#D(~gc4t!udV`L67h1uX)_=vni1iI7udiiqG5lzaxnWh%v(>Z(O*hvKZb^70I!{G;ut#H8<>9g{=wF<>UOxbHiYg zt376^?{i=Of;W(R*(d>cG3G7D2I)PZA#qCO^*~%Z;eOwsWJ!^V-FLrT{r_$+C!E*3#x9Q3Jeu#tes4CuFUw3plsEd z;!|v>ep=%NI$>gU&?5=yWpUx> zaz=Et7k0?waY5+8ZNcuuuhLAy%u67-UX6VQZ&zau4t%Wi2pSvKD_MNv+&gA&Xi7WO zd)b4;Y^=CfHbF+A@@QUGfoG0Xx5Q*9P+0d-Jv_pg#~YBwCL1BxZWjquwxk79xi6xI#Gb>iJYJ)_>61$Y04r^J&NS zn2mvC&fu&q!DgF9`#m-)hZod*+frqR0>Ox^N)OgGMK!$MhsgisHzk6BXBVe%D>U%{Y5zlxu;koZ|ULgF=n~qKER0I$tgqqjHJ!nX%iaZe-Oh=~=2e;7OcBo_V{q z2}<)J-=?V$W85n?c~+4+uZ?lF?JkKf`@8@r`YK%Wv#>7v(gl-lGoamo!hLCx+&FQuV0OPXtF9mg zSDqJCJp9|jk)ydN{#v=itl4Eu=9GOBH@a9$RIwR~g-}n7$!g!sdJn>M)eZ7n(|op6 zwbF;&`@6j4`hhXldjm8v#=F`fL`qUkqp4ht?prtw`KgRgRFS8}R-uvz5`B0&p8|0( znW(qFX)jdwso}L*5;Gy0eE=CWgInKx&dsyhxCFVI!h{?PHE%YW>Gs~}(cik=VvcZ8 zC#c*-JDT#f@orp`$Bk|AEax(`)X$KcUO&f zOFamF5FK`N<{{@TAC_X)l!=-dKaQmDJRx&BH^tx|UFUdVDRfIw??Xg{HxsLtbHJG6 zHOGL~&sXZ<-=|r7n}w{YYsH-Y=XAmb-Z(FOS9Z5f#=8_1JB-Xw>&$pWzIsOG5z|>Q}+1(gEF7}smwd7eNnPfm*iwRcK}4^<$Q-mrN8D$L$L%K6>m;z z$@RnT%K}n34gdo;`o5C$A1n0He=#n{3HbU0wftnjSQ75h7U6+fKo5e%sFFpzEP3{b z7Qed$?DN#P8|;W?Y!GBQJE@cf5>%z1a|~)8frf$s^WbPC*WF*=zys-StUva@0NzJ` ztSJxXsP!zcq0_R&;y*!k0$Ow(;OpGi@jTWQR1!NdaoloF$*ACkN?v-#&wcA}*=izha6w-6LQL(%xMBQ`~ z|K9U`qR@we{Q$m&jcavtPrDok&FRz@pN#sh^oaCU-ed(L7;8yz4)0ZuVB>AyF+X}J zQiRtZ*r}8nLAksJgc8XHRC8^5nZE-_vb0|Sa*CAsWA}%s9;-sdNqJytM9nJ%K?)A5 zQG6Ax8xp9Y?ft1}VZ07$fXZ$NRn@^UQ`b2kdCuv41YG(LJ!ELr8{JQ$KU|&sLu3`b zW4G9${d-ayh&IhHk5&kpWyakVLmMy8ar+HB^1WSY1lC{B@5z?IOn|@|Hm8&FH?QBu z7^q8fZ!Cb@rya97W!dgW0i}j%%gDVDcK$Zy77YHXWX{!8Ir#(#BrJq3cg6z1Gex}_ zzilX-C!0Rd2P^8?$w+)+Shxx3lnQ&)A(;!n|wT_%_EbI0tV!k z8P0Bb*WzF|hO4`!rN`6QR`Y@)?B2|Q7;2kEZbU-8RzFB(`xK~(!*r}cO-#?!^@qg zuJSISV5P9u>BUMFqW@qev;VM?mU3vpXHyQ5)0g8n?M3Hw@R9Bm^qeLk3~|>bkEtuC z^mq)yQ`XLvydvVXAkh$xnhVY z8WdmR1j0cGpuvGJl<;T<#zXU7>6ijtnvvIJvu-FFVqdi;7IO?^Gt9ttm(&VS|K<*{ zDJ&)b^oI+oe3s@h2ly@_l{+U8#I5@b5O%K9nTAy~d!&{hJrqxf8y?TA*pW(4hQP_; zauEq0eWv(PpH+WZM#E8yhyJvT%*7*yDf4w;10zil@73 zlz+k}c1z`hF$54$t7?dX(7c1M?@lX$e)m9U?Eg~WbpX(SLG)z2U-%6F()yBpQ00Zv zR;#&7^N>&36J}y|3{@wbrJdonr>=SWANE{ej*@QKKe1m$8Qy#31Y_P-uIMYVrc`(~ z{5t?Ty5_AoU2LL^I$eITsc+~wyP1Oq-A>}i-foBSZ=JsDb#`O&9j{i(9U#!ev@~!f zx;0}#Bb>cD*gR3Yoso+J?!1M7OPjYi=1lPJFCr5KjYz5tE;KETU1NHkvor^Zm#pm` z6^uUqJ2U{X{{tFW(cM)qa4m6aG``IWLEOOMx=IJb0MF4`sY+R4oA0wTbvIYZ2ku}f&n*sm{U{j8ii4F?le931U=|=n{TE`0a0$3r+ zOXH`!j)Wj@<&u&_x&DkV)n{5K_=bJOKW3AOy2`x!Pql78%0I7|4kb2lmM$*H)m6R~ zb9{C?_6`qFUC|_WdG8SG7RiuNg#Y{x`SY=rWyf!z1@*n4l)ORJ;xSgv^|N5>OtOV(K5 zc0&Nz-)>ogd`8G4kZciM)#@^2;ElGrM46OB*|wNdNrcNv_@iFgY{P%dGA1v>YSpG0 zU3Nv)`LLkE)AEP@Qi%ihLOj=O%ARa|ggAZTp6Oe{Oj zIjCAI2whY|BrztDnFYh7LZxQ`|0%HQ6GIv|MiLnnRiFb}g>Ol|ZoIqc&j}~uj}rYB z#ps@IHy;I@ydeLjZ@^v|vXi5%u{(Xho-;{t1tRur{nZ=v0gpSk#A95ZPF%MH| z$JB>iX8f_+ z9?&N6E~XrIdK?P?fW#GiK#BP!cy&VCc}?6A=5l!)cKH@V5a;5|e67+xW&}D8X1C3j z3i7iy#W%97NfL_7W=~=2Zr>-3zz2Im%s4CzYFC{-k*J2F@wNu{2%iIHO>`5L%Zz0mTMZNj6d z7Q3zm7&^yU?{wrRj2=t;PNlB6O8cN?v4hd5z&!o8sNDZF&3jRipQ5a-)vuAvnS7Y7 zbMwD>N&+^Fq*R_fUppj-0N82u35~&Q#rbyFYsCV*8sJ9uTN*bjvTL$4msZuBQR`Ch zXaeF7+zG-2+0Ck}DmP%ytQHPY@{f+&(+7gzmkaQ}q@Pt_c`o|(vPO@F#f*H1+asTo-69zOiW*(Z=uLb``=D-}UE!D8^#kTA^U2^Q-MDnsmfP3IkX3x7 z?a7Y9$WkZ|8lm~c75C`%P2x*$19-gfDK~%pSrPk!+q@G|CW=h?+EKGb7;I0mmWt`_ zm#?vBGpp=lJ#TwGFy8Apec+G^_$#>DaCeSvG6Q;Bk-S3Cfl#%5V8b1g_cJU@KU@#g z)`${`g@t;eJ3g1rMxnSbW=N>I+v*qX-;XfB9nc8uE|;i5QnhZEAng%AB_FUNOP?Av z0zKM+_If{nZkEB>GHHtkOYKlXUT0yF#11q(TIs*`LFLbVh-Ukzeb{4iE;_iT{pfd* zmVCwt&>pUuU+Em=3+H8|_6@x?u$gR`4Ej|*y*MWFR#Jr z#m~xfXkS$1b38xSy!}{#4;&OQ!&HaA)dKsmu;DQ47&2O%N}q?hSWDw0QHW=ubSL+j zfXg1o(>Qufr)m3b#=)CG(^*=|5F&OYDb?q#?Br^=jmpl*(9kb0vM(Ga4i)xdw5GMU z64cP+pse`a_gRh7*8Fj`={kQi;@cqhCBE)O{i)AVTiOXMuJo(YDn5^+Q>!v zwKRy5A?6_0wZb*$>%x%T_p!gR_7EWYnhz0xBz9h{AZ>4}1~>ewvy}Zl#@(#a0gTeV zNDAMyVMoJMF3YBOD(P3<9w>Rjav8~x(Y2R8HkaMQBwIpHqVBC;9L?KFFXok$Mq6)> z_85Pso*k?`2owIz{X`BF ztrYlgGo{Ym^B79{PPdB~Pr^SolMhdXcFelL7CQ(v>uLyn!yUq+2h$CeNsFhZkm#mR z7A%auXK$}D%yf{~(!CA#yYwFSMC-EzJ zR$B@I>uXZt8h0~cZ5R7D30%hX?$lF$E}C2A*nYwBbVv1+J_QiVfZ7D+vvZ~H`9(;V zBkdA2Y`-9GWvBKGY90EKw1IvAo%D%hyIA8*Mqp=8PWy|zfEES(uz&~W3w%uhEKRzPQ((q8&>k9S#Mdp3y zDPEC4E?EM|dCPM5pivcb9t)-1Iotpq`{yerr;B`*Nxhf4EAJSf)pn2KiBico$`}?O zT!thC7KxY5OUC8vJXc#4G|o3}AbFAqLQk)wq~iTPQkRXbWG<~m{yJgEJNBE_^9Z1o zJ;d@Hm}_9QIZSmwv<0|fgMfr21oRdR)FZ*fqUBS87CUNNw=DX%y$GhnRHkiEfktu? zLv#|#{wZp5N?@1T2~C7{6)yo`Y=z^cGk&A;PS=c)u2>6qHfk^hCUD{$oAc_Jh|#ll z;c9IT&U&@(uoDXw;N-urOr*)OUXJ3Qu?Q%-&dAqj*g;=vKpFgf=Y&s@($i~-TSc=7 zv(>&5`N3`#nPnR%p+;u5AN?0srmSHEN*{f03@Ts38N%bWzgzC)20)+6g55O76AA{T zB{jNs0H9XuapR!^?I~L7e=-i}k%<-E;8mn6WJymm@fW7vUf1Rz0~pO3JS~dV?Uwxn zK(#W-pAji~=Mh$V#hEdytkcF;K4}f~$HX|lHJfO2Go-bjZEoy7ES|g~Q^p7~UU@h5 z4XPF3@D=s$!|tS^)+q^UpX;g13H|=_eUMaN^l?yVu)vwNeD%3OH||V5WzoZSC|@+( zfuW(L!rpRdhkRU0Q2-Ff;a~5Gi5HoH_o)0vl0eA2#7aS4oP6VW-P5lCfBDYQk>e2@ z+yXF09cc8c155)Mud*hstE&?wlcpt9Q7=B%vTk1VZ9@qr!W5xRRZ%ZFXQ~Rwe|T?T zR)aEs*;gD-V^&_J4f|N`=!WN#|MlDG0J4oTk12!eQxO*nw^MWXu6R>RcGpVq=%X-= zHMN3ZAnz(&PtcM`j`=o_lJ|5W$~lUo(F|H;~jC5pt&R$rBkRJDm-=#>RKER0-or}IzghMTbh(p5|X zjr&B`8~4mKb;%Y-z3jPYnw@X)M}=6{)Yft<6{-jNB4O#AZcikDb)S5dNs}hA);ajG z=+*W#7=+0{YoT&o%%OA8x){`k70irjIXX{50#s?J3}ybJC<2C~0WTKE)kxl_R+Sv- z{X}XpRR>c7KPq!9T0+}s)&T9yXA)N6d7b?kdB^_>Pq$aK7ZF(^6wyHP^~ZYdGZv5S zrkRX1A~C{?D)V_{V!LG2wqUF-yq8I*ln^gYpi2%TE=YQNnlj1tzMFB&n$_3@XE4u& zgQA8`-*^M3fbk>LHmkH3iI$mT(($3wE-nBVxwDYy_p$@S#)%<wX4S}5*u2IeSuV3 zoll$Nb9y}ejJ3PC-sXZ3%GvfecS0BK{);IceD|&~PHpj8sc@!mCIp|?gex|$@>2p% zrg^(IRG^9f74%MFS;RHj{%A?6v)p(`%0fB3Ck6m& zM$15~%FLE8`v?R5;W73zC@4!h9*M!OC9hd1do-k0eTLC&j^WQPj&WG!Sq`Md6Hnjb zzZsZQP5;!>^%-9 z3Kdy~B9YGmL+(?iX8vp|UU)mwgsD5lBHT^PB9 zO0%q32UM$zOU7dtqQdSGYBU%bV^fmpP@Vfxw!wg*&iyFP!+*uI2l1Of`OCb~O{&X@ z$Uz3oE8Zu4t8?ZXAF5J&5RDfuf|bwfO_XDbhTse>(_R8q9dd;ed81ZiJ-rQ2&mC0w zfbYN~i$8nWttd|Z4_W>7gL;<2GhPq>KO_c#%L;lJFeVF_2ZaJeuMmhwu^dLSn?vAp z#&2Y7Rhcp4^HFj--V)-0DD+h1Y(FyPtd^un92w9vS28Yc{M*qti{*DF&3?o`jI z+UQ7i%LW2fZ@ir=ylpnIa$>YHR9UTuF>>O6{h&-0`#i0-tu35&jkm4X1k#+$Zdr8# zUR+!Rt_67l!NbPXxXN;tg6l_S)W6E-USJQR#XYs#tN(IvOq6=kw7eU?&bLV1^o~ud zqt@7UJ#W7)x7t@obd7DcHBJoz5UB~u*qzveNOfOI_xe^u zgk^_Ij%5HSGGp62TPjT9oASPJP$kU&HW-bqC;c;8eO8_p5Ypeob}vXkqEySg;oc84?V-P#Vwnf(A6H`NZca`sc) zGTf%dp+haH*wp*9Aa(gLRi|3~8(Ca-jGT^>I2HL{Kjg)T6B}CC9O(8s$$;}`q5JCt zr0Uf+k3WC@e2-!bdJ;U#caw#mCMclesmqC*S&31VG?8buK8cS($J3HC-diyG(`w(Y z`t$^oQ{McdI5X3L3|)Ei8rd{}h<}+O37PDJ^eo;eB&~NKVCXuV990m~Jlmott(E2A zK|A(*#w85%MqM;+=qkB=-(@AtXviG>&bTx9C)A2#m152W$LF9wlk9Alf3+bvu3shc zb?3R2#m8m2Z&fZho#IzC6qcdV$GC{k{FS|-HSYS1WBl**08#Uyn%m@RX&O|WpW3cF zNg_8ejNp<|be$zW>qR_qj^HFFCo`5tNDhvxaYPF}5FX4hYb+lLgf37np7F91T%@9(mt2Atmwh1*+@Ct3F<0JX1GzM=qmZm);x;ZNh2 zi69@dQ{HrT-|ZRt%jyW)yvLPDJr!@f^j z$lhPIT}K$w-y6Tr9U&8_l`f4VzAWCj{TlSi<81K!Cf>jbrGqC>wW3TY#X#udw;yzl zDum95|KMbn(s7YW%YB{xUO=r5L+RP%N5h`x=V>jAMJ-eiSl#rh&s)4O99=+1slJHc zG#HJV#~9h9#PB&iU5~d%;B@x-t-uumHb4{x!pFZx;op%3h&2Y%TEDfk1Nw^OFWvi7 z83BO-M|cP&%YR_J<9m!J4keP4h@(tK%8?ZT3P zPx5`{2{|Og{xG2WwOk#}J(BTg&mupacGH~()JY#uC$xc~@|}Z}lxH7urC{cfJ8YCp zlq&ENZP$5YmiIN2*zA_ye#s%F$~q*ulJ+Cr+*uJToe(*Y4ITr}*!d?p0I$zPDGIQ= zgOq>0xt3gD526p6gw*fo7~ywsX^YqP=R^w@aVdP?m01u)zt>eb)j-k)E+p8`Wu%oghSTW>*bJ;M?U6`XVNwuVi}K7fbe@ea!C@5!FIYeeB$3hZ(Dv3^pt{xFNSzKx;c2~yEH1}cbP4$-MJ zUdVfbNu4*9?kd~Mq31@We4U~ekY=4qct~@IqMo#97~a9zaVNyVVPKhy>gKY+3=b$n zPXbh;J&Bt47aRcoQSD#x3e;P z&+A3K`}BP8SRMjfS}6S3nmvTw9BGlCbpzPTN8xEaOS5~i!STWFgk)Q?@w;ED$c>l3 z0u?m6J^BU^!MjN5*J@d%Mp59BC}v7wQu`m4KTSNE6Gnhp*uv5|iooqn6P=F<9hX;N z&;2dcM@~EAmp8cM+A3Za&o7y3;bvrS= zn_9{9%W8P!?N!)l86VZr-e{rZpZ|07()dWRxkHKkb2*J4@z#ZKu~7yayuaVbkFgIE zwBbxM4>a*e`Xg%#E?GBip6#R6m>_eD3w(trKG@ zkrr=s|2F&rxP-7aXT4GKRa`~~aGR)r_0j(e&IQ1x4`8Tn|KhcOIica-P6)xcV|~nq zysyZvQY>d|{5{Ks$oJLkEt%GU>JoLhg zewij6v^0Q5;O2D31WP``y#MEqTtOt+L&3WmNI;-IEHY=wf@h7J7GuqDW9r@ z*1$>sJk>7=??U?}9NI@yME1Sv-zD(dAHVjUHer3n^1-V;Ka`y^XPY2`S(7sa+W46L zi_QT_52_mXwHH1$$}@J&uX}miex*y)#Nn#muKKmQrIrY1HK;er+NvXuewwhzk{e40 zG{t^6kyX+47sPPZJ&M!f2@U*lB2@~`l+-W+`V5&}^4*ktee(p;%+J1euvb-v(2HnS z129wM9XiqxT&pOn@G~misk}tF4r$>D-tTKl(ELh+xzi=tFjy{vnf26{>7|3%9G zb4DkB)JcwQtTZ-@Gx+YC7@#re;~?z5nRI@TxI~Idsk$;3ByH$Rd;htNG$C+vLj)7z zCH}Y=r|NgG4x0> z^xgV1;69OsO6r)cPcz4{a9!{>wz<^mqVGDC6pQ(niXPZ{9 zxIgPYL@DDxd#4wZ&WjkyDS$80#w`SV|1j8+g)%3ya+O7$8T`bj`HO~hPI;qtu4UQo zTqMG#3bOqsPhh*qwt$&Q3_lJ;kW8?~4XNNO8W-nd%E`G}>xV{V{JPTyB`fmi!99qJ zD%7i;&{vIs6|BMof$V>6F$PNl>%-=mMu*{gytyi+c$dhD{T?QRLw3lCAY|agx_KtO z@g>gd%MTIxLaC-FozuXn-C-J(rwmte72c0$6`uQg+=(|e$qYngZCRLuraoTwj6Owt z390Q)s!1z=&4J{VaQ;;A3F?dEHoTiUb0YJoTh=g%KC=T4t>?oU1g&*JL3n$hZoTVwc6@~`@`jEkpV6;DkLyg4kBpJl8j+eAw>&l z>Q*W+B+5xj;g*fx24cFG7fI*waCsyk0c11mE0VC!YRAM($)hp&^6{|!d0~Cthl~U& zbpIMxxLUFv#PE1FApwhm8?sC)pV___USl_9F#xr*D-KXQufjP678BT$KsM}5U|K9d|)r)X5D%n*|b@* zQwAHZA;pNWHa7WED>q1Ld)yHLuC;^wBLxBhp^g+T$GEMJPeymFw*vg%4ysy$YM@>z6 zUlkHU03@=2Us{fzUWGT!Jtq5qAQ-tp==}w#hLKBI9Z3LHv+3agJ|AHD(O~@1Q60L{ z9scwZ(PG2zoV1qDdaPs-gB_B|FHxlg-LA?KK^ohHez~d`WsRrpB}SzUt{2XVfJIu;y`ysW^0L`#&b3=5+ENHJrI5J-ZzECzV4|UGuMumD7=P&_87Irb50MfD7 zmByLbGVy4ODOmlQ_*$3~b1g3AJd;y^y&+Da=b2q!k`?dkY2~7|tLc(n;}>2IS^25~ zQL8@RXqtb4Mlev)c9E3{%tJL02M$TofsDmCUL{pdJaC07D@LVIjvDZrg1|QL{U)pa zAY*`P7#|grKL~u5d-?@^EuIe|h-|U@F=%HiWj^`Q&Ugb{zp$qNSiu8fC)(z)l-kxx zHBlmIN9{4cHV*t36g403eR4Rut{@l0t8A$ScTrX{j*A{R8gQ8@P2ePUrK(w$7fLCj za+OLw4f|2Duo9VE(0JFb(%QB@DW@&%h_T~ey@H=U_)@?y`%N9=X(;T%xct4!7x6*p zJ~fOen8v`LOAS$SpQCFpZ!9?7cU@K%uC|DqR3+23AptOyu&l)tZ7@X8M2C+R2+Uw8 zj>hu|MO@dJ*1$Q!zYWNWP;-2%Q;2 z)TO}r9a(q)4Z}(VbwGc-f(TCjE6o607e!;D%#`_O((nFGDsf!WZ$6Y5<(}qiy~yf^ zqCMH8;Hx=SdlXHqDp@=Q-w3PZOa-egBX;^n}9j@=?{0!3Giq2_!kb;+mciNnh zcr?!np@og!J**zVeVNOOg4IJ4=_+WG42p9B9?l9NRQFCx>EcxVoO`Z^5D@Xlzc?O% z^$}i$zfb%w|6*O?Bmsr?MAP{u#{cVf=wQUJlv&f6hO%`h0jYn5c|rRpy2y1~YQqqw z|B@iU3utJl?F%2gDGM~lZ`$R{6 zW2P*oY_iNJ@}wrdfa<0>kYE@qpboATuwEXcXszq|%xF@P=GQ@rcoP*3xC*Ln zfSSVxR9Les2}H}S`NF#g@tza5|Lr1v`3mYRK%L0|joXO+i@asLm%m^@7R7H`OjPKe zbF_BfV5SUM_oI2@q5D)Z&cv$z6~HUR;`AW$!uHr>CxWm6 z;dFdPXuWb;`dM>=aE1hv7jhZr6E2+lwEp18$8z)dSo9VB4)-VjY56(r@kuhG;mN~< zMv*jHY1Q{t^;3AnHM+t3NQ;`Jc0fF{x*Qbx%~xcOLSNM{hh8X0il9{*5?T1Xh1|JU zRUtmhEzbel$SgLHIAeati`?BptLsD^yBHG<1&l3!AK+0<0A;%dp7kfm_U`Wcs|SS^}`d>}FgHjh_4 z1Ti=XM1T3ia>I5TfyW2cfyBzKouvn78pZTfc)_iVhK)Y);avpdklQjY(LobOgEszG z)a_3X>s@pfsAZ-DQCkv2ndd9|hQ|BlmRf))2e+UnjV8e~nVC_cqj6%n-?H~{!DHBV zD%fr4Crc6TB$3NPPDz(n9T4#LfwO2J!j>r1G1FQ~N2tE%V9Vfgr#|KO<0N7Qn(w2U zny&b&UXj8a$rV0C$cra|C>Q`n(dOPQ00d@nS)$zic^BXS{K-?C(OKYnH5~6F${reEN7JL6VkoTmJpf zt^6^Ku33E4&q<6HtH#%bMMj4wUU-GMD#Rc7$0kW6 zfo6$D<45ROZCS6M#=NnhA0_zd`ZmD`zl)WR52_LrxpKH)l>B&tpsdDS_2q5pQ~YjY z?X!=Dx|$N90MQSG+-|E?pK(ziu3N9|nsGXn3Y67;wNN2Ki~TfKG;nA~XT)0C@!XkpC^tFrGoldz4yd`f5je{p zu9h|biLPFgU!=EBV(=KI3`8a_hMH;-EAz|cyk{T2>x(RO@fNTAB(AH4ATyA}YZ-mp z2HMnq$&PQ>Yv%@6F=aCy=CF1YZPo5{Xw!eVB$>o1HNpC%^u?2!3esF~Gd|A3&tiWe z>oP6~<-_wu9YcL(uZX44pY2JZX>F)0CR7zn{QYsA#DG`wF?r!OBSg3S2d~@M` z_SxsW`<(Aw@Adwn*Ch*KtvTm!j(gnW9{2DqW7d2)#YdH-=Qh>&BAjRBAs$0Tq#(Gu ziV(8bO^`9X`G1#S4FY&@Y~tRGr0gJ}(cd%+d${8d84F$9C=qhh&`8(I>J0r(mD@vh zwEn>PhKE{zEM>xP26d!uoipydD`PypMMK|vCft?qJehQ2VlHo|(;0mzM`0Ka<-MNU z4OR0moaSF@W^Qia7t&cE)VC_57n&vBxxe)*1_Y^OP7;HL z<4zZs@ciaxF!|Q#Y5EezEJvjM#7p|mVGz-rJwuQq)19~02rM5Goz}( z8a<-Mlj7|}Gqn5)FB<%GqGS;79&>yHuUiPBOa~CMx|1AD1tfhk@?J&!c1h_-g#IJb z2$Yler6ryHGW(73vUyMbny~ZYsDr6D$_6c-4tkZa_ z9Sg4R%s}sev?`~O%kmKxaHNFjKz{R%J054VVJA$1vsDyOfWgBe!s(8F?Uf94TKW7o z!;vGZbAO6TCmj1Z>M;>#CEdmiv(MU$aA3|ZC zG38-aAv+3tdUhK8PXB^6wSw_Ovp`>;xD4+s+rB4)J?QzzZBOOwSBDS|#C~uG>F_Gv zA4WI|M9$k(AH>c7TPU59*@ybuOLmdKyrO>=<}iY7O}rYE%)t0T)iO|$)7;97IX?FQP(22J_N9EBTO|LVUxY?*=v~=jM!*Tw)(hU*(T@K!nyC_;6b?F))_e%wQtChm zycauRyVu28gmRX^FlfQ9j%Ohj3xJwQGLHoNa^Yw~oN)s=`cWUt5@0aY(aElpq=a-_ zjxT`k&on}yW1-H^%6n!m_O-)1Q1)O96XaPnnu?t)G~6v4-d-&Yw&4(1|GS2)YJ z%Nxb~?IJYzY*9KY8N(?xoznVtBn;!AkfoJmwijVcA%mkttISgmam@OmgNX?m?0Ldc za3yi&BK+BO<1k-oG;la}F+ujj9FMa~Om2umta!J2l)>btg5x}#FhT;yGhI92vg>me zU(c@y5rEVFVE(JL+bdVAzGXd4R&K65v(G`)lm6~o+*fkCuu@O9--5oabUH5Lt7@;02YIyhiwsVw$uupU7g*jOvDvs6zfZ{&yd{F%{tGwk^duB)hKV#}4Rd zvOnoQffI6=fNK%P4Ng-aM09-5k&!TsGb#Od+*Fkm>Za@dl@_*hg-ugmiI#=$w&A9A z&R$wHBn=^VW8;89C}vSfU5wiby7V<_r^wgcx0|;xbbxDQdFLOgR+j_(K4a(`5Il41 zL3aK(OOCP*RCYyre<$bQFhv7G&n*3YVjX~aplMJ6|Gy#%YF?0BZy-Cay(7lDux(i@ zJ-0%VVKwr8SYwTDmykrFl*1e6??A;v$Xexk(OOihE1bR6YwQ@7`t>Rc*Ozbtt7;>{ zmXjp~OS7lmf$E8J_La)Lec)5(U4`*Ud7hWa!C3r{#A zu%+X1j!8p+9NM4lIw6`GxQJoNi;8>Vm8hfyUbHmee+I}37*7G|b0lEuJ}vd3s^K4i zEU$wVn2*t#R5URtFgHlvb1NyKKC~?%#9Jy9Y z`NKr^a^>`)04ei+PTEyDLHdeprKS>ikV)vG0QTtz&*XUR$43^I-s~#djYNa86{O;q z!&>yIXm?8~OoBZMyGF1mO}N?m%y)cY28^{XmUOYabcuk4kS|8KMnn*P=y~ z@HhDVjWEOk0UlCe#o}=2D!BL&q{UNMApmr5|Cts4FV?F!8V(gu3?|w`FUc*VSIJ?< zug!ZosmFW2fGGI5-T8TbcjP$Z@v{HVhS*PF&+;$8kxb}ECAar5?AgIWjPreY@@aQ! zIk0Okx7blfJi03<@SB3J$G5NFgZoOL5W^iq!1>TT$yB``A)DuKUs{s+eDnvuiNAlc zkL5$#_f`)35f2YBo6`)rcgxcX)zHfeQhK(XReqJDO<;!Tra>s)(wSYe=-9m|Wa9p| z^nTAZCJwxho-G7w#|71YoAwX<5=%`+iun>1HkyW=?V0!z-@{*;KvV65iOnYHjW(?d zO?ywW4Nl%zhH2tK3A{vboc&@t>cCHNlza6*Gv)t_i-*-25OE(C&~Ml=6?209g%jGl zu^^!_uWo3BJ!mT;A0T4Y4lamQU59{Soo{FA41?1sj*6z7E}O)WE~r_CVm7~$I8oG+ zO2FX{2Z9N4KwIz1a9zk?$~GrDIH9-UjN`+xZ|=k>DrATiSVcYTOu#a7FjW)69iL27 z2$-L&qalSANe4M zhv-Vb2G|Ja4{)b9l?zVMBz`kWrI5mA1o?PP7HNboXO8REXj>7D?rpYoDi;sRb~ty;}NE^wqnJRYz;$#fsU=qbHdc z?q3B0eRBi8DHwYoCo!Km{6Hl|w@AIJu52g_fxO=_*6>|uk&h()ng=nY`bSBoCmtvL z;YcB08H492VSAFLhtNs#2VrjO^bHStyR(3-?KuUVT+yKpxnN%H?WJ$Z#<)Wkd!9S} zI{M1acNYeGjPyYAF&XebBp)BZ-2w~{IftIh>H!`F!PdL0xjKaz)y)1wQ$X7 zrm+2{x6#j*Qc$ECaiWMnfl@7^VYc`s{5}XsegNyZkja9gu~7_8i&}=hBD2bG7`T)E=NfPI72%DHRC+u_H0vTMD@{Luva0|dhpT8Vac$q5VLNEa z_ZYYT=uKVO6EvaAEa*KF4A418b7dgrVYQ|3SBuGXTayIwGduf6r zLz?D!0@8+7#p~yz+E#S4xSc0ftU#r?4rRAPuhiqj-L3i4oFVE(H}$iARNzY!-zFtB z>eVFYW4|Hh5hevLx-%NnHK^6y16ir|zUx>}Gx=vR<=^48CF<~7pNx%f{XcYRC$?8{ z-)`bJW^GXll2-#DZu?rw*)MePS=6Ldh=!zt_49wO-K_*Qp-n@|Wn&)Q$Nr4{W=1MK z8nQG6R!+YLBdt&6=`4ORY43Mmta1i{#6Cg4e=R^8ipULBEjQd34jo_7fo?B6rX2 zjfl3z>8cM73hOVTo*>JfJARlCYL$i+nO?C74qXXk!tYBAW(Go&*TmBpGKW^#;tNOJ~YE*jzyJ+;O zAYFA=Hps5lE1o*EQr;&+Yep`$t)K^+xe5QuR`D4}8{Vw-kStEeL<4Za;(TkRC^Oy_ z%R|gS?Me&jLj_)*ydlM_qg6X@KJun)Hy5{yzBl}hKDDBCkW|-k!9*LE^6d>uJ_E5i zs9~=}w&RtGC+FX%^&r7J3~m3HIyIltL~M5-Y8Y_op!pE@)rAdoppMks9#LaClpy;N zK=523`RjL5*MCT$>Vr$cCm3}Fh>eYiQBVz~erG--^;~;W@!cJvf zZ8J_fwx04su8_E|Dqm|j_3GXZ?pT6tRQ`L(i!!sJ#h+pE-3!Cu)|LOpLq~$uei{(s z{KcQ3%IN1Xb9~JT(|G7sr!_^3qz>9Dv!4GQ%pB5R-ZlifwvGo})x?H& zA6+OX7WeQ={m3F%f^w&OkO&-cuiPDc}Pmt(>>Vr z)Ff|NQF)ADGQr^uQ}b8WQ+3d8qVZwvMNpavQ?Me#84WLcy3ZS24Cwop^A95C<0MW9 zdX;H^0^)NJGsasrEd0&05+Ouny-lq**y$$Ip4W%`sn zlz>uEs$k<(ln=VGtPCjEub(Jj9T+S8BP2WS+O>J) zwcc?jhY3BT9qvq}pPk#T@2mNJTezmt+*;8^nUcC9RzAt`rJ4(6Uo~OCic@^}W$F z&wLoTs}Tc2=)N>~m6L>R#X=+`LGFP4?dofnU-JZW$nx!6gY?jL zbfoF6&|Mcj|2M?O;kN8S6^7d-`IC-IG%WK~_X{&Q7x^4$Do{>vg8z@Tp7hE$!QqXv zuBuk^)Rsp*Kcp07t^POxvM=W(GlSNi@+!9xo0ZGon_ z4#F&t_iD{sEj`RJHs>F$(8+ba56<{@L5vqLbAfL#66fKhW3GoK0A2 z+v(8hduKM5rKPyE^U#*peW_{R(&1!2LVeCE$ARGC~V<{HryBVN55T$`L z0_=jl?AO>tkZ2U=VIQlhgCI*CTjn7~9XhJh_RF5ybS=iJMD6{k2{KOcID&3pI76PW zhwJ0^oK(m`JHP+^m=gARJPmuo9`Ise)l(Zx06-vKkVyU{ju$CniX&j2Hjb3M*bm=e zH{QR0_-iXz&ij?&bb2`6>JRKf>Di8-myzsqyI}eV- z7?Jaj-4ACf?$!6M1}4_Otf19s!ZvaKC%EbU18ypsf%h%#ghLvz3=avC&&#^+{r1rp zu!UA~Rd~U8V=j%_!nw#h)@CzydF@R`M1Y74$@Un>M;Z*oQkaSpi_~kRVr5*<@Bgr= zxm9&m+>=JH)i(_|tBI;=4B?+bU>cYIxd(zG4* zD4C;?9r&K?Z5vwbtEv^~=8x^K^kALc{0w+EQa2=K)NKh&EE)rf%?ikWS1$t(@U!_I zLmv_Z=6^#I1fWULNyQ6Cs!^p#jfN~&NFAYd6keTIbgd9hWnwrH`sS{_P;0f1 zsT*VGSVT=J^#XNhXK_<~JY(a@608Kq^==>+2?6rNVK{Wm8aY)1e=hP`u%A6x+jouQ zWhE0@IX|z@?+UUz>PH*jVZl>(I?-z0XR*A)|8n{KDjL$fZ5^GnOGARDDzQ<_SY%>+ zUu-8HOTqtdYIzCzCsmnQ1diBhG~MN2G7?DR82CU|EW;{H;10;4(HVY0ZqNZyK$59}6v zyHCo%<(A8J`(iu}p-nfb;VgPh3z2Bu zw$h$SJu?BlMdCAiC-RzU!&fYYf;*HG6=o6yt0Pxx-HLCY!l<@a1bgK>blVasVs;a@ zD|-mGDwc5LN-v=hil|k=1aMp)bz=^X5!X2^E_-X@x-8xCV-@@?i##CmV)5*iJYM){A$vhL zpI@Ey1Q;r|y2kzt(+>BM;8s}w9;P~ECU3Sd9eISg|C|{FG*4r-7{AjJ6>MmHe&x+% z4pGVXO+{ZlwgtC8NcSWux4iV}d#_f~4Bt^P<*%q-n5HL~{OeGb0M~WpbnFP{b*g+s zVgk%X65A3K5Kc6D?CYK{kC?bT0l59UW2M zoUz&Gq4vYWZwUS6!6R_)VkovMe;8P?czccLr*pa(*zYuGzaubH69p=`@Sv4IZf z2P>RtOZOTA)Ra+lC&;-jl7bdzzP;P_^6B{c^xBI$8V0}L(MFR-yGN_i_rwyX>sk9H zUo--IftcG;O0c7l{rmfi;?~i!8?1)7BT7NNKlfWN?4iS9jpQs!xkW4cpBHGrak@lY zeDMt&v($D_e)qLU3%ww56qQFb#Cm1omUs;uJUZS=hAwO2aRqFo9w~?pIRE5x*__7D zZ0Je5xGE4E*MDU+g<3Hv#D|^#;qrlWM)Er82nlE<%7#fixUdQ>O0ur?e-iF2YO4Rl zAFK4z0^ak61H?$m&%|GQC*Cxzx5;X*dhYYI6kamsrZ0b?OUs(%uMy;Rm?Wr(FMA`M z3RlqIY+$Omah_}Pe&_wE*g>xngSi(zhH2#)o$M50jV)M=#hn02n{~Nr99#(4P*BTh z#*(PrX(lURwtSOP<=@X?<=TdshZgs0UKGe4|MiH0s*Fs*ip<~4q!0&E%Rv-+E+?23 zZ@ho!(+gU2jG~cxHCqo|!3g&|^$MSb+I8(2$l_1Ne(A>{b>;HtjyLP@8_ypBG_{GwXND`DubGuoc>#j$M#4u_G{fI3SahUQT-^n1$Fv?LQvH^#SmmJe`u^RZJR<3gY#bO}AKtoj zs7q1wt_!A633@R2NSAq~KR?!A3_@BdroIgMd97o&e~9OCZr9#rhxbYPL|a1JRJg(c z((55()z*g^^W&sRXs(e{{cogp-%Reu^2PT2?kq7DtN*b?%lfV=R0nTrXP>)@nJm)m zh#^hmz0Nr|+1S*sMDt*pB&a?he`l_;BMFlPSnnSx*Yuy1i}fuUE;D=Jz}+OScrHtMy@QpD>sk(hw9 zjQ!L~LjTZ8fv6F;2i0?%H|pL8Tr=x4I(s>fwIPgvNntpz5_A|>Oh4uIim;L-0uK8O z=F=8;a9RFR^vQLl5PXRUYGM_EoO`ohSql%Mez1jZpLm~9oDx>vpWDwhJaP3T!Q~`x zv+rorFE@ijEoSFAM1 z?c^&;Pzf+lv25(7hR4J6!{L<-k!Qv9z|v$j1_p zdfldS24T+pMvnz#z^s7vN-f?NW3^ihgg6;my52HIL6!j%j^G1;F1tT@-MWv)B5T*1 z?eL~=sUwq~YMb0&3>F5;e~xV1Umkl?tS%>m$T z{lHgLI%Saw1+X^g+W^-KzEp)wmY-ZcG~Wuan`)?d7kLXGm3(aj?T3-)xy#}H?9cv* zt|f_Ysly5hFnd6@VC|K${hl@l7UJ;*TCalIRA0qM&SK+M4r>#GyhHhaG39>uWWM0g zAaunjtk+J@e(7IW!xu^`w6MATVcG4jbBnAj_5#dp~-5Lj8 z0Pd%N5Kr~f?GD*(pq%eikzW!xv0t44=tAo&Kha@I_zE( z4&>&!SIseG(eb%m)^aE2gqU^y7rT1reZJJbz+Z_05vLH)f;sCt{jiK!RchbCEa68j< zKtc6Rz?s3Rc+4uM8nX0|XLJxNCKw`mU}}B;=kk2njbq)eMPLgNH(G^uI4lDJI*UDI!e!G^~!J*aC9^{`%WUjTG?TC$k;E^|s~ITRBH18MLZWbbo@Uw2CxQ2)L?wUpiO*UoR+wnrT)|@p>JdWNQE(wY``)^`famz8OJDnRdiz?h% z;A+NqHACJnbjIK{FV31NP0M6^?`4gxrZR}-di*)QWP+>sa_Laag(dy{+)45y1~ny- ziIcll z+2~yDEI2FT)B`g4?`Ey=HghA}ZLeK1REeCp5|a(&euT$VP2qh~PR7K0<&nrDBH&0 zV1#iu68c-;(^UxA0_{oUAQ8?s(v9zos5^TUXR$KCRm2a@SS>_cFa#(1g4Q?ZEpjiC zm?US7evx(ePH=*NJ&BX99|aSl8Tf)7Xd4qFx;4`-BH1@mDgS)3pWFM|o(SQ4FSbuCBxUa*Za z+PzNm2tz~qU3{aSe7H`GsCqFbJItEdhbrvJ>a?c@Z2t>-f@vOgG{TIZV@UR)rs++x znLLE^*|u^<*{1RFTcem^IE&dt7wHPmog@Zm`Uf_9Y6Is053C7i z%p}jO=hMUfMJHb-H%v$n%m3!NzUnR**|-O$-u$P7eZ=%qwL2{|Y4fzESp&m$`kmwf zJk>aAxGPGC*U3y(R5=e1DG$3sqYdg=}+RM=OgS_`7^0Gv~RBZaPll7lOnAg0@ zK!bS*_v!0Zx|oJqiT$GtYNLa0bsY(P*?ZA$I@Rwa<|hW$z3bNvUyEV#X7FA@`EAG z4d)7*UycZRP>vN%yMsigjEdL@rOa|{;bmGLd z%jJ+E3R@GSFo*cw8{hn(YHc1vSs}H6PjdoRANE%@dHcC>UohX-=95Ktog{VBpzL*5ek9x*|)blT7`ZtaccisOSC5+PIFLMlOHr=N+7M zOl|_S8(v~qU@B68z12pp$usX+rFfh?TDZv32_=j@{cwTc4?!egspW)8^2N9mYtQ*eu@oj?$+N^*8#8UbOqUCBy36 zpQ;_Yazuc|C;X0foPgVPJ0jr+&>5=4gFaqN2hmcc2M-!+%TG7gzx^y1CTz8WFZo1G z#w0?I%x|4CLb?E~arxPTX|q5aCPD1KgzwYKSL8;@4{cAv+CUOr7LCLI2j(ULHYUXO zzx%X_!B`ChM`PFzMvq2py^muDZt?0?wPpnNVAMcF&0J+Fu%y=b{U{VnnMBMF^}(L` zW<<=o;y-&G3e_rrQmFI8ghU(L~x3{ zzLO<8kKuVCtFbN`rE9pyWD%|To~);OH1eowfMlcj^i2?oX7XKjgx^d3ghg)OBgnm@Z3Gw@ZP~)6+vb<^EN1hmOdM1@M>K(Di zN#qjFosGyX9ved?0>-0!dAgz6cQ zf%xikTs1^APaYl<4Tt9L$pdS1dMRFa-iRmc#Z|drCXJd0>AIbl7SYmWa!f%CdOL7{ z`*iz1Xb}u^#UXjETeQux83%q}JHs@X6>@LApAJvfja~W665*v^;eLM4$1(`TzqJ*^ z%~>FGe)uT#o8eJ*$EifFpLmFy1_Wq@lPjt0Mkw==InYQON5o>v@1M&iqQ(6h;3IwJ zb)^7vO$ba7;!m}|u#4v!NBz`1l#Z-RZYc$7nMahtJRlIVBA7SLjc|S8*6KE25MpduV92zSYUDT zOmp!)fkug%0~{9AH03JR%Wv|nFHfG=1*5(AIs1rkB8Qno;J9UTYCBju-K!BSp8Jl7 zlNl6@s&!ba4=A=~IaaA&S$$scIHtTA?i-K`A&yl2Wvcp_hpvm{AdP7r$yg<*g78N^ zIljOfDq*c6-K^+#FAO`{qm%TjM{L$aI&rq#Z*0GL`mksStEe8l2bVEHYK04f2sHA+#0dB`nVV9 z1_r~;;uKZ9{8_BsQH|#5r0_@Q2*5Guc;_NdmA5PT+>FpV#_}5%A204%4?H}$T=TTE zs}nioSJ%z^l)*aDXb_`H$hP$tpsD&B6HAH~tnyUy4P}bTzpjCH#0mMoL@Q=vuz(wkmx{jPKLCS2!1x&#_K5ca z*n0cg(MwTqX#De2SLXM1Zh4~xNXJ-HOR1*4OWCR2%2=HBEqe@BI2~7C*QTL8O16}D zJ#VNrVuRwmx%q;23(1!Qr?@{fy)Vx#FIv~1**+L1EM-7pty&2dot;E0n(zv-qdL!k z6Vx(7->}yhr$E}X zK2MVq^^tsvDS&k4<5vA4<3vY9aP7Q(S)xf}bku5B_TW*x(h~aI8_SRF6pX$;ZuAEvuVpbXfU)gFo@i~>Blo6cGdvM(tL*cEG(s8 z5g6s0lE7l!B73fXS+X66_2cZxfzgzo>4-5)- z&o7v@!#SZ=v|6{?{PoZ+DsryK#563%r-!SaNrS)`Rv4B@-cl(Q$(?}1!`tgmQk~nn z#s%cUF=Wnn8GrS}av96ZUmc$yF`dZh(m;E84`c8rW&6Pdz;?t`Oy-_u8Ogvsh4ek7 z234}WUKks}mhfPp;1TUIO9GAzHCSdK+lepYi`HJbIQh1nQvhvR{w!48Rlq62tzO zIEwBR=~@|sNg>yL) zB{`|x3nT(9Q^gOSx@FM0x!z!E#oV*Po_d#|m*fC{>Y3aIhpyMsr46lHWtZSfL^U%y zoGVa&A)gBpkJ3&*G_P?BS|;2m&?gd0$yTKZ6?emBm>?ytBx&QED;7G`4tw8d;x83N zJ{KF{?A;%u=PY*ktF|LE^=;U~AkB9rwAz=aoAtp&$%oSP2(`8JxfYSxLyoq~reDaX zYHO#_5Yd5n-~Zm}?xp|T@~jm!V&a_OvmnXv;unCVhNY<7I~!x*!_5y#V02qnGvOlJ zboO`dUGB=T9#pI26+6wO7{)pH#Z6a4ldX|imd>r_ObZ)T5igLg1X~e_YACI zuqGMgzM{|LXK0*DVR|3g^bJ)}V*>wYIWQ7W8mG+Z6t?@VEqxyHu?e|ni$9Ih)CgH# zp8eY@!jrbsw%}oFmn^(vXr4m?8pQsoW+dc&bo8*5oPoS^sZ#+{L|{LpO-MFIZPMwX zufXp(bwkr8k;|;vaIYn;i_K^3kYWbsyT78O3L)=>B(=e8x9yiVd|6b#wD^N!)OmEw zCzV4|zHdj7A%@D4+(SqW^Lr@lPmI4iS}ozh#x#FfT6I5l~KEVJdrLN{hBF_GItZTu%)dMa*lY z0_a8Y7?~f>Ke^{l46^+i11t66*j0Zura~}Hj&*;NYpezwr~LOUcQMa%dpowD*RyZC z^2gI>|I&|g^B*<_mixa-`OFTGDhhnG3>|yFFqun>76tKXtlgDviN)h+tIJB+|HkL& z><42f5BX@fAeQ?KDksr3t0lQg@S~m?p9YBbbB2`J+utFb^V_t=8p}=1{L9M z3QqFzwABrdIp%c(hWRt7+wuGD(Ao(2TA(oCKMVdWJ|`RNP<6>hQo3O_RBwhS)7W$k za+d5*Swl7?Ycd!}62_9HgD%Dy_XC4c9+xp@lz~AKUI7V5nqEa0Q%e2nlRI}IaK;o6 z!2buK`A70&7Yt{0V%p2RFy=p(>>WnV?pQVhnsrUhH#r89DzP%A6F$dZqsv=+9Ejiy zOrlL+tfIjT-^svXBXG8m)V7Ke7%zIHNTlyv^p{O;5fUVkzC2|#QZR^7fgYUrf4bS# zUPai*r5(oO03|{QT7F%X>Eq$3KD-tg)6HMn_!)iAsn#%k)e(sT9X1bY8U<0DSq z#eFk}iu_nFPMP|+^5QWOa8O#ad|mPF2khZmxTDBML~j~{8d2{rbljz|`GCo&I^Wh; z?z0_PU5xTv!hUQAv{H;}&<^D^AaSHN*S3LKqK$onU;wGmUj?d?j~zu*TK z5p9m&NYuRBlzP>R(NqKD?FDRj7rv9InF3P?B%x#2>$J}4z}~6vq}LSh-8)B09=*E2^{eP_>T4g7DR+6|7DB4eAA|3FSNqFX zFGl(9b;eTI(LszJ&b=yMMnqxtzwfwtDgq`uvB`m33H#5j6y(NYF9^y zf#y`#pi%^}^w*mSJHU(G{EUM*G)mTT27$3wq`c0O$r;7NNsd!mYfQpHef(XOU7c3; z-uvM>PPO4)xuAoynI~EesBmi8jPG*R6_qDtcI6wwreY$vn_j2B0Vt{(-yu*E8aF>(>JWwyc|HsI=l zp3k!q(qjq2a}?Z_$Tz00o%#9iB<`?xyx1~#T~GXyA_uf+JZJigMt%u~lOwB2Esx7$ zrtMZ`*eC3TRo-s{-boSF8SybTI%9D5hy-l zU9J%s9}~erU9ZpHBEwVG|x%CJ`CVcbDT^A8&R);jI) zri%Z_MOFjyMLTDI1ft$dGY{XSQXmp(N?XFu$xeDHQ*TT1!g#gaMYuOy7`nkw>M&=# z3YP5I4PnqX%M+~yj?=4@b_EBhmnss5e8`RvNoVD2#DvJ%5u-r%<_5Y|5br9|z>O83 z78vtVu0QcIr=9Kk>ua`RrG-8d6T*9B0)uTsV0R&_nczZ%?y$3KPrY<@kNH{Cw#^i{iuVL z*e&EW2i;1Rcb8pnpWW@qZbr&P*=c+U(PdpEzEGu@CxWY1{>v zBm00is8@ExmLNSV!;Ye2^|hhO<|KBM8jHW@q8;=GC;UNp$p>IM`mG6;xg zZF(u({|4nG>SYc$9Vk#5-_*5JuCSH|ZLd8m-@@eQLOt3VF$NCJ<&L)w8nm2Pw1?dEZ^^M(9b00LYVXT` zlD3fxp7K9i+jPCTeh^g=pG>a2i@|jzgISxt_2^*^4D#Nam!q|*(skSk$uIhIEM#Qh z-rJ17vF3EdA)#90kHi~ox$Cm!!NjC<*=!jHI#1Mlrgco|A_ zEv3s%Y$qUa2Vbj9B>Nx7`2>gHAFkkuWZ>^-Z|&+{lza3brr}gw?9!YZdbG~xu=jku zu1#WhITmH1V0vXr%8uv^2y>^ER&_6Ba29)yuVHK?#}{HN?)y%SU`KWR!J?l-N|0f^ z>Z=o>WLnTan{+7^4xx~JnGjd$xYTC}tKY+YhWSR#GQ7V~noDlMi-rb~V$}u$TMFA- zi~WCTo6Iy#wZoVV8Vcn#=-IYTp9ui%#USx&T0}BsP3-Jg74YPlh^ks_et7TrT->r`2L|XU z#Ct~{DKS~>9CzwQtOgE8-$(852?Fk|?Co^|?|j-fQ+uZG=RA>%D-c=gU?FZ*;&kj* zj5@o>2~Cj`6Y<=OVD!FObUdHrSRqfv9 ziU^M)Aaf{A`rwSH5tl15c2rm*gz8Jln#9 zm0zJ{g9$a^bBKy{Gbjt3;=l3O`(58%qGqb?@%7xk#KT$2>!i$P;Ode=7CK)TT@ym2e97bunAb8tYclqSZ^k0`)jAfHrrNDF2w?Avj!h>ziCCyf`cW!y6K+e>(`$GvXd7{jH_ zTnxrg=GAWK1Sz%0G<5&d?wyj;$1_PiVjFe=hLHZj25?5^ns4F-T$k;(NLy)gbDkvl zni~4Cih!oeIa5OSQW7x0_)$09eaKcGrAd15{^?@YjoHWj#(qIQ@xwqW1{c|}eIYPL zvv@Jt&=I6~gm&goI1%uJNP@tnyX z^PfitZS2Yqz&PaVbQ7M;N|v|zdr- zSJmG>Wc3sAiO!q!LXJM2W{RHhWW+L8wu*We&wqLiv<6Hi*QN?(z!`=ShTw2WnJ5#% z;${8gwVr!NH=FLPjSXb%8Wq-6J;9TyZmyK!$Rrts;A{>I!p#Ue40o%)*W4y4T{k==9mLgaNe#Q#x27sj96KiT-{N?NNj{hV# zdmUEgVXH&clOq1x-+0E|_dH4bSzwu+(Ma8fWb+*c_T1JsV9%X;-LzIdpdG!IFc~O^`>*I2y~c&>D4=}3;W|1TW2va+ z!c@b+r#8k_0W^CF0HG6l74)g9{IM5w_kv2U_FK)E_Ah;UI789`HuGI1Sq7-)@?t5-mBGgnw*5-C*;VZ0-NKTCr8iYIom)@I(_SM zARwFmM(2+|4W(8VNG)fRR#5hPCBc`Fv8#@>)UVR2MpjjK9`;_TdGImx(*&4PGoA=9?1(o}d5&ODh^db{!Ce7K z${8it=-)@lsvycML+S0V2YHmm0m;GGe*cX5o<1d6Wr6Ba3rf69ShLVf_s}#^MWq{a zfs^&ZPrq+JWeZ4!e?VPv1`D?rs#H_L7atFobPj`skazuCv|)rd(`&>XFA+)PaA=l4 zl+!KmW=-w?m)7^ptqhD?I(=cHjbjO4(ylqV{tvADM-LbnoDB0j(VRGP!5JC2zv#^v zj##zeJ!oaY(pNzQY8Z{mQSkL3V~*`e7goBJ7ADWXYSO~#8XPi-x;Sa#|EcYzGj1lQ zw17t#)Bt@Jk7h98qKkZN)MV{R5jC2!5<-oUMCpyWPr`q4rvw-a3tY>+nTMx$>wo`q zv!}PexS6{H=tPCS4k4cKjo7DeR#X-0noy9eE;FhF_vW1w! z1-c?guf3CT1CqSeO{*13{25~iMyOi8;q0UiOxasc4oGkQWBo#}j&zaM4~U=4GFLS3 z?$D93GpW2;<HFQZE6yRcIn?er>Iz_n7rr7N z6&Nu?OmD$gQAOhsO~z&|SN;cotvHcXCBA5UlV!V3)QQw!nAp5N{N*Q@f1nU^@81@w%InIUisCfVF!DH(HxT37ZAu zdqfg;f<1 zl>xRSFrTPQ2gvl(LILYNIHRlZZh~!330S*9`+!r}Sf;5rx0}A~9=0=1Y<`sFW#sai z1}tW;fezzOum9U?aIIUT;f62CoX3seV{N1H*p5{U-G`E@>zHut%mLmwqJ)65f?_HtJ#@h^FUz!--*qfY5g%jNC_k~9KN z3jmrwDOh#ISb%)%ZC)dKeAN_CvqCgxdw`<)64@f5HI?2>;Aslj>0XgoEme{DIhOJC zz>4fX6IOcxqXk5QC=m-TlAWZ`al2>zAwFZ?`UD$i$ywP2#!p*p{>nFrm-? z{+yP|BTRSWe`0Z}0Q$mY3K{D?42-w5JIzs{9qnW z6oJ>cbrLDHg}fvh!>|6hs={dAKDA+>&JJqJTYHPJF6x1ir5^Q{y0><3H6I_Z#J>0V zduQvLZiLO-K)WQt^0)9;Cm{jF{B|Hg6jRQFo>|_}t|#1o6?5Hb+nS*J&6Q)3^zNM5 z)?Qd(OWpa43<}ZPcFY%GLZJ)#Lrp|_)OHS-+lAjQYyK(QWsUxLWltxaf-SYL2Bf5& zyQCdJ`I@)>taiU6hY33QR?#tY1Yu$9nVUQVNUy9Rq9&@>SA3%|86rS%-ye(CT1Xtw z&{TH&=XBDLkO0=J6r*Ls$p2dS&VgCiJv_4tJxj{TG?x2|Y-9JbToc$Q258O{ObjoB zduc9WCEn`*ngEr`{{Y~#s!;}}!Mz-yv2j}*#R6(5cXsd6M+-L6WGu(1kL7WTa{GV7iO#4852Aq%Bu0gz!7LytrKfgqzcJ#}&2{3TIoL zAS2PhEbGMAd%Rz0X>EereTsLUl1>aCdZua18h9U>)AOeWS630TcT0LJsbJ-&QslJZ zBq0eT?iu2%g#@Rbcn`{{FpnRelITWLTheCJE%*mD4q}r&m!Af9+4&nDzfJ)8SZ@~C zbp-3m0Q2XUl`7{%B{LxwyA_A&Ohc5^BmWgkX!(W|O3zSIc-w-Dc!aaxnDIT@5cu@v zpYO##7E(XZmVAipfoU_fo*YM1EXz{yX#@_u#x>E~sV#F!_qP|Xv^NO0C>1!WKX&d@ zGx?N93`o%z@g*s<5UJkaj@KWO_2n9E`w*{1i*r?sALrZO!h1u|#6KOngOMWJERA4; zbc`+}p$0;eiEe}1wy0>8#}&c^s!d6)TZ!7pOd6b!cTFrGvx9ayvd_uc4L?}mLM^{n zYz_okY2CHuGWiy%bYgxa$UyKg&2&=|UG;akZGJQxLT$+>(=;75Rup$P3w0J*rIHK9~*3bN6y3^3xM6v$u?sMC|SI*vIL}gjndd&T7@^Tfi;H2J~>POTgV*9^o zLI_T21AFCQ2-n;OI6rF6a>{2jhVR?|E#!EOe)Obu=RA8;mbJ99xR5S+Irlv~fUcI@ zlF4-KlGrMdj@S({2clzcdCE>?KgYh)s@!`XykFdw3{h&QXQd+mymHpp>z$7YNrk-r zS^R9@6}x=Dy5i7Sh>%~H8Q-YM3XJ;&0b=kvFg|jP*oN;$wa5RTNOLVPHGbiu z#9iLdb^=qaJ<_^cMb-AuazTOU?!qEcqMRF4%geY$UQTs{(;VMl71&dWTkf*;|Ib(T%wN2mEGJrFY)_t=EGXeW_UkRPy8sGz zFozWYQtca(7MH{^xmR*<|@VM z*h;CcEh zsj4XZ2|!`(<;@~T4LLLT38q$Zy7qBR8ShWjx{fk@isf=E%r`85Tfdj* z@)es#QtvV_2G;;*<0k#@Laf6ApIqcsv#7!UxylgY5u(Y>2PQj(0Y~~#7cMrGsk9AWQF#;r>h66i8IG7nF15YRWW3&3V|+3Us)WWgFE4M z#K!9-7;-nz3;S+KI?DUhl#d?D>L8L`^PRl%$y(Xqf>(1IPQ71(*{VC0d`baz#%{1F zwK1vkqaHm)Lvy}s+N4-7TGp~{V*^aV4{R>({TQv1o<68-`GJ4#hRSR2aaW_<)WLb# z)`U;cJOAzxy;|E7Nz#?AT{yojM(u0V&(eoXF;Y~nGSjwqS8Zig>x1`~>&`N(8b!^f zQvqh-mbaxRrw0#ThffS21=8DQ-h9`S>T{R)jncEmYe4issIL-RN6A~l7xdJ;|JPmF z$&AAwdwtH1?zciZ%h0dsXcopFp(-8aEWl?yaq^f!5j$VHUwvT&9cH+r*5bs{!FlxWs5WY=8S4l6f$hob$)_EMK8vID z;sHiEh>tMPbHs{SB9~9*X-B)0jqk-rd#BWVM|YKc>fV(9P}Np+;mht^xI+L8Cw z?4I`;ZZnx^e{GXQ^iyDpQj0VD=juOZ=9)jB{%NX z-yT3l(Rii7KLaW$3<8+<*d|dSUWwnXCS2t#RS5{2MCiEwBI}Q%W7w?|@+YE6?6Wxb zo{OuiqJAgJ*jM>)e8 z$chn8aUvJ>KqDV-2`Q|uwcj}J$E32jIVOgz5U`DBLPR_u8$^zRV# zGQ7tv*dCqUy3MX4UWlJUU9~ z%8m!_m9wH)X4@Kz{Vw5p)`0gL+_c`sfswXUB06;V;nu`RmI{vtLNT0W(j7f9TL{@KE6GlI&g)dj1klI ze~yL!))(C>7e5^5RhIAwA64N@cGpv5@t1LBi90``rm6O~V(sWZUzr(mHFQUItACN| z3nXcaIBv53oBGX}R=ZM@Nf~+28DX8{FgtAeZ$-S4ROie~TKJ-FKmbsssrLP3;zQCn z@nH%=TVMVS{f=1yC+al{A5Saf<|e-Pax$!^dADn-XnB#*lM%q(tM>K1Xe#XUqnaji z#EGreL@N1>7Hw{dOZp`ygRdMEn%rwQ4y||>Z?)0G)Zz#XOaZ|IFblwVKpE@rAk~7Xkp9oyqGX z`t^WXR5H!Y2M?JN-;Vt6Efw!x$ZZb37??BveDC__;@4b7^i|Q44V7t}@pS!!AI=V( ze9XmdVa8iQmPJx)5D&D(g^{zj$MD;eui_%ESjdWH#n~8#RV1&Es{^qTf7ywP)rv$C z@x2(pIqDAJMQ$w-LLNk49j_4*9(zr8RisYf{E?~kp#JS^!0Q%36H=gLo|LC|#6Dpu z8*Iy`cW{~UkYW4gab?;`*JdKscsc7mZZE0O?=X4%~V3?t~uS7`@WUq#;C6_%Qly!{C*qZCR`) zDQOUmj-FEScOa@(`4((A(krD+0>E^1Frc3v>XddS;Peo?PzgoxGE}JF^HSAF!i!|; zv7QCz_ePaZ;s_6Sn^vfmN1YZ{5lo8V-!5S7RXl(x|EzSYj-+OE*R{O0+dE*>)A99P za^H~U!>v@)@h8F08rR^G(o*wjcG56h?DqGPtVQkQ(ndw9T?&3myteNBBeaWiH1YZ! zyY?eLH%g2~)>s`}lbt$AHsv$N4Uh8^eJ#^o#INf{!^ag+qqmT+&T6{Tt|Cr^o(=iJ zT+}{1UM*u5U(o8OnX=P6~s%kHWu!#iR< z#v=u?f}G+2aC9Ym_7)@WG7%Zdw$BiFZQIO`daEf|@jMT4e`x{=xCJ=i$zT=Xrt%a{8gKvz;|+B;YvnFRVyW|HKVR#_nO}?|8w& z*kp3PxkK!3y+<-gE?hIcQ9RrUp1;JEuRbm=YZbjzn-y}+{vAJn$CoxledF&o|0_}* z7<1~`S8DrQ{J;dDsa`FJKC=&0D&u8F@>fZ_el8`)0Qj?f8oDFJ`9q`(AuB2kyHVwY3AbY&-r8FVq!ycZ*xW=gJuA^Tkf5YR+2> z7OnF6r@x@n*9xoSUo8E@NWz#7PHh@bX&8W}4&M~|*P4%C3RwxyX=PKadhcy&#!8;x zQrC*eJOJaE=VrVm$os?}+>xfmhZJ?%VMtKV2Jg1{c}$@NfF0JkUr+PM$1C7y?N7;R z0+Ff|iI9W(eQ58B>`qXv|HWoes~6514>+18DnlVsZcO&BJ8~!{8(4gK;RUNHjhnsp zBmh#E4ZL`^;t&%hg8AkC#3X*SE0QL=$R-gLSmIj7rw{pEAALPT>-Zkl@{K9MdS`5N z#ZH$^M|X--!1G<@FbJR)5fmNLh#PZpnzbJ^%#8<1L&OFdOcHcqgajc~K>H{GTfLjs z>hgy0N0v6~1&3aK$|~~Bc)zusrAOw2_0wcV_WPE5l%P~XhaI>kvcSLFaw4R$lZZmq zhCEG2HJ<<~shF;e>k1o%q|(oI3pj?os&6^^3%$*80opiq;QF zyK7r_CTjn+6UB^rQ9GnKN~RMW7Z=!?;u1|g&G$15Eb1;lQF=T&#f;p~6fE2CR-W5- zQ`nlMIS1h=t5Q(;VNL=XDaln;gC{%;mrc9Gh06?ZCq$u-8%a!cn_a*6oGFR$szu=U zlkVI|f*t}8H2G=?Q6^x7i3?%n5^l^{2Wj{=#%3+-V}t`4%PS8()}c-c`>vM+ej8d< z*`MoLzD1R|4&32BS#ry?Wm;s8N>A2m9;CU8&Sbe+prUN+MtJus!~6%a z%!~Zsi((dD+&EsX1KK1vK^)I+@@f1og{Eg#OBnGYd}5$}3T92-?~WO&*Y?h&sO&wL zrXDHU$f5!J+>X2GCHoigcAs${F-br+KT=${eNV;LvRrKtG8OFhSEDn*J!GFqdKq={ zAIY6G5pan2f@)r8?vca)?yTE$yy7ZBW(`Y6IM^B!ewltH$8WiShY^0W-Sz6>b&EoB zO-*0>C0Du8EHlcnsx;>h|J3s10Wxkz`+ds$KEGA*FbpDf$Db?gHBwKVyd}s@QD3tIQlsDG_ULuYln0q0_TF;3 zW=!KtY!1U1Z#)=r019PaUD$6p)!tShS}m%!2JJ0<};vsuN4W4N=ujqD~dI4P0{N72w)W~aWtMOIF~Dh~ z!%UuVO8@Frd)3 zLC&|`IX8hH9m+dOBIzsboVdOuc%^I~ICJ+n=PE)J5o@;m~MuYa0W~Xx3#}ICEp~MpN5$yjjV0jou8*Nxqaj_ zZTv|e`R^3-Dy8H=UvEJGw5OwB>z3*kz9+YQq zg9fzJ-hmo6o+B-QGn9y1xaFju)0;rUr+kP-3p&zSA>FT;a9H+YxshC%!2v6Bk)=`+ ziQb1!Rxna4L%y|hRoQ5vH2e1>-tbDv#?eu<1!|P(b68XOV~D)JYu;zwDQMntew1Dc zx#_F<;>Rq=YAr1&r|WU!7fYKSnpE`tiF8x7dxsiXW)$tcQhNd_(pwiSahcn{I;=FL~ z7<&D~D2O*PQ3!v?>gcqNVE=E`V1^;ahsS(M6SNNkr|ctoVe}5v$UMReHxrAvZ;N(= z(-t%igNFh>W-PB2zHcP9XtX}E7dcRfZ2A1Xc98CAW+XuP!53N`>tIVIfingwOeNRP z4~i0OK`Yl}&}`n>$gyvh-kQS**g|5$zW9*%-PwKlHe=QM<4#s2PaJj$5d|pt`Ek(? zlxM1&49Uh?NXLTxmJSxk^0pFkT(0gG+!uVFL-4$ZVp_g%9kv0s&j!8T&#tDtp_l`R zGZE3)1Jg$)Zng24yXKePb$Feo>8#cv2q7M)k-C3AR;bkLq`Yz31RaEIYu!W6dkUiw zZrfhM1koWp?f3NBcOVp!x~{8m1}Zq{IWNuDgN*SuoTaw`A`;Tiwlei1)&A=4+955X z$M)~Hs;!5W*p|-v9aU|sNs&LltHxlbdUpk5TzbMjRk$V;toB!kCuSog=tae0D{m|Y zL=7J#ji0Bf;iD10m5FS{U;NvFSne868&!z$GZpW3gw5d4kgf7BiLS%VLz$zRILp_63HUJ0bqWHZFFT z0F=EC)5h(6RhHu`OAPe~b5f(>3aCJ4F}3q}N1UrJUg%zZnSRVFB-P1{1b+G6NttX3 zhCP3X#@RqQ1wn}T!w6~gb+(3LxP@%Is@-d`?HA^q)0Ixr?EpbRcwd5lK6apF!U6n9dhgSaY?PqZj3qIG{A)-l4( z4X|-5iQc8*7}WgORimKlVg`9=qaby}pe1Kcnta}kK&oP0(hCv1;Cc6ui2V#}dx=e? zO{(r%t)?NRwzrBis}^o05+8aZ{QpwH0W7wpHf=}}{9zv7GDO;)4A#~5#PHkn>_&malhiuTVioPag z!N?YUhzsG?;PceG*40lfddz!{T8e*0eKjr8!ocQ*TxjCMscXSS4kszoV5&R~CzvwlHyRZ97!>0%yi^lY*yKfmjs>UOyrUX!0L z^LB_Dv}d2)ogg#}pGH6HoN0$=cOgTVq#^Ie?jXq+43>K7SvRj)FOc`M#a{|rAmt!9 zUHUuYbfQ_kizH%W&^v?LRDSeVeiH>z=?*!KjgYcM{q3%s%^7-IfZcSix9EXHlHGXo zZ}6@!#CS}XD}32qh0yV1AEGJxT2PrkiP~F|n}ET^aFtXkzX9*pAGxXJAl%I&OhQI~!FQW7a|o5c-Un$Q{A6wph1J`5+^NiSkif&AA;wkj{;Rd9Qc z+WEq2Z9UzxQ|k{vANmk|?**K*olVgv4bI6ip__nolKl&|#xl)2vRLd5%7Y})e=ENj zuwGmey}RafJcu-x)sF?e0??9a~>YxP&~G;+NW3Sz;N9YurFrxBih+K*1VrxioHiXi%!h7%EDu$cqml(ka^ZIIUyw> zI15t+{D+P6&MM4$e|@)8&MyB9o*@HANY~MM8aK9DZ}<>4quFm%UTZvd`v*2{NU-QK zo5b&lX5c7!7wLM^GZB#w4MkEG@xaf{dvtHbG|~4&Jl*es&gkFjgga9nKKg;ZR3`KU!wxKvgu6PermTC4jhTojPmR^tB{&NRuG4>)uji7C~rcD(;_I7(; zcEx4dDKVK+w$>9D9KKse#8 z=5ng+^P0U35lFO+ix}bU$bu>z&lb*`;RQVu^j=>)o?IUo?m=lo{3{*`xZ^?DV-?QN zWqmjKTp7Qv;4_Ek;QbVc>OO4yKWOq)@+zALY_k#RP|YedI3jwsV)ou*W(!h+h+7GK z1>;jQ$G=KmCsTkKvr1c=#?CB3W>*+wiR-*9t7@68D7Zvou6TaN`U~S+X?<{v=}*o1iT&InfKWOaY;36kx5BaLHT$-zHO31h z){p4JO|HO{jxKuhJHiHJ?~Oom0E~rW(>hTzFEhZ_Xed~Yr${m9_k_8!m-^KWgw`WTOVx`43-NYa&3zV1&ZZ{CmS2$Mr z9gpzA!|Tp-YZ=rl*I$OGG@HlU{M`N#LSF(QZgTzNkl@rTzN3TQur{*j;0mx;CRwGr z9%g<=G>76(>B&1SYFU+$(8FF_0H}me&&TLm&jPB?~ z6WDr>gwMKOO!!v8`LB|LVJA?pp4?TeYMkhZ3Sd)Mp&KV;Sk^tW)&vttGk|n^v`)7* zE55RMc&RMdCI~@{M{j5}UC# z;iJpefS1;qWx3Fq80<8|mvb$;fO9D(8mu|F07e2tt8xYxu~#nXLM4K<)Hv7Cw`sgsISYh_P)zT0zASpdrH+T+iyyl z1&|&;;k1~Rb??nrjpYd|-3-A@;;SN%?i zMf-<-L};gYV|w)}^Iu(W@ZJVJ2u@Ej{Zv`@W#_67=E=pr$BK@vaKD)!y0K*_^QD-B z;CEhG$6PYo_#dSYISvqHe93*+zU4QbGilw-t{DMPIRU$H6uPN$koJvsxVRCLy>Fc7 z$)=(%=~l8*$(`9%N)47@5KMS}Y2wv%&ga@4qH=yQ6y^cMe+MQ-I-3X<1_-gMf3rUJ z4EvY~fh4*tBg~d!RX-S5WJBm#Yj`g-;4EoYp7l)pr32|U2X+YER0e{2Nm3Yz5KIyl znjDk)*=uV`5?{*T_*puk<~ms>WALnEuF;v)nVBR2jCwiP=3UA0+d2*riv}ds(1IB| zKuD(2u3ZXtz4!*0#}iCf@{oImWYJ~Rc>_9m58m6YOpw~{Fu1fXe9&cA+?lkelTl78 zIbHSLT08w~gqG zYXZh=T+fxrGgRfdmzXX!2!-9fB1>^2HYCJIGeL7c-Qo`>(C`n z;5^8_BmdIaQxXSfHGb> zi7npwcvscKwddta<}fjUxLcm%YO&uW;~FmD&13Npfb9Ub*pOm`wGk}ew7+?zYdzE; z#cg3qrv8^G!{SM?n}M(-A+qh?cgo&jR$&0n@K&P} zT;|P_6X!*g&pbnH-NVwm_ik`JUqNLHm!oe;$gpx1xk2AN4!ioBbSxgAC!TqJMw8TP zj7=MP9B?tO@;Q6Zzi5(${PPs&rV{LkXPPOY(G>2@@&i>Fa5%sDg`yRl%uoki=|~@M z@iMx*NhY-RSE!H1c>X(+iM!2ftsMs4u^*Rtxq;tf9z7vd04%ysT<@88A`5842Gv(A1GAxbXRpK_|kaKw%+b5!hvfvXLLEKXv8BRjk!sjT;rcqYuZ$8MnbZCwu z^wg4(k#hMOq>&2WEaD1 zHEW=+EpP@=M5myvyPo(q(8tCnnd@5Ism-~I+Th85fDt`U-UzFoh>^tD=zAH*jU zwt_S(*GHb(8fWn_$(LSgviY>DfY8p_VQiO;$^FR|+F@DM80(hIW0Llx``qp1cCQfQ2t=%hw7Ho~I##jY#bX^j zreDJsw*o)lqJZxY580|d6=*x>Oo&GzMy$LGD5~oyqRKzl7c>C?3VL+UHLe!KqU~^E zZ}=}CZv;$^mwxPTcR>&EHVJBFJb=4TZNIi(SC(6}w3ZkVpoxm8hfvs}ih+5+|Fd&3 zi24zYXOO|~oU#So&%IwvmoGKZ)5ZD*dAfAI%ThuD`2U`IKlWZ}Y-FUK?a#?TxIcIp zcnSTk;_RyKb_>?(`#E{)?*}^V>ptrGS#96pYDKJzCN9g>JE}po?LT;yqJ@b@Xv(fh zM7VsDN#v&S(HBc-ICLdJ4k%59*BXazJ!#f_uiMJPlN@AP%NV#yA3SLax8cq``$(zz z4*L<6nGL#2Fo(|of-grOW4xQ;TpFKJjZAjx+K*7E3sh`!0&J0kmOPk!9zJ;%Qes)8 z25W}lJS5zRcBd8^Kc^|^VXQi8`j#ddNYtUC_uikEDt2uoePD1DhcV5{ISt(oAPzW< zI+;Q;++$g#EY+dY`AmTaqAt+>IJa+@)~MU#*K66DhiNkQB1?y9ao9(ZT?2@A!>vCQ z8Rt(zc$)mmd9GYs7uPKl>_x1@k?wTvA6STC0s$O+p5@m=p0_$^UG-Z#o=1EeVW{qp zcAt*WQ^N>T!4A(o0>nICNJK8_T_>cUjRfkRn$jOtZX(-!pg`M2AgJ^t(0)1vhrUj#;oS3-0& z9&r~6;td}?_urq$!YHC9`NGC%!tl`4Vt1aXDC|3gUOZXCzHNlgX}pB2R>GeP2McG_ z)0i~e>1L#21kFi7FzU&k)lQ(sc<(GJT0w}fMLcZX&bZgBY^XKpKpwn+F&W&GZp-#6 z;`#3bLBU}9<8C=Pr>2v*M8&(GQrbMi^c8)G{oLm<*pzK$nmS9cn>LVlf}H+=?Siqu zPi;s!K5$ZZap2n;9;Zn(H_Y17+8($$~YUjCZ<#%WEPRH>2ek25!5_ zth5iaOU+SKl0tg0W`bv&{EqeB-LnLcTsuq@HGKW&*VddW?7Z;sBOVg6- z*a-MWBxe{;jVIGYlj>lsH=c^4;ebVE+MmUzijRs7CWR7opaTEWEx23&CelQy7ZbB9 zFC@r)oo6yMYfvH4bO@VBQCwl8I^SP>5= zD%+vFvRy+Y-QG*&>D;&=MIq3UFr+-OS%8B$0%=P)-mCcRV^8a+B?Da|t!nr6a~0cU(ACTkvy5HwIT0$d@=&C!Bs9ASGLRWevLQ1 zKgCb#E%m#Ej(wK@{F!T5#>d=!b5@+fkmqy=l)=7Pi>c>bn6sG#lNC&Eb*G{9n|S1X zeB5qrsjpHEIB_jaO4?Z{b{2E}NqY+${%G=f6P*M})amiLV0Vc6VYYtd+_^gXAz13& zv6snaH!yfH*VBBq;A@@j29;WzAsF*~zqEw5o{YK;CtPROyT#EFmEK~ykj1=-f8W2_ z&mZ|{v2gJ3`)AsO&fa42cBpFgU0F5J6=cd^F46A{W6|xKl zZ_b`p=;%`lm(KV9#GZ-6{-eJV`fl0q_6DV+4C-iMw5)&fD#oDVGru2~9UdZ}vZ+>~NH9P#3brhPX8+7Q`@tL>PjQVz$ zB#_F5K3E5mRu9194QyU$ybUl#s+d~K6d|9Cv3Z0k04!waFd-yF`M2^g-8B8Qoc$ly z(0|>U0mP87i8W<50NY2{hWk%Ry&tY)$p3+PCPKrqbaoMD(5ErrrRKuKtRH1Sap(|a zFXFgDd>nB>&++LCeUO*R?%#(syj%C7u5gogG1|qVkrDvmHuJ)|^fwT3 zh|lCWe_$+IV6KHr1j@AcHC${r*wPaEg+5nSq)V0X4MSr&!K`|{iou3P0us#KiXtkT z!@(S1#P)Z7ew>@+k3>-THp05qkBa5uT!_WPs#pPwPh&l0#+?nT&-sawPCuRF# zr%H#A2XAiz@ZLhX5 zb@)@hf%uo0$U=v2zA9T4Y-+t50z2Ic7>~SVOanP!<|nF!Ja)6p@jQ_VTD@`3y5lXl zisnuiJ?8=6j!YF7IUVdMZZ6uqJG|9yTE=@1lb2u?#6)Zn%+?aNk9pHX55et-%#rZH zGRy*!Z(;FpIE`4akss27*U9c9m&C^Xs_4TDM0JAlt>V=vQ%TOl9a$3llU5)e^)Sz+NQSD3Rdoqn`4b1pi&%Vt9bIz(oRcL@Fd%sZ>jdF{KL5M?dqs(>F>lDVA&DUR^VO zvk%PaC>_OBl6LmW$Z^t4)-E4?!6{A)>e@&udFNRQ9n(Nw2^$>Gn1Kut6BfbIiMMP6| z8~Q%P*-?4%-*U!sTO!NS z#p|AjZ~m-awvTj|D(N786#DwmCftXH;`HE=(8l za&`NVD3X{dyB{~USD8`s&`Ab*h*=FQ{BTzT_McV9<0}gD636T z>cFd-?iyt8tM$KYmaJeV^j~n5@sk|}plZ@MYyh&v5D6#dca?1A<&%xaB>moB&=-K>r zjC{WM{lero4IE!=N|4!PRz)#B<)H^i1FC^ zCj>u!X;o$DYH)@5KGBzyq`e^w8ql=wsFyDlOeY6=SDu1~xNeB!9l1z&WUQg$N{9lX z(=`*qmQWsH@;_I9W<3*^cKITBe`aXmJb%9-gzo%~DU-n$U&siHl5=mBE1>{AZg9Qo z&h^??v}Sz#g4iAhLI5`{x)eN2QWHm;(w9eLZ#gHfb5&6ivyYgS&(PuZ-F=4C{~+p81AaELy9()O5cUM@?x1eP_eP763cf4B?%INs_J##EW} z*%&z1FQkmO;M#$1J%}ZqvL_%(IllcYIAurITY=h& z5Ghh6Q3ytG;W@(SCp_nFQ~6c;6>jW%MN+m33dmj)kpJta>n<_d9&^lX+=ZE+?n)ac zIF%0#ssv%w8zaHTV+P2-JqW)}sFTFA5Vu?HZLTLfimV7g9phH(w6_g=Qz*42b2XIb z$GIGJhQTT#vIQ6PQPbihLGJ?RcOnUVoTzMg^*haKTpl(6*0-=`28mIwOm_cz%u-Od zXu$=rq5<*6xs|GJMr31&-{;6r(JxIk7suLY7hQh6>17^xWLT{weqVjhA2N@*xpUtZ zVp{OUUI#rcI`U8#%6gJg$F|*{df%5((*iu9s_fDwqf`?4 zD?KGg+||1mkDI8%?jQhu)A@&*+iXVmi|bR+%4u%pA2{Ci;}#F4o-Oh^4zXNwlp-Us zAB9&~a<28U$W&obi$K0NBfOQ7g01Vmo%9Z&C+c(E3%qH_srXp|R(~|%bP)O7|Gv`9 zefko#^|MDNli~ziTO8j71fz4$uVYtxrwl+L_4V}~bXCLarSsM^SKnn57VqIS(FybFEye)c}+?Du=m_a8rgz;(0MTyu>%=NRKc zce_1Ajkl?PU$_J5rBUgl{OZbar-`l5`S#7?v&P0-gbEs3nrBg9xE!*4dV>cBgZgGc z-E7JkE4YrtPWX*9BndfzW1?u?swSWJ@RsDm+gTp=AbzJb|Be`g&?>7X&L9Rev?o3! zjqTdG1_&mmJnJuGuc?)4e{dF`);z;Q5p)6N9IQvXjXRlR$y+$_=k+bZj7eEeBBzXI z8OqUETAO_>WUJNaYTAND{VJ24BAFgm2H4|}K)0jH7I*Bzsh2i5^HKqCc4ducw9e717a zlkPd|h)mCYdk^_o)@KpmbvB8p=HFM34!7ZBytEkwsEj=}d4^RfH z>1JwfkSY{upoLS=%=e}(yC?w+&fmzcE_@W%^hm&>_^@?`Q(b#swGi2UZ zVD3p7i}i#Q4Di1I(BuI!1rQ*Q{0n~aOb!{}!?0a&seQ_BVinRreF&Oihu_bW&a}H# zm11hZ1j$W|iHH2-UDHh(<38i0qa5>_x1d<$9|=)&sg+wxImBMjj}LGd-dkdkJe_qt)TQzULXFUEq6)Gc8Nl6!W)0nyvyF7Vy+ZDWUFcI6cG{ILZunqy?kefj#H#0_Z z287QN*ruJs=prImb7^QPbw`~cCI?g*_H7IUA53G_(IX5vM=d9@j&l3E7JhhO?C!x= zYea&-y?AvE#?Na-nv_u=L4!l^&eIKfd2}PXXa$L%%f_T9EFwigEv?dC1R9?qkE9fe{x2*x3xp3Js+M65n3vifyA*3bf7Hk~D)N1s z&PcCop=eCaOSBsSbks?|3Tti%yh~`r8Ke7-D37Km_PC%qFF(|bD-*nb>lG9{DrQc# zV{d+W-z(7GFGl{SK-^}Wc>u?NJOCj(V2d%)M(@8oNl$+H8+ceEt^_P&R;|E3;fuxk zC|OKJgJJZ9o~B`X8mCiisPYj1nqur5peowxb7U(-W0iBN6IUn@x#K(Z^1U>m4gJ6; znprIgH6i?x&Dn1{PnG?Ca7pdW_XiP2E)%Beo+nP7QT<=QDa;6d9hC)}yM3{5Cq_&f zF)0V7yWA0$d-{P=&t8PJ-;wEoKlHH^l(?e^|4L3@oMXsG5QON|8f9dZI0+j}yUPn9 z@ZZ;XBjcGRQ;C7nj{Ok_arS{yrcl3T_}uuKh@HVPmsCL+I*;!&G>dM;W-Z6A$sLBW zHR*dU1V4X3!O8QNMRpQ1ajnA;v2H{kR|C(tY$~A!-_km7LZ%Y~_d4zQF8Fk(_Ro8bpz~0+Opo{y#ZzL@9g;O*Gtsq$UuL-jH65k1m#JxC|cC70frl)2l>C;ajDZH zYg7KqvaUBYf5U<2WbD{@>C<`hil7G(LaAuqCdvY9qy66N3Y~UO9$Hp?0MH`j*&8`m zNDY|w;)Is^*0a|8Ay*Q@m`5Fm1XMDOiR{#3fL;w|0Xy$TLO#o{s@VETq@nJt5`XG4Gn@+f#4yN}=1AOwnD%%q8tbd2 zP2_70fu_w0L_GcNdW3-m9htoZIw5A7zDqbXr6l(vAqpD)M%c8^aoJ}^DeRHPt!pmI zd@(}TWjQ)}+|pj`I&(j2=&iDnqa8ud;&A0VB8#2n4}A8mouHfQxYI6y)sydHol_(w zNzTG~vRwL>d4H1GGc&kCcBwTk5$Grf>+qRP;LSnUMY{AGIZLr0#AXVvAuMBGtiaFy zqguq!)UUR0E3um3)WcXUne(rfzAA|xx7egCfxahHMq}8O?*T}iZZmDO zn!{T|VhHLvG18|nG9kH%45w!KN#fc-E@%SzY*Y4GG~{H$>VW$BSLOk4Bh4XWKK~6m`D6o ztnS+=mMI($KPANppH^NdJPQ-_1yhF#BG%k|+EdX^#D>@CIE2ZKt&(tj%lavOuW9s1xTKVI8_+$N_ z{Z9 zQ*&j47?FHk%+&O%0sIr4TykFY?Z98NYuL+RE;7y5pk(e=Y*;u-?7!CEb}ddhxO2rn zY!W_tTBM?O5_V7nvJ@J?y-n;q3!W|Jqy4dPIMgX!JYbVwh1VHn+CLWkBAu4ALai6% zbLm>soE@X zw3-=0jupGcTF6Ba9g`MAz1@yuu>`RIpIHiVI`_tOC0;~MU%l4-iWy82lJ5E}9 z{*HW9UAu`t{4T6o<((8oKe70Ft`m2Y9^cnX2{IYFfylkq&ra6+Gj8#}jG{yclil9Y zLxi~sf>~=K499(JZ92#+(?@nTcbw-nC6eiREZX#t4*-Tu;EjF63a<@k;A*fvKgZqP zpM+1_GVW#1PuAJHoc}hCYbjJtB*?`=3wc$jAeWx}^!(#;fP=|emJ&S-5-R_vb)#)@ zM$pm?J#MV%sy5R*cW0$e&C+|v9=6L9R4e}7(+FWwl$6_V|M5=w;Ak)sfAEARoZtoNRfD`5 zV;4q<#Z;bN`(ahhQHy7NZ160C8Xpc3mt|Mlnf&5(P=Oj(31%n(yDj}Vt_7}KD>b@& z16%!%DTdt%z9KS-GgJRo|6si%IuWnQMRb7$R@}b zxN4RTZW66#18elAGR?Rhdn$9i&q;}6BCC(0xn}3iB=!pz5+~|WFHbgo(^M_n_He7Y zsA1(B@N_Tozv4d=z+0jo5E;zDXyC0vNx~fv9b91zobtS{mm~p+6+YO`K9GUaYQ}2P zc>9cK+`26*nB3LTkrNkX>;0|L9KMhl?jz=)j)q3My|23gYty&*GZm#EyB^!&@L z*L?By2jQ0}h3_~|prAQF_0g2a_rer}4%vq|ti&jg?7Z~e@%Un9B){Yl-e2iPWcg;| z%7c%RGk7HhNJ$4JZrTkFIM!h*dSm@M-#FLt`qa3we(*Yrj_^EWi_f@Y9cMH7`J{4Z z2!5_=JKSABYshA)Hb1OXJpQ#6f{7QiXZ^D3h0R}S)lxuVASE$;5l+C~g)e9@f5*pQ zulsO#o+)Ct7dPG7TJqwD<|km{rXAU2<){x~NoyY!d<%2TuO~?R11s_Z89q}g9vcdA zm*mC|S0>vSRu%6B41%3q{o8y(+d(7Ih`{eY>3>mV$)9}z6WqPhubrbd6%>KzvKd3I zt*yG3UNLYhyaUh@k6Z|Pl!3aFna<|kf;Y`7+ihbtF3jP&F75QY%f?C=tgxg=y?E?7 zS?G&eFMLw#-S8fQ<3s0^EUe^?(aE6K`CQ+q=w+HqkWPMKcx?wa?UTHHT&sj1mGy}< z%}zsKU^v0@@mb4!J|x~^v$>fY&WYHud(N9X_w4}fj9f4_9Z8mTTFEZA)(e(LLdR{s z-c&KKSWdU*P0cO3H69!A#svtc!MF2g?Xd^~5Y7h*+kgRaO|}`^oCRpJ?2UFp9szFn z*N|c*ou4_ZC+Y<@>lLFdp)Xd&uUW5~N*D@_*drg72v>64%5fPo`m^oMx|03k?Dx{- zUF{1Uwz(u;=gojOHyoAeatf%(vP0RZ7=v!kC=V69|KzDDnTI_Xwq*KF!EOSs;rUG3 z%6Z6Hv+T-m;CauMy^o%I9S2Bc_mtRRdw?5T@hEAg5!>Gn!>u2kw$~-~aP-NK^HW;n z!C`MyBa)b}ao&u*im?Bx`nswYz2bCU;#!T|V|0(gKMSUE@$TfaLiXLm;t)$0Q@Bu` z&FH4@5^h8wS^=S#0z#6{vhZSwkK0@+W6Uquc#2Hl_-1r4hdGGiVwl2X#TviP5q)gC z5utg%n{)NUFG11Tv_bbtiEuR!LKzrA;&h5$)njpM8q`X&-pu~<#zYs5N$*pB>TvUr zaCcN<2DwhH5i8Zwt*~u(f%4l*)!U?hh2^h)pWAN`ZTF3?CdwvrR(#28`Gp@aKGQ>WCpQ?d6O2eiG2gxOLX;oEQP^wz$Aa=oHO!SoGdt$8DD>iBuc%R+b5Yv#*aao$|C z&)IV}^#jTAWHw^m(tY`#9>b-efmZJKN1)XCf-PtJU4ztMY`1OXN@=dn;lY&$?tPb^ zXOgAUVq|CTaky}!qsE4dJ0gl!6O4Sowf!7CIWl2Z;_=nHgc8itYl=x|o=ws+Y=I3? za(Vh~zEO>X=A`!->!}MJZJ0P*a4hY4yR9pHT4%i^xoblyUSVxW*MWfZG`k>VVmYAZ zW@?n|9ttCa;0dz#y8#7F9SCwA5L@1~GzxMUo^&xuhF5C(L*B9zl-vAwF#)a_Mlslf z*|b4W>v4l!K3Cdz$a&sg##O{km+)xn40--^ zf9S^;Kq+kBMl^s%n?2rAyQr&GmMC$1lk>|o=`Flf-Q)ygu6a~NvThxCS!6!d4V;}B z`;?{mg>Le09K5>~FT3V8j@RSbl{Xjr3(oY&nW?T6rSyk>Z+8GWL^8 zg8Pw>8Xbj`l$Vk!*dtQn(kRV)3$*%_Dqjx1=3oqbYXs*29XF*AP@tTHiv! zk6!7eHF6M@jF{@tj*yKn=g3i9RU>Z>2wDy`!`NkNT65zMXSMU*xt{KTUUV=qc<-uw z1!T`ZI3sAZ%js9r;`>)3tgF~#Kb-=g3WN@i!h2tJ(9?9R%ozv%-;+t+W~%gwG>87w zmo*D5$Gs2RyaI>E+>v2rqpU*eN1o@8G~KbIO&lHQr;)Rwgu>OkuAiQ<<*6d#vqyGM zxv_9|Oq@Lk3wrAxCWAJrfnh!AjMG?{{T`YD=PuPHH^w1P7mM{~QQ6a_cdR}0>6Qs(CY0|;=N?R_vdw-o-d6|ww;6dBpqTf3UsfsQ+IE z61|8v^2Z_sY*!*FA6yxY)xMeDMM*L&51GXj{CspFcrTnCgXL$dPw7q)u!_+!d3hkL zxzaqBx*kzJ?3Ht&N1du!d2<#c#Q!BX2vC^HTXLgXop}b3Mgde=lAJD0FLqKWx7^cHh%|gj2Db5)kBD4Id-D&%C4YNA zD{y#cz~L1Fppgqur;tHq1p32Za{pI5sRP0>P+bGLmIOL(qR=SbU_P*s0GzFE$i6Pz zd=Sy}ppH}_UtME$Q@FA_$Ddh-ua7_P+nxV)Bps@~TT`OGByA`5QIHDF$>%_jMqXXwS>i9|4!WKB^O?(M6=^ zY+>x*92^twKz;CcDgP1d2L}A4k)KKe7K}4I-8Q;tjRlVz{RQufhEbqg!^^NRW;(u@ zTpAS^;jC!jnIH%}8M1HN`XL)FTO@jGZLuFEt{HPi-ma)z7uDTnX$ho~(A5eqd;>a~ zhS^bA^(uGl^@SiVT&qr;gWH=!VR;^MbzK~uRw}d`K~KawqnJyVS?7J=$GG}IXD%`p zOKZ?P&ZiIJlAHDU)9BY#Zhp==QlKhEyW_M72Ll=iXwx}?d+`>uobDH$1uVKyEEX*d zLK6zl0;rXout?j^t1ax>lEqoj`bXIkLZ;$@NQ*WiUnQEg?0`;%S#-dX*YpJk^wt)v^!3s#QmqqMX`x-F zXyrW7ezaa}h}R|TED3xnW-MTV4v=BrN<$8SJ+=_|!2`gCtOqd` z3sxs6c^@|cZ(I^-!$q{f4iA%52K7j}cIgFNAS$ono;xg8PO}gQ+71(C7sEL3G#`OM zG8!O_NbE}d#ST+>d2L!1@q*@lcEYGX<$L*LYsw@Bu4R1Pi?OX{X9K3Ma0xcz2>tv* zGTo6-k8`NM`zZfw z#q_EMP{iw%dwXd}dZMNSUC6f2I#+}A3s1H1E4lFXNv>BoPzJLv(WaP?KaXn-J&Ri% z`AgTu8Hx3tl zpzNx;>Lk2UW!ht~uk!{A?02t5+3xJKSDM0J`Sz5X(O8olzu45bCr8>~Flgjpd+HSn zjhcZcSX!pP@FFNd?yr0=aOGeADLHbw5xF2P64VEphVy3Ap*cYp2Mw*Yk5A_F?Vv6) zWZ$PU&Ev{CTs(x9hqq#iwiM4bcAIT+U*xg0tUF~GFSfyi-}Ngt**1)qHG>|(!@pJ= zz5jK=BY#eTFV^mQ)2HhXCXGIMcO;}DlnPNyiPKzPx^45$N2JU)&-W;=oU*~qo)Aq` z-eAOA^F2Is7+AInUb}t&h7bnnn8z2(*h)T_RN^W#a*$WGD;>*6+)8MLgP;>tBipP& z{l9?4wv~V*S3Gtkr(}*nqdU?hzrL_gJ=eN=N@>0i98`Ssm7X{?AO+V*3Ed*=9_t-_ z#`Uwc^Hi^qn)5a2MqHE3g80J~cqf`s!v!p^#va;^t;aOaJmP<@NqKiK8H)DCGwWLp z@F64aSlpxo!Y1J&DtLp}GAW0W%Rj2O5$C4->-lx zdZXQIMPU&Ky?#6(D*4TEORre)AjwX%Tl$U5&`TXxd=(uwi>s}7j5~@?=qd{ooUh%I zL#Mq3u)TWl(`CnzMp&r#zT_d?JAK$M2UOQMS<)h{qtXh}zt*;(;eGdR8?9#)e3`APqd_)!C#DY1Qq=eT-e zjiCTEf+S>9CkHTY2*u$^yX-0qh;?QO&8p;Rxi-FBbAjj<=JEZlDx3HSwjuL90Y9Yp zSW_5AboKcdL;-tsZu5E8ZP}d;l|{)-;*8jy_%~hQOIwn;OBgW z2Ab5_4ZIUXbJ_8`3`W+Yd&5LWTJQ#1nj3Qo6sSG%Jve(DCqKY_;|#1Y$71=&Dh)HKK%VcE-)|^@KMPp@KnMd7UR-GF5ZDg7zy3m#k z+c}cr8l4E*Imri*F}yY{^*%vE`ekap`)cyLyc~j`d>8+)hnY6YEQaAUE-QS=yX>>` z@^;4mA)NN7mPXd!N?m&Cm_LfVIoSBJW;1LPyr&dGXzcCCgW)@O+Yk_AggR)n6^hzS zG$l4hA5j}sR+TgNnvkU2h2L=UlnB@H9LasdMIoO4hIM@{!Zzgv%+98F#2XE!WSIzP zVVH4OP0i!Eym1QgxU{k}RZuKE1QYKG@AOo$fk*D#V_E_8c{X-#Z?S4jQxqI-y{0*A zOr*1u(+5mqv0^7B(Us_crGy}v2GHs&HK3QiTpCticuFGU$r_T~t%4()=em&2j-M5! zrvXPHZghOa*`NE8r`7#0jxRannH^}m2Gd}&7SR}6QIQGRND5d6D)Do>Pb8n1@J$HOb)Fyq`?g3&`#Nti(ue&*FTjD{rb7E8{B9c z3WWzxT|G`Esr(;y!dri6l4lER&nR10A%Brv8FN(x$|F2Ijy0gT zuE9p6^nsx@=gO{I0igt1>Nl78krf|Fw=V{x;?xdcy{VX3d%y0fL8?SG7P4W92rSgf zejbJDJ+0|z6Y>wQwXPsu-H3mM1h2gat#w54(-T<}_OzAKZ#oej&Sbvaqg5!L2! zA@`)*tT$ay^B`ygf?B1db@v=8urHF4a$rMZWgg;>=RSHK>Yl$*9Rp>fk*-dPiW7#a zuTUCj!L?yhNv3@=Q3lBBcmEUgk2xYEn8UCOnz~TT(xKCN$w?j_ls2_=+y{Aw-bcdJ zYn$HeKa@!YBE?vK>h|eyFKz|LO*MYI?T~fV;hHn)`MG3z7x$2^sK4F%Q{}&J@Q}TT zA?fA9?e`zt6*iRpk5j|hel!?q>rDEm-&k{Rc_9p7kM%}5>|QJ=uCwvi8Zf3x{0l#S zN>cCqYML3=UHZ)9`a$Y!#dN9l?DR3T(LpdI`p)ZiVvZB55)J1YEBm%T#18R(gTMlH zUORMR=_&aLIGJ*hPT#6Ems!oJP@4aSl`BV|u(lZhdXE|qk^6S9ivAa5NzAjrw43}n zL|T$Lcxsy zA%~Lr$%AjGo?r9M@jklVWM!h&7W&WLRpZr{gO?n2as%4SB zubRU;+`fpkAKdk=Ob^Z!HIVgIt<9K5>do#R)Nk)7Lm{xA+naB->1-`G(~TNX$=siI zY$IKq-?#q1;qfd~gyQ_hhEIGzZBWhG*OEM&GF19Q9)H#YL4|oHVoo&Tiqf;%AaBd? zdU~A!RGs+~&J-&0pF>5jeuNg1(_mEEQW&4^hTfJ;=_-uhs+D(>xcF1-wa1Q-PV8)k zrc8NZ+A%+9uB+A8sAg8H^E{H=rGyAEa&w2I`iaxYT-1s5g%Ib&PAT<4@)l8_~t%zat~ z%o!iouh1{TGpy;dd?;IdaJ##Lu12%ey6F6Qb^24eOCbn?XwARc*&w1s$F~LW#{JXJ zPG~A1F^GqLbJL3Wa|KN9GB?(%>|9#tV3;q0)@l*NFde`Be1`METEZ(Yhp&#)gSr7W zpCfn#_|oOK(pcTsUnd_bX=qd7pE>t+Wi}0STaJ)sWjmWh!Y>WHlgT-Fj&9O`#d!p7 z=qo?aXcOfDlfrqKk67LwmqfVN(U_4qCM`7aVZ~X8$84wa7o{A)GQuWIj z&!F~ttpnM}!*%cZg3c!HnKBA18lrC_TJ<&QmZCmw- zztUfY#%JVD`SA7gtIV=IF-U*+E6CXiPIvReIu*%c@>T+Cu?Atyn;U{%;|hB47yd!A z3}Jeu4Wv#DJd;E#%I3P=N@zE^{k~_dO#5ev9|(RRQBFNiM` zvpv;Z;Z&<}$GMKO>^Q)I8T=|mZ|d}U$ov;(gKEQ0-u+by9wG6k1t!z59MJMk?erZp zw;B7ODXP@xVNCafyWy?0tV>$h$EoeTI|7QsmK#qDh!uQ>=`QdwyAt$FKU6}eHc)Jx zodblmo5cQ5)L$OoG)~Ttt)RliZkHil)y2Nls7=B0-kv+5Ji~lna`G;nEw*!8G_!^< zbe|2pE3EeebNf|>iiE#oGn%b)>@9G6hQRTt891h&&XyIeTf^GA?@eP=kU zlox;Cr_s5U+d|2)O#{2x*9xvquB4Z^uV7N9_p${ypG6jlZ$6S`%@p2Ev=?b1u5=!1 z$=GIHn#<5uQyY5luHQIuq?DpuYDsi$9+qZ=d&s0PveTmv_Hn^Ys@y*s=I&ZSujAM< zjb7m&yRo`@wHcd%{QWPzL`9JH`MgII`l*C?w=cn~q~3VXMQEc2>58ib9D38Sf!sxA zV0;E=T7S;(XCtW7N6!L%N$5GE$0;HgBsJx4>oUcD z%n&n4%h(|-H^|N1N~P=;o>-5C?w8XwzxUzYP8da;MDX^H`1#M*HVCah%i>nbk_)s{ zwzzZR<)8tj3feTpbdhb z^8-8hV*uEQg)oBdzDn& zwU9;mJCiT(b7~>+e&CX}TD<4?fbLmb*6=xayyhY^4J>7)4e+Kz%jX$Z^Ng52GatY{ zE_{oe(p8&e{1mKY49TV0{N8(uz~XMjVunaf#)bW>>`mL)$eWoOEXApC zdp|hBcM|Mhmd(npMHtU8i;~mDgEYR@7}OQc;0sya&OX);s$Dc?$fp*$FRq|4&R&pmhtk+40(yy>+}+w;7R6^9pHNJV&}|Nh=xpd_$L zn{Dcr?aCwjC!s>q9qGlMNs$J5lw4TlTj*k)cykyx@8gJJvnA?3)NR%9gNsQUPK`Ry zOr9lM6U}b+AJfw%6{y@0^wcW#@{y?T1FQr<5&cIyi=Rih@m^P4p=p>)dMH0ZuoEI2R%Gc@+JSgc2O>EbHuj=u&C#C5}+ zp<+ldp!=yM$V9`hGe-Q&X}HF`QbAAa@uDv1%Yxiw`On;xo*cdZe>)I>S@InabuHoV z8JZxuQ0)S$q_w4XxPS~~-gy)S5z0TuhlV<0kzw~W(8;8O)5vaqW%Vm9b!b&w*F0XB z0e`X@DxH4JJ!#XD67a#AQJ5I2IMO}uW?XTyO6;WXo)0g9?QHqJ~7rEEb+ z>iJCwwwja4ugmYG2A6gc`W9sU6U>#BTb+!%RP@W4t-*G+{Rk{}`tOLC@WGj1j%)74 z8_AvAo+R782jbbF*!W-kmminufJIcn9{u*rH>3m~MCFOQ3(ynak@{-ic4fh5=h^;< zSjY`&kLlDa5`@6-ys8`3uq-Ty;&>ER%i_uh#eV+s({l~E8Ug7L7`h=xiAoF*tyqRX z>AT_nolI$Vl?PP0S5nEzPiOMhSVwJ#U#3+c@<)bEB~1nqcTDnv005!3`%FU4mBbwZ znSn<%sT>w+BS>VR6L7{~IZ)_-pIu-(fw=O<8rD42JE4ReLi}ROMd{9~CX)$Crb|&q z4&y~)A2W+#aPm6)o$IJ8mVGD7SEuNAG`nH;ZEIk*{{I;B&%k5>-`nf3TNw{s2&4Gh z-4PhDPQ?II9IjLnMXA!RaWAgLh_>4S(2ND3BHroT!N~*r$qRTxTHnaVF8!Nnw+T+G zXA-0R=y)Y)>RUfVY8s0DN1dJDWe=Igl&N_g?@f~wRl`1A95fx3EAWNVk$pKOj*=A` z2%O4KHC1JV(Nf9_`XmumN*LajAz-p>;Hf32&=dt|7^MR#o!_;oq9uxa9d_y0e|pfa zqU=US8a<~J|3X2VBm+MA=TD$O?HQ5vQecb2t*aCt%{q1w9rk@?##{e4g=P zi`?zhALOvixr3Ux$o}0kop5})=VUCnaqo%3^b7~l6y4;gSk41g7ch5Y(~XED(s09{K7uugK45fCC7Y{ugA*bE{L?1E2&~Y z**k{yl;9MOvy8cugL#fq7-m*9W+06SgYvlU3(3+Fh-n|7p}_Fk{5qvj$^21!Q*~(R zsnKcZ`dS@~{Og@X(62>2Xy5|?ya?{)Bl-D5$O(6oH{jw?zs3lW*G}e%B$AX*ByARQ=IDnl$#}F%%BiDG zu?rUQ?-BugTGx9;6c|kGEc78F{oB753zzUkkxlE#igEO6em8vD?Wtm z06H~QZ|C~`&9uEfD?bTnI*pDtR)Q{dr|-ix+m!aeNrHV z7Ja!?^I^n}{u!$?&q5P=zYTs^0nB7pGkef#$%LQe*^Qmi6uYHXb?0_XZpVS$wSzS! zr1liry=%p4_|m2?r5Q!FG}vpfU*kTLK6yst33aQr;3iz{^_NajT^zu_60z&8;w(2s zr5!RT`gMBGC%ef=O^$wz)qw0^zAW0Rlw!AT+2@5R`EKZ$KW>nSfGwwaeBe1q*#9yE zDyYk=vDGNhE|B=AFzxRz%;U^g7I(U0u||Q4b}^H^W;Fb_9unqt$v@hq&j^!x5y!a} z*W7vLpLwnQG)1+I&igg?GOdgeTA$Y90XPXFYkDugvqLw^oj|navJ8B|2T6X8cbCWc zf`(M2lA-V}E(Rx2wV9+g(k7DYop%-}-Rdsv%+nG4srO@A=Vt)VY|ksO_wUf2y(?*b z;IDGbCh**61V%gI2huG%nt@AorWA{OX^CiPtem}OiOfSsk_n8*Ov?dhc(?{)7miTD zA{p{1@*$TBpDEPFUZV87LWAS7oQ@qk*oB^fxrL(cvZdH3a-eHKh+jV5@)1jzEz%4u z&Hc=c@k!<5r7+&_X>30=qgvwM64@CqNn7F%KIl|7%jng@+gh{}=K-CVrLY97{>QDQ z;7v?2X$0!MRnTdkB3t_-3F0qo53G=K!El`5M;@6(z@n89@E-Y__VrI7_GiCT3RHlV z1<+l0i2oxs#KI+l;Q6B+kz(hJ!y+YZhfsri(?ldC@n{c;>MDG5Q`a$Xql z&^uo{M6=w}#8r3XHN4Ox-0dxYP5V`jIf`9_X*gl;fy*M~?)~S5 zhyQjxH9ELCoUz*CRzp2tJ|AU!_RutO=IM!pED8^Kf7H^ZV>9@{*m(} z%-$CO&FF6fwE!^#;B@;VgPGQRklZEd;(R32^}^AMa(R?M!8nm`R@$d|_7|a>lSXz4 z2igLz)m&PVU+zYgF@T_QXFV{DoFreEPRpYo%jv83_OY!5cqFkFMn9}0jwo9z>|{S4 zU(Hf>CPATcU%0-(&>~gGhKm4BG+aGMF*~yKm(Jv05+(9_^7x9FO9D>Bl zyKC0<@{A+*IrPij)G_t+7uZm-J&bFjCL}M1jdWS%#|$EjmpknSYC1wBmB#jMNjE|n z(b{3*sJV+7E3OV#4|((fN5F^JM*dWs#~+TaF!K2;tX%$yu#^>N^N z5Cypqfzp2WrOEyh?#7=mxk3-U( z2e;8ymGn)HP+_bG*y^i9TwO@2{1~$E(cs`ByXz`MtU0Ovx#-0i{*POg+LDGQV199* zg4J{+<6^s{#pI5J#hiCu&H4Y8jlx(Ehm_EQu>8g`BOB3niwS9V!stL#5?BK-7M z@8GArt1LZ8OUFhfd95mw_s8<5FsW$`Jk|()o=Air8j@nA*plwc z>EzA6BS1XmvIkw-N8f9o1mVMHNe|}1F6E;xRRrb+or=9p>%Qq59QRDJst@q{QL5;% zv6q=gXUh9|ZDqIG@5l(Uy{%)hL~Mn_xgwA?in(uur+jP%#HT7JO?k*kqn2h;-r`5g z6WVCtVVBH3JW?%gjqCL#n+?9d=)to82f8U>h;8A3kQKjIG5>v8{MI#tEadm~PL4;I z3Dm#L_riHvRA!DW_LfmFF}SxU!4SzT1J-&WFgwr5-8FPq*6kr}2lc-DmPwczUd$-d z23inuV-Z_xQ*wNYm4JQm^)KqkUQMUmP6OpTpd0)*1t>;VTjq{($Vik7y)3+4Jy~ZI z*U4T|-mKV3Ix6!mk?!;&lk4?-`&9*S_?bXU>Ooux8)6Fk6}ulgL@vMf*0rgC4H$Nw z7hY*Bm80Uic{I>S6Hu(cnxRy^h|Cba!iEkuM&K*HBx9oKSh5Q-BGz`V*ME6JY#T`| zsUc~|{N~{foV@s;Q5kGl=Z!i!fc5{8Quz;z^Z$C*`FO|?XzGlEL-ht|4c1WdQ-oVr0NrazQG%Z_w;T`jG^9 z_U>1)G-B}V``~&vYS*^^+!2`{53M!vltYocUUB}_!CPck4dw*>z(1GdZzA)5Bc{CM zAzx4|R4ydASG_70zAXfakRIqm?C!zfAK?2i7uhMQ$DACkD!Z5rzM4y!LK#3sUckc+I3uUrw?aZiF}hls zgn!{2urM2L6RotlCc#6F;zFGe=+7QRPajTRR{4s@pBi`%cP+UR)jfviViEsV+dygW zLZ;%KMye6<;tLa~r3aAiVc}n_KadE=fVvIkMgqy#dq3H$Wx%pD-LS z^;kO)G(_3V9j16&JBv_f)x?4lHGUJK#VgRXTOpiUe*117VN|-nSI|V%HJPg!i4Vo6 zwptwX{)4(=hbAfaw5w-H-Z~@icUw~XyR<3~Q6$M-Mv^emTwm?8#{aYp8moOqO?D7v z&Tx2p$nS`)8jH7%(Cnd4hs4I=Er{~Vg;S#5XK$}1Kp)ODnU2QOdg6^721S3LwCzx? zSk8jxevju|JJ!SuJ%gi?5ZNl-$q~5MzZFK?moMb~)ONKQJH26HU`tO21~-5ly$6rd zRjrS*V&>PF&<7p#7o9j4HD1r2OHs5J;V10^48r9Vs?cNmc9|pRnBEdoB_+9YlnYQ` zDe&*BAx2#=V7fnlqu0CxuAkd;U|6G|6HuIt9u>!TzzV*N3ew|$uhGK)m^0D;1ep#8 z6sV>xXR><*&k_ZtnI^D6c>Buf;9yCu@ z_@-jD%yIZAH33uT$ z7I_8-Et9lgqzRSPH}!&RQS`toOTICsqd6!Wg$BuBBmki_?%<9WaomhDSkC`1^YV~Jd>GmRE_ zeE%!Kg8NCcus`i3PIW+kA%T6U)(&*xe0N){poL17T$(?>D4sX{4V@C43mDHF5hoW zXB|+7_p>C0RylBy^RweGXc`;H4WO-$T+)P-@?zyR_~BSPpvm*A8te?F;!i_3;2d|N zsF0oKt@ne0yO~?7%`R>D-Y9P4luK8inRN~6J|<7=0G6Z{yq@6|=lWa@3vRzl*9}I+ z4gGzeLGLXO`F|!d#6o@bg#`T>ir%qRm`EO5&5%-FMz!~o%(`9(^BHKnTD3wH8`;Fd zqjp8-SDg_CfMNVKc|;zEf1TdTBxn3F7nJE4ul*j&=6lfl|GQ>7as#x^gH=4Ydo1jVm^cD+tbmQqI_1U*g|mK zo(@U(FZyO>2@`7z49H8EApw)d>QC-rl2ut@GR9_LAPy>UGYR6kp{EEFhDq_qO3uCu zJ)*pP{jT^2AroVv~17)7s1!hXCuEwfZ)UmjRUWLi)q4}Q+ZtI%QBJ9ktQecu( zPzKF`zkfQS1bg?5Z}R>XG~=8+ILp(rgc~yEy4`2wKA-Ml8g5Vo_WCrKVS=f=WpGd#};p@U=9J=KjUT&!yrDORBLDZhgFA+s- zV*dYhxwE5!DoOj}5?aw?TMXk^IcWC$lU~p@*7lS{lYxv+pzu`9w?1RFONHfM)dm0C zU^nL?4M^ifS7p4VD*fBbttarA!Eu3ori!YnRzr?s(()%8b0>*CatGN`!ZPYB^sDbQ z>prN*N}fy~m1l#PJ1dQ;VtQW@aZs~^(z|pLj>xm0ewMn{@aRj`;nwl`fljV!q#Bmy zzj>(%@AHVAd+4iPWlH4?k$Y57xl&i7x3XYaP*fZX?!^}2=CVvh1q+j(AFWFiTPnid zE55>xJLr5y%0!(FRNU}PtIYOggMT6L5Kvg5x6N3M43>(SXMNM9rV%4f@s0oNG-y&7 z!gl1m=&eAT>S5CbjWmi8&l8R3RkrIz>ICD;)8La8SX3o$9!kxNuM?0FJeGmL#ey{v z6xuIO2;Uhysz1@SB(uk0FwYHESqizXf@^iSd)0uz6Q&7l6Fp)^hn;C+H);{(WxKqq zJZQ~pOfzJapZFVl%x5GXP<9hN9GT1z0Urpw@+e?&EPOBKU7jQ)RG`ct4_!L|ZgKbY zGrkrljAI>&Denm_#p^B=V3P5-OyqyNk(<@Tkt#dEw77R4O+*eh4~4Y^Aq3=2@CV1K zN=eHJ!*BU?2iL`ORU_0sQFg(sU>zq65G#WHGnkzb^qr!qCf}csn{L7_o}jKZ+9{6% z+qQinGr}NQ{6TgZ)=Goj`6N^w;Up;{o>>Dw^sLLRw!03^|8PCSj^2Nxo-y77K_K$S z$qQg1;ZHr%u(HIj`4h>_5%fKt0_F=nAWIFt5X60~u2aL+ z^C%^Q6}qdSZA`%ivWm}y$Uo-jCQIiAr)N?{_ro@U!#S2Qr0i8|h9P3HS@S?uf`^bX9 z3A+M{J?Xv3E4&7%U-#C_21=fd4M-lcQIw!W5UTMB7Acmi(Lzq&krigNi{wdH2L3BVrb zkLl$9wIf>Q!js{+G5mkp`|@xo+xPFM%@SoRYhmn>Ws;O_vM-HniI8QG>|2PzWM9Uv z46@5oME2|@Lb7Gwx0&o^-;Lje@AFi@=ldSV`~Uko{4tJu?&~_8_vb$M>pWKo*E8;5 z;JU|Vc>$v!jg?$!KEqTp5)ypoHplGhoe$7$-({Qj=p!_1QRL&MkrTtR!TE&7Oy+&| zv=;b|v9&i8mR%_9$>(nornexP53R!V!oa{-hq|H$6+4<<>9JCV1@gqS4f@; z2u`=>y#7}ugPHIBWRm?}T8wpIoZ;z>tEZ)3C6K?PO%#md^-80?>`9KGD$($cIft~Lb00uG198Lp1a3S9n$r82|1}xE zRAY$FjAO|ssOta$`m}9htdu|3t%-rj%x_PD!}$nAP1;{5Z6>}4b~3+B0POjg&;5n| zOl@s5%4C)ogT`CFrkQ4YEmI2vMJq_9^|>Bp-5U7+m5~=*-9mhk`(6Yr`%_)rFXl#S z^amu}Ows~nPFLZRk zsl)V5FqA6GaQ#BYd&KIF@tgL9cNQ`BaZ#`XHzT*tit5Io*twMxce*p6NWtSKNd>RQ z+$YHUoA202=?UG1N9E@}fyPxi-2p`<^&fAyl&X8*KqyFfJWp71ce2Dre%9!HBVHg? zPO*ZaoPEU2JyRW{;qIOzZg#)c0R6+97Y8$W4LJU8>UA9F z(8&OS`V?@+j*{jur*|9#HneCM#re9e&owi3H-=pnv_6_-2&s1y^IXBWUvt2vnI0ik z@X5>tEKkoD8x?N%ur<0HWDPte(U@*kH|5w@t%)5}E4#h^)&a`-v;{XvVvx*^ALX{D ztezoKp=v<^3FP)qPg~r4Y{)f8lL`T~6vN5)8D@*!+JM$?(psecDb$`m(!fEML=EnpV(|){W4{U>p;8Ji(P^N`g>JhC5G2h_x?0%#mKE#5Dy_-SQ z^Xu8f$CdiULPJ?n^!R?VbGo4d;zRX}fYp+7p9tmB?kk;p@?frR@tNxwNrO+%Wwmc= zb7QU|A9SgSr0kqz&fK(IYZ|dTZxbNLwCOaXc**J|8G&+f0<(#~9~juA{NUu2aB`>( z4|^SwanP)%+1lfIOrV0vVS1y4{#+f|VgNUFUtGQEs+ahRSvG3vVpq{VuHfl1qWnA5 zZv1vv3_sOjLf8$!km_!N`Yc%N-&TcJF z;m9mH+lh}bUvNfrbNdVD*SnX8-)a+?MDbZ-=(2Wdp5kI=dss?u?@3(Md@HC^%8Gup~ERVa4ulKgb`8XN7rD;mi zm|rrHqUw`=!*6_OeBjAjHUl#VJzd=u3qG{5iKRc{dv3 zNL)^SC(wW}_NG#~C0T=T@6vw4_k!-OPN@5N+4_i5X8KM2ka@L5cw`vr=?53CB)JMQ z4c$=4DemjZ-X)7Gw{7`26@r$WKYiVijd<|&4f(+66;X>4N7)WKv~f;ED@kv&#}531 zY?q+&!iinbSjl&1vb#Ss`}qLHV3WjaD6{lO9opJ0a{c%beYU@wK0Wg2|JCyeiaiZY z`cZLCcAjN+g51gKoI054=$R+FR6p5&nEKnd-!sYH(=YzwQHk(MRMgIcm4{4fil(?W z+Gm>Ww7fEI3|m3AWGcUe14aDb9ufQ24r}?*GYT?+aQX%WunGvM>c4z6TIH%;(EQLI zBolurmfwQtcM0^L??%q3q!|nhF^k(z`tC_Se5)hWMSYyB$qm;&r3rfM6`4jOBqItu z9;qaA)^bdmKnukA@nI1ei50Mo&y|Em!TA)%Lwlav7izdiC_{(BZ&WMQ$1L1>s^h65 zQq}gbp#Sl%?__gl;?3%|b2UTa%i|gv?Sx4maMtto?)%Pq>*wy<7{_3;&I1^WY$-I> z%pXkb>|11PY*wq5=reT{!57Xzk#`T0qAwT}vE7-W8y+kt<{ZAk+|eccxsXH2)e`|b z*<0D{vE`LfyebJg!b-X%=#=S;lnN)k5t9xyt#F}7+s>_BKa`YUB<(Kjc}~yev64&K z+9qa{F(sA5xb`;SuPC>#VH9V-W#HlivYXwTLElxyTab-oId;@}pSwqCDL)*CH*R-Q zVP61)Dt@94&@}w|zXaQVdPsAf6DV=uzy=S8%b+Ot zmDcz>+ZRkDf*V;SvJ(>%ag8!ZwrC(PIVT%K@ue^+Q{AcN9>HKg$9Mf9)z2GSy*ZbY`uKtD!Gc(S5L((I0e2jZ3md@aO8)8hEM%U zDB&J>2g5f3~LQ${)5K$#mZV%t%667d)>%K_o&mfFPY>f!%)l77)LI(3@UxusvFfJY_Dg8O!x|E zY`g^nHo0o_L#zu;Mk&h;3q5(z?-iT9(#>#vMjKmag3riv{BX5fAw7$+^bt`em_i*) zHC)+{!*~j+7Py%;^o1$9D|xO`?U24Sb1Mz1UivYDyl2}%Cd*USSiYp%;4MnE84`QL z0RMGOV%@aenuqqS?Qw0D zE?`ja9eJ5_lcH^4DP^mQ!$}nHI ztgYDZSeyf75@R+hc^-;um&D|$ovVZw?_U+UmfVIOH!QY-p1u`)y&+#1 zB`980pQK&vkhNE{+TU)lEM!8Ngp~Nqj<|2#ZwojjZa+I`?YfcoeMD{o(}=l%{_!=d z2sf_3E!>j=MZ?<{FfCE+cd$D0;f&%yp6UHJd4}6t%mQfItjJ&r__{$rK1(KO6Wj2|wBW493HSHtt zduP;P4}x0NQMc^IpADmAV3hlcYdoE`GO(g7$qEOY()OuJ)0UXOjlce|kt9xBv^nqK z`I!7kEpF;-bJNa&!f*N(InczCMt$&nf&xPQy~_}iv*l=^Pmply;C!ZZ&rulZrv$ps zIp=rJCvzUya4NQs@gMlmpWAH_TvSyJMES;cl(q_#8`zGx4zJ}>B0upjMeWfq>KDi^ zI{^+_qkCK0t!G@7VmH22@CIP}vNhYG7cNkeLpZQpgBMB|$fRWIVVP#@Eg7EsI}UG+ zTVi;DBf*!Pi)6g{p~8&vI2~^hOB|#=qS4(FEj`jjSqYjcOe*ezw&9%l=;gT{o`DS_ zGBQlEq0zIzE8O+bHI{vR3;hMB0E|BMzSX6Z3!5QJGYG^(c7kU3)5ZevtUE`_;RtZz ziCFTkyvHRv@87>~3J*Vqn{Ltchb!B`Z7dfD3u!HJdp|cFwTqtk>22|;mY!=kEh-O2))H`Ri06VzHn|&!q~W7jWY5x)5{RMj!^P_@{!S_!P{lSd;c)@f zavG@C(?}Y*RJpFo>9b0U71ZsKYiI=iEcvdLa*`lI`{Kj)D&^PvI@kU z8?Gt#B*v+49)@3ey27b4I3S34xK)0qlvQ_vgX=-LcyGS8pO%aF1p@;PBiAR}*B{;8 zeGubx_Ju9FI^gchBoBlwkz}fD*7C}(7Mj|}esRgZL5j`sJc=;DuysMRNjI%rOtdX; zpHsP=rNE3je~^9i=jv`4S4|##k80fhw&&eGD-lDtm>}^Y>ppudU+2kfrLubOb%vYh zI{C@Vb7(i~klL4nPhP^IAAKt;EA^(bHFILjNsLCK6|JHiN4hRm#!#{FvJ~9#kM) zKCgZlhz`djcOhPgHgTd9S0+|UMAeG8+%+rMKZhosI7mlc0kxK>y-C7+@gFmtmZXIq zBeZcOH>5Vbik2I@Xfn+b`gtbxzU;gWh|H29aTP1JxOJMa*Q$}p?9>I4XajXT6}3MP zVYPtASFsDcWA{BhQ_5Z=C@G7q&7$~>;QZHDyt{IeR)!B z@#9XZJo~HqRsF?&?|wA1aWq%|Fb^96(-T&`xZz~biw*OA`5F8I6on^NCL2h$6=}BT z^J{52Bc& znmET~WYkI@%uCLUxLhJ2i}Qs@a|oWL6MIVOwoZp8P~QEbfPvlbl8fi4vn_RIzm!*N*DUhDzg~b!e>GSEG`HeE9GeWay#df|aY{p+J7%RGU zy}&qo_q)kO^HlXRaf!R`xIX*HdG#aJ`jxTlwkIFiC2>nPXGI6e1lZ=vImvlGDqNo` z8!GPp;|?1l5z^Q|n@@VRu}>Jq=-#@8Ne<{dOiMyXX^I+ubRct;gHW~t)}-{!PmGT< z)2Jw3&qZDfQgp~e#vAuJcr!bs1YhA|AlRuO~oHi<^&%m^_c$r@-|r3<$IUCK?TKwvqDnsA06TC+=_XD zKi?sopHWa@T8b22M31i4Ba^qR4k*=?cHmTFt0J3;VSYSfc7q3LdS=W_*R1*l%1b-* z9*TBRBhJq!!4Xjj^GT8YhvO4eOt6NbAy1>|I3%DyVUZqRl9pFiqF^5c&DyTkb{^}7 zV;{mQK-)V-vSp$$iazEsb_Qq{|S3`m-y@E*g=17rev;96`bjE zUw2g;vXOF8SFxwrS&~isgDj&MCZRNZdfIq$7e_}yYxd(H&Cw?P!3vU+@CgaGjwfOy zr0Q4e1v_2mm;#OV^>EW6h?5WlYBV)(lqEgT??ToP`-Z~I27zQH2lfF_Tj@>=be}Yy znIBA=wUVSbCiln&w`zge;?Yoz`(GM1Y%QLA?3JBgX1S?dAaKAw?W!tpgV#>F{7qs+ zxgo2oA`jn?bQ8uWR$l^z-*Zr&8vp2%`PP@562w7#`=NF2+y*sOdqxsw=le`qe-Y01 zdWodU_lpyKD4^4Irk`AaeKOjuahd*{V?bDrZ^;n9Bey@BYUyE}=E zD%F`R$O#Sys6N|JUua~$3<;KhUZmx?+O${-vA>u$Ju_5fHFaNXnh%<*Q54Zx(eU;e@i{0!QzY3nEqG*-Z4 zk%nT~HW7*vys>yQ_>PwC)IKt)hz)xOw>VU`@~u0~{|I0ArHIu0L9(OySz|a8tH;jT zO2BThdDovCfd3{ctTSD{r34MEy~%+hFoevPJPjODl!_t-CQLMJ{PLwscFU8pU0}W6 zIN(b@L`6$Wlz5l3k55coTp5tf<_x$N74mqXFk=s5@aEo?D&%Hc@r&PqO9ZKR24(?) z0cyWQuV`ktyy%Dzpl7--8!HP7OIoiDCDX>iPcUG+i}8Q$H@}KH*hD=ZsVLFQ;nK=W zxXMi~r2vk)-d4Ca)s`sQ6an$sONBjMGW-{%^IhrkWaV;^yt+5U+_yRoNs<7C08&v9 zplW@6e0=sMVCLW0{(T?Reao)=#sn55qf(d;F9$Jzq9eu3g|S+gk}9pJbM7Cg0L%=m z2sF)Xuxd-=n3$Of1TluV4M62EC>d9{_E}GLrY}Ev5BKa(mdNnp_cskd{=^TwTx^}- z7yw%w>*+I4kEkv6j!b2SD!277#Q}fb!!`jduS?1=Z@jQ1f32S3jlp2d$soS zNqiLZw=!FUy=mA1bn$ob00`N>&Tmj7)0g)!eeRrQ&i#gD6yALB4iSVn0N^c|!gjRk z;|H%(sXZ1re$wsL>BiZvEM?^HxCWd8uNQD&Nn6uCr*zE4Rf;*+ONEq`xi z=mOjS&N9%duH{;SIPy?>O9h<293hSblO0bim*wQ}B#F5qDgbk|kMNVa2VyYMS=RJ7 z?U)^=+b&Pt2?~rgf`G@Gt)3AqHdU7*=_BB|_Y@v??F`t{l-A$qk6(@;LDvxy *, *, *, *', name='calc_qkvg') -def calc_qkvg(x: torch.Tensor, qkv_proj: torch.Tensor, gate_proj: torch.Tensor, - bs: int, s: int, r: int, head: int, c: int): - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - - gate = gate.reshape(bs, s, r, head, c).transpose(2, 3) - q = q.reshape(bs, s, r, head, c).transpose(2, 3) - k = k.reshape(bs, s, r, head, c).transpose(2, 3) - v = v.reshape(bs, s, r, head, c).transpose(2, 3) - return q, k, v, gate - - -""" -[bs, s, r, cm] -> [bs, s, r, cm] - -used as column-wise gated self-attention -""" - - -@nnscaler.register_op('N S R^ M^, M^ E^, M^ F^, E^ M^ -> N S R^ M^', - name='MSAAttention') -@torch.jit.ignore -def MSAAttention(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, - c: int, scale: float, chunk_size: int, is_train: bool): - bs, s, r, cm = x.size() - - if chunk_size == -1: - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - - gate = gate.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - q = q.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - k = k.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, - c).transpose(1, 2) - v = v.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - - sim = torch.bmm(q, k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - - attend = torch.bmm(sim, v) * gate - - out = attend.reshape(bs, s, head, r, - c).transpose(2, 3).reshape(bs, s, r, cm) - out = torch.matmul(out, out_proj) - else: - if is_train: - q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, - bs, s, r, head, c) - else: - q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, - c) - assert s % chunk_size == 0 - out_chunks = [] - - def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - gate: torch.Tensor, start: int): - cur_q = q[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_k = k[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c).transpose(1, 2) - cur_v = v[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - sim = torch.bmm(cur_q, cur_k) * 0.125 - sim = torch.nn.functional.softmax(sim, dim=-1) - attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, c).transpose( - 2, 3).reshape(bs, chunk_size, r, cm) - return attend - - for start in range(0, s, chunk_size): - if is_train: - attend = ckpt.checkpoint(attention, q, k, v, gate, start) - else: - attend = attention(q, k, v, gate, start) - out_chunks.append(attend) - - out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) - return out - - -@nnscaler.register_op('N S R^ M^, M^ E^, M^ F^, E^ M^, N 1^ 8^ R^ R^ -> N S R^ M^', - name='MSAAttentionWithBias') -@torch.jit.ignore -def MSAAttentionWithBias(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias: torch.Tensor, head: int, c: int, scale: float, - chunk_size: int, is_train: bool): - bs, s, r, cm = x.size() - assert cm % head == 0 - c = cm // head - - if chunk_size == -1: - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - - gate = gate.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - q = q.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - k = k.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, - c).transpose(1, 2) - v = v.reshape(bs, s, r, head, - c).transpose(2, 3).reshape(bs * s * head, r, c) - - sim = torch.bmm(q, k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - - sim = sim.reshape(bs, s, head, r, r) + bias - sim = sim.reshape(bs * s * head, r, r) - - attend = torch.bmm(sim, v) * gate - - out = attend.reshape(bs, s, head, r, - c).transpose(2, 3).reshape(bs, s, r, cm) - out = torch.matmul(out, out_proj) - else: - if is_train: - q, k, v, gate = ckpt.checkpoint(calc_qkvg, x, qkv_proj, gate_proj, - bs, s, r, head, c) - else: - q, k, v, gate = calc_qkvg(x, qkv_proj, gate_proj, bs, s, r, head, - c) - - assert s % chunk_size == 0 - out_chunks = [] - - def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - gate: torch.Tensor, bias: torch.Tensor, start: int): - cur_q = q[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_k = k[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c).transpose(1, 2) - cur_v = v[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - - sim = torch.bmm(cur_q, cur_k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - sim = sim.reshape(bs, chunk_size, head, r, r) + bias - sim = sim.reshape(bs * chunk_size * head, r, r) - - attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, c).transpose( - 2, 3).reshape(bs, chunk_size, r, cm) - return attend - - for start in range(0, s, chunk_size): - if is_train: - attend = ckpt.checkpoint(attention_bias, q, k, v, gate, bias, - start) - else: - attend = attention_bias(q, k, v, gate, bias, start) - - out_chunks.append(attend) - out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) - return out - - -""" -([bs, s, r, cm], [bs, r, r, cz]) -> [bs, s, r, cm] -""" - - -# note: code not reused constrained by cube's interface -@nnscaler.register_op('N S R^ M^, N R^ R^ Z^, M^ E^, M^ F^, E^ M^, Z^ H^ -> N S R^ M^', - name='MSARowAttentionWithPairBias') -def MSARowAttentionWithPairBias(msa_repr: torch.Tensor, - pair_repr: torch.Tensor, - gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias_proj: torch.Tensor, head: int, c: int, - scale: float, chunk_size: int, is_train: bool): - # call: MSAAttentionWithBias - bs, s, r, cm = msa_repr.size() - - bias = torch.matmul(pair_repr, - bias_proj).permute(0, 3, 1, - 2).reshape(bs, 1, head, r, r) - - return MSAAttentionWithBias(msa_repr, gate_proj, qkv_proj, out_proj, bias, - head, c, scale, chunk_size, is_train) - - -@nnscaler.register_op('N S^ R M^, M^ E^, M^ F^, E^ M^ -> N S^ R M^', - name='MSAColAttention') -def MSAColAttention(msa_repr: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, - c: int, scale: float, chunk_size: int, is_train: bool): - # call: MSAAttention - return MSAAttention(msa_repr.permute(0, 2, 1, 3), gate_proj, qkv_proj, - out_proj, head, c, scale, chunk_size, - is_train).permute(0, 2, 1, 3) - - -@nnscaler.register_op('N S^ R^ M^, M^ M^, M^ E^, M^ E^, M^ M^, M^ M^ -> N S^ R^ M^', - name='MSAColGlobalAttention') -def MSAColGlobalAttention(msa_repr: torch.Tensor, q_proj: torch.Tensor, - k_proj: torch.Tensor, v_proj: torch.Tensor, - gate_proj: torch.Tensor, out_proj: torch.Tensor, - head: int, c: int, scale: float): - # [N R S M] - msa_repr = msa_repr.transpose(-2, -3) - - # [N R M] - q = torch.sum(msa_repr, dim=-2) - # [N R M] - q = torch.matmul(q, q_proj) * scale - # [N R H E] - q = q.view(q.shape[:-1] + (head, -1)) - - # [N R S E] - k, v = torch.matmul(msa_repr, k_proj), torch.matmul(msa_repr, v_proj) - - # [N R H S] - a = torch.matmul(q, k.transpose(-1, -2)) - a = torch.nn.functional.softmax(a, dim=-1) - # [N R H E] - o = torch.matmul(a, v) - - # [N R S M] - g = torch.sigmoid(torch.matmul(msa_repr, gate_proj)) - # [N R S H E] - g = g.view(g.shape[:-1] + (head, -1)) - - # [N R 1 H E] - o = o.unsqueeze(-3) * g - # [N R S M] - o = o.reshape(o.shape[:-2] + (-1, )) - - return torch.matmul(o, out_proj).transpose(-2, -3) - - -""" -[bs, s, r, cm] -> [bs, s, r, cm] -""" - - -@nnscaler.register_op('N S R M^, M^ E^, E^ M^ -> N S R M^', - name='MSATransition') -def MSATransition(msa_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - return torch.matmul( - torch.nn.functional.relu(torch.matmul(msa_repr, proj1)), proj2) - - -@nnscaler.register_op('N S R M^, M^ C^ -> N S R C^', name='OPMLeftProj') -def OPMLeftProj(msa_repr: torch.Tensor, proj: torch.Tensor): - return torch.matmul(msa_repr, proj) - - -@nnscaler.register_op('N S R M^, M^ C^ -> N S R C^', name='OPMRightProj') -def OPMRightProj(msa_repr: torch.Tensor, proj: torch.Tensor): - return torch.matmul(msa_repr, proj) - - -""" -[bs, s, r, cm] -> [bs, r, r, cz] -""" - - -@nnscaler.register_op('N S^ R M^, N S^ T^ M^, F^ Z^ -> N R^ T Z^', - name='OuterProductMean') -@torch.jit.ignore -def OuterProductMean(left_act: torch.Tensor, right_act: torch.Tensor, - out_proj: torch.Tensor, chunk_size: int, is_train: bool): - bs, s, r, c = left_act.size() - t = right_act.size(2) - - a = left_act.transpose(-2, -3) - b = right_act.transpose(-2, -3) - - if chunk_size == -1: - outer = torch.einsum('...bac,...dae->...bdce', a, - b).reshape(bs, r, t, c * c) - outer = torch.matmul(outer, out_proj) - else: - out_chunks = [] - - def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): - lhs_slice = lhs[:, start:start + chunk_size, :, :] - out = torch.einsum('...bac,...dae->...bdce', lhs_slice, - rhs).reshape(bs, chunk_size, t, c * c) - out = torch.matmul(out, out_proj) - return out - - for start in range(0, r, chunk_size): - if is_train: - ret = ckpt.checkpoint(opm, a, b, start) - else: - ret = opm(a, b, start) - out_chunks.append(ret) - outer = torch.cat(out_chunks, dim=1) - return outer - - -@nnscaler.register_op('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', name='TMOLeftProj') -def TMOLeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - b = a * torch.matmul(pair_repr, proj2) - return b - - -@nnscaler.register_op('N S R^ Z^, Z^ E^, Z^ E^ -> N S R^ E^', - name='TMORightProj') -def TMORightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - b = a * torch.matmul(pair_repr, proj2) - return b - - -@nnscaler.register_op('N S T^ Z^, Z^ Z^ -> N S T^ Z^', name='TMOGate') -def TMOGate(pair_repr: torch.Tensor, proj: torch.Tensor): - return torch.sigmoid(torch.matmul(pair_repr, proj)) - - -@nnscaler.register_op('N S R^ E^, N T^ R^ E^, N S T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', - name='TriangleMultiplicationOut') -def TriangleMultiplicationOut(a: torch.Tensor, b: torch.Tensor, - g: torch.Tensor, - tri_mul_norm2_weight: torch.Tensor, - tri_mul_norm2_bias: torch.Tensor, - tri_mul_proj5: torch.Tensor, cz: int): - a = a.permute(0, 3, 1, 2) - b = b.permute(0, 3, 2, 1) - - p = torch.matmul(a, b).permute(0, 2, 3, 1) - p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, - tri_mul_norm2_bias) - p = torch.matmul(p, tri_mul_proj5) - return p * g - - -@nnscaler.register_op('N R^ S Z^, Z^ E^, Z^ E^ -> N R^ S E^', name='TMILeftProj') -def TMILeftProj(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - a = a * torch.matmul(pair_repr, proj2) - return a - - -@nnscaler.register_op('N R^ T Z^, Z^ E^, Z^ E^ -> N R^ T E^', - name='TMIRightProj') -def TMIRightProj(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - a = torch.sigmoid(torch.matmul(pair_repr, proj1)) - a = a * torch.matmul(pair_repr, proj2) - return a - - -@nnscaler.register_op('N S^ T Z^, Z^ Z^ -> N S^ T Z^', name='TMIGate') -def TMIGate(pair_repr: torch.Tensor, proj: torch.Tensor): - return torch.sigmoid(torch.matmul(pair_repr, proj)) - - -@nnscaler.register_op('N R^ S E^, N R^ T^ E^, N T^ S Z^, E^, E^, E^ Z^ -> N T^ S Z^', - name='TriangleMultiplicationIn') -def TriangleMultiplicationIn(a: torch.Tensor, b: torch.Tensor, g: torch.Tensor, - tri_mul_norm2_weight: torch.Tensor, - tri_mul_norm2_bias: torch.Tensor, - tri_mul_proj5: torch.Tensor, cz: int): - a = a.permute(0, 3, 2, 1) - b = b.permute(0, 3, 1, 2) - - p = torch.matmul(a, b).permute(0, 2, 3, 1) - p = torch.nn.functional.layer_norm(p, (128, ), tri_mul_norm2_weight, - tri_mul_norm2_bias) - p = torch.matmul(p, tri_mul_proj5) - return p.permute(0, 2, 1, 3) * g - - -@nnscaler.register_op('N S R^ C^, C^ D^ -> N S R^ D^', name='TANSBias') -def TANSBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): - return torch.matmul(pair_repr, bias_proj) - - -@nnscaler.register_op('N S R^ Z^, Z^ E^, Z^ F^, E^ Z^, N T^ R^ G^ -> N S R^ Z^', - name='TriangleAttentionNodeStart') -def TriangleAttentionNodeStart(pair_repr: torch.Tensor, - gate_proj: torch.Tensor, qkv_proj: torch.Tensor, - out_proj: torch.Tensor, bias: torch.Tensor, - head: int, c: int, scale: float, - chunk_size: int, is_train: bool): - # call: MSAAttentionWithBias - bias = bias.permute(0, 3, 1, 2).unsqueeze(1) - - return MSAAttentionWithBias(pair_repr, gate_proj, qkv_proj, out_proj, bias, - head, c, scale, chunk_size, is_train) - - -@nnscaler.register_op('N S^ R C^, C^ D^ -> N S^ R D^', name='TANEBias') -def TANEBias(pair_repr: torch.Tensor, bias_proj: torch.Tensor): - return torch.matmul(pair_repr, bias_proj) - - -@nnscaler.register_op('N R^ S Z^, Z^ E^, Z^ F^, E^ Z^, N R^ T^ G^ -> N R^ S Z^', - name='TriangleAttentionNodeEnd') -def TriangleAttentionNodeEnd(pair_repr: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias: torch.Tensor, head: int, c: int, - scale: float, chunk_size: int, is_train: bool): - # call: TriangleAttentionNodeStart - pair_repr = pair_repr.permute(0, 2, 1, 3) - bias = bias.permute(0, 2, 1, 3) - out = TriangleAttentionNodeStart(pair_repr, gate_proj, qkv_proj, out_proj, - bias, head, c, scale, chunk_size, - is_train) - return out.permute(0, 2, 1, 3) - - -@nnscaler.register_op('N R T^ Z^, Z^ E^, E^ Z^ -> N R T^ Z^', - name='PairTransition') -def PairTransition(pair_repr: torch.Tensor, proj1: torch.Tensor, - proj2: torch.Tensor): - return torch.matmul( - torch.nn.functional.relu(torch.matmul(pair_repr, proj1)), proj2) - - -@nnscaler.register_op('* -> *, *', name='multi2ref') -def multi2ref(x: torch.Tensor): - return (x, x) diff --git a/examples/alphafold2/policy/spmd.py b/examples/alphafold2/policy/spmd.py deleted file mode 100644 index 1cc25fd0..00000000 --- a/examples/alphafold2/policy/spmd.py +++ /dev/null @@ -1,274 +0,0 @@ -from typing import List - -from numpy import TooHardError -from nnscaler.graph import IRGraph -from nnscaler.ir.operator import IRDataOperation, IRFwOperation, IRBpOperation -from nnscaler.graph.function.anchor import IRGraphAnchor - - -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for dev_id, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, dev_id) - return sub_nodes - - -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], idx: int, - dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=len(devs)) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def _tps(graph: IRGraph, nodes: List[IRFwOperation], devs: List[int], idx: int, - dim: int): - sub_nodes = [] - for node in nodes: - sub_nodes = sub_nodes + _tp(graph, node, devs, idx, dim) - return sub_nodes - - -def _coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], - colocate: int, idx: int, dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=idx, - dim=dim, - num=colocate * len(devs)) - assert sub_nodes is not None - graph.recompute(sub_nodes) - for devid in devs: - for coid in range(colocate): - sub_node = sub_nodes[devid * colocate + coid] - graph.assign(sub_node, devid) - return sub_nodes - - -def PASSingleInference(graph: IRGraph, resource): - assert resource.ngpus == 1 - - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - - return graph - - -def PASSingle(graph: IRGraph, resource): - assert resource.ngpus == 1 - - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - - fnodes = graph.nodes() - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - - indices = [ - fnodes.index(anchor) for anchor in anchors - if anchor.name == 'One Layer Evoformer Start' - or anchor.name == 'One Layer Evoformer End' - ] - assert len(indices) % 2 == 0 - for i in range(len(indices) // 2): - lhs = indices[2 * i] - rhs = indices[2 * i + 1] - - # deepmind's default recompute strategy - graph.recompute(fnodes[lhs + 1:rhs]) - - # another strategy - # sub_indices = [] - # for j in range(lhs + 1, rhs): - # if isinstance(fnodes[j], IRGraphAnchor): - # sub_indices.append(j) - # sub_indices.append(rhs) - # for j in range(len(sub_indices) - 1): - # graph.recompute(fnodes[sub_indices[j] + 1:sub_indices[j + 1]]) - - return graph - - -def PASExtraSingle(graph: IRGraph, resource): - assert resource.ngpus == 1 - - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - - fnodes = graph.nodes() - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - - indices = [ - fnodes.index(anchor) for anchor in anchors - if anchor.name == 'MSACol' or anchor.name == 'One Layer Evoformer End' - ] - assert len(indices) % 2 == 0 - for i in range(len(indices) // 2): - lhs = indices[2 * i] - rhs = indices[2 * i + 1] - - graph.recompute(fnodes[lhs + 1:rhs]) - return graph - - -def PASData(graph: IRGraph, resource): - devs = list(range(resource.ngpus)) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - algo = node.algorithms('data') - sub_nodes = graph.partition(node, algo, num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - batch_dim = node.get_batch_dims()[0] - - for node in graph.nodes(): - if isinstance(node, IRFwOperation): - if node.name == 'mul': - sub_nodes = graph.replicate(node, times=resource.ngpus) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - continue - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, - algo, - idx=0, - dim=batch_dim, - num=resource.ngpus) - for idx, sub_node in enumerate(sub_nodes): - graph.assign(sub_node, idx) - return graph - - -def PASDAP(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - - fnodes = graph.nodes() - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - - indices = [ - fnodes.index(anchor) for anchor in anchors - if anchor.name == 'One Layer Evoformer Start' - or anchor.name == 'One Layer Evoformer End' - ] - assert len(indices) % 2 == 0 - - for i in range(indices[0]): - if isinstance(fnodes[i], IRDataOperation) or isinstance( - fnodes[i], IRFwOperation): - _replica(graph, fnodes[i], tp_devs) - - for i in range(len(indices) // 2): - lhs, rhs = indices[2 * i], indices[2 * i + 1] - sub_indices = [] - for j in range(lhs + 1, rhs): - if isinstance(fnodes[j], IRGraphAnchor): - sub_indices.append(j) - sub_indices.append(rhs) - graph.recompute(fnodes[lhs:rhs]) - for j in range(len(sub_indices) - 1): - sub_l, sub_r = sub_indices[j], sub_indices[j + 1] - names = [] - for k in range(sub_l + 1, sub_r): - names.append(fnodes[k].name) - names = set(names) - nodes = fnodes[sub_l + 1:sub_r] - # DO NOT USE THIS - # graph.recompute(nodes) - - if 'MSARowAttentionWithPairBias' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) - elif 'MSAColAttention' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'MSATransition' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'OuterProductMean' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'TriangleMultiplicationOut' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) - elif 'TriangleMultiplicationIn' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'TriangleAttentionNodeStart' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) - elif 'TriangleAttentionNodeEnd' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'PairTransition' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) - else: - assert False, names - - for i in range(indices[-1] + 1, len(fnodes)): - if isinstance(fnodes[i], IRDataOperation) or isinstance( - fnodes[i], IRFwOperation): - _replica(graph, fnodes[i], tp_devs) - - return graph - - -def PASDAPInference(graph: IRGraph, resource): - tp_size = resource.ngpus - tp_devs = list(range(tp_size)) - - fnodes = graph.nodes() - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - - indices = [ - fnodes.index(anchor) for anchor in anchors - if anchor.name == 'One Layer Evoformer Start' - or anchor.name == 'One Layer Evoformer End' - ] - assert len(indices) % 2 == 0 - - for i in range(indices[0]): - if isinstance(fnodes[i], IRDataOperation) or isinstance( - fnodes[i], IRFwOperation): - _replica(graph, fnodes[i], tp_devs) - - for i in range(len(indices) // 2): - lhs, rhs = indices[2 * i], indices[2 * i + 1] - sub_indices = [] - for j in range(lhs + 1, rhs): - if isinstance(fnodes[j], IRGraphAnchor): - sub_indices.append(j) - sub_indices.append(rhs) - for j in range(len(sub_indices) - 1): - sub_l, sub_r = sub_indices[j], sub_indices[j + 1] - names = [] - for k in range(sub_l + 1, sub_r): - names.append(fnodes[k].name) - names = set(names) - nodes = fnodes[sub_l + 1:sub_r] - - if 'MSARowAttentionWithPairBias' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) - elif 'MSAColAttention' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'MSATransition' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'OuterProductMean' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'TriangleMultiplicationOut' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) - elif 'TriangleMultiplicationIn' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'TriangleAttentionNodeStart' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) - elif 'TriangleAttentionNodeEnd' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 2) - elif 'PairTransition' in names: - sub_nodes = _tps(graph, nodes, tp_devs, 0, 1) - else: - assert False, names - - for i in range(indices[-1] + 1, len(fnodes)): - if isinstance(fnodes[i], IRDataOperation) or isinstance( - fnodes[i], IRFwOperation): - _replica(graph, fnodes[i], tp_devs) - - return graph diff --git a/examples/megatron_gpt/.gitignore b/examples/megatron_gpt/.gitignore deleted file mode 100644 index cf862087..00000000 --- a/examples/megatron_gpt/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -*log.txt -*.pt -*.cube \ No newline at end of file diff --git a/examples/megatron_gpt/README.md b/examples/megatron_gpt/README.md deleted file mode 100644 index 4a20fdfc..00000000 --- a/examples/megatron_gpt/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Train Megatron-GPT with Cube - -This example demonstrates how to train a GPT model from Megatron-ML using Cube. The process consists of three main steps: -1. Instantiate the model and trace it to an fx.Graph. Then, convert the fx.Graph to a Cube graph. -2. Compile the Cube graph into Python code by **data parallel** on 2 devices. -3. Train the GPT model using the compiled code in Fairseq. - -At first, clone the Megatron-LM and checkpoint to the devcube branch, gpt model in this branch is a single device version. - -```console -git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Megatron-LM -cd Megatron-LM -git checkout devcube -# cd MagicCube dir -cd ../MagicCube/examples/megatron_gpt -# download gpt2-vocab.json and gpt2-merges.txt -wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -``` - -The following three commands correspond to the above three steps: - -```console -bash run.sh trace -bash run.sh compile -bash run.sh run -``` diff --git a/examples/megatron_gpt/convert.py b/examples/megatron_gpt/convert.py deleted file mode 100644 index ca5f1dbc..00000000 --- a/examples/megatron_gpt/convert.py +++ /dev/null @@ -1,76 +0,0 @@ -# 1. build model -from gpt_model import build_model, GeLUFunction -model = build_model() - -# 2. register customized op -from nnscaler import register_op -register_op('* h, h -> * h')(GeLUFunction.apply) - -# 3. build semantic model -from nnscaler import SemanticModel -smodel = SemanticModel(model) - -# 4. set dummy input -import torch -batch_size = 16 -seq_len = 128 -dict_len = 50000 -smodel.dummy_input={ - 'src_tokens': torch.randint(0, dict_len, (batch_size, seq_len)), - 'target': torch.randint(0, dict_len, (batch_size, seq_len)), - 'ntokens': 128, -} - -from nnscaler.graph.function import IRObject -from nnscaler.ir import IRFullTensor - -src_tokens = IRFullTensor(shape=[batch_size, seq_len], - name='src_tokens', - dtype=torch.int).tosub() - -target = IRFullTensor(shape=[batch_size, seq_len], - name='target', - dtype=torch.int).tosub() - -ntokens = IRObject(name='ntokens') - -# 5. convert to graph -from nnscaler.graph.segment import IRSegment -from nnscaler.program import Program - -from torch.autograd.graph import saved_tensors_hooks - -class no_save_tensor_hook(saved_tensors_hooks): - def __init__(self): - - def pack(x): - return None - - def unpack(x): - raise RuntimeError("not expecting backward to be called on this tensor") - - super().__init__(pack, unpack) - -Program().clear() - -with no_save_tensor_hook(): - outputs = smodel(src_tokens, target, ntokens) -outputs[0].backward() - -Program().finalize() -Program().set_input([src_tokens, target, ntokens]) - -if outputs is None: - outputs = [] -elif not (isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = [outputs] -Program().set_output(outputs) - -graph = Program().get_graph() - -# 6. save graph -graph.dump('megatron_gpt2.cube') - -for node in graph._nodes: - if isinstance(node, IRSegment): - print(node.debug_tensor_map_str()) diff --git a/examples/megatron_gpt/gpt_model.py b/examples/megatron_gpt/gpt_model.py deleted file mode 100644 index aea1f412..00000000 --- a/examples/megatron_gpt/gpt_model.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch - -from megatron import initialize_megatron -from megatron.training import get_args, ModelType -from megatron.arguments import core_transformer_config_from_args -from megatron.model import GPTModel -from megatron.model.fused_bias_gelu import GeLUFunction -from megatron.core.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear - -class GPT2Model(GPTModel): - def __init__(self, config, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True): - super().__init__(config, num_tokentypes, parallel_output, pre_process, post_process) - - def forward(self, src_tokens, target, ntokens): - position_ids = torch.arange(0, src_tokens.shape[1], 1).unsqueeze(0).expand_as(src_tokens) - attention_mask = (torch.tril(torch.ones(1, 1, src_tokens.shape[1], src_tokens.shape[1])) < 0.5).bool() - res = super().forward(src_tokens, position_ids, attention_mask, labels=target) - return res, ntokens, {'loss': res, 'ntokens': ntokens, 'nsentences': src_tokens.shape[0], 'sample_size': ntokens} - - -def build_model() -> GPT2Model: - initialize_megatron(extra_args_provider=None, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) - get_args().model_type = ModelType.encoder_or_decoder - config = core_transformer_config_from_args(get_args()) - model = GPT2Model( - config, - num_tokentypes=0, - parallel_output=True, - pre_process=True, - post_process=True - ) - - return model diff --git a/examples/megatron_gpt/parallel.py b/examples/megatron_gpt/parallel.py deleted file mode 100644 index 7de6570d..00000000 --- a/examples/megatron_gpt/parallel.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -plan_ngpus = int(os.environ['PLAN_NGPUS']) -runtime_ngpus = int(os.environ['CUBE_SCALING_FACTOR']) * plan_ngpus - -# 1. load graph -from nnscaler.graph import IRGraph -graph = IRGraph.load('megatron_gpt2.cube') - -# 2. register customized op -from gpt_model import GeLUFunction -from nnscaler import register_op -register_op('* h, h -> * h')(GeLUFunction.apply) - -# 3. parallel model -from fairseq.nnscaler.pas_policies import PASData, PASRandomSPMD -graph = PASData(graph, plan_ngpus) - -for node in graph.nodes(flatten=True): - from nnscaler.graph.function.anchor import IRGraphAnchor - from nnscaler.graph.function.pyfunc import IRPyFunc - # skip graph anchor and multiref: they will be removed or replaced by system - if isinstance(node, IRGraphAnchor) or node.name == 'multiref': - graph.assign(node, 0) - if isinstance(node, IRPyFunc): - graph.assign(node, 0) - if len(node.device) == 0: - raise RuntimeError(f"Node {node} device is not set") -from nnscaler.graph.gener.gen import IRAdapterGener -graph = IRAdapterGener.gen(graph, cost_fn=None) -if graph.sched is not None: - graph.sched.apply() - print(graph.sched) - -from nnscaler.graph.schedule.schedplan import SchedulePlan -from nnscaler.execplan import ExecutionPlan -if isinstance(graph.sched, SchedulePlan): - execplan = ExecutionPlan.from_schedplan(graph.sched) -else: - execplan = ExecutionPlan.from_graph(graph) -# execplan.visualize('plan.png') -from nnscaler.execplan.planpass.fusion import DiffFusion -execplan = DiffFusion.apply(execplan) -# plan pass for computation grouping -from nnscaler.execplan.planpass.grouping import Grouping -if not graph.sched: - execplan = Grouping.apply(execplan) - -# 4. generate code -from nnscaler.codegen import ModuleCodeGen, ScheduleCodeGen -filename = 'gencode{}.py' -_runtime_ngpus = None if plan_ngpus == runtime_ngpus else runtime_ngpus -assert len(execplan.graph.device) == plan_ngpus, f"{execplan.graph.device}" -mgener = ModuleCodeGen(execplan, scale_ndevs=_runtime_ngpus) -sgener = ScheduleCodeGen(execplan, scale_ndevs=_runtime_ngpus) -for rank in range(runtime_ngpus): - fname = filename.format(rank) - # generate spatial module code - mgener.gen(rank, outfile=fname, attach=False) - # generate temporal schedule code - sgener.gen( - device = rank, - outfile = fname, - attach=True - ) diff --git a/examples/megatron_gpt/run.sh b/examples/megatron_gpt/run.sh deleted file mode 100644 index a8b0f6af..00000000 --- a/examples/megatron_gpt/run.sh +++ /dev/null @@ -1,108 +0,0 @@ -# Usage: bash run.sh mode = {trace, compile, run, all} - -MEGATRON_PATH=/home/ningshang/Megatron-LM -TENSORBOARD_DIR=/data/ningshang/megatron_gpt -DATA_PATH=/data/ningshang/torchscale_data -TORCHSCALE_PATH=/home/ningshang/anaconda3/envs/cube/lib/python3.10/site-packages/examples/fairseq -FAIRSEQ_PATH=/home/ningshang/Fairseq - -export USE_TORCHFX=1 -export LOG_PARSER=1 -export DISABLE_CODE_LINE_INFO=0 - -PLAN_NGPUS=1 -CUBE_SCALING_FACTOR=2 - - -# check arg num -if [ $# -ne 1 ] -then - echo "Usage: bash run.sh mode = {trace, compile, run}" - exit 1 -fi - -MODE=$1 - -if [ $MODE = "trace" ] -then - VOCAB_FILE=./gpt2-vocab.json - MERGE_FILE=./gpt2-merges.txt - GPT_ARGS=" - --num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --seq-length 128 \ - --max-position-embeddings 128 - " - USELESS_ARGS=" - --micro-batch-size 4 \ - --global-batch-size 8 \ - --lr 0.00015 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 - " - DATA_ARGS=" - --data-path $DATA_PATH/train \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --data-impl mmap \ - --split 949,50,1 - " - OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 - " - PYTHONPATH=.:PYTHONPATH:$TORCHSCALE_PATH:$MEGATRON_PATH CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --nnodes 1 convert.py $GPT_ARGS $DATA_ARGS $OUTPUT_ARGS $USELESS_ARGS >trace_log.txt 2>&1 -elif [ $MODE = "compile" ] -then - PLAN_NGPUS=$PLAN_NGPUS CUBE_SCALING_FACTOR=$CUBE_SCALING_FACTOR PYTHONPATH=.:PYTHONPATH:$TORCHSCALE_PATH:$MEGATRON_PATH:$FAIRSEQ_PATH python parallel.py >compile_log.txt 2>&1 -elif [ $MODE = "run" ] -then - PLAN_NGPUS=$PLAN_NGPUS PYTHONPATH=.:PYTHONPATH:$TORCHSCALE_PATH:$MEGATRON_PATH torchrun \ - --nproc_per_node=2 \ - --nnodes=1 \ - $TORCHSCALE_PATH/train.py $DATA_PATH \ - --num-workers 2 \ - --activation-fn gelu \ - --share-decoder-input-output-embed \ - --arch lm_base_125M \ - --validate-interval-updates 1000 \ - --save-interval-updates 1000 \ - --log-interval 1 \ - --task language_modeling \ - --sample-break-mode none \ - --tokens-per-sample 128 \ - --optimizer adam \ - --adam-betas "(0.9,0.999)" \ - --adam-eps 1e-08 \ - --clip-norm 1.0 \ - --lr 6.0e-4 \ - --lr-scheduler polynomial_decay \ - --warmup-updates 230 \ - --dropout 0.0 \ - --attention-dropout 0.0 \ - --weight-decay 0.01 \ - --batch-size 16 \ - --update-freq 1 \ - --required-batch-size-multiple 1 \ - --total-num-update 5000 \ - --max-update 5000 \ - --seed 1234 \ - --ddp-backend=legacy_ddp \ - --cube-scaling-factor $CUBE_SCALING_FACTOR \ - --subln --xpos-rel-pos \ - --parallel-backend=cube \ - --compile=run_only \ - --tensorboard-logdir $TENSORBOARD_DIR \ - --save-dir=/data/ningshang/checkpoint >run_log.txt 2>&1 -else - echo "Usage: bash run.sh mode = {trace, compile, run}" - exit 1 -fi diff --git a/examples/openfold/blocks/attention.py b/examples/openfold/blocks/attention.py deleted file mode 100644 index 5913f791..00000000 --- a/examples/openfold/blocks/attention.py +++ /dev/null @@ -1,360 +0,0 @@ -""" -Attention Module for MSA Attention and Pair Attention in Evoformer -""" - -import nnscaler -import torch -import torch.utils.checkpoint as ckpt - - -@nnscaler.register_op('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S R^ M^', name='msa_attn') -@torch.jit.ignore -def msa_attn(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, - c: int, scale: float, chunk_size: int, is_train: bool): - # nnscaler.profiler.CudaTimer().start('msa_attn') - bs, s, r, cm = x.size() - - if chunk_size == -1: - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - gate = gate.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) - q = q.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) - k = k.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c).transpose(1, 2) - v = v.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) - sim = torch.bmm(q, k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - attend = torch.bmm(sim, v) * gate - out = attend.reshape(bs, s, head, r, c).transpose(2, 3).reshape(bs, s, r, head * c) - out = torch.matmul(out, out_proj) - else: - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - gate = gate.reshape(bs, s, r, head, c).transpose(2, 3) - q = q.reshape(bs, s, r, head, c).transpose(2, 3) - k = k.reshape(bs, s, r, head, c).transpose(2, 3) - v = v.reshape(bs, s, r, head, c).transpose(2, 3) - assert s % chunk_size == 0 - out_chunks = [] - - def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - gate: torch.Tensor, start: int): - cur_q = q[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_k = k[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c).transpose(1, 2) - cur_v = v[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - sim = torch.bmm(cur_q, cur_k) * 0.125 - sim = torch.nn.functional.softmax(sim, dim=-1) - attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, c).transpose( - 2, 3).reshape(bs, chunk_size, r, head * c) - return attend - - for start in range(0, s, chunk_size): - attend = ckpt.checkpoint(attention, q, k, v, gate, start) - # attend = attention(q, k, v, gate, start) - out_chunks.append(attend) - - out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) - # nnscaler.profiler.CudaTimer().stop('msa_attn') - return out - - -@nnscaler.register_op('N S R^ M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, N 1 head+ R^ R^ -> N S R^ M^', name='msa_attn_bias') -@torch.jit.ignore -def msa_attn_bias(x: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias: torch.Tensor, head: int, c: int, scale: float, - chunk_size: int, is_train: bool): - # nnscaler.profiler.CudaTimer().start('msa_attn_bias') - bs, s, r, cm = x.size() - assert gate_proj.size(1) % head == 0 - c = gate_proj.size(1) // head - - if chunk_size == -1: - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) # N S R (head dim) - gate = gate.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) # (N S head) r dim - q = q.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) # (N S head) r dim - k = k.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c).transpose(1, 2) # (N S head) dim r - v = v.reshape(bs, s, r, head, c).transpose(2, 3).reshape(bs * s * head, r, c) - sim = torch.bmm(q, k) * scale # (N S head) r r - sim = torch.nn.functional.softmax(sim, dim=-1) # (N S head) r r - sim = sim.reshape(bs, s, head, r, r) + bias # N S head r r, N S 1 r r - sim = sim.reshape(bs * s * head, r, r) # (N S head) r r - attend = torch.bmm(sim, v) * gate # (N S head) r dim - out = attend.reshape(bs, s, head, r, c).transpose(2, 3).reshape(bs, s, r, head * c) - out = torch.matmul(out, out_proj) - else: - gate = torch.sigmoid(torch.matmul(x, gate_proj)) - q, k, v = torch.matmul(x, qkv_proj).chunk(3, dim=-1) - gate = gate.reshape(bs, s, r, head, c).transpose(2, 3) - q = q.reshape(bs, s, r, head, c).transpose(2, 3) - k = k.reshape(bs, s, r, head, c).transpose(2, 3) - v = v.reshape(bs, s, r, head, c).transpose(2, 3) - assert s % chunk_size == 0 - out_chunks = [] - - def attention_bias(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - gate: torch.Tensor, bias: torch.Tensor, start: int): - cur_q = q[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_k = k[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c).transpose(1, 2) - cur_v = v[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - cur_gate = gate[:, start:start + chunk_size, :, :, :].reshape( - bs * chunk_size * head, r, c) - - sim = torch.bmm(cur_q, cur_k) * scale - sim = torch.nn.functional.softmax(sim, dim=-1) - sim = sim.reshape(bs, chunk_size, head, r, r) + bias - sim = sim.reshape(bs * chunk_size * head, r, r) - - attend = torch.bmm(sim, cur_v) * cur_gate - attend = attend.reshape(bs, chunk_size, head, r, c).transpose( - 2, 3).reshape(bs, chunk_size, r, cm) - return attend - - for start in range(0, s, chunk_size): - if is_train: - attend = ckpt.checkpoint(attention_bias, q, k, v, gate, bias, start) - else: - attend = attention_bias(q, k, v, gate, bias, start) - out_chunks.append(attend) - out = torch.matmul(torch.cat(out_chunks, dim=1), out_proj) - # nnscaler.profiler.CudaTimer().stop('msa_attn_bias') - return out - - -# note: code not reused constrained by cube's interface -@nnscaler.register_op('N S R^ M^, N R^ R^ Z^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^, Z^ head+ -> N S R^ M^', name='row_attn') -def row_attn(msa_repr: torch.Tensor, pair_repr: torch.Tensor, - gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, - bias_proj: torch.Tensor, head: int, c: int, - scale: float, chunk_size: int, is_train: bool): - # call: MSAAttentionWithBias - bs, s, r, cm = msa_repr.size() - # N R R Z, Z h -> N R R h -> N h S R -> N 1 h S R - bias = torch.matmul(pair_repr, bias_proj).permute(0, 3, 1, 2).reshape(bs, 1, head, r, r) - - return msa_attn_bias(msa_repr, gate_proj, qkv_proj, out_proj, bias, - head, c, scale, chunk_size, is_train) - - -@nnscaler.register_op('N S^ R M^, M^ (head+ dim), M^ (head+ dim 3), (head+ dim) M^ -> N S^ R M^', name='col_attn') -def col_attn(msa_repr: torch.Tensor, gate_proj: torch.Tensor, - qkv_proj: torch.Tensor, out_proj: torch.Tensor, head: int, - c: int, scale: float, chunk_size: int, is_train: bool): - # call: MSAAttention - msa_repr = msa_repr.permute(0, 2, 1, 3) - out = msa_attn( - msa_repr, gate_proj, qkv_proj, out_proj, - head, c, scale, chunk_size, is_train) - out = out.permute(0, 2, 1, 3) - return out - - -# @nnscaler.register_op('N S^ R^ M^, M^ (head+ dim^), M^ E^, M^ E^, M^ (head+ dim^), (head+ dim) M^ -> N S^ R^ M^', name='MSAColGlobalAttention') -def global_attn(msa_repr: torch.Tensor, q_proj: torch.Tensor, - k_proj: torch.Tensor, v_proj: torch.Tensor, - gate_proj: torch.Tensor, out_proj: torch.Tensor, - head: int, c: int, scale: float): - # [N R S M] - msa_repr = msa_repr.transpose(-2, -3) - - # [N R M] - q = torch.sum(msa_repr, dim=-2) - # [N R M] - q = torch.matmul(q, q_proj) * scale - # [N R H E] - q = q.view(q.shape[:-1] + (head, -1)) - - # [N R S E] - k, v = torch.matmul(msa_repr, k_proj), torch.matmul(msa_repr, v_proj) - - # N R H E, N R E S -> N R H S - a = torch.matmul(q, k.transpose(-1, -2)) - a = torch.nn.functional.softmax(a, dim=-1) - # [N R H E] - o = torch.matmul(a, v) - - # [N R S M] - g = torch.sigmoid(torch.matmul(msa_repr, gate_proj)) - # [N R S H E] - g = g.view(g.shape[:-1] + (head, -1)) - - # [N R 1 H E] - o = o.unsqueeze(-3) * g - # [N R S M] - o = o.reshape(o.shape[:-2] + (-1, )) - - return torch.matmul(o, out_proj).transpose(-2, -3) - - -@nnscaler.register_op('N S R M^, M^ E+, E+ M^ -> N S R M^', name='feedforward') -@torch.jit.ignore -def feedforward(msa_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): - """ - MSA transition - """ - # nnscaler.profiler.CudaTimer().start('ffn') - x = torch.matmul(msa_repr, proj1) - x = torch.nn.functional.relu(x) - x = torch.matmul(x, proj2) - # nnscaler.profiler.CudaTimer().stop('ffn') - return x - - -@nnscaler.register_op('N S R^ Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N R^ R^ head+ -> N S R^ Z^', name='tri_attn_start') -def tri_attn_start(pair_repr: torch.Tensor, - gate: torch.Tensor, qkv: torch.Tensor, - out: torch.Tensor, bias: torch.Tensor, - head: int, c: int, scale: float, - chunk_size: int, is_train: bool): - # bias = torch.matmul(pair_repr, bias).permute(0, 3, 1, 2).unsqueeze(1) - bias = bias.permute(0, 3, 1, 2).unsqueeze(1) - out = msa_attn_bias(pair_repr, gate, qkv, out, bias, - head, c, scale, chunk_size, is_train) - return out - - -@nnscaler.register_op('N S^ R Z^, Z^ (head+ dim), Z^ (head+ dim 3), (head+ dim) Z^, N S^ S^ head+ -> N S^ R Z^', name='tri_attn_end') -def tri_attn_end(pair_repr: torch.Tensor, - gate: torch.Tensor, qkv: torch.Tensor, - out: torch.Tensor, bias: torch.Tensor, - head: int, c: int, scale: float, chunk_size: int, is_train: bool): - # bias = torch.matmul(pair_repr, bias).permute(0, 3, 2, 1).unsqueeze(1) - bias = bias.permute(0, 3, 2, 1).unsqueeze(1) - pair_repr = pair_repr.permute(0, 2, 1, 3) - out = msa_attn_bias(pair_repr, gate, qkv, out, bias, - head, c, scale, chunk_size, is_train) - return out.permute(0, 2, 1, 3) - - -class MSARowAttention(torch.nn.Module): - """ - MSA Row Attention with Pair Bias - """ - - def __init__(self, hidden: int, heads: int, z: int, scale: float, chunk_size: int = -1): - super().__init__() - assert hidden % heads == 0 - self.heads = heads - self.dhead = hidden // heads - self.chunk_size = chunk_size - self.scale = scale - self.bias = torch.nn.Parameter(torch.empty(z, heads)) - self.gate = torch.nn.Parameter(torch.empty(hidden, hidden)) - self.qkv = torch.nn.Parameter(torch.empty(hidden, hidden * 3)) - self.out = torch.nn.Parameter(torch.empty(hidden, hidden)) - - def forward(self, msa_repr: torch.Tensor, pair_repr: torch.Tensor) -> torch.Tensor: - """ - msa_repr: [N S R M] - pair_repr: [N R R Z] - """ - out = row_attn( - msa_repr, pair_repr, self.gate, self.qkv, self.out, self.bias, - self.heads, self.dhead, self.scale, self.chunk_size, self.training - ) - return out - - -class MSAColAttention(torch.nn.Module): - """ - MSA Coloumn Attention (no bias) - """ - def __init__(self, hidden: int, heads: int, scale: float, chunk_size: int = -1) -> None: - super().__init__() - assert hidden % heads == 0 - self.heads = heads - self.dhead = hidden // heads - self.chunk_size = chunk_size - self.scale = scale - self.gate = torch.nn.Parameter(torch.empty(hidden, hidden)) - self.qkv = torch.nn.Parameter(torch.empty(hidden, hidden * 3)) - self.out = torch.nn.Parameter(torch.empty(hidden, hidden)) - - def forward(self, msa_repr: torch.Tensor) -> torch.Tensor: - """ - msa_repr: [N S R M] - """ - out = col_attn( - msa_repr, self.gate, self.qkv, self.out, - self.heads, self.dhead, self.scale,self.chunk_size, self.training - ) - return out - - -class Transition(torch.nn.Module): - """ - Feedforward for msa_repr and pair_repr - """ - def __init__(self, hidden: int, ff_mult: int = 4) -> None: - super().__init__() - self.proj1 = torch.nn.Parameter(torch.empty(hidden, ff_mult * hidden)) - self.proj2 = torch.nn.Parameter(torch.empty(ff_mult * hidden, hidden)) - - def forward(self, msa_repr: torch.Tensor) -> torch.Tensor: - """ - msa_repr: [N S R M] - """ - return feedforward(msa_repr, self.proj1, self.proj2) - - -class TriangleAttentionNodeStart(torch.nn.Module): - - def __init__(self, cz: int, pair_head: int, c: int, scale: float, chunk_size=-1) -> None: - super().__init__() - self.heads = pair_head - self.c = c - self.scale = scale - self.chunk_size = chunk_size - self.layer_norm = torch.nn.LayerNorm(cz) - self.gate = torch.nn.Parameter(torch.empty(cz, pair_head * c)) - self.qkv = torch.nn.Parameter(torch.empty(cz, 3 * pair_head * c)) - self.out = torch.nn.Parameter(torch.empty(pair_head * c, cz)) - self.bias = torch.nn.Parameter(torch.empty(cz, pair_head)) - - def forward(self, pair_repr: torch.Tensor): - """ - pair_repr: N R R cz - """ - pair_repr = self.layer_norm(pair_repr) - bias = torch.matmul(pair_repr, self.bias) - pair_repr = tri_attn_start( - pair_repr, self.gate, self.qkv, self.out, bias, - self.heads, self.c, self.scale, self.chunk_size, self.training - ) - return pair_repr - - -class TriangleAttentionNodeEnd(torch.nn.Module): - - def __init__(self, cz: int, pair_head: int, c: int, scale: float, chunk_size=-1) -> None: - super().__init__() - self.heads = pair_head - self.c = c - self.scale = scale - self.chunk_size = chunk_size - self.layer_norm = torch.nn.LayerNorm(cz) - self.gate = torch.nn.Parameter(torch.empty(cz, pair_head * c)) - self.qkv = torch.nn.Parameter(torch.empty(cz, 3 * pair_head * c)) - self.out = torch.nn.Parameter(torch.empty(pair_head * c, cz)) - self.bias = torch.nn.Parameter(torch.empty(cz, pair_head)) - - def forward(self, pair_repr: torch.Tensor): - pair_repr = self.layer_norm(pair_repr) - bias = torch.matmul(pair_repr, self.bias) - pair_repr = tri_attn_end( - pair_repr, self.gate, self.qkv, self.out, bias, - self.heads, self.c, self.scale, self.chunk_size, self.training - ) - return pair_repr - diff --git a/examples/openfold/blocks/embedder.py b/examples/openfold/blocks/embedder.py deleted file mode 100644 index 0a5f7080..00000000 --- a/examples/openfold/blocks/embedder.py +++ /dev/null @@ -1,230 +0,0 @@ -import torch -import torch.nn as nn - -from typing import Tuple, Optional - -import nnscaler - - - -@nnscaler.register_op('N res, cz nobins, cz -> N res res cz', name='relpos') -def input_embedder_pair_emb(ri: torch.Tensor, - tf_emb_i: torch.Tensor, tf_emb_j: torch.Tensor, - w_relpos: torch.Tensor, b_relpos: torch.Tensor, - relpos_k) -> torch.Tensor: - - ri = ri.type(tf_emb_i.dtype) - d = ri[..., None] - ri[..., None, :] - boundaries = torch.arange( - start=-relpos_k, end=relpos_k + 1, device=torch.cuda.current_device() - ) - reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),)) - d = d[..., None] - reshaped_bins - d = torch.abs(d) - d = torch.argmin(d, dim=-1) - d = nn.functional.one_hot(d, num_classes=len(boundaries)).float() - d = d.to(ri.dtype) - pair_emb = torch.nn.functional.linear(d, w_relpos, b_relpos) - - pair_emb = pair_emb + tf_emb_i[..., None, :] - pair_emb = pair_emb + tf_emb_j[..., None, :, :] - - return pair_emb - - -@nnscaler.register_op('N res tfdim^, cm tfdim^, cm -> N nclust^, res, cm') -def input_embedder_tf_m(tf: torch.Tensor, w_tf_m: torch.Tensor, b_tf_m: torch.Tensor, nclust: int) -> torch.Tensor: - tf_m = torch.nn.linear(tf, w_tf_m, b_tf_m) - tf_m = tf_m.unsqueeze(-3).expand(((-1,) * len(tf.shape[:-2]) + (nclust, -1, -1))) - return tf_m - - -class InputEmbedder(nn.Module): - """ - Embeds a subset of the input features. - - Implements Algorithms 3 (InputEmbedder) and 4 (relpos). - """ - - def __init__(self, tf_dim: int, msa_dim: int, c_z: int, c_m: int, relpos_k: int): - """ - Args: - tf_dim: - Final dimension of the target features - msa_dim: - Final dimension of the MSA features - c_z: - Pair embedding dimension - c_m: - MSA embedding dimension - relpos_k: - Window size used in relative positional encoding - """ - super().__init__() - - self.tf_dim = tf_dim - self.msa_dim = msa_dim - - self.c_z = c_z - self.c_m = c_m - - self.linear_tf_z_i = nn.Linear(tf_dim, c_z) - self.linear_tf_z_j = nn.Linear(tf_dim, c_z) - # self.linear_tf_m = nn.Linear(tf_dim, c_m) - self.w_tf_m = torch.nn.Parameter(torch.empty((c_m, tf_dim))) - self.b_tf_m = torch.nn.Parameter(torch.empty((c_m))) - self.linear_msa_m = nn.Linear(msa_dim, c_m) - self.w_tf_m - - # RPE stuff - self.relpos_k = relpos_k - self.no_bins = 2 * relpos_k + 1 - self.w_linear_relpos = torch.nn.Parameter(torch.empty((c_z, self.no_bins))) - self.b_linear_relpos = torch.nn.Parameter(torch.empty((c_z,))) - - def forward(self, tf: torch.Tensor, ri: torch.Tensor, msa: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - tf: - "target_feat" features of shape [*, N_res, tf_dim] - ri: - "residue_index" features of shape [*, N_res] - msa: - "msa_feat" features of shape [*, N_clust, N_res, msa_dim] - Returns: - msa_emb: - [*, N_clust, N_res, C_m] MSA embedding - pair_emb: - [*, N_res, N_res, C_z] pair embedding - - """ - # [*, N_res, c_z] - tf_emb_i = self.linear_tf_z_i(tf) - tf_emb_j = self.linear_tf_z_j(tf) - - # [*, N_res, N_res, c_z] - pair_emb = input_embedder_pair_emb( - ri, tf_emb_i, tf_emb_j, - self.w_linear_relpos, self.b_linear_relpos - ) - # pair_emb = relpos(ri.type(tf_emb_i.dtype)) - # pair_emb = pair_emb + tf_emb_i[..., None, :] - # pair_emb = pair_emb + tf_emb_j[..., None, :, :] - - # [*, N_clust, N_res, c_m] - tf_m = input_embedder_tf_m(tf, self.w_tf_m, self.b_tf_m) - msa_emb = self.linear_msa_m(msa) + tf_m - - return msa_emb, pair_emb - - - -@nnscaler.register_op() -def sum_d(x: torch.Tensor, bins: torch.Tensor, inf: float) -> torch.Tensor: - squared_bins = bins ** 2 - upper = torch.cat( - [squared_bins[1:], squared_bins.new_tensor([inf])], dim=-1 - ) - d = torch.sum( - (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True - ) - d = ((d > squared_bins) * (d < upper)).type(x.dtype) - return d - - -class RecyclingEmbedder(nn.Module): - """ - Embeds the output of an iteration of the model for recycling. - - Implements Algorithm 32. - """ - def __init__(self, c_m: int, c_z: int, - min_bin: float, max_bin: float, no_bins: int, - inf: float = 1e8): - """ - Args: - c_m: - MSA channel dimension - c_z: - Pair embedding channel dimension - min_bin: - Smallest distogram bin (Angstroms) - max_bin: - Largest distogram bin (Angstroms) - no_bins: - Number of distogram bins - """ - super().__init__() - - self.c_m = c_m - self.c_z = c_z - self.min_bin = min_bin - self.max_bin = max_bin - self.no_bins = no_bins - self.inf = inf - - self.linear = nn.Linear(self.no_bins, self.c_z) - self.layer_norm_m = nn.LayerNorm(self.c_m) - self.layer_norm_z = nn.LayerNorm(self.c_z) - - bins = torch.linspace(self.min_bin, self.max_bin, self.no_bins, requires_grad=False) - self.register_buffer('bins', bins) - - def forward(self, m: torch.Tensor, z: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - m: - First row of the MSA embedding. [*, N_res, C_m] - z: - [*, N_res, N_res, C_z] pair embedding - x: - [*, N_res, 3] predicted C_beta coordinates - Returns: - m: - [*, N_res, C_m] MSA embedding update - z: - [*, N_res, N_res, C_z] pair embedding update - """ - m = self.layer_norm_m(m) - z = self.layer_norm_z(z) - d = sum_d(x, self.bins, self.inf) - d = self.linear(d) - z = z + d - return m, z - - -class TemplateAngleEmbedder(nn.Module): - """ - Embeds the "template_angle_feat" feature. - - Implements Algorithm 2, line 7. - """ - def __init__(self, c_in: int, c_out: int): - """ - Args: - c_in: - Final dimension of "template_angle_feat" - c_out: - Output channel dimension - """ - super(TemplateAngleEmbedder, self).__init__() - - self.c_out = c_out - self.c_in = c_in - - self.linear_1 = nn.Linear(self.c_in, self.c_out, init="relu") - self.relu = nn.ReLU() - self.linear_2 = nn.Linear(self.c_out, self.c_out, init="relu") - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: [*, N_templ, N_res, c_in] "template_angle_feat" features - Returns: - x: [*, N_templ, N_res, C_out] embedding - """ - x = self.linear_1(x) - x = self.relu(x) - x = self.linear_2(x) - return x - diff --git a/examples/openfold/blocks/evoformer.py b/examples/openfold/blocks/evoformer.py deleted file mode 100644 index c3b63977..00000000 --- a/examples/openfold/blocks/evoformer.py +++ /dev/null @@ -1,177 +0,0 @@ -from typing import Tuple -import torch -from examples.openfold.blocks.attention import MSARowAttention, MSAColAttention, Transition, TriangleAttentionNodeStart, TriangleAttentionNodeEnd -from examples.openfold.blocks.tmu import TriangleMultiplicativeUpdate -from examples.openfold.blocks.opm import OuterProducterMean -from examples.openfold.blocks.utils import multi2ref - -import math -import nnscaler - - -# @nnscaler.register_op('N S^ R^ cm^, N R^ R^ cz^ -> N out^') -# @torch.jit.ignore -# def input_packing(msa: torch.Tensor, pair: torch.Tensor, out: int) -> torch.Tensor: -# buffer = torch.cat((torch.flatten(msa, start_dim=1), torch.flatten(pair, start_dim=1))) -# return buffer -# -# -# @nnscaler.register_op('N out^ -> N S^ R^ cm^, N R^ R^ cz^', name='input_unflatten') -# @torch.jit.ignore -# def input_unpacking(buffer: torch.Tensor, -# S: int, R: int, cm: int, cz: int) -> Tuple[torch.Tensor, torch.Tensor]: -# msa_nele = S * R * cm -# msa = buffer[:,:msa_nele].reshape(buffer.size(0), S, R, cm) -# pair = buffer[:,msa_nele:].reshape(buffer.size(0), R, R, cz) -# return msa, pair - - -class Evoformer(torch.nn.Module): - """ - Simulate execution of evoformer in alphafold. - - The mask and dropout is ommited for simplicity. - """ - - def __init__(self, s: int, r: int, cm: int, cz: int, - use_chunk=False, is_train=True, - c=32, msa_head=8, pair_head=4, - c_tri_mult=128, ff_mult=4): - super().__init__() - - self.s, self.r, self.cm, self.cz, self.c = s, r, cm, cz, c - self.fout = self.s * self.r * self.cm + self.r * self.r * self.cz - self.msa_head, self.pair_head = msa_head, pair_head - self.c_tri_mult, self.ff_mult = c_tri_mult, ff_mult - self.scale = 1.0 / math.sqrt(c) - - self.is_train = is_train - - self.msa_row_chunk = 4 if use_chunk else -1 - self.msa_col_chunk = -1 - self.opm_chunk = self.tans_chunk = self.tane_chunk = -1 - - # MSA row-wise gated self-attention with pair bias - self.row_norm_m = torch.nn.LayerNorm(cm) - self.row_norm_z = torch.nn.LayerNorm(cz) - self.row_attn = MSARowAttention(cm, msa_head, cz, self.scale, self.msa_row_chunk) - - # MSA column-wise gated self-attention - self.col_norm = torch.nn.LayerNorm(cm) - self.col_attn = MSAColAttention(cm, msa_head, self.scale, self.msa_col_chunk) - - # MSA transition - self.msa_transition_norm = torch.nn.LayerNorm(cm) - self.msa_transition = Transition(cm, ff_mult) - - # Outer product mean - self.outer_norm = torch.nn.LayerNorm(cm) - self.outer_prod_mean = OuterProducterMean(cm, c, cz, self.opm_chunk) - - # Triangular multiplicative update using outgoing edges - self.tmo = TriangleMultiplicativeUpdate(cz, c_tri_mult, outgoing=True) - - # Triangular multiplicative update using incoming edges - self.tmi = TriangleMultiplicativeUpdate(cz, c_tri_mult, outgoing=False) - - # Triangular gated self-attention around starting node - self.tri_attn_node_start = TriangleAttentionNodeStart(cz, pair_head, c, self.scale, self.tans_chunk) - - # Triangular gated self-attention around ending node - self.tri_attn_node_end = TriangleAttentionNodeEnd(cz, pair_head, c, self.scale, self.tane_chunk) - - # Transition in the pair stack - self.pair_transition_norm = torch.nn.LayerNorm(cz) - self.pair_transition = Transition(cz, ff_mult) - - def forward(self, msa_repr, pair_repr): - - nnscaler.runtime.function.anchor('MSARow') - pair_repr, dummy_pair_repr = multi2ref(pair_repr) - residual = msa_repr - msa_repr = self.row_norm_m(msa_repr) - dummy_pair_repr = self.row_norm_z(dummy_pair_repr) - msa_repr = residual + self.row_attn(msa_repr, dummy_pair_repr) - - nnscaler.runtime.function.anchor('MSACol') - residual = msa_repr - msa_repr = self.col_norm(msa_repr) - msa_repr = residual + self.col_attn(msa_repr) - - # nnscaler.runtime.function.anchor('MSATrans') - residual = msa_repr - msa_repr = self.msa_transition_norm(msa_repr) - msa_repr = self.msa_transition(msa_repr) - msa_repr = residual + msa_repr - succ_msa_repr, msa_repr = multi2ref(msa_repr) - - nnscaler.runtime.function.anchor('OPM') - msa_repr = self.outer_norm(msa_repr) - pair_repr = pair_repr + self.outer_prod_mean(msa_repr) - - nnscaler.runtime.function.anchor('TMO') - pair_repr = self.tmo(pair_repr) - - nnscaler.runtime.function.anchor('TMI') - pair_repr = self.tmi(pair_repr) - - nnscaler.runtime.function.anchor('TANS') - residual = pair_repr - pair_repr = self.tri_attn_node_start(pair_repr) - pair_repr = residual + pair_repr - - nnscaler.runtime.function.anchor('TANE') - residual = pair_repr - pair_repr = self.tri_attn_node_end(pair_repr) - pair_repr = residual + pair_repr - - nnscaler.runtime.function.anchor('PairTrans') - residual = pair_repr - pair_repr = self.pair_transition_norm(pair_repr) - pair_repr = self.pair_transition(pair_repr) - pair_repr = residual + pair_repr - - return succ_msa_repr, pair_repr - - def tflops(self, n_seq: int, n_res: int) -> float: - """ - Single sample tflops - """ - msa_size = n_seq * n_res * self.cm - pair_size = n_seq * n_res * self.cz - flops = 0 - - # msa layer norm - flops += 4 * (msa_size * 4) - # pair layer norm - flops += 2 * (pair_size * 4) - - # attention: gate + qkv + q@k (N S head r c, N S head c r) + k@v + dense - msa_attn = n_seq * n_res * self.cm * self.cm + \ - 3 * n_seq * n_res * self.cm * self.cm + \ - n_seq * (self.cm // self.c) * n_res * n_res * self.c + \ - n_seq * (self.cm // self.c) * n_res * n_res * self.c + \ - n_seq * n_res * self.cm * self.cm - - pair_attn = n_res * n_res * self.cz * self.cz + \ - 3 * n_res * n_res * self.cz * self.cz + \ - n_res * (self.cz // self.c) * n_res * n_res * self.c + \ - n_res * (self.cz // self.c) * n_res * n_res * self.c + \ - n_res * n_res * self.cz * self.cz - - # row and col end attention - flops += 2 * msa_attn - # tirangle start and triangle end - flops += 2 * pair_attn - # msa and pair transition flops - flops += 8 * n_seq * n_res * (self.cm ** 2) + \ - 8 * n_res * n_res * (self.cz ** 2) - # pair_repr tmi and tmo: projection + gate + 2 matmul - flops += 2 * (n_res * n_res * self.cz * self.c_tri_mult) + \ - n_res * n_res * self.cz * self.cz + \ - self.c_tri_mult * n_res * n_res * n_res + n_res * n_res * self.c_tri_mult * self.cz - # opm: left + right + opm - flops += 2 * n_seq * n_res * self.cm * self.cz + \ - n_res * n_res * n_seq * self.c * self.c + \ - n_res * n_res * self.c * self.c * self.cz - return flops / 1e12 diff --git a/examples/openfold/blocks/opm.py b/examples/openfold/blocks/opm.py deleted file mode 100644 index 8d4e8271..00000000 --- a/examples/openfold/blocks/opm.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -Outer Product Mean module for Evoformer -""" - -import nnscaler -import torch -import torch.utils.checkpoint as ckpt - - -# @nnscaler.register_op('N S+ R^ M^, M^ c^, M^ c^, (c^ c^) cz^ -> N R^ R^ cz^', name='outer_prod_mean') -@torch.jit.ignore -def outer_prod_mean(msa_repr: torch.Tensor, left_proj: torch.Tensor, right_proj: torch.Tensor, - out_proj: torch.Tensor, chunk_size: int, training: bool): - # nnscaler.profiler.CudaTimer().start('opm') - # N S R M, M c -> N S R c - opm_left = torch.matmul(msa_repr, left_proj) - # N S T M, M c -> N S T c - opm_right = torch.matmul(msa_repr, right_proj) - bs, s, r, c = opm_left.size() - t = opm_right.size(2) - - # N S R M -> N R S M - a = opm_left.transpose(-2, -3) - # N S T M -> N T S M - b = opm_right.transpose(-2, -3) - - if chunk_size == -1: - # N R S M, N T S M -> N R T M M -> N R T (M M) - outer = torch.einsum('...bac,...dae->...bdce', a, - b).reshape(bs, r, t, c * c) - # N R T (M M), (M M) Z -> N R T Z - outer = torch.matmul(outer, out_proj) - else: - out_chunks = [] - - def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): - lhs_slice = lhs[:, start:start + chunk_size, :, :] - out = torch.einsum('...bac,...dae->...bdce', lhs_slice, - rhs).reshape(bs, chunk_size, t, c * c) - out = torch.matmul(out, out_proj) - return out - - for start in range(0, r, chunk_size): - ret = ckpt.checkpoint(opm, a, b, start) - ret = opm(a, b, start) - out_chunks.append(ret) - outer = torch.cat(out_chunks, dim=1) - # nnscaler.profiler.CudaTimer().stop('opm') - return outer - - -@nnscaler.register_op('N S R M+, M+ C -> N S R C', name='opm_projection') -def opm_projection(msa_repr: torch.Tensor, proj1: torch.Tensor): - x = torch.matmul(msa_repr, proj1) - return x - - -@nnscaler.register_op('N S^ R C^, N S^ T^ C^, F^ Z^ -> N R T^ Z^') -@torch.jit.ignore -def opm(left: torch.Tensor, right: torch.Tensor, out_proj: torch.Tensor, - chunk_size: int, training: bool): - bs, s, r, c = left.size() - t = right.size(2) - # N S R C -> N R S C - a = left.transpose(-2, -3) - # N S T C -> N T S C - b = right.transpose(-2, -3) - - if chunk_size == -1: - # N R S M, N T S M -> N R T M M -> N R T (M M) - outer = torch.einsum('...bac,...dae->...bdce', a, - b).reshape(bs, r, t, c * c) - # N R T (M M), (M M) Z -> N R T Z - outer = torch.matmul(outer, out_proj) - else: - out_chunks = [] - - def opm(lhs: torch.Tensor, rhs: torch.Tensor, start: int): - lhs_slice = lhs[:, start:start + chunk_size, :, :] - out = torch.einsum('...bac,...dae->...bdce', lhs_slice, - rhs).reshape(bs, chunk_size, t, c * c) - out = torch.matmul(out, out_proj) - return out - - for start in range(0, r, chunk_size): - ret = ckpt.checkpoint(opm, a, b, start) - ret = opm(a, b, start) - out_chunks.append(ret) - outer = torch.cat(out_chunks, dim=1) - # nnscaler.profiler.CudaTimer().stop('opm') - return outer - - -class OuterProducterMean(torch.nn.Module): - - def __init__(self, cm: int, c: int, cz: int, chunk_size: int) -> None: - super().__init__() - self.left = torch.nn.Parameter(torch.empty(cm, c)) - self.right = torch.nn.Parameter(torch.empty(cm, c)) - self.out = torch.nn.Parameter(torch.empty(c * c, cz)) - self.chunk_size = chunk_size - - def forward(self, msa_repr: torch.Tensor): - """ - msa_repr: [N S R M] - """ - left = opm_projection(msa_repr, self.left) - right = opm_projection(msa_repr, self.right) - out = opm(left, right, self.out, self.chunk_size, self.training) - return out - # return outer_prod_mean( - # msa_repr, self.left, self.right, self.out, - # self.chunk_size, self.training - # ) diff --git a/examples/openfold/blocks/tmu.py b/examples/openfold/blocks/tmu.py deleted file mode 100644 index 32d0550e..00000000 --- a/examples/openfold/blocks/tmu.py +++ /dev/null @@ -1,98 +0,0 @@ -import nnscaler -import torch -from examples.openfold.blocks.utils import multi2ref - - -# @nnscaler.register_op('N S R Z^, Z^ E, Z^ E -> N S R E') -# def tmu_projection(pair_repr: torch.Tensor, proj1: torch.Tensor, proj2: torch.Tensor): -# x = torch.matmul(pair_repr, proj1) -# x = torch.sigmoid(x) -# x = x * torch.matmul(pair_repr, proj2) -# -# -# @nnscaler.register_op('N S R Z+, Z+ E-> N S R E') -# def tmu_gate(pair_repr: torch.Tensor, proj: torch.Tensor): -# return torch.sigmoid(torch.matmul(pair_repr, proj)) - - -@nnscaler.register_op('N S R Z^, Z^ E^, Z^ E^, Z^ E, Z^ E^, Z^ Z^ -> N S R E, N S R E^, N S R Z^', name='tmu_projection') -def tmu_projection(pair_repr: torch.Tensor, - left1: torch.Tensor, left2: torch.Tensor, - right1: torch.Tensor, right2: torch.Tensor, - gate: torch.Tensor): - # left - left = torch.matmul(pair_repr, left1) - left = torch.sigmoid(left) - left = left * torch.matmul(pair_repr, left2) - # right - right = torch.matmul(pair_repr, right1) - right = torch.sigmoid(right) - right = right * torch.matmul(pair_repr, right2) - # gate - gate = torch.sigmoid(torch.matmul(pair_repr, gate)) - - return left, right, gate - - -@nnscaler.register_op('N S R^ E, N T^ R^ E^, N S^ T^ Z^, E^, E^, E^ Z^ -> N S T^ Z^', name='tmo') -def tmo(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, - norm_w: torch.Tensor, norm_b: torch.Tensor, out: torch.Tensor): - a = left.permute(0, 3, 1, 2) - b = right.permute(0, 3, 2, 1) - p = torch.matmul(a, b).permute(0, 2, 3, 1) - p = torch.nn.functional.layer_norm(p, (128, ), norm_w, norm_b) - p = torch.matmul(p, out) - p = p * gate - return p - - -@nnscaler.register_op('N R^ S E, N R^ T^ E^, N T^ S^ Z^, E^, E^, E^ Z^ -> N T^ S Z^', name='tmi') -def tmi(left: torch.Tensor, right: torch.Tensor, gate: torch.Tensor, - norm_w: torch.Tensor, norm_b: torch.Tensor, out: torch.Tensor): - a = left.permute(0, 3, 2, 1) - b = right.permute(0, 3, 1, 2) - p = torch.matmul(a, b).permute(0, 2, 3, 1) - p = torch.nn.functional.layer_norm(p, (128, ), norm_w, norm_b) - p = torch.matmul(p, out) - p = p.permute(0, 2, 1, 3) * gate - return p - - -class TriangleMultiplicativeUpdate(torch.nn.Module): - - def __init__(self, cz: int, mult: int, outgoing: bool) -> None: - super().__init__() - self.layer_norm = torch.nn.LayerNorm((cz,)) - - self.left1 = torch.nn.Parameter(torch.empty(cz, mult)) - self.left2 = torch.nn.Parameter(torch.empty(cz, mult)) - self.right1 = torch.nn.Parameter(torch.empty(cz, mult)) - self.right2 = torch.nn.Parameter(torch.empty(cz, mult)) - - # self.norm = torch.nn.LayerNorm(mult) - self.normw = torch.nn.Parameter(torch.empty(mult)) - self.normb = torch.nn.Parameter(torch.empty(mult)) - - self.out = torch.nn.Parameter(torch.empty(mult, cz)) - self.gate = torch.nn.Parameter(torch.empty(cz, cz)) - self.outgoing = outgoing - - def forward(self, pair_repr: torch.Tensor): - """ - pair_repr: [N S R Z] - """ - residual = pair_repr - pair_repr = self.layer_norm(pair_repr) - - left, right, gate = tmu_projection(pair_repr, - self.left1, self.left2, - self.right1, self.right2, self.gate - ) - - if self.outgoing: - pair_repr = tmo(left, right, gate, self.normw, self.normb, self.out) - else: - pair_repr = tmi(left, right, gate, self.normw, self.normb, self.out) - - pair_repr = residual + pair_repr - return pair_repr diff --git a/examples/openfold/blocks/utils.py b/examples/openfold/blocks/utils.py deleted file mode 100644 index 7bb61f49..00000000 --- a/examples/openfold/blocks/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -import nnscaler -import torch - - -@nnscaler.register_op('* -> *, *', name='multi2ref') -def multi2ref(x: torch.Tensor): - return (x, x) \ No newline at end of file diff --git a/examples/openfold/model.py b/examples/openfold/model.py deleted file mode 100644 index 3403b3c2..00000000 --- a/examples/openfold/model.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Alphafold 2, using implementation similar with OpenFold. -""" -import torch -import torch.nn as nn - -# from examples.openfold.blocks.embedder import InputEmbedder, RecyclingEmbedder, TemplateAngleEmbedder -from examples.openfold.blocks.evoformer import Evoformer -# from examples.openfold.blocks.evoformer import input_packing, input_unpacking - -from dataclasses import dataclass - -import nnscaler - - -@dataclass -class Config: - - # input_embedder - # input_embedder_cm = 256 - # input_embedder_cz = 128 - # input_embedder_msa_dim = 49 - # input_embedder_relpos_k = 32 - # input_embedder_tf_dim = 22 - - # recycling embedder - # recycling_embedder_cm = 256 - # recycling_embedder_cz = 128 - # recycling_embedder_inf = 1000000000.0 - # recycling_embedder_maxbin = 20.75 - # recycling_embedder_minbin = 3.25 - # recycling_embedder_nobins = 15 - - # templates - # template_angle_embedder_cin = 57 - # template_angle_embedder_cout = 256 - # template_pair_embedder_cin = 88 - # template_pair_embedder_cout = 64 - # template_pair_stack_hidden_tri_att = 16 - # template_pair_stack_hidden_tri_mul = 64 - # template_pair_stack_ct = 64 - # template_pair_stack_dp = 0.25 - # template_pair_stack_inf = 100000000.0 - # template_pair_stack_noblocks = 2 - # template_pair_stack_noheads = 4 - # template_pair_stack_pair_transition_n = 2 - # template_pointwise_attention_hidden = 16 - # template_pointwise_attention_ct = 64 - # template_pointwise_attention_cz = 128 - # template_pointwise_inf = 1000000000.0 - # template_pointwise_noheads = 4 - - # extra msa - # extra_msa_embedder_cin = 25 - # extra_msa_embedder_cout = 64 - # extra_msa_stack_hidden_att = 8 - # extra_msa_stack_hidden_mul = 128 - # extra_msa_stack_opm = 32 - # extra_msa_stack_pair_att = 32 - # extra_msa_stack_cm = 64 - # extra_msa_stack_cz = 128 - # extra_msa_stack_eps = 1e-8 - # extra_msa_stack_inf = 1000000000.0 - # extra_msa_dp = 0.15 - # extra_msa_stack_noblocks = 4 - # extra_msa_stack_no_heads_msa = 8 - # extra_msa_stack_no_heads_pair = 4 - # extra_msa_stack_pair_dp = 0.25 - # extra_msa_stack_transition_n = 4 - - # evoformer - evoformer_s: int = 128 - evoformer_r: int = 256 - evoformer_cm: int = 256 - evoformer_cz: int = 128 - evoformer_c: int = 32 - evoformer_use_chunk: bool = False - evoformer_is_extra: bool = False - evoformer_nlayers: int = 4 - - # batch size - bs: int = 1 - - -class AlphaFold(nn.Module): - - - def __init__(self, cfg: Config = Config()) -> None: - super().__init__() - self.cfg = cfg - - # self.input_embedder = InputEmbedder( - # cfg.input_embedder_tf_dim, cfg.input_embedder_msa_dim, - # cfg.input_embedder_cz, cfg.input_embedder_cm, - # cfg.input_embedder_relpos_k - # ) - # self.recycling_embedder = RecyclingEmbedder( - # cfg.recycling_embedder_cm, cfg.recycling_embedder_cz, - # cfg.recycling_embedder_minbin, cfg.recycling_embedder_maxbin, - # cfg.recycling_embedder_nobins, cfg.recycling_embedder_inf - # ) - - # template config - # self.template_angle_embedder = TemplateAngleEmbedder( - # cfg.template_angle_embedder_cin, - # cfg.template_angle_embedder_cout - # ) - # self.template_pair_embedder = nn.Linear( - # cfg.template_pair_embedder_cin, - # cfg.template_pair_embedder_cout, - # ) - self.template_pair_stack = None # TemplatePairStack() - self.template_pointwise_att = None # TemplatePointwiseAttention() - - # extra msa - # self.extra_msa_embedder = nn.Linear( - # cfg.extra_msa_embedder_cin, cfg.extra_msa_embedder_cout - # ) - - self.extra_msa_stack = None # ExtraMSAStack() - - # evoformer - self.s, self.r, self.cm, self.cz = cfg.evoformer_s, cfg.evoformer_r, cfg.evoformer_cm, cfg.evoformer_cz - self.c = self.cfg.evoformer_c - assert self.cm % self.c == 0 and self.cz % self.c == 0 - - self.msa_norm = torch.nn.LayerNorm(cfg.evoformer_cm) - self.pair_norm = torch.nn.LayerNorm(cfg.evoformer_cz) - self.evoformers = nn.ModuleList( - [Evoformer( - self.s, self.r, self.cm, self.cz, - c=self.c, msa_head=self.cm // self.c, pair_head=self.cz // self.c, - ) for _ in range(cfg.evoformer_nlayers)] - ) - - self.structure_module = None # StructureModule() - self.aux_heads = None # AuxiliaryHeads() - - def forward(self, msa, pair): - """ - msa: [N S R cm] - pair: [N R R cz] - """ - msa = self.msa_norm(msa) - pair = self.pair_norm(pair) - # nnscaler.runtime.function.anchor('PackingRegion') - # x = input_packing(msa, pair, self.fout) - for evoformer in self.evoformers: - nnscaler.runtime.function.anchor('Evoformer Start') - msa, pair = evoformer(msa, pair) - # x = evoformer(x) - # msa, pair = input_unpacking(x, self.s, self.r, self.cm, self.cz) - loss = torch.sum(msa) * torch.sum(pair) - return loss - - - def tflops(self) -> float: - """ - TFLOPs for one sample - """ - tflops = 0. - for layer in self.evoformers: - tflops += layer.tflops(self.s, self.r) - return tflops diff --git a/examples/openfold/policy/mpmd.py b/examples/openfold/policy/mpmd.py deleted file mode 100644 index 26353f1c..00000000 --- a/examples/openfold/policy/mpmd.py +++ /dev/null @@ -1,313 +0,0 @@ -from typing import List - -from nnscaler.graph import IRGraph -from nnscaler.ir.cten import IRCell -from nnscaler.graph.function.anchor import IRGraphAnchor -from nnscaler.graph.segment import IRSegment -from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation -from nnscaler.graph.schedule.schednf1b import IRScheduleNF1B -from nnscaler.graph.schedule.sched1f1b import IRSchedule1F1B - -import more_itertools -import numpy as np - - -def _group_to_evoformers(fnodes) -> List[List[IRCell]]: - # group to evoformer layers - evoformers: List[List[IRFwOperation]] = [] - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor) and node.name == 'Evoformer Start'] - indices = [fnodes.index(anchor) for anchor in anchors] - for lid, idx in enumerate(indices): - # get first forward op - for fnode in fnodes[idx+1:]: - if not isinstance(fnode, IRGraphAnchor): break - fnode.comment = f'===> start of evoformer layer {lid}' - start = idx if lid != 0 else 0 - end = indices[lid+1] if lid + 1 < len(anchors) else len(fnodes) - evoformers.append(fnodes[start:end]) - print(f'find {len(indices)} evoformer layers') - return evoformers - -# ========================= parallelisms ================================= - -# tensor parallelism -def _tp(graph: IRGraph, node: IRFwOperation, devs: List[int], tag='dim', **config): - if len(devs) == 1: - sub_nodes = [node] - else: - algo = node.algorithms(tag) - sub_nodes = graph.partition(node, algo, num=len(devs), **config) - assert sub_nodes is not None - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, int(devid)) - return sub_nodes - -# replicate -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]): - if len(devs) == 1: - sub_nodes = [node] - else: - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, int(devid)) - return sub_nodes - - -# ========================= policies ================================= - - -def PASSingle(graph: IRGraph, resource): - assert resource.ngpus == 1 - # print(graph.extra_repr()) - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - return graph - - -def PASDP(graph: IRGraph, resource): - dp_size = resource.ngpus - dp_devs = list(range(dp_size)) - - dataloader = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[0] - - # partition dataloader - dls = graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) - for devid, dl in enumerate(dls): - graph.assign(dl, devid) - - # partition forward operators - for node in graph.select(ntype=IRFwOperation): - if len(node.inputs()) == 0: continue - #FIXME: a workaround to find batch dimension - batch_dim = node.input(0).shape.index(bs) - _tp(graph, node, dp_devs, idx=0, dim=batch_dim) - - return graph - - -def PASDAP(graph: IRGraph, resource, tp: int): - - assert resource.ngpus % tp == 0 - dp = resource.ngpus // tp - - devmesh = np.arange(resource.ngpus).reshape(dp, tp) - tp_devs = list(range(tp)) - - # grouping - evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) - for layer in evoformers: - graph.recompute(layer) - - dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[dataloader.get_batch_dims()[0]] - print(f'> get batch size: {bs}') - dls: List[IRDataOperation] = _replica(graph, dataloader, tp_devs) - for tp_idx, dl in enumerate(dls): - dp_devs = devmesh[:,tp_idx] - _tp(graph, dl, dp_devs, 'data') - - - fnodes = graph.select(ntype=IRFwOperation) - fnodes = [fnode for fnode in fnodes if fnode.name != 'Evoformer Start'] - - node_groups = more_itertools.split_at(fnodes, lambda n: isinstance(n, IRGraphAnchor)) - - for nodes in node_groups: - # tensor parallelism - names = set(n.name for n in nodes) - subnodes = [] - if len(names) == 1 or 'mul' in names: # for first layer norm operators - for node in nodes: - subnodes.append(_replica(graph, node, tp_devs)) - # elif 'input_packing' in names: - # for node in nodes: - # subnodes.append(_replica(graph, node, tp_devs)) - elif 'row_attn' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) - elif 'col_attn' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) - elif 'opm' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) - elif 'tmo' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) - elif 'tmi' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) - elif 'tri_attn_start' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) - elif 'tri_attn_end' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) - elif 'feedforward' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) - else: - assert False, names - # data parallelism - for ns in subnodes: - for tp_idx, subnode in enumerate(ns): - dp_devs = devmesh[:,tp_idx] - if bs in subnode.input(0).shape: - dim = subnode.input(0).shape.index(bs) - _tp(graph, subnode, dp_devs, idx=0, dim=dim) - else: - print(f'replicate op on data parallel group: {node.name}') - _replica(graph, subnode, dp_devs) - - return graph - - -def PASRoundRobin(graph: IRGraph, resource): - - pp_size = resource.ngpus - - # grouping - evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) - for layer in evoformers: - graph.recompute(layer) - - - fstages = [[] for _ in range(pp_size)] - nlayer_per_stage = len(evoformers) // pp_size - for lid, fnodes in enumerate(evoformers): - sid = min(lid // nlayer_per_stage, pp_size - 1) - fstages[sid] += fnodes - graph.staging(tuple(stages[0] for stages in fstages)) - - dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] - graph.assign(dataloader, 0) - - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - for sid, fstage in enumerate(fstages): - graph.assign(fstage, sid) - - return graph - - -def PASNF1B(graph: IRGraph, resource, mbs: int, gbs: int, recycle: int): - - assert gbs % mbs == 0 - nmbs = gbs // mbs - pp_size = resource.ngpus - - # grouping - evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) - assert len(evoformers) % pp_size == 0 - for layer in evoformers: - graph.recompute(layer) - - fstages = [[] for _ in range(pp_size)] - nlayer_per_stage = len(evoformers) // pp_size - for lid, fnodes in enumerate(evoformers): - sid = min(lid // nlayer_per_stage, pp_size - 1) - fstages[sid] += fnodes - graph.staging(tuple(stages[0] for stages in fstages)) - - dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] - graph.assign(dataloader, 0) - - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - for sid, fstage in enumerate(fstages): - graph.assign(fstage, sid) - - strategy = IRSchedule1F1B(graph, nmbs) - graph.predef_sched(strategy) - - return graph - - -def PASDAPPipe(graph: IRGraph, resource, mbs: int, gbs: int, tp: int, pp: int, recycle: int): - - assert gbs % mbs == 0 - assert resource.ngpus % (pp * tp) == 0 - dp = resource.ngpus // (pp * tp) - nmbs = gbs // mbs - - devmesh = np.arange(resource.ngpus, dtype=int).reshape(dp, pp, tp) - tp_devs = [0] * tp # dummy device, which will be reset at dp - - - # grouping - evoformers = _group_to_evoformers(graph.select(ntype=IRFwOperation)) - assert len(evoformers) % pp == 0 - for layer in evoformers: - graph.recompute(layer) - - fstages = [[] for _ in range(pp)] - nlayer_per_stage = len(evoformers) // pp - for lid, fnodes in enumerate(evoformers): - sid = min(lid // nlayer_per_stage, pp - 1) - fstages[sid] += fnodes - graph.staging(tuple(stages[0] for stages in fstages)) - - # setup dataloader - dataloader: IRDataOperation = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[dataloader.get_batch_dims()[0]] - print(f'> get batch size: {bs}') - dls: List[IRDataOperation] = _replica(graph, dataloader, tp_devs) - for tp_idx, dl in enumerate(dls): - dp_devs = devmesh[:, 0, tp_idx] - _tp(graph, dl, dp_devs, 'data') - - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - assert len(fstages) > 0 - for sid, fstage in enumerate(fstages): - fnodes = fstage.select(ntype=IRFwOperation) - fnodes = [fnode for fnode in fnodes if fnode.name != 'Evoformer Start'] - node_groups = more_itertools.split_at(fnodes, lambda n: isinstance(n, IRGraphAnchor)) - for nodes in node_groups: - # tensor parallelism - names = set(n.name for n in nodes) - subnodes = [] - if len(names) == 1 or 'mul' in names: # for first layer norm operators - for node in nodes: - subnodes.append(_replica(graph, node, tp_devs)) - elif 'row_attn' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) - elif 'col_attn' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) - elif 'opm' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) - elif 'tmo' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) - elif 'tmi' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) - elif 'tri_attn_start' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) - elif 'tri_attn_end' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=2)) - elif 'feedforward' in names: - for node in nodes: - subnodes.append(_tp(graph, node, tp_devs, idx=0, dim=1)) - else: - assert False, names - # data parallelism - for ns in subnodes: - for tp_idx, subnode in enumerate(ns): - dp_devs = devmesh[:, sid, tp_idx] - if bs in subnode.input(0).shape: - dim = subnode.input(0).shape.index(bs) - _tp(graph, subnode, dp_devs, idx=0, dim=dim) - else: - print(f'replicate op on data parallel group: {node.name}') - _replica(graph, subnode, dp_devs) - - strategy = IRScheduleNF1B(graph, nmbs, recycle) - # strategy = IRSchedule1F1B(graph, nmbs) - graph.predef_sched(strategy) - - return graph diff --git a/examples/openfold/train.py b/examples/openfold/train.py deleted file mode 100644 index 61417954..00000000 --- a/examples/openfold/train.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=1 \ - examples/openfold/train.py --fp16 --layers 24 --gbs 1 --recycle 2 -""" - - -import torch -from examples.openfold.model import AlphaFold, Config - -import nnscaler -from nnscaler.compiler import compile, SemanticModel -from nnscaler.profiler.timer import CudaTimer, print_each_rank -from nnscaler.profiler.memory import memory_summary -from examples.openfold.policy.mpmd import PASDAP, PASRoundRobin, PASNF1B, PASDAPPipe - -import argparse -from functools import partial - - -nnscaler.init() - -parser = argparse.ArgumentParser(description='AlphaFold Train') -parser.add_argument('--fp16', action='store_true', default=False, - help='use fp16 for the training') -parser.add_argument('--layers', type=int, default=4, - help='evoformer layer number') -parser.add_argument('--msa-hidden', type=int, default=256, - help='cm value') -parser.add_argument('--pair-hidden', type=int, default=128, - help='cz value') -parser.add_argument('--head-dim', type=int, default=32, - help='c value') -parser.add_argument('--mbs', type=int, default=1, - help='micro batch size') -parser.add_argument('--gbs', type=int, default=1, - help='global batch size') -parser.add_argument('--tp', type=int, default=1, - help='tensor parallelism size') -parser.add_argument('--pp', type=int, default=1, - help='data parallelism size') -parser.add_argument('--recycle', type=int, default=2, - help='data parallelism size') - -args = parser.parse_args() -dp = nnscaler.runtime.device.DeviceGroup().world_size // (args.tp * args.pp) -assert args.gbs % args.mbs == 0 -assert args.mbs % dp == 0 -assert args.msa_hidden % args.head_dim == 0 -assert args.pair_hidden % args.head_dim == 0 - - -# PASDAP = partial(PASDAP, tp=args.tp) -PASNF1B = partial(PASNF1B, mbs=args.mbs, gbs=args.gbs, recycle=1) -PASDAPPipe = partial(PASDAPPipe, mbs=args.mbs, gbs=args.gbs, tp=args.tp, pp=args.pp, recycle=args.recycle) - - -def nparams(model) -> int: - cnt = 0 - for param in model.parameters(): - cnt += param.nelement() - return cnt - - -def train(): - - cfg = Config(evoformer_cm=args.msa_hidden, evoformer_cz=args.pair_hidden, - evoformer_c=args.head_dim, evoformer_nlayers=args.layers, - bs=args.mbs) - print_each_rank(cfg, rank_only=0) - - model = AlphaFold(cfg) - print_each_rank(f'iteration total TFLOPs: {model.tflops() * (args.recycle + 1 + 2)}') - if args.fp16: - model = model.half() - - dtype = torch.float16 if args.fp16 else torch.float32 - dataloader = nnscaler.runtime.syndata.SynDataLoader( - shapes=([cfg.bs, cfg.evoformer_s, cfg.evoformer_r, cfg.evoformer_cm], - [cfg.bs, cfg.evoformer_r, cfg.evoformer_r, cfg.evoformer_cz]), - dtypes=(dtype, dtype), - batch_dims=(0, 0) - ) - - print_each_rank(f'before partitioned model parameter: {nparams(model)}') - - model = SemanticModel(model) - @compile(model, dataloader, PAS=PASDAPPipe, override=True, load_content=True) - def train_iter(model, dataloader): - input_ids, position_ids = next(dataloader) - loss = model(input_ids, position_ids) - loss.backward() - model = model.get_gen_module() - - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - print_each_rank(f'after partitioned model parameter: {nparams(model)}') - - torch.distributed.barrier() - print_each_rank('model weight consumpition:', rank_only=0) - memory_summary() - - CudaTimer(enable=False).warmup() - iter_num, warmup = 5, 2 - for step in range(iter_num): - if step == warmup: - CudaTimer(enable=True, predefined=True).start('e2e') - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - - if step == 0: - print_each_rank('passed first iteration') - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - CudaTimer().stop('e2e') - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) - - memory_summary() - -train() \ No newline at end of file diff --git a/examples/policies/__init__.py b/examples/policies/__init__.py deleted file mode 100644 index 749781c0..00000000 --- a/examples/policies/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from examples.policies.gshard import PASGShard -from examples.policies.random_spmd import PASRandomSPMD -from examples.policies.alpa import PASAlpa \ No newline at end of file diff --git a/examples/policies/alpa/README.md b/examples/policies/alpa/README.md deleted file mode 100644 index 5c847cea..00000000 --- a/examples/policies/alpa/README.md +++ /dev/null @@ -1,26 +0,0 @@ - -# Alpa Implementation - -## Prerequisite - -```sh -pip install pulp -``` - -## Implementation Notes - -* The implementation doesn't support auto_layer construction, and relies on the `nnscaler.runtime.function.anchor` as stage division candidates. - -* The implementation doesn't support `follow`, which relies on the user customized operator to achieve manual fusion. - -* For computation cost: - - * we assume the full efficiency, which is calculated by `cost/tp/dp` - - * Similar with Alpa, we force computation-intensive operators to be partitioned, and allow computation-light operators to be replicated. The computation-intensive operators are defined as operators that require weight for input (usually are customized operators). - -* For communication cost: - - * Similar with Alpa, we calculate the cost of communication by `bytes / bandwidth`. - - diff --git a/examples/policies/alpa/__init__.py b/examples/policies/alpa/__init__.py deleted file mode 100644 index 7b012bc9..00000000 --- a/examples/policies/alpa/__init__.py +++ /dev/null @@ -1,240 +0,0 @@ -from typing import List, Optional -from functools import partial -import warnings -import torch - -from nnscaler.graph.function.anchor import IRGraphAnchor -from nnscaler.graph.function.dimops import IRDimops, TransformRule, DimopSplit -from nnscaler.graph.graph import IRGraph -from nnscaler.graph.segment import IRSegment -from nnscaler.ir.operator import IRFwOperation, IRDataOperation -from nnscaler.ir.tensor import IRFullTensor -from nnscaler.graph.schedule.predefined import PredefinedSched -from nnscaler.runtime.device import DeviceGroup - -from examples.policies.alpa.plan import ParallelSpec -from examples.policies.alpa.inter_op import inter_op -from examples.policies.alpa.intra_op import intra_op -from examples.policies.alpa.layer_op import annotate_structure -from examples.policies.alpa.cost_model import CostModel -from examples.policies.alpa.estimator import Estimator - - -def _replica(graph: IRGraph, node: IRFwOperation, devs: List[int]) -> List[IRDimops]: - """Replicate a node""" - sub_nodes = [node] if len(devs) == 1 else graph.replicate(node, len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def _tp(graph: IRGraph, node: IRDimops, devs: List[int], **configs) -> List[IRDimops]: - """Tensor parallelism on a node""" - sub_nodes = [node] if len(devs) == 1 \ - else graph.partition(node, node.algorithms('dim'), **configs) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def _auto_multiref(graph: IRGraph, plan: ParallelSpec): - """ - Apply automated multiref on tensors that are partitioned differently by different nodes - """ - # get parallel strategy - specs = dict() - for stage in plan.stages: - for cid, spec in stage.tp_spec.items(): - specs[cid] = spec - - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): continue - if len(graph.consumers(ftensor)) <= 1: continue - consumers, ctensors = graph.consumers(ftensor), graph.ctensors(ftensor) - splits = set() - for consumer, ctensor in zip(consumers, ctensors): - spec = specs[consumer.cid] - if spec is None: - splits.add(DimopSplit.R()) - else: - idx, dim = spec - rule: TransformRule = consumer.algorithms('dim').infer(idx, dim, 1) - split = rule.inputs()[consumer.inputs().index(ctensor)] - splits.add(split) - if len(splits) > 1: - print(f"> detected a(n) {'activation' if not ftensor.is_attr() else 'parameter'}: " - f"{ftensor.name}({ftensor.tid}) is partitioned differently. Apply multierf...") - graph.multiref(ftensor) - - -def PASAlpa(graph: IRGraph, resource, - recompute: bool = False, - nmicros: int = 1, - db_cache: str = 'db_train.json', - load_spec_file: Optional[str] = None, - save_spec_file: Optional[str] = None, - max_pp_size: Optional[int] = None, - max_tp_size: Optional[int] = None, - max_layer_number: int = 12) -> IRGraph: - """ - Alpa policy examples. - - Require user to manually add cune.runtime.anchor inside model - for AutoLayer partition position - - @param graph IRGraph: model graph - @param resource Resource: resource - @param recompute bool: whether to enable recompute on each layer - @param nmicros int: number of micro-batches - @param db_cache str: database cache file - @param load_spec_file str: reuse spec file - @param save_spec_file str: save spec file - @param max_pp_size Optional[int]: limit the maximum number of pipeline parallelism size - @param max_tp_size Optional[int]: limit the maximum number of tensor parallelism size - @param max_layer_number Optional[int]: maximum number of layers to search - """ - # recompute granularity will follow original anchor scope - layers = annotate_structure(graph) - if recompute: - for layer in layers: - graph.recompute(layer) - - anchors = graph.select(ntype=IRGraphAnchor) - nlayers = len(anchors) + 1 - removed = 0 - while removed < nlayers - max_layer_number: - for anchor in list(anchors[::2]): - graph.remove(anchor) - anchors.remove(anchor) - removed += 1 - if removed >= nlayers - max_layer_number: break - anchors = graph.select(ntype=IRGraphAnchor) - if removed > 0: - print(f'> shrink search space to {len(anchors)+1} layers') - - # enable this will follow alpa's policy: recompute on auto-layer granularity - # layers = annotate_structure(graph) - # if recompute: - # for layer in layers: - # graph.recompute(layer) - nodes = tuple(graph.select(ntype=IRFwOperation)) - - dl: IRDataOperation = graph.select(ntype=IRDataOperation)[0] - mbs: int = dl.output(0).shape[dl.get_batch_dims()[0]] - - # reserve 2GB memory for nccl - mem_limit = resource.gpus[0].memory - 2 * 1024 * 1024 * 1024 - print(f'> search [constraints]: device limitied memory: {mem_limit}') - # profile - print(f'> profiling model...') - estimator = Estimator(db_cache) - latency, memory = estimator(nodes, train=graph.train) - print(f'> search [estimation]: single device latency: {latency} ms, memory: {memory/1024/1024/1024} GB') - if DeviceGroup().rank == 0: - print(f'> search [dump]: saving profiled database...') - estimator.save() - # build cost model - print(f'> building cost model...') - cost_model = CostModel(graph, estimator) - - # alpa search -- only apply on rank 0 to ensure deterministic - if DeviceGroup().rank == 0: - if isinstance(load_spec_file, str): - print(f'loading spec from {load_spec_file}...') - config = ParallelSpec.load(load_spec_file, graph) - else: - print(f'> start searching...') - intra_solver = partial(intra_op, recompute=recompute, memory_limit=mem_limit, cost_model=cost_model) - config = inter_op(nodes, resource.ngpus, intra_solver, mbs, - max_p=max_pp_size, max_t=max_tp_size) - print(f'> parallel spec results:\n{config}') - - if isinstance(save_spec_file, str): - print(f'> saving spec to {save_spec_file}...') - config.save(save_spec_file) - - state: str = config.getstate() - state = torch.tensor([ord(c) for c in state], dtype=torch.int, device=torch.cuda.current_device()) - # notify -suppose each node has 8 gpus - for rank in range(8, DeviceGroup().world_size, 8): - print(f'> notify rank {rank} has finished searching...') - torch.distributed.send(torch.tensor([state.size(0)], device=torch.cuda.current_device()), dst=rank) - torch.distributed.send(state, dst=rank) - - else: - print('> waiting for rank 0 to finish searching...') - length = torch.tensor([0], device=torch.cuda.current_device()) - torch.distributed.recv(length, src=0) - state = torch.empty(length.item(), dtype=torch.int, device=torch.cuda.current_device()) - torch.distributed.recv(state, src=0) - state = ''.join([chr(c) for c in state.tolist()]) - config = ParallelSpec.loadstate(state) - print(f'> parallel spec results:\n{config}') - - print(f'> instantiate plan...') - # print(graph.extra_repr()) - - # auto-multiref - _auto_multiref(graph, config) - - # staging - cid2node = {n.cid : n for n in nodes} - leading_cids = [list(stage.tp_spec.keys())[0] for stage in config.stages] - leading_nodes = [cid2node[cid] for cid in leading_cids] - graph.staging(leading_nodes) - segments = graph.select(ntype=IRSegment, flatten=False) - fsegments = [seg for seg in segments if seg.isfw()] - assert len(fsegments) == len(config.stages) - - # replicate data loader - devices = list(range(resource.ngpus)) - _replica(graph, dl, devices) - - # partition - # TODO: make data parallel to be outside of pipeline parallelism - for sidx, stage in enumerate(config.stages): - tp, dp = stage.tp_size, stage.dp_size - spec = stage.tp_spec - stage_devices, devices = devices[:tp*dp], devices[tp*dp:] - print(f'> applying spec: tp={tp}, dp={dp} for stage {sidx}...') - for node in fsegments[sidx].nodes(): - if isinstance(node, IRGraphAnchor) or node.name == 'multiref': - continue - if node.cid not in spec: - print(f'warning: node {node.name}({node.cid}) not in spec, replicate') - _replica(graph, node, stage_devices) - continue - if mbs not in node.input(0).shape: - if dp > 1: - print(f'warning: cannot find batch dimension of {node.name}({node.cid}), assuming idx=0, dim=0') - batch_dim = 0 - else: - batch_dim = node.input(0).shape.index(mbs) - strategy = spec[node.cid] if node.cid in spec else None - # data parallel - if not isinstance(node, IRDimops): - warnings.warn(f'detected a node {node.name} is not IRDimops, replicate for data parallel') - dp_nodes = [node] if dp == 1 else graph.replicate(node, times=dp) - else: - dp_nodes = [node] if dp == 1 else \ - graph.partition(node, node.algorithms('dim'), idx=0, dim=batch_dim, num=dp) - # tensor parallelism - tp_nodes = [] - for dp_node in dp_nodes: - if strategy is None: - ts = [dp_node] if tp == 1 else graph.replicate(dp_node, times=tp) - else: - idx, dim = strategy - ts = [dp_node] if tp == 1 else \ - graph.partition(dp_node, dp_node.algorithms('dim'), idx=idx, dim=dim, num=tp) - assert len(ts) == tp, f"got tp nodes: {ts} | partition {dp_node} with {strategy}" - tp_nodes += ts - for devid, tp_node in zip(stage_devices, tp_nodes): - graph.assign(tp_node, devid) - # print(graph.extra_repr()) - # setup schedule - if graph.train: - sched = PredefinedSched.sched_1f1b(graph, nmicros, len(config.stages)) - else: - sched = PredefinedSched.sched_infer_pipe(graph, nmicros, len(config.stages)) - return graph diff --git a/examples/policies/alpa/cost_model.py b/examples/policies/alpa/cost_model.py deleted file mode 100644 index 91ba02a0..00000000 --- a/examples/policies/alpa/cost_model.py +++ /dev/null @@ -1,227 +0,0 @@ -""" -Cost model for intra-op plan search -""" -from typing import List, Callable, Tuple, Dict -import numpy as np - -from nnscaler.graph import IRGraph -from nnscaler.ir.cten import IRTensor -from nnscaler.ir.operator import IRFwOperation -from nnscaler.graph.function.anchor import IRGraphAnchor -from nnscaler.graph.function.dimops import IRDimops, TransformRule, DimopSplit - - -DistSpec = Dict[int, Tuple[Tuple[int, int]]] - - -class CommCost: - """ - Get communication cost in milliseconds - """ - @staticmethod - def get_bandwidth(ranks: List[int]): - """ - TODO: support with real runtime information - """ - if len(ranks) < 8: - return 150 * 1e9 # 150 GB/s for intra-node (NVLink) - else: - return 12.5 * 1e9 # 12.5 GB/s for inter-node (IB) - - @staticmethod - def allreduce_cost(tensor: IRTensor, num_devices: int) -> float: - bandwidth = CommCost.get_bandwidth(list(range(num_devices))) - return 2 * (num_devices - 1) * tensor.byte_size() / num_devices / bandwidth * 1000 - - @staticmethod - def alltoall_cost(tensor: IRTensor, num_devices: int) -> float: - # bandwidth in all-to-all is really worse (1GB/s) and should not use - return 1e6 - bandwidth = CommCost.get_bandwidth(list(range(num_devices))) - return tensor.byte_size() / num_devices / num_devices * (num_devices - 1) / bandwidth * 1000 - - @staticmethod - def allgather_cost(tensor: IRTensor, num_devices: int) -> float: - # bandwidth in allgather can only be half due to torch implementation issues - # return 1e6 - bandwidth = CommCost.get_bandwidth(list(range(num_devices))) / 2.98 - return tensor.byte_size() / num_devices * (num_devices - 1) / bandwidth * 1000 - - @staticmethod - def reducescatter_cost(tensor: IRTensor, num_devices: int) -> float: - # bandwidth in reduce-scatter can only be half due to torch implementation issues - # return 1e6 - bandwidth = CommCost.get_bandwidth(list(range(num_devices))) / 2.38 - return tensor.byte_size() / num_devices * (num_devices - 1) / bandwidth * 1000 - - -class CostModel: - - def __init__(self, graph: IRGraph, estimator: Callable): - - self.graph = graph - self.estimator = estimator - - # node property - self.comp_cost = {} - self.mem_cost = {} - - self.edges: Dict[int, List[int]] = {} - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): continue - for producer in graph.producers(ftensor): - if not isinstance(producer, IRFwOperation): continue - for consumer in graph.consumers(ftensor): - if not isinstance(consumer, IRFwOperation): continue - self.edges.setdefault(producer.cid, []).append(consumer.cid) - - # node.cid -> ((idx, dim),) - self.partition_algos: Dict[int, Tuple[int, int]] = {} - - fnodes = graph.select(ntype=IRFwOperation) - fnodes = [n for n in fnodes if not (isinstance(n, IRGraphAnchor) or n.name == 'multiref')] - - for fnode in fnodes: - latency, memory = self.estimator((fnode,)) - self.comp_cost[fnode.cid] = latency - self.mem_cost[fnode.cid] = memory - self.partition_algos[fnode.cid] = self.get_transform_space(fnode) - - def get_transform_space(self, node: IRFwOperation) -> List[Tuple[int, int]]: - """ - Get the transform space of a node - - None indicates replicate - """ - light_op_names = ('add', 'sub', 'mul', 'layernorm') - # light_op_names = () - if isinstance(node, IRDimops): - params = [t for t in node.inputs() if isinstance(t, IRTensor) and t.is_attr()] - # must be partitioned for computation-intensive ops - if len(params) > 0 and node.name not in light_op_names: # not node.signature.startswith('torch.'): - return list(node.transform_space()) - # can be partitioned or replicated for computation-light ops - else: - return [None] + node.transform_space() - return [None] - - def get_memory_cost(self, fnode: IRFwOperation) -> int: - if fnode.cid not in self.mem_cost: - if not (isinstance(fnode, IRGraphAnchor) or fnode.name == 'multiref'): - print(f'warning: cannot find memory cost for node {fnode.name}({fnode.cid})') - return 0 - return self.mem_cost[fnode.cid] - - def get_comp_cost(self, fnode: IRFwOperation, num_devices: int) -> np.ndarray: - """ - Get computation cost related to different partition strategies - """ - return np.zeros(len(self.partition_algos[fnode.cid]), dtype=float) - # cost = [] - # original_cost = self.comp_cost[fnode.cid] - # for strategy in self.partition_algos[fnode.cid]: - # if strategy is None: - # cost.append(original_cost) - # else: - # # computation efficiency simulation - # efficiency = 1 - (num_devices-1)*0.1/2 - # cost.append(original_cost / num_devices / efficiency) - # return np.array(cost, dtype=float) - - def get_comm_cost(self, fnode: IRFwOperation, num_devices) -> np.ndarray: - """ - Get communication cost for a node given a strategy - - This only calucates the cases for partitioning on value dimension - - @return cost: np.ndarray: 1-D array of the cost on allreduce - """ - cost = [] - for strategy in self.partition_algos[fnode.cid]: - if strategy is None: - cost.append(0.) - continue - s_cost = 0 - idx, dim = strategy - rule: TransformRule = fnode.algorithms('dim').infer(idx, dim, num_devices) - for idx, output in enumerate(rule.outputs()): - if output.isV(): - s_cost += CommCost.allreduce_cost(fnode.output(idx), num_devices) - cost.append(s_cost) - return np.array(cost, dtype=float) - - def get_pair_reshard_cost(self, fnode_src: IRFwOperation, fnode_dst: IRFwOperation, - num_devices: int) -> np.ndarray: - """ - Get cost of resharding between two nodes - @return cost: np.ndarray: 1-D tensor of (nsrc * ndst,) shape, - nsrc is the number of partitioned ways of the source node - ndst is the number of partitioned ways of the destination node - """ - nsrc = len(self.partition_algos[fnode_src.cid]) - ndst = len(self.partition_algos[fnode_dst.cid]) - cost = np.zeros((nsrc, ndst), dtype=float) - - def comm_cost(tensor: IRTensor, num_devices: int, - src_split: DimopSplit, dst_split: DimopSplit, dst_replica: bool): - # note for data parallel, we don't consider allreduce cost as it - # will only be performed at the last of iteration. - if tensor.is_attr(): return 0.0 - if src_split.isV() or src_split.isR(): - # identity-allreduce or identity-identity - if dst_split.isR(): - return 0.0 if dst_replica else CommCost.allreduce_cost(tensor, num_devices) - # split-allgather - if dst_split.isD(): - return CommCost.allgather_cost(tensor, num_devices) - if src_split.isD(): - # allgahter-reducescatter or allgather-split - if dst_split.isR(): - return CommCost.allgather_cost(tensor, num_devices) if dst_replica else \ - CommCost.allgather_cost(tensor, num_devices) + CommCost.reducescatter_cost(tensor, num_devices) - # all2all-all2all or identity-identity - if dst_split.isD(): - return 0.0 if src_split == dst_split else 2 * CommCost.alltoall_cost(tensor, num_devices) - raise NotImplementedError(f"Unknown split type: {src_split} -> {dst_split}") - - # FIXME: need consider cases that an operator has multiple **same** inputs - tensors: Dict[IRTensor, Tuple[int, int]] = {} - for idx, output in enumerate(fnode_src.outputs()): - tensors[output.parent] = [idx] - for idx, input in enumerate(fnode_dst.inputs()): - if not isinstance(input, IRTensor): continue - tensors.setdefault(input.parent, []).append(idx) - tensors = {t: tuple(v) for t, v in tensors.items() if len(v) == 2} - - for i, strategy_src in enumerate(self.partition_algos[fnode_src.cid]): - - rule_src = None - if strategy_src is not None: - idx, dim = strategy_src - rule_src = fnode_src.algorithms('dim').infer(idx, dim, num_devices) - - for j, strategy_dst in enumerate(self.partition_algos[fnode_dst.cid]): - rule_dst = None - if strategy_dst is not None: - idx, dim = strategy_dst - rule_dst = fnode_dst.algorithms('dim').infer(idx, dim, num_devices) - - for tensor, (idx_src, idx_dst) in tensors.items(): - cost[i, j] += comm_cost( - tensor, num_devices, - rule_src.outputs()[idx_src] if rule_src is not None else DimopSplit(r=True), - rule_dst.inputs()[idx_dst] if rule_dst is not None else DimopSplit(r=True), - strategy_dst is None - ) - return cost - - def get_edges(self, nodes: List[IRFwOperation]) -> Dict[IRFwOperation, Tuple[IRFwOperation]]: - """ - Get edges of a subgraph - """ - edges: Dict[IRFwOperation, List[IRFwOperation]] = {} - cid2nodes: Dict[int, IRFwOperation] = {n.cid : n for n in nodes} - for node in nodes: - if node.cid in self.edges: - edges[node] = [cid2nodes[cid] for cid in self.edges[node.cid] if cid in cid2nodes] - return edges diff --git a/examples/policies/alpa/estimator.py b/examples/policies/alpa/estimator.py deleted file mode 100644 index 07391383..00000000 --- a/examples/policies/alpa/estimator.py +++ /dev/null @@ -1,402 +0,0 @@ -from typing import Callable, Tuple, Union, Optional, Dict, NewType, List -import time -import os -import json - -# ===== neccesaary for profiling ===== -import nnscaler -import torch -# ==================================== - -from nnscaler.ir.cten import IRTensor, IRObject, IRCell -from nnscaler.ir.operator import IRFwOperation -from nnscaler.graph.parser.register import CustomizedOps -from nnscaler.graph.segment import IRSegment -from nnscaler.graph.function.dimops import IRDimops -from nnscaler.graph.function import IRGraphAnchor - - -Shapes = NewType('Shapes', Tuple[Tuple[int]]) -DTypes = NewType('DTypes', Tuple[torch.dtype]) -ShapesDTypes = NewType('ShapesDTypes', Tuple[Shapes, DTypes]) -NameOrFunc = Union[str, Callable] - - -_train_module_ref: torch.nn.Module = torch.nn.Module().train() -_eval_module_ref: torch.nn.Module = torch.nn.Module().eval() - - -class CompProfiler: - - @staticmethod - def profile(node: IRCell, train: bool = True, - warmup_sec: float = 2, prof_times: int = 50) -> Tuple[float, float, int, Tuple[int]]: - """ - Profile a function - - @param func Callable: the callable function, e.g., torch.nn.functional.linear - @param warmup_sec float: warmup seconds - @param prof_times int: profile times - - @return latency float: average latency in ms - @return memory int: average memory in bytes - """ - torch.cuda.empty_cache() - # print(f'current GPU memory: {torch.cuda.memory_allocated() / 1024 / 1024 / 1024} GB') - - func: Callable = CompProfiler.get_func(node) - args, kwargs = CompProfiler.get_inputs(node, train=train) - - # prepare gradients - with torch.no_grad(): - outputs = func(*args, **kwargs) - outputs = (outputs,) if torch.is_tensor(outputs) else outputs - assert all(torch.is_tensor(otensor) for otensor in outputs), \ - f"{func.__name__}: require all the outputs to be tensors" - grads = tuple(torch.zeros_like(otensor) for otensor in outputs) - del outputs - - def run_step(func, tensors, kwargs, backward: bool): - if not backward: - with torch.no_grad(): - outputs = func(*tensors, **kwargs) - else: - outputs = func(*tensors, **kwargs) - torch.autograd.backward(outputs, grads) - - # memory - torch.cuda.synchronize() - torch.cuda.empty_cache() - mtic = torch.cuda.max_memory_allocated() # in bytes - memory = 0 - if train: - used_tensor = set() - def pack_hook(x): - nonlocal memory, used_tensor - if x.storage().data_ptr() not in used_tensor: - used_tensor.add(x.storage().data_ptr()) - byte_size = x.element_size() - for dim in list(x.size()): - byte_size = byte_size * dim - memory += byte_size - return x - def unpack_hook(x): return x - - with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): - run_step(func, args, kwargs, backward=True) - torch.cuda.synchronize() - del used_tensor - else: - run_step(func, args, kwargs, backward=False) - torch.cuda.synchronize() - mtoc = torch.cuda.max_memory_allocated() - memory = mtoc - mtic - - # warmup - torch.cuda.synchronize() - tic = time.time() - while time.time() - tic < warmup_sec: - run_step(func, args, kwargs, backward=train) - torch.cuda.synchronize() - - torch.cuda.synchronize() - tic = time.perf_counter() - for _ in range(prof_times): - run_step(func, args, kwargs, backward=train) - torch.cuda.synchronize() - toc = time.perf_counter() - latency = (toc - tic) / prof_times * 1000 # in milliseconds - - return latency, memory - - @staticmethod - def get_inputs(node: IRFwOperation, train: bool) -> Tuple[List, Dict]: - # create data - def dummy_torch_tensor(tensor: IRTensor): - """Generate dummy input tenosrs""" - dtype = tensor.dtype - constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand - return constructor(tuple(tensor.shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=tensor.requires_grad) - - args = [dummy_torch_tensor(t) if isinstance(t, IRTensor) else t for t in node.inputs()] - # replace kwargs starting with 'self.xxx' - kwargs = {} - for name, value in node.kwargs.items(): - if isinstance(value, str) and value.startswith('self.'): - value = getattr(_train_module_ref, value[5:]) if train else getattr(_eval_module_ref, value[5:]) - kwargs[name] = value - - return args, kwargs - - @staticmethod - def get_func(node: IRFwOperation) -> Callable: - """ - Get function call - """ - assert isinstance(node, IRFwOperation), f"Only support profiling forward operation but got {type(node)}" - - def get_dep_names(sign: str): - ret = [] - code_impl = CustomizedOps.kOpCodeDef[sign] - for code_line in code_impl.split('\n'): - idx = code_line.find('# call: ') - if idx != -1: - dep_name = code_line[idx + 8:] - assert dep_name in CustomizedOps.kOpCodeDef, dep_name - ret = ret + get_dep_names(dep_name) - ret.append(dep_name) - return ret - - if node.signature in CustomizedOps.kOpCodeDef: - code_impl: str = CustomizedOps.kOpCodeDef[node.signature] - local = {} - exec(code_impl, globals(), local) - fn = list(local.values())[0] - else: - fn = eval(node.signature) - return fn - - -class ProfileDataBase: - - def __init__(self, filename: Optional[str] = None) -> None: - """! - Create a database for profiling result - """ - self._data: Dict[str, Dict[str, Tuple[float, float, int]]] = dict() - if filename is not None: - self.load(filename) - - def profile(self, node: IRFwOperation, train: bool = True, device: Optional[int] = None): - """ - Profile a forward node in IRGraph on a specific device (default current device) - - @param node IRFwOperation: node of IRGraph - @param device int: the device that the node will execute on - - @return latency float: average latency in ms - @return memory int: average memory in bytes - """ - if self.exist(node): - return self.query(node) - - if isinstance(device, int): - orig_device = torch.cuda.current_device() - torch.cuda.set_device(device) - - color, default = '\033[31m', '\033[0m' - - #FIXME: OOM will increase cuda allocated memory - try: - latency, memory = CompProfiler.profile(node, train) - # log to database - self.insert(node, latency, memory) - except Exception as e: - err = f'{color}profil error:\n {str(e)}{default}' - print(err) - latency, memory = e, e - - shapes = tuple(t.shape if isinstance(t, IRTensor) else None for t in node.inputs()) - dtypes = tuple(t.dtype if isinstance(t, IRTensor) else None for t in node.inputs()) - error = f'{color}None{default}' - print( - f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} | train {train} => " - f"latency: {round(latency, 2) if isinstance(latency, float) else error} ms | " - f"memory {memory if isinstance(memory, int) else None} bytes") - - if isinstance(device, int): - torch.cuda.set_device(orig_device) - return latency, memory - - def insert(self, node: IRCell, latency: float, memory: int): - """ - log (reset) the span of a node with key - - @param node IRCell - @param latency float: inference time in milliseconds - @param memory int: inference peak memory in bytes - """ - name = node.signature - key = self._serialize(node) - assert isinstance(name, str) and isinstance(key, str) - if name not in self._data: - self._data[name] = dict() - latency = latency if isinstance(latency, float) else None - memory = memory if isinstance(memory, int) else None - self._data[name][key] = (latency, memory) - - def exist(self, node: IRFwOperation) -> bool: - """ - Check if the node has the performance recorded in the database - - @param node IRFwOperation: forward operation - - @return exist bool: True if the performance is recorded, else False - """ - key = self._serialize(node) - if node.signature not in self._data: - return False - if key not in self._data[node.signature]: - return False - return True - - def query(self, node: IRFwOperation) -> Tuple[Tuple[int], Tuple[int], float, float, int, Tuple[int]]: - """! - Get the performance number of a node in IRGraph - - @param node IRFwOperation: node in IRGraph - - @return latency float: average latency in ms - @return memory int: average memory in bytes - """ - key = self._serialize(node) - if node.signature not in self._data: - return None - if key not in self._data[node.signature]: - return None - return self._data[node.signature][key] - - def _serialize(self, node: IRFwOperation) -> str: - """ - Serialize the shapes, dtypes and kwargs into a string - - e.g., - shapes: ((1024,), (1024,1024)) - dtypes: (torch.float32, torch.float32) - => ((1024,), (1024,1024)) : (torch.float32, torch.float32) - - @param shapes Tuple[Tuple[int]]: the shape of each tensor - @param dtypes Tuple[torch.dtype]: the dtype of each tensor - - @return key str: the serialized string - """ - shapes, dtypes = [], [] - for t in node.inputs(): - if isinstance(t, IRTensor): - shapes.append(t.shape) - dtypes.append(t.dtype) - elif isinstance(t, IRObject): - raise RuntimeError('IRObject has not been supported in _serialize') - else: - shapes.append(None) - dtypes.append(type(t)) - shapes = str(tuple(shapes)) - dtypes= str(tuple(dtypes)) - return shapes + ' : ' + dtypes - - def _deserialize(self, key: str) -> ShapesDTypes: - """ - De-serialize the key string to shapes and dtypes - - e.g., (1024,)-(1024,1024)=torch.float32-torch.float32 - => shapes: ((1024,), (1024,1024)) - dtypes: (torch.float32, torch.float32) - - @param key str: the serialized string - @return shapes_and_dtypes ShapesDTypes: shapes and dtypes - """ - shapes, dtypes = key.split(' : ') - shapes = eval(shapes) - dtypes = eval(dtypes) - # shapes = tuple(eval(shape) for shape in shapes.split('-')) - # dtypes = tuple(eval(dtype) for dtype in dtypes.split('-')) - return shapes, dtypes - - def dump(self, file: str, override=False): - """! - dump the profiled data into json format - - @param file str: the file name - @param override bool: True if the existed can be overrided else False - """ - if os.path.exists(file): - assert override, f"File {file} exists. Set override = True to force dump." - with open(file, 'w') as f: - json.dump(self._data, f) - - def load(self, file: str): - """! - load the profiled data into data base. The original existed one will be - overrided by the loaded data. - - @param file str: the file name - """ - with open(file, 'r') as f: - self._data = json.load(f) - - def __repr__(self) -> str: - data = [] - for signature in self._data: - for key in self._data[signature]: - shapes, dtypes = self._deserialize(key) - latency, memory = self._data[signature][key] - data.append(f'{signature}: shapes={shapes}, dtypes={dtypes}, latency {latency:.2f} msm, memory {memory} bytes') - data = '\n'.join(data) - return data - - -class Estimator: - """ - Estimator to measture the computation / memory cost of a subgraph - """ - def __init__(self, cache='./profile_database.json'): - - self.cache_file = cache - reload = cache if os.path.exists(cache) else None - self.database = ProfileDataBase(reload) - - def profile(self, node: IRFwOperation, train: bool) -> Tuple[float, int]: - if node.name == 'multiref' or isinstance(node, IRGraphAnchor): return 0.0, 0 - trials = Estimator.special_rules(node, [None]) - for config in trials: - if config is None: - num = 1 - latency, memory = self.database.profile(node, train) - else: - idx, dim, num = config - print(f'> ... try node {node.name} with idx={idx}, dim={dim}, num={num}') - sub_node = node.algorithms('dim').instantiate(idx=idx, dim=dim, num=num)[0] - latency, memory = self.database.profile(sub_node, train) - if isinstance(latency, float): break - if isinstance(latency, float): break - assert isinstance(latency, float), f"Failed to profile: {node}" - latency, memory = latency * num, memory * num - self.database.insert(node, latency, memory) - return latency, memory - - def __call__(self, nodes_or_segment: Union[Tuple[IRFwOperation], IRSegment], - train: bool = True): - """ - Profile the computation cost of a subgraph - - @param nodes_or_segment Tuple[IRFwOperation] | IRSegment - - @return latency float: latency in ms - @return memory int: memory in bytes - """ - nodes = nodes_or_segment.nodes() if isinstance(nodes_or_segment, IRSegment) else nodes_or_segment - memory, latency = 0.0, 0.0 - for node in nodes: - if self.database.exist(node): - node_latency, node_memory = self.database.query(node) - else: - node_latency, node_memory = self.profile(node, train) - if train: - memory += node_memory - latency += node_latency - else: - memory = max(memory, node_memory) - latency += node_latency - return latency, memory - - def save(self): - self.database.dump(self.cache_file, override=True) - - def special_rules(node, trials): - # if node.name == 'embedding': # for GPT - # trials = [(1, 0, 4),] - # if node.name == 'self_attention': # for GPT - # trials = [(1, 0, 4),] - # if node.name == 'window_attn': # for Swin - # trials = [(1, 0, 4),] - return trials diff --git a/examples/policies/alpa/inter_op.py b/examples/policies/alpa/inter_op.py deleted file mode 100644 index e8cb2bc4..00000000 --- a/examples/policies/alpa/inter_op.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -Piper policy - -https://openreview.net/attachment?id=-U9I0f2S7W&name=supplementary_material - -The implementation is a little bit adapted to fit with cube's view -""" -from typing import List, Callable, Tuple, Dict, Optional -import time - -from nnscaler.ir.operator import IRFwOperation -from examples.policies.alpa.layer_op import IRLayerOp, cluster_to_layer_ops -from examples.policies.alpa.plan import StageSpec, ParallelSpec - - -def iter_subgraph(nodes: Tuple[IRLayerOp], s: int): - """ - Iterate sub-graphs of the nodes - - @param nodes Tuple[IRFwOperation] - @param s int: number of stages - - @return (sub_graph1, sub_graph2) Tuple[Tuple[IRFwOp], Tuple[IRFwOp]] - """ - assert s > 0 - if s > 1: - # don't consider the head and tail to be anchor - assert len(nodes) >= s - 1, f"layer op: {len(nodes)}, stage: {s}" - for idx in range(len(nodes)): - remain_nodes = len(nodes) - (idx + 1) - # sub-problem of iter(sub_graph2, s-1) must iterable - if remain_nodes < s - 2: continue - sub_graph1, sub_graph2 = nodes[:idx+1], nodes[idx+1:] - yield sub_graph1, sub_graph2 - else: - # s == 1, take all - yield nodes, () - - -def DP(nodes: Tuple[IRLayerOp], k: int, s: int, intra_solver: Callable, - mbs: int, max_d: Optional[int] = None, max_t: Optional[int] = None, - _cost : Dict[Tuple, float] = None, - _config : Dict[Tuple, List[StageSpec]] = None, - _intra_cache = None) -> Tuple[Dict, Dict]: - """ - DP algorithm to search for balanced pipeline stage divisions by considering - tensor parallelism and pipeline parallelism. - - cost[D][k][s] = min_{D' \in D} min_{t, d where t*d<=k} max( - TPS(D\D',t,d,s), cost[D'][k-d*t][s-1] ) - - D: subgraph - K: number of devices - t: tensor parallelism size - d: data parallelism size - s: number of pipeline stages - - @param nodes Tuple[IRFwOperation]: sub-graph - @param k int: number of devices - @param s: number of pipeline stages - @param intra_solver: - which takes nodes, tensor parallelism size, data parallelism size - and in-flight number of microbatches, and outputs the - @param mbs: micro-batch size - @param max_d int: maximal data parallelism size constraint - @param max_t int: maximal tensor parallelism size constraint - - @return costs Dict[( (IRCell,), k, s ), latency] - @return config Dict[( (IRCell,), k, s ), [(IRCell,),] ] - """ - nodes = nodes if isinstance(nodes, tuple) else tuple(nodes) - key = (nodes, k, s) - - # initialize: dp[((), k, s)] = 0 for every k and s - _cost = dict() if _cost is None else _cost - _config = dict() if _config is None else _config - _intra_cache = dict() if _intra_cache is None else _intra_cache - max_d = k if max_d is None else max_d - max_t = k if max_t is None else max_t - if key in _cost: return _cost, _config - - # dp tatble boundary - if len(nodes) == 0: - _cost[key], _config[key] = 0, [] - return _cost, _config - - assert not (k == 0 or s == 0), \ - f"Illegal configuration: nodes: {len(nodes)} k={k}, s={s}: device number (k) cannot be smaller than pipeline stages (s)" - assert k >= s, f"Expected k >= s but got k={k}, s={s}" - - # True for 1,2,4,8,16,... - is_of_power2 = lambda n: (n & (n-1) == 0) and n != 0 - - # construct dynamic programming table - min_val = None # None means no solution - for sub1, sub2 in iter_subgraph(nodes, s): - for d in range(1, min(k + 1, max_d + 1)): - if mbs % d != 0: continue - for t in range(1, min(k // d + 1, max_t + 1)): - # constraints: all devices must be used - if s == 1 and d * t != k: continue - # only search for gpu# of power of 2 - if not is_of_power2(t * d): continue - # guarantee sub-problem searchable - if k - d * t < s - 1: continue - # constraints: every device must be used - if s - 1 > 0 and len(sub2) == 0: continue - # sub2 cost - DP(sub2, k-d*t, s-1, intra_solver, mbs, max_d, max_t, - _cost, _config, _intra_cache) - sub2_cost = _cost[(sub2, k-d*t, s-1)] - if sub2_cost is None: continue - # sub1 cost: s is also the in-flight microbatch number - sub1_config = intra_solver(sub1, d, t, s, _cache=_intra_cache) - if sub1_config is None: continue - sub1_cost = sub1_config.est_latency - # pipeline cost - cost = max(sub1_cost, sub2_cost) - config = [sub1_config] + _config[(sub2, k-d*t, s-1)] - # update - if min_val is None or cost < min_val: - min_val = cost - _config[(nodes, k, s)] = config - - _cost[key] = min_val - return _cost, _config - - -def inter_op(nodes: Tuple[IRFwOperation], ndevs: int, intra_solver: Callable, mbs: int, - max_d: Optional[int]=None, max_t: Optional[int]=None, max_p: Optional[int]=None) -> ParallelSpec: - """ - DP algorithm to search for balanced pipeline stage divisions by considering - tensor parallelism and pipeline parallelism. - - @param nodes List[IRFwOperation]: graph - @param ndevs int: number of devices - @param intra_solver Callable: estimator - which takes nodes, tensor parallelism size, data parallelism size - and in-flight number of microbatches, and outputs of - cost (latency in ms) and config (intra-tp config) - @param mbs: micro-batch size - @param max_d int: maximal data parallelism size constraint - @param max_t int: maximal tensor parallelism size constraint - - @return best_config - """ - nodes: List[IRLayerOp] = cluster_to_layer_ops(nodes) - nodes = tuple(nodes) - print(f'> search [search]: constructing dp tables ({len(nodes)} layer ops)...') - tic = time.time() - max_d = mbs if max_d is None else max_d - max_d = min(max_d, mbs, ndevs) - max_t = ndevs if max_t is None else max_t - max_t = min(max_t, ndevs) - max_p = ndevs if max_p is None else min(max_p, ndevs) - max_p = min(len(nodes), max_p) - cost, config = None, None - for nstages in range(1, max_p+1): - cost, config = DP(nodes, ndevs, nstages, intra_solver, mbs, - max_d, max_t, cost, config) - print(f'> search [search]: getting optimal results...') - min_cost, best_config = None, None - for nstages in range(1, max_p+1): - tcost = cost[(nodes, ndevs, nstages)] - if tcost is None: continue - if min_cost is None or tcost < min_cost: - min_cost = tcost - best_config = config[(nodes, ndevs, nstages)] - assert best_config is not None, f"no solution" - toc = time.time() - span = toc - tic - print(f'> search [finish]: searching time: {span} s') - print(f'> search [result]: minimal latency per microbatch {min_cost} ms') - assert all(isinstance(config, StageSpec) for config in best_config) - spec = ParallelSpec(stages=best_config) - return spec diff --git a/examples/policies/alpa/intra_op.py b/examples/policies/alpa/intra_op.py deleted file mode 100644 index 7f2af325..00000000 --- a/examples/policies/alpa/intra_op.py +++ /dev/null @@ -1,230 +0,0 @@ - -from typing import List, Tuple, Dict, Optional -import multiprocessing -import numpy as np -import warnings -import time - -from nnscaler.ir.cten import IRTensor -from nnscaler.ir.operator import IRFwOperation -from nnscaler.graph.function.anchor import IRGraphAnchor - -from examples.policies.alpa.layer_op import IRLayerOp -from examples.policies.alpa.cost_model import CostModel -from examples.policies.alpa.plan import StageSpec - -# ILP solver -import pulp -from pulp import LpVariable, LpProblem, LpMinimize, LpStatus, lpSum, lpDot, LpStatus - - -def intra_op(layer_nodes: List[IRLayerOp], dp_size: int, tp_size: int, - inflights: int, recompute: bool, memory_limit: int, - cost_model: CostModel, _cache: Dict = None) -> Optional[StageSpec]: - """ - Search for the best intra-op parallelism configuration given device mesh. - The search is only suitable for training. - """ - key = (layer_nodes, dp_size, tp_size) - if isinstance(_cache, dict) and key in _cache: return _cache[key] - - tic = time.time() - - fnodes: List[IRFwOperation] = [] - for layer_op in layer_nodes: - for node in layer_op.nodes: - if isinstance(node, IRGraphAnchor) or node.name == 'multiref': continue - fnodes.append(node) - - # search for tp configuration - - # create variables (nodes) - s, d, c = {}, {}, {} # partition index, computation cost, communication cost - e, r = [], [] # inter-node resharding cost - - num_nodes = 0 - for fnode in fnodes: - cid = fnode.cid - npartitions = len(cost_model.partition_algos[fnode.cid]) - s[cid] = LpVariable.matrix(f's[{num_nodes}]', (range(npartitions),), cat='Binary') - d[cid] = cost_model.get_comp_cost(fnode, tp_size).flatten() / dp_size - c[cid] = cost_model.get_comm_cost(fnode, tp_size).flatten() / dp_size - # setup initial value - for pidx, strategy in enumerate(cost_model.partition_algos[fnode.cid]): - if strategy is None: continue - idx, dim = strategy - identifier = fnode.anno.input(idx)[dim].identifiers[0] - if fnode.anno.getlen(identifier) % (tp_size * dp_size) != 0: - # print(f'remove transform choice on {fnode.name}({fnode.cid}) ' - # f'of strategy: {strategy} for tp={tp_size}, dp={dp_size}') - s[cid][pidx].setInitialValue(False) - s[cid][pidx].fixValue() - num_nodes += 1 - - edges = cost_model.get_edges(fnodes) - num_edges = 0 - for src, dsts in edges.items(): - for dst in dsts: - nsrc = len(cost_model.partition_algos[src.cid]) - ndst = len(cost_model.partition_algos[dst.cid]) - e.append(LpVariable.matrix(f"e[{src.cid}, {dst.cid}]", - (range(nsrc * ndst),), - cat='Binary')) - r.append(cost_model.get_pair_reshard_cost(src, dst, tp_size).flatten()) - num_edges += 1 - - # initial value: --skip - - # objective - prob = LpProblem('intra_op', LpMinimize) - # computation cost - obj = 0 - for fnode in fnodes: - cid = fnode.cid - obj += lpDot(s[cid], c[cid]) + lpDot(s[cid], d[cid]) - # communication cost - for i in range(num_edges): - obj += lpDot(e[i], r[i]) - - prob += obj - - # constraints - - # a) only one partition can be selected - for fnode in fnodes: - prob += lpSum(s[fnode.cid]) == 1 - for i in range(num_edges): - prob += lpSum(e[i]) == 1 - - # e_src_dst[i][j] = 1 => s_src[i] == 1 and s_dst[j] == 1 - eidx = 0 - for src, dsts in edges.items(): - for dst in dsts: - for row in range(len(s[src.cid])): - C = len(s[dst.cid]) - prob += lpSum( - e[eidx][row * C + col] for col in range(0, C)) <= s[src.cid][row] - for col in range(len(s[dst.cid])): - R = len(s[src.cid]) - C = len(s[dst.cid]) - prob += lpSum( - e[eidx][row * C + col] for row in range(0, R)) <= s[dst.cid][col] - eidx += 1 - - # b) memory constraint --skip - - assert "PULP_CBC_CMD" in pulp.listSolvers(onlyAvailable=True), ( - "Please install ILP solvers by 'sudo apt install coinor-cbc' or 'pip install pulp'") - - time_limit = 600 - solver = pulp.PULP_CBC_CMD( - mip=True, msg=0, - timeLimit=time_limit, - threads=multiprocessing.cpu_count()) - prob.solve(solver) - - status = prob.status - objective = pulp.value(prob.objective) - objective = float(objective) if objective is not None else -1.0 - # print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}") - # print(f"#nodes: {num_nodes}, #edges: {num_edges}") - # print(f'ILP search time: {time.time() - tic:.2f} seconds') - - # reshard_cost = 0 - # for i in range(num_edges): - # reshard_cost += lpDot(e[i], r[i]) - # reshard_cost = pulp.value(reshard_cost) - # print(f'debug info: reshard cost: {reshard_cost}') - - if prob.status in [pulp.LpStatusInfeasible]: - raise RuntimeError("Cannot run the function under the given memory budget.") - - def get_non_zero_index(binary_vector): - """Get the index of non-zero item in a vector.""" - ct = 0 - ret = None - for i, elem in enumerate(binary_vector): - if pulp.value(elem): - ret = i - ct += 1 - - assert ct == 1 - return ret - - tp_spec: Dict[int, int] = {} - for fnode in fnodes: - index = get_non_zero_index(s[fnode.cid]) - tp_spec[fnode.cid] = index - - # check results - e_val = np.full((num_edges,), -1, dtype=np.int32) - eidx = 0 - for (src, dsts) in edges.items(): - for dst in dsts: - e_val[eidx] = get_non_zero_index(e[eidx]) - src_spec_index = e_val[eidx] // len(s[dst.cid]) - dst_spec_index = e_val[eidx] % len(s[dst.cid]) - assert src_spec_index == tp_spec[src.cid] - assert dst_spec_index == tp_spec[dst.cid] - eidx += 1 - - if objective > 1e13: - warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") - - # estimate activation memory - non_recompute_mem = 0 - recompute_mem, curr_recomp_id = [0], None - for node in fnodes: - strat = cost_model.partition_algos[node.cid][tp_spec[node.cid]] - op_tp_size = 1 if strat is None else tp_size - node_mem = cost_model.get_memory_cost(node) // (dp_size * op_tp_size) - if node.recompute != curr_recomp_id: - recompute_mem.append(0) - curr_recomp_id = node.recompute - if node.recompute is None: - non_recompute_mem += node_mem - else: - recompute_mem[-1] += node_mem - act_memory = non_recompute_mem * inflights + max(recompute_mem) - - # estimate parameter memory - param_mem = 0 - pids = set() - for node in fnodes: - attrs = [t for t in node.inputs() if \ - isinstance(t, IRTensor) and t.is_attr()] - for attr in attrs: - if attr.tid in pids: continue - opt = 4 if attr.is_param() else 1 - # we estimate parameter size by assuming it will partition on weight - param_mem += opt * attr.byte_size() // tp_size - pids.add(attr.tid) - - # print(f'debug: inflights: {inflights}, act memory: {act_memory/1024/1024/1024}, param mem: {param_mem/1024/1024/1024}') - mem_cost = act_memory + param_mem - if mem_cost > memory_limit: - print(f'searching results of {len(tp_spec)} nodes: tp={tp_size}, dp={dp_size}: no solution (memory: {mem_cost/1024/1024/1024} GB)') - return None - - # get tensor parallelism spec - stage_tp_spec = {} - names = {} - for fnode in fnodes: - strategy = None if tp_size == 1 else \ - cost_model.partition_algos[fnode.cid][tp_spec[fnode.cid]] - stage_tp_spec[fnode.cid] = strategy - names[fnode.cid] = fnode.name - - config = StageSpec( - est_latency=objective / 3 * 4 if recompute else objective, - est_memory=mem_cost, - tp_size=tp_size, - dp_size=dp_size, - tp_spec=stage_tp_spec, - names=names, - ) - print(f'searching results of {len(stage_tp_spec)} nodes: tp={tp_size}, dp={dp_size} ' - f'latency={objective}, memory={mem_cost/1024/1024/1024} GB') - if isinstance(_cache, dict): _cache[key] = config - # print(config) - return config diff --git a/examples/policies/alpa/layer_op.py b/examples/policies/alpa/layer_op.py deleted file mode 100644 index bf220456..00000000 --- a/examples/policies/alpa/layer_op.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import List, Dict, Tuple -import more_itertools - -from nnscaler.ir.cten import IRCell -from nnscaler.ir.operator import IRFwOperation -from nnscaler.graph.graph import IRGraph -from nnscaler.graph.function.anchor import IRGraphAnchor - - -class IRLayerOp(IRCell): - - def __init__(self, nodes: List[IRCell], layer_id: int = None): - super().__init__('layer_op', 'layer_op', 0, 0) - self.nodes = nodes - self.layer_id : int = layer_id - - -def cluster_to_layer_ops(nodes: List[IRFwOperation]) -> List[IRLayerOp]: - layer_ops: List[IRLayerOp] = [] - ops = [] - for node in nodes: - if isinstance(node, IRGraphAnchor): - if len(ops) != 0: - layer_ops.append(IRLayerOp(ops, layer_id=len(layer_ops))) - ops = [node] - elif isinstance(node, IRFwOperation): - ops.append(node) - if len(ops) != 0: - layer_ops.append(IRLayerOp(ops, layer_id=len(layer_ops))) - return layer_ops - - -def annotate_structure(graph: IRGraph) -> List[Tuple[IRFwOperation]]: - """Annotate graph stucture in generated code""" - anchors = graph.select(ntype=IRGraphAnchor) - for idx, anchor in enumerate(anchors): - nidx = graph.index(anchor) - graph.node(nidx + 1).comment = f'===> split position {idx}: {anchor.name}' - fnodes = graph.select(ntype=IRFwOperation) - subgraphs = more_itertools.split_before(fnodes, lambda n: isinstance(n, IRGraphAnchor)) - return list(subgraphs) - \ No newline at end of file diff --git a/examples/policies/alpa/plan.py b/examples/policies/alpa/plan.py deleted file mode 100644 index 291ee7c8..00000000 --- a/examples/policies/alpa/plan.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Dict, Tuple, Optional -from dataclasses import dataclass -import json - -from nnscaler.ir.operator import IRFwOperation -from nnscaler.graph.graph import IRGraph - -@dataclass -class StageSpec: - # estimation - est_latency: float # in milliseconds - est_memory: float # in types - # config - tp_size: int - dp_size: int - # node.cid -> (idx, num) | None - tp_spec: Dict[int, Optional[Tuple[int, int]]] - # node.cid -> node.name - names: Dict[int, str] - - def __repr__(self) -> str: - dscp = '' - for cid, strategy in self.tp_spec.items(): - strategy = 'Replicate' if strategy is None else f"idx={strategy[0]}, dim={strategy[1]}, num={self.tp_size}" - dscp += f' {self.names[cid]}({cid}): {strategy}\n' - return dscp - - def to_dict(self) -> Dict: - return { - 'est_latency': self.est_latency, - 'est_memory': self.est_memory, - 'tp_size': self.tp_size, - 'dp_size': self.dp_size, - 'tp_spec': self.tp_spec, - 'names': self.names - } - - @staticmethod - def from_dict(d: Dict): - tp_spec = {int(cid): spec for cid, spec in d['tp_spec'].items()} - names = {int(cid): name for cid, name in d['names'].items()} - return StageSpec( - est_latency=d['est_latency'], - est_memory=d['est_memory'], - tp_size=d['tp_size'], - dp_size=d['dp_size'], - tp_spec=tp_spec, - names=names - ) - - -@dataclass -class ParallelSpec: - stages: Tuple[StageSpec] - - @property - def est_latency(self) -> float: - return max(s.est_latency for s in self.stages) - - def save(self, filename: str): - """ - Save plan into json file - """ - with open(filename, 'w') as f: - json.dump([s.to_dict() for s in self.stages], f) - - def getstate(self) -> str: - """ - Get plan state as json string - """ - return json.dumps([s.to_dict() for s in self.stages]) - - @staticmethod - def loadstate(state: str): - """ - Load plan from json string - """ - stages = json.loads(state) - return ParallelSpec(tuple(StageSpec.from_dict(s) for s in stages)) - - @staticmethod - def load(filename: str, check_graph_consistent: IRGraph = None): - """ - Load plan from json file - """ - with open(filename, 'r') as f: - stages = json.load(f) - spec = ParallelSpec(tuple(StageSpec.from_dict(s) for s in stages)) - if check_graph_consistent is not None: - graph = check_graph_consistent - cid2name = {n.cid: n.name for n in graph.select(ntype=IRFwOperation)} - for stage in spec.stages: - for cid, name in stage.names.items(): - assert cid in cid2name, f'graph is not consistent with plan: node cid {cid}:{name} not found in graph' - assert cid2name[cid] == name, f'graph is not consistent with plan: cid {cid}:{name} name mismatch' - return spec - - def __repr__(self) -> str: - dscp = f'nstages: {len(self.stages)} | latency: {self.est_latency} ms' - for sidx, stage in enumerate(self.stages): - tp, dp = stage.tp_size, stage.dp_size - latency, memory = stage.est_latency, stage.est_memory / 1024 / 1024 / 1024 - dscp += f'\nStage {sidx} (tp={tp}, dp={dp}, latency={latency} ms, memory={memory}):\n' - dscp += f'{stage}' - return dscp diff --git a/examples/policies/gshard.py b/examples/policies/gshard.py deleted file mode 100644 index f4378f6e..00000000 --- a/examples/policies/gshard.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Policy example following GShard -""" - -from typing import List - -from nnscaler.ir.tensor import IRSubTensor -from nnscaler.ir.operator import IRDataOperation, IRFwOperation -from nnscaler.graph.graph import IRGraph -from nnscaler.graph.function.dimops import IRDimops -from nnscaler.graph.function.anchor import IRGraphAnchor - - -def follow(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int, - nodes: List[IRDimops]) -> List[IRDimops]: - """ - Partition nodes along one tensor dimension - - @param node IRDimops: the entry node - @param devs List[int]: the devices - @param idx int: entry node partition config idx - @param dim int: entry node partition config dim - @param nodes List[IRDimops]: partition node scopes - - @return remain_nodes List[IRDimops]: remaining nodes that are not partitioned - """ - assert node in nodes - algo = node.algorithms('dim') - if not algo.satisfy(idx=idx, dim=dim, num=len(devs)): return nodes - # tensor parallelism - sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - # partition successors - nodes.remove(node) - for oidx, tensor in enumerate(node.outputs()): - if not isinstance(tensor, IRSubTensor): continue - ftensor = tensor.parent - for pdim in range(len(ftensor.shape)): - if sub_nodes[0].output(oidx).shape[pdim] != ftensor.shape[pdim]: - break - else: - continue - for consumer, ctensor in zip(graph.consumers(ftensor), graph.ctensors(ftensor)): - if not isinstance(consumer, IRDimops): continue - if isinstance(consumer, IRGraphAnchor) or consumer.name == 'multiref': continue - if consumer in nodes: - cidx = consumer.inputs().index(ctensor) - follow(graph, consumer, devs, cidx, pdim, nodes) - return nodes - - -def PASGShard(graph: IRGraph, resource): - - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): continue - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor) - - devs = list(range(resource.ngpus)) - - def replicate(node): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - - # print(graph.extra_repr()) - - fwops = graph.select(ntype=(IRDataOperation, IRFwOperation)) - print(f'> total fwops: {len(fwops)}') - while len(fwops) > 0: - fwop = fwops[0] - if isinstance(fwop, IRGraphAnchor) or fwop.name == 'multiref': - fwops.pop(0) - continue - # replicate if the node is not IRDimops - if not isinstance(fwop, IRDimops): - replicate(fwop) - fwops.pop(0) - continue - # partition along the longest dimension - configs = fwop.transform_space() - configs = sorted(configs, reverse=True, - key=lambda config: fwop.input(config[0]).shape[config[1]]) - for (idx, dim) in configs: - if fwop.input(idx).shape[dim] % len(devs) != 0: continue - if fwop.algorithms('dim').satisfy(idx=idx, dim=dim, num=len(devs)): - print(f'> policy partition: entry Fwop{fwop.cid}: {fwop.name} idx={idx}, dim={dim}') - follow(graph, fwop, devs, idx, dim, fwops) - print(f'> remaining fwops: {len(fwops)}') - break - else: - replicate(fwop) - fwops.pop(0) - return graph diff --git a/examples/policies/random_spmd.py b/examples/policies/random_spmd.py deleted file mode 100644 index 7b176f8e..00000000 --- a/examples/policies/random_spmd.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Random SPMD policy -""" -from typing import List, Optional -from nnscaler.graph.graph import IRGraph -from nnscaler.graph.function.dimops import IRDimops -from nnscaler.ir.operator import IRDataOperation, IRFwOperation -from nnscaler.graph.function.anchor import IRGraphAnchor -from datetime import datetime - -import random - - -def _tp(graph: IRGraph, node: IRDimops, devs: List[int], idx: int, dim: int): - sub_nodes = graph.partition( - node, node.algorithms('dim'), idx=idx, dim=dim, num=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def _replica(graph: IRGraph, node, devs: List[int]): - sub_nodes = graph.replicate(node, times=len(devs)) - for devid, sub_node in zip(devs, sub_nodes): - graph.assign(sub_node, devid) - return sub_nodes - - -def PASRandomSPMD(graph: IRGraph, resource, seed: Optional[int] = None): - """ - Random SPMD policy - """ - # get the current random state - state = random.getstate() - - seed = int(datetime.now().timestamp()) if seed is None else seed - print(f'> set random SPDM policy seed to {seed}') - random.seed(seed) - devs = list(range(resource.ngpus)) - - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): continue - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor) - - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if node.name == 'multiref' or isinstance(node, IRGraphAnchor): - continue - if isinstance(node, IRDimops): - configs = node.transform_space() - if len(configs) == 0: - _replica(graph, node, devs) - else: - configs = sorted(configs, reverse=True, - key=lambda config: node.input(config[0]).shape[config[1]]) - random.shuffle(configs) - for (idx, dim) in configs: - if node.input(idx).shape[dim] % len(devs) != 0: continue - if node.algorithms('dim').satisfy(idx=idx, dim=dim, num=len(devs)): - print(f'> partition node {node.name} ({node.cid}) with config idx={idx}, dim={dim}') - _tp(graph, node, devs, idx, dim) - break - else: - _replica(graph, node, devs) - else: - _replica(graph, node, devs) - - # restore the random state - random.setstate(state) - return graph diff --git a/examples/utils.py b/examples/utils.py index f08b6191..74192693 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,6 +1,9 @@ from typing import List, Union, Callable, Optional, Tuple import logging +import torch + +import nnscaler from nnscaler.graph import IRGraph from nnscaler.graph.segment import IRSegment from nnscaler.graph.function.dimops import IRDimops @@ -9,6 +12,7 @@ from nnscaler.ir.cten import IRCell from nnscaler.ir.tensor import IRFullTensor from nnscaler.graph.function.anchor import IRGraphAnchor +import nnscaler.runtime from nnscaler.utils import print_each_rank import numpy as np @@ -24,7 +28,7 @@ def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: The product of group_num should be same with total devices. e.g., 6 device to 2 x 3 mesh will results [dim][group_id] = tuple[int]: - ( + ( ( (0,1,2), (3,4,5) ), ( (0,3), (2,5), (3,6) ), ) @@ -66,7 +70,7 @@ def _tp_autoplace(segment: IRSegment, ftensor: IRFullTensor, producers: List[IRFwOperation], devs: List[int], sub_nodes: List[IRFwOperation]) -> List[int]: """decide the devices of the partitioned `sub-nodes` to achieve optimal communication - + Args: segment (IRSegment): segment of the ftensor ftensor (IRFullTensor): the tensor to be partitioned @@ -87,7 +91,7 @@ def _tp_autoplace(segment: IRSegment, ftensor: IRFullTensor, return devs # tensor parallelism -def tensor_parallelism(graph: IRGraph, node: IRDimops, +def tensor_parallelism(graph: IRGraph, node: IRDimops, idx: int, dim: int, devs: List[int], autoplace: bool = False) -> List[IRDimops]: """Apply tensor parallelism of a node to devs""" @@ -111,7 +115,7 @@ def tensor_parallelism(graph: IRGraph, node: IRDimops, # replica -def replica(graph: IRGraph, node: Union[IRFwOperation, IRDataOperation], +def replica(graph: IRGraph, node: Union[IRFwOperation, IRDataOperation], devs: List[int]) -> List[Union[IRFwOperation, IRDataOperation]]: """Replicate a forward node or dataloader to devs""" if len(devs) == 1: @@ -131,7 +135,7 @@ def get_policy(modules: List, name: str) -> Callable: Args: modules (List): list of modules name (str): name of policy - + Returns: Callable: policy """ @@ -142,4 +146,11 @@ def get_policy(modules: List, name: str) -> Callable: policies = [] for module in modules: policies += list(policy for policy in module.__dict__.keys() if policy.startswith('PAS')) - raise ValueError(f"policy {name} not found. Candidates: {policies}") \ No newline at end of file + raise ValueError(f"policy {name} not found. Candidates: {policies}") + + +def init_random(): + np.random.seed(1) + torch.manual_seed(1) + if torch.cuda.is_available(): + torch.cuda.manual_seed(1) diff --git a/examples/openfold/blocks/__init__.py b/examples/vision/swin/__init__.py similarity index 100% rename from examples/openfold/blocks/__init__.py rename to examples/vision/swin/__init__.py diff --git a/examples/vision/swin/baseline.py b/examples/vision/swin/baseline.py index 0678a105..f7976e41 100644 --- a/examples/vision/swin/baseline.py +++ b/examples/vision/swin/baseline.py @@ -7,6 +7,7 @@ import math +from typing import List, Tuple import warnings import torch @@ -628,7 +629,47 @@ def flops(self): return flops -class ImageDataLoader(nnscaler.runtime.syndata.CubeDataLoader): +class CubeDataLoader: + r""" + Cube Dataloader + """ + def __init__(self, shapes: Tuple[List[int]], dtypes: Tuple[torch.dtype], batch_dims: Tuple[int] = None): + """ + shapes Tuple[Tuple[int]]: + The shape for each data + dtypes Tuple[torch.dtype]: + The dtype for each data + batch_dims Tuple[int]: + The batch dimension of each data + """ + if not all(isinstance(shape, list) for shape in shapes): + raise TypeError("Expected each shape in shapes to be a list") + if len(shapes) != len(batch_dims) or len(shapes) != len(dtypes): + raise TypeError("Expected number batch dim and dtypes to len(shapes)") + self.shapes = tuple([list(shape) for shape in shapes]) + self.dtypes = dtypes + self.batch_dims = (0,) * len(self.shapes) if batch_dims is None else batch_dims + + def get_batch_size(self) -> int: + """ + get batch size + """ + all_batch_size = set([shape[dim] for shape, dim in zip(self.shapes, self.batch_dims)]) + if len(all_batch_size) != 1: + raise ValueError("Heterogenous batch size in dataloader") + return list(all_batch_size)[0] + + def set_batch_size(self, batch_size: int): + """ + set batch size + """ + self.batch_size = batch_size + for shape, dim in zip(self.shapes, self.batch_dims): + shape[dim] = batch_size + print(f'> data loader output shape change to: {self.shapes}') + + +class ImageDataLoader(CubeDataLoader): def __init__(self, batch_size: int, img_size: int, num_classes: int): @@ -643,7 +684,7 @@ def __init__(self, batch_size: int, img_size: int, num_classes: int): batch_dims=(0, 0) ) self.samples = [self.random_sample()] - + def random_sample(self): torch.manual_seed(0) img = torch.rand( @@ -658,7 +699,7 @@ def random_sample(self): device=torch.cuda.current_device() ) return (img, labels) - + def __iter__(self): return self @@ -684,7 +725,7 @@ class Config: drop_path_rate = 0.2 drop_rate = 0.2 - + # 224 x 224 img_size = 224 @@ -759,7 +800,7 @@ def train_iter(model, dataloader): if step == 0: print_each_rank('passed first iteration') - + if (step + 1) % 2 == 0: print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) diff --git a/examples/vision/swin/blocks/attention.py b/examples/vision/swin/blocks/attention.py index 07375626..72f4c3d8 100644 --- a/examples/vision/swin/blocks/attention.py +++ b/examples/vision/swin/blocks/attention.py @@ -7,7 +7,7 @@ # this cannot partition on head dimension # as the head dimension is a secondary hidden dimension in (3 head dim_head). # To make partition work (correctness guarantee), the dimension is swapped as (head dim_head 3) -@nnscaler.register_op('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), nw N^ N^ -> B N^ C^') +@nnscaler.register_op('B N^ C^, (h+ dh^ 3) C^, (h+ dh^ 3), (wh^ ww^) (wh^ ww^), t^ h+, C^ (h+ dh^), ? -> B N^ C^') def window_attn(x: torch.Tensor, qkv_w: torch.Tensor, qkv_bias: torch.Tensor, relative_position_index: torch.Tensor, diff --git a/examples/vision/swin/blocks/patch.py b/examples/vision/swin/blocks/patch.py index 3d8a124d..8c48f551 100644 --- a/examples/vision/swin/blocks/patch.py +++ b/examples/vision/swin/blocks/patch.py @@ -6,7 +6,7 @@ import nnscaler -@nnscaler.register_op('B (2 h^ 2 w^) C^ -> B (h w) (4 C)') +@nnscaler.register_op('B (2 h^ 2 w^) C^ -> B (h^ w^) (4 C^)') def patch_merge(x: torch.Tensor, h: int, w: int): B, L, C = x.shape H = 2 * h @@ -22,6 +22,7 @@ def patch_merge(x: torch.Tensor, h: int, w: int): x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C return x + @nnscaler.register_op('B ic+ (ps^ w^) (ps^ h^), oc ic+ k^ k^, oc -> B oc w^ h^') def patch(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, ps: int): """ diff --git a/examples/vision/swin/blocks/transformer.py b/examples/vision/swin/blocks/transformer.py index bfd520b1..946948a2 100644 --- a/examples/vision/swin/blocks/transformer.py +++ b/examples/vision/swin/blocks/transformer.py @@ -20,7 +20,7 @@ def drop_path(x: torch.Tensor, drop_prob: float, training: bool): return output -@nnscaler.register_op('B (nh ws) (nw ws) C -> (B nh nw) ws ws C') +@nnscaler.register_op('B (nh^ ws^) (nw^ ws^) C -> (B nh^ nw^) ws^ ws^ C') def window_partition(x: torch.Tensor, ws: int): """ Args: @@ -36,7 +36,7 @@ def window_partition(x: torch.Tensor, ws: int): return windows -@nnscaler.register_op('(B nh nw) ws ws C -> B (nh ws) (nw ws) C') +@nnscaler.register_op('(B nh^ nw^) ws^ ws^ C -> B (nh^ ws^) (nw^ ws^) C') def window_reverse(windows: torch.Tensor, ws: int, nh: int, nw: int): """ Args: diff --git a/examples/vision/swin/model.py b/examples/vision/swin/model.py index a2853ca8..c28b3778 100644 --- a/examples/vision/swin/model.py +++ b/examples/vision/swin/model.py @@ -44,7 +44,7 @@ class Config: drop_path_rate = 0.2 drop_rate = 0.2 attn_drop_rate = 0.0 - + # dataloader # 224 x 224 @@ -146,7 +146,7 @@ def __init__(self): # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=cfg.img_size, - patch_size=self.patch_size, + patch_size=self.patch_size, in_chans=3, embed_dim=cfg.embed_dim, norm_layer=nn.LayerNorm ) @@ -221,7 +221,7 @@ def flops(self): # =========================== Data Loader ======================= -def dummy_data(batch_size: int, +def dummy_data(batch_size: int, dtype: torch.dtype, cfg: Config): input_ids = torch.randn( [batch_size, 3, cfg.img_size, cfg.img_size], diff --git a/examples/vision/swin/policy/gallery.py b/examples/vision/swin/policy/gallery.py index bdc16cfa..380c30b6 100644 --- a/examples/vision/swin/policy/gallery.py +++ b/examples/vision/swin/policy/gallery.py @@ -1,8 +1,13 @@ from typing import List +import more_itertools as mitr +import itertools + +from nnscaler import ComputeConfig from nnscaler.graph import IRGraph from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.schedule.predefined import PredefinedSched +from nnscaler.graph.segment import IRSegment from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from examples.utils import tensor_parallelism, replica, group_to_layers @@ -24,42 +29,19 @@ def coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, return sub_nodes -def PASSingle(graph: IRGraph, resource, **kwargs): - assert resource.ngpus == 1 - # print(graph.extra_repr()) - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - return graph - - -def PASData(graph: IRGraph, resource, **kwargs): - """Data parallelism""" - devs = list(range(resource.ngpus)) - dataloader = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[0] - # replicate dataloader - replica(graph, dataloader, devs) - # partition forward operators - for node in graph.select(ntype=IRFwOperation): - if isinstance(node, IRGraphAnchor): continue - try: - tensor_parallelism(graph, node, idx=0, dim=0, devs=devs) - except Exception as e: - _logger.warning(f'fail to partition node {node.name} at idx=0, using replica') - replica(graph, node, devs) - return graph +def pas_megatron(graph: IRGraph, cfg: ComputeConfig): + """Megatron-way tensor parallelism""" + devs = list(range(cfg.plan_ngpus)) + # skip mutliref because the partition of tensors in transformer are the same for all tensors -def PASMegatronTP(graph: IRGraph, resource, **kwargs): - """Megatron-way tensor parallelism""" - devs = list(range(resource.ngpus)) # attention for attn in graph.select(name='window_attn'): tensor_parallelism(graph, attn, idx=1, dim=0, devs=devs) # feedforward for ffn in graph.select(name='feedforward'): tensor_parallelism(graph, ffn, idx=1, dim=0, devs=devs) + # replicate other nodes for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): if len(node.device) == 0: @@ -67,9 +49,19 @@ def PASMegatronTP(graph: IRGraph, resource, **kwargs): return graph -def PASMeshShard(graph: IRGraph, resource, **kwargs): - """Coshard policy example""" - devs = list(range(resource.ngpus)) +def pas_mesh_shard(graph: IRGraph, cfg: ComputeConfig): + """ + Coshard policy example + + It will partition a tensor `colocate*plan_ngpus` subtensors, + and each device will have `colocate` subtensors. + + This can save GPU memory when work with recompute + """ + devs = list(range(cfg.plan_ngpus)) + + # skip mutliref because the partition of tensors in transformer are the same for all tensors + # attention for attn in graph.select(name='window_attn'): # _tp(graph, attn, tp_devs, idx=1, dim=0) @@ -78,6 +70,7 @@ def PASMeshShard(graph: IRGraph, resource, **kwargs): for ffn in graph.select(name='feedforward'): # _tp(graph, ffn, tp_devs, idx=1, dim=0) coshard(graph, ffn, devs, colocate=4, idx=1, dim=0) + # replicate other nodes for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): if len(node.device) == 0: @@ -85,22 +78,29 @@ def PASMeshShard(graph: IRGraph, resource, **kwargs): return graph -def PAS1F1B(graph: IRGraph, resource, nmicros: int, **kwargs): +def pas_1f1b(graph: IRGraph, cfg: ComputeConfig): """1F1B schedule""" - num_stages = resource.ngpus - num_microbatch = nmicros + num_stages = cfg.pipeline_nstages + if num_stages != cfg.plan_ngpus: + raise ValueError('1F1B schedule requires num_stages == plan_ngpus') + # group to transformer layers transformers = group_to_layers(graph.select(ntype=IRFwOperation)) + stages = mitr.divide(cfg.pipeline_nstages, transformers) + stages = [list(itertools.chain(*s)) for s in stages] + graph.staging([t[0] for t in stages]) + # staging - nlayer_per_stage = (len(transformers) // resource.ngpus) - for lid, fnodes in enumerate(transformers): - stage_id = min(lid // nlayer_per_stage, num_stages-1) - _logger.info(f'assigning {lid}-th transformer layter to stage {stage_id}') - for fnode in fnodes: - graph.assign(fnode, stage_id) + stages: List[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + stages = [s for s in stages if s.isfw()] + + for idx, stage in enumerate(stages): + for fnode in stage.nodes(): + graph.assign(fnode, idx) + # replicate dataloader for node in graph.select(ntype=IRDataOperation): - replica(graph, node, list(range(resource.ngpus))) + replica(graph, node, list(range(cfg.plan_ngpus))) # apply 1f1b schedule - PredefinedSched.sched_1f1b(graph, num_microbatch, num_stages) + cfg.apply_pipeline_scheduler(graph) return graph diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index ba109c59..a9cfac66 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -1,84 +1,103 @@ """ example: -PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - examples/vision/swin/train.py --policy PASMegatronTP --fp16 +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/vision/swin/train.py --policy pp --pp_size 4 --gbs 16 --fp16 + +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/vision/swin/train.py --policy 1f1b --pp_size 4 --gbs 16 --fp16 + +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/vision/swin/train.py --policy megatron --tp_size 4 --gbs 16 --fp16 + +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/vision/swin/train.py --policy megatron --tp_size 2 --dp_size 2 --gbs 16 --fp16 + +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/vision/swin/train.py --policy mesh_shard --tp_size 4 --gbs 16 --fp16 + +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=4 \ + examples/vision/swin/train.py --policy autodist --tp_size 2 --dp_size 2 --gbs 16 --fp16 """ -import math +import logging +import itertools + import torch from functools import partial -from examples.vision.swin.blocks.attention import init_relative_position_index +from examples.utils import init_random from examples.vision.swin.model import Config, SwinTransformer, dummy_data import nnscaler -from nnscaler.compiler import compile from nnscaler.profiler.timer import CudaTimer, print_each_rank from nnscaler.profiler.memory import memory_summary -from nnscaler.runtime.utils import microbatches import examples.vision.swin.policy.gallery as gallery -from examples.utils import get_policy import argparse -parser = argparse.ArgumentParser(description='GPT Train') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') -parser.add_argument('--fp16', action='store_true', default=False, - help='use fp16 for the training') -parser.add_argument('--dp', type=int, default=1, - help='data parallel size, only for megatron') -parser.add_argument('--tp', type=int, default=1, - help='tensor parallel size, only for megatron') -# training -parser.add_argument('--gbs', type=int, default=4, help='global batch size') -parser.add_argument('--mbs', type=int, default=4, help='micro batch size') - -args = parser.parse_args() -nnscaler.init() +import nnscaler.utils -# get policy -policy = get_policy([gallery], args.policy) -policy = partial(policy, - nmicros=args.gbs//args.mbs, - dp_size=args.dp, - tp_size=args.tp -) +def src_hash(): + import hashlib + from pathlib import Path + h = hashlib.md5() + nnscaler_dir = Path(nnscaler.__file__).parent + example_dir = nnscaler_dir.with_name('examples') + for f in itertools.chain(nnscaler_dir.glob('**/*.py'), example_dir.glob('**/*.py')): + h.update(f.stat().st_mtime_ns.to_bytes(8, 'little')) + return h.hexdigest() -def train(): - - batch_size = args.mbs - load_content: bool = False +def train(args, compute_config: nnscaler.ComputeConfig): + nnscaler.utils.set_default_logger_level(logging.INFO) cfg = Config() model = SwinTransformer() model = model.half() if args.fp16 else model - - dtype = torch.float16 if args.fp16 else torch.float32 - - - gen_data = partial(dummy_data, args.mbs, torch.float16, cfg) - dataloader = microbatches((gen_data(),)) - - @compile(model, dataloader, PAS=policy, load_content=load_content) - def train_iter(model, dataloader): - imgs = next(dataloader) - loss = model(imgs) - loss.backward() - model = nnscaler.utils.load_model() - - if not load_content: - for name, buffer in model.named_buffers(): - if 'rp_index' in name: - window_size = int(math.sqrt(buffer.size(0))) - buffer.copy_(init_relative_position_index(window_size).cuda()) - - optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999)) + gen_data = partial(dummy_data, args.mbs, torch.float16 if args.fp16 else torch.float32, cfg) + + init_random() + DATA_SIZE = 1024 + data = [] + for _ in range(DATA_SIZE): + data.append(gen_data()) + + num_replicas = compute_config.runtime_ngpus // compute_config.plan_ngpus + rank = torch.distributed.get_rank() // compute_config.plan_ngpus + data = [data[i] for i in range(rank, len(data), num_replicas)] + chunk_size = args.gbs // args.mbs + data = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)] + + # get policy + prefix = 'pas_' + policy_name = prefix + args.policy + if policy_name in gallery.__dict__: + policy = gallery.__dict__[policy_name] + else: + policy = args.policy # use the builtin policies + + model: nnscaler.ParallelModule = nnscaler.parallelize( + model, + dummy_input={'x': gen_data()}, + pas_policy=policy, + compute_config=compute_config, + reuse='moo', + instance_name=args.policy + ) + model.cuda() + + optimizer = nnscaler.build_optimizer(model, torch.optim.Adam, lr=5e-4, betas=(0.9, 0.999)) torch.distributed.barrier() + print_each_rank('model weight consumpition:') memory_summary() nparams = 0 @@ -86,32 +105,71 @@ def train_iter(model, dataloader): nparams += param.nelement() print_each_rank(f'model parameter: {nparams}') - CudaTimer().warmup() - iter_num, warmup = 5, 2 - for step in range(iter_num): - if step == warmup: - CudaTimer(enable=True).start('e2e') + iter_num = 5 + for idx in range(iter_num): + model.train() # collect data - samples = [gen_data() for _ in range(args.gbs // args.mbs)] - dataloader = microbatches(samples, dtype=dtype) - # train iteration - train_iter(model, dataloader) + samples = data[idx] + + model.train_step(samples) optimizer.step() optimizer.zero_grad() - if step == 0: - print_each_rank('passed first iteration') - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - CudaTimer().stop('e2e') - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) memory_summary() if __name__ == '__main__': - - train() + parser = argparse.ArgumentParser(description='swin Train') + parser.add_argument('--policy', type=str, help='PAS policy choice') + parser.add_argument('--fp16', action='store_true', default=False, + help='use fp16 for the training') + parser.add_argument('--dp_size', type=int, default=1, + help='size of data parallelism') + parser.add_argument('--pp_size', type=int, default=1, + help='size of pipeline parallelism') + parser.add_argument('--tp_size', type=int, default=1, + help='size of tensor parallelism') + parser.add_argument('--zero', action='store_true', default=False, + help='use zero1 for the training') + parser.add_argument('--mbs', type=int, default=4, help='micro batch size') + parser.add_argument('--gbs', type=int, default=4, help='global batch size') + + args = parser.parse_args() + + nnscaler.init() + + if torch.distributed.get_world_size() != args.dp_size * args.pp_size * args.tp_size: + raise ValueError('world size should be equal to dp_size * pp_size * tp_size') + if args.gbs % args.mbs != 0: + raise ValueError('global batch size should be divisible by micro batch size') + + compute_config=nnscaler.ComputeConfig( + plan_ngpus=args.pp_size * args.tp_size, + runtime_ngpus=torch.distributed.get_world_size(), + use_zero=args.zero, + use_end2end=True, + dynamic_shape=False, + use_pipeline=args.pp_size > 1, + pipeline_nmicros=args.gbs // args.mbs, + pipeline_nstages=args.pp_size, + user_config=nnscaler.UserConfig( + graph={ + 'mbs': args.mbs, + 'fp16': args.fp16, + 'src_hash': src_hash(), + }, + code={ + 'pas_name': args.policy, + 'gbs': args.gbs, + 'pp_size': args.pp_size, + 'tp_size': args.tp_size, + 'dp_size': args.dp_size, + 'pas': { + 'update_freq': args.gbs // args.mbs, # for autodist only + } + } + ) + ) + + train(args, compute_config) diff --git a/nnscaler/algorithm/ops/dimops.py b/nnscaler/algorithm/ops/dimops.py index 9c200ad1..794e4a72 100644 --- a/nnscaler/algorithm/ops/dimops.py +++ b/nnscaler/algorithm/ops/dimops.py @@ -241,7 +241,7 @@ def modify(kwargs: Dict, idx: int, dim: int, num: int): return TransformRule(itransform, otransform, modify) -def collect_split_info(node: IRFwOperation): +def collect_split_info(node: IRDimops): """ Collect the split information of the node. Args: @@ -258,6 +258,7 @@ def collect_split_info(node: IRFwOperation): split_info = {} for idx_shape, shape_anno in enumerate(anno.inputs()): + if shape_anno.ignore: continue if not isinstance(node.inputs()[idx_shape], IRSubTensor): continue for idx_dim, dim_anno in enumerate(shape_anno.dims): diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 56bdbc7d..1286b140 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -3,7 +3,7 @@ from .autodist_config import AutoDistConfig from .model_graph import ModelGraph, estimate_mem_lower_bound from .descs import * -from .util import get_node_arch, replica, partition_node +from .util import replica, partition_node from nnscaler.graph import IRGraph from nnscaler.graph.segment import IRSegment @@ -29,7 +29,7 @@ def check_env(autodist_config: AutoDistConfig): - arch_dir = autodist_config.profile_dir / get_node_arch() + arch_dir = Path(autodist_config.profile_dir) if not arch_dir.exists(): _logger.info(f'create folder: {arch_dir}') arch_dir.mkdir(parents=True, exist_ok=True) diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index b261de58..87d165be 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -44,7 +44,7 @@ class AutoDistConfig: - fp16 & bf16 training w/ memory efficient adam w/o inkernal cast: (2 + 2) (fp32 weight + fp32 gradient) - partition_constraints_path (`str`, *optional*, defaults to `''`): The path to the partition constraints file. Details can be found in docs/solver_interface/partition_constraints.md - - profile_dir (`str`, *optional*, defaults to `~/.cache/nnscaler/autodist`): + - profile_dir (`str`, *optional*, defaults to `~/.cache/nnscaler/autodist/1.0/get_node_arch()`): The directory to store the profiling results. - load_plan_path (`str`, *optional*, defaults to `''`): The path to the plan file to load. If specified, the plan will be loaded from the file instead of searching. diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index d34125a3..2408b469 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -15,7 +15,6 @@ from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.function.dimops import DimopSplit, IRDimops -from .util import get_node_arch from .autodist_config import AutoDistConfig _logger = logging.getLogger(__name__) @@ -63,7 +62,7 @@ def __init__(self, graph: IRGraph, config: AutoDistConfig): self.graph = graph self.autodist_config = config - self.profile_dir = Path(config.profile_dir) / get_node_arch() + self.profile_dir = Path(config.profile_dir) self.db = ProfileDataBase() self.comp_profile_path = self.profile_dir / 'comp' if not self.comp_profile_path.exists(): diff --git a/nnscaler/autodist/util.py b/nnscaler/autodist/util.py index 7a632046..c256c328 100644 --- a/nnscaler/autodist/util.py +++ b/nnscaler/autodist/util.py @@ -23,13 +23,18 @@ def double2byte(val): def double4byte(val): return struct.unpack('d', val)[0] + def get_default_profile_path(): - return Path.home() / '.cache' / 'nnscaler' / 'autodist' / '1.0' + return Path.home() / '.cache' / 'nnscaler' / 'autodist' / '1.0' / get_node_arch() + def get_node_arch(): import torch - return torch.cuda.get_device_name(torch.cuda.current_device()).replace( - ' ', '_') + if torch.cuda.is_available(): + return torch.cuda.get_device_name(torch.cuda.current_device()).replace( + ' ', '_') + else: + return 'cpu' # although we don't support cpu now, we still need to return something for testing # tensor parallelism diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index eb58f7c3..dea2160d 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -67,6 +67,7 @@ import re import string import logging +from itertools import dropwhile from nnscaler.ir.cten import IRTensor, IRObject from nnscaler.ir.operator import IRFwOperation @@ -149,7 +150,7 @@ def parse(anno: Union[str, Tuple[str]]) -> Tuple[Tuple[str], Tuple[ReduceType]]: assert str.isdecimal(identifier) or str.isidentifier(identifier) or identifier in _kSpecialIdentifiers, \ f"identifier can only be integer or python identifier but got {identifier}" # integer will always have stay reduction type - if str.isdecimal(identifier): + if str.isdecimal(identifier) or identifier == '?': reduce = DimAnno.ReduceType.Freeze identifiers.append(identifier) reduces.append(reduce) @@ -481,6 +482,45 @@ def create_op_str(ins: Tuple[Tuple[Union[str, Tuple[str]]]], ou_annos.append(' '.join(flatten)) return ', '.join(in_annos) + ' -> ' + ', '.join(ou_annos) + def transform_space(self) -> List[Tuple[int, int]]: + """ + Get transformation space of the operator, the transformation space + represents all configurations that can be segmented + + @return List[Tuple[int, int]]: list of (idx, dim) + """ + # only the first identifier in a dim anno is partitionable + # eg. (a b c) x -> (b x) + # b, c or x can't be partitioned, because they are not in the first position + # a special case is when the leading identifiers are '1' + # for example + # (1 a b) -> b or (1 1 a b) -> b + # in both cases, a can be partitioned, but b can't + + # collect all unpartitioned identifiers that are not in first position + nonleading_ids = set() + for shape in self.inputs() + self.outputs(): + for dim, dim_anno in enumerate(shape.dims): + for identifier in list(dropwhile(lambda x: x == '1', dim_anno.identifiers))[1:]: + if not str.isdecimal(identifier): + nonleading_ids.add(identifier) + + visited : Set[str] = set() # to remove equavalent configurations + configs = [] + shapes = self.inputs() + for idx, shape in enumerate(shapes): + if shape.ignore: continue + for dim, edim in enumerate(shape.dims): + # this for loop just checks the first element. + for identifier, reduce in dropwhile(lambda x: x[0] == '1', zip(edim.identifiers, edim.reduces)): + if identifier in visited: continue + visited.add(identifier) + if reduce != DimAnno.ReduceType.Freeze and identifier not in nonleading_ids: + configs.append((idx, dim)) + break + + return configs + class DimopSplit: """ @@ -774,8 +814,8 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict for ashape, itensor in zip(op_anno.inputs(), inputs): if itensor is None: continue - if not (isinstance(itensor, IRTensor) ^ ashape.ignore): - return False + if ashape.ignore: + continue if not isinstance(itensor, IRTensor): continue if ashape.ndims != len(itensor.shape): @@ -841,24 +881,9 @@ def algorithms(self, tag: Optional[str] = None): def transform_space(self) -> List[Tuple[int, int]]: """ - Get transformation space of the operator, the transformation space + Get transformation space of the operator, the transformation space represents all configurations that can be segmented @return List[Tuple[int, int]]: list of (idx, dim) """ - visited : Set[str] = set() - configs = [] - ashapes = self.anno.inputs() + self.anno.outputs() - for idx, eshape in enumerate(ashapes): - if eshape.ignore: continue - if idx < len(self.inputs()): - if not isinstance(self.input(idx), IRTensor): continue - for dim, edim in enumerate(eshape.dims): - for identifier, reduce in zip(edim.identifiers, edim.reduces): - if identifier in visited: continue - visited.add(identifier) - if identifier == '1' or self.anno.getlen(identifier) == 1: continue - if reduce == DimAnno.ReduceType.Freeze: break - configs.append((idx, dim)) - break - return configs + return self.anno.transform_space() diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index d7c381c3..3ff700e5 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -194,7 +194,7 @@ def udfop(*args, signature=None, **kwargs): tensors = args[:ninputs] for idx, t in enumerate(tensors): # argument check - if str(anno.input(idx)) != '?': + if not anno.input(idx).ignore: if not isinstance(t, IRTensor): raise ValueError( f"{idx}-th input needs IRTensor, but got {type(t)}: {t}\n" diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 041681e0..e45354a9 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -4,7 +4,8 @@ Users can write the policy following the steps: 1. Apply multiref -2. Apply recompute + If all consumers of a full tensor consume the same subtensor (the partitions are exactly the same), we can skip this step. +2. Apply recompute (if needed) 3. Graph staging (pipeline only) 4. Graph partition & assign 5. Apply schedule (pipeline only) @@ -63,14 +64,8 @@ def pas_dp(graph: IRGraph, cfg: 'ComputeConfig'): if ngpus != 1: raise ValueError("Data parallelism only supports 1 plan GPU") - for ftensor in graph.full_tensors(): - if ftensor.is_grad(): continue - if len(graph.consumers(ftensor)) > 1: - graph.multiref(ftensor) - + # no partition is done, so we can skip multiref safely for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if node.name == 'multiref' or isinstance(node, IRGraphAnchor): - continue _replica(graph, node, [0]) return graph @@ -101,7 +96,7 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): _replica(graph, node, devs) else: configs = sorted(configs, reverse=True, - key=lambda config: node.input(config[0]).shape[config[1]]) + key=lambda config: node.input(config[0]).shape[config[1]]) random.shuffle(configs) for (idx, dim) in configs: if node.input(idx).shape[dim] % len(devs) != 0: continue diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index d7927bfd..7f5117f6 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -14,6 +14,7 @@ import _operator # required by eval() import nnscaler # required by eval() +from nnscaler.graph.function.dimops import IRDimops from nnscaler.ir.cten import IRTensor, IRObject from nnscaler.ir.operator import IRFwOperation from nnscaler.graph.parser.register import CustomizedOps @@ -279,7 +280,7 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b if isinstance(device, int): orig_device = torch.cuda.current_device() torch.cuda.set_device(device) - + in_mem_info, param_mem_info, buffer_mem_info = [], [], [] for t in node.inputs(): if isinstance(t, IRTensor) and t.is_param(): diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index b0f7867b..4a17635f 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -48,7 +48,7 @@ def fullslice(input: torch.Tensor, *slicers: Union[None, slice, int, torch.Tenso 1) `None` will always extend a dimension at current position. 2) `slice(None, None, None)` equals to `:`, meaning select every element at its dimension. - + Args: input (torch.Tensor): input tensor slicers (Union[None | slicer | int | torch.Tensor]): slicers for input @@ -181,13 +181,13 @@ def rand(size: Tuple[int], dtype=None, requires_grad=False): def full(size: Tuple[int], fill_value, dtype=None, requires_grad=False): return torch.full( - size, fill_value, dtype=dtype, requires_grad=requires_grad, + size, fill_value, dtype=dtype, requires_grad=requires_grad, device=torch.cuda.current_device() ) def arange(start: int, end: int, step: int, dtype: torch.dtype, requires_grad=False): - return torch.arange(start=start, end=end, step=step, + return torch.arange(start=start, end=end, step=step, dtype=dtype, requires_grad=requires_grad, device=torch.cuda.current_device()) diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index 6643628e..e992d485 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -82,7 +82,8 @@ def test_follow_rope(): in future, we may add follow chains for binary ops, like mul, add, etc. ''' - cfg = AutoDistConfig(mesh_col=2, re_profile=True) + profile_dir = Path(os.path.dirname(__file__)) / './test_follow_rope_profile' + cfg = AutoDistConfig(mesh_col=2, profile_dir=profile_dir) model_graph = ModelGraph(ir_graph, cfg) spmd_solver = SPMDSolver( @@ -208,7 +209,8 @@ def test_follow_attention(): ''' pc_path = Path(os.path.dirname(__file__)) / 'test_attention_follow.yaml' - cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2, re_profile=True) + profile_dir = Path(os.path.dirname(__file__)) / './test_follow_attention_profile' + cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2, profile_dir=profile_dir) model_graph = ModelGraph(ir_graph, cfg) spmd_solver = SPMDSolver( diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.contiguous.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.contiguous.json new file mode 100644 index 00000000..fb371404 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.contiguous.json @@ -0,0 +1,62 @@ +{ + "(2, 128, 8, 64)-(2, 128, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005054101347923279, + "bw_span": 0.029171817004680634, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 8, 64)-(1, 128, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006699748337268829, + "bw_span": 0.025911815464496613, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 8, 64)-(2, 64, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006890296936035156, + "bw_span": 0.03253184258937836, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 4, 64)-(2, 128, 4, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006747245788574219, + "bw_span": 0.025579333305358887, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 8, 32)-(2, 128, 8, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.0067526474595069885, + "bw_span": 0.026317313313484192, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.reshape.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.reshape.json new file mode 100644 index 00000000..b3936a04 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.reshape.json @@ -0,0 +1,50 @@ +{ + "(2, 128, 8, 64)-(2, 128, 512) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009675882756710052, + "bw_span": 0.048818811774253845, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 8, 64)-(1, 128, 512) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010307878255844116, + "bw_span": 0.035110488533973694, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 8, 64)-(2, 64, 512) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009638629853725433, + "bw_span": 0.036632269620895386, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 4, 64)-(2, 128, 256) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009684823453426361, + "bw_span": 0.0355524942278862, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.view.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.view.json new file mode 100644 index 00000000..152252b5 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.Tensor.view.json @@ -0,0 +1,50 @@ +{ + "(2, 128, 512)-(2, 128, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009916722774505615, + "bw_span": 0.049952976405620575, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 512)-(1, 128, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009810179471969604, + "bw_span": 0.037309154868125916, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 512)-(2, 64, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009626708924770355, + "bw_span": 0.03499723970890045, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 256)-(2, 128, 4, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010732002556324005, + "bw_span": 0.03535933792591095, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.div.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.div.json new file mode 100644 index 00000000..85b00b0c --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.div.json @@ -0,0 +1,62 @@ +{ + "(2, 8, 128, 128)-(2, 8, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1048576 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01583155244588852, + "bw_span": 0.0686870887875557, + "infer_memory": 1048576, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 128)-(1, 8, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.023482181131839752, + "bw_span": 0.08055288344621658, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 128)-(2, 4, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.02271253615617752, + "bw_span": 0.07915832102298737, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 128)-(2, 8, 64, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.015473924577236176, + "bw_span": 0.06928052753210068, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 64)-(2, 8, 128, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.024519674479961395, + "bw_span": 0.07612667977809906, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.matmul.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.matmul.json new file mode 100644 index 00000000..242863a4 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.matmul.json @@ -0,0 +1,211 @@ +{ + "(2, 8, 128, 64)-(2, 8, 64, 128)-(2, 8, 128, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 524288, + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.024786032736301422, + "bw_span": 0.15842635184526443, + "infer_memory": 1048576, + "train_mem_info": [ + 524288, + 524288 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 128, 128)-(2, 8, 128, 64)-(2, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 1048576, + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03702659159898758, + "bw_span": 0.17407722771167755, + "infer_memory": 524288, + "train_mem_info": [ + 524288, + 1048576 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(1, 8, 128, 64)-(1, 8, 64, 128)-(1, 8, 128, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.036241114139556885, + "bw_span": 0.18119476735591888, + "infer_memory": 524288, + "train_mem_info": [ + 262144, + 262144 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 4, 128, 64)-(2, 4, 64, 128)-(2, 4, 128, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.025830231606960297, + "bw_span": 0.16476567834615707, + "infer_memory": 524288, + "train_mem_info": [ + 262144, + 262144 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 64, 64)-(2, 8, 64, 128)-(2, 8, 64, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.036576949059963226, + "bw_span": 0.18481463193893433, + "infer_memory": 524288, + "train_mem_info": [ + 524288, + 262144 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 128, 32)-(2, 8, 32, 128)-(2, 8, 128, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03641769289970398, + "bw_span": 0.18008965998888016, + "infer_memory": 1048576, + "train_mem_info": [ + 262144, + 262144 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 128, 64)-(2, 8, 64, 64)-(2, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 524288, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.025328807532787323, + "bw_span": 0.15807561576366425, + "infer_memory": 524288, + "train_mem_info": [ + 262144, + 524288 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(1, 8, 128, 128)-(1, 8, 128, 64)-(1, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 524288, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03785807639360428, + "bw_span": 0.17732437700033188, + "infer_memory": 262144, + "train_mem_info": [ + 262144, + 524288 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 4, 128, 128)-(2, 4, 128, 64)-(2, 4, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 524288, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03658011555671692, + "bw_span": 0.24166293442249298, + "infer_memory": 262144, + "train_mem_info": [ + 262144, + 524288 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 64, 128)-(2, 8, 128, 64)-(2, 8, 64, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 524288, + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.0364525243639946, + "bw_span": 0.1794666051864624, + "infer_memory": 262144, + "train_mem_info": [ + 524288, + 524288 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 128, 128)-(2, 8, 128, 32)-(2, 8, 128, 32) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 1048576, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.035658665001392365, + "bw_span": 0.16031749546527863, + "infer_memory": 262144, + "train_mem_info": [ + 262144, + 1048576 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.dropout.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.dropout.json new file mode 100644 index 00000000..e496b684 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.dropout.json @@ -0,0 +1,62 @@ +{ + "(2, 8, 128, 128)-(2, 8, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1048576 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006994791328907013, + "bw_span": 0.030942633748054504, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 128)-(1, 8, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01101568341255188, + "bw_span": 0.033733807504177094, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 128)-(2, 4, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009799189865589142, + "bw_span": 0.029200315475463867, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 128)-(2, 8, 64, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00997837632894516, + "bw_span": 0.0340314581990242, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 64)-(2, 8, 128, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009815767407417297, + "bw_span": 0.034654513001441956, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.linear.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.linear.json new file mode 100644 index 00000000..cb3b2bbc --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.linear.json @@ -0,0 +1,182 @@ +{ + "(2, 128, 512)-(512, 512)-(2, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.036911480128765106, + "bw_span": 0.13526901602745056, + "infer_memory": 524288, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 512)-(512, 512)-(2, 128, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.024814344942569733, + "bw_span": 0.16298573464155197, + "infer_memory": 524288, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 128, 512)-(512, 512)-(1, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.040480680763721466, + "bw_span": 0.14026649296283722, + "infer_memory": 262144, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 64, 512)-(512, 512)-(2, 64, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.04645884037017822, + "bw_span": 0.13854112476110458, + "infer_memory": 262144, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 256)-(512, 256)-(2, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.036505237221717834, + "bw_span": 0.13390127569437027, + "infer_memory": 524288, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 512)-(256, 512)-(2, 128, 256) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.04011690616607666, + "bw_span": 0.12174341827630997, + "infer_memory": 262144, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 128, 512)-(512, 512)-(1, 128, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.03942213952541351, + "bw_span": 0.16794409602880478, + "infer_memory": 262144, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 64, 512)-(512, 512)-(2, 64, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.0401325523853302, + "bw_span": 0.1276616007089615, + "infer_memory": 262144, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 256)-(512, 256)-(2, 128, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.03376603126525879, + "bw_span": 0.11620894074440002, + "infer_memory": 524288, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 512)-(256, 512)-(2, 128, 256) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.04029721021652222, + "bw_span": 0.15293508768081665, + "infer_memory": 262144, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + 0 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.softmax.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.softmax.json new file mode 100644 index 00000000..6d9cc888 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.nn.functional.softmax.json @@ -0,0 +1,66 @@ +{ + "(2, 8, 128, 128)-(2, 8, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 1048576 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010657869279384613, + "bw_span": 0.05731713026762009, + "infer_memory": 1048576, + "train_mem_info": [ + 1048576 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(1, 8, 128, 128)-(1, 8, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013040006160736084, + "bw_span": 0.0829434022307396, + "infer_memory": 524288, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(2, 4, 128, 128)-(2, 4, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.0180685892701149, + "bw_span": 0.08466336876153946, + "infer_memory": 524288, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(2, 8, 64, 128)-(2, 8, 64, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012842938303947449, + "bw_span": 0.07357802242040634, + "infer_memory": 524288, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + -1 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.sum.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.sum.json new file mode 100644 index 00000000..986493b4 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.sum.json @@ -0,0 +1,50 @@ +{ + "(2, 128, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01636110246181488, + "bw_span": 0.06344523280858994, + "infer_memory": 1536, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012609921395778656, + "bw_span": 0.034259818494319916, + "infer_memory": 512, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012742914259433746, + "bw_span": 0.03195144236087799, + "infer_memory": 512, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 256)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012966431677341461, + "bw_span": 0.03117881715297699, + "infer_memory": 512, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.transpose.json b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.transpose.json new file mode 100644 index 00000000..e7116ccc --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_attention_profile/comp/torch.transpose.json @@ -0,0 +1,182 @@ +{ + "(2, 128, 8, 64)-(2, 8, 128, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00963062047958374, + "bw_span": 0.045523419976234436, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 64)-(2, 8, 64, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006907619535923004, + "bw_span": 0.039892829954624176, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 64)-(2, 128, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010016374289989471, + "bw_span": 0.04464220255613327, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 8, 64)-(1, 8, 128, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009498931467533112, + "bw_span": 0.03776978701353073, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 8, 64)-(2, 8, 64, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012565404176712036, + "bw_span": 0.053708069026470184, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 4, 64)-(2, 4, 128, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009473785758018494, + "bw_span": 0.043984316289424896, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 8, 32)-(2, 8, 128, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009542889893054962, + "bw_span": 0.04641469568014145, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 64)-(1, 8, 64, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009560585021972656, + "bw_span": 0.047761574387550354, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 64)-(2, 4, 64, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009697489440441132, + "bw_span": 0.05306694656610489, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 64)-(2, 8, 64, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006787292659282684, + "bw_span": 0.039852969348430634, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 32)-(2, 8, 32, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009612180292606354, + "bw_span": 0.04738885909318924, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 64)-(1, 128, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009609386324882507, + "bw_span": 0.03564096987247467, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 64)-(2, 128, 4, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00969376415014267, + "bw_span": 0.03511160612106323, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 64)-(2, 64, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00937674194574356, + "bw_span": 0.03548767417669296, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 32)-(2, 128, 8, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00964682549238205, + "bw_span": 0.03964081406593323, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/_operator.neg.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/_operator.neg.json new file mode 100644 index 00000000..2f1a93e2 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/_operator.neg.json @@ -0,0 +1,50 @@ +{ + "(2, 1, 128, 256)-(2, 1, 128, 256) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009768269956111908, + "bw_span": 0.0, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 1, 128, 256)-(1, 1, 128, 256) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010193884372711182, + "bw_span": 0.0, + "infer_memory": 131072, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 64, 256)-(2, 1, 64, 256) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009867176413536072, + "bw_span": 0.0011801719665527344, + "infer_memory": 131072, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 128, 128)-(2, 1, 128, 128) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009940750896930695, + "bw_span": 0.0, + "infer_memory": 131072, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.cat.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.cat.json new file mode 100644 index 00000000..781eeb37 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.cat.json @@ -0,0 +1,41 @@ +{ + "(2, 1, 128, 256)-(2, 1, 128, 256)-(2, 1, 128, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012085027992725372, + "bw_span": 0.0, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 1, 128, 256)-(1, 1, 128, 256)-(1, 1, 128, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 131072, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012713484466075897, + "bw_span": 0.0, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 64, 256)-(2, 1, 64, 256)-(2, 1, 64, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 131072, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012474879622459412, + "bw_span": 0.0, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.fullslice.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.fullslice.json new file mode 100644 index 00000000..97036417 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.fullslice.json @@ -0,0 +1,77 @@ +{ + "(128, 512)-(128,)-(128, 512) : torch.float32-torch.int64-torch.float32 : False-False-False": { + "in_mem_info": [ + 262144, + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.014503486454486847, + "bw_span": 0.0, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 128, 512)-(2, 1, 128, 256) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00962521880865097, + "bw_span": 0.0, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(128, 256)-(128,)-(128, 256) : torch.float32-torch.int64-torch.float32 : False-False-False": { + "in_mem_info": [ + 131072, + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.014260224997997284, + "bw_span": 0.0, + "infer_memory": 131072, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(128, 512)-(64,)-(64, 512) : torch.float32-torch.int64-torch.float32 : False-False-False": { + "in_mem_info": [ + 262144, + 512 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.0144224613904953, + "bw_span": 0.0, + "infer_memory": 131072, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 1, 128, 512)-(1, 1, 128, 256) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009587407112121582, + "bw_span": 0.0, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 64, 512)-(2, 1, 64, 256) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009298883378505707, + "bw_span": 0.0, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.add.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.add.json new file mode 100644 index 00000000..4e0f6057 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.add.json @@ -0,0 +1,67 @@ +{ + "(2, 128, 128, 512)-(2, 128, 128, 512)-(2, 128, 128, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 67108864, + 67108864 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.29664188623428345, + "bw_span": 0.0, + "infer_memory": 67108864, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 128, 512)-(1, 128, 128, 512)-(1, 128, 128, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 33554432, + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.1504974439740181, + "bw_span": 0.0, + "infer_memory": 33554432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 128, 512)-(2, 64, 128, 512)-(2, 64, 128, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 33554432, + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.15051700174808502, + "bw_span": 0.0, + "infer_memory": 33554432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 64, 512)-(2, 128, 64, 512)-(2, 128, 64, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 33554432, + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.15072356909513474, + "bw_span": 0.0, + "infer_memory": 33554432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 128, 256)-(2, 128, 128, 256)-(2, 128, 128, 256) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 33554432, + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.15038345009088516, + "bw_span": 0.0, + "infer_memory": 33554432, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.mul.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.mul.json new file mode 100644 index 00000000..3e48818b --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.mul.json @@ -0,0 +1,67 @@ +{ + "(2, 1, 128, 512)-(128, 1, 512)-(2, 128, 128, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 524288, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.11223778128623962, + "bw_span": 0.0, + "infer_memory": 67108864, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 1, 128, 512)-(128, 1, 512)-(1, 128, 128, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.05703773349523544, + "bw_span": 0.0, + "infer_memory": 33554432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 64, 512)-(128, 1, 512)-(2, 128, 64, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.059348903596401215, + "bw_span": 0.0, + "infer_memory": 33554432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 128, 256)-(128, 1, 256)-(2, 128, 128, 256) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 262144, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.057874247431755066, + "bw_span": 0.0, + "infer_memory": 33554432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 128, 512)-(64, 1, 512)-(2, 64, 128, 512) : torch.float32-torch.float32-torch.float32 : False-False-False": { + "in_mem_info": [ + 524288, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.057921744883060455, + "bw_span": 0.0, + "infer_memory": 33554432, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.sum.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.sum.json new file mode 100644 index 00000000..0d67bd7d --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.sum.json @@ -0,0 +1,62 @@ +{ + "(2, 128, 128, 512)-(1,) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 67108864 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.10211896151304245, + "bw_span": 0.0, + "infer_memory": 2048, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 128, 512)-(1,) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.05497559905052185, + "bw_span": 0.0, + "infer_memory": 2048, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 128, 512)-(1,) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.05486011505126953, + "bw_span": 0.0, + "infer_memory": 2048, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 64, 512)-(1,) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.054844655096530914, + "bw_span": 0.0, + "infer_memory": 2048, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 128, 256)-(1,) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.055054761469364166, + "bw_span": 0.0, + "infer_memory": 2048, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.unsqueeze.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.unsqueeze.json new file mode 100644 index 00000000..dafafa89 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.unsqueeze.json @@ -0,0 +1,38 @@ +{ + "(128, 512)-(128, 1, 512) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005235709249973297, + "bw_span": 0.0, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(64, 512)-(64, 1, 512) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005417689681053162, + "bw_span": 0.0, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(128, 256)-(128, 1, 256) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005025602877140045, + "bw_span": 0.0, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_partition_constraint.py b/tests/autodist/spmd_solver/test_partition_constraint.py index da51bbe1..d1cee788 100644 --- a/tests/autodist/spmd_solver/test_partition_constraint.py +++ b/tests/autodist/spmd_solver/test_partition_constraint.py @@ -75,7 +75,8 @@ def test_partition_constraint(): pc_path = Path(os.path.dirname( os.path.realpath(__file__))) / 'test_pc.yaml' - cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2, re_profile=True) + profile_dir = Path(os.path.dirname(os.path.realpath(__file__))) / 'test_partition_constraint_profile' + cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2, profile_dir=profile_dir) model_graph = ModelGraph(ir_graph, cfg) spmd_solver = SPMDSolver( diff --git a/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.matmul.json b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.matmul.json new file mode 100644 index 00000000..e6e6abbd --- /dev/null +++ b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.matmul.json @@ -0,0 +1,192 @@ +{ + "(2, 128, 768)-(2, 768, 128)-(2, 128, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 786432, + 786432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.0893130898475647, + "bw_span": 0.1256939023733139, + "infer_memory": 131072, + "train_mem_info": [ + 786432, + 786432 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 128, 128)-(2, 128, 768)-(2, 128, 768) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 131072, + 786432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.035924091935157776, + "bw_span": 0.16662534326314926, + "infer_memory": 786432, + "train_mem_info": [ + 786432, + 131072 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(1, 128, 768)-(1, 768, 128)-(1, 128, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 393216, + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.046819448471069336, + "bw_span": 0.18264036625623703, + "infer_memory": 65536, + "train_mem_info": [ + 393216, + 393216 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 64, 768)-(2, 768, 128)-(2, 64, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 393216, + 786432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.08911415934562683, + "bw_span": 0.12745074927806854, + "infer_memory": 65536, + "train_mem_info": [ + 786432, + 393216 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 128, 384)-(2, 384, 128)-(2, 128, 128) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 393216, + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.05186498165130615, + "bw_span": 0.16340836882591248, + "infer_memory": 131072, + "train_mem_info": [ + 393216, + 393216 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 128, 768)-(2, 768, 64)-(2, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 786432, + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.08894223719835281, + "bw_span": 0.044196657836437225, + "infer_memory": 65536, + "train_mem_info": [ + 393216, + 786432 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(1, 128, 128)-(1, 128, 768)-(1, 128, 768) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 65536, + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03444403409957886, + "bw_span": 0.16248486936092377, + "infer_memory": 393216, + "train_mem_info": [ + 393216, + 65536 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 64, 128)-(2, 128, 768)-(2, 64, 768) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 65536, + 786432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03483705222606659, + "bw_span": 0.1537581905722618, + "infer_memory": 393216, + "train_mem_info": [ + 786432, + 65536 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 128, 64)-(2, 64, 768)-(2, 128, 768) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 65536, + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.03511644899845123, + "bw_span": 0.15609469264745712, + "infer_memory": 786432, + "train_mem_info": [ + 393216, + 65536 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 128, 128)-(2, 128, 384)-(2, 128, 384) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 131072, + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.034658610820770264, + "bw_span": 0.1550775021314621, + "infer_memory": 393216, + "train_mem_info": [ + 393216, + 131072 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.linear.json b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.linear.json new file mode 100644 index 00000000..d394cc28 --- /dev/null +++ b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.linear.json @@ -0,0 +1,182 @@ +{ + "(2, 128, 768)-(768, 768)-(2, 128, 768) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 786432 + ], + "param_mem_info": [ + 2359296 + ], + "buffer_mem_info": [], + "fw_span": 0.043386779725551605, + "bw_span": 0.13291947543621063, + "infer_memory": 786432, + "train_mem_info": [ + 786432 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 768)-(768, 768)-(2, 128, 768) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 786432 + ], + "param_mem_info": [ + 2359296 + ], + "buffer_mem_info": [], + "fw_span": 0.044011883437633514, + "bw_span": 0.18772706389427185, + "infer_memory": 786432, + "train_mem_info": [ + 786432 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 128, 768)-(768, 768)-(1, 128, 768) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [ + 2359296 + ], + "buffer_mem_info": [], + "fw_span": 0.04044007509946823, + "bw_span": 0.1281624659895897, + "infer_memory": 393216, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 64, 768)-(768, 768)-(2, 64, 768) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [ + 2359296 + ], + "buffer_mem_info": [], + "fw_span": 0.03905370831489563, + "bw_span": 0.1338878646492958, + "infer_memory": 393216, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 384)-(768, 384)-(2, 128, 768) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [ + 1179648 + ], + "buffer_mem_info": [], + "fw_span": 0.03585759550333023, + "bw_span": 0.1320449635386467, + "infer_memory": 786432, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 768)-(384, 768)-(2, 128, 384) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 786432 + ], + "param_mem_info": [ + 1179648 + ], + "buffer_mem_info": [], + "fw_span": 0.044881924986839294, + "bw_span": 0.1383565366268158, + "infer_memory": 393216, + "train_mem_info": [ + 786432 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 128, 768)-(768, 768)-(1, 128, 768) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [ + 2359296 + ], + "buffer_mem_info": [], + "fw_span": 0.03957469016313553, + "bw_span": 0.158010795712471, + "infer_memory": 393216, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 64, 768)-(768, 768)-(2, 64, 768) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [ + 2359296 + ], + "buffer_mem_info": [], + "fw_span": 0.04173126071691513, + "bw_span": 0.15814080834388733, + "infer_memory": 393216, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 384)-(768, 384)-(2, 128, 768) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [ + 1179648 + ], + "buffer_mem_info": [], + "fw_span": 0.041070207953453064, + "bw_span": 0.16851909458637238, + "infer_memory": 786432, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 768)-(384, 768)-(2, 128, 384) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 786432 + ], + "param_mem_info": [ + 1179648 + ], + "buffer_mem_info": [], + "fw_span": 0.039997510612010956, + "bw_span": 0.0878872349858284, + "infer_memory": 393216, + "train_mem_info": [ + 786432 + ], + "train_mem2in_idx": [ + 0 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.relu.json b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.relu.json new file mode 100644 index 00000000..5507fc56 --- /dev/null +++ b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.relu.json @@ -0,0 +1,66 @@ +{ + "(2, 128, 768)-(2, 128, 768) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 786432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01955125480890274, + "bw_span": 0.08346512913703918, + "infer_memory": 786432, + "train_mem_info": [ + 786432 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(1, 128, 768)-(1, 128, 768) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.018684566020965576, + "bw_span": 0.06451644003391266, + "infer_memory": 393216, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(2, 64, 768)-(2, 64, 768) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.018518418073654175, + "bw_span": 0.06296299397945404, + "infer_memory": 393216, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(2, 128, 384)-(2, 128, 384) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01855306327342987, + "bw_span": 0.06243288516998291, + "infer_memory": 393216, + "train_mem_info": [ + 393216 + ], + "train_mem2in_idx": [ + -1 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.softmax.json b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.softmax.json new file mode 100644 index 00000000..beeb07ed --- /dev/null +++ b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.nn.functional.softmax.json @@ -0,0 +1,50 @@ +{ + "(2, 128, 128)-(2, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01788698136806488, + "bw_span": 0.07958952337503433, + "infer_memory": 131072, + "train_mem_info": [ + 131072 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(1, 128, 128)-(1, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 65536 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.017834454774856567, + "bw_span": 0.06955564022064209, + "infer_memory": 65536, + "train_mem_info": [ + 65536 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(2, 64, 128)-(2, 64, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 65536 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.018155574798583984, + "bw_span": 0.0684056431055069, + "infer_memory": 65536, + "train_mem_info": [ + 65536 + ], + "train_mem2in_idx": [ + -1 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.sum.json b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.sum.json new file mode 100644 index 00000000..1b74a33c --- /dev/null +++ b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.sum.json @@ -0,0 +1,50 @@ +{ + "(2, 128, 768)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 786432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.026979856193065643, + "bw_span": 0.07029566913843155, + "infer_memory": 1536, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 768)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013635493814945221, + "bw_span": 0.03176294267177582, + "infer_memory": 512, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 768)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013938359916210175, + "bw_span": 0.03203488886356354, + "infer_memory": 512, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 384)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013530440628528595, + "bw_span": 0.03103390336036682, + "infer_memory": 512, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.transpose.json b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.transpose.json new file mode 100644 index 00000000..61d5c797 --- /dev/null +++ b/tests/autodist/spmd_solver/test_partition_constraint_profile/comp/torch.transpose.json @@ -0,0 +1,50 @@ +{ + "(2, 128, 768)-(2, 768, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 786432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009663403034210205, + "bw_span": 0.04680529236793518, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 768)-(1, 768, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009761378169059753, + "bw_span": 0.0479038804769516, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 768)-(2, 768, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009763240814208984, + "bw_span": 0.05019083619117737, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 384)-(2, 384, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 393216 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009500794112682343, + "bw_span": 0.04594437777996063, + "infer_memory": 0, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index cbfb230d..b516fe12 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -6,7 +6,7 @@ from functools import partial import nnscaler.graph.function as F -from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.function.dimops import IRDimops, OpAnno from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.cten import IRObject @@ -34,7 +34,7 @@ def partitionable(node: IRDimops, **config): test_view2 = partial(partitionable, create_op(F.Reshape, [(2048, 8, 64),], shape=[2048, 1, 512]), - idx=0, dim=1, num=2, + idx=0, dim=1, num=2, ) def UDFOp1(input, weight, signature='test_udf_op1'): @@ -52,7 +52,7 @@ def test_no_return_op(): def NoReturnOp(input, weight, signature='no_return_op'): anno = 'a b, b c -> ?' return IRDimops(NoReturnOp, 'no_return_op', signature, [anno], [input, weight]) - + op = create_op(NoReturnOp, [(1024, 512), (512, 1024)]) assert len(op.outputs()) == 1 and isinstance(op.output(0), IRObject) and (not isinstance(op.output(0), IRFullTensor)) @@ -84,3 +84,14 @@ def TestFunc(input, weight, bias, number=128, signature='test_func'): op = create_op(TestFunc, [(1024,), (2048,), (128,)], number=IRObject(value=128)) partitionable(op, idx=0, dim=0, num=2) + + +def test_transform_space(): + assert OpAnno('a b, b c -> a c').transform_space() == [(0, 0), (0, 1), (1, 1)] + assert OpAnno('a^ b, b c -> a^ c').transform_space() == [(0, 1), (1, 1)] + assert OpAnno('a b, (b n) c -> a (n c)').transform_space() == [(0, 0), (0, 1)] + assert OpAnno('a b, (b n) c -> a (n b c)').transform_space() == [(0, 0)] + assert OpAnno('a b, (b n) c -> a (1 b c) n').transform_space() == [(0, 0), (0, 1)] + assert OpAnno('a b, (b n) c -> a (1 1 1 b c) n').transform_space() == [(0, 0), (0, 1)] + assert OpAnno('a b, (b n) c^ -> a (1 1 1 b) n c^').transform_space() == [(0, 0), (0, 1)] + assert OpAnno('a b, (d^ n) c -> a (c n) d^').transform_space() == [(0, 0), (0, 1), (1,1)] From 4100bbecaad8b909f1e7b830b2180552454b32db Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 15 May 2024 08:35:40 +0000 Subject: [PATCH 1637/1892] Merged PR 2142: parallel module: add default argument support for tracing parallel module: add default argument support for tracing --- nnscaler/parallel.py | 40 +++++++++++++++++---------- nnscaler/runtime/module.py | 8 +++--- tests/parallel_module/test_gencode.py | 37 ++++++++++++++++++++++--- 3 files changed, 62 insertions(+), 23 deletions(-) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 3912fb9c..7ac09008 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -53,7 +53,7 @@ _PREDEFINE_SCHED_NAME_PREFIX = 'sched_' for k, v in PredefinedSched.__dict__.items(): if isinstance(v, staticmethod) and k.startswith(_PREDEFINE_SCHED_NAME_PREFIX): - _PREDEFINE_SCHEDS[k[len(_PREDEFINE_SCHED_NAME_PREFIX):]] = v + _PREDEFINE_SCHEDS[k[len(_PREDEFINE_SCHED_NAME_PREFIX):]] = getattr(PredefinedSched, k) # be compatible with python 3.8 _PREDEFINED_POLICIES: Dict[str, Callable[[IRGraph, 'ComputeConfig'], IRGraph]] = {} _PREDEFINED_POLICIES_NAME_PREFIX = 'pas_' @@ -174,6 +174,7 @@ def __post_init__(self): def apply_pipeline_scheduler(self, graph: IRGraph) -> Optional[SchedulePlan]: """ + Apply the pipeline scheduler to the graph. Do nothing if not use_pipeline """ if self.use_pipeline: @@ -317,7 +318,7 @@ def _add_cube_savedir_to_syspath(cube_savedir: str) -> Path: def _is_any_gencode_loaded(namespace: str) -> bool: """Check if a module is loaded""" - for m in sys.modules.values(): + for m in list(sys.modules.values()): # list() to avoid mulitple thread confliction # m.__name__ doesn't always work as some module doesn't have __name__ attribute. if getattr(m, '__name__', '').startswith(namespace + '.' + _GENCODE_FILE_PREFIX): return True @@ -553,6 +554,7 @@ def _gen_graph( dynamic_shape: bool, end2end_mode: bool = False, inference_only: bool = False, + use_pipeline: bool = False, ): # reset environment program = Program() @@ -591,8 +593,10 @@ def _gen_graph( ir_dummy_inputs.append(None) # always set None to *args/**kwargs elif node.target in dummy_input: ir_dummy_inputs.append(dummy_input[node.target]) + elif forward_args[node.target] is not inspect.Parameter.empty: + ir_dummy_inputs.append(forward_args[node.target]) else: - raise ValueError(f"Input {node.target} not in dummy input. Default value is not supported.") + raise ValueError(f"Input {node.target} not in dummy input, nor has default value.") for i in range(len(ir_dummy_inputs)): ir_dummy_inputs[i] = to_ir_input(ir_dummy_inputs[i], fx_input_nodes[i].target) # if the input is not a tensor, we should wrap it with IRObject @@ -634,8 +638,8 @@ def _gen_graph( if ir_dummy_outputs is None: ir_dummy_outputs = [] elif not isinstance(ir_dummy_outputs, (tuple, list)): ir_dummy_outputs = [ir_dummy_outputs] - if _contains_uncommutable_data(ir_dummy_outputs): - raise RuntimeError(f"Communication generation error: some of outputs are not commutable between gpus.") + if use_pipeline and _contains_uncommutable_data(ir_dummy_outputs): + raise RuntimeError(f"Communication generation error: some of outputs are not commutable between gpus, which is not supported in pipeline parallelism.") program.set_output(ir_dummy_outputs) program.finalize() @@ -714,7 +718,8 @@ def _gencode( graph, forward_args = _gen_graph( module, dummy_input, outdir, dynamic_shape=compute_config.dynamic_shape, end2end_mode=compute_config.use_end2end, - inference_only=compute_config.inference_only + inference_only=compute_config.inference_only, + use_pipeline=compute_config.use_pipeline, ) graph.dump(graph_ckp) torch.save(forward_args, forward_args_ckp) @@ -761,7 +766,9 @@ def _gencode( # code generation assert len(execplan.graph.device) == compute_config.plan_ngpus, f"{execplan.graph.device}" mgener = ModuleCodeGen(execplan, compute_config.runtime_ngpus) - sgener = ScheduleCodeGen(execplan, compute_config.runtime_ngpus) + sgener = None + if compute_config.use_end2end: + sgener = ScheduleCodeGen(execplan, compute_config.runtime_ngpus) for rank in range(compute_config.runtime_ngpus): fname = outdir / _GENCODE_FILE_TEMPLATE.format(rank) mgener.gen(rank, @@ -771,12 +778,14 @@ def _gencode( as_parallel_module=True, end2end_mode=compute_config.use_end2end ) - # generate temporal schedule code - sgener.gen( - device=rank, - outfile=fname, - attach=True - ) + # generate temporal schedule code only for end2end module + # because the code generated is wrong for non-end2end module. + if compute_config.use_end2end: + sgener.gen( + device=rank, + outfile=fname, + attach=True + ) return ret @@ -815,8 +824,9 @@ def _load_cube_module_class( cube_module_class.__qualname__ = module_class.__qualname__ # cube_module_class.__module__ = module_class.__module__ cube_module_class.__orig_module_class__ = module_class # save the original module class - cube_module_class._train_step = gen_imported._train_step - cube_module_class._infer_step = gen_imported._infer_step + # override train_step and infer_step only if they are defined in the generated module (end2end module only) + cube_module_class._train_step = getattr(gen_imported, '_train_step', cube_module_class._train_step) + cube_module_class._infer_step = getattr(gen_imported, '_infer_step', cube_module_class._infer_step) return cube_module_class diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 5bc36a6d..e839f682 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -708,23 +708,23 @@ def sync_grad(self): def _train_step(self, dataloader) -> Union[List[Any], Any]: """ - This function is assigned automatically when loading module class + This function is assigned automatically when loading end2end module class Returns: Union[List[Any], Any]: the output of the training step, In Pipeline mode, it should return a list of outputs for each sample Otherwise, it should return a single output """ - ... + raise NotImplementedError def _infer_step(self, dataloader) -> Union[List[Any], Any]: """ - This function is assigned automatically when loading module class + This function is assigned automatically when loading end2end module class Returns: Union[List[Any], Any]: the output of the training step, In Pipeline mode, it should return a list of outputs for each sample Otherwise, it should return a single output """ - ... + raise NotImplementedError def _scale_loss(self, is_dummy_batch: Optional[List[bool]], scale_fn: Optional[Callable[[torch.Tensor], torch.Tensor]]) -> None: """Setup cube backward hook for loss scale and dummy batch. diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index d98c2433..325c90d9 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -34,7 +34,9 @@ def _gencode_worker(tempdir): init_distributed() m = Module0() with pytest.raises(RuntimeError): # config mismatch - _to_cube_model(m, ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True) + pm = _to_cube_model(m, ComputeConfig(1, 1), cube_savedir=tempdir, load_module=True) + with pytest.raises(NotImplementedError): + pm._train_step(None) # for non-end2end parallel module, _train_step is not implemented @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @@ -196,6 +198,30 @@ def test_codegen_unused_args2(): launch_torchrun(1, _gencode_unused_args_worker2, tempdir) +class DefaultArgsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, m=0, n=None): + return self.linear(x) + m + + +@replace_all_device_with('cpu') +def test_codegen_default_args(): + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + DefaultArgsModule(), + {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, + 'dp', + ComputeConfig(1, 1), + cube_savedir=tempdir, + load_module=False + ) + # parallelize will succeed. + assert True + + class AttrModule(torch.nn.Module): def __init__(self): super().__init__() @@ -684,7 +710,7 @@ def p(cube_dir, use_pipeline, dynamic_shape, return_type, inference_only=False): parallelize( m, {'data': torch.randn(batch_size, dim), 'return_type': return_type}, - 'data', + 'data' if not use_pipeline else 'pp', compute_config= ComputeConfig( 4, 4, inference_only=inference_only, @@ -709,8 +735,11 @@ def p(cube_dir, use_pipeline, dynamic_shape, return_type, inference_only=False): r"self\.register_parameter" ) p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=0) # should success - with pytest.raises(RuntimeError, match='.*Communication generation.*'): - # fail for non-tensor IRObject return + if use_pipeline: + with pytest.raises(RuntimeError, match='.*Communication generation.*'): + # fail for non-tensor IRObject return in pipeline mode + p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=1) + else: p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=1) p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=1) # should success p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=2) # should success From 7e6169fa3c72bd1910e9d1089dbbb74c1969e970 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Wed, 15 May 2024 09:18:30 +0000 Subject: [PATCH 1638/1892] Merged PR 2123: Sphinx doc skeleton This PR sets up the sphinx doc framework. It currently contains 3 demonstrative doc pages: 1. `parallel_module.rst`, which is ported from the former `parallel_module.md` 2. `parallel.rst`, which is auto-generated from `nnscaler/parallel.py` docstr 3. `quickstart.rst`, an unfinished quickstart guide. It's committed for discussion. I can remove it from the PR. To build the doc: ``` cd docs make html ``` The render result can be previewed here: (corpnet required) http://10.150.240.223:10780/index.html --- .gitignore | 2 + docs/Makefile | 20 +++ .../{ => source}/autodist/interface_design.md | 0 .../solver_interface/partition_constraint.md | 0 .../solver_interface/pc_examples/moe_pc.yaml | 0 .../pc_examples/retnet_dp2_pc.yaml | 0 .../pc_examples/retnet_hybrid2_pc.yaml | 0 .../pc_examples/retnet_mp2_pc.yaml | 0 docs/source/conf.py | 33 +++++ docs/source/index.rst | 23 ++++ docs/source/parallel.rst | 5 + docs/{ => source}/parallel_module.md | 0 docs/source/quickstart.rst | 114 ++++++++++++++++++ docs/{ => source}/register_custom_op.md | 0 nnscaler/parallel.py | 88 ++++++++------ requirements-dev.txt | 4 + 16 files changed, 254 insertions(+), 35 deletions(-) create mode 100644 docs/Makefile rename docs/{ => source}/autodist/interface_design.md (100%) rename docs/{ => source}/autodist/solver_interface/partition_constraint.md (100%) rename docs/{ => source}/autodist/solver_interface/pc_examples/moe_pc.yaml (100%) rename docs/{ => source}/autodist/solver_interface/pc_examples/retnet_dp2_pc.yaml (100%) rename docs/{ => source}/autodist/solver_interface/pc_examples/retnet_hybrid2_pc.yaml (100%) rename docs/{ => source}/autodist/solver_interface/pc_examples/retnet_mp2_pc.yaml (100%) create mode 100644 docs/source/conf.py create mode 100644 docs/source/index.rst create mode 100644 docs/source/parallel.rst rename docs/{ => source}/parallel_module.md (100%) create mode 100644 docs/source/quickstart.rst rename docs/{ => source}/register_custom_op.md (100%) diff --git a/.gitignore b/.gitignore index b6fbec72..f76d9cb7 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,8 @@ fullmodel.pt fullmodel.pt.* dist_param_map.pt +docs/build/ + ## autodist ## # Python cache diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d0c3cbf1 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/autodist/interface_design.md b/docs/source/autodist/interface_design.md similarity index 100% rename from docs/autodist/interface_design.md rename to docs/source/autodist/interface_design.md diff --git a/docs/autodist/solver_interface/partition_constraint.md b/docs/source/autodist/solver_interface/partition_constraint.md similarity index 100% rename from docs/autodist/solver_interface/partition_constraint.md rename to docs/source/autodist/solver_interface/partition_constraint.md diff --git a/docs/autodist/solver_interface/pc_examples/moe_pc.yaml b/docs/source/autodist/solver_interface/pc_examples/moe_pc.yaml similarity index 100% rename from docs/autodist/solver_interface/pc_examples/moe_pc.yaml rename to docs/source/autodist/solver_interface/pc_examples/moe_pc.yaml diff --git a/docs/autodist/solver_interface/pc_examples/retnet_dp2_pc.yaml b/docs/source/autodist/solver_interface/pc_examples/retnet_dp2_pc.yaml similarity index 100% rename from docs/autodist/solver_interface/pc_examples/retnet_dp2_pc.yaml rename to docs/source/autodist/solver_interface/pc_examples/retnet_dp2_pc.yaml diff --git a/docs/autodist/solver_interface/pc_examples/retnet_hybrid2_pc.yaml b/docs/source/autodist/solver_interface/pc_examples/retnet_hybrid2_pc.yaml similarity index 100% rename from docs/autodist/solver_interface/pc_examples/retnet_hybrid2_pc.yaml rename to docs/source/autodist/solver_interface/pc_examples/retnet_hybrid2_pc.yaml diff --git a/docs/autodist/solver_interface/pc_examples/retnet_mp2_pc.yaml b/docs/source/autodist/solver_interface/pc_examples/retnet_mp2_pc.yaml similarity index 100% rename from docs/autodist/solver_interface/pc_examples/retnet_mp2_pc.yaml rename to docs/source/autodist/solver_interface/pc_examples/retnet_mp2_pc.yaml diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 00000000..0a698442 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,33 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +from datetime import datetime + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'nnScaler' +copyright = f'{datetime.now().year}, Microsoft' +author = 'Microsoft' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + 'myst_parser', + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', +] + +# templates_path = ['_templates'] +exclude_patterns = [] + + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'furo' +# html_static_path = ['_static'] diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 00000000..7c5d2678 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,23 @@ +.. nnScaler documentation master file, created by + sphinx-quickstart on Fri Apr 19 15:38:29 2024. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to nnScaler's documentation! +==================================== + +.. toctree:: + :maxdepth: 1 + :caption: Contents: + + quickstart + parallel_module + register_custom_op + parallel + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/source/parallel.rst b/docs/source/parallel.rst new file mode 100644 index 00000000..086bf11f --- /dev/null +++ b/docs/source/parallel.rst @@ -0,0 +1,5 @@ +``nnscaler.parallel`` +===================== + +.. automodule:: nnscaler.parallel + :members: diff --git a/docs/parallel_module.md b/docs/source/parallel_module.md similarity index 100% rename from docs/parallel_module.md rename to docs/source/parallel_module.md diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst new file mode 100644 index 00000000..6a6f6181 --- /dev/null +++ b/docs/source/quickstart.rst @@ -0,0 +1,114 @@ +########### +Get Started +########### + +Repo address: https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube + +The nnScaler repo is currently internal. +If you do not have access, please contact cubedev@microsoft.com + +************ +Installation +************ + +To get started, install the latest wheel from +`DevOps Artifacts `_. + +If you are familiar with Azure stuffs, you can follow DevOps' guide to set up the repository. + +Or if you prefer the simpler way, download the ``.whl`` file in the "Files" section of the website, +and install it locally: + +:: + + python -m pip install nnscaler-*.whl + +********** +Quickstart +********** + +The next step depends on your choice of the training framework. + +- **No framework**: if you write your own training code and do not use a framework, + see :ref:`Parallelize API` section. +- **Fairseq**: if you use fairseq, see :ref:`Fairseq` section. +- **Lightning**: TODO + +.. _Parallelize API: + +Parallelize API +=============== + +TODO: write a hello world example, assigned to Zhe Liu + +If you write your own training code, you can use the *parallelize* API to make your model parallel: + +:: + + import torch + from nnscaler import parallelize, ComputeConfig, build_optimizer + + class LLM(torch.nn.Module): + def __init__(self, ...): + ... + def forward(self, x): + ... + + llm_sample_input = ... # dummpy input will be used to do tracing + pas_policy = ... # the PAS policy, you can use autodist pas + compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + ..., + ) # compute environment config + ParallelizedLLM = parallelize( + LLM, + {'x': llm_sample_input}, + pas_policy, + compute_config, + ) + +Example +------- + +An example of the parallelize API is provided in the repo: +`train.py `_ + +You can download and try it: :: + + torchrun --nproc_per_node=4 --nnodes=1 train.py + +Documentation +------------- + +If the example works for you, you can now follow the documentation to parallelize your model: +:doc:`parallel_module` + +.. _Fairseq: + +Fairseq +======= + +nnScaler provides `fairseq integration `_. + +TODO: refine the example (and its doc), assigned to Youshan Miao + +TODO (long term): write an example using unmodified fairseq + +Installation +------------ + +To use fairseq, clone the fork and install it: :: + + python -m pip uninstall fairseq + + git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Fairseq + cd Fairseq + python -m pip install -e . + +Example +------- + +Follow the example +`here `_. diff --git a/docs/register_custom_op.md b/docs/source/register_custom_op.md similarity index 100% rename from docs/register_custom_op.md rename to docs/source/register_custom_op.md diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 7ac09008..175ea4f9 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -370,11 +370,13 @@ class BroadcastGenFilesStrategy(Enum): The broadcast strategy for generated files. Only new generated files can be broadcasted. The files includes: + 1. config file: compute config (compute_config.pt) 2. trace files: graph dump (graph.ckp), forward args dump(forward_args.pkl), - origin module metadata (origin_module_metadata.pt), init weights file(fullmodel.pt.*), - param name mapping (dist_param_map.pt) + origin module metadata (origin_module_metadata.pt), init weights file(fullmodel.pt.*), + param name mapping (dist_param_map.pt) 3. code: generated code files (gencode*.py) + Reused files will not be broadcasted with any of the following options. """ @@ -859,8 +861,8 @@ def parallelize( the generated code in outdir will be removed EVEN IF the code generation fails in this call. if the input is a module object. - The module object will be copied to cpu to handle possible insufficient gpu memory. - The training flag will be the same as the original module + * The module object will be copied to cpu to handle possible insufficient gpu memory. + * The training flag will be the same as the original module This function can be used to convert both module object and module class to cube module or cube module class. Among key-value arguments, @@ -868,31 +870,36 @@ def parallelize( whereas init_module_params controls how to load cube module object after conversion is done. 1. If the input is a module object, it will return a CubeModule object if load_module is True. - This is useful when the module is created by a factory function. - a. module_fn is ignored. - b. module_dtype is used to control the dtype of the input module. - c. init_module_params is used to control whether to initialize the cube module parameters when load it. + This is useful when the module is created by a factory function. + + a. module_fn is ignored. + b. module_dtype is used to control the dtype of the input module. + c. init_module_params is used to control whether to initialize the cube module parameters when load it. 2. If the input is a module class, it will return a CubeModule class if load_module is True. - a. module_fn is used to create the module object, or module's__init__ if not prent. - b. module_dtype is used to control the dtype of the created module (by constructor or module_fn). - Of course, it can be merged into module_fn. - c. init_module_params is ignored. + + a. module_fn is used to create the module object, or module's__init__ if not prent. + b. module_dtype is used to control the dtype of the created module (by constructor or module_fn). + Of course, it can be merged into module_fn. + c. init_module_params is ignored. After the module is converted, you can use it to create module object by calling it like a module class. The module class is defined like: - ``` - class GenModule(nnscaler.runtime.module.ParallelModule): - def __init__(self, init_params=True): - super().__init__() + + :: + + class GenModule(nnscaler.runtime.module.ParallelModule): + def __init__(self, init_params=True): + super().__init__() + ... ... - ... - ``` + So you can use `init_params` in `__init__` to control whether to initialize the module parameters. For example, if you don't want to initialize module params: - ``` - module = GenModule(init_params=False) - ``` + + :: + + module = GenModule(init_params=False) Args: module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled @@ -1047,10 +1054,14 @@ class OptimizerExtraState: the key is the module prefix of the parallel module. A module prefix is the same prefix used when you call `module.state_dict()` without the ending dot. For example, if you have a module + + :: + module submodule1_1 submodule2_1 submodule1_2 + then the prefix of `module` itself is `` (empty str). the prefix of `submodule1_1` is `submodule1_1`. the prefix of `submodule2_1` is `submodule1_1.submodule2_1`. @@ -1155,17 +1166,18 @@ def build_optimizer( Build an optimizer for a module. To support parallelized module (CubeModule), we hook 4 places in this function: + 1. optimizer constructor: - the parameters of optimizer will not be the same with the parameters of the module if we use zero - so we need to replace the parameters of optimizer with CubeModule.parameters_for_optimizer - It is impossible to make this change transparent to end users. + the parameters of optimizer will not be the same with the parameters of the module if we use zero + so we need to replace the parameters of optimizer with CubeModule.parameters_for_optimizer + It is impossible to make this change transparent to end users. 2. optimizer.step(): - we need to call optimizer.sync_shard_grad() to sync the gradients of the module before optimizer.step(). - In zero mode, we have to call CubeModule.gather_params() after optimizer.step() + we need to call optimizer.sync_shard_grad() to sync the gradients of the module before optimizer.step(). + In zero mode, we have to call CubeModule.gather_params() after optimizer.step() 3. optimizer.zero_grad(): - We need to call CubeModule.zero_grad() after optimizer.zero_grad() + We need to call CubeModule.zero_grad() after optimizer.zero_grad() 4. backward(): - you need to call optimizer.sync_shard_grad() manually if you want to read the gradients of the module before optimizer.step(). + you need to call optimizer.sync_shard_grad() manually if you want to read the gradients of the module before optimizer.step(). Args: module (torch.nn.Module): the module to be optimized @@ -1509,12 +1521,16 @@ def merge_state_dicts( Note: Only Adam-like optimizers are supported for merging Please Note: - We don't garantee the devices of tensors are the same in the merged state dict. - You can assume the device of the tensors in the merged state dict can be one of the following: - 1. the current device when running this function - 2. the current cuda device when running this function - 3. the device of the tensor in the original state dict - When you load the state dict from file, you can just use `torch.load(..., map_location='...')` to unify the device of the tensors. + + We don't garantee the devices of tensors are the same in the merged state dict. + You can assume the device of the tensors in the merged state dict can be one of the following: + + 1. the current device when running this function + 2. the current cuda device when running this function + 3. the device of the tensor in the original state dict + + When you load the state dict from file, you can just use `torch.load(..., map_location='...')` to unify the device of the tensors. + Args: model_state_dicts (List[Dict[str, Any]]): the model state dicts from each rank optimizer_state_dicts (Optional[List[Dict[str, Any]]]): the optimizer state dicts from each rank @@ -1652,6 +1668,7 @@ def load_merged_state_dicts( ): """ Load the merged state dicts to the module, and optionally the optimizer to a specified device. + Args: module (torch.nn.Module): the module to be loaded module_state_dict (Dict[str, Any]): the merged model state dict @@ -1659,6 +1676,7 @@ def load_merged_state_dicts( optimizer_state_dict (Optional[Dict[str, Any]]): the merged optimizer state dict device (Union[str, torch.device]): the device to put the module and optimizer state dicts. Use torch.cuda.current_device() if it is None. + Returns: None """ @@ -2011,7 +2029,7 @@ def deduped_state_dict( """ Return the state dict only for the ranks that is necessary. For details, see `ComputeConfig.optimizer_dedup_group_size` - and `ComputeConfig.module_dedup_group_size`. + and `ComputeConfig.module_dedup_group_size`. Args: module (torch.nn.Module): the module to get state dict diff --git a/requirements-dev.txt b/requirements-dev.txt index 826839f5..3f203d24 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,13 @@ coverage +furo mock +myst-parser pre-commit pytest pytest-cov pytest-mock +sphinx +sphinxcontrib-napoleon tabulate tox tox-conda From 77e56cdd2315f9b00fa9aa0079d118e01a58e001 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 16 May 2024 03:37:38 +0000 Subject: [PATCH 1639/1892] Merged PR 2141: Use multiprocessing in profiler use multiprocess to speed up profiling in autodist. parity check passed --- nnscaler/autodist/cost_database.py | 109 +++++-- nnscaler/autodist/spmd_solver.py | 4 +- nnscaler/graph/graph.py | 20 +- nnscaler/profiler/database.py | 461 +++++++++++++++------------- tests/graph/parser/test_register.py | 8 +- tests/profiler/test_op_profile.py | 6 +- 6 files changed, 344 insertions(+), 264 deletions(-) diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 2408b469..c1a99b52 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -5,6 +5,8 @@ from os import listdir from pathlib import Path import logging +import multiprocessing +import torch from nnscaler.graph import IRGraph from nnscaler.ir.cten import IRTensor @@ -19,6 +21,9 @@ _logger = logging.getLogger(__name__) +import nnscaler +_DEFAULT_COMM_DATA_PATH = Path(nnscaler.__file__).parent.parent / 'data/profile/mi200/comm' + def _piecewise_estimator(xs: List[float], ys: List[float], x: float) -> float: """ @@ -51,8 +56,53 @@ def _piecewise_estimator(xs: List[float], ys: List[float], x: float) -> float: xs[i]) raise RuntimeError(f'x={x}, xs={xs}, ys={ys}, should not reach here') -import nnscaler -_DEFAULT_COMM_DATA_PATH = Path(nnscaler.__file__).parent.parent / 'data/profile/mi200/comm' + +def _filter_and_group_nodes(graph: IRGraph, db: ProfileDataBase) -> List[List[IRFwOperation]]: + visited_nodes = set() + node_to_profile = list() + for node in graph.select(ntype=IRFwOperation): + if isinstance(node, (IRGraphAnchor, IRPyFunc)): + continue + hash_code = node.signature + ' : ' + db._serialize(node) + if hash_code in visited_nodes: + continue + node_to_profile.append(node) + visited_nodes.add(hash_code) + + dev_num = torch.cuda.device_count() + + # divide `node_to_profile` into `dev_num` groups + node_groups = [[] for _ in range(dev_num)] + for i, node in enumerate(node_to_profile): + node_groups[i % dev_num].append(node) + return node_groups + + +def _profile_nodes(dilled_info: str, dev_id: int, partition_degree: int, re_profile: bool, comp_profile_path: str, result: multiprocessing.Queue): + import dill + torch.cuda.set_device(dev_id) + + id_state, dilled_graph = dill.loads(dilled_info) + graph = IRGraph.from_dill(id_state, dilled_graph) + db = ProfileDataBase() + db.load_ops(comp_profile_path) + nodes = _filter_and_group_nodes(graph, db)[dev_id] + + ret = list() + for node in nodes: + if isinstance(node, IRDimops): + partition_nodes = gen_partitions(node, + partition_degree, + base=partition_degree, + depth=1) + else: + partition_nodes = [node] + for partition_node in partition_nodes: + profiled_metrics: ProfiledMetrics = db.profile(partition_node, override=re_profile) + ret.append((partition_node.signature, db._serialize(partition_node), profiled_metrics)) + _logger.info(f'device {dev_id} finished profiling {len(nodes)} nodes') + result.put(ret) + class CostDatabase: @@ -81,35 +131,32 @@ def __init__(self, graph: IRGraph, config: AutoDistConfig): self.ignore_small_tensor_threshold = self.autodist_config.ignore_small_tensor_threshold def profile_comp(self, partition_degree: int): - visited_nodes = set() - for node in self.graph.select(ntype=IRFwOperation): - if isinstance(node, (IRGraphAnchor, IRPyFunc)): - continue - hash_code = node.signature + ' : ' + self.db._serialize(node) - if hash_code in visited_nodes: - continue - if hasattr(node, 'anno'): - partition_nodes = gen_partitions(node, - partition_degree, - base=partition_degree, - depth=1) - else: - _logger.info(f'only profile replicated for {node}') - partition_nodes = [node] - for partition_node in partition_nodes: - # the returned schema may change over time, we re-profile - # if encountered an exception - try: - profiled_metrics: ProfiledMetrics = self.db.profile( - partition_node, - override=self.autodist_config.re_profile) - except Exception: - profiled_metrics: ProfiledMetrics = self.db.profile( - partition_node, override=True) - self.db.dump_op(self.comp_profile_path, - node.signature, - override=True) - visited_nodes.add(hash_code) + + # use spawn to make sure the profiling process is independent from each other + # and the main process, this is also required by torch + mp_context = multiprocessing.get_context('spawn') + + results = mp_context.Queue() + processes = [] + for i in range(torch.cuda.device_count()): + p = mp_context.Process(target=_profile_nodes, + args=(self.graph.dumps(), i, partition_degree, self.autodist_config.re_profile, self.comp_profile_path, results)) + processes.append(p) + p.start() + + # put queue.get() before join to avoid deadlock + for p in processes: + ret = results.get() + for sign, serialized, profiled_metrics in ret: + _logger.debug(f'profiled {sign} in {serialized} with {profiled_metrics}') + if not self.db.exist_serialized(sign, serialized): + self.db.insert(sign, serialized, profiled_metrics) + results.close() + + for p in processes: + p.join() + + self.db.dump_ops(self.comp_profile_path, override=True) def exist(self, node: IRFwOperation) -> bool: return self.db.exist(node) diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 8f0e5cb9..37e60131 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -653,7 +653,7 @@ def calc_partition_info(self): self.partition_info: List[List[PartitionCostDesc]] = list() for i in range(self.graph.op_num): cur_info = [] - _logger.info(f'calc partition info for {self.get_operator(i)}') + _logger.debug(f'calc partition info for {self.get_operator(i)}') for j in range(self.get_op_partition_count(i)): cost_desc = self.calc_partition_cost(i, j) if cost_desc.comp_time == float('inf'): @@ -662,7 +662,7 @@ def calc_partition_info(self): ) cost_desc.comp_time = 0.0 cur_info.append(cost_desc) - _logger.info(f'{self._op_partitions[i][j]} {cost_desc}') + _logger.debug(f'{self._op_partitions[i][j]} {cost_desc}') self.partition_info.append(cur_info) _logger.info('finish spmd solver initializetion') diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 23c4a587..59958d5d 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -1042,11 +1042,9 @@ def recompute(self, nodes: Union[IRSegment, List[IRFwOperation]]) -> bool: # =================== Helpers ==================== - def dump(self, filename: str) -> None: + def dumps(self) -> str: """ - Dump the graph into pickled format - - @param filename str + Dump the graph into binary by dill """ # FIXME: dump doesn't support customized op class PicklingContextSave: @@ -1056,9 +1054,17 @@ def __exit__(self, exc_type, exc_value, traceback): IRObject.__getstate__ = lambda self: self.__dict__.copy() with PicklingContextSave(): - with open(filename, 'wb') as f: - save = (IDGenerator().get_states(), self) - dill.dump(save, f) + save = (IDGenerator().get_states(), self) + return dill.dumps(save) + + def dump(self, filename: str) -> None: + """ + Dump the graph into pickled format + + @param filename str + """ + with open(filename, 'wb') as f: + f.write(self.dumps()) @staticmethod def from_dill(id_state, graph): diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 7f5117f6..bfff1905 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -11,9 +11,10 @@ import math import logging from dataclasses import dataclass, asdict +from pathlib import Path import _operator # required by eval() -import nnscaler # required by eval() +import nnscaler # required by eval() from nnscaler.graph.function.dimops import IRDimops from nnscaler.ir.cten import IRTensor, IRObject from nnscaler.ir.operator import IRFwOperation @@ -30,6 +31,9 @@ _train_module_ref: torch.nn.Module = torch.nn.Module().train() _eval_module_ref: torch.nn.Module = torch.nn.Module().eval() +# when profiling fails, we use the long default value as a penalty +_FAIL_FW_SPAN = 1000 * 1000 # 1000 seconds + @dataclass class ProfiledMetrics: @@ -54,158 +58,209 @@ class ProfiledMetrics: # the index of the tensor saved for backward in `node.inputs()` list train_mem2in_idx: Tuple[int] + def __repr__(self) -> str: + contents = dict() + for key, value in self.__dict__.items(): + if key in ('in_mem_info', 'param_mem_info', 'buffer_mem_info', 'train_mem_info'): + contents[key] = [f'{v / 1024 / 1024:.2f} MB' for v in value] + elif key == 'infer_memory': + contents[key] = f'{value / 1024 / 1024:.2f} MB' + elif key in ('fw_span', 'bw_span'): + contents[key] = f'{value:.2f} ms' + else: + contents[key] = value + return str(contents) -class CompProfiler: - @staticmethod - def profile(node: IRFwOperation, func: Callable, shapes: Shapes, dtypes: DTypes, - requires_grads: Tuple[bool], values: Tuple[Any], - warmup_sec: float = 2, prof_times: int = 20, max_prof_sec: float = 20, - **kwargs) -> Tuple[float, float, int, Tuple[int]]: - """ - Profile a function +def get_func(node: IRFwOperation) -> Tuple[Callable, Shapes, DTypes, Dict]: + """ + Get function call and its arguments from a cude IRGraph node + """ + assert isinstance(node, IRFwOperation), f"Only support profiling forward operation but got {type(node)}" + + if node.signature in CustomizedOps.kOpRuntime: + fn = CustomizedOps.kOpRuntime[node.signature] + else: + fn = eval(node.signature) + shapes, dtypes, requires_grads, values = [], [], [], [] + + # TODO: this function should rewrite with pytree + def extract_val(val: Union[IRObject, Any]) -> Any: + if isinstance(val, IRObject): + return extract_val(val.value) + elif isinstance(val, tuple): + return tuple([extract_val(v) for v in val]) + elif isinstance(val, list): + return list([extract_val(v) for v in val]) + elif isinstance(val, dict): + return {k: extract_val(v) for k, v in val.items()} + elif isinstance(val, slice): + return slice(extract_val(val.start), extract_val(val.stop), extract_val(val.step)) + else: + return val + + for t in node.inputs(): + if isinstance(t, IRTensor): + shapes.append(t.shape) + dtypes.append(t.dtype) + requires_grads.append(t.requires_grad) + values.append(t) + else: + shapes.append(None) + dtypes.append(None) + requires_grads.append(None) + values.append(extract_val(t)) + return fn, shapes, dtypes, requires_grads, values, extract_val(node.kwargs) - Args: - node IRFwOperation: the node in IRGraph - func Callable: the callable function, e.g., torch.nn.functional.linear - shapes Tuple[Tuple[int]]: the shapes of each input tensor - dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 - requires_grads Tuple[bool]: whether the input tensor requires gradient - values Tuple[Any]: the values of the inputs that are not IRTensor - warmup_sec float: warmup seconds - prof_times int: number of execution for profiling an operator - max_prof_sec float: max seconds for profiling an operator's forward or backward - kwargs Dict: other keyword argument for func call. - Returns: - fw_span float: the time in milliseconds for forward time - bw_span float: the time in milliseconds for backward time - infer_mem int: the peak memory in bytes after inference of the function - train_mem_info Tuple[int]: byte sizes of activation tensors saved for backward - """ - assert len(shapes) == len(dtypes), \ - f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" - # create data - assert dtypes is not None - def gen_torch_tensors(shape, dtype, requires_grad): - """Generate dummy input tenosrs""" - constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand - return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) - - tensors = tuple( - gen_torch_tensors(shape, dtype, requires_grad) if isinstance(value, IRTensor) else value \ - for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) - ) - require_backward = any([t.requires_grad for t in tensors if hasattr(t, 'requires_grad')]) - # FIXME: reconsidering requires_grad - if func.__name__ in ('type_as'): - require_backward = False - # repalce kwargs starting with 'self.xxx' - train_kwargs, eval_kwargs = {}, {} - for name, value in kwargs.items(): - if isinstance(value, str) and value.startswith('self.'): - train_val = getattr(_train_module_ref, value[5:]) - eval_val = getattr(_eval_module_ref, value[5:]) - else: - train_val = eval_val = value - train_kwargs[name] = train_val - eval_kwargs[name] = eval_val - - # run one sample - outputs = func(*tensors, **train_kwargs) - # only profile IRDimops currently, which has at least one tensor output and - # may have non-tensor outputs (like list, tuple, dict, etc.). In addition, - # we assume that non-tensor outputs will not be used in backward. +def profile(node: IRFwOperation, func: Callable, shapes: Shapes, dtypes: DTypes, + requires_grads: Tuple[bool], values: Tuple[Any], + warmup_sec: float = 2, prof_times: int = 20, max_prof_sec: float = 20, + **kwargs) -> Tuple[float, float, int, Tuple[int]]: + """ + Profile a function + + Args: + node IRFwOperation: the node in IRGraph + func Callable: the callable function, e.g., torch.nn.functional.linear + shapes Tuple[Tuple[int]]: the shapes of each input tensor + dtypes Optional[Tuple[torch.dtype]]: the dtype of each input tensor. Default will use torch.float32 + requires_grads Tuple[bool]: whether the input tensor requires gradient + values Tuple[Any]: the values of the inputs that are not IRTensor + warmup_sec float: warmup seconds + prof_times int: number of execution for profiling an operator + max_prof_sec float: max seconds for profiling an operator's forward or backward + kwargs Dict: other keyword argument for func call. + + Returns: + fw_span float: the time in milliseconds for forward time + bw_span float: the time in milliseconds for backward time + infer_mem int: the peak memory in bytes after inference of the function + train_mem_info Tuple[int]: byte sizes of activation tensors saved for backward + """ + assert len(shapes) == len(dtypes), \ + f"func {func.__name__}: expected each shape has a corresponding dtype, but got {shapes} and {dtypes}" + # create data + assert dtypes is not None + def gen_torch_tensors(shape, dtype, requires_grad): + """Generate dummy input tenosrs""" + constructor = torch.zeros if dtype in (torch.int64, torch.int32, torch.bool) else torch.rand + return constructor(tuple(shape), dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) + + tensors = tuple( + gen_torch_tensors(shape, dtype, requires_grad) if isinstance(value, IRTensor) else value \ + for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) + ) + require_backward = any([t.requires_grad for t in tensors if hasattr(t, 'requires_grad')]) + # FIXME: reconsidering requires_grad + if func.__name__ in ('type_as'): + require_backward = False + # repalce kwargs starting with 'self.xxx' + train_kwargs, eval_kwargs = {}, {} + for name, value in kwargs.items(): + if isinstance(value, str) and value.startswith('self.'): + train_val = getattr(_train_module_ref, value[5:]) + eval_val = getattr(_eval_module_ref, value[5:]) + else: + train_val = eval_val = value + train_kwargs[name] = train_val + eval_kwargs[name] = eval_val + + # run one sample + outputs = func(*tensors, **train_kwargs) + # only profile IRDimops currently, which has at least one tensor output and + # may have non-tensor outputs (like list, tuple, dict, etc.). In addition, + # we assume that non-tensor outputs will not be used in backward. + outputs = (outputs,) if torch.is_tensor(outputs) else outputs + outputs = tuple(filter(lambda x: torch.is_tensor(x) and x.requires_grad, outputs)) + assert all(torch.is_tensor(otensor) for otensor in outputs), \ + f"{func.__name__}: require all the outputs to be tensors" + grads = tuple(torch.zeros_like(otensor) for otensor in outputs) + + def run_step(func, tensors, kwargs, backward: bool): + outputs = func(*tensors, **kwargs) outputs = (outputs,) if torch.is_tensor(outputs) else outputs outputs = tuple(filter(lambda x: torch.is_tensor(x) and x.requires_grad, outputs)) - assert all(torch.is_tensor(otensor) for otensor in outputs), \ - f"{func.__name__}: require all the outputs to be tensors" - grads = tuple(torch.zeros_like(otensor) for otensor in outputs) - - def run_step(func, tensors, kwargs, backward: bool): - outputs = func(*tensors, **kwargs) - outputs = (outputs,) if torch.is_tensor(outputs) else outputs - outputs = tuple(filter(lambda x: torch.is_tensor(x) and x.requires_grad, outputs)) - if backward: - torch.autograd.backward(outputs, grads) - return outputs - - # profile inference peak memory + if backward: + torch.autograd.backward(outputs, grads) + return outputs + + # profile inference peak memory + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + mtic = torch.cuda.max_memory_allocated() # in bytes + with torch.no_grad(): + run_step(func, tensors, eval_kwargs, backward=False) + mtoc = torch.cuda.max_memory_allocated() # in bytes + infer_memory = mtoc - mtic + + train_mem_info = [] + train_mem2in_idx = [] + used_tensor = set() + # ref torch/utils/checkpoint.py/_checkpoint_without_reentrant + def pack_hook(x): + nonlocal train_mem_info, used_tensor + if x.untyped_storage().data_ptr() not in used_tensor: + used_tensor.add(x.untyped_storage().data_ptr()) + byte_size = x.element_size() + for dim in list(x.size()): + byte_size = byte_size * dim + idx = -1 + is_attr = False + for i, t in enumerate(tensors): + if not isinstance(t, torch.Tensor): + continue + if t.untyped_storage().data_ptr() == x.untyped_storage().data_ptr(): + if node.inputs()[i].is_attr(): + is_attr = True + idx = i + break + if not is_attr: + train_mem_info.append(byte_size) + train_mem2in_idx.append(idx) + return x + + def unpack_hook(x): + return x + + torch.cuda.synchronize() + torch.cuda.empty_cache() + with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + outs = run_step(func, tensors, train_kwargs, backward=require_backward) + + # warmup + warmup_cnt = 0 + tic = time.perf_counter() + while time.perf_counter() - tic < warmup_sec: + run_step(func, tensors, train_kwargs, backward=require_backward) torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - mtic = torch.cuda.max_memory_allocated() # in bytes + warmup_cnt += 1 + toc = time.perf_counter() + func_duration = (toc - tic) / warmup_cnt + real_prof_times = max(1, min(prof_times, math.ceil(max_prof_sec / func_duration))) + + # profile forward only + torch.cuda.synchronize() + tic = time.perf_counter() + for _ in range(real_prof_times): with torch.no_grad(): run_step(func, tensors, eval_kwargs, backward=False) - mtoc = torch.cuda.max_memory_allocated() # in bytes - infer_memory = mtoc - mtic - - train_mem_info = [] - train_mem2in_idx = [] - used_tensor = set() - # ref torch/utils/checkpoint.py/_checkpoint_without_reentrant - def pack_hook(x): - nonlocal train_mem_info, used_tensor - if x.untyped_storage().data_ptr() not in used_tensor: - used_tensor.add(x.untyped_storage().data_ptr()) - byte_size = x.element_size() - for dim in list(x.size()): - byte_size = byte_size * dim - idx = -1 - is_attr = False - for i, t in enumerate(tensors): - if not isinstance(t, torch.Tensor): - continue - if t.untyped_storage().data_ptr() == x.untyped_storage().data_ptr(): - if node.inputs()[i].is_attr(): - is_attr = True - idx = i - break - if not is_attr: - train_mem_info.append(byte_size) - train_mem2in_idx.append(idx) - return x - - def unpack_hook(x): - return x + torch.cuda.synchronize() + toc = time.perf_counter() + fw_span = (toc - tic) / real_prof_times * 1000 # in milliseconds - torch.cuda.synchronize() - torch.cuda.empty_cache() - with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): - outs = run_step(func, tensors, train_kwargs, backward=require_backward) - - # warmup - warmup_cnt = 0 - tic = time.perf_counter() - while time.perf_counter() - tic < warmup_sec: - run_step(func, tensors, train_kwargs, backward=require_backward) - torch.cuda.synchronize() - warmup_cnt += 1 - toc = time.perf_counter() - func_duration = (toc - tic) / warmup_cnt - real_prof_times = max(1, min(prof_times, math.ceil(max_prof_sec / func_duration))) - - # profile forward only - torch.cuda.synchronize() - tic = time.perf_counter() - for _ in range(real_prof_times): - with torch.no_grad(): - run_step(func, tensors, eval_kwargs, backward=False) - torch.cuda.synchronize() - toc = time.perf_counter() - fw_span = (toc - tic) / real_prof_times * 1000 # in milliseconds + # profile forward + backward + torch.cuda.synchronize() + tic = time.perf_counter() + for _ in range(real_prof_times): + run_step(func, tensors, train_kwargs, backward=require_backward) + torch.cuda.synchronize() + toc = time.perf_counter() + fwbw_span = (toc - tic) / real_prof_times * 1000 # in milliseconds + bw_span = max(fwbw_span - fw_span, 0.0) - # profile forward + backward - torch.cuda.synchronize() - tic = time.perf_counter() - for _ in range(real_prof_times): - run_step(func, tensors, train_kwargs, backward=require_backward) - torch.cuda.synchronize() - toc = time.perf_counter() - fwbw_span = (toc - tic) / real_prof_times * 1000 # in milliseconds - bw_span = max(fwbw_span - fw_span, 0.0) - - return fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx + return fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx class ProfileDataBase: @@ -219,68 +274,22 @@ def __init__(self, filename: Optional[str] = None) -> None: if filename is not None: self.load(filename) - @staticmethod - def get_func(node: IRFwOperation) -> Tuple[Callable, Shapes, DTypes, Dict]: - """ - Get function call and its arguments from a cude IRGraph node + def profile(self, node: IRFwOperation, override: bool = False) -> ProfiledMetrics: """ - assert isinstance(node, IRFwOperation), f"Only support profiling forward operation but got {type(node)}" - - if node.signature in CustomizedOps.kOpRuntime: - fn = CustomizedOps.kOpRuntime[node.signature] - else: - fn = eval(node.signature) - shapes, dtypes, requires_grads, values = [], [], [], [] - - # TODO: this function should rewrite with pytree - def extract_val(val: Union[IRObject, Any]) -> Any: - if isinstance(val, IRObject): - return extract_val(val.value) - elif isinstance(val, tuple): - return tuple([extract_val(v) for v in val]) - elif isinstance(val, list): - return list([extract_val(v) for v in val]) - elif isinstance(val, dict): - return {k: extract_val(v) for k, v in val.items()} - elif isinstance(val, slice): - return slice(extract_val(val.start), extract_val(val.stop), extract_val(val.step)) - else: - return val - - for t in node.inputs(): - if isinstance(t, IRTensor): - shapes.append(t.shape) - dtypes.append(t.dtype) - requires_grads.append(t.requires_grad) - values.append(t) - else: - shapes.append(None) - dtypes.append(None) - requires_grads.append(None) - values.append(extract_val(t)) - return fn, shapes, dtypes, requires_grads, values, extract_val(node.kwargs) - - def profile(self, node: IRFwOperation, device: Optional[int] = None, override: bool = False) -> ProfiledMetrics: - """ - Profile a forward node in IRGraph on a specific device (default current device) + Profile a forward node in IRGraph Args: node IRFwOperation: node of IRGraph - device int: the device that the node will execute on override bool: True if the existed can be overrided else False Returns: profiled_metrics ProfiledMetrics: the profiling data """ - fn, shapes, dtypes, requires_grads, values, kwargs = ProfileDataBase.get_func(node) - if not override and self.exist(node): return self.query(node) - if isinstance(device, int): - orig_device = torch.cuda.current_device() - torch.cuda.set_device(device) - + fn, shapes, dtypes, requires_grads, values, kwargs = get_func(node) + in_mem_info, param_mem_info, buffer_mem_info = [], [], [] for t in node.inputs(): if isinstance(t, IRTensor) and t.is_param(): @@ -290,15 +299,15 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b elif hasattr(t, 'byte_size'): in_mem_info.append(t.byte_size()) else: - _logger.warning(f'node {node}: skip input {t}') + _logger.debug(f'node {node}: skip input {t}') # run profiling try: fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx = \ - CompProfiler.profile(node, fn, shapes, dtypes, requires_grads, values, **kwargs) + profile(node, fn, shapes, dtypes, requires_grads, values, **kwargs) except Exception: _logger.exception(f'fail to profile {node}, use default values') - fw_span, bw_span = 0, 0 + fw_span, bw_span = _FAIL_FW_SPAN, 2 * _FAIL_FW_SPAN infer_memory = 0 for t in node.outputs(): if isinstance(t, IRTensor): @@ -306,20 +315,9 @@ def profile(self, node: IRFwOperation, device: Optional[int] = None, override: b # by default, we assume that all the input tensors are saved for backward train_mem_info = copy.deepcopy(in_mem_info) train_mem2in_idx = list(range(len(in_mem_info))) - # log to database - key = self._serialize(node) profiled_metrics = ProfiledMetrics(in_mem_info, param_mem_info, buffer_mem_info, fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx) - self.insert(node.signature, key, profiled_metrics) - _logger.info( - f"profiled {node.signature} | shapes: {shapes} | dtypes: {dtypes} | requires_grads: {requires_grads} | " - f"=> in mem info: {in_mem_info} | param mem info: {param_mem_info} | " - f"buffer mem info: {buffer_mem_info} | fw: {round(fw_span, 2)} ms | bw: {round(bw_span, 2)} ms | " - f"infer mem: {infer_memory} | train mem info: {train_mem_info} | idx: {train_mem2in_idx}") - - if isinstance(device, int): - torch.cuda.set_device(orig_device) return profiled_metrics def insert(self, name: str, key: str, profiled_metrics: ProfiledMetrics): @@ -345,9 +343,22 @@ def exist(self, node: IRFwOperation) -> bool: @return exist bool: True if the performance is recorded, else False """ key = self._serialize(node) - if node.signature not in self._data: + return self.exist_serialized(node.signature, key) + + def exist_serialized(self, signature: str, key: str) -> bool: + """ + Check if the node has the performance recorded in the database + + Args: + signature str: the signature of the function + key str: the serialized key + + Returns: + exist bool: True if the performance is recorded, else False + """ + if signature not in self._data: return False - if key not in self._data[node.signature]: + if key not in self._data[signature]: return False return True @@ -433,12 +444,22 @@ def dump(self, file: str, override=False): with open(file, 'w') as f: json.dump(self._data, f) - def dump_op(self, file: str, signature, override=False): - assert signature in self._data.keys(), f'{signature} has not been profiled' - file_n = os.path.join(file, signature +'.json') - with open(file_n, 'w') as f: - to_dump = {key: asdict(value) for key, value in self._data[signature].items()} - json.dump(to_dump, f, indent=2) + def dump_ops(self, folder: str, override=False): + """ + dump the profiled data into json format, each operator is saved in a separate file with the signature as the file name + + Args: + folder str: the folder name + override bool: True if the existed can be overrided else False + """ + folder = Path(folder) + for signature in self._data: + fname = folder / (signature + '.json') + if fname.exists(): + assert override, f"File {fname} exists. Set override = True to force dump." + with open(fname, 'w') as f: + to_dump = {key: asdict(value) for key, value in self._data[signature].items()} + json.dump(to_dump, f, indent=2) def load(self, file: str): """! @@ -450,10 +471,16 @@ def load(self, file: str): with open(file, 'r') as f: self._data = json.load(f) - def load_ops(self, file: str): - for filename in os.listdir(file): + def load_ops(self, folder: str): + """ + load the profiled data from json files in a folder. Each operator is saved in a separate file with the signature as the file name + + Args: + folder str: the folder name + """ + for filename in os.listdir(folder): if filename.endswith('.json'): - with open(os.path.join(file, filename)) as f: + with open(os.path.join(folder, filename)) as f: signature = filename[:-len('.json')] loaded_json = json.load(f) self._data[signature] = {key: ProfiledMetrics(**value) for key, value in loaded_json.items()} diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index f439082d..d5e6ed4d 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -1,6 +1,6 @@ import nnscaler from nnscaler.graph.parser.converter import convert_model -from nnscaler.profiler.database import ProfileDataBase +from nnscaler.profiler.database import get_func import tempfile import torch @@ -83,7 +83,7 @@ def test_common_register(): # test profiler.database for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'mock_add']): - profile_name = ProfileDataBase.get_func(node)[0].__qualname__ + profile_name = get_func(node)[0].__qualname__ assert profile_name == p_name, f'{profile_name} should be {p_name}' @@ -95,7 +95,7 @@ def test_common_register2(): # test profiler.database for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'mock_add2']): - profile_name = ProfileDataBase.get_func(node)[0].__qualname__ + profile_name = get_func(node)[0].__qualname__ assert profile_name == p_name, f'{profile_name} should be {p_name}' @@ -107,7 +107,7 @@ def test_autograd_register(): # test profiler.database for node, p_name in zip(ir_graph.nodes(), ['linear', 'linear', 'Function.apply']): - profile_name = ProfileDataBase.get_func(node)[0].__qualname__ + profile_name = get_func(node)[0].__qualname__ assert profile_name == p_name, f'{profile_name} should be {p_name}' diff --git a/tests/profiler/test_op_profile.py b/tests/profiler/test_op_profile.py index 88105464..9af6ac11 100644 --- a/tests/profiler/test_op_profile.py +++ b/tests/profiler/test_op_profile.py @@ -7,7 +7,7 @@ from nnscaler.parallel import _gen_graph from nnscaler.ir.tensor import IRTensor from nnscaler.ir.operator import IRFwOperation -from nnscaler.profiler.database import CompProfiler, ProfileDataBase +from nnscaler.profiler.database import get_func, profile, ProfileDataBase class NaiveFFN(torch.nn.Module): @@ -29,9 +29,9 @@ def test_op_profile_times(): with tempfile.TemporaryDirectory() as tempdir: graph, _ = _gen_graph(NaiveFFN(), {'x': torch.randn(2, 128, 1024)}, tempdir, False) fc1, relu, fc2 = graph.select(ntype=IRFwOperation) - fn, shapes, dtypes, requires_grads, values, kwargs = ProfileDataBase.get_func(fc1) + fn, shapes, dtypes, requires_grads, values, kwargs = get_func(fc1) tic = time.perf_counter() - CompProfiler.profile(fc1, fn, shapes, dtypes, requires_grads, values, **kwargs) + profile(fc1, fn, shapes, dtypes, requires_grads, values, **kwargs) toc = time.perf_counter() # this is always true because the op is very small. assert toc - tic < 20, f'op profile time is too long {toc - tic}' From 9ca92d30d5342e8ee861a60aa3612e335bb4d511 Mon Sep 17 00:00:00 2001 From: Zhiqi Lin Date: Thu, 16 May 2024 08:58:31 +0000 Subject: [PATCH 1640/1892] Merged PR 2136: dependency track for kwargs This PR supports the dependency track for kwargs in IRFwOperation. consider an operator is defined with kwargs including tensors/IRObjects: `CusOp(t, xx=s)`, where `s` is a tensor or dynamic IRObjects. The model has definition likes: ```python class Model(nn.Module): def __init__(self): xxx def forward(self, t, s): return CusOp(t, xx=s) ``` in this case, the graph IR will have kwargs of `xx:s` in the operator. Since the graph doesn't track the `s` tensor previously, we will not generate communication once `s` is partitioned or placed on different devices. This PR supports tracking IRObjects in kwargs to make the graph complete and generate complete tensors. --- nnscaler/graph/gener/gen.py | 29 ++--- nnscaler/graph/gener/utils.py | 9 +- nnscaler/graph/graph.py | 33 ++--- nnscaler/graph/segment.py | 121 +++++------------- nnscaler/ir/cten.py | 146 +++++++++++++++++++++- nnscaler/ir/operator.py | 11 +- tests/graph/gener/test_producer_fusion.py | 64 ++++++++++ tests/graph/test_graph.py | 24 ++++ 8 files changed, 305 insertions(+), 132 deletions(-) create mode 100644 tests/graph/gener/test_producer_fusion.py diff --git a/nnscaler/graph/gener/gen.py b/nnscaler/graph/gener/gen.py index 070c83d3..0ee20e83 100644 --- a/nnscaler/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -518,11 +518,9 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens itensor = new_ftensor.select(ctensor.indmap, ctensor.valmap) igrad = new_ftensor.grad.select(ctensor.grad.indmap, ctensor.grad.valmap) with graph.update(consumer) as consumer: - idx = consumer.inputs().index(ctensor) - consumer.set_input(idx, itensor) + consumer.replace_input(ctensor, itensor) with graph.mirror.update(consumer.mirror) as bconsumer: - idx = bconsumer.outputs().index(ctensor.grad) - bconsumer.set_output(idx, igrad) + bconsumer.replace_output(ctensor.grad, igrad) for devid in devtensors: indmaps = [t.indmap for t in devtensors[devid]] @@ -621,14 +619,13 @@ def local_producer_fusion(graph: IRSegment, ftensor: IRFullTensor) -> IRFullTens if node is None: for ptensor, producer in zip(devtensors[devid], devops[devid]): otensor = new_ftensor.select(ptensor.indmap, ptensor.valmap) - ograd = new_ftensor.grad.select(otensor.grad.indmap, otensor.grad.valmap) + ograd = new_ftensor.grad.select(ptensor.grad.indmap, ptensor.grad.valmap) with graph.update(producer): - idx = producer.outputs().index(ptensor) - producer.set_input(idx, otensor) - producer.input(idx).grad = ograd + producer.replace_output(ptensor, otensor) + for t in producer.find(otensor): + t.grad = ograd with graph.mirror.update(producer.mirror) as bproducer: - idx = bproducer.inputs().index(otensor.grad) - bproducer.set_input(idx, ograd) + bproducer.replace_input(ptensor.grad, ograd) else: node.device = devid node.recompute = rcid @@ -712,14 +709,12 @@ def local_consumer_multiref(graph: IRSegment, ftensor: IRFullTensor): # set corresponding consumer input and its backward consumer = devops[devid][idx] with graph.update(consumer): - while ctensor in consumer.inputs(): - fidx = consumer.inputs().index(ctensor) - consumer.set_input(fidx, otensor) - consumer.input(fidx).grad = new_ftensor.grad.select(ctensor.indmap, (0,1)) + consumer.replace_input(ctensor, otensor) + for t in consumer.find(otensor): + t.grad = new_ftensor.grad.select(ctensor.indmap, (0,1)) with graph.mirror.update(consumer.mirror) as bconsumer: - while ctensor.grad in bconsumer.outputs(): - bidx = bconsumer.outputs().index(ctensor.grad) - bconsumer.set_output(bidx, new_ftensor.grad.select(ctensor.indmap, (0,1))) + bconsumer.replace_output( + ctensor.grad, new_ftensor.grad.select(ctensor.indmap, (0,1))) # insert multiref multiref.device = devid min_fidx = min(graph.index(consumer) for consumer in devops[devid]) diff --git a/nnscaler/graph/gener/utils.py b/nnscaler/graph/gener/utils.py index 9fd36a96..068583fe 100644 --- a/nnscaler/graph/gener/utils.py +++ b/nnscaler/graph/gener/utils.py @@ -106,11 +106,12 @@ def flatten_grad(graph: IRSegment, ftensor: IRFullTensor): valmap = curr_valmap.map((0, 2)) if cidx != len(consumers) - 1 else curr_valmap grad = ftensor.grad.select(ctensor.indmap, valmap) # update consumer and its mirror node - fidx = consumer.inputs().index(ctensor) assert consumer.mirror is not None, consumer - bidx = consumer.mirror.outputs().index(consumer.input(fidx).grad) - consumer.input(fidx).grad = grad + with graph.update(consumer) as fnode: + for t in fnode.find(ctensor): + old_grad = t.grad + t.grad = grad with graph.mirror.update(consumer.mirror) as bnode: - bnode.set_output(bidx, grad) + bnode.replace_output(old_grad, grad) # update current valmap curr_valmap = curr_valmap.map((1, 2)) if cidx != len(consumers) - 1 else curr_valmap diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 59958d5d..79f57f07 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -87,25 +87,20 @@ def forward(self, *args: Tuple[IRObject]) -> Union[IRTensor, Tuple[IRTensor]]: else: raise RuntimeError('len(args) < len(itensors)') + # replace the graph input tensors with provided input tensors + # this is to connect outside-model operators like dataloader. for idx, (iobj, arg) in enumerate(zip(iobjs, args)): # reset input self.set_input(idx, arg) - # replace node inputs + # replace node inputs, kwargs and outputs for producer in self.producers(iobj.parent): with self.update(producer): - while iobj in producer.outputs(): - oidx = producer.outputs().index(iobj) - producer.set_output(oidx, arg) + producer.replace_output(iobj, arg) for consumer in self.consumers(iobj.parent): with self.update(consumer): - while iobj in consumer.inputs(): - iidx = consumer.inputs().index(iobj) - consumer.set_input(iidx, arg) + consumer.replace_input(iobj, arg) # reset output - for oidx, output in enumerate(self.outputs()): - output = IRGraph.modify_objects_of_complex( - self.output(oidx), lambda t: t if t != iobj else arg) - self.set_output(oidx, output) + self.replace_output(iobj, arg) from nnscaler.program import Program Program().add_nodes(self.nodes()) @@ -178,19 +173,17 @@ def from_logic_graph(nodes: List[IRCell], """ modifier = lambda t: t.tosub() if isinstance(t, IRFullTensor) else t # input / output - inputs = [IRGraph.modify_objects_of_complex(t, modifier) for t in inputs] - outputs = [IRGraph.modify_objects_of_complex(t, modifier) for t in outputs] + inputs = [IRCell.modify_objects_of_complex(t, modifier) for t in inputs] + outputs = [IRCell.modify_objects_of_complex(t, modifier) for t in outputs] # nodes for node in nodes: for idx, ftensor in enumerate(node.inputs()): - if isinstance(ftensor, IRObject): - subtensor = ftensor.tosub() if isinstance(ftensor, IRFullTensor) else ftensor - node.set_input(idx, subtensor) + subtensor = IRCell.modify_objects_of_complex(ftensor, modifier) + node.set_input(idx, subtensor) for idx, ftensor in enumerate(node.outputs()): - if isinstance(ftensor, IRObject): - subtensor = ftensor.tosub() if isinstance(ftensor, IRFullTensor) else ftensor - node.set_output(idx, subtensor) - node.kwargs.update(IRSegment.modify_objects_of_complex(node.kwargs, modifier)) + subtensor = IRCell.modify_objects_of_complex(ftensor, modifier) + node.set_output(idx, subtensor) + node.kwargs.update(IRCell.modify_objects_of_complex(node.kwargs, modifier)) graph = IRGraph(nodes, inputs, outputs, module_name) # check IRPyFunc diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index b9c33e37..feeb0934 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -359,17 +359,17 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: for ptensor, producer in zip(self.ptensors(ftensor), self.producers(ftensor)): # filter out non-autograd operators of IRPyFunc if isinstance(producer, IRPyFunc): continue - idx = producer.outputs().index(ptensor) grad = None if fgrad is None else fgrad.select(ptensor.indmap, (0, 1)) - producer.output(idx).grad = grad + for t in producer.find(ptensor): + t.grad = grad # set for consumers consumers, ctensors = [], [] # consumers that require gradient for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): + itensors = consumer.find(ctensor) # set by default None - for t in consumer.inputs(): # consider an op can have multiple same-tensor inputs - if isinstance(t, IRSubTensor) and t == ctensor: - t.grad = None + for itensor in itensors: + itensor.grad = None # filter out non-autograd operators if fgrad is None: continue if isinstance(consumer, IRPyFunc): continue @@ -383,9 +383,8 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: valmap = curr_valmap.map((0, 2)) if cidx != nconsumers - 1 else curr_valmap grad = fgrad.select(ctensor.indmap, valmap) curr_valmap = curr_valmap.map((1, 2)) if cidx != nconsumers - 1 else curr_valmap - for t in consumer.inputs(): - if isinstance(t, IRSubTensor) and t == ctensor: - t.grad = grad + for t in consumer.find(ctensor): + t.grad = grad def debug_tensor_map_str(self, ftensor: Optional[IRFullTensor] = None) -> str: dscp : str = '' @@ -412,8 +411,8 @@ def create_bwop(self, fwop: IRFwOperation) -> IRBpOperation: @return bwop IRBpOperation: the created backward operation """ assert isinstance(fwop, IRFwOperation), "Expected IRFwOperation" - fins = [t for t in fwop.inputs() if isinstance(t, IRSubTensor)] - fous = [t for t in fwop.outputs() if isinstance(t, IRSubTensor)] + fins = [t for t in fwop.iobjs() if isinstance(t, IRSubTensor)] + fous = [t for t in fwop.oobjs() if isinstance(t, IRSubTensor)] igrads = [t.grad for t in fins if t.grad is not None] # note not all output tensors will be consumed by nodes, e.g., chunk. # for these cases, the backward op doesn't have exactly the same number of @@ -466,22 +465,22 @@ def _reorder_producer_consumer(self): self._consumers, self._ctensors = dict(), dict() # set input and output - for obj in IRSegment.get_objects_from_complex(self.inputs()): + for obj in self.iobjs(): self._add_ftensor(obj.parent) - for obj in IRSegment.get_objects_from_complex(self.outputs()): + for obj in self.oobjs(): self._add_ftensor(obj.parent) # set producer and consumer + # NOTE: we use `dict.fromkeys` to remove duplicate tensors + # as well as keep the order of tensors for node in self._nodes: if isinstance(node, IRAdapter): continue - itensors = set(t for t in node.inputs() if isinstance(t, IRObject)) - for itensor in itensors: + for itensor in dict.fromkeys(node.iobjs()): ftensor = itensor.parent self._add_ftensor(ftensor) self._consumers[ftensor].append(node) self._ctensors[ftensor].append(itensor) - otensors = set(t for t in node.outputs() if isinstance(t, IRObject)) - for otensor in otensors: + for otensor in dict.fromkeys(node.oobjs()): ftensor = otensor.parent self._add_ftensor(ftensor) self._producers[ftensor].append(node) @@ -493,11 +492,9 @@ def insert(self, node: IRCell, index: Union[int, CellPosition]): """ Insert a node at index. - TODO: dataflow dependency update - TODO: input / output check - - @param node IRCell: the inserted node - @param index int: the index + Args: + node (IRCell): the inserted node + index (int or CellPosition): the index """ pos = CellPosition((index,)) if isinstance(index, int) else index @@ -507,18 +504,18 @@ def insert(self, node: IRCell, index: Union[int, CellPosition]): index = pos[0] # insert node self._nodes.insert(index, node) - # update producer and consumer if isinstance(node, IRAdapter): return - # consumer - itensors = set(t for t in node.inputs() if isinstance(t, IRObject)) - for itensor in itensors: + # update producer and consumer + # NOTE: we use `dict.fromkeys` to remove duplicate tensors + # as well as keep the order of tensors + # - consumer + for itensor in dict.fromkeys(node.iobjs()): ftensor = itensor.parent self._add_ftensor(ftensor) self._consumers[ftensor].append(node) self._ctensors[ftensor].append(itensor) - # producer - otensors = set(t for t in node.outputs() if isinstance(t, IRObject)) - for otensor in otensors: + # - producer + for otensor in dict.fromkeys(node.oobjs()): ftensor = otensor.parent self._add_ftensor(ftensor) self._producers[ftensor].append(node) @@ -533,12 +530,12 @@ def remove(self, node: IRCell, _pos: Union[int, CellPosition] = None) -> CellPos """ Remove a node at index - # TODO: check input and output - - @param node IRCell: the removed node - @param _pos Optional[Union[int, CellPosition]: help to save cost if provide node position. + Args: + node (IRCell): the removed node + _pos (Optional[Union[int, CellPosition]): help to save cost if provide node position. - @return index CellPosition: the removed index + Returns: + CellPosition: the removed index """ pos = self.index(node) if _pos is None else _pos assert self.node(pos) == node, \ @@ -551,8 +548,7 @@ def remove(self, node: IRCell, _pos: Union[int, CellPosition] = None) -> CellPos # update producer and consumer if isinstance(node, IRAdapter): return pos # consumer - itensors = set(t for t in node.inputs() if isinstance(t, IRObject)) - for itensor in itensors: + for itensor in dict.fromkeys(node.iobjs()): ftensor = itensor.parent idx = self._consumers[ftensor].index(node) self._consumers[ftensor].pop(idx) @@ -560,8 +556,7 @@ def remove(self, node: IRCell, _pos: Union[int, CellPosition] = None) -> CellPos if len(self._consumers[ftensor]) == 0 and len(self._producers[ftensor]) == 0: self._remove_ftensor(ftensor) # producer - otensors = set(t for t in node.outputs() if isinstance(t, IRObject)) - for otensor in otensors: + for otensor in dict.fromkeys(node.oobjs()): ftensor = otensor.parent idx = self._producers[ftensor].index(node) self._producers[ftensor].pop(idx) @@ -1098,55 +1093,3 @@ def extra_repr(self) -> str: dscp += f"\nOutputs: {self.outputs()}\n{'=' * len(self.name)}\n" return dscp - @staticmethod - def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[IRObject]: - """Get all IRObjects from a complex data structure - - Supported complex of types: List, Tuple, Dict, IRTensor, IRObject - - Args: - val (Any): the complex data structure to be modified - _objects (List[IRObject] | None): - if provided, the objects will be appened into this - - @return _objects List[IRObject]: all IRObject - """ - _objects = [] if _objects is None else _objects - if isinstance(val, (tuple, list)): - for item in val: - IRSegment.get_objects_from_complex(item, _objects) - if isinstance(val, dict): - for key, value in val.items(): - IRSegment.get_objects_from_complex(key, _objects) - IRSegment.get_objects_from_complex(value, _objects) - if isinstance(val, slice): - IRSegment.get_objects_from_complex([val.start, val.stop, val.step], _objects) - if isinstance(val, IRObject): - _objects.append(val) - return _objects - - @staticmethod - def modify_objects_of_complex(val: Any, modifier: Callable) -> Any: - """Return a complex data structure with modified IRObjects - - Supported complex of types: List, Tuple, Dict, IRTensor, IRObject - - Args: - val (Any): the complex data structure to be modified - modifier (Callable): a modifier that takes an IRObject and return a new one. - - Return: - new_val (Any): complex data structure with modified IRObjects - """ - rcall = IRSegment.modify_objects_of_complex - if isinstance(val, tuple): - return tuple(rcall(item, modifier) for item in val) - if isinstance(val, list): - return list(rcall(item, modifier) for item in val) - if isinstance(val, dict): - return {rcall(key, modifier):rcall(value, modifier) for key, value in val.items()} - if isinstance(val, slice): - return slice(rcall(val.start, modifier), rcall(val.stop, modifier), rcall(val.step, modifier)) - if isinstance(val, IRObject): - return modifier(val) - return val diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 1089f419..5bfa4bd8 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -16,7 +16,7 @@ from __future__ import annotations from functools import lru_cache -from typing import List, Tuple, Union, Optional, Any, Dict +from typing import List, Tuple, Union, Optional, Any, Dict, Callable from collections import OrderedDict import copy import torch @@ -41,6 +41,13 @@ def __init__(self, """ Create a node with name (variable name) and module type (module_name) + Notes: setting input / kwarg / output IRObject will cause a copy-on-write + behavior, where the IRObject will be copied (see `set_input` and `set_output`) + to prevent users accidently updating it outside. + + To update the IRObject in input / kwarg / output, please use `find`, `input(s)`, + and `output(s)` to get the real instance tensor in the IRCell. + Args: name (str): the cell name signature (str): the cell function signature, @@ -164,6 +171,18 @@ def inputs(self) -> Tuple[NestedVarOrStatic]: Tuple[NestedVarOrStatic] """ return tuple(self._inputs) + + @lru_cache(maxsize=None) + def iobjs(self) -> Tuple[IRObject]: + """ + Get all IRObject in the inputs and kwargs. + + The order follows the inputs order then kwargs order. + + Returns: + Tuple[IRObject]: all IRObject in the inputs + """ + return tuple(IRCell.get_objects_from_complex([self._inputs, self._kwargs])) def output(self, index: int) -> NestedVarOrStatic: """Get the index-th output value @@ -175,6 +194,16 @@ def output(self, index: int) -> NestedVarOrStatic: NestedVarOrStatic: (nested) IRObject or any static value (int, bool, str, etc) """ return self._outputs[index] + + @lru_cache(maxsize=None) + def oobjs(self) -> Tuple[IRObject]: + """ + Get all IRObjects in the outputs. + + Returns: + Tuple[IRObject]: all IRObject in the outputs + """ + return tuple(IRCell.get_objects_from_complex(self._outputs)) # 'maxsize=None' set no limit on cache growth, but it's ok since we have no args @lru_cache(maxsize=None) @@ -192,6 +221,7 @@ def reset_inputs(self, length:int) -> None: """ self._inputs = [None] * length self.inputs.cache_clear() + self.iobjs.cache_clear() def set_input(self, index: int, val: NestedVarOrStatic) -> NestedVarOrStatic: """Set the index-th input @@ -208,6 +238,7 @@ def set_input(self, index: int, val: NestedVarOrStatic) -> NestedVarOrStatic: val.cell = self self._inputs[index] = val self.inputs.cache_clear() + self.iobjs.cache_clear() return val def reset_outputs(self, length:int) -> None: @@ -216,6 +247,7 @@ def reset_outputs(self, length:int) -> None: """ self._outputs = [None] * length self.outputs.cache_clear() + self.oobjs.cache_clear() def set_output(self, index: int, val: NestedVarOrStatic): """ @@ -232,7 +264,66 @@ def set_output(self, index: int, val: NestedVarOrStatic): val.cell = self self._outputs[index] = val self.outputs.cache_clear() + self.oobjs.cache_clear() return val + + def replace_input(self, old: IRObject, new: IRObject): + """Replace the old input (including kwargs) with the new input + + Args: + old (IRObject): the old input + new (IRObject): the new input + """ + def replace(obj): + val = copy.copy(new) if obj == old else obj + val.cell = self + return val + + self._inputs = IRCell.modify_objects_of_complex(self._inputs, replace) + self._kwargs = IRCell.modify_objects_of_complex(self._kwargs, replace) + self.inputs.cache_clear() + self.iobjs.cache_clear() + + def replace_output(self, old: IRObject, new: IRObject): + """Replace the old output with the new output + + Args: + old (IRObject): the old output + new (IRObject): the new output + """ + def replace(obj): + val = copy.copy(new) if obj == old else obj + val.cell = self + return val + + self._outputs = IRCell.modify_objects_of_complex(self._outputs, replace) + self.outputs.cache_clear() + self.oobjs.cache_clear() + + def replace(self, old: IRObject, new: IRObject): + """Replace the old object with the new object in inputs, kwargs, and outputs + + Args: + old (IRObject): the old object + new (IRObject): the new object + """ + self.replace_input(old, new) + self.replace_output(old, new) + + def find(self, obj: IRObject) -> Tuple[IRObject]: + """Find all the objects equal to `obj` in inputs, kwargs, or outputs + + Args: + obj (IRObject): the object to find + + Returns: + Tuple[IRObject]: all the objects equal to `obj` + """ + outs = [] + IRCell.get_objects_from_complex(self._inputs, outs) + IRCell.get_objects_from_complex(self._kwargs, outs) + IRCell.get_objects_from_complex(self._outputs, outs) + return tuple(out for out in outs if out == obj) @property def comment(self) -> Optional[str]: @@ -266,6 +357,59 @@ def __repr__(self) -> str: f"inputs={ins}, " f"outputs={self.outputs()})") return dscp + + @staticmethod + def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[IRObject]: + """Get all IRObjects from a complex data structure + + Supported complex of types: List, Tuple, Dict, IRTensor, IRObject + + Args: + val (Any): the complex data structure to be modified + _objects (List[IRObject] | None): + if provided, the objects will be appened into this + + @return _objects List[IRObject]: all IRObject + """ + _objects = [] if _objects is None else _objects + if isinstance(val, (tuple, list)): + for item in val: + IRCell.get_objects_from_complex(item, _objects) + elif isinstance(val, dict): + for key, value in val.items(): + IRCell.get_objects_from_complex(key, _objects) + IRCell.get_objects_from_complex(value, _objects) + elif isinstance(val, slice): + IRCell.get_objects_from_complex([val.start, val.stop, val.step], _objects) + elif isinstance(val, IRObject): + _objects.append(val) + return _objects + + @staticmethod + def modify_objects_of_complex(val: Any, modifier: Callable) -> Any: + """Return a complex data structure with modified IRObjects + + Supported complex of types: List, Tuple, Dict, IRTensor, IRObject + + Args: + val (Any): the complex data structure to be modified + modifier (Callable): a modifier that takes an IRObject and return a new one. + + Return: + new_val (Any): complex data structure with modified IRObjects + """ + rcall = IRCell.modify_objects_of_complex + if isinstance(val, tuple): + return tuple(rcall(item, modifier) for item in val) + if isinstance(val, list): + return list(rcall(item, modifier) for item in val) + if isinstance(val, dict): + return {rcall(key, modifier):rcall(value, modifier) for key, value in val.items()} + if isinstance(val, slice): + return slice(rcall(val.start, modifier), rcall(val.stop, modifier), rcall(val.step, modifier)) + if isinstance(val, IRObject): + return modifier(val) + return val class IRObject: diff --git a/nnscaler/ir/operator.py b/nnscaler/ir/operator.py index 6433772c..17217618 100644 --- a/nnscaler/ir/operator.py +++ b/nnscaler/ir/operator.py @@ -31,7 +31,16 @@ def __init__(self, name: str, signature: str, for idx, input in enumerate(inputs): self.set_input(idx, input) - # additional argument + # setup kwargs + # similar with set_input and set_output, the IRObject + # in kwargs will be set with copy-on-write to avoid + # potential modifications outside. + def replace(t: IRObject): + t = copy.copy(t) + t.cell = self + return t + + kwargs = IRCell.modify_objects_of_complex(kwargs, replace) self.kwargs.update(kwargs) # default infer rule diff --git a/tests/graph/gener/test_producer_fusion.py b/tests/graph/gener/test_producer_fusion.py new file mode 100644 index 00000000..72791c8e --- /dev/null +++ b/tests/graph/gener/test_producer_fusion.py @@ -0,0 +1,64 @@ + +from nnscaler.ir.tensor import IRFullTensor +import nnscaler.graph.function.function as F +from nnscaler.graph import IRGraph + +from nnscaler.graph.gener.gen import IRAdapterGener + + +def _tensor(shape, requires_grad=True): + return IRFullTensor(shape, requires_grad=requires_grad).tosub() + + +def test_gener_producer_fusion_replicate(): + + data = _tensor([128, 128], False) + w1 = _tensor([128, 128]) + out1 = _tensor([128, 128]) + l1 = F.Linear(data, w1) + l1.set_output(0, out1) + + w2 = _tensor([128, 128]) + out2 = _tensor([128, 128]) + l2 = F.Linear(l1.output(0), w2) + l2.set_output(0, out2) + + loss = _tensor([1]) + sum = F.Sum(l2.output(0)) + sum.set_output(0, loss) + + nodes = [l1, l2, sum] + graph = IRGraph(nodes, [data], [loss], 'genmodel') + graph.backward(loss) + + graph.assign(l1, 0) + + s1, s2 = graph.partition(l2, l2.algorithms('dim'), idx=0, dim=0, num=2) + r1, r2 = graph.replicate(s1, 2) + graph.assign(r1, 0) + graph.assign(r2, 0) + s3, s4 = graph.partition(s2, s2.algorithms('dim'), idx=0, dim=1, num=2) + graph.assign(s3, 1) + graph.assign(s4, 1) + + graph.assign(sum, 0) + + # print(graph.extra_repr()) + IRAdapterGener.local_producer_fusion(graph, out2.parent) + # print(graph.extra_repr()) + + assert len(graph.select(name='accum')) == 1 + accum = graph.select(name='accum')[0] + + mms = graph.select(name='linear')[1:] + assert len(mms) == 4 + + new_t = mms[0].output(0) + old_t = l2.output(0) + assert new_t.parent != old_t.parent + for mm in mms: + if mm.device == (0,): + assert mm.output(0).parent == new_t.parent + if mm.device == (1,): + assert mm.output(0).parent == old_t.parent + assert accum.output(0).parent == new_t.parent diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 5de8015c..392421cd 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -26,3 +26,27 @@ def test_graph_from_logic(): assert isinstance(node.kwargs['kw']['a'][0], IRSubTensor) assert isinstance(node.kwargs['kw']['b'], IRSubTensor) assert isinstance(node.kwargs['t'], IRSubTensor) + + +def test_graph_kwargs_track(): + + node = IRFwOperation("test", "test", + inputs=[IRFullTensor([256, 256])], + num_outputs=1, + # kwargs + kw={ + 'a':[IRFullTensor([128, 256]),], + 'b':IRFullTensor([128, 128]) + }, + t=IRFullTensor([128, 256])) + output = IRFullTensor([128, 256]) + node.set_output(0, output) + graph = IRGraph.from_logic_graph([node], [node.input(0), node.kwargs['t']], [output], 'GenModule') + assert len(graph.nodes()) == 1 + assert len(graph.full_tensors()) == 5 + args = [IRFullTensor([256, 256]).tosub(), IRFullTensor([128, 256]).tosub()] + # forward replace + graph(*args) + assert graph.input(0) == args[0] + assert graph.input(1) == args[1] + assert graph.node(0).kwargs['t'] == args[1] From 37b7c7fb4328d591183be421d0bfb46fbc3d8da1 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 20 May 2024 06:20:36 +0000 Subject: [PATCH 1641/1892] Merged PR 2145: parallel module: ComputeConfig: move pas config to top level --- docs/source/parallel_module.md | 4 +- examples/vision/swin/train.py | 7 +-- nnscaler/parallel.py | 63 +++++++++++++++++++++----- nnscaler/policies.py | 4 +- nnscaler/runtime/module.py | 6 ++- tests/parallel_module/test_override.py | 8 ++-- tests/test_policies.py | 12 ++--- tests/utils.py | 3 +- 8 files changed, 75 insertions(+), 32 deletions(-) diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 22965cbf..555f4898 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -204,6 +204,7 @@ class ComputeConfig: pipeline_nstages: int = 1 pipeline_scheduler: Optional[str] = None + pas_config: Dict[str, Any] = field(default_factory=dict) user_config: UserConfig = field(default_factory=UserConfig) ``` We can categorize the fields into 4 categories: @@ -225,6 +226,7 @@ We can categorize the fields into 4 categories: - `pipeline_nmicros`: the number of microbatches in the pipeline. - `pipeline_nstages`: the number of stages in the pipeline. - `pipeline_scheduler`: the scheduler name for the pipeline. Current we support four schedulers in training `1f1b`/`1f1b_plus`/`gpipe`/`chimera_direct` (4 stages pipeline only), and one scheduler in inference `infer_pipe`. + - `pas_config`: the configuration for the PAS policy. It is a dictionary, and will be used by the PAS policy. Please note different PAS will have different configurations, and please check the PAS policy for details. 4. User configuration - user_config: the user configuration,which is used to decide whether skipping compiling and reusing the previously compiled parallel module. It has two categories of configuration: - `graph`: the graph related configuration, which is used to decide whether skipping graph generation only. @@ -524,7 +526,7 @@ The input is a list of samples, and returns a list of outputs for the samples. I Writing a pas policy can be very hard and error-prone. So we provide 6 builtin PAS policies to help you. `dp`, `tp`, `pp`, `data`, `hybrid`, and `autodist`. Please note only `autodist` policy is the recommended policy for most cases, and all other PAS policies are mainly test purpose only. -The configuration of the PAS policy should be passed in the `user_config.code['pas']` of `ComputeConfig` as a dictionary. +The configuration of the PAS policy should be passed in the `pas_config` of `ComputeConfig` as a dictionary. 1. `dp`: data parallelism. It will replicate the module across all devices, and run data parallelism across all devices. It requires the `plan_ngpus` must be 1 and no configurations diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index a9cfac66..2895dbcf 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -153,6 +153,10 @@ def train(args, compute_config: nnscaler.ComputeConfig): use_pipeline=args.pp_size > 1, pipeline_nmicros=args.gbs // args.mbs, pipeline_nstages=args.pp_size, + pas_config={ # for autodist only + 'update_freq': args.gbs // args.mbs, + 'use_fp16': args.fp16, + }, user_config=nnscaler.UserConfig( graph={ 'mbs': args.mbs, @@ -165,9 +169,6 @@ def train(args, compute_config: nnscaler.ComputeConfig): 'pp_size': args.pp_size, 'tp_size': args.tp_size, 'dp_size': args.dp_size, - 'pas': { - 'update_freq': args.gbs // args.mbs, # for autodist only - } } ) ) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 175ea4f9..9b5c0384 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -93,16 +93,8 @@ class UserConfig: # ``` graph: Dict[str, Any] = field(default_factory=dict) # you can put any configuration that may affect the generated code (but not affect the traced graph) here. - # For example, extra arguments of your PAS function can put here. - # For all builtin pas, we will put PAS config in `code['pas']`. code: Dict[str, Any] = field(default_factory=dict) - def get_pas_config(self) -> Dict[str, Any]: - """ - All builtin pas will read their config here. - """ - return self.code.get('pas', {}) - @dataclass(frozen=True) class ComputeConfig: @@ -133,8 +125,9 @@ class ComputeConfig: pipeline_nstages: int = -1 # it is pas's responsibility to apply the scheduler pipeline_scheduler: str = '1f1b' + # PAS policy settings + pas_config: Dict[str, Any] = field(default_factory=dict) # the customized configs from user that can affect the graph and code generation. - # for example, module configuration or PAS policy settings. user_config: UserConfig = field(default_factory=UserConfig) def __post_init__(self): @@ -219,6 +212,52 @@ def optimizer_dedup_group_size(self) -> int: else: return self.plan_ngpus + @classmethod + def safe_dump_to_file(cls, cfg: 'ComputeConfig', file: Union[str, Path]) -> None: + """ + torch.save(cfg) is not safe when we change the fields of ComputeConfig. + So we should use this method to save the config. + """ + torch.save(asdict(cfg), file) + + @classmethod + def safe_load_from_file(cls, file: Union[str, Path], return_none_on_error=True) -> Optional['ComputeConfig']: + """ + Load the config from file. + `return_none_on_error` controls the behaivor when the file not exists or failed to load. + If `return_none_on_error` is True, will return None when failed to load. + If `return_none_on_error` is False, will raise when failed to load. + """ + if Path(file).exists(): + try: + cfg = torch.load(file) + if isinstance(cfg, dict): # in old version, we save the object directly (not save as dict) + # this can raise if cfg has extra keys. + # which means some fields of ComputeConfig has been removed(we should avoid this). + # in this case, we just return None. + return cls(**cfg) + return cfg + except Exception as e: + if not return_none_on_error: + raise + logger.warning(f"Failed to load ComputeConfig with error {str(e)}.") + elif not return_none_on_error: + raise FileNotFoundError(f"Failed to load compute config from {file}. File not found.") + return None + + @classmethod + def safe_equals(cls, a: Optional['ComputeConfig'], b: Optional['ComputeConfig']) -> bool: + """ + Return False if a and b are from incompatible version of ComputeConfig + This is only for backward compatibility, and will be removed in future + and can use `==` when we save dict version of ComputeConfig to file. + """ + try: + return a == b + except AttributeError: + logger.warning("Failed to compare ComputeConfig. They are incompatible.") + return False + @contextmanager def _flags(flags, /, **kwargs): @@ -477,8 +516,8 @@ def _prepare_and_check_reusable( # you can take it as a continous operation after a failed generation. reusable = False config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE - old_config: ComputeConfig = torch.load(config_file) if config_file.exists() else None - is_config_match = old_config == compute_config + old_config: Optional[ComputeConfig] = ComputeConfig.safe_load_from_file(config_file) + is_config_match = ComputeConfig.safe_equals(old_config, compute_config) is_graph_config_match = old_config is not None and old_config.graph_config == compute_config.graph_config trace_meta_files = [ outdir / FxModuleParser.ATTR_CONTENT_FILE_0, # just check the first is good enough @@ -953,7 +992,7 @@ def __init__(self, init_params=True): outdir, reusable = _prepare_and_check_reusable(cube_savedir, module_class, compute_config, instance_name, reuse) if not reusable: config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE - torch.save(compute_config, config_file) # always refresh compute config + ComputeConfig.safe_dump_to_file(compute_config, config_file) # always refresh compute config with _compile_flags(compute_config): regen_status = _gencode( module_or_module_class, diff --git a/nnscaler/policies.py b/nnscaler/policies.py index e45354a9..1aaf6ff3 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -78,7 +78,7 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): # get the current random state state = random.getstate() - seed = cfg.user_config.get_pas_config().get('seed', 1) # by default we fix the seed for test reproducibility + seed = cfg.pas_config.get('seed', 1) # by default we fix the seed for test reproducibility random.seed(seed) devs = list(range(ngpus)) @@ -197,7 +197,7 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: - pas_cfg = cfg.user_config.get_pas_config() + pas_cfg = cfg.pas_config # required parameters update_freq = pas_cfg['update_freq'] diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index e839f682..0e613bb2 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -660,6 +660,7 @@ def _post_init(self, init_params=True): # TODO: re-enable this check # if dist.is_initialized() and self.rank != dist.get_rank(): # raise RuntimeError(f"The rank to load this module file name is expected to be {self._rank}, but got {dist.get_rank()}") + from nnscaler.parallel import ComputeConfig self._non_presistent_buffers_inited = init_params or not self._non_persistent_buffers_set module_file = Path(sys.modules[self.__module__].__file__) @@ -670,7 +671,10 @@ def _post_init(self, init_params=True): self._warn_uninitialized_non_persistent_buffers() self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}")) - self._compute_config: 'ComputeConfig' = torch.load(module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}")) + self._compute_config: ComputeConfig = ComputeConfig.safe_load_from_file( + module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}"), + return_none_on_error=False + ) self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}")) for reducer in self.reducers: diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index 622c5139..be07b2dc 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -180,21 +180,21 @@ def test_override(): # Graph | graph match | generate _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'graph', 'g8', False) g8_module_path = module_path.with_name('g8') - assert torch.load(g8_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(1, 1) + assert ComputeConfig.safe_load_from_file(g8_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(1, 1) graph_stat = (g8_module_path / 'graph.ckp').stat() args_stat = (g8_module_path / 'forward_args.pkl').stat() _to_cube_model(MyModule, ComputeConfig(2, 2), tempdir, 'graph', 'g8', False) assert (g8_module_path / 'graph.ckp').stat().st_mtime_ns == graph_stat.st_mtime_ns assert (g8_module_path / 'forward_args.pkl').stat().st_mtime_ns == args_stat.st_mtime_ns - assert torch.load(g8_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(2, 2) + assert ComputeConfig.safe_load_from_file(g8_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(2, 2) # MOO | graph match | generate code only _to_cube_model(MyModule, ComputeConfig(1, 1), tempdir, 'moo', 'g9', False) g9_module_path = module_path.with_name('g9') - assert torch.load(g9_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(1, 1) + assert ComputeConfig.safe_load_from_file(g9_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(1, 1) graph_stat = (g9_module_path / 'graph.ckp').stat() args_stat = (g9_module_path / 'forward_args.pkl').stat() _to_cube_model(MyModule, ComputeConfig(2, 2), tempdir, 'moo', 'g9', False) assert (g9_module_path / 'graph.ckp').stat().st_mtime_ns == graph_stat.st_mtime_ns assert (g9_module_path / 'forward_args.pkl').stat().st_mtime_ns == args_stat.st_mtime_ns - assert torch.load(g9_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(2, 2) + assert ComputeConfig.safe_load_from_file(g9_module_path / ParallelModule.COMPUTE_CONFIG_FILE) == ComputeConfig(2, 2) diff --git a/tests/test_policies.py b/tests/test_policies.py index b0a9f2f9..230aa58d 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -47,14 +47,10 @@ def test_autodist(): MLP(), {'data': dummy_data()}, 'autodist', - ComputeConfig(2, 4, user_config=UserConfig( - code={ - 'pas': { - 'update_freq': 1, - 'task_name': 'test_autodist', - } - } - )), + ComputeConfig(2, 4, pas_config={ + 'update_freq': 1, + 'task_name': 'test_autodist', + }), cube_savedir=tempdir, load_module=False ) diff --git a/tests/utils.py b/tests/utils.py index b21abd87..bfc7ee70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,6 +16,7 @@ import torch.distributed as dist import torch.distributed.distributed_c10d as c10d +from nnscaler.parallel import ComputeConfig from nnscaler.runtime.module import ParallelModule from nnscaler.runtime.device import DeviceGroup, CompileFlag @@ -291,7 +292,7 @@ def new_empty(cube_module_cls: Type[ParallelModule], device='meta', init_params= This is useful when you want to get model information (e.g. fullmap/zero) without allocating memory. """ module_file = Path(sys.modules[cube_module_cls.__module__].__file__) - compute_config = torch.load(module_file.with_name(f"{cube_module_cls.COMPUTE_CONFIG_FILE}")) + compute_config = ComputeConfig.safe_load_from_file(module_file.with_name(f"{cube_module_cls.COMPUTE_CONFIG_FILE}")) with replace_all_device_with(device, True), mock_cube_env(cube_module_cls, compute_config), mock_dist(cube_module_cls.rank, compute_config.runtime_ngpus): return cube_module_cls(init_params=init_params) From 3eed171c87ec8f9849cc4321a110edc277881917 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 20 May 2024 06:46:43 +0000 Subject: [PATCH 1642/1892] Merged PR 2147: Add graph and plan analysis in autodist Example output for input graph analysis: ![image.png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2147/attachments/image.png) Example output for spmd plan analysis: ![image (2).png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2147/attachments/image%20%282%29.png) ![image (3).png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2147/attachments/image%20%283%29.png) In addition, some info level logging is changed to debug to make the output more clean. --- nnscaler/autodist/apis.py | 17 +- nnscaler/autodist/cost_database.py | 3 +- nnscaler/autodist/descs.py | 5 +- nnscaler/autodist/model_graph.py | 71 +++++++- nnscaler/autodist/spmd_solver.py | 252 +++++++++++++++++++++++++---- nnscaler/graph/parser/fx/parser.py | 2 +- nnscaler/profiler/database.py | 3 +- utility/prim_profiler.py | 3 +- 8 files changed, 309 insertions(+), 47 deletions(-) diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 1286b140..60833f1c 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -1,4 +1,4 @@ -from .spmd_solver import calc_optimal_spmd_plan +from .spmd_solver import calc_optimal_spmd_plan, analysis_pretty_printer from .pipeline_solver import calc_optimal_pp_plan from .autodist_config import AutoDistConfig from .model_graph import ModelGraph, estimate_mem_lower_bound @@ -122,10 +122,7 @@ def parallelize_graph(graph: IRGraph, with open(autodist_config.save_plan_path, 'w') as f: json.dump(search_out.to_json(), f, indent=2) - _logger.info(f'use plan with e2e time/s {search_out.e2e_time}s,' + - f'stage mems/GB {search_out.stage_mems}, ' + - f'stage all times/s {search_out.stage_all_times}, ' + - f'stage comp times/s {search_out.stage_comp_times}') + _logger.info(f'use plan with e2e time/s {1000 * search_out.e2e_time:.2f}ms') pp_desc = search_out.desc cid2node: Dict[int, IRFwOperation] = dict() @@ -240,9 +237,11 @@ def parallelize_graph(graph: IRGraph, # partition and assign nodes to devices # TODO(yizhu1): network topo aware device map offset = 0 - for spmd_desc, stage in zip(pp_desc.spmd_descs, stages): + for idx, (spmd_desc, stage) in enumerate(zip(pp_desc.spmd_descs, stages)): cur_ngpus = spmd_desc.mesh_desc.ngpus dev = [offset + i for i in range(cur_ngpus)] + stage_info_str = f'stage {idx} on devices {dev} with mem {search_out.stage_mems[idx]:.2f} GB' + _logger.info(f'\nautodist plan analysis for {stage_info_str}:\n\n{analysis_pretty_printer(spmd_desc.analysis)}') offset += cur_ngpus for node in stage.nodes(): if isinstance(node, IRFwOperation): @@ -254,16 +253,16 @@ def parallelize_graph(graph: IRGraph, p_desc = spmd_desc.partition_descs[node.cid] partition_node(node, graph, dev, p_desc) if isinstance(node, IRDimops): - _logger.info( + _logger.debug( f'apply {node} with {node.anno} at {node.comment}, plan: {p_desc}' ) else: - _logger.info( + _logger.debug( f'replicate non-IRDimops {node.signature} with {node.comment}' ) else: replica(graph, node, dev) - _logger.info( + _logger.debug( f'NOT included in plan, replicate {node.signature} with {node.comment}' ) diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index c1a99b52..c9900c7a 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -1,5 +1,4 @@ from typing import List, Tuple, Union, Callable, Dict -import numpy as np import json import os from os import listdir @@ -241,7 +240,7 @@ def get_mem_and_buffer(self, op_partition, is_train: bool, stage_num: int): node_mem = opt_resident_mem + memory_results[ 'train'] + 2 * memory_results['param'] + memory_results['buffer'] node_mem = node_mem + (stage_num - 1) * activation_mem \ - if is_train else memory_results['param'] + if is_train else node_mem node_buffer = max(memory_results.values()) \ if is_train else memory_results['infer'] diff --git a/nnscaler/autodist/descs.py b/nnscaler/autodist/descs.py index f6fcab20..9d9bb9e3 100644 --- a/nnscaler/autodist/descs.py +++ b/nnscaler/autodist/descs.py @@ -35,6 +35,7 @@ class TensorParallelDesc: partition_descs: Dict[int, NodePartitionDesc] recompute_groups: List[List[int]] mesh_desc: MeshDesc + analysis: Dict[str, Any] def to_json(self): ret = {} @@ -42,6 +43,7 @@ def to_json(self): ret['partition_descs'] = descs_list ret['recompute_groups'] = self.recompute_groups ret['mesh_desc'] = self.mesh_desc.to_json() + ret['analysis'] = self.analysis return ret @staticmethod @@ -51,7 +53,8 @@ def from_json(ret): partition_descs[k] = NodePartitionDesc(v) return TensorParallelDesc(partition_descs, copy.deepcopy(ret['recompute_groups']), - MeshDesc.from_json(ret['mesh_desc'])) + MeshDesc.from_json(ret['mesh_desc']), + ret['analysis']) @dataclass diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 44d4c240..246e0f45 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -107,8 +107,7 @@ def estimate_mem_lower_bound( else: raise RuntimeError(f'invalid zero stage {cfg.zero_stage}') - min_single_dev_mem = max(opt_transient_mem, activation_mem) \ - + 2 * param_mem + buffer_mem + opt_resident_mem + min_single_dev_mem = max(opt_transient_mem, activation_mem) + 2 * param_mem + buffer_mem + opt_resident_mem return min_single_dev_mem @@ -274,6 +273,11 @@ def __init__(self, self.start = start self.end = end + def get_full_name(self): + if self.module_type is None: + return self.name + return f'{self.name}, {self.module_type.__name__}' + def insert(self, node: IRFwOperation, module_info: List[Tuple[str, Any]], flops: int, fw_span: float, idx: int): self.leaf_size += 1 @@ -423,6 +427,66 @@ def __repr__(self): return desc +def collect_depth2scope_nodes(root: ScopeNode) -> Dict[int, List[ScopeNode]]: + depth2scope_nodes: Dict[int, List[ScopeNode]] = dict() + + def dfs(node: ScopeNode): + if node.depth not in depth2scope_nodes: + depth2scope_nodes[node.depth] = [] + depth2scope_nodes[node.depth].append(node) + for child in node.children: + dfs(child) + + dfs(root) + return depth2scope_nodes + + +def analyze_base_graph(root: ScopeNode) -> None: + ''' + Analyze the input graph's structure and statistics based on profiling results. + NOTE: if the input graph contains operators that consumes or generates extremely + large tensors, the profiling result may be incorrect. User should check the + partition plan's analysis later. + ''' + depth2scope_nodes = collect_depth2scope_nodes(root) + + # Similar to deepspeed profiler, we list top3 modules in terms of + # params, buffers, activation mem and fw_span + show_num = 3 + def get_val(node: ScopeNode, key: str): + # pretty print the memory size in MB and span in ms for ScopeNode + val = getattr(node, key) + if 'mem' in key: + return f'{val / 1024 / 1024:.2f} MB' + elif 'span' in key: + return f'{val:.2f} ms' + else: + raise RuntimeError(f'invalid key {key}') + + def build_info(nodes: List[ScopeNode], key: str): + info = list() + sorted_nodes = sorted(nodes, key=lambda x: getattr(x, key), reverse=True) + for node in sorted_nodes[:min(show_num, len(sorted_nodes))]: + info.append((node.get_full_name(), get_val(node, key))) + return info + + visual_contents = dict() + for depth, scope_nodes in depth2scope_nodes.items(): + # ignore the root node, since it doesn't have module info + if depth == 0: + continue + visual_contents[depth] = dict() + for key in ['param_mem', 'fw_span', 'train_mem', 'buffer_mem']: + visual_contents[depth][key] = build_info(scope_nodes, key) + + ret = '-' * 25 + 'nnScaler Graph Profiling Result' + '-' * 25 + '\n\n' + for depth, contents in visual_contents.items(): + ret += f'depth {depth}\n' + for key, info in contents.items(): + ret += f' {key} - {info}\n' + return ret + + # a class to store statistics of a continuous sub-sequence # in the initial graph's topology sequence @dataclass @@ -499,7 +563,8 @@ def reconstruct_scope_tree(self): root.insert(node, module_info, calc_flops(node), fw_span, idx=i) root.pull_up(db) - _logger.info('\n' + root.__repr__()) + _logger.debug('\n' + root.__repr__()) + _logger.info('\n' + analyze_base_graph(root)) return root diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 37e60131..02a9ee73 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -1,4 +1,4 @@ -from .model_graph import ModelGraph +from .model_graph import ModelGraph, collect_depth2scope_nodes from .cube_operator import CubeOperator from .descs import * from .cost_database import CostDatabase @@ -14,6 +14,7 @@ import yaml import numpy import logging +from dataclasses import dataclass, asdict from collections import defaultdict from pathlib import Path from typing import Dict, Tuple, List, Set, Any @@ -21,10 +22,16 @@ __all__ = [ 'SPMDSolver', 'calc_optimal_spmd_plan', + 'analysis_pretty_printer', ] _logger = logging.getLogger(__name__) +_PLAN_ANALYSIS_LIST_TIME_TOP_NUM = 10 +_PLAN_ANALYSIS_LIST_PARTITIONS_TOP_NUM = 2 +_PLAN_ANALYSIS_MODULE_MAX_DEPTH = 4 +_PLAN_ANALYSIS_MODULE_TOP_NUM = 3 + @dataclass class PartitionCostDesc: @@ -55,7 +62,7 @@ class PartitionCostDesc: def __repr__(self): contents = dict() - for k, v in self.__dict__.items(): + for k, v in asdict(self).items(): if 'mem' in k: k_in_mb = k + ' (MB)' contents[k_in_mb] = v // 1024 // 1024 @@ -63,6 +70,22 @@ def __repr__(self): contents[k] = v return str(contents) +@dataclass +class ModuleMemCostDesc: + total_cost: int + resident_mem: int + activation_mem: int + opt_transient_mem: int + recompute_mem: int + transient_mem: int + + def __repr__(self): + contents = dict() + for k, v in asdict(self).items(): + k_in_mb = k + ' (MB)' + contents[k_in_mb] = v // 1024 // 1024 + return str(contents) + class SPMDSolver: @@ -366,9 +389,9 @@ def build_op_partitions(operator: CubeOperator) -> List[OpPartition]: operator.ir_cell) if replicated_ops: for signature, ops in replicated_ops.items(): - _logger.info(f'find {len(ops)} replicated {signature}') + _logger.debug(f'find {len(ops)} replicated {signature}') for op in ops: - _logger.info(f'\t{op}\n\t{op.comment}\n\n') + _logger.debug(f'\t{op}\n\t{op.comment}\n\n') if self.non_used_pcs: _logger.warning( f'find unused partition constraints {self.non_used_pcs}') @@ -724,7 +747,7 @@ def gen_min_mem_plan_greedy(self, start: int, plan.append((i, cur_mem.index(min(cur_mem)))) return plan - def calc_mem_cost(self, plan: List[Tuple[int, int]]) -> int: + def calc_mem_cost(self, plan: List[Tuple[int, int]]) -> ModuleMemCostDesc: ''' calculate the memory cost of the plan @@ -732,12 +755,9 @@ def calc_mem_cost(self, plan: List[Tuple[int, int]]) -> int: plan (List[Tuple[int, int]]): the plan to be evaluated Returns: - int: the memory cost of the plan in bytes + ModuleMemCostDesc: the memory cost of the plan in bytes ''' - def to_mb(size: int) -> int: - return size // 1024 // 1024 - mem, act_mem, opt_transient_mem, transient_mem = 0, 0, 0, [] for op_idx, p_idx in plan: desc = self.partition_info[op_idx][p_idx] @@ -750,9 +770,6 @@ def to_mb(size: int) -> int: act_mem += desc.activation_mem opt_transient_mem += desc.opt_transient_mem transient_mem.append(desc.transient_mem) - _logger.info(f'resident mem: {to_mb(mem)} MB') - _logger.info(f'activation mem: {to_mb(act_mem)} MB') - _logger.info(f'opt transient mem: {to_mb(opt_transient_mem)} MB') cost = mem - act_mem + max(act_mem, opt_transient_mem) start, end = plan[0][0], plan[-1][0] @@ -769,26 +786,23 @@ def to_mb(size: int) -> int: cur_recompute_mem_cost += p_cost_desc.activation_mem recompute_mem_cost = max(recompute_mem_cost, cur_recompute_mem_cost) - _logger.info(f'recompute mem: {to_mb(recompute_mem_cost)} MB') cost += recompute_mem_cost # A heuristic that helps to estimate the memory cost accurately. # It is hard to fully reuse large memory blocks in the cached allocator. - # - in training, use the maximum 2 transient memory - # - in inference, use the largest transient memory + # In training and inference, we use the top 2 largest inference transient + # memory cost. In training, we double the cost as a result of the backward pass. if transient_mem: transient_mem.sort() transient_mem.reverse() - if len(transient_mem) == 1 or not self.autodist_config.is_train: - cost += transient_mem[0] - _logger.info(f'transient mem: {to_mb(transient_mem[0])} MB') + if len(transient_mem) == 1: + transient_mem_cost = transient_mem[0] else: - cost += transient_mem[0] + transient_mem[1] - _logger.info( - f'transient mem: {to_mb(transient_mem[0])} MB, {to_mb(transient_mem[1])} MB' - ) - _logger.info(f'total mem cost: {to_mb(cost)} MB') - return cost + transient_mem_cost = transient_mem[0] + transient_mem[1] + if self.autodist_config.is_train: + transient_mem_cost *= 2 + cost += transient_mem_cost + return ModuleMemCostDesc(cost, mem, act_mem, opt_transient_mem, recompute_mem_cost, transient_mem_cost) def calc_inner_time_cost(self, plan: List[Tuple[int, int]]) -> float: ''' @@ -990,9 +1004,9 @@ def _solve_by_ilp(self, start: int, end: int) -> SPMDSearchOutput: prob += act_mem <= max_act_opt_transient prob += opt_transient_mem <= max_act_opt_transient if self.autodist_config.is_train: - transient_coef = 2 + transient_coef = 4 else: - transient_coef = 1 + transient_coef = 2 prob += mem - act_mem + max_act_opt_transient + transient_coef * max_transient + recompute_mem <= self.mem_bound # 4.3. constraint over e @@ -1080,7 +1094,7 @@ def get_non_zero_index(binary_vector): plans.append((i, s_val[i - start])) p_cost_desc = self.partition_info[i][s_val[i - start]] inner_time_cost += p_cost_desc.comp_time + p_cost_desc.weight_update_time - mem_cost = self.calc_mem_cost(plans) + mem_cost = self.calc_mem_cost(plans).total_cost return SPMDSearchOutput(self.partition_path2desc(plans), mem_cost / 1024 / 1024 / 1024, all_time_cost, inner_time_cost) @@ -1130,6 +1144,149 @@ def do_dp(self, intervals: List[Tuple[int, int]], ret.append(descs) return ret + def analyze_plan(self, plan: List[Tuple[int, int]]) -> Dict[str, Any]: + """ + Analyze the given plan and return the analysis results. + The analysis includes: + - Computation Related + - the total computation time + - the top-10 operators that consume the most computation time + - detailed partition plans for the top-2 operators that consume the most computation time + - Communication Related + - the total communication time + - the top-10 operators that consume the most communication time + - Memory Related + - the top-3 modules that consume the most memory in depth 1-4 + - Detailed Partition Plans for each IRDimops + + Args: + plan (List[Tuple[int, int]]): the plan to be analyzed + + Returns: + Dict[str, Any]: the analysis results + """ + ret = dict() + start, end = plan[0][0], plan[-1][0] + + # top 10 operators grouped by signature that: + # - consume the most computation time + # - consume the most communication time + sig2comp_time = dict() + sig2comm_time = dict() + op_idx2comp_time = dict() + op_idx2comm_time = dict() + dimops_split_info = list() + sig2split_info = dict() + comp_time_sum, comm_time_sum = 0, 0 + for op_idx, p_idx in plan: + desc = self.partition_info[op_idx][p_idx] + node = self.graph.operator_list[op_idx].ir_cell + sig = node.signature + if sig not in sig2comp_time: + sig2comp_time[sig] = 0 + if sig not in sig2split_info: + sig2split_info[sig] = [] + sig2comp_time[sig] += desc.comp_time + op_idx2comp_time[op_idx] = desc.comp_time + comm_cost = 0 + for k, comm_vec in enumerate(desc.comm_time): + producer = self.producers[op_idx][k] + # do not consider the communication cost between the node in the interval + # to its producer outside the interval currently + if start <= producer <= end: + producer_p_idx = plan[producer - start][1] + comm_cost += comm_vec[producer_p_idx] + op_idx2comm_time[op_idx] = comm_cost + if isinstance(node, IRDimops): + partition_repr = (repr(node), repr(node.anno), node.comment, repr(self._op_partitions[op_idx][p_idx])) + split_info = (partition_repr, desc.comp_time, comm_cost) + dimops_split_info.append(split_info) + sig2split_info[sig].append(split_info) + if comm_cost == 0: + continue + if sig not in sig2comm_time: + sig2comm_time[sig] = 0 + sig2comm_time[sig] += comm_cost + comp_time_sum += desc.comp_time + comm_time_sum += comm_cost + + sig2comp_time = sorted(sig2comp_time.items(), key=lambda x: x[1], reverse=True) + comp_sig_num = min(_PLAN_ANALYSIS_LIST_TIME_TOP_NUM, len(sig2comp_time)) + sig2comp_time = sig2comp_time[:comp_sig_num] + top_comp_time_sum = sum([x[1] for x in sig2comp_time]) + ret['comp_time_sum'] = comp_time_sum + ret[f'top_comp_time'] = sig2comp_time + ret[f'top_comp_time_sum'] = top_comp_time_sum + + # in addition list partition plans for top comp time operators + top_op_split_info = {} + for sig, _ in sig2comp_time[:min(_PLAN_ANALYSIS_LIST_PARTITIONS_TOP_NUM, len(sig2comp_time))]: + top_op_split_info[sig] = sig2split_info[sig] + ret['top_op_split_info'] = top_op_split_info + + sig2comm_time = sorted(sig2comm_time.items(), key=lambda x: x[1], reverse=True) + comm_sig_num = min(_PLAN_ANALYSIS_LIST_TIME_TOP_NUM, len(sig2comm_time)) + sig2comm_time = sig2comm_time[:comm_sig_num] + top_comm_time_sum = sum([x[1] for x in sig2comm_time]) + ret['comm_time_sum'] = comm_time_sum + ret[f'top_comm_time'] = sig2comm_time + ret[f'top_comm_time_sum'] = top_comm_time_sum + + # similar to analysis in the raw graph, we list the top-3 modules that: + # - consume the most computation time + # - consume the most communication time + # - consume the most memory + # to reduce the complexity, we only consider the modules: + # - 1 <= depth <= _PLAN_ANALYSIS_MODULE_MAX_DEPTH + # - in the interval [start, end] + # - composed of more than one operator + ret['module_analysis'] = {} + op_idx2plan_offset = {op_idx: i for i, (op_idx, _) in enumerate(plan)} + depth2scope_nodes = collect_depth2scope_nodes(self.graph.scope_tree_root) + for depth, scope_nodes in depth2scope_nodes.items(): + if depth == 0 or depth > _PLAN_ANALYSIS_MODULE_MAX_DEPTH: + continue + content = {'comp_time': [], 'comm_time': [], 'mem': []} + info = list() + for scope_node in scope_nodes: + # currently do not consider the module that is not in the interval + if scope_node.start < start or scope_node.end > end: + continue + # skip modules composed of only one operator, since they are covered + # at the operator level analysis + if scope_node.start == scope_node.end: + continue + comp_time, comm_time = 0, 0 + for op_idx in range(scope_node.start, scope_node.end + 1): + comp_time += op_idx2comp_time[op_idx] + comm_time += op_idx2comm_time[op_idx] + sub_plan_start = op_idx2plan_offset[scope_node.start] + sub_plan_end = op_idx2plan_offset[scope_node.end] + sub_plan = plan[sub_plan_start:sub_plan_end + 1] + mem_cost = self.calc_mem_cost(sub_plan) + info.append((scope_node.get_full_name(), comp_time, comm_time, mem_cost)) + # sort by comp_time + info.sort(key=lambda x: x[1], reverse=True) + for i in range(min(_PLAN_ANALYSIS_MODULE_TOP_NUM, len(info))): + name, comp_time, _, _ = info[i] + content['comp_time'].append((name, comp_time)) + # sort by comm_time + info.sort(key=lambda x: x[2], reverse=True) + for i in range(min(_PLAN_ANALYSIS_MODULE_TOP_NUM, len(info))): + name, _, comm_time, _ = info[i] + content['comm_time'].append((name, comm_time)) + # sort by mem + info.sort(key=lambda x: x[3].total_cost, reverse=True) + for i in range(min(_PLAN_ANALYSIS_MODULE_TOP_NUM, len(info))): + name, _, _, mem = info[i] + content['mem'].append((name, repr(mem))) + ret['module_analysis'][depth] = content + + # TODO: generate a visualization of the plan like torch.profiler + ret['dimops_split_info'] = dimops_split_info + + return ret + def solve(self, intervals: List[Tuple[int, int]], topk: int) -> List[SPMDSearchOutput]: ''' @@ -1176,7 +1333,46 @@ def partition_path2desc( return TensorParallelDesc(partition_descs=partition_descs, mesh_desc=self.mesh_desc, - recompute_groups=[]) + recompute_groups=[], + analysis=self.analyze_plan(plans)) + + +def analysis_pretty_printer(analysis: Dict[str, Any]) -> str: + ret = '' + ret += f'Total computation time: {1000.0 * analysis["comp_time_sum"]:.2f} ms\n' + ret += f'Top {_PLAN_ANALYSIS_LIST_TIME_TOP_NUM} of operators that consume the most computation time:\n' + for sig, time in analysis[f'top_comp_time']: + ret += f' {sig}: {1000.0 * time:.2f} ms\n' + ret += f'Top {_PLAN_ANALYSIS_LIST_TIME_TOP_NUM} of operators computation time sum: {1000.0 * analysis["top_comp_time_sum"]:.2f} ms\n' + ret += '\n' + ret += f'Top {_PLAN_ANALYSIS_LIST_PARTITIONS_TOP_NUM} operators split info:\n' + for sig, split_info in analysis[f'top_op_split_info'].items(): + ret += f' {sig}:\n' + for partition_repr, comp_time, comm_time in split_info: + node_repr, anno, comment, partition_info = partition_repr + ret += f' {node_repr}\n' + ret += f' {comment}\n' + ret += f' {anno}, {partition_info}, comp_time: {1000.0 * comp_time:.2f} ms, comm_time: {1000.0 * comm_time:.2f} ms\n\n' + ret += '\n' + ret += f'Total communication time: {1000.0 * analysis["comm_time_sum"]:.2f} ms\n' + ret += f'Top {_PLAN_ANALYSIS_LIST_TIME_TOP_NUM} operators that consume the most communication time:\n' + for sig, time in analysis[f'top_comm_time']: + ret += f' {sig}: {1000.0 * time:.2f} ms\n' + ret += f'Top {_PLAN_ANALYSIS_LIST_TIME_TOP_NUM} of operators communication time sum: {1000.0 * analysis[f"top_comm_time_sum"]:.2f} ms\n' + ret += '\n' + ret += 'Module analysis:\n' + for depth, content in analysis['module_analysis'].items(): + ret += f'Depth {depth}:\n' + ret += f' Top {_PLAN_ANALYSIS_MODULE_TOP_NUM} modules that consume the most computation time:\n' + for name, time in content['comp_time']: + ret += f' {name}: {1000.0 * time:.2f} ms\n' + ret += f' Top {_PLAN_ANALYSIS_MODULE_TOP_NUM} modules that consume the most communication time:\n' + for name, time in content['comm_time']: + ret += f' {name}: {1000.0 * time:.2f} ms\n' + ret += f' Top {_PLAN_ANALYSIS_MODULE_TOP_NUM} modules that consume the most memory:\n' + for name, mem_desc in content['mem']: + ret += f' {name}: {mem_desc}\n' + return ret def calc_optimal_spmd_plan( diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index 45948d81..37327823 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -256,7 +256,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule else: frame.set_var(node.name, ir_node) - _logger.info(f'parsing result: {ir_node}') + _logger.debug(f'parsing result: {ir_node}') return ir_nodes @staticmethod diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index bfff1905..938babca 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -151,6 +151,7 @@ def gen_torch_tensors(shape, dtype, requires_grad): gen_torch_tensors(shape, dtype, requires_grad) if isinstance(value, IRTensor) else value \ for shape, dtype, requires_grad, value in zip(shapes, dtypes, requires_grads, values) ) + total_input_size = sum(t.numel() * t.element_size() for t in tensors if torch.is_tensor(t)) require_backward = any([t.requires_grad for t in tensors if hasattr(t, 'requires_grad')]) # FIXME: reconsidering requires_grad if func.__name__ in ('type_as'): @@ -193,7 +194,7 @@ def run_step(func, tensors, kwargs, backward: bool): with torch.no_grad(): run_step(func, tensors, eval_kwargs, backward=False) mtoc = torch.cuda.max_memory_allocated() # in bytes - infer_memory = mtoc - mtic + infer_memory = mtoc - mtic + total_input_size train_mem_info = [] train_mem2in_idx = [] diff --git a/utility/prim_profiler.py b/utility/prim_profiler.py index e68f5bf9..6daef5ff 100644 --- a/utility/prim_profiler.py +++ b/utility/prim_profiler.py @@ -20,8 +20,7 @@ def main(): - base_path = get_default_profile_path() - default_path = base_path / get_node_arch() + default_path = get_default_profile_path() if not default_path.is_dir(): default_path.mkdir(parents=True) From c7b6817c6bf9365d4187497424538540d1c15ba0 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 20 May 2024 11:14:32 +0000 Subject: [PATCH 1643/1892] Merged PR 2143: support script function & type dict_keys dict_values dict_items --- nnscaler/graph/function/dimops.py | 13 ++--- nnscaler/graph/function/function.py | 15 ++++++ .../fx/concrete_trace_utils/concrete_proxy.py | 8 ++- .../concrete_trace_utils/concrete_tracer.py | 38 +++++++++++--- .../parser/fx/concrete_trace_utils/utils.py | 34 +++++++++++- nnscaler/graph/parser/fx/mapping.py | 3 ++ nnscaler/graph/parser/fx/parser.py | 36 ++++++++++--- nnscaler/graph/parser/register.py | 13 +++++ nnscaler/profiler/database.py | 3 +- tests/graph/function/__init__.py | 0 tests/graph/function/helper.py | 16 ++++++ tests/graph/function/test_dict_values.py | 30 +++++++++++ tests/graph/function/test_script_func.py | 27 ++++++++++ tests/graph/tracer/test_pytree.py | 52 +++++++++++++++++++ 14 files changed, 261 insertions(+), 27 deletions(-) create mode 100644 tests/graph/function/__init__.py create mode 100644 tests/graph/function/helper.py create mode 100644 tests/graph/function/test_dict_values.py create mode 100644 tests/graph/function/test_script_func.py diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index dea2160d..9904bcae 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -685,24 +685,24 @@ def anno(self) -> OpAnno: def transform_rules(self) -> Tuple[TransformRule]: return self._trans_rules - def ianno(self, index: int) -> Tuple[DimAnno]: + def ianno(self, index: int) -> ShapeAnno: """! Get index-th input tensor shape annotation @param index int: the input index - @return dim_annos Tuple[DimAnno]: a tuple that each element is a dimension annotation + @return dim_annos ShapeAnno: a tuple that each element is a dimension annotation """ assert index < len(self.inputs()), "index out of boudary" return tuple(self._iannos[index]) - def oanno(self, index: int) -> Tuple[DimAnno]: + def oanno(self, index: int) -> ShapeAnno: """! Get index-th output tensor shape annotation @param index int: the output index - @return dim_annos Tuple[DimAnno]: a tuple that each element is a dimension annotation + @return dim_annos ShapeAnno: a tuple that each element is a dimension annotation """ assert index < len(self.outputs()), "index out of boudary" return self._oannos[index] @@ -715,11 +715,8 @@ def infer_shape(self) -> bool: """ for oidx, otensor in enumerate(self.outputs()): shape_anno = self.oanno(oidx) - if str(shape_anno) == '?': + if shape_anno.ignore: assert isinstance(otensor, IRObject), f"expect IRObject for unknown shape, get {otensor}" - _logger.warning( - 'detect IRObject output in a IRDimops, please ensure the annotation is ' - 'correct w.r.t the partition policy.') continue shape = [] for odim in range(shape_anno.ndims): diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 794c9e42..a523a9b8 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -2631,3 +2631,18 @@ def Sigmoid(input, *, out=None, signature=None): raise ValueError("Expected 'out' to be None") annos = ['* -> *'] return IRDimops(Sigmoid, 'sigmoid', signature, annos, [input]) + + +def Dictkeys(o: Union[Dict, IRObject], signature=None): + assert isinstance(o, dict) or isinstance(o.value, dict), f'the input should be a dict or an IRObject with dict value, but get {o}' + return IRPyFunc(signature, inputs=[o], outputs=[IRObject(name='dictkeys', value=o.value.keys(), is_constant=o.is_constant)]) + + +def DictValues(o: Union[Dict, IRObject], signature=None): + assert isinstance(o, dict) or isinstance(o.value, dict), f'the input should be a dict or an IRObject with dict value, but get {o}' + return IRPyFunc(signature, inputs=[o], outputs=[IRObject(name='dictvalues', value=o.value.values(), is_constant=o.is_constant)]) + + +def DictItems(o: Union[Dict, IRObject], signature=None): + assert isinstance(o, dict) or isinstance(o.value, dict), f'the input should be a dict or an IRObject with dict value, but get {o}' + return IRPyFunc(signature, inputs=[o], outputs=[IRObject(name='dictitems', value=o.value.items(), is_constant=o.is_constant)]) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 6926bc95..58ed679f 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -235,8 +235,14 @@ def keys(self): return self.tracer.create_proxy('call_method', 'keys', (self,), {}) @compatibility(is_backward_compatible=True) + @property def values(self): - return self.tracer.create_proxy('call_method', 'values', (self,), {}) + if callable(self.value.values): + def _values(): + return self.tracer.create_proxy('call_method', 'values', (self,), {}) + return _values + else: + return ConcreteAttrProxy(self, 'values') @compatibility(is_backward_compatible=True) def items(self): diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 45884a31..b1efc0a3 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -19,7 +19,7 @@ from contextlib import contextmanager import torch -from torch._C import ScriptObject +from torch._C import ScriptObject, ScriptFunction from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict from torch.utils._pytree import tree_flatten, tree_unflatten @@ -132,12 +132,12 @@ def __exit__(self, *args): side_effectful_inplace_ops, ) from .utils import ( - FrameRecord, ExtraSEFPatcher, EmptyResult, extract_results_metadata, flatten_trees_with_func, flatten_trees_with_func_and_spec, + get_common_spec, map_trees_with_func, get_frame_record, ) @@ -162,8 +162,8 @@ class Location: class LeafFnWrapInfo: """ extra_locs: The place the function is imported. - is_force_trace: If set to false, the function will only be traced if input relates to concrete_args. - Such as 'torch.rand', we should trace it even if it doesn't relate to concrete_args. + is_force_trace: If set to false, the function will only be traced if inputs include proxy. + Such as 'torch.rand', we should trace it even if it doesn't have proxy as input, so it should be force traced. replace_fn: If not `None`, we will use it to replace the original function in traced code. Such as ModuleList.__getitem__, we can use operator.getitem to replace it. """ @@ -1035,6 +1035,24 @@ def torch_assert_wrapper(condition, message): if func.__self__ not in self.agfunc_dict: self.agfunc_dict[func.__self__] = _create_wrapped_leaf_func(self, func, func) wrapped = self.agfunc_dict[func.__self__] + elif isinstance(func, ScriptFunction): + # if it is a script function, + # here will wrap the origin function location and forward the script function to the origin one. + # _torchdynamo_inline is introduced in pytorch 2.0, it is the original function of the script function. + inner_func = func._torchdynamo_inline + # some `func.__module__` may have additional `_` compare with its import path in user code, + # for example, `operator.add.__module__` is `_operator` and `_operator` is a built-in module and we don't want to touch it, + # we assume user won't import function from module named with prefix `_`, + # here we only wrap the function under no prefix `_` module, i.e. functions under `operator`. + if inner_func.__module__.startswith('_') and inner_func.__module__ != '__main__': + path = sys.modules.get(inner_func.__module__[1:], sys.modules[inner_func.__module__]) + else: + path = sys.modules[inner_func.__module__] + locations = (*locations, Location(path, inner_func.__name__)) + if wrap_info.is_force_trace: + wrapped = _create_wrapped_leaf_func(self, func, inner_func, (self,)) + else: + wrapped = _create_wrapped_leaf_func(self, func, inner_func) else: if func.__qualname__.startswith('_TensorBase'): locations = (*locations, Location(torch.Tensor, func.__name__)) @@ -1057,7 +1075,7 @@ def torch_assert_wrapper(condition, message): # if func.__module__ is not None: if func.__module__.startswith('_') and func.__module__ != '__main__': - path = sys.modules[func.__module__[1:]] + path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) else: path = sys.modules[func.__module__] path = getattr(path, func.__qualname__.split('.')[0]) @@ -1073,7 +1091,7 @@ def torch_assert_wrapper(condition, message): # if func.__module__ is not None: if func.__module__.startswith('_') and func.__module__ != '__main__': - path = sys.modules[func.__module__[1:]] + path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) else: path = sys.modules[func.__module__] locations = (*locations, Location(path, func.__name__)) @@ -1099,7 +1117,7 @@ def torch_assert_wrapper(condition, message): } for clz, wrap_info in self.autowrap_leaf_class.items(): if clz.__module__.startswith('_') and clz.__module__ != '__main__': - path = sys.modules[clz.__module__[1:]] + path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) else: path = sys.modules[clz.__module__] if wrap_info.is_iterable: @@ -1384,7 +1402,11 @@ def update_tree_proxy_value(dst_pytree, src_pytree): copy the value from src_pytree to dst_pytree with the dst_pytree spec, if the leaf is proxy, only replace the proxy.value, not replace the proxy. """ - _, spec = tree_flatten(dst_pytree) + # consider about this case: + # dst_pytree: {'a': [1, 2, 3]} + # src_pytree: {'a': [1, 2, 3, 4]} + # then the public spec is {'a': *}, we don't want to flatten the list here. + spec = get_common_spec(tree_flatten(dst_pytree)[1], tree_flatten(src_pytree)[1]) def update_proxy_value(a, b): if isinstance(a, ep.ConcreteProxy): diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py index dec5178e..c1ba419b 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py @@ -17,6 +17,11 @@ from . import concrete_proxy as ep +DICT_KEYS_TYPE = type({}.keys()) +DICT_VALUES_TYPE= type({}.values()) +DICT_ITEMS_TYPE= type({}.items()) + + # These need to run in global scope to handle nested calls correctly _orig_module_call: Callable = torch.nn.Module.__call__ _orig_module_getattr: Callable = torch.nn.Module.__getattr__ @@ -97,7 +102,24 @@ def _get_node_type(pytree: Any) -> Any: torch_pytree._get_node_type = _get_node_type -def flatten_trees_with_func(fn, pytrees): +def get_common_spec(dst_spec: TreeSpec, src_sepc: TreeSpec) -> TreeSpec: + """ + Return the common part of two treespec. + For example: + dst_spec is {'a': [*,], 'b': [*, *]} + src_sepc is {'a': [*,], 'b': [*, *, *]} + common spec is {'a': [*,], 'b': *} + """ + if isinstance(dst_spec, LeafSpec) or isinstance(src_sepc, LeafSpec): + return LeafSpec() + if dst_spec.type == src_sepc.type and dst_spec.context == src_sepc.context: + if len(dst_spec.children_specs) == len(src_sepc.children_specs): + children_specs = [get_common_spec(dst, src) for dst, src in zip(dst_spec.children_specs, src_sepc.children_specs)] + return TreeSpec(type=dst_spec.type, context=dst_spec.context, children_specs=children_specs) + return LeafSpec() + + +def flatten_trees_with_func(fn, pytrees) -> Tuple[List[Any], TreeSpec]: """ Each pytree in pytrees should have the same structure. @@ -272,7 +294,15 @@ def extract_tensor_metadata(obj: Any): def extract_results_metadata(results: Any, node: Node): if results is not EmptyResult: - meta = map_aggregate(results, extract_tensor_metadata) + res = tuple(results) if isinstance(results, (DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE)) else results + meta = map_aggregate(res, extract_tensor_metadata) + # we should get the meta info of the inner element of these type obj + if isinstance(results, DICT_KEYS_TYPE): + meta = {i: m for i, m in enumerate(meta)}.keys() + if isinstance(results, DICT_VALUES_TYPE): + meta = {i: m for i, m in enumerate(meta)}.values() + if isinstance(results, DICT_ITEMS_TYPE): + meta = {i: m for i, m in meta}.items() node.meta['tensor_meta'] = meta node.meta['type'] = type(results) diff --git a/nnscaler/graph/parser/fx/mapping.py b/nnscaler/graph/parser/fx/mapping.py index a8efb69b..b416d0b0 100644 --- a/nnscaler/graph/parser/fx/mapping.py +++ b/nnscaler/graph/parser/fx/mapping.py @@ -145,6 +145,9 @@ def exist(signature: str) -> bool: 'builtins.list': function.MakeList, 'builtins.slice': function.MakeSlice, 'builtins.len': function.Len, + 'builtins.dict.keys': function.Dictkeys, + 'builtins.dict.values': function.DictValues, + 'builtins.dict.items': function.DictItems, # # torch nn functional '_operator.matmul': function.Matmul, diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index 37327823..cc30dfcd 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -14,6 +14,7 @@ import torch.fx from .concrete_trace_utils import TensorMetadata +from .concrete_trace_utils.utils import DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE _logger = logging.getLogger(__name__) @@ -134,6 +135,10 @@ def meta2var(meta: Any) -> Any: if not all(isinstance(key, str) for key in meta.keys()): raise TypeError(f"only support dict type with str key, but got {meta.keys()}.\n{node}") return {key : meta2var(value) for key, value in meta.items()} + if isinstance(meta, DICT_VALUES_TYPE): + return {key : meta2var(value) for key, value in enumerate(meta)}.values() + if isinstance(meta, DICT_ITEMS_TYPE): + return {key : meta2var(value) for key, value in meta}.items() # TODO: data type check, with cases like {'a': 1.2, 'b': torch.Tensor} return IRObject(name=node.name, value=meta, is_constant=is_constant) @@ -169,6 +174,11 @@ def parse_complex(val: Any, frame: Frame) -> Any: return list(FxModuleParser.parse_complex(t, frame) for t in val) if isinstance(val, dict): return {key: FxModuleParser.parse_complex(val, frame) for key, val in val.items()} + # because fx node cannot be a dict key, so skip DICT_KEYS_TYPE here + if isinstance(val, DICT_VALUES_TYPE): + return {i: FxModuleParser.parse_complex(x, frame) for i, x in enumerate(val)}.values() + if isinstance(val, DICT_ITEMS_TYPE): + return {i: FxModuleParser.parse_complex(x, frame) for i, x in val}.items() if isinstance(val, torch.fx.Node): return frame.get_var(val.name) return val @@ -235,7 +245,9 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule elif not isinstance(ir_node.output(0), IRTensor) and ir_node.output(0).value is not None: if dynamic_shape or \ any_ir_object_satisfy(ir_node.output(0), lambda a: not a.is_constant) or \ - any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, IRTensor)): + any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, IRTensor)) or \ + any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, (DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE))): + # type of return values of dict.keys, dict.values and dict.items can not be repr, so we must take it as a node frame.set_var(node.name, ir_node.output(0)) ir_node.output(0).name = node.name else: @@ -327,17 +339,27 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> """ if not isinstance(node_target, str): raise ValueError(f'node_target must be a string, but got {type(node_target)} with value {node_target}') - for module, module_name in [(torch, 'torch'), (torch.Tensor, 'torch.Tensor')]: - lib_func = getattr(module, node_target, None) - if lib_func is not None and callable(lib_func): - return f'{module_name}.{node_target}' + # NOTE(nishang): seems that we don't need to guess the method sig? + # for module, module_name in [(torch, 'torch'), (torch.Tensor, 'torch.Tensor')]: + # lib_func = getattr(module, node_target, None) + # if lib_func is not None and callable(lib_func): + # return f'{module_name}.{node_target}' assert len(node.args) > 0, 'Expect an object as the first argument of call_method' # example node.args[0].meta is {'type': } in_type = node.args[0].meta['type'] assert node_target in in_type().__dir__(), f'node_target = {node_target}, in_type().__dir__() = {in_type().__dir__()}' - sig = f'{in_type.__name__}.{node_target}' - return sig + # TODO: for the history issue (please see the comment out lines after NOTE), + # we should forward the torch.Tensor.xxx to torch.xxx if xxx existed under torch, + # because many torch.Tensor functions are not included in the mapping.py, + # we should add torch.Tensor.xxx in mapping.py + if issubclass(in_type, torch.Tensor) and getattr(torch, node_target, None) and callable(getattr(torch, node_target)): + return f'torch.{node.target}' + # here forward torch.nn.Parameter.xxx to torch.Tensor.xxx + elif issubclass(in_type, torch.Tensor) and getattr(torch.Tensor, node_target, None) and callable(getattr(torch.Tensor, node_target)): + return f'torch.Tensor.{node.target}' + else: + return f'{in_type.__module__}.{in_type.__name__}.{node_target}' @staticmethod def _find_module_of_method(orig_method: Callable[..., Any]) -> str: diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index 3ff700e5..2e479314 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -7,6 +7,8 @@ import inspect import logging +from torch import ScriptFunction + from nnscaler.graph.function.dimops import IRDimops, OpAnno from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply from nnscaler.ir.operator import IRTensor @@ -136,6 +138,9 @@ def decorator(fn: Callable): def get_import_path(fn: Callable) -> str: if is_autograd_apply(fn): import_path = inspect.getmodule(fn.__self__).__name__ + elif isinstance(fn, ScriptFunction): + # fn._torchdynamo_inline is the original function + import_path = inspect.getmodule(fn._torchdynamo_inline).__name__ else: import_path = inspect.getmodule(fn).__name__ return import_path @@ -151,6 +156,12 @@ def get_import_path(fn: Callable) -> str: op_name = name if name is not None else fn.__self__.__name__ args = inspect.signature(fn.__self__.forward) arg_names = list(args.parameters.keys())[1:] + elif isinstance(fn, ScriptFunction): + # fn._torchdynamo_inline is the original function + fsig = f'{import_path}.{fn._torchdynamo_inline.__name__}' + op_name = name if name is not None else fn.name + args = inspect.signature(fn._torchdynamo_inline) + arg_names = list(args.parameters.keys()) else: fsig = f'{import_path}.{fn.__name__}' op_name = name if name is not None else fn.__name__ @@ -162,6 +173,8 @@ def get_source_code(fn: Callable) -> str: if is_autograd_apply(fn): code = inspect.getsource(fn.__self__) code = code[code.index(f'class {fn.__self__.__name__}'):] + elif isinstance(fn, ScriptFunction): + raise NotImplementedError('Do not support get source code for ScriptFunction.') else: code = inspect.getsource(fn) code = code[code.index('def'):] diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 938babca..7cf3cd37 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -154,7 +154,8 @@ def gen_torch_tensors(shape, dtype, requires_grad): total_input_size = sum(t.numel() * t.element_size() for t in tensors if torch.is_tensor(t)) require_backward = any([t.requires_grad for t in tensors if hasattr(t, 'requires_grad')]) # FIXME: reconsidering requires_grad - if func.__name__ in ('type_as'): + # the __name__ of function with type of torch.ScriptFunction is None + if hasattr(func, '__name__') and func.__name__ in ('type_as'): require_backward = False # repalce kwargs starting with 'self.xxx' train_kwargs, eval_kwargs = {}, {} diff --git a/tests/graph/function/__init__.py b/tests/graph/function/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/graph/function/helper.py b/tests/graph/function/helper.py new file mode 100644 index 00000000..1bc5257f --- /dev/null +++ b/tests/graph/function/helper.py @@ -0,0 +1,16 @@ +import torch +from nnscaler import register_op + + +@torch.jit.script +def cus_add(a, b): + return a + b + +register_op('*, * -> *')(cus_add) + + +@torch.jit.script +def cus_sub(a, b): + return a - b + +register_op('*, * -> *')(cus_sub) diff --git a/tests/graph/function/test_dict_values.py b/tests/graph/function/test_dict_values.py new file mode 100644 index 00000000..79a2cceb --- /dev/null +++ b/tests/graph/function/test_dict_values.py @@ -0,0 +1,30 @@ +import tempfile +import torch +from nnscaler.parallel import parallelize, ComputeConfig + +from ...utils import replace_all_device_with + + +class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + k = list(x.keys())[0] + v = x[k] + y = list(x.values())[0] + z = list(x.items())[0][1] + return torch.sum(v + y + z) + + +@replace_all_device_with('cpu') +def test_script_func(): + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + Model(), + {'x': {'a': torch.rand(10)}}, + 'tp', + ComputeConfig(2, 2), + cube_savedir=tempdir, + load_module=False + ) diff --git a/tests/graph/function/test_script_func.py b/tests/graph/function/test_script_func.py new file mode 100644 index 00000000..b2e1460f --- /dev/null +++ b/tests/graph/function/test_script_func.py @@ -0,0 +1,27 @@ +import tempfile +import torch +from nnscaler.parallel import parallelize, ComputeConfig + +from .helper import cus_add, cus_sub +from ...utils import replace_all_device_with + + +class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + return cus_add(a, b) + cus_sub(a, b) + + +@replace_all_device_with('cpu') +def test_script_func(): + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + Model(), + {'a': torch.rand(10), 'b': torch.rand(10)}, + 'tp', + ComputeConfig(2, 2), + cube_savedir=tempdir, + load_module=False + ) diff --git a/tests/graph/tracer/test_pytree.py b/tests/graph/tracer/test_pytree.py index 043cb322..24d3035f 100644 --- a/tests/graph/tracer/test_pytree.py +++ b/tests/graph/tracer/test_pytree.py @@ -2,6 +2,13 @@ from torch.utils._pytree import tree_flatten +from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_proxy import ( + ConcreteProxy, + Node, +) +from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import ( + update_tree_proxy_value +) from nnscaler.graph.parser.fx.concrete_trace_utils.utils import ( flatten_tree_with_spec, flatten_trees_with_func, @@ -57,3 +64,48 @@ def test_map_trees_with_func(): pytree_2 = [5, (6, {3: 5})] assert map_trees_with_func(lambda a, b: a + b, [pytree_1, pytree_2]) == [6, (8, {3: 9})] + + +def test_update_tree_proxy_value(): + class DummyNode: + def __init__(self, name): + self.name = name + self.graph = None + + pytree_1 = ConcreteProxy(node=DummyNode('test_node'), value={'a': {'b': [1, 2]}}, tracer=None) + pytree_2 = {'a': {'b': [1, 3]}} + new_pytree = update_tree_proxy_value(pytree_1, pytree_2) + assert str(new_pytree) == "ConcreteProxy(test_node, {'a': {'b': [1, 3]}})" + + pytree_1 = {'a': ConcreteProxy(node=DummyNode('test_node'), value={'b': [1, 2]}, tracer=None)} + pytree_2 = {'a': {'b': [1, 3]}} + new_pytree = update_tree_proxy_value(pytree_1, pytree_2) + assert str(new_pytree) == "{'a': ConcreteProxy(test_node, {'b': [1, 3]})}" + + pytree_1 = ConcreteProxy( + node=DummyNode('t1'), + value={'a': ConcreteProxy( + node=DummyNode('t2'), + value={'b': ConcreteProxy( + node=DummyNode('t3'), + value=[1, ConcreteProxy( + node=DummyNode('t4'), + value=2, + tracer=None + )], + tracer=None) + }, + tracer=None) + }, + tracer=None + ) + pytree_2 = {'a': {'b': [1, 3]}} + new_pytree = update_tree_proxy_value(pytree_1, pytree_2) + assert str(new_pytree) == "ConcreteProxy(t1, {'a': ConcreteProxy(t2, {'b': ConcreteProxy(t3, [1, ConcreteProxy(t4, 3)])})})" + + pytree_1 = {'a': ConcreteProxy(node=DummyNode('test_node'), value={'b': [1, 2]}, tracer=None)} + pytree_2 = {'b': {'a': [1, 3]}} + new_pytree = update_tree_proxy_value(pytree_1, pytree_2) + # because the spec of pytree_1 - {'a': {'b': *}} - and pytree_2 - {'b': {'a': *}} - is completely differet, + # the result is directly pytree_2 + assert str(new_pytree) == "{'b': {'a': [1, 3]}}" From 0ddc7e57e6e116b2a394984b10253e8cb45fc8b6 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 20 May 2024 12:11:05 +0000 Subject: [PATCH 1644/1892] Merged PR 2148: Set input tensor's requires_grad to false when not diffirentiable --- nnscaler/parallel.py | 8 +++-- tests/autodist/pas/all_replicated_pp.json | 30 ++++++++++++++--- .../pas/replicated_and_partition.json | 30 ++++++++++++++--- tests/parallel_module/test_embedding.py | 33 +++++++++++++++++++ 4 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 tests/parallel_module/test_embedding.py diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 9b5c0384..27cb64aa 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -313,14 +313,18 @@ def to_ir_input(sample, name): # generate backward communications in adapter. However, as long as # the data doesn't require gradient in real runtime, the backward # communication will not be triggered. + # PyTorch only supports floating point and complex tensors for autograd. + # To align with PyTorch, we set requires_grad to False for other types. + requires_grad = sample.is_floating_point() or sample.is_complex() tensor = IRFullTensor( shape=sample.size(), name=name, - requires_grad=True, + requires_grad=requires_grad, dtype=sample.dtype ).tosub() tensor._value = sample - tensor.grad = tensor.parent.grad.tosub() + if requires_grad: + tensor.grad = tensor.parent.grad.tosub() return tensor return IRObject(name, value=sample, is_constant=False) diff --git a/tests/autodist/pas/all_replicated_pp.json b/tests/autodist/pas/all_replicated_pp.json index 285edb7c..8b72449d 100644 --- a/tests/autodist/pas/all_replicated_pp.json +++ b/tests/autodist/pas/all_replicated_pp.json @@ -32,7 +32,17 @@ "mesh_desc": [ 1, 2 - ] + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } }, { "partition_descs": [ @@ -65,7 +75,17 @@ "mesh_desc": [ 1, 2 - ] + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } } ], "recompute_groups": [], @@ -76,12 +96,12 @@ }, "e2e_time": 0.0, "stage_mems": [ - 0.0 + 0.0, 0.0 ], "stage_all_times": [ - 0.0 + 0.0, 0.0 ], "stage_comp_times": [ - 0.0 + 0.0, 0.0 ] } diff --git a/tests/autodist/pas/replicated_and_partition.json b/tests/autodist/pas/replicated_and_partition.json index 6a133938..a47261b1 100644 --- a/tests/autodist/pas/replicated_and_partition.json +++ b/tests/autodist/pas/replicated_and_partition.json @@ -32,7 +32,17 @@ "mesh_desc": [ 1, 2 - ] + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } }, { "partition_descs": [ @@ -65,7 +75,17 @@ "mesh_desc": [ 1, 2 - ] + ], + "analysis": { + "comp_time_sum": 0.0, + "top_comp_time_sum": 0.0, + "top_comp_time": [], + "top_op_split_info": {}, + "comm_time_sum": 0.0, + "top_comm_time_sum": 0.0, + "top_comm_time": [], + "module_analysis": {} + } } ], "recompute_groups": [], @@ -76,12 +96,12 @@ }, "e2e_time": 0.0, "stage_mems": [ - 0.0 + 0.0, 0.0 ], "stage_all_times": [ - 0.0 + 0.0, 0.0 ], "stage_comp_times": [ - 0.0 + 0.0, 0.0 ] } diff --git a/tests/parallel_module/test_embedding.py b/tests/parallel_module/test_embedding.py new file mode 100644 index 00000000..6192ddb9 --- /dev/null +++ b/tests/parallel_module/test_embedding.py @@ -0,0 +1,33 @@ +import torch +import tempfile +import pytest +from nnscaler.parallel import _gen_graph, ComputeConfig +from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed = torch.nn.Embedding(10, 20) + + def forward(self, x): + return self.embed(x).sum() + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_requires_grad(): + model = Model() + model.train() + + dummy_input = {'x': torch.randint(0, 10, (10, 10))} + + with tempfile.TemporaryDirectory() as tempdir: + + graph, _ = _gen_graph( + model, + dummy_input, + outdir=tempdir, + dynamic_shape=False, + end2end_mode=True, + ) + embed_op = graph.nodes()[1] + assert embed_op.inputs()[0].requires_grad == False + assert embed_op.inputs()[1].requires_grad == True From 75266740fd1fdf4d47c202f23428b5999e039fdc Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 27 May 2024 10:21:33 +0000 Subject: [PATCH 1645/1892] Merged PR 2146: add & refine functions support add: ``` torch.any torch.isnan torch.isinf torch.svd torch.diag torch.randn torch.rand_like torch.randn_like torch.nn.functional.l1_loss ``` refine: ``` setitem torch.bitwise_or torch.gather torch.nn.functional.nll_loss ``` --- nnscaler/graph/function/function.py | 365 ++++++++++++++++++++----- nnscaler/graph/parser/fx/mapping.py | 11 + nnscaler/runtime/function/function.py | 49 +++- tests/graph/function/test_functions.py | 76 ++++- 4 files changed, 415 insertions(+), 86 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index a523a9b8..c66fe961 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -255,80 +255,133 @@ def Linspace(start, end, steps, *, out=None, dtype=None, return CubeLinspace(start, end, steps, dtype, requires_grad=requires_grad) +def creation_function_args_check(op_name, *, generator=None, dtype=None, layout=None, device=None, memory_format=None): + if generator is not None: + raise ValueError(f"not support non-default generator for {op_name}") + if dtype is not None and not isinstance(dtype, torch.dtype): + raise ValueError(f"only supports torch.dtype for {op_name} but got {dtype}") + if layout not in (None, torch.strided): + raise ValueError(f"not support non-default layout for {op_name}") + if memory_format is not None: + raise ValueError(f"not support non-default memory_format for {op_name}") + if device is not None: + _logger.warning(f"not support manual device in {op_name}, the device will be ignored") + + +def creation_function_size_check(op_name, size, *arg_size) -> Tuple[Union[int, IRObject]]: + size_val = _unwrap_value(size) + if isinstance(size_val, int): + size = (size, *arg_size) + elif isinstance(size_val, (tuple, list)): + if len(arg_size) > 0: + raise ValueError(f"get illegal input size={size}, arg_size={arg_size} in {op_name}") + # convert scalar to shape (1,) tensor, nnscaler don't support empty shape [] now. + if len(size_val) == 0: + _logger.warn(f"detect tensor creation function {op_name} create a scalar, force it to create a shape [1] tensor instead") + size = (1,) + else: + raise ValueError(f"get unknown input type size={size} in {op_name}") + return size + + def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): - # note: device is ignored - assert layout in (None, torch.strided) and memory_format is None, f"Not support for non-default memory_format and layout" + """ + torch.empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, + memory_format=torch.contiguous_format) → Tensor + """ dtype = dtype if dtype is not None else torch.get_default_dtype() - assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + creation_function_args_check('torch.empty', dtype=dtype, layout=layout, device=device, memory_format=memory_format) + + # using nnscaler runtime function is because we need set device on the correct device during runtime signature = 'nnscaler.runtime.function.empty' - size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) - size: Tuple[Union[int, IRObject]] = size + arg_size + size = creation_function_size_check('torch.empty', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} - anno, rules = _get_creator_anno_rules( - tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) + anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) return IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) def Zeros(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): - # note: device is ignored - assert layout in (None, torch.strided), f"Not support for non-strided layout, get {layout}" + """ + torch.zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor + """ dtype = dtype if dtype is not None else torch.get_default_dtype() - assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + creation_function_args_check('torch.zeros', dtype=dtype, layout=layout, device=device) + + # using nnscaler runtime function is because we need set device on the correct device during runtime signature = 'nnscaler.runtime.function.zeros' - size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) - size: Tuple[Union[int, IRObject]] = size + arg_size + size = creation_function_size_check('torch.zeros', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} - anno, rules = _get_creator_anno_rules( - tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) + anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) return IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) def Ones(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): - # note: device is ignored - assert layout in (None, torch.strided), f"Not support for non-strided layout, get {layout}" + """ + torch.ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor + """ dtype = dtype if dtype is not None else torch.get_default_dtype() - assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + creation_function_args_check('torch.ones', dtype=dtype, layout=layout, device=device) + + # using nnscaler runtime function is because we need set device on the correct device during runtime signature = 'nnscaler.runtime.function.ones' - size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) - size: Tuple[Union[int, IRObject]] = size + arg_size + size = creation_function_size_check('torch.ones', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} - anno, rules = _get_creator_anno_rules( - tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) + anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) return IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): - # note: device is ignored - assert layout in (None, torch.strided) and memory_format is None, f"Not support for non-default memory_format and layout" + """ + torch.rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, + requires_grad=False, pin_memory=False) → Tensor + """ dtype = dtype if dtype is not None else torch.get_default_dtype() - assert isinstance(dtype, torch.dtype), f"only supports torch.dtype but got {dtype}" + creation_function_args_check('torch.rand', dtype=dtype, layout=layout, device=device, memory_format=memory_format) + + # using nnscaler runtime function is because we need set device on the correct device during runtime signature = 'nnscaler.runtime.function.rand' - size = (size,) if isinstance(size, (int, IRObject)) else tuple(size) - size: Tuple[Union[int, IRObject]] = size + arg_size + size = creation_function_size_check('torch.rand', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} - anno, rules = _get_creator_anno_rules( - tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) + anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) return IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) +def Randn(size, *arg_size, generator=None, out=None, dtype=None, layout=None, device=None, requires_grad=False, + pin_memory=False, memory_format=None, signature=None): + """ + torch.randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, + requires_grad=False, pin_memory=False) → Tensor + """ + dtype = dtype if dtype is not None else torch.get_default_dtype() + creation_function_args_check('torch.randn', generator=generator, dtype=dtype, layout=layout, device=device, memory_format=memory_format) + + # using nnscaler runtime function is because we need set device on the correct device during runtime + signature = 'nnscaler.runtime.function.randn' + size = creation_function_size_check('torch.randn', size, *arg_size) + kwargs = {'size': size, 'requires_grad': requires_grad, + 'dtype': dtype, 'pin_memory': pin_memory} + anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) + return IRDimops(Randn, 'randn', signature, [anno], [], rules, **kwargs) + + def Full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, requires_grad=False, signature=None): """ torch.full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) """ - assert layout in (None, torch.strided), f"Not support for non-default layout" dtype = dtype if dtype is not None else torch.get_default_dtype() + creation_function_args_check('torch.full', dtype=dtype, layout=layout, device=device) + + # using nnscaler runtime function is because we need set device on the correct device during runtime signature = 'nnscaler.runtime.function.full' - # cube treat scalar as size (1,) tensor now, scalar support will in another pr if necessary - size = tuple(size) if size else (1,) - anno, rules = _get_creator_anno_rules( - tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), True) + size = creation_function_size_check('torch.full', size) + anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) return IRDimops(Full, 'full', signature, [anno], [], rules, size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad) @@ -338,6 +391,9 @@ def NewTensor(data, *, dtype=None, device=None, """ torch.tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) """ + creation_function_args_check('torch.ones', device=device) + + # using nnscaler runtime function is because we need set device on the correct device during runtime signature = 'nnscaler.runtime.function.tensor' val = data @@ -512,7 +568,8 @@ def BitwiseOr(input, other, *, out=None, signature=None): if (not isinstance(input, IRObject)) and (not isinstance(other, IRObject)): return input | other assert isinstance(input, IRTensor) and isinstance(other, IRTensor) - annos = ['*, * -> *'] + lshape, rshape, oshape = _handle_broadcast(input, other) + annos = [OpAnno.create_op_str([lshape, rshape], [oshape])] return IRDimops(BitwiseOr, 'bitwise_or', signature, annos, [input, other]) @@ -525,6 +582,20 @@ def BitwiseNot(input, *, out=None, signature=None): return IRDimops(BitwiseNot, 'bitwise_not', signature, annos, [input]) +def IsNan(input, *, signature=None): + """ + torch.isnan(input) → Tensor + """ + return IRDimops(IsNan, 'isnan', signature, ['* -> *'], [input]) + + +def IsInf(input, *, signature=None): + """ + torch.isinf(input) → Tensor + """ + return IRDimops(IsInf, 'isinf', signature, ['* -> *'], [input]) + + # TODO: this function should rewrite with pytree def _unwrap_value(obj: Union[IRObject, Any]): if isinstance(obj, IRObject): @@ -1033,6 +1104,27 @@ def Sum(input, dim=None, keepdim=False, *, dtype=None, signature = None): return IRDimops(Sum, 'sum', signature, [anno], [input], dim=dim, keepdim=keepdim) +def TorchAny(input, dim=None, keepdim=False, *, out=None, signature = None): + """ + torch.any(input) -> Tensor + torch.any(input, dim, keepdim=False, *, out=None) -> Tensor + """ + einput = ShapeAnno.create_shape_str(input.shape, '^') + dim_value = _unwrap_value(dim) + if dim_value is None: + anno = OpAnno.create_op_str([einput], [['1']]) + return IRDimops(TorchAny, 'any', signature, [anno], [input]) + else: + eoutput = copy.copy(einput) + keepdim_value = _unwrap_value(keepdim) + if keepdim_value: + eoutput[dim] = '1' + else: + eoutput.pop(dim) + anno = OpAnno.create_op_str([einput], [eoutput]) + return IRDimops(TorchAny, 'any', signature, [anno], [input], dim=dim, keepdim=keepdim) + + def Mean(input, dim=None, keepdim=False, *, dtype=None, signature = None): """ torch.mean(input, *, dtype=None) @@ -2134,10 +2226,24 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: return obj[index] -def SetItem(__a: Any, __b: Any, __c: Any, signature = None) -> Union[Any, IRPyFunc]: - """_operator.setitem(__a, __b, __c) / nnscaler.runtime.function.setitem(__a, __b, __c)""" +def SetItem(__a: Any, __b: Any, __c: Any, *additonal, signature = None) -> Union[Any, IRPyFunc]: + """ + _operator.setitem(__a, __b, __c) / nnscaler.runtime.function.setitem(__a, *__bc) + + If __a is a IRTensor and __b is a tuple, __b will be flatten to ensure we can give each element an annotation, + and the returned value is a IRDimops. + If __a is a IRObject, the returned value is a IRPyFunc. + + Note that in IRDimops.new, __c might not the original __c of the setitem during parse, it may be one of the elements of the flatten __b, + in this case, original __c is the last element in additonal, original __b is (__b, __c, *additonal[:-1]). + """ signature = 'nnscaler.runtime.function.setitem' - obj, index, val = __a, __b, __c + # additional is used to receive additional parameters due to __b flatten + # unflatten __b here if additional is not empty + if len(additonal) > 0: + obj, index, val = __a, (__b, __c, *additonal[:-1]), additonal[-1] + else: + obj, index, val = __a, __b, __c if isinstance(obj, IRTensor): # TODO: move to some function like FullSlice when ready # TODO: give a IRTensor as return value or return a IRDimops @@ -2149,13 +2255,15 @@ def SetItem(__a: Any, __b: Any, __c: Any, signature = None) -> Union[Any, IRPyFu edim_ins = [edim_obj] # index annotation - if isinstance(index, IRTensor): - edim_index = ShapeAnno.create_shape_str(index.shape, '^', iterator=gener) - edim_ins.append(edim_index) - elif isinstance(index, IRObject) and any_ir_object_satisfy(index, lambda a: isinstance(a, IRTensor)): - raise RuntimeError(f"setitem did not support slicers include tensor now, got {index}") - else: - edim_ins.append(['?']) + idxes = index if isinstance(index, tuple) else (index,) + for idx in idxes: + if isinstance(idx, IRTensor): + edim_index = ShapeAnno.create_shape_str(idx.shape, '^', iterator=gener) + edim_ins.append(edim_index) + elif isinstance(idx, IRObject) and any_ir_object_satisfy(idx, lambda a: isinstance(a, IRTensor)): + raise RuntimeError(f"setitem did not support slicers include tensor now, got {idx}") + else: + edim_ins.append(['?']) # value annotation if isinstance(val, IRTensor): @@ -2165,7 +2273,8 @@ def SetItem(__a: Any, __b: Any, __c: Any, signature = None) -> Union[Any, IRPyFu edim_ins.append(['?']) anno = OpAnno.create_op_str(edim_ins, [edim_out]) - return IRDimops(SetItem, 'setitem', signature, [anno], [obj, index, val]) + # because we cannot annotate the tensor inside tuple/dict, so here we flatten the idxes. + return IRDimops(SetItem, 'setitem', signature, [anno], [obj, *idxes, val]) is_constant = not ir_object_contains_dynamic(index) index = _unwrap_value(index) @@ -2233,12 +2342,30 @@ def NLLLoss(input, target, weight=None, size_average=None, torch.nn.functional.nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean') """ - assert weight is None - annos = [ - 'C^, N -> 1', - 'N+ C, N+ -> 1', - 'N+ C *, N+ * -> 1' - ] + if weight is not None: + raise NotImplementedError("weight has not support for torch.nn.functional.nll_loss") + if _unwrap_value(reduction) == 'none': + annos = [ + 'C^, N -> N', + 'N C^, N -> N', + 'N C^ *, N * -> N *' + ] + elif _unwrap_value(reduction) == 'sum': + annos = [ + 'C^, N -> 1', + 'N+ C^, N+ -> 1', + 'N+ C^ *, N+ * -> 1' + ] + elif _unwrap_value(reduction) == 'mean': + # TODO(nishang): here should consider about the ignore idx and the scale of the result if we apply tp + # for now, we give '^' to all anno, only replicated is allowed for mean reduction. + annos = [ + 'C^, N^ -> 1', + 'N^ C^, N^ -> 1', + 'N^ C^ *, N^ * -> 1' + ] + else: + raise NotImplementedError(f'unknow reduction in torch.nn.functional.nll_loss: {reduction}') return IRDimops( NLLLoss, 'nll_loss', signature, annos, [input, target], @@ -2246,6 +2373,30 @@ def NLLLoss(input, target, weight=None, size_average=None, reduce=reduce, reduction=reduction) +def L1Loss(input, target, size_average=None, reduce=None, reduction='mean', signature=None): + """ + torch.nn.functional.l1_loss(input, target, size_average=None, reduce=None, reduction='mean') + """ + if not isinstance(input, IRTensor) or not isinstance(target, IRTensor): + raise ValueError(f"expect input and target are IRTensor, but get input={input} and target={target}") + if input.shape != target.shape: + raise ValueError(f"shape mismatched, input shape is {input.shape}, target shape is {target.shape}") + if _unwrap_value(reduction) == 'none': + annos = ['*, * -> *'] + elif _unwrap_value(reduction) == 'sum': + edim_in = ShapeAnno.create_shape_str(input.shape, '+') + annos = [OpAnno.create_op_str([edim_in, edim_in], ['1'])] + elif _unwrap_value(reduction) == 'mean': + # TODO(nishang): I don't know how to give a correct tp anno, the result of loss will be scaled by tp number if we apply tp + # for now, we give '^' to all anno, only replicated is allowed for mean reduction. + edim_in = ShapeAnno.create_shape_str(input.shape, '^') + annos = [OpAnno.create_op_str([edim_in, edim_in], ['1'])] + else: + raise NotImplementedError(f'unknow reduction in torch.nn.functional.l1_loss: {reduction}') + return IRDimops(L1Loss, 'l1_loss', signature, annos, [input, target], + size_average=size_average, reduce=reduce, reduction=reduction) + + def MakeTuple(inputs: Iterable, signature=None): return tuple(inputs) @@ -2283,7 +2434,7 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, dropout_p=0.0, For a common attention, the generated anno is like (a e d^, a b^ d^, a b^ c -> a e c). """ if not isinstance(query, IRTensor) or not isinstance(key, IRTensor) or not isinstance(value, IRTensor): - raise RuntimeError(f'query: {query}, key: {key}, value: {value} should be IRTensor, something went wrong.') + raise ValueError(f'query: {query}, key: {key}, value: {value} should be IRTensor, something went wrong.') gener = iter(string.ascii_lowercase) value_anno = ShapeAnno.create_shape_str(value.shape, iterator=gener) value_anno[-2] += '^' @@ -2297,9 +2448,9 @@ def ScaledDotProductAttention(query, key, value, attn_mask=None, dropout_p=0.0, out_anno[-1] = value_anno[-1] if attn_mask is not None: if not isinstance(attn_mask, IRTensor): - raise RuntimeError(f'attn_mask: {attn_mask} should be IRTensor, something went wrong.') + raise ValueError(f'attn_mask: {attn_mask} should be IRTensor, something went wrong.') if len(attn_mask.shape) < 2 or len(attn_mask.shape) > len(query.shape): - raise RuntimeError(f'attn_mask shape {attn_mask.shape} is not supported, while query shape is {query.shape}') + raise ValueError(f'attn_mask shape {attn_mask.shape} is not supported, while query shape is {query.shape}') attn_mask_anno = [] # the anno of attn_mask will conbine query and attn_mask shape except last dimension, # the last dimension of the attn_mask anno will be the same as key penultimate dimension @@ -2378,8 +2529,7 @@ def FullLike(input, fill_value, *, dtype=None, layout=None, """ torch.full_like(input, fill_value, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor """ - if not (layout in (None, torch.strided) and memory_format is None): - raise ValueError("Not support for non-default memory_format and layout") + creation_function_args_check('torch.full_like', dtype=dtype, layout=layout, memory_format=memory_format) kwargs = {'fill_value': fill_value, 'requires_grad': requires_grad,'dtype': dtype} shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] @@ -2390,30 +2540,44 @@ def ZerosLike(input, *, dtype=None, layout=None, device=None, requires_grad=Fals """ torch.zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor """ - if not (layout in (None, torch.strided) and memory_format is None): - raise ValueError("Not support for non-default memory_format and layout") - dtype = dtype if dtype is not None else torch.get_default_dtype() - if not isinstance(dtype, torch.dtype): - raise TypeError("only supports torch.dtype but got {}".format(dtype)) + creation_function_args_check('torch.zeros_like', dtype=dtype, layout=layout, memory_format=memory_format) kwargs = {'requires_grad': requires_grad, 'dtype': dtype} shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] - return IRDimops(ZerosLike, 'zeros_like', signature, annos,[input],**kwargs) + return IRDimops(ZerosLike, 'zeros_like', signature, annos, [input], **kwargs) def OnesLike(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None, signature=None): """ torch.ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor """ - if not (layout in (None, torch.strided) and memory_format is None): - raise ValueError("Not support for non-default memory_format and layout") - dtype = dtype if dtype is not None else torch.get_default_dtype() - if not isinstance(dtype, torch.dtype): - raise TypeError("only supports torch.dtype but got {}".format(dtype)) + creation_function_args_check('torch.ones_like', dtype=dtype, layout=layout, memory_format=memory_format) kwargs = {'requires_grad': requires_grad, 'dtype': dtype} shape = ShapeAnno.create_shape_str(input.shape) annos = [OpAnno.create_op_str([shape], [shape])] - return IRDimops(OnesLike, 'onesLike', signature, annos,[input],**kwargs) + return IRDimops(OnesLike, 'ones_like', signature, annos, [input], **kwargs) + + +def RandLike(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None, signature=None): + """ + torch.rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor + """ + creation_function_args_check('torch.rand_like', dtype=dtype, layout=layout, memory_format=memory_format) + kwargs = {'requires_grad': requires_grad, 'dtype': dtype} + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(RandLike, 'rand_like', signature, annos, [input], **kwargs) + + +def RandnLike(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None, signature=None): + """ + torch.randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor + """ + creation_function_args_check('torch.randn_like', dtype=dtype, layout=layout, memory_format=memory_format) + kwargs = {'requires_grad': requires_grad, 'dtype': dtype} + shape = ShapeAnno.create_shape_str(input.shape) + annos = [OpAnno.create_op_str([shape], [shape])] + return IRDimops(RandnLike, 'randn_like', signature, annos, [input], **kwargs) def Addmm(input: IRTensor, mat1: IRTensor, mat2: IRTensor, *, beta=1, alpha=1, out=None, signature = None): @@ -2551,16 +2715,69 @@ def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: stride=stride, padding=padding, dilation=dilation, groups=ori_groups) +def SVD(input, some=True, compute_uv=True, *, out=None, signature=None): + """ + torch.svd(input, some=True, compute_uv=True, *, out=None) + + NOTE: the signature of torch.linalg.svd is different with torch.svd, don't forward torch.linalg.svd to this function + """ + if not isinstance(input, IRTensor): + raise ValueError(f"expect input is an IRTensor, but get input={input}") + if len(input.shape) < 2: + raise ValueError(f"expect input at least a 2-D tensor, but get input with shape {input.shape}") + + some_value = _unwrap_value(some) + compute_uv_value = _unwrap_value(compute_uv) + + in_shape = ShapeAnno.create_shape_str(input.shape, '^') + m, n = input.shape[-2:] + # for the some is False or compute_uv is False + o1_shape = copy.copy(in_shape) + o1_shape[-1] = o1_shape[-2] + o2_shape = [in_shape[-1] if m > n else in_shape[-2]] + o3_shape = copy.copy(in_shape) + o3_shape[-2] = o3_shape[-1] + + if some_value and compute_uv_value: + o1_shape[-1] = in_shape[-2] if m < n else in_shape[-1] + o3_shape[-1] = in_shape[-2] if m < n else in_shape[-1] + + annos = [OpAnno.create_op_str([in_shape], [o1_shape, o2_shape, o3_shape])] + return IRDimops(SVD, 'svd', signature, annos, [input], some=some, compute_uv=compute_uv) + + +def Diag(input, diagonal=0, *, out=None, signature=None): + """ + torch.diag(input, diagonal=0, *, out=None) -> Tensor + """ + assert isinstance(input, IRTensor) + diagonal_value = _unwrap_value(diagonal) + if len(input.shape) == 1: + dim_len = input.shape[0] + odim_len = dim_len + abs(diagonal_value) + anno = f'{dim_len} -> {odim_len} {odim_len}' + else: + # TODO: in fact, we can partition with modifier here, will do it latter + if diagonal_value >= 0: + outlen = min(input.shape[0], input.shape[1] - diagonal_value) + else: + outlen = min(input.shape[0] + diagonal_value, input.shape[1]) + anno = f'{input.shape[0]} {input.shape[1]} -> {max(0, outlen)}' + return IRDimops(Diag, 'diag', signature, [anno], [input], diagonal=diagonal) + + def Gather(input: IRTensor, dim, index: IRTensor, sparse_grad=False, out=None, signature=None): """ torch.gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor """ - if not (0 <= dim < len(input.shape)): - raise ValueError(f"Dimension {dim} is out of bounds for input with {len(input.shape)} dimensions.") + dim_value = _unwrap_value(dim) + if not (-len(input.shape) <= dim_value < len(input.shape)): + raise ValueError(f"Dimension {dim_value} is out of bounds for input with {len(input.shape)} dimensions.") + dim_value = (dim_value + len(input.shape)) % len(input.shape) if len(input.shape) != len(index.shape): raise ValueError("The dimensions of 'input' and 'index' must be the same.") for i, (dim_input, dim_index) in enumerate(zip(input.shape, index.shape)): - if i != dim and dim_index > dim_input: + if i != dim_value and dim_index > dim_input: raise ValueError(f"Index size {dim_index} at dimension {i} exceeds input size {dim_input} at the same dimension.") gener = iter(string.ascii_lowercase) input_anno = ShapeAnno.create_shape_str(input.shape, iterator=gener) @@ -2569,7 +2786,7 @@ def Gather(input: IRTensor, dim, index: IRTensor, sparse_grad=False, out=None, s if dim_input != dim_index: input_anno[i] += '^' index_anno[i] += '^' - elif i == dim: + elif i == dim_value: index_anno[i] = input_anno[i] input_anno[i] += '^' index_anno[i] += '^' @@ -2578,8 +2795,8 @@ def Gather(input: IRTensor, dim, index: IRTensor, sparse_grad=False, out=None, s # When dynamic shape is enabled, this partition may be incorrect. # We keep the partition here for now, and consider reporting errors that cannot be partitioned at run time in future. index_anno[i] = input_anno[i] - anno = OpAnno.create_op_str([input_anno, index_anno], [index_anno]) - return IRDimops(Gather, 'gather', signature, [anno], [input, index], dim=dim) + anno = OpAnno.create_op_str([input_anno, '?', index_anno], [index_anno]) + return IRDimops(Gather, 'gather', signature, [anno], [input, dim, index]) def Ceil(input: IRTensor, out=None, signature=None): diff --git a/nnscaler/graph/parser/fx/mapping.py b/nnscaler/graph/parser/fx/mapping.py index b416d0b0..8efc5b2e 100644 --- a/nnscaler/graph/parser/fx/mapping.py +++ b/nnscaler/graph/parser/fx/mapping.py @@ -65,6 +65,8 @@ def exist(signature: str) -> bool: __ttemplate('sqrt'): function.Sqrt, 'math.sqrt': function.Sqrt, __ttemplate('log'): function.Log, + __ttemplate('svd'): function.SVD, + __ttemplate('diag'): function.Diag, 'math.log': function.Log, __ttemplate('rsqrt'): function.RSqrt, __ttemplate('clamp'): function.Clamp, @@ -88,6 +90,7 @@ def exist(signature: str) -> bool: __ttemplate('max'): function.Max, __ttemplate('min'): function.Min, __ttemplate('where'): function.Where, + __ttemplate('nonzero'): function.Nonzero, __ttemplate('nan_to_num') : function.NanToNum, __tttemplate('type'): function.Type, __tttemplate('long'): function.Long, @@ -129,6 +132,8 @@ def exist(signature: str) -> bool: 'torch.functional.einsum': function.EinSum, __ftemplate('unfold'): function.Unfold, __ftemplate('nll_loss') : function.NLLLoss, + __ftemplate('l1_loss') : function.L1Loss, + __ttemplate('norm'): function.Norm, 'torch.functional.norm': function.Norm, __ftemplate('layer_norm'): function.LayerNorm, __ftemplate('scaled_dot_product_attention'): function.ScaledDotProductAttention, @@ -167,10 +172,16 @@ def exist(signature: str) -> bool: __ttemplate('full'): function.Full, __ttemplate('full_like'): function.FullLike, __ttemplate('rand'): function.Rand, + __ttemplate('rand_like'): function.RandLike, + __ttemplate('randn'): function.Randn, + __ttemplate('randn_like'): function.RandnLike, __ttemplate('clone'): function.Clone, '_operator.is_': function.Is, '_operator.is_not': function.IsNot, + __ttemplate('isnan'): function.IsNan, + __ttemplate('isinf'): function.IsInf, + __ttemplate('any'): function.TorchAny, __ttemplate('add') : function.Add, '_operator.add': function.Add, __ttemplate('addmm'): function.Addmm, diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 4a17635f..0104d4c0 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -141,6 +141,9 @@ def select_scatter(input:torch.Tensor, src:torch.Tensor, dim:int, index:int): def tensor(data, *, dtype=None, requires_grad=False, pin_memory=False): + """ + force set the device to torch.cuda.current_device() + """ return torch.tensor( data, dtype=dtype, device=torch.cuda.current_device(), @@ -149,6 +152,9 @@ def tensor(data, *, dtype=None, requires_grad=False, pin_memory=False): def empty(size: Tuple[int], dtype=None, requires_grad=False, pin_memory=False): + """ + force set the device to torch.cuda.current_device() + """ return torch.empty( size, dtype=torch.get_default_dtype() if dtype is None else dtype, device=torch.cuda.current_device(), @@ -157,6 +163,9 @@ def empty(size: Tuple[int], dtype=None, requires_grad=False, pin_memory=False): def zeros(size: Tuple[int], dtype=None, requires_grad=False): + """ + force set the device to torch.cuda.current_device() + """ return torch.zeros( size, dtype=torch.get_default_dtype() if dtype is None else dtype, device=torch.cuda.current_device(), @@ -165,6 +174,9 @@ def zeros(size: Tuple[int], dtype=None, requires_grad=False): def ones(size: Tuple[int], dtype=None, requires_grad=False): + """ + force set the device to torch.cuda.current_device() + """ return torch.ones( size, dtype=torch.get_default_dtype() if dtype is None else dtype, device=torch.cuda.current_device(), @@ -172,14 +184,34 @@ def ones(size: Tuple[int], dtype=None, requires_grad=False): ) -def rand(size: Tuple[int], dtype=None, requires_grad=False): +def rand(size: Tuple[int], dtype=None, requires_grad=False, pin_memory=False): + """ + force set the device to torch.cuda.current_device() + """ return torch.rand( size, dtype=torch.get_default_dtype() if dtype is None else dtype, device=torch.cuda.current_device(), - requires_grad=requires_grad + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +def randn(size: Tuple[int], dtype=None, requires_grad=False, pin_memory=False): + """ + force set the device to torch.cuda.current_device() + """ + return torch.randn( + size, dtype=torch.get_default_dtype() if dtype is None else dtype, + device=torch.cuda.current_device(), + pin_memory=pin_memory, + requires_grad=requires_grad, ) + def full(size: Tuple[int], fill_value, dtype=None, requires_grad=False): + """ + force set the device to torch.cuda.current_device() + """ return torch.full( size, fill_value, dtype=dtype, requires_grad=requires_grad, device=torch.cuda.current_device() @@ -218,7 +250,18 @@ def nndropout(input: torch.Tensor, p=0.5, inplace=False): return torch.nn.Dropout(p, inplace)(input) -def setitem(__a, __b, __c): +def setitem(__a, *__bc): + """ + If __bc has more than 2 elements, that means idxs are flatten becasue idxs contains tensor. + In this runtime function, idxs will be structured as a tuple if they are flatten, + and return __a to make this inplace operation trackable. + """ + if len(__bc) < 2: + raise ValueError(f'at least two arguments needed, but get __bc={__bc}') + elif len(__bc) == 2: + __b, __c = __bc[0], __bc[1] + else: + __b, __c = __bc[:-1], __bc[-1] operator.setitem(__a, __b, __c) return __a diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index d24fd357..40de23a0 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -406,6 +406,9 @@ def test_Setitem(): op = F.SetItem(IRTensor([3, 4, 5]), IRTensor([3, 4, 5]), IRObject(value=1.)) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, d^ e^ f^, ? -> a^ b^ c^' + op = F.SetItem(IRTensor([3, 4, 5]), IRTensor([3]), IRObject(value=0), IRObject(value=0), IRObject(value=1.)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, d^, ?, ?, ? -> a^ b^ c^' + def test_Len(): op = F.Len([1, 2, 3], signature='builtins.len') @@ -627,28 +630,31 @@ def test_Flatten(): def test_Gather(): op = F.Gather(IRTensor([2, 5, 3]), 2, IRTensor([2, 5, 1])) - expected_annotation = 'a b c^, a b f^ -> a b f^' + expected_annotation = 'a b c^, ?, a b f^ -> a b f^' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." op = F.Gather(IRTensor([2, 5, 3]), 2, IRTensor([2, 5, 3])) - expected_annotation = 'a b c^, a b c^ -> a b c^' + expected_annotation = 'a b c^, ?, a b c^ -> a b c^' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." op = F.Gather(IRTensor([2, 5, 3]), 2, IRTensor([2, 4, 3])) - expected_annotation = 'a b^ c^, a e^ c^ -> a e^ c^' + expected_annotation = 'a b^ c^, ?, a e^ c^ -> a e^ c^' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." op = F.Gather(IRTensor([2, 5, 3]), 2, IRTensor([1, 3, 1])) - expected_annotation = 'a^ b^ c^, d^ e^ f^ -> d^ e^ f^' + expected_annotation = 'a^ b^ c^, ?, d^ e^ f^ -> d^ e^ f^' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." op = F.Gather(IRTensor([2, 5, 3]), 1, IRTensor([2, 2, 3])) - expected_annotation = 'a b^ c, a e^ c -> a e^ c' + expected_annotation = 'a b^ c, ?, a e^ c -> a e^ c' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." op = F.Gather(IRTensor([2, 5, 3]), 0, IRTensor([1, 5, 3])) - expected_annotation = 'a^ b c, d^ b c -> d^ b c' + expected_annotation = 'a^ b c, ?, d^ b c -> d^ b c' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." op = F.Gather(IRTensor([2, 3]), 1, IRTensor([2, 1])) - expected_annotation = 'a b^, a d^ -> a d^' + expected_annotation = 'a b^, ?, a d^ -> a d^' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." op = F.Gather(IRTensor([2, 3]), 1, IRTensor([1, 1])) - expected_annotation = 'a^ b^, c^ d^ -> c^ d^' + expected_annotation = 'a^ b^, ?, c^ d^ -> c^ d^' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." + op = F.Gather(IRTensor([2, 3]), -1, IRTensor([1, 1])) + expected_annotation = 'a^ b^, ?, c^ d^ -> c^ d^' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Gather." @@ -686,4 +692,56 @@ def test_Unfold(): def test_Sigmoid(): op = F.Sigmoid(IRTensor([2, 3, 4])) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' \ No newline at end of file + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '* -> *' + + +def test_BitwiseOr(): + op = F.BitwiseOr(IRTensor([8, 10]), IRTensor([10])) + assert op._annos_candidates[0] == 'a b, b -> a b' + + +def test_TorchAny(): + op = F.TorchAny(IRTensor([10, 10])) + assert op._annos_candidates[0] == 'a^ b^ -> 1' + + op = F.TorchAny(IRTensor([10, 10]), dim=1) + assert op._annos_candidates[0] == 'a^ b^ -> a^' + + op = F.TorchAny(IRTensor([10, 10]), dim=1, keepdim=True) + assert op._annos_candidates[0] == 'a^ b^ -> a^ 1' + + +def test_L1Loss(): + op = F.L1Loss(IRTensor([8, 10]), IRTensor([8, 10]), reduction='sum') + assert op._annos_candidates[0] == 'a+ b+, a+ b+ -> 1' + + op = F.L1Loss(IRTensor([8, 10]), IRTensor([8, 10]), reduction='mean') + assert op._annos_candidates[0] == 'a^ b^, a^ b^ -> 1' + + +def test_SVD(): + op = F.SVD(IRTensor([3, 4])) + assert op._annos_candidates[0] == 'a^ b^ -> a^ a^, a^, b^ a^' + + op = F.SVD(IRTensor([4, 3])) + assert op._annos_candidates[0] == 'a^ b^ -> a^ b^, b^, b^ b^' + + op = F.SVD(IRTensor([4, 3]), False) + assert op._annos_candidates[0] == 'a^ b^ -> a^ a^, b^, b^ b^' + + +def test_Diag(): + op = F.Diag(IRTensor([5, 10]), 0) + assert op._annos_candidates[0] == '5 10 -> 5' + + op = F.Diag(IRTensor([5, 10]), 5) + assert op._annos_candidates[0] == '5 10 -> 5' + + op = F.Diag(IRTensor([5, 10]), 7) + assert op._annos_candidates[0] == '5 10 -> 3' + + op = F.Diag(IRTensor([5, 10]), 10) + assert op._annos_candidates[0] == '5 10 -> 0' + + op = F.Diag(IRTensor([5, 10]), -1) + assert op._annos_candidates[0] == '5 10 -> 4' From e2f5697212bc0523f26cdfa6254f7d32b02f4676 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 29 May 2024 05:21:38 +0000 Subject: [PATCH 1646/1892] Merged PR 2157: fix small bugs in autodist & tracer 1. node.inputs can have non-IRTensor 2. during tracing, input args move to cuda may OOM --- nnscaler/autodist/model_graph.py | 3 ++- .../graph/parser/fx/concrete_trace_utils/concrete_tracer.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 246e0f45..99b0ff3f 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -238,7 +238,8 @@ def aggregate_train_mem(sub_nodes: List[IRFwOperation], db) -> int: train_mem += mem else: t = node.inputs()[in_idx] - if t not in visited_tensors: + # `t` also can be any other unhashable var, if we set ? in annotation + if isinstance(t, IRTensor) and t not in visited_tensors: train_mem += mem visited_tensors.add(t) return train_mem diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index b1efc0a3..fc9fd1d4 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -408,10 +408,9 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] raise RuntimeError() return result - if self.cpu_offload: - args, kwargs = tree_to_cuda(args), tree_to_cuda(kwargs) - try: + if self.cpu_offload: + args, kwargs = tree_to_cuda(args), tree_to_cuda(kwargs) result = run(kind, target, args, kwargs) except torch.cuda.OutOfMemoryError: if self.cpu_offload: From 49433e5b9dcb78e9889d4bb4c0854d801ad3e954 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Wed, 29 May 2024 12:32:06 +0000 Subject: [PATCH 1647/1892] Merged PR 2150: fix fullslice fix fullslice --- nnscaler/graph/function/function.py | 23 ++++++++++++++++++----- tests/graph/function/test_functions.py | 18 ++++++++++++++++-- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index c66fe961..b81cf74a 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -953,12 +953,15 @@ def Topk(input, k, dim=None, largest=True, sorted=True, *, out=None, signature = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) """ edim_in = ShapeAnno.create_shape_str(input.shape) + edim_ou = copy.copy(edim_in) + k = _unwrap_value(k) if dim is None: edim_in[-1] += '^' + edim_ou[-1] = str(k) else: edim_in[dim] += '^' - edim_ou = [['?'], ['?']] - anno = OpAnno.create_op_str([edim_in], edim_ou) + edim_ou[dim] = str(k) + anno = OpAnno.create_op_str([edim_in], [edim_ou, edim_ou]) return IRDimops(Topk, 'topk', signature, [anno], [input], k=k, dim=dim, largest=largest, sorted=sorted) @@ -1369,6 +1372,8 @@ def View(input, size: Tuple[int], *arg_size, signature = None): """ out = torch.Tensor.view(tensor: torch.Tensor, *size) """ + if isinstance(size, torch.dtype): + raise ValueError(f"View by dtype is not supported: {size}") in_shape = list(input.shape) if isinstance(size, IRObject): assert size.value is not None, f"shape should have a reference value but got: {size}" @@ -1818,9 +1823,7 @@ def list_shape(lst): else: raise RuntimeError(f"Unsupported slicer {slicer}. you may need to wrap related logic in a Customized Op.") - if output_shape_unkonwn: - edim_ou = ['?'] - else: + if not output_shape_unkonwn: edim_ou += edim_in[in_idx:] if len(edim_ou) == 0: # special case for scalar = torch.Tensor([1,2,3])[0] @@ -1828,6 +1831,16 @@ def list_shape(lst): edim_in = [edim_in] edim_in.extend(edim_in_additional) + + if output_shape_unkonwn: + edim_ou = ['?'] + for i in range(len(edim_in)): + for j in range(len(edim_in[i])): + # current implementation doesn't use '()', so we don't consider it to simply the code + assert '(' not in edim_in[i][j], 'no () is supposed to be used' + if not edim_in[i][j].endswith(('^', '+', '?')): + edim_in[i][j] += '^' + anno = OpAnno.create_op_str(edim_in, [edim_ou]) return IRDimops(FullSlice, 'fullslice', signature, [anno], [tensor] + slicers) diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 40de23a0..39852909 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -181,6 +181,20 @@ def test_Repeat(): op = F.Repeat(inp, o(2)) +def test_Topk(): + op = F.Topk(IRTensor([3, 4, 5]), 3) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c^ -> a b 3, a b 3' + op = F.Topk(IRTensor([3, 4, 5]), 3, dim = 1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c -> a 3 c, a 3 c' + + +def test_Nonzero(): + op = F.Nonzero(IRTensor([3, 4, 5]), as_tuple=True) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> ?, ?, ?' + op = F.Nonzero(IRTensor([3, 4, 5]), as_tuple=False) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> ?' + + def test_Where(): op = F.Where(IRTensor([3, 4]), IRTensor([3, 4]), IRTensor([3, 4])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b, a b, a b -> a b' @@ -228,7 +242,7 @@ def test_FullSlice(): op = F.FullSlice(IRTensor([3, 4]), IRFullTensor([2,2])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b, c d -> c d b' op = F.FullSlice(IRTensor([3, 4]), [True, False, True]) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b, ? -> ?' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, ? -> ?' op = F.FullSlice(IRTensor([3, 4]), IRFullTensor([3], dtype=torch.bool), 0) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^, c^, ? -> ?' op = F.FullSlice(IRTensor([3, 4]), [True, False, True], [0,1]) @@ -264,7 +278,7 @@ def test_GetItem(): op = F.GetItem(IRTensor([3, 4, 2]), [slice(None), IRTensor([3, 5], dtype=torch.int64)]) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c, ?, d e -> a d e c' op = F.GetItem(IRTensor([3, 4, 2]), [slice(None), IRTensor([4, 2], dtype=torch.bool)]) - assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b^ c^, ?, d^ e^ -> ?' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^, ?, d^ e^ -> ?' # obj is IRObject op = F.GetItem(IRObject(value=[3, 4, 5], is_constant=False), IRObject(value=0, is_constant=False), signature='operator.getitem') From 8036b82bbea1af8f6f3fe73d7d0765d17d80844e Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 3 Jun 2024 07:06:07 +0000 Subject: [PATCH 1648/1892] Merged PR 2153: Add dedup_attrs to reduce model size in checkpointing Introduce an interface `dedup_attrs`. It deduplicate the attributes according to `rank2attr_area_map`. For each `slicers` of a full tensor with the name `orig_name`, we only store its first appearance in the `rank2attr_area_map`. In addition, we will check - the shape of the full tensor is consistent across different ranks - the slicers of the full tensor are not intersected with each other - the slicers of the full tensor can cover the full tensor --- nnscaler/runtime/module.py | 74 +++++++++++++++++++ tests/parallel_module/test_attr_dedup.py | 92 ++++++++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 tests/parallel_module/test_attr_dedup.py diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 0e613bb2..cba53b1e 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -41,6 +41,80 @@ class AttrMeta: # (i.e., no partition on value -> no need to sum up) val_chunks: int + +def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, AttrMeta]]) -> Dict[int, Dict[str, AttrMeta]]: + ''' + Deduplicate the attributes according to `rank2attr_area_map`. + For each `slicers` of a full tensor with the name `orig_name`, we only store its first appearance + in the `rank2attr_area_map`. + In addition, we will check + - the shape of the full tensor is consistent across different ranks + - the slicers of the full tensor are not intersected with each other + - the slicers of the full tensor can cover the full tensor + The input and output attribute area map's key is the local attribute name. + + Args: + rank2attr_area_map (Dict[int, Dict[str, AttrMeta]]): the mapping from rank to the attribute area map + + Returns: + Dict[int, Dict[str, AttrMeta]]: the deduplicated attribute area map + ''' + # assume ranks in rank2attr_area_map are in increasing order + ranks = list(rank2attr_area_map.keys()) + for i in range(1, len(ranks)): + assert ranks[i - 1] < ranks[i], f'rank {ranks[i - 1]} should be less than rank {ranks[i]}' + + orig_name2slice_info = defaultdict(list) + orig_name2shape = dict() + + def need_save(slicers: Tuple[slice, ...], saved_slicers_list: List[Tuple[slice, ...]]) -> bool: + for saved_slicers in saved_slicers_list: + assert len(slicers) == len(saved_slicers), f'If two slicers are related to one same full tensor, lengths should be equal, but get {slicers} vs {saved_slicers}' + if slicers == saved_slicers: + return False + # if slicers intersect with saved_slicers, raise error + for s, ss in zip(slicers, saved_slicers): + if s == ss: + continue + if s.start < ss.stop and s.stop > ss.start: + raise RuntimeError(f'intersected slicers {slicers} vs {saved_slicers}') + return True + + ret = dict() + for rank, attr_area_map in rank2attr_area_map.items(): + dedup_attr_area_map = dict() + for attr, attr_meta in attr_area_map.items(): + assert attr_meta.val_chunks == 1, 'not support partitioning on value dimension' + if attr_meta.orig_name not in orig_name2shape: + orig_name2shape[attr_meta.orig_name] = attr_meta.shape + else: + assert orig_name2shape[attr_meta.orig_name] == attr_meta.shape, \ + f'unmatched shape {orig_name2shape[attr_meta.orig_name]} vs {attr_meta.shape}' + if need_save(attr_meta.slicers, orig_name2slice_info[attr_meta.orig_name]): + orig_name2slice_info[attr_meta.orig_name].append(attr_meta.slicers) + dedup_attr_area_map[attr] = attr_meta + ret[rank] = dedup_attr_area_map + + # since we + # - skip saving when there are identical weights + # - assert the slicers are disjoint + # we can use the sum of the sub-slicers to verify the full tensor is covered + for orig_name, slicerss in orig_name2slice_info.items(): + shape = orig_name2shape[orig_name] + full_size = 1 + for s in shape: + full_size *= s + covered_size = 0 + for slicers in slicerss: + size = 1 + for s in slicers: + size *= s.stop - s.start + covered_size += size + assert full_size == covered_size, f'uncovered size for {orig_name} with shape {shape}, slicerss {slicerss}' + + return ret + + class CubeModule(torch.nn.Module): """ The module is responsible for parameter synchronization diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py new file mode 100644 index 00000000..8e5312bc --- /dev/null +++ b/tests/parallel_module/test_attr_dedup.py @@ -0,0 +1,92 @@ +import tempfile +from pathlib import Path +import pytest +from typing import Dict, Tuple, List, Any + +import torch +from torch import nn + +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, \ + merge_state_dicts, load_merged_state_dicts, \ + deduped_state_dict, load_deduped_state_dict +from nnscaler.runtime.module import ParallelModule +from nnscaler.graph.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.policies import _tp, _replica +from nnscaler.runtime.module import dedup_attrs + +from .common import init_distributed, clear_dir_on_rank0, assert_equal +from ..launch_torchrun import launch_torchrun + +class Net(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 4, bias=False) + self.fc2 = torch.nn.Linear(4, 4, bias=False) + self.fc3 = torch.nn.Linear(4, 4, bias=False) + # register a buffer + self.register_buffer('buffer', torch.zeros(4)) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.buffer + x + return x + +def pas(graph: IRGraph, config: ComputeConfig): + fw_nodes = graph.select(ntype=IRFwOperation) + assert len(fw_nodes) == 4 + devs = list(range(config.plan_ngpus)) + # partition the batch dim, weight is replicated + _tp(graph, fw_nodes[0], idx=0, dim=0, devs=devs) + # partition the weight, input is replicated + _tp(graph, fw_nodes[1], idx=1, dim=0, devs=devs) + _replica(graph, fw_nodes[2], devs=devs) + _replica(graph, fw_nodes[3], devs=devs) + return graph + +def _gpu_worker_spmd(cc: ComputeConfig): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_dedup_attr') as tempdir: + module = parallelize( + Net(), + {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, + pas, + cc, + cube_savedir=tempdir, + instance_name='attr_dedup' + ) + print(module.fullmap) + world_size = torch.distributed.get_world_size() + attr_area_maps = [None for _ in range(world_size)] + curr_rank = torch.distributed.get_rank() + torch.distributed.all_gather_object(attr_area_maps, module.fullmap) + rank2attr_area_map = {} + for i, attr_area_map in enumerate(attr_area_maps): + rank2attr_area_map[i] = attr_area_map + torch.distributed.barrier() + dedup_meta_info = dedup_attrs(rank2attr_area_map) + dedup_area_map = list(dedup_meta_info[curr_rank].items()) + if curr_rank == 0: + assert len(dedup_area_map) == 4 + assert dedup_area_map[0][1].orig_name == 'fc1.weight' + assert dedup_area_map[0][1].slicers == (slice(0, 4, None), slice(0, 4, None)) + assert dedup_area_map[1][1].orig_name == 'fc2.weight' + assert dedup_area_map[1][1].slicers == (slice(0, 2, None), slice(0, 4, None)) + assert dedup_area_map[2][1].orig_name == 'fc3.weight' + assert dedup_area_map[2][1].slicers == (slice(0, 4, None), slice(0, 4, None)) + assert dedup_area_map[3][1].orig_name == 'buffer' + assert dedup_area_map[3][1].slicers == (slice(0, 4, None),) + elif curr_rank == 1: + assert len(dedup_area_map) == 1 + assert dedup_area_map[0][1].orig_name == 'fc2.weight' + assert dedup_area_map[0][1].slicers == (slice(2, 4, None), slice(0, 4, None)) + else: + raise RuntimeError(f'Unexpected rank {curr_rank}') + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_attr_dedup(): + cc = ComputeConfig(2, 2, use_zero=False) + launch_torchrun(2, _gpu_worker_spmd, cc) From 79070fb7f7a499d23f62f707e265600c8915ae48 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 4 Jun 2024 01:28:11 +0000 Subject: [PATCH 1649/1892] Merged PR 2158: parallel module: rename cube related names 1. rename cube related names 2. move user_config.code to pas_config --- docs/source/parallel_module.md | 101 ++++++--- examples/vision/swin/train.py | 28 ++- nnscaler/__init__.py | 1 - nnscaler/parallel.py | 205 +++++++++--------- nnscaler/policies.py | 2 +- tests/graph/function/test_dict_values.py | 4 +- tests/graph/function/test_script_func.py | 2 +- tests/parallel_module/test_broadcast.py | 2 +- tests/parallel_module/test_checkpoint.py | 6 +- .../parallel_module/test_checkpoint_buffer.py | 2 +- .../parallel_module/test_checkpoint_dedup.py | 2 +- .../parallel_module/test_checkpoint_shared.py | 2 +- .../parallel_module/test_checkpoint_unused.py | 2 +- tests/parallel_module/test_ddp.py | 2 +- tests/parallel_module/test_end2end.py | 4 +- tests/parallel_module/test_gencode.py | 44 ++-- tests/parallel_module/test_inference.py | 2 +- tests/parallel_module/test_init.py | 8 +- tests/parallel_module/test_line_timer.py | 2 +- tests/parallel_module/test_nested.py | 4 +- tests/parallel_module/test_override.py | 8 +- tests/parallel_module/test_reducer_hook.py | 2 +- tests/parallel_module/test_scale_grads.py | 2 +- tests/parallel_module/test_submodule.py | 2 +- tests/parallel_module/test_wholemodule.py | 2 +- tests/test_policies.py | 4 +- 26 files changed, 240 insertions(+), 205 deletions(-) diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 555f4898..541687f0 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -1,6 +1,8 @@ # Parallel Module -Cube can parallelize a `torch.nn.Module` to a parallel module. A parallel module is a special `torch.nn.Module` but runs in multiple gpus/nodes. All the complexity of distributed training/inferring is hidden from the user. +nnScaler can parallelize a `torch.nn.Module` to a parallel module. +A parallel module is a special `torch.nn.Module` but runs in multiple gpus/nodes. +All the complexity of distributed training/inferring is hidden from the user. Currently we support three kinds of parallelism: data parallelism, tensor parallelism and pipeline parallelism (model parallelism). We can also combine them to get the best performance. @@ -181,11 +183,6 @@ def train(model: ParallelizedPipelinedLLM, data): The configuration of the compute environment. It is a dataclass with the following fields: ```python -@dataclass -class UserConfig: - graph: Dict[str, Any] = field(default_factory=dict) - code: Dict[str, Any] = field(default_factory=dict) - @dataclass(frozen=True) class ComputeConfig: plan_ngpus: int @@ -205,7 +202,7 @@ class ComputeConfig: pipeline_scheduler: Optional[str] = None pas_config: Dict[str, Any] = field(default_factory=dict) - user_config: UserConfig = field(default_factory=UserConfig) + user_config: Dict[str, Any] = field(default_factory=dict) ``` We can categorize the fields into 4 categories: @@ -226,11 +223,11 @@ We can categorize the fields into 4 categories: - `pipeline_nmicros`: the number of microbatches in the pipeline. - `pipeline_nstages`: the number of stages in the pipeline. - `pipeline_scheduler`: the scheduler name for the pipeline. Current we support four schedulers in training `1f1b`/`1f1b_plus`/`gpipe`/`chimera_direct` (4 stages pipeline only), and one scheduler in inference `infer_pipe`. - - `pas_config`: the configuration for the PAS policy. It is a dictionary, and will be used by the PAS policy. Please note different PAS will have different configurations, and please check the PAS policy for details. -4. User configuration - - user_config: the user configuration,which is used to decide whether skipping compiling and reusing the previously compiled parallel module. It has two categories of configuration: - - `graph`: the graph related configuration, which is used to decide whether skipping graph generation only. - - `code`: if it has changed, the code will be regenerated. + - `pas_config`: the configuration for the PAS policy (partition-assign-schedule policy, which describes how to place all computations across devices. For details, please refer to [PAS Policies](#pas-policies)). + It is a dictionary, and will be used by the PAS policy. + Please note different PAS will have different configurations, + You can also put any other settings that can affect code generation here. but please prefix the keys with `_` to avoid conflicts with PAS configurations. + - `user_config`: the user configuration, which is used to decide whether skipping compiling and reusing the previously traced graph. Note: 1. You can put any custom configurations in `user_config`. The assumption is different `user_config` should generate different graph/code. So if the user config is changed, we will regenerate the graph/code automatically. Here are some examples: @@ -245,7 +242,7 @@ Note: if module_config.use_3d: ... ``` - here we can set `user_config.graph` to `{'use_3d': module_config.use_3d}`, + here we can set `user_config` to `{'use_3d': module_config.use_3d}`, and we can be sure different use_3d config will never use the same graph (and eventually the generated code). - Example 2: save file stats @@ -259,13 +256,23 @@ Note: h.update(f.read()) compute_config = { ...., - user_config: UserConfig( - graph = { - 'files_md5': h.hexdigest() - } - ) + user_config: { + 'files_md5': h.hexdigest() + } } ``` +2. If some settings doesn't affect tracing/graph generation, but do affect code generation, you can put them in `pas_config`. Please prefix the keys with `_` to avoid conflicts with predefined PAS configurations. One typical example is you can put the name of selected PAS policy in `pas_config`, so changing PAS policy will regenerate code but the graph will be reused. + + ```python + compute_config = ComputeConfig( + ... + pas_config={ + '_pas_name': ..., + # PAS policy specific configurations + ... + }, + ) + ``` ### ReuseType @@ -280,10 +287,10 @@ class ReuseType(Enum): ``` We call it a `match` when the `ComputeConfig` is the same with the previous run. -1. MATCH: Reuse if match, error if not match, generate if no previous gerenated code exists. -2. OVERRIDE: Nothing will be reused. Everything will be regenerated. -3. MOO: MOO is short for 'match or override'. It will reuse if match, generate if not match or no previous generated code exists. -4. GRAPH: Reuse graph only if match, generate otherwise. +1. `MATCH`: Reuse if match, error if not match, generate if no previous gerenated code exists. +2. `OVERRIDE`: Nothing will be reused. Everything will be regenerated. +3. `MOO`: `MOO` is short for 'match or override'. It will reuse if match, generate if not match or no previous generated code exists. +4. `GRAPH`: Reuse graph only if match, generate otherwise. ### BroadcastGenFilesStrategy @@ -363,7 +370,7 @@ def parallelize( pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], compute_config: ComputeConfig, *, - cube_savedir: Union[str, Path] = './.cube', + gen_savedir: Union[str, Path] = './.nnscaler', reuse: Union[ReuseType, str] = ReuseType.MATCH, instance_name: Optional[str] = None, load_module: bool = True, @@ -379,13 +386,16 @@ It has the following parameters: - `dummy_input` (`dict`): the dummy input for the module. The keys are the argument names of `Module.forward` function, and the values are the dummy input for the arguments. The dummy input will be used to trace the module. Please note the module can't be parallelize if `Module.forward` has positional-only arguments. -- `pas_policy` (`Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]`): the pas (partition-assign-schedule) policy, which describes how to place all computations across devices. You need either pass a builtin PAS policy name or a a custom policy function which should take an `IRGraph` and a `ComputeConfig` as input, and return a new `IRGraph` with the PAS policy applied. We have 6 builtin PAS policies: `dp`, `tp`, `pp`, `data`, `hybrid`, and `autodist`. Please note all builtin PAS policies except `autodist` are only for test purpose. The `autodist` policy is the recommended policy for most cases. For details, please refer to `PAS Policies` section. +- `pas_policy` (`Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]`): the pas (partition-assign-schedule) policy, which describes how to place all computations across devices. +You need either pass a builtin PAS policy name or a a custom policy function which should take an `IRGraph` and a `ComputeConfig` as input, and return a new `IRGraph` with the PAS policy applied. + We have 6 builtin PAS policies: `dp`, `tp`, `pp`, `data`, `hybrid`, and `autodist`. Please note all builtin PAS policies except `autodist` are only for test purpose. The `autodist` policy is the recommended policy for most cases. + For details, please refer to [PAS Policies](#pas-policies) section. - `compute_config` (`ComputeConfig`): the environment resource - `reuse` (`ReuseType`): specify which part can be reused. -- `cube_savedir` (`Union[str, Path]`): the directory to save generated code +- `gen_savedir` (`Union[str, Path]`): the directory to save generated code - `instance_name` (`Optional[str]`): the instance name of the generated module. If it is `None`, will use the default name `_`. @@ -406,10 +416,10 @@ See more details in the `ParallelModule APIs` section. Note: -1. This function can be used to convert both module object and module class to cube module or cube module class. +1. This function can be used to convert both module object and module class to parallel module or parallel module class. Among key-value arguments, `module_fn` and `module_dtype` control how to create the module object. -whereas `init_module_params` controls how to load cube module object after parallelization is done. +whereas `init_module_params` controls how to load parallel module object after parallelization is done. 2. If you want to save multiple instances of the same module (with different configurations), you can specify the `instance_name` to distinguish them. @@ -445,14 +455,14 @@ To support distributed training, in the function we need to hook 4 places (which 1. optimizer constructor: the parameters of optimizer will not be the same with the parameters of the module if we use zero. - So we need to replace the parameters of optimizer with `CubeModule.parameters_for_optimizer`. + So we need to replace the parameters of optimizer with `ParallelModule.parameters_for_optimizer`. 2. `optimizer.step()`: we need to call `optimizer.sync_shard_grad()` to sync the gradients of the module before `optimizer.step()`. - In zero mode, we have to call `CubeModule.gather_params()` after `optimizer.step()` + In zero mode, we have to call `ParallelModule.gather_params()` after `optimizer.step()` 3. `optimizer.zero_grad()`: - We need to call `CubeModule.zero_grad()` after `optimizer.zero_grad()` + We need to call `ParallelModule.zero_grad()` after `optimizer.zero_grad()` `build_optimizer` will patch optimizer for you. Besides the above patches, we also add several utility functions/variables to optimizer: @@ -540,7 +550,7 @@ The configuration of the PAS policy should be passed in the `pas_config` of `Com 5. `hybrid`: pipeline parallelism + tensor parallelism + data parallelism. It will do model parallelism and tensor parallelism(on 0 dimension) inside a scale unit, and run data parallelism across scale units. It requires the `use_end2end` and `use_pipeline` to be true. It has no configurations. 6. `autodist`: the recommended policy for most cases. Currently it only support Adam-like optimizers. It will automatically choose the best partition for you by balancing the memory usage and speed. It has the following configurations. - - `update_freq (int)`: the update frequency when training the module. Required. + - `update_freq (int)`: the update frequency when training the module. Default is 1. Optional. - `mem_constraint (float)`: The memory constraint in each device in GB. Optional. - `task_name (str)`: The name of the current task to distinguish runs. Optional. - `use_fp16 (bool)`: Whether you use `fp16`. Default is `False`. Optional. @@ -558,6 +568,35 @@ The configuration of the PAS policy should be passed in the `pas_config` of `Com Please note all options to `autodist` are just suggestions. `autodist` will try to find the best partition for you, which may not be the same with your suggestions. + You can also put any other settings that can affect code generation here. but please prefix the keys with `_` to avoid conflicts with predefined keys. + +Here is an example: +```python +compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + pas_config={ + '__pas_name': ..., # addtional configurations that can affect code generation. + 'update_freq': ..., + 'mem_constraint': ..., + 'task_name': ..., + 'use_fp16': ..., + 'use_memory_efficient_fp16': ..., + 'use_bf16': ..., + 'use_memory_efficient_bf16': ..., + 're_profile': ..., + 'verbose': ..., + 'load_plan_path': ..., + 'save_plan_path': ..., + 'partition_constraints_path': ..., + 'recompute_modules': ..., + 'pipeline_pivots': ..., + 'use_apex_fused_adam_v2': ..., + }, +) +``` + ### Checkpoint support You can save/load the checkpoints for parallel modules. diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 2895dbcf..c81cc285 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -153,24 +153,22 @@ def train(args, compute_config: nnscaler.ComputeConfig): use_pipeline=args.pp_size > 1, pipeline_nmicros=args.gbs // args.mbs, pipeline_nstages=args.pp_size, - pas_config={ # for autodist only + pas_config={ + # customized settings that can affect code generation. + '_pas_name': args.policy, + '_gbs': args.gbs, + '_pp_size': args.pp_size, + '_tp_size': args.tp_size, + '_dp_size': args.dp_size, + # for autodist only 'update_freq': args.gbs // args.mbs, 'use_fp16': args.fp16, }, - user_config=nnscaler.UserConfig( - graph={ - 'mbs': args.mbs, - 'fp16': args.fp16, - 'src_hash': src_hash(), - }, - code={ - 'pas_name': args.policy, - 'gbs': args.gbs, - 'pp_size': args.pp_size, - 'tp_size': args.tp_size, - 'dp_size': args.dp_size, - } - ) + user_config={ + 'mbs': args.mbs, + 'fp16': args.fp16, + 'src_hash': src_hash(), + } ) train(args, compute_config) diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index cb9e60ef..63ba391f 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -1,7 +1,6 @@ from .version import __version__ from .parallel import ( ParallelModule, - UserConfig, ComputeConfig, ReuseType, BroadcastGenFilesStrategy, diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 27cb64aa..2e2d73a6 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -62,40 +62,6 @@ _PREDEFINED_POLICIES[k[len(_PREDEFINED_POLICIES_NAME_PREFIX):]] = v -@dataclass -class UserConfig: - # you should put any configuration that may affect the traced graph here. - # So we can track the changes and make sure the generated code is correct. - # Example 1: save module configuration - # ```python - # class MyModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - # def forward(self, x): - # ... - # if module_config.use_3d: - # ... - # ``` - # here we can set `graph={'use_3d': module_config.use_3d}`, - # and we can be sure different use_3d will never use the same generated code. - # Example 2: save file stats - # If you want to track all related file stats (just like traditional compilers do), - # you can save the md5 of the files to save some bytes: - # ```python - # import hashlib - # h = hashlib.md5() - # for f in Path('./src').glob('**/*.py'): - # with open(f, 'rb') as f: - # h.update(f.read()) - # graph = { - # 'files_md5': h.hexdigest() - # } - # ``` - graph: Dict[str, Any] = field(default_factory=dict) - # you can put any configuration that may affect the generated code (but not affect the traced graph) here. - code: Dict[str, Any] = field(default_factory=dict) - - @dataclass(frozen=True) class ComputeConfig: plan_ngpus: int @@ -126,9 +92,38 @@ class ComputeConfig: # it is pas's responsibility to apply the scheduler pipeline_scheduler: str = '1f1b' # PAS policy settings + # you can also put any other settings that can affect code generation here. + # but please prefix the keys with `_` to avoid conflicts with predefined keys. pas_config: Dict[str, Any] = field(default_factory=dict) # the customized configs from user that can affect the graph and code generation. - user_config: UserConfig = field(default_factory=UserConfig) + # you should put any configuration that may affect the traced graph here. + # So we can track the changes and make sure the generated code is correct. + # Example 1: save module configuration + # ```python + # class MyModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # def forward(self, x): + # ... + # if module_config.use_3d: + # ... + # ``` + # here we can set `graph={'use_3d': module_config.use_3d}`, + # and we can be sure different use_3d will never use the same generated code. + # Example 2: save file stats + # If you want to track all related file stats (just like traditional compilers do), + # you can save the md5 of the files to save some bytes: + # ```python + # import hashlib + # h = hashlib.md5() + # for f in Path('./src').glob('**/*.py'): + # with open(f, 'rb') as f: + # h.update(f.read()) + # graph = { + # 'files_md5': h.hexdigest() + # } + # ``` + user_config: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): if self.plan_ngpus <= 0: @@ -162,9 +157,6 @@ def __post_init__(self): if not self.inference_only and self.pipeline_scheduler in _PREDEFINED_INFERENCE_SCHEDS: raise ValueError(f"pipeline_scheduler {self.pipeline_scheduler} is not supported in training mode.") - if isinstance(self.user_config, dict): - super().__setattr__('user_config', UserConfig(**self.user_config)) - def apply_pipeline_scheduler(self, graph: IRGraph) -> Optional[SchedulePlan]: """ Apply the pipeline scheduler to the graph. @@ -185,7 +177,7 @@ def gpu_config(self) -> Dict[str, int]: def graph_config(self) -> Dict[str, Any]: return { 'dynamic_shape': self.dynamic_shape, - 'graph_user_config': self.user_config.graph, + 'user_config': self.user_config, 'inference_only': self.inference_only, # there will be no backward nodes in the graph in inference mode 'use_pipeline': self.use_pipeline, # pipeline option can affect the graph generation. 'end2end_mode': self.use_end2end, # end2end_mode can affect the graph generation. @@ -351,12 +343,12 @@ def _get_full_qualified_name(obj: Any) -> str: return obj.__module__ + '.' + obj.__class__.__qualname__ -def _add_cube_savedir_to_syspath(cube_savedir: str) -> Path: - cube_savedir = Path(cube_savedir).resolve() - cube_savedir.mkdir(parents=True, exist_ok=True) - if str(cube_savedir) not in sys.path: - sys.path.append(str(cube_savedir)) - return cube_savedir +def _add_gen_savedir_to_syspath(gen_savedir: str) -> Path: + gen_savedir = Path(gen_savedir).resolve() + gen_savedir.mkdir(parents=True, exist_ok=True) + if str(gen_savedir) not in sys.path: + sys.path.append(str(gen_savedir)) + return gen_savedir def _is_any_gencode_loaded(namespace: str) -> bool: @@ -395,7 +387,7 @@ def _broadcast_single_value(src_rank, group, obj=None): _DEFAULT_INSTANCE_NAME = '_' _GENCODE_FILE_PREFIX = 'gencode' _GENCODE_FILE_TEMPLATE = _GENCODE_FILE_PREFIX + '{}.py' # 'gencode{}.py' -_CUBE_MODULE_NAMESPACE = '_cube_modules' +_PARALLEL_MODULE_NAMESPACE = '_parallel_modules' _GRAPH_DUMP_FILE = 'graph.ckp' _FORWARD_ARGS_DUMP_FILE = 'forward_args.pkl' @@ -453,25 +445,25 @@ class RegenStatus(Enum): def _prepare_namespace( - cube_savedir: str, + gen_savedir: str, module_or_module_class: Union[Type[torch.nn.Module], torch.nn.Module], instance_name: Optional[str] = None, ) -> Tuple[str, Path]: - cube_savedir = _add_cube_savedir_to_syspath(cube_savedir) + gen_savedir = _add_gen_savedir_to_syspath(gen_savedir) instance_name = instance_name or _DEFAULT_INSTANCE_NAME instance_name = instance_name.strip('.') if instance_name else '' instance_namespace = f'.{instance_name}' if instance_name else '' - namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_or_module_class)}{instance_namespace}' + namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_or_module_class)}{instance_namespace}' - outdir = cube_savedir / Path(namespace.replace('.', '/').strip('/')) + outdir = gen_savedir / Path(namespace.replace('.', '/').strip('/')) outdir.mkdir(parents=True, exist_ok=True) return namespace, outdir def _prepare_and_check_reusable( - cube_savedir: str, + gen_savedir: str, module_or_module_class: Union[Type[torch.nn.Module], torch.nn.Module], compute_config: ComputeConfig, instance_name: Optional[str] = None, @@ -481,7 +473,7 @@ def _prepare_and_check_reusable( Prepare the output directory for code generation, and also check if the existing code is reusable. Args: - cube_savedir (str): the directory to save generated code + gen_savedir (str): the directory to save generated code module_or_module_class (Union[Type[torch.nn.Module], torch.nn.Module]): the original module or module class compute_config (ComputeConfig): the environment resource instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. @@ -494,7 +486,7 @@ def _prepare_and_check_reusable( RuntimeError: if the existing code is not reusable, will raise RuntimeError if the code is not reusable but the module is already loaded. """ - namespace, outdir = _prepare_namespace(cube_savedir, module_or_module_class, instance_name) + namespace, outdir = _prepare_namespace(gen_savedir, module_or_module_class, instance_name) # decision matrix for code generation # reuse flag | dir condition(imported, empty, match, unmatched) | action @@ -702,14 +694,14 @@ def _gencode( module_fn: Optional[Callable[[], torch.nn.Module]] = None, ) -> RegenStatus: """ - Generate cube module source code from a torch module, and save it to file. + Generate parallel module source code from a torch module, and save it to file. Generated module will be save according to its full qualified name. If you want to save multiple instances of the same module, you can specify the instance_name to distingish them. For example, if the module is `torchscale.x.y`, then the generated module will be save to - `cube_savedir/_cube_modules/torchscale/x/y/instance_name`. + `gen_savedir/_parallel_modules/torchscale/x/y/instance_name`. Args: module (torch.nn.Module): the module to be compiled @@ -750,7 +742,7 @@ def _gencode( module = module.to(dtype=module_dtype) if any(isinstance(m, CubeModule) for m in module.modules()): - raise RuntimeError('CubeModule can not be nested.') + raise RuntimeError('Parallel modules can not be nested.') # save origin module metadata meta_info = OriginModuleMetadata( @@ -835,22 +827,22 @@ def _gencode( return ret -def _load_cube_module_class( +def _load_parallel_module_class( module_class: Type[torch.nn.Module], *, - cube_savedir: Union[str, Path] = './.cube', + gen_savedir: Union[str, Path] = './.nnscaler', instance_name: Optional[str] = None, rank: Optional[int] = None, ) -> Type[ParallelModule]: """ - Load the generated cube module class, with train_step and infer_step assigned as member function.. + Load the generated parallel module class, with train_step and infer_step assigned as member function.. - Please note that the cube module class should be generated beforehand by _gencode(). + Please note that the parallel module class should be generated beforehand by _gencode(). Args: module_class (Type[torch.nn.Module]): the original module class - cube_savedir (Union[str, Path]): the directory to load generated code + gen_savedir (Union[str, Path]): the directory to load generated code instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. rank (Optional[int]): the rank of the module. If it is None, will get the rank from torch.distributed.get_rank(). This option is only useful for debugging or writing pre/post-processing tools. @@ -859,20 +851,20 @@ def _load_cube_module_class( Type[ParallelModule]: the generated module class """ rank = torch.distributed.get_rank() if rank is None else rank - namespace, _ = _prepare_namespace(cube_savedir, module_class, instance_name) + namespace, _ = _prepare_namespace(gen_savedir, module_class, instance_name) gen_imported = importlib.import_module( f'{namespace}.{Path(_GENCODE_FILE_TEMPLATE.format(rank)).stem}' ) - cube_module_class = gen_imported.GenModel + parallel_module_class = gen_imported.GenModel # rewrite class name and module name - cube_module_class.__name__ = module_class.__name__ - cube_module_class.__qualname__ = module_class.__qualname__ - # cube_module_class.__module__ = module_class.__module__ - cube_module_class.__orig_module_class__ = module_class # save the original module class + parallel_module_class.__name__ = module_class.__name__ + parallel_module_class.__qualname__ = module_class.__qualname__ + # parallel_module_class.__module__ = module_class.__module__ + parallel_module_class.__orig_module_class__ = module_class # save the original module class # override train_step and infer_step only if they are defined in the generated module (end2end module only) - cube_module_class._train_step = getattr(gen_imported, '_train_step', cube_module_class._train_step) - cube_module_class._infer_step = getattr(gen_imported, '_infer_step', cube_module_class._infer_step) - return cube_module_class + parallel_module_class._train_step = getattr(gen_imported, '_train_step', parallel_module_class._train_step) + parallel_module_class._infer_step = getattr(gen_imported, '_infer_step', parallel_module_class._infer_step) + return parallel_module_class def parallelize( @@ -881,7 +873,7 @@ def parallelize( pas_policy: Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]], compute_config: ComputeConfig, *, - cube_savedir: Union[str, Path] = './.cube', + gen_savedir: Union[str, Path] = './.nnscaler', reuse: Union[ReuseType, str] = ReuseType.MATCH, instance_name: Optional[str] = None, load_module: bool = True, @@ -891,7 +883,7 @@ def parallelize( broadcast_strategy: Union[str, BroadcastGenFilesStrategy] = 'none', ) -> Union[None, ParallelModule, Type[ParallelModule]]: """ - Convert a torch.nn.Module object or class to CubeModule object or class. + Convert a torch.nn.Module object or class to ParallelModule object or class. If you want to save multiple instances of the same module, you can specify the instance_name to distinguish them. @@ -907,19 +899,19 @@ def parallelize( * The module object will be copied to cpu to handle possible insufficient gpu memory. * The training flag will be the same as the original module - This function can be used to convert both module object and module class to cube module or cube module class. + This function can be used to convert both module object and module class to parallel module or parallel module class. Among key-value arguments, module_fn and module_dtype control how to create the module object. - whereas init_module_params controls how to load cube module object after conversion is done. + whereas init_module_params controls how to load parallel module object after conversion is done. - 1. If the input is a module object, it will return a CubeModule object if load_module is True. + 1. If the input is a module object, it will return a ParallelModule object if load_module is True. This is useful when the module is created by a factory function. a. module_fn is ignored. b. module_dtype is used to control the dtype of the input module. - c. init_module_params is used to control whether to initialize the cube module parameters when load it. + c. init_module_params is used to control whether to initialize the parallel module parameters when load it. - 2. If the input is a module class, it will return a CubeModule class if load_module is True. + 2. If the input is a module class, it will return a ParallelModule sub class if load_module is True. a. module_fn is used to create the module object, or module's__init__ if not prent. b. module_dtype is used to control the dtype of the created module (by constructor or module_fn). @@ -951,7 +943,7 @@ def __init__(self, init_params=True): it can be a name of builtin policies, or a custom policy function. compute_config (ComputeConfig): the environment resource reuse (ReuseType): specify which part can be reused. - cube_savedir (Union[str, Path]): the directory to save generated code + gen_savedir (Union[str, Path]): the directory to save generated code instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. load_module (bool): whether to load the generated module or module class after conversion is done. init_module_params (bool): If true, when we construct the module, all its parameters are initialized with the same value with when we traced. @@ -964,15 +956,22 @@ def __init__(self, init_params=True): Please note that the broadcasting will only be done in torchrun environment, and will throw an error if torch.distributed is not initialized and broadcast_strategy is not NONE. Returns: - Union[CubeModule, Type[CubeModule], None]: - if load_module flag is set, return the converted CubeModule object or class + Union[ParallelModule, Type[ParallelModule], None]: + if load_module flag is set, return the converted ParallelModule object or class if load_module flag is not set, return None """ + if ( + isinstance(module_or_module_class, ParallelModule) or + (inspect.isclass(module_or_module_class) and issubclass(module_or_module_class, ParallelModule)) + ): + # already done + return module_or_module_class if load_module else None + if ( isinstance(module_or_module_class, CubeModule) or (inspect.isclass(module_or_module_class) and issubclass(module_or_module_class, CubeModule)) ): - return module_or_module_class if load_module else None + raise RuntimeError("Old style CubeModule is not supported") if isinstance(pas_policy, str): if not pas_policy in _PREDEFINED_POLICIES: @@ -993,7 +992,7 @@ def __init__(self, init_params=True): # generate code only in node0 # if it is not in a torchrun environment, just generate. if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - outdir, reusable = _prepare_and_check_reusable(cube_savedir, module_class, compute_config, instance_name, reuse) + outdir, reusable = _prepare_and_check_reusable(gen_savedir, module_class, compute_config, instance_name, reuse) if not reusable: config_file = outdir / ParallelModule.COMPUTE_CONFIG_FILE ComputeConfig.safe_dump_to_file(compute_config, config_file) # always refresh compute config @@ -1051,26 +1050,26 @@ def __init__(self, init_params=True): if broadcast_strategy != BroadcastGenFilesStrategy.NONE: _broadcast_gen_files( module_class, - cube_savedir=cube_savedir, + gen_savedir=gen_savedir, instance_name=instance_name, broadcast_strategy=broadcast_strategy, ) if load_module: if not torch.distributed.is_initialized(): # we only support loading in torchrun environment - raise RuntimeError("Load CubeModule failed: torch.distributed is not initialized.") + raise RuntimeError("Load ParallelModule failed: torch.distributed is not initialized.") torch.distributed.barrier() - cube_module_class = _load_cube_module_class( + parallel_module_class = _load_parallel_module_class( module_class, - cube_savedir=cube_savedir, + gen_savedir=gen_savedir, instance_name=instance_name, ) if is_module_class: - return cube_module_class + return parallel_module_class else: - cube_module = cube_module_class(init_module_params) - cube_module.train(module_or_module_class.training) # set training state to the same as original module - return cube_module + parallel_module = parallel_module_class(init_module_params) + parallel_module.train(module_or_module_class.training) # set training state to the same as original module + return parallel_module @dataclass(unsafe_hash=True) @@ -1208,17 +1207,17 @@ def build_optimizer( """ Build an optimizer for a module. - To support parallelized module (CubeModule), we hook 4 places in this function: + To support parallelized module (ParallelModule), we hook 4 places in this function: 1. optimizer constructor: the parameters of optimizer will not be the same with the parameters of the module if we use zero - so we need to replace the parameters of optimizer with CubeModule.parameters_for_optimizer + so we need to replace the parameters of optimizer with ParallelModule.parameters_for_optimizer It is impossible to make this change transparent to end users. 2. optimizer.step(): we need to call optimizer.sync_shard_grad() to sync the gradients of the module before optimizer.step(). - In zero mode, we have to call CubeModule.gather_params() after optimizer.step() + In zero mode, we have to call ParallelModule.gather_params() after optimizer.step() 3. optimizer.zero_grad(): - We need to call CubeModule.zero_grad() after optimizer.zero_grad() + We need to call ParallelModule.zero_grad() after optimizer.zero_grad() 4. backward(): you need to call optimizer.sync_shard_grad() manually if you want to read the gradients of the module before optimizer.step(). @@ -1276,24 +1275,24 @@ def build_optimizer( opt_module_locs: Dict[str, ModuleParameterLocation] = {} def _local_parameters(module: torch.nn.Module): - cube_suffix = "_CUBE_SUFFIX" + pm_suffix = "_PARALLEL_MODULE_PARAM_SUFFIX" gen = module._named_members( lambda m: [ - (cube_suffix, p) # (cube_suffix, p) to meet _named_members requirement + (pm_suffix, p) # (pm_suffix, p) to meet _named_members requirement for p in ( m.parameters_for_optimizer() if m.compute_config.use_zero - else m.parameters() # `CubeModule.merge_partial_states` supports parameters_for_optimizer() only in zero mode + else m.parameters() # `ParallelModule.merge_partial_states` supports parameters_for_optimizer() only in zero mode ) ] if isinstance(m, ParallelModule) else m._parameters.items() ) for idx, (name, param) in enumerate(gen): - if name.endswith(cube_suffix): # is a parameter of ParallelModule + if name.endswith(pm_suffix): # is a parameter of ParallelModule # -1 for removing the dot # please note when the whole module is a ParallelModule, # the name will be empty after removing the suffix - name = name[:-len(cube_suffix) - 1] + name = name[:-len(pm_suffix) - 1] if name not in opt_module_locs: opt_module_locs[name] = ModuleParameterLocation(idx, 1) else: @@ -1656,7 +1655,7 @@ def merge_state_dicts( pm_orig_param_names: Dict[str, List[str]] = {} for k, extra_states in pm_extra_states.items(): module_prefix = '.'.join(k) - pm_orig_param_names[module_prefix] = CubeModule.get_origin_parameter_names([e.param_area_map for e in extra_states]) + pm_orig_param_names[module_prefix] = ParallelModule.get_origin_parameter_names([e.param_area_map for e in extra_states]) # now we can construct the merged state of optimizer from any rank # as said previously, the merge will be based on rank0's data orig_states: Dict[int, Any] = optimizer_state_dicts[0]['state'] @@ -1978,7 +1977,7 @@ def _get_valid_name_from_merged_model( def _broadcast_gen_files( module_class: Type[torch.nn.Module], *, - cube_savedir: Union[str, Path] = './.cube', + gen_savedir: Union[str, Path] = './.nnscaler', instance_name: Optional[str] = None, broadcast_strategy: Union[str, BroadcastGenFilesStrategy], ): @@ -1987,7 +1986,7 @@ def _broadcast_gen_files( Args: module_class (Type[torch.nn.Module]): the original torch module class - cube_savedir (Union[str, Path]): the directory to save generated code + gen_savedir (Union[str, Path]): the directory to save generated code instance_name (Optional[str]): the instance name of the generated module. If it is None, will use the default name. broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for generated files. @@ -2014,7 +2013,7 @@ def _broadcast_gen_files( # use the first rank of each node to broadcast if curr_rank % local_world_size == 0: - _, outdir = _prepare_namespace(cube_savedir, module_class, instance_name) + _, outdir = _prepare_namespace(gen_savedir, module_class, instance_name) files: List[str] = [] # send file list if curr_rank == 0: diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 1aaf6ff3..37418dda 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -200,7 +200,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: pas_cfg = cfg.pas_config # required parameters - update_freq = pas_cfg['update_freq'] + update_freq = pas_cfg.get('update_freq', 1) if isinstance(update_freq, (tuple, list)): update_freq = update_freq[0] if cfg.use_pipeline and update_freq != cfg.pipeline_nmicros: diff --git a/tests/graph/function/test_dict_values.py b/tests/graph/function/test_dict_values.py index 79a2cceb..193f329f 100644 --- a/tests/graph/function/test_dict_values.py +++ b/tests/graph/function/test_dict_values.py @@ -8,7 +8,7 @@ class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() - + def forward(self, x): k = list(x.keys())[0] v = x[k] @@ -25,6 +25,6 @@ def test_script_func(): {'x': {'a': torch.rand(10)}}, 'tp', ComputeConfig(2, 2), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) diff --git a/tests/graph/function/test_script_func.py b/tests/graph/function/test_script_func.py index b2e1460f..f5d6ee4e 100644 --- a/tests/graph/function/test_script_func.py +++ b/tests/graph/function/test_script_func.py @@ -22,6 +22,6 @@ def test_script_func(): {'a': torch.rand(10), 'b': torch.rand(10)}, 'tp', ComputeConfig(2, 2), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index f7a3a36e..a732f860 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -30,7 +30,7 @@ def _to_cube_model(module, compute_config, cube_savedir, {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, 'tp', compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name, load_module=load_module, broadcast_strategy=broadcast_strategy, diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 78dfb28d..103c4d10 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -14,7 +14,7 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts, UserConfig +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm @@ -50,7 +50,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name, dum dummy_input if dummy_input is not None else {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) @@ -99,7 +99,7 @@ def to_pipeline_module(cls, compute_config: ComputeConfig, cube_savedir, {'data': pipeline_dummy_data()}, PASMegatron, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index 8aec6d4b..b304dc92 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -51,7 +51,7 @@ def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_sh {'x': torch.randn(input_shape)}, 'tp', compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name, init_module_params=init_module_params ) diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index 2b8a0ab1..ee7e9edd 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -42,7 +42,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index 0165f024..afcc9517 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -41,7 +41,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py index 8e90041b..d49a5dc8 100644 --- a/tests/parallel_module/test_checkpoint_unused.py +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -51,7 +51,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index ec256176..57509ea6 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -46,7 +46,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index 86f94370..de18cbc9 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -120,7 +120,7 @@ def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, use_pipeline, nstages=Non use_pipeline=use_pipeline, pipeline_nmicros=nmicros, pipeline_nstages=nstages, pipeline_scheduler=pipeline_scheduler ), - cube_savedir=tempdir + gen_savedir=tempdir ) model.cuda() train_result = _train_cube(model, nmicros, runtime_ngpus // plan_ngpus, torch.distributed.get_rank() // plan_ngpus) @@ -330,7 +330,7 @@ def gpu_worker_cube_one_sample(): use_pipeline=True, pipeline_nmicros=2, pipeline_nstages=2, pipeline_scheduler='1f1b' ), - cube_savedir=tempdir + gen_savedir=tempdir ) model.cuda() train_result = _train_cube_one_sample(model, 2) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 325c90d9..c455f6db 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -17,7 +17,7 @@ def _to_cube_model(module, compute_config, cube_savedir, load_module): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'data', compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, load_module=load_module ) @@ -63,7 +63,7 @@ def test_codegen_slice(): {'x': torch.tensor([1.0, 2.0, 3.0, 6.0])}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) assert m_new is None @@ -91,7 +91,7 @@ def test_codegen_args(): }, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=True ) @@ -118,7 +118,7 @@ def _gencode_unused_args_worker(tempdir): }, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=True ) assert m_new is not None @@ -170,7 +170,7 @@ def _gencode_unused_args_worker2(tempdir): }, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=True ) assert m_new is not None @@ -215,7 +215,7 @@ def test_codegen_default_args(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) # parallelize will succeed. @@ -231,10 +231,10 @@ def forward(self, x, attr): def _gencode_contains(cubesave_dir, module_class, index, search_re): - from nnscaler.parallel import _CUBE_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME + from nnscaler.parallel import _PARALLEL_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME from pathlib import Path import re - namespace = f'{_CUBE_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' + namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) filecontent = (outdir /f'gencode{index}.py').read_text() matches = re.findall(search_re, filecontent) @@ -253,7 +253,7 @@ def test_codegen_attr(): {'x': torch.tensor([1.0, 2.0, 3.0, 6.0]), 'attr': AttrHelper()}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) # in old version, all 'forward' functions will patched to a function named 'new_func' @@ -285,7 +285,7 @@ def test_codegen_getitem(): {'batched_data': {'x': torch.tensor([[[1.0], [2.0], [3.0], [6.0]]])}}, 'tp', ComputeConfig(2, 2), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False, ) assert _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') @@ -315,7 +315,7 @@ def test_codegen_training_flag(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) @@ -374,7 +374,7 @@ def test_codegen_iter(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) @@ -404,7 +404,7 @@ def test_codegen_const(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) assert not _gencode_contains(tempdir, ConstantModule, 0, r'\s+5 = builtins.int') @@ -443,7 +443,7 @@ def test_codegen_tensor_slice(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False, reuse='override', ) @@ -454,7 +454,7 @@ def test_codegen_tensor_slice(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False, reuse='override', ) @@ -481,7 +481,7 @@ def test_codegen_dictget(): }}, 'tp', ComputeConfig(2, 2), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False, ) assert _gencode_contains(tempdir, DictGetModule, 0, r"dict.get\(\w+, 'y', \w+\)") @@ -526,7 +526,7 @@ def _gencode_min_function_worker(tempdir): }, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=True ) assert m_new is not None @@ -562,7 +562,7 @@ def _gencode_max_function(tempdir): }, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=True ) assert m_new is not None @@ -603,7 +603,7 @@ def test_codegen_shared_parameter(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False, reuse='override', ) @@ -640,7 +640,7 @@ def test_codegen_buffer(): {'x': torch.randn(128, 64)}, 'dp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False, reuse='override', ) @@ -684,7 +684,7 @@ def test_codegen_inference(): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, 'dp', ComputeConfig(1, 1, inference_only=True), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) assert _gencode_contains(tempdir, Module0, 0, @@ -721,7 +721,7 @@ def p(cube_dir, use_pipeline, dynamic_shape, return_type, inference_only=False): pipeline_nstages=4, pipeline_scheduler='infer_pipe' if inference_only else '1f1b' ), - cube_savedir=cube_dir, + gen_savedir=cube_dir, load_module=False, reuse='override', ) diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 417eb17a..5dfb2de9 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -50,7 +50,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 6340850b..91322dc8 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -3,7 +3,7 @@ import torch -from nnscaler.parallel import _load_cube_module_class, parallelize, ComputeConfig +from nnscaler.parallel import _load_parallel_module_class, parallelize, ComputeConfig from ..launch_torchrun import launch_torchrun from .common import CubeLinear, init_distributed, init_random, clear_dir_on_rank0 @@ -26,7 +26,7 @@ def _init_params_worker(): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, 'tp', ComputeConfig(1, 1), - cube_savedir=tempdir, + gen_savedir=tempdir, reuse='match', ) module1 = cube_module() @@ -69,12 +69,12 @@ def test_empty_weights(model_class, tp): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, 'tp', ComputeConfig(2, 4, use_zero=True, zero_ngroups=2), - cube_savedir=tempdir, + gen_savedir=tempdir, reuse='match', load_module=False, ) for i in range(4): - module_class = _load_cube_module_class(model_class, cube_savedir=tempdir, rank=i) + module_class = _load_parallel_module_class(model_class, gen_savedir=tempdir, rank=i) m = new_empty(module_class) assert m.rank == i for p in m.parameters(): diff --git a/tests/parallel_module/test_line_timer.py b/tests/parallel_module/test_line_timer.py index bc92dfc0..44c865f9 100644 --- a/tests/parallel_module/test_line_timer.py +++ b/tests/parallel_module/test_line_timer.py @@ -29,7 +29,7 @@ def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_sh {'x': torch.randn(input_shape)}, 'tp', compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name, init_module_params=init_module_params ) diff --git a/tests/parallel_module/test_nested.py b/tests/parallel_module/test_nested.py index 27018788..1c215949 100644 --- a/tests/parallel_module/test_nested.py +++ b/tests/parallel_module/test_nested.py @@ -14,7 +14,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir): {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}, pas, compute_config, - cube_savedir=cube_savedir + gen_savedir=cube_savedir ) class Module0(torch.nn.Module): @@ -44,7 +44,7 @@ def __init__(self) -> None: def forward(self, x): return self.module1(x) - with pytest.raises(RuntimeError, match='CubeModule can not be nested.'): + with pytest.raises(RuntimeError, match='Parallel modules can not be nested.'): _to_cube_model(Module2(), 'data', ComputeConfig(1, 1), cube_savedir=tempdir) diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index be07b2dc..c8b9dc8f 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -7,7 +7,7 @@ import shutil from nnscaler.graph.parser.fx.parser import FxModuleParser -from nnscaler.parallel import ReuseType, parallelize, ComputeConfig, _load_cube_module_class +from nnscaler.parallel import ReuseType, parallelize, ComputeConfig, _load_parallel_module_class from nnscaler.runtime.module import ParallelModule from ..utils import new_empty, replace_all_device_with @@ -20,14 +20,14 @@ def _to_cube_model(model_class, compute_config, cube_savedir, reuse, instance_na 'data', compute_config, reuse=reuse, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name, load_module=False, ) if load_module: - module_class = _load_cube_module_class( + module_class = _load_parallel_module_class( model_class, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name, rank=0 ) diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index fc6e2234..2c145c3d 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -34,7 +34,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py index 109bd6b0..0af1a934 100644 --- a/tests/parallel_module/test_scale_grads.py +++ b/tests/parallel_module/test_scale_grads.py @@ -46,7 +46,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 45d7584a..40e76df3 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -42,7 +42,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index bb9b44f1..84c4d886 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -42,7 +42,7 @@ def _to_cube_model(module, pas, compute_config, cube_savedir, instance_name): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, compute_config, - cube_savedir=cube_savedir, + gen_savedir=cube_savedir, instance_name=instance_name ) diff --git a/tests/test_policies.py b/tests/test_policies.py index 230aa58d..5f56c24a 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn -from nnscaler.parallel import ComputeConfig, UserConfig, parallelize +from nnscaler.parallel import ComputeConfig, parallelize from .utils import init_random @@ -51,7 +51,7 @@ def test_autodist(): 'update_freq': 1, 'task_name': 'test_autodist', }), - cube_savedir=tempdir, + gen_savedir=tempdir, load_module=False ) assert m_new is None From 73710215efd4596086a83ee1f6bc03b88e3f45c2 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 4 Jun 2024 05:00:22 +0000 Subject: [PATCH 1650/1892] Merged PR 2159: Integrate zigzag attention in nnscaler NOTE: this PR is a temporary solution for customized functions that have intra communications. The inside communication cost is not considered in the profiler, leading to the sub-optimal generated plan. end2end parity verified on YOCO-3B, 4XA6000 --- examples/zigzag_ring_attention/README.md | 18 + .../zigzag_ring_attention/test_zigzag_attn.py | 136 +++++ examples/zigzag_ring_attention/zigzag_attn.py | 102 ++++ .../zigzag_attn_implementation.py | 482 ++++++++++++++++++ .../zigzag_utils/zigzag_utils.py | 207 ++++++++ nnscaler/codegen/emit.py | 14 +- nnscaler/codegen/frontend_mapping.py | 14 +- nnscaler/codegen/module/module.py | 16 +- nnscaler/graph/parser/register.py | 24 +- tests/codegen/test_emit.py | 2 +- tests/graph/parser/test_register.py | 31 ++ 11 files changed, 1025 insertions(+), 21 deletions(-) create mode 100644 examples/zigzag_ring_attention/README.md create mode 100644 examples/zigzag_ring_attention/test_zigzag_attn.py create mode 100644 examples/zigzag_ring_attention/zigzag_attn.py create mode 100644 examples/zigzag_ring_attention/zigzag_utils/zigzag_attn_implementation.py create mode 100644 examples/zigzag_ring_attention/zigzag_utils/zigzag_utils.py diff --git a/examples/zigzag_ring_attention/README.md b/examples/zigzag_ring_attention/README.md new file mode 100644 index 00000000..f7cf2b28 --- /dev/null +++ b/examples/zigzag_ring_attention/README.md @@ -0,0 +1,18 @@ +# zigzag ring attention + +Tensor parallel (partition head) is a widely used distributed plan to train large language models. Computation and memory are +distributed evenly across devices. However, when the sequence length is extremely long (e.g., 1M), the partition degree of +tensor parallel is constrained by the number of kv heads, which means that the maximum number of devices in a data parallel +unit is no more than the number of kv heads. As a result, tensor parallel fails to scale a model with long sequence length. + +[ring attention](https://arxiv.org/abs/2310.01889) is proposed to address this issue. It partitions q, k and v along the +sequence dimension and passes the partitioned q, k and v through a ring of devices. [ring flash attention](https://github.com/zhuzilin/ring-flash-attention) +implements a high-performance version in PyTorch. This example attempts to integrate the causal version of ring attention +(zigzag ring attention) into nnScaler. + +The interface is wrapped in `zigzag_attn.py`. [flash attention](https://github.com/Dao-AILab/flash-attention) is required for this example. + +Test can be run with the following command: +```bash +torchrun --nproc_per_node 4 test_zigzag_attn.py +``` \ No newline at end of file diff --git a/examples/zigzag_ring_attention/test_zigzag_attn.py b/examples/zigzag_ring_attention/test_zigzag_attn.py new file mode 100644 index 00000000..62ae13dd --- /dev/null +++ b/examples/zigzag_ring_attention/test_zigzag_attn.py @@ -0,0 +1,136 @@ +import torch +import nnscaler +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType +import torch.distributed as dist +from flash_attn import flash_attn_func + +import nnscaler.graph +import nnscaler.graph.function +from examples.zigzag_ring_attention.zigzag_attn import wrap_zigzag_attn_func + +import random + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, _in0, _in1, _in2): + out = wrap_zigzag_attn_func(_in0, _in1, _in2) + return out + +def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: + ngpus = resource.plan_ngpus + partitioned = False + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == 'examples.zigzag_ring_attention.zigzag_attn.wrap_zigzag_attn_func': + print('Partitioned node: ', node) + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=0, dim=1, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + assert partitioned, f'expect zigzag_attn_func in graph, but not found.' + return graph + +if __name__ == "__main__": + nnscaler.init() + rank_id = torch.distributed.get_rank() + world_size = dist.get_world_size() + + set_seed(rank_id) + bsz = 1 + seqlen = 3824 + nheads = 5 + d = 128 + + device = torch.device(f"cuda:{rank_id}") + dtype = torch.float16 + # dtype = torch.bfloat16 + + q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.barrier() + + single_out = wrap_zigzag_attn_func(q, k, v) + single_out.retain_grad() + single_loss = single_out.sum() + single_loss.backward() + + model = TestModule() + + _in0 = q.detach().clone().requires_grad_() + _in1 = k.detach().clone().requires_grad_() + _in2 = v.detach().clone().requires_grad_() + + parallel_model = parallelize(model, dummy_input={"_in0": _in0, "_in1": _in1, "_in2": _in2}, pas_policy=policy, + compute_config=ComputeConfig(world_size, world_size), reuse=ReuseType.OVERRIDE) + parallel_model = parallel_model.cuda() + + + parallel_model.train() + + _in0 = q.detach().clone().requires_grad_() + _in1 = k.detach().clone().requires_grad_() + _in2 = v.detach().clone().requires_grad_() + + para_out = parallel_model(_in0, _in1, _in2) + para_loss = para_out.sum() + para_loss.backward() + parallel_model.sync_grad() + + log("single out", single_out, rank0_only=True) + log("multi out", para_out, rank0_only=True) + log("out diff", single_out - para_out, rank0_only=True) + + log("single dq", q.grad, rank0_only=True) + log("multi dq", _in0.grad, rank0_only=True) + log("dq diff", q.grad - _in0.grad, rank0_only=True) + + log("single dk", k.grad, rank0_only=True) + log("multi dk", _in1.grad, rank0_only=True) + log("dk diff", k.grad - _in1.grad, rank0_only=True) + + log("single dv", v.grad, rank0_only=True) + log("multi dv", _in2.grad, rank0_only=True) + log("dv diff", v.grad - _in2.grad, rank0_only=True) diff --git a/examples/zigzag_ring_attention/zigzag_attn.py b/examples/zigzag_ring_attention/zigzag_attn.py new file mode 100644 index 00000000..4e2a7bac --- /dev/null +++ b/examples/zigzag_ring_attention/zigzag_attn.py @@ -0,0 +1,102 @@ +from typing import Tuple, List, Dict +import torch +from torch import Tensor +import torch.distributed + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir.operator import IRFwOperation +from examples.zigzag_ring_attention.zigzag_utils.zigzag_attn_implementation import ZigZagRingFlashAttnFunc +from flash_attn import flash_attn_func + +import torch.distributed as dist +from nnscaler.runtime.device import DeviceGroup + +def wrap_zigzag_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, + dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None) -> Tensor: + ''' + wrap the zigzag_attn_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + return output + + assert causal == True, "zigzag_ring is meaningless for causal=False" + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + + output = ZigZagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ).contiguous() + + return output + +def emit_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +register_op('bs l h dim^, bs l h dim^, bs l h dim^ -> bs l h dim^', emit_fn=emit_zigzag)(wrap_zigzag_attn_func) diff --git a/examples/zigzag_ring_attention/zigzag_utils/zigzag_attn_implementation.py b/examples/zigzag_ring_attention/zigzag_utils/zigzag_attn_implementation.py new file mode 100644 index 00000000..1e473301 --- /dev/null +++ b/examples/zigzag_ring_attention/zigzag_utils/zigzag_attn_implementation.py @@ -0,0 +1,482 @@ +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from .zigzag_utils import RingComm, update_out_and_lse, shuffle_input, recover_output + +''' +Assume we have 4 GPUs A, B, C, D. +The sequence is represented as [0 1 2 3 4 5 6 7]. + +The P2P communication ring is A -> D -> B -> C -> A +The initial status of the attention computation is +X +X X +X X X +X X X X +X X X X X +X X X X X X +X X X X X X X +X X X X X X X X +Note: +- the computation in the diagonal is `causal=True` +- the computation in the off-diagonal is `causal=False` +We consider a `X` with `causal=True` as a unit computation block. +In this example, there are 4 steps. Each device is responsible for 2 unit computation blocks in each step. + +q status is same across all steps (q is not transmitted): +GPU A: [0 7] +GPU B: [2 5] +GPU C: [3 4] +GPU D: [1 6] + +Step 0, kv status: +GPU A: [0 7] +GPU B: [2 5] +GPU C: [3 4] +GPU D: [1 6] +Computation status: +A +X D +X X B +X X X C +X X X C C +X X B X X B +X D X X X X D +A X X X X X X A + +Step 1, kv status: +GPU A: [3 4] +GPU B: [1 6] +GPU C: [2 5] +GPU D: [0 7] +Computation status: +X +D X +X B X +X X C X +X X C X X +X B X X X X +D X X X X X X +X X X A A X X X + +Step 2, kv status: +GPU A: [2 5] +GPU B: [0 7] +GPU C: [1 6] +GPU D: [3 4] +Computation status: +X +X X +B X X +X C X X +X C X X X +B X X X X X +X X X D D X X +X X A X X A X X + +Step 3, kv status: +GPU A: [1 6] +GPU B: [3 4] +GPU C: [0 7] +GPU D: [2 5] +Computation status: +X +X X +X X X +C X X X +C X X X X +X X X B B X +X X D X X D X +X A X X X X A X + +From this example, we can conclude the key insight of zigzag ring flash attention is: +- split the sequence into fine-grained blocks to achieve balance across steps and gpus +- schedule the computation in a zigzag pattern to minimize the communication overhead + +To be more specific, if the sequence length is L=4n, the total computation cost of flash attention +with causal=True is 1/2 L^2 = 8n^2. Each device needs to compute 4n. Each step needs to compute 2. + +Computation task assigned for each GPU: + +GPU 0: (0, 4n-1) +GPU 1: (2, 4n-3) +... +GPU n-1: (2n-2, 2n+1) +GPU n: (2n-1, 2n) +GPU n+1: (2n-3, 2n+2) +... +GPU 2n-1: (1, 4n-2) + +Dependence of kv (required kv range) for each device: +GPU 0: [0, 4n-1] +GPU 1: [0, 4n-3] +... +GPU n-1: [0, 2n+1] +GPU n: [0, 2n] +GPU n+1: [0, 2n+2] +... +GPU 2n-1: [0, 4n-2] + +In general, if there are 2n GPUs, the ring is 0 -> 2n-1 -> 1 -> 2n-2 -> ... -> n -> n+1 -> 0 + +For each device, the 2n steps is divided into 3 parts: +1. compute the local attention with `causal=True` +2. if current step is less or equal to its relative rank in the ring, select the first half + of the received kv to compute the attention with `causal=False`. In the example above, each + device computes to `left` of its corresponding rows in the status matrix. +3. if current step is greater than its relative rank in the ring, select the second half of + local q and full received kv to compute the attention with `causal=False`. In the example + above, each device fills the remaining part of its lower row in the status matrix. +''' + +def zigzag_ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group) + + block_seq_len = q.shape[1] // 2 + q1 = q[:, block_seq_len:] + + out = None + lse = None + next_k, next_v = None, None + + def forward(q, k, v, causal): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + return block_out, block_lse + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step == 0: + block_out, block_lse = forward(q, k, v, causal=True) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + block_out, block_lse = forward(q, k0, v0, causal=False) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + out, lse = update_out_and_lse( + out, + lse, + block_out, + block_lse, + slice_=(slice(None), slice(block_seq_len, None)), + ) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +''' +In the backward pass, we assume q, k, v and out are saved in the shuffled order. +In addition, the backward pass requires a shuffled dout as input and generates +a shuffled dq, dk, dv as output. Note that out is a sum of all step outputs, so +we can directly pass dout to each step's backward block to compute the local gradient +according to the differiential chain rule. + +Similar to the forward pass, in the backward pass, the 2n steps are divided into 3 parts. + +Different from the forward pass, we need to communicate the gradient of kv in a ring as well. +To be more specific, each device calculates the local gradients of dq, dk, dv. In the following +steps, dq will be accumulated in the initial device, while dk and dv will be transmitted to the +next consumer device, then accumulated in the consumer device. In the end, the dk and dv will be +transmitted back to the initial device. + +In addition, to be compatible with the flash-attn's interface and reduce the precision loss, +we will accumulate and transmit the gradients in float32. They will be converted back to the +original dtype at the end of the backward pass. +''' +def zigzag_ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=1)[1] + q1 = q.chunk(2, dim=1)[1] + out1 = out.chunk(2, dim=1)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() + block_seq_len = q.shape[1] // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout, q, k, v, out, softmax_lse, causal): + seqlen_q = q.shape[1] + seqlen_kv = k.shape[1] + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq_buffer[:, :seqlen_q], + dk_buffer[:, :seqlen_kv], + dv_buffer[:, :seqlen_kv], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + backward(dout, q, k0, v0, out, softmax_lse, causal=False) + dq += dq_buffer + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) + # always use the first half in dq_buffer. + dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.revert_rank: + dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] + dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] + else: + dk += dk_buffer + dv += dv_buffer + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, zigzag ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, dk, dv +''' +class ZigZagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + q = shuffle_input(to_send=q, process_group=group) + k = shuffle_input(to_send=k, process_group=group) + v = shuffle_input(to_send=v, process_group=group) + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = zigzag_ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_output(out, process_group=group) + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_input(to_send=dout, process_group=ctx.group) + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = zigzag_ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + dq = recover_output(dq, ctx.group) + dk = recover_output(dk, ctx.group) + dv = recover_output(dv, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/examples/zigzag_ring_attention/zigzag_utils/zigzag_utils.py b/examples/zigzag_ring_attention/zigzag_utils/zigzag_utils.py new file mode 100644 index 00000000..a576db2f --- /dev/null +++ b/examples/zigzag_ring_attention/zigzag_utils/zigzag_utils.py @@ -0,0 +1,207 @@ +from typing import Optional, Tuple + +import torch +import torch.distributed as dist + +__all__ = ["update_out_and_lse", "RingComm"] + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + @torch.jit.script + def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + + out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + + lse = new_lse + return out, lse + + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + parts = self.world_size // 2 + self.ring_list = [] + for i in range(parts): + self.ring_list.extend([i, self.world_size - i - 1]) + + self.revert_rank = self.ring_list.index(self.rank) + + offset = ((dist.get_rank() // self.world_size) * self.world_size) + self.send_rank = self.ring_list[(self.revert_rank + 1) % self.world_size] + offset + self.recv_rank = self.ring_list[(self.revert_rank - 1) % self.world_size] + offset + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + +def shuffle_input(to_send: torch.Tensor, + process_group: dist.ProcessGroup = None): + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + block_seq_len = to_send.shape[1] // 2 + to_send_slice = to_send[:, block_seq_len:].contiguous() + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[:, block_seq_len:] = to_send[:, :block_seq_len] + to_send_f[:, :block_seq_len, ...] = res + else: # A: 0 1, -> 0 7 + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len] + to_send_f[:, block_seq_len:, ...] = res + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + + return to_send_f + +def recover_output(to_send: torch.Tensor, + process_group: dist.ProcessGroup = None): + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + to_send_f = torch.zeros_like(to_send) + + block_seq_len = to_send.shape[1] // 2 + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if rank >= world_size // 2: + to_send_slice = to_send[:, :block_seq_len, ...].contiguous() + else: + to_send_slice = to_send[:, block_seq_len:, ...].contiguous() + res = torch.zeros_like(to_send_slice) + + assert to_send_slice.is_contiguous() + assert res.is_contiguous() + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: + to_send_f[:, :block_seq_len] = to_send[:, block_seq_len:, ...] + to_send_f[:, block_seq_len:] = res + else: + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len, ...] + to_send_f[:, block_seq_len:] = res + + return to_send_f.contiguous() diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index 500b28e4..2bfb9885 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -157,7 +157,7 @@ def emit_dataloader(self, node: IRDataOperation) -> List[str]: outputs = self.return_name(node.outputs()) return [f'{outputs} = next({self.tensor_name(node.input(0))})'] - def emit_fnode(self, node: IRFwOperation, prefix_attr: str = None) -> List[str]: + def emit_fnode(self, node: IRFwOperation, runtime_devid: int, plan_ndevs: int, runtime_ndevs: int, prefix_attr: str = None) -> List[str]: """Emit forward node code The result will look like (the lines are split into `List[str]`) @@ -168,6 +168,16 @@ def emit_fnode(self, node: IRFwOperation, prefix_attr: str = None) -> List[str]: The fields storing intermediate codes that are populated by this method: - NONE + + Args: + node (IRFwOperation): the forward node to emit + runtime_devid (int): the device id at the runtime + plan_ndevs (int): the number of devices in the scale unit + runtime_ndevs (int): the number of devices at the runtime, which is a multiple of `plan_ndevs` + prefix_attr (str): prefix to the tensor name + + Returns: + List[str]: the lines of the statements of the final Python code """ assert isinstance(node, IRFwOperation) codes = [] @@ -190,7 +200,7 @@ def emit_fnode(self, node: IRFwOperation, prefix_attr: str = None) -> List[str]: kwargs = IRSegment.modify_objects_of_complex(kwargs, modifier) emit_rule = self._emit_rules.map(signature) - body = emit_rule(node, inputs, kwargs) + body = emit_rule(node, inputs, kwargs, runtime_devid, plan_ndevs, runtime_ndevs) if len(node.outputs()) == 0: code = body diff --git a/nnscaler/codegen/frontend_mapping.py b/nnscaler/codegen/frontend_mapping.py index 26d76549..973c321a 100644 --- a/nnscaler/codegen/frontend_mapping.py +++ b/nnscaler/codegen/frontend_mapping.py @@ -1,11 +1,12 @@ # Some operators should be specially handled during codegen to the frontend code, # here we define the customized rule for code emisson. -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple from nnscaler import ir from nnscaler.ir.cten import IRTensor from nnscaler.ir.operator import IRFwOperation +from nnscaler.graph.parser.register import CustomizedOps class Sign2EmitRule: @@ -27,9 +28,12 @@ def map(self, signature: str) -> Callable: Returns: Callable: emit rule that takes the node, args (List[str]) and kwargs (Dict[str, str]) as input """ - return self._sign2rule.get(signature, self.emit_common) + if signature in CustomizedOps.kOpEmit: + return CustomizedOps.kOpEmit[signature] + else: + return self._sign2rule.get(signature, self.emit_common) - def emit_common(self, node: IRFwOperation, args: List[str], kwargs: Dict[str, str]) -> str: + def emit_common(self, node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: """Default rule to join all args and kwargs""" signature = node.signature @@ -42,7 +46,7 @@ def emit_common(self, node: IRFwOperation, args: List[str], kwargs: Dict[str, st args = ", ".join(list(args) + kw_pairs) return f"{signature}({args})" - def emit_slice(self, node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + def emit_slice(self, node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: """Special rule for generating slice node The op is: @@ -71,7 +75,7 @@ def emit_slice(self, node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[st return f"{in_tensor_var}[{', '.join(subscript_components)}]" - def emit_setattr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str]) -> str: + def emit_setattr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: """Special rule for generating setattr node """ assert False, f"This emit rule is deprecated, please report if you reach here" diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index a79f01f3..7169b408 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -424,7 +424,7 @@ def forward(self, x, y=None, z=None): for node in sequence: if isinstance(node, IRSegment): if not node.isfw(): continue # skip backward segment - codes = self.emit_segment(node) + codes = self.emit_segment(node, device) elif isinstance(node, IRFwOperation): raise RuntimeError(f"Unexcepted global-level op call: {node}") elif isinstance(node, IRAdapter): @@ -700,7 +700,7 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: add_code = reducer_add.format(reducer=reducer_name) self.model_init_statements.append(add_code) - def emit_segment(self, segment: IRSegment) -> List[str]: + def emit_segment(self, segment: IRSegment, runtime_devid: int) -> List[str]: """ Emit IRSegment code. @@ -740,12 +740,12 @@ def emit_segment(self, segment: IRSegment) -> List[str]: assert len(rc_group) > 0 gid: Optional[int] = rc_group[0].recompute if gid is None: - codes += self._emit_nodes(rc_group, lifetime) + codes += self._emit_nodes(rc_group, lifetime, runtime_devid) else: # get recompute excution code rc_segment = segment.create_segment(rc_group) rc_codes = self._emit_recompute(rc_group, - rc_segment.inputs(), rc_segment.outputs(), lifetime) + rc_segment.inputs(), rc_segment.outputs(), lifetime, runtime_devid) codes += rc_codes # release input tensors after exiting a RC group: last_node = rc_group[-1] @@ -758,7 +758,7 @@ def emit_segment(self, segment: IRSegment) -> List[str]: return codes - def _emit_nodes(self, nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: + def _emit_nodes(self, nodes: List[IRCell], lifecycle: LifeCycle, runtime_devid: int) -> List[str]: """ Emit code to invoke operations and adapter, e.g. (the lines are split into `List[str]`) @@ -777,7 +777,7 @@ def _emit_nodes(self, nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: for node in nodes: # execute if isinstance(node, IRFwOperation): - code = self.emit_fnode(node, prefix_attr='self.') + code = self.emit_fnode(node, runtime_devid=runtime_devid, plan_ndevs=len(self.devices), runtime_ndevs=self.runtime_ndevs, prefix_attr='self.') node_codes += code elif isinstance(node, IRAdapter): # for adapters inside an IRSegment, we don't apply async communication to it @@ -794,7 +794,7 @@ def _emit_nodes(self, nodes: List[IRCell], lifecycle: LifeCycle) -> List[str]: return node_codes def _emit_recompute(self, nodes: Tuple[IRCell], inputs: List[IRSubTensor], outputs: List[IRSubTensor], - lifecycle: LifeCycle) -> List[str]: + lifecycle: LifeCycle, runtime_devid: int) -> List[str]: """ Emit code to define a Python function for Recomputing and invoke it e.g. (the lines are split into `List[str]`) @@ -845,7 +845,7 @@ def recompute(tensor_2222): # for ncode in ModuleCodeGen._emit_nodes(nodes, lifecycle): # fb.insert_body(ncode) - fb.insert_body(self._emit_nodes(nodes, lifecycle)) + fb.insert_body(self._emit_nodes(nodes, lifecycle, runtime_devid)) fb.insert_body(f'return {output_names_tuple}') codes = [''] + fb.code + [''] codes.append( diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index 2e479314..9cd50124 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -2,7 +2,7 @@ Register cutomized function """ -from typing import Dict, Callable, Optional, Union +from typing import Dict, Callable, Optional, Union, List from functools import partial import inspect import logging @@ -11,7 +11,7 @@ from nnscaler.graph.function.dimops import IRDimops, OpAnno from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply -from nnscaler.ir.operator import IRTensor +from nnscaler.ir.operator import IRTensor, IRFwOperation _logger = logging.getLogger(__name__) @@ -25,6 +25,10 @@ class CustomizedOps: kOpRuntime: Dict[str, Callable] = {} # signature -> runtime function implementation code kOpCodeDef: Dict[str, str] = {} + # signature -> special emit function, will not store if emit_fn is None + # It accepts the node, repred args, repred kwargs, runtime_devid, plan_ndevs, runtime_ndevs + # as input and returns the generated code. + kOpEmit: Dict[str, Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str]] = {} @staticmethod def map(signature: str) -> Callable: @@ -46,7 +50,8 @@ def exist(signature: str) -> bool: return signature in CustomizedOps.kOpMap @staticmethod - def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Callable): + def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Callable, + emit_fn: Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str] = None): """Register an operator Args: @@ -54,6 +59,9 @@ def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Call op_create_fn (Callable): IRDimops creation function code (str): runtime function implementation code runtime_fn (Callable): runtime function + emit_fn (Callable): special emit function for codegen, will use default emit function if emit_fn is None. + It accepts the node, repred args, repred kwargs, runtime_devid, plan_ndevs, runtime_ndevs + as input and returns the generated code. Returns: None @@ -65,10 +73,12 @@ def register(signature: str, op_create_fn: Callable, code: str, runtime_fn: Call CustomizedOps.kOpMap[signature] = op_create_fn CustomizedOps.kOpRuntime[signature] = runtime_fn CustomizedOps.kOpCodeDef[signature] = code + if emit_fn is not None: + CustomizedOps.kOpEmit[signature] = emit_fn def register_op(annotation: Union[str, Callable], name: Optional[str] = None, - code_impl_pattern: str = 'import') -> Callable: + code_impl_pattern: str = 'import', emit_fn: Callable[[IRFwOperation, List[str], Dict[str, str], int, int, int], str] = None) -> Callable: """ Register a function with IRDimops annotations. @@ -123,6 +133,10 @@ def anno_fn(*inputs, **kwargs): can only be 'import' or 'source'. If 'import', will generate code with import statement. If 'source', will take the source code directly. Default: 'import'. + emit_fn (Callable): special emit function for codegen, this emit accepts the node, repred args, repred kwargs, runtime_devid, + plan_ndevs, runtime_ndevs as input and returns the generated code. Check examples/zigzag_ring_attention/zigzag_attn.py + for more details. + Default: None. Returns: fn (Callable): the runtime function @@ -221,7 +235,7 @@ def udfop(*args, signature=None, **kwargs): # step 4. register in CustomizedOps _logger.info(f'registering op {fsig}...') - CustomizedOps.register(fsig, udfop, code, fn) + CustomizedOps.register(fsig, udfop, code, fn, emit_fn) return fn return decorator diff --git a/tests/codegen/test_emit.py b/tests/codegen/test_emit.py index fef41acd..cf814b5e 100644 --- a/tests/codegen/test_emit.py +++ b/tests/codegen/test_emit.py @@ -30,6 +30,6 @@ def test_tensor_name(): def test_emit_module_attr(): dropout = Dropout(IRFullTensor([1024, 1024], requires_grad=True), p=0.5, training='self.training', signature='torch.nn.functional.dropout') - code = FuncEmission().emit_fnode(dropout) + code = FuncEmission().emit_fnode(dropout, runtime_devid=0, plan_ndevs=1, runtime_ndevs=1) print(code) assert 'training=self.training' in code[0] diff --git a/tests/graph/parser/test_register.py b/tests/graph/parser/test_register.py index d5e6ed4d..a3b6dc70 100644 --- a/tests/graph/parser/test_register.py +++ b/tests/graph/parser/test_register.py @@ -1,6 +1,7 @@ import nnscaler from nnscaler.graph.parser.converter import convert_model from nnscaler.profiler.database import get_func +from nnscaler.codegen.emit import FuncEmission import tempfile import torch @@ -122,3 +123,33 @@ def test_autograd_register(): sub_nodes = ir_graph.partition(node, node.algorithms('dim'), idx=0, dim=0, num=2) for sub_node in sub_nodes: assert sub_node.kwargs['h'] == 2 + +def customized_add(x, y): + return x + y + +def emit_customized_add(node, args, kwargs, runtime_devid, plan_ndevs, runtime_ndevs): + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + args = ", ".join(list(args) + kw_pairs) + return f"torch.add({args})" + +nnscaler.register_op('*, * -> *', emit_fn=emit_customized_add)(customized_add) + +class ModelCustomizedAdd(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + return customized_add(x, y) + +@replace_all_device_with('cpu') +def test_customized_emit(): + model = ModelCustomizedAdd() + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(model, {'x': torch.rand(10, 10), 'y': torch.rand(10, 10)}, tempdir, False) + add_node = ir_graph.nodes()[0] + code = FuncEmission().emit_fnode(add_node, runtime_devid=0, plan_ndevs=1, runtime_ndevs=1) + assert 'torch.add' in code[-1] From 30ec31d2cca01d88c292474026acdac273c950a6 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 4 Jun 2024 06:02:56 +0000 Subject: [PATCH 1651/1892] Merged PR 2156: initialize lightning support add nnscaler strategy --- nnscaler/integration/__init__.py | 0 .../integration/lightning/fabric/__init__.py | 0 .../integration/lightning/pytorch/__init__.py | 2 + .../lightning/pytorch/precision.py | 140 +++++ .../integration/lightning/pytorch/strategy.py | 498 ++++++++++++++++++ nnscaler/integration/lightning/utils.py | 13 + requirements-dev.txt | 2 + tests/integration/__init__.py | 0 tests/integration/common.py | 10 + tests/integration/lightning/__init__.py | 0 tests/integration/lightning/datasets.py | 200 +++++++ .../integration/lightning/fabric/__init__.py | 0 .../integration/lightning/pytorch/__init__.py | 0 .../lightning/pytorch/simple_datamodules.py | 130 +++++ .../lightning/pytorch/simple_models.py | 211 ++++++++ .../lightning/pytorch/test_strategy.py | 132 +++++ 16 files changed, 1338 insertions(+) create mode 100644 nnscaler/integration/__init__.py create mode 100644 nnscaler/integration/lightning/fabric/__init__.py create mode 100644 nnscaler/integration/lightning/pytorch/__init__.py create mode 100644 nnscaler/integration/lightning/pytorch/precision.py create mode 100644 nnscaler/integration/lightning/pytorch/strategy.py create mode 100644 nnscaler/integration/lightning/utils.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/common.py create mode 100644 tests/integration/lightning/__init__.py create mode 100644 tests/integration/lightning/datasets.py create mode 100644 tests/integration/lightning/fabric/__init__.py create mode 100644 tests/integration/lightning/pytorch/__init__.py create mode 100644 tests/integration/lightning/pytorch/simple_datamodules.py create mode 100644 tests/integration/lightning/pytorch/simple_models.py create mode 100644 tests/integration/lightning/pytorch/test_strategy.py diff --git a/nnscaler/integration/__init__.py b/nnscaler/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nnscaler/integration/lightning/fabric/__init__.py b/nnscaler/integration/lightning/fabric/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nnscaler/integration/lightning/pytorch/__init__.py b/nnscaler/integration/lightning/pytorch/__init__.py new file mode 100644 index 00000000..80f37b9e --- /dev/null +++ b/nnscaler/integration/lightning/pytorch/__init__.py @@ -0,0 +1,2 @@ +from .precision import NnScalerPrecision +from .strategy import NnScalerStrategy diff --git a/nnscaler/integration/lightning/pytorch/precision.py b/nnscaler/integration/lightning/pytorch/precision.py new file mode 100644 index 00000000..895282bc --- /dev/null +++ b/nnscaler/integration/lightning/pytorch/precision.py @@ -0,0 +1,140 @@ +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, Union + +import torch +from torch import Tensor +from torch.optim import Optimizer +import torch.amp + +import lightning.pytorch as pl +from lightning_utilities import apply_to_collection +from typing_extensions import get_args, override + +from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.pytorch.plugins.precision.precision import Precision +from lightning.fabric.utilities.types import Steppable +from lightning.pytorch.utilities import GradClipAlgorithmType + + +_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"] + + +class NnScalerPrecision(Precision): + """Precision plugin for training with nnscaler. + + .. warning:: This is an :ref:`experimental ` feature. + + Args: + precision: Full precision (32-true), half precision (16-true, bf16-true) + + Raises: + ValueError: + If unsupported ``precision`` is provided. + + """ + + def __init__( + self, + precision: _PRECISION_INPUT, + scaler=None, + ) -> None: + """ + Args: + scaler: a torch.amp.GradScaler-like object, supporting + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + """ + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`precision={precision!r})` is not supported in nnScaler." + f" `precision` must be one of: {supported_precision}." + ) + + self.precision = precision + self.scaler = scaler + + precision_to_type = { + "bf16-true": torch.bfloat16, + "16-true": torch.float16, + "32-true": torch.float32, + } + self._desired_input_dtype = precision_to_type[self.precision] + + @override + def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: + return module.to(dtype=self._desired_input_dtype) + + @override + def convert_input(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) + + @override + def tensor_init_context(self) -> ContextManager: + return _DtypeContextManager(self._desired_input_dtype) + + @override + def module_init_context(self) -> ContextManager: + return self.tensor_init_context() + + @override + def forward_context(self) -> ContextManager: + return _DtypeContextManager(self._desired_input_dtype) + + @override + def convert_input(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) + + @override + def convert_output(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + + @override + def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override] + if self.scaler is not None: + tensor = self.scaler.scale(tensor) + return super().pre_backward(tensor, module) + + @override + def optimizer_step( # type: ignore[override] + self, + optimizer: Steppable, + model: "pl.LightningModule", + closure: Callable[[], Any], + **kwargs: Any, + ) -> Any: + if self.scaler is None: + return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs) + + closure_result = closure() + + if not _optimizer_handles_unscaling(optimizer): + # Unscaling needs to be performed here in case we are going to apply gradient clipping. + # Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam). + # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. + self.scaler.unscale_(optimizer) # type: ignore[arg-type] + + self._after_closure(model, optimizer) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found + step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type] + self.scaler.update() + return step_output + return closure_result + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float] = 0.0, + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + ) -> None: + """Clips the gradients.""" + if clip_val <= 0: + return + if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + raise ValueError('nnscaler does not support clipping gradients by value.') + elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: + optimizer.clip_gnorm(clip_val) # define in nnscaler diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py new file mode 100644 index 00000000..c7fa3a42 --- /dev/null +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -0,0 +1,498 @@ +from contextlib import nullcontext +from functools import partial +import logging +from pathlib import Path +import os +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Dict, + Generator, + List, + Literal, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) + +import torch +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer +from typing_extensions import TypeGuard, override + +import lightning.pytorch as pl +from lightning.pytorch.accelerators import Accelerator, CUDAAccelerator +from lightning.pytorch.plugins.precision import Precision +from lightning.pytorch.trainer.states import TrainerFn +from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment +from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher +from lightning.pytorch.strategies.parallel import ParallelStrategy +from lightning.fabric.strategies.registry import _StrategyRegistry +from lightning.fabric.strategies.strategy import ( + TBroadcast, + _Sharded, +) +from lightning.fabric.utilities.distributed import ( + ReduceOp, + _distributed_is_initialized, + _sync_ddp_if_available, +) +from lightning.fabric.utilities.distributed import group as _group +from lightning.fabric.utilities.seed import reset_seed +from lightning.fabric.utilities.types import _PATH, _Stateful +from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers +from lightning.pytorch.utilities.types import LRSchedulerConfig, STEP_OUTPUT +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn +from lightning.pytorch.utilities import GradClipAlgorithmType + +import nnscaler +from nnscaler.integration.lightning.utils import inplace_optimizer_fn +from .precision import NnScalerPrecision + + +logger = logging.getLogger(__name__) + + +class NnScalerStrategy(ParallelStrategy): + r"""Strategy for nnscaler. + + .. warning:: This is an :ref:`experimental ` feature. + + Arguments: + state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. + + - ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is + a folder with as many files as the world size. + - ``"deduped"``: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is + a folder with as many files as the world size. + """ + strategy_name = "nnscaler" + _registered_strategies: List[str] = [] + + def __init__( + self, + accelerator: Optional[Accelerator] = None, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + precision_plugin: Optional[Precision] = None, + compute_config: Optional[nnscaler.ComputeConfig] = None, + state_dict_type: Literal["deduped", "sharded"] = "sharded", + pas_policy: str = None, + gen_savedir: Union[str, Path] = './.nnscaler', + reuse: str = 'match', + instance_name: Optional[str] = None, + **kwargs + ) -> None: + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + precision_plugin=precision_plugin, + ) + self._forward_redirection = None + + self._num_nodes = 1 + self.compute_config = compute_config + self.pas_policy = pas_policy + self.gen_savedir = gen_savedir + self.reuse = reuse + self.instance_name = instance_name + if self.compute_config is None: + raise ValueError("The `compute_config` must be provided to the `NnScalerStrategy`.") + if self.pas_policy is None: + raise ValueError("The `pas_policy` must be provided to the `NnScalerStrategy`.") + + self._state_dict_type = state_dict_type + self._nnscaler_extra_state_key = 'nnscaler-extra-state' + self._state_dict_type_key = 'state-dict-type' + self._pl_module_name_key = 'pl_state_dict' # save some extra pl module states + self._pmodule_attr_name = 'nnscaler_pmodule' + self._module_name_key = 'state_dict' + self._opt_name_key = 'optimizer_states' + + @override + def setup_environment(self) -> None: + if not isinstance(self.accelerator, CUDAAccelerator): + raise RuntimeError( + f"The nnscaler strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" + " is used." + ) + super().setup_environment() + self._setup_distributed() + + def _setup_distributed(self) -> None: + assert self.parallel_devices is not None + self._validate_device_index_selection() + reset_seed() + self.set_world_ranks() + self._set_node_environment_variables() + nnscaler.init() + + def set_world_ranks(self) -> None: + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.compute_config.runtime_ngpus) + + def _set_node_environment_variables(self) -> None: + assert self.cluster_environment is not None + os.environ["MASTER_ADDR"] = self.cluster_environment.main_address + os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) + os.environ["RANK"] = str(self.global_rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["LOCAL_RANK"] = str(self.local_rank) + + def _validate_device_index_selection(self) -> None: + selected_device_indices = [device.index for device in self.parallel_devices] + expected_device_indices = list(range(len(self.parallel_devices))) + if selected_device_indices != expected_device_indices: + raise RuntimeError( + f"The selected device indices {selected_device_indices!r} don't match the local rank values of processes." + " If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable" + f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`." + ) + + @property + @override + def restore_checkpoint_after_setup(self) -> bool: + return True + + @property + @override + def lightning_restore_optimizer(self) -> bool: + return False + + @property + def is_distributed(self) -> bool: + """ + Indicates we are running in distributed mode + And `distributed_sampler_kwargs` will be used to configure the sampler + """ + return True + + @override + def setup(self, trainer: "pl.Trainer") -> None: + super().setup(trainer) + assert self._lightning_module is not None + assert self._model is not None + + # nnscaler handles gradient clipping internally + if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule): + rank_zero_warn( + "Since nnscaler handles gradient clipping internally, the default" + " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients." + " The hook will still be called. Consider setting" + " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`" + " which will use the internal mechanism." + ) + + if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + raise MisconfigurationException("nnscaler does not support clipping gradients by value.") + + @override + def _setup_model(self, model: Module) -> Module: + """Set up a module for inference (no optimizers). + """ + if getattr(model, 'dummy_input', None) is None: + raise ValueError("The `dummy_input` must be defined as a property in the module.") + if not isinstance(model.dummy_input, dict): + raise ValueError("The `dummy_input` must be a dictionary with forward arguments names as keys.") + + old_training_flag = model.training + if not old_training_flag: + logger.warning("The model is not in training mode. Setting it to training mode for parallelizing.") + model.train() # always use the model in training mode + pmodule = nnscaler.parallelize( + model, + self.precision_plugin.convert_input(model.dummy_input), + self.pas_policy, + self.compute_config, + gen_savedir=self.gen_savedir, + reuse=self.reuse, + instance_name=self.instance_name, + broadcast_strategy='all' + ) + model.train(old_training_flag) + pmodule.to(self.root_device) + + # update the device of the module + model._device = self.root_device + + # set all module parameters of original model to None + # to reduce the memory usage + # In return, the original model will not be able to access the parameters anymore + # but the forward will be redirected to the parallelized model + # TODO: this doesn't work for pipeline because fullmap is not complete + for attr in pmodule.fullmap.values(): + attr_name = attr.orig_name.split('.')[0] + setattr(model, attr_name, None) + + # torch.nn.Module will add new attributes to the model automatically. + setattr(model, self._pmodule_attr_name, pmodule) + model.to(self.root_device) + # rewrite model forward to parallelized model forward + model.forward = pmodule.forward + + return model + + @override + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + assert self.lightning_module is not None + + # If we're setting up for evaluation after fitting, we need to discard the optimizers + # since we're rewrapping the model, otherwise optimizer param references are no longer valid + # and subsequent checkpoint saving can fail + self._reset_optimizers_and_schedulers() + + optimizer, lr_scheduler = self._init_optimizers() + if len(optimizer.param_groups) != 1: + raise MisconfigurationException( + "nnscaler currently only supports single optimizer with a single param group." + ) + new_optimizer = nnscaler.build_optimizer( + getattr(trainer.model, self._pmodule_attr_name), + partial(inplace_optimizer_fn, optimizer) + ) + # the lr_scheduler doesn't need to update when we change the optimizer's param_graups[0]['params'] + self.optimizers, self.lr_scheduler_configs = [new_optimizer], ([lr_scheduler] if lr_scheduler else []) + + def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig]]: + assert self.lightning_module is not None + optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) + if len(optimizers) > 1 or len(lr_schedulers) > 1: + raise MisconfigurationException( + "nnscaler currently only supports single optimizer, single optional scheduler." + ) + return optimizers[0], lr_schedulers[0] if lr_schedulers else None + + @property + @override + def root_device(self) -> torch.device: + assert self.parallel_devices is not None + return self.parallel_devices[self.local_rank] + + @property + def num_nodes(self) -> int: + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: int) -> None: + self._num_nodes = num_nodes + + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + @override + def distributed_sampler_kwargs(self) -> Dict[str, Any]: + return { + "num_replicas": self.compute_config.runtime_ngpus//self.compute_config.plan_ngpus, + "rank": self.global_rank // self.compute_config.plan_ngpus + } + + @property + @override + def precision_plugin(self) -> NnScalerPrecision: + plugin = self._precision_plugin + if plugin is not None: + assert isinstance(plugin, NnScalerPrecision) + return plugin + return NnScalerPrecision("32-true") + + @precision_plugin.setter + @override + def precision_plugin(self, precision: Optional[NnScalerPrecision]) -> None: + if precision is not None and not isinstance(precision, NnScalerPrecision): + raise TypeError(f"The nnscaler strategy can only work with the `NnScalerPrecision` plugin, found {precision}") + self._precision_plugin = precision + + @override + def _configure_launcher(self) -> None: + assert self.cluster_environment is not None + if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + + @override + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Optimizers can only be set up jointly with the model in this strategy. + + Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together. + + """ + raise NotImplementedError(self._err_msg_joint_setup_required()) + + @override + def model_to_device(self) -> None: + assert self.model is not None + self.model.to(self.root_device) + + @override + def barrier(self, name: Optional[str] = None) -> None: + if not _distributed_is_initialized(): + return + if torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=[self.root_device.index]) + else: + torch.distributed.barrier() + + @override + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + if not _distributed_is_initialized(): + return obj + + obj = [obj] + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + @override + def reduce( + self, + tensor: Union[Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Tensor: + """Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + + """ + if isinstance(tensor, Tensor): + return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + @override + def lightning_module_state_dict(self) -> Dict[str, Any]: + assert self.model is not None + # do it in `save_checkpoint` + return {} + + @override + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + assert self.optimizers + # do it in `save_checkpoint` + return {} + + @override + def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: + # Override to do nothing, already loaded the states in `load_checkpoint()` + pass + + @override + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + # Override to do nothing, already loaded the states in `load_checkpoint()` + pass + + @override + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + """Save model, optimizer, and other state to a checkpoint on disk. + """ + if storage_options is not None: + raise TypeError( + "`NnScalerStrategy.save_checkpoint(..., storage_options=...)` is not supported because" + " `NnScalerStrategy` does not use the `CheckpointIO`." + ) + + # broadcast the path from rank 0 to ensure all the states are saved in a common path + path = Path(self.broadcast(filepath)) + path.mkdir(parents=True, exist_ok=True) + + nnscaler_pmodule = getattr(self._lightning_module, self._pmodule_attr_name) + pl_module_state_dict = self._lightning_module.state_dict() + # remove the parallelized module state from it + for key in list(pl_module_state_dict.keys()): + if key.startswith(self._pmodule_attr_name + '.'): + pl_module_state_dict.pop(key) + + nnscaler_extra_state = { + self._state_dict_type_key: self._state_dict_type, + self._pl_module_name_key: pl_module_state_dict + } + checkpoint[self._nnscaler_extra_state_key] = nnscaler_extra_state + + if self._state_dict_type == "deduped": + module_state, opt_state = nnscaler.deduped_state_dict( + nnscaler_pmodule, + self.optimizers[0] if self.optimizers else None + ) + else: + module_state = nnscaler_pmodule.state_dict() + if self.optimizers: + opt_state = self.optimizers[0].state_dict() + else: + opt_state = None + checkpoint[self._module_name_key] = module_state + if opt_state: + checkpoint[self._opt_name_key] = [opt_state] + + torch.save(checkpoint, path / f'{self.global_rank}.pt') + + @override + def load_checkpoint( + self, checkpoint_path: _PATH + ) -> Dict[str, Any]: + """ + Load the contents from a checkpoint and restore the state of the given objects. + """ + # broadcast the path from rank 0 to ensure all the states are loaded from a common path + path = Path(self.broadcast(checkpoint_path)) + assert self.model is not None + assert self.lightning_module is not None + + state_dict: dict = torch.load(path / f'{self.global_rank}.pt') + nnscaler_extra_state = state_dict.pop(self._nnscaler_extra_state_key) + # load the extra states of the pl module + self._lightning_module.load_state_dict(nnscaler_extra_state[self._pl_module_name_key], strict=False) + + module_dict = state_dict[self._module_name_key] + state_dict[self._module_name_key] = {} + optimizer_dict = None + if self._opt_name_key in state_dict: + optimizer_dict = state_dict[self._opt_name_key][0] + state_dict[self._opt_name_key] = [{}] + + state_dict_type = nnscaler_extra_state[self._state_dict_type_key] + + module = getattr(self._lightning_module, self._pmodule_attr_name) + optimizer = self.optimizers[0] if self.optimizers else None + + if state_dict_type == "deduped": + nnscaler.load_deduped_state_dict(module, module_dict, optimizer, optimizer_dict) + else: + module.load_state_dict(module_dict) + if optimizer_dict is not None: + optimizer.load_state_dict(optimizer_dict) + + return state_dict + + @classmethod + @override + def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: + if not torch.distributed.is_available(): + return + + strategy_registry.register( + "nnscaler", + cls, + description="nnscaler training", + ) + cls._registered_strategies.append("nnscaler") + + def _get_process_group_backend(self) -> str: + return 'nccl' # nnscaler only support nccl diff --git a/nnscaler/integration/lightning/utils.py b/nnscaler/integration/lightning/utils.py new file mode 100644 index 00000000..bc105233 --- /dev/null +++ b/nnscaler/integration/lightning/utils.py @@ -0,0 +1,13 @@ +import torch +from torch.optim.lbfgs import LBFGS + +from lightning.pytorch.utilities.exceptions import MisconfigurationException + + +def inplace_optimizer_fn(optimizer, params): + # hack to replace the optimizer's param_groups with the new params + optimizer.param_groups[0]['params'] = list(params) + # handle special cases. e.g. LBFGS + if isinstance(optimizer, LBFGS): + raise MisconfigurationException("LBFGS optimizer is not supported.") + return optimizer diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f203d24..ae21ff91 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,6 +6,8 @@ pre-commit pytest pytest-cov pytest-mock +scikit-learn +lightning sphinx sphinxcontrib-napoleon tabulate diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/common.py b/tests/integration/common.py new file mode 100644 index 00000000..4e8b5a83 --- /dev/null +++ b/tests/integration/common.py @@ -0,0 +1,10 @@ +import torch + +class BoringModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2, bias=False) + + def forward(self, x): + x = self.layer(x) + return torch.nn.functional.mse_loss(x, torch.ones_like(x)) diff --git a/tests/integration/lightning/__init__.py b/tests/integration/lightning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/lightning/datasets.py b/tests/integration/lightning/datasets.py new file mode 100644 index 00000000..3769e76d --- /dev/null +++ b/tests/integration/lightning/datasets.py @@ -0,0 +1,200 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import random +import time +import urllib.request +from typing import Optional, Sequence, Tuple + +import torch +from torch import Tensor +from torch.utils.data import Dataset + + +class MNIST(Dataset): + """Customized `MNIST `_ dataset for testing PyTorch Lightning without the + torchvision dependency. + + Part of the code was copied from + https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py + + Args: + root: Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train: If ``True``, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + normalize: mean and std deviation of the MNIST dataset. + download: If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + RESOURCES = ( + "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt", + "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt", + ) + + TRAIN_FILE_NAME = "training.pt" + TEST_FILE_NAME = "test.pt" + cache_folder_name = "complete" + + def __init__( + self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs + ): + super().__init__() + self.root = root + self.train = train # training set or test set + self.normalize = normalize + + self.prepare_data(download) + + data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME + self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) + + def __getitem__(self, idx: int) -> Tuple[Tensor, int]: + img = self.data[idx].float().unsqueeze(0) + target = int(self.targets[idx]) + + if self.normalize is not None and len(self.normalize) == 2: + img = self.normalize_tensor(img, *self.normalize) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + @property + def cached_folder_path(self) -> str: + return os.path.join(self.root, "MNIST", self.cache_folder_name) + + def _check_exists(self, data_folder: str) -> bool: + existing = True + for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): + existing = existing and os.path.isfile(os.path.join(data_folder, fname)) + return existing + + def prepare_data(self, download: bool = True): + if download and not self._check_exists(self.cached_folder_path): + self._download(self.cached_folder_path) + if not self._check_exists(self.cached_folder_path): + raise RuntimeError("Dataset not found.") + + def _download(self, data_folder: str) -> None: + os.makedirs(data_folder, exist_ok=True) + for url in self.RESOURCES: + logging.info(f"Downloading {url}") + fpath = os.path.join(data_folder, os.path.basename(url)) + urllib.request.urlretrieve(url, fpath) # noqa: S310 + + @staticmethod + def _try_load(path_data, trials: int = 30, delta: float = 1.0): + """Resolving loading from the same time from multiple concurrent processes.""" + res, exception = None, None + assert trials, "at least some trial has to be set" + assert os.path.isfile(path_data), f"missing file: {path_data}" + for _ in range(trials): + try: + res = torch.load(path_data) + # todo: specify the possible exception + except Exception as ex: + exception = ex + time.sleep(delta * random.random()) + else: + break + if exception is not None: + # raise the caught exception + raise exception + return res + + @staticmethod + def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: + mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) + return tensor.sub(mean).div(std) + + +class TrialMNIST(MNIST): + """Constrained MNIST dataset. + + Args: + num_samples: number of examples per selected class/digit + digits: list selected MNIST digits/classes + kwargs: Same as MNIST + + """ + + def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): + # number of examples per class + self.num_samples = num_samples + # take just a subset of MNIST dataset + self.digits = sorted(digits) if digits else list(range(10)) + + self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}" + + super().__init__(root, normalize=(0.5, 1.0), **kwargs) + + @staticmethod + def _prepare_subset(full_data: Tensor, full_targets: Tensor, num_samples: int, digits: Sequence): + classes = {d: 0 for d in digits} + indexes = [] + for idx, target in enumerate(full_targets): + label = target.item() + if classes.get(label, float("inf")) >= num_samples: + continue + indexes.append(idx) + classes[label] += 1 + if all(classes[k] >= num_samples for k in classes): + break + data = full_data[indexes] + targets = full_targets[indexes] + return data, targets + + def _download(self, data_folder: str) -> None: + super()._download(data_folder) + for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): + path_fname = os.path.join(self.cached_folder_path, fname) + assert os.path.isfile(path_fname), f"Missing cached file: {path_fname}" + data, targets = self._try_load(path_fname) + data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits) + torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) + + +class AverageDataset(Dataset): + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] + + +class SklearnDataset(Dataset): + def __init__(self, x, y, x_type, y_type): + self.x = x + self.y = y + self._x_type = x_type + self._y_type = y_type + + def __getitem__(self, idx): + return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type) + + def __len__(self): + return len(self.y) diff --git a/tests/integration/lightning/fabric/__init__.py b/tests/integration/lightning/fabric/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/lightning/pytorch/__init__.py b/tests/integration/lightning/pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/lightning/pytorch/simple_datamodules.py b/tests/integration/lightning/pytorch/simple_datamodules.py new file mode 100644 index 00000000..0ead1445 --- /dev/null +++ b/tests/integration/lightning/pytorch/simple_datamodules.py @@ -0,0 +1,130 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from lightning.pytorch.core.datamodule import LightningDataModule +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import DataLoader + +from ..datasets import MNIST, SklearnDataset, TrialMNIST + +_SKLEARN_AVAILABLE = RequirementCache("scikit-learn") + + +class MNISTDataModule(LightningDataModule): + def __init__(self, data_dir: str = "./", batch_size: int = 32, use_trials: bool = False) -> None: + super().__init__() + + self.data_dir = data_dir + self.batch_size = batch_size + + # TrialMNIST is a constrained MNIST dataset + self.dataset_cls = TrialMNIST if use_trials else MNIST + + def prepare_data(self): + # download only + self.dataset_cls(self.data_dir, train=True, download=True) + self.dataset_cls(self.data_dir, train=False, download=True) + + def setup(self, stage: str): + if stage == "fit": + self.mnist_train = self.dataset_cls(self.data_dir, train=True) + if stage == "test": + self.mnist_test = self.dataset_cls(self.data_dir, train=False) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=False) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False) + + +class SklearnDataModule(LightningDataModule): + def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 10): + if not _SKLEARN_AVAILABLE: + raise ImportError(str(_SKLEARN_AVAILABLE)) + + super().__init__() + self.batch_size = batch_size + self._x, self._y = sklearn_dataset + self._split_data() + self._x_type = x_type + self._y_type = y_type + + def _split_data(self): + from sklearn.model_selection import train_test_split + + self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( + self._x, self._y, test_size=0.20, random_state=42 + ) + self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split( + self.x_train, self.y_train, test_size=0.40, random_state=42 + ) + + def train_dataloader(self): + return DataLoader( + SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type), + batch_size=self.batch_size, + ) + + def val_dataloader(self): + return DataLoader( + SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), batch_size=self.batch_size + ) + + def test_dataloader(self): + return DataLoader( + SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size + ) + + def predict_dataloader(self): + return DataLoader( + SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size + ) + + @property + def sample(self): + return torch.tensor([self._x[0]], dtype=self._x_type) + + +class ClassifDataModule(SklearnDataModule): + def __init__( + self, num_features=32, length=800, num_classes=3, batch_size=10, n_clusters_per_class=1, n_informative=2 + ): + if not _SKLEARN_AVAILABLE: + raise ImportError(str(_SKLEARN_AVAILABLE)) + + from sklearn.datasets import make_classification + + data = make_classification( + n_samples=length, + n_features=num_features, + n_classes=num_classes, + n_clusters_per_class=n_clusters_per_class, + n_informative=n_informative, + random_state=42, + ) + super().__init__(data, x_type=torch.float32, y_type=torch.long, batch_size=batch_size) + + +class RegressDataModule(SklearnDataModule): + def __init__(self, num_features=16, length=800, batch_size=10): + if not _SKLEARN_AVAILABLE: + raise ImportError(str(_SKLEARN_AVAILABLE)) + + from sklearn.datasets import make_regression + + x, y = make_regression(n_samples=length, n_features=num_features, random_state=42) + y = [[v] for v in y] + super().__init__((x, y), x_type=torch.float32, y_type=torch.float32, batch_size=batch_size) diff --git a/tests/integration/lightning/pytorch/simple_models.py b/tests/integration/lightning/pytorch/simple_models.py new file mode 100644 index 00000000..2d124a96 --- /dev/null +++ b/tests/integration/lightning/pytorch/simple_models.py @@ -0,0 +1,211 @@ +# copied from lightning tests + +from typing import Any, Dict, Iterator, List, Optional, Tuple +import torch +from torch.optim.lr_scheduler import LRScheduler +import torch.nn.functional as F +from lightning.pytorch import LightningModule +from torch import Tensor, nn +from torchmetrics import Accuracy +from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset + + +from lightning.pytorch.utilities.types import STEP_OUTPUT + + +class ClassificationModel(LightningModule): + def __init__(self, num_features=32, num_classes=3, batch_size=10, lr=0.01): + super().__init__() + + self.lr = lr + self.num_features = num_features + self.num_classes = num_classes + self.batch_size = batch_size + for i in range(3): + setattr(self, f"layer_{i}", nn.Linear(num_features, num_features)) + setattr(self, f"layer_{i}a", torch.nn.ReLU()) + setattr(self, "layer_end", nn.Linear(num_features, 3)) + + acc = Accuracy(task="multiclass", num_classes=num_classes) + self.train_acc = acc.clone() + self.valid_acc = acc.clone() + self.test_acc = acc.clone() + + @property + def dummy_input(self): + return {'x': torch.randn(self.batch_size, self.num_features)} + + def forward(self, x): + x = self.layer_0(x) + x = self.layer_0a(x) + x = self.layer_1(x) + x = self.layer_1a(x) + x = self.layer_2(x) + x = self.layer_2a(x) + x = self.layer_end(x) + return F.softmax(x, dim=1) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return [optimizer], [] + + def training_step(self, batch, batch_idx): + assert self.training + x, y = batch + logits = self.forward(x) + loss = F.cross_entropy(logits, y) + self.log("train_loss", loss, prog_bar=True) + self.log("train_acc", self.train_acc(logits, y), prog_bar=True) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + assert not self.training + x, y = batch + logits = self.forward(x) + self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False) + self.log("val_acc", self.valid_acc(logits, y), prog_bar=True) + + def test_step(self, batch, batch_idx): + assert not self.training + x, y = batch + logits = self.forward(x) + self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False) + self.log("test_acc", self.test_acc(logits, y), prog_bar=True) + + def predict_step(self, batch, batch_idx): + assert not self.training + x, _ = batch + return self.forward(x) + + +class RandomDictDataset(Dataset): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self, size: int, length: int): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + a = self.data[index] + b = a + 2 + return {"a": a, "b": b} + + def __len__(self) -> int: + return self.len + + +class RandomDataset(Dataset): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self, size: int, length: int): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index: int) -> Tensor: + return self.data[index] + + def __len__(self) -> int: + return self.len + + +class RandomIterableDataset(IterableDataset): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self) -> Iterator[Tensor]: + for _ in range(self.count): + yield torch.randn(self.size) + + +class RandomIterableDatasetWithLen(IterableDataset): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self) -> Iterator[Tensor]: + for _ in range(len(self)): + yield torch.randn(self.size) + + def __len__(self) -> int: + return self.count + + +class BoringModel(LightningModule): + """Testing PL Module. + + Use as follows: + - subclass + - modify the behavior for what you want + + .. warning:: This is meant for testing/debugging and is experimental. + + Example:: + + class TestModel(BoringModel): + def training_step(self, ...): + ... # do your own thing + + """ + + def __init__(self) -> None: + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + @property + def dummy_input(self): + return {'x': torch.randn(32)} + + def forward(self, x: Tensor) -> Tensor: + return self.layer(x) + + def loss(self, preds: Tensor, labels: Optional[Tensor] = None) -> Tensor: + if labels is None: + labels = torch.ones_like(preds) + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(preds, labels) + + def step(self, batch: Any) -> Tensor: + output = self(batch) + return self.loss(output) + + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + assert self.training + return {"loss": self.step(batch)} + + def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + assert not self.training + return {"x": self.step(batch)} + + def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + assert not self.training + return {"y": self.step(batch)} + + def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[LRScheduler]]: + optimizer = torch.optim.SGD(self.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self) -> DataLoader: + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self) -> DataLoader: + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self) -> DataLoader: + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self) -> DataLoader: + return DataLoader(RandomDataset(32, 64)) diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py new file mode 100644 index 00000000..10ac539d --- /dev/null +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -0,0 +1,132 @@ +import os +from pathlib import Path +import math + +import torch +from lightning import Trainer +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, Timer +from lightning.fabric.utilities.cloud_io import _load as pl_load + +import pytest +from unittest.mock import Mock, patch + +from nnscaler.parallel import ComputeConfig +from nnscaler.integration.lightning.pytorch import NnScalerStrategy, NnScalerPrecision + +from ....launch_torchrun import launch_torchrun +from .simple_datamodules import ClassifDataModule +from .simple_models import BoringModel, ClassificationModel + + +def fit_worker(tmp_path): + dm = ClassifDataModule() + model = ClassificationModel() + compute_config=ComputeConfig(2, 2) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=2, + accelerator="gpu", devices=2, + gradient_clip_val=2.0, + strategy=NnScalerStrategy(compute_config=compute_config, pas_policy='tp', gen_savedir=tmp_path), + plugins=[NnScalerPrecision('32-true')] + ) + trainer.fit(model, datamodule=dm) + trainer.validate(model, datamodule=ClassifDataModule()) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_multi_gpu_model_only(tmp_path): + launch_torchrun(2, fit_worker, tmp_path) + + +def ckpt_path_epoch_restored_worker(tmp_path): + """Verify resuming from checkpoint runs the right number of epochs.""" + + class TestModel(BoringModel): + # Model that tracks epochs and batches seen + num_epochs_end_seen = 0 + num_batches_seen = 0 + num_on_load_checkpoint_called = 0 + + def on_train_epoch_end(self): + self.num_epochs_end_seen += 1 + + def on_train_batch_start(self, *_): + self.num_batches_seen += 1 + + def on_load_checkpoint(self, _): + self.num_on_load_checkpoint_called += 1 + + model = TestModel() + max_epochs = 2 + compute_config=ComputeConfig(2, 2) + trainer = Trainer( + max_epochs=max_epochs, + limit_train_batches=0.65, + limit_val_batches=1, + callbacks=ModelCheckpoint(dirpath=tmp_path, save_top_k=-1), + default_root_dir=tmp_path, + val_check_interval=1.0, + enable_progress_bar=False, + logger=False, + enable_model_summary=False, + strategy=NnScalerStrategy(compute_config=compute_config, pas_policy='tp', gen_savedir=tmp_path), + plugins=[NnScalerPrecision('32-true')] + ) + trainer.fit(model) + + assert model.num_epochs_end_seen == max_epochs + assert model.num_batches_seen == trainer.num_training_batches * max_epochs == trainer.global_step + assert model.num_on_load_checkpoint_called == 0 + + checkpoints = sorted(list(set(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt")))) + + assert len(checkpoints) == max_epochs + for ckpt in checkpoints: + model = TestModel() + state = pl_load(ckpt / '0.pt') + # Resume training + trainer = Trainer( + default_root_dir=tmp_path, max_epochs=2, enable_progress_bar=False, + strategy=NnScalerStrategy( + compute_config=compute_config, + pas_policy='tp', + gen_savedir=tmp_path + ), + plugins=[NnScalerPrecision('32-true')] + ) + trainer.fit(model, ckpt_path=ckpt) + assert state["global_step"] + model.num_batches_seen == trainer.global_step + assert model.num_on_load_checkpoint_called == 1 + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_ckpt_path_epoch_restored(tmp_path): + launch_torchrun(2, ckpt_path_epoch_restored_worker, tmp_path) + + +def trainer_accumulate_grad_batches_zero_grad(tmp_path, accumulate_grad_batches): + with patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: + model = BoringModel() + trainer = Trainer( + num_nodes=1, + devices=2, + default_root_dir=tmp_path, + num_sanity_val_steps=0, + limit_train_batches=20, + limit_val_batches=1, + max_epochs=1, + enable_model_summary=False, + accumulate_grad_batches=accumulate_grad_batches, + strategy=NnScalerStrategy(compute_config=ComputeConfig(1, 2), pas_policy='tp', gen_savedir=tmp_path), + plugins=[NnScalerPrecision('32-true')] + ) + assert trainer.accumulate_grad_batches == accumulate_grad_batches + trainer.fit(model) + assert sgd_zero_grad.call_count == math.ceil(trainer.limit_train_batches / accumulate_grad_batches) + + +@pytest.mark.parametrize("accumulate_grad_batches", [1, 2, 3]) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_trainer_accumulate_grad_batches_zero_grad(tmp_path, accumulate_grad_batches): + launch_torchrun(2, trainer_accumulate_grad_batches_zero_grad, tmp_path, accumulate_grad_batches) From 16e99c3554ca01d899f6d8a3f3d23c8a303ea4f6 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 5 Jun 2024 06:14:42 +0000 Subject: [PATCH 1652/1892] Merged PR 2166: rename dynamic_shape to constant_folding parity check pass unit test pass --- .gitignore | 1 + docs/source/parallel_module.md | 21 +++++++++- examples/llama/generation.py | 2 +- examples/vision/swin/train.py | 2 +- nnscaler/compiler.py | 6 +-- nnscaler/graph/parser/converter.py | 18 ++++---- nnscaler/graph/parser/fx/parser.py | 16 ++++---- nnscaler/ir/cten.py | 41 +++++++++++-------- nnscaler/parallel.py | 12 +++--- nnscaler/program.py | 14 +++---- tests/autodist/graph/test_calc_flops.py | 2 +- tests/autodist/graph/test_recompute.py | 2 +- .../pas/test_shared_param_pipeline.py | 2 +- .../spmd_solver/test_cube_operator.py | 2 +- tests/autodist/spmd_solver/test_follow.py | 4 +- .../spmd_solver/test_partition_constraint.py | 2 +- .../autodist/spmd_solver/test_shared_param.py | 2 +- tests/graph/function/test_dimops.py | 4 +- tests/graph/gener/test_reducer_gen.py | 2 +- tests/graph/parser/test_converter.py | 4 +- tests/graph/parser/test_ir_obj_constant.py | 2 +- tests/graph/parser/test_parser.py | 8 ++-- tests/parallel_module/test_attr_dedup.py | 4 +- tests/parallel_module/test_embedding.py | 10 ++--- tests/parallel_module/test_gencode.py | 28 ++++++------- tests/parallel_module/test_override.py | 4 +- tests/test_program.py | 2 +- 27 files changed, 120 insertions(+), 97 deletions(-) diff --git a/.gitignore b/.gitignore index f76d9cb7..855c7e42 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ shelf # cppimport generated file .rendered.*.cpp +.nnscaler/ diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 541687f0..2421892f 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -188,7 +188,7 @@ class ComputeConfig: plan_ngpus: int runtime_ngpus: int - dynamic_shape: bool = True + constant_folding: bool = False use_zero: bool = False zero_ngroups: int = 1 @@ -207,7 +207,24 @@ class ComputeConfig: We can categorize the fields into 4 categories: 1. Trace configuration - - `dynamic_shape`: whether to use dynamic shape or static shape. + - `constant_folding`: whether to enable constant folding when generating code. + When it is true, all non-tensor non-input values will be folded into the generated code. + + For example, if user's code contains following snippet, and `bsz=1`, `num_heads=32`, `len=1024`, `hidden_dim=128` at tracing. + ```python + bsz, num_heads, len, hidden_dim = x.size() + x = x.view(bsz * num_heads, len, hidden_dim) + ``` + The code (graph) is folded into the following format + + ```python + y = x.view(32, 1024, 128) + ``` + + Constant folding is helpful to simplify the input program, + and can make the compiling process faster and reduce the communication cost at runtime. + However, user should make sure that inputs at runtime share a same schema (including shape) with tracing and correspond to a same computation graph. + Errors may be raised at runtime when this assumption is broken. 2. Compute environment configuration - `plan_ngpus`: the number of gpus to be used as a unit. The model is partitioned (TP or PP) within a unit, and then data parallelism is applied across multiple units. So every `plan_ngpus` devices holds the whole model. Furthermore, assume we have two workers, and their ranks are `rank1` and `rank2`: 1. if `rank1 // plan_gpus == rank2 // plan_ngpus`, then they are in the same unit. diff --git a/examples/llama/generation.py b/examples/llama/generation.py index 52b12b0b..d603da48 100644 --- a/examples/llama/generation.py +++ b/examples/llama/generation.py @@ -116,7 +116,7 @@ def policy(graph, resource): return graph @compile(self.model, sample_tokens, 0, - PAS=policy, model_dynamic_shape=True) + PAS=policy, model_constant_folding=False) def infer(model: torch.nn.Module, tokens: torch.Tensor, prev_pos: int): logits = model(tokens, prev_pos) return logits diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index c81cc285..867c9c8c 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -149,7 +149,7 @@ def train(args, compute_config: nnscaler.ComputeConfig): runtime_ngpus=torch.distributed.get_world_size(), use_zero=args.zero, use_end2end=True, - dynamic_shape=False, + constant_folding=True, use_pipeline=args.pp_size > 1, pipeline_nmicros=args.gbs // args.mbs, pipeline_nstages=args.pp_size, diff --git a/nnscaler/compiler.py b/nnscaler/compiler.py index 22d7dfb6..128fb363 100644 --- a/nnscaler/compiler.py +++ b/nnscaler/compiler.py @@ -38,7 +38,7 @@ def compile(model: Union[torch.nn.Module, SemanticModel], *args, PAS: Union[Callable, Tuple[Callable, Callable, Callable]] = None, - model_dynamic_shape: bool = False, + model_constant_folding: bool = True, load_graph_file: Optional[str] = None, save_graph_file: Optional[str] = None, comm_cost_fn: Optional[Callable] = None, @@ -61,7 +61,7 @@ def train_iter(model, dataloader): model (SemanticModel | torch.nn.Module): single-device model args (Tuple[Any]): compile function example inputs PAS (Callable | Tuple[Callable, Callable, Callable]): policy to transform and schedule graph - model_dynamic_shape (bool): whether to compile model with dynamic shape + model_constant_folding (bool): whether to compile model with constant folding load_graph_file (str | None): load cached graph. This will skip parsing the function and model. Note the user should keep correct `fullmodel.pt` if load_content is True. @@ -90,7 +90,7 @@ def train_iter(model, dataloader): model = SemanticModel(model) assert isinstance(model, SemanticModel), f'Require nnscaler.SemanticModel or torch.nn.Module, but got model: {type(model)}' model.save_content = load_content - model.dynamic_shape = model_dynamic_shape + model.constant_folding = model_constant_folding inputs = [model] for arg in args: diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index f614af72..30b37e17 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -104,7 +104,7 @@ def to_ir_graph( traced_model: torch.fx.GraphModule, dummy_input: Dict[str, Any], attr_savedir: Union[str, Path], - dynamic_shape: bool = False, + constant_folding: bool = True, ) -> IRGraph: """Convert torch.fx.GraphModule based model into IRGraph @@ -112,20 +112,20 @@ def to_ir_graph( traced_model (torch.fx.GraphModule): single-device model description in fx format dummy_input (Dict[str, Any]): dummy input of model, the keys are the names of forward arguments. - dynamic_shape (bool): - whether to use dynamic shape. Default False. + constant_folding (bool): + whether to enable constant folding. Default True. attr_savedir (Union[str, Path]): directory to save content (attribtes) Returns: IRGraph: IRGraph of model """ - _logger.info(f"use {'dynamic' if dynamic_shape else 'static'} shape to parse graph") + _logger.info(f"constant folding {'enabled' if constant_folding else 'disabled'} to parse graph") with no_save_tensor_hook(): inputs, nodes, outputs = FxModuleParser.parse( traced_model, dummy_input, attr_savedir=attr_savedir, - dynamic_shape=dynamic_shape, + constant_folding=constant_folding, save_content=True, ) module_name = traced_model.__class__.__name__ @@ -142,7 +142,7 @@ def convert_model( model: torch.nn.Module, dummy_input: Dict[str, Any], attr_savedir: Union[str, Path], - dynamic_shape: bool = False + constant_folding: bool = True ) -> IRGraph: """Convert torch.nn.Module based model into IRGraph @@ -150,8 +150,8 @@ def convert_model( model (torch.nn.Module): single-device model description dummy_input (Dict[str, Any]): dummy input of model, the keys are the names of forward arguments. - dynamic_shape (bool): - whether to use dynamic shape. Default False. + constant_folding (bool): + whether to use constant folding. Default True. attr_save_dir (Union[str, Path]): directory to save content (attribtes) Returns: @@ -159,5 +159,5 @@ def convert_model( """ traced_model = to_fx_graph(model, dummy_input) _logger.debug(f'the traced model is:\n{traced_model}') - graph = to_ir_graph(traced_model, dummy_input, attr_savedir, dynamic_shape) + graph = to_ir_graph(traced_model, dummy_input, attr_savedir, constant_folding) return graph diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index cc30dfcd..d9464ed5 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -35,7 +35,7 @@ def parse(module: torch.fx.GraphModule, attr_savedir='./', *, save_content: bool = True, - dynamic_shape: bool = True + constant_folding: bool = False ) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: """Parse torch.fx module into cube IR @@ -46,7 +46,7 @@ def parse(module: torch.fx.GraphModule, dummy_inputs (Dict[str, Any]): the dummy inputs to run the module attr_savedir (str): the directory to save the attribute content save_content (bool): whether to save the content of the module - dynamic_shape (bool): whether to parse the module with dynamic shape + constant_folding (bool): whether to parse the module with constant folding Returns: inputs (List[IRObject]): the input IRObjects @@ -64,7 +64,7 @@ def parse(module: torch.fx.GraphModule, if node.op == 'placeholder': FxModuleParser.init_objects(node, module, frame, dummy_inputs.get(node.name), is_constant=False) else: - FxModuleParser.init_objects(node, module, frame, None, is_constant=True) + FxModuleParser.init_objects(node, module, frame, None, is_constant=True) # get graph inputs placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] @@ -80,7 +80,7 @@ def parse(module: torch.fx.GraphModule, # parse graph nodes all_ir_nodes = [] for node in module.graph.nodes: - ir_nodes = FxModuleParser.parse_node(node, module, dynamic_shape, frame) + ir_nodes = FxModuleParser.parse_node(node, module, constant_folding, frame) all_ir_nodes += ir_nodes # get graph outputs @@ -95,7 +95,7 @@ def parse(module: torch.fx.GraphModule, return inputs, all_ir_nodes, outputs @staticmethod - def parse_node(node: torch.fx.Node, module, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: + def parse_node(node: torch.fx.Node, module, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: """ Parse the node and return the IRFwOperation nodes """ @@ -104,7 +104,7 @@ def parse_node(node: torch.fx.Node, module, dynamic_shape: bool, frame: Frame) - if node.op == 'output': return FxModuleParser.parse_prim_output_node(node, module, frame) if node.op in ('call_function', 'call_method'): - return FxModuleParser.parse_prim_function_method(node, module, dynamic_shape, frame) + return FxModuleParser.parse_prim_function_method(node, module, constant_folding, frame) if node.op == 'get_attr': return FxModuleParser.parse_prim_get_attr_node(node, module, frame) if node.op == 'call_module': @@ -202,7 +202,7 @@ def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: raise RuntimeError(f'unknown module: {prim_module.__class__.__module__}') @staticmethod - def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, dynamic_shape: bool, frame: Frame) -> List[IRFwOperation]: + def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: # get signature fsig = FxModuleParser._get_qualified_name(node.target, node) # get inputs @@ -243,7 +243,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule for i in range(len(vals)): ir_node.set_output(i, vals[i]) elif not isinstance(ir_node.output(0), IRTensor) and ir_node.output(0).value is not None: - if dynamic_shape or \ + if not constant_folding or \ any_ir_object_satisfy(ir_node.output(0), lambda a: not a.is_constant) or \ any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, IRTensor)) or \ any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, (DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE))): diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 5bfa4bd8..a822a99e 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -46,7 +46,7 @@ def __init__(self, to prevent users accidently updating it outside. To update the IRObject in input / kwarg / output, please use `find`, `input(s)`, - and `output(s)` to get the real instance tensor in the IRCell. + and `output(s)` to get the real instance tensor in the IRCell. Args: name (str): the cell name @@ -106,7 +106,7 @@ def dispatch(self, device: int): For single operators, the mirror node will be reserved. For nodes that cover multiple devices, e.g., IRSegment and IRAdapter, the mirror node will be removed and require additional `make_pair` elsewhere. - + @param device int: device id @return dispatched_node IRCell: the node that only has one device placement. """ @@ -171,7 +171,7 @@ def inputs(self) -> Tuple[NestedVarOrStatic]: Tuple[NestedVarOrStatic] """ return tuple(self._inputs) - + @lru_cache(maxsize=None) def iobjs(self) -> Tuple[IRObject]: """ @@ -194,7 +194,7 @@ def output(self, index: int) -> NestedVarOrStatic: NestedVarOrStatic: (nested) IRObject or any static value (int, bool, str, etc) """ return self._outputs[index] - + @lru_cache(maxsize=None) def oobjs(self) -> Tuple[IRObject]: """ @@ -266,7 +266,7 @@ def set_output(self, index: int, val: NestedVarOrStatic): self.outputs.cache_clear() self.oobjs.cache_clear() return val - + def replace_input(self, old: IRObject, new: IRObject): """Replace the old input (including kwargs) with the new input @@ -283,7 +283,7 @@ def replace(obj): self._kwargs = IRCell.modify_objects_of_complex(self._kwargs, replace) self.inputs.cache_clear() self.iobjs.cache_clear() - + def replace_output(self, old: IRObject, new: IRObject): """Replace the old output with the new output @@ -299,7 +299,7 @@ def replace(obj): self._outputs = IRCell.modify_objects_of_complex(self._outputs, replace) self.outputs.cache_clear() self.oobjs.cache_clear() - + def replace(self, old: IRObject, new: IRObject): """Replace the old object with the new object in inputs, kwargs, and outputs @@ -335,7 +335,7 @@ def comment(self, info: str): Tag an info to the cell """ assert isinstance(info, str), "comment only allowed to be string" - self._comment = info + self._comment = info @property def module_stack(self) -> Optional[OrderedDict[str, Any]]: @@ -357,13 +357,13 @@ def __repr__(self) -> str: f"inputs={ins}, " f"outputs={self.outputs()})") return dscp - + @staticmethod def get_objects_from_complex(val: Any, _objects: List[IRObject] = None) -> List[IRObject]: """Get all IRObjects from a complex data structure Supported complex of types: List, Tuple, Dict, IRTensor, IRObject - + Args: val (Any): the complex data structure to be modified _objects (List[IRObject] | None): @@ -419,10 +419,15 @@ class IRObject: def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: Optional[None] = None, is_constant: bool = True): """ - @param name str: object name - @param tid int: object unique id - @param val any: the value of this object - @param is_constant bool: if the value is a constant during the whole training / inference + Args: + name (str): object name + tid (int): object unique id + val (Any): the value of this object + is_constant (bool): if the value is a constant during the whole training / inference + This flag is only used in constant_folding mode, to prevent the object from being folded. + An IROject is considered constant only when: + 1. val is not a tensor + 2. val is model input, or is the result of a non-torch operation on another constant IRObject """ self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() self.name: str = name if name else 'obj' @@ -458,7 +463,7 @@ def tid(self) -> int: @property def cell(self) -> IRCell: return self._cell - + @cell.setter def cell(self, val: Optional[IRCell]): assert isinstance(val, IRCell) or val is None, "Expected cell to be Optional[IRCell]" @@ -476,7 +481,7 @@ def device(self, val: Union[int, List[int]]): raise RuntimeError( "IRObject placement is not allowed to set manually" ) - + @property def parent(self): """Get parent""" @@ -533,7 +538,7 @@ class IRTensor(IRObject): IRTensor serves as tensor data of IRGraph edge Note by setting IRTensor name to "None" indicates this tensor holds nothing - and will be translated to None in code generation. + and will be translated to None in code generation. """ _meta = ['name', '_is_attr', '_is_grad', '_requires_grad', '_dtype', '_persistent'] @@ -582,7 +587,7 @@ def is_buffer(self) -> bool: @return is_buffer boolean: True if is buffer. """ return self._is_attr and not self.requires_grad - + def is_persistent(self) -> bool: """! Check if the tensor is persistent buffer. diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 2e2d73a6..99a4304d 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -67,8 +67,8 @@ class ComputeConfig: plan_ngpus: int runtime_ngpus: int - # whether to use dynamic shape to generate code - dynamic_shape: bool = True + # whether to fold constant when generating code + constant_folding: bool = False use_zero: bool = False zero_ngroups: int = 1 @@ -176,7 +176,7 @@ def gpu_config(self) -> Dict[str, int]: @property def graph_config(self) -> Dict[str, Any]: return { - 'dynamic_shape': self.dynamic_shape, + 'constant_folding': self.constant_folding, 'user_config': self.user_config, 'inference_only': self.inference_only, # there will be no backward nodes in the graph in inference mode 'use_pipeline': self.use_pipeline, # pipeline option can affect the graph generation. @@ -588,7 +588,7 @@ def _gen_graph( module: torch.nn.Module, dummy_input: dict, outdir: Path, - dynamic_shape: bool, + constant_folding: bool, end2end_mode: bool = False, inference_only: bool = False, use_pipeline: bool = False, @@ -610,7 +610,7 @@ def _gen_graph( # generate ir logic graph ir_graph = parser.to_ir_graph( - fx_graph, dummy_input, outdir, dynamic_shape + fx_graph, dummy_input, outdir, constant_folding ) # generate dummy inputs for logic graph @@ -754,7 +754,7 @@ def _gencode( graph, forward_args = _gen_graph( module, dummy_input, outdir, - dynamic_shape=compute_config.dynamic_shape, end2end_mode=compute_config.use_end2end, + constant_folding=compute_config.constant_folding, end2end_mode=compute_config.use_end2end, inference_only=compute_config.inference_only, use_pipeline=compute_config.use_pipeline, ) diff --git a/nnscaler/program.py b/nnscaler/program.py index 03b6abfd..bfc1e6f2 100644 --- a/nnscaler/program.py +++ b/nnscaler/program.py @@ -87,7 +87,7 @@ def __init__(self, dataloader: MicroBatchDataLoader): """Create semantic dataloader representing the dataloader in training iteration. Calling `next(SemanticDataLoader)` will generate an IRDataOperation in graph, - which takes the `self.irobj` (i.e., reperesenting the non-tensor value of real + which takes the `self.irobj` (i.e., reperesenting the non-tensor value of real dataloader instance) as input and produces outputs that are converted to IRObject or IRTensor. The IRDataOperation will be added to the final graph and generate code like `data = next(dataloader)` @@ -100,7 +100,7 @@ def __init__(self, dataloader: MicroBatchDataLoader): self.dataloader: data.DataLoader = dataloader # the IRObject representing the `dataloader` instance, which is only used by the # IRDataOperation. Since we already know the output of the dataloader, - # we don't need to set the value for it. + # we don't need to set the value for it. self.irobj = IRObject(name='dataloader', value=None) def __iter__(self): @@ -141,7 +141,7 @@ class SemanticModel: def __init__(self, model: Optional[torch.nn.Module], save_content: bool = True, - dynamic_shape: bool = False, + constant_folding: bool = True, attr_savedir: str = './', ): """ @@ -152,8 +152,8 @@ def __init__(self, model: Optional[torch.nn.Module], single-device model description, only required for rank 0 save_content (bool): whether to save the content of model and load it into generated model. Default True. - dynamic_shape (bool): - whether to use dynamic shape. Default False. + constant_folding (bool): + whether to enable constant folding. Default True. attr_savedir (str): directory to save content (attribtes) """ @@ -165,7 +165,7 @@ def __init__(self, model: Optional[torch.nn.Module], self._loaded_module: CubeModule = None # parser configuration self.save_content: bool = save_content - self.dynamic_shape: bool = dynamic_shape + self.constant_folding: bool = constant_folding self.attr_savedir: str = attr_savedir @property @@ -244,6 +244,6 @@ def __call__(self, *args): self.model, dummy_input=self.dummy_input, attr_savedir=self.attr_savedir, - dynamic_shape=self.dynamic_shape + constant_folding=self.constant_folding ) return self._ir_graph(*args) diff --git a/tests/autodist/graph/test_calc_flops.py b/tests/autodist/graph/test_calc_flops.py index dac7e72c..1f1d8295 100644 --- a/tests/autodist/graph/test_calc_flops.py +++ b/tests/autodist/graph/test_calc_flops.py @@ -43,7 +43,7 @@ def test_calc_flops(): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, - dynamic_shape=True) + constant_folding=False) nodes = ir_graph.select(ntype=IRFwOperation) assert calc_flops( nodes[0]) == 2 * batch_size * hidden_dim * hidden_dim * hidden_dim diff --git a/tests/autodist/graph/test_recompute.py b/tests/autodist/graph/test_recompute.py index a75d4ab4..8fe551ba 100644 --- a/tests/autodist/graph/test_recompute.py +++ b/tests/autodist/graph/test_recompute.py @@ -88,7 +88,7 @@ def test_recompute(): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, - dynamic_shape=True) + constant_folding=False) config = AutoDistConfig(recompute_modules='Decoder.Layer') model_graph = ModelGraph(ir_graph, config) diff --git a/tests/autodist/pas/test_shared_param_pipeline.py b/tests/autodist/pas/test_shared_param_pipeline.py index 2f784287..482f9365 100644 --- a/tests/autodist/pas/test_shared_param_pipeline.py +++ b/tests/autodist/pas/test_shared_param_pipeline.py @@ -53,7 +53,7 @@ def test_shared_param_pipeline(): smodel = SemanticModel(model, attr_savedir=tempdir) smodel.dummy_input = {'x': torch.randn(bsz, hidden_dim)} - smodel.dynamic_shape = False + smodel.constant_folding = True program.set_input([dataloader.irobj]) ir_dummy_input = next(dataloader) outputs = smodel(ir_dummy_input) diff --git a/tests/autodist/spmd_solver/test_cube_operator.py b/tests/autodist/spmd_solver/test_cube_operator.py index 0ce7c21c..29cab169 100644 --- a/tests/autodist/spmd_solver/test_cube_operator.py +++ b/tests/autodist/spmd_solver/test_cube_operator.py @@ -42,7 +42,7 @@ def test_cube_operator(): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, - dynamic_shape=False) + constant_folding=True) cfg = AutoDistConfig(mesh_col=2) model_graph = ModelGraph(ir_graph, cfg) mock_attention_op = model_graph.operator_list[0] diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index e992d485..bd9a06c3 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -54,7 +54,7 @@ def test_follow_rope(): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, - dynamic_shape=False) + constant_folding=True) ''' the computation graph is as follows: getitem getitem @@ -167,7 +167,7 @@ def test_follow_attention(): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, - dynamic_shape=False) + constant_folding=True) print(ir_graph.nodes()) ''' the computation graph is as follows: diff --git a/tests/autodist/spmd_solver/test_partition_constraint.py b/tests/autodist/spmd_solver/test_partition_constraint.py index d1cee788..95bacaef 100644 --- a/tests/autodist/spmd_solver/test_partition_constraint.py +++ b/tests/autodist/spmd_solver/test_partition_constraint.py @@ -71,7 +71,7 @@ def test_partition_constraint(): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, - dynamic_shape=False) + constant_folding=True) pc_path = Path(os.path.dirname( os.path.realpath(__file__))) / 'test_pc.yaml' diff --git a/tests/autodist/spmd_solver/test_shared_param.py b/tests/autodist/spmd_solver/test_shared_param.py index f1312644..920d93a1 100644 --- a/tests/autodist/spmd_solver/test_shared_param.py +++ b/tests/autodist/spmd_solver/test_shared_param.py @@ -42,7 +42,7 @@ def test_shared_param(): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, - dynamic_shape=False) + constant_folding=True) cfg = AutoDistConfig(mesh_col=4) model_graph = ModelGraph(ir_graph, cfg) diff --git a/tests/graph/function/test_dimops.py b/tests/graph/function/test_dimops.py index b516fe12..900215f4 100644 --- a/tests/graph/function/test_dimops.py +++ b/tests/graph/function/test_dimops.py @@ -76,8 +76,8 @@ def TestFunc(input, weight, number=128, signature='test_func'): partitionable(op, idx=0, dim=0, num=2) -def test_dynamic_shape_infer(): - # TODO: please note that this test should be rewritten after we can fully support dynamic shape +def test_constant_folding_infer(): + # TODO: please note that this test should be rewritten after we can fully support constant folding def TestFunc(input, weight, bias, number=128, signature='test_func'): anno = '(a number), (b number), (a b) -> 1' return IRDimops(TestFunc, 'test_func', signature, [anno], [input, weight, bias], number=number) diff --git a/tests/graph/gener/test_reducer_gen.py b/tests/graph/gener/test_reducer_gen.py index 332336b7..c318d39b 100644 --- a/tests/graph/gener/test_reducer_gen.py +++ b/tests/graph/gener/test_reducer_gen.py @@ -43,7 +43,7 @@ def build_graph(): model, {'x': torch.randn([128, 128], dtype=torch.float16)}, attr_savedir=tempdir, - dynamic_shape=False + constant_folding=True ) graph.backward(graph.output(0)) return graph diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index e960488d..65e4e905 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -41,8 +41,8 @@ def forward(self, x, **kwargs): assert any(node.op == 'call_function' and node.target == torch.nn.functional.linear for node in nodes) with tempfile.TemporaryDirectory() as tempdir: - to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) - ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) assert ir_graph is not None assert (Path(tempdir) / FxModuleParser.ATTR_MAP_FILE).exists() assert (Path(tempdir) / FxModuleParser.ATTR_CONTENT_FILE_0).exists() diff --git a/tests/graph/parser/test_ir_obj_constant.py b/tests/graph/parser/test_ir_obj_constant.py index d493e2fb..0263d824 100644 --- a/tests/graph/parser/test_ir_obj_constant.py +++ b/tests/graph/parser/test_ir_obj_constant.py @@ -30,7 +30,7 @@ def forward(self, sample): return self.fc(sample['x']), res with tempfile.TemporaryDirectory() as tempdir: - cube_graph = convert_model(SimpleModel(), {'sample': {'x': torch.rand(4, 10), 'y': 10}}, tempdir, dynamic_shape=True) + cube_graph = convert_model(SimpleModel(), {'sample': {'x': torch.rand(4, 10), 'y': 10}}, tempdir, constant_folding=False) # check input is not constant assert not cube_graph.input(0).value['y'].is_constant for i, name in enumerate(['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow', 'sub', 'neg', 'exp', 'sqrt']): diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 9c6f6d6d..2b67a7e2 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -31,7 +31,7 @@ def forward(self, x): print(fx_graph.graph) with tempfile.TemporaryDirectory() as tempdir: - ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) assert ir_graph is not None assert len(ir_graph.attributes()) == 2 # param1 and param2 assert len(ir_graph.full_tensors()) == 8 @@ -60,7 +60,7 @@ def forward(self, x: dict): print(fx_graph.graph) with tempfile.TemporaryDirectory() as tempdir: - ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) print(ir_graph.extra_repr()) assert len(ir_graph.inputs()) == 1 @@ -88,7 +88,7 @@ def forward(self, x): print(fx_graph.graph) with tempfile.TemporaryDirectory() as tempdir: - ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) print(ir_graph.extra_repr()) assert isinstance(ir_graph.output(0), IRTensor) @@ -111,7 +111,7 @@ def forward(self, x): print(fx_graph.graph) with tempfile.TemporaryDirectory() as tempdir: - ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, dynamic_shape=True) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) print(ir_graph.extra_repr()) assert isinstance(ir_graph.output(0), IRTensor) diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py index 8e5312bc..146a9c98 100644 --- a/tests/parallel_module/test_attr_dedup.py +++ b/tests/parallel_module/test_attr_dedup.py @@ -54,7 +54,7 @@ def _gpu_worker_spmd(cc: ComputeConfig): {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, pas, cc, - cube_savedir=tempdir, + gen_savedir=tempdir, instance_name='attr_dedup' ) print(module.fullmap) @@ -84,7 +84,7 @@ def _gpu_worker_spmd(cc: ComputeConfig): assert dedup_area_map[0][1].slicers == (slice(2, 4, None), slice(0, 4, None)) else: raise RuntimeError(f'Unexpected rank {curr_rank}') - + @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_attr_dedup(): diff --git a/tests/parallel_module/test_embedding.py b/tests/parallel_module/test_embedding.py index 6192ddb9..21d1a8d1 100644 --- a/tests/parallel_module/test_embedding.py +++ b/tests/parallel_module/test_embedding.py @@ -11,21 +11,21 @@ def __init__(self): def forward(self, x): return self.embed(x).sum() - + @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_requires_grad(): model = Model() model.train() - + dummy_input = {'x': torch.randint(0, 10, (10, 10))} - + with tempfile.TemporaryDirectory() as tempdir: - + graph, _ = _gen_graph( model, dummy_input, outdir=tempdir, - dynamic_shape=False, + constant_folding=True, end2end_mode=True, ) embed_op = graph.nodes()[1] diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index c455f6db..5efff8ab 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -704,7 +704,7 @@ def test_codegen_end2end(): dim = 1024 nlayers = 16 batch_size = 64 - def p(cube_dir, use_pipeline, dynamic_shape, return_type, inference_only=False): + def p(cube_dir, use_pipeline, constant_folding, return_type, inference_only=False): m = End2EndModule(dim, nlayers) m.train() parallelize( @@ -714,7 +714,7 @@ def p(cube_dir, use_pipeline, dynamic_shape, return_type, inference_only=False): compute_config= ComputeConfig( 4, 4, inference_only=inference_only, - dynamic_shape=dynamic_shape, + constant_folding=constant_folding, use_end2end=True, use_pipeline=use_pipeline, pipeline_nmicros=4, @@ -727,33 +727,33 @@ def p(cube_dir, use_pipeline, dynamic_shape, return_type, inference_only=False): ) with tempfile.TemporaryDirectory() as tempdir: for use_pipeline in [True, False]: - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=0) # should success + p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=0) # should success assert not _gencode_contains(tempdir, End2EndModule, 0, r"self\.register_buffer" ) assert _gencode_contains(tempdir, End2EndModule, 0, r"self\.register_parameter" ) - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=0) # should success + p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=0) # should success if use_pipeline: with pytest.raises(RuntimeError, match='.*Communication generation.*'): # fail for non-tensor IRObject return in pipeline mode - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=1) + p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=1) else: - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=1) - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=1) # should success - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=2) # should success - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=2) # should success + p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=1) + p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=1) # should success + p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=2) # should success + p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=2) # should success with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=3) + p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=3) with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=3) + p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=3) with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=4) + p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=4) with pytest.raises(RuntimeError, match='.*Loss can only be scalar tensor.*'): - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=False, return_type=4) + p(tempdir, use_pipeline=use_pipeline, constant_folding=True, return_type=4) - p(tempdir, use_pipeline=use_pipeline, dynamic_shape=True, return_type=0, inference_only=True) # should success + p(tempdir, use_pipeline=use_pipeline, constant_folding=False, return_type=0, inference_only=True) # should success assert not _gencode_contains(tempdir, End2EndModule, 0, r"self\.register_parameter" ) diff --git a/tests/parallel_module/test_override.py b/tests/parallel_module/test_override.py index c8b9dc8f..90da6b6d 100644 --- a/tests/parallel_module/test_override.py +++ b/tests/parallel_module/test_override.py @@ -83,7 +83,7 @@ def test_override(): # MOO | unmatch | generate _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o1', load_module=False) - _to_cube_model(MyModule, ComputeConfig(2, 2, dynamic_shape=False),tempdir, ReuseType.MOO, 'o1') + _to_cube_model(MyModule, ComputeConfig(2, 2, constant_folding=True),tempdir, ReuseType.MOO, 'o1') # MOO | imported | raise error _to_cube_model(MyModule, ComputeConfig(1, 1),tempdir, ReuseType.MOO, 'o2', load_module=True) @@ -173,7 +173,7 @@ def test_override(): g7_module_path = module_path.with_name('g7') graph_stat = (g7_module_path / 'graph.ckp').stat() args_stat = (g7_module_path / 'forward_args.pkl').stat() - _to_cube_model(MyModule, ComputeConfig(2, 2, dynamic_shape=False), tempdir, 'graph', 'g7', False) + _to_cube_model(MyModule, ComputeConfig(2, 2, constant_folding=True), tempdir, 'graph', 'g7', False) assert (g7_module_path / 'graph.ckp').stat().st_mtime_ns != graph_stat.st_mtime_ns assert (g7_module_path / 'forward_args.pkl').stat().st_mtime_ns != args_stat.st_mtime_ns diff --git a/tests/test_program.py b/tests/test_program.py index 53b34b3c..8ed79062 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -30,7 +30,7 @@ def forward(self, x: dict): dummy_input = {'x': {'data': torch.randn(4, 4)}} module = MyModule() - model = SemanticModel(module, save_content=False, dynamic_shape=False) + model = SemanticModel(module, save_content=False, constant_folding=True) obj = IRObject(value=dummy_input['x']) model(obj) From 906a57427be975b080f2dfdc7317bd0e06270e47 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 13 Jun 2024 06:37:38 +0000 Subject: [PATCH 1653/1892] Merged PR 2173: Refine infer batch dim in autodist Assume operators with parameters consume and generate tensors with batch dim. A search is followed to propagate the possible batch dim to the whole graph. Add a test to check autodist will generate data parallel plan. parity check passed --- nnscaler/autodist/model_graph.py | 19 +- nnscaler/autodist/spmd_solver.py | 2 +- nnscaler/compiler.py | 1 - nnscaler/profiler/memory.py | 1 - tests/autodist/graph/test_recompute.py | 4 + tests/autodist/spmd_solver/test_follow.py | 72 ++++++ .../comp/torch.Tensor.contiguous.json | 62 +++++ .../comp/torch.Tensor.reshape.json | 50 ++++ .../comp/torch.Tensor.view.json | 50 ++++ .../comp/torch.div.json | 62 +++++ .../comp/torch.matmul.json | 230 ++++++++++++++++++ .../comp/torch.nn.functional.dropout.json | 62 +++++ .../comp/torch.nn.functional.linear.json | 182 ++++++++++++++ .../comp/torch.nn.functional.softmax.json | 66 +++++ .../comp/torch.sum.json | 50 ++++ .../comp/torch.transpose.json | 182 ++++++++++++++ 16 files changed, 1081 insertions(+), 14 deletions(-) create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.contiguous.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.reshape.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.view.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.div.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.matmul.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.dropout.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.linear.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.softmax.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.sum.json create mode 100644 tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.transpose.json diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 99b0ff3f..23d95a08 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -882,20 +882,17 @@ def init_operators(self): src_op.add_consumer(dst_op) dst_op.add_producer(src_op) - # infer batch dims + # Infer batch dims + # Assume operators with parameters consume and generate tensors + # with batch dim. A search is followed to propagate the possible + # batch dim to the whole graph. seed_ops = [] visited = set() for op in operator_list: - if len(op.producers) == 0 and len(op.in_tensors) > 0: - contain_non_param = False - for t in op.in_tensors: - if not t.is_attr(): - contain_non_param = True - break - if contain_non_param: - _logger.info(f'add seed op {op.ir_cell}') - seed_ops.append(op) - visited.add(op.ir_cell.cid) + if any([t.is_param() for t in op.in_tensors]): + _logger.debug(f'add seed op {op.ir_cell}') + seed_ops.append(op) + visited.add(op.ir_cell.cid) dq = deque(seed_ops) while len(dq) > 0: op = dq.popleft() diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 02a9ee73..b2027e2e 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -1188,6 +1188,7 @@ def analyze_plan(self, plan: List[Tuple[int, int]]) -> Dict[str, Any]: sig2split_info[sig] = [] sig2comp_time[sig] += desc.comp_time op_idx2comp_time[op_idx] = desc.comp_time + comp_time_sum += desc.comp_time comm_cost = 0 for k, comm_vec in enumerate(desc.comm_time): producer = self.producers[op_idx][k] @@ -1207,7 +1208,6 @@ def analyze_plan(self, plan: List[Tuple[int, int]]) -> Dict[str, Any]: if sig not in sig2comm_time: sig2comm_time[sig] = 0 sig2comm_time[sig] += comm_cost - comp_time_sum += desc.comp_time comm_time_sum += comm_cost sig2comp_time = sorted(sig2comp_time.items(), key=lambda x: x[1], reverse=True) diff --git a/nnscaler/compiler.py b/nnscaler/compiler.py index 128fb363..066d9204 100644 --- a/nnscaler/compiler.py +++ b/nnscaler/compiler.py @@ -33,7 +33,6 @@ _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) def compile(model: Union[torch.nn.Module, SemanticModel], *args, diff --git a/nnscaler/profiler/memory.py b/nnscaler/profiler/memory.py index b7e82cd6..eeda2b0c 100644 --- a/nnscaler/profiler/memory.py +++ b/nnscaler/profiler/memory.py @@ -4,7 +4,6 @@ import torch _logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) def memory_summary(): diff --git a/tests/autodist/graph/test_recompute.py b/tests/autodist/graph/test_recompute.py index 8fe551ba..f1dbaccc 100644 --- a/tests/autodist/graph/test_recompute.py +++ b/tests/autodist/graph/test_recompute.py @@ -143,3 +143,7 @@ def test_recompute(): assert model_graph.min_recompute_mem == layer_node.train_mem fnodes = ir_graph.select(ntype=IRFwOperation) assert model_graph.recompute_groups == [fnodes[5 * (num_layers + i) : 5 * (num_layers + i) + 5] for i in range(num_layers)] + + # will label operator like GELU and add with `has_batch_dim=True` + for op in model_graph.operator_list: + assert op.has_batch_dim, f'{op} does not have batch dim' diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index bd9a06c3..03221091 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -283,3 +283,75 @@ def helper(search_out): ilp_spmd_outs = spmd_solver.do_ilp([(0, model_graph.op_num - 1)], 1) assert helper(dp_spmd_outs) == expected_out assert helper(ilp_spmd_outs) == expected_out + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') +def test_solver_data_parallel(): + from nnscaler.ir.unique import IDGenerator + IDGenerator().clear() + bsz, seq_len, hidden_dim, num_heads = 2, 2048, 512, 8 + dummy_input = { + 'x': torch.rand(bsz, seq_len, hidden_dim), + } + model = AttentionModel(hidden_dim, num_heads) + model.train() + fx_graph = to_fx_graph(model, dummy_input) + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = to_ir_graph(fx_graph, + dummy_input, + attr_savedir=tempdir, + constant_folding=True) + print(ir_graph.nodes()) + + profile_dir = Path(os.path.dirname(__file__)) / './test_solver_data_parallel' + cfg = AutoDistConfig(mesh_col=2, profile_dir=profile_dir) + model_graph = ModelGraph(ir_graph, cfg) + + spmd_solver = SPMDSolver( + graph=model_graph, + mesh_desc=cfg.mesh_desc, + autodist_config=cfg, + stage_num=1, + micro_batch_num=cfg.update_freq, + ) + + partition_counts = [ + spmd_solver.get_op_partition_count(i) + for i in range(model_graph.op_num) + ] + print(partition_counts) + assert partition_counts == [ + 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 6, 4, 4, 4, 6, 4, 4, 4, 5, 4 + ] + # should generate a pure data parallel plan, e.g., partition the batch dim + expected_out = [ + (2, (((0, 0), 2),)), + (3, (((0, 0), 2),)), + (4, (((0, 0), 2),)), + (5, (((0, 0), 2),)), + (6, (((0, 0), 2),)), + (7, (((0, 0), 2),)), + (8, (((0, 0), 2),)), + (9, (((0, 0), 2),)), + (10, (((0, 0), 2),)), + (11, (((0, 0), 2),)), + (12, (((0, 0), 2),)), + (13, (((0, 0), 2),)), + (14, (((0, 0), 2),)), + (15, (((0, 0), 2),)), + (16, (((0, 0), 2),)), + (17, (((0, 0), 2),)), + (18, (((0, 0), 2),)), + (19, (((0, 0), 2),)), + (20, (((0, 0), 2),)), + (21, (((0, 0), 2),)) + ] + + def helper(search_out): + return search_out[0][0].to_json()['desc']['partition_descs'] + + dp_spmd_outs = spmd_solver.do_dp([(0, model_graph.op_num - 1)], 1) + ilp_spmd_outs = spmd_solver.do_ilp([(0, model_graph.op_num - 1)], 1) + assert helper(dp_spmd_outs) == expected_out + assert helper(ilp_spmd_outs) == expected_out + diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.contiguous.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.contiguous.json new file mode 100644 index 00000000..66990601 --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.contiguous.json @@ -0,0 +1,62 @@ +{ + "(2, 2048, 8, 64)-(2, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00392962247133255, + "bw_span": 0.0361565500497818, + "infer_memory": 8388608, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 2048, 8, 64)-(1, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005548819899559021, + "bw_span": 0.022963620722293854, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1024, 8, 64)-(2, 1024, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006691738963127136, + "bw_span": 0.0254802405834198, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 2048, 4, 64)-(2, 2048, 4, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006204098463058472, + "bw_span": 0.023501552641391754, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 2048, 8, 32)-(2, 2048, 8, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.004957988858222961, + "bw_span": 0.01678112894296646, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.reshape.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.reshape.json new file mode 100644 index 00000000..f0933e95 --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.reshape.json @@ -0,0 +1,50 @@ +{ + "(2, 2048, 8, 64)-(2, 2048, 512) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.007220730185508728, + "bw_span": 0.026119686663150787, + "infer_memory": 8388608, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 2048, 8, 64)-(1, 2048, 512) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009209662675857544, + "bw_span": 0.0355185940861702, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1024, 8, 64)-(2, 1024, 512) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00962037593126297, + "bw_span": 0.03517419099807739, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 2048, 4, 64)-(2, 2048, 256) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009422190487384796, + "bw_span": 0.03640800714492798, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.view.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.view.json new file mode 100644 index 00000000..68f18636 --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.Tensor.view.json @@ -0,0 +1,50 @@ +{ + "(2, 2048, 512)-(2, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009936653077602386, + "bw_span": 0.033933669328689575, + "infer_memory": 8388608, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 2048, 512)-(1, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010452233254909515, + "bw_span": 0.036592036485672, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1024, 512)-(2, 1024, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010039843618869781, + "bw_span": 0.04428252577781677, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 2048, 256)-(2, 2048, 4, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009671784937381744, + "bw_span": 0.03549344837665558, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.div.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.div.json new file mode 100644 index 00000000..e2375b2c --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.div.json @@ -0,0 +1,62 @@ +{ + "(2, 8, 2048, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 268435456 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.7881719619035721, + "bw_span": 1.9588613882660866, + "infer_memory": 536870912, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 2048, 2048)-(1, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.3966972231864929, + "bw_span": 0.9810343384742737, + "infer_memory": 268435456, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 2048, 2048)-(2, 4, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.3962540999054909, + "bw_span": 0.9834336116909981, + "infer_memory": 268435456, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 1024, 2048)-(2, 8, 1024, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.3961294889450073, + "bw_span": 0.9815100580453873, + "infer_memory": 268435456, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 2048, 1024)-(2, 8, 2048, 1024) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.39616115391254425, + "bw_span": 0.9814225137233734, + "infer_memory": 268435456, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.matmul.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.matmul.json new file mode 100644 index 00000000..aa662f0b --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.matmul.json @@ -0,0 +1,230 @@ +{ + "(2, 8, 2048, 64)-(2, 8, 64, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 8388608, + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.5448272451758385, + "bw_span": 2.108476497232914, + "infer_memory": 285212672, + "train_mem_info": [ + 8388608, + 8388608 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 2048, 2048)-(2, 8, 2048, 64)-(2, 8, 2048, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 268435456, + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.9633848443627357, + "bw_span": 2.7184441685676575, + "infer_memory": 285212672, + "train_mem_info": [ + 8388608, + 268435456 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(1, 8, 2048, 64)-(1, 8, 64, 2048)-(1, 8, 2048, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 4194304, + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.2818003296852112, + "bw_span": 1.0446002706885338, + "infer_memory": 142606336, + "train_mem_info": [ + 4194304, + 4194304 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 4, 2048, 64)-(2, 4, 64, 2048)-(2, 4, 2048, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 4194304, + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.28081703931093216, + "bw_span": 1.0572420433163643, + "infer_memory": 142606336, + "train_mem_info": [ + 4194304, + 4194304 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 1024, 64)-(2, 8, 64, 2048)-(2, 8, 1024, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 4194304, + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.28518345206975937, + "bw_span": 1.0718883946537971, + "infer_memory": 146800640, + "train_mem_info": [ + 8388608, + 4194304 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 2048, 32)-(2, 8, 32, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 4194304, + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.4970652982592583, + "bw_span": 1.9734794273972511, + "infer_memory": 276824064, + "train_mem_info": [ + 4194304, + 4194304 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 2048, 64)-(2, 8, 64, 1024)-(2, 8, 2048, 1024) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 8388608, + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.2919815480709076, + "bw_span": 1.0712673887610435, + "infer_memory": 146800640, + "train_mem_info": [ + 4194304, + 8388608 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(1, 8, 2048, 2048)-(1, 8, 2048, 64)-(1, 8, 2048, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 134217728, + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.4901982843875885, + "bw_span": 1.4053288847208023, + "infer_memory": 142606336, + "train_mem_info": [ + 4194304, + 134217728 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 4, 2048, 2048)-(2, 4, 2048, 64)-(2, 4, 2048, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 134217728, + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.4908433184027672, + "bw_span": 1.399235613644123, + "infer_memory": 142606336, + "train_mem_info": [ + 4194304, + 134217728 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 1024, 2048)-(2, 8, 2048, 64)-(2, 8, 1024, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 134217728, + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.4898613318800926, + "bw_span": 1.4226442202925682, + "infer_memory": 146800640, + "train_mem_info": [ + 8388608, + 134217728 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 2048, 1024)-(2, 8, 1024, 64)-(2, 8, 2048, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 134217728, + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.5006546154618263, + "bw_span": 1.410149410367012, + "infer_memory": 146800640, + "train_mem_info": [ + 4194304, + 134217728 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + }, + "(2, 8, 2048, 2048)-(2, 8, 2048, 32)-(2, 8, 2048, 32) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 268435456, + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.9725717827677727, + "bw_span": 2.6656288653612137, + "infer_memory": 276824064, + "train_mem_info": [ + 4194304, + 268435456 + ], + "train_mem2in_idx": [ + 1, + 0 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.dropout.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.dropout.json new file mode 100644 index 00000000..c7ca2767 --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.dropout.json @@ -0,0 +1,62 @@ +{ + "(2, 8, 2048, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 268435456 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.0060519203543663025, + "bw_span": 1.1683166027069092, + "infer_memory": 268435456, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 2048, 2048)-(1, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005933083593845367, + "bw_span": 0.5830274894833565, + "infer_memory": 134217728, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 2048, 2048)-(2, 4, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005894899368286133, + "bw_span": 0.5836460739374161, + "infer_memory": 134217728, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 1024, 2048)-(2, 8, 1024, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006268918514251709, + "bw_span": 0.5835466086864471, + "infer_memory": 134217728, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 2048, 1024)-(2, 8, 2048, 1024) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006314180791378021, + "bw_span": 0.5827156826853752, + "infer_memory": 134217728, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.linear.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.linear.json new file mode 100644 index 00000000..2882d7e6 --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.linear.json @@ -0,0 +1,182 @@ +{ + "(2, 2048, 512)-(512, 512)-(2, 2048, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.13340581208467484, + "bw_span": 0.11256430298089984, + "infer_memory": 17825792, + "train_mem_info": [ + 8388608 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 2048, 512)-(512, 512)-(2, 2048, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.12639649212360382, + "bw_span": 0.2559272572398186, + "infer_memory": 17825792, + "train_mem_info": [ + 8388608 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 2048, 512)-(512, 512)-(1, 2048, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.06875265389680862, + "bw_span": 0.08587203919887543, + "infer_memory": 9437184, + "train_mem_info": [ + 4194304 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 1024, 512)-(512, 512)-(2, 1024, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.06851088255643845, + "bw_span": 0.08921511471271515, + "infer_memory": 9437184, + "train_mem_info": [ + 4194304 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 2048, 256)-(512, 256)-(2, 2048, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.07312837988138199, + "bw_span": 0.09327642619609833, + "infer_memory": 13107200, + "train_mem_info": [ + 4194304 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 2048, 512)-(256, 512)-(2, 2048, 256) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.0704590231180191, + "bw_span": 0.08722972124814987, + "infer_memory": 13107200, + "train_mem_info": [ + 8388608 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 2048, 512)-(512, 512)-(1, 2048, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.0701218843460083, + "bw_span": 0.14707427471876144, + "infer_memory": 9437184, + "train_mem_info": [ + 4194304 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 1024, 512)-(512, 512)-(2, 1024, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.07040128111839294, + "bw_span": 0.14658160507678986, + "infer_memory": 9437184, + "train_mem_info": [ + 4194304 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 2048, 256)-(512, 256)-(2, 2048, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.07458627223968506, + "bw_span": 0.14493074268102646, + "infer_memory": 13107200, + "train_mem_info": [ + 4194304 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 2048, 512)-(256, 512)-(2, 2048, 256) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.070917047560215, + "bw_span": 0.16025099903345108, + "infer_memory": 13107200, + "train_mem_info": [ + 8388608 + ], + "train_mem2in_idx": [ + 0 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.softmax.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.softmax.json new file mode 100644 index 00000000..8bd4b216 --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.nn.functional.softmax.json @@ -0,0 +1,66 @@ +{ + "(2, 8, 2048, 2048)-(2, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 268435456 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 1.0575706139206886, + "bw_span": 3.7416458129882812, + "infer_memory": 536870912, + "train_mem_info": [ + 268435456 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(1, 8, 2048, 2048)-(1, 8, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.517265684902668, + "bw_span": 1.8769219517707825, + "infer_memory": 268435456, + "train_mem_info": [ + 134217728 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(2, 4, 2048, 2048)-(2, 4, 2048, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.5280718207359314, + "bw_span": 1.860242709517479, + "infer_memory": 268435456, + "train_mem_info": [ + 134217728 + ], + "train_mem2in_idx": [ + -1 + ] + }, + "(2, 8, 1024, 2048)-(2, 8, 1024, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 134217728 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.529155321419239, + "bw_span": 1.8756849691271782, + "infer_memory": 268435456, + "train_mem_info": [ + 134217728 + ], + "train_mem2in_idx": [ + -1 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.sum.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.sum.json new file mode 100644 index 00000000..f58f2764 --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.sum.json @@ -0,0 +1,50 @@ +{ + "(2, 2048, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.02030692994594574, + "bw_span": 0.03039538860321045, + "infer_memory": 8390656, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 2048, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.015971437096595764, + "bw_span": 0.033936649560928345, + "infer_memory": 4195840, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1024, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.0163249671459198, + "bw_span": 0.03274437040090561, + "infer_memory": 4195840, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 2048, 256)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01620892435312271, + "bw_span": 0.03477875143289566, + "infer_memory": 4195840, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.transpose.json b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.transpose.json new file mode 100644 index 00000000..9c71e6e1 --- /dev/null +++ b/tests/autodist/spmd_solver/test_solver_data_parallel/comp/torch.transpose.json @@ -0,0 +1,182 @@ +{ + "(2, 8, 2048, 64)-(2, 8, 64, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01199282705783844, + "bw_span": 0.03650356084108353, + "infer_memory": 8388608, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 2048, 8, 64)-(2, 8, 2048, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005709193646907806, + "bw_span": 0.023253075778484344, + "infer_memory": 8388608, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 2048, 64)-(2, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 8388608 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010336004197597504, + "bw_span": 0.03510527312755585, + "infer_memory": 8388608, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 2048, 64)-(1, 8, 64, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00971127301454544, + "bw_span": 0.03457833081483841, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 2048, 64)-(2, 4, 64, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00948682427406311, + "bw_span": 0.03553144633769989, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 1024, 64)-(2, 8, 64, 1024) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009621679782867432, + "bw_span": 0.03477856516838074, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 2048, 32)-(2, 8, 32, 2048) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009464845061302185, + "bw_span": 0.0379662960767746, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 2048, 8, 64)-(1, 8, 2048, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009648129343986511, + "bw_span": 0.03602709621191025, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1024, 8, 64)-(2, 8, 1024, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009131059050559998, + "bw_span": 0.03459863364696503, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 2048, 4, 64)-(2, 4, 2048, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00681169331073761, + "bw_span": 0.024368613958358765, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 2048, 8, 32)-(2, 8, 2048, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006612204015254974, + "bw_span": 0.024354644119739532, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 2048, 64)-(1, 2048, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009928643703460693, + "bw_span": 0.03518201410770416, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 2048, 64)-(2, 2048, 4, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005669891834259033, + "bw_span": 0.024593807756900787, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 1024, 64)-(2, 1024, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010106153786182404, + "bw_span": 0.03500021994113922, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 2048, 32)-(2, 2048, 8, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 4194304 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010841339826583862, + "bw_span": 0.034562498331069946, + "infer_memory": 4194304, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file From ebfe23030c781a5f3f85ba7f4310cbdb4e267131 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 13 Jun 2024 06:43:10 +0000 Subject: [PATCH 1654/1892] Merged PR 2175: change use_reentrant to False change use_reentrant to False because True is not stable, and in some torch version it may trigger bugs. --- nnscaler/codegen/module/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 7169b408..5a87abb0 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -849,7 +849,7 @@ def recompute(tensor_2222): fb.insert_body(f'return {output_names_tuple}') codes = [''] + fb.code + [''] codes.append( - f'{output_names_tuple} = ckpt.checkpoint(recompute, {input_names_tuple}, use_reentrant=True)' + f'{output_names_tuple} = ckpt.checkpoint(recompute, {input_names_tuple}, use_reentrant=False)' ) return codes From 8080e2a0a7d984229ea033480d16ca3ebc3adb54 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 14 Jun 2024 03:52:17 +0000 Subject: [PATCH 1655/1892] Merged PR 2172: refine IRObject handling refine IRObject handling 1. make tensor non-constant 2. don't trigger error on non-constant args for non-registered pyfunc. 3. set is_constant=False for all input objects. unit test pass parity check pass --- docs/source/register_custom_op.md | 27 ++++++++++++++ nnscaler/compiler.py | 2 +- nnscaler/execplan/execplan.py | 28 +++++++-------- nnscaler/graph/parser/fx/parser.py | 34 +++++++++--------- nnscaler/ir/cten.py | 10 +++--- nnscaler/parallel.py | 2 +- nnscaler/program.py | 2 +- tests/graph/parser/test_parser.py | 57 +++++++++++++++++++++++++++++- 8 files changed, 124 insertions(+), 38 deletions(-) diff --git a/docs/source/register_custom_op.md b/docs/source/register_custom_op.md index 85cb4928..6a7f120b 100644 --- a/docs/source/register_custom_op.md +++ b/docs/source/register_custom_op.md @@ -130,6 +130,33 @@ A `reduction` can be a set of {'', '+', '^'}: A dimension can also be annotated with inner-dimensions using brackets, i.e., '(' and ')'. The value of inner dimension needs to be inferrable, or indicated by function args (of same name). +Please be very careful when you use '?'. If it depends on the tensor input, +then the tensor input should be marked as non-partitionable. + +Example 1: +```python +@nnscaler.register_op('a^ b^ -> a^ b^, ?') +def op1(x: torch.Tensor): + x = ... + y = some_func(x) + return x, y +``` + +Example 2: +```python +@nnscaler.register_op('a b -> a b, ?') +def op1(x: torch.Tensor): + x = ... + y = 10 + return x, y +``` + +In Example 1, as `y` has dependency on `x`, its value will be wrong if we partition `x`. +So `x` should be marked as non-partitionable. + +In Example 2, `y` is a constant, and its value is independent of `x`. +So we can mark `x` partitioned. + ### Shape Annotation e.g., 'a (c+ d^) e' diff --git a/nnscaler/compiler.py b/nnscaler/compiler.py index 066d9204..4d99a815 100644 --- a/nnscaler/compiler.py +++ b/nnscaler/compiler.py @@ -103,7 +103,7 @@ def train_iter(model, dataloader): dtype=arg.dtype).tosub() arg._value = tensor else: - arg = IRObject('obj', value=arg) + arg = IRObject('obj', value=arg, is_constant=False) inputs.append(arg) myrank = DeviceGroup().rank diff --git a/nnscaler/execplan/execplan.py b/nnscaler/execplan/execplan.py index fd4ba6ea..387eaf47 100644 --- a/nnscaler/execplan/execplan.py +++ b/nnscaler/execplan/execplan.py @@ -37,19 +37,19 @@ def __init__(self, cell: IRCell, @property def device(self) -> int: return self._cell.device - + @property def cell(self) -> IRCell: return self._cell - + def isfw(self) -> bool: return self._cell.isfw() - + def dispatch(self, devid: int, _mirror = True): assert len(self.device) > 0 and devid in self.device, f"Cannot dispatch of ReuseCell {self} to device {devid}" if devid in self._cached_dispatched: return self._cached_dispatched[devid] - + inputs = [] for t, cell_t in zip(self._inputs, self._cell.inputs()): if isinstance(cell_t, IRSubTensor) and devid not in cell_t.device: @@ -67,7 +67,7 @@ def dispatch(self, devid: int, _mirror = True): IRCell.make_pair(reuse, mreuse) self._cached_dispatched[devid] = reuse return reuse - + def __repr__(self) -> str: return f'ReuseCell-{self.device}(name={self._cell.name}{self._cell.cid}, inputs={self.inputs()}, outputs={self.outputs()})' @@ -93,9 +93,9 @@ def from_schedplan(schedplan: SchedulePlan): A schedule plan has multiple micro-batches, where each micro-batch goes through the all operators in the model graph. So an operator will be executed multiple times with different data from different micro-batches. - - The IRGraph only contains operators / IRTensors / IRObjects of one micro-batch. - To represent data of a different micro-batch, we need to map the data in IRGraph to a + + The IRGraph only contains operators / IRTensors / IRObjects of one micro-batch. + To represent data of a different micro-batch, we need to map the data in IRGraph to a new one with different IDs. """ graph_inputs = schedplan.graph.inputs() @@ -121,7 +121,7 @@ def get(tensor: IRObject, micro_idx: int) -> IRObject: micro_objs.setdefault(micro_idx, {}).setdefault(tensor.parent.grad, fgrad) t.grad = fgrad.select(tensor.grad.indmap, tensor.grad.valmap) return t - + micro_fcells: Dict[(int, IRCell), ExeReuseCell] = {} def block2reuse(node: Block) -> ExeReuseCell: if node.content.isfw(): @@ -141,7 +141,7 @@ def block2reuse(node: Block) -> ExeReuseCell: else: mcell = block2reuse(Block(node.content.mirror, node.mid, node.span)) return mcell.mirror - + topo_seqs: List[IRCell] = [] for block in schedplan.nodes(): if isinstance(block, Block): @@ -288,7 +288,7 @@ def map2time(node): if isinstance(node, IRDataOperation): return 0 if isinstance(node, IRAdapter): return 0.25 return 1 if node.isfw() else 2 - + if map2mem is None: def map2mem(node): if isinstance(node, IRSegment): @@ -356,7 +356,7 @@ def depends(prev: IRCell, next: IRCell) -> bool: start_time = max(start_time, end_time) break start_times.append(start_time) - + start_time = max(start_times) for device in node.device: # time @@ -406,7 +406,7 @@ def depends(prev: IRCell, next: IRCell) -> bool: unwrap_node = node.cell if isinstance(node, ExeReuseCell) else node if end - start == 0: continue - # draw + # draw color = map2color(unwrap_node) rec = Rectangle((start, devid + 0.5), end-start, 1, color=color, ec='black', lw=1.5) @@ -425,7 +425,7 @@ def depends(prev: IRCell, next: IRCell) -> bool: break fontsize[0] = min(fontsize[0], fs) txts.append(txt) - + # set font size to same for txt in txts: txt.set_fontsize(fontsize[0]) diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index d9464ed5..43791e4d 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -62,9 +62,9 @@ def parse(module: torch.fx.GraphModule, # create IRObjects and IRTensors for node in module.graph.nodes: if node.op == 'placeholder': - FxModuleParser.init_objects(node, module, frame, dummy_inputs.get(node.name), is_constant=False) + FxModuleParser.init_objects(node, module, frame, is_constant=False) else: - FxModuleParser.init_objects(node, module, frame, None, is_constant=True) + FxModuleParser.init_objects(node, module, frame, is_constant=True) # get graph inputs placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] @@ -73,7 +73,7 @@ def parse(module: torch.fx.GraphModule, # it should be wrapped into an IRObject for idx, placeholder in enumerate(placeholders): if not isinstance(inputs[idx], IRObject): - obj = IRObject(name=placeholder.name, value=inputs[idx]) + obj = IRObject(name=placeholder.name, value=inputs[idx], is_constant=False) inputs[idx] = obj frame.set_var(placeholder.name, obj) @@ -85,6 +85,9 @@ def parse(module: torch.fx.GraphModule, # get graph outputs outputs = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + # currently fx graph always has only one output + # even if a tuple/list is returned, it is still just one output + assert len(outputs) == 1, f"Expect only one output, but got {len(outputs)}" if save_content: attr_savedir = Path(attr_savedir) @@ -114,7 +117,7 @@ def parse_node(node: torch.fx.Node, module, constant_folding: bool, frame: Frame @staticmethod def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, - frame: Frame, concrete_value: Optional[Any] = None, is_constant: bool = True): + frame: Frame, is_constant: bool = True): assert isinstance(node, torch.fx.Node) def meta2var(meta: Any) -> Any: @@ -142,14 +145,9 @@ def meta2var(meta: Any) -> Any: # TODO: data type check, with cases like {'a': 1.2, 'b': torch.Tensor} return IRObject(name=node.name, value=meta, is_constant=is_constant) - if hasattr(node, 'meta') and 'tensor_meta' in node.meta: - meta = node.meta['tensor_meta'] - val = meta2var(meta) - else: - # FIXME: double check: there should be a concrete value as example, - # otherwise, it may fail in parsing node like getattr - val = IRObject(name=node.name, value=concrete_value, is_constant=is_constant) - + assert hasattr(node, 'meta') and 'tensor_meta' in node.meta, f"Node {node} should have tensor_meta" + meta = node.meta['tensor_meta'] + val = meta2var(meta) frame.add_var(node.name, val) @staticmethod @@ -221,11 +219,13 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # case2: python runtime function else: _logger.warning(f'Set python runtime function: {fsig}') + is_constant = True if any_ir_object_satisfy(input_vals, lambda a: not a.is_constant): - err_msg = f'non register python runtime function {fsig} has a non constant input: {input_vals}, ' + \ - 'please register it as a customized function using nnscaler.graph.parser.register' - raise RuntimeError(err_msg) - ir_node = IRPyFunc(fsig, input_vals, [IRObject()], **kwargs) + warning_msg = f'non register python runtime function {fsig} has a non constant input: {input_vals}, ' + \ + 'You can register it as a customized function using nnscaler.register_op to remove this warning' + _logger.warning(warning_msg) + is_constant = False + ir_node = IRPyFunc(fsig, input_vals, [IRObject(frame.get_var(node.name), is_constant=is_constant)], **kwargs) if isinstance(ir_node, IRCell): module_stack = node.meta.get('nn_module_stack') @@ -266,6 +266,8 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule ) ir_node.set_output(0, output_val) else: + # SignFx2Op may return object that is not IRCell but a concrete value, for example Add. + # As node is deleted, we must set concrete value or IRTensor/IRObject into framework. frame.set_var(node.name, ir_node) _logger.debug(f'parsing result: {ir_node}') diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index a822a99e..94ecee5a 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -425,9 +425,11 @@ def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: val (Any): the value of this object is_constant (bool): if the value is a constant during the whole training / inference This flag is only used in constant_folding mode, to prevent the object from being folded. - An IROject is considered constant only when: - 1. val is not a tensor - 2. val is model input, or is the result of a non-torch operation on another constant IRObject + An IROject is considered not constant when either of two satisifies: + 1. val is a tensor + 2. val is model input, or is the result of a non-torch operation on another not constant IRObject + Please note is_constant flag is only used in parser, + so after parser, you can totally ignore this flag. """ self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() self.name: str = name if name else 'obj' @@ -545,7 +547,7 @@ class IRTensor(IRObject): def __init__(self, shape=None, name='tensor', dtype=None, tid=None): - super().__init__(name, tid) + super().__init__(name, tid, is_constant=False) self._shape: Tuple[int] = () if shape is None else tuple(shape) self._cell: Optional[IRCell] = None self._dtype: Optional[torch.dtype] = dtype diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 99a4304d..c3fc0d01 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -648,7 +648,7 @@ def _gen_graph( # the IRObject representing the `dataloader` instance, which is only used by the # IRDataOperation. Since we already know the output of the dataloader, # we don't need to set the value for it. - ir_root_obj = IRObject(name='dataloader', value=None) + ir_root_obj = IRObject(name='dataloader', value=None, is_constant=False) Program().set_input([ir_root_obj]) data_op = IRDataOperation(ir_root_obj, ir_dummy_inputs) # add the data operation to the graph, which will use `next` to get data. diff --git a/nnscaler/program.py b/nnscaler/program.py index bfc1e6f2..075e1f79 100644 --- a/nnscaler/program.py +++ b/nnscaler/program.py @@ -101,7 +101,7 @@ def __init__(self, dataloader: MicroBatchDataLoader): # the IRObject representing the `dataloader` instance, which is only used by the # IRDataOperation. Since we already know the output of the dataloader, # we don't need to set the value for it. - self.irobj = IRObject(name='dataloader', value=None) + self.irobj = IRObject(name='dataloader', value=None, is_constant=False) def __iter__(self): return self diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index 2b67a7e2..b4bdf553 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -1,5 +1,8 @@ import tempfile +import pytest import torch + +import nnscaler from nnscaler.ir.cten import IRObject, IRTensor from nnscaler.graph.parser.converter import to_fx_graph, to_ir_graph @@ -115,4 +118,56 @@ def forward(self, x): print(ir_graph.extra_repr()) assert isinstance(ir_graph.output(0), IRTensor) - assert ir_graph.output(0).shape == (10, 1) \ No newline at end of file + assert ir_graph.output(0).shape == (10, 1) + + +@nnscaler.register_op('m n -> m n, m n, ?') +def func_multi_outputs(x): + return x, x, 3 + + +@nnscaler.register_op('m n -> ?') +def func_output_list(x, factor=1): + x = x * factor + return [x, x] + + +@nnscaler.register_op('m n -> m n, m n') +def func_output_list2(x, factor=1): + x = x * factor + return [x, x] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('output_list', [True, False]) +def test_num_outputs(tmp_path, output_list): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + out = func_multi_outputs(x) + y, _, scalar = out + (sz, _) = y.shape + sz = sz + scalar + if output_list: + return func_output_list(y, factor=sz) + else: + return func_output_list2(y, factor=sz) + + dummy_input = {'x': torch.randn(4, 4)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) + print(ir_graph.extra_repr()) + + assert len(ir_graph.nodes()) == 5 + assert len(ir_graph.nodes()[0].outputs()) == 3 + assert isinstance(ir_graph.output(0), list) + assert len(ir_graph.outputs()) == 1 + if output_list: + assert len(ir_graph.nodes()[-1].outputs()) == 1 + else: + assert len(ir_graph.nodes()[-1].outputs()) == 2 From 9c57adb668215840056dbc20aebce61efcafd032 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Sat, 15 Jun 2024 01:18:11 +0000 Subject: [PATCH 1656/1892] Merged PR 2178: parallelize: rename dummy_input to dummy_forward_args --- docs/source/parallel_module.md | 12 +++++--- examples/vision/swin/train.py | 2 +- .../zigzag_ring_attention/test_zigzag_attn.py | 12 ++++---- .../integration/lightning/pytorch/strategy.py | 10 +++---- nnscaler/parallel.py | 28 +++++++++---------- .../lightning/pytorch/simple_models.py | 4 +-- 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 2421892f..466cb65c 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -28,7 +28,7 @@ class LLM(torch.nn.Module): def forward(self, x): ... -llm_sample_input = ... # dummpy input will be used to do tracing +llm_sample_input = ... # dummy input will be used to do tracing pas_policy = ... # the PAS policy, you can use autodist pas compute_config = ComputeConfig( plan_ngpus=..., @@ -133,7 +133,7 @@ class End2EndMLP(nn.Module): loss = self.loss_fn(x, data['target']) return loss - llm_sample_input = {'data': ..., 'target': ...} # dummpy input will be used to do tracing + llm_sample_input = {'data': ..., 'target': ...} # dummy input will be used to do tracing pas_policy = ... # the PAS policy, you can use autodist pas compute_config = ComputeConfig( plan_ngpus=..., @@ -383,7 +383,7 @@ We have `parallelize` function to Convert a torch.nn.Module to a ParallelModule. ```python def parallelize( module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]], - dummy_input: dict, + dummy_forward_args: Dict[str, Any], pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], compute_config: ComputeConfig, *, @@ -401,7 +401,11 @@ It has the following parameters: - `module_or_module_class` (`Union[torch.nn.Module, Type[torch.nn.Module]]`): the module or module class to be compiled. Please note if the input is a module object, we will return a `ParallelModule` object. If the input is a module class, we will return a `ParallelModule` class. -- `dummy_input` (`dict`): the dummy input for the module. The keys are the argument names of `Module.forward` function, and the values are the dummy input for the arguments. The dummy input will be used to trace the module. Please note the module can't be parallelize if `Module.forward` has positional-only arguments. +- `dummy_forward_args` (`Dict[str, Any]`): the dummy input for the module forward. +The keys are the argument names of `Module.forward` function, +and the values are the dummy input for the arguments. +The dummy forward args will be used to trace the module. +Please note the module can't be parallelize if `Module.forward` has positional-only arguments. - `pas_policy` (`Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]`): the pas (partition-assign-schedule) policy, which describes how to place all computations across devices. You need either pass a builtin PAS policy name or a a custom policy function which should take an `IRGraph` and a `ComputeConfig` as input, and return a new `IRGraph` with the PAS policy applied. diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 867c9c8c..71974952 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -86,7 +86,7 @@ def train(args, compute_config: nnscaler.ComputeConfig): model: nnscaler.ParallelModule = nnscaler.parallelize( model, - dummy_input={'x': gen_data()}, + dummy_forward_args={'x': gen_data()}, pas_policy=policy, compute_config=compute_config, reuse='moo', diff --git a/examples/zigzag_ring_attention/test_zigzag_attn.py b/examples/zigzag_ring_attention/test_zigzag_attn.py index 62ae13dd..399db43b 100644 --- a/examples/zigzag_ring_attention/test_zigzag_attn.py +++ b/examples/zigzag_ring_attention/test_zigzag_attn.py @@ -5,7 +5,7 @@ from nnscaler.parallel import parallelize, ComputeConfig, ReuseType import torch.distributed as dist from flash_attn import flash_attn_func - + import nnscaler.graph import nnscaler.graph.function from examples.zigzag_ring_attention.zigzag_attn import wrap_zigzag_attn_func @@ -14,9 +14,9 @@ def set_seed(rank, seed=42): seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def log(msg, a, rank0_only=False): @@ -47,7 +47,7 @@ def log(msg, a, rank0_only=False): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() - + def forward(self, _in0, _in1, _in2): out = wrap_zigzag_attn_func(_in0, _in1, _in2) return out @@ -103,7 +103,7 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: _in1 = k.detach().clone().requires_grad_() _in2 = v.detach().clone().requires_grad_() - parallel_model = parallelize(model, dummy_input={"_in0": _in0, "_in1": _in1, "_in2": _in2}, pas_policy=policy, + parallel_model = parallelize(model, dummy_forward_args={"_in0": _in0, "_in1": _in1, "_in2": _in2}, pas_policy=policy, compute_config=ComputeConfig(world_size, world_size), reuse=ReuseType.OVERRIDE) parallel_model = parallel_model.cuda() diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py index c7fa3a42..274b151e 100644 --- a/nnscaler/integration/lightning/pytorch/strategy.py +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -200,10 +200,10 @@ def setup(self, trainer: "pl.Trainer") -> None: def _setup_model(self, model: Module) -> Module: """Set up a module for inference (no optimizers). """ - if getattr(model, 'dummy_input', None) is None: - raise ValueError("The `dummy_input` must be defined as a property in the module.") - if not isinstance(model.dummy_input, dict): - raise ValueError("The `dummy_input` must be a dictionary with forward arguments names as keys.") + if getattr(model, 'dummy_forward_args', None) is None: + raise ValueError("The `dummy_forward_args` must be defined as a property in the module.") + if not isinstance(model.dummy_forward_args, dict): + raise ValueError("The `dummy_forward_args` must be a dictionary with forward arguments names as keys.") old_training_flag = model.training if not old_training_flag: @@ -211,7 +211,7 @@ def _setup_model(self, model: Module) -> Module: model.train() # always use the model in training mode pmodule = nnscaler.parallelize( model, - self.precision_plugin.convert_input(model.dummy_input), + self.precision_plugin.convert_input(model.dummy_forward_args), self.pas_policy, self.compute_config, gen_savedir=self.gen_savedir, diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index c3fc0d01..bd23710f 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -586,7 +586,7 @@ def _prepare_and_check_reusable( def _gen_graph( module: torch.nn.Module, - dummy_input: dict, + dummy_forward_args: dict, outdir: Path, constant_folding: bool, end2end_mode: bool = False, @@ -605,12 +605,12 @@ def _gen_graph( raise ValueError(f"Default value type {type(v)} of forward args is not supported.") # generate fx graph - dummy_input = _to_cpu(dummy_input) - fx_graph = parser.to_fx_graph(module, dummy_input) + dummy_forward_args = _to_cpu(dummy_forward_args) + fx_graph = parser.to_fx_graph(module, dummy_forward_args) # generate ir logic graph ir_graph = parser.to_ir_graph( - fx_graph, dummy_input, outdir, constant_folding + fx_graph, dummy_forward_args, outdir, constant_folding ) # generate dummy inputs for logic graph @@ -625,15 +625,15 @@ def _gen_graph( ir_dummy_inputs = [] for node in fx_input_nodes: if node.target.startswith('*'): # *args or **kwargs - if node.target.strip('*') in dummy_input: + if node.target.strip('*') in dummy_forward_args: raise ValueError(f"Input {node.target}: *args or **kwargs is not suppported") ir_dummy_inputs.append(None) # always set None to *args/**kwargs - elif node.target in dummy_input: - ir_dummy_inputs.append(dummy_input[node.target]) + elif node.target in dummy_forward_args: + ir_dummy_inputs.append(dummy_forward_args[node.target]) elif forward_args[node.target] is not inspect.Parameter.empty: ir_dummy_inputs.append(forward_args[node.target]) else: - raise ValueError(f"Input {node.target} not in dummy input, nor has default value.") + raise ValueError(f"Input {node.target} not in dummy forward args, nor has default value.") for i in range(len(ir_dummy_inputs)): ir_dummy_inputs[i] = to_ir_input(ir_dummy_inputs[i], fx_input_nodes[i].target) # if the input is not a tensor, we should wrap it with IRObject @@ -685,7 +685,7 @@ def _gen_graph( def _gencode( module_or_module_class: torch.nn.Module, - dummy_input: dict, + dummy_forward_args: Dict[str, Any], pas_policy: Callable[[IRGraph, ComputeConfig], IRGraph], compute_config: ComputeConfig, outdir: Path, @@ -705,7 +705,7 @@ def _gencode( Args: module (torch.nn.Module): the module to be compiled - dummy_input (dict): the dummy input for the module + dummy_forward_args (Dict[str, Any]): the dummy input for the module forward pas_policy (Callable[[IRGraph, ComputeConfig], IRGraph]): the pas policy compute_config (ComputeConfig): the environment resource outdir (Path): the directory to save generated code @@ -753,7 +753,7 @@ def _gencode( torch.save(meta_info, origin_module_metadata_ckp) graph, forward_args = _gen_graph( - module, dummy_input, outdir, + module, dummy_forward_args, outdir, constant_folding=compute_config.constant_folding, end2end_mode=compute_config.use_end2end, inference_only=compute_config.inference_only, use_pipeline=compute_config.use_pipeline, @@ -869,7 +869,7 @@ def _load_parallel_module_class( def parallelize( module_or_module_class: Union[torch.nn.Module, Type[torch.nn.Module]], - dummy_input: dict, + dummy_forward_args: Dict[str, Any], pas_policy: Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]], compute_config: ComputeConfig, *, @@ -938,7 +938,7 @@ def __init__(self, init_params=True): Args: module_or_module_class (Union[torch.nn.Module, Type[torch.nn.Module]]): the module or module class to be compiled - dummy_input (dict): the dummy input for the module + dummy_forward_args (Dict[str, Any]): the dummy input for the module forward pas_policy (Union[str, Callable[[IRGraph, ComputeConfig], IRGraph]]): the pas policy, it can be a name of builtin policies, or a custom policy function. compute_config (ComputeConfig): the environment resource @@ -999,7 +999,7 @@ def __init__(self, init_params=True): with _compile_flags(compute_config): regen_status = _gencode( module_or_module_class, - dummy_input, + dummy_forward_args, pas_policy, compute_config, outdir, diff --git a/tests/integration/lightning/pytorch/simple_models.py b/tests/integration/lightning/pytorch/simple_models.py index 2d124a96..6081a716 100644 --- a/tests/integration/lightning/pytorch/simple_models.py +++ b/tests/integration/lightning/pytorch/simple_models.py @@ -32,7 +32,7 @@ def __init__(self, num_features=32, num_classes=3, batch_size=10, lr=0.01): self.test_acc = acc.clone() @property - def dummy_input(self): + def dummy_forward_args(self): return {'x': torch.randn(self.batch_size, self.num_features)} def forward(self, x): @@ -165,7 +165,7 @@ def __init__(self) -> None: self.layer = torch.nn.Linear(32, 2) @property - def dummy_input(self): + def dummy_forward_args(self): return {'x': torch.randn(32)} def forward(self, x: Tensor) -> Tensor: From 17877a96641be29541ec8cc30d54098f32cdbaef Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Sat, 15 Jun 2024 04:42:14 +0000 Subject: [PATCH 1657/1892] Merged PR 2177: Fix policy for autodist bug set `memory_constraint` correctly --- nnscaler/policies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nnscaler/policies.py b/nnscaler/policies.py index 37418dda..be94e414 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -210,7 +210,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: mesh_col = pas_cfg.get('mesh_col', cfg.plan_ngpus) if mesh_col != cfg.plan_ngpus: raise ValueError("mesh_col should be equal to plan_ngpus") - mem_constraint = pas_cfg.get('mem_constraint', -1) + memory_constraint = pas_cfg.get('mem_constraint', -1) task_name = pas_cfg.get('task_name', '_') use_memory_efficient_fp16 = pas_cfg.get('use_memory_efficient_fp16', False) use_memory_efficient_bf16 = pas_cfg.get('use_memory_efficient_bf16', False) @@ -228,7 +228,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: mesh_row = 1 ngpus = mesh_row * mesh_col task_name = f'{task_name}_{ngpus}gpus_{update_freq}update_freq' - if mem_constraint == -1: + if memory_constraint == -1: # consider memory fragmentation and other buffers, use 80% of the memory memory_constraint = int(0.8 * torch.cuda.mem_get_info()[1] / 1024 / 1024 / 1024) From 046e02016b0092d99a19329eba600a55bb3d196f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 17 Jun 2024 07:17:33 +0000 Subject: [PATCH 1658/1892] Merged PR 2176: add a mini-trainer --- docs/source/parallel_module.md | 2 +- examples/vision/swin/train.py | 6 +- nnscaler/cli/__init__.py | 0 nnscaler/cli/arg_parser.py | 230 +++++++++++++++++++++++++++++ nnscaler/cli/train.py | 8 ++ nnscaler/cli/trainer.py | 246 +++++++++++++++++++++++++++++++ nnscaler/cli/trainer_args.py | 255 +++++++++++++++++++++++++++++++++ tests/cli/__init__.py | 0 tests/cli/common.py | 18 +++ tests/cli/test_arg_parser.py | 140 ++++++++++++++++++ tests/cli/test_trainer.py | 28 ++++ tests/cli/trainer_args.yaml | 26 ++++ 12 files changed, 955 insertions(+), 4 deletions(-) create mode 100644 nnscaler/cli/__init__.py create mode 100644 nnscaler/cli/arg_parser.py create mode 100644 nnscaler/cli/train.py create mode 100644 nnscaler/cli/trainer.py create mode 100644 nnscaler/cli/trainer_args.py create mode 100644 tests/cli/__init__.py create mode 100644 tests/cli/common.py create mode 100644 tests/cli/test_arg_parser.py create mode 100644 tests/cli/test_trainer.py create mode 100644 tests/cli/trainer_args.yaml diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 466cb65c..b6b734b5 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -9,7 +9,7 @@ Currently we support three kinds of parallelism: data parallelism, tensor parall Data parallelism and tensor parallelism are support for all kinds of module, but pipeline parallelism is only supported for end2end modules for scheduling reason. An end2end module is a module which satisfies: -- the first argument of `module.forward` is the data sample +- the first argument of `module.forward` is the data sample, and every other argument should have default value, and use its default value in `module.forward` function. - the first return value of `module.forward` is the loss (scalar tensor) The above restrictions are necessary for the pipeline parallelism to work. Of course, you can still use the parallel module without pipeline parallelism for end2end modules. diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 71974952..190ea5a2 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -141,7 +141,7 @@ def train(args, compute_config: nnscaler.ComputeConfig): if torch.distributed.get_world_size() != args.dp_size * args.pp_size * args.tp_size: raise ValueError('world size should be equal to dp_size * pp_size * tp_size') - if args.gbs % args.mbs != 0: + if args.gbs % (args.mbs * args.dp_size) != 0: raise ValueError('global batch size should be divisible by micro batch size') compute_config=nnscaler.ComputeConfig( @@ -151,7 +151,7 @@ def train(args, compute_config: nnscaler.ComputeConfig): use_end2end=True, constant_folding=True, use_pipeline=args.pp_size > 1, - pipeline_nmicros=args.gbs // args.mbs, + pipeline_nmicros=args.gbs // args.mbs // args.dp_size, pipeline_nstages=args.pp_size, pas_config={ # customized settings that can affect code generation. @@ -161,7 +161,7 @@ def train(args, compute_config: nnscaler.ComputeConfig): '_tp_size': args.tp_size, '_dp_size': args.dp_size, # for autodist only - 'update_freq': args.gbs // args.mbs, + 'update_freq': args.gbs // args.mbs// args.dp_size, 'use_fp16': args.fp16, }, user_config={ diff --git a/nnscaler/cli/__init__.py b/nnscaler/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py new file mode 100644 index 00000000..927b9a4a --- /dev/null +++ b/nnscaler/cli/arg_parser.py @@ -0,0 +1,230 @@ +from typing import List, Tuple, Dict, Any, Union +from dataclasses import dataclass, is_dataclass, asdict +import enum +import ast + + +try: + from types import UnionType +except ImportError: + UnionType = None # for python < 3.10 + + +def parse_args(argv: List[str]) -> dict: + raw_args = {} + last_key = None + for v in argv: + if v.startswith('--'): + if '=' in v: + k, v = v[2:].split('=', 1) + raw_args[k] = v + last_key = None + else: + k = v[2:] + raw_args[k] = None + last_key = k + else: + if not last_key: + raise ValueError(f"invalid argument {v}") + raw_args[last_key] = v + last_key = None + + args = {} + for k, v in raw_args.items(): + k = k.replace('-', '_') + keys = k.split('.') + current = args + for key in keys[:-1]: + if key not in current: + current[key] = {} + current = current[key] + current[keys[-1]] = v + + return args + + +def merge_args(args: dict, new_args: dict): + for k, v in new_args.items(): + if k in args and isinstance(args[k], dict) and isinstance(v, dict): + merge_args(args[k], v) + else: + args[k] = v + + +def _fix_any(type_): + return None if type_ == Any else type_ + + +def _fix_optional(type_info): + if getattr(type_info, '__origin__', None) == Union \ + or (UnionType and isinstance(type_info, UnionType)): + args = getattr(type_info, '__args__', None) + if len(args) != 2 or (args[1] != type(None) and args[0] != type(None)): + raise ValueError(f"Invalid optional type {type_info}") + if args[1] == type(None): + return _fix_optional(args[0]) + else: + return args[1] + return type_info + + +def _fix_type(type_info, raise_on_nested=True): + type_info = _fix_optional(type_info) + type_info = _fix_any(type_info) + if raise_on_nested and getattr(type_info, '__args__', None): + raise ValueError(f"Nested type {type_info} is not allowed here.") + return type_info + + +@dataclass +class _TypeInfo: + type: Any = None + key_type: Any = None + value_type: Any = None + item_type: Any = None + + +def _get_type_info_from_annotation(type_info): + type_info = _fix_type(type_info, False) + if type_info is None or type_info == Any: + return _TypeInfo(type=None) + if type_info in (list, List): + return _TypeInfo(type=list) + if type_info in (dict, Dict): + return _TypeInfo(type=dict) + if type_info in (tuple, Tuple): + return _TypeInfo(type=tuple) + + origin = getattr(type_info, '__origin__', None) + args = getattr(type_info, '__args__', None) + + if origin in (List, list): + if len(args) != 1: + raise ValueError(f"Invalid list type {type_info}") + return _TypeInfo(type=list, item_type=_fix_type(args[0])) + elif origin in (Dict, dict): + if len(args) != 2: + raise ValueError(f"Invalid dict type {type_info}") + return _TypeInfo(type=dict, key_type=_fix_type(args[0]), value_type=_fix_type(args[1])) + elif origin in (Tuple, tuple): + if len(args) != 2 or args[1] != Ellipsis: + raise ValueError(f"Invalid tuple type {type_info}") + return _TypeInfo(type=tuple, item_type=_fix_type(args[0])) + else: + if type_info.__module__ == 'typing': + raise ValueError(f"Unsupported type {type_info}") + return _TypeInfo(type=type_info) + + +def _get_type_info(dataclass_type) -> Dict[str, _TypeInfo]: + if not is_dataclass(dataclass_type): + raise ValueError(f"{dataclass_type} is not a dataclass") + type_dict = {} + for k, v in dataclass_type.__dataclass_fields__.items(): + type_dict[k] = _get_type_info_from_annotation(v.type) + return type_dict + + +def _is_primitive_type(data_type): + """ + We only support int, str, bool, float as primitive types. + """ + return data_type in (int, str, bool, float) + + +def _guess_deserialize_object(value): + if isinstance(value, dict): + return {_guess_deserialize_object(k): _guess_deserialize_object(v) for k, v in value.items()} + if isinstance(value, list): + return [_guess_deserialize_object(v) for v in value] + if isinstance(value, tuple): + return tuple(_guess_deserialize_object(v) for v in value) + if isinstance(value, str): + try: + # try to parse as literal + # if failed, return as it is + # Please note that if there is no type annotation, + # you should provide the value in python code format + # for example `[a, b, c]` will return string as a whole + # but `['a', 'b', 'c']` will return a list of strings + return ast.literal_eval(value) + except Exception: + return value + return value + + +def _deserialize_object(value, value_type): + """ + deserialize object based on single value type: + 1. If no value type or collective type, try to guess its value type. + 2. if it is primitive types, return value_type(value) + 3. If it is a dataclass, ask dataclass to deserialize. + 4. Otherwise, we will return as it is + """ + if not value_type or value_type in (dict, list, tuple): + return _guess_deserialize_object(value) + try: + if value is None: + return value + if isinstance(value, value_type): + return value + if issubclass(value_type, enum.Enum): + try: + return value_type[value] # first treat as enum name + except KeyError: + return value_type(value) # then treat as enum value + if value_type == bool and isinstance(value, str): + if value.lower() in ('true', '1'): + return True + elif value.lower() in ('false', '0'): + return False + else: + raise ValueError(f"Failed to deserialize {value} to {value_type}") + if _is_primitive_type(value_type): + return value_type(value) + except Exception as ex: + raise ValueError(f"Failed to deserialize {value} to {value_type}") from ex + + if is_dataclass(value_type): + return deserialize_dataclass(value, value_type) + + return value + + +def deserialize_dataclass(value, value_type): + if not isinstance(value, dict): + raise ValueError(f"Expecting dict, but got {value}") + if not is_dataclass(value_type): + raise ValueError(f"{value_type} is not a dataclass") + + type_info = _get_type_info(value_type) + member_values = {} + for k, ti in type_info.items(): + if not k in value: + continue + v = value[k] + if ti.type is bool and v is None: + v = True # set bool to True if it shows up in cmd line + if v is None: + continue + + if ti.type in (list, tuple, dict, type(None)) and isinstance(v, str): + v = ast.literal_eval(v) + + if isinstance(v, (list, tuple, dict)) and not ti.type: + ti.type = type(v) + + if ti.item_type or ti.key_type or ti.value_type: + if ti.type == list: + v = [_deserialize_object(x, ti.item_type) for x in v] + elif ti.type == tuple: + v = tuple(_deserialize_object(x, ti.item_type) for x in v) + elif ti.type == dict: + v = {_deserialize_object(k, ti.key_type): _deserialize_object(v, ti.value_type) for k, v in v.items()} + else: + v = _deserialize_object(v, ti.type) + + if v is not None: # for none values, use default value. + member_values[k] = v + + return value_type(**member_values) diff --git a/nnscaler/cli/train.py b/nnscaler/cli/train.py new file mode 100644 index 00000000..bf7fd749 --- /dev/null +++ b/nnscaler/cli/train.py @@ -0,0 +1,8 @@ +from .trainer import Trainer + + +if __name__ == '__main__': + trainer = Trainer() + if trainer.train_args.run_mode == 'run': + trainer.train() + diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py new file mode 100644 index 00000000..9927965f --- /dev/null +++ b/nnscaler/cli/trainer.py @@ -0,0 +1,246 @@ +from dataclasses import dataclass, asdict +from typing import Any, Dict, List, Optional, Union +from pathlib import Path +import sys +import copy +import inspect + +import torch +import torch.distributed +import nnscaler + +from .trainer_args import TrainerArgs + + +@dataclass +class TrainStatus: + epoch: int = 0 + in_epoch_pos: int = 0 # the position inside an epoch, used for resuming training + + +class Trainer: + def __init__(self, + argv: Optional[List[str]] = None, + train_args: Optional[Union[Dict[str, Any], TrainerArgs]] = None + ): + """ + Args: + argv (Optional[List[str]]): command line arguments. If not specified, sys.argv[1:] will be used + train_args: a dict used to construct TrainerArgs or TrainerArgs object itself. + """ + if train_args is not None: + if argv is not None: + raise ValueError("argv and train_args can not be specified together") + if isinstance(train_args, TrainerArgs): + self.train_args = train_args + else: + if not isinstance(train_args, dict): + raise ValueError(f"train_args should be a dict or TrainerArgs, got {type(train_args)}") + self.train_args = TrainerArgs.from_dict(train_args) + else: + cli_args = argv or sys.argv[1:] # remve the leading script name from sys.argv + self.train_args = TrainerArgs.from_cli(cli_args) + + self.model = None + self.optimizer = None + self.dataset = {'train': None, 'val': None, 'test': None} + self.dataloader = {'train': None, 'val': None, 'test': None} + self.lr_scheduler = None + self.train_status = TrainStatus() + self.dummy_input = None + self._setup() + + def _fix_input(self, input): + if isinstance(input, dict): + return {k: self._fix_input(v) for k, v in input.items()} + elif isinstance(input, list): + return [self._fix_input(v) for v in input] + elif isinstance(input, tuple): + return tuple(self._fix_input(v) for v in input) + elif isinstance(input, torch.Tensor): + if self.train_args.fp16: + return input.half().cuda() + elif self.train_args.bf16: + return input.bfloat16().cuda() + else: + return input.cuda() + return input + + def _create_dummy_forward_args(self): + assert self.dummy_input is not None, "dummy_input is not set" + assert self.train_args.model_type is not None, "model_type is not set" + + arg_names = list( + inspect.signature( + inspect.unwrap(getattr(self.train_args.model_type, 'forward')) + ).parameters.keys() + ) + return {arg_names[1]: self.dummy_input} # arg_names[0] is self + + def _setup(self): + compile_only = self.train_args.run_mode == 'compile' + if not compile_only: + nnscaler.init() + + def _create_model(): + model = self.train_args.create_model() + if self.train_args.fp16: + model = model.half() + elif self.train_args.bf16: + model = model.bfloat16() + if self.train_args.ckpt_tracing: + model.load_state_dict(torch.load(self.train_args.ckpt_tracing)) + return model + + # load a dummy input from training dataset + if not compile_only: + for stage in ['train', 'val', 'test']: + self.dataset[stage] = self.train_args.create_dataset(stage) + self.dataloader[stage] = self.train_args.create_dataloader(stage, self.dataset[stage]) + + self.dummy_input = self.dataloader['train'].collate_fn( + [self.dataset['train'][idx] for idx in range(self.train_args.micro_batch_size)] + ) + else: + train_dataset = self.train_args.create_dataset('train') + self.dummy_input = self.train_args.collate_fn( + [train_dataset[idx] for idx in range(self.train_args.micro_batch_size)] + ) + del train_dataset + + self.dummy_input = self._fix_input(self.dummy_input) + + # setup compute config + compute_config = copy.deepcopy(self.train_args.compute_config) + compute_config.pas_config['__pas_name'] = self.train_args.pas_policy + compute_config.user_config['__from_trainer_args'] = { + 'mbs': self.train_args.micro_batch_size, + 'gbs': self.train_args.global_batch_size, + 'fp16': self.train_args.fp16, + 'bf16': self.train_args.bf16, + } + + # parallalize model + pmodel_class = nnscaler.parallelize( + self.train_args.model_type, + self._create_dummy_forward_args(), + self.train_args.pas_policy, + compute_config, + module_fn=_create_model, + gen_savedir=self.train_args.gen_savedir, + reuse='moo' if compile_only else 'match', + instance_name=self.train_args.instance_name, + broadcast_strategy='all', + load_module=not compile_only, + ) + if compile_only: + return + + torch.distributed.barrier() + + self.model = pmodel_class() + self.optimizer = self.train_args.create_parallel_optimizer(self.model) + self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) + self._load_checkpoint() + + def _load_checkpoint(self): + if not self.train_args.ckpt_load_file: + return + state_dict = torch.load(self.train_args.ckpt_load_file, map_location='cpu') + ckpt_save_type = state_dict.get('train_args', {}).get('ckpt_save_type', None) + + if not ckpt_save_type: # it is a merged state dict + nnscaler.load_merged_state_dicts( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) + if 'lr_scheduler' in state_dict: + self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + elif ckpt_save_type == 'sharded': + self.model.load_state_dict(state_dict['model']) + self.model.cuda() + self.optimizer.load_state_dict(state_dict['optimizer']) + if 'lr_scheduler' in state_dict: + self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + self.train_status = TrainStatus(**state_dict['train_status']) + elif ckpt_save_type == 'deduped': + nnscaler.load_deduped_state_dict( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) + if 'lr_scheduler' in state_dict: + self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + self.train_status = TrainStatus(**state_dict['train_status']) + else: + raise ValueError(f"Unknown checkpoint type: {ckpt_save_type}") + + def _save_checkpoint(self, from_end_of_epoch=True): + if not self.train_args.ckpt_save_dir: + return + save_dir = Path(self.train_args.ckpt_save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + if self.train_args.ckpt_save_type == 'sharded': + model_state_dict= self.model.state_dict() + optimizer_state_dict = self.optimizer.state_dict() + elif self.train_args.ckpt_save_type == 'deduped': + model_state_dict, optimizer_state_dict = nnscaler.deduped_state_dict( + self.model, self.optimizer + ) + else: + raise ValueError(f"Unknown checkpoint type: {self.train_args.ckpt_save_type}") + + train_status = copy.deepcopy(self.train_status) + if from_end_of_epoch: + train_status.in_epoch_pos = 0 + train_status.epoch += 1 + + state_dict = { + 'model': model_state_dict, + 'optimizer': optimizer_state_dict, + 'lr_scheduler': self.lr_scheduler.state_dict() if self.lr_scheduler else None, + 'train_status': asdict(train_status), + 'train_args': self.train_args.to_dict(), + } + torch.save(state_dict, save_dir / + f'ckpt_{train_status.epoch}_{train_status.in_epoch_pos}_rank{torch.distributed.get_rank()}.pt' + ) + + def _global_batch_iterator(self, num_skip_first = 0): + samples = [] + for idx, sample in enumerate(self.dataloader['train']): + if idx < num_skip_first * self.train_args.update_freq: + continue + sample = self._fix_input(sample) + samples.append(sample) + if len(samples) == self.train_args.update_freq: + yield samples + samples = [] + if samples: + yield samples + + def train(self): + num_skip_fist = self.train_status.in_epoch_pos + for epoch in range(self.train_status.epoch, self.train_args.max_epochs): + self.train_status.epoch = epoch + for idx, samples in enumerate(self._global_batch_iterator(num_skip_fist)): + self.train_status.in_epoch_pos = idx + is_dummy_batch = [False] * len(samples) + if len(samples) < self.train_args.update_freq: + gap = self.train_args.update_freq - len(samples) + is_dummy_batch += [True] * gap + samples += [self.dummy_input] * gap + + self.model.train() + self.optimizer.zero_grad() + losses = self.model.train_step(samples, is_dummy_batch) + if self.train_args.clip_gnorm: + self.optimizer.clip_gnorm(self.train_args.clip_gnorm) + self.optimizer.step() + + if self.lr_scheduler: + self.lr_scheduler.step(epoch) + + self._save_checkpoint(True) + + num_skip_fist = 0 diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py new file mode 100644 index 00000000..e6b3462d --- /dev/null +++ b/nnscaler/cli/trainer_args.py @@ -0,0 +1,255 @@ +from dataclasses import asdict, dataclass, field +import importlib +from typing import Any, Dict, List, Optional + +import torch +import torch.utils +import torch.utils.data +import torch.utils.data.dataloader +import yaml +import torch + +from nnscaler.parallel import ComputeConfig, build_optimizer +from nnscaler.runtime.module import ParallelModule + +from .arg_parser import deserialize_dataclass, merge_args, parse_args + + +def load_type(type_name: str): + parts = type_name.rsplit('.', 1) + if len(parts) == 1: + nm = __builtins__ + type_name = parts[0] + else: + namespace, type_name = parts + nm = importlib.import_module(namespace) + return getattr(nm, type_name) + + +@dataclass +class AggregatedOutputs: + """ + Aggregated outputs from all micro-batches + """ + loss: Optional[int] = None + num_samples: Optional[int] = None + num_tokens: Optional[int] = None + # any other custom outputs + aggregated_outputs: Any = None + + +@dataclass +class TrainerArgs: + compute_config: ComputeConfig = None + + gen_savedir: str = './.nnscaler' + pas_policy: str = 'autodist' + broadcast_strategy: str = 'all' + instance_name: str = None + # compile: compile the model but not training + # run: compile and run the model + run_mode: str = 'run' + # the model state dict for tracing. + ckpt_tracing: str = None + + model_class: str = None + model_args: Dict[str, Any] = field(default_factory=dict) + fp16: bool = False + bf16: bool = False + + optimizer_class: str = None + optimizer_args: Dict[str, Any] = field(default_factory=dict) + + dataset_class: str = None + train_dataset_args: Dict[str, Any] = field(default_factory=dict) + val_dataset_args: Dict[str, Any] = field(default_factory=dict) + test_dataset_args: Dict[str, Any] = field(default_factory=dict) + + dataloader_class: str = 'torch.utils.data.DataLoader' + train_dataloader_args: Dict[str, Any] = field(default_factory=dict) + # default to train_dataloader_args + val_dataloader_args: Dict[str, Any] = field(default_factory=dict) + # default to train_dataloader_args + test_dataloader_args: Dict[str, Any] = field(default_factory=dict) + + dataset_sampler_class: str = 'torch.utils.data.DistributedSampler' + train_dataset_sampler_args: Dict[str, Any] = field(default_factory=dict) + val_dataset_sampler_args: Dict[str, Any] = field(default_factory=dict) + test_dataset_sampler_args: Dict[str, Any] = field(default_factory=dict) + + lr_scheduler_class: str = None + lr_scheduler_args: Dict[str, Any] = field(default_factory=dict) + + micro_batch_size: int = 1 + global_batch_size: int = 1 + + max_epochs: int = 1000 + clip_gnorm: float = 0.0 + # TODO: support different ways of calculating grad and loss + # sum: sum the gradients of all micro-batches + # per-sample-mean: average the gradients over all micro-batches + # per-token-mean: average the gradients over all tokens + # you must specify `aggregate_outputs_fn` and return the number of tokens + gradient_accumulation: str = 'sum' + # the function to aggregate the outputs from all micro-batches + # inputs: (list of local outputs, torch group) + # output: AggregateOutputs + # you can use `torch.distributed.*` functions to do the work + aggregate_outputs_fn: str = None + + ckpt_save_dir: str = None + # `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is + # a folder with as many files as the world size. + # `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is + # a folder with as many files as the world size. + ckpt_save_type: str = 'sharded' + ckpt_load_file: str = None + + def __post_init__(self): + if not self.compute_config: + raise ValueError("compute_config is required") + if not self.compute_config.use_end2end: + raise ValueError("use_end2end must be True") + if self.global_batch_size % self.micro_batch_size != 0: + raise ValueError(f"global_batch_size {self.global_batch_size} is not divisible by micro_batch_size {self.micro_batch_size}") + if self.run_mode not in ('compile', 'run'): + raise ValueError(f"Invalid run_mode {self.run_mode}") + if self.ckpt_save_type not in ('sharded', 'deduped'): + raise ValueError(f"Invalid ckpt_save_type {self.ckpt_save_type}") + if self.fp16 and self.bf16: + raise ValueError("Cannot use both fp16 and bf16") + if not self.model_class: + raise ValueError("model_class is required") + if not self.optimizer_class: + raise ValueError("optimizer_class is required") + if not self.dataset_class: + raise ValueError("dataset_class is required") + if not self.dataloader_class: + raise ValueError("dataloader_class is required") + if not self.dataset_sampler_class: + raise ValueError("dataset_sampler_class is required") + + @classmethod + def from_cli(cls, argv: List[str]) -> 'TrainerArgs': + d = {} + if argv[0] == '-f': + with open(argv[1], 'r') as f: + d = yaml.safe_load(f) + argv = argv[2:] + + merge_args(d, parse_args(argv)) + return cls.from_dict(d) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'TrainerArgs': + ta = deserialize_dataclass(d, TrainerArgs) + return ta + + def to_dict(self): + return asdict(self) + + @classmethod + def from_yaml(cls, path: str) -> 'TrainerArgs': + with open(path, 'r') as f: + return cls.from_dict(yaml.safe_load(f)) + + @classmethod + def create_kwarg(cls, value: dict): + for k, v in value.items(): + if isinstance(v, dict): + value[k] = cls.create_kwarg(v) + elif isinstance(v, list): + value[k] = [cls.create_kwarg(i) for i in v] + elif isinstance(v, tuple): + value[k] = tuple(cls.create_kwarg(i) for i in v) + + if '__type' in value: + value_type = load_type(value.pop('__type')) + return value_type(**value) + elif '__value_type' in value: + if 'value' not in value: + raise ValueError("value is required when __value_type is present") + value_type = value.pop('__value_type') + if value_type == 'function': # when type is function, the value should be the full qualified name of the function + return load_type(value['value']) + else: + # call its __init__ function + value_type = load_type(value_type) + return value_type(value['value']) + else: + return value + + @property + def model_type(self): + return load_type(self.model_class) + + @property + def collate_fn(self): + """ + Used to generate dummy input from dataset + """ + args = self.train_dataloader_args + if 'collate_fn' in args: + return load_type(args['collate_fn']) + # hack to get default collate_fn + return torch.utils.data.dataloader.default_collate + + @property + def scaling_factor(self): + return self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus + + @property + def update_freq(self): + return self.global_batch_size // self.micro_batch_size // self.scaling_factor + + def create_model(self) -> torch.nn.Module: + kwargs = self.create_kwarg(self.model_args) + return self.model_type(**kwargs) + + def create_parallel_optimizer(self, parallel_model: ParallelModule): + kwargs = self.create_kwarg(self.optimizer_args) + optimizer_class = load_type(self.optimizer_class) + return build_optimizer(parallel_model, optimizer_class, **kwargs) + + def create_dataset(self, stage='train'): + dataset_args = getattr(self, f'{stage}_dataset_args') + if not dataset_args: + return None + kwargs = self.create_kwarg(dataset_args) + dataset_class = load_type(self.dataset_class) + if issubclass(dataset_class, torch.utils.data.IterableDataset): + raise ValueError("IterableDataset is not supported") + return dataset_class(**kwargs) + + def create_sampler(self, dataset, stage='train'): + sampler_args = getattr(self, f'{stage}_dataset_sampler_args') + sampler_args = sampler_args or self.train_dataset_sampler_args + kwargs = self.create_kwarg(sampler_args) + kwargs['dataset'] = dataset + kwargs['num_replicas'] = self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus + kwargs['rank'] = torch.distributed.get_rank() // self.compute_config.plan_ngpus + sampler_class = load_type(self.dataset_sampler_class) + return sampler_class(**kwargs) + + def create_dataloader(self, stage='train', dataset=None): + dataloader_args = getattr(self, f'{stage}_dataloader_args') + dataloader_args = dataloader_args or self.train_dataloader_args + kwargs = self.create_kwarg(dataloader_args) + kwargs['dataset'] = dataset or self.create_dataset(stage) + if kwargs['dataset'] is None: + return None + if 'collate_fn' in kwargs: + # special handling for collate_fn as a function + # here we don't use self.collate_fn to avoid its implementation hacking + kwargs['collate_fn'] = load_type(kwargs['collate_fn']) + kwargs['batch_size'] = self.micro_batch_size + kwargs['sampler'] = self.create_sampler(kwargs['dataset'], stage) + dataloader_class = load_type(self.dataloader_class) + return dataloader_class(**kwargs) + + def create_lr_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.LRScheduler: + if not self.lr_scheduler_class: + return None + kwargs = self.create_kwarg(self.lr_scheduler_args) + lr_scheduler_class = load_type(self.lr_scheduler_class) + return lr_scheduler_class(optimizer, **kwargs) diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cli/common.py b/tests/cli/common.py new file mode 100644 index 00000000..6935a36e --- /dev/null +++ b/tests/cli/common.py @@ -0,0 +1,18 @@ +import torch +from torch.utils.data import DataLoader, Dataset + +from tests.parallel_module.test_end2end import MLP + +class SimpleDataset(Dataset): + def __init__(self, dim: int, size: int = 100): + self.data = torch.randn(size, dim) + self.target = torch.randn(size, dim) + + def __getitem__(self, idx: int): + return { + 'data': self.data[idx], + 'target': self.target[idx] + } + + def __len__(self): + return len(self.data) diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py new file mode 100644 index 00000000..69d06753 --- /dev/null +++ b/tests/cli/test_arg_parser.py @@ -0,0 +1,140 @@ +from dataclasses import asdict, dataclass +from typing import List, Optional, Tuple, Dict, Any, Union +import sys + +import pytest + +from nnscaler.cli.arg_parser import parse_args, deserialize_dataclass, _fix_type + + +def test_parse_args(): + assert parse_args(['--a-good=1', '--b', '2', '--c.d=3', '--c.e', '4', '--f.g.h=5']) == { + 'a_good': '1', + 'b': '2', + 'c': {'d': '3', 'e': '4'}, + 'f': {'g': {'h': '5'}} + } + parse_args(['--a=1', '--b', '--c.d=3', '--c.e', '4', '--f.g.h=5']) == { + 'a': '1', + 'b': None, + 'c': {'d': '3', 'e': '4'}, + 'f': {'g': {'h': '5'}} + } + + parse_args(['--a=1', '--b', '[1,2,3,4]']) == { + 'a': '1', + 'b': [1,2,3,4], + } + + with pytest.raises(ValueError): + parse_args(['--a=1', 'e', '--b', '--c.d=3', '--c.e', '4', '--f.g.h=5']) + + +def test_fix_type(): + assert _fix_type(int) == int + assert _fix_type(None) == None + assert _fix_type(Any) == None + assert _fix_type(Optional[bool]) == bool + assert _fix_type(Union[bool, None]) == bool + assert _fix_type(List[str], False) == List[str] + assert _fix_type(Optional[List[str]], False) == List[str] + + with pytest.raises(ValueError): + _fix_type(List[str], True) + + with pytest.raises(ValueError): + _fix_type(Union[bool, int]) + + with pytest.raises(ValueError): + _fix_type(Union[bool, int, None]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='| is not available as union type for python < 3.10') +def test_fix_type2(): + assert _fix_type(bool|None) == bool + assert _fix_type(list[str], False) == list[str] + assert _fix_type(Optional[list[str]], False) == list[str] + assert _fix_type(list[str]|None, False) == list[str] + + with pytest.raises(ValueError): + _fix_type(list[str], True) + + with pytest.raises(ValueError): + _fix_type(bool|int) + + with pytest.raises(ValueError): + _fix_type(bool|int|None) + + +@dataclass +class GConfig: + h: int + + +def test_deserialize(): + @dataclass + class C: + d: int + e: int + + @dataclass + class G: + h: int + + @dataclass + class F: + g: G + + @dataclass + class A: + a: int + b: bool + c: C + f: F + h: Tuple[int, ...] = None + g: List[str] = None + k: List[int] = None + w: Dict[str, int] = None + v: Dict[str, int] = None + x: Dict[str, Any] = None + y: List[F] = None + z: Dict[str, Any] = None + + x = parse_args(['--a=1', '--b', '--c.d=3', '--c.e', '4', '--f.g.h=5', '--v.a=10', '--v.b=20', '--k=[10,12]']) + y = deserialize_dataclass(x, A) + assert y == A(a=1, b=True, c=C(d=3, e=4), f=F(g=G(h=5)), k=[10, 12], v={'a': 10, 'b': 20}) + + x = parse_args(['--a=1', '--b', 'False', '--c.d=3', '--c.e', '4', '--f.g.h=5', '--v.a=10', '--v.b=20', '--k=[10,12]']) + y = deserialize_dataclass(x, A) + assert y == A(a=1, b=False, c=C(d=3, e=4), f=F(g=G(h=5)), k=[10, 12], v={'a': 10, 'b': 20}) + + x = parse_args(['--a=1', '--b', '0', '--c.d=3', '--c.e', '4', '--f.g.h=5', + '--v.a=10', '--v.b=20', + '--z.__type=tests.cli.test_arg_parser.GConfig', + '--z.h=6', '--z.y=hello', + '--z.x=True', + '--z.array=[1,2,3,4,5]', + '--z.badarry=[1,b]', + '--z.dict={"a": 1, "b": 2}', + '--z.baddict={a:1,b:2}', + '--z.nest_dict.__type=tests.cli.test_arg_parser.GConfig', + '--z.nest_dict.h=7', + ]) + y = deserialize_dataclass(x, A) + assert y == A(a=1, b=False, c=C(d=3, e=4), f=F(g=G(h=5)), v={'a': 10, 'b': 20}, + z={ + 'h': 6, + '__type': 'tests.cli.test_arg_parser.GConfig', + 'y': 'hello', + 'x': True, + 'array': [1, 2, 3, 4, 5], + 'badarry': '[1,b]', + 'dict': {'a': 1, 'b': 2}, + 'baddict': '{a:1,b:2}', + 'nest_dict': { + 'h': 7, + '__type': 'tests.cli.test_arg_parser.GConfig' + } + } + ) + assert deserialize_dataclass(asdict(y), A) == y diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py new file mode 100644 index 00000000..a8e2b179 --- /dev/null +++ b/tests/cli/test_trainer.py @@ -0,0 +1,28 @@ +from pathlib import Path + +import torch +import pytest + +from nnscaler.cli.trainer import Trainer +from ..launch_torchrun import launch_torchrun + + +def trainer_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + trainer = Trainer([ + '-f', config_path, + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--ckpt_save_type', 'sharded', + '--ckpt_save_dir', str(ckpt_savedir), + ]) + trainer.train() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_trainer(tmp_path): + launch_torchrun(4, trainer_worker, tmp_path) diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml new file mode 100644 index 00000000..5454597d --- /dev/null +++ b/tests/cli/trainer_args.yaml @@ -0,0 +1,26 @@ +compute_config: + plan_ngpus: 4 + runtime_ngpus: 100 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 + +model_class: tests.cli.common.MLP +model_args: + dim: 16 + nlayers: 16 + +optimizer_class: torch.optim.Adam +optimizer_args: + lr: 0.01 + +dataset_class: tests.cli.common.SimpleDataset +train_dataset_args: + dim: 16 + size: 100 From 8dfec28c3e09bc47070be8fdcb29bdde41e077eb Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Tue, 18 Jun 2024 02:51:31 +0000 Subject: [PATCH 1659/1892] Merged PR 2040: add interface for cube integration test and script for collect huggingface mo... add interface for cube integration test and script for collect huggingface models --- examples/huggingface_nlp/compile_hf.py | 434 ++++++++++++++++++ examples/huggingface_nlp/compile_interface.py | 166 +++++++ 2 files changed, 600 insertions(+) create mode 100644 examples/huggingface_nlp/compile_hf.py create mode 100644 examples/huggingface_nlp/compile_interface.py diff --git a/examples/huggingface_nlp/compile_hf.py b/examples/huggingface_nlp/compile_hf.py new file mode 100644 index 00000000..7e319ed3 --- /dev/null +++ b/examples/huggingface_nlp/compile_hf.py @@ -0,0 +1,434 @@ +from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM +from _collections_abc import MutableMapping +import os +import json +import sys +import subprocess +import traceback +import torch +import psutil +import logging +import time +import nnscaler +import inspect +import logging +import argparse +import warnings +import pathlib +from compile_interface import ModelCompiler, TraceCompileException, calcu_max_diff, logger + +INFO_FNAME = 'info' # info.log will log all the info and other loggers will be redirected to this file +TRIED_FNAME = 'tried' # tried.log will log all the model names that have been tried +LOADED_FNAME = 'loaded' # loaded.log will log all the model names that have been successful loaded from huggingface +ERROR_FNAME = 'error' # error.log will log all the error messages +EXPORT_FNAME = 'export' # models successful exported +EXPORT_ALIGNED_FNAME = 'export_aligned' # models exported and aligned with original model +TRACE_FNAME = 'trace' # models successful traced +TRACE_ALIGNED_FNAME = 'trace_aligned' # models traced and aligned with original model +COMPILE_FNAME = 'compile' # models successful compiled +COMPILE_ALIGNED_FNAME = 'compile_aligned' # models compiled and aligned with original model +TRAIN_FNAME = 'train' # models successful trained +TRAIN_ALIGNED_FNAME = 'train_aligned' # models trained and aligned with original model + +FXMODULE_PARSER_WARNING_FNAME = 'FxModuleParser_Warning.log' # log file for FxModuleParser warning +COMPILE_ERROR_JSON = 'error.json' # error.json will store the error summary for all the models + + +warnings.filterwarnings("ignore") +torch.set_printoptions(edgeitems = 2) +text: str = "Huggingface is a really excellent project!" +########## define logger ########## +loggers = {} + + +def setup_logger(log_file, level = logging.INFO, need_timestamp = True): + """Setup a logger for log_file + """ + logger = logging.getLogger(str(log_file)) + logger.setLevel(level) + # logger will only init once for one log_file + if not logger.handlers: + handler = logging.FileHandler(log_file, "a") + if need_timestamp: + formatter = logging.Formatter('%(asctime)s [PID %(process)d][%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + handler.setFormatter(formatter) + handler.setLevel(level) + logger.addHandler(handler) + return logger + + +def logger_redirect(logger, to_logger_file, prefix = '', need_timestamp=True, level = logging.INFO) -> logging.FileHandler: + """Add logger to another file + """ + result_handler = logging.FileHandler(to_logger_file, 'a') + if need_timestamp: + formatter = logging.Formatter(f'%(asctime)s [PID %(process)d][%(levelname)s]: {prefix} %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + else: + formatter = logging.Formatter(f'{prefix} %(message)s') + result_handler.setFormatter(formatter) + result_handler.setLevel(level) + logger.addHandler(result_handler) + return result_handler + + +def add_logger(log_dir, log_key, prefix = "", level = logging.INFO, need_timestamp = False): + """Init a logger, redirect it to INFO_FNAME and add it to global loggers""" + global loggers + if log_key in loggers and loggers['log_key'] is not None: + return + _logger = setup_logger(log_dir / f'{log_key}.log', level, need_timestamp) + logger_redirect(_logger, log_dir / f'{INFO_FNAME}.log', prefix = prefix) + loggers[log_key] = _logger + + +def setup_loggers(log_dir, level = logging.INFO): + """Setup loggers for compiling process""" + info_logger = setup_logger(log_dir / f'{INFO_FNAME}.log', level, need_timestamp = True) + loggers[INFO_FNAME] = info_logger + add_logger(log_dir, TRIED_FNAME, prefix="model tried: ", level = level, need_timestamp = False) + add_logger(log_dir, LOADED_FNAME, prefix="model loaded: ", level = level, need_timestamp = False) + add_logger(log_dir, ERROR_FNAME, prefix="", level = level, need_timestamp = False) + +######### define logger ########## + +def print_memory_usage(prefix : str = ""): + """Print current gpu memory usage""" + process = psutil.Process() + mem_info = process.memory_info() + loggers[INFO_FNAME].debug("When " + prefix + f": Current memory usage: {mem_info.rss / (1024 ** 3):.2f} GB") + try: + smi_output = subprocess.check_output( + ['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,nounits,noheader'], + encoding='utf-8' + ) + memory_info = smi_output.strip().split('\n') + gpu_mem_tuple = [] + for idx, mem in enumerate(memory_info): + used, total = mem.split(', ') + gpu_mem_tuple.append((idx, int(used) / 1024, int(total) / 1024)) + loggers[INFO_FNAME].debug(f"GPU memory usage (index, used-GB, total-GB): {gpu_mem_tuple}") + except subprocess.CalledProcessError as e: + loggers[INFO_FNAME].error("Can't execute nvidia-smi command:", e.output) + except FileNotFoundError: + loggers[INFO_FNAME].error("nvidia-smi command not found , make sure nvidia driver has been install successfully.") + + +def _prepare_hf_nlp_input(model, dummy_input): + """Preprocess dummy_input for huggingface nlp models""" + if isinstance(dummy_input, MutableMapping): + dummy_input = dict(dummy_input) + assert isinstance(dummy_input, dict) + forward_signature = inspect.signature(model.forward) + if 'decoder_input_ids' in forward_signature.parameters and 'decoder_input_ids' not in dummy_input: + dummy_input['decoder_input_ids'] = dummy_input.get('input_ids', None) + if 'token_type_ids' not in forward_signature.parameters and 'token_type_ids' in dummy_input: + dummy_input.pop('token_type_ids', None) + if 'attention_mask' in dummy_input: + dummy_input.pop('attention_mask', None) + return dummy_input + + +def dump_orged_errors(model_name, error_dict, log_path): + """Dump error_dict to json file log_path, error_dict is a summary for error:model_name pairs""" + exc_type, exc_value, exc_traceback = sys.exc_info() + first_line = f"{exc_type.__name__}: {exc_value}" + first_line = first_line.replace(model_name, r"{model_name}") + + if first_line in error_dict: + error_dict[first_line]['model_name'].append(model_name) + error_dict[first_line]['count'] += 1 + else: + error_dict[first_line] = {"count": 1, 'model_name': [model_name]} #, "example": exception_string + + error_dict = dict(sorted(error_dict.items(), key=lambda item: item[1]["count"], reverse=True)) + + with open(log_path, 'w') as json_file: + json.dump(error_dict, json_file, indent=4) + + +def load_error_summary(log_dir): + """Load error_dict from COMPILE_ERROR_JSON in log_dir, this is for resume""" + errors = {} + if os.path.exists(log_dir / COMPILE_ERROR_JSON): + with open(log_dir / COMPILE_ERROR_JSON, 'r') as json_file: + errors = json.load(json_file) + return errors + else: + return errors + + +def print_model_size(model, model_name): + """Print model size in MB""" + param_size = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + + size_all_mb = (param_size + buffer_size) / 1024**2 + loggers[INFO_FNAME].info('model size: {:.3f}MB'.format(size_all_mb)) + loggers[INFO_FNAME].info(f"{model_name} has parameter: {sum(p.numel() for p in model.parameters())}") + print_memory_usage(f"after load model {model_name}") + + +class HFModelLoader: + """Load huggingface model, tokenizer, config by model_name""" + def __init__(self, model_name, cache_dir, reduce=False): + self.model_name = model_name + self.cache_dir = cache_dir + self.reduce = reduce + + def _load_model_from_config(self, config): + torch.manual_seed(0) + def _get_auto_model_class(config): + try: + if config.architectures: + architecture = config.architectures[0] + if "CausalLM" in architecture: + return AutoModelForCausalLM + return AutoModel + except AttributeError: + return AutoModel + try: + model = _get_auto_model_class(config).from_config(config, trust_remote_code=True) + return model + except Exception as e: + raise e + + def _load_model_from_pretrain(self): + torch.manual_seed(0) + return AutoModelForCausalLM.from_pretrained(self.model_name, cache_dir=self.cache_dir, trust_remote_code=True, resume_download = True) + + def load_hf_config(self): + try: + config = AutoConfig.from_pretrained(self.model_name, cache_dir=self.cache_dir, trust_remote_code=True, resume_download = True) + return config + except Exception: + return None + + def load_hf_model(self, config): + """Load huggingface model by config or pretrained + """ + def _reduce_model_size(config): + params_to_reduce = ["n_layer", "n_layers", "num_hidden_layers", "num_decoder_layers", "num_heads"] + for param in params_to_reduce: + if hasattr(config, param) and getattr(config, param) is not None: + loggers[INFO_FNAME].info(f"set {param}: {getattr(config, param)} to 1") + setattr(config, param, 1) + return config + try: + if not self.reduce: + model = self._load_model_from_pretrain() + else: + config = _reduce_model_size(config) + model = self._load_model_from_config(config) + return model + except Exception: + if not self.reduce: + loggers[INFO_FNAME].info("load model from pretrain failed, try by config") + try: + model = self._load_model_from_config(config) + return model + except Exception: + raise + raise + + def load_hf_nlp_tokenizer(self): + """load huggingface nlp tokenizer by model_name""" + try: + tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=self.cache_dir, trust_remote_code=True, resume_download = True) + return tokenizer + except (OSError, ValueError): + # The script uses just one of the seven tokenizers below, as we're only checking if the logits match for the same input. + # BertTokenizerFast, CamembertTokenizerFast tokenizer, XLMRobertaTokenizerFast tokenizer, DistilBertTokenizerFast tokenizer + # T5TokenizerFast tokenizer, RobertaTokenizerFast tokenizer, GPT2TokenizerFast tokenizer + loggers[INFO_FNAME].info("loading pretrained tokenizer failed, use bert-base-uncased tokenizer instead") + from transformers import BertTokenizerFast + return BertTokenizerFast.from_pretrained('bert-base-uncased', cache_dir=self.cache_dir, trust_remote_code=True, resume_download = True) + + +class HFCompiler: + def __init__(self, args): + self.model_name = args.model_name + self.cache_dir = args.cache_dir + self.trace = args.trace + self.compile = args.compile + self.train = args.train + self.reduce = args.reduce + self.export = args.export + self.log_dir = args.log_dir + self.model_loader = HFModelLoader(self.model_name, self.cache_dir, self.reduce) + + def load_resources(self): + self.config = self.model_loader.load_hf_config() + loggers[INFO_FNAME].info(f"config: {self.config}") + if self.config is not None: + loggers[INFO_FNAME].info(f"{self.model_name} config loaded") + self.tokenizer = self.model_loader.load_hf_nlp_tokenizer() + loggers[INFO_FNAME].info(f"{self.model_name} Tokenizer loaded") + self.model = self.model_loader.load_hf_model(self.config) + print_model_size(self.model, self.model_name) + loggers[LOADED_FNAME].info(f"{self.model_name}, {self.config.architectures if hasattr(self, 'config') and self.config else None}") + + def compile_hf_worker(self): + try: + start_time = time.time() + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + subprocess.run('rm -rf gencode*.py fullmodel.pt.* dist_param_map.pt', shell=True, check=True) + # load config, tokenizer, model + loggers[TRIED_FNAME].info(f"{self.model_name}") + model_name = self.model_name + + self.load_resources() + dummy_input = self.tokenizer(text, return_tensors="pt") + + # build dummy_input and forward + dummy_input = _prepare_hf_nlp_input(self.model, dummy_input) + self.model.eval() + + compiler = ModelCompiler(self.model, dummy_input, 'dp') + if self.export and torch.distributed.get_rank() == 0: + add_logger(self.log_dir, EXPORT_FNAME, prefix="model exported: ", level = logging.INFO, need_timestamp = False) + add_logger(self.log_dir, EXPORT_ALIGNED_FNAME, prefix="model export aligned: ", level = logging.INFO, need_timestamp = False) + emodel = compiler.export() + max_diff = compiler.forward_diff(emodel) + loggers[EXPORT_FNAME].info(f"{model_name}") + loggers[INFO_FNAME].info(f"export max diff: {max_diff}") + if max_diff <= 1e-5: + loggers[EXPORT_ALIGNED_FNAME].info(f"{model_name}") + else: + loggers[ERROR_FNAME].error(f"{model_name} not aligned before and after export, max diff:{max_diff}") + + if self.trace and torch.distributed.get_rank() == 0: + if self.export: + self.model = self.model_loader.load_hf_model(self.config) + print_model_size(self.model, self.model_name) + add_logger(self.log_dir, TRACE_FNAME, prefix="model traced: ", level = logging.INFO, need_timestamp = False) + add_logger(self.log_dir, TRACE_ALIGNED_FNAME, prefix="model trace aligned: ", level = logging.INFO, need_timestamp = False) + t_model = compiler.trace() + max_diff = compiler.forward_diff(t_model) + if t_model: + loggers[TRACE_FNAME].info(f"{model_name}") + loggers[INFO_FNAME].info(f"trace max diff: {max_diff}") + if max_diff <= 1e-5: + loggers[TRACE_ALIGNED_FNAME].info(f"{model_name}") + else: + loggers[ERROR_FNAME].error(f"{model_name} not aligned before and after trace, max diff:{max_diff}") + del self.model + torch.cuda.empty_cache() + + if self.compile or self.train: + if self.export or self.trace: + # this model should be load again if traced before because the model will be changed during trace + self.model = self.model_loader.load_hf_model(self.config) + print_model_size(self.model, self.model_name) + add_logger(self.log_dir, COMPILE_FNAME, prefix="model compiled: ", level = logging.INFO, need_timestamp = False) + add_logger(self.log_dir, COMPILE_ALIGNED_FNAME, prefix="model compile aligned: ", level = logging.INFO, need_timestamp = False) + p_model = compiler.parallel(self.model) + max_diff = compiler.forward_diff(p_model) + if p_model: + loggers[COMPILE_FNAME].info(f"{model_name}") + loggers[INFO_FNAME].info(f"compile max diff: {max_diff}") + if max_diff <= 1e-5: + loggers[COMPILE_ALIGNED_FNAME].info(f"{model_name}") + else: + loggers[ERROR_FNAME].error(f"{model_name} not aligned before and after compile, max diff:{max_diff}") + + if self.train: + add_logger(self.log_dir, TRAIN_FNAME, prefix="model trained: ", level = logging.INFO, need_timestamp = False) + add_logger(self.log_dir, TRAIN_ALIGNED_FNAME, prefix="model train aligned: ", level = logging.INFO, need_timestamp = False) + steps = 10 + compile_loss = compiler.train(p_model, steps = steps) + compile_logit = p_model(**compiler.dummy_input) + + origin_loss = compiler.train(self.model, steps = steps) + origin_logit = self.model(**compiler.dummy_input) + + loggers[TRAIN_FNAME].info(f"{model_name}") + + max_diff = calcu_max_diff(origin_logit, compile_logit) + if max_diff <= 1e-5: + loggers[TRAIN_ALIGNED_FNAME].info(f"{model_name}") + else: + loggers[ERROR_FNAME].error(f"{model_name} not aligned before and after train, max diff:{max_diff}") + except TraceCompileException as e: + # Exception will be cause from nnscaler.compile, or the program will be blocked + if torch.distributed.get_rank() == 0: + loggers[INFO_FNAME].error(f"fail when nnscaler.compile: {model_name}", exc_info=False) + error_message = traceback.format_exc().strip() + "\n" + loggers[ERROR_FNAME].error(f"{model_name}, {self.config.architectures if 'config' in locals() and self.config else None}, failed") + loggers[ERROR_FNAME].error(error_message) + dump_orged_errors(model_name, error_dict, self.log_dir / COMPILE_ERROR_JSON) + import glob + if not bool(glob.glob('gencode*.py')): + torch.distributed.barrier() + raise + except Exception as e: + if torch.distributed.get_rank() == 0: + loggers[INFO_FNAME].error(f"fail: {model_name}", exc_info=False) + + error_message = traceback.format_exc().strip() + "\n" + loggers[ERROR_FNAME].error(f"{model_name}, {self.config.architectures if 'config' in locals() and self.config else None}, failed") + loggers[ERROR_FNAME].error(error_message) + dump_orged_errors(model_name, error_dict, self.log_dir / COMPILE_ERROR_JSON) + raise + finally: + end_time = time.time() + loggers[INFO_FNAME].info(f"Finish trying model: {model_name}, time: {end_time - start_time:.2f} s") + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description='The script to compile huggingface nlp models') + parser.add_argument('-d', '--cache_dir', default='/tmp/hf_cache', help='cache directory for config, tokenizer, model') + parser.add_argument('-m', '--model_name', required=True, help='model name in huggingface') + parser.add_argument('-t', '--trace', default=False, action=argparse.BooleanOptionalAction, help='do trace') + parser.add_argument('-c', '--compile', default=True, action=argparse.BooleanOptionalAction, help='do compile') + parser.add_argument('-tr', '--train', default=True, action=argparse.BooleanOptionalAction, help='train steps') + parser.add_argument('-r', '--reduce', default=False, action=argparse.BooleanOptionalAction, help='whether reduce large models by setting layers to 1') + parser.add_argument('-e', '--export', default=False, action=argparse.BooleanOptionalAction, help='torch.export models') + parser.add_argument('-l', '--log_dir', default='~/hf_logs', help='log directory') + args = parser.parse_args() + return args + + +if __name__ == "__main__": + """ This script is for compile huggingface models and logs activities in argument of '--log_dir'. + It will first load configuration, tokenizer, model from huggingface, then do export, trace, compile, train one by one and check alignment at each step. + Among them, only "compile and train" is enabled by default, while "export" and "trace" are disabled by default. + Now it only supports huggingface nlp models. + usage: + torchrun --nproc_per_node=1 --nnodes=1 compile_hf.py -d -m -r -l + """ + args = parse_arguments() + if isinstance(args.log_dir, str): + args.log_dir = pathlib.Path(os.path.expanduser(args.log_dir)) + nnscaler.init() + if torch.distributed.get_rank() == 0: + if not args.log_dir.exists(): + args.log_dir.mkdir() + + # load error dict + error_dict = load_error_summary(args.log_dir) + + # block logs except from rank0 + if loggers is None or loggers == {}: + setup_loggers(args.log_dir, level = logging.INFO) + for tmp_logger in loggers.values(): + if torch.distributed.get_rank() != 0: + tmp_logger.setLevel(logging.WARNING) + for handler in loggers[INFO_FNAME].handlers: + logger.addHandler(handler) + + # add model name to FxModuleParser log + fxparser_warning_path = args.log_dir / FXMODULE_PARSER_WARNING_FNAME + file_handler = logging.FileHandler(fxparser_warning_path) + from nnscaler.graph.parser.fx.parser import _logger + _logger.addHandler(file_handler) + _logger.warning(f"\n{args.model_name}") + + # Instantiate and use the compiler + compiler = HFCompiler(args) + compiler.compile_hf_worker() + diff --git a/examples/huggingface_nlp/compile_interface.py b/examples/huggingface_nlp/compile_interface.py new file mode 100644 index 00000000..b3f5dbb3 --- /dev/null +++ b/examples/huggingface_nlp/compile_interface.py @@ -0,0 +1,166 @@ +import torch +from nnscaler.graph.parser.converter import to_fx_graph +import nnscaler +from nnscaler.runtime.utils import microbatches +from typing import Any, Dict +import inspect +import os +import logging + + +logger = logging.getLogger('compile_wrapper') + + +def prepare_dataloader(model, dummy_input): + forward_signature = inspect.signature(model.forward) + params_with_defaults = tuple( + v.default if k not in dummy_input else dummy_input[k].to(torch.cuda.current_device()) + for k, v in forward_signature.parameters.items() + ) + dataloader = microbatches([params_with_defaults] * 2) + return dataloader + + +def calcu_max_diff(before_trace, after_trace): + """Recursively calculate the max difference between two dicts or two tensors""" + max_diff = 0 + if isinstance(after_trace, torch.Tensor): + diff = torch.max(torch.abs(after_trace.to(torch.cuda.current_device()) - before_trace.to(torch.cuda.current_device()))) + if diff > max_diff: + max_diff = diff + elif isinstance(after_trace, dict): + for key in after_trace.keys(): + diff = calcu_max_diff(before_trace[key], after_trace[key]) + if diff > max_diff: + max_diff = diff + elif isinstance(after_trace, (list, tuple)): + for i in range(len(after_trace)): + diff = calcu_max_diff(before_trace[i], after_trace[i]) + if diff > max_diff: + max_diff = diff + else: + diff = calcu_max_diff(before_trace, after_trace) + if diff > max_diff: + max_diff = diff + return max_diff + + +def set_seed(seed): + import random + import numpy as np + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +class TraceCompileException(Exception): + """An exception that occurs during the model tracing or compilation process""" + def __init__(self, message, original_exception): + super().__init__(f"{message}: {str(original_exception)}") + self.original_exception = original_exception + + +class ModelCompiler: + def __init__(self, model: torch.nn.Module, dummy_input: Dict[str, Any], policy): + nnscaler.init() + self.model = model.to(torch.cuda.current_device()) + forward_signature = inspect.signature(model.forward) + self.dummy_input = { + k: v.default if k not in dummy_input + else dummy_input[k].to(torch.cuda.current_device()) + for k, v in forward_signature.parameters.items() + } + self.policy = policy + self.model.eval() + self.before_trace = self.model(**self.dummy_input) + + def forward_diff(self, model): + """Compute the model's output and compare it with the original model's output""" + if model is None: + raise RuntimeError("Model is None") + model.to(torch.cuda.current_device()) + model.eval() + _value = model(**self.dummy_input) + max_diff = calcu_max_diff(self.before_trace, _value) + return max_diff + + def trace(self): + """Trace model""" + try: + if torch.cuda.is_available(): + try: + traced_gm = to_fx_graph(self.model, self.dummy_input) + except: + raise + logger.info("Successfully traced with gpu") + return traced_gm + else: + raise RuntimeError("CUDA is not available") + except Exception as e: + raise TraceCompileException("An error occurred during trace the model.", e) + + def parallel(self, model): + """Compile model""" + from nnscaler.parallel import parallelize, ComputeConfig + try: + parallel_model = parallelize( + model, + self.dummy_input, + pas_policy=self.policy, + compute_config=ComputeConfig(1, 1, dynamic_shape=False), + reuse='override', + load_module=True + ) + return parallel_model + except Exception as e: + raise RuntimeError("An error occurred during the model compilation.", e) + + def train(self, model, steps = 1): + """Train model with dummy_input for steps""" + from torch.optim import SGD + set_seed(0) + model.to(torch.cuda.current_device()) + model.train() + optimizer = SGD(model.parameters(), 1e-3) + loss_fct = torch.nn.CrossEntropyLoss() + label = torch.zeros_like(self.dummy_input['input_ids']) + for _ in range(steps): + optimizer.zero_grad() + output = model(**self.dummy_input) + if isinstance(output, torch.Tensor): + loss = loss_fct(output, label) + elif isinstance(output, dict): + if 'logits' in output: + loss = loss_fct(output['logits'].view(-1, output['logits'].shape[-1]), label.view(-1)) + elif 'last_hidden_state' in output: + loss = loss_fct(output['last_hidden_state'].view(-1, output['last_hidden_state'].shape[-1]), label.view(-1)) + else: + raise RuntimeError(f"Output keys doesn't supported: {output.keys()}") + else: + raise RuntimeError(f"Output type doesn't supported: {type(output)}") + loss.backward() + optimizer.step() + return loss + + def export(self): + """Trace the model using torch.export, similar to trace""" + from torch.export import export + os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ['TORCH_LOGS'] = '+dynamo' + os.environ['TORCHDYNAMO_VERBOSE'] = '1' + try: + if torch.cuda.is_available(): + try: + dummy_inputs = tuple(self.dummy_input.values()) + exported_gm = export(self.model, self.dummy_input) + except: + raise + logger.info("Successfully export with gpu") + return exported_gm + else: + raise RuntimeError("CUDA is not available") + except Exception as e: + raise TraceCompileException("An error occurred during export and forward the model.", e) + From 77cabf5454e4e17ec904fa63a8e9a7ef1f8ef5af Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Tue, 18 Jun 2024 10:20:04 +0000 Subject: [PATCH 1660/1892] Merged PR 2181: quick fix compile huggingface create cache_dir if not exist refine ComputeConfig as it changes --- examples/huggingface_nlp/compile_hf.py | 2 ++ examples/huggingface_nlp/compile_interface.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/huggingface_nlp/compile_hf.py b/examples/huggingface_nlp/compile_hf.py index 7e319ed3..676d9bab 100644 --- a/examples/huggingface_nlp/compile_hf.py +++ b/examples/huggingface_nlp/compile_hf.py @@ -408,6 +408,8 @@ def parse_arguments() -> argparse.Namespace: if torch.distributed.get_rank() == 0: if not args.log_dir.exists(): args.log_dir.mkdir() + if not args.cache_dir.exists(): + args.cache_dir.mkdir() # load error dict error_dict = load_error_summary(args.log_dir) diff --git a/examples/huggingface_nlp/compile_interface.py b/examples/huggingface_nlp/compile_interface.py index b3f5dbb3..6d0a20f9 100644 --- a/examples/huggingface_nlp/compile_interface.py +++ b/examples/huggingface_nlp/compile_interface.py @@ -109,7 +109,7 @@ def parallel(self, model): model, self.dummy_input, pas_policy=self.policy, - compute_config=ComputeConfig(1, 1, dynamic_shape=False), + compute_config=ComputeConfig(1, 1), reuse='override', load_module=True ) From d0b7e5ef3084a0cb2acb4391958b5339995dfc90 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 19 Jun 2024 02:38:45 +0000 Subject: [PATCH 1661/1892] Merged PR 2111: refine optimizer state dict merge This PR is trying to reduce the memory usage when merging by combining zero and tp state together. --- nnscaler/runtime/module.py | 253 ++++++++++++++--------- tests/parallel_module/test_checkpoint.py | 46 ++++- 2 files changed, 200 insertions(+), 99 deletions(-) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index cba53b1e..3445da14 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -333,6 +333,17 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref 'optim_state_dict': optimizer_state_dict, }, filename) + @classmethod + def _safe_tensor_equal(cls, tensor1: torch.Tensor, tensor2: torch.Tensor): + if tensor1.shape != tensor2.shape: + return False + if tensor1.dtype != tensor2.dtype: + return False + if tensor1.device == tensor2.device: + return torch.equal(tensor1, tensor2) + else: + return torch.equal(tensor1.cpu(), tensor2.cpu()) + @staticmethod def merge_model_state_dicts( state_dicts: List[Dict], @@ -355,22 +366,37 @@ def merge_model_state_dicts( raise ValueError("Expected model state dicts to have the same length as fullmaps") full_model_state_dict: Dict[str, torch.Tensor] = {} + # used to track the merging status of each parameter to avoid inconsistence. + # key is the parameter name, value is a set of slicers + # Here we expand slice to (start, step, stop) tuple, + # because before python 3.12, slice object is not hashable + state_dict_merge_track: Dict[str, Set[Tuple[Tuple[Any, Any, Any], ...]]] = {} # gather param/buffer full tensor - for model_state_dict, local_fullmap in zip(state_dicts, fullmaps): + for rank, (model_state_dict, local_fullmap) in enumerate(zip(state_dicts, fullmaps)): for local_name, meta in local_fullmap.items(): if local_name not in model_state_dict: - # this is a non persistent buffer, skip - # non persistent buffer should be stored in the fullmap, but not in the model state dict + # the parameter may not in model_state_dict (deduped with optimization) + # Another casee is when this is a non persistent buffer, we should skip it. + # because non persistent buffer should be stored in the fullmap, but not in the model state dict continue # create full tensor on cpu partial_tensor = model_state_dict[local_name] if meta.orig_name not in full_model_state_dict: full_model_state_dict[meta.orig_name] = torch.empty( meta.shape, dtype=partial_tensor.dtype) + state_dict_merge_track[meta.orig_name] = set() # assign partial tensor if meta.val_chunks > 1: raise NotImplementedError("Not support of partitioning parameter / buffer at value dimension") - full_model_state_dict[meta.orig_name][meta.slicers] = partial_tensor + + state_dict_merge_track_id = tuple((i.start, i.step, i.stop) for i in meta.slicers) + if state_dict_merge_track_id in state_dict_merge_track[meta.orig_name]: + if not CubeModule._safe_tensor_equal(full_model_state_dict[meta.orig_name][meta.slicers], partial_tensor): + raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") + _logger.debug(f'rank {rank}: skip merging duplicated model state for param {meta.orig_name} with slicers {meta.slicers}') + else: + state_dict_merge_track[meta.orig_name].add(state_dict_merge_track_id) + full_model_state_dict[meta.orig_name][meta.slicers] = partial_tensor return full_model_state_dict @staticmethod @@ -461,93 +487,6 @@ def merge_state_dicts( # gather optimizer states full_optim_state_dict: Dict[str, Any] = {} # param_id -> Dict[state_name, value] - # at first, merge the partitioned optimizer states due to zero to the zero-disabled format - if zero_idx_maps is not None: - def _check_state_size(opt_state_keys, bucket_state): - """ - Check that all the keys except the scalar step for a - parameter in optimizer states have the same shaped tensor. - - For example, exp_avg, exp_avg_sq in Adam. - """ - if len(opt_state_keys) <= 1: - return True - return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape - for key in opt_state_keys) - - def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): - assert bucket_size % len(bucket_states) == 0 - opt_state_keys = list(bucket_states[0].keys()) - if 'step' in bucket_states[0]: - opt_state_keys.remove('step') - assert _check_state_size(opt_state_keys, bucket_states[0]), f'the keys {opt_state_keys} have different shape' - # NOTE: only support adam for now - assert 'exp_avg' in opt_state_keys - assert 'exp_avg_sq' in opt_state_keys - chunk_size = bucket_size // len(bucket_states) - start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size - end_rank_id, end_offset = pend // chunk_size, pend % chunk_size - opt_states, opt_states_1d = {}, {} - for key in opt_state_keys: - opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, - device=bucket_states[0][key].device, requires_grad=False) - opt_states_1d[key] = opt_states[key].view(-1) - - if start_rank_id == end_rank_id: - for key in opt_state_keys: - opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] - else: - offset = chunk_size-start_offset - for key in opt_state_keys: - opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] - for i in range(start_rank_id+1, end_rank_id): - for key in opt_state_keys: - opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] - offset += chunk_size - if end_offset: # skip if end_offset == 0, because it is a no-op - for key in opt_state_keys: - opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] - - if 'step' in bucket_states[0]: - opt_states['step'] = bucket_states[0]['step'] - return opt_states - - # Parameters are partitioned inside a scale unit composed of plan_ngpus GPUs. - # When ZeRO-1 is enabled, optimizer states (like exp_avg and exp_avg_sq) are partitioned within - # each ZeRO group. Since the training is done in a synchronized way, the optimizer states are - # identical across each ZeRO group. - # As a result, we can retrieve and merge the optimizer states in other scale units following the - # information stored in zero_idx_maps ONLY for the first scale unit. - opt_state_list = [] - for work_idx in range(plan_ngpus): - model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[work_idx] - opt_state = {} - for model_idx, opt_idx in model_idx2opt_idx.items(): - if isinstance(opt_idx, int): - # the param without reducer - assert opt_idx2ranks[opt_idx] is None - opt_state[model_idx] = optim_state_dicts[work_idx]['state'][opt_idx] - else: - # the param in reducer bucket - opt_idx, pstart, pend, pshape = opt_idx - ranks, bucket_size = opt_idx2ranks[opt_idx] - bucket_states = [optim_state_dicts[rank]['state'][opt_idx] for rank in ranks] - opt_state[model_idx] = _retrieve_param_opt_state( - bucket_states, - pstart, - pend, - pshape, - bucket_size) - _logger.info(f'finish handle optimizer state for worker {work_idx}') - opt_state_list.append(opt_state) - assert len(optim_state_dicts[work_idx]['param_groups']) == 1, 'only support param_groups to be one group' - - # assign opt_state to state_dicts, cannot be assigned in the above loop - for work_idx in range(plan_ngpus): - optim_state_dicts[work_idx]['state'] = opt_state_list[work_idx] - optim_state_dicts[work_idx]['param_groups'][0]['params'] = sorted(opt_state_list[work_idx].keys()) - _logger.info(f'finish assign optimizer state for worker {work_idx}') - # build parameter order to match with the optimizer state order # NOTE: the param IDs in optimizer typically follow the same order of # local `model.parameters()`. However, `state_dict.keys()` contains @@ -565,10 +504,95 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): # parameter in the local model state, and assign the slice to the position. full_optim_state_dict['state'] = {} full_states = full_optim_state_dict['state'] + + def _check_state_size(opt_state_keys, bucket_state): + """ + Check that all the keys except the scalar step for a + parameter in optimizer states have the same shaped tensor. + + For example, exp_avg, exp_avg_sq in Adam. + """ + if len(opt_state_keys) <= 1: + return True + return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape + for key in opt_state_keys) + + def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): + assert bucket_size % len(bucket_states) == 0 + opt_state_keys = list(bucket_states[0].keys()) + if 'step' in bucket_states[0]: + opt_state_keys.remove('step') + assert _check_state_size(opt_state_keys, bucket_states[0]), f'the keys {opt_state_keys} have different shape' + # NOTE: only support adam for now + assert 'exp_avg' in opt_state_keys + assert 'exp_avg_sq' in opt_state_keys + chunk_size = bucket_size // len(bucket_states) + start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size + end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + opt_states, opt_states_1d = {}, {} + for key in opt_state_keys: + opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, + device=bucket_states[0][key].device, requires_grad=False) + opt_states_1d[key] = opt_states[key].view(-1) + + if start_rank_id == end_rank_id: + for key in opt_state_keys: + opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] + else: + offset = chunk_size-start_offset + for key in opt_state_keys: + opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] + for i in range(start_rank_id+1, end_rank_id): + for key in opt_state_keys: + opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] + offset += chunk_size + if end_offset: # skip if end_offset == 0, because it is a no-op + for key in opt_state_keys: + opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + + if 'step' in bucket_states[0]: + opt_states['step'] = bucket_states[0]['step'] + return opt_states + + def _merge_opt_zero(worker_idx, param_idx): + model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[worker_idx] + opt_idx = model_idx2opt_idx[param_idx] + if isinstance(opt_idx, int): + # the param without reducer + assert opt_idx2ranks[opt_idx] is None + return optim_state_dicts[worker_idx]['state'][opt_idx] + else: + # the param in reducer bucket + opt_idx, pstart, pend, pshape = opt_idx + ranks, bucket_size = opt_idx2ranks[opt_idx] + bucket_states = [optim_state_dicts[rank]['state'][opt_idx] for rank in ranks] + return _retrieve_param_opt_state( + bucket_states, + pstart, + pend, + pshape, + bucket_size) + # full_index: param IDs in the full optimizer state for full_index, param_name in enumerate(origin_parameter_names): _logger.info(f'start to handle optimizer state for param {param_name} with full_index {full_index}') - for optim_state, fullmap in zip(optim_state_dicts[0 : plan_ngpus], fullmaps[0 : plan_ngpus]): + # zero_done_track is used to avoid re-merging the same parameter + # in the optimizer state + # zero_done_track_id: slicers + # Here we expand slice to (start, step, stop) tuple, + # because before python 3.12, slice object is not hashable + zero_done_track: Set[Tuple[Tuple[Any, Any, Any], ...]] = set() + # used to track the merging status of each parameter to avoid inconsistence. + # key is slicers + # please note this is only used for non-zero mode + # becase re-merging the same parameter slice (via _merge_opt_zero) is avoided in zero mode + state_merge_track: Set[Tuple[Tuple[Any, Any, Any], ...]] = set() + + # There is this for loop because a parameter may be sharded due to TP, + # consequently, the parameter's optimizer state is also sharded. + # This for loop is for merging the sharded parameter's optimizer state + # into its original full state (i.e., the non-partitioned one). + for work_idx, (optim_state, fullmap) in enumerate(zip(optim_state_dicts[0 : plan_ngpus], fullmaps[0 : plan_ngpus])): if 'state' not in optim_state: continue # adam-like optimizers have optim_state['state']={} before any optimizer.step() if not optim_state['state']: continue @@ -579,25 +603,56 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): for local_index, meta in enumerate(param_fullmap): if meta.orig_name != param_name: continue full_states.setdefault(full_index, {}) + # TODO: support customized param groups, where each parameter has IDs # specified from its own param_group - states: Dict[str, torch.Tensor] = optim_state['state'][local_index] + track_id = tuple((i.start, i.step, i.stop) for i in meta.slicers) + if zero_idx_maps is None: + states: Dict[str, torch.Tensor] = optim_state['state'][local_index] + else: + if track_id not in zero_done_track: + # As ZeRO is applied, the optimizer state of this parameter (a shard) + # may not be stored locally in its optimizer state. + # _merge_opt_zero is for recovering the optimizer state corresponding to this parameter shard. + states: Dict[str, torch.Tensor] = _merge_opt_zero(work_idx, local_index) + zero_done_track.add(track_id) + else: + _logger.debug(f'rank {work_idx}: skip merging duplicated optimizer state for param {full_index} with slicers {meta.slicers}') + continue + for state_name in states.keys(): value = states[state_name] # special handle for step: scalar tensor type if state_name == 'step': - full_states[full_index][state_name] = value + if state_name in full_states[full_index]: + if not CubeModule._safe_tensor_equal(full_states[full_index][state_name], value): + raise ValueError(f"Conflict in merging {param_name}.{state_name} from rank {work_idx}") + else: + full_states[full_index][state_name] = value continue + # for non-tensor states if not isinstance(value, torch.Tensor): - full_states[full_index][state_name] = value + if state_name in full_states[full_index]: + if full_states[full_index][state_name] != value: + raise ValueError(f"Conflict in merging {param_name}.{state_name} from rank {work_idx}") + else: + full_states[full_index][state_name] = value + _logger.debug(f'non-tensor state {state_name} is merged for {full_index}') # for tensor states, like 'exp_avg' else: # create optimizer state tensor if state_name not in full_states[full_index]: full_states[full_index][state_name] = torch.empty(meta.shape, dtype=value.dtype) - # assign with partial tensor - full_states[full_index][state_name][meta.slicers] = value + + if track_id in state_merge_track: + if not CubeModule._safe_tensor_equal(full_states[full_index][state_name][meta.slicers], value): + raise ValueError(f"Conflict in merging {param_name}.{state_name} from rank {work_idx}") + else: + # assign with partial tensor + full_states[full_index][state_name][meta.slicers] = value + + state_merge_track.add(track_id) # handle additional state dict keys for optim_state_dict in optim_state_dicts[0 : plan_ngpus]: @@ -609,6 +664,10 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): _logger.info(f'inherit optimizer state key {key}') full_optim_state_dict[key] = optim_state_dict[key] + # reset the param_groups params to the full parameter list + if 'param_groups' in full_optim_state_dict: # for backward compatibility + full_optim_state_dict['param_groups'][0]['params'] = list(range(len(origin_parameter_names))) + return full_model_state_dict, full_optim_state_dict @staticmethod diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 103c4d10..9a68cfd5 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -286,7 +286,7 @@ def assert_model_state_dict_equal(state_dict1: dict, state_dict2: dict): assert torch.equal(state_dict1[index].cpu(), state_dict2[index].cpu()) -def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inference_module: torch.nn.Module = None): +def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inference_module: torch.nn.Module = None, check_merge_log=False): ckpt_file_template = 'ckpt_{rank}_{start}.pth' ckpt_merged_file_template = 'ckpt_merged_{start}.pth' temp_inferenece_ckpt_file_template = 'inference-{rank}.pth' @@ -398,7 +398,25 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf ckpt_state_dicts = [torch.load(f) for f in ckpt_files] model_state_dicts = [ckpt['model'] for ckpt in ckpt_state_dicts] optimizer_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_state_dicts] - merged_model_state_dicts, merged_optimizer_state_dict = merge_state_dicts(model_state_dicts, optimizer_state_dicts) + if check_merge_log: + from nnscaler.runtime.module import _logger + import logging + from io import StringIO + string_stream = StringIO() + old = _logger.level + _logger.setLevel(logging.DEBUG) + handler = logging.StreamHandler(string_stream) + handler.setLevel(logging.DEBUG) + _logger.addHandler(handler) + merged_model_state_dicts, merged_optimizer_state_dict = merge_state_dicts(model_state_dicts, optimizer_state_dicts) + logs = string_stream.getvalue() + # check some zero merging is skipped due to replicate + assert 'skip merging duplicated optimizer state for param' in logs + assert 'skip merging duplicated model state for param' in logs + _logger.removeHandler(handler) + _logger.setLevel(old) + else: + merged_model_state_dicts, merged_optimizer_state_dict = merge_state_dicts(model_state_dicts, optimizer_state_dicts) torch.save({ 'model': merged_model_state_dicts, 'optimizer': merged_optimizer_state_dict @@ -521,3 +539,27 @@ def test_checkpoint_intra_reducer(module_type, use_zero): assert torch.equal(a.grads[k], b.grads[k]) for k in a.weights.keys(): # weights assert torch.equal(a.weights[k], b.weights[k]) + + +def _gpu_merge_worker(): + init_distributed() + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_merge') as tempdir: + compiled_module = _create_cube_module('data', + ComputeConfig(2, 2, use_zero=True), + tempdir, + 'whole', + ) + _train( + compiled_module, + 1, + 0, + 0, + 8, + tempdir, + check_merge_log=True + ) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_checkpoint_merge(): + launch_torchrun(2, _gpu_merge_worker) From 4dc11660eb271fb7fff059b7c57ceedb9cde8bd2 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 21 Jun 2024 08:44:10 +0000 Subject: [PATCH 1662/1892] Merged PR 2180: lightning: fix gradient sync and gradient averaging --- docs/source/parallel_module.md | 6 +- .../lightning/pytorch/precision.py | 65 ++++++--- .../integration/lightning/pytorch/strategy.py | 70 ++++++++- nnscaler/parallel.py | 4 + .../lightning/pytorch/simple_models.py | 27 +++- .../lightning/pytorch/test_strategy.py | 133 +++++++++++++++++- tests/parallel_module/common.py | 16 +++ 7 files changed, 289 insertions(+), 32 deletions(-) diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index b6b734b5..c7816683 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -487,7 +487,11 @@ To support distributed training, in the function we need to hook 4 places (which `build_optimizer` will patch optimizer for you. Besides the above patches, we also add several utility functions/variables to optimizer: -1. `sync_shard_grad`: Sync the shard gradients of the module from nodes with same shard to the optimizer. This function is called in optimizer's pre-step hook. But If you want to access the gradients before `optimizer.step()`(for example, you need gnorm), you need to call this function manually. +1. `sync_shard_grad`: Sync the shard gradients of the module from nodes with same shard to the optimizer. +Please note the gradients are `None` until `optimizer.sync_shard_grad()` is called. +This function is called in optimizer's pre-step hook. You need to manually call it in two cases: + - If you want to access the gradients before `optimizer.step()`. + - When closure is used in optimizer.step(). In this case, optimizer's pre-step hook will be called before `train_step`, so no gradients are synced. 2. `scale_grads`: Scale the gradients of the module by multiplying a factor. This function is useful to avoid overflow when the gradients are large. Please note you can only call this function **after** `sync_shard_grad`, because the gradients are `None` until `sync_shard_grad` is called. diff --git a/nnscaler/integration/lightning/pytorch/precision.py b/nnscaler/integration/lightning/pytorch/precision.py index 895282bc..5f43e03c 100644 --- a/nnscaler/integration/lightning/pytorch/precision.py +++ b/nnscaler/integration/lightning/pytorch/precision.py @@ -2,6 +2,7 @@ import torch from torch import Tensor +import torch.distributed from torch.optim import Optimizer import torch.amp @@ -16,7 +17,7 @@ from lightning.pytorch.utilities import GradClipAlgorithmType -_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"] +_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true", "16-mixed", "bf16-mixed"] class NnScalerPrecision(Precision): @@ -25,7 +26,8 @@ class NnScalerPrecision(Precision): .. warning:: This is an :ref:`experimental ` feature. Args: - precision: Full precision (32-true), half precision (16-true, bf16-true) + precision: Full precision (32-true), half precision (16-true, bf16-true) or + mixed precision (16-mixed, bf16-mixed). Raises: ValueError: @@ -56,6 +58,8 @@ def __init__( self.scaler = scaler precision_to_type = { + "bf16-mixed": torch.float32, + "16-mixed": torch.float32, "bf16-true": torch.bfloat16, "16-true": torch.float16, "32-true": torch.float32, @@ -80,7 +84,9 @@ def module_init_context(self) -> ContextManager: @override def forward_context(self) -> ContextManager: - return _DtypeContextManager(self._desired_input_dtype) + if "mixed" in self.precision: + return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return self.tensor_init_context() @override def convert_input(self, data: Any) -> Any: @@ -96,6 +102,17 @@ def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: tensor = self.scaler.scale(tensor) return super().pre_backward(tensor, module) + @override + def _after_closure(self, model: "pl.LightningModule", optimizer: Steppable) -> None: + if self.scaler is None: # will be handled in optimizer_step instead of here when using scaler + self._sync_grad(model, optimizer) + super()._after_closure(model, optimizer) + + def _sync_grad(self, model: "pl.LightningModule", optimizer: Steppable): + optimizer.sync_shard_grad() # closure is used, so we have to sync gradients after closure + cf = model._trainer.strategy.compute_config + optimizer.scale_grads(cf.plan_ngpus / cf.runtime_ngpus) + @override def optimizer_step( # type: ignore[override] self, @@ -107,23 +124,26 @@ def optimizer_step( # type: ignore[override] if self.scaler is None: return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs) - closure_result = closure() + # TODO: test the following logic - if not _optimizer_handles_unscaling(optimizer): - # Unscaling needs to be performed here in case we are going to apply gradient clipping. - # Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam). - # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. - self.scaler.unscale_(optimizer) # type: ignore[arg-type] + closure_result = closure() + self._sync_grad(model, optimizer) + # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. + # Unscaling needs to be performed after grad sync but before gradient clipping + self.scaler.unscale_(optimizer) self._after_closure(model, optimizer) - skipped_backward = closure_result is None - # in manual optimization, the closure does not return a value - if not model.automatic_optimization or not skipped_backward: - # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found - step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type] - self.scaler.update() - return step_output - return closure_result + + if not model.automatic_optimization: + raise ValueError("nnscaler does not support manual optimization.") + if closure_result is None: + # in manual optimization, the closure does not return a value + raise ValueError("nnscaler does not support None as the return value of the closure.") + + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found + step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type] + self.scaler.update() + return step_output def clip_gradients( self, @@ -138,3 +158,14 @@ def clip_gradients( raise ValueError('nnscaler does not support clipping gradients by value.') elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: optimizer.clip_gnorm(clip_val) # define in nnscaler + + @override + def state_dict(self) -> Dict[str, Any]: + if self.scaler is not None: + return self.scaler.state_dict() + return {} + + @override + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if self.scaler is not None: + self.scaler.load_state_dict(state_dict) diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py index 274b151e..48daf229 100644 --- a/nnscaler/integration/lightning/pytorch/strategy.py +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -1,8 +1,9 @@ -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from functools import partial import logging from pathlib import Path import os +import types from typing import ( TYPE_CHECKING, Any, @@ -22,9 +23,11 @@ import torch from torch import Tensor +import torch.distributed from torch.nn import Module from torch.optim import Optimizer from typing_extensions import TypeGuard, override +from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.pytorch.accelerators import Accelerator, CUDAAccelerator @@ -55,6 +58,7 @@ import nnscaler from nnscaler.integration.lightning.utils import inplace_optimizer_fn +from nnscaler.runtime.device import DeviceGroup from .precision import NnScalerPrecision @@ -196,14 +200,38 @@ def setup(self, trainer: "pl.Trainer") -> None: if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: raise MisconfigurationException("nnscaler does not support clipping gradients by value.") + def _get_dummy_forward_args(self, model: pl.LightningModule) -> Dict[str, Any]: + # two options to set the dummy forward arguments + + if hasattr(model, 'dummy_forward_args'): + if not isinstance(model.dummy_forward_args, dict): + raise ValueError("The `dummy_forward_args` must be a dictionary with forward arguments names as keys.") + return model.dummy_forward_args + + if hasattr(model, 'dummy_forward_args_fn'): + dummy_forward_args_fn = getattr(model, 'dummy_forward_args_fn') + if not callable(dummy_forward_args_fn): + raise ValueError("The `dummy_forward_args_fn` must be a callable function.") + trainer = model._trainer + data_source = trainer.fit_loop._data_source + assert data_source is not None, "The `data_source` must be defined in the trainer." + assert data_source.instance is not None, "The `instance` must be defined in the data source." + with enforce_0_num_worker(DataLoader): + dataloader = data_source.dataloader() + assert dataloader.num_workers == 0, "The dataloader must have `num_workers=0`." + data = next(iter(dataloader)) + dummy_forward_args = dummy_forward_args_fn(data) + if not isinstance(dummy_forward_args, dict): + raise ValueError("The return value of `dummy_forward_args_fn` must be a dictionary with forward arguments names as keys.") + return dummy_forward_args + + raise ValueError("The `dummy_forward_args` or `dummy_forward_args_fn` must be defined in the module.") + @override def _setup_model(self, model: Module) -> Module: """Set up a module for inference (no optimizers). """ - if getattr(model, 'dummy_forward_args', None) is None: - raise ValueError("The `dummy_forward_args` must be defined as a property in the module.") - if not isinstance(model.dummy_forward_args, dict): - raise ValueError("The `dummy_forward_args` must be a dictionary with forward arguments names as keys.") + dummy_forward_args = self._get_dummy_forward_args(model) old_training_flag = model.training if not old_training_flag: @@ -211,7 +239,7 @@ def _setup_model(self, model: Module) -> Module: model.train() # always use the model in training mode pmodule = nnscaler.parallelize( model, - self.precision_plugin.convert_input(model.dummy_forward_args), + self.precision_plugin.convert_input(dummy_forward_args), self.pas_policy, self.compute_config, gen_savedir=self.gen_savedir, @@ -240,6 +268,24 @@ def _setup_model(self, model: Module) -> Module: # rewrite model forward to parallelized model forward model.forward = pmodule.forward + # patch log function to add sync_dist_group + rank = torch.distributed.get_rank() + # create all groups + plan_ngpus = self.compute_config.plan_ngpus + runtime_ngpus = self.compute_config.runtime_ngpus + for i in range(plan_ngpus): + DeviceGroup().get_group( + list(range(i, runtime_ngpus, plan_ngpus)) + ) + sync_group = list(range(rank % plan_ngpus, runtime_ngpus, plan_ngpus)) + + _old_log = model.log + def _new_log(self, *args, **kwargs) -> None: + kwargs['sync_dist_group'] = sync_group + _old_log(*args, **kwargs) + model.log = types.MethodType(_new_log, model) + model._old__log = _old_log + return model @override @@ -496,3 +542,15 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: def _get_process_group_backend(self) -> str: return 'nccl' # nnscaler only support nccl + + +@contextmanager +def enforce_0_num_worker(cls) -> Generator[None, None, None]: + """Context manager to enforce the number of workers to be 0 in DataLoader.""" + _old__init__ = cls.__init__ + def _new__init__(self, *args, **kwargs) -> None: + kwargs['num_workers'] = 0 + _old__init__(self, *args, **kwargs) + cls.__init__ = _new__init__ + yield + cls.__init__ = _old__init__ diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index bd23710f..544ec81a 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1319,6 +1319,10 @@ def _step_post_hook(opt, *args, **kwargs): for m in parallel_modules: m.gather_params() + # Please note: + # register_step_pre_hook doesn't work expectly + # when closure is used in optimizer.step() + # in that case, you must call sync_shard_grad() manually optimizer.register_step_pre_hook(_step_pre_hook) optimizer.register_step_post_hook(_step_post_hook) diff --git a/tests/integration/lightning/pytorch/simple_models.py b/tests/integration/lightning/pytorch/simple_models.py index 6081a716..d72a0c87 100644 --- a/tests/integration/lightning/pytorch/simple_models.py +++ b/tests/integration/lightning/pytorch/simple_models.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from lightning.pytorch import LightningModule from torch import Tensor, nn +from torch.optim.optimizer import Optimizer from torchmetrics import Accuracy from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset @@ -30,10 +31,12 @@ def __init__(self, num_features=32, num_classes=3, batch_size=10, lr=0.01): self.train_acc = acc.clone() self.valid_acc = acc.clone() self.test_acc = acc.clone() + self.dummy_forward_args_fn = lambda batch: {"x": batch[0]} + self.update_history = [] - @property - def dummy_forward_args(self): - return {'x': torch.randn(self.batch_size, self.num_features)} + # @property + # def dummy_forward_args(self): + # return {'x': torch.randn(self.batch_size, self.num_features)} def forward(self, x): x = self.layer_0(x) @@ -77,6 +80,24 @@ def predict_step(self, batch, batch_idx): x, _ = batch return self.forward(x) + def configure_gradient_clipping(self, optimizer: Optimizer, gradient_clip_val, gradient_clip_algorithm) -> None: + def _fix_name(name): + prefix = 'nnscaler_pmodule.' + if name.startswith(prefix): + return name[len(prefix):] + return name + grads = {_fix_name(n): p.grad.cpu() for n, p in self.named_parameters()} + weights = {_fix_name(n): p.data.cpu() for n, p in self.named_parameters()} + self.update_history.append((grads, weights)) + return super().configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm) + + +class ClassificationModelWithLRScheduler(ClassificationModel): + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [scheduler] + class RandomDictDataset(Dataset): """ diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index 10ac539d..a6112c8c 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -10,23 +10,27 @@ import pytest from unittest.mock import Mock, patch +import nnscaler from nnscaler.parallel import ComputeConfig from nnscaler.integration.lightning.pytorch import NnScalerStrategy, NnScalerPrecision +import nnscaler.runtime from ....launch_torchrun import launch_torchrun +from ....utils import init_random +from ....parallel_module.common import assert_close, assert_equal from .simple_datamodules import ClassifDataModule -from .simple_models import BoringModel, ClassificationModel +from .simple_models import BoringModel, ClassificationModel, ClassificationModelWithLRScheduler def fit_worker(tmp_path): dm = ClassifDataModule() model = ClassificationModel() - compute_config=ComputeConfig(2, 2) + compute_config=ComputeConfig(1, 1) trainer = Trainer( default_root_dir=tmp_path, max_epochs=2, - accelerator="gpu", devices=2, - gradient_clip_val=2.0, + accelerator="gpu", devices=1, + gradient_clip_val=None, strategy=NnScalerStrategy(compute_config=compute_config, pas_policy='tp', gen_savedir=tmp_path), plugins=[NnScalerPrecision('32-true')] ) @@ -36,7 +40,7 @@ def fit_worker(tmp_path): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_multi_gpu_model_only(tmp_path): - launch_torchrun(2, fit_worker, tmp_path) + launch_torchrun(1, fit_worker, tmp_path) def ckpt_path_epoch_restored_worker(tmp_path): @@ -130,3 +134,122 @@ def trainer_accumulate_grad_batches_zero_grad(tmp_path, accumulate_grad_batches) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_trainer_accumulate_grad_batches_zero_grad(tmp_path, accumulate_grad_batches): launch_torchrun(2, trainer_accumulate_grad_batches_zero_grad, tmp_path, accumulate_grad_batches) + + +def correctnes_worker_nnscaler(tmp_path, gradient_clip_val, with_lr_scheduler, + precision='32-true', + with_tp=False, with_empty_scaler=False +): + init_random() + dm = ClassifDataModule() + init_random() + if with_lr_scheduler: + model = ClassificationModelWithLRScheduler() + else: + model = ClassificationModel() + if with_tp: + compute_config=ComputeConfig(2, 4) + policy = 'tp' + devices = 4 + else: + compute_config=ComputeConfig(1, 2) + policy = 'dp' + devices = 2 + scaler = None + if with_empty_scaler or precision == '16-mixed': + scaler = torch.cuda.amp.GradScaler(enabled=(precision == '16-mixed')) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=2, + accelerator="gpu", devices=devices, + gradient_clip_val=gradient_clip_val, + strategy=NnScalerStrategy( + compute_config=compute_config, pas_policy=policy, gen_savedir=tmp_path, + instance_name=policy + ), + plugins=[NnScalerPrecision(precision, scaler=scaler)] + ) + trainer.fit(model, datamodule=dm) + return model.update_history, model.nnscaler_pmodule.fullmap + + +def correctnes_worker_ddp(tmp_path, gradient_clip_val, with_lr_scheduler, precision='32-true'): + init_random() + dm = ClassifDataModule() + init_random() + if with_lr_scheduler: + model = ClassificationModelWithLRScheduler() + else: + model = ClassificationModel() + trainer = Trainer( + default_root_dir=tmp_path, + precision=precision, + max_epochs=2, + accelerator="gpu", devices=2, + gradient_clip_val=gradient_clip_val, + strategy='ddp', + ) + trainer.fit(model, datamodule=dm) + return model.update_history + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize("gradient_clip_val", [None, 0.1]) # 0.1 is chosen to only clip the first update +@pytest.mark.parametrize("with_lr_scheduler", [False, True]) +def test_correctness(tmp_path, gradient_clip_val, with_lr_scheduler): + def _merge_results(returns): + results = [returns[i][0] for i in range(len(returns))] + fullmaps = [returns[i][1] for i in range(len(returns))] + weight_results = [] + grad_results = [] + for i in range(len(results[0])): + weight_results.append( + nnscaler.runtime.module.ParallelModule.merge_state_dicts( + fullmaps, + [result[i][1] for result in results] + )[0] + ) + grad_results.append( + nnscaler.runtime.module.ParallelModule.merge_state_dicts( + fullmaps, + [result[i][0] for result in results] + )[0] + ) + return weight_results, grad_results + + # Test 16-mixed with and without gradient clipping + # when gradient clipping is on, the following check will fail + # TODO: fix the test when gradient clipping is on + if not gradient_clip_val: + ddp_results = launch_torchrun(2, correctnes_worker_ddp, tmp_path, gradient_clip_val, with_lr_scheduler, '16-mixed') + + nnscaler_returns = launch_torchrun(2, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler, '16-mixed', False, True) + nnscaler_merged_weight_results_fp16, nnscaler_merged_grad_results_fp16 = _merge_results(nnscaler_returns) + + for i in range(len(ddp_results[0])): + assert_close(nnscaler_merged_weight_results_fp16[i], ddp_results[0][i][1]) + assert_close(nnscaler_merged_grad_results_fp16[i], ddp_results[0][i][0]) + assert_equal(ddp_results[1][i], ddp_results[0][i]) + + nnscaler_returns = launch_torchrun(2, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler) + nnscaler_merged_weight_results, nnscaler_merged_grad_results = _merge_results(nnscaler_returns) + + nnscaler_returns = launch_torchrun(2, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler, '32-true', False, True) + nnscaler_merged_weight_results_scaler, nnscaler_merged_grad_results_scaler = _merge_results(nnscaler_returns) + for i in range(len(nnscaler_merged_weight_results_scaler)): + assert_equal(nnscaler_merged_weight_results[i], nnscaler_merged_weight_results_scaler[i]) + assert_equal(nnscaler_merged_grad_results[i], nnscaler_merged_grad_results_scaler[i]) + + ddp_results = launch_torchrun(2, correctnes_worker_ddp, tmp_path, gradient_clip_val, with_lr_scheduler) + for i in range(len(ddp_results[0])): + assert_close(nnscaler_merged_weight_results[i], ddp_results[0][i][1]) + assert_close(nnscaler_merged_grad_results[i], ddp_results[0][i][0]) + assert_equal(ddp_results[1][i], ddp_results[0][i]) + + if torch.cuda.device_count() >= 4: + nnscaler_returns = launch_torchrun(4, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler, '32-true', True) + nnscaler_merged_weight_results, nnscaler_merged_grad_results = _merge_results(nnscaler_returns) + + for i in range(len(ddp_results[0])): + assert_close(nnscaler_merged_weight_results[i], ddp_results[0][i][1]) + assert_close(nnscaler_merged_grad_results[i], ddp_results[0][i][0]) diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 5dbbb752..1ff3456e 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -133,3 +133,19 @@ def assert_equal(a: Any, b: Any): assert_equal(a[i], b[i]) else: assert a == b + + +def assert_close(a: Any, b: Any, atol=1e-6, rtol=1e-6): + assert type(a) == type(b) + if isinstance(a, torch.Tensor): + assert torch.allclose(a.cpu(), b.cpu(), atol=atol, rtol=rtol) + elif isinstance(a, dict): + assert len(a) == len(b) + for k in a.keys(): + assert_close(a[k], b[k]) + elif isinstance(a, (list, tuple)): + assert len(a) == len(b) + for i in range(len(a)): + assert_close(a[i], b[i]) + else: + raise ValueError(f'unsupported type {type(a)}') \ No newline at end of file From 98d57c86ac6731152b9198529376e95b79a08a9a Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Mon, 24 Jun 2024 03:23:54 +0000 Subject: [PATCH 1663/1892] Merged PR 2183: fix cache dir fix cache dir is a str, no exists() function --- examples/huggingface_nlp/compile_hf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/huggingface_nlp/compile_hf.py b/examples/huggingface_nlp/compile_hf.py index 676d9bab..c48a40bc 100644 --- a/examples/huggingface_nlp/compile_hf.py +++ b/examples/huggingface_nlp/compile_hf.py @@ -404,11 +404,12 @@ def parse_arguments() -> argparse.Namespace: args = parse_arguments() if isinstance(args.log_dir, str): args.log_dir = pathlib.Path(os.path.expanduser(args.log_dir)) + args.cache_dir = pathlib.Path(os.path.expanduser(args.cache_dir)) nnscaler.init() if torch.distributed.get_rank() == 0: - if not args.log_dir.exists(): + if args.log_dir and not args.log_dir.exists(): args.log_dir.mkdir() - if not args.cache_dir.exists(): + if args.cache_dir and not args.cache_dir.exists(): args.cache_dir.mkdir() # load error dict From 312539ed24c10835948ea5c82966376f67b96672 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 24 Jun 2024 07:18:23 +0000 Subject: [PATCH 1664/1892] Merged PR 2179: Refine follow logic in autodist After this PR, the `follow` logic in autodist is - the father (followed) op should not contain a sum dim (linear is defined as a sum op, since it has a sum dim in computation) - a unary op (like GeLU) will try to follow its producer - if a op's inputs are from multiple producers (like add, concat), it will follow the 1st producer if the producers are in a same `follow region`. Update the test case to elaborate this PR. Fix the bug in dp solver when computing the in edges for a dp node. --- nnscaler/autodist/autodist_config.py | 2 +- nnscaler/autodist/cube_operator.py | 15 +- nnscaler/autodist/dp_solver.cpp | 24 +-- nnscaler/autodist/spmd_solver.py | 66 ++++--- tests/autodist/spmd_solver/test_follow.py | 81 ++++---- .../comp/_operator.neg.json | 108 +++++++++++ .../comp/nnscaler.runtime.function.cat.json | 91 +++++++++ .../nnscaler.runtime.function.fullslice.json | 123 ++++++++++++ .../comp/torch.Tensor.view.json | 50 +++++ .../comp/torch.add.json | 130 +++++++++++++ .../comp/torch.mul.json | 170 ++++++++++++++++ .../comp/torch.nn.functional.linear.json | 182 ++++++++++++++++++ .../comp/torch.sum.json | 120 ++++++++++++ .../comp/torch.transpose.json | 62 ++++++ .../comp/torch.unsqueeze.json | 36 ++++ 15 files changed, 1183 insertions(+), 77 deletions(-) create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.Tensor.view.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.nn.functional.linear.json create mode 100644 tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.transpose.json diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 87d165be..c925890b 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -98,7 +98,7 @@ class AutoDistConfig: - max_pipeline_unbalance_ratio (`float`, *optional*, defaults to `0.5`): The maximum unbalance ratio in pipeline parallelism. The higher the ratio, the more unbalance is required, the smaller search space will be explored. - - solver (`str`, *optional*, defaults to `'dp'`): + - solver (`str`, *optional*, defaults to `'ilp'`): The solver to use in spmd parallelism. Currently only support `'dp'` (dynamic programming) `'ilp'` (integer linear programming). diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py index 24a987e9..fc69eb0e 100644 --- a/nnscaler/autodist/cube_operator.py +++ b/nnscaler/autodist/cube_operator.py @@ -29,11 +29,12 @@ def __init__(self, ir_cell: IRFwOperation): self.in_tensors, self.out_tensors = [], [] self.op_name = self.ir_cell.signature - self.producers: Set[CubeOperator] = set() - self.consumers: Set[CubeOperator] = set() + self.producers: List[CubeOperator] = list() + self.consumers: List[CubeOperator] = list() self.dim_info = {} self.parallelable_dims = set() + self._has_sum_dim = False self._recompute = False self._recompute_start_op = False @@ -55,6 +56,10 @@ def __init__(self, ir_cell: IRFwOperation): self.collect_anno_info() + @property + def has_sum_dim(self): + return self._has_sum_dim + @property def recompute(self): return self._recompute @@ -72,10 +77,10 @@ def recompute_start_op(self, value: bool): self._recompute_start_op = value def add_producer(self, producer: 'CubeOperator'): - self.producers.add(producer) + self.producers.append(producer) def add_consumer(self, consumer: 'CubeOperator'): - self.consumers.add(consumer) + self.consumers.append(consumer) def collect_anno_info(self): for idx_shape, shape_anno in enumerate(self.ir_cell.anno.inputs()): @@ -86,6 +91,8 @@ def collect_anno_info(self): reduce_type = dim_anno.reduces[idx_id] if reduce_type != DimAnno.ReduceType.Freeze: self.parallelable_dims.add(identifier) + if reduce_type == DimAnno.ReduceType.Sum: + self._has_sum_dim = True val = (idx_shape, idx_dim, idx_id, reduce_type) if identifier not in self.dim_info: self.dim_info[identifier] = val diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index 8e3826b8..51da4e76 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -110,13 +110,13 @@ void ThreadPool::waitFinished() { const int MAX_CONCURRENCY = std::thread::hardware_concurrency(); ThreadPool pool(MAX_CONCURRENCY); -std::vector> split_work(int num) { +std::vector> split_work(int num, int base) { std::vector work; - if (num < MAX_CONCURRENCY) { + if (num < base) { work = std::vector(num, 1); } else { - work = std::vector(MAX_CONCURRENCY, num / MAX_CONCURRENCY); - for (int i = 0; i < num % MAX_CONCURRENCY; ++i) { + work = std::vector(base, num / base); + for (int i = 0; i < num % base; ++i) { work[i] += 1; } } @@ -463,14 +463,11 @@ class DPSolver { find_existing_follow = true; // update if (tmp->id < producer->id) { - for (int _ = 0; _ < producer->p_num; ++_) { - if (producer->p_father[_] == - tmp->p_father[cur_ir[i].second]) { - // replace to align with the filter logic in python - // only the newest node in the follow chain is kept - cur_ir[i] = std::make_pair(producer->id, _); - break; - } + if (tmp->p_father[cur_ir[i].second] != + producer->p_father[producer_p]) { + is_legal = false; + } else { + cur_ir[i] = std::make_pair(producer_id, producer_p); } } break; @@ -642,8 +639,7 @@ class DPSolver { << ", state num: " << iter->second->dp_nodes.size() << std::endl; } - std::vector> split_info = - split_work(iter->second->dp_num); + std::vector> split_info = split_work(iter->second->dp_num, MAX_CONCURRENCY); for (const auto &item : split_info) { pool.enqueue([=] { for (int i = 0; i < item.second; ++i) { diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index b2027e2e..f491e8e4 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -14,6 +14,7 @@ import yaml import numpy import logging +import functools from dataclasses import dataclass, asdict from collections import defaultdict from pathlib import Path @@ -444,17 +445,26 @@ def build_following_relationships(self): self.follow_ids = list(range(self.graph.op_num)) self.father_ids = list(range(self.graph.op_num)) + def follow(op_idx: int, follow_idx: int): + # an operator is not allowed to be followed if it has a sum dimension + # this constraint is added to make sure the output shapes of the partitions + # are different. + if not self.get_operator(follow_idx).has_sum_dim: + self.follow_ids[op_idx] = follow_idx + self.father_ids[op_idx] = self.get_father_id(follow_idx) + for i, op in enumerate(self.graph.operator_list): - # - op consumes tensors from only one producer - # - op has only one input tensor - # - the producer has only one input tensor if len(self.producers[i]) == 1: + # consumes tensors from only one producer and has only one input tensor if len(op.in_tensors) == 1: - j = self.producers[i][0] - # constrain the following chain starts from a unary operator - if len(self.graph.operator_list[j].in_tensors) == 1: - self.follow_ids[i] = j - self.father_ids[i] = self.get_father_id(j) + follow(i, self.producers[i][0]) + elif not op.in_tensors and isinstance(op.ir_cell, IRDimops): + raise RuntimeError(f'find operator {op.ir_cell} has producer but no input tensor') + elif len(self.producers[i]) > 1: + producer_father_ids = [self.get_father_id(j) for j in self.producers[i]] + # all producers have the same father + if len(set(producer_father_ids)) == 1: + follow(i, self.producers[i][0]) _logger.info('finish building following relationships') @@ -513,20 +523,19 @@ def calc_father4op_partition(): self.get_op_partition_count(i)))) father_id2preserved_pids[i] = set(p_fathers[-1]) else: - cur_p_fathers = [-1] * self.get_op_partition_count(i) + # store the candidate father idxs for each partition + cur_p_fathers = [-1 for _ in range(self.get_op_partition_count(i))] for producer in self.producers[i]: if self.get_father_id(producer) != fi: continue # assume there is only one tensor from producer to consumer idx_map = find_idx_map(self.get_operator(producer), self.get_operator(i)) - if len(idx_map) != 1: - raise RuntimeError( - f'find multiple or no idx_map {idx_map}') + if not idx_map: + raise RuntimeError(f'find no idx_map {idx_map} between {self.get_operator(producer)} and {self.get_operator(i)}') u, v = idx_map[0] for j, tgt_p in enumerate(self._op_partitions[i]): - have_changed = False - p_father = -1 + p_father = [] for k, src_p in enumerate( self._op_partitions[producer]): # use shape to check follow relationship between partitions @@ -534,17 +543,18 @@ def calc_father4op_partition(): if src_p.ir_cell.outputs()[u].shape == tgt_p.ir_cell.inputs()[v].shape and \ not src_p.is_partial_val: p_producer = p_fathers[producer][k] - if p_producer == -1: - p_father = -1 - else: - if not have_changed: - p_father = p_producer - have_changed = True - # if p_father = -1, this partition will be filtered out - if cur_p_fathers[j] != -1: - assert p_father == cur_p_fathers[ - j], f'{i} {self.get_operator(i).ir_cell} {fi} {self.get_operator(fi).ir_cell}' - cur_p_fathers[j] = p_father + if p_producer != -1: + p_father.append(p_producer) + if cur_p_fathers[j] == -1: + cur_p_fathers[j] = p_father + else: + cur_p_fathers[j] = list(set(cur_p_fathers[j]).intersection(set(p_father))) + for j in range(self.get_op_partition_count(i)): + if not cur_p_fathers[j]: + cur_p_fathers[j] = -1 + else: + assert len(cur_p_fathers[j]) == 1, f'unexpected partition {self.get_operator(i).ir_cell}, {cur_p_fathers[j]}' + cur_p_fathers[j] = cur_p_fathers[j][0] p_fathers.append(cur_p_fathers) # -1 will be filtered out in the intersection operation below father_id2preserved_pids[fi] = father_id2preserved_pids[ @@ -674,6 +684,7 @@ def calc_partition_cost(self, op_idx: int, partition_idx: int): def calc_partition_info(self): self.partition_info: List[List[PartitionCostDesc]] = list() + state_num = 0 for i in range(self.graph.op_num): cur_info = [] _logger.debug(f'calc partition info for {self.get_operator(i)}') @@ -687,6 +698,11 @@ def calc_partition_info(self): cur_info.append(cost_desc) _logger.debug(f'{self._op_partitions[i][j]} {cost_desc}') self.partition_info.append(cur_info) + cut_partition_cnts = [self.get_op_partition_count(idx) for idx in self.cut_ops[i]] + cur_state_num = functools.reduce(lambda x, y: x * y, cut_partition_cnts, 1) + state_num += cur_state_num + _logger.debug(f'{i}-th operator follow {self.get_father_id(i)} with cut ops {self.cut_ops[i]}, {cut_partition_cnts}, {cur_state_num}') + _logger.info(f'total state num is {state_num}') _logger.info('finish spmd solver initializetion') def estimate_min_mem(self, start: int, end: int) -> int: diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index 03221091..5dea60ef 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -17,7 +17,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=0): cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) @@ -27,10 +27,20 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): class Model(torch.nn.Module): - def __init__(self): + def __init__(self, head_num, hidden_dim): super().__init__() + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.head_num = head_num + self.hidden_dim = hidden_dim + self.head_dim = hidden_dim // head_num - def forward(self, q, k, cos, sin, position_ids): + def forward(self, x, cos, sin, position_ids): + bsz, seq_len, hidden_dim = x.shape + q = self.q_proj(x) + k = self.k_proj(x) + q = q.view(bsz, seq_len, self.head_num, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.head_num, self.head_dim).transpose(1, 2) q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) out = q + k return out.sum() @@ -38,15 +48,17 @@ def forward(self, q, k, cos, sin, position_ids): @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA unavailable') def test_follow_rope(): - bsz, seq_len, hidden_dim = 2, 128, 512 + from nnscaler.ir.unique import IDGenerator + IDGenerator().clear() + bsz, seq_len, head_num, hidden_dim = 2, 128, 8, 512 + head_dim = hidden_dim // head_num dummy_input = { - 'q': torch.rand(bsz, 1, seq_len, hidden_dim), - 'k': torch.rand(bsz, 1, seq_len, hidden_dim), - 'cos': torch.rand(seq_len, hidden_dim), - 'sin': torch.rand(seq_len, hidden_dim), + 'x': torch.rand(bsz, seq_len, hidden_dim), + 'cos': torch.rand(seq_len, head_dim), + 'sin': torch.rand(seq_len, head_dim), 'position_ids': torch.arange(seq_len, dtype=torch.long), } - model = Model() + model = Model(head_num, hidden_dim) model.train() fx_graph = to_fx_graph(model, dummy_input) @@ -57,28 +69,31 @@ def test_follow_rope(): constant_folding=True) ''' the computation graph is as follows: - getitem getitem - | | - unsqueeze unsqueeze - | \ | - | -------------------------------------------------mul - | | | - mul fullsclie fullslice | fullsclie fullslice | - | \ | | \ | | - | \ neg | \ neg | - | \ | | \ | | - | concat | concat | - | | | | | - add-----------mul-------------------------mul----------add - | | - | | - ---------------------------add-------------------------- - | - sum + q_proj fullslice fullslice k_proj + | | | | + view unsqueeze unsqueeze view + | | \ | | + transpose | -------------------------------------------------mul----transpose + | | | | + ------------mul fullsclie fullslice | fullsclie fullslice | + | \ | | \ | | + | \ neg | \ neg | + | \ | | \ | | + | concat | concat | + | | | | | + add-----------mul-------------------------mul----------add + | | + | | + ---------------------------add-------------------------- + | + sum currently, the following chain is only composed of unary ops there are 2 chains in total: - 1. fullslice -> neg - 2. fullslice -> neg + 1. view -> transpose -> fullslice -> fullslice -> neg -> concat + 2. view -> transpose -> fullslice -> fullslice -> neg -> concat + 3. fullslice -> unsqueeze + 4. fullslice -> unsqueeze + 5. add -> sum in future, we may add follow chains for binary ops, like mul, add, etc. ''' @@ -95,14 +110,15 @@ def test_follow_rope(): ) assert spmd_solver.follow_ids == [ - 0, 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 11, 12, 13, 13, 15, 16, 17, 18, 19 + 0, 1, 2, 2, 4, 4, 6, 6, 8, 8, 10, 3, 3, 12, 11, 15, 16, 17, 5, 5, 19, 18, 22, 23, 24, 24 ] partition_counts = [ spmd_solver.get_op_partition_count(i) for i in range(model_graph.op_num) ] - assert partition_counts[6] == partition_counts[7] - assert partition_counts[13] == partition_counts[14] + chains = [[2, 3, 11, 12, 13, 14], [4, 5, 18, 19, 20, 21], [2, 3], [4, 5], [24, 25]] + for chain in chains: + assert all(partition_counts[i] == partition_counts[chain[0]] for i in chain) class Attention(torch.nn.Module): @@ -354,4 +370,3 @@ def helper(search_out): ilp_spmd_outs = spmd_solver.do_ilp([(0, model_graph.op_num - 1)], 1) assert helper(dp_spmd_outs) == expected_out assert helper(ilp_spmd_outs) == expected_out - diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/_operator.neg.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/_operator.neg.json index 2f1a93e2..edad6b2b 100644 --- a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/_operator.neg.json +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/_operator.neg.json @@ -46,5 +46,113 @@ "infer_memory": 131072, "train_mem_info": [], "train_mem2in_idx": [] + }, + "(2, 1, 128, 256)-(2, 1, 128, 256) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01745130866765976, + "bw_span": 0.05410369485616684, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 1, 128, 256)-(1, 1, 128, 256) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01782551407814026, + "bw_span": 0.06865318864583969, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 64, 256)-(2, 1, 64, 256) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.019250623881816864, + "bw_span": 0.03654696047306061, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 128, 128)-(2, 1, 128, 128) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01733209937810898, + "bw_span": 0.05305930972099304, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 32)-(2, 8, 128, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.017808005213737488, + "bw_span": 0.05735121667385101, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 32)-(1, 8, 128, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.017965584993362427, + "bw_span": 0.054376013576984406, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 32)-(2, 4, 128, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01771729439496994, + "bw_span": 0.053922832012176514, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 32)-(2, 8, 64, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.015980377793312073, + "bw_span": 0.044227391481399536, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 16)-(2, 8, 128, 16) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.017199106514453888, + "bw_span": 0.0558655709028244, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] } } \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.cat.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.cat.json index 781eeb37..bd739dc9 100644 --- a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.cat.json +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.cat.json @@ -37,5 +37,96 @@ "infer_memory": 262144, "train_mem_info": [], "train_mem2in_idx": [] + }, + "(2, 1, 128, 256)-(2, 1, 128, 256)-(2, 1, 128, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.02269688993692398, + "bw_span": 0.07137507200241089, + "infer_memory": 1048576, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 1, 128, 256)-(1, 1, 128, 256)-(1, 1, 128, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 131072, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.016112439334392548, + "bw_span": 0.07224995642900467, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 64, 256)-(2, 1, 64, 256)-(2, 1, 64, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 131072, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.015472806990146637, + "bw_span": 0.04730988293886185, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 32)-(2, 8, 128, 32)-(2, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01870095729827881, + "bw_span": 0.07958691567182541, + "infer_memory": 1048576, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 32)-(1, 8, 128, 32)-(1, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 131072, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.022658519446849823, + "bw_span": 0.07670298218727112, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 32)-(2, 4, 128, 32)-(2, 4, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 131072, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.022013112902641296, + "bw_span": 0.07071848958730698, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 32)-(2, 8, 64, 32)-(2, 8, 64, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 131072, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.027807801961898804, + "bw_span": 0.07438678294420242, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] } } \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.fullslice.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.fullslice.json index 97036417..2d982aa1 100644 --- a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.fullslice.json +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/nnscaler.runtime.function.fullslice.json @@ -73,5 +73,128 @@ "infer_memory": 0, "train_mem_info": [], "train_mem2in_idx": [] + }, + "(2, 1, 128, 512)-(2, 1, 128, 256) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.017194636166095734, + "bw_span": 0.12950021773576736, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 1, 128, 512)-(1, 1, 128, 256) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.017298385500907898, + "bw_span": 0.15945546329021454, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 1, 64, 512)-(2, 1, 64, 256) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01769084483385086, + "bw_span": 0.15969909727573395, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(128, 64)-(128,)-(128, 64) : torch.float32-torch.int64-torch.float32 : False-False-False": { + "in_mem_info": [ + 32768, + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.015087053179740906, + "bw_span": 0.0, + "infer_memory": 66560, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 64)-(2, 8, 128, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.016987882554531097, + "bw_span": 0.13937149196863174, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(128, 32)-(128,)-(128, 32) : torch.float32-torch.int64-torch.float32 : False-False-False": { + "in_mem_info": [ + 16384, + 1024 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.014596432447433472, + "bw_span": 0.0, + "infer_memory": 33792, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(128, 64)-(64,)-(64, 64) : torch.float32-torch.int64-torch.float32 : False-False-False": { + "in_mem_info": [ + 32768, + 512 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013903342187404633, + "bw_span": 0.0, + "infer_memory": 49664, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 64)-(1, 8, 128, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.018783099949359894, + "bw_span": 0.1173781231045723, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 64)-(2, 4, 128, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012349337339401245, + "bw_span": 0.09924005717039108, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 64)-(2, 8, 64, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.017327629029750824, + "bw_span": 0.14248937368392944, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] } } \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.Tensor.view.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.Tensor.view.json new file mode 100644 index 00000000..72b2b558 --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.Tensor.view.json @@ -0,0 +1,50 @@ +{ + "(2, 128, 512)-(2, 128, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.010309554636478424, + "bw_span": 0.03663599491119385, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 512)-(1, 128, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00978522002696991, + "bw_span": 0.03732983022928238, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 512)-(2, 64, 8, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.012000463902950287, + "bw_span": 0.037534162402153015, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 256)-(2, 128, 4, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.00963546335697174, + "bw_span": 0.034671276807785034, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.add.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.add.json index 4e0f6057..fe21819a 100644 --- a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.add.json +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.add.json @@ -63,5 +63,135 @@ "infer_memory": 33554432, "train_mem_info": [], "train_mem2in_idx": [] + }, + "(2, 128, 128, 512)-(2, 128, 128, 512)-(2, 128, 128, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 67108864, + 67108864 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.2963786944746971, + "bw_span": 0.5899650976061821, + "infer_memory": 201326592, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 128, 512)-(1, 128, 128, 512)-(1, 128, 128, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 33554432, + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.15056170523166656, + "bw_span": 0.29656700789928436, + "infer_memory": 100663296, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 128, 512)-(2, 64, 128, 512)-(2, 64, 128, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 33554432, + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.15017595142126083, + "bw_span": 0.2963421866297722, + "infer_memory": 100663296, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 64, 512)-(2, 128, 64, 512)-(2, 128, 64, 512) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 33554432, + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.1500900834798813, + "bw_span": 0.2969391644001007, + "infer_memory": 100663296, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 128, 256)-(2, 128, 128, 256)-(2, 128, 128, 256) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 33554432, + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.1501079648733139, + "bw_span": 0.2966005355119705, + "infer_memory": 100663296, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 64)-(2, 8, 128, 64)-(2, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 524288, + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01867692917585373, + "bw_span": 0.05651172250509262, + "infer_memory": 1572864, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 64)-(1, 8, 128, 64)-(1, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.011582672595977783, + "bw_span": 0.035241805016994476, + "infer_memory": 786432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 64)-(2, 4, 128, 64)-(2, 4, 128, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.01173894852399826, + "bw_span": 0.03551803529262543, + "infer_memory": 786432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 64)-(2, 8, 64, 64)-(2, 8, 64, 64) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.019761547446250916, + "bw_span": 0.05498770624399185, + "infer_memory": 786432, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 32)-(2, 8, 128, 32)-(2, 8, 128, 32) : torch.float32-torch.float32-torch.float32 : True-True-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.011792778968811035, + "bw_span": 0.03712214529514313, + "infer_memory": 786432, + "train_mem_info": [], + "train_mem2in_idx": [] } } \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.mul.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.mul.json index 3e48818b..4ad0aa0e 100644 --- a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.mul.json +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.mul.json @@ -63,5 +63,175 @@ "infer_memory": 33554432, "train_mem_info": [], "train_mem2in_idx": [] + }, + "(2, 1, 128, 512)-(128, 1, 512)-(2, 128, 128, 512) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 524288, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.10940209031105042, + "bw_span": 0.3049662336707115, + "infer_memory": 67895296, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(1, 1, 128, 512)-(128, 1, 512)-(1, 128, 128, 512) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.05617942661046982, + "bw_span": 0.157836452126503, + "infer_memory": 34078720, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(2, 1, 64, 512)-(128, 1, 512)-(2, 128, 64, 512) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 262144, + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.058712996542453766, + "bw_span": 0.15809480100870132, + "infer_memory": 34078720, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(2, 1, 128, 256)-(128, 1, 256)-(2, 128, 128, 256) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 262144, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.056887976825237274, + "bw_span": 0.15803445130586624, + "infer_memory": 33947648, + "train_mem_info": [ + 131072 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(2, 1, 128, 512)-(64, 1, 512)-(2, 64, 128, 512) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 524288, + 131072 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.05685463547706604, + "bw_span": 0.1590479165315628, + "infer_memory": 34209792, + "train_mem_info": [ + 131072 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(2, 8, 128, 64)-(1, 128, 64)-(2, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 524288, + 32768 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.02037864178419113, + "bw_span": 0.06314236670732498, + "infer_memory": 1081344, + "train_mem_info": [ + 32768 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(1, 8, 128, 64)-(1, 128, 64)-(1, 8, 128, 64) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 262144, + 32768 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.014657899737358093, + "bw_span": 0.042913854122161865, + "infer_memory": 557056, + "train_mem_info": [ + 32768 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(2, 4, 128, 64)-(1, 128, 64)-(2, 4, 128, 64) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 262144, + 32768 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.0200541689991951, + "bw_span": 0.06307531148195267, + "infer_memory": 557056, + "train_mem_info": [ + 32768 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(2, 8, 64, 64)-(1, 64, 64)-(2, 8, 64, 64) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 262144, + 16384 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.011976435780525208, + "bw_span": 0.03905370831489563, + "infer_memory": 540672, + "train_mem_info": [ + 16384 + ], + "train_mem2in_idx": [ + 1 + ] + }, + "(2, 8, 128, 32)-(1, 128, 32)-(2, 8, 128, 32) : torch.float32-torch.float32-torch.float32 : True-False-True": { + "in_mem_info": [ + 262144, + 16384 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.02042073756456375, + "bw_span": 0.0607198104262352, + "infer_memory": 540672, + "train_mem_info": [ + 16384 + ], + "train_mem2in_idx": [ + 1 + ] } } \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.nn.functional.linear.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.nn.functional.linear.json new file mode 100644 index 00000000..85ab103e --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.nn.functional.linear.json @@ -0,0 +1,182 @@ +{ + "(2, 1, 128, 512)-(512, 512)-(2, 1, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.025295652449131012, + "bw_span": 0.07937606424093246, + "infer_memory": 2097152, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 1, 128, 512)-(512, 512)-(1, 1, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.04148464649915695, + "bw_span": 0.1379936933517456, + "infer_memory": 1572864, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 1, 64, 512)-(512, 512)-(2, 1, 64, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.04674382507801056, + "bw_span": 0.12613851577043533, + "infer_memory": 1572864, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 1, 128, 256)-(512, 256)-(2, 1, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.03477931022644043, + "bw_span": 0.10643582791090012, + "infer_memory": 1310720, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 1, 128, 512)-(256, 512)-(2, 1, 128, 256) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.04062950611114502, + "bw_span": 0.12806560844182968, + "infer_memory": 1310720, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 512)-(512, 512)-(2, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.03484264016151428, + "bw_span": 0.10529998689889908, + "infer_memory": 2097152, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(1, 128, 512)-(512, 512)-(1, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.04092026501893997, + "bw_span": 0.14705508947372437, + "infer_memory": 1572864, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 64, 512)-(512, 512)-(2, 64, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 1048576 + ], + "buffer_mem_info": [], + "fw_span": 0.04442241042852402, + "bw_span": 0.11804904788732529, + "infer_memory": 1572864, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 256)-(512, 256)-(2, 128, 512) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.04016123712062836, + "bw_span": 0.1149836927652359, + "infer_memory": 1310720, + "train_mem_info": [ + 262144 + ], + "train_mem2in_idx": [ + 0 + ] + }, + "(2, 128, 512)-(256, 512)-(2, 128, 256) : torch.float32-torch.float32-torch.float32 : False-True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [ + 524288 + ], + "buffer_mem_info": [], + "fw_span": 0.03950390964746475, + "bw_span": 0.11718999594449997, + "infer_memory": 1310720, + "train_mem_info": [ + 524288 + ], + "train_mem2in_idx": [ + 0 + ] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.sum.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.sum.json index 0d67bd7d..c210e88f 100644 --- a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.sum.json +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.sum.json @@ -58,5 +58,125 @@ "infer_memory": 2048, "train_mem_info": [], "train_mem2in_idx": [] + }, + "(2, 128, 128, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 67108864 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.10197646915912628, + "bw_span": 0.19819512963294983, + "infer_memory": 67110912, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 128, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.05556941032409668, + "bw_span": 0.09961947798728943, + "infer_memory": 33556480, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 128, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.055540911853313446, + "bw_span": 0.09965654462575912, + "infer_memory": 33556480, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 64, 512)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.05510002374649048, + "bw_span": 0.1001240685582161, + "infer_memory": 33556480, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 128, 256)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 33554432 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.055274367332458496, + "bw_span": 0.10407418012619019, + "infer_memory": 33556480, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 64)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.016355328261852264, + "bw_span": 0.030515156686306, + "infer_memory": 525824, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 8, 128, 64)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013197772204875946, + "bw_span": 0.03109648823738098, + "infer_memory": 262656, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 4, 128, 64)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013214349746704102, + "bw_span": 0.031638890504837036, + "infer_memory": 262656, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 64, 64)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.021334923803806305, + "bw_span": 0.05695987492799759, + "infer_memory": 262656, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 8, 128, 32)-(1,) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.013319402933120728, + "bw_span": 0.03231540322303772, + "infer_memory": 262656, + "train_mem_info": [], + "train_mem2in_idx": [] } } \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.transpose.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.transpose.json new file mode 100644 index 00000000..6f2150bb --- /dev/null +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.transpose.json @@ -0,0 +1,62 @@ +{ + "(2, 128, 8, 64)-(2, 8, 128, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 524288 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006978213787078857, + "bw_span": 0.02502594143152237, + "infer_memory": 524288, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(1, 128, 8, 64)-(1, 8, 128, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006794556975364685, + "bw_span": 0.05167778581380844, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 64, 8, 64)-(2, 8, 64, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.006595253944396973, + "bw_span": 0.024466030299663547, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 4, 64)-(2, 4, 128, 64) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009507313370704651, + "bw_span": 0.03769509494304657, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(2, 128, 8, 32)-(2, 8, 128, 32) : torch.float32-torch.float32 : True-True": { + "in_mem_info": [ + 262144 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.009664148092269897, + "bw_span": 0.041610561311244965, + "infer_memory": 262144, + "train_mem_info": [], + "train_mem2in_idx": [] + } +} \ No newline at end of file diff --git a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.unsqueeze.json b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.unsqueeze.json index dafafa89..56237093 100644 --- a/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.unsqueeze.json +++ b/tests/autodist/spmd_solver/test_follow_rope_profile/comp/torch.unsqueeze.json @@ -34,5 +34,41 @@ "infer_memory": 0, "train_mem_info": [], "train_mem2in_idx": [] + }, + "(128, 64)-(1, 128, 64) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 32768 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005337037146091461, + "bw_span": 0.0, + "infer_memory": 32768, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(64, 64)-(1, 64, 64) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 16384 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005495548248291016, + "bw_span": 0.0, + "infer_memory": 16384, + "train_mem_info": [], + "train_mem2in_idx": [] + }, + "(128, 32)-(1, 128, 32) : torch.float32-torch.float32 : False-False": { + "in_mem_info": [ + 16384 + ], + "param_mem_info": [], + "buffer_mem_info": [], + "fw_span": 0.005407258868217468, + "bw_span": 0.0, + "infer_memory": 16384, + "train_mem_info": [], + "train_mem2in_idx": [] } } \ No newline at end of file From eeef286cf5975a25e20466bec446fd4cecb66ad9 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 25 Jun 2024 06:32:44 +0000 Subject: [PATCH 1665/1892] Merged PR 2186: hotfix: non-tensor support for consistence check in merging --- nnscaler/runtime/module.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 3445da14..b6285ef8 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -334,7 +334,13 @@ def save_checkpoint(self, optimizer: torch.optim.Optimizer = None, filename_pref }, filename) @classmethod - def _safe_tensor_equal(cls, tensor1: torch.Tensor, tensor2: torch.Tensor): + def _safe_tensor_equal(cls, tensor1: Any, tensor2: Any): + # in different versions, the data may be different types + # for example, step in optimizer.state_dict can be scalar tensor or int. + if type(tensor1) != type(tensor2): + return False + if not isinstance(tensor1, torch.Tensor): + return tensor1 == tensor2 if tensor1.shape != tensor2.shape: return False if tensor1.dtype != tensor2.dtype: From deb1d8497d21d8f7d3b38dbfb9e60c85601add36 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 25 Jun 2024 10:21:27 +0000 Subject: [PATCH 1666/1892] Merged PR 2187: Fix parity alert: forbidden to follow operators that contains attributes parity alert passed ![image.png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2187/attachments/image.png) --- nnscaler/autodist/cube_operator.py | 7 +++++++ nnscaler/autodist/spmd_solver.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py index fc69eb0e..ee867e90 100644 --- a/nnscaler/autodist/cube_operator.py +++ b/nnscaler/autodist/cube_operator.py @@ -37,6 +37,7 @@ def __init__(self, ir_cell: IRFwOperation): self._has_sum_dim = False self._recompute = False self._recompute_start_op = False + self._has_attr = False self.omit_recompute_in_idx = [] self.omit_train_idx = [] @@ -50,6 +51,8 @@ def __init__(self, ir_cell: IRFwOperation): for item in ir_cell.inputs(): if isinstance(item, IRTensor): self.in_tensors.append(item) + if item.is_attr(): + self._has_attr = True for item in ir_cell.outputs(): if isinstance(item, IRTensor): self.out_tensors.append(item) @@ -60,6 +63,10 @@ def __init__(self, ir_cell: IRFwOperation): def has_sum_dim(self): return self._has_sum_dim + @property + def has_attr(self): + return self._has_attr + @property def recompute(self): return self._recompute diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index f491e8e4..b4d417bc 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -449,7 +449,8 @@ def follow(op_idx: int, follow_idx: int): # an operator is not allowed to be followed if it has a sum dimension # this constraint is added to make sure the output shapes of the partitions # are different. - if not self.get_operator(follow_idx).has_sum_dim: + follow_op = self.get_operator(follow_idx) + if not follow_op.has_sum_dim and not follow_op.has_attr: self.follow_ids[op_idx] = follow_idx self.father_ids[op_idx] = self.get_father_id(follow_idx) From 6bb80de77014ddce0022db127312de94a52c4e9f Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Tue, 25 Jun 2024 11:56:26 +0000 Subject: [PATCH 1667/1892] Merged PR 2185: bugfix to train_mem2in_idx The index in `train_mem2in_idx` is the original index of the input of the operator, here add a mapping for the original index to the pure tensor index. This bug is found by the functions that didn't put tensor input in the front, i.e., `torch.gather`. --- nnscaler/autodist/model_graph.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 23d95a08..fc8b9792 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -745,10 +745,13 @@ def label_ops(self, operator_list: List[CubeOperator]): for i, idx in enumerate(train_mem2in_idx): if idx == -1: continue - if operator.in_tensors[idx].tid in counted_tensors: + tensor = operator.ir_cell.inputs()[idx] + assert isinstance(tensor, IRTensor), f'expect tensor, but get {type(tensor)}' + if tensor.tid in counted_tensors: operator.omit_train_idx.append(i) else: - counted_tensors.add(operator.in_tensors[idx].tid) + counted_tensors.add(tensor.tid) + # deduplicate parameter and buffer tensors # assume the traverse order of input tensors is the same as # the order in profiling From b943f8ea8f968dbeb6d63ef79f10c17803c536c2 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 26 Jun 2024 06:45:30 +0000 Subject: [PATCH 1668/1892] Merged PR 2184: parser: never fold getattr node 'self.training' parser: never fold getattr node 'self.training' unit test pass parity check pass --- docs/source/self_training.md | 102 +++++++++++ nnscaler/codegen/emit.py | 46 +++-- nnscaler/codegen/frontend_mapping.py | 20 ++- nnscaler/graph/function/function.py | 14 +- nnscaler/graph/parser/converter.py | 2 +- .../concrete_trace_utils/operator_patcher.py | 57 ++++++ nnscaler/graph/parser/fx/parser.py | 28 ++- nnscaler/runtime/function/function.py | 10 +- tests/codegen/test_emit.py | 6 +- tests/graph/parser/test_ast_transformer.py | 56 ++++++ tests/parallel_module/test_gencode.py | 170 ++++++++++++++++++ tests/utils.py | 4 + 12 files changed, 483 insertions(+), 32 deletions(-) create mode 100644 docs/source/self_training.md diff --git a/docs/source/self_training.md b/docs/source/self_training.md new file mode 100644 index 00000000..26bdf566 --- /dev/null +++ b/docs/source/self_training.md @@ -0,0 +1,102 @@ +# self.training support + +To parallelize the training process, we firstly need to trace the module and get a static computational graph. + +A common problem with static graph is that it is impossible to handle control flow. + +But on the other hand, `self.training` is very common used in module forward method. +So we add a very limited support for `self.training` in tracing. + +Please note that user code is flattened and transformed into a single `ParallelModule` at runtime, so `training` is a global module state, and we don't support the case that user want to set a sub-module's training to True but remaining modules to False. + +## `if` statement + +We don't support any control flow, so For the following code, we only put the `if` branch that is executed during tracing into the graph. + +```python +if self.training: + ... +else + ... +``` +The consequence is that model training/validation will use exactly the same code path. + +## `if` expression + +Some torch operations use `if` expression to select different parameters, for example + +```python +torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, + dropout_p=self.dropout if self.training else 0, + is_causal=self.is_causal +) +``` +To support that, we provide a limited `if` expression support, +by converting `if` expression to a function call. + +For example: + +We will convert + +```python +x = a if self.training else b +``` +to +```python +x = nnscaler.runtime.function.ifexpr(self.training, a, b) +``` + +This trick is not free. It will introduce two side effects: +1. Short-circuit evaluation is not supported. +Both branches will be evaluated, so you must make sure that both branches are valid, and have no side effect. +To reduce the side effect, we will check true expr/false expr, and requires both don't contain function calls. +so the following code will not be converted: + ```python + x = f(a) if self.training else b + ``` +2. We will convert `if` expression only if the condition is `self.training`. +So if a non-module class has a `training` attribute, the `if` expression in its member functions will also be converted if its condition is `self.training`. + +Please note you can always use `register_op` to define a custom op to handle the `if` expression. +For example, you can convert the above code to: +```python +import nnscaler +import torch + + +@nnscaler.register_op('?, ? -> ?') +def get_dropout(training, dropout): + return dropout if training else 0 + +torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, + dropout_p=get_dropout(self, self.dropout), + is_causal=self.is_causal +) +`` + +## self.training as a parameter + +If you use `self.training` as a parameter, it is well supported. + +For example: +```python +torch.nn.functional.dropout(x, 0.1, self.training) +# the generated code will be exactly the same as the original code: +# torch.nn.functional.dropout(x, 0.1, self.training) +``` + +But be careful, if you use `self.training` in a boolean operation, +the generated code may be not as you expected, because +1. We don't trace bool operations. +2. Boolean operations are short-circuit evaluated, so only one expression will be kept in generated code. + +For example: +```python +torch.nn.functional.dropout(x, 0.1, global_setting.enable_dropout or self.training) +# if global_setting.enable_dropout is True, the generated code will be +# torch.nn.functional.dropout(x, 0.1, True) +# if global_setting.enable_dropout is False, the generated code will be +# torch.nn.functional.dropout(x, 0.1, self.training) +``` diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index 2bfb9885..fec27406 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -1,6 +1,8 @@ -from typing import Generator, Iterable, List, Any, Optional, Tuple +from typing import Generator, Iterable, List, Any, Optional, Tuple, Dict import logging +import torch + from nnscaler.ir.cten import IRCell, IRTensor, IRObject from nnscaler.ir.tensor import IRSubTensor from nnscaler.ir.operator import IRDataOperation, IRFwOperation @@ -42,6 +44,8 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: Returns: the val that can be repr safely """ + if isinstance(val, IRValue): + return val if isinstance(val, IRObject): tensor_name = val.name tensor_name = tensor_name.replace('.', '_') @@ -58,11 +62,23 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: elif isinstance(val, tuple): # TODO: support subclasses of tuple, like torch.Size? return tuple(_safe_repr_value(v, prefix_attr) for v in val) - elif isinstance(val, (int, str, bool, float, type(None), bytes, type(Ellipsis))): + elif isinstance(val, (int, str, bool, float, type(None), bytes, type(Ellipsis), torch.dtype)): return val + elif isinstance(val, torch.device): + # use device string representation + # this should be rarely used + # as we will ignore device parameters. + return val.type if val.index is None else f'{val.type}:{val.index}' raise ValueError(f'Unsupported data type: {type(val)}') +def _safe_str_dict(val: Dict[str, Any], prefix_attr: Optional[str] = None) -> Dict[str, str]: + """ + Return str-able value of a dict of tensors or values. + """ + return {k: repr(_safe_repr_value(v, prefix_attr)) for k, v in val.items()} + + class CodeEmission: """ Basic emission @@ -91,6 +107,7 @@ def complex_name(self, val: Any, prefix_attr: Optional[str]=None) -> str: """ modifier = lambda t: IRValue(self.tensor_name(t, prefix_attr)) val = IRSegment.modify_objects_of_complex(val, modifier) + # TODO: use repr() instead of str() return str(val) def tuple_name(self, tensors: List[Any], @@ -109,6 +126,7 @@ def tuple_name(self, tensors: List[Any], if isinstance(t, IRTensor) and skip_attr and t.is_attr(): continue names.append(self.tensor_name(t, prefix_attr)) + # TODO: use repr() name = '(' + ', '.join(names + ['']) + ')' return name @@ -135,18 +153,26 @@ def return_name_complex(self, vals: List[Any], def kwargs_name(self, **kwargs) -> str: """Get kwarg name""" names = [] - # FIXME make the str include `""` - # for name, val in kwargs.items(): - # if isinstance(val, str) and not val.startswith('self.'): - # kwargs[name] = '"' + val + '"' # turn object into name modifier = lambda t: IRValue(self.tensor_name(t)) kwargs = IRSegment.modify_objects_of_complex(kwargs, modifier) for name, val in kwargs.items(): + # TODO: use repr() instead of str() + # names.append(f'{name}={repr(val)}') + # the problem here is current adapter prims use dtype as str for code generation + # It is too big change for now, and will fix it later. names.append(f'{name}={val}') name = ', '.join(names) return name + def kwargs_dict(self, **kwargs) -> Dict[str, str]: + """Get kwarg dict + Key is the orignial string + And value is the `repr` of the value, + so you can safely use it in the code generation + """ + return _safe_str_dict(kwargs) + class FuncEmission(CodeEmission): def __init__(self): @@ -191,13 +217,7 @@ def emit_fnode(self, node: IRFwOperation, runtime_devid: int, plan_ndevs: int, r # setup arg string inputs = [self.tensor_name(t, prefix_attr=prefix_attr) for t in node.inputs()] # setup kwarg string - kwargs = dict(**node.kwargs) - for name, val in kwargs.items(): - if isinstance(val, str) and not val.startswith('self.'): - kwargs[name] = '"' + val + '"' - # turn IRObject into name - modifier = lambda t: IRValue(self.tensor_name(t)) - kwargs = IRSegment.modify_objects_of_complex(kwargs, modifier) + kwargs = self.kwargs_dict(**node.kwargs) emit_rule = self._emit_rules.map(signature) body = emit_rule(node, inputs, kwargs, runtime_devid, plan_ndevs, runtime_ndevs) diff --git a/nnscaler/codegen/frontend_mapping.py b/nnscaler/codegen/frontend_mapping.py index 973c321a..61bd27e6 100644 --- a/nnscaler/codegen/frontend_mapping.py +++ b/nnscaler/codegen/frontend_mapping.py @@ -8,6 +8,8 @@ from nnscaler.ir.operator import IRFwOperation from nnscaler.graph.parser.register import CustomizedOps +from nnscaler.graph.parser.fx.parser import SELF_GETATTR_SIG + class Sign2EmitRule: """Emit rule for frontend PyTorch codegen""" @@ -16,7 +18,8 @@ def __init__(self) -> None: # the registered emit rules self._sign2rule = { 'torch.slice': self.emit_slice, - 'setattr': self.emit_setattr, + SELF_GETATTR_SIG: self.emit_self_getattr, + 'nnscaler.runtime.function.function.ifexpr': self.emit_ifexpr, } def map(self, signature: str) -> Callable: @@ -75,7 +78,18 @@ def emit_slice(self, node: IRFwOperation, arg_vars: List[str], kw_pairs: Dict[st return f"{in_tensor_var}[{', '.join(subscript_components)}]" - def emit_setattr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + def emit_self_getattr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: """Special rule for generating setattr node """ - assert False, f"This emit rule is deprecated, please report if you reach here" + assert len(node.inputs()) == 1, f"self_getattr should have 1 input, but got {len(node.inputs())}" + assert isinstance(node.input(0), str), f"self_getattr should have string input, but got {type(node.input(0))}" + # use node.input(0) instead of arg_vars[0] + # because we don't want to use it `repr` form + return f'self.{node.input(0)}' + + def emit_ifexpr(self, node, arg_vars: List[str], kw_pairs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule for generating setattr node + """ + assert len(node.inputs()) == 3, f"ifexpr should have 3 inputs, but got {len(node.inputs())}" + return f'{arg_vars[1]} if {arg_vars[0]} else {arg_vars[2]}' + diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index b81cf74a..d2bff238 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -860,7 +860,7 @@ def Dropout(input, p=0.5, training=True, inplace=False, signature = None): """ annos = ['* -> *'] return IRDimops(Dropout, 'dropout', signature, annos, [input], - p=p, training='self.training', inplace=inplace) + p=p, training=training, inplace=inplace) def nnDropout(input, p=0.5, inplace=False, signature=None): @@ -2242,7 +2242,7 @@ def GetItem(a: Any, b: Any, signature = None) -> Union[Any, IRPyFunc]: def SetItem(__a: Any, __b: Any, __c: Any, *additonal, signature = None) -> Union[Any, IRPyFunc]: """ _operator.setitem(__a, __b, __c) / nnscaler.runtime.function.setitem(__a, *__bc) - + If __a is a IRTensor and __b is a tuple, __b will be flatten to ensure we can give each element an annotation, and the returned value is a IRDimops. If __a is a IRObject, the returned value is a IRPyFunc. @@ -2627,7 +2627,7 @@ def Type(tensor: IRTensor, dtype: Optional[Union[str, torch.dtype, IRObject]] = raise ValueError("Expected 'out' to be None") annos = ['* -> *'] original_dtype = dtype - dtype = _unwrap_value(dtype) + dtype = _unwrap_value(dtype) if dtype is None: return IRPyFunc(signature,[tensor], [IRObject(value=str(tensor.dtype))]) else: @@ -2666,9 +2666,9 @@ def Conv1D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, """ torch.nn.functional.conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor """ - if isinstance(stride, int): + if isinstance(stride, int): stride = (stride,) - if isinstance(dilation, int): + if isinstance(dilation, int): dilation = (dilation,) if isinstance(padding, str): if padding == 'same': @@ -2804,7 +2804,7 @@ def Gather(input: IRTensor, dim, index: IRTensor, sparse_grad=False, out=None, s input_anno[i] += '^' index_anno[i] += '^' else: - # TODO: Currently, this only works in static cases. + # TODO: Currently, this only works in static cases. # When dynamic shape is enabled, this partition may be incorrect. # We keep the partition here for now, and consider reporting errors that cannot be partitioned at run time in future. index_anno[i] = input_anno[i] @@ -2839,7 +2839,7 @@ def Unfold(input: IRTensor, kernel_size, dilation=1, padding=0, stride=1, signat """ if not isinstance(input, IRTensor) or len(input.shape) != 4: raise ValueError("Input must be an IRTensor with 4 dimensions, [N, C, H, W].") - + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size dilation = (dilation, dilation) if isinstance(dilation, int) else dilation padding = (padding, padding) if isinstance(padding, int) else padding diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 30b37e17..27e46701 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -79,7 +79,7 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: leaf_functions = {func: LeafFnWrapInfo([], True, None) for func in autowrap_funcs if func is not None} # get cube runtime functions - cube_rt_funcs = [cube_rt_function.anchor] + cube_rt_funcs = [cube_rt_function.anchor, cube_rt_function.ifexpr] leaf_functions.update({ func: LeafFnWrapInfo([Location(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py index 953a439f..9f9e8a34 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -42,6 +42,56 @@ class OperatorTransformer(TrackedTransformer): ast.IsNot: 'is_not', # operator.is_not ast.In: 'contains', # operator.contains } + def visit_IfExp(self, node: ast.IfExp): + # only handle self.training case + # Attribute(value=Name(id='self', ctx=Load()), attr='training', ctx=Load()) + # And the body and orelse should not contain any function call + # because we can't handle the short-circuit evaluation in if-expression + # For example, + # `x[0] if x is not None else None` will raise an error + # if we convert it to `nnscaler.runtime.function.ifexpr(x is not None, x[0], None)` + if not _orig_isinstance(node.test, ast.Attribute) \ + or not _orig_isinstance(node.test.value, ast.Name) \ + or node.test.value.id != 'self' or node.test.attr != 'training'\ + or any(_orig_isinstance(n, ast.Call) for n in ast.walk(node.body)) \ + or any(_orig_isinstance(n, ast.Call) for n in ast.walk(node.orelse)): + return self.generic_visit(node) + + self.modified = True + # convert to nnscaler.runtime.function.ifexpr(condition, true_expr, false_expr) + # Please note short-circuit evaluation is not supported in this function. + # so it is not 100% equivalent to the original if-else expression. + # TODO: support short-circuit evaluation, + # which requires to expand the condition/true_expr/false_expr inplace + # For example, currently implementation will convert: + # x = f(m) if a else g(n) + # to: + # x = nnscaler.runtime.function.ifexpr(a, f(m), g(n)) + # And the generated code will be + # t0 = f(m) + # t1 = g(n) + # x = t0 if a else t1 + # The fix should remove t0/t1, and expand them in if-expression. + return self.generic_visit( + ast.Call( + func=ast.Attribute( + attr='ifexpr', + value=ast.Attribute( + attr='function', + value = ast.Attribute( + attr='runtime', + value=ast.Name(id='nnscaler', ctx=ast.Load()), + ctx=ast.Load(), + ), + ctx=ast.Load(), + ), + ctx=ast.Load(), + ), + args=[node.test, node.body, node.orelse], + keywords=[] + ) + ) + def visit_UnaryOp(self, node: ast.UnaryOp): if _orig_isinstance(node.op, ast.Not): self.modified = True @@ -232,6 +282,13 @@ def patch_func_helper(self, func): ], level=0 ), + # equals to + # import nnscaler + ast.Import( + names=[ + ast.alias(name='nnscaler') + ] + ), *body0.body ] body0.name = func_name diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index 43791e4d..94624eb9 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -19,6 +19,10 @@ _logger = logging.getLogger(__name__) +# virtual signature for `self.` +SELF_GETATTR_SIG = 'self_getattr' + + class FxModuleParser: """ torch.fx module parser @@ -225,7 +229,11 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule 'You can register it as a customized function using nnscaler.register_op to remove this warning' _logger.warning(warning_msg) is_constant = False - ir_node = IRPyFunc(fsig, input_vals, [IRObject(frame.get_var(node.name), is_constant=is_constant)], **kwargs) + output = frame.get_var(node.name) + if not isinstance(output, IRObject): + # avoid nested IRObject + output = IRObject(name=node.name, value=output, is_constant=is_constant) + ir_node = IRPyFunc(fsig, input_vals, [output], **kwargs) if isinstance(ir_node, IRCell): module_stack = node.meta.get('nn_module_stack') @@ -279,7 +287,11 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, There are two types of get_attr, one is `FxNodeKind.PrimGetAttr` which is dealt with in this function. The other is `FxNodeKind.PrimCallFunction ` (i.e., ) which is dealt with by parse_prim_function_method. + + The object of get_attr node is always the traced module or its sub modules. + node.target is the attribute name of the object. """ + ir_nodes = [] concrete_value = FxModuleParser.fetch_attr(module, node.target) if isinstance(concrete_value, torch.Tensor): assert isinstance(concrete_value, torch.Tensor), \ @@ -300,9 +312,17 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, else: frame.set_var(node.name, exist_tensor) else: - assert not isinstance(concrete_value, torch.Tensor), f"GetAttrPrim: unexpected parameter" - frame.set_var(node.name, concrete_value) - return [] + if node.target == 'training': + # Let's just support `self.training` and ignore all other cases for now + output = IRObject(name=node.name, value=frame.get_var(node.name), is_constant=False) + ir_node = IRPyFunc(SELF_GETATTR_SIG, ['training'], [output]) + frame.set_var(node.name, output) + # never fold the IRPyFunc node + ir_nodes.append(ir_node) + else: + frame.set_var(node.name, concrete_value) + + return ir_nodes @staticmethod def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 0104d4c0..018a3efc 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple, Union +from typing import Optional, List, Tuple, Union, Any import torch import torch.nn.functional as TorchF import operator @@ -13,6 +13,14 @@ def identity(tensor: torch.Tensor) -> torch.Tensor: return tensor +def ifexpr(cond: bool, true_value: Any, false_value: Any) -> Any: + """ + if expression + Please note there is no short-circuit evaluation in this function. + """ + return true_value if cond else false_value + + def anchor(name: str): """ anchor operation for graph navigation diff --git a/tests/codegen/test_emit.py b/tests/codegen/test_emit.py index cf814b5e..9ec6b894 100644 --- a/tests/codegen/test_emit.py +++ b/tests/codegen/test_emit.py @@ -1,5 +1,5 @@ import pytest -from nnscaler.codegen.emit import CodeEmission +from nnscaler.codegen.emit import CodeEmission, IRValue from nnscaler.ir.cten import IRObject from nnscaler.codegen.emit import FuncEmission from nnscaler.graph.function import Dropout @@ -19,7 +19,7 @@ def test_tensor_name(): assert repr_expr({'a': 1, 'b': IRObject('name', 111, 'value')}, 'model.') == "{'a': 1, 'b': name_111}" assert repr_expr([1], 'model.') == '[1]' assert repr_expr((1,), 'model.') == '(1,)' - + assert repr_expr((1,...), ) == '(1, Ellipsis)' with pytest.raises(ValueError): @@ -29,7 +29,7 @@ def test_tensor_name(): def test_emit_module_attr(): - dropout = Dropout(IRFullTensor([1024, 1024], requires_grad=True), p=0.5, training='self.training', signature='torch.nn.functional.dropout') + dropout = Dropout(IRFullTensor([1024, 1024], requires_grad=True), p=0.5, training=IRValue('self.training'), signature='torch.nn.functional.dropout') code = FuncEmission().emit_fnode(dropout, runtime_devid=0, plan_ndevs=1, runtime_ndevs=1) print(code) assert 'training=self.training' in code[0] diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 59625cf9..429876a3 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -11,6 +11,62 @@ transform ) + +@pytest.mark.skipif(sys.version_info < (3, 9), reason='ast.unparse is not available in python3.8') +def test_ifexpr_transfomer(): + # x = ast.parse('nnscaler.runtime.ifexpr(1, 2, 3)') + + tree = ast.parse(dedent(''' + x = 0.1 if self.training else 0.2 + ''').strip()) + transformers = [OperatorTransformer()] + modified, new_ast = transform(tree, transformers) + assert modified + assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' + x = nnscaler.runtime.function.ifexpr(self.training, 0.1, 0.2) + ''').strip() + + tree = ast.parse(dedent(''' + x = x.p if self.training else 0.2 + 0.3 + ''').strip()) + transformers = [OperatorTransformer()] + modified, new_ast = transform(tree, transformers) + assert modified + assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' + x = nnscaler.runtime.function.ifexpr(self.training, x.p, 0.2 + 0.3) + ''').strip() + + tree = ast.parse(dedent(''' + x = x.p if self.training else 0.2 + f(0.3) + ''').strip()) + transformers = [OperatorTransformer()] + modified, new_ast = transform(tree, transformers) + assert not modified + + tree = ast.parse(dedent(''' + x = f(x) if self.training else 0.2 + ''').strip()) + transformers = [OperatorTransformer()] + modified, new_ast = transform(tree, transformers) + assert not modified + + tree = ast.parse(dedent(''' + x = 0.1 if self.training else f(0.2) + ''').strip()) + transformers = [OperatorTransformer()] + modified, new_ast = transform(tree, transformers) + assert not modified + + tree = ast.parse(dedent(''' + x = f(0.1) if self.training else f(0.2) + ''').strip()) + transformers = [OperatorTransformer()] + modified, new_ast = transform(tree, transformers) + assert not modified + + + + @pytest.mark.skipif(sys.version_info < (3, 9), reason='ast.unparse is not available in python3.8') def test_op_transfomer(): tree = ast.parse(dedent(''' diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 5efff8ab..9ee06603 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -230,6 +230,16 @@ def forward(self, x, attr): return x + getattr(attr, 'a') +def print_gencode(cubesave_dir, module_class, index=0): + from nnscaler.parallel import _PARALLEL_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME + from pathlib import Path + import re + namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' + outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) + filecontent = (outdir /f'gencode{index}.py').read_text() + print(filecontent) + + def _gencode_contains(cubesave_dir, module_class, index, search_re): from nnscaler.parallel import _PARALLEL_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME from pathlib import Path @@ -240,6 +250,7 @@ def _gencode_contains(cubesave_dir, module_class, index, search_re): matches = re.findall(search_re, filecontent) return matches + class AttrHelper: def __init__(self) -> None: self.a = 2.0 @@ -760,3 +771,162 @@ def p(cube_dir, use_pipeline, constant_folding, return_type, inference_only=Fals assert _gencode_contains(tempdir, End2EndModule, 0, r"self\.register_buffer" ) + +from dataclasses import dataclass +@dataclass +class DataT: + x: int = 0 + y: int = 0 + + +class DropoutModule(torch.nn.Module): + def __init__(self): + super().__init__() + self._data = DataT() + + def forward(self, x): + x = x + self._data.x + return torch.nn.functional.dropout(x, 0.1 if self.training else 0.2, self.training) + + +@replace_all_device_with('cpu') +def test_codegen_dropout(): + """ + Test if self.training is correctly handled in the generated code + """ + with tempfile.TemporaryDirectory() as tempdir: + m = DropoutModule() + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False, + reuse='override', + ) + # it should looks like: + # add_17 = torch.add(x_20, 0, alpha=1) + # del x_20 + # training_9 = self.training + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 778, in forward, return torch.nn.functional.dropout(x, 0.1 if self.training else 0.2, self.training) + # ifexpr_4 = 0.1 if training_9 else 0.2 + # training_1_12 = self.training + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 778, in forward, return torch.nn.functional.dropout(x, 0.1 if self.training else 0.2, self.training) + # dropout_16 = torch.nn.functional.dropout(add_17, p=ifexpr_4, training=training_1_12, inplace=False) + # del add_17 + # return dropout_16 + assert _gencode_contains(tempdir, DropoutModule, 0, + r"ifexpr_\d+ = 0.1 if training_\d+ else 0.2" + ) + assert _gencode_contains(tempdir, DropoutModule, 0, + r" = torch.nn.functional.dropout\(add_\d+, p=ifexpr_\d+, training=training_1_\d+, inplace=False\)" + ) + + +@nnscaler.register_op('?, ? -> ?') +def get_dropout(training, dropout): + return dropout if training else 0.0 + + +class DropoutModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.dropout(x, get_dropout(self.training, 0.2), self.training) + + +@replace_all_device_with('cpu') +def test_codegen_dropout2(tmp_path): + """ + Test if register_op is correctly handled in the generated code + """ + m = DropoutModule2() + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # it should looks like: + # training_7 = self.training + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 838, in forward, return torch.nn.functional.dropout(x, get_dropout(self.training), self.training) + # get_dropout_3 = tests.parallel_module.test_gencode.get_dropout(training_7) + # training_1_11 = self.training + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 838, in forward, return torch.nn.functional.dropout(x, get_dropout(self.training), self.training) + # dropout_15 = torch.nn.functional.dropout(x_18, p=get_dropout_3, training=training_1_11, inplace=False) + # del x_18 + # return dropout_15 + assert _gencode_contains(tmp_path, DropoutModule2, 0, + r"= tests.parallel_module.test_gencode.get_dropout\(training_\d+" + ) + assert _gencode_contains(tmp_path, DropoutModule2, 0, + r"= torch.nn.functional.dropout\(x_\d+, p=get_dropout_\d+" + ) + + +class DictOutputModule(torch.nn.Module): + def forward(self, x): + return {'data': x + 10} + + +@replace_all_device_with('cpu') +def test_codegen_dictout(tmp_path): + m = DictOutputModule() + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # it should looks like: + # def segment9(self, x_9): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 819, in forward, return {'data': x + 10} + # add_6 = torch.add(x_9, 10, alpha=1) del x_9 + # return add_6 + + # def _forward_impl(self, x): add_6 = self.segment9(x) + # return {'data': add_6} + assert _gencode_contains(tmp_path, DictOutputModule, 0, + r"return {'data': add_\d+}" + ) + + +class KwargsModule(torch.nn.Module): + def forward(self, x): + return x + torch.zeros_like(x, dtype=torch.float32) + + +@replace_all_device_with('cpu') +def test_codegen_kwargs(tmp_path): + m = KwargsModule() + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # it should looks like: + # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 861, in forward, return x + torch.zeros_like(x, dtype=torch.float32) + # zeros_like_9 = torch.zeros_like(x_12, requires_grad=False, dtype=torch.float32) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 861, in forward, return x + torch.zeros_like(x, dtype=torch.float32) + # add_8 = torch.add(x_12, zeros_like_9, alpha=1) + # del x_12, zeros_like_9 + # return add_8 + assert _gencode_contains(tmp_path, KwargsModule, 0, + r"torch.zeros_like\(x_\d+, requires_grad=False, dtype=torch.float32\)" + ) diff --git a/tests/utils.py b/tests/utils.py index bfc7ee70..5c8f41e7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -117,6 +117,10 @@ def wrapper(*args, **kwargs): return orig_func(*args, **kwargs) wrapper.__name__ = orig_func.__name__ wrapper.__qualname__ = orig_func.__qualname__ + if hasattr(orig_func, '__module__'): + # torch.Tensor.new_empty, etc. don't have this attribute + # TODO: FxModuleParser._find_module_of_method will fail if __module__ is not set + wrapper.__module__ = orig_func.__module__ wrapper.__cube_orig_func__ = orig_func return wrapper # these constructors are enough for most cases From 04c608a4264def5f34935abf92f31bfb959f2a96 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 28 Jun 2024 01:47:15 +0000 Subject: [PATCH 1669/1892] Merged PR 2188: self.training in submodules: hotfix for nightly test self.training in submodules: hotfix for nightly test unit test pass parity check pass --- examples/huggingface_nlp/compile_hf.py | 7 ++- examples/huggingface_nlp/compile_interface.py | 11 ++-- examples/huggingface_nlp/requirements.txt | 1 + nnscaler/autodist/dp_solver.cpp | 1 + nnscaler/graph/parser/fx/parser.py | 24 ++++--- tests/autodist/spmd_solver/test_follow.py | 63 ++++++++++--------- tests/parallel_module/test_gencode.py | 44 +++++++++++++ 7 files changed, 105 insertions(+), 46 deletions(-) create mode 100644 examples/huggingface_nlp/requirements.txt diff --git a/examples/huggingface_nlp/compile_hf.py b/examples/huggingface_nlp/compile_hf.py index c48a40bc..55a01507 100644 --- a/examples/huggingface_nlp/compile_hf.py +++ b/examples/huggingface_nlp/compile_hf.py @@ -139,9 +139,9 @@ def dump_orged_errors(model_name, error_dict, log_path): error_dict[first_line]['count'] += 1 else: error_dict[first_line] = {"count": 1, 'model_name': [model_name]} #, "example": exception_string - + error_dict = dict(sorted(error_dict.items(), key=lambda item: item[1]["count"], reverse=True)) - + with open(log_path, 'w') as json_file: json.dump(error_dict, json_file, indent=4) @@ -337,6 +337,7 @@ def compile_hf_worker(self): loggers[ERROR_FNAME].error(f"{model_name} not aligned before and after compile, max diff:{max_diff}") if self.train: + self.model = self.model_loader.load_hf_model(self.config) add_logger(self.log_dir, TRAIN_FNAME, prefix="model trained: ", level = logging.INFO, need_timestamp = False) add_logger(self.log_dir, TRAIN_ALIGNED_FNAME, prefix="model train aligned: ", level = logging.INFO, need_timestamp = False) steps = 10 @@ -345,7 +346,7 @@ def compile_hf_worker(self): origin_loss = compiler.train(self.model, steps = steps) origin_logit = self.model(**compiler.dummy_input) - + loggers[TRAIN_FNAME].info(f"{model_name}") max_diff = calcu_max_diff(origin_logit, compile_logit) diff --git a/examples/huggingface_nlp/compile_interface.py b/examples/huggingface_nlp/compile_interface.py index 6d0a20f9..1e95cd66 100644 --- a/examples/huggingface_nlp/compile_interface.py +++ b/examples/huggingface_nlp/compile_interface.py @@ -50,9 +50,9 @@ def set_seed(seed): import numpy as np random.seed(seed) np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) class TraceCompileException(Exception): @@ -75,7 +75,7 @@ def __init__(self, model: torch.nn.Module, dummy_input: Dict[str, Any], policy): self.policy = policy self.model.eval() self.before_trace = self.model(**self.dummy_input) - + def forward_diff(self, model): """Compute the model's output and compare it with the original model's output""" if model is None: @@ -111,7 +111,7 @@ def parallel(self, model): pas_policy=self.policy, compute_config=ComputeConfig(1, 1), reuse='override', - load_module=True + load_module=True, ) return parallel_model except Exception as e: @@ -163,4 +163,3 @@ def export(self): raise RuntimeError("CUDA is not available") except Exception as e: raise TraceCompileException("An error occurred during export and forward the model.", e) - diff --git a/examples/huggingface_nlp/requirements.txt b/examples/huggingface_nlp/requirements.txt new file mode 100644 index 00000000..976a2b1f --- /dev/null +++ b/examples/huggingface_nlp/requirements.txt @@ -0,0 +1 @@ +transformers diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index 51da4e76..8fdfb74f 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -502,6 +502,7 @@ class DPSolver { // we select the 1st partition of the pre_node // need to be careful when the graph has multiple outputs // shall we constrain that the output of the graph is replicated? + cur_ir.push_back(*follow_candidates.rbegin()); } else if (pre_node->father_id == pre_node->id) { assert(follow_candidates.rbegin()->first == pre_node->id); cur_ir.push_back(*follow_candidates.rbegin()); diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index 94624eb9..bb7c20c6 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -235,12 +235,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule output = IRObject(name=node.name, value=output, is_constant=is_constant) ir_node = IRPyFunc(fsig, input_vals, [output], **kwargs) - if isinstance(ir_node, IRCell): - module_stack = node.meta.get('nn_module_stack') - ir_node.module_stack = module_stack - comment = str(node.meta.get('frame_record', '')) - if comment: - ir_node.comment = comment + FxModuleParser._set_node_meta(node, ir_node) ir_nodes = [] if isinstance(ir_node, IRCell): @@ -312,10 +307,13 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, else: frame.set_var(node.name, exist_tensor) else: - if node.target == 'training': + assert isinstance(node.target, str), f"GetAttrPrim: expect `node.target` to be str but got {type(node.target)}" + # in sub modules, the target is full qualified name (for example `embeddings.dropout.training`) + if node.target.split('.')[-1] == 'training': # Let's just support `self.training` and ignore all other cases for now output = IRObject(name=node.name, value=frame.get_var(node.name), is_constant=False) ir_node = IRPyFunc(SELF_GETATTR_SIG, ['training'], [output]) + FxModuleParser._set_node_meta(node, ir_node) frame.set_var(node.name, output) # never fold the IRPyFunc node ir_nodes.append(ir_node) @@ -331,6 +329,18 @@ def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, fr frame.set_var(node.name, output) return [] + @staticmethod + def _set_node_meta(node: torch.fx.Node, ir_node: Union[IRCell, Any]): + if not isinstance(ir_node, IRCell): + return + + module_stack = node.meta.get('nn_module_stack') + ir_node.module_stack = module_stack + comment = str(node.meta.get('frame_record', '')) + if comment: + ir_node.comment = comment + + @staticmethod def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: if isinstance(node_target, str): diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index 5dea60ef..fc1fc2ca 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -70,7 +70,7 @@ def test_follow_rope(): ''' the computation graph is as follows: q_proj fullslice fullslice k_proj - | | | | + | | | | view unsqueeze unsqueeze view | | \ | | transpose | -------------------------------------------------mul----transpose @@ -89,8 +89,8 @@ def test_follow_rope(): sum currently, the following chain is only composed of unary ops there are 2 chains in total: - 1. view -> transpose -> fullslice -> fullslice -> neg -> concat - 2. view -> transpose -> fullslice -> fullslice -> neg -> concat + 1. view -> transpose -> fullslice -> fullslice -> neg -> concat + 2. view -> transpose -> fullslice -> fullslice -> neg -> concat 3. fullslice -> unsqueeze 4. fullslice -> unsqueeze 5. add -> sum @@ -187,34 +187,34 @@ def test_follow_attention(): print(ir_graph.nodes()) ''' the computation graph is as follows: - linear linear linear + 2linear 3linear 4linear | | | - view view view + 5view 7view 9view | | | - transpose transpose transpose + 6transpose 8transpose 10transpose \ | | - | transpose | + | 11transpose | \ / | - matmul | + 12matmul | | | - div | + 13div | | | - softmax | - | | - dropout | +15training 14softmax | + | | | + \ ---- 16dropout | \ / \ / - matmul + 17matmul | - transpose + 18transpose | - contiguous + 19contiguous | - reshape + 20reshape | - linear + 21linear | - sum + 22sum the follow chain is as follows: 1. view -> transpose @@ -238,14 +238,14 @@ def test_follow_attention(): ) assert spmd_solver.follow_ids == [ - 0, 1, 2, 3, 3, 5, 5, 7, 7, 6, 10, 11, 11, 12, 14, 15, 15, 16, 18, 19 + 0, 1, 2, 3, 3, 5, 5, 7, 7, 6, 10, 11, 11, 13, 12, 15, 16, 16, 17, 19, 20 ] partition_counts = [ spmd_solver.get_op_partition_count(i) for i in range(model_graph.op_num) ] assert partition_counts == [ - 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 4, 4, 4, 2, 4 + 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 1, 2, 2, 4, 4, 4, 2, 4 ] # under the current partition constraints, the solver should generate # a Megatron-LM plan @@ -276,20 +276,22 @@ def test_follow_attention(): (13, (((0, 1), 2),)), # partition the head dim for softmax (14, (((0, 1), 2),)), + # replicate `training` + (15, (((-1, -1), 2),)), # partition the head dim for dropout - (15, (((0, 1), 2),)), - # partition the head dim for matmul(attn_weights, v) (16, (((0, 1), 2),)), - # partition the head dim for attn_out.transpose + # partition the head dim for matmul(attn_weights, v) (17, (((0, 1), 2),)), + # partition the head dim for attn_out.transpose + (18, (((0, 1), 2),)), # partition the head dim for contiguous - (18, (((0, 2), 2),)), - # partition the head dim for reshape (19, (((0, 2), 2),)), - # partition the input feature for o_proj + # partition the head dim for reshape (20, (((0, 2), 2),)), + # partition the input feature for o_proj + (21, (((0, 2), 2),)), # replicate the sum - (21, (((-1, -1), 2),)) + (22, (((-1, -1), 2),)) ] def helper(search_out): @@ -337,7 +339,7 @@ def test_solver_data_parallel(): ] print(partition_counts) assert partition_counts == [ - 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 6, 4, 4, 4, 6, 4, 4, 4, 5, 4 + 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 6, 4, 4, 1, 4, 6, 4, 4, 4, 5, 4 ] # should generate a pure data parallel plan, e.g., partition the batch dim expected_out = [ @@ -354,13 +356,14 @@ def test_solver_data_parallel(): (12, (((0, 0), 2),)), (13, (((0, 0), 2),)), (14, (((0, 0), 2),)), - (15, (((0, 0), 2),)), + (15, (((-1, -1), 2),)), (16, (((0, 0), 2),)), (17, (((0, 0), 2),)), (18, (((0, 0), 2),)), (19, (((0, 0), 2),)), (20, (((0, 0), 2),)), - (21, (((0, 0), 2),)) + (21, (((0, 0), 2),)), + (22, (((0, 0), 2),)) ] def helper(search_out): diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 9ee06603..dc4133a8 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -871,6 +871,50 @@ def test_codegen_dropout2(tmp_path): ) +class DropoutModuleNested(torch.nn.Module): + def __init__(self): + super().__init__() + self.dropout = DropoutModule2() + + def forward(self, x): + return self.dropout(x) + + +@replace_all_device_with('cpu') +def test_codegen_dropout_nested(tmp_path): + """ + Test if register_op is correctly handled in the generated code + """ + m = DropoutModuleNested() + m.train() + parallelize( + m, + {'x': torch.randn(128, 64)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # it should looks like: + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 838, in forward, return torch.nn.functional.dropout(x, get_dropout(self.training, 0.2), self.training) + # dropout_training_7 = self.training + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 838, in forward, return torch.nn.functional.dropout(x, get_dropout(self.training, 0.2), self.training) + # get_dropout_3 = tests.parallel_module.test_gencode.get_dropout(dropout_training_7, 0.2) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 838, in forward, return torch.nn.functional.dropout(x, get_dropout(self.training, 0.2), self.training) + # dropout_training_1_11 = self.training + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 838, in forward, return torch.nn.functional.dropout(x, get_dropout(self.training, 0.2), self.training) + # dropout_15 = torch.nn.functional.dropout(x_18, p=get_dropout_3, training=dropout_training_1_11, inplace=False) + # del x_18 + # return dropout_15 + assert _gencode_contains(tmp_path, DropoutModuleNested, 0, + r"= tests.parallel_module.test_gencode.get_dropout\(dropout_training_\d+" + ) + assert _gencode_contains(tmp_path, DropoutModuleNested, 0, + r"= torch.nn.functional.dropout\(x_\d+, p=get_dropout_\d+" + ) + + class DictOutputModule(torch.nn.Module): def forward(self, x): return {'data': x + 10} From ef2586ef7ed57840d4a92e5f7e2ec5d0a9b0aa89 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 2 Jul 2024 02:01:25 +0000 Subject: [PATCH 1670/1892] Merged PR 2189: Lightning: refine code/add more tests --- docs/source/pytorch_lightning.md | 68 +++++++++++++++++++ .../integration/lightning/pytorch/strategy.py | 35 ++++++---- nnscaler/parallel.py | 26 +++++-- nnscaler/runtime/gnorm.py | 3 + tests/cli/common.py | 2 +- .../lightning/pytorch/simple_models.py | 12 ++-- .../lightning/pytorch/test_strategy.py | 68 +++++++++++++++++++ tests/launch_torchrun.py | 2 +- 8 files changed, 188 insertions(+), 28 deletions(-) create mode 100644 docs/source/pytorch_lightning.md diff --git a/docs/source/pytorch_lightning.md b/docs/source/pytorch_lightning.md new file mode 100644 index 00000000..2dfdba85 --- /dev/null +++ b/docs/source/pytorch_lightning.md @@ -0,0 +1,68 @@ +# Pytorch Lightning support + +We support Pytorch Lightning by `NnScalerStrategy` and `NnScalerPrecision`. You can use `nnscaler` strategy in pytorch lightning like this: + +```python +compute_config=ComputeConfig(...) +policy = ... +trainer = Trainer( + ..., + strategy=NnScalerStrategy( + compute_config=compute_config, pas_policy=..., gen_savedir=..., + ... + ), + plugins=[NnScalerPrecision(precision, ...)], + ... +) +trainer.fit(...) +``` + +## Model + +### Dummy input + +We need a dummy input to trace the forward function. You can specify it in two ways: + +1. Add `dummy_forward_args` property to your model class, which should be a dictionary of forward inputs. +2. You can also add `dummy_forward_args_fn`, which will be used to convert the sample (loaded from train dataloader) to forward inputs. + +### Rewritten members + +We will rewrite two functions: +1. `forward` function: As we explained before, the `forward` function will be replaced with a distributed version. +2. `log` function: We will rewrite the `log` function to force the `sync_dist_group` to be set properly when `sync_dist=True`. + +We will also set all trainable modules to None to reduce memory usage. + +To make sure the model can be used with nnscaler strategy, you should follow these rules: + +1. All trainable parameters should only be used in forward function. +If it is used outside forward, it should be in torch.no_grad context. +Otherwise, as we don't create reduce-op outside forward, its gradient will be incorrect. +2. Train/Validate/Test should use exactly the same graph. +3. All functions replying on the trainable modules should be rewritten with forward function. +After our conversion, all those modules will be None. + +## Strategy + +The constructor argument of `NnScalerStrategy` is the combination of `Strategy`'s constructor and `nnscaler.parallize` function. You can refer to the documentation of `Strategy` and `nnscaler.parallize` for more details. + +One special argument is `state_dict_type`, which specify the format in which the state of the model and optimizers gets saved into the checkpoint. + +- `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. +The checkpoint is a folder with as many files as the local world size. +- `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the local world size. + +## Precision + +It has exactly the same constructor arguments as `Precision`'s constructor. + +Currently we support `32-true`, `16-true`, `bf16-true`, `16-mixed`, `bf16-mixed`. +You can specify a grad scaler when you use `16-true`. + + +## Limitation + +1. Only one optimizer is supported. +2. Only one lr scheduler is supported. +3. Only one parameter group is supported. diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py index 48daf229..320f3754 100644 --- a/nnscaler/integration/lightning/pytorch/strategy.py +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -269,15 +269,7 @@ def _setup_model(self, model: Module) -> Module: model.forward = pmodule.forward # patch log function to add sync_dist_group - rank = torch.distributed.get_rank() - # create all groups - plan_ngpus = self.compute_config.plan_ngpus - runtime_ngpus = self.compute_config.runtime_ngpus - for i in range(plan_ngpus): - DeviceGroup().get_group( - list(range(i, runtime_ngpus, plan_ngpus)) - ) - sync_group = list(range(rank % plan_ngpus, runtime_ngpus, plan_ngpus)) + _, sync_group = self.compute_config.create_sync_group() _old_log = model.log def _new_log(self, *args, **kwargs) -> None: @@ -384,13 +376,15 @@ def model_to_device(self) -> None: def barrier(self, name: Optional[str] = None) -> None: if not _distributed_is_initialized(): return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=[self.root_device.index]) - else: - torch.distributed.barrier() + assert torch.distributed.get_backend() == "nccl", "nnscaler only supports nccl backend" + # https://github.com/pytorch/pytorch/issues/53658 + # It would be better to provide device_ids=[self.root_device.index] + torch.distributed.barrier(device_ids=[self.root_device.index]) @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + assert torch.distributed.get_backend() == "nccl", "nnscaler only supports nccl backend" + if not _distributed_is_initialized(): return obj @@ -417,8 +411,19 @@ def reduce( reduced value, except when the input was not a tensor the output remains is unchanged """ - if isinstance(tensor, Tensor): - return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + assert torch.distributed.get_backend() == "nccl", "nnscaler only supports nccl backend" + + if not _distributed_is_initialized() or not isinstance(tensor, Tensor): + return tensor + + op: Optional[ReduceOp] + if isinstance(reduce_op, str): + reduce_op = "avg" if reduce_op == "mean" else reduce_op + op = getattr(ReduceOp, reduce_op.upper()) + else: + op = reduce_op + + torch.distributed.all_reduce(tensor, op=op, group=group, async_op=False) return tensor @override diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 544ec81a..a14b7657 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -204,6 +204,26 @@ def optimizer_dedup_group_size(self) -> int: else: return self.plan_ngpus + def create_sync_group(self) -> Tuple[List[int], torch.distributed.ProcessGroup]: + """ + Create a sync group for the current rank. + The sync group is a group of ranks that have exactly the same weights, but different inputs, + so they should synchronize with each other to get the whole gradients/loss/etc. + + Returns: + Tuple[List[int], torch.distributed.ProcessGroup]: return the rank list of the group and its torch.distributed group + """ + rank = torch.distributed.get_rank() + # create all groups + plan_ngpus = self.plan_ngpus + runtime_ngpus = self.runtime_ngpus + for i in range(plan_ngpus): + DeviceGroup().get_group( + list(range(i, runtime_ngpus, plan_ngpus)) + ) + rank_list = list(range(rank % plan_ngpus, runtime_ngpus, plan_ngpus)) + return rank_list, DeviceGroup().get_group(rank_list) + @classmethod def safe_dump_to_file(cls, cfg: 'ComputeConfig', file: Union[str, Path]) -> None: """ @@ -1262,11 +1282,7 @@ def build_optimizer( # we need to add all parameters of non-parallel modules to a reducer to reduce grads # if there are non-parallel parameters if plan_ngpus != runtime_ngpus and non_parallel_modules and any(p.numel() for m in non_parallel_modules for p in m.parameters(False)): - rank = torch.distributed.get_rank() - # create all groups - for i in range(plan_ngpus): - DeviceGroup().get_group(list(range(i, runtime_ngpus, plan_ngpus))) - group = list(range(rank % plan_ngpus, runtime_ngpus, plan_ngpus)) + group, _ = compute_configs[0].create_sync_group() non_parallel_module_reducer = Reducer(group) for m in non_parallel_modules: for param in m.parameters(recurse=False): # only add leaf parameters to avoid duplicate diff --git a/nnscaler/runtime/gnorm.py b/nnscaler/runtime/gnorm.py index bb94b34b..8f073da3 100644 --- a/nnscaler/runtime/gnorm.py +++ b/nnscaler/runtime/gnorm.py @@ -238,6 +238,9 @@ def grad_exists(p): if multi_tensor_l2norm_available: total_norm = _multi_tensor_total_norm(grads).to(device) else: + # torch.nn.utils.clip_grad_norm_ way to calculate the norm + # norms = torch._foreach_norm(grads, 2.0) + # total_norm = torch.linalg.vector_norm(torch.stack([norm.to(device) for norm in norms]), 2.0) total_norm = torch.norm( torch.stack( [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] diff --git a/tests/cli/common.py b/tests/cli/common.py index 6935a36e..6020579c 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -6,7 +6,7 @@ class SimpleDataset(Dataset): def __init__(self, dim: int, size: int = 100): self.data = torch.randn(size, dim) - self.target = torch.randn(size, dim) + self.target = torch.rand(size, dim) def __getitem__(self, idx: int): return { diff --git a/tests/integration/lightning/pytorch/simple_models.py b/tests/integration/lightning/pytorch/simple_models.py index d72a0c87..a481c2c5 100644 --- a/tests/integration/lightning/pytorch/simple_models.py +++ b/tests/integration/lightning/pytorch/simple_models.py @@ -57,23 +57,23 @@ def training_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = F.cross_entropy(logits, y) - self.log("train_loss", loss, prog_bar=True) - self.log("train_acc", self.train_acc(logits, y), prog_bar=True) + self.log("train_loss", loss, prog_bar=True, sync_dist=True) + self.log("train_acc", self.train_acc(logits, y), prog_bar=True, sync_dist=True) return {"loss": loss} def validation_step(self, batch, batch_idx): assert not self.training x, y = batch logits = self.forward(x) - self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False) - self.log("val_acc", self.valid_acc(logits, y), prog_bar=True) + self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False, sync_dist=True) + self.log("val_acc", self.valid_acc(logits, y), prog_bar=True, sync_dist=True) def test_step(self, batch, batch_idx): assert not self.training x, y = batch logits = self.forward(x) - self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False) - self.log("test_acc", self.test_acc(logits, y), prog_bar=True) + self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False, sync_dist=True) + self.log("test_acc", self.test_acc(logits, y), prog_bar=True, sync_dist=True) def predict_step(self, batch, batch_idx): assert not self.training diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index a6112c8c..fa00ee6c 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -173,6 +173,62 @@ def correctnes_worker_nnscaler(tmp_path, gradient_clip_val, with_lr_scheduler, return model.update_history, model.nnscaler_pmodule.fullmap +def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_scheduler, + precision='32-true', + with_tp=False, with_empty_scaler=False +): + init_random() + dm = ClassifDataModule() + init_random() + if with_lr_scheduler: + model = ClassificationModelWithLRScheduler() + state_dict_type = 'sharded' + else: + model = ClassificationModel() + state_dict_type = 'deduped' + if with_tp: + compute_config=ComputeConfig(2, 4) + policy = 'tp' + devices = 4 + else: + compute_config=ComputeConfig(1, 2) + policy = 'dp' + devices = 2 + scaler = None + if with_empty_scaler or precision == '16-mixed': + scaler = torch.cuda.amp.GradScaler(enabled=(precision == '16-mixed')) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + callbacks=[ModelCheckpoint(dirpath=tmp_path, save_top_k=1, save_last=True)], + accelerator="gpu", devices=devices, + gradient_clip_val=gradient_clip_val, + strategy=NnScalerStrategy( + compute_config=compute_config, pas_policy=policy, gen_savedir=tmp_path, + instance_name=policy + '_resume', + state_dict_type=state_dict_type + ), + plugins=[NnScalerPrecision(precision, scaler=scaler)] + ) + trainer.fit(model, datamodule=dm) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=2, + callbacks=[ModelCheckpoint(dirpath=tmp_path, save_top_k=1, save_last=True)], + accelerator="gpu", devices=devices, + gradient_clip_val=gradient_clip_val, + strategy=NnScalerStrategy( + compute_config=compute_config, pas_policy=policy, gen_savedir=tmp_path, + instance_name=policy + '_resume', + state_dict_type=state_dict_type + ), + plugins=[NnScalerPrecision(precision, scaler=scaler)] + ) + trainer.fit(model, datamodule=dm, ckpt_path='last') + return model.update_history, model.nnscaler_pmodule.fullmap + + def correctnes_worker_ddp(tmp_path, gradient_clip_val, with_lr_scheduler, precision='32-true'): init_random() dm = ClassifDataModule() @@ -231,14 +287,26 @@ def _merge_results(returns): assert_close(nnscaler_merged_grad_results_fp16[i], ddp_results[0][i][0]) assert_equal(ddp_results[1][i], ddp_results[0][i]) + nnscaler_returns_ckpt = launch_torchrun(2, correctnes_worker_nnscaler_checkpoint, tmp_path, gradient_clip_val, with_lr_scheduler) + nnscaler_merged_weight_results_ckpt, nnscaler_merged_grad_results_ckpt = _merge_results(nnscaler_returns_ckpt) + nnscaler_returns = launch_torchrun(2, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler) nnscaler_merged_weight_results, nnscaler_merged_grad_results = _merge_results(nnscaler_returns) nnscaler_returns = launch_torchrun(2, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler, '32-true', False, True) nnscaler_merged_weight_results_scaler, nnscaler_merged_grad_results_scaler = _merge_results(nnscaler_returns) + + assert len(nnscaler_merged_weight_results) == len(nnscaler_merged_weight_results_ckpt) + assert len(nnscaler_merged_weight_results) == len(nnscaler_merged_weight_results_scaler) + + assert len(nnscaler_merged_grad_results) == len(nnscaler_merged_grad_results_ckpt) + assert len(nnscaler_merged_grad_results) == len(nnscaler_merged_grad_results_scaler) + for i in range(len(nnscaler_merged_weight_results_scaler)): assert_equal(nnscaler_merged_weight_results[i], nnscaler_merged_weight_results_scaler[i]) + assert_equal(nnscaler_merged_weight_results[i], nnscaler_merged_weight_results_ckpt[i]) assert_equal(nnscaler_merged_grad_results[i], nnscaler_merged_grad_results_scaler[i]) + assert_equal(nnscaler_merged_grad_results[i], nnscaler_merged_grad_results_ckpt[i]) ddp_results = launch_torchrun(2, correctnes_worker_ddp, tmp_path, gradient_clip_val, with_lr_scheduler) for i in range(len(ddp_results[0])): diff --git a/tests/launch_torchrun.py b/tests/launch_torchrun.py index 886928f4..50cc596e 100644 --- a/tests/launch_torchrun.py +++ b/tests/launch_torchrun.py @@ -8,7 +8,7 @@ from .utils import retry -@retry(ChildFailedError, delay=10, match='RuntimeError: The server socket has failed to listen on any local network address.') +@retry(ChildFailedError, delay=10, match='The server socket has failed to listen on any local network address.') def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): launch_config = LaunchConfig( min_nodes=1, From a182bcc3f3a3668453dddb440f3c259d2334de0d Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Tue, 2 Jul 2024 08:09:42 +0000 Subject: [PATCH 1671/1892] Merged PR 2144: Nightly build scripts Add a pipeline to nightly build wheel, and fix packaging for autodist profile data. pipeline: https://msrasrg.visualstudio.com/SuperScaler/_build?definitionId=114 repo: https://msrasrg.visualstudio.com/SuperScaler/_artifacts/feed/nightly --- nnscaler/autodist/cost_database.py | 4 +- nnscaler/resources/__init__.py | 25 +++++++ .../profile/mi200/comm/intra_16.json | 0 .../profile/mi200/comm/intra_2.json | 0 .../profile/mi200/comm/intra_4.json | 0 .../profile/mi200/comm/intra_8.json | 0 nnscaler/version.py | 2 +- pipelines/nightly-build.yaml | 25 +++++++ pipelines/scripts/update_version.py | 71 +++++++++++++++++++ pyproject.toml | 2 +- requirements-dev.txt | 1 + requirements.txt | 1 + 12 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 nnscaler/resources/__init__.py rename {data => nnscaler/resources}/profile/mi200/comm/intra_16.json (100%) rename {data => nnscaler/resources}/profile/mi200/comm/intra_2.json (100%) rename {data => nnscaler/resources}/profile/mi200/comm/intra_4.json (100%) rename {data => nnscaler/resources}/profile/mi200/comm/intra_8.json (100%) create mode 100644 pipelines/nightly-build.yaml create mode 100644 pipelines/scripts/update_version.py diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index c9900c7a..ea292998 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -15,13 +15,13 @@ from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.function.dimops import DimopSplit, IRDimops +import nnscaler.resources from .autodist_config import AutoDistConfig _logger = logging.getLogger(__name__) -import nnscaler -_DEFAULT_COMM_DATA_PATH = Path(nnscaler.__file__).parent.parent / 'data/profile/mi200/comm' +_DEFAULT_COMM_DATA_PATH = nnscaler.resources.files() / 'profile/mi200/comm' def _piecewise_estimator(xs: List[float], ys: List[float], x: float) -> float: diff --git a/nnscaler/resources/__init__.py b/nnscaler/resources/__init__.py new file mode 100644 index 00000000..af4fe9c0 --- /dev/null +++ b/nnscaler/resources/__init__.py @@ -0,0 +1,25 @@ +""" +Pseudo module of resource files. +""" + +from __future__ import annotations + +__all__ = 'files' + +# TODO: when drop python 3.8 support, change it to `importlib.resources` +import importlib_resources +from importlib_resources.abc import Traversable + +def files() -> Traversable: + """ + Alias of ``importlib.resources.files('nnscaler.resources')``. + + Returns: + A ``Path``-like object. + + Example: + :: + import nnscaler.resources + (nnscaler.resources.files() / 'path/to/my_file.txt').read_text() + """ + return importlib_resources.files(__name__) diff --git a/data/profile/mi200/comm/intra_16.json b/nnscaler/resources/profile/mi200/comm/intra_16.json similarity index 100% rename from data/profile/mi200/comm/intra_16.json rename to nnscaler/resources/profile/mi200/comm/intra_16.json diff --git a/data/profile/mi200/comm/intra_2.json b/nnscaler/resources/profile/mi200/comm/intra_2.json similarity index 100% rename from data/profile/mi200/comm/intra_2.json rename to nnscaler/resources/profile/mi200/comm/intra_2.json diff --git a/data/profile/mi200/comm/intra_4.json b/nnscaler/resources/profile/mi200/comm/intra_4.json similarity index 100% rename from data/profile/mi200/comm/intra_4.json rename to nnscaler/resources/profile/mi200/comm/intra_4.json diff --git a/data/profile/mi200/comm/intra_8.json b/nnscaler/resources/profile/mi200/comm/intra_8.json similarity index 100% rename from data/profile/mi200/comm/intra_8.json rename to nnscaler/resources/profile/mi200/comm/intra_8.json diff --git a/nnscaler/version.py b/nnscaler/version.py index cce384d3..1d8a19d3 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1 +1 @@ -__version__ = '0.3' +__version__ = '0.3.dev0' diff --git a/pipelines/nightly-build.yaml b/pipelines/nightly-build.yaml new file mode 100644 index 00000000..da792cf6 --- /dev/null +++ b/pipelines/nightly-build.yaml @@ -0,0 +1,25 @@ +trigger: +- main + +pool: + vmImage: ubuntu-latest + +steps: +- task: TwineAuthenticate@1 + inputs: + artifactFeed: SuperScaler/nightly + +- script: | + python -m pip install --upgrade build twine + displayName: prepare environment + +- script: | + python pipelines/scripts/update_version.py --nightly + python -m build + displayName: build wheel + +- script: | + number_of_wheels=`ls dist/*.whl | wc -l` + test $number_of_wheels -eq 1 + python -m twine upload -r nightly --config-file $(PYPIRC_PATH) dist/*.whl + displayName: upload nightly wheel diff --git a/pipelines/scripts/update_version.py b/pipelines/scripts/update_version.py new file mode 100644 index 00000000..98b4f289 --- /dev/null +++ b/pipelines/scripts/update_version.py @@ -0,0 +1,71 @@ +""" +Update "nnscaler/version.py" before building the wheel. + +Usage 1: + + python update_version.py --nightly + +Update version.py to "X.Y.dev{TIMESTAMP}+{GIT_COMMIT}". + +Usage 2: + + python update_version.py 1.2 + python update_version.py v1.2b3 + +Update version.py to the specified version (normalized, leading "v" removed). +It will verify that the release part matches the old version. +""" + +import argparse +from datetime import datetime +from pathlib import Path +import subprocess + +from packaging.version import Version + +project_dir = Path(__file__).parents[2] + +def main(): + parser = argparse.ArgumentParser(add_help=False) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--nightly', action='store_true') + group.add_argument('version', nargs='?') + args = parser.parse_args() + + version_file = Path(project_dir, 'nnscaler/version.py') + file_content = version_file.read_text() + version_str = file_content.split('=')[-1].strip()[1:-1] # "version = 'x'" -> "x" + repo_version = Version(version_str) + + if args.nightly: + timestamp = datetime.now().strftime('%y%m%d%H%M') + + r = subprocess.run( + 'git rev-parse --short HEAD'.split(), + stdout=subprocess.PIPE, + cwd=project_dir, + text=True, + ) + if r.returncode != 0: + print('[error] failed to get git commit hash') + exit(1) + commit = r.stdout.strip() + + new_version_str = f'{repo_version.base_version}.dev{timestamp}+{commit}' + + else: + arg_version = Version(args.version) + + if repo_version.release != arg_version.release: + print('[error] version not match') + print(f' repo: {version_str} -> {repo_version}') + print(f' arg: {args.version} -> {arg_version}') + exit(1) + + new_version_str = str(arg_version) # normalize + + file_content = file_content.replace(version_str, new_version_str) + version_file.write_text(file_content) + +if __name__ == '__main__': + main() diff --git a/pyproject.toml b/pyproject.toml index 009d83ac..083e3ea0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,4 +27,4 @@ dynamic.dependencies.file = "requirements.txt" # the following part only affects wheel, not sdist # since our current plan is to use cppimport, sdist is not needed packages.find.include = ["nnscaler*"] -package-data.nnscaler = ["autodist/csrc/*"] +package-data = { nnscaler = ["resources/**", "autodist/*.cpp"] } diff --git a/requirements-dev.txt b/requirements-dev.txt index ae21ff91..d9c13b5e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +build coverage furo mock diff --git a/requirements.txt b/requirements.txt index 55e0755d..cda12044 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ cppimport dill +importlib-resources matplotlib more-itertools numpy>=1.23.0 From 9eebf402e4c3df6762d64c365a4189d910d7fd41 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Tue, 2 Jul 2024 08:37:17 +0000 Subject: [PATCH 1672/1892] Merged PR 2194: Reset version to v0.1 and update email --- nnscaler/version.py | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nnscaler/version.py b/nnscaler/version.py index 1d8a19d3..2273bc24 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1 +1 @@ -__version__ = '0.3.dev0' +__version__ = '0.1.dev0' diff --git a/pyproject.toml b/pyproject.toml index 083e3ea0..4d02a157 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,13 +11,13 @@ readme = "README.md" requires-python = ">=3.8" # TODO: license authors = [ - {name = "nnScaler Team", email = "nnscaler@microsoft.com"} # FIXME: email + {name = "nnScaler Team", email = "nnscaler@service.microsoft.com"} ] # TODO: keywords # TODO: classifiers [project.urls] -Homepage = "https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube" +Homepage = "https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube" # TODO: github [tool.setuptools] dynamic.version.attr = "nnscaler.version.__version__" From 7a2485d0e5e23b2e3002ec5fdcd2d304c38d62e8 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 2 Jul 2024 08:53:31 +0000 Subject: [PATCH 1673/1892] Merged PR 2193: lightning: refine docs about checkpoint refine docs --- docs/source/pytorch_lightning.md | 21 +++++++++++++++++++ .../integration/lightning/pytorch/strategy.py | 2 +- nnscaler/parallel.py | 9 +++++--- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/docs/source/pytorch_lightning.md b/docs/source/pytorch_lightning.md index 2dfdba85..918d9473 100644 --- a/docs/source/pytorch_lightning.md +++ b/docs/source/pytorch_lightning.md @@ -60,6 +60,27 @@ It has exactly the same constructor arguments as `Precision`'s constructor. Currently we support `32-true`, `16-true`, `bf16-true`, `16-mixed`, `bf16-mixed`. You can specify a grad scaler when you use `16-true`. +## Checkpoint + +If this is the first time you run the model, and you have a pretrained model, you must load the pretrained model before you pass it to the `Trainer` constructor. The tracing process will use the pretrained model weights to trace the forward function. + +Just like other pytorch lightning strategy, +you can resume from a checkpoint by specifying the `ckpt_path` argument in the `Trainer.fit` function. +Please note when the parallel plan is changed (i.e you re-trace the model with different configurations), +the checkpoints become incompatible, and can't be loaded any more. +You must firstly merge the checkpoints to a merged checkpoint and load it as a pretrained model. + +You can also merge all checkpoints (saved by each rank) to a complete checkpoint by using the `nnscaler.merge_state_dicts` function. +```python +import nnscaler +from pathlib import Path +state_dicts = [] +CHECKPOINT_DIR = Path(...) +for rank in range(world_size): + state_dicts.append(torch.load(CHECKPOINT_DIR / f"{rank}.pt")['state_dict']) +merged_state_dict, _ = nnscaler.merge_state_dicts(state_dicts) +torch.save(merged_state_dict, CHECKPOINT_DIR / "merged.pt") +``` ## Limitation diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py index 320f3754..4aa0204d 100644 --- a/nnscaler/integration/lightning/pytorch/strategy.py +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -269,7 +269,7 @@ def _setup_model(self, model: Module) -> Module: model.forward = pmodule.forward # patch log function to add sync_dist_group - _, sync_group = self.compute_config.create_sync_group() + _, sync_group = self.compute_config.get_sync_group() _old_log = model.log def _new_log(self, *args, **kwargs) -> None: diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index a14b7657..72fc6c71 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -204,12 +204,15 @@ def optimizer_dedup_group_size(self) -> int: else: return self.plan_ngpus - def create_sync_group(self) -> Tuple[List[int], torch.distributed.ProcessGroup]: + def get_sync_group(self) -> Tuple[List[int], torch.distributed.ProcessGroup]: """ - Create a sync group for the current rank. + Get sync group for the current rank. The sync group is a group of ranks that have exactly the same weights, but different inputs, so they should synchronize with each other to get the whole gradients/loss/etc. + Please note if sync groups haven't been created, it will create them. + So it will deadlock if only some of ranks call this function. + Returns: Tuple[List[int], torch.distributed.ProcessGroup]: return the rank list of the group and its torch.distributed group """ @@ -1282,7 +1285,7 @@ def build_optimizer( # we need to add all parameters of non-parallel modules to a reducer to reduce grads # if there are non-parallel parameters if plan_ngpus != runtime_ngpus and non_parallel_modules and any(p.numel() for m in non_parallel_modules for p in m.parameters(False)): - group, _ = compute_configs[0].create_sync_group() + group, _ = compute_configs[0].get_sync_group() non_parallel_module_reducer = Reducer(group) for m in non_parallel_modules: for param in m.parameters(recurse=False): # only add leaf parameters to avoid duplicate From d6f6c09b842c3cf62bfd5e8f37fa930039380e40 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 3 Jul 2024 07:43:24 +0000 Subject: [PATCH 1674/1892] Merged PR 2196: never fold nnscaler runtime functions never fold nnscaler runtime functions --- nnscaler/graph/parser/fx/parser.py | 6 +++++- tests/parallel_module/test_gencode.py | 15 +++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index bb7c20c6..e1918242 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, List, Tuple, Callable, Union, Dict, Type, Optional +import nnscaler from nnscaler.ir.operator import IRFwOperation from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.cten import IRObject, IRCell, IRTensor @@ -246,7 +247,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule for i in range(len(vals)): ir_node.set_output(i, vals[i]) elif not isinstance(ir_node.output(0), IRTensor) and ir_node.output(0).value is not None: + # never fold our own functions defined in `nnscaler.runtime.function` module. + # currently only `ifexpr` will go here, and it will never be folded. if not constant_folding or \ + ir_node.signature.startswith(nnscaler.runtime.function.__name__ + '.') or \ any_ir_object_satisfy(ir_node.output(0), lambda a: not a.is_constant) or \ any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, IRTensor)) or \ any_ir_object_satisfy(ir_node.output(0), lambda a: isinstance(a, (DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE))): @@ -303,7 +307,7 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, persistent = node.name not in module._non_persistent_buffers_set tensor.as_buffer(persistent=persistent) frame.add_attr(tensor, concrete_value, node.target) - # the case that the parameter is consumed multiple times and regisetered previously + # the case that the parameter is consumed multiple times and registered previously else: frame.set_var(node.name, exist_tensor) else: diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index dc4133a8..f186d26f 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -790,7 +790,8 @@ def forward(self, x): @replace_all_device_with('cpu') -def test_codegen_dropout(): +@pytest.mark.parametrize('constant_fold', [False, True]) +def test_codegen_dropout(constant_fold): """ Test if self.training is correctly handled in the generated code """ @@ -801,7 +802,7 @@ def test_codegen_dropout(): m, {'x': torch.randn(128, 64)}, 'dp', - ComputeConfig(1, 1), + ComputeConfig(1, 1, constant_folding=constant_fold), gen_savedir=tempdir, load_module=False, reuse='override', @@ -839,7 +840,8 @@ def forward(self, x): @replace_all_device_with('cpu') -def test_codegen_dropout2(tmp_path): +@pytest.mark.parametrize('constant_fold', [False, True]) +def test_codegen_dropout2(tmp_path, constant_fold): """ Test if register_op is correctly handled in the generated code """ @@ -849,7 +851,7 @@ def test_codegen_dropout2(tmp_path): m, {'x': torch.randn(128, 64)}, 'dp', - ComputeConfig(1, 1), + ComputeConfig(1, 1, constant_folding=constant_fold), gen_savedir=tmp_path, load_module=False, reuse='override', @@ -881,7 +883,8 @@ def forward(self, x): @replace_all_device_with('cpu') -def test_codegen_dropout_nested(tmp_path): +@pytest.mark.parametrize('constant_fold', [False, True]) +def test_codegen_dropout_nested(tmp_path, constant_fold): """ Test if register_op is correctly handled in the generated code """ @@ -891,7 +894,7 @@ def test_codegen_dropout_nested(tmp_path): m, {'x': torch.randn(128, 64)}, 'dp', - ComputeConfig(1, 1), + ComputeConfig(1, 1, constant_folding=constant_fold), gen_savedir=tmp_path, load_module=False, reuse='override', From d6c25a34cf78c3c49ba25eb30d33f06024c4269c Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 4 Jul 2024 09:19:10 +0000 Subject: [PATCH 1675/1892] Merged PR 2171: Refine dp solver in autodist - align the memory estimation in dp solver with ilp solver, check this [PR](https://dev.azure.com/msrasrg/SuperScaler/_git/MagicCube/pullrequest/2121) for more details - refine c++ code - have verified the search result compared to the ilp with & without recompute on retnet-3b NOTE: after this PR, more meta information are introduced in a dynamic programming state, resulting in the dp solver may be slower than ilp solver, which needs further optimization. --- nnscaler/autodist/cube_operator.py | 9 + nnscaler/autodist/dp_solver.cpp | 324 ++++------------------ nnscaler/autodist/dp_solver.h | 225 +++++++++++++++ nnscaler/autodist/model_graph.py | 1 + nnscaler/autodist/spmd_solver.py | 23 +- tests/autodist/spmd_solver/test_follow.py | 4 +- tests/autodist/test_dp_solver.py | 58 +++- 7 files changed, 355 insertions(+), 289 deletions(-) create mode 100644 nnscaler/autodist/dp_solver.h diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py index ee867e90..1ec6ed7c 100644 --- a/nnscaler/autodist/cube_operator.py +++ b/nnscaler/autodist/cube_operator.py @@ -37,6 +37,7 @@ def __init__(self, ir_cell: IRFwOperation): self._has_sum_dim = False self._recompute = False self._recompute_start_op = False + self._recompute_last_op = False self._has_attr = False self.omit_recompute_in_idx = [] @@ -83,6 +84,14 @@ def recompute_start_op(self): def recompute_start_op(self, value: bool): self._recompute_start_op = value + @property + def recompute_last_op(self): + return self._recompute_last_op + + @recompute_last_op.setter + def recompute_last_op(self, value: bool): + self._recompute_last_op = value + def add_producer(self, producer: 'CubeOperator'): self.producers.append(producer) diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index 8fdfb74f..5564c2fc 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -1,54 +1,10 @@ // cppimport +#include "dp_solver.h" #include #include namespace py = pybind11; -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { -public: - ThreadPool(unsigned int n = std::thread::hardware_concurrency()); - - template void enqueue(F &&f); - void waitFinished(); - ~ThreadPool(); - - unsigned int getProcessed() const { return processed; } - -private: - std::vector workers; - std::deque> tasks; - std::mutex queue_mutex; - std::condition_variable cv_task; - std::condition_variable cv_finished; - std::atomic_uint processed; - unsigned int busy; - bool stop; - - void thread_proc(); -}; - ThreadPool::ThreadPool(unsigned int n) : busy(), processed(), stop() { for (unsigned int i = 0; i < n; ++i) workers.emplace_back(std::bind(&ThreadPool::thread_proc, this)); @@ -129,104 +85,12 @@ std::vector> split_work(int num, int base) { return ret; } -struct DPNode; - -struct Node { - int id; - int father_id; - - int cut_len; - std::vector cut_nodes; - - int p_num; - std::vector p_time; - std::vector p_comp_mem; - std::vector p_buf_mem; - std::vector p_act_mem; - std::vector p_opt_mem; - std::vector p_father; - - int producer_num; - std::vector producers; - std::vector> comm_costs; - - // assume the number of combinations is less than 2e9 - int dp_num; - std::vector dp_nodes; -}; - -struct DPNode { - Node *graph_node; - int pg_id; - std::vector ir; - std::vector> in_edges; - // mem, time, activation_mem, optimzer_mem - std::vector> state; -}; - void resetNode(Node *node) { for (DPNode *dp_node : node->dp_nodes) { dp_node->state.clear(); } } -void printNode(Node *node) { - std::cout << "id: " << node->id << std::endl; - std::cout << "father_id: " << node->father_id << std::endl; - std::cout << "cut_len: " << node->cut_len << std::endl; - std::cout << "cut_nodes: "; - for (auto cut_node : node->cut_nodes) { - std::cout << cut_node->id << " "; - } - std::cout << std::endl; - std::cout << "p_num: " << node->p_num << std::endl; - std::cout << "p_time: "; - for (auto p_time : node->p_time) { - std::cout << p_time << " "; - } - std::cout << std::endl; - std::cout << "p_comp_mem: "; - for (auto p_comp_mem : node->p_comp_mem) { - std::cout << p_comp_mem << " "; - } - std::cout << std::endl; - std::cout << "p_buf_mem: "; - for (auto p_buf_mem : node->p_buf_mem) { - std::cout << p_buf_mem << " "; - } - std::cout << std::endl; - std::cout << "p_act_mem: "; - for (auto p_act_mem : node->p_act_mem) { - std::cout << p_act_mem << " "; - } - std::cout << std::endl; - std::cout << "p_opt_mem: "; - for (auto p_opt_mem : node->p_opt_mem) { - std::cout << p_opt_mem << " "; - } - std::cout << std::endl; - std::cout << "producer_num: " << node->producer_num << std::endl; - std::cout << "producers: "; - for (auto producer : node->producers) { - std::cout << producer->id << " "; - } - std::cout << std::endl; - std::cout << "p_father: "; - for (auto p_father : node->p_father) { - std::cout << p_father << " "; - } - std::cout << std::endl; - std::cout << "comm_costs: " << std::endl; - for (auto comm_cost : node->comm_costs) { - for (auto cost : comm_cost) { - std::cout << cost << " "; - } - std::cout << std::endl; - } - std::cout << "dp_num: " << node->dp_num << std::endl; - std::cout << std::endl; -} - // lazy decode // after decoding, ir stores the partition id of each cut node void decodePGID(DPNode *dp_node) { @@ -243,21 +107,10 @@ void decodePGID(DPNode *dp_node) { std::reverse(dp_node->ir.begin(), dp_node->ir.end()); } -struct SearchPlan { - double all_time; - double inner_time; - int memory; - std::vector> path; - - bool operator<(const SearchPlan &other) const { - return all_time < other.all_time; - } -}; - class DPSolver { public: - DPSolver(bool verbose, int mode, int mem_bound, int mem_div, int topk) : verbose(verbose), mode(mode), mem_bound(mem_bound), mem_div(mem_div), topk(topk) - { + DPSolver(bool verbose, int mem_bound, int topk) + : verbose(verbose), mem_bound(mem_bound), topk(topk) { queries.clear(); id2node.clear(); search_results.clear(); @@ -272,7 +125,8 @@ class DPSolver { } void add_node(int id, int father_id, std::vector cut_ids, - std::vector producers, int p_num) { + std::vector producers, int p_num, bool is_recompute, + bool is_recompute_in, bool is_recompute_last) { if (verbose) { std::cout << "id: " << id << ", father_id: " << father_id << ", cut_ids: "; @@ -289,9 +143,13 @@ class DPSolver { id2node[id] = node; node->id = id; node->p_num = p_num; + node->is_recompute = is_recompute; + node->is_recompute_in = is_recompute_in; + node->is_recompute_last = is_recompute_last; node->p_father.resize(p_num); node->p_time.resize(p_num); node->p_comp_mem.resize(p_num); + node->p_in_mem.resize(p_num); node->p_buf_mem.resize(p_num); node->p_act_mem.resize(p_num); node->p_opt_mem.resize(p_num); @@ -312,14 +170,15 @@ class DPSolver { } void add_partition(int node_id, int p_idx, double p_time, int p_comp_mem, - int p_buf_mem, int p_act_mem, int p_opt_mem, int p_father, + int p_in_mem, int p_buf_mem, int p_act_mem, int p_opt_mem, + int p_father, std::vector> comm_costs) { if (verbose) { std::cout << "node_id: " << node_id << ", p_idx: " << p_idx << ", p_time: " << p_time << ", p_comp_mem: " << p_comp_mem - << ", p_buf_mem: " << p_buf_mem << ", p_act_mem: " << p_act_mem - << ", p_opt_mem: " << p_opt_mem << ", p_father: " << p_father - << std::endl; + << ", p_in_mem: " << p_in_mem << ", p_buf_mem: " << p_buf_mem + << ", p_act_mem: " << p_act_mem << ", p_opt_mem: " << p_opt_mem + << ", p_father: " << p_father << std::endl; std::cout << "comm_costs: " << std::endl; for (std::size_t i = 0; i < comm_costs.size(); ++i) { for (std::size_t j = 0; j < comm_costs[i].size(); ++j) { @@ -331,6 +190,7 @@ class DPSolver { Node *node = id2node[node_id]; node->p_time[p_idx] = p_time; node->p_comp_mem[p_idx] = p_comp_mem; + node->p_in_mem[p_idx] = p_in_mem; node->p_buf_mem[p_idx] = p_buf_mem; node->p_act_mem[p_idx] = p_act_mem; node->p_opt_mem[p_idx] = p_opt_mem; @@ -540,83 +400,61 @@ class DPSolver { void update(DPNode *dp_node, int start_level) { Node *node = dp_node->graph_node; decodePGID(dp_node); - int cur_p_idx = *(dp_node->ir.rbegin()); + int cur_p = *(dp_node->ir.rbegin()); if (node->id == start_level) { - // each dp node maintains a list of states, each state is a tuple - // (mem, time, pred_dp_node, activation_mem, optimizer_mem) - dp_node->state.push_back(std::make_tuple( - node->p_comp_mem[cur_p_idx], node->p_time[cur_p_idx], nullptr, - node->p_act_mem[cur_p_idx], node->p_opt_mem[cur_p_idx])); + // each dp node maintains a list of UnitDPState + dp_node->state.push_back(UnitDPState(node, cur_p, nullptr, 0)); return; } // storing edges takes space, so we build edges when needed buildInEdges(dp_node); - int cur_p = *(dp_node->ir.rbegin()); if (dp_node->in_edges.empty()) { - dp_node->state.push_back(std::make_tuple( - 0, std::numeric_limits::infinity(), nullptr, 0, 0)); + // no in edges, means the node is not used + UnitDPState state = + UnitDPState(0, 0, 0, 0, 0, 0, 0, + std::numeric_limits::infinity(), nullptr, 0); + dp_node->state.push_back(state); return; } - // use a priority queue to maintain the best state, similar to the merge - // sort - double cur_p_time = node->p_time[cur_p]; - int cur_p_comp_mem = node->p_comp_mem[cur_p]; - int cur_p_act_mem = node->p_act_mem[cur_p]; - int cur_p_opt_mem = node->p_opt_mem[cur_p]; - std::priority_queue> pq; + // use a priority queue to maintain the best state like merge sort + std::priority_queue> pq; for (std::size_t i = 0; i < dp_node->in_edges.size(); ++i) { DPNode *pred = dp_node->in_edges[i].first; - int mem = cur_p_comp_mem + std::get<0>(pred->state[0]); - double cost = cur_p_time + dp_node->in_edges[i].second + - std::get<1>(pred->state[0]); - int act_mem = cur_p_act_mem + std::get<3>(pred->state[0]); - int opt_mem = cur_p_opt_mem + std::get<4>(pred->state[0]); - pq.push(std::make_tuple(-mem, -cost, i, -act_mem, -opt_mem)); + UnitDPState pred_state = pred->state[0]; + double transition_cost = dp_node->in_edges[i].second; + UnitDPState new_state = + pred_state.generate_new_state(node, cur_p, transition_cost, pred, 0); + pq.push(std::make_tuple(new_state, i)); } - std::vector lows(dp_node->in_edges.size(), 1); + std::vector lows(dp_node->in_edges.size(), 1); - int cur_mem; - double cur_cost; int pred_idx; - int cur_act_mem; - int cur_opt_mem; + UnitDPState cur_state; while (!pq.empty()) { - std::tie(cur_mem, cur_cost, pred_idx, cur_act_mem, cur_opt_mem) = - pq.top(); - cur_mem = -cur_mem; - cur_cost = -cur_cost; - cur_act_mem = -cur_act_mem; - cur_opt_mem = -cur_opt_mem; + std::tie(cur_state, pred_idx) = pq.top(); + DPNode *pred = dp_node->in_edges[pred_idx].first; pq.pop(); if (lows[pred_idx] < dp_node->in_edges[pred_idx].first->state.size()) { - DPNode *pred = dp_node->in_edges[pred_idx].first; - int mem = cur_p_comp_mem + std::get<0>(pred->state[lows[pred_idx]]); - double cost = cur_p_time + dp_node->in_edges[pred_idx].second + - std::get<1>(pred->state[lows[pred_idx]]); - int act_mem = cur_p_act_mem + std::get<3>(pred->state[lows[pred_idx]]); - int opt_mem = cur_p_opt_mem + std::get<4>(pred->state[lows[pred_idx]]); - pq.push(std::make_tuple(-mem, -cost, pred_idx, -act_mem, -opt_mem)); + UnitDPState pred_state = pred->state[lows[pred_idx]]; + double transition_cost = dp_node->in_edges[pred_idx].second; + UnitDPState new_state = pred_state.generate_new_state( + node, cur_p, transition_cost, pred, lows[pred_idx]); + pq.push(std::make_tuple(new_state, pred_idx)); ++lows[pred_idx]; } if (dp_node->state.empty()) { - dp_node->state.push_back(std::make_tuple( - cur_mem, cur_cost, dp_node->in_edges[pred_idx].first, cur_act_mem, - cur_opt_mem)); + dp_node->state.push_back(cur_state); } else { - int pre_mem = std::get<0>(dp_node->state[dp_node->state.size() - 1]); - double pre_cost = - std::get<1>(dp_node->state[dp_node->state.size() - 1]); - // if (cur_mem > pre_mem && cur_cost < pre_cost && - // cur_mem + cur_opt_mem <= mem_bound) { - if (cur_mem > pre_mem && cur_cost < pre_cost && - cur_mem - cur_act_mem + std::max(cur_act_mem, cur_opt_mem) <= - mem_bound) { - dp_node->state.push_back(std::make_tuple( - cur_mem, cur_cost, dp_node->in_edges[pred_idx].first, cur_act_mem, - cur_opt_mem)); + UnitDPState pre_state = dp_node->state[dp_node->state.size() - 1]; + int pre_mem = pre_state.total_mem; + double pre_cost = pre_state.time_cost; + int cur_mem = cur_state.total_mem; + double cur_cost = cur_state.time_cost; + if (cur_mem <= mem_bound && cur_mem > pre_mem && cur_cost < pre_cost) { + dp_node->state.push_back(cur_state); } } } @@ -655,65 +493,31 @@ class DPSolver { SearchPlan process_state(DPNode *dp_node, int idx) { // build the optimal path of each partition of last operator - // and return the best path + // and return the plan std::vector> path; DPNode *cur_dp_node = dp_node; - int cur_idx = idx; - int best_mem = std::get<0>(dp_node->state[idx]); - double best_time = std::get<1>(dp_node->state[idx]); - int act_mem = std::get<3>(dp_node->state[idx]); - int opt_mem = std::get<4>(dp_node->state[idx]); - double inner_time = 0; - int cur_best_mem = best_mem; - std::vector buffers; + UnitDPState best_state = dp_node->state[idx]; + UnitDPState cur_state = best_state; + DPNode *pred_dp_node = nullptr; while (true) { int cur_p = *(cur_dp_node->ir.rbegin()); Node *node = cur_dp_node->graph_node; path.push_back(std::make_pair(node->id, cur_p)); - buffers.push_back(node->p_buf_mem[cur_p]); - inner_time += node->p_time[cur_p]; - cur_best_mem -= node->p_comp_mem[cur_p]; - DPNode *pred_dp_node = std::get<2>(cur_dp_node->state[cur_idx]); + pred_dp_node = cur_state.pred_dp_node; if (pred_dp_node == nullptr) { break; } else { + cur_state = pred_dp_node->state[cur_state.pred_idx]; cur_dp_node = pred_dp_node; - cur_idx = std::lower_bound( - cur_dp_node->state.begin(), cur_dp_node->state.end(), - std::make_tuple(cur_best_mem, static_cast(-1), - static_cast(nullptr), -1, -1)) - - cur_dp_node->state.begin(); } } std::reverse(path.begin(), path.end()); - std::sort(buffers.begin(), buffers.end()); - long long ret_mem = static_cast(best_mem); - if (mode == 0) { - ret_mem += buffers[buffers.size() - 1] + buffers[buffers.size() - 2]; - } else if (mode == 1) { - ret_mem += buffers[buffers.size() - 1]; - } - ret_mem = ret_mem - act_mem + std::max(act_mem, opt_mem); - if (ret_mem > mem_bound) { - return SearchPlan{-1, -1, -1, std::vector>()}; - } - if (verbose) { - std::cout << "best time: " << best_time - << ", best mem: " << best_mem / 1024 / 1024 * mem_div << "MB, " - << "activation mem: " << act_mem / 1024 / 1024 * mem_div - << "MB, " - << "optimizer state mem: " << opt_mem / 1024 / 1024 * mem_div - << "MB" << std::endl; - } - return SearchPlan{best_time, inner_time, static_cast(ret_mem), path}; + return SearchPlan{best_state.time_cost, + best_state.total_mem, path}; } void post_process(int start_level, int end_level, int topk) { std::vector best_info; - double best_time; - double inner_time; - int best_mem; - std::vector> path; for (DPNode *dp_node : id2node[end_level]->dp_nodes) { int cnt = 0; for (std::size_t i = 0; i < dp_node->state.size(); ++i) { @@ -735,9 +539,7 @@ class DPSolver { if (verbose) { std::cout << "start to solve" << std::endl; std::cout << "verbose: " << verbose << std::endl; - std::cout << "mode: " << mode << std::endl; std::cout << "mem_bound: " << mem_bound << std::endl; - std::cout << "mem_div: " << mem_div << std::endl; std::cout << "topk: " << topk << std::endl; } init_dp_info(); @@ -764,7 +566,6 @@ class DPSolver { } long long state_cnt = 0; for (auto iter = id2node.begin(); iter != id2node.end(); ++iter) { - int cur_id = iter->first; Node *cur_node = iter->second; for (DPNode *dp_node : cur_node->dp_nodes) { state_cnt += dp_node->state.size(); @@ -778,8 +579,10 @@ class DPSolver { std::chrono::duration elapsed_seconds = end - start; - std::cout << "elapsed time: " << elapsed_seconds.count() << " s" - << std::endl; + if (verbose) { + std::cout << "elapsed time: " << elapsed_seconds.count() << " s" + << std::endl; + } } std::vector get_results(int start_level, int end_level) { @@ -787,13 +590,8 @@ class DPSolver { } bool verbose; - // mode = 0: training, use the sum of the two largest buffer sizes - // mode = 1: inference, use the largest buffer size - int mode; // mem_bound: the maximum memory usage, in bytes int mem_bound; - // mem_div: the memory divisor, to avoid overflow in int32 - int mem_div; int topk; std::unordered_map id2node; @@ -804,12 +602,11 @@ class DPSolver { PYBIND11_MODULE(dp_solver, m) { py::class_(m, "SearchPlan") .def_readonly("all_time", &SearchPlan::all_time) - .def_readonly("inner_time", &SearchPlan::inner_time) .def_readonly("memory", &SearchPlan::memory) .def_readonly("path", &SearchPlan::path); py::class_(m, "DPSolver") - .def(py::init()) + .def(py::init()) .def("add_interval", &DPSolver::add_interval) .def("add_node", &DPSolver::add_node) .def("add_partition", &DPSolver::add_partition) @@ -822,5 +619,6 @@ setup_pybind11(cfg) cfg['extra_compile_args'] = ['-std=c++11'] cfg['extra_compile_args'] = ['-O3'] cfg['extra_compile_args'] = ['-pthread'] +cfg['dependencies'] = ['dp_solver.h'] %> */ diff --git a/nnscaler/autodist/dp_solver.h b/nnscaler/autodist/dp_solver.h new file mode 100644 index 00000000..e6d5af3d --- /dev/null +++ b/nnscaler/autodist/dp_solver.h @@ -0,0 +1,225 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(unsigned int n = std::thread::hardware_concurrency()); + + template void enqueue(F &&f); + void waitFinished(); + ~ThreadPool(); + + unsigned int getProcessed() const { return processed; } + +private: + std::vector workers; + std::deque> tasks; + std::mutex queue_mutex; + std::condition_variable cv_task; + std::condition_variable cv_finished; + std::atomic_uint processed; + unsigned int busy; + bool stop; + + void thread_proc(); +}; + +struct DPNode; +struct UnitDPState; + +struct Node { + int id; + int father_id; + + int cut_len; + std::vector cut_nodes; + + // whether the node is in a recompute region + bool is_recompute; + // if the node is in a recompute region, whether it accetps input outside + // of the recompute region + bool is_recompute_in; + // if the node is in a recompute region, whether it is the last node in the + // topological sequence of the recompute region + bool is_recompute_last; + + int p_num; + std::vector p_time; + std::vector p_comp_mem; + std::vector p_in_mem; + std::vector p_buf_mem; + std::vector p_act_mem; + std::vector p_opt_mem; + std::vector p_father; + + int producer_num; + std::vector producers; + std::vector> comm_costs; + + // assume the number of combinations is less than 2e9 + int dp_num; + std::vector dp_nodes; +}; + +struct DPNode { + Node *graph_node; + // pg_id: partition group id, an equivalent representation of `ir` + int pg_id; + std::vector ir; + std::vector> in_edges; + // saved UnitDPStates are sorted by total mem cost and satisfy that + // if lhs.total_mem < rhs.total_mem, then lhs.time_cost > rhs.time_cost + std::vector state; +}; + +struct SearchPlan { + double all_time; + int memory; + std::vector> path; + + bool operator<(const SearchPlan &other) const { + return all_time < other.all_time; + } +}; + +struct UnitDPState { + int param_related_mem; + int activation_mem; + int opt_transient_mem; + int largest_transient_mem_1st; + int largest_transient_mem_2nd; + int max_recompute_mem; + int cur_recompute_mem; + int total_mem; + double time_cost; + + DPNode *pred_dp_node; + int pred_idx; + + UnitDPState() {} + + UnitDPState(int param_related_mem, int activation_mem, + int opt_transient_mem, int largest_transient_mem_1st, + int largest_transient_mem_2nd, int max_recompute_mem, + int cur_recompute_mem, double time_cost, + DPNode *pred_dp_node, int pred_idx) + : param_related_mem(param_related_mem), activation_mem(activation_mem), + opt_transient_mem(opt_transient_mem), + largest_transient_mem_1st(largest_transient_mem_1st), + largest_transient_mem_2nd(largest_transient_mem_2nd), + max_recompute_mem(max_recompute_mem), + cur_recompute_mem(cur_recompute_mem), time_cost(time_cost), + pred_dp_node(pred_dp_node), pred_idx(pred_idx) { + total_mem = calc_total_mem(); + } + + UnitDPState(Node *node, int partition_idx, DPNode *pred_dp_node, int pred_idx) + : pred_dp_node(pred_dp_node), pred_idx(pred_idx) { + double time_cost = node->p_time[partition_idx]; + int comp_mem = node->p_comp_mem[partition_idx]; + int in_mem = node->p_in_mem[partition_idx]; + int buf_mem = node->p_buf_mem[partition_idx]; + int act_mem = node->p_act_mem[partition_idx]; + int opt_mem = node->p_opt_mem[partition_idx]; + this->param_related_mem = comp_mem - act_mem; + if (node->is_recompute == true) { + if (node->is_recompute_in == true) { + this->activation_mem = in_mem; + } else { + this->activation_mem = 0; + } + this->cur_recompute_mem = act_mem; + if (node->is_recompute_last == true) { + this->max_recompute_mem = act_mem; + this->cur_recompute_mem = 0; + } else { + this->max_recompute_mem = 0; + } + } else { + this->activation_mem = act_mem; + this->max_recompute_mem = 0; + this->cur_recompute_mem = 0; + } + this->opt_transient_mem = opt_mem; + this->largest_transient_mem_1st = buf_mem; + this->largest_transient_mem_2nd = 0; + this->time_cost = time_cost; + this->total_mem = calc_total_mem(); + } + + int calc_total_mem() { + return param_related_mem + std::max(activation_mem, opt_transient_mem) + + largest_transient_mem_1st + largest_transient_mem_2nd + + max_recompute_mem; + } + + UnitDPState generate_new_state(Node *node, int partition_idx, + double transition_cost, DPNode *pred_dp_node, + int pred_idx) { + double time_cost = node->p_time[partition_idx]; + int comp_mem = node->p_comp_mem[partition_idx]; + int in_mem = node->p_in_mem[partition_idx]; + int buf_mem = node->p_buf_mem[partition_idx]; + int act_mem = node->p_act_mem[partition_idx]; + int opt_mem = node->p_opt_mem[partition_idx]; + int param_related_mem = comp_mem - act_mem + this->param_related_mem; + int opt_transient_mem = opt_mem + this->opt_transient_mem; + int largest_transient_mem_1st = this->largest_transient_mem_1st; + int largest_transient_mem_2nd = this->largest_transient_mem_2nd; + if (buf_mem > largest_transient_mem_1st) { + largest_transient_mem_2nd = largest_transient_mem_1st; + largest_transient_mem_1st = buf_mem; + } else if (buf_mem > largest_transient_mem_2nd) { + largest_transient_mem_2nd = buf_mem; + } + int max_recompute_mem = this->max_recompute_mem; + int cur_recompute_mem = this->cur_recompute_mem; + int activation_mem = act_mem + this->activation_mem; + if (node->is_recompute == true) { + if (node->is_recompute_in == true) { + activation_mem = in_mem + this->activation_mem; + } else { + activation_mem = this->activation_mem; + } + cur_recompute_mem += act_mem; + if (node->is_recompute_last == true) { + max_recompute_mem = std::max(max_recompute_mem, cur_recompute_mem); + cur_recompute_mem = 0; + } + } + return UnitDPState(param_related_mem, activation_mem, opt_transient_mem, + largest_transient_mem_1st, largest_transient_mem_2nd, + max_recompute_mem, cur_recompute_mem, + this->time_cost + transition_cost + time_cost, + pred_dp_node, pred_idx); + } + + // for priority queue, sort the state by total memory cost from small to large + // if the total memory cost is the same, sort by time cost from small to large + bool operator<(const UnitDPState &other) const { + if (total_mem != other.total_mem) { + return total_mem > other.total_mem; + } else { + return time_cost > other.time_cost; + } + } +}; \ No newline at end of file diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index fc8b9792..710e67be 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -926,6 +926,7 @@ def init_operators(self): if end - start + 1 != len(interval): raise RuntimeError('recompute nodes are not continuous') self._recompute_group_idxs.append(interval) + self.operator_list[end].recompute_last_op = True @property def recompute_group_idxs(self) -> List[List[int]]: diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index b4d417bc..5730d22b 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -1106,15 +1106,12 @@ def get_non_zero_index(binary_vector): offset += 1 plans = [] all_time_cost = objective - inner_time_cost = 0 for i in range(start, end + 1): plans.append((i, s_val[i - start])) - p_cost_desc = self.partition_info[i][s_val[i - start]] - inner_time_cost += p_cost_desc.comp_time + p_cost_desc.weight_update_time mem_cost = self.calc_mem_cost(plans).total_cost return SPMDSearchOutput(self.partition_path2desc(plans), mem_cost / 1024 / 1024 / 1024, all_time_cost, - inner_time_cost) + self.calc_inner_time_cost(plans)) def do_ilp(self, intervals: List[Tuple[int, int]], topk: int) -> List[List[SPMDSearchOutput]]: @@ -1135,20 +1132,22 @@ def do_dp(self, intervals: List[Tuple[int, int]], import cppimport.import_hook import nnscaler.autodist.dp_solver as dp_solver - mode = 0 if self.is_train else 1 - mem_div = 64 - mem_bound = int(self.mem_bound) // mem_div - solver = dp_solver.DPSolver(self.autodist_config.verbose, mode, mem_bound, mem_div, topk) + if self.autodist_config.memory_granularity < 1024: + raise RuntimeError('dp solver assumes the memory granularity is at least 1024 bytes') + buf_mul = 2 if self.is_train else 1 + mem_divisor = self.autodist_config.memory_granularity + solver = dp_solver.DPSolver(self.autodist_config.verbose, self.mem_bound // mem_divisor, topk) for start, end in intervals: solver.add_interval(start, end) for idx in range(self.graph.op_num): + op = self.graph.operator_list[idx] solver.add_node(idx, self.father_ids[idx], self.cut_ops[idx], - self.producers[idx], self.get_op_partition_count(idx)) + self.producers[idx], self.get_op_partition_count(idx), op.recompute, op.recompute_start_op, op.recompute_last_op) for i, partition in enumerate(self._op_partitions[idx]): p_cost_desc = self.partition_info[idx][i] solver.add_partition(idx, i, p_cost_desc.comp_time + p_cost_desc.weight_update_time, - p_cost_desc.mem // mem_div, p_cost_desc.transient_mem // mem_div, - p_cost_desc.activation_mem // mem_div, p_cost_desc.opt_transient_mem // mem_div, + p_cost_desc.mem // mem_divisor, p_cost_desc.in_mem // mem_divisor, buf_mul * p_cost_desc.transient_mem // mem_divisor, + p_cost_desc.activation_mem // mem_divisor, p_cost_desc.opt_transient_mem // mem_divisor, self.p_fathers[idx][i], p_cost_desc.comm_time) solver.solve() ret = [] @@ -1157,7 +1156,7 @@ def do_dp(self, intervals: List[Tuple[int, int]], descs = [] for result in cpp_results: desc = self.partition_path2desc(result.path) - descs.append(SPMDSearchOutput(desc, result.memory * mem_div / 1024 / 1024 / 1024, result.all_time, result.inner_time)) + descs.append(SPMDSearchOutput(desc, result.memory * mem_divisor / 1024 / 1024 / 1024, result.all_time, self.calc_inner_time_cost(result.path))) ret.append(descs) return ret diff --git a/tests/autodist/spmd_solver/test_follow.py b/tests/autodist/spmd_solver/test_follow.py index fc1fc2ca..7f37b312 100644 --- a/tests/autodist/spmd_solver/test_follow.py +++ b/tests/autodist/spmd_solver/test_follow.py @@ -226,7 +226,7 @@ def test_follow_attention(): pc_path = Path(os.path.dirname(__file__)) / 'test_attention_follow.yaml' profile_dir = Path(os.path.dirname(__file__)) / './test_follow_attention_profile' - cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2, profile_dir=profile_dir) + cfg = AutoDistConfig(partition_constraints_path=pc_path, mesh_col=2, profile_dir=profile_dir, memory_granularity=1024) model_graph = ModelGraph(ir_graph, cfg) spmd_solver = SPMDSolver( @@ -322,7 +322,7 @@ def test_solver_data_parallel(): print(ir_graph.nodes()) profile_dir = Path(os.path.dirname(__file__)) / './test_solver_data_parallel' - cfg = AutoDistConfig(mesh_col=2, profile_dir=profile_dir) + cfg = AutoDistConfig(mesh_col=2, profile_dir=profile_dir, memory_granularity=1024) model_graph = ModelGraph(ir_graph, cfg) spmd_solver = SPMDSolver( diff --git a/tests/autodist/test_dp_solver.py b/tests/autodist/test_dp_solver.py index b44abc79..aad7d400 100644 --- a/tests/autodist/test_dp_solver.py +++ b/tests/autodist/test_dp_solver.py @@ -9,20 +9,20 @@ # each operator has 2 partition options def test_dp_solver(): - solver = dp_solver.DPSolver(True, 0, 80 * 1024, 1, 1) + solver = dp_solver.DPSolver(True, 80 * 1024, 1) solver.add_interval(0, 2) - solver.add_node(0, 0, [0], [], 2) - solver.add_partition(0, 0, 1, 1, 1, 1, 1, 0, [[]]) - solver.add_partition(0, 1, 2, 2, 2, 2, 2, 1, [[]]) + solver.add_node(0, 0, [0], [], 2, False, False, False) + solver.add_partition(0, 0, 1, 1, 1, 1, 1, 1, 0, [[]]) + solver.add_partition(0, 1, 2, 2, 2, 2, 2, 2, 1, [[]]) - solver.add_node(1, 1, [1], [0], 2) - solver.add_partition(1, 0, 0.5, 1, 1, 1, 1, 0, [[0.1, 1]]) - solver.add_partition(1, 1, 1, 2, 2, 2, 2, 1, [[1, 0]]) + solver.add_node(1, 1, [1], [0], 2, False, False, False) + solver.add_partition(1, 0, 0.5, 1, 1, 1, 1, 1, 0, [[0.1, 1]]) + solver.add_partition(1, 1, 1, 2, 2, 2, 2, 2, 1, [[1, 0]]) - solver.add_node(2, 2, [2], [1], 2) - solver.add_partition(2, 0, 1, 1, 1, 1, 1, 0, [[0.2, 1]]) - solver.add_partition(2, 1, 2, 2, 2, 2, 2, 1, [[1, 0]]) + solver.add_node(2, 2, [2], [1], 2, False, False, False) + solver.add_partition(2, 0, 1, 1, 1, 1, 1, 1, 0, [[0.2, 1]]) + solver.add_partition(2, 1, 2, 2, 2, 2, 2, 2, 1, [[1, 0]]) solver.solve() @@ -32,7 +32,41 @@ def test_dp_solver(): # optimal all time 1 + 0.5 + 0.1 + 1 + 0.2 = 2.8 assert best.all_time == 2.8 - # optimal inner time 1 + 0.5 + 1 = 2.5 - assert best.inner_time == 2.5 # the optimal plan is each operator's first partition assert best.path == [(0, 0), (1, 0), (2, 0)] + +def test_dp_solver_mem(): + solver = dp_solver.DPSolver(True, 100, 1) + solver.add_interval(0, 4) + + solver.add_node(0, 0, [0], [], 1, True, True, False) + solver.add_partition(0, 0, 0.1, 10, 1, 1, 1, 1, 0, [[]]) + + solver.add_node(1, 1, [1], [0], 1, True, False, True) + solver.add_partition(1, 0, 0.2, 10, 2, 2, 2, 2, 0, [[0]]) + + solver.add_node(2, 2, [2], [1], 1, True, True, False) + solver.add_partition(2, 0, 0.3, 10, 3, 3, 3, 3, 0, [[0]]) + + solver.add_node(3, 3, [3], [2], 1, True, True, False) + solver.add_partition(3, 0, 0.4, 10, 4, 4, 4, 4, 0, [[0]]) + + solver.add_node(4, 4, [4], [3], 1, True, False, True) + solver.add_partition(4, 0, 0.5, 10, 5, 5, 5, 5, 0, [[0]]) + + # the total memory cost should be + # param: 10 - 1 + 10 - 2 + 10 - 3 + 10 - 4 + 10 - 5 = 35 + # buffer: 5 + 4 = 9 + # activation: 1 + 3 + 4 = 8 + # opt_transient_mem: 1 + 2 + 3 + 4 + 5 = 15 + # recompute: max(1 + 2, 3 + 4 + 5) = 12 + # in all: 35 + 9 + max(8, 15) + 12 = 71 + + solver.solve() + + ans = solver.get_results(0, 4) + + best = ans[0] + assert best.all_time == 1.5 + assert best.path == [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)] + assert best.memory == 71 From d1d123c0db035cdaae8e6e4b428b786354b93eee Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 8 Jul 2024 03:10:38 +0000 Subject: [PATCH 1676/1892] Merged PR 2198: Fix bugs in autodist 1. [PR](https://dev.azure.com/msrasrg/SuperScaler/_git/MagicCube/pullrequest/2185) ignores the case when profiling fails 2. dp solver bug 1: segment fault when `following candidates` is empty 2. dp solver bug 2: corner case, the new generated dp state can be illegal, need to check when adding it to new states tests added --- nnscaler/autodist/dp_solver.cpp | 14 +++++++--- nnscaler/autodist/dp_solver.h | 14 ++++++++++ nnscaler/profiler/database.py | 7 +++-- tests/autodist/test_dp_solver.py | 47 ++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 7 deletions(-) diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index 5564c2fc..515b412d 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -267,8 +267,9 @@ class DPSolver { for (int idx = 0; idx < producer_comb_num; ++idx) { bool is_legal = true; int val = idx; - std::vector producer_ps(node->producer_num); // decode the producer partition combination + // continue if the partition states of producers are illegal + std::vector producer_ps(node->producer_num); for (int j = 0; j < node->producer_num; ++j) { int k = node->producer_num - 1 - j; producer_ps[k] = val % node->producers[k]->p_num; @@ -291,6 +292,7 @@ class DPSolver { if (!is_legal) { continue; } + // build the representation of the predecessor dp node // std::vector> cur_ir(node->cut_len - 1); bool has_found_follow = false; @@ -361,8 +363,9 @@ class DPSolver { // do nothing, means the pre_node's output is not used // we select the 1st partition of the pre_node // need to be careful when the graph has multiple outputs - // shall we constrain that the output of the graph is replicated? - cur_ir.push_back(*follow_candidates.rbegin()); + if (!has_found_follow && !follow_candidates.empty()) { + cur_ir.push_back(*follow_candidates.rbegin()); + } } else if (pre_node->father_id == pre_node->id) { assert(follow_candidates.rbegin()->first == pre_node->id); cur_ir.push_back(*follow_candidates.rbegin()); @@ -422,6 +425,9 @@ class DPSolver { std::priority_queue> pq; for (std::size_t i = 0; i < dp_node->in_edges.size(); ++i) { DPNode *pred = dp_node->in_edges[i].first; + if (pred->state.empty()) { + continue; + } UnitDPState pred_state = pred->state[0]; double transition_cost = dp_node->in_edges[i].second; UnitDPState new_state = @@ -445,7 +451,7 @@ class DPSolver { pq.push(std::make_tuple(new_state, pred_idx)); ++lows[pred_idx]; } - if (dp_node->state.empty()) { + if (dp_node->state.empty() && cur_state.total_mem <= mem_bound) { dp_node->state.push_back(cur_state); } else { UnitDPState pre_state = dp_node->state[dp_node->state.size() - 1]; diff --git a/nnscaler/autodist/dp_solver.h b/nnscaler/autodist/dp_solver.h index e6d5af3d..05bc2196 100644 --- a/nnscaler/autodist/dp_solver.h +++ b/nnscaler/autodist/dp_solver.h @@ -213,6 +213,20 @@ struct UnitDPState { pred_dp_node, pred_idx); } + std::string to_string() { + return "param_related_mem: " + std::to_string(param_related_mem) + + ", activation_mem: " + std::to_string(activation_mem) + + ", opt_transient_mem: " + std::to_string(opt_transient_mem) + + ", largest_transient_mem_1st: " + + std::to_string(largest_transient_mem_1st) + + ", largest_transient_mem_2nd: " + + std::to_string(largest_transient_mem_2nd) + + ", max_recompute_mem: " + std::to_string(max_recompute_mem) + + ", cur_recompute_mem: " + std::to_string(cur_recompute_mem) + + ", total_mem: " + std::to_string(total_mem) + + ", time_cost: " + std::to_string(time_cost); + } + // for priority queue, sort the state by total memory cost from small to large // if the total memory cost is the same, sort by time cost from small to large bool operator<(const UnitDPState &other) const { diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 7cf3cd37..d21c8c5b 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -292,14 +292,15 @@ def profile(self, node: IRFwOperation, override: bool = False) -> ProfiledMetric fn, shapes, dtypes, requires_grads, values, kwargs = get_func(node) - in_mem_info, param_mem_info, buffer_mem_info = [], [], [] - for t in node.inputs(): + in_mem_info, param_mem_info, buffer_mem_info, in_mem_idx = [], [], [], [] + for idx, t in enumerate(node.inputs()): if isinstance(t, IRTensor) and t.is_param(): param_mem_info.append(t.byte_size()) elif isinstance(t, IRTensor) and t.is_buffer(): buffer_mem_info.append(t.byte_size()) elif hasattr(t, 'byte_size'): in_mem_info.append(t.byte_size()) + in_mem_idx.append(idx) else: _logger.debug(f'node {node}: skip input {t}') @@ -316,7 +317,7 @@ def profile(self, node: IRFwOperation, override: bool = False) -> ProfiledMetric infer_memory += t.byte_size() # by default, we assume that all the input tensors are saved for backward train_mem_info = copy.deepcopy(in_mem_info) - train_mem2in_idx = list(range(len(in_mem_info))) + train_mem2in_idx = in_mem_idx profiled_metrics = ProfiledMetrics(in_mem_info, param_mem_info, buffer_mem_info, fw_span, bw_span, infer_memory, train_mem_info, train_mem2in_idx) diff --git a/tests/autodist/test_dp_solver.py b/tests/autodist/test_dp_solver.py index aad7d400..330e8630 100644 --- a/tests/autodist/test_dp_solver.py +++ b/tests/autodist/test_dp_solver.py @@ -70,3 +70,50 @@ def test_dp_solver_mem(): assert best.all_time == 1.5 assert best.path == [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)] assert best.memory == 71 + +def test_dp_solver_build_in_edges(): + # mock following code + # dropout_rate = self.attention_dropout if self.training else 0.0 + # attn_output = nnscaler_flash_attention_forward( + # query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, causal=causal + # ) + # 3 nodes will be generated, there are no following chains and tensors between them + # 1. self_getattr + # 2. ifexpr + # 3. nnscaler_flash_attention_forward + solver = dp_solver.DPSolver(True, 100, 1) + solver.add_interval(0, 2) + + solver.add_node(0, 0, [0], [], 1, False, False, False) + solver.add_partition(0, 0, 0, 0, 0, 0, 0, 0, 0, [[]]) + + solver.add_node(1, 1, [1], [], 1, False, False, False) + solver.add_partition(1, 0, 0, 0, 0, 0, 0, 0, 0, [[]]) + + solver.add_node(2, 2, [2], [], 1, False, False, False) + solver.add_partition(2, 0, 1, 0, 0, 0, 0, 0, 0, [[]]) + + solver.solve() + + ans = solver.get_results(0, 2) + + best = ans[0] + assert best.path == [(0, 0), (1, 0), (2, 0)] + +def test_dp_solver_mem_bound(): + solver = dp_solver.DPSolver(True, 10, 1) + solver.add_interval(0, 2) + + solver.add_node(0, 0, [0], [], 1, False, False, False) + solver.add_partition(0, 0, 0, 8, 0, 0, 0, 0, 0, [[]]) + + solver.add_node(1, 1, [1], [], 1, False, False, False) + solver.add_partition(1, 0, 0, 5, 0, 0, 0, 0, 0, [[]]) + + solver.add_node(2, 2, [2], [], 1, False, False, False) + solver.add_partition(2, 0, 1, 11, 0, 0, 0, 0, 0, [[]]) + + solver.solve() + + ans = solver.get_results(0, 2) + assert len(ans) == 0 From 42f64b1165c670b140fc177f17c7b66825eca866 Mon Sep 17 00:00:00 2001 From: "Xin Ji (CSI Interfusion Co Ltd)" Date: Thu, 11 Jul 2024 09:35:30 +0000 Subject: [PATCH 1677/1892] Merged PR 2169: support conv1d-2d support ConvTranspose1D,Conv2D,ConvTranspose2D --- nnscaler/graph/function/function.py | 328 +++++++++++++++++++++---- nnscaler/graph/parser/fx/mapping.py | 8 +- tests/graph/function/test_functions.py | 101 +++++++- tests/parallel_module/test_gencode.py | 131 ++++++++++ 4 files changed, 511 insertions(+), 57 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index d2bff238..5979b4d2 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -12,7 +12,7 @@ from nnscaler.ir.tensor import IRSubTensor, IRFullTensor from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule -from nnscaler.graph.function.conv import IRPad, IRConv2D, IRConv3D +from nnscaler.graph.function.conv import IRPad, IRConv3D from nnscaler.graph.function.anchor import IRGraphAnchor _logger = logging.getLogger(__name__) @@ -1580,19 +1580,6 @@ def Pad(input, pad, mode='constant', value=0.0, signature = None): # stride=stride, padding=padding, dilation=dilation, groups=groups) -def Conv2D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, signature = None): - """ - torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) - """ - if isinstance(padding, int): - padding = [padding] * 4 - elif len(padding) == 2: - padH, padW = padding - padding = [padH, padH, padW, padW] - return IRConv2D(signature, [input, weight, bias], 'conv2d', - stride=stride, padding=padding, dilation=dilation, groups=groups) - - def Conv3D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, signature = None): """ torch.nn.functional.conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) @@ -2662,42 +2649,50 @@ def Erf(input, *, out=None, signature=None): return IRDimops(Erf, 'erf', signature, annos, [input]) +def unwrap_if_irobject(x): + return x.value if isinstance(x, IRObject) else x + + def Conv1D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, signature=None): """ torch.nn.functional.conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor """ - if isinstance(stride, int): - stride = (stride,) - if isinstance(dilation, int): - dilation = (dilation,) - if isinstance(padding, str): - if padding == 'same': + if len(input.shape) not in [2, 3]: + raise ValueError(f"Expected input tensor to have 2 or 3 dimensions, but got {input.shape}") + stride_val = unwrap_if_irobject(stride) + padding_val = unwrap_if_irobject(padding) + dilation_val = unwrap_if_irobject(dilation) + groups_val = unwrap_if_irobject(groups) + if isinstance(stride_val, int): + stride_val = (stride_val,) + if isinstance(dilation_val, int): + dilation_val = (dilation_val,) + kW = weight.shape[-1] + effective_kernel_size = (kW - 1) * dilation_val[0] + if isinstance(padding_val, str): + if padding_val == 'same': # For 'same' padding, calculate padding needed to keep the output shape the same as input shape # this mode doesn’t support any stride values other than 1. - kW = weight.shape[2] - iW = input.shape[2] - effective_kernel_size = (kW - 1) * dilation[0] + 1 - total_padding = max(0, (iW - 1) * stride[0] + effective_kernel_size - iW) + iW = input.shape[-1] + total_padding = (iW - 1) * stride_val[0] + effective_kernel_size + 1 - iW pad_ = total_padding // 2 # NOTE: While we calculate padding for both sides, conv1d expects a single integer for symmetrical padding. - padding = (pad_, ) - elif padding == 'valid': - padding = (0, ) + padding_val = (pad_, ) + elif padding_val == 'valid': + padding_val = (0, ) else: - raise ValueError("Unsupported padding value: {}. Use 'valid', 'same', or an integer.".format(padding)) - elif isinstance(padding, int): - padding = (padding,) - elif not isinstance(padding, tuple): + raise ValueError("Unsupported padding value: {}. Use 'valid', 'same', or an integer.".format(padding_val)) + elif isinstance(padding_val, int): + padding_val = (padding_val,) + elif not isinstance(padding_val, tuple): raise ValueError("Padding must be a string ('valid', 'same'), an integer, or a tuple") - ori_groups = groups - if isinstance(groups, IRObject): groups = groups.value - _, iW = input.shape[1:3] - oC, iC, kW = weight.shape - oW = (iW + 2 * padding[0] - dilation[0] * (kW - 1) - 1) // stride[0] + 1 - if input.shape[1] // groups != weight.shape[1]: - raise ValueError(f'Input shape and weight shape are not compatible for the number of groups. input shape: {input.shape}, weight shape: {weight.shape}, groups: {groups}') - if weight.shape[0] % groups != 0: + iC, iW = input.shape[-2:] + oC, iCg, kW = weight.shape + oW = (iW + 2 * padding_val[0] - effective_kernel_size - 1) // stride_val[0] + 1 + if iC // groups_val != iCg: + raise ValueError(f'Input shape and weight shape are not compatible for the number of groups. input shape: {input.shape}, weight shape: {weight.shape}, groups: {groups_val}') + if oC % groups_val != 0: raise ValueError('The output channels of weight must be divisible by the number of groups.') def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: # only for partitioning groups @@ -2708,24 +2703,249 @@ def modifier(kwargs: Dict, idx, dim, num: int) -> Dict: kw_groups = kw_groups.value kwargs['groups'] = kw_groups // num return kwargs - if bias is None: - # NOTE: cannot support partitioning inchannel when groups>1 - if groups == 1: - annos = [f'n iC+ {iW}, oC iC+ {kW} -> n oC {oW}'] - rules = None + if len(input.shape) == 2: + if bias is None: + if groups_val == 1: + annos = [f'iC+ {iW}, oC iC+ {kW} -> oC {oW}'] + rules = None + else: + rules = [TransformRule([DimopSplit.D(0), DimopSplit.D(0)], [DimopSplit.D(0)], modifier)] + annos = [f'(g {iCg}) {iW}, (g {oC // groups_val}) {iCg} {kW} -> (g {oC // groups_val}) {oW}'] else: - rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0)], [DimopSplit.D(1)], modifier)] - annos = [f'n (g {iC}) {iW}, (g {oC//groups}) {iC} {kW} -> n (g {oC//groups}) {oW}'] - else: - # NOTE: not supported value partition of bias yet - if groups == 1: - annos = [f'n iC^ {iW}, oC iC^ {kW}, oC -> n oC {oW}'] - rules = None + if groups_val == 1: + annos = [f'iC^ {iW}, oC iC^ {kW}, oC -> oC {oW}'] + rules = None + else: + rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0), DimopSplit.D(0)], [DimopSplit.D(0)], modifier)] + annos = [f'(g {iCg}) {iW}, (g {oC // groups_val}) {iCg} {kW}, (g {oC // groups_val}) -> (g {oC // groups_val}) {oW}'] + elif len(input.shape) == 3: + if bias is None: + # NOTE: cannot support partitioning inchannel when groups>1 + if groups_val == 1: + annos = [f'n iC+ {iW}, oC iC+ {kW} -> n oC {oW}'] + rules = None + else: + rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0)], [DimopSplit.D(1)], modifier)] + annos = [f'n (g {iCg}) {iW}, (g {oC//groups_val}) {iCg} {kW} -> n (g {oC//groups_val}) {oW}'] else: - rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0), DimopSplit.D(0)], [DimopSplit.D(1)], modifier)] - annos = [f'n (g {iC}) {iW}, (g {oC//groups}) {iC} {kW}, (g {oC//groups}) -> n (g {oC//groups}) {oW}'] + # NOTE: not supported value partition of bias yet + if groups_val == 1: + annos = [f'n iC^ {iW}, oC iC^ {kW}, oC -> n oC {oW}'] + rules = None + else: + rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0), DimopSplit.D(0)], [DimopSplit.D(1)], modifier)] + annos = [f'n (g {iCg}) {iW}, (g {oC//groups_val}) {iCg} {kW}, (g {oC//groups_val}) -> n (g {oC//groups_val}) {oW}'] return IRDimops(Conv1D, 'conv1d', signature, annos, [input, weight, bias] if bias is not None else [input, weight], rules, - stride=stride, padding=padding, dilation=dilation, groups=ori_groups) + stride=stride, padding=padding, dilation=dilation, groups=groups) + + +def ConvTranspose1D(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, signature=None): + """ + torch.nn.functional.conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) + """ + if len(input.shape) not in [2, 3]: + raise ValueError(f"Expected input tensor to have 2 or 3 dimensions, but got {input.shape}") + stride_val = unwrap_if_irobject(stride) + padding_val = unwrap_if_irobject(padding) + output_padding_val = unwrap_if_irobject(output_padding) + dilation_val = unwrap_if_irobject(dilation) + groups_val = unwrap_if_irobject(groups) + if isinstance(stride_val, int): + stride_val = (stride_val,) + if isinstance(padding_val, int): + padding_val = (padding_val,) + if isinstance(output_padding_val, int): + output_padding_val = (output_padding_val,) + if isinstance(dilation_val, int): + dilation_val = (dilation_val,) + if not (len(stride_val) == 1 and len(padding_val) == 1 and len(output_padding_val) == 1 and len(dilation_val) == 1): + raise ValueError("stride, padding, output_padding, and dilation must have a length of 1") + if weight.shape[1] % groups_val != 0: + raise ValueError(f'Weight output channels must be divisible by groups. weight output channels: {weight.shape[1]}, groups: {groups_val}') + if input.shape[-2] != weight.shape[0]: + raise ValueError(f'Input channels and weight input channels must be the same. input channels: {input.shape[-2]}, weight input channels: {weight.shape[0]}') + if input.shape[-2] % groups_val != 0 or weight.shape[0] % groups_val != 0: + raise ValueError(f'Input shape and groups are not compatible. input shape: {input.shape}, weight shape: {weight.shape}, groups: {groups_val}') + iW = input.shape[-1] + kW = weight.shape[2] + oW = (iW - 1) * stride_val[0] - 2 * padding_val[0] + dilation_val[0] * (kW - 1) + output_padding_val[0] + 1 + # iC+ represents the merging of input channels + # Example: If the input is (batch_size, 3, 32), with three input channels + # Partition: The 3 input channels can be logically divided into 3 subsets (each subset contains 1 channel). + # In the convolution calculation, these three subsets are combined into a whole for processing, and the output result is a new feature graph. + if len(input.shape) == 2: + if bias is None: + annos = [f'iC+ {iW}, iC+ oC {kW} -> oC {oW}'] if groups_val == 1 else \ + [f'(groups group_size^) {iW}, (groups group_size^) oC {kW} -> (groups oC) {oW}'] + return IRDimops(ConvTranspose1D, 'conv_transpose1d', signature, annos, [input, weight], + bias=None, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + else: + annos = [f'iC+ {iW}, iC+ oC {kW}, oC -> oC {oW}'] if groups_val == 1 else \ + [f'(groups group_size^) {iW}, (groups group_size^) oC {kW}, oC -> (groups oC) {oW}'] + return IRDimops(ConvTranspose1D, 'conv_transpose1d', signature, annos, [input, weight, bias], + stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + if len(input.shape) == 3: + if bias is None: + annos = [f'n iC+ {iW}, iC+ oC {kW} -> n oC {oW}'] if groups_val == 1 else \ + [f'n (groups group_size^) {iW}, (groups group_size^) oC {kW} -> n (groups oC) {oW}'] + return IRDimops(ConvTranspose1D, 'conv_transpose1d', signature, annos, [input, weight], + bias=None, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + else: + annos = [f'n iC+ {iW}, iC+ oC {kW}, oC -> n oC {oW}'] if groups_val == 1 else \ + [f'n (groups group_size^) {iW}, (groups group_size^) oC {kW}, oC -> n (groups oC) {oW}'] + return IRDimops(ConvTranspose1D, 'conv_transpose1d', signature, annos, [input, weight, bias], + stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + + +def Conv2D(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, signature=None): + """ + torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) + + NOTE: the helo-exchange partitioning is supported in IRConv2D + TODO: partitioning groups or iC+ is possible, but need full fledged implementation of the annotation + """ + if len(input.shape) not in [3, 4]: + raise ValueError(f"Expected input tensor to have 3 or 4 dimensions, but got {input.shape}") + stride_val = unwrap_if_irobject(stride) + padding_val = unwrap_if_irobject(padding) + dilation_val = unwrap_if_irobject(dilation) + groups_val = unwrap_if_irobject(groups) + if isinstance(stride_val, int): + stride_val = (stride_val, stride_val) + if isinstance(dilation_val, int): + dilation_val = (dilation_val, dilation_val) + if isinstance(padding_val, str): + if padding_val == 'same': + kH, kW = weight.shape[2:4] + iH, iW = input.shape[-2:] + effective_kernel_size_h = (kH - 1) * dilation_val[0] + 1 + effective_kernel_size_w = (kW - 1) * dilation_val[1] + 1 + total_padding_h = (iH - 1) * stride_val[0] + effective_kernel_size_h - iH + total_padding_w = (iW - 1) * stride_val[1] + effective_kernel_size_w - iW + pad_h = total_padding_h // 2 + pad_w = total_padding_w // 2 + padding_val = (pad_h, pad_w) + elif padding_val == 'valid': + padding_val = (0, 0) + else: + raise ValueError("Unsupported padding value: {}. Use 'valid', 'same', or an integer.".format(padding_val)) + elif isinstance(padding_val, int): + padding_val = (padding_val, padding_val) + elif not isinstance(padding_val, tuple): + raise ValueError("Padding must be a string ('valid', 'same'), an integer, or a tuple") + iC, iH, iW = input.shape[-3:] + oC, iCg, kH, kW = weight.shape + oH = (iH + 2 * padding_val[0] - dilation_val[0] * (kH - 1) - 1) // stride_val[0] + 1 + oW = (iW + 2 * padding_val[1] - dilation_val[1] * (kW - 1) - 1) // stride_val[1] + 1 + + if iC // groups_val != iCg: + raise ValueError(f'Input shape and weight shape are not compatible for the number of groups. input shape: {input.shape}, weight shape: {weight.shape}, groups: {groups_val}') + if oC % groups_val != 0: + raise ValueError('The output channels of weight must be divisible by the number of groups.') + + def modifier(kwargs: dict, idx, dim, num: int) -> dict: + # only for partitioning groups + kwargs = dict(**kwargs) + kw_groups = kwargs['groups'] + if isinstance(kw_groups, IRObject): + kw_groups = kw_groups.value + kwargs['groups'] = kw_groups // num + return kwargs + + if len(input.shape) == 3: + if bias is None: + if groups_val == 1: + annos = [f'iC+ {iH} {iW}, oC iC+ {kH} {kW} -> oC {oH} {oW}'] + rules = None + else: + # NOTE: g can be partitioned only when rules are provided + annos = [f'(g {iCg}) {iH} {iW}, (g {oC // groups_val}) {iCg} {kH} {kW} -> (g {oC // groups_val}) {oH} {oW}'] + rules = [TransformRule([DimopSplit.D(0), DimopSplit.D(0)], [DimopSplit.D(0)], modifier)] + else: + # NOTE: not supported value partition of bias yet + if groups_val == 1: + annos = [f'iC^ {iH} {iW}, oC iC^ {kH} {kW}, oC -> oC {oH} {oW}'] + rules = None + else: + annos = [f'(g {iCg}) {iH} {iW}, (g {oC // groups_val}) {iCg} {kH} {kW}, (g {oC // groups_val}) -> (g {oC // groups_val}) {oH} {oW}'] + rules = [TransformRule([DimopSplit.D(0), DimopSplit.D(0), DimopSplit.D(0)], [DimopSplit.D(0)], modifier)] + elif len(input.shape) == 4: + if bias is None: + if groups_val == 1: + annos = [f'n iC+ {iH} {iW}, oC iC+ {kH} {kW} -> n oC {oH} {oW}'] + rules = None + else: + # NOTE: g can be partitioned only when rules are provided + annos = [f'n (g {iCg}) {iH} {iW}, (g {oC // groups_val}) {iCg} {kH} {kW} -> n (g {oC // groups_val}) {oH} {oW}'] + rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0)], [DimopSplit.D(1)], modifier)] + else: + # NOTE: not supported value partition of bias yet + if groups_val == 1: + annos = [f'n iC^ {iH} {iW}, oC iC^ {kH} {kW}, oC -> n oC {oH} {oW}'] + rules = None + else: + annos = [f'n (g {iCg}) {iH} {iW}, (g {oC // groups_val}) {iCg} {kH} {kW}, (g {oC // groups_val}) -> n (g {oC // groups_val}) {oH} {oW}'] + rules = [TransformRule([DimopSplit.D(1), DimopSplit.D(0), DimopSplit.D(0)], [DimopSplit.D(1)], modifier)] + + return IRDimops(Conv2D, 'conv2d', signature, annos, [input, weight, bias] if bias is not None else [input, weight], rules, + stride=stride, padding=padding, dilation=dilation, groups=groups) + + +def ConvTranspose2D(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, signature = None): + """ + torch.nn.functional.conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) + """ + if len(input.shape) not in [3, 4]: + raise ValueError(f"Expected input tensor to have 3 or 4 dimensions, but got {input.shape}") + stride_val = unwrap_if_irobject(stride) + padding_val = unwrap_if_irobject(padding) + output_padding_val = unwrap_if_irobject(output_padding) + dilation_val = unwrap_if_irobject(dilation) + groups_val = unwrap_if_irobject(groups) + if isinstance(stride_val, int): + stride_val = (stride_val, stride_val) + if isinstance(padding_val, int): + padding_val = (padding_val, padding_val) + if isinstance(output_padding_val, int): + output_padding_val = (output_padding_val, output_padding_val) + if isinstance(dilation_val, int): + dilation_val = (dilation_val, dilation_val) + if not (len(stride_val) == 2 and len(padding_val) == 2 and len(output_padding_val) == 2 and len(dilation_val) == 2): + raise ValueError("stride, padding, output_padding, and dilation must have a length of 2") + iH, iW = input.shape[-2:] + kH, kW = weight.shape[2:4] + oH = (iH - 1) * stride_val[0] - 2 * padding_val[0] + dilation_val[0] * (kH - 1) + output_padding_val[0] + 1 + oW = (iW - 1) * stride_val[1] - 2 * padding_val[1] + dilation_val[1] * (kW - 1) + output_padding_val[1] + 1 + if input.shape[-3] != weight.shape[0]: + raise ValueError(f'Input channels and weight input channels must be the same. input channels: {input.shape[-3]}, weight input channels: {weight.shape[0]}') + if input.shape[-3] % groups_val != 0: + raise ValueError(f'Input shape and groups are not compatible. input shape: {input.shape}, groups: {groups_val}') + if weight.shape[0] % groups_val != 0: + raise ValueError(f'Weight shape and groups are not compatible. weight shape: {weight.shape}, groups: {groups_val}') + # FIXME: inchannel is reduction dim or outchannel? + # iC+ represents the merging of input channels + if len(input.shape) == 3: + if bias is None: + annos = [f'iC+ {iH} {iW}, iC+ oC {kH} {kW} -> oC {oH} {oW}'] if groups_val == 1 else \ + [f'(groups group_size^) {iH} {iW}, (groups group_size^) oC {kH} {kW} -> (groups oC) {oH} {oW}'] + return IRDimops(ConvTranspose2D, 'conv_transpose2d', signature, annos, [input, weight], + bias=None, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + else: + annos = [f'iC+ {iH} {iW}, iC+ oC {kH} {kW}, oC -> oC {oH} {oW}'] if groups_val == 1 else \ + [f'(groups group_size^) {iH} {iW}, (groups group_size^) oC {kH} {kW}, oC -> (groups oC) {oH} {oW}'] + return IRDimops(ConvTranspose2D, 'conv_transpose2d', signature, annos, [input, weight, bias], + stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + if len(input.shape) == 4: + if bias is None: + annos = [f'n iC+ {iH} {iW}, iC+ oC {kH} {kW} -> n oC {oH} {oW}'] if groups_val == 1 else \ + [f'n (groups group_size^) {iH} {iW}, (groups group_size^) oC {kH} {kW} -> n (groups oC) {oH} {oW}'] + return IRDimops(ConvTranspose2D, 'conv_transpose2d', signature, annos, [input, weight], + bias=None, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + else: + annos = [f'n iC+ {iH} {iW}, iC+ oC {kH} {kW}, oC -> n oC {oH} {oW}'] if groups_val == 1 else \ + [f'n (groups group_size^) {iH} {iW}, (groups group_size^) oC {kH} {kW}, oC -> n (groups oC) {oH} {oW}'] + return IRDimops(ConvTranspose2D, 'conv_transpose2d', signature, annos, [input, weight, bias], + stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) def SVD(input, some=True, compute_uv=True, *, out=None, signature=None): diff --git a/nnscaler/graph/parser/fx/mapping.py b/nnscaler/graph/parser/fx/mapping.py index 8efc5b2e..92a3255d 100644 --- a/nnscaler/graph/parser/fx/mapping.py +++ b/nnscaler/graph/parser/fx/mapping.py @@ -222,8 +222,14 @@ def exist(signature: str) -> bool: __ttemplate('reshape'): function.Reshape, __ttemplate('conv1d'): function.Conv1D, + __ftemplate('conv1d'): function.Conv1D, + __ttemplate('conv_transpose1d'): function.ConvTranspose1D, + __ftemplate('conv_transpose1d'): function.ConvTranspose1D, # - # __ttemplate('conv2d'): function.Conv2D, + __ttemplate('conv2d'): function.Conv2D, + __ftemplate('conv2d'): function.Conv2D, + __ttemplate('conv_transpose2d'): function.ConvTranspose2D, + __ftemplate('conv_transpose2d'): function.ConvTranspose2D, # # __ttemplate('conv3d'): function.Conv3D, # diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 39852909..0072aac8 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -597,27 +597,36 @@ def test_Softmax(): def test_Conv1D(): + op = F.Conv1D(IRTensor([3, 4]), IRTensor([3, 3, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC+ 4, oC iC+ 1 -> oC 4' + op = F.Conv1D(IRTensor([3, 4]), IRTensor([3, 3, 1]), groups=1,padding="valid") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC+ 4, oC iC+ 1 -> oC 4' + op = F.Conv1D(input=IRTensor([8, 32]), weight=IRTensor([16, 8, 3]), bias=IRObject(value=16),groups=1,padding="same") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC^ 32, oC iC^ 3, oC -> oC 32' op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 4' op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), stride=2) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 2' op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), padding=1) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 6' + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), padding=IRObject(value=1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 6' op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), dilation=2) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 4' op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), groups=1) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 4' op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), groups=1,padding="valid") assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, oC iC+ 1 -> n oC 4' - op = F.Conv1D(input=IRTensor([4, 8, 32]), weight=IRTensor([16, 8, 3]), bias=IRTensor([16,]),groups=1,padding="same") + op = F.Conv1D(input=IRTensor([4, 8, 32]), weight=IRTensor([16, 8, 3]), bias=IRObject(value=16),groups=1,padding="same") assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC^ 32, oC iC^ 3, oC -> n oC 32' op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 1, 1]), groups=3) expected_annotation_for_groups = 'n (g 1) 4, (g 1) 1 1 -> n (g 1) 4' assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation_for_groups - op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), bias=IRTensor([3])) + op = F.Conv1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), bias=IRObject(value=3)) assert op._annos_candidates[0] == 'n iC^ 4, oC iC^ 1, oC -> n oC 4', "Annotation mismatch." + def test_Arange(): op = F.Arange(10) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 10^' and op.kwargs['dtype'] == torch.int64 @@ -759,3 +768,91 @@ def test_Diag(): op = F.Diag(IRTensor([5, 10]), -1) assert op._annos_candidates[0] == '5 10 -> 4' + + +def test_Conv2D(): + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, oC iC+ 1 1 -> n oC 4 4' + op = F.Conv2D(IRTensor([3, 4, 4]), IRTensor([3, 3, 1, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC+ 4 4, oC iC+ 1 1 -> oC 4 4' + op = F.Conv2D(IRTensor([3, 4, 4]), IRTensor([3, 3, 1, 1]), groups=1, padding="valid") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC+ 4 4, oC iC+ 1 1 -> oC 4 4' + op = F.Conv2D(input=IRTensor([8, 32, 32]), weight=IRTensor([16, 8, 3, 3]), bias=IRObject(value=16), groups=1, padding="same") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC^ 32 32, oC iC^ 3 3, oC -> oC 32 32' + op = F.Conv2D(input=IRTensor([8, 32, 32]), weight=IRTensor([16, 4, 3, 3]), bias=IRObject(value=16), groups=2, padding="same") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == '(g 4) 32 32, (g 8) 4 3 3, (g 8) -> (g 8) 32 32' + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), stride=2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, oC iC+ 1 1 -> n oC 2 2' + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), padding=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, oC iC+ 1 1 -> n oC 6 6' + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), padding=IRObject(value=1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, oC iC+ 1 1 -> n oC 6 6' + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), dilation=2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, oC iC+ 1 1 -> n oC 4 4' + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), groups=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, oC iC+ 1 1 -> n oC 4 4' + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), groups=1, padding="valid") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, oC iC+ 1 1 -> n oC 4 4' + op = F.Conv2D(input=IRTensor([4, 8, 32, 32]), weight=IRTensor([16, 8, 3, 3]), bias=IRObject(value=16), groups=1, padding="same") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC^ 32 32, oC iC^ 3 3, oC -> n oC 32 32' + op = F.Conv2D(input=IRTensor([4, 8, 32, 32]), weight=IRTensor([16, 4, 3, 3]), bias=IRObject(value=16), groups=2, padding="same") + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n (g 4) 32 32, (g 8) 4 3 3, (g 8) -> n (g 8) 32 32' + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 1, 1, 1]), groups=3) + expected_annotation_for_groups = 'n (g 1) 4 4, (g 1) 1 1 1 -> n (g 1) 4 4' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation_for_groups + op = F.Conv2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), bias=IRObject(value=3)) + assert op._annos_candidates[0] == 'n iC^ 4 4, oC iC^ 1 1, oC -> n oC 4 4', "Annotation mismatch." + + +def test_ConvTranspose2D(): + op = F.ConvTranspose2D(IRTensor([3, 4, 4]), IRTensor([3, 3, 1, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC+ 4 4, iC+ oC 1 1 -> oC 4 4' + op = F.ConvTranspose2D(IRTensor([3, 4, 4]), IRTensor([3, 3, 1, 1]), bias=IRObject(value=3)) + assert op._annos_candidates[0] == 'iC+ 4 4, iC+ oC 1 1, oC -> oC 4 4', "Annotation mismatch." + op = F.ConvTranspose2D(IRTensor([3, 4, 4]), IRTensor([3, 3, 1, 1]), padding=IRObject(value=1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC+ 4 4, iC+ oC 1 1 -> oC 2 2' + op = F.ConvTranspose2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, iC+ oC 1 1 -> n oC 4 4' + op = F.ConvTranspose2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), stride=2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, iC+ oC 1 1 -> n oC 7 7' + op = F.ConvTranspose2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), padding=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, iC+ oC 1 1 -> n oC 2 2' + op = F.ConvTranspose2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), padding=IRObject(value=1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, iC+ oC 1 1 -> n oC 2 2' + op = F.ConvTranspose2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), dilation=2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, iC+ oC 1 1 -> n oC 4 4' + op = F.ConvTranspose2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), groups=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4 4, iC+ oC 1 1 -> n oC 4 4' + op = F.ConvTranspose2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 1, 1, 1]), groups=3) + expected_annotation_for_groups = 'n (groups group_size^) 4 4, (groups group_size^) oC 1 1 -> n (groups oC) 4 4' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation_for_groups + op = F.ConvTranspose2D(IRTensor([2, 3, 4, 4]), IRTensor([3, 3, 1, 1]), bias=IRObject(value=3)) + assert op._annos_candidates[0] == 'n iC+ 4 4, iC+ oC 1 1, oC -> n oC 4 4', "Annotation mismatch." + + +def test_ConvTranspose1D(): + op = F.ConvTranspose1D(IRTensor([3, 4]), IRTensor([3, 3, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC+ 4, iC+ oC 1 -> oC 4' + op = F.ConvTranspose1D(IRTensor([3, 4]), IRTensor([3, 3, 1]), groups=3) + expected_annotation_for_groups = '(groups group_size^) 4, (groups group_size^) oC 1 -> (groups oC) 4' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation_for_groups + op = F.ConvTranspose1D(IRTensor([3, 4]), IRTensor([3, 3, 1]), bias=IRObject(value=3)) + assert op._annos_candidates[0] == 'iC+ 4, iC+ oC 1, oC -> oC 4', "Annotation mismatch." + op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, iC+ oC 1 -> n oC 4' + op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), stride=2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, iC+ oC 1 -> n oC 7' + op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), padding=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, iC+ oC 1 -> n oC 2' + op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), padding=IRObject(value=1)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, iC+ oC 1 -> n oC 2' + op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), dilation=2) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, iC+ oC 1 -> n oC 4' + op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), groups=1) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'n iC+ 4, iC+ oC 1 -> n oC 4' + op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), groups=3) + expected_annotation_for_groups = 'n (groups group_size^) 4, (groups group_size^) oC 1 -> n (groups oC) 4' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation_for_groups + op = F.ConvTranspose1D(IRTensor([2, 3, 4]), IRTensor([3, 3, 1]), bias=IRObject(value=3)) + assert op._annos_candidates[0] == 'n iC+ 4, iC+ oC 1, oC -> n oC 4', "Annotation mismatch." + diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index f186d26f..d6eefb4c 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -977,3 +977,134 @@ def test_codegen_kwargs(tmp_path): assert _gencode_contains(tmp_path, KwargsModule, 0, r"torch.zeros_like\(x_\d+, requires_grad=False, dtype=torch.float32\)" ) + + +class ConvTranspose1DModule(torch.nn.Module): + def __init__(self, weight, bias=None, stride=1, padding=0, output_padding=0, dilation=1, groups=1): + super().__init__() + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + self.stride = stride + self.padding = padding + self.output_padding = output_padding + self.dilation = dilation + self.groups = groups + + def forward(self, input, **kwargs): + groups = kwargs.get('groups', self.groups) + return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias, self.stride, self.padding, self.output_padding, groups, self.dilation) + + +def _gencode_conv_transpose1d_function(tempdir): + init_distributed() + weight = torch.randn(3, 3, 3) + bias = torch.randn(3) + m_new = parallelize( + ConvTranspose1DModule(weight, bias), + { + 'input': torch.randn(2, 3, 4), + 'groups': 1, + }, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=True + ) + assert m_new is not None + args = inspect.signature(m_new._forward_impl).parameters + assert len(args) == 2 + assert args['input'].default is inspect.Parameter.empty, "Expected 'input' to have no default value" + assert args['kwargs'].default == inspect.Parameter.empty, "Expected 'kwargs' to have no default value" + + input_tensor = torch.randn(2, 3, 4) + model = ConvTranspose1DModule(weight, bias) + expected_output = model(input_tensor, groups=1) + actual_output = m_new(input_tensor, groups=1) + assert torch.allclose(actual_output, expected_output, atol=1e-6), "Expected the output of ConvTranspose1DModule to match the expected output" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of GPU devices') +def test_codegen_conv_transpose1d(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, _gencode_conv_transpose1d_function, tempdir) + + +class Conv2DModule(torch.nn.Module): + def __init__(self, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + super().__init__() + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + def forward(self, input, **kwargs): + groups = kwargs.get('groups', self.groups) + return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, groups) + + +def _gencode_conv2d_function(tempdir): + init_distributed() + weight = torch.randn(3, 3, 3, 3) + bias = torch.randn(3) + m_new = parallelize( + Conv2DModule(weight, bias), + { + 'input': torch.randn(2, 3, 32, 32), + 'groups': 1, + }, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=True + ) + assert m_new is not None + args = inspect.signature(m_new._forward_impl).parameters + assert len(args) == 2 + assert args['input'].default is inspect.Parameter.empty, "Expected 'input' to have no default value" + assert args['kwargs'].default == inspect.Parameter.empty, "Expected 'kwargs' to have no default value" + input_tensor = torch.randn(2, 3, 32, 32) + model = Conv2DModule(weight, bias) + expected_output = model(input_tensor, groups=1) + actual_output = m_new(input_tensor, groups=1) + assert torch.allclose(actual_output, expected_output, atol=1e-6), "Expected the output of Conv2DModule to match the expected output" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of GPU devices') +def test_codegen_conv2d(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, _gencode_conv2d_function, tempdir) + + +def _gencode_conv2d_function_(tempdir): + init_distributed() + weight = torch.randn(6, 3, 3, 3) + bias = torch.randn(6) + m_new = parallelize( + Conv2DModule(weight, bias, groups=2), + { + 'input': torch.randn(2, 6, 32, 32), + 'groups': 2, + }, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=True + ) + assert m_new is not None + args = inspect.signature(m_new._forward_impl).parameters + assert len(args) == 2 + assert args['input'].default is inspect.Parameter.empty, "Expected 'input' to have no default value" + assert args['kwargs'].default == inspect.Parameter.empty, "Expected 'kwargs' to have no default value" + input_tensor = torch.randn(2, 6, 32, 32) + model = Conv2DModule(weight, bias, groups=2) + expected_output = model(input_tensor, groups=2) + actual_output = m_new(input_tensor, groups=2) + assert torch.allclose(actual_output, expected_output, atol=1e-6), "Expected the output of Conv2DModule to match the expected output" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of GPU devices') +def test_codegen_conv2d_groups(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(1, _gencode_conv2d_function_, tempdir) From d82882e3eb8eebed7ed4982db55356033fec095d Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 12 Jul 2024 05:51:18 +0000 Subject: [PATCH 1678/1892] Merged PR 2200: TensorBase adaption in torch>=2.3 `_TensorBase ` rename to `TensorBase ` in torch 2.3, this will affect the functions call like `torch.Tensor.view(t, (1, -1))`, this pr fix this issue. --- .../graph/parser/fx/concrete_trace_utils/concrete_tracer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index fc9fd1d4..41613bc0 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -1053,7 +1053,10 @@ def torch_assert_wrapper(condition, message): else: wrapped = _create_wrapped_leaf_func(self, func, inner_func) else: - if func.__qualname__.startswith('_TensorBase'): + # for example, torch.Tensor.view.__qualname__ is 'TensorBase.view', + # should also add the location `Location(torch.Tensor, func.__name__)` for these methods. + # NOTE: `_TensorBase` is renamed to `TensorBase` in the latest pytorch version. + if func.__qualname__.startswith('_TensorBase') or func.__qualname__.startswith('TensorBase'): locations = (*locations, Location(torch.Tensor, func.__name__)) wrapped = _create_wrapped_leaf_method(self, getattr(torch.Tensor, func.__name__), func.__name__, wrap_info.replace_fn) elif func.__qualname__.startswith('_VariableFunctionsClass'): From 8ee468b28d727b98c81341bfe0ba111c16f1cd07 Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Fri, 12 Jul 2024 10:58:04 +0000 Subject: [PATCH 1679/1892] Merged PR 2197: update doc for v0.1 update doc for v0.1 --- docs/source/faq.rst | 50 +++++++++++++++++++++ docs/source/images/overview.png | Bin 0 -> 148362 bytes docs/source/index.rst | 3 ++ docs/source/quickstart.rst | 33 +++++++------- docs/source/readme.rst | 76 ++++++++++++++++++++++++++++++++ 5 files changed, 147 insertions(+), 15 deletions(-) create mode 100644 docs/source/faq.rst create mode 100644 docs/source/images/overview.png create mode 100644 docs/source/readme.rst diff --git a/docs/source/faq.rst b/docs/source/faq.rst new file mode 100644 index 00000000..0b40ce88 --- /dev/null +++ b/docs/source/faq.rst @@ -0,0 +1,50 @@ +Frequent asked questions +------------------------ + +**What is nnScaler?** + +The nnScaler is a system that takes a DNN model that is designed for a single device, e.g., GPU, and automatically converts it into a program that can execute concurrently on multiple devices. + +**What can nnScaler do?** + +Under the hood, nnScaler analyzes the given DNN models, plans for appropriate parallelization strategies, and generates corresponding execution code. With nnScaler, users can focus on single-device DNN model design, offload the complex parallelization work to nnScaler, and easily achieve high-performance parallel DNN execution. + +**What is/are nnScaler’s intended use(s)?** + +Due to high compatibility and extensibility, nnScaler can be used for the innovation of a wide range of new DNN models and DNN systems, including new model structures, training patterns, as well as new parallelization techniques that go beyond existing data-parallelism, tensor-parallelism, or pipeline parallelism. + +**How was nnScaler evaluated? What metrics are used to measure performance?** + +For execution performance, nnScaler can support new parallelisms that outperform existing parallel execution approaches: +1. Fitting larger DNN models given the same hardware. +2. Providing faster execution for the same model on the same hardware (included in our OSDI’24 paper). + +For compatibility, nnScaler can support paralleling new DNN models by providing user-defined functions (a few lines of code) for the new operators unrecognized by the nnScaler. + +**What are the limitations of nnScaler? How can users minimize the impact of nnScaler’s limitations when using the system?** + +- Certain DNN model architectures or execution patterns may violate the assumptions of nnScaler and, therefore, cannot be supported by nnScaler. +- The nnScaler does not guarantee the optimality of parallelization, so it is possible for nnScaler to miss the optimal parallelization strategy given DNN model and device settings, while only providing suboptimal solutions. +- Despite our best efforts to ensure the parallelization process is correct, it is possible for nnScaler to generate parallelized programs for concurrent execution that are inconsistent with the original DNN model for a single device. + +**What operational factors and settings allow for effective and responsible use of nnScaler?** + +- We provide documentation to guide users in the usage of the nnScaler. +- We provide parallelization examples that users can directly leverage for parallel execution if they intend to execute the same DNN models. +- We also provide certain cases of customization, including reconfiguring the device settings, adopting new DNN models in nnScaler, and supporting customized operators. + +**What are extensions(plugins) in nnScaler and how does nnScaler use them?** + +The nnScaler supports the extension with customized parallelization of DNN modules, allowing new DNN models to be parallelized. During this process, nnScaler will handle the new modules in the same way as those it already supports. + +**What can nnScaler provide to extensions(plugins)?** + +The nnScaler provides an easy-to-use interface so users can conveniently realize customized parallelization of certain DNN modules by only implementing a few user-defined functions. + +**What kinds of issues may arise when using nnScaler enabled with extensions(plugins)?** + +- When paralleling new DNN models, users may try some structures or execution patterns that violate the assumptions and fail to support. +- When adapting new DNN models for parallelization, users may incorrectly implement the user-defined function, causing nnScaler to produce incorrect parallelized programs. +- Certain unforeseen mistakes in nnScaler implementation may cause it to produce incorrect parallelized programs without warning, leading to incorrect execution. +- To mitigate unsupported issues, users may disable parallelization for the entire DNN model or certain parts of the model as a workaround. +- To mitigate incorrect execution, users may compare the parallelized programs and original DNN model execution on small datasets to confirm their consistency before deploying to large scale for long-term execution. \ No newline at end of file diff --git a/docs/source/images/overview.png b/docs/source/images/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..260ddd4c29b5c9bddc863e30000ae16d0afae87b GIT binary patch literal 148362 zcmZr&1yoes+D1h|5M&UBkWjjj4haQ@5NYWWkdiKGP-zC34@8gVd28Du&(;! z-T-%YWt)@1<%)~CoD^1ZAJq!@0o(Gi(qk+vBm)1;3YULJNT(^I`d4~?Q@?G@$AtQKJq+*sQrASsj%pI(6a;T3lx?B?vB5k4-!G{f zZFK+dCb)4Q+Z6IY?!P9wip2dN_aEWD(fMD`zZdw3^5%3M<{NkB+U=W+T-V13x7T1Hgw#`h+&ZN%!ca$7 zCFC|GD$jH%JZ)7{C6p-mEJ*AXOFZ{}W$9PhgO}fykhC7kgn$p)n5@ddxlN-3K1lha zs9NcZuYr@p>`lMU&W~onD>9xpzqOaFWL8doa5iqym#kaO*|2Md2D70Dp5kYy7|->w z+%JwpH`k&!amX9BjcT^xx{90s9@OcUDP*4R_r(lc z2u&fF%ilL6#Gr%VLt;i6yxk`%?UGQlcfgY#5&R?6vBfP6WD+SV-ul%XeBhe(&Yygm zz*pxxc#?F7Yv}Jhp%NU3NZ{|`6|rolwE&y zTtgjJ6w;1H5LZ*cpQc{&e7xL`Rv||1^W(j+DdzsoZ=}`d*klUxAkkN;ZIWPh8q?Uta#P*0NzE;@9C*Wz3yeVQYDOkgfn1 zv~MDN`T=?GSHsWS;3Im@o7|`TQkObc=RP(CO_J5Dt(nE5D#k#(v(p~k&MCnI6VsM} zZ)U`;j$f6d?Ux*7e5?fCd*VJlO%!VEqHKI=2!5RHi=W--6!f|@F+*$!b*>ydb-=C4 z8Z+>m$#(FsWdMVDXCbHd@*3(M4lAE!A3xRlm96H-w-NT$J}1tUJf_$OE5mYlC5b`~ z%CvsR_w9a^wnNXKeiYTjGe23P>x{f>39o;(ImK!xSDyTR)(@?u7bmSgQ({F?kR;?# z8Dub=t>ofGH7e{jR&1!w&!*S6+M6VlW$yX0(HD&uk6h*ML|P9sO*%bIm6#rx|1A5W zUb)HdJg{3wz?sUdGb#&xw58a+In$tzR{`tRZE%_oF6WTUNv7`Xh`ehzcoadwmrT>y z;C+;7ej;4^YJ<(7kL}U~yR{QSL&Igzk$)z9Tv4KwB-pa3BSC0mRcF~xlhf&Ru&QaT zD-m_&+6~-zg2^|)6qg*LIgc8Z_MUY4Lv7%G2ZxzKh=Hk!65r!o_>Rnip8K~ zHJIS8ZKu=zvf4zGe-ldb`1wC_>OCvZlP&-K|5ERR)Jx{tpny&4#8 zm{lZGj4(J3Ad z0D2-zwX5JG{{2OHVo;ElBKXA%4g?9m3+oK?MRNUfgA?={U&!4~t^Z`q869dsC`%Od z2)vD`<&j(asohAvFIlfznlRqy8`r;}x7&Z@sHEy$z^yhWD){}*JoL&e2tcM)PUH>c zzPLP^^^ar3!ule9EU@&mIe_2i*zx7}{Pz~B>DT}H!s{w2=ZaOAK1^q()R=)0wLEt^ z66(w|TuKSI%Jc|}aXnnujo_H0_s9%w4!{}iil(=tS9&1kj5|k!^;+l{l!!c#+ADQ3 z4eMUHo*g@d&`10;`F+_NljIJ{2>U$@De^*Re)VEg#wST!)Eg46&k}=oj9qU^nEWU; z9}3cKd{ppYo%KKIey_7<7zxqBep@AY`V zzi72E)hoHOdY-e7p6y>@42P6B|NL@gbtnt&YEtQoJ{&SWlJq|8ac0ND@*qozvK{#J zc%;eSf1=D1!GHc9`i~M$q?#<|88hRDwwpkj;fQ;jMHEkw`CAi$*^$CEVL7eiZ^3;F zariRabBlsE5NvI(cDn9(?wn2pt-d_~PqSY#!nSec`m&j%|AQ}Hzea%?`z-$%Yd9Hq z3cnEg*xF?|ULO%;93zdHcPa94qlU}d=Q|76-$J+@S4UV^XBz5A!UID>h9mC^ElVc5 zW4)%`*=3zx@ZIxgrgc1FK0(pgN1G+xeld>rUTn3-`C*pL_ontVHrSgywQD zW`NTp&b&KL4PcK4`M6kLwtlWy2|CO?1#^?-G~Y1@iXvxK+drYwm1eNyte|Fwt7dAV zxm-7YNu$UXeor`2aOs%eWk^bctJw7`j)**s9Cd&avUC=f2Q&81PF%V{)}m#O7lkqNd_OR4^Smj= zOqB9AF$IMJr+#HDS4WD3Z=zA7k97^0Wgq&5TjdJY7fEtK+Yco0&+**Gm1q2f|G;e9 z-cRDWlmOl+3!;0UKl-Um=z_@a{+~FwAF;PNzqC(W5h}MlR+73QwO*uO#UTkk>=NrA zDD-OLJ7yJiP}gnKDu=ft;q!#R5L)A_|{@fsPi{!;r{nrT)3WanFAEE zvB&QXHm}!C%lLKUdpD4{xYvutg;m})t;?PZjecs@ryM{PcrX8e=fvq(oSY76QP!1x zET1`^re5iwB}=|>k1EW^&yl?4c|B#3q(lhK`CiDPmYpzhg7@=#K@RzhMOBwA zvR|C2Bj_ULo# zPRD*jL)Om#`de`U(UXW`-8iDX?&*1cFFE!}(eStRj9>eT{Z6B&u^Z zjn0%on{}UlyhcmOFSO7qlwyn7RXgjCd z-KuZy{YZ_ae?Hl0#FO_$k3^go(@0n~SS98k?J32HT{rhhY;#k&WWn2sl76$6tW%)+?;kXFn-D7x%*=B&;t(WT;LpIy6+1KKCqo5T#C8;<(JPXoE8yg1H?r;Np2=g24@nhW}>Rxq{4Pm?W6 z#&Ky!2O|<>y#0k?WLsyPK2#=!ievWKm@?jYbxVEJ_Q%Tik-UwIhUFemuLi#=x(7~< zId{p;gM3)cnr9GAUo<@CPX?fCzQCGlOW;jROVdJ{I&?Bi13EZ5uxSeODL5~iLRarf&N(ub~V zoI!Ev8#cOSM)Xe@?!NCKUEjXT*!W3gyT|=l=RE9vTOuTdiTT5elX47O|a`c?t8d30E#n5YV{SYM|a!k-t>Mp zX|;Vkx!4`gDH(YUQwM~br*+!c5TB41y_DTr2lH05_2L=PK?z9QwN;*b&>jDOeOy;O zcj_G9*Wwby$$9L3N#h2&1{^{cG#$f!(MIiIX2w#nSvVy;r$jRIwkUi4w@V--FqDt%b(yqC|jBApwU*Y3F1k&K~4Y3!msw(Qq9%ju7ouZm%PR( zb$i6&zR6NEzgHBpQ*g1sItaw~r?74i-=%wB6k<7GY?u~Jp!3pdIaUP|%d$ETpCq?6 zR2M#7tmQWiMvlSHVBM;3u{U8xy`qGfYA⪻1dh1hwwpkznO>PR~4DZKc-s9go3V9 zclMN)AmPROLAu2}VJw3+hLq5&RsZJ7CkF#}{h+@~wf7SvI#=`n^11bMJ00W?=WT@fW^7qBs zBf9LTYZWr=bmB>WRM==sn*?U)1R~m(Lm4eP>56HtnL;jc{#5m|H;>@3U!3~VC6+x4 z+b&|as64~phm}FWzCUvi_H^+0H73b=y2e*-)p~`p$`S6`9Kk)o2%$hgZ-`)LJl*T1==M~B>XR?b~Uix;(*xt%>+2! zapOpLUwvO;Gp3Hom$YAsC~wM+81wGoGVWdqQtEwxz&|+M3Wc5hJ0+@Tsv3b}(feA6 zTJ)tDqw}7r;J%9YKJB_tYnONHJ8i$5q7dcb|$ZP6`I`i z)t&-?s0ZWJ!hWo7c&U##>j*Mr?M_w0cE7VanW7)9Hk6VshYV{6ZPm<>AVV$x3k%)z zc*H?G{u_e`ISD|K#H6I=WImtqp&SDh)CaXxiCz93sNm>m@t#naOsC_~reXI`-(wD9 z@;|WUwImS>Qoj4pkma$IuYY?olQa{XagK;v=m*#}3qRRt_S*x=)Js+?%}M3^0}cby zeUk_Hk^1KX1c~P&ZZsmP5>_KAy!FwXzsj<9`AXsrS|s*Qh<0Zj`%%VLr@J5dl^xY? zVSr|z-K00k`$AP|JUL=!1=$UBMEXG;VyI^JL8<%$fTOS*DF!!bA%;68Aw1K-AWY-`)?!8u*^>r!grd zs)9x*D=$BW(Zxa88`tFd2YzdHueGA8+@}e=Bv7a@%GhNj7J02D&SmO+9dca_zl^ds8v$+a zK-&D$3ZP3?mTde7mb_I40%69u+&+)N6b*c6kW7@0oIJ{b@QpsXY;c}Da@{}M_;6%6 z`!p|9r;QJ+STmt%^cR5Du&}DbD3G}HlqNDN&(n-T;OE(jK zwf@1Jpz|K|<}Hx^ER!F)M`bG|$!;}WHcqcy2_(wrG;S1gTOW^R(=L7*MqDG*@9{b5 z-s7wJhIN8=la;;PGv2Gyh zK`5YON&D?P*>lg?qpjKai&L!&mxfMrEe<--U;uSjC2pv6&;eln*-oE&!JyVPjE;^j z*JEd{(%gfRn&!GmZ=wJp!VG=5F;b|@*M82WRY*=iEjqH{Z?JQ+(8Ta-V^45#DFlJ& z=S*d8Q2PtId)~E$gShyzv+VCdqT^VrJ(qY1v_ zZ90(jf*t7U*ds0iR@NgW1N-5RkaA6rV}%unS1uS75|eL47|KNN3Ekwj{T%^fIb85; z`X&3pthyO$6Kt^8af1Q3QZ-d!Lw2zGUGX#}xv6_`%Rbio@0XHK z(>m6F(ifju^zyP*9lQW!0mo|bgWXYc7`^v)ke}rTQuqTB_#&i=|UJIzBdRWmyo2#NHi{#1RVXlZ=^$lobl zfrD%U)@Yb{OX+moThDnr`+@xHi;W0wWp$MEcUHf**e#v$A0N0}0eqYdBQ`0&(-A3|MK(DvkiHQc5nep*K2SsmKl5I9n z!>8)pRd^;s`8*kJ+f;waa3sGe5`j%Gwi+BgAJg5rY|B6*;v+eN$e=%QTKMbhd!2=a zukT_Xy4$yVc`cyZG$=&sR6=3({o$nS%b=S}onC{0mSyXk|H*GwzrbbrCj!~k&(par z8^@8j6`3&81KV>qCXvN`@VO_U$*RY4?zpA%1R&#;$#-P+7uwC6rw3uiEH0zA;Y$vy zI0i-Yrb5jOYY5CDDjDXYQYuv!vj{JYV$F`kzG`)ejKf zlY+F-HUBAZZ+{?hRW9#es`iOaq@3bgos?2x*SbM6o) zjhk}uVklE#q}*DgZr;-K3AB-RbH!ian*7B3GJqz9ao?ZXy8`@j|FSYH1N9~6%)3??Gx-)9OrICmD*!$nA?OB_iD-*O`N2iuM?)$f#* zQJKme@)#k8cl$CR+?ctwr6m{8g+m58RN6cHH;Gz)#`%w(xTp9Zxh!;<9HnGo&}E+M zLC1()4;^L08GB;Qfu$bh+hMn!a0%(?)b)&%`8f5HiQbj(+6N_)%sV@a-B7;Xn+dVy0xu!qypV>Gfq-H23EK9n6ZH`F<3z zs`EK1UCeR%4RDyj+;8XNt>R+=EvkI#CBjShpb<9dLSpOzWT8d!g>G+9FRhvfJ!cwy z>ns?WX(_3%Q>=Qtxkf0r^VZ5(Y&ST?n}MudRRy05jBr@JMq3c4&cj615DJIWK-AG?tEGQc)lfCCG-MOYphB zwan&P_aj)^3#?81;&;W3xl-4DC(0V>YyM%d3DjX^JFJ4w0PDfA}Xa$p|w*vd!Eex=4@*(2~ysw%yZ9z&A74N@Lsb&R;BBT zy4Tq;(){bsrJPj1bSS5uJiTsye`h)j4(Lj`IK}|lD9+*wzqrE`V-_jRuDyu$S+_hmFrTAP!s5CyTP+k8$v?wxfwu3VK zAlS~1q5WR^;tvJmS6Od(CgVz4qb3Gpm^!B{45PLt^XcEMY=0Nu-#;;q^QN8p1C4i% zbOIqtOKAdnVa%owS=U>7!^mSgD9Uu4A{`QO78ZHOmeW!Jio1;sgo>4F-}rDc*y&cWE)*b)MdIzzIU9 zm?lVKZ_GLh9n?-m+O`4(F92#=6CF+5GKi94(!njQR$7x)5W-$k2>v zp@m$zjRC%PIO&0eh%#hD!!@PSe){F|1zZ*%5+ow#7WoiO411u5#NA?VTvgPT%LvF- zn`QRfv@$%C;+PZh;Kw?!37hCi;8S!Wr-R_8b@`m^S~o_$_?B{)Y{+oCw+g=_Sz{56 z)Vz?|O7*`$)`#OrUJ$_}l0@ZP5-SmL8)=L<{-hF74}_DCSB{33<uda!D9oB9)OmTpIup8v9Eb9tgx#3pywOQXIYTxZWbN)YgV*sw z<6MJ681aMM73IrBT!unPoTPoKYNOkKij|;@o zZ8?$#9r*ou3|*8HP5vt;H?oT!dNS zB_{A$4uRFR)}?bjxj=6L0{uz23!nXzu-Ep_H|AkK=hqaC2e$2$g02LyCW_F7(M;p2 zy%izHF#A~c#4Jc$9Bk0dX>zcOY5`Q^b89tE@i@I{nS9}ho!k`S_T>{Pyt?eR$f_{R zmGC>-S?K5U3NG25#)a_=Qkxa@l5&!NF46dg`hP4LK z0!gY;k$G?L1ri0vwSRk-b`gzlHnS#(H!F1ZH9~hr{Y?Mm^r-lLcM-reu6&poeM188 z`M?t}kMn?KF^Urz7l(@cT+OepgpWmJx7+vJgd zcmGcBtncIF>gJ#PE>Trx5{QYAN??A-+7r4Yk?{AiTZk2j<=5q%MBbPr;TO2;{%L{L zIO~${WR?M-kZbxh1ce6QZ}9t^zC@uznsQUZv+b5>T+doKQQox_bWoo7N8zxJC~8Ga z>)B33B}bm{L~I0fy7Ziow0DU$8 ztWdyB<4(G1Gy!pxj%x8?2v5Ju)5mw7XxWVh5#y%d7m2#2@r}{qrb;$ex`h!84-rQMUcVpc zhpIOuvpSn_u-Zmv{hL(lVs(iZeurJ9FC2lS0*J}BbsZUMn_M$OfmQsgeTUT4 z)OJ;UZ?EIqu6fc!e*ic*Z1&e|lup~RueL)7_t13>j+QH-qn?jBOlg3|``u>>NUs=@ zqhWX7Z1w0?rEA3lz4EQZ?QUbMsqh0mHL>RFgK?bt!EMj@8>3(?iB)m3&T6Tbl>ylu zV*NZ`$}B$s^tPrXJ8(&p{di-4uL9k&geb8aSdfOTlZ;k!N2b(0_Go5PT^D`9)(|Za zvVAW1qR$7G_c-VA{$S~JxrbzS{1Id2D{>D59w*m)3#qvL#Rs47vkN7n_cdh!KWn2T34+mB6Pv9WS$pO9$ zBUAtxh}|viVkV!hb-tB51|EbwAwAOoBxim0<|1L&O;k3lVfVQ>o`}inW=(0LK1F2k z_??b@_d5GLeD6ujh{QNVw?;uOaeu|VeBnz;OXnTnK6w6c51yK6xPJ-CzwzTH` zp9<>-NvbAWCkOmK`ghtre-3JHiZz<((((V6bd?|QK~W|>BV_8p+GAFu8R6wAj+Stm z3!L^E{qRK}iAx*cC4K@h6?d#@G?=-%dQhtJ{D7=;radcrXwCpd`F@Ruam-!T@AW+l z-(#r))E&R`f;{TF3GRK81Xc|ul+Cx$_J-YdCmSnDCYfcfgvH!RQU@O&C*ODJm z+^}!4uG`}xf2US5z(`2vw$XJrA~N43t+CLBY05DTr>fBE#NiS&*3?JJ$_QyBhal3E zhBxHKI1~eK{M}@C(A79EmfD3ZHqE$KavlVcHK4WRZjKJZ>H)o#ZuWB!2I5O?E1*AY zZ~z6Ype^X!nOOjL*W-r6#bwCDVCm`UB;C9FPrwXfx6y6#`JTG=s{MvWpsHJqld!fK zcPURF({ln9A+apIUp2i=VWZDwbtErHI_$<>Fl$EMhE%Yfpqpa-(%qhQerYdQjVipk z7!*_9DjU=@Vj`};awAJ1$zAQJ`CYDb6O!LzbEyTcc~P9Ax-dPYY1LJ)gxSC$U<0eV z_JlV)D;^Drm49HKY~YVCJCyBj=;VV&q(B;TlJ1*ATy!gKDS$;xX}9^>W#c%=Njzy( za9d#@xnJ&un1>DgTs|nt6qfO$dCLP$;5ElP*qqkiJv~^9OYQhAUnf7;;z*w%EMSgN zEy-}7jps53DKdHhBC_;zzDfb`-kOUHO3JP&CbI2yb7N? zCW&SFa^W7u)oBwKzM*y*rDHcy&MHg#hNzao z6tYDTjtuo$NW|Dio6Zi)Lt&qQOf^%n_E%o)NOSpBV1&sTy~eD8!Q`EQq%yIZu$goA z5-2*pZJ~4Fx6;}tZH3`bj%XcLBXt<`52?a^4eC7xcf;-m;sMv*9eHHeFq!`_FfR<{ z%{FGF+WVPxQJ~T4$Ju(}@>h}ag{_srgUxXLz@^^kr7fx)$O6pZ6 zIg_#|?+9G<#xQNzjPRL)oHWVeL%SOXB&2xIf)4@zBg46P{UeUK50AoN#pix0OdBgt z&ILzR0|_^Ql0tQKzAGZ`c@b}Z`{D}5h~juh+cX+b+6vrX^wK&8(xlytA88wp2sooY zqVi;9$byT!M4=`5=I1NU;JL+8u8gnWKdSS84=9EY;el0K*wqk~kN#1`C-Q-g{9jGTk zgftK3=yOUwZ6o0<&W#S-p%gs+(L*GXjmN&h($Nrj6{v!u%ulFNoYhPKyet;{P;DGl4_V z2t+iDm~=WPV%+uXxy*~cWD!!@Y8WK}{Rr7+WT8zK&{)Flx(5C)K4BGU-@3L}v*hc#HrmT&KR5Y3Lsd9j zn4IRljB8UNdJf_yoNbzsk&)%XRqGEL(^fE9tbS+5;V!+6I~{k)H9>C$RBBcD!w>k_SxPdx>IjIy)bpa2aId-~9D(A>uihW*(NnW;5_18R<9MWEr z%;J>fOjlwi^XkRGcEs2b#wb+7f8Kqh{c50V`F7A!D>e5t`fEvc#-9#z&NlOUgR{Kz z_Rq(_d5|$v#jbYU_`_K3?U``cl=uyLz>&0AU8nb2FdM!Tu7mQ&w)*HFJ-Sh_6<-D> z61`}_UdNxgr3_X(J0LtiT2$&!TU>ewenYh32Pk9+- zE$+w*+h}0(@d0h1RCEIqS59nR`>d{0@_0Z_#?IifKG+Dqi|eNdj?#6sJmu$_C@|)) zksWeSl>CeS8aI9mxx1?lbFRY0wX540SBP)-;EFV*?u)t_dbnek%hDi0H_h3#f9tVS zVTHP79i;;Exkno>pQQ=ZV$&uGnL;`~nmv&E)JE3>d?NgEgpjZF2o9akU$XYH+%6V( z4;Oc$*pav{y9)|+ZXdrkvDZ5k<#`kWkACE2m4l5@&jx5hx97M6iB^Fj9niR>NYBW-p`0KyO#DauFxQrSSQ7Oh0h`j0(T#nJD;d>I^VNO@p z!KcpAI+5B}uewmkx8%vS$xtyL+F^c{6Eb z8y!`6p@eAA4Lcr(OABE?cez^hv39qbXlIYLmB(oU1^RTL-COc^CjzXIe|qQ|*A-%* zRx9p2)>B*~5fIDj1d?_Z(HBEuhR^n?IoM9j`Kp~)eG2=ydiYYrZV>YBQw2a4c&^~! zVpNu?Ir0=Km;_yLG(@G=ean85S@L!i^+Ssq810z>%B2+%W+nQe*sv~<`1|%=!cpt+ z)jV39DAf_fbsYZD_L1^5fQip{I|p-FMXn73vNW*A%i$O9$O4Jn!Vl%={YLgw~qnG`!yAK48|E zxGB)SSG07DV12nIfGTYzgK30o53|IjR-RNl>N+Hf`pUNR98fjg%GiFCyL*0W#{$%t z@LAG=4Ycgga!o4!Y-&A9MVl5EFqA|<$^y**4UmY^L7A9{asZ^z{=(Ng-9;(Ps-_Tk zRJ6j4d|4(QVz`g5H9NuV=rd$vlknEhd#C4HlX3KJ5iC&DW#jpBBA|z!w+0go9X!z@ zd4B)Tc_Z8_3n;`CkPl4pC~?y!!FW(tV{wt^4JTxSdW9(8?F0t+%uQNWRY7fT4G%w@z+)ji-91?;@)s>_XD;UB9gtwXX}mk44^mS z%(T`k(1yswCV5s$(ZxF`^gAIL4q}gcP#y$u0kcLq0NHkbd5|8A2nQLu2P9sBgvzRT zoiwTo*u92G%He@K@AKMaEAl5nhhzv+_`QR7}Zd%=cee)8gJ7iQyG;pfdQq*>! zDkXcFbaYoPOrQ_0mqM5&H61tNbZnbgxq{6pa$Mk;G%%wIhqctosuF)^kr40TY?o(% zw$>8lWPW|WbvfP1)qHZ3*rsCD?&z%Q6k|ks+2@5q$L7HPd_E41NG-6UKg*PdOz!wz zoGg*+wp{~|;mmnSIC5t236iT@<4cbOk+REe!u`bgR~0$iN)_u-=G}Si`(u)~Ipqoo zOrwEPgRymjah)OVs&Nc`HUTXukyO=xDj@&?{U?6C*|85y=-PMm&aGDL^?dH}18+C* zG4+S>y~~z|xlJ78T=(X>S%v=|EYGgd`x`(DG;QCd`s=xCw2q>t-3xa=N;it|LD<{2 zyt@dsiLn_~u|_R;-0$6A^7nd#uxl5~YsBdziddU7}~OvLl^ zo=UgZQiV^U*RsuTZV?~21sXmm1xeEyo5V8k3CHgLx6gs0-N))bI{C(4!qQ0ce6rjc z7@?$=EB{l_WB|Wi*7#nITtwS`h$oZ#c^xHOr#F$ER!QfGQ@IwzkAQi2$PES=-o&iaj zlHZDGx@pVD6^9*jqJY=@+jx}S@<7_FHCkcEM|hYcI!{7w(`H-tr|8=A08S2*AASey z>rV{kPj;>s-Zpk<>ezk8<$$_&DmZx%gs zM|&XMcP{kJ;-lBxuk!29=y;KaBd?Hn4vLtl*vN?&t<_Oyb2D~2cWRAN7;dsg@3XN zgVIhbJp&orGx^#EuY=Wc+aJ8`ycoMWCW-$Z3Q;V>(+{dUL9f66Zrc9GV19km9~W$S zGgZ5R4mtANe2J`CG`+-e*+1Q!uGrVyGd7&pN_V55Ij_R7l1t%bvmXUhO-Q4w@+9n`XA zjFQ9gc>ibfwS40y2~rl-FO@xYTaAjqc%}dao0> zsCv)Zd2vnq-6lPe>&s9bhH4p(s|Uh<8$z+b4%fcd_Of$(hu>mmspW`M1wiZVM;uc8 zz+h6R8Yp+A#ur^_ezP+foSMjgmS?J81=LD+oUSqv4fv+mn=&_{H ztei^9rTLHA%ip9WSUk`r5qXUOha}$nyeF806V#R=p}N402>ROrT7d2wd+TncwY?py z7Ajpz1LP|nZHVg7cb!1Z68&lefy6q#@|=S~paVLB&AM9g#d}NRYV#3OMWSnFG~d|G z*(xzc5PkF`X3;8==ObJ{f2v6l(n=J1>!kptKm$nIwrlRo1CRMG#_K)#K##3LsVURg z54@d&RkMBtzPO){bG-C>|FQE5~gfyk+3R$ zmt0pillWn2C7N``$Tm@_nJs%Kv2L$^qZ06_8k=t!ZL2JM-y>Gt$mRErtRxT5E)1&- zg&Xv1OXn_-KD6v3)(WWMF}2vfy*96h&@h9V)KzyX5*OP1wyBV|$_<7{&wkGwd5Y7| zN92@stFO~?zsqy*dbXN;i>?Z_JI%QyWrLMfo^u!+Hn0(}nxt2ojo7TbjI{0}TxuSB zosg>5aCj}#C2i(5EtzD>c}|fVWn=x$=;iWz?Kz31l?wDY{t%RHCu&h=_o`=0NOIHMMljM-xygRgGA?;b9*ws8L8DiGS!?Q@W*-n0upw(= z%C2=g0}LD&gh&=M2oG>DCWu`kgCNx*6OdDb#{MU>#$Ryg05!EfRrBuUXraP&d}z5{ znMrGKxy=}t;c&lBf~3F1doII?%19ow4*B6eE2P7$#7iKQk+Q0{RT6e1nn#QD4KI$I zKv%(cc#!6I*9O<@y0+obysOtf;!vat&?+_Yo98l@7e^)L+>L)bhOq#fdVfC`q=9m3 zF@awn=8ed}Dw!-x#c&g}>$H9Q+quFCfj_DND#B4fqLP_J85{R;*}+7J*|%8o1_<1d zp#_4AbEl+!H{8@;&v~E6^V~$wdEf-LsXc7*xu~=A3t72m>PdLKHa4rX$URl#5&J=Z zG4roGQJjvBOj|Zx^}V%RQQqvJ#mXE=nj%o?Ao17mo1*}A{ zG{eQUH>U03KPt;i+wpQHfsvG)fc(L0aHh^4;3RND2RKKQfu8q5S6%WA5nRl7M<{Z3 z`Y}0c^*qLV(Vvq>3it6h$N7CgrI;e6@!(p+*-_B{S}ly%s0NQj+b4{MMPu+D6ylHC zAozQcsyu4)!~g1kS+E^Oa-eCYoOT}=3*=><@n`kNYPvJf#miku0KFbkJbG2<#a!(c zRWO}x$xYhHO=dY(PxyP6>iRe_I=oKmzL_s}~zud1)YpfbMWe?@HNH`u# zGnWuGIwWj_e8kac9c^HBUNuM*vxI!CovXGa{m;rNuVm-S#C7z$arQf*AG|2sgvtTMd>E3FG8g6C`{^&w)4|$9Mo?p@)NG zSrP;@>9-iBNQVbL%ufXb_`KpWxcJQpZ^}j$JG%ro1oVV=L-`ZK-r>)uM8QnQ z)H*1rB%aw%MSDf)%LXjj$81X1H4n3>q#dnGg-e+i46ScOZ3SSk#$uLriCRm)tn*&I zzPi)Y5D$LeLj+3 z>IAmlc4KLoL|I2=b~64tVW9k)eLg@z-oV;nf?qjteo>}7`%OB4*^k3#mUpsz#aMgY z`*M))Ut)b46h2k&sc}m4sF6A@BN796?&gYk?dQ@B3b%M0J|@bGj1 zU0I?+79pg5CyQiWJ5sYOw?XNB5K^*{);=4RS%ZT^`Rcfm>DBP~J!qKQm+T3CUt7%0+-G7|>?I%`GZ64fH>~ps|{dFgNQEKq!T4M%31uttK!GD^oxF2-Tn+8&> zP?Wi4Zz2qw5h%3_ShRj2On0&+Ip1G&ur{V~vb$(Ex%#7wdAeckOsXjv9DjW>nWtNp zZ5K#fFkR;!Pxlw5%ZM9g`U~(HI)}at;P9awL;)<<`%Cm;hD{vf+iQM=#=kwL-hTtY zd)U8=9R~A~uxV&|JXS9OPIbUpEvo>(f-%jG4FC0SZT*coj5rQ|L%&SfJ(^e+2BjZ*&a4`<;6$Gxj1Mc^^`fHc8cV7XNvKMmbrKLT6Wi#UL)awTlyAoK!J<6t&&iSILqu}{tJ}UQ)Y0m1Srg(?>1)u zo@*cUQfo8~(!J%@!z9C9z_uAgBjICJ^PSxLa6RqsqDa~~o({tgkBZ9192t%y0gXIT z{sEw3TB82jXNqZu0gJRuQJ@tGCyZfGkOGJ(E1F&oFP^~SPnS5yUEctR4Hj_R?Sx4& zEjg+szig0A3I-BO6!t>&m2P_Ho*!I|9u6ZujBISOWGezDHLYgE?R4*3MW2va!@849 zGPqh95eB2}Nl+A<=Ay=V_DIOy~$&p#(5(ue1gMxG9)&kFZeE0Dm!g6jBj zWqWQAw0qUlS`NzxyUi$we(w9aUe{|VCt>p$ znbCl8$X4RJSf^_miBP69<{6iC&?F2Ce2_S0Of^T>2f#jQD412aQw0cjK!Gi`e$?Ze zQyTlTOi|nL;by5ou9!mK^xnjn$|GMVahw+K+=Y18b8NNWhbrq=+Mz4QusLWT)!q8| zxsl7|ZCl=`F_xrx=li>^Y`RlQ)Rzt(Q67w25^10V3K=PI^b<&Rbe3a|*#~wL3ku?d z9D{!5K1#~C=fGOR00lrhc1|qfl&L{f)@Zz~HXmF#&xRW0lm<`8POZHS#yC$hlw?FC zxafEJ%WQ5Uy@qX<3N@m+)(gTcPxoNSnJ&gV^?P*YS*Ppdnl*&;S^KeExZ9QqD@ngL zSIef)nov~_V%+$Nhj>?(QD(S%IvN?H_;K8DJ+&Pu{I``3iz;lA8` zO6ggh^rNdrEJsVcfQysTUP$?+hmSR+nl7tugItFqIiswiW;ah#e}>>o7}m0n*3;(CY!4g2c;cUA1b@g_pw z&m*_PrS*;(D$DOpG#A+^n1xSx5R7~I2_1JeUuA(j3t<{t^;q7z6;Jcb3})K@LNZ{W zpv+2rUcbHbBZdT~0=d1g@q zt(I-KAA?gbbTyfj@_5}lcUPP5hi4rY^(08JikJ;2+ml}3`TnaibndYtiopghMfb;I zwY9Fky5YjQc?zWfmO#~ZfV&DZ%REVpdyzfpZ>tRrGqk$}WQ3bWCCA*R!fp5?r>f1+ z^J6hyYGYqtr`QfwE+8v!8sG~39 zJU=rRFiX0B8%#d+-|uZDO)Y^89HB17XH$VD;N(_iH(j&*ukYgdA54v@!p9yg-99=@ z$&PyRdpN_sUD3nDj(TVZ)vACoRV{MclX3#mols4(Xjl+aynWW=+ZnZ}4rz-}>X5!b zjPT8gvzDjB-m%K180V#~oX^NG4$)f+P%U1aqirZ8mQ}HfZ(RpI6YYv3SFOd4t?u3Vdb13~JmV5emc;;CKXUN|*!g8u`8AzR1Em@1W= zQE1i`2dL-*JnsdBW7g2Arop7czg#65{z)w{t}blD{8RbMmUG02ncOYR)U*9%-By(^ z|KB~dRKu=qKT0&7>5Ti~QQPuOgM=K*K6=+n{_ycrDC?dVef}6>>cNfEF*-H=rFe{5 zD*wJBi`Z-p^X2<cqg`@AjLir&3-r<933|iuOSLxDWR;UWdC0&u zrl=lXxj6!b7CYbTx#)5u6m;J5#CTy)QvVOEHWUrTY*N&qm$>FTJ?)c}<}ZM2_vLbB zPdO>&hflSYtwX|U9>qX5q2@nx$*|PjdAkarvyRI_n72NT@y^*FMn`=cSKCm&vxXmZ zc8v36s?hm<8cx_M_J3c5Ai_vnmv)$W!h8L7#i{w{TcLwNmwL&9Bv=A$xt?V>AFOV> zBuDOx<23)gp$bLBg(TW}O>~7(`npS($#=Cc9x0p{q_$x|IuQ;{MxaDAt_c7JV$je2vKVkcyQAGA48@nmb=u1QBt#v#>xf2OGcaLs z`g-K$nUa<0Ashz>#{>XU;R(#gxqcHX$^%}SFF&-6n%8iur0DJTU$=Uzs_Y+yY^F_>XI9XUg&6*9^2(TeoJ%j=6>%fz7B>&F zcwN!Xz0*rS)zro0MT*d9ezQs7@Tkg)h7Yj9dFr*aV2o2(%wdrX0bg( zuM+7i$14hy$!>+2AEQ45`}RkVO<7oJ1CM4}q?JS-_)_G6 zlq4c4s4L6FPe>xr(G9@i1eh1=T3@IMpc62I?ikUCRSWupYY)JWjwHKy<3BwwM?mCo zPKm@s-|)+DiB%GD>Lw_|I5GV!e@_691r#JTgJZzXem#~z8a|!2nXh8TRGrA0L(TDA za>HP5v{2p!eU5iMaLBmkkwnDLya_%TL6ky>v`G6$#6`l;(@h6(Z!ngE4{!(v)N(kv zx)i5cTn9M6gz%=h0M3K!q7E+X)Q99jB#Dh#d0G z&u}_$>_Uvk@(ZEvjLmPN_kYSGw=XhY2)9!L!0q^o1GO*NuYzB-6N}?)V`FS!(*SIT zC)-<%xm0%UyHptfZjbHB@^^#dKNA=(WqKpc7+*>@e!UJY_?Gp57alG5XWJr8f<5%F zho7T$q?go<(n(Z#@iwpD(FI*`$8*Z6$Ao3=lO!VlG=A9!bzB=;GlSWGofs5W~=hrFtu1v{B-Asr{Y4 zuAqhz4oB~?HX5-AzEKTo>!hK$TQ#MUZ8j$4M#%$j;ZF_0x_@xr5&AQE88G~F=`0mm z!eze2Qv*CDn4=hz=iuTUmd@f$UHY_kGK;3A!-%-5W-E~V?=BD2bLY>M`A0cd%fqbv=htRdy>UGL3iTT`~!I7B-o>T+Lx`Hq-DHHXd}%nz9tk z*w0P=d%m~BI1E}ep2Yn_H!H_yQe2sGjSV6ei(}U9(Q0q+auhAZoiTg%a;*26;LgQ> z4ZYsv_xEi9c7-~7kJevA#Fq9D!Q#haS)>c#RH*~90HK3H;$LqTIZQ}{U1Xfs_F8Rr z&P7miCqbXUlZT+Z<{NUt?yF~W&0BrR==d#rlz<445F)p|J|zIlb~qF_lJv4al2zm; zBcuK1&u^Bli*02FB*08ySA6l+pv<#(zV$C}E1y07nD4Exaq-}*bw7)AN2Y_(B*(1z z5sdYJwB)7F`)$Ep*p(Lo^j8^T1c;uGG78DBPmynvEHT*>$NFU^j0%qBRLh6RqWf)&Wsb_jXe_$nm3-!OyT|OJ*}IdpJm8m z@B$l0K?N-TyP+G&TY`=a!A@$=UK~0*CKx*PX}2$-S$UBZ^8{C*!5FKUQyl)ThTTY? zgJzB(Rm~3Ltsoa0`E*=I-3_6Z^_0*l$`|<}KfW8#@e_96)dQ7};EGcJ^|e?ESQy~t zy+`-l0;fZz9gp>4g49)2KgSX<@@dH^+rIljgAERW2sy_Pf>K5Q99pm7=`OrmPmjP! z`E2@09%YtqUYlw%h9_lBlx5wq9LPL*(T;QiS7k@>tz37q_Eo@#iiIC_C4z+N1h=H| zwQFNpq4|TV?rB6e{ZjX?|Gv5BS){cTlQ{1k!MmL7pp*b4H&=2~Z%YZ##MzGEYG zjiLWD?Sc!}uLY;{4iDo#k|4$03Va-;Y1!1KyF&C;m(thO?Czf@7n;4gb$T1cII`+S z$gy0GqNCN@Xpse^Hhe8v0~Q-9a?#z`WgDj8E6fsFzsu`mWQbH>YID|mvG&?L&UrN=`GPWd%0i2PW3y!L6MgZ`^z7uljj#KG%=wcy+VvQmu=6W z1_cf+%2Khr&^$dv&M4d)!G!`3){tJ8um(a10OkT1=JYeqhF{Lk+7?-5#Q`-`(L;$v zJ6=BI!HzzV_TueuSiC92kZz-p5YqpwPIUqv*!;}<5^c{S@caJIZq|?@5J07GvmC$T z&MPkn+Abf`;iPZAIlo?cBhAJImTR(?- zqHnNPR99a2Muf@atGDwaX=(`XiaOLr2+imQun7EOVfY-P!{^5Rh(&Vw1;0|laVAeL zY_nOHm-}ba+H3b4>}Jceo{IBlO`=go{mxTf_>T!HFL)5dz?dP}*14YEvCW7RW#kgY zytE*sUxOB0iHN}t)jq(^f10g%iC5t?pAQv`k;*p<=BYqnIM5}Gz7=}Io*?Rs@uwp&-^GBwb5y*(1|g~HlM56hHbs?0H*g89;qM;sxE>|+1d*KYY^}uHo)89S zTm_^QwOaf1{LECaxbxd6n5n5uv?S1OFpE2%?FF>BK&k`*0ARZ#DDIvE(CHzEqYpMq zrS|Y7w_01zn7{N)-3exkOfVApNG0(D`h)xgg9xlhmbGDy zRk+{XcBb)FBEXUk0ebNTb?s1}>$ng2o<$9+aMBun&lb0K-}^CQ<;Z`pdr;Mg8V}rz zzxrHB&Frl6uhWxaw`n!ZzboLmb0k|`&Y_I#fAC(NyAh+?aKlELh5gR8@`K7Nh)feulrFY{2b1~bs>`|qSj{37WAiOt)dcT%G~DzApl_;0Md z5+d{Cx&L(P2~IXj4!g8^FQ=lGxx=bJz1BUu(OM|<`@yT^i-`_UY%BYG%j8OBOA^t|I=Cu5DU2mI0>U7-389**+4% znDf&+TVIHQH>|oJGdi-Ip3aAQxyAlM=|E}j_@%GEaBY8fiKbWK`sGyvMBOC5U-_}| zKOO&1oR7=e{HZUIf-+QIxarm;N`Q!{rI@v;H_?k+Ooa<&QsvR;S!s)IRPeR7F>l4^ zr-Odfv7~8^l-s?E4YqL>iU2+#kLT&AUw)k$gn8qz{F7{g)uJVVt|HAMne}6jlDbuG zVv-uESOP*)yf?qV4G7evSgi!3ieXli0^g44gu)t7YCerOP=v z1;4NDHvK88K-V~}==Iy%*3&75pf0N7k5I~bMTKc?KN`9ALO*`wda%OkTu$(6BAjDQ zn5_bQTG!kBUbjMEvN+NN3O~ovm@`P0n{(}ct_Y?eErH6)_S|xPi_40$>e0-%?2_K4 z*I*omYsf(>hXOXr<-*s5F|esv9I8IjK;13wGB3LxjbV+Nb&;|i*s}$gGBn*Vl>SdE z4GSGdDr)e#`wModq7EV=D#DR*fR%_4so3e zNevbi?!K4z?t=B82NR5H7ms7dq#a_MgAIs;*Jbo(S)@JpXLBNlFD{>wG^`sj8t}Mq zkQ{)c-vq2(kOrZx>n<{7@EXgGi{<7cO&{LOoaRGS#QnQkbQ;PmHAkhaM+M$#AZWVZpiEt0Us2&<}}%d-;PxUMm>a zjJq;2S)>h(xil%EgfN}+a;T67^_E#0q4|m1%%4_nS3P5)NQw^mjnnSY0_Uk(&@m;< zgznaWwY6+sY2Ruu^}BO;)Ebw7cXCak_t9V_=UW{h#Jf6$;?=xP2K_SQc_@;h0Cv8% zKjt6`EbUm|@3`WJwH=+QCEQ*6vwbc?0sRL9J+==iumM%4ARpmS$F(zldW3zcQoH}F z@u6lj2G`-k9{_foET|r@qo1sMjHPSVF{IJ7{Tb_2$imup!-FkI_QK}3MR40GDCR=l zL*X+wH&sYbFwViEhi*^h9Z6bctmS{^htZ;p{blvHXaQ*hzfrTb%91;YU|8HR}!+n-;eQ&o65V* zYRhulpORPA-$#1*%UQfA9`qO53UPcJ`Av8(O&!U;0K1HR9CMfVjM_l9G4 z5LL6#GkpnoJ7ImHXcy)sB>}twBQx@M>#5{!`*S) z#UZ4+-&6Kl8PoKIq?suzC#i4xrdqB2jjYTk)Gh zW-5fHSSKOsTsader^GK=Jb06wAxc$rg_lB{lb7REheldstb(kpkF7TGtqGOq(M7rQ zla9G^0*4RmhpyC*fAx{w4}l-E8IlVqUgiC!R+;+ov*2j`Pl5FCnDqEHqyf8vX8+MB zwG$7cM4fHk;9WYg5%j7p+1q=Jsm8MEqy)AT4;P7XcPw3M{4Oaiu;VD1t_mtU#P=3d zI#>HjETkkTe>bDQTY312SPy#d%;7%R-F_^-dAmDj$R*!pgnZiqs_^uAsqpr}hs0WW z2~of1o0V_j$!T2a!GPcbb#1Z<6_FS@I~za0Q18`i++pAUdO*gb>{!5RheC{pvPV){ z9KuV*G`=bHfMpFc^G~YeK$x4Zn$tNkjrfF@n>@C!YGN@rA z&B2XK1e8sbKQ98+X+Al_(idCC%#=L#XH852XPc8ETAKtIxjRR*s-AEA+H{`G1oi9?oow#=~EKtYs zvg-Rgii7L$!^2!?a(>uc2=(c;SU&qxTgU#%CTR;X7Qn)0POHd z3Sqcumkh;m){rts;PsxK08bC+=B9C0HFS1wG0-b-j^_JC@VU)WmTt~pGv{GkMiA9M z!ou1b=WYn5MjAgY zfO?GYabcSrZZ<|d-yqN7t!8@E&?CkMx_YRr<(iwD3-b~lJF>a!%}-B-_!S#ruM`I|T4;i6Nh)nj(!C~WS8zN!@*5757+qj9Paz~G?$+^$<_D&A zq#*Ax#Di7zK2V%Omv_d4&bW660!PLPQh^7_*oDWz+>p9GU018SG&NknFuKoUMZafm z%>-Sc-Z)E&ic=|P>D&^2&AMg6MlI}(We{1L>dcuQyWMiJgZcI31t}@Pm{QRl2K~a7 z^AcWbpY3GV*S6~^7#X%=XWa{i6WpZM+Czohdp|Me=$BfuW$&JVQdrNuc#|Vj2R(y% zN9g8rN6rp_z`I5C>r=zb>|X&KBrY_u%d2z4;kna;!tCH*P_ZQY6P@ zN!au{8BfX8N_IbbcgYc|C*6>tpxqJtIAZLwY@^BTFe0TE>-j0cbIecbV#L@*wg;!_l^KFznGYl_{iv7@je@7Jz zM=Vzw&|(hbm-B7<^L|!);cU&4Mbw+>9B$tVxm^3x+zjW%tPMpMbY8;qM^riD~`H*L! zQ#^lOr(o!Oq3u-G?7s-M{uoe|t+xOQ>RIR>0cvK&m0wa%@+4Lpf3Uu~1zA&Mm-9`+HsG+`9g^tG)82))5|1bh3LK4;6R1TgSV)^8*qEUb3fSFYjfi?fdg8mWGt&7a{sX;Fa%vndpqL?B(tUD19#*Yp!!@| z`Hgqb8M+6)qo<-(yhZ_?{zY)`$F3&i=L4mD6zTX7r3P0wBfaNSZzcVGUQ#dOyHrTr1urPRgeTfj@`^wAqEcrLw%XMr2 zBP_=Ps_5+{eQm7R6;nAxZN82yj41=jcDE4T8$3Q#su8?Lmc!T9=LDYemEv4n@WGL{ zm5;#(Tp`w5^iHI>OLniI{l@{b!l>kAuEDJv$`kBZmWQ*FX&kqb2UG6086Dnv{5Is6 zOnUUfNdjgy%T6eV=;+F{<}uV%sXaI|P7ko5jT0MYOgi1nl~GP8m%LwKlUDlX12`hV zND)xiI{pE0bbC%?o1FCKQz-ZIYq7q1yVtAyGtMW2CdYcQ>06+INcIk9$KYYOtL;U= zjBq?D;5GcbIJzZ`{?F7I6u_lzomui9c*v9GQz(-SzWLE~q|vfSxat4x!r(unTFZ_& zFd=L?U%dj`SRuO%cQJ%{QsO*UksUE$)U^^KWNZJr00hzgirX{04)LTE$Q66}0kKrL zao#Y}zI~GG{=iMyl6UW$F**Ek!aWhR6sOxl$gs?=i60SnklwlZI+V_a1_@R!OD+jrBc;)$QsOEOkqpj z7t727X+?($9wRFIIj(xchIC~ZXtV$3=;=6(^XNPCsJ?saaPscRXUB>o8l^BvKyDOh zp*gMAJ83t1y%qy!C3=qn0n5DzYOx4ZjVzkueq(kRmzjs94kXoh{znC z41j?e7j^u|d4u++0N%I|Qs@*m5I_~aRqvP0&t)tyYJBqJs_zP-7! zzxUpeT0v$#d^i@W9j@y#&taaNGPZgZ z%I~id1}R8C;ZA_LhVuE=R4oB>8l9>2Q`f@d!8N!Ag2m9+4$|#DKlN@h#;kqYTop8` z4z*BYT znV>4P6iB)hdQM#0|Gi2Te*IBuwD7@IRxqTP2mrMP2?O)-Is%7HUvVO{A1{=o)+S4b z=;It&Fal81p=-J4J?6*R2L}f$Zjb7$q+f7Eq;*>os28*zzP?}krDc9tg@d0ceTlKi z3PQy#x4+b7ef$Grb27ay{ogVCzZ#4oHy|a+#*&Kt<+7DBVDhCoMab-$1cp;^rs<~Z ztZe_%!<$AM6SP8>7Lh~Zf&IN9zimx~cWq7f2Co}O^{Hag>I9^-f{l+aWrIajE6ZMg zX-Z;k{j8tH+C^FQF<2MVE01ZH<`{x7U=24p$5qjbleH;`Wj*LCfSV*YYEwN1N)%+; zxXuIPt(drux^eno2OuUjUuhFO($G43l>Y4H0;QS0lIg_GtaroX?om_$$w?(f=Wo}` z!uS)_r^Tx1V>pvm`!*MJT3PF$-q)1mQ(O&H-r_lQ@AQ9pWdEw<=Oh3nV8+YJ$z3J- z<2p^&ic&!;=)6gg1H%;%ih~;=ZAjGsANlxZ@9^S$KV5jjqnBO=n~LC*lK|F*Twg1T zy}id*%DnB@Px2v0%YztV>6<%3AGr*Q3yeLma}U?qqe49HeT(x zyl7WxlJkkNg6HuA!-E^K;cNnHTPwA2>*W0QaxO>$A`QrGx*jvzEd>U~Vgc#^eQPGf zkA=udq})iM)&RoBIjFqu&m2dSFn#^y$B;#uN+FguFNh=Lg2^9HKJI{P-y_tYX|2s9 zs*+d`|zvJ1Y6#0?57m#Frpsk2jD zW&lDVjnG5%|KIly?OD1?JS?;Wg!jFn#Lx$y`oYG09U8c~0`t1kI@?NpVth*ObV?fBlrrExt0JPu-+DL_H2t<;}wyX_()3w~{ zKyEV}P@clEmqV1hhOV=^#6!Sd;ffKgk=@xLOINm3_gBoe!ely7TriGj)uoM}Z*!Ww z+hfL|Ddr^2Fo!WHu1J8dO0l@0bB zVZkr=?>bv2`;2DFHj+p!t(|4Wyb``w@rynP`lLvR-CD71jynh^_UHS3_2GO|x%Rd8 z38$5#_2vrhRSn!lpHyRYowYEB!Mi@0sxTurunL7nueZe4AL$ksB_a4CV&uvm%^0uq zad=Sd$x}UD$AOb8pJJ?fswN+!u7ef=(CU6aeTlUT_O0m#-*oH)QPjk@ERNm}_O*5v z14C%k6|T51TCR_%E@6MFj88OU^lU%nvo)tnWgO~VJg@n?zTzkou44?sfSTwXu3y!C zKnaa#B#TTOM`glYIgh2|e5BfRA_n(Vt?#gW>e;2Z9{GCj4C)CUpDo4)ferTyCE=b1 zyP5;<#IyZfqavYPIrx*5=BzMW={`xXT`sVAF{HX@beItCT{EA12bWVri{LAFe`NWKrz$kx`(9wc+!Ru(tO8C6M;F|mjvDRw zSPc*j7Y)M9od7;vi(TyJ*Vs<;uiLg{{Wq@Bs~@n~#i~V}@7b#;NpQeGfK=Z3TPmWG zgF#=iO@C=yo|P^P8rZ?sTK*(u`yhG|SfQ1o3@Fm@FUB>w~z5NeNH<{WAkhoa-4_6|2KKN#M6Ssx(8yhW->H ze}3Mr_J@|6A$v7m>JE?~G{i6e}Dl zoA2EstHyDH?_dj!YKp;a!-SZOK042}Y4bnYtUuTUBhc-TP>qcO=efAiA_bYJOpYFW8DZ~I6F z55|rp4#Fq#7?|Grs8n$}kA!52^>jU^H=SWlh}gLsQ=1`Rl|a=K1%}CGMI60u?ycEA zk-1=Tb47QOyF=IU7WLNJqEiF)^5(*jO;kbBdt;X6B3>~QYqz`FJVMYfZlWwsRNN!k z35(p6!86R8aqKy;fAX96PMqEh7XyIlT4@bLOZvl*&{@$Y|P!GQw)f`Li>?FIg z>Zu)ZBo81-&5)S=e@>S~1XkUUn1OfzTW#ofPXJbfmz^#;t@m|ueAEq8uT^;)w35T> zI(C;C!ks}w{zpXrZ#?yY8oGkRO33;|db|ivWC3M!`>3d31;R_D^ z4#b!hgQOOAMZV!Dmu2~8YRP7tKtrq+zwh#ilVw~sVH`(IPhxM}2<MF3mpH;|L)Y zs1^7DbEt^Kq(Mg)PU>(hc~*V?!jZC$`z{PiY1-x<=%qisOoO+*X*bg+ z0nuRJgD~`tb>(*cZL{xx>TL7_c^Z{+AQ#e*=m7dZ#P9Wqy6Wsi~tzRHxF%H z@@e?=WwWCvZ>5YoW=H;-AbC(HbOgK0qL!qPD*W}2g_GL9IxhHt&LX#asaHfTjiwmT z5+UU^^c}fgiEEhW){<6PM47;_Q~`A_8!Fb{4T@k(r5H&*i5=}`CoFuEwSytzi?ZzQ zVpuk2CtKW5J?V+!cV<00cCwGm4y=kzK6pUfzvQ(r&H@$YGh=_Kf4Qym>oqdZLn4mS zE8#2VtvqG^IrUfeu_d*`rA`(xx-nuIf2^5{*>!Y!ieaJ>s1*B;Ev*+DZlCx{sE(EK z;}R9sDP+?b3xF^8GjhCYLkL$<{~YjO2D6W?cmFj6Z=H+C%24sW)E5s1BqD)Quvm8j zuD$j8_8KMfBCiI_Deg(2w2NIi94GtyO4kQzAe%AYmI3<@%O1fypli1cgCGKD(I_bL z$SUzSkH`9Nz%p3Fg-WPsEwz1ma%WFm_5JH5xCxB95UU04NF>5ak7&HTz^p9j`x77kzTk@xPBxBw$MMynv84(E>Z=+qn zw#uM~81@OGj8+LOYXp@58gBML^rI`D*ZK7d-fXzpKw6=^EfD#g4@{Y{&r`D6r4#+f zOHNqQm_EeIOP&ta-7|ZsEo(0cv$$hG^s$H7?56WD?|W}dIjN&-G=E(53N8~&lH43E8gIR$R+PGa_2zWWW&IJIh z9#!m5ZFx$4`YL*>GzlpFzF5pu6aa$nmp~@7l$1*8Gc$SzaS}o6<^{rDJiP`!7!At% z{Q!0aF`?iN3RQ)Aco#^%=J5Bf*GI_#2ngiN1Rg$Ks8_^oYu&Lyc^E#q>wYp+tSK_v zu`=(@a8U8OAD#EGBQ<6zfdVO{KQTNhd+)~>6~9CC54?$_!Ka0VGW)(hB=w^bk0K>u zl1nxyDYOzRRy@C5{_(LLoSQjTYxi_k%4-WD!Lc_4(t2(S(=;FFH9Lsb3Gorf515RS zb$Q|B~;eiGmLP${so}=CYqnSrMCHXg+4zt zZrN4x<6{**W2$5I-%CASUbOsn!egePwKbGm-50hit;)Qk!uBt(KS4Qxr;XvJ+4mAG z3A*H@3ja=uNUWfehun5rHsLh8%<@>@`~eg5p1gFSiYKBdj?_g!Z_`I49RE2QEFBW_JHdD>;Vc6 zE!H=aUkX*r)xwfjNpe$qA~89$sk{}pLHN@InU~;PZj6Lkz}eP0aKhXB*qR@Cs(4f# zggm3Kty4k}RZF1mYUJKol4Ai}0E-ykl0C{BDRfPjGQp0dOvpu zZLC2P1fC&(bUXBGpFH388=WZ8Ej*WS{kg&1%rfZVRUzmjMW)Jj1h z`LBP|KEUyj?J_z}RwAq+&#LFBAM0d*jCc^)9tg1dvr4yGO#9WJv$CnebU3kVjksp; z!qAWx^)OoSBIB9@UJkxI#HxWVOMblDe)C*?IB!fW_Yhe`s+`Q2*Oh=}Lptl@>q?HB z<;yXp!Kmq?L<<*KJ|Lm*ebc(z~64T?7<4 zi^(BG0VHHltN27e?_^6Hn(2%^z?hNYn%0{rfOwJ=c#JdI({WkTZQr~_MNj&BzWsN| z*+=3gG!AMs>^l$2=i9)c9gkaE9VSfwIuda6z$z&CuC`^7AUVzY<^E>+T8UYEPaJ~> z0H2%R7E(LM|M~xii;rIS{cO{{#ngMZC337%3{$TH$I_|x)|NV!vkT_~&kkXiqh_-l znzz6|q*^?fJK>Ef=w<)eZHO}f>HR;iv35Wu9%^_zB^{b@u7VhJfCt8_?~9bt{aSkv z0l+E5d;rMR++PpB1eKk6Xb6mITVUxd=e-tQ8bJ~Evx6J5w|Yz5IvIR4q-Ou>V?4Q<$?Z@Y&yl@aj z8+76v3}2O8e{injwkw4h3qAx}y+%j@U`;7{R7d(DNJN{!E$|ekvkET$hygR{?bt`4 zeu^|EQ$1U;Lbb=n8z=i2oEVeA^2wJNB4pljbCTiMvt7TS7APR(1p2R~w3MLiU)Ahy z17&~9E-oVZ#*%NF_hkO#8YJ1)x=6vtGgvWNosOB5Mdv4N4d-_VrCrb%=z()mo zc|VZBboDmEsiR|k@S~XRF(Z%dxf{G(rD3*06_{J^?|g}GmL4g~K-ZpZf2#pm4xsLq z0=U?}WVMgT&7w)y<8(HyNK}KNrNaEPqW-W7@y#tOI2rOj2(dsqbt7=Z%(43A)D( zFZpVDnAyRKEz)j;S9ZSMQI9R=D*mMK|^XZP5fPzWiN?)+Z8Mb9c zD`&&a7ZY$=DfEEiSoyQt=)uA4w4yB1m>r7BJ8B`26Q|L=n75Vf{tjyCAYZ=kj0=(Y z#`kp~Xtru>cVKwmwD_ggDxW+uHyW>zhB?efINWcLcjtUm!7s6x>Z&H@-$W`T`}_Z) zq|gt6n6~8xdYDtz9nEhCPpoLe0FmaHQ5b$FV=c&7Q32G&`K4NqX zdK1`X*TugWn!lN*nJSrgdrkU|?2Vj3qj?BRLWWKf;9&5TohtQjjUMOf7sI%?p{g*Aw%^Cx{K>t$EAs><7K4%tzh5@ zTnHZu@yXJEdnEV{aO2w8!tZ>JG=jK{2}oG$-TD!;7nbzlzQ1=lR8C+A_|x#UCZd26 zP#l$tP`=QLuv|CjH$>~b6eb&Dn|E>44x%Wkndu1*0Z-t9{{JbB3K_T$GJ?=^x#x6v z)QN|`^tu=n6l>2X9|v$&E08@-D_C2H@>LrBDcN{XG=zv5G(6K6tbfcQqJ)?WbsHcG zohM)SNZXuc$73XQ@@mhq<4a9$xrlw0?1z|$C*ZhcH^O&)ZYV&?`|hW)0I6XSrrWaM zqd1X19}VZJtf$fHk>%;*U*_ z2r2}7Qb)>EeIX)0yiCp@bL`plSl{nCri6|O6&&0OroTUUCbKO6_7}WPhXqvDyup}H z*V9ljalK>$*rnM0*I89LFLGtFz4jhH8g1vB#ALK~{y_)niNW*Cwe-N`$C4?8UG!$J zsEgtiXt$8M+^zM?=^EKEZFFhk)nO&U8I4btQo{wY*Mkz%@ zI3<1&oNqw9R)gQd<*VnK)0yA8knT}kg9>i+OERH?|8?uj57WWvkdQHXOfcZkm$g^E zpL?$N)}Z6gL847uRehrpnEHkoAIEY8CYyv%GUlm&KNCJfM+i63e}^fXliUg#oYv|af&F{2#=elw}-+Y2JdzS;-y338+dW|0~u zS zK6ZAEa;%>)PgC!ylR1<(JljQ)3fK!RCq0#J5+fUW)gh)z7&%;eH#zU(ZO;YBmGkd1Yjl+5_7o8kFv}9#Zt@C zVVh&Sr2H<7-)EZ=#SEMrs7Z#H05$sluX`jPkqQ^BN_sRX3})Y{n`!@~&QsQaR@d*( zxg<^==vo+3)9Y&y%n0URv+^t)5kGT0zN$=_6*91{$?)D85I$*=_X@xIFd}C7gl7{T zBRoQ1w#?)gW6_ry7kK`g(UH=oVnalZSe?I544!Q(Mn$$dZRq6-~`_PuS~d{P6XcSF)^krDQw& zq!BA1gk{x0|GJf|a@{wK}k0{A<%Ng5u!i$X+aVt$=qfb~-B}0KrgU z3jKQn6p#~6u)ds6pXln)I$LLVtG#l zhYNVkH4xOFP@e+Z5k1d2;e#(9(g)OG>0Hh$?=_EJOCYgP1B4e}#;Mu!ruC-INac}( z|DA(Te+QNKgvho@(OB!^moq*McSXQ6q}Plwz2Z!Zbh#;b8=teX|2#(4ZpVI2)B26D zfVgq(X73P{*ho$A)i`lUjx2wKB=L6tvEVZ~zlfkv+k0M0{lGbGyU_#&Mq-Ko4P{Ja zZ2~w~L4?>FBe5i6$;Xe#DG1zk*x--@ZOhb5BjmL?=QrBX7EIS*xc`A$g_h5 zg7GM~PTI+X%FQhR+2(scL~Cbr4IRC|It!c;#EU2k7>>DauK@G&X^0{L~P{rka0^rHss@ICt&;8)z@qk)?;SAkT*m+!;xPBAieK)o5B zdDaPr);wg~aC8kFkvlQ|{loi4x%A6uc1@b|4wG$%i2;$fU`VTDcedqz(+!?ggH!cT zg53gNiN`XHZo2;+5+1g0hpS`p(m5Uv}*Hj8$4cImVa5Y&R| zW23s0b`~8gQtpe>2DII6e@T5tF%j3H|9e)3(8BlbIdBfHDH&%Pa*C59Y--XjWQ@GX zkL^=^LkNEEO?ap*S9X*_jHES}Nowr5Z215bAD4D^fq*5@kSFrz4oWG3)wca1UGz|+ zuQA#VJoU|8+gM0E-&k#Uz6o&JMnN#^KhrsRTiq;!Yt_|G_t(RpSKRK^U zw!a{)kX*yw&GEqbG|C0kJ4Itp#pZ8_s$ar4GC-RvL}MJH@Sd)DHcQ=XauGbSHz?U8 zbKHjx0nB({&p_iQL|^Z^5+lZCyYVRl$7HFDWx>5M4sMx&M({3c2YdpDi{y(ONU`G4 zcstU1vjT+t&~0%?TfPlu7O`6u3cVM3 z7C{cobSzJ^D<$y2qI{*%&-mOZswMdiip|!l$iEN%l`@OONO{8NvR$5ORB3so2V7J5 zq=LW;#g?=)VV;r_pzrMO`&RYp8o$6bgU*MSwFk(@Avg)DfYp~!%L{;)H~bo?i#?Vf z7Fv0;JMCC`?`NGBT!*bKWosY35xEXqhZV9r4Fi?00{(3v6TnumQw%%Tc;T_rX`-nov7B z^P7%ZKK*n2edxI$yY-qyhWD-`#$F&RqMeD*X{{ALZ(16M47qylL!c4^ph-m=!j8j) zirW&j0XauwtT9T60hl9wcDFvS< zylg%ADpI%U!Mcv1i^y{B3DPDbfk^>27Adkt@ zagMDAmB?XF-VO`%OF{l;664I94Mj%DNe-RQdqu;ada9N$rR9zavqj?^n|hnUh6y=1 zWTaX6486MZB`o5O=iGB!ZP+28npX`}dM;@a{_TNQ?kfiPBLUypqxC}6k%ye&oa{VM z*~4S@?Rnq@yS)IjiwLtKQ~3tIn;@RY+gEM_au=?Pl$Se5nhv~UsF8Ah16*Kpl80&N zj9;CrE~|`QejRuFu~z~cDO|J9L=|f%&qpFj4v1C>V57s6mk>_4 zRQmIpami_gXFg!DHJw*G_O(#LAQ5;oc>CIXIst=MEr~%zAHz;7GSo6@(MC`i2g&u9XZ%-!>&{EGBhrrp$#V@}w&iM5!Z|x!AV8S zM2mlz9@AkYzx3HvY4z?cI)JczL0iVSLa$)d>w6@<{Vip61j`D{>r^WFWmx!IPa20~ zw_V%)|1FC+OwJLnDi!MWss_cc6?EEjU4l;}*=FQjFx7kR>8Te9jY_#ai@8h*$#(*x z@6WtQKXg6*v7yM0CA-cgV{jjA-188uVVDE2um2?ZLe+}+BtUY3hR=CGcqmzVrJx2t zZhc}gE@ISVkg7lM`t7O}^Urt>3S?QL)l=p#;lzjLFsUWmM5O!fh{x9z+h@~&mO&Du*4T~nZL!OUzoGl z*%SM->dEnmvYqYsuNd|OB4%`OJPVV_M<}gDxV)<)TKF$>6cT`gbB@MkcAP%W%ARZR zF1iKDvzmTlOa-h}qKm&ApQ?o%S}jiYuuQEcL}2KbnQXq2kC80VjRghl;VSm^bn1@g z<%ks8$($zUsA%GG^SXt|+S*PcKJ$N8n1$*8qwBrnxo+S0aT!HIFSG2KO;$!$@-oZb zvuw&HBcnw2DkD34XC!2n5s{r_L_$_J+5FCD-S_+R{(L`=-#_=mEe~GL>v>)0b)Lt0 zoX6o#s}%bDL*$-SJkuj^)Ky-46HfFs_%hN=_I90*{=HXsD_zt zroniuV*zHXZ${s8+-&yvA^F_(>Y(!;bx*8H$>y=`q(=$qCKt^(o7vA#)dDkW1(R6D@&v%U4Fk@gGgSOlYN?`}C3#>BQ zwAyo^l#?KU{fG}@{Sr|{$iXGOD1kk!>(MXUHU*hC7jNsO=jB4qS#5cikhdKc;~NHytcSI^aqRt+AzT(Y#5D?n6x}*4d$j@}t!6D~5>!-YTB0 z(3Ka#s4&6kG%RWOiGx?XIoEw1SGnQO)${Kdx{R{zc2wJsUBQIbSOgCR-ASdqTiG;U zswuB3ikK>tYZKTLTjj3}m8zR)hxjF!$)(A?PS=NHWuUp>$`Y6;70zx%q4+VI=&>A@LEP(7gJcNEx3 zs|-v2O^F68X3aLbUyuZ!uj%{ul7hDYh$V0zZUYS^mekiHP!HQCOPGAR(C~9NP6{wc zYbFd?mgXQ$zC+l2a=k`4$&qX7N_Xt!!&p*vCK;6cgx$lMomi^l=TLqHWB%?R9 zUDvN)2Zj7S(E-rS>6w=W|4ITW)&|7GEZ#CZY#j%0z?8RmPIS|>LOeLn=Y)xS2TJq* zbqXY4GF=rMPE2PayVe$?vvnbfF6YS z&zj*Q<9o233Bo|QeORXJM_{GGZH2%uv?q zkJ%>}LLo$?*$L8|c1k?FS^X~t|DasO-SZhP4!ZMsW>bldwR%Oc^-0nVnu9a>uG z?Q-Z;R6EUHj1}Kh06zj86IK@|8w+B0AZD?o*2gM>CP& z>;Widn7`UczbmWn2d+Ua`9UpN!qxv;&R5PrELc2@03x5}8@u8AK*WeSIoeX~*1*6} z1rWw2iRa0ZCaGL5=2CdhZ%H%X(XT&(9H?Z&H-UO>NeScnOU6JAbM$wTq`6M=M!RauoE~sr z9a9>Q?|(W@M+bFl>-tPiyJ`TMN?Gkh#XaWn%^}5r@WuXROV62}7h6<^`r%lYxJU--!0_63FlT_Qzg%v21)}-9`kJq1+46#T z7ZZMCW`y@4q(8u;=&xI8A*0CL3pi_vy$9|T?d|!Q_)Rax><>nb%fpT=Oh382c>2K* zW*%RGYew!T(aGxD_00S?^X_Xk&)i9_-Vf~aqwoQy2(&dZGw0n%p^;+_Uq0IXF7-_7 z^nLXNj1Be!)rMOc3uEGOS((24e=$Uo7edjH_7&6`3m(Y5Y}rNoP3<)m!S-GqujzBa z$9PBPd7j)IRPlFZ8X13v0mJn0KsiSgv)=i(cSI>KSKqm( z5Cw!1gnb~rAd374=;Rol(VzvZQGNf%lxQ_b!M30-Vqkng?rjKzfRr`#Kvw$8ELJVW zTP_^Pg=*jxL^f~rqb*cKSXvm6qMdxbz+sZJwbwy9D^?GS0i=$aOa@$BfWapUpisf3 zJfpIv)PSdd%1BL00(8lh+p%F!drui>1E6;MLuc;;Lm8QQgzQl1IOHYHgINlOdC&yEt=YVE*LCJ24#*7NnwnLF8!PdzfRSS>#mX##9$YfRy}UGl zhE)vO!g50rHc(X3ue4DA)qLa0S;F1E8-YIKfa^L_Y66tUyH@sxwpd7%Cc!$po037| zjY*W4Vep|B0MhwBxU(W%=+CpI9bZWA9||B56Nfy zK9GpGk6~gSN=By(u&}%7M>TqRbbbp)gCKsd0q$$Wls1(c) zmBz~K`D)N4lCo>a?mBx7g0AXUWVb6tj%kK{_F)iOoiLaSPlQY}$N3)NetXcmS_B$M z>bC*0$ZFW1imaedYdWb%CVQzq0ZV&gf01l-ncNU;h6SCzc%$isKrh=p+UQo&)Ir9# z%^k4|IR%ZnWaEfhIU*f{D*=H4H;oDy-h2Ho{FOZx6&uY@?4iqITIYY-gpj?@ft#uG z_jhL&r@WWexc0>%he$AK^0Cz?#ul4jJyOW_s3LyK_^YbSd+@l~?==^}Cj_wss0(7n z`ZX}^m~8M0L0gmq zTm8ozqU8AJ$3hN7W4ZLf+v>}FV13_RfeW*Gva2~DeV@5i(+|qcEY=RkNeKs|3V4~9 z!u7y+NvegnvC*;|)OggWvmp%QkG61L=V zsUSo$BI2S(wT@<32K}mlTs=cC)AKthP?86FOusAgpSM7!JgcIoG;K~X&|9Y3|K5Cl zA<-~8Xl)T;6dJ)1`wj+CO+fZ>!rwW^vESLywHY{V7w2Wn z$^w$e6!t#@LdXMzC6ry>@}@kze&5Y4MseByCG$s(ue~^`#v~E3FBJaQ`GXyxm~!IY z?zc-P!Ss(=JtAy2|LN;U@PT~^Lj+ofHeJDAvb{<#-3-hftALUCvEF0N;2$h!h(Uc}4>77PGBhi7KWvb*vO^9o$|)G+7U()1oaiJK)bXNB z0ngp>YCdbQwsU!J1)08%^{WY%2I)Ca3IFf;Q8Ca@$X{zZi~?T(i!FdC5Ws8_$aGF2 zN3D3^%s>bi)(g;}>3~Wep|9$|gc?$}(37k42J5wXJP0?EUDLLe-mvIn2R8Hu0kmnjvqPnus0?*zmZW-$R=*J08$wb#z_}@^l|{wd$3yw4bTVkafu{p^J*~5@2kpL8B^@1 zLJIQ7zf;!}%bCIB77^`1&T@plD+jHq{7-J3bg{@6$^uOyXk;kVVI<{0ozM_8R~8Uk zV8ZC%H~(ECidKRnlAh7Ow;ObpN(9ja7)U~oQKQx@T@16n*^b!!)J>ghRKa{M@a;R2 z|8+Qr8PH%SO-eDHreCPRQXjtMZH~Ie#dWKqL|v5l($zbUoW^w;FC&t}`8C{*@qK49 ziUBcD2Vs$17*(hzfhj{kU?fnn3E)5Be2}{S6g&wKcPaaGw{xlY!}>YJ{7;|d-a?;R z1Jg!yE&{^NKNDg9=PiBb5O$o*0Sdg5*%B$#czCxRj+`G8nvUO`G8uK)TCjfaT<1_*gtM)}R=c(;J_JzW1+rC>I%0CLt5e%y8 z8)?zsHA>95mjwAAV8_Z>{-6br<^;T7`XI&&{d05wJ2UX#vsmA832*E@yu9yf!qaQH zUUB@16(+ZtmJYpyf?8{YGC^o4%bGXcAOYJ7377C(WBLGEyW4ehKPC$1-8}rTtKAFy zW75Bgel$qJAtIH$rcG5xiMzl&`15#h1AV@W1_6=+xY9nqGQR#1EG(7Bi;dn^GOONr ze$C!sd~w*?+M8_w9AI61~}!mX&P3 z0QxO3Ut=BD!8m!%F#U*cel1|V_6UyEe+^Lr*)rsw+pn*&h&+Jjq@nQ}J2DBc&9+Kr zTsKYrQNu)y@?Ft`l35k#?J(nw@K{d2mGB*?wC7f{+T zmn?Rzsdj#a==7iK@eQEfagcFLzwOUv!&Y~E-ik4ROVp5h6d+_eF>>8kAx+Hv7taG% zHTtn=ag%8jD@<(UJ#53<;k#zjFTGnw?YMsLxwyl$tsf#!b8IfVoLbPAHaTwr$pV=v z$GRI=)EA*TUa$4{^A%8!J?+gF$CzFuc(d9pMv6y7RAZ=O0i0H{<;tu9!v;(Gc1DB8 zAz#G6`w=2R`aog@E>!_jEcBxOkr{!Ig9sWJJt=GT-Dsy6+@%M$ZXV1PD^BtLnrx)X zRETl@@sr`e=Ar9dwicDI-$GXc zOEQ-+d_tQ$7{Q2Cn|cbw8XCQQ{;tnPBb-m&)Q1BKR8gOjg^Ir^P&A##dMRqv8FORf z#RHhA>Nn>AXnM&*EXFJrnKcwL*nSzS`~Wmzf8fG9_J5!89vEUb=ps{OT02xSx0)r0 z|28%~!P8_CS}h=8!FCOLl;rl;HHMSjVN&5>kWPsH{#9533ut?p#d2Hv%5&W zruS6J0$a?7Zg#g&EmJ1J@CWSHNBesZL7c7KQR3P)F)>P@r(OchG96m?n|>8z16o1` zMfvp*^}kZ4$;syw51sXDkxqLmcV0euJQ&iNBN#~0U|U`OuV&@n42SuI=6BU)eoOqX zTX2gZ^%(M9;+Y?s9oF71?vC#+^(SQ1j9C>%M}%@DyiiYOK`(uz_0vaIGiaDigyGV zhls|cTx@~<>nmol7&#V5METT(%d~+bcG3L>*1z{)WO3$5K$%VZ6`@k}#RBQ>BK& z?QEp?LHiA8=;ff7_vWF>rOWzlQ|TlS#`S0#eRzj$hobqddJ04Cr-tga?E;tS#0D}D z(~P|!LfY%{O!;Lm*DJ!=Tu$VsgE$b;tbc~9|M=eZ?$3NQ9y{zlYw-K{QgTy`@a^NG zbG|N~j5M{ic$p#5@S&&y!Mgi-s2Y1yOVDRS+s{;1v;XR!`0d!n2;za+pgX3>6m;n> zSR{s-FT;(y{gm5#oe4?e&?jD%fx#!j zLH|)pMkaSZ+wSe~tcnERHfH}PfBYZ)IcI{Co7mLd3_4DELifJ(BXzj) z8H|ZJ@Mz6Ot~SYszd!wzGz(H&hR!-o0kbRk1 z@tJIx9K;)zW0jhFBPx}kL_YXg2$a;xUk$x=wT|E*#)hA3IrW3sl zUf!gtWNSA1aMOj2Jp|YvQEd4iS=0Ya#X*O0;6GXe>`rn(I@tDodL(Kq#KCd?Jub!- zzWZOV8umZf-?aQY6-5njIqEh#TZl1Os+A&oc=skJ#`Gn*wMq;OU&LQ|wSMh>W{5Un z+qOfqcCk*FJ5=d~dWrH}?0V%d&62>K5nN7>E%jTFyaqXS551W2c5oTw&}_s;o31hv zG9l{645W9lM9C|}I1etop%IKn;KQ;4IZTb^?m}B6$9;SPJqHAXVgTtwD0n`mU*|>o z+?Ku1A}cL?HIrpjw$oRmAfz_G%XS5f)Of=u4t&+88oY)o1tFz{zg>gq9|G4$n``;L zm;aM!ZQmoGjH*AO&69(CTKFAE;g3~!Z<^XKmp=E0Axdvhk-oo7&SUg~$GFbotq?h_ zSxi1o%-WL`7}3>=PdNh-pAceOE1)ata5w*wd)l8pX?yR^9a~?Of_(Nth_Gk7?De!ns;J-ks9@6+SmYaf0!S=buZ7 z_u=j;H$ja)7~XV2g1m0#1{{=}kzx=x8hrN{D`@=jd#2`d1{pv|55>aqQ*jcUjqVG* z6ONvZGHFWW5L%LE*T~O3eJu{ZumM2v?@rLI)XH^cRZ6s?Tk@Bh%p|yf|j7CYz$1AwqF~C({y=kH47b1Ivx@d z(r*iCYIIs&0qmw$r61m^de|PV`!6l1~CpST3>0i!jceW~)?>bvCke_wM zSxciRh0c$R2SlK!B67~8*lAAn6_37Mv>**Y5kcu+8|0l6^CzDUzXtn6x1?1JpVtSE zF6F>;2OHi9WY_}cvC!@cXJlep$s#xz1+73X<;5qb-iA$Sz~{Lx{t5!#;KD%m%5P(s z|4Ih8R0+VjD|4VtiQ%D;q~Y^8{B-Gt+cbp?ZsaA7iV!Iz+JrP;{|6xhgW7!0jW*L0 z7~}jw1GlhhbuKJ0VrDSKzV%s{mooU#fuB6^;=AllLJbZ!peu5mK1kS4|*yE zufpuqJ8q*;w8)E?GV|fl0_LViYti1@q$iH)E-fHA_xH))y3UWaQ;K)tMqe|-K2hsJ zrT)=6E1lD^rsMtL9rK4>4k>Z)dhqPf(<$hCP`uU+j zxXrh3B5TUi0C?(7`U@Ed7z|joglApSzktAAwD^rXPXGUDC+IF+;AJ)KD@w<0HC(SK zxQVI;y^I7FYN$*ww)i;7^&zzAv;x9k8}#2dqNM=epS0u zLmP;^i)xEjd!^Mu5L_G$s@9U7QVe=0{SpQ(IxWUc5X!IIr-x}a|5;UR48Tt|xM6P@ z)>4LF70-X+V_bZ4m4$%r$@NU!zpKSWK0Un*JQIuYv*RFbBj&mO(h|?y2a16(^CeJA zo;HV;UgEYKm(&%c+#YM1TtQZ;zeHakjH2%Zpa0|-q7s7}@GmDxkWPZ>b=za2&kpN? z-b@c^z0vEVG8Rv7jTZQI+!#bpu$-HyxB12Umss}L5xkmNRLORI4U;Hu+f1sAm@xf2 zGM7BUzZf-8;vzoiuyDa(k4!T7(r|`(Y{E1NS@xIpiXJ!qcJ;!dNH#TyA%(3OKmblAUn6R2-jx4ymN2I%xxHZk!g);-%jy+y8O!h! z-7PvoM&o|(2w#wAv=gOFDqp=`BYzDKd<(zcvk*%<^ZJWs*XPCi%Lzr~a<6MLP%v@` z;%0y7vENglP}KClpeg+mVP+HSS^pa$Z2P7cI31p}X|KnvRdCVxtxJ1l-28NHHUI@J z=6sZ(y$2`kJ1xit8Od)J(0y^&_e;}n@_E~xi8};xGBudwqCO~j(ui_sHRVz8Qv1`( zz>Remw0K+eax6YRLqItPAo0PhCkznWK2$YSj;C^eiH%2sw^nb^6p~n@CyG|%NASaH ztiSgqbNO7Xw5Jls@=SuHrcUzQtW1-^q+C!PH?_&c{7BNj9E_$NmAheO8=@d49QZ_d z;N)fFd)n}Iv7PG|au65GFAm=lT`0IvTYicC)g?wu?-hPhJC5fyPP>YFw@<#jtj+oL z-(7O>cB2DF%*UR{_o3zWL3IUO{Y_f=QpA^2$Wyp+0#E4s<8N`-plGR_SfT{SBUZ?PQ^H8uv%d1wTlG)| zcOY{LtzaC$)@0}42qC9R(5wx2y4@6);q!m?NIng8qKEA%RDizNf2@0q=KTCJs&&2} zs5~}cK_14i2Sihl1AdvmBI=j)L&k5eJW9}U7bfBuMM*#EHUwNfQpn)r6k6FwMH;(5 zBB#3;twjnv0CvnyFvnm^Uz(B_63)2z3TuVn_aE;H!l@coM-dsar8iZF4Lhg5($2jg zu^f5oUl@L%D6b|ftHwT<`y^zWhYQU<_;8ITr^eV@zNi2y-dyfGxp-LtY{rY2VT0Hb zr|F}&9$0N&Nm8_Wc58EDWXDK}-}l+A2b0(5bHYvu-;Kyi32NS#Lx821fVNQ%K})hJhQ(21hp<{@-qp}b6sRLPu?2E z%M0#>)8l3d$DlD272?nbyKB71_9JySb?t_m@Q>X#h^BiSi%EJbI|GLg3(=>SyssQ8 zZ@+j+oWxd1Ph3BJ-Jo_Y%T+%HO&`8T*~$2ukf@yV++2N>AtRnemMxn7#EK@7LyqgL z;4eY0lf%G@$G>^ZDKx8%ywV17El|gyZu+u>s`(8G&Pz3P3a{O9)y8xLag=pi=3m?< z#;feC98R%W9n`M+`TR{4$X^{mi;!lj7HFk`oKMaYxadfUx&Dlp8IZkNM<7nDpf9Te zoiN+B5R}_&C#qxOVA{*m#6e>EHgPPLKK7&j_Z0dT02S~I;!r&UX*1Vb)7Qc1SdBa7 z9Q#qB79-*pZa0_{YL_Ak<5N$$Fu8&_=R*7#jzAzan~WfU-Vdiz;X83?Wyq&}s@)*r zMyz|tKm77-cKB?G?GxF2hUF!|D0#@oh z*KK=lOjjw`un=3IM2YD3{ryQmB~juLs1XI^;?GfJOLsDMm9MMUI2hYX$FW>Aky56- z-?(|In#7N$Y?!mH8ZUyki-AiM(ZvWI9VZmp^Xd^?XHo1&PozBJx~=)zP4xrMxyVJ+ zu2g=T`4E|Xj)+zyj~JG0Px=1+l5(;hfj;pkI)bBFVTH-uu)zfvYObp~g`cH20Elm;TWSp$dc1e|xyPzuJrn9pylSq^2lXdfWnRLRLS9#Hr-Jyl)IH;6 zbSLA_O8fC^2CuCm;~j zK)MDL%L{k{=tKG*gOZgrIcjR1P?k}VoR1@f4NIR{&@xth|Ge6l`Sg6~XlagZC)p@` zIzKI~eHBZ;t*LOtfgU7dCcJv>n#P}7t@Z2nWP3m7iDS$?82j>I;Z}0ZVXBk&>!F8B zMSHuhGVW&NsM?^;qgrf5%g>ttH)$dp&}YSQUl>?rMS6iXo>-9GiLov?XqWy{G%F zIVw;{&L1&(*za>F+FNUq@U^=C4bq7ZvD=q*PJBnQ#Hsp&l&1cc2Hqg@uT4IF# zNc~^?QNsd7yS~i-Ahxf0As8sKl_9Bx-BKikYC$se;a`OjmBE~_HB4_yA_vngt+5M5 zTRqx};2z79IvM6vOlgUq3ZZGJp?2%D<%F5`8&b>j@rA!MM1z6|69g>uTMhNfZG?TU z^?Y;Jc|)z9Y=cdB67h4Udqu^hY$#46CqjeUkcwP>tSy5?V%dc?e7&w~S&u?h>US%z zdpEnB8LNKI@*c~!#S-@W;O9H-;u`qE162wO*p;~JaHudIshx$d!r>?@dG2$0*yIZl z>tCF<1QffwVivn34(MV)rN7D%F7CbEnX4UZO-gHk-HPTmgFa^zF+vLq zSfU(-)qG-N$w*7@TXztMFFgv1VOAve{{|dQxkxo*>>C#|F!h<89&<4^P1A(P8cHm{+b#bPTbSb8lxVJ-LbBERVT+${rFH4$!NF^H-@ zIF-0?8NydbjQy1%#0h+!zT%gBEbl} zCTEFyY8eW`gUT5sn?5d{UM`jl*`Ly`>%QR@f3#Z6Rq21~XZN){({3C*g`{Dl@mJmI zWPEo4>DeUl7djez{Fp4nSp@gv4^8@TY@Qf&n{WeGECl^p;{X@>+oViC;0$y!Hl3{u zkr42lwDSH@*mmen?QoYRS%>PCR_4*sbYC)i1}%xoJ?*MWh4zInMSWKZyNSF^Ip=LS zw=TuM8jG=VxX#`qWpQGqje|$1OHepbPUBAdT7j7R>?Z?N?!ykXOnz-P{Y}y$+qX;S z!g12u&X*$olUUm0bcBxP@gG(S{es_bTOI4fU1=TSaXD8O&iS(TUS{h+7h?nO1g8AB zO8$dlNsZ$8XWC2n#gXgRIy*W`urP=)7AOp9 zh;rC_LHd=rL|-@(TQ{@^*UynFR*g?DOY_4&f9jynm$ibcLL-;Z%S05~Sh*lmbctL7 zcu2Wp34kUXGKZ&%dp#3#TX<=plPMEq1DN~M#B*?h(f99jvAAz1e#rxW?msbhT72K) z2YopI^AXy--Cbi+FvX6r9{V{10kas|4{)tiNKSKvJ6$Q6ROLSCP=moyc17Y14{XuN zHdy@_v=4y=>S12bcUz9PkBS{^46ACDf8OU++6)qR1j}g5lcdaeblhF`tkZ}$9B%{W zF1}h~kjjfgZ&xKOGTJx}mF#*ZFW)F1?qcj(7pspnU1DT#!ilh3b9)f0q*wl;dQpD3 zs>{X_Wx40+mQnu6(ik>rW`2fisrcsP;I}*-3cNV3EPICMwLHGuHnRq^n32*(8;@kU)WDiLHPLZn5;fBW-{xn*Ts($6CPqq z*M5)YYVkSV8TnjcE zmZCTukWb72#THS$?pnWx-+~)JF9!f8`LuOcKaQa-g9U1c-)@xlZqmb(Tg_1Z&yYy^ z!2R|j!l*9AuI3l_n|P;P;~P;5)bNY`W-`{v+ZLmY8xoGPtVx=E!T6vWn=08=)?>@de zA#dtM%CDs5t<0kA1hegNgZ+iX{7Mv zOVCaHwoxKUE9s*EI(ROlgtTfyuL&3-F&EIUb1NKQVuXHU{7L^X;aMi!fLdix@(Fy> zwG!Nyhx4%g2`Xg6ji%;nu8%};bMXYY&~4z+HA$1>`FOYc$dWDBkT%{^o5>g7ludu3d+ad8@C?1T6~XzmmW1Pope8_l z^*a8wA>sP6-|!E5aqC(PvV_VPxlY;VO9vY2$6cA>yQu@3+B3t?95n_lKBjHO7Y)*% zP`*^99wZU4KxKbCFr!rrvBJN?YUvrLHbyp)5}S3`|4QK)7=~O7hv@31w;5%u1C&0gbw)(yvl?`HeUCh zya8Xh;RjXsMtavDaOsu5Lz6CkDWw95@m{_lXIRA7K$49ZH@(oj-hwaawxE`H<}O@D z>QD{JC_XDtj&@bj4$a$lo(~c>M}d?;{$&>RIDy1g``y*4=CwUHB2x&%yhFlt`BB;A zYL7MIwODp`uqzd8;DRuF01>vw1ClC2#%JidD8nB+51-^$^JOUvD^=G|(GF~mT}ZpJ zC0_MxLm030-eBRE%db~MoD8CL$fv&uWJ`U`uW#=zX5lDcCREJ|)gkAre3jQYymh6d zqhZ&P-ZJ(=(rb;aCXypNs@C!1BO~RnCg!F0W8NXx_TLNK3Q7tMPD;f4gC_b4+@hXs zy86sKJW&WG50A8KUz96)Pvc|jEm=(b;NqzD>g5}p;p_ZoASX)TVtj4Ze!d780Ipsx zVjiovAO&~3HP;1TiuU+>GMR|`=SzmXGQz_Q}qv4HXp(s!r%lGuiWtmg4D0B1|&dX8}E7y z?GJ{0=#)H|mIRmhjdqIfh=PV-HjouXB!^q6IvFz)W0hwL9&}G?gkyOOcg1~t!9q1` zKNS6;r2jO0PYtf3t*r@sUc6#0g5=jI?j9MDF>+R&ZpPshH}yBa`$;^^uwluVs{;AT zv^xvt9@>2WMIdzYHK!%Y;o;V-rmfKOp^k`sZ8xL1>bEbx>~%92txY)8 zz`NPwI7+)wU%2_z2Gm;d&UF5_kutIItcb0E`0e8lO}q^bNi>~|^n+jPJ<7J9bk8RD z-J$K5bp9bLdqj`L|VTuNCG&!awJ7Zl~yYH-aYDEfIwXdfXOhK1Dl~zc@#I<&~uL+KleBr1esM2lW%@ zR~?DDpPd|Uv_9k2eOzMH_`t_*Ak@?m&(Yfu9wvQ=^CuIBer}VR9WjeXEdOV(Pltk1 zz3KdyVY#2zOY>#6O<4Y7tP;eZ-s1>g_rS{4i@EAZRJIjmP#sL^K~+fkOE>S%Vi!-m zJSKa%+vS=U96~mn+NMfBZgh=&P|^=%A+%?CK3PGBFXi8AOkOxGhcRv6r&*O6E;yJo zua{vs07;0h-mP0{lFyfpT=gt|N%>W**I7QmWC)-25|WaL7keY9(~2UX#GEl|@Km>x z=-r>s@TZQ(bR^}xa@wV|RV-EIv zdThm!f7ksP`o(pRRW&fZF$#;TcOt2_3~VftgH6Ey5Lkfowg4@ zSG_{`!?{Vy!`O3cd`X&ofP&UyMbPfdoX?I>O_`yn_=2$OxZ<>nj6@9jlWIk!Hd~VP zHM=hl#e2qs?kN;}(A5uoZXn>eXh(2-a;v|KE)}X5;}6_<+esNi+LnrO(%p<{9bCkI z!r@x_fdb$z@BKy=C;n;+_d_|JZ!4NYT!0rTC02wRV_)m#+&9gBOpfESHq0}C-O6YP&f^3G!HVL^gLVW1!A z-;;~1Y=WXp=1PR{p6=4D3PuiI89d!E4rWKoUg8KaP4VU?)=6COSU$sD)8qC-LqVTXJ z0WAVar+e%RAA5QDhKX0r8gt?oz0|B+&VlbAyMso!Un9taG&vt3r01CDkkydGKCgQB zZ9q_j+D%7%8VR`IADv>$kl(GHpN2`?MQ}%S6Mxnj!*Ek;{Hn*Q(arI9&A;ZBm_g^m zS=@{^1}#nBuX#7{Ij1fU+s~Zd#tNuVHjc` zCJB|n2o5@X(y%OH$0i|qN z!*E{n(i`4`fy7;FV9DUL$lOCsIZX;m+p7QVZ@bRHmZ$Z zbn<&qwG82CQcwH-aCgZ5V#Rm%1ndT2f;r2PdcdZty5*L>g^y!CLBAhBMKFWaXYaMe^ z$k`|qc4&iBJfms`gAd8?(&}T|6WS5UeFYgx_sc4(c`8)?4HLQ#f0)y-AD1(P2bzjS znkcXof6g{@{E{bnT?IcMZ|Cc6a>eI55^CaX(#_AVJ@XhNdKnk6bipXe>O%ra%hn@o zsrJc|n;%KZ?!p2;nYVwi&(h3%?>^t;SveJ`#kud;k2# zsSI({D$)vlhCk_%`r^fA)S?-ahLxPtR{q#J)JfI*B9R8<^2Q$;ROjrP zZWaz%ZhHTM8cvRgR@}+QELbWA9is*(ehZYYWR!9*&n5~WzU<}my3=o^1E`=@=c^;= z;lEmNVm8?l$6Qz!mLN5ZhD(T9fzA%kyniCn#I*Qv0y^vZg|MlPHl@p4RkhE_0S{&d zXq3a5-uGw8Mj6xYyh7ah9fN`~5;891kDq5Azeud0^Hqe4-#MMra%`JcnaSFh@xK;+ z-FFx4BQ=4go%=AUV=wp7OwTg?`pSWUae;p=Q|mtg}qr&e(C4f zqNzoHyci-&;EYgU>MDQ4sLafD-wDXJtU>KWma`pmi=MQT*O<&g_#1kZ_R;ud&l7o* zb*U7XFfOrvt?UF%2<_$hzAUw8e8^jHg^?Fzs(j>zGNky-PV9!&T!=CEh9b>nY4RH~ zd@<+}#oh@V{ufy}bf0u8*;Wcg6^E;?zS+E%d6dYT)H{e?=E9;Zu`i%K4Y~ znY$Bu;N@_&5C<1vW3=$RI zum|6U-Ct4R`6TbAEojVzJl}syQEevn<}j+n91itM$)Xa@zRAjiTef%t z_?HlJ4GYAeWgz2|u3^q9Wz0H9F1p=C_Wt~uFcI$S`(_qX+voXU6ZWf?TgLJ>?a913 z*`0jY7q=tjy{$kj0YxU`=wMSGbZ;P(s_E6au{R!VcE`wF#J&-bs!+_v2nY*aJERp? zw3lnN2wru0@j0p|FIqWK;W;2xw~YQ9#xkC}GD(PnHdJ$3^QnmfzmEHC-kX?{S^lCA zcQU-R@~#S{AK2OwoW6^VPi?@LyA!i0=rN){y$v8cKis%o&yI8Vnje!GawZfJb4&0f zo5Uq{5>h- z6~Hpp*GR$`Q(qVxgH;o+Xx`b!XDb!3&1H!la6}ZgNZXiGF_h=~ z9UGz4`v2J>BMd`YKOXZ>RNk95nTfx0@g9P+QEc4pdb$=kRM%Z#`H*eqbzgnwBG=h< z-ji>k1X+M>4JgrRpyTEqO6bU5e&a)A*_CX}!r%n`4{CC`DKFqLRw+xKczXWRGg;Ykk(MX< z;{6|&QT9*tr+ufZ1zqk@O|tsqpim7)jf1!t(znV1u`GUss0APeuV<=2cGM zB@*pzoOzU61RBB`fTyY1B8Lj2(767Z@6NC;iaSK+pEK3%%G)3d#q_yBQ8KFP%+c#A zczx%H8*Ul-zRPt%XA~COo}}ZqEPi9TA*Nn)9P^xE%k^_Xw-e>@dy|@YxuohdLORz!G>#cT$6ay+}J}e zhtv{v^ir!Xr|0>ZPf?iUmq7kLEAqZc&vIgC17~7l6rKUFRnTn+K^;EYfhewm% zO!YiwGvh-#famS4dsCf3j@=>__es!Ma@IIg+_<|H0lDXknNg}v!vqNmdJA_J18Sb2 zi3w{S4pjS0+s}1)`deHMY~9-Aul*6MZ5G11t-#dK_+I1#MxaI=?Ka4@$qecpmLO^hHZKevL|`yQfGo1r3N4ufkX6oiV0kohfDNXl*8~12XOCT_&_IsXw)RJQ1X@pvxydL?~;L?!)!bP`uG&FKQ5ngC$Os0dT;dzFMjiw$_3ur3q*vt zBNh{`;=fDpcIj$TmS8E}tkXDwMDvn7X3EI2Aj{k%pZ0}ow8grf=XE)}Rw51JK1E)& z5F`*I%tUgik#+?6I!~Ka*>q>&`UPbwXrZY!_C|Ixaw%3(U!*z}&1oJME_3}9#@nD% z{fH||Wjs^BIO=u!1AeX*wrKN@90P*B;U^c0pYx+h1`NQu|h!$4A)km=d00bqu0w%8R#6{olE+zMcu^Mz;r|0PxF=C`^yBNo= z{gJMlx$NsSPPD?$q5Fz@U~U1CR=OzqZZVs%!j7J( zY;VdysqK%YC)+B0e5AP|R&tiHwi_LZa}mGFNZ0lBgEY%;SyG;EjMN#tlqOfV{sYd9ih2i$ae5^D$BuuG5}_Qk$RE&s%Kg*l%G^RSU6rn zTwxb3^S7Z%FH8A}%c77Km%Z>$7QFnYx?Rm9Q;f_r-rzjGsQec?wCHthTe;rwW=VRx z{W(`aph;bnsSMOK%i^Z`rk`#iNMb5Eo=WoerIolhRqO_J!~}$!Ld5W4A90yMlQ;Hb zlO~nTT+Przc56kikN@I5@AdU`^YeFbV$h5+wY4HOWL$c3(%sL$&lCy~bc-*4DRsSt zCv9;?Up7*KG&Cye+Qf%Hdjo;D|K^^3-(L{1xsP4g_D|*@c7fy&_cB9ZmOWat&fsS9 z8-k2zD_E;Wjo11fN=PBOUo_=M-!u8DObp-5SMR5KwX(+3;w_k&3Xm3AZ!q@AnzTT2q(^|ZS_ z1PV%8rMP^1sUf{&=ePr;nV~<6ie=LPhVrFhWQjvW=)>tVFG0Usap!EQcBxPtj9jws zMC|vgb)(wq6iRIdnVWiRcRJ!!qw(_55m0?KZr+d5O)Yg1!LlmN$Jt9gA9}&GDs9lp z=Yy9+6}x634?(CgLjx@i|KCqO^_}Braj3q_P*q}w6vj8gPxj~OI^y1~9(iw3P`)V* zQX`$dH8q~%F!A=s3z&U^8@DBmEkKDZlR!7XnRQS4(wW$^7O~j$VaVsj0HypHxkbk@ zoNnELiQM9a#BP5)M?VD23xPj~ndlRwtPYs%%I&z1d%v~>-QJTMl<(?Jd&yv?Kt-Er z5%%SV4cc4byehzq>)aNJVcNHo z4+r;Ll*v)*Q%5g(wlJp;_KwOUr*#YPkUem-C;%28=Q7_*(~%tt;0DQFs}?i$%asU; z@4Qw%ud3#wu?Eh`5<^*>r2t5AF1{G@^wP|(Cm(#LZ?8#1Ya;7{MeIAv{v6Bg`_-x2 zde7QSmpEHn6xP`3vRf2zbD+f0K%@32UXn$MLSmU=W7uCsDHso;ZA@?|jB2wKdUZ`* zIm2>pCx+HJ9VY3&x0hn-0Fp&-tNicgDN3dPs#We{gr8+p&w=?no9pHRd#Sd4Oc{pi z`t%qrs8X&fJ@3&Uvft&BiLHDaRIDlrMb1=vz}R0fSbbeu%3s3Rvg9*pUw$VtlMTvZ2i8UU#c^ zE33ymT0+jThDw^rs)XFcA2X>keln|tg;OnUp@zNASwPI~N5FcrgB(2=Dg#_4?z*)q zM0u!xFO#h$>9C!?&TGC1&WsIxc;Fd$Jx=@2u+-a{LIPv!+?U1Dap)VMwGCrNk(czb zvp_-}$WYesqbhlhk}7F(+%JWfYqE5ewVbMPr^7Ja4R<3$&OEID$k=1(5py;@EWVVB zK$alopu*`TN8Hwfalbfzj`;Q3?!tj8&Yn?c1JB-=cR4-Z7da|2Z*(vSTC?_sN+Hke zG@i?PlDPMz7cPb$TWp=SgXV$~t3SQD^Sz(wFu2~nl{jrXT)iUdX5@KWm4iMIz<&rvUHv?Y}?Q@g)n_ek%)E@qSPm=!5qcUFx$mGai9q6kDggNLSKn94ifs38UkR5cMKS=4bwZ2S@4+mwVnFw; zBxMikkm-#mm3sT>x{ED!ZqIhaHvZ%Ux~o4&DA;DYDV%3$l;y#5^z1F;lO(4{T|9*b zWjHLmn+Kd~MZ{nO5Jmoy`?eat?-pKH+z!rSWA8lE&4hWIu@7|cf)C)nruAa?cg`0N zEnmP6A9dZ`Fq6fl@Q@wL2$V20P{V!GK`%3@Ff3Amf$-pRTMvejhcfC z<2zm{SIcO#fAfs~;SWWyy%+d<>tN_YH&9r7{n<|Y1e_}7mYJ6k{~ul79nbap{_kj@ zi6Se?%6KDt)9^M!Hd!TvGDG%?2oV{D$X+3`GwMwdGLtfsWbd8zyY9En`JB)9`~CUr zoF3=o^}1j8ecjjfyq?eJB~N*ygL$BFl|URUMuPkl~b%*5V6i*PVJLulRGerptYNqRFhmc7T4!k3i>66TGg1 zgM%GvW|Q0|*+J-X0(R{uUJ8P@(1P;k#RRa`O4!MUQVqL&xu>i7-K^pL`?FH@7(F>n zLD^;2gJ%-HR=by7O?c(*M0T$T0kwZ3E0BiZbtojOu*@NL(A-e?sKMNSr1dFFd|voz z*bWc$J7v-rw?gdk1t1A$ZAJhdcEIszNK48B893^vK!?)>>K$&4>Wfp4?=&0)==e@w z=_7!-nMMJQbdWeyrS?*nS20c?d(w#~utBWj;$H-JeK6!7<^$6g>z;6`t}hK&Ku*R1 zdaq)-Q3tXiT%IW?ntzB?)-2!-^JK+*16zOweggk24blo?l#mKFr9~ow6>QnL2inbiYF5 z*YKJ&IJQ@7Ol-$mg}SP zP>F87#7RR+g8B^Dla{U5dm6A{K6};E zS@oOk*YbSBRQ>o3DNC39%fUsDR*BEnUyslT-p-qMBR>RD5B4b)5u9ejBGgtY$qrzE(`f_>8E~b` zJS4|?M*Mx@z+(G7sm$Ai>VQ3AQ~WxDN_7mQ-vwbh%Qx9>amJx&H9eE*-xc5nZ>e3L z?&GN9SgLplU^7N01E;h*qh%KI(}#vvn4%ycp^ir@%ULKZPhk5eah=l|e{sTA&ai8y8YRjs;PWeuK5AlE?iRvh zq2ICI*L#pa{H^shnQknOmlXA+2?&<|(qk53SLZGyCoW5T_IjkaSwPk_!l;lzttqP^ zeS**Yqj}#5ksk9hpUAI$ZmDjU74sj*!vjBVRPN}E&g#jk5U_F{*FK>JJTbUHf6zvA zXa(;MvgVV+1#uZ5ugsf@q{ps2oKlhO(Y_nZ*!17fIPN6n!6=3VgVXhF2G9LjPK-LN z4`EX5u!p~kL!Nu$T0~88Dw@!2MU7b}ASHQaXaIKp$=AuaEdlY~2J&*4BtG{gNQBo; zu1CJAz<>sYS`)Q=M%(wrkk;H78?CtVlM0Y zmY8v%Q-<00qBN1F-maL!#SdwpQ#XrPe)hT(@;BC* zTxx3z5J6%;YRD*bnYBcIDs4Cp`?u@X^d+nQd8>|^;cRJzwCbh6w%mFG@Jh#XL(r?(z3X+*$W+y;?)X^qOQrXsi@P3w6$Pyz4i*s7w742#OE1HJ&|(H9i2AdS-n-j9 z3;5g^{V#U>H&(oY-6R(1(%cr=X+$=~0U41maR zyG3Djg)Bd>NGbx+Lh)qb-NQ71@ZguA1+~?2(yUL2Ou$cN+T|nx!KHnTj5mlw2%8R9 z77!Ael1=d{QFbKeQgU4rNrdc}6cbexLbj`czn0(3W(!Yyky?TV_S!i#Fv z&+NFXsgK>n`2FP6PE7RT4zx@3pJdLTxifvT5g6vMIi5Z%>s!fm`}C$G>m$GAkI;?f zZM~)1rAw5B2jQb2;%d8)=1=m>&5S^*P1{FwOt+I2 zXl8f@H2DJymg7Ju#D9F25R2Hlyg=aMNH@HOzu(XI+&{m{8-A5aIZl4qp_>hk0`BH8 zmw{T^t(L&z>*1c9!Hs`!d|dh?JQsev54#om zmtMVFOC%P0ftlA-z<~(DB*BHp9ZVsOz5*!ZtwR~o-u$A1Hv8%`j>`0iP`x6x0ikU+ z^8qy^5VCvb45u9b8D9bw_V!~lGEoXnZ2@IOJ0u_m>4^W?ro%LWSVCi{1|)$Hi~8s6 zEAC6blp9{ZJz*Xt4}CLUPH4Bm&j@_W0lRdv{F9xpmd@u~Enum>mO?lR-_I!bYRF^> zgZwAA@XQ;BrBe^M&loGnKnoB5A$1`(e0dR#{rU7PV%gvsqe3H~Z{vPtWMA$}aRg`H}g1`h}u>KB0A__Kyu z%&4_N7XQxAXAu`I>F(UX{{cRjq@ISNl$9|7-L$}>sQ>M8Z$ zEX~fsW++sw!x{d)A3)d8xR5Cc$HkxUF%}mJg?FfkUV>lLi;u68Fyl{}FbceYPH1f#M!%7#rYXUMQO<5F zi&U2h%t{4EW(GNM+kW_xw3lY08-UmZOUM;cwf-X{r-9TSMXG7QMb2O5ba{2*JLfxy zOTF?jm}kXC-u@AwRm@pED<;;v5K2?Qc({Q7F)rv?QINAud;Bs65Rx30FSkmZHd}gCio4VTDCHLb*TcR0fC}#1nZ{UYSNf6z=L$m7X!M z1vp`%dTsEoL~atYW!~M%tkDxTZ=q^<{cyg>_4J}ZmCfLT3)eF+I4#Yn<{$6mj))Nw zT$rC{4VfQvT$>#@`w;CLuSKv)Xa`H#@lAf6@$qa<$KSn%NDK&lEx~@-44r8?2qimT z1R$e_+X=sf;fP-4y3V#BMPW)dloBQDwmLmQ7gP=NgGUHP+(<$X&yer?y=NHyuG=m!^H%nTpYm)r{OipJYUI-zDU*4DlYAt85oDvX}#DO zf<91brY(t(D{H{UJ}l% z(GANqFi8iqASQI$9z!urWe;+}*u!Lp_t%hH=$QgO76p9TZZXlCAVwri^4M6+j_CrU zu7=lE{Ul6PTvf{2u=Q}8pev}^Ue(73U&2S)l)lU;80sE|B3y<&MANe0m!qm1D4VLu zVh8CG^GDyZ3rNpa*@-&9?gZjGI+D!APb78d72~B&NX84&nt9H;{{AuEhPz!!4d>h_ z_!N(dZOsO&;GM6e=mua0>aJ}4`}Y+FL0e}Z7X>{_A+EDFEpA38vLkvrek`F++5cn& zZAn5h@Gx7KW3Xr7gZj+bxs%!Oh&ud~&$*gRNWq~Qgt~4M?giRWFOve2e$vy5gIKZ``Jm9Gg8FQV7f7~@KR$p5QLe8x?q`qzL(eQ5VTnZ*-FYI6Ox zyQ_3!+*KLz5dC%l%0fOoKoseCaFBqs6Fwf3?vblTp+&daW?eoA6S9P~JmtO5XWIDU z+NlN|HUDh2)CX32U_z4)g@b4!#fS>ku=R*_G2v@7vKrTXgF&!$qHBYNAQqVFDTv*n z@<-q3!;On+5K*F}Jl)FFgWkvlWBA~JC?-mX!uTDf$$;&-+Fh~4hX&{c81?)al4>?Y z1ZRKKv#3U>J4uL0s2aX6MjOi}gM0Ui$b|s3yX5WQWYDWAG6KAfgdYD3IxZx`9Ax8B z+zBS~eT+rYg8>5V{J@cLHt=?#jHPV?W!gO)u`QEr?zg90A;lIQmr5jdgyGRigXaZf z1LA=Dspu`RBshhH3gA=R$y2+)OahgtfOM(F2h)H^^w-=fU!kx}1Au_`a zLPxtdJmyDPV>xwHM9#NXEq5wq%rvjg^ab4S07L8712rD2-)`Bg1rQBkm#B`;g}k}A z6Km}C&3V%S1hEZ~-BpWCoKkbmU?9P>V&k_h6m+bZVg^n4zF%aCeTCh`*WFILQETR z%tMdGkLNfPWKgn$8a0Jv@Qt(#6SgQ000Z(RDfTHe`Ep}oq4j&?8{NJA>Qub?lpQw0 zwXCp6K39DU3PA^x3TVxnVik=CT;`38zu}X(0EiKyRTLxyHRsA#BUxL&G(})D9XSOkWgzMORPBi&)IvJAEOOzyVVDX>Wf>pOjKJD$A zUsQvRt%Snfr*?+2{!x3=Y0Rzf$6q^Ax5}k(gnHM1sVMsprJapqk7O<4i|MEzfES|jlQKP&d zvJQ65e?8~|!19m>`Ge96=MV1T)rgS%n1Y$njZIvO%K3 zQ{J$PhiAG6U+(TDov@j|ATPLD(DqB5Iie>6xjYr*o;OMS;mSV0E|_sA4;M+Sb8HJT zIHqey7U}_MWCugfoSWI7fUePBQ-C{)c1Qt8U@zVkR?jDVj|6jqaN@52*T#5`F?qmT z^nZz}arBypvx6}-n^MHwA3M_b|wfOmq5}^ zFaj*SbMr3_;(9b^1E^?eISmWI$&ibLB3sB3q^rvcG8Z*pWd$F&U}^_-BwSegxb%wr zYc_t-D!6?Er2=#4D+x5OO`cKg`>NPhnQA1nf`3v;0#Kji>I?m0rvOQz(X!u~nMIW> zUCbFtf8n5@9#edkIwB5~rqFM=M!)-ME)jceTtMfOnFacYPM&*ecU$|qmPU6-vlw{Y^!4&E!l zPD~O>Z?vr1V~n`ed>1J=RlDOHGb-M+9?CqrvbXI?>sx@R{p`G_)KO~5hpY`B6h{tW zGUN^wVN9@Z;Oy@eoll{ZIYfyWclmG+^ubj80c7-!>fB6VDJNBCOcKS0(>)tv(MAjC z3d|1n;S6!H!s4urq2=%Y&X|DU)nmwFE3JD36v!<+>j05r`4Z}Gh-r*c+shiKB9O&)F^FehSx1l`fnGD$OKzO^ zx-{J6yjttP4kCFE)b5W@3J6ews2I!3pRSzr5D_+yX`V%yH+jO}EA`K1B<6C0OM*;Z z&wl|+^{wRgDbR&03U0&l(6vo6>821vLXF#0gDRtYaeZh#9bKXPa7K=w7IoO)bBe%6 z6>ODbA^Xet-P$xv$&PB#D+jjHjcdK4BwQV0NBZHfdUXejFJx5rz^gPl= zb@rI9dyR;8GdH{gV+b1X0Ho5i+busoK2a0hPjLMy9A0&(iVSPYK(9$2jlJT}S)O*~ zS?8!oESG)^{xah50OY|CAw-Yf)X1@LRAXT%Q5r9&RAzv?w}vvhX=y~<_d%C*a5;2y zoildj+ux84FJvzW#Uv$3sCA8Got>5k9`;+m)F!GW4NPfP+!#54Cz^*JZOw2jrRqL|t3rLvK4 zjb{HtU081pEn^)g-q?{40n#IS{$yHIWO+I_%a;DuocCnwrc9u=ijg zvk_*_Y5bAvf-02qRB*q}DLn{M#fI9jLZa_E~}=@0YzkD3swgR(y@m^6JP>l z7w9+XGy;wTbfAQ!XnKqFU(~)CHOaj$+f9jM)z9}>e)`?hTK4Zg2)>>rt|~yIr95$a3Acg zch>JKMzLK4Yy1eWtL@e9_aEb!Bj+Tb7?mKk^fZJ zjm_V6>fjA{3anRqZHkE0%*w|4=chH7>}sUmJer-#@*#&*ndCA`ceEM5qp~6H11$g+ zR53CVTI*~exV73`?`hp>hasLm!Hl6iMZ`(sdNyQe@^NJ=54`e=Q!#2V*^q;iI=LWeq31dl^!(KO;!0FjgKAvL3~ zCBw!~0c~jI4devWr)r_c*t-Kv9odX)JUUkoTw6ZBkEw`Qhx~>@z2PX;v9{xg!Q%7n zO1zM=E=VhZ%-cU81D!V1C+=O5hK>d>W`t=K~My%NIzldV6~$Us5@%8EiSEXLMEc!ZAOx2SVvm zxgNl}5mdkOiGy*;H?$%^?3b8tyw-miZndaC57CR#S{i41*_j7zDeSa0sZdKC%Dnvf z-_WZx1V}qbM5`f}JZmR@Y7}d9lFK;~I~`1~C$mu{WMJms#&!HB+L+qsp>|h{WPUg= zhNd}aQv5|gMF`SA!bV1UHn?H;b`>J{L3aoTJgZg2(I3ted_)Ow5%7FdhC1Dev`f0MjRTHDuv&*g^FIrIa;P1-q12STb9=19XvS)lq0c5Z5?tAgBRVy6E^*~7etC0ppy-?v18<0I>u;HLYFlK9 z!~i0Qawx_IfnRt*_dKZ}Cz^9Q7+<_Mu=A-%7 zLn;ss{=LS-uLL3iMxDSF06}hL9PFFyi{oK+Ni8yz<7E{N2sc8u=aU0lDG7=b5qj?} z`{BZxyJKNDgXckKNgfP;Ujd>PH)%b}$V~e5tMmX90uOi7KFteIaOy-~b7oM$)1FQs z*(72FlFq{Mghr4IvAw;0*sc{Tpm;0v#^jB9IN6+1&62F{yuVcGC!ASpr$ zzJ;T9oN(e$(^#_e@7`>TM1|EZB)yO4RcmnRS(kSemItr6mt_CvHRW9Zo$-TzmvqCY zerKrc1apC~nU$~M&<+GyIy$NZ5I5`q*!5!4es4*YnMMxg8&}k-7d(WpqoMwkKl}^UaeR4s{@cnTH>DzZc+Sm@r~K` zpBh{~rioLaoS;}DgNim3uw)V|;EXCRjc^l}@qR}AUPNB;Ut-F75eR!x#ym1pUcrpB z7JJ!5hq&N$@<|X zFMA?s9I38YA5_f*d%sQ)wOD{)4=8uo;SD+1+Jv!sN!yohO8<`h-KMKFBiEAr!BF;U zcbRzS;flj*MW!EmL*XbKccK-#FZ=0GH3yStL$zwPTwA8KF@s83SFU^O1 z=OQ?RcC&f!Y(+sME$9;fWaEHU+0jwFK#W)T_`fhlRk|FfeklRymj63tzh1pZaS`$E zHGN0!Wn`=$mJZ*KV2A%tJem}HPc3wZ>WUkIPaAA?I{1!Bc9B?h$+Ls8*G&=1-bV$E zR`P2T=dbfgE*TsfUei}IYSH%}wCW3gz|*S&a1e(?1KMNbfh>`Km^$*^mt~SRdJ*X{ z&XMj4{!D53MLSX_YFm#bUTz5JaRiOh{_}9F#mw%b87gVHY zL)jn+g;3!yx~DvY2fakGbN{NQmlxC=%R}DfP)g4~0{B{CNWTO*D*}@aG&cD4Moh2` zr5c5e;OOI4lO?nl?%sP&0lTij<{u}lZ-QD=CG9Ipi|Sgm?c>K+I|=0xaDOA|*j;DRkPSaZJ5`hznmA)$&dEMjQe7^S_ecfzERE3*f2hp{&c|5qBCDb z9jLaKJ->82lJUxVe}U9_ay&sLo=iXPDrl;-mX`I5Lq5z{a3&bFet&c{YpmBQXXBvt zOP-HsQYJ)CI6J?nS!mX7t)nh1VOI|IR!Iz7o4J4g#KhZ*(~fgVwWr+~;I<*TyZ#|} z@vuK4aY~_5lE9xs^zN^dO^m4YgCl&Q{;fB2PmWfIk^77wa^P52dWBb^-2v$}^E}Bp zz5NRik^LAXQ{+KL0O}hj@K-x9>?#9a9Vg5bJDX7?IGAx#EiqR$AQ+IM`Wh&QAht=u z7jDHq{vJ#NV6qhW`51?>*CoN?sDZOg*7bLO+zmHDXymxLYTXbu{AnZpoaJEb{eYr? zC_%plk5AX$-Mn~6M|5MjvuC+2b$-Ve++xlHcW$2O8XQ5-CfGBll@6E}f-T~3;g=^c zfhwz&3m=lqQHD(rWMX2%=DPT^we2l<1N3GbLNs4Uxeyr{D1vQiq2y&!Pc%|3b>e{0 zG{1%F$cQJblCoVRHu-%HIdWx(+?0@|<9i!$@6gUO_!)!{GBC!cEzHw#Id7*=SpmvC|z~b!FqJOzHHq@iZg{4r!Q}W632KF3XVIx&d9Gr*+Ys zz{4%4Iq*svuIH3fMlKhfY}7UY8}fOE`o}*0k+0R%)L?-D@Psm8jjrthD47Wv_9*Ut zPcA-*>3IFkK^-Ak4bmV})BwtNmq+n)gl-5>*I0-ns|Ce4%(Ho#NFhVkNj{fDkK#(S zU5*x^V;Az=luYd~g8e%IMVNTX=&@oF!T$G}OG1*@)9bVnxkhlfwPt13tyC*95f_)jT zaY=Nk-G@ZUzvm3<57vBDpakONp=dTW9Q$Rv=m?kw(7k}M=*oDMgH;96+62Unuv>LR zHqub(&8ZEolo04W((yDwIrGjKhS<*1k!IzO84Q#iUq+h!+gyVSmMWTy_KjG) z&0H~9)gN3_@_uk{>MR+6tpq#Y)x92mpTF5(Y6|Y3L62}_Edv`{PmigAj(644UZajBx8p&;bh%8*x+#?yUYB=uKP6a zu)XczJrkL#S=gGScRQW7$bTEhE#qz<3VHav;`FEKx237*76Wb#5<{wAsx||UBx(tl z^8@AfbqkpF@zn6FJKu_K@PypivFAOyRa~PO`8|iGf!j-ASK`pNVy^jHT;u0*dTwhM z^O-3*;aPtj!Ds)g1A+4gBqFy){>(5LC3d|Zr5Was0~1_70Q{(R^Tao8mn9*SpfYbw zV2@U!mog_u2&sxuC0EsiwG+I!JTxU&l1*)L#2cR#uAGZ?mE(StTKp_Axq;q26!E^E ze4$H$I;`@_tyr`K9G@*t?z%p|$Nq_gV^}4fly^^J4;LtC zvw${sCMRs4j!wmI9hVYw8Uc&PuYY&^?Rmr;D&p`n=v4rT zN!xtI&%1|U4N&Dn;6niqOc%i%BC+xFNsi~%PQMf0n|Gq95xDs5Qy~=-g(%Qmq>jMg zQWPpPlcG(@5 zr|nKkay)HrSl#@mx%TM3F$hjXNS_5=ypA9;rYz{A2Ytc`@g#}ch~J(1!qq?xv6P?% zF^u0Hnr0256L$`c6S7uB-CU4+QXw;|iOeLgl--@fDkQfW8Gq1z*Jh10Bg-2f=7P3t zS9zbjVj`nr9P7=PJEdR=0kR_U(c4#xOn=Pf>W<{|~R z!oG7{Cye3|6=@SUquc|ucb3oZ4};CnLsLD|PZ8A1J=={mu@a6qVwnTk!l9Uf$AlV@ z3`Koi-vLuH@qKzS9{c*6WQs8A8L~`@In*zC7TpA`_)QP4Bvr&*TaK4uS7((!LkGlt zFivJyR6Cn8_LBk`v;HCb)rkD&pV?eJM??80$}K!mHPyU;;+ST@HFVplWhau-PoW$Q zsSnzYqZb%geq=tNIiwo(&n-uTw=jv_Yj|9)E`ih)(r;g|?t|MOeq{Eeum@-evt5Gt zy)c9iMPAnuD}-(E?&$6hZBgB%$&Z;?Kndgchy5l?a5@(L2~DK=}cKk|?`M zu-u7w`?KPdQ3`@LOP^~L-k^=Eu1hR94Ufyb+zA47x2=zv2KmDOxM#^C_pH5VKBf0_ z>(LW|8*)UG6{T^Cj`Kj;Xi?~IOKKnR-f@0;7TMk6pA6YL5xgQ42Rjm?yU-SxNRKs2 zfsoTuF>|{oA5Trl-1rf~GY(DkO!$)kGm7x5kYm_sq#2m7klIM{-)gXowG3@!tTvHO ziBco6j`|x-t^e(Kn(t8x{i(7S)QH8Dl7p^Z6ZH%)x6n?>?U(a$T`#*x-E2S543{*u zn6^a+c4oB6w3&aB)k|_WvO0UNGGi0?IbUEOc)SY1@Qp3`az|g>g4WsHPVYK>$BE-Q zXPKU&osE9D_qMp9Sp*zH@a6j$|Fy6rEDp+>ql-tn}JQ*;r<7pk>d;=;Gn8 z9@o|e`hCor!<8oH14JsF)+$wJjB7;SL_e_h(sz;A`ZzJw04^_3Gwl)2*8^3NVIKOC zIfFkH*X8s~`$}Pz)2Ebgh=Ku_w{LQP1MrRI!}r~W8>at%GtBF&ezs@?MKUYr^xFL6 zSh9r9=9!ud*P(~Y^%|B2`Qp!Bn`e5BoFXaE(%!@3$fdrf7d58X%Bm%;l=!Fksp{RX zDk4^56!-f5RMy5b*6rEXV+}nXcGn5FUFYsKe3`hMgDHC756IY}LG^a-&7-RXd;9;w zesr|qDvZJc_deH3JoORSxB>$Z=;O>0iG77PiMBUw$G5IA8=yA;4hdUdyGuM`_t*C> zQ@?(!E}8E&*Gw(qet+I62tD_X>9m}&dT*$ne@~)Nbwg*PUJ`MEo;l!=fFZ(jV<|UA z2Iq>m-nvYd)r4KuooA6Xg^O96E>31!=xLc3KcBNZZRKi@ZCKuN*fVPj5CcS&&pJ;2 z=5<@1yWBuCjqWY@@gYW=A$NcZ=ThOVR)E(wVE~{y8PRRb5^L>7fqQ84^#nS>g&u{pS*mM;$FsLV1Fn3mJ9$Pn@uxIVeGj3&=*n6vk(!stt_SS+lvIGzJK2l z5n#oJ>><*uF$U%5fzYJt*8c=UPD?wrT5(}U<`LP7K_j>4iJdk-2^h_>Gn;M$JnZ7MuTo<^LssyEhI;{8w%(?YYi}|n?~SS z?m{*NF98omU+%*G&Pu1-Rvjq;N`5PcT?sC@To=}p;S7-QcxuNFaF+#Qd|C>gk5Ny% zLRZZ2mg~Z5j=m4zs&4PI^xHpM4g~(172`}00>rMqS2JsnZRfSn#a|j3D!GFKa%V88 z=nSGpZf0PeEWq89gf}!JFcd!v9*rbUdj6ai9A&nDSIMuxz~Ax)v$)jFv^@C5JjYeE zpAeaI3EVXv@-b__-kBJuMc65I9rgL! z)gdk>kUka)6-4LJ?drwPi@v<_dT_*YF#j&@EAaH_o=Ne*ZFAyhc)5; z0MH3Y;;tX~lkVhTaCqjvhNkQF8)-K2UMFAguyOU@>qy1miWE}w=hP;bpJ}E9J=bps z8KiD}|AAnVS{*0qO{q`e;ye=P;ByoIus!(c8=L2?4kPbG_hFVv(-`+8DdL<|pA>DbF6s+XL|dz^)wfuef5 z?!H6s99ySz+c%4;o};iR0JyI$^h62Wq>Pc2BXZ*PFFRG93TqbrD7z5R~ynV8H(auS;Oj{<3fm?)2UovHkF%PqZ~2%C@dY z56O4mkr_jk(Z#J#FWbD5K_*mb;0?a_DN?Cp|9R5LMxY=ndj4Kc22!Is*jJ9P4ON#B zMhrJS-`aQ7weR3Dx^cc;*pn>2sbDPX2qs}&bEy~qm9`RmKLI|DM6nef8H1s;!rb3y zFg2UA;%=kSItllAkJ1Pb%X8A%e0u`pWxE$5!@v3bwa)_d{OarBi^~&+o7g#k^)0&T zd~Qv%p@QOymZ6O8mwZFoVoHH4(DeX{aVAJNuTC8^%+}iEUhaA??Jr!Lr;^sp2xhSA zAok^a2b+(Ck_Vm=b;NKOb8PSjXj;h;od}ao$lD;S@R+^O_x|z+266 zhUpuZVv+Tf=YtZGl{LcJ4~bsCEe9x0^p;xF&MRKH1|uvYM7>99<}>cS0GOK>j7ZWh zndOYfpys`-Q{QsxQzyB)3=Z_^_qx3#+BiMRh^2Y5S*@oereu*EKJ(#T(ET1Z5JFbO zm8N7n;-vtwtC0EE<%P_FYt>V~age-gf+EHZoT;97A@&rja-2+=jcY93GO7b0r!L<7 zX?4x}P7U)Fs*g|GuBjuuHxghE|7Bz2KyA8&pAIcm|KxX@s^bbSf8O8cFj&Z*8r)h( z_Np2LAinBofU)43T+*S%(Zj*a)O@BFXlbsKKpqZcP{~p`lR}U|nqAC8hHK38zL+00 z`2x7;OQa3?;n%I2(rNqb|V*--_f{^5iG1!g*1}Wr=WX$8~zo7GUXw`>S>+jK6 zV}sdk4C)@^%*_ahXa>q%7z|jVnOqQdKog* z;C$e^Nb%xjo1`iTRWphw_&XjyrNH^-T!pWRY9L*b7)&I0U02S^zsw6*I?W;gzOE#u zGhq1%!yhZ>_-(!w1sHj6JB@8Ym)@Dv8(0#FplN5)qEO3!4XE~Vw~ldkP6X9^?xIFW z0zcoi7EFHm{q{dO{5uH7Rb$I5H<#zUSF?kT>{Y2}2q6^%Tz)x%&rGq~o`&E2ioB*J z!CM=sFYpL3VCfP3P8+vyf8|_-bAW$hUkqU=IfPv(N3wu46(#NMHO4p9ohz5*x%w=B zm7fn#=?u68MPqrPMU_1hD*ofgk1UWSYnI#H7?YrrA_6<-<&8NL@kYiq`oZMFKK;Wi z2fP>Gy?eMkR^14&@t;ycuz-4FVbRNN8M5m0Va9d+LUBx1+5h$E;OyML8||If0drd2 zW=eX-3mfR~=7;E4XyzL{gJ>r&y}R{#cEzFZk&cn;`w9&A#>9l++ijE4o?`_|U^?CE zTO3tn!dL$lx*WYsKK&3t9?=t_6Xu+woU(bjcLvALY>qT3v-rXaQMnrBMh- zcw&u|z>GO+=ey97r_KJqHwl!7djtaVmyj*VZh(9!eUpif$xr6|c@4^84*LG45||3z zI$CJ4&>_dr0WSG7PZb$ZEQJUq7W*1T>l8G9pG< zk;EYciWloObPafW$>rw6)_b^%3{IRlkt!AjF06hSa^`h|#sb3{%Z#zDyF8{FTX|rl zXAuo_$NjJwNkNdNyQdnHe+4y{KR+naVD4`w2$Cg1LRhu+leEceHQ4hsd)jM zn4Fkg;Vq&!WlNS1FpJ=%_Qq4vzL<}A=!!fTHdw_0kQgWRy~e7WW8y3O`;(G4vM<|~ z5L06Dcfr(m4#s}{dqC;F0llP0nPg0|<;4Zq z**V{+QWPYO`E7o3#n4!ZMFFnEf2t_ZaN=L`?E&zezZrm02iJ_)U%?q{UHA(f&*+$isuP;4JP zi4iI94Rol$+%f=L?U}+64(8sGw-gs7tmUB9WgtTxBLl_xk0N@hV~HG(5$^1fUoxlV z&FA^saiI1pAG%&tMLsfqa~Mr;*q>DAtMUQ!?T-F4Ak~!Y(VO-I&BRA-PyaRQemps} zwx|)llOczT>m8v|HWc3vECyDzk@XXdKB0mLQN$HR(1DUn};{G4Jx!$yN}(Q54T^Z53RbA zPx_-x(jWLF8nA)0q2-%njpm*s$4>e=;Htq({R^yZk>j*1wjDCE!t&j~?JsyHelOpW zOShQ(pc3co!PlQ2V>>gHm0+6vdCZ4b5@)Smryfya>R(!&?9Q!>;mP4Jk*RMljTZIV zDz~eKc&~Aq$VM_Xvp+{1H8a0M9JlN%%`+CR-$GODAhC&q&kErO64^6BV<_xD+z(EY zZ@=W~gdaM|_C29Y_7P3-jk<%i@5A5*5{>;xA-hhbv7(eY=Tpf^QYp@V2&Ep;mI1B)!k7H_v>rDZ@w9uZx=T$^d{rM3geHpCOnBzx z8OGu~vluRYR)b9K_3@s9q*+>OOsGX}B-(xE#mP5BgKS?N*j$PXD266fd3~4rS6~?N z@*uM1Zk2|2A`w$XNL2E-K2;JtLFk{v4aB%_*@n48JYb^m=L}_9t(9=7O*#Q)I#TkA ztnl8P5$X`%tpFtepK`128)q)mxUaa=nx&S?98rH7)meXI2*_G+>7fk3g~>N$%??nq zxqO3!^Qv`-&|vR?1*aEPgb0;2CykVI%0?=HTwUNfOTW{~(Bxl>qlxX>OjVv^IwAxo&$@9bobvms9g+pu_`VpQStdT^6#Q4 z7cXB=fIr&VT2jrPu|b8!lAU>5qB3~yIA0kR@AoT*l%REV(KBIhuFagCQzV3&+|p+U zVnJ8#Ql+{dhJDG4G%-S>IQ4Tz0#lCc;eRJ)J1|a1bOc&3d;-Luyd}KBx|936OTfm%uU4jJ~n{)UvH4ky1br zs4d5uN(DT+T87li-qIv`EGl$QO>Exn34j_5l>RcHZa^`Vd>ZB>#-z}zmhYqmdrOv{ z7J7omkzzEeObiucB!rppK6;z~^ODpV;>gFa zGDT)>$=4d7gHAV8?kny88dGH1#p(uJjOSf<9~?W)u|3{!q_~g&_7_YoOBCcIX?2h9X7(E-BO9085cqYFp-k zepC4q*?uj7kI%)+^gq_e0vLuE<}TX6fxHw5nX*E`6f~gaYb6fFE|Ef_yXQv;fIRb8 zHJG9hH@7R7oL$0M?u#Y|54E4{NWX59TYj#h{#7Ap z&u}p?Bwq6Fub$Pd^!$UaAlQ?aAr~f-$_%=WoxbcB#OEB1H7%ZsdJV%BVz(8a@hzB_ zQ-cws^EYYevlYf;2oX5UeDbESp6O&&x`6!$54TK1+nk;%+DA_duZdix;DC!{zjy>| z2(p_4K9f9ZSfWy`&u}R9n|S}6ZCXIM>h|jWtQRBe0>4J~EwJ8!nw1I}uc>z{?#hBFlOJUBs z;+kt=2xjqP$G6Yk_pd#k7tX?%sF*n4OP=yJbcD9gcLoOM|4e256T|mP84wW-d9Iy? z*%S&T2C0Q_FgU#4OB~ycZHdB`doO{8-q1!JMrB?RX=L;W+c<8SFT@BsBT-bjr7X%n z_CczX|JihdlT+R%4jxubnySR%;X(x`Ij{viHj35{u4F>VZaG{t>``g3WSMMp~Oh6ut# zfx|vDR?2Ij7SmVP*nGPq_|#^e3cd!22vT{bnuw1OYh2ngDM&cr+qqmk<tU(j-Cr#^W3EFR4}5`cvk+=%MmAg4Qu_*=(C#;A z%K~~ys$KHe7g&9~jXE%I`p2@1sr?&1=&QeH=(@(`-jQ%RGoTs!`JsTGj5yc1%DOX} zr|i>AM}KoSWO{A-W^t2e!1=*Z8=E56?NBOSKhRsxCdAe?|6tHHsCMg~mwe^A^eg*X z|NNq7UJFi|k3+xorBxr~aZKCeb&Z~!ZyFgjepG6-m2^X62hB`WDgV*@Y{=XHPqu>W zOg>2>3?Y;txaZP`&CTV_AP}b>3&idF*dSZM^{h(D4>@)m!ueQQ;962zouu=Ot zQB6*aXCIdiRNQUg#qOzq5atZ;*`NdFk1+nTvpl@^QZyM4ykMGaS3bDCZVxafaD-ej zxa)%mme8zPcTc$vomlIoaGWJJyW?fN5|%-el4NMV5+ytBZX(zJ?SB8a8yq&=h3tw) zG?ps=HxQ1FM1Mb<>Akz-xV3J7dF{FjfxccI%(G^ZI$aBep3>U=s62;t+f~>dG&}a; zhNclWPt^(-IshJXw`j6&!g15cq9pZCql5rjrY2ILA{4axeD&dR@Eh$~k-G<;Cuo=; z)Vg9xJ?<@`Z@o1w`v$0@CE?iN9`6nL_+O#l>=o-uau?|IFM%KXleR=4r@Q!ZFNqS+`Z>PPvRg8BNcU({QGf+PK z8~!%}W0kp9iNYPf(W28}?8MOfeTK!flNcK7di19p0vQV?1iNAD*x&H(7 z-BI;e?+HWM$`+I?y%OBnebWQboJsYIxhD68ERu)MLTvYUhpBvz8~2jwNUUjjT$A`* zRcfSw2*w8ndbUmtr30NeCheGO4dl69wzb>8;37Szniv&UT_i@byDbGmyh_i_(wiJN zfpoSIh#d+|shPT+z5lT0|50QT?Dx4EfrUO}Hq8;@Q+oJeiQ7LSQ+l127$kZ}%madw z^RB1pC_<;`fo8rICp{#e^w6TrStA&FiA)yeCF_LiiVH4tKbe}{HUqr1UifVu5-B6n zFmzp|Ah3dycHd>{#E-uiLytVGeMULGttFzG2IOxSXsVeoqq>Mujq=9o+>k0TIW<1g z(Cu`*ht-ulM2m@vg4|P76{e++_X^J4-{b=&SM+MS%fgQwr*K-~Fq?rzZV5C(=AT*+ zlnS83hKUD|S*(b$4ZsBYSZA7K=8*XCW~R0Q^O=rlhYIkb8Ol6Ajat z->#jQ=kMY^*Ybqy)$#@VRZTvjZqITM?hUX3`uxU1ir|C zSy021^bqhJBZOQ|pJ2W60svFuxx6dS#i9|eoZto6L?z2{8!b@Z(oHmoGૂ<>7 zjrUZ;m5I`EpCI4O2kFZ3UpCTxq>FvM=Y4Phjf_iV^uI=yO9T=3L2Ie~`1_m<6on_b zpVW@`WT{`G-8?yiSPRWxkx8Znmw_cL>A3lf; zp*iZn7g|E|%B;S+Qep4Wk$-lKK3KuCCtE;O2`rKsS(x+SC9~Z27VFiQvp0dc<_~73 zXJp|rqhbMq(}TJDkhx=7yc;J5tHlj#y#{#A7`H@*aw9VJVRGGUtr?J(Qgak>mN35d zvqN9frX6fTSZB77dG`l>{Pr>wfS$^*VFKA|PhPk)&p>Gx)ewcao<5E5?&$d+Af4ee zduIX&dFQ=wppj~Tx{w{fyPSzA5-mL$TFMLMn(W>@qvi2o;qSo9QG#B7#<}UnU#45@ z3z08y4dF`wJ9PnknI6y5xT4=2zGFY7)ITD2kz}av? z+%D+*r&+_*lGpJ6_&V#bs=BV-tEdPl3JB8Z2BcF&kWxCNyQQR4I#oha1StXOPC+^a zlvF}Ql-`7tboUtxpXWXA_g&{Y{^@0~_gZt#F~_*?-~E)e-Qwc^~#XM(lPG?}NgNa;Y!{kyL6i3a9mE^U8cImACIKLCWv)=2lZT1Q0IvL}atm z+P5#TwH5b|9u%6d*bA25nR-f6i{?XfMk5JH1LgyC$DV3R)56zI5kM(?EuNUuOK~-89s^us&0S^!W{iZa+R>%PUGq9bbddS; zZSi=?hT0tRz!AvK#)#>V1?VyT8Zr0kHiM5Ox4i$5op?rzFU|kDp=V!o$G}T9w=LtU z*;vtsyqR0D!`P#Ai6S)X-S0-T>5o-g@s2F0YJT1TyH~XL>gGazv2@SI_No`9fY8{4 z2hJfiRi;^U+q;#GIrFQ3V6@l(bFzDNgj`LNfLmW%EKVOvx{8a4xsR4%s2YI^am7gg zg$JAV2GmdRQi*cILH?yTRYY|_Y|~G>j$SNC!BT`el4mnO?6Lr%a2!NG?3J_HTAL83 zFHR(oc)q!1BE2QM0{2^Z10Z_F23headl>oB7H zziruN0(#91sh*OkyKPKSG(in3X>|o>KXCVYQuvMxExo?*VnMb>A(+0zbtq-(TKtEN zG7DZ=_ciX{&e*tPk?+G6ove*gj3Xd|q7ksY3?@;#HAS*boYCfoZTkJq^@jH1B-EvW zDsC;e$lDP;8@h~^ao}kL&t}#}+|H;prL0GIHl!*92_^40=zn^7aHWh3jLu3bp1^mW?WYEw-Y>)ESF7yHlr-4*%#pg)xBR<*S zR}&$g+UXz~pt7tfPB%*29x6st$5QJ>Sa`%j|H$??%c0o(r^y#I6z0uMA;pzQT}=5T{W zM#}58jq>N@%3C-Uy+9w!i^&a>BwC{hwu>sgXlq32Q(=GnMC*Oq2KPl_2>ZZit5Z^_ z{DY!XJ=_Y+w{^|uG>m3mZ`JxNjF9HebyJ=dQd{$4eV50USgM!M8!R>G%VhtJbaeWC6n*8{C-yIG+AA-MrNBSUJWCAMrZa{?=eGB{14M4G@>U3p@HQVQeTA6j2&Gle5dj#BR+sB*o^5Wf9xH$ z0rOFH3)7}(UY@X;2D$?N$y}`<^3%Y!un#O@4!nEURqKUxA`tGUXX`~4w1#LAhE#b6lf2v$~)hSph`h(^o9)W-D)GU=4QwxHEG4#)C` zD;gaWEx1=6RXiw-zR%1>PdJV8 zXLoEeK%sZZ6a?SxhhXV$_7L}Eld8Ug^oX*qT!3!Gh==P-o3VQFd&@c;QEvLUse+3Z zc{Et{kiNUuM$5I@zn`}!LW|Q#Dpxrf9z-cx+TJT0Ha9(B>)-ZO`Ph4v4mtv6W5ZM2 zze;s92T2=871yG*JVPO{lUEfb{A*Fk<#SrbD_n{Qq?`r+mpQzJm0@GrK(Gkpl^+pM2k67RwpOz_OfOxSo_$(lHNGmwr=Uv%1Ih@gc4AjLRAY7q6gb} zYiEQ`6m>t`?n@GQesWKpm3?Zw+`|93acZ*p_xQ_Xub2iBp(G2nVmWQK;w+)ASatbE zGrWj!LhWZ9;7c3vrrZZj^6%z*SE~hjU(Dp2Eq_zp>;gW-|7ZgT9zd9{ipR@LTqh^& zT#69!1+zP@@mwEA0AFt4e?ie?32C(ENJx%bl=&Uk`vN}jt9Gb6rYsh1>*MP1K=&no7FduXK{Ev|`w#N4(oe^N38j_NdtRG%b*XVum7 zYX^fuURu{d`t0=8(T1MU4iwuB=7&aVp3B=?9L7y*?u@#I1CqG5Q-3Iy!;La%68}e7 z2?V8Uulc-aJ-LE?_#5V|rv%VcZ+#ZX&;!X!CgZ#HeXvnE-G_-~dmYgdD#VWLt>Ent zQjp}nRmfw<|4_?xwB5!aG9W;sX&mAP5vcyrGCoz@6pRD$ewcXB(T4%CSFlc6js|NexC z{UuCAX)zB9+x3^xrGXWBl3#z1PQ|js@RYfw+dPflY*7lRxF4Db${d|?;e-22{J*_M z4a1W}U5E=TAf{BDUpE|sAtW#_n5N8CvC`AwMiG1DZQuPZl~U=l;U7|9A@OG`YNd|X zdwVXEDOWFqmT7QjD2)LhzZ)#GFT;hG@wVINOVwUjs>9Lu(8ZMLY2HykYC3~jcc%t@#y=(igeztY9(98wo=3HRJQFiuJ7P>J-yHZnxoRr zt+vf9n&wvaO4$IVb|rDp_nO~ppPXC?6rl%{x4ZGxcP&Y4aF@(vk<^V$;7YJIZOD7c z*ILS74W>DF1|F8p-?u4Q*vM8HxX9HX=68)*Geb1fJ#Q-}@7uE{R(B;HMN*d+DSFD; z9mMC7jrs^e%;VoAJ1yN`o{)JJFjV{9>c{9W+~CuwJ53i+o2v0^WFYQIscIS{BY|i# zAA?li6Ut&xQI>h?D;MeTGAg7zdwL*TUfQ16gBjodLL5(@G2|!3hv8m31Ky|FhbF51U5rJK*44!i<0Wg(M(1RSs^?E|2yu43gr)9ECFT zl+TMlUn*G~7fvr5@MZI-Q8-q(Z*kv&%%VaFh;)A!@=0RIf9_O^xtJCpd!+GqTrw#* zbawIIR%gYH&g;yXMZ_tOAkD|A@}QT%$0Q=S&zkv0o3Pm5P~SBb(pg6UVoxbL^af;J zFkvE%p&3y%#tLBm=G=Zcy6Pi~vpDOpkz`cw8-*A#jZFn?Ev(_~dOIzb`8(F-U35@@uJ@xh!PxD+@3>pKPeuhb$lq`(4wAedFB>T%~) z76iHhIw8m9SWWZMn!Cujp(Lgk{a>aV#B}gl?C@#`%EL!f0?<*zB>}xa*-kj?X&piw zBjM?v9|@RcGn|bR07#&~#04kSAJ8wJc>xTK#}|P>C_xj_-bWnEX10LJAUZ#%QvN^y>@+jTpqyVl;RyBEzKVB&eZQ z@g)YaD0ed7TOlx_e3h0Y8KtH{G%sjh-L+jK({;?jx)L7eWSC3c3$z;9j&)Bt_dg|{ zwwYw7c3ir<` zPQRL~U0Eaz0igWt%eJQ2Oo$!BaV`AL=5V2M0H=q(&%%d_+gloM=8-y_-7q)^Eod znwu&I%?DqG;Yr;@BrWFwj{Na5Kn}x!U;c{DoF>vF!m4_`&qxX%^W^Z#)uA6N>~TX^ zl4mNs!~vr5&mjVp-g9G!6%5d-wvC(AAzoBsf4EfjE%SVy>3UlPBK8xG^1;?z~X*U-5t}H z_6D~{`-f&i@mS%1+(YK9(8k*R4tQ8vw>UydkNty{-aLGRiy;jpVqSxT(or-wLwH(F z8W5CK@anU+K%no%Fr`~lw47{Tm`UN1AfkaE%WUSV3wmC(pX?Vm5vi;D%wZC$Sou7g z44;6|lkEsNCa!yGJ7Jh-&hP;==h+0<1h|0sK?oLux8QtW z^z!-f2IQxhOCt|quMq(7cmx-IZmDE;AC78c>ueN6CODB)#_jj$&>c0`iZx2*oo!Bf z1F(v-V@BUFqi!1+;v`tH&7=qAU;WnXb;EzLOysXzZaH0@$c@!hk8xl8ihFjO8AQqo zwnVlMuWg7iFy14Ux(T>T@v(a*mL$~34qP{>oj0>jX0w4`dz{nd&hs3`qxG*!1#nF# z66Ha6?N!VdBdU=nx#hQ8zrvEr9J1=9qJsBH8%~q0C%Q`87XKfSk%1V9df`n_w-P0v#);F~_{M4K| z@C9w%^*->wa}bQVy{p4WwLv7ff^dwO`}mH@Ep3%B;ZZS9N?ngQJ+|@T1$wY*HdJUHm#X^-%gOXR`h9#;mPI-D~iHMgQZz0AY%OlQ$ZP5?6Krx6-S>NXm48p!0}yi2P1 z%{XYBOba*ADfc3$*|NAQ-x)>Aw-GW^^wwFXi9TKSPIx|A21uCEY4@}bSHp#@@%wjd z0cL|~w?du61Hrcz0^)|{%+14bJLQ8%W34b}<;S`Sj9HCMR}Nb|n0I|ArI+?d@ z%rYHcq*d0j484tS3{}|0*2asv;c(sh&%R_1Xb)g0j92S}vZOwh%GnO@8C=+NWBR}@ zRO2uc3~K=92YN?m}m?w+UF5z51B!5}imoSlPIxN~2 z8QZ0FX62qx588-6Pc8Mjd6Ze^IlU~sT?`#wHi@D?X=08AyN4T7>lII3JwylC>*~2( z^_KX@?C645_ch~R_uM*ZG)v%FewB}3$c(}dbtQH48X7Do7c3qu)^}kZ5U9JnR44vw zE0OmpxeR{G1|!fSsg^7gW4jJ$*pg39WiQB6LTsF!C|Drb#)7!q7nF&3U6vDvr^4cr zp(^^U&-dBa1H?piIyu)NWQP4r99T`!&x$C)jchS-UG+jFUWiL?bA2P$Db4fmwAMm% zdh5Qd>Wi~<5T}--7>oDacApeoQ-u=!J{6RF7Ib_wV#9H~-ei1NV=L^F%*Z^}xNoc7 zLKaDb%^m5YUGhgiUs361-Jyo31wrv@6vNPf>|55aC9%NQi+y84kr)pZS> z+VS01gd&=TV{zly5mNcuu17Q>D29}#4nRDE{*N+0!}+(_&b4@PvCDQDsUm zYLBkfV}-T4@sxo;ADXW&40K7Re6DI&0>I= z$^pF+qG&tioti(#91s(xhG$8SIYsK?CC{Uf3E97ON<4~!B~P8;Tmm5yv=tdbLbWPh zCV5JGjpsU+2j~gS!Zt2BoMY~yRAV7x;DKub<&^n1OL8%R)KT1nPE^_+NH6qV3$K|# zbCnjSGfqZqB3hYR#sYYu~lg{Zw${J7=qy(dYRMRU7o*e-?A4vom z&KZWMDXOU}nx`^puzL5@){0Ei;5sk0uTU3Tofep6Lw)X{D$|)3E>Ze84s?2>J&TBhaoxs zvv6AFK7KCB#qxLS^Meu--$2nA!UL2D@!SOE5NsQG)B=Wn{tWxuaOYB$!cY7;syjbP zWjZ&0((E!$LUbyAzv!veGei7m&xB?4VAd9hKOldiz}Ov%A#)#c4`j49o56w)q8ID% zi`2HFx#47DL>+A`a5-%&%)Yh|Jk*G98}OYIs4GY$1WmSI^kMe`DoM4x*x`( z{R@Rz&4Jyie2M|ku0F2Q1L8O+&aSM)Dosup6=qvsX=&5qEv91-tEJ=}3Px(=k^ejo zXvicczdV#d?WsTRY>O^@@`{2_Tu~9rcSJ{*=I6l$1cx=QEZ>))npbPBeqM7LKTL-- zo0Oiiah(jtYH3}Lb87B8uC>yOxKB46S zE{zOH}j`v@G;`Lf=>c3n`w4q8Aw9KwSdjdoC=0Yc>?5>RXACgR@=Es{0dkLH5 zA_z=H?yBaeLqv9Y)J0v!U?W7qe^{uJdVr*@H8MXF`ZxYtY$e>(_W}%k((yRxH zh9AL8)01pM_I6zV?1NfXmo_yfF4ssZ;w9dvbg^bV(f6>#LY!+y>RGoMeh{Y@-@_Z) zn9Rk0=}Z+`BtML5eP@M?Dsr>#Xw8e%Qq~hb;z>QGy$lRC7>S-a!*oePGXu0oc|EN^ ze`-{su-*qCF>iU(>nav<7W?Rt6^Kw9D}DS~OL{)%@)U!XlxUI=6mO`RDi&p7)o8{r zzAB1KUt*Z31Y&3mdfz*wZ`))mSEQohSOy*TPa#-1;+a=*W$>gmAN6>`LA^EHW^2T_ z#Vt2JskcAjNK&hDH*PtRO%fsXf%xa_3FIt#3r-Xf^ zaJs29fNe!SViRr6tJ}tsN|&UhUb+ntFKg0oZhf#w(&M0-_g(>J0}tB_+j5<=LsgAr z@z3!3qVUR%0enK0Vu2WLqaPG_ay_7G^9B!KM~uKr3yNj$0<6%!(~j65B%nP!;C2Kn zgwZN~_~&qfyi$|H*NQPzaSwb(2K1wvdFd)OBK%V#uT>L5=U7VrmP)P?L|( z1zh>zEZrM24j-%8{HubmL#n{v*8xkcr{*%m)c{275%zpr9qJGo~R>dW4u+LG1>^!mgp3j;!(wPopebtNU9uS6Q~M@rlsVe z*_i1Ig=3JEov$CLXq0A+Fv^4?Jjcsx_3_?_uBlGejCH}F^E03kCq4$%&3WS=Dl=fN zUg~|T$Mn=rEWFYqG|Leo7)=SiG#sQNlxr)IhSrx2hrnt7rf$jC%PRBDO>Y|3kp8kk zjrb4!<-r%J?#1{=IRug0mR_0#iGUi#)7q<{hss3k9W8V3w(Y-wd&f zsEyq>~!5qx?3&u$GFF25n^Sk z?1y2sPWH6~OJ&R5&(-wo)cr*r^^U?CO9RdB+ggW<2F6S+cfI>nd?V&*nZv~uU$u$V z)9J&<`^eb2;G<1xJAkzkhrdKKc^@p@oA;P0F>0fizD4+65@?l_@j^a=nUWE%7(U>p z>$aeI-l~fe4_FVLtV>idz^Ri$IiNuH1`$URV%hN!ZaVs-W`tWl5wvwQ^p4-PT zOcSPiR@tQInUENvDT6c2|6ORvz2f(00QWo*Ln1WVloD?B=Z=S#b3Ojyp4Xoj&b3iwgiLd2TmAo}m*NQ=FqIzi0o2{i+q#jb8 zm<`pXq_y?_*Im*37ZJQy0k1=we5^*H2#G+cBp`!yKBZoM`$4IYs6NA*NhK@Tu`d37 zL7MrVA^4i0U#jbL6x9+f+^J&Ath$cg`e8NZ>oCLi8x>?HtroPsdmoZ2|K70wFoyu> zGyt*Ux&CN?tdyI~NYi}YgCf<OoIlDpG?8kTK(+D+2>eWPN$H(nV@bTRBEyeZTBwb@^ z8?|RY*lO7L=0;^*jfz7+)`{tUN8-0P!NB!o=bZ?R(w9<*e6Gn40FgA~P{D&HhxOj#}OV$znM zN#r5zm?N&UHt3nTYdo&j;B?XK$>ddX#Fbvk?yCA6CD`?JEjcV`fHjmR)g+{bGk_De zE%Q5h!_~w=IbSd1(4g#Fl{IdoHyBou;Rcysnjx3mVWK&tV*BIk%>CIA=4N}k&^al8 zcNV8|3*UckaUgR08AXwkdMBJ0=RHQ$)Z zx}~G7dRPBs%gO9hOjh$eFHqSTqjak*Iz6=UEA5-Rak);qQpK}TF~D5z8(pVMDc+{Ms6kd<9ELh2Nqu2vNTB`X!Y$B!92o6> zo)6=4RHh!_H>#{W^(j2M^&8Y%$bEjmlZrg?fzMbiI<7hT;Oa{6F2&Mcemu)@^`{@r z5$*qGRiWIAMTq)R3N=zGR7~9-?$u>)-#T5feaj#EoJhoK1+5dPtZ_(>EBVR{vD!ih zqc7DL-|cfu(vw(*CHs!uV0$qWbV#g=lSYL{7+&KIUaw5Q`tMNuLz9^^kkIslW@NPT zWzAaW!VjI1r3-d0WQJsLi(+fhfzV zXDWIW`zG<9eEdab)-9%qNP9UDnrffj!Qrb?*Il$!Je*E}y`I@-0Gohc{dmf87O z#ktNMLu6xQ9&24NmJyAkySLNwSa8_b&2*2&+@ROmP zez>y^Zn49Qk=%8~17q`#c=cIGeY&R?NsH5?%3l3?F}VEcE@tMoMCoe8&1}yV&Ue}G zxs-U1N#8aai$#B;W-FGSdi9-e+o)ybf}TxQ0b4IL++V40W$8Lmg-ER(MALe(&eomI zM7M9_BOMClg#--({VGXgn z_;pkg<4IvQUWl{nbpfp_iIc>0$|s|(NF*S%qxHfY11kS{Ul454dwgci=_Ti*T)hn& zE{pX?@jOO?&>lZrse?9ZlN?NNhRi@`CtWb{(_mNt;cl4x@mQ&s_#8~AyhJC8H~!BD zzDe31JDg>PIZ{r@R%SN%btA!FEr-3v3dW*cq@CshRk4@iuKb=%l7OQ^ieJ)}gzMfV?c2i)k|d7YQDR<&$knyb+QW_|kG zx5YaVaBO%0YuW-ZPd=7_C9cM$#~Wi&OEUuyjY#t3!10>kQ!?B@6p}V-i_kkjss6?R z-F)Gce~y}n;fVgbD!HS zWoZtGd9%{?1iVW68@zel_1T6Pxe_6CkS|HKrAtMc+_PX*cJZob0aD05Fy8olLmx<$ zSVw73r`vCp`vK+Cml;TlZ(0}GLU+Z#0h~(r#fru6e@T;h%$!~u_7=Q114tc7rIAR9 zpI65FkO54;xYFMN%;_r)>s0jI$g3>q#tWCAM$RdI@`1}G_>eAMT#YuT*L=92B3Y%lJ0Ednea z_izT=mwiXjVIsGU)^yd?Vh|y65nbK3)b86=(iQC~2|(1q#}SYZG%%Y(ZmWtRe6~Tn z?r^=U+w!4~4tWuV;K#NSXdSDRTd24w^k3EkPO+?`C*utRE22w>tBMwzu2Hc#ZnBA~vkDZBm!+)<--%R9={HsJj_(U&G z4Zi^QONS&$$-TwmM)L3gL4LV`R1r)@ovQZ=pp03a*oC&y3yV+Q@b_TL;$VMPEUNuW zm_*)pM%pu)Y!j7tt-1NRWcx&$Yi)Ud7koo$Q|*hB54~(xxtK4n^1Fvqv~;N#_2^oX z3IygEtyauM+~KwzW2T@=YI~z&uvvb4==PX|e6h!vX(4^{$@Xn@`!HK@?bOHNR6-#h z9rrH;%dg6lu&FzZ1iIE|tx$?_LN)l0{w6{gM4wa%k+S5C#3U z@F2`5EHz|&@ZUix+lcfN5GTB;RtAID+4K^LSal>3+wmXTRnQ-AxBh^w?;O8f!9D)o zg3ohj8bWi{eaZXWMPYXwHsjVx6~?{Vy^$jD@lt-%$c{R$%$38* zew*Y7qlqHyU~+-9m-u>>&LqrQt0fao43)_zU&Adg5RS>ata`%Xku-tZ|$y4QOG>?pY=U`M-9|9{Th22E04h2 z{LSggV=2EbOU4U$E@;aim{gwEwwoUSJW{pImx;!2W6KeU)sT{KJ9gz-+UDZm)fp*5 zFm(niWk}{QUIV1`X-=~d;Cg<-g}F(V@pFn75`((^O1XI+1rhhiMD9P}8bcS*MQ?ZO zkYoa=>D$o#Zm9oHts4zIO~A`4f$3nfmCj~9wN!Gzawfj_12 z25h*2`0fU|VqB{$4C~Jt*E#BkTQ?YI?~uxu`V+}FsBOD|Evb zuk-Jz9a~@=I*Pe)ea2|4xqs1rM}{VpI4Bz%!i0e>?em8*1lrmyZdc5}r3P6Z?Q5YP zRWAZ`f<8$$>n16vnaz(k`e=0FA!q?r#kCkfS^Y}uluOxN?_y(s3iVqFdcACCWgMd5 zuR_ZZkX!kSPusCiuVC6>>wYZ;xg!6Qo&XcbGpiug@BpLORbYNimTx!^p09`OJA+N> zYs@wD`DygYcS^16F1x6*!btp5y`e9@%Ry#kO~@r8+E8q6=4`n15c}v^idP1OhgAN! zBVE>}7~{OVRbW~#PD4ejMtOR6PEY|dvtf4XCp9P? z(qr}9JKVysTr_EvCW29&gB~z1GTU39QR>m!J6xxd)~>LM@P7E_mdM52^Gx2)Bio_v zkO&uj=pMZ?{c|>1Kc;oJHn)~OzkAivk14=4KFfNz;$bo1WB!7I^O%(~bev%Nv$h|? z&4u;9z6ucg--HlSLt87_kkuji0Eh8Vm9W-{RbkI7_rL(6E|dU7hOkh{4pHny$Uo;orx)R`)Jg;0+&Sz_tF8>OI6LxFsd zYTJEqCg+(vW(PWdDAews%vAAUKO5`^ib(k2lIYbZCf08Z?z{uW3V0Av00q@}1m#I=q1E@jFTRtyH%@E6tuC%>MTfe+0cXxUvEJ zHmJA~Rji)Q%O{7E9hULe@e14e>&M`pV~6VJZ?MdWVQ`xV@81MINy!fA3s|nwAjX8i zM{v6%E@W&Tpo!`S(w*o+BQ^@Bh?rXDOd@#TE;!kZchj3hfJwEcur4i8?yW+hxC2B3 zc1F^M&|EKhG%e+45Y;~mvG6%C`uO+Z{ly;2B<;FsA%-FsFt(LN07>~jz!zw8E2TQz z?R>)~L`bjj8iE-m8;|$%r^xOh`yujhPxY2J&z%}FAApDp&m{Uaz!>^lUM`5N3+Wp>}b|$J@f)t5J#)S3( zb`3|gNiYrO*eRwX{K{3mrvuRta(X}OTDsNOZz&Hk&}-9-*z5Mw=F@{0g}P^{)r)lU ztVh7jnEHdIZV8#b+3-H(3$Be9D?H7{y|c(l<+-~G!a93p0g`XAv`LN-H-&+--P zHoEz~Sw0ZNWhPU}Qwf3~MA~;4)Qg-(aNy%4t&G+C3HfQ-Ob{`{SXsSZ4%&u`YhD?& zIwp7~su2xd2Mha`$pt;DHk=5*ETDj74mdO_*j!omkG~M!cVEDM*Wj#;HYbBKNzO>natUP!Wg)^%yLR?eZkGZ# z7SI`VF6jaVMuOq_a664`{C8e~t7T*n#QaT>pJmwUr^&y#+@K8Ip~+bl?7;ZWjIo*` z5OM{u9(I~BiLlG@V(wIb>jXyTA7?&|u~kQ3VYkqzF4eT_(r}bgcVG`0IXF3USmZ>t*ohH})q3+1W9PvHScb3CrADT7;U>xczjw`*$2 zWk>4?0}UgibbL(s4IIhCctdpCXz9%SCV0g!HHuFUR_kffaC86Qy;?bn^<(x4%tZgD zB>llv0izfCl!TorIX|+4Q2td(T~Z*3Mj1>70sT&gO7T=K537p)apeBr;~c&qG;Rp- zhn*`aV~Z9%w4un_wvgbRYlp59Bhd!f?GjYOSV$XF@Co4 zQlK}XLHnS6_%-W0_iJVl&TjKL4Dg-zW0!31kLkM!FwfbUU}rjhOlz_%uK$ffK`V}c zM&r7oniteu$q>&rP`XRpezFWm8Tr{Yc#bONscTNWoLYN+Q-ES46y&G4svk33bU)ZZ z-%W_mX(2<)8QtukIb#u94O}zbXbvPWda4$R({Kk88D1a)I5lvC`4eK*nxrap_U1Ug zn!eOD_+^af&bI>G|Em~>`z4J>5q}i7*KK!hOYYA7k&VsIeaZi17oB;0ok-}(pCxK6 zYUbd?y(3Ir{MM2X{k}jJSsHsIX&slE-8S?69ibJWh2ixneA2R#2Cnf|2*en_dThr( zMOpL=LDsEpe2@4&>|!BsI#>P>(UtPhH&K{OJrqGBJiGK49*tzmMXAzPeGblkZk?&o zyr$+a$~H7p+Wrw*R=(TJdGy0Gb{gmH=?vSul%?@T0{odQ+PhSxX~~Pur6sblvUXSd z)Ez^yG{jq&Z9xvj&@rPY5YL*8GOkNU4b!WWL{vr$chFbb18lK*X6CvoHWPvUw%EEt6GgmuFYr8oV%h<*GC;>w&o=G8Wh53&YXHwSm`lyqp|Kn z9r=;5(h15U8_ZmiOpVDx>>#}=(f|jtIxHrD8`$A~e;R$Mj~fDdL}{b#MuKP&!D|Ga z#@MxJZq zyw!DMVrehwcHy=D8N_GAnI^fl&In}P6FOY=g)i$dBq%pQm);)E%l5Z5hj-9tRCC~r{r3(nUi}ZK zIGAK7_22Z7(^YNP#0qNsw3pdS{S_%|x1Hw1n0#Cpg}R0dtWIu`1`ewAJl6cETaIJm zfAtlgqMxD~Gki!p+G}nHu9|7)yH_zY$(z>Rw*N;+_1oIPr5K`LhZl!wF zFG{g%_lhzCPWw$w;epaIzO7P&;f$<2%(=~$i$46Xg^6^tGEZR*8n!#3&};5R&WqDC zW8#I`L?VF&Msg=M`AFJl<%4TPm!l#BNLb9Ou7oxkjY?VRMYgVfOA$>?Q57#yO;x|0!Gy$*4t|~^0_FW2FKUVI zjW~Z&P9RH?cWZznW@gJN^PL2POI$sFKz@m+m>6&JtkIW7D~Fx?@QUhCCl(BqI8Z6mCSLvBR$&==Ja=!OikoF2F`Rq8Q#h z9Y)S&dj(xhRo|yjY!n;5ya!aOw7fH^v%4_ZDeLigqZNhm6Toi}oJ!?cCwk7$0f7o) zXQPdBql`heh~CyH%}E&%SL`Crd^BnEP8vpywH4eA&66T{Dd_GQ>|Xs1F5y1pKwjT}O+Z0m zo#T(ke&hfDA>oMmc!94SbvUD1{cbGJou27QTsfPTrdW7cO0-%T#$xAeJr=!I$tPu1 zJaLB0a81_cNh!3Vo3tU#RG5@iFi?D09j3MS@Su)vP4qT+8oQg^6(Qlxpvgb~2YnAg z+q*fGElvp2*eC1&|Ld7n=}~;ph%#JQCpMY6DsV1XY3^tB)9k^i{3w~Kw1Div0?=yE z9+IL5mjO@9Is@Ukdis?g;Y7HZ3MKx3ibZ(O#1W|>5}XYm0e;ut| zp$bjj%Wb3|+Q|qWM;K_LScg-s3i}&E18DaHD&9SOqp2n9-6dwt?yJ; zH~WL`<<@nOb~4hW#c*`ax>41YFHcMrMd@v}(>A0PtT!2Oh*uxgu}D{o%Xbk6V1^9p3 zi}1sPrVs+}Vd{JODW2hg9R|9E8MWsyfiaR6t-5*Kqz}kL)0s<(s;>|N>Rw$X5m%=( z^BAKil(^sWqf!&aX-e6}O~OFToHu=AiJLL`MhA^n)QiiNJ^lSzpQ+ZTJEr)lq)UAJ zDqLVoN@L|nQOQ;6_EN*9ADTJZ9=5HC$u9VlN~lFVBj{@Ru%%7i`y}KO0gG0(F?fY3 z^*?~bZ}q@8JoQj8NzkjNZ6`CDA8>=sMqabs6m8cbn+%U8>VLP7kPqa18e5Jq@R5-q z;8yi8jcZ|{M2dOu&(k$}uQft;dNlK7A?jiaMn|VnEl3LoKJG)>TT$eNy7FknU^+lp zY|@{lXNu=C>wVyQx;MzXHqLlG)@&K}X3}m2JH!_;8ik57bOqCx)*gLqwgl_i4U+-f zSEZkEsBmWXzE$|(r4~|_Ldx4L2Zx8-a)X0VpmI`={Mn;KZL%rv<~xs9V}7GM?|V3XbW2Dx1WYNqYazj7A%lNl^MkW zU3v7HlH8P>;Mxy0ajdIpOqFX;`?3z3i4rNQ5%%4wCjkk4SG8ANW9L(R0)iobJ2ZSbxzA-W~hop?nHYy*U273%2 z zkA3MJN|rgW?Zr_}4*7-eJ?VxRcq*cwUIZI9>2+@`-AOPuy=^;m`HzL{UNEcOsr4W=gW*gelz5n@THCdV-r{-0}HPPdE7U1W`oJn(0~U~o6(Tv z5yC!urd{&MmpC-R(57B8*=hB3Mp6<52Hg8j3p^KHTXH4SZgkw|8G#mU+XGLY^_Mr^ zRlp)|Fi+qzHza@XgYkm1J?dx8uY9ZXT1hu{YVtUiTLr?T{wM%ut@a_I^p8u~*g?8P zeM{^emuqZn6h)uzuA);o{jmy{)7l&mWWek|9zAd-RtpnaS#0Orf6d^&DL{*54|vTA zoOBDFcs5!Om8V_|XX|nN0+hB{g4ZLZA!LNN{uEP5PXv}-2#X9Y4fMWf#+RlA zURT;cGe@5}kKO#@hTvWDFlajeCYxA$%zeCStL$>PQ#;z~8@T zs6Pg%915ie>)0aCrwRDBiMSZk{FEP3YZ~K4`#bvRIQEI*8qjNO9)6$N0%p{c_?w29 zK#{=UPwMC{ZGTa#|JpFKi&wDGkkB|H^h&9uL&mQe^$yKy<>C(Mn0%V2%dTn_e~!NJ zut<%489I!L%2S}7`EWOTT7K^p$$by6Z)(~onvpWIaS?}E*+RXa^ZR|J-h-H}^PG2( zGUzpkHlX@B^DKh~tS+1$+g8cOcv1*it#OdOgN~7-%yLj}sM3}VS{6GVtgq?U0mFA+ z^fHS)sntgyDEw|%(r@(kQo#B+Hr7g0PrrSHZ$>(*SNtet7(Xvl3dUkyi^MtDLD~=c z=2B{xTVUherQDDf7#qB*l1p=ayORWcTSE+Ylx}%T-A!arzUS=d#e)fRc12OCmGwsl zYW|oW32}pZ9xQVQ#$Vz%8a8;N4ofTtS=-zx?>o{nKScGJ;1Plkbe>?_)Nk2mWv%^7 z%pGXie!2H?>Me+JxT-UhI}cj-yohBNxXn$~T~{iRF9TdN_}<{Yl+TxsRKraV_e(nqnq&Fe&^XsS# zbs|M%5<)9;-?kJpI2cFE3)Q>6M9BL>4D&!}cf$7>#sWeMO~3Mp9R}n8wjpNJxZA? zxhExTqr!fv;&^OMkPk@jlDR0gu72xsDqjU_L#+y{c|J^G)m!u5GDB>&3VZPvJh#$R z?T`eDSi9M9x|;FPv{wvTPGT2W3U8*ZV6Bu06Lo-;g!A5g*SGf?H$0hCi%K2m1e9vR z;uuCHx+A-4$0qV$q*S}F(k$BB!WY9-y^+!3c|-pfpE)g_EqvJwT3;Bdx;3gtc@(1Q z9}+zz>R70Dn_X1;>71zTICpFDP1!g6j-Lfy`|lme1Vu>8tt^9L>&1KCT zV0NI#VDSbOU%7bOGWkCI)(04XM{t7Ud;Gy4Fo6sF*Ec=U2I1Yt0O#@!5^WE>mo6UW zaOqk^QG(jKf_Vmby#iN26NAtnmC}IXQiqga#?7~v|1sghH5o&!0CfJ^Djt)QfWge zYQUiIfm9o7;Ln{dmdcSA%#qMnGPW7gy*F~}Mb2DFAFX=^Hthwk{y43!kt(zN>5f=6 zubjcTz0!IADerz-8Q69cM595#?EN4(HuqXgl-7LW{o*+r-i%b)$tZbcb6foCwzuYR8GbrHMzsc< zWKmELouT3UyILkBQqJ~^l3vm+G4wC)`~=s9epQomkvJYWZy!fX-)#6fJYeSZMl=T- zA7`Xg^5(-&a3~u>R$g^UWlCQ#Dr@PA1W}*P9oJIZ$1Fan&83G;gT{*uY@xuo{pW+T zRvqR=tw?S-Ip<xV{^y+X#E#?PeTNKv8&EvBKd>M+CJhv0ARt76E2xe*jsf=VX`~ zGuGzYN8ge)6Ox@(0)n_0J<9&{D*qnctFsDve?ewrdsMkWrvH5ySuw~~MRm3@+Eh$=C#c(tsePy9EKVZD745dhbcqPr6@)) zU6%3EJ*4rXdvs;r6r3vXL<3+wnoFnI>U4Si*VsGPdPz*_(j6Y3BVwJrr#D~dGAds< z*d@S>7kBJkY2RBq^oe6mQqUmkqO&V-o+E0yF--gO zxB}sj80r!^4=R_Y=fff~5{d)D_7Kj6-27Zv@Tv4{$mMdrFn*uDQIhAl~T&<45``0Q8}tDinEGYZ{!T9w9%z(}=HAf1&L zUEll&dKx$BR8eTw}pZOVg{1Cd4g&7THY^we+T;Xfs~iiLpD0H19Sh z%o8$0Q<07m&HT5?u+hiR2$!ku!XXjv{#}shzzaTOIbQoLy&t+E@XlZ^kVm$F{OcH5 z2jd+^|3e-4?z9=+vY`R17Jk*l1#xT}R#!46C);$eh*^D5q@{sHJ+D^es^)nAkU1De z2gEQ|phb2+_v4utiB)2R!C1TuR7BzV=4q$&tUVWe7Jo1vaeUv4LX~FjMQvKHUyGy( z19qg)A8tN@$&_5Yu)*+(f@wD*|L)vyXE-=sCyJ9eVi$4@7h#NHTI4;8idB8kUPexLZvfoS91-zk`DCtG)8d7@ z!f$90JSRK;`kuRQF-8nxZUmjg>eO69W_~A;1$uU7Y7g_=b(D%VI38q-^&mI-&TiOw z6PbgeGFl6gO}7Ada7m2H|95#Zi1~WC0-h8qxxFfXWo9#BQ}$P@`~>6afm;iiq>Jhl zCi-kZsLm31)Wwr#kLt0#m@(!0ahtzcV$RT!BSYhf61mZMxdno z*_8T)SlKZH#Yp-7K{5_jYB>y4GJVfcuh-uu3BDD_e`_>MrJjo4G;$#+$7(D0xUOdX zP9=wl=qZHOg~bO!sb7kM?bp`}kL-W#I~s4~UWl+YP}x$v{3QhAQ~WcravxDIj7We6 zd-0b&U2#1XfzqDj%A_RkaH5I5M&A5yvpeki+lF<4Trp-v-kI|(`AY(mAF4GKS2+o6FR`&5lKXOB+(R$b!(#iTKLInx;I?MlorSdkAX#{?70v|#YExqByUmD zbtf=amRzPckyl9vMa*_lS%Iu&JBY1$h#Ku37ZMN#2+!4X^Q-V)2Mbds&Z1 zkkz<&FKRO_StTw*`(%(go8_}#5|l*(>w8I=*|S-VS9K8}?FJJ9cEShgLcY&e%ZT;3 zn(h-fidZii?u%TtXc4wkV5em#tVlLbc_Mj?vF=B=&aLx28)&6#ZQn*0srBj{nljBV zNEw)Zu$63Ak_5dA5CNnL6j|Hr`LrRH`atMjR~82gua=%C#fS$PPwirq?8(zhw+?pv zv`3S)o_W<1?a>cMW_05h?Lxlko-57hC#89iB)fekq>RA2w4+QNB7_c-8$Aa^XMPQt z4y0&y65|B0oI2_vf+q8XbxKKRr^o3z?idO*{GKH=f>8 zd?3~FHiBu#_iWLUb+7cTIjX|dan9x?nf;(8d!vJsNzoHB%#h3(c?C00sVK#Tm>gd; zn#&yw=C?iNY7z`%z>O9z7abF+LRax>t<FSK}03SS<# zRpw|o4Zw)I(^|ofS2aeHH1HYs=vZ*oueBLcb(dYzdDE9d$R$H{~9(5?0x(6%|c-hE>s zC>OUOr=KFwf@v{0)%1h@#m3?UO!M)YhIg0ypI-l<^ex^n>YWRVshZ6!t>$~#(B{1w z33tQXJ%Q8K;i~G zT6EhJrZ9<+KX@aVrt_)6toM_iup+mVg*RSYh;g9wt@z9$%(>(4Uh$=w3oQC2cXDG3 z4YJB!*t0l|vUuw&pjwFp9i_}Zy_SrZS19=kV)!DdvwZ0OK0CSJnMpMA8g-}dSKcwTPSc|~6uYD};$+OuP-{^Lj- zs_{n_o9V#4EhrA~D-T>o+BX9r%|lGk-@6N5C+z#@;XTumnJBrsMQrWLd#%7Z(U$B} zy=y(cUZ0v(xMt9m{COOEF{wW;*7MRU+U%RLLSS^TvCQ)6raJHMe|iAfTiaSamGas( zs`E<@cRmJuBIKRrv{5?ImO}GsJ=N)HUv%^F>ISd(!)r!*2yBX>&s;6~>9zQ9JnyoJ zfV_JDJtm2hsS>gM3b69z{qYYZznEAyP^~9f-?15__G?Lk91K&~AMHPD)qu~Xu~K>D zk!7s!^P83BFzPH4aSRoxw@rWDEz6p|NLnx zH0E1O=vz;KWPiifi7G21_;plMoMvP8uCsYE8hiI^m!cx<%2;(%JFVwsqwFTA&4@qu z#MhH&1&VM~H7J-fXFd^3|1#q0i{}aDE|>Wl!GKrz;?G&Ur=140E+*Q3_ogxoz2g6W zC&Wlrn1{HO<@Ir+MbGu!Ie4X;`16mcNtJdMsX&VROq*(=thi)SFxJnLhnJ@GA}mY~kIqwouml#w70XvY{brir&L1$@Gjtev_sHkC zq9-Pe%3EIqJ*r#JD~u=ILaNtLeZ%eda)`(Pp&$<~*)7sH#_ye;7S+`DKC3%LvQZKj zdWz@?iN=pU%B;84@%C|UUAApEa=koNKi;mRQliD&u6nap$>Oa-CB>s|qG2oT3NfcN zPAd^d{%L*USd08vgU%F96?{MSAu*Kdc)g1bjoE(AITAvn?!Yc2<|Ms271_hAE0_@F z%BYs+wEe}R4?_QeV38bWxf`C01F%F{AIZEEo74$k1R0iAs4aM|x0zWzV#qA=E|S7~H*@EQjblAI!uxz#F(a#{xWx|w_-m>5 zhgGxAVZ~OEq~Egu-&<#>H7-)!yL5zvMeF1YdRsg@SPdtX>0jeiH zara*zt9IT4vYLHjEzk!l10Y~2;~S6DV_1npw2RJL|0L#Y z>ONuKQ#7PGr;VNbHE;T=<=A+f{e(FKhGG7!^*~32oa$pD?v2<6+Rz2E3tfjB8g5P# z2%DC+rLhz?LrQ_G%3Vb+S@BD~S%;y48?QU3t(J<>>57f#vB7uJp%+`Wlqs=Vyi%&d zafFWbN*N5?7Y_D&W*Ic)hYq&)%6q>=C4!u&t6mi&>^MYGM^!%%6q$7302VyrfQ@S_ zJqeJB7M3#!q4hxVng~%hMFU9Ppnvr^&fD`?F=#r%EZ5CqJuY1#WKwKaabQ+qEyLOs zxr9Z|BQ6-DNyMOdp`e!%K~t}Ak)1Vo_mQA&1u#nJqhu6JjWa6=lerm06M~^ zQXC!eSJx3$tf>PWoaZ3idGtE<@H!a7tsOBDJTt(rH4Y&#Y_=h?eX#yT7b0)upGGsc zM~I=&?i=%H%mmZtoxRjLkm|Vo7mcbe+Vs0#e43rmFdM|ns3_9J5WYo#(zqwAQj^TI zq#?jU?PcKlWF}1ZNH^>WYL=H1K4jfbB^n45wXzB}{ChcrV z9PKlXCawgVCC~ulYF0&6OI6o!Rr@6+j5^K71oyu=cHi<(10D_zF{pX{^CI;WiLc{| zq}9ZGo^{0U)T$~W#gK`VNFA_rE)#wplKb>l@Ln)U^cM<`kSwJOn8!R4B5#E{lp7+` zTS!t*6rA)E4iI_QsKEqDo}iwY!!CLG%yV>OfCzINwQKk^nnE@Hg%M1)gq00zDB+A|__FRBR4dGebBk;ZuoqPA^)@TS9z9JD)7;?dWlm4r( z=7hJ>;7x_fLVxxLL~?Bp%bzSX7!H*=eYIjE!g_ES#RwQjMKFFYNI)mgUFWQS@$VY} zhVn}I5LrRqgAL|q`@oRvu|P##v`~i`dcdK>r711e;}Va_;D)<}qZty0U18E}ru302 zUI3;hmy^O(I*my6cCP~H3{qZ+c2!y-n~&R z=YhcH3jLV$l0AKDJykl7CVio?+c3|4Lem+S^xGJE?Y^CTs3cJD)7oMG*;or)CI3se zqj)``iU3w_&(A{1B&2}cYu@FH{rV#|6uPZ+X~`YE}J4RRSA-x_?>D8;=&l{slG z*Kb5?7hO^|k%SW)>z9lP5+6FFWK&NusFlA<1NOjyD#=}4cQ-mb^?cL z5^`V2V`DbmTLOOsMxoFT+}lsBah&##VH?`@v%^Ba3Si>kpF(q4)D&r2zp8vxNtokz zD^_9(KjT^#m3DM)Mwz@j&1{p$d6-Nd^U0ymGR}{Cw)^gpKDgh_?(1{?WjFP!F#b z)&FiNcr?k+{(iX48kUVpl`wdIxWZ_k@W{u#ZXt`5<+mYn1BgIRmfI)DGX+g2{5~!fK5bJHTer1; zt;I$Czv>pEsh;E3gCAF3z`&rRh~)&V6R_tF*9j^Pj(xRjB-vl(wgo$HC@YEoiPPtu z<~$k;CrrZ%^s|`&f3OJ4SiqSGa$(6bki!{IKpumyJS)UkvA|^zKw*`_^(j z_{M;I#czU8Vy4%C#bsV+aaiTfq<%55|29kBQ;9swRaRduXNj_bVly-5?H%jxqM(h8 zt!z00nc#qxF2iEskUWN($p9UaL3+lt0&VRL@3PbLFjLrAWG?dj+TdMCk>#;3b!&_a zjiNjEo(Z^e%%DXSB#1ejZR)xEzEV0kJjoK(ETr{WXBw%eVv?s|Qp2WknDuB*nm7sk z+!11Lc+5ic0Vj6O@T#U%*ZsBM(lvEaP-N!T+ISb_e5GlN9HQm_Sx@n8TW9Aqn!VM~ z?&#+9Bte2O1#k?S8-0uPuKPJ`d+RL=pbFo`uGkGaui%x$wU2lRz5CCr>ldnDrwE?~ z>2QdRZ@7EiW3rfj>d+tkJuQXsNsl1on9=}~3COOcuztwc;`D?>c@3==R#BM2^KoFY zXoL;5&J%`R#0#_S*HOwK<<}0+UouRObEYuK$JhVJ3oGB##BYSzjgMw$;!5 z9>m?2eb%HjgehGzs)DWt4_o{nFBt`b-IB8cqFSzYN|#aO3dJvNdM997FiGzJmQbnT z)KEor6HQq&+H2T+7jDseIP0W8UeVN^(qV(?>E0KA0>CzQ+Y_2~)~Yq#h2uL`#H$aGfPsJPsoVg1}JDPhB3~8#sY8*;~|Ay26osc z7)~haV7Od`YWNbtrioV%l?F4Mjqsg`Bs{39m(P+XOm&Z(;pGoUvI-Ok9V1cR4H|#`b(OEPR3hF) z!JvGKEaaIcC>}DMzsec}LRskNnWtoP%Jo2YvGg>V;dD0DVc z>P_!Tvo6rM=?%@Q^vwQl4vpk~l6mmHS@rr_so7P5wBKz)zc0^?AZn#vElRBDz(;VZR3yi2Gp#cdgg z813)#hFB<{gREckt5q$FcT%glk^*fJNQJJGw&U|MF zkKKH%@+?+OW!P!)_a!dIyFWG_ewofH{)uh2Txd;ZGiY*?J1fZeu2Zb{Rv{mYwINrL zSoewE{Gq*%REoaa(qu8Xju(IS1L|DuU+Ag?m7u)3;m^&UOTI^ z?}9eA7mEFfKR49^4mjHMdhE0B@I)3=AaL%lAE5y_6|pef>9168jep_~*wse&t6Dd? zt+o%mU=(iH3bYhBCcat2LX!}1!4?vWxJ0~QddYiSy*i1|kIhsoxJI=$3m9f>uviQ-hcfM>9o^l>ol{y2w&f0Xy*obev3@o!Mg)1EDKyL}92 zHD=W@N|D?di*5zyP`wm zV0GfN1|~>(PLp?_lmL5_ulB#bMC!p0{Yg`DIIk*k4xt>J3TV&e7jj3xyGXjibQb4~ zSKi0Eh#n2h4n~r-4*TGt=tp_&kM%4qe5+C+#zJw(?F$1bcucMd7&N$$53rU4>H5G? z`I=Tf!~nF<|dxo*1{W#GEh>``9)V_$1>e}{E-xIup3w4YYf z)%E9wUCg!JQaYjiGWMeKP?G&J#;?N*8RvF4+B?f*${mUaCJbTleg8!am$sodqjZdg zH|qN;amOV_1>}|fL&kdLQSiH4-B|-oQianvuU>)f1N+aZGZkWMhvs=o zX0A9U7?@cIH0_BcJeK#V!Q+T!aYj6HE$WNH3Zc}F>G@7CQFru)VNxSFXpkUEB?@MY5PI` z-hkH5+A{{NqBl$YA4gO260&3i!;k{#e{)}|MD=`A%F}+xIM0LKRe!xpIIlQzj)|7N z-13Pb4~MR0SwsmwKKtg6GUwvVH+dU~cZeb=YbbUOc zlZm`39nO&@;0k_s9|Zm>y#s`iA|-M@9f?eb$CA=o3lxjrrlj2$E5RaERe(j72F!*s zBGL|N3XKp39nl*Pnb!dS1)V%kbj=;jZ8dqVzxCz8-o{EZ@6ru4&Z{c0Y}X$1cJ<6}Rhr=G$=(dS6KQzmno&#D;*e7eQ_8BHqW;IOZ{bGCr%(%>DnJOs>N zPG^-*m!YnkNv^nSSa-K~t>SaF!Pz>-Dx!WF~r=)B9^M(+o7Ts$Pd;RshE4Fiv_$QgF08Jwj}j9@}Tyo-f- zwG=3H?HlIN7@q(Uj^5HDkChp_a#Ax#Jup-ru=o)F!50QJN*om(%^1s$;h?Vm35&j) z=cjt!t))VPhg*T)MTEjTsC+^MG}%xMOHaPF)`!(C$@5TH5(l~(lCi(N{D6%qjMjUN z;;v%R`gQ8_${C>%JtiS;F_oFWsjyZmm4pR&yyTJF1`csP$c>C5)8Xi zEJxh0Nds0W0`BzI>UiqgI{9L0y#y*fH?APzicR7#<$ zk_--hPAk!nrpSvFvOw7xX%-1a3TC0qWA>%|JHuQqpKAc5&C-P_yk)Od_UyRtmt5nq zPP)RVP`mAuq>zj$`Bia#1<606-GxP#!ffV$wAl$}Ll*-w-lZjb`Sg*)V!L1>XSFpE z+VQUX@}GCWh!eFL5<{GU22I7&CPpWpo>qNJBM-^s@{6W*J>TOBp0J{S`d1LrX6;QF z$UU_g2^Q7q>-8`9M}Kvzo7HhoMULIG-BvklZ-i)-*zeMG>OD5s{g;I!??}h74XlBn z-uXy_6X)L@$6w6T5B33Vvh5U<4>_+AXreK}y$l4SNg#Bh9a@r)woGlW(z>Z&{`En9 zx3l&&|Fe{DFXNk-I;~SP$mPWjZ^%ZMY@ym?o2r-6RNxFa06;}nnx73CadR$0JL?O& zpxV~0)9VY9pVbt_q0gTNGL{g)yX6p_l9ud!s8i7+Rg0q!;1hq!uL*{^KF@#o^v49T zI0>>-5WphP^f8wnC?aVfd~X&;7qx=n_4!ls!K7hpJqhT?+a+&Z+<7nG$D>qrX^*lN zSe*2TB8Z+iIvc=A2D#X`I$ON9^B#s_b?@r$fS!}>yjuv1dObAR&xTc8N1?~Tzx;Wa z8&KsD1EWS$mfZg(EUTg?iPo-m@~R-`pQlL*QUUg6`O@{^YoA$*J2s;TQ<17$9iuli z_lSU!)Wp3dDsc#}bfJ@E1?(Rw)A}i$+5k#<5PPdxnQV#(O;7RG@9Kg%pFHbAAb$V@ zN~oh5b^%b3VN)mSg)Ud)fr2)#WG}9xOAHT2!w2>dXs!6e9K{8tkv1U*nHoFN8gvo+ zRs;4@$W7H}H>``Yw3XA#JuV%!${H}kHEgom?4=kWf(Zq^A^SL(Vb5ta4hMFPnVW<0 zo&KEKjwKrw%2lq(+T~@kGR+s{CbC-u0ujB3bs^hy=GdK4FQ$xNxvh)9jBecvFh=Du z<3msCpLG(m|9hPZOCY1?wVQ2v?vz2nA%|d$e?QaHm2(U#Ke%6|8}wn*8#M|C|6VRt z7!RX#P8Fm=v=M0!kce4NJ@wq*^Z=ydCR82GL0**^F(#QOWuPl|0({0VK2ldC=tz+K ztccf6#M>7_A17Cr2;&R7c|ajWe6e8c5_mtKDOFcMRAWbleMvaU`fY!H{i`Lr4)>px zydQJ}oGsJ~Us@?Zd|nvoVT~a5k~K&f_G2uWp1h2O|xd37PkE7 zspXzMVB&i_n*2E$bN*IIE4||n0#BeI(|9o>1cEwon9f{mx0DCM+A=W$3;(zqShRl^ zY=-@1_;^$3u-QRKrkW%Y;rv-It3{E}I@Nc7{bF*ugyuF{WC=&$A}5NUML>@{tYt_rOT0l_;)Uy@?vA>EPmHRbPMkf{zvK zv6y4_xOS#kmi&azF#kh3R`khye&eG9b0ZF5Ou3wTG0_+L6JbBJAtZ~G0F=I&G8P)iq!YT0)Dv@{)H&v2ydv}>z`{lAb znV`#n1l4(b)8~}lHxtiMdk*=5Cp%n$e+)@L&K1V-G*U=4ysu?^$EB$8Ru5u3K|zOt_61t%1BcCi! zJInzZECsvZxVo*N?G_zjT2zD_=gn+|8_rhQ|uE6HkRVjVIEoel%;`g1z^XGc? zVvdO6#*9F5Un4IUv+U8dtAqOZysE@gm~250ut#pQ_wH6lpZaW`^-Q}6P)BaYyj<}- zim|E2Ks7hO^Glm zw?F(ScP;1qY8s>E?)Nn5U2EAU5F=`H?f^xXfF04QQ^5iQE8b@Zd)6I5 za&Fpzj-!Zm-2uido4)_~DOULRY;#MTIXeVwHCZ-s7>{#QT5Yr{3aCP8lAl0qY$G|v ztMdP6x9^)Yn+=iya;)zL{h=i6ep6$5@*TV7_ zDaYr+n|FNDC!Gm9N%$h#07F3a;jmWA5j6oQJP4sfs2t7k=L)C!9MX ze>wx{D(0Qpm2?Pq!|Op@+3lA-V%BVxNaK>YYh zn?9|ttl(Z}_FDzGUC?VN4aarTo=-*Lm;*$^56@9A2Z<<9S(-@}fW>yWf7SpQ9f5h{ zxIRz3URVw-lpsRpWx<`8B%J&7+p+Rh+<(341;k6x^`g+RRbIT^R)2u<23=(>`a-J+ zSHgQ!etNp+yO8i!I4sz|grQ0@;`FD89AhSS_Gw;%!{9<(WaE3U9`8X@Jd3K7FtCoC z!bvy{?6Q}fH?LLpt;rPZP7 zS7pgI-s%vfl7lGB+TMcZB%!j#GjY};nNmk+oDzf?#R}73@>SX-h(XX14UnIzeAPs# z_M-g|spfEI`F>xv0z0h36@D&6OUg*l$cQ*!KhD3<|2^H~^Wo590o`>IWDSpnbfEh^ z7(QsFXUx)=GqWTh4SVRChPf%bUnm8JR@rPuhP?{Z8a6eJ|bg#ZUu$570Rn z7bgs&UH2WzQ^+J}urObGmHhd8@uFX`g!Fe!`e-@|^+bSqTGxO|@ebz>xVhkc`hz&WL+ zf6C`LHx7>l81OMmxjCkZv}{Y>?%@b=BZ>ePCz)&hcu2g-i12W{LYR$q{S3@qtisAh zDaP!c@L5{)8%^v^hBkB!I!SJ5?V%1v2>$JMlfQHJ21I%RhlLsu5}@tDQMFQdZ4cep z$0Tly;NMaerac>s^B~uL8UH-|`8|lXG$6$8CtGn9VlXe0Y*o{4&}Id8-ZAkPeHaQ2 zihB1_`rhJ;l=zTwx8E6tUi;&V1cc%P3wjVk(tVbVS(h4V0V&HHS2fIE$0WtoiK)_n z1POA-;S=x$2{<6Tuhy-2jamtrqr&L^!FRxFCgrXrjbf0WWYDd&@k#T&&$bAt1#QS8fKZJGVhH z@I~O^afH2b^uRbc&t4sY0K^dunI$roQ?~1b8mQ(cJQ(qSik{74aji2d(qjTBeQ(+1R zCJsO{bN7fqa@BL%IVoLteyXSZV9!i&WD1Y+^A=AjIapnyy{0h3J>G zWPbwsVj*~dy>7*xOlS)(GvI*XiFs4Kt<$?xkyd^>`#wP;M+7dZDTaT7RT}0a7zDFC z)2Z`?^OfH4B{T5*v}4MH^bs48w<&IcF{q3H+H&A%0+4WIInw*IsIyIbpY5{N15kKh z6wQqF5YS3jee z;9minq+$(nUo{N#nm3K;k9ge`wS65Iv~d&Oq%r3ZC3AzQ`q z5>1Y5CR>`-OCNmz-a=i0_JnB1S%#WdMU&9V#$Hbbx7o{}cPQzDdA~`%Dz+M!UkhQ> z_j+N<2!L5*zj>^uFaP>L53QbaTAyXeYhJd|k*UUmCC77ip*p0bBPaadXKPfk@K;FE zc!FSeKPb#3f-{*3%(ku)io$8x6k+n}BuMDI zpnV+tPM8MHgOfuz^&>y~Xd8qX2V<*y@UMY}&BTWmQWIHr-TEtYC6{P@p|F;|0i#3t zVReD7$)TOoR!_62b0|`Ic=G~)*?VXyCUq-wf-&WM zqr2A53(B84uQlzUr;!rOf3U+G1)IwT3I(~IKYYQ?spqoMd@~Ll%bvN+2=$Rc<;zDi z2S&G-AOab!7Zv}ANUrQ)->EmpAP|vCLsWl043h?nTK6|ca*u^IKti1JTa-v#5++}a z7fwY79`+8NOa6C%QJ}|(;B%P58GdUo$Yr|C27?d|0MhQeya{_-CoscnioSB@SLuTE z=S4SEbNG($FgLWJ9+`>0TqEA*TZ7jToNP*!ekNFCUOxd1k*Xw=)W{l^Cx@JtF5Z-f zHEYhyTrw>jr3K(OUE>KDv6`1)pOssCByj(=SYaprcuW@_Yx0{2E73TV_o8}Eut7dQ zqq@E%giXDyh(7^xWBz!Z^4CpJvGH!m#2^|C$O;j+8@ zsO5?b1Vr|7bvjTKa%uNa@;qbhr>CqM4-JDDLDz6|{$~|$U!?~k4YUBa^3}YAp4w>( zP+q|0YkBt{cNg_%XP?JE%qM+tjyjx(+_5&+X`?=*G=|il9+Hu?KnQ+l$#({BZoSE2 zUQHBoaw_(YFh^<1P-OTPT2p$A!NtJ}+!pnvWonCYR5HmW#SkdZL zWAkX2ZG&W$tXQ3W&0^cS>#n;aVHKW$pL}`t@P1aMpG#h6x_ZwYGd^8B6A^_ReYxS6 zYW#=6KOZS+j+aMqQ#;5&A5&_!h}$WIl+167U7iSm_mN_ZlGGAJ=ZCivN!$hY7VUu6 zhL|Pf#k1I+_en5lCl`Ufen(c+Yxy`AfDzrv)Yw(S?mu-eU!+LfuB2gr1JGPC?(s<7gnN&v{2J>kSnXI$By%WOEb zZ1pkrbySERIZ(>BL&CbC+TDW`q*Vw70ppysX9*S(*x%HFBv1oxgkK0}QO41AVrTS9 z`l1sf3xqD=RF@wuw1e~@17g{_dW&BJvQfWtbRar?W^OMww(jIrjh`QsfkY^F^Phdc zF(O1n=?OK#$_5|84Y9rICR*i;Y_fh5#);H!%w0{h01E0C^#Nfwv)f0Jog4pw zBcEZ~urTj_Ev6~1a*d3|A@!n-S^6E^nA?8=lIzk4I^} z5jh#ecB#L(p^ExDf3{}4HvQ410PvXDHk;uroky*<9EmAZ2-=GD2`B#21H?_f#J;bb za?K{%X|eECu-GwAs_@5HkLJ0dy!RB2H<{0hPMHj899#^b(~19FOc?HkJ#ggx)>JU_!obhm_Ni6 z?=OWbrz$nw!x4D7X^yli^F29DrvZ`;q56`6piNcV7Wc?a$=m!q*|!q`;Eq)KpStb= z9<4h1nmnqRhM5Kf1@X6M3C6)hL^E{Q|9pKqUZc>`&%kE5WP3JMXk!W*LT2*)b4ZXV zVfQUwR2~c;ueN$XAVDN&eZ6lA2J)2-3v{inB83ojGiThQe7PxC;M_b^y5;}`b7W$l z<$AAPLz-&#u|-#B6FVpJLH|UP9SBj=tSU8aYN8>^0VyR)P{z-Cp7QaijUBmVX91x5 zA3YXWwHFyh%B*vpPb{Zoc61rQs?3Kavh>%#btlywc~_ZA_Z)r%#Q~BPr^_HXls5_( zBqyhVu@m;RG2(%Ab%2OW0ghM!G-L7wc~OYhbRi~Faz9T#-ZS7Z6=hjPs4NC{g{MTx&k}%*@qN>NeNm!a63M+!L0Ld+$p{bIB(ZNyqspDKTwc0 zeAm#le9soa9QBf4TjXafCq8_CDMTat<(4Hd41BdY@q#7h{oSd(r+JE*x~A^}J*3Mq z3N!^&AH3f$SD%#=^||(p-z-LZ0UYx+PDNDVY)Yz6y7n=VR4M%lZ}ZlA zM&MGoeD$JSOp08dj`OcabhYl8j?_}A6~=)cr{6?a?1T5oj=tw;iUbEM#AxJ}Z+1GS z_~?WRJP=VI*PkFvFN)i#*}3;mP<;Qy9|YGD36m_vf$x&|Ko@E6SpETlhxk5%iG*vm z3z=RkzCW4nRJ$1!kObhm7JM5m%%{u2m&tu6OsS?Stc#nT^J~ms&@tqi0B{HyaR$u4 zUp$2Ab6&@~E%#yRKA1mylp1t}NvW}b*$h7y>#@OUb3OU~v#WSgZBtxkLuAX`DHXN4 zp_(t!9zXY2BIP#!`KoJOq`yL#(fcNl5PXImMr1zLQ)#W`D)O9FHgD7aO%|+Cq56pS zu{)lVwX4x@e6s`$q#bzwA||DOUT@{aTAF&oVD4T*J@;=Ky%St$D8taas8>p349!`h z`%cFaj6h_9JAsrCCNvy1*nUb;rNIX&=PzpMZw+t7wF#L<*+T8XaxOi;9^}&)q262! zWs$f{P}4FRHZt0|>dbt3djOEW1W@@^k|ZlG_?Rw{QMPSZ2)!cX;QzBw?qaPexP2AE z>YwdH=QJmsHwN{BnAR5unJ$uW{5pH~6OI4<9|?ixC_gY8+bJGN9_F*13@z-*{da(u zVgN%$%y_NC5HP>X7NJ7pEs^p9+LS|^KV7zLn!oQ_?y@%yMe2ZImz^f|gjx_Uj|YET zX7hMt60$$Q{mIgA$ELqxE!%uPoHxt1W=1Gd(c8ipxVtlakf%m%Gt*v`*(U~yNT^>w zzjf~mLd5%H?!}`XD=4vqi0Nc<#5Q#Q<^!p`L&!3@%Fy`3dq+|dHr8s z4{AP#)gL2NJ0`u|=*N|rPUirPfprptAM^OhT00sHig|Cw8?^5uf+Bp)NM)1NVPz7@ zza+q|jStYJ$5%~g#_#1GkKegCraNEKkNRd(*Y<)ADyx`KR4@qgIQ<**iys=kLtQ^g znh@la{)Z$2hVfPWVQuiR{CgR{;qU&gKOMI4AMUq9i5Mw8z>}Qu1PuTT0?vz}0NG<8 z1fq{YtAvYf z$YcJIQGp-Cp<#w$Q~_ZBBu1#hsaA%T&3w7a$AFNCm}2DDJr9FwGNW?(qP;JuVLE0M zP?8{T0~``W77jkI-X=W$^uRBo?{Dt=QI+8Tee|Uvlib+0NCD@*e`Gx1hyMmrj=t)U z2jX%1|EPW(=FNe|Sh0g|)bUQ4F74olEnQX{bdfaVxFYg}b&*IqbD!@>vkrN<_>}}BAeuSq;>PZo&*?TG+X~A;KTpGKy4XG%_(~+1BuMg z=M)VkAnQwCO>|7R(w$6k^6fAz(%_PULZz8>#4T5$8E*a9M)eKQ#_*{JfGS?alMO@t z=#o*Xn%z8uRri2dWXR@as0Gx~k=+VFmWyL75twmiB#dCnOPWsA$wX%PK`@-SM1Y?9aQW^1e%hx_3^M%%=(w=Z<&b z0>vU5ExjUMTMIG7OpAA9WGfneaL=W;eCCZ9C6e1puB}|bBBvKaZ`AY{iKhVFq zTTl1*8=13VeQ4u$F55lnW@=^YwHWgTN&*Wp$qjkLpXtAK*o|Lc&@OxjIZ8v8tec2C<`Y3E zczIsm1{zf)->rc{)`TzD(3iv9@K8YKf42}KM@kG7OOX=`%1*uzLUuWu1tEw7YfL;H z3E{ z&OPUyY<*ApcB|GCVfi5&j|w9c_n9+d>ex( z4B_Bp{mTpZ``*+I6hKDF!af)BwHzA*Ck8sRs_%$pw7u>{LnD}bKL7K`TgCK{wy5$0 zJV>(#w3vi?rL{o8HLIj8-6-A8GV7f~4>7pC@ZGs#xHU~m- zCl0Kkaj3?xYFIqv_#19QG_L<=I~^Dkm(7vWTYyYX{LIQ6kBZZ-s~O%kEZHJo0-xf9 z=(j-?Xeh_f$ya;s{c40Z>*0+RD26O{4l2m2@;d{GejD!-jonk&p3k0(qZw^(ROC%t z<<@@vrn_VzW3fh0nFt;Z&uo8t$2FSC#N#ATOE)f7M)NteuinJD4;p!}PKO`ymP3(; z-Y7Y+E>O-unYIHZn^7U1f!lI&ttqv9R(5?4FBAR&LL8J##w>rucG~iyov#;_&QwOI zJ1JGFaeTUQc%C6WS91X^bKP4SZSZF>8w>o|zuD;w{B|TeZItGKBVtt>$Yr9{^mEz! znl&ewJLIj#;s}K#LQ}`}DM9$=OI^d(0HvP-|L5o{Omb?B3Te#H>52{=-fG>9qJSQd zN-&83`wN0lfg;_;Me3$n`KsF&B5rE#_@j|ie@TH>G8Qc&DplCKbaj8vG?I3;=QKqQ zIlEJo^{jx(Y^2j`oB0K|vI$0|z-X4q+PQ+b_?>Pd)6pgSoA zC+qhppAZowIda5}q4T{}BqdRp_sO{4#k3upz9#M}*G#Gpv4&+EHvtzLlO) zAo@;Pi*B{Yy*#51GfLaAnWvbocd!H1qyHrh0SEdhH2Y8AKZb>v7TED}3z6Fxc4gL@ zHAoUROW#RdGP>NMIWoSDW=|JV2&PHCbgSc4CHRpYgdyMOPU*p5? z27lhJUQo?TlMdwe@>zV6GU(MGUHOTKq)iHjPQfsvcuu3A$(!QS`p?sI!Xt28h7p`%7vpx%;d3Y*wk>im<JIW=Vrz7264fc6$5<+r)^X#nyINwJRL@lq0=Hg5!(W za{_2J)(K~SdCWk^mqvY-jZqVtOr;}Vg_lko&gNDeJXsgoTb8{p;tiw6r>y;Uyqcj- zJpvNclkhu8Az%f@>fM#V=?>1n&-0&dQ?ooERpfOFWj86@+iqQjsq+o3iL&dFCj%&n zP{(f0#gJARIgsDvo`0g|{7Qd>=FR1`FMPQNN%g+~A2pfp_fhprie<0+!c z*wLaGsZ9s<;~kvU^D=Rj0|!Fwi?8-0c&sO;{Ti=k7$lH^A8+}~LfLkq_XQsO!YcNO zok8lrnJr@Jb^PH?o}f5#l?P&si6)L0cjipi$%kn(W({yXhwXnPExCw1CmzP$t~$Ry z!qq@o`CQ@Bc2#>LMaHTOW4K9oq|FyW>w?CX#~t?ha?wzxfrQ~-(gdbVOZJ(e7TW+?|%+`l1|lEl{Z!=vr9k zEOJ!8qaHofXLV=4jVGN@M<1Yy?*3pFms%%Y{NO_-U(fz>QNeq~>~4yuY6SkVOespG zsdsvxbnXuvzcrrw68^hi+sI<|PBio_nZ0ELn~TSf<*y%qO~pDZ@EzOEHk?uL$6OWxmDoAL6o3AnBgjN*6m`#!@^EmSxNE0*v2p!X2Rbe z-iHP_ml^`C=S?x?uK^v zmCgm%H82Lwa55yD5Wjp}XWzz&tFbT#jJ?>Qa`seE%jY>s-fLc`3FWDb#d(k6L~#hqY( z1%j-41P(dl-uqh{4Nu`8m0U!rG@6Osb{Vfhzn!nd6%LAfhR1Lj1ET4uV)gr*G>?CZ zf)@20!Oxt*ncD_Bm#t605)&Y&S;F>)`o3aA=OjvoGE%vSF@AO@=;ahaN$;(GM1L8t zj&$J9IyvqwVD9i*8p*A`)4eR!+#sh_qphJqd`(vS-1}huW4X~{H&_NpdX_o;=~QBe zP2+oBL!Wf5!oj*afO33l7|(7;m6PsUR(b|~yye(7|rq`4PQ@Q`p z#r370Xz2G1+?rbVr@Y1R%0xmzHV!Q{;PLA9qoFH3<~I-oP+T9zC@_jatulB8WIMaF6?-UUwgM+H*^to##jXt1@e$pA(6X~1Y4OUmn9c{*|r>@wutKgYBNuPZonuaFLxge9)v1deX(krJ`*20V8A*2rk z)bz|2irMy)P&KkOEeWwZtgsuXU5aE%LGxIwtc0lv7BGkHktltnj&itaHKZo|aCoYg zv};aYk8HmPDH+E(l(Pd*@xPVB_${s`M>f~hUjxh&%I1nybY>t+d8e(H$y}C6)-jsf z$i)%`wD6E^9jS7L+z;c305Z3DudPY^w)~G<^Z9K^_-@F!oKbi)vKJiviTK@&f#rfB zaOPOyrLX1U;a`7{f5(nG8`xb|+s#j<*$r{}taePxkR3_4176V0={3a`F8vA+Ifwuz$vdM^jhhuzzhnn16~V~8&qxdW$=&>77V znx~@OSK04C#6`9yDrTY}!5Jlx6@{~tbry`Ane8g|jzQa4NqBi!YUQ`CHbgTn#LHOX zjO*voYJdOC%uJoi)Y(hnHe-8Jo|x{&lRF~T@=FG=^r^A=zo85)YR>Z}6X&lk$CdAn z@+5A{I@(~4YzBQ0G3-XxG63y&xL34rX(-GuqnJ+kgDn5|{X)$QlhND?P#LZ?kMvHA zvR2)Ytlt0P(j_qR=gbAaep-hN^;LXJgocmGv1DE$`)Rz2UIM5SzxKNF>N*7M_ zEXU7#3l7_fra<#vEB*fU>FJ56bmw34_za8coZYRa6dYJ;a8N5)Xp?&JK4zrXt~jNm z=VnI0O)IUM2(G6Eg?7_sJ{BoSl8g3&cgr@27dH9cB+X^?%EW-|7UQw-D6Y`NH>Y9} z4;hB+{``!c$Y|wv#t?3t97v);GWPyFHh4So3@snW-v<5LfRviL0az(@iqeeH%=d!{ zE~&UUuk?HP?ng(6f*wokbWoGPri|s@ceCQFOmWx2kISaFGU=@Tz@OWnGB7AEzqNx> z@~X_Yo8m_Qpkh?gCG;eXg6~RS)0)T{p;8X9JwECnE*mAI(S5edPEHN@Y}d1DRIcFOpbTMMi_#(TtR&?=_^w&W1y5`2$_AOt?xZw1WtaymU>(mZ&7U~crD$xdN!mDEmhJdj0}FLz&!@Ps%`IFVy(pw#6&S|y`! zQQTX%zf_X6OwjCO#8Vn6VW(BHn{JrX-5Bd*8YC@dvNPRTWOsM`HJ!+V;NYX=I6Pv+ zp^k8;Dqe6*2~MO9FB+DHW+YP91HSkod^Z1id;V#q!Gk@Zek0=emQ82YLSe}9t~Kfr zj8qa%*Y254DtRz?G95RJ);b0sP1?=}+mP{OOed5qpRJs~I@dT1AhmRp@4pa88>s$8pU3NZTp$_IxID_$pmfN_<_YkFfFIvjSs!}}U26@}KL5Y=n( z5JbK<0r7hhTy`6BI4?o|s^1vHJonmpti~V79&th4;XA}Fy8 zMyov89e=1IIU4bvMXRG@2#$Tk8I;@iYnr+@O`n{2*NZ3QzTpsGynO)M@q4o3>Wr!p zFcnThU`!a8UYa-@SKh-4J4awrwKYZje{>c9+Xz*to!6`#z3)Wp9|o1>@R=a!HYTw9 zT>4dHVIPe}RNlTFU8!@7fNk;-sAl-+NPrR3JntA${pMg=8BQBIxJtDx)p~eC8-Al6 zzO5`jUJAon63wtOWiP`rYpQfbfgRA@0cU7rVx3i!XTvcG-(`-!TQJ`h(IC=uh+8Af zsI&K3cEMxO9%A!Ux4v^#iXSY~hV3Q3UY7nW3vaGS{@M{7s>!|>EJKigF$oPFHT=TG zs1?3Mb;mESL==)vHUMHCPFK`UnELOZU!;Q@EM&Dn1AXuMX0~Y+NWgjCJWB?F_A%F$ zXE8mhPVEaq$iQ~-C$HUvA0&*oEsf@9WI8NEE4JxSLo(O^uBmjnxPmr6DF#nuHxF@L zF9gluqANG9GflXZ-Oq2g+LSv9AYIqMb(@S$`oay>#PgedCaUwLusTn`ymN!;Ixj=3 z%XD||*01qURNYd=ssg=6ZoEPN;!NqAlvXHFty?11{*E4={T)3t^Su!?m~;XGI0kuS z{5vN}eD;xU()oXQPNOb|cVOyI;}=Ilhx6huYi~w8cVrujnY*F~$Tq$< zng=s;Vk^StP#sA|pB74FxDls4*B6P!7pk9g4$zt@f15)E!8EJ?-^K}@t__(s>YJQo5(gFkL7ZVu^22VZDoR$Gz0^fl47g@a00ZtK&pS3s09wN?y zQlTgrS+7Bzu{*E2JdXHCK@S^t{Z88bj|Fq@??3f7-TF>VV)`XeT?&edr?z}gQYb&c z*p;PvZ{xKwy=jdOo>u;&;BlAd#JF4q1J7Ttfy(E8^J^Yfy*i21X;I#~I_AI%uuH2$ z^h;s54aTCMozHW-z}UY522X!>#y?WYQVG)b!o$}fN4zltm8MQ zk%4wKCb&)=3`D1OG4(&+F$P6*q##`cpIR0hKODFHAU_-TR+cf3P^aJBGG3GrEI>0G z<&5cABpj@hy8p0Lm#F;N_%Zj|?nYB78R4|}#yUj}6gTnsgGOkC_{cr zod=O;0*E#mAnD!+^j)2W<{H)Bb57t*EFm8UQmQZ*GkQG-M@#LE8se{@N2e) z;wx^;AD6|~b~`p#KKcHG>fr!ab9jYo!ss}Eu)64Z)$=zx`YPRvy8|Q)_`SmqnHunM zQ9M98xCBTC7aGeV{#4CI#%W9sqj*2%P=o{$Pe~)AiL)VN{gz6-MD8>|Gy> zvB?g1Nqs1OTq8TehEb6kY+kk+U(Znk-z)8l&g4AM;F<_{wSBK1EXW#y^eLxgG85)6 zwG{!i^--R8GvBZ7kxH$di<)!Pvz&f-^Ev;4Fn`(G>XMC(aTRNq5OC7G z+ZJj6W)IHR&k4eDXE#o8!E55cS+0Yd4HkKac=T0Tg|(~4Rqfh8dX@9O2jk-wIQ@iF z(vq)8f|;762}_L`31e2y(Hw9NmpQ@zbN>)2VUsdv55f~(>Z=Ko*>Nv>q&MRNF7pQ~ zWK-55UPM>!VB~lYM{`P7AsS^ZL)avgi!8K8*PtJ#dY4U2TZznV=DNXp3`jF>n&1+Z zGmW*D*p5$p#ukY>TKpLT@mXUcXcBraAeKvzeRXn(Kjbi|#Y$?2+kox=*bgA(-9aVKQbFWA?ZpUQ=-vGSpfL>RdElHDlj+sT| z6(&$_!kNG*KrBN1l@=~nJj$^kOYXgs!bSXmYvm=^fHOgdA_r)ES9ciBtzn(KHl z&+SAi__jz8B5X}{K}3k~%to9cj!-ekmmV|-KOx|=ON;k(8W1`j9XH~A;Mtb%{~v9I z-_zg-WGxzt5q5a&zPmfK74;3{r?&oCty@xMy5b+h z-4lBr=g~$=p~SU2!sbU4IVZ10y@5RU#%Y;>o1yQ+4C9PXAiOXW!P$;$EghnwB->@;z( za^MBrNd`4%xg`W=^y=$uSItTD#GBeBnM3XFzc}e%V_M4OVc!TlJ=bz_0J?7sYUqW& zx6o!KQF@K*&QKncjw@5!GJ0P)QAI<_mw_v&jw6bx6DBO?q9ej>7Y(M~#-D86<8f-` zudYK*WklxNL#+n73-M>{#@BR$t~@0hWT9WW=p0QpO#U%zn!1cB#ngM@ag=>WsUDFK z88Q;>$nPcE^e-CkXp8&0X*2Nnoiv`3ZHhx_qgRq7l=cl~)(<0;9k%dwJ4ckhd=XH* znMVtp%UtWY6*ZE%?43Ks&WoszU-G2s>!d^^1Ec9IQFUoiZB|_l4X`{<+&TNH@?_rk zfa47c1(^CY=}06vWtEBfCe$z0`tA`>UG|_i^;E|PGsSKLO8`PBC;E!DkKc%ZxQQF!H3+fx3RXv-r=}`33WqA- zw4n76nC9x)&;cyqInAH$ulbK|$#3aG!w^CnfbJ?PK3ZTQqi~apf0KmLr;arSM%;~s zy`xbB+Mm?keNKc&$WuaYmG~~Jc1nTnNNF(^QE|I=-Q07zm|A?jDIQI0#Ydp~IupgD zH)4HXr#RIJ23v3*VXw)r=1M5Tj(LuMj{ydC_RR)rtLNTdT#gXm4tbe~v5xQbZ3F?^ zo%F|dZCN>hTkB~_+Q#U<6=IA|rmxH2nv589v>E{6WXEwwXWXm7>boZHjv}R*jF-R_ZB`aVu8CtMe zDNu&m)_S96!0DDFx1*4neR*`T$E49;q2KS zK${122W`hM!wS+WMYnNFnI?C8H*3Sr%jRr;fH8MY+i^YX%^guAg07WFwnGJn)I@)t z!3!FXXLrDDN~2#@(58TzX9E}Gi{0eV(*jAuO?_}&IZ|Mo07*vmJ`wF*PM7>wYf5_SUxh>pa0eaTqNK=faZq4QU%9+uYEB<{(lRDj`;z{ zw%5T@dz5* z@j6Zcl`S3A0WnNh<*s!g1|F#3&egHWDzbEDk>3(#J#n7ud)fdbR>iL; zUKk;wRR`v($P1Y$hFH6hv?cowZVMDi@(sY(2Mg51K^i2*_k8dc!4p)dRAaweZs85ERM1ag&l7_~*M zW6o2+wq3w3*`;pjqkd^wBG$XBnTy?`)2{BAHTczuhA$PHTxH?%tk!c+kU|+zHKSm65va;03cGLILB-_`!ncJUc)7?L;$g z`Gcrcp$c^wurzQ&3$9rB&L_TuN_y>zsC6=VA;qUn%neAb3{8_NjNh0Qm4PDqUggup zP1xV1z~mpan!$4es=d&@gN3Y5U*P<);L6R9FaR*Vl`UcF1|3eq;VDhF8jea0_23iv zdt|`*bI+ANX!bAkajN#sak7npmeE5;JFH;H@OLEF0f z8Yu}GpQVv#|81QZ)lB6sjrBXek@p9`PL@4@kz|%=a?|1;sg<_M+B(8X31%^4j?U_i z&@Hm;m~`DNElPh`w%hui`spnLd}wkSdF=3o?S2o|JCRrP{rUM?%l64B)|p+e0P4?g zhe&Ly=lC4OV}+HWnf0Pwspz4aY3IrX!pyoZs%J9=G=gdZ4xT>Dy&%zEC(@(#ZB;OL z9PDz|dYO5eaWm^7OH#t40x3*=PlD*}FkMq>a zGiAEX1PrskE@WDtf1r0ZD375sP15hKshCb1(%Gpl;vTTdkPc+bnS4Rz)X(8EbbGY) zZBu6eS6~ak*!#O@k&9W*S+)TJYr&0t#Z`thmhpT}tYxwAgth*5Gnv^W_V$ z61jrB9SmDYz~~RxVR)xefuP&Ku|OV1a#z{KzD-{xItncghI6=d6Rd&P=rpB9 z_}n^O_6?mduMS0lhTZ?ZtEZrCjA3VpJVgo4F?Ey+6i<_2xL<>JP)}(JkPT&M3b}%Z z5^Q)4A^&B}cmNH_wEfaxkLuiFb^p3#!yWU=$~NIkaGCp+k}fb}fMJ|A$sBg~tu-*wpJpoN^7Wb~823 zh)1Z^NTa~HKL*n;)?c_pp8V&X8P)0-|b6YE5XaV4fDEvV-M@L z1YACiz1LAgRFTvMnEk5Zl@b464U|{kwng9iEZOif&%g-O@NC5Ot&5=9f#h4I6u=Qw z^_sQ6c|s%7hqi=v7+sS0eyQ;#?;ky=x!3Afv9-i}FW_oapp%trkB(Q!var(3(Iy3a z7SArzOC-GCT)92)n6;^CWTJjD1BAFCR^Q8%)V>Aj1^qX@zhBP~| z;xWqH)h?gtaKm!q!vEox9q~B5+v zM^tLq^m%7uug&YGS8N@-_uigJ8!mUvQ#iIAvONz`r7$9CpwpTFihpN`t-%X|l$1Jf z&XJz*-&{U)IKPcVAS8J~FhprYoU*Rm+$HIFsEFkbNpjc3;7f1^+sT#e2!tCU)PLD5 z4QF*hx0&P@lwRcTGNVPjy++wppL0|uMvxkXA3{XARbKyx7}^DRG^Mt?*BTN8LdR<( zrp8(;jmEu&*(21FlvH?+6(FDFP?X(10B+RF$``tOn<@&iiOyyr ziq(ibs7Dgk2sm#|4P)dlk%U&e1TdZGh)}|SHXZ2`J+11U?aq!xy9JhW1ve7zY%cK> zrP^xjy=SFWNBor9XP3-P5lYSMrJQ{{RU3g1?tP2Ew4 zZ2Vj>e@ zDVOE5>u$tbuBNQo{9qtDcI=LowAd|Gd}i&hJk=Kj064NCI0t>QzuVsRIOaeS-i^X* zOfQ{2HzoQP53%ZOcs;uIk5o@%bBm&E)rqOqZ|TJ$^@x;E(3@Rg$!oVTXT$~azVN%xg*U0!HJiw-Abp59()p1J%Uff z&U9@fr-$oFuO!n)cMe%*kMG+=$67`t=m~J zD=fy{5pY8%t?9#Zi2lA?*VtH(?cQk{ru^~QUZaK#+901J5qTe0l3jkhURuf0tCV)I zK2;Tm5LACQfP8f@aD5+n(x+L>Ev-6+wB8DfgO$R^z)`U*`1-l@i4I3L|BC~S>Bu5B z%SkHW6a~h8Rzuu>qWRW-naRe3U;dC_*T%~xO4FHg z{$lH-yx>r))I=Xx1?`(t7hHvRf8J?R1OB@ddGPnqmL@Sz7H|jk122+EGxMX5j z^vo>XVcNVmfgMdM=d;xa<~CC`vOGm}9;-8?zxQxDxk|*wjdq17UkzbOxe{2O9x{b= zL%}ZI07)s#vndk6L@OW9lDkaFP$_GPrC)WXA+%N{iW71#T8by+wRz)MRrP1E(lx8q+(hp!8-|!`(e7G>32fth7Pnh zTOyBtOtJ931h??d6sD11CARHH4_0J@>{3|gB!es|S0pMBpBf#4H6HuYl|^>C5yas__GviaXGcy);=Y8#&dFjp za2EeNgbf4B7w6vDP=}$4g14KH@xAp$F*xP8^deSbp2W(8?2dS%Q18b4)y-Q8n$x3b zl9+T=e8eWg&V^rGslGCh%i0)|O;LW&jbrfy13&;l7w?V0iAFs)6O7+cW^15huF%N* zO4N8V24ghr{u_@D_V<(l%U*}^k-nk5cW3a6808I#Llm+~K>-pv^o5{GjV? zIERDHfA> z93kQL(&kpvRBr%v=xZq|Dm) zM_}0tF}n<|(sgDP-?Pz@hxYUQkG93$O5xvUI#Lsj_mtRfO=4Rzc0u#L+KdViml&~U z0*fVgz$!1a1ymy;`{$umwP4Sp^*M79^X#Iv0P%2oS=%`pp|C7%JgoG-OgfY zaOZL9)_8Ae$p!>jI_~-Tr_##4@!Q~O>6GcMUXdsxolF*_!Zmb|TQDZ#F}-o%x%6^b zK%fVCX!?~Pr6UJ%)Exbx76br)h4XGy6k4;x)@L-)PR$ImtwM;toQ%+Y*YU=GqS=&% zR4>g$Y0G=bi<;bfHG5f5dd=fkciYNO?~&;A*GV@o9TFO;H{vR{h53q28TlBCRRr(l z&$96C_HWiM&1-_f{xEFrcUUv>lg}$6T=~JGVOp5Y*bhwhkO?(5qEp5WWj{^{;nIQWXAiw>Ye$FLay9hQ#v)*QwceJ$pyEF*oWm{mp;L zXrZ#LVJtYbLCsU5vEG!?=%ZyQ!P(Nos;>%uDX^EWWz$83r$~)Hxs#^%mc?$pySJXJ zSN9Gs76ep%0D!MYcj$knRv~XwO7j9P!$e84*AdFCxpCd+d6QPDy+BV#0wf|f0#ZH- zwIj^ac%GUKgJ3ElF?^G_#&-J0V&!D}+fl-{7ltj~qX)WHCAL`;Ew2?gCptCiym# z{0;hV9vkSGbD2=4PUnD{Ka(A~WyK<&7)(inqlk1gpCLeSWnTDb>SQknyhrjsrqx5- z+y9!*8YluyKE1m+CJz|Di3k=?qo2(>MxC$D8hziW#4N^9+0}t5pZ%o2JfJ)8XjP0G z#bGuc5ON#)=#E&1bu-XOhJ09fG~W<_D_#-D@0>T$lr&m0+`?kTdQ!42^7I@z_oc^& zm&zr^%I!AGk_`$-AKQA_DC3PPSL)9U0YF2!LVg75PA|khM*TZ;8$Q6W0w`{o`u7e- z_B4g}r!>Yd{%metbCso5BTA9o^uQK`5c$7HQU0wR-gPPZ?EG3$`v+j}5pcTEr zG$hr1_;aPKs>VlG(8e202kuM%xbCE)({86Tu$zrJxNeT4gGawCBbfdFHMqsVc}@yo z$I;|uZrfk54sah)ms$>bDrZ>mRVS`d^Qpjtt>rW)vGuKqJh)CEdos>*Cw=Kdas;jQ zL}yox?`uOyI1KhN$_Pq?Z3>spL|46pUZo1ttXufGWLLDbtgSw3p~+bHUI}svU^i~r ztDQV61B#nSU1MJjPrAEh7YV2k$b+uenwPqSw%R zl+$Mg-=W*JDKU!P=z0z2wZ{4GO~P1mI6KZW&-{ydyt0~e=ML^~x8(;Yp zoVzo0E)?ZeB<^k@*;d6hX5j7={;<}_msqGsCDxLIMdWU-Jm2OjzJq#rAg8=`L3kV^ z)lNmp8?~MH18t`Ry^@+M*N=~{SeNc-l^*o(txUbpf2&)ZA>Dd5#52;gH}7ESgA%=d zUWD}?;Fjs&pz>}yfH>tilnZ~D-N8ZXr+!@b;0x+!#H`fKjo<-*%AFbD1^C7{k`4#C6W-(Z0k zyr(r00}PM_>dqIH+9Yae8#S7p0>C#(s|9Dy;u@Dh3g- z-L@M%^Bpf&mzF$H(4gXcqKWJ50(RQyo9`j(wEjuv?i+@Fvw2^ThtmjB6~)UrTaqRLT{ET&HKG$5&1obu_Lkzq#+=)9Sg6U6gjS zYvp9R4!TOI_?>p1RUpaM*jf0dbhGtWO4$4d)759;vxRy{B>4D$OUSWZMWz?rGs0sTTYD6u~#34eaEo$?d z4!IZT)84ER-qZxt=zVxHI+hMkgSuXnQ{j5(u~EJ_NZUo;KKxL@Dz@r-Z1V zqOGg8YX95nD+jEE>pF6A`BF!bVn4-NWS2bX0@ZV(n+DCIqw+llZ(pHhW-XO_#H`7} zTH$nAnp#wqB<`kX?%F(usJ8BqHiO2yDS~D(yz@t>>nr?6z-r|7QG~3d;^8gUx}9R6 z*&m9u{x7q#5!8+mbv>==^5r2o*qJoS4wvnQV;0ygzMd-08g%Te=uTIze8f!e;PgH7 ztb_DVr0Vwi(rv#{I0z}&yJ|bkd6BP z`oWV3RP$*O!!cug7A7q;NiW{lUbCJnJ-ty)z*JTz}Nq2s-D`1pGVuk|>he z?_W>ed3i`^BjqW^|8rv#Dj-`kU=o^yu=jLiqFKhvV{&zhb3I_X3idvT=xc=2(P!EM zIFUC0*JsiWpNSgSj~2>5ddhE(*!_ZYAJ1n54GnM1G>+FZH(syAZ%KRmIk8n3l$$+Z zAwPCZ>(Ho;?CHm(%tkyn^&~?bTk=_}#Tki^&F}K2aK*s-|3vLnOu+65*LhUZyyOZ?;#@->g{R@%p21`*p9^_;;bhMsyx`=>90E$cC%u4 zrOB>Ko<}Q55jo?btF{|g<^E&0cGz916It2X1O$A*)x$R*F`SxGA6mUfbI^`UlUuPd zV9x%q=A8LWC(usX9r7LEIs^thZHi%zJbv;_J&5K|GE`Q78wm7#t2?7T@-|&9zuz#% zTzG54d21gZSNx{By0hps8Qeh2P-vK6*aphAl~sLiqfu*)VUs5kDzLT z*RPx={GGG_A4=F-{MLf?)-ETr>ESIbwGmWgoVG4EwG6cV2AFPAA|Nko{q0ZE{F>aXo@`n=_&I#j8#W-LQ60&N zaceKzA*tVQJ3t&tyHmh_<)4)3);#{iHU5A~oAxB{A;~^vOavS4P1Kq6s(C zEpwt{e)J4)J$Dw0S4I!^HDZ{wC{Lcj`dRq7vn@%I%vJ0V`8MzuAYe$izq{kTve)LV zr6zHsO`8H}YT;f`es3fwHc0C4%{!+iAW!_WFhjuJ3M6?cSrLhdH>zb)_y%4J*BIFs zRY=&9_1ak_7Yjf25U@<_fGZk(gQ5q9XgvBt1Mnjzpe4$q@b-_&yQcuCRxhy#*=u+c zB+mK)yUtX82NWgsaO@54a=qH%He}btTor>Co?8bNo}ez*Rj^ZR_ZTdXeR+VcLzINj z^5ar?+p*qWL=uo&2kjf=-_5!q2#HtRJ#%@0sMUwy`(qDF$+oux_MKspP&-?Zhpq#m zT)KX>SM4k(0$XC3ryB^79%PTp=$dgu=G8p~Wn|!Y=24h#;b;>a*9kRhit*LO#-)n_ zUlIF+TV{Rw%yJxf6m0pe5gjlxFv*#P79*L1Pytsr=o@lO5B8S4&+f^!um1v(FUr(W zNR^#Ke!2xvb4I8aiP;c3z3T3Tgbo4J|1^qj?62MdLqhJk>mS9aaStK=9c+3ZOv}D@ zA~o9HHL_cxYOD7Q6=HPXtX7;vie>1s-(6O})L%A^qvblN4%Y}Tpk~{Pzj}KN31Fb3 zhv58Wl_5!lk~-}9@%%hFnUyO9@YYs#7S^GhewShomz5x-eyF-FF2Z)BvtoBca%GvP znkd$~kz;2AsM)?qEmAnw6N@y=q2$EN(PV_PFc18DT@jni-}Q%`eCR!;B$aG#i%9IM z-3F++->Ig{e64d@+W(A0KA8W{VdeWZS4VM#Z#x=(AjV{ppZ@NfH@2?dGuj@1JVcW; zVMpcmt|ROOf>?Yj^_{^GP1XzdnJ&uD*%9 zrBS=r*BBg%kk@jqzSZ;Xo$1z-&QxcShSxl3_~N52)a!|#)1$;tr*pNwqv6v%Sr$IE zb+8%YvFrGYeLOdIG~x=4EMY9Jm05Q#qcw>zQW`O7dAff+Dz-+pcrk7NB+w2lJx!;d z+ut`1_r#ysS|ZB{$2Q~e} zBea7RRxGmJ3D;(J*gI^ZP5(nlg-yQ8q#K{Zm;h2#TJ71V3KuC>{&Ewm+d$5?m`Mf# zc)<%$QXm}bKA)#yc{ATO0Ig3475n=$pBGN<*HgAq1oHYYjbUN@IGhg2fjxPp(3qYC zYAkm4D?gyUzygT;skfU6jsUf_xMLUjQBjVKk@|U*EP@o}pvO&y)A1Hu|YVRpapcm`0dMV90n|~ z39f0pg8P~S?rV#|qL$5n#%*9NQ26T83iFr{lb@TFnAP^M zuX>@QZ4hE|bVgM`9~=Q2FyGi5blV=^|F`S(`=>y-9@W1;^53svK#B41|M~qt`XK`R zf4=tE0^}S2^EFZKzsuC0zoH=iYp(h8wXooS)%JhBbgW+D|9^Y`b@e`cLGCfAtw4}+ zZ>#!XM?_z{*oqku!X5@d&>)vi^bnx)??wDJRy1f839ORA>j5qt_EU0)Ks*`b?;nF^ z=Z|Fz?zPvp+iRw-zrMdPn(xa)Y}80F!JX%sFw@hu{ep||+g*_`D46FQ=G{)U{?2$8 zRsV}pMPBDk1VZ6U^VnKJ@c60caOj|-34_#wdZZH6((d{F?HEwd{rf(T^*@ICy;>Ri z=6*Cd?ug503&4bevd!Sqg6|S_LWYhjjg07em@ZKtVhWV%1OD?4F|>7B8exV35mx|q zc$!o=<>ARO=2hXLIK(dz8R|lg6BYC~U>aPeQfg;ZI*(~G{PgTRVtTBy=0HY~G6);G zF4b#q8{-Y>Un*KG<^OrRhWYY&?UsDaRHDiXpxp3S@7749F}6H>W>VgQLh?Z%Wd@RW zg2b5tmp9ee|1v9LQOdrMutNau03lU^1_)d``n_x7U|o@dEUq-qoeg%tv6(-f1i63G zv=;L5{9=d9hK5+yq|xAmHMnNThee@VkmWa9&e>UD)-4122vOfa-lt3vdq+>tVL`6yA(|#VEfk)7EhGYGlSn&;P3X!8q-fF{*97fu|J`jF2@p*QVA(_h?caSehPi1-MErNZ%P?fURB{PPqczU@+2QzpsY|n-0GOMOIXXG3 z;5T!>M{lbIE&0*X0mwp1yQ~Ml5XBY}Rfhyg*)l(Ds|_YT3Z=E2TAS@UOy2-HB`qLr zqu~dCT$MYznJL_WhopXv&vxt?=x?@v!Ayk}+@dk?9Pm%AN-Rx#BC%zB9|F}Q?DbdQ z?Gn?2-XohPXvVVU2s}sviUUs8S0f76Y^QSkReBy|_f0=CV8%GP}^R!5pr+lC1p= z!S3RyF#)nVMY@e22<4^w+jFQ5k&kvCPvie8%5^GnY-LkG>}~=h5To-NMz2tTQ!_;my|suQgHzDD;BFr+R4<0<9VW#XJ{x0iwU8 zK%oV{-LbH!mxkdsRyIF8K-mJ8y6|&UV9(U4?u$*w+@X%90!;PGp6d~>s{on&>3g)s zy=CGLQqKdfBmXphM&3SJXuekn#FCgZ&6T~@{F%b0oz&m}sval$5(NwPv+mD&vF}n| zRIa=Vz=D>Tqu=oX?u^^M7rfX}0^72Kio#oBhb@)(M^76lk$f3w@HTK5dAKYT_U80( zeOHfH3BPR=cYiRQ-B-bQJKZ*OM);t&dLOrE{ZS71CN|u38P&23cb)y{k&Lz-m^)jT z8UjsJIq#wPi)S3bGesOUsA)9|egNVJq0tc+PFg{(AuGU^qPnv+Wc{zuq-aYxK@Mo= z?@1mt4VX_VW_E#6Zr6;S1^lgM#0yCH=;2r-MzdD@yxrH0T$dS<3&Yg~?T+WT9)CMJ zTVYP}o9H8HT!k-c89is@*TLoC8^tw$Ub5`hlJ~(J?WUQkE&|nJP8ManUxK#v5oZjB zKp@VU^?Nr#Q+5L7g15N-Fb$js!fz~am>3*Pq6Jl=aVi%TLII!!XVlwd`I=y2Gv$OO zH)Y&R5qdsIPwyva=O>$Dd0G$P-aQLbmyTKZfW~ftx*jS`X8+DAEZm#xMCUq1@t&;x z25U(pksK?F2nbnYjVo-R}8E*+Lg zhG*e(f)ZNzm`pv8;zZZTw}yi(O!-8-9L%u@j}klRc2M=Y=ZDfX3c?we(h< zN@kz}B67^fY1Ck6NcrTqQ)|i_@dEe}os)}3vuOIyrORZDTGZ|Q&sIZe4(rJJOQ#~l zb4#a)fRrp(CWa#q!55QI%VW7L;6v~a9&d+gV!PmS&DG=RLFkb@9J=SemazxN3etS@ zYOp;FmbqHHp=?ss6sz{P&D*Z1APW^5C+DIOP{`B+l)?qHm1>ABIj&k{t5KJy|5ApiW z@kI+-l}_S~V{gpQwkP`_%5fL$Rgkbmf#zXbO|?j}SX%ts-A-u=WDk@*b-tJ(F+dpl z$WohgG#CCXG&O*G9!hzBL)8%nI2p{wL!onzKR~G_WPkbw2;Xz_cWm`u92(0s?ni*) zKYO&ia+~pjJ8?4T-WveqI^zz6%(mh)awiIjw&BK|Ky(=pWF-xvt*O3H{$pu7h64qr zJ)(C@n4d?%YWpiDojqcIo2GqtyJ-w#)UCqtCM%$E6ff9wi~L;9RM~J}1WhR811s13VzXpfkcqyA(yfn`L5S#LHHUCh&@EIcyUmmf2c`PJn?+OA z%LS=_k&Jq#erbp$3R?9gt{gSY3YgBlo3iO3yftyPE&q|~%?@|1MOW2pvm=%6eaf=X z;YYbEIQ>?!-c+y-*f(KO&A!Eh0Em&(@hzh$Q9_6sO8OQrcT^dQDqoS+67$CSmRp)8py!~-4Z!MgVsBeqeXrJRxtk%jmUyun1;X9JT60{w5V54SwyO_0o z+&<4g-2seyAEcpXc5Ae7Utw)(ns8^q@tA&Yb8ou3Y; zUFr^~B^?d{J`@+%jDI!9%kx~Xu?R#fXK6|Nf zC#SrON%!RQkvPg%`hgP*_abmqHWmgr5(kkQaTlr${f(9u0{lwipUsH~Z$`Vql5u+^Uj`KN;QkLW7mX$afviQtwGnq2ZJ~7MrIAb0pNT7{I_6IDY?J18 zPp%C}rS_g!Bqx}y#9vz7$gD|zi19~u?NXE1;ydeMjZJP#o$f8kDD$=btl3x6gH)|? z9CgYlpqt=Q+Tw8E=s0 zgFj-B#vZss(&5BCK-}XVv{-%M1^Ud;aZkf-Uo6hhvfY%SRx#Jw3y zTacO%Ksn$p($4kt(v?$l>$#(WAFFoMYR(mnd`yX!Slbr(%B7gu+iw2mZqwbobO7_K zi-$?c`klSKnNYb|B~(`RWc%k^NxY6;40cRkSGD)SV>P+?Q`syjJ9Q2r&mzPf1k`F< z{N~(Yvlp0+=H6HuwYInj6NR6TSJD#1!o}y6zF!X`$bppL4@+)&AP)3GX~~t)UhfYhdKUei{W0KvX+e<(G4>yE7cL3V*1WrEaSoTk*qyOhbpMfcqc+b z3uNLOD>r(-a`(G2Ct3WvSek3q6bLe3C7gNW2?d~S6M2T|JRDJ!hE`sf5(#KQAH9y* zE>kSETH*cD$#SNYr#0|;1V4Fa(F*J2-Wz+lV-lI`!4i0RlIa(iul}fdcVC+|_)^~~ zH&XHGF7M_+@yNT~x8=s3S1G*LuRpVcTTq@5yh+d>`<6##evTmLOHeDC*%zwRFRX^g zkA09riqMP9>!kBLD(7!lW?Gn*J-=`xaK?ET+6%V>Abz}|IZ;R;jqcrOKTI4q${(HM zaZ45H+swYia>^_=wc6vN8jElu3>I2T?h#dPu!RE%(Kkp`O;=Uo(%Y4|s@zNJFJqGH z+Npm$2NIMgOobn7UiB|Xjuvxxe~!?fThPj(Oz&# zxx!>m-Lodz>onXCfU&Ieb1kFB#jF{D)rtU{ypOwk9SC!d3edS@lG3-EXAyGc{%aI2 zj*urCLQ^2m=@ESGATKQmRYH!^+3s89ZH)N@xx0KOo=tAzcB#@IWLPkG#zfVc`Q7pr zwTXi`0l2bR8|^tH?;gA8AXa#LLVg zQO<{DH$Q5fUUOgsd(eFCbs1K_mp*_b*s~HK-MBM4zN8D~*c10Ioj8p`V?Hx1xp%Laxs{y_= z5pL8mpbW)Q4vAiFP^vYR>(_1hlshzSY(=d941A@=e)UL*-i8fBhY^mH!i$cLwbsT2 z7d>gwfKb=A0L)UH30!;ZsKw-}{*R^89QwPL!3%4Br1j_UnczNpDVody(~TBP^1QU| zL$`=y$2?#6x_a==w1P3F%>FOi13R{Atz)es32S@k=o$=QkL{3M1xC@(zFtc4XKzUJ zbl{NoTdM;P9QwH!8(+wril*`u4&k-R?oZze^Z`a6i8%KDy1rEja8{6{poTEi8*Ivq z#R1%I47Kj3%xB@xu(h6n%e5sHo*>6stSLTjEz@?na5x`gbXA%8j5Kf|`N>anzWusj zQiGE>pM~O?t#~Ih*W_<;ke;Pj-~Vve@EJPpg&a~CC6<6Dg(XL�CK_Y?k>AR^os1PxU!c0rKya zsejBnW;M{FaZm+=xK`+wHR(AM9e^?j%FtFkaJ$II&|9yqg2^-Ro^X*eV7hxIu&gaLe9tCM1UCIYVuQ z{`)fPDM96zf3E_6D-Z?uM)i3vnY^VD$lQKFor|Hoit6MIws_%uLl}>cvp30K>OMz` zKot{0rbPZ}KM)vwR1VfGpPGdd{;I(gOXPylNbltr* zaOLQMe`+LwJvc5nEG`v4WL!?0`vGOqCaIdu@5*>x4`fX<51~X*Boln5D{^-Vaa-Oo z8~~$@Pgo|)pI=93DqaQ4LXUI(n($>TcZCE>hr;Us!t5ZLrg+qR&~H%q$AcPC3{4$+=B0fGC3>i6_~M{O^K<+oj)>dpWJg@Am1w z2R;zy36IdFimyN;oz+Z&j+)@%Myb<~kCw?)%@3lw`41I22~eq1EVN1BI+0sRo%}Ee zk-)ywktfharVMm+quelup+XMONsQR7LEp&(D!&0ZnH^E_BW;6EQRY>zKRVY)-TZ-8 z@XpLPerXirE1#z52I+Et!D@kM5p#?wf~0N30}l|}fPbhQfs^nUzF{#5R3eWLISvlJ zTqU20`W|JlSn|k(u|9S=27%sR%_1O)ybv=`;V~6gpUs+NpM)J}&?Xt$3c=wq?H)LD zabt1ll{FO|4GZTK09iAzqK)A2pK>zDw8)*jvAOld$vjx!( z>p@-w|Iq(CF`*6{Cx*95-L8x@<}y5z1Je{Uw>jV6ycrt-pLLUY#8**BZ+{}|(5Y$S zAVdM5Gn5vS8p8wu0iEZ4OQsXqtoELD1Qk{3X zDa+YCy_nqsT0Lio_(LaeSh6r(3(^87VgJpa6xEGC3oKfY~2b zo1Pd5cBcXJD2`-L{;2eU;FQ-ZiH0z9oHQS43zBFNeYy?IlTw*TR6?h^0+F?+JcLeTL)%fOjB_c{(Gb$fzW}f$)A4%^W}`E3Gxe~WYciAAfM;*_AwIg6(s!Rd;!QV< z8^OtD(yzQmTeY2=BtA+8gzjvQ?JFr%ne_?-P;Chr_8cCrDZzV~EPnHMY2YZ>KgGS4WFNqqt8KDNu79jOU0 zV*NN?=(elR3T0p3hHSZ8*VhE`q1rwGQhXidtl&2kmi_kx(&dto)E17s{QEQxUc&Mo zljCJXXj3_d=28G-f6)T(zH;$V1FlHTX+k2rS>T0!Ht1LJtK{61to%_?30RUsG>v>? zhs70SY(d0C2`^ZMH22Wr!ffg$1doSpA2-#}DBHWogt6L4A;HZsZ_1WRky%7luWL_YC$3kZO*B^N?oU`!$sZ~XAW4<+4 zB-1T!vjl7~p$q1M?m_s#hy@Z}u*|k1Tei#!%frnf(+eq&aJqmF*8mi4RAhXzO;#Sn zS~qY%gpQhPj!Xp3SbEqv89Q)b!$QqUG$e3Z@Vah%_7(c;ti3MrN=FL8g873cUx?2D z^mWknImclgc%H0OhWFElA%HdXH}2_VONBC}6Ao?V+h<-vnb7V_KAgS(=;RXptJa)7 z4DLD2Qiy?#YIY?@$y)6RN2+sKDe(Y3)x71h(w!1W^S%$_q8A zY>vPp28~=EQ$3r6`~I6sN#;INqUe{Ig$0#l@837yhlw z@7EfFXe?y!<}&9$BTgZ{N10vf$sRkqPaqt6omFwmD0i+}zx> z80KDwCK%?s{9DavJv}|L16l5wa*4c86em~Y;?JS-A=~&ucWXDdeXwP=9zmy3wYc%1 zJ;mh!UXy$Q>|3m^UFV5iHNQBI;huU}K3Yn^S<-2lF+LEpjsJ12XnKRUjCH?v)L5r& zyse2Iil;-NX)8sd9%9H#gNF=MVjv@A&QHTg!O9 z(~b;mM*fP|IT)mVF^rM9qNlg_?zb7pyIm?NEiK*As;!RG5zA9}`P~~V*D^=+(>@b$ zRwvq?*JKcIfo2%@>V<;RU$uOcP+WL7OQvfmmX0n(Cyg1OpU@6%hND%Wwxm3@G4~j; zylXk**-~j6KmPu>{H@w!w(;vZQwA|4#ZYWfV=NSg^+4NAz0d68D+{7eGiSxml8f4; zgX2Lf!m1JuYfQe)XFl&7jW z_>_SAyoz|skqvl@@%!MkJe9cEgl1Np&B}2$DWY8S1PIUiZqJr8mjFou*+}NHDATaiqJdBOj||WdYf|cNeS6`1*O4MMo@mgc5JGbvZs<9s!|tOLu`kzf05UBV$LnkAE7;W!A7Y zy+8Y1<609M^40u%t8}|+Mnu0Xf$!)toU;_eu+va(r{in{X=?sIzm8c&=-FVN&=l8HG9U9Dp$`gF!Bd2E&QAkpB6$HL zE-p@96LDk4x5Nu+`!OgkQ5~&bS{I78{Q#+mc#Mhf^(|N$uXb9$vcnx}BNENa)u98VwdkjTX8L@QpCj`Txj;4H}DF#Oe z98@~;8_7g|8+u1!-!0gzhZ7gPr*1Cn({(M&ho42Rn&E`+l*9pUEc>`sEBnwqsTeMTDU zVN3@?L05q=h0;nIRaCM_HGW9{;FWA70f+EmC+?;7X?u@W6OoP^9=W6nmjM*J4v&Cc z#Ogz+pltP(y3> zQ?VmD-A*V4PD&cHPm)8E&rXwo(hE>CQw>(+9MIC0A48(+6ep%kH$OobksbdF5#k~J zGx>2RE;C8G7@4qj=TG@2+>qP3sP^>f)5jaGU0KCneou3}r>eU#}q3gCS@)$kPIq}MYX)|&7xz+BHP>ofL_tW5xuz9mA0`2obk#TZMyRnp-LA)UTu81i$_dJt!QQ;)X^(Z;L??X zX|2H(nmTG1%$Ph4Wh7?mRvG;njxnzwexgFPR+D-c&M_w^Z9Xf)Fn^oX={IDkpbEAL z?jR-^k7W!LloRSs93Ta4%Q?!I)7uOih)EQyK8r)v}Xx01xEf&962R+_{`A!M%qe2GXY?R4o?DFasa zme`(VIB~V;CM%1qW{Aku>$k91ZOEm<@1E(pFWGi4`WS@pgGBm+)(3Q8An0?oOIJ-M zw2!;rByL;ssRh;C?hsD?yXHlZK+P)o&>vJHcH>0`KDyQ5XGP&x;_k!;Y zvUI<_Xnr0t1~{Q&Q7-1tI*{D4B*tmZOUnHw>dViZG8e$x%Mt#Qq_p&1cVHD;pO_C{ zi`jsy%DQIiTVbEF3v!J8`AUBS@0+q*w9C^J%73QuN+GPwBKeC~#p2$ZjHO_it!rS- zV##E$Cvl;ti7g+wtpjAd4}U1Y7Cp?*T~^r~zhcbt(lkWxZNl5PZ_$h-B2m^z;iC)Y rU-f?<`rltgCyD<5`0b+^J-_MBNGZu>_&riZ=0pd3QM2@dP00TMt#mG2 literal 0 HcmV?d00001 diff --git a/docs/source/index.rst b/docs/source/index.rst index 7c5d2678..ca339793 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,10 +10,13 @@ Welcome to nnScaler's documentation! :maxdepth: 1 :caption: Contents: + readme quickstart + pytorch_lightning parallel_module register_custom_op parallel + faq Indices and tables ================== diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 6a6f6181..0c9732e7 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -87,28 +87,31 @@ If the example works for you, you can now follow the documentation to paralleliz .. _Fairseq: -Fairseq +Fairseq (TODO) ======= -nnScaler provides `fairseq integration `_. +.. TODO: -TODO: refine the example (and its doc), assigned to Youshan Miao + nnScaler provides `fairseq integration `_. -TODO (long term): write an example using unmodified fairseq + TODO: refine the example (and its doc), assigned to Youshan Miao -Installation ------------- + TODO (long term): write an example using unmodified fairseq -To use fairseq, clone the fork and install it: :: + Installation + ------------ - python -m pip uninstall fairseq + To use fairseq, clone the fork and install it: :: - git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Fairseq - cd Fairseq - python -m pip install -e . + python -m pip uninstall fairseq -Example -------- + git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Fairseq + cd Fairseq + python -m pip install -e . + + Example + ------- + + Follow the example + `here `_. -Follow the example -`here `_. diff --git a/docs/source/readme.rst b/docs/source/readme.rst new file mode 100644 index 00000000..11a93b10 --- /dev/null +++ b/docs/source/readme.rst @@ -0,0 +1,76 @@ +============================================================== +nnScaler: A Parallelization System for DNN Model Training +============================================================== + +Introduction +------------ +**nnScaler** is a parallelization system for deep neural network (DNN) model training. + + +nnScaler automatically parallelizes DNN models across multiple devices, enabling users to focus on model design. nnScaler supports new parallelisms that outperform existing parallel execution approaches. nnScaler supports extending DNN modules with new structures or execution patterns, enabling users to parallelize their own new DNN models. nnScaler can support paralleling new DNN models by providing user-defined functions for the new operators unrecognized by the nnScaler. + +Features +-------- +- **Automatic Parallelization**: nnScaler automatically parallelizes DNN models across multiple devices, enabling users to focus on model design. +- **High Performance**: nnScaler supports new parallelisms that outperform existing parallel execution approaches. +- **Extensibility**: nnScaler supports extending DNN modules with new structures or execution patterns, enabling users to parallelize their own new DNN models. +- **Compatibility**: nnScaler can support paralleling new DNN models by providing user-defined functions for the new operators unrecognized by the nnScaler. + +Overview +-------- + +Below is an overview of the nnScaler system. The nnScaler system consists of three main components: the parallelization compiler, the planner, and the interface. The parallelization compiler takes a DNN model as input, converts into intermediate representation (Graph IR) and generates execution for multiple devices. The parallelization planner will provide efficient strategies during parallelization. The nnScaler interface provides a set of parallelization APIs to support different trainers through certain adapters, as well as extending the nnScaler system. + +.. figure:: images/overview.png + :alt: overview + :figwidth: 80% + :align: center + + **nnScaler Overview** + +Outline +-------- +- **Quick Start**: Learn how to install and use nnScaler. + - **Installation**: Install nnScaler on your machine. + - **Get Started**: Started from a simple example. +- **User Guide**: Learn how to use nnScaler to parallelize a model. + - **Example**: Parallelize NanoGPT through PyTorch Lightning interface. +- **Developer Guide**: Find detailed information about nnScaler. + - **Extending nnScaler**: Learn how to extend nnScaler. +- **Frequently Asked Questions**: Find answers to common questions about nnScaler. + + +Reference +--------- +Please cite nnScaler in your publications if it helps your research: :: + + @inproceedings {nnscaler-osdi24, + author = {Zhiqi Lin and Youshan Miao and Quanlu Zhang and Fan Yang and Yi Zhu and Cheng Li and Saeed Maleki and Xu Cao and Ning Shang and Yilei Yang and Weijiang Xu and Mao Yang and Lintao Zhang and Lidong Zhou}, + title = {nnScaler: Constraint-Guided Parallelization Plan Generation for Deep Learning Training}, + booktitle = {18th {USENIX} Symposium on Operating Systems Design and Implementation ({OSDI} 24)}, + year = {2024}, + publisher = {{USENIX} Association}, + } + +Contributing +------------ +This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. + +When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the `Microsoft Open Source Code of Conduct `_. For more information, see the `Code of Conduct FAQ `_ or contact `opencode@microsoft.com `_ with any additional questions or comments. + +Trademarks +---------- +This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow `Microsoft's Trademark & Brand Guidelines `_. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos is subject to those third-party's policies. + +Contact +------- +You may find our public repo from `https://github.com/microsoft/nnscaler`_ or microsoft internal repo `https://aka.ms/ms-nnscaler`_. + +.. _`https://github.com/microsoft/nnscaler`: https://github.com/microsoft/nnscaler + +.. _`https://aka.ms/ms-nnscaler`: https://aka.ms/ms-nnscaler + +For any questions or inquiries, please contact us at nnscaler@service.microsoft.com. + From 0f9d8dad81f636c28cb70a0f6732d2d72e8cb18b Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Mon, 15 Jul 2024 01:53:58 +0000 Subject: [PATCH 1680/1892] Merged PR 2191: nanoGPT example Added nanoGPT example Known causes of parity mismatch: - grad clip - param group - unable to align rng (validation loop consumes random numbers, etc) Update: Added args `use_nnscaler`, `plan_ngpus`, `runtime_ngpus`. Precision will derive nanoGPT's arg `dtype`. Doc will be added in another PR. Other known issue: When using fp16, nanoGPT uses a scaler while this version does not. I think it's not a big deal because by default nanoGPT does not use fp16. --- .gitmodules | 3 + .../images/nanogpt-curves-deterministic.png | Bin 0 -> 92079 bytes docs/source/images/nanogpt-curves-dp2.png | Bin 0 -> 103662 bytes docs/source/images/nanogpt-curves-orig.png | Bin 0 -> 119671 bytes docs/source/images/nanogpt-curves.png | Bin 0 -> 97778 bytes docs/source/index.rst | 1 + docs/source/nanogpt_example.rst | 193 ++++++++++++ examples/nanogpt/.gitignore | 1 + examples/nanogpt/README.md | 14 + examples/nanogpt/README.rst | 1 + examples/nanogpt/nanoGPT | 1 + examples/nanogpt/requirements.txt | 8 + examples/nanogpt/train_nnscaler.py | 274 ++++++++++++++++++ 13 files changed, 496 insertions(+) create mode 100644 .gitmodules create mode 100644 docs/source/images/nanogpt-curves-deterministic.png create mode 100644 docs/source/images/nanogpt-curves-dp2.png create mode 100644 docs/source/images/nanogpt-curves-orig.png create mode 100644 docs/source/images/nanogpt-curves.png create mode 100644 docs/source/nanogpt_example.rst create mode 100644 examples/nanogpt/.gitignore create mode 100644 examples/nanogpt/README.md create mode 120000 examples/nanogpt/README.rst create mode 160000 examples/nanogpt/nanoGPT create mode 100644 examples/nanogpt/requirements.txt create mode 100644 examples/nanogpt/train_nnscaler.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..ad6f0da2 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "examples/nanogpt/nanoGPT"] + path = examples/nanogpt/nanoGPT + url = https://github.com/karpathy/nanoGPT.git diff --git a/docs/source/images/nanogpt-curves-deterministic.png b/docs/source/images/nanogpt-curves-deterministic.png new file mode 100644 index 0000000000000000000000000000000000000000..53e443600978ea9094ffea25dc22bfa7a4925292 GIT binary patch literal 92079 zcmcfpbyQW|`vs038tFqf2uOE#iGYB#fOLs~bT>$MN=btV3J3^DcPcF{-QC^4b$mbH zf9@amjyuL3cMsop$Z?;&)?V>E^OGo|d*2YJwS0)#fdizGzy!x%Hy@K|i>Ospy`s;hOVomhwXWK86HgB)55<;C5 zi^9e(&VG3>Xq$8E=|h&BgiUH^M-~>aXuxGeKunApA%m8nvo|`mTtfhbLa(C^zL&$p zwT;?^buI@;=LKjo6R=W)e@G~rvEfe}v_M2eq^4F~Ntjyt?|!_DLFA}z_)48QPv2+U zq_5Z;=aFM%X>EkT^Yi`8If+HN|7>sd(HrW)9@1)MXD4f%i%3&RYzl4+k$&nz|L?{^ zRYv35*|~?Nrk}8$`9DkTA6>9RfrQV_&I*lZgolUU#cDF2_uVP{?*`Yio8RkB?NwE8 zalwOOSHhynvdsTocHJEj;7k9mxyTT_|89Xo>j}^qU=6@7Cl*5G|GR%pi#7BA?(Zwp z5&XaV6Z9$_{<|Cezt;5s%LYm$^AK>LP|v^_o1UMxv&P)%UNt$=d2OFxCgo}}i+zj+ zOoQMc)}X{8LfR6c#FQ_?Q^eb8~!-aT21B`Iyz}` z*|+9Ds;Z8c!>BseozA_CipC;Hcq}C8a6PF>K*z5ezI>tADbxDz3zn9})Kg5ewf4Po zRKCF%7Z-+y?`vndoHvxfRC-X8c+$c~!SA<4)$;9OYnXD;>DI1gjqX6E5T2Bj6s2#L zgrdBBfIfTJjLpG>pHn?~L_`EDY(k2sZ1~CeZKeZ4Lqo-zwU#|f zM4mi(qRNPGX=Szj?_1!byw+j8PTwX znBf0=Xh>E~EkUmTOF<|mS?KXnUyq4bPd^t;9MYNiy3e%N=I`D;uEPsFeF~>-=vk;uc^eoA-*kJ>=pdX0 z323pe%#_QK1fq{XoqT=R#@vWkh z)6?PLM>yBz^qic>)F$tXz4lzGpr#+xv@|q3OkXu4i;IgX2em(@rlNpV79AZ;NdVu} zBFr~2F~P{ny64lmTRc8K-m0IvLk9hpEJdcc)=SYz!h z9#k#v_3PJ985mHo!VsQ5eF}zsWw_msCbb_*MYZ^*18n%a_38IDx)>zfpPhp>TwPtu zwIvVit*os4nNoiI_<@*S>I)%ICh*u@A-U3Cow(w>Rmk@8LCU=X1O#EM zX2ZC>y}hA0|Lb)k*j~{s(`wUvpylF%UM~7SKH-pLP|1$|mBirgcXxL>Oe-PLDzS3@ z#>-3id*eX^yR#rBDX;vNPD}ZYyUp|w&%nC|hcO)p!)^1I|K5R@kI!WGr)}2T9x(rT z>R)jvDTn@irRwVKUCs7B_yB^(0Y4sta#vRbo`{6sc42CT|Kr0$_|&`Sp&0EoQ--r| zv5MB#%%wJNPV|>G&xAMhuh|{uA1f{{crQhZlwxqoNbg0kA0fx{5?mLfY-u;+%1s{G zBPQW}Q^6r2%fDI5@r9AEb0vu0f&qJ=;pN3oHS((;;p#)VuV=)I+(cvWMtAWN8!>es z!^7%$MYZ4c%Fc^F$Bav|e5!M4-dzO;dM;8m!|H+vi=H_hyDU3T;_#L2=bX@;bkXw= z0x%UF6km|Cd+a$5M9xPD*m;r(y%8|}bw%EEH>u2PJ|7@pM}poK7|3DNld<@^4Fdu- zt30}BHPb&EeY5z!;^3I!7?ZutEN)nOuaZiF6S1}Rj=rUY6Cu29ypgKs?o5Lm>%l`+6XaatU@L9IFwfQer`zlN49;_#g4`&t zZC~|PwA{1Ux3le5IdqcM8ry6Zoro-W`y#pVyX<){47x9`46TkG+pBxGn|T=b9r`!P zc4^MtANgeNn9${akSm{}Wo0d0ygdK7FsJp?`_eAc?bEE?U-7p|3@SbGPoFnrRsBl3 zyi4hPL7HBd%s++*ele6zXmdFueDbOLZTx{K}m>LQpf^=F95bd{_e; znqrYcb`87r7E?O6oTL&22~s)IqK>(wde?-047#s#3dhxALPy&s$Hx(4WXFp1)|KBH zschT**sNc7<0@i(VTnjlj-Rl)x;hRa;pDDojPl#X2lGmu_>2tnu+9(B(HO;= z%u+*CV17eH+)azd;8`H!6kS7v=(6dMTwIaK^4>4|vz7PPT50Rd@>x6(Z=46$Wl>fR z_zt=3SgyQR)Az?bI*-eb4y%6;k4G4vKoBl$_2bTogGDHFLZ%C-$f8uwJfB$;baK+< z!`)&}n%t|*I&AglI4mBC`@uKV)LHc)Tvfh^EH*LYLPR+}=Gi@77O6TqbS*dja`ZdsY&~A`Pc2A5o)3fD zB%_1lW{ZGH4}wJg3t`lPBAMLfu$+pADD30)A34WESK8?=lO7Wqv-eCj9SrjX>0eKliCpuX7k?+N7)??@dk>z#Ejz*+k8 z-glm;cI6Ib>NyooL|##$l;+VB{?bRBL&GC9asO*k8KiK?)=_s1kHg}yY+n)&C)Yqg z4!><1(P3HFAeCRl?T=Gx_eTm7%QZ`Wx#m`0!>gKDPt?DE2YI^iId^M<)kkf69wiK9oT!iJ;m@I>FT$ z+v(#8Z;#dH@(QzAGeOL^9-D=(g3GTwU&RntTe3L)^bd(^{TrKEGez637K*L?__y~= zxU^{nLL*Di_A+}xP0@&?^7)Kg^gU@jsn&i&4m>-re<`Se953S^LK>5=jZIvJs2 z_I^m#=dP!BI<1OOfIg#gyR!GFvl3-8Ap+&+hVf-+`OyT2=k!k-;qkJJqiHjV%&RNB z1eKB9^DB_jMf{H?2*3sVd*0M8k{}7hyF8dD zXI{Vtc(C#AR`d016L6Pd<3L3&9ns1k=hv3AH+=kmGX=eZ{T^=KrO705xo0DI2{>t+S<>GiW^wPfl;5=t(FiyaDbK?au2YYAarnpp6r{|n4 zJ^4(}($-Alri{!&@z<|UcBjhaJUy?O9&JiWOS`Zo?gDyxWTnc>ot&MI$2C;)s;jGe zg}#-P9G{J;tb$Y;Hfnk1pzycFSA?0FSy(QYLf8crpp@;*%od3!JYQ^A20e_W9X7Wp zyGv6mtajtBE}!*c6pqIBrs=e)sR!iaiD~Z@*4_S)uv1BjCXLmEysF{z?bkZP(%N^T zKKt`bPFBk`KF$+zIjq{&wC#V@L;@M8(0Qybrz7<@swB`8Xbyf8nQiCH?du==3Hu>z zksO-lym=>Khx9SJqgg-&aUkP(KgEonQiAL>uHHyh;{6@-C|)^0+a1PO=D?x>NF z_SVOH`=*P(@0*!3V!_1i%o%MB2W?^FW_m7bGS2>7*^S{BG=LnD`3><)xeXQ8sWOLJ zPE%Z&k6HNm8PS`|9o8o)wM{S6gnx@lz)saeD_TjWGOiUPcFy*K6qv8jMos?Wx@qy| zeEzmOC9Wl)eH)RLQ0MV1GO_sm(d3s`(^n@ph$!K3QAj`SetMox+lGfDbGjQDJ@K?X z=^?XU0(jy1n>TODoi(+!&8|<@J@#v+jH_luKgKu3a7uI|AT-V?&_9Wl)z{(SXTN25 z@Eau|7I9-I*qE9~@0SWhgwx?j-h`~iFE%i-|x?hpbbAoFn{P|xmqwbkF}4JOinrPwlR ziIm`b1y%f$Z&hozvT3GsrpI&cW2g@LtW#eG~wt@73l1B7IU#=p?wvH*I$Ww zR@GU(wiWNxO+#6^u^dO+{y_;gyus`!)5AIWFVJ4lzLb4aZ+p!G-6%Sv%rz}xIcPf z>V*usVZ1b^9(gk^&$-i5ethe>xN<(Fy_CW67$hR$ffV1hoT%N}hR#zXHn;J)+$(a} zMDltl>&n@k0UbSPgAOmyJu}{mcN_tATS}9LbOe+KL*lnKTjw|85^|=dGywR{5q*n` zRZg3y$Wjioo&DLq?qu@k^UK~W%~>+~*&$F~nVHd6>$1_SWY3kL4`|HPwld^_2?VRt;tYgF?Xzk zYzl-*)IQHgsE_Ul)ioU|20okx?I(i(MQ+?npGpe}yMF`xCvC5e8JhQh&Tc5(us zc(z$=p#aqY05La?QCwwOELY~Updgle{v1L2`ExF-Kc5ZPg_3=5qx9K(=iL`$04ftC z$1=CDxS9OC+zWS3H&yHgm znM&X|k-l0n%zq5`>HAvBy~)Nho@LpwcX#IlvTK=hqe*Azxpt{xKbQhPKk+|fW8Pv< zuDw5?d;kwUFYE_v+wLUUC9?$L`-&IS{ z&!*g;5sN#UE#sGJLGrb#c@efwTC}+om<=4dv1J?N#~n+u=!j`8!~&+?i9y-|zhMp= zFH1axHQJE}CcL+omChy3m#{xw<iO%z~O4 zUai-#dfIgaohxU0{FLhz+?vm+kV~()-C|T+3p)bM7n2?b$gZ8Vt^yjn!FBJw#Yq=Q z-RAFvfKEtOUf68r0F!9%^l<@B!2Z=5!hFtyY*=8RgxKWg$Fa>@gqFn(NN8Mljy*da zC`Ud-*OL=Ht2P;P#$E}#;mmaZg%}>!WHvK5hepI39-Tt2F|w6wALvbwE@Aq6@ryLm z?@o6k-VWD+%!LbIs?}L;hH8zWc83Ji#RXvutFNaO(o?aqgKP4Fm+bmn3K-f|l*n~Q zgJqqM4TC}kwNC*N_TiUvND@7wHKbu)uVPab%KSLzNxoWL*$G{m4lfjFjnly%-&J!d#;6&T?0X%z^2OYEDv_ja zBiWV_Z{p7_o7+6oj2SmbFACqK5tU-MdwQQ+j)7=EGw; z5-tB2>&iN1#H36OF;)|roO={(nKKCKu%%i#Rvc>Q@%|g9#C<@SNOHN70)^}IJiwhh zxzNJb)Gn~ReP4I**Iku8C(!$mi*N-3wb1vX4D1&2IB)1Z?jB;+q+;&4a3@@z7e&D* zS&gU1HQOlk-?hk)J33g9__@4APjmA6)vH%mN1fP>`bK=byhrbzX8iyJ&?;cgdSKm+ zArr`NZZ?P~&u%PazhIsdbw+x)?mSHsS@%>j(u#E7aB1Dnon7c`RkT=*9nNW=kPt_Z z46IYk?wzl>Tq>k^(N~MmdF!C4V}#kT?|g^jhw|{SH7Y6@;Dci!-UwelB*cxO=)|9L zf5Z%5?gM#j^k&zhn$gA;z#H~!!FZ<#$z_<$l!-G$X`2yP4Xr*r-g#GRyz%Zy-SO_A z2cx;*!R-yLJkea>W0Mn~MT?X9;*sN3Yn;{!)vw+r}DFGBOmkzq0^McC}rU1A%CKQTm`Yo8-CGM**skW57kY`-xipp@R`=1_#~4AgH*x zxoy6%HqzSKwMRli^0=IL7fIa*!{eS1A0P1fGo4A0XMI=H@bkYVArAUZ&qAr1T`>dJ zN;i)d#3t$<3LlrV6{jGml9rRUIxg3nIwy`yXxI$_33e|7^u(_Tm)(uRdb9Nca%Mg;{{G<)qYTM+PP zW@a1&$+K?Bo9&LZLPA139`7$rYTQ`eTdgOU0gpn<$5*px?|imIuKYBu#)oKVlqzuH zi>x^>WT0K7aK1)acfiMl1uc7WHGTo!Q-{1`4dU?RVYGhGLA1IaADiNGwIJ3r20@W) zzvD4~r;en^?9H~wwEoTJy9Da>AM4Xz=@e`==awO)38U>-&JuJHEjS!DO*kAn!cHi~ zw9xyax2vF<($&v#;^3zcSWH;4y!vH?LLA%ZHyyV7es`1^^ZDGUmFAk6!={{wLs2tQ z&3TzaUndUH95o= zHXpnd6BEP8$VjYu7`d^z`S|QGTg-dz$EP?I#XFQ1wi`JC|EKMJdg48?Arum)qn^zU z0^hwi0Mq|S6jCMf;Bi!6iG*bwNWF5N|1i(o;|x(UmlmDBn+#|wvVo(w4Xz8tU7Ar`DH0| zfc%8myMrN~T2L6PoxwXRd+oq2qf+N;fG5pyE-R4hWZGBxMwp?G<%z7z+)D{O`M`qF zpuzCIu<3ob@yW>x%h|0hOUTyo1%F1_aDDMAO?4yu*G7V~j?T`6?CgiseLP%T(Ue0? zObTKA-N~}uM9;IW(WYtf*J5|+*pN!1g{Y#WPmXih+4`*PHr(|Hd^BbS~E^95R!ZUAbeK<93^Sj>~n87K=+wV*+><2*u%?57?WO zzHDV$yEPIYn`e@>4QEzu0pBRPq~)!sqSEU+We~>EdWXxNicVzU0!79orE_6z#pPZR z)fboE!UwbQ{HrGXR@&o97?t1A)_br!j&HtUvp1jpp|lrqT#+f)*)J0a#UKfMU*1dz z#vo|kuJ?K~Cl1CaYGI!3cSLwmLs#Do1n@?66D72!bDF@Pd5N8QHOm<>u(D#Lr>9f; zVv^P&o2tpm$he5E_QX8b(7+=l1I%rPH`vGAJUmED(7szwMYW!eUxzW>Z1$N~J<>EM zuhhTd6B+q%fk*X*S|}E(12Qtpr_&Fgk`k+4AFE8OzrGkq4}E0-zp{gE4`@d7j=if$Ofnh}pfw!(91xS#!0~JDTZE5M?DnNt?I0G;_SXvgMe)1>40-R%gAnB+JmH=Iur{N2 zcC7QvFyvu`KRY`i`0(@Cc7^+)A?!noDiXk(%fEUR(z%QyjCEZ*4C<;nq8|q79d7}O8)gx^@UP+F-eR?a z_)nkE3Pu@S$lDRoC<4HsH}!)x8eUq;3e+l8iV*OLrR3iV6NLQyd}`c?;8+^s*#EZz zT|wI84Zi|2xT@bI|!=k>7kcQpg~YkQ6*hmc!3szs0`NZ zn&*_EH(RDM1BgD4L#%>PRKO-KcZOYB<`^`(c?!N>Ed2br&AP_RRn}h1 z&@lGo#0fTugM)*`B_*)w?k=BcS*_e{1OYgY*5&;`8kgU_^0K7s*2q&J??jUe;eyQt z@`|9x5p_=tS+BneXpgG9V;k}B?CjH|U%6x3U_7JRgHUo@a-Kd9J$wFRp`^UL7krSA zmKF}8N8T!i;M=!vN33g5v9Jh*e+M{&XUi9=gI$J7@wUy%(sHY*JJqYGunUi0_&j5v~tSg%z`EXl{-DU_d7P2 zy{E|vo_jaB_4OpmMM>ZzAaKgiK;sOkz~Wo_IcvAd?7F4}2M1?qX(?%K%?xQX=Wbti zK+TKM8Df@xCon6 z>KOBqhQzLDfS!Rd_Z(_SpRf+;Uk@C$3*xrLS`6^o`E37A!vk@xK zTiV(6fLY)Iu^rL|{xZ6-5Zl)${~-x1G|vT=@}`d0JAGw4J3AwDbI~9@f+^8qB?NR3 zEjRbcXQDcA8zVFGPSNLwklqO%aG#)iVc1dTdGwDT+kJlRyqT z=__^vt)>Y&()_fU3B(q^fZ57BXnD+mzLr!_fn513mZF!@X@kvQz zp2`%04iK#p0)@iFw;()Re@;xqB^UL;Q7+QhVhmhHg_#U?RzO*_30pFN`{|l8@{a-k zQAqZpeMzE!wWFU4f?40xoX_c?exdRf1T2m<15hD9P3;#t37(UsPcm6(^0a*aKB(JN z3Y6F*31H)kGK+};DxRbN`;!;^_ge3@|h@QzrhYRJ6Lcz!;fBO6^;159`Ay39Sb6eYoxufif(1evT)Z@`V#j)Vj)~$n&$R0n zn@^=ZoLx?}%9e+W%y=BEkcyATA%t7wdc0kH#33LE0WwVXJJ^GUL{F8%yG>gu<`F9# z@+F`VUGa*%@yLuzPC}z|S}|F#?M-bu2Bqa_L6URorEqsh7CuQ!Ng)L3iZD*>f_Rt< zyy=@LLbj9J*E>3bqek)wn#36n=s#aKp|&A4jl*2=nm6YnT?ZCbGscxV9{OT!HvG@Qb^`F3qPqy_v!HOuuClYJoi*W3 z2fiJlcLo321X286sJ|nk;Td7XQz#3h&iKcK$-8qs&}Py|CD=z)Hgk1Q;BTqh`_X@Q zLlZG~*oiF>Y*^m=Hq~Hqwc{(ua$HgfTU2TDfPU+yR{w(T-v$%~4PzFO&8 z@aYq2ni4&%`AVH!a0>we(pODp7xJ7xfBpml>j~gs2hw>^a3f^4&VZk;3!_JtBM7RILDOIc;Y(3$zW;RqBx5ESH)ZRifSz%!zv49;pNXi^8rG5;! z5qXKSW=eCI6+$s!JJ{yYg7^=+Mj>q-U&#Qdt2^!?1JNA<0Y(QQ01tlm_ahBHLrjp5 z&XwfTXJ?51**X1s5Egu<0Lz8NqE#Buzm5xg1L$bL*>Nw`Do}YiZ1^)q)DSxXxKz{C z{ky)pv0(vV3Mglv;K@t=0lt!`R=8OfVXcxXp?{-Vt_NgJu#{n=Gn)eoJvU2bwF@xA z*xmIhz#}2Cw5R~kfq^+X>E~G8DJee~y=ZjXKXJ%!QGEIG!-Vy|t88NlL4*`h!SA0~ zt@zEzR;&2f&(7@D*Vp}oeh*T7%E(aD)@E>thU?F0E%y+iY;37tOH>4Vtm54;WqHPy z699UKsW7|T)-V_^>n+WbYsFOy^7^#u}QeQIwNkrw6sv7U{lHf zvLP`k$-Ab!e_iI|x=HE4J?=i7m(Fv)?&9M2VjssvKvf}CGC?CKsI$e0FqHQ^j5=8X z4AxVZk8nkKd2sw0n2wH)(sCpYy1KdmIeOjS#|HfT$*=F&O(3tRc>3G4my8ECL-#Yl z9QPlx68 zgSD!bfMZs%ZYpt)rk2tqI_+PI{b`A3bY8d7hpW)@h9 zaS~ATe##Wcb>KCm`X*?-KlRpix$D0r$n1hxj*p+)lXkTWdeLuLMRo)_ zLrmeG1&dI9gWp?#*51=2&AM?51K~2hj$^L`0FZD@7S7Mc0qtT%%4zMvuZ@?=IO`Ey zkn2W_I-&>1f)dG&_x&U*aIRdRe`AJyvuNc>N`B#(A&Lfp@10#YC-~)!#fTQle zs~5N@n`r`xpZ-8J@4lm87aaYg6_WYsGC1}HxB;OJg~}&D$a#aBkB`8u(O^pue7GhB z;`o#~F6Ng)@yOEY)=pFVj*Y~kc<1m!JW$Z@9PR z=!>o=tPfiZ%PWHy;ya{exzNokfPu`0^%TAM`1pSbT>N?sPbr-n3AhT`KqZVrMiy04 z-tuV^7z3|MK~zKm6&A2vO3Ohp&>41k(q)2p5)b9dxQH-&eEZlRiSj4X!?oU_QN)CY zE4Qb=%l-w!?chN$>eTBHLxI$H&J2kf=yJJLhP0KO__m zRAJ`e02(}SbNS#7zZsrxbT1!)`l!pwB15P-I8>JC7*dmyLjnRIwGte#PC&ax`jR;L zALxCYkKA}aogDAC`!C%b>-}2>GKLm_9@D9JdIk>%W`J7xmrN~$j+(UGGn)LQy3K6w_{_`!y|g}C_PQnbz1dpqR0B5> zkeAk+gbo@H#X&-b*&b$2{x{o^HpwDgk>_>m9iTyrl+TQs{ zdX(y)D+JT$J!Tbd$lRtI=oHCd!5ZIP9Rtulckls)vJC-SG@z+TG@6tza+3%QO*NCo zmw(lo{{paF*WyOOL(7KRW~@|b+0Dr>CcjHHR$-p^+*?*^6?|ED52x-ZPGbp)RRB=} zUa&z4R?c1X6DqH{(WpUqe@KA{M-eJ(h}hT}eRorSq;X8cONw4mX!8V8;R47>;mdd@NG7|pBU}!?S3K+Paq!8UfwIPCQ{30g^W5*wQpO?6);hW ze0!vYf01l3xT0BB@&;+K^%KrSkw z@N$M;P&@#8>VJV-W{tk7F!m-(o=!HL6*{I)1qUH7HHu`eP2pavez0hG-3;hFV&~1Fkg;w2|A8l$+x^;7@i*j>wAppUbFSeP1KBz0ua*DZfeVYBIYCji;JO&r8WXem7o%_RzMOmc@9Tap= zfGMG`TojnVVI9;-mN?8Rvftd0(MVgmHvSYj*1Z*Q*^ovIe|Y!Q95AM4W+Aas;#>Gjn=~(BA?Z@tiPo|hHS@2a7wxwnK>;a4b=-cP)AOLr zf#lZIB==6hDr)dkv}^(3?IQ;*o_l+X48IYm(k*kS*XJaS`Cl#uT(?$k7frc%C`(5T!WUkIDl%)Ac6H0hfN{73bQvr?nEw9A1OxcOn>OEyi`!lh83e(voIM}Oh*Hbf!;TSs zLhwp&*+JLjhd{@O)|1o{3jQz>-E|yG0xM~M0CfYcY;D&8tLH{SLXzcA$!m!NN^w;U z4d*|nr9cjVDQ^8L0@k*+K+c2$!bXgG!IMh<;$S{%dRhnQS;x~2YG7R}(c>XwQx2@| z$O7mCN*ZD=Q^e2eUq>8$6oAVoT9$Uef(N*=a@D^E$I^ikaLrjJA5#2zR5kPkwQg33 z3AAd|oT;NwXtmMqt#F`E zHoyw&?MtAA?YTkcVX?5ZG-7?<^9feK`}J{@g1Q#O8!!ibL=iG@stUAuA7yEz`N+9; zSgUp3JAhaNkt8ja{%B1%F6GEMvy;a|lvjML_i>9K#*I08tx&-GBg@EiprWCj>_*W# zwO(C$zQ%D-I);bTJJQ={2=X zi5%P(+j#q6onlP^C9Zx#G`;xV@a!TX=Y|z zu+k*cn*5l^`x75y#9}CAx7(^}cl8psukTkPS-!UcqyLhFAT07Q($WU~BZ-^9h!X{& zV26%c;{TWEwuloU&l07Jq@H5Rr8#k8jr*cC*Ao+};rWTpVJJ$HguWOG(c;p9WtV@K zVN?j-XbvKBA5PhjjDsqE)f0I` zX<;S_&NyG>3541IWg=*THoz$^xF1YOs{bf;$zVx;v?6wTzYsPt8fZ(v0oH;y2AUXx z+BYSc5}r5W|9uDg0kjl{7#@g|M;ubu=Ry~v+zpnkV}&2vz)$3Hl1phy;z*n~@W^40 zm|y=oL-C7ci$z~-O@nl73dnlybR2SWRu_Dbc=F+XnO3%$bA!VYH~Y4(;KZJnJS&<^ zI4fFYcofc8(ujfS_o`NXvVQlQBkAZSmMqVMwyDv`TBX#tT+Dm)gE)~CLC*2`-h818 zs4#p20!PY!RFZfa-+Mx^-A+zU{%toH<08{c_avTHoey5O8>cHQT2A*$37^j|PJKSE z;%9CC?|97SdQ{gkPvRU0GWF_r=oIPU#oFW6d_YuMY$1~aljZU)FenI$kH0oGr2%?+ z!zju60RRcKmxweUJGq!|6aa{TJ@GD-Lv{R$0?tLsqjYyHSJQ&cIl6;k!|aR8Hz;CaDX`#J!N=cdIBMDjg*j5={19Mo0B+{b zhyfDao|_0@>Yn08;80Moxn#B)&%idyLg`Q)HnKNU>3IY&0RhSAe0GYfg_(1Vu3 zC8rVIqw=264F_M~g2sVlxM@HI1Tyzn<*OND{J-)^7 z2ZQ?8_f9;5JRglHMYNh77oR**LRc^8$z`LXT9ICwjKMbbIFqn2DQ-lU{bCHjnfV_|4&<4}jc*D> zzwsPAZ+5ib>?O>gv}mLZKu_37u;io!JY{fU^OJzxx?$VOoU6k)G;GqK>1krYS%f<; zC?Qzul!TyhnG}etDlg4z+JxftV8EA4dXCZK^ZQ_CbfOXkn5QVqtUZHqx3JCv&g<(! zHI*xyf3)&932x?X%Xah1?d8A{MIm!5`@rrmK_wYK`c_5?1}B42Cr}Ylda{mcG&;c% zRpH}vH4tia&)iDA9)-JPdT-CkJp42S*Ecpyk5{@uxJ7`No0}VhIty^%{xg%toW|{e zNMe6eii^7BA{jipLR3+@fez&!uTh1&e)9^p8;1n3x1m)yJFB(R|1JA0dH2FUqfiBl z49XPRf+UDbtOED!gU6mu!nXwQb(&{qnL|Gq)%kE~5{MIiL$l(36`J5Icdf_k?1_LGUo@16>w;C0Y2Gh;n0+37x zz5InlWX6PCAquGv)0X%n0TxG{*O=dUt=q?! zu3utJf0^3`cYp7!EKHoa=i~X%R?fE zd*yX&C|xv9Kmc--J@E}!kLyVJ^KN8L>%$!~g>D%iv0VfH_A-$)Eft-MCe6unZ~`U% zmpQSxPu>5J|K(SK1G)JGH3iHKgsf*tGv5zclLp*;Q}efr!8Csk&c;lTpL@gk;+yb8 zxu7~7Q!LxZ&nEGT)s#$0Uf(`{Mg(?+=V4gqvsP$#`3}Nf@AWM#EH=#m2?TbJ*k6Bu z#p{5-zw&KAp;DIc1TD$0N5t?#)P!fIGUJ6{-g4`XQSf>FV!+7NSsPed&d2d2QW=BF z`u^ZQr4=m9@%QIF1*c4|N*xCFe;4b2WXyaCn(cGtmtWAgz;J(pRJ!#Fr|uO@NpoPg*uhjzKyJ0M5DAuine=PRM2+@mN?VNp@% zvV+Hbji$dTaDf1(kQ6v2yHm8{;*@dXH;e~Or$g+uYZQRvKL`tMo&x6+|NWYQ zy&96{^nZN*ZkSn7sDdk;v}VGzuO-q6mVdzSltEMsj9~RV8PhEiN!Y#Q?zt)E#k+!+cZvLGQoUBLmK9TU{wzbmmkl z0s;bq{J2smZ^bIndQ&G=uI3}2!)FpI7^4jxRfAE^K<y#<*k|bBZ>;*)Q6HT19$(JCqz{bAN(T0T8^n?`J;Eo}dnF%{9?%AP zLZc}c`I#0e@kFI?w6eCXQc1WTSdc*NY6gzr``vCtf|r0Cb>XQjfzwJdz-Sc^tEQru z!U{K1d3UwDltI+#|mb*!{?YJSzpv~Q{dq~!dc~#ij*wM zg!r+P?w5hALBn<@8r+8)9Q#s6eCroO&KfA`pg>y}1=S@E3CS<+r0^~LaWd@O_Btwd zLC4&Ks@VKrnh>Jc@5obXycB$*PV4|Z%M^#gK;pMt`EUy6kIBhqnsXlVg;L@|af$H< z&6kESMLysmF+DwkLc}e%sfnC4=1af)mw^GLk0~i`GY?T5EnABLGQD|Nk==Dy8F?u^ zFpq`W*bltcj}l=kJyh|8#{w9W4>8?%(E8UF7W5w#o}t4zl>p|n)UcWK{nYnvK%`Q; zkOOdB&_2y&`U{DHG#X{p_%C)rO-q%gur4nJK^|!=$UHbw(aUNSFEqe-(1hnJRwq@ zXu>SB4O5-{6vW)=79G(D7xC#seIy0cGY7`(flAx2Ihvz>5%|Zag#I=OytQy}-~ z0r>G4U|NFYe)$p#=mv5u6p*z_wMzbH6WnuaRLQA&W&6UR{Yg;G>Ld6Y=d`e0P}TIg zrL&9D%KV09U}5vF%)+mo)2Fyd)SBXmd4Ha&%6B})a~~-D?D$0*k?XIX@K3%rECehB z2tb8^&;UNAy9WGYL~s@t);DmD790n#w6?|-{__Obu)NnAh7Zvg@;@*$s>ZZ0Z_?-x zf)PIl@^Gu=kB=StQ>b+^Ao)0a20PQU5OCyJK%A*v@Ood)VHp${2!5M~=*IyGA#yCh z0#j;Z1EK`x)`2-8BULJAYyj{ScO2rMjT|-Q2t{Y-rMmWyYDP?<4N06B9Y6Lj%BwOL z)~etoSdOg3kqVYs#|Xy>Ft$Pi;*@;p2Ucnfhe!vPwxnN{;&y~ZRc5qFzq$m{=LWW2 zI_zq1yM(-%bp-3*ym}Qg0&RE;Yzd>5pxSKuAF$WZNYZZ7dnOByTbN4C$xZsYEJM|H zE*Q{zPwl3RZPvu3dvBu(hzvvzy@t)XU0X8bn$7q#fGG?HtpTwIoDaZMVR+5%1gvx; z(Bf2JfvT_=B%O2U#Oxay3I)Q9q9Qs-LkX#=VWp*|c-ik^IuEdW`I9&mOipTn+UhB+ z3}_uf`O)lsNdo{YFxHCxwA0kkKuG!(40vqp@YTP@qSKzY89;>lbRx2;$2~)8@`DfJ z(C)I9QdfZIO~s(6kXJ}yQc;D26A#Qw8f-VS?*@Dlpyn$d+}bbIdcWBS=}3}h`GlIP8M_gn zKyG)7l*~q!M-V3$Y*Rh5+#qP&H!%GB<7GJXV``D}y%=IIIkQw*v)tItDOs zois$zrS`A8rf+qDU&%rA`OB9TZ~~!2MdBVYcQsWGgjtYz0hB$`!lEHP5~e?Jqi*$M z(=Bs_=yOsqB`RIWrp0|!>J`?ddlsU+&zb^&xA2B;VyHWzBAQiirs8kf$D$7=fFM*Y znGoE}bjBgx%9)7r_nV~T1sNasGfRL<2F#d%;!K12808kr{s1qF+;+#`?C^>nYMT;E zxu9)G^p8a@&psee?BoL?5Of{LfiTe>Oy20^Bur3%Rdk?$0f%#!FB`={o#wF|LIn7N z`oY)Q1|cQi8@~0Moo)J_tZS zIDE?FqOi$Wh`N46i|65vuOJsi7eYs_VAX5Jdn9(;h4-J?ZQLD*>EW7jk`Ah(-!lAP zB%6qQ!)&E7bZ!6&o+Ts$NQTnV(t`iYt}3x9O4pbRY1u#M;E$-Weie>< z_F5GYKe1q49Z~U_gd@Nv(yWBH`&;w@fF=nbBFB$Q7){OZ){O!B*IxPU#}GVdR-(v3 zT@Vn|a9NjHm-P6X-?s=-u_yUU)vjo(!8LkcRzx)B&qgcj^O*DVm-_>AR<^Q z;=HGhnC}X$mi2O)vM(c_{S7R-okaQSDkVVNv59BK9U2mb&qEPAYpl-g!fm36OH;R$ zY?7g1B2M8tj0s}>e>P6G_un-;lQ9Tej(-lE4O`AY3M_@<>$mrS>q<7uz%&ZuKxa~h zIl`lyHhi)*Bnxt`avV1h>VwF80OUy@UibHa)TMp(b(Fu8QuP^F?_c}Z7k^W69j3p_ zoJ#yulsv5%uSo!ft)np_(D-&M+2NXWty%>{y zD-48#jB?C06r#S5z7fLG{^7-q!4ppav|L+x9TgL|>nX2W{nJrXyh_W49PUb;+m+jr z0S}ls==i8pWKvTi))E*mMuV(|fYGL>Z#W+=y0W^uv@@nmgmRN)o(0cD*j6vm<<_I^ z)e&r7ULx#%RMk9kXPbMdD-%kS?*Cx$5(=lyP4oi5@FX44+*6RyQ=mm4dpP$6>T4?i|ACk(oB9}^Ims>b4*$uiQPs8sE&A!G zI4L?&@s;xexD`JY(?E0szaSTC?=}1So%5&Qy>z?$>hWo>=;|r~Y^HcLGxTy=T3XjO zuHZkQ5H@u9-VrU#=4C%!5$lFF0FhB0+3b-t=;?kes~p?LEjLV^rauoIHVO3GDWJlq zINBdnj(F6iQ?+eZ^!R^9l8D)=v{vq7{m$J@sBfYmydXwt6^V;`C)q+@{gH=GH-jOF zjle7bh|OO$9{^Uc&OwiZKsF%jlVpa7okz_x-v$uymu-8qJAP_a#eUgBW;O>FxJSx9 zH^{n)`^!H7!%C7X%q-+iFY`?r>)HeobMyN1y-wmwMV@y*ep6~{YGT)$4*mT3GhhXv z!~XU*eO|9-iVG-32nYgsB7;4Z1}01rX;wh`V)SbJbkQdwJK7TtL9@6|2(v!tv;ME( z(l=nR&{9$i((_^5Wb82JB7jv{virRTHyY?fyLHvCH|2YB8ZdgCito`+ZnscZjV`+J zZAf!O9f~=LpM(_wbu1rL-Pc&U4ia_L(2EzJiwA2#@*dE682<~=1$KapM)79W5FH3_ z2onrqOqg-ZjU)p+&zDVHr+V0rdoG$*gprzYQW;tn3gLQ-dx_B0_9c)V^V;DJu|Mx+ zFKiJ|wyyvO)E+1~mAz%f#1Ks45Zl5T%LMdb@r{(9bCAnL0QQ7~owXLfZRh`Yi^4~E z>2HfdfT-zbU{wRCqcLDrr#FDEqz+UJ6C1(;RmQNW^*nxSXjFYf^J`N{GupBbmeU7k zT}N2~6Y!e^?ZUQP0>x(L9jSK8qzgZSP8XqF>@8Qkp5SMFDIhN9r6Y0sb9ewOSL6jL zoJbQHoE|zug&2^D(g4MblG9W?ZPkpAGBP%fd78$^&COlW;urDlDWm^<_*!tJo@91y zoP#KO!vvoozQ?UWRewapOzL-RQJ!rwg_1rY&;1qx$lx|}k}ykC)RD>b z{g7EH3)fjP#$qN~vCSdxz!pnMtVJg@$DKMQ<2XNcq2dquG!p|>4L8Xm5e*agGpwZ# zy^cU<_T5(nO+`jOA}SC>%EvPwvmkUy!`qWRvLWn54gY-13^ibO@Xv}Z>4v2i+jMP5W%o!(#6$TNyFW)>&3d~Pfs>S5*wd`W96a&mIwGgWw)K=BLO%&#+$$t!o;(luMjEhs|&u)ps@d=E<&On@kAKbjl=T?LZr;EanLlN zAix7bKtiS#2KOOM7DTsez5$yNHz~_w)ryat%=!c}vk!=1ps@yy%o0Gcs&29_i9g~1 zl9s-uBI^PhzESu}&-A6|@F6VcSZ$dyFQmGHpVXsVg)wD(z2YbrcH`{W;2k})_;wEj2_afWf$3}nGZmdE(0e+CHy0mZ_tMt?Tzx0GcsBbeG6 zE-wbKoqSQL$ng=&$R)KenT&D3)VIJ+Is_|oKnpmENZ6J-{u?wlE;c)gbE@w%W4V=s zO4n7E9~>9S0AQy-I6+Lw^gi!3I4(48ZA&iP9O3xpdH*ZOE=BnHwbqH}m>~%1Qn#r# zMv|IdDiS14F?|ER##4`493!hoFMP;McOg~g07{;QP9c$(4x8lw#D;oe#~Nc$=>O_5 zO4VwWC*r=#ODXSpQ9XPo%gygFme>k6@o7e!33IT{*wueKkU9h#j=6%JazX%{UOI{(}P}!uAINgYUMfX(l+}eQVe+zXr`dN&ZMm|6sa{&eYAM??YC} z6zpUm9cJ2fW6)ojLU40@Lc;f0Ep}BSqwH~i{NBFd7twf6M~#n*nh}%G|A~zw?FU9yOMYGu<$166gdj`ZYsfNBZ{;;$*$T&A2lTecHg^@> zVqa$6YapYSROV`kEh@QnO1Od)kZPP(+jw~oaCpL`qR=2yN!#d->~AH3AQl6@xxCv* zyI)Lrn-(hYrrMQr>v|fZme2mKa|c`u`j}}(cqq3okh;qdx>s}RosFEq`(lg%EE z1@Y`{4n%GW_O4!FwV#IrMSN%p7xU|6oiTZU57l)rgMM6_gnU|c$D+S*QSyT(r(s%9 zIVkMS1_G-7S1*Q^6G@O{We%s^NDoH{{0ZyJI;UPi9vdkCC}91{zZh|DT>yjOiFTL+ zfPrCI)^=qd$V!i#6v{$+Y<1Tv#!9qK#z$C$+jx`1c9GVV?3jI@gg`2xw-K` zIufs0@iF8VBx+Vbc{AdOh)JevO%>U4hdc#%mOx5@N8RgGpD=2Y%(LYgJ@h(*L-}U? zG%)Yj@`229cilI%GPa-pnIgUR64?y?O7GSAI7OTk()h|R7Kl5OB6X=slLFN>GcEvZX;>5;gSm31p0D$dv?eWNB#$nks2g z$TjPZT+7$iewwk_)cO4!d{3F_ZFz7DnDmc+7ATUe;Z}0OP2)q$@kScUuytjKNJzBo zG=cNRQaI12eF3T&lpVDAGwr?{U=R0yarSkXk^q)*qhQdT9^PB*@I0!k6WUg6)@5jJ z!F$pKmCxuh#>c;(LK-HR-C!mgm?=A!Gv)m@8`{wjmKqZi=`X9ZayFIab%o^-8+6sd zcNEqF&}Y5ifgBB#rwnEMcjaPKUuh2vV{#DtgVe?m*f%6S!r+0Sr>Cbf&!!?kAfS>( zl05-^2eBi_S4-LDPpaW2s&J*Kv{aFrKd?={R{CbS3-WsHTizUgpDAGm*zGSljKE?e ztM8kRm(i6EA42{-+PdF&4rEtsF+bN4Q5G2bP>YkgXP!cw5D}=a!OnR!;UuDp{X_Jr z;Q(d+qPy7*z1>(MAKK?71(XQ!v6A}%Vko@w=U}abyBtVh|6T${yO(w5{A?f1EG=g^ zXnyhV@ezhRoU50JXaN%unEy%H)T#A9hYZnBm0YE&W8S8A`7$T#Bwjc^J-t#E^f>EM z%hf^0*V2fM2?A$<33~J;qy~tZFk3#kPJnh=M+SfW_#Ah>ezlv|SXuB@!{ANzuea*L z<%>?6qinP(%&y-ef`TMqm5Y`C07AtJGMMXs-WXeWku(+dC&Yx!h4II9K`r{8k z5e`1S`pd}3C>m@6d+V%Zl#~Nt5wy-vh<@(1i+AJDWGHa*cv$a;9kJB?TR8pP2Q;5S<#+F9D(}oacWL1Q z7Z1R$#}X~sS}>MorF-n9^?h952i<5sJxs z7ZBJ#ResI)+iz7nGa?AlEz88-w}l{#WJH4)CATef+e0QCC&9m|-HQ#E+& zma@(eDU66nfPIB|{0*dP_8?+|?AboV3Jf4AQi%3NqA|EwLIIj8_5|&h#I3F774sb& zz+6A@nsR$K2wRwO_0h?)Z8NtM>PP~$F(c&PDWeu42)GB~gkXZQ*4EbYvo?*7k4Hnm z(sn-AzjS5*q{vu5d4Z_R;33k?kd)FSl2yA8-)V$Y zg3)t*S{cj)Snc?_I-!_rGzWPmzq4j+w(0NZb|*ip_TovpnQ+p5Z25U_mZQ_OGclIm z;a09b2JhwP_JvF^)l>#Pz>QjAh#Df zwOL%o5{1`d=}2zrd#~Ljy zF!4M2MU`RMMgsK=MlMSRs;*9aK2N;9Ty-qF7}?9jOB3V!6hcM!RYXRj^#(eNlL>QF;|T57 z$IDp5)qJjDu`v+B1Ol8M*y=$ttWhn!dO5>Cidc zl3*;y=}b0jL2L_$&enM74aK-Ba^9aDSc?|s6Iz;Qd117eeeMJ@_QUBQ%*t7qa>n;= zR~nBC5mwx?XM47L1BdZe`*lu0dds%yp%Uox@YRgjW`TKIW^j1w{hej%tgx~QActru z+^T+-Hm;37yzMor+0uzC(vi&?!BAbJl&D%}3g6*aHizjd&wI*cBTG&|vFXE&~Zlb@b>UkZj{#r!?!kOh5xSO@`ohJhUW{KR@!r#*bMc01mbS zbHpulBT}`B>sNuX{u-y)ck)aR^tKujS^?D8OG|ILw6uhPEw>yoqh?~E}%gppwkxhfm7kGJ;{({%AtgjIi zzCnCyCu+gqR8&!v{K)cOvTPa(1+iYoY9 zh&3SzhH-6dIhJ0EVktEfSJ0*~)9f=jbg?Dx@JHTb>l9gfV}{+lMl_6C$pXCr>_r8N zm!+VY&Cbb*ebeF!gR?lO8!Q%x4OKT^WcP}8ISvn?8Lm}|DT2dPfENU76U%muEY!z2 zWZ#y>^Khqq{&^Rc1kDP;OSf8)dK6Eq@UiEcAeUzjdgsE4X$A0zkXFP~Wn2v_FeVYW0=onWE~7Obuo?i*Skd6@1aCxl_h zvO6b^x;0@Ft^T5WFI4ZCZm4?XPA8SGkg#3-0?dikA7V5~cw5OXW%)=7LI`I-e0flo7#GP8}&qu1We8fA_%E#xEf>hjB6L$A zTwt9COwo|*_yp-+e;OA?(n;kiytI;X9yny1GGSM}jZWA(OGBXo@|4IZ#i!;Qnaia| zhx|SdKTCn#cGsgD$vol9W&!Kcf6F{kQ8k)LQWZ3azuQJGd!SsV!7dnd8zhEjU z9`rtdO3ZN1A@R(Y3?HAvJg$A@E+Lh#1kYKY&Dr$Rf|%l}G4g$99t3Z~84pXZ|3-gI zR_Vg4^#>uvUc>@qpoDQw)c&vd#NP~*Cjwe!XWh`0Qjz%ecR;9(j(-+jzcDN+DGB07 zE)fv}@X#6n=Ncr(*1fQ+3Q|raQ~E!IIC3Wlbq}zzH!P)S4rrlNuQK`g`n8B~@t;c2 z93YR{yL~PZ#xrWYqwN-yi!kDCeMUAjL0ywo(1oOb{}*))%K zx-hh2p!q}!)-gWw!!BrCz{wOozl(&MX!9VoeROnYxY~6c=e1HGjwFgspO^`HK2)t* zR9leulUQvP&Wvzd+)3*bgr*YpfL;@Di@Cc>?ZKPpVNZf$j$#H&TE@M{rumqV>dHM==c`PlGtfE6|ue`H{;TiWg5ph?VWT>Iv-Gs zZ4(O{e@HgshjSNXpTT#;BhMrC61mDOd-jz4NVR)y7FzbZUNR=K=`J01Odi^Nd#M|U zTQ66KH;7er|3zpE|IIKYKS%rAx!wXXw)Sn|{{$*SoR8X2GRy@~A*_slB)(PyI2>_= zZqsq_NAk}(4(F+dwot}3w!llq8h%ToF+f4(`yav*LPT4#keips9=dSm@qXV9PRgN5 zM}Pq6-MPVPbp4|=*XcuFVpG3uKeX=-WFVqXol2{KXe$NIy{!#LO)34}@6%e?=glcY zYhM4-iMw|!T(!{E(b+JOqA@-qfgOy`iF_qT6!AX+E0u8Uf}Zw82#wkPuf4&QUPZ@m zfT_EM#<-gVYLiwWJkU@h)h^7S?;p@*jtG(*ZZp3`)mZzTWL^KM_H?oTiqwM^EBIbx z??a{08x~8?%96=W0yXDepEz%WhE4pZ1eiCk4f6z%gAK8%J%mCkH|V-PNzpJ2tix#F zp@yAe5kr2e-9B4p@=xqPSVNIboko%g;W}0!=BdOEHrPrGaa_u5>($IUuH@E z8kDvkveE251cG#u+l5Oij1}>>jW<#{|IC2;`L7>0=oxsSX3P$Gbas_N(UJwbRg)JQ z-QE8+48xP$u}C6w+J{SRRq(dYpA$E9ZuZ1H7_*dv7cvnYGSh)~L8nyHeJYEXGPVDD zC6T}NYzDabkPn`|hnoPz_n&ioB&ya=4n*IvH2oTv7w~;co@PsRqd&O}%LrA4a1aLF zT?H{vrlxU#5|&j^(@+RRToy_7oL4p9bb&ovly{Y@m?4yIgb%H3cIgroM8$$F=AUYX zgb$M!eP_lK2lcSd;?`&f<1Y_owfF!gGJZ79T*y0bP?Kxe6QiS@YSVbSF zJh<--V}N*PZ4Az+L4YeWq1_&j5_O}}yH-7V78-rEKBVJvLe(zJ{W$n)FVgv22k8tkIlN(viyiH*po+Hgj& zcth!%*7U8Nm@Q-IOxO0aQ$vt^v`NN;F2DKKVvpp9**%q%NW2vstTt6il9>N%^(Qd1 z?sXmg`E{&(l-@f0pd$`DVTpL5me{CoY47}F@|5BEn73*Z~Ye&q!IV>S+_(#8a zuO$-$u;>04<`X!5`Y8xd;jp1pjWmYfho55!$ZnmlZ-QegK#}G!+u@G{Vf>O6!yXx* zbIiU~NSq-|0%&GAru`z4Bm!$Pk^*z-S!w^+(<~L>B5a26_1M3oOz88`JOTuW{(glc z;aabq3cLwK0*8v$tmw53olDTaA;lJ2iWDM1_o4g_R`U@+kt2upKsw_JHDS_N^m#00 zTorB=>ne=Y)KpT4aOz(TAOZiw(59L6QR<~kxKb&UJY2Rr>s$t}Zr^R)L@v5#fNj z1AaYs3C!BVwo)X_p2+I=?zRg(&~o!6P2>}$n2ffg;@edMc_s9gP@Mo*{f|fs^7#M? z0KX&5hwTR`{Zo^ZhY{QQXj%iUH{VxEpH&aSCSlwTVbKt$^XQ9poicFMDS-4K2L>QE z@`vf!a_qj|B5agI;Cz7@gJ_5#KniT`aBLHDbD-Aap*?qRR6bvlq$%@TV>!*~&rfF- zwmC4T22^-cw5|t`_r)fLi#7q~Z3Dh}#1GQ)hWy5hfxjX$5NN9E>7`pR$j=wRmVD{> zK_2p??y@;gf#QUTtv%bBj)v{^=DqGOdr0vl7Xv{f#3<45Z6{&NMil5lc!7~L=rHv1 z@&XXD3&2Dsa>6H1o`iX8`c_YUq-JI&g)nc(6B#@X?AhIZl%`>3M&GyjzN!5O-gN6> zEc4VE;yhlvRxcddzDa7d?srQUUt?R_IaxVugjUT|uS3(;%q}PxkZf}}E-73doY~QM zyQPkXyYB>EMNt8VmZ&xQbSC~CbTECn>tJHp@Cn8+|&6 zft4}|0RyDlrx7t1vQ}Voa&T`sPEY3W>Y{>#j~+on0N`XQ6%&idu?P9y{r&w-pdIXO zraO=a8)&WlT8~ph|5BbJKzwCT5&%@LntUZ2;<>JjUNcmJM2qAhUp9|FY_L12VZU4e zgap(-{ok)^VHqNo);!VW><^(+`X31LeW<+e;#jA`+f7ANXn3M@O6E|5WySOye(en@_E>tg>6LgFLj5nNheGh zcL`?wvC*T!Hma=-Ym=5Ff>0Fhzucz#BzV?-JK*}4|M?C<)7+;{joi!s#*04MS!;HO zV-A?;Lh%wYjmRfn25rUf$XvO|HJ)-DUSkQ8xc)H$6jTt#b-!p(;~ovFffsVKkpK&~ z=x4tBv^R9+pYi2NGB4Bl^vk-BIcTpCAP2Ag9Qvv$0~8TMoi-TTBoWkP*d(>9jic{_ zsJgF643A1CsDy|us=3uqDg5Kzggy!i5aynhm@?4$jUO~U zy_?i)g45MWzgztq>Ny+<5u%9haoXE$ZS&j=r`iyrR3rS>K8J9>%)vk8JWqm zKHyCKC+zPFIzJu}pj5hz#ejp;Hfdh7@DKm-mX2kVnQdl)(T`DGelXhnah7Y0OEG`} zmIE}^V(5AFwxE4#uQp{#;^M8a^%qYCrhGMZFSH?fKFQiDNjHwnq%@#1rm3^1?v2vo zpr#3@h1=hj1#HaD?0YWA4$CGRPps;*KDX~6>fyhNT6_7CgipB(l0!hnaZnI)YyxK^!hv>y*XhFb0D*e`J&Ma_7`jFUUnu`O zHtOmjQ56Z6J${d#M(gwd&?e0G*BjkkNM9M3DEp8ajNxzdTT+9+;A3uNG`_kFLV>?A zUrjSSS+~81dm)P9&anRF0#r<8 z5bb-m*JX7*u|i>g0!99slLjvhavW^Po@xc*H7Q#!Klm+Kbynn{<_MJQNz%@YVYqX53;@P^A$vWTyzF;C$uk?9e!n6i5AzQBEEF4$&FNRZTa;EFnI?fB_^`) zfj1ITbNmfBIE>!q#OYTVa&<3}!463NYV^|{jL~&V->zN6s|-PHW60s47J^mA+T9@4}kG*jM(*f1Uo1Y$FVL= zGZfHK;4k(eVInRY4j7^=dqJ`2mu(*$0zXcbIGcLRkUvL?dfXezQA#|tDGPbW%i%U2 zT&GxR9YSV54mE^RnJ0dBlU%I;H$3Kr0W9gjg%Z`>ax1H#4NpLbD=>yU@{XDP08X2D zND~J}=Yhgwm_vwB@wZ%#h+1JLa3D*YaqZ=x)I!;Plsvu%(_vGYnXPcwHB~mDU&mP@k*!C2 zA)21HgG1Fn+e#2fU5lp52`)0fToO9p+w4375qHRf)-6~J{iN7`1S~@5nb>Q$9&#bB zcij0MC_K2OrQ4(M=&*x=l}khEUCdTZQD8310OsJ(`aoPNS1H9d=}#KeKBCsL?^6!P z-nMx*K@Ox$3F~Lo$EQJZs%LnWe4@9aOyYzCpc$}_B~@Kxlz6+*RI#xzbnj~cr-DvN zjJsf#o#hHEKxI4i6f_hl!2Elmbp8FnZ~KO9F%p6YI)uqdlYcfN0LS~FN%)iR%?R>U z_+J2;L$vQAbTNYvk_bY|NSquxEWaSn5)->3V@Se>o6YgJqGSOjXZxp9t(%TqWW-=zrm8!j{}V@Y%YlUW!Op0$cXxw6D4GY!>>~? z;+$g-S9&DuRZr*?y)K)7fR|Db>-Uo2_m7_-6$*KHRfS!?tnh^k#J+>>90o-gBoaaz zwm)d(!fZR2y~BsvL5)knXaiI>56u0h+Wlef9a-n*AX=5Hn@SMwKuWZs4DraAN)M;O zVtoWjMNUgzkj^LLKGmlgzpUgf?&r8QR9XGIqd+#3CU%S z#pYfrH$$6}6ZrISJl)w+S-REJ+|H==o%zznB%2b({)Rw1~8gJ(GMFjzwyEYs-?Om zBwS?FFe(?i8wwdn3S}30_(@!whxRY8#*YenpTdtwdBFXaeo_K##-K1nselv%me6#- z&rm%;G(u*)-!N}J5Sdjw-JowE%_~igr>kO5fqXXmur)^ScAASyQt5DGnViLUq|6Wc}~3Qr)L_%yK9m44QlSj` z5OBdnBb$$Ou@rvbhGNUTew_!-9A>^{(l@H6Gh+ptGen4+f2xZlgJ@3m4snX6-U^&Q z+p0?MUs_z!(AFlt7_}kuwKE-Fma*(+@WWvwKa!qSomkxa{aE4+v+;Z@ z$uCdkk2VP+khwCI!_f=0td%5;p7#Hy>+HS;#Cs0NHNGd4Q8%ztKsp!203NKP37Hxz zQgATQJ`^xZ0pDJ6p;2J2CMHa{Nka2PW~Py$+v0T6)YF6FpLG*K)bjQD%Ttd|HFbY+ z1l$+4Hma#>97yq?h*aI?V+B5$x3WNi8R*m z^{``B#}2Nxg-jn@NajnvZD^*7xGkb>$`gh!Ql;7QL!5W}Nr0-6D00|PgFx5or>A?B z9$S}5az~m`IpOE4t1u0dAZ!j?)Dm98;kW=-$3mP1h$F-@U~nJYhU`bX56=sU$-gi7 zEB*OZ2%to{_DgMyKvMbS78zyRC$-cctlGHOA>$1+x)Wc#;&@bWawSOv&J~RNLsd+l zkz7PX6Yab6wI}9h8wyjP6o&Y&WF|l2L;fPnN68;AuZB69O@ZRLD2c{6pDPCxE;I+i z(2ilQR_17`iA><4iiDk^7FGkwY{Pn|vBE<({T1^NuQfSkX^K6z#R^!GgXksW9VAf@ zBr{7YziXg^OzbAim9$=Tny;BCfS^IuPxZe#0#7{fTQ$^H-D326y?Rj$Cl;zw%P%B}aGorCso3Yk2p{ z`1i~EnfLGC1@GufZ|~2I!DA2Vg1;(mp2wIurAB+<#_%|mEN~QDK&xX!)Ax21luk>p z%|F^(;pq(jdeW7;Fp^aq$R|$QANZ)adydLFa#$gIu#bsKNc2Oxts6VPTI%>jpcFgr zQQS(gk%HbEJR&mXpzj3)O1wrg?y_0_N)eAQ-s2+*ytFS%8QZ;Mrz1t9R@Vj6w>Ysx z<)cG7ho+hFB_@2{Vq&fp3dRRko^H|NUIJT7ccRTL&~*^4`|*sE)j^eDxE^f6#l|HdnK$uFE{i1 z>|h5?d#bL~o5BO5dwgoz8G)OxPD2bSlve6V4x-on(DB*Qi%vAC%cdNat4fI&nzZ5- zC!W>WghDKL(Mx8$)22%yeR9b44f-?-m}PtV@1MfN7KP`>nNfhNSMkxKN4G=ixOle$ znehjhL@;9`hH@ICQDUX0>eJJerV&0eeK~vYRkuYdwbi+#nl;!q?{5XNZlm8|N(JMy zi;#0u8(UlNGB!Ki4YBy&CywvxJd+(kon`Vvz202CT1TBbpK4;fmO~FCW#kdy zFCVEWzG|HOLWT-%dZzH@nIPdRKE^9IQEHJdY~^yzC?g{y-DW$)8cei?POVT;;AJX- zwfA~A4m;K9O}*&2!ZRp#y^K#3Yz_&!V^?tr6Ti>;#q)jbi>qYpD$?7mD4RkB{U{<* z%8Cr2R`a+CN4E4257%m5GMuPU=8AvBCBaRUtZ{#6SQi&nkjlE~^X}vi7YfuC_2MCp z(=$m0oi|wp>4ARD#Lya+QXy6`)+0+R zdWUGbJ_YVh^zTCpXz?Im9H~w=P$uAxIeCR`L7i<=YMi3hugNPmFZi)arua!3cP)Q@ zd%QbOgN?_k>8=;qC>SV4kFUvdTGVt4ERgHlYT=~k7{EQ9XuYjS$)_omtgoCUX(cnN zag_??*SeL|y&N^nf4PNF^o03~CsR@9s`-hpPNWG6^C}PHqe5JEd-14xzep^3yPV6tcr&qTc!E7WdNAoav_9+_U!YcxvZs zpv5*Ri=X3Ri+(Xt<%Doc?I>VyB}A)Ln^yWf;`-#gOCE|H{56UXmTpKL?r?+itjRw>Z-!{ z4j7keiLgGjX@q&)OJxCTi3>E_QWy=538CQl>oFAx$`33Lqv>l7pU^p4KXI{H$uV?A zcP~FPf3;%T%VVlDipQg?-WiCKsw+-WRu&jfy`2N_2Qr7-tp7cS1MGZ0kq0`i(T;Rj zN@>{M_<>w@tSZz`@r%VuZPmEj>Q>;%M5$dqE<@g@4g0~whVs#pPp>LC+_1mqMQQlZ zr%^Pdisb3=g(e{5_V@RVzVfL()6NgWji&#I7hk=??{(9~2IJu)V^T@wr&&~bC7^Lj zD|*>T#!4!-aELv`CGnh+^NJ>?MQ-LyI+^xx=Ow+IVU|BK9hs-%&IbGSU==kf*$9$3e9%?ZXi5>xO<@5lZh{ ze0mn#sb}w8=*PG@51DeFRRQ(z#4Xi0=B_$Yu2qdQ2iJE{d)GoZH6l1VpMl$-8!L1* zy5+&jwN7Pb*(g-#!_37PstZjP2?kO(MbhHcN^lCr*~-qzXA5^_T`x8g6@}NQCfpj= zP$t&!md6(u0$&5q@m(-I zGgB1aJE4FwaKUj~+2*3i=Y|=}cf5?ZKOcEabVq+ft?mjX&I0kp5JggoQh2QiS-W-x zW$ZcOgIheGsl``l!meLIo#zW)^msF?GC`+CEJ$@QT$6xGHBcJ9?yj9Tzm{X@i`FPw z@T}$JethEk8!HqFE?kiM$>_BdlYvWyz%T~Bi;dCo%oev8aq(;OI$ziPgD+!zyn#3V z3f*kfV+!}?LG7`O=!V+Rl&)e}ZMMSIhz|xhmL_gS<}$v4fxcACo@zZ+h9aB13ZWdz!LIHjT1sFvI&h({SR=4CoH zh`D$cAe0k4$bKtBSZ4VwI%F%Y%xRIVO&GY#jqIZnCu{sq`|v(bXIrv=S7r3Z#+g*y z37b7ss^;ZHrt2teSvZ~8G(L&KUaJ!l6ihi^%l$;9SQvNwg#%4a$}kdBh9kum%ax~@)n_|4A3hFm&Qd2o}e#uSW0SaKDgBPraUEPEZU5>{(M91y?bZy zaf?_BsMY{>sb+~dcJx^nE^x<6_Lo7R`yu5rtVc|(R`na;lZTOEG zVefY1N7Bl|3{A59>Su?q*LIY4upKEdCG1QDV(YI^=4v9FT;)g~_9hESm8Qqg__-5~ zuC7ED78dGBs&MDjCeE^Qa(ypf=E5_8KOJ;hwPRopu{l$-o<6ns4Ij#p=ZPK5@9bz!Ln{Jh}b%!{lo=s(J%f zv%0gX7ruFNnpXEeAHuo#L(%wkHQUC!z1W6Y#(f1Yo^6?xy2)(8ti+63#`FY+L?r^X z8uQcB{{>*2$qK^I`t|o9Ldwx7zr}|@dh`5N<7e_R6>?LoZIkm=wqNkmmlR$+N*)v> zTJoYvj*u$hTx9Wg4#=&A+J^J}4K?BBiOcEJdKd0{#dLQo|MTf4nn{#w;EGYt_&kSy zJoD;=b?MqI@@l>32-)x3W*6;H`{$hMxY0<$+`w$WkkPaJt<1<7t*|P}lUf6-fQI5b zN|>jl+%X3iWbjO^PO`APUQgpjo^iG`FGHHWi;E_>&)`8nW`39%3$^2~u_HCNNrNZO zwh38ReTl;~!$i80DvPO^$b_?NuxPnudbzOYYKk`~p(2G>j*KCf;|J+6YHdU0yVp-JD+xz@=v@t9~$grssN48mQ z3cJ&H~DY zVU`JSAcl;*g!A=JukjlqWdZWV#MfUnuz-IbGEwY&E=ttY--?HwXqq#lSXsq(^M1CV z1Pb-4I+%;$EeZu=aHWnM<67MNV*hNT#Lh5B#Gtg4(-1P8;iH6-X_4ZnA6AT*+;b&r z?>67I67}>jse65}QZSH3X-O#0+&CNdOO-b6ZD$}qv$T(1*Nx+E^=ut%e|`Z0<>&fI z;9H_@d(x2)Fs;-{#`E|j)H?^lQ<$g={^v5U`=oBi2H1dE-!I4aRU`}{tyjM z5+zdp>oGq)9st^0!CtZUr$e;7J`nIYEyU~G@LbOxx&Rbi(yUFy3*=gA&vZoc0i0$` z7yDPGKy8w2=|pzwnjG~vecic{b5&T?3iGanb!`n6*#I86ZSo9F$91E>@1s)b>d#wX zAUyk%kNpKQp)09F-(_aki)U9Q!xs-nE(^OW)M;G-;?B&?+c$B@weuBY+RiD;T#$;* zJ|ktcQmJ5RX&F7X_ImxC$nibR$5&PHre|b${xebH>4(jn$haR1q89BZFHSdm;7E@! zp#JDeCe*O@&OM<(p%UciL_X+G)O}g=YyV8{SY}>(Qw#39M;f(aU0rywOB@9=lr zI+&&@Lf%%TFqqpotgA-Hao~@&#T#!zFsJJ|dDfIv;HF{n5wkC=j+jBUoc1t<1q*AA zr4tc%j9?fwr^qHPkDtdtyUPA5AT_Y;KLUD>vTB-K~EekzsX?EyVHKMGqi{3|+ z6vMD8BVR4G0vW>D4xR5p)NF6qt+)qb8^Q>c42gB=q-6@u#<$}NB7wMsw?&EH`Lrxy? zyuT^50;6UMBUtvId*1*4JqRp;o1j^mU`}))t z6{f0WpsnLlcJ@^^onXqFmF&nOg+iHP+3>H>g}v#HGq@SZY0mS0`kH&Po-&Rpx4dOR zhw&MHynv7pi14aw3SB;a@AKE|>ra^pnP*&#Pe9fNHKF*Ls}?f_50>qpUEJzHhC<~B zrP#dMrOkbIuO~ORn!i$>y7s_NoJz%QGNf`C&1vea(qK-1*=DXTgt;N-e0x17n1uhe zN6q(u2hQ!Ey0h|U^Vwa@!-%PlZV4x2quVWZo*N`|Ww{rvqSmOeJY7jct;@uI2h<@B zXNU=YUQ4Q%6vtGT#&$MwU*E8Sjv%pi=uex2iMx~CI0}WiV!9P3>6Cu_?6S8tkKdke zMnkP@Q J(cyH?rgp^6}R`7{F>Zw`2(Z6#K)TNF59dK-)<%Cu$I&qk)vDmc!{g_ z*uVCsvMxVh1*6nGlp==G6T5S2lu>hHZgU8!5LK)Ms$wmL6VKN-w2GZ~Nf!gvM%_bkpxcdva?L`W!ITV!2Tz+h0=aZ87as+O zma(Pp*G?L(ed}2x@#l}^t_=Tn+xsQ^m-DvNx#76OouPjGK6p=_pU8B>z&{dl)(o=-<6w*`*EW|E6{=+TVNOu)o|H=$ zLj3J;K6N5x#k0DUuU3~|$ZM|+s-mOPPZ^&oOQ59%^{b}l)u zI37Dmx0}9W_oRJ7eiXLlTN>hR-2af*vi_F(B>wYr-!e?W??Y@M&5uBb8JAZJqvp1NZtp`K(^3529g zcO={6$B(!6_M&zvF@|8iv$3_ccXf4LH>xX!16vCU3VJ|tI9_HZ3ps8e)ZE(M77r32 z94Ei;-;>d?fEDT{9N4zV^)z48ZEe^)DXi~-w~n$J>JGb}L)_rb`e{SJ?vzvol z6L=!MBM(JrTYS|S4v0PS$T$Y)iSH`2hE&goU1yCpr7_Lsc}VQZ?J4=q%J_ZdhqDQK zeEk(#S;!VejNcl|fovMYf-Fd_LtJ|unnqfbA7ENcrg5riiKB5Bd2VbIJjld^T*SSp zBf^dTEdxdAqQeBa-f1|Ftv{~OOd)@su#`VGV+M~+;c-$)tu(6V^+!c0u#**nDbbTz z+$wV9d3<-Nx1SQqUOFD9YISsfFV~7(?!Lnr87q91Dr|q>jMdj~KPkI7{`?>5N_6!N5*HEtGCAEQx?kTmI{&Zf743v}yLg6en({d$nk@;>AREdeP+#>42K< zR`}WWP(yH|L&(rL?8eM<6vfE=ue<~D8zNm@-BzzR%>zG}?@5vHLl#Y;gzNj6CPx(N z^5x6O@s}VN(a5_5$p!bl-w}cfr^zxMrnuviWSZF~%?##*UwdEiP|RmN;z0`C`}Xb6 zhGH@p3rEwTVeL-_ud-h5AAQc^GLV@q^d6hXLWv`xeqVuP^uE$<=Ly!k=U*6)2oXaS z)phkM+X5a_Tql<9rta4z%)XOk0DGdQNWY!C$-@D@5{O14g4|zFgE3!@Ic{UN`Kt8_ zaSQbl;rjivvxhI`&D3A=Eqx#4h+2fd-4VD(d8eU442}pwP6hAqU)C&ZW;}E83BvBf znc$FI;s7B9h(qpeyKh#+?uI>J^`MDCeh_8X6O98WpYlZ>{oZT8Y7Ij|2QV?B62&#Y zC@nr4tq^AQDb0q9o%ild{X#|SijigWQ1liKeZ4h3G(HoI5BmxAmghX?Jv=)=q$}p3 zH-rc7JpCv*I#8@YZo!bE1X~b7%%x6MvtF=krpk0P2`1<~b)UGD z`V&TO({K1fc_X>!jEWQTYV8I84{Pro&*l674Zll;vIeucBYxu~cpTqOH!W9L6BS_x{_2)hO?$K`gsdYBOL z$mK`sDC)?IzkH#GL-UD@0rILf#m|lH$bga&TZ^0BUs`b*GdR(`p9O(7O{%0gY18wF?op zIa~!BckXAjJutOhTU(c}dP)A$!;KwPPGBv3y&v+tx6CQg)n~kBP-B>BudycOV9V_X zpC=S6a*DP+=G?V1H2RC7Mpe@^7s)CmM)q?bYPzoQF!EV-HzVt*ifkWFnIn;|XMwHE znyk)WI3I+pJW0ICZp(nwrH5-S9G`9<<3wW2uQx{LV!rcfx8+fV)W+B132_k)cR!y0 z1C1$gOJe^+rAQCS<%H-So0F51kDng~jIvXp*rc$u2z{)tXJh1%t!)wrS2n(U`Qze2 z*o7c=@s*UTbJ~7!#?>%4gQ|sR{@VeJ5_BiB`ckN>_#Pt3r z!YjkaoHhp;RgJV8_3G^y*s5jLKKI?wuB$N>F_+lhbSlQ#fG5rG&o_T+yn)OMDUtlD zx`$q$!~@qq^jLd*<=KYFKZj#9aE$%ttKMzd*0*AF{T$Ex-#I14SsNvO$}G%JZf|y; z<`|byt99!4)pX4F3#i^8a{v0iscyaIYt!51p4zN%mT^cbpy@sSwVjpyQT#!R=6`22A^B^dSdW_H6`jm5NE)y3>6OkW3evsA-zkdJDHc&#Z$7ASNawgT_ z;N-7H=9uHR<-mEbq@uDz+bfh(Jkv@wa4n}p7o(aOg*2-sNz1nU7N$5rk}o!lJ8`d~ zWU(w#rB7@a+hN87U;%g^;M33~YXi6mo<&0MKstP+OV_(LNuj=_v4Aqa`4;!W2x-K^VfLAO|=- zW!5|GT3M6vqdc_2#CE$e7gk4`yI*SEMA3 z1NVwM?FFpr`3vJJ64uQRmSI1HB{6W$#f0u^D@>B4jkD`2>%a)COFuJqQVLsJxzRbh zhuaX`Hoz|8`AE0%icr4(`cxjT1*PJ?uPAz*8no*HHRz6_r)QVHzeM?iTY}$r671|> zctX7*=*Uau(!3A5UT>)Pi_};rVaIjb!mgn6YO$+~;t~hl;nTlq#rC9AQ+#CG0wSwC z{FQ~gu{7_8<~fA%i{*}Q%vWe7@iwbAcvVb>owk@jNw%&9oo?+cr+}Q$KD||)Eo%x2 z>-cH@%AR9L*iF0h+i6a>1;EncwYsZuh3JsB&CWlscHC8!21Zrv+do^O{-N{7n?vpA zxhd;RweX?}l^^&l$1gsi-x3kBpqtMxv}?yWH_VF`by-D0Q)X9BQM6=*#;((Kf1c&v zrA1cuT{viiATY!z&r(#WK`b5KT4AXD>`lE4g$_wQukfsmjN;GI-&?vP4{sF%NyE@T zFdt};u2kOme7&pWpG`{N=1*I&QYdVJg>nHg2(}6v)7NiEx9&VytlQnaE#|)Cx)(iu zF-bEfcmR=%Q%fE(6pEHf1m8ia4MVB#e8bFmAJxs#O$q1R5VwkUC*BuOEEi&y7p7KL z^cEIFY7ucyRq)f+dtttw(N<6LdcCiAC=H&$)#XZRDHN$8Ui4t~`l*?0Bq;2{xAJ1( z@ziV9@cCHR=MS6MWvcS26$1U3MuQBp4ZCz>+RxXp*LQl9b=~t;{K`;t=a=23vu@II zUv)5&ax05%7ua-Fplxg_!_!T){`wa|x2cq^9FAHn-y3dym3JLPow^ zQDNx&^-JRFdX7YyY;8%8-{Z-H*H-hjtvWLhcPHkA+^!_e2d`as!ws)t;8a|OEI}?6 zg&bOnp>zNM3omUHXT?8pEdftf=%Ak~b(0L#_1pH{M6< z&t!*#kA;PW>9|&sC9V*f8u(5v7yY_SfRNbXx&JQ{spm76jb~c&7?^8Iowt@=H2GYP zkQgPUYl0JALt&tqvz2}M@}EGM;6XP@dkwzSd2_*!AzBeLU$Zs{? z$0(kxoSfpXDc!T0ipFT|j|};GWMpKte1biC>qMX4(caTzM)IM5XbwLwZz0Kt;PWwS zO%BSK%lo*NH_XmG`O#Fo%i+cHf9lo|INOKmd0%z3yf#MtvYY--80aTPxR9Qav1*t5 zQS;?T0A#+SRfyxC)tQgM0eye{&jIb$esh5y2#xH*s64XKP12m+%cvQX+~nj)8xI2m zL;FO+7720#iaOWO#Rj3tUdcunX+B4sJ7$3sUjF{nGBPsN&yTGmQLLT^ib6&wCcN?E zg@uLJD2S&N0?yAieZOFXf3g{JAUEy3Y4>>=0o6z9F(=)H2b_j^QP&DZ;GgcWQ@4cO#yXFY%mfUuf`fy^%TrOoXm7Cv>Py5QH-m$J z%scgc0&IimU z(Tq+f?+e(+)LR{GZQcV1a!>0Zr^0oOzr9fE;Sj?o)s%k!%v^_A>(BlN;c*EcK$GPC zpc8I6!H0o9;@h_`OT7ngZnc8CBxd%Sq_OKJza?>^c-gblJK~2cG3{4AL0zpS{5q!T z=_bF-5uyT~P{vM<4$InwiK;^oq5PTreoCn2L&wIA8~qrrlRNJYCmaMSN7!rpemq&CS-)n4QqGlSVO4;^UIRhI2JD?Bcl7x2;twBI5j8IY;((A4f4OL{ zh=>SsXb=?MdGH`!s`k;n;1dFv5Mwk3(XD}fuQ?0o5Tbzv zU&4pImFGIu8SI1&-(29y6Deu6yLuBj#Xd~b)YNd{?9xHz0fU()TeolDvVVVLwXwQN z?FFIZC{c};kH62db~S{Q&R&`d|<^Qw$4&VZc^W5KN_A{jEDZM!Nr^$sUgJ zVa&DgquX@DV{vW;<%5ZM@BfGxw9ppp=|MO9q!#k+drR@&j4dnzT_$=pa+FGxzA$6Z z4SKMt=?N>K%MHxrC2aTbtO#yM(Cf=N0zA7ia(btIkjW?Nbih6&RJ4hn*nFxW&VJm z)N&P(0|-kAWmK|p`QTT^<4x^F*;9_A(((`f{gm~eAzpqta*El{m;I;QHa~_^qh&VMhv}eJihdEQnTyzw&Hi`u*;o2(VG!U52}LAv z*?GU)cy|B+FmV?CA>6UI({`UtszYBkIak*layT&G8h@5C&jfpbKCt!!2ldv#ikN3U zf{TTBMu8w5Gcz-em}X9ZNFH+5|8a67fr!9+8kv}w5T~)>WY6d3=H_ZX7F6v1pJuCf z@n!0Qtmt7!lDVDC2g1b#$ohmvWZJalm!+yH(8k&KV#GU%2L3-JeL!mQYB9QL4bViI zRV(HFX`nkPttJH>q}eSwBU? z=ekiu|0Z&cv%9@mgwRbBxjws?5ye@C`}3DCH3goYIH@fA8y6PmXO=C1uCBcqkBswS z=I`KE6W$UB1|IK-NCXEF4g9xp;78k0Kyr_cx72m|*CtFK7_3Vw|DtMaZEau~DZYF6 z%Ba6>W*4E`2VFPTU$e#hBAy{J7THk;8up+YDA`+=7?#P_{#gtY>*(ph=oH=;ys}Y@ z==QB;{asv2s$OqEBjUB7D#sOD+r#J2cSo*1F$w*lSQ+~N7$YL#(rwH#!Bl}GF*T_~)7| zZqhJYbvqK)%Sa{CwRf++Y-Fc9*xPGd>a zw!}@awp8ZTUicJ&!a8efkIH9LW=5+U6|1+8Kla$8{JRkDprOh}ji{X=i}WE5#`=GKcPupYVv>@_AvuYVP!kN_ z*7}it?OIY;L23W|a8JBpfyeE98WKMw9>~#H1~Curx$qfKBNdn=HDLg&(vE}ec*{mX z*b7}>pBxw+8(R_TS-$|F7VULi=g|-XBG+|rX2&2z-?>cs{(K~6yb|_Ut3UU(K!W7K z$9O9zqh!CR3SKrhzXhXQ+}N01F&-gut{D3ddK@fyb;`K|KtSz7QgzKwzI;=c-_? z1AaI+Y=RtYY-C;~-DU}z;NW09Un!uY(@5|E4>ir zgn+~^;pg$59!Xp-KKIMvSHpU2YKRGaE6Ykti&|QqsPL{{wW_eDX1%z$xc(+nQ`0~5 z^D!r%NH-VUU{k{zRCLsMXMIlMQniFB6(;)KtgJL{hw@(f1r1;l)IS2p~Q*b+AK zu12H}+6tbD{%m}viE?=B2G*Yjwj5Da{`d^H>@_0Nf~kX2u>%7U-OV`P*|d3hzlqo8 z=H`ZYCaD%!DLRp$8QNY4b`3T3OT8NwMvo}&*hph~OpZ`40-jCMKKy&Fe%^C6SdJ;w zSJVl1OpdRB~XxCy0c{|L)nb##Qr#ruXS-yyi z``))4vrZ1xGp@W&p=RWM<?Of)l-gMtum|qDTG-2*%A3wqz(jiKf~iDk&+6h_J);OuV`v zu2R&{kX4bs`{4r2|i0!k>0|!}wHu^bcBz5 zU>Mk z+~L@aZEP4}HV~8}v4)F3)`0vD@ z)h%yqJ0SFWU@_jMO{N_D<{IN%AdfXA|no0U9{Fxp|j&9ofM`KxQ=46-s{( zir~$gJyP!1V`8?vP#b8Xdq~cijZTlQ>q|}=-(r5$*?mpAt8~wj$P&I2N!l*mjQE-` z{kwM|D#W`UIJX^cg7rXKELdZ?==LSeq&od^BKb)Oo|_p7L2M5tPknVX`-#i}VwQbE z$N9pQ5%eubb1&|8PR#mRkFsW}meVoN#CP40wO#8Cj~c{Z`b0bRnKmlpQZ=Lsnw z(m1@t8K<7wqk@==|IR&gM+-Gqyhp+CTR$!LpG{DQb;dx(M zZ7xvea=)^yQVXp|o0e5D*Lenc_2JH?CtWBz=A?e#x;IVj#pgW<0+r`9~5T&C3{`G8^Spy68s~|VT)<>!pC%o_f z_dXN$QBgAp*Jm^B=upr8-I5y-S$m&sX4bAnkncT@Z>4(w_lslyr4?%V4z$|`qfRWJ z#O41~{%@(uhRxCxT*BwxIWTB|fe=8{nmfim8u6>@lgBwi5)M`hp^c=2B6(guS0oiE z_Jn~In##O~6M+pKE$_PPUl^9i8>>beaCBw1zIFfUTd{TK(NI&U3*JHZf7mm?wKz|k zq$S8V*4kpPBbI;xKt8Q4YN+xMVt}7Nb1IuZEK8}F)JlOH@MF+Ywefq&mKJ%C(#zsn z&gGAv20(`1<3>T7r0>-~KfNRL-n~`-N&BGi);7cweKT1drEOF} zQs}<5OmO-N^MCCAu7qKGPfU>&jZ%rOrREYHmvZbRhCO%0gvb^noD58k0f}ykK(x898LiCXW_0YkE7oUN zw>)xNh_uPSuAAJ;ZJ6V*6D3QcdVrUr%f?#-;8IRz_3wBw?p zqR!6FygPQ#yt8iU?DeEYEQgjB*gs4Y5)Cuv9d2A7&ygBQPEU>|7L$b zKYRoc;xT*sbc0A??UA!_C+p@@3;?Nt6$*SclPKr3OQVr_Z|t3?PYniAV{v%C^mKQ( zJ`V<8U3ZSabFjYFKGV=U0XPEpfy<`_ms2f*J0&ac>>%Mc%GT5sAn7 z?-&6!BX2DG%GZw{iy{m?bo#yLP}NFOPE1J>O&Nw1!5gsvr4Z2Lbq*=(_UYeE05*?< z(HO-H1N(CDKn%z&*R!&g9;`z6rxjbkbUcKc93{0bEGa5|Ebua~VkR zxljTkA9wOeb+VomRw?fE^EW_dAh=~lLK}d`VwuOF$Zf3I2N75zMv0Z?kuwlrqJdum zB2IRx{of^bqSV0b2ItBCGXYU|NJGSggq`U2Z3m2(OuL&ODe#B=KFVK&Bx?-wh1-GQ zG@uFC$;bPy19(N^O7dv3xRRC)Uh(leK0n^iypB!esDpz;=4|lJ9YR9uNJ|QGQ+SBx zJU2Fkk_59^xA@us$hdWL)^%XN{DOjOalKMhHhpz3`xzlT#KYU`R&K|r(|bW0ayHkeV08q*BK}}nIGo0P~VMr=Xd!H(8ft&92IDA;KA;MUXdy zSfrJMHjX*;n~Aj-5g}_vJIBPt)IfV8SFe_n{RICYhB=+Xu9pzBxE{=O8fpFVyN$fv zhBWz+36s)Wtj(F4h`H?W!c5xwU~bX!P;h)N9XD zq2K%rov5>HJ8rCq*&T%3u^n0O1BuIO0X-x)1XaNgpGXn=Us4#1Gt2V)D@(*b6o7eU zT88ev*c4P1BdTkUtTJPY}*C2ZUgaG5Fz3b;-_s6j8E8H^9 zVt}C6h3{;*3Y4e*DpJGMlzkoxSq3eb3;ft3B1rtPto!CI$Ai4^QOaUY@*wdpxX~Ca z+A&#?JBsM^r+dE1ILPjII>KVg_|R845SYi+rgX#(;p|TSG(!Lqn{=RMt)S@lOA-hf zNVht4PW`ypH8D}piL-SJmu<)4Y5>>6ebChhK=XBRzep(?6PQr#G?Ud>Pk*H_0b|s6 z-3$q#ru+i;K!8RpukCR2{z;HyxGT}gE=<9^yG{GOb6V>$wUZ|qK+E_54$G`A4bvA< zFfcH1(O92Qwrb@HihA~e0oIR|m22UJiV*pnfuT3)8>@CNz=>Or1P)|9$OKW_o=R4T zG4-?UR|kocq#yrud-Gm7VsNK_>4=pPvzuXFxAr!N6xO%|=xtMn zP!PcgCOWQ-K23wT5ginKRW7hN;Io9-K^{egPr>M-|1D9WZP~MjIXXIextkFajggs| zMtp0g?k?H+Uwe@DLDTdMOTQ0Ni_BMAParw5pW18NUdZn$%%&lQ-h)L*MBp54!GL63 zs#!^HDypiS+!QpLkN4TFL?`0vq1~9icmCqVuBFBKx*Y2_>FKZ0Du#|-j3Q^;k?;tA z$7{VMnvYfucW7{c9srwx`lT#W>GpANG_asU@>Su=xOQ#l;G#F9Ea(ijlTT%I%w!`C z9d~6yq?Tu~6G0LMb7|N7%((2`RUXncQt%gmKHbvNybTi1=6rV!yxua#LknK;G-M71 zG3>~vir7owyzBQ_L;3_@tZ4U`q?%>Bmcy8W zY`w6lNzcX*EQ(_6H>sdjTU5fpF~vIZMIz+m;fazn<2mFJ`ix^-j70>%y70l^Gj>Ap zGVXIM;4+clew(N6X50%0T8i<;Yq>)}rcl2}GqWt`G5-2_A(#+UDO0@uVTeyvew@p5 zWgU4x!NAGIwSp3b&T{8>>!js9M!%{Mm}B8{IhcfbUquqrRM4O;O65j)?|Yh^@#&{- znbSB9{Q#xdE;Q7OvBb?6U!+(TRaHReiy(<*ZWsZ;3R60g)P#hbN^ARFrWbK7baJxxz)RF^AfBp$ z<*$LvFa?oav5~3i*?0rPIswF{ftF#2O@hSHKjpb`iC^;s%e8ZfZ(t>3;F~Vhr*K8Q z5beh;BDJKUp{Y*dSSAw*(B$MaqxHkBXasGRCIss)1kRsZ&&Ea#Y7g(O0acs(dv!o?EA;b{i*ek`6jf&w%52S1i?ntp?(PQ1|ftAWtSO8KgLTEe+@Bf|A ztNt|_bi(iP$;fb`@lte-P%QAnVgMrKV3+ta(nHd~D)pG;0Dw_o1UPr^-re5SRkE!O zuvkZX`!lDw@1H&q%h)lUtd;uIz|FzQsn1#&=u`}33(cl-(M%#DBBiYvVp3A&n408+ z=1rU7yN;bdzm?sTOqMdbe0f&454cZw5u6UuA}##s3OTlPb9<@{2JASGwy!|O8iZVe zp8rloHJ^qHtbHk-{59+nzyIs7YdPINKM}dxg$sTN_2mS^i$8q|RM{AZZvS7uo=kn> z>g-HH8KWS~A*UxNHTQm^-a8SgXeWs4A#5$*HV6Xk@slSik4N>S_%9+3A?o*F?t`~r z+=jNPt$WUW(kImX2x8Fj?vI$VjWDI(hyZptopJ(UcJkSZABm#^e@cR`d(Y&@qVQlGBMiiZ*k9KH(UNdoEogVGfj;+aR;p%zumqqi568#jJix z5KNZ~(a^X*_s&Q-{KJ5@KLjw>xjtxW3Wenc8*Gd)6zJ|Vi=P2gC>6!L^WeEBnO033vBG8OTt$Jjj!Hoz`%9!0QZt9ca4&G(P+tJ2u>-*hu~dE5 zy2MVA3V(VIKkO}g`86a4G5d7C!xOTECc2QYurO~}$3#`E3}d}e!KnJ9;W_K9Ek%uZGSelpdrZK%f#F%pms!hVDcX)(dnLDu{! zLNo@;aEw9kRl+TmN;)oS7V;Q4(YQnShXAAj>$fNbOK9`=!C?a^IgD@})1uZmySi>{ zDoNOFsB-2}{-Z09EFf5$e5(|z1{{%N$Medn`uca()m;BrZtDliMIU|Oq_Svx78VzQ zv_-$Oftdl{C%A9$mJ~+6p}M>sUEGDHRgw@gE>q;xZ&HT67N-BS7f}t1x2)BGwM~81 zl$eNAwYhtdk~I+oKO)rE*QXV~&Y|e-S^eDFgJ9ghelc6M@Xq9z-2}JBRkXywxluq@ zBQo>lm3G=nDeLBT*Uv03B6F3c7}=k-U3K`K{_OeSNrie5M>ekTV|KJ0C>F}A&xRvw zA9K!f8sTO)P1XP1ugW7D-~L$j(BdSsW{Q4-RNUD!@#U1Aw<}%y-U_0~pig#-_WP)@7_Py54DNw)M%pQKv&n<^N|^^UsK?g zz|>C)=!^WxQCA@L99)*i;uLuEw3PsPe07t{U@#K3LUQ)|IZOnum~}KYFnwIsT(3y(g1^J|t&aRYU=}`$ z!l~oOD}AlgI}!M;ppbcc2xT`jOx=Y6Ap&T0DR)zO3XtwAP1Da0(?v*%l;2kX|sFA&`X0a?x~73Cz~LNZUZ z`;hxZ&DVog1=MTWmhX=3L9PVG+Q`c4HpWxsy?x`c=<3j4y+VdN#3CZ+LmOb7^)M^s zrpZ-v9Exz&K+A}5B|E432w|P)5tiQ_+1|Z=?G+ZrtakS7R8FV@-9RbAw2r~dOf=>i|Z5P%xc(*|z znDYq87R+OQqP)1JPnBo}NEa`}yz90_l4S5-e)|k*H57c*@Er&O;UTi;-lnA;IeeJn zG26oiHWx`LYoZf)S~z4t>0Ky2*mv!Z80bcItD!0Q29gENc#vcl4g{XWAE(Y3vvdH2 zmi|x$*T~02Q=RALgjV`gXts`Ffu-Q>5aszI={U^G#Nv9+V_{5T*fAm4F&FkZZzvYh zg<`NYh-c)8u|mf&s157`s+1$;rFWmUS}t z34lO1w|RbMBTOa`UHa3-_&HK4WT-~s^0%!b1C#iKT>V*y0>23?Cbk_2?gtTGjyem{ z7?^}Fu(d})n2H${Qsn(L6G#e!@i?}sI1*X!$678V&?L?UAdiH_4?5t6%>{rY zY#qzl<)Og=A{K-UG zPrv2s(}R%~5ikb6s2ojm)Z@a{g7r4Peew1?RKj$db}u_uJS3UL^?CW1b8;+N-|yFQ zf=x(qV2jPk?G#1121!Vq_%hHS9i)$~3R~pEhsKL@Q~JAy(R}$*WBN8NDL4eI7&YZF z6gxn0o4azhc{zT}f;EEt=H8I_SUMCDWS;hq+`v~V(gUsr;nVMFYHBV?AV96vH24m? zndLM7O6Ls_uGKs|JQ7wv!H=xPR`Wi+gHZlR9B(1`Eu!1**83)9l-e|p+8Z>r2Uj)Ff@4MV`s#{iTh74OG`*w0vP z-)@L`i_p8FMG6(!OE^^I45El+^jG5^GN}fIF6083Bn#N8uuc^rb#~5W1nOjSAc8MK zxQa`bO0``pNuHZQ587v_0lw7qb{DpV9jb^3XAjzSw*xEt2PiStGc!{qYkt}$j1c=K zvU|i~BniZC+z1b6B^|>_0atlBHCZi4#l^{L`O2ZdU?bc?8SscK*C=U6^E*QOG&OTt z^+i9SMFaM5mc}nh>t-4pVW5LCfEq}EjOy5|qbteaU3kD9%zAD<%PgxAI;{B1)>a*Dz5ucMEOl(@Ws$P?w5?@?9jX}n4A zh<*7o{rUZVfw~auohoV#2v`D1IY?aaNkv7)v^-8^!^T!tUv@<=rQe-SKI!224lGz6 zW5M08hWvNHsP!j%XjCT+z7-T4d`xh9n}Wsrkui#W8diO(gKu--xGrH^V|7ZVx>vRS za@T1bLL+KMgtxwLRkm(8@kE*xcqP+XU;IIG!u?3fh5^1D=;4V+6130~paK-DS^E4~ zVAi@3>_TI_J^U_sY<9HZUqi?KEP^Hr#^NjD3$FNjxFQSnaDnPu1$(bMadl6;d7N=B6JZNM;Lwnuv{%JOWU%DS%=5=t^T zv}L5Npo;#i$fvHoMK~U#sHXxR@KiO;&k-S7)t%GVdD@q@lstPDy3I$X^=W&(Dy!jy33%Iw>}KW|!fQ?`CE&Lz~*6#UPQnWIk7>zw%T z61>wY3eMGawA`z`x$2Ay>6U;UeF9}uh=Yz}hvDpxZg%5S*CaX^mxg7yvZLMSugxWS>5Igb62s{Gy8ZdU({1?r!z&e7sP#gp3B$h;VITk*m(!p;4A| z0MIe5%&+QDlovedy`c#z+(_Mb453Yf3tWH%R^vjvipCg5Mlcd_B@j7_+L=?wfSYZC zaz+fN+`Ose6B98oaB`+d68FCKum6Pj1HA05Yr-h&-Tm-YLx(jiBeR_g*OAf2Osa%9@ zqON;8D7(V)0+?+X<6kI5U%&b>l3P~9lw6jojrS2-hW&dnxA(34H$)# zvgFQiWh{TyK3H!snZBheZ)i5)Ga8}L&H%%?1HtaY6KjBQWa`o1{z!e~7 znyakqzojfW@OI#St^i>>cQOKmMRrO?|12vjU1Ia^ZwM?k)4iP%fBySmodsA&ALy*q zFP*agw~iVQ($dB(R5$8u$`fpyob6*ioqN?v-k+#NL(LVPc9~(YasTPcx}>|R@F=v< zC22*|Zr}W`RgL#ftMmX@sbxjZy_1*2RH1W`>-tZRPq*A>2#?Z}l)m*}Q$8#OMLE(E zf?BddK%eEp{Fce^d(SdH3befF0792otyHqC7K1iq(DANn!Q%7*c2gh}4_u{}zlfy^ zKlNP7-V8n#Ttjm8)8)21P`~P(y?_2Y9GwxQbN$uJSayy%|5=>T@O713T$&ebXlQsQ?{6d@wT67FVk}C^HbR1M<)Z*{tGyOK2De4R zL}{iPoJB!alxGUb+<-u7MVAcvuwFfnFF;!un41sFL{p}@9oybrx~oV`hZSjFI6?Xu zg6opH7^ZD`13X=f`~20{0QkChlAH5U7X;u?o;xz_1G%n`Tt}eQ7qT0&va{`ys%Gcr zp2IUiOt=ccD?SgxWBCW^>50`YyIuZi4)TkidXZ~&tBREg{xTsT*2@b$xhtOvOHn&m z5us6Ns?FMUiQVzo*oCd9-#a;FNJ9Ch@u-;ob zsx@@+p%0S0Nh<;1uzgig3eZ$aI}OtT(UJGpFq3f}r7&{KR5xKBr=O(NPk?2^nZcQW z1WA<$-V$*qe-_f)nskta60!>!4^Sa)B#vgX8~Ip!Z~~p7v1K=HmIQzaS6l>l=5=Si z+QH`cF3bq-K*2%o0RqM~ga-~1hRYV6OmMRXvw0I358VOok$|Y+OH~T4-vZ-g2XaLp zoe!WU!S*;v=6GY|Q-5`$50dGBbi#2Pw=e-!!e{H&6U?9tN>Iou#=zH+ILa(PU)pm? z=EB#<0Z5Kb7ygXZE(vMhR@wLEfp|>CZx9Rx3*+ltC;P|lZy_XuAM~erQ+?ZB@IOA} z$&JVJ1yPC+bIOC)rs*v`x|gh?=;t_mJk6(h;Oxs2MKD=eJZ>n5`han1)tMGruMF}w zBRfX2{@{?OEI`e!kV8A{#_j??HN~59J@mkl1h|-Uto3JgFOtx)h1n~_ag$&PdB)X| z*4wCGA`BDUSzP^Xb!B9`9}wFGUi(XiGoZ%BC7P##^*6#{BKXR?sF;Z=*HK7Uko%DW zVtlU8!?L`zl$6V$7Y8Bys}nVnq;hJGj6VVTNb+p}(b-n0b@WA)4|_)>a>%2}cr6I8 z3FS`8kP_!tLJ>vg0)XI<3(-N<1t_pfc>BIT8u9&h^V3#D%nE?``sp1XkQkAW^jAwR zThz62+HH>owQsvC5Y;CVzy8*|19*9;h#5t)PO@j4y}<;cEMYwK0Wf9T;KT;CGe`%i zKqa8G2KP2_@ao*;Wy)bx_MoK2{N`xl6Rg{nk1r9qlSZ}dve)(za4$;sHZsKP;%K>k z16*S#=*448JGRckxe+*oVCJwN)D*%>Bic_Mbv=MRK`=oa2kRzmToR{~Sj%C(ir_); z7{iO-kyBtn$kaAcLY1NWTQ9IH%5$j*A^>r-%6oE7KWGzfNnia%1B8uIh(vwz=M>1{ z>=`X^bFHrS@1!fge(k-h3e!e*0(KxR3Ni?@4UH|{3dIXi6}eTYOJj$UHh?d?G4DZY z+Q7o20Rvrvx^8rp8&;Tr!XOBm=g%YNXGc_aU)1nqUR~>r&GjWp7i;6s^tTlpqVR%& z$5Z-)F+jzY80PV&SeO?XAIh;>T3UkHs2932GklH{1zL#E2&D#y+i+6}=ou1LqVzS- zZ3P9R#FC!v5$OI6CW0Ev2kCYo;ad@XQ^wmaLI}Q!LdUr;GRY~vy(RqSqusyOMx4EVRrc~8AorB)GOIdY3S%aV9rODE7Uw>QDMWx!4Zgd8zy*4TRXcIp?`BQ{E#z9oXCx#3zfroL?i-U z2(a=NnNy6eUcE{VTvqp2(7B{)Ic3H##RJUE7^FIUkl3PRiPXO#M$zp|vuVlkMa3E~ zLPY$W)D02HQ2NI?hk=~og8AWGQV)X7he&W0l)_M~q0Jalc%MKsq#{N^(cqj(#=0y` zZ+LT&k=>M(EaY7Mt$gsHqOV26Yd)Bh(#3WzY}BR28AG!ANR3O87my59!m@!0F9#O1 z+4j9|rg=1wsAXBhZbZsB^*#VlWoNsYMFKSLlH|GV+62ACW_tKXY6{U9EYfV(c%-;h z4z?a7?1}|AS6A9*Kp=;c>Yl^+#cL#O#Q8wLMbc-uQjt|AcsAcog)q5I+R;TyZpfa zbRg>;ot`cOV~bV?#;>oW?m(O`kCIJ5e0JEjqYjj9YjBi zyk!L@lx#RO<2Zj*L7@bgBf$TosA`7}ChcMsw{_MGECZom2r+;h6jIwGkWr&L1w){4 zlMs~!V58lr>gxh@hsW_oLZ9O>uo3|dM80GkRaMj>1t2h_=(FV8xA9#+OjyGoJ*qU{ z02kT|tbTg9CFxxrppefHzawfI2@)>@jk_-VyOkcM`?R(Z$}u=;UJq(i_HxvZe`#)M z$^HvF^e{U&_YLx{dlQ?O$EV#f0Fzu|H8u4ffDP=XmRN{px2=2jgo&3m;Ke+A_wHTm zd^!0R-P1u6aDkso zO1vS@NkeeBQRf1gJ$7(#X_xkEugjUbSK%Y&Qm=zlKj1jXg6m=5ShqzS8geL)XK-+2 z;pF^`J^0QYFNj=bc6QdT$#}Onn~He+uUckCMGA(lidXkuYNZAqFoo78PCrtFkE^+Ha>0wP;8q=FMLo0 zc|t&jyTOA`qZnxqBnXmNgS;%7gj;~=J`L8VR^bgaV&9U50+{*=bPZYg)ydeBi*mRg}FiR9x$|LQtyEn<5X&>_Jo%54kLWZe!G@ z0eQOBq&a%UaOMxfGi{i!hu9mrjzb%ve)$Czq3^UEejBIG3=4EDpAov_2Eg4q*xEWE z+jLMI#36373KMEf5aAf3q8v{RLZJjtooz~%X7oK zC8}@uGjh@iWB`lthdgDuZJXf)I?k)otV;G&oSay{WdSPBT@|fA74W=h`UxaBzbGnfc+*SD*@F zh#nGYm8z9<5M$nr_feAv5QmW(9};n7e4q4TJs>hj$#Cf0Aj(3iC(?*H09!V zqJzQKzh2cBgSN4xNeBbC%uQIbno()BSWFEK;Ls%XvtU6UA8=v-HV4R&yeWu#052j^ zuX(A;(VXkV2yx{N#jV8&XXBW7cno_-t|Lak#J5$*D*oUM5|`ZY^Cv%^8HMMmnj@l! zJBdY>iw=!^=8Cy(S*}xpGn=}BDMOf3*wVsXO8fTW21gCHH|la=KLg+aI)kEPj* z@M)Yf&P>oxeVgZe^X^?aPE#5Nnc?S0$RZ{1)FB$3q6z;ltX4?gHR6?G@#jAPTs zM-VQ7^g#l#1{x(~S6v9$0kRY20I5wPd<{?m_@iP3XH1-&bokGGM~gj zBXlY$^T+5jN+Gz4YT%CbvOj!inX3W+B8Ghu86hAjm{^gH0Noe8JywUELcdml}L zFmY6fE-5(P)QGTcyUOS(hf@vmvUs==smO!>A~XtOa|lfXBL5DkauUTJnchbr|B<{I zZmEpatN^PdLBu(LHOV$We*&UO!E%!Hl1PMt#I<#Gi)(7`G?4-foP4qt5H`TglVPz) z#|hm<@j`tXiZ%NW))Qfnl#~?7DzE{-Zr+VQIW>#O$Jxb&B>&(L2{(tdfIx}F_kuj_ z)ij1gLK@aN1d#GZ8Sd1E^B3M+_`oT^mjz3SEU$gT$nSK6?lm2=rhj=Q0NFp7l|`N6s0}5qKjKKqDH# zqz5sK#H;82CW%3Y1a4P-B+;=U8E5<0eG;pJ@|yZzE1z^lX}wx2C?I<51t>ZYb=gns zM=cl&#Cr(p2sH;0Y8n)2IQ-8*!#2CgNi8G^Pmi=QVw>PyJx7MwJQH}Bw-A75Qj6OV zoM1UlBLx*Guy#ZYjl{7WdUAhkDpE&)g_kp4qT7qFu)TT}Qm!t|3z7Tm_<^1WBN=@T ztI$3%OT?fC_QYwBzBaHDW%T84E)C}$#cbOTWAXW*&H{o8*%YFCe{XLpf(N180Pqj1 zkR~~H>oy8|0q8^V8P;9rygNDf2(}*Vwb%%sQw=T)jKr`Li3EZ9Q7FB=z0K7X5d#9u zrxX=Q$8h8-=)Fh#|B#i{UbL@)_f7?6rvpy?#&lpFx;A_mwY#wHM;C;$9ZdO z254^;W7VM1Be||k>_ucC93lqqD_TB!G-7vm^!6T4`f*1v6$?i!0KiOY5Lf7Eh1;1N zZiy6P21PAyTE;6W`L_9~)b8DP=i3S65pq^;rwv8C9?b{Vnw`(}^E^8uJCu^?5Vo~-v^y4UV zqIq5Mz=pH2Q@|Wd!D8AxKH&~!I??mUT>Ha>@{@nU3ul}&G_B5}Ew4}3^T&%C_?@Jg z)YhL(>%xi(Ak<6)T*gi*gt@0?;PlcKPO0uoMhs67+g~98ATwi-Ph^;LC5%IH;foE` z0Rowbd7(0gL5ql|c%cgN8sf(+aFy8tyBw}`kUDPU618q8AoBUdL3-AdK_4mD);Ao;*o1lxDfQ0EEOi>u8KboMb7R zA!+dkUrT6Id_BTKKpSe848xU6w@t6Tb^CTH>=;U*$BtXI;F)lkib&1@!`=?R30jNh0rcfHKlU`%##9< zy+P&+d;A>n@ulxC>50bcKZy!7-iOm@x1ZZkimZhwuO_oHGpB?`{5CqSI31pylVk4` zq?d8&`m0yM!e~qfT&P=sZuPhuAM0LP_cK;Qu$ua!@iXyi2Hef7arQ2fG^ zmw*%&!}h=t#~p`y)}V%l!O+mKHsze2OnG~F=<&rp-IRW<82tHa-RTZ9I=(T;1a$Y%7{}tUAHIq!MXygQy zu?X$$?PM&PHEY?8k^v|)Z^BtoW=BepeuJ^jsVztM7Fu>wY);n^5QZwlEaYhF>6_Us zHV7}pqW2dm{z_9iDCRnQ8>)B3>EA;W$qlA%@O5@5EajBF#$M=IGC+BTYGNGy&8Px} zG>aBRe6XYh@WtKKjc4+QpF(;-d;rO3u@6cCG`*_-w3C=g&qwn&LhHW+!z79Ia(N!_ zzJ1!MabEx{p$L=oV5m(s{F~`lGvI1Z@okH2ZEYpRwA#XDSZ{P7K$AsZd``=z!N3s* zny9Eqs0<;e=-Zj|N7NGTMV1v76|F!3Ly~qf&lAQ56L#U}{5f>v&W>)@tT#;tx>uJd zyY3N!fMA|O!4Oef{>-3-o$mXp6@A)hw7rN(FJOpp?x_)8ph()T;G-kah2qVMm5D)u z2cPrqgTV|igKe$V!{6f=@nxUC>`&r2S>89%F*LLe?Hdx*Z!L4y?}iY=(m+n`_xsLs zjjH<#n?aj`?5F1SZ1AyO*^K{}ibBd>U>nGmYwj*b+q3~}&6&Tb>FdYUXONLgN!hi* zC5d1CrHzEcZ+`a<>+}ydM!ujR;~72$QoUMe51=;eU1XlD_f%L(i_AFP+?g~N*W5{4yO9^R(m$%Na#a5d4 zuiVKUES+#zSkGdJfrxY=i2LAXF!AwM&rK@zz=|6l*}CC~VEA%UsF~E$Ymx(0uoc%p zvE1eeQNx#4l`1T-sv!o9>sn-#rJ=RKijb@XZ^9p1DO;)i5}Q@;jJ!hxkDTQS?R^jy z?a1F;+vIV(-soO{up$sXhIH5^%S3wY;+$$EaZHdI3c$Ppiec?;61fE|G-+WA5Qun( z{coVfE$SL0H3>u=04NRw14(M&f7f$z1|d<3#_=U7sEB?R`XnNdI)eJsYH&1!ML~81 zm@DyZ-UUl5D{^SS;}Z@A#+HPW0yEgCG9;(W8>)HkDuqfz`-uQlecYMniXLitf&aR> z!+rSUnRq26GAf3IT3$=HxG70JW}v{jQUlHqE31BU@f4B7A`C8#KNBMcogRPJmcI#O zwA(2}k}Q=$^i-7b?YEK;K)^G@3av)2g+egRiV06d2&tSG6cjZuM2KnF-`1aj3!;&% za%_A=7^`tAw{G3q4r@$K+KjUuvltQRfCbu?F9ow6LMk|GYgsToP#^ADwBRUMoXU9# z`ovRFvJDxRnE|O&Bf~*3f_Ubz@J3T$OystZF{5v_IHGHu0{oNh1oI-;0B6jwq!9Y) zu(~=6X{m%4E9X>?13)Sm=rhxDT_l=640Qf<) zt?j^zM`ibM__FMPQo|pR+|i+)C{kNy>sFs3tI5$0`IKgqJdwwQc|=nZ7LEvnHUJuE z|A|Ql=~pw-&=6nN(cga)21q~u+9seSq~MUK&jve=wr++SA^--;vmr}XxNt#076_Nf z09p*7STq3Zt)ZB^06^gkY`6U(5k9{kI141q#BCF^(U2C?>Qp6Ku&`1=K>@r#2eJzg zV&3Qyfn_3b@%iG*5Yp*?8-Vf>ZI1fC#Rj(8Eh0yRC@d>0i{RD((XPuyA11L5#2NJl z?9X<;$;kMCXd_D+I1JQ+YfvE%MEVxm58LKdLISi7nK60n3m5Ds1VxaVA(4X_FM?-- zl_xDPQFmyNYQQ0J$qNN-NTQLLnRb2%fGo-?^LNLpZo`ns2M^H4$QbH&WYpT(_S+DT zu13s((il?VhfECQaa-Uoob-TEL_g>#u@aQnK6H)y`PS;wwg6(fkplf)9PtEw;H8;# z3<+Yu!OxEJ{ay=O!-ky#t$T6-XJBAFu_gqofo5zJFx(l&IIVVQmv*;2W_cmja;zXZ zIr;h)z87=Kii%#z$-gaBh4p^UT)>0hQ#MOW^3d20MH^^u(8*qCVRXlg46zAU4tzYa zm5(pFIv4<0S=-`>%IknYvyGf}C}lSaP1J$Fv-^^5Ea1~U(~?HtGy_^U)D zL=RoExzFB@q{?|KXaM0i>Z_!Vi158w*tUJUyzoXgw}Yqeean|}Y!i*$WUk2tn%rAl z(=rcB{^I3JKrfGm@>0VdJ`6y%DK|e2p(IpLgnJ+vI)V|D?`REixWtA7{@Sx=4+6IP ziSFA31UBl-*w}yWyf>-K?EbH!{Tt`WRJi~FIY9#*52lFeJe{!C5`<^rj(~i1ewALl zY(Sl+t=}PZQiC-%VdTaKlQE%B1;gShYTqnWZ}OdK+WMjs{xNS_HrCzdN5;%((xBSb z4R%+g^MCb}THHF@ytKClbvKtQBeWLpPwIv}*J#Uk{g0Lgi~WH&-hX*chl`tA7i=%( zboY3{@;`S^A(L1Q=^b*yx~n&A<(B6j((F9q!Z7x4YJwTZxT2!QQTczfik3<8|Jm&8 z|91lY|2Nj2mY%DyX@YU7Pki;6D9$op8;+F&5wGo4S#pASR(mGZA6Ro=!Bqb5c(5o} zgl5VS^@uTcjXAH%s*+><)${vW_g`DlCV6fCl;Ksggm>c!78;Y*!fj^0#%mS@A5`98+@>e~p z)gX%qv(uEsG}nO}aW9Yb*JiS{(%SpQKiGYFIJ%|B7>}W-3(4de^-1U1uOWzp%U9PNI(C;5=Q~fnEAWlWLnZrm6HBuCSML zi%I-yWaM$MpnSJKl<@9E`0WRgcbGD)Z*53qQY`Q;iHj9xDmUUDme*6BX=MDr?l?U^ey%5hSIKae3Woh(T^!ewp7{#<}R$-cS6-g7OjCDkePMttoBS2xAKaCUYPW#DnA zs4iMN+g#F;c+I3fX25TM-cRPJ$rFLm%wD#?EzfS1#!^Y~CAY0^S+wF6?znPvoBPdA ze>scNI2vsSe$tOPTC8)Lcz$S)`nR|&=2H0yiWH4O%E)-A9jmA1&7(shvWK_0o!MD1 zZypNUqPgdnSxifd&S>%Lg+q)U#qXOJ8zSNi|FXx-+*cR4&h>b4Q>43C$KI|36|!UY z7d=whPS4WnRorOtewgB_+$S*K^`@$UM@mRA{@V_fyl~^FNYB?O$n<8FsoSn{V2{&` z!LONm_k^VVJ#{1695zmxhvFGI_NOzhOC0(7Ue~icrTelDChd&`=7 zOG}h?*F;P9#SfZ9KciD2}?|4*_ zM%0m4r%t=~4%a<$cM%Q`XEE-HD3hcgzrZ~_Y&prj>-)wlyD$G;2wCmwoaUE2_xgzF zVB8d2`P$l3N@hz!FWwj2bR3&Dnxe`{KjJN&?q@Lk>P%akwqii)m7Z;GuHng%%@LWS zb{B_-V;Z<`zxA08#l=&qWSUzV;}1SE8JhbzLZy9r;sSX&{qExCjRv=N z_vpl&y6ncu5Z47$4rlSdRj0(l zb24`Q;?w)q80omc_u_;52F5DWu0Bp9yTSU)x$A$G54V1PQ!n}}p_0kUgJ=FlJ{w1# z-beLE--g2~7^*m`dO2^Jn%|CYjm58WPd(I1HPn8kcCONO-X`jS+W*DbSH{)V18+hp zR@@z0+})jGEn3{&-QC^Y-QC?CF7CzQ;_mM7zkPRqyZd=R-kfA^W|ExA$z-1A(jlTM zjWPS(PI{;{<>ShAafZj&%bWI+?W8jJ_hvu;-K}x<1yNnrH!=k*6`5z>s-yGiPOmSj zDd0J^ej8uZAdnaq&46^t14`E1YoBJw6*@gXELJ?epKj!X-`-gdk>g7kj(s_2@cJ~N zE%o+VtTnP>&vk)Y?48eOO1BLq!*s2he~dAy^5Ro_Kj>P%Lh(|KfXbfeSP*fhRU3t0 zOx#S3*ZSm7eRg%LeYm|~Y^y&KEAy<4ya#oRA6#qM7n~r?#|vq(^l|#ESaFE8MlP6& zsMhH6rOR~^sy6!LK25v;cy^dENB!*m$2g=;c${Bv1NWH`>G8qx!U|`c=7Bfy3~+`x zv9uP@l;aUG8qP2>hveC}`tn>GOy>d0{d%os2mdD}&hnLoOTo@c8j!8EHdu^gI)Ll$ z&iEL&rWUQKG-dMCKg*zMuqkHwwvDeWCo6{mVl1!A)WoT1gQX}oR8y;GCMehoRKV82 z2C+a&$fhL?X>;V;D-;JRn4J{tBozU*J#xAV(mGyEi+#$Olaq=(Crzp^^jD_&cvF{p zqdv(P=h*9ofESgo(@aC!zSJ?Yx*vpe%PlIm+KF=YJ|>NZgxUEE^-0dB4ch8C!|tp+7l_=BB%>a}Vdw`efW!?4D;!j8jkalYnvWIRfu31pE> z$;}?5w~vno-ehE%q&{^>oqVcSoyf&J3vYY!BoF&#Sod=2%x9Rd-9^y1M0E+0Kldqe zJJo77^~0rm71cot%FW%iA(%iwUkBcDjCWl;CDm49?HvWqRc;9+F`Qh_$4fQmIYt10 zK6Ft^dG2tqdJOsbAHy&^6 zK9cNMHqHHE8$%_T)mQ-r@{vnG6Qb9@!x>S)q~VWO0-fa3+jP>b^Jg6eOzf@Ub)7g{ z3C)cdHdDsfca)3)vb#ABcw>Z4kGiB;H>n~cT8rUd2P!=f84i*uGrvq;1>p4W@qDJD zeP^v+d(Fn)qWy_lITSBq(C*-|)b=1cV9`a89Pzj&l>WC!DyF$7 zItI{{m^Ava75G&r+SyD8-`>|He{pW;-TLFziVZZ^Jz*VY=V+Suy1&1rQYo%}F>0=N z6}2N(ifE^&+IEbVMz}Gm>fs)p!(dSFx%rl{GL-@;!#QEEZOOwpmBV&rHsUWX zr3;bn!BP`L&A2_7NXw(4d|&NDlS;p>*yM5871oN!>3)Y#Xo#^~EYUyd)W8Tjosd#@ z1lzvh)LdypQ`yl5ss$BhoHAN~PA#$;VW9i>!F(&B#p@gnVR@&(?Fn;wr?SL#@CD$I zqAvz#q19qTh^Vc zgIc2kVt%bRIuP$mMrV~yt!ahS$OdXB=lfp1Z+Yr8x8pHcK}jV8KNxH~vzE}nNq1hR zK2O+B`?x@1EeLnEDv+uK5sYQaOBLyoy#Dhj@Kw-WFg#i=@*R5E%jqE_3gU69cdb7( zd7Z(TQ~0r0$Bsl*R8iCs&LLEuA@Mt47{OqtQSVh_#U#KUaLcf&^zzX=p0-KG{cYg2g|GtEuxsdfnGsXz@eFwj~Fg^cvM5xx|+B4sG z^r^B5Q!JQK0}RZ>{I`gZVoWrFkiJsIknw!&r!{u}lvV$-aWbV{q1wx4D(@#dZcpjy z(gmBB2c&nGb-<&zQ;3|M>|oNdoKzYuFK08HI8ovD^W^=L_c|x&(;3dfhBqW0{N^e^ zmDdd)5SV z5Q7);8}4yiFS>a|R#@qV&N7vSa2Fum-{12Q!9jDz>3qrUC%&ToYd}`-vAO_e^6e7) z`Bi%BL2_&Awx|71QUDh2vpqPl^X=9R6!cr} zI>p^-QYQtjM{RV_lPDKJdN{*TMCZ)&*a${bnD`82UO(4Ww;&sFZR^4x2ajt<_pE8KTW;)av>1?0p2ue>LheyEB_F`D;hdgxd zLNeZnqcsXujxNcEhYC_aNLJv3KYm2@an*92Ju!6S zSq022C&dgNQRpgN-^&N-4J~El5^^Iz5QrjUJ_ZrFfqxcblejK}al+h$IEKBE@9Zjj zie&Ofe~p5oz4=moJO%@oXf`QcSPf5Xx2;Yc84#l98^HMaEaIK|WM^?jBD?GA50zf7 z7*!2XyECZ8+i8bQ^V%^`TKG`@Z9q1GLkNVn48@I$rLzQfrHOc|oW-PX%SapO|7MlX zZi@0zIzY*6Iob}xut2!uf=Ph}9c2!%k{72XdlC1vp_3JO0)-dVYp58z^gX@ix ziqHhJ-aQn~n_Z44&o@rbE?ovf!?L`Eb8H_xikKX6(_rq;4u%ihhYv4ErW=v#c z?*xu%w#as|yT?HYPwg%?v>q^ffc4(z+m){5)OrKpX}7{k2WFvyC%AT zaUrZkCz##Ud%#i#J&7FRAVH=t{eTOxR?Q!78L2bh2ci!@4STR_aVC>r9-R%8=+t0T zd53l%ZzP6aMm!w}L4jW4BJ|he6aCd%!XrK_XnJ6KOsnR`DHEaB>izKazdmuwSU%x) zK}!z~x^U~B0Xe~NVAi^jh25pVRIH53TT&iJ`}6G1XFS%E-%?rV4JKQ&P49brVx!Rm zvjt*qW0=7WcV?e&If>GJieW>Nbpet0TjuH*;i8v~!?9fAC`f{YZAkO(&HWx0-#T>{ zwA`UcOA&Mgp%r?RKOm}4;#^wRl?zK*{Y@e(WiveQ(pw1)9T`%4Pv z7vGX;RIqLu^QAAYwf&_Va-pltU&SDMO0y|Q5FTtRUUydh;`YwsR6{C2GGxi5y|t0?D$NEODgZD{9Z_+ zs{iu^ks&M0CfG@62^Jc;BI(EV#ilCj9f2pi8>@Ox3%rXx_MR+*{%#@vQNtPd zAa{I=Pl(QZ*7$;pSbw1Pk-Vm0%t(sSV5L>$Sio4w__^+{ILS zhz3@F5s`Z-L(z<<1X9>@7{B=k!C#H}FXG$u+eSb4e#VQ)&tH~Iv*4BF7!vB_*)~EH!d}6}-9fHg{Bynn_)DAC$scah zbI4mujXVqZ6UhV&BBg{Q|9&p@i8WYQ0jG;4+4)qKrKx4q{hQwrUC%O!TQ&7T6ShZ&G zTnQ=qp*_Lb6#QU!{-WfiK9N1}KuBYn$X#GySb+s2C;V)6>yy^>h>-r9ExGS(+T&Ttsm-Ug0MgqCYKqk#&;M{yfp zFrX0O%X)h^dbT*`0_4lAWTAl?^vu0r!IDy*LhqujNy`@oG5rtFCnMP z&8lNH4dwxu z!C8KtD|x%G7&H)XnyI-G!yH5NEgyGB|OHsruu{;t2drkJHuw%s{B zQLRuyZVr}VD`B#*@LBZcjT4prG**63KC9vh7B!G8P^#Fu<&*ufc~?wU@)1FuM^B>GMlKj=_wo9AO1?}h&e4r3qSqewj^9}n+EV5|SoVD`{tzod&XGtX zf{qT-Ki}BExBN1*!)1s9*JCL;i)3lMftTkr1mf1SbOgs&;|Xc%0q(>(flC}V{3Z+V z8)xBWJeQXXro+efxz?^QaHtmm_6OEXBcmV$Zxu*gFb3R)|BTkOPNR{~=N#M%6Suu! zN$@MBgCf@H%w zjMDWIME~YqO?L^{qSK>9VCm}-&fc}W#neysuze`y{h+Vc`Ow+b|5p4ltj5rS#)vS5GpklXK zMc?dMb$CV@j?wAeIuJgjR{K?Ijy9%2`eMaj8`QoS@L((%J>MvEf?Z*g8BtY@P zgJ`**KZ}MxD}JA*7*x0&=;_9vZ{?POwc2d>Jf=!MQI@Z-4Y!`&R3eewN};n)5>xSQ z;r#Tis_;W2t|J>MTMc!!$`KYr;)?Hf$lm^8d95^**Cn+9rI_bsbxeWF?i^QAL#_BV z%u75vbKp-U!UkUh>LZwN3A=*o-M!$+s>uU|OHt1DNQxAfzHAG_tK6T8(HX9zzObZ; z*8@Dvs`sp1+Ka!XJD5KB3#OZqnQ=T%`T(>3dolPQ!)dEV1iy(#o%9E&B?yC8p4$Xz;j~I{zk-$P{Brd z)A_bFJe`8F!;Rd!4wfIGN3 z&op|oXZsSu>_oYn`}`;yI<6HQK4pG;8I^3F#5Y@7F;_0^%zDs7-e_G7;>@-i-~44R zu=Xc8cv5R7on*%u#1B%Axckx7c+;6<7>EW2bvLb_XF6J_;LiM=+8QV;mR5q()=jJR z?Unf`2Srj&qFC;TX6=CS&~@GC`ABUWE?WRv<#C6B7mCkMK*a{ozfr_W(|8x;;?Viv zX(m_jk>Yp4S>`SFprzVM>?_NkB}EW3s>+3)Yv0iq`Q8XQqzw~ZbU^*#($^+;QOK~C zlK*>H{Zz#IeiA0a4hG>W}dN*et!XYirXmLNoULIDe&3#84r!8 zplC~oY@Ucv>4+r4NK4*^*#(^SRj;`Rr$`vkVc8DZ*1Tj%$yQP}Q2fi8cq}_U zwzO=%8<_P?F_;%($%SOH4Z&(+Qzc?b>HO-@z%qfB19Ai+_z(qvvWW{#hVdFca0oC%a2p z^TecuP*+D)jcT17a3Yl=91@3y@&GoK9v&M-C{f7S`OWlF!diH(wJ(r!%tInqZ(Zn7GW$s}z?xmxq{XUgU?$pL=PMBx^+h`waB zrmX6L0@|tkKGjXPQ`p-fi&qLF*|?ayqRih#G05&xsiI@HH1`J$T#Z5?PFsUJoe3>o zZLO@IB7^>Xf>`K;NX2aK4ik6Q&;Ht7mrV7jbPC?tqPXI1OK+RhyP_NcxxNN$_*nB9 zj*bYDKYl2$CiCm+3674@Neboj(b{aAKGGVsbeja$5g)S%ADp#O`vMas78}xcEA=^Y zMNOTSTQHouDbkQkOoO?x`&Q^L1dayVw9b6tLS1D@3(e^|k?dD|O7rO{&+TewR6{Y>R zml!Sb*V;g3kc#Wadc2u*Ig=Kou1HM;`HMC~`KP9@bp<=~2+2zSqRf4o_Z9K3@Xs(s zOWV=sL?x?J1Fd{`QQC=uFY~_?IO!2&{^oecM(hg1L@CzVjgHPMS(Vvn`7?{sqKc_? zV?8i?aU7plT?uzZ)J2fWOL zstTy+T z!9PIE-iz_zVowhI6UH~RL9m@Nt~ov{kjDoW(C5p(uy$CJe%W(>chnNOoEmhoiO^4@ z!gYc~BmV9l{nxZEJ%zR|>6hqOl5ZB)6mqc1&z{vwA9x+xfcAvqt%Nu|%d{j9rf(HQ zcH3tJSJb-R*NELJ1+|n1LZgcH7!2`Pi&>yv@fPn9#)u<0SOrn?_VneG^Nhm-W=$R6p zv&fgM{P3iaC_r&!$cd3JTLgTgP}qmn?hfRereP_Iu&CEC8J**mq>94q(6n9#e?il! zTb&L{PUza~(G6>>B8gv*msnjSEN8vo#YJTFSaH%+m)R#7SB$Q^D?}k~7)^o3TEL;t zWSS|c1H6$jJt_bf0@m3W+Sp>+T-m{fDv<5u;!-yLj2$KuEHtCxhgpjoL_PPfuE#_g z=v2Z)mlJ5VWoLOuRiPCp;lN`zt}15iDc!!8SX&yeYwC=}q>5zti(Z{DLe(q=9iMY$ z2Y9^W$d!w5i`5<$XO@s)dNA0MxAQqQ9yAFOM_-*EYSdkD$8;w(9*?G)d20<|dI; zD@W_X=q=QHi6jw;igP&C=jLKWS1WVse>YrUe&C}U9*MRtU7)V5>x(xL;gjZe-!0RY zJ;0sb{+$x0P?om04c3RB`pB2E!Nbxh2RqhkE)Sb&RyTPZPM0gd72%|EX2zXQZyv>c z&+3iHN+5?7MSh$jWRZcUq9N!Lu6FunYs&54p8iM|Dj!adRw7SA&TghL{pmL@%FR@$ z_j|0}TpM00=hZ)b!)5pWU9P**Od8XnK;-zIQmwMijlu_6!1D50*6~H@Xd^3r`jNSP zjhQ~^bT&jsU#+F)`5dDK%kdh$&z;YGz=FaaUYtP{!`T|fdxqWoMt*O2kRWnc9 z$antA+PCQQqYFT^)hrWp@c<6+Y3CkQmsfhdkG4V5$pVMx_FVG$c`iTseBnGtr~Ul? z)xC?*3+lFu#`2r-J}}QsD%G>MF*t!gGl3@>`R5N_2?`mvpiSAM*;57e*wJkIcdy#_ zTOKHeTTv4*uU>Ef(zAco2Rr%)vcEK#IuSQV^DZFVjZ`30<@1b-_U_}5v%m}pBajJ?r8C9glb-AN+9n==M{oE}XN zEMO|d2;#u5FP~T1^Gpsrc%0>E#sfk&UD`xFpfV5&3KjbitCHm7)N}>~|7f}lVgC47 zE5mHbobkPbzOen_sbi!YU7Fju&>+62y!A`968Bu4u+vRh@8>JKqIp;7={1cs?hGXK zj*mG6i@Te`;i<*n)DHr!Gc;(fxdf=YZNIPD+#~*R(xoe}6qfD`M*8KlI^%<9lC!PC z->j;>EKw62E$6kHMb*|qND(oHox8+Wo1iv(bM^4k$oAx6$RhxXD58+lWS+Af>CCoz zpR#E?kjLgrZ9i0J6U=&sI#+kQ0zS4LwYS`Td`|qBX~vp<&Xl#Bxd9ceFSIH93#^&pO+@D)2|u zmWIAAIO-EUgaV>S-;d_Z{0s7>-bbiGmpZTYyZxy4JWTecW$sp@r0{w&1&H{BnU0=J z*1$hljfXZ^eBw)NpR=B$N=q1kf?rD4OQ6XLoNw<_pj!#dNY!0}l7z+To>~s~Xhq&& zG9S~r#7R~u1fZ2B1A>BRl2%(vgrmNc(&@L>yTccB3;b4%M1kZp$0d+RpC^b#l^SLwz`yOE~IOpu~r$%TX60bFLJ zEU3?g zkPRU{d=*PV{Be1X zwo8hq5B~g_J_V=!E8qh~NXW<5bsIZ<>39g!$pka9D$s~LEk4!Tt!%Y<1J3GPucog# zwWX{oi%qq(pNrFkQlT06omVnc}Ug-KmE_IFQGRk)~?&Q4FuIt?;UmD49oE(Dh zWxCjH#q1!`niceU_6`RiyD3%@7MHO+jUI01o@i_GUpyXrjArtAunIxSpAvJhkv|n1 zz>&nk=J-0%fq_6ETU*ruy>>FOy%>&_qf~OAJs`^xJm;nFAf=?Tt6Ggkv6u%CBhaHXwaV#B|gTgCvP=IgoS7^HO`)jnEds zraPEycUKncNvgWTdXH0Q$yv$r8@RF5|^s|g~=j5=dYmE z0<$!R3@Teg&ViN)23<9uA6Z)RcpHh)tUnDUo{xGw`CC@K$IcLyQpU$k2Yr z8v|s5UprHSA$Cm}4=7dl^Yk-t&hRgUT5(qp*U!KB^MLD zA6zZJessr@2hgWNGd;9&6?rTi`-l<|{^FEr<=QZ4b`ka-$+7=Ks`dlc8#$mW1I;W+ z9`w|;k&erLrKpwOkDoUdMjFw6#f&BKQY+Cp?LH7%EO#^xu7j1vP!Q|Qm+};AZ7(1n zY4}b_;`K>_n$N04wy}9;?b+K%8_N6Syf3x%cc#`RBs2~;$spXnRKh0d=B4(@(h}L& zZYBAbtMC1Gg9}u#HhApY&`UMX9%^`vIB7^mGblITp?Fj%DuMmeX+l^(O%pNbxz3-Q z{`}^@KkaYU8Qw`d_}TB?A3!|Qx;IE)1=Go&nR{hgZzRi=^{;v)L0eX3{KYBRGf)&t z3nnj49bc26PG+08&<9;4a5X5?IoV8G@stSi`j_4JmV&0puan4L`AsQ6&ovxkVVDIf z!|`miwMWPBRy}c_v6u+m!8j$)OcKdpEPJe)kRQx?RWu!scdFwPAn=wEb0azqsGT6I z!vw&lKXXJSftMPKRE~|`nE0-AzkT>gd zLs@~0FKBT!j)tHQjD~_^eGUd!jjOE)8{bu}+eL2GsJVai+^CN{RAcEN0Fm9#Evhi` zYsPmkD-(5%U;fg^SJU#zIkx%dGY`v>mzprwi$T}7zQ&#}lHaMJRn{|Ed6&5p?pt#8 zH{45^xrrn!#ARAI2rQYmcq$;yK?q}M(ukT4D=4qSj+ppifHE;Ji9v>o3|yK_O@}UC zY{UJ)3XPNFMrSTNX+ld0@oty6MS6;AYZ)r246sUktQq^4Z#sIieB*0Any8woc} z)NChiGOsAa@W&C2R@#LUo7~lUf5Th@=TB!f)Qql*%aMtaTTx_CZ6yG?aDc}%x&BNG z`R-P{G=1bCA@SNvwdi0jg$7SMoacWag@p-UGl5?MqPJ)cIUi=iPWHpM4=l4_{n_37 zZQlCID!z%*xc>Q`@*dJ9mIZ<|j068P>m7Z%;%72J)b~UNo6hl@g~7y(C8NRQy2&e( zCxS~^w`r342b}nH)yV)(u{-cOn8_E&L<*LDns8rrlV7Gu6mdp0Mxqb|5_hc=4#(Mb zU-EY-b5(LerzEt?3DV$-0DYt2=~2-+T6J$M=K8{kwWuPmZ1KqoOZH#N2(d!gvzN}# z&65uTbS>96kK%46{z&k!7v*{w&`ufV2hhI-*EXE132UINYw*p9`FAc=biRj3bB*6K|54Q2ENgg3WL3?u1nb%S= zi9Uv=_Q;^#E>T4yUG@~A{zpu2NC*gd3~PJyDDNvCMBgCMr_W)&ulgAPl!yIKwI3c9 z;3cUn&|zIp`$c`loZ%NJbuHD%4hLI!^#3uciX5OhJTTUAc(DLmxIK{3ck$@|nfM34 zqm51Ad;nNvp9^y`#Z$DIp51?Nl5v4df7e^&6YmZ|1i|-Jb2Tje9nw zUJ}-7RRp{$C!GdJC@cL_eK6zxL5#C)PfVd5H+^nFXO&5Qg)5&*i$MC=T(x!2Tapa@ zpY#!4T21?c>pPIW&$8yED(W3XcGV{Twv1i8Buh|&bPPpGJE*#-R5t)qTsZ>`0BBNW zk;+yptZ?ba(exZ&XdIhgAO9PrAA=FB+EKnDUKOt{e;l{3FEd`0pZ{PHqqBq!buBO{~$=i(QwgdPt)?hDRI*ODSe!KZ5|lpfulb~ z*pH10Hhn+(qU<)SSj|9^@z3zr%4Hys84M7b@MrsfjGV|e8z9h+y5`Xb^%k7_JJ`9F zf4IAYA=I?Ak(J7eHZ}+p|GMzc&Y;$MvA+MvRsVO1YZdALJDl1yu$=!*>;ImIE=hcm z`9D>Wcbek=|C`ReEkPnv-DYTK7)!#j94>xS?+JRhx>O2{`?n))scTspwzXh!U|*Rx z(iFZaVt{U8pVmw8!UC{K=|m?6+3A18tlts-ijjZ#)aSXR5WeDiIOiXUquG*c;g`l! zB2IZQYP-|qX$|lij=q^GLTtio%JpWdV4Y(vva5Je9DmeL_f>#?HoLUGP*usdelKP% zPlU6gHY03jEpiJHQ*?I&-}KVMT6sQo$}D%vb>~+qAH4%F;e5K)eh6%zSuWv=qB$gY zJC``E2F2?fyYbiBBf7f3k>R!ED+KJQ_B=>*89@=WOm7IDD`oFV zV{dn)cy*&tj@Iwpp9nU+QOn!dp?S3mq{wJ(+4+gF((lJgxI?r=TOF<|AmOsp2Mvs~ zrwavZXgnD=5M%*7MZ6a*fUU@v?KuK@%cZA|YtFHC*|>jgo@lNfD_CSt#D|LF3kcDyD~Bxr=JngV!|DI#!aMN$~y%T3v@bvDtd8X*L^L3|~P(&Zruu zV|twSe=ZO7SvV_NYec=7HMeAUlpr-2ToPKF(2a{rugVY4$I)U^6|S^|KqDfU&Q+r% zX*I-mKHs~@I=Bw=Kk(Y-3MDwRmY+#18iv9Dzqxjx^9?#bufGck0t=oRe5`*bCk;XS=!{%4ai;(gyXq7U5c1tox^QrA!(bMmJ$Q` zREihmzhR!0S$asmt}Qm?%Ga8RJz$o;F}LL~Kd1C^C>P(g@>gWb*Xk-5$+bt=X{9OE zfGQmmlf9G4mR0vqdPSfs@&*%5{IU~(nJpowTv&&r7NvQn+-QHMObW+T40o0OBg4;8 z9Cu1&WHrnzkw;9H#C`rkb$*hl7>-U`<*4eQ5jKtSp2FOr$5fVVRfHcH(7> z#yl>8$fgzDaEpAZeCAr*#ozy`P^lL--z?-nksZ)^g1;HN(_A=p%YQ^h|9*U=GhI?N z+Dasl5S;LNBG_OPq&mJRg9kopG7T0pM0|1}B;|{yQ^Q${mvmSKID3`Br`HR&8H83; z^xQ}nbe;1Hkgu;H#yFBkF`unP+bi1VqCI&B{d_A zQ4Yaj$eqaW)$oRW>JB942$U6~)2Yjs%~e*-RodGX)p{ry%}jkKY7AZiC(p=_}p3%Bj4t#H`H~)Xk_!tvHY;p zMW@l;Se~v(HSVXxWEmEQm*bdDsOjd4gaMLEQEExcT(LAkk>c!6iDTp2M29#iGIj{F zL{l9jXEdZ+lSMHQb$O|LsUwRKux-TA2WyjTww-2dYQ-xFswjdg=-Z-VcPnxcudH%R z=%-huD>}{F&&~1!tiz5ML3is5`aB zq}fN>(jO|t!m1mCncF8T0nHT|I|fCyG}kF~K}M^R34;&igKz}MRlB)iZ;Hv1Tn4}sdD*WR0~c9p2%b9 zzV<5V^qpDGSNmFxT`BUdu^saDGA>5H)^0$Hj86fc6Xdv5ss1`tpk~B<6C;ujPNJ)9 z^|+V&IDBM1hbukmNvt_RSC(*d?X2^ir+GK-J|x|`10VegPSKMFl$hn#q4Iev0Z%a{ zBPK53%Z&VNewwVbDL=gxahkLxIu}MOwmetTrrJHuL-+jMOJ(|-rpWcWc?{{3H%tIH zpxmeu-oc@Cu<2fn9T|>VmzS|4*VxRw@2M$BBK6#+1A1w!^o;>WOclP9>IckC7L>|+ zjJc7`;ombt=8jPZJc^_<02!ziEHD+qo1W!hBVTz37Adp%!l|1|v=y_B0H$I2l(XF| ziBLn5G)27MicxIaOQoumG^sZBea844Hl+KCb|mKARFX`f4Dwkml02s6B?~CAN*3kA zQ!er=yB?*OQcf)BrWPv^y_gCyadKbA!bhEy1XEfBqk@Vn`G7` ziONZWa;T$5g|O?P^iN~{yr@z4$0Z&nmhISj!7VeXZvq~NJ#(+%RhyC2*=EkiKuN=H zI=@_eg5|nvJeV-!VX|}~2nz~6r#oQqH`t8M?w05L>pR8N;T9aMltMVW6jAmIJLquT zMasnW^W*MPAkOH}pQ*DE(Y<)9Z_D^@YhU0Cbg!wCc4!?EQGthjX4IlZ4Z5Xu`^qhJJGAR`6@L2YdZ62W_5a z*s6$+74&>aMoY^sZkQzAwnreJ#}q2^=O-ci#WfkR-AW1>z3|s=@rT+&OTqjqY#9EQ zA9_S#W0~$+5S$$*NaCm}_%k{@BX42v3d@(!#HJk3?--q{P5>@z&N2tARrj*i>r+Di z_-RpQp6*9%xi+sO)rC3UGj+*@X#Uodn7T(!O2Ms-@A?=>)QSt_g=JgdfDFw?=pdNr z%sc3sX2=iYlLstu^uiHuiW0*%c$qa$2ctnCh9(`h)Wb9%Vt28VzlOwZS+5;p0RL8htUeDE9 zx6C(>z>Tj)&B~RG={XEq+hcJ`cQv}=W70>{lOWdN$z^o)9-D4M6)Cpv^}4_*vdZ~X zW!U+s@ye?RO>WNv9E%t{<9qhI zXHuyMk$B*x+`|ReIDshBXHSZ>ho>99Qu_D5#eWn6ddBC8)%)d@3P2i?mVFx1&C<2$ zLwt+)HmN|iVrL4o{Z^b7i66Vl4m!|3GdvjaJ1}^L7n)8=c_YSw;mY|_eQ)CLO4zdr z`opye;P_+P%ea&nVrGAIN8RW{S}D|$3$w*vL8!Z&-Z>UO?D!Oc>TDhpCi5+@7unQd zzn+}bW-PlPuMCd5Pa6Wjh{apG8=42bfW&gj2RDyNl>V0Qhg7PIZ%+<3+QKP)j}35Q zG(#v;gDFS|!h1_fiZ-U^%-|L{9tac|4yOvhX3P~pOU@?0O7o95^%y9557gN?-{70> z_7ZP>8+lvZ*j(v+CjI2?nSM@Ez-HC$g>xgW#y8jmHoO@~xlSV6`$=EFstzCb9*&;i z>Xn}+x3-Jf?l1E{WWR?ji$rBQ@{EqX_`zsjIJM3; zWRv~0H)jf5bi2dfbnstRB|#S45)RTdqTaF8%htb846Cok{JPN|9lx4#pk>_J-;z^B zT)b<#dN7m!`@GY07rRGb3PslRKOWrBpEvkTY&{$_l6H_PeVuQX8ZT#GW?pB z&$gIZ04B6{H>*&leU8=6T}}#DyD_6S{VIC#2&5ecCzL^*z2aED7(#B9NE@no>+ZC4 z0gA!5wS{K_-P7(K7+Buna)uB_Q$ZhcS}svRMOx5Bu8qw|y{leu&xd<dW-G2yD=*|j6`#>qp<>}x)}ZN4L{FEb8@ z%is7d79&@@?S3b-c)z*2v)>rcFo5Uw4Af!d%vP>g?!+V4C4jS~GwDvc-3dO~2r%c` zDY;uAMQkQ)-#Fqfob@L1J%aIZc;ia%meW1WGVp`X<0<7XJVC5OzzBd2*4fNh-|bI? zCaaJ94|=4>@$1cDOrIvld|6hg%bfTxH|A z9dP5bjp0+^{?jn)NI$G}_GS-=A^KGjoCV6J<~^(l`Qmq$$(22;bu#3`E`-IwQY9?> z5_KfHaNy924*9|}ta`1+l0M|4i`?o!zB|kigH#Q9kvaMlHejk%z;;V7D)fZGE_*6C zAjHEI6fE_M+r#?<-^4tIDz)4_AO|&x>%KjG)#G5Q_vt(klmM^*fKYYNMhd)LTc7${ z$>%Tp+d%%ZrPh`(AqACiRfGb8t|urdrOk=K0m|^cmv=G~3X6j((5kQ|J3%p!&sMurhHbu(<4WnO#%H zLHk3%jz|#@`+5#h8a>vtb;1u0`ZjoBesH#ifgpM9XE^+iSZ|-M$3HqaSx==0wgxPL zt#MI2vUU{IK~Ogt(^qWntI~MVROjQ|T7GBg_gn_;t_nUgh#yR1D2`@u^E~YNo1xjR zO{tSZ6yHF0R^wWS&ROZc#;Wq`Fus019Oo)YrSFGu0XfFs1EvZMYz_?kLa;sk$v6aN ztVdXcMgX|4*U6A|7E>bRUaRTpRfqQcbH7qIF^wIqC$V9J9RKmJ;M;Ecy??3dZ6 zWC`QVS$V-Qh1Jl*chj0F1QZuI-^fOyybpacwW>L}TOp7ij9N_*@Bp5DO|Ww|_y4E5 zvkZ!}iS|4aAi>=w1b25!&>+FxHMqNn;O-DS!5s#7cXx-uZEy?D4tcY0)!lpVZf(_m z+4=NLO?N-1&*?r-&CL1re-mJo-g^B;Oo929@z22W&+elMA)d;i-5RYV29Dv_AVzN3F;{jq$|B}Mt=M{g8_quufADc!{6CCDx6bOEBu_bi|$}18!Nw)H|gSlu3*IP zU7}+6!McJ`QuEytT+g*BI;I`EEDl)M=#VZ`!jYHJ47=#7MBX4dNP+?TZ5oo(yPmhIt; z`Bw59^KztJsWokUdhq^|-~2A^`VLW1erLqLL~Q7(^k=JM3V%5Pk|2K1oaJ~3;(h#a zb@L{hKDR(|z46@XhNuwI`t}k_jSWw{w8CV*wVrl2PQ2;)<*}JzLc|xCHtp?bk^bvP&az_jU+a~}SH)a$ zIqy7;4jLM7u`+rN7{Et-agW~OXQc4FK$+&t8UUh6rY*UQ^!&#)_{(!?(l$_)Bu>fe zRy}0$~qY1g3w1%-s_clCEjTh$vplux=M+J-7Al3M|a|c@rw-I4VCb17<7uc^KeK;PKp~jN;mRVhsL$zQs}W2VDru1_(#i% zTZW(E%3!Ps3NrYX9loqpIdlpF#UU=t-2Y)wJ@of6iDAz~&uE;EFXf-mhT(fN^Ba5E_NJC#fpe zSNaetC)Mlxiu791{EEb8=^c(s>z`K32K#ttJD=FNUfyqmOI!`qS(gzj!|fbFyvQJL zu-*{^kGS%L;6o-rov0bC_g|t=?yx7^VJOMSwwxIx5_oh7%G%rEheR6f4{yTcE^pbR zFHzRdsyMfmwLbR1HmPRx__aPqP^%DzuOrb5pa$)M4cb45=!5E-bIR^Z4%cvQ#UcyW zI@_!9p>NiUWr{A#6_G{4#y!he3z@szy-@mXQsZ8g(zr~raeYVX^T<-OZeZzrt2#g#l8%!3MT)eKc=hABc0O-vVo^oH8Hb8wmmQZ&o zkY6<)nA8rLy^81I1TElTp-_UDPN){2spW#kbMhwGV_URD)!^0{R|WRiC>i)xD$xg! zF_QKE?w-V0V>#W7URl)YNW4}OZV@E(U6Dd}#NK#Os$8qq;BKR$(59%f%PvXsx~b|W z_$lJ~7YqhA=gz4Hl38OtvJp;HUQHf;Vo{IelJosHlQ$~2-^u8BW28&OP0mcj<2mVn z{Ki=IwF-~C5ovTK?#w)o=ZSNK;A`6X*v$3BuJB_EC5 zrB2oE=!0owS&^SRi%CY*XED+%AJnioqN6H^%g@wAO`l;Xt^_}Y3A+rj zm?^R$BN2Vs67-#@4N7GL&0VC4XeIe*p6BHN)_BzRSiCgeI-B~+s(?Jxzko!|Y$Z5V z-uq@ockHJGHoOgIym-ftyxW%jGnWeKX2C=W?RL-oLqw)==ZeIH*hN798=%@}`w8QG z7p#|M@rdVXbC*8Ye)>s%AB&Glg(xq}oN@3jNfq3@eNn1paMi8z{7rRD(*2AuZx^dV zbs&i&)@JZ>Od4X7J`^;Ge4(3tNErFfIo7k)+AbUiUkbD` zgl4gaS}*SvE*v`SqkN-Dw!Tt_Bv^bF_ocp{3o43a%gahv(?^?QqeywI%v_b!fZ*1) zucnzX^iUq_w0JInnD1luDn+d=hUbTP+Xu2k?C_nDUm~ICs`(~++j23o)-s&$A~pL3-^VWzWC3;R&qRW8VuzOB7vAgF{f z(|9RZ=wwqT{iVEYL@jRS=TPlhUQViqEe4s}dqlKJ{kMWc+sJI(71Y}UJ4ub%1&Jl#q>93w>Pix2&B0i^r`9-icCE0+8i4s6so8)IU^ws8-DQ0M1es4#y-~6O{=wwbHr1H+1 z26iNEeCCEC#AiAh&np-=%4;xz?8fP#iUg)M6*0qzBwc;5#kS30atZiuXj3+l*2b8{ zH;QA4i_CNzx$TF2s$DCZ$bfymOUmHlK5n{WGq`e*V`-Y+l$@{aN9Tc^fVjxp{sI?+ z-9*YWn;u%8Z%{AwBn z3Y9y83Og{YezZqD&5Hm=K%X+Lu+5{JhW=s5lw@j^OO8UtbrMyHs&ZH=IDR=h_6~js z(R%K?z3d4enyPF>(7e=^X`afnQwg?z(baTclOtJqTaX|smQT1Vnef3!Guk^ju=MS1 zfE?9)r%@FaG7Rms&6&EOwE`mDx02_T-W2i1c$>~PWenb=um`3HjA=Ml!&x3;{2F7& zs$N!J^uo-c=ZZ}d2td7rV(G+y@#)T&)&;dX7E{@-vI{Q(x{3D>Y=$&4fEw5p?Q5!U z>Ne5z;^HoQzSoq16%+GnVfru6U!&M>lCdLq13P9)LdiE-ev~I+lyvkUvAsPOjFmM} z`#HKR1%MU!!ClOks7OS4^P%06%dcx&s_w~amM`uUh{f+XX`$H2mim>F{ot0O>Czj; zgDEiY%gN7Nl+CPlS~ zk2hw})I-VhDcZ=xwLH!nZKVUQ-JiAoU(sk4+c`@U0Ojx(VAw^Luh?={bLGzMyZucO z3xm2;XrY3Z+B<;~{3~mvm}4QrUFHmPZ&MbLT^M=K8;ew)mLLg<_7mU;e}#$BrpQ$f zdn9k94YWcmTE2XtH`Ojf>+rV?y~&?f%2mAE1&V2(BA#ntD6Ht^?U7zQ;bpHKD4wDg zJSPfUB_7)i+D8os3KP@cxkpi^Vbmo}iuUzT-mg)*_=CR)$e%D4m`v_4wP^9x3I}KE z_tw)j5OFz|?qV)ox%D|}(5+)c!8raO!xCj7$Pc(FBi=4@t5`MYq$q840s>M?Eo__;~Vp2E^23(w~m_BW=4 z#UAD0cnP>(8r}QoRHz$#{QIK6PW?3^Q(~W{a1r&CEm^@m3LJmF z^QtXFIG)3ofx2H5ajgY?oVEJ0Tedtv?nTRGj*o~eu%EAs$(s>ewtqfHFT2^h))n6PZGZtUfkH49Q>|6kavWJb|ExW<~8l&Sy;Z9 za(P0{p^qkS61xEUG3yoYjt3v>jCUwK`AZhSQ-E zKaKbcyfN|fWl9?xIt56)No~o>hb8s1m>(0Zg1HAne&$0ElZ>}SK}z?)@4-Pi}Q3NtvlTO#45OqqbGl{ajB!g zTMc!}5UoFdrkkMz9!uy_5}?qN`-7B^QI=L2OWaOh*0FR(lF1x=B2F$hYQbk;&a9KH ztJ?M_`S#pXI7aT}Pf_{6YJ<--?2_VN3IU_#q`iKajAILN{**vl{m(_l<}-EY2&nCj zH8AE6s|HJlB=%d;&{kT{(3j0~CQ4ZhMAb`N0A^=)xz&pz4Z_>feYVboL3A zqi(5YJ)Tz(Fyq#kv7xUIWCb5t?JXEU9~1PMurmY^__8NZOObM3^e5u398w)2`9u1v zrdja!wl(lYJ$CI_^6UfUAB5Xn<6dCJy|*N0?e(T8+xcdBBXcyDyPW7aITx37%^ex^ z@Fi{>Pip*bIJSgTkVo;mv6}_yPxQ`5XE$^W zsGfgtH(2G7C1{!+XxXgas3Lih|E8GvqHg>Aj~wPTOU{Zo@lvKHneJERFkfjhu~J`M*=<_o}7rD{e69HbX4> zvC5mT3Sw({pA8rEnSYKBB!*b(I#`}KUJMXL zMpKKubLQ<6NoJC^pOt;)CJ0ib#Ic{OYv+pix^aCNakmoOk~+`KdPS7{lw%;{{Dsr` zVZ2sm4S!f7%O%h2)3`CD48;F}6ZWSG!nbC}%q~$L+^r0usnJO=2UJ=t3T6*QXXm8x zYM{=*OQ_KivY=L3;5}p}85v7(O`qEs`BRLCc%(XA^(nPxj&Njl$G^gVY@FxusQx_6 zr=E<>u3yb^&=c%ubJ4k{4b!Hvz__3THmH1I+DTMr2H*M6{E^TMOcvsUDMtS}PO9>{ zZ_k+3b9WmPPH&t!Dpgt=|H|XhQMZh5{GXLmDEZVDl&%wPTqSZ)~OW?9xW%vxOJg7(-i|Wvy*l4yQq^mH^k$e!}1wQ;sAhzF8@YNR>UqODAKJ|-Rq91?vnBr8eFSMPrw|FfW2eBX%r35teyEOc#Xhq%8GXy z-5@#ds{aAdnRK?cA1HQ{XEp^OBZ_}i>BX`u zn~IGhDvp=< z>JxT^mGJ1AkioPk4F;WY_D>zyh5Dn<)@kuFAVB&s+T>4^_4#ojQ|Sh_NG&xE0R_q< zh-o+(!-tFp2t}Bnn~IND&n!-QE{O$Uncwt;OuLfoR(QR?*Gq}DSrF~V964I4uV5!y zvj`qe#8T!Ft_|N{eqc%auJMk)>YJp}&E+@-S(sqDQad5Vg8H|(Oj>1HPu^?mpo~)# zM6G9>=uEo(Q)1)I%A!XH209oA|6*bDu)-62{p$+tcF*Vz>2~J)X~ePK+?xR2!l`;V z3Z!<{wt3cPc;6_)344<*h?e0R{Wf2j`>gL@fl0cNM~$tyBpAPe>k0``%k8qpJ*my1 z;(H6z2k~DQfV#r%p%0BuJt5ic!<`yLMT*wUo+MNr+Ab3mvTzb;d0gMTsRYfxjvrHG zG26dbXX1Ww&YA_#&cdE<*pD z@l4@^;xjazz^iscbg(=_+)j}Nc&w1%R{}GJ+@a(7&t0?hFay`cI5{a1ujHfaXP35$ zR5%0#HQHKRxnO)4d95BhlWSaYZ{N5|TrzkUc57wgVIdq=$rWC7tUGD4+*B#DUq9lY_e27@SV~SMW}b$U!<5>7 zRaehpCeW${Te$GL$fWf6Y4I9j3Amzs`FfEge+1&pYYF6y*LK?NEzx-}-rM=VMjz=V z`Ol65*)wojVql-hzJGsSKkH&)ub2*JI2@zL)p|E9!%D#-LhGDs(iUt*%$M@S^&up3 zZlm6V8et5Ndb$TrA`%vBgD{;ffkw*~l_-9U#H3kshTdRs=QJ2_F)p7S0qf*bf>*ar z4;u~3?lNQSW>M_s%-Tice}SSs+LdOFCs>F~Fq-kUOiSb|J&_r)dv34inA`~NQby?L zFe{RbWzUsnXY}r!#vR<4!S*@v$T2X+`r8tK(7$<3eV@<$Q-d!VBEMR`7gl9ZgH4xl7Zq=n@NiWh(jB8+rcs{pf%w~ z$jf)6FC8lM^GJr<>DUQrHnLlap@PFWlie*o@h>OdD^ZX;d!24C>7{1w4^@}@>io*I zXD^vifn~&`K{pv*9rJz4B*4 zbb9;L`-lZr!9ey(TV-o;;_>g^enT}+H+pc(q;S9?EB?AM)7%{nl7bmqRRfS_6Fhz8>UHD=Gp|XP3+z)O`GfD!sPuf3 z#@ogEW3v=}t2;5@Pc==g4<{m07xoFmU|)Z&a=tgcN}WjxV~eCNw5MP8rwEtWEK=-bQxE=JYN_$t9(TwUFfBwc`)t< z#@H>g1LP1^)R^>Dk6BAW*BHTliTNiHt0PLjBMamBmF|^Lua;%-U=SGomKA)GmR`l9 z!<)Rs*sU+=kuefZ*^@KJ6!HcH$T_ae)9brrRF$4yW5D%OkW}R5okWbv5Uiz19@DZ6 z7k*pvt6KG4e~8UINK5+N+0Y$Z*1UjbBNn`-VmG6{KI4q{GJS0_;d)%eKg@0@0n+yK zu}nGxgqw7odmFve(PE(JVC6em>sA-4EV4H5dO8OVng!J*y1l406IjeONEa z9(jp4jJE!Klis3Q_6PqotPTnRg!Z`PHfijf2P;IN>xH3gDU_K#BQP-V(sMnPgABwK z>MMU#2#Nj%1@`qGWN-#DfB8GDQ zMDY2Tn}-3W?NHF=n7rNIqouMB+0v4CqI>{2IgsV8;=MKp_yXwf!C$m<0AKP>PCvYo zp0P>XjXZhWkAkyJlm#eM%<0PYmz{2TF3dzgg*kJg<67Ikytbl9$uHt|M+Zi;~u)!2Do zpaYMC1bz;Hr!f?Z(Tpe)9=>dQlkW0DYiBbNh>C9{VMA5I&S-|XFDx87!@HTP$O{JM zCvcA(_zd!-o;}DlCDhVOSqD$uwNr?+?Q3KWL>0YshyYR`LNEx5jvtsPEqk=K%tv4w zCOEkyJ!k8cT1UznAfF*zcfj}6;IeK7dv<* zO1*u|UpdRnyDGuQzlQ(=8(UnJ&hi!c>G1?*otocF;!9Nsj?L>PImBUofI+tdgG8}D zuIf`wG?rmL<84fq`VVc9Yv`LCv9BbM7sQo^fT}7hx=hH-XD~fTDjKwZy4!#|+lKws4bP^XNE` z8c?+fCfRK}`LLnD70SvV;XAGg`X29WK+NAN{HeghZQJUhE}ZHWfQl`;4Yb;Gxh4I+ z`Jnpi{50iWHPG$s3LYl?a3i)Fi57325!@`Ke!UDfy*$9Gm&ur}z6+V+9eXc$Tyok? zQt%xn3A!N3_y1sUvGDCN$GdH_pE4=o*?T4TwN^h2p(0a*Z%Fs)4|dgJZ(COX%hM%# zc_UZ}iIW1`#BHFgQe*{LDB@(XwO(2HmBb1e5&2CYvm7nsU6SzqOQ;wgu>;>xWsz-O zI~=ubLCot!Z}nSizfI4=@ikO5-%jZ(rBio{oPZb|t~;Z+v|0ZdWKVOSX(hVhhcUd4 zCU_`mZKPGd>_3qSK((FILf{sd0)F5>{2Fu-AhEg~`#eWL8~e(FiSAy1_Dwr0Oryrb zExV#VOlq{Hfc(dL4O@nFW9GwYfhHQ+gbh;IoacX$JyzUn(cTteAg$_O!0*zy_w!$_ z3_R+&CfMNkxh{CiGU=cTbxWNIP4(7a=zs7sV_i)Ix>AmciF!q8pwugDX1fM9k4F=$ zYp}^K=z?1Vr*E=&;$QQ2%v4aRKwX9Mc|y_7q>ez^C->J_lIr$?dtr@FeqV9fnHCrt z)C39JG+$tjPi`V5%e)Tz;gTk)m8bD|*r;qYv-}E7dGdsr+WJZThv_AApul7tDhw<> z8B+Mrkf}|4U(#vVz~(U^=m1ms#neJ5|j!$xk$97TKC4!r^VaE`t!R zFPEIIyRp>_5UQF~;IN<*JF^YOgGy@|!Be5E*KvpEoeh^`dTolL1Qso%Q!UK}3(PAQ zDQbxCk}I*AT0)f8M7`<#T_(H*aI7%Jeu0I#S*QL0#UHTn-f=SMwgiAx?JZ^spT#A3C10~Hc=L>wyh(!vD|6Rb zEk1*?5nc^$iPxn7|DlZN3J^)7fIc zt^NJaV9CyBkj3kagj?8DD|AJ`BDh|ak~x!SThY?-Z_;KgIHZ#vZ@UKad&86Ujf_M_ zQ#sc5Yj{TqmaAS}`9?lZ>Q$E0FD*^Ws~{gNk&8*I9K!E@(Ni(Ly|M8L+e%_~kE|%bF+qI@eQ9ci}(Cb`R*9^Iq+r_4>!A0Mv@aH|c*b!8Z0-{>@u| zTm;Urf8+n^da%n>|G(>sjUO4bzrX)4$gJYz#Eu`iTWhQK&m-klzW2HUEVPTQSw^9t zpzz|H@=vUa3k#`hG>7NgbPlv$^VFtD-aKg1tOlYmQ~V%t0T&to0DuL*jm+=v`E>ty z=QIqwYrw+sV4?29QKIbmFMn?kBPt=0GYdBp#h`jyakhDAR=#Zp854}6zWs{)lmOgXc5#4%LKBUmokVD#)MfU3FkDoL- zH+qg(hwl&aXs&+z+OtS>Nh`vjn z5aKlo?$n(cLq|uK zzycqYO@ANOyIRtP3%B$Wrw~$5!1$E++WbiXcm%mpi`j2*vsI5(oAtBLy2+mzNf8kd z3<(;nmd#`Iz;`V1{ofu4%gI^LV?B#VNg;+@fqxbel{Pe_HYn5Y;!QrYT11w0ZJg#}5teZRkk#%Een40QBxkyT4BKarL; z9hRVFOXY$E@N|S_{wKO{alNiV0#x4;Wgo*|DPWRaMg}Fy@P$P__7&1#P4`2B>sb=) zy|AGtPHPEXct65lW;RLu{a8fAC+Z>ff1jh{Dg1x`z=i`3DH$0np^E^CR4ij=W#xHs zh?JzH4)o@df|?plQ6FX@-N;3VVSB)fueP@6f$gZ5T)wT%hn*yHU%tG8_teIviy9X3 zi1ZWxe(b>a*t1!iyK?+X`<@+IZ#A01)<`DbVH>7M`f#fN2Zq1pgo>dy`!955uwK;E z$(_uiA`9gKvkFzKYBeiF^H2H8&5?2f5~R<$c6A?lz-=bkpnl4;#eEeyEPnqB&)_}l zZjO4(D=JWSCcC@3mh#I6TFfiab|x=#e*Pqjh=};JwZ)ykN{EW^?%g}5a~at0d}hX7 z3+7z-k(sG4Z{bArDz>*B&oJq&q-0OAYw#^C4Gk$Rt$#y9L+iF>>FgdBHuiRZ-D6!N zGqYfQ1A~rQRCIJaZAim4<7}mcV1A%?a{`a;E1S)M;Lm`C`wPyHcGC%G z7II_QitlcYMbh67@Iv^o>6Mv(;-v5G?Quo;*j;55Rr&@6A$!M=ks%sg9<6`)^a;aX z#OKyTCWSA)+Gew3Z)Uw_r!YM)Z|BFV+uikIWI_VZ!x!BebV`MK!$03We)wQO;Mx?M zQH|?ovOaMrLgIWHDm8G-@zGIXdhAwU_0PEF*6Wfpv$Mt)7Sz~*TTd>uomTv*bMo@8 zm1(#v7V1Kuuz{`C?~b4(aptTjE)HXISw@PJ&O*<Q)N+iX?rhvB zc4lS@{JoQrPwc0sS@D??q{eoKhJj)HSX4x$gFUca$jQ+`IjCP`3Vw2G`iytof`lPAk zQuv-PjZIFj9$$)<5465^JkkpiT@^?h%~vTqslzunfowXX73%k_H*RiMf9gX>Y<8$x zeJ8ROwl`A|XuC6k{>P6w$&E;{tpgXi_j~7xG$?*Aikcguv z2UtU4)k=Wtgr$Q0=+hj{ev_gG(MgH-4J?~O9gWcKov+DQ+~tR4WY*$!8ce;hOeB9q zRT^BlySuyB*4F;mFF&QBp;@|Kyc6i|M$qWB+MlcLOXjs-IwQHe5aTeNV#ZiUfi0#A z9P)FOR^yd6e=8LY?Td%mzXYw8aY()w@ei;qUOMdD{g_J5(C4y3?w{_6#c}USN^NMD zs4Q5iOy&4NB1gs=1#F~w(Soz0?h_Ew-jW1Dx++t7eznxI47Xm1+M69R_)>!-H6tb)68V;^XGx z+u&-yRU|@OF^Hm)$Xc;Oy~fDsD8ec4rFaMtjH^Pke>GprDLwz zhL(|$il6_sc=GlAB0$z#?FjY~dI{o&=$9{F)KUanj|4TUonGgqyIsMl{w}WR9X!$a z#`CG+=viPagYI-u?IE{9vFPCUu;9fcXO{C1Z#T*@xjNSM_aZBgdoV+7#@U#xM!zBr zHD6maw9rKz)TbtkK^JEEKYrl}(?rIQ5b)}xH=47E05)U7=3jpP%y&Dw*8GpZuKb7nO6Yg+7i z=0`gzEZKKPEKoUIn*-~Hfh|ENo4d`bA0a3vpA=^&QQmY8jOUw6uavdHJ=z;|i7c~l ziI$Z@ER*Z`dAdI7xL!@|0Oc$?Zh9Awo1?oka!}XK zz&U3t;)$1%$$W$Eu{IV?@f^bNksIXoK!MiZ>Uh zsQHwm+|A?d1Qzu@HtyF1O8i&!f;?1D0*70V7ok{!ZUv8|y1#Dh;T_*BLOVE#U42N1 zR@;77*u2w2MuPXuS1$PukYtuzszAe-9K&YJ5kRs#DH zUYmS)B74YXmvM3H2EFJMR8Ucopul)3D3}Uyg*fZ!>jyjhVW)3wajyh!M^c^SLV9;3 zrd)r8cc#+O9otPATBLiNg<@KBF2s{sm+YKJPk%A4@CkZtakFl5XRtn!YEA!|&tZ~G zz)B&dGC1P%g7$o=`|(<1Y{`B+9c{qaK^u85f{H6In)^;?-NeVfkAwNT(~Wl>i_rQA z+q;^Gm5FyZ2R?>>9xvItx$WrG7%udUpcEagX{M}=zqYVCE@An#U+l`q&^Q zUvMbSle>vKWJjEI26fEqPxZfVM`7We z^AV#lHQpRGA>EO`S5oTiWondmc6OdqEjYL9pR=1a8=IIQwkE2z+gtfPYp&a$@H*WM zA>vRAptG^@ag>|8U65qY0cs0-)Z(Rc$&QnpBD5HRqdbSB4H50LXef>&mR)KI9APoL zU~6KG1vxaDl=#vLI(U9$;%~2)t8e+$E~wp55#sB9%*2<* zT+(iT5Xi|d4Q<}ZhYrO#I(L-&w)|lQIObPK2&#trt?jS8Jbw`~1i>Xt*zA=wU87@S zVye_NFK^vI?rhqtT$KIzas1_Cc1MRO2+XK4Oq#Ox_LW&HC}q)fN~Hj^dfcUUbaZ$m zKNnMVSzt>zmvM7(*=~v|%{boIIpi%m>>Sj0VW!e$>|7Jp>*V+)u%&hxYk8YVA}v=Y zohtcc$;_Nk&FmxEQRr)Qwb+{of^7F8vQ4gy&Kq|_%GJ*|T!VIS7YFCAYsylLnyyP` z3o0%+o0KHmah`|hf3CYYGbm$=78cz@2B}E6(o9`W5I>2qRv$RWhSyd`_pm0DXz-(X z-Z5ynB%(2{Y@`h8SCgE4)>pewEKD+Lx|;~BE0E$3_hOD}yjgX9dAKQf7hel-d2vft zJq#D7LQkmo8?N`s8@RCWc(7$J`y?_>7?!=;T}8;@=v3;;5oU97j`3;{N24G|Z3Kai zXLr0lYxQPtW2y<+#{EjP5u!`}gx!cPLeD#uM>p5Lf40yvh|nKw&+t?cym(-k<)eHw z)_CYc-sW;UzfS^9h)H6y8Z@yu+8`0|?fueHY>0eKWb@KIPl6!LuPY8auzkf%o&Cdp zP`g`gf0_D-7m;O5uf&AMEFUR?R?3{|^! z;SY%m-5-#-+SPJ{0*r>1mV4#XEU2yqxHkz|Sy@HLRUffl z^`#1?uB@zhY)WWrzbrEx>@XT5{B>KBSI#Kjcu{b=@)^7`RRG{5wy}mr#s9s;T?0* z+vk!m!dd#iR9Md#XD=T*z`MRw{P3X-)H%LYrLCM63w%`4am9;+6_u4`U$+!=N^p-i z2dUX??gRtb*w|KyUFg+oqJID8P-Db_)mhr(HWZy9Q5hK&ZWpa2rRMWIKZj|`4Eo6^ zDJd89gLSfxGLft=1{*b&`n$HOG#p-pn!QZrZ}sVZ?8R}Ic+9lk@CApBu`Udf8tdC_ zq;*SPc~rze{yz0d86GlpAG(e}FJ?18D#f+4-e@`KdAHnne(OgK#chqaQ`vVpnsnhA z3@{edgaM$Pw&CJET+Fm z9H#}bVzzAX^^-Xd*dBqZiP>tUtySak4M3d|V5HYH3>%zP8%EszX}o|W=1I8E3vXvL zF0)4>k`J8>$vIlqdfGr7{qO9*^fc+rhO`TA#uo@H2`7v`I2oFML2j@qjDt9vn;VMJ zpW3A7nReHTO|3Xs?@>Eg|KLFJ?yBfwmCu5)h2N%eY&5J$EEXM+ir6!)HZD7L z*-teGZ0+nG8SAYpUyur)E}f$8wOwEi3ixA-?@&uby_o5n?4nT^Wva3T z6|ua?HSrw(SB{C`!H{lfVgT(CmA#s)pi~~d1EfOG^0Zs&sN8<)7cX<~^o{#s$}B;~doF5 z<}Ur^F77iLu+czulw2mhfLn;~)gX%@WZ&JHr_7Gt-TRjhegf~yJ*m6aaeerSIX4c? z%T|n61o%cdPTstX0*N+*^+|Vm(4lv8SDMGu76Lsp>bJUXw%wtQC7j-Dg;0#k^pj+F zbUbe2P~T1x_at9pCqLL9F%nl0SAILhQm!_QWJ+BghPwBmaLzT&S~M0AZKH0}2)IKj zzMoc)7oFkXG3xh z4#H2b%U9#H*saeNSd?79Nl}}ji?x~sOvqtziBc*C`7y`o>1jIjyY1|6GrpUH zR=|$I5xZSzKSM&IrK2OL7$p?*Tsa;vSRtt#v4?JQ_nb5{n#U+YmeNSr4}xo<-Jy*K zTrAQVgE!cm8dnx>w;xcVKkKyl==oRZ6uRl4;|X5DS3>u9lCBeRwOZc0ysEriWn}Dz zXxevw?b*AzW@OW>zr=I9@uohj>(+mH7nFL~cE=lum512+pmVMV^A19vG%?IgBrKvw z14ze>{>09e_CQ2$p(M)|L_I^p9s8xH#zsaBb~&5gQ>FKe7+KZrkYz`lgY(>?H{ zDhwg-JeEh_^2}mUdzOih-Qk97?nxC9T2o@o~6Bdfs@x9i_P7e3u3nq zO!&Nc^GMbTwnJ0eg)TGC4+b({)+E@E21oABeGo`+RIwghue3<%wYa^tO-!0zbJLrx zk|RDwR{C_~*N))AKSV+}|MjH$<}6qa%EKJJqcPJJ-Cw))yN1>ha`|$d%LgJ@@*Uz^ zW3(_Y)3%(|RG&KDqE2nO`lxYPY`%YYAKvgRZT}}Vng7q69KY!@L(IVTvuPD+pVUQ2 zc4?{g(pj57rXWXFK|wGqRpp|l-Hz7+lprZ3b$W9}(Z0C6__?7hRZf}fYP(bAtioJX zwbn0FFMPFJhxsO~Xc3yGy?N)|zj=OVvy*IiSnX^-gd@<@-qpf8^%asj^%W|{D9@m; zcH%o({bvC7us@|?MC2F!#XAlp?$jF3SOHrEHw$u&)bO=@LWKenr`DY?7+TZx_Vxzl zZ=fZYw1h-=&fJ=A-LuCJO&1$nm+XvlT{0pSD)@{!3&&JjQhp6$&dEcDHq07~Y= zhYvy7O}VM9>TjCU(6W}o5JDq04>(hoy4uTd-<{P@!f0Pk%HJuKrBEt2^P~iTIn`-*j%ge5l-?da$zjyfLOVI&_uPWP!K`7d^a$4C`5m z^SOE7hG*m2!nO0Qd@LQV)N-d$RBEam>GKQa>=9^9?_bLIYAUMy=H}*oEzK2-@ja_c z6maFi?(MRV3JOr)_K}%k7r6B{J}g|p#Ijkz30PJ4@}tNBGz2pIBcMRkos3AeUR3}H zk41*ug41x|ElkYo4!vQioSg;Bsa^@7bAVoM`irI~*)KS5o8xQeDh_T$9h9U5iyo1L-*q-pY^h@6DGwQeDmMj4w?hZM@=EuXLR8_VJv%m=4QcPErJA}d zG=OyF4$}?chA*0wl9K8zwsWjiV>IZCvl1uz5aA#yC7)$z*uooPv2fgtiS$m-rBt=R z<*Y2Vef1D|cHqUQycd=ctGzu z{>MpXJUl!KN=g(p&D>*2IOKwYf|A>UCoh3ErZGM_iKd1tp<`lU(HS1am$7+rdS-;5 z>*q(2TxyqPvZOk zjrRBdLS6LfQk3Iw2Y+5bSW$Bd3q!I~oC3}l-31pigCqzz%`@uj`4{wqBzWz1wI&PI z4;gi(fkvTqeYV%8Q{oWIx=~(Pc?`XV?355ZP95xgFGd6uhbPTi%nR${-<{8<4ITHZ zH~mD&4jJ{;Kry#mf4(rf%7e^dHY4)+bE4ZFa!FPe6{w^OH+P>SJ${%E6rwkz&mjsO z$vr)iu%h|Zt5-1L&Ys&!H^lw{*Z|fVzNnkXSFU8) zpTkCVyF%(8!9gtjg`SPcM+eDW@ZfG;$CxeYRK@d+6Bvcm9txWdqkKxySnWQoekO9 zVxc?CuveF_y~IX_2crM|s|OWR2eakf{oavgUhg%U+o5?KV$OPfp}NeryM! z=cEd?L?)2~?60t>sE?Z)Y8r4EuuTBfjM~^e1$zhtLs4U6TDw~Qh4mmhpaKF_6)49* zqLiC@BlmY6s{lY!;v$#Y?tCgJBNxlBuLsJSkc|yne}6yR$=MmI=?mABVX+Wg-^#71o zg&5EkgzVSv+X2E|=WHj{XEJ4d#+UP#Vf%_bZ@=h@ct@4;w<`eeeRh(yWS!J=;62(g zRqqpFn++StNKa&evzLa>%&7i0ofe|U4z#C*#UEV-85xw%zfE5fqIT9sRSie+04Zy_ zQ-*UsLwy^zhapWBlNx#I+bil#Mp#cDUraVa-$_X!!GU7cW#1;1oRgDtAR5F?-rfyeo8iyEdVUrSDDy-d?yini9I8Sl^E>F`YL7nW?x(ugvb+YBam%^n)3Hufp ztQ}xXjR+UcXxhT0S1tDg&XSOjaB(%+=qBC?&;qK>M^DVB zkHKaFb|H8MU#twA0I+W$|AC;=cX)k!;UeH7pjjt{XLEIRD{O5|&l}U17VFtzuO>7Cs3`;%b!n)Y;Mc?7R$Xo+R!FJW zA>sBjRxLgJaW_8(DA2^q+C>co%3Vhg-qfDmvuDpR&WN~S zze{e_$I>}8IT<@-u|R7%6MA>Mc(*R4d4;jQz77`^74;k$-q_ff+XxoCSSRKkR)c{P z1Ls!p9-;D;2>dBLIwu%6V*2b&x zqADPbJP? z-kZr}TzCTe+aFdy!zIaP*CfIY3H;Vtq#P9?Dz&LS=6#9?53M?zF{zMBhJ6>kYT1XN z9JVXqT;7lH!*_oJG5BXcWZM7EW_uvc@tC3XLIt zD6rn2=}wUKq(IL~3lh^95A0Zzc^>!*-aIIp>^06#_L(;i1?esAzGvtRu$_isKm`~I zRD->e&|l8|Y83dD7fRxr`zg4zfuxlC@2tO`R7;1Q^>46=!A20%{|yC%yTZ=Z^?t2f z?HD-te#j$igM6)kul60>1X0B*{WYwONslB{5XuGIm9D5xVF#cd@WmESGyy+Qn1hK^ z(9+T_thagpm%X0gsC}E=`!ogogN#)?R>O%_G@V7emFhm-k&#$!WyXGqg@vYY@`LCi zmv^;e?X_@!Z3%*nCf5D?Ve{<@Q&@^>A}e{0l57IBeXsylIe|+Gm}vOlZjj#ILfkhi z8*drv>+260Db2CCD0Af7IwKNz9qmSyjCK|NfLV|@0HARJ4%p9oJL-j{xEx7 zgFgWSe#$qp;pe>d7G-*y7EhnlDmPrU2@=Z@Mt6)pxn^E6*E(8Fy8`IY{fm(T-= zsZRcW!FZ1F$P_B9T(&K8+*wNmCwMoeU)%yJYKi^qjbXn61}iF!rKCK(Cwg|lw>|eZI-acPdPw?|= z8wOj|1$;%ZW!bV$3TA1%NGLd~$D^sxsQvwYhgb|w@RFoqEi(Sh?-9nDRM>$d>p;mJ zXnY8O92FUv)e%UMhk!}M3>IDe=g*6G&{$Yc&pjwx5e@V~`lR>_iE3VG3QJ#oRk#k* z8DFv0)Bwi?g@KHbEW)9%;p)rL@o|Mw6>M)wT|viFr$O2BSDsaH9KcCKh9@^Z)7gmC z7_8~={S#DP{B@QozGmEi??h@v9W>en6)Vbe18;%3AE5{EQ^miABmxiSo{<@BXub`i zNrv)Fe1wv;u~)!i5^SIXRE4L-#?76%LfzL>`~3%Z%g9K7-Q)$c>fpatJXBxUcbIV{ z?4y_)Pn?*)G0jM9(w);;V4zDdgRIOssp(#CmwVqDlCgNR!Kr3HZwmqv$R$&0OpgR>5XAn6}-C){r6zyf>Yofu=~mBDeOVuPDd=F>HZW_ z@GOrmbmqMlT;O+daoTI?!3MnqzsTR--X7OqBMpl*ctMF44e@~YS5Q=h+3Sy_feWquF7*n;^!1_ZD6MqfN)ecpZG36B$WZ?MY^w12>Q6&K7AFtVu) zrxX61zdvzRJpy*jzzBNzEJ;rgnTvjFjmHqf|73Wsfo$OkFmxC0#PoDGP~RMZO!3Ci zksJ2%@n0WRsinceAXEfv*JoFz73cNv14}83J{Vv)WCT_f?J9CRrE55=C<1vFM8F_S zpkThn4)o5HiY?XcLa1RSQMxX2hlbOUcDpH#OeijR-SMTTc=xX{sX+M?v|1e2Rn zt+iv&DB?_8RD*4iJ_i*`*^+h49XtUK%yEVO!X z4Al0K1&H5DGc-;wC00zPLguk*5(66GL!y zb#*fJm6wx410})W^h@k#-k>VlU8w(T!HQDS&>+y$(*t5nRz=0@g>}LY0DMqI8g~um z8AcO^wMX@x%rTXeY@u@)?k}~?6vRn|2H2~>f`^YxJgYb8L^vFgVq!6w$oGVo042i0 zf<}E8p-{N^<3H}`MV+>P0Nyu#d-MN<*k4C8^71+>59BpTpmrHChnx0O888o-otD*_ zJ-N(Oluv|}Ajln|ZYSh$pi!}lv*8ehv3d;yG{oKbGm(D&3)+3wlPB-H7C(M`fk`Dj zw93;4tMb4~Kq3B2)JI-X5fcP((Ehf?&Fl5h0}k+%>E-2Ewe!Rf>yBtp?tzMio0}Vv zXoGT}Kh~2KuV9TU`_Xe2Q6uT|j%qS+JP_}W1@Btl;iva2}k; z$Hz^lir@13b%z;$`Lg1HKuk)A`p9z8^<0}H5m2xmH7{Uq1G0)vs}IVf#%ql5tg&ut z=B47QF4=fl=T6a?KA8j=YO)A8m^u$orC+;_zdbFh?9G;(`^xMhltYZj77M^ZKv}?K zV^faLby6U};952jf*P@E={ph)6H|6wT@s*MOWkg-9AaPK7^DZ2 zDc$kN4vYKk|NeW@?_;be_Y>J`tM(U4OJ8|Zz_A0?gDIT7jJ5P{apYbDqjX!zQseLG z+Ryel;DD6ex_h-#v>C)U0M2noEfxrH;VfIf@(Nc@M&&wJ58LkJYB>pD8P>q3Wc>U@ ze-?WtA?J#pKNBprthAYyho?I6X$wce&oW*us9-@Nz;57Hl$TTNLik{BOB)&+6S=Qg zDgGAWi=jk-sIUR57N&3OTX3F}UIj2kHiNB4%r~n_)o>0oi|37hi)^MsO~xuP4G^s7haZU=*RP{fQc@gZk6C21*%l~a(V7sHKA&133fWCtq~_J@*x+z* z&VL8pu3($J%2tV6B z@E=((w^|hi`5&`}$hz2WcgmyRAicPl8IVcbfZYHMQGm`F%e0Aubu)t#_a{AcLf1YH zlu4f}&GkV(16|7wvE(pZq#LzB#`g_$ChL-#&eS?Mkd-Aw#HTpZ!jobnj*J)cEKmZs zXTIFvZ0rW5+drtNZa1wDbIN!4^gr9%fJ?*(q}-=71TKGAmU@faxHz~eaL(P++lv65 zv+5)|WMn8rhR2lsEbw0Es~3Us!jW2?6o7%+2TPu4Fv&julgU3Amg2NcJ5i)5XJBAZ zC*@C2uv|_8ywuMgF+E~Ds;_)bO8Nb{bBoXO;0>N5zi=8@`lEPOJ1-0R9O`=SaX}g2 zJ3+@SVtq;uPtsZ#9*TjQguY%jp$r#x#>*2o*e~CY@b{pro)8;&yxaDnSGw!&wpy+& z4BW%Unl2Vsw{}U~E>?@-sqhI@N01&BvFnpY#SaqPws|6t7V0B8w|KTbDa94m3+p}5x7ei`#DH9p)miFb{;%uP; zS2UoC{#!iig<&(Oih%T>qN4H_$9HS95fFXQOL@nRomOB^_c8L=~2j#(yh~6^}YO5JMmoY5Nr|H_Lm?gf*TKq zjKaXsdVP{IGVSZu;6@7U!MIfrmDLT>lW)dWr{~L_$fBlO5&p_+8B9P}ANBGX*gprHVz zRhYR5J_WM9c%}s`_%POXy8v5XAoM*nYrfhG- zb%ky+0FHKh6_eIS-PC{>{a*=OQREuzjf69@ut5Bu>%e~aQ_ryU08eTTi>OILXz*Eu z*QZ~-5+LOCba$r##Jivm+twfzAmsu2f#QUU9{5a1+EyA%pyXRulcOB5KB(A?*Vl+VYAp@yAP}%uTkpzwTStI_;=~Zw29o& zx=0NJ8GSa0_@1V}3)4#X1P>XdK*`byIJ-7O!E2R&4^L`1Gi)Z?69YW^*Y0==+VSfyLEzj_}}n?{-a64$Gh8Tesui=k8GHffhaF`?rRbDa~OpPh6EfaF{Zfa z&1GLT`uoSHHe2lL2Ji5b_4EF9v^LOptqg7t+ z-x_Td;krmr@$6jJL}q0k%uS-M*O}bAV;gVQy%su@z>?R*=Uk5aT`qeyv(nS=^EEa~ zthWB^OlJOb1)iJnU874T6Ng|46C7Rcv$6a77J*r}19{i3_FsqtN|X-(-uHppp8#xg zzO*P^Jgr#xJ`J_b_n30xUj?7z`)I#ubb)z_axOTe|9z7CA*lBT?@7C`!$AXUR@xm! zTsZgq!0~0Ap$r)30Y>XAz2Bjj5_4~_?6E8X-sXvREUa1P#bG z-w6`qMd2Tp#{!*6zPG0^co0&(7VbW8frJ4P7s7Sh&Dohr>m*kMriUoA}D z0kBytPa-0opZ_zjWDt#Kvlmp!R^(GctMiwZ9jrVqYZocU*3! zhW?M5TWQge))DGZX&Y(v;zD({ePqigY>^W^e(Ih^wkq8&1LzZLaD4*s8?d_L#eeh} z3G@H~T${;Gwo7+AL2qht3hl) zPmubb;X%P{3|OcZ9~p4+L>~o4kBwCKRcZ(u-PW>WK@7SFG#|m(Cx1^EY4?vcF)6I3 zTDH%Xnsa7M59l^Z=W;~tOc;2Hv6nt^&6u|hc_~W2kohl>YAY;Jr2BAld#*$_47QcC ziSIvEoGgi|Kso;({S?ju$j=~r_`Kk#?Hq%6gO#h3k8?2nar+VYPj3h^*n!28Gl*{N;cM{u8~cO1jd-UN2%_3I6kJ8Chv)r$OeNAzhl5LkYxNTW)sA^D?| zV+X#UY>2C<^KhOe{~%pWC=a79VS{xB;EVTzI@a>rJn`b};)Yt(1yJ~y{Td`8*QoWR zV3t~%1 zWpM04%cnIc`2cf=x0g>vEvB#eM(SMjcV%NAqx*~6nmp!HRdE*ma14o#7f#+cW4aP0 z>#S^Yiu-_KUoGur1RzFC8%7^=k*Yhc14T6EFOMC>A`wH69X0AbQ_Z@WANg z!oaU3sM5XJ#aN6(R-kAiEnWw-?M^87nmu3Z*ZXITzP>Qm9qGgztD6SQ+3=>a3P>e(H{UF+4ii=o!6 z*k#?{w$yDXJeoe{u;^v!r+lQ3WtbYml1I8Q4z~6S%*6pG0K4n}7{L!PO|XP+Molau z&52Yw;PvG4UO`LbHucq7vEFFX-<$@D+}OonNxzT;_(T0_*o( zEFxVm?eG2$cAJJ4a0&I+0l+)p6LU^G1&8;kVC}~l$*Yem5-892O({3#WwOiAES&tr zm3Zf&qE6&*b>NxL^oKEL_@CpSUI{#ngFA3Nnv~rs)tZTX8~t^fMYK6de;JAcm6B76 zg1FyZ14rw?Om7w{wBb3HEbP=Pe?Je!xAiDsx3x%LNE8Z5OPb4MqdAOS?nzi65zyc0 ztThGMaT$m#RDyz=ll@x^*6?8`B41P+ng~>qYOj291)%i zLUbonbe(Auev?Oq<<>zyPzQ=hXmjJaH6<|B2_sSm+H3AitY~fKGfg(iYiV4ap*5sK z&z&GH!h!$1Ro7WKJF>@3y88OVFMbR8J;!7xDTd%-lngtwteHe1b3^X#q&M%TDKj!< zTjJtzAS3TYR<@%QWry!e2@IS^^OD0)OJdX=oL?Pg!TaIsc$xejieJel=NG9e3nq3a z=TGBLQ^3E7AUl~NZzCiTETYT({;X33%yow#Ha#h5O#&CiKq0>x7x}ifn_KV%Qmc z5W&{NR@u3-%%)SchfI;iCRNhuQplF?Z$>~)7q=lQ>A?3BJ-D(T+Rr1~HHQB9OgkZB zSR9@rpCkA|@`dfL7NQ_If-^Kz=1r6`dEh&fdsgR9EiYzxL4<#kmwU*Noo`6{B9e)^ zUdy~r3f852%)`rDeYKsZQf7cMY!n43G?<=Yvs=SY8oB)W(9Wam)mBUf%yXwgZ(>)sQ*x&@R55ciJ*HE+I; zjREM2iBl60<)z1}YBh6YHABkg$m+Um4jBITlO;ht6F5`j)k)@yy?=j5z(Sosf{0__Ros^^?YEI92j0Q4hx*)8|CfW~-WYCx^OKiVai zESzG)0pYxh#D^q)Ku}uEHK%^f;hB+A1Ug+geguMU)5XF@s8c{-;P^a+umMo-ujlr6 zciYS>6%`bIf<&$t<^Ow3m|e)XD8Y#S`{d%CeR*a`NR$BIFjU^(2rHFS+CC@Vuky*y z_fC;5Z4qgc4Ti7Ox)C;sike3}V5I{Pk}3a+ztg_`yBTE1$a50wSdJSdh07MZ+p2GU z(^8ae5YSEN^!F73I_;?a0LUFsr4t$qB;{7v0>a47?x<4`8Xm}4^CwJ5CVNyuQ86+l zRG8*ov@kL^|Fp!}CY1J^h=>SiL0ct6LO@#ppoZVjH?FhNQ}hnr>DQ$8mQfmSailW6 z5Osf>-wIO2?xiy)%TL@Iz7soy>&cI$``i*(%n#ykIg?{_fFPl$fW3tDI#P$0v|kl8 z7m;;G$OKTLJrX6{FpXfS#$^_Id~iy#IEz%z9(v=+(gyk=;y^`Fis3$WYA3htzC!vj{~5qEG&vN^ib2pGu+Jpn@P-H*!lyI zo<{8W5Y~=pEoN~p!U>zRrzc)^!$gIYqF(#*{QS#2PQt+_UJ1z#47 z7ZQ&l^U(}Q3iRG<2~x~MB(fRG0nA!t?P}j1V4~V^zz0?3D4P4(eX#9Lo^QvN}U zJ-eBo-6Uuh)6#oKhW@=-tFD(zZ+d#V)mj5+^$1#QyzybS94aqlIh=?oCMttArm*>- zU+fYzS{GiB&F#G0D2X5%XxFh-nBWK<$sd)Vdw&3c+0lELKnH2O6MzwjK;c{zkUTEVp9FFBGO6ZFe3xqR9f|{c=(_*3-lBz08)N)_wCmU6C=pp z;)6nRdC$(qWFvYZrK<4#2B>#Ut*6Y_Cw3_GGbrCs)C!@=3)ZHYAjF53(d81Jr+C*s zNIJXm*s^G93bB+mS*VJG=lt-Bjt&GPaQB|#t8XDFD&(|&b6drx2GqXp=*r|^7csWZxg}H9 z{gxU*>cm%ky>QfD)T80+tD{XVmRZXgL1JQVvuk*1>pyN+qp6_Zw-dBI$$%?Bb-n(u z#wG;qVr7<$7m7))^pOmyS@UB;q+=3dKYS>p?)2988$>5K`Gq`+Nir9A1a`?g+#Gg-|A0cQ3futg1~sn96qZ?` z?seAnw06--c}J0MqU|f(ah8{4>7Ro)QOJg&w5sC^PQVUS^V9tC7x002F%uIVEG_n&Wf=K7mY~{p+V?bJxfNejgsen(w|>LsyK$ z=fxujuu1Q0?gA#W?PI%@L$mD4q8Fj=2qns19fhw&bkkfrvTfsO_sg~YJbK^z%)WoC zs`@1N-HW(6M=m1ojEj6_$!yQuyK`sK=Rao4_dHLd^~7PVGN7yqgZ=k9fbLa_AQNsf zm(|tT;MSv%%P3s^@P*G}XsW`eVNK)Jqu@Al4LjS32+$g}R)`o?El0_8ikqK>mBdkS zN(3E&!nD`1`M@PERo`vv}&yP2_xIaE}24-eulnuQA#gmAXDAkkS5QxT@l;ASkb^;t1p81npo18ALz@FcCLqin-d)-oT_@8qa zsnPLRNKoSHpa~vZmLz)`GqcrE3)8|*N<2QeSZhf^!$y>^vxCM0to!gTaapZd^j>L2 z41=1v>p2_2-dctcW|7?3+;yctjDACueoj=<< zXmMwDta~5F4$}5BW!VOZ*^D!(SAuxN2N~Mq6u%pf&@;!g=&$L>H&LX=fqWOq|A9d?%BEx~N zZ)!LFy1K4X?=DHebwD}((5xWB8tLv%5ozfL>F$(nLAtw98l*%J6a@DEdB1OW zc6N7m=9LlV_0KP!=RWtId(OG|*g<#BeW<=+aNlxt_`9?HV_wua{au=0^SwtWDUVei z7PZA_2}w?TG(WqEPIYfXf~C!%>wDjqgv7k}rXh9E#YIKF`SN%~__|MeO~MoloJ|$8 z+D=@XgAJQdK4G^e)*XT+uETn@hXW51m0otZElaUYV zb*R7aYx+RDcIOsyb7+k3PJXymZtjH)(R{oD_M(ud^?lDq=NFdKc{B-bOu3a=t6KGe z6oJs|@`091VWUJ>b$eBPOyh8R;YBO!`mucdFK6kriq~4Q2p5*i%Y91gbfPkAl;0%% zL0v+u4y8m1WT&I=d)H|w+3L}KzKASWT{6D)PgAt~dR{(qZ#5~WKDN;2sCAHJ#g3aO1kL}%} zyd7aJY0>*mnkg7#+=F$J8>{GCFAQ-b4^rP#W+8Pf?dJrsOg6I)L6^uhz5!ugbwN*s zQa_ZKJsT|k8UM5Csp5n)yWdH2n{O64}!9r15+TNBt889DCc zy;ByDDk66@XV2KJ*x~~M(-p#eYC7Raulwq}pI*i_I_QTTmV^t+y~K}1y98NAlA6@f zc;(Uibd)9O$;a(n1R+KP?oN@|5)NmK4+;I~#ZX*_)F?EVD+ov|EZmZfj_%yd==3$y zlSbX8=zT;nH!9YHmBm0VUh+8-;u-MCdZiPgWX2n};|~{Hd<|_7M6MhnMh^Z7%wVgw zv>u3FlzEJu5P42%_Oc8BbP7R7WV6Gw!oxpBhi63vr;HaaJrgF~G4evT23Vro*i6Q_ zifgF>qdH%{4|OKQmSIrMb}OrBiCSAfm*c-+wW@^;grNRK%F|D3_b~Z=dpfB-aMl$7 z7$)70Ua)E5;>Ttd@#1l*8LZB6?Dkc(`^3hFO?_Vn#_-as`DLYkM(hd$CogavUO?Z5>aK5^i06cIM4( z7ke~Cdym#aBEz7)H-RO%H#|;nh(@W60%dYy!gqRy5&=41^R$@zDAkw~KH1-ph?9;g znRWTcWAa;hT;JN8xI7`MB5{%8uRHAyZTrGrUZ2*m?5aj_)xX60&Tpg;qtQ}C)?WW# ztN@92@oM~v@_%0WhfNatrD=uhu)J6iS4llOr8PlVbhma&-ng)RUPUUeHS8?R@RpWz z!J?65FkZyePMhJg44W~LV(c(-4qnuqullSqvKamUcZxxazDO6+=TX6+9g#6&(W>88 z85#C$Ew7&;0XViGVfgjzAI`}dwq*(R+7*h1&_RXaDGNn+vH|WTfxfy?@_}Dq2rh7I zfFf)j&n-bCx_tz-JUmi-wYVmNfKx}h%AqeNfG{%SiE21wS{4|ql%}0%N&Wm&G;3H$ zI<_R>Q{Yc?p+&tNhNXGJEJ+qEK1|!8<#JQDHHFojGP7rzq&=Ngl6cE7( z?Tit5^9yHeuX8Q$apQ1zH2sr*c<5(yoDHOjCsU24+pcLTJZV`qQ(>R=_;OchRb^-? zG(P{gNyMt7`>ZqPj<`eX-xbDRhK(xp(wIdUTes>RnJ`-i65gp(6 z%@T>km{?7Lw#T}Li)kj7^X(dorx6w&B>Lqa)=bb69 zOve`Y*PE6w;H0zR>Z5UDXfr8)9rNy{m``qaFgpIl2@YXEREZjKtUV@he;k+HO?2<^ ztumXa`%@aU38D?{_u|Zllc%BCC6D(ECHs)F*>2(o_n@@=9kj*@6;rFRNQ9Hs_jCC_ zBPX||tpmG4x8$f1WD6GAv70MoSpo(}bb#(f4i?SCjtgt7qH#yxns8*GCb~BoGyaMd z%ZFI}bNqL^82KN3gwbDwfc7-DXPC_+`GR1glcM7QP+4%J3kFV0;=sVy(G z8Y3pQO|mFG+FdhG{8nCc0qcA}t&}2@EzQCG8n$eI1xOhr-wofiJCe(N>KN9LDuSg6 zWERAADkt-$DNGma{%n%g+Eg&(UNTa!;1fec)G8D6TcdGq1F=bg$4Z~GobTPQ=M!=? zH#z=2mPnM@k4Y6b!QVpB+-hI&)f0=+oeLJ?M=k2YV2PE?qN01C61!GC0QpwJfLN-= zkfq|M5LzWZhDR~=jFA1QJDf394y|(f({we+@fBH?U#y=lY6~?q$0(c%LDRRh^l$dTw$Ydp7op7mvPO2wAao-evQa7~um|j} zWzpEDNQW9T@iYg|F4laSOpzOTJ&i6#BcX!mMg(nU|AJ(&MS^w1IC1qpm4)_4{j**; z(YStd5gVt`XP;iR#)dB|;NWyfz-) zfhDMC3*0a%zNih-;$#n!KKn}}xk&wa@Qfm9=Sfb(#r-yn z`1{N6i2T*mAOhz_Wy9|SrP@|mnlY&`SF+T&)4S&{4bH}T0}Y}-aro$+<5Cz(hd6JO zF@J42&Up@6bm&?KPuWKK)|CIFLS(vOf7Q&qQa%p$+ep8Knj!KQBp=T7j=v{Ye@~8J z=9Aow=RW7Vn|9J?U8$SW^uB7%4&d3ks6N=OyV#t-WdZ8wRt~<8c1s_D-ID9(@^3eU z)7~{%hMMu1gy$>N{;OK;??CYc6f00pXVr9#(t|Q&VF6*41Ogb$-i^;RQolZvpWroe zPnDDZQoo{(N*~huy3et)r7e8oZ^pc69HE=!Ke#n_P!F^izO)Bv!yF$@bv!`SMNSl| z0PGyllq3Ubl?G=_pX2VRiGvPDETp*faieTnx?taj+Kl%hq!@}OcqD{P%zfg(H0idao9E0ls_N3Z9QAo#`duxqq8Oj(DB3`xC{9eQehFu4)!Hv_8wrsv zXpIS5RFw)m2@?7F*8|GW5QMr1$g3bb0`4BDdqlFcr73$;B>b&0_B=IAkF1|!-Tz*N zr^~8pdRUUv*?PzNn16S_G_%V3#(FNcgcR>5Rl763O)wQzrN`qvxI+p|ylV*%LbF4F zPW4YH)CYdx0zxm?a$ig){91di`6ao~!0X>FW76%n+Z~rxh>xadY6zp;J#O;rZq4Gi z`PIVmf_G?k{0SqWUc{Vbm$J)C(|0`kmDrHUD%OGHj}iWdchfzb)h45vzOAVW-&w!H z8$!G*zYbj>TK4q!KdNR6GXhZ#NZ&!PyQ0;1x6k*GCcYr&k2^fVwfBqm^;(5O4i4)fz1fbG?Nb~WRaRskZ(iy2G#GxJFa*NI3 zk8|bt(Mf7p?qvjR>#RSMSzh|An0nrA=P!qlq%KfWf2ZTZ^cf9t{XzUX+;Y35v*Vp0^@*I!%X-)%WpxB~!%==gDfSUJ7Sb zf2;|ia&oBEJ8+AR(Pl~XM;j}-pKa`!Xy7c`)Q=Y_{P0@*NeMiirJTSJG|E!HbGVe} z=OnLWl;fjA-u@<$Iv1w0_Uge2^~>LLaZa;>)RSjB1#hfG6eobLd?+Q@^b8d(QRD=q zRsivV%rL5)Pn$MHBLdIwKT)UFATFbAg~SKr#C4P4Trd&nMRG zxsKaqJdJePD`?x!dCb2jSo9Wis%nu?G@<5Q^7{V3os<2*7L7_7w@lT*7hgTfd>K%Rg$7sy3GEInL;8JyJiLHte&mkzdfE zCf{@!^KhcK5BtLE`M7HVD}+l_{f~wL|B2H_YO`yhRV1k%3s5XCf?{davWZ!J}y z2jy9tTyr;fW@j#LS{yS~{}p2XxEZ}h7p~t)X7|(!S}TcM_PHRi{Tqli-5XOUs~528 z7^n7@v{9EfkrFO)%U!f94zlkFw{Kmm-@K>acE}EEpr9v)$lszI`l}kA1aGYAlGgs! zurfzMD%rk@Weq6^A6h* z2L=%{q2fRS7<<9P*hr+NGr)@M$V?+3V=kue!uhBu+qq{auXaUI(sf!8VwE?g2$!8? z;Plw%%k~|>-a2%h^zRc^{3fj6IqH#3@$P@~fs~SA=oHFLWXZ+6*DLA*)tmR(QNBI3 zKRuM@G{b<5t9NUBvNl-Oiv+f}lKa`M+?1tSKnOV92RPkNe5D)VkEdY*^Dm|+j`X>N zN2d-fBxUImB6mdjxj@gU7D6lCcgA{qs$inFxx-ycZuRg$$2r@lVJI)d{@Y!HB<_(g{`mR?o^+W5B zEQC>)la!TxQl)P&?n@4B!h6)&NLb>{D=#gcf zd*;_`FC<38x)~RGMvjhb&>+%u7!_Tx>t{g72F_+?G_30V_V=y8kon%vAbrubpECmm zz=BZd6GFPVEbBqS6XfW={aFlkChLzd$)AqL>h86#>(9TjI!u?7{<_nBhQA^?CbUY{ z!7z$nhr@+)kAGslisG2aaQ=-hk6DX$aep8{A|JR98UbAN9FD1ffQK|}Ss_1Sk(IsV zJ)82%_AS@Oibsgr?I-j|d@_m~BL2eK+K&qg3V@|2CG^i=duvMv1Q9RKcRwR#=jGWi zI!@cJhH~sX{{tt?uTbN%j*L3M4-cW?n5IgfTX369cmS4kDi=00|7AUf7sImnd$vpW zSv-&Ub;y<_;jB4Ey?{R+9BN9(fM1p-i z(x9)?Op;`-{gpQm9l=Bn=>G`KG^CDT62R-^10Dci0giBJQF6itO|-4j6lG8W)WeeFI~f6|B8#uYnU)kKD4sV990?&<82)9bh+{KpAG8m428bp^#m3) zfFJxTzGl9c5D~k5fioBW+;dwB(OuF5=Ln=*iicxk7%dF$2%GjPc2!U(r47G=ZxD*+ zoxQzIqtB5DtA5BnVIDxI9^vs3qo2ev0p8d`ZOr~_62}%%2ho}l+~N)kn+U6C9Wt~z zL*o(>Dd-_-_?iDHt`BAfA(4b0YgBYJt}Jmb2+@JbXxq4m<}8Bo1rd~h*=92^tZ*cS zbG7-7_-MF0K`TF$ajo3pH`@qqKDet_7THrS#}S!@#`>}mhj zs69XhiQwy^@Urvsdm+=k^jNk@fPnSu*Do>4&@44~cOJ(0U=Rr*3^hO`VZlC+$$Jo` znMS6!IJgm)hk#L=?OBja!}M-?GO+$`3ZWe98i;vjQy+e(HQYpld?uQSB2P~otV!00 z3maMW{!bg4Nq-EB`LWj(PCe0Z5@1{jPbc396gEe+1H>3% zOC{#x$9!5qa6j(P5-KKFx_Mdi8-!-_ePlTYje6_5i#Bdq(ED)`huwO71T5af=U#vN zkq{aRgoCi#C>a(2PzpxtUa`e71(0#eSeGZVzCv}f=iY0ftGRtO*btT%RK@;!pYyQ? z?;j2G7e<@WW&>oC+R}X?<4gv$7PwqEpjuZgd2_J@_jRw|nyQ{OTxRfUZE>~P!Rqqa zxY<&!%lP&Tp7wjwKHDEJF^RZUFhl)NxD!wANFy34?o)bmMNb_u#C!f#lIlBF5RR!YknAz5LZfjz~EUX|OvY*ujtI~5nWL2&`E&DPtT8kZSeJ;=z zGiFYe*Zsk;UD|Fgy8O*RHvJJgKnJDSc5vJ1cW%d#=?{XQKIhs*LcD&jX1?x6VG+&q zgKvvwu#Erxif=_uuc?WHGc1;n#Rwf)7GKlm1_T#k`vE4z&7SYtP?H@-@6#-9I=urP zKFuW`a3udx4erCKD{HZ#G}g|}ow8_gkE6VC^_|hA2u<}nOm}r%O*4F8PWP|V=XM`skV^j0HkeD|pTlBkv?b2 zq)xSYEdq)KDc)_jgs1xj@|xhKiYG-)Z6~IUt@WqVnw3j zs@+?W6-z+NAOu0&j32T5PUf>LOV^!p%%B;DHpTvO;@h5*(ko;2z^H~R=Dkp&kL;(; zJc*NgtQ!ueZ2M|k?)X}sjWeSsbbFV5-n4eK%~W=73>zfU0{wl#%30xJI(;2KIF_f^xKVHpet2)_~IP4m?*N~{ks)?p=a`IVGh8975^ za`yzLCKrxVm z2zNrbgvKjQO67}3%#Vnkxy{+ zIzz-7E^J&wvlfdN=*pJQE1UNWA>b@KDM77u_{RyU!ZWO2jyLW@6&IvA`a)f^Gn0oD zsE7_WSabl-*v?lI?pxxF-Rc~9iuGk%Y3cd|{^V#ZUT*gO?c8Q<;n5A2iqvB@>F99z z+J2g`#c2^09sK(#&rDBMMwc`?!%kFCrt~ED*&>^%4eSn36X}ZWx2q}ii)cvy{$Y8M zFsN$8@VbX0Dm{#Ag8&FZSZ9Mg7suF)D3Dp%$&Gzcsonc?t>a z_~s*lIR$?Cd4BtAFZh1Ae0%FN{Z^)sElqFQL*RLv;lIYiu1AO(_F-YqOKfL_+w=2p zjf54wYRAt^wHzkOjL+goa2&0uvo4qi6vMtmWK^fD7+#T%KZBYno!>xmyZXMbw3Yo0 z!)8NOIwPS9)bWtNtbfQ19($t44M_W)kfH7)blcYwcOdTerYL<$&a{q=f@Z))7NbV)tgFHws9UHN(D)h+XxkzWKDMM=eLvGUJR>ub>n_s zsXdpY=x>MIclq#lK9fvpPRg#opmTcAe#;#_5z}4u@m4BDYg~BDm+8^Hr-Ibcutvo;=;?&k z6VQ%c>?M*J&wgeE>NT-y`@a~?)_1(%Oz;fUHb`gFdH=Vv38604_Ba$pVrF5f2m9>= z(Ep3a(UaNg*ff;W|3UBiWxA%+I=ovLuQRPzGcsV*-&#IZ`(p6jozmVsU7OPbEn-Xa z_P32;kBvi?q@TgIL>k)_HK3B~D!koo7@}5fh6~Pn$(|>(sEc?$qkv zG=KCsy=BM8DsTVvs()?MDT;5K)q>g+Z`{C<=l)bcN6zK8-TkWbqMyfaU=Gymva4e& zsiHsAllb}w)S*=-?!g43)WCvtp~k+T9$mEt2U>*U*J;ia#u6ZClc?C7i@;n4Q9k&r zL(TF%Gnf(K1h9;-RO#+o0?UqK_3DqH?C}{JQ$XlmZ6=jfD(u$#IF%TojsM+2H+7dR z?yZpQn;YFz$`5OrhT^`9P1RiZho!EcPlz&-tC2BbNZASTFZdRK=|Jv=R z&s@VZXPtJqM%)TQGwPwLLKpnhg*>&Ni|n}?-PvD|AK0NKck8hg%}a>2 zx*lzS`vAh8JZf5fo z#BQxlYV|{Ds*rlhTPbq6zp@?>2#K24j=9g{{&7z-wLitA30kqx5<1~6VtW4WQ?|yH z{}*=dQfK|8#fQP8{HlTts17*>1(LFGuCx6ljdJyQblF6?~;p+ zH4fH={!L)Eqy0^-EBm~Yvl>VE9bxW2W|>&*NK$9;-!(xTa4(21AoW}O_^k9hLc+i8 z0PTWK-qQtp@kJRhpCI%`LqT{6z~n7?tyPm6`%PJ={RKjxe-?wGkxTpwWQ)3vw2Y34 zFIJ~^Q+3=o${&61Wlt4+_tT@2jJK}oiqv%T+$D;;93{jeX+(>5Ftwk zVs=JFMdY{~f6By!QbS{->tQFpTAj_ZV0a3v>DXN*mp@JiOM`I?8Z?ZIxy8lwU~R_* z<)(pwK?Gua3Y5u!{};d6;rvwlGN^rDXvzBTqmmW|fsm+U#e~!gT+I_Z&ORhj(XUG$ zpF(8%6D#U5(aTFR;~D$$<_4~EILl%(ClU!TDsvUh&7Xno)4zWnGMl{8T3TAb6;@PN zFYMlb0iC8FJ39sW`2?|FhD?ej4Dw0d9+N;iQE3sBWhRLlDhTrA--P8WVi#=RNkhM4B zth@N=>E4w+2W^) z4Pj@neh=nI2|5?rbHCbDA}pH@1`|RP)-dQjv4A7tqSZ9Cz3_r>7fo`PK5ooPHjYYE z_nmIN6&RWLlJGPD|2F@=0a^WRkvB5rc}CdXoZK(uJ2zTRXCJ$H55xrF7!P!r4xfK5 z8wXQER=`ml=)jEai;m5Ieh>Q0iKB0s+!e#u|FpBXOG zz2hiya{v9w*8W1puo64cTZOBc)KW9()9$tGxhS6B2|EzKxA%x9XZLTT%dm{b)x9o>cXM+_)?mULPAc39GlXv6ipYyx~ zDQ4mn5Z9t_2;Asuq$nRDEmf zZA=SyRz2sRlCQYZSJ-u3Q|G@M6J=SSJ@*U6sxg`>VmZP0%g9?Bv0oz0QG9g_WXx|s zPv_MJYBz9~qqQL#hLSLmN&gOjgVZ3N;zwA4fIrN45cS+RbFR<;D_n$B8!Sy>azP5< zi`F`5i~N^|55NZxOj&H3@WJbmP+^fkVQ~$%Y7A?GIrq%e5mnKnOQBC$fxWYvPaMor zdc=8FuN9x;$83@EU3&>g0%C_KQk=w+cPDk`WCqHw_H!R-Fz5LU-SU+RpzB0bB^ z|5L0Zl7Nl}`~$cunoKXImThe2Eoi~Dwa3Nm@~d&V}yHH4B6@Nt8XnYVW*_lxGJedh>Vk>2KG1wQ^xe=+cPeeKgp ztG{Aj>UWd0@Yy526DG*@pcx;M9Rx-&;fr{1&0s4T`d@~Eu>B~8F$FL+4yJqZy^ygw z*o1ryNECT_M+^4Dqu_h^>Q%r8DojG=^-`9@V=H6({&5S-3yzv{KmAbSRVUf@2o34C zVtV1767Z<=lNQ5lXC7z5-cN)72x#~NqRiMGK_FET0Z)V5aDdqehd^R&Nh(IZMaB6b`qT#`-T4d zHp2w|pBv)UxG%{6vbcCKD4~2q5FH&;q-eyGq<*CqI5NHEK)&@|{EN=~)iFxs7dN2W zqk>ix7J_;=@LtN3#{c|cdakhJ8PFG}Vru?P#`cmTtnV*9KuYLtPt;*ND(|1~J}|S5 zvFWD}v`qgqZv^3TH1Q_u29B1`V2otSsn&;4Z2*Tr0&d{8gl}q6OxnZ~qlkUiU82S)8p3-(XXrO8 zV}&2_ZBwGk^ImG(7bHtlFPWVkqKGBi&7|^pC_^}uyh$712cqsuZ~;TLM>;hGv>L%o zwq9R)ziYW|=lrig)nonHTaf=DXmb>!IWu zD^@R|mPXdQl9^#n2>`%$b@j5NASfxKDC0Gy?$BN$RaLJzI*cdIwV!*+{d_l$jl^kL zKQpi`J@5}Fx{?nxSa=vxV#x=w2rGub5aa(oGdYG34C0vU4?N78t(SivAnP3?CrQ*I zC}mNi9HR}+vudY!GQESy7V-@W?Z@@e+dDgu)U{VGl}xzc*`H*;RgAOczT?Q9Wl!4QoSygOCOin0#Pn zX~OaGv1eB^7_(`D>(n_ff|1$@o*AqIu^$6>)5hcas}_b2uA*lj_VQqCCUr`yoxRnw z9IP_FB=4)^tYJL(9%ucJavh+a z%bNlfW*8_&4kN-^H&fjcV0a@`b{(oYfjyBd4V4usN1|lo8(6b-Aw5GS578s0r=&AB zJ=Iw>On#_KX?0@}PodZBX@8yTLJ1P_@g-|kmF^uFXiPS(fBW>AA(|0*Y&c7-e1~{E z`M;LG2shJfBUp;v6fq9)4D=3%-?)5w{=m%wt_`6cg3@5Pd=k92V$e%12H?N(P0XWx zs=0Ylhinb$4QUmnm}0IGZOY$@%=9jQIhcQ6s;c z#MCwg<$BjS=fZzAJl8Dw3?)5%3w4D^)ZhhLcg!?2Yz97 zkO$Pma)03ZFIytozv9!E`DS+xFl)}Jt4jc?pi4Z<21q(#oQEKI=OPS?p;1$=UIH@# z!5)Dr4?M#`C-?3h-<^(;;yj#Z&+}>r1qZH(nwJ+j2M7mh-nRKB#_QeF=(%sLuQ+_f zdRp3PzU2Mwn3YtD$8b9@+Uyzi>?a@EhwVngUvZe;&!XY&L`<;9G&D2-BPg$`%Aw~m z7!%w@t^-6Li~=^g9$A02c&rtVhK4Z3HK6}3A_(*Jc33Fpc z=I8s@fi&9F9Wq5db~U`@8Ni`d{ebw&h{LpjmcbCek3Ob3l5zz}@S*1SxKqdAd~}o* zG10?-q=NTofi&Nd46K#JW@GTsVs7SGa#v0tJh0Hd$?!QMZYgbNWsFKLiidjmXziu* zP`#X$6(gAa=P0Nd>C^Y)pxqV}6dZGk7t9(>pES$aV-$u51aod8A|l45KKc5NaBrxY z;hn6-xIU2o;#fd5h0zSk+t)v^%*}^rj~>@NP7>=Y_%`eLizjTYgo3+{bhJy%L`WDBAizyuZabJybhtZB<|JrmvD}9kfJi>>PHIX&K6vM^HBKy3c0&N zvo=3jVT~{eMHW&Sif{l1B_pHkNB?_p2|}8=V8)G?BCDbsE%RCsOLlpAnbmSG<^lh| z2x!LWFefH38=^s}d|MX(&f7X0QqQwv(w8rZ-%l2)L?r9&P~h|+Rc;MYeMn=wC-sV7l5m0SYn&Yef-whUMF8Z+f@-{y9(^=ok2 zqbr&~y}lCfR9^Q!I!AAe@9Km+PV|UbEQg1OZQJs4)~y%pNRS;vwcrn?KWA97Pdtw? zYZvSghTq<3YIjL*%8u_-nITPL9CX`L_z9I^I1f6r zs_gpJzvk)7d~P$thuTs%a--+v^BwDyA5Xx9I&N}qM*=g$*~Aaj@MPc((P-gPErWv; zqVdb+u=NlgPRmz?QBOd1-p#`l@8M~70rYa zz%Y1D#d7uJ#PC8=Y{_UP%PQM;-P?D&C_3p7|Gf6yM`lj+#<`#jN%E>L0y7bZOgkIl z4~|UvI11$Ai1i^9xFh2y!HHX-T=XVg7{0ZWAjI7{bvAje{7J<#;!pk~iGJ*Dvsz<7->DJA{){d1i7b2cOpW;mSwf+saw4lL7e#;olFkqPaB7 zy^{w94a|56is>?ii?AQhJ`C0JW;=f^2Igge$PQ&D-~`AVe; zoWq>N7<@2CNFaa=@o424O8_f{!2sHxzb6F(uwF#jV8Ku63cF9qMO0#An#X%jlPSJ= zMoE4YZ!dDlGD^85xEHflNo@ba+M@q$=YAnIXFK=Q)%l;oWlP&e%Y+m)CFTj?jMrK6 zKjdjI@)i<{18m*UMhn{SPKl-2x7OLgC1i<(5-~SOAfD_BbLeAUBA+j>bzgCKQzXMB z#bo|0EHN!83TxN->gn>NZ>!ELZ-w^h-ONlBY`i6;Jv;K$tk%lfIor2m1?t*77DJkq zXwHiK4G-vY!p=rL^Kp`G&JL4h*)ITypWV3SyJJ35IOi3_`t^JtKPL-JRY4-Iyrrys zy|79}0x(DnVjBcdp2D@9WfnU$Ohq>c_Uj#}a9D`4qRla!uJovjNbX9=rY^Kq?EbD( z)^@<4Osr3`dDd(a<0e#9vBjD(qv>kF!=3^!o~HZ>N|OlT56t3{5Lptt``Rgikdb_J#VL!kmsAxeP# zbmjS)CfaWIC>IXrWI3-r9%+NBnV6Gs9`l_u%j=zbm-Lbnl^P!7jSQ=qk8>|&dd}`@ zW(X$yI>9a~l`_$k3Nuqs+f5#bT&VtMsFGRsq-(!8?ZqqxwnBcU(;L=+EScTCYt2f50>-f0f3YH<|^LvPxF{( z;)j1pz9ub(A5+MTPosS2lmMn$qPkk55>I7(@`7SAXGPfFnJICNRogy&dVMu+$$gH> zu8Q}IrM$^B(i<+9PH?wb3`Q~0(+cl}@=#(u*0qYkSRcqUV3#(MalQFsq!5DyejI6fG>Lp@J zJSpSUmDxN9c|KYYDnj%F2~CQR53s32 z3+l-ar4*8?-YP}jp~HSOmrS6K4T~y2X3k!FeuGul)=1wrwSCnwQeT2x1V32J5F+Ls zF)ez@&fh$f#kQ+mm69olCH589WUi{9SOI?vsWf(Pz81&~1VESZ9)EIqygfdgWSI=8 z2OMAITwI9&OyZyw+=u-?t)RENn9iAEF;*BdQEK-<@Jn8_WZzaHMZslo-$Pk{Qx?L% zUaJmCSEt{f-evaBQ1aI!ebh)E9zxE86!>ByS38GqBIt*Nw76oGmLF#&7#|mEWwUSDuOieaaX9MhY`Q&>h+0a zl$=NStnJO8aH^NjS}2t*;FO#>Xyqj<_TnVZh4MTP_PD^jq)YC(EcTlSZ%NUyG_<2P zX7_R!{{m-8^h*Co=j96{=q)4UA2nL%KYt)ts4cFob1^&u> zqoiUc>iGKt*fKIkMpU6Y948k#-%)vmYV?`nj+L{Gt;q*n5)_wT&7|^RHckbaePPdqglqO+i61 zLV6wBG)G|Q8|uA#MdevZ0$8$?W}}(6VecY+hb5Kae!5?*-Tdj1Ro!w^_efZTB zA7EuzeUQv(+)`!OmmWwZ&&T1GE^QVmOv1`74gO%xpHl*dxWMef=_d|r?FRisMZQ)eJIwEB)M+m;?2l97sp_};^HUdLA`Yp5gtRoe z1Lyis7zs~ca4u+N&{hG`&=3Ig|O!F?0(sQy{l(TqNg2KLJ$wwOVw0}unUsDTzsGA?EUO?R))!y4q zLi}93YrP_4kXEGVBx>a{*1{Yul_n)E(ih*`lWqA$AvYzE|GTEj5+RD@_Kpry zFl2c0?hh%fn>GVxCc{1)NaoR!wL7hr+x;N=2S9|nbD!vIPb7Igl(G_%k_cK#!k`hi z^FHwm2b=?$(ql_*{7e;Wrqw(lK6&iXShl`Omrr@D34bdR>6WFvpHC;JJazBq-b!{dsl={JIx1}jt_C05W{3=wuO1-jBG3r-YGBo*qGFqu8e zQcArF&rF-A_;B~MiO-ZFi^i5#)CAn z^E-r?D@wiE>^7YTe5_$|gD5^x*JS2c>22S-Df<^|cL{%3T2$}1>CT05JP<;g%r7x2 zoQvUP<9>-2$ZA^t>W(50b(~nD@~C-MTMk*?NH5hSW>cIdhVhf9{5^d^AI6JH#cof( zDtW2(q|nU7oDBVtMqaUZ<~@z@&B&#`{14;;wywQ-5U0M|a7va~UY{Fup9Th`?dmoa(3QdG8bz8UChlg|J-Kz&7-cj0-o(J%cej+9iP8%v zk*c`S1r{oN43eD6SuBZOg_m>{Gwa{`rBm^^t?1{?@3!-^z}~LiJT|kFQ1*uj2EK09 z2V##NqZk)EY_e_Y+pp<4Zs<&MSJcZOR$qFMP+5=&jO>Ox%NF#xW5PP!qEmmJfBe>& zt1=Z46MM78-3GOq(wf36T8T=?@xQJIqRFOPul9}*R1VT^UH!u@b0_h|6QV&T6nb4- zk0}%XjwVF6lP`d|lzr5;x}KCC^PJ-#D&_YB>WReUk5$VGsj{5wBE*6vC}l#g-US8^ zG_}r{rVR38;8rK``>W&Md{6z%C%IXQ)8aDNVcJPnEP+dRR;F39xd3bFeE=r71!W?+@+Ug zIFI^08{LcEPx`Dci2IfIXsJP&F`s+oMzrZ!Kub{c`{3Z~!bd3ZDAG3?@#{nLSIpj? z=*QX~Zr$PYYwTdCo4N?&F?^)0$gUlpL?d-$E=)Kd4@>(>_z1^QGz+tZ|L?eB{U)^&6!) zVRW1DMA5RqknJb?;3Q+9n1xj}R}XDpO_^;uyMLxTwWl%0P5kTiSuPqq@*$p5irZ-Ow)`o>@h9d(vn_T)ifHk0C7RdC44P{bGVhGeC-q-I!{HNl*e#glyg(OQJ9-tD)tsbWRh zR{v)`I+(%}zqMroR-);`s00$6h*&$FA;xqOgkcppzw@+6hvWCNQ^6v>+K#L}n|o;K zy$`n6Q0DWO3j$gaBXxv-RzFa`PGh(;rgu;=FiJ;?E+yL$@4QY{^pTugb?|)&?gOV! z7_|?Zb+jZ)WRk$)VeZ z$odSfR~#wQYvnQ&I+^qi#HxV1yFrL3PKECBBc{>Q(xS#n;%>dZj#+&8@9xFHaLxQM&*R%#`028>v)t2tz8 z08wB-K?;ndzuv<1Szt8y2`pG{ zl6*v5sz^Pn&36IlGjTfpP=dFUfTqam@&8>rq@|<`{#@=b9m@_nr~$M}yT}IiLEw8- zRMao4hsfOM*N@5W87nIelAjRfck!EfaG*;6zh(BWOWlj^2h9MdY6d?2C|=iIG_sNln@E~qt{ba zpj|&8#FrRnz#ARv5o3D9#qY8h4;jE|22OvwAHr#AXzK1nYH{a{2WgkVbd8ui@_}=knj*7EF17Kvrtztq|kn@QJspysj?IF;@%=6mrmH^FB*=tHzP= zze`^24oBe*vFo*9k4{Q@wDqE|1mHMgmF?&QJ&dz&=r+m^Ywx9PbErW(Gz^ zV|BLrrj1et6{?`uy$#b1A&twNG`mF-_21$-8WDaFnUvHIOCb8_2PQf_m$8o7rus#g4TRj3UWff zfPis8iUCp+{(fBOsF*9AVLn~j(h8}Gcdx`l_ zpz*}(ACuSR-@Fhj0No%!NJmr2=z#zo&e1#EG4BgrDE{aQBk6}}*kZ6(Y}_Ws!SnmR z0EiXfWo4L-=lTW&uo(0_1WJ+2>HgvkMB+oj69)wYsm|dws(k&zWAOhU=qdce^TqhM zxQ*Hi@YnB0{DHa67eBcEE6cBU#!=Bj#UDdTMQ<2d4aWV%5r?^mXfzNF{tw3fJD%(Q zeILhFSBs{KmNcvoC7DGj4J%v8C|SwKmf6xU%gibxBq3z;R3RZNq^xI1$ljapap$^T z@6YeQ-*vlP|L8d$&-*yXaUAD)0v^T`&AS^U$)`tZh}aG8br|pc*kQ49@u3e87;lZa zIT8fc|1!|Ku`ivoKP4%$FTzz4lfYEY#n(m3bNugW`#pVnr(FMZLSp8PK)0C+;;|Io zLfq<(Y>+n3BWdwG`A=ey*^?(vPQ|^w$-3^OwSMAx-Yg!FZyky*zx!JpTOP741b=%B9ek{h3&SJLBtLW$w3M_7nP57+NP5=kPqR}YE5-Zz^I?h>AkLj?~4#+mcshpn`5q12+O><<2m>%@@?J8?Lq+~V0~H#gdJMg|6{9Icf$ zIhud{{P`Xg+%4Negt<-{z2Bj!oZwHnxYaUvpgW@QXy$3kQ&Pvp&Ened|wS0i3gQNDJ2> zW3m#e_n(`Y*n(dyBO~J~_eW*P!#viR8>`pGboevQ5sMp~DFY^%L=!*y#tBNzx9Cte zIg8CtUjH9#KuY~mq%^}uu0oYmT-$#)$C{Ti8CiRa){5r;4p-as{-vY4bleo){I_jfYKgl2O<9@E{r`j1 z4{2umVe%4Yh%RT6`fz^zqCcLT!qI}_baa5K{dX~{kkELDw%=qbZ?^4#oV>hzc){8= zORM{-5^CP_mo8m$<*@+xkK4pb>SJ@Y%v@y!POS`9%y62qnI7vo2eSZbX0)e%qi>1T z$KnLQ1(TC@KpdLmnP1_33agF-?9YXY|8?DPjvrZ>!Vd7+?3n#2xCWk(AS~DD$ukT4Z_2BP)&4+?tVat7-+g;)1EyQ&|DWI zy{w$cgeMHl%*?#GIg)%AdGvjS)=?53x1{9cR_#4_u*CMtrAymYEn$_~t-tjoNW6xod-1?WcK-O{DRji@!wfbm@drE zsj-K10n!10P|$4^Ks;n%SrLGdl8TBARvaB(bw))x$;w?@`Z`h?`%u`YsHljI$!f=N z7m+Ys`gD?IDh_Rg#3R&!bq@)H75d1Y##IdkDIz~EtjB| zE3%ttw#5>XhU&S$i+6f{j^9-6Z<*7KE)@tLgFKVDkA{?YS!L5YL}t=&c(lUv?17nP zJ(x~lRRoeszJ!S)|EyjuZP7+A940qC-si5B#Ex);1Yl1W*RC`SX;uW0UJpS=-UHmc*4vt@4)lpc(i4NeYOwZBE|t%Ztne)v;=Jp zakcfy_~zx?!Ixr+`T;Sq7nq7fGdSL8wt{H$o*%gmykev7W1QA+{TuUgXUIUqgDIbg{yB#%Xv(Wg!Sef5x2W633S*5}xo!6o>; zVbrp_iyxR`H>9I;{rV1J3zh?CJqAO3)aWrp?)>@3z24EUIX3I??~K-z4w8!A&%yEd z>C;UYg#05veUi}N3B=cb)lZF(=UB600~;<{RjY@6f`(ir%CMQibdY4`wd59Yt|i413Mwa7vf znd=DmDRE{vd5*9sx&{U_e+52vH^`C}#Q^nw5&Ob9w#3&8od4JLpcxQS!)j z4j61J1ZEro^eYr+AoAC8nX~5Kx&F~c=0D|vRF-4M>H=?JBv1~?&3{syL`kR$K6}xX z7FzeyX|5g3$Lyb)2}j51LIk!Z99@EFfWM ztHFDIp`R9STL{=J=(zj*g+cYraAvp zC;1(2F{ChF#qA#r`T6VyuVRvst3THK!Z205dKFiLELpE2)CTcMM8&vJ%c~WP@3g2bg zoE@PUhgpkQ@B0X>8LrIq01D>6`P~B^4tJAedV9RtX`T@i+@}8D){-ep>$V?0sT#M1 zRm8mQAEUJ3%uFGwpHQBbcH8^AU%s3NchWC?wxhH2`r++~mvhPQNF#lxrzZxWR5B&l zPu#z#2t+0b_l0<)6`!D4cznKclZm8h?$B8>2p3Zc;ZlG9`SUHucEFuehM9?}{O*d? zuf8iCw8&{j+JB=y{zJEwMO)!o%>7%=RBBz+dJ8T~_vInIJ>IJ`8p|V2KjO3si-`Y- zWp-I`2K+$$2NK+>ePQ)=>fC4P7ona(G9J_iInvoE#B7lsiX2-3o71(ZvH&HIuahbd zK86Ru7ee74cR7`f)BkU7l3wah7_R{L2*uwZN7WOs`B`?ndMJYklY4}*JL!icfl|<< zR0i+?U?|f0h^9rQ@1wt(W*jwb5_+$G2=*cGm#t>FWBc3{EuB2+brWJ+kzhsm8Y zHs;1weQqBDG=E-JHrjsdFCG~_YW0TQ$jDY~+Op-ek$^it$l*E0q! znIduEoTuZI;a1w{Zw{}4XqygPR^h#sbTn(|=%Qdy@NpOvI2Jt zV|U=y4O?=BBO)S}Ig`|;ux<#&D;P{01RgDjcc*jOe+C|nJMI>=uCdx4=|P8pMlDGT zl+Anf1np;2u)C)Db2e0}dIDyMWs@m|i-F|o{kr5D1ViJ=0=+~=BNOxzF>uysTLwqI zacXLc2K~jpMM}EssQp6li>Y=+840TK{@ShBTjlNl;`%u5nH+&W$G;i!r)UEac)bZ( z>{(r176k=`CGKKt-Ntf;GlF?Y8Mig%#sQJYU+0Fap_BZNh@V9<#2Z6)Qlax)Qh)xR^x?2ompoESz(SmsV>a~k^t!dl zFh+uBpt zW|NbAiF>=}XTDw6FY610;9~QA>*IrRi+`tQuWjJRp0PG75ojjw7wHToJYX$(B8^<; zAUF4D8r9V~ba8n%b(U6Evf~gW)$byQ0LbQ6+9F}5KW(O>R}>idgoK_0GoatIr$4?& z*#{XL7u`4~qHNr|H=vmHK`*DIrb3qcx{7O^1`pOiF4lLhpAS*n)iiYyp0UXCDA?Gf zo9V{BZ;r(?5SH$-XS#OQS zLuEa+-9%s)l%$NI|49YO2XKDz6A{w9f=yMpbZJefR&?TPHeWce4vQ((M7s6s#~M{5 zU%7Q_B~AD|e6QI%4!^YW>aP!X1DQhbnFV85(4;@ zI1||l$gyjEw5U^+4i5__=NpNQ5xj>+IxAx@z7%788$oh*zhiw3X$OFjc0T6;U5(rt3J67Dj*6O^NR5U#@ z8+rLnSlKW+#(y)-SsX8I?NTJrB@{oo*pad5(i%UfB{-tvX+bNfWx&}89G)ECvk}QG zf@o)FG5fq1gM))T&K-;57(XYu*qZvlVFC`uUNRVuhT#X!v!9#$Z4=9jCFkyUR6jMS zx3`zCQJ526zQ`zze%m(JYD35X$fbQ9j;F=q>KXCZ6hY5}Ko&gkk)57;V%;pq2(tLt z;#DqNNQEEdG)OC!R_e@3Wi4M>12T0K^0coLt65u>m%YU#gOB%D*Tk#w0${1gG;c4R z^}3z5Tg2)yu;&^iJ|i!#Lhg0~1PAcQSg5HSK`c`Q8WA1aHkq#mvQk)FT$7|N4&UwQ ztqEN^BKomraf{zyET}CI0KLCiUor6liF?n)i%Nd@%>4A3KOtt015{l;B3V+3m1BZA zLo&J^S@c-*%g|6602WRYF9)nk>5z-wS+-)fWdtC&TQfTX&)@YGNK!uZ+067a3!3hh z4dx%Z*%&$Z1D}S1AC1}`zjJK9Pe6)PXP9Yg@YD-W08ySvvqY;WVe7ws1M?G97iv;p z%*B_^a*cHfSL*9+;F>Q68H+wDQ0VuSmBgJ+ZtT0e;k4fh(~Mn@3R-Q&wl0jIa%>Rb z_z<#H1eR{i`o&&MO{giSj5HyRm($me8PFI8LRtgaJym$rl~HR@R7`xj>i}~3KcbK1 zUL4F1JvSJgH?cIoPUaA3W-hq0B^VO&K@D&`ZF7yyHK#P-Ll@Z8darmwQoL zj?1S9Yu_7I-xIO^TlHGu0A+sqCRHp1$S=kqoxi_oJ;VaB2w;d7>&JY-SQiTY!nN}3 zkzc*Eq~akn(G}sfTGAc@p!GM0;2CGeOIW1?s^Gn-S#5v_CT8w-`Cb%@fNq6ivw-`K>awc%Ls;WHtHE1>^sb$1$LT5nda0= zV#INItq=uG)*9$zAU%7j?J`Fi_4svf`8Wl;y=vp0uh_1Bhf3f$5R;FGZW#+Q2@DP{ z#oLJB{eragieKQZbRYyYZpjS-+Sib7Qih3b+SCz39Q^uq75U2#Bohb;FgH0`!(|`o zbc9rg>R?)*$b`UtQaFqYPlCCnUx6|fBiJQi%GBHh6%8HPpx{4t^ZVi3!@88IKf4r_ zly;o3=Qqn6V?Yp%;8eTO6bM3Yl9E;)3MxALI4D3S@#(;Xi*X6Cyp-atY0XHXolGWh z8hwL<-RgTfdXn<>$ulEiBhQ>+RtW%1flV|9iBV^$G`{BFIAml-&FV-d!s!0o5jkof zZ{DbncIP07#BhX9Rz2T6VZHePzT{u;#{TmASdB43_a2Ac$L*d{oMv9s1aeNdiXJ=W zzwh+JsyqkVZ0nxSAbyVo$xw&;TKy_1NC(P>BsaFf%bA-g#mURMUGi>I8Iy(s+_q;37V(@6m$|b2f1R8#>kSy=N@NamoT2s2H_} z9CBoEDNZxv7F1_6h`0vumrS?h@b7ds36fzW_W^+8BdO@&K2b@NYyfb8j5_OWAmM9; z1^LiE6k20*z1VB4&s#69uh+&QIo~m-*;JWx8$r8pTLVkVvYLUH2A4TDi7wikzE7%koWg!sDg z(oUe1BhufuF)$2+));hIm~iQ|dhu2@-ot{PQ8HR_nP177Tj38}U5`LsK=g@WBGyv9 zjexC_(er}zg|`x6LdZeNAB`K{uSN0>3I;LrIwp~=$n$h`_9I?{b3TS{(AY|;c+~_$ zFvpz!ZUj}wsZ7x=@#z^J`H^-LA%75(<4ZjeV}d6c#xeMod)pE&{Gob+epUD>;B@bdfAS#} z#)A_XwsNAz4m=4=y+^q|_KT^$L+lC*FRwR7qul~R&5*-sT%85^PXtW(r%#_IG5oWP zfb~c$u|BaiB?v4$w30|TpEI6LSkBhlywrvRM>e|2%ez8C@sahM-?>vN6;UnM-Xt~--sbb zSv{TSrwmk+w8}vpm4f>vJ$fc#1m@^f4b&0A!6(o+W|&Qgb^^P=6wQjt>SsLbPE}d@e+I`~URNH0hejfK~ zm?)9!#-Z&1XEf+Qsj8~7Xl}uq{(QK*f>I%N5@9larfBS;iQb>5$=glAF#(HQ5j#iz==KtQ;V^2s5sOh64g-n%Ks4kV4KA zyFAZQ=Q?jT4u6H}Q>SoEi(UJA?Z{zcL|y23KyEX2W+Rdvra%aZ^R6}?Gi~~j)VGyZ zGH9mgi%#;^T^AuFAW3*(;pChW$5z6}!6ff-8dbnm^ygHGgwv$xHsSI@Vj4FvN5z1} zGw04B%iRgXgODs1{*BV<015}9pIK10MP9}vdLfFEr5dIwA6ev$kUUKLTs7b<<0^QZ z67`l*a{Ysy$Be2_v!=6U$eKqoJBhuze2Tt4yQQLkNg7sX%@`86B>!JJDJz|Ei&oSeXt1F;EbXlFaIT1NiIcwSg z>;r)2utNn9{;&EY);|+zU~0lqPDh!@b(k>K4hg#fS6vGeS(yG4N^9Yl=c>m%5qiLn z)Fx@?|ETP+H@R|$=4d+oQ)FhoX3NWkN=a z4xA;-gX-H$AAtauJl=fZwNh+_aOYe~Q{bXxDUb!!E&Tgo^R_ctG@1Y4qpyP?>ut>N zhWBm@m5b+SffYUvV5hq|M-abaBg8T)-_D042D!$+lXx(sArBt>qdnXQP7r5e#(G0f zZCn9t9AN2C5o?<7+hS#;JCOACvxmb{z+XZ#u*D=C>s99{5Vl~~*$T}Q1O)?;eNFbO#n<6BgbH35g~eDg2msr>Mpls2jPlu!8fwch&QMC16SZ8X zAr=6dCHT1slDC}y)J=a300HunkX7szW^e)|>O!o@IMbnC?Z>}9{I%G(P+`kUuvM@D zq*2h&`T~l)#fa2^cnAvTLHLu#RM0UEw?1MHcy7ErXg})%jvNw9B8EoqFH!98vw=uO z<$Px+2)?8PxDN`}B$vL4Y>*~Fm$|w5H*xWe8I798%uU7umV>`eIZt=92XJ5hgd<=F zd+lhQpJ=U3E5{xN%>OIQ45@6MI~x<0@M=PZ;mc*EHRcMypNS8GS`09~)~z-9gMu(E zG&pmGq`wHl9_yCiU2UUDBJhr+126&`uie5#GJsHkXrHzv_DStbbjHFQHmuqUw{WK> zt(NGc_BalHykjK>K^l@44>OSoaypT#UD9EX%UxH%D51=acz)duz6YS~h>P9?YuEKO zaTh*nHx2G;hxo$OMnu&UA;SlG$`Qx`7Qok&bdswr&%FICvKeV*Wo6f%qyDd6i4`Nc z?rO~7g-t;eji}}`9E8l-vu8g{A4J%96=#XxtR)l$<`5%lx2=gK_akarL=9&IKdH_3 z3i5z;nkT#=o?j%m=+f@pTMo|0sYaO;2Iuhd@>bTbMF|7Kx1xb5fUHUGB>8GS8z4s-pyK>C_pDCvj15$$0OkY{_qB@ZGJ zNWMLykA{Cihp*k!-dq$Za3b}Kt+O*_Ak(-i{Nu_-4RljC(sGab*vqmD@&hZ>BEaHgIZwgo~dF%nYe|6_@ zlvGxpi>)CPOzrE<;BQBBuEJVlp`lB#t#GqixpF0CCQ0q$@u7c^gm*-Fi*Y9NVnV4b zoM?~eMTWY00>u4mppAaKYZZ3=LWkeEVFSgq(Har0aZ_gX#EV_Wj2=iux53Lh_f@^S zb{h1Cv~dVPd-Go(mgUd5Cz>J9fx#uaS>()9f4wp0)hhtgQ^>3)yP~|GBOtE#^2GY6B}wvAHNKd&1z_(@0&l+wzK`A3uI{o}=tJ zVIPh@!*EhS#e&D9e_5q;K@ene%5GTx!vf3iBbm1Y(URAE_p0AqZ1_W@y-V}`i;#niAa?{+D2Td zwIQ|H@D%tFhmnASZ&`+pEgK(uc(9=O00CC8fJrH?g-3nQSenUgs@hQ;9{w9XNm284 zYT6Ms-Ca<`fI`-9$~1>^hDsShKSAgfNb$wTdn-Z9t#Fk%m86w7+c1tYy|J_eiHLB7 z8_3|wGL|oTkuzr=Xm1`LegzQhYlxsXAd;UGW{56prrgKSs^{cfpVA#h?C$)3m=*QK}ldMMWBKv`pK%WVf>8uDl zP!VK7;qn~NXZV;&X75KB5((M~XA%%1DC-G_eu&YirQ!ynM}KZ5(04Z&+I?_&?mD6x ze3FtH9+6Jx#sl32mr`=vxqU~DJj1NKt>6ogQ$vK2kYs0x&y|m_bA;*y9(e<^*jMD* zQk@lH3h{NLy+3`C3x<%)(I$tWB1p13VnR%_O~t*5A_1T~xXGl@qbWC#os&ctKq?Mo z5Y8!P0wLNFoBF*4@;%6w*FJmpjD?MD82HL{RIK&206T*M0ToE+5>$@@V}TYhsaX`+ zit)$Jo{w-gK>{YhNiC!q068;VRd?}w#embHehS&a>8R?%s18D9T#|{1P4DBgQLX3| zp^JRLxG^30a8X;4yYQ`FG$-swPsZ1MgIo?;-r^Ep(llk3x-iEKv= z#k5%jtMjP7CmKWn0ZW^1cZLDrJPLP-^z$^L>%!IulOkKi-E#mMdg+Dy_RZRe@z0bK{1#_Z;P9bb)(65?x-njuBm3vCka z-4N_WxMG};gKi=X6(XXMtbTXji$Q#@MC&}*a5L#yaIh5Q{Dh)M4nUA8#DBOZz*C-@?QAFC z1+(s?v=jGLS@k?8y!7CE6-NiqV&|2EkSw?+!te27X9hXH~~UYp|@957us( zH4@my8U{(_cO7rNpqTwUg;2Q;+39L$ok5aA#MTzH}vhLrsq zfOPD9d(Q;<1I77b_3#i7x>q>!WQ|dDK~TYUt5(pUR|0WchxJ2yh@Gq51twQ*apSv-P_1A%|ry~XuGG?2Gpqku^iAsv=NbpfsjTH7(fo-=owaay|W4ncJ(RddZV2973;u%usa2z-|$iqOnsfOLarUO7n z`#yhu9^MVx&IH*LGgv%;?0^uYacC9=BEZHV5V@AsZ5A+kLMNj!=%|60*~<1v=9&;e zGdaKjD4cBvq*x3y^YUdSJj(?gok(7V3;@80GEhLQ189n<3$b)j)rN#fz|ZVJ$^u_T znduk(HF}=X=2M2>{|5p2KFCv(G)dt00{=tfVqKAy!R;c>3ICFd#D&&sciYY%8D`7? zGB#Rm*ZYQz5k@{o+yYgy7O)#^3g`F{oM%Y#atn1bI~{t*extw6A7nbjN;i6;3L;3`mR z075?E`1bXL5>ywu%Vqvjt~K4dNUVt&qqI+vbfoT$3JS@=Pn%Ju*vQ!LG`jNV&!3Oj z_<^IViFG-REx6H`KB|>*9%+~T>29qDm%h3x9zFe!GMpvJ4prR|nsdlsil-9mq2_ zvmsx;kpR;;gp4x>s$Inb-Sz^D>Tk{&!5x}8dvrOGHbzxnY7@InvB`TrQ^pCp7Fw{1 z9sAJ12ZPD;NtrPmBhYdJuI^GH+z8eh(Gz%msPnfFvfeg z04&F8E(G-2OXS;#dQno~vFAj#^0}j`@I~R_;nII?F{}cI%HeN)(nXXrRT)X!>gvhK zZOGo62eJ8`n*34##v_Ofxk-)l#^ZxDxPPH$(=nNL$iSl@Xrd^{w6!IGwuH|)3|Ta7 z&8gMkr9T0zanJ_6q#9o*$kcBPOyQFuKSM|pHWRmM6BCnW?{)Ct2)S9sCOW?LjImPK z&;rG2eu@!IsD>un7xFkuPF@gt(o>BLorJ{DUTFo*cjO$Ou5_6%IN*I;0wF3_?L=jr z0Jk{8~;l{fb=^Qjq6N#@=Z?nok?a#VZr;|&V0?Q-; zcS9jkB=+4j1C~m@%~CrNa>B@{DBf=21XV654m+^iu#GEV8~vuIWEMH_Gr~Xxg`8(6 zt(8s+3kaka(z&Pl$^wDQ=VY{9V!k;yU>gwMl#P~}2BNm$v)bgGPH33%8vq+Bal6#~ zZuz{dCBj1siTFA#!6zPo&F-*=5H*29vLSyk{keVqf`|b0_(=4q64EvE}+){)9x7gtMt%MgA^%u0w$7;9DDCcss9H^qpQA_*(BA#U~2 zn?Y~lNZrdCh3^{1DJ4Lkq3j`Z(8M>B`MA-JZ33e5`%dx;$&)~h!uH?$9&xuKC)0v(g502 zvsPgE@53{Fbba5-FUd(0xh@E3_EonpElr_{Eo(_klg1VXch)Mi#Ui;#nJgb&B`qy6 zNOO|v9qQVAmr5rS-{76_@gAfv7c7Bv^-^((i{^!tZXgn*)(aG}PKIcry2pQ?j=lj_ z`MSPHV`1t>r3DW9OCjedvpleAxn2Yn#8+%9DuciHn8>AF10& zGITuiUKW;rkgUl+^}-#l2HSRU=VCdh`9rwJny7Od->7;bMVNKCa?~`ZPZa^>Lv<&S zW+lK)@E3A+c0x!b!}l^F{~$6qQVdzgJh~NGOABn5W*Q?o`wuedp`dwV{SZon>RtDw z$Al04kwE={?wJ7c&!unUf96A`afhXk;-bg;Q%3vJ)j2K)aC-S^T6;k;?xm z_ZI);$GNnXOE36e!RgY=|5uQ+^kyl4k*Z~j|MJ*<1fK4Q!?-b$5A@)mThy5p>eU$X zBe&uM=0Z3xD-q4`)AZ3x2I;}3jt@XTg3xdcGt!sTxN2WfIRVN608J^h19sqVQQW)1 z>Qdy^62GOsU3~oj>8>L(^)ZHa+FWxQr03;#SC`r3oT!R;tjKa8^iFQ>anN=XmZoMu zQ&`~>mpwkO{11W|HjFL1lhdC~=PJ=^y8y0(gPZ#mKn`S8t)VlO0nmLm_WeQdYJkL< zu}Moq4E7K}%M_d%O1YWF{7@mq z)^ytvQa!dN%yPHp;}YFZx_prMqIM_eRWaOSWfNvQP~xnSM6ja$CBbb_LK&f4Y8Nzo z*6ZqCG=c#L&*-3yfr+31V^J+sFU0u+HO@%5a{eT7C9a1$0RuSi6}Ru>oQKYXnE*AA_t~LK7U-+DnmGZW=n*&( z0su(HaFDyAe`u(qTx8l?r0cc9YuQn8dMJKLKJXzftafGoFCwfku25sS98W8q|tR!q%ecp&r0$*bNn<L@3wc7 zlsTOk=Or1p)k`5S8e{m0ZZ9V7)f) z-t8d~fp~Ib>Q7){MI9F452R&09&^PilByY^fH71~{VAiR)uLiz2@ztEPDFV4hY1m= zhdg!(dI3WVg~I7VN?fSVW;uh~=o@weX7|>#S}!Cripbk=UAwEJR_D)+@51H~yab6G z&7>&`XHfE7Z5EPnC4-2pwm?D-(a{BCnaeqx7zu8lInYM&TE1_ra`KOn&}NfC)PP4^oG`G*Zq&M7%M77>r9b-0_l8?g`VNL z<8CVI5ZnTKlzai;9f;*Wn7#L(e2J{@slWfq59o7tdI@?2xNQzu97#Qap%XL{IV24s zi;%-vqP^~v^A0G!cMw{(eW)tm#Rbye?wbRMpTl79kcpQ=Ychpj7N?yrhGvq_w<;Vj zBZ)X3OrFyp{CYJ?H8CBUEDiAD2WD#~hC9lp@MKPXPN2nJOsFB)gla~KFy{cFAlgM4 z^*QPh09|aLA`ux7lm^Tk>7%k?1JDeT5JI$!l)|}~h`k7Yi`#*BB7?>8ng#Ad#>5~# z?a%Jvfz>`x{1Qz)A3lV7cln*-_EscIK;t04+UdC}q#SiN2rZ|1INV~y;_8e>?ehlA z5>(ymdsxNAnQ@m54Z9>YA3&!*eO4o=T%QI(Z^xya4J%i#o{7;9yM5O<`4SC&Y}c4Q z)xa!G-H9>FK5q|6R%{=Xb$nW^GjdcKb~7EZ&x0nIIwG=r15kg*852xjm08VsN?`^c z^)IZ2LQNF#fu~>Wr`p|E(LN-9>C$8PUDoZD^(~S$MZnY08^FEjo_YcsXvqmX^Wd{l z5g36(N#_byH4+Fx>NwIz1uO$?GVB6s7bmj9E?y`OhKZNnKMy;MrF{rWR{QodkbdT) zT?&}VabGK`6NurN*qUcpUYs7N--sjyM;ohpInVx>IA%@>gm{9cumo|YY4LNVf(Wp* zJ;1IcBjhKYE+f_UBETOwE_4r-%s7F$as=aque3y4D2M@3YBhMmh8%5d55%SP+Np?2 z@$Vzwqpd9vSr+_hPp54rH;HOc?MT26!LqHewfc~IgIXP4yiJaveD$yj) zU#bX%>KFK67<|e|F@P%)8K5jCzIGMoDL^qGshrCSMZ@I!DKn;sRBYrNh;*swyDK@l zgxy8*XWIyZSO!G(6bAo17jqIR3*<#p2sBG811eMs(|it@q?6Jdp6v^0`zfq4QJ{#= z4z559Nve%I$^waD?=bLHGHKCRVt(I@EEU6?+=&t5YBhic1A)Ul|6gDe#SaRS+gd*)j(&5*L37 zWIyQJ<#h?FJ<{B$gxB7s0?sz09DuabOEgy@1wyDX2g(6^+G#QnIr2y%Vo=lttRZ|2 z`-RvMxG!g?{l7xzcFASvp~5`5vu%XwCWz8bk8sjM>NYBzP>cEt*iS2*73pdVQoF@@ z#sp}@B&1!0K?bea3TZ=ehea~fF;E|oDlHU*4vrO(QvxzK2+-Z_q=KTleQLg}kR2Eh zvLsInA_~;RWdH6Y^`fwj?KnGUq7@ZjNRa|cHO+t3t-&`xk%+ho8%-cI@FN#v>NOt$ zXYLPC%>X!*tCeJCkJutoJKq@~4rwKadrv?{$Z<#`0&Emvc%iW1nHgih_(I0Qf)At! z&PgcXGI|Dv&!8LefC1~&cs>J=dw`Z(cosr@BC0tIxp{uNda{nI*%`v^7GYmNZO8=i z71g&uLAnFYNYLvtltfC<>P}hTLn#0JsKEx{%dAlG!OIT;-U`Us$y0x?&PWfKCZduR zRRE$PZ-C)A5u0~l|9&@)+ARNfOFBPqWB4BX$gAid40Y!{bj6<0a zE&O~q)We-D1|X6=@w|i<*zg@{R|wdLniqAU=(P=KFX)7Wz!o-{ld(0(CG!V?&Xf(Y zCyGiQ@OAUb113`@HBQ7W+aeDo3f+Eo6pG!bx3Dm1R5HvO_WM!}1qn9w{KH5E*$s~h zGN{{FUBS{KH^)l9QB9D7216lE8Bk&VjT3=T#!u60;hLgm1&0qllQOONNgL)lsSl)? zFmQCJ3OYyK-Tn%|L6G+YZsiH=@oC6`;U(`p@uqZmH!A87{llth}~qY3)OCIAuf}f>I*%HUZd($~A&hk>;qV zT}8%w2fVs`nvn{y&oZba=bWE!t$&EBanz=nffsPfmx@>Bg#SNx>Cz8Y@p=+u;3$w- z&vC~zI+*KZAgJIVRkExh6;>F8GVLRWWS~7d|DZz5wrzg%aN165*-;$`u*hkq8EP3+ zP?1dsSBWNttEla=Y&ClM<6z7lfTm2$>2Cw|6^T=_Znw}T@Rx8D(Oa}UNDK+V9_bbc z?@g4h5PcK$89g3hz+qp6h(rP=oSGWJ(TT6){RpHG>-@r}w;{ZQ zYk3DgW=f{wFld#(-xH*9+HQkU1xh!I0fVH@6?^)p@}+li`!x_5l)0I&<2;G?g~5qZ zlU3#APXd;7VhV>70rursB<-%MebBV_GRTc59v&bEQXz#UI=Y|6pRXP!0&yH~m)W7! z&^rT#uf_(EEE39OZ=6@0%x}vW>OG0$gbHH=lgJqq$|T8^b$E=z{v!qeP%{uX>ht$X zmIe=+s|1xKK$wKaGFEcso6Ma!cGl@4L^u%U=6r^yVE`uvJ zIwL%1RkxM1WKvNUqCsF$L$XR%InB_Op+$=v;rvrkB135hC^_?&F08pl@=083;}yg( z70gSDhzln}$#iB^rvvma`SN8@Hpu<)h^ByFKPcCag1u@2;!QRcG1r!@TYo!v zWTxxp&R=%>^YTwe2eO(SVva)KZ;E5n0U0y~j->)%OswXDN#oo%7ngOwQXX`1tJQc`ja zaWI+XgLfE02$?+a+k4+m_ys;x~d4!(i>@=3a+;fl~w|4fMwc#sW}` zC_4WwS;T(k_ABHiN|)XKHf_}m*C22X$OV5pIf0#?us}g=R!p=THjH~}Kt)G+prjxX zb>7r8k@`C4o4nPj9*ZoH?YwE7c`Yz2IyyRzYd@i72&Cj#1d6(8(LpI8Z6gOp*0*jI z(rx?z5Uw<#p+&ed=MzuQv9W{-;{gz&+M^@EGHmc_iIbOl>f#vz!|Oq6K*}|6O2P9x z&c!39gm)wDr$FI2ds}B|3}0S-2lj){@{bI1mBQ3Yu`&c*Q-HCEN4+*%`PpjiC#oC?Cf+SBIZk zp)H%TAA(TD`ZjJdXc_!tG+295kAPejKs}+ycO17n44?_J(kHl1TW2PLS%8%>tO{R7 zq?$LW?l;Why4wdJTF`?tCCNK?iumVKy&gs3wWVYPu1_J2jFsiRZ{K7P#^4jF-51Gsn(r%E}W_@7_H{ zsnJcv4`ULYvOp7G!Xkjef^ZvM19;b3=_?FCVS*J0mH;P&AqS|?sTCeG;!&w~%~FgT z^=uRJ4Q-r*1)wUwMUxK1?8w!TXafAh(L?VZ`Ym#j2lSa z@{m{S4nq65`fW58V~Y3cwH!RjLg8`2?iZr80vHtH!8^fBH$#xKdFAEVdiU`g#qVINg<$la53jl!2KlY%KL9B7u;g%m`>Pl z{a_}di69#%6#6>FJbP<+!AI%|pN-OpO|eq*bd3FXQ$`W)f0R+9fXkjvwkW} z8NsE8V&eRq}^k#FdS0eI}02F<{U91JK|f|Rs^rey#)E*8qyvJ&Chdh zP=}_fzQ^gnYbl?XFI`$zGNG4iDG)EGv}f)yHOoNv!@L8k>qIz4t_fPnP&ZDbC^o0D z=l!IIj`Rp}9vZQQSxYMK0RTHInXN&45L6_2pd?#FB$x<4g9obwVA%aQbV#tr-x##Y4M>4Grq9?GFhamS~P;1+`5iJr&o%AIO z^4zz|E`*4?h!LU^NgD}?0aCn#)>>*kyLmHy3Pro5Tt9W-2Pfwq(=>Hqo8Ik^LQ&cj z*d(@YTg}@&5J{(pG8OPr@Sd5!oNe_#aR8t6@bIvttz?v)SS4iB`#fpQI9>qVdS+;& zyMpq?gV=Z&WRmPc%jVSi7)4jVT%v*y2*Hn$+z5ZjALkKyM2KsWGmWSJG$s;1*k|p& z^rPqU*1A9`PVqnGZggV@f+X*@pr9b+Kl>$LdY9|BgM}bPx~V^By>`JH%C0%0FN+!SdHm5=K{s=)Sj5XChQlnuaPprtbzFG0#7 zz&K$Z70p%bT9H^M1c6d8_ppv2b%_WHY3nb8$gRj|#CKu6)CCKnDBDv@|dhjDUx12^V>N`ny?2z}PFi#{l5C2q#W zfD%qI-IT0p#wEB1*$>Vrj?8XX3H15fFD{;6nAxA%zJc0&CEl|7JX$0&yu)Ad2ndrqMZ#-&3x#Vp9~GF(eTp=6)2iP3F0xt_RwUz7z#!wEJ#v7R9l46P z%{_mEd`crLWx$e?b5fK46Hu|qhMca@YZ;ABeJDuot8Xz%YnaIT*zMHTyK-gW0we>} zc};$iF~x=)G#`7OD#pD+7(qd2*M4yq zh!ATNHFp7~17A!i2`C)b?Kz%V7#Gy`R{=2gZ7!8JTf2lx!y*sL8q_y$YBLf+EeEvP zsF!eemT`6#!$JG;;|HCon!Nn4D+;n2m#~RwxRGlTfNbyl`7gpRFDp0{7ZpueTiK6x zeFT_7243$tX0#PRFT03{`kzqk2xMC<>~f)S3oj=wm6eoy1U>4Ot^MSiBaiuwB~FSr zx0UtXbGj#JS6Nrdckp0$wzh0s*K$RAi8l4@(76g!*W2XGsmkW~^c5~>$b@WTeN{34 zRm#2>zys>sDbr{6c%KvVRlF59Gf}krCQ<;jL92mZ<&vvl*t!+O_b+L=CpWp+*;i{m z)iM&&-{?TE6ZcYPpN1{9$Fr?uq(HFD_A1yq_}5X z+DE1u%`@i((^$tpKg!Hk*~ZUFH{&uBI@QROS#dNqzVeZ@-AusFu3=Tam`=-C!*uh$ z?-d@=t$j{iH7RR7d%A_w*nEGoHh9e@6>N7wd3En?S zbJ07F+umHaL)`0HS8UI_AJKn$ig1G&W#+>!ubEb4rc=ANIpAU5jt0?ZQtEX(I?sp} z+a~GtCAnCBXl-i~pNoA?X+5HTZq+l9_!IZVzwV=21YJ8Zx|aEP)AAy!@AmP(#v8Ua z94tT3rPG-61#1h!5!r@csV!;PlvL=tX7|4o;A82k&@25+GDq{K-!`n981di=$!E*| za$t+#RdMaJr~ML@>C#LyL!7Pe$uqm_P9)`A{*$uc@NAbf)nTFFsm*3X%`1~SpVT?=$-sBzmRGDf$D<`c~)lTEAj?rcuy}ldOv!!a5A|3$wcB2nVuS&`+8iyFXkTj ztInidczmpdd*WCpoRFYPC&QT@cdDZxe-pxutOR z%?c5AvCo5QHJ4{glyZj)mVF(MxBb;TRmw6gKwFgU#96@36Z@gjPy})Dd=8yD(^hf?z$y? z*s%6!lY%&Xn4ou9{=(;RB_4^LREeZlQiUy(%gs(u?w%8qE}{=P>HCUBf_C_^>Xi5U zoHu+M-(AlMzUoQ8CRX2f=Yx1ZA-R}$3z;i3is^3O(9Lpbj~Yt0Th0Ib$3W^k(NZIK zwRKQgQ0fK}O{t})zTb7y4Q{r}{aASHajfDkSyQL^^f@zfx*WONi?>LeYxE-Unguf&?o96PxB4Zf^a?Uf53;`md( zTju)dW5Koq4~8qZS>>4t^~K+K_iDDZQ=GNk%Aw1CLHd1C%&FTtM+aO|L%py$pMz68 z4xP%ds64dxOs3_J>6(S1>U(oS%_%y6i!Mzyiocpjl)RVC=4t;qqA6h22HvK|Q2Vpn zq-cMpIMSO2u6P_h*SzuYYqg_qw)<_>f2-G>A%2v1iX*Y@&uIZlSh0g3n}m~3zW8iF zhYdRq)kIa%c_#FwFOyL7rp9^Ge&$+teU){YskW)J#ob>uoF?~+6^Goo%uDMpb$2Yu z?Do}K7TSajS4Fvy3vR{mb+M190K8S4i-X%>9P^tRee{IUUCMfOCV=8~4 zPWg_q0!z1fyMP@oC#1^c54q-0+9R%69(*U_6k-#nv$yBV+R-Cp@*G}OdV_}>R5pHk zk$o?rA!NIirBh_Cu<^{45*zosnz8(geD9o8JOb`>&-ZibCyMnlUc?i%CO776(H^C+ zITd|!vI*6*ey_em@!QtiZ)#r>2blFHcW-gs@TKTp5E zo$obcdAo<5#l$;{7QwITVP1X5ne}BF@2=Znbrk>bx_rl5Z2@;}=Z1=F4i5c$zQ2e` zUZpHm^*OzLs&&cte-tz~P2^ z{_ul83V4;A^}Ork(n9vMRW~}wXZWldNgkU$(s7faKdp4{3G>VU^`uH{N-PNFcer8o zdH>0Bv#h|6<)gn7j(5fRqL-8UQf=v@#CPieb+B4Xlg+BVn%=e#x$?}=f?(q zgzJ}gHQl6PYRUEAEY6^Gm~DPlf93##QJ12;y!~(ZUdpUF z<$I2OsJfl$NgZQd%8j!jrWqFP^%qCFIv1{<7+5{E`rtY30;QvCoN5~F5+-Q1Lh>0y zZf@osS1)v5FdggtnxB;Ga9%ulq5xvz>b~ zYUcAOxzkZW`g#3-C;|KI!tTrWSiRm#8!l`jk{15==)~|xb;~$SW%jY`kdKE>SI2NX zpPb{QQneaslvl0z>*CHWX!g_3)O%N)f{M4|ZIKuH2|b%fUd~KS@*WkiUFW2|;Ge?c_)88+4Y*??L76!w=o8LnH`^9 z#HccN$C_eq%-8C&9dCN^{OF#Vx0bxqwswJ(D4on3RZj1{)=8{4{5$bGar)RT+0gQS zIj5|Nbh~?W^-t$9L@j;Kmk;_8LOq>_&6@d*z51K%2R7N)uVw43eC8PQ=Br7P^Vi3B zJs*FIoAr*^xazQe;uDD-Y^*%bzSb`(Rn%f^Eci($7p1X4X1q~YqqSXfKFIg0NT~Tp0b$S%}&5(DXBNl#M~LtV!gIbNGQ8aT5FH6*Hn&-P+DE4$rlcv_opbj zF<$EeE5sVv`oT z?fUHH+r71PL}I+ELuz_ zyk7cPzPtIxqRff6nw#U&+4#rOZ2xxKF~7W1?49TDu$dnd#T#>-Y!67NnDPZWe+{7u zn``N8>+D@mJ!~|+JYa5+Gk&Acklo%C3)#!rv-EF6@}6yt-Q}Hl#DrNXGCPiGJy*Q^ z!Ql&9g1a&r*IK@axKeYL|26fl-~ylgf&PCQH|n;J+f{ep=gr^qEiCK;twO|buR-f$ z^RHUdOzqn9VVVMrFRK1gZ;AYUFRQ?RVCd_+q1pGxR-`ID;d#6BLNoL4=G*a8t#!_| z-{US&(*x?*M0t`Jb5-RMG|RrP;4#gOJC=5@mLpa0aEB}Jl-|7ZS?3(D>l*}1)z^2& zD9;b*OjX%y1tmX@rsjX16jw~nI^Ok{KjM*;^%(E#jR~D~JKQf#&vCtNFevJH7+Nkq zSeilid|_&Tdf2VDafwU1$7x-CwDQ!IxX!!sIUG>Zp7han4ley-R_}9-*IxS#9oukW z8vTWOHNl}uFQ>Qj9nxC8mv=>dD@|l>F=%3-3uyfRxH`um$%1Wd&$Op)+wN)GwykN~ zwr$(CZB*N~Z5y}GJzv}(U;V18ol&tPcIC=9bG_?XoHI;X0c#2UH3hu?p7zuueGXlm zQ8{cwNN4g>(jD%*co{lX9T2e=N!9m3JCz zl04&*?ckA+e8X>TC39Lv92&~M;tKzZ?++c{t4IB3=RS)rFjIk|_Vwit9x_d^xhq0R z2ijY$yJG(j9tQKgZ23Zm1JIP4t`SDao95K8*_!pMZsVojBeJsuWic{cfvVNZ;c@IF zQ9_Q{aHK{PW6z}qG6{_K2oPI^zEz4O?b8%R_vW4UlD zuq>~QasR9`^s-s;Bt3~_c3gVX8<8mq7r*MY>JE7lk8a$kx{NJQlsQw{`>j! z4(<7sdr4uB9&!GYeGAtP&&oXT!{%{lyv~GRT>&h9rIHecqvfFZ#nq7#WjG zLV@pmJFZIsZ5kU|rz3h*p@kHsApXQF`Pz5c|L3nM6W#W9Fy)zLvnOMiSw^&+!;S@) zhhP#FW|u8-4SdWp)D*VDv!UF{ZDyRnC=tzgz_X|N2&^pq?k((~d6nh&5e1J%eUuwg zGyQdl$IC;Ayki6mN!fT;X?Kp!#ef53m1~!tT@+U>YuN*BQyFn{~%|m-44B zYeRra$Hq16`n+^JY5|6h>Q3!@mIgYfz+i^V!SyH@z5LNNR}`Anoexyz?g~Y%egNLe$|WXwEc=Xe>k2k zkrvEXl0YujrcV7Ef|6`|)p4(1@k_bb-3Bxqk3gk_n>)Ys#0d_2;J_?Ur-Z=c@Ekg; z5xinRu!<9};5^Re<;rIGH6ixg3dCn6uvlKh{9t=D6NNYS$Ud7Fv~V7&|CFC#9WJML zH;TWcCGPG~Oh?lWB7-X38E#W#(TGv~pTmvN?0%xT{%=NluC+#5jU>qEl7`81b`wm6 zTq|CCpb)ZT{H`eA1d-ddy8t!q*+!2WmU7x0fIvmpN}MKPtjyoX^X-$>J)Zdf4EuE- z_rI@6x~mE4;?*602soJ8E4Ge$5}f7!IVd)K4>)o^aW{N7>vMWg-c?nZ!dHnDjTk)F zAfq5eu4y>tqI}e8EXkeZ_oTGHmn$rJz_@(C0JJ$u%eFf zL=O00&H^)!ZyF^{?hIj2l|tjzR;B-YDn~EWq;$A^?;ebU3!NflRbprEnxVHkqi=GT1(MSM74^yswMqzD&=g;- zOq9Q8bIluU%lL0&C|J6GmXnGV%eKd3&#n94NA>w0KrKt*CD(zkjhvnk9zxRPf#>pD z+A*15HQzO@&RzTynj6UhOXPbMLHWw<;_EL^7Fsg2bq_1bE$a&+W>;30`j;Y7d zdwU5_8?cUvQ(n^qKzr9JYib(HV=h;exqQnN@B7sp1IZ`7uYb$RwNaWxgPSM7=o2=+ zs$60*^`GOCHM$2^U*bkX_Hhm{HgiXrQE1_Q0`HBqE^E^HVzELjBq-zpJ>*}%)xF*b zG=mB;ZvH$^X}>Ih5*Y`oPU{#BTRiC2x4<*sOf1cDJGIBU8VB zR~bp63)U`uz*2@MXjt*7m{yzD`F66sYvn5oHlgiOHo0T7J8C8VxZSqQ*N zQkERw^&Y+ZDiy-CUN$;HWu^aI^Le8M^yz8%!+ICD7 zXi2U$hbdJgK{wLF?D1iUK6QQvrBgzSat%!>t|lc}k5)hv5Hlqg$-;Eur=;S+7)oHQ zQp@T+0R5=j4xB6yLc{j|Nxj+aq3K|A`)xaC_ZKf9AZ`0%B4}`xd>v)Lb?bL+-uBE& zt`3hZsFvtF{%w?CUGOcH^*L5WA_JQr=8&EuZ9KvTjes{15D(zFdkj+*MrEuLOU@q@=^DAF5^U z{{PDL)}#N=oG#(y)bxW@vTSys6&{OP3BD{tljl1d+Qh!ExFiB9R$=5=mWyk%#xlE# zsG3xWFwU~b9n+d_ZrnSY%&gNc!34=rSZ#gZHe50qTrLwqZ#Fnz(df5WG+zqC>?2iR zz1&Sbj}&;MhcVUU1zWCj^0wpQ+_%BC!tiG?0;{^p;$Iw02Qkq-=}IK{&lox|7voVY z!m)^SVy8-RTW?ICOP=6+f4rm9%SAGl+Rt1h)w;G>#%zaa`pjyo6@LA*`C=XD?0u}2 zWRpo7_+F z%zv%;|J}>{xbFX3)bLYZ;dGsHJ6*>t6*cFHTp?XP`yXB{>7SAuT+*)~C0xzMaQP_T z3uezdCra*Y*Vw_@2%l&Ux?b=y#Z%q0ifY4#ql6mq%UcjQ9Y?+*(`OII%X`i))Jj9c z{rK={l;N@6N1wa<>mK`&nRIK`t)a;(gIO(r+&IZOPcn_#M0t|B(55YL*V~pD@W87k zyUp?SM;#U1wfZ4^c}q_OEOW$bWTsi=+bP{dl)IgzK3g{iZpbpD$H1;bHd*8C-CdqVu$d zFNo|0mkS?d3*OLZ=EUP|pR8BxH%=6z4r3+N#DyUh`{ql+1{F!nAl<0xzR^BBUSpXN ze4P_@q@;R6pSk=Gii5#CMK*zQIU0o!f!!hVvEec}5}}gzk$u2Kn4r3AfjOdqQsSLU zG@40Xm_Sh6%|Gd1Qw7E{g(kaEmHd&nK*t3QuG%VfkdE%$U;(&M6RJ1IkN;Xw z^yxM2DjrT?M4&lHAa`b(o3JUr_**Sis8yhxYrwj;eEDnGO(;W`jRrJKi)V@k;VlQ@ z3wJN(8*H#jfU5fVLU|w-)nyMJSL#qS{;qG7`ahsYf;rk}gZEz6{%fL|!3da1bhLHl zX0r@RopkQR1nwgMG;RYd@2`(_Olb#;nBXm$;0xtdPP&tp_Xn?(a_7q>d05@JTQi+4 zVP`;djvelN;H^u3YEq*lZ1^fDFS`>G9zb<(5}u5MoofVZRh)K1LM?t8ZkTBD227Jy zkyT9sd(IUw!37vldw#-vnZ07xJco58`li0F5-$}s9r?*xQ{$2G9mDcmFvij;C@+|9 zrC~PLxa!S$bb9Z5P=A&4CJ`&zX}nQiT|Q&_aGJ+1P^2r`OdlmYr!zhsUt?%>9ME-i z@pcZ4Z<{zce@@a~!xehLEUN3m%CJOr#fWzE@-K>FIymU3+SKiLqU79+W%%SMjx@X# zgguyJ2bZ6O)u-I-PZd$C63WKIl%4ZIjnArXUSeMQn4V`Xw;EWgW;7Fm=n7V=yHWEtG_ zfrGsmL&)O&r%o?)mNQ>3-u5MVz6`-C;L^jqUQCjiAp{HFS`eJ4kN0C#_8fDZeAysJ zC>Ctj(SS`=wbd`;FP4a~UcC3-r8H6}{&<)SU1mu!SFK$;I2~Q@#nAB{C16;vnMUvC zy^fp07xOEhC?>xkw;h}fAD%7%HOWo zFsY(8GR|Q*Zg%i?x%P^0I+3E(f@HhpkTsTcCkz8)#?6)?zt^rwbF4z8*||8$o(9Z) zvl0#+kn6`wheCogL1L}B(uDbgUH{U}M6r2OWCyn3u#=CORW~BXSr~y9FC-<)cf2pk z4Iz#3XbL4Dh+k2H@}K*S!m-yBHj?S}=f>GnVJIC}bin{eB@MA@L&DRXfngX3`N3|C zEkr1Cbas$510KbE!JdjDriqwB>z3*yjIf?!Cw;mKTT*l*#5+AHW>qdfrG6-T-l0$*vLXb=7w zga1c{-RvIc_*~uA$!pUk&9EAWyBFb6jMf_czDLb2ED)y(@=y zn`Qc5T2)C3aUlbcb%>*T^7yZtCBrUb-gf3dMHX8>&-;RlD|)V;3xxJwPlN?y!a4-Z zv-=&(J=rmyE92M1a_Z+P>f8eCGa*gxZ%wFb8I2Mm4@ir}?e?T`sERmBxG*y_Q*nr!>+H7xO@ z=m8mbELWrr=^^z@b6E;Bl!_O9e;(MS(OFVh&%TS}80$peXxf@O9x z&D~=lXzK%U%p&CUgU*bA21?OWkfbrGGFP{8(@kn1&Qq6i0o^z?#Z)mfCyV|GGu6hq zksF7%EE^m??**jvXFKnX(PERtmd@|&lZ)0C$CfhZIX({$OB(wzts{fVP5Bre%4cNV^_VW-e#RhdJK|J&=}QpZFJh|^cu$Jhyo zPza8tpd~5ZQJcYO!sn{Hg}CWshQrb{=ep&veWHmh3BcwbfzORcoMlJt75^%psKV#I zBDm$ivTp)Qnvl>bdzf59&rgsd*XsOpV_P2~LdUl-jP%I{ld%@y52_{K<2+p~6;AK^ zDepd-Wv3zm^zn@tf0L3B>Tt2@5NNSRos(wp;`J)@gRUGYJ>_gIirG0hn#_zGk$=MA zd>>x3@c1#C@%j{p7k<7{X|i08mhv{>GSc_Zne6rYL1Qy#;k4A#>JzfHn81lUFO-{^tU=mG>7GrW9`4EUgQT3s)swhGl$tg? z=VwvWHT-QJfr@sRm=9-xDHGlm%E=gj5N09k%XPs*^8QVt%;*%9CU0WNf{=qk51%%uUUp>P;obBm-${TTqPmT}Q6CVLSw)Y#f-LNnd$N4)1B|l)6 zg8;{`4fY870C0NS1t#zAxL|1}BoA*CU1*HKTt6f=XHvw0gn+Wcw>qy2?&n__(h~|X zaVHZDvS*{y_+nxoKKgyRlAB$CMXIo*c=a0(Ysm5aIG-Y_@I_7EV1M>(lFMo{SUyFjouGjROGF#x`zrX1Gr;!mHjFrmK2_1}(vwZE@tftnW=5V}_;Gpv4^GI|p zyDdctiTVOZdjYyU0Mfq&(h|evPrcK_C?TGc>(rU=uO>2El1E2ZuaS5>f&4@zwU72d zA`xk=<5;OY_T&NJ(I&UIe-*Mf$P{1h5BQ&QU?t^xghhok+os&-i1~;uf(FLN{H!~1JA;;|QN)d{i$%%HH)8V%{b~Pgm zdx}}5TC)MI22>br=lCrxc!pADzuBh8ey+CosB0wJ366X`JB;^6*{sdz^>H8Y6D&}~ zjj-aFCzOg2HqBYKFF{Xev&~%T+ZM#Erku&7ouSNLe}oP}MM30yB;Tr!zT6($Wn1;# zS8*inpspJ`8{8LveKbBRj+JekG(6&Z*)QSXNOEl{!3fyf0R1J4EktrFQChXcOM#(< zePe2LLk4Cr4IYx-Vbpo8`TGu6^$(M z(HkFd#0IvQDHVa2RTSGaJpO{%(3H(rztC&O4HI>QE)*%i1nI(`UhqlJfI@Fl5iLP6 zAa~dk+bU^T^#_&xfHd}MrKHfVhu4ZY$b$&VS&1DewhfLF)0Y-U=H4LJcsq1h6|U5rW)~JbavBp910vVToTIbsCa2r}09Eng@pCv!OXZG` zuW0Sv1P)=VJtq#yz+!V15l#yCL%04KjLhy}<2;6!6zGy4$zS0%AsmO4E%s{SQ)2Mo zZuv}s&pDC;z{BlLSL;1g2W-6A@fwVM;^=uHBGy0wQ{4VZ)ly3pYs;$peMs82HZwS{{ce>|_ zoDwZv+|N%#c4dxxc zEOTx7lAN+=O64d(CYmFI${_GX<7jX-qa^<1FZxS^q^+UgE?(!43J))hkJi_Cm=8Nqe`dm)9!{xhLGefuG9BiL$u&qsZ4~#&KB+-U2ofB=gwgIa( z&b7G)2wIEuA8(h2-7YvXpY}nva%a5iTF#{3afa${TH!<*I0 z@&O~Nj|RtjBR;z+#ITBEuja2ttP^=7!uO{QJ;&2MHyOP605Atj;bCec9v=;n4nK;B ztDPk~6j9^FCMbwM>nF>m(4yJ*nsgHxW8PM7q}f{%x;DRmGpd?2=I%2%L_ZtDGJU>5 zG@Nm)4IDN<7R;t9OKo$cMsYhxl4kX$yYa> z8)RY}a{;S&Iz}HW0|ogym=csj=P5dCX4=qFN>xA-cOV~Ffmz3pdpktT`eJPz-Pj#> z-EjP!qpEtp({h~J-nL7~wKsP?8d}tqo%uphb{dN;h)kHDfi=}nMEhAiX`PUqtOrbJ z-O3Kd)&9V#QOBes8zqL65L$;Nj;5lM3isCqfD_)$nqzGK`n~=zn+Ncek**3V(DEHr zb}XZh;x^%|KUzUHXQzY@-C19f{>j4rZDz^3%RDCifj;fr$-Gb+Laa$YRI`T#cwFDu zr74V%0L1sU{b`a04H^9t=z_X%82bn2!lFDQX-mrbqXfFA$U%PpC{Q8dTuIg+-gsb1 z3{r8Ss8A96-s!%d2K=v%@g^7;Z@+A|5s+TKy#aKExyPM@%;Qgn8ixD?Bp{J^B*&En zT#%Rcyu?+(y!ib5S)OB~HX7zQ!v^03UA<0bXBz_xcj~wphc5+qrl-FZmHTn=iQySJ z6`#+?z3J&v>B?ah6$4l=TvKHEe#BM3`)1$EM5SP+gYEuG=! zp{EPUA}BKuvZexmP(Ug1{`^?waj-oW<>-*xHu3#NK z=AktWi192pItduq-7|qfJ)XjXx+aNS`syumN8j$)iWWZgg#mh0dS0WWNqk6<~b` zl11TkH$;k-E^o&wr)dJ~zkupuvhm(j%9KN}jcQX6{C-T!HZ)=xI!-fOVP@xi6=eD` z$rc_g4Q>uc+v0?+1G@<^af_!SvNaPCy5g(NQMY8r8%6;|4^L`+h=q^N?njbEX8J1pu*dfOZsPqDe0E05A0%N!CYlw; zTZWz--sz-qi_;6o05?U}*ay;AOcVGzL+o+f<--U%%-x%OK$njRhl-NDS)VCOP>CMW z7NmMpYUM{*LT%G>W{AM(;_~ZsDrXLyPL1%ZbWfmGujcho6KY z|NFiM>6{@%1FgE~L5d^@-t+uzvU zTzp6mc7L8Y!vLD)8}A?8L#`3a{wm*Us|N%e06}>Y-k(oGari1=6Lyz0L(Ukl=B`9g z7#rVDv!)1GB`{dN7kGIejT7D2e{$%T|2$(VX-weQ$D|QWi-$D4jT=;}nhunb{NNVz zpLd{NH}H)BS3F7@I~SYeWI1~!QB>67z}8KfO;xw)m9UdYnzX>d9KJlAqbQV*4%4}D8I^BJSKVYviKkMlax4a|f ztW@@9%Rs*v;JCO*x+rXD?3JMH-KeneYiEB&U7F7i3o1L1GVv$^T&k$ zjj6h2M5&Ic8#Z)(01-7rk`w0hl}!=#{P`y(OUOMj!J5gHQiqcdwu>Y7wuR=|_P zKZ5n@?1SVqxx-blSJ;o8K0Xxu*$uu3luQ)aAE<4@n=5C9_|S@|l>5fcg-n2848m;wJQ%AmBla*WzA8!Jm;@= zVqq5v$Th_i9B<%A`g>vkP`n-{mB6K8`c@zt6YIx}5^0EQu2tBTh~kFmyWv3eJ%v8N za7-i0W$87#)Ci}443)^WUXxz+TRf_c&tRDqrHaq>zJ0C(ND$ zu8+)Cklom=ItO1imLC_MrZp|_lH;C!LS1PW#LZIfiz@UMZ&WBfMY3y_Z$u4TH61(y zvIM9q@pN|da1=iFes?Ob}WP?~VlT=3_4%WDjUpPmo$wHtN_2~}HW@Anshz(jW%<=TUt9h=0rmcx4EVx`TtFfaxhfHT_RZ6cS7h2;}_d6veWf%PHK+ZuV_^nip*zYx#JFZlKI8zQyce)BAb)p_K6r zR1HZ9%ANg?9(?>4)dKzBFQ+dEQrhYJv3Bs&I<5YtWmDrN)?Y4)-ODLbRVL$4xlvoJ zDc99W`NyTGiPYyQAsyX_sy)B7et+7`t9cA0%-q5_NM+ZrSB@`&C02WbtVE?YUvXsT z<;=>tIy7H;u@k<0Wq6%XMU|)`q$SE}?nRZ9W*~Wup%t0O`M1~KhydOIoDUin{gv-Vng}sugjAG6CY#jP16t%`J zavQv8_5Y^AgTlxv_-j5A4{{BaKIG=cI8?i35erR})eV|ivAw;lwRHIK|0$CA%|Zm$ z5Gs;I|IL6!ACaq2VV{V`&|f7)O+j&DY3^}S@T+-24VQGqmhK>;Lf}bA4MmtQ88EKG zlz@pAvUDW&TXz~QcE^v;+FO2|LgBVW z)@&Hfs?h2y!+>K72y;-AI(GJt-z2IJziIqWw@6Zzb*FbDbS3HQ(&%o%Oa8vA0+WO| z9c4-VHeW2MoO3bU%ROs_0KB1-giBaKn`D)1r!&O9| z_!4uk(}oYWSP{6li(}a`{ZCRHNPch}kagtE&*alWG49|EOppHXOAJ&)d?d^tp!(pe zKaV#U6$=W&ZRdQ`KH2m7+RzVYfWU^TjX)&psYPj2wm2IcW_H% z)!!$1;n3L_$riOBb*GVhZ*+ZTFZiOb|w&cTk61b91MuLfdXkA!jgBnM~OkxOsQ z+L>)>z+z3Rm5SEPGLu(M33i6PX0cnq88B(~)*xH0JLc(C>67PAIhbBPU$k58k)%&I zMeAB;jC3csA)Gx~N{(|Z-b(dMoa(YP><$V(tf}cELXT{E+Nq)z$i&ELES34CTp^OV zSnJJ?il#bhASzg~xwYoSK|w@UvyJ@#X{GvR$@bAn3jYZR{+3>Xmq5rUg5?%f9nh`) z*Bm;6lQ~n$Fi0{R__f48MDq?4kO3A29a&~ad_F$^9`OZ2c5SqHbxtfY6#FlB|7c(4 zpRHGV64J`Y{ZD6&h|O`*2JY?nMyvIpU}nkPM8^awvgZ1Sly!86%{lH=7GO<9a${C1 zm+e8*=3UKt7^@DZd}DF4_^;0Xuqw&d!{)e7=Rmkqwm?z<%*XAdJM4MX!F4%>Ko+i* zsd@>=*Ahgv?n@F4mh8=p>8NUJe5EZ`@I!f|<{*Evbx&lUVUQztm>J8$)Ku$;Aw;V} zleSAan$$syGUv;{zkW)MIyIWqIR$YU@d5IR(Oy{(cm=+!L(l}Vy)Fan)s~R`<#@4S zP+mBgM|W?jNqm?;XvcM5avWx0P{x8B6r&U;Rg_$x+Of;mqa&F7Y@}KT=Z_(?7DZ4o#E6Kn}vaxn^q)TH=fk#dPeHQCoXh8B%BF} z|G(*^q*VQZ*-aSmAm3c)A0SWCU?E$hgzm;7VQ?Xc0u4iE7qeS<7~(e#ao?#h2#oP$ z%d;t~B_4_s@qGD15X=#zrKPG%)(tL)ihv2WsS@nnl^Q2K`gbfNVs|c_?e@j{rS#?zp z{Cy{5#av$mp;efpDG7)`XrQ{f4XsR=L$-4U!G454X;K2?SfXJ(ZcxJ6nJj_q`=VBt z6Jt7)a{tvKHv0!A2pI@=m)C9WU_3IKIHs>uR6zjt{e8J=>;8EG(VHE^EL_5`-$fBb zl)J0iXJ^t|fZ!POJw4ZvNlth3=Q7!2xHmh~HIGOQpD;(pP<@UVB8TXBAl2_*T#F1K zJqm%&@3FXfj}+Q7aft;9E8WlTU!W!7Ik%@J+~oV7Xn_7GZ02;oe`k%>l5QoLP#Pqd zXfVx^rxV+*$OWpLN0n(p5x@LiLS{!Op>yrV<;nM8Tb5DZERI{0afba*j#DLe;i3^Dmuf7u~?q8i}5j`hhSXmO3H$xV+ zX8_+>(rqQ0YGK!EOBsyScMXaRt{$8EV>^ZHMq=uOyeA~%Bh8C1qZO?B#-KCQ0y{|} z^Eo&nY+ccGkj8{9FGbDHyup$*R8lgF! zIUNW5W^QFQqi3)Tgw`DOe=gtGiz2<~_G?*N&J|vkVM<$#navoONwd#{BiMg2)RGDWyo#`g40yc7< z=0o7~3y`!CkqXuSUicmPb&E^d5Cu8eK2s15ULQ^>gPk=oJA(d`Z#Lh!d2}4$-yIIQ zhAp;m2mAjcV>cdW6d>NU^RPgq^y!D%p5bLcU!%F zWW<>1ZlQ_Jik+eOUAN$lFB!{aT|HbaCLAqp^^fnuVm)6SKOps$Avb9>=Ea8yM)M^{ zI^5O!0W@Rs7&~{!IwAO@LUex*uB0#6Yy_s;Y#U$95w*Uw?!@oTnk*H?B^wJn+Vj(- z5JA{$d4ALm4aH+4@z+B@KxTIStc2sb`kREeHX^PcFLfuEpY`kpTS5=SBBcH%zp`VD z=Sx@h{f{;bj6za66ecL>2h}$qZWcOMUYK)bBZ(1Yfw=JKOY9s_L8tb%!lS3*>=;Yn zkxcq%dJ|0C+p6$vq7c%YTtIPX$_58dZ69t2lV<2XvLcX0|y&^upqg16KQV8HF=V<1LNc8P2F8I zy4_r01LIZb;qB~oWvYtEXs|@)NO!P@Kp{N*ee^UJ~ z*Fbm$pvPM}Glo4;$6&PU3viTI*Lg(uMEhQKq@gv5&9tmyGTEL`S|~P#Jf;+0i}8~8 zIWx{@2&A*(oB5lo3AVBUwO(32<|MqMhMC(7NuRi1zqOC>j0HB8nqjV|9wIr6YqEtH z0SH4O-a}iCzyRLW{xzN7TIs7HdOwpaL+}Df=>}NE`E=Oj*|C$`7|O!KWNky#rn>jW z=_lMc21}<#XQo5>nl}9tj@~ok*5&jW$opf|#zBxM`&^vIs8JgZWFM5t((p`QbN~wG zoHKE41{PfTo&`r)IvcD_R`ky;T*RtSW179N&_rju+q3nZA%S!_ZAp6#>j3$fX`giIAG1TI4rfG~WVG$}zbWr5pT_igJf(hin#l7t z|3n%@K=N8wT0&}nL%kC!$CT0c%iYJTGTG3zIO$~#JDbtK55>*46-%TrRr4%qW7K`y zb9V6uLqyTp+4*J8O7fNS!_XDF%#B+Tq$AeD!fQsKqN9i80tUKK8mx7CbH5s|m@SKl89cW@l@+;CmEj_Gfl)ak;0^`E%{D=)+DIZ@?=i z5PZnGr@#mhbY9?8GF@r}&s7^0Sa#5%TCw}rsM_e?Cm%|kYv!EcgfLjEkx`O^I@*b& zTlwN2XR#2awRkde!d)5yLRCO04ydFDj`<8Y+P20IZFH%uJSQ0Hu=ys3}dvd_{?rf4r) z%^Fwl>KZjt1>cy{!MiIAP=gI)VRA6)Mmt%iv;j3fNi@O!-5A{lZmoYyBN-6Mftg6@ z&+&%MYaKCYiRu?I@qNoQdxxlGMkJaS2-c@=W$B!eQ@?UBi^gB=>Q)!_Z^*ekXnA1X zlE<>oaI{pEKs1*zis}vC&JAsu+PJb84~ulZK$zsLX}h;cNRbxEN^q?EFz{*Y_tFOA zdJU5Psv~oxqt-1gX{;Asoy4S8x@P06AfYK`W| z2!{-tA&$eg+z_;R;Qn<7TZUnO{CU`!*M4FB08#*Ju} z{R16%f|w{t39kp-$t@Op=RSB58lI7^Ons!@BxxzH7Lug%CX>sg;ivOG0QcTLY|~8` z$M_PFv3*o=#e6`V!FGw8Ejf`1jmbdqWTox{Mb;Qw{!@*}@gFGl7CT1GS+sU_rR2mftIPHNuxbi*o$b%~{@N(#ES*RpmgE#Rjy3Ib4DCi8B#3Hb| zW>O(-k5KdlUtHaVBF}$5?%tQvnnq&^{mUq6_&d7A(3sL}3~9ZZj~sIi&G9N;8~9Hw zv-MIrU> z>OSRdX60-;k0V^7a<^-Li=1 zK#4=+!co##vZ<*KXiXcfQ++tgx<|d-MGvCT*T_;E-FXu-7*d)_(?7VaEDjP7>uPN) zGBmUWe5`j-6AmhK=?IdA7r1^n@*XffP%hO<~d0fLYPA(^l}#h9|o-9c%& zJ$Za|erF!pjFmqUT#S|@EjB|?rT_WGgAI)yQUmJjy>~sv^7c*r^|gvr#(qHb<@&2X zBrAPmU$F@f%#&kZUoi>KME8I}wv{EO$bd;W_VFT$6Zz0?OEfsr$84UP87||@1%7+z9)C?-Wt|Bd zMd14{3%Fp(`ELh5nC%>lUCJjWUd!;VrWh26@)4}sw7a^Q%C&%wLxx{Y?Qq^;u6+|2 zf;v2$8}`iy75y$>yuIC4x3yZRnj|S(yNqyDrQbHVp~hMBB&Y4!oE76SJu*x+e{az$ zW`~nd&4}y7=P%jfl8)R|%_YZBp$)r?czAlLL`*4UlWM!)JT;jvU3P^}sbv@cd}^+u zsip+pMKEInO)gqSJ0!oD39i#j5c7&I$r$v(J08uxHjUus&M7Vz>> z&|Ik%T6<&pWMV+wh@?_+oK*62aA&lwl@2E+uj0@Z=%UjzFmBvjEO>@2*OmK@8Fpss~>VNiqp34^rTiiDOLV6 zW4WXxhohP3{LJ(!&ToL=`<(QyOrNg+@(3x5pKeSa6OG>CTA#93-ND}L%^LZnsDXz9 z!6m+AaW>#V01c8C?7><}uHSBnE^N>e(T{izhET`Ph$7(;y_F*uRqN))4~=t*B!51_ z^{8?{I6%X8XkeIDqv%{JG-#{1L+VKh70OwN1kLuIIv9UMY#GC$N#3N+jr-+pwc?Iv>PeL7#%^Rg(`5(%e>{bexG3=Rzrtkla{bM^}-!- ze&wPg81v)dQ_Q>?Ufw`z$J$^czw~X_C(-;8uh92c0Ghyea?(wTsKaw7ip7!a!vn!y zuW&q91^k+@AIJRt>8FU&6TR*0&2?%{MkWm&jn7>4na*wIX4}uLumXw@gCI-mTv`9U zRpvch&BaA)gWWF%(^O*W;Gj)byam0?^_UQalji8p&xxWxOk;5}`cR(>L_BGqY?R|( zTeSI&q{&^n%Oq;m+i+?5sZNw3Y-DRN>ys!s+XqxoB~ z2<9mqYm4puWVwZH>j9z$YhADd)&A)cgp(|Gd7`D^Uj@kOvcmgsiyq?YZteA*Izv z6)W8^G?a>4Ln1v)x{nPL0k{mv4N(QE`*p@qc!>C~Z-1Q0DBul&YJ)uDPz|qm7N!uM z?x9DQ+1hC!HASeWl554N8C*Qgp&;B)CddsY)r*sT@oo5jw9&}G)XnE8`Gg8W=ntbj^ITdg+D5T02)jrydZs&^D3!AkmsWL|UT;(|@>RU`GyLTdk3Z&4M5Z!z zgg@an$Ss+CGeqQ&IWT8qL@$D3;ER}ZYNHQ+AybyuOz2OOzdj9*$+)P~jr^&2H1KUu zdLM;zW>)}xDSJ-zWTgS2X7EyGhkB=nfQr_FF*)GyXLi=LfHtOA& zN^+G?$iNJtSX*&9K!!?Z6Lo6wpap%-Q}7TcmbIH=LYM2YQrA7Wi$7GK@EaT~AnrFs zD9V0^Tk`8pg6z62dr6xeA`n+S*c2OKF!fU=c*em=d+>QdxgBcRwr>wj<7~1-hKPRi zzLHGE=nk67k~~(_gr`p%5DcV?eLG_;(=+nEwlmtDgM0|th!yUV_%z@1%CUh-MapA8 z(6fH58IDM{TB$dI(emY-NRPC_(wPn%4qdwaQ~;5?vmZC1@ZIP=wTBpE|B3@j24V7k z&GppSK%S~|s?cHZQt@7&(w98eSB&)?#(;1xJm$h2ytz*(m>sCS-9L#CRp~oXJ6be0S6Q)bdL2VD<5sR;A#@U21o~)$aUu_f# zB`z)DdG}}M%zuI&a=m|@Vr&th2QkRM$qR7t`6t5t(DM24q&-Cvp(;(8_H8QjU*_!4 z^`zqB#e#)?5#Fh?>{XCnEOG3uGJ65XNC=F>hS=IpdL!qbKt ziER}P?6a@Nmx>V$+?tn!y(GXB1Cj-IYwDmlW`4s#q)Lq!{32x6o5L4T5zI;Y|=L-Ed8*rnR%n>=;TRM;orG)cJZ67qLlq`OX5*ag#8EbIY?BVH6&{|UauMt*c zY|ywAilgz_D-V~#pAlhfVXG{`c9L(*&6^ZU*96h}GRM6&sjX_Q&>=}UVfo1`GKuLq z77g@nnahI52$>3su7QhI)%v4#{hE_>tdx1Y!-cfy(luGNZN&YmlgFQPqZipc{$vOb ztDR|zdOIOx>*rDFSi zk?d5jIlTcYpvgn5c8Q_{@^*qT(_+p-sL2!`zuynobyDy!AQpsTi0-VWH?Vxl%y|c+ zV!I8W13Qtadhj!1FGN(>$%H0nak+2nK81_G$%_Pz3&_pgte79)=?Qgm@_xO*<$9>t z?a57;XcS^8HFxVAyaPwCD3o&z`W%pM@dgh6wV`1Vb9Bpr)9PO8Zd+jdt2^A~V|bEnn6vN;Y#kqp-HMB~*ZbZo0}zZzlwB z`LeD6!vAE5AR3oTu$?$xhA)=8Z*g!$J#D>##1jF+q7sK2`yI>aU5P+$>PlK~!@T9_ z7!(7;Y0WNF>?l$|GcX()GkZKB%@w^ZC}~TjvB3D&m1C0Hj$v;-Zuha;$*#h^TpZs7 zOorC7+W18EQ+U%m*)Jq6ZvuBc;&*8dCa*JRPZ6)3*^u*t7?Lq(%Kq*xsn)hf(|fI* zUxw-<+3AC!Di71mhB+7D!4c*s{jce@&5=s3j*{9?5;=vM94Tuv4z1?*M8^*Qg4Dks zF@TRkslueewnfMQVO)poU+ZBiCXB=)6W_xe>=me6*X$)HC40q6G;=5vItv@ue;U8Z zP#*Ms*M6Z6?He;s3HGuijjo7`)u7!SKXY;cSE)zrxPFbGZ$UIbzT}M@R&^6jMH6okJqMe^f}E0mK=I=JuC=BNz{>+ve>6pk_N@ZH>lE&c;O|8xz`0y?3doy0z8HH)9iIr7 zYR~CWGntT-5^6(qh5;MORlB~9wgAq~|B8gXRB-*X>PyM=)MA7>!Q3=x*%1uFV}vSqeWH`lN`zBdEFZO z{%ML|A3Iwpzi+f&D?rhPZs1GAI6EvG`Q1iR_TmoGe~fqV7Bi^59p7&(!hlz9(aLIu z-Z1|lCyL*L?_92f`5A1d;j<_$`w8-!7_AgfJ&TDpbG};;4LV(x$>&T=IwW-jF-Z!K zx!9&JtqHy-L7Dywoy?Lg&m}!r#Rc|z7U|sgP{qVNcFRV$8cG^3YfC@gd{{Wxt7cWcdWXpkamBRaJszvO}r6T5^ z^Q8X!FM}nW@l`R%ep5fu*n~3NjFxn*dMac3FPv02zv)UHPHa5YfElW(jco4F)ioB_ zzrV>sPc9o{(Lr(^UGL!j3@#wS-L^KxSco^Fma9ezrR!y&w0b>(D@^hob=U^Cbvffk z$`P3LckXjy+YC;u!QF=Q^>1+U$0gVXz}$1B&0Ct%V!w$vd2ZqQ&J_0~xxacFYCN5D za>d$7LJbdpw3F0;Zqxs5p=B-?D+-9pS|Fj3l&^~ zaLkaJ+CrD0=&mTnmJZ$%Q`-{LGPY$dvxQ-$hNlZ}o{cQn&FQqj16eQmQm@JPt{)SU zzda-Pxl)VIZZ0pZn?s;=%)S?=4kae)p(bFp`&MRvC$^YIvY(IdKoWj*X73ZY28Abf z54%kXta&$Qxc*Sj_uJniiI?PrCKKZ+K7f!~oot~HF}8QPfc*-aP?if&h?JdD!5 z6s!n`u}CgR&a|SQ^Ks>R>*<`U7JKAED@FE--Mf#qk2Z6gACHWjBZLoTAKKk(#)MKo!e`4N!st26>udU2yBfFO>Z3KFco8s!APXIf_78 zcLJHph|*!gwa>dC{j~$qy~+!Tr8?J-P_MtIO;>oXYuPIF7NpqE9^SBRvlo9$hFm4} z>thl!mVh6v_G%_ zS{u4n^$P$8IU@#RJ)#Ios4xpkdYbhbQ%>+E%OK*rt9dvr5v1NhymUQ8RLar%ZMuew zJ0BW9)+{;*-y%!-4U%LpGus)5$!k~&&zKr)gk$l*Q?gF!2xP?)!s6P7ji~y&jE)x? z-=b38jlf1Pv$k0Wv|~P+q3gsfi0-2?O7ly+$py#Q-m_Oy7ZW!GN20Gs)!LQ^B)e|6Xy|_A1#jm>FV@!5Zz+@yQ}kh z>{gZ-_n&wM|9c@}S`WO>N?dLOiFN&35rgm6@EZpAi-x$SV?1_n* z^aEi;QgbiG99qetzjesKDrx{1eE!|#MBdPA{gx92o)Gua0jYJ%=>tC+7WQj;npO;u ztrM#Ord^36lrTWucP<4f%hW??%;)cC^xx$S(yyYCLqfHHOR|KGv1CHE8#g-)H-ks< zMs=N}xD=OCHKD1``(sYOJ1zyb^$MHkPdR%7IAN{IykEay&X%adKfb`|viL`vdT@uH zxyQ+U2jncC8MSrlK%m1xN<(cyOQ$6E?sv!X0xASG#txRdb z+P~5rjK&PXFftOTpqGy^VN~Gq_XIhjXwy)F*?*5M!iUHro2wWo2%!q-*k4eg5{?K| z*M@G77EOke{d?B=;{4Jx#b>$>DAjy?`!?77PffPh6RPdMbMW@P3G?3rlYba-X8m`n z`k&V=Skm7A7IOdjE&$)&L=5$h5z7A=HE0s{fTiC5Gu*3w1^oAP|8vCvU!?IxS8Zn#Ke`1kf^(!YQA1%c>)9s1jV z4|hIc*T)Pg-4#4o@h?>dq6l#MhFE&N$!fC)VkBFX*=!V^`oz-r_0<<#l<|Ji7ijit z#7`1mA-l$puF0nKasM)vHs&fU62FD`@1qF$>rHvhh3b%=LtwS(@39m(a(DCQJ1Rx} zcxb#bjR!jZi}r|B!*RE^h(9572Y7g?M^%TlvlGml&(xxiTrOqL&FND2OC~D|1gjuk9NZby~eQL{<0y*JdN>ufsa>J^v~rT4g+@4P)5Yz9+LC=v24W?BHGW{laJ ziR^cMIvhE^VSN~Vtw+!)+0dTG+3nRZE5!xjG_N%26OYp}eqY?B8~D7dzSR#y*9OOD znJ5k4FuTzZ7E1D3mv=^K&f%dJQh=pmp24Id2ifMqSa=o9=b>x!oyNL3b1iFdedGT( z*it(Zga|Kmrm)qYzguJ4pbN>Vd#C`?#j`QbiQ7D?;Px~i*nN1t$Fa;G?eRIbA{ zUukL*L7Cs`y9=bno=w1O&NU}@P=&ons4GqRU!o$^@(=t`eEu+i?uy8$-1pI zmkK`u#fRY>bJm)WkQ3kGJ5h+9BWNiNaeMU;rNHRb9LG(?Ip7AEKk1sCsloM1xdq=t zxYj?-*uZ*h)(CSklsY&_t6|jc*1`i&m7CM#XeY0O9i5NIsz+bm@ee)kz5oFeapeCQ z{a*#%9K~cl9dJaW?y0E4q~ce;y(E_wBC{bqR%ldcZ$RQj7`)u@KnljQ6Mw!HW>5Wa zd!8mVA33Xj9FmG|5bMr{ZAA;6E_srWcupQAZHaWpMCefpb@=>>oq0>5JNY#>D`CCu z#rqwgG&ue8Fpr<)&>d;GD;9&@K-c(b{=C={1w|X38VDzRMxc&iJyr`@GUj7pzI8`}Kn+kTCht$Thnk%r zKDGqZRNV!Ot@vW<;TgMK3Fn{Z_cqSSM4^zjEvIAztQI8xe0?e$zGh^psHa2a?T-m7f zTlAK6+EkKGR+*XI{#N1ONG%eRH5trSMNBAbM(sQJmZe@Gx(+wMn22A@4qn)7%z-Z` zYCq9|m+~5WXh>9xBZZ)Mls)NmoP4f0EMP#~O~T213<7X5H}t`;W7j6SoB<8@1p3V$ zjp>siFeIb?u*T%QO|d-Z%4gcHkjeww_FIeeI0SW;ePK(8_T1VNcV~vE#g3XIA`Y}7 zS91hSurUTq0Nm|}@pO>Eibnhx-ee)pM~xqj45-+8Q>jOg}aQ!@ZM~4HW%LLBPBCdux^gmp-^Dr)_##LGd%4l zl9_#94sCFeCgUC0e>l8fqh((%`#BU{rp**$_pF@^N*<0ueE2`l))$03*;)dE z$83KJ^)}q55F0fw2mV_I(+d2^$zLLOb|2nV2pGhv293o`t;;Nu$<$GVJF`!jAVr(gbs2xf^c>AK2D44{;>VgmH)Re7wsk2_F_nXT z#!HHyEr;y_rQeH+eAh4NZ}X9ub2LGcDxM7FIq9cc*X%Y~ZBRLY}o?C+SdGV!=u zwXl(|gB4qJIq(JPcA)&Q)e`edw59XjKROIM%{QJ&GD;UHe(%Wf<*9*Do0&G3=-#KW zRIi1SrCKMQa)tr>TPptmjh(cdAG390>@TjxUZk%1ox+f7KwEzN@uwjM*(&PbX2?Wc zdKK*VQdPP9%dAE-ZnJ*<-&z;`o6Xjkq{g6oip8OZ_{wf?j#ixHBsGnNE_~W`mtvL8 zBWYg5NUzK42ig*$WSz@=mkKISKCa=Vest+<0;WPMlcZ(z*)CnFH5Xga%_AjF#BBWXQ~$aPyv=Y^oA#M{K{8LfR)3?(E!gUcAtLsc z5KrL~)#cC?9bt=^c96V(uJxo!kmMh%IrV7aBaf?hJ}Go>5u0a7SLrj6B1si$MQNJekrhGgh?;yZL%IwFUSy9%%Hxn;Jux(X6G&iUG)EOy6sF zDu&Qpxym)wEL4_U9XN;W{2GutE3uchLW7UzXYsmE={&w zY=+h|OHiw4lQRSZHj{bxP_<;md5rco+c{~+TGwMue64|e1}xxhw#Tm^u&Ea9L@JI{ zC`uF1N%W7t>TI&45Q;S-nOct>{L_@Gojyu*F4}%7mV0=$v(x!2!LMJiU7l>U9x$_| zDYg23kh@;Cg{YTvUbeFr7beDeo%G=Z@ZB;k7evZazMQ^}^($HPJ|b#Op>PKZv9`-o z%cl%~N>5+*cdYnbP&i~amCdsx4Uzq+t?6pg-fqdU8ug&r(UzJAboY34*M&AGJ_&kE zE2J>Q*r9~zpn#BT{(1!pg-i2_wo``8@9G($IQgOY{PgDYokE*QM>*lHh1(CI8T#0m zDVb};nG5;5;&Z{uW^6x9g}P9P((J|?xlC{zKNylWUo4CU&2`iWd{bch-0f_kiq9V% z9CRM~VzB5b%TupE*}3vzIn!mQmXZ7{Y$UO>*975BKh!PHQ&4>F`r4@9M6~%PZK6F!XK5y1%W2DD9>v3vIx-b+uX{D(mxW~t zqyYXCvy_~e;n>N;pi)cNl_s>DU$5^pvJK{=#fh_Lr6U}+xEbzZJs8Bao63|hU%=O7L>@P1AcDB{Zg+21?4UJ6j&pgm z*#JHpUbDV$9*Kw(720#3qZ7H)loo@gf7;~)?@X|m&-h$!a~4Qwu3#B)X@T6`e)F#h zQ>50B4+6RQv+=io5Mj4f(d}fBxxUU}D%5uI)`@wFdd6RJB^D}l{YTskCi8S6G~m2k zUGN2^A_Wcu3W@VtFd~t_KAKz>=vXnbQyg3r@ld@{m&!0Tq3^5SD*?0m0?z}*1kqFg z{Fvqc!jBnkz)yXz#%T;6BbiT9!m<=zXvvz_;ObBEGK=8|&}3GJFgEH!bw)ys(EU|P zZJA=JJ`o-`{iEHKR?9{s%xFfdhETbi&Dn26;fk`{7oyk`#4ZZvUpXi&<|BQaQ1@|H z5e-kP>n@*zs4sOebh?iFayg5F&@PpFg|uA98B9GAhXv>S8iotG`HGWB9 z3kzmeprVV?+Bd=Gl-gE1e zK#*8**x#puNSJYvg-~~nf<**VwL}TFMj3mn5)#F4?U95R&Ru~ebE_C$NPsR@HKs#T#0M1G=)Ds zmba~W;>UT+R<8IxHmbd1=aWY@r#E?IOPTqX0mzo)7-hHQ2Ogjh8n7ZaSO?08<`yXT z!YWsD-kCd?oXAm*^5?%Wl;&K-=@OU{dja=se9yWs78qNm6Zlx5oybAzOI>HbwN=o_ z-hVhgn#Ss(+yxNS9!*_}#d1G~s(6`lpc1EzCJmHZ{jf4p9fSqkt7RJIY)8OgmXDf? z?(;iFeozDYRwMOwq8>Y8ClCoJ+ zOg9buS1H{Ldtcq|fJ11wO>Fp|H%*Jww?7q%gp63=C1rWiQc6~mMh8dG8mS&>RmRxG zeriEan_i9fX)Rpd&^0Ajo$Q0xwG<%1s223Mj-fuE%1zhM8?y9>Zo>2G=^GVXKB#!3 zX=3WU$lzmjmd<(L!euBMN!ILI!?7FO+Dv!4=*@KMbM3Md7~GWib1zMn+d;2`|PnTfcUmwGF0pH#6E9+0<>Uh*YDel}fLZ(p>hUWi) zJx=>5czeR?1d*v@*`i^SU@rO453+SXM&`fnQ1 z-e#=a5GndVM>VVgf_^~zmTl<(>fUwaN;d9DlgGVFrmnIVS&6lgais*v{~LqImVOZAR=V8AJvgmr^fsmxHg#w|R|Ijr zvbI;wgoqy>qNK0k>7x+S2XBMfMH~?Z`pJ9_7i=YntITxpBdu|EfHWNujcUOWXME_< zaJFx&@?{sr64lINx%&JCyC=8PKVb_UbYvJyq6G`!vB~SqcVvuQ9ZIIa;~;etKu)wBS(+ zslX^9Z-?%d;x;jVPDI58UnkG+ivN`~SHIDvpKr)yGo+xI66ucGoyI72rErv3Q29*= z)C5FYnP*isI-%?C|D>nt>(oW}KiaMbRX?fBD&T5MVnp`^6v5APU`nEb#ff!FtNM*B&iC8MUtSq}?rxf^uY zxWboa7{M<3(J@+ilt&ZoT%sH~hU#S>V3yUE3D9<>qkp9H)yT9h;b!WLhxMRKgt_-` zo8T)c+#f@ct~w0UW^i&f1o52wGyXLNW4`;s2GWvkpg$D=5W^~G1X=F=gpBKvRXDK` z`~2YkM5}H+qD~SM+SBmXjx10f4V%jrN6pFUC#O1e+rW$1$!+XO8bQFWX5@9|(CU*> zAFTHF@+N+)U=jKGPLjj+31y>>IlpRW&SKB%Spe3P9r?05-`;pld|XnKxBr5~LVNRj z5~PFaGYPn1GWO}7{Z1v)HcGQIa20H;M2Z!ouNf|HctHtMn(gDwcRIGfYNos@iU7l( zO#BV42LXJ(B_kv9o}1Br5hQ92^9LXA#MIZON6m(ubXy#eQK^e|+G=7i&X;%n%ecg{ zKVI1Z_Gq&|-s@cwe&S;6lhZO&S+J6;uqeyBIK07LJVfE(y7|MUF1j;hkzGs8tIr>( zbLKZ)(IK}|f3FLqW}D{RXc>V}5 z(XQ2yk#D3Sfkrb3f22cDN|2~9XmR)l9;$s1zGJ+UkY*})i1fRr0y!OM383RZMw;*9 z@>iHvxV8wHpauj(58BHz4krCa5{Qo{3o*kk2MYQoDH9WKWHkL}rmBdK3)uX1hWx4k zViYN_Y;WZg>EB_1LydrZc|1;{M5DX>jLx5TDbnpLSW8Xh+at>*ZPt61=y>{nAgEh! z{zVEYuy&;z5eS7A60L2fv}zMWQF{0jJF~n{vv>w`4-J@GUi<{ZCmI3ku(>K0^xB`) z3O!rgg*j>c`9@-}=&g($+$v`LO>i_j{@Re!1wU~tpd>IU{CjGC)a`UsR#pnx)#lpU z=Ph?^zM7~)@M(AkqBdYp0}*H^F@esCC^n)=p$qn>p^8UnM2ibBi#=8f=MqP3mOPLS zM5v1*@%Xsce!$)mEay#+7=F|-20IrdKhWhZJj^?tmTWV+)jUh?$4iWmUzT2>O_$lV z7_uG3cU7hl=-&P4O0*8E(U2| zH82CsCv5W(jG6v))?|ZQ67DVS3N0G<-4fjF@p_J0uM40rQJ?+r$0P1N9Nlt)yLbAt zT>kEZ=PgQo`cfIBxn1w(z7t;Aw(_?TqX>CAKw75yHs3z)2WipTac)dF1m>z08+jww zL=m6q=hznLvNd?c72%i;h586c6Hn@Exu{gYQq)`(A>Wnd!4KST6Uc%NdZVKZx=I4% zsE(t@GND26nqzZ%&Z()f)QGKkcNbC_`hUv-_u=+mkCQV2Y_`Q@wtHzmrPhE>>>TX7 zKi~aPj-76>KxVeNMn1wpusG}qQBuEUrKfP6-ymg88UaeV&aMZF!wn~GO-5V=jyCK8 ziJ9z9Of(JGQXDzQ%-hlv*W<4|4VU~boZZWgHq)k0+ETUaaASj4eGXG6&~xo-cr57{ zb@f~>)D?ANqKN6SdiSRG0bRq+sMd?oXXQ~6`SsYFn(qbtb*cLYXl_2CUc0#tu3lFI zcUytx_OAVZIv?Y_o|$bNa|$hEvk$7(|HO2`+>r)MRq!=MK$D6bv1+)ZFgV#fF{rM3 z-fO5`cN%?M-Qi4H>EO z@w#N&I2<`01~~2IH?&On&a0*vEI_a#+FcsiCDF>UUwq?#IqTZTQL|6ipP$SBaJr=$ zN4xyly@f1Sgx^Jx5Lz;0_a2jYwVqJD$VC1wS?*#Z`g@oCDQD!1jguOq=QXCm*5+hW z`A?R_+hb$-3kP0JN%H9@oB0j$uhfF(ne536c14Q(WGnfhZyk#O{xivgm< za+w*2_Ojgr-H!_D%o{nfmt7OWbrdF0CA$3iGC>3DxSPJog)Aufn4|5t9c>2A z2bW{Ce+@}#H^UYU3CW0JR@BL4Or;ByZbjg|dHn;6QRiUU4rHs%vw56AaarI;YV}_u zIk!IFVsphkNXS*C4{}G-D(u#^e04`~al^hmI4&(!3oMI@AGX~=OkS^GPpUw%KFxpxWGXU=J!%qj(A1(o+M?vxP|xZ8HMO(-U*xo&SBOyNr` zB*raKHbRGph%uffh*&9C`5Kr{8Px^%_o9d!7zLv%#Qs*!8;m~s>&6>2gep#df}iv2 zNe1yr8D8%2vCz=p4acLYPk4l36ZNNkD)?;JS?kX|@18I#Q*QhApW+)HEtE^(vYg$w zQExV&RxhNn*{sjj`$|BEIc&7w{-&t7=izJg zPcG@rT{dVACrs@fYr^e17brRcp(u0|2;nEeR_!CX;=ox>W5m{V5CGb4W=@;qOj{_} zz8wU{JN1CO0+?j^*<;*8v=mD=jU`#VVs%`tQ*4>gVJ&8uzdF5UnW#r$>UD(9t*a|AV=1nDWKn>>TX}g_*hvuGJJpG=>zycVesf z-PenSL4K%Ix7ff$6r3co(HSt9wZ2nybDWp6`Yq0aFkNHxHj)R6&yhVi-2aPleP(-o z_~W)8i^-f+saf}DuepaKzwGNhDa3OQ*@XGfTVXawW6KkhU!Q#ltfyw|z`5X^Jy%b_ zG+&4KtjTR5mr8NIIPS1&jPY{s`UM9|$PR4gKT_5s_tuRam<#5p5Q8mc8f1{4c2Zc4 z24pTrxP~SWc0kC@U2+ApBqh;wS8Uac~R>AMCYG7{j{XsU0Ax z@HV@$KipGo1vi;&jb$o~o7H>gk@ipZR_++^hW}!-lp_W0ORcm%uJ--m<6Ms=q7xkG zxEnVPeu;H(0xnoOJ83k8HIc;~)w*HOs@)0K>#amkCo7Ud53q{oLCw2C!EszcovKBm zjKYc(9boQ28k-}dq|qk1BMJ-t&dz!33B?hW>gaMV3v)T=EHNMHl)XC07VDK3N`Sxp zqnI?xMMR|(;b!bPr+6ptJOQEJz#C)O^`^+kPA>>VS9^93m8CR2aG?>UIS6)z&oSVy zL(aLe@wb9PVWKVkK}?XoB6g~g02 z#2wMk6|TSy?_eqR9ZN|z-AKv=3JTD2j~^gIq= zqEWIje}x0=Ho!7x!@!Ut>4JwXxL{S0SY+J|NXfo%LZa_XUM~c_yRVV=U1B|)4frcm zYmQ3hV{2C;cK6AIy*Ob5Wk555ZM3N-4}oyNf-X& zP7*+oYfm9mL$;nat2nv3T7cEE9kzP*TWy!v{jhS~QdGo$uLgk6`QM2K3_NmJk{ta< zoISTQ=Axt_^LluPmgm4prEBP`Lj*+QeYNAikCG0!sMYia6V9s*+ngNMoKwIjP4$mG zJ~m{4q<$R7YLny%0+odm=ffGY$?Rq(d_~!?M?AiZlc>*o+? zT9tc`9%rTwayk7`5Fa-B;tU;2dE!^x#O91qR&uWq230!u*txY$>Fn0D_jJQ#JO1u7 zEwuBp)u?sb1JSnDoZ&Beu7+=z-^9*Y4QbPSS$!9UOHp4$hv}YGOFUPC#9=>CsNCom zuyog3o??$exlf<8b$8EKJc_BC)>87?q^FSD0xGX41E{>J36K!y@Er`ced!3+N7wq@ z7&JwU)pI%vkp_J(I&Xf@*gl@>Jb4utKjjOCp&0Pf8?|(!%IG@qc(2PEtq!2KYU@oM zqddoFm#l%}ZjTi1$YMXFjFLz)96w(YxB?vHfAt{S!!z=6aJnw z6tHXLdoD_q+y{ju#Q2*_GHEFcxGxSP#B^mklskz9`g>&6BAHO-I`E>^+&>3y&ih6y zee?xuXrx~qvroQw`g)(;9O^whzIeXk{*Kx0!2Z!gE98_jwfB~ zV5#=d0K?Y-WXVyR3tRQWP$(1+X%C(e?qzy(vOWzVV7SnFRcK|Zzo-&+Kn0itg<~4( z13wk2H)sBgp*+{UPhkdZkLJiMgcN;bsQZUww9GB*03oYy7J=*ev_!054$JeHfDlA; zN2~)DY#H*hc$xSNsJqT%_bUdk=WpicvUFSH_`SD`9_dla4A}KOLd&JK+LBA30-;6@ zy%F@Du(q}RLZ{{IifUSb(m0vB;d#wycZ5Re%d)WCbbV=tQ=bcS&sMpZd=ZtpI|8+q zAG6Xl`{{H^tqb%D7;#4pEUCVbhS?rk$Ya>_OvEB%Xfj!IVDP&09qG34kC+?Oy?fKq zx7~VW?sOWRj`4Q|W>e|hBRddwci;G&-$(yN?UoxTn00>g*^d9g1l-zNZ)m?5=B2Ye z1EPm_cwrI_V;3paS5A!P^3nJyEg=J5f>DHs7$8+n^N&>d+69}11r{6OUPxxc-vX(g z0yL@p$AvH!#&FDzStp_QE&BN*`gtqr>x773?rq|!%_jl06Q&qgRiOnl{M$F`EL$RdqRj* zw>~u@;ThX}w@C@y4X@DH@gz?J0)ViDZ;y9eoFtL*gI~@~;lBD{r&=-^$Xi|usl(?P zt95hF2~(m!41!V-qgV({{)NAxtlx-Hd#I(b=I*^>f=iIT%@No4 z#de|9=My9vM&m&oo2dD~NEnMFZ;n+%eBDqJSE1jG_Y^11c{+)P!WdcOl`9~b;I6v6 zU(Plin!&7 z5ec`#qD;8F6U(ryxAi|X8VSmo*Au9<%{A%IFPLBRnTYXugjQO+JhIISEXM)!c0iLA z9w>f>bmPR62h*R(FKm)ZtiV)2;WmmQbT24rP!Jgm7{WNr zy3_AJf2%fqgz<6&4kE~;1xe%BFD;aa77-#>SL^o-My(PL|DZ5tGFHC`FinXbw`Rjx zsRU#OGd@s%Sd;mZMW!aOh7LBLXU*>0OT{B2M1Qq5N06Y8H{eSGNckB{p<{a*N3b~) zR}8%AgPGAM0xAe-ci06S-qwZxaHKGj1xkd)Exf#7qWRQ+NxqG@zVhsptoF=RZ+Uep zwV~hTAPB@vLi7=k8^NjWC#Y-`0MSh`#c}Q)5{suM#pt3cP+ftYlUVDEYxZn3%yB{rgAP4czlCa$Xt>Ae6f=k>d>`~qtIdyjku&c zYwUp3UQ@7MX-Q}7pP6Civd0b+!U;5DcVQgh)@ihaXNh4nyJ|$n@r(}YsztZS6;f@^XDPeVbaer<( zijy!q9n23!@pR|l2#+gGQo5azB)5CJouF90aOCYkH)uYYZ00GYt-S1X-9@4* zLI{v)w)nkmjnEe2ehu7F2u-v^CuTA`0uli^k)gp0(e0`hC~lUJy&v``h{rU&&OS94 z^DI%iI%)OB*nwy5gg~r49DT2PmxhS|(CEl75WCbn@n4GxSynyQwJz&$16-RYyN{&4 z(PwZK`6`MFuFvv+9#gY}VBc=T;ToDyZUy5H`7D@uZ^+t~ViEj=6|Nk@x0S$qRn0?A&d zerI|Jb+5GTejy<|Flk_&2M7H)-KZlTlY+tQV$5R64EuVLe1qSV>|BJ?>xdV@H}}&X z2lgTJe4q-rXo6+4MdSy( z`()gP|NhZ}5_LpWL zj-$?XuJ)vM`>?j;^U)#?p?4WIt^GwQe4p86@wvxl``=eP%m|ikKSqNYa84bfKKS`X#2A2YSpC*WgQoPG@buprk+&6JNW`{_GHP=TFjb^mC!7NKb*>+n z$MaP88uOAI;#mggc1tuJyBZSwEjf5zv5ry-skf6<$A*b=|0DU|@{iFDE1$Il+OZd| zrU&Az@s#a6V^Rh%QkHsi_OK>OzP5qic#Zwgs^H(WkF>@KJOmrAJ;kaOP|iCVL!|-- z7_8-m_zEd<2Fq~}=VIK7x2elAg%f~f819hu7!ms`o8d-)1B*-N*jzY4!(n$!(|s5IEogW{DeKh)4! z#yG{t!x3g&wE=A47p}+wF#<7D=Kk>UjqnC!^?KT`yOM17%`Lj1HN(D>N->t7_Z*W^ zDxkATLfTndd~Rq!s2tqfgqVpQSs9Bz!760j_cRKe9fW*Otz30x6mu_t-WKU1b}!^6 zAlVI*M)ml5{%;1~8%CH+(Tnl^wf5>4T9Gzh&J>9@t~}JKIYDGtvoWmL`)98bV)mMQ zxxDFDoOsEr$vrrm?N@k?-xNp6`t<#XInmRZuIn||1i{b;>7^fEzB!nJ$mvy(Bx0Nu z&+YTf?=AD`$YkmtG1b+r9`7iqsK4~(5rbhkh+2woI7Z`qx(pT*6S71U9ifkn7C}{N zq<`->aEqD_)>uPsMZn^iuu>(f=2vc|#PGqdxKNP+gYm`>^XyblCLG`1RLn=s$)3T= zWjQ8M=82#ah(;~G&E@4GPZN7^9EQg%bna(l*EyWO(oz%i84nbKCuJA*cKuS}@?(iX z$|IWrVz&(j25|AGeV?H#7o-19_@(u)a8>f_;d}qahr6=Hp0Psey0`q6yAu1&Y7I-a zjh(pPaokzn(;eOB6wB)Y58AC3KLiG=?h(1yOW)(Fl0qkHfUiIKYDuybp4}#Tmg150 z*orquk&?cGo~-=8=d9e_B|Vk4v=1_`r}Ayv<7DyckEJz}D0G>2t6k@pt;6Z3a=X#= zR@eXWP{_PJ)BC7Ra2IZKA_=iujw!HfAyU4%4$P4PqvO#W`aeL`e#?@j2eLGp&sgj= zJQMDWx6g4^M@qLN$IFH1Gc}O0Is*uG@Qo+(nR}8jy9TL}D5f1+dx8 z*$)vp3`Xi|Am+*T!K5{aTbIj!m@L6sfNJyQB!@{O)Zs=& zM8z4MPU)SiH$$I$-Xnr;w{L~h2zZC2jn+a1wLdy3K~ln9kw)^#zj{Hg&a*u!&Feoi ztSZ<1;$2CwnQcc28~2x;##_*G;OvC{pW3eSEvheCE25-yr;-vPISj2xr+|QTmvjv= zFqCvjcL{>P5Yi>xozgvY3|&Lq!QXQ~+&|!+5A&Q@`(5i@>)rdDndi*jzN{IIKP+)Q z;xJa!6E5@5H}2OG7F42F=C!EZ2S9Or^JT`>Uof`ozs=H9u6UXdt3~A#e6GaM?;;#4 z7C8cqI$aLF`~d|>K=XPePiC?^bF=xQu7}JnzlC8Z3&n;-^K*JFeJF*089b*JP>@2M z$Vm;-aW>}?QjEr^Yc>}u(22+$o00U&e(*o|P3-;<{z=X;Nl{11aEM6x)^oHF^r|6z zxk6CKX)h^T=J{7Nk^+T+%_{p(f#UgbgUK}*xvvK#CdwG5!doOmQxz>|e>^qQgI79U zxo}WolVe%4?S!GkmudzV*$PhlR{z-FuxPPa8&(n zS}s|GruW%5#tX4o@Ha)QPsn;LMp{s=W0-r#BV&?nG*jnQR=Q9L7M$HPMYP8Fsy+`~JW$8ccQy5tLf4KX|i4 zG&F3h6GI-^YXBNsn@=QHlN#+^2keb9 z&2Q`*R-3j$d|=yzK!TBT2Qub3C`&bI2K(o;*({p2nb!5Ch|zCP7P1gB&GlTzd?Z}T znvCMciP^tCr$`*oPRtyc&58ap)qrVN+hV?yE>R^ganj6+X(61>Gh+5Muo#P??LLl+>Iu4aBWFEYRwiqIK_eu`M;0NK$e%x5Fx1>M$L?FpcEx*>NgMopE7_sf@wm@~ zVBDCarMCBY7Vz838NI`Rgh<@j$2kw)Du8oz`tsX^noBc1@ycT5d8*pLiuetH}D(0bC& zX}w*IqmPqXkz$RH`GocFRKRG#9Kinuy4CZVu>!uZH5IQV{a`4rKvh_`vYyl(HhKem zp}BG}sQbZV?%UCDoR_#hZ-#5U7*i)+5+;WzB(KC$+PlY4dryD z%{xI7wW|LN9r*wTR_1iRz+oTAU%9qg$T;Ry6udB)n^Fv<_V)Nl`MePxd@GR<;f0d)ny5+pc74vQuWO_))v?Tv4T?DX&NfQS0P zR4{fctKnjMaLTcgOJC$_2uFbY!Kh!99~{eyQb{@kY`4Ut=23SVx8T`U<^k%m=V~YR zx;y44a{C&shgkYlmEljLd3v@C3EtsE*E}@rGI8TE@&0azdcT!^Per=L(EKo^09V$N z97q8T!`&BYvwr}SP+9`=pdf}G>VBK`Hf(Yb>$6jU^Y|dK4Z~!W@{u#?QLvDzd~Wig9ig{&oTM|FF$h zGHH|Mvu)nu)+0+e57!OG2;3V6@GY>Ix+*YHVNTfZCSZ@SnqDuS&L|SFNLXZ%9THW% zSQwP-m^)gYUiPbbu_|K7)b-m9ws964vlv?SNl<^nx16#A)1&2M&LESUC{h_D)`oYKTjcS`2*pt-t1IMB0Ai-jx`&eY9B9;O25islSjT^O~#| z9Az~j#J!csrQl&jN1&)Q*K;Kpo2` zIK?}84PwK~mISs&@MUIwl}*;0g2c@98DbK<8Hw$>EQi}A{>DJF6y7f;C0rRoqluDG zZ9B)~k*|r(l31VB;2*CX>JII9ZE4GLlNBny#(G&}y;DkWb-n;*dY|Jcsmx6J9;T1t z-{0;~KtshkY$f*WZ5Pg{seE2A^Vc{pgtufrc-!3{bz$AoHJfK^NqGq5)d+Fiz^m2B zU1sH^IJEoX)r#zQ<3DM?Fg2GmR;|wTJ9Aoy9sZVQ`;hCS0NNX10r(;|E>bZH>r0uC{fNiV7 zFrCm`2%7$le-+lDX`Q1DhrMw5fiT~?9?D$-=c;?p(~mia%r2eo9@N$WgaRYPck#1x zb)mK(i=c~Aq2!80(f->qHFA$>Wk>!dz6_b-lQVkz%C(Y_y0x1z?1Klj^d%s6$14I* z`LWo;s+;rS#(>WjRp($nfBo@3|E0&-vAoiT;}idHTD=J$H77U<58;)KBO-p+2eF(H z+kf{P@100Vo1OU&z&IOu3StJ$CvKvAsymY$_(Gq=jjm~yyev`VX|0q;UNgsNH$_Pz zmztE>lgQ_%;h*<7cB9uif%Rb@WXfG{R@+?J(=@Z<7FuC^CQb@CecgTOoqyx@R)P*- zAm%2r^|moO4v`^iVZy^k;v+7*=~(4RS*L)8#e8OB)lzlFr$}nw`W@c!)IpacUeILqz&@N_Je&;I zMclunN=1xOmC}d?uPakaF>l#v*w2Xn(#PK(p{WbG$K8K8L*qSldNC;cKFw#Lz5d~s zSR?JL8{v&4iKEiGIDGGX>Z52Qp5f;)tgOg&Konof7h?_t0p0GL@Lnh-g^$oSC9A;) znR`fc6wx#@tj|*2aSA9NR%b1o(23$#qU)^UU`5lo&w}@;?S$aeb|vf#1wa|``5Lga zkFo+;FCk*bxlG3(C;E4cMb=WFy1E#*m;VEGRZA}Aap!~TW%=pOC?M)YuaR=rVMT~! z(VHv$P&8At!0m_P6bF)tzj|@@NWL_O(I~xZbia<*yLFP~fn#r_f`C7fKBadQY(z}M zY2xXm=!q(yYa*u$qF`|4^YwH!?qmDAr z7q$!48t!c@~3<%R~(uCk;v7NX0hyL+z0?$ii-a$j&xAxpvFrT!xT$C_YAgUXa zuZ{2_>`u#B?Fp4FuaEaJB4oKZ->gMibHJL<^{&rNrpKzkk_re2fLpwR)k_Y8^GK6) z7s6M@cNQv5DIr3w7e6nLdSm50b=~oq|Dt{xSsqk$7Bf~pzv+n|TNhw6;=@%+9r3w9 z{rzlSOVEt9?0Q@P-u`;bV>RI2MtfAF$4)m3`>ckaZNf>{?OyV6VVkdtiKPh{=M{H1 zVA4DTd6=7P3yl=!y_X|Nj9j@#_=tNM9ZYj4jelGeyL>ZsT!s8E=GAdT+4*^8rYu>g z$kom}NLxtUyrE5Ik7Pg(H=-Esc7T56)H_=qp;enclA_t4|wc1GkR&#w6&~fE%19|+{bYk?#o7Lmr3GdW$DOek4!C$8dRpi zwvUJMbr0+AWrr!8UW<61WrHsdu_=MEX{8WaUS7e7)@;NiJs3X+=JIe9^`Fkxdy{sx zNqvQGK>AYE_hxj#LRY)@lHy~|Gayv!2{Em+ZdO3x?$is6gmpF?iSVM}TcQnk8~2)Q zQ7Jl2X|IKw3S((J_FOU^P9XY}SD?;CHNvg~4XS;4AZny$jaS;sT>Fc_LQc?DA;7_X zd8@CK^=G*QuHUW?<&tUBFy2P9YEe7>z6|kC$)Bjl-?t0zrb`tWi0H`P6ae1^QyDNi z?vOJ~ad<@xb)$8UZ%u}K9L?L$%o$St5PGF?J*!q1PF-*IJ^4dX&loI)E{%p}q%|kb zqo~TiW5&>~r9odHtnC1FaxY(ItbPDzT|T}JF8*@8Q|sEE|Lpqk6!&_%#f<=d92*~h zcBk2d)#0$Fq`*)J(-FO2$%5)uCM=A@u^=YMC(WWGIE(3Yl&+jC1O1emg_hpLoCJEH zKWcw?-#}lRCb_h>lGKuM^wApmTih+l*DJAyG0=R(3bFj?AeSwqa#o#`d zE`I|(52<%IGb7;7{5}H2GXnYMX3CJ3OjAr9=0}c(lrErN-UZ5d+KDI#%3-rR3tY z&m!Wv20b0oG<6YgOm6PDJbCyl$?vWVQUtR%qu2BosvMXklNN`|f?Fu`UaTl=Od7m3 zyx*`bnm>be2lB8$mp6l+54H&N&27%ty4zc$3soc=WT>^Yev~ftbQY++1kt3Or&E|D zsKVHa2)qA)=cDUgf6A&B_(gcTF!4V$%`Yj>S~t1PYb`kW0-D??Efy;XcnS%h^(a|v z4v7`$&^Kbp;bU2_)0GXV>+*FjwvKEAG@9;lNl2NUcBTSkT?M%W?_#fc_<0U%TQ-?} zLhswGY22Phzimh?ErnZyw5%DsQNZ{$S+Maqi$pr`ha6byjt{7(_ z0vKD#o!f(wu#q5PlU5|sI9p$^^g9zuulNS^7>7y@|F<>_cVVC7wZ zPOH61-2#KTXd2cT*#?me9o=G{!41GD+}jMVqK zcg*E1worl-P?F(l=ULt;$J@ZG0O*eUAm>7!d3Dz?TOOwc_)K&ogr1QnjzN?6Q}yje zd{Xdl!dM5P46-0uIdh%%>oF6|Ck}W?<58f#((REZp3p%k`t>!qcm6jTWu1!#!ChJE zz&-}GY}9nIV-*i3Cgy%bL;!Lwl}N0Wfn$7yeg?Ms7(Uylv;v|s14_9$j!Lab0nxk_ z6BIXgChT&~r!3(@GwT%1kkI0sd;7JvN~x3Kdi#_6gFl7rqVbS^QK+iZ>76@IrRBbK z{_Cd(IL(FxDO!EJpk(z_d5D&7u9&r=Exn@2-T7Cg`akJHnk|E7J7`Pw3flDH<4q{H znb>oT))qr1A5>KEY3#)nMs~{HFpG%%Ib_}mE>zAE{j5PVNU|Zu4gyNdfL2`iD&hyD z@GzwvfO;?QSM8+5ztJNE?mv~yqiTr+I~dDh7?fRyd0k|@CT3siE5*b;sGM*hd`|BJ@Lc&pdYFur7H)UXlTXz4a92#z$4B8@cRp z`PV7*+%@gh%?v7ma zAeKxN+<~Ei2|?Uw2BxEs|4Xm$TpZm5Ir zr@UB)&h|~j+sc>eNYqssyosIH!znJPU{<8xo>-EX>CL_13`LK!%Ab~68QFzF%dDLS zsv#pSDZU%eK5a7~t;RzP>mrIi324k1Fk2QG5i${B5Jr8G*x$Y+7DW1zk+=Mf(eyJx zXm(d;$#Uj%+A^!D{1|u5O`@81ePy8M8zh9&S;)+5sDoLqk(w>Q2wG0dM)k^A=k4T2 zd}Q_#tG}$Asoxuek|3bg^2Z2F-F7Uz5?bm~z)X++8&dM?M@-rrp6AG{_!_B2Zrc>d z`L8LcWgCYO{s15myivX zG<274RnhHD5AuA~x1~SebtkKUaN1ZBtaPYMc3)-M*Q*k7@Kv3BQ5O2;O8X~jD<9TMew!R+i!p{&?g;9(|lVRgG?nGx&&Rlr^m@}udMp&#o z8Q)hTyIO>do22otuFc{U99(^#^=xwik0ES@-~w_m<3A^LJNaqbvz5$c%{Q!+0Gn1? zUQP$z`XEAIKSIvbvN)uTqG>9V*v(ipnGguto=^J=8Lp~`p6n^8!(jZZm1Sr?^0V@^ z`F_GbwkDk+_z`4WD89;W!d+X951mo5d^4bl#J^_Dm+SEtNj$$MT!GkAeR%lzF`E;{ z^1Gg<`tK1ii2G0kQ}tykd6n{^G*Y3d^%h9abkbRIbNA^NeRr~)^W*Va^7=d(P+0R> zFIA|KrBx=T4+X>2pS=E_Xt!;k|Dp?xCneAY)K;Eu-9nOMd(-wwW ze!v;D+UephbMEp}dh>be>|v2>94#*~UK2 zd{$TQI@^h^@Va)Qi|LIy{P8Fsm7h7~EJT=G;0vlmb}czhGMgG}wtpz}rN7GyH-DGv zH5M_&KAL{8WVb2ufT|U2X2_Pd$|Ai-{dyLtIIRzUnIY(^y7F)6kU+ysV^U1*#b8d) z{$~S8PNe_{P*GGMAwqa1)jF7x2Bx5QquYv)0O)~=kE;7UrQOc-=^ujhd$o@I{(iZ9 zWuJJT+x54^FG-R7W~1|i6cf!J!z6hJiRE+0UhV>M(L^I3=5{oVCkXSO+9T;gRCnw2#)+F+Y>_3ZG zXKcJb|D!&%d%b1wix41?p|m*!u27nRIRKOS;eFob{|Lfhro+O{$%#{sA$`^cVefV z6FUOarxhody?vEaA33*w-3)HiBOh&~tB_@Q&yLHQOUjVb%g7SsN!S--EfPozZE%O* z#0E;yGv+5k1Z@@m-^6q17Ze|}xw-CC&p8CdmOf3an4FPAHi;tUNWw}Q=u|1zZF_jO X`^9hk++0EM338K_RFWtaH}wA>mqAHD literal 0 HcmV?d00001 diff --git a/docs/source/images/nanogpt-curves-orig.png b/docs/source/images/nanogpt-curves-orig.png new file mode 100644 index 0000000000000000000000000000000000000000..75f9ecebbf578d88a144d7e60f3ac728ee66d826 GIT binary patch literal 119671 zcmc$`WmJ`G_&&%H6-7cM6$B+k8j*&rh$5|^bV@7Tji3@z5`yHG7U`DU0@5WQ-Q6JF zF!#$jzyHjd_%v(Q>`w~p&HL1KUvWS9J(H8b$Gw4zg@uJL`BYp13+sF=7S=hHOBdlQ zOg=~ASpQ&QNs2#w;TXF-;@}uFGjOuLwrSiFVbn`~`5~n2+hnFEsM$gAK!v@Ya$z#zQ>VCeyPxEut zvQ)BoW&CG~maq~-I6I9=N`?YQ%;i-)Y2g*dMuJOp>3psVt4e0ExXYZ&1tC3)E+{f&9VgsB?> z#A%}!+2DgE8_Y|H1D`M*xzmMQ;vJf1?IAwnj@Ub#P*^>*RLaq6bK4z}U?me;9D0)UoW0_&n2aK^=VNj4>G|!Q z9d1EELFQevlA(3y%FCC{doxvQwNEz4pFDXIbWlu3N2lSsUe~{}Bl%4xl@6ID$-5-D z>%WNQ$9E5zwh#3U4Xr0WkWbnQ!hb5XtgN`qOJWid5|YZ|T4oY|{7^9(gXf?A_Cf6S z$##xMXJ@CC3{z!Lc;C{@(PjtJ{rh%ZRSljbWMq@}<=1cC6#0H#PA)AhF_AW|Rhs*W zva+%;3kyr9ov&{r$>EyEskEKl_H%r)pq|9X{wjH|9@mJ)zR9VKcuH^GKehT}iXc-d z&Ob^jLpf(Q!H1fb%-^bsA<}X-%1CZOr?9-7GScN$r2Iv1pT^eK7skxAq3^X^*B{%j z4omXfh}&9wsjAxak(A{cul;_jk!bDDpFeej51ZuO#njXWNnX&qWNMa2NKiYy4+$Y+ zFPo=r7d>2Nj#KUJ?tWLh(1piD6rz8hhesTC=G>;N8bzV~idtED`NHZdBX#f+{@(Ml zwW>p7Q`2D%65->m?%gYf$jWSXpfFj}@8_uRF$m|Nfm$aCE}*4nIFZTU*<5f1V!Z zr^#rA)8Wmf-NS)Ay#{yej`T0p_L&4P*6IT(nORu6H=h8!@Wiw|k&|mDYP;(+^I6$f zoR(FqV$b*7-N2k6`8RJ^gdA3fEp}W~RR`RsQsd*PVLeK;78e(fsfQxi z)W6&_3Obj@+;?z^&u!_sV@qG@La3LxG94|q+5P+X*Ed=?g3(1R%*>rSMjtDxa{hQm z)EF5UJXci}5ZyC*{hIdbB^i@AiH!XGYu!A%*K^r4%U*P3G!;+qY3!Ecv;Np%(kQP9 z+O}k&2;OldA`*B7BQzcU5n{<7&@h!yQK8`$l-d|yrQjm=WqEn|xc-(XI(UC`F1W+; zwTX$q;>}NKgTtjZyI0udh1bOac*6U*--8beYyB=Q8{yim&;02nffK<K3=oXI zC(B6@+v`Sa;QS^0=fX~~TR`ADp@oZ%i$boMy3`i?2eKLpx$#4&9ZFvI<4a?zc_HZG zy!h-Cnu(p@xO3+R)_=?IZR_Em$I6$Vs73$yIOF8fd%bw0QFd(d`Djg_McI3=N{%s3 zv>j^ye(?SCZ%x}zmLr|CXRO-a;g%Gvu4?}s+SUyCys0Z{)nmP(RkF#o6 zbU$_z=RsR?bfpcOP4G*-D{Ki%PTg;=giu&H zTqbwlQ><_c!D%^VYRk`F*0*`uOjRpmVH1$!(VhYvUHaZH-Cu+Gv40tn!>l< zDl)dI{B(|6j?H#sP!~sVRssH{D>s&w?c4Z#VP%DZj&AD92_sF18B@Eg3a6HL-Y9*aoV~ z=GeKn!7JdLZAJ5e=Pf1zY@Ym`CMMPGnkyB5j~|}PezBZ)<@WM8jzwjF_KH_#U7A&B zNhH2UX%y4iA#vaEDL-GlG4_Gu%m9;9s6ux6!Zq9-4eG(VY@8b=4jj|(m)VP}2CKt5 zGE-*~jDEOxa*4iOJuKa+p>%ZqDYsgq_@yJ`Xl9nFbj>01@1}%%^W!KNg2StGZn~s& z={Dcg7E*Q>jGQ*A%YQ6QdCY_?Me2;XwNheoKaSrH5^xounGMu--1)h?9`uA4QyA~v z`4H=1v9F-2r)|g=`{Ky1N4S607Rhn3LPXqW**@Zqztit_*MAenPn1L&>i%vVWI|64 z^l?qD&jbwU%D>e1O zUd)e)1yR!!ChLP$J?gSUb5ZNHW7i3KLf-9m(va({B)EBoZ8sDY6$OB_1T)(LW73_f z56o1_oBVX_o!T`Oj1nDvdT2@Qvdk61t|^#xE=_#c{`>iN58Y2T(FHbha*3UHE?l^v z;V@!{WBO5afN@5CFSs|y=vYKq0gFV#iEycJtV7f+_MB?Pp*-b4vE4oOHAnJe*k?la z1N@ui8LEeWM~2;ZYcn;E?c`E}#*X)h2uknd^vN$M$Lmd9A>54g(bkzdrb*;iYjE4G za17{iu<2_sA7AjimR;&W$06*>=YXn+2=q9-%AgticYkRujpnVmas+5#=WXefHW_9#!MI{W(@iRFbKIygAgq^hk&-QR4p za1R3^K$o3VS{h~Hz8O=o)g|5SjuSedv%T2E!^<0VddkDcw+Z!Ie#p$5xOvBr-A(Uq zQU|K}dnZNbh%)MM{;oHFpo3<#r@fW~YUY|G0(xCHB$LU6i zX`p!2C$wpYh{kSD+I;&V&B>2h0qT=l8ppjl{G*jYCp^)8gDO`H+LDJL!+c^W!};6M zFh*r}PrGy~M?T-Yr%%PiVmh1VC@|jg{oWo|N3MI0p4^z5j03-y_ib8k#qI6w_11eG z&(zhWd2Wat;S(gln1F?U3*#OUVLlZLXCJinyH&8wj9)fb zs@%@Ng@K`}!60@|EH#YO0SiRFrd7Gpq^{=(*09 zNxu@Y-l!HiUuvLapJJ1{5^(`NUFDHA>B$i<-W+f8u$VNqu<@5^x~`Eh?`@CHjMFh1 zCzC(5LRD5GFU^NmtCRgx(1%py$se3fatOOA{jL!%$FEfFCN;XDFwO_6@y827>w|0* zp0^J1ZR&%a$VJN^QzrN}29O0e5R0-$-@)v+S^QWj9{QG@ExO310Nm4I*b1j)#6ITW z4P(f!^Kwc{qR0NGCQc*=2TnV^9acR}4KdW)<=V`n+4xdh z`bD04%B^Cri`Tcde#|vC*kf7G)pQplFY@O-3$;$}>9ettaK6}=Sk+3=2hXSJxm$oi zu~+TAl9HAVcKL1w0t@&YS5#C~_xHZHGAI`1?BwKR{UaxZ7cYM08?WVujr~i}mzftzRE!S_Edo754@UlHM~} zcJ=qJ;QyLp$~FDaBzpR?*oNOG{X$E&tf_CEujXGXQJ%5743u#23^^u-<8=O?_zdcr z_%Qj0T7veqZ?A~9``p#6IH={bZmtdjNuR*Y&CL(wHvQW-{#A_Gg9Vly zME2Z{w*E1%L z;MiKj(zbQ|bt9^mA`Lb}=^$qJkLFMvpo0ABUa^AKkgX+Lz^^#HS$Z^OaXLLPzwb!n z_K~__pkl=phwRGL(PqNY6YAB>?fhxYwaMi8f(+@9mUOL*cmNB(1N&dT(H#Z)iG70ov{l-Dmee$QiQ_Y3M&ILK~IZ>vg5 z+>o<;q=aw<8NdUU|SLX*sALb;LAZgkEVk9hs1ltHkMO|4I*Yk_?aDYxPo1f- zqcI#*wHx|~^K%CE^q|<(rC@Q{6{0#jbg2S=dASSy@Nv%V$9RLkcMsttKgdGTcIyK8OocB89Ce* zqZLW!=iG%$;-t(=T-5sbMp#{zE=&&}@}2B2J9Eb`++LAgn-%uIKx}+?+_$e;{K zXjQfA{b${9DEsBuF#KsPs!gok3SE01T=nJa-3G1Rek~d>7YI` zrc{c3dUH*Lh;C5AVT=fvLa^)B&t0p!$=cBd<`(&1uPT0yt**F_g#IwMU!EV`9UqPF z6B%W4d}SBDml!6na;cdFv>bEpGB-Q(^%+!&`v<0>uF~<9v{!2N-Zsa_?;Z#al=5n| zJad^}iR-`C{$rk1flb}%NEg*P|AsE$Z!cz3#f)U_x`lp1UCk6)qsGfA(%PuuP6+>7 z|4Bag=-pW3I*R246!l!ENJAm!go3hOhtTa;QKmN7g5J>;Kbh%n68^57KvfLq``53& zqjBC~_RkeZ&qTcOs9n@#t2o+{FWreHT;IJ+ed{LC&|}y2DgW`5HOr<|XEpn**|bv| z2F#lc%T;{zto~bnjyYLs$At#v^*|r3)RdGDni(FoftVq&;ev3|&44_;E&`j@35_G_ z(Qw0M4;l2^Ul{<}^b?-TZG?8~Hj4;l*d6gdhdA6 zqXXP7m(|K$ts{4j#FkK1gf0zR*5Ye;9B%`(oOTk}`4oL>852c6Q^eN%U@BJUx z+S;-xwD#tKo)~x5t~pUkIX%RZ*Vk-#sXe4fsG^agu{AQbQQ)u=?`_~h zS0cK8FLSNZ6Xmv}usYzTu0HYDb=OnfDe2~T;PAoTM34a95M}B1I{KB{*vYO{%Q0&F zNJM&O)RB!QewQD0c-T<7w^paI*?6}qyvmnD*qWx8&^=cBWD8s4sC%TOO|6}2^-u); zol`3&QEQukec^;C5`A2hmleUL72$TWdvp|vC#y$J3pHo?Z-#4F6K%Rz#dVfwx7UocEMS>h{P-~Dk^Ld~Xs=_~vJUO5 zcROiO&vikBaPt_`o447+zkb;yPpPJ8S43=ZOYccIK zAfU!Xh#`q~7tBz~qNO6R88k@?2@6ZYcq%5tSwf#|rA_@p`ZaJqm@saD_}}pSwYFc_O@>0pskq>7q7& z*{QJ!+mRfTA@Xc7yy-E@HR$HpYC;N=Z;P^T$A!OgqW9ffY0Zxu*!ZZoHmF_Pg3U@S z+tt^S7rx$)Z^Qv_1!b6W4PnO07b zo4Oq?UHn#T&3?RaPs4fX)=kZKzZCNuWeW<{JYoWe`9^A{EXqFUSg=(76BfR#(Y5Rn z_g64c54_tP5dI*GrRc9E6c*a5nQw5Z2L=Q@)6^7kFU`y}#e~^`hq*jl8V1gOI{Vs} zm&(e)=++@UFl<_*?iul3kkiws62lEX(!mYh8>f5bc$sflqy6VY!aK9!k48NM3oFFB z2zK3ql2@^*ZE$H$+S@c7e^aL9{V>j%=+(t^++Hk72~l?Jm+)d^3#Q|j5jO;qBIDw6 z2)=YjTib=NNl8<_t3o8WvWKd?-Kx)@6XmIxV$c(w_Ps65&58N>))=g`Z{Lg+-Uw5F z{`~n%YN>pibR_%Be%lIDYU?Y)61K+GF{q9bsRQ47f(;in+q?Elw?~htCU&izu2|ko zom&Z^0*_e{+$TXHp*%Ih(+HQ2v28jtWzMjg;Dv8rzrGC#2{|^rYhYxw#&>aj-?bHu z^1v{dfMw9lOBD;|wtGi8XkwxmvqhtD8Gp@1M%jZ`qpo$6;>uQ+7-Fp+OD6hX4Wa)W z9sN<4!q&lIb8jkGesQ-MA_`wiJj6Vu6Sbdd{(R58q@=l7D!c07HMR5n?N@meks8*1 zhJ3Plg^$VCS0DRji;3&+Je8EJ=EQS$adBEqPh;G6ajG0Op>&opT66Au=zL7|6X+e} zCUx{m_tt^ewtQCBYN87ku4!z%M6PK|1WZG^G2631$TF1fSBff+Vg++U0=L>$t?Z}JKFDI6ofO?v9?{3TRS^lot<3i1r?QA9MTc20*TjO zgu13mKT9<->Js`8e6sla%9Y3Z_W46A6)R>IU}PF~*=Oett?ap0NtUkZeU6IqYniE8 z+Wt=d8R9pYJHT6GMz9)3LR)t0Vk#MGR;` z@$1*GZ_rRkS}(1x*3ZaA{XWFQBQ&|5?(0h*7ZE|$)!p3?Nh-CmxUf(?VQbkZ@^_#C;gc54e$Gm-hcEY&Hbp4*>{P3h$74Ys2utK{Zoe$`2ntIEx5NWo2c7(}pO)v8AREu!mEM ziax`CP4=UO<1V~=IQrL{4}wM~h!`8&+vRK}6`yp3_kpH@RC)Q{Pl69(UsC=8?TD#+9r%EGUw zL3x5;HCDiehXN%vGBQ$FRAg0inQ7k0A^e;t*6yea$IF*5Zv?a?U%L1XOC?Kb+1|w~ zBZ|MMsOW`~(p{{pL?OpV2YndBYOf2D4zI*KyO#hcz%cAIwmm8nxfe?ta*3Il`EsTl zl4j51UVJyDmpO3tVp?kID~gNjj>4TO$VL|uys#-q;Da69P|$%95_s=GVISm646jOqf3-CkMSAECHWMYvSEabm@aw!6a)eoWl?bYU)%B zh6ddcIj^8kwtjg0W5hgDi6!CLysrfx4-XGpq=R>@t_FMG-b`ob^NTn*X$D!{o$S`6u{n(2RxIJ_gyolngxhQ$B2v{lpB#ln&9uYP3gB&E=owpBAg8fq z3^=eoQ}nqsx=)djk@MH!sZ(3;_h~%3L-OiyXQkkKLlC{3g9Bg4Q+?Mb&-kfCC;3== zdV0F8EeSrw#7NuQ^WMC96Qaa-US8)aiG*JG;wInXBWAkS92 zK~Ky`g4;PTAPJxW+3C;zs12Clp_Td-dzy_Ms048lLbMPd=9!NOZftDKFE1xUoC;BA z@|#O8nT66LdEM-G4RYmwZVIg^I4pz52dck+#U5&6^b5-83sIrsJEo6@fB*h8uK>$& zZ@=xAvVP2RjoN=-F% zjcn3SV6c;{;gRj?>Y})DBQYVt6PB3eq7z`@dN$_78Gfnt`yFa|X47j#A=vw{&4b3) zjM}2_>l+y{mBlh66K>22{S*@43kykN+ER&$X(hg8W(Gt@-+04~3`u&Xr6p%(#sicgH{r>IKWHa%QynCyOH98oMa6r9!wXO?~^33k{N%{0(NBs;?zT{#8#^6 z^d%_yfom_)#wdx1oGO7><${$_Ghqopps+X*PLEgHpSGWI@ewVpHRssLPDzn;T;ua^ z-@Z{0#@@Jb;|t2e-HQbVzcAgaZqdVThpYi?nYsnr*YNaog-oXB;{Ff=2E;i|`iawq zBnyS4ojv2uI&6dI<$wIBQo6bnzGq}KxbQFR!g<~Kv0AnSD^HKJZFvo8Qc;11qCmYX!yW8wlK5FX+N2pD7-y&Da4cDsgVih&Uo?Sg$~6-oh>`akiKAYAMwJ_kR83|?|W%?#DNBQb_wQq<%dSLc;r0;5`G4vZk{F}j2RjJsi}ey3e$=bCA7;(EXPlsuEOcEEoa)duZEvJ9>Ju8x|9j;m$+bke-(pB49Hcz;}(D zC}etT)d52p-0(<+a;l=i5ybxV`Ey;I!>A-wPcIfBcjrx@SU#H( zB^8g6Ly;sR5(`BVsiH4l(3jD}NOP}nR6HyH)AR*C;Rcj&Z+yycSy=}A8#5Xp6I``Dz@RVWa(FXd`KCZm!`BIInb{D?dLU3oxp%u#o=84VaD~q<_e~V*SbX0kOc7 zy1F_zmWqmD_TQnIY_;mpyy~DQhnt`O3$p3us;O$&N``9yYM=`S$6?fIj%dWNOw7GYjlS+$4Y#NW^eDXD?HhCe#h=pGjE|3JUaCmz0`8rs zB7-W|gDO$;{bdGhTg2#ro-6UlLc-vNFsP~o85I?=Wo3e}@sQ@6OQ+xvL2-9o_t%#B zo)>T~tyZPbm7!NI{R}?e8_9VrfO z!1xXa2hL9(Sy;YQv&!x^SO{oEx^?f~z0Q-Vw6vkFazZOS05f`{>p>2=Cf%v)Eg(HW z1KpscBvwImgm3G41_p+ye-zEkvUuYGN|KX*E?^T_Gx}sXc;y_JxXVz-$0$fs<*d4oT549e-mK<>=0mM z!jd21bd5Z=OIan8)}=p=cZTu?O#Rp;^y@p<0ujV29x|sR^IwOzeF1LpFZ|qm@bV=o z7-cDNhN`NnYIn?#6DrGV+W+7VNf-C6kf~?j{XnUN2p?$swm317(_RPaX*}$ z-f`PYu+sYa`t)UZFi7U@84iyqpfu}%8PoGGU%UXskk{0_nWw^vjO^(-w^bW79SVg? zH{|nl=4ZHB%USQjUsPPYxLgE}X|z6Bm#I;zQ2yGF%16XAON|0XRrBGBz)E2^$OS)k z11S^wD1E%zXM{)v2x0y+Je(Qk3M*p!U;m^`E{LmjRWDdxH6`FLk`iYPm&;QCMt38% zoYM_!K2rH0PxzA=!6Z0@lF4hXb3zq!B5Mf<`vW|;!zM+*A9~u5k~Lv)dG!AqI zw!fsqIkEpf(`D?C8q{m8VF5!KNyzoAZ{K`DXhcUx!x2|ujS!E7%5~YYTbA#+hDJc& zlB_MWqLlPvXn;vH^d6oAfdqX4l9rF5{EU#0lPjvK`sq^mu|L*-GE|V6T~(!pu#r#| zUESN8a~(8e)jXngKbGIYCCVAM(q z3#0*X-}nb<`hr*3;^3Q{I>UlI$dlm3zu9_k|4*9e=NAuUY-|b+5@Dk${72~!y1+jd zYZH1G?^DxmREMi~W+p3MEz-o7gu!_9;&4K)HYL>4)7#sAa9tCqfezoWCH7pL12w1_ zWRV#9&WXWkfPOD7ym!VtQV@E=hCTz-R3gk+POc>59|SYp*fAT;ERL0V9$68FW^zUIPJ= z{JPcC;}k&SkqM>D=ZJ`F1U^ptGm(VOGCVgfe66gE!>8m8zouwHx%_ICuo zcza(2j{a0yy8BHVQr>G{7X@)V|4i36GSUVBy}!SY>k-^FK7I=rJ%~rymcHo)gpJS6 z8bSq3U1_RlIv7gEI3nxI&49GZU|H@A7;T0GyS{h89FBloq~4IV0b*B(X9D_#;`()1 z6`)xU$12)m_=3TofXxb>&-_$D%-Oci+z7$)YnXa&Y#m-!R+c+pnw*+?mq$PfbQ%y4 z=w~5~rUJF*rAyJhbKgm@KrIXE6+BAUG5&9sa0IzzzC%%r-CpVo-py92aO6rg=>k`I z=}>i9m6u;oFnU}HDX9@aD)sYJ5I83D?;ISAp0@_Z5na0YJ%fMZJ@FrY4MrJtbxPn$ z7C>>ahMnf(0O{I=H(QgV4SsiZaW6T{A{6*|?eqi@%iza=vSv;xwUlf&EC82t+Yr2& z9moi`cONqTL3#s3Pu7DAEGNVOE@AnOX8ACZA5$)XBLK`7m`ibS1W1x(+yNkeP7O-1 z?LHH|D|}KW660q>dQnXXG=<&?6e)$5FH`vHxOL^AWTTS2b6ix5uk@BQg*A^2NCf?S8(}Ab2=!xsq~8k_v4KYxaLnj98&ESJ z1G@U0=Yy^|+h^n9$uhN44+kYHHj~$EYUZ2pe99b4_jiT!Hk}18@-dYO15$n^ben$! zJDrC)M*XAA#qtq?fDffeJrvz6Q|B~qztFpLkp>iTkX0QG0*|)Gr zb@zrx>ZqF8Vi#h~c<{lV!ch`Ma{yy(l|sNe)K!?85WbM^OyPAdE!ikqDc@RdvSPEU z!Ov-pfnqhq%o4k!clcL#kM=|9Wq`&G1P!3j*Bcr5Hm)<@xkE9Q4va&tWf8vF8tlGB zhjh3C=lZu+(Nm;;$c!cCw^-#X{W@vr>dsPGtyWxo8WohFd<}Ub=-N>Jae>yMp?~@L^2%M2_#6E>O7m_};JkEr{Bg{vr_9o~-`lHZ3jm z!|%Krx1Rl^Y$q@pHu{iq42?cQom_2NGNWWqQXLB)j0U{(uLF*$Gr(3aHWw_z-$?coqvP(BP|2gFSB>& zHI{zopJ_5Z#b?jnngroj79l=WsOLE*fm{APj~9lex!yiAelI(ORvt$7x^=+ayLam& z)1lqzHR7JOI^)Jpzs0K+d;OaEikP%t%x!13jsmCdRPjg#Fe`K*)CSQ%ldni<@Mi^T3x{>pMo`9UVfR^gf}{ z!LS0DH*emYErQ|w3v$*^l`{PEyBt0Z=+Th{|Gqc;K*9(tRgj#Lh2U1jyKYEUUMGvv zLHKE9C*0Sty(TYCFKODJ>j_z0^#5)H5HxXSIdab_Dy$Qfxtw~3Y~RJ--kzSBWhmnT z>OzkKIKKQr%yg{kX8A}wz`r4W$xxQz5iquSWy}F7Nag!)dOyX*#m!r%2EYhdYYQi3 zIJ~O(Z@;0Wq&)rgQBl6`0X8F1$bd2A6kuDdr=V!Ltw4H!kVWWW2)#9=5$EVNyDec@ zD!l$Nj|JAI>%3`*{Uzq9zo#W7Jt}}BWtcASz7aNKOx+n=7n6^O*!6SWMAYu}~%*#}CP}tb@_8FAlQxF1TfGoEoJn0WGe*K;#A^fkm zsjcniTCkF`a@xqP$Ks(sL7?q*qff*+DP#~1DH1#F`-M!Xbb#ec!Jw?2>ZczTLV@`B z)3eizXS1t=0jUPFi|l7GBDkxOBjoV?%EfnQoB8~|2|vohsYSg$Qi%y0=0_io*q;yG z`oSXAc{=C4SysicB~2QucJ>H7x?vPw<$|nZr4}e2?#sjGyN(E@Y?^>HwPC`N{>*|P z`1b!K#HCr)TbnB;_A`qp#PxKK9^JRKw|@;@S|jVGC!y;Dm|&U2;b9FRB5(rw&I+d` z5I3-HuG?()C{L|B_UGzigKTGo%OxQ7+$2%Aa{Cn<(YuCs7l$c&Irr>;_k$Cf4+ti1 z7!E{OraL)1n|+}HoE>f#W=p}hKoTw6?)QBP3Hc`=B?N3`Z+${#Rx;DND~%eKTohb2 z#bms^Cy&b!N6VM$Ldg*RgIiahA%iRcrh@$~J$+)ipo1DSJT?{#7mNxEc|2dr683)i zykC8RB7iMpX@01>I7|IsWc+%#hC!o`myOQ$)N-X_9HQb{-GhZUE@AfSIKV8I`*#ru zl3>ozN7GD*VI=WqJg=z84)?pDGgNC(R3Hz6bLx#r)a!BY@|LGF<6qk~Ps`3G$m!#= zs*`&Z^7Q@asMZq+38=RL#CR8hCw1>Q=m}EZxB+Q+cW-aRynysZxuIUbZ;} zEAUgzT6i2%LW8zL)n?qkQa||U$n_ZA98hB9P5c0a8Zi6CC`MXtf03iVwR^|zqJKRJ zKZS;do}I$9+P!AwDzApAiF!&4;lwE)Z*RuZL_MHcR^=mLrcPyZ#YXrP{FUL1zqM$Z zRD5VPOW5Y4TBSZRKYs_2fCjDnoyOsasqev;TJY=+&uvUW(;X@TZ2A$^WWTHabr3wV zzy&{uXT%bofzuUb2U4y6og!NXF;Ubpd*v1wuL_}eP0wAB^V#x~clBBsD%sqH?e1%F z54=Lj%=!NNyuA6K2*}tq9Ea~ek+_)48{f(sPfr%idG$3t^a<2l_GkJwxL^cf@$}40 z2#`T!qM$-h5#Ye+0aNQimffbn<2GuiYMbZ(FQMVG7!^KV4pO*vb==ItCB{CTI7-~S zL_J3q5NjIDxL$UXpj%JPi>Figj-xiXQGl1`nPL;tUMB(Ir@=E5}E34c_SlIz}s z$Q87y$9B(Ikt7}mad1=n^q>zm%^BVdb~l`nNkfdwT4f~CF^xw+z=&`FYB`)|U`Wkh zOGWvUO03g<6y!)Xn%8~87MSV2%}&O8L|xBnLCsFfJtPA6%ifFKu38VYpa(Uq%EH8i z)6D~j2GPnh4UNny7;V8HK(fDJA?*~P>!#M}&kTLZps3pT>aJ$_{FxL9-}+ix{8hve zja`lvk7NY+oJgWMz5_&ECF=&GuS`Ux!(H&UlX7#P1s*&G_(r@tNSW33Ei`ORoBixg@^da zg{-M@GP}Z@Q}35$Gl#Y%wA}2}k=h3{-3$A!@I5eT`Pz|^y;O8Z|5mniz8yhV%8n&Hve(ds)UUUr>y#e??i#MP)rKF~wTLc!vJ!ICO`wtkF zcSxiV1o8}{;`292<21XF_SyC0>r=fg1AL-_ipJTtm(co`2g*e}@(v^|6W)~OUkeKn z|5zq}dD&E^oBl>1gzgWmtnv;(VYG@N-8N9O7(k~Y{s~)p*kxU=mFW~VJ@ei=ohMuh zLi2$+bFkycp}FX(Eu=lxBoKEYCIcw>M*5+9wU@sFaWB9{R-I5_hOLV16 zM{%-eL>4q-&i~AAXX299`kGD@?_{{V1VPJxDW~&5HxE@I!Xyo8F354~!a^o)b*rcz z2$W&>Y(2YD@xUgj33F;zH@6vra0>3T1nfh2I0;(<2uW3Gn;8Q!RT-hWc+%LJ6D;rR zwV|Efs`=6No<`6mm1Z|9C$4|-2QXU%pf34ZjpGfv-++3lA#>b<1y6OqFuWaK6&D8?A=5N0Hq z0@({8#6cY3T?*{;GsB5n-ZM-_xZA5my?vow{@p)I(e+|!zILQStqn#9gLdafNgxWxur$5!ZA;dz*oi-ZXd4Syp51Vefzs?Y%^eu{(LQ)Er7y} z)$gM4{vflf5kc-2fIGIY^TaQ$@Heq0p)n`WaoW)6Xu%Vwe+@v-9;mJ|RIohhK;M<5 z7>tx%{cy)Fyq`Q?raBS1KO)bBuot3`_e1|{G0EYl7cx`do$D}ZwZowQe)9LOI^7Vf zm7gd0Ly!Elz}pVj*?ScT=#bzdE}Uw|)VlHhcc2%$h!+VC|EKu)`2Q$-c&=(}xTA6? zk(MM6uaw&w#;NM0{8v#SjBD^;1;zzxls-_T0DSK$}ynibE%8n_lH$?K*w2@mk2cjzVqd`7Fc`}I#ZX|)o{;e~zj7nNL zkwBGL$)x{HKlGcy75R9jPfg#?p4k1dKV?Bg1PE-TX(d&Fm$eU(Tv2R* zMdA4aruv3t_xe%foR{7oXIk_kkq?oi7L>dxbYD*ok$>Ga6>y9je*ynd1T!U1f1O({ z{!}SZq)z-5UEx+vQuf@y*;vcL3!q8-{$1*9s3Nkp|G<6vvpI`EoW(lmjR*h}?_Z>m zVTSv-l9HE68T|ct*e~FDKhj*B0Z4+rt^l}%4)-#jcKvk}X4Uv{wVMa7FDn82;%-^K zuFf#DRDO;r%33Mo^U?6D=7M$#^vs`OOvFL_OFwp5aq+j%LgE7~Ko3ESk59Yqu(JN* z#5-^>pXf1DKijcn`)~)mjU}f-%--JqeLhflxbwXQAD?3n>8F_GMR3f zGR+s_*UX@~Vp?$UFTVM>wmj-F{WVEjAoJW0L+wmK7hkJ+UBsPM5!y|zEtQ#O>>8o? zZ<5+7_JJ?6;hE({h`qxDYvD!_pn{KiF<)0q> z@bBs2Q!LqLf2GNWz&X9ihg#h;p6A_4+o$Ei#v2m|UFC|3inS0!Hu<2AbVWLLx+qlS zZ^L?e@{4)d4p6V^x~WoKkLDtY?o^V*1bBHf?ulbE^J7>PCvMp}O< zT09c0m+=*fJStJq4q3YYz5z`)TeTfpP^WXK~rD`gmxrCwrPeSth+H)nortzp4d0TNGLip$yN&W*dQQ_CN7UU9A1 z8V0|LD+oG}IDNo+P0rQ?DoE_GIUi!6@A7p(=aG1>%G0)N{gkCmEYsJA*ntCX;^j1Y zI#ZWkYSR3)Uln>geK12pBpEKWR&my8VQS7CmV*JN4mPv~w7;ul`OthMz{i(unR~@g z1<~rY;0Cxr^@}I?R5-igi-hXIg+E)OxF$EZF2>KyME2eDsGhuBdb&VQ=FbHY+W5Fm zSLHuKS}qZ4@F5G(%JV<1ED5T+LiY!thhyE*b+C$2-Uu9E;QXirNL^#v>fw!v(u4Bg z(I!{K1A{TC=E+XEDmJ5IVlt}8iKHt1UoBh(ivnDvkfdF{_zoIXO(NQ#FaBRzxS*&u{f|dtB(*>) zl}hSwLf0*V4#Pz?QqG;wFwLG2COl`_p0$ssIq(oJSjsRpIi>r0F$t2Ay0u*h1QQM0 zxf9GQx0Q;`N18=;habtw1&Mh-y0aPARFTQ2dqXK&-gta~4KvfDT=L!Xhi@ZZ&&Y@@ zBpsc`JO~MT9LBv8@Ih{~(K2G9fFl9Z`3PP5k=lmEZN($i?5~0IxgOT`%5(k;^tCTT z=R<+_d|t{f4|;az*UrjNup`;W#)KmX*eBy%RwsglR&NOjsq1#@I#OMT)m$h2UXK1s z$MnEEX<|-vMMqA=l@mwGXcSpN*5KV%F5a;ldBGHZ@G&%+#Nnjf@$~y7;y4 zDJk(CeJS<~?AIf}y~9-W^<$mW+QY!Q$KpZS?(U*w{*M#h%6tL6I0Rx0X$W=8J{`wF zHAYr#rARJNa&?m?$F&Z(`uBHkUrM&)R5&KLU*-S2e?|Te_g#1vMh*@nMYfd|Amwi= zG%Oin=7ygOQ`zL1$_1J7g`+EDrWY1shFI?}UKVfTc>Bsi+Ece<*kAwgUFfb+AQ0g(=H;uUT8G$t78Q z^^Ay?HC=J?ObF@WNv(K$zf0bI>`J)fzoxKZRn68+Yg@&qo0_LzJ4-$v$YKn%X0pB* zTE{U!CqtA(+7tcyyrGpB;T=h@*D3k8TL~@OYvi9BI|=Dh7WPw`jQl-)AjIP5OH}}U z>b&N|-#{Fs`W4)*3aBgWj=v`rb8PXRH^+;9XyK{2K4S*7Ep(^KzOSA5kLd&$8%cS$U;Cib(%0$`#OKy?gDVqrBWvysUHBixQ!{D`XIj@a-t4^UR!l8&+Gkpn~meDgUkxV1m znIKOJyxKy=OzN+bq{KGcWm*dUiAyWuP0+ud*8i~4Ax6^&qrSw3+=U!A0P(s3=KiXF*7^vG6_&Iu&i+rtQR9cMPa=iVT}&m0HFI zKm(?kxTP31RsJ^aM8bqZ?UCKCVrz%fmDGgWl|XAz2TSw?-XqacFoN*ZcyEjw+q?F* zA9Hn2X+(&Zo>D`%83?{f_$w696Xh2`@LOe_N*@5Rm!@ z0vVG}UOUIXy2J3(#;)YqV8q8aO{F@opCrQX;Xr2UV}t6?1p{>2t*IAyLzjBJvij#T zYn2D!kLP@NaO)=oaIF#e9G!i|@pf4y-Kefc%Wq)`&v4WSS5rzoG!tpOeV>Yf67+;y zBHj>FLl5D!-N!EZlo|)oRNVw89eg3}y?^q2Z}|HA${5A)noYtzJ_B3ZGgND{3{|Z& zmrSAErOh48}= z<*ifV$Tn%+18Y2BS|Rjy(5i3^_d|b`n~;6E2s8&YF2!nE(d9XYN3S+G-8LZ29!-mi z*<@tP0wt7+T#JWnNwvTbe_&t$LfPi|oX}Q5d+EEI`Rw&?rl0>_)v;+fE}VNoEdS=0 z(sQxoq!@z4t)PmxZ3|yH%>)pEo6yOtcCq*Q*dbNc@e;F)cKo~#2$ZYHUo?V0RPiO1 zQOsXP;_*V7)oN^x!q_3JnbSWfnjZINR4slL8XI%aa^H-zZf1rqC1{L`!Uu(^vBup+ z<`z`QwqUA97ntMne9h^yme_9$u_zQSbSDAv8M$)(meJI4zdDGO)LGqwiJ1Q z-iLOjp^iSJbl4EdVk=ImH%&ay>(QTTNm~Xf`p`iu<33P zq`O4AQ3L^{1tg_ogEUBYgES~ecMC{&H=FK|?sx6;eBb-ud&U`OI2<0__r2GeYpyx3 z>o=n;T;ouqKr@lcL&v)Adt%wj#GYWmet%lVKd*nr0ZE5qN~%O!8g|GkKfH3{ zw;u2Duu3INI1@*E5GEH@%d788F(YaiZq&hLAed}qc7xw2%nS)-coLd8%Kqr1%>%DK z2P1%@aAo>rAbV|^-Y4pvjCxnppAB7`CaMTV_+}W!Yg~Aj0$hdfj`^}a{UNY)$|t*%vL%w z%Wck?d9OKt&#`IRDi7~#&%QOVDOInD{~X*ZsV-T3`NH>;|ky+M0uQ~5rYutr6s7@7LrRcv19HYDGK zyP)oSObj0AjpNQ|5oBEcx z84aQJaZHifQlF@dYl@2%G<+y>6?v8ckBMK?J;qSC?`=hqct$5KMJ7pYHvIcE|Esv+ z$pDR^VV&Mfa1joXWiQtG+)k#@y<XV_KaOFV%42D%z6jzWTfO~Q#uJBp(NZ!9E%yL?|w zde`5jM+4VA*YCqXCsj7O3az&$IapVW8(>#i|SRg}lSCcYBQ) zUV~4(lT%JtTgh6Kuov}%G?~8pil_vTqYGG9~#(~!A{UQXLNXFgFm$E*XF17V;g9Gdjv5WFZprTEd1cfB#U!tipkCZ3cB<9~A z6Cb5pI>w@Z^{@z%zP<-b6>k+hUFXr|DX5@}jyuu-|1mKNN>0l01wRAt`Ac0?uu zioM^rDxKyzX171eQWd1QRz)Kq7E_bXvVR;4T-(=P2;eOxk^N5k{u4v`WQXIvtN-pq zNgbCcrc=vzn4RHpX5?^fu;W*`!H-c%s2nu%8=^;UIH1}JWZq^!KmgcnTvoJQhDLqp z@D4i6Z>{k$E#lY6xRzlr1+uqqE=_CXU;+Qv3i(ff9GzoO{h|AHzTbh=X9qXtv{S7G zMEqS_`vO7A8BPrFdWc%e$$~zZk#XLW6w^l_RH2g$%}Vr)s?a7L|%J_4Z{y1z%yCp&b%LpKvsfpR{=vfoFa&pPKh2b3WW z*=J9h5k@6fPZl-VJ0u$#K6RF%$`iWr264w{W9^d8 zy94b$q6lx%Nhjt6B4_o*JI+^;p)k2%gE_ssa1`e+&-@lko2h1tU%QTd2{yf&CAnfv zU*xd&%}dXF`aZfr*zsBS6GtrT9Mctgccf~6_Y}M8%jngfydOIRo^Fh>AG}ynh(QUH z>lUcdH8;FTtJ%9SQO3&&t@3*H`wi}0wyT?axTX`2^M=~nlpKq|z45nqZnCS0ke8c* z<8nXr{@!4jpKcDH1w4W9-{p+WrjzppjyU|{GlXi3wRV*&c8VG!29|PLwXK;&5hSnebS?!9w zGQx?m6s^}kP7*t1{n+}O$Y=+uwJlfv05!ZkE>i75I#XTm8+Oryv!cm>PoZ6mW0jV?R>q`>o1I(T4IT)uws#T z-#*O{oHN{`9Xbs)NiTPP&{P271W1EG;RHxB|BWfy$mo8xlNIUgs~MLS-Anb7uPY!{ zbr;;bd64cNVXgCJNPeAs<=G!%BE3_~uW@iOj_Vb6vunvxWUPLeXHrmzSRYI_14jli z5v;O)x&~VE=h15`r`N~Vf0tgAe9DxBT9d^2bxn6})`*XYkG0jmuC~+=irRqcNPYTcXnmJ z@8oKmR#QmJ1oqVVV@RABj(!5)DmBq)ga(cE=}Fwn{M+q%07%g6NNPx&I1eF{TVU~;7rLH2F+)@*1Eo4Q9V~Otz=GJEweQ zKQ|d}e9zg~uP7CoKYs~H40ibw6(0|V~nX5&oK zp>m1dh5OMS2u8w%O$C-Qca#XZaizb=>!}lB`5iy4F%)NS8`eL2Ozk{ukmS+;*@S@W zwMnn^7a!20%Y{s1Qz>Ifn zeO`<@MaGqBQlp08+eHs-Kz|d2j>!wFCREpRd_K1bI)Wyfr`MjKIr`?xQ*5gAa!f4| z`-G_MDrIKIoe0`4G*om&Vk*XqPxKrDvCa~ok$aTj7;Jx!UTte2FecWLRst|J84`=y zE9+!GYFiUC4iBNe^a$g^-zHsF%w&z*W5{W-i(d1M1+q1MW}B6xgro+<9WqfH_q^hx zWghQL6DMR!o_z|be`E0BBKaMU4lJ6C93;iuVZ&QFKYxA)!UKT0nE|v@@Jh6eEP#PM z2oyy?<|hHg43wkJM@G(!yT zPk4c36YUKu@op(&=GUK6_)vXz6`W2YDCHt73nZ;j?@TZH#wW{DSKhm%SL70 zBl-8+3q>=E&89!)YRm-)w6!KKg*<$;7Wq#VfcdTz$c{aMbBPwf8Admhh{I{#nEhyK z64Cmt*SkCUt9x3?qxoY3n-15Kod`E~Y~_$p=)iyysEf%tIBdG*JsJ#z$;5*H!9wo>|bCB#u_;U7DY+NTlsbwAsa!S0q)J$> zd}F0*IIEhmv#p1Tw%*z?pGST9ADXGpQ>N;Mq7(-apMXYL#;hc`1m zjsC;sGoK$Uv8Tr^hTVG$No815JgfIqOY~b7Z3nc%bRQF#RFznSWNQMo+ zR4_690C?rr<>8SL+?ou$g>wc(7i8~Tq+0K!v@}h5j*lVz!A}6xrrqghDsd-e{RkU0p&*tPA+Hg}r9Y@}mw}+D%VYdWWT?8K&eJ)208Dix6Yc z^tNrC+5_`1l}5SqbHl zwCo=2e!m0_da7i1kG0L?8a-;@E2L_1IuA>`cScm0*}6>zYzkV@RCAFC_FCKy(F-Ai ztAAxCL_ipeDt~*X-pAZFz^eRRV#f}sCA7;7e1PmIIvIG(LAMhT`*wB1P5@}=fBg7? z@Xhr|FFOE(I2`NwuFaoxoAwV*3<(w2-`b!3uyl;mdgS7((w!ZHGS16pS<$Yr1Vz`U`i! z@f#lNco|4eI=kS-|6s|QKP&gB#_pvk%sF-0D`-h=6mi)9>#^UOk|O^%Ziu8E%bRVWTe;UOHokBC37rqDJYRuoUgi;%S?^;PSGI z_%o9$UNXY%z*`dQEbq&$=Eja3QR*?66uoijl7fd+UO82-(2`>B*+z`mdfONFPEIB8 zSgy(gi_rLiwCqJ_++VNKOYFEO&CQtAHAY%mmc7knuS7MaW71#>h|^gfTpRi98|%I} zm09QfSG3__xnc^7+cSIwMZfjF>lei0U9ON1GV*Xq+3x+i(oEUDMA;55&ZDg(bZ@s% z?%w-tV6&MjShW563aI~;6dy0WAA>qzD{#9iU#fE`NiR>_f-F1HU zdAgGNiWD?I3?q&_xfW059A0R~d8PbejP3v@e@eGso=>xHUdGB{vI1@C;0O2Sa@<}i zuZ3s(q%3X+Zy-szSlX|+zsfCFj;HeDj42O%U-Z#%9@VDITkKyNSP<=B`Oc$g5pxgv zN5GGGy7%cZ1rF%{&jTCMW0)uStcaxvuU0fQ{lQypz9mB9Q=Ds0cvnMM+MoDbnK4cj zF0?RAe5(sQ?!8wPH*&R3Q8s6*VS(s%vO9eXottX>K+-5!boA%>a7XYNH-Sg|^Ghla zZf)$rL$DD4{wFQ&7@d0tFl?aU9NfJYB(5!OmIsaxIO5PbF0pclIoAB3U*MKhzS0SI zz{>3PipnaAtNx3(CK=*@JRbzO-C;@Oxtx+PPye@%UAabr^hE;uo;)V${ z)1IGUzKS@Ef{3*Kfyd+H0`QA!B>sZAf_RKbhfouAJ6$m>V9P&%2CUZ z`gbgzT}~U7YeTT>eBR`Sr_8L1CS?&enfd<9~`LBRG4kYd5-*$HuP=3}LV%yTwSE9R^VfYM)-K$k@aD1ciMC`oX0j!Av`#3A}RymL+6-Q9z#8Y>I}r9EHSC&ozjb%HdCFYE4> z#V;7LzD{_HY5S5{Qh4X$30t(|m=xdYk|&p^Sf90^N&RNn(GAKzcBIn4f>D4RcTYI~ zlANKTp~j#MP=eBYjjPZk`f6q>Xpxhe<$?v=N}#C&gu^|9gIzE?p#O;g>;;(J2(X3z zZV*aYd!8Vqw;nXVLvlxS5J4i8ql)`XIJhoo$37qg)%u%!+wN9g#&QLB9g^*po=7AJEG|poSL>^MD)E z?greSy&pWC_mf^L57NNj+-$w~ah2HB7t4lBjbl3Qp%ru16IZMK(~&QkIQ!6hbBx7B zue%2ukYoWUk8h5xh*HG}vJ8P<>0&9P8GUuX*pzs0YG}8PwB<Q85R=CV=$|NacI` z`#W|VW8&lASX#0m04r!0(4qig$zOF=!Ha85mkWFcOx%BtW64}o7=47jU3^thqtOsp z+(A-?3d$8n2b5jlW*h>+)Z~v-U}EpFvlWHPFe3g655pv_dBFVmL0C@q5jRvtzz1Y< z0RNFi;R|m>1c0pv@;Z3{ei3#V%4=|cTE1>~M^leO~0$KoL_g_6Wr9Zt|14e6rI?IK#kd`S&Kdu5<6+qgk3 zG!xcrharQ=BW%obRIhn-D84?0dBgp|pf1Gxw$;L|D>tt?pet0>1okHY7#*2!-QdbB zKX95t%|x#63ydABam}GVK5FE%&Fs0e4ch}2D|v0DA*5FjqS!;0EGp}##J^Xi z#W9m%7XuA>6H2o?HCUCh$of5uxKdB34Ja z!0q|<>(`#);m-yJ2%%kne7jzlY`K<0__9wpnJ@NF)?{fuS2X9Of!4MG=Y$OB!s~oq zas=H4#gxh<&m$>~{d%4=G1=KQS>8RRI;PfxYiJt&K{ZRPa!V zYPcGoxDW{khAhEkB8^SqM1xLQk&O5iS?cl8l|x($5s5hmxAkhOA{P@jwq5jzh#(ng z&iSpr3l(z*p^<&hEUUy5Ijpb%Ip)UK451u}XOhDC{G=d6Ba-IxqhjKJGK0-%jItd! z^u2QEZ{sOnk7v@0LAsfnZuxtB%n8Q~#{PU74Aj@jc!0JAK(e?%doPo-rg(_mX)$OZ zc_<8mw%?Kgc?S%8p{882h&Dqp-$RU<*|?SiL~{aN!qhA8GmY?_6K=@Y^71`aH zfM@)usE8qV*i_Spa@8E{mD?_VFDNB-(PBmXMA>{oGWFE!jq+uzYxu^QE|Dg#u0%>D zZluFLR#SI$X@;si`0D7QBV;#k~-=Bn<&$b09Jot zAVV=f8>!qiBL(r9D=(s|HfB76Xem;~z=Nq4B zLq?ee1G|>gw?~f#Q1MloJy_ z^OrBr^bG0aT~+KWtWc>!;%zk_yfAjBJ@oy#U+*=@7dslUAxz}TQ6`=3n7Lf$_ITY? zIm>^`y5w&@!eSN={%-cQrDR~&5PT0+gu@*^sG;PCd>PeH3SDC8@4Nw&3PnJ``TtHQ zE(}-wjWoS~>I&qH<71aJq6sSC(6z8VEU0Wu-Ztk(+`^UY=ft#a<0hfEof%S75KZ;t zt->tMs~uL&*m#mQ;IZb0Geo!+EDUy zsc@N_9mt#%8wEY1)2o-MIDWI|b8~!UKH4IGR?JC}a@cd{94z)dd=6pcYls@dJuQgi zdwIxg>zGILw7r$W<_eVHI@px(1XBs~6CW`Wc9?)R5=Ag6=7;V-bl;ZF2*y0G4U#&2 z)IFY&L(b+j&?3TO($kdzerQKW$5lBNj5IMWH?N6N`V9_-XV0F+9M%I}q_o-7#7RYa z*i!KEFAfy#Y?s5J-9j=} zUK+IF>jCZ%ISijKkl^uXQ3cn!3`3VVCm;Q_%SuN7lvFpEE7}VjH{jT26iY(2_X8Y{ z6DlBxX3xsHHqXtJo%RZRANVN8LqXEo66{cax9@Q*fMKS0%p)+zHzgY~GqaUr%$)=d ze;N{;as1*u#Fy;+EboH*uY*q7=4UVsUd;r2DK;Wu2LY!Jb1< z8Zh->f&}@Qf6-01UX`wpjgdVK@d}1(rDVVt2M3HkJjDR&4fKS+`X(1=T#O5T#F5-0 zBhm_V60hqJU3etC-IIJq`BpnG(6govUr2R5*)cKQz}o=DiNyd5y(NS!~VPg_)?YD^C14(v|}WK2*N5$O3L*cN>I!tzh)$nrl* z6r=Z?;m@H6;^-yRQ03{Hwh*1bOlL1ddJe$HA#)h@d^nj6#b_6=()M^J2-oWmPFs@ z?Z%yyZq;O>mP6_RO8&!MNcGexz72y)*kOb6nVZaZemM1ew2s8r#igN!@BLwMX!gcv z-+{^0i21Ah(cr zPi8(9sJ|jnfNqVD&x9bt71uh4Zr>z5e*|UyO*FUkND22-)g)T{J2XP;SX+|U<3aK! zV9Pj?0LD{ME?As2g*KHyA|rDr%_Rb1rm(B)Jx}{K{jOa)ErcLh$z>VJ&zRwwZN?uW zC#Cc6uNkAJ8a87%G-Y}Tp&T;HA&xqVR(^rv{%0_0%82G6aa)-@tG2+}(@ag(s(5Rp zj8!P3AG49R6AP_w?{DcZ=OdB)h0Sxs7ZQ=B2W7S_eqacI=BX2AM=ty>JgXMO<9Nm| zSlBibRBqSn`@A*^uPlb7Y~`R5c2H?(X0ldEu)4TdHs2$597eC`MV#+T4_I2j0Br(H z>aHTa)k1LI3R=`ohHSnE#c{%Yi&^gVF*0@Jl?>r6CCgoKzGFqxZ*d z*&W2WNnHZz*lZD;*xnc&NatY!*^nKvun9+n<}Edt2U36q>v8=Hk=`@Y%Gn7r0Kukr za6vQWV4`fbQAK;fDw^1&9X4~=-*&YHwJ^+vrsGDAfA$X{9pUC7l^DMu7G(T=R=bPHX?=OsSt#3nTezuDtljUx9fal6<%*zu|2cWLAD(r+UdJw#VI z{1)(o^*gJ1`*bZR2g+a*-#c>+KEX^45fmro1~`4H-T{txWpsdn7giEujcNe0CmPuv zxzm`|m6Er}8ih*_$geNhwV~{g8ZV>YHi1Ufm?&%hwDxwA3uZxucgDC+*k~Q+ffSctx!!bzJ>>v~zd2CPJ}|qn4L*A8^Vc zC(TcOE`@QHZ+}c2@3=)4Dqk8p;e7UEm~e^XxajB(=XH^OCr$zlVpllxKsiSj58Exz zxK}e0ba3%SM;Ixvl-}Ns+e#c02c~8jir%g1|J&?BM5g}JcZl!oqmL7%W>y(`#T7zIGiFp zT1Ik(r32DqmtmI3c*uN3Y@*cg7|$8YUaBqQ(zv9hOOnu|_zu>Tq&Jy^IR#^&*k4~? z7kYh7$z(A&`xH0^=4)-=Qv&mObCXz=P*mUJH6PJqZeVvmg+Y&K!|*itxU3(4uMcoM zkpJR7NU}tmoHDAMEh8i}VxhuIelB)18VL$XKqWc&C6$g2Skv)~tE+iO>B&aG-39dY zrE^9tbk`gSL!`G44uGWd>=?CQ3|JEYC{s)92(V~?und6^0oS;A%WU9?(YqU)>vdK+ zcN8;o?)$eZ7-v!Aux$2!2q~#QHPMJeYgmZIq*gCRI z1i*w0iqwEY6?IpP2JA@lA`^--3B`F0(!HX~w>L*?47w_aW?52|lW=7dkUy=m@U zmVema>sNg~$Kzhhy%bU^LNMU|7}bD0LSIr{G-y{6_x&@ z44G1|NPCy|_TnQKm*&iFYwCO9qk^uPZ~f|$$)d4tr$7G=2Z9rM6p+mTbqM5sKz@ve ziSnvAXZ`Z?nsWbI?oUlPbf!Vv+ykJC^Giy?-A*^@7n5_>UZ5evWx0Ul`3}xeKsPWU z0{_AZU(w8P)TGHX;3(DLJqkCO*}@CE3h7qbOkmG1EJW+zw5wl0_5i`>6it}Mn<0j< zYte|B;uw!C0$zKYD$xbbE|fMqY$JDHNQB|tU$ce={h2R8z4EoupVWL{cJx=zL*ZIg zIW9Pfn5zqzzkB-og^#C27%SRbxqX&Os=8d%2bawcyaMC^euShM`KG-r!YCh^b5T^A zEA^Eg6i8_F4P-2^T`IG@h)jLn5C|nkMK>bTP68IRnpp$DNzcvCgVxaq<|`oTj$WIX zpI4NXMFsw1d0sy~xKEEq^TR)sH0XiVzvVYryk`QBiUXB?WiHfO?@981;n3z0{VFBy zn8f+^x#zpX44p~jtNo?6hGRvmc~e7m6G6rA4?zntr)8hxIdS)BAiIbLbaSfTMT0<| z&)5HPP!Y+mQ${uz2ZwWo%xpqzs*$Zj|NHpQ6&7O8f3cYp|IsjM@Yh2>-DQJb&-3U` zMOK!MSa&QKuQOY-cg|2nyq7A`MTH?}4y2}>VJL%|_4l_(o-bnb_aC{kiIK^i`LOxN z%!)qGOm+)hhC0?(RTO>}ji|`0N}4=-O6nuvmvdl}KVaxn9}{EubSR7IWCRTf#`k1Y z?53lH=SqlJDSWlcG%B?rDWLx7&mIdZ&z>n?e{b)o);u^gXz=NTg-H|pS3?Fz^ zRe)0v-E#GeP>U#v{E9KWV{a!v-drF|wAsIEf=gko%XRU^x*$*|RH-5p4Ms@#G8T4g*Lb!;JVfvEu&xxTlt{TA8R)NEvjqY}Ec^BN4+n&_G%Zu316kgg-8OYaceMFhr|Q!iZxI2Iojmpmyh^D?P8`Y!>lL2Sm3nCb(SZaf6rvn z@06vg=8va&204Z}&ZX}3oN@0jg>Wp&*vj~~PoN>37?7MuqLv$w&_>v-7Wp;Jxru|8 zklL;=#(z8!&?{S~f}3uY91<#?y7m*rj9xO7&ArmN#7Vx;0)k7x*a4->jEonUq`YCB z^A-mlAt`_PiiP#Gq47yOe5)se(qsJGE4!aLUA|$eBGyk~lC2v^v!Wqeow?*1R5nZ0 zUC{#_`}jm{%ECm~T4frbGvSAXa|+B!ew@SON~0$mjA@LPql}(mg0{QsX;FgaMRESA ztCtD<8>ERbfa5*Q&}i4TNiYx5GghP#0FRq9<*|cri|@l7;xVcZmuJ^C!Q0%`Ug95| zpT%~Dj6~iVF{VB|cnO6HTi7^^o7doFCnWed4NyY4Up6MT7~pg8JcK-` z=yt$A5t;zI1`rWI>tRsMpo$p_dbhyW;<-k;YyjF;2yo3v<6%Q6Xdm529PI6_fu@dF zwlQ!|egE@EYDj304TO7O@J0Zw4M6nPt}sFWc=SN64!oY=5Gbg9nfPd>+h3z5ibM<7 z9=F>l8!OIJ79DuSJVcvO4Y87Y(+itE-$>pfcTUjwROWL&*lK1xzla%8g=mieTvzM< z3zy(-$4-!&4tyJYqvFp5!%96jSOMz;e0mdK&)fNBK!T{8|EM!LA4X;bXwS1MsH-hzTLLpkU3B!s|Mv)|29$uu6`D zIk*?^%}lw8`(^RfI6b8MM4MGKT4OfmRFx%^kb8Gn6UL^tApDx-^P;4Mi4_wPX^pqb z)kTI|KSqX^xNr;1$*6GRPfI}lP~V7>TSY#^4`s2ZX`HFEa1D1`h@NptC2PaaN0=RG z{)5olgADiO0iGG~k#|svgwnlPnV6pL1`|Hr0P__ABnT_qnb(+~ZpIR^UGKmk!yz~c z3h+@g!TjPY_;gnhis=cteb_(JDs zpp3?;cegs^O`I2kld%-@IK}G{jp)x}{#pmh-&}ay=!0gUcNpLMBQ#~n6Yu{av3#B6d+pEHUUG%&+9A6CN~xQ`<#A;AU5x25Y8GKpIiqt&tgHH)wuTBh){E6L zK1IWW+VsNcNs}>CB`I?y5&pV<=5I-yQISK6umo`?v4*I8RP3C}qM4DRDB~2YV%dN$ zvrb*29$3@l{-BOhp{lWGGtKtmw%hJiZINS_NNCBo7~i*P&fIzKjEF5*XlD_UB{3F8 zQt+pd(o?CXfjMS*dU}FUT<;&2s=p9{K~PN;x_Ww^^?w@Z_)@VbDJj!bK39jA3pv#5 zfnt-wygqQY1}fym)HI@*nwkRl!uQb7UxV8U;>%jn1QP|FZ{d=Ok<`F@fD{y@k`g+e;kJ4~t-@2hX~@nCS9QH0c|;;z(D?cC&rU@a$4S=GD{;05`gMQZndB3}zk0&Y z8&u5SK4x-q2}e4NHAGA-1jGqc9&=q8SEdNs8&MM13x|wdcEDJ=z`y$fDpi-fkcnu> z1h`X)QCPutVMncxX5A6wRo3%0b9BP}V6M*Z4-9|6t3jh?iA_sW0E! z6g>eifsTP8aM}AYJt+wS=kdX@i7`=8c0=7Yw#yICq+MYV6rz0~z*vlCKRr1)$-pl7 zJKOBR4}72WJUkQQSrY=_DYwtvjf{*$$Q8nYTMx`)$WqUdncT{4g?jWUIM|U{O4(bW zdo9|64CBMwmW>>hx$&z+vL>s5v?{U{e_OU&KiaaifxI8ZVJ|=GZ^w?LmAPZ76W@W9 z+5ns{WWvsvwzjq;JmwN$ilzm@et(0wmO;N>1ko)z+B4c^^)QxsQtJ;0qc`|*j!M+S zUl?1wyd2H>z6m{`7pke7%YU}+dlb+-R@;#7bUDTYW0a$&Jz?6TVbp!LPZ+8jSEass3gOda19I zcWC)8g9(I-M2+rVbDq^T&Fgpj;K5&6PB0;S^uBj$a5?%2Rk`xV{p06&4@Q;z^*AX{ zYU{>{Zy!E*h#I0*yS*_MaQ{|U#6cQ@7TB{fIj?}#TF4k(U!HE@Y z(zE^9SO8>Z$&n|$YzPb!cS-^SO?xz?LZ(oEeF&xHO4^Pgv!Zv_Lu4jn*Y3~S=dDXq zvl?7iJLJ09guX7Kjgy@(3df?kIzJMT`Q`?XWihWxj# zeo7b_(Y#FjHKpJ4>jiqjj~`Tccz8)-P$AIY{!vyI1-R2=1*()Q6D-c+k9XXDC|DX= z`F@({KRxGAefNgZ2i*SM&>H4PjwOecD zU7Mz>1(dmpEYHx6(w#3&PDw4XKf{rP-zM$A?_-UWhv7zD>fL?Q+WC$Wtw~aT9{6 zNRoQd960nlzonA-W7fI)4%99=Y^6-oKu}3%N>ZWXa;-ua?IPs$uU9LiizFV4OnTJR zKOP=4wV)rLTQ=f5rQfW)b3>(FPW!H6+&0Wg7O$X~V9zSdE#6_&FYCaCdUrt9r;Giv zuI5_T|GGvZ=O?c??7ssp8|5c&^CVDKmznkxMvw_g0n63ATt4^>Fm`&g<~~yV!POLr zKq~l`o|P3FP-(*bf?&P^Q0PlaN>*8n2Lhs}LIFD}Dk?r1nI7ywc_%X5Gs3K7Eq7Po zLuv%Fhj9w(qQ4*wq2l6{)xrrsU>f9rXaPU(DB%f@$s|Q0VsYiktbgi zqSF7jA1zBi#M-eCk&^E1;DU^2%5uJjJuffM)y)k)E$;s2++wPjcxh>A%ggR53G!vO zq#Z|#V)E;$tVvhX9kfGxq=uSwCV`B>_5>=;o4snUulickZ_>4w#kiT>ku0FYsy38SAu{}%Z&iOb`a8}(oB z2X7*kY3H7+_uoj+#w^emFL=`@_;FqK<^?ceNlFa=l!IpCZfUw|*ub-u2`i!t&L-Kt zR^)odpUqx9R#Eu0#Mq`=tu8!$&dDSQB0uzyVz>J|$0p)a5ruO4O`&gxTG68S*r_6R zImZ^GRRkXP=7ZF?mX|QE6R@hN`05xDZyMqm(jRX2IIziuIzh#U!71&(SN3xB=6r9q z)|Qfg+4Ja`dbS4UKnN0H zjpjC#EWQ<8)1Ty5N7&+iP-I8ucA1tm3cZ@1JDKLJJqX3oU*=F*uzoJ!8?bkna%DdY z&t{X5&7cka^UY|VxYeMqE@RafQsKZYX#Rldv90Bq#&`GKvz&OJpbp!xW>bBWbM((# zSm19%0y@krfg;57N(FmjYTQoJxlo`sGgi3zF(;|0kR(o3o7ib^i`J|k9xwbA7L zfg7^a4u6-T6t?DAD5H0-&W!SU^5u9*yQb?G*58`vC*73B)gS7wH&egp+bjN;cKzf7 zN~*684m-=Ot#t5L8r+}f>8|&VP}ypt@Hqm|Vvvb>CIOnE5(B<-V>+19^Sk-t!-pz! zdzpW^pp;ekmL(=O79PH@yD@qP!=$aZeHhrn%VxMpCM=Y^VD?;G6XsY>+qjnyW#jqm z@tq$20z;P%;J>S-ky6zj!kY{D1g1xMO3N{9;h{}yr1k-}RLvjJEj&JD?;EI2v|Mke z9!ll0($cU)tlbsEsJjH#`pg&V@KqE-&UZT-7FiS$?p44&l$=ygwR@yx|MH`3$_)pf zu#3G~X^y=mUO96sEqI%vmy}cG`N@s1fP{ic#mT8A!;cPsO>cd~pf^7zEx)|0K&OMz`Qxs#j25cX0tLM+ z_^Sy+pMt4Q0M-x2I+BB#I=>uqRSHzn&A|)wxIElgY4-zqmX(c3gz5%%^mcTfFzo!E z`QG8wgy{5yj=72hY5z_kqd_iSKs*+j~EIdZKAfxD2s=Ze@A4NXxd>VwwE;Hc+O|DFYJ#Vx2SoH(M%4=MM=yjx9Lk{|37K}BTSPTx zo&zD-2M|BiiqUI)RGfp?Isg7xa^smyjKQQu&Kt$Pt9CdGJbaWL&pnT-3fvIo5=?50 z+h1Lrz}ZoUF)_kd0vh$pfdBT19xq4{lmMk3i2Z2k={I~!=>h@+@kmHMz@R{Q$^hIH zX=!BOyz`&5>@wg)usdFn1o)hgSt|xQI)9L})6mj}eE%+0F#eo`}fY-=HvKi&M_Z5A7%!SI!b`&*E|Q-e90tE)!f#t0Fxqxv#c3_FH(>@T-w0(A!_ z3X*b%E#1V3KU7dW%5*_xwB;Q^lI98P6yEj*&50d!jv=C}_O_(Z({q&PyX^SJhZHZ) zX!1AxJ!dRtlIUyxsiB&A7we;2@G zZHM2Ij+*oDh3bT`cD)Lh7hJBIWh(XB(m50^K1lqOVM#fT3N5C0XhfTQQ%aR*L+ z6?{=j1uKGe$2C=~(}9E|q9#dsvit4}o=T(hfl*%^GaWrWVpLRA8yL)m{_L5M@jS@P z02N3O_##jww6!6?6OyHz%J-hj1pKH;B>=z%)87oitpXmCFG8TP|h6TKLmEQXvAbMQLU)zyI?q|A0%%rq`m z>-JH6Et+2>^otn<{exs-L~hh#BL-nCvELj9nC0>Ir)Fu+9?v(zwi@Hc-Vw4bZ5uq< z%VybrGwa}p2!;jXelUeMVA9g~YfXJ&)T_KU)cVTv)*wOaL7k93H+gDobe!eh?TD(}fg`0aE&xDSXz? zzO@^)S@E&f0wgWV_|}ljbc;*gWwrUJuS_@zV7!KOz<%h~+I0T?%Oan5xY+1$cV%xH zrzt0g27|%+6ux%rf3zx`4hhp>WZXr@Z(>aEY2pHdM9dc(ctMt@FxI=&AYA^7^lSEp z{D4&!n-%`vsgLJ&p1Z%UO-o3y(5o1${~oLGi61J?U!bh}X@Wlyz zEIepPzku+VM-^+hRI_BtxZK3mwY5nA-eADItgLLM+o{>c5s<5Zt^LwA3jAry=CU%* zKY#wPG(R@%{f5t9wOXi)hJ(8Sf($GiP(};qvSDm6MSKoXPkq+5LsQqU_ROJKotjAdzn$LM1;6`YoP(0@ z*Q4gku~u(y5T?QVHN4oL-QsPq0Mju9nOWG9XgUM*+m znu6Db<<9#vupL~>;$P(lAQ)w3WdX_}P)>?n94=yf1B4?mCPezzA{t~ zK#7(BzSMCi6Y0`e;Q1LCM6kWwDvpSV2&TPMf(`F}zUws9oiKe_Xrq!L(%S;PE&S9f z20p==qX6#^tRn~=3yYqLstsNO0VfYgub)#;9PX%Drh$kJ1}uUKN&HR|5ts=S3S<6HDAC*)#GB{5FAb*<$25U!g{VM3zTX;7Y|Hly+*w_S{?>{IG}r$&C`Ibt-S_m zw#D@CMT_~&n|nCgvwHet2UskOnVuDM%D-Yw{#zB6DYh-wD9s;-HC06eGBYy^I_iz+ zU~hfEkkpN-26OTEE_1w=|N>4}lLbGouYJ$AG)0Wp&E&@v%Z^ zg3JC)Gzu2UF|Z4(y{4EA@QF+2wE(uj#YG}Qw5r9n>=wZFeV z6YKK3&6e9kS{9a{bKy4s+uAS!sy}@!z-|B z?s|LgpQMCNn3lg ze}rVHk5yKF#74_BZz}TNC%cNCUL~$Ps3Z8%qqy0fn>WL@G*5IXh@cNI)4nDN1-G;5 zGB`TbejR@(yP&8*;AO^JbW8r8nCQj18&WA;PPBKAvWdJl&1FS-q;8=-mG?;@aT9)A zJk^Au?SzX_jIXbMMtdL)P9*GM$)LyNL;W*dWt%1y@(ardr>@mMKBw?Shtl?W(imQlc4f9 z%`E;%U4Ek?xm2rB9D!twgZL0sC3fB?zdY%C#ISVDuiPOioTwtXvtkRo5UpJ6<744*$p`BIK zyG_#NFk*1t2Y-58aKt!fgl}YG;!JM%3)EL{;#j|Bi@nTzCq?s0T*LnoAn7~6GbIVo z)ckX@CLfYw>-^xCSIa=O%lK^Tyz$->H-IM@=Z>Xc?SRxk?)dG5rM@ged$B~q6?^;7 zy-3im4LES%02>$AJsixyavhu*kscyJ!0P)&O1cW4eR3H2=g-J;F4|ou|5yhlY%%Pq zjIYfwtCz_+c<2x&vfEUbZQZ&x6({GI)I_*9bFI$O@d$HDl7&xWNSkGE@x4umw9aCIEI1taf^aB*XjKGbY$-@K z(t+~uVHd*M4yXoICIN^Litz!kn68MgFJH;(c5d*8h>%&`KSA<6{fB55TixQ#S8%t< z{5zPfy=-j7^_{x>jmb<}Yx5W8XMA>w|H;dm7#+>^snwP6dUvDAytJk!rMdKcIB58% zp9lE2tsx)%N}JkeC)PB?dlTnBo21D>wMc970SNy5S@Yp0ruQak1US$@@mGvp@Bf$f z(Uf$+Sk$4`B0Fg7?smP_f%7k~V@48A&3Z9=97eNbH@}{%a#QkL9Qu$6fcc>;?(!d%StR9 zU_$x+9omQMM!~SSG`jXZ+SXSyFpS!XXev+DwH%}-`%;E=UO%DCM&+UIF5 z4=6u*?AQ|C;7Sr}GI&WfeEe?O3{G;m>*p6-dzSVx^$S4Vht2AWE#lNI&RwC1klooe z>*VU%Juoovce%Zwcgf#%>+QJ#4Z25*Flq62{C~w6TO9mU7bCZcYX=Y=Lj7{W#d>c& zh=Gj5cKbC07ejrg42to$R?0m{WUL;u{A9K-kLHWwPAexdHgXmB6R7SyaT`oZhLvTiUsn>5A@T~ zF15JvAgJt)Ycsyct6kmQk~<%eC-Q=B@$EsTBmqoBFdie1O7rpG3xubK0E?n`?@&)g zFv>RAfV?G3=uq{Av?e6;Qe-sTb?t*e+`Nn983R%IO-1A!b;#W*VsC!BMQ$*Bc;|!QwYxnR6BP5AUFlqBFme*{HYq&^f;qE{G;R1wQR_S>n?=OOhS z##Io6iHQ1O%-1gr|AbQQjC*atS9)EkzD!&a^ml`Eo-1iPx7XB@twZmx{_p>Dk}Q zApITl*?!~!QVT8-;bU^NMp_F}PsRr+l7xcmk(y)y%3}ooF3fhf(X)hx)=!zh!c?Q4 z770>{{y#+P*rLJ0sqfpT$NJyZNG~YJ$gJS!=l5X|`fi(~4j4VAnWEB@+Rz|$Z=-@{8s2c+~5_+~@g zrLiB41JC`$oqm4}S~V1P^34*pyaQzlXLP)G$RI7!(so8YRn+-zfI;Q-`@MD%0rqVWtsDE2@{T3 zFb3qIZSF9hIG{r;S2~C)|B(l~RXhKHVL;)Z(}r+(fb#?Xmg1S6Es~d)$ELL?EiOl) z07558Jvx+*Qu280e`=Z)rSbJ&DXnJ++wI)>^Q+#xy!ejTXpm`^;zWq$U@UVF*Ql)Lfh{E(F{ zW@r-xsD4G}ZxYD+(|qN>_dkT+@XlYIjgaqtetv)dOL;jS0+Ln6uj3JajJ_UjB9n~?;h!)HgV)0$W zD%EPyg7X*Xv$o(X6*nGR#IozVZsuc3J5+HLzwfHUN%If#@AA+phtf>Gg2}y@CiUE) zFapPqU&v}9i-)IGgSMPk{wJdVpH1RIje^9KxsX`Fia5^>m(}*$L5&#sV(9nEx?5%k zV;OS`Fx8L*@+Roe4JZ~2{Ul34C@;QWcC&eN`aPh^5NKjbUN)!XisZ=0SW3!4Od-vBo@j7yBD2wkMq#n6C{8_R3RUBYJJ_-dg zEq^U$|Dv;~iB_#zb#CM87zIBc+{;+b` zh8$}LJ&OFt{l`oRvzM-y8^MfnR3mu{ASV8|;h&{i2e#_+YG%wSzNm&%tgk&^oI8j< zw@v?GfFfj-;GUV2f2#&TlMmbA?NLh<|CN0UtrNkBaOv6y$$`I=jE&p2IaM?MB6(*} z&K@}#V@h_2*}cd1J`Ir+XOO#!`LHBEefp+;Ps4ee5Bhl^aFo2x?Q6rKE_7&GbQXDd zpPh~EX?=gvtTQ_eq=33jCWoo@^p-tesstkmi0K9tI#y>;+x5oTE@n z@X@uX<|tF)y5YDLjJqWMTTS)ee|U92Z3|f+tJriZUdfjSM~AeuY~#O)cD+%Y z|AbM9$H6uB7-`{oXWuVR$s>ENb#|nZ4+orT-WTyRXd*vLft<4^0U zI!6w(ti^xu70y0H$cwey6qeL>*rNIH)&u9?Ag{V;-5Da7udxbFlj=O@(L=UiBBCvrCsbrR@yzouUm38TSV}MmSpix8p%sgin+T4}<|KqF z7MNY%$78fTp0EYgy*PF%PNujQevy36t1+9MYZZ97DYi~n8hn(e|0DA0PT;2IX;Rl+A6 z(6p6{<(g0n)Nm5!(v1<#P`1-@hupUJC@LumuEzS4f29;Ujh0#U3Z|OX3%B ztO-Ue{Wq>of=if0+Zc`qBlj(7%f7G%$~bdL4)^Q9jtQ9LS-~(r;?=8|=2uNx@qf|LU@qL<+pB3|_`Di8@8}`xcAWWwF#o zy=z>Uy9)Fnpax#VW@d8(w)#N-a3HQybK!vCU|H0k4ObU=A111X?rU?X_b&PE78Fll z&fUA877c>Qu?6V)?hqOr9Bh~OznZT41eO(YO)d12g1J9BbdX;Q=l*PmvZOITKR40A zAEH19EH(_%&KhPoI>2`kM6SW+jL{F<$!2@;)K6;qu4$zE*4#F5n~S#}yZ$@Bb=OJw z25IS%C$3>hE6CM63^gVPrO_S6ah{+C)?bHwHd4Z5@G;xkaCUY!8!5Z|!v`Y-%QuO& zXV*#mrdz#w^)7D%5~Ea5>fLIj^k67F&@OFj2(@Fr99SLWuHt)>Gc&GaO@Cqw16*}l zn41uF9$^F1J2^f5x@`h9N3MlOhx>#z`M=5E7h#2QihW00)h0i`yaIw;N6q|Xw~6)* zctiOOOV?vJn`uQRt5XN$WKN#c{rHDgk1My>4K@dX&3eym-}RhCDvY-h0aeW#>H4IPt|y#!YL`Y7_8qC9RQ>@ z64uwMBOOpVM^PTBou3&ew^kq=p4pSs--r_YNV7gN(IgC0 z^w!ce%&<17M81ubpjyr-6%g%wd?~i=8o#CoQhHHd^rqEFok;nN1zh&=*@>FuLv#2% z9KNBx(NUe7iwg^v4*U^*#_sHJ18g~x9z08+d3Jl3+{8wl`?VK}AAZ5981^r~LC;JZ zDutNR?&UZn36+31%mjEi^0ZP*>;wV;Z{=0$Ux;Ti@~s@KYdN=llnNzd+UJ|fsjQ&N z_wNw{>YpB4hFTBDgCa)P0Y2M0Se=oVPwFYhSAoL%)h6dqh=R?C%LcYr?$+q-2iXPs zh(-)iHFhw1?MP#K=GI_~oRr?S+(s;1V!S~WP_ls{#1T+lzbCx&IFqKhL0te%UI@ju ztRU+MR6q9YU`%vWO9eRpd#HhHk$mx2By^z=B*rC31!)m^hb>#PwK=5yc8Zq*xWJ^q z4>HidyJzYDO4FWi<4ciEA%7$0JevEp6!rSWSLY$XQ6KH&Oisf{Q9MUD2b42rfuI(h zVt<4B?e)DlGw()L))5GN6{`s3dV3+k%GK-LF{ICj=Rq9EquG2zKMfjR?2xWFu$$ z7##?srEaWoFra1Yq-U=y#)Bq`dh(+oaT8LaVQthYI$|?QROMU*AmDCq-n@x8J%;1T zkF^~?qZKOc@OIC+w zwYzu&j8yb-SnKn$ImguZ8xWA0jBTvg zGEXyo&_WyjN&3q)(Za)?lR^JMP~_ZZyP@`-$&Ek$8*s^C5bgP{*iAI{xLN4NIo#Lv zyi||Q7O|}eq`=F4 z|7F#PiuTVxPc=Tj*sL(ForSBnV!b=g&Wm?@kydn2nB9OzUJx z0jK)}JK$1=c2ln7kh(z`7Y?3by>&H6Dn6(9v;7){W zWm|iMVB&kg1W6~SYzSIEF%PNMbq1d0q(DxZ2w*+c>M}sZ!EfBIQgR~dR|&;J9@#~5 zV>B~DEnlx@WsRi$lCXwRpZ=qvHgYS8$SsC2@=|+a7JLR%d*MQg~j@K`~(%*br?rK9Xj-~ znImHH3L6d~>42zLZ>c_Kj9tCe+Y_WUGONTQWB3o9Nm>XVHAU!2zZ<;%IIBw@Dmp5r zM&5olns{=TB_@s|R1rYItK4YhFC8l+++_Ef=FU>h7+{qzO5A!r*WsXJn>`cpB7wo_ zM#g&XAl?;Z^X>fxjlgK&~;P@^gpF<7@wL$v5 z@a%%hc=cWfq7vqrPC!$~{}c;g@KIdKP-cTKi57gsqo+-nWqAmoJbP|=&6oJ!C=5%bFoPutXR%9a* zVqwMFuF2u<3dgpqg>y2HeS?N3NtFryglA-)H~8B(BhY-&dGQFjF!QK)PR!t}v=g}3 zftjI#rVmmJQ9Y1~p2!k-KQIeMAX4NP_P}hgVV_KPhJ}fViSPVQeiq+W6t@+wKR!NN z1}A6AOOKasQ#RzyXG^80z6C(u=$Q`@G=2E^aYFhal;)^eVe;P2--w`t$SpALaE6#?)=`jO=Kuozm5nZouS}^Q z>(D0$^dwLpS(v6ZQQT4!b#a~*;&0=R$w#;A)vJvJh7k+|(2IlWx*7t5&YrG9ld%5yK^J zE-h%P6#(SbzcyuzXF@lL^#oO$0iJYl{JvP&qwtaN)CfX*adUSkz6W$+VADhN){XX{ zBfx9Tid!QQxpw_UP*A%boxZ6$I19AQZ}aUmaic=zeV=MvNyR?Ra_&G#eLcq`?GLJU zcUGH2sLc$J+<-&W9{_V(@FKNYR}Eg|d}kr`M3$jtbGq5sP)8*G$_AoR8rFckUpZmK z3I`N80_xg(U8IkzIpY4i;n0gU+YVjFIdda$2@(FP0NfR2D&JaE5*Uj74y*RxfB@D0 z;k=9QR(!7xul(7RMnDB$-tVQ@7s9|C#r<|85Q+hP;P|_QF!UyIXdL(hYZ3gHkg;1x>1+ zlt3O05iq7lx&94S7?6}YW+GmIEiSe|8VBGO%Dv4h* zEw+n0{wl7l^oCNk7F+|QlaM1=`B|G1AO2 ztpWRl)qpFJcS%u#wIG0~sF9+t*FVuz9>B@-(-GP(9QzKXLyGnTI&ww*+!J67x@~f{ z37c=Ds|r05@3Als!YfXD@UCVNiXmb=SCX3grq!FttCQ&i3nQ^2(Lu%KMfo?0wMdLq zkU$ZjHfCs_g%f1~2e?k5m;$tau9L+WTv_5VKgzo3vAKvF|NQ)`sM5>{P@SlxNEe`U z4eYsJ;j1;xiznh68f@zDA!^(~bq)yBP`*=bdgNhIJ7r~chz^9907_|6dLZ+lP?Ld?u47Ezr`!_9tVDL8LLmemtbWwl zdWuqz*$yPieAkKZ2s>YMhYH|5^~9fV^?LU<)1rkKCSXxC#6j4kydjKX&dn*T&sXqS zlJ`MO-X>tQDg`F)*nI(j#Q*|%ou8(_CrIF=ak6|(*^VCE>g{8*QAiP~9w7~(II59i z)|}2oBq2m0&9e^|+}Z%*h%F84FU+3*`W7Ayl#44NqIt_ilcN=%$^~%X$=w8gg+rX2 zoI{G%Lx=vw$)M7Zs+Z%6vD87Z`|OfrfpH1cS`Lna27s~fJu3pGXeBlSj+1z{c^0KQ zAH)gK?&1-A?cMzR%4WGE{ujt2F-aT7y|G4J04@H8eC#DukjR8Y9YTIWvLR2gVu1W> z-jYcz&D6`0WyM@C`yWqL60>_A>7ZydO>e`UfQlyZ`^_?F)b+sKmU+W6pHd9`s!Y&~ zJO32wDbUAOszxO$&&#+Kkdts)ZK_EX3$RR5UNr(eoX5skr9x6Ud5Dqov!hDD7n+@~ zt{i4*LZ&q=^?7^Y(E4rL-h)2)K(vJ>^XJ+b_rX}~w{V;p%&LdU7JHAH-gYpnH7dOg z-VV|%B`TVuoe|zhp}sttUEhycBfla3<&INPWYIZ+v?UDfnP6UH6k=f_);GcK?!t6Q zyua4V3rj%0T`IZ1g(yF#ytn#4dUOXvi@{JDLD?f=C5T5;0AL0|^b`iadh(OVPFP^e z^0y;P)*+xu?32?&Cp%R}lx%)Sgcpbuy$+#;g$wEdA%*KDUH`D{-`^MJ*#Hlm64dzC z%+ZO7cu3sD!^6(J9A7xW=y^aufLUG4y_Qk5HZ}D9`PG!h>BYG7cy#UltSwu%)B?T< zHrpWnr(dou8WaXM*+VIXQ}3RipXM&j`B=Sn?6xy<^luU~spXnq>2Rd30)bEyf#{69 ze7TRg1?$BNJ3jnX#}+20W%13K$h4N2{657s5lOYUl8UVl3{!K4FF%Hjd%*(F377&~ zJ8*(^{raanf@da%q)l5gD2)}=K{ammjNIdB9pv#LV@!Ylwn7l@{-YGD^6JWGUQluC zq$YPGi-2{3FGP%;8HiDg_FmEmpV0NS0K&0xU^%p^68gNqX5#Cub%69l+UCh?CbAtC zbCJ%~e9tyZ&4*`k905=dx7^Em_1>Thw-ey)lugjd`o-MAh5VagY48iU$+OUG!Qtzh zCBU6s4z+J;YM{{&agWjK7Ko?V6yv0t$BgfvoMbxXY z8p+r;egf2b>>3u9it9_3k@+b&u^MTI>=gtIbQ2H|gSHgxA~@}h!AQWxdXP<4Cf3M7d7cbDN|U|{h>_A^ds$WLnH zUe(wXdbOY&v9k&m@hO1BwrfvM^~5PfDk21uNX7kgeIrLu_-^pV1hFEmmw<6>mB|PP z+lNJyRc(i_8u)+|_vmnQ6n|uBXc)992lW~}m`__%8Yx6D9y?i6Q{$J_aZN`~Omet* z6KM4IYcm$?tEaCD-H z3J$*gn z{sfgGjHlPGtf{w$L07ugAyw*t@(uZ43RPdyHjFw#oDfyN*^*hQLD5roM+Njx7 z2e96cx9yYhx$x!XMM%R-<0*;IW5Z!3`MfWj;J`pG)xiE%R`;XCa5fmU*-3OwGeB6P%1f~zj3X0^-1CNZX5a|u5pWivY6tbcythq)DX zq&ZG0U9xNi#8HWzU0c+(SV?EV8fJmp*RRvz86d5G3wLaXi>`j(+#_(P?M~fmW83?% z4G4k(DY3=UAbeWG=CFPbRuI{=t(H8zyf5r}H^Q0bp)Un~6Y3Aj@(rSvFBoEwh%9E~g#~ z9evxl*=y0Yun5!`yU1#m0&O(q*v5}UXlGga%EyXzXPY-BH8LkVA^!7dMv`?0%-#0u zNQgqkoavR8C#dOHBe8&AC94>0{smw~ps6s_%&)h6KX|pi-rJ1{%^hXL4N}wO0Kd2O z74hK8*ak`hFSqzF&^3WZo8AEFDe&Y0FNys zB&o@-%S=8zeoK^_ii$2^6gwWNHt0{_&2T4_o8z;yp=hS0>)}v!=9yUZNy5GZhY9Iy z>Ukv-y#W&CAc^94A;IZLzeHgG9|f=p8fnb2Wjif@{l=vu8jRAp}wLB?E0aX<5im+VdMGZ>nBDN^ud^%21KTwg^&w=cZGfCWM<5XP0Xq`_aW4T;7M_>*^nUO`Wk9ag zmvcvo6^tz$T{evTc$&K5_|^RJM>zySp< z{}-*qv(CV9B35O?#*N2aTwb@>vPxgm)mq!%p{OJbT_>l}?un$UgXhizA@jQ_Pj27- z2eJojNu`!qm=-|}5bAHvNN(sAuL97Wf)g7MTGzHp2teBx74~)@d=Midc75x~fthFu zfumv+&{GS`eMb;_LZw00EQ= z@{GYv@Sr_H7C$7{@T;|)XlZShrl{u-+4l2fjsIx$F7ajx2oLwgH9)E@D=YKDHGF>G z@d7{y&)pzXe7OyZgPncXvxm zcTnFWeQV{SCC?yWk)q&l? z45tVvsqMVM46qpyjukt4$zxD9>m-bB2BHE~36#n)*dUe@?)Y5dlL8*pY<7T?QUatioy6G36XxxCAyDcdv0{OD zBZF%c=*BCA6g}oQ49v8E@5EE?#glJv#dJOM+NGkww+62Qw<8-}YXt^WYHm!Gq*hc| zPWT9Cl<`jR9@D6SoPig)=aL)cTNt+suYo84EdW|M2QM#UkbDWry|}=>)||AuOl0m_ zb@vfLEFLktzN7fY&N%GIX|mIc1a&J)hco6^eFT8Th>(l#?(m}wiyZ_-AERKlL(pUm z7{6!0E%Gf@5t6yu_IK&VT--f9-&JD#78VU5HQL+T-EbBL`j={78OlN%{XIHr*q-N1 zZU8|Vpel$bhFwkyb{^s@CnPrR1){+-=+E-K?wF1UCtMh?Dv2it4(io9He=VE1&VI_ zaXgSS97};m;*l}^y#zCmqnM<^7U}+n8LZK{hfATWUCEef{~d$dp9Xq?d=_wbDGKCl zU{geWWR}ry0ZMhi>IfVtMZ3TsvP((r3a&hb`WX*n1-=8UtG`j_AA-t_GSm~B23%kd z7E{tIIZ{HsL`xUc92A;%xkI4_Y?R!mR zjr&BKEa)R(eM3;C6(wI$4TNAw8Uq48tl$IzVtuXw1CFmL2=qk^zkqqr(G1ZtN8%vV zCooD*VAE6X*V+dWrSq>TiJ0+-;RiE(y6nWvkhF>86Ua)wMG}3DP=LAy2df~<7y!rI z0HK1;)QcE#q3b%BA>Wb@yGEnsw|pG280(&l5cmi?WW$af9}w1HnZ%otesYY*B8GN2 zXA!T{K!_n^cq_(7!(?#4pdbvXiA;b%WGaI`e;FhK0w@$zV*5CRdyCc~Vhmax#^SGu zWr6;)V|aqC8ajM@Y!Tt%dYyMmkib=XzJ1F^tD?kCuk58{jiC^f*eR0%?>`&ftaM?1 zLOLrT0ig>tU?@GSEo3NcEAE4LTrIemV>J4Hz5!R@C33 zHI11<%2TdV?jVCS1+dYpyY}B4a%=ftje=S3bXXJJOMFYu-M@A$S4wU-|+8!$4X%n5!rSf zFTM%?96g$XK$qc}b^;G;PKyJEZ1tEOUI_;%pa2*$3*rAUXaH~Ij0#&JK{XU0o%JkU~nd*3j?!00umDT8*dH} z@qi@N16Q;$3j4;?WwtpuVZ zprhI4qD>iEn9zf8`SN8?zk@tHm|8=-5=25c&?Co14{l|!?Y$-gFpMa=4i=Kcks+2v zcd-zfGI~G@Hku^1oY}(6tku8zZDGJ!K#tr4ocfD~2MS+kRVj9KQe%Z>a1?H=y!FKW zS@9-#N}?+TBy%J>*@7ELryEOAW#vGUoDng#Ne~zJ@w?4u$j9Q(+qIar1ZlqvR(Ye`rz>3U?x~HBN)3MhVwMpp;;RCxbDg)Ru8%Ci4oxU(FLyA;K>#$HESgd?hpI?{b`!5>#n@+=~S~{ ztu=sYKn~S?&w= zFPg=%Y@tS;ivtt{_2g6gZTxlxFFJF0bI-^LgCD}WmYmy4-gr+^5Prn$dkTaK99X}1r%E|RG9Mg`3IM#OM!piJsCV? zE_)~lnKV<6Z;#AXk{drn%+)gQqkUO4c%v6ZHLj6%6KA_o%Iy^tRFxC%w-b|^F^|wm zl}m0MX)+zH-qtkMA8(7t4luukkFU+xTiLZKf*7jTFYSYE>DG$2lXiB8kqNFNN3BQ0 ziGqTUiu6;I;h=%mqA3PAw7uQMWvuTQh~ocTyBXS8WYH3J@4*Y?(V5=N%vKO1#l^)> z8zbZ=NEjw}AWZ1gg#(VVzT zYDEXPyCN+;oMZMUlc4TW25kPMEE;Ft2hyH|+SP~I8^~WKXlq7!u>Uo8;f?-ZTri7X zI@<64=T4$MELeX2x}wO#xV>Kef1PYoQz(Kl8R+UH9t|?#xSa@I@`q6m`0IE&x%M{#Jps6|*TI@dT3yp%@@vfrfVi4!TCAHMZN1!N)%|Qv= z*!Jj&U zR=$f6=@>vDC2A53Ox`ZL_9WGT^QG9OmV<~;QL!l6w8uMn z`%@TxZ3nZw5#6- zCqNY7EF_1Xm5wNLNe4A4F^E8p*o$eC)vY5KGU;dc{o|P}f;kbW*_m_Ou9dpO z+efMtkdS{;BaFy7%rFZa0{cSN6Yw(8F!}iW2#WRE<#KwXVw5%v0%H3_bpn8&!8uop ztX_(W3%OzC)vH51yEM_oCaXU|F?w}AFy<2AeOp1Ux3@RffdkNIH|*Rwn7F=#y`rkh z{)`EN66S+mM~#NOfola~gAxqk2OuE!YR?k<#}@sksQQru_3f5AqEy}Tr)k~-`2GVd z@(2XbjQlXZ0>*oCo;bQN3C6{6r;il&Ey~;9A9ty7p>H9k5s5F;ICwu0nh6Reh$?`wtQq2)uIwRfkc|-aj>~wXCK1m;s)4H_Rtalp zM-VBAV}+D$kjDu<0=^*})7vVodoB_iR>9T$ygUro`3kRZ$)#7ipFX3y<>AqIjxtHB zX(h3}5bYhz-k$J1$`&Zo(I#UI(irX9xO4E5+Ay zbz~sEVy8m<&`2JSil{M33;*nByvXqxL7=X2zj+oh_?HHa@ z0gsGj4@?PhF)_FC23E-SFj=1Ov<39OHZynx`xrbw$i194hv~!le`zGrGy`V^sehrv zy#6_~K+(Lmh7r^{3ZX)*KytBw6lK^^=Mhh!%%GHoodPN}nq!V2{pjUzyMj?InCxOe z^-V^l;EQUA)f4s$loi>k_+T>U1pHfB{zUs3JaEaWPgkBpjra&_G9D>3udhUc2L}(e z#fGdSjbs$$ZqSEj;Y>oweF$^|L~y%LLq%xh?b*h#cC9g#h>8I>DA53KNan_p z#nvtZ#P2{ZX819&zK)B^jSmki%ajKicuKpAMPPSf$wOs8FA2~%F+ZxP+?0vlbLi0T zylD-NfO{A=qRSgR(u$cvHhKC>Z@s;)O~tvRn063kW)H$uv|qyA2udm;sL=o-6Ad^z z^gOx>VU*G9orS#qlm)_Rg_Qby$df7)+#b}~xC6T19X-Z*47RAIdt{paKreGp_o03J zPE6%hNN*AeTF=YNtI?c0d+De#NW-a)g?U-{D3C|tQmTg>WGA+hxVfM)U^C(xiT@Rb z4vxczH-j7l_`z3(0n$G*QVGR0^bXnNMl#AQ_IcIWBW<_g|3*YcEu88yPQj}|!|)&w z48J$uXKf9|Z3%2Pt@G_*K2WYlVZvy|-H?$MG6zi&9k+NYJK$S^Wg;BGjPBaPoFkcD z0>;ti)8_+9!cb#;L8d=~Hm&|pQT-|WJ7mTE6(B*2(?4x+pl4)^12u&Wv&*G#kMYSh zVbecJV}-HZ8gBQ>r4b^_q-V905)K$<2_8=wk7ho9%Ol0Oj!iO78L@cdizXJb00 z@jiTR8qd!Yy@t#TLs1T24Oa_%Ay4o>K;n^k!tiw{;Nq}(MiCtNF&UhoJ4_p~$&u#9U$^{_;bt|$p;19gja$Be$PnkQ{37Ao){hXc)=I z(EjAsM*#ud!rD>T^%yVJW4)k(+}B89KX70TRl{38R>hKG)v7m@l?c_Of1{0TuGP>w zr_A~MQQF`$(a3QTyQ*enOiJn4u^R$)s}<)QR{rzPdhCxKhViT^fJ-0?K?TB?qWWO8 zeu4oe1lZ9Zifmw#rE2i3?zF+`LO{u4yXY*y4QZ;Vpc1YDCsguivsSm9feo(@ZXROQ z1AzYq-RB)l)fyV5JSllMHIic?66c`<7k*AceiIKL*6LY!nx(B zh0z)L2PAVj_Eyl`cNlMp_Wc!!whYsc1hSqa(AO=KQu12IY69JuuA*_6p%FI7PcQV6 zeZojmtrm+J9y6*(Ho#!@OOQdy8pEoJz$VbRbKT(P2)@V-D-j`^k{%l(Js}~(g^(){ zH3%u<0ssRPIXcv7WCzs_UQKJ-27eYdD0%M09JGe{CV(gw??{e*6q?M?f^^#2n*T_f z-6kxlZp?+Up zPTG`VKh*Io`!LpDALOD>rd0w)oT6$7UTv2^JVy6^94-PIos3Zg^w#R!l7^(XD?_^q zI$>Zy$YLK8@yDSS1?zw(DTB^^;-k#LHUnlM_T5zJRm!W;eS$SLGEvf?Qg`_<8Y`6J z6?bOUo(r0L5D{U{Ev+>?F~NEGu$t(eM;%|i9bF#2%=>g^>b*54h_i1vf9~83iDEmf(upc(d}{q;{lX3NqRlWc7{bF&TzDjA;!dX`6C_iT zrWzGCDw+}nAJH8%tD-#?=1GJB{e()zUb$uhKy&@homF~{ZLw_6kwLjzh0$7uZO#K4 zr_G_oA`itK$^Qqm)BB<>ZLor%0WU3@fQbhUfC-iGiZ}2EK|{s(qX-BraEr-c4Pz!} zJXQv|>3#hrq_K2AFRx5=?V&554q%pJ;D#%>&q2SKx@L(uO-6%b!^1FGUR|w0e-2*P z&j8#60Vw{9tS%uzyQfMvk?Z-V|GmSgfX(#54_#&sEV1&~GQ1h?FugPM)C~JqeLc18D zjZsOEv6B?c6`jH+dm*6E0ozSqUmram2fPpez>5lOicJFR$Wp9T5dUUaZpa`{k??(G z-l0%AgKC4QK~QaA`hk^n@67G>@i-Wqv;m@PiHB1NE;=TmjfXJgDPQ03?S*rScX-5C z!{lg*$KW?H^5n2d+SJ2q9anNkxCE~Y48Fo}4m7ahg_+*>?mdGJ1oR6@AjqgYHZ!-$ zTtg-aqb7~T!sTOawwLVT4TE}%TZqcU*ytJN76W!Lhs0)=;QEIDsQ}o)nCD{91+GBDT8*xXtyI8 z4td1%y+g$6I`QU;>|lIt*B|AhhQSGz(VH77LU4CgT+Xh~z&b%8fDHsEv=MF?_DtvC z;8{dPcs(!mR0JWR48Py!_c6!T)be#d2j=QOP>eUkLt)bMmer>PII!7E zCc)spcrWk(vruLc!LcLNoE8w7Wx&on*LGYDQXgFjk^tt$cx#N@Kc}S}4Q{r=;XkZv zU6&Uu)92`UGB4M5`OL5zdi3KCh2Z2BG~B?UO6Uq(Y$DPKv)F6#{I@mi6)Kc;Ov3t> zXOIHEl+IFfpk#`Ho*rjZ7#?Q9L@F};g_vM9l11(M{DBjeW#@P-CFC*9)&zAK>+Wdx3y=U{T{I+rK*0nWZ&}Afg0C9EzCNMf;zn>b0+R z0a{&Zbtrp%&`Jb$mF&~!&wUX6;n)CKwnggK*zGzn?lm*xiqaZ2ruK(BU(Duy>bFBN zh3bxd|Ni&X2P|PM{g|fp?a!Br2f=WM6O$x(F1{R(dL&8#K|$hW#G=L+K$3jkUf7M( z%|zo+0#~3SZ`;K)bE0a`>i9JH6?w7OnUMO>6md>h_XDC=HTI0_DgDA!BUMp3VZGPD zS}!eHgmFU8(UT|t0aCyZ_8SzV*WDDI!iX>^Q;$KI2*DsVXjhHw`Ib&jB36{sS3uiT z_ZbEIB$x)OW_r<#M>L%rKad4EKOdB;IkTofNPDm&df6M>;A};$UK^#gh}1 z3bhSiA5OKro6rgulNsiZ5*Hh0O?0C9yc3e$y9ZN8)AT7?44@!p-$6wVj>w&Z(*Jj>`UyqwxbNEbv~CjFr`WJTiR-4;5NtH+C3y5W;KerB|3ms$zCFe56J{Rz>|hoo;k(nC)~< z!`0VT;9XF+s6ws5xGrG3iiC4V6Z_HqSiPSM-RnH&qe$1VJf4Ku8zSK>fM<*__$`Vz zTOY8>4WcEPCx+yi)o+L2RSaZ-!SPXk327j;PH*j=$U0|CSdxHYk@Qi{;{HpZ2?;6S zXpS@w4+FYW!5O69yp_OAtkh)OB4~~3$BHXo0uCc&U@~w;d5@>AFENP%yqg|ar+e+D zE|4I+XNhd#@K(qn+@`$W|2kP6-&d>fnk49l2C7alRIW(dCm1-QGG7S1|0Xe6k~Es& z9qCBmfNQmp`p(Uz>k@h2NNyN?hM`i!oU|Le^*7GljO6 z99vF&n(9WC_g;@bc@@w(jd^M6wOww|wBIogA?s895k6<2` zH@@Z>@|trk+#gxpV<(0JqN0Qz%n5j|=XMVFDm49V5&x^wZ87Nu<$>|3!{Ay93ky|3 z6Wu&<%6RvuLaF@g_Bj;q`eXnVUS@Fmz)MOKEqfLjnRXJ=$b)P z7Tg-XzA=C;thydXWwKR?MF+4MuIQu87r$AF5F;dMIffDNo&S9XWhts906ZNu;oB&b z&<@DKu{8dNa}bj&UrM)X)Bbz-DWo89AjFLTg^1X!vseUS`;W)iPXenjl;B^0h6}IG zV@erDAU%fuSdpX={1vV{6BsN(*F}P#z)%MaS5Uscu^m%C5Km%Hn_=~XaQsGO`45lz zNr{MR6IG-vTwOS42%}`k>=)Ag+G7vHzY++3SY!yY$I^kFg?nez8Dc+%-*5`gp0IJ` z$S(BplZFXGf~G;LueZ%hoWKXxn)f~<|hCK61P*Y^ujYNR(N$}L@h0Z zLsQuVX?`yhA?IISIOZ81T_4{VE)~wG=&l4f?uFNq-hH?V6meiS9I8NsuVAUsD$u-3 z4^p7;%MhW5LY_h5iMY_+gu<#+_yS$zCjcWC4$-VtLC*O?GNc zGpfJPDjAI0#mt`MMi-g60>m4@Wuh6u8WpY>|5MujM}sou^IZm}ZQJfPQeyGo0cw%_ zf=1T=UQp#w8Y)CbBFxt(FeMv|4Z1Ce;R+Qb402j#sQ`7BDHx)_k3QQGW}p}5Dx1|6 zGkG9e@luHCfWY`ra3iFYtqHl3;T*GQ2_tKT&S249u!GK3fS0NPc-scqV{=^bh?qA_FK zH5#RkPHPHX537yfL>oK$_?EtbHSG&R4*+%#3LyYmOjjngDk=Us)pjd-!5{W0BGr{7jFQc6$GhM3*}-@;v|l|L3aXr zun<1o00-)rU!Npt2V*0|z(>P}j&cJRc-zyHAN7`)!_T{m=0FTphnolu8GBHBVCqi6 z^bZ~kfAfLO4Cw|dR<-b&d%MSid)5X;fF*c+Fl*ii#u=dG?~vd9jXK#e_i6`NFWhl5 z^Ah6_Aesx3PX_sgX%s4QvjF^=MH>`BOu>c=OwJnn(Gh7ju!bB%RnZp@RDnOkQ!}N? zo|CHBaPG&MpcynW1UCM}hC6}pg*f*iLQ+Sp{CL=@snxK%sc7)>M2ab2i=DIiQ6#MwhZF3vTAvPeA(8mUahRiTNytv=lUcJzR}bLpiSfN zZ~HwTmXv^^2F8$;+bX()5!EPoAzYKO_Hzt0mhy?p%*ty?Yp?kECAFp_E0)~#gX0BPW}%o21N zyX-RY>tJ~t%kdykF%LWEgX7%Voll;;{c}>23v;4yCa*yL4>w~I6ZoQU!pSU}g=U$Q z^e<7_PCb!rO;TPUdJsPlJT0?=Sz8va{X{I)%Ajn)(O+Gd^}^1{Dck?yer3w+s949I zIaDj)Un9pyQGygfgO`8sHpz|GdJ1z@I}LwR>v|6dKMrJcQis*)EHyOKpdc@wZS@O% zrcvf#&O_Deg0-Di;Cb1eIWeO*zky;Odv7;xE1S}}fcMaMhwSw5zB#-122Z6ctYj2T_YfxVvuITvo-YVrZ zLbo=Ux)u2O&-IRoM%KJYjjHZ-HD3@{X6H>|I9Q~IRu%51IBo-_yHBX=9Hm==wuL{b zR}!SJ{{$m?2bbaMFAAWst}`a6b^B``xV2@c|`w%sqV$4Q9T!S6(nxiv(U`hW}Id}y>Py<&(fs8oz)^dbn@~~ zTQ}R7%`Ht*4Ayr>{1r@XGmRRi+3bXNiy7H|jvcQfPc~|Nml&4l*unkm!@HTUO;RcMu}#8k zL+{6Z(|b-Hr*Y#`0~?y0+vz-*Jzt%Bc_t-5Bs;Wh(Aj)$jCq%n$3A1zssvqQm4k=- zByWaxk6ji#@g>cUFZri&7>oC&@~VBT-+%EQ>oXUh{2lIhPL$6=bhwh!-%IM6x0kDo z;m6s4*K$8U2%Q+}kuEp*-I3dV=~l`8=$Z5F?%6!{c@KtL@>M5`RPM`mDC;U&GCMC@ zo0nlKCo(8I-mW>@n6+}J;my?}Kji(YwYRdK*r=#U71kC1sVKI>wL2m%4_7&59x3h5 zKa}#$PBbkiOUb1A>!W|pI}A4k485u_diyC}eCH3(_21iqc7@xyMKUgE8=StIe20kG zT}$&pT7cmL9ydzJS|6psBle`Yb15fpFw;ffC+i1U<;!aMhJW1H9+IB7d0>i9O5+{% zyI(l-8p?Vv;W&2puEyDna1R-!HBl!R3>ql=Thi`{Dteg)2%6SXX8$R*OYuIYTJg%h z@YVOz{Q9+$+rIE@U6}M&a5rW;A~~`B`@bvFPhURaesx(nOGdAE(f9Z?J~O7}`YdWr zs_|N?d)*rn?mu`nb6ScqokP0WalBXl%x6Q$3Fb z)a?$0jJb(31~WZdyM5En^~Yys^uPAp$~rWqxsN^giu7G`g@N0d|Ad~seMKn#K~NLx zZ-)4hW!ie%JzYLueEIrob>`u5_2|;#2kN#~ zV&+HL^oFv%I+wLE7v>J2N zXA$=t|1PlO&g2SHUUIT(gr(L7$mjBerxI4x0*{b1tb~cymb`yTf zch5cMv+rd+Dy&$$Hm}xXuH1d1;FNW_nd#U;!&N+$UmptwzCPHPU~!At-JR#Sov68v z-NQ-#BlUZlnMON%+@^+)GdARVu9RVnI+fr(Tj+7&leo0)(+@IJPZPxS?%3A6ypc|w zshQD}h~)dU%%u8_(U3StuGCKw@7=F%IcL5^TdFE}^jDn4xx*IIcCs^NPynv4 z&~;q6ykg?Y$%GfTpRcW)9P5r0Jx7VHYjwUgknJ&cVOog(lqrtMX^x!w?%`qevO#7- zJo4$4f&FfWxqsh}sobP=@yq1a1@T>HA0G@@W4lVZN{2tj*5k=P&dM%*N2KP>m>vm; zAGBn7cXs%qQ%>ra$@$OEg@P|7*N3#ot~I=n`+8E8;mbttC8>o=Yx`pt3Z?0l+x~ik zj_babloi$Lis~^NpK8(b;B?Fwe>>T`2XA^rxNM;#PV(!e(jw`g!kxxhf~94$A=7m) z+?S5$z341IC1e!SFx~T0p)e$y!{v#?j$?Y`Iy*`$?*5{#{*<5`;*d7az{#kyw|oAG z_l($`uhe}}K7DC1G#vtYNcPG_{HAvM$PhwHv&QqDdrrCT4AxcsTF3-2$pcKCg(??vmv(RB~5 z&2sjC+VmkWd70oslgY=ij-2l4X{RgdNA<+~vI9z!*NqI znH00@kMR~pqvH2}CUc&zssB-XMJ?gFL%7RFUY7%o1sS_+n~djQ2K772%qz?U#>Vz| zwB%SMFy1D^Wmypss9l!Nr>X*|yMps{zuDiAW*33-4 zNp-ixWM^jOf_`yHQ|K|pSDQ?~{y$WGW3VJmwC&89!#TEX+nQtBwr$(CZQHhO+qU)k zyYX(sd;4c~RdsZ9R_3mi3uY}#=#!8Qr1l;Wx}D$bG#gDKjl_#kQyIB&8%v2X;M%2kB3#dpG#JoHO9KN;42 zUO{}CjJI4;*B>kj&UXLEGNq^6dpUAXxD+4Qw$2U?F^X0RPvEVyLh~0W%RRp6co@`GGx(3+2taNfXQ{+_g z3X}UeqP@`(I%;QLq^_J_T`k+*^R^C;}dd%;}vHU4_U|l)QI+s7PGi; za8xG*;gop!e}q;$DoGb9p3PUV;4RyK#GO_T#!eQ@$VXopt}! z^76Ry5#%;TAoI_FtSs7AtjQhQoKNuU53j6m zBukQN@<>-cT`|tTrA&Rg(rJ?-7xY*pwR`0g(uR(Y8{cfnq`WZZ+H)|bQR3EiRL*FfzkLfVx znDD+s7>v2Bsy7;)KS7(#=EGu4P5W>kXczCSrJen9N?#~5Luv**vz=|sIo(<7+`XTme|t!&dMP$fBpi*7spmOp|FyJxIv zt|lJ&UQQ-zN>}|*l{#Qu;vnfLn_m|@n^(>?pi^%!nH=?zEyA=1Rih}6J;nf+Stsr+ z3}E@18v{vqYs{9e`Y1+H)V^MK_|qlx*+$((T26GQ?wTI_9Q_W+TS78^rK$S}>C3g< za$S=BUU#S2ip5;B>Cz(3Y1bZettNNErtGdpw1}mk3lkdJ7yKSLz%&1Rv zv9^lL+oI-L2H+~xEh~}j5=z#vyw;qfvnY43%vaN+q3@@Q6!*38|hk zA_UBCm#?3Q$5zw1eOI)%JKAz7OduDfbAby-KpH${9}INtS|!9^k%dPbWu-W8KbX@I z4`2(EA`-{p5L_>B-pVwtaY>NLbY~sxo9H(p1u(3LHETo74&B})Bqa-9PY#)_=h;_J z*ld@~q*88|k1rj(hzSQfH!F70&nPPMc|$9E|M;E!)U>prUsV_gnCOOwixsu_H3)e5^TJJ)$3*(bOi@nwj#^c(*J9! z1~7CUJ6F@Gt9ucA4;-P}t)`z;f5ShE{>kY;Mp_CbOG+nP9Fg#F(h1vncyN{vX(Hlj zq|_4leZTs&f`IkYIFc((4Gze0L?&r4JZ`CE0b{&>ez=Mkk+q5a`$Q1B=#j@`r8uyi zNcA_1`fmARRD(Nd&f!_k+yTL3!}J~JBQt2Da&Rjr!Z>R6OmDZdX<1 z{ZiYnSp!oL%Yd#B>8|b=ZISI9S_b239;#)9nw^u?6VJA(@Fu0VgGEyUKgCe2gnxux zpv(r&Cf9BF!Mr`{;$mk#5^pc0 zJkyv}pzRYGt{g(8ShGuAPSNZxTVMt}iazLJfctGfh=zj;d8DJ_dnFyJ85|r><-x>G$ z@aca&a&E>N-~DEZx-+Mbzwtif%hO?!vHwHU1 z;0(FBCu=5#)p@^fP?v-I8QH>`Te7+t=ZV1Uj^U-Z;zYnxS>8Q4p8w+Q_@$?h~7tCe9@43&> z{gLv!{fVwE-F-2FY1io5&!r|Lt)f$zG1%vaQwv4|JSqaeS0akg3X!+ENyqkzHsUm{Shd!{Ha071Dnl9$~BQdP! z+Nf{EXTIeuYLFRga3g}_&gU}A$x%^dB_xM4JB;)nF!-`Q1O`>BMBJ|QTr(1-HkqtA zIn2;HzaNT=iwotoEY&|eAUn*QwjEqCJO~J)#AgqU>jspHsmbk=>^Ar5o9ix_TSP*d zw=UbJgt;O?DytsO?7X}SE{Gf79!Aeid&&y&5&atlU>h)G2FK1xdx?1iBq70&zmJ^u ziOM4-3)hy$b&)mm=ZWIK70fxZpmh-|oLLIyO_0a5@@Oex!S2b#rMEkARx6$s?Mg*@ zcJxK*KDRh!_z6x#D;C872_*khpGGv)Ym;lHkSuc%?m4hUjr?P4IX)&EU#gP%KqH?* zx5}Un-|j@sKaSON)X7=jhyvn&rB{~%u;sMJNnT z;Hd2jJWpp0GbEf0L1c{Mh>((IUM8+a!fWa)>cAP_|+xGcbUN&Wh5Lx#PStmZ?xt- z-A$*5!dx7BrHoGDA%b|Ra@za3| zKr#QPe%JtvfK{`r9gB?iLcz_gE%G<{2=_ zzWPNnJkrNSC7dF}=i{py;47b|a0#zZJJ~IxA|DgwLs}_5D`1mqp7BEA z>7KB|or);0(Kk-LpZgtxJ-JQg<4*Gnm^|YN9dpVO0GA%r#pT7vHUi;C`M)ewi%X)u{~W^rwLc() zkV!=J&Bp*P&4}P+SQ$x(kAK$_LBu9zn4;JxCSC?e$@~AywMJBE0FY4dzxD0iJabq< zIZb9L#>I%Xqpl$SKR4cqB4RVzeNI$o($EY{gY_wC`gNQOJ*jyknKBo3JLZhpWY%2+ z&lw#x{Nh#9kKh#F7J(yQqXYYyx#O^bi)7qSducMWSFnjBA>Xl0%2a(QWNaZ^`0%%< z_nSL}-Cd#3Fu}O7O6l(jRp#VfOzZ1E=_Gp|!^v{Gk)6ZAVy(-z*g%tIbMk_;z_^N^g=VpUmTuMj-NIN#O)$GSu*70ip(b zj;eof)R%2iS0BSS3MuQ2{)`mn5N1=M!q@(4PhWN^TnfLqE zZ{`ojg2W6-4qPH0E`*#1w)wyU6`svSAb%wp%T9DoQjx`bOKo_u88qV0%(8 zVJ>~)JS;d3S%}f*ZH`ghV1gC!1wvuMS@>!DcS_RS)BPjHeTH5EDN15q>K7%eDL@72C8OZNiy~Q&+wFWeeF-6lR^Lh^s?_1TV(?=5E zrau@)$}0t-L%F2%O5dtYxp>w@b`?Pc5$Z34czg`fY zZblsFAWeHVM7JH3Gk(q5G>%aV?6PPWPaU>OP0A`R3GzD`xiy%ZSCX`KSVGZzJb0c%@M(x)*A>QY4BR~ zjvffxBU6f9pAf1&EGVW=Zw+6PxE1nQlB6Ek)Yt zv_Gjyl$ywhGm(WwP(6g@y&XHIH)#wa8u5uEH`LqoA{?XQPS!ttky)gq$skdJI_mAa z{aGq*&bRko`oKo|s!IovkhW`)G%;q4&5Mm+Bj$00f*2e=qbtK2vzqb1not!Dj~>o- z9-7il3viAi8XU@FEVrF}(B^#9!(XR6-v>B=i`dlTyWA7hh(9>%{bNFS`2?wRhh*Uk z*1U-B623DhrkxYDRAs9Zh7S@B*u)rmSS4Gm&UWRfSZJixkMx+1P~}F+@Rh6{+R_f- z|MeR(_!Vtm3VlSGfTaRj?ucKT6WQL(G6}kpNH7?&s3p(2Wkdkqkib|~%pG35a5_9! z5x+$^4h}PlB?Yw&TYf12a@Cjl3?wg(qtYKCX;WMg@I0?;x*j?IBQKM#AzY_B+Yc4< zz-}*Z$YxbV426mjwo%$MKJnwv&U!O~x)6TUn-3EmkF1hN^E?A7Ni8@yV>f4PA&4T& z5}&QmRSIt}lj|fA|3apzy?ZIb>Q?lXF=hgQ308W45Poz-{TDP|T?c8Zz4Lm<&b(T% zCf@?Ua`Q$}k?rxYn18||nyA>=#aLHej>mvVD&6?f>Ppmh2w53^=O?@5p1&Zf&U(=5 zAdg?VYU*!44h$xtPn00E-L-FPfmzZ;lw_P9KTR>y) zh1`Y2w!B6?Fq-Srg6r}Ojh1t#(%sQjoh*vPG8-Si4cEMU+`9s;=M)UYx`WUEVZjuv zvwNkoreqGaQB8ThFPu?BP>#curYR7>VdoETW%9X1;91*_;kQRunO|+3Arwgtllok* zc|9&A++xYnvqXnxz!&m;KL$%*N|OqZ42p(Lr$%4V2ZZ#Y01F0f*@y3l6gAcul4ojV-+n)py zufix=i=9epm>diWK{SxxPaQ#2xw?p>Y#v10mjrswej7C+75O&^hs^J42_+Vd!&WF6 zJ7u>QIFc&Maoo3{l))D%yLXwEvl9Qu5`xPIBWm|D@oO|17Ry2(i5}E-K|Y~5P?W%F zcQ+1~H42MQ6h1{if6mOofHno+ctE(87%t6H(3=(79`H%7e(18 zalQ#(0R~vt&h{UnR>#kR=#APA|K?P|<(fnGE>q6gjkU@6Z1>|8XEZTte^`54h?9pk zLdu+)Ps)`-@H~J7H%#63&s99%{q_tX0dR?yh3B8CGWlv>7`p} zn21KFiC3FaoUo8QfX!o(wzT%T%U8qsl-kmaOSP60B~DzpJP#Nx&wV^&G#A{`p+|eQCIK!m==~A}rTeX-ADE*r*p+9Nzo1ErZ&Ds(cI6_J9#xHhh6-7|X`R64sgu*j9uJyE_0%6rr7NlD|L z1lwRmEY3wO?wvQvh4UD@eLBm5GY$JR;4_L)n32`Qdq`eGMYv@q>QhRl8f=8^gJ0mP z+G>nif<^HidavPEYljakzJBo0Y5@zq{yedG6aWf*%Mj{o)%aMgeF25nLw)_~`1X7= z(ix`t4uC5NE$tbZJ|jGRG13W^-X6aiWZH0ZkMsJ7wkISxYA7*TtKDOYiPF>5e8>Lk z`Mh_+g$Vm$K>5>7ggxBIh(9MNFG)0$!%NB_9kZrSM2=711?_9vhlV>8aWWGYDm|EL z@6d#vASzndk`^y2D*4ouDOrcRH35b@;#YafabVjLx`K zSO_vW#Y?Q$RO0CauEr1qrtMxnHy76Jfz6mxd@&hGDpM}FAOy_%IjuxYjY|1n5(&4= zzXCY|-C;u%1xx~t>Cj1fSRCu;wP;WJ!C}4FwS!!*nX<(nrOGftp$oY-jsiCU^NK#* zUqyeS&lkTaMz)832btoD<#1U}M?-`D0Zn9MD~~I}!fWe-g!v0UAVC!9f zIhrzZR<+{COdoC<@Ylpd>RIR4;HB>7BqY-LMuhmY`h1Ug$B4h0hv+T#@VML7?Mwo6 z9r+>G220Md-X*s44EBJ^SC8Q+EM_Z)(V)3^-_nu-8bd9v*9g_Fpe)yM%1j_!H5N~A zodtX<6#-oTn^gw;7oA_ML`52dA>LQ;XK!$Y92)Pa$NdkCeBgox_!@hZ4GOBBaftaV zI9$tv#5v5COnwEf()h*NTo2Ae6c7ZB7TMAu?8#yT#T{ILd1b$HA?Q%+Q1BS`s+~8@ z2rBc*P?P~y2ZlIQZSuYPNKXOp(+MU#g664wIYjp~ zZn7^tGB)H}#cqYEm3E7NXDXSV{m}gWLZc}`Q8wzBbQ+bRZd`ZL7Dp=nF(BM}GyDG> zQ|6iG_cg(HE|>E<7<2R}9UIlS{N+nxdfZsksa0)91|k|!*}8`w2sgox%0$!MJ?$=W zSd$oWq4!leRffI5mRXsw+dW7Ql!bEZ%J1%ILo_ax02B1+^a3@Rj}tB|Z#5sp6{9p2 z?)0>P%UYuka8Z(dAV8z$e97I8V)0K575zT3rwdg${!}*GzB(>*;xu7OixB^=W)M!s zFcCr|2_;GWUoG+7z8Ww<^iL6mK~L`=s|s&fufp;xH4GBEkHl1VqU!X3Ky)$fcKRZG zQ(v%#RlhDQ6ot?|D^#*Th?)@r7d=P}K4gg`S8k&2{2)|xQSW{_qC)*&uc)UY5(yZx zdJYM}pPV&+ifJFoxSSbt=|ZQPEZLziNP`-%(|DS0Gtgfo^6}(SnL%BJgZF?OA!gN4aO6~u!()#uhL}c!K8yuMGJVRdrrXl;s;^TA z7-7NN%LZ4(pLsD&HD(pkV->-O|41oLUFq+B;?nBzIR=RgThJaDcXd#Otm1bdqzFvi z$0PiCx{QO8+~1u~W&h}@yp?GO&mJ!vtG^t1DAo1G_b-|@oxbC^9eK)ypH*2hf!lHN z-ZX<-9FwbCA*~ITlmnD)`n2S@KeVD9+#;I#%dj6|v>(uEPD4Z}H}he5(R7Q>`vjp^dY_(lp2gO|cjS6!v6)zUY!}NpHNh zQvPTXE43}~L&*I>f-nh*<5PKm_35bv>j`Ic`?@Xm#&rS4VR#KbTdVC~=4ph+c@kRA zKZLr#5Wm!666MDsZUszR=S70d)88_#~Nh@r3h$-~3`ajC+^=Q+<7yd^5g1NAEqE zdvI}jR`xGAZK1$o*5QB#K{)3S;G>JgK!^JoF8=jWckxX3DLygQX~;*#;{{t+SOxYs z?H?eZ0@%|dwCg2gPlcam=`;Q{mu^tFprOoe>cX!;fY}doHU|# zyc;gD@1Wm5P|L=V-$)W4Ybbnl9;6`7kwxByzu%v&ClC#VsjNX{|H0*YvP^I6=Zb zwd6bZE4n7KHVjXls!&zz5Dl}vl+2NVlnXm)bX$IDw#_q^ zFjNT7Q(wUO411P>ifD*9pMu3#N9Cm`KSB#-L_+GT%;UA&8V4&8uA!Y2x)^6qKH%;0 zy_*IHM%=YIHQLXC-L_wl(MGm{@y5WjJKOTBh%hK?ade+`p=A$P8oPZam*IQ0XNN!@ zVNDPr$f5s>Ff-cM*ExqILh9GJJn-^^w^!x9Rt`#$HZtu4Er8jQ#|7<`F~29%piEoR z$d*DG)N-3?(r5E*w$4!{W7aA5)HXA>03Z6>>g+mg9hbOBeQ`v?z4aKpB&K{XV?=oH zh#H`2k8{vdwnu(G*gB6OK?TiWeqeRpmLX7idr!0*Czz+QCY*ppD_yPfSgxoP>m2kv znPExGU z#7Lv^Wx;lcb8jbqI2zId-|N%5gWh*`<1tISFaqyzSd(|1(1v?|l;S}^_z{vsdsLoX zJt!3IY6c)1>PUv24&`+&u|1!%#mWV};Y0V&Ri#(hu|4*JRlhDYu$G>m)80dU`D}pO zxTT$IvbXF#La0!2;zpquNg_hYC;-n+0Kne?X9F&UYxMdGL+uB6ilW* zD-wcO(1JgX9KZS)9r|vX@5F$$l_h8Nl1+Rb7QMW;?u6CF>)cC&cy+uzFrr!&XL?S3 zIGX05+36TnQNc6BdB|#914}vT`iA&2ePA7R2n+Px6PY8LdPj?#@|!bnrba0j6h+xv468tu->D&As+`+ZhN zbRQIhXp_*WG`DXcR2uD*PHDLV@(pIPh^cq89~sURX0*gv(#;J+x{_4r?J9&0YTb28D)&gTx~s3 z<#n)67y(iRec}c5y1bg~{_!!@nbZycEkEd#7yr|pg4uET=<%P4EE%743Ca}*NmJS2 znDE#6PZ!j6rf1>n^Ki^x#?jFjeK;vNpKOaSO$7VwB7UD;o4QZrFIEs?uwjOAO4afY z)(hV78Bv)q@$uX{ee`Rj?EOGRf6cTCd!je{kim??L-EOM?dt_<&*q9M$xNa>`X1OQ zpK;!F!o{|wV_ymUT_rH+#jbrXCYGjvk))+pAy^*W7cBU@sqv$MBY?m}qL`Z9DGA<7 zMdN`*?AuBkn41w6tq<+y@JKp$&k{zsyx)Pkgd^19VoN%+#-gG9F(=$XC{Zvr4FgY4y%P!enY&<2vkgDl?RAR&8sW+&`bH6@wh|lnPJGAnP5_W5 zXoyg7P29fAW^@U3o9FIbt-7DV=$l^2cGva@HR0R3$A= z;cJtCwp8ltbnaF_2wy-YdEKMxW(Y;$ut%@rMsKRlbE(ymV?ciI^?4Rwg%6ag%GCHq zkxP>xOe&|UCQ&SuKWwd~dFD@6j{*!K@2(3Mci&W`K;^R+&{6SYNAcA<(d5bGz{?O7 zAUU!iYyUl@*>Ax9Eox zl8l8*N$w6cEQWIvuw;PG&A6fz^ik_rUoHrRLzHH6kKGs!=_A7*{yhv9^+!x(D_0ni zTC4lbcNA&jOB~?RC)i9sG8AiAhYuJ_OM8uaO?;B0jQY&ghf_zf z-j$ZofG_ixlGObd@J@#Gw^vf%5@A0QbTLgMY7VSKc5Lo9U#&iUNf!#F0$mPJ8x!)E zZX9q)X6y=>jT=AYF5G`-B`cI0$<9+;&hV*680~fxpe;KYul?_$&htBbVeo)+q z=J6(C97T{kPPB}*y$Webs?OLX7|!!^DG1;V94~6Sbis)6j6KQ0pmf!U_emww#J;xz zSYkkILLf5E`O^g=I7ZTsh0 z^KvMYiL+OwcOFT^iM09gN}rjn^g~4b-9=cjPLuP#48_B)g+=&pNtwPndl+-%`6xr( zNmT7G#kdyC${gQ4!VlfPSw5_4D4As~^VKy8+4{J$`a9it^98}k^Aomh5Z>^5k$xD? zGuDVd0Ri9jolN~*Q((CqM6%;ETyNuDu-$)+gnC`4|9G?6U^7`OOfEB7`bQGoBV50L zWVYPbSGF5x32UX!Z4=G>hb1}~1SA{h)2|M*z|Lj?>*rFKHOYo=VDn;)%^Ir78&jlU z(V0c)V7r@HX};NjXJo(WWo*e9gvUcSr=Y@O^7{L;Yn z$fnNABRMT+Gy$ys-IPB+;LYK|+*`9P>qDA*m9CiIot_Xf5A?ybZvAwx? z>+}ZxMoX%@Mr{%gycGDBQYl=C)1&@sBrH;^9jAeqv!+1(_F!*}DOxAMX9WE;>Y54d zj(G~o1HEpSuGq3^E>36!RsO3FcJW95H@(b(D|{JAOa@Ajl#HYy7K+@rbbG>5DXB_; z4cfxy;CCFO8nj`-q(by88%JJv4F0(rR`{lLDF`D3{M@M(K{`@~|Ixa9)2)e6y6&U+&e>X-s4L z`nSpTRi}YWWA=NM*XGxFuFb1kn)OgnKc`YTO^L;)gUWX&(w0j_!I41(_!5$s1WRclR}eF;S2%1Yq*^%l6!gHVOC_Ubmyu)&=@!Re&IZwZdVaZ_kq1SZ7@d` zUv@Fjs5UTCr!VH#wr+2{-E1sne-k^~Q#~$&DwOUS7$L0n-ILTr845 zYqo|=n+{rv5_JDl0QEkLiJ{0OjA(Y4(iL}Or_pavtxJU_8bZxI>+OF1D23py54~D} zGIVh1EY!U4W+F2`T;E*6n!_BmhbT(VLPwDOWflCJZ~n{cv;7Ohy@1$*#f(h082NV) zPMy<@R3u>cS>-Y=z3;b(+kHQD@vG^+ zmAYmoGnpYpBe^sneX@UR3)}WEMvXI!b|-xKYST{sj}cEv$B!|I)b-2_5ATDU7~{$D zcW*HqX}VNt)UdQT|2G_91;ffK+@4f|@>vd8^SY>miS0-}Ncd%g!#!rE@npI*8W;Ria=DXmVQEz5lxuXYm& z?~M&ci#@@;%w~wj(qPC>;B^Hha-)#oiUKn{3?`)iI2PFuW6J94#jcEXr)#kUsw<7rD9*_=U(zA8$+I-?tb0IBZH-+gwZL zU=WmE_9OK*-f|eiXmZ=fGSLk52K41iLZ*2qqPyP=c@`GykKaJ}pAgNZC#Ne@eeWm8 za8{$dG*RVzq?Xb&jjxTYp#%mgR-?tIiOKd%ma;jZ43x~MxJRYLCUT4ZS~OgmQCu-@ ze4+%tP)%K+ZDWzMZa|dlrR?+FvKGq#{=wp~i zBvY%1OV=tvikilu=IX^Sv94;h#Pp}n3sNVQxg>z+?;fUsm!F3uFlXxSHdZ|-P5cu~ z!v1l1FQ@f#lv!Dt+*D11lzbv%%y#neEQUkfkx@{?q}w|0yk=$jlN-s~5Yh960L&fH z7#s1o#Yo-mH>9hN>F<~bfa;DIlBX>seyb753>*OiGZm}pLK}O3Gf&iT5i_ww?xrR_ zA7Nl|+xJi~(IOQaS4=aF)`&MSl=|W&6wIv`cz1_HM7Z#4QrUsdwsYpfalL&md^^rt z>vw4Og1REv#NBUv-8RRZth#Wnxll&!9?r#=^~4mY@eu=>;!wrK0*$VJ94tB$u`(vM z6l1A-n=iW$)?cTlJwe#^zct?YsIoR9#;Z~Kad`S=3z2_Ian3gPDZ0aI$i!|Z@-^P) zF09f+Rb5EUm+2;7x5q*EjRo|Q_t00n@hR5Hw4s*eTCcb_o~}kZ3A8gFI1&V}(161N zl~CzvKJb_BXWP>fM9s`zPqHOl9>B9@743NY$8d#@^bsVLGnZIxq<^luK=lZz))hks zFpS)BQNcLe4LkmHQotw^3R~Q-N)*q}SIy9MceTaOw8Q9$q{62&9WTsiC=Bxa%_03j zZyE`lhSDXuc0;o>j%CdltgYcxn8)`+)SFmyU<~nBXsbbIFNZYB#I3ttfzEUd*i~f6 zWd}%h`{c5^pDiBiNU63Q{&c~M12dI4q%g1<={HaffyV4JUCf5i_&MD^G?VTf!d`MZ zp4h-;j?30qb7QPluyx+Ef!U`9#^FG=4+vD^ z0T-d>NWNy_e^0MUj1Pc?{Xa~mACAepz0O&)v4gY0?wN@DrLF8sE*6X(tZ9?SIs8HL z``xog*u+Pc5yWr)5SAEH>5uN*(_P-PwgWT(F0y%w?&SzR*BZA%{c&e>{Ho&Q`8}@Z zp>rjA7|Xd(x}i_Gq?}{SFXYx7t^>&R{d57C1NKW>{Ah115);afQ=`T2lVNV zrpfqqR<5_93LP>^hOp|be`BiZIf_9j~22-#-QwBre-; zo6@mI4$`zE4(ZLs6&CK|liRl|g(kV27z&Qb&<2!t4H82b$@cB)ze{_$s+xz$_&8iy z06&knPNh0Tka*8%_hRj8UzOGa4PP)f!mvnD?|BBDp1QSywcNdafEbM_m;$Q zEmJcYw>vo~OCq-m9=WjckjAcN672gV>F4J&gIm{x&pB^x4< zviNu~5EN>=a6(Xb)c=@~ovsXG`)u8GbXQ7j%_bn;vrCWo%F4oY$AO2>eW3!CE1UYw z?g;q&5Pp^3rD1CaLRMZjH1Lz*uEvRul_Dd77JTom(@dzA$4K# zZtQ&>p$et4a5WKY>>U{!ZpArwMTYM*((Di`rbe}`f$*3Z8oie4x7GmqWzHjn+Z1_E zS7|yTA|}0|0>OJCV$Jhl)}$H!NpGm1r^g{YXhhU~0c&k3fI>f0HshNFaXODf8^brY z@mdA{C)xb$eL@-ZmD_O6M5p#N875&WK>A^Yp*%S|E5~Iw1TqY59-IuvuR)bP4+e+9 zBYs*kcWyR{=ycm2dGS3P+W|-p3!^XBCI+ssMVFE_{3)Vc_I6kk14@(6GoBnWCl2cU z`Boyw32^wdNCvL~HYzXKnp1#N=14Ig=cV_I>Z=*P{9s46X9P%%78t>UvrZpli7$r$S3 z27DU&1a;k6$8#JKP~ZoHU0aqxW8;OyvJTSuM&BsZl^py;^6>LV!*u8^QOo@e;76C- zfAryLxc79M_tA1uHYWfZSaNu2o8in`f6*+A7C@I-)b|V~1~6qdkIGV1CNfEFqc4KC zI@TWEKNQoJA*D+3(Y=;~Y_nj_J>=AsaJF{)ajqNlG@!8Bm_CYMgP7^eW>ceKG`p6g zYBQh9J?ts@eB1HM(%8ldlZureA-&N-4ck`XL}2F|+0$V~RDX5=eTM|l83FJ8ImV_b z&bBWM%efn=&Y#(3tH8SJO;8sD8LHOV2Bs%0WcHou6|~g65QRsiPq7m^Zz%7EbPXw0 zMqBU1lGkz>6jq=N=CT^&6h~*ExgeO&M&~sZsvX~~#V=(td_h60tLspy91&(l_qp+H zvX+M{k>^{#iuj{#bbAh+d90^r*Cfo7RC%6}1zM5dMEZzve{vlm)B?y@qeD*XN@Wo)IN$ z?nE6O)UpfQy8rXtgNk2bRs3e~x7NDp4@gu*MytCneqJ(@K4Nww)>@AHxq~GqYQ$Q6 z@zpRgS@J}GRzX9nsaQ0^-yN2S02eCP#Bu?rHXtE!?3|V=DvD8^wZ@-(`k^Tm_E!sV z6oYqe^?DB_@l+c-_b8AE2=MipnSej!Dn-AwqJwiiu3A!va#ghtW2PHk9l?te(Z7xw zTp$|O56++7f$)htREk>PSV~p^{p-1Wc__Et3T*vn9!LW8F1v*da9ki_+uxGqD)BzT z*toBYJocUdhdL6fH7+FLKjCpIccBjvGq;ne_BdMX6I*nU&(+%C{!nCMH)Rv59CeUF zmx`|-NW?Y!@Te#`Bt%g&U5z|Z;}3NmNxU&%qe2llkhh*6R505taU(SxO18aQY6((# zU+g}U+5OWS!uN?bf(An8;W2VtP~ijJn%wS`iL?s}2`?iiOy#K|r5`R~b%(d>K;zKp z;vM_6s}84p!H)}*cVkD9n}ww5m8i(~V_EtMb;GL?{%p0nWK~~(OjqCN`XG)iA3I30 zLLv)Q4iJE0()!|;EYy1LQH5KcgevZf?K|owb=rcRE>h^4T)h5T$P?m9%vQ=lsxy}X z%!b9kKOXj=ELHsWsk`;x1kO^yXkksnEaPM4&+UT|ZqB>W-kp12P&qdj zh4T@uT9YxS>+O)gO@c6sT$nwfb>`=8P%k$Z(Fzdj+L8@c=j&c2hBk@q1= z^->nz5(5JR2YVv-m(25LOXqC9K`d>ub<<#~N3J6vkZP16(=2J&CHN`%r()SaWor$|Wo0 zzL8{EORit}6RYLp>>z>~PkwH=XBcwX}5 z(D2%?nHJklEjyaU`&wR`N2gIERj;L21J0e8EV+o$WJ#@igE3s&=TUTKLA>9EcmDQD zTTqRXtT=cq7&^1SA5&Z%D+qIc4irs;vSbm&j|>3oAG7+Q8H)5mr)#Aq*10t5jdXo) zL2@%_O5n&UlT+NC4lmlQEiTmC;uw6^VBfL^{kF%<_MB^FXe0R1U3RBDZtP>53ceMx z0;8eL9JbmCI#ZZIVmKs1c+r$pm}JZBz>ixEZ^hG};)(y*=HKd276 ztoET^LA-@Mx4It+g&@|UYS7S^qp-2}H^O+ccyNZg;j6($=F0bHF*3*fU>}cgPYx~L z^pEFVg)E8>eSCg6HCE<2Jt8ixCKfOPPRT3T;iwqFwcENk&xy7v(Q-`JE$!~ulk1iS zqK*!H&>R)P_+g>f%lqRf`l6PF?At>pRk{x@)EFk^o0*n4XV>l|h~K-A!)t%%VG5M# zfy85;-emZd0O7`l4h!X>y(mNO*(a#tn`+aXV7T6vYo`xCt8jqW?VFZL4d2NajF0v zk1K}7g2uXC0=b_~-Ch_(LlO0Zw71BXdlp%%BTa(GWiY-)+q%fqoJ>cNOj$!fhQgmW z^Djh6lsK4&7{+F2kn=G^&53di^;Y&K;(!(2;nZo7Y39ksiwzW?5jbb4^O05q89Fs? zQ)Ua0gABuXdvY?n$I+O8DgUwD5sDUps&P2#q(7cuo>D6e5Lf^@;&euL?hR%Y_glV( zmuJ2-ojV`6@xssSxoZ>H?556874ieSB};J}bq0Nv&ma4rHviw0Is2US0yzavr$Vfr z;&1pH_VDt>7tziXz+fu(h8@9@>D5q{em(Cf72X_?Yl(XLS{*n>54-?{rbulnKtpRfVEm>n7D>WfivQIDEZ;zE|%U%HMo8wgZG z3V1O#mg!~dn%o@9 zy{W>a{;zg}yb@)w&~@w5s6Q+r1 z#qsY>e;sC=D&wL;c$Zrwb)GU60rb)W=T)n;IU7hh`842EK#;(pHwA(E1r#lBygA#E6n)@YaoU%r;8J)kRpZZ0p^4B$> z&&Z)Klt7Op^)cAQXB=)Z`gncL^+t!5k*?XR9s1;DHnQ9)$KYwF#8O2ICys(s_UJcZ z8vTluazMG8Osu10u{GvAxyB@)b!8&FKo+PXlT5I?Kvy03lY}3Qs#Z*8w1%jzG!^qb z&W94Ls-5xLVxbh(*>|$$Ny_<+Jar2A>%pD(iONYX?py&2~g$!qGW z7IS{ZLmck~gB~B@lC##=<^k%75v6a~g|JEvx!`)K4!ZcrZxBkj4_#)m>uYneqt6Evf(OI1t zZm{etZMAhV|L0VioaILPG~g*v4VwHcN(I8*8vp(rewl}{2-qtNk6}8HT7`!vZfw-t z!mdT8mU}SM9@_F)t`t7g9uY6$%i_2e;m7*>4xDr%txL5!rw_vsQsWhI@HI7M2SU1f zlnYYV!yQ)WtemWz<*lTH6ks3qU)|O4FOvF+#h5bqEoJhgkLiO3{f)FCXRB5u{wS_V zzqej3lLeOx+=LeJO!41xes1mB2%c;2(iSoPA(bg~8olZKoI<{?F9ox2T zr(@gfSRLE8ZQFJwxzz{XJ?Fmr-gxVeT{ZUDShZ`8z1Cdw2Yd)jC8$uS7Rl7qhVJ{1 zoBcXG5BV9VcnYE>gHciYn@3&dBKD?d>l0_{G#Z?zPmz%-bIpw?^dXa)O$#5Enkt@O z1QU0+c0zoI+AJmnCE6ypi6SVeGbt{7Hih|Ll1n16J1|7D+98D_b~Pa*=e$mvTJTh zwcmTpp9)^xM`RBoG56L%qv-0LH;*8vf0xvlT9?u;b^m4elBC7fZ-VDBnC zRRpIKJAY>PmDaor3F)evkCI(|@_fL{gf984cn(s!H1vi+ZO<@1sHDGnA5J1K0XRMU z+k}fkUy1372{0e33MVkKd4&rX@ZU;M7O_mXQdP)4OPF^fOn)3mxa~nXuXG!Tpe!tN z+|p^J_4_M>CS}d>wXF;anJbP!f+=J808skA$QcwX6V`K0jwHLPtCq9fIR)>Bl^>xq zSDwrCkEBsUF&KX6juxG11sL?=r~n*(IAZQt@yC#&EH3L?zCjPI5xc4tE#|4XG44K# zdt|0coY^)eIK8NLmNaD{DTl|4;Z6kbX}uVXc|{VyuaK=eAdDzyOtM4}wAIDy!gb+l z@87t8DV%>=spmmC_MF4E!F<_q+Zl*wl*ObO!Vh3}Sq_nY7nWI_bvfym?y z!~G!i9skkazY#;&Y$_<9gIz&ln-xuK)TvlPbqLGlJos;LW;`O&FfU z!BV=4xg01_a}=@So6Cc0xsdP5D>kY!E;JvMv#iH@@LLpx8Bm6rw5~*64zqJz(=#Gi zI}0b)z@HHyX6|GA4Yo5?)=E&ta-r}5jx!%jo>VC9AnF46m1X5n*YPSL_Sy4g2$W3I z=tuIIX)Q{|EaBjgUW*TAPY;IrOzp6P`@2v_;=gEb&nbEhej~BK^!obf=38PFwqKc# z=Y7*;m%N<<%dsC2M!{#h(fgPTvS)js4_brf3H7=tLb@hvQGp)BedO%+!Qmeo%CaZ- zo7mZy@iA`JG6-;WaW^+-rHd0#PH6p}(K~;M;sY-62M&cOs{v z`tNKU?e;8B%;C&4WHyPGix4^J#Vd{!p=$z%ON9P9`C4p;(kH)c&uvKIaz@Ov2(Szk zYh_28Ub&B^9H(B8$(JuoEl^k;oIjALi<*Sx4`Zd*CTA;D6fLA(bF1$ZUH2_?+j+B%N&h%y6GyhTY{d~i3jFh{Zet*2k{v^;zi)VyFDHiozzbkG zMg#7UW4U-_dy@G}U<9_^tst{$g7iprC~+oqY6IyepW%E#xrXB)^EcZOLgnvD6CU_< zyCLJnofh=zU@^wy8+WPP58tP|qS3?q+pebdbi$v=rFMiw zFL~VVmtK=en$#LU8ELvKx=X-hV<%Q7^0^;=On1TQj1pgVG+nsOTAahz+i~~FZdLA& zR$NdwH}C`it6+_#ksBKrw#tmzYfMBce`GqV`7q{V|%;uyfjfYj&`yc}j5 zk*x3+`mF5U{RUB(l|_uvxpE(j`W|q#nn{0(eBBX#N)+kc)%ypvI>j~WZ@4t3^yYTj zT9?$Ovo&tFs->bI^K=t&>wK;u)p@_3(HHEk4}+IxUq<26NBSoR#Mzk>oS*iU$~5*C zQo^R#oezl1O*$HzQg!li-36x@Ee}PJoJ=sKm;0BYWK-Alsn?JnPKYA7LDG2L>7BJd zjRy7&`;W?{8ocPZ#Peb(RX%^2Y z5BcJw42(Hhlk*M1sz)bpyU(058`e@yj6pk#NGe)WqTY2d>zglBO|B--c6##Zs!sbK z=jyxa?#~j)Gx-3&Jm)pSu zfyNtOX?-*TNIQ-?>Pxo_xqo%0tFm5M>H=5IBVLXOGj%`>p?%|-a<0jibwEF*d4RUr z|7g$Xh7CpY+M}74PXWt}`s~)<;lz}iU~dIGQ{gHH{<9FakN4uQG1i!d+zhy8XzWT? z5Mz%Rlp^}9A~JtDpk93XX1PmL$s0bn#uH+V$KKcA(y8=nfyISMQ*0H2;ag*?j@;=+ zJ?DkS$aoEH^5+w&C%u2KJSJ@9>;u=5GpW!(*_X0qm7G9#f@RnR1sFLMki)1wVHo9* z3>lw2oQfy%7C3fJ^LY}mn82v33Lc1kZ^_)ERil3g>tM8d#(qsy){lnTT@Rm;%4iB9 z6P0ZAEjKOJVWy@%nmXumBuIFbYT!~WG@Cb8xvV#sRIi5lf?A$5C`mIk*~P{L@GSmh zJS9iQgeAyRR1O!_H37hlS+14n)UZc8uw3Ru>R*y!V17E1pmI$S^PXJ92}-oh8^($Q zvcZqH=#xp-J|YTy zg@lHNA~+t5J~Y`5SKDD1d6zq$7h9}@NpFSu3dQ`3LCx!wmzP&PXHt9quFId27~0mx zB{*58+73)j6UPSISsDw~CH z?$OtOiRzU=?(xtCt0c##yFRUyRrmIUN3Hza8(%b;ZpEqyT}p@7;s-eGDPz%}ce=Ug z7O_Ev=kp_{Ot(1VzOA_UQbf?!3kTBmdR_-752H+j*aI;N&@LE3smO%p~ zn~`ES5-1afMCa$qxM?(LmuKh809AI!Lzv6U%g?Qi*k^PE9GoZ(!>6e#g@4bm32vqs zE{jF1R((U`-aEUiKFG{aB@jTR2m$VU-mUqce5_U+U!Fv88jitwPIQvFRq2N9ldo6)M$@LE8MqOi`k6&H|9f`hlMkzwDvwB zG<8<}@G-6LCAzcy2SH=a`mFB}{iJ?TUdpa}Tb)ro411%84Z{r zq1EP|nD8Zxy3*|jy3l`iwO#iR;e7jhY3R???Sc|5ytO&c<=z}4s!?$}2N9iRbdiJC zPyVa_krB{_z@$+d4^?q&%*%`&K~SOujl%J5+Xd|=syuh^PvZIPF4`oDsX9NnN$U6; zFaLabE)l&Au_lf||4q&$R3&ASJiqxTP=(Jf$R$>WQ>QHP%?thWCjvWy=m1v23}Ml> zKi|qmkU+9ldLEw-6hHW9`!M;K$=eAplm3E$omm5%7|^W#?dPF7K1Xe-(?ZJQKNlT< z+bW{v6(SISsbDikfk88}G?;w}&?G1&7YLUtR}tQS-|f%ykZP%-lC}RIGj-}CIDJn< z7*V0y?w_RJ8osOJ;-Ms+NyhzaAp;3%g}=HEJP8wqvs#0{KJJg;SAUrO@cnzPmU{M@ ze`X*3d_$jO!$Qep(8_Z*pLiL`|NQ;2&sPMh_{&`HzEA&(Gr?VwPc>Mwe|6)3KEmL? ziQD~q$obc~v9Ulc2mgQPH~+j3aNF1FKH|-EMT^7Kf7l-IbWV z%S#!cC*8|&evF-?qYwhsE*TS8&73=f$XH|EJ&-Of>Jl(k=)(P)?Iel0A4bpvhUlqyy-Kn z_|x*;7;9z5xY}1J)?C>)4JGn3(b*g}_J$8tmkl=K){FbZ0+B6~=T?f5cB-f4Ha)DR zNQyP_LKW3ut+$A1B|NXJhyp%m&Qq0^PB1a-~~ZFIPYUs=f9Eox)U z*%X^@=77gbsd0zT<%xKns?hxMzQ?J|S?$i#exf&dAmTZl>ChC(41m8rYqb0!y*Pfs zC|7M_s=-_;U~Eg((?n~p_iiD>%W_ksSfdv;$^2RHUAgu|u~OO>6NSajdXh5wI4YHi zsca=E_bcx*#VqmSkG>Y*gZ8A|97YmKkKR|p6D_QuWC#C z{SLk2vcT`+5+ZhU$ZV^cVnK2q*5a9jrt|2P`^6Y@;kSx9TQ59~$r}rg=veuc zg_DIq0;cjX)ptj(11$*F^tmi(Fbk(VkpxVo5h^6F7hvc%t2%qgFjvw4h?sm}#K-|{m@d~_gn&`jdIj{f0lIGjEoRa&SuF6w34!|i zAKfCw(Nb!g*b6h8k-Em+M(%B5oZONK%O{Sl^(E5c|Nrnk4fKi29IN_m@%@<#Vhk-! zdaGA^W-9ywQrFsqs1L^~bg^&~i}cNFFq6kj8QrFj$RH=r>ZTxP&%1?DN6XI_=N(Bc z+lgI*-_qHJLp0)a4&q81mdpEAyzLpdz>p8)7a-1OF}fnV2u{fbB3$j~j0QD4QLu_! z3JX}1i(U6>Y_*e$Dl)eD8gY=HND9;TO+|eunayMZ0GP9`BEcG#F|RMUw6vRHdQwzj zZr3lx?DN%q}oOXC4Vol&OVnHbz2rxA=CWwvm=P2+HS*mm-uUYYn(GoUDc$u|UL5R_IQH z_ldxEA_?#S2<3I)0R=Z0t63Puj5ic);Z8*} z+HBjknJL+>*$(f^{M%SEe-C|!YJ&MK{>I9j`hn56#DhPQ`p}?$S(lS`s}2jrPEmcF zZT~#O=(06!2zqVSvp=wBrlTeesPhyaJ)17Hf#zh^ovOxVV#lZ&@&~6Or7_Wp5F|Gg zbc-!1(!Xj2Z7d(F-xN}J<%?sp;tfg6)!Fc1u2!{oJFtg4mQKTi2M#?yN`wFXp6&NV z0190JgR~!LPN0LFF$=UZ=?5CYm;pVT+;8^<5E!HKU-++g_hyaBlU0-53&yd`C;up) zrOmc|w;_me*RMNZFDn6DBZ8tO>r1l<>UNtplqs|rh)@>2v+aS#9hze6bU;K5NL~E$ zHp`^T)Yb)4s8aXTU@eh)N^(UAZ~84pMj~PC;?pao@+U_sC%py2`xYeQVxmiLtEr+C zSB}EFi_*okFXwxF<@rtu@}W^~Nj0#lY8LISn$74lK(0c4pu5odM|#$9FoCWWOu;{5 z+;YLM`!PqZ?uHm{;4t63sdHbolriSgW5lP^S~0*+0^?Es>0y0_%-;BkF0;NMN)Lpa zQz0Kki?20QE&@@r4&#p+U(l4r0V(w}80*9cYx>i zPme{rujTxWvPBy2CmZa1bdPX_&H;z@wkv+48!3)+8i^(wTzMW938lQe?aN$|ix0Xl z>(yw5E5&IqiH5_~{JaX`!@9OK$Oart<+KfDBzWkbjLs)FXC(J{O(sU4!DcMxvNEzD z8jWJ53SEm5Ysp_xHGmHe}3}Su0wwWcUA~H!X8MmceKs9sBUS z*XUmVV6pRV%(_}t?(Q-W`feLaMnPg%qSB?qd>f|sdCkgfwVB=hUHQ24obxb-cV`aH zg4qIp8@iS4k-*>+VwJ^Q=1ul=d&`@IjYau*v56G3?OZEy<=~6qcb6jTrk(Q`+4z3a zuuqHv2(mDE6s8^Y$GW3yd7E&OQtc~bOx0V$1Tj00s%z-;EyvV+HW3L{bI~=8Lc7V{ z5$w;<2>5*^iX)-Hnl_FPQZm*HgQ*~5FN?(jc|GrN%PsoE@Ydc+Ay@l~*d;Kuq;sKyeB@(}+$P_#-aW;E8_l834s>CNmf6dbsP1vp18XA%Uc6}bP zIY((~KpRe33hhPclxDwrLVM)lX&sCg{(xe(ip(iLpEzp%z!h`3JX)?nZ)UR)299Vb z)v})2{q~kyJ~o3Undb*Yzr9(!j3^op-E6XesAnFw8dO{1F_{Mno6ltHE@=7Umy;&& z`SZCb1sh+XHzXu=1O$X`pwY<|Icwox8zJC?L%B>zA84wIS z>vmr>M>b&BAnIs5#_h9i5AT=`*=5R4=HX-}j3LVvFDO%EIV0%OX+>tNnFC`Y*EZ^A zp5RvRipqr~Ax~xd;|<+vFMUNYtlNi~>e_!#^}=kn6jDUHLRfdDw;j*bVaqD_u$ZVM zzKXp?nQ?zWl-^;U=_9q#zmL?hkuU&1)cdZ>pyv`9RQ-GH>FCh2(HLFOZ6C%==a1aA zM?=f4bDF`reP(Fi!;`REhsff&{FmY0TW_@Nqg;NTE zwcI|LGya$ZMF=lKc@8*XjsddbA(8RHm@$oakjyp)P7oq%@FSe+Z!VHS00@ z<~F&CB>?kBIZ`VK%qV61VB`>Sc=+RCI&*8VW!IWKTRN%Gh@at>61T?j1|u6oaTZXS zd0?Xk?xo#<0E!_Ozo{suu;CSpM{3(o`oUl>n{HspWImUPq}W8| z2TQ8UK<{$vkVMF#x;QH#T0U-8PbB$jbvjKuCGt}ypNXu#=>e(}g=&rlPdKN$v78o@~efo+2AX;b4Oy^Epw{& zOs2nBqzsD?2y7dqYk!0n!B}fzylq#&Tror?+_c^skdTxpFw>0=d8#>(f_h6lZzb_v zA}(&iqtks&gBdm@dVR#>W_JrWElOl=qXL%BHj@or#t0Mef|d$$)17b?z$UV?gp-a% z?x0@m>Idwc-=@Yl5QR#+b)2l_b}ua6qcBi?iPatdGJw%; z#q&P8O5mFlau%~dt>zHkX3WXd_6sHmMU(mlUcT8*ADElPT^MWwy~_ z{7WA+4@yJAk}!28r79P7*hQHr=cPQ0WHMm&b@+xMoz{fG{jR7Ba&Q1v3bIRxh*gG5wHg8vJUEJ>_C?PSt@yarV4|jm z^oEcbBuhD1Az#!al%S`Rt%358aChBmQFFXNrM~8r&>dGkpsE~|Z1n}Hdno{X&l7Bfbb}m|oyNWS4vJ(af z&zjV+VgWcBa4M_K2CXFk>x^q5JCE~l&$1KGrEIitj7Yda1mUdR=v`38H|RtZWy`%i zr8rDCH@yP__&fHsF&qrxC7{raj9c_2-W1uk7roWTJIr-0CbQ6phDs1SkZ7QFu!x9? zS`brt^Xy!!1n8jn@x&dWKNOxdBMCLf_M zH@uQR!J7P8*ViZxml``hO7rRf3n|1Ni{iUp8H8aykl0C(?1^Aftx;zhL#C1Ld(I+-cj4Njzwq_t69Q^I@&_Ac1 zC0=j(-=cGviMKx||9qrv*$(OdY*MuTvAZ{o{t~aS5We%V30e$o1@NW_p+ntHz>uzS z^wM1R>;FWL-@flb7^}z!XD`F-el~G``Hfmk2Ax`c2xdH(ZQ>oXJ`5VFvH7cPOIA4f}ysqbfP2#vzJf9=W6yX?O~Wy#A?e%NL~&Ra_Smm9Dc^`Fg!b60 zrS{+RY-DOnecxJO@>HA(UZ0TFr#<0yQRBOWJO+KS?q+hCPn)omD{9VD=YrzlY$HCx z4Cm@e)*Y?JW{DC?ijs&@rx<{OVTa?J5>Erx)09U1Kc)oZuz_P~@-hkA7 z5abw|!hm8B_3$qbbIl~oD~mCNRgek3zlt0=H2eG$=H6hkQexe>Dl03wfkv8ZkK*Fu z&8@8_T!oX-8B&md@CR8bwE;MQlL2-E6|mad_T4F0Y|fuJGhOKF*O_~Os&daY!Jzw?K{8}S2-sd&g5CQpesMT zK+8EIK$LK^o@)i{``5`00J zJPU!(;KFI{wILgAL`?m@0ya}wr%^w#Av!ZT5TG%xKUS%{EC$8-uMF($wQizp>ue%EJ~mw`Kg6n0tS*mSQhSi;#We^BewfJl_8A zj!vQC>i0@sYpV72x@aD@z3Huu$;4dmlQmatN!q}!Y}xXzvD}uC3z)3hWauhu@)c5T z-8`FG(+ymk!9p)QuXN`#8RSfnoIRUM{f*Z5b#WANz}=U##p0XyFOVCKp>GDLt1god zjN!c1~^izbw?W4;5?}dsYqwdh$o_0?MmO`J1Yt0U4XksZ316vdc z8CwT?fZV%0^CnC)#Wi`AA{8)}#%Vo4##J$Eqvd|dL>#VA30wCu%DCPD{OnK-?995R zE}2W7u*o_KoU#&yNpO-RF4n#^yKrx^ge{fqbx(idDcRckny}I%W#d(+(V$odQ|GJs zWXjqQ>`mlbnSi`cR@c<`K3XZ|PQ$Za+UHJXD-ehS^Hi{!K)xhpiXl~XjL&$GMzJS1 z{GtMRUI;?9`o10tD(m;JLzAMvOE;1lTZ~5F1q<Uy`-3sTrKgi9yjp` z&q2C;mvZR3&c~Jqd=UFehFIlMvNJI}|aM>JcqQc1I}PDCCrxJIK=6Q{DXjr)S< zxLJek?N40qKX1pDyn#y6VEvoiqselX5*lkBxWQ@&TDNl22(fkXzaGAoJPp`#FCG+` z4Nxk%LXZ-#1`Af=Gaf3}tcx9)I`)_hlCIKcb4|UYvtr|{|D@NbX($(^cvG9Kc#q?A zs_7@es0pet1w?14bpT3qHUpo-_(V4LLEmd)=Lk&xrVR!2HURO6fK@briKEND<7T`xK80Z=YV8M& zmtw6>7nk^nKk`gV)${fLKyw~kz5EB7Q|eG5n5@D?CRe-03~mx*NFad$y6q_lBt**{ z<3>a%{0D7~qXzvUGli7^*`xjQHl-+7@=%w8`sst#_&}i6JPvd6qy;S9fg3~iGAOhK zIGGY1S)9Z9-J2YK4tgb~2q|K<=>2yrC{HnFx*-W1Gu<|XFw>@+J!E82#BO1$dbo_i zs2(1RxX%-J*6iB&B(!@MZZi%;q>Jz&2e->6Y{uici#->b^7tp<0a>+1@?M?X0VZ$- zjbqj-QM$s_o2d@ji-&Ezu*1mpY3>o(2J{~dWC8z$|J;Vf6#JfoG{PGKb39HujyB9? z2bwji$zFHQ1&86Rz6e=+e4i=;`v`ZNxOd>Edxe9|?i9F@(PY0p5BB)Eo)mvbMg3__ z@m;8vlL%j|uz48Zt@?8uoqtG!^iQZ+dQ+YtL9?e7*wO^)ABa#=9FET2QCOo_zt!eV z^)c6o?AXMh*=1&3TqOswaWHyIU29v%3G)3Rv+9ROR z?*JTdXABGSgQsGhIy6upg^1l25uNq6i3GVZEY(tI@9J#&asp}0?4aRtF}1hxM8eqY z{|ydx>)2Ihv}b#&+3y5vQHhXR?#-7YVm~0|C-rpSz;N=gasZNa=c%E-HUxZwf;oo4 zAbc`#p|?2%gL=3SA@QWK^-OZH?m?O?J)N$s{N|DdMi7)~x^_6n9of-$*<$!y_L&0n zqW?`bKT>UiuDR6%O6dDz{#7biCPcVwotprJDjeSSasz}E8Y09Fpm8G8CzXG=a6t-# zG1f=ON3U;zJTmR@SsbAF6Hx}^ADWe{6lzlJjf5=ZgEXvCm)+5KjvO^59Ft$rd&xW}(59KRBrXYvM1-wq(y_D3mSu z*pg$lVO=m?(@?QlkWtHp0L`F21&Ad82|1ov%1ez%paw8*MMXy=C9bI>bDaOi6Ea5cNWDg6_ny1@J)}$u9#65{fC-;s>&}|kt zR*V~mstUZ!9%)uaH2Y)W1sK~w^GwPIhpWAFf*#e~80w3$loNc5L>>#NyPMW#y}%PN zFT~!NRp`X=jHw*Jl$#X^=032!3W_78 z39OFQX8EU30P)AgfGX1fL@B(Du-nbY;YdhgOp$pwsNI^}hzMeeK|c{<_dH_E1R|sH z?^b!CqU9uy?hrH`JaSNo+#3W2(J=&+Uv~*<0(HCvn_2KE+l~S?!1jd^9oVpEOz>HW zGl9~5E2RErBoqLQDbOM2his|-z+ePpBzfWu;^&qqK#?$E9$^ho;Qnz3+lwcZb?${Fx zzI1zVeO^(QT9pr4$cotBQX9k!fThAuN232dFhg+RlUP_#X&Uq$s9-+lt>9pURdxVn zI#qTM%1VR(O%aJR!f(+T0hWrx(Q|_?Q}SmoX}ct7?f(Ngy*Bip$mvtR1B`kn+72(n z6&8DHXxuYhsOp9;EUCKom}PD|Tvb?g;ccU)lwXR^)>=f^Hnb+|&F{x`xvcJ>GWy~g zIf;e4!c_!-S-vl4r6Vv}s6fW#FVXS#c<)<&e7Lv3n$QH%b80>@@CM-!Yjvqud~ z97aYgmj(?&QeL&Mln7JsyN^g-Wv4f1E$(=&t!V(i^Fegmn;Y;-Y-kg@vEU54#gsnm z?OR}x%C>{*gZI#aR?RH|K%z%u8$6v5&9mejmgS*zzEptXWC&Gd9X^a1R0z-tZF5eoX@j>nT~P^Ebr{f5;Qi2zb@4Rm7kZ~r z#n7fyuj{zCHeRjR7=f(em;hsE^2J@^@u>BBVlN~Anx{6~ij?5H#s9P1XCD>lVLx?9 zl}}Lt(%j#qyjtgSB;J2Gf#}vtM}&4ue`RX5nPGuEki7ZunRCo2O=p&aI1^*Py*EZk);=z=o~NJhzE1TQ|2dtBo1l%3MTqdpqW%t+Z!@ckhspu@x z=pUxNq-E8PppII^H6|R<*_+L)6bSn+&4;xC#qz}J3HcBMfo4F7dW{&s{^H@#f)0O? z1PyBjlF|v*OIl}-q5%UOK6Ss}}qMsL0RLE>mJTg=n zBjA}L`0MX}nQ~z>*USW(tJ+Ex+rtx>5en{XU`?i$3z5vUmb^cbBSA>n=bF0H=2ZL| zDB15Yq7HM|MW}9+vzlcoC)L~+uKQ@ns~EKi2QF*$B&01d^Ee(RGfNZ>wU3M6FHp{nfKRl{MtZ2Ya5q1IeS zkPY5+J6@3(bF)1q$0}{4C%NJHC5}(|K$9c%=1Nmu1~db;o>peK7Nwk8ZJf}6H5se^ zBrnyWmT-g3m9(nwq^U@PN3$_D;Ps2>6Scp}TM=wYDW5a^*3v{k0Fr^qQ$f!>k(LF0 zj`r5lK%W5Q6xtkv9WA&_s@WmzDYanh?%jjY866JAd4VWeS6U}5nceBM_xGfmA)UAm z7F3c=It2-B6SpO@>iP070-v>=5Q}oz!6cvB4{e58+_BWGC;37UkPO%IkaRyoiCDFn zXviNJou+eYNqR6}6Xg(_yMARo8P6Z{QAb5_Re(2<*zW4PWF!1~!{$K#^YuPiy}pRX zin085&KyKbyaYYunXR|1>;P6IefvS8grs1dUF~S7%Tn7f|Dh#IvLy2*@#S`xHt>4e zW)P&-oMMWM#*sj>gw9s-oe2Do-ZH{ zJL_82Ts{q{k}*3oyZiFZggOt?6CC}VdE}^FszS7xBS@mvK@dJVf8~Gx&2&vTU8Mmr0eVuXYi}4Umt1!fDQ5>P^JH@w7*I~UV!Q_o zJSa7Ok$9+x&2Jr-jO%M|G#yx$F>ybi2ubHho3cx~>Xcp!4{bCQX#Q7kv^X4&!xQty zDxHDThm4h0Ypk_KYs0lh$w2i#++Cd!I9$zwY`G*bp3_NE4Mitf4I=j**HTs8f?;A^cy*2~GVvkc8?xOckow&ZGg9HhN4 zc-XQPJY+)lB*e-rqlyV=czCh>A~ge~KWRk@rlK#hM@)rc;s!1F-3?9wMvR9S&yyu4$GX8tQ=b}MV&YC}`{c@+ zY-sI%v9KfE{^|UQgBb5<@t}p2hq^?*)>4bq<&N0;<_JuVd%Cc9aKm}0&REb{Tbjm1 zn5w#)O$jQ9|H#kvEkM~UsXuv%{d}i>L(4hgf6Dq^R;*62sMxlIH0pSsaxQlmlUy4l zJ}mu~YT;>Cp%&vc_DQcV3k@JW==0ii8$AH!YY$n7e^}A_zsdN*Z$Oji2!-2?vO@L` zfNc74gu(`RFzrku-!Jd%T7kc^4<}`r#EX=+vt3^Qyf4yWF!o@`RT2-+BO<^RhWQZ{ z+Z9MRioHcFA)iOS{ILlY_bHKhm1mO?L!UQ6KADs;IYLsQ7#?g}DEq$SgprejoKUB; zK`WWpINNlyq(u$PYne{-eF|CajrQpun;%T>>~u?~on&rGN4w>fb>FR3qisnWjWvPd zP3#o2Enw!36!Ps4w5d25`*#wYN*slly3w>T-lGN6l^%GjD6746Vljj+Hhq#{MmMRxN6lvFbK{CJ-)RSfz4 z`!}`X+4;rAUkEL`+qGV_N`oiS(be7jbZG5DSFv}8vlc7*@3~Ncc{|nk>symF7cP~1 z21Mb(F8mQ{e|0gnKZy1>vga3?DVG~Q2D9%|7#fk;M0<3t>E5bB(R=uab9a9vgy#-566UOlgeGtTIh0H+?mrP@qb_sF{}0`yY5s zkPROX8utGmL?*^)o~7z519%H&KxWRHD8t3k8AR0EwYMH*VvpvEPha>zYuJOBr$Un> zSiX^M0T++_ms__i(`ed#4B>5kV*_Bb}Vo`CNC6me_ z3kCwSOglLe#PX$pgm1QLUq!cl+aArEnj`au z_WKn|56b2+jwtYxTZfmw%DK%pf`SdZ4i=wvm#?LS0(ABu@8_y+z1aNcB~SCdfVT@)n_$=uP#2f(~|LOm^# z4H9+6pi>}S1p*!7zNL`57|f^jdevFNdxx)lweys)PQMxxAll4sc7V4W#4^?5hUC9( z`^n+JK}_1;B}BIu>bnptj0uH)E83jMFW4&E2N~nh_vLmUF1p`y=CWL6Uqab$UGqkH zZKZy-H9;}~1OmYU;l!x3XFfWgeRn<}YRuz0jx*}$y+vvTBSyL-8aIixz{`xV$0Xp2 zO5WT8mNYYGfQ*simST;aW7m%+0o8*y&+hA2gS{}Zej4YPVBVvkpAtCns;u^*&88OL zM4&u+f`oVHvV`@yV?@Il$3fA_#1Q_B4DA=HCA3!s{SnpG_gr=z5-hw5mXpXLye>p} z`>?*}bs2IgHnl*!j{JLFI9rXg!#^rGc+UdfO{N zE5?$i@N4#Y#0a#Ip@~HJntRgU+lH-P;w!LpnFCzX)~q825r z^$}GCZi~Nc)FTRU*Oe`LQ-b;~vw2CV=ZV^Zgen!dW#^Lpid!hv zZ2`0M{KtO0#XzFz)(S#n49DLgJ29X)_HecAt29Jfu%Fc7o`UoZNghiB9b|4nT~c{Y9e- zSJXUp1@lh5Feo@bu=V0G?gL{ht|%ee-*8aOe2d!L;so2oP>h^m#E2o?St1=)Lgm#q z7rp}XTabcmZR#X^L1M2k+3;v91LLD(&KJWu5i%+mVWyN)B%hASCxpZptLp!S>o(H8 zCbwFM8z6Z}7HPzTg zK!lr#kLE<-c!^AqK{Po_i#ekQoHV1?+tDedQx>Aof3x5`g3R-a^0GJ*yKk_}Y0i7>o0rvgp;YW0wje@y@O zxvlB@Z>jL<-U$Md_!EpXHXw4{n0X`ItPT;Dz9~MacTKpN#m> z>We~eN(|~5yZb$)!HGYeZLNJ9oSpKx?34k1EIOVo2Sdu_oiNOAy|TE4ba6HNQHQ3| zoA5)vW;bxHbk8FI8PHOuSxtYRA%){fxIiK|aXyudt))CcalFdtkTAM~$D+rNHq1#@ zBxxWKEwkSn*c*6Uofe^BT|r}B%Bf4zbB zVSi1YAy}uX^#0*2oSDK_WItJKD0H??)m76j-hq&scO2h1k6FAz{=C;6t1S6Xog}qStykZA%N00oAu>4J$A#Yp z!xO(AbI@v|*OVs?+&XB`xcKQ52QW*<7I(qzsm+ukAt!CwUp&E?DMAbV?Ew|WX<)w7 zZTA}FbdJGo*&1;9S6VwwYm_bJS=RpXT>WyRP3-^1`^jLlB_$*z1do&E`k`SvWSKL0 zsg%58{?S>#&2&mlF5D6x-QH}U`WCEZrbuU#fPu9&dz`;@ZHF*DfWdzmn&86LcY3A& zLr)U+!o5AtR>RgHP?J}}y;bm><}zeuwwOQ2)>vzdMruqk5HP8dhs#?y_rG}h%Ah!& zuj>Q?1Shx?B)GeipuyeU-CZ|8fZ!pxySqD!yD#qUEY1Rpy!`%8)%#_urf0VM*3|9p zJE!kGrw%#61Ft*j&*TU8A0I@QR9JqkOyXE;k7TCQYkHeC8e`ehaqQ}jUl=6n?=u+T z`DI~A!?d2MW|Np)+vo)04(J-PPvp9`@}>H9#X1Tv5Air&ovXaMZ;r*Ussy+!7zl*% z7#&{naA$BFxb?g>I(Vrpflzm*u9UFj;eK6mp_FtUOxTBH?;lPPUDjJ7mSIl>^WAc4 z2CHZuvt!|v_Z4-G)CJsW@2mf#m(Z8?dxvOqm*w5D*k}X~pEA%Aw0D0n3+&s*>W8$U zXD={~$BPuX^#3E&Ht9Bk`H{F?CDKfUbI{e5s82~)YulgxG!w(8?(gN{pps_aVPeo6 zF+MgqB@I-U2@&-wD=FQt^-QIz+AM+o+aF03i!beu>!pHkZ!7Inm$YTgvAdi0 z%Vr>7t6XVRAvf{A6cS&*lBi**Md0)xm6QOKM}xJxVe%+uKmEXl`wm3r96fEAr|EW<3vP^mXN0P}n7y?ckP2eR{4L zt=;#>eLI&Z=F7aKyIOGD@t`=I)2zguAoS|XmVLA`BzAVDS>l6yB-OQUYa4h#-zoiQ zoZ$uFjOWSd3kr@BL&oR1@$BFi9-zEa|0NPE3F*GrgGOa_1cC@~*(dGx3K0!4G|&1V zJO{#1)M5~c5#M0Ym#LfJMnbXBW@a}@Y@jC%BCc$zL->J(abvv+ni*D!=dh9{*5jvaw{;K>=R8p?p$ACw-mQqJ# zk=a%p--_|N@QaK?M*nM}T9fe>PN%AAf^zZ$xInK~MwU8P_HD{9+Z9vS@#yQxDi4S6 zisYolVM8{zE#(2fKS7G_@xRT3#@W>)AJoFIH)~;lH8}#vZixf!=2$k87rY)FxGwXC zd)YL9=1dyxaS4*m*K*>-YJL<s z2SRv}(Nk*Z(Ad!cqr0x2`JCX%&W|>84_F*6u)@I3iINZm)@5u)JZ`qT{c35%GfAKZ z9C+4*?F(6R_QZvdY<0V2y^ctbHw)KPHbmmj8)p8@aGU#S3Wp_{YO#X(J5}y@mQXsE z?LxvPYa3>!@n%x?(p7VLxBgDV&Gc!vIrX~T(6*b&n9O$Gh>n*+y(+NYxQ`pH9|x73 zp3MwQxG+KZBeDO*?$uf#b=_jNJnrMorobb6&(EY$EEfIAkc7!ayx84Ut;TChwAG9j zJ#KbPc`c-lrDgxsr8bD5b`ez_G4mSj5(lRAh)4e2U0?lHaTh3NpmeK?gxctoGmH7| z{OskP1C9SoQ`>o`*Uyb;vwE{1jFs$wSgpk|TvvXDcKIV;16^oyk?<=7h!T0~bzcgRxHiW%OD+{R2jojPFJ=e`!)$lunlyiGC zFvmjn7}p%&5_|9*8zUK26?_NmDf<$_Nmx2i`3#>E3Q_4{yEbg}w5i!E6xJe$bVI}C z1ibcn+G;W1UgHftNGTSa{cnC{L2VIMo6isS%=?P2w?{|&&+fYyr#Jvr!?M_C1d=MV zjJ&R~+Wd}Wf;E4$YzDVglrnHtKy}X6sCb}c~RTwx*P5TX5O`)FlV&1P;?b-*U-i!qURJZ^yGV<1;cd=}3#w`3HG1tKvbAlYi z0_&*-H9`NX;cLQaHGwcvd7AwM|5MSU`@v*EoJ2^ap8wNr<$&ddLa}}@i?`XjhQXgU z`EoP2%#_Cgs!(Ch_>>~h$LUYp`>6x$DL3AvGp2gHOMTfNLMu~g!7tC7{v;1T3cxw*E)_Dg{98r~+Uv{4e*Mx+ zRJtqK#K!2$*UGtMzVhN}(R|5}kXtn@W?6?6e7+>bO(BW5V$%GZYuWR|0KTjf@6Db4 zX73KAbWfzF8TXy@DH?XaIXiG<05aUWiQC6h7SjW&2cGoIDF`Aj5w~%s4?y8)Z(VXr z(!%Ajl%=m1%=Yct?Ov23T-HPVDP<4klmzXUSnOGTNuEzUQ{HI&qG_>vyAVmtEZvE{ zO;6_RLE&*zS19QDeC=GFEp)lOXgY68PnXDG3vAh(cNbkC@s>KCx)doOXkJo=@%AYb z@HsF^TqS;I1YYU%3UbvSXWeH5M^-&1BoM;^Wy#5WAhwN>qZ+Fk7|NkO?U$jB^-k|+ z#ChE|_S*MvIf40%K5-x_Lx{&x07Kuf0WYX_U?RJ4s$MW)`1KoPA(2pwb^9af<6_^BdhJ5q`7gvtD?3fxe&={V$nF&!4=3i&rZ= z>q(kkKHYh1_D=B!h04BIgy64qe3|{orF@5q^h_J%Up{&7%l)iu(dRQn`&iQP84bNV zHN~g2b0zp!4MxXJ(QjH=A~i3#Fsv;(Dy}P-(pJy&o6nyq`*kOiWeb*mDmLeN!Efg? z9qUYm^D#s+CI>ZVlSa#g$Q&wamBjx}z}}gpfz6rq!6M#{9a66{8_Ih0YZ{Um8-e`{ zxlTC24W#Ge(?&$38?#%?qm86P!}_6y)OMyloyf=j^XF(kWeH{>^Jv@;(SHPDTAu<$ zVsh_6fSc)NU2irSxN<4g8|@GR8ob%fRPR!K#P}tCcFdnHPlkZ(H_9q=qW`W&KW(Zk z)V8Q=u@0oxKtsPMJBg9_rz`k(<*xrDk+aREbEBNT{=*_t?h^PQA^IC+IZyhBEJ;VK~5rIAd~J z6Sbq^fi6w`8g|TJBc&bAau}m+0oZ1~|IC-?_w-EW%(SX>h}UDTQQKnf^+m1pXo0K# zF<-r(hhPkAMk1{B_%gDjf1k%*dnIAFcAGo*1{wI|Kg#eYLQ`+V;FEr99L5fAo{H^^ z!HwpQ7IH? ze8%VzifJDM322X!==1zn-Z-+9>B{-2SXj7Mw!<8>5+_ z5y5~EYjDZ6pk#i_Bfq4`h@ChXMy!Uf+w>G(iC*LuURq_r^qub!>U<{u zK%hc13ML~_Nlr={%oPh`O_k#%Mq6{>c@JDN6{M7uFi;6OzkdC?m3_3rG{#VZYwx>A ze;$iX#7FfTL-fyT=<4~PMg6wR{@UA5<%H*M+5DND#Z%XdSK8~6Ih|Yh&Dqxo$-#^N z>{#3{0{kQ8b^3)xqkus~6ggu>B_JU1=D4x3@$MM$DJehSwN9+&N@ga&Ss0?9$@_s3 z6Bd2jPJ4iG&C5ykcY_*rqQ~|{i$S@Ir*A!3Nbh~_c(|y-7Q-T-T`n2|oOk-eMSrMf zot@d$)z`BTMyE{hH5VmLOaC`|o*8#KlZX@jG48SVN41`>GNLr^t)govJe5uSOia=5 zMkNq(<0NEeW^QsbF`XZ{DgF1i9mD=}54{|Y7T>o`>+1f$?upTcP*39iuWMax*y78dulfSy*v&&YT3UPnlqCFy!2WDz3Hou zS@!yH=qwA-c1{)00N^caD|c%HJ8l&X>hjI>pY7dZDqE1_0=nTK<$ci=whK__7683U35o~GTL)nP%Ka&rm_GB z1!E)i3wlHM+q0yw&J7;7?=n4h`8rpWu+R#qEHCmV3{vC#h#Amobm-z3XmcJm({Aaj zqv_@xFjZZl)t&~6ch-U=T&&|MiF)>Ptmuc?oBv-!O}}2xp~`PvA|G08d8Zn}`(<^F zjG~zW`+%>bPwO$syN5BDtPW3JnxYf3y!PP51F-WT5q%{=^>VEB8P z|ILR@*-8z=rqSapem2Hw>D{t zlP?^%%At_spD#%8>`fiLX?J1O6izmSd&vf7S_wQf4Hwt*8UkpVLg176FiCj*xK>s- zuScKljpPXd*(M{;mOnN`eV4|{UdKqyoDO^gL#Jp-bWRZ^czu7RLch3rV&k9?*!i0? zi4ZgMIhDv zZZk(Lt8WYH=tGe4<~a#KAz5tasdq=Rj3TB3&dbf_IBA~e@{9CY*x(4l&K@HF9Ul)s z9a>_gkz%j+ZO;(C%|luvA&$e>4u|Zux6Y3^og7O~C;S0|dS9K8F2;1T6$&-ENURTI za#DQ5QQ)RWnR zNydLwIkNfexG3nJN%TfJg({?ub7a#0dpY&S_4lkn2~_1}1L3hZdU6TQGCK)M$QDBd zK$bsQlH~Dw#!0lbB16cx`{fL~=Qif2(@gkM_eX?c-KZ_L);qXtV8#?uYuf|SrfS`u zU^#C`wtBDJI-re(hyFrVuh&K!59C;~(#P_9RjQ&XK~CdU;W>kQt0b zP5}oFK)4z67X=qJzNHSVn9c`MF;&f`%xxWLVqRWyA~wF3buViaT_RqIfvTWGZ*a&q z=es_F=aB7A2h$!b1<;FWL0ZOO^da&-~u{uLPA2K;N^WN%fcci zmQzxi{{JgxvSj{`O6C_eX2IF+U&<4F*|n8gOKN|YjrXFgk=NnpzvY$Dlwpop^(Wq0 z+$WEv3bzUUaJH0oX_wXBEwtr6);6&zamHDxxbZF!(|zBbenbX|a5Ts7{4u}z7wgRL z$}D30=`o+%em}7&I=oCZjK5XHeu%kGPtcrp+c0GSZ)|gKyIu5;Nb>Jh|Hrqjq2!C* zNp2=&Ncq5Ljc5CBzMDP=_MX^Pw||btT+thc44rXH7YO?!6+I}QpOFTeJHrg#R~Y|M ze$1eD*6XV(ucPW)l1F$0@Mab_Q{=ly%^07}h{Y!4=3zj)CJ6B}nJ6^W$!Kh-U%{?PP=p}4AL}_$ zPfsP;m&4Cn+gD2QB*z=-^JU9l-umH1c7AFZ6oC{1s@4yt-Dm2Xm^UcNA%b48ZfW&lMUM{QYU!#c(yXWWnMfdNai&~XlDm*IT z+Ey0ge_Q=C^ytpVjbFbsSGAdZo^=_`*du|Id!J1pipu;=qJ1N`giBKo|BJbaOw)nKg5pu`PvPu$@;Wz?IY=J2bQ%D=e zOlU4kN1vr^a;@JLg5YK~T3at_2_6SIge$4u=;2L+2!?c zjJOR8s&-@q!PJc+T8`Mb_=a{vLTuJ~ltFRp8cZPQ`C&u#)_OL=7rZGMG~w=!H#k9k zb(X{5U%lKdJ(L(4Cvnx=?eQsVb!2venYdM2W`6p^95)Abf^9pk$O|QzAdXo4f<8gl z9YvCDsRt2fDA@h{Z1aKqO2IcUKztg~I&@U=7R&`G*mlCz@_<;`APDfanRB;i@~Mj= z+FIOzlxen6FAfx<@6$rfQ3yeqZSN%(x5~r{y4zej0CC+DKD7E>1{wtcCm&ccK@u_0 ze}jEPbFF>r6t+K66LmdM6kh(4Y{=9EX*Sh!^TyKpq0U%QDU!$JP9n8*m55y2iG^zG z6CbceBP_B%UW!iocjPx)7F`jpMy{~oV@n+@AKfd9ZfC9N$VcfVD1cL`lWmmTy=m_G?F*EU zR!N0Sg`sw9OG#UIFI+#F2TC{FFQHfcY$}AK`ikVV1Xg^xe^XCHjkxTTfpfU-d<=V6uk9wju|%;#xT>bx^%tyM7p z7levF+m_7Rxbc%XG!upODcf1;uT$gvtap&2+qK&&O#;nDodoHSjP0`=+PEzGLtKl& z29od9d{>dQI6BUo$P=X`_pE(~X9EW%Bl-|jpePZKe-Vp`jvNl-VGO1%YSBP9UQ4?z z!-m%uokm9eF3QEw8&g0C7-u{~fH0D)1rDNUsF2JJIu{&Ww6)8ZK~K zEK~rFL+#?D1xGkFU`~_qg@0Vp7CqTo9}jO}R}noYU>03am?L=Ti9Q+V1@y%~w=L&= z8lXUHV@Bx3u(-z-fVM>ZMh$6#UQ1ZH*`c z4@1OB0wmbX{Kv!7_wGNTuJmL2tpI_#_uN)Jhd`+m;SvH+q_7_O2Qqc)VitM{>WwMy z=vxs(<5o~iW}%=nWY$EslI$$9SRgY2gaqfgZJ;P3BbDv%CwTMuPWVXo0014<4Gs?| zKg!2Jyi|_&KA?!$fzPv=?m(Q8L27I)=WE{KlB}k^PAEbRN16HS=O*INjMe*cd8HEP z7Qp1NSlHNl?}r)RD4bAPkn@$2J1i(~L8+$exOYlmW^*%3BUWQ)7ozZs%d&6|P=giy z{BNmMRw8*NK`>j}`XqMu>Qy;~Ti(oLrO(aPSHJ5;0*e4F`AOr?_rj4+Yri$$Cq(?n zRb=bx@y0!8M2)M{j1L7E5GDPQlo!kMX8m6_pRs!a>l^lWQ!s&GMyP?v`Teoyz61KRj=-}rHZRfSghq4v-fO z2@YJldgzTBZEOlB#xU{cYP_ira-(B`N|9njHga#4rfQZl#8I+`BJeKqHou~lAAEls zQFK*2;!BDq8NbccAO?Sa>b$|JzQgb3c-LR!aKjGPd35o)Gev=JDpYvze4!NCIhD3= zX^u^+J`AiMeu)1gHONtynAPcu!0IuhZKn`&G*Yg`>P3srr87X}>%ua3YjNKplyN2d zI_RMuZ*SOHGL(JzI+$wbbBy}h5Nmv9t}DfQ5XdY}*nG3|%u9UValV>9YOJHC4$qpO z6e1UfMplhLaJNNW6q`0&$Fab!Iign|6AKbp3GOrO7_tbf?8ay^=y=i66n%RQ8p1RD zdU8iW!j@V~`$L@D?w8M3@&z?)@8%xNpD<6z{=3PwUl=9@0K-+*=gWGR(8F@EW-XR1ZUH~KO+>Qo1Bf7sG5 zVA{DaO?8e=DpMNtz-tb5^J;yA-+ma8hokVNoH#0*dPsL_BTW8+q z`^ukgLuhlOuOZqu&Bpco+hC+mwsZU3Jjz~&LoMZ5G?Tkw50mHPsGndz&oEgI9XIv6 zdHo6{dyN&T#MNssHcDSPVtF1OP0(oomO-5T2jz(sRL3-mzTowWS1i<$QndSltEunKWzg49+2WzC@^Jpp*y`zU!*R-nQIhT zs;&(44;f~&p4CcZ3^(jJJYg~j6Nu)gLA@_thqT-%u*#v#MgAONP#w1)<*VpP|B_9& zKD)vT_03{wFRi_#b{Zq;r@4!luOs~2qH;jx>D;MO&X;rpucaOxS*7tfD!eD>TCcZp z*{?Pgdfyy~OG$xRT`dvQ|A%Xgey?6MWyQa|y%n!se9t)!DK4f-PTeR~qi+<|hAnFl z)72umps?JVaZwoHQ-=FJqfjUsziDx5zEWUX8>l&vgi>z!rBX+t_utjpe9lK$ORXBd zr|Gf)jF>(V4%cJ`&)JhtU5{qWH_3L}EkI;*NIFVqJr|_&nqeV_zoDxf)_scl@6p2c zvA)Yswiy4p7^j+r_md;py1>fIQxy27pnX%!Gc$xS_ORY&F3QHv`FA`a&dI+6O z@Ipn*hvE0zURoe&l`4-vTVJ0VD}&5875(rJ3x=)u;qPsw3U2^)*Nd`!U#bmw{ViW( z1KmMOf)I0VZKe?BD};M{*+OK%^#DP9v+SOhK%@MxSZx+MHlp9=Ku+Q;uIlAY! zvs?O&mj@V))w{yq+iV;Uo84^%Z_mut&t^>CCKrP0dB?Nm=>EPQ1r8;mv$d2N(XP&^ zP|Solbmczb{d$(5 z&@C@M;YFmLb23YJx-oOlO#u9N^j58$PB>jXrh#%*JD#T_PCcX53J$2WQldU=-`roh zQ>4MWkNElX)}+A2GmXJJ@c9nyiAYsy?Cz_ZrTlTZ3yN-$^Mo~}C9fyg5d97Lp!OjdnOC2y6r+L~>Lm(h&Cv^6WUub7sXmIUPkpx1 zJb^oQ*5zOmXuub2$e_I%XjlEF7t3FORedhfZ7cEWb6fVdeTvgV=@^w4@`6_+0ztiJ zUz7EpP65yDvm)b>bZZ|iIEQHq_`=xU02H^HwcmDEHDvKUq&U~-4Qp|SqRV|t-UMO6 zC=}MGc1i8T$;@X03Z|kLG8sB-S)Pzp8N|8!>x0|aQhrxP*tZL!(2{QJwTYy6Er6DeMbvbMwg@i=$yj_@v{|@ zDz{b*woYyS@k(#3X*qB!C%<6PYA`TToF`0OQ0ci#SE=n zust($U@#Wl-f{q@z!&bF`e2Fs|*1?rHKl1SK-QNYQW57hPm%?suQyz4gY-{wJbr z{gJDR42-k2XH%V8TWFEuz`Ag7;5AWe*siZnE%6jc{Iw{uvqL_ zmt&|#T~e*4;o@yZgV-k@)krY|(nf+sCJKB&Sq&E((HlZO*J}vpe*!8Qru;HPmWy)U zBYpSp>T{7D4{Dl0Z%-*h29JO24@p`Q{zF-}d$G_C+rRrH5$(@4%lRBiu}8|0X1Mow zxHoG94c4=B0kR?u4d$~=hN+)_p|J#`&3D^B&qDCW_@fT=>$xnw6#DVonvtj*HH;4P^-O;cq@e5)ubLxHY|wYg6fB)?sGzm@xUK%mKH>c`)oQtc5(+< zHox=^pF28l%G!=>K);ysR~H(buKc!^%Xk+0%T~VGhwB+Zh67R8Yx6W)(sx%P(;x4c z7$}eIC9@Y3dqM0nAc`Gxwl5Tz$VwYP@yI{q+gs&`jFdUr!!v~ALL7H8%VR9rLFU578F?E}wji9S9sU6&> z*NHM8bM<^AB}c#T7L<3UW8USEmhn6+!Rv0 z%Qydv(6)fJ)Nyy5k)Jb(Y}gs`Dq^1txx#THxgoVq4H>FfU!_0-j{7rav?J>}{A;=B zoPS<$B|O)47j1uFq^^gPn8`y$Q0L#H$|I6y?FwDJpy#LS2V{Iw_@DQ@{h&cFR?1g* zAD`S~mnSAYQV&-3jc!@zp>sNU6eL3h26IG`S>7w2KALR$bZOl(j`2hl7(=s4fRDi( zwGKL(_Tmjb-vmZFi%wrCkVDn>mx>jy6hNVKx5HZyGTb`K$v6%-meF5bbJY-ZS8g^~ z&()tcnQJ}1mX(jW^5Wv3!o$Josm(nhBo+GqWQ=kJAoDr9)7(~QenSgbc)WjgwPBnf zSpI+R`KOu&OqnHWjTH=w?{LOZI<$0A8`=;^Vc2ah_s%$Fqx(sfJ(A`>+ z(KujAwgG{T$!}488ciRay-fhnm_mXcFKoXXbzo$Vy&JI>Fca(WtyiNrMW%ioji1CI zB#0B1l>Zv}J3zn%Ex(@99jYdNV}Pr*)@=kXsTfMk@K+7?C3i5#OBV6LT%=#QTIe9L zq~iXaNv4GzXa5p)31$z3Ngs>|@n{Y^z@6wF{RPH2OG1r!QvkzJl+JyMpa9j4oY)J6 z4S+Wr6s_9)b-PhwMNt0bwepgmJ3f`8zIM-?=}RBJlO&_&(RuXvLX4#St>s<0K% z`HXiIN7n!X0_UyQ1DDB(Enl+_DuRBD3t~nauqlmh#xjM&iQU#UILacwZF=cckVx1t z6lB1Qy>!3N*o31kS*=ZSgo!)l$YFWu-oyoX&~K#n#MEy8DAW$w1>+{xUZhN>%i3#= zWpHqsswEHNPde^Jnk$@1WNV6q9yHcJ+77ztP1c2JdG|sm_W+8X6kA7!&yKe6>N?H& zG8<$hbQur3w0O2d82YQ@hV>fQM3akWBibFp70jvW7PV(Bp-J|uhEJv%j6g6ngvAn1 zS9jcMGqvVnGeyYuTdD=ZdXfJ`!ogvRfESW&J!8n>H>6nO!E{9s-KArfmn#E7qV7zU zO8|{D-y?$5t}Vy)JO3=`3d^yh^&gPu>`?q|+b#WxJZ?E!@H=l+}bY3%csTY_^y0sBf*lwqv4K#mIyMrd336N>z7v_Ni z@IKg+o`smxg{8K+t@N0&V}Q#;bov6++Qd`xp+h4ih^AQM|9o!BMRBr}ZJc zb_>@!SfbD6(p&!>4u8HPOCYv1eW~knBbW`mIo2I%qES^jGMk1|6*_X@u6RiQ?YfW- z8Sv{Jk_%`@A`>L)V~~Zi%fq$Q@ztM;ZTP>zpoNL{(!{3{vIdVpfW5O_Af;^z=s35q z*FJ&9@;PsF$JmhV>5boJZO5`o$Y|okGAhuuLe!Ujo-|bEMf+-Tojs9eV~*8Vl*Pbc zC0SGb7iidUY|{iB>k)yAmRL>QU_~T9_MPXBQ}zLQlo*DK$eqky^xv|uE21-+6r!#R zNxeZ#|J$8%R9byfZR_dH*-JMyGBIDFi;MF^W&D_1}L3u`KDvSz9FJ7v5l^lU1Ltd>-Cni?5D z*EZV*uoMl{Xq%ZC@|8?`@7o7s4%RhVN#!H?G#Jc8Wqpy~HXvN)uN1)Qipl$NLoLog zH@8rAu-4V=wao;}kp=FPjfAqQzta!!by{;zb+(3`W&ERQlJ|{{rLbUin6uLmDTB$t zhcNC8B=eZCCEKojBHoV7exoP~pj5DbP2ml@3spbja@a^aN;0+r%+#H+{CwBSHn;=C zABII69W2JO>%uYnB1vg3<;9FGKdJ%#@oA_c`MFRk!i#y4SCJ1>&SkHtE=ErS5E1nYuZCV= zD_3(h_7V5LwPk~tHTeumq%P-@H`(a75J(Ab;oB!9!YbBB;n^XFbaM~q%Ap-y>2{~l zKXDPRClBc6o=s823|A!n(`irBYmUoM^8jjkx$LwP3UE@ z#r!DUl4?C~ttX2GdgYrC-cG#j@~AG`#?c9slS? ze^SV+_`k@w-~Pe_G2=Ha%X>cl?(^eK%Q~jpv#J@$#czE(Enq5B81@O|Xy6wb>J`hrbW9Y-nQ&>YJLqRbbXg6t z=}j}yQWIC|7U}x&Y9W(4YN6c^DR{&%kJ|GTn)iD0LK{A`nHCa!i2qR0ZeRncTdMov z@YdyhwRZC>t4t^OmS`n}f1|7ys;0eSwB75N{&)~}A&mz9aS^y~tvn-kKfzFEv;s(p zoX;yIQ*z5RFglFL)N^78p}Dz>%+rOye=3tTXpf2Nd#*bXSRivzZPqmm^(c+>oN0aYHoXOPIgAPU=bYAM;H0tmqS$JM=s7+7f(o}*v6eQ>_j$q zo$4!(os62U8dltOwrS7cK;u!4(=HPmTU>S78ox7WG0~f0FE&GPdC!$IB~m0%YoYC& zz>S!?Q2%hP^1f%QG(RC0p!{p6&eYT9>QV{M`>nk;ePT&pVzP32=Xf`~ zDq%Nk)xlzdXLIX(SM}&Bg6V0tD2;|+zDKIEfJ9KnuMmC?Z7!Dya-KL!mhCrrb<~4m zg&nB*2>eA$!yX0#rhQUa9PGJZ*_;TL{fGF+MqVi98k(yX6y`sfw78nXvj5fg*#c!f zdsAHUJ9gzs&nAGFf~$n(2AC~>*xiAyi+l)xMtE*MrCgv2ImdI6L9` z!%rL*@c*Os4EEk9Ax(O-4O2j^{OL(<$lQ9zY0OyigzVdvsFV_oB{GymyZ3O(q!$bwVXthe|M3+ zbg8!MHEtqu_Y^}32zqklk)=}a&9#|S!l`wTw z%>s?A#ugtz+w^`aYTlf-+`#W*`d9WRyqv)Xup$Q;RfMRgExb>Q-kk2)2le~A);o!- z8_s_SI^nOc_(uVO+R3l)=St%v}&z@XKAcrpAq%s8Qck$ykzg-DoN!&aS9?Wm5osPPMO zG822?(Tjb?5%r&zK2giq8FWyfY1|6HFJWF3 z$yQRr5stHlo-iMvq75pd$BYA)8!m-53_C)0t$R@PpP8nbQTp50#@=(%ODF2>9G>aZ zEjDCpE)qxK`zNSfdBay1TOZ}1^x=c59yNYuv^F^SB4l!_i$!dJTXaG7&H_i9`%7@t zL;+7$Z_GX=;)FM5@M_oK>5bIP@c?{@+T}zaYHEYNqS@*sC(3hH=KyZPKcEnmd=uAsg2h@7lUb7*)}x}GoZ zM!Yt8!{*|R)VyZ^x8rKe5(N#EGm4tLkzV&J%XUarKi|SIkUO}SvqaV^zV?>RU}a|v z*6_ki##C{xZ0*BIAC<0D==aB@K+IDERy^_?m)wJnFnc5v}~xpFOJqL#5`DPysj*~IBLuDipB zyVp6mTP=>Pv~MBtizgec&kGyj+1XVcNwn`NiueMhS9sE%CDR1AN5l*48-yD_d_-YN zJxIEYzkDK$(5%u@jhFU3dM{0isa6P~XZgimWLkpp%fS9?L0#@cUL~~&785>9$913K zz&|-Z748U244ghcWHssVIu~;S=|-63)*rb793j{o55G0ZhgE}R5{Z}#(^!c_uNF#_ zN2ewCC#k*_e^Cz-oGDRj{isGyT_33A7_G*SV?x99N@0zLqZ5q4A6A+0G93;=&3aF4 zeN6Hv;m>q%vo|z`JR5(49!K)ZbWk>4-n&-ZAO|{(V$9Ms{GQ=>3vcMr3r}1Eq$4#6S6PKKk>Blq!b@|VRtV@jy@|zOyGGXN zpfv_HvQ}mE6ZsO}%wjP``jx9rOE;a#4W5D#@G<^D?PIDwP6^J!?_ZE-h3&xWet>C>8?dR4Od;G%`iK49P#q#08Vpf z{ssLZ+mYEYI1U4t4WJ`PceO++8u~{$@4nMR#czh!)6pjFR-|7s?hZxor}6-|e`*9e zqOzPF43RZEv2sCbk%6@)189>g%yqsq<$~M}dRK>^uMBq@{w7-Yl=9YE;7iNMQ7tEg zOz4XedCqW4+#?Od@G%`;gwvIptwtCs(0OoJ4^qAJ+DS_{Z_%!ZJDZI7NB5gyr8cpj z`9@@(&0+t;e7`rryylsqRmWb-o1q1YrutT7+Pbb#HNAtS5?(e>QHp(_d$7#`w!=E+ z=IYws{^Xs^DK{E3e;Zi*cY{|-B3mc7%?ld+_}B%VF3&$XL809 zG`~-T-fwno)K!d8)X1cPuU=&`L^4teTt$-zhW;B$%%~e5g{^QHmXk6-2XzfSHUkl8 zWQP6$A!&mGAIK_$xPWmj;*p+K-^w(i3tUnRk8zC95EKfva%XXPop^CNI$mU|H=_S+ zS2hA%E@Lme0Zt@-f}@u(&VV|Lc{GJWl6`1vy!T11vz(VH95k9+FNW%#C3W&Q%^-)z zwO>>AB(0NRtm)=KwA~8>LA&b5od?$mdXJ1zDu*e|B1NyDE36HpfdESG;ts`EX-r<6 zFU??kgS01X$htS##TiHNF)PCcv0U@eLYzS;x3xDi|20~t0NTO)UG1-s_|1;dt^4)V zfUnhay~8KU_D5`BWDngBpz3jgg)B5Io_&_S79UP~e&PNhfc6JuDPWd(>_Wa)qRNwk z>DnbH$pzuy^E&R=q`PXgB7ev+am}HyC8Ss`OE4>7J(TURef>;0OD(JX!m!6mCYatm zd={tW*Sv&sbwotPZG%b{j`Y3~PEm^%v3D_djFw3>Qf?x}h-?+bDd*O;1WMwciju$< zt;u~6j~Ha+D#4!hjp^;-NV)wR4JE8*lRuEWx8c#o^2f3rH}e*`eh}U;XtM9*zZ&^fGw}kkYHa4H{1AH zV$7Id)_JxfMbQ)r+8!$GTK%8o;&G#MZ%u{2jh0_n6jR4bNc9QcCag0KTnDl#`vKux z!Purqh|aZC=KM<7GzYh4JG=5r*Ub3#t@b;SP12ongPHQPAzqju*D}ks1-E}YEPwwV zj|0JO<}@zk!f|fpL{6si@?W{R!|gN!U@=+x|D`0=qYPw8%RIwjA~WXG4kS$uAYV~g z@+m3Rm{WZLkOz)G*Yrt-&iqvIBgRv$uUR>3X<0{u+PL|gt80YPYt>b>+yy+?oTb#W zkBoDgy5cB&#)0D`M1=!-D0zG3sJA+`$!IrF2Fo$v`}t4rksO{cr)1Kj1i>BuMr2PkS@J!y*s=H-s~)t!k!lhJhd6?+`6j~G}sSGf~?Z*%Z% z=?%wG!8ann6WwU={?qsf%~G8idoUOfG34WBakdj|rWH*Ly(f{p^Z%E4$X-MD0s43{ zv6#~OM77MV`Ff!$1u23)m=llx++A}tWHVO+Y&>y9+gUERQ;jC0lrl)eyJ>f3;}IQD zA@~|N40IQsE>+zYHJ3GTSH5#Yk^uLbmY-)tV|>xK%g-HWb>Mgd4h!%59&TSyJ1(6a z*^>k(k#%1P8V_iSI7bOwZS>fJVpuvRpOrJc*-m&id@opfNuQ!Fike1J$5&c@L3O%D z@^G%!c)C;&U1#hL?hY5Xi7&qaU#d|_>AN2Yf>%05ZgF-WPF{bl9ky6KuX*620I#W* z-y~=hlQjn1_+|w8D)G|E0BOwVLc@Oq)X$!GTmw)^1}&#JSwMk8VIr`GUTKNwlg8bh zhu&K~WqMj2DehNHUf!>)FfPDwUgtAxW)`CC+asB_K{6J`f@1k^4Htr4+4F^S;s4;i z^KS~hvFp+c_n=dXPW$<`y6?vVS^B}N5^b0gN*e}=anjt@RC=EZzo&D4i`c?k{ROQ4 zIDaqnoVzBa>jJ!!oua}DDTuuKj$vLg^97DrC;&-_kF{m?wg1Ql+_y zv?~+b!)|_r`CFA+0zA3iik|l0v~CG(Y?w=k@|Si1-mwgp!ZgH{I<5H}jrx^t*q4jh^4gOC`mosWy zl+RM-8gX5>A6Gg$!pRim3VsMhk3W8U4PLW)VZZ{XW688|VDUO~#@nO!i)A;r?>CKf zrZBAxCZ@(LSJITodKNjiN|VgwsVV^^WZ6o#vWC=lR#Xp6Q=BE+=IkQEg5vLT4n;3E7xQtq>Z!dQ3J)wd-jqdK!Nl46*u1Y|P?!nv{xzAUw;@4b}IR|)4H+79(! z^0lK4f3xT`!r>H#mDw~RRlI|UI}+dn=XPZ!C;rS|O>*{1<;unb7I2%g(M}9GyqSAV zn3`h_Lm~0w1)u8XPS)fpQi;xF!qrz|K(bs#d(Ef4rT@Qgy`k=1gU;IB1 zB$Dl`GT^$s zv%fNjpecyA(keLAva@t7Wj(_+i&;7mlb2q&(HKM zj)*3=x6W#+hQ}EHue-DUit2m!xJn302nfi~(v5V-NH@b!L#mWWcMc^W{h^xyX@>3+ zq?88f7`j19Bqi>_@8{mN?ppURxbxGjIs2URoc-*5_L;rkulLc74YH)j34{+ilFkFkwjmE6#G8g$Te?vn!8Stm?2+ z-2JA_GE9vem&|xlq@fw`@SOj3!<~4Rkyp|n%JynBo=ws1ipW@6+v}5I70Xc2NeIEV z)pQAgwJZ4*J~r*)egV=j%v76#DTXB(yCMBYjlIIHgZdpKZgG*?%iGS4EV}OYR4RLf zvU6_nsaI{IgGfX}-v?-I6s!#Mr(X3*Ni5%m95ffUk)1&f*Bcb>vcLYFeD)>^zvB%p z`U|I|Jz{7oYhutb1gT-DmZ64G7AnB>mAQES%%=M1^3Iwv{7^f~!?VQEU~uLsMX=Zx zU#CYbnsPvo0u~wx^jPA0dlgD_%D|P$fwL(g+L{kIBqY3`^?Y%>&Ly$Wy)nZSQyZUa zt8FO%&^{P`IsQ}ba8GyJ&cLmO$RXz}I$Oz%J~mdn7}h}=B_1`QlAK{+6#cftj4$^W z?zg42exg$EE3aJ5>R0(NhbH3Zg&103R+5H5IV6WOxtQ*pYzJYqlayttpvM9 z_P?&jJb22cYtSZh6lD6O$I5v$OFqyrLL_=X->V6dTYgSh zFoN5D4T7`>Mo)ALp$^9SX?vm=1GG}jA-h&c!Z{iBoA4c@pX^j~2_mi_M@vK)#-^1> zXVtir9y;+;$|}0|f@WCzFP-7e8O9xrYMK)z1hPnPFCwB7oDYObOR4JDzE6^!0@S^x z^DLr36GfUqApQfhsg(g7$tK*La)t%B>%p2EljWGX@k{eNPPfCwft4jo4Zm|`9KVeQ z<&9fim*`9UsLk=nfKpUT=t7IT^YSPxBu#QF_81WhIkGAIWy*+G;xA$OmiDEb3$*{nm-?;o(a(My%@cLHC5;RpR4Ly`a$K95eb)}K+ zD0yoOTF|WN7HHO~I6CfTG~@L2-}d4!9=|2SbD@-CON`E@@#{-3kM)E4W%K%rCdUt`3jb1ysGn#_){J)A(Y#{N~TQJy>g$3{Cr;c@b{rHfvPEuv}YNW52ESj z(WNs?7qp}qU$lrI9BmgR<_|!}BsSViwEB@S;ImA^GhTsGLS?x$->v4IJh2nhec|m^ z95pAH;NXYe4$>Y@HJE_Z-zLv zvZ0SyUuyJN%X00SOh=qkG+-R%u&FYv{x%86kAn+@kn(Y9q76}sKp;q28!&1BXI1Sr za;2gek`|U|XP~Tr1Z853AV|A!GJC4lrIZy7E|EsqCU(>*c{G-J0?mS7cobFY0~mi0 z&460RTNPS^4P^!YHkBe|AEy535!%u4(xDg@)|o7_{#{+tSx+*_(kS=0+L*5+ zuW>!btU8rZxed6U$kku@W0@@BgNEL$(U&f zsE|zkKG?CZFJgzAo`u`JMCu^oD8-T}=k4J;6Y;b3_T^ricSZ-a5bKNVi)4Bw%d7YXvHUMb;i_SDpi^ek5toxWcy3u+(&4LU3rB?CR+!| zED(1ne5v%^$lctioOXx`EnHlZ!g1`%kLZ0$LzD|{>V*|KNo!@+;g~!(k47-?FoY_< z7N&86O=ZJmFaFHk?$3Ajxi77@;0RI;P<@^Z2h=X9LJLQ*J)PD`ILL4CK;JBI*YqH@ zDG^PJ=pKNs#HuFwHrkp{^lL ztygsAY3&Fu2ZF-f1aW~zHs1Q{o`h~Q9%5>6SItH3(!z^f!5p1i`oa#E%&OoUEkea% z=n*z9)8057Wm}{k53BytT~-K46tBDuU zBhf^xksv`&Qdz+G&)Z2tjVjjLtBeQ^Z)^7B+5FZ#SkMz(i8sVpqokSi`o8bsG>*7R zmv0r2uaT@C$0Z%lAYr{UOq_y6{4f+5X_%I0HkVf$GNo_O>GLsmkFt@6H?f!B`iWGb zbalZ%RbhPEerK6k-1GFrKiPrvW&SKlGf?$F*3eZWD_`W%QA~C?;%T9>bpDUeRRsvZhpolH9Q?+Lx(uEC!7P#)9;cJYK@S50?#V2 zxA-oKu12CRr*ZafzS2V;qNqQ3mt5(xAqVOsX+_$T zJGXosT#5_E{_1iWEmABza?K7DE^UYq5Q< zj`8)<@ZVZ7BvRN^^iiwW8l|6A?04-n*K*}qnmx*qfwmfQ=XG(lzD#_1?Xd5W%oT$^ z=Iy~Wk=AJd-Xc}&Q6ji#ayvRmmdDx^&rq|Dvf*+ZAmcWbI_U0eU=!S|rH>n9Z`914 zB*(%kR%mR;>xqlPz$i100g=~(UyGtrwBi+%8nT()*eyssE|w1+cPjw`5+P0PdjA`dfBPTz1D9XtR*x^(A)X4_USJz4h%OPcv4=v>|XYKHPdzVsiEml2;8S>WH^%%26W8w=EoSwT|yqs?^GBd9yH@MABH8XEHei|j^+Vh^hdtYgh zOi+fu^Jja_dL5~r12ceK^9+}+e@A-~(i~M(5IFTtjv%oqMAu}YSNN$A4_=iiN=D*( zNHt&7A2wuVs&TwE7BfRiDvt`w9Q0jw)JfwcqHcgEXrAt0UtD96P^gS}F1$8*Rq*ch zRjWHJaLcgSxNB#{J#NGp)|=xZ8XftX-RD><+S}byqXDI^vL;s$^JaI7W6c6A zJ~}d8*v`&J9ykf!iO!`7ZsO4?(1p%@7!QAQu1)LMM_*tg_Bv@^?c0jFr_X_1!uDb} zdXeV4^}+^fK%r{sKrG}~ow6D>-=2D&wclfutD8&x3eh)I7aVG$r_%j>-}1zJWNfee zN9k0I>7w#b#lZn97)3AceYTzxO4S_onu%W|k;plXAa^#_IaD|qqh=NxW~5fvSV(QbniCg3R zrQI91!l`#h$34>x%Q2)dVk@r(*ap~1M{R@1e|32Fdo@HcGNQ}ul!m#x z;mA`Dze?7&o^+~97326b{F$3m+9s} zSY~)5^YZi^HbXkjW46^Ql)?tRe>tCZo3GY(SD*UXAGu~Q?R(ct9bTGHoN@FA;S{&f zy0~Q3r6``qSXdw!Te1B^z81uA16FOqOtRYFPi8w^sC=(p9T+F^hNYH61Q%g9oL|tO zQi8xc-;~>#EV;99aHQn(;cGhv=A6j0O_U~E_Kapd>zXcm*YT*3MGv33mz@c#t9R7M zw~2p(oT!YLEa0!ga`1=f3*HaI#{Tuy$^gOYIx-P$JK2~wI4OGPer-X8H~ zZP!vT`G1Qj$y_q)s}&HAPfSwO{>A}QxLS)(S*RczCw@RlJ0yS7_n3|gcY3^THPL?a zExA~j{DEN>Z>Y2>QI_6si5s+aw6iNc1$F+F>gPoce~vv5;BMq1?FulQlp$2umm+=3BTtBha%@XF;}Ll;XkO#V{9k@K-eCsQ-C;z< zx)oFd$y}CzrLkVwlwY&5VB)Z3hKBqjmcA-8h7GZTF%Z@h{N79oSg4s@v z|AUw4utRkjy;$s-A(n0i0}Uul z%okkrd^<`$_FXd@TNHaC$&$thWRq+Rc)vyf+pE~xh%E&SjrYfR*i0H+-*5Nw8;ye7 zE;?e>R(o71tVTUAi3rwX+N?Gp`g6-P;>Ry{@J6>=k7BSoO}Hjir(f?=DmBuD<^7i7 z%-gWxO`J+!{v_4PXAKqNwlt63*CF!N{Xj(}_94Dk06-jsv{uLq2D-TYR>22>`maseWx>d@q0bwD)fI?2$M^ zq-)up4D#bhg%q#ZKQ>0a*cNjp|8vy2 z=az{n`^k@>W_GaE^F&cit-x}Gif|dQJK^X>cI{msiEf|vFnu5eqrBKjo<1doD*+tY z{#>1`Gg|w{g62#ZmIOYq|if5dNh0_^CK+;ixWKX-&f-D9FW{>(Li$4MBEH$&vm$nd!fZJ$yD6 zn`Ym`VI6_gDWO|6WBHOxdbftA(I#zwUDw@^vu`78gsGki_sT5s#ODG#E`n~ITj^9n zlxo*$yakA+lq%Br%jXvIO1Dcm&CddH_!avlR8F(iWP3oH6I}n?1f}UmX*<_H2rev$6bz-UdJ*MIIOZ}s&Lp)kxSD1-2 z{qYAiazWB3X}?k^SyINexKq{NIHo+Hce$7A+YvDYs|ut?0XWilslJwo+fTiN_>jAK zO4X|nI$-Di#xNUi9#3g4s+z7(!V3!42Th5AVCcmI>`dV5Kh4W1zKI{-S5!X>}*J*u%vMovy?8QOikeh3r@wfVl$RUl>& z8v8NQ_d2r{Pu6ka0`IT&3$-&)e!Z6f<@lZQF6t}Gq!TFqassfPF_Vxj+4%jhH^S|E zbKG<@W+TVG!oR zgi0cxp!+sN!S*P5zl#k*D?uHqQk8DDvCC$F-~ffSvv=nlcE1Sdv^gHBbYZ zJJ5P7ta(nQDnRI7lQOQvw26OJb+*%le>>Oc=qS0DRuYiBN!*B>K$99CU12g8TIkUS zxrNG>J>)F@ zW<08#s{MUHJi^kv%#m>ha_ z7Ib&ZNL-pd#3dKVlaHbD@Y=OkptiSOMiL)VwoBuE>8MU((08d0vg-efwu;a+cU zihfSVwZ~7#T1$j2*mAnq5ie_JHs^dfK8fr98&3NwvLcC)lr)(dxPysE++M_k-^A-6 zH-L`MsWV$^+5rZbb%1N^=z8tmbyWoPYaKAS{NTXxzNy&{2z<0O;cC`{p{eHTf12AqZg0c=;@B)}*p;1e{pv{;=?vGE&9W92*?}vv0 zw-UWaJolXl#qE85K*!AdNhy)_^Z3R`IZ^k6yth5!fUyKwnLuHi)p)@WA7JF^@Bd)J znjaWqC`H{yOZ9mlp`qFLUWWhA00zWnXAAg0RUXMb`j53$wB7opw(&AO5$%20t4WL< zQTHua=CYY=OZQ*G;C>7ElhePFQQ$l91M~lQv^O3$3Nsk51H)*@qi_-@k@Mnt-ne7_QSev+l8uVLePy>S-9bTRO{Xc k1ApfycQr^;%9rmRPJ*5aaEleKJODn5vT8ErVAH_=0--)3ZU6uP literal 0 HcmV?d00001 diff --git a/docs/source/images/nanogpt-curves.png b/docs/source/images/nanogpt-curves.png new file mode 100644 index 0000000000000000000000000000000000000000..ff37369e94506085ea8ec2a87d38dbd2e97a0e17 GIT binary patch literal 97778 zcmd43byStx+b*mKB7!I(A*D3ZpmeCTq@aX!OD?)Yq+0|DDQS`J?pQR^-QC@>&V1P0 z-+R9Ee&aiTeB+F9o-y1T)_P(-bKY^?*L6?)^73VP*R#4ZCWwP z$JNyOiT%<)c;mIW<49>&eQ1ou`jHQ`d%eO$unX<0w;VcE3zv5u(#S~J%`X<@99R=` z>qwT7s>|bU>}77$!W8U7IguWr`Jy3TB743T28ZbLtf7Ez1GpXs*}z3##dkkHKVfw6 z(U5VIX~AYK##3C3NB6z`0t1`O%ZLiEMDE`9a z(Vh)IKh}tV&(=q`?A|2(`ANn}fT2@F5}WBn4i2H1Zi5ZMI-b%j{Qxj*=nx-#H%ke9 z2OXYoJ;ML_*v}8+If~3bpOTZa{lew)|M!cwpYs35FDQJX{u!FpKLZ0z$s&QP+uH)! z6D=(*toGY#^9u{1@ARROHHD9C=&$vs+1S`T@)yQDvnwblh_h%!D_3KZq8?z)mh-`e zUg2GCNy+278F#4=v$6>Z85sd;xInZ>Al}|$xGXlN@577IBsssLUIC&amEd}CTdPw( z(uXxQHSonzyDE~h2{UnvBAYe&1gCX67l9e=nc*DhAdQ&y6 zuCcLkKx%sW6LnudKVm~Nx6@}6rDnX@;ocN4wHuKQdK1i){j8^KN3byNKIGZ8s*=@qA?d5@Mi#0!iKOo_aG8}LRFD~z^SuN}M8 zbnFc+F30__)u7X1SH1JsyA)h5 z5_{W&vz9y!kKXB4W~MHV$|qQ1&$22jf5-})r6LrnE-zi;H;HKyHI$4s73LHa6bO0k zOx-^#sH;oM%Jy3399RUC@ZuGX@jbQq!DIVdLkT|K;lAtXtH`$9A!spW$~)}Yt;ILF z2YjOq7@u;9@neryw=xDK-n|PO=H`*H%b_o!_mYs6o!KC!NeorF2P8t&-{ie;Lc&sso5S*b`?kHX**hKTYa(|E4OM z@$n8ZGS@*wW-s`ku#3y3fagO(7e`0ONtc6e)XVqA+mjVt@w{b-3+L?$-(Vt6($M{o z(a{3kiuK=FceCeGx6>DXI&s75jyOm%yv~-hg@_P6sylxy->&{?4fJb!w&%ddpi-Dx zN2K5Wjj90g#WWndg7{jqm9Ih)6R!YL$bS*C&%Vjg5<0+*a~sYm_2=>kt1$x*%NqH z9n6Bu(+AS4@Nhqj^u6fEMX46~I8rj^RqJmY=~W+^2l-#PO7yg`s=6mtCtF86D*SU7U>J>&^vdqEst zzXqGR=E&kH`-8MQoOfoYkCs@og|8M}6IbVYL6-gFT-uG_`LVl$7PquQUHeQeKhAF9 zS&QKA$;k7y+v2h9oV6#m2Kw7S43*)RnwJ4erwnCR^KS}IF0bLbr>76yS4uiap9VW} zPIWxi4J~~5T&@?6xMoMFO$%#s&&)b&CW$0ba};#(1M={uDHV)RuKHVPaaRb8s{^wUb^4fLegMOL*De!#Du=Ry&Zm43h$`J zo0?7p*6#)40cH%xafS7YyuADyE319&5t|of{0UA@PLLUgiF;)#WRT|P@c+)a#^>ZO z4|*3h(^R(7x>XR;|MuGQN+>&K-*&sk=`M$fpU>ma{W4#{bO8ct?5N(cnev2QSqmM) z#i-XmwNXN^wzMOMP|?KXW3G>22SYekelK?%#WYH5FAYgu%y>zX?pdDESg(!LwxMye zZjzSfu|`?EkzrU)y6At5#mzt(>YJLHO5r3c@MD6MUbW18%0zqpN~Ow6*0F*e z@va2^Sm4UvzJ0@VR!>Q(*xFk>a)QqAc=cH#oO;fx0>kl&INW`Yu`{ZOZ!-=5r#3|= zZ?)9+-CsImv9Lm|JsmtmZg|2d^5$I8v)S2g^x8JfvN!9gI5FGD`TLo?jtA{D<;N8x z75aq_KIu+5!}h6BE@$!aztwD9c$D#y3mWm+@}u1tqoR9rx2Nv}aXSF>2_ma0+tF2f-}2*eXXfS>^3NO{?#_+*Q(~-o;ZEzB z`MAVtk*b`}Mj5GnhpoKTXN}?24kx^q2jAXHy)W48E^~nO=_Kr+>Ch32Smy@;x)0=x@AY!7b|aT`!Y*0XFfGtrg6@K*z@_rnai_- zhmM`IX9p|nuIDyF%jG+reNr!pFb4*K;gp|DJ7@m-^%UFxYiw+1sM{fh8dFqX+5>)^ zNeBGeraMXY73*FH5~C$QH=RQd%F3n<27@O2qWP*%uh1)&jtrzK6u6ACW6WSRa8>5Z zsHknHVAR@ZZ`ZH#8N>K>y(neXuZ!rsSqCIpVCroXu&XHF-V6TT>uZ<7O+2aM{rR%q z1^7;}Hgj~B|2F2d`Y9gs<;9HJ#puE1CnqEhYkFq46+X91JEGDr?OkKs%PsDRSeU>u zvEYyG=JXP45j}!~iz^28$_OtnWNO(27vUsHGc!go0XoN{(D1;M?D8{6<_n}eB;jSdGt8Pj`(k;N zL~XVE?;@GqLvD9+9i>{W;be{GyFJ(-D~~?T*k1I=;gGrRCXcK+gk2x+L@Af9au~?< z2%U7Vd=B5|>&Cr2yranI=AWVQvFrCr?lb21J>j8D+>Pt1rz2(VT@xf{Y8R8w@#?_U zSN&{?{MHZ7qz}&Kx_kRD>v=ENk5OzQr#q5g!uAGd66SLrqgU99dgVNvY#A`Cg4;w) zcZg-3a4u^dTL;Cs!X2J(Rq%arJHCHU98$}Yriju@1-)N{BzMq&#?NSz=UG?L__5jK8@8Lm4x!} zrML5pIr&xxjq>@h0^jaQhJ*bQg<8x+rRJ~h1!vR5 zNaxb1=pC0Uv3X{yzpEyiP(KQ&|58=1`LFBE-GKw3Yj{;3|rL4Y9IPN4M<~9;} zV?b7AvtItr=R_M5*3$S>#zna*?EX)iWwb<|h62m;%mJ2yFj+<6&(;;!Rimfc0rVf6o3=H9Oef=Q^3RmXMGj zls9?M_JQtMG{>oDv$C3Z|(@tDEgONB9+rNvCV zE>)!>Heza*Vf%6eDF&h)?s%0IjgFi;_k@R~w%W42*&HjPV_?|SB!{XIfJvSdE)V{R79QIt z9*MiA5XcW%>SEUD4h@+1AgQ)HxpL96F>}sW$*Ir%+1k1otGT>zpt02Jeod2CeTr}H zuou6x$sm4O2|;JJ>B?y zQu9jH$Qbz>X3~2s!r@Ij-i$;y9#F=KxsZoPMTr?1tvyw;J6-OIX@c;)qq?LdIw;l1 zvHe-VT+1t*TorM>d^xl`tJ{|>@`96d0o1Za<0bUzhR1SJp`vBCOa zCM6{$f%@>HW5QaicXycu5AN@H_tal?s21}q?^UpqC^Nj$;=W?z~Qte9(L zU@*YRBCClvj7v1XlX}R2JJ)P zP2Z>_?;#^Wfa~z^&^5ss6tB8vWF2o?;XR;k{PgM5Q&!e699h1-{B>@C*O&qG;DWfA zG3@^F5*g_>$f|sWcn<{Cb}rXB@N6QjHsHMn-PbC%M6`&k+a3zZ;aoq9QsQMZ7*xxY zu%Cr~v#h#KGA*ZeO;Ei)cQn6U{Q$A=^piRJD9xzK^z-JF z_f$m;01m1?#22=ZWLBWIJ2G;%Wr%%_KM|t7OP?_D!r>()kFEz34==f}&gcTAHqp|M zcH>wIp$XfOx~VnW+?Uu2#FyCH%tJeuKCr7(1WBT(u8U6kJ&x+7GGZr2NAMk@>=lp zPms!jnMb`m;nDDEhR1Kk^J_^^=!R0o-7EexuK$YlpvO>yFWYGe<>PTRUw$N_Rt?iA z!|E7W>)JMmj%m5yz2xYjf45*PnVTb9x*QRs*&Vp>F~A4Q?j|$va+m(sp{&-!dJO3; zd`Ht3|^$iKbvi#GW4!ZCB+Pnzk9PUA6TAL-yR-jt+Oy|YPWdx@^rPc6? zY`y6kZ@dX?2N_$UGrxdW$gE|7aMFK|Cf@o!TI5mnwLE*$pKV>#VZlxl{b!Rj@eK}v z-XnE_-j-BRyIMlb4fkF#I2umY>xbt=^Qany3OneJM4!us7HF@VrUm<@9|eK}PyxO!!Sbb~65%+7 z7qJ?!L>x=7F;;xD)%Y@G26IRk~-qvjOD#X@i*Ldxsbm?%U!VIRr<{9D0;}~Vj z6>gWUP_lxw(#Y`~nAHe=@7x#k+(pFqElcdOb954Vamd0Y)R`Z#& zpW8k&Ib0T$bHKrn9|~TwK>tGdj|sZFyG5g@DCz-|b-EYo-qp=H(*=@O9?3eUF`wr7 zCBfy;_ZLfQ^=Si)%JvD}!9TcBGAaUSU}uk+UDVo_MG#9-QE@GKN*LQuWJ=j#^OgRJ z!&3($wHLE*INlHC+wLWvwthh9Rk>nvy27YAvSNyA9WebPUtX@o)-8I6(r4z5btBXk9pzQXK?W6)aJeLBEhZ$AK%sCz@47F^^Fa&F7gGP zIKYza9V8;+Vqvv0WixI>bu%z=61e$W^<$;eVQ$2t`}8SERNOwVOjSmZFVpnNpD#Pt zv4UoGeq?;^G)lFuG+GoZ=yeO}=VVqhg*m48>%A7rYd5m0k)qAJ9W2(|S;w?#LXdUAbC!6+MCu4lM@pYdQDP?BG4icN~j_gl6$?7^P+$S9; z231^k+I=f(M~kaOY7WVnOwaf=f1DiD=0q}@r(f-AU)yMZ8y^+%0?A%R;7o7YafJj> zbO6$_0eTd*E(s%GecgDw_G0@w)7^TxYYwonr1>0;(9#0_?fg#c`uh4a_r^-c1B2_M zJa<55-aZ|vyB0AwXA%(+5wiQSAotaMjyvMg2&X0yfxW<3Wp7#Jy39m+o%pfVamM{@ zZ%Ou1!yEavva@!5=Dn(n&apz)H+piB%AG5fG%MlzDKINwI}IrIVz&JyRMLMP#A zW3C%aDl%2!QQJlPN2-Yn@F>BG10g1&F~jRGH15uLXaGrFg;d0YpE+bM{=|hmBWEL+ zH0K0wpxy8V1q6Q?zCB^U>>ZIh!~yr&hYMa8xUbKE>;k33yHep0)FpF0eCMvEpsLEd zXwN!bYDN?)a89eP?yAsA>V9ciSXfBFX+CM0xa@#2V)1sdu`YMEYxv{GE17=w+f2(z zg|2BbbaB=@sHU|k-^Uf13);CU;YcnX8y9e?*+_5w$-_Q9j>(oxd_M6bHxs^K z=MGXG?Xiqdos`qUy8Q4CC_8l_$6%NeD?yt~mW zzVA0}%pI)gIFWgN!8R#z zW2I#F1r*Wk6We639XSQpvoi{C;?{gC0~7p&0|67BE*&dMiA|Va6iw``krxGm`k}j~ zynNwE^zyO+NLGB=JYL5ic7A`zqGezh+}L=vUk8^Ysw)enHV~@W=grhBU#(X?+qoATe4nnoaRZzQ$py(*c-I}$dLMdr)RBhm6}k24mFO6kg%NwJ zV>j5Mj{G)hhAL`J;+;Vg>#!K>KVo0HzqdcbzzM^md~{DvQ`23%Zq`sw@7A9`e=Nk| z#}`*ZwVejEo!#A3#F#rr#M>r@HG_liTSG{yls<>Y#yU%x@IBFmhqRb-JpAPdfY<%l zkjQQc(b9vjJ_d;hA-C@{qO{eCh>QvvtuoTts6?UaJY{>10#ucQW4Tt4VDJIE)Y>YV zJ?x3-sM&D6l8d4zkeP$aMa0Jk_pVH2DHv%^=>y&kx-QSfo+5yG0 zcjgjj-NOw~5CQhr;LUmo%+~@y2trpi@&M#&n@6l1^N%OMdWFYW`z1AELLq9scqN*lFhPtgH*!Aj-;;;^Id{_=3OSn`_(s zFXXYX{CP@BN+_wRk7em2R)@yN11P8!NdUc$?Yh6eAH)P{oqEFUgo7mYdX#dBvHv#+ zfB!$F#et`e=W7|zzhEZay5GNlJAfZm%FN^Y(l8`VMsm~A(ppy8eA2ci$^gHjYeU#1 z52!Doo>5@?6R3CO4*)Z)wB0zIB=JZ{AZ-mL>jnDC?!}d86veykvz5dZZ+wlXxc3Rw zJN){GhFU-}OI|r5P_A?12gt+;L8!-Tx2Xkg=hThKme>nlAIbtICK|;y59AwfsAx7G zr2X;ZM-UU7MKJ{(tX^2sQURParMMV3utmp)?nc$@2bS61)>dyCGA)p}>beMh2pv{Z zv_NY7=Lmb92J|Rl;k&zME;jb|MCv|}29UpWPD|^yEUR_{Xii@W+-_M zx`Wtiza(>7ad89)wJDOUL3j`F!4Pi8$ytW@C?4?Ou$0}?g;4PG&Dd{75*8-%-y~zUSD-coJkGnCEJ$T;>-%JEoSd7>^n0Xb z@Cl?Tv(+DsXok-07!IUM1D!gh6S@_lU2g5UF_2y*qnM}|8M{}lFp~axN55kfCZO^O zhbSs4;xR?h|K`;GN3*^&X+PsSo+bmnTcqVXo%MW&c&2eEV0AMx@E>mK|1MkKd9SEg zXt^M|B~Js*saHp5=UBNVMT6F#Re6k*mvP`;dY@IzD@3XWA;p{68-yi|szE|XkOp!K z5TU_{rKrRE(!lS27?MEiJ9QDr9$pPrPd*6)h`dc;$`VfUcU1q3uxAws-1p|t)BEpw zJ4U|Sl~W{x${Q|Xj!IlG@Gxy7P)RaqR7byhs%1mx)A#SR<+)#({5x=5^O_^YEqOzm z+t7p-OU1BxM@JLU+i_TJqge+I5b44-uf(tP5(Yetl6K?ta>9dB#&D&i0cyxDJkUH^|Z0GTTeY=IdiGKPPBL0#tiLUSpWopw zDbax*fz?w0bcua`he&}C=J85Pe1t1vXkY)y(OydQvW_W`a_3UfYFG3J1{jre0U7Ptp4Tcp45*Q?W`Mg!DuB_$stxBY~CsLQ_gfCQvYSxW3@0Nc17fw}mXVVa4CGGl2JktB+TpPs$*X_Z=~svo zs1L;7-+wMbL54N|IS7~wy4ROy(0q8kd6ps>hD8ndxi4RkZl-0PA5g>4cYw5cmWC@PVTUHeU%4aw?cs*(!y0qFu{m^{}llXrEJ>i@gU#dkEQFzLKg^KWjscLgkA zZrXV#t*UAQN?!j&deGW~e>yR9oqqG{;yr$F)g>sP(=syP&{-l)AeZWWrq3<^5)_2N z_=E{euay@RLrpn|_V)Im*#2*fMZs75`8Q`ko~^zf#PkYEAw0h9kPFoD%waX{=21U%&bIZ;W|^*fx*oj9@A z5K4JBkcF+@c0^sw&aU(Zs$g-y;heO%gXw!GW55WU^DhdBvl}=1n{S~EFrN0m%eRm) zV?7|vY4i`Qp|ClbvbSCc=Irn5)BdmfLismJ(@o`HB#`8dj=#U($viMI0leZr1C-{R z5_WSFP|T9Sf!-seIj3y;t(qD!Q&cB3-Ru$&w$TCtSVUx`=uz%-XuiaBLI2AC=P(s; z(Y^0%t}5%>5Q6D#zFg+4n_JJyqkeP`STJN=`8E?~WtKx9l=`Np&7t8uk?EtwO*Y)X z)|hg_3C^tPVmu{FH*lX?{SB7ePr1n80SNEnl>b4QAUc62nA0ma2nix(eA45<>FF^z zIDU-Gqdj-b`xHsR8kfi4@80EMhG1HGx$@sx0BJlND90e-1~Ij;{E){FmZBy`i*YgN z=NI;y-HY6FJIhm1KOsw@#{pePTJz2RaSrqT|G|RMz~!ObZ%{-CtigtwFwJzC0F3i{ z5qwhG>s}83#M#fwvKKbF9F;Hr#c_noxLhKPIF_Zj>3VK4bTNQ3my+zoOVG4`nYBU6 zTFLeIyjF1DOB^)Tm^hAgL#g)j=X2s;eR5w`S6$gnPoC?(3Tz4R6HZCI+B8(NYNz_^ zMo_MTQj`N9DKr&|cY>{`yC~tI%ZH7L;eDTUKP;rYNPD0E`&cvjbJ1J3JJCT*3Sga6 z3Jcl(VlhBg2f!De5n%jd;b`Fc7JPQy^WOF;<&Mtp2uX#vI2#*@+HZ%I+Qg%EZ2W=C z5rYmrbx?W=ACX7h*u=v1#SRVV0R=QTw4M<>x9jyvAH}|ZH9eSj;2o}^HyPyo#{Q(L9kvDYcNsX(({Wv|&a5PCZ@=O()uP+XC z8pEFEOmB%QQ9%y_M^g|I%~)UywvfTJ;Vm2wIZl%eQqD@szjG;y%hj%tZGh@1Va;-T zPTX(%)t&JI@KC2oFkNwu%P}{W1D#D6nOk%Gv4?8`1JugFz$UJp>24 zE6)+n%|VYOIZ%TE_Vyx0?pN=JK`#ch*RdT5&Dw3MFS86dU_xuT(|h%&&2zKgb{&kWSOj@z}iapYkv*A`KIRR4Ckwy}V^;~11c z;aa_VI;~N=zEvtfgg%FQZJ}fFE3L)1E4Ek$E|M`}a&5lp`%kKdiPXP=1wfrSA1;gL zcP_Qdyy@?KqoihuQe@DbTv0LhPlP`1#4BKb#_|})b%SfD*ocG%E9}#su`z%60&67( zn9MNktdx{Hp>QEe?3y|uF!V9zq9GaiZPU>LYmyAD3}}>XmQY4wHR`9_w5$4ChsX+E zy)ZH|x{=kO6i0bs$fheRZuq8E|5Ruy8|ET<1t7w^r_|Iwr>71nouK%{A}j?M$v-qS zbY0#3F<1NT7%4n8R&L>Y3u|;EJU96m*ol?g{kX-?rdYfSP8o+e3%F}(rDkX!;ZXGn z9hkx%;ly9ZjN$RTM60UsOwLhl3lVTtw(AnTrgm2}>t8oh2Z;%o`=7x zy#b5BRUq$(rIR8wyqm-76?ov!E^2!J0!J^3XCEs$i;}JhQzWsm!gNjez{jdiSYyNf z#a@aE1Emr79@QXhq2j6zakT>sT&F-e=%!ZIQ+B%Cga6Mog4Jj?HEM+f4feA)zIl<- z4|;D*7J7~!#%t(3RJt(tzsb}idr`a|um1iCnC3sn*$WiFbPo@=ir}jYTUxRL$OH+a zpaGUhJ#wy#12s1;Cg^!+6WO|s!a!gEa1+s=k&$?SzE5^idJNfgO5flhfc7Le$sXVo zBeLAe_2+Jib$@Mlbu?nV zo|%Dt5Bcx=Ri3YB#c;w%K?U;Cr zmI@C}b5J0DoKiV9c!*7dJNa`86w{(#FSe^>|1t%rWG*c0_L&N^Hje-#yaX-Mv!GG5 zS&zpdyHKYkU_ovu45+N>Wo0L;#z9{G$U5Si#n|`l6-G(}l<*M=QWr*^iYsy--o1M* zBI4brkD&?zrf3w4X^vV&|0=c*Bfq_v`OxnVqEY#>fH(~4=(PTfqD>C~5Qf`HQX#dYX3Cm|sL-ClCa$`J5Ba}fZ< zBSdt6#jx|yxpCe?P81o`-I^WOP*6qhZ5H*C?wiHQtgvRUwmI^Vf-=;g=Bf)MJ!UF+Cch@xAcW}=qT+PVJ939XHNVx2aMabYG^Io0xU*uc{xGVRPHu4 zv@{g&<>kNht^|&6Q_J}d2wctM*VS(v?Rzu3cmOWKi1)Q{_`h=}jpo4|(+~ zwp*GRhqaOxL}y(4NlVmTfeQwXSg;z-HE_AsnAn4i~0b3g`|98rt0Kl>m`dhsBnw9jk~NJBDS+IeIl#ab+*k;~XNq zpd&c}iH2JTFFZKNt34uB;}tAFuvBmKL{{SqC6^c@Fc}kx=x5p03iwBEV0>zP{Cy& zQV;$eJm`(7-fJZ~JU)hyVzEy+#{IxC8JT$onKp#L#1GB10u#oUW|9*A^tX#wvq4w4 z)r{-Ozg0?##Y()r=oOPKz*m2iyX*?p2MFSwNcD4IMajpn=S}G%QsZF|BV|PZx^z4T z)`HeWnDq6~85})ZgAdxT*K4A)F(Ixs3~O~zL6(XDW8Ok|=K{b$!uPC-oHWaty`{?| zzl=0HP#@Q!g!_LRJP2Zv1{uB&&|MJM)@(eve;-}dO~luTWW`LtJ9{rmi)|x98Srx> z;3fYl4@L2y7izRI_{dC&dkk`vm&oI~j5Nb6XhHn^UV|M^XCO8=gC-&n3=KeC*Q zGEWE;M5zk1l%74N1>rwfV;w~VwQV{okl*l|8aaIr%|amEj|aY1je?H1dpIDAbKe`} z98o{pjp02|N5=yt1_Ub2%Xxe>WF!D+vHgh=)}D?j=D-UE<}r8OT0l!~9g+xJyjitB9YN?HjWWyW+(zhR&+PS`=ThSKH zvj=g57~)(&Z5QWuz6SOn5vYee17Mz2mvyXNnlDTtzzykDn&Iu8s4E|FWYLW0Ho(Py z_wXG^0d_ORDNK^v4q79q)0Zkr?aPqUuEM*jDH*`4{q#*_?}usWE*yaLEPx>(((uJ{ z)F1(UIgE@4^myI^`+N=&E?>RhJOg`y0JZ7kl}Ygd24il}lFJ)=4R*+Z9cneXM(%#E}B4Z_`(hb1Kvc;R}6Iv6;@E^x53|ZIc zDe}FoUejJ8!am6+xfP;SgYB#bftr^{zD*6l{-E6@tEY#DHXtCnVLUJRs?Gboa=Pr( z_E_8b2YZ~AE>y`wQ_UK)362y=f-tb>3g~t$l|fh%d1JaM$;pkY*rKVDOcltVD=T?C z>AA}xF|Xfwe-3C;@VKGPr@WSTQ@-`;sBXl9y1d~yvWy_!s`SRH+bls%#93fI#RH{? zO}mR{Q$V2e0V!lR8u_QAx_yzw1&Qe=mq^+Dk-cs$3QKk-FD%khJ6qcV+YQBZshDIf z$FDh>bqPU1LDxqedKR0b1)hL_3L@hiYj%@Vz<_sg|{IUItJ&=AoWyqO9AO|Q3VM(hA&fD5o=SO_!t0}1&84vw6 zTL*Bc!{*|?wCkl)jrE>?KY4_C5zHhBoI+R*ys9^NRn9$RYJCkE5b@G=={D^6{OxjZ zun3)&dlP#-9}4^evEUEiUNa7K&q*M5q^6}k@Xt4jvRq5lbvegpiOpt^xQc$z&HN|? z`R2|(6(DDY0DSZn$Z9uy0}Z!1u_6r)d#$+Te$1(7JSb3nwisamQTqA0bN~*nVaM%+ z7e$8c({}j+CfExJ3BrW@PA>rZ1WfjUJKz9)nBT`Hu=qQ+2tuA0x0$hABlrl-Xip6g z|2z;ARd=-o19}+HRGL0L08a!ph0)$ZbUZ@6Lv^NgK%sSbd4%A*74@gvmVl~Q_?Nxg z&};jtazoG)%ZWDXwWhmV^OwZxp zTQ;B4SmL#`8_I2~ymH)djwq=Xvh0Yt2~@o5c4PXWa7pf|o>v{4JjCv%<7Z%~3hzs5 z=dx<`l)wQR(c&)P0yPJc=S119PKN39ToBM%_?CCA`c4BVN<^25Lz6Hy6(olUG;0SMM~44$?Jx_iF=Yaf8_6N=7ZEiW?o z8p$5|dzXl}DsDQ3>&yn?k^P;ZvYL>P@Z$@$Wcl=>2!@px9Go@63IU11^xN9wUF@u? zI>Aut@(BZR;%;oWsEbQ=UmAfZiHr_`v+4a5$-E{n|7k5$ka2|qeFOancD|r9&Bo4L zrh_ql)a~!tN&HyG<%mWUyMpuAgu?=vYolRq2eJRIHy|>#vzH01!VMH6E(t+732Z8* z`iLLGlwb>G)LL9eW{R6ni)jYkknB+-uuPPo8!uA^w53k~dJAicK2MJwPO$ohirSeN z-!#|__Ma=G2YYomRL1Z$8kU!%hM;A)A^;Ce${u#GLhqrVxqAOao9uDI$?2yBX0&Wo z9VqR;+eq>yoiCj*ryv+)9>JcHp{qp+91~Ueop{o@L!L^q0nF>s)e^J&95^SXsA%9H zMj}Yx{ws?C(*Vuo*Ec7Is}Jl73$6(TRBZ7*>=R#`00xlQ{r^eUVtikU6P0+Y88r~n zBRkrE7y!Yg#jn+%9^(UBCV`v;c{pG|npC=3aj6nTHOq06e(OJ@&)%8+I4xnx^4WHr zugi%0t`Z$NHb0R2cnXV(N244Qw|H}Ic zeHkJaO1gOQWljSSFh82UGju4280zw-JW&dBIzEaXunq<+9p!U~8wPdpa|>*3rKkZ% zQ)8!syw1#xI;n(X!md>W*g?BFqhc5@A}V?p5JNHyfzx74LLuGQl9nBWE(V9pZ4$9! zKu?h@94Y*BaZf*Rhhp-*F2vuFV|ytprXqJt|D1n^NIluF&Cl+?xa=@~ z+)7!1L;1R>(Ivm(6;v#JxVhJb?SJF9g}x0DOM>~6p*v^Iz?%^@-)}D-{|I<4O#rQa z$gYznPbnGM!~iOzf4(Kq@oXsyP#JBc*u*Mb^w)HFMe#nx8qa6yLd9vdAcl@jm{G0# zDEtX?{ z(>VxcQUU3wd@Vh;t+y9$JY@8xEum1I9b{P2*jSBZhWxl3#ynQ)E9+?K^Exp3ccuWa z0*(mwyg>w@lYPXqUIYIW+B5cY-N3fc#DA>4yi>)}l4=>Dj}Os1cl6$`S41=Nq~T$Oc`WzD6CG6zu1&r3K?!UZ zeXE9+yNU4)3HbW7zU!yk4HnEF%*gtPRf5Nc8%Mi{+UyfeZ7RU{*xnr3a>sWrr{iaC zC^wWwju6HkJ9beA-oOP6IeTkT@G0o(I4HDD?wB%Nd1 z>TkHP*`UIP@Qr!Z76(Mo02uf$ojWRNp~p&b!Xp*Ex|*Y_-|Z>vM*uQ)TYiZ__Y;B3 zT_JL;D~88oi2wwrfifss8V^rZGu{3O=^i~mmtg{c=UHreDA>GnLUbbmwdWg_HHSxW z`Bpb`Biq?=e<)+*0M8E7-pCNVq65mm9|#}ZpNU7LWEigrfwhf?ESYLpJ)Dy}yq13k zEx+SRJl}pB1tk>~lzTzm^P3H5H*nf6iDsGG$<8ouaetNA73fDC=8r>THSbSzqrCW% zgX9b_+}ehfRZZ{kFy~7^qYK~qUnL1TiZ9*Rvb7hK0D;b;+shAMVGV6b-UM%9(efLh zZzq>H-T8Y6Btyk0Aokq@MG984QLm0uPYJS_oMBe-@uSK(mj;gKPo`6A_sxtaqDV5{ zLTj1YX4#!Q>WXI&?)`f^bgb~24B0*+bKv7Y1vnw=<}mVp`0}>}p4*z`;rpd0zn;6_ zJ}&y~z_OWAX@-X0;nq^MRHyfDX(!WM(@D()4~sAW>a(zHr;7R%#Ssd+e06bI zJ{M(cW=W+fQ7t%!H{f#LZKlVX&x4Ypbzu(@L5rXK3Zj(}xKJHMUS^#lw^X8h1!5}kXycH8a8W!F~!u-^hV20&|s&L~+aW(MQ zF(B=rD^ZZ_AIjvNZ+H+_K>?vu;;v9OZy`{s?v;>8uI+vpSkBS;*z+Ms7J)QX6Krbg z5Xt0Dffgbq$}N+sJ+a}C{UQ6eb97v@3P3DjzdnFn2&a1stTBhya*GyHZRN-y#-+32 zOgO|J1@~eb=d55n%~_R_G6?LW3IKJW$${c#iI9=u&w$2cOFaWUv7_8N5Cwz-N~qWw zS2id=vMCoWF~3(S&CfY|PLwJ^beb$7BX#W0N0lvUj>O%HXcWEj^bY4Q&=MDvOppH?!;se{sv&Mg`6oR?x(op zO5%rWtiDd_zIk?e8;u(HYugC^j^Dn50B7rrF}eLOph1mTp199#j}=}U z!uWGBanfBkqvqF(!*ItCXJ50LDo3_VGYecUI7j9-N_GL``7JKaMZ6^F1z6i=_LkO# zxLra6WW#ThJf4iwc}Q4ltQA2L{P=cutSyyRgfEw~UNC<2=v8lWo`BJCT5=en;q3Xy ze>SBc^8qT)z@I;W2H>p32-9X}hGk|nJ=SyDwpEt857>(5 zd^Iq#f*od->pY7SE^8i|dmlmy34eEFyQWgQPd}k2H7I3%9Va z9X*8O)FWN*%Aid-qrSbhTujWOYO%d=-Xq5v{IYZR@D$BE*?!?3iVXHkP^_gW^ip-c zd~hpp?DZFZ-Ovs72+DX7%CqOg%{h7|2OAM3wUWk{1lU(NmfpP-V=kD~_N~|3Z=kde zwmlMZn#X~dPKcZv%?Du)lx+ovyiNQ;JF}#e%QE55^NU3i{$$<~L}G#BxStNTx8ED@ zAHUlkeYJ0ro6C`;sL@jm<}i59 z^Efk3P{dL(iD5FG+f^=tU34)QHPa(oE1IE`e!Pr;yNS3X$Vw3nVE%K(BG2aN%PS4C z*yzq^Lc#j}(pzFbzrOAD1>$oWWA%4&xG2)~{refNI2l4yGURMlX{o7S0a^s(_Z%{d@n4!>f3S!qh5H53noYmutFPg+y5&QOWzFFx9|g zEQpR9=XjL6CxUJFLcQyv;!+N|)b!cV8YKmV5Re_(WH@*WGE)eM7RbBp*uEz`q=l)! zQaO0jL1FUgwe_tnno(_T2QaJ{OAtpju*Pio!T7vHPMD`IKjnk zbSAm4$|vN9d;fNhp8W9AVoT4{XSktsblt1Jp$tjTCi$b&D8&tv8ZbYBwnqpQEvKF# zP)B5vYzi-yucM793=MzouH}8Q`Y=*p$OD-88v0ke>l#VlIy=MJ_eH(7k;jjjDtL0Z z?0uyC3(hE8AM3EGu!L@Hh_hwr>&kRaaZj1q{~9ba|Av$b)RXYG+scfrdOR6�U|q z?3B^U1?}e)8^1QxUQ6~W7V2hvT@9q6OaHu>ipIrGeMgHO{%yAHMY$8;_YZPlu7MU@ zT>v6O@R;>$)QhNr0QZorrjF+blk6-6^VYO6c?vmhMqJt-3E{iyv3UU;Fh}>Vi+hnc z_YUPJO2WjwQ3pr~dm4~Z`L|$6+x%6Laq6yf->6;^DY-mjs=Fg8W9MG>jhiYVRnFcu{6M~JwQ6&3i#Bv!uZc4{e15WNulotf(#_}wzA zz(=awWv?6u{_W4p+gaR^3?naEGh7pp#UlIC075G5>5&0_3|2u)RNXDy>xR+1R>~?T zkA}7amF4VE!|nyW{Q!WbSbDdTcSZ)Yvhj)1KNiq=-$`7 z3^p^ZryY03D?O0=u{0r7SQ-YFMsj(J#w_F3-vG;&EiA+S=Z5{?akTdtGZzoN1P<$` zCVN&LI6tLrAQOpZP~VC?TwcvgqNJw}g87(#Rj!a0jbct;38G#DLD}dJ_ySXLhAYZ^ zPZrXr6hy$GOmA`ItWWDzRfH0|l9lYc@+hAx8ciq^=4`3UmC=*ZP$t34qSRND^e!1* z@u47P;t?0Wb2Xzyf{ll*y&tV$|Fp`6(b>S&E4I9?^ho|5KoUu^hb>fU(Ni*DFP>(! zy!#{NW4wRhqPnRx{FrU?smiH!Vd*mNONHgex2r?5q3rY^f9q5+Lzd4omk)hJoxFty z2x*CgO-$(ELeh2X1t?S@WI8MzUsmZTtn@W-@0_jdz5F5D5MndIso2KE8mmT*wHKPe z_9;)o{Ozw(%wJFbyf95+-PMLDrl+qVkq!|7kV-!V50;(&QhT>wXEC-W>yO#OC;yr( zld~`VxHE#5=ZH9gOr@saVSgV>B+|Ikx1(DzIcF_t_1PcrZf+`D+9-=mYHON}5R>4z zQPBJNcJbSVhE;lA&n2kh%&gNr@)6D8;w2XwZ`>BZ^nyEnqXl=2E6by0gk9t{f|Gkn z_8mnGTDPQ&iAO7Ircc`SDGS)q;-=cXE74h5>fL&=g?I<@*ob00Ih9xG)g|I8;6n_x)= zeQjrm-wYC`3?@ch^TDNKuQ^;%1_w;Ni5>+A4Qp|~1o!&-(zvR_QBtrrxsX**z1WKP zZ(f%q{qzQG(GzA0^3WG2uQJ$qUpSr19e(7k9Fhtg>lqD}7VkaAo+f`pPjvSO4?V=j zPuT0PM{&iSR^MHrR~LvYh|ixE6c1JmokJI?z`b`zLQ7^TJ&j77>6FC3^VU&Bfio&h z=(YzvDhm$l{2;GrNsx)WU2*xzvcD#GY?noxEnVkLb!26&~ zef`4W_jMg}Q$Nq{L?l1s>7M>sR4QZ_`pv3;)q-_0_?0-Nn0nR%@o0>W$VA!x`x!@3 z7i0Rx!<(WcNq4uGdZXKe+X!*2RFb|LSU@chK)Do)bP{Hd?Gd9#i}c%tta}e4Odr4f zVofX1l}TfdO~8A2C(0I8%(UMB;w{>!d10x3*7c2-ilTvCR7Dd56<8HN1FIH(@KM|& zLs(ZqlOi*8|JPDR-x98?@b}>74)rx|NRst!>+mk=MwgrGebdLw>zY^%yd8SKlQbhM z+}`O1s~F+#q4jx_S*#PQ%tU=Wl#yjt9g@&g*JC0sxtLz#ah&i?l{!3KL3^QnJK8tf zyk&6Qo7sjgBMrBw>T8e_1W%*(8$}&}GS;RwP`*JbNMn)YVH_TDt^w^ZbLrKo?uXlC za^>`LJLFn!tm-VA_2{G_JB@}PoSf@sG_s99o2B2MT@&^SBUJMR9a~G zyHyz+4VSA9O^RnL)!g5d&w8rdY?KCjN4~jGKw78Er&nij1J?E6fw@x1y zS3FB@n|_m-9@SGT>r0a;=)qOOZGrCTXt1FyH5NQVU0@$fw7EXoSmNe4b}}`z#EXaVuHP1uj#k zi7uvO7Y-LTibIz>JK|#lN;>Lm_-GjJDzpZ(P&`to(B3|XI2=XW8X^hnw1DB*)+dyi zanUi$p{h&)8*0xyTcU5bZ?W|0`H?}9z_6SY-Q(-n+k;!i<1>1E3j|1d%dVg8oj4w? zCC$#qHvCvq`b=+-LB)CsK(~$h00(y@g zE7_)f=IbAcDpy`7r8^@3>(9AQce$5bDXl#=JGRqJu^Nu(oSltW%8JhxhW)`5c3UPY zTt$j9#loWWPZ-Ow#FUBii8=Fw>l+OfzNBG^qz0i--Mbb;SQI^Ib=NSLiy`7X^qIzI ziS)0mjgB)l&=kCJ{(3O?tKz@~ImXnwNNFY`hbHT8ABx8==EItoGWRRugShZs9KCL@ zj77rHDxjeMNT_^Gl>e}8d4FJPD(L$nOdQ=C>xz%$6CKyBe6%d1XhVyacRS0#DJdu4 zwB{34zae=CzVz=zA$KQYy|9cnQEQ{4cQ}5Yt1Fvlk4~G@FAD13!EoHT<2TzIk|g6! zc0S@xOrEih4%kRXTAaJX(m z>y0U%O-qE~m7m}4OvVQ%E{UmiCH2b}yZ5LhB8kXj2V+9)35JU6B|X0RzU94hkgz^F zc+T|bQDS<&6WR0<6`ByK_{06Y1c=8wc@WD`Fary-DbP_5A!(0K6cas-vUr)uuldt3 z-1MV|Khf>1zFb|2{szD3qL(GQVxoa!-%4^G^r|%5&vBD+l=K>EiD>rguWah#8_R{e0A2Nb`5-<)R zx&A_^Agug6OhV%BTcYt!@83h7*mG<}zXh>322NU}_z(M6tpUt3h=GqEQ$t}V{DG{n z=3-a46TZyA4l)bDl$^kHX7Q#A*ms$D=hMF6Oj+DlU0=#-$g=ecg3~b7S%`=>?Q8ME zqdK&iVhkpjh4(gz7&Ze+zL`6|=dFl1_Qim#aciwKG@d1(bt@X&a+g=7_HBrSHYZQ$ zg{H(TM{RBVg^cZRV3UX-a*k+Kyr(u?(TUb{nu*j64QI$QpRy^j=m@N07;^)^e@A!o z5_id1EvD@LcO9xe4Mf|}&?v!oV|Ls1vIdbUQs#z9J~16cs~kz-m0rA=gr0~mv#M07 zutRQHTXvRXXN$K`Hf7+bjLIvxvtFwqreTd@^y1JawlTSFo)%QiyHi; zKGH5#VXY{l^beqG;kDGF)Fry9+(+OHyW+pd;|CSthjKBFcd zkp$lR9mVd^A3KE5ITQ(aJj2a^wU$BkgST}gXlBB{EkaP#8BKW`rU-Rb*6t?$N}Rm^ zs;sWTZYSKDwv8q;jF>r0Sm$#24yn(#uN4`t4=3^&(p6Q(e%vQ^aO}3dDS44#Ouea+ z=A@3vBD^-SOMhFTg2>8}GP<|2+C*_Orc!G|F|_2gCj8XaYogAgf}KlYHnfuw47ASm z!uNAU@8|Av^o4wRBR#nK!mh5{KU#{~+`%W%Q>1cle+5y~8spcVpel`C_p`5A-87`6 z(Pq*fdM)Ka78YI^I=^E8gFg$xN>oY0;|n++3OK#aKVOOo9Ps9tOk{U$tlmx>#}8fR z@H%cJ@4>JFG};AmXv@4c^m>+Yd?|F8G_bpA1T29VzN4uM%rgafQc$5^@ zM}HGWtC&^RN_n*&y0^Yy@?hwRhUcKmR^sGKOQ>l~U*5nBv)#NC75R+^Y%zK~e`?@U zQmTHU33XU3l8v^e z)eDPA8|8&4TExGLijg&`HQE{Vt+bAM%3!{sMNRWpR`s?weCbn~p?TcmqAP1><71{x zxX22brpq((beuJzp&Jg-&yUpgh)XzMt(8*!2>VUR#rddYED&CjX6PLU2?+^UMk%(d zH~Sg_5h=B{WoC%X^Up<`(-}2+UPO`aUL2XE)KrA96`G z55ulI8&h>Z1&e0xz}+|e!A1Yp5d8v?mj0;p=&Ok;&l84@#Mdh9;BMunFq3praq0f8 z5b>vPC3^44S^mJzgwe0aB}G;(v^rVPA2Ouzx_3Ufw43Z0W>S3&Px}~de+a4fs=%o> zsjNQ43=_ZiJX2?G%~C_8ueGe?DV(ig57^UEUC| zA~rSuEf89sQ;_36XFMULi+ex1H0Jg)BZ1DG@MNkgO+TJTWttFsuw%BOWjtR&X?Wj> zRgs&r$l)~&8U@DyF2QW7i@R7zh@QrdH8%Uiw6@rgqn0yK>B+a1z`IzCZ&bZMztzXU zL&FP?o4$66LXD#bRmc{T;<@QDRNZFr;B|gzUHCE9+}Dgu$;@z}nH4VNHZVGf+)TV! z`qnr!tng7yvLfvx4k#ZHi^O=FNd1 z9m#$IENZ_%ekCQPN|pR9Oc6XbKYkJ`OAZm{`bOM_n2n%3FXqC*mJ#tI`UZ3Fk!7gSdnCmN1GN&UC|9vO_>n}%G$5~7P1(d zhe&x5-wmZ!(plct)!e2X5-R)_n}7V!)pk8dd$u)LmCvP#Ly0X%< zoE`h`r+n$?{k+W~JeS8@E$OOFGBO?Yd=#?Rd$MDqW$bKeHt$3}k0}y=mGaC2rH7zu zO%6NcF9a*tZ9QAoJzLkz12(L6WW$}6U89D|BNhCLCNEzK+L9|D0H?o`Rl;Av%8hu}(a zcHfoSE$hUSck?^W{fUL1)%SGJ0RHS&hCyR52?eT5X?- z5OIzuj58~Wc0pxM{3J$ku=uQA@_So+%uK^s9D4~*Q>ZT?Vr$IuEkR#)Pv^yrr$)qS zSU@ zq6Ml1uyUh~LKLKDvi~y(4BaSl^cdznq(XiCKkAXwZzb%7-p$M$M0m!n8&sU2nW;+X zZn8FKJm1WJl7Qwyg3F`e{H4zmt0~_Ise8WSNWCIc6&d%(ecmD=>aq?SiycG9vzm)f z5!|E(2gE!3`RR{7N*K-QnEbLw;&3kGe^6OIkgu^i1fv^pX7y%|BAUzyO>f@Kd;ND5 zGoyBqt2qQIaoYbP*)WQ$ChF9K633y8F}j5k+s<1oa$%||l${%AKfhpT|GDnMfs~JI zAvqHp3}YLbn(D#;m}1@lK^_0{SNxEcIsYzAmG=89yA_2+tL%<^3E+oGrbFIn4$>@e zP|lYxhJE4i7M6PXTAleI13hPbCf#hk##6Zh=a%)47_77uJGAX%BkJCZhlnVQz5IbQ zm-V6Cke*rk%VCjPt^_)Lnl%q>#&-IONyLh{4w99inkyD+O-(YdF@9M2R}TwWB7Rc) zaU)+W%E57-njtACtuJraF@5<+!8Rd)3zTcc029*TqFp-suhtICIR;H|v^gW#&69dE3i$+9cX>lME(5;rq$_Tllwm>DT#j6FsSivnIJ2 z<0lqcnA%JQv0Pe;cVI%3cxBNxhns~J*QPx8Jsf+Ev*zr>U)2|G8B+vLhLxohR5W;~ z%f2Vh4%cryE41|N@%a-c<*>d*FR(D!TkbP&_{9>-{ zZ>7~8C)l%s(`6eiGFFK7>e(Q^V;rkG-lOl2Ngf<~M!vBN$0gC&wmy>g%wNtfad6bJ9B+_l#gDa;FJWoFdrrZmh?v4dY`gg^4j~U;Y^(*M z2!GqYo!X4S^4A@gr&UJ6$s^Y5h%iax4OAGQ@JW8?WMCvX~#GHJ&qdyQZ@8`GZ1-EGV&xZZr z9viRr(0Iq0--rF`mwYwX;_*3sDaDPQc27T+8pg7`)}MR$v%&K1_#Y!G$**dzWbW+D zuG`;>(Lb*!YD?AlmTYsnzsxYDG??OKdsxF2Vi?8k!R55^nll75LjSMh1_1$^jLZvq zrK??ld+b}-+Iuu$G=Ol!a7|gtJnIv5yhx!^ppp9GOwf zkX$idXM2*jD%%+6z^y4jA6%_6H2EknVt{)v#Op`F&~T3l(d6sO|b8c4aKq;3kwTt0E*A33G3JI zjGj>bVC0uEK^7@OU7$$hp~ikPUGKE)Ev`>vg>FX9`GK-m^$|?oxKG9MI5zW#&I{ht z{s0>a=p|jkVkFW1XepKPIX^sX0B+XS)|g7Bw1(O{lyqvhd~n zbff{XcEB_Tr6Hz^C^4FoV^v<=brQnn_^_bO+xi|u7;D1VA#bc<$lK@&6(wW^za+^l z4eleRPX_!#Yz-@pu@ou0E$)XyckbNzjck557S0Pk)ecc-yrbWky1x(PuvmR8e1?QA z@MgE7MZD8*`7ryqk@40Pct7|SFns)>lCX*`H(&cwL%`b)Mv^zmE}#T0teIfL@k=Qa z0sInZF_bBlSyAyahr1RbO??0cpaXKfC?!-C_OFHNO?}8Iz zb>cr;%fXSH<-P<}Bht~%9A%r|4OYi^bLssQ-=jm*GiL&SU@I#F#{kWgd~jHRw*s}H zE{J8?bh!Lki&?O)<)Om!lweHIFDQ|HMc~*7dMd`nO|~j6Q1(5vsCxNB6dK&hj^!x# z9654sz5Mn2@V#N6TnR~WWMrZ#lr-M{ejU*|TCcjr;mmT(r+VzM8T~x$A-DX{7GjYF z78W4gx<@deWeSTdiDGEWC|IjTWTAsMc*Kra=?*USBu2i=IZo?Kv~g2gCy7Q?C3OGt z`g8sL>fuYxcHR$bF$+|E6f7U(J*WtFUtb(+!=yMHOsZc=EFxW|hdl$}W0By)*fHaMMXvp)JVK$prU7a@n*P6#aX}ZTKW7JOPk8)pHsXR8Hhf}r521cb}nk`V2 z>t@Yk<%&8XA#)-Yce~MnB9Gxp#z3pYmaOA2Ve8NC(e_^-dv_+s0{p#XrtH+!dX~_C z);;t@6K8jq;ioUNcXQKsm||6!}Y6^cqJsp~Nuv!i}D2o%;5>!YZfR zy$6ryi{g8tHv+n-UtM$JZFkcabgXV1nL3hCL80`LVG2}n{d)ib7@YqCHG-w#KY|h% ziLKt8d%n$ky|xM8H(J>5F-G(g;<@FBFUnTzQt*eM3V|tr^r1Yw zzG0K*tadu|y||zmqa?sFRnY5c&zG1Oe+;cO6Fe!0Jw(D(|H(UpFf;1*PpI+ z5R3v%UihOKdMaI~%2z2E*Wy^8?H0TfA2SdIcbm&p}New|yQVh@Cro7t+_y!;PcrB(m3 zUI^diZ}yCrRGELxr{ajFL7I_@e;FBJa*&tlulMWad6eR}I6tK*?2&sZd75~t5|2K3 zZdMLb&UMBGveW!r@u8O9K#<35A6Wrw%+o9EDRz|qW48o7UZ(;q(hfLd8AZiTfXWJF z5aRbP^lZk7){xNsb4FoLMiiZgsinnSxR@Bs{tfB*y$==Gfl9$8hf39C1%+RYuYEVZ z^e*WTRQCGJ!g86)S{$TU5VH|25*TBHlRe+>Biy>;LnXb2z+{2Vz?}P4Ff)6P{o!(w z>+29Xw1t%VmZa8t#glKOfU_bt|tTkIR^Ecl%o4ML8 zM6b;s57sW#AsB?!EzH%2KE?k2=(M!7fehNCMu;^MPm(6WS1~VSxsdlcO-c(q?)WmM zhiMXtc|O@(>doD9_*@=Dpu_g8RZ~(AxM>8VxJ*{7I4uEmpQ+l@C6?7xSuUv(u^)Bk zK6?C(v1mr^KnBgEhmRg1u&GW?PH^$?2vqwDW!2OqMMcrCpe4K$#>bTbNA9YT#>5|( zQkd!X^FU^_8-jmcU`jgO^E({s{7&H2l}5CQDI_Fx z6Lx2D@$uKHIK^Sd1fYKegC4Swe@aMEykddckZ>P*V#SewSU4sAe6wIAV9bQLm<3ap zGRzMkdtf=UQLjX^V&Y(OvuQP6!rA#%Vj|NM&;hAHH7=-=ANq$?)o9hv z<@3{M@tgWz+7BNF96siQzi35TMZ*!1imn`4oDc2$nwJ@PfYCoNd~q+tBAQSN6SZp} zYNJOFXJ)@?3N2F|qik9qZZ&F^hng`oZW0E3f9Ls=Fu~l+(mK(84h{BOp)Y{wryv?O z{|D3y)P8(@cjH%aqOs3Rd)_Ec9N*!r4N#P`d0f|o=`SNmHJvDEUihhRnyB=TNqX%f z5wwwgTOxg*{Z8ps5D7HB>?$?&;Zoj{g{H-EUV>IC955Zb4 zB0G~=wV@r-DI%;D)a2e_(Sk8Ch@TkcQ56P$7k_6M`mfL%tDnVQ34My$+2JNaucWD4 z*kii<6f}(2(*}#fPy1M`Ph6#*FSOh@xz;$QOEI<4uG>r6FePCtM4Mpy^qu6hvL$ro z-bH6ZHC{1QB4)bHG~5J1)otaRR{=Yy6Du2eR1_a^(+~GMS9_cAU*N2+lFuCt5NCbu zolBu&;1Bxxuz`$Mu$jDFUnLMCnIF}{SBALQ6$(%M-dSnG?BXW3e{YE=XYAvQ*Km2F zsA1#me`rm6i%Uz&rAZpYI4>eD))XG@&E6y+qT<^<9g$n$>W$uy(jEWq&x^|#;ENmf zDIZ$;);#<29+ObkrlP1_hmd%74{1@0WM!EZ?;3ys)J${v9 z;TBUhL;5Sxx*ns}t+;gC5H_f)nz~*LFQDn)~I}wE?%y=~UPH*oQ z+zDuJ2!FxIB8tph``~CBV85rNKxzNSpnkS~DxNbQ0v+HZM$@%F#E0vrR6 z0i0^)@0Ok4=NYH{B!erd^0kWWX<=EC8|0FR-oO;!&F8CDjWm9U7EEV}t$**%fy~Sa z+U^uSrKes4fI(Lf*b0jrCqQ!0*9&nk==ar~ix4V^1Q}@LPzWlI@m=%t%>zYV%5kEJEVK9G z-|U`3H<6iQ{(I#-`au4_D5phinBNq@+?HSkF14~UdTvgOp?ZaMe0=Gas=8}Ua6BMZYrfkHGubh;79<} zvm(nyQCJH-S(S3F?l0p0yDGK1v7rOL^Cr<#Qfzvr7E~R(w7H$Q)r~w&z82=vLldg> zwssm$NevY%M0ZZGO#{1vkFdnnAyp~d0=giGHdueRG0r zbC35c8{@912-K>_B~N?4d3a z+qqswrHPG*1LHdq-rdLGj&@d8oF~J#(q$ zVd;vHRl2s)BpSPreB0W3z3RU~o}?c`Wiz!ti*N+@kMH~evojBB4X-NbzgQyxa|iSN zlNK!)0AK24Mkl0|P zkC(6%XT*#%PCEF-;>;uP*BIVt`{Ok|=UIi${!2oUKGGh_vt7q=?}YXIn&6EMlm9}b znN`l-AjA{`UyViKJ~Zne6a?rZ#g}LQm;f;?0;T|g@k4GmhtdB}n5b!d75@ErmXRMY ze5#1eCCPz;u&(Zrd)r3{5z0?y>t{L~wgZTg!j z6h%I56U50SeFCXAcaqc;Yq?6j?#2=)(^TRH^a?f(@=ayD7csR@3a@lfl#uiYk!-#$ zQr6e3`Fe+Rd0+wj45|j8jp$``D-ycJ>EYdtqOZ~Gq{M#XY;^VF zrn}yVV^6unS_}#plRv_uPtk3?sYTsa>l@i%QVRA@-2)th{|onoki#7Z>-Dy1bC}eV z{2H9^(b>_m%&XfB(3KQEe$BosNhxRA9vT@$!~O|*O2}tV*@7d*;srxK<G9FSy&mRGJ>U6SC8*4+5t9^pzi0)dJj z*xA|+o7@Ia^-E5mr^_wX}k*llLa)X|Xa8ofNoc1aNbRWx*S@dDNZk%gtAf4C~upW!aGs!KXavHwiG zJ0?RcWbb_Q5#3|iv6Fj23G0k!wc=(mJca{!#M6DqssG5TOD6n zAFvkDHh4f|?|F=GQB{{^Z=(^bZBXW~EmOMu>bpKHA$z}$*j)ZRqp|33F$!j1I&y*C z%6V&!lfT=p0IG(UiH`78pVlDZrF?M5vV+Kue^Wdri!DH0R8*@woCf%axDZ=urDbKi zz{L|bB*}@fGSOxR(3A+q)xVrWZ`64X0LE_*7!DASlaFL;VF6Y4O`m=j+OyVqLp`rW zqMYt>*~ddX71?3KiMYjf(~Wrxj8=EqbmaC(B=|}l( z`@m@pPNaZt*UxPWPhPI_=jwGNPmRrjuAG|5ZSrPgqlnb%zRtsEJ8UL_FZ zY46BSN(ij;vWEYS%RPP2uJaL7Xp3K*v53EV+^@7^uy)D8W0$e?7k%5{uQ#>d+>XPG zMmu_&<3Sy0P@7JJKdZ|)`{WqT!&&=}MHiVsC@l2tvkJfAiDwPhd578iHE z5cYxHQVE=U?S*{kIDNbfOGFvhd+P6o+V9sNkt;;Oy2=DhJQ4r-aB%$I536~o1f%As zNIVxF4*r};SG4%qb|(Zc_)uU`PJWN5?A_{3;8Z|>I5+OfD=HdTqrp~*FZ$#@Qgpb3 z;S4HsGO!7E1{3+%!~+t*^OSWN=M-BbtBQeHm7qK;3}wL!AZ>Y3$JFS8A^xmXUDIey^R}eFJ9vgiF*;}F)dB5C{xH5))W!g z_3Am#-i=qe`K}9>C3Z8j(sHUuKCTmC66RREcZpC^Uals4$jI@_H8RNKKD|Zlp_M|7 z?(cu=7OBAXd#iOv#9u;x7z<04M{VEjai$Z=a+m0INb(UX;beAN+UnUzV6M<8%lxo1 z$9?@_s_)KclF5RHi@dQo2ENrsgCe&`Jh6Cv-%BrGOY+_#;aAHMcoit0PWIZ6X8BHy zNO>YfN%X_GA^zl)7Z(r7Y`!M8_z~3%i5IgYCyYnE_%1>yiId+dt5Gu7C(ef_k7AsR z8Ajrz+9X1r5dE343Wj0tjZWg1C!Po6WIuYs9CS~^z~7`RlF6m-s%>-rw&zgC zna-&1feK;e^C;R`A`zy)VnfB-zrSMaX&K(WdZ3yZ@_599U3xzJK9nW?m{!Po;YZs^ zyyTRJ;5=XK33)^d=WlarNoz&U!0JmoeKA&w-xe{*78}YwoXmK3ry|3%2O>{37t$Oa zarTSP*{|{+5)nvQ>?#K9nNn_@)6pHueO=BO zH$n=DRvYPEtgGje3f+2dgB7hYRUGeKtJhYKZH_L*!Vg|G*>I%Lzb7i?=u0!zo*H=e zd_wk?&RKn_VQf!Ic-YzuQj2Z<~pX{a7-PW7nTQgALi8-CfOQPqY>5YQpV* zPx|c)40fS%9&lWECnu+#e380(ldX8X&hX^ckEvZ~QgrNMZyf{(X0k3RvJOA^d&wR9 zbI1M}I<;O-hgT0$sL^TE5Om+p5kO!<_az>7mg?rrwz}EekZG=(!?JdL+*i z1>SF`m-bSz*o#>*a}sj*pB8Nne~uTKXx}ZqU+77c-ZMK8{@s@>$5cH&LCj4CvV-i{jtej#kSWB5ci8D!8C zOP~j^q=fRFw`6PjvhPRj4Jy~<9FvSq znq;HZDPm9Rq!=3>#*54Zx0_H5*xh}UV8s|OuRT(eW!an*?RIm=cogm3+XqN%dq1-( zdPTKHABsP`9Vh9`x9b+muFU>3OPWtF8wYVhm8fM#KSPz)YKh$7>e9G!*@q~X#N&c8 zdyI_|Wq)YHO@E>#^!FaE77*?}i6A3hNUdRX>JR>{*p1)Z_pNis?(MP`#yRn3aQCyf zN@^p8L_gg(spi?z*;^X?_6@>|qKd0ZEwQiO#hkowGr~wzUMsbE|HFFb?9C@wDQ8$ zO{(kDWrsOx*2G%ipVp+ugdzfelI#7Umhjn~Fd79g_ zs7>u8kH2x3Se%Ck2@MSmFt`WHj1_KR5_SToLzMQ|hxv8{I1;2(Z>lU8?IGAeaMuHS ze87Fu`ydS!Z-pXJc@Ab=_v7ln?~jzl=g-!?P#jLY9p9?UygO3A{NZxsM~+KdQESKJ zCNsNkI-9Gbe!03c+X6((zb)8miTp=!{XAujjA)$>H#C8L9_jkcn?n`m>ZeEBN$Y!Y zFU^-`eI941R9PZfZ%-!x{PUb`jdZqI-TEhxh*{Hjyqp%Dd(R?=|B}b_&$ptvjX5T2 zCf!XpT8tg6zB2lAltEV(Bbjuf8qHt0nWpB^)S~Jf@CgAfcol9H z$!fVTBs_d!^DR0NOsoQPjt-LpXE<+%QQ4ej6@0$1f{(@;x(l{7s@+>0D;~sx3j3Nb z?n@@yRu$^;drWkPkO`sdG-u1Ci=r|yAl$*NSrKf>Dn+29*6)i?0WzfPH*SQ6gfumK z-8LLZZDScZI67*$c5j+3bI1>_H#5`F*sl0RS65f_`Ge;VgQhXle-w?GtXOlfZ<%=j zHM3cT*x1zdurAMDy_WudsZgAe?+qMCMl}g^0eIZp4WxhhC9h2>lCS^ z*MHW2-gZ+^bMSIJSlfjFTiq(7Ix8!S^WGeW6&oQ4nM^9bAIc-`H<19mEG&FWQ&SUh zVaB-L2Y^SmW!4OgKJb@tYH8k3%9qqMG$Mv$G`e=$0Qe@wb%1MB77M6{Uq{yKh(ZzX zfInp){6%O4;69C&pkZJf*s`fflM}1{%6WI&ZK}rZX#s>zi|h=tGBVh@-!EKmr2PJ^ zp-TkEHD_jbaqbZOn0ssUIpXk0(DO?4`z=8ZfC-_3Z#R}S!~z`?GXppz0twl%D_#9dWmf|2>_&viPnFo|CYTU);XR6A4#3{1=@y_83<>-b^zQME*HnTpaC6iV#11 zjh|MV)*yZ%MY*@B!;AEHo4Z>Cl#1jSP(n@xAuPi&5&netvHXMw~ctudc*v2tY6p+5X<%-iGfI ztR6<(M)Rl7pYMMB_;DJO4F;wm7z zHJc*znOH7gMd=16;&pgdiX?E@q=Kpd8I$AxI9tT^;XgYS{^RssokL*_0B2JN4F30{ z0uirlt^J4LwbY5e%$2RGtB+j$sHBIpGd$XwdiCcAAHR?g4iQl=Y`yXu85spH!b5$f zvs!HE=e|7eyUFnHO7Jo1PTUyJq2EoP&~Dv&LPFwB#A*8lbhOd^nibths=#fq&95#t zM~kkhtE(frW6$^zpPf_!%blA`XYO-#&(a1ly?Rjkf@oqIv(4C_znO-GB?2N0L?;*p zi;z$p;B%G0QrYBjn@zIOzRTeU6KMyJ z_c|ip6a*rO~*?~xy2I5d z)_+!qmaC%*RI2t&1DBRmD=plw-NbyGAx_X8^=y070zKiY&}VFX{0yD!wgXO=(-)3) zfbe+B+(%{wWC4WQ*WaIZP)g8D(9w|t63OtV3X!4;oY5yl< zmZ6?Q?+`>k@G2Y+Hyozc;S46*ea+OMCk%M{Wy-1=+H{?Aw0N&yzm~Cm`6hEH1t5?& z%y{9;8KMH*+>k_KC!fUiW~gkZZI}fLenoDanNdG+{U?cLav}EOir?hvk!-i8^J*g_ zBfSd>w%?7!aXGQV%MFjigWnSn1ycqFo@THl=Sl}A{h_)GzEd`>qm2$5EfW;q=x0t_ zX9$_Q*fn>l=;!WW!f0Hj;I={kbrZk}6~gOP`*@=JjGqpnlwh`9491EnY-^?_Qm(!e zR}_JI3v)J4fG}es#a<#rm<{<30h8yan|HCnE=^V2AU|L=d4SaM`LRAkHiX@ApAcK-E-Wb4H;|lX@)PwU7UBMX&gc$zEY}CSd zpf`f^Jzs?#&^>F3f)UcZg(SjuWrkXI=Q{@04~7v@sS(~uDp>q!%s22Im0yF2xkN23 zncn5dz4=BTn;n%}H3FaX>>L7=MYmqP`5D(s0fd<&lx_u71&f8wdms?Hvn&^tGhc?n>7iv z2=1IlKKB9W!!+h{erQZg7fj7OEg5#;MEWQ62+R1v?qBxvpMU982)*(ro<1<;g}^QT zzPM*gjW}E^ufaqvr%@QmxRO=-pum5x@h-gMFr~FtQ=I$vmp8HCJAEh28r_f~&O|U1 zLCp|etirz-ABz~&!wWF?@7O=Rx_H>yue^YtM)ZSsxBufz?*2VHic(rb{4a@F(~lHB z-AE4en@_I(e)VIE;XikX_zh32lN<=8L1(YbKlp6OqYyp;q8&9GL^+Ony>e=c6DJFP zdsY`wlml9)|4FpQfv+pz{xe`$LU<+`CbZ0^Pd9O}uy}ZRQNTt%fink_wp4WvLmBNb zm$EWm_VlTI?iY7iFDzbM0(IBFECqFk4}q!cf0Q@L<&XgJZ>y`TSN@?jc{V#c8zLkh zcB>_%|M-X1nFiPV?*)J({qCE55%8aReEPJlk%oAnBNoF!Pd7I=Qm)Q)#1ppWAjqvxv^Y}9uzpfv-7h+53)0QZHiqjL4yn#0KB_uf{RsR95jN+Up?zKLTsMewwM z;j3J8ytufqBdvY3WX%CrI6gbOKlv-Po+CB(c5ZwjK{R77kJ)OPo2$85_*|Pvl9bT8;aui6@Fd70N2RSMf2nB*X?N8W6g4 zvq^$YJw82Ms&?k#;Ry~|U}j>XewQ=6i6IAqK#c?l##CTT5QuR9LMsxp=9e?heBkiF z57QK*Ws(lazh+~L24l= zygAJCcmzv@Y@FrliqOG0q;C2)op|^VF zT=e?2wY9;W-3Im3fY1dIQiHc$tOK->^ctc@F#HIm1#%q{28Lj@I!6Z@ON2@D^;jN~ z;%?J;Ug&!H(RORnuVq#PoaX4}A>sFrfV~ae_9dWy059y`GK)Yg@3T=td~T8m5uCpV zK?&y{I;ebwMsX!X3}uSrHFnPsOn}OHk;%zV5FUcnY|?jVwWsJDVB+neX@Yk5E)~3$ z5p~yuh^Gom)wu{g-og}MxH6c&-_Afy3AjJjl3(LbJlvWJ_`ZnnZK4Y9jg35X_}^{v zE&&8PdB8vsxH}2;5uX)xy*QzQWA8vDst7}+edQwOPS1fh2d2US=9rzJLuBDLM6M7! zft$N=kyKbs3=*{S{mAN8%8^K-7X@l0zhS!U}lc?d0PwS4sZspb-(bfk75|9Dlt7oaM<< zoH-kWd2liyahfG#x6sk6mwLapX*EF(Yv#Qd0@-_fivzOZ^JvvN~{O%kaL=^W5 z#l*xQS@tY3*GvV2$252`9jU`7tyiY(u#(~fKH1a9NBhOYI-jEGmzresW*UO7P>m&VN#njZ8K%t~lJqbUbP%|@!g@@ykl98pW z>??r=LB!+fF?B#y|`>s%@=>X`8(Ju_K5t$;`z6_v=6R5_%N*gv86x0wO;GO&ea$@u=7vX&a zTA<@kA0!e+IZ$dS`{E%6BC3jG4FbkU-AT&I7q&xCp+V3gB3uYZLSKLr;v0M?pt^v# z`p$E!C2^Qjq$u(6`}XjK0TyV&SI~Zl(QgK{H3NZIpk(Ge@H&D9Q?+(q6f4{8Y|4QxU}AS{tB(UZU=Wcyuo4kh8*dl_&rvv|z#kp*Szc9mk1YtM5I;c=X>t zL$CHR?P^G61I{4}1+~*USo;Zej2)Oa*#UI;9Vq32mjT=VtnLps7$teNZ!+oK3Fr!l zZ;uUn^b)eM2ke$<6?}+9&yzY_zCaD=5JeN&eNvMl<#f8Y=F38f(cG@QE1+CGHcxbwgb5$Wc|C4&yR9 z!^OgKo7n8LIZ5EOOC4Dst8;n@4ifjNd>|$f$7Hqz_>NlzMn95~i;9W_9OV(o7xa)o zXNDmCRE3Nv`=`+#$G-{Rq^^IebbqGE_JPXA_&)ul=DzI z+C5;Y@P85Z)=^n+QP-&5T^NLVq{~1WJfecMq|!>4AR#4Tp{NK-NQWRL0@4i%3er*{ zEy4rR2-0=u#`C`SyLa3%?jPrjp~Cb0V(+!)nsctTbpY)vTW{=B1^Az~+2r~S#ZMke zUwS>;37g>-%J42rr5Krz;~Ew#95@r(i7uaKD9*vahytSDmEqvpMJ_q1O5l>$e5^T!ImsD&JS-oJl;yq$VnKSbE+aK5JO#fuLKjKCoS z*jpDbU3x7amq#P>+AB0v1|=n2G2SV;=fN|0Gs?=p7I9q@@SWN16j@**Zo!}!#G~XX zmpLd?EtC{*=~ZCiVlW4vPd}kT`aI19Gd(@LK<=su9E9Fv@!dj4M+Z0*Uf=(9a1xd9 zF(Dxuzluv(wil9;?;7vMB6A&;kixmjpYEySbF1sX~4a& zIJIc=HAib{%C(b0Fo9DhNesTI-*(bTO_8;RW(NLPYZvd4iWnnO{ravw7FkspaAsp< z?UyWz?ipm0uZCRU(YHP>eyP(;sqnM<^Gl92rTpdV+4eHZd@{?tbBrYVpHr>LIvWwG zivLPF%zHL7A`Y=Azx#q0v>HK98;1&`m}PNE(#R$8eibT{YZ<18s5mv({Xnpw;Hwpr zG}{FPNU_m&9#n=dSVe^XFACT=DEqrUZlt>yS^J}}uM~$=e84I17TM~2ug+QgTnFoT zgx&9-v5G9Pv$K>jkjIjaI^G=>NKeanW-Qa`vt?%OIj8K$b>}7rH~!mS7uDa9FR%&+ zp<_I2m?fpWtFKRNqPHshR@a(;MBP_$q?--k2Y}@!RTuD_K846fif#vQq~3@HkGdE{ zsA&J5zR^)niKz1QXX?q?pLuT&_|zG5;1GEgl6ta^Ahtk+Kn%5&z!x4u!a(5&kfzOb zSg2zj$2khU8bwK?k!wRiDd5AVGA{FYr7)8Vr*I$8*X-7@8EAN7R3Ci`xWLiT@v@ZE zn>*7ZO=s0{I>|aQ59*gaio%xsAs}Fxn_54ZWm-pvA+KV7nh=aO;Zz9z9#3|$H-DyCN!w#Fy>R;z?_}se==T}%Mzb2 zi;~wj>Ep8C0uX(Yo#t(Tgpyl_hk(kl_7INVUtXPVQ2(7qH| zyNinIv2r{;6VsR0+~Z_#9MCCoKgPrq`z6C{&EGuua|?bn3P_}m=AydNvjgnc5RP$9 z@S3B7g27-Rfo37M-oWYbfO5cA#V|pB0T}QXv_EZ@c_cwu5(fot+PwK?w-lDTMx5j9 z`2UHyWbN#Dp(rq(IkPteM`~TS!=GyeC)~#~KVJWVUHQp49SKbLhZgo-#HRlq;lnH= zl#xyO?>=7tV#OeUe-dfacxu+VBG~i{;H1(BEPSj*^T*__HpwyN1L~#SY<-2p70n<2 zJX^P-OYWitT>fuax~bio8d4z!KK_vZ6I9@Nyhb_nzq~>quf=f?MgW-qx4e8z5P#fF z00sYlfCBmf$!4|fR^-~SVevl;wf~rszklvw{U5B2(McDMeyh?e|E z+hB^BIQL5CXuFT&V92%pWvc&YIS{Ob@(XMJfA1TS5r^?^T3B4nhb(y6Zw1cSp=1Sa zuW7xpa%KBe3knKQT<`B!aiRZ@z+1WM|Ka`r-owQLd3pwhoZ{l*d>XOWJSN}RR^{9t z?4?{L7pr1*6`Eyli0W`x4)A+NIFikd3b7#ir?vj+axsJ zp1j@cxAJvzo0?6z!N@^i>)1gri4TAgm;-}$7fu=d?U{;as*B1ygfnm5l z+>-uhsM#AeYZzX_7(}E4GixVoP$?c*!6`Zd`H$y6%!I@dUr&|H) zcAi@}2}Ta=!32S}!HDDnE=!Jw&lzR)vBZQkOP!mI|+rX1hcdc#fR zL%%gbk6_~CpUcs*Vb!TM!WHjt6oZJ$=18?2)y%OSO;qpWw;icOBJedSBZ6>8Ubf|S zC!9(lR9%x)F!!`2-GsI2E5h`l+;(TNshT5G!>wYdD}2TQt@i?WlvxZ*OG{u*FuDk!H-Qv1W3%+BG zR2~=&K8^-IKMl0^kzWgQ9LE=JUbTSx00MW#x%Mw!y!ba#^tl}l3Gp4q*RKQW+(9Qy z3`VdyY&*;qdu92k;PS%6!RMlSVK|Ypk~JyME7H<8Ft8lTV}EVrMu3_dI1(Ffloycs zU_l^00%(_KPz^BPDp#-FFoxN>UN2stW>PM<0ZP2O7+f?E=_#CJe97Y1*Vp)LDBAEI z3LuS_^Ed19;C$SuAyT;D%FQx^!kp)x;9}ofbc;geG^~5G84J(&-aG$jNc35+fcLHN z!s~Kwn|Qi3^E`OGgXU+g7@d+PP7jzCVtEXLxDVhz24IMs!Co26SKl3cMA(rRPKPM` zCAi{ym@G)KmkB-RBHU!6p|2k;-_g`UcP#Nut>KmUM%R((UYk}hXOn|X22gZxdWptl z<~VlCg_!ze>=Q!GRZz5XFCda}EaE{yyOX5P7pa0YOT1G}e+nplJ+l>uQwf3Oyxx{& zL8xoI))xpt7BgdAB8&6Eerai3p)O1Is+nfeIEEZd_Qfdq-#8|oa(Tld@UCe~+7lEK zd+F$8K?K;(4`X)E7YI}y%JF(li5lfLk9N>JZ8jxij6KgS`9G$0aDyIk#=Jw2tThz6 z6H?R%Y`~oGufhf;O9++A1ZwVH!{4AWOI{qi@T)ERjD>|oWe|@ip3JG+1B%EVc<()k zL|xbr6bR6cV9UyvbJaN*WP(Y9fC^^9qW(pB`w0BJC*ZF6&=d{SLbj>V&M<>TbKD+k zi;SDCB?y23|Bh452D=J|8?}6=aP=r*$`B zi&C)YT;CQ1sI_akK!bl_dl~H&Po6(tA8t?+iU5~-{o@)ym+z?t zJIVP=u%3YI?Qoos1~2ivbyy!FC~6K(r%5FUzTgM|8J;#IL8X$!#OLglQIXO}G2 zJoT&g0<&6XYt_^Jta#Rs$N8YpLRWAFl%<**SHmvHBKk?H%7p6KR1y)uIrX<@avpvz z%4d^OVg38(XW*ajkwYwsu^%EXJ-dZ^v)l}neEIKdFUqTp(|hN7TKlbGnjpuxV3%H9 zGD`P*WHi*2M2GFJ3#tP3^Rwc|JI{bP#F@hY;5x9lY@-tn-p{H$RxL1k@vUcOpL>y3f z7l$*JJ7F!^-X2UQbUi99tS7{jTZtO?iTc@Y)(8E%w6s)=IS8Rk2k{4#!R>*D*C31x z(^YBj44s0-sWc&HTVbja>P6*G}PMKp=d=F0>u($Q}2uf}@}Wi@^TJy^WS3 z`?=>;zvkMSyNmvF{`~noobET&mc2K(y&8`PRhb|PPgLoo;0Go81z{X`;@!36`yivD z+H!9XiX!EzWLxUJ7WZl~%~0RC$u?Wvkn1o64KdFgMTj!1=v)>MdKzv34B;8l1q;}E z5T1-Xs5#q;k-TLnv{0vfmKNtHKpQiOM`l~~rQkrL5)YaR+zlZo5G^JzU(+BR9CqKd_VMiVxIQAy$ zYr{z6q%@?hSif!^&Ol6TyQ|L_`9t27Q{X|qd8aU*Ur%!WhYue_G&&$n^+7C@(4#vl zBosn)B&G`DDM8)f*@T2HE-Sl{gT@otzucjj>Njfz_xxgbc(_Psnj=60;11bRWl{5N16>Iwbd_Iu1}Ze#(fRp#)={kl)2*q9Z$LJsh@ZX}DZWIAX&olp zxO{2jh7Bxz6tJ^*s#4<7-6Yl=r}%cw^rtVnrzwYPU`s%Ad^qSH+GV_E7)%);E>LOO zt$r?g0fi|5Ore)MmM0CAKKZFaTn|bYay~ag&Jk|YA4o?^8PjM7yqEXN&x&-#EFj_S6>)B+Pz-vJl(#=BcP)$B&eoaN`KNagXD&8L)a#K z^wzFj>wClpc}f8Ur01}1lSZ2v4^lWWUD1wiT(@rD)8q?gqKL4~sOdp8mzzMnP<(f- zHj-nVUt%F(o|5LyEn7sQwHPk7_NoT7#q}FCe)J)){QOk#fuRB+7s+R-Uq7 zURsn0&=^57@nKivu1Vpu{&mG;^q={ql3^RGpPx3rRgCi*oVLt-#g2ux*6`cALwL1f zG4&KX$TN!|t&JOEUoX(&;+UH5|8)%oP*VammU*pz?b^?xI5xvx!!cb{*K#U}zqHJo z5vZ$7&QCtR5y*kL{rq#$e^Bs%hV{Uqw>qiK-}CIxfYxSUVAudI0Z|S@kS=Z*)qRFp zy9CGuc#T97lIV`g7qeelP5>nQYJ9=UsjR6%>Fs7gs%8cVwOfL=gw0Cxf0q6WC5nKDyCJJ*Kq#K_TzkJ^W`Rs}~it2<8^CO7Xxng_? z&MiW)X+{R#xRX{Ffa|o0kuZVOKqgY5f&oYscLajhuqN?~p_CDgd`W^EcMvl~aV%Pq z_(m^(|I3i}N@umOrYKz+_$hkHSASw#{Xjd4uYv1?i1!pHJZ&VYq`(sc^)WkF0cJ9z zuAu|A(!qO|kZQ`+T7NGM)P#B_wp)jN9}Q8PRnrWQZH_sjh+Qqk;9a*}M!g;Jo}6b7 zBwapPnjTvgf8^h(nadEBbdeB9Dkg|W;>?f5-<8M9Gg)Zxm3~`ojtaP?#lFk;Y0%$9<`vb=}oO z55W-ilU$rUJ)F}nRSXM73Wd&#tH_cEX-JKwRXE^(zlG7nd7uk5u& z@{A+&763reB+Yy#m6Q{s1@jy$R-gXFTG+T(;giWhVV_Va|E#4bw1J&30T1t~nxGGR zbotQ^7%n2PqzLnp&99(2{?^KRxW$)4BWKsagKiR0+Oy*nVgL&kc4{%1^ug}39Q$z| zcEDm_AW>^;>rnRKEtVMII*_Z5ALX+je^Y7#-k{EUlI5*(LOF~xf@Tfic{gp_rkE3l z?4f4Q$<8hv_YevH*38%q&*7+;n9lu_MdFD>GDGGGGH?yYbwf}2(4AgcDWymK5VN%WS?KZq%nIWUqSg=j zW>hVOw4(<4`YbC!KN4O(JcOEnGX@87o-Wa|)G|!FVBK)16>8Dh^?B%YzT=Bg5NDJ; z)Eon6E70^-9XZryW}Db8VW*(k7gAVQ=w-Ub5CW5Nx+bo9=)*GK`(JoE^Uuf!!}bF0eB zelK6X&U48mmIr3MPh6W)=kwRPIP<-{0SO z;H4Od5))CxgNj9CrWiWfHnSnYFY4BKA(@@2QuqeHIZNcAMJOE z>;UDAFfjl16~vv00GSoQYk}i{fj!RQ!^-7snx6#Zm5+>{|FZRdp6aIO2Tv9KBNwEx z`}g#843G=)Vk^pkzr=$)A0@w=mNpR!5Z-4U@bi)KahtYpr(oUCwhG^wdUtxF#+O5u zd;P|ZtUCP2Z|-=R6xA*-4wjz|x^x1TdFq9KOYj?LzktL12DH|{Rs0GO_%R4q^i(;X zf|p*TpK1YgfPoNnfjs7@x+wrddU)1DE$Kug0zLM`)3dK&ek9kV(33R7Y^CBjY4|mq znmvfy^!o>d9+0MyKrE!L0WS=sgDtlm*2{)>D&FKRX=~5m=_X%elgQ>V{JIJW3WeWS z<8s3{G^~%C^F@)&$7_5=BEMI z@IgD7NslN%kW>wu5^2CH=z=T&h&jNiS&3c{Ly%JNHp5(d8t;cgJqtWdFG~H-O-+7i zA?@%j9YVwx8jvLX5t1$NOg-!$xJv=hHg1K5=ZGKbPrWR>bl^Vx`ERoj*(s*J7Op&8 z1EK8oQv*N;jf1WFieLTXAOs884YsTmXya4eofosE(Q^6&+)d$YJ)ysDbo261&K{z{(Ys7`0fpiPQ4x>!)4-iHr6wBSz)Ro9CH2{zoQ8OF0=bWYCw0M zDIF;dQMQvG2|xe9&BH^KP!FwmrO-0L_>^g!MzAe{KF;isuj~_q2Be26_BekI23sm@a0`6;!P!a?S+Izt^-Pp}wjNk=7MsH3v3g1mT zc2rxPY>Q90!*dADu-IhR$9u|XsS1JDm^4BXeF6g|K-~kxRNM|~Vrea6x)X*XBi^TM zVRqus$?L1Y@{e6MQ%@L5oOCt|v*M9}a)hKrymfe(Eb|U%E<-=fxS-7zt7er1WoWu# zktTBZOw`SNxPL$SVB#4E#`|05h#(8|$n>(H#rgZk{Oa4_p zd5%TDlR9+ff}Ik$GC)|Jg_+;9eS_FTo6w;I+f7rq6;RaGeEYm_^Ort?Tqc8n!F)?- zGo#8uMJ3_K5i)poK)GVk%v;T#T{F-B{DTe2*8TFHont1xgaSv8MA_uA+Vldy2;w{> z5*`>i)`UZ=iBYt$&?U6Ar?Gs}p@TCP$}Hl3O|1xy+*+Z((E^0Dh}?1@ z&-SxELfY8gC;-v2K6+6GI|L8wn1j<&#ycq)nUWvfq|59_ox&{kU=ZrUJ7~=$XA6=Q zE-@`_u~@J$E{Ur6b(;2`CUe9$l5k+gp$gc-|F5?8aRzVZrfG~_1j zEUSA9pz{-I6Kp@08co{CX#INGmtp=YDakxEY_p%2G@z3Q5QU~ah0^M^&F7(a))9IdI_-l}``;H*i3t1Q-f(QLg?f3Vuz{oGnp`8)e&EMqAr|>M%cP zaH$Sq{R%&=8?2YCYYVkJS-qvft7Nfvm*_DD_q@y@QuF0slXJNYFF}{}PU?~#_+w8G zgvj_Au)Zl`;NnXAmm!Gb>Wn1m+hGzxnX$#Gj+u>Z8+n?Y-hY4Ij4`tI+AnXJ|K#Og zx_kfn7nrQ^^KS0$pPHJQ&}6d<8aG~<#f35k7{$K4jJyWBl1BZ)N~d;*9@k;;HJ=Q* z4!nJAhs?7NyXKX*G~QRTPN7bjSJZdYjMf9@U5Axd9JL>!Ms6eF7Q^%b%l|btUTTju zzCM62Kqn|i&`*1nW)pXkuELZFv}f11B22`5a9=v69)9S1Fu)DSLtOv&Thz7<59IxF zYMXyqN^g(|7TtxLB)H*rs~~QxzaHE#whsBk*fxA}ygl7Z zn!F#tmAr(_jQ^SFld?N>6>iWPkXEH#KV@<{biI}D+-W^3J}gcCM)TNTi<8@e(iJc1 zqyIH$ew<_of=Uk@2mRV*ON0Mki|6_({5{^LPN{P#WE`kE9T|3iog8ig1EMR7c1KLT zIMU;rsFfFZk=E=)GuMqGPUPQOH0x}8)XL-R+t9XRlqhYJO6do8b^PSXudT0FY$*FK z^LU#&^ajL$)67cnPWjJ%V*4rKM3C;IzsA;JGx}!hTs2S0zwU_4cKjRE8R*9y9qWnT zz-ixh;taSVq!LmIvd|yR9p<4Uf9*S&f*Sm7-9h&NBNqb zwT3=cS$E5XsW@X+481(+JUOt4uDBgP5V{ghIy(i8LEt&!sK<4q5of179j6F!FGIr! zf?m-#r-|3*50rL$@k(zdL|%l3Zh@t2+aoaB0-EsJjU5FmMP_@`Qgsqj3K{q?I9Q?M zDb)M<^5?3Q)?t(Z4?vW#p34GTt&i~;d6?8qUMdbYx5fX95Mqn9bK3Jt>0(u-{pNB&nxv)y%gCLqTNf=1wi@%65-JjrdrwU z=K-Yt+J6FG`ubi%u7w8-{_<_vfR;vJEc0a=qa*8H-;@l|jT^5wMaqOomo(o%!{4M<%dT?NF)uC6Y0h2ZjFb-~>r z^urO2G${+K8Nd7b<2j z+#lcpHj#!qLhqtOItJ{5Ao?Sk8s!JDMSlRLZX|x9Rgl2dy^dx600}`MRL!srxfoje zp51sDgh&mo7Q*bd@^SJYQOO~*$w99OweLNII>NHu29O;-*P>yY`usb#iY&gRCVsZ{ zpoXYHH~j>W23_nGx?vunX}4`^v9MJy>;cSXUFgQRn4nC~V1zt%gz%D}^-djrZrYJY zmLwiN>bxIl=z5L%b`>?hwHrJV&+%#$pcH62{=O7-I{!xT2GSvq9uGc;KTKfA-axP? zuO5of??}}T$TJ3OMjix(!5VH31g=0XqX=VOg%JG5BRV=w({|q zt6#r<6ak{#-fp0QnKEC#B=bXFgcm$H^XSmy)%LiVg}K2bgtCoDpx9LyxN!(|HL+r$ zHyXeWRu#m_ukhnxvIv^tV7Gn&6CCO%3`&oNa0Rg>ywP|T42SQHn{WfTVK|3PZV09} z5G5Y>N?cOVWTzy||FN(F4YwBTlF(8|#?2$do{5D4r|kUMyNG61U@|ac$PMokH4K)- z+PjF9?=b}cou)DI&iXbfMlc!?5`q|Ny{yYO%E`$QI0XIT6l|Mhcl?mImB?xci2?b9 zDp}K^pFZl+HubrY0;uKsHK8GccGAtA*%g$ZI5Ofwl)Qc3fxUY#z>xq82bobrE1Teg z5}N!pasgxKFgtY24Pd^A839?{4+Ih1=qL|dQ!3#*`2z0@!VkQxV1bXKpp*eCPl6#d zzc=G)PEhNlFx~}_481kQhy^Oi+UxOt%lWpS)G65pYEJ@{#>k9tr!>g87u_?mm_tmu zq0sM8p3Q4JqMl>jUj*&eYIuqpe}aZL`ul`AsE)j;787*A`GFa{!nX{YS1De~uWt!Q zK(6fsn;&!W$(|K6T)>&fc%NSr)OZk<9*IVF!I$?Fj2QU<;wFbLX&il6c7lyzAl6{2 zP}-x9ww^@ldCGXAEv}VMvqFXqcL=C~B8HYd5^6i?16Il{(#KG@3m3#u78`P**g^k= z&zpMoa+!lz$k=bl_Sh`UO*gTLDCsC(J-_Y2Ccq>C=0Oe87Zg+_3P*x!!U122sJsKv z_0R!9v9y3@;kmN4fUL&bJC3762`w(l@%yo*9R2*7?1AomUdi0*&lEU$@&#CR3U%&+ zBRr@W{=Np^1(aqPs^zQ>3mTd4aYE)(6|P)Dgf|Fl3ych{c+rrqBAN&+6LLv&6shC^ z8d#DboM;i!4&^$KI~QzOWWteg`Z((Qh(=@L$R!#wlpLuLeqSu>(@t~NELRgN!JiWX z?=hY8b(mbaOuo5;ucAwBmhz+(f@R=^W!nNinrK3f5Esa3kY4${v*(RJH1*?E5DDpD zdw%Cg_yVi2?tEXJ0eh5d>|IjfElofzl0S?)r{VftT9}X}$;fDp^VXEyM!lW8cKE3# zBJYsSXv7XE?N?RPz)+gQi6EC@s)9{?4hcruudv-LMyt>nxfLQ$1$h5obx~W9R>8=W zO*=fBn;Moxzdl}%HT(zB(m~_M?)KZR&7&SWed`Bdk?W{6gc)Q*D-l(oM=>C>Be7*&**P?L6AB|4n1n$WqLIXL3uJ~hSpv0_ zATy$-1H^$CCP*7x6)_LXzAPva+*VH>jdOlJt(bE{Q z;B-cZ*@){fXn^wo+d`@rD3byX_9@Ks`t(PS-hKc6B<0aH=LhB;6fB-(&6Hf_*_j!p zCm+yH6J??dCmY^B94gESs%FY?sKX}1$C2VKg~XAW|7@zABF!vI?zj`vkS|r_n;vmU5N{hiYg{J2p|D{YxM~;p8-Gap_ zHApck@2{^0?#PeN3GA#wOy7^EeF8EB#0s*uuzR4}5t)rAJs)P&Kv)*)QVnG3Z`?-w zR`(MEZv+#If|s=4ojf3esBA$>BDnH8FYw~i*M%dmBcQ_r zUJ5Q6JYzVL54z)5bT9chpXeAK9qc^V7SPLcP{4)_MUuSgJ2ZRr+N-+_Dqttz0H1$R zKGhnh2cL7%JapFqun8Ln0nAt81r{U}7o^>BfC#i_Flpp)kva<9N5r^Yeze=qr4OoC z6uQj_aStaX>Stdwho2>6o(J)#W8+Q$sv$khU_hczz_H zl@j;Y;?p1b~XP>1|P2Oyc)fEo_; z66t=F^y8pMgT`~wVni@Wl6f%k41B11p504muob6{eIgAj#dsln&I?(V3VUg3ofX(u zKfsK;gWCX9DEhofLAxFocfB;tS#YeO^)Asp-Cj3(N6xQZZ^M-wA=LzB>=ZPpsU(6w z@Zr=7#XeBy@bfawYqaDEkRp#yLi z>y|C@)lY_c9Bp{`0SPHB{!O#vq~8se59h_ntJvej_oez4esTLWdBylfG9}JvyDf1W zKs^Q^BCZlxD_>&spspbHKS-j$_&Vs~r}$nqY>3@Wyiq{57|oQ7_Z39@OrlMtE^34& zrF0F%4~%SCL)`dnFO8@reOW$ZDKdjhNgO>l_bbweMM*+GA$L6ljE4BDo(XE1DUP$2 zAnD&!#>ypW3h8Bor#+?n<%J{70(=b*%0E0s(5m=dpgUx)f!#~k2~+M=Fym;S{|o_^ zhz4K>p?e)|nnYi6gnO+r2Ko&y<$*9=YbTgM)x^omdk{_l0DYMlZ;%mWMk7`V?D}${ zKGTpRQCrx*3uBHI_&m|G*CtPY7BUf_0-3G^t)>KwubcG3 zvuqC;(^V7N42z=7hHSN4?m(l)SC3P_`_|b>m!Qm=`mOzVWbI{i7#?_bslRuFU7K~w zunkn07fC$se&&O?LRt;yAK$!He#+M|qr(o11}1a(U5V^oUymJ#vaC0!gR(gai9~(6 zPs2G;vu4P!NMS1d8wNet>*7o2(4cS)dnfWQSP8n^V#|vhH|2F0nIqV{DN0<7wGag5`1S} zpXkMltNa4fG7M{vi_r_@Ql4j7kXBq27oRDXwdnGJZY&FA-H+r17}kUF1=z-70ruk( z&9e1a@VFdjkI`OLo{LwH@LcR66F`YofH>SSrQ$a}=r^JX>4+_9kJTzHs2gBbLmEl! zl!UR0_s$gCPT`rO#9W1!Oo0%P7bhV>yeI(JCQXTYusEX#B21HN0)wC+b(Xa-4}2bC ztj}4y_eaD&1FwDks23i`#KeTc9+~FCCw{gu4qo-3JN`*Id?Cx8DgvtTbsk@Ue~>uFEuw zFH6(v4`&Zooy`ran33k;0ObM?6qW-S7RAacX|#U}lMa7B2Kzp^b!6Kk^@;^hN7tcV zlE(p%f^95=_hB}0ZxB^-4IJU$F>eJM4%ZbGuVH7GH9m_PvbuIl22{e+^%}>Hb=KWR z=Qy$FZG+8z{ni6xz(uTQR6a%>QAc~5X3|ytTf1CWnmr+2@j^lO7_nZ3sETr10rNj8 zz2l;xCBvPCVgCO9)b(o(I*+O)C}aMNIDLN_CdH7cj$*2{?Fkq`M4XaXRw!{#!4G0$ zQ~#y1QgNve+xbIO%*vz(6f^6tqbP(f+44^g={Xcu|l0MyQY5m93y=^;X(H8v_}kb?OZ;ca<&V?|qiAg#QWRpO*u5aK}re*?lE z9?gb-`(bw7(?wu`g1aIwJ3TMPP}=zE0knBv#aE7Z%ahi-w})sQjOks8jK$l ziXofaQuhb=>mV<6kQf(Z>f66wvu&Py&+7SXXauo;>vPda!FMQn(6|@SKXgE^=6IiF zw$G{hM6CGN$ccw)0Sigh`CY2{%sykT!=7`Kl(p9kDp*$^PhA5?e85O`P~`-Q9cO4%YNg zY;zkjyA4KVO*oJjs!0v|LPY-$QNQz2Hs0yPHhN*JP>o=ZAr57gm88EMe$@wnop;<% zDIl~HogEyBeU~d$ta8F(7-42+?2r{VVyukJb}g?fK}oh1Hthb#WblraMkj{1^`Z^YJ<-3^XE^z z1(+h0ST1HBWn*&(mLuZ-1iob)4hvkcJ#LvWd;U0T64Rt&P;SGWK^0V0f~!Fe%r3RT z2m&|2L97NM&y!n2Uk_7XLc7`3Xx$CO;6>cBSn1iy9vq{}w)VFCS=J#Kb_!MLT{yA| zY_j{GeZ~#1feQu=Oybd_`;Su{e*Or%n=o<7=Tgk2+Rb=iJi(l!FW=c+U(^~Zmh{QT zDbm)z1AReR#^H&+8#9_aEi^mRQ30A|dV|A2;iaESIao-zS(J__I-#XOE}_gY12vzv zL8K*{-;^6%Gz9s7k(i1vd|1wv($sVj!k8Zrs-&5FaPV1~?oY6Vs4LpcGEdhN$F7j@ zASyPv9Vr6b2e;J)QV!@g*?XgL7tQ~}Hh{f=<5=EETEJZ*NlmPJwA9qW-N#m`2c|EA z5QiHx5ZMGHHsDrY|(Zx9bqDK-;e(&31sF81V8LJT1Cdx{y>!sfwl9o^cS! z)w}yXo0$gsN_@)2rYdhY6P8l%#WaK;GD{2A;wOJuTzfOCo-Z#YymXM+1gF01@Uid$ z2}Oc30h}hCoan3qPfA(_c5M8;DMgKMP5F#`GqJ4cap%GgO1y*Hx9h=-ATNI4CrgUCA$j>$DmG0| zh8Tx4tsKcPOg=I52Ni{9fG9l&qa>mtKc>OTCK%yRm$ULr~$g94`xssITDtm-2 z1o%J@$Kcb&QJ4C5F=3h@G=gxTgN6h(Du6Rp)fOd#FqB)7mbB+UG0^t?LF)X1!A6)u z#S3~;7XSnyme8}rb6yrLBI4`s|LaxnQYJ%gr#%>`o)3kRXqkee$9F`jX%Ora)GP=r zcd9607sshh<=XOwMn{Z0FQB&N4tkDGB+sVQnqk%^C%6v?sfEFK#CXPV>Xcku<V8QJ zE4^7J&}`m?#3?5YGe9^P$1S>w)^FQTtW?@5DJ}gO?P!uk@oP)-7 zMCPJf5xkJo(FNUu8e$${t$hv!xp^CI7j#(7mtXYF$~Uv#HG)aijw%f2&dItJCG3~Y z?|dlFf7Y*4{|&=Cz2TPILjuD^zB`3}Y*C5`FIxM@nd{Z9gv0f(*E{kU@Ys|t=ia%L zapzUcVV{3vxGfAn+EzB-EbYoEpW=V6-T!`v5BHlC%Uc2+bHxuSJ7mfg6EDf;a3x%t z{C;^zY&zMbUHbrMa|QF`XA47Zc4s;*5;waIgxF{6Y!^x5We^H682qB8BG<(%>gn*R zB~g0etHIGT9xCs(hI4oE>1OKY1jf@xCidN_340^5)giC_;pqZZPcxyo(r>BG!wyrS zDcyPZMhj}*NauY$xh|}@ur{k{lYCtj50!Gd0VX(EH7MQGbL`w1ush7Pv7+=xV?k0I z!0}LEMb9_-QGH5Z#-|V^H|Mey-TL8lVk9eHOSX!X1@oaXRe0#_@ zS1K??WG;<$ZmX@y`T@4*nKjh!xwYgw8Sq#2#(8z@n9 zVq=yQUI))EN0|+9|LjPt85=0*v`g*V#O;5RceKBlhm+@Z*3yY*p|ftCmCAWJVzpKF zS*2@zCTlh6yZ&&#HFq(%J*dycv|r?X!RD&NMgdNH#ChA(wX>7%>^!{g$e5IZFfLLz zLw4?Nht9W`_5!Kd?Hf)NSlwPzl{-rrHdlFRRMIg(JK(p2zAXkH&_6S=PW@Eji(R(c z$*hJ?+vmVsz{Z2%B%xq_?BdH;&e}W06x6fY9g0{LBXs{gxLuRF{QZNSr>`fkZ=IAC zR{oNp+gz<+&2r>{Jb4;e=7CVNTrJrg( z*h2U8YMj)Yq{B*c;#0JeCHQ_e=d!vi9QsTRST1r3eowje5KI+IA~jzgmQB+ zzqmi%^Sq%$ANs%TRW3MAzxG1h^pKdN<*#EK*e^2<+R$u#xj#-Np8Hs|)bX@H$<4JE zou^nnNeL7tEC`pnU-FXv$-zLx{jBpmEm4w~2&f!$RrpfUZvA&Q;Uya@LtDHLm&i95 z)o+_j@l6}}y+c}}|0K=Q#IHE_@Gh)p?ZTyEVfiPeFyns6xKGU{9ug zFzn+y9u|blHqL1f^nke4Std+!ePxb7si@vs<-RW8O?jgAwgWAP)y-V!} z3Ttc`hL=YtWYy;uUJ2P}iXYncs`%YG)kho}19BpMMfdi*oHQCZEb~2tr%BJe_}ZM= z!TC-)nnTHU2Y32Rp1-5NzeZT-{5)sV zoz31eH><==FU7fP31@|xIb{B<45YHNxwz$!e)wFjRM5=6^aqlATHhXO(UbgWwz1rD z*^09+TY7zwwMuaF`xjv@3qcPK9I@}MuIQ+-;_e7g>phEN&3B`o-exf(y_Pr~$)ZKk>3&$c@55>0q*c{RD5AW0Gj3aU4ZlcdwtT+c@&_Y z2@fpu-)&5ko|w`wnk!qIxBU1C^>*5bpoz<#h~?CxXB`g5o7x1In;5-&y8NVYx+Qk< zXN}{evBGhuB@^-UHxFkxzpy!%U%j={ywo&z?q+{=w}D0D^kXk+yOCM_;2Dekr8oNy zwxQ~*Xe^QL6l;}_&Y8~wE+sQN z1>~5oB&ChVx4+*XG8$AqSmP|+KA$t8TfEDW>u%zfX<<#*x@z-)g2F=@p>{6}3#Go% zhJES{)T+z3H*$|z^F3!~RZzl>G^al*4&0kA<-4p&O=k(dt$Dr8?W&<%IbFiQa#-W* z_H`Wti$ZFrE89Na=y;fArfI%!-P?c`TaJd)eYee2QZ z$hnDu9|1$Mwm<6jMQXMt#-4Sh%>HsXFe>O#-!HnC<^p5-VfK*Nt9%Eat6rK*Y+UHI zygs_LZD-@=l5iRQb~ncCq;4@bA-&=_=JSIb`L%}TmL=U^sO%>LME6bHOUti6^vz(? zxrB31T#pV++XXOMJ0zx4COd1cSFLB2xqT`yZ%Js`RORD+t%4uf-}Gu1d7|sTZKkX3 zi&qLx=CiQ!&o-axdXQ}L{-JO&t1M5!pV|4P`Gwo|js>e_IA;UzB{@~PJH>t1*zi_p zia*JRJAAt3%|rf_rN!qpopxd7%3&n~x4IQ_xP}7C`DbQ{iy2;1_f+<+{=GP8!9SGd z*l)zIv_;fh;PczC4X1uooNEp;y>NA4wL^i{5xQ>Qrq=_19P-#H3x9ZbESBYZX7tU5 zWqRqZ*Qv6xHc2!n`)>Q|j1yIWGUFT5-ucYVAE`zG{)u0hYQ7&Vy_CDwHo!QI{q?0J z=evg~PPN9S3%R#x4Te4nqF|Mln0YPnOsAL1>CjyLVGq6R&F|)UIfX|aB-PNcovgvPfJYD6nxPdN{@n&%4M(^Gm$yKkF{bUpq6X?^f zH4DbHakzb}`TprOM{UDc-1%;*#s$0ZQ;kneuEf_hrP68o18;ltAA^E*Oo6*PFGq~p`#Vpk^_BUa`)$$r+f?xV6RZ7npGPd$?x<@@&{2rq z+&5U|w*70k1l`|bVyD0m? zW`4^Rw;=ejR&0Y73e>&Zq#zOKtSV65AAPgit~VrrIS{;bsT4P_&JX03Q5fMC!UUPd2XZ?&5O{;)o5q@kt1~6?&^bH zha)mHYG2%_s7@bGIec?;xY}Dt;Ol$ST~TU<)pzf;+Wo1&*)ycRZdj74I9&Nb^x^!a zX7QxK#f9xVGr=B`I`6FMK7_{{*%c?QUD^}opEbKrFW2c)5W7U1 zJk`6sMgduEX73sIy8PG|z%i{SUdXkFBj?La!SUt6>QZlhh0j+j`lA*P@y=?!lAZOZ zDD?7b{dbJW;$yekA^p~~=a#twIQ;th7;|okX$((<@Y1+4O52Z*FD_1V1yr+6 zH`zZQKVVz;vU$4B>f)B$oy*=YJ(d4CnQ?|QbJ6wuRyN&cg)dyW7P-eJhMYY#(-wJ( zdt5xiin{Vn4_*tg?Hnv#d+KABCueVvRYzr z+*RHkW!RL;qvPjNW6`%fZ?@hq_TSCEfjU0tlk#)JK9w%!hs>0|-MjbOLxvlx#D_G6 zcJ{eYYU6I46*+6H7&q_FotJ)|n7`=D(rH&RPwwnWvxUjLTOsZACSoxnjLT3XnAEiBhrzzt8`3Hek z)SXay9xsgtX8mJu*`o2ougHJwDOVgkJ&Sd%o z(-d#j)f>{v;tx%Y%MMk&|G_rDsV1Gjc=Sxhjf^iZx+1RskBd0#Y@c&?-)9Y@@r$<| z2Zf4)Zn{KtYHw8Y2>zM<{pqjx$?Mp{>)@)5U$@lxlw$1M|!Q70d!|9IZ4VE;+!zHx1rpMHtF$&+<^65Dz^H}WklE&6ahtkwRo zr&o}2G{yQvNrl+=P`)SGJGpyYbB=L2F4}$*wDft_Qs?9G=YqNMqk>t<1uwnq@)sh5 z0m-^l)340`HS}eiy`exQrOap+!nwMDUy=7oC%2!Ij6h-!8=Kj!z0!VR-)?m-^`|*5 z+4kC)ec3uGVmQ+>`p&tBiRx?=GmX(je6H*>zkqN@<^4I!*$+#^svchQJANgpB;fs% zY!ezD!70C{$~j;AkR#(F`Bx6-O&y6(na<4Gou9qFfL%*p?Tm8fkxUmSO^sx)$5(g8w3f8BoL(wa3z3N_ zdet6My?$(?!t%)FVyYv%6B#D6Ef z4y@S9UF)a93{p=u->cBLpON1g{YtsCWYqb_NSC6i8T;xe)?HBoA72SkS+qAgJ^$&S zwxq?jeETC`t_d@%`i`imui11ZV~bC_{euq|MVRbXlc~D*z<|S7ylgy8F3!sSxM1ig z?8?_}7;$3)Br=|0D3gKaGNzCUib3UjFq%x`9f&1I`ol-v1X< zZy6O=(=`nzAwWoQcS3M?4-nklT?Th|3GVJNFu1$BI}8MOcXuD)quJ!%x?wNID zpO#&_s;n9`|3v#StO*y4C+_%<8lT=g4U^rBg4^$eTrCNkeW@Ec0h%Q|H}LQOzk2M-f9uh_rs zNrq+!%%G$c$XdU?rPD^!6OM(BkW0r~p+YB8GcipO5E7oBasF?bJa7L0B>~Al%4wgt z+cd4LXfdU>I`Trx=L1KHar$kSCvjqGyu+&<8q9T7`_i1-IFI8vTaP7YZFFJ4OLwP| zcJTo9G@9fasW8)-Hx^D?@*XUWsKKO&90|Vo1NL8KY=sU?Jf`y(zk2MS_!9G6>X;u3 zbg!OlWsP5c++8jZ8A*0fJT*AcjpCp!RiBvA19s8LQ|=d1aVL|ZOn*FehHk6iDYm3W z)Zw|>&z?7nqSkY&;Hi?&Iwe-XPUuhWv@zg$A{oC%Kty7 z&HqP!U^Z+19~t9AMf$~()b<=+uxw&Galix2RPZqb-R48zsP}DBYuermT6FNemxQ{S zV%%sq@~4NF9lvpC*B`Y)0&-6R9THxd_vOle70@kHI6nXhM2T^ zBYdvTnhoB6?-8V8SAZylmYFM~+t6bgEt4{Lhy?ES#xnlb>LXb`iPmDQtd&Z^G@dGp zGY>WUF;BFM=^|-}XYR#1N~5@SSlm6qICxxJ6M*)hX_>?IdaG#b@Kv%i;AE{3F{@8G zmGKQVt0%X;<=f^3w0(0~!@(BCvckYcI59z$Db6@>wHe6uk`Kwtt?VHxYHhNs@D%&APN$uIU$I#oqW7ViX2L%*okMmars04# zbywOLXKh3)%pdGYpM8neFcu6r*c~xHC=embp;H^7RrE0@B`ZOFCOw8q2A2;4MT(#4 z%gNE$O$nW88VK{Iy6*_1MY3kXN34X5FNefZ7M>0V4VV(rlEQ?n`zJR-<)SZ6J(52+ z_4#wLL`a9==7IP{i3rqQB(PGNGuf>W={#QNPS|8K;Ob<8dKg~3S`eu`)oZVRcGhH? zX-~jr?vFy#=98Zq%tVt))>{ch*$MGY%BJ`3UnOuUSMkKmhxb2JW0V$dz?wE93-Z5kU~Ae$Q%7Q zaFH|0w(BwM$rG)n0ly~>u7G#&N<96B`Rt{l8(lTYBgvre?~&Jy0ZTAL#-D|6Y}EHR z2gPyaE-h8q%5$Qk4&<%T{MYIU%un3x!#XraYPz0=lv>u{r&>$EsF`C+T=m7y7%Y)7 z#GZH0LZzqWktFYW}e z0MPD<=bTvqOB>}eElz%6TJ;fB=p_l!2{U4Cd_oPj;J3S(wQd(Wqx=Gc+G9A%M3qC6 zh=_H9bfSpd;9D3e^RV60{hIpBgd_1>WM6?`1rzH2L4YbWHNN_^l6aGjMw7#iL#y3M z0M${M7>RS)$tLN}Hdp*OJ-mSr`a-FLgbC+Jt6gveIR$klVH8VP3v^j6$ObxJqwT$f`)sk0b=|zQMuC((Fc$Q>P?X zY-?Yie$vt)F@$X(ZcH@or83^r&pPAJXpT^I@|$z3)K;Pn!Tq-{F1Z$7 zG5&1)Ypf|{&kqG@8_%x71gcl-1M){mTF;oQOsW40v;UARS8aTR+B^yVh}})YC3D&p zP3~{-i9b$dc=)1;XXuf6%y?O{!v0kF=Z~zJb31+8$MR1&WC_2^H8(4%3}E4t4+-We8WKADWhP4|S| zY1mrWZ0WnypQyb)5Rv|Q*_2t;uo>Cw6`|h6g1+5G;mbH|mjS>L*3YkXrN)`~J5krS z(MNyo4c{^CMQL*BuP7F~FtTR+za(G$H8p>(|Hh*-orpJ4%VEL&w?wzV<(QoLQ5qv& zbM%c8QzDefbPPkyVwO6exm-NI=H!A5JX9E93Wco%9P<(A|^L{(o*!^VnIz=B<34m3K)$5bce};?eZx9AVT^}MB z!yi8|%Z>EzLv$v&TlyW?&53>#GCaP!gvYY{NOo9UOBDmbt7Q!b@;*^UsfejJZo;=b z81eL!2y>7fv;LIUDEaK1s|@6Jh&@d@?%+ttw6`w@-ZI}gNPW*@LUEYv_&w9XHFzpj zUx_1Q*n{7c(C03iIaseSBVq9`p=eBUtnnm{^+bnnjd1KQA3~Lbl4D>J9_p3#T7+2N zd)w5%fy?^{8N2tkP-xTD<{QnzhrO}xWwsT2<+H&$w!3!~pI)~|uozWZL->>B+Uj`Z ztZPHDjaQo$5-j^WctK}0yTu}EYN9jFmU(I}0f{i;k7I*gB}v($TTh58FAnHq}M9?!l;rMmYM)>&;{s#MIT(6k>9@op|OU0x8}s`hK7{?nh+96#Sg}w zB*u^SefQ0A>0e8^6qLgU&qSZ{$9MYTw5Z;%9a(Q#wx1}^katAse&!AGNIx}V-G%Fj z)gIkq?oa5-c^!X#douZH+kXGu)|)#)M?w{N|Kl|MnalRhJt+k?8XzOc?Xt(Q#mbYy zU@&(sY=y{SmrmzsTF1v1Ftg&jbJO&Ed8`Fuc2q;1|K%=2&zT=I!v=hNk%Eb=>K;CS zkQaqE_w5qQXu?XsK^59*HQDc&Pi9iZ!0F*FULp?8K24x@#Ia=sg_j9BmLd<0Pqwny zj&YQ~m7rnWR~pV$vgCilS05;?Usm>l5RO{C1o=Tl%iBOim?6hX9J^o2eBC3uQV0ys zYai&NVDC1V8D>xaNXH&eZ$RB>9_UyUE63h*R5f~sh~V_}vsOm9uUppR0Ys8eq+`+v zicys7QL!aP%;j6**ql;$A+vMYn&uc0(~5qIE#)e~iYU~LK)#?%Tx@y2SM;LzRbf^d zAs4y=>7>*$UpHkjn3$8pr*v80|6b*8m%@wihB1z^+71H9r@ zSGqNS*h_`1N!L?IOh;RE^tJ^QCnO`uN}CuCPBsGC)ioS#H4`&rJIJs=K3lv7Q!`(F zwFBMN%-%JGG_bkyV&XvzRb2239Q3l!>H#o|ko{p|bxkA9{h;H5Eu&|RM=N~tZ>zXK zt&|#^*Ds-ar{za7TPX56=%twq^onrwZ6%4Jy1!ADO21|90x2`LCHuZ zJg}cWc~?jX3n(W?h{@mfPc3vse((c-Uyh1iBXRmbJO0!`*e;Q-_3)a4@Y&!N@!|Bx z&D%u2z|dALIsLm}!D*vL=ZOdg7D#OqI*0YW+}GV0DM{)q_r1!MDGu}f1#}^p9{8JcBV3RgMDfR zvm~CU72D|TMdg^L;m$ZEL8fr9aIpH6sBo@Q+xAW|=`ZG2QVI>2#L+~obawXQ&k+OP zDVb-mZp z&p!O_p_+vW+KLzh3?V_VV*CF92j-n?OfktVz$x(*b(TG?%2_?}C%Lt$V$^qQevo`7 znn*-5s`u+pfG>YHcuxyQcLD(x<>^r@G=*1bg2nP$*HF?6e8mRZseqAIOVjLvjsTaW z5B|vnW8vt~8t@%GU16Bjxlj!4k+UA%uoy22#(4GRtO)wp#6fnjL@M9JG?st;Kt=9o zg)$yGGk93rdfbz{7-9_9evPiJH?HN|L3>7@K*I8QjP95k!11cN&#mo?Zh=?tnd9HJ)DY88qzQGyS05ukp>>iOYn%}$G1Jr7aBp6JW zCP-OV=0bZhq2IFU%%UweB~c97piI}%i#n*D{PT904!sw>(4kM?%qh@4P$xX*(@gP- zKG=^chu$;!M|PfE^1hD*%)Ys86ytRIyJ!znS+evu3V{?YGx4PVCs*PABkR@Oke&#* z5}p)$47T5VcftFW6P%KtwYl-4v`K$3AgSRC!`ws$MzX?%zIn5vAsXksS5_LSoR2O} zF?o$<=bap$R|xurCFGS&d+*R2iV+>iyV@#$XFPk4cX)lMJ{e>TM^}n>8S$rdxaaDQ zSJf>>N1d|dQL=}8erx85{1ud#yPtBYFeN6E-C$vNbJr_U@kekF8K=MQa2D&2sxtl3 zm4h;|gnkslzC?4fTOOVb;<+0)b{PD(UlZ#DavU<@8zyC=L8rRW#i-n<5K?^Wb`%ag znN~e;)d?Uf;^!!-!9`L)`90iHRr2G(*IT=G`0?dEms|X1%M1(90aYyNZ8i{qqNw8XhkzPXzaqbb)U*974!!T5rmF5!{k(7 z0s*NW7gGvm#-6{mt@vs;H&<#b`rsC*?e1L!3;q$u`E*O<#Qw4U6NBbL>hoc!V)S3# zWA>2G8b+ z+x49O#v1I`bTWSm??L)Zl6z{Ft10_MbnI@1vKRUJIr_|(YWTTh^*=LL)Zx&JX%=wBvo}VBI72{UhHvbu;BN^rX9azwRg(pVU|3{&K}T*$OZpWvl3Ce$iFJ4TCPXr+a6+50`iwc zf{K#<{m^=RI*0=Lm=H^?|7zFqo(k;UtTUlEK3H?BO)M3%KX76O=W<-H@bMb+xXSn@ zv28)m6;^Q)X*SVD9dG{*uJD#_hi~QSnzr5)?WLu6Glrtl$k?(Z zLNiU8F@sWBVJEk|H5rt9il7XLqs0R&lMI@Ckc_H@{Hr?zVOwl5i^ajE;RC}Lk zw+b(5&g3ShOWlZ7-z(6nY7AxSwXv81C{3^!YRRCXWI~=?dJVMOe8oV&)4p9kCgw+q zO`EF7>|GDYQ&LaY>R`sN+Z>;Ex{3R(w$XKwH|vB4gKOKx4P2=d(1}F83CQU!h#|ib z;Y^*->Yx0rqdux5%dxK(6FIh(ML>)m#X_QH#Z^{kEb2T;fr$zr8BHOWD(=1xj>d%z zae6s)hLPvqw0gL1(mL8qJ?|wE)Re{sGuM_6f;lKYNY$ijmH2LpBW8Vh|0uk1B&k^& zEzoP>bhOcHGy#I2&F{uGz9vjX(vB6cd5xs8eEeeXS4ap!Ejy4>NAu47wtp~9o|DtC z>am&mbtAl-Oy>5|$P7b0*nIXz0J~d2@VTe&>Grs6}w* zOH8$TZ*|2hONead{_4}1w&ol(sJmcXmq@yuuLBMEC6UrPg8S-uxxz}Kw#@Ngwz~$t zH5$?Qc|eomNWzP&?#r4^xedv0w)DWCA&v^)O_fldej}INIdh%%zL0*AP0QiX{|Oru z-foQnQP0g=F9anorY_9yOSQ?d{v{0f?0`<~PpOvi=aqfRdWKlE9)I8nqQvF%M0+ry zv-zwyL*;dD$u0YcCRA93e8O@wt}&6-*4v<@METLKiLfYOgd5~T4`Bq&xZE90EsL!B zd7^#%4dO1-gw@a+_>fWs2-Yy604prB%QXeLj}qU-KN=7ysDn+UZRtA zrFcwEf2ta_zw%uANbreL-kDPExG}dqe{gM0q(5VQFad||%X@*d*Z<#)*&1Ax+5gnp}iq!o7*ttmlZ-q}`2zrs-aWeRoPjOgMUV;gdAzlp6}IP4x8SS1-Dc6rtUACX`8( zA2T|+x!|C$A8G;Wwjto{$PZ$<3*HHE;!9pckAXqMe@n-}jlS5FmuA=WT>{cFga>mR zjH%X{cREuXbd`yV9G$IWSX|6|Em`n^7YKIoWr`CZi7WLn##%xn3xID8EOuf?09V^} zR+jRM zBXBVq&QqyTmWQhLG%c-_Mu0B=DwxRM$IP#n#)EuFWU4NtWor0c>{VgjTc}>Sv9agi zgBP7Cb>)*I1$%%Y1xXG059w0zpo_9gFo?yxc)~1J3OTSTMLpnv-?(F8$iASiX!@x1*m81+5I$ z6A5T7>@tIglNT1*q@M(UgA*@ZBmp#xBBT8IlTUZPRJR@bOG&zSa3BMJ0&^S_N1{Zq z1ZhX29G607+FeO>`LEHAzbm3Ct;vtteBpvww0lxa=W9{H;Iv$lFb=IafRLLzXN7t> zG7k2i#0tDl4Sc06nvC4ANDe>}07t$F@q+})h=PR?^E;28?LgJ!_N^q*(Lp$#HdSVX zKf-+^QzuAKJKU#UxSNW_%W*p|au;$$Ay2#z5ki;j>lmKJz~z-FoXOt@ppNG`GKkn~ z7_ff_d29Y;7#;or^r2NEmfMcm*|A)mr>2t6w_BzlL_{i?u5}GLS~xcLXU7np%eqz@()FT&@=)C6A^m?>*BI_ z|4DtS5jf%~R5jpp9W9xxk}SMrfo>!3To+p-yZ9Bi6FJUG{;wNVDABseL<^6O@~B^f zlv|q?pmQ$2*vL^tH6*fRblC1cR&MyU9W=6y7OY2WPDzQy%73jsyblP=qG)7Gl?6H@ z>jfl%CKQ0^FyeKKfU*QhAYK=AK?e;XwbW=84S*JGW^AkO$d zr{26g61ya4rN->4xfmDYYo%%30NavWBNf}tDl~3Q4(Gvnn9biGtN?l6<3LixMKOu5 zK<2&z1P}*s0*n8C^F&YzTV*cba%Tn`qJ^BVFi1sG`R@M3I)GW+#bvQ$Ws2M(M`kj3 z{P2#1cL?3+N^Nc!)gE{e0ujvyo^_mFZOE=QlOS}y`tRv66x>hhP*JWK3>@mU?8t)* z?ofT1`W~2^PvO2trilRm-i-!rBD?d9(Iib-+e(y` z_{LhZqIL%5y^lDL_heTm3cvkcuFn>{K1-FYb34hAuq2`0FWZfdp4ItWbhrt;FkcGh&GOo4VW7dRbAqLiiW( zbtC5Bk)ki{-|vTrot%hozqUtY8q2f5;RLRB=dD>x#>#LF>DrLu4e~p@fn*^xb^!fC zFgx8`aX&HITa5JStp9vomq@20)s`pN>Wf6LsQ)jKh*}F%*Tz@Iu$r-X0JG_?i?Wgr zCuh(zlO^Sv4Gebxth1Px2;h|{Qr!(15ZpKwEh6SyINC8Xd(;s-Tr?FJXnyI`BxXE^ zf`b{Qq1vgkI%4=QuHbwj(%jJKVmPfFhFN}}O35k{SQ8dT3UOyGEuSzWMXSj{`Plt5 zhFLqLScw#@x-M1UP@6cCGdD{%nCP#`{nG^wga?R9C*wL1D)`~EAARi&m}CYF2t*d{ zR(aY;)|U%sdEza(&k38jfacj}gcFD4O{7UE_wi)7GuV3 z|IlvkR5xGZL<#ZZzh8}d^9emEJ3;Yu&k7t}eq)Ymj^afWYh;O~&idj5+KWZp#Gzen zA$pB^qNa?dEk^|XOxedWFHLbFi zSE`VSR50eE?5`!`o-uIu%`ozpolodyAH+3W8F2#g+T4O*A?xQ$f?y7-5g1r8HYSC* z-BRrcZfDlvx&7+%^)@8V9G9N4nf|(elv%t{UrCh_?-w6^B-AjK|9;E?R|C zQYcfhPH(gvZza}Avu8$pr6Kyp04bt3&`$WhD9pbnYgWrA)W*CcEJ(!gengA65epo( zBjk;vQfq;$&{Ww8;RC4STJ#vk=}mlP6uE(Duon6w94?pripGS3k?L(p=^f0P03iV& z4M|>8=HGa-aIyuwTW>7{_(JR)_|`v8-lFy7pU|czK`7Q=jCU@u?^R&G^9(wiF4>-T zCXf788mYqODkT}vA*?-5LBRKv=`OJwAa^=d$ET1PU9LhABdoJ}P;r8xRIY&|;Km%{ zieeu_TiH~J9kk{z%JY>8f8<2=^kVizT=hcwT(mt~sglup&yKI{p6Hzh42+A%vDU`7 z$Haj(MUbS*Ii<@KN3-4~_kp=|^kABGZKNSsK zfdJufwS=e7ngY2XRIxDb}lb3-HK*l!?75mb*m<|LE4?=%TKVS)t@3^ISf`i(P^u6 zcNbaDgJsqlDINBUB78|E*4;0IkDN%N`~N>5K_XINW7 z+j>$HVRZ#c70P9pdHO3;MT$N=j<1$q=CdPlvd^7fy|j2u7azc0npnb49pQvHsY}looXE+u=FLNy6v4dv1Dp^;a>oOj~_f>`!Wj2HKV0q8e9gAmSc|+2? z_2K6lR&bULa4EK(+#cmjmmWP9@n6%^*Rarx><1S5==T!K1!y>5-3jTvH$rpp0OE@c zSQx?MMyd~}he*){PpNDDv%qnp1iZK*U76DooLQXe_`}CUfqm7w2bRW0B=&T(Xy28; zIe2*zugcSE@x(N~xk4ynz*EB7LPH7@;lX%a?ZT;twlw7}IDS@w!QjJ3RtrCYJoAs1 zD>(i~BvT<}ZCqex^p;#d*_;6EWZFI>k7vE29=58RPJJ9VuJ&merV%h(Z zkmk53J%0>Rc%XYLe5*~%M(w~>*F^0h7eIciLO5H*wbARo-h>$>=qkjq z`oegw8NK($HohRsQZFQ>oTdw&i!niXHJ`6$m)slcbjL+Bxxa2G*XjzSTC_&8kJU7Mu3sd7HX&8Om-9W7YtH( zZ;OP30<90%P4KKel)AFxq+Wc;hk>+@e#0b@qr?4zL?WnFvO~cNIed5Jv?qHI-hY`Ai!hv`ot4yVOq;p;5JZy9G&fW}Smw z+#e<+m--C+2=_%sL9I@|@k8&E4s=t)p}iD%fA?)C_+g_}9Wp@NGfXjIS8F?3|A=o@ zCQ+KT#x=)iqJt0q+$Oz3Ts9)7zFz_x)Gp9sSTwH5=PD$3_a=wm8ftj!0%jCM$8mn6 z2hr>gsoll0M987&i+d62KoV9a$2psgh2qiw7V|C$DlOk&DA49ja`!5cjId6pD{3h7KSu@XfI zQ1bN@=}qVR*-I=VJ1+lv{|K4RbrRt_v5TdsDTI$~!zBog%olqVd_+p?iJ!C7=yLd+ zr}boDyCd{84@m2BKscFaFwmUj-ffT>^`aHt@sKfJ&qictXNP+~X9N0GjLBCYrAX`5 zrZ=1@KDhhN%4QkeIG}t4zjVnMJ{zA3RiNVW|(kp46Q1#`!0e<;jneIVSwOR;Uo6I3A)*tvmkq0h)LDb=B-fIB?>9X?ggU zT>H8|-7VV7_eP3@b+&0?7$00X{7F?U8LssnTdMJzl}IIBJDq$sJvPJ;V^onfgXOzK zZ;?34=_z5Jy!dX}dVI9XEbH;{m&uDu=?T#%lRV0s5-gA_#P0_TntNG8HXuM9{o|O^ zZ$xg}faz3fD6~`m?DXvS!vjsV%a2wM!nLO({dL_lROg6qyrfFNplWZoZ-B~@n98bJ zNfepzGL<;m5<<{^MsOunjW0I;Us|5Wc zjr-gJJU!-1UGXzeziEk+@`xDVX!lU4(pK%2n>6WIN~C_thdC?5{CPaiIj@EK?y~cV z$EMvTfFpIe5M4=$9|*<}qe_OI+*NkoLAyusHu-V_7L6=7)(1CKSv6FLpf%2$^%4$~ z>Xp}AZR*b2wGZdwa8TYV)xq0C87VI(dsaS7y z*SnbCjs%zG-PXoPhQ+((f+&a%9M8W789b5V6*#&(l8L?yhE*{H0T@Syzmttzbnx*y zeU~XJO}y9}%YAYc09DuZ>R2d@a|J4yV$6NEvUs2hT+^Zca75DK#=O@RyxgSV*A5uQW2E(Hm?- zxbL}KvouDYB~V%u>BX;9slV^B-8@QgN>%{H7yl|O?BRk4EkSwN3k0vjVax@ez}_xUhNnj z(UPyWLUHiqxv7#>yeb~;*IB^EuuqAA3n&z~iPIqQ%AdU``_uLFTW?o{^}QFRvSj%~ zW&E}zJ3_LjTI}ySfaS$4Vlu9+)W&>){rAO*?#>;~>;AyGGHCvC;2-aoM7p!)LFIRb zrOU<8k=1bbp{^X&v-@3L_7ti$pAcfpmfM!#J{|5NprR@4LgCZR-t*80NM?ONWNa11 zJgZ9C3|y~1-CC|Uf!BTiCKS#mNe{?%2B{{xd^Z}(_KD=V#?Bjv*soJ%%n`svr;scGWOs2TXm@XI~ht(Za=by!&`S{Ib%KZB$fq&oreWu$odwc zI#kB=m%PBbxzmsjRwmfUU~JH9ql;d{wnD#qwCGVrl5PJR9rscuP$c4~H5Y7H%89HC zySS{4n@Mx9bZrj2m83b?S@q2gIe}>|$T?~`arg=-#ZTC4^~{5{R)D?M{_8wv=+U49 z2XI}i`d>F;ZvhC=K*E+t_w49JLtqzg?CGrZ$e1^GQ0eF;4Xa5F!DKGMU-~g)CQmZj zLR92IK6+9kX2k3(J9GZTAp9-d&Li+HHA*xzQjDQ3t}c{ll+M)6KGKR!T(;6EiCL{~ zy3ED56r*;)>j1Z{e82y1HjT#Y3D;9L>f0QT((S!2f^auPCUu~g1G*4D~3=6^*fOT5S@oLl@fbNM653?yxy;ExZVijcUYB;6!u|Jw$s;AY0 zO?pN>0wdJ#z&I|$8=tLl!p|pg@6=H=zZpK}eL2%MB(O-#SrWYcKx2 z|3-ZEXM%gzyC)djP%}jEyqANgK|8g3uSb?mx0jb5(H9dnUTlmNz~x#)MqkbNT`&}PBW_)+P$mcE*P+< zYdC{1yOf_rG3z5PN~N$f|&CtXZ%J zNX8NKYBsf%k=@RMr|1HASjCrBZ;t-2Xi^Hy&OA$XDHeMxrQmdCZQ{AOnJyJ!%`O&Uq zlmppmP$gRq0V-&iSHh;^UOE<+-~;C?uj(}Sh^$-=NXOsYly|32OP$z5p_|ldz4u(g zxr`Td$wS9KAo=!}Swz{5sf4<9O3``xz)Wt2yP-dD3}`8Nzs@X5ri%TJ5h#=JY=|^^FaXEB zr?zhNyf9Op`DtCra=U+ZGjh1?E@RPN5Nmwz7V0UrlCQ}urtS`Agj-dL-VYHMmn{Fn z0v9tx@fMD6iPrdPz$-smo3czXy|-@%rGC5Fke^Ny&%YF#o#OJIWw(iH(fZ);%gPzH=Y!hGVS5Xbcsw0D^*5w;Fbmh~rhnPSv;mKZqrD*Ss4#f*?Xg#c*HqV$ z*ojg@tW;)alopfQoI4~fa|%B$2!2RakkSlC^f!%(+0l?7vBvybG0Gv6sy-`80>ejO zt1+;pV~?XUJ<1lR`X+a&UJplY(yy#5()d)DpcOr&lEQh6+tG9Edu96cvKSQ%U4|4= z7p02~?pUu_E0#bF?go5e&4my<5^#vD!i(#UqxrK+T3x-jBI^<4H}K# z_0~nF>Kd-Uz>2f<(9;}WmCn%6N93Q3JK0j6oMmE0bwA;xs05lE+*fG3TlB}5?%WWV zHZK+8??G0w`Ksep!PN0mzG*~Q(E7*mVXppaNaQXwA$L_8k?VzH!hTnHr$829CWBc2 z5tq`Eh7TFRO>%CQ1A$iDPvNpCFD3f#X7!gVTE{WJhhqUC(FaEVi^3P+z?HJ-WuN&E5c zY3}xi8cw^!_m2|#zXb;QjfB>O@puNF{qrh!YmDL^jiFKzLhFgCE}mU&$qMF2 zjYU(X-q3UKmmOOt{zS?;sFqiY9QjkSlKBY9v?uQ0L=#YgdD5Kw)Y5c8{=B5ongw26 z03VNdf-(|&6|ySdgcRH}n%wr<2ohJ2!+@20+@7MNzLMk;ugOq)Kq^c1^gQg&S_Zwd zLrxoK1}I>hl!3CQ3O&hWvVP}et>GYI8pYKAYgBiitqG4b5> zVGE}zTymuJ4{xT#iP?#QgKePRH&z@)<2Cd`%d>5D2D;spKQ~%)i&S&Qd`fK?9kusrDsFw%y^Mh{s2@r5f z5w0~&X-n-3wkt?`P+xXOqcLO5Pocn#vyDk;UIfkb6J=nBXb?Kuh|p4a!+Xej&d#gXdt_)Wnu00*RA(C*dLr}azr1AlV`ku|NGjJDgk`npE5)M=Y!3)~|V|^AP6W$Qgyaj5d1d$oY~3h4vhYPvP0aX&|COuZSfG2Ncdi zO3lNE{cSl%l_#_to_;2J z^TxBiLzMfF38Bel_(Qz}t7EtG_PSTM8Lv~LbneXUp(~meH*cOFm9}zoCX7K^$wTM& z@FwG-VkzjBSLc7qsJW&xK|{4PCT}H6PIp~vy^GUGX=@(vD`)|z1s=0G&1+~Mk{^h5 zOvdt?hox!wG7V9KEM-bVETR0S(xyu7Zz3kr=cp>fXG0k4Xe&kJ{iWngRVGdZqA8+!ujjZNZG_-BgX3zJvqRHeBDrSGqmC*BNW z1J`)K;B1If26H@h;j463ZCuC%LhAjRzD3KLj|>qDCe1FrZV>S4)hyO6QrcJ6bXoIt z?BS@*!HIyvW}(ROs@={Tj;7KmyHkUj!{x0vDzk(TKNUo7oY~CdiOKQ{F!fniMvQ4{ zBgwH#+Hmj54gTcbzEeu&_yH()ad`xB8j)N(N4tEOL9TX$RRfETe{@$$so>rMA97_f zZ9aPHyqETMIxs&H4EI~iK*2111Rr1vUr$U;@NaiH64d{N*&|Y&I_e31!P)g1tW_E( zFc~TYpj1XS`r8Nlp=!Y4hZW_^@3PPxXiN7RSJQnQXM zxT5yUw;!7Pl(DNehT8bYcK+<BG+vL4)0ZX^^h~#jvL_C3b$UZ1SriA>`tk{hBpXX)PiHQWqllL8tzI= zz3WPH{$_4duRWNi>(0efLLb0!GWN9d4A=6}T}~RxT-T3_rN3dkdw1VU&TVJm1zq5OdoPRGKY!L! zB~SL(VCy%5C;PUyW^Z9T^mkAdoA42HuXX{HcsF@ZMVU$}hduTFIiDMJs{m7a-4dFc@OeJ=8@gY7<>1y%w zx{NPe#B5!%l^2REA+Z{CjMU=+OH&BqPa%B9()#t{vjz|wldez=_lfV7GjzSqON<56 z5M*_2%Zx$?HagL+P(E*NUw3_!ft>)5w>jty3H=e@%=j8;G7e^^{QLbD@J3#MoM2vU zte<0BfR#YG2DZ1`Pe#n{y5fGP4ib+b6S%ouq8!MoPz>u|4eEo}4xVk`EHAK|C0Vo8 zZ-0~HXYM~Y$7?=)o4J4e`CifeMiy~1k1WcM;L97@`_L+%QV05YhD+5S--LWb+}`H?zD11G+&j5b%pU#?msz^3?SliR=TOy&d=`?b+*xR`*~$ zt!IBxjJNxuV8Xwy=9^m>VnZG^=oea9K&k3(K#6@7;rq0VQW;aJMuHq@_{2|)`me{T zTFba8`Z`uZ~`)UW_&R%aMBnSV}w4uYn^sw|!T@L^~Qa>W0*%8*NZ-1_}0kr=X~ zmEGjF|2^FQeL*=i4uKhi3Tv2$ByYk+ftsyI3|jP!#En~R)r$RR?nlUj@0rM{Q~*VO zmHfx^@xy!!neHSR;QRZJYlJsbQQlijR4e}XG?DPr4=(8HDPQfc=uCsk$hp6K^FI_& z>OE%2Q_E2Q)wu59>u7-6dnXOg%1!EF{uv4_5y^&vsUJV_~>4I3(C-T>+5gyOiPlhLRcstY?(K&O&@8V;cC`*Jc z^9P>@PI$eszKNlSg~nOjHq4B!XMwx4RzoVv-Hwr<3>Aj-3e0oWQkGB0-oNj*sFiV; z^?|e5*CDVw)pA4ybG-NOb{U*rQy_ojctml`Cv2nPq)Qj)THRx#h*V7IW{N2wvfG)J z;erxMA#=+k%c-ORADd!Woe26w#_HEa=f04kxMw??BgbPY%ze?57dLE`2-3LM6Rlzj z>806Rdu=pfcnn%i~p*O@w5gXfjGe4YfZ=Nen=WOzC-B>%rdSZ)vFgVLc% zXhRM@Os$oLcs|ypPL>k>lxP?A+*?b{FXcK8fTF`lCNq4#%I&7L5KtL`bb3RMa)BVb z(8|tmC_MF&6ZB7YI(lCfs}^V4@R4}0XW%I#R36ExbIDHSyk8-Fsr=eNcl_}_sL`8c zAPdWFRfE7~4uRQBH4*k5%>wXs1)XX%UCoA&xMi)pQ|~QMyAmxZKF5?~tklK}xL#@J zh#yLw3{Cq0izRW=ft(1!BpMDxSD&G1C$z>0Gw>8zXxhow3t4U{6G^jS?pSmdFTN?*#FsUl-_&IQQgS+ z>fmWF_D{cL2$Bj494klm*rlggOC^*T4j{+P8}q@9pDRXprxzV3YlXn3at{4oLkgk#erXEhDjB7(KO-~Zap%J?N7cN(=MDYVA@SjK zzUYz302;eF>rn$LMG}-`bhEG(iq8upbXvn!Ynj zV^k)Bg7&DL>wRHlQr&d`cl&ZWLsJrt-Hk21>k0;bh&J%@CiORR`|;#tCnb}I*VtpQ zyHu5z#$LA6D(UQ)p5$;2`Q##3wp&NKL%0hOW?Xb{xVF~g%$hgKaw4+wpE6yArQ0K4 zpGcy6_Fmha^o*0{OzMrDxd^k)s60Ds7CG6oZ-u4O`%ABN7QiGzEn{1vwHQjVE#a#1 z7t{UfKwJqJixL6q^kx4mYvasSpAcRh2TE3R$+DK20}e|bM8|%`E;zcSXxyGjb>7}A z*KVMe3ez{+t8Sko`F@t6NaH?e!MY8Fx6nRWBM5s$X6|EQ*{9iw)pp&xN)p+mPWkn6pRk8>C)|lbE*& zNRFl&48%{ypy1FkF@G&C&OBV}MRaz=PU%Q*Z+ybt<5n`A&p~oK>ACc73Qlj0G*z&v za_8bbUj;lP)ZlouMXXmDGj1qeQHewyiPVO8kmM>Pc+<^=C~=|@FTg0@SE@U z)Lzbtl1^r90)HIuKhpV3Ch7f(qTr!Kl3WX0%j)Zo=(7zjdN(+jSo(EWs2h z;#x~Os$QR2!c(?|cOUYtj3q;b+p8cb`VU{#N69dzn?O&wcI4fWqyo#UH)9;)QPe+O zbK&&qD71XnEv?xh?A-xzeV0_fd_RW1DIL`!oR-N2F=uv5DVJw;@6Q3h-LQWyyS<&z zow!F=EK#N1Z5nnq>Y?Ge#BizINP+oz0rJSfi;b7m3V#RK1u2QD03FGPCRT*mzvLB5CdaFOer6#oE2$cRS@0; zV?e(3mvHZAAI6viUY&?9fxKCuZt}hF_D}Gp0i4oHljYWKEtO4CEIueL4YdLW@umJM z8l~v|Wuo4MlRLxvFH3UjUWt&WWmDO17FbGU6E06bZ7Zgutoy90z*5m(7?e$z&KDBS9myK|G$F!s-Vp}(-?3mM$pO-o z!jBSBq9hNu>JRQl+#ItSbPeAosvDKaERLo7jIDkxn`(n3%8b#53gitR9TKyD{0w@$ z4#h&XS>i7nXQtSdxES}FHJ|?Y*0Oup@(hb=Ccx$c-w;ZdqY@oBiu((#(f8DIrUkBG z#Jhh3eySIH-Q2gTCYs?B*JSC4?OlI=xNiJK77-)hb*5pGY5mKYSBF1w-pjO1E&8+` z(fx_GGd}kyOEW!-Ia|JWvt|Az3IM2bp>}u1JGg1V7>LbWsC1j#ls_|sdL82cPEaII zy`GDAS+znqTO64p&T(ZfmaXMX95l)rM`G^#=6OmjYO^e#V?|;gO64a7)7cdM$z`Q@ za*w4dgCgj7fVIW}?6f0zYqr*b0Ev^IkX7=I8JL z72sP+wXf2H51xoV-_570jMRXjzUz%>C2CElh!3@B z98T%zMnlF~E%3f4$BSGEY6b?&B z<|BSu=)W6fmBFwS?}Edd>_)rZE^{9o?<@L3|3l`w^`5LTtG;i720JeQ1MG~$l2~cV z9h2Lz2%d2xvfZ}q*$J0`FGb0OFdny!a61mkFLH!WkrY~!v5RHUYS%w92UApk^r$0f zo`}?|lbDq&)_j;G^N1XC$F<`m22vNnZt(0^f&SV9fW(!Y&ru)6n(s}m8Pr?x(Ur}N zt3NDJANUA&00Al&-JgxeG6EvIKcU|R@-V%;(sCG*S9}ese&VMeJMPS}V|4+bF=%!; zLxgi|_BdTG8FY^cUk;b?=X^HpjZxmUa|I@Du6FFZMkKKXJC;r9i3wZUoDsDBLQG2- zii&56lU(-EACD3XdjMxE%yb*pPhFX=bD~%94imS-Pz9Q5Rxt=utE4tQ5*WXbF`vz) zaK{EOR#gegzS$ePOkc7KQy7vW6+1OH+wdpy!d?f&*YRL^Wm)0dyLX>J3(_ zVg7jrVFO;ntak3ksoF{{SXJgzLfr`gGFHFN zoHYr0*1J=^;|ZPLS?sLzFnF^(g1+U!ih_frtQ2Aa@<(6o3AHw<4%S~a#<9w-yhdPm z6bt0ux*a3WvFr6O4Y2JosCIeZ9I!z(~Kg62BZZw`UU}+9V z{{uAnAi0NOcyV=mL!}3{j$Y6yQ)}A!+!*en>#PMy8LL6OGN3{hz@_S76?=DJ%Dpn{u5B$2Rlf8P<{1>sMDKd0{K`4gK;x`=jL z;|<3YfBgslxd3QBuUM^-JD$sDC}`k4zcKcpx63v=fj+UF4spOUJZ*O5exwMHh_6YX z!mgcs4o;LRJ{Dh@;riSMlB>Q{c~&goR%n4Uej)x;)juvVV{(XiLvt?H`vM(M8aDG- zhBZsW7f9=`f((ONJQF4p*>;d!15uH*k3>tp=GwE*`P8U0F=_Ea$p-~zExZ}S_nn?N z=yhyoVf}L%$glK)#JWSp7m| zU}ftfM~F0f`iH*!LGuz0VZ$u|p{yTjOE4sAdgEv}vaI4G46&VD?1B7x$_3`Q2^u#H zFGGftF^af28M@W?iJxUxO1r-U@Nfs^+mT2JR@whVVOM$l{y&MMK(NJRhdD&MYD;I%u5CCKfsOiD0y>6DvQPNGQ%9FUHiylG zO387;L#e{z2V;_|6~-6#{ZleTJ45KXS$w2Fro;Ms{H&sXelN4{B@{OKdPyh{U`8WG ziit@TG{`TweNwGNHV1+wlp!M?L8peI5^0C{s*HEa5h@*R{x3ACCIA`*wXop8#ff73 zIpUM3G&4=ikgdXI2b}4lXs@5%XxJmbUAc9k(Fa{1qW-%zs7;PsAA(97v|^{+Z@;g$ zWoNFKsJ|vsN&EEaT;>Jy%-sLsfy23iw^ys3BJxq;L*GNW6?E8Ay)f<*fNP@M@Fo8} zkBE`)N`)49mFD)YcPLO%OGBXh>0RaRnC=zyT??=DzOg8FC`uUnecn&@+}aY)x2kC< zcm`!~OBpvg(29>sdcJG->oaE_T7d(r=PENrrSKmiO(AxM!dbPH2G>Q89Qi`83caxK zjn;Eg|AZ@{!f7@qidY`P4a2bZKEHzN2fJOpGaE`$`PS%tWKB%iX$LMIsWYSa3I;kO8k#N_v??%?CH*J=E$tg!P`IFQsLxdCJt=Z@GH8kQ_fr*gHC#DKY^QN2a z#%JrH})0xcDaq6kHssomy2#AQ_|uGr|6pgDJ}XCI<*a zz`?O)TtgggqdaR4BhLEmfjZ^j!t|uJIah8mo}~4KiuipQS+sOfOaD@qjGjcdnD07$QKR#HTei@Z6xrd!E^Om1jECF$bX_`33X^l#=;swPs8rM`)73sDlnhlia0Ma z@Xdk9V8GMHmGF_2pAL{1#fu!eeXu|2G=cr7*et4Ix#)YmAQ@2XYg z4G0=5f@tE0LECN{2nEv!b-xTEepdTa#f&x^x_@(Ym2mm#_=Rgl+8n(-QT(gT#;O%0 zwn(_{%<`9YcEHen`1S1{d-CZQlJ|6P5!_2I(NzV{2|`@#jKvI< z8#cT53HRRbibL&L(#zmw|2>_CZnoDitx+RwIdJ_014S&?2%4LFhZ&=H6QPXl1~(f> z6NaJyEhblZ&+{)nL5uQv3RFN!Wbdt6;h+m);ih&MHa{R*st3*k6Y4TnhH0X3?8VPw z+U}SDk!PfoyUVbSj$n8nO_H1ULH|o3-+R?Vy_2oZFU*63EnFu6-rwF{PxOw%%GSLk ztrqt;I39O6d*+5)2C{F4Jwfr@jNIERK%0a0#`7|-cb1FbXi2Z(vl?Is!71W4}5zGE0wP&kmIR?cgY z_+#cmfs}xY)B+;Ib~#CoY5jA+9G{flUal5^IHE}*-b|ncVT8LK1>ROF$r$ThG38YNEcJ!3A| zWM?MJWFIU`PvQ}4F7^DoMeV%SY@8?)O9#C_X`48Nk=(YWA=$wNey5Z#t;!sOKW|vl z-|iSwb?z&+)#nLIT)ST+uLxrAqBE;lDfPO&n?i(!Lh*u?-i);o;5QT)1`c95B=z!! zGq^M53>?{s>dSz@8+`$rJ>go+^{EzJROtPveP8WleF>%i#2CNNth@thV>n{)$=Ky- z)gfdp`5XN39S?1^nTV`RHd6{WL7m`{ja|Nsb9Ia#kCWV_CksVdi(KMS)3atP3C8di zzv#=P22ZiMVe9e-+ZeIoOjk6ky`=lwwX2}(iDj^a^^GG@`aIsc)?Pe79eUVLG618u zE!z83x?1BQMmYCR1UdYn>JJ!vXPrWtQDJpu#t`*b1EkUv8PI>Bj4Yr4;kcRuC(Pls znIdQT$gn; zKG_Cy!wGa6zeVD~D47McxT5F-^Sm!@QQ02HuEtYjGFolD{`vYG3X@oRQdJ-CDk2~WJ;OaspmL?-Pm`kRTY@DG# zaGqJc1T`ofGu=wamvaGVcqpf{&tb<8eM3-2&weEy^&qq@irgn84Mn$+O(mK&p8L0xd_j% zcAsRwM27_7Kwc`4wfZVO$&m*OX*w!$w%fm`H59@5z)~lQdn&C2Fna~)C z8`^#;9^uz6CbJCGUt;e!0du9jwt(GC9)m*6nIh#68ROt#-KFFjl+?RorHG9}9gU*5 zC@#tbVk))i>RzmAr;D!WB|RJ^eN@%9r|SZv(?)foC4vUzeAS-r=Xif#AP5Bw=oeiP zlX)T&3JJ?ZP^AAGis0mouD;G^F1X%+U=(E3!P-Ig8>nsU!XKpno4*bGi@*J& zeq`moP(-l6QGEithblRsW%5)`=*i(m)fi!1{`q=;^nB57#P@})prq@3N^ zWJ?$qVG%2E^vWIv4Hm18q@-s}S5AKrFFp^1Txbr)gH?L^7!_;GNg=X_B zG!u48vn>sdq$Av>1j)Ris3@>TfJBYH2n9gJoy##-%e%eUP@yUvN(g1ZtzH|@eoeqj zwcj22@|7JI`z>hhe0p2B)~u7UGV_y)&qGp;htq*PPD|!2&GG1I9#xY$VK~KpnQvQ*Bu`SoNj|jE7|zDH;`96X%rr#1NY! zW(sz&3T!t<(qs$0yy#)k-2HJ;HnedsZwG7sH^3C_B4&S5*1ZkiuGz!2z)GbbG8Olc zC~92Qo~Ef%$8!FlMI+Vmd>gB*>UUT)A@Fe=`)xmPtWAZX2_V?mQp+{!OvVsr zsOd!!zF>-puMdrTxe%*`Z(;`#<8n=yHA5zfwH;&pZ(SxMx2&-1=3d6r;1x0H0#{^S zE7VUR9MW5|8|}u7O|v<5d#8Rm*o|qB=?;;at6S3+0b`2E;9%3%7CDmbA@!6!67UCC zp`8#iKrxuOqQ~exR4lqMwW<;L5rs@trQoSwt2c^wJm+k!7*%)ajb!)nTyG2#=& zD7D|1APnGzP@SKMK^Z7?vox-Vo8`!N$-Jh~&M?I43qEYvl#I1`e_-y{SWSMy`fT1Zeuj(xh9&-66@Rb}lab7`kGKk@`fO+{pe*ktb{8#Xn#S zc{U-sghUSQZmK;lwFoU%2G0%?Qj}u4+{15!r-1+_N=k<31DGI)?EZ1>$k|8of%H{l zt=r7enIn{s>}&02zj(jq4444PClma1hn(u5sK^uB&c)kVI28+g9plly-jPb_Nwy7t zpR9mzvpTJAQ2#TP1oKJR^kk~I4bx+rfd?FszlBj2x(ud~=G79KI_ewT&LOw~LVGog z4qV(oVc)82&SGf?e6CO~2cm_i2H;#!-}ge zm#Y@Ocw_ksoCYN~er-r)Z>0RMY`eO_aCOXI^!wC`!F5!&qJY4McjJVQXyY`a&)E0g z{B`_XpQ%DVt>gW^HSNOPXi-zbPw#v8%ehyWRObR3PyEgx+9j;}Os#vKdY4cI3d!NX zE!G?l52^jSsLQA?1{*$zPHl6a;B{x@FxjGLx!r(rfDXC-UX+vK2c@o}1y(u1OJh zp#Q;YXgmJ_JCz#iNNRbkv7Xyb@#oIb*=&Y$$hFgc?eVg@D@Q4zF2m{bKJdIaw*Qnr zUS>%}>h8=^Yci^B91E2p{qPzjB!km5lD$=6_p885gtSeEsE2}v*Xe@p@yevx5r*%F zJbf+YqFfA?Y1bu+LKdA=+v(iqGh0nLXt>I||5YUL3U7AfatKl7R53Y5o zy%P$L@?+v3DKE!19Zs-Skgy$vK0_fFzXb^LBG()mZJvPsoGqR(AFDz*O=AQOn9P#v zZjR>o51pwd^>q4!MKRxVe*dj1)s@qbHX{3uLgHZsp_iAVyHt$|!0yL)F4~{%*imjfq#6&Nr;*HS zVUN~(m!;DApN&^JVmHOohkw-u@k`7b-!BdHs_il_@y|6x4q>&YUMjp24HFo^@(#0| z7TJyKlUT6g3soPU-ZLMGPSCxZ78BDod0hf6d}NBEwtu37AEF%gnC-s+>kEY`Sl$;O zbx-qD%Yji`oz5nc{l7+OwGDp?5@A=UC%g;#d5PzFiC8aFobl^of7O1~HY;bovF9hD z*p781EGe{q&?iSlGXv3n6}Mq?a@OI_=RaGO(c))qmS?t1+U;|cO-2^JlneNFbK~Uo znb1Mz#8n1;IhVb-_Nwrlz@r{I6QPfznTdI9cjJxsZ_R2Jjd-5GoH=aSJe{$k7xqV5 zYophYGiI~a-y_B@Vaiq^!>@E_Jiv|2vhPNyu!^vdx5C?W&ZxTBRV*ccx-SXnMN zL|QG6UyoXDJ|t?LEZT)o8IOchC#QW%mmfZ3AX~uc1-SRMA@JGV9KoPF0SE5-XpA}o zEF%uS;1QUV+ukoMxdUtmf@^}3%a6F}yL6@QTy6Tl-fqnbrB+TSab7wd)e5KjTZpN3FJK?*7 zd#Vf;AS2_D$AxX5E@OY|mdX*p}bE$K05x-VTB(}f;aILtx>lB(_+s0GgGj=!MzBT#lK^mT= z?&LlS$AuWQb7j+_AFVp6B&SEG$1yMZ14VfHj)C0~-~8Z8?Tb&LcltrY;WuLSs4d5# zLmZu@-{o#+T)`(&V~7464R25h6mbvx7Yy{c`a2h^P2iPYdOqQxFameShfp^a zjjP>&B~)C*VlT7#S_t0o%FQ#v_US@ZPOI5nvTD;sT=&s;{q8O!`6~HT72583`9Rm- z_T9K}9m$R>j|ZE26)I!T=2ZLm*PTJ0R7ObgMED(nWc+>min(N|at#0sW3kusq1t+I z-@?j;nEBkU=$^Jk-$tktMdsv}x6Ip1HBWdp@Mx5L;8jLG!a+!+w*!^HVDTjXr`83n zBAM1B?oeLAHim?{y%A6u&)~Lpd^@%^(zEK@W@Y>&lO~|MmN3T5ie#u9onau47tQ;R(I8U2K0KDV3-s6-};dr3g?%E zR~(8PBUMYkRP(lEwyYJMS?Tel z;JnK3}op#3-_}b7e;`&7mk0V1SOUR)@x=qw7^d-0c&}-Fk@E(L6@yr!%tCtNZCvrW(h+c*oTI>qkMB-c+y96wJr80SU%ga% ztb|SZF)*Zx$>+NEa6R^q>onCc&G~=2p2;sqaWR?D7cr#f{8M&FPWsqtf0is>&W%5@ zIpT7t($0vxe@m+SD0?!OR4@`_t$537?{rN#N{V8sKGcH*Av>b@%_df(*(_I#Z^J?_(-IRQk#95{DVdq z{6iXgs|RqeijScY46%GuD%}nbIqZ^fRuLX9(9qFn+j;>-v~kl>d_8TvX}5Gh)Z}4D zmyS-SE3#Z#e%c!OiFn1|y z2p18&XFt-CaH-{H6aA%yx5hR%&zSK6M;?tPaQ*J+Cwv|VIj z`6oXRk7W2p3u-1{A`|UKkAzpM#!Afo5>+^QVl)0uz5C5LYRtLt*9cB~qA>JtJ5#{gr{+rK;3cDtcb_!4bAJ5Mkr{nArrWy{{vGao z@3*AYc}RRb51kB4`@aR!j3Y2*ez@MY&imZs@Y3n|6Yj7UCHagY@}2F~OoeHo>)5xX z*Qg5QDC$MJiDVo4lg>RKH!De1$23_e zWV;FkCWde3$?0pfkTFSsr5*+@dZfzP0w(X?Lv40`FkGW4@{nz%2{a+`^8TVd45L;l z=G%?R@(n#w`djNb#lMzGaQ&IreEtW4HH;z#^E93rXtv#BC234Dh&|Q!Xky=>4Fb}4 z>t4!q)wjM6z*gGbGQ4Y!B}s^sFI>&OIoo%+bEG~%WLDa>9oq(7=c(+N1XD%jh&U|l zzx3TeR&-2?GGcRa!Dwkg>}z;soR$cmP5l+V8oV+SbWg3sVgRyodEzYA;OvjC8#Fov zzv*3WIb4;RGe+*WTkHjO2v zQm2VM5DI{uep;(7(yS1&E=c znd9)RxGBsbP3u!N&t|&txcH0a1c5AM@-{_|pj>hnz{Z6^s=gbt3%XZQxU3P5mL^l< zM6}rawo-0WhN z)78cKCdE%yE^i^Fx}IAO5FL#=3fs@q!`WcXIL4Y_GdQMgd4dNLT;&27IZ4q6!x|@! zKmIfaw0nxNYbvNZ^+b~9S&*!-NUsfpzGn3ra2n0QucSv5CJ9nRC13dL-oIm-886{~ zM|PBhKY{?UKSClxcF!K@m2J4juE|n|ESI__m$)l1G@ueMsnOjMno>6G+F{gC3qWsU)#BM4o;a_ zNk%C;fUMHvk7P+*dKXn?)qJI@A?0ze#fmQ%AQm4zRSW{vbaBHsZk)u!tE4A)*2D=0 zNBYqT1t16AXkFgZIuh2y_C&(izF-BxL+M8gGc*&e+A;TBjwMQ znT?%Fr8b*IcnI-0N4l;uS&RWBIAGXcw18f1%vBC*dqz8u63r?i^XrKx>)JcF&i*M5 z+%XkZc*|PEC7l*K*bU9Ggifr5t~*nEmvtcS%$2^fm>DJ6$stK@`t%z0M0O9?4T4WA zdkfl1_j6AQ5{8XrXG1J_0TPSsi@7q#&}#Czk2To4`{|)DZX(o`hAIV*D`FCglpA0@ zMYZ{;FJ)AqcFS1Q3(rdoMK=bkpMeq-2CtQbYdz&=c-wHP>xr<}3cPhTo=kU< z2xDhMIUsFR>STObCsf4o6@tNpquLwSuPEv^dIDYcJ1z@X7bjyX%lRuDtMcc7)d=&t zr^F;a=d!`*lBI}Y#jG3EP%k|GfRo=D8f5fP_OK=<@g1RY(lTym~q{zbr*{ zYrBHoXDjlipAo%#K%$LlEquil>rMam-6LdA1BF7Wk7sQ-frk0?yKj(2AaIT6B7C)< zZT)0B6hVq%L3P^YJKvjG9~fn0Mjo{$nZ5Q7gBA)3*%gfHz=f!5&}%UcMRk)asz>bO zIhRPHl818VVFllAfm(3*3{7+a5c|vR4mhuQJ}%GKPbPaFgzUAY2l}>@0qd6KO=?C3+&lrS*n4 zU2ecVT_Sm^YNLZ43bwaeny$2`{b7dKGm>m@a6WXzazOM2uC2Grjo4JN0$Ss=&I?unN8@SBO8iR@S{*L1?1tQ#oCxkPSxXaP4qsHvsee7&rg1 z_gDp&;(^fYi1Mt3M*hX*4g7-H`Jktt5%Reugy16`uHkD)zv6K-r1U~y1Vx?%$BQ6_ zz^@GwgmHXkHEPbIu+ENzUV83LHDW>AGcm5bhJvcbpVyzm4MZm#&aM2V4@>5vmjb0V7i$1zqz0soRsz_Vg029H z%a_7GQ*%g(96^(|G^Ax{WrRU7Vm!bKCYd8cKN!$&kR?AhT2i$j!FzF(T4@d^?RgXw z6QRHIL5z|70jCI$7%l}uji;YQQe%vmS}tt$X<%`*PNsE^cobL}o1E>^JYO*E9>I!C zg~oj5XbrY@`!<>^h4mlT-dz$`Fj2TX13V5}l(w+&cTqCz93)S4ue?5IT)n&BG){0% zen}!PRe#6~>erKgWQ?sy8541|i{(ay`ypuK?}1*FurG~i<{vR<*2N%17nRW_@yFt$_=HWy=6#7pmDTlm&Y!L+^gzP27*ihOvIxt0BqY3W zHQ;~%F`+nPe~N8)E^;BqM{*pT8l~?kOuA_Mp9~m|KeAt~8nQwz$=BTGQ9VF@?W{RZYH7$hoT+sH3?y z-J*oJO<}La9D|lFQkpyT#aiEcF-wm~3dirRsU? zHFCA)CJdAM<4Ao}l>utL0HFZP!sqenG!~ z&+I#2*5IGCW0W&hiqkZ2>EMNW*L%bT8ZMm&t!j#OQLjQFOfZ+|}jdF7iy!o~CxPCh|+%&aQr*djOI2P9ar zgg5tK#nau+hDOx=ch>WP<8yLmLwJX7P6l&ZE!Gr6{!oq3j&k)CB*esz>;qAWPm3zK zj6Kf|n{oG;OI)R-#8Jhz!)i%MOvX&*mt<9y{Ba+ zogs~KV^acJBAxiDS z$>9&_oJ$J{Z5VqkY*Bfdqd7aJ-PKM1)ia0O1Re?u%xyN%b+hmR(F=^9`Da6U1Ub5^ zldO$Ug9e7pJI?rHO0L^0H5Q7*tTU}^YwHhGLfJJ6Vk>&2muJi=(qE;gwwKNW!gPsO{w~2+!ff ziY426dU*g1y!(^ia{^{k_5xnNZY14-=j*?l31UsOo< zb6cSgCpHOvcgTAIpmCxNK;=T`of%VKTlsqdQ*8|`pgLa_=w&%V&NVW9U7_qynX-y^ zs24a4hlY;v3Ms|FjTk9>y64}T!EV~!ih_q1mC@eAwOaVM&>`9-fgv{A^>XXr_HB!( z4nHBH&r(2{Tt^1ltsoS#7BYd-D9!IPNWcF#qXT9n5UsW^EML*f)2|RryP0z3r>y}p z@7Gt#_DrWx6U2%Tg5o8XiAU2xY;_#L`0~bzc~2|7J6oc_P97F=L~+dTU8t+Kszb{OZiNjFB19qT;RwHI$+gIzLn3~pEr zP^a^DH|Ww~D5rw`3P_MFAL-<{ws)^-d0l@KJcGsQl9XjT#r3sP9N)@^pT(>5`3sO0 z^4Fu`aAsVZ8}r?Nb#oEECz6RQXnp&7|r>8leHkbh;O_B^cI=}^LH}-d0*9tNkkF{y(~1K zv|T*WgCqf#g;hcnHLhub{9l3lt_ki+zV+Uoke}S)&HMp(yy^;z;hmte+70Wh=$QNC z9lNT)ER`}u9`1iR6HxO=(Bp-JL8SxwEQDklT;U!r2s|^Qb|roDpA6y&Ma_d_Z#iCps?{JcS(Ig%doKD~)6-LG&K5 zs@JVatNvfLTxmR%Ya4$^CrdKqNXIg^F^#2q{g{%BGZ@KIF^tAG6G}AXlqKYF)NqKO zK^n_Qa&Qm^*&P~7^V+k;WEpjm95Tu>MZ^2d4DW~c`}^tsJmU#kfD)$f98A`$z|5+Z!}5EY(HdF_Bh6{t?Mt z_7ltH3n0AZ{qrh+-}d&l#;lgln*+TGSluPMi*L8nqx8SdNcoeeBb74*bZd|9pkL`Y z=JY3IalCSVxX|``vYcGv!SNtb$|cdKKg5HBRwDdsvo$b@uS%VDvkm<~Vz#?|BZ|OJ zHMcZO&RZzuevV0RjzUc>eQ9Mn?D8$lkTgMwlU5Q9T3$He{%c)q@ldpS@6jCB`9=3c zFV-1{&{j|)4G&uQ~%6uZJs zH@o3oZA^WM&do^j(FaeWS3-}UM-OUSr)6dfgi_kK)AUtz|0)SG)MC zSDp*_(lFl6LT(-SDSy}L_?q$);Zrx#m<7l~$V!?^v`n0Ewa5+h?Gj## zs|c|;6eP=jjptnKD1Sx9Gz!Y1*`_l8di`KSz?i$V=dSv_)E%qtG#G-=sw_#q%+YEz zf7~xZ$2P-gZe4868F{N=n_%!>*woT$YJZ6qXx(9=3?`X6t;P2bYLt6{nQe3fap?&# z*OSrl(u(J~vC4pRw!&*Q;%)bXgaz3Jq&vs!xly!W{$!QF&~U`Zg%bujg8kMk9`nc3 ziR$_J`C#aQ^Ky7&Kg4?$ZE0Yx;m4+Q>rbXhxPZjlx-QLU<@?@<`I@A%gmz6pyTgTT zauCSRpxsM4YmQPdi>MbSKW=wVM-Ue^C)bs1T;!vYip zSUA1Jt-j)YFIA>?bK3V5#?tahRqa-g(XK3q!8?on0>EepN&%`f@sK8dT0ohVN_Bq1 z*EG)1qM6L1$yF8t86N3Ms(+N06vKALg*y#q1{dmbii zU!X611MddrndO03(Y;S9vL_WWy8yWUtm-*Lti4TP8d)m(f-D$Zc|G^9tGZeCQVM#V zELN8PSZSi)F2ID>&+xWH;w)HF?ppp~oIouUN0vD0{K^YqffH;W>0!!o*b+w!i8RUy zNGj#lT@|=5_QbvMmZ9IUB7s17XafVT_C8V%Tb}Y6e9Xh}V%o`4*v4Bh8jVKLl5lgA zb}n_wzcpv`^m>EZ_sZPnd3&Wq(UW&cgu~`XAa>FwH?H_!tNXCCBPoXUvPE{PWaC(o z($?XRhdMs1q&0Lr-0hAu+_0DzchVsNkls^w_>=HE6!qF|%;f9fnV zD3tdYpWajk0GF{44=3l0s6ZRUX6qp82rxUd9<_6`O=mGe&VFR-t2F?OkJSZ?u4`>g zwKb2OJjn9h2Taw}1r(Gz%41}^+(r&jZ$qiQuxmAVJMC{rP5?2HLm?i9_J^IFF&Y0j zEMKV?1zOWuIR$>%$9fG)A8NeQJdLqD$l(leJ!pgTO3=dq@hqYcEcTS2YiO%}?j`~d iFg6Is(t_`6&9UrApFV0r``R!+D(5j5EGx{3vHt<&K}~D` literal 0 HcmV?d00001 diff --git a/docs/source/index.rst b/docs/source/index.rst index ca339793..7934db1e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -12,6 +12,7 @@ Welcome to nnScaler's documentation! readme quickstart + nanogpt_example pytorch_lightning parallel_module register_custom_op diff --git a/docs/source/nanogpt_example.rst b/docs/source/nanogpt_example.rst new file mode 100644 index 00000000..3fd9685d --- /dev/null +++ b/docs/source/nanogpt_example.rst @@ -0,0 +1,193 @@ +############### +nanoGPT Example +############### + +This is an example showing how to parallelize `nanoGPT `_ +with nnScaler and `Lightning `_ trainer. + +This example contains one single script, ``train_nnscaler.py``, besides the original nanoGPT repo. + +*********** +Get Started +*********** + +Installation +============ + +1. Install nnScaler :: + + pip install nnscaler + +2. Clone nnScaler repo to get the example :: + + git clone --recursive https://msrasrg.visualstudio.com/SuperScaler/_git/MagicCube + cd MagicCube/examples/nanogpt + +.. + FIXME: update url to github? + +3. Install nanoGPT's dependencies :: + + pip install -r requirements.txt + +4. Prepare dataset :: + + python nanoGPT/data/shakespeare_char/prepare.py + +Test with Single GPU +==================== + +Now you can run ``train_nnscaler.py`` with `torchrun `_: :: + + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +This will train a baby GPT model on a single GPU. +It will take several minutes and the best validation loss will be around 1.47. + +Get Distributed +=============== + +nnScaler is meant for distribution. For v0.1 release, we are focusing on data parallel. + +If you have 4 GPUs on one node: :: + + torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +Or if you have multiple nodes, for example 2 nodes with 4 GPUs each: :: + + # on each node + torchrun --nnodes=2 --nproc_per_node=4 --rdzv-id=NNSCALER_NANOGPT --rdzv-backend=c10d --rdzv-endpoint= \ + train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +NOTE: The local batch size is fixed by default, so using more workers will result in larger total batch size. + +Tensor Parallel (Experimental) +============================== + +nnScaler will support tensor parallel and hybrid parallel in the next release. +You can try this feature now, but its stability and parity has not been strictly verified yet. + +Using data parallel: (each model instance runs on 1 GPU, 4 instances using DP) :: + + torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=1 --runtime_ngpus=4 + +Using model parallel: (a model instance runs on all 4 GPUs, no DP) :: + + torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=4 --runtime_ngpus=4 + +Using hybrid parallel: (each model instance runs on 2 GPUs, 2 instances using DP) :: + + torchrun --standalone --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --plan_ngpus=2 --runtime_ngpus=4 + +Resuming +======== + +You may resume a interrupted training: :: + + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --init_from=resume + +This will load the latest checkpoint saved by Lightning. + +For distributed environments, checkpoints must be *merged* when the environment changes. +Check :doc:`the reference ` for details. + +.. + FIXME: link to the section (dunno how to link into markdown) + +******** +The Code +******** + +The example code ``train_nnscaler.py`` is modified from nanoGPT's ``train.py``. + +The modification consists of two parts, (1) porting to Lightning trainer and (2) using nnScaler for distribution. + +The Lightning port is not the point of this example. Check the source code if you are interested. + +To parallelize the lightning model with nnScaler, there are 2 noteworthy places: + +1. Define the forward function and declare it's inputs: + + .. code-block:: python + + class LitModel(L.LightningModule): + def __init__(self): + super().__init__() + self.model = model + self.dummy_forward_args_fn = lambda batch: {'x': batch[0], 'y': batch[1]} + + def forward(self, x, y): + _logits, loss = self.model(x, y) + return loss + + A separate forward function is *required* because nnScaler will only parallelizes the codes in ``forward()``, + and will not touch those in ``training_step()``. + + And then, a special function ``dummy_forward_args_fn`` need to be defined to the ``LightningModule``. + It takes ``training_step()``'s ``batch`` argument, and returns a ``dict`` presenting ``forward()``'s parameters. + This function will be used to trace the module's forward graph. + +2. Register nnScaler's strategy and plugin to the Lightning trainer: + + .. code-block:: python + + compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, constant_folding=True) + strategy = NnScalerStrategy(compute_config=compute_config, pas_policy='autodist') + plugins = [NnScalerPrecision(precision)] + + trainer = L.Trainer(strateg=strategy, plugins=plugins, ...) + + For data parallel, always set ``plan_ngpus`` to 1 and set ``runtime_ngpus`` to the total GPU number. + + Other parameters are used for performance (efficiency) tuning. + +For details, please check the :doc:`API reference `. + +********************** +Parity and Limitations +********************** + +Single GPU +========== + +For comparison, you can run the script without using nnScaler: :: + + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --use_nnscaler=False + +This will result in a similar loss curve: + +.. image:: ./images/nanogpt-curves.png + +There are several causes for the mismatch: + +1. nnScaler and Lightning have slightly different gradient clip implementation. +2. It cannot fully syncronize the random state for dropouts. +3. PyTorch is not deterministic by default. + +To get a perfectly matched curve, use the following command: +(The overfitting is significant due to the lack of dropout) +:: + + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --deterministic=True + torchrun --standalone --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --deterministic=True --use_nnscaler=False + +.. image:: ./images/nanogpt-curves-deterministic.png + +Data Parallel +============= + +Here is a comparison between nnScaler's and Lightning's builtin data parallel: + +The curve is not fully reproducable due the nature of parallel. + +.. image:: ./images/nanogpt-curves-dp2.png + +The Lightning Port +================== + +The Lightning port is not exactly the same as the original nanoGPT training script for the following reaons: + +1. The Lightning ``Trainer`` is different from nanoGPT's training loop. +2. nnScaler v0.1 lacks the support for multiple parameter groups, and therefore the weight decay is configured for all parameters. + +.. image:: ./images/nanogpt-curves-orig.png diff --git a/examples/nanogpt/.gitignore b/examples/nanogpt/.gitignore new file mode 100644 index 00000000..7c443661 --- /dev/null +++ b/examples/nanogpt/.gitignore @@ -0,0 +1 @@ +lightning_logs/ diff --git a/examples/nanogpt/README.md b/examples/nanogpt/README.md new file mode 100644 index 00000000..9468fa06 --- /dev/null +++ b/examples/nanogpt/README.md @@ -0,0 +1,14 @@ +Prepare data: +``` +python nanoGPT/data/shakespeare_char/prepare.py +``` + +Run without nnscaler +``` +python train_lightning.py nanoGPT/config/train_shakespeare_char.py +``` + +Run with nnscaler +``` +torchrun --standalone --nproc_per_node=1 train_lightning.py nanoGPT/config/train_shakespeare_char.py +``` diff --git a/examples/nanogpt/README.rst b/examples/nanogpt/README.rst new file mode 120000 index 00000000..a9f2be20 --- /dev/null +++ b/examples/nanogpt/README.rst @@ -0,0 +1 @@ +../../docs/source/nanogpt_example.rst \ No newline at end of file diff --git a/examples/nanogpt/nanoGPT b/examples/nanogpt/nanoGPT new file mode 160000 index 00000000..9755682b --- /dev/null +++ b/examples/nanogpt/nanoGPT @@ -0,0 +1 @@ +Subproject commit 9755682b981a45507f6eb9b11eadef8cb83cebd5 diff --git a/examples/nanogpt/requirements.txt b/examples/nanogpt/requirements.txt new file mode 100644 index 00000000..28dfed63 --- /dev/null +++ b/examples/nanogpt/requirements.txt @@ -0,0 +1,8 @@ +datasets +lightning +numpy +requests +tiktoken +torch +tqdm +transformers diff --git a/examples/nanogpt/train_nnscaler.py b/examples/nanogpt/train_nnscaler.py new file mode 100644 index 00000000..c4f60964 --- /dev/null +++ b/examples/nanogpt/train_nnscaler.py @@ -0,0 +1,274 @@ +import math +import os +from pathlib import Path +import pickle +import random +import sys +import time + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +import lightning as L +import nnscaler.integration.lightning.pytorch + +nanogpt_path = Path(__file__).absolute().with_name('nanoGPT') +sys.path.append(str(nanogpt_path)) + +from model import GPTConfig, GPT + +torch.manual_seed(0) + +# ----------------------------------------------------------------------------- +# default config values designed to train a gpt2 (124M) on OpenWebText +# I/O +eval_interval = 2000 +log_interval = 1 +eval_iters = 200 +eval_only = False # if True, script exits right after the first eval +always_save_checkpoint = True # if True, always save a checkpoint after each eval +init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' +# data +dataset = 'openwebtext' +gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes +batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size +block_size = 1024 +# model +n_layer = 12 +n_head = 12 +n_embd = 768 +dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ +bias = False # do we use bias inside LayerNorm and Linear layers? +# adamw optimizer +learning_rate = 6e-4 # max learning rate +max_iters = 600000 # total number of training iterations +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 +# learning rate decay settings +decay_lr = True # whether to decay the learning rate +warmup_iters = 2000 # how many steps to warm up for +lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla +min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla +# system +dtype = 'float32' + +# nnscaler +use_nnscaler = True +plan_ngpus = 1 +runtime_ngpus = -1 # use -1 for WORLD_SIZE since nanoGPT's argparse require it to have static type + +deterministic = False + +# ----------------------------------------------------------------------------- +config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(nanogpt_path / 'configurator.py').read()) # overrides from command line or config file +config = {k: globals()[k] for k in config_keys} # will be useful for logging +# ----------------------------------------------------------------------------- + +if deterministic: + # seed is set at the top of the file + dropout = 0.0 # must set before model init + grad_clip = 0.0 + torch.use_deterministic_algorithms(True) # NOTE: requires env CUBLAS_WORKSPACE_CONFIG=":4096:8" + +# various inits, derived attributes, I/O setup + +torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul +torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn +# note: float16 data type will automatically use a GradScaler +ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] + +# poor man's data loader +data_dir = os.path.join(nanogpt_path, 'data', dataset) +def get_batch(split, ix): + # We recreate np.memmap every batch to avoid a memory leak, as per + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + if split == 'train': + data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') + else: + data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') + x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) + return x, y + +# init these up here, can override if init_from='resume' (i.e. from a checkpoint) +iter_num = 0 +best_val_loss = 1e9 + +# attempt to derive vocab_size from the dataset +meta_path = os.path.join(data_dir, 'meta.pkl') +meta_vocab_size = None +if os.path.exists(meta_path): + with open(meta_path, 'rb') as f: + meta = pickle.load(f) + meta_vocab_size = meta['vocab_size'] + print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") + +# model init +model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, + bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line +if init_from == 'scratch': + # init a new model from scratch + print("Initializing a new model from scratch") + # determine the vocab size we'll use for from-scratch training + if meta_vocab_size is None: + print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") + model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 + gptconf = GPTConfig(**model_args) + model = GPT(gptconf) +elif init_from == 'resume': + print(f"Resuming training") + # resume training from a checkpoint. (handled by lightning) + if meta_vocab_size is None: + print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") + model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 + gptconf = GPTConfig(**model_args) + model = GPT(gptconf) +elif init_from.startswith('gpt2'): + print(f"Initializing from OpenAI GPT-2 weights: {init_from}") + # initialize from OpenAI GPT-2 weights + override_args = dict(dropout=dropout) + model = GPT.from_pretrained(init_from, override_args) + # read off the created config params, so we can store them into checkpoint correctly + for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: + model_args[k] = getattr(model.config, k) +# crop down the model block size if desired, using model surgery +if block_size < model.config.block_size: + model.crop_block_size(block_size) + model_args['block_size'] = block_size # so that the checkpoint will have the right value + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + +## Lightning Wrappers ## + +class NanoGptDataset(Dataset): + def __init__(self, split): + self.split = split + data = np.memmap(os.path.join(data_dir, f'{split}.bin'), dtype=np.uint16, mode='r') + self.len = len(data) - block_size + + def __getitems__(self, indices): + x, y = get_batch(self.split, indices) + return ( + x.clone().detach(), # theoretically unnecessary, for robustness + y.clone().detach(), + ) + + def __len__(self): + return self.len + +class Scheduler(torch.optim.lr_scheduler.LRScheduler): + def __init__(self, optimizer): + self.it = 0 # must before super().__init__() + super().__init__(optimizer) + + def get_lr(self): + lr = get_lr(self.it) + self.it += 1 + return [lr for _ in self.optimizer.param_groups] + +class LitModel(L.LightningModule): + def __init__(self): + super().__init__() + self.model = model + self.dummy_forward_args_fn = lambda batch: {'x': batch[0], 'y': batch[1]} + + def forward(self, x, y): + _logits, loss = self.model(x, y) + return loss + + def step(self, batch, batch_idx, log_name): + x, y = batch + loss = self(x, y) + self.log(log_name, loss, logger=True, on_epoch=True, sync_dist=True) + return {'loss': loss} + + def training_step(self, batch, batch_idx): + return self.step(batch, batch_idx, log_name='train_loss') + + def validation_step(self, batch, batch_idx): + return self.step(batch, batch_idx, log_name='val_loss') + + def test_step(self, batch, batch_idx): + return self.step(batch, batch_idx, log_name='test_loss') + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, betas=(beta1, beta2), fused=True) + scheduler = Scheduler(optimizer) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'interval': 'step', + 'frequency': 1, + }, + } + +## Training Loop ## + +def main(): + global runtime_ngpus + + precision = {'float32': '32-true', 'bfloat16': 'bf16-true', 'float16': '16-true'}[dtype] + + if use_nnscaler: + if not os.getenv('WORLD_SIZE'): + print('[ERROR] nnScaler must be launched with torchrun') + print('Example usage for single GPU:') + print(' torchrun --standalone --nproc_per_node=1 train.py nanoGPT/config/train_shakespeare_char.py') + exit(1) + + if runtime_ngpus == -1: + runtime_ngpus = int(os.getenv('WORLD_SIZE')) + + compute_config = nnscaler.ComputeConfig(plan_ngpus, runtime_ngpus, constant_folding=True) + strategy = nnscaler.integration.lightning.pytorch.NnScalerStrategy( + compute_config=compute_config, + pas_policy='autodist', + reuse='override', + ) + plugins = [nnscaler.integration.lightning.pytorch.NnScalerPrecision(precision)] + precision = None + + else: + strategy = 'ddp' + plugins = None + + lightning_model = LitModel() + + trainer = L.Trainer( + strategy=strategy, + precision=precision, + max_steps=max_iters, + limit_train_batches=eval_interval, + limit_val_batches=eval_iters, + limit_test_batches=eval_iters, + accumulate_grad_batches=gradient_accumulation_steps, + gradient_clip_val=(grad_clip if grad_clip != 0.0 else None), + plugins=plugins, + ) + + trainer.fit( + lightning_model, + DataLoader(NanoGptDataset('train'), batch_size=batch_size, shuffle=True), + DataLoader(NanoGptDataset('val'), batch_size=batch_size, shuffle=True), + ckpt_path=('last' if init_from == 'resume' else None), + ) + +if __name__ == '__main__': + main() From 13ad8ff5914fbacaa80ec5b72d911da56f2bd3fa Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 15 Jul 2024 02:59:17 +0000 Subject: [PATCH 1681/1892] Merged PR 2199: fix ifexpr warning fix ifexpr warning unit test pass parity check pass --- nnscaler/codegen/frontend_mapping.py | 2 +- nnscaler/graph/function/function.py | 14 +++++++++++++- nnscaler/graph/parser/converter.py | 7 ++++++- nnscaler/graph/parser/fx/mapping.py | 3 ++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/nnscaler/codegen/frontend_mapping.py b/nnscaler/codegen/frontend_mapping.py index 61bd27e6..8bcce3d3 100644 --- a/nnscaler/codegen/frontend_mapping.py +++ b/nnscaler/codegen/frontend_mapping.py @@ -19,7 +19,7 @@ def __init__(self) -> None: self._sign2rule = { 'torch.slice': self.emit_slice, SELF_GETATTR_SIG: self.emit_self_getattr, - 'nnscaler.runtime.function.function.ifexpr': self.emit_ifexpr, + 'nnscaler.runtime.function.ifexpr': self.emit_ifexpr, } def map(self, signature: str) -> Callable: diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 5979b4d2..613c3c55 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -79,6 +79,18 @@ def Identity(tensor: IRObject, signature = None): return IRDimops(Identity, 'identity', signature, [anno], [tensor]) +def Ifexpr(cond: Any, true_value: Any, false_value: Any, signature = None) -> IRPyFunc: + signature = 'nnscaler.runtime.function.ifexpr' + cond_val = cond.value if isinstance(cond, IRObject) else cond + result = true_value if cond_val else false_value + result_val= result.value if isinstance(result, IRObject) else result + + return IRPyFunc(signature, + inputs=[cond, true_value, false_value], + outputs=[IRObject(name='ifexpr', value=result_val, is_constant=False)] + ) + + def MultiRef(tensor: IRTensor, times: int, signature = None): """ nnscaler.runtime.function.multiref(itensor: torch.Tensor, times: int) -> Tuple[torch.Tensor] @@ -2527,7 +2539,7 @@ def Log(input, *, out=None, signature=None): def FullLike(input, fill_value, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=None, signature=None): """ - torch.full_like(input, fill_value, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor + torch.full_like(input, fill_value, *, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) → Tensor """ creation_function_args_check('torch.full_like', dtype=dtype, layout=layout, memory_format=memory_format) kwargs = {'fill_value': fill_value, 'requires_grad': requires_grad,'dtype': dtype} diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 27e46701..517b3b1f 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -2,6 +2,7 @@ import logging from pathlib import Path import operator +import warnings from nnscaler.ir.tensor import IRFullTensor from nnscaler.graph.parser.register import CustomizedOps @@ -86,7 +87,11 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: }) dce_ignored_funcs = set(cube_rt_funcs) - with no_save_tensor_hook(): + with no_save_tensor_hook(), warnings.catch_warnings(): + # ignore the warning from fx about get_attr + warnings.filterwarnings("ignore", message= + ".*does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target" + ) traced_model = concrete_trace( model, dummy_input, diff --git a/nnscaler/graph/parser/fx/mapping.py b/nnscaler/graph/parser/fx/mapping.py index 92a3255d..85b13938 100644 --- a/nnscaler/graph/parser/fx/mapping.py +++ b/nnscaler/graph/parser/fx/mapping.py @@ -220,7 +220,7 @@ def exist(signature: str) -> bool: __tttemplate('contiguous'): function.Contiguous, __ttemplate('reshape'): function.Reshape, - + __ttemplate('conv1d'): function.Conv1D, __ftemplate('conv1d'): function.Conv1D, __ttemplate('conv_transpose1d'): function.ConvTranspose1D, @@ -253,6 +253,7 @@ def exist(signature: str) -> bool: # # runtime functions __rtemplate('anchor'): function.GraphAnchor, + __rtemplate('ifexpr'): function.Ifexpr, __rtemplate('identity'): function.Identity, __rtemplate('multiref'): function.MultiRef, __rtemplate('accum'): function.Accum, From d1852063c4daa8a3e7b3e8016265176eb86f2d0d Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 16 Jul 2024 06:19:56 +0000 Subject: [PATCH 1682/1892] Merged PR 2205: Fix sum anno bug ensure the output tensor has dim anno when it is a scalar tensor --- nnscaler/graph/function/function.py | 3 +++ tests/parallel_module/test_gencode.py | 30 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 613c3c55..7c0aca8c 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1115,6 +1115,9 @@ def Sum(input, dim=None, keepdim=False, *, dtype=None, signature = None): sort_dim.sort() for dimidx in sort_dim[::-1]: eoutput.pop(dimidx) + # handle the case of scalar tensor output + if not eoutput: + eoutput = ['1'] anno = OpAnno.create_op_str([einput], [eoutput]) return IRDimops(Sum, 'sum', signature, [anno], [input], dim=dim, keepdim=keepdim) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index d6eefb4c..8ff8de56 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -979,6 +979,36 @@ def test_codegen_kwargs(tmp_path): ) +class ScalarTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.proj = torch.nn.Linear(1024, 1024, bias=False) + self.scale = torch.nn.Parameter(torch.zeros(64)) + + def forward(self, x): + x = self.proj(x) + coef = torch.exp(torch.sum(self.scale, dim=-1)) + x = x / coef + return x.sum() + + +@replace_all_device_with('cpu') +def test_codegen_scalar_tensor(tmp_path): + m = ScalarTensorModule() + m.train() + parallelize( + m, + {'x': torch.randn(1024, 1024)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # parallelize will succeed. + assert True + + class ConvTranspose1DModule(torch.nn.Module): def __init__(self, weight, bias=None, stride=1, padding=0, output_padding=0, dilation=1, groups=1): super().__init__() From 5368dad7d3d4d6023d1eb97b20593165d123d0bc Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Tue, 16 Jul 2024 08:15:57 +0000 Subject: [PATCH 1683/1892] Merged PR 2201: Add release pipeline and update dev version to v0.2 Adds a pipeline to upload release wheel to devops artifact and test.pypi.org. (Has already been run basing on 0.1 tag) And then update `version.py` to 0.2. Don't want to bother create a separate PR. Pipeline usage: 1. Open the pipeline webpage: https://msrasrg.visualstudio.com/SuperScaler/_build?definitionId=116 2. Click "Run pipeline" 3. Choose the branch/tag 4. Click "Variables" 5. Click "version" and set the value to something like "v0.1" 6. Confirm update and run (the update will not be saved and must be done every time) --- nnscaler/version.py | 2 +- pipelines/release.yaml | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 pipelines/release.yaml diff --git a/nnscaler/version.py b/nnscaler/version.py index 2273bc24..506a4934 100644 --- a/nnscaler/version.py +++ b/nnscaler/version.py @@ -1 +1 @@ -__version__ = '0.1.dev0' +__version__ = '0.2.dev0' diff --git a/pipelines/release.yaml b/pipelines/release.yaml new file mode 100644 index 00000000..b7a06433 --- /dev/null +++ b/pipelines/release.yaml @@ -0,0 +1,39 @@ +# depends on two variables: +# +# - version +# must be set on devops website for each run +# the value should be something like "0.1" or "v0.1a1" (w/ or w/o leading v) +# +# - test_pypi_token +# secret, should never expire +# to view it or to update it, check onenote accounts/pypi page (test.pypi token) + +trigger: none +pr: none + +pool: + vmImage: ubuntu-latest + +steps: +- task: TwineAuthenticate@1 + inputs: + artifactFeed: SuperScaler/release + +- script: | + python -m pip install --upgrade build twine + displayName: prepare environment + +- script: | + python pipelines/scripts/update_version.py $(version) + python -m build + number_of_wheels=`ls dist/*.whl | wc -l` + test $number_of_wheels -eq 1 + displayName: build wheel + +- script: | + python -m twine upload -r release --config-file $(PYPIRC_PATH) dist/*.whl + displayName: upload to artifact + +- script: | + python -m twine upload -r testpypi -p $(test_pypi_token) dist/*.whl + displayName: upload to testpypi From c59f7b1dd334aa609e4b79f6e5f4d8dc3352a383 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 16 Jul 2024 08:33:14 +0000 Subject: [PATCH 1684/1892] Merged PR 2202: add scalar tensor support 1. Add a flag to IRTensor to indicate whether it is originally a scalar tensor. 2. During graph transformation, do as it is. 3. When generate code, check the flag to generate correct code. unit test pass parity check pass --- nnscaler/graph/graph.py | 16 +- nnscaler/graph/parser/fx/parser.py | 9 +- nnscaler/graph/segment.py | 44 ++--- nnscaler/ir/adapter/prim.py | 16 +- nnscaler/ir/cten.py | 126 +++++++------ nnscaler/ir/tensor.py | 249 +++++++++++++++++++------ tests/graph/function/test_functions.py | 18 +- tests/ir/{tensor.py => test_tensor.py} | 0 tests/parallel_module/test_end2end.py | 9 + 9 files changed, 324 insertions(+), 163 deletions(-) rename tests/ir/{tensor.py => test_tensor.py} (100%) diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 79f57f07..48275d7a 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -73,7 +73,7 @@ def forward(self, *args: Tuple[IRObject]) -> Union[IRTensor, Tuple[IRTensor]]: """ if not all(isinstance(arg, IRObject) for arg in args): raise TypeError("Expected input arguments to be IRObject") - + # align graph with input tensors iobjs: Tuple[IRObject, ...] = self.inputs() if len(args) != len(iobjs): @@ -560,7 +560,7 @@ def reside(self, tensor: IRSubTensor, devices: Union[int, List[int]]): def sequential(self, prev_nodes: Tuple[IRFwOperation], succ_nodes: Tuple[IRFwOperation]): """Schedule primitive: schedule prev_nodes right before the succ_nodes - + The position of `succ_nodes` will keep unchanged in the sequence while the `prev_nodes` will be scheduled right before the `succ_nodes`. Corresponding backward operators will also be re-ordered. @@ -595,7 +595,7 @@ def sequential(self, prev_nodes: Tuple[IRFwOperation], succ_nodes: Tuple[IRFwOpe if len(set(prev_indices).intersection(set(succ_indices))) != 0: raise ValueError(f'find duplicated node in both succ_nodes and prev_nodes') # TODO: check dependency - + seq = list(self._nodes) # cut out prev_nodes fstart, fend = min(prev_indices), max(prev_indices) + 1 @@ -635,7 +635,7 @@ def depends(self, pre_node: IRCell, succ_node: IRCell) -> bool: Note this function only checks direct data dependency that whether the outputs in `prev_node` and inputs in `post_node` have data dependency. - + The function cannot detect data dependency in graph like: pre_node -> (some nodes) ... -> post_node @@ -749,7 +749,7 @@ def blocking(self, nodes: Tuple[IRFwOperation]): Args: nodes Tuple[IRFwOperations]: the start forward node of each stage. - + Returns: None """ @@ -799,7 +799,7 @@ def blocking(self, nodes: Tuple[IRFwOperation]): assert all(isinstance(node, IRFwOperation) for node in fnodes), \ f"find at least one nodes are not of IRFwOperation in the stage {sid}. They should be moved to the front" fstages.append(fnodes) - + # grouping into segment for sid in range(len(fstages)): self.group(fstages[sid]) @@ -1115,7 +1115,7 @@ def load(filename: str): def checksum(self, strict: bool = True) -> str: """Get the MD5 checksum of the graph. - + This is used to guarantee the consistency of the graph between multiple nodes. @@ -1123,7 +1123,7 @@ def checksum(self, strict: bool = True) -> str: The checksum considers the IDGenerator status. If the user modifies the IDGenerator status (i.e., creating tensors or nodes), it will have a different checksum. - + Args: strict (bool): If True (by default), get the checksum of the whole graph status, including tensor shapes, tensor ids and node ids; diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index e1918242..d9800264 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -129,8 +129,6 @@ def meta2var(meta: Any) -> Any: """Support complex data type of List, Tuple, Dict, Tensor/Object""" if isinstance(meta, TensorMetadata): shape = meta.shape - # TODO: support scalar type - shape = torch.Size([1]) if shape == torch.Size([]) else shape dtype = meta.dtype requires_grad = meta.requires_grad return IRFullTensor(shape=shape, name=node.name, @@ -265,6 +263,13 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule else: output_val = frame.get_var(node.name) if isinstance(ir_node, IRDimops): + # TODO: refine here + # infer_type actually just check whether the annoation is consistent + # with actual output + # internally it will set the shape of output, + # but the output is quickly rewritten by the actual output + # in following code `ir_node.set_output(0, output_val)` + # So the scalar-tensor flag is not removed with `infer_shape` ir_node.infer_shape() if isinstance(output_val, IRTensor) and isinstance(ir_node.output(0), IRTensor): assert output_val.shape == ir_node.output(0).shape, ( diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index feeb0934..39be3049 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -23,7 +23,7 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: assert isinstance(other, CellPosition), "Cannot compare with non-GraphIndex object" return self.indices == other.indices - + def __lt__(self, other: object) -> bool: assert isinstance(other, CellPosition), "Cannot compare with non-GraphIndex object" if len(self.indices) < len(other.indices): @@ -203,7 +203,7 @@ def node(self, index: Union[int, CellPosition]) -> IRCell: def index(self, node: IRCell) -> CellPosition: """ - Get node index. The dispatched node (e.g., IRAdapter, IRSegment) + Get node index. The dispatched node (e.g., IRAdapter, IRSegment) will return the index to its un-dispatched node @param node IRCell: the queried node @@ -293,7 +293,7 @@ def ptensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: A full tensor (ftensor) is originally produced by some operator(s). These operators can be further partitioned into multiple sub-operators. Each sub-operator potentially produces a smaller part of the ftensor (a.k.a. sub-tensor). - This function returns all the sub-tensors that are produced by operators + This function returns all the sub-tensors that are produced by operators inside the segment. Args: @@ -310,7 +310,7 @@ def ctensors(self, ftensor: IRFullTensor) -> Tuple[IRSubTensor]: A full tensor (ftensor) is originally consumed by some operator(s). These operators can be further partitioned into multiple sub-operators. Each sub-operator potentially consumes a smaller part of the ftensor (a.k.a. sub-tensor). - This function returns all the sub-tensors that are consumed by operators + This function returns all the sub-tensors that are consumed by operators inside the segment. Args: @@ -329,15 +329,15 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: applied for this graph. If a tensor is consumed by multiple consumers, the value map of its gradient - will be in exponential format. + will be in exponential format. E.g., t has consumed by node1, node2, node3 and node4. Then the gradient value_map of t (t.grad) of each consumer is (idx, nchunks): (0, 2), (2, 4), (6, 8), (7, 8), where: (0, 2) + (2, 4) + (6, 8) + (7, 8) - = (0, 2) + (2, 4) + (3, 4) - = (0, 2) + (1, 2) + = (0, 2) + (2, 4) + (3, 4) + = (0, 2) + (1, 2) = FULL VALUE @param ftensor IRFullTensor: the full tensor. @@ -362,7 +362,7 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: grad = None if fgrad is None else fgrad.select(ptensor.indmap, (0, 1)) for t in producer.find(ptensor): t.grad = grad - + # set for consumers consumers, ctensors = [], [] # consumers that require gradient for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): @@ -437,7 +437,7 @@ def _add_ftensor(self, ftensor: IRObject): self._ctensors[ftensor] = [] if ftensor.is_attr(): self._attributes.add(ftensor) - + def _remove_ftensor(self, ftensor: IRObject): """ Remove a full tensor in segment @@ -533,7 +533,7 @@ def remove(self, node: IRCell, _pos: Union[int, CellPosition] = None) -> CellPos Args: node (IRCell): the removed node _pos (Optional[Union[int, CellPosition]): help to save cost if provide node position. - + Returns: CellPosition: the removed index """ @@ -603,15 +603,15 @@ def reorder(self, node: IRCell, index: int): @contextmanager def update(self, node): """ - Update a node. Note the related change in backward operator + Update a node. Note the related change in backward operator will not be automatically updated. - + TODO: update operator dependency e.g., with graph.modify(node) as node: node.set_input(0, tensor) - + @param node IRCell: the node that must in the graph @return node IRCell: the modify node """ @@ -645,7 +645,7 @@ def select(self, name: Optional[str] = None, ntype: Optional[IRCell] = None, fla IRSegment, turn `flatten=False` will get the same result as `flatten=True`, and can save more time because `flatten=False` will not traverse the nodes in IRSegment. - + Args: name (Optional[str]): the node name ntype (Optional[Type]): the node type @@ -681,7 +681,7 @@ def finsert(self, fwop: IRFwOperation, index: Union[int, CellPosition]) -> IRFwO assert isinstance(fwop, IRFwOperation), "Only allow insert an IRFwOperation" pos = CellPosition((index,)) if isinstance(index, int) else index assert isinstance(pos, CellPosition), "Expect index to be int or CellPosition" - + index = pos.indices[-1] fsegment = self if len(pos) == 1 else self.node(CellPosition(pos.indices[1:])) fsegment.insert(fwop, index) @@ -720,7 +720,7 @@ def multiref(self, ftensor: IRFullTensor, *deprecated_args) -> IRFwOperation: # check no transformation assert len(self.ptensors(ftensor)) <= 1, f"no transformation should be called before multiref" assert len(set(self.ctensors(ftensor))) == 1, f"no transformation should be called before multiref" - + # create new full tensors consumers = self.consumers(ftensor) tensor = self.ctensors(ftensor)[0] @@ -768,7 +768,7 @@ def multiref(self, ftensor: IRFullTensor, *deprecated_args) -> IRFwOperation: def single_consume(self, one_for_all: bool = True): """ Transform graph to make each non-attribute tensor has up to - one consumer. Multiref nodes will be inserted. The API is useful + one consumer. Multiref nodes will be inserted. The API is useful for cases like inference, where different consumers are partitioned with different tensor dimensions. @@ -890,14 +890,14 @@ def single_consume(self, one_for_all: bool = True): self.insert(multiref, idx) # ====================== Graph Generations ============================ - + @staticmethod def get_inputs(nodes: List[IRCell], exclude_attr: bool = True): """ Get all the input tensors that are required by nodes. @param nodes List[IRCell]: the nodes - + @return inputs List[IRTensor]: the input tensors """ all_outputs = list() @@ -961,7 +961,7 @@ def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> I attr_as_inputs (bool): whether to treat attributes as segment inputs Returns: - segment (IRSegment): the grouped segment. + segment (IRSegment): the grouped segment. """ segment = self segment_outputs = IRSegment.get_objects_from_complex(segment.outputs()) @@ -981,7 +981,7 @@ def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> I # tensor and its device match dmatch = lambda t1, t2: t1 == t2 and t1.device == t2.device - + inputs, outputs = set(), set() sub_cids = set(node.cid for node in nodes) for node in nodes: @@ -1056,7 +1056,7 @@ def order(tensors: Set[IRObject]) -> Tuple[IRObject]: tids = np.array([t.parent.tid for t in tensors]) indices = np.argsort(tids) return tuple(tensors[idx] for idx in indices) - + if self.isfw(): inputs, outputs = order(inputs), order(outputs) diff --git a/nnscaler/ir/adapter/prim.py b/nnscaler/ir/adapter/prim.py index e89900d7..47e4583f 100644 --- a/nnscaler/ir/adapter/prim.py +++ b/nnscaler/ir/adapter/prim.py @@ -180,7 +180,7 @@ class MovePrim(CommPrim): def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): if len(kwargs) == 0: assert len(itensors) == 1 and len(otensors) == 1 - kwargs['shape'] = itensors[0].shape + kwargs['shape'] = itensors[0].origin_shape kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None @@ -222,7 +222,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim """ if len(kwargs) == 0: assert len(itensors) == 1 - kwargs['shape'] = tuple(itensors[0].shape) + kwargs['shape'] = itensors[0].origin_shape kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None kwargs['dsts'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) @@ -253,7 +253,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **k """ if len(kwargs) == 0: assert len(itensors) == 1 - kwargs['shape'] = tuple(itensors[0].shape) + kwargs['shape'] = itensors[0].origin_shape kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None kwargs['dsts'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) @@ -278,7 +278,7 @@ class RDGatherPrim(CommPrim): def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim: int, **kwargs): if len(kwargs) == 0: assert len(otensors) == 1 - kwargs['shape'] = tuple(itensors[0].shape) # the input tensor shape + kwargs['shape'] = itensors[0].origin_shape kwargs['dtype'] = str(itensors[0].dtype) kwargs['srcs'] = tuple(itensor.device[0] if len(itensor.device) > 0 else None for itensor in itensors) kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None @@ -288,7 +288,7 @@ def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], dim def volume(self) -> int: return self.output(0).nelement() - + def __repr__(self) -> str: inputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.inputs()) outputs = ', '.join(f'{t.name}{t.tid}{t.shape}{t.valmap}' for t in self.outputs()) @@ -303,7 +303,7 @@ class RVGatherPrim(CollectivePrim): def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): if len(kwargs) == 0: assert len(otensors) == 1 - kwargs['shape'] = tuple(itensors[0].shape) + kwargs['shape'] = itensors[0].origin_shape kwargs['dtype'] = str(itensors[0].dtype) kwargs['srcs'] = tuple(otensor.device[0] if len(otensor.device) > 0 else None for otensor in otensors) kwargs['dst'] = otensors[0].device[0] if len(otensors[0].device) > 0 else None @@ -327,7 +327,7 @@ class BroadcastPrim(CollectivePrim): def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): if len(kwargs) == 0: assert len(itensors) == 1 - kwargs['shape'] = tuple(itensors[0].shape) + kwargs['shape'] = itensors[0].origin_shape kwargs['dtype'] = str(itensors[0].dtype) kwargs['src'] = itensors[0].device[0] if len(itensors[0].device) > 0 else None super().__init__(itensors, otensors, **kwargs) @@ -460,7 +460,7 @@ class VChunkPrim(CollectivePrim): def __init__(self, itensors: List[IRSubTensor], otensors: List[IRSubTensor], **kwargs): super().__init__(itensors, otensors, **kwargs) self.signature = 'nnscaler.runtime.adapter.vchunk' - + def volume(self) -> int: return 0 diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 94ecee5a..cd268749 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -541,46 +541,84 @@ class IRTensor(IRObject): Note by setting IRTensor name to "None" indicates this tensor holds nothing and will be translated to None in code generation. - """ - - _meta = ['name', '_is_attr', '_is_grad', '_requires_grad', '_dtype', '_persistent'] - def __init__(self, shape=None, name='tensor', dtype=None, tid=None): + Note scalar tensors will always be converted to 1-d tensors. + So all further operations could ignore the scalar tensor case. + You can get the original shape with `origin_shape` property. + """ + def __init__(self, shape=None, name='tensor', dtype=None, tid=None, *, + is_attr=False, is_grad=False, requires_grad=False, persistent=False + ): super().__init__(name, tid, is_constant=False) - self._shape: Tuple[int] = () if shape is None else tuple(shape) - self._cell: Optional[IRCell] = None - self._dtype: Optional[torch.dtype] = dtype + self._is_scalar_tensor: bool = True + self._shape: Tuple[int] = () + self._dtype: Optional[torch.dtype] = None # tensor gradient self._is_grad: bool = False self._requires_grad: bool = False - self._grad: Optional[Union[IRTensor, float]] = None # _persistent is a buffer only field, but in inference mode all params will be post-processed to buffers, # so set _persistent True in as_param() for register these params to persistent buffers. - self._persistent = False + self._persistent: bool = False + self._update( + shape=shape if shape is not None else (), + name=name, + dtype=dtype, + is_attr=is_attr, + is_grad=is_grad, + requires_grad=requires_grad, + persistent=persistent, + ) + self._cell: Optional[IRCell] = None + + def _update( + self, + shape=None, + name=None, + dtype=None, + is_attr=None, + is_grad=None, + requires_grad=None, + persistent=None, + ): + """ + Set tensor metadata + """ + if shape is not None: + self._is_scalar_tensor = not shape + # will always convert scalar tensor to 1-d tensor + self._shape: Tuple[int] = (1,) if not shape else tuple(shape) + if name is not None or self.name is None: + self.name = name + if dtype is not None: + if not isinstance(dtype, torch.dtype): + raise ValueError( + "Only support setting IRTensor with dtype of torch.dtype" + ) + self._dtype = dtype + if is_attr is not None: + self._is_attr = is_attr + if is_grad is not None: + self._is_grad = is_grad + if requires_grad is not None: + self._requires_grad = requires_grad + if persistent is not None: + self._persistent = persistent + + return self @property def dtype(self) -> Optional[torch.dtype]: """Tensor data type""" return self._dtype - @dtype.setter - def dtype(self, val: Optional[torch.dtype]): - """Set data type""" - if not isinstance(val, torch.dtype): - raise NotImplementedError( - "Only support setting IRTensor with dtype of torch.dtype") - self._dtype = val - if isinstance(self._grad, IRTensor): - self._grad._dtype = val - def is_param(self) -> bool: """! Check if the tensor is parameter @return is_param boolean: True if is parameter. """ - return self._is_attr and self.requires_grad + return not self._is_grad and self._is_attr and self._requires_grad def is_buffer(self) -> bool: """! @@ -588,7 +626,7 @@ def is_buffer(self) -> bool: @return is_buffer boolean: True if is buffer. """ - return self._is_attr and not self.requires_grad + return not self._is_grad and self._is_attr and not self._requires_grad def is_persistent(self) -> bool: """! @@ -606,35 +644,13 @@ def is_grad(self) -> bool: """ return self._is_grad - def as_param(self): - """ - Set the tensor as trainable parameter - """ - assert self._grad is not None, "missing grad tensor" - self._requires_grad = True - self._is_attr = True - self._is_grad = False - self._persistent = True - return self + def is_scalar_tensor(self) -> bool: + """! + Check if the tensor is scalar tensor - def as_buffer(self, persistent=True): - """ - Set the tensor as un-trainable buffer + @return is_scalar_tensor boolean: True if is scalar tensor """ - self._requires_grad = False - self._is_attr = True - self._is_grad = False - self._persistent = persistent - return self - - def as_grad(self): - """ - Set the tensor as gradient - """ - self._is_param = False - self._is_attr = False - self._is_grad = True - return self + return self._is_scalar_tensor @property def requires_grad(self) -> bool: @@ -655,6 +671,14 @@ def __copy__(self): tensor.cell = None return tensor + @property + def origin_shape(self) -> Tuple[int]: + """ + Get the original shape of the tensor + (Because self.shape will convert scalar tensor to 1-dim tensor) + """ + return self.shape if not self.is_scalar_tensor() else () + @property def shape(self) -> Tuple[int]: # NOTE: here return a tuple but not a real torch.Size obj may have risk, here is an example: @@ -662,12 +686,6 @@ def shape(self) -> Tuple[int]: # (torch.Size + list -> torch.Size) will change to (tuple + list -> error), is wrong. return self._shape - @shape.setter - def shape(self, val: Tuple[int]): - self._shape = tuple(val) - if isinstance(self._grad, IRTensor): - self._grad.shape = tuple(val) - def nelement(self) -> int: """ Get total number of element in the tensor. diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index 94a7553f..09312e76 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -25,6 +25,7 @@ """ from typing import List, Optional, Union, Tuple, NewType, Dict, Any +import torch from nnscaler.ir.cten import IRTensor @@ -255,15 +256,58 @@ class IRFullTensor(IRTensor): the sequentail execution order by its graph. """ - def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=None): - - super().__init__(shape, name, dtype) - + def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=None, *, + is_attr=False, is_grad=False, persistent=False, is_loss=False + ): + self._is_loss: bool = False # record all created sub_tensors self._subtensors : Dict[(ValueMap, IndexMap), int] = dict() + self._grad: Optional[IRFullTensor] = None - self.requires_grad = requires_grad - self._is_loss = False + super().__init__(shape, name, dtype, requires_grad=requires_grad, is_attr=is_attr, is_grad=is_grad, persistent=persistent) + self._update( + is_loss=is_loss, + ) + + def _update( + self, + shape=None, + name=None, + dtype=None, + is_attr=None, + is_grad=None, + requires_grad=None, + persistent=None, + is_loss=None, + ): + super()._update( + shape=shape, + name=name, + dtype=dtype, + is_attr=is_attr, + is_grad=is_grad, + requires_grad=requires_grad, + persistent=persistent, + ) + + # reset grad + if self._requires_grad: + if self._grad is None: + self._grad = self.like_grad() + else: + self._grad = None + + if is_loss is not None: + self._is_loss = is_loss + + if self._grad is not None: + self._grad._update( + shape=self.origin_shape, + is_attr=self._is_attr, + dtype=self._dtype, + is_loss=self._is_loss, + ) + return self def __hash__(self) -> int: return self._id @@ -279,15 +323,28 @@ def __copy__(self): def like(self): """! - Create a IRFullTensor with same meta data but a different id. + Create a IRFullTensor with same name/shape/dtype/requires_grad/is_loss but a different id. @return tensor IRFullTensor: the created tensor """ - tensor = IRFullTensor(self.shape, self.name, self.requires_grad, self.dtype) - if self.is_loss(): - tensor.to_loss() + tensor = IRFullTensor( + self.origin_shape, self.name, self._requires_grad, + self._dtype, is_loss=self._is_loss + ) return tensor + def like_grad(self): + """! + Create a gradient IRFullTensor with same shape, dtype and is_attr. + + @return tensor IRFullTensor: the created tensor + """ + grad = IRFullTensor( + self.origin_shape, 'g' + self.name, + requires_grad=False, dtype=self.dtype + ).as_grad(self._is_attr) + return grad + @property def grad(self) -> Optional[IRTensor]: return self._grad @@ -297,11 +354,17 @@ def grad(self, val: Optional[IRTensor]): """ Setup gradient for the tensor. """ + assert not self._subtensors, "Grad can only be updated before creating sub-tensors" assert val is None or isinstance(val, IRFullTensor) if val is not None: assert self._requires_grad, f"Cannot assign {val} to no grad-required tensor" - assert val.shape == self.shape + assert val.origin_shape == self.origin_shape assert val.is_attr() == self.is_attr() + # TODO: we should check the grad-required here + # it is very common in current code that we assign None to grad + # so currently it is impossible to check the grad-required here + # else: + # assert not self._requires_grad, f"Cannot assign {val} to grad-required tensor" self._grad = val def is_loss(self) -> bool: @@ -317,53 +380,54 @@ def to_loss(self): Set this tensor as loss tensor. The tensor shape must be [1,] """ assert tuple(self.shape) == (1,), f"Loss tensor can only have shape [1,] but got {self.shape}" - self._is_loss = True - if isinstance(self.grad, IRFullTensor): - self.grad._is_loss = True + self._update(is_loss=True) - @property - def requires_grad(self): - return self._requires_grad + @IRTensor.dtype.setter + def dtype(self, val: Optional[torch.dtype]): + """Set data type""" + assert not self._subtensors, "Cannot change dtype after creating sub-tensors" + self._update(dtype=val) - @requires_grad.setter + @IRTensor.requires_grad.setter def requires_grad(self, req_grad: bool): - if req_grad: - self._requires_grad = True - if self._grad is None: - grad = IRFullTensor( - self.shape, 'g' + self.name, - requires_grad=False, dtype=self.dtype - ).as_grad(self.is_attr()) - self._grad = grad - else: - self._requires_grad = False - self._grad = None + self._update(requires_grad=req_grad) + + @IRTensor.shape.setter + def shape(self, val: Tuple[int]): + assert not self._subtensors, "Cannot change shape after creating sub-tensors" + self._update(shape=val) + + def as_attr(self): + raise RuntimeError("as_attr is ambiguous for FullTensor, use as_param or as_buffer instead") def as_param(self): """ Set the tensor as trainable parameter """ - self.requires_grad = True - self._is_attr = True - self._is_grad = False - self._persistent = True - if isinstance(self.grad, IRFullTensor): - self.grad._is_attr = True + return self._update( + requires_grad=True, + is_attr=True, + is_grad=False, + persistent=True, + ) def as_buffer(self, persistent=True): """ Set the tensor as un-trainable buffer """ - self.requires_grad = False - self._is_attr = True - self._is_grad = False - self._persistent = persistent + return self._update( + requires_grad=False, + is_attr=True, + is_grad=False, + persistent=persistent, + ) def as_grad(self, of_attr: bool = False): - self._attr = True if of_attr else False - self.requires_grad = False - self._is_grad = True - return self + return self._update( + requires_grad=False, + is_attr=of_attr, + is_grad=True, + ) def select(self, indmap: IndexMap, valmap: ValueMap): """! @@ -384,6 +448,7 @@ def select(self, indmap: IndexMap, valmap: ValueMap): else: sub_tensor = IRSubTensor(self, indmap, valmap) self._subtensors[keys] = sub_tensor.tid + return sub_tensor def tosub(self): @@ -413,7 +478,6 @@ def extra_repr(self) -> str: class IRSubTensor(IRTensor): - def __init__(self, ftensor: IRFullTensor, indmap: Union[Tuple[StartEnd], IndexMap], valmap: Union[Tuple[StartEnd], ValueMap], @@ -429,15 +493,23 @@ def __init__(self, ftensor: IRFullTensor, assert isinstance(ftensor, IRFullTensor), "Expcted ftensor to be IRFullTensor" assert 'dtype' not in kwargs, "IRSubTensor is not allowed to initialize with a dtype" super().__init__(shape=indmap.shape, name=ftensor.name, **kwargs) - for attr in IRFullTensor._meta: - setattr(self, attr, getattr(ftensor, attr)) - self.cell = None # the full tensor self._full_tensor = ftensor + + # remove the redundant attributes to avoid misuse. + # they will be updated from parent + del self._dtype + del self._is_attr + del self._is_grad + del self._requires_grad + del self._persistent + + self.cell = None # the index from full_tensor self._indmap: IndexMap = indmap # val map self._valmap: ValueMap = valmap + self._grad: Optional[IRSubTensor] = None def __eq__(self, other) -> bool: if isinstance(other, IRSubTensor): @@ -475,16 +547,8 @@ def valmap(self) -> IdxChunk: def ndims(self) -> int: return len(self.shape) - @property - def dtype(self) -> Any: - return self.parent.dtype - - @dtype.setter - def dtype(self, val): - raise RuntimeError( - f"IRSubTensor dtype must follow IRFullTensor dtype. " - f"Please set it by subtensor.parent.dtype = {val}" - ) + def as_attr(self): + raise RuntimeError("as_attr is not allowed for SubTensor") def splitdims(self) -> Tuple[int]: """! @@ -601,15 +665,74 @@ def __copy__(self): tensor._cell = None return tensor + # forward all property to parent + @property + def dtype(self) -> Optional[torch.dtype]: + """Tensor data type""" + return self.parent.dtype + + @IRTensor.shape.setter + def shape(self, val: Tuple[int]): + # TODO: remove this function + # It is not reasonable to set shape for a subtensor. + # But there are codes doing that in current repo. + + # Here we check against self.shape instead of self.origin_shape + # because the assignment of shape is done after we erase all scalar-tensor + assert tuple(val) == tuple(self.shape), 'Cannot modify shape of a sub-tensor.' + + def is_attr(self) -> bool: + """Check if the tensor is attribute""" + return self.parent.is_attr() + + def is_param(self) -> bool: + """! + Check if the tensor is parameter + + @return is_param boolean: True if is parameter. + """ + return self.parent.is_param() + + def is_buffer(self) -> bool: + """! + Check if the tensor is buffer. + + @return is_buffer boolean: True if is buffer. + """ + return self.parent.is_buffer() + + def is_persistent(self) -> bool: + """! + Check if the tensor is persistent buffer. + + @return is_persistent boolean: True if is persistent. + """ + return self.parent.is_persistent() + + def is_grad(self) -> bool: + """! + Check if the tensor is gradient + + @return is_grad boolean: True if is gradient + """ + return self.parent.is_grad() + + def is_scalar_tensor(self) -> bool: + """! + Check if the tensor is scalar tensor + + @return is_scalar_tensor boolean: True if is scalar tensor + """ + return self.parent.is_scalar_tensor() + @property def requires_grad(self) -> bool: - self._requires_grad = self.parent.requires_grad return self.parent.requires_grad @property def grad(self) -> bool: """Get the gradient of this tensor. - + The gradient is kept aligned with its parent IRFullTensor. """ if not self.requires_grad: @@ -620,12 +743,18 @@ def grad(self) -> bool: def grad(self, val: Optional[IRTensor]): """ Setup gradient for the tensor. + Currently unlike IRFullTensor, IRSubTensor's grad is never created automically """ assert val is None or isinstance(val, IRSubTensor) if val is not None: - assert self._requires_grad, f"Cannot assign {val} to no grad-required tensor" + assert self.requires_grad, f"Cannot assign {val} to no grad-required tensor" assert val.shape == self.shape assert val.is_attr() == self.is_attr() + # TODO: we should check the grad-required here + # it is very common in current code that we assign None to grad + # so currently it is impossible to check the grad-required here + # else: + # assert not self._requires_grad, f"Cannot assign {val} to grad-required tensor" self._grad = val def is_loss(self) -> bool: diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 0072aac8..aa1743bd 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -432,7 +432,7 @@ def test_Len(): assert op.outputs()[0].value == 3 and not op.outputs()[0].is_constant -def test_Min(): +def test_Min(): op = F.Min(IRTensor([2, 3, 4])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a^ b^ c^ -> 1' op = F.Min(IRTensor([2, 3, 4]), 1, True) @@ -465,10 +465,10 @@ def test_FullLike(): op = F.FullLike(IRTensor([2, 1, 4, 1]), 1.) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a b c d' op_int = F.FullLike(IRTensor([3, 2]), 5) - assert len(op_int._annos_candidates) == 1 and op_int._annos_candidates[0] == 'a b -> a b' + assert len(op_int._annos_candidates) == 1 and op_int._annos_candidates[0] == 'a b -> a b' op_true = F.FullLike(IRTensor([2, 2]), 1., requires_grad=True) - assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' - op_float = F.FullLike(IRTensor([1, 2],dtype=int), 1, dtype=torch.float) + assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' + op_float = F.FullLike(IRTensor([1, 2],dtype=torch.int), 1, dtype=torch.float) assert len(op_float._annos_candidates) == 1 and op_float._annos_candidates[0] == 'a b -> a b' @@ -488,8 +488,8 @@ def test_ZerosLike(): op = F.ZerosLike(IRTensor([2, 1, 4, 1])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a b c d' op_true = F.ZerosLike(IRTensor([2, 2]), requires_grad=True) - assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' - op_float = F.ZerosLike(IRTensor([1, 2],dtype=int), dtype=torch.float) + assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' + op_float = F.ZerosLike(IRTensor([1, 2],dtype=torch.int), dtype=torch.float) assert len(op_float._annos_candidates) == 1 and op_float._annos_candidates[0] == 'a b -> a b' @@ -497,8 +497,8 @@ def test_OnesLike(): op = F.OnesLike(IRTensor([2, 1, 4, 1])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c d -> a b c d' op_true = F.OnesLike(IRTensor([2, 2]), requires_grad=True) - assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' - op_float = F.OnesLike(IRTensor([1, 2],dtype=int), dtype=torch.float) + assert len(op_true._annos_candidates) == 1 and op_true._annos_candidates[0] == 'a b -> a b' + op_float = F.OnesLike(IRTensor([1, 2],dtype=torch.int), dtype=torch.float) assert len(op_float._annos_candidates) == 1 and op_float._annos_candidates[0] == 'a b -> a b' @@ -595,7 +595,7 @@ def test_Softmax(): op = F.Softmax(IRTensor([2, 3, 4]), dtype=torch.float32) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'a b c -> a b c' - + def test_Conv1D(): op = F.Conv1D(IRTensor([3, 4]), IRTensor([3, 3, 1])) assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == 'iC+ 4, oC iC+ 1 -> oC 4' diff --git a/tests/ir/tensor.py b/tests/ir/test_tensor.py similarity index 100% rename from tests/ir/tensor.py rename to tests/ir/test_tensor.py diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index de18cbc9..402aa949 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -171,16 +171,25 @@ def test_end2end(): assert len(ga4_result) == 16 cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid', True) # micro_batch_size = 4 + for _, v in cube2_results.items(): + # all losses should be scalar tensor + assert all(i.shape == () for i in v[1]) cube2_result = merge_cube_result({k: v[0] for k, v in cube2_results.items()}) assert len(cube2_result) == 16 allclose(cube2_result, ga4_result) cube4_results = launch_torchrun(4, gpu_worker_cube, 4, 4, PASMegatron, True) # micro_batch_size = 4 + for _, v in cube2_results.items(): + # all losses should be scalar tensor + assert all(i.shape == () for i in v[1]) cube4_result = merge_cube_result({k: v[0] for k, v in cube4_results.items()}) assert len(cube4_result) == 16 allclose(cube4_result, ga4_result) cube2_results_non_pipeline = launch_torchrun(4, gpu_worker_cube, 4, 2, 'tp', False) # micro_batch_size = 4 + for _, v in cube2_results.items(): + # all losses should be scalar tensor + assert all(i.shape == () for i in v[1]) cube2_result_non_pipeline = merge_cube_result({k: v[0] for k, v in cube2_results_non_pipeline.items()}) assert len(cube2_result_non_pipeline) == 16 allclose(cube2_result_non_pipeline, ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error From 106bbf20b63e10a3dbccf7031906b9386692831c Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 17 Jul 2024 06:26:14 +0000 Subject: [PATCH 1685/1892] Merged PR 2204: fix embedding padding index --- nnscaler/graph/function/function.py | 2 ++ nnscaler/runtime/function/function.py | 25 ++++++++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index 7c0aca8c..eea632f7 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1939,6 +1939,8 @@ def CubeEmbedding(input, weight, padding_idx, signature = None, **kwargs): start, stop = weight.indmap[0] else: start, stop = 0, weight.shape[0] + # here we can split the vocab dim with `+`, because we rewrite the embedding logic to ensure the result is right + # please review nnscaler.runtime.function.embedding for more information annos = ['*, n+ e -> * e'] return IRDimops(CubeEmbedding, 'embedding', signature, annos, [input, weight], padding_idx=padding_idx, start=start, stop=stop) diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 018a3efc..9cc5f4a3 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -101,21 +101,40 @@ def conv3d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso def embedding(input: torch.Tensor, weight: torch.Tensor, padding_idx: Optional[int], start: int, stop: int): """ - Embedding + add start/stop to make vocab dim partitionable. + + for example, if the vocab size is 100, and partition the weigth on vocab dim to 4 part, + then on each part, it will have different start/stop: + 1: [start=0, stop=25] + 2: [start=25, stop=50] + 3: [start=50, stop=75] + 4: [start=75, stop=100] + before do embedding, the input index outside the range will be masked, + and directly assign 0.0 to the masked position on the output. + + If vocab dim is partitioned, the results are summed to ensure the correctness of the final result. Inputs: input: torch.Tensor [*] weight: [vocab size, embed size] - start: int - stop: int + start: int, the weight split start index on vocab dim + stop: int, the weight split stop index on vocab dim Outputs: output: [*, embed_size] """ input = input.long() input_mask = (input < start) | (input >= stop) + # make the range of value in the input to [0, stop-start) + # note that the embedding is implemented like a look up table. masked_input = input.clone() - start masked_input[input_mask] = 0 + # if padding_idx is inside [start, stop), should map it to [0, stop-start) + # if padding_idx is outside [start, stop), directly make it None + if padding_idx is not None and start <= padding_idx < stop: + padding_idx -= start + else: + padding_idx = None output = TorchF.embedding( masked_input, weight, padding_idx, None, 2.0, False, False From 992b945f08172a17a0f3e26d810895c4b2b4853a Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 22 Jul 2024 06:46:05 +0000 Subject: [PATCH 1686/1892] Merged PR 2192: minitrainer: refine config MiniTrainer workable version. the parity check is against lightning. --- nnscaler/__init__.py | 3 +- nnscaler/cli/arg_parser.py | 28 +- nnscaler/cli/loggers/__init__.py | 2 + nnscaler/cli/loggers/logger_base.py | 24 + nnscaler/cli/loggers/tensorboard.py | 83 +++ nnscaler/cli/loggers/wandb.py | 53 ++ nnscaler/cli/train.py | 6 +- nnscaler/cli/train_hook.py | 141 +++++ nnscaler/cli/trainer.py | 550 +++++++++++++++--- nnscaler/cli/trainer_args.py | 409 ++++++++++--- .../integration/lightning/pytorch/strategy.py | 15 +- nnscaler/parallel.py | 35 +- nnscaler/utils.py | 101 +++- requirements-dev.txt | 2 + requirements.txt | 1 + tests/cli/common.py | 1 + tests/cli/test_arg_parser.py | 23 +- tests/cli/test_trainer.py | 171 +++++- tests/cli/trainer_args.yaml | 34 +- .../lightning/pytorch/simple_models.py | 13 + .../lightning/pytorch/test_strategy.py | 231 +++++++- tests/parallel_module/test_attr_dedup.py | 2 +- tests/parallel_module/test_checkpoint.py | 6 +- .../parallel_module/test_checkpoint_buffer.py | 2 +- .../parallel_module/test_checkpoint_dedup.py | 2 +- .../parallel_module/test_checkpoint_shared.py | 4 +- .../parallel_module/test_checkpoint_unused.py | 2 +- 27 files changed, 1712 insertions(+), 232 deletions(-) create mode 100644 nnscaler/cli/loggers/__init__.py create mode 100644 nnscaler/cli/loggers/logger_base.py create mode 100644 nnscaler/cli/loggers/tensorboard.py create mode 100644 nnscaler/cli/loggers/wandb.py create mode 100644 nnscaler/cli/train_hook.py diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index 63ba391f..f43567cd 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -7,10 +7,11 @@ parallelize, build_optimizer, merge_state_dicts, - load_merged_state_dicts, + load_merged_state_dict, deduped_state_dict, load_deduped_state_dict, broadcast_weights, + load_sharded_state_dict, ) from nnscaler.graph.parser.register import register_op diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index 927b9a4a..11b74606 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -199,10 +199,12 @@ def deserialize_dataclass(value, value_type): type_info = _get_type_info(value_type) member_values = {} - for k, ti in type_info.items(): - if not k in value: + used_keys = set() + for key, ti in type_info.items(): + if not key in value: continue - v = value[k] + used_keys.add(key) + v = value[key] if ti.type is bool and v is None: v = True # set bool to True if it shows up in cmd line if v is None: @@ -215,16 +217,24 @@ def deserialize_dataclass(value, value_type): ti.type = type(v) if ti.item_type or ti.key_type or ti.value_type: - if ti.type == list: - v = [_deserialize_object(x, ti.item_type) for x in v] - elif ti.type == tuple: - v = tuple(_deserialize_object(x, ti.item_type) for x in v) + if ti.type in (list, tuple): + if isinstance(v, (list, tuple)): + v = ti.type(_deserialize_object(x, ti.item_type) for x in v) + elif isinstance(v, dict): + v_dict = {_deserialize_object(k, int): _deserialize_object(v, ti.item_type) for k, v in v.items()} + v = [None] * (max(v_dict.keys()) + 1) + for k, x in v_dict.items(): + v[k] = x + v = ti.type(v) + else: + raise ValueError(f"Invalid value {v} for {value_type}") elif ti.type == dict: v = {_deserialize_object(k, ti.key_type): _deserialize_object(v, ti.value_type) for k, v in v.items()} else: v = _deserialize_object(v, ti.type) if v is not None: # for none values, use default value. - member_values[k] = v - + member_values[key] = v + if set(value.keys()) - used_keys: + raise ValueError(f"Unknown members {set(value.keys()) - used_keys} for {value_type}") return value_type(**member_values) diff --git a/nnscaler/cli/loggers/__init__.py b/nnscaler/cli/loggers/__init__.py new file mode 100644 index 00000000..20900fb3 --- /dev/null +++ b/nnscaler/cli/loggers/__init__.py @@ -0,0 +1,2 @@ +from .tensorboard import TensorBoardLogger +from .wandb import WandbLogger diff --git a/nnscaler/cli/loggers/logger_base.py b/nnscaler/cli/loggers/logger_base.py new file mode 100644 index 00000000..a507bf1e --- /dev/null +++ b/nnscaler/cli/loggers/logger_base.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from typing import Optional, Dict + + +class LoggerBase(ABC): + """ + Base class for experiment loggers. + """ + + @abstractmethod + def setup(self, config: Dict) -> None: + """ + Setup logger with trainer args. This is useful for saving hyperparameters. + Will be called once before `log_metrics` + """ + ... + + @abstractmethod + def log_metrics(self, metrics: Dict[str, float], step: int) -> None: + ... + + @abstractmethod + def finalize(self) -> None: + ... diff --git a/nnscaler/cli/loggers/tensorboard.py b/nnscaler/cli/loggers/tensorboard.py new file mode 100644 index 00000000..ed20a92c --- /dev/null +++ b/nnscaler/cli/loggers/tensorboard.py @@ -0,0 +1,83 @@ +import atexit +from pathlib import Path +from typing import Dict, Optional +from datetime import datetime + +import yaml +import torch +try: + _tensorboard_writers = [] + from torch.utils.tensorboard import SummaryWriter +except ImportError: + SummaryWriter = None + +from nnscaler.utils import rank_zero_only +from .logger_base import LoggerBase + + +class TensorBoardLogger(LoggerBase): + def __init__( + self, + name: str, + root_dir: str, + **kwargs, + ): + if SummaryWriter is None: + raise RuntimeError( + "tensorboard not found, please install with: pip install tensorboard" + ) + + super().__init__() + self._name = name + self._root_dir = Path(root_dir).expanduser().resolve() + self._kwargs = kwargs + + self._summary_writer = None + + @property + def log_dir(self) -> str: + """ + Root directory to save logging output, which is `_log_dir/_name`. + """ + sub_path = [s for s in [self._name] if s] + ld = self._root_dir.joinpath(*sub_path) + ld.mkdir(parents=True, exist_ok=True) + return str(ld) + + @rank_zero_only + def setup(self, config: Dict) -> None: + self._ensure_writer() + self._summary_writer.add_text("config", yaml.dump(config)) + + def _ensure_writer(self): + if not self._summary_writer: + self._summary_writer = SummaryWriter(log_dir=self.log_dir, **self._kwargs) + _tensorboard_writers.append(self._summary_writer) + return self._summary_writer + + @rank_zero_only + def log_metrics(self, metrics: Dict[str, float], step: int) -> None: + self._ensure_writer() + + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + + if isinstance(v, dict): + self._summary_writer.add_scalars(k, v, step) + else: + self._summary_writer.add_scalar(k, v, step) + + @rank_zero_only + def finalize(self) -> None: + if self._summary_writer: + self._summary_writer.close() + _tensorboard_writers.remove(self._summary_writer) + + +def _close_writers(): + for w in _tensorboard_writers: + w.close() + +# Close all writers on exit +atexit.register(_close_writers) diff --git a/nnscaler/cli/loggers/wandb.py b/nnscaler/cli/loggers/wandb.py new file mode 100644 index 00000000..dd263be0 --- /dev/null +++ b/nnscaler/cli/loggers/wandb.py @@ -0,0 +1,53 @@ +from typing import Dict, Optional +from pathlib import Path + +try: + import wandb +except ImportError: + wandb = None + +from nnscaler.utils import rank_zero_only + +from .logger_base import LoggerBase + + +class WandbLogger(LoggerBase): + def __init__( + self, + name: Optional[str] = None, + project: Optional[str] = None, + entity: Optional[str] = None, + dir: Optional[str] = None, + **kwargs + ) -> None: + super().__init__() + + self._name = name + self._project = project + self._entity = entity + self._dir = dir + self._kwargs = kwargs + + @rank_zero_only + def setup(self, config: Dict) -> None: + if self._dir is not None: + self._dir = Path(self._dir).expanduser().resolve() + self._dir.mkdir(parents=True, exist_ok=True) + + # reinit=False to ensure if wandb.init() is called multiple times + # within one process it still references the same run + wandb.init(name=self._name, project=self._project, + entity=self._entity, + reinit=False, + dir=self._dir, + config=config, + **self._kwargs + ) + + @rank_zero_only + def log_metrics(self, metrics: Dict[str, float], step: int) -> None: + wandb.log(metrics, step=step) + + @rank_zero_only + def finalize(self) -> None: + wandb.finish() diff --git a/nnscaler/cli/train.py b/nnscaler/cli/train.py index bf7fd749..b0aa3c8f 100644 --- a/nnscaler/cli/train.py +++ b/nnscaler/cli/train.py @@ -1,8 +1,12 @@ +import logging + +import nnscaler + from .trainer import Trainer if __name__ == '__main__': + nnscaler.utils.set_default_logger_level(level=logging.INFO) trainer = Trainer() if trainer.train_args.run_mode == 'run': trainer.train() - diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py new file mode 100644 index 00000000..8f17c166 --- /dev/null +++ b/nnscaler/cli/train_hook.py @@ -0,0 +1,141 @@ +from typing import Any, Dict, List, TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from nnscaler.cli.trainer import Trainer + from nnscaler.cli.trainer_args import AggregatedOutputs + + +class TrainHook: + """ + Note: All hooks are called in all ranks, and the inputs of hooks are only the local data. + """ + def on_train_start(self, trainer: 'Trainer') -> None: + """Called at the beginning of training""" + + def on_train_end(self, trainer: 'Trainer') -> None: + """Called at the end of training""" + + def on_val_start(self, trainer: 'Trainer') -> None: + """Called at the beginning of validation""" + + def on_val_end(self, trainer: 'Trainer', val_loss: float) -> None: + """Called at the end of validation""" + + def on_epoch_start(self, trainer: 'Trainer', epoch: int) -> None: + """ + Called at the beginning of each epoch + Args: + epoch: the current epoch index + """ + + def on_epoch_end(self, trainer: 'Trainer', epoch: int) -> None: + """ + Called at the end of each epoch + Args: + epoch: the current epoch index + """ + + def on_train_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: + """ + Called at the beginning of each training step + Please note one train step may contain multiple batches + Args: + batches: the current batches + idx: the index of current step + """ + + def on_train_step_end(self, trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None: + """ + Called at the end of each training step + Args: + outputs: the outputs of the train_step + batches: the current batches + idx: the index of current step + """ + + def on_val_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: + """ + Called at the beginning of each validating step + Please note one val step may contain multiple batches + Args: + batches: the current batches + idx: the index of current step + """ + + def on_val_step_end(self, trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None: + """ + Called at the end of each validating step + Args: + outputs: the outputs of the val_step + batches: the current batches + idx: the index of current step + """ + + def after_aggregate_train_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None: + """ + Called after aggregating outputs in train step + Args: + aggregated_outputs: the aggregated outputs + train_loss: the loss of the current step + idx: the index of current step + """ + + def after_aggregate_val_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None: + """ + Called after aggregating outputs in val step + Args: + aggregated_outputs: the aggregated outputs + val_loss: the loss of the current step + idx: the index of current step + """ + + def before_zero_grad(self, trainer: 'Trainer') -> None: + """ + Called before zero_grad + """ + + def after_zero_grad(self, trainer: 'Trainer') -> None: + """ + Called after zero_grad + """ + + def before_gnorm_clip(self, trainer: 'Trainer') -> None: + """ + Called before gradient clipping + """ + + def after_gnorm_clip(self, trainer: 'Trainer', gnorm: torch.Tensor) -> None: + """ + Called after gradient clipping + """ + + def before_optimizer_step(self, trainer: 'Trainer') -> None: + """ + Called before optimizer.step() + """ + + def after_optimizer_step(self, trainer: 'Trainer') -> None: + """ + Called after optimizer.step() + """ + + def on_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: + """ + Called after loading checkpoint. + If you saved something with `on_save_checkpoint` this is + your chance to restore this. + + Args: + checkpoint: the checkpoint loaded + """ + + def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: + """ + Called before saving checkpoint. + If you want to save something, you can add it to the checkpoint here. + + Args: + checkpoint: the checkpoint to be saved + """ diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 9927965f..c7e19e33 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -1,26 +1,65 @@ from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional, Union from pathlib import Path -import sys +import sys, os import copy import inspect +import warnings +import shutil +import logging import torch import torch.distributed +from torch.utils.data import DataLoader + +from tqdm import tqdm + import nnscaler +from nnscaler.utils import enforce_zero_num_worker +import nnscaler.utils + +from .trainer_args import AggregatedOutputs, TrainerArgs + -from .trainer_args import TrainerArgs +logger = logging.getLogger(__name__) + + +# the format of the checkpoint file +# keys: epoch, step, rank +# currently it is not configurable +# TODO: make it configurable +CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/{rank}.ckpt' +CHECKPOINT_LAST_DIR_NAME: str = 'last' +CHECKPOINT_BEST_DIR_NAME: str = 'best' +CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}.ckpt' +CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}.ckpt' @dataclass class TrainStatus: + best_loss = float('inf') epoch: int = 0 - in_epoch_pos: int = 0 # the position inside an epoch, used for resuming training + # used for resuming training + # it is the index of the next batch in the current epoch + # i means the i-1 batch is done, and we should resume from ith batch + # for example + # 0 means the epoch is not started + # 1 means the 0th batch is done, and we should resume from 1st batch + next_batch_index: int = 0 + + +@dataclass +class _StepStat: + train_loss: float = None + val_loss: float = None + lr: float = None + gnorm: float = None class Trainer: def __init__(self, argv: Optional[List[str]] = None, + *, train_args: Optional[Union[Dict[str, Any], TrainerArgs]] = None ): """ @@ -38,18 +77,27 @@ def __init__(self, raise ValueError(f"train_args should be a dict or TrainerArgs, got {type(train_args)}") self.train_args = TrainerArgs.from_dict(train_args) else: - cli_args = argv or sys.argv[1:] # remve the leading script name from sys.argv + cli_args = argv or sys.argv[1:] # remove the leading script name from sys.argv self.train_args = TrainerArgs.from_cli(cli_args) + self.rank = None + self.sync_group = None self.model = None self.optimizer = None self.dataset = {'train': None, 'val': None, 'test': None} - self.dataloader = {'train': None, 'val': None, 'test': None} + self.dataloader: Dict[str, Optional[DataLoader]] = {'train': None, 'val': None, 'test': None} self.lr_scheduler = None self.train_status = TrainStatus() self.dummy_input = None + self.total_train_steps_per_epoch = None + self.loggers = [] + self.hook = None self._setup() + @property + def num_train_steps(self): + return self.train_status.epoch * self.total_train_steps_per_epoch + self.train_status.next_batch_index + def _fix_input(self, input): if isinstance(input, dict): return {k: self._fix_input(v) for k, v in input.items()} @@ -77,10 +125,21 @@ def _create_dummy_forward_args(self): ) return {arg_names[1]: self.dummy_input} # arg_names[0] is self + def _load_dummy_input(self): + with enforce_zero_num_worker(DataLoader): + assert self.dataset['train'] is not None, "train dataset is not set" + dataloader = self.train_args.create_dataloader('train', self.dataset['train']) + assert dataloader.num_workers == 0, "The dataloader must have `num_workers=0`." + return next(iter(dataloader)) + def _setup(self): compile_only = self.train_args.run_mode == 'compile' if not compile_only: nnscaler.init() + if torch.distributed.get_rank() == 0: + logging.getLogger().setLevel(logging.INFO) + else: + logging.getLogger().setLevel(logging.WARNING) def _create_model(): model = self.train_args.create_model() @@ -88,26 +147,26 @@ def _create_model(): model = model.half() elif self.train_args.bf16: model = model.bfloat16() - if self.train_args.ckpt_tracing: - model.load_state_dict(torch.load(self.train_args.ckpt_tracing)) + if self.train_args.tracing_from_weights: + model.load_state_dict(torch.load(self.train_args.tracing_from_weights)) return model - # load a dummy input from training dataset - if not compile_only: - for stage in ['train', 'val', 'test']: - self.dataset[stage] = self.train_args.create_dataset(stage) - self.dataloader[stage] = self.train_args.create_dataloader(stage, self.dataset[stage]) - - self.dummy_input = self.dataloader['train'].collate_fn( - [self.dataset['train'][idx] for idx in range(self.train_args.micro_batch_size)] - ) - else: - train_dataset = self.train_args.create_dataset('train') - self.dummy_input = self.train_args.collate_fn( - [train_dataset[idx] for idx in range(self.train_args.micro_batch_size)] - ) - del train_dataset + # create dataset and dataloader + for stage in ['train', 'val', 'test']: + self.dataset[stage] = self.train_args.create_dataset(stage) + self.dataloader[stage] = self.train_args.create_dataloader(stage, self.dataset[stage]) + if self.dataloader[stage] is not None \ + and not self.dataloader[stage].drop_last \ + and len(self.dataset[stage]) % (self.train_args.micro_batch_size * self.train_args.scaling_factor) != 0: + warnings.warn( + f"Length of {stage} dataset ({len(self.dataset[stage])}) " + f"is not multiple of micro_batch_size * scale_factor ({self.train_args.micro_batch_size * self.train_args.scaling_factor}). " + f"In this case, the train_step for the last batch of samples can fail! " + f"You can specify `drop_last=True` in DataLoader to fix this problem." + ) + # load a dummy input from training dataset + self.dummy_input = self._load_dummy_input() self.dummy_input = self._fix_input(self.dummy_input) # setup compute config @@ -118,6 +177,7 @@ def _create_model(): 'gbs': self.train_args.global_batch_size, 'fp16': self.train_args.fp16, 'bf16': self.train_args.bf16, + 'model_args': self.train_args.model_config.args, } # parallalize model @@ -137,78 +197,208 @@ def _create_model(): return torch.distributed.barrier() - + self.rank = torch.distributed.get_rank() + self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq + if len(self.dataloader['train']) % self.train_args.update_freq != 0: + self.total_train_steps_per_epoch += 1 # will add extra dummy batches + _, self.sync_group = self.train_args.compute_config.get_sync_group() self.model = pmodel_class() self.optimizer = self.train_args.create_parallel_optimizer(self.model) self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) + self.loggers = self.train_args.create_loggers() + self.hook = self.train_args.create_hook() + self._log_config(self.train_args.to_dict()) self._load_checkpoint() + @classmethod + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): + state_dicts = [torch.load(f, map_location='cpu') for f in checkpoint_files] + for i in range(1, len(state_dicts)): + if state_dicts[i]['train_args'] != state_dicts[0]['train_args']: + raise ValueError(f"train_args in {checkpoint_files[i]} is different from {checkpoint_files[0]}") + + module_state_dict, opt_state_dict = nnscaler.merge_state_dicts( + [s['model'] for s in state_dicts], + [s['optimizer'] for s in state_dicts] + ) + train_args = copy.deepcopy(state_dicts[0]['train_args']) + train_args['checkpoint_config']['save_type'] = 'merged' + merged_state_dict = { + 'model': module_state_dict, + 'optimizer': opt_state_dict, + 'lr_scheduler': state_dicts[0].get('lr_scheduler', None), + 'train_status': state_dicts[0]['train_status'], + 'train_args': train_args, + } + torch.save(merged_state_dict, output_file) + + def _log_finalize(self): + for logger in self.loggers: + logger.finalize() + + def _log_metrics(self, metrics: Dict[str, float], step: int): + for logger in self.loggers: + logger.log_metrics(metrics, step) + + def _log_config(self, config: Dict): + for logger in self.loggers: + logger.setup(config) + def _load_checkpoint(self): - if not self.train_args.ckpt_load_file: + resume_from = self.train_args.checkpoint_config.get_resume_checkpoint_dir() + if not resume_from: return - state_dict = torch.load(self.train_args.ckpt_load_file, map_location='cpu') - ckpt_save_type = state_dict.get('train_args', {}).get('ckpt_save_type', None) + logger.info(f"Resuming from {resume_from}") + if resume_from.is_file(): + resume_from = resume_from # when we load from merged checkpoint + else: + resume_from = resume_from / f'{self.rank}.ckpt' + state_dict = torch.load(resume_from, map_location='cpu') + self.hook.on_load_checkpoint(self, state_dict) + ckpt_save_type = state_dict['train_args']['checkpoint_config']['save_type'] - if not ckpt_save_type: # it is a merged state dict - nnscaler.load_merged_state_dicts( + if ckpt_save_type == 'merged': # it is a merged state dict + nnscaler.load_merged_state_dict( self.model, state_dict['model'], self.optimizer, state_dict['optimizer'], ) - if 'lr_scheduler' in state_dict: - self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) elif ckpt_save_type == 'sharded': - self.model.load_state_dict(state_dict['model']) - self.model.cuda() - self.optimizer.load_state_dict(state_dict['optimizer']) - if 'lr_scheduler' in state_dict: - self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) - self.train_status = TrainStatus(**state_dict['train_status']) + nnscaler.load_sharded_state_dict( + self.model, state_dict['model'], + self.optimizer, state_dict['optimizer'], + ) elif ckpt_save_type == 'deduped': nnscaler.load_deduped_state_dict( self.model, state_dict['model'], self.optimizer, state_dict['optimizer'], ) - if 'lr_scheduler' in state_dict: - self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) - self.train_status = TrainStatus(**state_dict['train_status']) else: raise ValueError(f"Unknown checkpoint type: {ckpt_save_type}") - def _save_checkpoint(self, from_end_of_epoch=True): - if not self.train_args.ckpt_save_dir: + if 'lr_scheduler' in state_dict: + if state_dict['lr_scheduler'] and not self.lr_scheduler: + raise ValueError("lr_scheduler is not set in the current trainer") + if self.lr_scheduler: + self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + self.train_status = TrainStatus(**state_dict['train_status']) + + def _save_checkpoint(self, loss): + checkpoint_config = self.train_args.checkpoint_config + + if checkpoint_config.no_save: + logger.info('Skip saving checkpoint because `no_save` is set to True') return - save_dir = Path(self.train_args.ckpt_save_dir) + + torch.distributed.barrier() + logger.info(f"Saving checkpoint after {self.num_train_steps} steps with loss={loss:.3f}.") + save_dir = Path(checkpoint_config.save_dir) save_dir.mkdir(parents=True, exist_ok=True) - if self.train_args.ckpt_save_type == 'sharded': + if checkpoint_config.save_type == 'sharded': model_state_dict= self.model.state_dict() optimizer_state_dict = self.optimizer.state_dict() - elif self.train_args.ckpt_save_type == 'deduped': + elif checkpoint_config.save_type == 'deduped': model_state_dict, optimizer_state_dict = nnscaler.deduped_state_dict( self.model, self.optimizer ) + elif checkpoint_config.save_type == 'merged': + raise ValueError("merged checkpoint is not supported for saving") else: - raise ValueError(f"Unknown checkpoint type: {self.train_args.ckpt_save_type}") - - train_status = copy.deepcopy(self.train_status) - if from_end_of_epoch: - train_status.in_epoch_pos = 0 - train_status.epoch += 1 + raise ValueError(f"Unknown checkpoint type: {checkpoint_config.save_type}") state_dict = { 'model': model_state_dict, 'optimizer': optimizer_state_dict, 'lr_scheduler': self.lr_scheduler.state_dict() if self.lr_scheduler else None, - 'train_status': asdict(train_status), + 'train_status': asdict(self.train_status), 'train_args': self.train_args.to_dict(), } - torch.save(state_dict, save_dir / - f'ckpt_{train_status.epoch}_{train_status.in_epoch_pos}_rank{torch.distributed.get_rank()}.pt' + self.hook.on_save_checkpoint(self, state_dict) + ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( + epoch=self.train_status.epoch, + step=self.num_train_steps, + rank=self.rank, ) + logger.info(f"Saving checkpoint to {str(ckpt_file.parent)}") + ckpt_file.parent.mkdir(parents=True, exist_ok=True) + torch.save(state_dict, ckpt_file) + + # save last + if checkpoint_config.save_last: + logger.info(f"Saving checkpoint as the last checkpoint.") + last_file = save_dir / CHECKPOINT_LAST_FILE_FORMAT.format( + rank=self.rank + ) + last_file.parent.mkdir(parents=True, exist_ok=True) + if checkpoint_config.symlink_best_and_last: + # remove the old symlink or file + if last_file.is_symlink() or last_file.exists(): + last_file.unlink() + # symblink as relative path + last_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) + # last_file.symlink_to(ckpt_file) + else: + shutil.copy(ckpt_file, last_file) + + # save best + if checkpoint_config.save_best and loss <= self.train_status.best_loss: + logger.info(f"Best loss updated: {self.train_status.best_loss:.3f} -> {loss:.3f}") + logger.info(f"Saving checkpoint as the best checkpoint.") + best_file = save_dir / CHECKPOINT_BEST_FILE_FORMAT.format( + epoch=self.train_status.epoch, + step=self.num_train_steps, + rank=self.rank, + ) + best_file.parent.mkdir(parents=True, exist_ok=True) + if checkpoint_config.symlink_best_and_last: + # symblink as relative path + if best_file.is_symlink() or best_file.exists(): + best_file.unlink() + best_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) + # best_file.symlink_to(ckpt_file) + else: + shutil.copy(ckpt_file, best_file) + + # remove old checkpoints + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + # only the first rank in the group will do the job + if self.rank % local_world_size == 0: + self._expire_checkpoints() - def _global_batch_iterator(self, num_skip_first = 0): + def _expire_checkpoints(self): + if not self.train_args.checkpoint_config.keep_last_n_checkpoints: # keep all + return + + save_dir = Path(self.train_args.checkpoint_config.save_dir) + checkpoints = [ + p.name for p in save_dir.glob('*') + if p.is_dir() and p.name not in [CHECKPOINT_BEST_DIR_NAME, CHECKPOINT_LAST_DIR_NAME] + ] + if len(checkpoints) <= self.train_args.checkpoint_config.keep_last_n_checkpoints: + return + + # (step, num) pairs + checkpoint_info = [(int(p.split('-')[1]), p) for p in checkpoints] + checkpoint_info.sort() + expire_list = checkpoint_info[:-self.train_args.checkpoint_config.keep_last_n_checkpoints] + + best_ckpt = save_dir / CHECKPOINT_BEST_DIR_NAME + if best_ckpt.exists(): + for p in best_ckpt.glob('*.ckpt'): + if p.is_symlink(): + ckpt_name = p.resolve().parent.name + if ckpt_name in expire_list: + expire_list.remove(ckpt_name) + logger.info('Keep old checkpoint `%s` because it is the best.', ckpt_name) + break # just check the first file is enough + + for _, ckpt_name in expire_list: + logger.info('Removing old checkpoint: %s', ckpt_name) + shutil.rmtree(save_dir / ckpt_name) + + def _global_batch_iterator(self, num_skip_first = 0, stage='train'): samples = [] - for idx, sample in enumerate(self.dataloader['train']): + for idx, sample in enumerate(self.dataloader[stage]): if idx < num_skip_first * self.train_args.update_freq: continue sample = self._fix_input(sample) @@ -219,28 +409,240 @@ def _global_batch_iterator(self, num_skip_first = 0): if samples: yield samples + def aggregate_outputs(self, loss_outputs, sync_group) -> AggregatedOutputs: + # loss is the first element of the output (or the only element) + losses = [ + loss if isinstance(loss, torch.Tensor) + else loss[0] + for loss in loss_outputs + ] + loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) + torch.distributed.all_reduce(loss_sum, group=sync_group) + num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) + torch.distributed.all_reduce(num_batches, group=sync_group) + + return AggregatedOutputs( + loss_sum = loss_sum.item(), + num_batches=num_batches.item(), + ) + + def _fix_batches(self, batches): + num_batches = len(batches) + is_dummy_batch = [False] * num_batches + if num_batches < self.train_args.update_freq: + gap = self.train_args.update_freq - num_batches + is_dummy_batch += [True] * gap + batches += [self.dummy_input] * gap + return batches, is_dummy_batch + def train(self): - num_skip_fist = self.train_status.in_epoch_pos + assert self.train_status.next_batch_index <= self.total_train_steps_per_epoch, \ + f"next_batch_index({self.train_status.next_batch_index}) " \ + f"should not be larger than total_train_steps_per_epoch ({self.total_train_steps_per_epoch})" + + if self.train_status.next_batch_index == self.total_train_steps_per_epoch: + self.train_status.epoch += 1 + self.train_status.next_batch_index = 0 + + next_batch_index = self.train_status.next_batch_index + self.hook.on_train_start(self) + for epoch in range(self.train_status.epoch, self.train_args.max_epochs): + self.dataloader['train'].sampler.set_epoch(epoch) + + torch.distributed.barrier() self.train_status.epoch = epoch - for idx, samples in enumerate(self._global_batch_iterator(num_skip_fist)): - self.train_status.in_epoch_pos = idx - is_dummy_batch = [False] * len(samples) - if len(samples) < self.train_args.update_freq: - gap = self.train_args.update_freq - len(samples) - is_dummy_batch += [True] * gap - samples += [self.dummy_input] * gap - - self.model.train() - self.optimizer.zero_grad() - losses = self.model.train_step(samples, is_dummy_batch) - if self.train_args.clip_gnorm: - self.optimizer.clip_gnorm(self.train_args.clip_gnorm) - self.optimizer.step() + self.train_status.next_batch_index = next_batch_index - if self.lr_scheduler: - self.lr_scheduler.step(epoch) + self.hook.on_epoch_start(self, epoch) + self.train_epoch(epoch) + self.hook.on_epoch_end(self, epoch) + + if self.lr_scheduler and self.train_args.lr_scheduler_config.interval == 'epoch': + self.lr_scheduler.step() + + if self.train_args.max_train_steps and self.num_train_steps >= self.train_args.max_train_steps: + logger.info(f"Reached train steps({self.train_args.max_train_steps}): Training is done.") + break + + next_batch_index = 0 + else: # not from `break` + logger.info(f"Reached max_epochs({self.train_args.max_epochs}): Training is done.") + + self._log_finalize() + self.hook.on_train_end(self) + torch.distributed.barrier() + + def _validate_and_save(self, step_stat: _StepStat): + if self.dataloader['val'] is None: + self._save_checkpoint(step_stat.train_loss) + return + loss = self._validate(step_stat) + self._save_checkpoint(loss) + if self.train_status.best_loss > loss: + self.train_status.best_loss = loss + + def _validate(self, step_stat: _StepStat): + if self.dataloader['val'] is None: + logger.info('No val dataset specified. Validation skipped.') + return step_stat.train_loss + + logger.info('Validating...') + data_iter = enumerate(self._global_batch_iterator(stage='val')) + if self.rank == 0: + total_val_steps_per_epoch = len(self.dataloader['val']) // self.train_args.update_freq + if len(self.dataloader['val']) % self.train_args.update_freq != 0: + total_val_steps_per_epoch += 1 # will add extra dummy batches + data_iter = tqdm( + data_iter, + total=total_val_steps_per_epoch, + initial=0, + desc=f'Validating', + disable=not self.train_args.enable_progress_bar + ) + + loss_sum = 0.0 + batches_count = 0 + + self.hook.on_val_start(self) + for idx, batches in data_iter: + if self.train_args.max_val_steps and idx >= self.train_args.max_val_steps: + break + + num_batches = len(batches) + batches, _ = self._fix_batches(batches) + + self.model.eval() + with torch.inference_mode(): + self.hook.on_val_step_start(self, batches[:num_batches], idx) + losses = self.model.infer_step(batches) + self.hook.on_val_step_end(self, losses[:num_batches], batches[:num_batches], idx) + + aggregate_outputs = self.train_args.aggregate_outputs or self.aggregate_outputs + aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) + self.hook.after_aggregate_val_step_outputs( + self, aggregated_outputs, + aggregated_outputs.loss_sum/aggregated_outputs.num_batches, + idx + ) + loss_sum += aggregated_outputs.loss_sum + batches_count += aggregated_outputs.num_batches + + # update train status + loss = loss_sum / batches_count + self.hook.on_val_end(self, loss) + + step_stat.val_loss = loss + self._log_metrics(asdict(step_stat), self.num_train_steps) + return loss + + def train_epoch(self, epoch): + VAL_STATUS_NO = 0 # not validated or saved + VAL_STATUS_VAL = 1 # validated but not saved + VAL_STATUS_SAVE = 2 # validated and saved + has_validated = VAL_STATUS_NO # 3 states + resume_from_idx = self.train_status.next_batch_index + data_iter = enumerate(self._global_batch_iterator(num_skip_first=resume_from_idx)) + if self.rank == 0: + data_iter = tqdm( + data_iter, + total=self.total_train_steps_per_epoch, + initial=resume_from_idx, + desc=f'Epoch {epoch:04d}', + disable=not self.train_args.enable_progress_bar + ) + + step_stat = _StepStat() + for idx, batches in data_iter: + has_validated = VAL_STATUS_NO + # the current batch is idx + resume_from_idx + # `+1` because the next_batch_index is the index of the next batch + # all save_checkpoint will be done at the end of the loop with correct next_batch_index + self.train_status.next_batch_index = idx + resume_from_idx + 1 + num_batches = len(batches) + batches, is_dummy_batch = self._fix_batches(batches) + + self.model.train() + + self.hook.before_zero_grad(self) + self.optimizer.zero_grad() + self.hook.after_zero_grad(self) + + self.hook.on_train_step_start(self, batches[:num_batches], idx) + losses = self.model.train_step(batches, is_dummy_batch) + self.hook.on_train_step_end(self, losses[:num_batches], batches[:num_batches], idx) + + aggregate_outputs = self.train_args.aggregate_outputs or self.aggregate_outputs + aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) + if self.train_args.optimizer_config.loss_reduction == 'mean': + loss = aggregated_outputs.loss_sum / aggregated_outputs.num_batches + else: + loss = aggregated_outputs.loss_sum + step_stat.train_loss = loss + self.hook.after_aggregate_train_step_outputs(self, aggregated_outputs, loss, idx) + if self.rank == 0: + data_iter.set_postfix({'loss': loss}) + + self.optimizer.sync_shard_grad() + + # scale gradients + if self.train_args.optimizer_config.grad_reduction == 'sum': + # do nothing. Already done in reducers + pass + elif self.train_args.optimizer_config.grad_reduction == 'mean': + if not aggregated_outputs.num_batches: + raise RuntimeError("`aggregate_outputs` doesn't set `num_batches` field") + self.optimizer.scale_grads(1.0 / aggregated_outputs.num_batches) + else: + assert self.train_args.optimizer_config.grad_reduction == 'per-token-mean' + if not aggregated_outputs.num_tokens: + raise RuntimeError("`aggregate_outputs` doesn't set `num_tokens` field") + self.optimizer.scale_grads(1.0 / aggregated_outputs.num_tokens) + + # clip gradients + self.hook.before_gnorm_clip(self) + if self.train_args.optimizer_config.clip_gnorm: + step_stat.gnorm = self.optimizer.clip_gnorm(self.train_args.optimizer_config.clip_gnorm) + else: + step_stat.gnorm = self.optimizer.clip_gnorm() + self.hook.after_gnorm_clip(self, step_stat.gnorm) + step_stat.gnorm = step_stat.gnorm.item() + + # update parameters + step_stat.lr = self.optimizer.param_groups[0]['lr'] + self.hook.before_optimizer_step(self) + self.optimizer.step() + self.hook.after_optimizer_step(self) + if self.lr_scheduler and self.train_args.lr_scheduler_config.interval == 'step': + self.lr_scheduler.step() + + # validate and save checkpoint + if self.train_args.checkpoint_config.every_n_train_steps and \ + self.num_train_steps % self.train_args.checkpoint_config.every_n_train_steps == 0: + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE - self._save_checkpoint(True) + if self.train_args.max_train_steps and self.num_train_steps >= self.train_args.max_train_steps: + if not has_validated: + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + logger.info(f"Reached max_train_steps({self.train_args.max_train_steps}): Training is done.") + break - num_skip_fist = 0 + if not has_validated and self.train_args.val_every_n_train_steps and \ + self.num_train_steps % self.train_args.val_every_n_train_steps == 0: + self._validate(step_stat) + has_validated = VAL_STATUS_VAL + # import time + # time.sleep(0.2) + else: # not from `break` + if not has_validated: + if self.train_args.max_epochs == self.train_status.epoch + 1 \ + or (self.train_args.checkpoint_config.every_n_epochs and \ + (self.train_status.epoch + 1) % self.train_args.checkpoint_config.every_n_epochs == 0): + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + elif self.train_args.val_every_n_epochs and \ + (self.train_status.epoch + 1) % self.train_args.val_every_n_epochs == 0: + self._validate(step_stat) + has_validated = VAL_STATUS_VAL diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index e6b3462d..6efe6d7d 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -1,6 +1,9 @@ from dataclasses import asdict, dataclass, field import importlib -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING +from pathlib import Path +import logging +import copy import torch import torch.utils @@ -9,13 +12,25 @@ import yaml import torch +from nnscaler.utils import transform_recursively from nnscaler.parallel import ComputeConfig, build_optimizer from nnscaler.runtime.module import ParallelModule from .arg_parser import deserialize_dataclass, merge_args, parse_args +from .loggers.logger_base import LoggerBase +from .train_hook import TrainHook + + +logger = logging.getLogger(__name__) def load_type(type_name: str): + """ + Load function/class from its full qualified name + """ + if callable(type_name): # a function or class + return type_name + parts = type_name.rsplit('.', 1) if len(parts) == 1: nm = __builtins__ @@ -31,13 +46,210 @@ class AggregatedOutputs: """ Aggregated outputs from all micro-batches """ - loss: Optional[int] = None - num_samples: Optional[int] = None + # the aggregated loss as a sum + loss_sum: float = None + # number of mini batches + num_batches: int = None + # number of tokens (only used when grad_reduction is 'per-token-mean') num_tokens: Optional[int] = None # any other custom outputs aggregated_outputs: Any = None +@dataclass +class ModelConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class OptimizerConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + clip_gnorm: float = 0.0 + + # loss reduction method + # mean: average the loss over all micro-batches + # sum: sum the loss of all micro-batches + # Please note in validation stage, this configuration is ignored + # the loss is always averaged over all batches + loss_reduction: str = 'mean' + # different ways of calculating grad + # sum: sum the gradients of all micro-batches + # mean: average the gradients over all micro-batches + # per-token-mean: average the gradients over all tokens + # you must specify `aggregate_outputs_fn` and return the number of tokens + grad_reduction: str = 'mean' + # the function to aggregate the outputs from all micro-batches + # inputs: (list of local outputs, torch group) + # output: AggregateOutputs + # you can use `torch.distributed.*` functions to do the work + aggregate_outputs_fn: str = None + + def __post_init__(self): + if self.grad_reduction not in ('sum', 'mean', 'per-token-mean'): + raise ValueError(f"Invalid gradient_accumulation {self.grad_reduction}") + if self.grad_reduction == 'per-token-mean' and not self.aggregate_outputs_fn: + raise ValueError("aggregate_outputs_fn is required when grad_reduction is 'per-token-mean'") + if self.loss_reduction not in ('mean', 'sum'): + raise ValueError(f"Invalid loss_reduction {self.loss_reduction}") + +@dataclass +class DatasetConfig: + type: str = None + train_args: Dict[str, Any] = field(default_factory=dict) + val_args: Dict[str, Any] = field(default_factory=dict) + test_args: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class DataloaderConfig: + type: str = 'torch.utils.data.DataLoader' + train_args: Dict[str, Any] = field(default_factory=dict) + # default to train_args + val_args: Dict[str, Any] = field(default_factory=dict) + # default to train_args + test_args: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class DatasetSamplerConfig: + type: str = 'torch.utils.data.DistributedSampler' + train_args: Dict[str, Any] = field(default_factory=dict) + val_args: Dict[str, Any] = field(default_factory=dict) + test_args: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class LRSchedulerConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + interval: str = 'epoch' + + def __post_init__(self): + if self.interval not in ('epoch', 'step'): + raise ValueError(f"Invalid interval {self.interval}") + + +@dataclass +class CheckpointConfig: + save_dir: str = './checkpoints' + no_save: bool = False + + # `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is + # a folder with as many files as the world size. + # `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is + # a folder with as many files as the world size. + # `"merged"`: everything has been merged into a single file. + # Used internally only when you merge the checkpoint files via `Trainer.merge_checkpoints` + save_type: str = 'sharded' + + save_last: bool = True + save_best: bool = True + symlink_best_and_last: bool = True + + # save the checkpoint every n train steps + # Please note we always run validation before saving the checkpoint + every_n_train_steps: Optional[int] = None + every_n_epochs: Optional[int] = None + keep_last_n_checkpoints: Optional[int] = None + + # resume training from a checkpoint folder + # can be 'last'/'best'/a specific folder + # we will not resume if resume_from is last or best but the corresponding checkpoint does not exist + resume_from: str = None + + def get_resume_checkpoint_dir(self) -> Optional[Path]: + if not self.resume_from: + return None + if self.resume_from in ['last', 'best']: + d = Path(self.save_dir) / self.resume_from + if not d.exists(): + return None + return d + return Path(self.resume_from) + + def __post_init__(self): + if self.resume_from: + if self.resume_from in ['last', 'best']: + if not self.save_dir: + raise ValueError("save_dir is required when resume_from is 'last'/'best'") + if not (Path(self.save_dir) / self.resume_from).exists(): + logger.warning(f"`{self.resume_from}` checkpoint does not exist. Will train from scratch.") + elif not Path(self.resume_from).exists(): + raise ValueError(f"resume_from {self.resume_from} does not exist") + if self.no_save: + return + + if self.save_type not in ('sharded', 'deduped', 'merged'): + raise ValueError(f"Invalid save_type {self.save_type}") + if not self.save_dir: + raise ValueError("save_dir is required") + + if self.every_n_epochs is not None and self.every_n_train_steps is not None: + raise ValueError("Cannot specify both every_n_epochs and every_n_train_steps") + if self.every_n_epochs is None and self.every_n_train_steps is None: + self.every_n_epochs = 1 # default to 1 epoch + + if self.every_n_train_steps is not None and self.every_n_train_steps < 1: + raise ValueError("every_n_train_steps must be positive") + if self.every_n_epochs is not None and self.every_n_epochs < 1: + raise ValueError("every_n_epochs must be positive") + if self.keep_last_n_checkpoints is not None and self.keep_last_n_checkpoints < 1: + raise ValueError("keep_last_n_checkpoints must be positive") + + +@dataclass +class LogConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HookConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HookMapConfig: + on_train_start: str = None + on_train_end: str = None + on_val_start: str = None + on_val_end: str = None + + on_epoch_start: str = None + on_epoch_end: str = None + + on_train_step_start: str = None + on_train_step_end: str = None + on_val_step_start: str = None + on_val_step_end: str = None + + after_aggregate_train_step_outputs: str = None + after_aggregate_val_step_outputs: str = None + + before_zero_grad: str = None + after_zero_grad: str = None + + before_gnorm_clip: str = None + after_gnorm_clip: str = None + + before_optimizer_step: str = None + after_optimizer_step: str = None + + on_load_checkpoint: str = None + on_save_checkpoint: str = None + + +class ArgsTrainHook(TrainHook): + def __init__(self, hook_config: HookMapConfig): + self.config = hook_config + for k, v in asdict(hook_config).items(): + if v: + setattr(self, k, load_type(v)) + + @dataclass class TrainerArgs: compute_config: ComputeConfig = None @@ -49,85 +261,64 @@ class TrainerArgs: # compile: compile the model but not training # run: compile and run the model run_mode: str = 'run' - # the model state dict for tracing. - ckpt_tracing: str = None + # the model state dict file for tracing. + # It is only used in tracing to serve as the initial state dict of the model. + tracing_from_weights: str = None + + model_config: ModelConfig = field(default_factory=ModelConfig) + optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig) + dataset_config: DatasetConfig = field(default_factory=DatasetConfig) + dataloader_config: DataloaderConfig = field(default_factory=DataloaderConfig) + dataset_sampler_config: DatasetSamplerConfig = field(default_factory=DatasetSamplerConfig) + lr_scheduler_config: Optional[LRSchedulerConfig] = None + checkpoint_config: CheckpointConfig = field(default_factory=CheckpointConfig) + log_config: List[LogConfig] = field(default_factory=list) + # It can be `HookConfig` or `HookMapConfig` + hook_config: Any = None - model_class: str = None - model_args: Dict[str, Any] = field(default_factory=dict) fp16: bool = False bf16: bool = False - - optimizer_class: str = None - optimizer_args: Dict[str, Any] = field(default_factory=dict) - - dataset_class: str = None - train_dataset_args: Dict[str, Any] = field(default_factory=dict) - val_dataset_args: Dict[str, Any] = field(default_factory=dict) - test_dataset_args: Dict[str, Any] = field(default_factory=dict) - - dataloader_class: str = 'torch.utils.data.DataLoader' - train_dataloader_args: Dict[str, Any] = field(default_factory=dict) - # default to train_dataloader_args - val_dataloader_args: Dict[str, Any] = field(default_factory=dict) - # default to train_dataloader_args - test_dataloader_args: Dict[str, Any] = field(default_factory=dict) - - dataset_sampler_class: str = 'torch.utils.data.DistributedSampler' - train_dataset_sampler_args: Dict[str, Any] = field(default_factory=dict) - val_dataset_sampler_args: Dict[str, Any] = field(default_factory=dict) - test_dataset_sampler_args: Dict[str, Any] = field(default_factory=dict) - - lr_scheduler_class: str = None - lr_scheduler_args: Dict[str, Any] = field(default_factory=dict) - micro_batch_size: int = 1 - global_batch_size: int = 1 + # default is self.micro_batch_size*self.scaling_factor + # which means update_freq is 1 + global_batch_size: Optional[int] = None max_epochs: int = 1000 - clip_gnorm: float = 0.0 - # TODO: support different ways of calculating grad and loss - # sum: sum the gradients of all micro-batches - # per-sample-mean: average the gradients over all micro-batches - # per-token-mean: average the gradients over all tokens - # you must specify `aggregate_outputs_fn` and return the number of tokens - gradient_accumulation: str = 'sum' - # the function to aggregate the outputs from all micro-batches - # inputs: (list of local outputs, torch group) - # output: AggregateOutputs - # you can use `torch.distributed.*` functions to do the work - aggregate_outputs_fn: str = None + max_train_steps: int = 0 + max_val_steps: int = 0 - ckpt_save_dir: str = None - # `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is - # a folder with as many files as the world size. - # `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is - # a folder with as many files as the world size. - ckpt_save_type: str = 'sharded' - ckpt_load_file: str = None + # validation frequency + val_every_n_train_steps: Optional[int] = None + val_every_n_epochs: Optional[int] = 1 + + enable_progress_bar: bool = True def __post_init__(self): if not self.compute_config: raise ValueError("compute_config is required") if not self.compute_config.use_end2end: raise ValueError("use_end2end must be True") - if self.global_batch_size % self.micro_batch_size != 0: - raise ValueError(f"global_batch_size {self.global_batch_size} is not divisible by micro_batch_size {self.micro_batch_size}") + if not self.global_batch_size: + self.global_batch_size = self.micro_batch_size*self.scaling_factor + if self.global_batch_size % (self.micro_batch_size*self.scaling_factor) != 0: + raise ValueError(f"`global_batch_size` {self.global_batch_size} is not divisible by `micro_batch_size*(runtime_ngpus/plan_ngpus)` " + f"which is {self.micro_batch_size * self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus}") if self.run_mode not in ('compile', 'run'): raise ValueError(f"Invalid run_mode {self.run_mode}") - if self.ckpt_save_type not in ('sharded', 'deduped'): - raise ValueError(f"Invalid ckpt_save_type {self.ckpt_save_type}") if self.fp16 and self.bf16: raise ValueError("Cannot use both fp16 and bf16") - if not self.model_class: - raise ValueError("model_class is required") - if not self.optimizer_class: - raise ValueError("optimizer_class is required") - if not self.dataset_class: - raise ValueError("dataset_class is required") - if not self.dataloader_class: - raise ValueError("dataloader_class is required") - if not self.dataset_sampler_class: - raise ValueError("dataset_sampler_class is required") + if not self.model_config.type: + raise ValueError("model type is required") + if not self.optimizer_config.type: + raise ValueError("optimizer type is required") + if not self.dataset_config.type: + raise ValueError("dataset type is required") + if not self.dataloader_config.type: + raise ValueError("dataloader type is required") + if not self.dataset_sampler_config.type: + raise ValueError("dataset_sampler type is required") + if self.lr_scheduler_config and not self.lr_scheduler_config.type: + raise ValueError("lr_scheduler type is required") @classmethod def from_cli(cls, argv: List[str]) -> 'TrainerArgs': @@ -146,7 +337,13 @@ def from_dict(cls, d: Dict[str, Any]) -> 'TrainerArgs': return ta def to_dict(self): - return asdict(self) + # replace all callable with their full qualified name + # please note it is not reversible if local functions are used + return transform_recursively( + asdict(self), + lambda class_or_func: f'{class_or_func.__module__}.{class_or_func.__qualname__}', + callable, + ) @classmethod def from_yaml(cls, path: str) -> 'TrainerArgs': @@ -155,6 +352,7 @@ def from_yaml(cls, path: str) -> 'TrainerArgs': @classmethod def create_kwarg(cls, value: dict): + value = copy.deepcopy(value) for k, v in value.items(): if isinstance(v, dict): value[k] = cls.create_kwarg(v) @@ -181,18 +379,13 @@ def create_kwarg(cls, value: dict): @property def model_type(self): - return load_type(self.model_class) + return load_type(self.model_config.type) @property - def collate_fn(self): - """ - Used to generate dummy input from dataset - """ - args = self.train_dataloader_args - if 'collate_fn' in args: - return load_type(args['collate_fn']) - # hack to get default collate_fn - return torch.utils.data.dataloader.default_collate + def aggregate_outputs(self): + if not self.optimizer_config.aggregate_outputs_fn: + return None + return load_type(self.optimizer_config.aggregate_outputs_fn) @property def scaling_factor(self): @@ -203,38 +396,42 @@ def update_freq(self): return self.global_batch_size // self.micro_batch_size // self.scaling_factor def create_model(self) -> torch.nn.Module: - kwargs = self.create_kwarg(self.model_args) + kwargs = self.create_kwarg(self.model_config.args) return self.model_type(**kwargs) def create_parallel_optimizer(self, parallel_model: ParallelModule): - kwargs = self.create_kwarg(self.optimizer_args) - optimizer_class = load_type(self.optimizer_class) + kwargs = self.create_kwarg(self.optimizer_config.args) + optimizer_class = load_type(self.optimizer_config.type) return build_optimizer(parallel_model, optimizer_class, **kwargs) def create_dataset(self, stage='train'): - dataset_args = getattr(self, f'{stage}_dataset_args') + dataset_args = getattr(self.dataset_config, f'{stage}_args') if not dataset_args: return None kwargs = self.create_kwarg(dataset_args) - dataset_class = load_type(self.dataset_class) - if issubclass(dataset_class, torch.utils.data.IterableDataset): + dataset_class = load_type(self.dataset_config.type) + dataset = dataset_class(**kwargs) + if isinstance(dataset_class, torch.utils.data.IterableDataset): raise ValueError("IterableDataset is not supported") - return dataset_class(**kwargs) + return dataset def create_sampler(self, dataset, stage='train'): - sampler_args = getattr(self, f'{stage}_dataset_sampler_args') - sampler_args = sampler_args or self.train_dataset_sampler_args + sampler_args = getattr(self.dataset_sampler_config, f'{stage}_args') + sampler_args = sampler_args or self.dataset_sampler_config.train_args kwargs = self.create_kwarg(sampler_args) kwargs['dataset'] = dataset kwargs['num_replicas'] = self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus kwargs['rank'] = torch.distributed.get_rank() // self.compute_config.plan_ngpus - sampler_class = load_type(self.dataset_sampler_class) + sampler_class = load_type(self.dataset_sampler_config.type) return sampler_class(**kwargs) def create_dataloader(self, stage='train', dataset=None): - dataloader_args = getattr(self, f'{stage}_dataloader_args') - dataloader_args = dataloader_args or self.train_dataloader_args + dataloader_args = getattr(self.dataloader_config, f'{stage}_args') + dataloader_args = dataloader_args or self.dataloader_config.train_args kwargs = self.create_kwarg(dataloader_args) + if 'batch_size' in kwargs: + raise ValueError("`batch_size` should not be specified in dataloader_args. " + "You should use `micro_batch_size` instead.") kwargs['dataset'] = dataset or self.create_dataset(stage) if kwargs['dataset'] is None: return None @@ -244,12 +441,40 @@ def create_dataloader(self, stage='train', dataset=None): kwargs['collate_fn'] = load_type(kwargs['collate_fn']) kwargs['batch_size'] = self.micro_batch_size kwargs['sampler'] = self.create_sampler(kwargs['dataset'], stage) - dataloader_class = load_type(self.dataloader_class) + dataloader_class = load_type(self.dataloader_config.type) return dataloader_class(**kwargs) def create_lr_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.LRScheduler: - if not self.lr_scheduler_class: + if not self.lr_scheduler_config: return None - kwargs = self.create_kwarg(self.lr_scheduler_args) - lr_scheduler_class = load_type(self.lr_scheduler_class) + kwargs = self.create_kwarg(self.lr_scheduler_config.args) + lr_scheduler_class = load_type(self.lr_scheduler_config.type) return lr_scheduler_class(optimizer, **kwargs) + + def create_loggers(self) -> List['LoggerBase']: + loggers = [] + for log_config in self.log_config: + kwargs = self.create_kwarg(log_config.args) + logger_class = load_type(log_config.type) + loggers.append(logger_class(**kwargs)) + return loggers + + def create_hook(self) -> TrainHook: + if not self.hook_config: + return TrainHook() # empty hook + + if isinstance(self.hook_config, dict): + if 'type' in self.hook_config: + hook_config = HookConfig(**self.hook_config) + else: + hook_config = HookMapConfig(**self.hook_config) + else: + hook_config = self.hook_config + + if isinstance(hook_config, HookConfig): + kwargs = self.create_kwarg(hook_config.args) + return load_type(hook_config.type)(kwargs) + elif isinstance(hook_config, HookMapConfig): + return ArgsTrainHook(hook_config) + else: + raise ValueError(f"Invalid hook_config {hook_config}") diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py index 4aa0204d..2dd184ab 100644 --- a/nnscaler/integration/lightning/pytorch/strategy.py +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -59,6 +59,7 @@ import nnscaler from nnscaler.integration.lightning.utils import inplace_optimizer_fn from nnscaler.runtime.device import DeviceGroup +from nnscaler.utils import enforce_zero_num_worker from .precision import NnScalerPrecision @@ -216,7 +217,7 @@ def _get_dummy_forward_args(self, model: pl.LightningModule) -> Dict[str, Any]: data_source = trainer.fit_loop._data_source assert data_source is not None, "The `data_source` must be defined in the trainer." assert data_source.instance is not None, "The `instance` must be defined in the data source." - with enforce_0_num_worker(DataLoader): + with enforce_zero_num_worker(DataLoader): dataloader = data_source.dataloader() assert dataloader.num_workers == 0, "The dataloader must have `num_workers=0`." data = next(iter(dataloader)) @@ -547,15 +548,3 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: def _get_process_group_backend(self) -> str: return 'nccl' # nnscaler only support nccl - - -@contextmanager -def enforce_0_num_worker(cls) -> Generator[None, None, None]: - """Context manager to enforce the number of workers to be 0 in DataLoader.""" - _old__init__ = cls.__init__ - def _new__init__(self, *args, **kwargs) -> None: - kwargs['num_workers'] = 0 - _old__init__(self, *args, **kwargs) - cls.__init__ = _new__init__ - yield - cls.__init__ = _old__init__ diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 72fc6c71..a022dd03 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1723,7 +1723,7 @@ def merge_state_dicts( @torch.no_grad() -def load_merged_state_dicts( +def load_merged_state_dict( module: torch.nn.Module, module_state_dict: Dict[str, Any], optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, @@ -2323,3 +2323,36 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): module.mark_non_persistent_buffers_inited() torch.distributed.barrier() + + +@torch.no_grad() +def load_sharded_state_dict( + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + device: Union[str, torch.device] = None +): + """ + Load the sharded state dicts to the module, and optionally the optimizer to a specified device. + + Args: + module (torch.nn.Module): the module to be loaded + module_state_dict (Dict[str, Any]): the sharded model state dict + optimizer (Optional[torch.optim.Optimizer]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the sharded optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. + + Returns: + None + """ + + device = device or torch.cuda.current_device() + module.load_state_dict(module_state_dict) + module.to(device) + if optimizer: + if optimizer_state_dict is None: + raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") + optimizer.load_state_dict(optimizer_state_dict) diff --git a/nnscaler/utils.py b/nnscaler/utils.py index e14819b7..30dcfeab 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -1,10 +1,12 @@ -import os -from typing import Optional, Tuple, Callable, List, Set, Any, Iterable +from contextlib import contextmanager +from functools import wraps +from typing import Generator, Optional, Tuple, Callable, List, Set, Any, Iterable, Type, Union import logging from pathlib import Path import sys from collections import defaultdict from dataclasses import dataclass +import inspect import nnscaler from nnscaler.runtime.device import DeviceGroup @@ -165,6 +167,101 @@ def set_default_logger_level(level): ) +@contextmanager +def enforce_zero_num_worker(cls) -> Generator[None, None, None]: + """Context manager to enforce the number of workers to be 0 in DataLoader.""" + _old__init__ = cls.__init__ + def _new__init__(self, *args, **kwargs) -> None: + kwargs['num_workers'] = 0 + _old__init__(self, *args, **kwargs) + cls.__init__ = _new__init__ + yield + cls.__init__ = _old__init__ + + +def rank_zero_only(fn: Callable[..., None]) -> Callable[..., None]: + """ + Wrap a function to call internal function only in rank zero. + Function that can be used as a decorator to enable a function/method being called only on global rank 0. + Please note + 1. that the fn should be no return values, and no side effect. + So it is only recommend to use this decorator for logging or printing. + 2. `fn` will also be called if the distributed environment is not initialized. + """ + + @wraps(fn) + def wrapped_fn(*args, **kwargs): + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else None + if rank == 0 or rank is None: + fn(*args, **kwargs) + + return wrapped_fn + + +_DICT_ITEMS_TYPE = type({}.items()) +_DICT_KEYS_TYPE = type({}.keys()) +_DICT_VALUES_TYPE = type({}.values()) + + +def transform_recursively(data: Any, fn: Callable[[Any], Any], + target_types: Union[Callable[[Any], bool], Type, Tuple[Type, ...]], + collection_types = (tuple, list, dict), skip_dict_keys = True +) -> Any: + """ + Transform the data with the given function, will recursively apply the function to the nested data. + Args: + data: the data to be transformed. + fn: the function to apply. + target_types: the target types to apply the function. + collection_types: the collection types to apply the function to the nested data. + skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). + _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. + """ + if isinstance(data, collection_types): + if isinstance(data, tuple): + return tuple(transform_recursively(t, fn, target_types, collection_types) for t in data) + if isinstance(data, list): + return list(transform_recursively(t, fn, target_types, collection_types) for t in data) + if isinstance(data, set): + return set(transform_recursively(t, fn, target_types, collection_types) for t in data) + if isinstance(data, dict): + return { + k if skip_dict_keys else transform_recursively(k, fn, target_types, collection_types): + transform_recursively(v, fn, target_types, collection_types) + for k, v in data.items() + } + if isinstance(data, _DICT_ITEMS_TYPE): + return { + k if skip_dict_keys else transform_recursively(k, fn, target_types, collection_types): + transform_recursively(v, fn, target_types, collection_types) + for k, v in data + }.items() + if isinstance(data, _DICT_KEYS_TYPE): + return { + transform_recursively(k, fn, target_types, collection_types): i + for i, k in enumerate(data) + }.keys() + if isinstance(data, _DICT_VALUES_TYPE): + return { + i: transform_recursively(v, fn, target_types, collection_types) + for i, v in enumerate(data) + }.values() + if isinstance(data, slice): + return slice( + transform_recursively(data.start, fn, target_types, collection_types), + transform_recursively(data.stop, fn, target_types, collection_types), + transform_recursively(data.step, fn, target_types, collection_types) + ) + raise ValueError(f"Unsupported collection type: {type(data)}") + elif isinstance(target_types, (tuple, list)) or inspect.isclass(target_types): + if isinstance(data, target_types): + return fn(data) + elif callable(target_types): # not a class, but callable. treat as a check function. + if target_types(data): + return fn(data) + return data + + class accum_mode: """Make cube execution in gradient accumulation mode. diff --git a/requirements-dev.txt b/requirements-dev.txt index d9c13b5e..9107136d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,3 +15,5 @@ tabulate tox tox-conda yapf +wandb +tensorboard diff --git a/requirements.txt b/requirements.txt index cda12044..f8e85677 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pulp pybind11 pyyaml torch>=2.0 +tqdm diff --git a/tests/cli/common.py b/tests/cli/common.py index 6020579c..ea300e1b 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -5,6 +5,7 @@ class SimpleDataset(Dataset): def __init__(self, dim: int, size: int = 100): + torch.manual_seed(0) self.data = torch.randn(size, dim) self.target = torch.rand(size, dim) diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 69d06753..99592aeb 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -1,4 +1,4 @@ -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from typing import List, Optional, Tuple, Dict, Any, Union import sys @@ -108,6 +108,14 @@ class A: y = deserialize_dataclass(x, A) assert y == A(a=1, b=False, c=C(d=3, e=4), f=F(g=G(h=5)), k=[10, 12], v={'a': 10, 'b': 20}) + x = parse_args(['--a=1', '--b', 'False', '--c.d=3', '--c.e', '4', '--f.g.unknown=5', '--v.a=10', '--v.b=20', '--k=[10,12]']) + with pytest.raises(ValueError): + y = deserialize_dataclass(x, A) + + x = parse_args(['--unknowna=1', '--b', 'False', '--c.d=3', '--c.e', '4', '--f.g.h=5', '--v.a=10', '--v.b=20', '--k=[10,12]']) + with pytest.raises(ValueError): + y = deserialize_dataclass(x, A) + x = parse_args(['--a=1', '--b', '0', '--c.d=3', '--c.e', '4', '--f.g.h=5', '--v.a=10', '--v.b=20', '--z.__type=tests.cli.test_arg_parser.GConfig', @@ -138,3 +146,16 @@ class A: } ) assert deserialize_dataclass(asdict(y), A) == y + + +def test_deserialize_list(): + @dataclass + class A: + a: List[int] = field(default_factory=list) + b: List[GConfig] = field(default_factory=list) + c: Tuple[int, ...] = None + + + x = parse_args(['--a.0=1', '--a.1=2', '--b.0.h=3', '--b.1.h=4', '--c.1=4']) + y = deserialize_dataclass(x, A) + assert y == A(a=[1, 2], b=[GConfig(h=3), GConfig(h=4)], c=(None, 4)) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index a8e2b179..e17a0e23 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -1,28 +1,189 @@ from pathlib import Path +import shutil import torch import pytest +import torch.distributed from nnscaler.cli.trainer import Trainer +from tests.parallel_module.common import assert_equal from ..launch_torchrun import launch_torchrun -def trainer_worker(save_dir): +def trainer_logging_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + log_savedir = save_dir / 'log' + tb_log_savedir = log_savedir / 'tensorboard' + wandb_log_savedir = log_savedir / 'wandb' + # train 4 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint_config.no_save', 'true', + '--log_config.0.type', 'nnscaler.cli.loggers.TensorBoardLogger', + '--log_config.0.args.name', 'test-cli', + '--log_config.0.args.root_dir', str(tb_log_savedir), + '--log_config.1.type', 'nnscaler.cli.loggers.WandbLogger', + '--log_config.1.args.name', 'test-cli', + '--log_config.1.args.dir', str(wandb_log_savedir), + '--log_config.1.args.project', 'nnscaler', + '--log_config.1.args.mode', 'offline', + ]) + trainer.train() + + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + assert (tb_log_savedir / 'test-cli').exists() + tfevents = list((tb_log_savedir / 'test-cli').glob('events.out.tfevents.*')) + assert len(tfevents) == 1 + assert tfevents[0].stat().st_size > 1000 + + assert (wandb_log_savedir / 'wandb').exists() + wandb_offline_dir = list((wandb_log_savedir / 'wandb').glob('offline-run-*')) + assert len(wandb_offline_dir) == 1 + wandb_run_db = list(wandb_offline_dir[0].glob('run-*.wandb')) + assert len(wandb_run_db) == 1 + assert wandb_run_db[0].stat().st_size > 1000 + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_trainer_logging(tmp_path): + launch_torchrun(4, trainer_logging_worker, tmp_path) + + +def trainer_resume_worker(save_dir, save_type, bf16): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) gen_savedir = save_dir / 'gen' ckpt_savedir = save_dir / 'ckpt' + # train 4 epcho in one time trainer = Trainer([ '-f', config_path, + '--bf16', str(bf16), + '--max_epochs', '4', + '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', - '--ckpt_save_type', 'sharded', - '--ckpt_save_dir', str(ckpt_savedir), + '--checkpoint_config.save_type', save_type, + '--checkpoint_config.save_dir', str(ckpt_savedir), + '--checkpoint_config.resume_from', 'last', + '--checkpoint_config.keep_last_n_checkpoints', '30', ]) trainer.train() + ckpt_files = set(ckpt_savedir.glob('**/*.ckpt')) + assert len(ckpt_files)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last + + # train 4 epcho two times (resume from last) + ckpt0_savedir = save_dir / 'ckpt0' + # first two epochs + trainer = Trainer([ + '-f', config_path, + '--bf16', str(bf16), + '--max_epochs', '2', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint_config.save_type', save_type, + '--checkpoint_config.save_dir', str(ckpt0_savedir), + '--checkpoint_config.resume_from', 'last', + '--checkpoint_config.keep_last_n_checkpoints', '30', + ]) + trainer.train() + ckpt0_files0 = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} + assert len(ckpt0_files0)/4 == min(30, trainer.total_train_steps_per_epoch * 2) + 2 # 2 for best/last + + # create merged checkpoint + ckpt1_savedir = save_dir / 'ckpt1' + ckpt1_savedir.mkdir(parents=True, exist_ok=True) + if trainer.rank == 0: + Trainer.merge_checkpoint(list((ckpt0_savedir / 'last').glob('*.ckpt')), ckpt1_savedir / 'merged.pt') + + # continue with the last two epochs (resume for sharded/deduped checkpoint) + trainer = Trainer([ + '-f', config_path, + '--bf16', str(bf16), + '--max_epochs', '4', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint_config.save_type', save_type, + '--checkpoint_config.save_dir', str(ckpt0_savedir), + '--checkpoint_config.resume_from', 'last', + '--checkpoint_config.keep_last_n_checkpoints', '30', + ]) + trainer.train() + left_files = { + f: f.stat().st_mtime_ns for f in ckpt0_files0.keys() + if f.exists() and f.parent.name not in ['last', 'best'] + } + assert left_files # some checkpoints are removed + for f, s in left_files.items(): # make sure the old checkpoints are not overwritten + assert ckpt0_files0[f] == s + + ckpt0_files1 = set(ckpt0_savedir.glob('**/*.ckpt')) + assert len(ckpt0_files1)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last + + torch.distributed.barrier() + + # continue with the last two epochs (resume for merged) + trainer = Trainer([ + '-f', config_path, + '--bf16', str(bf16), + '--max_epochs', '4', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint_config.save_type', save_type, + '--checkpoint_config.save_dir', str(ckpt1_savedir), + '--checkpoint_config.resume_from', str(ckpt1_savedir / 'merged.pt'), + '--checkpoint_config.keep_last_n_checkpoints', '30', + ]) + trainer.train() + left_files = { + f: f.stat().st_mtime_ns for f in ckpt0_files0.keys() + if f.exists() and f.parent.name not in ['last', 'best'] + } + assert left_files # some checkpoints are removed + for f, s in left_files.items(): # make sure the old checkpoints are not overwritten + assert ckpt0_files0[f] == s + + ckpt0_files1 = set(ckpt0_savedir.glob('**/*.ckpt')) + assert len(ckpt0_files1)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last + + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} + for i in range(4): + x = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt') + y = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt') + z = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt') + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + assert_equal(x['lr_scheduler'], y['lr_scheduler']) + assert_equal(x['model'], z['model']) + assert_equal(x['optimizer'], z['optimizer']) + assert_equal(x['lr_scheduler'], z['lr_scheduler']) + + if save_type == 'deduped': + assert (ckpt_savedir / 'last/0.ckpt').stat().st_size > (ckpt_savedir / 'last/2.ckpt').stat().st_size + assert (ckpt_savedir / 'last/1.ckpt').stat().st_size > (ckpt_savedir / 'last/3.ckpt').stat().st_size + else: + assert (ckpt_savedir / 'last/0.ckpt').stat().st_size == (ckpt_savedir / 'last/2.ckpt').stat().st_size + assert (ckpt_savedir / 'last/1.ckpt').stat().st_size == (ckpt_savedir / 'last/3.ckpt').stat().st_size @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') -def test_trainer(tmp_path): - launch_torchrun(4, trainer_worker, tmp_path) +@pytest.mark.parametrize('save_type', ['sharded', 'deduped']) +@pytest.mark.parametrize('bf16', [True, False]) +def test_trainer_resume(tmp_path, save_type, bf16): + launch_torchrun(4, trainer_resume_worker, tmp_path, save_type, bf16) diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml index 5454597d..e11c3048 100644 --- a/tests/cli/trainer_args.yaml +++ b/tests/cli/trainer_args.yaml @@ -10,17 +10,29 @@ pas_policy: autodist micro_batch_size: 2 global_batch_size: 8 max_epochs: 4 +max_train_steps: 100 -model_class: tests.cli.common.MLP -model_args: - dim: 16 - nlayers: 16 +model_config: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 -optimizer_class: torch.optim.Adam -optimizer_args: - lr: 0.01 +optimizer_config: + type: torch.optim.Adam + args: + lr: 0.01 -dataset_class: tests.cli.common.SimpleDataset -train_dataset_args: - dim: 16 - size: 100 +dataset_config: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +checkpoint_config: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/integration/lightning/pytorch/simple_models.py b/tests/integration/lightning/pytorch/simple_models.py index a481c2c5..037abb22 100644 --- a/tests/integration/lightning/pytorch/simple_models.py +++ b/tests/integration/lightning/pytorch/simple_models.py @@ -33,6 +33,8 @@ def __init__(self, num_features=32, num_classes=3, batch_size=10, lr=0.01): self.test_acc = acc.clone() self.dummy_forward_args_fn = lambda batch: {"x": batch[0]} self.update_history = [] + self.loss_history = [] + self.val_loss_history = [] # @property # def dummy_forward_args(self): @@ -65,8 +67,11 @@ def validation_step(self, batch, batch_idx): assert not self.training x, y = batch logits = self.forward(x) + loss = F.cross_entropy(logits, y) self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False, sync_dist=True) self.log("val_acc", self.valid_acc(logits, y), prog_bar=True, sync_dist=True) + return {'loss': loss} + def test_step(self, batch, batch_idx): assert not self.training @@ -91,6 +96,14 @@ def _fix_name(name): self.update_history.append((grads, weights)) return super().configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm) + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: + self.loss_history.append(outputs["loss"].item()) + + def on_validation_batch_end( + self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ): + self.val_loss_history.append(outputs["loss"].item()) + class ClassificationModelWithLRScheduler(ClassificationModel): def configure_optimizers(self): diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index fa00ee6c..4a037731 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -15,10 +15,13 @@ from nnscaler.integration.lightning.pytorch import NnScalerStrategy, NnScalerPrecision import nnscaler.runtime +from nnscaler.cli.trainer import Trainer as CliTrainer +from nnscaler.cli.trainer_args import CheckpointConfig, DataloaderConfig, DatasetConfig, DatasetSamplerConfig, HookConfig, ModelConfig, TrainerArgs, OptimizerConfig, LRSchedulerConfig + from ....launch_torchrun import launch_torchrun from ....utils import init_random from ....parallel_module.common import assert_close, assert_equal -from .simple_datamodules import ClassifDataModule +from .simple_datamodules import ClassifDataModule, SklearnDataset from .simple_models import BoringModel, ClassificationModel, ClassificationModelWithLRScheduler @@ -29,6 +32,7 @@ def fit_worker(tmp_path): trainer = Trainer( default_root_dir=tmp_path, max_epochs=2, + enable_progress_bar=False, accelerator="gpu", devices=1, gradient_clip_val=None, strategy=NnScalerStrategy(compute_config=compute_config, pas_policy='tp', gen_savedir=tmp_path), @@ -91,7 +95,8 @@ def on_load_checkpoint(self, _): state = pl_load(ckpt / '0.pt') # Resume training trainer = Trainer( - default_root_dir=tmp_path, max_epochs=2, enable_progress_bar=False, + default_root_dir=tmp_path, max_epochs=2, + enable_progress_bar=False, strategy=NnScalerStrategy( compute_config=compute_config, pas_policy='tp', @@ -120,6 +125,7 @@ def trainer_accumulate_grad_batches_zero_grad(tmp_path, accumulate_grad_batches) limit_train_batches=20, limit_val_batches=1, max_epochs=1, + enable_progress_bar=False, enable_model_summary=False, accumulate_grad_batches=accumulate_grad_batches, strategy=NnScalerStrategy(compute_config=ComputeConfig(1, 2), pas_policy='tp', gen_savedir=tmp_path), @@ -136,6 +142,158 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmp_path, accumulate_grad_bat launch_torchrun(2, trainer_accumulate_grad_batches_zero_grad, tmp_path, accumulate_grad_batches) +# hack to satisfy cli requirements +_correctnes_worker_datamodule: ClassifDataModule = None +_correctnes_worker_model: ClassificationModel = None +_correctnes_worker_update_history = [] +_correctnes_worker_train_loss_history = [] +_correctnes_worker_single_loss_history = [] +_correctnes_worker_val_loss_history = [] + + +def get_full_qualified_name(class_or_func): + return f'{class_or_func.__module__}.{class_or_func.__qualname__}' + + +def correctnes_worker_cli_dataset(stage): + if stage == 'train': + return SklearnDataset(_correctnes_worker_datamodule.x_train, + _correctnes_worker_datamodule.y_train, + _correctnes_worker_datamodule._x_type, + _correctnes_worker_datamodule._y_type + ) + elif stage == 'val': + return SklearnDataset(_correctnes_worker_datamodule.x_valid, + _correctnes_worker_datamodule.y_valid, + _correctnes_worker_datamodule._x_type, + _correctnes_worker_datamodule._y_type + ) + else: + raise ValueError(f'Unknown stage: {stage}') + + +class CorrectnessWorkerM(torch.nn.Module): + def __init__(self): + super().__init__() + self.m =_correctnes_worker_model + self.m.log = lambda *args, **kwargs: None + del self.m.train_acc + self.m.train_acc = lambda *args, **kwargs: None + + def forward(self, batch): + return self.m.training_step(batch, 0)['loss'] + + +def on_before_grad_clip(trainer: Trainer): + grads = {n: p.grad.cpu() for n, p in trainer.model.named_parameters()} + weights = {n: p.data.cpu() for n, p in trainer.model.named_parameters()} + _correctnes_worker_update_history.append((grads, weights)) + + +def after_aggregate_train_step_outputs(trainer: Trainer, aggregated_outputs, train_loss, idx): + _correctnes_worker_train_loss_history.append(train_loss) + + +def on_train_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: + _correctnes_worker_single_loss_history.append(outputs[0].item()) + + + +def correctnes_worker_cli( + tmp_path, + gradient_clip_val, + with_lr_scheduler, + precision='32-true', + with_tp=False +): + + def on_val_step_end(trainer: Trainer, outputs, batches, idx) -> None: + _correctnes_worker_val_loss_history.append(outputs[0].item()) + + assert precision == '32-true' + global _correctnes_worker_datamodule + global _correctnes_worker_model + init_random() + dm = ClassifDataModule() + _correctnes_worker_datamodule = dm + lr_config = None + init_random() + _correctnes_worker_model = ClassificationModel() + if with_lr_scheduler: + lr_config = LRSchedulerConfig( + type=torch.optim.lr_scheduler.StepLR, + args={ + 'step_size': 1, + } + ) + + if with_tp: + compute_config=ComputeConfig(2, 4, use_end2end=True) + policy = 'tp' + else: + compute_config=ComputeConfig(1, 2, use_end2end=True) + policy = 'dp' + + tmp_path = Path(tmp_path) / 'cli' + train_args = TrainerArgs( + compute_config=compute_config, + gen_savedir=tmp_path / 'code', + micro_batch_size=_correctnes_worker_model.batch_size, + global_batch_size=_correctnes_worker_model.batch_size*2, + max_epochs=2, + pas_policy=policy, + instance_name=f'cli_{policy}', + enable_progress_bar=False, + model_config=ModelConfig( + type=CorrectnessWorkerM, + ), + dataset_config=DatasetConfig( + type=correctnes_worker_cli_dataset, + train_args={ + 'stage': 'train' + }, + val_args={ + 'stage': 'val' + }, + ), + dataset_sampler_config=DatasetSamplerConfig( + type='torch.utils.data.DistributedSampler', + val_args={ + 'shuffle': False, # lightning doesn't shuffle val set + }, + ), + optimizer_config=OptimizerConfig( + type=torch.optim.Adam, + args={ + 'lr': _correctnes_worker_model.lr + }, + clip_gnorm=gradient_clip_val, + ), + checkpoint_config=CheckpointConfig( + no_save=True, + ), + lr_scheduler_config=lr_config, + hook_config=dict( + before_gnorm_clip=on_before_grad_clip, + after_aggregate_train_step_outputs=after_aggregate_train_step_outputs, + on_train_step_end=on_train_step_end, + on_val_step_end=on_val_step_end, + ), + ) + trainer = CliTrainer( + train_args=train_args, + ) + _correctnes_worker_update_history.clear() + _correctnes_worker_train_loss_history.clear() + _correctnes_worker_single_loss_history.clear() + _correctnes_worker_val_loss_history.clear() + trainer.train() + return _correctnes_worker_update_history, trainer.model.fullmap, \ + _correctnes_worker_val_loss_history, \ + _correctnes_worker_train_loss_history, \ + _correctnes_worker_single_loss_history + + def correctnes_worker_nnscaler(tmp_path, gradient_clip_val, with_lr_scheduler, precision='32-true', with_tp=False, with_empty_scaler=False @@ -161,6 +319,8 @@ def correctnes_worker_nnscaler(tmp_path, gradient_clip_val, with_lr_scheduler, trainer = Trainer( default_root_dir=tmp_path, max_epochs=2, + enable_progress_bar=False, + num_sanity_val_steps=0, accelerator="gpu", devices=devices, gradient_clip_val=gradient_clip_val, strategy=NnScalerStrategy( @@ -170,7 +330,7 @@ def correctnes_worker_nnscaler(tmp_path, gradient_clip_val, with_lr_scheduler, plugins=[NnScalerPrecision(precision, scaler=scaler)] ) trainer.fit(model, datamodule=dm) - return model.update_history, model.nnscaler_pmodule.fullmap + return model.update_history, model.nnscaler_pmodule.fullmap, model.val_loss_history, model.loss_history def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_scheduler, @@ -200,6 +360,8 @@ def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_s trainer = Trainer( default_root_dir=tmp_path, max_epochs=1, + enable_progress_bar=False, + num_sanity_val_steps=0, callbacks=[ModelCheckpoint(dirpath=tmp_path, save_top_k=1, save_last=True)], accelerator="gpu", devices=devices, gradient_clip_val=gradient_clip_val, @@ -215,6 +377,8 @@ def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_s trainer = Trainer( default_root_dir=tmp_path, max_epochs=2, + enable_progress_bar=False, + num_sanity_val_steps=0, callbacks=[ModelCheckpoint(dirpath=tmp_path, save_top_k=1, save_last=True)], accelerator="gpu", devices=devices, gradient_clip_val=gradient_clip_val, @@ -226,7 +390,7 @@ def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_s plugins=[NnScalerPrecision(precision, scaler=scaler)] ) trainer.fit(model, datamodule=dm, ckpt_path='last') - return model.update_history, model.nnscaler_pmodule.fullmap + return model.update_history, model.nnscaler_pmodule.fullmap, model.val_loss_history, model.loss_history def correctnes_worker_ddp(tmp_path, gradient_clip_val, with_lr_scheduler, precision='32-true'): @@ -239,6 +403,8 @@ def correctnes_worker_ddp(tmp_path, gradient_clip_val, with_lr_scheduler, precis model = ClassificationModel() trainer = Trainer( default_root_dir=tmp_path, + enable_progress_bar=False, + num_sanity_val_steps=0, precision=precision, max_epochs=2, accelerator="gpu", devices=2, @@ -246,7 +412,7 @@ def correctnes_worker_ddp(tmp_path, gradient_clip_val, with_lr_scheduler, precis strategy='ddp', ) trainer.fit(model, datamodule=dm) - return model.update_history + return {'update': model.update_history, 'loss': model.loss_history, 'val_loss': model.val_loss_history} @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') @@ -273,6 +439,13 @@ def _merge_results(returns): ) return weight_results, grad_results + def _assert_loss_equal(returns0, returns1, loss_idx0=-1, loss_idx1=-1, val_loss_idx0=-2, val_loss_idx1=-2): + # TODO: val_loss check + assert len(returns0) == len(returns1) + for i in range(len(returns0)): + assert returns0[i][loss_idx0] == returns1[i][loss_idx1] + assert returns0[i][val_loss_idx0] == returns1[i][val_loss_idx1] + # Test 16-mixed with and without gradient clipping # when gradient clipping is on, the following check will fail # TODO: fix the test when gradient clipping is on @@ -283,41 +456,73 @@ def _merge_results(returns): nnscaler_merged_weight_results_fp16, nnscaler_merged_grad_results_fp16 = _merge_results(nnscaler_returns) for i in range(len(ddp_results[0])): - assert_close(nnscaler_merged_weight_results_fp16[i], ddp_results[0][i][1]) - assert_close(nnscaler_merged_grad_results_fp16[i], ddp_results[0][i][0]) - assert_equal(ddp_results[1][i], ddp_results[0][i]) + assert_close(nnscaler_merged_weight_results_fp16[i], ddp_results[0]['update'][i][1]) + assert_close(nnscaler_merged_grad_results_fp16[i], ddp_results[0]['update'][i][0]) + assert_equal(ddp_results[1]['update'][i], ddp_results[0]['update'][i]) nnscaler_returns_ckpt = launch_torchrun(2, correctnes_worker_nnscaler_checkpoint, tmp_path, gradient_clip_val, with_lr_scheduler) nnscaler_merged_weight_results_ckpt, nnscaler_merged_grad_results_ckpt = _merge_results(nnscaler_returns_ckpt) nnscaler_returns = launch_torchrun(2, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler) nnscaler_merged_weight_results, nnscaler_merged_grad_results = _merge_results(nnscaler_returns) + _assert_loss_equal(nnscaler_returns_ckpt, nnscaler_returns) nnscaler_returns = launch_torchrun(2, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler, '32-true', False, True) nnscaler_merged_weight_results_scaler, nnscaler_merged_grad_results_scaler = _merge_results(nnscaler_returns) + _assert_loss_equal(nnscaler_returns_ckpt, nnscaler_returns) + + cli_returns = launch_torchrun(2, correctnes_worker_cli, tmp_path, gradient_clip_val, with_lr_scheduler) + cli_merged_weight_results, cli_merged_grad_results = _merge_results(cli_returns) + # remove leading 'm.' in names + cli_merged_weight_results = [{k[2:]: v for k, v in x.items()} for x in cli_merged_weight_results] + cli_merged_grad_results = [{k[2:]: v for k, v in x.items()} for x in cli_merged_grad_results] + _assert_loss_equal(cli_returns, nnscaler_returns, val_loss_idx0=-3) + assert cli_returns[0][-2] == cli_returns[1][-2] + assert [(x+y)/2 for x, y in zip(cli_returns[0][-1],cli_returns[1][-1])] == cli_returns[0][-2] assert len(nnscaler_merged_weight_results) == len(nnscaler_merged_weight_results_ckpt) assert len(nnscaler_merged_weight_results) == len(nnscaler_merged_weight_results_scaler) + assert len(nnscaler_merged_weight_results) == len(cli_merged_weight_results) assert len(nnscaler_merged_grad_results) == len(nnscaler_merged_grad_results_ckpt) assert len(nnscaler_merged_grad_results) == len(nnscaler_merged_grad_results_scaler) + assert len(nnscaler_merged_grad_results) == len(cli_merged_grad_results) for i in range(len(nnscaler_merged_weight_results_scaler)): assert_equal(nnscaler_merged_weight_results[i], nnscaler_merged_weight_results_scaler[i]) assert_equal(nnscaler_merged_weight_results[i], nnscaler_merged_weight_results_ckpt[i]) + assert_equal(nnscaler_merged_weight_results[i], cli_merged_weight_results[i]) + assert_equal(nnscaler_merged_grad_results[i], nnscaler_merged_grad_results_scaler[i]) assert_equal(nnscaler_merged_grad_results[i], nnscaler_merged_grad_results_ckpt[i]) + assert_equal(nnscaler_merged_grad_results[i], cli_merged_grad_results[i]) ddp_results = launch_torchrun(2, correctnes_worker_ddp, tmp_path, gradient_clip_val, with_lr_scheduler) + if not gradient_clip_val: + _assert_loss_equal(ddp_results, nnscaler_returns, loss_idx0='loss', val_loss_idx0='val_loss') for i in range(len(ddp_results[0])): - assert_close(nnscaler_merged_weight_results[i], ddp_results[0][i][1]) - assert_close(nnscaler_merged_grad_results[i], ddp_results[0][i][0]) - assert_equal(ddp_results[1][i], ddp_results[0][i]) + if gradient_clip_val: # currently it is not exactly the same when gradient clipping is on + assert_close(nnscaler_merged_weight_results[i], ddp_results[0]['update'][i][1]) + assert_close(nnscaler_merged_grad_results[i], ddp_results[0]['update'][i][0]) + else: + assert_equal(nnscaler_merged_weight_results[i], ddp_results[0]['update'][i][1]) + assert_equal(nnscaler_merged_grad_results[i], ddp_results[0]['update'][i][0]) + assert_equal(ddp_results[1]['update'][i], ddp_results[0]['update'][i]) if torch.cuda.device_count() >= 4: nnscaler_returns = launch_torchrun(4, correctnes_worker_nnscaler, tmp_path, gradient_clip_val, with_lr_scheduler, '32-true', True) nnscaler_merged_weight_results, nnscaler_merged_grad_results = _merge_results(nnscaler_returns) for i in range(len(ddp_results[0])): - assert_close(nnscaler_merged_weight_results[i], ddp_results[0][i][1]) - assert_close(nnscaler_merged_grad_results[i], ddp_results[0][i][0]) + assert_close(nnscaler_merged_weight_results[i], ddp_results[0]['update'][i][1]) + assert_close(nnscaler_merged_grad_results[i], ddp_results[0]['update'][i][0]) + + cli_returns = launch_torchrun(4, correctnes_worker_cli, tmp_path, gradient_clip_val, with_lr_scheduler, '32-true', True) + cli_merged_weight_results, cli_merged_grad_results = _merge_results(cli_returns) + # remove leading 'm.' in names + cli_merged_weight_results = [{k[2:]: v for k, v in x.items()} for x in cli_merged_weight_results] + cli_merged_grad_results = [{k[2:]: v for k, v in x.items()} for x in cli_merged_grad_results] + + for i in range(len(nnscaler_merged_weight_results)): + assert_equal(nnscaler_merged_weight_results[i], cli_merged_weight_results[i]) + assert_equal(nnscaler_merged_grad_results[i], cli_merged_grad_results[i]) diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py index 146a9c98..97ac8ff4 100644 --- a/tests/parallel_module/test_attr_dedup.py +++ b/tests/parallel_module/test_attr_dedup.py @@ -7,7 +7,7 @@ from torch import nn from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, \ - merge_state_dicts, load_merged_state_dicts, \ + merge_state_dicts, load_merged_state_dict, \ deduped_state_dict, load_deduped_state_dict from nnscaler.runtime.module import ParallelModule from nnscaler.graph.graph import IRGraph diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 9a68cfd5..ad49d9cf 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -14,7 +14,7 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dict from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm @@ -326,7 +326,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf # assert not check_model_state_dict_equal(inference_module.state_dict(), model_state_dict) # inference model can be loaded from merged state_dict - load_merged_state_dicts(inference_module, merged_model_state_dict) + load_merged_state_dict(inference_module, merged_model_state_dict) torch.save(inference_module.state_dict(), temp_inferenece_ckpt_file) torch.distributed.barrier() inference_ckpt_files = [ckpt_dir / temp_inferenece_ckpt_file_template.format(rank=i) for i in range(torch.distributed.get_world_size())] @@ -336,7 +336,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf model_from_merged = type(model)() optimizer_from_merged = build_optimizer(model_from_merged, torch.optim.Adam, lr=0.01) - load_merged_state_dicts( + load_merged_state_dict( model_from_merged, merged_model_state_dict, optimizer_from_merged, merged_opt_state_dict, ) diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index b304dc92..f40685f7 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -5,7 +5,7 @@ import pytest import torch.distributed -from nnscaler.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dicts, broadcast_weights +from nnscaler.parallel import parallelize, ComputeConfig, merge_state_dicts, load_merged_state_dict, broadcast_weights from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index ee7e9edd..59115655 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -7,7 +7,7 @@ from torch import nn from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, \ - merge_state_dicts, load_merged_state_dicts, \ + merge_state_dicts, load_merged_state_dict, \ deduped_state_dict, load_deduped_state_dict from nnscaler.runtime.module import ParallelModule diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index afcc9517..d0d02f7e 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -6,7 +6,7 @@ import torch from torch import nn -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dict from .common import CubeLinear, init_random, init_distributed, clear_dir_on_rank0 from ..launch_torchrun import launch_torchrun @@ -136,7 +136,7 @@ def _load_merged(parallel_model: torch.nn.Module, ckpt_dir): raw_model_state_dict: Dict[str, Any] = raw_ckpt_dict['model'] raw_opt_state_dict = raw_ckpt_dict['optimizer'] optimizer = build_optimizer(parallel_model, torch.optim.Adam, lr=0.01) - load_merged_state_dicts( + load_merged_state_dict( parallel_model, raw_model_state_dict, optimizer, raw_opt_state_dict, ) diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py index d49a5dc8..21e886cf 100644 --- a/tests/parallel_module/test_checkpoint_unused.py +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -14,7 +14,7 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dicts +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dict from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm From 027fd6437bbbc2f76cea306c7a14c1b8212561c5 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 24 Jul 2024 07:52:55 +0000 Subject: [PATCH 1687/1892] Merged PR 2209: Nanogpt with mini-trainer parity matched between lightning version & mini-trainer version --- examples/nanogpt/README.md | 14 --- examples/nanogpt/train_cli.py | 182 +++++++++++++++++++++++++++ examples/nanogpt/train_cli_args.yaml | 66 ++++++++++ examples/nanogpt/train_nnscaler.py | 3 +- nnscaler/cli/loggers/wandb.py | 5 + nnscaler/cli/train.py | 2 +- nnscaler/cli/trainer.py | 20 ++- nnscaler/cli/trainer_args.py | 93 +++++++++----- nnscaler/parallel.py | 6 +- 9 files changed, 337 insertions(+), 54 deletions(-) delete mode 100644 examples/nanogpt/README.md create mode 100644 examples/nanogpt/train_cli.py create mode 100644 examples/nanogpt/train_cli_args.yaml diff --git a/examples/nanogpt/README.md b/examples/nanogpt/README.md deleted file mode 100644 index 9468fa06..00000000 --- a/examples/nanogpt/README.md +++ /dev/null @@ -1,14 +0,0 @@ -Prepare data: -``` -python nanoGPT/data/shakespeare_char/prepare.py -``` - -Run without nnscaler -``` -python train_lightning.py nanoGPT/config/train_shakespeare_char.py -``` - -Run with nnscaler -``` -torchrun --standalone --nproc_per_node=1 train_lightning.py nanoGPT/config/train_shakespeare_char.py -``` diff --git a/examples/nanogpt/train_cli.py b/examples/nanogpt/train_cli.py new file mode 100644 index 00000000..84ad7e9d --- /dev/null +++ b/examples/nanogpt/train_cli.py @@ -0,0 +1,182 @@ +""" +Run training with this command in this directory: +``` +DETERMINISTIC=1 torchrun --standalone --nproc_per_node=1 \ + ../../nnscaler/cli/train.py -f train_cli_args.yaml +``` +""" +import math +import os +from pathlib import Path +import pickle +import random +import sys +from typing import TYPE_CHECKING +import time + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +import lightning as L +import nnscaler +if TYPE_CHECKING: + from nnscaler.cli.trainer import Trainer + from nnscaler.cli.trainer_args import TrainerArgs + +nanogpt_path = Path(__file__).absolute().with_name('nanoGPT') +sys.path.append(str(nanogpt_path)) + +from model import GPTConfig, GPT + + +def init_env(train_args: 'TrainerArgs'): + torch.manual_seed(0) + np.random.seed(0) + random.seed(0) + if os.environ.get('DETERMINISTIC') is not None: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + + # be consistent with nanogpt settings + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + + +def on_train_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: + if torch.distributed.get_rank() == 0: + print(f'# train_loss {idx:03d}', outputs[0].item()) + + +def on_val_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: + if torch.distributed.get_rank() == 0: + print(f'# val_loss {idx:03d}', outputs[0].item()) + + +# poor man's data loader +class NanoGptDataset(Dataset): + def __init__(self, data_dir, split, block_size): + self.split = split + self.block_size = block_size + self.data_dir = data_dir + data = np.memmap(os.path.join(self.data_dir, f'{split}.bin'), dtype=np.uint16, mode='r') + self.len = len(data) - self.block_size + + def __getitems__(self, indices): + x, y = self.get_batch(self.split, indices) + return ( + x.clone().detach(), # theoretically unnecessary, for robustness + y.clone().detach(), + ) + + def __len__(self): + return self.len + + def get_batch(self, split, ix): + # We recreate np.memmap every batch to avoid a memory leak, as per + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + data = np.memmap(os.path.join(self.data_dir, f'{split}.bin'), dtype=np.uint16, mode='r') + x = torch.stack([torch.from_numpy((data[i:i+self.block_size]).astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy((data[i+1:i+1+self.block_size]).astype(np.int64)) for i in ix]) + return x, y + + +def _create_nano_gpt_model( + init_from, + *, + n_layer, + n_head, + n_embd, + block_size, + bias, + dropout, + meta_path=None, +): + # reset seeds to ensure the same initialization with nanogpt + torch.manual_seed(0) + np.random.seed(0) + random.seed(0) + # attempt to derive vocab_size from the dataset + meta_vocab_size = None + if meta_path and os.path.exists(meta_path): + with open(meta_path, 'rb') as f: + meta = pickle.load(f) + meta_vocab_size = meta['vocab_size'] + print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") + + # model init + model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, + bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line + if init_from == 'scratch': + # init a new model from scratch + print("Initializing a new model from scratch") + # determine the vocab size we'll use for from-scratch training + if meta_vocab_size is None: + print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") + model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 + gptconf = GPTConfig(**model_args) + model = GPT(gptconf) + elif init_from == 'resume': + print(f"Resuming training") + # resume training from a checkpoint. (handled by lightning) + if meta_vocab_size is None: + print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") + model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 + gptconf = GPTConfig(**model_args) + model = GPT(gptconf) + elif init_from.startswith('gpt2'): + print(f"Initializing from OpenAI GPT-2 weights: {init_from}") + # initialize from OpenAI GPT-2 weights + override_args = dict(dropout=dropout) + model = GPT.from_pretrained(init_from, override_args) + # read off the created config params, so we can store them into checkpoint correctly + for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: + model_args[k] = getattr(model.config, k) + + # crop down the model block size if desired, using model surgery + if block_size < model.config.block_size: + model.crop_block_size(block_size) + model_args['block_size'] = block_size # so that the checkpoint will have the right value + + return model + + +class Model(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.model = _create_nano_gpt_model(*args, **kwargs) + + def forward(self, batch): + _logits, loss = self.model(batch[0], batch[1]) + return loss + + +class Scheduler(torch.optim.lr_scheduler.LRScheduler): + def __init__(self, optimizer, warmup_iters, learning_rate, lr_decay_iters, min_lr): + self.it = 0 # must before super().__init__() + self.warmup_iters = warmup_iters + self.learning_rate = learning_rate + self.lr_decay_iters = lr_decay_iters + self.min_lr = min_lr + super().__init__(optimizer) + + def get_lr(self): + lr = self._get_lr(self.it) + self.it += 1 + return [lr for _ in self.optimizer.param_groups] + + # learning rate decay scheduler (cosine with warmup) + def _get_lr(self, it): + # 1) linear warmup for warmup_iters steps + if it < self.warmup_iters: + return self.learning_rate * it / self.warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > self.lr_decay_iters: + return self.min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return self.min_lr + coeff * ( self.learning_rate - self.min_lr) diff --git a/examples/nanogpt/train_cli_args.yaml b/examples/nanogpt/train_cli_args.yaml new file mode 100644 index 00000000..351089aa --- /dev/null +++ b/examples/nanogpt/train_cli_args.yaml @@ -0,0 +1,66 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: false + use_end2end: true + +init_env_fn: examples.nanogpt.train_cli.init_env +run_mode: run +pas_policy: autodist +micro_batch_size: 64 +grad_accumulation_steps: 1 +max_train_steps: 5000 +max_val_steps: 200 +val_every_n_train_steps: 250 +enable_progress_bar: True + +model_config: + type: examples.nanogpt.train_cli.Model + args: + init_from: scratch + n_layer: 6 + n_head: 6 + n_embd: 384 + dropout: 0.0 + bias: false + block_size: 256 + meta_path: ./nanoGPT/data/shakespeare_char/meta.pkl + +optimizer_config: + type: torch.optim.AdamW + args: + lr: 1e-3 + betas: + - 0.9 + - 0.99 + fused: true + clip_gnorm: 0.0 + +lr_scheduler_config: + type: examples.nanogpt.train_cli.Scheduler + args: + warmup_iters: 100 + learning_rate: 1e-3 + lr_decay_iters: 5000 + min_lr: 1e-4 + interval: step + +dataset_config: + type: examples.nanogpt.train_cli.NanoGptDataset + train_args: + data_dir: ./nanoGPT/data/shakespeare_char + split: train + block_size: 256 + val_args: + data_dir: ./nanoGPT/data/shakespeare_char + split: val + block_size: 256 + +checkpoint_config: + keep_last_n_checkpoints: 10 + every_n_train_steps: 250 + save_type: deduped + +# hook_config: +# on_train_step_end: examples.nanogpt.train_cli.on_train_step_end +# on_val_step_end: examples.nanogpt.train_cli.on_val_step_end diff --git a/examples/nanogpt/train_nnscaler.py b/examples/nanogpt/train_nnscaler.py index c4f60964..a5fe394d 100644 --- a/examples/nanogpt/train_nnscaler.py +++ b/examples/nanogpt/train_nnscaler.py @@ -60,7 +60,7 @@ plan_ngpus = 1 runtime_ngpus = -1 # use -1 for WORLD_SIZE since nanoGPT's argparse require it to have static type -deterministic = False +deterministic = os.environ.get('DETERMINISTIC') is not None # ----------------------------------------------------------------------------- config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] @@ -73,6 +73,7 @@ dropout = 0.0 # must set before model init grad_clip = 0.0 torch.use_deterministic_algorithms(True) # NOTE: requires env CUBLAS_WORKSPACE_CONFIG=":4096:8" + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" # various inits, derived attributes, I/O setup diff --git a/nnscaler/cli/loggers/wandb.py b/nnscaler/cli/loggers/wandb.py index dd263be0..c23c25b3 100644 --- a/nnscaler/cli/loggers/wandb.py +++ b/nnscaler/cli/loggers/wandb.py @@ -20,6 +20,11 @@ def __init__( dir: Optional[str] = None, **kwargs ) -> None: + if wandb is None: + raise RuntimeError( + "wandb not found, please install with: pip install wandb" + ) + super().__init__() self._name = name diff --git a/nnscaler/cli/train.py b/nnscaler/cli/train.py index b0aa3c8f..3070dbf2 100644 --- a/nnscaler/cli/train.py +++ b/nnscaler/cli/train.py @@ -2,7 +2,7 @@ import nnscaler -from .trainer import Trainer +from nnscaler.cli.trainer import Trainer if __name__ == '__main__': diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index c7e19e33..fb3fd99d 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -133,6 +133,7 @@ def _load_dummy_input(self): return next(iter(dataloader)) def _setup(self): + self.train_args.init_env() compile_only = self.train_args.run_mode == 'compile' if not compile_only: nnscaler.init() @@ -154,6 +155,12 @@ def _create_model(): # create dataset and dataloader for stage in ['train', 'val', 'test']: self.dataset[stage] = self.train_args.create_dataset(stage) + + # load a dummy input from training dataset + self.dummy_input = self._load_dummy_input() + self.dummy_input = self._fix_input(self.dummy_input) + + for stage in ['train', 'val', 'test']: self.dataloader[stage] = self.train_args.create_dataloader(stage, self.dataset[stage]) if self.dataloader[stage] is not None \ and not self.dataloader[stage].drop_last \ @@ -165,13 +172,14 @@ def _create_model(): f"You can specify `drop_last=True` in DataLoader to fix this problem." ) - # load a dummy input from training dataset - self.dummy_input = self._load_dummy_input() - self.dummy_input = self._fix_input(self.dummy_input) - # setup compute config compute_config = copy.deepcopy(self.train_args.compute_config) compute_config.pas_config['__pas_name'] = self.train_args.pas_policy + # autodist configs + compute_config.pas_config['update_freq'] = self.train_args.update_freq + compute_config.pas_config['use_bf16'] = self.train_args.bf16 + compute_config.pas_config['use_fp16'] = self.train_args.fp16 + compute_config.user_config['__from_trainer_args'] = { 'mbs': self.train_args.micro_batch_size, 'gbs': self.train_args.global_batch_size, @@ -203,6 +211,7 @@ def _create_model(): self.total_train_steps_per_epoch += 1 # will add extra dummy batches _, self.sync_group = self.train_args.compute_config.get_sync_group() self.model = pmodel_class() + self.model.cuda() self.optimizer = self.train_args.create_parallel_optimizer(self.model) self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) self.loggers = self.train_args.create_loggers() @@ -447,7 +456,7 @@ def train(self): next_batch_index = self.train_status.next_batch_index self.hook.on_train_start(self) - for epoch in range(self.train_status.epoch, self.train_args.max_epochs): + for epoch in range(self.train_status.epoch, self.train_args.max_epochs or sys.maxsize): self.dataloader['train'].sampler.set_epoch(epoch) torch.distributed.barrier() @@ -487,7 +496,6 @@ def _validate(self, step_stat: _StepStat): logger.info('No val dataset specified. Validation skipped.') return step_stat.train_loss - logger.info('Validating...') data_iter = enumerate(self._global_batch_iterator(stage='val')) if self.rank == 0: total_val_steps_per_epoch = len(self.dataloader['val']) // self.train_args.update_freq diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 6efe6d7d..0cafa6b9 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -279,13 +279,17 @@ class TrainerArgs: fp16: bool = False bf16: bool = False micro_batch_size: int = 1 - # default is self.micro_batch_size*self.scaling_factor - # which means update_freq is 1 + # You can set one of `global_batch_size` and `grad_accumulation_steps` option. + # Please note if both are set, they must be consistent. + # default is + # global_batch_size = self.micro_batch_size*self.scaling_factor + # grad_accumulation_steps = 1 global_batch_size: Optional[int] = None + grad_accumulation_steps: Optional[int] = None - max_epochs: int = 1000 - max_train_steps: int = 0 - max_val_steps: int = 0 + max_epochs: Optional[int] = None + max_train_steps: Optional[int] = None + max_val_steps: Optional[int] = None # validation frequency val_every_n_train_steps: Optional[int] = None @@ -293,20 +297,35 @@ class TrainerArgs: enable_progress_bar: bool = True + seed: Optional[int] = None + # environment initialization function + # you can put your environment initialization code here + init_env_fn: str = None + def __post_init__(self): if not self.compute_config: raise ValueError("compute_config is required") if not self.compute_config.use_end2end: raise ValueError("use_end2end must be True") - if not self.global_batch_size: + + if not self.global_batch_size and not self.grad_accumulation_steps: self.global_batch_size = self.micro_batch_size*self.scaling_factor - if self.global_batch_size % (self.micro_batch_size*self.scaling_factor) != 0: - raise ValueError(f"`global_batch_size` {self.global_batch_size} is not divisible by `micro_batch_size*(runtime_ngpus/plan_ngpus)` " - f"which is {self.micro_batch_size * self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus}") + self.grad_accumulation_steps = 1 + elif not self.global_batch_size: + self.global_batch_size = self.micro_batch_size*self.scaling_factor*self.grad_accumulation_steps + elif not self.grad_accumulation_steps: + self.grad_accumulation_steps = self.global_batch_size // (self.micro_batch_size*self.scaling_factor) + + if self.global_batch_size != self.micro_batch_size*self.scaling_factor*self.grad_accumulation_steps: + raise ValueError(f"`global_batch_size` {self.global_batch_size} is not equal to `micro_batch_size*scaling_factor*grad_accumulation_steps` " + f"{self.micro_batch_size*self.scaling_factor*self.grad_accumulation_steps}") + if self.run_mode not in ('compile', 'run'): raise ValueError(f"Invalid run_mode {self.run_mode}") if self.fp16 and self.bf16: raise ValueError("Cannot use both fp16 and bf16") + if not self.max_epochs and not self.max_train_steps: + raise ValueError("max_epochs or max_train_steps is required") if not self.model_config.type: raise ValueError("model type is required") if not self.optimizer_config.type: @@ -351,29 +370,28 @@ def from_yaml(cls, path: str) -> 'TrainerArgs': return cls.from_dict(yaml.safe_load(f)) @classmethod - def create_kwarg(cls, value: dict): - value = copy.deepcopy(value) - for k, v in value.items(): - if isinstance(v, dict): - value[k] = cls.create_kwarg(v) - elif isinstance(v, list): - value[k] = [cls.create_kwarg(i) for i in v] - elif isinstance(v, tuple): - value[k] = tuple(cls.create_kwarg(i) for i in v) - - if '__type' in value: - value_type = load_type(value.pop('__type')) - return value_type(**value) - elif '__value_type' in value: - if 'value' not in value: - raise ValueError("value is required when __value_type is present") - value_type = value.pop('__value_type') - if value_type == 'function': # when type is function, the value should be the full qualified name of the function - return load_type(value['value']) + def create_kwarg(cls, value: Any): + if isinstance(value, dict): + value = {k: cls.create_kwarg(v) for k, v in value.items()} + if '__type' in value: + value_type = load_type(value.pop('__type')) + return value_type(**value) + elif '__value_type' in value: + if 'value' not in value: + raise ValueError("value is required when __value_type is present") + value_type = value.pop('__value_type') + if value_type == 'function': # when type is function, the value should be the full qualified name of the function + return load_type(value['value']) + else: + # call its __init__ function + value_type = load_type(value_type) + return value_type(value['value']) else: - # call its __init__ function - value_type = load_type(value_type) - return value_type(value['value']) + return value + elif isinstance(value, list): + return [cls.create_kwarg(i) for i in value] + elif isinstance(value, tuple): + return tuple(cls.create_kwarg(i) for i in value) else: return value @@ -395,6 +413,19 @@ def scaling_factor(self): def update_freq(self): return self.global_batch_size // self.micro_batch_size // self.scaling_factor + def init_env(self): + if self.seed is not None: + import random + import numpy as np + torch.manual_seed(self.seed) + np.random.seed(self.seed) + random.seed(self.seed) + + if self.init_env_fn is None: + return + init_env_fn = load_type(self.init_env_fn) + init_env_fn(self) + def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model_config.args) return self.model_type(**kwargs) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index a022dd03..83163346 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -65,7 +65,7 @@ @dataclass(frozen=True) class ComputeConfig: plan_ngpus: int - runtime_ngpus: int + runtime_ngpus: Optional[int] = None # whether to fold constant when generating code constant_folding: bool = False @@ -128,6 +128,10 @@ class ComputeConfig: def __post_init__(self): if self.plan_ngpus <= 0: raise ValueError(f"plan_ngpus {self.plan_ngpus} must be > 0") + if self.runtime_ngpus is None: + super().__setattr__('runtime_ngpus', int(os.environ.get('WORLD_SIZE', 0))) + if not self.runtime_ngpus: + raise ValueError(f"runtime_ngpus is not set and WORLD_SIZE is not set.") if self.runtime_ngpus <= 0: raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be > 0") if self.runtime_ngpus % self.plan_ngpus != 0: From 167704aa4661ad76a303a54421758d652c0fa51b Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 25 Jul 2024 09:11:31 +0000 Subject: [PATCH 1688/1892] Merged PR 2210: Minitrainer: refine names / precision support --- examples/nanogpt/train_cli_args.yaml | 15 ++- nnscaler/cli/arg_parser.py | 5 +- nnscaler/cli/trainer.py | 80 ++++++----- nnscaler/cli/trainer_args.py | 124 +++++++++++------- tests/cli/test_arg_parser.py | 32 +++-- tests/cli/test_trainer.py | 58 ++++---- tests/cli/trainer_args.yaml | 8 +- .../lightning/pytorch/test_strategy.py | 14 +- 8 files changed, 201 insertions(+), 135 deletions(-) diff --git a/examples/nanogpt/train_cli_args.yaml b/examples/nanogpt/train_cli_args.yaml index 351089aa..0e10681f 100644 --- a/examples/nanogpt/train_cli_args.yaml +++ b/examples/nanogpt/train_cli_args.yaml @@ -12,9 +12,10 @@ grad_accumulation_steps: 1 max_train_steps: 5000 max_val_steps: 200 val_every_n_train_steps: 250 -enable_progress_bar: True +enable_progress_bar: true +# precision: bf16 -model_config: +model: type: examples.nanogpt.train_cli.Model args: init_from: scratch @@ -26,7 +27,7 @@ model_config: block_size: 256 meta_path: ./nanoGPT/data/shakespeare_char/meta.pkl -optimizer_config: +optimizer: type: torch.optim.AdamW args: lr: 1e-3 @@ -36,7 +37,7 @@ optimizer_config: fused: true clip_gnorm: 0.0 -lr_scheduler_config: +lr_scheduler: type: examples.nanogpt.train_cli.Scheduler args: warmup_iters: 100 @@ -45,7 +46,7 @@ lr_scheduler_config: min_lr: 1e-4 interval: step -dataset_config: +dataset: type: examples.nanogpt.train_cli.NanoGptDataset train_args: data_dir: ./nanoGPT/data/shakespeare_char @@ -56,11 +57,11 @@ dataset_config: split: val block_size: 256 -checkpoint_config: +checkpoint: keep_last_n_checkpoints: 10 every_n_train_steps: 250 save_type: deduped -# hook_config: +# hook: # on_train_step_end: examples.nanogpt.train_cli.on_train_step_end # on_val_step_end: examples.nanogpt.train_cli.on_val_step_end diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index 11b74606..008f0dc4 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -60,7 +60,10 @@ def _fix_optional(type_info): or (UnionType and isinstance(type_info, UnionType)): args = getattr(type_info, '__args__', None) if len(args) != 2 or (args[1] != type(None) and args[0] != type(None)): - raise ValueError(f"Invalid optional type {type_info}") + # when multiple types are allowed, + # we don't do any conversion + # let's the user to handle it + return Any if args[1] == type(None): return _fix_optional(args[0]) else: diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index fb3fd99d..d5fb2e23 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -106,10 +106,8 @@ def _fix_input(self, input): elif isinstance(input, tuple): return tuple(self._fix_input(v) for v in input) elif isinstance(input, torch.Tensor): - if self.train_args.fp16: - return input.half().cuda() - elif self.train_args.bf16: - return input.bfloat16().cuda() + if input.is_floating_point() and self.train_args.input_dtype is not None: + return input.to(self.train_args.input_dtype).cuda() else: return input.cuda() return input @@ -144,10 +142,27 @@ def _setup(self): def _create_model(): model = self.train_args.create_model() - if self.train_args.fp16: - model = model.half() - elif self.train_args.bf16: - model = model.bfloat16() + if self.train_args.param_dtype == self.train_args.buffer_dtype: + if self.train_args.param_dtype is not None: + model = model.to(self.train_args.param_dtype) + else: + # separate param and buffer dtype + # TODO: a little hacky. A better way? + # 3 kinds of tensors are converted in Module._apply: + # model parameters, its grad, and buffer + # param_dtype controls the first two, (but grad is `None` here) + # and buffer_dtype controls the last one + buf_ids = { id(buf) for buf in model.buffers(recurse=True) } + if self.train_args.param_dtype is not None: + model._apply( + lambda t: t.to(self.train_args.param_dtype) + if t.is_floating_point() and id(t) not in buf_ids + else t) + if self.train_args.buffer_dtype is not None: + model._apply( + lambda t: t.to(self.train_args.buffer_dtype) + if t.is_floating_point() and id(t) in buf_ids + else t) if self.train_args.tracing_from_weights: model.load_state_dict(torch.load(self.train_args.tracing_from_weights)) return model @@ -177,15 +192,14 @@ def _create_model(): compute_config.pas_config['__pas_name'] = self.train_args.pas_policy # autodist configs compute_config.pas_config['update_freq'] = self.train_args.update_freq - compute_config.pas_config['use_bf16'] = self.train_args.bf16 - compute_config.pas_config['use_fp16'] = self.train_args.fp16 + compute_config.pas_config['use_bf16'] = self.train_args.param_dtype == torch.bfloat16 + compute_config.pas_config['use_fp16'] = self.train_args.param_dtype == torch.float16 compute_config.user_config['__from_trainer_args'] = { 'mbs': self.train_args.micro_batch_size, 'gbs': self.train_args.global_batch_size, - 'fp16': self.train_args.fp16, - 'bf16': self.train_args.bf16, - 'model_args': self.train_args.model_config.args, + 'precision': self.train_args.precision, + 'model_args': self.train_args.model.args, } # parallalize model @@ -231,7 +245,7 @@ def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): [s['optimizer'] for s in state_dicts] ) train_args = copy.deepcopy(state_dicts[0]['train_args']) - train_args['checkpoint_config']['save_type'] = 'merged' + train_args['checkpoint']['save_type'] = 'merged' merged_state_dict = { 'model': module_state_dict, 'optimizer': opt_state_dict, @@ -254,7 +268,7 @@ def _log_config(self, config: Dict): logger.setup(config) def _load_checkpoint(self): - resume_from = self.train_args.checkpoint_config.get_resume_checkpoint_dir() + resume_from = self.train_args.checkpoint.get_resume_checkpoint_dir() if not resume_from: return logger.info(f"Resuming from {resume_from}") @@ -264,7 +278,7 @@ def _load_checkpoint(self): resume_from = resume_from / f'{self.rank}.ckpt' state_dict = torch.load(resume_from, map_location='cpu') self.hook.on_load_checkpoint(self, state_dict) - ckpt_save_type = state_dict['train_args']['checkpoint_config']['save_type'] + ckpt_save_type = state_dict['train_args']['checkpoint']['save_type'] if ckpt_save_type == 'merged': # it is a merged state dict nnscaler.load_merged_state_dict( @@ -292,7 +306,7 @@ def _load_checkpoint(self): self.train_status = TrainStatus(**state_dict['train_status']) def _save_checkpoint(self, loss): - checkpoint_config = self.train_args.checkpoint_config + checkpoint_config = self.train_args.checkpoint if checkpoint_config.no_save: logger.info('Skip saving checkpoint because `no_save` is set to True') @@ -375,21 +389,21 @@ def _save_checkpoint(self, loss): self._expire_checkpoints() def _expire_checkpoints(self): - if not self.train_args.checkpoint_config.keep_last_n_checkpoints: # keep all + if not self.train_args.checkpoint.keep_last_n_checkpoints: # keep all return - save_dir = Path(self.train_args.checkpoint_config.save_dir) + save_dir = Path(self.train_args.checkpoint.save_dir) checkpoints = [ p.name for p in save_dir.glob('*') if p.is_dir() and p.name not in [CHECKPOINT_BEST_DIR_NAME, CHECKPOINT_LAST_DIR_NAME] ] - if len(checkpoints) <= self.train_args.checkpoint_config.keep_last_n_checkpoints: + if len(checkpoints) <= self.train_args.checkpoint.keep_last_n_checkpoints: return # (step, num) pairs checkpoint_info = [(int(p.split('-')[1]), p) for p in checkpoints] checkpoint_info.sort() - expire_list = checkpoint_info[:-self.train_args.checkpoint_config.keep_last_n_checkpoints] + expire_list = checkpoint_info[:-self.train_args.checkpoint.keep_last_n_checkpoints] best_ckpt = save_dir / CHECKPOINT_BEST_DIR_NAME if best_ckpt.exists(): @@ -467,7 +481,7 @@ def train(self): self.train_epoch(epoch) self.hook.on_epoch_end(self, epoch) - if self.lr_scheduler and self.train_args.lr_scheduler_config.interval == 'epoch': + if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'epoch': self.lr_scheduler.step() if self.train_args.max_train_steps and self.num_train_steps >= self.train_args.max_train_steps: @@ -582,7 +596,7 @@ def train_epoch(self, epoch): aggregate_outputs = self.train_args.aggregate_outputs or self.aggregate_outputs aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) - if self.train_args.optimizer_config.loss_reduction == 'mean': + if self.train_args.optimizer.loss_reduction == 'mean': loss = aggregated_outputs.loss_sum / aggregated_outputs.num_batches else: loss = aggregated_outputs.loss_sum @@ -594,23 +608,23 @@ def train_epoch(self, epoch): self.optimizer.sync_shard_grad() # scale gradients - if self.train_args.optimizer_config.grad_reduction == 'sum': + if self.train_args.optimizer.grad_reduction == 'sum': # do nothing. Already done in reducers pass - elif self.train_args.optimizer_config.grad_reduction == 'mean': + elif self.train_args.optimizer.grad_reduction == 'mean': if not aggregated_outputs.num_batches: raise RuntimeError("`aggregate_outputs` doesn't set `num_batches` field") self.optimizer.scale_grads(1.0 / aggregated_outputs.num_batches) else: - assert self.train_args.optimizer_config.grad_reduction == 'per-token-mean' + assert self.train_args.optimizer.grad_reduction == 'per-token-mean' if not aggregated_outputs.num_tokens: raise RuntimeError("`aggregate_outputs` doesn't set `num_tokens` field") self.optimizer.scale_grads(1.0 / aggregated_outputs.num_tokens) # clip gradients self.hook.before_gnorm_clip(self) - if self.train_args.optimizer_config.clip_gnorm: - step_stat.gnorm = self.optimizer.clip_gnorm(self.train_args.optimizer_config.clip_gnorm) + if self.train_args.optimizer.clip_gnorm: + step_stat.gnorm = self.optimizer.clip_gnorm(self.train_args.optimizer.clip_gnorm) else: step_stat.gnorm = self.optimizer.clip_gnorm() self.hook.after_gnorm_clip(self, step_stat.gnorm) @@ -621,12 +635,12 @@ def train_epoch(self, epoch): self.hook.before_optimizer_step(self) self.optimizer.step() self.hook.after_optimizer_step(self) - if self.lr_scheduler and self.train_args.lr_scheduler_config.interval == 'step': + if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'step': self.lr_scheduler.step() # validate and save checkpoint - if self.train_args.checkpoint_config.every_n_train_steps and \ - self.num_train_steps % self.train_args.checkpoint_config.every_n_train_steps == 0: + if self.train_args.checkpoint.every_n_train_steps and \ + self.num_train_steps % self.train_args.checkpoint.every_n_train_steps == 0: self._validate_and_save(step_stat) has_validated = VAL_STATUS_SAVE @@ -646,8 +660,8 @@ def train_epoch(self, epoch): else: # not from `break` if not has_validated: if self.train_args.max_epochs == self.train_status.epoch + 1 \ - or (self.train_args.checkpoint_config.every_n_epochs and \ - (self.train_status.epoch + 1) % self.train_args.checkpoint_config.every_n_epochs == 0): + or (self.train_args.checkpoint.every_n_epochs and \ + (self.train_status.epoch + 1) % self.train_args.checkpoint.every_n_epochs == 0): self._validate_and_save(step_stat) has_validated = VAL_STATUS_SAVE elif self.train_args.val_every_n_epochs and \ diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 0cafa6b9..9f27c4ab 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -1,6 +1,6 @@ from dataclasses import asdict, dataclass, field import importlib -from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union, get_args from pathlib import Path import logging import copy @@ -250,6 +250,16 @@ def __init__(self, hook_config: HookMapConfig): setattr(self, k, load_type(v)) +_TENSOR_TYPE = Literal['param', 'buffer', 'input'] +_PRECISION_TYPE = Literal['fp32', 'fp16', 'bf16', 'none'] +_PRECISION_MAP = { + 'fp32': torch.float32, + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + 'none': None # as it is. no conversion will happen. +} + + @dataclass class TrainerArgs: compute_config: ComputeConfig = None @@ -265,19 +275,20 @@ class TrainerArgs: # It is only used in tracing to serve as the initial state dict of the model. tracing_from_weights: str = None - model_config: ModelConfig = field(default_factory=ModelConfig) - optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig) - dataset_config: DatasetConfig = field(default_factory=DatasetConfig) - dataloader_config: DataloaderConfig = field(default_factory=DataloaderConfig) - dataset_sampler_config: DatasetSamplerConfig = field(default_factory=DatasetSamplerConfig) - lr_scheduler_config: Optional[LRSchedulerConfig] = None - checkpoint_config: CheckpointConfig = field(default_factory=CheckpointConfig) - log_config: List[LogConfig] = field(default_factory=list) + model: ModelConfig = field(default_factory=ModelConfig) + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + dataset: DatasetConfig = field(default_factory=DatasetConfig) + dataloader: DataloaderConfig = field(default_factory=DataloaderConfig) + dataset_sampler: DatasetSamplerConfig = field(default_factory=DatasetSamplerConfig) + lr_scheduler: Optional[LRSchedulerConfig] = None + checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) + log: List[LogConfig] = field(default_factory=list) # It can be `HookConfig` or `HookMapConfig` - hook_config: Any = None + hook: Union[HookConfig, HookMapConfig, None] = None + + # TODO: mixed precision support + precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None - fp16: bool = False - bf16: bool = False micro_batch_size: int = 1 # You can set one of `global_batch_size` and `grad_accumulation_steps` option. # Please note if both are set, they must be consistent. @@ -322,21 +333,34 @@ def __post_init__(self): if self.run_mode not in ('compile', 'run'): raise ValueError(f"Invalid run_mode {self.run_mode}") - if self.fp16 and self.bf16: - raise ValueError("Cannot use both fp16 and bf16") + + supported_precision_type = get_args(_PRECISION_TYPE) + supported_tensor_type = get_args(_TENSOR_TYPE) + if not self.precision: + self.precision = 'none' + if isinstance(self.precision, str): + self.precision = {k: self.precision for k in supported_tensor_type} + for tensor_type in supported_tensor_type: + if tensor_type not in self.precision: + self.precision[tensor_type] = 'none' + if self.precision[tensor_type] not in supported_precision_type: + raise ValueError(f"Invalid precision {self.precision[tensor_type]} for {tensor_type}") + if any(k not in supported_tensor_type for k in self.precision): + raise ValueError(f"Invalid tensor type found in {self.precision.keys()}") + if not self.max_epochs and not self.max_train_steps: raise ValueError("max_epochs or max_train_steps is required") - if not self.model_config.type: + if not self.model.type: raise ValueError("model type is required") - if not self.optimizer_config.type: + if not self.optimizer.type: raise ValueError("optimizer type is required") - if not self.dataset_config.type: + if not self.dataset.type: raise ValueError("dataset type is required") - if not self.dataloader_config.type: + if not self.dataloader.type: raise ValueError("dataloader type is required") - if not self.dataset_sampler_config.type: + if not self.dataset_sampler.type: raise ValueError("dataset_sampler type is required") - if self.lr_scheduler_config and not self.lr_scheduler_config.type: + if self.lr_scheduler and not self.lr_scheduler.type: raise ValueError("lr_scheduler type is required") @classmethod @@ -397,13 +421,13 @@ def create_kwarg(cls, value: Any): @property def model_type(self): - return load_type(self.model_config.type) + return load_type(self.model.type) @property def aggregate_outputs(self): - if not self.optimizer_config.aggregate_outputs_fn: + if not self.optimizer.aggregate_outputs_fn: return None - return load_type(self.optimizer_config.aggregate_outputs_fn) + return load_type(self.optimizer.aggregate_outputs_fn) @property def scaling_factor(self): @@ -413,6 +437,18 @@ def scaling_factor(self): def update_freq(self): return self.global_batch_size // self.micro_batch_size // self.scaling_factor + @property + def param_dtype(self) -> torch.dtype: + return _PRECISION_MAP[self.precision['param']] + + @property + def buffer_dtype(self) -> torch.dtype: + return _PRECISION_MAP[self.precision['buffer']] + + @property + def input_dtype(self) -> torch.dtype: + return _PRECISION_MAP[self.precision['input']] + def init_env(self): if self.seed is not None: import random @@ -427,38 +463,38 @@ def init_env(self): init_env_fn(self) def create_model(self) -> torch.nn.Module: - kwargs = self.create_kwarg(self.model_config.args) + kwargs = self.create_kwarg(self.model.args) return self.model_type(**kwargs) def create_parallel_optimizer(self, parallel_model: ParallelModule): - kwargs = self.create_kwarg(self.optimizer_config.args) - optimizer_class = load_type(self.optimizer_config.type) + kwargs = self.create_kwarg(self.optimizer.args) + optimizer_class = load_type(self.optimizer.type) return build_optimizer(parallel_model, optimizer_class, **kwargs) def create_dataset(self, stage='train'): - dataset_args = getattr(self.dataset_config, f'{stage}_args') + dataset_args = getattr(self.dataset, f'{stage}_args') if not dataset_args: return None kwargs = self.create_kwarg(dataset_args) - dataset_class = load_type(self.dataset_config.type) + dataset_class = load_type(self.dataset.type) dataset = dataset_class(**kwargs) if isinstance(dataset_class, torch.utils.data.IterableDataset): raise ValueError("IterableDataset is not supported") return dataset def create_sampler(self, dataset, stage='train'): - sampler_args = getattr(self.dataset_sampler_config, f'{stage}_args') - sampler_args = sampler_args or self.dataset_sampler_config.train_args + sampler_args = getattr(self.dataset_sampler, f'{stage}_args') + sampler_args = sampler_args or self.dataset_sampler.train_args kwargs = self.create_kwarg(sampler_args) kwargs['dataset'] = dataset kwargs['num_replicas'] = self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus kwargs['rank'] = torch.distributed.get_rank() // self.compute_config.plan_ngpus - sampler_class = load_type(self.dataset_sampler_config.type) + sampler_class = load_type(self.dataset_sampler.type) return sampler_class(**kwargs) def create_dataloader(self, stage='train', dataset=None): - dataloader_args = getattr(self.dataloader_config, f'{stage}_args') - dataloader_args = dataloader_args or self.dataloader_config.train_args + dataloader_args = getattr(self.dataloader, f'{stage}_args') + dataloader_args = dataloader_args or self.dataloader.train_args kwargs = self.create_kwarg(dataloader_args) if 'batch_size' in kwargs: raise ValueError("`batch_size` should not be specified in dataloader_args. " @@ -472,35 +508,35 @@ def create_dataloader(self, stage='train', dataset=None): kwargs['collate_fn'] = load_type(kwargs['collate_fn']) kwargs['batch_size'] = self.micro_batch_size kwargs['sampler'] = self.create_sampler(kwargs['dataset'], stage) - dataloader_class = load_type(self.dataloader_config.type) + dataloader_class = load_type(self.dataloader.type) return dataloader_class(**kwargs) def create_lr_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.LRScheduler: - if not self.lr_scheduler_config: + if not self.lr_scheduler: return None - kwargs = self.create_kwarg(self.lr_scheduler_config.args) - lr_scheduler_class = load_type(self.lr_scheduler_config.type) + kwargs = self.create_kwarg(self.lr_scheduler.args) + lr_scheduler_class = load_type(self.lr_scheduler.type) return lr_scheduler_class(optimizer, **kwargs) def create_loggers(self) -> List['LoggerBase']: loggers = [] - for log_config in self.log_config: + for log_config in self.log: kwargs = self.create_kwarg(log_config.args) logger_class = load_type(log_config.type) loggers.append(logger_class(**kwargs)) return loggers def create_hook(self) -> TrainHook: - if not self.hook_config: + if not self.hook: return TrainHook() # empty hook - if isinstance(self.hook_config, dict): - if 'type' in self.hook_config: - hook_config = HookConfig(**self.hook_config) + if isinstance(self.hook, dict): + if 'type' in self.hook: + hook_config = HookConfig(**self.hook) else: - hook_config = HookMapConfig(**self.hook_config) + hook_config = HookMapConfig(**self.hook) else: - hook_config = self.hook_config + hook_config = self.hook if isinstance(hook_config, HookConfig): kwargs = self.create_kwarg(hook_config.args) diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 99592aeb..0f4bd8df 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -42,11 +42,8 @@ def test_fix_type(): with pytest.raises(ValueError): _fix_type(List[str], True) - with pytest.raises(ValueError): - _fix_type(Union[bool, int]) - - with pytest.raises(ValueError): - _fix_type(Union[bool, int, None]) + assert _fix_type(Union[bool, int]) == None + assert _fix_type(Union[bool, int, None]) == None @pytest.mark.skipif(sys.version_info < (3, 10), reason='| is not available as union type for python < 3.10') @@ -59,11 +56,8 @@ def test_fix_type2(): with pytest.raises(ValueError): _fix_type(list[str], True) - with pytest.raises(ValueError): - _fix_type(bool|int) - - with pytest.raises(ValueError): - _fix_type(bool|int|None) + assert _fix_type(bool|int) == None + assert _fix_type(bool|int|None) == None @dataclass @@ -159,3 +153,21 @@ class A: x = parse_args(['--a.0=1', '--a.1=2', '--b.0.h=3', '--b.1.h=4', '--c.1=4']) y = deserialize_dataclass(x, A) assert y == A(a=[1, 2], b=[GConfig(h=3), GConfig(h=4)], c=(None, 4)) + + +def test_deserialize_union(): + @dataclass + class A: + p: Union[str, Dict[str, str], None] = None + + x = parse_args(['--p=hello']) + y = deserialize_dataclass(x, A) + assert y.p == 'hello' + + x = parse_args(['--p.a=a', '--p.b=b']) + y = deserialize_dataclass(x, A) + assert y.p == {'a': 'a', 'b': 'b'} + + x = parse_args(['--p.a=1', '--p.b=b']) + y = deserialize_dataclass(x, A) + assert y.p == {'a': 1, 'b': 'b'} # Dict[str, str] is ignored. so '1' will be converted to int diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index e17a0e23..22bc952d 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -24,15 +24,15 @@ def trainer_logging_worker(save_dir): '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', - '--checkpoint_config.no_save', 'true', - '--log_config.0.type', 'nnscaler.cli.loggers.TensorBoardLogger', - '--log_config.0.args.name', 'test-cli', - '--log_config.0.args.root_dir', str(tb_log_savedir), - '--log_config.1.type', 'nnscaler.cli.loggers.WandbLogger', - '--log_config.1.args.name', 'test-cli', - '--log_config.1.args.dir', str(wandb_log_savedir), - '--log_config.1.args.project', 'nnscaler', - '--log_config.1.args.mode', 'offline', + '--checkpoint.no_save', 'true', + '--log.0.type', 'nnscaler.cli.loggers.TensorBoardLogger', + '--log.0.args.name', 'test-cli', + '--log.0.args.root_dir', str(tb_log_savedir), + '--log.1.type', 'nnscaler.cli.loggers.WandbLogger', + '--log.1.args.name', 'test-cli', + '--log.1.args.dir', str(wandb_log_savedir), + '--log.1.args.project', 'nnscaler', + '--log.1.args.mode', 'offline', ]) trainer.train() @@ -65,16 +65,16 @@ def trainer_resume_worker(save_dir, save_type, bf16): # train 4 epcho in one time trainer = Trainer([ '-f', config_path, - '--bf16', str(bf16), + '--precision', 'bf16' if bf16 else 'none', '--max_epochs', '4', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', - '--checkpoint_config.save_type', save_type, - '--checkpoint_config.save_dir', str(ckpt_savedir), - '--checkpoint_config.resume_from', 'last', - '--checkpoint_config.keep_last_n_checkpoints', '30', + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.train() ckpt_files = set(ckpt_savedir.glob('**/*.ckpt')) @@ -85,16 +85,16 @@ def trainer_resume_worker(save_dir, save_type, bf16): # first two epochs trainer = Trainer([ '-f', config_path, - '--bf16', str(bf16), + '--precision', 'bf16' if bf16 else 'none', '--max_epochs', '2', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', - '--checkpoint_config.save_type', save_type, - '--checkpoint_config.save_dir', str(ckpt0_savedir), - '--checkpoint_config.resume_from', 'last', - '--checkpoint_config.keep_last_n_checkpoints', '30', + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.train() ckpt0_files0 = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} @@ -109,16 +109,16 @@ def trainer_resume_worker(save_dir, save_type, bf16): # continue with the last two epochs (resume for sharded/deduped checkpoint) trainer = Trainer([ '-f', config_path, - '--bf16', str(bf16), + '--precision', 'bf16' if bf16 else 'none', '--max_epochs', '4', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', - '--checkpoint_config.save_type', save_type, - '--checkpoint_config.save_dir', str(ckpt0_savedir), - '--checkpoint_config.resume_from', 'last', - '--checkpoint_config.keep_last_n_checkpoints', '30', + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.train() left_files = { @@ -137,15 +137,15 @@ def trainer_resume_worker(save_dir, save_type, bf16): # continue with the last two epochs (resume for merged) trainer = Trainer([ '-f', config_path, - '--bf16', str(bf16), + '--precision', 'bf16' if bf16 else 'none', '--max_epochs', '4', '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', - '--checkpoint_config.save_type', save_type, - '--checkpoint_config.save_dir', str(ckpt1_savedir), - '--checkpoint_config.resume_from', str(ckpt1_savedir / 'merged.pt'), - '--checkpoint_config.keep_last_n_checkpoints', '30', + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt1_savedir), + '--checkpoint.resume_from', str(ckpt1_savedir / 'merged.pt'), + '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.train() left_files = { diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml index e11c3048..67a4d82e 100644 --- a/tests/cli/trainer_args.yaml +++ b/tests/cli/trainer_args.yaml @@ -12,18 +12,18 @@ global_batch_size: 8 max_epochs: 4 max_train_steps: 100 -model_config: +model: type: tests.cli.common.MLP args: dim: 16 nlayers: 16 -optimizer_config: +optimizer: type: torch.optim.Adam args: lr: 0.01 -dataset_config: +dataset: type: tests.cli.common.SimpleDataset train_args: dim: 16 @@ -32,7 +32,7 @@ dataset_config: dim: 16 size: 10 -checkpoint_config: +checkpoint: keep_last_n_checkpoints: 30 every_n_train_steps: 1 save_type: deduped diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index 4a037731..c588fb80 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -244,10 +244,10 @@ def on_val_step_end(trainer: Trainer, outputs, batches, idx) -> None: pas_policy=policy, instance_name=f'cli_{policy}', enable_progress_bar=False, - model_config=ModelConfig( + model=ModelConfig( type=CorrectnessWorkerM, ), - dataset_config=DatasetConfig( + dataset=DatasetConfig( type=correctnes_worker_cli_dataset, train_args={ 'stage': 'train' @@ -256,24 +256,24 @@ def on_val_step_end(trainer: Trainer, outputs, batches, idx) -> None: 'stage': 'val' }, ), - dataset_sampler_config=DatasetSamplerConfig( + dataset_sampler=DatasetSamplerConfig( type='torch.utils.data.DistributedSampler', val_args={ 'shuffle': False, # lightning doesn't shuffle val set }, ), - optimizer_config=OptimizerConfig( + optimizer=OptimizerConfig( type=torch.optim.Adam, args={ 'lr': _correctnes_worker_model.lr }, clip_gnorm=gradient_clip_val, ), - checkpoint_config=CheckpointConfig( + checkpoint=CheckpointConfig( no_save=True, ), - lr_scheduler_config=lr_config, - hook_config=dict( + lr_scheduler=lr_config, + hook=dict( before_gnorm_clip=on_before_grad_clip, after_aggregate_train_step_outputs=after_aggregate_train_step_outputs, on_train_step_end=on_train_step_end, From c5b6dfbdd2fe79405b2a6c3150347c667f4c695c Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 29 Jul 2024 06:44:22 +0000 Subject: [PATCH 1689/1892] Merged PR 2212: add mixed precision f16 optimizer add mixed precision f16 optimizer --- nnscaler/parallel.py | 3 +- nnscaler/runtime/f16_optimizer.py | 128 ++++++++++++++++++ tests/cli/test_trainer.py | 12 +- tests/cli/trainer_args.yaml | 8 +- tests/runtime/test_f16_optimizer.py | 57 ++++++++ .../test_f16_optimizer_trainer_args.yaml | 45 ++++++ 6 files changed, 250 insertions(+), 3 deletions(-) create mode 100644 nnscaler/runtime/f16_optimizer.py create mode 100644 tests/runtime/test_f16_optimizer.py create mode 100644 tests/runtime/test_f16_optimizer_trainer_args.yaml diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 83163346..63571d95 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -2266,7 +2266,8 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g # TODO: can be slow? for k in state_indexes: keys = sorted(optimizer_state_dict['state'][k].keys()) - assert set(keys) == {'step', 'exp_avg', 'exp_avg_sq'} + # for mixed precision f16 optimizer, we will add custom keys + # assert set(keys) == {'step', 'exp_avg', 'exp_avg_sq'} keys.remove('step') # we have done step in previous. for key in keys: value = optimizer_state_dict['state'][k][key] diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py new file mode 100644 index 00000000..676649d1 --- /dev/null +++ b/nnscaler/runtime/f16_optimizer.py @@ -0,0 +1,128 @@ +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class MixedPrecisionF16OptimizerMixin: + """ + A mixin class for mixed precision optimizer. + Support both FP16 and BF16 parameters. + + 1. It will create a copy of FP32 parameters and grads, + and use the FP32 copy for optimization (via `build_fp32_params`). + 2. It will sync FP16 grads to FP32 grads before optimizer.step(). + 3. It will sync FP32 params back to FP16 params after optimizer.step(). + 4. It will zero FP16 grads and FP32 grads to zero in zero_grad(). + + """ + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in mro(method resolution order) + super().__init__(*args, **kwargs) + + @classmethod + def build_fp32_params(cls, params): + # create FP32 copy of parameters and grads + fp32_params = [] + for p in params: + p32 = torch.nn.Parameter(p.data.float()) + p32.grad = torch.zeros_like(p32.data) + fp32_params.append(p32) + return fp32_params + + def step(self, closure=None): + """Performs a single optimization step.""" + self._sync_f16_grads_to_fp32() + super().step(closure) + self._sync_fp32_params_to_f16() + # No need to call gather_params here when zero is enabled, + # as the gathered params are not in the optimizer + + def zero_grad(self, set_to_none: bool = True): + """ + Clears the gradients of all optimized parameters. + Will ignore `set_to_none` and always set fp16 grads to None, and fp32 grads to zero. + """ + for p in self.f16_params: + p.grad = None + for p32 in self.fp32_params: + if p32.grad is not None: + p32.grad.zero_() + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = super().state_dict() + + # move fp32_params to the same level with 'exp_avg' and 'exp_avg_sq' + # we do this to handle the merge of sharded checkpoint in nnscaler + assert 'state' in state_dict, f'state not found in state_dict: {state_dict.keys()}' + assert isinstance(state_dict['state'], dict), f'state is not a dict: {type(state_dict["state"])}' + assert len(self.fp32_params) == len(state_dict['state']), \ + f'len(fp32_params) != len(state[state]): {len(self.fp32_params)} != {len(state_dict["state"])}' + assert 'exp_avg' in state_dict['state'][0], f'currently only verified for adam-like optimizer' + for key, value in state_dict['state'].items(): + assert self.fp32_params[key].shape == value['exp_avg'].shape, f'Shape mismatch: {value["exp_avg"].shape} vs {self.fp32_params[key].shape}' + value['fp32_params'] = self.fp32_params[key] + + return state_dict + + def load_state_dict(self, state_dict): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if 'state' in state_dict and len(state_dict['state']) > 0 and 'fp32_params' in state_dict['state'][0]: + logger.info('try to load fp32_params from state_dict in f16_optimizer') + assert isinstance(self.fp32_params, list), f'fp32_params is not a list: {type(self.fp32_params)}' + device = torch.cuda.current_device() + for i, param in enumerate(self.fp32_params): + ckpt_param = state_dict['state'][i]['fp32_params'] + assert param.shape == ckpt_param.shape, f'Shape mismatch: {param.shape} vs {ckpt_param.shape}' + logger.info(f'param {i}, fp16 norm: {param.data.detach().norm().item()}, fp32 norm: {ckpt_param.data.detach().norm().item()}') + param.data = state_dict['state'][i]['fp32_params'].data.to(device) + # pop to avoid store a redundant copy in the wrapped optimizer + state_dict['state'][i].pop('fp32_params') + + if len(self.param_groups) != 1: + raise RuntimeError('only support one param group') + self.param_groups[0]['params'] = self.fp32_params + + super().load_state_dict(state_dict) + + def _sync_f16_grads_to_fp32(self): + # copy FP16 grads to FP32 + for p, p32 in zip(self.f16_params, self.fp32_params): + if not p.requires_grad: + continue + if p.grad is not None: + if p32.grad is None: + p32.grad = p.grad.data.float() + else: + p32.grad.data.copy_(p.grad.data) + else: + p32.grad = torch.zeros_like(p.data, dtype=torch.float) + + def _sync_fp32_params_to_f16(self): + # copy FP32 params back into FP16 model + for p, p32 in zip(self.f16_params, self.fp32_params): + if not p.requires_grad: + continue + p.data.copy_(p32.data) + + +class MixedPrecisionAdam(MixedPrecisionF16OptimizerMixin, torch.optim.Adam): + def __init__(self, params, **kwargs): + self.f16_params = list(params) + self.fp32_params = self.build_fp32_params(self.f16_params) + super().__init__(self.fp32_params, **kwargs) + + +class MixedPrecisionAdamW(MixedPrecisionF16OptimizerMixin, torch.optim.AdamW): + def __init__(self, params, **kwargs): + self.f16_params = list(params) + self.fp32_params = self.build_fp32_params(self.f16_params) + super().__init__(self.fp32_params, **kwargs) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 22bc952d..444d93c2 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -63,9 +63,14 @@ def trainer_resume_worker(save_dir, save_type, bf16): gen_savedir = save_dir / 'gen' ckpt_savedir = save_dir / 'ckpt' # train 4 epcho in one time + optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' \ + if bf16 == 'Mixed' \ + else 'torch.optim.Adam' + trainer = Trainer([ '-f', config_path, '--precision', 'bf16' if bf16 else 'none', + '--optimizer.type', optimizer_type, '--max_epochs', '4', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), @@ -80,12 +85,14 @@ def trainer_resume_worker(save_dir, save_type, bf16): ckpt_files = set(ckpt_savedir.glob('**/*.ckpt')) assert len(ckpt_files)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last + torch.distributed.barrier() # train 4 epcho two times (resume from last) ckpt0_savedir = save_dir / 'ckpt0' # first two epochs trainer = Trainer([ '-f', config_path, '--precision', 'bf16' if bf16 else 'none', + '--optimizer.type', optimizer_type, '--max_epochs', '2', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), @@ -106,10 +113,12 @@ def trainer_resume_worker(save_dir, save_type, bf16): if trainer.rank == 0: Trainer.merge_checkpoint(list((ckpt0_savedir / 'last').glob('*.ckpt')), ckpt1_savedir / 'merged.pt') + torch.distributed.barrier() # continue with the last two epochs (resume for sharded/deduped checkpoint) trainer = Trainer([ '-f', config_path, '--precision', 'bf16' if bf16 else 'none', + '--optimizer.type', optimizer_type, '--max_epochs', '4', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), @@ -138,6 +147,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): trainer = Trainer([ '-f', config_path, '--precision', 'bf16' if bf16 else 'none', + '--optimizer.type', optimizer_type, '--max_epochs', '4', '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', @@ -184,6 +194,6 @@ def trainer_resume_worker(save_dir, save_type, bf16): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @pytest.mark.parametrize('save_type', ['sharded', 'deduped']) -@pytest.mark.parametrize('bf16', [True, False]) +@pytest.mark.parametrize('bf16', [True, False, 'Mixed']) def test_trainer_resume(tmp_path, save_type, bf16): launch_torchrun(4, trainer_resume_worker, tmp_path, save_type, bf16) diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml index 67a4d82e..272ce791 100644 --- a/tests/cli/trainer_args.yaml +++ b/tests/cli/trainer_args.yaml @@ -6,7 +6,7 @@ compute_config: use_end2end: true run_mode: run -pas_policy: autodist +pas_policy: tp micro_batch_size: 2 global_batch_size: 8 max_epochs: 4 @@ -32,6 +32,12 @@ dataset: dim: 16 size: 10 +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + checkpoint: keep_last_n_checkpoints: 30 every_n_train_steps: 1 diff --git a/tests/runtime/test_f16_optimizer.py b/tests/runtime/test_f16_optimizer.py new file mode 100644 index 00000000..f077f2ea --- /dev/null +++ b/tests/runtime/test_f16_optimizer.py @@ -0,0 +1,57 @@ +from pathlib import Path +import shutil + +import torch +import pytest +import torch.distributed + +from nnscaler.cli.trainer import Trainer +from tests.parallel_module.common import assert_close +from ..launch_torchrun import launch_torchrun + + +def trainer_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('test_f16_optimizer_trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + # train with normal mixed optimizer + ckpt0_savedir = save_dir / 'ckpt0' + trainer = Trainer([ + '-f', config_path, + '--optimizer.type', 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.train() + torch.distributed.barrier() + + # train with normal optimizer + ckpt1_savedir = save_dir / 'ckpt1' + trainer = Trainer([ + '-f', config_path, + '--optimizer.type', 'torch.optim.Adam', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + ]) + trainer.train() + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt') + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt') + # actually they are not close + # assert_close(x['model'], y['model']) + # assert_close(x['optimizer'], y['optimizer']) + # assert_close(x['lr_scheduler'], y['lr_scheduler']) + assert x['optimizer']['state'][0]['exp_avg'].dtype == torch.float32 + assert 'fp32_params' in x['optimizer']['state'][0] + assert x['optimizer']['state'][0]['fp32_params'].dtype == torch.float32 + assert y['optimizer']['state'][0]['exp_avg'].dtype == torch.bfloat16 + assert 'fp32_params' not in y['optimizer']['state'][0] + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_bf16(tmp_path): + launch_torchrun(2, trainer_worker, tmp_path) diff --git a/tests/runtime/test_f16_optimizer_trainer_args.yaml b/tests/runtime/test_f16_optimizer_trainer_args.yaml new file mode 100644 index 00000000..22895c57 --- /dev/null +++ b/tests/runtime/test_f16_optimizer_trainer_args.yaml @@ -0,0 +1,45 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 1 +max_train_steps: 10 +enable_progress_bar: false +precision: bf16 + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: torch.optim.Adam + args: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped From 67f0e813dbfe5f85922943f3d754d28665e66652 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 30 Jul 2024 02:23:26 +0000 Subject: [PATCH 1690/1892] Merged PR 2214: lightning: add merged checkpoint support --- docs/source/pytorch_lightning.md | 13 +---- .../integration/lightning/pytorch/strategy.py | 58 +++++++++++++++---- .../lightning/pytorch/test_strategy.py | 17 +++++- 3 files changed, 65 insertions(+), 23 deletions(-) diff --git a/docs/source/pytorch_lightning.md b/docs/source/pytorch_lightning.md index 918d9473..a66bc082 100644 --- a/docs/source/pytorch_lightning.md +++ b/docs/source/pytorch_lightning.md @@ -68,19 +68,12 @@ Just like other pytorch lightning strategy, you can resume from a checkpoint by specifying the `ckpt_path` argument in the `Trainer.fit` function. Please note when the parallel plan is changed (i.e you re-trace the model with different configurations), the checkpoints become incompatible, and can't be loaded any more. -You must firstly merge the checkpoints to a merged checkpoint and load it as a pretrained model. +You must firstly merge the checkpoints to a merged checkpoint with `NnScalerStrategy.merge_checkpoint` and then load the merged checkpoint as a regular checkpoint. -You can also merge all checkpoints (saved by each rank) to a complete checkpoint by using the `nnscaler.merge_state_dicts` function. ```python -import nnscaler -from pathlib import Path -state_dicts = [] -CHECKPOINT_DIR = Path(...) -for rank in range(world_size): - state_dicts.append(torch.load(CHECKPOINT_DIR / f"{rank}.pt")['state_dict']) -merged_state_dict, _ = nnscaler.merge_state_dicts(state_dicts) -torch.save(merged_state_dict, CHECKPOINT_DIR / "merged.pt") +def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str) -> None: ``` +where `checkpoint_files` is a list of checkpoint files to merge, and `output_file` is the output file path. ## Limitation diff --git a/nnscaler/integration/lightning/pytorch/strategy.py b/nnscaler/integration/lightning/pytorch/strategy.py index 2dd184ab..83477023 100644 --- a/nnscaler/integration/lightning/pytorch/strategy.py +++ b/nnscaler/integration/lightning/pytorch/strategy.py @@ -81,6 +81,12 @@ class NnScalerStrategy(ParallelStrategy): """ strategy_name = "nnscaler" _registered_strategies: List[str] = [] + _nnscaler_extra_state_key = 'nnscaler-extra-state' + _state_dict_type_key = 'state-dict-type' + _pl_module_name_key = 'pl_state_dict' # save some extra pl module states + _pmodule_attr_name = 'nnscaler_pmodule' + _module_name_key = 'state_dict' + _opt_name_key = 'optimizer_states' def __init__( self, @@ -116,12 +122,6 @@ def __init__( raise ValueError("The `pas_policy` must be provided to the `NnScalerStrategy`.") self._state_dict_type = state_dict_type - self._nnscaler_extra_state_key = 'nnscaler-extra-state' - self._state_dict_type_key = 'state-dict-type' - self._pl_module_name_key = 'pl_state_dict' # save some extra pl module states - self._pmodule_attr_name = 'nnscaler_pmodule' - self._module_name_key = 'state_dict' - self._opt_name_key = 'optimizer_states' @override def setup_environment(self) -> None: @@ -507,7 +507,13 @@ def load_checkpoint( assert self.model is not None assert self.lightning_module is not None - state_dict: dict = torch.load(path / f'{self.global_rank}.pt') + if not path.exists(): + raise FileNotFoundError(f"Checkpoint file {path} not found.") + + if path.is_dir(): + state_dict: dict = torch.load(path / f'{self.global_rank}.pt') + else: + state_dict: dict = torch.load(path) nnscaler_extra_state = state_dict.pop(self._nnscaler_extra_state_key) # load the extra states of the pl module self._lightning_module.load_state_dict(nnscaler_extra_state[self._pl_module_name_key], strict=False) @@ -524,15 +530,43 @@ def load_checkpoint( module = getattr(self._lightning_module, self._pmodule_attr_name) optimizer = self.optimizers[0] if self.optimizers else None - if state_dict_type == "deduped": + if state_dict_type == 'merged': + nnscaler.load_merged_state_dict(module, module_dict, optimizer, optimizer_dict) + elif state_dict_type == "deduped": nnscaler.load_deduped_state_dict(module, module_dict, optimizer, optimizer_dict) - else: - module.load_state_dict(module_dict) - if optimizer_dict is not None: - optimizer.load_state_dict(optimizer_dict) + else: # sharded + nnscaler.load_sharded_state_dict(module, module_dict, optimizer, optimizer_dict) return state_dict + @classmethod + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str) -> None: + """ + Merge the checkpoint files into a single checkpoint file. + + Args: + checkpoint_files: The list of checkpoint files to merge. + output_file: The output file path. + """ + state_dicts = [torch.load(f, map_location='cpu') for f in checkpoint_files] + + module_state_dicts = [s[cls._module_name_key] for s in state_dicts] + opt_state_dicts = None + if cls._opt_name_key in state_dicts[0]: + opt_state_dicts = [s[cls._opt_name_key][0] for s in state_dicts] + + merged_module_state_dict, merged_opt_state_dict = nnscaler.merge_state_dicts( + module_state_dicts, + opt_state_dicts + ) + merged_state_dict = state_dicts[0] + # reuse all states from the first checkpoint except the module and optimizer states + merged_state_dict[cls._module_name_key] = merged_module_state_dict + merged_state_dict[cls._opt_name_key] = [merged_opt_state_dict] + merged_state_dict[cls._nnscaler_extra_state_key][cls._state_dict_type_key] = 'merged' + + torch.save(merged_state_dict, output_file) + @classmethod @override def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index c588fb80..04fdcc41 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -346,6 +346,10 @@ def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_s else: model = ClassificationModel() state_dict_type = 'deduped' + if gradient_clip_val: + do_merge = True + else: + do_merge = False if with_tp: compute_config=ComputeConfig(2, 4) policy = 'tp' @@ -374,6 +378,17 @@ def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_s ) trainer.fit(model, datamodule=dm) + torch.distributed.barrier() + if do_merge: + resume_ckpt = Path(tmp_path) / 'merged.ckpt' + ckpt_last_dir = Path(tmp_path) / 'last.ckpt' + ckpt_last_files = list(ckpt_last_dir.glob('**/*.pt')) + if torch.distributed.get_rank() == 0: + NnScalerStrategy.merge_checkpoint(ckpt_last_files, resume_ckpt) + else: + resume_ckpt = Path(tmp_path) / 'last.ckpt' + torch.distributed.barrier() + trainer = Trainer( default_root_dir=tmp_path, max_epochs=2, @@ -389,7 +404,7 @@ def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_s ), plugins=[NnScalerPrecision(precision, scaler=scaler)] ) - trainer.fit(model, datamodule=dm, ckpt_path='last') + trainer.fit(model, datamodule=dm, ckpt_path=resume_ckpt) return model.update_history, model.nnscaler_pmodule.fullmap, model.val_loss_history, model.loss_history From 6a069fae36b5065cd569923cb826858a0ca92b9a Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 30 Jul 2024 02:53:25 +0000 Subject: [PATCH 1691/1892] Merged PR 2213: Refine ring flash attn: add llama 3.1's implementation --- .../README.md | 5 +- .../core/ring_attn_implementation.py | 263 ++++++++++++++++++ .../core/utils.py} | 27 +- .../core}/zigzag_attn_implementation.py | 2 +- examples/ring_attention/ring_attn.py | 101 +++++++ examples/ring_attention/test_ring_attn.py | 136 +++++++++ .../test_zigzag_attn.py | 12 +- .../zigzag_attn.py | 2 +- 8 files changed, 538 insertions(+), 10 deletions(-) rename examples/{zigzag_ring_attention => ring_attention}/README.md (75%) create mode 100644 examples/ring_attention/core/ring_attn_implementation.py rename examples/{zigzag_ring_attention/zigzag_utils/zigzag_utils.py => ring_attention/core/utils.py} (89%) rename examples/{zigzag_ring_attention/zigzag_utils => ring_attention/core}/zigzag_attn_implementation.py (99%) create mode 100644 examples/ring_attention/ring_attn.py create mode 100644 examples/ring_attention/test_ring_attn.py rename examples/{zigzag_ring_attention => ring_attention}/test_zigzag_attn.py (93%) rename examples/{zigzag_ring_attention => ring_attention}/zigzag_attn.py (97%) diff --git a/examples/zigzag_ring_attention/README.md b/examples/ring_attention/README.md similarity index 75% rename from examples/zigzag_ring_attention/README.md rename to examples/ring_attention/README.md index f7cf2b28..35f9fede 100644 --- a/examples/zigzag_ring_attention/README.md +++ b/examples/ring_attention/README.md @@ -1,4 +1,4 @@ -# zigzag ring attention +# ring attention Tensor parallel (partition head) is a widely used distributed plan to train large language models. Computation and memory are distributed evenly across devices. However, when the sequence length is extremely long (e.g., 1M), the partition degree of @@ -12,7 +12,10 @@ implements a high-performance version in PyTorch. This example attempts to integ The interface is wrapped in `zigzag_attn.py`. [flash attention](https://github.com/Dao-AILab/flash-attention) is required for this example. +In addition to the zigzag version, we also include a implementation based on [llama 3.1](https://ai.meta.com/research/publications/the-llama-3-herd-of-models/)'s technical report. This version uses `all_gather` and `reduce_scatter` to collect and distribute the kv values and gradients. You can check the code in `ring_attn.py`. + Test can be run with the following command: ```bash +torchrun --nproc_per_node 4 test_ring_attn.py torchrun --nproc_per_node 4 test_zigzag_attn.py ``` \ No newline at end of file diff --git a/examples/ring_attention/core/ring_attn_implementation.py b/examples/ring_attention/core/ring_attn_implementation.py new file mode 100644 index 00000000..e6731b54 --- /dev/null +++ b/examples/ring_attention/core/ring_attn_implementation.py @@ -0,0 +1,263 @@ +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from .utils import shuffle_input, recover_output, GlobalMemoryBuffer + + +_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + up_q = q[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + up_out, _, _, _, _, up_lse, _, _ = _flash_attn_forward( + up_q, + up_k, + up_v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + + down_q = q[:, block_len:] + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + down_out, _, _, _, _, down_lse, _, _ = _flash_attn_forward( + down_q, + down_k, + down_v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + + out = torch.cat([up_out, down_out], dim=1) + return out, up_lse, down_lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + up_lse, + down_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + dq = torch.zeros_like(q) + dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_dk") + dk_buffer.zero_() + dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_dv") + dv_buffer.zero_() + + up_q = q[:, :block_len] + up_out = out[:, :block_len] + up_dout = dout[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + _flash_attn_backward( + up_dout, + up_q, + up_k, + up_v, + up_out, + up_lse, + dq[:, :block_len], + dk_buffer[:, :(up_rank + 1) * block_len], + dv_buffer[:, :(up_rank + 1) * block_len], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + down_q = q[:, block_len:] + down_out = out[:, block_len:] + down_dout = dout[:, block_len:] + # TODO: optimize the buffer allocation + down_dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_down_dk") + down_dk_buffer.zero_() + down_dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_down_dv") + down_dv_buffer.zero_() + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + _flash_attn_backward( + down_dout, + down_q, + down_k, + down_v, + down_out, + down_lse, + dq[:, block_len:], + down_dk_buffer[:, :(down_rank + 1) * block_len], + down_dv_buffer[:, :(down_rank + 1) * block_len], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + dk_buffer.add_(down_dk_buffer) + dv_buffer.add_(down_dv_buffer) + + dim_size = list(k.size()) + dim_size[1] = dim_size[1] // world_size + dk = torch.empty(dim_size, dtype=k.dtype, device=k.device) + dv = torch.empty(dim_size, dtype=v.dtype, device=v.device) + dist._reduce_scatter_base(dk, dk_buffer, group=process_group) + dist._reduce_scatter_base(dv, dv_buffer, group=process_group) + + return dq, dk, dv + + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, all gather k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, reduce scatter dk, dv +''' +class RingFlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + q = shuffle_input(to_send=q, process_group=group) + world_size = dist.get_world_size(group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed._all_gather_base(k_buffer, k, group=group) + torch.distributed._all_gather_base(v_buffer, v, group=group) + + out, up_lse, down_lse = ring_flash_attn_forward( + group, + q, + k_buffer, + v_buffer, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, up_lse, down_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_output(out, process_group=group) + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + dout = shuffle_input(to_send=dout, process_group=ctx.group) + q, k, v, out, up_lse, down_lse = ctx.saved_tensors + world_size = dist.get_world_size(ctx.group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed._all_gather_base(k_buffer, k, group=ctx.group) + torch.distributed._all_gather_base(v_buffer, v, group=ctx.group) + + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k_buffer, + v_buffer, + out, + up_lse, + down_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + dq = recover_output(dq, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None diff --git a/examples/zigzag_ring_attention/zigzag_utils/zigzag_utils.py b/examples/ring_attention/core/utils.py similarity index 89% rename from examples/zigzag_ring_attention/zigzag_utils/zigzag_utils.py rename to examples/ring_attention/core/utils.py index a576db2f..ba1d1b61 100644 --- a/examples/zigzag_ring_attention/zigzag_utils/zigzag_utils.py +++ b/examples/ring_attention/core/utils.py @@ -1,9 +1,31 @@ from typing import Optional, Tuple +from functools import reduce +import operator import torch import torch.distributed as dist -__all__ = ["update_out_and_lse", "RingComm"] + +# copy from megatron/core/utils.py +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if ( + self.buffer.get((name, dtype), None) is None + or self.buffer[(name, dtype)].numel() < required_len + ): + self.buffer[(name, dtype)] = torch.empty( + required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) def update_out_and_lse( @@ -46,6 +68,7 @@ def _update_out_and_lse( out, lse = _update_out_and_lse(out, lse, block_out, block_lse) return out, lse + class RingComm: def __init__(self, process_group: dist.ProcessGroup): self._process_group = process_group @@ -94,6 +117,7 @@ def wait(self): self._reqs = None self._ops = [] + def shuffle_input(to_send: torch.Tensor, process_group: dist.ProcessGroup = None): @@ -159,6 +183,7 @@ def shuffle_input(to_send: torch.Tensor, return to_send_f + def recover_output(to_send: torch.Tensor, process_group: dist.ProcessGroup = None): diff --git a/examples/zigzag_ring_attention/zigzag_utils/zigzag_attn_implementation.py b/examples/ring_attention/core/zigzag_attn_implementation.py similarity index 99% rename from examples/zigzag_ring_attention/zigzag_utils/zigzag_attn_implementation.py rename to examples/ring_attention/core/zigzag_attn_implementation.py index 1e473301..b1cbfdba 100644 --- a/examples/zigzag_ring_attention/zigzag_utils/zigzag_attn_implementation.py +++ b/examples/ring_attention/core/zigzag_attn_implementation.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward -from .zigzag_utils import RingComm, update_out_and_lse, shuffle_input, recover_output +from .utils import RingComm, update_out_and_lse, shuffle_input, recover_output ''' Assume we have 4 GPUs A, B, C, D. diff --git a/examples/ring_attention/ring_attn.py b/examples/ring_attention/ring_attn.py new file mode 100644 index 00000000..e6da2f3a --- /dev/null +++ b/examples/ring_attention/ring_attn.py @@ -0,0 +1,101 @@ +from typing import Tuple, List, Dict +import torch +from torch import Tensor +import torch.distributed + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir.operator import IRFwOperation +from core.ring_attn_implementation import RingFlashAttnFunc +from flash_attn import flash_attn_func + +import torch.distributed as dist +from nnscaler.runtime.device import DeviceGroup + +def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, + dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None) -> Tensor: + ''' + wrap the ring_attn_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + return output + + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + + output = RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ).contiguous() + + return output + +def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate ring_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + +register_op('bs l h dim^, bs l h dim^, bs l h dim^ -> bs l h dim^', emit_fn=emit_ring)(wrap_ring_attn_func) diff --git a/examples/ring_attention/test_ring_attn.py b/examples/ring_attention/test_ring_attn.py new file mode 100644 index 00000000..198a1f4a --- /dev/null +++ b/examples/ring_attention/test_ring_attn.py @@ -0,0 +1,136 @@ +import torch +import nnscaler +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType +import torch.distributed as dist +from flash_attn import flash_attn_func + +import nnscaler.graph +import nnscaler.graph.function +from ring_attn import wrap_ring_attn_func + +import random + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, _in0, _in1, _in2): + out = wrap_ring_attn_func(_in0, _in1, _in2) + return out + +def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: + ngpus = resource.plan_ngpus + partitioned = False + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == 'ring_attn.wrap_ring_attn_func': + print('Partitioned node: ', node) + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx=0, dim=1, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + assert partitioned, f'expect ring_attn_func in graph, but not found.' + return graph + +if __name__ == "__main__": + nnscaler.init() + rank_id = torch.distributed.get_rank() + world_size = dist.get_world_size() + + set_seed(rank_id) + bsz = 1 + seqlen = 8192 + nheads = 24 + d = 128 + + device = torch.device(f"cuda:{rank_id}") + # dtype = torch.float16 + dtype = torch.bfloat16 + + q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + dist.barrier() + + single_out = wrap_ring_attn_func(q, k, v) + single_out.retain_grad() + single_loss = single_out.sum() + single_loss.backward() + + model = TestModule() + + _in0 = q.detach().clone().requires_grad_() + _in1 = k.detach().clone().requires_grad_() + _in2 = v.detach().clone().requires_grad_() + + parallel_model = parallelize(model, dummy_forward_args={"_in0": _in0, "_in1": _in1, "_in2": _in2}, pas_policy=policy, + compute_config=ComputeConfig(world_size, world_size), reuse=ReuseType.OVERRIDE) + parallel_model = parallel_model.cuda() + + + parallel_model.train() + + _in0 = q.detach().clone().requires_grad_() + _in1 = k.detach().clone().requires_grad_() + _in2 = v.detach().clone().requires_grad_() + + para_out = parallel_model(_in0, _in1, _in2) + para_loss = para_out.sum() + para_loss.backward() + parallel_model.sync_grad() + + log("single out", single_out, rank0_only=True) + log("multi out", para_out, rank0_only=True) + log("out diff", single_out - para_out, rank0_only=True) + + log("single dq", q.grad, rank0_only=True) + log("multi dq", _in0.grad, rank0_only=True) + log("dq diff", q.grad - _in0.grad, rank0_only=True) + + log("single dk", k.grad, rank0_only=True) + log("multi dk", _in1.grad, rank0_only=True) + log("dk diff", k.grad - _in1.grad, rank0_only=True) + + log("single dv", v.grad, rank0_only=True) + log("multi dv", _in2.grad, rank0_only=True) + log("dv diff", v.grad - _in2.grad, rank0_only=True) diff --git a/examples/zigzag_ring_attention/test_zigzag_attn.py b/examples/ring_attention/test_zigzag_attn.py similarity index 93% rename from examples/zigzag_ring_attention/test_zigzag_attn.py rename to examples/ring_attention/test_zigzag_attn.py index 399db43b..e5eea77d 100644 --- a/examples/zigzag_ring_attention/test_zigzag_attn.py +++ b/examples/ring_attention/test_zigzag_attn.py @@ -8,7 +8,7 @@ import nnscaler.graph import nnscaler.graph.function -from examples.zigzag_ring_attention.zigzag_attn import wrap_zigzag_attn_func +from zigzag_attn import wrap_zigzag_attn_func import random @@ -56,7 +56,7 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: ngpus = resource.plan_ngpus partitioned = False for idx, node in enumerate(graph.select(ntype=IRFwOperation)): - if not partitioned and node.signature == 'examples.zigzag_ring_attention.zigzag_attn.wrap_zigzag_attn_func': + if not partitioned and node.signature == 'zigzag_attn.wrap_zigzag_attn_func': print('Partitioned node: ', node) sub_nodes = graph.partition( node, node.algorithms('dim'), idx=0, dim=1, num=ngpus) @@ -75,13 +75,13 @@ def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: set_seed(rank_id) bsz = 1 - seqlen = 3824 - nheads = 5 + seqlen = 8192 + nheads = 24 d = 128 device = torch.device(f"cuda:{rank_id}") - dtype = torch.float16 - # dtype = torch.bfloat16 + # dtype = torch.float16 + dtype = torch.bfloat16 q = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(bsz, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) diff --git a/examples/zigzag_ring_attention/zigzag_attn.py b/examples/ring_attention/zigzag_attn.py similarity index 97% rename from examples/zigzag_ring_attention/zigzag_attn.py rename to examples/ring_attention/zigzag_attn.py index 4e2a7bac..a20fcba0 100644 --- a/examples/zigzag_ring_attention/zigzag_attn.py +++ b/examples/ring_attention/zigzag_attn.py @@ -5,7 +5,7 @@ from nnscaler.graph.parser.register import register_op from nnscaler.ir.operator import IRFwOperation -from examples.zigzag_ring_attention.zigzag_utils.zigzag_attn_implementation import ZigZagRingFlashAttnFunc +from core.zigzag_attn_implementation import ZigZagRingFlashAttnFunc from flash_attn import flash_attn_func import torch.distributed as dist From 965907c638164993fe10e225b6d1cab8bc252b45 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 30 Jul 2024 03:19:01 +0000 Subject: [PATCH 1692/1892] Merged PR 2207: Fix loss related gencode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Loss is a special tensor in the computation graph. - requires_grad = True - the forward graph and backward graph share exactly a same tensor physically The main branch exists problem when partitioning the loss. Since the loss is a scalar tensor by default, it is partitioned along the value dimension. Assume we have a operator `nll_loss([1024, 2048], [1024]) -> [1]` with annotation `N+ C^, C^ ->1`. In LLM training, `N` is the token dim, `C` is the dictionary dim, partition along `N` will partition the loss along value. In the main branch, following code will be generated ![image (3).png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2207/attachments/image%20%283%29.png) Although it is runnable and correct, it breaks our definition of `IRSegment`, **the intermediate variable `nll_loss_10138` should not be passed out as an output tensor**. However, removing this sub-tensor directly does not solve the problem, since the real loss tensor is generated by an adapter `nnscaler.runtime.adapter.all_reduce`, which means its `requires_grad` field equals to `False` at runtime. In addition, the additional partitioned `nll_loss_10138` disappears at pipeline in the main branch. ![image (4).png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2207/attachments/image%20%284%29.png) Root causes are - when `gen_activations` is called to generate adapters, the returned adapter for the partitioned loss is wrong. It should be a `nnscaler.runtime.adapter.nn.allreduce_identity` instead of `nnscaler.runtime.adapter.all_reduce` - an additional compiling pass `Grouping` is called for spmd/tp. `Grouping` will dispatch the partitioned graph to each device and build an `IRSegment` for each device. - in the `create_segment` method, there is an additional check when determining the outputs: `isinstance(otensor, IRSubTensor) and otensor.is_loss()`. This check will add both of `nll_loss_10138` and `nll_loss_1955` to the segment's output. - `nll_loss_1955` is annotated with `requires_grad=False` and `grad=None`, `nll_loss_10138` is annotated `requires_grad=True` and `grad = gtensorxxx`. According to the logic in `get_backward_callsite_io_tensors`, `nll_loss_10138` will be recognized as the real loss to the backward graph. - However, in the pipeline code generation, there is no `Grouping` pass. The dispatch process (ExeReuseCell -> Segment -> IRCell) strictly follows the assumption that output of a segment should be a full tensor. To solve this problem, in this PR - generate correct adapters when the output loss is used in another operator (like the `.data` operation in fairseq's criterion) - choose tensor as the segment's output carefully to make the emit process runnable   parity check passed ![image.png](https://msrasrg.visualstudio.com/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/r... --- nnscaler/graph/gener/gen.py | 18 +++ nnscaler/graph/graph.py | 2 +- nnscaler/graph/segment.py | 35 ++++-- tests/compiler/test_compile.py | 4 + tests/graph/test_loss.py | 165 +++++++++++++++++++++++++++ tests/graph/test_segment.py | 59 ++++++++++ tests/parallel_module/test_pyfunc.py | 78 +++++++++++++ 7 files changed, 352 insertions(+), 9 deletions(-) create mode 100644 tests/graph/test_loss.py create mode 100644 tests/graph/test_segment.py create mode 100644 tests/parallel_module/test_pyfunc.py diff --git a/nnscaler/graph/gener/gen.py b/nnscaler/graph/gener/gen.py index 0ee20e83..07884a88 100644 --- a/nnscaler/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -2,6 +2,7 @@ import numpy as np import itertools import logging +import copy from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.gener.concurrent import ConcurrentGener @@ -381,6 +382,23 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: bctensors = bctensors + tuple(fwop.output(0).grad for fwop in input_producer[ftensor]) bctensors = expand_devices(bctensors, consumer=True) assert all(len(ctensor.device) == 1 for ctensor in bctensors), "Not support for multi-device" + # special case for loss tensor: + # 1) Since loss is the output of the whole graph, we don't have a backward producer node for loss. + # Therefore, bptensors is empty for loss tensor. + # 2) We must make sure bptensors to be non-empty to generate correct communication prims. If bptensor + # is empty, grad communication (the backward adapter) will not be generated, so only forward 'all-reduce' + # will be used. As a result, the loss tensor's requires_grad will be set to False at runtime. + # 3) According to loss's semantics in current deep learning, the backward prim should be `identity`. When + # the loss tensor is partitioned along the value dimension, since it is reduced by `add` operation, it is + # safe to use `identity` as the backward prim. + # 4) To generated `identity`, we follow the implementation at activation -> graph/segment output below: create + # dummy producer tensor and assign device information. Note, it is equivalent to copy bptensors from bctensors. + if ftensor.is_loss() and ftensor.requires_grad: + assert len(bptensors) == 0, f'expect no backward producer for loss tensor {ftensor}, but got {bproducers} with {bptensors}' + assert ftensor in output_consumer, f'expect loss tensor {ftensor} in output_consumer' + bptensors = tuple(fwop.input(0).grad for fwop in output_consumer[ftensor]) + bptensors = expand_devices(bptensors, producer=True) + fadapters = [] diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 48275d7a..8f0b8e18 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -193,7 +193,7 @@ def from_logic_graph(nodes: List[IRCell], if any(isinstance(t, IRSubTensor) and t.requires_grad for t in node.outputs()): requires_grad_pyfunc.append(node) if len(requires_grad_pyfunc) > 0: - dscp = (f'Cube does not support to compute gradients for IRPyFunc.\n' + dscp = (f'nnScaler does not support to compute gradients for IRPyFunc.\n' f'Following nodes require gradients, this may trigger error in backward:\n') for node in requires_grad_pyfunc: dscp += f'\t{node.signature}, cid: {node.cid}\n' diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index 39be3049..5e3b31a7 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -1,6 +1,7 @@ from contextlib import contextmanager from typing import Dict, Union, List, Optional, Set, Tuple, Any, Callable import numpy as np +import logging from nnscaler.ir.tensor import IRFullTensor, IRSubTensor, ValueMap from nnscaler.ir.cten import IRTensor, IRCell, IRObject @@ -11,6 +12,9 @@ from nnscaler.graph.function.pyfunc import IRPyFunc +_logger = logging.getLogger(__name__) + + class CellPosition: def __init__(self, indices: Tuple[int]): @@ -358,24 +362,39 @@ def infer_grad(self, ftensor: IRFullTensor) -> None: # set for producer for ptensor, producer in zip(self.ptensors(ftensor), self.producers(ftensor)): # filter out non-autograd operators of IRPyFunc - if isinstance(producer, IRPyFunc): continue + if isinstance(producer, IRPyFunc): + if fgrad is not None: + msg = f'nnScaler does not support backward of IRPyFunc: {producer}, ' + \ + 'skip setting gradient, please register it as IRDimOps.' + _logger.warning(msg) + continue grad = None if fgrad is None else fgrad.select(ptensor.indmap, (0, 1)) for t in producer.find(ptensor): t.grad = grad # set for consumers - consumers, ctensors = [], [] # consumers that require gradient + # We strictly follow the behavior in the fx graph. It is possible that there + # exists a node that consumes a tensor with gradient but generates tensors without + # gradient, e.g., the `.data` operation in torch. As a result, nnscaler will generate + # backward adapter (communications) between this consumer and its producer. + # According to the runtime behavior, we have + # case 1: there are gradients flowing in the consume full tensor. This case happens + # when the full tensor is the segment output at the same time. Note in nnscaler + # we will replicate segment's outputs and we will generate another adapter for + # the activation -> segment output case if the two adapters are different. As a + # result, the node (not matter IRDimOps or IRPyFunc, e.g., .data) should be replicated + # as well. In this case, the backward adapter is correct. + # case 2: no gradients exist, then the backward adapter does not influence the result. + consumers, ctensors = [], [] for ctensor, consumer in zip(self.ctensors(ftensor), self.consumers(ftensor)): itensors = consumer.find(ctensor) # set by default None for itensor in itensors: itensor.grad = None - # filter out non-autograd operators - if fgrad is None: continue - if isinstance(consumer, IRPyFunc): continue - if any(isinstance(t, IRSubTensor) and t.requires_grad for t in consumer.outputs()): + if fgrad is not None: consumers.append(consumer) ctensors.append(ctensor) + # set with value map curr_valmap = ValueMap((0, 1)) nconsumers = len(consumers) @@ -1001,8 +1020,8 @@ def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> I inputs.add(itensor) for otensor in node.outputs(): if not isinstance(otensor, IRObject): continue - # if the tensor is required by segment outputs or is loss during train, set as output - if (isinstance(otensor, IRSubTensor) and otensor.is_loss()) or otensor in segment_outputs: + # if the tensor is required by segment outputs, set as output + if otensor in segment_outputs: outputs.add(otensor) continue consumers, ctensors = self.consumers(otensor.parent), self.ctensors(otensor.parent) diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index f54f3f4b..66bdd515 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -94,6 +94,10 @@ def tensor_parallelism(node, idx, dim, num): graph.assign(sub_node, idx) return sub_nodes + # loss partition + for loss in graph.select(name='sum'): + tensor_parallelism(loss, idx=0, dim=0, num=ngpus) + l1, l2, l3, l4 = graph.select(name='linear') # l1 tensor parallelism diff --git a/tests/graph/test_loss.py b/tests/graph/test_loss.py new file mode 100644 index 00000000..3494726a --- /dev/null +++ b/tests/graph/test_loss.py @@ -0,0 +1,165 @@ +import pytest + +import tempfile +import torch +import math +import os +from pathlib import Path +from nnscaler.parallel import _gen_graph +from nnscaler.policies import _tp, _replica +from nnscaler.graph.gener.gen import IRAdapterGener +from nnscaler.execplan import ExecutionPlan +from nnscaler.execplan.planpass.fusion import DiffFusion +from nnscaler.execplan.planpass.grouping import Grouping +from nnscaler.ir.adapter.prim import AllReduceIdentityPrim, AllToAllAllToAllPrim, AllGatherSplitPrim +from nnscaler.codegen.emit import FuncEmission +from ..utils import replace_all_device_with + + +# in this test, we check following assumptions when partition the loss +# - the output of a parallel module should be same with the input module +# - the adapter should be AllReduceIdentityPrim + +class ModuleA(torch.nn.Module): + def __init__(self): + super(ModuleA, self).__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc(x) + x = x.sum() + return x + + +def pas_partition_loss_simple(graph): + dataloader = graph.nodes()[0] + linear = graph.nodes()[1] + loss = graph.nodes()[2] + _replica(graph, dataloader, [0, 1]) + _tp(graph, linear, [0, 1], 0, 0) + _tp(graph, loss, [0, 1], 0, 0) + return graph + + +class ModuleB(torch.nn.Module): + def __init__(self): + super(ModuleB, self).__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc(x) + x = x.sum() + y = x.data + return x, y + + +def pas_partition_loss_hard(graph): + dataloader = graph.nodes()[0] + linear = graph.nodes()[1] + loss = graph.nodes()[2] + get_attr = graph.nodes()[3] + _replica(graph, dataloader, [0, 1]) + _tp(graph, linear, [0, 1], 0, 0) + _tp(graph, loss, [0, 1], 0, 0) + # .data is automatically replicated since it is a IRPyFunc + return graph + +class ModuleC(torch.nn.Module): + def __init__(self): + super(ModuleC, self).__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc(x) + y = x + x + return x, y + + +def pas_parallel_module(graph): + linear = graph.nodes()[0] + add = graph.nodes()[1] + _tp(graph, linear, [0, 1], 0, 0) + _tp(graph, add, [0, 1], 0, 1) + return graph + + +def mini_compile_and_check(model_type, pas, checker, end2end_mode): + dummy_input = {'x': torch.randn(2, 10)} + model = model_type() + model.train() + + with tempfile.TemporaryDirectory() as tempdir: + init_graph, _ = _gen_graph(model, dummy_input, tempdir, constant_folding=True, end2end_mode=end2end_mode) + partitioned_graph = pas(init_graph) + adapter_graph = IRAdapterGener.gen(partitioned_graph, cost_fn=None) + execplan = ExecutionPlan.from_graph(adapter_graph) + execplan = DiffFusion.apply(execplan) + execplan = Grouping.apply(execplan) + checker(init_graph, partitioned_graph, adapter_graph, execplan) + + +@replace_all_device_with('cpu') +def test_loss_partition_simple(): + from nnscaler.ir.unique import IDGenerator + IDGenerator().clear() + + def checker(init_graph, partitioned_graph, adapter_graph, execplan): + fw_graph = execplan.seq(0)[1] + bw_graph = execplan.seq(0)[2] + adapter = fw_graph.nodes()[-1] + assert len(adapter.prims) == 1 + assert isinstance(adapter.prims[0], AllReduceIdentityPrim) + assert fw_graph.outputs() == init_graph.outputs() + emit = FuncEmission() + input_tensors, output_tensors, output_grads, input_grads = \ + emit.get_backward_callsite_io_tensors(bw_graph) + assert len(output_tensors) == 1 + assert output_tensors[0] == fw_graph.outputs()[0] + + mini_compile_and_check(ModuleA, pas_partition_loss_simple, checker, True) + + +@replace_all_device_with('cpu') +def test_loss_partition_hard(): + from nnscaler.ir.unique import IDGenerator + IDGenerator().clear() + + def checker(init_graph, partitioned_graph, adapter_graph, execplan): + fw_graph = execplan.seq(0)[1] + bw_graph = execplan.seq(0)[2] + adapter = fw_graph.nodes()[-2] + assert len(adapter.prims) == 1 + assert isinstance(adapter.prims[0], AllReduceIdentityPrim) + assert fw_graph.outputs() == init_graph.outputs() + emit = FuncEmission() + input_tensors, output_tensors, output_grads, input_grads = \ + emit.get_backward_callsite_io_tensors(bw_graph) + assert len(output_tensors) == 1 + assert output_tensors[0] == fw_graph.outputs()[0] + + mini_compile_and_check(ModuleB, pas_partition_loss_hard, checker, True) + + +@replace_all_device_with('cpu') +def test_segment_parallel_module(): + from nnscaler.ir.unique import IDGenerator + IDGenerator().clear() + + def checker(init_graph, partitioned_graph, adapter_graph, execplan): + # print(adapter_graph.nodes()) + fw_graph = execplan.seq(0)[0] + bw_graph = execplan.seq(0)[1] + # print(fw_graph.nodes()) + # print(bw_graph.nodes()) + adapter0 = fw_graph.nodes()[2] + adapter1 = fw_graph.nodes()[3] + adapter2 = fw_graph.nodes()[5] + assert(len(adapter0.prims) == 1) + assert(isinstance(adapter0.prims[0], AllToAllAllToAllPrim)) + assert(len(adapter1.prims) == 1) + assert(isinstance(adapter1.prims[0], AllGatherSplitPrim)) + assert(len(adapter2.prims) == 1) + assert(isinstance(adapter2.prims[0], AllGatherSplitPrim)) + assert fw_graph.outputs() == init_graph.outputs() + + mini_compile_and_check(ModuleC, pas_parallel_module, checker, False) diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py new file mode 100644 index 00000000..5af0678d --- /dev/null +++ b/tests/graph/test_segment.py @@ -0,0 +1,59 @@ +import nnscaler +import nnscaler.graph.function.function as F +from nnscaler.ir.tensor import IRFullTensor +from nnscaler.graph import IRGraph +from nnscaler.ir.adapter import IRAdapter + + +def _tensor(shape, requires_grad=True): + return IRFullTensor(shape, requires_grad=requires_grad).tosub() + + +def test_create_segment_loss_func(): + data = _tensor([256, 256], False) + w1 = _tensor([256, 256]) + out1 = _tensor([256, 256]) + matmul_1 = F.Linear(data, w1) + matmul_1.set_output(0, out1) + w2 = _tensor([256, 256]) + out2 = _tensor([256, 256]) + matmul_2 = F.Linear(out1, w2) + matmul_2.set_output(0, out2) + loss = _tensor([1]) + sum = F.Sum(out2) + sum.set_output(0, loss) + d = _tensor([1], False) + get = F.GetAttr(loss, 'data', 'getattr') + get.set_output(0, d) + nodes = [matmul_1, matmul_2, sum, get] + graph = IRGraph(nodes, [data], [loss, d], 'genmodel') + graph.backward(loss) + segment = graph.create_segment([matmul_2, sum, get]) + print(segment.extra_repr()) + assert len(segment.outputs()) == 2 + assert segment.output(0) == loss + assert segment.output(1) == d + + +def test_create_segment_loss_adapter(): + data = _tensor([256, 256], False) + w1 = _tensor([256, 256]) + out1 = _tensor([256, 256]) + matmul_1 = F.Linear(data, w1) + matmul_1.set_output(0, out1) + w2 = _tensor([256, 256]) + out2 = _tensor([256, 256]) + matmul_2 = F.Linear(out1, w2) + matmul_2.set_output(0, out2) + loss = _tensor([1]) + sum = F.Sum(out2) + sum.set_output(0, loss) + sum.device = 0 + adapter = IRAdapter([sum.output(0)], [sum.output(0)]) + nodes = [matmul_1, matmul_2, sum, adapter] + graph = IRGraph(nodes, [data], [loss], 'genmodel') + graph.backward(loss) + segment = graph.create_segment([matmul_2, sum, adapter]) + print(segment.extra_repr()) + assert len(segment.outputs()) == 1 + assert segment.output(0) == loss diff --git a/tests/parallel_module/test_pyfunc.py b/tests/parallel_module/test_pyfunc.py new file mode 100644 index 00000000..86a854ff --- /dev/null +++ b/tests/parallel_module/test_pyfunc.py @@ -0,0 +1,78 @@ +import pytest + +import tempfile +import torch +import math +import os +from pathlib import Path +from nnscaler.parallel import parallelize, ComputeConfig + +from .common import init_distributed +from ..utils import replace_all_device_with, catch_log +from ..launch_torchrun import launch_torchrun + + +class MyMatmul(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return x.mm(y) + + @staticmethod + def backward(ctx, grad_output): + x, y = ctx.saved_tensors + grad_x = grad_output.mm(y.t()) + grad_y = x.t().mm(grad_output) + return grad_x, grad_y + + +class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.weight = torch.nn.Parameter(torch.rand(10, 10)) + + def forward(self, x): + x = MyMatmul.apply(x, self.weight) + return x + + +def _worker(): + init_distributed() + + dummy_input = {'x': torch.rand(2, 10)} + from nnscaler.graph.parser.fx.parser import _logger as _logger_parser + from nnscaler.graph.graph import _logger as _logger_graph + from nnscaler.graph.segment import _logger as _logger_seg + with tempfile.TemporaryDirectory() as tempdir, \ + catch_log(_logger_parser) as log_stream_parser, \ + catch_log(_logger_seg) as log_stream_seg, \ + catch_log(_logger_graph) as log_stream_graph: + + m_new = parallelize( + MyModule(), + dummy_input, + 'dp', + ComputeConfig(1, 1, use_end2end=False), + gen_savedir=tempdir, + load_module=True + ) + parser_logs = log_stream_parser.getvalue() + seg_logs = log_stream_seg.getvalue() + graph_logs = log_stream_graph.getvalue() + # parser.py: parse_prim_function_method + assert 'non register python runtime function' in parser_logs + # segment.py: infer_grad + assert 'nnScaler does not support backward of IRPyFunc' in seg_logs + # graph.py: from_logic_graph + assert 'nnScaler does not support to compute gradients for IRPyFunc.' in graph_logs + + # not registered, encounter NameError + with pytest.raises(NameError): + logit = m_new(dummy_input['x']) + print(logit) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason='lack of gpu devices') +@replace_all_device_with('cpu') +def test_ir_pyfunc(): + launch_torchrun(1, _worker) From 65d84d6c8357c91172c1191c3ca701f7b90ecb73 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 30 Jul 2024 06:32:46 +0000 Subject: [PATCH 1693/1892] Merged PR 2216: fix out of disk error when running pipeline. Azure pipeline agent is limited to 10GB space, and cannot be changed. The current solution is to delete virtual env after every test config is done --- tox.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tox.ini b/tox.ini index e455477d..ae7e80bd 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,7 @@ envlist = py38,py310 skipsdist = True [testenv] +allowlist_externals = rm passenv = * install_command = pip install {opts} {packages} deps = @@ -16,3 +17,4 @@ deps = commands = coverage erase pytest --cov={toxinidir}/nnscaler -x tests coverage html + rm -rf {envdir} From 9b8bdb75ba6b1212dbcc7f774be51464721e9cb0 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 31 Jul 2024 05:50:00 +0000 Subject: [PATCH 1694/1892] Merged PR 2217: scalar tensor: use right shape in generated module init code scalar tensor: use right shape in generated module init code unit test pass parity check pass --- nnscaler/codegen/module/module.py | 9 +++-- tests/parallel_module/test_gencode.py | 48 +++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 5a87abb0..f4eb993e 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -626,13 +626,13 @@ def init_attributes(self, node: IRCell): if itensor.is_param(): code = psign.format( name=self.tensor_name(itensor), - shape=tuple(itensor.shape), + shape=tuple(itensor.origin_shape), dtype=itensor.dtype ) elif itensor.is_buffer(): code = bsign.format( name=self.tensor_name(itensor), - shape=tuple(itensor.shape), + shape=tuple(itensor.origin_shape), dtype=itensor.dtype, persistent=itensor.is_persistent() ) @@ -640,13 +640,16 @@ def init_attributes(self, node: IRCell): raise RuntimeError(f"Unexpected tensor type: {itensor}") self.model_init_statements.append(code) slicers = tuple(slice(start, stop) for (start, stop) in itensor.indmap) + if itensor.is_scalar_tensor(): + assert len(slicers) == 1 and slicers[0] == slice(0, 1), f"Unexpected slicers {slicers} for scalar tensor." + slicers = '...' # Ellipsis slicer for scalar tensor, x[...] is equivalent to x val_chunks = itensor.valmap[1] code = map_sign.format( attr=self.tensor_name(itensor), tid=itensor.parent.tid, is_param=itensor.is_param(), orig_name=itensor.parent.name, - full_shape=tuple(itensor.parent.shape), + full_shape=tuple(itensor.parent.origin_shape), slicers=str(slicers), val_chunks=val_chunks ) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 8ff8de56..507f0ff4 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -984,8 +984,11 @@ def __init__(self): super().__init__() self.proj = torch.nn.Linear(1024, 1024, bias=False) self.scale = torch.nn.Parameter(torch.zeros(64)) + self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) + self.num_batches_tracked: torch.Tensor def forward(self, x): + self.num_batches_tracked.add_(1) x = self.proj(x) coef = torch.exp(torch.sum(self.scale, dim=-1)) x = x / coef @@ -1005,8 +1008,47 @@ def test_codegen_scalar_tensor(tmp_path): load_module=False, reuse='override', ) - # parallelize will succeed. - assert True + # the code will look like this: + # def __init__(self, init_params=True): + # super().__init__() + # # communication groups + + # self.register_buffer('num_batches_tracked_33', torch.empty((), dtype=torch.int64), persistent=True) + # self.add_full_map('num_batches_tracked_33', 2, False, 'num_batches_tracked', (), ..., 1) + + # self.register_parameter('proj_weight_35', torch.nn.Parameter(torch.empty((1024, 1024), dtype=torch.float32))) + # self.add_full_map('proj_weight_35', 4, True, 'proj.weight', (1024, 1024), (slice(0, 1024, None), slice(0, 1024, None)), 1) + + # self.register_parameter('scale_37', torch.nn.Parameter(torch.empty((64,), dtype=torch.float32))) + # self.add_full_map('scale_37', 8, True, 'scale', (64,), (slice(0, 64, None),), 1) + + + # self._post_init(init_params) + + # def segment41(self, x_43): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 990, in forward, self.num_batches_tracked.add_(1) + # add__34 = torch.Tensor.add_(self.num_batches_tracked_33, 1) + # del add__34 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 991, in forward, x = self.proj(x) + # linear_36 = torch.nn.functional.linear(x_43, self.proj_weight_35, bias=None) + # del x_43 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 992, in forward, coef = torch.exp(torch.sum(self.scale, dim=-1)) + # sum_1_38 = torch.sum(self.scale_37, dim=(-1,), keepdim=False) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 992, in forward, coef = torch.exp(torch.sum(self.scale, dim=-1)) + # exp_39 = torch.exp(sum_1_38) + # del sum_1_38 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 993, in forward, x = x / coef + # truediv_40 = torch.div(linear_36, exp_39, rounding_mode=None) + # del linear_36, exp_39 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 994, in forward, return x.sum() + # sum_2_32 = torch.sum(truediv_40) + # del truediv_40 + # return sum_2_32 + + assert _gencode_contains(tmp_path, ScalarTensorModule, 0, + r"self\.register_buffer\('num_batches_tracked_\d+', torch\.empty\(\(\), dtype=torch\.int64\), persistent=True\)") + assert _gencode_contains(tmp_path, ScalarTensorModule, 0, + r"self\.add_full_map\('num_batches_tracked_\d+', 2, False, 'num_batches_tracked', \(\), \.\.\., 1\)") class ConvTranspose1DModule(torch.nn.Module): @@ -1114,7 +1156,7 @@ def _gencode_conv2d_function_(tempdir): m_new = parallelize( Conv2DModule(weight, bias, groups=2), { - 'input': torch.randn(2, 6, 32, 32), + 'input': torch.randn(2, 6, 32, 32), 'groups': 2, }, 'dp', From f6448a5d8f29a6af58bf794bdba6419d55cc62c6 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 2 Aug 2024 01:56:18 +0000 Subject: [PATCH 1695/1892] Merged PR 2218: minitrainer: fix checkpoint bug (last checkpoint may not be saved) fix checkpoint bug --- nnscaler/cli/trainer.py | 54 ++++++++++++++++++++++----------------- tests/cli/test_trainer.py | 29 +++++++++++++++++++++ 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index d5fb2e23..cbaf4a44 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -50,10 +50,10 @@ class TrainStatus: @dataclass class _StepStat: - train_loss: float = None - val_loss: float = None - lr: float = None - gnorm: float = None + train_loss: Optional[float] = None + val_loss: Optional[float] = None + lr: Optional[float] = None + gnorm: Optional[float] = None class Trainer: @@ -485,11 +485,12 @@ def train(self): self.lr_scheduler.step() if self.train_args.max_train_steps and self.num_train_steps >= self.train_args.max_train_steps: - logger.info(f"Reached train steps({self.train_args.max_train_steps}): Training is done.") + logger.info(f"Reached max train steps({self.train_args.max_train_steps}): Training is done.") break next_batch_index = 0 - else: # not from `break` + else: # not break from for loop, which means not finished with max_train_steps + # finished with max_epochs logger.info(f"Reached max_epochs({self.train_args.max_epochs}): Training is done.") self._log_finalize() @@ -500,15 +501,20 @@ def _validate_and_save(self, step_stat: _StepStat): if self.dataloader['val'] is None: self._save_checkpoint(step_stat.train_loss) return - loss = self._validate(step_stat) + + if step_stat.val_loss is None: + self._validate(step_stat) # will update step_stat.val_loss internally + + loss = step_stat.val_loss self._save_checkpoint(loss) if self.train_status.best_loss > loss: self.train_status.best_loss = loss def _validate(self, step_stat: _StepStat): if self.dataloader['val'] is None: - logger.info('No val dataset specified. Validation skipped.') - return step_stat.train_loss + logger.info('No val dataset specified. Use train_loss as val_loss.') + step_stat.val_loss = step_stat.train_loss + return step_stat.val_loss data_iter = enumerate(self._global_batch_iterator(stage='val')) if self.rank == 0: @@ -556,7 +562,7 @@ def _validate(self, step_stat: _StepStat): step_stat.val_loss = loss self._log_metrics(asdict(step_stat), self.num_train_steps) - return loss + return step_stat.val_loss def train_epoch(self, epoch): VAL_STATUS_NO = 0 # not validated or saved @@ -574,8 +580,9 @@ def train_epoch(self, epoch): disable=not self.train_args.enable_progress_bar ) - step_stat = _StepStat() + step_stat: Optional[_StepStat] = None for idx, batches in data_iter: + step_stat = _StepStat() has_validated = VAL_STATUS_NO # the current batch is idx + resume_from_idx # `+1` because the next_batch_index is the index of the next batch @@ -648,7 +655,6 @@ def train_epoch(self, epoch): if not has_validated: self._validate_and_save(step_stat) has_validated = VAL_STATUS_SAVE - logger.info(f"Reached max_train_steps({self.train_args.max_train_steps}): Training is done.") break if not has_validated and self.train_args.val_every_n_train_steps and \ @@ -657,14 +663,16 @@ def train_epoch(self, epoch): has_validated = VAL_STATUS_VAL # import time # time.sleep(0.2) - else: # not from `break` - if not has_validated: - if self.train_args.max_epochs == self.train_status.epoch + 1 \ - or (self.train_args.checkpoint.every_n_epochs and \ - (self.train_status.epoch + 1) % self.train_args.checkpoint.every_n_epochs == 0): - self._validate_and_save(step_stat) - has_validated = VAL_STATUS_SAVE - elif self.train_args.val_every_n_epochs and \ - (self.train_status.epoch + 1) % self.train_args.val_every_n_epochs == 0: - self._validate(step_stat) - has_validated = VAL_STATUS_VAL + else: # not finished with max_train_steps + if step_stat is None: + return # no train step runs. No need to save checkpoint + if has_validated < VAL_STATUS_SAVE and \ + self.train_args.max_epochs == self.train_status.epoch + 1 \ + or (self.train_args.checkpoint.every_n_epochs and \ + (self.train_status.epoch + 1) % self.train_args.checkpoint.every_n_epochs == 0): + self._validate_and_save(step_stat) + has_validated = VAL_STATUS_SAVE + if not has_validated and self.train_args.val_every_n_epochs and \ + (self.train_status.epoch + 1) % self.train_args.val_every_n_epochs == 0: + self._validate(step_stat) + has_validated = VAL_STATUS_VAL diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 444d93c2..78efa2f6 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -197,3 +197,32 @@ def trainer_resume_worker(save_dir, save_type, bf16): @pytest.mark.parametrize('bf16', [True, False, 'Mixed']) def test_trainer_resume(tmp_path, save_type, bf16): launch_torchrun(4, trainer_resume_worker, tmp_path, save_type, bf16) + + +def trainer_last_checkpoint_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '1', + '--global_batch_size', '4', # mini_batch_size=2, update_freq=2 + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + '--val_every_n_train_steps', '1', + '--checkpoint.every_n_train_steps', '15', + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer.train() + + torch.distributed.barrier() + # make sure the last checkpoint is saved. + assert (ckpt_savedir / '0000-0025' / f'{trainer.rank}.ckpt').exists() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_trainer_last_checkpoint(tmp_path): + launch_torchrun(1, trainer_last_checkpoint_worker, tmp_path) From 7c3a85f7aa9c245e035df3a8c370ac434553b9c2 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 2 Aug 2024 03:01:09 +0000 Subject: [PATCH 1696/1892] Merged PR 2219: Add example chunk_linear_cross_entropy and refine autodist interface - add an example code to reduce the memory footprint when the sequence length and dictionary size is extremely large. It is verified in real model training - add an option `transient_mem_coef` to control the memory constraint --- .../chunk_linear_cross_entropy.py | 60 +++++++++++++++++++ .../test_chunk_linear_cross_entropy.py | 43 +++++++++++++ .../ring_attention/README.md | 0 .../core/ring_attn_implementation.py | 0 .../ring_attention/core/utils.py | 0 .../core/zigzag_attn_implementation.py | 0 .../ring_attention/ring_attn.py | 0 .../ring_attention/test_ring_attn.py | 0 .../ring_attention/test_zigzag_attn.py | 0 .../ring_attention/zigzag_attn.py | 0 nnscaler/autodist/autodist_config.py | 7 +++ nnscaler/autodist/spmd_solver.py | 16 +++-- 12 files changed, 120 insertions(+), 6 deletions(-) create mode 100644 examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py create mode 100644 examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py rename examples/{ => customized_ops}/ring_attention/README.md (100%) rename examples/{ => customized_ops}/ring_attention/core/ring_attn_implementation.py (100%) rename examples/{ => customized_ops}/ring_attention/core/utils.py (100%) rename examples/{ => customized_ops}/ring_attention/core/zigzag_attn_implementation.py (100%) rename examples/{ => customized_ops}/ring_attention/ring_attn.py (100%) rename examples/{ => customized_ops}/ring_attention/test_ring_attn.py (100%) rename examples/{ => customized_ops}/ring_attention/test_zigzag_attn.py (100%) rename examples/{ => customized_ops}/ring_attention/zigzag_attn.py (100%) diff --git a/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py b/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py new file mode 100644 index 00000000..d56a32b0 --- /dev/null +++ b/examples/customized_ops/chunk_linear_cross_entropy/chunk_linear_cross_entropy.py @@ -0,0 +1,60 @@ +import torch +import torch.utils.checkpoint as ckpt + + +def linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int = 0) -> torch.Tensor: + """ + Compute the cross entropy loss of a linear layer. + + Args: + + x: [token_num, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [token_num], the target token index + padding_idx: int, the index of padding token + + Returns: + + losses: [token_num], the cross entropy loss of each token + """ + logits = torch.nn.functional.linear(x, w) + normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) + losses = torch.nn.functional.nll_loss(normalized_logits, y, reduction='none', ignore_index=padding_idx) + return losses + + +def chunk_linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor: + """ + In order to reduce the memory usage when the sequence length and dictionary size are large, we can split the input + tensor into chunks and compute the cross entropy loss of each chunk separately. + You can register this function with annotation 'b l d^, n^ d^, b l -> b l'. + + Args: + + x: [bsz, seq_len, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [bsz, seq_len], the target token index + padding_idx: int, the index of padding token + chunk_size: int, the size of each chunk + + Returns: + + losses: [bsz, seq_len], the cross entropy loss of each token + """ + bsz, seq_len, hidden_size = x.size() + token_num = bsz * seq_len + x = x.view(token_num, hidden_size) + y = y.view(token_num) + + if token_num % chunk_size != 0: + raise ValueError(f"token_num {token_num} is not divisible by chunk_size {chunk_size}") + + chunk_num = token_num // chunk_size + xs = x.view(chunk_num, chunk_size, hidden_size) + ys = y.view(chunk_num, chunk_size) + losses = [] + for i in range(chunk_num): + loss = ckpt.checkpoint(linear_cross_entropy, xs[i], w, ys[i], padding_idx, use_reentrant=False) + losses.append(loss) + losses = torch.stack(losses).view(bsz, seq_len) + return losses diff --git a/examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py b/examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py new file mode 100644 index 00000000..1be4b144 --- /dev/null +++ b/examples/customized_ops/chunk_linear_cross_entropy/test_chunk_linear_cross_entropy.py @@ -0,0 +1,43 @@ +import torch +from chunk_linear_cross_entropy import chunk_linear_cross_entropy, linear_cross_entropy + + +def test_chunk_linear_cross_entropy( + bsz: int, + seq_len: int, + hidden_size: int, + dict_size: int, + dtype: torch.dtype, + chunk_size: int = 1024, + padding_idx: int = 1): + print(f'test chunk linear cross entropy with {dtype}') + device = torch.device('cuda') + + x = torch.randn(bsz, seq_len, hidden_size, dtype=dtype, device=device) + x1 = x.clone().detach() + w = torch.nn.Parameter(torch.randn(dict_size, hidden_size, dtype=dtype, device=device)) + w1 = w.clone().detach().requires_grad_(True) + y = torch.randint(0, dict_size, (bsz, seq_len), dtype=torch.long, device=device) + y1 = y.clone().detach() + + x1 = x1.reshape(bsz * seq_len, hidden_size) + y1 = y1.reshape(bsz * seq_len) + bsl_losses = linear_cross_entropy(x1, w1, y1, padding_idx).reshape(bsz, seq_len) + bsl_loss = bsl_losses.sum() + bsl_loss.backward() + + test_losses = chunk_linear_cross_entropy(x, w, y, padding_idx, chunk_size) + test_loss = test_losses.sum() + test_loss.backward() + + losses_diff = (bsl_losses - test_losses).abs() + print(f'losses_diff max: {losses_diff.max().item()}, losses_diff mean: {losses_diff.mean().item()}') + w_grad_diff = (w.grad - w1.grad).abs() + print(f'w_grad_diff max: {w_grad_diff.max().item()}, w_grad_diff mean: {w_grad_diff.mean().item()}') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + test_chunk_linear_cross_entropy(2, 4096, 4096, 32000, torch.bfloat16) + test_chunk_linear_cross_entropy(2, 4096, 4096, 32000, torch.float16) diff --git a/examples/ring_attention/README.md b/examples/customized_ops/ring_attention/README.md similarity index 100% rename from examples/ring_attention/README.md rename to examples/customized_ops/ring_attention/README.md diff --git a/examples/ring_attention/core/ring_attn_implementation.py b/examples/customized_ops/ring_attention/core/ring_attn_implementation.py similarity index 100% rename from examples/ring_attention/core/ring_attn_implementation.py rename to examples/customized_ops/ring_attention/core/ring_attn_implementation.py diff --git a/examples/ring_attention/core/utils.py b/examples/customized_ops/ring_attention/core/utils.py similarity index 100% rename from examples/ring_attention/core/utils.py rename to examples/customized_ops/ring_attention/core/utils.py diff --git a/examples/ring_attention/core/zigzag_attn_implementation.py b/examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py similarity index 100% rename from examples/ring_attention/core/zigzag_attn_implementation.py rename to examples/customized_ops/ring_attention/core/zigzag_attn_implementation.py diff --git a/examples/ring_attention/ring_attn.py b/examples/customized_ops/ring_attention/ring_attn.py similarity index 100% rename from examples/ring_attention/ring_attn.py rename to examples/customized_ops/ring_attention/ring_attn.py diff --git a/examples/ring_attention/test_ring_attn.py b/examples/customized_ops/ring_attention/test_ring_attn.py similarity index 100% rename from examples/ring_attention/test_ring_attn.py rename to examples/customized_ops/ring_attention/test_ring_attn.py diff --git a/examples/ring_attention/test_zigzag_attn.py b/examples/customized_ops/ring_attention/test_zigzag_attn.py similarity index 100% rename from examples/ring_attention/test_zigzag_attn.py rename to examples/customized_ops/ring_attention/test_zigzag_attn.py diff --git a/examples/ring_attention/zigzag_attn.py b/examples/customized_ops/ring_attention/zigzag_attn.py similarity index 100% rename from examples/ring_attention/zigzag_attn.py rename to examples/customized_ops/ring_attention/zigzag_attn.py diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index c925890b..9606cd4d 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -102,6 +102,11 @@ class AutoDistConfig: The solver to use in spmd parallelism. Currently only support `'dp'` (dynamic programming) `'ilp'` (integer linear programming). + - transient_mem_coef (`float`, *optional*, defaults to `2`): + In autodist, a heuristic is used to estimate the transient memory size: + `transient_mem_size = opt_transient_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)`. This formula + is useful in many cases, but it may be too strict when some operators consume or generate a large tensor + (>= 4GB). In this case, you can set `transient_mem_coef` to a smaller value to relax the constraint. """ def __init__(self, @@ -135,6 +140,7 @@ def __init__(self, max_pipeline_bubble_ratio=0.4, max_pipeline_unbalance_ratio=0.5, solver='ilp', + transient_mem_coef=2, **kwargs): self.pc_path = partition_constraints_path self.profile_dir = profile_dir @@ -169,6 +175,7 @@ def __init__(self, self.max_pipeline_bubble_ratio = max_pipeline_bubble_ratio self.max_pipeline_unbalance_ratio = max_pipeline_unbalance_ratio self.solver = solver + self.transient_mem_coef = transient_mem_coef ignored_keys = list(kwargs.keys()) if ignored_keys: diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 5730d22b..c8e1f35d 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -816,8 +816,11 @@ def calc_mem_cost(self, plan: List[Tuple[int, int]]) -> ModuleMemCostDesc: transient_mem_cost = transient_mem[0] else: transient_mem_cost = transient_mem[0] + transient_mem[1] - if self.autodist_config.is_train: - transient_mem_cost *= 2 + transient_mem_cost *= self.autodist_config.transient_mem_coef + if not self.autodist_config.is_train: + transient_mem_cost /= 2 + else: + transient_mem_cost = 0 cost += transient_mem_cost return ModuleMemCostDesc(cost, mem, act_mem, opt_transient_mem, recompute_mem_cost, transient_mem_cost) @@ -1021,9 +1024,9 @@ def _solve_by_ilp(self, start: int, end: int) -> SPMDSearchOutput: prob += act_mem <= max_act_opt_transient prob += opt_transient_mem <= max_act_opt_transient if self.autodist_config.is_train: - transient_coef = 4 + transient_coef = 2 * self.autodist_config.transient_mem_coef else: - transient_coef = 2 + transient_coef = self.autodist_config.transient_mem_coef prob += mem - act_mem + max_act_opt_transient + transient_coef * max_transient + recompute_mem <= self.mem_bound # 4.3. constraint over e @@ -1134,7 +1137,8 @@ def do_dp(self, intervals: List[Tuple[int, int]], if self.autodist_config.memory_granularity < 1024: raise RuntimeError('dp solver assumes the memory granularity is at least 1024 bytes') - buf_mul = 2 if self.is_train else 1 + buf_mul = self.autodist_config.transient_mem_coef + if not self.is_train: buf_mul /= 2 mem_divisor = self.autodist_config.memory_granularity solver = dp_solver.DPSolver(self.autodist_config.verbose, self.mem_bound // mem_divisor, topk) for start, end in intervals: @@ -1146,7 +1150,7 @@ def do_dp(self, intervals: List[Tuple[int, int]], for i, partition in enumerate(self._op_partitions[idx]): p_cost_desc = self.partition_info[idx][i] solver.add_partition(idx, i, p_cost_desc.comp_time + p_cost_desc.weight_update_time, - p_cost_desc.mem // mem_divisor, p_cost_desc.in_mem // mem_divisor, buf_mul * p_cost_desc.transient_mem // mem_divisor, + p_cost_desc.mem // mem_divisor, p_cost_desc.in_mem // mem_divisor, int(buf_mul * p_cost_desc.transient_mem // mem_divisor), p_cost_desc.activation_mem // mem_divisor, p_cost_desc.opt_transient_mem // mem_divisor, self.p_fathers[idx][i], p_cost_desc.comm_time) solver.solve() From a9942a706c133847b9153c123fceb64584a020ee Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 6 Aug 2024 06:10:21 +0000 Subject: [PATCH 1697/1892] Merged PR 2203: Refine pipeline implementations Verified on llama3 8B + 4K on 4xA6000, distributed plan is 2 pipeline stages, each stage is composed of 2 devices. ![image.png](https://dev.azure.com/msrasrg/bb54e96e-8cc1-46f6-9021-c7048165b5bc/_apis/git/repositories/66b74611-09f4-4d0e-89b7-5ee93c087d3c/pullRequests/2203/attachments/image.png) This PR includes: 1. refine autodist implementations, including - add option `parallel_profile` to control whether profiling nodes in parallel: in pipeline solver we need to build the SPMDSolver and profile nodes for many times, only the first constructing needs parallel profiling to speed up, dumping the graph and sync takes a lot of time when in parallel -> we only profile in serial for later SPMDSolver constructing - fix bugs to generate correct partition plans and analysis for intervals whose searching result is built from identical intervals - add a flag in gecode to tell front-end the loaded module is a pipeline stage or not. It is helpful when compile stage is separated from runtime. 2. fix bug in executor to support bf16 backward 3. refine comments unit test passed parity test passed with [PR](https://dev.azure.com/msrasrg/SuperScaler/_git/Fairseq/pullrequest/2220) verified on 8xH100 with 4 pipeline stages, each stage is composed of one stage. --- docs/source/parallel_module.md | 2 + nnscaler/autodist/apis.py | 2 +- nnscaler/autodist/autodist_config.py | 22 +++++-- nnscaler/autodist/cost_database.py | 88 ++++++++++++++++------------ nnscaler/autodist/model_graph.py | 5 +- nnscaler/autodist/pipeline_solver.py | 70 +++++++++++++++++----- nnscaler/autodist/spmd_solver.py | 50 +++++++++++++--- nnscaler/codegen/module/module.py | 5 +- nnscaler/execplan/execplan.py | 2 +- nnscaler/graph/gener/gen.py | 2 +- nnscaler/graph/segment.py | 12 +--- nnscaler/policies.py | 15 +++-- nnscaler/runtime/executor.py | 9 +-- 13 files changed, 190 insertions(+), 94 deletions(-) diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index c7816683..2d7030e6 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -590,6 +590,8 @@ The configuration of the PAS policy should be passed in the `pas_config` of `Com - `recompute_modules (str)`: The module names to recompute, separated by `,`. For example, `module1,module2`. Optional. - `pipeline_pivots (str)`: The module names to pivot the pipeline, separated by `,`. For example, if `module1,module2` is specified, stages searched by pipeline solver only start from either `module1` or `module2`. Optional. - `use_apex_fused_adam_v2`: If set to `True`, the apex fused adam v2 optimizer will be used. Default is `False`. Optional. + - `parallel_profile`: If set to `True`, autodist will profile operators in parallel by using available gpus. Default is `True`. Optional. + - `max_partition_degree`: Max degree when partitioning an operator / node. When pipeline parallelism is enbaled (`use_pipeline` is True), user can change the value to constrain the plan to be composed of stages that span on less or equal to `max_partition_degree` devices (recommend to set `max_partition_degree` to the number of devices in a node to avoid inter-node communication, but should be be no more than `plan_ngpus`). Default is `plan_ngpus`. Optional. Please note all options to `autodist` are just suggestions. `autodist` will try to find the best partition for you, which may not be the same with your suggestions. diff --git a/nnscaler/autodist/apis.py b/nnscaler/autodist/apis.py index 60833f1c..8e9073b2 100644 --- a/nnscaler/autodist/apis.py +++ b/nnscaler/autodist/apis.py @@ -155,7 +155,7 @@ def parallelize_graph(graph: IRGraph, ) (p_idx, p_dim), p_num = stage_desc.partition_descs[ consumer.cid].desc[0] - if p_idx != -1 and consumer.inputs()[p_dim] == ftensor: + if p_idx != -1 and consumer.inputs()[p_idx] == ftensor: raise RuntimeError( f'node {consumer} has partitioned input {ftensor}' ) diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 9606cd4d..5373115a 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -92,16 +92,19 @@ class AutoDistConfig: is specified, stages searched by pipeline solver only start from either `module1` or `module2`. - pipeline_nstages(`int`, *optional*, defaults to `1`): The number of stages in pipeline parallelism. This option is only used when pipeline is True. - - max_pipeline_bubble_ratio (`float`, *optional*, defaults to `0.4`): + - max_pipeline_bubble_ratio (`float`, *optional*, defaults to `0.2`): The maximum bubble ratio in pipeline parallelism. The higher the ratio, the more bubbles will be allowed, the larger search space will be explored. - max_pipeline_unbalance_ratio (`float`, *optional*, defaults to `0.5`): - The maximum unbalance ratio in pipeline parallelism. The higher the ratio, the more unbalance is required, - the smaller search space will be explored. - - solver (`str`, *optional*, defaults to `'ilp'`): + The maximum unbalance ratio in pipeline parallelism. This is a metric control min_pipeline_stage_time / max_pipeline_stage_time. + The higher the ratio, the more balance is required, the smaller search space will be explored. + - solver (`str`, *optional*, defaults to `'dp'`): The solver to use in spmd parallelism. Currently only support `'dp'` (dynamic programming) `'ilp'` (integer linear programming). + - parallel_profile (`bool`, *optional*, defaults to `True`): + Whether to profile on multiple device in parallel. If set to `False`, the profiling will be done in a + single device sequentially. - transient_mem_coef (`float`, *optional*, defaults to `2`): In autodist, a heuristic is used to estimate the transient memory size: `transient_mem_size = opt_transient_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)`. This formula @@ -137,9 +140,10 @@ def __init__(self, pipeline=False, pipeline_pivots='', pipeline_nstages=1, - max_pipeline_bubble_ratio=0.4, + max_pipeline_bubble_ratio=0.2, max_pipeline_unbalance_ratio=0.5, - solver='ilp', + solver='dp', + parallel_profile=True, transient_mem_coef=2, **kwargs): self.pc_path = partition_constraints_path @@ -175,6 +179,12 @@ def __init__(self, self.max_pipeline_bubble_ratio = max_pipeline_bubble_ratio self.max_pipeline_unbalance_ratio = max_pipeline_unbalance_ratio self.solver = solver + if pipeline and solver != 'dp': + _logger.warning( + f'pipeline is enabled, but solver is not dp, set solver to dp' + ) + self.solver = 'dp' + self.parallel_profile = parallel_profile self.transient_mem_coef = transient_mem_coef ignored_keys = list(kwargs.keys()) diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index ea292998..c7224960 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -56,7 +56,7 @@ def _piecewise_estimator(xs: List[float], ys: List[float], x: float) -> float: raise RuntimeError(f'x={x}, xs={xs}, ys={ys}, should not reach here') -def _filter_and_group_nodes(graph: IRGraph, db: ProfileDataBase) -> List[List[IRFwOperation]]: +def _filter_nodes(graph: IRGraph, db: ProfileDataBase) -> List[List[IRFwOperation]]: visited_nodes = set() node_to_profile = list() for node in graph.select(ntype=IRFwOperation): @@ -67,26 +67,17 @@ def _filter_and_group_nodes(graph: IRGraph, db: ProfileDataBase) -> List[List[IR continue node_to_profile.append(node) visited_nodes.add(hash_code) + return node_to_profile - dev_num = torch.cuda.device_count() - # divide `node_to_profile` into `dev_num` groups - node_groups = [[] for _ in range(dev_num)] +def _group_nodes(node_to_profile: List[IRFwOperation], group_num: int) -> List[List[IRFwOperation]]: + node_groups = [[] for _ in range(group_num)] for i, node in enumerate(node_to_profile): - node_groups[i % dev_num].append(node) + node_groups[i % group_num].append(node) return node_groups -def _profile_nodes(dilled_info: str, dev_id: int, partition_degree: int, re_profile: bool, comp_profile_path: str, result: multiprocessing.Queue): - import dill - torch.cuda.set_device(dev_id) - - id_state, dilled_graph = dill.loads(dilled_info) - graph = IRGraph.from_dill(id_state, dilled_graph) - db = ProfileDataBase() - db.load_ops(comp_profile_path) - nodes = _filter_and_group_nodes(graph, db)[dev_id] - +def _profile_nodes(nodes: List[IRFwOperation], db: ProfileDataBase, partition_degree: int, re_profile: bool): ret = list() for node in nodes: if isinstance(node, IRDimops): @@ -99,6 +90,20 @@ def _profile_nodes(dilled_info: str, dev_id: int, partition_degree: int, re_prof for partition_node in partition_nodes: profiled_metrics: ProfiledMetrics = db.profile(partition_node, override=re_profile) ret.append((partition_node.signature, db._serialize(partition_node), profiled_metrics)) + return ret + + +def _profile_graph(dilled_info: str, dev_id: int, partition_degree: int, re_profile: bool, comp_profile_path: str, result: multiprocessing.Queue): + import dill + torch.cuda.set_device(dev_id) + + id_state, dilled_graph = dill.loads(dilled_info) + graph = IRGraph.from_dill(id_state, dilled_graph) + db = ProfileDataBase() + db.load_ops(comp_profile_path) + node_to_profile = _filter_nodes(graph, db) + nodes = _group_nodes(node_to_profile, group_num=torch.cuda.device_count())[dev_id] + ret = _profile_nodes(nodes, db, partition_degree, re_profile) _logger.info(f'device {dev_id} finished profiling {len(nodes)} nodes') result.put(ret) @@ -130,30 +135,35 @@ def __init__(self, graph: IRGraph, config: AutoDistConfig): self.ignore_small_tensor_threshold = self.autodist_config.ignore_small_tensor_threshold def profile_comp(self, partition_degree: int): - - # use spawn to make sure the profiling process is independent from each other - # and the main process, this is also required by torch - mp_context = multiprocessing.get_context('spawn') - - results = mp_context.Queue() - processes = [] - for i in range(torch.cuda.device_count()): - p = mp_context.Process(target=_profile_nodes, - args=(self.graph.dumps(), i, partition_degree, self.autodist_config.re_profile, self.comp_profile_path, results)) - processes.append(p) - p.start() - - # put queue.get() before join to avoid deadlock - for p in processes: - ret = results.get() - for sign, serialized, profiled_metrics in ret: - _logger.debug(f'profiled {sign} in {serialized} with {profiled_metrics}') - if not self.db.exist_serialized(sign, serialized): - self.db.insert(sign, serialized, profiled_metrics) - results.close() - - for p in processes: - p.join() + if self.autodist_config.parallel_profile: + _logger.info('Profiling in parallel') + # use spawn to make sure the profiling process is independent from each other + # and the main process, this is also required by torch + mp_context = multiprocessing.get_context('spawn') + + results = mp_context.Queue() + processes = [] + for i in range(torch.cuda.device_count()): + p = mp_context.Process(target=_profile_graph, + args=(self.graph.dumps(), i, partition_degree, self.autodist_config.re_profile, self.comp_profile_path, results)) + processes.append(p) + p.start() + + # put queue.get() before join to avoid deadlock + for p in processes: + ret = results.get() + for sign, serialized, profiled_metrics in ret: + _logger.debug(f'profiled {sign} in {serialized} with {profiled_metrics}') + if not self.db.exist_serialized(sign, serialized): + self.db.insert(sign, serialized, profiled_metrics) + results.close() + + for p in processes: + p.join() + else: + _logger.info('Profiling in serial') + node_to_profile = _filter_nodes(self.graph, self.db) + _profile_nodes(node_to_profile, self.db, partition_degree, self.autodist_config.re_profile) self.db.dump_ops(self.comp_profile_path, override=True) diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 710e67be..1b824c46 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -90,14 +90,15 @@ def estimate_mem_lower_bound( 1. activations, parameters, buffers and gradients are distributed evenly across plan_ngpus 2. the optimizer memory is distributed evenly across zero_group_size (when zero stage 1 is enabled) or plan_ngpus ''' + opt_resident_mem = cfg.opt_resident_coef * param_mem + opt_transient_mem = cfg.opt_transient_coef * param_mem + # avg memory cost of activation, param (grad), buffer activation_mem = activation_mem / plan_ngpus param_mem = param_mem / plan_ngpus buffer_mem = buffer_mem / plan_ngpus # avg opt mem - opt_resident_mem = cfg.opt_resident_coef * param_mem - opt_transient_mem = cfg.opt_transient_coef * param_mem if cfg.zero_stage == 1: opt_resident_mem = opt_resident_mem / zero_group_size opt_transient_mem = opt_transient_mem / zero_group_size diff --git a/nnscaler/autodist/pipeline_solver.py b/nnscaler/autodist/pipeline_solver.py index b3e4b5d4..9d000461 100644 --- a/nnscaler/autodist/pipeline_solver.py +++ b/nnscaler/autodist/pipeline_solver.py @@ -43,7 +43,6 @@ def _collect_tp_intervals( tp_degree: int, stage_num: int, interval_groups: List[List[IntervalInfo]], - spmd_solver: SPMDSolver, ) -> List[int]: ''' collect intervals for given tp_degree and stage_num @@ -97,8 +96,6 @@ def calc_min_mem(start, end): start, end = group[0].start, group[0].end if calc_min_mem(start, end) > cfg.memory_constraint: continue - if spmd_solver.estimate_min_mem(start, end) > cfg.memory_constraint: - continue local_fw_span = model_graph.query_fw_span(start, end) / tp_degree if local_fw_span < min_fw_span or local_fw_span > max_fw_span: continue @@ -133,19 +130,12 @@ def _compute_tp_info( no_solution_states = set() def process_case(device_num, stage_num): - solver = SPMDSolver(graph=model_graph, - mesh_desc=_dev_num2mesh_desc( - device_num, cfg.mesh_desc.col), - autodist_config=cfg, - stage_num=stage_num) - selected_group_idxs = _collect_tp_intervals( model_graph, cfg, device_num, stage_num, interval_groups, - solver, ) intervals = [] for i in selected_group_idxs: @@ -156,8 +146,19 @@ def process_case(device_num, stage_num): _logger.info( f'process case: tp {device_num}, s {stage_num}, {len(intervals)} intervals' ) + if not intervals: + return None, [], [] + # postpone the initialization of SPMDSolver to save time + cur_cfg = copy.deepcopy(cfg) + cur_cfg.world_size = cfg.world_size // cfg.mesh_desc.ngpus * device_num + solver = SPMDSolver(graph=model_graph, + mesh_desc=_dev_num2mesh_desc( + device_num, cfg.mesh_desc.col), + autodist_config=cur_cfg, + stage_num=stage_num) + solver_ret = solver.solve(intervals, 1) - return intervals, solver_ret + return solver, intervals, solver_ret def _calc_upper_bound(tp_degree: int): # bubble time percentage <= bubble_ratio: @@ -169,19 +170,56 @@ def _calc_upper_bound(tp_degree: int): (1 - bubble_ratio) * micro_batch_num + 1) return min(cfg.mesh_desc.ngpus - tp_degree + 1, upper_bound) - # TODO(yizhu1): use multiprocessing to speed up + # intervals in a same group share a distributed plan. To make the code generation + # correct, we need to adjust the plan for each interval based on each offset with + # respect to the first interval in the group. + def shift_plan(solver, spmd_desc, offset: int, shifted_start: int, shifted_end: int): + assert offset >= 0, f'invalid offset {offset}' + if offset == 0: + return spmd_desc + new_spmd_desc = copy.deepcopy(spmd_desc) + new_partition_descs = dict() + plan = list() + shifted_idx = shifted_start + for k, v in spmd_desc.desc.partition_descs.items(): + new_partition_descs[k + offset] = v + plan.append((shifted_idx, solver.node_desc2idx(shifted_idx, v))) + shifted_idx += 1 + assert shifted_idx == shifted_end + 1, f'expect {shifted_end + 1}, got {shifted_idx}' + new_spmd_desc.desc.partition_descs = new_partition_descs + new_spmd_desc.desc.analysis = solver.analyze_plan(plan) + return new_spmd_desc + tp_info = {} for tp_degree in legal_tp_degrees: + # In current parallel profiler's implementation, the profiling is divided into + # following steps: + # 1. searialize the input graph + # 2. lauch the multi-process profiling by python's spawn method + # 3. each process loads the serialized graph and do profiling + # 4. transport the profiling result back to the main process + # It helps to reduce the profiling time when the graph has not been met before. + # But the procedure itself has a large overhead. + # In PipelineSolver, the SPMDSolver is constructed and used to search the optimal + # plan for multiple times. For given `tp_degree`, cases that need to be profiled + # are the same. As a result, we set `cfg.parallel_profile` to True at the first time + # and set it to False for the rest of the time. + cfg.parallel_profile = True for stage_num in range(1, _calc_upper_bound(tp_degree) + 1): - intervals, solver_ret = process_case(tp_degree, stage_num) + solver, intervals, solver_ret = process_case(tp_degree, stage_num) + cfg.parallel_profile = False for interval, spmd_descs in zip(intervals, solver_ret): start, end = interval if spmd_descs: for group in interval_groups: if group[0].start == start and group[0].end == end: - for interval in group: - tp_info[(tp_degree, stage_num, interval.start, - interval.end)] = spmd_descs[0] + # iso -> isomorphic + for iso_interval in group: + iso_start, iso_end = iso_interval.start, iso_interval.end + offset = model_graph.operator_list[iso_start].ir_cell.cid - \ + model_graph.operator_list[start].ir_cell.cid + tp_info[(tp_degree, stage_num, iso_start, + iso_end)] = shift_plan(solver, spmd_descs[0], offset, iso_start, iso_end) else: no_solution_states.add((start, end, tp_degree)) _logger.info( diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index c8e1f35d..ea0caf5a 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -1112,7 +1112,7 @@ def get_non_zero_index(binary_vector): for i in range(start, end + 1): plans.append((i, s_val[i - start])) mem_cost = self.calc_mem_cost(plans).total_cost - return SPMDSearchOutput(self.partition_path2desc(plans), + return SPMDSearchOutput(self.build_tp_desc(plans), mem_cost / 1024 / 1024 / 1024, all_time_cost, self.calc_inner_time_cost(plans)) @@ -1159,7 +1159,7 @@ def do_dp(self, intervals: List[Tuple[int, int]], cpp_results = solver.get_results(start, end) descs = [] for result in cpp_results: - desc = self.partition_path2desc(result.path) + desc = self.build_tp_desc(result.path) descs.append(SPMDSearchOutput(desc, result.memory * mem_divisor / 1024 / 1024 / 1024, result.all_time, self.calc_inner_time_cost(result.path))) ret.append(descs) return ret @@ -1329,18 +1329,41 @@ def solve(self, intervals: List[Tuple[int, int]], raise RuntimeError( f'unsupported solver {self.autodist_config.solver}') + + def node_desc2idx(self, node_idx: int, node_desc: NodePartitionDesc) -> int: + ''' + convert the node partition description to the corresponding index + + Args: + node_idx (int): the index of the node + node_desc (NodePartitionDesc): the partition description of the node + + Returns: + int: the index of the partition + ''' + for i, p in enumerate(self._op_partitions[node_idx]): + op = p.operator + p_info = tuple([ + (op.dim_id2pos(dim), num) + for dim, num in zip(p.partition_dims, p.partition_nums) + ]) + if p_info == node_desc.desc: + return i + raise RuntimeError(f'fail to find the partition {node_desc} for node {self.get_operator(node_idx)}') + + def partition_path2desc( - self, plans: List[Tuple[int, int]]) -> Dict[int, NodePartitionDesc]: + self, plan: List[Tuple[int, int]]) -> Dict[int, NodePartitionDesc]: ''' convert the partition representation: (op_idx, partition_idx) to (op_cid, partition_desc) Args: - plans (List[Tuple[int, int]]): the partition plan to be converted + plan (List[Tuple[int, int]]): the partition plan to be converted Returns: Dict[int, NodePartitionDesc]: the converted partition plan ''' - partitions = [self._op_partitions[u][v] for u, v in plans] + partitions = [self._op_partitions[u][v] for u, v in plan] partition_descs = {} for p in partitions: @@ -1351,10 +1374,23 @@ def partition_path2desc( ]) partition_descs[op.ir_cell.cid] = NodePartitionDesc(desc=p_info) - return TensorParallelDesc(partition_descs=partition_descs, + return partition_descs + + + def build_tp_desc(self, plan: List[Tuple[int, int]]) -> TensorParallelDesc: + ''' + build the tensor parallelism description for the plan + + Args: + plan (List[Tuple[int, int]]): the plan to be converted + + Returns: + TensorParallelDesc: the tensor parallelism description + ''' + return TensorParallelDesc(partition_descs=self.partition_path2desc(plan), mesh_desc=self.mesh_desc, recompute_groups=[], - analysis=self.analyze_plan(plans)) + analysis=self.analyze_plan(plan)) def analysis_pretty_printer(analysis: Dict[str, Any]) -> str: diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index f4eb993e..d1cb3508 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -462,6 +462,7 @@ def forward(self, x, y=None, z=None): class_name='GenModel', derived=[f'nnscaler.runtime.module.{"ParallelModule" if as_parallel_module else "CubeModule"}'] ) as cb: + cb.insert_body(f'use_scheduler = {self.execplan.graph.sched is not None}') if as_parallel_module: cb.insert_body(f'rank = {device}') # save rank in class level with FunctionBlock(func_name='__init__', args=['self', 'init_params=True']) as ib: @@ -707,9 +708,9 @@ def emit_segment(self, segment: IRSegment, runtime_devid: int) -> List[str]: """ Emit IRSegment code. - The resultant `List[str]` will be lines of the statements of the final + The returned `List[str]` will be lines of the statements of the final Python method for the targeted Segment. - The resultant lines will not include the signature and the return statement + The returned lines will not include the signature and the return statement of the generated Python method. These lines will be put into `model_methods_bodies` and the missing Python-syntactic parts will be injected later on. diff --git a/nnscaler/execplan/execplan.py b/nnscaler/execplan/execplan.py index 387eaf47..e9bf96e1 100644 --- a/nnscaler/execplan/execplan.py +++ b/nnscaler/execplan/execplan.py @@ -149,7 +149,7 @@ def block2reuse(node: Block) -> ExeReuseCell: assert isinstance(block, IRCell) topo_seqs.append(block) - # set up returning outputs by packing output results from each micro-batch into a list + # set up returned outputs by packing output results from each micro-batch into a list outputs = [] for mid in range(schedplan.nmicros): outs = [] diff --git a/nnscaler/graph/gener/gen.py b/nnscaler/graph/gener/gen.py index 07884a88..79c69a3b 100644 --- a/nnscaler/graph/gener/gen.py +++ b/nnscaler/graph/gener/gen.py @@ -410,7 +410,7 @@ def skip(ptensors: List[IRSubTensor], ctensors: List[IRSubTensor]) -> bool: fadapters.append(fadapter) # (activation -> graph/segment output) generation: generate communication adapters between - # producer operatiors and graph/segment output tensors. Note graph/segment output tensors + # producer operators and graph/segment output tensors. Note graph/segment output tensors # always require for full-shape/value for output, while consumers may partition them. Therefore, # we need to additionally generate adapters for this case. if ftensor in output_consumer: diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index 5e3b31a7..a53282d4 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -965,7 +965,7 @@ def get_outputs(nodes: List[IRCell], exclude_attr: bool = True): def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> IRCell: """Create a segment (sub-graph) with part of the nodes. - This only return the created segment wihout modifying the graph. + This only return the created segment without modifying the graph. Calling this requires that the dependencies are already materialized, i.e., every input IRSubTensor should have a corresponding producer. Two scenarios @@ -1045,7 +1045,7 @@ def order(tensors: Set[IRObject]) -> Tuple[IRObject]: def dispatch(self, devid: int, _gen_mirror: bool = True) -> Optional[IRCell]: """ - Instantiate the segement to a specific device. + Instantiate the segment to a specific device. @param devid int: the target device @@ -1057,17 +1057,11 @@ def dispatch(self, devid: int, _gen_mirror: bool = True) -> Optional[IRCell]: return self if devid in self._dispatch_cached: return self._dispatch_cached[devid] - # inputs, outputs, nodes = [], [], [] + inputs, outputs, nodes = self.inputs(), self.outputs(), [] for node in self._nodes: if devid in node.device: nodes.append(node.dispatch(devid)) - # for itensor in node.inputs(): - # if itensor in self._inputs and itensor not in inputs: - # inputs.append(itensor) - # for otensor in node.outputs(): - # if otensor in self._outputs and otensor not in outputs: - # outputs.append(otensor) def order(tensors: Set[IRObject]) -> Tuple[IRObject]: """Reorder by logical tensor id. Temporally necessary for pipeline scheduling""" diff --git a/nnscaler/policies.py b/nnscaler/policies.py index be94e414..b3750caa 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -207,9 +207,12 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: raise ValueError("pipeline_nmicros should be equal to update_freq") # optional parameters - mesh_col = pas_cfg.get('mesh_col', cfg.plan_ngpus) - if mesh_col != cfg.plan_ngpus: - raise ValueError("mesh_col should be equal to plan_ngpus") + mesh_col = pas_cfg.get('max_partition_degree', cfg.plan_ngpus) + if cfg.plan_ngpus % mesh_col != 0: + raise ValueError(f"plan_ngpus {cfg.plan_ngpus} should be divisible by max_partition_degree {mesh_col}") + mesh_row = cfg.plan_ngpus // mesh_col + if not cfg.use_pipeline and mesh_row != 1: + raise ValueError("mesh_row should be 1 if pipeline is not enabled") memory_constraint = pas_cfg.get('mem_constraint', -1) task_name = pas_cfg.get('task_name', '_') use_memory_efficient_fp16 = pas_cfg.get('use_memory_efficient_fp16', False) @@ -224,10 +227,9 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: recompute_modules = pas_cfg.get('recompute_modules', '') pipeline_pivots = pas_cfg.get('pipeline_pivots', '') use_apex_fused_adam_v2 = pas_cfg.get('use_apex_fused_adam_v2', False) + parallel_profile = pas_cfg.get('parallel_profile', True) - mesh_row = 1 - ngpus = mesh_row * mesh_col - task_name = f'{task_name}_{ngpus}gpus_{update_freq}update_freq' + task_name = f'{task_name}_{cfg.plan_ngpus}gpus_{update_freq}update_freq' if memory_constraint == -1: # consider memory fragmentation and other buffers, use 80% of the memory memory_constraint = int(0.8 * torch.cuda.mem_get_info()[1] / 1024 / @@ -289,6 +291,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: pipeline=cfg.use_pipeline, pipeline_pivots=pipeline_pivots, pipeline_nstages=cfg.pipeline_nstages, + parallel_profile=parallel_profile, ) return parallelize_graph(graph, autodist_cfg) diff --git a/nnscaler/runtime/executor.py b/nnscaler/runtime/executor.py index 7e59b2a8..c762f438 100644 --- a/nnscaler/runtime/executor.py +++ b/nnscaler/runtime/executor.py @@ -9,6 +9,8 @@ _logger = logging.getLogger(__name__) +_ALLOW_GRAD_DTYPES = (torch.double, torch.float32, torch.float16, torch.bfloat16) + def debug_id(tensors, msg: str, rank: int): if torch.distributed.get_rank() == rank: @@ -113,11 +115,10 @@ def aexecute(subgraph: Callable, *input_tensors: Tuple[Any], requires_grad=True) outputs = subgraph(*input_tensors) else: outputs = subgraph(*input_tensors) - allow_grad_dtypes = (torch.float32, torch.float16) - if torch.is_tensor(outputs) and outputs.dtype in allow_grad_dtypes: + if isinstance(outputs, tuple): + outputs = (t.requires_grad_() if torch.is_tensor(t) and t.dtype in _ALLOW_GRAD_DTYPES else t for t in outputs) + elif torch.is_tensor(outputs) and outputs.dtype in _ALLOW_GRAD_DTYPES: outputs = outputs.requires_grad_() - else: - outputs = (t.requires_grad_() if t.dtype in allow_grad_dtypes else t for t in outputs) return outputs @staticmethod From b34d283c3313789516379058af37bb3b252b026b Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 7 Aug 2024 08:13:49 +0000 Subject: [PATCH 1698/1892] Merged PR 2221: minitrainer: add document minitrainer: add document --- docs/source/trainer.md | 437 ++++++++++++++++++++++++++++++ nnscaler/cli/arg_parser.py | 8 + nnscaler/cli/train_hook.py | 10 + nnscaler/cli/trainer.py | 18 +- nnscaler/cli/trainer_args.py | 52 +++- nnscaler/runtime/f16_optimizer.py | 9 +- tests/cli/test_arg_parser.py | 14 + tests/cli/test_trainer.py | 177 ++++++++++++ 8 files changed, 700 insertions(+), 25 deletions(-) create mode 100644 docs/source/trainer.md diff --git a/docs/source/trainer.md b/docs/source/trainer.md new file mode 100644 index 00000000..454b6a9f --- /dev/null +++ b/docs/source/trainer.md @@ -0,0 +1,437 @@ +# Trainer + +We provide a `Trainer` class that can be used to train and evaluate models. It will firstly parallelize the model on multiple GPUs with `parallelize` API, and then train the model with the given dataset and optimizer in a distributed way. + + +## Arguments + +All the arguments are defined in `TrainerArgs` class. Here is the definition of `TrainerArgs`: + +```python +@dataclass +class TrainerArgs: + compute_config: ComputeConfig = None + + gen_savedir: str = './.nnscaler' + gen_reuse: str = 'auto' + pas_policy: str = 'autodist' + broadcast_strategy: str = 'all' + instance_name: str = None + run_mode: str = 'run' + tracing_from_weights: str = None + + model: ModelConfig = field(default_factory=ModelConfig) + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + dataset: DatasetConfig = field(default_factory=DatasetConfig) + dataloader: DataloaderConfig = field(default_factory=DataloaderConfig) + dataset_sampler: DatasetSamplerConfig = field(default_factory=DatasetSamplerConfig) + lr_scheduler: Optional[LRSchedulerConfig] = None + checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) + log: List[LogConfig] = field(default_factory=list) + hook: Union[HookConfig, HookMapConfig, None] = None + + precision: Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None] = None + + micro_batch_size: int = 1 + global_batch_size: Optional[int] = None + grad_accumulation_steps: Optional[int] = None + + max_epochs: Optional[int] = None + max_train_steps: Optional[int] = None + max_val_steps: Optional[int] = None + + val_every_n_train_steps: Optional[int] = None + val_every_n_epochs: Optional[int] = 1 + + enable_progress_bar: bool = True + + seed: Optional[int] = None + init_env_fn: str = None +``` + +The design philosophy of `Trainer` arguments is: +The classes(or factory functions) of components(model/optimizer/etc) +and their arguments are provided in the `TrainerArgs` class (functions/types are passed as fully qualified names), +and we are responsible for creating them. + +For example, you can tell me how to create a model by providing the model type and its arguments in `ModelConfig` class. + +Please note some of the arguments of components are set automatically, and you should not set them manually. +For example, arguments `dataset`, `num_replicas` and `rank` of the dataset sampler are set automatically by the `Trainer` class. +Those 3 arguments passed in the `DatasetSamplerConfig.train_args/val_args`(if any) will be ignored. + +```python +'dataset': { + 'type': 'SomeDataset', + 'train_args': { + ... + }, + 'val_args': { + ... + } +} +'dataset_sampler': { + 'type': 'SomeDatasetSampler', + 'train_args': { + 'num_replicas': ..., # this will be ignored + 'dataset': ..., # this will be ignored + 'rank': ..., # this will be ignored + ... + }, + 'val_args': { + 'num_replicas': ..., # this will be ignored + 'dataset': ..., # this will be ignored + 'rank': ..., # this will be ignored + ... + }, +} +``` + +If any argument type is a class, you can pass it as a dict, and add a special key `__type` to specify the class type. + +For example, if the module `__init__` takes `ModelConfig` object +```python +class SomeModule(torch.nn.Module): + def __init__(self, model_config: ModelConfig): + ... +``` +You can pass the `model_config` as +```python +{ + 'type': 'SomeModule', + 'args': { + 'model_config': { + '__type': 'ModelConfig', + # arguments to create ModelConfig + } + } +} +``` + +We also use `ast.literal_eval` to guess the type of the string arguments, You can skip it by passing a dict with `__value_type` and `value` keys. For example, you want a number to be a str, you can use +```python +{ + '__value_type': 'str', + 'value': '1' +} +``` +Internally we will get the final value with `__value_type(value)`. + +### Component Configs + +- `model` (`ModelConfig`): The model to be trained. You need to provide the model type and its arguments in `ModelConfig` class. Here is the definition of `ModelConfig`: + + ```python + @dataclass + class ModelConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + ``` +- `optimizer` (`OptimizerConfig`): The optimizer to be used. + + ```python + @dataclass + class OptimizerConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + clip_gnorm: float = 0.0 + + loss_reduction: str = 'mean' + grad_reduction: str = 'mean' + aggregate_outputs_fn: str = None + ``` + - `type` (`str`): The optimizer type or factory function. + Please note the first parameter of the optimizer constructor must be the model parameters. + - `args` (`Dict[str, Any]`): The arguments of the optimizer. + - `clip_gnorm` (`float`): The maximum norm value for gradient clipping. 0.0/None means no clipping. + - `loss_reduction` (`str`): The reduction method for loss. + It can be `mean` (average the loss over all micro-batches), + `sum` (sum the loss of all micro-batches). + Default is `mean`. + Please note in validation stage, this configuration is ignored the loss is always averaged over all batches + - `grad_reduction` (`str`): The reduction method for gradients. It can be `mean` (average the gradients over all micro-batches), `sum` (sum the gradients of all micro-batches), `per-token-mean` (average the gradients over all tokens). Default is `mean`. Please note if `per-token-mean` is used, you need to specify `aggregate_outputs_fn`, which will return the number of tokens + - `aggregate_outputs_fn` (`str`): The function to aggregate the outputs of the model. It is required when `grad_reduction` is `per-token-mean`. Its signature should be `def aggregate_outputs(self, loss_outputs, sync_group) -> AggregatedOutputs`, where `loss_outputs` is a list of outputs of the model, and `sync_group` is the `torch.distributed.ProcessGroup` to sync with. The function should return an `AggregatedOutputs` object, which defines as: + ```python + @dataclass + class AggregatedOutputs: + # the aggregated loss as a sum + loss_sum: float = None + # number of mini batches + num_batches: int = None + # number of tokens (necessary when grad_reduction is 'per-token-mean') + num_tokens: Optional[int] = None + # any other custom outputs + aggregated_outputs: Any = None + ``` +- `dataset` (`DatasetConfig`): The dataset to be used. + ```python + @dataclass + class DatasetConfig: + type: str = None + train_args: Dict[str, Any] = field(default_factory=dict) + val_args: Dict[str, Any] = field(default_factory=dict) + ``` + - `type` (`str`): The dataset type or factory function. + - `train_args` (`Dict[str, Any]`): The arguments of the training dataset. + - `val_args` (`Dict[str, Any]`): The arguments of the validation dataset. +- `dataloader` (`DataloaderConfig`): The dataloader to be used. Please note we recommend to pass `drop_last=True` in the dataloader arguments to avoid the last batch with different sizes. + + ```python + @dataclass + class DataloaderConfig: + type: str = 'torch.utils.data.DataLoader' + train_args: Dict[str, Any] = field(default_factory=dict) + # default to train_args + val_args: Dict[str, Any] = field(default_factory=dict) + # default to train_args + test_args: Dict[str, Any] = field(default_factory=dict) + ``` + - `type` (`str`): The dataloader type or factory function. + Please note the dataloader constructor must at least have 3 parameters `dataset`, `batch_size`, `sampler`. + - `train_args` (`Dict[str, Any]`): The arguments (except `dataset`,`batch_size`, `sampler`) of the training dataloader. Argument `batch_size` will be set to `micro_batch_size`. + - `val_args` (`Dict[str, Any]`): The arguments (except `dataset`,`batch_size`, `sampler`) of the validation dataloader. + +- `dataset_sampler` (`DatasetSamplerConfig`): The dataset sampler to be used. + + ```python + @dataclass + class DatasetSamplerConfig: + type: str = 'torch.utils.data.DistributedSampler' + train_args: Dict[str, Any] = field(default_factory=dict) + val_args: Dict[str, Any] = field(default_factory=dict) + test_args: Dict[str, Any] = field(default_factory=dict) + ``` + - `type` (`str`): The dataset sampler type or factory function. + Please note the dataset sampler constructor must at least have 3 parameters `dataset`, `num_replicas`, `rank`. + - `train_args` (`Dict[str, Any]`): The arguments (except `dataset`,`num_replicas`, `rank`) of the training dataset sampler. + - `val_args` (`Dict[str, Any]`): The arguments (except `dataset`,`num_replicas`, `rank`) of the validation dataset sampler. + +- `lr_scheduler` (`LRSchedulerConfig`): The learning rate scheduler to be used. This is optional. + + ```python + @dataclass + class LRSchedulerConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + interval: str = 'epoch' + ``` + - `type` (`str`): The learning rate scheduler type or factory function. + Please note the first parameter of the learning rate scheduler constructor must be optimizer. + - `args` (`Dict[str, Any]`): The arguments of the learning rate scheduler. + - `interval` (`str`): The interval to update the learning rate. It can be `epoch` or `step`. Default is `epoch`. + +- `log` (`List[LogConfig]`): The loggers to be used. You can provide multiple loggers. Currently we have two builtin loggers: `TensorBoardLogger` and `WandbLogger`. + + ```python + @dataclass + class LogConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + ``` + - `type` (`str`): The logger type or factory function. + - `args` (`Dict[str, Any]`): The arguments of the logger. + +- `hook` (`Union[HookConfig, HookMapConfig, None]`): The hooks to be used. You can provide a hook with a hook class or a map of hook functions. + + Hook class: + + ```python + @dataclass + class HookConfig: + type: str = None + args: Dict[str, Any] = field(default_factory=dict) + ``` + + - `type` (`str`): The hook type or factory function. + - `args` (`Dict[str, Any]`): The arguments of the hook. + + Hook map: + + ```python + @dataclass + class HookMapConfig: + on_train_start: str = None + on_train_end: str = None + on_val_start: str = None + on_val_end: str = None + + on_epoch_start: str = None + on_epoch_end: str = None + + on_train_step_start: str = None + on_train_step_end: str = None + on_val_step_start: str = None + on_val_step_end: str = None + + after_aggregate_train_step_outputs: str = None + after_aggregate_val_step_outputs: str = None + + before_zero_grad: str = None + after_zero_grad: str = None + + before_gnorm_clip: str = None + after_gnorm_clip: str = None + + before_optimizer_step: str = None + after_optimizer_step: str = None + + on_load_checkpoint: str = None + on_save_checkpoint: str = None + ``` + - `on_train_start` (`str`): The hook function to be called at the start of the training stage. Signature: `def on_train_start(trainer: 'Trainer') -> None:` + - `on_train_end` (`str`): The hook function to be called at the end of the training stage. Signature: `def on_train_end(trainer: 'Trainer') -> None:` + - `on_val_start` (`str`): The hook function to be called at the start of the validation stage. Signature: `def on_val_start(trainer: 'Trainer') -> None:` + - `on_val_end` (`str`): The hook function to be called at the end of the validation stage. Signature: `def on_val_end(trainer: 'Trainer', val_loss: float) -> None:` + - `on_epoch_start` (`str`): The hook function to be called at the start of each epoch. Signature: `def on_epoch_start(trainer: 'Trainer', epoch: int) -> None:` + - `on_epoch_end` (`str`): The hook function to be called at the end of each epoch. Signature: `def on_epoch_end(trainer: 'Trainer', epoch: int) -> None:` + - `on_train_step_start` (`str`): The hook function to be called at the start of each training step. Signature: `def on_train_step_start(trainer: 'Trainer', batches: List[Any], idx: int) -> None:` + - `on_train_step_end` (`str`): The hook function to be called at the end of each training step. Signature: `def on_train_step_end(trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None:` + - `on_val_step_start` (`str`): The hook function to be called at the start of each validation step. Signature: `def on_val_step_start(trainer: 'Trainer', batches: List[Any], idx: int) -> None:` + - `on_val_step_end` (`str`): The hook function to be called at the end of each validation step. Signature: `def on_val_step_end(trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None:` + - `after_aggregate_train_step_outputs` (`str`): The hook function to be called after aggregating the outputs of the model in the training step. Signature: `def after_aggregate_train_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None:` + - `after_aggregate_val_step_outputs` (`str`): The hook function to be called after aggregating the outputs of the model in the validation step. Signature: `def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None:` + - `before_zero_grad` (`str`): The hook function to be called before zeroing the gradients. Signature: `def before_zero_grad(trainer: 'Trainer') -> None:` + - `after_zero_grad` (`str`): The hook function to be called after zeroing the gradients. Signature: `def after_zero_grad(trainer: 'Trainer') -> None:` + - `before_sync_grad` (`str`): The hook function to be called before syncing the gradients between ranks. Signature: `def before_sync_grad(trainer: 'Trainer') -> None:` + - `after_sync_grad` (`str`): The hook function to be called after syncing the gradients between ranks. Signature: `def after_sync_grad(trainer: 'Trainer') -> None:` + - `before_gnorm_clip` (`str`): The hook function to be called before gradient clipping. Signature: `def before_gnorm_clip(trainer: 'Trainer') -> None:` + - `after_gnorm_clip` (`str`): The hook function to be called after gradient clipping. Signature: `def after_gnorm_clip(trainer: 'Trainer', gnorm: torch.Tensor) -> None:` + - `before_optimizer_step` (`str`): The hook function to be called before the optimizer step. Signature: `def before_optimizer_step(trainer: 'Trainer') -> None:` + - `after_optimizer_step` (`str`): The hook function to be called after the optimizer step. Signature: `def after_optimizer_step(trainer: 'Trainer') -> None:` + - `on_load_checkpoint` (`str`): The hook function to be called after loading the checkpoint. If you saved something with `on_save_checkpoint` this is + your chance to restore this. Signature: `def on_load_checkpoint(trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None:` + - `on_save_checkpoint` (`str`): The hook function to be called before saving the checkpoint. If you want to save something, you can add it to the checkpoint here. Signature: `def on_save_checkpoint(trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None:` + +### Compute Config + +All compute configs are put in `compute_config` (`ComputeConfig`). Please refer to [link](./parallel_module.md#ComputeConfig) for more information. + +Please note only end2end mode is supported in the trainer, so you must set `compute_config.use_end2end` to `True` to make it work. + +### Checkpoint Config + +```python +@dataclass +class CheckpointConfig: + save_dir: str = './checkpoints' + no_save: bool = False + + save_type: str = 'sharded' + + save_last: bool = True + save_best: bool = True + symlink_best_and_last: bool = True + + every_n_train_steps: Optional[int] = None + every_n_epochs: Optional[int] = None + keep_last_n_checkpoints: Optional[int] = None + + resume_from: str = None +``` + +- `save_dir` (`str`): The directory to save the checkpoints. +- `no_save` (`bool`): Whether to save the checkpoints. Default is `False`. +- `save_type` (`str`): The type of saving checkpoint. It can be `sharded` or `deduped`. Default is `sharded`. + - `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. + The checkpoint is a folder with as many files as the world size. + - `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. + The checkpoint is a folder with as many files as the world size. + - `"merged"`: everything has been merged into a single file. Used internally only when you merge the checkpoint files via `Trainer.merge_checkpoints` +- `save_last` (`bool`): Whether to save the last checkpoint. Default is `True`. +- `save_best` (`bool`): Whether to save the best checkpoint. Default is `True`. +- `symlink_best_and_last` (`bool`): Whether to use symlink (instead of copy) to the best and last checkpoint. Default is `True`. +- `every_n_train_steps` (`Optional[int]`): Save the checkpoint every `every_n_train_steps` training steps. Default is `None`, which means no checkpoint is saved based on training steps. +- `every_n_epochs` (`Optional[int]`): Save the checkpoint every `every_n_epochs` epochs. Default is `None`, which means no checkpoint is saved based on epochs. +- `keep_last_n_checkpoints` (`Optional[int]`): Keep the last `keep_last_n_checkpoints` checkpoints. If we have more than `keep_last_n_checkpoints` checkpoints, we will remove the oldest ones. +Default is `None`, which means all checkpoints are kept. +- `resume_from` (`str`): The path to the checkpoint to resume from. It can be `last`/`best`/a specific folder/file. +We will not resume (nor report error) if resume_from is `last` or `best` but the corresponding checkpoint does not exist. +Default is `None`. + +Please note when the parallel plan is changed (i.e you re-trace the model with different configurations), +the checkpoints become incompatible, and can't be loaded any more. +You must firstly merge the checkpoints to a merged checkpoint with `Trainer.merge_checkpoint` and then load the merged checkpoint just like a regular checkpoint. + +```python +def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): +``` +where `checkpoint_files` is a list of checkpoint files to merge, and `output_file` is the output file path. + +### Other configs +- `gen_savedir` (`str`): The directory to save the generated files. Default is `./.nnscaler`. +- `gen_reuse` (`str`): the reuse strategy of the generated code, it can be + - `auto`: automatically decide the reuse strategy (`moo` for `compile`, `match` for `run`) + - one of `match`/`override`/`moo`/`graph`. See `parallelize` API for more information. +- `pas_policy` (`str`): The policy of parameter partitioning. Default is `autodist`. +You can pass builtin pas policy name or your own pas policy function. +See `parallelize` API for more information. +- `broadcast_strategy` (`str`): The strategy of broadcasting the model. Default is `all`. See `parallelize` API for more information. +- `instance_name` (`str`): The instance name of the trainer. Default is `None`. See `parallelize` API for more information. +- `run_mode` (`str`): The run mode of the trainer. It can be `run` (compile and train the model) and `compile` (only compile the model to generate code). Default is `run`. +- `tracing_from_weights` (`str`): The path to the weights to be loaded when tracing(compiling) the model. It is only used in tracing to serve as the initial state dict of the model. Default is `None`. +- `precison`(`Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None]`): The precision of the model. It can be a `str`, which means the same precision for all tensors, or a `Dict[_TENSOR_TYPE, _PRECISION_TYPE]`, which means the precision for each tensor type. Default is `None`. Currently we support 3 tensor types (`param`, `buffer`, `input`) and three precisions (`fp32`, `fp16`, `bf16`). You can set precision to `none` to avoid any precision conversion. +- `micro_batch_size` (`int`): The micro batch size. Default is `1`. +- `global_batch_size` (`Optional[int]`) and `grad_accumulation_steps` (`Optional[int]`): You can set one of `global_batch_size` and `grad_accumulation_steps` option. Please note if both are set, they must be consistent. Default is `micro_batch_size*scaling_factor` and `1` respectively. +- `max_epochs` (`Optional[int]`): The maximum number of epochs to train. Default is `None`, which means no limit. +- `max_train_steps` (`Optional[int]`): The maximum number of training steps to train. Default is `None`, which means no limit. +- `max_val_steps` (`Optional[int]`): The maximum number of validation steps to validate. Default is `None`, which means no limit. +- `val_every_n_train_steps` (`Optional[int]`): Validate every `val_every_n_train_steps` training steps. Default is `None`, which means no validation based on training steps. +- `val_every_n_epochs` (`Optional[int]`): Validate every `val_every_n_epochs` epochs. Default is `1`. +- `enable_progress_bar` (`bool`): Whether to enable the progress bar. Default is `True`. +- `seed` (`Optional[int]`): The random seed. Default is `None`. +- `init_env_fn` (`str`): The function to initialize the environment. Default is `None`. + +## CLI + +You can run the trainer with the following command: + +```bash +torchrun [torchrun arguments] ${NNSCALER_HOME}/cli/train.py -f ${CONFIG_FILE} [other arguments] +``` + +CONFIG_FILE is the path to the configuration yaml file. It looks like (taken from our test case) + +```yaml +compute_config: + plan_ngpus: 4 + runtime_ngpus: 100 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 +max_train_steps: 10 + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: torch.optim.Adam + args: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped +``` + +All the arguments in the yaml file are the same as the arguments in the `TrainerArgs` class. +And they can be override with the command line arguments. +For example, you can override the `max_epochs` with `--max_epochs 2`, or override the `model` with `--model.args.dim 32 --model.args.nlayers 32`. diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index 008f0dc4..ce82542f 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -10,6 +10,11 @@ UnionType = None # for python < 3.10 +_TYPE_KEY = '__type' +_VALUE_TYPE_KEY = '__value_type' +_VALUE_KEY = 'value' + + def parse_args(argv: List[str]) -> dict: raw_args = {} last_key = None @@ -137,6 +142,9 @@ def _is_primitive_type(data_type): def _guess_deserialize_object(value): if isinstance(value, dict): + if _VALUE_KEY in value and _VALUE_TYPE_KEY in value and len(value) == 2: + # keep as it is if it is a value object + return value return {_guess_deserialize_object(k): _guess_deserialize_object(v) for k, v in value.items()} if isinstance(value, list): return [_guess_deserialize_object(v) for v in value] diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index 8f17c166..28e590e0 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -101,6 +101,16 @@ def after_zero_grad(self, trainer: 'Trainer') -> None: Called after zero_grad """ + def before_sync_grad(self, trainer: 'Trainer') -> None: + """ + Called before sync_shard_grad + """ + + def after_sync_grad(self, trainer: 'Trainer') -> None: + """ + Called after sync_shard_grad + """ + def before_gnorm_clip(self, trainer: 'Trainer') -> None: """ Called before gradient clipping diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index cbaf4a44..42c1750f 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -206,11 +206,11 @@ def _create_model(): pmodel_class = nnscaler.parallelize( self.train_args.model_type, self._create_dummy_forward_args(), - self.train_args.pas_policy, + self.train_args.resolved_pas_policy, compute_config, module_fn=_create_model, gen_savedir=self.train_args.gen_savedir, - reuse='moo' if compile_only else 'match', + reuse=self.train_args.gen_reuse, instance_name=self.train_args.instance_name, broadcast_strategy='all', load_module=not compile_only, @@ -382,11 +382,17 @@ def _save_checkpoint(self, loss): else: shutil.copy(ckpt_file, best_file) + torch.distributed.barrier() # remove old checkpoints local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) # only the first rank in the group will do the job if self.rank % local_world_size == 0: - self._expire_checkpoints() + try: + self._expire_checkpoints() + except Exception as e: + logger.warning('Error when removing old checkpoints: %s. Will try later.', e) + + torch.distributed.barrier() def _expire_checkpoints(self): if not self.train_args.checkpoint.keep_last_n_checkpoints: # keep all @@ -546,7 +552,7 @@ def _validate(self, step_stat: _StepStat): losses = self.model.infer_step(batches) self.hook.on_val_step_end(self, losses[:num_batches], batches[:num_batches], idx) - aggregate_outputs = self.train_args.aggregate_outputs or self.aggregate_outputs + aggregate_outputs = self.train_args.resolved_aggregate_outputs_fn or self.aggregate_outputs aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) self.hook.after_aggregate_val_step_outputs( self, aggregated_outputs, @@ -601,7 +607,7 @@ def train_epoch(self, epoch): losses = self.model.train_step(batches, is_dummy_batch) self.hook.on_train_step_end(self, losses[:num_batches], batches[:num_batches], idx) - aggregate_outputs = self.train_args.aggregate_outputs or self.aggregate_outputs + aggregate_outputs = self.train_args.resolved_aggregate_outputs_fn or self.aggregate_outputs aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) if self.train_args.optimizer.loss_reduction == 'mean': loss = aggregated_outputs.loss_sum / aggregated_outputs.num_batches @@ -612,7 +618,9 @@ def train_epoch(self, epoch): if self.rank == 0: data_iter.set_postfix({'loss': loss}) + self.hook.before_sync_grad(self) self.optimizer.sync_shard_grad() + self.hook.after_sync_grad(self) # scale gradients if self.train_args.optimizer.grad_reduction == 'sum': diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 9f27c4ab..e868ba9f 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -1,6 +1,7 @@ from dataclasses import asdict, dataclass, field import importlib -from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union, get_args +from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union +from typing_extensions import get_args from pathlib import Path import logging import copy @@ -13,10 +14,10 @@ import torch from nnscaler.utils import transform_recursively -from nnscaler.parallel import ComputeConfig, build_optimizer +from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule -from .arg_parser import deserialize_dataclass, merge_args, parse_args +from .arg_parser import deserialize_dataclass, merge_args, parse_args, _TYPE_KEY, _VALUE_TYPE_KEY, _VALUE_KEY from .loggers.logger_base import LoggerBase from .train_hook import TrainHook @@ -154,8 +155,8 @@ class CheckpointConfig: every_n_epochs: Optional[int] = None keep_last_n_checkpoints: Optional[int] = None - # resume training from a checkpoint folder - # can be 'last'/'best'/a specific folder + # resume training from a checkpoint folder/file + # can be 'last'/'best'/a specific folder/file # we will not resume if resume_from is last or best but the corresponding checkpoint does not exist resume_from: str = None @@ -232,6 +233,9 @@ class HookMapConfig: before_zero_grad: str = None after_zero_grad: str = None + before_sync_grad: str = None + after_sync_grad: str = None + before_gnorm_clip: str = None after_gnorm_clip: str = None @@ -265,6 +269,10 @@ class TrainerArgs: compute_config: ComputeConfig = None gen_savedir: str = './.nnscaler' + # the reuse strategy of the generated code + # auto: automatically decide the reuse strategy (moo for compile, match for run) + # Or one of match/override/moo/graph (see `nnscaler.ReuseType`) + gen_reuse: str = 'auto' pas_policy: str = 'autodist' broadcast_strategy: str = 'all' instance_name: str = None @@ -331,9 +339,19 @@ def __post_init__(self): raise ValueError(f"`global_batch_size` {self.global_batch_size} is not equal to `micro_batch_size*scaling_factor*grad_accumulation_steps` " f"{self.micro_batch_size*self.scaling_factor*self.grad_accumulation_steps}") + if self.compute_config.use_pipeline and self.grad_accumulation_steps != self.compute_config.pipeline_nmicros: + raise ValueError(f"grad_accumulation_steps {self.grad_accumulation_steps} must be equal to " + f"compute_config.pipeline_nmicros {self.compute_config.pipeline_nmicros}") + if self.run_mode not in ('compile', 'run'): raise ValueError(f"Invalid run_mode {self.run_mode}") + if self.gen_reuse != 'auto': + if self.gen_reuse not in [e.value for e in ReuseType]: + raise ValueError(f"Invalid gen_reuse {self.gen_reuse}") + else: + self.gen_reuse = 'moo' if self.run_mode == 'compile' else 'match' + supported_precision_type = get_args(_PRECISION_TYPE) supported_tensor_type = get_args(_TENSOR_TYPE) if not self.precision: @@ -397,19 +415,19 @@ def from_yaml(cls, path: str) -> 'TrainerArgs': def create_kwarg(cls, value: Any): if isinstance(value, dict): value = {k: cls.create_kwarg(v) for k, v in value.items()} - if '__type' in value: - value_type = load_type(value.pop('__type')) + if _TYPE_KEY in value: + value_type = load_type(value.pop(_TYPE_KEY)) return value_type(**value) - elif '__value_type' in value: - if 'value' not in value: - raise ValueError("value is required when __value_type is present") - value_type = value.pop('__value_type') + elif _VALUE_TYPE_KEY in value: + if _VALUE_KEY not in value: + raise ValueError(f"`{_VALUE_KEY}` is required when `{_VALUE_TYPE_KEY}` is present") + value_type = value.pop(_VALUE_TYPE_KEY) if value_type == 'function': # when type is function, the value should be the full qualified name of the function - return load_type(value['value']) + return load_type(value[_VALUE_KEY]) else: # call its __init__ function value_type = load_type(value_type) - return value_type(value['value']) + return value_type(value[_VALUE_KEY]) else: return value elif isinstance(value, list): @@ -424,11 +442,17 @@ def model_type(self): return load_type(self.model.type) @property - def aggregate_outputs(self): + def resolved_aggregate_outputs_fn(self): if not self.optimizer.aggregate_outputs_fn: return None return load_type(self.optimizer.aggregate_outputs_fn) + @property + def resolved_pas_policy(self): + if self.pas_policy in _PREDEFINED_POLICIES: + return self.pas_policy + return load_type(self.pas_policy) + @property def scaling_factor(self): return self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py index 676649d1..feae1496 100644 --- a/nnscaler/runtime/f16_optimizer.py +++ b/nnscaler/runtime/f16_optimizer.py @@ -63,17 +63,14 @@ def state_dict(self): assert 'exp_avg' in state_dict['state'][0], f'currently only verified for adam-like optimizer' for key, value in state_dict['state'].items(): assert self.fp32_params[key].shape == value['exp_avg'].shape, f'Shape mismatch: {value["exp_avg"].shape} vs {self.fp32_params[key].shape}' - value['fp32_params'] = self.fp32_params[key] + # .detach(): save tensor instead of Parameter. + value['fp32_params'] = self.fp32_params[key].detach() return state_dict def load_state_dict(self, state_dict): """Load an optimizer state dict. - - In general we should prefer the configuration of the existing optimizer - instance (e.g., learning rate) over that found in the state_dict. This - allows us to resume training from a checkpoint using a new set of - optimizer args. + This will also load the fp32_params from the state """ if 'state' in state_dict and len(state_dict['state']) > 0 and 'fp32_params' in state_dict['state'][0]: logger.info('try to load fp32_params from state_dict in f16_optimizer') diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 0f4bd8df..7e4ba3c1 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -171,3 +171,17 @@ class A: x = parse_args(['--p.a=1', '--p.b=b']) y = deserialize_dataclass(x, A) assert y.p == {'a': 1, 'b': 'b'} # Dict[str, str] is ignored. so '1' will be converted to int + + +def test_deserialize_value_type(): + @dataclass + class A: + p: Any = None + + x = parse_args(['--p.__value_type=int', '--p.value=1']) + y = deserialize_dataclass(x, A) + assert y.p == {'__value_type': 'int', 'value': '1'} + + x = parse_args(['--p.value=1']) + y = deserialize_dataclass(x, A) + assert y.p == {'value': 1} diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 78efa2f6..f588e9ac 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -6,6 +6,7 @@ import torch.distributed from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer_args import AggregatedOutputs from tests.parallel_module.common import assert_equal from ..launch_torchrun import launch_torchrun @@ -66,6 +67,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' \ if bf16 == 'Mixed' \ else 'torch.optim.Adam' + use_zero = save_type == 'sharded' trainer = Trainer([ '-f', config_path, @@ -76,6 +78,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt_savedir), '--checkpoint.resume_from', 'last', @@ -98,6 +101,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', @@ -124,6 +128,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', @@ -152,6 +157,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--gen_savedir', str(gen_savedir), '--compute_config.plan_ngpus', '2', '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt1_savedir), '--checkpoint.resume_from', str(ckpt1_savedir / 'merged.pt'), @@ -226,3 +232,174 @@ def trainer_last_checkpoint_worker(save_dir): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_trainer_last_checkpoint(tmp_path): launch_torchrun(1, trainer_last_checkpoint_worker, tmp_path) + + +_train_losses = [] +_val_losses = [] + + +def after_aggregate_train_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None: + _train_losses.append(train_loss) + + +def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None: + _val_losses.append(val_loss) + + +def trainer_loss_reduction_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + _train_losses.clear() + _val_losses.clear() + trainer = Trainer([ + '-f', config_path, + '--enable_progress_bar', 'false', + '--max_epochs', '1', + '--global_batch_size', '4', # mini_batch_size=2, update_freq=2 + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + '--val_every_n_train_steps', '1', + '--checkpoint.no_save', 'true', + '--optimizer.loss_reduction', 'mean', + '--hook.after_aggregate_train_step_outputs', + 'tests.cli.test_trainer.after_aggregate_train_step_outputs', + '--hook.after_aggregate_val_step_outputs', + 'tests.cli.test_trainer.after_aggregate_val_step_outputs', + ]) + trainer.train() + + # get a copy + train_loss_mean = _train_losses[:] + val_loss_mean = _val_losses[:] + + torch.distributed.barrier() + + _train_losses.clear() + _val_losses.clear() + + trainer = Trainer([ + '-f', config_path, + '--enable_progress_bar', 'false', + '--max_epochs', '1', + '--global_batch_size', '4', # mini_batch_size=2, update_freq=2 + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + '--val_every_n_train_steps', '1', + '--checkpoint.no_save', 'true', + '--optimizer.loss_reduction', 'sum', + '--hook.after_aggregate_train_step_outputs', + 'tests.cli.test_trainer.after_aggregate_train_step_outputs', + '--hook.after_aggregate_val_step_outputs', + 'tests.cli.test_trainer.after_aggregate_val_step_outputs', + ]) + trainer.train() + torch.distributed.barrier() + + assert len(train_loss_mean) == len(_train_losses) + assert len(val_loss_mean) == len(_val_losses) + for i in range(len(train_loss_mean)): + assert train_loss_mean[i] == _train_losses[i] / 2 # 2 is update freq + + for i in range(len(val_loss_mean)): + assert val_loss_mean[i] == _val_losses[i] + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_trainer_loss_reduction(tmp_path): + launch_torchrun(1, trainer_loss_reduction_worker, tmp_path) + + +_before_step_grads = None +def before_gnorm_clip(trainer: 'Trainer') -> None: + global _before_step_grads + _before_step_grads = {i: g.grad.clone() for i, g in enumerate(trainer.optimizer.param_groups[0]['params'])} + + +def aggregate_outputs(loss_outputs, sync_group) -> 'AggregatedOutputs': + # loss is the first element of the output (or the only element) + losses = [ + loss if isinstance(loss, torch.Tensor) + else loss[0] + for loss in loss_outputs + ] + loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) + torch.distributed.all_reduce(loss_sum, group=sync_group) + num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) + torch.distributed.all_reduce(num_batches, group=sync_group) + num_tokens = num_batches * 2 # fake value + + return AggregatedOutputs( + loss_sum = loss_sum.item(), + num_batches=num_batches.item(), + num_tokens=num_tokens.item(), + ) + + +def trainer_per_token_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + _train_losses.clear() + _val_losses.clear() + trainer = Trainer([ + '-f', config_path, + '--enable_progress_bar', 'false', + '--max_epochs', '1', + '--global_batch_size', '8', + '--grad_accumulation_steps', '2', + '--gen_savedir', str(gen_savedir), + '--max_train_steps', '1', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '2', + '--val_every_n_train_steps', '1', + '--checkpoint.no_save', 'true', + '--optimizer.grad_reduction', 'mean', + '--optimizer.aggregate_outputs_fn', 'tests.cli.test_trainer.aggregate_outputs', + '--hook.before_gnorm_clip', + 'tests.cli.test_trainer.before_gnorm_clip', + ]) + trainer.train() + + # get a copy + grads = _before_step_grads + + torch.distributed.barrier() + + trainer = Trainer([ + '-f', config_path, + '--enable_progress_bar', 'false', + '--max_epochs', '1', + '--global_batch_size', '8', + '--grad_accumulation_steps', '2', + '--gen_savedir', str(gen_savedir), + '--max_train_steps', '1', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '2', + '--val_every_n_train_steps', '1', + '--checkpoint.no_save', 'true', + '--optimizer.grad_reduction', 'per-token-mean', + '--optimizer.aggregate_outputs_fn', 'tests.cli.test_trainer.aggregate_outputs', + '--hook.before_gnorm_clip', + 'tests.cli.test_trainer.before_gnorm_clip', + ]) + trainer.train() + + torch.distributed.barrier() + + assert set(grads.keys()) == set(_before_step_grads.keys()) + for n, p in grads.items(): + assert torch.equal(p / 2, _before_step_grads[n]) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_trainer_per_token(tmp_path): + launch_torchrun(2, trainer_per_token_worker, tmp_path) From cb40f768026df2e2680e0cebf030bd08d8d5bfd2 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 8 Aug 2024 03:12:22 +0000 Subject: [PATCH 1699/1892] Merged PR 2222: parallel module: remove pipeline related config from global compute config parallel module: remove pipeline related config from global compute config --- docs/source/parallel_module.md | 41 +++++++------- examples/vision/swin/policy/gallery.py | 8 ++- examples/vision/swin/train.py | 8 ++- nnscaler/autodist/autodist_config.py | 6 ++ nnscaler/cli/trainer_args.py | 4 -- nnscaler/codegen/module/module.py | 4 +- nnscaler/parallel.py | 70 +++++++++++------------- nnscaler/policies.py | 30 +++++----- nnscaler/runtime/module.py | 24 +++++--- tests/parallel_module/common.py | 6 +- tests/parallel_module/test_checkpoint.py | 13 +++-- tests/parallel_module/test_end2end.py | 35 +++++++----- tests/parallel_module/test_gencode.py | 9 +-- 13 files changed, 140 insertions(+), 118 deletions(-) diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 2d7030e6..fb0e5ebf 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -140,10 +140,6 @@ class End2EndMLP(nn.Module): runtime_ngpus=..., use_zero=..., use_end2end=True, - use_pipeline=..., - pipeline_nmicros=..., - pipeline_nstages=..., - pipeline_scheduler=..., ..., ) # compute environment config ParallelizedPipelinedLLM = parallelize( @@ -154,8 +150,6 @@ class End2EndMLP(nn.Module): ) ``` -If you want to enable pipeline parallelism, you need to set `use_end2end=True` and `use_pipeline=True` in `ComputeConfig`. You also need to set `pipeline_nmicros` and `pipeline_nstages` to specify the number of microbatches and stages in the pipeline. The `pipeline_scheduler` is the scheduler to schedule the pipeline. See below for details. - For end2end modules, you can't use `Module.forward`. Instead, you must use `ParallelModule.train_step` and `ParallelModule.infer_step` to train/infer the module. @@ -196,11 +190,6 @@ class ComputeConfig: inference_only : bool = False use_end2end: bool = False - use_pipeline: bool = False - pipeline_nmicros: int = -1 - pipeline_nstages: int = 1 - pipeline_scheduler: Optional[str] = None - pas_config: Dict[str, Any] = field(default_factory=dict) user_config: Dict[str, Any] = field(default_factory=dict) ``` @@ -236,10 +225,6 @@ We can categorize the fields into 4 categories: - `zero_ngroups`: the number of groups to be used in zero. - `inference_only`: whether to generate code for inference only. If it is true, the generated code can not be used to train the model. - `use_end2end`: whether to use end2end training. For the requirement of end2end, see the description above. - - `use_pipeline`: whether to use pipeline. Please note the pipeline parallelism is only supported for end2end modules, so you must set `use_end2end=True` if you want to use pipeline. - - `pipeline_nmicros`: the number of microbatches in the pipeline. - - `pipeline_nstages`: the number of stages in the pipeline. - - `pipeline_scheduler`: the scheduler name for the pipeline. Current we support four schedulers in training `1f1b`/`1f1b_plus`/`gpipe`/`chimera_direct` (4 stages pipeline only), and one scheduler in inference `infer_pipe`. - `pas_config`: the configuration for the PAS policy (partition-assign-schedule policy, which describes how to place all computations across devices. For details, please refer to [PAS Policies](#pas-policies)). It is a dictionary, and will be used by the PAS policy. Please note different PAS will have different configurations, @@ -543,7 +528,7 @@ Please note: It has the following arguments: - `samples` (`List[Any]`): a list of samples. - if pipeline is used, it must have the same length as pipeline_nmicros + if pipeline is used, it must have the same length as configured to pas policy. - `is_dummy_batch` (`Optional[List[bool]]`): indicates whether the each micro-batch is dummy - `scale_fn` (`Optional[Callable[[torch.Tensor], torch.Tensor]]`): the function to scale the loss @@ -555,7 +540,7 @@ def infer_step(self, samples: List[Any]) -> List[Any]: ... ``` The inference step function. It should be called in the inference loop. -The input is a list of samples, and returns a list of outputs for the samples. If pipeline is used, it must have the same length as pipeline_nmicros +The input is a list of samples, and returns a list of outputs for the samples. If pipeline is used, it must have the same length as configured to pas policy. ### PAS Policies @@ -568,11 +553,22 @@ The configuration of the PAS policy should be passed in the `pas_config` of `Com 2. `tp`: tensor parallelism + data parallelism. It will do tensor parallelism inside a scale unit, and run data parallelism across scale units. It has only one configuration: - seed: the random seed for choose the partition dimension. Default is `1` -3. `pp`: pipeline parallelism + data parallelism. It will do model parallelism inside a scale unit, and run data parallelism across scale units. It requires the `use_end2end` and `use_pipeline` to be true. It has no configurations. +3. `pp`: pipeline parallelism + data parallelism. +It will do model parallelism inside a scale unit, +and run data parallelism across scale units. +It requires the `use_end2end` be true. +It has two configurations `pipeline_nmicros` and `pipeline_scheduler`. +See `hybrid` policy for more details. 4. `data`: tensor parallelism on batch dimension. It has no configurations. -5. `hybrid`: pipeline parallelism + tensor parallelism + data parallelism. It will do model parallelism and tensor parallelism(on 0 dimension) inside a scale unit, and run data parallelism across scale units. It requires the `use_end2end` and `use_pipeline` to be true. It has no configurations. +5. `hybrid`: pipeline parallelism + tensor parallelism + data parallelism. +It will do model parallelism and tensor parallelism(on 0 dimension) inside a scale unit, +and run data parallelism across scale units. +It requires the `use_end2end` to be true. It has the following configurations. + - `pipeline_nstages`: the number of stages in the pipeline. Default is `plan_ngpus`. Optional. + - `pipeline_nmicros`: the number of microbatches in the pipeline. Required. + - `pipeline_scheduler`: the scheduler name for the pipeline. Current we support four schedulers in training `1f1b`/`1f1b_plus`/`gpipe`/`chimera_direct` (4 stages pipeline only), and one scheduler in inference `infer_pipe`. Default is `1f1b`. Optional. 6. `autodist`: the recommended policy for most cases. Currently it only support Adam-like optimizers. It will automatically choose the best partition for you by balancing the memory usage and speed. It has the following configurations. - `update_freq (int)`: the update frequency when training the module. Default is 1. Optional. @@ -590,10 +586,11 @@ The configuration of the PAS policy should be passed in the `pas_config` of `Com - `recompute_modules (str)`: The module names to recompute, separated by `,`. For example, `module1,module2`. Optional. - `pipeline_pivots (str)`: The module names to pivot the pipeline, separated by `,`. For example, if `module1,module2` is specified, stages searched by pipeline solver only start from either `module1` or `module2`. Optional. - `use_apex_fused_adam_v2`: If set to `True`, the apex fused adam v2 optimizer will be used. Default is `False`. Optional. + - `explore_pipeline`: If set to `True`, autodist will try pipeline parallelism to find the best partition plan + (but the selected partition plan is not necessarily pipeline parallelism). + - `pipeline_scheduler`: The scheduler name for the pipeline. Please note currently `1f1b` is the only supported scheduler in `autodist`. Default is `1f1b`. Optional. - `parallel_profile`: If set to `True`, autodist will profile operators in parallel by using available gpus. Default is `True`. Optional. - - `max_partition_degree`: Max degree when partitioning an operator / node. When pipeline parallelism is enbaled (`use_pipeline` is True), user can change the value to constrain the plan to be composed of stages that span on less or equal to `max_partition_degree` devices (recommend to set `max_partition_degree` to the number of devices in a node to avoid inter-node communication, but should be be no more than `plan_ngpus`). Default is `plan_ngpus`. Optional. - -Please note all options to `autodist` are just suggestions. `autodist` will try to find the best partition for you, which may not be the same with your suggestions. + - `max_partition_degree`: Max degree when partitioning an operator / node. When pipeline parallelism is enabled to explore (`explore_pipeline` is True), user can change the value to constrain the plan to be composed of stages that span on less or equal to `max_partition_degree` devices (recommend to set `max_partition_degree` to the number of devices in a node to avoid inter-node communication, but should be be no more than `plan_ngpus`). Default is `plan_ngpus`. Optional. You can also put any other settings that can affect code generation here. but please prefix the keys with `_` to avoid conflicts with predefined keys. diff --git a/examples/vision/swin/policy/gallery.py b/examples/vision/swin/policy/gallery.py index 380c30b6..7576ae9c 100644 --- a/examples/vision/swin/policy/gallery.py +++ b/examples/vision/swin/policy/gallery.py @@ -80,13 +80,15 @@ def pas_mesh_shard(graph: IRGraph, cfg: ComputeConfig): def pas_1f1b(graph: IRGraph, cfg: ComputeConfig): """1F1B schedule""" - num_stages = cfg.pipeline_nstages + num_stages = cfg.pas_config['pipeline_nstages'] + nmicros = cfg.pas_config['pipeline_nmicros'] + scheduler = cfg.pas_config.get('pipeline_scheduler', '1f1b') if num_stages != cfg.plan_ngpus: raise ValueError('1F1B schedule requires num_stages == plan_ngpus') # group to transformer layers transformers = group_to_layers(graph.select(ntype=IRFwOperation)) - stages = mitr.divide(cfg.pipeline_nstages, transformers) + stages = mitr.divide(num_stages, transformers) stages = [list(itertools.chain(*s)) for s in stages] graph.staging([t[0] for t in stages]) @@ -102,5 +104,5 @@ def pas_1f1b(graph: IRGraph, cfg: ComputeConfig): for node in graph.select(ntype=IRDataOperation): replica(graph, node, list(range(cfg.plan_ngpus))) # apply 1f1b schedule - cfg.apply_pipeline_scheduler(graph) + cfg.apply_pipeline_scheduler(graph, num_stages, nmicros, scheduler) return graph diff --git a/examples/vision/swin/train.py b/examples/vision/swin/train.py index 190ea5a2..27416f71 100644 --- a/examples/vision/swin/train.py +++ b/examples/vision/swin/train.py @@ -150,9 +150,7 @@ def train(args, compute_config: nnscaler.ComputeConfig): use_zero=args.zero, use_end2end=True, constant_folding=True, - use_pipeline=args.pp_size > 1, - pipeline_nmicros=args.gbs // args.mbs // args.dp_size, - pipeline_nstages=args.pp_size, + pas_config={ # customized settings that can affect code generation. '_pas_name': args.policy, @@ -163,6 +161,10 @@ def train(args, compute_config: nnscaler.ComputeConfig): # for autodist only 'update_freq': args.gbs // args.mbs// args.dp_size, 'use_fp16': args.fp16, + 'explore_pipeline': args.pp_size > 1, + # for pp + 'pipeline_nmicros': args.gbs // args.mbs // args.dp_size, + 'pipeline_nstages': args.pp_size, }, user_config={ 'mbs': args.mbs, diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 5373115a..fbef7612 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -92,6 +92,8 @@ class AutoDistConfig: is specified, stages searched by pipeline solver only start from either `module1` or `module2`. - pipeline_nstages(`int`, *optional*, defaults to `1`): The number of stages in pipeline parallelism. This option is only used when pipeline is True. + - pipeline_scheduler (`str`, *optional*, defaults to `'1f1b'`): + The pipeline scheduler to use. Currently only support `'1f1b'`. - max_pipeline_bubble_ratio (`float`, *optional*, defaults to `0.2`): The maximum bubble ratio in pipeline parallelism. The higher the ratio, the more bubbles will be allowed, the larger search space will be explored. @@ -140,6 +142,7 @@ def __init__(self, pipeline=False, pipeline_pivots='', pipeline_nstages=1, + pipeline_scheduler='1f1b', max_pipeline_bubble_ratio=0.2, max_pipeline_unbalance_ratio=0.5, solver='dp', @@ -176,6 +179,9 @@ def __init__(self, self.pipeline = pipeline self.pipeline_pivots = pipeline_pivots self.pipeline_nstages = pipeline_nstages + self.pipeline_scheduler = pipeline_scheduler + if self.pipeline_scheduler != '1f1b': + raise ValueError(f'pipeline scheduler {self.pipeline_scheduler} must be 1f1b') self.max_pipeline_bubble_ratio = max_pipeline_bubble_ratio self.max_pipeline_unbalance_ratio = max_pipeline_unbalance_ratio self.solver = solver diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index e868ba9f..dae04974 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -339,10 +339,6 @@ def __post_init__(self): raise ValueError(f"`global_batch_size` {self.global_batch_size} is not equal to `micro_batch_size*scaling_factor*grad_accumulation_steps` " f"{self.micro_batch_size*self.scaling_factor*self.grad_accumulation_steps}") - if self.compute_config.use_pipeline and self.grad_accumulation_steps != self.compute_config.pipeline_nmicros: - raise ValueError(f"grad_accumulation_steps {self.grad_accumulation_steps} must be equal to " - f"compute_config.pipeline_nmicros {self.compute_config.pipeline_nmicros}") - if self.run_mode not in ('compile', 'run'): raise ValueError(f"Invalid run_mode {self.run_mode}") diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index d1cb3508..1e5a2e25 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -462,7 +462,9 @@ def forward(self, x, y=None, z=None): class_name='GenModel', derived=[f'nnscaler.runtime.module.{"ParallelModule" if as_parallel_module else "CubeModule"}'] ) as cb: - cb.insert_body(f'use_scheduler = {self.execplan.graph.sched is not None}') + graph_sched = self.execplan.graph.sched + cb.insert_body(f'use_scheduler = {graph_sched is not None}') + cb.insert_body(f'nmicros_per_scheduler_step = {graph_sched.nmicros if graph_sched is not None else 1}') if as_parallel_module: cb.insert_body(f'rank = {device}') # save rank in class level with FunctionBlock(func_name='__init__', args=['self', 'init_params=True']) as ib: diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 63571d95..8d24e4fc 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -82,15 +82,6 @@ class ComputeConfig: # which must be a scalar tensor use_end2end: bool = False - # current only end2end module supports in pipeline mode. - # so be sure to set use_end2end=True when use_pipeline=True - use_pipeline: bool = False - # number of micro-batches - pipeline_nmicros: int = -1 - # number of stages - pipeline_nstages: int = -1 - # it is pas's responsibility to apply the scheduler - pipeline_scheduler: str = '1f1b' # PAS policy settings # you can also put any other settings that can affect code generation here. # but please prefix the keys with `_` to avoid conflicts with predefined keys. @@ -143,32 +134,35 @@ def __post_init__(self): # have to use __setattr__ for frozen dataclass super().__setattr__('zero_ngroups', 1) - if self.use_pipeline: - if not self.use_end2end: - raise ValueError("pipeline is only supported in end2end mode") - if self.pipeline_nmicros <= 0: - raise ValueError(f"pipeline_nmicros {self.pipeline_nmicros} must be > 0 when use pipeline") - if self.pipeline_nstages <= 0: - raise ValueError(f"pipeline_nstages {self.pipeline_nstages} must be > 0 when use pipeline") - if self.plan_ngpus % self.pipeline_nstages != 0: - raise ValueError(f"pipeline_nstages {self.plan_ngpus} must be a multiple of plan_ngpus {self.pipeline_nstages}") - if self.pipeline_scheduler not in _PREDEFINE_SCHEDS: - raise ValueError(f"pipeline_scheduler {self.pipeline_scheduler} is not supported. " - f"Supported schedulers are {_PREDEFINE_SCHEDS.keys()}") - if self.inference_only and self.pipeline_scheduler not in _PREDEFINED_INFERENCE_SCHEDS: - raise ValueError(f"pipeline_scheduler {self.pipeline_scheduler} is not supported in inference mode. " - f"Supported schedulers are {_PREDEFINED_INFERENCE_SCHEDS}") - if not self.inference_only and self.pipeline_scheduler in _PREDEFINED_INFERENCE_SCHEDS: - raise ValueError(f"pipeline_scheduler {self.pipeline_scheduler} is not supported in training mode.") - - def apply_pipeline_scheduler(self, graph: IRGraph) -> Optional[SchedulePlan]: + def apply_pipeline_scheduler( + self, + graph: IRGraph, + pipeline_nstages: int, + pipeline_nmicros: int, + pipeline_scheduler: str, + ) -> Optional[SchedulePlan]: """ Apply the pipeline scheduler to the graph. - Do nothing if not use_pipeline """ - if self.use_pipeline: - sched = _PREDEFINE_SCHEDS[self.pipeline_scheduler] - return sched(graph, self.pipeline_nmicros, self.pipeline_nstages) + if not self.use_end2end: + raise ValueError("pipeline is only supported in end2end mode") + if pipeline_nmicros <= 0: + raise ValueError(f"pipeline_nmicros {pipeline_nmicros} must be > 0.") + if pipeline_nstages <= 0: + raise ValueError(f"pipeline_nstages {pipeline_nstages} must be > 0.") + if self.plan_ngpus % pipeline_nstages != 0: + raise ValueError(f"pipeline_nstages {pipeline_nstages} must be a multiple of plan_ngpus {self.plan_ngpus}") + if pipeline_scheduler not in _PREDEFINE_SCHEDS: + raise ValueError(f"pipeline_scheduler {pipeline_scheduler} is not supported. " + f"Supported schedulers are {_PREDEFINE_SCHEDS.keys()}") + if self.inference_only and pipeline_scheduler not in _PREDEFINED_INFERENCE_SCHEDS: + raise ValueError(f"pipeline_scheduler {pipeline_scheduler} is not supported in inference mode. " + f"Supported schedulers are {_PREDEFINED_INFERENCE_SCHEDS}") + if not self.inference_only and pipeline_scheduler in _PREDEFINED_INFERENCE_SCHEDS: + raise ValueError(f"pipeline_scheduler {pipeline_scheduler} is not supported in training mode.") + + sched = _PREDEFINE_SCHEDS[pipeline_scheduler] + return sched(graph, pipeline_nmicros, pipeline_nstages) @property def gpu_config(self) -> Dict[str, int]: @@ -183,7 +177,6 @@ def graph_config(self) -> Dict[str, Any]: 'constant_folding': self.constant_folding, 'user_config': self.user_config, 'inference_only': self.inference_only, # there will be no backward nodes in the graph in inference mode - 'use_pipeline': self.use_pipeline, # pipeline option can affect the graph generation. 'end2end_mode': self.use_end2end, # end2end_mode can affect the graph generation. } @@ -618,7 +611,6 @@ def _gen_graph( constant_folding: bool, end2end_mode: bool = False, inference_only: bool = False, - use_pipeline: bool = False, ): # reset environment program = Program() @@ -702,8 +694,6 @@ def _gen_graph( if ir_dummy_outputs is None: ir_dummy_outputs = [] elif not isinstance(ir_dummy_outputs, (tuple, list)): ir_dummy_outputs = [ir_dummy_outputs] - if use_pipeline and _contains_uncommutable_data(ir_dummy_outputs): - raise RuntimeError(f"Communication generation error: some of outputs are not commutable between gpus, which is not supported in pipeline parallelism.") program.set_output(ir_dummy_outputs) program.finalize() @@ -783,7 +773,6 @@ def _gencode( module, dummy_forward_args, outdir, constant_folding=compute_config.constant_folding, end2end_mode=compute_config.use_end2end, inference_only=compute_config.inference_only, - use_pipeline=compute_config.use_pipeline, ) graph.dump(graph_ckp) torch.save(forward_args, forward_args_ckp) @@ -800,6 +789,13 @@ def _gencode( if not isinstance(graph, IRGraph): raise RuntimeError("Expected policy return IRGraph") + # currently graph.sched is only used for pipeline parallelism + # so it is not none means we are in pipeline parallelism + if graph.sched is not None and _contains_uncommutable_data(graph.outputs()): + raise RuntimeError("Communication generation error: " + "some of outputs are not commutable between gpus, " + "which is not supported in pipeline parallelism.") + # check assignment for node in graph.nodes(flatten=True): # skip graph anchor: will be removed diff --git a/nnscaler/policies.py b/nnscaler/policies.py index b3750caa..cfb65264 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -117,8 +117,8 @@ def pas_pp(graph: IRGraph, cfg: 'ComputeConfig'): """ pipeline parallelism inside a scale unit, and dp across scale units """ - if cfg.pipeline_nstages != cfg.plan_ngpus: - raise ValueError("pipeline_nstages should be equal to plan_ngpus") + if cfg.pas_config.get('pipeline_nstages', cfg.plan_ngpus) != cfg.plan_ngpus: + raise ValueError("pas_pp requires pipeline_nstages == plan_ngpus") return pas_hybrid(graph, cfg) @@ -155,11 +155,12 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): """ pipeline and tensor parallelism inside a scale unit, and dp across scale units """ - if not cfg.use_pipeline: - raise ValueError("pipeline should be enabled") - + if not cfg.use_end2end: + raise ValueError("Hybrid policy only supports end2end module") ngpus: int = cfg.plan_ngpus - nstages = cfg.pipeline_nstages + nstages = cfg.pas_config.get('pipeline_nstages', cfg.plan_ngpus) + nmicros = cfg.pas_config['pipeline_nmicros'] + scheduler = cfg.pas_config.get('pipeline_scheduler', '1f1b') tp_size: int = cfg.plan_ngpus // nstages if ngpus % tp_size != 0: raise ValueError(f'invalid tp_size {tp_size} for ngpus {ngpus}') @@ -192,26 +193,30 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): for dl in graph.select(ntype=IRDataOperation): _replica(graph, dl, devs=list(range(ngpus))) - cfg.apply_pipeline_scheduler(graph) + cfg.apply_pipeline_scheduler(graph, nstages, nmicros, scheduler) return graph def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: pas_cfg = cfg.pas_config - # required parameters update_freq = pas_cfg.get('update_freq', 1) if isinstance(update_freq, (tuple, list)): update_freq = update_freq[0] - if cfg.use_pipeline and update_freq != cfg.pipeline_nmicros: - raise ValueError("pipeline_nmicros should be equal to update_freq") # optional parameters + explore_pipeline = pas_cfg.get('explore_pipeline', False) + if explore_pipeline and not cfg.use_end2end: + raise ValueError("explore_pipeline cannot be enabled if use_end2end is False") + pipeline_scheduler = pas_cfg.get('pipeline_scheduler', '1f1b') + if pipeline_scheduler != '1f1b': + raise ValueError(f"Only 1f1b scheduler is supported in autodist.") + mesh_col = pas_cfg.get('max_partition_degree', cfg.plan_ngpus) if cfg.plan_ngpus % mesh_col != 0: raise ValueError(f"plan_ngpus {cfg.plan_ngpus} should be divisible by max_partition_degree {mesh_col}") mesh_row = cfg.plan_ngpus // mesh_col - if not cfg.use_pipeline and mesh_row != 1: + if not explore_pipeline and mesh_row != 1: raise ValueError("mesh_row should be 1 if pipeline is not enabled") memory_constraint = pas_cfg.get('mem_constraint', -1) task_name = pas_cfg.get('task_name', '_') @@ -288,9 +293,8 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: zero_ngroups=zero_ngroups, load_plan_path=load_plan_path, save_plan_path=save_plan_path, - pipeline=cfg.use_pipeline, + pipeline=explore_pipeline, pipeline_pivots=pipeline_pivots, - pipeline_nstages=cfg.pipeline_nstages, parallel_profile=parallel_profile, ) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index b6285ef8..6b41f183 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -121,6 +121,14 @@ class CubeModule(torch.nn.Module): before training """ + # whether the train_step/infer_step is using a scheduler, + # will be assigned in the generated subclasses + use_scheduler: bool + # the number of microbatches in one scheduler train/infer step + # 1 if no scheduler is used. + # will be assigned in the generated subclasses + nmicros_per_scheduler_step: int + def __init__(self): super().__init__() self._reducers: List[Reducer] = list() @@ -925,7 +933,7 @@ def train_step(self, because the gradients will be cleared in the beginning of this function Args: samples (List[Any]): a list of samples. - if pipeline is used, it must have the same length as pipeline_nmicros + if pipeline is used, it must have the same length as configured to pas policy is_dummy_batch (Optional[List[bool]]): indicates whether the each micro-batch is dummy scale_fn (Optional[Callable[[torch.Tensor], torch.Tensor]]): the function to scale the loss Results: @@ -946,9 +954,9 @@ def train_step(self, sample_count = len(samples) dataloader = microbatches(samples, cycle=False) - if self.compute_config.use_pipeline: - if len(samples) != self.compute_config.pipeline_nmicros: - raise ValueError(f"Expected {self.compute_config.pipeline_nmicros} samples, but got {sample_count}") + if self.use_scheduler: + if len(samples) != self.nmicros_per_scheduler_step: + raise ValueError(f"Expected {self.nmicros_per_scheduler_step} samples, but got {sample_count}") # only one step, so begin/end are both True with accum_mode(begin=True, end=True): return self._train_step(dataloader) @@ -967,7 +975,7 @@ def infer_step(self, samples: List[Any]) -> List[Any]: Args: samples (List[Any]): a list of samples. - if pipeline is used, it must have the same length as pipeline_nmicros + if pipeline is used, it must have the same length as configured to pas policy Results: List[Any]: a list of outputs for each sample """ @@ -978,9 +986,9 @@ def infer_step(self, samples: List[Any]) -> List[Any]: sample_count = len(samples) dataloader = microbatches(samples, cycle=False) - if self.compute_config.use_pipeline: - if len(samples) != self.compute_config.pipeline_nmicros: - raise ValueError(f"Expected {self.compute_config.pipeline_nmicros} samples, but got {sample_count}") + if self.use_scheduler: + if len(samples) != self.nmicros_per_scheduler_step: + raise ValueError(f"Expected {self.nmicros_per_scheduler_step} samples, but got {sample_count}") return self._infer_step(dataloader) else: outputs = [] diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 1ff3456e..c06ac4a7 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -51,7 +51,9 @@ def create_mesh(ngpus: int, group_num: Tuple[int]) -> Tuple[Tuple[Tuple[int]]]: def PASMegatron(graph: IRGraph, config: ComputeConfig): - num_stages = config.pipeline_nstages + num_stages = config.pas_config['pipeline_nstages'] + nmicros = config.pas_config['pipeline_nmicros'] + scheduler = config.pas_config.get('pipeline_scheduler', '1f1b') tp_size = config.plan_ngpus // num_stages _, tp_mesh = create_mesh(config.plan_ngpus, (num_stages, tp_size)) @@ -74,7 +76,7 @@ def PASMegatron(graph: IRGraph, config: ComputeConfig): for dl in graph.select(ntype=IRDataOperation): _replica(graph, dl, devs=list(range(config.plan_ngpus))) - config.apply_pipeline_scheduler(graph) + config.apply_pipeline_scheduler(graph, num_stages, nmicros, scheduler) return graph diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index ad49d9cf..fb0ac375 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -89,10 +89,11 @@ def to_pipeline_module(cls, compute_config: ComputeConfig, cube_savedir, assert compute_config.plan_ngpus == 2 compute_config = replace(compute_config, use_end2end=True, - use_pipeline=True, - pipeline_nmicros=2, - pipeline_nstages=2, - pipeline_scheduler=scheduler + pas_config=dict( + pipeline_nmicros=2, + pipeline_nstages=2, + pipeline_scheduler=scheduler + ) ) return parallelize( cls, @@ -132,7 +133,7 @@ def __init__(self): def train_step(model, x, y, optimizer): model.train() - if isinstance(model, ParallelModule) and model.compute_config.use_pipeline: + if isinstance(model, ParallelModule) and model.use_scheduler: # actually train_step will return two losses (for each input) # here we fake one loss to y_pred, so we don't need to change the check logic y_pred, loss = model.train_step(x) @@ -155,7 +156,7 @@ def train_step(model, x, y, optimizer): def gendata(model, data_size, start, end, rank, num_replicas): data = [] init_random() - if isinstance(model, ParallelModule) and model.compute_config.use_pipeline: + if isinstance(model, ParallelModule) and model.use_scheduler: data = End2EndMLP.gen_pipeline_data(data_size, start, end, rank, num_replicas) elif isinstance(model, End2EndMLP): data = End2EndMLP.gen_raw_data(data_size, start, end, rank, num_replicas) diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index 402aa949..b1d5fc38 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -102,7 +102,7 @@ def _train_ga(model, update_freq, data_size=DATA_SIZE): return results -def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, use_pipeline, nstages=None, nmicros=None, model_cls=MLP, pipeline_scheduler='1f1b'): +def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, nstages=None, nmicros=None, model_cls=MLP, pipeline_scheduler='1f1b'): init_distributed() init_random() nstages = nstages or plan_ngpus @@ -117,8 +117,11 @@ def gpu_worker_cube(runtime_ngpus, plan_ngpus, policy, use_pipeline, nstages=Non compute_config= ComputeConfig( plan_ngpus, runtime_ngpus, use_end2end=True, - use_pipeline=use_pipeline, pipeline_nmicros=nmicros, pipeline_nstages=nstages, - pipeline_scheduler=pipeline_scheduler + pas_config=dict( + pipeline_nmicros=nmicros, + pipeline_nstages=nstages, + pipeline_scheduler=pipeline_scheduler + ), ), gen_savedir=tempdir ) @@ -170,7 +173,7 @@ def test_end2end(): ga4_result = _train_ga(model, 4) # micro_batch_size = 4 assert len(ga4_result) == 16 - cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid', True) # micro_batch_size = 4 + cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid') # micro_batch_size = 4 for _, v in cube2_results.items(): # all losses should be scalar tensor assert all(i.shape == () for i in v[1]) @@ -178,7 +181,7 @@ def test_end2end(): assert len(cube2_result) == 16 allclose(cube2_result, ga4_result) - cube4_results = launch_torchrun(4, gpu_worker_cube, 4, 4, PASMegatron, True) # micro_batch_size = 4 + cube4_results = launch_torchrun(4, gpu_worker_cube, 4, 4, PASMegatron) # micro_batch_size = 4 for _, v in cube2_results.items(): # all losses should be scalar tensor assert all(i.shape == () for i in v[1]) @@ -186,7 +189,7 @@ def test_end2end(): assert len(cube4_result) == 16 allclose(cube4_result, ga4_result) - cube2_results_non_pipeline = launch_torchrun(4, gpu_worker_cube, 4, 2, 'tp', False) # micro_batch_size = 4 + cube2_results_non_pipeline = launch_torchrun(4, gpu_worker_cube, 4, 2, 'tp') # micro_batch_size = 4 for _, v in cube2_results.items(): # all losses should be scalar tensor assert all(i.shape == () for i in v[1]) @@ -234,16 +237,16 @@ def test_pipeline_shared(): ComputeConfig( 2, 2, inference_only=False, - use_end2end=True, - use_pipeline=True, pipeline_nmicros=2, pipeline_nstages=2, + use_end2end=True).apply_pipeline_scheduler( + None, pipeline_nmicros=2, pipeline_nstages=2, pipeline_scheduler='infer_pipe' ) with pytest.raises(ValueError, match='is not supported in inference mode'): ComputeConfig( 2, 2, inference_only=True, - use_end2end=True, - use_pipeline=True, pipeline_nmicros=2, pipeline_nstages=2, + use_end2end=True).apply_pipeline_scheduler( + None, pipeline_nmicros=2, pipeline_nstages=2, pipeline_scheduler='1f1b' ) @@ -251,13 +254,13 @@ def test_pipeline_shared(): # 'chimera_direct' needs more gpus # 'infer_pipe' only work for inference # None looks doesn't work - cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid', True, None, None, MLPShared, ps) # micro_batch_size = 4 + cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid', None, None, MLPShared, ps) # micro_batch_size = 4 cube2_result = merge_cube_result({k: v[0] for k, v in cube2_results.items()}) assert len(cube2_result) == 16 allclose(cube2_result, ga4_result) # TODO: fix `chimera_direct` - # cube4_results = launch_torchrun(4, gpu_worker_cube, 4, 4, PASMegatron, True, None, None, MLPShared, 'chimera_direct') # micro_batch_size = 4 + # cube4_results = launch_torchrun(4, gpu_worker_cube, 4, 4, PASMegatron, None, None, MLPShared, 'chimera_direct') # micro_batch_size = 4 # cube4_result = merge_cube_result({k: v[0] for k, v in cube4_results.items()}) # assert len(cube4_result) == 16 # allclose(cube4_result, ga4_result) @@ -274,7 +277,7 @@ def test_pipeline(): # pp_size = 2 # tp_size = 2 # scale unit size = 4 - cube8_results = launch_torchrun(8, gpu_worker_cube, 8, 4, PASMegatron, True, 2, 2) # micro_batch_size = 4 + cube8_results = launch_torchrun(8, gpu_worker_cube, 8, 4, PASMegatron, 2, 2) # micro_batch_size = 4 cube8_result = merge_cube_result({k: v[0] for k, v in cube8_results.items()}) assert len(cube8_result) == 16 allclose(cube8_result, ga4_result, atol=1e-5, rtol=1e-5) # looks tp introduces more error @@ -336,8 +339,10 @@ def gpu_worker_cube_one_sample(): compute_config= ComputeConfig( 2, 2, use_end2end=True, - use_pipeline=True, pipeline_nmicros=2, pipeline_nstages=2, - pipeline_scheduler='1f1b' + pas_config=dict( + pipeline_nmicros=2, pipeline_nstages=2, + pipeline_scheduler='1f1b' + ), ), gen_savedir=tempdir ) diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 507f0ff4..130645a5 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -727,10 +727,11 @@ def p(cube_dir, use_pipeline, constant_folding, return_type, inference_only=Fals inference_only=inference_only, constant_folding=constant_folding, use_end2end=True, - use_pipeline=use_pipeline, - pipeline_nmicros=4, - pipeline_nstages=4, - pipeline_scheduler='infer_pipe' if inference_only else '1f1b' + pas_config=dict( + pipeline_nmicros=4, + pipeline_nstages=4, + pipeline_scheduler='infer_pipe' if inference_only else '1f1b' + ) ), gen_savedir=cube_dir, load_module=False, From 06b94fc1a7b74667a31bb83fa2f48b3a2c0e9229 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Wed, 14 Aug 2024 00:13:37 +0000 Subject: [PATCH 1700/1892] Merged PR 2225: Add missing annotation for cross_entropy cross_entropy supports `reduction='none'` which means do no reduction. The PR adds support for this case. https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html --- nnscaler/graph/function/function.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index eea632f7..f6790e4d 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -2013,7 +2013,13 @@ def CrossEntropy(input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0) """ - if reduction == 'sum': + if reduction == 'none': + annos = [ + 'C^, N -> N', + 'N C^, N -> N', + 'N C^ *, N * -> N', + ] + elif reduction == 'sum': annos = [ 'C^, N -> 1', 'N+ C^, N+ -> 1', From a999b2509625ca45e9290bb9eac2bfc7538c8842 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 15 Aug 2024 02:19:05 +0000 Subject: [PATCH 1701/1892] Merged PR 2230: minitrainer: Add model/optimizer/lr_scheduler hook support. minitrainer: Add model/optimizer/lr_scheduler hook support. --- docs/source/trainer.md | 18 ++++- nnscaler/cli/train_hook.py | 107 +++++++++++++++++++++++++++++- nnscaler/cli/trainer.py | 18 ++++- nnscaler/cli/trainer_args.py | 2 + nnscaler/runtime/f16_optimizer.py | 46 ++++++++++++- 5 files changed, 186 insertions(+), 5 deletions(-) diff --git a/docs/source/trainer.md b/docs/source/trainer.md index 454b6a9f..ed24be93 100644 --- a/docs/source/trainer.md +++ b/docs/source/trainer.md @@ -231,7 +231,12 @@ Internally we will get the final value with `__value_type(value)`. - `type` (`str`): The logger type or factory function. - `args` (`Dict[str, Any]`): The arguments of the logger. -- `hook` (`Union[HookConfig, HookMapConfig, None]`): The hooks to be used. You can provide a hook with a hook class or a map of hook functions. +- `hook` (`Union[HookConfig, HookMapConfig, None]`): The hooks to be used. +You can provide a hook with a hook class or a map of hook functions. +Please note if your `model`/`optimizer`/`lr_scheduler` inherit from `TrainHook`, +their hook functions will be called automatically. +The order of the hook functions called is `model` -> `optimizer` -> `lr_scheduler`, +and hooks passed with this config is called in the last. Hook class: @@ -250,6 +255,8 @@ Internally we will get the final value with `__value_type(value)`. ```python @dataclass class HookMapConfig: + after_setup: str = None + on_train_start: str = None on_train_end: str = None on_val_start: str = None @@ -278,6 +285,9 @@ Internally we will get the final value with `__value_type(value)`. on_load_checkpoint: str = None on_save_checkpoint: str = None ``` + - `after_setup` (`str`): The hook function to be called after setting up the trainer. + Only be called when `run_mode == 'run'`. + Signature: `def after_setup(trainer: 'Trainer') -> None:` - `on_train_start` (`str`): The hook function to be called at the start of the training stage. Signature: `def on_train_start(trainer: 'Trainer') -> None:` - `on_train_end` (`str`): The hook function to be called at the end of the training stage. Signature: `def on_train_end(trainer: 'Trainer') -> None:` - `on_val_start` (`str`): The hook function to be called at the start of the validation stage. Signature: `def on_val_start(trainer: 'Trainer') -> None:` @@ -292,7 +302,11 @@ Internally we will get the final value with `__value_type(value)`. - `after_aggregate_val_step_outputs` (`str`): The hook function to be called after aggregating the outputs of the model in the validation step. Signature: `def after_aggregate_val_step_outputs(trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None:` - `before_zero_grad` (`str`): The hook function to be called before zeroing the gradients. Signature: `def before_zero_grad(trainer: 'Trainer') -> None:` - `after_zero_grad` (`str`): The hook function to be called after zeroing the gradients. Signature: `def after_zero_grad(trainer: 'Trainer') -> None:` - - `before_sync_grad` (`str`): The hook function to be called before syncing the gradients between ranks. Signature: `def before_sync_grad(trainer: 'Trainer') -> None:` + - `before_sync_grad` (`str`): The hook function to be called before syncing the gradients between ranks. + Please note this hook can't be triggered correctly, + and you should not reply on this. + Will fix it later. + Signature: `def before_sync_grad(trainer: 'Trainer') -> None:` - `after_sync_grad` (`str`): The hook function to be called after syncing the gradients between ranks. Signature: `def after_sync_grad(trainer: 'Trainer') -> None:` - `before_gnorm_clip` (`str`): The hook function to be called before gradient clipping. Signature: `def before_gnorm_clip(trainer: 'Trainer') -> None:` - `after_gnorm_clip` (`str`): The hook function to be called after gradient clipping. Signature: `def after_gnorm_clip(trainer: 'Trainer', gnorm: torch.Tensor) -> None:` diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index 28e590e0..ea68a823 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -11,6 +11,13 @@ class TrainHook: """ Note: All hooks are called in all ranks, and the inputs of hooks are only the local data. """ + + def after_setup(self, trainer: 'Trainer') -> None: + """ + Called after trainer setup when run_mode == 'run'. + When run_mode == 'compile', this hook will not be called. + """ + def on_train_start(self, trainer: 'Trainer') -> None: """Called at the beginning of training""" @@ -103,7 +110,8 @@ def after_zero_grad(self, trainer: 'Trainer') -> None: def before_sync_grad(self, trainer: 'Trainer') -> None: """ - Called before sync_shard_grad + Called before sync_shard_grad. + TODO: Please note this can't be triggered correctly, because end2end mode is not supported. """ def after_sync_grad(self, trainer: 'Trainer') -> None: @@ -149,3 +157,100 @@ def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> Args: checkpoint: the checkpoint to be saved """ + + +class AggregatedTrainHook(TrainHook): + def __init__(self, hooks: List[TrainHook]): + self.hooks = hooks + + def after_setup(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.after_setup(trainer) + + def on_train_start(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.on_train_start(trainer) + + def on_train_end(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.on_train_end(trainer) + + def on_val_start(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.on_val_start(trainer) + + def on_val_end(self, trainer: 'Trainer', val_loss: float) -> None: + for hook in self.hooks: + hook.on_val_end(trainer, val_loss) + + def on_epoch_start(self, trainer: 'Trainer', epoch: int) -> None: + for hook in self.hooks: + hook.on_epoch_start(trainer, epoch) + + def on_epoch_end(self, trainer: 'Trainer', epoch: int) -> None: + for hook in self.hooks: + hook.on_epoch_end(trainer, epoch) + + def on_train_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: + for hook in self.hooks: + hook.on_train_step_start(trainer, batches, idx) + + def on_train_step_end(self, trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None: + for hook in self.hooks: + hook.on_train_step_end(trainer, outputs, batches, idx) + + def on_val_step_start(self, trainer: 'Trainer', batches: List[Any], idx: int) -> None: + for hook in self.hooks: + hook.on_val_step_start(trainer, batches, idx) + + def on_val_step_end(self, trainer: 'Trainer', outputs: List[Any], batches: List[Any], idx: int) -> None: + for hook in self.hooks: + hook.on_val_step_end(trainer, outputs, batches, idx) + + def after_aggregate_train_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', train_loss: float, idx: int) -> None: + for hook in self.hooks: + hook.after_aggregate_train_step_outputs(trainer, aggregated_outputs, train_loss, idx) + + def after_aggregate_val_step_outputs(self, trainer: 'Trainer', aggregated_outputs: 'AggregatedOutputs', val_loss: float, idx: int) -> None: + for hook in self.hooks: + hook.after_aggregate_val_step_outputs(trainer, aggregated_outputs, val_loss, idx) + + def before_zero_grad(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.before_zero_grad(trainer) + + def after_zero_grad(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.after_zero_grad(trainer) + + def before_sync_grad(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.before_sync_grad(trainer) + + def after_sync_grad(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.after_sync_grad(trainer) + + def before_gnorm_clip(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.before_gnorm_clip(trainer) + + def after_gnorm_clip(self, trainer: 'Trainer', gnorm: torch.Tensor) -> None: + for hook in self.hooks: + hook.after_gnorm_clip(trainer, gnorm) + + def before_optimizer_step(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.before_optimizer_step(trainer) + + def after_optimizer_step(self, trainer: 'Trainer') -> None: + for hook in self.hooks: + hook.after_optimizer_step(trainer) + + def on_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: + for hook in self.hooks: + hook.on_load_checkpoint(trainer, checkpoint) + + def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: + for hook in self.hooks: + hook.on_save_checkpoint(trainer, checkpoint) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 42c1750f..00611637 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -19,6 +19,7 @@ import nnscaler.utils from .trainer_args import AggregatedOutputs, TrainerArgs +from .train_hook import AggregatedTrainHook, TrainHook logger = logging.getLogger(__name__) @@ -229,10 +230,22 @@ def _create_model(): self.optimizer = self.train_args.create_parallel_optimizer(self.model) self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) self.loggers = self.train_args.create_loggers() - self.hook = self.train_args.create_hook() + + supported_hook_components = [ + self.model, + self.optimizer, + self.lr_scheduler, + ] + self.hook = AggregatedTrainHook( + [x for x in supported_hook_components if isinstance(x, TrainHook)] + + [self.train_args.create_hook()] + ) + self._log_config(self.train_args.to_dict()) self._load_checkpoint() + self.hook.after_setup(self) + @classmethod def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): state_dicts = [torch.load(f, map_location='cpu') for f in checkpoint_files] @@ -619,6 +632,9 @@ def train_epoch(self, epoch): data_iter.set_postfix({'loss': loss}) self.hook.before_sync_grad(self) + # actually `sync_shard_grad` is no-op here + # because trainer only supports end2end model + # and syncing grad in end2end model is done in `_train_step`. self.optimizer.sync_shard_grad() self.hook.after_sync_grad(self) diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index dae04974..bef0bc1c 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -214,6 +214,8 @@ class HookConfig: @dataclass class HookMapConfig: + after_setup: str = None + on_train_start: str = None on_train_end: str = None on_val_start: str = None diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py index feae1496..7d128c68 100644 --- a/nnscaler/runtime/f16_optimizer.py +++ b/nnscaler/runtime/f16_optimizer.py @@ -1,11 +1,17 @@ import logging +from typing import Optional, TYPE_CHECKING import torch +from nnscaler.cli.train_hook import TrainHook + +if TYPE_CHECKING: + from nnscaler.cli.trainer import Trainer + logger = logging.getLogger(__name__) -class MixedPrecisionF16OptimizerMixin: +class MixedPrecisionF16OptimizerMixin(TrainHook): """ A mixin class for mixed precision optimizer. Support both FP16 and BF16 parameters. @@ -20,6 +26,23 @@ class MixedPrecisionF16OptimizerMixin: def __init__(self, *args, **kwargs): # forward __init__ call to the next class in mro(method resolution order) super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + def after_setup(self, trainer: 'Trainer') -> None: + """ + Here we override the clip_gnorm and scale_grads methods in the optimizer. + Reason: + 1. The original clip_gnorm and scale_grads methods apply to bf16 grads, which is not what we want. + We need to apply them to fp32 grads. + 2. Combine the multiply_factors of clip_gnorm and scale_grads. So only one muliply is needed. + This can mitigate the precision loss caused by multiple multiplications. + Assumption: + `clip_gnorm` is called immediately after `scale_grads` in training loop. + """ + trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm + trainer.optimizer.clip_gnorm = self.overrided_clip_gnorm + trainer.optimizer._scale_grads = trainer.optimizer.scale_grads + trainer.optimizer.scale_grads = self.overrided_scale_grads @classmethod def build_fp32_params(cls, params): @@ -102,6 +125,9 @@ def _sync_f16_grads_to_fp32(self): p32.grad.data.copy_(p.grad.data) else: p32.grad = torch.zeros_like(p.data, dtype=torch.float) + if self._multiply_factor != 1.0: + p32.grad.mul_(self._multiply_factor) + self._multiply_factor = 1.0 def _sync_fp32_params_to_f16(self): # copy FP32 params back into FP16 model @@ -110,6 +136,24 @@ def _sync_fp32_params_to_f16(self): continue p.data.copy_(p32.data) + def overrided_scale_grads(self, scale: float): + """ + Scale the gradients by a factor. + Will override the original scale_grads method in ParallelOptimizer. + """ + self._multiply_factor *= scale + + def overrided_clip_gnorm(self, max_norm: Optional[float] = None) -> float: + """ + Will override the original clip_gnorm method in ParallelOptimizer. + """ + # self._clip_gnorm() is ParallelOptimizer.clip_gnorm + grad_norm = self._multiply_factor * self._clip_gnorm() + if max_norm is not None and max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp(max=1.0) + self._multiply_factor *= clip_coef + return grad_norm + class MixedPrecisionAdam(MixedPrecisionF16OptimizerMixin, torch.optim.Adam): def __init__(self, params, **kwargs): From c4e574a6ae4d580a05eb0d26f29a3d26814a1f5d Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 15 Aug 2024 03:15:35 +0000 Subject: [PATCH 1702/1892] Merged PR 2231: Align grad computation with fairseq 1. register pre-hook to each reducer to divide scaling_factor 2. adjust back grads after allreduce --- nnscaler/cli/trainer.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 00611637..3d94bcab 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -228,6 +228,10 @@ def _create_model(): self.model = pmodel_class() self.model.cuda() self.optimizer = self.train_args.create_parallel_optimizer(self.model) + # the reduce op is `sum` by default, follow torch's c10d, grad is divided by scaling_factor before allreduce + def reducer_pre_hook(reducer, grad): + grad.div_(self.train_args.scaling_factor) + self.optimizer.register_reducer_pre_hook(reducer_pre_hook) self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) self.loggers = self.train_args.create_loggers() @@ -464,7 +468,7 @@ def aggregate_outputs(self, loss_outputs, sync_group) -> AggregatedOutputs: torch.distributed.all_reduce(num_batches, group=sync_group) return AggregatedOutputs( - loss_sum = loss_sum.item(), + loss_sum=loss_sum.item(), num_batches=num_batches.item(), ) @@ -569,7 +573,7 @@ def _validate(self, step_stat: _StepStat): aggregated_outputs = aggregate_outputs(losses[:num_batches], self.sync_group) self.hook.after_aggregate_val_step_outputs( self, aggregated_outputs, - aggregated_outputs.loss_sum/aggregated_outputs.num_batches, + aggregated_outputs.loss_sum / aggregated_outputs.num_batches, idx ) loss_sum += aggregated_outputs.loss_sum @@ -639,18 +643,20 @@ def train_epoch(self, epoch): self.hook.after_sync_grad(self) # scale gradients + multiplier = self.train_args.scaling_factor if self.train_args.optimizer.grad_reduction == 'sum': - # do nothing. Already done in reducers + # do nothing. `multiplier` is already correct pass elif self.train_args.optimizer.grad_reduction == 'mean': if not aggregated_outputs.num_batches: raise RuntimeError("`aggregate_outputs` doesn't set `num_batches` field") - self.optimizer.scale_grads(1.0 / aggregated_outputs.num_batches) + multiplier /= aggregated_outputs.num_batches else: assert self.train_args.optimizer.grad_reduction == 'per-token-mean' if not aggregated_outputs.num_tokens: raise RuntimeError("`aggregate_outputs` doesn't set `num_tokens` field") - self.optimizer.scale_grads(1.0 / aggregated_outputs.num_tokens) + multiplier /= aggregated_outputs.num_tokens + self.optimizer.scale_grads(multiplier) # clip gradients self.hook.before_gnorm_clip(self) From c07cd2d626f234077457d8af85b30fe1b5a6f135 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 15 Aug 2024 06:43:07 +0000 Subject: [PATCH 1703/1892] Merged PR 2223: refine pytree in this pr, directly copy torch 2.3 pytree as `_pytree.py`, it is more powerful than the previous version, and implemented most of the functions we needed, the only reason to copy the file instead of import is we also support version 2.0 <= torch < 2.3. there is no need to review `_pytree.py`. the additional support for pytree is in `pytree_utils.py` parity check passed. --- .../parser/fx/concrete_trace_utils/_pytree.py | 1552 +++++++++++++++++ .../fx/concrete_trace_utils/concrete_proxy.py | 7 +- .../concrete_trace_utils/concrete_tracer.py | 77 +- .../fx/concrete_trace_utils/orig_func.py | 43 + .../fx/concrete_trace_utils/pytree_utils.py | 105 ++ .../parser/fx/concrete_trace_utils/utils.py | 145 +- tests/graph/tracer/test_pytree.py | 140 +- 7 files changed, 1768 insertions(+), 301 deletions(-) create mode 100644 nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py create mode 100644 nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py create mode 100644 nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py b/nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py new file mode 100644 index 00000000..3cc31f63 --- /dev/null +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/_pytree.py @@ -0,0 +1,1552 @@ +""" +NOTE: This file is copy from torch 2.3.0, and make some extension + +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. + +This pytree implementation is not very performant due to Python overhead +To improve the performance we can move parts of the implementation to C++. +""" + +import dataclasses +import importlib +import json +import sys +import threading +import types +import warnings +from collections import defaultdict, deque, namedtuple, OrderedDict +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Deque, + Dict, + FrozenSet, + Generic, + Hashable, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + OrderedDict as GenericOrderedDict, + overload, + Protocol, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + + +__all__ = [ + "PyTree", + "Context", + "FlattenFunc", + "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", + "TreeSpec", + "LeafSpec", + "keystr", + "key_get", + "register_pytree_node", + "tree_flatten", + "tree_flatten_with_path", + "tree_unflatten", + "tree_leaves", + "tree_leaves_with_path", + "tree_structure", + "tree_map", + "tree_map_with_path", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_dumps", + "treespec_loads", + "treespec_pprint", +] + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + + +DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 +NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" + + +class KeyEntry(Protocol): + def __hash__(self) -> int: + ... + + def __eq__(self, other: object) -> bool: + ... + + def __str__(self) -> str: + ... + + def get(self, parent: Any) -> Any: + ... + + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +ToStrFunc = Callable[["TreeSpec", List[str]], str] +MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] +KeyPath = Tuple[KeyEntry, ...] +FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]] + + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +# - flatten_with_keys_fn, which is a callable that takes a +# pytree and returns a list of (keypath, value) pairs and a context. +class NodeDef(NamedTuple): + type: Type[Any] + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] + + +_NODE_REGISTRY_LOCK = threading.Lock() +SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} + + +# _SerializeNodeDef holds the following: +# - typ: the type of the node (e.g., "Dict", "List", etc) +# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict" +# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the +# context, and the version number +# - from_dumpable_context takes in a string representation of the context, and the +# version, and returns the deserialized context +class _SerializeNodeDef(NamedTuple): + typ: Type[Any] + serialized_type_name: str + to_dumpable_context: Optional[ToDumpableContextFn] + from_dumpable_context: Optional[FromDumpableContextFn] + + +SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} +SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} + + +def register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + try: + from . import _cxx_pytree as cxx + except ImportError: + pass + else: + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, # deprecated + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node for the Python pytree only. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + """ + warnings.warn( + "torch.utils._pytree._register_pytree_node is deprecated. " + "Please use torch.utils._pytree.register_pytree_node instead.", + stacklevel=2, + ) + + if to_str_fn is not None or maybe_from_str_fn is not None: + warnings.warn( + "to_str_fn and maybe_from_str_fn is deprecated. " + "Please use to_dumpable_context and from_dumpable_context instead." + ) + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + +def _private_register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the Python pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " + "Overwriting the previous registration.", + ) + + node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn) + SUPPORTED_NODES[cls] = node_def + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + if serialized_type_name is None: + serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND + + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + + +@dataclasses.dataclass(frozen=True) +class SequenceKey(Generic[T]): + idx: int + + def __str__(self) -> str: + return f"[{self.idx!r}]" + + def get(self, sequence: Sequence[T]) -> T: + return sequence[self.idx] + + +K = TypeVar("K", bound=Hashable) + + +@dataclasses.dataclass(frozen=True) +class MappingKey(Generic[K, T]): + key: K + + def __str__(self) -> str: + return f"[{self.key!r}]" + + def get(self, mapping: Mapping[K, T]) -> T: + return mapping[self.key] + + +@dataclasses.dataclass(frozen=True) +class GetAttrKey: + name: str + + def __str__(self) -> str: + return f".{self.name}" + + def get(self, obj: Any) -> Any: + return getattr(obj, self.name) + + +def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: + return list(d), None + + +def _tuple_flatten_with_keys( + d: Tuple[Any, ...] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _tuple_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]: + return tuple(values) + + +def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return d, None + + +def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _list_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]: + return list(values) + + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _dict_flatten_with_keys( + d: Dict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _dict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]: + return dict(zip(context, values)) + + +def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: + return list(d), type(d) + + +def _namedtuple_flatten_with_keys( + d: NamedTuple, +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _namedtuple_flatten(d) + return ( + [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], + context, + ) + + +def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple: + return cast(NamedTuple, context(*values)) + + +def _namedtuple_serialize(context: Context) -> DumpableContext: + json_namedtuple = { + "class_name": context.__name__, + "fields": context._fields, + } + return json_namedtuple + + +def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: + class_name = dumpable_context["class_name"] + assert isinstance(class_name, str) + context = namedtuple(class_name, dumpable_context["fields"]) # type: ignore[misc] + return context + + +def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _ordereddict_flatten_with_keys( + d: GenericOrderedDict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _ordereddict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _ordereddict_unflatten( + values: Iterable[Any], + context: Context, +) -> GenericOrderedDict[Any, Any]: + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_odict_flatten = _ordereddict_flatten +_odict_unflatten = _ordereddict_unflatten + + +def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: + values, dict_context = _dict_flatten(d) + return values, [d.default_factory, dict_context] + + +def _defaultdict_flatten_with_keys( + d: DefaultDict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _defaultdict_flatten(d) + _, dict_context = context + return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context + + +def _defaultdict_unflatten( + values: Iterable[Any], + context: Context, +) -> DefaultDict[Any, Any]: + default_factory, dict_context = context + return defaultdict(default_factory, _dict_unflatten(values, dict_context)) + + +def _defaultdict_serialize(context: Context) -> DumpableContext: + default_factory, dict_context = context + json_defaultdict = { + "default_factory_module": default_factory.__module__, + "default_factory_name": default_factory.__qualname__, + "dict_context": dict_context, + } + return json_defaultdict + + +def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: + assert isinstance(dumpable_context, dict) + assert set(dumpable_context) == { + "default_factory_module", + "default_factory_name", + "dict_context", + } + + default_factory_module = dumpable_context["default_factory_module"] + default_factory_name = dumpable_context["default_factory_name"] + assert isinstance(default_factory_module, str) + assert isinstance(default_factory_name, str) + module = importlib.import_module(default_factory_module) + default_factory = getattr(module, default_factory_name) + + dict_context = dumpable_context["dict_context"] + return [default_factory, dict_context] + + +def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]: + return list(d), d.maxlen + + +def _deque_flatten_with_keys( + d: Deque[Any], +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _deque_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]: + return deque(values, maxlen=context) + + +_private_register_pytree_node( + tuple, + _tuple_flatten, + _tuple_unflatten, + serialized_type_name="builtins.tuple", + flatten_with_keys_fn=_tuple_flatten_with_keys, +) +_private_register_pytree_node( + list, + _list_flatten, + _list_unflatten, + serialized_type_name="builtins.list", + flatten_with_keys_fn=_list_flatten_with_keys, +) +_private_register_pytree_node( + dict, + _dict_flatten, + _dict_unflatten, + serialized_type_name="builtins.dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) +_private_register_pytree_node( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name="collections.namedtuple", + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, +) +_private_register_pytree_node( + OrderedDict, + _ordereddict_flatten, + _ordereddict_unflatten, + serialized_type_name="collections.OrderedDict", + flatten_with_keys_fn=_ordereddict_flatten_with_keys, +) +_private_register_pytree_node( + defaultdict, + _defaultdict_flatten, + _defaultdict_unflatten, + serialized_type_name="collections.defaultdict", + to_dumpable_context=_defaultdict_serialize, + from_dumpable_context=_defaultdict_deserialize, + flatten_with_keys_fn=_defaultdict_flatten_with_keys, +) +_private_register_pytree_node( + deque, + _deque_flatten, + _deque_unflatten, + serialized_type_name="collections.deque", + flatten_with_keys_fn=_deque_flatten_with_keys, +) + + +STANDARD_DICT_TYPES: FrozenSet[type] = frozenset( + {dict, OrderedDict, defaultdict}, +) +BUILTIN_TYPES: FrozenSet[type] = frozenset( + {tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type] +) + + +# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +def _is_namedtuple_instance(tree: Any) -> bool: + typ = type(tree) + bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + fields = getattr(typ, "_fields", None) + if not isinstance(fields, tuple): + return False + return all(type(entry) == str for entry in fields) + + +def _get_node_type(tree: Any) -> Any: + if _is_namedtuple_instance(tree): + return namedtuple + return type(tree) + + +# A leaf is defined as anything that is not a Node. +def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: + return (is_leaf is not None and is_leaf(tree)) or _get_node_type( + tree + ) not in SUPPORTED_NODES + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +@dataclasses.dataclass +class TreeSpec: + type: Any + context: Context + children_specs: List["TreeSpec"] + + num_nodes: int = dataclasses.field(init=False) + num_leaves: int = dataclasses.field(init=False) + num_children: int = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.num_nodes = 1 + sum(spec.num_nodes for spec in self.children_specs) + self.num_leaves = sum(spec.num_leaves for spec in self.children_specs) + self.num_children = len(self.children_specs) + + def __repr__(self, indent: int = 0) -> str: + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" + if self.num_children > 0: + indent += 2 + children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += "," if self.num_children > 1 else "" + children_specs_str += ",".join( + [ + "\n" + " " * indent + child.__repr__(indent) + for child in self.children_specs[1:] + ] + ) + repr_suffix: str = f"{children_specs_str}])" + return repr_prefix + repr_suffix + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: + if self.is_leaf(): + subtrees.append(tree) + return + + node_type = _get_node_type(tree) + if self.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != self.type: + raise ValueError( + f"Type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if len(child_pytrees) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(child_pytrees)}.", + ) + if context != self.context: + raise ValueError( + f"Node context mismatch for custom node type {self.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES + ) + if node_type != self.type and not both_standard_dict: + raise ValueError( + f"Node type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + if len(tree) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(tree)}.", + ) + + if both_standard_dict: # dictionary types are compatible with each other + dict_context = ( + self.context + if self.type is not defaultdict + # ignore mismatch of `default_factory` for defaultdict + else self.context[1] + ) + expected_keys = dict_context + got_key_set = set(tree) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + child_pytrees = [tree[key] for key in expected_keys] + else: + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if ( + context != self.context + and self.type is not deque # ignore mismatch of `maxlen` for deque + ): + raise ValueError( + f"Node context mismatch for node type {self.type!r}; " + f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch + ) + + for child_pytree, child_spec in zip(child_pytrees, self.children_specs): + child_spec._flatten_up_to_helper(child_pytree, subtrees) + + def flatten_up_to(self, tree: PyTree) -> List[PyTree]: + subtrees: List[PyTree] = [] + self._flatten_up_to_helper(tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in self.children_specs: + end += child_spec.num_leaves + child_pytrees.append(child_spec.unflatten(leaves[start:end])) + start = end + + return unflatten_fn(child_pytrees, self.context) + + +class LeafSpec(TreeSpec): + def __init__(self) -> None: + super().__init__(None, None, []) + + def __post_init__(self) -> None: + self.num_nodes = 1 + self.num_leaves = 1 + self.num_children = 0 + + def __repr__(self, indent: int = 0) -> str: + return "*" + + +# All leaves are equivalent, so represent with a single object to save on +# object construction time +_LEAF_SPEC = LeafSpec() + + +def _tree_flatten_helper( + tree: PyTree, + leaves: List[Any], + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + if _is_leaf(tree, is_leaf=is_leaf): + leaves.append(tree) + return _LEAF_SPEC + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + + # Recursively flatten the children + children_specs = [ + _tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees + ] + + return TreeSpec(node_type, context, children_specs) + + +def tree_flatten( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used + to reconstruct the pytree. + """ + leaves: List[Any] = [] + spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf) + return leaves, spec + + +def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + This is the inverse operation of `tree_flatten`. + """ + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be " + f"instance of TreeSpec but got item of type {type(treespec)}.", + ) + return treespec.unflatten(leaves) + + +def _tree_leaves_helper( + tree: PyTree, + leaves: List[Any], + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> None: + if _is_leaf(tree, is_leaf=is_leaf): + leaves.append(tree) + return + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, _ = flatten_fn(tree) + + # Recursively flatten the children + for child in child_pytrees: + _tree_leaves_helper(child, leaves, is_leaf=is_leaf) + + +def tree_leaves( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> List[Any]: + """Get a list of leaves of a pytree.""" + leaves: List[Any] = [] + _tree_leaves_helper(tree, leaves, is_leaf=is_leaf) + return leaves + + +def tree_structure( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + """Get the TreeSpec for a pytree.""" + return tree_flatten(tree, is_leaf=is_leaf)[1] + + +def tree_map( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Map a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map_`. + + >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + {'x': False, 'y': (False, False), 'z': True} + + If multiple inputs are given, the structure of the tree is taken from the first input; + subsequent inputs need only have ``tree`` as a prefix: + + >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` + is the tuple of values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + +def tree_map_( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. + + See also :func:`tree_map`. + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + tuple(map(func, *flat_args)) # consume and exhaust the iterable + return tree + + +Type2 = Tuple[Type[T], Type[S]] +Type3 = Tuple[Type[T], Type[S], Type[U]] +if sys.version_info >= (3, 10): + TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType] +else: + TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] + +Fn2 = Callable[[Union[T, S]], R] +Fn3 = Callable[[Union[T, S, U]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]: + ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]: + ... + + +def map_only( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]] +) -> MapOnlyFn[FnAny[Any]]: + """ + Suppose you are writing a tree_map over tensors, leaving everything + else unchanged. Ordinarily you would have to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + """ + if isinstance(__type_or_types_or_pred, (type, tuple)) or ( + sys.version_info >= (3, 10) + and isinstance(__type_or_types_or_pred, types.UnionType) + ): + + def pred(x: Any) -> bool: + return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type] + + elif callable(__type_or_types_or_pred): + pred = __type_or_types_or_pred # type: ignore[assignment] + else: + raise TypeError("Argument must be a type, a tuple of types, or a callable.") + + def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: + # @functools.wraps(func) # torch dynamo doesn't support this yet + def wrapped(x: T) -> Any: + if pred(x): + return func(x) + return x + + return wrapped + + return wrapper + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type[T], + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Callable[[Any], bool], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type[T], + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Callable[[Any], bool], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only_( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +def tree_all( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_leaves(tree, is_leaf=is_leaf) + return all(map(pred, flat_args)) + + +def tree_any( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_leaves(tree, is_leaf=is_leaf) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_all_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_leaves(tree, is_leaf=is_leaf) + return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +@overload +def tree_any_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_any_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_leaves(tree, is_leaf=is_leaf) + return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten( + tree: PyTree, + treespec: TreeSpec, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Optional[List[Any]]: + assert isinstance(treespec, TreeSpec) + + if _is_leaf(tree, is_leaf=is_leaf): + return [tree] * treespec.num_leaves + if treespec.is_leaf(): + return None + node_type = _get_node_type(tree) + if node_type != treespec.type: + return None + + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, ctx = flatten_fn(tree) + + # Check if the Node is different from the spec + if len(child_pytrees) != treespec.num_children or ctx != treespec.context: + return None + + # Recursively flatten the children + result: List[Any] = [] + for child, child_spec in zip(child_pytrees, treespec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) + if flat is not None: + result += flat + else: + return None + + return result + + +@dataclasses.dataclass +class _TreeSpecSchema: + """ + _TreeSpecSchema is the schema used to serialize the TreeSpec + It contains the following fields: + - type: A string name of the type. null for the case of a LeafSpec. + - context: Any format which is json dumpable + - children_spec: A list of children serialized specs. + """ + + type: Optional[str] + context: DumpableContext + children_spec: List["_TreeSpecSchema"] + + +class _ProtocolFn(NamedTuple): + treespec_to_json: Callable[[TreeSpec], DumpableContext] + json_to_treespec: Callable[[DumpableContext], TreeSpec] + + +_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} + + +def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: + if treespec.is_leaf(): + return _TreeSpecSchema(None, None, []) + + if treespec.type not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Serializing {treespec.type} in pytree is not registered.", + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] + + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"No registered serialization name for {treespec.type} found. " + "Please update your _register_pytree_node call with a `serialized_type_name` kwarg." + ) + + if serialize_node_def.to_dumpable_context is None: + try: + serialized_context = json.dumps(treespec.context) + except TypeError as e: + raise TypeError( + "Unable to serialize context. " + "Please make the context json dump-able, or register a " + "custom serializer using _register_pytree_node." + ) from e + else: + serialized_context = serialize_node_def.to_dumpable_context(treespec.context) + + child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] + + return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) + + +def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: + if ( + json_schema["type"] is None + and json_schema["context"] is None + and len(json_schema["children_spec"]) == 0 + ): + return _LEAF_SPEC + + if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f'Deserializing {json_schema["type"]} in pytree is not registered.', + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] + + if serialize_node_def.from_dumpable_context is None: + try: + context = json.loads(json_schema["context"]) + except TypeError as ex: + raise TypeError( + "Unable to deserialize context. " + "Please make the context json load-able, or register a " + "custom serializer using _register_pytree_node.", + ) from ex + else: + context = serialize_node_def.from_dumpable_context(json_schema["context"]) + + children_specs = [] + for child_string in json_schema["children_spec"]: + children_specs.append(_json_to_treespec(child_string)) + + return TreeSpec(typ, context, children_specs) + + +_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " + f"TreeSpec but got item of type {type(treespec)}.", + ) + + if protocol is None: + protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL + + if protocol in _SUPPORTED_PROTOCOLS: + json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) + else: + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) + return str_spec + + +def treespec_loads(serialized: str) -> TreeSpec: + protocol, json_schema = json.loads(serialized) + + if protocol in _SUPPORTED_PROTOCOLS: + return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + +class _DummyLeaf: + def __repr__(self) -> str: + return "*" + + +def treespec_pprint(treespec: TreeSpec) -> str: + dummy_tree = tree_unflatten( + [_DummyLeaf() for _ in range(treespec.num_leaves)], + treespec, + ) + return repr(dummy_tree) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def pytree_to_str(treespec: TreeSpec) -> str: + warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") + return treespec_dumps(treespec) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def str_to_pytree(json: str) -> TreeSpec: + warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") + return treespec_loads(json) + + +def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: + """Get a flat list of arguments to this function + + A slightly faster version of tree_leaves((args, kwargs)) + """ + leaves: List[Any] = [] + for a in args: + _tree_leaves_helper(a, leaves) + for a in kwargs.values(): + _tree_leaves_helper(a, leaves) + return leaves + + +def tree_flatten_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]: + """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. + + Args: + tree: a pytree to flatten. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A tuple where the first element is a list of (key path, leaf) pairs, and the + second element is a :class:`TreeSpec` representing the structure of the flattened + tree. + """ + _, treespec = tree_flatten(tree, is_leaf) + return list(_generate_key_paths((), tree, is_leaf)), treespec + + +def tree_leaves_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> List[Tuple[KeyPath, Any]]: + """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. + + Args: + tree: a pytree. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A list of (key path, leaf) pairs. + """ + return list(_generate_key_paths((), tree, is_leaf)) + + +def _generate_key_paths( + key_path: KeyPath, + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Tuple[KeyPath, Any]]: + if is_leaf and is_leaf(tree): + yield key_path, tree + return + + node_type = _get_node_type(tree) + handler = SUPPORTED_NODES.get(node_type) + if not handler: + # This is a leaf + yield key_path, tree + return + + flatten_with_keys = handler.flatten_with_keys_fn + if flatten_with_keys: + key_children, _ = flatten_with_keys(tree) + for k, c in key_children: + yield from _generate_key_paths((*key_path, k), c, is_leaf) + else: + # We registered this pytree but didn't add a flatten_with_keys_fn, complain. + raise ValueError( + f"Did not find a flatten_with_keys_fn for type: {node_type}. " + "Please pass a flatten_with_keys_fn argument to register_pytree_node." + ) + + +def tree_map_with_path( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but the provided callable takes an additional key path argument. + + Args: + func: A function that takes ``2 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. The first positional argument + to ``func`` is the key path of the leaf in question. The second + positional argument is the value of the leaf. + tree: A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests: A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the + corresponding leaf in ``tree``, ``x`` is the value at that leaf, and + ``xs`` is the tuple of values at corresponding nodes in ``rests``. + """ + keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf) + keypath_leaves = list(zip(*keypath_leaves)) + all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) + + +def keystr(kp: KeyPath) -> str: + """Given a key path, return a pretty-printed representation.""" + return "".join([str(k) for k in kp]) + + +def key_get(obj: Any, kp: KeyPath) -> Any: + """Given an object and a key path, return the value at the key path.""" + for k in kp: + obj = k.get(obj) + return obj diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 58ed679f..2c58b9ae 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -18,6 +18,7 @@ from torch.overrides import is_tensor_method_or_property from . import concrete_tracer as et +from . import pytree_utils from .utils import ( _orig_tuple, _orig_list, @@ -31,7 +32,6 @@ _orig_bool, _orig_slice, _orig_set, - map_recursive, get_frame_record, ) @@ -262,8 +262,9 @@ def __torch_function__(cls, orig_method, types, args=None, kwargs=None): def find_tracer(a): if _orig_isinstance(a, cls): tracers.add(a.tracer) - map_recursive(find_tracer, args) - map_recursive(find_tracer, kwargs) + + pytree_utils.tree_map(find_tracer, args) + pytree_utils.tree_map(find_tracer, kwargs) if _orig_len(tracers) > 1: raise RuntimeError(f'Found multiple different tracers {_orig_list(tracers)} while ' diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 41613bc0..5b449150 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -83,6 +83,7 @@ def __exit__(self, *args): return from . import concrete_proxy as ep +from . import pytree_utils from .function_patcher import FunctionPatcher from .operator_patcher import OperatorPatcherContext from .utils import ( @@ -135,10 +136,6 @@ def __exit__(self, *args): ExtraSEFPatcher, EmptyResult, extract_results_metadata, - flatten_trees_with_func, - flatten_trees_with_func_and_spec, - get_common_spec, - map_trees_with_func, get_frame_record, ) @@ -410,21 +407,26 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] try: if self.cpu_offload: - args, kwargs = tree_to_cuda(args), tree_to_cuda(kwargs) + args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cuda(), args) + kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cuda(), kwargs) result = run(kind, target, args, kwargs) except torch.cuda.OutOfMemoryError: if self.cpu_offload: _logger.warning(f"cuda out of memory, try to trace {target} on cpu.") - args, kwargs = tree_to_cpu(args), tree_to_cpu(kwargs) + args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), args) + kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), kwargs) result = run(kind, target, args, kwargs) else: raise if self.cpu_offload: - args, kwargs, result = tree_to_cpu(args), tree_to_cpu(kwargs), tree_to_cpu(result) + args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), args) + kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), kwargs) + result = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), result) - unexpected_types = types_other_than(result, (*base_types, type(None), torch.Tensor)) - if not contains_types(result, (torch.Tensor,)) and unexpected_types: + if not pytree_utils.tree_any(lambda x: isinstance(x, torch.Tensor), result) and \ + pytree_utils.tree_any(lambda x: not isinstance(x, (*base_types, type(None), torch.Tensor)), result): + unexpected_types = set([type(elem) for elem in pytree_utils.tree_flatten(result)[0] if not isinstance(elem, (*base_types, type(None), torch.Tensor))]) _logger.warning(f"result of target {target} contains unexpected types {unexpected_types}, which is not a common behavior.") torch.cuda.empty_cache() @@ -474,7 +476,11 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: create the nodes for the target and the input of the target (if the target is one of call_method, call_function, call_module). """ with self.do_temp_call_origin(): - args_unwrapped, kwargs_unwrapped = unwrap_nested_proxy(args), unwrap_nested_proxy(kwargs) + def unwrap_nested_proxy(proxy: ep.ConcreteProxy): + return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) + + args_unwrapped = pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, args) + kwargs_unwrapped = pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, kwargs) if self.need_revert(target): with self.patcher.revert(): @@ -902,6 +908,8 @@ class map_wrapper_clz: _fx_wrapped_ori_clz = _orig_map def __new__(cls, the_func, *iterables: Any): + if self.temp_call_origin: + return _orig_map(the_func, *iterables) tracers = _orig_set() for one_iter in iterables: if _orig_isinstance(one_iter, ep.Proxy): @@ -1358,47 +1366,6 @@ def wrapped(*args, **kwargs): return wrapped -def contains_types(pytree, types) -> bool: - """if pytree leaf has the given types, return true""" - return any(flatten_trees_with_func(lambda x: isinstance(x, types), [pytree])[0]) - - -def types_other_than(pytree, given_types) -> Set[Type]: - """return a set of types of the pytree leaf other than given_types""" - types = set(flatten_trees_with_func(lambda x: type(x) if not isinstance(x, given_types) else None, [pytree])[0]) - if None in types: - types.remove(None) - return types - - -def tree_to_cuda(pytree): - """return a same spec pytree with all the given pytree leaf tensor to cuda""" - # any operations under torch.no_grad context will have the result tensor with attribute requires_grad is False, - # here we must follow the original tensor requires_grad attribute when we move tensor to cuda to ensure the correctness of the tensor requires_grad state - return map_trees_with_func(lambda a: a.cuda().requires_grad_(a.requires_grad) if isinstance(a, torch.Tensor) else a, [pytree]) - - -def tree_to_cpu(pytree): - """return a same spec pytree with all the given pytree leaf tensor to cpu""" - # any operations under torch.no_grad context will have the result tensor with attribute requires_grad is False, - # here we must follow the original tensor requires_grad attribute when we move tensor to cpu to ensure the correctness of the tensor requires_grad state - return map_trees_with_func(lambda a: a.cpu().requires_grad_(a.requires_grad) if isinstance(a, torch.Tensor) else a, [pytree]) - - -def unwrap_nested_proxy(pytree): - """ - return a same spec pytree with the ConcreteProxy in the old pytree replaced with ConcreteProxy.value - """ - def unwrap(obj: Any): - while isinstance(obj, ep.ConcreteProxy): - obj = obj.value - return obj - - while contains_types(pytree, (ep.ConcreteProxy,)): - pytree = map_trees_with_func(unwrap, [pytree]) - return pytree - - def update_tree_proxy_value(dst_pytree, src_pytree): """ copy the value from src_pytree to dst_pytree with the dst_pytree spec, @@ -1408,7 +1375,7 @@ def update_tree_proxy_value(dst_pytree, src_pytree): # dst_pytree: {'a': [1, 2, 3]} # src_pytree: {'a': [1, 2, 3, 4]} # then the public spec is {'a': *}, we don't want to flatten the list here. - spec = get_common_spec(tree_flatten(dst_pytree)[1], tree_flatten(src_pytree)[1]) + common_spec = pytree_utils.get_common_spec(pytree_utils.tree_structure(dst_pytree), pytree_utils.tree_structure(src_pytree)) def update_proxy_value(a, b): if isinstance(a, ep.ConcreteProxy): @@ -1417,8 +1384,10 @@ def update_proxy_value(a, b): else: return b - flat_arg = flatten_trees_with_func_and_spec(update_proxy_value, [dst_pytree, src_pytree], spec) - return tree_unflatten(flat_arg, spec) + flat_dst_leaves = pytree_utils.tree_leaves_with_spec(dst_pytree, common_spec) + flat_src_leaves = pytree_utils.tree_leaves_with_spec(src_pytree, common_spec) + new_leaves = [update_proxy_value(dst_leaf, src_leaf) for dst_leaf, src_leaf in zip(flat_dst_leaves, flat_src_leaves)] + return pytree_utils.tree_unflatten(new_leaves, common_spec) @compatibility(is_backward_compatible=True) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py b/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py new file mode 100644 index 00000000..9918446b --- /dev/null +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py @@ -0,0 +1,43 @@ +""" +During tracing, the function or class in this file might be wrapped as another function or class. +If the original function is needed to use (usually in tracer), should call the function in this file. +""" + +# all functions in operator will be wrapped during tracing +from operator import * + +# the wrapped functon/class in builtins +import builtins + +isinstance = builtins.isinstance +issubclass = builtins.issubclass +len = builtins.len +getattr = builtins.getattr +id = builtins.id + +bool = builtins.bool +int = builtins.int +float = builtins.float +frozenset = builtins.frozenset +tuple = builtins.tuple +list = builtins.list +set = builtins.set +dict = builtins.dict + +enumerate = builtins.enumerate +map = builtins.map +range = builtins.range +reversed = builtins.reversed +type = builtins.type +slice = builtins.slice + +# the wrapped functon/class method/class in torch +import torch + +torch_module_getattr = torch.nn.Module.__getattr__ +torch_module_getattribute = torch.nn.Module.__getattribute__ +torch_module_call = torch.nn.Module.__call__ +torch_agfunc_apply = torch.autograd.function.Function.apply +torch_assert = torch._assert +torch_Size = torch.Size +torch_finfo = torch.finfo diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py new file mode 100644 index 00000000..b428c0e0 --- /dev/null +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py @@ -0,0 +1,105 @@ +""" +This file is the pytree extension by nnscaler. +""" +from collections import namedtuple +from typing import Any, List, Tuple, Iterable + +from . import orig_func, _pytree +from ._pytree import * + +import nnscaler.graph.parser.fx.concrete_trace_utils.concrete_proxy as cct + + +# if pytree is a ConcreteProxy, type(pytree) will return the type of ConcreteProxy.value +# so here we need add additional check for ConcreteProxy. +# this function will override the `_get_node_type` function defined in the original pytree code +def _get_node_type(pytree: _pytree.PyTree) -> Any: + if orig_func.isinstance(pytree, cct.ConcreteProxy): + return orig_func.type(pytree) + if _pytree._is_namedtuple_instance(pytree): + return namedtuple + return type(pytree) + +_pytree._get_node_type = _get_node_type + + +# by default, the registered types are: +# builtins.tuple +# builtins.list +# builtins.dict +# collections.namedtuple +# collections.OrderedDict +# collections.defaultdict +# collections.deque + +# register slice to pytree. +def _slice_flatten(d: slice) -> Tuple[List[Any], Context]: + return [d.start, d.stop, d.step], None + + +def _slice_flatten_with_keys( + d: Tuple[Any, ...] +) -> Tuple[List[Tuple[_pytree.KeyEntry, Any]], Context]: + values, context = _slice_flatten(d) + return [('start', values[0]), ('stop', values[1]), ('step', values[2])], context + + +def _slice_unflatten(values: Iterable[Any], context: Context) -> slice: + return slice(*values) + + +_pytree._private_register_pytree_node( + slice, + _slice_flatten, + _slice_unflatten, + serialized_type_name="builtins.slice", + flatten_with_keys_fn=_slice_flatten_with_keys, +) + + +def tree_leaves_with_spec(pytree: _pytree.PyTree, spec: TreeSpec) -> List: + """ + Flat a pytree with a given spec. + + Example: + + pytree = [1, (2, {3: 4})] + spec = TreeSpec([*, (*, *)]) + + # the returned value is + [1, 2, {3: 4}] + """ + assert isinstance(spec, TreeSpec) + + if isinstance(spec, LeafSpec): + return [pytree] + + flatten_fn = _pytree.SUPPORTED_NODES[spec.type].flatten_fn + child_pytrees, _ = flatten_fn(pytree) + + if len(child_pytrees) != len(spec.children_specs): + raise RuntimeError(f'The number of pytree children is not equal to the give specs.') + + result = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = tree_leaves_with_spec(child, child_spec) + result += flat + + return result + + +def get_common_spec(*specs: TreeSpec) -> TreeSpec: + """ + Return the common part of treespecs. + For example: + specs[0] is {'a': [*,], 'b': [*, *]} + specs[1] is {'a': [*,], 'b': [*, *, *]} + common spec is {'a': [*,], 'b': *} + """ + if tree_any(lambda spec: isinstance(spec, LeafSpec), specs): + return LeafSpec() + if all(spec.type == specs[0].type and spec.context == specs[0].context for spec in specs): + if all(len(spec.children_specs) == len(specs[0].children_specs) for spec in specs): + children_specs = [get_common_spec(*children_specs) for children_specs in zip(*(spec.children_specs for spec in specs))] + return TreeSpec(type=specs[0].type, context=specs[0].context, children_specs=children_specs) + return LeafSpec() diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py index c1ba419b..b925ca7a 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py @@ -2,20 +2,15 @@ # Licensed under the MIT license. import builtins -from collections import namedtuple from dataclasses import dataclass import importlib import operator import traceback from pathlib import Path -from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Type, List +from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Type import torch -import torch.utils._pytree as torch_pytree from torch.fx.node import Node, map_aggregate, _side_effectful_functions -from torch.utils._pytree import tree_flatten, tree_unflatten, LeafSpec, TreeSpec, SUPPORTED_NODES - -from . import concrete_proxy as ep DICT_KEYS_TYPE = type({}.keys()) DICT_VALUES_TYPE= type({}.values()) @@ -76,144 +71,6 @@ } -def map_recursive(fn: Callable, arg) -> Any: - """ - Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. - """ - if _orig_type(arg) != torch.Size and _orig_isinstance(arg, _orig_tuple): - t = _orig_tuple(map_recursive(fn, elem) for elem in arg) - # Support NamedTuple (if it has `_fields`) by repacking into original type. - return t if not hasattr(arg, '_fields') else _orig_type(arg)(*t) - elif _orig_isinstance(arg, _orig_list): - return _orig_list(map_recursive(fn, elem) for elem in arg) - elif _orig_isinstance(arg, _orig_dict): - return {k: map_recursive(fn, v) for k, v in arg.items()} - else: - return fn(arg) - - -def _get_node_type(pytree: Any) -> Any: - if isinstance(pytree, ep.ConcreteProxy): - return _orig_type(pytree) - if torch_pytree._is_namedtuple_instance(pytree): - return namedtuple - return type(pytree) - -torch_pytree._get_node_type = _get_node_type - - -def get_common_spec(dst_spec: TreeSpec, src_sepc: TreeSpec) -> TreeSpec: - """ - Return the common part of two treespec. - For example: - dst_spec is {'a': [*,], 'b': [*, *]} - src_sepc is {'a': [*,], 'b': [*, *, *]} - common spec is {'a': [*,], 'b': *} - """ - if isinstance(dst_spec, LeafSpec) or isinstance(src_sepc, LeafSpec): - return LeafSpec() - if dst_spec.type == src_sepc.type and dst_spec.context == src_sepc.context: - if len(dst_spec.children_specs) == len(src_sepc.children_specs): - children_specs = [get_common_spec(dst, src) for dst, src in zip(dst_spec.children_specs, src_sepc.children_specs)] - return TreeSpec(type=dst_spec.type, context=dst_spec.context, children_specs=children_specs) - return LeafSpec() - - -def flatten_trees_with_func(fn, pytrees) -> Tuple[List[Any], TreeSpec]: - """ - Each pytree in pytrees should have the same structure. - - Example: - - pytrees = [ - [1, 2, (3, 4)], # pytree 1 - [5, 6, (7, 8)], # pytree 2 - ] - - # the returned value is - [fn(1, 5), fn(2, 6), fn(3, 7), fn(4, 8)], [*, *, (*, *)] - """ - flat_trees = [tree_flatten(pytree) for pytree in pytrees] - flat_args = [v[0] for v in flat_trees] - specs = [v[1] for v in flat_trees] - - if not all(len(flat_arg) == len(flat_args[0]) for flat_arg in flat_args): - raise RuntimeError('the element number of pytrees are not equal') - if not all(str(spec) == str(specs[0]) for spec in specs): - raise RuntimeError('the structure of pytrees are not equal') - - return [fn(*vals) for vals in zip(*flat_args)], specs[0] - - -def map_trees_with_func(fn, pytrees): - """ - Each pytree in pytrees should have the same structure. - The returned value has the same structure with pytree in pytrees. - - Example: - - pytrees = [ - [1, 2, (3, 4)], # pytree 1 - [5, 6, (7, 8)], # pytree 2 - ] - - # the returned value is - [fn(1, 5), fn(2, 6), (fn(3, 7), fn(4, 8))] - """ - flat_args, spec = flatten_trees_with_func(fn, pytrees) - return tree_unflatten([i for i in flat_args], spec) - - -def flatten_tree_with_spec(pytree, spec: TreeSpec) -> List: - """ - Flat a pytree with a given spec. - - Example: - - pytree = [1, (2, {3: 4})] - spec = TreeSpec([*, (*, *)]) - - # the returned value is - [1, 2, {3: 4}] - """ - assert isinstance(spec, TreeSpec) - - if isinstance(spec, LeafSpec): - return [pytree] - - flatten_fn = SUPPORTED_NODES[spec.type].flatten_fn - child_pytrees, _ = flatten_fn(pytree) - - if len(child_pytrees) != len(spec.children_specs): - raise RuntimeError(f'The number of pytree children is not equal to the give specs.') - - result = [] - for child, child_spec in zip(child_pytrees, spec.children_specs): - flat = flatten_tree_with_spec(child, child_spec) - result += flat - - return result - - -def flatten_trees_with_func_and_spec(fn, pytrees, spec): - """ - Example: - - pytrees = [ - [1, (2, {3: 4})], - [5, (6, 7)] - ] - spec = [*, (*, *)] - - # the returned value is - [fn(1, 5), fn(2, 6), fn({3: 4}, 7)] - """ - flat_args = [flatten_tree_with_spec(pytree, spec) for pytree in pytrees] - if not all(len(flat_arg) == len(flat_args[0]) for flat_arg in flat_args): - raise RuntimeError('the element number of pytrees are not equal') - return [fn(*vals) for vals in zip(*flat_args)] - - class ExtraSEFPatcher: def __init__(self, extra_side_effectful_functions: Set[Callable]): self.extra_side_effectful_functions = extra_side_effectful_functions diff --git a/tests/graph/tracer/test_pytree.py b/tests/graph/tracer/test_pytree.py index 24d3035f..096a52a9 100644 --- a/tests/graph/tracer/test_pytree.py +++ b/tests/graph/tracer/test_pytree.py @@ -1,111 +1,51 @@ -import pytest - -from torch.utils._pytree import tree_flatten - -from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_proxy import ( - ConcreteProxy, - Node, +from nnscaler.graph.parser.fx.concrete_trace_utils import pytree_utils +from nnscaler.graph.parser.fx.concrete_trace_utils.pytree_utils import ( + get_common_spec, + tree_leaves_with_spec, ) +from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_proxy import ConcreteProxy from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import ( - update_tree_proxy_value + update_tree_proxy_value, ) -from nnscaler.graph.parser.fx.concrete_trace_utils.utils import ( - flatten_tree_with_spec, - flatten_trees_with_func, - flatten_trees_with_func_and_spec, - map_trees_with_func -) - - -def test_flatten_tree_with_spec(): - pytree_1 = [1, (2, (3, 4))] - pytree_2 = [1, (2, (3, (4, 5)))] - pytree_3 = [1, (2, (3,))] - _, spec = tree_flatten(pytree_1) - - assert flatten_tree_with_spec(pytree_2, spec) == [1, 2, 3, (4, 5)] - - # pytree_3 can not flatten by pytree_1 spec, so it should raise error - with pytest.raises(RuntimeError): - flatten_tree_with_spec(pytree_3, spec) - -def test_flatten_trees_with_func(): - pytree_1 = [1, (2, {3: 4})] - pytree_2 = [5, (6, {3: 5})] - flat_args, spec = flatten_trees_with_func(lambda a, b: a + b, [pytree_1, pytree_2]) - assert flat_args == [6, 8, 9] - assert spec == tree_flatten(pytree_1)[1] - pytree_3 = [1, (2, (3,))] - # pytree_3 has different spec with pytree_1 and pytree_2, so it should raise error - with pytest.raises(RuntimeError): - flatten_trees_with_func(lambda a, b, c: a + b + c, [pytree_1, pytree_2, pytree_3]) +def test_pytree_related_utils(): + pytree_1 = {'a': [1, {'b': 2}]} + pytree_2 = {'a': [3, 4]} + pytree_3 = {'a': [5, [6, 7]]} + pytree_1_spec = pytree_utils.tree_flatten(pytree_1)[1] + pytree_2_spec = pytree_utils.tree_flatten(pytree_2)[1] + pytree_3_spec = pytree_utils.tree_flatten(pytree_3)[1] -def test_flatten_trees_with_func_and_spec(): - pytree_0 = [1, (2, 3)] - _, spec = tree_flatten(pytree_0) + # test get_common_spec + common_spec = get_common_spec(pytree_1_spec, pytree_2_spec, pytree_3_spec) - def merge(a, b): - if isinstance(a, dict): - assert isinstance(b, dict) - return {**a, **b} - else: - return a + b + assert common_spec == \ + pytree_utils.TreeSpec(dict, ['a'], [pytree_utils.TreeSpec(list, None, [pytree_utils.LeafSpec(), pytree_utils.LeafSpec()])]),\ + f"expect TreeSpec(dict, ['a'], [TreeSpec(list, None, [*, *])]), but get {common_spec}" - pytree_1 = [1, (2, {3: 4})] - pytree_2 = [5, (6, {4: 5})] - assert flatten_trees_with_func_and_spec(merge, [pytree_1, pytree_2], spec) == [6, 8, {3: 4, 4: 5}] + # test tree_leaves_with_spec + assert tree_leaves_with_spec(pytree_1, common_spec) == [1, {'b': 2}] + assert tree_leaves_with_spec(pytree_2, common_spec) == [3, 4] + assert tree_leaves_with_spec(pytree_3, common_spec) == [5, [6, 7]] - -def test_map_trees_with_func(): - pytree_1 = [1, (2, {3: 4})] - pytree_2 = [5, (6, {3: 5})] - - assert map_trees_with_func(lambda a, b: a + b, [pytree_1, pytree_2]) == [6, (8, {3: 9})] - - -def test_update_tree_proxy_value(): + # test update_tree_proxy_value class DummyNode: - def __init__(self, name): - self.name = name - self.graph = None - - pytree_1 = ConcreteProxy(node=DummyNode('test_node'), value={'a': {'b': [1, 2]}}, tracer=None) - pytree_2 = {'a': {'b': [1, 3]}} - new_pytree = update_tree_proxy_value(pytree_1, pytree_2) - assert str(new_pytree) == "ConcreteProxy(test_node, {'a': {'b': [1, 3]}})" - - pytree_1 = {'a': ConcreteProxy(node=DummyNode('test_node'), value={'b': [1, 2]}, tracer=None)} - pytree_2 = {'a': {'b': [1, 3]}} - new_pytree = update_tree_proxy_value(pytree_1, pytree_2) - assert str(new_pytree) == "{'a': ConcreteProxy(test_node, {'b': [1, 3]})}" - - pytree_1 = ConcreteProxy( - node=DummyNode('t1'), - value={'a': ConcreteProxy( - node=DummyNode('t2'), - value={'b': ConcreteProxy( - node=DummyNode('t3'), - value=[1, ConcreteProxy( - node=DummyNode('t4'), - value=2, - tracer=None - )], - tracer=None) - }, - tracer=None) - }, - tracer=None - ) - pytree_2 = {'a': {'b': [1, 3]}} - new_pytree = update_tree_proxy_value(pytree_1, pytree_2) - assert str(new_pytree) == "ConcreteProxy(t1, {'a': ConcreteProxy(t2, {'b': ConcreteProxy(t3, [1, ConcreteProxy(t4, 3)])})})" - - pytree_1 = {'a': ConcreteProxy(node=DummyNode('test_node'), value={'b': [1, 2]}, tracer=None)} - pytree_2 = {'b': {'a': [1, 3]}} - new_pytree = update_tree_proxy_value(pytree_1, pytree_2) - # because the spec of pytree_1 - {'a': {'b': *}} - and pytree_2 - {'b': {'a': *}} - is completely differet, - # the result is directly pytree_2 - assert str(new_pytree) == "{'b': {'a': [1, 3]}}" + pass + class DummyTracer: + pass + + pytree_src = {'key': 1, 'value': (2, 3)} + pytree_dst = { + 'key': ConcreteProxy(DummyNode(), 4, DummyTracer()), + 'value': ConcreteProxy( + DummyNode(), + (5, ConcreteProxy(DummyNode(), 6, DummyTracer())), + DummyTracer() + ) + } + new_pytree = update_tree_proxy_value(pytree_dst, pytree_src) + assert new_pytree['key'].value == 1 + assert new_pytree['value'].value[0] == 2 + assert new_pytree['value'].value[1].value == 3 From 6bb0a9ed5f1c07a5159137ccf86170beeaf922cb Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 15 Aug 2024 07:36:45 +0000 Subject: [PATCH 1704/1892] Merged PR 2233: Minitrainer logging: log tag support Minitrainer logging: log tag support --- docs/source/trainer.md | 19 ++++++++++----- nnscaler/cli/loggers/logger_base.py | 2 +- nnscaler/cli/loggers/tensorboard.py | 38 ++++++++++++++--------------- nnscaler/cli/loggers/wandb.py | 4 ++- nnscaler/cli/trainer.py | 12 ++++++--- tests/cli/test_trainer.py | 7 +++--- 6 files changed, 48 insertions(+), 34 deletions(-) diff --git a/docs/source/trainer.md b/docs/source/trainer.md index ed24be93..4c67d2af 100644 --- a/docs/source/trainer.md +++ b/docs/source/trainer.md @@ -352,7 +352,7 @@ class CheckpointConfig: The checkpoint is a folder with as many files as the world size. - `"merged"`: everything has been merged into a single file. Used internally only when you merge the checkpoint files via `Trainer.merge_checkpoints` - `save_last` (`bool`): Whether to save the last checkpoint. Default is `True`. -- `save_best` (`bool`): Whether to save the best checkpoint. Default is `True`. +- `save_best` (`bool`): Whether to save the best (lowest `val_loss`) checkpoint. Default is `True`. - `symlink_best_and_last` (`bool`): Whether to use symlink (instead of copy) to the best and last checkpoint. Default is `True`. - `every_n_train_steps` (`Optional[int]`): Save the checkpoint every `every_n_train_steps` training steps. Default is `None`, which means no checkpoint is saved based on training steps. - `every_n_epochs` (`Optional[int]`): Save the checkpoint every `every_n_epochs` epochs. Default is `None`, which means no checkpoint is saved based on epochs. @@ -362,14 +362,21 @@ Default is `None`, which means all checkpoints are kept. We will not resume (nor report error) if resume_from is `last` or `best` but the corresponding checkpoint does not exist. Default is `None`. -Please note when the parallel plan is changed (i.e you re-trace the model with different configurations), +Please note + +1. When the parallel plan is changed (i.e you re-trace the model with different configurations), the checkpoints become incompatible, and can't be loaded any more. You must firstly merge the checkpoints to a merged checkpoint with `Trainer.merge_checkpoint` and then load the merged checkpoint just like a regular checkpoint. -```python -def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): -``` -where `checkpoint_files` is a list of checkpoint files to merge, and `output_file` is the output file path. + ```python + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): + ``` + where `checkpoint_files` is a list of checkpoint files to merge, and `output_file` is the output file path. + +2. When a checkpoint is saved, +we will run validation on the validation dataset and save the validation loss to the checkpoint file. +The validation run will ignore the `val_every_n_train_steps` and `val_every_n_epochs` configurations. +If no valid dataset is provided, validation is skipped and `valid_loss` is set to `train_loss` by default. ### Other configs - `gen_savedir` (`str`): The directory to save the generated files. Default is `./.nnscaler`. diff --git a/nnscaler/cli/loggers/logger_base.py b/nnscaler/cli/loggers/logger_base.py index a507bf1e..00b2f0bc 100644 --- a/nnscaler/cli/loggers/logger_base.py +++ b/nnscaler/cli/loggers/logger_base.py @@ -16,7 +16,7 @@ def setup(self, config: Dict) -> None: ... @abstractmethod - def log_metrics(self, metrics: Dict[str, float], step: int) -> None: + def log_metrics(self, metrics: Dict[str, float], step: int, *, tag: Optional[str] = None) -> None: ... @abstractmethod diff --git a/nnscaler/cli/loggers/tensorboard.py b/nnscaler/cli/loggers/tensorboard.py index ed20a92c..43295fe4 100644 --- a/nnscaler/cli/loggers/tensorboard.py +++ b/nnscaler/cli/loggers/tensorboard.py @@ -6,7 +6,7 @@ import yaml import torch try: - _tensorboard_writers = [] + _tensorboard_writers = {} from torch.utils.tensorboard import SummaryWriter except ImportError: SummaryWriter = None @@ -31,52 +31,50 @@ def __init__( self._name = name self._root_dir = Path(root_dir).expanduser().resolve() self._kwargs = kwargs - - self._summary_writer = None + self._yaml_config = None # will be set in `setup` @property - def log_dir(self) -> str: + def log_dir(self) -> Path: """ Root directory to save logging output, which is `_log_dir/_name`. """ sub_path = [s for s in [self._name] if s] ld = self._root_dir.joinpath(*sub_path) ld.mkdir(parents=True, exist_ok=True) - return str(ld) + return ld @rank_zero_only def setup(self, config: Dict) -> None: - self._ensure_writer() - self._summary_writer.add_text("config", yaml.dump(config)) + self._yaml_config = yaml.dump(config) - def _ensure_writer(self): - if not self._summary_writer: - self._summary_writer = SummaryWriter(log_dir=self.log_dir, **self._kwargs) - _tensorboard_writers.append(self._summary_writer) - return self._summary_writer + def _get_or_create_writer(self, tag: Optional[str] = None): + tag = tag or '' + if tag not in _tensorboard_writers: + _tensorboard_writers[tag] = SummaryWriter(log_dir=self.log_dir / tag, **self._kwargs) + _tensorboard_writers[tag].add_text("config", self._yaml_config) + return _tensorboard_writers[tag] @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: int) -> None: - self._ensure_writer() + def log_metrics(self, metrics: Dict[str, float], step: int, *, tag: Optional[str] = None) -> None: + summary_writer = self._get_or_create_writer(tag) for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() if isinstance(v, dict): - self._summary_writer.add_scalars(k, v, step) + summary_writer.add_scalars(k, v, step) else: - self._summary_writer.add_scalar(k, v, step) + summary_writer.add_scalar(k, v, step) @rank_zero_only def finalize(self) -> None: - if self._summary_writer: - self._summary_writer.close() - _tensorboard_writers.remove(self._summary_writer) + # will do nothing, as the writers will be closed on exit + pass def _close_writers(): - for w in _tensorboard_writers: + for w in _tensorboard_writers.values(): w.close() # Close all writers on exit diff --git a/nnscaler/cli/loggers/wandb.py b/nnscaler/cli/loggers/wandb.py index c23c25b3..08e2dd76 100644 --- a/nnscaler/cli/loggers/wandb.py +++ b/nnscaler/cli/loggers/wandb.py @@ -50,7 +50,9 @@ def setup(self, config: Dict) -> None: ) @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: int) -> None: + def log_metrics(self, metrics: Dict[str, float], step: int, *, tag: Optional[str] = None) -> None: + prefix = "" if tag is None else tag + "/" + metrics = {prefix + k: v for k, v in metrics.items()} wandb.log(metrics, step=step) @rank_zero_only diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 3d94bcab..3c90eb40 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -276,9 +276,9 @@ def _log_finalize(self): for logger in self.loggers: logger.finalize() - def _log_metrics(self, metrics: Dict[str, float], step: int): + def log_metrics(self, metrics: Dict[str, float], step: int, *, tag: Optional[str] = None): for logger in self.loggers: - logger.log_metrics(metrics, step) + logger.log_metrics(metrics, step, tag=tag) def _log_config(self, config: Dict): for logger in self.loggers: @@ -584,7 +584,7 @@ def _validate(self, step_stat: _StepStat): self.hook.on_val_end(self, loss) step_stat.val_loss = loss - self._log_metrics(asdict(step_stat), self.num_train_steps) + self.log_metrics(asdict(step_stat), self.num_train_steps, tag='val') return step_stat.val_loss def train_epoch(self, epoch): @@ -675,6 +675,12 @@ def train_epoch(self, epoch): if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'step': self.lr_scheduler.step() + self.log_metrics( + {k:v for k, v in asdict(step_stat).items() if v is not None}, + self.num_train_steps, + tag='train' + ) + # validate and save checkpoint if self.train_args.checkpoint.every_n_train_steps and \ self.num_train_steps % self.train_args.checkpoint.every_n_train_steps == 0: diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index f588e9ac..61365777 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -41,9 +41,10 @@ def trainer_logging_worker(save_dir): if torch.distributed.get_rank() == 0: assert (tb_log_savedir / 'test-cli').exists() - tfevents = list((tb_log_savedir / 'test-cli').glob('events.out.tfevents.*')) - assert len(tfevents) == 1 - assert tfevents[0].stat().st_size > 1000 + for tag in ['val', 'train']: + tfevents = list((tb_log_savedir / 'test-cli' / tag).glob('events.out.tfevents.*')) + assert len(tfevents) == 1 + assert tfevents[0].stat().st_size > 1000 assert (wandb_log_savedir / 'wandb').exists() wandb_offline_dir = list((wandb_log_savedir / 'wandb').glob('offline-run-*')) From 0c394ffe51ce59e08aa364da787ecb6be4ba3b8d Mon Sep 17 00:00:00 2001 From: "Xin Ji (CSI Interfusion Co Ltd)" Date: Thu, 15 Aug 2024 07:37:58 +0000 Subject: [PATCH 1705/1892] Merged PR 2211: support batchnorm2d 1. Wrap batchnorm2d/instancenorm2d as a customized function because it has control flow in its forward function. Create new batchnorm2d/instancenorm2 module for replacing the original modules using a utility function automatically. 2. support communication within operator. An example operator is batchnorm2d when partitioning the batch dimension. --- nnscaler/graph/function/wrapnn.py | 470 +++++++++++++++ tests/parallel_module/test_normlayer.py | 765 ++++++++++++++++++++++++ 2 files changed, 1235 insertions(+) create mode 100644 nnscaler/graph/function/wrapnn.py create mode 100644 tests/parallel_module/test_normlayer.py diff --git a/nnscaler/graph/function/wrapnn.py b/nnscaler/graph/function/wrapnn.py new file mode 100644 index 00000000..70651a6e --- /dev/null +++ b/nnscaler/graph/function/wrapnn.py @@ -0,0 +1,470 @@ +""" +This file deals with some special nn modules which have control flows (if/else) in their forward function. +These control flows go different branches according to self.training. +So we rewrite these nn modules to update their forward function, the new forward function uses a registered +customized function to wrap the control flows. nnscaler treats the customized function as a black-box leaf node. + +Currently, this file wraps the following nn modules: + nn.BatchNorm2d + nn.InstanceNorm2d + +At last, we provide a utility function to replace the original nn modules with the wrapped nn modules. +""" + +from dataclasses import dataclass +from typing import Tuple, List, Dict +from typing import Tuple +import warnings +import torch +from torch import Tensor +from torch.nn import functional as F +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.instancenorm import _InstanceNorm +from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm + +from nnscaler.graph.function.function import _unwrap_value +from nnscaler.graph.parser.register import register_op +from nnscaler.ir.operator import IRFwOperation +from nnscaler.ir.cten import IRObject, IRTensor +from nnscaler.runtime.device import DeviceGroup + + +def wrap_batchnorm2d_func( + input: Tensor, + weight: Tensor, + bias: Tensor, + running_mean: Tensor, + running_var: Tensor, + num_batches_tracked: Tensor, + momentum: float = 0.1, + training: bool = True, + track_running_stats: bool = True, + eps: float = 1e-05, + process_group: Tuple[int] = None, +) -> Tensor: + """ + This function wraps the original batchnorm2d forward function, because it has both control flows and nccl communication. + Most of the code is copied from https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#SyncBatchNorm + NOTE: the non-tensor inputs must be kwargs with default value. + NOTE: the invocation of the function must use kw format to pass kwargs. + NOTE: process_group and world_size is for the internal nccl communication, process_group specifies + the group of devices that will perform the synchronization, and world_size specifies the number of devices. + """ + if input.dim() != 4: + raise ValueError(f"expected 4D input (got {input.dim()}D input)") + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = momentum + + if training and track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if num_batches_tracked is not None: # type: ignore[has-type] + num_batches_tracked.add_(1) # type: ignore[has-type] + if momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + + if training: + bn_training = True + else: + bn_training = (running_mean is None) and (running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + # If buffers are not to be tracked, ensure that they won't be updated + running_mean = running_mean if not training or track_running_stats else None + running_var = running_var if not training or track_running_stats else None + # Don't sync batchnorm stats in inference mode (model.eval()). + need_sync = bn_training and training and process_group is not None + if need_sync: + # currently only GPU/PrivateUse1 input is supported + process_group = DeviceGroup().get_group(process_group) + if process_group is None: + process_group = torch.distributed.group.WORLD + world_size = torch.distributed.get_world_size(process_group) + need_sync = world_size > 1 + # fallback to framework BN when synchronization is not necessary + if not need_sync: + return F.batch_norm( + input, + running_mean, + running_var, + weight, + bias, + bn_training, + exponential_average_factor, + eps, + ) + else: + assert bn_training + return sync_batch_norm.apply( + input, + weight, + bias, + running_mean, + running_var, + eps, + exponential_average_factor, + process_group, # type: ignore[possibly-undefined] + world_size, # type: ignore[possibly-undefined] + ) + + +def unwrap_if_irobject(x): + return x.value if isinstance(x, IRObject) and not isinstance(x, IRTensor) else x + + +def batchnorm2d_annotation_fn(*inputs, **kwargs): + assert ( + len(inputs) == 6 + ), f"Expected 6 inputs: input, weight, bias, running_mean, running_var, num_batches_tracked, but got {len(inputs)} {inputs}." + input, weight, bias, running_mean, running_var, num_batches_tracked = inputs + """ + Restrictions: + 1. If `weight` is None, then `bias` must also be None. This is because in the absence of `weight`, + BatchNorm2d does not apply affine transformation, which means there is no need for `bias`. + 2. If `running_mean` is None, then `running_var` and `num_batches_tracked` must also be None. + This is because `running_mean` and `running_var` are used for tracking the statistics of + the batch normalization during training. If `running_mean` is not provided, it implies + that the module should not track statistics, hence `running_var` and `num_batches_tracked` + should also be absent. + Reference: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + """ + weight = unwrap_if_irobject(weight) + bias = unwrap_if_irobject(bias) + running_mean = unwrap_if_irobject(running_mean) + running_var = unwrap_if_irobject(running_var) + num_batches_tracked = unwrap_if_irobject(num_batches_tracked) + + if weight is None: + assert bias is None + wb_annos = "?, ?" + else: + assert isinstance(weight, IRTensor) + assert isinstance(bias, IRTensor) + wb_annos = "c, c" + + if running_mean is None: + assert ( + running_var is None and num_batches_tracked is None + ), "If running_mean is None, both running_var and num_batches_tracked must also be None" + r_annos = "?, ?, ?" + else: + assert isinstance(running_mean, IRTensor) + assert isinstance(running_var, IRTensor) + assert isinstance(num_batches_tracked, IRTensor) + r_annos = "c, c, 1" + + return "n c h^ w^, " + wb_annos + ", " + r_annos + " -> n c h^ w^" + + +class NnScalerBatchNorm2d(_BatchNorm): + def forward(self, input: Tensor) -> Tensor: + return wrap_batchnorm2d_func( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.num_batches_tracked, + momentum=self.momentum, + training=self.training, + track_running_stats=self.track_running_stats, + eps=self.eps, + ) + + +def batchnorm2d_reinit(module: _BatchNorm) -> _BatchNorm: + """Reinitialize the batchnorm2d module with the same parameters and arguments, but using + the wrapped module NnScalerBatchNorm2d.""" + if not isinstance(module, _BatchNorm): + raise TypeError(f"Expected module of type _BatchNorm, but got {type(module)}") + new_module = NnScalerBatchNorm2d( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + ) + if module.affine: + with torch.no_grad(): + new_module.weight = module.weight + new_module.bias = module.bias + new_module.running_mean = module.running_mean + new_module.running_var = module.running_var + new_module.num_batches_tracked = module.num_batches_tracked + return new_module + + +def emit_batchnorm2d( + node: IRFwOperation, + args: List[str], + kwargs: Dict[str, str], + runtime_devid: int, + plan_ndevs: int, + runtime_ndevs: int, +) -> str: + """Special rule to generate batchnorm2d node""" + + signature = node.signature + + # Compute scale unit device ids + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f"{key}={val}" + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [ + i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f + ] + assert ( + len(partition_dims) <= 1 + ), f"only support one partition dim for now, got {partition_dims}" + + if len(partition_dims) == 1 and partition_dims[0] == 0: # partition on batch dim + # if batch dim is partitioned, it means batchnorm is partitioned in batch dim both + # within scaleunit and across scaleunits + kw_pairs.append(f"process_group={tuple(range(runtime_ndevs))}") + else: + # the synchronization should occur across scaleunits + assert len(partition_dims) == 0 or partition_dims[0] != 0 + if runtime_ndevs == len(scale_unit_dev_ids): + kw_pairs.append("process_group=None") + else: + start_id = runtime_devid % len(scale_unit_dev_ids) + process_group = tuple( + range(start_id, runtime_ndevs, len(scale_unit_dev_ids)) + ) + kw_pairs.append(f"process_group={process_group}") + assert len(process_group) == runtime_ndevs // len(scale_unit_dev_ids) + + args_str = ", ".join(args) + kwargs_str = ", ".join(kw_pairs) + return f"{signature}({args_str}, {kwargs_str})" + + +register_op(batchnorm2d_annotation_fn, emit_fn=emit_batchnorm2d)(wrap_batchnorm2d_func) + + +""" + This function wraps the original InstanceNorm2d forward function. + + The logic in this function is exactly the same as in the original PyTorch implementation. + We copied the logic here to register it as a customized operation because nnscaler's + `register_op` only supports functions, not nn.Module classes. Therefore, this function + serves as a wrapper around the InstanceNorm2d forward logic, treating the entire function + as a black-box leaf node in nnscaler. +""" + + +def wrap_instancenorm2d_func( + input: Tensor, + weight: Tensor, + bias: Tensor, + running_mean: Tensor, + running_var: Tensor, + momentum: float = 0.1, + eps: float = 1e-05, + training: bool = True, + track_running_stats: bool = False, + num_features: int = 0, + affine: bool = False, +) -> Tensor: + """ + This operation applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) + note: `InstanceNorm2d` is appliedon each channel of channeled data like RGB images,usually don't apply affine transform. + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` or :math:`(C, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + Reference: https://pytorch.org/docs/stable/_modules/torch/nn/modules/instancenorm.html#InstanceNorm2d + """ + + def _get_no_batch_dim(): + """ + This function returns the dimension that indicates no batch dimension for InstanceNorm2d. + For 2D data, typically we have the following dimensions: + - 4D input: (N, C, H, W) where N is the batch size + - 3D input: (C, H, W) without the batch dimension + + InstanceNorm2d can work with both 4D and 3D inputs. When the input is 3D, we need to temporarily + add a batch dimension to perform normalization, and then remove it afterwards. + """ + return 3 + + if input.dim() not in (3, 4): + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") + + """ + Explanation: + - For a 4D input (N, C, H, W), the channel dimension is the 2nd dimension (index 1). + - For a 3D input (C, H, W), the channel dimension is the 1st dimension (index 0). + This logic ensures that we correctly identify the channel dimension for both 3D and 4D inputs. + """ + feature_dim = input.dim() - _get_no_batch_dim() + if input.size(feature_dim) != num_features: + if affine: + raise ValueError( + f"expected input's size at dim={feature_dim} to match num_features" + f" ({num_features}), but got: {input.size(feature_dim)}." + ) + else: + warnings.warn( + f"input's size at dim={feature_dim} does not match num_features. " + "You can silence this warning by not passing in num_features, " + "which is not used because affine=False" + ) + + if input.dim() == _get_no_batch_dim(): + return F.instance_norm( + input.unsqueeze(0), + running_mean, + running_var, + weight, + bias, + training or not track_running_stats, + momentum, + eps, + ).squeeze(0) + + return F.instance_norm( + input, + running_mean, + running_var, + weight, + bias, + training or not track_running_stats, + momentum, + eps, + ) + + +def instancenorm2d_annotation_fn(*inputs, **kwargs): + assert ( + len(inputs) == 5 + ), "Expected 5 inputs: input, weight, bias, running_mean, running_var" + input, weight, bias, running_mean, running_var = inputs + + weight = unwrap_if_irobject(weight) + bias = unwrap_if_irobject(bias) + running_mean = unwrap_if_irobject(running_mean) + running_var = unwrap_if_irobject(running_var) + + if weight is None: + assert bias is None + wb_annos = "?, ?" + else: + assert isinstance(weight, IRTensor) + assert isinstance(bias, IRTensor) + wb_annos = "c^, c^" + + if running_mean is None: + assert ( + running_var is None + ), "If running_mean is None, running_var must also be None" + r_annos = "?, ?" + else: + assert isinstance(running_mean, IRTensor) + assert isinstance(running_var, IRTensor) + r_annos = "c^, c^" + + # FIXME: c cannot be partitioned, because the kwargs num_features cannot be updated for now + return "n c^ h^ w^, " + wb_annos + ", " + r_annos + " -> n c^ h^ w^" + + +register_op(instancenorm2d_annotation_fn)(wrap_instancenorm2d_func) + + +class NnScalerInstanceNorm2d(_InstanceNorm): + def forward(self, input: Tensor) -> Tensor: + return wrap_instancenorm2d_func( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + momentum=self.momentum, + eps=self.eps, + training=self.training, + track_running_stats=self.track_running_stats, + num_features=self.num_features, + affine=self.affine, + ) + + +def instancenorm2d_reinit(module: _InstanceNorm) -> _InstanceNorm: + """Reinitialize the instancenorm2d module with the same parameters and arguments, but using + the wrapped module NnScalerInstanceNorm2d.""" + new_module = NnScalerInstanceNorm2d( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + ) + if module.affine: + with torch.no_grad(): + new_module.weight = module.weight + new_module.bias = module.bias + new_module.running_mean = module.running_mean + new_module.running_var = module.running_var + new_module.num_batches_tracked = module.num_batches_tracked + return new_module + + +wrapped_modules = { + torch.nn.BatchNorm2d: batchnorm2d_reinit, + torch.nn.InstanceNorm2d: instancenorm2d_reinit, +} + + +def convert_to_wrapnn(module: torch.nn.Module): + """Traverse the module and replace the original nn module with its wrapped version + if it is in the `wrapped_modules`. + Currently `wrapped_modules` contains `BatchNorm2d` and `InstanceNorm2d`. + + It is necessary to call this function on user instantiated model before parallelizing + the it, otherwise the modules in `wrapped_modules` cannot be partitioned, but be always + replicated. + + Anyway, it is safe to call this function on the model, even if the model + does not have the modules in `wrapped_modules`. + """ + if type(module) in wrapped_modules: + return wrapped_modules[type(module)](module) + + for name, child in module.named_children(): + module.add_module( + name, convert_to_wrapnn(child) + ) # will inplace replace the module with the same name + return module diff --git a/tests/parallel_module/test_normlayer.py b/tests/parallel_module/test_normlayer.py new file mode 100644 index 00000000..6e590e19 --- /dev/null +++ b/tests/parallel_module/test_normlayer.py @@ -0,0 +1,765 @@ +import uuid +import torch.distributed as dist +import tempfile +import torch +import pytest +import random +import nnscaler +from nnscaler.graph.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.runtime.device import DeviceGroup +from tests.parallel_module.test_gencode import _gencode_contains +from nnscaler.graph.function.wrapnn import convert_to_wrapnn +from nnscaler.parallel import parallelize, ComputeConfig +from torch.nn.parallel import DistributedDataParallel as DDP + +from tests.utils import retry, init_random +from .common import init_distributed +from ..launch_torchrun import launch_torchrun +from torch.distributed.run import elastic_launch, LaunchConfig +from torch.distributed.elastic.multiprocessing.errors import ChildFailedError + + +def policy(graph: IRGraph, resource: ComputeConfig, dim: int) -> IRGraph: + ngpus = resource.plan_ngpus + partitioned = False + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if ( + not partitioned + and node.signature == "nnscaler.graph.function.wrapnn.wrap_batchnorm2d_func" + ): + print("Partitioned node: ", node) + sub_nodes = graph.partition( + node, node.algorithms("dim"), idx=0, dim=dim, num=ngpus + ) + partitioned = True + elif ( + not partitioned + and node.signature + == "nnscaler.graph.function.wrapnn.wrap_instancenorm2d_func" + ): + print("Partitioned node: ", node) + sub_nodes = graph.partition( + node, node.algorithms("dim"), idx=0, dim=0, num=ngpus + ) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + assert partitioned, f"expect instancenorm / batchnorm in graph, but not found." + return graph + + +def compute_error(tensor1, tensor2): + mean_abs_error = torch.abs(tensor1 - tensor2).mean().item() + max_abs_error = torch.abs(tensor1 - tensor2).max().item() + return mean_abs_error, max_abs_error + + +def generate_parallel_data(size, device, dtype): + shared_data = [torch.randn(size, device=device, dtype=dtype) for _ in range(2)] + return shared_data + + +class BatchNorm2dModule(torch.nn.Module): + def __init__(self): + super(BatchNorm2dModule, self).__init__() + self.bn = torch.nn.BatchNorm2d(8) + + def forward(self, x): + return self.bn(x) + + +def _gencode_batchnorm2d_function(tempdir, config, pas_policy): + init_distributed() + m = BatchNorm2dModule().cuda() + m_2d = convert_to_wrapnn(m) + x = torch.randn(8, 8, 32, 32).cuda() + + m_new = parallelize( + m_2d, + {"x": x}, + pas_policy, + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + assert m_new is not None + m_new.train() + output = m_new(x) + assert output is not None + + bn = BatchNorm2dModule().cuda() + bn.train() + ref_output = bn(x) + assert torch.equal( + m_new.bn_running_mean_22, bn.bn.running_mean + ), "Custom output does not match PyTorch output" + + assert torch.equal( + output, ref_output + ), "Custom output does not match PyTorch output" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="lack of GPU devices") +def test_codegen_batchnorm2d_1_1(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 1, _gencode_batchnorm2d_function, tempdir, ComputeConfig(1, 1), "dp" + ) + + +def _gencode_batchnorm2d_function_2(tempdir, config, pas_policy): + nnscaler.init() + rank_id = dist.get_rank() + dtype = torch.bfloat16 + init_random() + device = torch.device(f"cuda:{rank_id}") + + m = BatchNorm2dModule().to(device) + m_2d = convert_to_wrapnn(m) + shared_data = generate_parallel_data((8, 8, 32, 32), device, dtype) + x_part = shared_data[rank_id] + + m_new = parallelize( + m_2d, + {"x": x_part}, + pas_policy, + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + m_new.to(device) + assert m_new is not None + m_new.train() + output = m_new(x_part) + assert output is not None + + gather_output = [torch.empty_like(output) for _ in range(2)] + dist.all_gather(gather_output, output) + y_output = torch.cat(gather_output, dim=0) + + bn = BatchNorm2dModule().to(device) + s_bn = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + bn, process_group=dist.new_group([0, 1]) + ) + s = DDP(s_bn, device_ids=[rank_id]) + s.train() + s_output = s(x_part) + s_gather_output = [torch.empty_like(s_output) for _ in range(2)] + dist.all_gather(s_gather_output, s_output) + sync_output = torch.cat(s_gather_output, dim=0) + + assert torch.equal( + y_output, sync_output + ), "Custom output does not match PyTorch output" + + y = torch.cat(shared_data, dim=0) + model = BatchNorm2dModule().cuda() + model.train() + output = model(y) + current_mean_error, current_max_error = compute_error(output, y_output) + mean_error, max_error = compute_error(sync_output, output) + assert (current_mean_error - mean_error) == 0 and ( + current_max_error - max_error + ) == 0, "Custom output is not the same as PyTorch output error" + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU devices") +def test_codegen_batchnorm2d_1_2(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 2, _gencode_batchnorm2d_function_2, tempdir, ComputeConfig(1, 2), "dp" + ) + + +def _gencode_batchnorm2d_function_4(tempdir, config, pas_policy, dim): + nnscaler.init() + rank_id = dist.get_rank() + dtype = torch.bfloat16 + init_random() + device = torch.device(f"cuda:{rank_id}") + + m = BatchNorm2dModule().to(device) + m_2d = convert_to_wrapnn(m) + + x_list = generate_parallel_data((8, 8, 32, 32), device, dtype) + x = x_list[rank_id // 2] + + m_new = parallelize( + m_2d, + {"x": x}, + lambda graph, resource: pas_policy(graph, resource, dim), + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + m_new.to(device) + assert m_new is not None + m_new.train() + output = m_new(x) + assert output is not None + + gather_output = [torch.empty_like(output) for _ in range(4)] + dist.all_gather(gather_output, output) + y_output = torch.cat([gather_output[0], gather_output[2]], dim=0) + + y = torch.cat([x_list[0], x_list[1]], dim=0) + bn = BatchNorm2dModule().cuda() + bn.train() + ref_output = bn(y) + current_mean_error, current_max_error = compute_error(y_output, ref_output) + assert ( + current_mean_error + ) < 1e-6, "Custom output is not the same as PyTorch output error" + + x = torch.chunk(x_list[rank_id // 2], 2, dim=0)[rank_id % 2] + + bn = BatchNorm2dModule().to(device) + s_bn = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + bn, process_group=dist.new_group([0, 1, 2, 3]) + ) + s = DDP(s_bn, device_ids=[rank_id]) + s.train() + s_output = s(x) + s_gather_output = [torch.empty_like(s_output) for _ in range(4)] + dist.all_gather(s_gather_output, s_output) + sync_output = torch.cat([s_gather_output[0], s_gather_output[1]], dim=0) + sync_output_all = torch.cat( + [ + s_gather_output[0], + s_gather_output[1], + s_gather_output[2], + s_gather_output[3], + ], + dim=0, + ) + + assert torch.equal( + gather_output[0], sync_output + ), "Custom output does not match PyTorch SyncBatchNorm output" + assert torch.equal( + y_output, sync_output_all + ), "Custom output does not match PyTorch SyncBatchNorm output" + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPU devices") +@pytest.mark.parametrize("dim", [0, 1]) +def test_codegen_batchnorm2d_2_4(dim): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 4, + _gencode_batchnorm2d_function_4, + tempdir, + ComputeConfig(2, 4), + policy, + dim, + ) + + +def _gencode_batchnorm2d_function_eval(tempdir, config, pas_policy): + init_distributed() + m = BatchNorm2dModule().cuda() + x = torch.randn(8, 8, 32, 32).cuda() + m = convert_to_wrapnn(m) + m_new = parallelize( + m, + {"x": x}, + pas_policy, + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + assert m_new is not None + m_new.eval() + output = m_new(x) + assert output is not None + bn = BatchNorm2dModule().cuda() + bn.eval() + ref_output = bn(x) + assert torch.equal( + m_new.bn_running_mean_22, bn.bn.running_mean + ), "Custom output does not match PyTorch output" + assert torch.equal( + output, ref_output + ), "Custom output does not match PyTorch output" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="lack of GPU devices") +def test_codegen_batchnorm2d_eval_1_1(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 1, _gencode_batchnorm2d_function_eval, tempdir, ComputeConfig(1, 1), "dp" + ) + + +def _gencode_batchnorm2d_function_eval_2(tempdir, config, pas_policy): + nnscaler.init() + rank_id = dist.get_rank() + dtype = torch.bfloat16 + init_random() + device = torch.device(f"cuda:{rank_id}") + + m = BatchNorm2dModule().to(device) + m_2d = convert_to_wrapnn(m) + shared_data = generate_parallel_data((4, 8, 32, 32), device, dtype) + x_part = shared_data[rank_id] + + m_new = parallelize( + m_2d, + {"x": x_part}, + pas_policy, + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + m_new.to(device) + assert m_new is not None + m_new.eval() + output = m_new(x_part) + assert output is not None + + gather_output = [torch.empty_like(output) for _ in range(2)] + dist.all_gather(gather_output, output) + y_output = torch.cat(gather_output, dim=0) + + bn = BatchNorm2dModule().to(device) + s_bn = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + bn, process_group=dist.new_group([0, 1]) + ) + + s = DDP(s_bn, device_ids=[rank_id]) + s.eval() + s_output = s(x_part) + s_gather_output = [torch.empty_like(s_output) for _ in range(2)] + dist.all_gather(s_gather_output, s_output) + sync_output = torch.cat(s_gather_output, dim=0) + + assert torch.equal( + y_output, sync_output + ), "Custom output does not match PyTorch output" + + y = torch.cat(shared_data, dim=0) + model = BatchNorm2dModule().cuda() + model.eval() + output = model(y) + current_mean_error, current_max_error = compute_error(output, y_output) + ref_mean_error, ref_max_error = compute_error(sync_output, output) + assert ( + abs(current_mean_error - ref_mean_error) == 0 + and abs(current_max_error - ref_max_error) == 0 + ), "Custom output is not the same as PyTorch output error" + assert torch.allclose( + output, y_output, atol=1e-6 + ), "Custom output does not match PyTorch output" + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU devices") +def test_codegen_batchnorm2d_eval_1_2(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 2, _gencode_batchnorm2d_function_eval_2, tempdir, ComputeConfig(1, 2), "dp" + ) + + +def _gencode_batchnorm2d_function_eval_4(tempdir, config, pas_policy, dim): + nnscaler.init() + rank_id = dist.get_rank() + world_size = dist.get_world_size() + dtype = torch.bfloat16 + init_random() + device = torch.device(f"cuda:{rank_id}") + + m = BatchNorm2dModule().to(device) + m_2d = convert_to_wrapnn(m) + + x_list = generate_parallel_data((8, 8, 32, 32), device, dtype) + x = x_list[rank_id // 2] + + m_new = parallelize( + m_2d, + {"x": x}, + lambda graph, resource: pas_policy(graph, resource, dim), + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + m_new.to(device) + assert m_new is not None + m_new.eval() + output = m_new(x) + assert output is not None + + gather_output = [torch.empty_like(output) for _ in range(4)] + dist.all_gather(gather_output, output) + y_output = torch.cat([gather_output[0], gather_output[2]], dim=0) + + y = torch.cat([x_list[0], x_list[1]], dim=0) + bn = BatchNorm2dModule().cuda() + bn.eval() + ref_output = bn(y) + current_mean_error, current_max_error = compute_error(y_output, ref_output) + assert ( + current_mean_error + ) < 1e-6, "Custom output is not the same as PyTorch output error" + assert torch.allclose( + y_output, ref_output, 1e-6 + ), "Custom output does not match PyTorch output" + + x = torch.chunk(x_list[rank_id // 2], 2, dim=0)[rank_id % 2] + + bn = BatchNorm2dModule().to(device) + s_bn = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + bn, process_group=dist.new_group([0, 1, 2, 3]) + ) + s = DDP(s_bn, device_ids=[rank_id]) + s.eval() + s_output = s(x) + s_gather_output = [torch.empty_like(s_output) for _ in range(4)] + dist.all_gather(s_gather_output, s_output) + sync_output = torch.cat([s_gather_output[0], s_gather_output[1]], dim=0) + sync_output_all = torch.cat( + [ + s_gather_output[0], + s_gather_output[1], + s_gather_output[2], + s_gather_output[3], + ], + dim=0, + ) + + assert torch.equal( + gather_output[0], sync_output + ), "Custom output does not match PyTorch SyncBatchNorm output" + assert torch.equal( + y_output, sync_output_all + ), "Custom output does not match PyTorch SyncBatchNorm output" + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPU devices") +@pytest.mark.parametrize("dim", [0, 1]) +def test_codegen_batchnorm2d_eval_2_4(dim): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 4, + _gencode_batchnorm2d_function_eval_4, + tempdir, + ComputeConfig(2, 4), + policy, + dim, + ) + + +class InstanceNorm2dModule(torch.nn.Module): + def __init__(self): + super(InstanceNorm2dModule, self).__init__() + self.inorm = torch.nn.InstanceNorm2d(4) + self.inorm.running_mean = torch.zeros(4) + self.inorm.running_var = torch.ones(4) + + def forward(self, x): + return self.inorm(x) + + +def _gencode_instancenorm2d_function(tempdir, config, pas_policy): + init_distributed() + m = InstanceNorm2dModule().cuda() + m = convert_to_wrapnn(m) + m_new = parallelize( + m, + {"x": torch.randn(4, 4, 32, 32).cuda()}, + lambda graph, resource: pas_policy(graph, resource, dim=0), + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + assert m_new is not None + x = torch.randn(4, 4, 32, 32).cuda() + m_new.train() + output = m_new(x) + assert output is not None + bn = torch.nn.InstanceNorm2d(4).cuda() + bn.running_mean = torch.zeros(4) + bn.running_var = torch.ones(4) + bn.train() + ref_output = bn(x) + assert torch.equal( + output, ref_output + ), "Custom output does not match PyTorch output in training mode" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="lack of GPU devices") +def test_codegen_instancenorm2d_1_1(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 1, _gencode_instancenorm2d_function, tempdir, ComputeConfig(1, 1), policy + ) + + +def _gencode_instancenorm2d_function_2(tempdir, config, pas_policy): + init_distributed() + rank_id = dist.get_rank() + dtype = torch.bfloat16 + init_random() + device = torch.device(f"cuda:{rank_id}") + m = InstanceNorm2dModule().cuda() + m = convert_to_wrapnn(m) + + shared_data = generate_parallel_data((2, 4, 32, 32), device, dtype) + x_part = shared_data[rank_id] + + m_new = parallelize( + m, + {"x": x_part}, + lambda graph, resource: pas_policy(graph, resource, dim=0), + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + assert m_new is not None + m_new.train() + output = m_new(x_part) + assert output is not None + + gather_output = [torch.empty_like(output) for _ in range(2)] + dist.all_gather(gather_output, output) + y_output = torch.cat(gather_output, dim=0) + + bn = torch.nn.InstanceNorm2d(4).to(device) + bn.running_mean = torch.zeros(4, device=device) + bn.running_var = torch.ones(4, device=device) + y = torch.cat(shared_data, dim=0) + bn.train() + ref_output = bn(y) + current_mean_error, current_max_error = compute_error(y_output, ref_output) + assert ( + abs(current_mean_error) < 1e-6 + ), "Custom output is not the same as PyTorch output error" + assert torch.allclose( + y_output, ref_output, atol=1e-6 + ), "Custom output does not match PyTorch output" + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU devices") +def test_codegen_instancenorm2d_1_2(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 2, _gencode_instancenorm2d_function_2, tempdir, ComputeConfig(1, 2), policy + ) + + +def _gencode_instancenorm2d_function_4(tempdir, config, pas_policy): + init_distributed() + rank_id = dist.get_rank() + dtype = torch.bfloat16 + init_random() + device = torch.device(f"cuda:{rank_id}") + m = InstanceNorm2dModule().cuda() + m = convert_to_wrapnn(m) + + x_list = generate_parallel_data((2, 4, 32, 32), device, dtype) + x = x_list[rank_id // 2] + + m_new = parallelize( + m, + {"x": x}, + lambda graph, resource: pas_policy(graph, resource, dim=0), + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + assert m_new is not None + m_new.train() + output = m_new(x) + assert output is not None + + gather_output = [torch.empty_like(output) for _ in range(4)] + dist.all_gather(gather_output, output) + y_output = torch.cat([gather_output[0], gather_output[2]], dim=0) + + bn = torch.nn.InstanceNorm2d(4).to(device) + bn.running_mean = torch.zeros(4, device=device) + bn.running_var = torch.ones(4, device=device) + y = torch.cat([x_list[0], x_list[1]], dim=0) + bn.train() + ref_output = bn(y) + current_mean_error, current_max_error = compute_error(y_output, ref_output) + assert ( + abs(current_mean_error) < 1e-6 + ), "Custom output is not the same as PyTorch output error" + assert torch.allclose( + y_output, ref_output, atol=1e-6 + ), "Custom output does not match PyTorch output" + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPU devices") +def test_codegen_instancenorm2d_2_4(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 4, _gencode_instancenorm2d_function_4, tempdir, ComputeConfig(2, 4), policy + ) + + +def _gencode_instancenorm2d_function_eval(tempdir, config, pas_policy): + init_distributed() + m = InstanceNorm2dModule().cuda() + m = convert_to_wrapnn(m) + m.eval() + m_new = parallelize( + m, + {"x": torch.randn(4, 4, 32, 32).cuda()}, + lambda graph, resource: pas_policy(graph, resource, dim=0), + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + assert m_new is not None + x = torch.randn(4, 4, 32, 32).cuda() + output = m_new(x) + assert output is not None + + bn = torch.nn.InstanceNorm2d(4).cuda() + bn.running_mean = torch.zeros(4) + bn.running_var = torch.ones(4) + bn.eval() + ref_output = bn(x) + assert torch.equal( + output, ref_output + ), "Custom output does not match PyTorch output in evaluation mode" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="lack of GPU devices") +def test_codegen_instancenorm2d_1_1_eval(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 1, + _gencode_instancenorm2d_function_eval, + tempdir, + ComputeConfig(1, 1), + policy, + ) + + +def _gencode_instancenorm2d_function_eval_2(tempdir, config, pas_policy): + init_distributed() + rank_id = dist.get_rank() + dtype = torch.bfloat16 + init_random() + device = torch.device(f"cuda:{rank_id}") + m = InstanceNorm2dModule().cuda() + m = convert_to_wrapnn(m) + + shared_data = generate_parallel_data((2, 4, 32, 32), device, dtype) + x_part = shared_data[rank_id] + + m_new = parallelize( + m, + {"x": x_part}, + lambda graph, resource: pas_policy(graph, resource, dim=0), + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + assert m_new is not None + m_new.eval() + output = m_new(x_part) + assert output is not None + + gather_output = [torch.empty_like(output) for _ in range(2)] + dist.all_gather(gather_output, output) + y_output = torch.cat(gather_output, dim=0) + + bn = torch.nn.InstanceNorm2d(4).to(device) + bn.running_mean = torch.zeros(4, device=device) + bn.running_var = torch.ones(4, device=device) + y = torch.cat(shared_data, dim=0) + bn.eval() + ref_output = bn(y) + current_mean_error, current_max_error = compute_error(y_output, ref_output) + assert ( + abs(current_mean_error) < 1e-6 + ), "Custom output is not the same as PyTorch output error" + assert torch.allclose( + y_output, ref_output, atol=1e-6 + ), "Custom output does not match PyTorch output" + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU devices") +def test_codegen_instancenorm2d_1_2_eval(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 2, + _gencode_instancenorm2d_function_eval_2, + tempdir, + ComputeConfig(1, 2), + policy, + ) + + +def _gencode_instancenorm2d_function_eval_4(tempdir, config, pas_policy): + init_distributed() + rank_id = dist.get_rank() + dtype = torch.bfloat16 + init_random() + device = torch.device(f"cuda:{rank_id}") + m = InstanceNorm2dModule().cuda() + m = convert_to_wrapnn(m) + + x_list = generate_parallel_data((2, 4, 32, 32), device, dtype) + x = x_list[rank_id // 2] + + m_new = parallelize( + m, + {"x": x}, + lambda graph, resource: pas_policy(graph, resource, dim=0), + config, + gen_savedir=tempdir, + load_module=True, + reuse="override", + ) + assert m_new is not None + m_new.eval() + output = m_new(x) + assert output is not None + + gather_output = [torch.empty_like(output) for _ in range(4)] + dist.all_gather(gather_output, output) + y_output = torch.cat([gather_output[0], gather_output[2]], dim=0) + + bn = torch.nn.InstanceNorm2d(4).to(device) + bn.running_mean = torch.zeros(4, device=device) + bn.running_var = torch.ones(4, device=device) + y = torch.cat([x_list[0], x_list[1]], dim=0) + bn.eval() + ref_output = bn(y) + current_mean_error, current_max_error = compute_error(y_output, ref_output) + assert ( + abs(current_mean_error) < 1e-6 + ), "Custom output is not the same as PyTorch output error" + assert torch.allclose( + y_output, ref_output, atol=1e-6 + ), "Custom output does not match PyTorch output" + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPU devices") +def test_codegen_instancenorm2d_2_4_eval(): + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun( + 4, + _gencode_instancenorm2d_function_eval_4, + tempdir, + ComputeConfig(2, 4), + policy, + ) From fefb87a8f51f2850aa4f20c8c82310711676ee0a Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 16 Aug 2024 05:59:17 +0000 Subject: [PATCH 1706/1892] Merged PR 2234: minitrainer: remove torchrun requirements for compile minitrainer: remove torchrun requirements for compile --- docs/source/trainer.md | 6 +++++- nnscaler/cli/trainer.py | 33 ++++++++++++++++++++++++++++++--- nnscaler/cli/trainer_args.py | 9 +++++++-- requirements.txt | 1 + tests/cli/test_trainer.py | 22 ++++++++++++++++++++++ 5 files changed, 65 insertions(+), 6 deletions(-) diff --git a/docs/source/trainer.md b/docs/source/trainer.md index 4c67d2af..5b13ece8 100644 --- a/docs/source/trainer.md +++ b/docs/source/trainer.md @@ -388,7 +388,11 @@ You can pass builtin pas policy name or your own pas policy function. See `parallelize` API for more information. - `broadcast_strategy` (`str`): The strategy of broadcasting the model. Default is `all`. See `parallelize` API for more information. - `instance_name` (`str`): The instance name of the trainer. Default is `None`. See `parallelize` API for more information. -- `run_mode` (`str`): The run mode of the trainer. It can be `run` (compile and train the model) and `compile` (only compile the model to generate code). Default is `run`. +- `run_mode` (`str`): The run mode of the trainer. +It can be `run` (compile and train the model in a single python script OR train from previous compiling results) and `compile` (only compile the model for code generation). Default is `run`. +Please note you can only use `run` mode with `torchrun`. +On the other hand, if you disable broadcasting generated files (by setting `broadcast_strategy` to `none`), +you can run `compile` mode without `torchrun`. - `tracing_from_weights` (`str`): The path to the weights to be loaded when tracing(compiling) the model. It is only used in tracing to serve as the initial state dict of the model. Default is `None`. - `precison`(`Union[str, Dict[_TENSOR_TYPE, _PRECISION_TYPE], None]`): The precision of the model. It can be a `str`, which means the same precision for all tensors, or a `Dict[_TENSOR_TYPE, _PRECISION_TYPE]`, which means the precision for each tensor type. Default is `None`. Currently we support 3 tensor types (`param`, `buffer`, `input`) and three precisions (`fp32`, `fp16`, `bf16`). You can set precision to `none` to avoid any precision conversion. - `micro_batch_size` (`int`): The micro batch size. Default is `1`. diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 3c90eb40..1342def9 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -11,6 +11,7 @@ import torch import torch.distributed from torch.utils.data import DataLoader +import psutil from tqdm import tqdm @@ -213,7 +214,7 @@ def _create_model(): gen_savedir=self.train_args.gen_savedir, reuse=self.train_args.gen_reuse, instance_name=self.train_args.instance_name, - broadcast_strategy='all', + broadcast_strategy=self.train_args.broadcast_strategy, load_module=not compile_only, ) if compile_only: @@ -228,7 +229,11 @@ def _create_model(): self.model = pmodel_class() self.model.cuda() self.optimizer = self.train_args.create_parallel_optimizer(self.model) - # the reduce op is `sum` by default, follow torch's c10d, grad is divided by scaling_factor before allreduce + # Here we carefully scale down the gradient locally with 1/scale_factor before reduce, + # (the reduce op is `sum` by default, follow torch's c10d, grad is divided by scaling_factor before allreduce) + # and scale up the gradient after reduce + # (see `train_args.optimizer.grad_reduction`` handling in `train_epoch`). + # This is useful to avoid overflow when the gradients are large. def reducer_pre_hook(reducer, grad): grad.div_(self.train_args.scaling_factor) self.optimizer.register_reducer_pre_hook(reducer_pre_hook) @@ -276,7 +281,8 @@ def _log_finalize(self): for logger in self.loggers: logger.finalize() - def log_metrics(self, metrics: Dict[str, float], step: int, *, tag: Optional[str] = None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None, *, tag: Optional[str] = None): + step = step or self.num_train_steps for logger in self.loggers: logger.log_metrics(metrics, step, tag=tag) @@ -322,6 +328,22 @@ def _load_checkpoint(self): self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) self.train_status = TrainStatus(**state_dict['train_status']) + def _log_mem_stats(self, tag=None): + # log minimum free memory over the iteration + cuda_free, _ = torch.cuda.mem_get_info() + cuda_gb_free = cuda_free / 1024 / 1024 / 1024 + cuda_gb_allocated = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024 + ram_gb_used = psutil.virtual_memory().used / 1024 / 1024 / 1024 + torch.cuda.reset_peak_memory_stats() + + self.log_metrics({ + 'cuda_gb_allocated': cuda_gb_allocated, + 'cuda_gb_reserved': cuda_gb_reserved, + 'cuda_gb_free': cuda_gb_free, + 'ram_gb_used': ram_gb_used, + }, tag=tag) + def _save_checkpoint(self, loss): checkpoint_config = self.train_args.checkpoint @@ -486,6 +508,10 @@ def train(self): f"next_batch_index({self.train_status.next_batch_index}) " \ f"should not be larger than total_train_steps_per_epoch ({self.total_train_steps_per_epoch})" + # reset peak memory stats before training + # So that we can get accurate peak memory usage for each step + torch.cuda.reset_peak_memory_stats() + if self.train_status.next_batch_index == self.total_train_steps_per_epoch: self.train_status.epoch += 1 self.train_status.next_batch_index = 0 @@ -675,6 +701,7 @@ def train_epoch(self, epoch): if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'step': self.lr_scheduler.step() + self._log_mem_stats(tag='train') self.log_metrics( {k:v for k, v in asdict(step_stat).items() if v is not None}, self.num_train_steps, diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index bef0bc1c..eb4de2f6 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -5,6 +5,7 @@ from pathlib import Path import logging import copy +import os import torch import torch.utils @@ -14,7 +15,7 @@ import torch from nnscaler.utils import transform_recursively -from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, _PREDEFINED_POLICIES +from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule from .arg_parser import deserialize_dataclass, merge_args, parse_args, _TYPE_KEY, _VALUE_TYPE_KEY, _VALUE_KEY @@ -350,6 +351,9 @@ def __post_init__(self): else: self.gen_reuse = 'moo' if self.run_mode == 'compile' else 'match' + if self.broadcast_strategy not in [e.value for e in BroadcastGenFilesStrategy]: + raise ValueError(f"Invalid broadcast_strategy {self.broadcast_strategy}") + supported_precision_type = get_args(_PRECISION_TYPE) supported_tensor_type = get_args(_TENSOR_TYPE) if not self.precision: @@ -510,7 +514,8 @@ def create_sampler(self, dataset, stage='train'): kwargs = self.create_kwarg(sampler_args) kwargs['dataset'] = dataset kwargs['num_replicas'] = self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus - kwargs['rank'] = torch.distributed.get_rank() // self.compute_config.plan_ngpus + # if not distributed, we use the rank 0 sampler + kwargs['rank'] = int(os.environ.get('RANK', 0)) // self.compute_config.plan_ngpus sampler_class = load_type(self.dataset_sampler.type) return sampler_class(**kwargs) diff --git a/requirements.txt b/requirements.txt index f8e85677..a96ac7c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ importlib-resources matplotlib more-itertools numpy>=1.23.0 +psutil pulp pybind11 pyyaml diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 61365777..bd8036d5 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -8,6 +8,7 @@ from nnscaler.cli.trainer import Trainer from nnscaler.cli.trainer_args import AggregatedOutputs from tests.parallel_module.common import assert_equal +from tests.utils import replace_all_device_with from ..launch_torchrun import launch_torchrun @@ -59,6 +60,27 @@ def test_trainer_logging(tmp_path): launch_torchrun(4, trainer_logging_worker, tmp_path) +@replace_all_device_with('cpu') +def test_trainer_compile_worker(tmp_path): + save_dir = Path(tmp_path) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + # compile only + Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + ]) + + assert set([f.name for f in gen_savedir.glob('**/*.py')]) == set(['gencode0.py', 'gencode1.py', 'gencode2.py', 'gencode3.py']) + shutil.rmtree(gen_savedir) + + def trainer_resume_worker(save_dir, save_type, bf16): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) From 8fec4b8702cc0bfcf40a8882924a843c41e81e43 Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Fri, 16 Aug 2024 08:31:54 +0000 Subject: [PATCH 1707/1892] Merged PR 2227: Log importance ratio for operators to make it easy to debug and optimize for autodist output each operator's importance ratio (percentages of states that can be reduced by forcing the operator to be partitioned in a single partition) an example output is ```text operator FwOp7-()(name=embedding, inputs=(t1768(p20,(1, 8192),d(),v(0/1)), w1770(p22,(32256, 4096),d(),v(0/1))), outputs=(t1771(p24,(1, 8192, 4096),d(),v(0/1)),)) has 4 partitions, importance ratio 0.225 at File "/home/yizhu1/ts_dev/Fairseq/nnscaler_examples/finetune_hf_model/src/model_helper/customize/modeling_nnscaler_mixtral_4_42.py", line 1047, in forward, inputs_embeds = self.embed_tokens(input_ids) operator FwOp22-()(name=transpose, inputs=(t1786(p66,(1, 8192, 8, 128),d(),v(0/1)),), outputs=(t1787(p68,(1, 8, 8192, 128),d(),v(0/1)),)) has 3 partitions, importance ratio 0.159 at File "/home/yizhu1/ts_dev/Fairseq/nnscaler_examples/finetune_hf_model/src/model_helper/customize/modeling_nnscaler_mixtral_4_42.py", line 358, in forward, value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) operator FwOp235-()(name=add, inputs=(t2029(p708,(1, 8192, 4096),d(),v(0/1)), t2063(p783,(1, 8192, 4096),d(),v(0/1))), outputs=(t2064(p785,(1, 8192, 4096),d(),v(0/1)),)) has 3 partitions, importance ratio 0.150 at File "/home/yizhu1/ts_dev/Fairseq/nnscaler_examples/finetune_hf_model/src/model_helper/customize/modeling_nnscaler_mixtral_4_42.py", line 831, in forward, hidden_states = residual + hidden_states operator FwOp160-()(name=add, inputs=(t1932(p457,(1, 8192, 4096),d(),v(0/1)), t1966(p532,(1, 8192, 4096),d(),v(0/1))), outputs=(t1967(p534,(1, 8192, 4096),d(),v(0/1)),)) has 3 partitions, importance ratio 0.150 at File "/home/yizhu1/ts_dev/Fairseq/nnscaler_examples/finetune_hf_model/src/model_helper/customize/modeling_nnscaler_mixtral_4_42.py", line 831, in forward, hidden_states = residual + hidden_states operator FwOp85-()(name=add, inputs=(t1835(p206,(1, 8192, 4096),d(),v(0/1)), t1869(p281,(1, 8192, 4096),d(),v(0/1))), outputs=(t1870(p283,(1, 8192, 4096),d(),v(0/1)),)) has 3 partitions, importance ratio 0.150 at File "/home/yizhu1/ts_dev/Fairseq/nnscaler_examples/finetune_hf_model/src/model_helper/customize/modeling_nnscaler_mixtral_4_42.py", line 831, in forward, hidden_states = residual + hidden_states ``` which means that constrain the partition space of embedding and residual add can reduce a large search space parity check & unit test passed --- nnscaler/autodist/spmd_solver.py | 113 ++++++++++++++++++++++--------- 1 file changed, 80 insertions(+), 33 deletions(-) diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index ea0caf5a..75503fdd 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -89,6 +89,61 @@ def __repr__(self): class SPMDSolver: + """ + Assume the dataflow graph is + a + / \ + b c + | / \ + d e f + | / | + g h + and operators are stored in a topological order [a, b, d, c, e, f, g, h]. + + In SPMD solver, dynamic programming is used to find the optimal partition plan where + dp[p(u), M] is the optimal plan for the subgraph ending with u in partition state p, + with memory bound M. if v is the predecessor of u in the topological order, then + dp[p(u), M] = min(dp[q(v), M - mem(p(u))] + comm_cost(p(u), q(v))) + comp_cost(p(u)) + comm_cost(p(u)) + where q(v) is the partition state of v, mem(p(u)) is the memory cost of p(u), + comm_cost(p(u), q(v)) is the communication cost between p(u) and q(v), and + comp_cost(p(u)) is the computation cost of p(u), comm_cost(p(u)) is the communication + cost of p(u) (like the allreduce cost in model update). + + However, u and v may be disconnected in the dataflow graph, like [d, c], [e, f], [f, g] + and [g, h] in the example above. To calculate the communication cost between p(u) and q(v), + we need to store additional information in the partition state. For example, we need to maintain + the partition state of node d in the partition state of node c, so that we can calculate the + communication cost when reaching node g. + + To achieve this, we calcuate the `cut_ops` for each node, which is the set of nodes that + need to be maintained in the partition state of the current node. The cut ops for the example + above are: + a: [a] + b: [a, b] + d: [a, d] + c: [d, c] + e: [d, c, e] + f: [d, e, f] + g: [f, g] + h: [h] + + To be more specific, the `cut_ops` is calculated by the following steps: + 1. calculate the `out_degs` for each node, which is the number of consumers of the node + 2. traverse the nodes in the topological order + - decrease each producer's `out_degs` by 1, if the `out_degs` is 0, remove the producer + from the `unclosed_idx` and set current node's idx (time) as the producer's 'close_time' + - set current node's `cut_ops` as the union of `unclosed_idx` and the node itself + - add the node to `unclosed_idx` if its #consumers > 0 + + However, in real-world scenarios, certain positions might have a large number of `cut_ops`, + and each op may have more than one partitioning strategy (for example, when the input data flow graph + is a complete graph). In such cases, the search space becomes very large, making it impossible to solve + within limited time and space. To help users reduce the search space, we calculate a metric called + `importance_ratio` for each op, which describes the percentage reduction in search space if the partitioning + strategy for that op is restricted to just one. + Thus, users can view the top 10 ops with the highest `importance ratio` output by autodist and use the + `partition constraint` interface to restrict the partitioning space of these ops, thereby speeding up the search process. + """ def __init__( self, @@ -119,39 +174,6 @@ def __init__( self.cost_database.profile_comp(self.device_num) self.stage_num = stage_num - # assume the dataflow graph is - # a - # / \ - # b c - # | / \ - # d e f - # | / | - # g h - # the ops are stored in a topological order [a, b, d, c, e, f, g, h] - # in spmd solver, dynamic programming is used to find the optimal partition plan - # dp[p(u), M] is the optimal plan for the subgraph ending with u in partition state p, - # with memory bound M. if v is the predecessor of u in the topological order, then - # dp[p(u), M] = min(dp[q(v), M - mem(p(u))] + comm_cost(p(u), q(v))) + comp_cost(p(u)) + comm_cost(p(u)) - # where q(v) is the partition state of v, mem(p(u)) is the memory cost of p(u), - # comm_cost(p(u), q(v)) is the communication cost between p(u) and q(v), and - # comp_cost(p(u)) is the computation cost of p(u), comm_cost(p(u)) is the communication - # cost of p(u) (like the allreduce cost in model update). - # However, u and v may not be connected in the dataflow graph, like [d, c], [e, f], [f, g] - # and [g, h] in the example above. To calculate the communication cost between p(u) and q(v), - # we need to store additional information in the partition state. For example, we need to maintain - # the partition state of node d in the partition state of node c, so that we can calculate the - # communication cost when reaching node g. - # to achieve this, we calcuate the 'cut ops' for each node, which is the set of nodes that - # need to be maintained in the partition state of the current node. The cut ops for the example - # above are: - # a: [a] - # b: [a, b] - # d: [a, d] - # c: [d, c] - # e: [d, c, e] - # f: [d, e, f] - # g: [f, g] - # h: [h] self.initialize() def initialize(self): @@ -686,6 +708,7 @@ def calc_partition_cost(self, op_idx: int, partition_idx: int): def calc_partition_info(self): self.partition_info: List[List[PartitionCostDesc]] = list() state_num = 0 + prefix_state_sums: List[int] = [0] * self.graph.op_num for i in range(self.graph.op_num): cur_info = [] _logger.debug(f'calc partition info for {self.get_operator(i)}') @@ -702,8 +725,30 @@ def calc_partition_info(self): cut_partition_cnts = [self.get_op_partition_count(idx) for idx in self.cut_ops[i]] cur_state_num = functools.reduce(lambda x, y: x * y, cut_partition_cnts, 1) state_num += cur_state_num + prefix_state_sums[i] = state_num _logger.debug(f'{i}-th operator follow {self.get_father_id(i)} with cut ops {self.cut_ops[i]}, {cut_partition_cnts}, {cur_state_num}') _logger.info(f'total state num is {state_num}') + if state_num > 1024 * 1024: + _logger.warning(f'too many states, please consider to add constraints to partition spaces of following operators') + importance_ratios: List[Tuple[float, int]] = list() + desc_str = 'output each operator\'s importance ratio (percentages of states that can be reduced by forcing the operator to be partitioned in a single partition)\n' + for i in range(self.graph.op_num): + partition_cnt = self.get_op_partition_count(i) + if partition_cnt == 1: + continue + if i not in self.close_times: + continue + related_state_num = prefix_state_sums[self.close_times[i] - 1] + if i > 0: + related_state_num -= prefix_state_sums[i - 1] + ratio = related_state_num / partition_cnt / state_num * (partition_cnt - 1) + importance_ratios.append((ratio, i)) + importance_ratios.sort(reverse=True) + for idx in range(min(10, len(importance_ratios))): + ratio, i = importance_ratios[idx] + node = self.get_operator(i).ir_cell + desc_str += f'operator {node} has {self.get_op_partition_count(i)} partitions, importance ratio {ratio:.3f}\nat {node.comment}\n\n' + _logger.info(desc_str) _logger.info('finish spmd solver initializetion') def estimate_min_mem(self, start: int, end: int) -> int: @@ -868,6 +913,7 @@ def build_cut_ops(self): out_degs = [len(op.consumers) for op in self.graph.operator_list] unclosed_idx = set() self.cut_ops: List[List[int]] = list() + self.close_times: Dict[int, int] = dict() for i, op in enumerate(self.graph.operator_list): for pred in op.producers: pred_idx = cid2idx[pred.ir_cell.cid] @@ -875,6 +921,7 @@ def build_cut_ops(self): out_degs[pred_idx] -= 1 if out_degs[pred_idx] == 0: unclosed_idx.remove(pred_idx) + self.close_times[pred_idx] = i ret = list(unclosed_idx) + [i] ret.sort() self.cut_ops.append(ret) From e6d19e957f266d5de1db571dccb7e2e20b0af1e7 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 16 Aug 2024 08:39:13 +0000 Subject: [PATCH 1708/1892] Merged PR 2226: add nested output support add nested output support --- nnscaler/codegen/emit.py | 50 ++++++- nnscaler/codegen/lifecycle.py | 4 +- nnscaler/graph/function/dimops.py | 2 +- nnscaler/graph/graph.py | 5 +- nnscaler/graph/segment.py | 6 +- nnscaler/ir/cten.py | 46 ++++++- nnscaler/ir/operator.py | 19 ++- nnscaler/utils.py | 17 ++- tests/compiler/test_model.py | 208 ++++++++++++++++++++++++++++++ tests/test_utils.py | 10 ++ 10 files changed, 331 insertions(+), 36 deletions(-) create mode 100644 tests/compiler/test_model.py create mode 100644 tests/test_utils.py diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index fec27406..d2060a22 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -223,12 +223,52 @@ def emit_fnode(self, node: IRFwOperation, runtime_devid: int, plan_ndevs: int, r body = emit_rule(node, inputs, kwargs, runtime_devid, plan_ndevs, runtime_ndevs) if len(node.outputs()) == 0: - code = body + codes.append(body) else: - outputs = [self.tensor_name(t) for t in node.outputs()] - outputs = ', '.join(outputs) - code = f'{outputs} = {body}' - codes.append(code) + irobj_path = {} + def r(t, current_path): + if isinstance(t, IRObject): + irobj_path[t] = current_path + elif isinstance(t, (list, tuple)): + for i, v in enumerate(t): + r(v, current_path + [i]) + elif isinstance(t, dict): + for k, v in t.items(): + r(v, current_path + [k]) + else: + # do nothing + pass + r(node.outputs(), []) + if all(len(x) == 1 for x in irobj_path.values()): + # if all IRObjects are leafs, we can directly assign the output + outputs = [self.tensor_name(t) for t in node.outputs()] + outputs = ', '.join(outputs) + codes.append(f'{outputs} = {body}') + else: + outputs = [] + im_outputs = [] + for t in node.outputs(): + if isinstance(t, IRObject): + outputs.append(self.tensor_name(t)) + else: + # new intermediate output + im_ouptut = self.tensor_name(IRObject('im_output')) + im_outputs.append(im_ouptut) + outputs.append(im_ouptut) + codes.append(f'{", ".join(outputs)} = {body}') + + for t, path in irobj_path.items(): + if len(path) == 1: # immediate output, skip + continue + out = outputs[path[0]] + for p in path[1:]: + out = f'{out}[{repr(p)}]' # extract step by step + codes.append(f'{self.tensor_name(t)} = {out}') + # release intermediate outputs + # because they are not used in the future, and don't managed by lifecycle + for im_output in im_outputs: + codes.append(f'del {im_output}') + return codes def emit_adapter(self, node: IRAdapter, prefix_attr: Optional[str] = None, diff --git a/nnscaler/codegen/lifecycle.py b/nnscaler/codegen/lifecycle.py index 9aff2052..cd1be218 100644 --- a/nnscaler/codegen/lifecycle.py +++ b/nnscaler/codegen/lifecycle.py @@ -54,10 +54,10 @@ def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: # aggressively mark all outputs for immediate deletion, # namely *after* 'i'-th statement, in case it's never used. - self.lifetime.update((tout, i) for tout in outputs if is_activation(tout)) + self.lifetime.update((tout, i) for tout in IRSegment.get_objects_from_complex(outputs) if is_activation(tout)) # "fast-forward" all inputs to the current statement, namely after 'i'-th node. - self.lifetime.update((tin, i) for tin in inputs if is_activation(tin)) + self.lifetime.update((tin, i) for tin in IRSegment.get_objects_from_complex(inputs) if is_activation(tin)) # Here (i+1) is always greater than 'len(nodes)' diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index 9904bcae..513f39ee 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -716,7 +716,7 @@ def infer_shape(self) -> bool: for oidx, otensor in enumerate(self.outputs()): shape_anno = self.oanno(oidx) if shape_anno.ignore: - assert isinstance(otensor, IRObject), f"expect IRObject for unknown shape, get {otensor}" + # otensor can be any type, including IRObject, collection types (list, dict, etc.) continue shape = [] for odim in range(shape_anno.ndims): diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 8f0b8e18..63d57749 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -183,7 +183,10 @@ def from_logic_graph(nodes: List[IRCell], for idx, ftensor in enumerate(node.outputs()): subtensor = IRCell.modify_objects_of_complex(ftensor, modifier) node.set_output(idx, subtensor) - node.kwargs.update(IRCell.modify_objects_of_complex(node.kwargs, modifier)) + for key in node.kwargs.keys(): + subtensor = IRCell.modify_objects_of_complex(node.kwargs[key], modifier) + node.set_kwarg(key, subtensor) + graph = IRGraph(nodes, inputs, outputs, module_name) # check IRPyFunc diff --git a/nnscaler/graph/segment.py b/nnscaler/graph/segment.py index a53282d4..251f192d 100644 --- a/nnscaler/graph/segment.py +++ b/nnscaler/graph/segment.py @@ -1004,8 +1004,7 @@ def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> I inputs, outputs = set(), set() sub_cids = set(node.cid for node in nodes) for node in nodes: - for itensor in node.inputs(): - if not isinstance(itensor, IRObject): continue + for itensor in node.iobjs(): if itensor.is_attr(): if attr_as_inputs: inputs.add(itensor) @@ -1018,8 +1017,7 @@ def create_segment(self, nodes: List[IRCell], attr_as_inputs: bool = False) -> I # if no producers inside the nodes can produce data, set as input if all(pid not in sub_cids for pid in pids): inputs.add(itensor) - for otensor in node.outputs(): - if not isinstance(otensor, IRObject): continue + for otensor in node.oobjs(): # if the tensor is required by segment outputs, set as output if otensor in segment_outputs: outputs.add(otensor) diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index cd268749..c3c4f5ac 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -215,6 +215,11 @@ def outputs(self) -> Tuple[NestedVarOrStatic]: """ return tuple(self._outputs) + def _copy_and_set_cell(self, x: IRObject) -> IRObject: + x = copy.copy(x) + x.cell = self + return x + def reset_inputs(self, length:int) -> None: """ Resize the inputs list to the new length and reset all input items to None. @@ -232,15 +237,42 @@ def set_input(self, index: int, val: NestedVarOrStatic) -> NestedVarOrStatic: Returns: NestedVarOrStatic: copied value """ - if isinstance(val, IRObject): - # copy the val - val = copy.copy(val) - val.cell = self + # recursive set cell + val = IRCell.modify_objects_of_complex(val, self._copy_and_set_cell) + self._inputs[index] = val self.inputs.cache_clear() self.iobjs.cache_clear() return val + def reset_kwargs(self) -> None: + """ + Clear all kwargs + """ + self._kwargs = {} + self.iobjs.cache_clear() + + def set_kwarg(self, name: str, val: NestedVarOrStatic) -> NestedVarOrStatic: + """Set the kwarg with name + + Args: + val (NestedVarOrStatic): (nested) IRObject or any deterministic value (int, bool, str, etc) + + Returns: + NestedVarOrStatic: copied value + """ + # TODO: is it possible that kwargs can be IRTensor? + # But it is used in unit tests. + # if isinstance(val, IRTensor): + # raise ValueError("IRTensor is not allowed to be a kwarg") + + # recursive set cell + val = IRCell.modify_objects_of_complex(val, self._copy_and_set_cell) + + self._kwargs[name] = val + self.iobjs.cache_clear() + return val + def reset_outputs(self, length:int) -> None: """ Resize the outputs list to the new length and reset all output items to None. @@ -259,9 +291,9 @@ def set_output(self, index: int, val: NestedVarOrStatic): Returns: NestedVarOrStatic: copied value """ - if isinstance(val, IRObject): - val = copy.copy(val) - val.cell = self + # recursive set cell + val = IRCell.modify_objects_of_complex(val, self._copy_and_set_cell) + self._outputs[index] = val self.outputs.cache_clear() self.oobjs.cache_clear() diff --git a/nnscaler/ir/operator.py b/nnscaler/ir/operator.py index 17217618..9269ad5e 100644 --- a/nnscaler/ir/operator.py +++ b/nnscaler/ir/operator.py @@ -32,16 +32,8 @@ def __init__(self, name: str, signature: str, self.set_input(idx, input) # setup kwargs - # similar with set_input and set_output, the IRObject - # in kwargs will be set with copy-on-write to avoid - # potential modifications outside. - def replace(t: IRObject): - t = copy.copy(t) - t.cell = self - return t - - kwargs = IRCell.modify_objects_of_complex(kwargs, replace) - self.kwargs.update(kwargs) + for name, value in kwargs.items(): + self.set_kwarg(name, value) # default infer rule requires_grad = any( @@ -122,6 +114,9 @@ def replicate(self): cpy.reset_outputs(len(self.outputs())) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) + cpy.reset_kwargs() + for name, value in self.kwargs.items(): + cpy.set_kwarg(name, value) cpy._mirror = None cpy.recompute = self.recompute return cpy @@ -177,6 +172,7 @@ def replicate(self): cpy.reset_outputs(len(self.outputs())) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) + assert not cpy.kwargs, "No kwargs for backward op" cpy._mirror = None return cpy @@ -218,6 +214,9 @@ def replicate(self): cpy.reset_outputs(len(self.outputs())) for idx, output in enumerate(self.outputs()): cpy.set_output(idx, output) + cpy.reset_kwargs() + for name, value in self.kwargs.items(): + cpy.set_kwarg(name, value) cpy._mirror = None return cpy diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 30dcfeab..cb877e34 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -9,7 +9,6 @@ import inspect import nnscaler -from nnscaler.runtime.device import DeviceGroup from nnscaler.flags import RuntimeFlag, CompileFlag import torch @@ -57,7 +56,7 @@ def _load_module_attr(filename: str, name: str): def load_model(filename: Optional[str] = None, load_content: bool = True, fullmodel_filename: Optional[str] = None): - filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename + filename = f'gencode{nnscaler.runtime.device.DeviceGroup().rank}.py' if filename is None else filename module = _load_module_attr(filename, Path(filename).stem) loaded_module: nnscaler.runtime.module.CubeModule = module.GenModel().cuda() non_persistent_buffers = loaded_module.get_non_persistent_buffers() @@ -79,13 +78,13 @@ def load_model(filename: Optional[str] = None, load_content: bool = True, fullmo def load_default_schedule(filename: Optional[str] = None): - filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename + filename = f'gencode{nnscaler.runtime.device.DeviceGroup().rank}.py' if filename is None else filename module = _load_module_attr(filename, Path(filename).stem) return module._train_step def load_eval_schedule(filename: Optional[str] = None): - filename = f'gencode{DeviceGroup().rank}.py' if filename is None else filename + filename = f'gencode{nnscaler.runtime.device.DeviceGroup().rank}.py' if filename is None else filename module = _load_module_attr(filename, Path(filename).stem) return module._infer_step @@ -141,10 +140,10 @@ def setup_stride_broadcast_group(stride_size: int) -> BroadcastGroup: world_size = torch.distributed.get_world_size() for i in range(stride_size): ranks = list(range(i, world_size, stride_size)) - DeviceGroup().get_group(ranks) + nnscaler.runtime.device.DeviceGroup().get_group(ranks) curr_parallel_group_ranks = list(range(rank % stride_size, world_size, stride_size)) - curr_parallel_group = DeviceGroup().get_group(curr_parallel_group_ranks) + curr_parallel_group = nnscaler.runtime.device.DeviceGroup().get_group(curr_parallel_group_ranks) src_rank = min(curr_parallel_group_ranks) return BroadcastGroup( @@ -262,6 +261,12 @@ def transform_recursively(data: Any, fn: Callable[[Any], Any], return data +def select_many(data: Iterable[Any], fn: Callable[[Any], Iterable[Any]]) -> Iterable[Any]: + """Select many elements from the iterable with the given function.""" + for item in data: + yield from fn(item) + + class accum_mode: """Make cube execution in gradient accumulation mode. diff --git a/tests/compiler/test_model.py b/tests/compiler/test_model.py new file mode 100644 index 00000000..2d1b3ef9 --- /dev/null +++ b/tests/compiler/test_model.py @@ -0,0 +1,208 @@ +import os +import logging +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F + +import pytest + +import nnscaler +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType + +from ..launch_torchrun import launch_torchrun +from ..utils import assert_parity + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[ + :, + :, + max(-pad_y0, 0) : \ + out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0) : \ + out.shape[3] - max(-pad_x1, 0), + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + self.factor = factor + kernel = make_kernel(kernel) * (factor**2) + self.register_buffer("kernel", kernel) + p = kernel.shape[0] - factor + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + self.pad = (pad0, pad1) + + def forward(self, input): + return upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + +def baseline(): + model = Upsample([1, 3, 3, 1], factor=2) + torch.manual_seed(0) + input = torch.rand(1, 3, 64, 64) + + # single gpu execution + model = model.cuda() + single_out = model(input.cuda()) + single_out = single_out.to('cpu') + return single_out + + +def parallelize_run(tmp_path): + model = Upsample([1, 3, 3, 1], factor=2) + torch.manual_seed(0) + input = torch.rand(1, 3, 64, 64) + + # multiple gpu execution + nnscaler.init() + nnscaler.utils.set_default_logger_level(logging.INFO) + compute_config = ComputeConfig( + 2, 2, constant_folding=True, use_zero=True, + pas_config={'update_freq': 1, '_batch_size': 1} + ) + para_model = parallelize( + model, + {'input': input,}, + 'autodist', + compute_config, + reuse=ReuseType.OVERRIDE, + gen_savedir=tmp_path + ) + para_model = para_model.cuda() + para_out = para_model(input.cuda()) + para_out = para_out.to('cpu') + return para_out + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_complex_model(tmp_path): + """ + The simplified model `Upsample` has complex operations, + such as, slice, max, permute, reshape, view, pad, flip, conv2d, etc. + """ + + launch_torchrun(2, assert_parity, baseline, partial(parallelize_run, tmp_path)) + + + +@nnscaler.register_op('m n -> m n, m n, ?') +def func_multi_outputs(x): + return x, x, 3 + + +# NOTE: "x" can be partitioned because "?" has no dependency on `x` +@nnscaler.register_op('m n -> m n, ?') +def func_output_complex_dict(x, factor=1): + x = x * factor + return x, {'y': 10} + + +# NOTE: "x" can be partitioned because "?" has no dependency on `x` +@nnscaler.register_op('m n -> m n, ?') +def func_output_complex_slice(x, factor=1): + x = x * factor + return x, slice(0, 10, factor) + + +# NOTE: "x" cannot be partitioned because there is "?" in output annotation +# and has dependency on `x` +@nnscaler.register_op('m^ n^ -> ?') +def func_output_list(x, factor=1): + x = x * factor + return [x, x] + + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = func_multi_outputs(x) + y, _, scalar = x + (sz, _) = y.shape + sz = sz + scalar + out = func_output_list(y, factor=sz) + out = out[0] + out, out_dict = func_output_complex_dict(out, factor=scalar) + out, out_slice = func_output_complex_slice(out, factor=out_dict['y']) + return {'out': out, 'slice_start': out_slice.start} + + +def single_run(): + torch.manual_seed(0) + dummy_input = torch.randn(4, 4) + module = MyModule() + module = module.cuda() + out = module(dummy_input.cuda()) + out = {'out': [each.to('cpu') for each in out['out']], 'slice_start': out['slice_start']} + return out + + +def two_gpu_run(tmp_path): + torch.manual_seed(0) + dummy_input = torch.randn(4, 4) + module = MyModule() + module = module.cuda() + nnscaler.init() + nnscaler.utils.set_default_logger_level(logging.INFO) + compute_config = ComputeConfig( + 2, 2, constant_folding=False, use_zero=True, + pas_config={'update_freq': 1, '_batch_size': 1} + ) + para_module = parallelize( + module, + {'x': dummy_input.cuda()}, + 'tp', + compute_config, reuse=ReuseType.OVERRIDE, + gen_savedir=tmp_path + ) + para_module = para_module.cuda() + out = para_module(dummy_input.cuda()) + out = {'out': [each.to('cpu') for each in out['out']], 'slice_start': out['slice_start']} + return out + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_complex_outputs(tmp_path): + launch_torchrun(2, assert_parity, + single_run, + partial(two_gpu_run, tmp_path) + ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..1f5f3d7f --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,10 @@ +import pytest + +from nnscaler.utils import select_many + + +def test_select_many(): + assert list(select_many([1, 2], lambda k: [])) == [] + assert list(select_many([1, [2, 3]], lambda k: k if isinstance(k, list) else [k])) == [1, 2, 3] + with pytest.raises(TypeError): + list(select_many([1, [2, 3]], lambda k: k)) From c4335c3d09f03c4209ed71b30dc465b8773c5cc6 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 16 Aug 2024 09:13:23 +0000 Subject: [PATCH 1709/1892] Merged PR 2238: minitrainer: fix bug when running compile with multiple workers minitrainer: fix bug when running compile with multiple workers --- nnscaler/cli/trainer.py | 5 +++-- nnscaler/utils.py | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 1342def9..efcc9a88 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -16,7 +16,7 @@ from tqdm import tqdm import nnscaler -from nnscaler.utils import enforce_zero_num_worker +from nnscaler.utils import enforce_zero_num_worker, is_running_distributed import nnscaler.utils from .trainer_args import AggregatedOutputs, TrainerArgs @@ -135,7 +135,8 @@ def _load_dummy_input(self): def _setup(self): self.train_args.init_env() compile_only = self.train_args.run_mode == 'compile' - if not compile_only: + + if is_running_distributed(): nnscaler.init() if torch.distributed.get_rank() == 0: logging.getLogger().setLevel(logging.INFO) diff --git a/nnscaler/utils.py b/nnscaler/utils.py index cb877e34..046641a0 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -7,6 +7,7 @@ from collections import defaultdict from dataclasses import dataclass import inspect +import os import nnscaler from nnscaler.flags import RuntimeFlag, CompileFlag @@ -261,6 +262,14 @@ def transform_recursively(data: Any, fn: Callable[[Any], Any], return data +def is_running_distributed() -> bool: + """Check if the current process is running under torchrun.""" + # TORCHELASTIC_RUN_ID is more unique than 'RANK'/'WORLD_SIZE' + # so we use it to determine if the process is running under torchrun. + # TODO: Is there a better way? + return 'TORCHELASTIC_RUN_ID' in os.environ + + def select_many(data: Iterable[Any], fn: Callable[[Any], Iterable[Any]]) -> Iterable[Any]: """Select many elements from the iterable with the given function.""" for item in data: From 5d8826608254a92222851ec76c08952b9d766822 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 21 Aug 2024 06:00:03 +0000 Subject: [PATCH 1710/1892] Merged PR 2235: bugfix: grad track in trace --- .../concrete_trace_utils/concrete_tracer.py | 19 +++-- tests/graph/tracer/test_ctxt_manager.py | 75 ++++++++++++------- 2 files changed, 58 insertions(+), 36 deletions(-) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 5b449150..9793fa35 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -407,22 +407,27 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] try: if self.cpu_offload: - args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cuda(), args) - kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cuda(), kwargs) + # Concrete tracer use `.cuda()` to execute operators in device and `.cpu()` to move the result back to host. + # In most cases, `.cuda()` and `.cpu()` keeps the source tensor's `requires_grad` attributes. + # The context `torch.no_grad` enforces requires_grad=False for all tensors that generated in its scope. + # As a result, behavior of `.cuda()` and `.cpu()' is unexpected. + # To handle this case, we manually set the `requires_grad` field after `.cuda()` and `.cpu()`. + args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cuda().requires_grad_(x.requires_grad), args) + kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cuda().requires_grad_(x.requires_grad), kwargs) result = run(kind, target, args, kwargs) except torch.cuda.OutOfMemoryError: if self.cpu_offload: _logger.warning(f"cuda out of memory, try to trace {target} on cpu.") - args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), args) - kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), kwargs) + args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), args) + kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), kwargs) result = run(kind, target, args, kwargs) else: raise if self.cpu_offload: - args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), args) - kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), kwargs) - result = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu(), result) + args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), args) + kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), kwargs) + result = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), result) if not pytree_utils.tree_any(lambda x: isinstance(x, torch.Tensor), result) and \ pytree_utils.tree_any(lambda x: not isinstance(x, (*base_types, type(None), torch.Tensor)), result): diff --git a/tests/graph/tracer/test_ctxt_manager.py b/tests/graph/tracer/test_ctxt_manager.py index 8e0de855..914dfe1a 100644 --- a/tests/graph/tracer/test_ctxt_manager.py +++ b/tests/graph/tracer/test_ctxt_manager.py @@ -5,41 +5,58 @@ from ...utils import replace_all_device_with -class SimpleModel(torch.nn.Module): - def __init__(self): +# copy from transformers llama modeling +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() - self.fc = torch.nn.Linear(10, 10) + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, x): - with torch.no_grad(): - y = self.fc(x) - z = self.fc(x) - return y + z + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class TestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.rotary_emb = LlamaRotaryEmbedding(128) + self.fc1 = torch.nn.Linear(128, 128) + self.fc2 = torch.nn.Linear(128, 128) + + def forward(self, x, position_ids): + hidden = self.fc1(x) + cos, sin = self.rotary_emb(hidden, position_ids) + return self.fc2(hidden * cos * sin) @replace_all_device_with('cpu') def test_requires_grad(): with tempfile.TemporaryDirectory() as tempdir: - model = SimpleModel() - dummy_input = {'x': torch.rand(10)} + model = TestModule() + dummy_input = {'x': torch.rand(1, 100, 128), 'position_ids': torch.arange(0, 100, dtype=torch.int64).reshape(1, 100)} graph = convert_model(model, dummy_input, tempdir) - node_no_grad_fc, node_fc, node_add = graph.nodes() - # x under no grad context - assert node_no_grad_fc.inputs()[0].parent.requires_grad is False - # fc weight under no grad context - assert node_no_grad_fc.inputs()[1].parent.requires_grad is True - # fc output under no grad context - assert node_no_grad_fc.outputs()[0].parent.requires_grad is False - # x outside no grad context - assert node_fc.inputs()[0].parent.requires_grad is False - # fc weight outside no grad context - assert node_fc.inputs()[1].parent.requires_grad is True - # fc output outside no grad context - assert node_fc.outputs()[0].parent.requires_grad is True - # y - assert node_add.inputs()[0].parent.requires_grad is False - # z - assert node_add.inputs()[1].parent.requires_grad is True - # result - assert node_add.outputs()[0].parent.requires_grad is True + hidden_mul_cos_node = graph.nodes()[15] + + # hidden requires_grad is True + assert hidden_mul_cos_node.inputs()[0].parent.requires_grad is True + # cos requires_grad is False + assert hidden_mul_cos_node.inputs()[1].parent.requires_grad is False + # output requires_grad is True + assert hidden_mul_cos_node.outputs()[0].parent.requires_grad is True From ec6137a6aa746428d9e99893009c01000d06d1de Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 21 Aug 2024 08:16:53 +0000 Subject: [PATCH 1711/1892] Merged PR 2239: parallel module: decouple from Program() parallel module: decouple from Program() --- nnscaler/graph/graph.py | 73 +++++++++++++++- nnscaler/ir/cten.py | 5 +- nnscaler/parallel.py | 40 +++------ nnscaler/program.py | 53 ++++++------ nnscaler/runtime/device.py | 120 ++++++++++++-------------- tests/compiler/test_compile.py | 45 +++++++++- tests/graph/parser/test_parser.py | 2 +- tests/parallel_module/test_gencode.py | 61 +++++++++++++ tests/utils.py | 10 ++- 9 files changed, 279 insertions(+), 130 deletions(-) diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 63d57749..de4098d1 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -102,10 +102,15 @@ def forward(self, *args: Tuple[IRObject]) -> Union[IRTensor, Tuple[IRTensor]]: # reset output self.replace_output(iobj, arg) - from nnscaler.program import Program - Program().add_nodes(self.nodes()) - - # return + # set global graph, so @compile can access it. + # @compile needs a global graph to work + from nnscaler.program import Program, is_global_graph_enabled + if is_global_graph_enabled(): + Program().add_nodes(self.nodes()) + + # return the output of the graph + # the return value simulates the output of the model `forward` + # e.g. If there is only one output, return the output tensor directly instead of a tuple if len(self.outputs()) == 1: return self.output(0) else: @@ -136,6 +141,18 @@ def backward(self, loss: Optional[IRSubTensor] = None): # set loss gradient loss.parent.to_loss() + # update input gradient + # Please note `infer_grad` will not set the grad of input tensors. + for t in IRGraph.get_objects_from_complex(self.inputs()): + if isinstance(t, IRSubTensor) and t.requires_grad: + t.grad = t.parent.grad.tosub() + + # update output gradient + # Please note `infer_grad` will not set the grad of output tensors. + for t in IRGraph.get_objects_from_complex(self.outputs()): + if isinstance(t, IRSubTensor) and t.requires_grad: + t.grad = t.parent.grad.tosub() + # infer gradient for ftensor in self.full_tensors(): self.infer_grad(ftensor) @@ -171,6 +188,30 @@ def from_logic_graph(nodes: List[IRCell], Returns: IRGraph: the graph with each tensor is IRSubTensor. """ + # currently fx graph always has only one output + assert len(outputs) == 1, "Single output graph is expected" + if isinstance(outputs[0], tuple): + # fx graph will always wrap the graph output with a tuple of outputs + # case 1: the return value of graph looks like `return x, y, z` + # here `outputs` will be `[(x,y,z)]` + # we will remove the outer tuple to make graph outputs[0]/[1]/[2] as x/y/z respectively + # case 2: the return value of graph is a single value `return [[x]]` + # here `outputs` will be `[[[x]]]` + # just meet our requirement, no need to change + # the graph outputs[0] is `[[x]]`` + # case 3: the return value of graph is a single value `return x` + # here `outputs` will be `[x]` + # just meet our requirement, no need to change + # the graph outputs[0] is `x` + # Please note that + # 1. we treat `return x, y, z` and `return tuple(x, y, z)` as the same + # 2. we treat `return (x,)` and `return x` as the same + # Case 2 can lead to problem because it changes the return of `module.forward`, + # so we will raise error for this case for now. + outputs = outputs[0] + if isinstance(outputs, tuple) and len(outputs) == 1: + raise RuntimeError("Single tuple outputs (like `return (x,)`) is not supported") + modifier = lambda t: t.tosub() if isinstance(t, IRFullTensor) else t # input / output inputs = [IRCell.modify_objects_of_complex(t, modifier) for t in inputs] @@ -230,6 +271,30 @@ def from_logic_graph(nodes: List[IRCell], return graph + def use_dataloader_input(self): + """ + connect the graph with dataloader input. + """ + # replace graph inputs with dataloader + # the IRObject representing the `dataloader` instance, which is only used by the + # IRDataOperation. Since we already know the output of the dataloader, + # we don't need to set the value for it. + ir_root_obj = IRObject(name='dataloader', value=None, is_constant=False) + data_op = IRDataOperation(ir_root_obj, self.inputs()) + # add the data operation to the graph, which will use `next` to get data. + self.insert(data_op, 0) + self.reset_inputs(1) + self.set_input(0, ir_root_obj) + + def no_backward(self): + """ + Set all tensors with requires_grad=False to simulate no backward scenario (inference only). + """ + if any(isinstance(node, IRBpOperation) for node in self.nodes()): + raise RuntimeError("Cannot set no_backward for a graph with backward operators") + for ftensor in self.full_tensors(): + ftensor.requires_grad = False + ##### Transformation Primitives ##### def replicate(self, node: Union[IRFwOperation, IRDataOperation], times=1) -> List[IRCell]: diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index c3c4f5ac..7a215362 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -734,13 +734,14 @@ def byte_size(self) -> int: def backward(self) -> None: """ - Autograd backward on the tensor + Autograd backward on the tensor, which is used in @compile The backward will apply on the program graph @return None """ - from nnscaler.program import Program + from nnscaler.program import Program, is_global_graph_enabled + assert is_global_graph_enabled(), "Require global graph enabled to call loss.backward()" graph = Program().get_graph() return graph.backward(self) diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 8d24e4fc..a7ab528a 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -42,7 +42,7 @@ from nnscaler.flags import CompileFlag, RuntimeFlag import nnscaler.policies as policies -from nnscaler.program import Program +from nnscaler.program import disable_global_graph from nnscaler.utils import get_member_by_name, setup_stride_broadcast_group, get_shared_params logger = logging.getLogger(__name__) @@ -613,9 +613,8 @@ def _gen_graph( inference_only: bool = False, ): # reset environment - program = Program() - program.clear() IDGenerator().clear() + disable_global_graph() module.cpu() forward_args_default = _get_arg_default_values(module.forward) @@ -628,7 +627,7 @@ def _gen_graph( fx_graph = parser.to_fx_graph(module, dummy_forward_args) # generate ir logic graph - ir_graph = parser.to_ir_graph( + graph = parser.to_ir_graph( fx_graph, dummy_forward_args, outdir, constant_folding ) @@ -660,42 +659,27 @@ def _gen_graph( ir_dummy_inputs[i] = IRObject(fx_input_nodes[i].target, value=ir_dummy_inputs[i], is_constant=False) # generate complete ir graph + ir_dummy_outputs = graph(*ir_dummy_inputs) if end2end_mode: # in end2end mode, we must use dataloader as the first argument of forward # we assume the first argument of forward is the data sample (which is a requirement in our doc) + graph.use_dataloader_input() - # the IRObject representing the `dataloader` instance, which is only used by the - # IRDataOperation. Since we already know the output of the dataloader, - # we don't need to set the value for it. - ir_root_obj = IRObject(name='dataloader', value=None, is_constant=False) - Program().set_input([ir_root_obj]) - data_op = IRDataOperation(ir_root_obj, ir_dummy_inputs) - # add the data operation to the graph, which will use `next` to get data. - Program().add_node(data_op) - ir_dummy_outputs = ir_graph(*ir_dummy_inputs) - graph = program.get_graph() # we require the first output is the loss if isinstance(ir_dummy_outputs, (list, tuple)): ir_loss = ir_dummy_outputs[0] else: ir_loss = ir_dummy_outputs if not isinstance(ir_loss, IRTensor) or ir_loss.shape != (1,): - # TODO: update when we support scalar tensor + # internally scalar tensor will be reshaped to (1,) in IRGraph raise RuntimeError(f"Loss can only be scalar tensor but got {ir_loss.shape if isinstance(ir_loss, IRTensor) else ir_loss}") - if not inference_only: - ir_loss.backward() else: - program.set_input(ir_dummy_inputs) - ir_dummy_outputs = ir_graph(*ir_dummy_inputs) - graph = program.get_graph() - if not inference_only: - graph.backward() - - if ir_dummy_outputs is None: ir_dummy_outputs = [] - elif not isinstance(ir_dummy_outputs, (tuple, list)): - ir_dummy_outputs = [ir_dummy_outputs] - program.set_output(ir_dummy_outputs) - program.finalize() + ir_loss = None + + if not inference_only: + graph.backward(ir_loss) + else: + graph.no_backward() return graph, forward_args diff --git a/nnscaler/program.py b/nnscaler/program.py index 075e1f79..2ffddfb7 100644 --- a/nnscaler/program.py +++ b/nnscaler/program.py @@ -18,47 +18,52 @@ import torch.utils.data as data -class Program: +_program_graph: Optional[IRGraph] = None + - class __Program: +def enable_global_graph(): + global _program_graph + _program_graph = IRGraph([], [], [], 'program') - def __init__(self): - self._graph = IRGraph([], [], [], 'program') - instance = None +def disable_global_graph(): + global _program_graph + _program_graph = None - def __init__(self): - if not Program.instance: - Program.instance = Program.__Program() - def __getattr__(self, name): - return getattr(self.instance, name) +def is_global_graph_enabled(): + return _program_graph is not None + +class Program: + """ + This is only used in @compile for backward compatibility. + """ def add_node(self, node: IRCell): - self.instance._graph.insert(node, self.instance._graph.nnodes) + _program_graph.insert(node, _program_graph.nnodes) def add_nodes(self, nodes: List[IRCell]): for node in nodes: self.add_node(node) def get_graph(self) -> IRGraph: - return self.instance._graph + return _program_graph def set_input(self, inputs: Tuple[Any]): - self.instance._graph.reset_inputs(len(inputs)) + _program_graph.reset_inputs(len(inputs)) for idx, obj in enumerate(inputs): - self.instance._graph.set_input(idx, obj) + _program_graph.set_input(idx, obj) # update gradient - for t in IRGraph.get_objects_from_complex(self.instance._graph.inputs()): + for t in IRGraph.get_objects_from_complex(_program_graph.inputs()): if isinstance(t, IRSubTensor) and t.requires_grad: t.grad = t.parent.grad.tosub() def set_output(self, outputs: Tuple[Any]): - self.instance._graph.reset_outputs(len(outputs)) + _program_graph.reset_outputs(len(outputs)) for idx, otensor in enumerate(outputs): - self.instance._graph.set_output(idx, otensor) + _program_graph.set_output(idx, otensor) # update gradient - for t in IRGraph.get_objects_from_complex(self.instance._graph.outputs()): + for t in IRGraph.get_objects_from_complex(_program_graph.outputs()): if isinstance(t, IRSubTensor) and t.requires_grad: t.grad = t.parent.grad.tosub() @@ -67,18 +72,16 @@ def finalize(self): Close the recording of program. If the program doesn't do backward, set all tensors with requires_grad=False. """ - graph = self.get_graph() # inference scenario, set all gradients to none. - if not any(isinstance(node, IRBpOperation) for node in graph.nodes()): - # set gradients of activation tensors to none - for ftensor in graph.full_tensors(): - ftensor.requires_grad = False + if not any(isinstance(node, IRBpOperation) for node in _program_graph.nodes()): + _program_graph.no_backward() def clear(self): - Program.instance._graph = IRGraph([], [], [], 'program') + # will enable and create an empty global graph + enable_global_graph() def __repr__(self): - return repr(self.instance._graph) + return repr(_program_graph) class SemanticDataLoader: diff --git a/nnscaler/runtime/device.py b/nnscaler/runtime/device.py index 63f00003..f28c47b2 100644 --- a/nnscaler/runtime/device.py +++ b/nnscaler/runtime/device.py @@ -1,7 +1,7 @@ """ Communication group settings among devices """ -from typing import List, Dict +from typing import List, Dict, Optional import numpy as np import torch import os @@ -9,70 +9,53 @@ import datetime from nnscaler.flags import CompileFlag +from nnscaler.utils import is_running_distributed _logger = logging.getLogger(__name__) _LARGE_TIMEOUT = datetime.timedelta(seconds=21600) -class DeviceGroup: - - class __DeviceGroup: - - def __init__(self): - if CompileFlag.dev_mode: - self.rank = 0 - self.world_size = 1 - self.local_world_size = 1 - self.local_rank = 0 - self.node_rank = 0 - else: - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group( - backend='nccl', timeout=_LARGE_TIMEOUT - ) - - # disable it for now due to connection refused error when nnodes > 1 - # TODO: investigate the root cause - # create a barrier group for synchronization - # it is OK even the user has already created this gloo group - # this new timeout will override the old one. - # self.barrier_gloo_group = torch.distributed.new_group( - # backend='gloo', timeout=_LARGE_TIMEOUT - # ) - - self.rank = torch.distributed.get_rank() - self.world_size = torch.distributed.get_world_size() - # assume each node has the same device number - self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) - self.local_rank = int(os.environ.get('LOCAL_RANK')) - self.node_rank = int(os.environ.get('GROUP_RANK')) - - torch.cuda.set_device(self.local_rank) - self.groups: Dict = { '1'*self.world_size: None } - self.streams: Dict[str, torch.cuda.Stream] = { - 'default': torch.cuda.default_stream()} - - instance = None - +class _DeviceGroup: def __init__(self): - if not DeviceGroup.instance: - DeviceGroup.instance = DeviceGroup.__DeviceGroup() - - def __getattr__(self, name): - return getattr(self.instance, name) - - # def __setattr__(self, name): - # return setattr(self.instance, name) - - def __len__(self, name): - return DeviceGroup.instance.world_size + if CompileFlag.dev_mode or not is_running_distributed(): + self.rank = 0 + self.world_size = 1 + self.local_world_size = 1 + self.local_rank = 0 + self.node_rank = 0 + else: + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend='nccl', timeout=_LARGE_TIMEOUT + ) + + # disable it for now due to connection refused error when nnodes > 1 + # TODO: investigate the root cause + # create a barrier group for synchronization + # it is OK even the user has already created this gloo group + # this new timeout will override the old one. + # self.barrier_gloo_group = torch.distributed.new_group( + # backend='gloo', timeout=_LARGE_TIMEOUT + # ) + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + # assume each node has the same device number + self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) + self.local_rank = int(os.environ.get('LOCAL_RANK')) + self.node_rank = int(os.environ.get('GROUP_RANK')) + + torch.cuda.set_device(self.local_rank) + self.groups: Dict = { '1'*self.world_size: None } + self.streams: Dict[str, torch.cuda.Stream] = { + 'default': torch.cuda.default_stream()} def group_exists(self, ranks): """ Check if group exists """ - rank_bits = DeviceGroup.bitmap(ranks) - return rank_bits in self.instance.groups + rank_bits = self.bitmap(ranks) + return rank_bits in self.groups def get_group(self, ranks): """ @@ -80,10 +63,10 @@ def get_group(self, ranks): None will be returned if length of ranks are equal to world size """ - if len(ranks) == self.instance.world_size: + if len(ranks) == self.world_size: return None - rank_bits = DeviceGroup.bitmap(ranks) - if rank_bits not in self.instance.groups: + rank_bits = self.bitmap(ranks) + if rank_bits not in self.groups: self.groups[rank_bits] = torch.distributed.new_group( list(ranks), timeout=_LARGE_TIMEOUT) return self.groups[rank_bits] @@ -92,7 +75,7 @@ def long_barrier(self): """ Barrier synchronization with very long timeout """ - # torch.distributed.barrier(group=self.instance.barrier_gloo_group) + # torch.distributed.barrier(group=self.barrier_gloo_group) torch.distributed.barrier() def get_stream(self, name: str) -> torch.cuda.Stream: @@ -100,7 +83,7 @@ def get_stream(self, name: str) -> torch.cuda.Stream: Get stream by name. If name doesn't exist, will create a new one. """ - return DeviceGroup.instance.streams.setdefault( + return self.streams.setdefault( name, torch.cuda.Stream()) def create_hybrid(self, group_num: List[int]) -> List[List[int]]: @@ -129,13 +112,11 @@ def create_hybrid(self, group_num: List[int]) -> List[List[int]]: assert len(outputs) == len(group_num) return outputs - - @staticmethod - def bitmap(ranks): + def bitmap(self, ranks): """ map the rank list to the bit map string """ - bits = '0' * DeviceGroup.instance.world_size + bits = '0' * self.world_size for rank in ranks: if rank >= len(bits): raise ValueError("rank {} out of range ({})".format(rank, len(bits))) @@ -143,10 +124,19 @@ def bitmap(ranks): return bits def __repr__(self): - msg = 'node id: [{}] rank: [{}] local rank: [{}]\n'.format(self.node_id, self.rank, self.local_rank) + msg = 'node rank: [{}] rank: [{}] local rank: [{}]\n'.format(self.node_rank, self.rank, self.local_rank) msg += 'communication groups (ranks):\n' for bitmap, group in self.groups.items(): ranks = [rank for rank, bit in enumerate(bitmap) if bit == '1'] - if self.instance.rank in ranks: + if self.rank in ranks: msg += '\t group {}: my group rank: [{}]\n'.format(ranks, torch.distributed.get_rank(group)) return msg + + +_instance: Optional[_DeviceGroup] = None + +def DeviceGroup() -> _DeviceGroup: + global _instance + if _instance is None: + _instance = _DeviceGroup() + return _instance diff --git a/tests/compiler/test_compile.py b/tests/compiler/test_compile.py index 66bdd515..a9d4d6a3 100644 --- a/tests/compiler/test_compile.py +++ b/tests/compiler/test_compile.py @@ -5,7 +5,10 @@ from functools import partial import more_itertools as mitr +import pytest + import nnscaler +from nnscaler.ir.tensor import IRSubTensor from nnscaler.utils import load_model from nnscaler.compiler import compile from nnscaler.runtime.utils import microbatches @@ -14,7 +17,7 @@ from nnscaler.ir.operator import IRFwOperation, IRDataOperation from nnscaler.flags import CompileFlag from ..launch_torchrun import torchrun -from ..utils import init_parameter, assert_parity +from ..utils import init_parameter, assert_parity, replace_all_device_with class MLP(torch.nn.Module): @@ -194,3 +197,43 @@ def train_iter(model, dataloader): baseline, partial(cube_run, 2, pipe_policy) ) + + +class TupleReturnModule2(torch.nn.Module): + def __init__(self, return_type=0): + super().__init__() + self.return_type = return_type + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x): + if self.return_type == 0: + return self.linear(x), + else: + return [[self.linear(x)]] + + +def tuple_return_run(return_type): + from nnscaler.policies import pas_dp + from nnscaler import ComputeConfig + from contextlib import nullcontext + + model = TupleReturnModule2(return_type) + data = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + dl = microbatches([data,]) + + def policy(graph, *args, **kwargs): + return pas_dp(graph, ComputeConfig(1, 1)) + + context = nullcontext() if return_type != 0 else pytest.raises(RuntimeError, match='Single tuple outputs.*') + with context: + @compile(model, dl, PAS=policy, scale=False) + def train_iter(model, dataloader): + x = next(iter(dataloader)) + loss = model(x) + assert len(loss) == 1 and len(loss[0]) == 1 and isinstance(loss[0][0], IRSubTensor) + return loss + + + +test_tuple_return0 = partial(torchrun, 1, tuple_return_run, 0) +test_tuple_return1 = partial(torchrun, 1, tuple_return_run, 1) diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index b4bdf553..85632f83 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -165,8 +165,8 @@ def forward(self, x): assert len(ir_graph.nodes()) == 5 assert len(ir_graph.nodes()[0].outputs()) == 3 - assert isinstance(ir_graph.output(0), list) assert len(ir_graph.outputs()) == 1 + assert isinstance(ir_graph.output(0), list) if output_list: assert len(ir_graph.nodes()[-1].outputs()) == 1 else: diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index 130645a5..c07d75f6 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -1,5 +1,6 @@ import inspect import tempfile +from contextlib import nullcontext import torch import pytest @@ -96,6 +97,66 @@ def test_codegen_args(): ) +class TupleReturnModule1(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, y): + return self.linear(x) + y, y + 10 + + +@replace_all_device_with('cpu') +def test_codegen_tuple_return1(): + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + TupleReturnModule1(), + { + 'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + 'y': 1.0, + }, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False + ) + assert len(_gencode_contains(tempdir, TupleReturnModule1, 0, + r"return add_.*, add_.*")) == 2 + + +class TupleReturnModule2(torch.nn.Module): + def __init__(self, return_type): + super().__init__() + self.return_type = return_type + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, y): + if self.return_type == 0: + return self.linear(x), + else: + return [[self.linear(x) + y]] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('return_type', [0, 1]) +def test_codegen_tuple_return2(return_type): + test_context = nullcontext() if return_type != 0 else pytest.raises(RuntimeError, match='Single tuple outputs.*') + with tempfile.TemporaryDirectory() as tempdir, test_context: + parallelize( + TupleReturnModule2(return_type), + { + 'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + 'y': 1.0, + }, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False + ) + assert _gencode_contains(tempdir, TupleReturnModule2, 0, + r"return \[\[add_.*\]\]") + + class UnusedArgsModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/tests/utils.py b/tests/utils.py index 5c8f41e7..c6bd6972 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,7 @@ import torch.distributed.distributed_c10d as c10d from nnscaler.parallel import ComputeConfig +import nnscaler from nnscaler.runtime.module import ParallelModule from nnscaler.runtime.device import DeviceGroup, CompileFlag @@ -256,26 +257,27 @@ def mock_dist(rank, world_size): @contextmanager def mock_cube_env(cube_module_cls: Type[ParallelModule], compute_config): - old_device_group = DeviceGroup.instance + old_device_group = nnscaler.runtime.device._instance old_dev_mode = CompileFlag.dev_mode used_cuda_fns = ['set_device', 'current_device', 'default_stream'] old_cuda_fns = { fname: getattr(torch.cuda, fname) for fname in used_cuda_fns } - torchrun_envs = ['RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_WORLD_SIZE', 'GROUP_RANK'] + torchrun_envs = ['RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'LOCAL_WORLD_SIZE', 'GROUP_RANK', 'TORCHELASTIC_RUN_ID'] old_envs = { env: os.environ.get(env, None) for env in torchrun_envs } try: - DeviceGroup.instance = None + nnscaler.runtime.device._instance = None CompileFlag.dev_mode = False for fname, fn in old_cuda_fns.items(): setattr(torch.cuda, fname, lambda *args, **kwargs: None) os.environ['RANK'] = os.environ['LOCAL_RANK'] = str(cube_module_cls.rank) os.environ['WORLD_SIZE'] = os.environ['LOCAL_WORLD_SIZE'] = str(compute_config.runtime_ngpus) os.environ['GROUP_RANK'] = '0' + os.environ['TORCHELASTIC_RUN_ID'] = '0' # fake torchrun env yield finally: for env, val in old_envs.items(): @@ -286,7 +288,7 @@ def mock_cube_env(cube_module_cls: Type[ParallelModule], compute_config): for fname, fn in old_cuda_fns.items(): setattr(torch.cuda, fname, fn) CompileFlag.dev_mode = old_dev_mode - DeviceGroup.instance = old_device_group + nnscaler.runtime.device._instance = old_device_group def new_empty(cube_module_cls: Type[ParallelModule], device='meta', init_params=False): From 07637d876c0dfdbe2f970c9bf38ca3f98481df6a Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Thu, 22 Aug 2024 04:08:40 +0000 Subject: [PATCH 1712/1892] Merged PR 2243: Fix parallel module when loading from a merged checkpoint If a tensor is a non-persistent buffer, we will check its existence when loading from merged checkpoint. --- nnscaler/graph/function/dimops.py | 2 +- nnscaler/parallel.py | 1 + nnscaler/runtime/module.py | 8 +++++--- tests/parallel_module/test_checkpoint_buffer.py | 8 ++++---- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index 513f39ee..b555624f 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -505,7 +505,7 @@ def transform_space(self) -> List[Tuple[int, int]]: if not str.isdecimal(identifier): nonleading_ids.add(identifier) - visited : Set[str] = set() # to remove equavalent configurations + visited : Set[str] = set() # to remove equivalent configurations configs = [] shapes = self.inputs() for idx, shape in enumerate(shapes): diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index a7ab528a..e8e75353 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1397,6 +1397,7 @@ def _scale_grads(self, scale: float) -> None: for p in pg['params']: if p.grad is not None: p.grad.mul_(scale) + optimizer.scale_grads = types.MethodType(_scale_grads, optimizer) def _register_reducer_pre_hook(self, fn: Callable[[Reducer, torch.Tensor], None]): diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 6b41f183..ff672e7a 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -390,8 +390,8 @@ def merge_model_state_dicts( for local_name, meta in local_fullmap.items(): if local_name not in model_state_dict: # the parameter may not in model_state_dict (deduped with optimization) - # Another casee is when this is a non persistent buffer, we should skip it. - # because non persistent buffer should be stored in the fullmap, but not in the model state dict + # Another case is when this is a non persistent buffer, we should skip it, + # since non persistent buffer should be stored in the fullmap, but not in the model state dict continue # create full tensor on cpu partial_tensor = model_state_dict[local_name] @@ -1200,9 +1200,11 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s dist2param = self.dist_param_map orig_param_names = list(dist2param.values()) # param names in original module (without prefix) + non_persistent_buffers = self.get_non_persistent_buffers() with torch.no_grad(): - attr_names = set(self._fullmap.keys()) + # avoid checking the non-persistent buffers + attr_names = set([attr for attr in self._fullmap.keys() if attr not in non_persistent_buffers]) origname_tid_map = {meta.orig_name: meta.tid for meta in self._fullmap.values()} tid_info = defaultdict(list) diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index f40685f7..d3ea99fb 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -73,7 +73,7 @@ def _gpu_worker(): from nnscaler.runtime.module import _logger with catch_log(_logger) as log_stream: net2 = _to_cube_model(Net2(), compute_config, tempdir, 'net2', (256, 64)) - net2.load_merged_state_dict(merged_state_dict, strict=False) # should success + net2.load_merged_state_dict(merged_state_dict, strict=True) # should success assert torch.equal(list(net2._buffers.values())[0], torch.ones(256, 64)) logs = log_stream.getvalue() @@ -81,7 +81,7 @@ def _gpu_worker(): with catch_log(_logger) as log_stream: net2 = _to_cube_model(Net2(), compute_config, tempdir, 'net2-2', (256, 64), init_module_params=False) - net2.load_merged_state_dict(merged_state_dict, strict=False) # should success + net2.load_merged_state_dict(merged_state_dict, strict=True) # should success assert not torch.equal(list(net2._buffers.values())[0], torch.ones(256, 64)) logs = log_stream.getvalue() @@ -99,7 +99,7 @@ def _gpu_worker(): with catch_log(_logger) as log_stream: net3 = _to_cube_model(Net3(), compute_config, tempdir, 'net3-2', (128, 64)) - net3.load_merged_state_dict(merged_state_dict, strict=False) # should success + net3.load_merged_state_dict(merged_state_dict, strict=True) # should success assert torch.equal(list(net3._buffers.values())[0], torch.ones(128, 64)) logs = log_stream.getvalue() @@ -107,7 +107,7 @@ def _gpu_worker(): with catch_log(_logger) as log_stream: net3 = _to_cube_model(Net3(), compute_config, tempdir, 'net3-2', (128, 64), init_module_params=False) - net3.load_merged_state_dict(merged_state_dict, strict=False) # should success + net3.load_merged_state_dict(merged_state_dict, strict=True) # should success assert torch.equal(list(net3._buffers.values())[0], torch.ones(128, 64)) logs = log_stream.getvalue() From 969f3f7948a6db102e72658af267ac8a51fbd454 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Thu, 22 Aug 2024 05:36:06 +0000 Subject: [PATCH 1713/1892] Merged PR 2240: replace _orig_xxx with orig_func --- .../fx/concrete_trace_utils/concrete_proxy.py | 98 +++-- .../concrete_trace_utils/concrete_tracer.py | 360 ++++++++---------- .../concrete_trace_utils/function_patcher.py | 4 +- .../concrete_trace_utils/operator_patcher.py | 47 +-- .../fx/concrete_trace_utils/orig_func.py | 4 + .../parser/fx/concrete_trace_utils/utils.py | 46 +-- 6 files changed, 227 insertions(+), 332 deletions(-) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 2c58b9ae..e4012cb6 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -18,20 +18,8 @@ from torch.overrides import is_tensor_method_or_property from . import concrete_tracer as et -from . import pytree_utils +from . import pytree_utils, orig_func from .utils import ( - _orig_tuple, - _orig_list, - _orig_type, - _orig_isinstance, - _orig_getattr, - _orig_range, - _orig_dict, - _orig_len, - _orig_index, - _orig_bool, - _orig_slice, - _orig_set, get_frame_record, ) @@ -52,7 +40,7 @@ class ConcreteProxy(Proxy): 'POP_JUMP_IF_TRUE', 'JUMP_IF_NOT_EXC_MATCH', # occurred in new python vertion, not tested ) - jump_opcodes = _orig_tuple(dis.opmap[name] for name in jump_opnames if name in dis.opmap) + jump_opcodes = orig_func.tuple(dis.opmap[name] for name in jump_opnames if name in dis.opmap) op_compare = dis.opmap['COMPARE_OP'] op_extended_arg = dis.opmap['EXTENDED_ARG'] op_call_ex = dis.opmap['CALL_FUNCTION_EX'] @@ -78,7 +66,7 @@ def __repr__(self) -> str: def __getattr__(self, k) -> ConcreteProxy: # if the proxy is a wrapped module, forward this call to the torch.nn.Module.__getattribute__ - if _orig_isinstance(self.value, torch.nn.Module): + if orig_func.isinstance(self.value, torch.nn.Module): return torch.nn.Module.__getattribute__(self.value, k) return ConcreteAttrProxy(self, k) @@ -97,7 +85,7 @@ def __iter__(self) -> Union[Iterable, ConcreteProxy]: calling_frame = frame.f_back assert calling_frame is not None cur = calling_frame.f_lasti // 2 - insts: List[dis.Instruction] = _orig_list(dis.get_instructions(calling_frame.f_code)) + insts: List[dis.Instruction] = orig_func.list(dis.get_instructions(calling_frame.f_code)) while insts[cur].opcode == self.op_extended_arg: cur += 1 @@ -117,7 +105,7 @@ def __iter__(self) -> Union[Iterable, ConcreteProxy]: elif insts[cur].opcode == self.op_unpack_sequence: # in executing `a, b, c = atuple` return ConcreteUnpackIterProxy(self) - elif insts[cur].opname == 'GET_ITER' and insts[cur + 1].opname == 'FOR_ITER' and _orig_isinstance(self.value, _orig_range): + elif insts[cur].opname == 'GET_ITER' and insts[cur + 1].opname == 'FOR_ITER' and orig_func.isinstance(self.value, orig_func.range): # in executing `for i in range(...)` return iter(self.value) # elif insts[cur].opname == 'CONTAINS_OP': @@ -136,23 +124,23 @@ def __len__(self) -> Union[int, ConcreteProxy]: calling_frame = frame.f_back assert calling_frame is not None cur = calling_frame.f_lasti // 2 - insts: List[dis.Instruction] = _orig_list(dis.get_instructions(calling_frame.f_code)) + insts: List[dis.Instruction] = orig_func.list(dis.get_instructions(calling_frame.f_code)) while insts[cur].opcode == self.op_extended_arg: cur += 1 if insts[cur].opcode == self.op_call_ex: # in executing func(..., *proxy) - return _orig_len(self.value) + return orig_func.len(self.value) elif insts[cur].opcode == self.op_tuple_unpack_call: # in executing func(*..., *proxy) # <= python 3.8 - return _orig_len(self.value) + return orig_func.len(self.value) elif insts[cur].opcode == self.op_list_extend: # in executing x.extend(*proxy) or [x, *proxy] # >= python 3.9 - return _orig_len(self.value) + return orig_func.len(self.value) else: - return self.tracer.create_proxy('call_function', _orig_len, (self,), {}) + return self.tracer.create_proxy('call_function', orig_func.len, (self,), {}) def __getitem__(self, *args, **kwargs) -> ConcreteProxy: return self.tracer.create_proxy('call_function', operator.getitem, (self,) + args, kwargs) @@ -167,34 +155,34 @@ def __bool__(self) -> Union[bool, ConcreteProxy]: calling_frame = frame.f_back assert calling_frame is not None cur = calling_frame.f_lasti // 2 - insts: List[dis.Instruction] = _orig_list(dis.get_instructions(calling_frame.f_code)) + insts: List[dis.Instruction] = orig_func.list(dis.get_instructions(calling_frame.f_code)) while insts[cur].opcode == self.op_extended_arg: cur += 1 if insts[cur].opcode in self.jump_opcodes or ( insts[cur].opcode in self.jump_before_opcodes and insts[cur + 1].opcode in self.jump_opcodes): # in executing branch condition - return _orig_bool(self.value) + return orig_func.bool(self.value) elif insts[cur].opname == 'CONTAINS_OP': # in executing 'in' - return _orig_bool(self.value) + return orig_func.bool(self.value) elif insts[cur].opname == 'BINARY_SUBSCR': # in executing slice or index, my_list[index] or my_dict[key] - return _orig_bool(self.value) + return orig_func.bool(self.value) elif insts[cur].opcode == self.op_call_ex: # in executing func(..., *proxy) - return _orig_bool(self.value) + return orig_func.bool(self.value) elif insts[cur].opcode == self.op_not: # We cannot return a proxy because 'UNARY_NOT' op will check the type. _logger.warning('please use the function patcher, or use "x = operator.not_(y)" instead of "x = not y",' 'otherwise the traced graph may be wrong') - return _orig_bool(self.value) + return orig_func.bool(self.value) else: - return self.tracer.create_proxy('call_function', _orig_bool, (self,), {}) + return self.tracer.create_proxy('call_function', orig_func.bool, (self,), {}) def __index__(self) -> Union[int, ConcreteProxy]: # should only be in list/tuple getitem - return _orig_index(self.value) + return orig_func.index(self.value) def __hash__(self) -> Union[int, ConcreteProxy]: # should only be in dict getitem @@ -224,7 +212,7 @@ def keys(self): calling_frame = frame.f_back assert calling_frame is not None cur = calling_frame.f_lasti // 2 - insts: List[dis.Instruction] = _orig_list(dis.get_instructions(calling_frame.f_code)) + insts: List[dis.Instruction] = orig_func.list(dis.get_instructions(calling_frame.f_code)) while insts[cur].opcode == self.op_extended_arg: cur += 1 @@ -257,17 +245,17 @@ def __torch_function__(cls, orig_method, types, args=None, kwargs=None): args = args if args else () kwargs = kwargs if kwargs else {} - tracers: Set[Any] = _orig_set() + tracers: Set[Any] = orig_func.set() def find_tracer(a): - if _orig_isinstance(a, cls): + if orig_func.isinstance(a, cls): tracers.add(a.tracer) pytree_utils.tree_map(find_tracer, args) pytree_utils.tree_map(find_tracer, kwargs) - if _orig_len(tracers) > 1: - raise RuntimeError(f'Found multiple different tracers {_orig_list(tracers)} while ' + if orig_func.len(tracers) > 1: + raise RuntimeError(f'Found multiple different tracers {orig_func.list(tracers)} while ' f'trying to trace operations {orig_method}') tracer, = tracers @@ -293,16 +281,16 @@ def __init__(self, root: ConcreteProxy, attr: str): self.attr = attr self.tracer = root.tracer self._node: Optional[Node] = None - if _orig_isinstance(root.value, torch.Tensor) and attr == 'is_cuda' and self.tracer.cpu_offload: + if orig_func.isinstance(root.value, torch.Tensor) and attr == 'is_cuda' and self.tracer.cpu_offload: self.value = True - elif _orig_isinstance(root.value, torch.Tensor) and attr == 'device' and self.tracer.cpu_offload: + elif orig_func.isinstance(root.value, torch.Tensor) and attr == 'device' and self.tracer.cpu_offload: self.value = torch.device('cuda') warning_msg = "operation .device is detected, it will always return torch.device('cuda') during trace, " + \ "please make sure don't manually change the tensor device in the code.\n" + \ f"\t{get_frame_record()}" _logger.warning(warning_msg) else: - self.value = _orig_getattr(root.value, attr) + self.value = orig_func.getattr(root.value, attr) def __repr__(self) -> str: calling_frame_name = inspect.stack()[1][1] @@ -316,7 +304,7 @@ def node(self): # which do not rely on the getitem call if self._node is None: self._node = self.tracer.create_proxy( - 'call_function', _orig_getattr, (self.root, self.attr), {}).node + 'call_function', orig_func.getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): @@ -363,15 +351,15 @@ def try_create(root: Any): def __init__(self, root: ConcreteProxy): if not hasattr(root.value, '__getitem__'): # transfer 'set' to 'tuple' - # it's tuple not _orig_tuple! + # it's tuple not orig_func.tuple! # root = tuple(root) - root = root.tracer.create_proxy('call_function', _orig_tuple, (root,), {}) + root = root.tracer.create_proxy('call_function', orig_func.tuple, (root,), {}) self.root = root self.tracer = root.tracer self._node: Optional[Node] = None self._value: List[Any] = [] self.index = -1 - self.len = _orig_len(root.value) + self.len = orig_func.len(root.value) def __repr__(self) -> str: return f'ConcreteUnpackIterProxy({self.node.name})' @@ -389,7 +377,7 @@ def node(self): def value(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call - if _orig_len(self._value) == 0: + if orig_func.len(self._value) == 0: self._value.append(iter(self.root.value)) return self._value[0] @@ -404,18 +392,18 @@ def map_aggregate_not_proxy(a, fn): """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ - if _orig_isinstance(a, ConcreteProxy): + if orig_func.isinstance(a, ConcreteProxy): return fn(a) - elif _orig_isinstance(a, _orig_tuple): - t = _orig_tuple(map_aggregate_not_proxy(elem, fn) for elem in a) + elif orig_func.isinstance(a, orig_func.tuple): + t = orig_func.tuple(map_aggregate_not_proxy(elem, fn) for elem in a) # Support NamedTuple (if it has `_fields`) by repacking into original type. - return t if not hasattr(a, '_fields') else _orig_type(a)(*t) - elif _orig_type(a) == _orig_list: - return _orig_list(map_aggregate_not_proxy(elem, fn) for elem in a) - elif _orig_isinstance(a, _orig_dict): - return _orig_dict((k, map_aggregate_not_proxy(v, fn)) for k, v in a.items()) - elif _orig_isinstance(a, _orig_slice): - return _orig_slice(map_aggregate_not_proxy(a.start, fn), map_aggregate_not_proxy(a.stop, fn), map_aggregate_not_proxy(a.step, fn)) + return t if not hasattr(a, '_fields') else orig_func.type(a)(*t) + elif orig_func.type(a) == orig_func.list: + return orig_func.list(map_aggregate_not_proxy(elem, fn) for elem in a) + elif orig_func.isinstance(a, orig_func.dict): + return orig_func.dict((k, map_aggregate_not_proxy(v, fn)) for k, v in a.items()) + elif orig_func.isinstance(a, orig_func.slice): + return orig_func.slice(map_aggregate_not_proxy(a.start, fn), map_aggregate_not_proxy(a.stop, fn), map_aggregate_not_proxy(a.step, fn)) else: return fn(a) @@ -442,7 +430,7 @@ def map_aggregate_not_proxy(a, fn): def _scope(method): def impl(*args, **kwargs): tracer = args[0].tracer - target = _orig_getattr(operator, method) + target = orig_func.getattr(operator, method) return tracer.create_proxy('call_function', target, args, kwargs) impl.__name__ = method as_magic = f'__{method.strip("_")}__' @@ -454,7 +442,7 @@ def _define_reflectable(orig_method_name): method_name = f'__r{orig_method_name.strip("_")}__' def impl(self, rhs): - target = _orig_getattr(operator, orig_method_name) + target = orig_func.getattr(operator, orig_method_name) return self.tracer.create_proxy('call_function', target, (rhs, self), {}) impl.__name__ = method_name impl.__qualname__ = method_name diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 9793fa35..054b3a90 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -83,56 +83,12 @@ def __exit__(self, *args): return from . import concrete_proxy as ep -from . import pytree_utils +from . import pytree_utils, orig_func from .function_patcher import FunctionPatcher from .operator_patcher import OperatorPatcherContext from .utils import ( - _orig_module_call, - _orig_module_getattr, - _orig_module_getattribute, - - _orig_agfunc_apply, - _orig_torch_assert, - - _orig_type, - _orig_isinstance, - _orig_issubclass, - _orig_getattr, - - _orig_range, - _orig_int, - _orig_float, - _orig_bool, - _orig_tuple, - _orig_list, - _orig_set, - _orig_frozenset, - _orig_dict, - _orig_map, - _orig_zip, - _orig_enumerate, - _orig_slice, - _orig_reversed, - - _orig_torch_size, - _orig_torch_finfo, - - _orig_len, - _orig_not, - _orig_is, - _orig_is_not, - _orig_contains, - _orig_index, - - _orig_all, - _orig_min, - _orig_max, - _orig_node_is_impure, - side_effectful_inplace_ops, -) -from .utils import ( ExtraSEFPatcher, EmptyResult, extract_results_metadata, @@ -181,7 +137,7 @@ class LeafClassWrapInfo: def is_autograd_apply(func) -> bool: return getattr(func, '__name__', None) == 'apply' \ - and _orig_isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) + and orig_func.isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) @compatibility(is_backward_compatible=True) @@ -199,15 +155,15 @@ class ConcreteTracer(TracerBase): ) default_autowrap_leaf_function: Dict[Any, LeafFnWrapInfo] = { # function - _orig_len: LeafFnWrapInfo([], False, None), - _orig_not: LeafFnWrapInfo([], False, None), - _orig_is: LeafFnWrapInfo([], False, None), - _orig_is_not: LeafFnWrapInfo([], False, None), - _orig_contains: LeafFnWrapInfo([], False, None), - _orig_index: LeafFnWrapInfo([], False, None), - _orig_all: LeafFnWrapInfo([], False, None), - _orig_min: LeafFnWrapInfo([], False, None), - _orig_max: LeafFnWrapInfo([], False, None), + orig_func.len: LeafFnWrapInfo([], False, None), + orig_func.not_: LeafFnWrapInfo([], False, None), + orig_func.is_: LeafFnWrapInfo([], False, None), + orig_func.is_not: LeafFnWrapInfo([], False, None), + orig_func.contains: LeafFnWrapInfo([], False, None), + orig_func.index: LeafFnWrapInfo([], False, None), + orig_func.all: LeafFnWrapInfo([], False, None), + orig_func.min: LeafFnWrapInfo([], False, None), + orig_func.max: LeafFnWrapInfo([], False, None), # force-traced function (the factory functions of tensor creation) torch.arange: LeafFnWrapInfo([], True, None), @@ -229,26 +185,26 @@ class ConcreteTracer(TracerBase): # method Sequential.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - Sequential.__len__: LeafFnWrapInfo([], False, _orig_len), + Sequential.__len__: LeafFnWrapInfo([], False, orig_func.len), Sequential.__iter__: LeafFnWrapInfo([], False, iter), ModuleList.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - ModuleList.__len__: LeafFnWrapInfo([], False, _orig_len), + ModuleList.__len__: LeafFnWrapInfo([], False, orig_func.len), ModuleList.__iter__: LeafFnWrapInfo([], False, iter), ModuleDict.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - ModuleDict.__len__: LeafFnWrapInfo([], False, _orig_len), + ModuleDict.__len__: LeafFnWrapInfo([], False, orig_func.len), ModuleDict.__iter__: LeafFnWrapInfo([], False, iter), - ModuleDict.__contains__: LeafFnWrapInfo([], False, _orig_contains), + ModuleDict.__contains__: LeafFnWrapInfo([], False, orig_func.contains), ParameterList.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - ParameterList.__len__: LeafFnWrapInfo([], False, _orig_len), + ParameterList.__len__: LeafFnWrapInfo([], False, orig_func.len), ParameterList.__iter__: LeafFnWrapInfo([], False, iter), ParameterDict.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - ParameterDict.__len__: LeafFnWrapInfo([], False, _orig_len), + ParameterDict.__len__: LeafFnWrapInfo([], False, orig_func.len), ParameterDict.__iter__: LeafFnWrapInfo([], False, iter), - ParameterDict.__contains__: LeafFnWrapInfo([], False, _orig_contains), + ParameterDict.__contains__: LeafFnWrapInfo([], False, orig_func.contains), } # equals to `from torch.nn import functional as nn_functional` # to pass pyright check @@ -260,7 +216,7 @@ class ConcreteTracer(TracerBase): default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, attr) for name in dir(nn_functional): attr = getattr(nn_functional, name) - if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__')\ + if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__')\ and getattr(attr, '__module__', None) not in ('typing', 'torch.nn.modules.utils'): if attr not in default_autowrap_leaf_function: default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) @@ -268,19 +224,19 @@ class ConcreteTracer(TracerBase): default_autowrap_leaf_function[attr].extra_locs.append(Location(nn_functional, name)) for name in dir(torch._C._VariableFunctions): attr = getattr(torch._C._VariableFunctions, name) - if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): + if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__'): if attr not in default_autowrap_leaf_function: default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) for name in dir(torch._C._nn): attr = getattr(torch._C._nn, name) - if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): + if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__'): if attr not in default_autowrap_leaf_function: default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) if hasattr(attr, '__module__') and attr.__module__ != 'torch._C._nn': default_autowrap_leaf_function[attr].extra_locs.append(Location(torch._C._nn, name)) for name in dir(torch._C._TensorBase): attr = getattr(torch._C._TensorBase, name) - if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__'): + if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__'): if attr not in default_autowrap_leaf_function: to_func = getattr(torch.Tensor, name, None) to_func = None if to_func == attr else to_func @@ -288,28 +244,26 @@ class ConcreteTracer(TracerBase): # find the multi position for default_autowrap_leaf_function in torch.__dir__() for name in dir(torch): attr = getattr(torch, name) - if callable(attr) and not _orig_isinstance(attr, Type) and not name.startswith('__') \ + if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__') \ and attr in default_autowrap_leaf_function: default_autowrap_leaf_function[attr].extra_locs.append(Location(torch, name)) default_autowrap_leaf_class: Dict[Type, LeafClassWrapInfo] = { # class - _orig_bool: LeafClassWrapInfo([], False), - # we don't want zip appear as a node in the graph - # _orig_zip: LeafClassWrapInfo([], False), - _orig_int: LeafClassWrapInfo([], False), - _orig_float: LeafClassWrapInfo([], False), + orig_func.bool: LeafClassWrapInfo([], False), + orig_func.int: LeafClassWrapInfo([], False), + orig_func.float: LeafClassWrapInfo([], False), # iterable class - _orig_tuple: LeafClassWrapInfo([], True), - _orig_list: LeafClassWrapInfo([], True), - _orig_set: LeafClassWrapInfo([], True), - _orig_frozenset: LeafClassWrapInfo([], True), - _orig_dict: LeafClassWrapInfo([], True), - _orig_reversed: LeafClassWrapInfo([], False), - - _orig_torch_size: LeafClassWrapInfo([], False), - _orig_torch_finfo: LeafClassWrapInfo([], False), + orig_func.tuple: LeafClassWrapInfo([], True), + orig_func.list: LeafClassWrapInfo([], True), + orig_func.set: LeafClassWrapInfo([], True), + orig_func.frozenset: LeafClassWrapInfo([], True), + orig_func.dict: LeafClassWrapInfo([], True), + orig_func.reversed: LeafClassWrapInfo([], False), + + orig_func.torch_Size: LeafClassWrapInfo([], False), + orig_func.torch_finfo: LeafClassWrapInfo([], False), } @compatibility(is_backward_compatible=True) @@ -359,12 +313,12 @@ def fetch_attr(self, target: str) -> Any: with self.do_temp_call_origin(): target_atoms = target.split('.') attr_itr = self.root - for i, atom in _orig_enumerate(target_atoms): + for i, atom in orig_func.enumerate(target_atoms): # if atom == '': # continue if not hasattr(attr_itr, atom): raise RuntimeError(f"Node referenced nonexistent target \'{'.'.join(target_atoms[:i])}\'") - attr_itr = _orig_getattr(attr_itr, atom) + attr_itr = orig_func.getattr(attr_itr, atom) return attr_itr def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]): @@ -384,7 +338,7 @@ def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any] result = fn(*args, **kwargs) elif kind == 'call_method': self_obj, *args_tail = args - fn = _orig_getattr(self_obj, target) + fn = orig_func.getattr(self_obj, target) result = fn(*args_tail, **kwargs) elif kind == 'call_module': assert isinstance(target, str) @@ -541,10 +495,10 @@ def create_arg(self, a: Any) -> Union[Node, Any]: start = self.create_arg(a.start) stop = self.create_arg(a.stop) step = self.create_arg(a.step) - if _orig_isinstance(start, Node)\ - or _orig_isinstance(stop, Node)\ - or _orig_isinstance(step, Node): - return self.create_node('call_function', _orig_slice, (start, stop, step), {}, node_result=a) + if orig_func.isinstance(start, Node)\ + or orig_func.isinstance(stop, Node)\ + or orig_func.isinstance(step, Node): + return self.create_node('call_function', orig_func.slice, (start, stop, step), {}, node_result=a) else: return a # For NamedTuple instances that appear literally as args, we emit @@ -577,7 +531,7 @@ def create_arg(self, a: Any) -> Union[Node, Any]: return self.create_node('get_attr', qualname, (), {}, node_result=a) - if _orig_type(a) in _proxyable_classes: + if orig_func.type(a) in _proxyable_classes: # This is an instance of a proxyable class for which we did not # witness its construction. Intern this as a constant attribute @@ -607,8 +561,8 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool """ similar to _symbolic_trace.Tracer.is_leaf_module """ - return (m.__module__.startswith('torch.nn.functional') and not _orig_isinstance(m, (Sequential, ModuleList, ModuleDict)))\ - or _orig_isinstance(m, self.leaf_module) + return (m.__module__.startswith('torch.nn.functional') and not orig_func.isinstance(m, (Sequential, ModuleList, ModuleDict)))\ + or orig_func.isinstance(m, self.leaf_module) @compatibility(is_backward_compatible=True) def path_of_module(self, mod: torch.nn.Module) -> str: @@ -625,9 +579,9 @@ def path_of_module(self, mod: torch.nn.Module) -> str: module_constants = self.root._module_constants assert isinstance(module_constants, torch.nn.ModuleList) if hasattr(mod, 'extra_repr'): - sub_path = _orig_type(mod).__name__ + mod.extra_repr() + sub_path = orig_func.type(mod).__name__ + mod.extra_repr() else: - sub_path = str(_orig_len(module_constants)) + sub_path = str(orig_func.len(module_constants)) if not hasattr(module_constants, sub_path): module_constants.add_module(sub_path, mod) path = '_module_constants.%s' % sub_path @@ -679,10 +633,10 @@ def create_args_for_root(self, root_fn, is_module, concrete_args: Union[Dict[str cnt = 0 self.placeholder_dict = {} arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] - diff_len = _orig_len(arg_names) - _orig_len(default_value_list) + diff_len = orig_func.len(arg_names) - orig_func.len(default_value_list) default_args = {arg_names[idx + diff_len]: default_value_list[idx] for idx in range(len(default_value_list))} if isinstance(concrete_args, tuple): - if _orig_len(arg_names) != _orig_len(concrete_args): + if orig_func.len(arg_names) != orig_func.len(concrete_args): raise RuntimeError(f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments") concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} def proxy_placeholder(name: str): @@ -794,7 +748,7 @@ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, # TODO: better infomation assert hasattr( root, forward_function_name - ), f"traced_func_name={forward_function_name} doesn't exist in {_orig_type(root).__name__}" + ), f"traced_func_name={forward_function_name} doesn't exist in {orig_func.type(root).__name__}" fn = getattr(root, forward_function_name) self.submodule_paths = {mod: name for name, mod in root.named_modules()} @@ -845,27 +799,27 @@ def get_middle_class(node, memo = set(), prefix = ''): yield m self.the_path_of_middle_class = {id(v): k for k, v in get_middle_class(self.root)} - @functools.wraps(_orig_module_getattribute) + @functools.wraps(orig_func.torch_module_getattribute) def module_getattribute_wrapper(mod, attr): if self.temp_call_origin: try: - return _orig_module_getattribute(mod, attr) + return orig_func.torch_module_getattribute(mod, attr) except AttributeError: - return _orig_module_getattr(mod, attr) + return orig_func.torch_module_getattr(mod, attr) with self.do_temp_call_origin(): try: - attr_val = _orig_module_getattribute(mod, attr) + attr_val = orig_func.torch_module_getattribute(mod, attr) except AttributeError: - attr_val = _orig_module_getattr(mod, attr) - if _orig_isinstance(attr_val, ep.ConcreteProxy): + attr_val = orig_func.torch_module_getattr(mod, attr) + if orig_func.isinstance(attr_val, ep.ConcreteProxy): warn_msg = f'Detected {self.the_path_of_middle_class[id(mod)]}.{attr} is a ConcreteProxy, ' + \ 'this is usually caused by directly assigning the return value of some leaf function to the attribute of the module. ' + \ 'Please note that this writing method may cause some trace errors.' _logger.warning(warn_msg) return attr_val - # using isinstance instead of _orig_isinstance to judge whether + # using isinstance instead of orig_func.isinstance to judge whether # the ConcreteProxy.value is the following three types if the attr_val is a ConcreteProxy - elif isinstance(attr_val, (_orig_tuple, _orig_list)): + elif isinstance(attr_val, (orig_func.tuple, orig_func.list)): if self.the_path_of_middle_class[id(mod)] == '': return self.create_proxy('get_attr', f'{attr}', (), {}) else: @@ -881,14 +835,14 @@ def module_getattribute_wrapper(mod, attr): return self.create_proxy('get_attr', self.the_path_of_buffer[id(attr_val)], (), {}) return attr_val - @functools.wraps(_orig_module_call) + @functools.wraps(orig_func.torch_module_call) def module_call_wrapper(mod, *args, **kwargs): if self.temp_call_origin: - return _orig_module_call(mod, *args, **kwargs) + return orig_func.torch_module_call(mod, *args, **kwargs) else: # codes below corresponds to symbolic tracer's call_module module_qualified_name = self.path_of_module(mod) - with ScopeContextManager(self.scope, Scope(module_qualified_name, _orig_type(mod))) as _scope: + with ScopeContextManager(self.scope, Scope(module_qualified_name, orig_func.type(mod))) as _scope: self.module_stack[_scope.module_path] = _scope.module_type if not self.is_leaf_module(mod, module_qualified_name): _autowrap_check(self, @@ -901,7 +855,7 @@ def module_call_wrapper(mod, *args, **kwargs): self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - ret_val = _orig_module_call(mod, *args, **kwargs) + ret_val = orig_func.torch_module_call(mod, *args, **kwargs) else: ret_val = self.create_proxy('call_module', module_qualified_name, args, kwargs) key, _ = self.module_stack.popitem(last=True) @@ -910,132 +864,132 @@ def module_call_wrapper(mod, *args, **kwargs): class map_wrapper_clz: # used to track the original class - _fx_wrapped_ori_clz = _orig_map + _fx_wrapped_ori_clz = orig_func.map def __new__(cls, the_func, *iterables: Any): if self.temp_call_origin: - return _orig_map(the_func, *iterables) - tracers = _orig_set() + return orig_func.map(the_func, *iterables) + tracers = orig_func.set() for one_iter in iterables: - if _orig_isinstance(one_iter, ep.Proxy): + if orig_func.isinstance(one_iter, ep.Proxy): tracers.add(one_iter.tracer) - if _orig_len(tracers) > 1: + if orig_func.len(tracers) > 1: raise Exception('more than 1 tracer detected. please report the issue') - elif _orig_len(tracers) == 1: - results = _orig_list() - for args in _orig_zip(*iterables): + elif orig_func.len(tracers) == 1: + results = orig_func.list() + for args in zip(*iterables): results.append(the_func(*args)) - return next(iter(tracers)).create_proxy('call_function', _orig_tuple, (results,), {}) + return next(iter(tracers)).create_proxy('call_function', orig_func.tuple, (results,), {}) ## for the multi-level list/tuple - iterables = _orig_list(_orig_list(it) for it in iterables) + iterables = orig_func.list(orig_func.list(it) for it in iterables) for it in iterables: for arg in it: - if _orig_isinstance(arg, ep.Proxy): + if orig_func.isinstance(arg, ep.Proxy): tracers.add(arg.tracer) - if _orig_len(tracers) > 1: + if orig_func.len(tracers) > 1: raise Exception('more than 1 tracer detected. please report the issue') - elif _orig_len(tracers) == 1: - results = _orig_list() - for args in _orig_zip(*iterables): + elif orig_func.len(tracers) == 1: + results = orig_func.list() + for args in zip(*iterables): results.append(the_func(*args)) - return next(iter(tracers)).create_proxy('call_function', _orig_tuple, (results,), {}) + return next(iter(tracers)).create_proxy('call_function', orig_func.tuple, (results,), {}) ## for the multi-level list/tuple end - return _orig_map(the_func, *iterables) + return orig_func.map(the_func, *iterables) def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(_orig_map)) + return id(__o) in (id(self), id(orig_func.map)) def __hash__(self): return id(self) map_wrapper = map_wrapper_clz class range_wrapper_clz: # used to track the original class - _fx_wrapped_ori_clz = _orig_range + _fx_wrapped_ori_clz = orig_func.range def __new__(cls, *args): # TODO: better infomation - assert 1 <= _orig_len(args) <= 3 - args = (arg.value if _orig_isinstance(arg, ep.ConcreteProxy) else arg for arg in args) - return _orig_range(*args) + assert 1 <= orig_func.len(args) <= 3 + args = (arg.value if orig_func.isinstance(arg, ep.ConcreteProxy) else arg for arg in args) + return orig_func.range(*args) def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(_orig_range)) + return id(__o) in (id(self), id(orig_func.range)) def __hash__(self): return id(self) range_wrapper = range_wrapper_clz class enumerate_wrapper_clz: # used to track the original class - _fx_wrapped_ori_clz = _orig_enumerate + _fx_wrapped_ori_clz = orig_func.enumerate def __new__(cls, iterable, start=0): count = start for elem in iterable: - if _orig_isinstance(elem, ep.ConcreteProxy) and _orig_isinstance(elem.value, (_orig_int, str)): + if orig_func.isinstance(elem, ep.ConcreteProxy) and orig_func.isinstance(elem.value, (orig_func.int, str)): yield count, elem.value else: yield count, elem count += 1 def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(_orig_enumerate)) + return id(__o) in (id(self), id(orig_func.enumerate)) def __hash__(self): return id(self) enumerate_wrapper = enumerate_wrapper_clz class type_wrapper_clz: # used to track the original class - _fx_wrapped_ori_clz = _orig_type + _fx_wrapped_ori_clz = orig_func.type def __new__(cls, obj_or_name, *args): # case 1: class type(name, bases, dict, **kwds) - if _orig_len(args) > 0: - assert _orig_len(args) == 2 + if orig_func.len(args) > 0: + assert orig_func.len(args) == 2 base_cls, cls_dict = args[0], args[1] # if it is a wrapped class, replace it to the original one - base_cls = _orig_tuple(bs._fx_wrapped_ori_clz if hasattr(bs, '_fx_wrapped_ori_clz') else bs for bs in base_cls) - return _orig_type(obj_or_name, base_cls, cls_dict) + base_cls = orig_func.tuple(bs._fx_wrapped_ori_clz if hasattr(bs, '_fx_wrapped_ori_clz') else bs for bs in base_cls) + return orig_func.type(obj_or_name, base_cls, cls_dict) # case 2: class type(object) else: - orig_type = _orig_type(obj_or_name) + orig_type = orig_func.type(obj_or_name) if orig_type in (ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): - return _orig_type(obj_or_name.value) + return orig_func.type(obj_or_name.value) else: return orig_type def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(_orig_enumerate)) + return id(__o) in (id(self), id(orig_func.enumerate)) def __hash__(self): return id(self) type_wrapper = type_wrapper_clz @classmethod - @functools.wraps(_orig_agfunc_apply) + @functools.wraps(orig_func.torch_agfunc_apply) def agfunc_apply_wrapper(clz, *args, **kwargs): if clz not in self.agfunc_dict: self.agfunc_dict[clz] = torch._C._FunctionBase.__dict__['apply'].__get__(None, clz) if self.temp_call_origin: return self.agfunc_dict[clz](*args, **kwargs) - tracers = _orig_set() + tracers = orig_func.set() def unwrap_detect_tracers(obj): if isinstance(obj, ep.ConcreteProxy): tracers.add(obj.tracer) ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) - if _orig_len(tracers) == 0: + if orig_func.len(tracers) == 0: return self.agfunc_dict[clz](*args, **kwargs) - elif _orig_len(tracers) == 1 and next(iter(tracers)) == self: + elif orig_func.len(tracers) == 1 and next(iter(tracers)) == self: return self.create_proxy('call_function', self.agfunc_dict[clz], args, kwargs) else: raise Exception('more than 1 tracer detected. please report the issue') - @functools.wraps(_orig_torch_assert) + @functools.wraps(orig_func.torch_assert) def torch_assert_wrapper(condition, message): - while _orig_isinstance(condition, ep.ConcreteProxy): + while orig_func.isinstance(condition, ep.ConcreteProxy): condition = condition.value - return _orig_torch_assert(condition, message) + return orig_func.torch_assert(condition, message) self.agfunc_dict: dict[Type, Any] = {} self.autowrap_leaf_pairs = { - id(_orig_torch_assert): torch_assert_wrapper, + id(orig_func.torch_assert): torch_assert_wrapper, } self.wrapped_leaf: Dict[Any, Tuple[Tuple[Location,...], Any]] = dict() @@ -1080,7 +1034,7 @@ def torch_assert_wrapper(condition, message): wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn, (self,)) else: wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn) - elif _orig_isinstance(func, (MethodDescriptorType, MethodWrapperType)): + elif orig_func.isinstance(func, (MethodDescriptorType, MethodWrapperType)): wrapped = _create_wrapped_leaf_method(self, func, func.__name__, wrap_info.replace_fn) elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn' \ and not func.__qualname__.startswith('PyCapsule'): @@ -1125,10 +1079,10 @@ def torch_assert_wrapper(condition, message): self.add_need_revert_function(func, self.wrapped_leaf.get(func, (None, None))[1]) self.clz_wrapper_map: Dict[Any, Type] = { - map_wrapper: _orig_map, - enumerate_wrapper: _orig_enumerate, - range_wrapper: _orig_range, - type_wrapper: _orig_type, + map_wrapper: orig_func.map, + enumerate_wrapper: orig_func.enumerate, + range_wrapper: orig_func.range, + type_wrapper: orig_func.type, } for clz, wrap_info in self.autowrap_leaf_class.items(): if clz.__module__.startswith('_') and clz.__module__ != '__main__': @@ -1152,54 +1106,54 @@ def torch_assert_wrapper(condition, message): wrapped = _create_wrapped_nn_module_func(self, mod, forward_function_name) self.wrapped_leaf[mod.forward] = ((Location(mod, forward_function_name),), wrapped) - @functools.wraps(_orig_isinstance) + @functools.wraps(orig_func.isinstance) def isinstance_wrapper(instance, clz): - if _orig_type(clz) in (slice, tuple, list, _orig_slice, _orig_tuple, _orig_list): + if orig_func.type(clz) in (slice, tuple, list, orig_func.slice, orig_func.tuple, orig_func.list): clz_wrapped = [] for wrapped_type, orig_type in self.clz_wrapper_map.items(): if wrapped_type in clz: clz_wrapped.append(orig_type) clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in self.clz_wrapper_map)) - # use _orig_isinstance(clz, Iterable) will cause an endless recursive loop + # use orig_func.isinstance(clz, Iterable) will cause an endless recursive loop for cls in (object, ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): - if cls in clz and _orig_isinstance(instance, cls): + if cls in clz and orig_func.isinstance(instance, cls): return True - if _orig_isinstance(instance, ep.ConcreteProxy): - return _orig_isinstance(instance.value, clz) + if orig_func.isinstance(instance, ep.ConcreteProxy): + return orig_func.isinstance(instance.value, clz) else: - return _orig_isinstance(instance, clz) + return orig_func.isinstance(instance, clz) else: if clz in (object, ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): - return _orig_isinstance(instance, clz) + return orig_func.isinstance(instance, clz) if clz in self.clz_wrapper_map: clz = self.clz_wrapper_map[clz] - if _orig_isinstance(instance, ep.ConcreteProxy): + if orig_func.isinstance(instance, ep.ConcreteProxy): instance = instance.value - return _orig_isinstance(instance, clz) + return orig_func.isinstance(instance, clz) - @functools.wraps(_orig_issubclass) + @functools.wraps(orig_func.issubclass) def issubclass_wrapper(subclass, clz): - if _orig_type(clz) in (slice, tuple, list, _orig_slice, _orig_tuple, _orig_list): + if orig_func.type(clz) in (slice, tuple, list, orig_func.slice, orig_func.tuple, orig_func.list): clz_wrapped = [] for wrapped_type, orig_type in self.clz_wrapper_map.items(): if wrapped_type in clz: clz_wrapped.append(orig_type) clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in self.clz_wrapper_map)) - return _orig_issubclass(subclass, clz) + return orig_func.issubclass(subclass, clz) else: if clz in self.clz_wrapper_map: clz = self.clz_wrapper_map[clz] - return _orig_issubclass(subclass, clz) + return orig_func.issubclass(subclass, clz) - @functools.wraps(_orig_getattr) + @functools.wraps(orig_func.getattr) def getattr_wrapper(obj, *args): # TODO: better infomation - if not 1 <= _orig_len(args) <= 2: + if not 1 <= orig_func.len(args) <= 2: raise Exception() - args = _orig_list(args) - if _orig_isinstance(args[0], ep.ConcreteProxy): + args = orig_func.list(args) + if orig_func.isinstance(args[0], ep.ConcreteProxy): args[0] = args[0].value - return _orig_getattr(obj, *args) + return orig_func.getattr(obj, *args) try: with self.patcher: @@ -1232,7 +1186,7 @@ def getattr_wrapper(obj, *args): results = OperatorPatcherContext.patch_run(fn, *args, *more_args, **kwargs) # we should unwrap proxy to the original value in the results when we record it to node.meta['tensor_meta'] def unwrap(obj: Any): - while _orig_isinstance(obj, ep.ConcreteProxy): + while orig_func.isinstance(obj, ep.ConcreteProxy): obj = obj.value return obj self.create_node('output', 'output', (self.create_arg(results),), @@ -1293,7 +1247,7 @@ def _patch_wrapped_functions(patcher : FunctionPatcher): """ for frame_dict, name in _wrapped_fns_to_patch: if name not in frame_dict and hasattr(builtins, name): - orig_fn = _orig_getattr(builtins, name) + orig_fn = orig_func.getattr(builtins, name) else: orig_fn = frame_dict[name] patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) @@ -1323,7 +1277,7 @@ def _autowrap_check(tracer: ConcreteTracer, frame_dict : Dict[str, Any], functio patcher.patch(frame_dict, name, agfunc_dict[value.__self__]) def _create_wrapped_method(cls, name): - orig_fn = _orig_getattr(cls, name) + orig_fn = orig_func.getattr(cls, name) @functools.wraps(orig_fn) def wrapped(*args, **kwargs): @@ -1341,14 +1295,14 @@ def wrapped(*args, **kwargs): return wrapped def _create_wrapped_nn_module_func(tracer: ConcreteTracer, mod: torch.nn.Module, name: str): - orig_fn = _orig_getattr(mod, name) - if not _orig_isinstance(orig_fn, MethodType): + orig_fn = orig_func.getattr(mod, name) + if not orig_func.isinstance(orig_fn, MethodType): raise RuntimeError(f'{tracer.path_of_module(mod)}.{name} is not a bound method, only support wrap bound method.') @functools.wraps(orig_fn) def wrapped(*args, **kwargs): module_qualified_name = tracer.path_of_module(mod) - with ScopeContextManager(tracer.scope, Scope(module_qualified_name, _orig_type(mod))) as _scope: + with ScopeContextManager(tracer.scope, Scope(module_qualified_name, orig_func.type(mod))) as _scope: need_pop = False if _scope.module_path not in tracer.module_stack: need_pop = True @@ -1515,15 +1469,15 @@ def _create_wrapped_leaf_func(tracer: ConcreteTracer, func: Callable, to_func: O def func_wrapper(*args, **kwargs): if tracer.temp_call_origin: return func(*args, **kwargs) - tracers = _orig_set(init_tracers) + tracers = orig_func.set(init_tracers) def unwrap_detect_tracers(obj): if isinstance(obj, ep.ConcreteProxy): tracers.add(obj.tracer) ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) - if _orig_len(tracers) == 0: + if orig_func.len(tracers) == 0: return to_func(*args, **kwargs) - elif _orig_len(tracers) == 1 and next(iter(tracers)) == tracer: + elif orig_func.len(tracers) == 1 and next(iter(tracers)) == tracer: return tracer.create_proxy('call_function', to_func, args, kwargs) else: raise Exception('more than 1 tracer detected. please report the issue') @@ -1534,15 +1488,15 @@ def _create_wrapped_leaf_method(tracer: ConcreteTracer, method, name: str, to_fu def method_wrapper(*args, **kwargs): if tracer.temp_call_origin: return method(*args, **kwargs) - tracers = _orig_set() + tracers = orig_func.set() def unwrap_detect_tracers(obj): if isinstance(obj, ep.ConcreteProxy): tracers.add(obj.tracer) ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) - if _orig_len(tracers) == 0: + if orig_func.len(tracers) == 0: return method(*args, **kwargs) - elif _orig_len(tracers) == 1 and next(iter(tracers)) == tracer: + elif orig_func.len(tracers) == 1 and next(iter(tracers)) == tracer: if to_func is not None: return tracer.create_proxy('call_function', to_func, args, kwargs) else: @@ -1569,15 +1523,15 @@ class clz_wrapper_clz: def __new__(cls, *args, **kwargs): if tracer.temp_call_origin: return clz(*args, **kwargs) - tracers = _orig_set() + tracers = orig_func.set() def unwrap_detect_tracers(obj): if isinstance(obj, ep.ConcreteProxy): tracers.add(obj.tracer) ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) - if _orig_len(tracers) == 0: + if orig_func.len(tracers) == 0: return clz(*args, **kwargs) - elif _orig_len(tracers) == 1 and next(iter(tracers)) == tracer: + elif orig_func.len(tracers) == 1 and next(iter(tracers)) == tracer: return tracer.create_proxy('call_function', clz, args, kwargs) else: raise Exception('more than 1 tracer detected. please report the issue') @@ -1587,9 +1541,9 @@ def __hash__(self): return id(self) for name in dir(clz): - attr = _orig_getattr(clz, name) + attr = orig_func.getattr(clz, name) if not name.startswith('_'): - if _orig_isinstance(attr, Callable): + if orig_func.isinstance(attr, Callable): setattr(clz_wrapper_clz, name, _create_wrapped_leaf_method(tracer, attr, name, None)) else: setattr(clz_wrapper_clz, name, attr) @@ -1612,19 +1566,19 @@ class clz_wrapper_clz: def __new__(cls, *args, **kwargs): if tracer.temp_call_origin: return clz(*args, **kwargs) - tracers = _orig_set() - if _orig_len(args) != 0: - if _orig_isinstance(args[0], ep.Proxy): + tracers = orig_func.set() + if orig_func.len(args) != 0: + if orig_func.isinstance(args[0], ep.Proxy): tracers.add(args[0].tracer) - if _orig_isinstance(args[0], Iterator): + if orig_func.isinstance(args[0], Iterator): args = (clz(args[0]), *args[1:]) - if _orig_isinstance(args[0], Iterable): + if orig_func.isinstance(args[0], Iterable): for item in args[0]: - if _orig_isinstance(item, ep.Proxy): + if orig_func.isinstance(item, ep.Proxy): tracers.add(item.tracer) - if _orig_len(tracers) == 0: + if orig_func.len(tracers) == 0: return clz(*args, **kwargs) - elif _orig_len(tracers) == 1 and next(iter(tracers)) == tracer: + elif orig_func.len(tracers) == 1 and next(iter(tracers)) == tracer: return tracer.create_proxy('call_function', clz, args, kwargs) else: @@ -1635,9 +1589,9 @@ def __hash__(self): return id(self) for name in dir(clz): - attr = _orig_getattr(clz, name) + attr = orig_func.getattr(clz, name) if not name.startswith('_') or name in ('__getitem__', '__setitem__', '__iter__', '__len__'): - if _orig_isinstance(attr, Callable): + if orig_func.isinstance(attr, Callable): setattr(clz_wrapper_clz, name, _create_wrapped_leaf_method(tracer, attr, name, None)) else: setattr(clz_wrapper_clz, name, attr) @@ -1667,13 +1621,13 @@ def _retain_weight_consistency(root: torch.nn.Module): _flag = 0 for module in root.modules(): for name, param in module.named_parameters(): - if _orig_isinstance(param, ep.ConcreteProxy): + if orig_func.isinstance(param, ep.ConcreteProxy): param: ep.ConcreteProxy _logger.warning(f'Parameter {name} of {module} is a ConcreteProxy. Some weight may be modified inplace within forward().') setattr(module, name, param.value) _flag |= 1 for name, buffer in module.named_buffers(): - if _orig_isinstance(buffer, ep.ConcreteProxy): + if orig_func.isinstance(buffer, ep.ConcreteProxy): buffer: ep.ConcreteProxy _logger.warning(f'Buffer {name} of {module} is a ConcreteProxy. Some buffer may be modified inplace within forward().') setattr(module, name, buffer.value) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py index 8bb97921..ff66d44b 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/function_patcher.py @@ -3,7 +3,7 @@ from typing import Any, Callable, List, Dict, NamedTuple from torch.fx._symbolic_trace import _Patcher -from .utils import _orig_reversed +from . import orig_func class _PatchedFnReusable(NamedTuple): frame_dict: Any @@ -99,7 +99,7 @@ def patch_method( def revert(self): if self.in_global_context: self._change_patch_mode_to(False) - for patch in _orig_reversed(self.patches_made): + for patch in orig_func.reversed(self.patches_made): # unpatch in reverse order to handle duplicates correctly patch.revert() try: diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py index 9f9e8a34..abe3d755 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -13,18 +13,11 @@ from textwrap import dedent from types import MethodType, FunctionType -from typing import List, Optional, Callable, Dict, Tuple +from typing import List, Optional, Tuple import torch -from .utils import ( - _orig_type, - _orig_isinstance, - _orig_len, - _orig_dict, - _orig_zip, - _orig_tuple, -) +from . import orig_func _logger = logging.getLogger(__name__) @@ -50,11 +43,11 @@ def visit_IfExp(self, node: ast.IfExp): # For example, # `x[0] if x is not None else None` will raise an error # if we convert it to `nnscaler.runtime.function.ifexpr(x is not None, x[0], None)` - if not _orig_isinstance(node.test, ast.Attribute) \ - or not _orig_isinstance(node.test.value, ast.Name) \ + if not orig_func.isinstance(node.test, ast.Attribute) \ + or not orig_func.isinstance(node.test.value, ast.Name) \ or node.test.value.id != 'self' or node.test.attr != 'training'\ - or any(_orig_isinstance(n, ast.Call) for n in ast.walk(node.body)) \ - or any(_orig_isinstance(n, ast.Call) for n in ast.walk(node.orelse)): + or any(orig_func.isinstance(n, ast.Call) for n in ast.walk(node.body)) \ + or any(orig_func.isinstance(n, ast.Call) for n in ast.walk(node.orelse)): return self.generic_visit(node) self.modified = True @@ -93,7 +86,7 @@ def visit_IfExp(self, node: ast.IfExp): ) def visit_UnaryOp(self, node: ast.UnaryOp): - if _orig_isinstance(node.op, ast.Not): + if orig_func.isinstance(node.op, ast.Not): self.modified = True return self.generic_visit(ast.Call( func=ast.Name(id=self.func_map[ast.Not], ctx=ast.Load()), @@ -104,17 +97,17 @@ def visit_UnaryOp(self, node: ast.UnaryOp): return self.generic_visit(node) def visit_Compare(self, node: ast.Compare): - if not any(_orig_isinstance(op, (ast.Is, ast.IsNot, ast.In, ast.NotIn)) for op in node.ops): + if not any(orig_func.isinstance(op, (ast.Is, ast.IsNot, ast.In, ast.NotIn)) for op in node.ops): return self.generic_visit(node) - if _orig_len(node.ops) != 1: + if orig_func.len(node.ops) != 1: raise RuntimeError('Chained Comparison is not supported') self.modified = True - if _orig_isinstance(node.ops[0], (ast.In, ast.NotIn)): + if orig_func.isinstance(node.ops[0], (ast.In, ast.NotIn)): args = [node.comparators[0], node.left] else: args = [node.left, node.comparators[0]] - if not _orig_isinstance(node.ops[0], ast.NotIn): + if not orig_func.isinstance(node.ops[0], ast.NotIn): ret_node = ast.Call( func=ast.Name(id=self.func_map[type(node.ops[0])], ctx=ast.Load()), args=args, @@ -143,7 +136,7 @@ class SuperTransformer(TrackedTransformer): super() is not supported for a standalone function. """ def visit_Call(self, node: ast.Call): - if _orig_isinstance(node.func, ast.Name) and node.func.id == 'super' and _orig_len(node.args) == 0: + if orig_func.isinstance(node.func, ast.Name) and node.func.id == 'super' and orig_func.len(node.args) == 0: self.modified = True # convert super() to super(self.__class__, self) return self.generic_visit(ast.Call( @@ -172,7 +165,7 @@ def __init__(self, proxy_call_name: str, ignore_funcs: Optional[List[str]] = Non def visit_Call(self, node: ast.Call): # will transform all function call to `proxy_call_name(func_name, *args, **kwargs)` # node.func can be expression, in that case, node.func.id is undefined. - if not _orig_isinstance(node.func, ast.Name) or ( + if not orig_func.isinstance(node.func, ast.Name) or ( node.func.id != self.proxy_call_name and node.func.id not in self.ignore_funcs ): self.modified = True @@ -208,7 +201,7 @@ def __init__(self, use_operator_patch: bool, operator_patch_backlist: List[str]) self.proxy_call_name = OperatorPatcherContext.patch_run.__name__ def patch_func_or_module(self, func_or_module): - if _orig_isinstance(func_or_module, torch.nn.Module): + if orig_func.isinstance(func_or_module, torch.nn.Module): module, func = func_or_module, func_or_module.forward new_func = self.patch_func_helper(func) module.forward = new_func @@ -241,7 +234,7 @@ def patch_func_helper(self, func): if self.use_operator_patch == (func in self.operator_patch_backlist): return func - if _orig_isinstance(func, MethodType): + if orig_func.isinstance(func, MethodType): # patch the function, not bound method, the function will be bound back after patch func_inner = func.__func__ the_self = func.__self__ @@ -249,7 +242,7 @@ def patch_func_helper(self, func): func_inner = func the_self = None # if it is not a function, or it has no code, then we can not patch it, directly return - if not _orig_isinstance(func_inner, FunctionType) or not hasattr(func_inner, '__code__'): + if not orig_func.isinstance(func_inner, FunctionType) or not hasattr(func_inner, '__code__'): return func lines, lnum = inspect.findsource(func_inner) @@ -305,14 +298,14 @@ def patch_func_helper(self, func): closure_dict = {} closures = func_inner.__closure__ co_freevars = func_inner.__code__.co_freevars - if (closures != None and _orig_len(closures) != 0) or _orig_len(co_freevars) != 0: - assert _orig_len(closures) == _orig_len(co_freevars) - closure_dict = _orig_dict(_orig_zip(co_freevars, [c.cell_contents for c in closures])) + if (closures != None and orig_func.len(closures) != 0) or orig_func.len(co_freevars) != 0: + assert orig_func.len(closures) == orig_func.len(co_freevars) + closure_dict = orig_func.dict(zip(co_freevars, [c.cell_contents for c in closures])) tuple_wrapped = tuple try: if sys.version_info < (3, 9): - setattr(builtins, 'tuple', _orig_tuple) + setattr(builtins, 'tuple', orig_func.tuple) var_dict = {} exec( # use func.__code__.co_filename to make the new function easily debuggable. diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py b/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py index 9918446b..3a3aa099 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/orig_func.py @@ -31,6 +31,10 @@ type = builtins.type slice = builtins.slice +all = builtins.all +min = builtins.min +max = builtins.max + # the wrapped functon/class method/class in torch import torch diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py index b925ca7a..6c0da1eb 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py @@ -1,13 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import builtins from dataclasses import dataclass import importlib import operator import traceback from pathlib import Path -from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Type +from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple import torch from torch.fx.node import Node, map_aggregate, _side_effectful_functions @@ -16,49 +15,6 @@ DICT_VALUES_TYPE= type({}.values()) DICT_ITEMS_TYPE= type({}.items()) - -# These need to run in global scope to handle nested calls correctly -_orig_module_call: Callable = torch.nn.Module.__call__ -_orig_module_getattr: Callable = torch.nn.Module.__getattr__ -_orig_module_getattribute: Callable = torch.nn.Module.__getattribute__ - -_orig_agfunc_apply: Callable = torch.autograd.function.Function.apply -_orig_torch_assert: Callable = torch._assert - -_orig_type: Callable = builtins.type -_orig_isinstance: Callable = builtins.isinstance -_orig_issubclass: Callable = builtins.issubclass -_orig_getattr: Callable = builtins.getattr - -_orig_range: Type[Any] = builtins.range -_orig_int: Type[Any] = builtins.int -_orig_float: Type[Any] = builtins.float -_orig_bool: Type[Any] = builtins.bool -_orig_tuple: Type[Any] = builtins.tuple -_orig_list: Type[Any] = builtins.list -_orig_set: Type[Any] = builtins.set -_orig_frozenset: Type[Any] = builtins.frozenset -_orig_dict: Type[Any] = builtins.dict -_orig_map: Type[Any] = builtins.map -_orig_zip: Type[Any] = builtins.zip -_orig_enumerate: Type[Any] = builtins.enumerate -_orig_slice: Type[Any] = builtins.slice -_orig_reversed: Type[Any] = builtins.reversed - -_orig_torch_size: Type[Any] = torch.Size -_orig_torch_finfo: Type[Any] = torch.finfo - -_orig_len: Callable = builtins.len -_orig_not: Callable = operator.not_ -_orig_is: Callable = operator.is_ -_orig_is_not: Callable = operator.is_not -_orig_contains: Callable = operator.contains -_orig_index: Callable = operator.index - -_orig_all: Callable = builtins.all -_orig_min: Callable = builtins.min -_orig_max: Callable = builtins.max - _orig_node_is_impure: Callable = Node.is_impure side_effectful_inplace_ops = { From 4b05bbc7718884efc088936033ec4efb30113465 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 23 Aug 2024 06:21:22 +0000 Subject: [PATCH 1714/1892] Merged PR 2242: Refine requires_grad of input tensors handling 1. Refine requires_grad of input tensors handling. Tracer/parser will respect the requires_grad flag of dummy inputs. 2. Fix IRObject lifecycle problem (just ignore IRObject lifecycle) parity check pass unit test pass --- nnscaler/codegen/lifecycle.py | 51 +++++++++------ nnscaler/codegen/module/module.py | 8 ++- nnscaler/graph/parser/converter.py | 4 -- nnscaler/parallel.py | 3 +- tests/parallel_module/test_gencode.py | 85 ++++++++++++++++++++++++- tests/parallel_module/test_normlayer.py | 21 +++++- 6 files changed, 143 insertions(+), 29 deletions(-) diff --git a/nnscaler/codegen/lifecycle.py b/nnscaler/codegen/lifecycle.py index cd1be218..e2413cde 100644 --- a/nnscaler/codegen/lifecycle.py +++ b/nnscaler/codegen/lifecycle.py @@ -17,7 +17,7 @@ def __init__(self, nodes: List[IRCell], graph_inputs: List[Any], graph_outputs: graph_outputs = IRSegment.get_objects_from_complex(graph_outputs) func_emission = FuncEmission() - self.nodes: Dict[int] = {node: lid for lid, node in enumerate(nodes)} + self.nodes: Dict[IRCell, int] = {node: lid for lid, node in enumerate(nodes)} # the last line id of consuming or producing a tensor self.lifetime: Dict[IRObject, int] = {} # the tensors can be released given the finish of line id @@ -91,33 +91,46 @@ def release_tensors_after_node(self, node: IRCell) -> List[IRSubTensor]: line_id = self.nodes[node] return self.release.get(line_id, []) - def releasable_after_node(self, tensor: IRSubTensor, node: IRCell) -> bool: + def releasable_after_node(self, obj: IRObject, node: IRCell) -> bool: """ - Check if the tensor is releasable after executing the node - - @param tensor IRSubTensor - @param node IRCell - - @return releasable bool + Check if the tensor is releasable after executing the node. + Please note that if it is not a IRSubTensor(is IRObject), + we will never manually release it. + + Args: + tensor (IRObject): the tensor to be checked + node (IRCell): the node to be checked + Returns: + releasable (bool): whether the tensor is releasable after executing """ + if not isinstance(obj, IRSubTensor): + return False + assert node in self.nodes - assert tensor in self.lifetime[tensor] + assert obj in self.lifetime line_id = self.nodes[node] - return self.lifetime[tensor] < line_id + return self.lifetime[obj] < line_id - def releasable_after_line(self, tensor: IRSubTensor, line: int) -> bool: + def releasable_after_line(self, obj: IRObject, line: int) -> bool: """ - Check if the tensor is releasable after executing the node - - @param tensor IRSubTensor - @param line int - - @return releasable bool + Check if the tensor is releasable after executing the node. + Please note that if it is not a IRSubTensor(is IRObject), + we will never manually release it. + + Args: + tensor (IRObject): the tensor to be checked + line (int): the line to be checked + Returns: + releasable (bool): whether the tensor is releasable after specific line. """ - return self.lifetime[tensor] < line + if not isinstance(obj, IRSubTensor): + return False + + assert obj in self.lifetime + return self.lifetime[obj] < line def get_line(self, node: IRCell) -> int: """ Get line id of the node """ - return self.nodes[node] \ No newline at end of file + return self.nodes[node] diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 1e5a2e25..d73f9941 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -699,7 +699,13 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: reducer=reducer_name, ranks=ranks, reduce_op=reduce_op, async_op=async_op, zero=zero, max_nbytes=max_nbytes, zero_ngroups=zero_ngroups) self.model_init_statements.append(init_code) - weights = [self.tensor_name(t, prefix_attr='self.') for t in weights] + # sorted with tid: to make the order of weights deterministic + # different order of weights in reducer can lead different all-reduce result + # not sure why (may be the result of ring-allreduce). but it is observed in the test, and causes parity check failed in the test + weights = [ + self.tensor_name(t, prefix_attr='self.') + for t in sorted(weights, key=lambda x: x.tid) + ] for weight in weights: add_param_code = add_param.format(reducer=reducer_name, weight=weight) self.model_init_statements.append(add_param_code) diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 517b3b1f..bfd2f0f7 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -135,10 +135,6 @@ def to_ir_graph( ) module_name = traced_model.__class__.__name__ - for input in inputs: - if isinstance(input, IRFullTensor): - input.requires_grad = False - graph = IRGraph.from_logic_graph(nodes, inputs, outputs, module_name) return graph diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index e8e75353..9481130a 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -308,7 +308,8 @@ def _to_cpu(val: Any): if isinstance(val, set): return {_to_cpu(t) for t in val} if isinstance(val, torch.Tensor): - return val.cpu() + requires_grad = val.is_floating_point() or val.is_complex() + return val.cpu().requires_grad_(requires_grad) return val diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index c07d75f6..a4295d49 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -259,6 +259,89 @@ def test_codegen_unused_args2(): launch_torchrun(1, _gencode_unused_args_worker2, tempdir) +def pas_dp_with_recompute(graph, cfg): + """ + pure data parallelism policy + """ + from nnscaler.ir import IRFwOperation, IRDataOperation + from nnscaler.policies import _replica + ngpus = cfg.plan_ngpus + if ngpus != 1: + raise ValueError("Data parallelism only supports 1 plan GPU") + + # combine + # x = _add(x, v1) + # x = x + v2 + # x = self.linear(x) + # together as a recompute unit + graph.recompute([ + *graph.select(name='_add', ntype=IRFwOperation), + *graph.select(name='add', ntype=IRFwOperation), + *graph.select(name='linear', ntype=IRFwOperation), + ]) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + _replica(graph, node, [0]) + return graph + + +@nnscaler.register_op('* -> *') +def _add(x, k): + return x + k + + +class RecomputeKwArgsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 5) + + def forward(self, x, v1, v2): + x = _add(x, v1) # v1 will be kwargs + x = x + v2 # v2 will be normal args + x = self.linear(x) + x = x - v2 + return x + + +@replace_all_device_with('cpu') +def test_codegen_recompute_kwargs(): + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + RecomputeKwArgsModule(), + { + 'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + 'v1': 1.0, + 'v2': 2.0, + }, + pas_dp_with_recompute, + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False + ) + # It will look like + # def segment33(self, x_36, v1_38, v2_39): + # def recompute(x_36, v1_38, v2_39): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 237, in forward, x = _add(x, v1) # v1 will be kwargs + # _add_29 = tests.parallel_module.test_gencode._add(x_36, k=v1_38) + # del x_36 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 238, in forward, x = x + v2 # v2 will be normal args + # add_30 = torch.add(_add_29, v2_39, alpha=1) + # del _add_29 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 239, in forward, x = self.linear(x) + # linear_33 = torch.nn.functional.linear(add_30, self.linear_weight_31, self.linear_bias_32) + # del add_30 + # return linear_33 + # linear_33 = ckpt.checkpoint(recompute, x_36, v1_38, v2_39, use_reentrant=False) + # del x_36 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 240, in forward, x = x - v2 + # sub_28 = torch.sub(linear_33, v2_39, alpha=1) + # del linear_33 + # return sub_28 + assert _gencode_contains(tempdir, RecomputeKwArgsModule, 0, + r'def recompute\(x_\d+, v1_\d+, v2_\d+\)' + ) + + + class DefaultArgsModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1110,7 +1193,7 @@ def test_codegen_scalar_tensor(tmp_path): assert _gencode_contains(tmp_path, ScalarTensorModule, 0, r"self\.register_buffer\('num_batches_tracked_\d+', torch\.empty\(\(\), dtype=torch\.int64\), persistent=True\)") assert _gencode_contains(tmp_path, ScalarTensorModule, 0, - r"self\.add_full_map\('num_batches_tracked_\d+', 2, False, 'num_batches_tracked', \(\), \.\.\., 1\)") + r"self\.add_full_map\('num_batches_tracked_\d+', \d+, False, 'num_batches_tracked', \(\), \.\.\., 1\)") class ConvTranspose1DModule(torch.nn.Module): diff --git a/tests/parallel_module/test_normlayer.py b/tests/parallel_module/test_normlayer.py index 6e590e19..2a4f063d 100644 --- a/tests/parallel_module/test_normlayer.py +++ b/tests/parallel_module/test_normlayer.py @@ -95,7 +95,13 @@ def _gencode_batchnorm2d_function(tempdir, config, pas_policy): bn.train() ref_output = bn(x) assert torch.equal( - m_new.bn_running_mean_22, bn.bn.running_mean + [y for x, y in m_new.named_buffers() if x.startswith('bn_running_mean_')][0], + bn.bn.running_mean + ), "Custom output does not match PyTorch output" + + assert torch.equal( + [y for x, y in m_new.named_buffers() if x.startswith('bn_running_var_')][0], + bn.bn.running_var ), "Custom output does not match PyTorch output" assert torch.equal( @@ -282,9 +288,18 @@ def _gencode_batchnorm2d_function_eval(tempdir, config, pas_policy): bn = BatchNorm2dModule().cuda() bn.eval() ref_output = bn(x) + + assert torch.equal( + [y for x, y in m_new.named_buffers() if x.startswith('bn_running_mean_')][0], + bn.bn.running_mean + ), "Custom output does not match PyTorch output" + assert torch.equal( - m_new.bn_running_mean_22, bn.bn.running_mean + [y for x, y in m_new.named_buffers() if x.startswith('bn_running_var_')][0], + bn.bn.running_var ), "Custom output does not match PyTorch output" + + assert torch.equal( output, ref_output ), "Custom output does not match PyTorch output" @@ -412,7 +427,7 @@ def _gencode_batchnorm2d_function_eval_4(tempdir, config, pas_policy, dim): assert torch.allclose( y_output, ref_output, 1e-6 ), "Custom output does not match PyTorch output" - + x = torch.chunk(x_list[rank_id // 2], 2, dim=0)[rank_id % 2] bn = BatchNorm2dModule().to(device) From 2ac649997f0884486d25f58ba105ae0505a09da4 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Mon, 26 Aug 2024 08:59:06 +0000 Subject: [PATCH 1715/1892] Merged PR 2245: Hotfix for reducer generation. --- nnscaler/codegen/module/module.py | 8 +--- .../lightning/pytorch/test_strategy.py | 37 ++++++++++++++++++- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index d73f9941..1e5a2e25 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -699,13 +699,7 @@ def init_reducer(self, node: IRWeightReducer, device: int) -> None: reducer=reducer_name, ranks=ranks, reduce_op=reduce_op, async_op=async_op, zero=zero, max_nbytes=max_nbytes, zero_ngroups=zero_ngroups) self.model_init_statements.append(init_code) - # sorted with tid: to make the order of weights deterministic - # different order of weights in reducer can lead different all-reduce result - # not sure why (may be the result of ring-allreduce). but it is observed in the test, and causes parity check failed in the test - weights = [ - self.tensor_name(t, prefix_attr='self.') - for t in sorted(weights, key=lambda x: x.tid) - ] + weights = [self.tensor_name(t, prefix_attr='self.') for t in weights] for weight in weights: add_param_code = add_param.format(reducer=reducer_name, weight=weight) self.model_init_statements.append(add_param_code) diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index 04fdcc41..350a5821 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -1,6 +1,8 @@ +from contextlib import contextmanager import os from pathlib import Path import math +from typing import Dict, List import torch from lightning import Trainer @@ -198,7 +200,38 @@ def on_train_step_end(trainer: 'Trainer', outputs, batches, idx: int) -> None: _correctnes_worker_single_loss_history.append(outputs[0].item()) - +_mocked_params: Dict[int, List[torch.Tensor]] = {} +@contextmanager +def mock_reducer_add_param(): + """ + Reorder the parameters in the reducer to match the order in the model + """ + from nnscaler.runtime.adapter.reducer import Reducer + from nnscaler.runtime.module import CubeModule + def add_param(self, param): + if id(self) not in _mocked_params: + _mocked_params[id(self)] = [] + _mocked_params[id(self)].append(param) + old_add_param = Reducer.add_param + old_add_reducer = CubeModule.add_reducer + Reducer.add_param = add_param + def add_reducer(self, reducer): + register_parameters = {} + for idx, p in enumerate(self.parameters()): + register_parameters[id(p)] = idx + if id(reducer) in _mocked_params: + _mocked_params[id(reducer)].sort(key=lambda x: register_parameters[id(x)]) + for p in _mocked_params[id(reducer)]: + old_add_param(reducer, p) + _mocked_params.pop(id(reducer)) + old_add_reducer(self, reducer) + CubeModule.add_reducer = add_reducer + yield + Reducer.add_param = old_add_param + CubeModule.add_reducer = old_add_reducer + + +@mock_reducer_add_param() def correctnes_worker_cli( tmp_path, gradient_clip_val, @@ -294,6 +327,7 @@ def on_val_step_end(trainer: Trainer, outputs, batches, idx) -> None: _correctnes_worker_single_loss_history +@mock_reducer_add_param() def correctnes_worker_nnscaler(tmp_path, gradient_clip_val, with_lr_scheduler, precision='32-true', with_tp=False, with_empty_scaler=False @@ -333,6 +367,7 @@ def correctnes_worker_nnscaler(tmp_path, gradient_clip_val, with_lr_scheduler, return model.update_history, model.nnscaler_pmodule.fullmap, model.val_loss_history, model.loss_history +@mock_reducer_add_param() def correctnes_worker_nnscaler_checkpoint(tmp_path, gradient_clip_val, with_lr_scheduler, precision='32-true', with_tp=False, with_empty_scaler=False From ce8ff4bff30c00576d6138a65bdcbc0167ab7c26 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Wed, 28 Aug 2024 06:04:06 +0000 Subject: [PATCH 1716/1892] Merged PR 2246: refine tracer wrap logic --- nnscaler/graph/parser/converter.py | 6 +- .../fx/concrete_trace_utils/__init__.py | 1 + .../fx/concrete_trace_utils/concrete_proxy.py | 11 +- .../concrete_trace_utils/concrete_tracer.py | 1063 +++-------------- .../concrete_trace_utils/operator_patcher.py | 4 +- .../fx/concrete_trace_utils/wrap_utils.py | 602 ++++++++++ nnscaler/graph/parser/register.py | 2 +- tests/utils.py | 10 +- 8 files changed, 778 insertions(+), 921 deletions(-) create mode 100644 nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index bfd2f0f7..f09ea014 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -11,7 +11,7 @@ from nnscaler.graph.parser.fx.parser import FxModuleParser from nnscaler.graph.parser.fx.concrete_trace_utils import concrete_trace -from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import Location, is_autograd_apply, LeafFnWrapInfo +from nnscaler.graph.parser.fx.concrete_trace_utils.wrap_utils import Location, is_autograd_apply, LeafWrapInfo from nnscaler.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops import nnscaler.runtime.function as cube_rt_function @@ -77,12 +77,12 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: autowrap_funcs = [CustomizedOps.kOpRuntime[sign] for sign in CustomizedOps.kOpMap] # filter out torch.autograd.Function.apply as concrete trace already treats them as leaf function autowrap_funcs = [fn for fn in autowrap_funcs if not is_autograd_apply(fn)] - leaf_functions = {func: LeafFnWrapInfo([], True, None) for func in autowrap_funcs if func is not None} + leaf_functions = {func: LeafWrapInfo([], True, None) for func in autowrap_funcs if func is not None} # get cube runtime functions cube_rt_funcs = [cube_rt_function.anchor, cube_rt_function.ifexpr] leaf_functions.update({ - func: LeafFnWrapInfo([Location(cube_rt_function, func.__name__)], True, None) + func: LeafWrapInfo([Location(cube_rt_function, func.__name__)], True, None) for func in cube_rt_funcs }) dce_ignored_funcs = set(cube_rt_funcs) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py b/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py index 630bfd03..279b44c1 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py @@ -12,4 +12,5 @@ More information about concrete tracing can be found in the :func:`concrete_trace` documentation. """ from .concrete_tracer import ConcreteTracer, concrete_trace +from .concrete_proxy import ConcreteProxy from .utils import ExtraSEFPatcher, TensorMetadata diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index e4012cb6..95b416d5 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -6,7 +6,6 @@ import dis import logging import inspect -import operator from typing import List, Optional, Iterable, Any, Set, Union @@ -143,10 +142,10 @@ def __len__(self) -> Union[int, ConcreteProxy]: return self.tracer.create_proxy('call_function', orig_func.len, (self,), {}) def __getitem__(self, *args, **kwargs) -> ConcreteProxy: - return self.tracer.create_proxy('call_function', operator.getitem, (self,) + args, kwargs) + return self.tracer.create_proxy('call_function', orig_func.getitem, (self,) + args, kwargs) def __setitem__(self, *args, **kwargs) -> ConcreteProxy: - return self.tracer.create_proxy('call_function', operator.setitem, (self,) + args, kwargs) + return self.tracer.create_proxy('call_function', orig_func.setitem, (self,) + args, kwargs) def __bool__(self) -> Union[bool, ConcreteProxy]: # to detect if in executing branch condition @@ -385,7 +384,7 @@ def __next__(self): self.index += 1 if self.index == self.len: raise StopIteration() - return self.tracer.create_proxy('call_function', operator.getitem, (self.root, self.index), {}) + return self.tracer.create_proxy('call_function', orig_func.getitem, (self.root, self.index), {}) @compatibility(is_backward_compatible=True) def map_aggregate_not_proxy(a, fn): @@ -430,7 +429,7 @@ def map_aggregate_not_proxy(a, fn): def _scope(method): def impl(*args, **kwargs): tracer = args[0].tracer - target = orig_func.getattr(operator, method) + target = orig_func.getattr(orig_func, method) return tracer.create_proxy('call_function', target, args, kwargs) impl.__name__ = method as_magic = f'__{method.strip("_")}__' @@ -442,7 +441,7 @@ def _define_reflectable(orig_method_name): method_name = f'__r{orig_method_name.strip("_")}__' def impl(self, rhs): - target = orig_func.getattr(operator, orig_method_name) + target = orig_func.getattr(orig_func, orig_method_name) return self.tracer.create_proxy('call_function', target, (rhs, self), {}) impl.__name__ = method_name impl.__qualname__ = method_name diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index 054b3a90..fb5f707c 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -11,17 +11,14 @@ import operator import functools import builtins -from dataclasses import dataclass, field -from itertools import chain from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType -from typing import Any, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, List, Callable, Union +from typing import Any, Dict, Optional, Set, Tuple, Type, List, Callable, Union from contextlib import contextmanager import torch -from torch._C import ScriptObject, ScriptFunction +from torch._C import ScriptObject from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict -from torch.utils._pytree import tree_flatten, tree_unflatten import torch.fx from torch.fx import GraphModule @@ -83,7 +80,7 @@ def __exit__(self, *args): return from . import concrete_proxy as ep -from . import pytree_utils, orig_func +from . import pytree_utils, orig_func, wrap_utils from .function_patcher import FunctionPatcher from .operator_patcher import OperatorPatcherContext from .utils import ( @@ -100,46 +97,6 @@ def __exit__(self, *args): HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS -@dataclass -class Location: - """ - The place a function/class locates. - Please note one function/class can be in multiple places. - Take `torch.meshgrid` for example, there are `torch.meshgrid`, 'torch.functional.meshgrid', 'torch._C._VariableFunctions.meshgrid', - """ - ns: Union[Type, ModuleType, Any] # the namespace of the name. It can be a class/module, etc. - name: str - - -@dataclass -class LeafFnWrapInfo: - """ - extra_locs: The place the function is imported. - is_force_trace: If set to false, the function will only be traced if inputs include proxy. - Such as 'torch.rand', we should trace it even if it doesn't have proxy as input, so it should be force traced. - replace_fn: If not `None`, we will use it to replace the original function in traced code. - Such as ModuleList.__getitem__, we can use operator.getitem to replace it. - """ - extra_locs: List[Location] = field(default_factory=list) - is_force_trace: bool = False - replace_fn: Optional[Callable] = None - - -@dataclass -class LeafClassWrapInfo: - """ - extra_locs: The place the class is imported. - is_iterable: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True. - """ - extra_locs: List[Location] = field(default_factory=list) - is_iterable: bool = False - - -def is_autograd_apply(func) -> bool: - return getattr(func, '__name__', None) == 'apply' \ - and orig_func.isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) - - @compatibility(is_backward_compatible=True) class ConcreteTracer(TracerBase): """ @@ -150,121 +107,6 @@ class ConcreteTracer(TracerBase): default_module_getattr = ( 'training', ) - default_autowrap_modules = ( - 'math', - ) - default_autowrap_leaf_function: Dict[Any, LeafFnWrapInfo] = { - # function - orig_func.len: LeafFnWrapInfo([], False, None), - orig_func.not_: LeafFnWrapInfo([], False, None), - orig_func.is_: LeafFnWrapInfo([], False, None), - orig_func.is_not: LeafFnWrapInfo([], False, None), - orig_func.contains: LeafFnWrapInfo([], False, None), - orig_func.index: LeafFnWrapInfo([], False, None), - orig_func.all: LeafFnWrapInfo([], False, None), - orig_func.min: LeafFnWrapInfo([], False, None), - orig_func.max: LeafFnWrapInfo([], False, None), - - # force-traced function (the factory functions of tensor creation) - torch.arange: LeafFnWrapInfo([], True, None), - torch.empty: LeafFnWrapInfo([], True, None), - torch.eye: LeafFnWrapInfo([], True, None), - torch.full: LeafFnWrapInfo([], True, None), - torch.linspace: LeafFnWrapInfo([], True, None), - torch.logspace: LeafFnWrapInfo([], True, None), - torch.ones: LeafFnWrapInfo([], True, None), - torch.rand: LeafFnWrapInfo([], True, None), - torch.randint: LeafFnWrapInfo([], True, None), - torch.randn: LeafFnWrapInfo([], True, None), - # torch.rand_like: LeafFnWrapInfo([], True, None), # seems that xxx_like will not directly call torch._TensorBase.xxx - # torch.randn_like: LeafFnWrapInfo([], True, None), - # torch.randint_like: LeafFnWrapInfo([], True, None), - torch.randperm: LeafFnWrapInfo([], True, None), - torch.tensor: LeafFnWrapInfo([], True, None), - torch.zeros: LeafFnWrapInfo([], True, None), - - # method - Sequential.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - Sequential.__len__: LeafFnWrapInfo([], False, orig_func.len), - Sequential.__iter__: LeafFnWrapInfo([], False, iter), - - ModuleList.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - ModuleList.__len__: LeafFnWrapInfo([], False, orig_func.len), - ModuleList.__iter__: LeafFnWrapInfo([], False, iter), - - ModuleDict.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - ModuleDict.__len__: LeafFnWrapInfo([], False, orig_func.len), - ModuleDict.__iter__: LeafFnWrapInfo([], False, iter), - ModuleDict.__contains__: LeafFnWrapInfo([], False, orig_func.contains), - - ParameterList.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - ParameterList.__len__: LeafFnWrapInfo([], False, orig_func.len), - ParameterList.__iter__: LeafFnWrapInfo([], False, iter), - - ParameterDict.__getitem__: LeafFnWrapInfo([], False, operator.getitem), - ParameterDict.__len__: LeafFnWrapInfo([], False, orig_func.len), - ParameterDict.__iter__: LeafFnWrapInfo([], False, iter), - ParameterDict.__contains__: LeafFnWrapInfo([], False, orig_func.contains), - } - # equals to `from torch.nn import functional as nn_functional` - # to pass pyright check - nn_functional = getattr(torch.nn, 'functional') - # order: torch.nn.functional > torch._C._VariableFunctions > torch._C._nn > torch._C._TensorBase - for name in torch.functional.__all__: - attr = getattr(torch.functional, name) - if attr not in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, attr) - for name in dir(nn_functional): - attr = getattr(nn_functional, name) - if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__')\ - and getattr(attr, '__module__', None) not in ('typing', 'torch.nn.modules.utils'): - if attr not in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) - if hasattr(attr, '__module__') and attr.__module__ != 'torch.nn.functional': - default_autowrap_leaf_function[attr].extra_locs.append(Location(nn_functional, name)) - for name in dir(torch._C._VariableFunctions): - attr = getattr(torch._C._VariableFunctions, name) - if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__'): - if attr not in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) - for name in dir(torch._C._nn): - attr = getattr(torch._C._nn, name) - if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__'): - if attr not in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, getattr(torch.functional, name, None)) - if hasattr(attr, '__module__') and attr.__module__ != 'torch._C._nn': - default_autowrap_leaf_function[attr].extra_locs.append(Location(torch._C._nn, name)) - for name in dir(torch._C._TensorBase): - attr = getattr(torch._C._TensorBase, name) - if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__'): - if attr not in default_autowrap_leaf_function: - to_func = getattr(torch.Tensor, name, None) - to_func = None if to_func == attr else to_func - default_autowrap_leaf_function[attr] = LeafFnWrapInfo([], False, to_func) - # find the multi position for default_autowrap_leaf_function in torch.__dir__() - for name in dir(torch): - attr = getattr(torch, name) - if callable(attr) and not orig_func.isinstance(attr, Type) and not name.startswith('__') \ - and attr in default_autowrap_leaf_function: - default_autowrap_leaf_function[attr].extra_locs.append(Location(torch, name)) - - default_autowrap_leaf_class: Dict[Type, LeafClassWrapInfo] = { - # class - orig_func.bool: LeafClassWrapInfo([], False), - orig_func.int: LeafClassWrapInfo([], False), - orig_func.float: LeafClassWrapInfo([], False), - - # iterable class - orig_func.tuple: LeafClassWrapInfo([], True), - orig_func.list: LeafClassWrapInfo([], True), - orig_func.set: LeafClassWrapInfo([], True), - orig_func.frozenset: LeafClassWrapInfo([], True), - orig_func.dict: LeafClassWrapInfo([], True), - orig_func.reversed: LeafClassWrapInfo([], False), - - orig_func.torch_Size: LeafClassWrapInfo([], False), - orig_func.torch_finfo: LeafClassWrapInfo([], False), - } @compatibility(is_backward_compatible=True) def __init__(self, cpu_offload = False, record_frames = False): @@ -310,7 +152,7 @@ def fetch_attr(self, target: str) -> Any: """ to get the attr in self.root. only for execution of 'call_module' nodes. """ - with self.do_temp_call_origin(): + with wrap_utils.do_temp_call_origin(): target_atoms = target.split('.') attr_itr = self.root for i, atom in orig_func.enumerate(target_atoms): @@ -324,7 +166,6 @@ def fetch_attr(self, target: str) -> Any: def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]): """ actually execute the code. - apply the patcher, and the _autowrap_check to the target function. """ if kind == 'output': return args[0], args, kwargs @@ -434,7 +275,7 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: use the 'run_target' to actually execute the code, and store the value in 'value' field. create the nodes for the target and the input of the target (if the target is one of call_method, call_function, call_module). """ - with self.do_temp_call_origin(): + with wrap_utils.do_temp_call_origin(): def unwrap_nested_proxy(proxy: ep.ConcreteProxy): return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) @@ -448,7 +289,7 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): value_unwrapped, args_run, kwargs_run = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) # because setitem is an inplace operation and will not return the obj, so here is a workaound to record node result - node_result = args_run[0] if kind == "call_function" and target == operator.setitem else value_unwrapped + node_result = args_run[0] if kind == "call_function" and target == orig_func.setitem else value_unwrapped # here update the origin args/kwargs to prevent inplace operator to the input args = update_tree_proxy_value(args, args_run) kwargs = update_tree_proxy_value(kwargs, kwargs_run) @@ -461,7 +302,7 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): node = self.create_node(kind, target, args_, kwargs_, name, type_expr, node_result) if self.record_frames and kind != 'placeholder': - with self.do_temp_call_origin(): + with wrap_utils.do_temp_call_origin(): node.meta['frame_record'] = get_frame_record() proxy = self.proxy(value_unwrapped, node) @@ -559,43 +400,9 @@ def create_arg(self, a: Any) -> Union[Node, Any]: @compatibility(is_backward_compatible=True) def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: """ - similar to _symbolic_trace.Tracer.is_leaf_module - """ - return (m.__module__.startswith('torch.nn.functional') and not orig_func.isinstance(m, (Sequential, ModuleList, ModuleDict)))\ - or orig_func.isinstance(m, self.leaf_module) - - @compatibility(is_backward_compatible=True) - def path_of_module(self, mod: torch.nn.Module) -> str: - """ - similar to _symbolic_trace.Tracer.path_of_module + in nnscaler, will unpack all module to functions, so always return False """ - # Prefer the O(1) algorithm - if self.submodule_paths: - path = self.submodule_paths.get(mod) - # TODO: better infomation - if path is None: - if not hasattr(self.root, '_module_constants'): - self.root._module_constants = torch.nn.ModuleList() - module_constants = self.root._module_constants - assert isinstance(module_constants, torch.nn.ModuleList) - if hasattr(mod, 'extra_repr'): - sub_path = orig_func.type(mod).__name__ + mod.extra_repr() - else: - sub_path = str(orig_func.len(module_constants)) - if not hasattr(module_constants, sub_path): - module_constants.add_module(sub_path, mod) - path = '_module_constants.%s' % sub_path - self.submodule_paths[mod] = path - return path - assert isinstance(path, str) - return path - # O(N^2) fallback in the case that we didn't store the submodule - # paths. - else: - for n, p in self.root.named_modules(): - if mod is p: - return n - raise NameError('module is not installed as a submodule') + return False # This method will be refactored @compatibility(is_backward_compatible=False) @@ -671,13 +478,126 @@ def proxy_placeholder(name: str): return root_fn, args, more_args, kwargs + def get_wrapped_leaves(self, leaf_functions: Dict[Callable, wrap_utils.LeafWrapInfo], leaf_class: Dict[ModuleType, wrap_utils.LeafWrapInfo]): + wrapped_leaf_leaves = {} + for func, wrap_info in leaf_functions.items(): + locations = tuple(wrap_info.extra_locs) + if wrap_utils.is_autograd_apply(func): + # torch.autograd.function + assert wrap_info.replacement == None, '.apply should set to_func to None!' + if func.__self__ not in self.autograd_functions_mapping: + self.autograd_functions_mapping[func.__self__] = wrap_utils.create_wrapped_leaf_func(func) + wrapped = self.autograd_functions_mapping[func.__self__] + elif isinstance(func, torch._C.ScriptFunction): + # if it is a script function, + # here will wrap the origin function location and forward the script function to the origin one. + # _torchdynamo_inline is introduced in pytorch 2.0, it is the original function of the script function. + inner_func = func._torchdynamo_inline + # some `func.__module__` may have additional `_` compare with its import path in user code, + # for example, `operator.add.__module__` is `_operator` and `_operator` is a built-in module and we don't want to touch it, + # we assume user won't import function from module named with prefix `_`, + # here we only wrap the function under no prefix `_` module, i.e. functions under `operator`. + if inner_func.__module__.startswith('_') and inner_func.__module__ != '__main__': + path = sys.modules.get(inner_func.__module__[1:], sys.modules[inner_func.__module__]) + else: + path = sys.modules[inner_func.__module__] + locations = (*locations, wrap_utils.Location(path, inner_func.__name__)) + wrapped = wrap_utils.create_wrapped_leaf_func( + func, + replace_func=inner_func, + default_tracer=self if wrap_info.is_force_trace else None, + ) + else: + # 'TensorBase': torch >= 2.3, '_TensorBase': torch < 2.3 + if func.__qualname__.startswith('_TensorBase') or func.__qualname__.startswith('TensorBase'): + locations = (*locations, wrap_utils.Location(torch.Tensor, func.__name__)) + wrapped = wrap_utils.create_wrapped_leaf_func( + getattr(torch.Tensor, func.__name__), + replace_func=wrap_info.replacement, + default_tracer=self if wrap_info.is_force_trace else None, + is_method=True, + method_name=func.__name__, + ) + elif func.__qualname__.startswith('_VariableFunctionsClass'): + if hasattr(torch, func.__name__) and getattr(torch, func.__name__) == func: + # avoid bad attr like 'unique_dim' + locations = (*locations, wrap_utils.Location(torch, func.__name__)) + wrapped = wrap_utils.create_wrapped_leaf_func( + func, + replace_func=wrap_info.replacement, + default_tracer=self if wrap_info.is_force_trace else None, + ) + elif isinstance(func, (MethodDescriptorType, MethodWrapperType)): + wrapped = wrap_utils.create_wrapped_leaf_func( + func, + replace_func=wrap_info.replacement, + default_tracer=self if wrap_info.is_force_trace else None, + is_method=True, + method_name=func.__name__, + ) + elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn' \ + and not func.__qualname__.startswith('PyCapsule'): + # method + # in torch >= 2.2, we found two functions under torch._C has no __module__: + # + # + if func.__module__ is not None and func.__module__ in sys.modules: + if func.__module__.startswith('_') and func.__module__ != '__main__': + path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) + else: + path = sys.modules[func.__module__] + path = getattr(path, func.__qualname__.split('.')[0]) + locations = (*locations, wrap_utils.Location(path, func.__name__)) + if len(locations) == 0: + _logger.warning(f'Can not find location of {func}, skip wrap it.') + continue + wrapped = wrap_utils.create_wrapped_leaf_func( + func, + replace_func=wrap_info.replacement, + default_tracer=self if wrap_info.is_force_trace else None, + is_method=True, + method_name=func.__name__, + ) + else: + # common function + # in torch >= 2.2, we found two functions under torch._C has no __module__: + # + # + if func.__module__ is not None and func.__module__ in sys.modules: + if func.__module__.startswith('_') and func.__module__ != '__main__': + path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) + else: + path = sys.modules[func.__module__] + locations = (*locations, wrap_utils.Location(path, func.__name__)) + if len(locations) == 0: + _logger.warning(f'Can not find location of {func}, skip wrap it.') + continue + wrapped = wrap_utils.create_wrapped_leaf_func( + func, + replace_func=wrap_info.replacement, + default_tracer=self if wrap_info.is_force_trace else None, + ) + wrapped_leaf_leaves[func] = (locations, wrapped) + + for clz, wrap_info in leaf_class.items(): + if clz.__module__.startswith('_') and clz.__module__ != '__main__': + path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) + else: + path = sys.modules[clz.__module__] + wrapped = wrap_utils.create_wrapped_leaf_class( + clz, + replace_cls=wrap_info.replacement, + default_tracer=self if wrap_info.is_force_trace else None, + ) + locations = (*wrap_info.extra_locs, wrap_utils.Location(path, clz.__name__)) + wrapped_leaf_leaves[clz] = (locations, wrapped) + + return wrapped_leaf_leaves + @compatibility(is_backward_compatible=True) def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, - autowrap_modules: Tuple[str] | None = None, - autowrap_leaf_function: Optional[Dict[Any, LeafFnWrapInfo]] = None, - autowrap_leaf_class: Optional[Dict[Type, LeafClassWrapInfo]] = None, - leaf_module = None, - fake_middle_class = None, + autowrap_leaf_function: Optional[Dict[Any, wrap_utils.LeafWrapInfo]] = None, + autowrap_leaf_class: Optional[Dict[Type, wrap_utils.LeafWrapInfo]] = None, concrete_args: Union[Dict[str, Any], Tuple], use_operator_patch: bool = True, operator_patch_backlist: List[str] | None = None, @@ -720,28 +640,12 @@ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, } # preprocess arguments - autowrap_modules = autowrap_modules if autowrap_modules is not None else tuple() autowrap_leaf_function = autowrap_leaf_function if autowrap_leaf_function is not None else {} autowrap_leaf_class = autowrap_leaf_class if autowrap_leaf_class is not None else {} - leaf_module = leaf_module if leaf_module is not None else () - fake_middle_class = fake_middle_class if fake_middle_class is not None else () operator_patch_backlist = operator_patch_backlist if operator_patch_backlist is not None else [] - # Python modules to apply autowrap to at the start, in addition to - # modules we see while tracing - self._autowrap_search: List[ModuleType] = list( - sys.modules[m] for m in (*autowrap_modules, *ConcreteTracer.default_autowrap_modules) - ) - # Functions we will eagerly wrap when we see them while tracing - # this captures both `math.sqrt()` and `from math import sqrt` automatically - self._autowrap_function_ids: Set[int] = { - id(value) for name, value in chain(*[m.__dict__.items() for m in self._autowrap_search]) - if not name.startswith("_") and callable(value)} - self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None - self.autowrap_leaf_function = {**autowrap_leaf_function, **ConcreteTracer.default_autowrap_leaf_function} - self.autowrap_leaf_class = {**autowrap_leaf_class, **ConcreteTracer.default_autowrap_leaf_class} - self.leaf_module = leaf_module - self.fake_middle_class = fake_middle_class + self.autowrap_leaf_function = {**autowrap_leaf_function, **wrap_utils.default_autowrap_leaf_function} + self.autowrap_leaf_class = {**autowrap_leaf_class, **wrap_utils.default_autowrap_leaf_class} if isinstance(root, torch.nn.Module): self.root = root @@ -751,7 +655,6 @@ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, ), f"traced_func_name={forward_function_name} doesn't exist in {orig_func.type(root).__name__}" fn = getattr(root, forward_function_name) - self.submodule_paths = {mod: name for name, mod in root.named_modules()} else: self.root = torch.nn.Module() fn = root @@ -779,409 +682,53 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): assert isinstance(fn, FunctionType) fn_globals = fn.__globals__ # run before it gets patched - fn, args, more_args, kwargs = self.create_args_for_root(fn, isinstance(root, torch.nn.Module), concrete_args) - self.the_path_of_parameter = {id(v): k for k, v in self.root.named_parameters()} - self.the_path_of_buffer = {id(v): k for k, v in self.root.named_buffers()} - - def get_middle_class(node, memo = set(), prefix = ''): - if node not in memo: - memo.add(node) - yield prefix, node - if isinstance(node, torch.nn.Module): - items = (*((k, v) for k, v in node.__dict__.items() if not k.startswith('_')), *node._modules.items()) - else: - items = ((k, v) for k, v in node.__dict__.items() if not k.startswith('_')) - for name, subfield in items: - if isinstance(subfield, (torch.nn.Module, self.fake_middle_class)): - submodule_prefix = prefix + ('.' if prefix else '') + name - for m in get_middle_class(subfield, memo, submodule_prefix): - yield m - self.the_path_of_middle_class = {id(v): k for k, v in get_middle_class(self.root)} - - @functools.wraps(orig_func.torch_module_getattribute) - def module_getattribute_wrapper(mod, attr): - if self.temp_call_origin: - try: - return orig_func.torch_module_getattribute(mod, attr) - except AttributeError: - return orig_func.torch_module_getattr(mod, attr) - with self.do_temp_call_origin(): - try: - attr_val = orig_func.torch_module_getattribute(mod, attr) - except AttributeError: - attr_val = orig_func.torch_module_getattr(mod, attr) - if orig_func.isinstance(attr_val, ep.ConcreteProxy): - warn_msg = f'Detected {self.the_path_of_middle_class[id(mod)]}.{attr} is a ConcreteProxy, ' + \ - 'this is usually caused by directly assigning the return value of some leaf function to the attribute of the module. ' + \ - 'Please note that this writing method may cause some trace errors.' - _logger.warning(warn_msg) - return attr_val - # using isinstance instead of orig_func.isinstance to judge whether - # the ConcreteProxy.value is the following three types if the attr_val is a ConcreteProxy - elif isinstance(attr_val, (orig_func.tuple, orig_func.list)): - if self.the_path_of_middle_class[id(mod)] == '': - return self.create_proxy('get_attr', f'{attr}', (), {}) - else: - return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) - elif attr in self.default_module_getattr: - if self.the_path_of_middle_class[id(mod)] == '': - return self.create_proxy('get_attr', f'{attr}', (), {}) - else: - return self.create_proxy('get_attr', f'{self.the_path_of_middle_class[id(mod)]}.{attr}', (), {}) - elif id(attr_val) in self.the_path_of_parameter: - return self.create_proxy('get_attr', self.the_path_of_parameter[id(attr_val)], (), {}) - elif id(attr_val) in self.the_path_of_buffer: - return self.create_proxy('get_attr', self.the_path_of_buffer[id(attr_val)], (), {}) - return attr_val - - @functools.wraps(orig_func.torch_module_call) - def module_call_wrapper(mod, *args, **kwargs): - if self.temp_call_origin: - return orig_func.torch_module_call(mod, *args, **kwargs) - else: - # codes below corresponds to symbolic tracer's call_module - module_qualified_name = self.path_of_module(mod) - with ScopeContextManager(self.scope, Scope(module_qualified_name, orig_func.type(mod))) as _scope: - self.module_stack[_scope.module_path] = _scope.module_type - if not self.is_leaf_module(mod, module_qualified_name): - _autowrap_check(self, - mod.forward.__globals__, - self._autowrap_function_ids, - self.autowrap_leaf_pairs, - self.agfunc_dict) - _autowrap_check(self, - mod.__dict__, - self._autowrap_function_ids, - self.autowrap_leaf_pairs, - self.agfunc_dict) - ret_val = orig_func.torch_module_call(mod, *args, **kwargs) - else: - ret_val = self.create_proxy('call_module', module_qualified_name, args, kwargs) - key, _ = self.module_stack.popitem(last=True) - assert key == _scope.module_path, f" Unexpected key {key}" - return ret_val - - class map_wrapper_clz: - # used to track the original class - _fx_wrapped_ori_clz = orig_func.map - - def __new__(cls, the_func, *iterables: Any): - if self.temp_call_origin: - return orig_func.map(the_func, *iterables) - tracers = orig_func.set() - for one_iter in iterables: - if orig_func.isinstance(one_iter, ep.Proxy): - tracers.add(one_iter.tracer) - if orig_func.len(tracers) > 1: - raise Exception('more than 1 tracer detected. please report the issue') - elif orig_func.len(tracers) == 1: - results = orig_func.list() - for args in zip(*iterables): - results.append(the_func(*args)) - return next(iter(tracers)).create_proxy('call_function', orig_func.tuple, (results,), {}) - - ## for the multi-level list/tuple - iterables = orig_func.list(orig_func.list(it) for it in iterables) - for it in iterables: - for arg in it: - if orig_func.isinstance(arg, ep.Proxy): - tracers.add(arg.tracer) - if orig_func.len(tracers) > 1: - raise Exception('more than 1 tracer detected. please report the issue') - elif orig_func.len(tracers) == 1: - results = orig_func.list() - for args in zip(*iterables): - results.append(the_func(*args)) - return next(iter(tracers)).create_proxy('call_function', orig_func.tuple, (results,), {}) - ## for the multi-level list/tuple end - - return orig_func.map(the_func, *iterables) - def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(orig_func.map)) - def __hash__(self): - return id(self) - map_wrapper = map_wrapper_clz - - class range_wrapper_clz: - # used to track the original class - _fx_wrapped_ori_clz = orig_func.range - - def __new__(cls, *args): - # TODO: better infomation - assert 1 <= orig_func.len(args) <= 3 - args = (arg.value if orig_func.isinstance(arg, ep.ConcreteProxy) else arg for arg in args) - return orig_func.range(*args) - def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(orig_func.range)) - def __hash__(self): - return id(self) - range_wrapper = range_wrapper_clz - - class enumerate_wrapper_clz: - # used to track the original class - _fx_wrapped_ori_clz = orig_func.enumerate - - def __new__(cls, iterable, start=0): - count = start - for elem in iterable: - if orig_func.isinstance(elem, ep.ConcreteProxy) and orig_func.isinstance(elem.value, (orig_func.int, str)): - yield count, elem.value - else: - yield count, elem - count += 1 - def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(orig_func.enumerate)) - def __hash__(self): - return id(self) - enumerate_wrapper = enumerate_wrapper_clz - - class type_wrapper_clz: - # used to track the original class - _fx_wrapped_ori_clz = orig_func.type - - def __new__(cls, obj_or_name, *args): - # case 1: class type(name, bases, dict, **kwds) - if orig_func.len(args) > 0: - assert orig_func.len(args) == 2 - base_cls, cls_dict = args[0], args[1] - # if it is a wrapped class, replace it to the original one - base_cls = orig_func.tuple(bs._fx_wrapped_ori_clz if hasattr(bs, '_fx_wrapped_ori_clz') else bs for bs in base_cls) - return orig_func.type(obj_or_name, base_cls, cls_dict) - # case 2: class type(object) - else: - orig_type = orig_func.type(obj_or_name) - if orig_type in (ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): - return orig_func.type(obj_or_name.value) - else: - return orig_type - def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(orig_func.enumerate)) - def __hash__(self): - return id(self) - type_wrapper = type_wrapper_clz - - @classmethod - @functools.wraps(orig_func.torch_agfunc_apply) - def agfunc_apply_wrapper(clz, *args, **kwargs): - if clz not in self.agfunc_dict: - self.agfunc_dict[clz] = torch._C._FunctionBase.__dict__['apply'].__get__(None, clz) - if self.temp_call_origin: - return self.agfunc_dict[clz](*args, **kwargs) - tracers = orig_func.set() - def unwrap_detect_tracers(obj): - if isinstance(obj, ep.ConcreteProxy): - tracers.add(obj.tracer) - ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) - ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) - if orig_func.len(tracers) == 0: - return self.agfunc_dict[clz](*args, **kwargs) - elif orig_func.len(tracers) == 1 and next(iter(tracers)) == self: - return self.create_proxy('call_function', self.agfunc_dict[clz], args, kwargs) - else: - raise Exception('more than 1 tracer detected. please report the issue') + fn, args, more_args, kwargs = self.create_args_for_root(fn, isinstance(root, torch.nn.Module), concrete_args) - @functools.wraps(orig_func.torch_assert) - def torch_assert_wrapper(condition, message): - while orig_func.isinstance(condition, ep.ConcreteProxy): - condition = condition.value - return orig_func.torch_assert(condition, message) + self.path_of_module = {id(v): k for k, v in self.root.named_modules()} + self.path_of_parameter = {id(v): k for k, v in self.root.named_parameters()} + self.path_of_buffer = {id(v): k for k, v in self.root.named_buffers()} - self.agfunc_dict: dict[Type, Any] = {} - self.autowrap_leaf_pairs = { - id(orig_func.torch_assert): torch_assert_wrapper, - } - self.wrapped_leaf: Dict[Any, Tuple[Tuple[Location,...], Any]] = dict() - - for func, wrap_info in self.autowrap_leaf_function.items(): - locations = tuple(wrap_info.extra_locs) - if is_autograd_apply(func): - # torch.autograd.function - assert wrap_info.replace_fn == None, '.apply should set to_func to None!' - if func.__self__ not in self.agfunc_dict: - self.agfunc_dict[func.__self__] = _create_wrapped_leaf_func(self, func, func) - wrapped = self.agfunc_dict[func.__self__] - elif isinstance(func, ScriptFunction): - # if it is a script function, - # here will wrap the origin function location and forward the script function to the origin one. - # _torchdynamo_inline is introduced in pytorch 2.0, it is the original function of the script function. - inner_func = func._torchdynamo_inline - # some `func.__module__` may have additional `_` compare with its import path in user code, - # for example, `operator.add.__module__` is `_operator` and `_operator` is a built-in module and we don't want to touch it, - # we assume user won't import function from module named with prefix `_`, - # here we only wrap the function under no prefix `_` module, i.e. functions under `operator`. - if inner_func.__module__.startswith('_') and inner_func.__module__ != '__main__': - path = sys.modules.get(inner_func.__module__[1:], sys.modules[inner_func.__module__]) - else: - path = sys.modules[inner_func.__module__] - locations = (*locations, Location(path, inner_func.__name__)) - if wrap_info.is_force_trace: - wrapped = _create_wrapped_leaf_func(self, func, inner_func, (self,)) - else: - wrapped = _create_wrapped_leaf_func(self, func, inner_func) - else: - # for example, torch.Tensor.view.__qualname__ is 'TensorBase.view', - # should also add the location `Location(torch.Tensor, func.__name__)` for these methods. - # NOTE: `_TensorBase` is renamed to `TensorBase` in the latest pytorch version. - if func.__qualname__.startswith('_TensorBase') or func.__qualname__.startswith('TensorBase'): - locations = (*locations, Location(torch.Tensor, func.__name__)) - wrapped = _create_wrapped_leaf_method(self, getattr(torch.Tensor, func.__name__), func.__name__, wrap_info.replace_fn) - elif func.__qualname__.startswith('_VariableFunctionsClass'): - if hasattr(torch, func.__name__) and getattr(torch, func.__name__) == func: - # avoid bad attr like 'unique_dim' - locations = (*locations, Location(torch, func.__name__)) - if wrap_info.is_force_trace: - wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn, (self,)) - else: - wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn) - elif orig_func.isinstance(func, (MethodDescriptorType, MethodWrapperType)): - wrapped = _create_wrapped_leaf_method(self, func, func.__name__, wrap_info.replace_fn) - elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn' \ - and not func.__qualname__.startswith('PyCapsule'): - # method - # in torch >= 2.2, we found two functions under torch._C has no __module__: - # - # - if func.__module__ is not None: - if func.__module__.startswith('_') and func.__module__ != '__main__': - path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) - else: - path = sys.modules[func.__module__] - path = getattr(path, func.__qualname__.split('.')[0]) - locations = (*locations, Location(path, func.__name__)) - if len(locations) == 0: - _logger.warning(f'Can not find location of {func}, skip wrap it.') - continue - wrapped = _create_wrapped_leaf_method(self, func, func.__name__, wrap_info.replace_fn) - else: - # common function - # in torch >= 2.2, we found two functions under torch._C has no __module__: - # - # - if func.__module__ is not None: - if func.__module__.startswith('_') and func.__module__ != '__main__': - path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) - else: - path = sys.modules[func.__module__] - locations = (*locations, Location(path, func.__name__)) - if len(locations) == 0: - _logger.warning(f'Can not find location of {func}, skip wrap it.') - continue - if wrap_info.is_force_trace: - wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn, (self,)) - else: - wrapped = _create_wrapped_leaf_func(self, func, wrap_info.replace_fn) - self.wrapped_leaf[func] = (locations, wrapped) + # use to track the autograd function classes with the wrapped apply method + # {autograd_function_class: wrapped_autograd_function_apply} + self.autograd_functions_mapping: dict[Type, Any] = {} + self.wrapped_leaf: Dict[Any, Tuple[Tuple[wrap_utils.Location,...], Any]] = self.get_wrapped_leaves(self.autowrap_leaf_function, self.autowrap_leaf_class) # for the customized functions, we need to revert all the wrapped function to the original one to run it # for the functions default wrapped, we don't revert to save time for func in autowrap_leaf_function: self.add_need_revert_function(func, self.wrapped_leaf.get(func, (None, None))[1]) - self.clz_wrapper_map: Dict[Any, Type] = { - map_wrapper: orig_func.map, - enumerate_wrapper: orig_func.enumerate, - range_wrapper: orig_func.range, - type_wrapper: orig_func.type, - } - for clz, wrap_info in self.autowrap_leaf_class.items(): - if clz.__module__.startswith('_') and clz.__module__ != '__main__': - path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) - else: - path = sys.modules[clz.__module__] - if wrap_info.is_iterable: - wrapped = _create_wrapped_leaf_iterable_class(self, clz) - else: - wrapped = _create_wrapped_leaf_class(self, clz) - locations = (*wrap_info.extra_locs, Location(path, clz.__name__)) - self.wrapped_leaf[clz] = (locations, wrapped) - self.clz_wrapper_map[wrapped] = clz - - for clz in self.fake_middle_class: - wrapped = _create_wrapped_attr_for_middle_class(self, clz, self.the_path_of_middle_class) - self.wrapped_leaf[clz.__getattribute__] = ((Location(clz, '__getattribute__'),), wrapped) - # wrap all forward in the submodule to trace the module stack + # NOTE: temp disable the forward wrap, will add back later for mod in self.root.modules(): - wrapped = _create_wrapped_nn_module_func(self, mod, forward_function_name) - self.wrapped_leaf[mod.forward] = ((Location(mod, forward_function_name),), wrapped) - - @functools.wraps(orig_func.isinstance) - def isinstance_wrapper(instance, clz): - if orig_func.type(clz) in (slice, tuple, list, orig_func.slice, orig_func.tuple, orig_func.list): - clz_wrapped = [] - for wrapped_type, orig_type in self.clz_wrapper_map.items(): - if wrapped_type in clz: - clz_wrapped.append(orig_type) - clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in self.clz_wrapper_map)) - # use orig_func.isinstance(clz, Iterable) will cause an endless recursive loop - for cls in (object, ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): - if cls in clz and orig_func.isinstance(instance, cls): - return True - if orig_func.isinstance(instance, ep.ConcreteProxy): - return orig_func.isinstance(instance.value, clz) - else: - return orig_func.isinstance(instance, clz) - else: - if clz in (object, ep.ConcreteProxy, ep.ConcreteAttrProxy, ep.ConcreteUnpackIterProxy): - return orig_func.isinstance(instance, clz) - if clz in self.clz_wrapper_map: - clz = self.clz_wrapper_map[clz] - if orig_func.isinstance(instance, ep.ConcreteProxy): - instance = instance.value - return orig_func.isinstance(instance, clz) - - @functools.wraps(orig_func.issubclass) - def issubclass_wrapper(subclass, clz): - if orig_func.type(clz) in (slice, tuple, list, orig_func.slice, orig_func.tuple, orig_func.list): - clz_wrapped = [] - for wrapped_type, orig_type in self.clz_wrapper_map.items(): - if wrapped_type in clz: - clz_wrapped.append(orig_type) - clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in self.clz_wrapper_map)) - return orig_func.issubclass(subclass, clz) - else: - if clz in self.clz_wrapper_map: - clz = self.clz_wrapper_map[clz] - return orig_func.issubclass(subclass, clz) - - @functools.wraps(orig_func.getattr) - def getattr_wrapper(obj, *args): - # TODO: better infomation - if not 1 <= orig_func.len(args) <= 2: - raise Exception() - args = orig_func.list(args) - if orig_func.isinstance(args[0], ep.ConcreteProxy): - args[0] = args[0].value - return orig_func.getattr(obj, *args) + wrapped = wrap_utils.create_wrapped_nn_module_func(self, mod, forward_function_name) + self.wrapped_leaf[mod.forward] = ((wrap_utils.Location(mod, forward_function_name),), wrapped) try: with self.patcher: # allow duplicate patches to support the case of nested calls - self.patcher.patch_method(torch.nn.Module, "__getattribute__", module_getattribute_wrapper, deduplicate=False) + self.patcher.patch_method(torch.nn.Module, "__getattribute__", wrap_utils.create_wrapped_module_getattribute(self), deduplicate=False) - self.patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) + self.patcher.patch_method(torch.nn.Module, "__call__", wrap_utils.create_wrapped_module_call(self), deduplicate=False) # for cuda versions of pytorch, autograd.Function.apply should be reverted by delattr - self.patcher.patch_method(torch.autograd.Function, "apply", agfunc_apply_wrapper, deduplicate=False, revert_by_del=True) - self.patcher.patch_method(torch, "_assert", torch_assert_wrapper, deduplicate=False) + self.patcher.patch_method(torch.autograd.Function, "apply", wrap_utils.create_wrapped_autograd_apply(self), deduplicate=False, revert_by_del=True) + self.patcher.patch_method(torch, "_assert", wrap_utils.torch_assert_wrapper, deduplicate=False) - self.patcher.patch_method(builtins, "map", map_wrapper, deduplicate=False) - self.patcher.patch_method(builtins, "enumerate", enumerate_wrapper, deduplicate=False) - self.patcher.patch_method(builtins, "range", range_wrapper, deduplicate=False) - self.patcher.patch_method(builtins, "type", type_wrapper, deduplicate=False) - self.patcher.patch_method(builtins, "isinstance", isinstance_wrapper, deduplicate=False) - self.patcher.patch_method(builtins, "issubclass", issubclass_wrapper, deduplicate=False) - self.patcher.patch_method(builtins, "getattr", getattr_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "map", wrap_utils.map_wrapper_clz, deduplicate=False) + self.patcher.patch_method(builtins, "enumerate", wrap_utils.enumerate_wrapper_clz, deduplicate=False) + self.patcher.patch_method(builtins, "range", wrap_utils.range_wrapper_clz, deduplicate=False) + self.patcher.patch_method(builtins, "type", wrap_utils.type_wrapper_clz, deduplicate=False) + self.patcher.patch_method(builtins, "isinstance", wrap_utils.isinstance_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "issubclass", wrap_utils.issubclass_wrapper, deduplicate=False) + self.patcher.patch_method(builtins, "getattr", wrap_utils.getattr_wrapper, deduplicate=False) for obj, (positions, wrapped) in self.wrapped_leaf.items(): for loc in positions: self.patcher.patch_method(loc.ns, loc.name, wrapped, deduplicate=False) - self.autowrap_leaf_pairs[id(obj)] = wrapped + + wrap_utils.autowrap_check(self, fn_globals) - _patch_wrapped_functions(self.patcher) - _autowrap_check(self, fn_globals, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) - for module in self._autowrap_search: - _autowrap_check(self, module.__dict__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict) with OperatorPatcherContext(self, use_operator_patch, operator_patch_backlist): results = OperatorPatcherContext.patch_run(fn, *args, *more_args, **kwargs) # we should unwrap proxy to the original value in the results when we record it to node.meta['tensor_meta'] @@ -1193,137 +740,9 @@ def unwrap(obj: Any): {}, type_expr=fn.__annotations__.get('return', None), node_result=ep.map_aggregate_not_proxy(results, unwrap)) finally: _retain_weight_consistency(self.root) - pass - self.submodule_paths = None return self.graph -# List of pairs of (global dict, function name) functions -# to patch for the purposes of the wrap() API. -_wrapped_fns_to_patch : List[Tuple[dict, str]] = [] - -# List of methods on classes to wrap (class type, function name) -# this currently only works for Tensor.* methods that aren't traced properly -_wrapped_methods_to_patch : List[Tuple[type, str]] = [] - - -def _find_proxy(*objects_to_search): - """ - Recursively search a data structure for a Proxy() and return it, - return None if not found. - """ - proxy = None - - def find_proxy(x): - nonlocal proxy - if isinstance(x, ep.ConcreteProxy): - proxy = x - - ep.map_aggregate_not_proxy(objects_to_search, find_proxy) - return proxy - -def _create_wrapped_func(orig_fn): - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - """ - Given an closed-over ``orig_function`` to invoke, search the args and kwargs for - a Proxy object. If there is one, emit a ``call_function`` node to preserve the - call to this leaf function directly. Otherwise, just return the results of - this function call, as this function is not being traced. - """ - proxy = _find_proxy(args, kwargs) - if proxy is not None: - return_proxy = proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs) - return_proxy.node.meta['is_wrapped'] = True - return return_proxy - return orig_fn(*args, **kwargs) - - return wrapped - -def _patch_wrapped_functions(patcher : FunctionPatcher): - """ - Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap - the listed global functions in the `_create_wrapped_func` wrapper. - """ - for frame_dict, name in _wrapped_fns_to_patch: - if name not in frame_dict and hasattr(builtins, name): - orig_fn = orig_func.getattr(builtins, name) - else: - orig_fn = frame_dict[name] - patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) - - for cls, name in _wrapped_methods_to_patch: - patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) - -def _autowrap_check(tracer: ConcreteTracer, frame_dict : Dict[str, Any], function_ids : Set[int],\ - function_pairs : Dict[int, Callable], agfunc_dict: dict[Type, Any]): - """ - Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. - This method searches a scope for them and patches them if found. - """ - patcher = tracer.patcher - if patcher.visit_once(frame_dict): - for name, value in frame_dict.items(): - # if callable(value) and (not name.startswith('_') or name == '_assert'): - if callable(value) and not name.startswith('__') and not name.startswith('_orig_'): - if id(value) in function_ids: - patcher.patch(frame_dict, name, _create_wrapped_func(value)) - elif id(value) in function_pairs: - patcher.patch(frame_dict, name, function_pairs[id(value)]) - elif is_autograd_apply(value): - # torch.autograd.function - if value.__self__ not in agfunc_dict: - agfunc_dict[value.__self__] = _create_wrapped_leaf_func(tracer, value, value) - patcher.patch(frame_dict, name, agfunc_dict[value.__self__]) - -def _create_wrapped_method(cls, name): - orig_fn = orig_func.getattr(cls, name) - - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - """ - Search the args and kwargs for a Proxy object. If there is one, - emit a ``call_method`` node to preserve the call to this method - directly. Otherwise, just return the results of this function - call, as this function is not being traced. - """ - proxy = _find_proxy(args, kwargs) - if proxy is not None: - return proxy.tracer.create_proxy('call_method', name, args, kwargs) - return orig_fn(*args, **kwargs) - - return wrapped - -def _create_wrapped_nn_module_func(tracer: ConcreteTracer, mod: torch.nn.Module, name: str): - orig_fn = orig_func.getattr(mod, name) - if not orig_func.isinstance(orig_fn, MethodType): - raise RuntimeError(f'{tracer.path_of_module(mod)}.{name} is not a bound method, only support wrap bound method.') - - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - module_qualified_name = tracer.path_of_module(mod) - with ScopeContextManager(tracer.scope, Scope(module_qualified_name, orig_func.type(mod))) as _scope: - need_pop = False - if _scope.module_path not in tracer.module_stack: - need_pop = True - tracer.module_stack[_scope.module_path] = _scope.module_type - elif _scope.module_path != list(tracer.module_stack)[-1]: - raise RuntimeError(f'Scope not match: {_scope.module_path} vs {list(tracer.module_stack)[-1]}') - # has tracer means in tracing progress - if OperatorPatcherContext.ctx_tracer and OperatorPatcherContext.ctx_patcher: - # `patch_run` is needed because this function will be patched by fx patcher, - # which means it will have `__fx_already_patched` flag, and operator patcher will not patch it again, - # so directly call `patch_run` here to avoid the `orig_fn is not patched by the operator patcher. - result = OperatorPatcherContext.patch_run(orig_fn, *args, **kwargs) - else: - result = orig_fn(*args, **kwargs) - if need_pop: - key, _ = tracer.module_stack.popitem(last=True) - assert key == _scope.module_path, f" Unexpected key {key}" - return result - - return wrapped - def update_tree_proxy_value(dst_pytree, src_pytree): """ @@ -1405,7 +824,7 @@ def copy_attr_new(from_module: torch.nn.Module, to_module: torch.nn.Module, targ @staticmethod def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: - if is_autograd_apply(orig_method): + if wrap_utils.is_autograd_apply(orig_method): # for torch.autograd.Function return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' @@ -1439,7 +858,7 @@ def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: @staticmethod def format_import_statement_new(name: str, obj: Any, importer) -> str: - if is_autograd_apply(obj): # type: ignore + if wrap_utils.is_autograd_apply(obj): # type: ignore # torch.autograd.function return MagicMethodPatcher.format_import_statement_ori(name, obj.__self__, importer) + f'\n{name} = {name}.apply' return MagicMethodPatcher.format_import_statement_ori(name, obj, importer) @@ -1459,164 +878,6 @@ def __exit__(self, exc_type, exc_value, tb): MagicMethodPatcher.available = False return exc_type is None -def _create_wrapped_leaf_func(tracer: ConcreteTracer, func: Callable, to_func: Optional[Callable], init_tracers = ()): - # to_func: to call correct replacement instead of the original (the original func may be wrong). - # such as: call torch.nn.norm instead of torch._C._VariableFunctions.norm. - # torch.nn.norm will help to pack dim to list if dim is an int. - if to_func is None: - to_func = func - @functools.wraps(func) - def func_wrapper(*args, **kwargs): - if tracer.temp_call_origin: - return func(*args, **kwargs) - tracers = orig_func.set(init_tracers) - def unwrap_detect_tracers(obj): - if isinstance(obj, ep.ConcreteProxy): - tracers.add(obj.tracer) - ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) - ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) - if orig_func.len(tracers) == 0: - return to_func(*args, **kwargs) - elif orig_func.len(tracers) == 1 and next(iter(tracers)) == tracer: - return tracer.create_proxy('call_function', to_func, args, kwargs) - else: - raise Exception('more than 1 tracer detected. please report the issue') - return func_wrapper - -def _create_wrapped_leaf_method(tracer: ConcreteTracer, method, name: str, to_func: Optional[Callable]): - @functools.wraps(method) - def method_wrapper(*args, **kwargs): - if tracer.temp_call_origin: - return method(*args, **kwargs) - tracers = orig_func.set() - def unwrap_detect_tracers(obj): - if isinstance(obj, ep.ConcreteProxy): - tracers.add(obj.tracer) - ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) - ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) - if orig_func.len(tracers) == 0: - return method(*args, **kwargs) - elif orig_func.len(tracers) == 1 and next(iter(tracers)) == tracer: - if to_func is not None: - return tracer.create_proxy('call_function', to_func, args, kwargs) - else: - return tracer.create_proxy('call_method', name, args, kwargs) - else: - raise Exception('more than 1 tracer detected. please report the issue') - return method_wrapper - -def _create_wrapped_leaf_class(tracer: ConcreteTracer, clz): - """ - Wrap a class as a tracable class, we usually wrap some classes that can be seen as creation functions. - For example, we can prevent the trace be interrupted by wrap ```tuple``` in the following case: - - ... - # x is a scalar - x_value = int(x) - new_x = torch.tensor([x_value, x_value]) - ... - """ - class clz_wrapper_clz: - # used to track the original class - _fx_wrapped_ori_clz = clz - - def __new__(cls, *args, **kwargs): - if tracer.temp_call_origin: - return clz(*args, **kwargs) - tracers = orig_func.set() - def unwrap_detect_tracers(obj): - if isinstance(obj, ep.ConcreteProxy): - tracers.add(obj.tracer) - ep.map_aggregate_not_proxy(args, unwrap_detect_tracers) - ep.map_aggregate_not_proxy(kwargs, unwrap_detect_tracers) - if orig_func.len(tracers) == 0: - return clz(*args, **kwargs) - elif orig_func.len(tracers) == 1 and next(iter(tracers)) == tracer: - return tracer.create_proxy('call_function', clz, args, kwargs) - else: - raise Exception('more than 1 tracer detected. please report the issue') - def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(clz)) - def __hash__(self): - return id(self) - - for name in dir(clz): - attr = orig_func.getattr(clz, name) - if not name.startswith('_'): - if orig_func.isinstance(attr, Callable): - setattr(clz_wrapper_clz, name, _create_wrapped_leaf_method(tracer, attr, name, None)) - else: - setattr(clz_wrapper_clz, name, attr) - return clz_wrapper_clz - -def _create_wrapped_leaf_iterable_class(tracer: ConcreteTracer, clz): - """ - Wrap a class as a tracable class, we usually wrap some classes that can be seen as creation functions. - For example, we can prevent the trace be interrupted by wrap ```tuple``` in the following case: - - ... - # x is a tensor - x_1st = tuple(x)[0] - ... - """ - class clz_wrapper_clz: - # used to track the original class - _fx_wrapped_ori_clz = clz - - def __new__(cls, *args, **kwargs): - if tracer.temp_call_origin: - return clz(*args, **kwargs) - tracers = orig_func.set() - if orig_func.len(args) != 0: - if orig_func.isinstance(args[0], ep.Proxy): - tracers.add(args[0].tracer) - if orig_func.isinstance(args[0], Iterator): - args = (clz(args[0]), *args[1:]) - if orig_func.isinstance(args[0], Iterable): - for item in args[0]: - if orig_func.isinstance(item, ep.Proxy): - tracers.add(item.tracer) - if orig_func.len(tracers) == 0: - return clz(*args, **kwargs) - elif orig_func.len(tracers) == 1 and next(iter(tracers)) == tracer: - return tracer.create_proxy('call_function', - clz, args, kwargs) - else: - raise Exception('more than 1 tracer detected. please report the issue') - def __eq__(self, __o: object) -> bool: - return id(__o) in (id(self), id(clz)) - def __hash__(self): - return id(self) - - for name in dir(clz): - attr = orig_func.getattr(clz, name) - if not name.startswith('_') or name in ('__getitem__', '__setitem__', '__iter__', '__len__'): - if orig_func.isinstance(attr, Callable): - setattr(clz_wrapper_clz, name, _create_wrapped_leaf_method(tracer, attr, name, None)) - else: - setattr(clz_wrapper_clz, name, attr) - return clz_wrapper_clz - -def _create_wrapped_attr_for_middle_class(tracer: ConcreteTracer, clz, the_path_of_middle_class): - _orig_clz_getattribute = clz.__getattribute__ - if hasattr(clz, '__getattr__'): - _orig_clz_getattr = clz.__getattr__ - else: - _orig_clz_getattr = None - @functools.wraps(_orig_clz_getattribute) - def clz_getattr_wrapper(obj, attr): - if tracer.temp_call_origin: - if _orig_clz_getattr == None: - return _orig_clz_getattribute(obj, attr) - else: - try: - return _orig_clz_getattribute(obj, attr) - except AttributeError: - return _orig_clz_getattr(obj, attr) - else: - return tracer.create_proxy('get_attr', f'{the_path_of_middle_class[id(obj)]}.{attr}', (), {}) - return clz_getattr_wrapper - def _retain_weight_consistency(root: torch.nn.Module): _flag = 0 for module in root.modules(): @@ -1695,10 +956,8 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], operator_patch_backlist: List[str] | None = None, forward_function_name: str = 'forward', check_args: Optional[Dict[str, Any]] = None, - autowrap_leaf_function: Optional[Dict[Any, LeafFnWrapInfo]] = None, - autowrap_leaf_class: Optional[Dict[Type, LeafClassWrapInfo]] = None, - leaf_module: Tuple | None = None, - fake_middle_class: Tuple | None = None, + autowrap_leaf_function: Optional[Dict[Any, wrap_utils.LeafWrapInfo]] = None, + autowrap_leaf_class: Optional[Dict[Type, wrap_utils.LeafWrapInfo]] = None, dce: bool = True, dce_ignored_function: Set[Callable] | None = None, cpu_offload: bool = False, @@ -1839,8 +1098,6 @@ def f(x, y): graph = tracer.trace(root, autowrap_leaf_function = autowrap_leaf_function, autowrap_leaf_class = autowrap_leaf_class, - leaf_module = leaf_module, - fake_middle_class = fake_middle_class, concrete_args = concrete_args, use_operator_patch = use_operator_patch, operator_patch_backlist = operator_patch_backlist, @@ -1851,8 +1108,6 @@ def f(x, y): graph_check = tracer.trace(root, autowrap_leaf_function = autowrap_leaf_function, autowrap_leaf_class = autowrap_leaf_class, - leaf_module = leaf_module, - fake_middle_class = fake_middle_class, concrete_args = concrete_args, use_operator_patch = use_operator_patch, operator_patch_backlist = operator_patch_backlist, @@ -1868,8 +1123,8 @@ def f(x, y): if node_a.op == 'get_attr' and node_a.name.startswith('_tensor_constant'): assert node_b.op == 'get_attr' and node_b.name.startswith('_tensor_constant') assert torch.equal(getattr(root, node_a.name), getattr(root, node_b.name)) - elif node_a.op == 'call_function' and is_autograd_apply(target_a): - assert node_b.op == 'call_function' and is_autograd_apply(target_b) + elif node_a.op == 'call_function' and wrap_utils.is_autograd_apply(target_a): + assert node_b.op == 'call_function' and wrap_utils.is_autograd_apply(target_b) else: assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}' diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py index abe3d755..2b623d3b 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/operator_patcher.py @@ -17,7 +17,7 @@ import torch -from . import orig_func +from . import orig_func, wrap_utils _logger = logging.getLogger(__name__) @@ -349,6 +349,6 @@ def __exit__(self, exc_type, exc_value, tb): def patch_run(func, *args, **kwargs): assert OperatorPatcherContext.ctx_tracer is not None assert OperatorPatcherContext.ctx_patcher is not None - with OperatorPatcherContext.ctx_tracer.do_temp_call_origin(): + with wrap_utils.do_temp_call_origin(): new_func = OperatorPatcherContext.ctx_patcher.patch_func_or_module(func) return new_func(*args, **kwargs) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py new file mode 100644 index 00000000..0ebd27ff --- /dev/null +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py @@ -0,0 +1,602 @@ +import builtins +from contextlib import contextmanager +from dataclasses import dataclass, field +import functools +import operator + +from types import MethodType, ModuleType +from typing import Any, Dict, Optional, Type, List, Callable, Union, TYPE_CHECKING, Tuple + +import math +import torch +from torch.fx.proxy import Scope, ScopeContextManager + +import nnscaler.graph.parser.fx.concrete_trace_utils as cct +from . import pytree_utils, orig_func, operator_patcher +if TYPE_CHECKING: + from .concrete_tracer import ConcreteTracer + +import logging +_logger = logging.getLogger(__name__) + +# global variable to control if the wrapped function should only execute the original logic +TEMP_CALL_ORIGIN = False + + +@contextmanager +def do_temp_call_origin(): + """ + Under this context, the wrapped functon will directly execute the original logic. + """ + global TEMP_CALL_ORIGIN + temp_call_origin = TEMP_CALL_ORIGIN + TEMP_CALL_ORIGIN = True + try: + yield + finally: + TEMP_CALL_ORIGIN = temp_call_origin + + +@dataclass +class Location: + """ + The place a function/class locates. + Please note one function/class can be in multiple places. + Take `torch.meshgrid` for example, there are `torch.meshgrid`, 'torch.functional.meshgrid', 'torch._C._VariableFunctions.meshgrid', + """ + ns: Union[Type, ModuleType, Any] # the namespace of the name. It can be a class/module, etc. + name: str + + +@dataclass +class LeafWrapInfo: + """ + extra_locs: The place the function is imported. + is_force_trace: If set to false, the function will only be traced if inputs include proxy. + Such as 'torch.rand', we should trace it even if it doesn't have proxy as input, so it should be force traced. + replacement: If not `None`, we will use it to replace the original function/class in traced code. + Such as ModuleList.__getitem__, we can use operator.getitem to replace it. + """ + extra_locs: List[Location] = field(default_factory=list) + is_force_trace: bool = False + replacement: Union[None, Callable, Type] = None + + +default_autowrap_leaf_function: Dict[Any, LeafWrapInfo] = { + # wrap widely used builtins functions that can be applied on torch.Tensor + builtins.len: LeafWrapInfo([], False, None), + builtins.abs: LeafWrapInfo([], False, None), + builtins.all: LeafWrapInfo([], False, None), + builtins.any: LeafWrapInfo([], False, None), + builtins.min: LeafWrapInfo([], False, None), + builtins.max: LeafWrapInfo([], False, None), + + # force-traced function (the factory functions of tensor creation) + torch.arange: LeafWrapInfo([], True, None), + torch.empty: LeafWrapInfo([], True, None), + torch.eye: LeafWrapInfo([], True, None), + torch.full: LeafWrapInfo([], True, None), + torch.linspace: LeafWrapInfo([], True, None), + torch.logspace: LeafWrapInfo([], True, None), + torch.ones: LeafWrapInfo([], True, None), + torch.rand: LeafWrapInfo([], True, None), + torch.randint: LeafWrapInfo([], True, None), + torch.randn: LeafWrapInfo([], True, None), + torch.randperm: LeafWrapInfo([], True, None), + torch.tensor: LeafWrapInfo([], True, None), + torch.zeros: LeafWrapInfo([], True, None), + + # method + torch.nn.Sequential.__getitem__: LeafWrapInfo([], False, operator.getitem), + torch.nn.Sequential.__len__: LeafWrapInfo([], False, builtins.len), + torch.nn.Sequential.__iter__: LeafWrapInfo([], False, builtins.iter), + + torch.nn.ModuleList.__getitem__: LeafWrapInfo([], False, operator.getitem), + torch.nn.ModuleList.__len__: LeafWrapInfo([], False, builtins.len), + torch.nn.ModuleList.__iter__: LeafWrapInfo([], False, builtins.iter), + + torch.nn.ModuleDict.__getitem__: LeafWrapInfo([], False, operator.getitem), + torch.nn.ModuleDict.__len__: LeafWrapInfo([], False, builtins.len), + torch.nn.ModuleDict.__iter__: LeafWrapInfo([], False, builtins.iter), + torch.nn.ModuleDict.__contains__: LeafWrapInfo([], False, operator.contains), + + torch.nn.ParameterList.__getitem__: LeafWrapInfo([], False, operator.getitem), + torch.nn.ParameterList.__len__: LeafWrapInfo([], False, builtins.len), + torch.nn.ParameterList.__iter__: LeafWrapInfo([], False, builtins.iter), + + torch.nn.ParameterDict.__getitem__: LeafWrapInfo([], False, operator.getitem), + torch.nn.ParameterDict.__len__: LeafWrapInfo([], False, builtins.len), + torch.nn.ParameterDict.__iter__: LeafWrapInfo([], False, builtins.iter), + torch.nn.ParameterDict.__contains__: LeafWrapInfo([], False, operator.contains), +} + + +def _functions_in_module(module: ModuleType): + """ + Detect all the callable functions in the module, exclude the functions start with `_` or typing related + """ + for name in module.__dir__(): + # get all callable except private function + if not name.startswith('_') and callable(getattr(module, name)): + op = getattr(module, name) + # exclude the typing related + if not isinstance(op, Type) and (hasattr(op, '__module__') and op.__module__ not in ('typing,')): + yield op, name + + +# get all functions in the default_autowrap_modules and add them to default_autowrap_leaf_function +default_autowrap_modules = (operator, math, torch, torch.functional, torch.nn.functional) +for module in default_autowrap_modules: + for func, func_name in _functions_in_module(module): + if func in default_autowrap_leaf_function: + default_autowrap_leaf_function[func].extra_locs.append(Location(module, func_name)) + else: + default_autowrap_leaf_function[func] = LeafWrapInfo([Location(module, func_name)], False, None) + + +default_autowrap_leaf_class: Dict[Type, LeafWrapInfo] = { + # class + builtins.bool: LeafWrapInfo([], False), + builtins.int: LeafWrapInfo([], False), + builtins.float: LeafWrapInfo([], False), + + # iterable class + builtins.tuple: LeafWrapInfo([], False), + builtins.list: LeafWrapInfo([], False), + builtins.set: LeafWrapInfo([], False), + builtins.frozenset: LeafWrapInfo([], False), + builtins.dict: LeafWrapInfo([], False), + builtins.reversed: LeafWrapInfo([], False), + + torch.Size: LeafWrapInfo([], False), + torch.finfo: LeafWrapInfo([], False), +} + + +# all wrapped classes should add to this mapping, use to track the original class, used by isinstance wrapper +# {class_wrapper: original_class} +wrapped_cls_to_orig_cls: Dict[Type, Type] = {} + + +def create_wrapped_leaf_func(func: Callable, *, replace_func: Optional[Callable]=None, default_tracer: Optional['ConcreteTracer']=None, + is_method: bool=False, method_name: Optional[str]=None): + """ + Create a wrapped function/method that will generate a call_function/call_method node when call `func` if there has proxy in the inputs. + + Args: + func (Callable) : the original function. + replace_func (Optional[Callable]) : forward the call to another function. + default_tracer (Tracer) : if the tracer is set, then use this tracer to create a node, no matter there has proxy in the inputs. + is_method (bool): if the functionl is a bound method. + method_name (str): use to identify the method name, if the function is a bound method. + """ + @functools.wraps(func) + def func_wrapper(*args, **kwargs): + global TEMP_CALL_ORIGIN + if TEMP_CALL_ORIGIN: + return func(*args, **kwargs) + else: + with do_temp_call_origin(): + tracers = set() + if default_tracer is not None: + tracers.add(default_tracer) + + def detect_tracer(obj): + if isinstance(obj, cct.ConcreteProxy): + tracers.add(obj.tracer) + + pytree_utils.tree_map(detect_tracer, args) + pytree_utils.tree_map(detect_tracer, kwargs) + + if len(tracers) > 1: + raise Exception('more than 1 tracer detected. please report the issue') + + tracer = None if len(tracers) == 0 else tracers.pop() + + if tracer is None: + return func(*args, **kwargs) + else: + if replace_func is None: + if is_method: + return tracer.create_proxy('call_method', method_name, args, kwargs) + else: + return tracer.create_proxy('call_function', func, args, kwargs) + else: + return tracer.create_proxy('call_function', replace_func, args, kwargs) + + return func_wrapper + + +def create_wrapped_leaf_class(clz, *, replace_cls: Optional[Callable]=None, default_tracer: Optional['ConcreteTracer']=None): + """ + Wrap a class as a tracable class, we usually wrap some classes that can be seen as creation functions. + For example, we can prevent the trace be interrupted by wrap ```int``` in the following case: + + ... + # x is a scalar + x_value = int(x) + new_x = torch.tensor([x_value, x_value]) + ... + + Args: + clz : the original class. + replace_cls : forward the call to another function. + default_tracer (Tracer) : if the tracer is set, then use this tracer to create a node, no matter there has proxy in the inputs. + is_method (bool): if the functionl is a bound method. + method_name (str): use to identify the method name, if the function is a bound method. + """ + class clz_wrapper_clz: + # used to track the original class + _fx_wrapped_ori_clz = clz + + def __new__(cls, *args, **kwargs): + global TEMP_CALL_ORIGIN + if TEMP_CALL_ORIGIN: + return clz(*args, **kwargs) + else: + with do_temp_call_origin(): + tracers = set() + if default_tracer is not None: + tracers.add(default_tracer) + + def detect_tracer(obj): + if isinstance(obj, cct.ConcreteProxy): + tracers.add(obj.tracer) + + pytree_utils.tree_map(detect_tracer, args) + pytree_utils.tree_map(detect_tracer, kwargs) + + if len(tracers) > 1: + raise Exception('more than 1 tracer detected. please report the issue') + + tracer = None if len(tracers) == 0 else tracers.pop() + + if tracer is None: + return clz(*args, **kwargs) + else: + if replace_cls is None: + return tracer.create_proxy('call_function', clz, args, kwargs) + else: + return tracer.create_proxy('call_function', replace_cls, args, kwargs) + + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(clz)) + + def __hash__(self): + return id(self) + + with do_temp_call_origin(): + for name in dir(clz): + attr = getattr(clz, name) + # '__getitem__', '__setitem__', '__iter__', '__len__' means this class can be iterable + # then we should wrap these methods to keep the graph preserved + if not name.startswith('_') or name in ('__getitem__', '__setitem__', '__iter__', '__len__'): + if isinstance(attr, Callable): + wrapped_method = create_wrapped_leaf_func(attr, default_tracer=default_tracer, is_method=True, method_name=name) + setattr(clz_wrapper_clz, name, wrapped_method) + else: + setattr(clz_wrapper_clz, name, attr) + + wrapped_cls_to_orig_cls[clz_wrapper_clz] = clz + return clz_wrapper_clz + + +def create_wrapped_module_getattribute(tracer: 'ConcreteTracer'): + @functools.wraps(orig_func.torch_module_getattribute) + def module_getattribute_wrapper(mod, attr): + global TEMP_CALL_ORIGIN + if TEMP_CALL_ORIGIN: + try: + return orig_func.torch_module_getattribute(mod, attr) + except AttributeError: + return orig_func.torch_module_getattr(mod, attr) + with do_temp_call_origin(): + try: + attr_val = orig_func.torch_module_getattribute(mod, attr) + except AttributeError: + attr_val = orig_func.torch_module_getattr(mod, attr) + if orig_func.isinstance(attr_val, cct.ConcreteProxy): + warn_msg = f'Detected {tracer.path_of_module[id(mod)]}.{attr} is a ConcreteProxy, ' + \ + 'this is usually caused by directly assigning the return value of some leaf function to the attribute of the module. ' + \ + 'Please note that this writing method may cause some trace errors.' + _logger.warning(warn_msg) + return attr_val + # using isinstance instead of _orig_isinstance to judge whether + # the ConcreteProxy.value is the following three types if the attr_val is a ConcreteProxy + elif isinstance(attr_val, (orig_func.tuple, orig_func.list)): + if tracer.path_of_module[id(mod)] == '': + return tracer.create_proxy('get_attr', f'{attr}', (), {}) + else: + return tracer.create_proxy('get_attr', f'{tracer.path_of_module[id(mod)]}.{attr}', (), {}) + elif attr in tracer.default_module_getattr: + if tracer.path_of_module[id(mod)] == '': + return tracer.create_proxy('get_attr', f'{attr}', (), {}) + else: + return tracer.create_proxy('get_attr', f'{tracer.path_of_module[id(mod)]}.{attr}', (), {}) + elif id(attr_val) in tracer.path_of_parameter: + return tracer.create_proxy('get_attr', tracer.path_of_parameter[id(attr_val)], (), {}) + elif id(attr_val) in tracer.path_of_buffer: + return tracer.create_proxy('get_attr', tracer.path_of_buffer[id(attr_val)], (), {}) + return attr_val + return module_getattribute_wrapper + + +def create_wrapped_module_call(tracer: 'ConcreteTracer'): + @functools.wraps(orig_func.torch_module_call) + def module_call_wrapper(mod, *args, **kwargs): + global TEMP_CALL_ORIGIN + if TEMP_CALL_ORIGIN: + return orig_func.torch_module_call(mod, *args, **kwargs) + else: + # codes below corresponds to symbolic tracer's call_module + module_qualified_name = tracer.path_of_module[id(mod)] + with ScopeContextManager(tracer.scope, Scope(module_qualified_name, type(mod))) as _scope: + tracer.module_stack[_scope.module_path] = _scope.module_type + if not tracer.is_leaf_module(mod, module_qualified_name): + autowrap_check(tracer, mod.__dict__) + ret_val = orig_func.torch_module_call(mod, *args, **kwargs) + else: + ret_val = tracer.create_proxy('call_module', module_qualified_name, args, kwargs) + key, _ = tracer.module_stack.popitem(last=True) + assert key == _scope.module_path, f" Unexpected key {key}" + return ret_val + return module_call_wrapper + + +def create_wrapped_nn_module_func(tracer: 'ConcreteTracer', mod: torch.nn.Module, name: str): + orig_fn = orig_func.getattr(mod, name) + if not orig_func.isinstance(orig_fn, MethodType): + raise RuntimeError(f'{tracer.path_of_module[id(mod)]}.{name} is not a bound method, only support wrap bound method.') + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + module_qualified_name = tracer.path_of_module[id(mod)] + with ScopeContextManager(tracer.scope, Scope(module_qualified_name, orig_func.type(mod))) as _scope: + need_pop = False + if _scope.module_path not in tracer.module_stack: + need_pop = True + tracer.module_stack[_scope.module_path] = _scope.module_type + elif _scope.module_path != list(tracer.module_stack)[-1]: + raise RuntimeError(f'Scope not match: {_scope.module_path} vs {list(tracer.module_stack)[-1]}') + # has tracer means in tracing progress + if operator_patcher.OperatorPatcherContext.ctx_tracer and operator_patcher.OperatorPatcherContext.ctx_patcher: + autowrap_check(tracer, orig_fn.__globals__) + # `patch_run` is needed because this function will be patched by fx patcher, + # which means it will have `__fx_already_patched` flag, and operator patcher will not patch it again, + # so directly call `patch_run` here to avoid the `orig_fn is not patched by the operator patcher. + result = operator_patcher.OperatorPatcherContext.patch_run(orig_fn, *args, **kwargs) + else: + result = orig_fn(*args, **kwargs) + if need_pop: + key, _ = tracer.module_stack.popitem(last=True) + assert key == _scope.module_path, f" Unexpected key {key}" + return result + + return wrapped + + +def is_autograd_apply(func) -> bool: + return getattr(func, '__name__', None) == 'apply' \ + and orig_func.isinstance(getattr(func, '__self__', None), Type) and issubclass(func.__self__, torch.autograd.Function) + + +def create_wrapped_autograd_apply(default_tracer: 'ConcreteTracer'): + @classmethod + @functools.wraps(orig_func.torch_agfunc_apply) + def agfunc_apply_wrapper(clz, *args, **kwargs): + if clz not in default_tracer.autograd_functions_mapping: + default_tracer.autograd_functions_mapping[clz] = torch._C._FunctionBase.__dict__['apply'].__get__(None, clz) + global TEMP_CALL_ORIGIN + if TEMP_CALL_ORIGIN: + return default_tracer.autograd_functions_mapping[clz](*args, **kwargs) + with do_temp_call_origin(): + tracers = set() + + def detect_tracer(obj): + if isinstance(obj, cct.ConcreteProxy): + tracers.add(obj.tracer) + + pytree_utils.tree_map(detect_tracer, args) + pytree_utils.tree_map(detect_tracer, kwargs) + + if len(tracers) > 1: + raise Exception('more than 1 tracer detected. please report the issue') + + tracer = None if len(tracers) == 0 else tracers.pop() + if tracer is None: + return default_tracer.autograd_functions_mapping[clz](*args, **kwargs) + else: + assert tracer == default_tracer + return default_tracer.create_proxy('call_function', default_tracer.autograd_functions_mapping[clz], args, kwargs) + return agfunc_apply_wrapper + + +class map_wrapper_clz: + # used to track the original class + _fx_wrapped_ori_clz = orig_func.map + + def __new__(cls, the_func, *iterables: Any): + global TEMP_CALL_ORIGIN + if TEMP_CALL_ORIGIN: + return orig_func.map(the_func, *iterables) + else: + # get the result first + results = orig_func.list() + for args in zip(*iterables): + results.append(the_func(*args)) + # if there contains proxy in results, then create a proxy with tuple as target + with do_temp_call_origin(): + tracers = set() + + def detect_tracer(obj): + if isinstance(obj, cct.ConcreteProxy): + tracers.add(obj.tracer) + + pytree_utils.tree_map(detect_tracer, results) + + if len(tracers) > 1: + raise Exception('more than 1 tracer detected. please report the issue') + elif len(tracers) == 1: + return next(iter(tracers)).create_proxy('call_function', orig_func.tuple, (results,), {}) + + return orig_func.tuple(results) + + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(orig_func.map)) + + def __hash__(self): + return id(self) + +wrapped_cls_to_orig_cls[map_wrapper_clz] = orig_func.map + + +class range_wrapper_clz: + # used to track the original class + _fx_wrapped_ori_clz = orig_func.range + + def __new__(cls, *args): + assert 1 <= orig_func.len(args) <= 3 + args = (arg.value if orig_func.isinstance(arg, cct.ConcreteProxy) else arg for arg in args) + return orig_func.range(*args) + + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(orig_func.range)) + + def __hash__(self): + return id(self) + +wrapped_cls_to_orig_cls[range_wrapper_clz] = orig_func.range + + +class enumerate_wrapper_clz: + # used to track the original class + _fx_wrapped_ori_clz = orig_func.enumerate + + def __new__(cls, iterable, start=0): + count = start + for elem in iterable: + if orig_func.isinstance(elem, cct.ConcreteProxy) and orig_func.isinstance(elem.value, (orig_func.int, str)): + yield count, elem.value + else: + yield count, elem + count += 1 + + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(orig_func.enumerate)) + + def __hash__(self): + return id(self) + +wrapped_cls_to_orig_cls[enumerate_wrapper_clz] = orig_func.enumerate + + +class type_wrapper_clz: + # used to track the original class + _fx_wrapped_ori_clz = orig_func.type + + def __new__(cls, obj_or_name, *args): + # case 1: class type(name, bases, dict, **kwds) + if orig_func.len(args) > 0: + assert orig_func.len(args) == 2 + base_cls, cls_dict = args[0], args[1] + # if it is a wrapped class, replace it to the original one + base_cls = orig_func.tuple(bs._fx_wrapped_ori_clz if hasattr(bs, '_fx_wrapped_ori_clz') else bs for bs in base_cls) + return orig_func.type(obj_or_name, base_cls, cls_dict) + # case 2: class type(object) + else: + orig_type = orig_func.type(obj_or_name) + if issubclass(orig_type, cct.ConcreteProxy): + return orig_func.type(obj_or_name.value) + else: + return orig_type + + def __eq__(self, __o: object) -> bool: + return id(__o) in (id(self), id(orig_func.type)) + + def __hash__(self): + return id(self) + +wrapped_cls_to_orig_cls[type_wrapper_clz] = orig_func.type + + +@functools.wraps(orig_func.torch_assert) +def torch_assert_wrapper(condition, message): + if orig_func.isinstance(condition, cct.ConcreteProxy): + condition = condition.value + return orig_func.isinstance(condition, message) + + +@functools.wraps(orig_func.isinstance) +def isinstance_wrapper(instance, clz): + if orig_func.type(clz) in (slice, tuple, list, orig_func.slice, orig_func.tuple, orig_func.list): + clz_wrapped = [] + for wrapped_type, orig_type in wrapped_cls_to_orig_cls.items(): + if wrapped_type in clz: + clz_wrapped.append(orig_type) + clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in wrapped_cls_to_orig_cls)) + # use _orig_isinstance(clz, Iterable) will cause an endless recursive loop + for cls in (object, cct.ConcreteProxy): + if cls in clz and orig_func.isinstance(instance, cls): + return True + if orig_func.isinstance(instance, cct.ConcreteProxy): + return orig_func.isinstance(instance.value, clz) + else: + return orig_func.isinstance(instance, clz) + else: + if clz in (object, cct.ConcreteProxy): + return orig_func.isinstance(instance, clz) + if clz in wrapped_cls_to_orig_cls: + clz = wrapped_cls_to_orig_cls[clz] + if orig_func.isinstance(instance, cct.ConcreteProxy): + instance = instance.value + return orig_func.isinstance(instance, clz) + + +@functools.wraps(orig_func.issubclass) +def issubclass_wrapper(subclass, clz): + if orig_func.type(clz) in (slice, tuple, list, orig_func.slice, orig_func.tuple, orig_func.list): + clz_wrapped = [] + for wrapped_type, orig_type in wrapped_cls_to_orig_cls.items(): + if wrapped_type in clz: + clz_wrapped.append(orig_type) + clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in wrapped_cls_to_orig_cls)) + return orig_func.issubclass(subclass, clz) + else: + if clz in wrapped_cls_to_orig_cls: + clz = wrapped_cls_to_orig_cls[clz] + return orig_func.issubclass(subclass, clz) + + +@functools.wraps(orig_func.getattr) +def getattr_wrapper(obj, *args): + if not 1 <= orig_func.len(args) <= 2: + raise Exception() + args = orig_func.list(args) + if orig_func.isinstance(args[0], cct.ConcreteProxy): + args[0] = args[0].value + return orig_func.getattr(obj, *args) + + +# NOTE: not in used, still need some test +@functools.wraps(orig_func.id) +def id_wrapper(obj): + if hasattr(obj, '_fx_wrapped_ori_clz'): + return orig_func.id(orig_func.getattr(obj, '_fx_wrapped_ori_clz')) + else: + return orig_func.id(obj) + + +def autowrap_check(tracer: 'ConcreteTracer', frame_dict : Dict[str, Any]): + """ + Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. + This method searches a scope for them and patches them if found. + """ + if tracer.patcher.visit_once(frame_dict): + for name, value in frame_dict.items(): + if callable(value) and not name.startswith('_') and (getattr(orig_func, name, None) is not value): + if value in tracer.wrapped_leaf: + tracer.patcher.patch(frame_dict, name, tracer.wrapped_leaf[value][1]) + if is_autograd_apply(value): + if value.__self__ not in tracer.autograd_functions_mapping: + tracer.autograd_functions_mapping[value.__self__] = create_wrapped_leaf_func(value) + tracer.patcher.patch(frame_dict, name, tracer.autograd_functions_mapping[value.__self__]) diff --git a/nnscaler/graph/parser/register.py b/nnscaler/graph/parser/register.py index 9cd50124..fbb4dec8 100644 --- a/nnscaler/graph/parser/register.py +++ b/nnscaler/graph/parser/register.py @@ -10,7 +10,7 @@ from torch import ScriptFunction from nnscaler.graph.function.dimops import IRDimops, OpAnno -from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import is_autograd_apply +from nnscaler.graph.parser.fx.concrete_trace_utils.wrap_utils import is_autograd_apply from nnscaler.ir.operator import IRTensor, IRFwOperation _logger = logging.getLogger(__name__) diff --git a/tests/utils.py b/tests/utils.py index c6bd6972..717c033b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -105,7 +105,7 @@ def replace_all_device_with(device='cpu', force=False): yield return - from nnscaler.graph.parser.fx.concrete_trace_utils.concrete_tracer import ConcreteTracer + from nnscaler.graph.parser.fx.concrete_trace_utils import wrap_utils orig_to = torch.Tensor.to orig_cuda = torch.Tensor.cuda @@ -182,17 +182,17 @@ def patched_cpu(self, *args, **kwargs): # patch concrete tracer's autowrap leaf function for tf_name, fn in old_tensor_constructors.items(): - leaf_info = ConcreteTracer.default_autowrap_leaf_function.pop(fn, None) + leaf_info = wrap_utils.default_autowrap_leaf_function.pop(fn, None) if leaf_info: - ConcreteTracer.default_autowrap_leaf_function[ + wrap_utils.default_autowrap_leaf_function[ patched_tensor_constructors[tf_name] ] = leaf_info yield finally: for tf_name, fn in patched_tensor_constructors.items(): - leaf_info = ConcreteTracer.default_autowrap_leaf_function.pop(fn, None) + leaf_info = wrap_utils.default_autowrap_leaf_function.pop(fn, None) if leaf_info: - ConcreteTracer.default_autowrap_leaf_function[ + wrap_utils.default_autowrap_leaf_function[ old_tensor_constructors[tf_name] ] = leaf_info for tf_name, fn in old_tensor_member_constructors.items(): From 11aa2c7cf2d22881b288fde5d06441e20aeb477f Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 28 Aug 2024 06:50:26 +0000 Subject: [PATCH 1717/1892] Merged PR 2250: bugfix: submodule buffer persistent bugfix: submodule buffer persistent --- nnscaler/graph/parser/fx/parser.py | 6 ++- tests/graph/tracer/test_buffer.py | 64 +++++++++++++++++++++++++ tests/graph/tracer/test_ctxt_manager.py | 2 +- 3 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 tests/graph/tracer/test_buffer.py diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index d9800264..608cee42 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -309,7 +309,11 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, if tensor.requires_grad: tensor.as_param() else: - persistent = node.name not in module._non_persistent_buffers_set + direct_module = module + full_qualified_name = node.target.split('.') + for name in full_qualified_name[:-1]: # last one is the attribute name + direct_module = getattr(direct_module, name) + persistent = full_qualified_name[-1] not in direct_module._non_persistent_buffers_set tensor.as_buffer(persistent=persistent) frame.add_attr(tensor, concrete_value, node.target) # the case that the parameter is consumed multiple times and registered previously diff --git a/tests/graph/tracer/test_buffer.py b/tests/graph/tracer/test_buffer.py new file mode 100644 index 00000000..60ecf086 --- /dev/null +++ b/tests/graph/tracer/test_buffer.py @@ -0,0 +1,64 @@ +import tempfile + +import torch + +from nnscaler import parallelize, ComputeConfig +from tests.utils import replace_all_device_with + +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode +from .test_ctxt_manager import TestModule + + +class BufferModuleNested(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("sub_buffer0_u", torch.tensor(1.0), persistent=False) + self.register_buffer("sub_buffer0_p", torch.tensor([2.0]), persistent=True) + + def forward(self, x): + return x + self.sub_buffer0_u + self.sub_buffer0_p + + +class BufferModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.test_module = TestModule() + self.buffer_module = BufferModuleNested() + self.register_buffer("root_buffer0_u", torch.tensor([1.0]), persistent=False) + self.register_buffer("root_buffer0_p", torch.tensor(2.0), persistent=True) + + def forward(self, x, position_ids): + x = self.test_module(x, position_ids) + x = self.buffer_module(x) + return x + self.root_buffer0_u + self.root_buffer0_p + + + +@replace_all_device_with('cpu') +def test_buffer(): + with tempfile.TemporaryDirectory() as tempdir: + model = BufferModule() + dummy_input = {'x': torch.rand(1, 100, 128), 'position_ids': torch.arange(0, 100, dtype=torch.int64).reshape(1, 100)} + parallelize(model, dummy_input, 'dp', ComputeConfig(1, 1), gen_savedir=tempdir, load_module=False) + # code will look like: + # self.register_buffer('test_module_rotary_emb_inv_freq_94', torch.empty((64,), dtype=torch.float32), persistent=False) + # self.register_buffer('buffer_module_sub_buffer0_u_114', torch.empty((), dtype=torch.float32), persistent=False) + # self.register_buffer('buffer_module_sub_buffer0_p_116', torch.empty((1,), dtype=torch.float32), persistent=True) + # self.register_buffer('root_buffer0_u_118', torch.empty((1,), dtype=torch.float32), persistent=False) + # self.register_buffer('root_buffer0_p_120', torch.empty((), dtype=torch.float32), persistent=True) + + assert _gencode_contains(tempdir, BufferModule, 0, + r'self.register_buffer\(\'test_module_rotary_emb_inv_freq_\d+\', torch.empty\(\(64,\), dtype=torch.float32\), persistent=False\)' + ) + assert _gencode_contains(tempdir, BufferModule, 0, + r'self.register_buffer\(\'buffer_module_sub_buffer0_u_\d+\', torch.empty\(\(\), dtype=torch.float32\), persistent=False\)' + ) + assert _gencode_contains(tempdir, BufferModule, 0, + r'self.register_buffer\(\'buffer_module_sub_buffer0_p_\d+\', torch.empty\(\(1,\), dtype=torch.float32\), persistent=True\)' + ) + assert _gencode_contains(tempdir, BufferModule, 0, + r'self.register_buffer\(\'root_buffer0_u_\d+\', torch.empty\(\(1,\), dtype=torch.float32\), persistent=False\)' + ) + assert _gencode_contains(tempdir, BufferModule, 0, + r'self.register_buffer\(\'root_buffer0_p_\d+\', torch.empty\(\(\), dtype=torch.float32\), persistent=True\)' + ) diff --git a/tests/graph/tracer/test_ctxt_manager.py b/tests/graph/tracer/test_ctxt_manager.py index 914dfe1a..9d214e9f 100644 --- a/tests/graph/tracer/test_ctxt_manager.py +++ b/tests/graph/tracer/test_ctxt_manager.py @@ -38,7 +38,7 @@ def __init__(self) -> None: self.rotary_emb = LlamaRotaryEmbedding(128) self.fc1 = torch.nn.Linear(128, 128) self.fc2 = torch.nn.Linear(128, 128) - + def forward(self, x, position_ids): hidden = self.fc1(x) cos, sin = self.rotary_emb(hidden, position_ids) From ae9d0f9fd1c241f1c492efa3073e7a28ec28fc3c Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Thu, 29 Aug 2024 02:40:16 +0000 Subject: [PATCH 1718/1892] Merged PR 2248: Minitrainer: refine progress bar and load_type --- nnscaler/cli/trainer.py | 180 +++++++++++++++++++++++------------ nnscaler/cli/trainer_args.py | 39 ++++++-- tests/cli/test_train_args.py | 25 +++++ tests/cli/test_trainer.py | 21 ++++ 4 files changed, 194 insertions(+), 71 deletions(-) create mode 100644 tests/cli/test_train_args.py diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index efcc9a88..77977f42 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -7,6 +7,7 @@ import warnings import shutil import logging +import time import torch import torch.distributed @@ -40,14 +41,7 @@ @dataclass class TrainStatus: best_loss = float('inf') - epoch: int = 0 - # used for resuming training - # it is the index of the next batch in the current epoch - # i means the i-1 batch is done, and we should resume from ith batch - # for example - # 0 means the epoch is not started - # 1 means the 0th batch is done, and we should resume from 1st batch - next_batch_index: int = 0 + num_train_steps_done: int = 0 @dataclass @@ -92,14 +86,11 @@ def __init__(self, self.train_status = TrainStatus() self.dummy_input = None self.total_train_steps_per_epoch = None + self.max_train_steps = None self.loggers = [] self.hook = None self._setup() - @property - def num_train_steps(self): - return self.train_status.epoch * self.total_train_steps_per_epoch + self.train_status.next_batch_index - def _fix_input(self, input): if isinstance(input, dict): return {k: self._fix_input(v) for k, v in input.items()} @@ -223,9 +214,22 @@ def _create_model(): torch.distributed.barrier() self.rank = torch.distributed.get_rank() + self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq if len(self.dataloader['train']) % self.train_args.update_freq != 0: self.total_train_steps_per_epoch += 1 # will add extra dummy batches + + if self.train_args.max_epochs and self.train_args.max_train_steps: + self.max_train_steps = min( + self.total_train_steps_per_epoch * self.train_args.max_epochs, + self.train_args.max_train_steps + ) + elif self.train_args.max_train_steps: + self.max_train_steps = self.train_args.max_train_steps + else: + assert self.train_args.max_epochs, "max_epochs or max_train_steps should be specified" + self.max_train_steps = self.total_train_steps_per_epoch * self.train_args.max_epochs + _, self.sync_group = self.train_args.compute_config.get_sync_group() self.model = pmodel_class() self.model.cuda() @@ -283,7 +287,7 @@ def _log_finalize(self): logger.finalize() def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None, *, tag: Optional[str] = None): - step = step or self.num_train_steps + step = step or self.train_status.num_train_steps_done for logger in self.loggers: logger.log_metrics(metrics, step, tag=tag) @@ -345,6 +349,23 @@ def _log_mem_stats(self, tag=None): 'ram_gb_used': ram_gb_used, }, tag=tag) + def _format_metrics(self, epoch_desc, idx, metrics: Dict[str, Union[float,int]]): + ndigits = len(str(self.total_train_steps_per_epoch)) + idx_format = f"0{ndigits}d" + int_format = '' + float_format = '.3f' + metris_str = ', '.join( + [ + f"{k}={format(v, float_format if isinstance(v, float) else int_format)}" + for k, v in metrics.items() + ] + ) + if idx is not None: + step_str = f'{format(idx, idx_format)}/{self.total_train_steps_per_epoch} ' + else: + step_str = f'' + return f"{epoch_desc}: {step_str}{metris_str}" + def _save_checkpoint(self, loss): checkpoint_config = self.train_args.checkpoint @@ -353,9 +374,13 @@ def _save_checkpoint(self, loss): return torch.distributed.barrier() - logger.info(f"Saving checkpoint after {self.num_train_steps} steps with loss={loss:.3f}.") + logger.info(f"Saving checkpoint after {self.train_status.num_train_steps_done} steps with loss={loss:.3f}.") save_dir = Path(checkpoint_config.save_dir) save_dir.mkdir(parents=True, exist_ok=True) + current_epoch = self.train_status.num_train_steps_done // self.total_train_steps_per_epoch + # the last step of the epoch + if self.train_status.num_train_steps_done % self.total_train_steps_per_epoch == 0: + current_epoch -= 1 if checkpoint_config.save_type == 'sharded': model_state_dict= self.model.state_dict() @@ -378,8 +403,8 @@ def _save_checkpoint(self, loss): } self.hook.on_save_checkpoint(self, state_dict) ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( - epoch=self.train_status.epoch, - step=self.num_train_steps, + epoch=current_epoch, + step=self.train_status.num_train_steps_done, rank=self.rank, ) logger.info(f"Saving checkpoint to {str(ckpt_file.parent)}") @@ -408,8 +433,8 @@ def _save_checkpoint(self, loss): logger.info(f"Best loss updated: {self.train_status.best_loss:.3f} -> {loss:.3f}") logger.info(f"Saving checkpoint as the best checkpoint.") best_file = save_dir / CHECKPOINT_BEST_FILE_FORMAT.format( - epoch=self.train_status.epoch, - step=self.num_train_steps, + epoch=current_epoch, + step=self.train_status.num_train_steps_done, rank=self.rank, ) best_file.parent.mkdir(parents=True, exist_ok=True) @@ -505,27 +530,23 @@ def _fix_batches(self, batches): return batches, is_dummy_batch def train(self): - assert self.train_status.next_batch_index <= self.total_train_steps_per_epoch, \ - f"next_batch_index({self.train_status.next_batch_index}) " \ - f"should not be larger than total_train_steps_per_epoch ({self.total_train_steps_per_epoch})" - + logger.info('Training...') # reset peak memory stats before training # So that we can get accurate peak memory usage for each step torch.cuda.reset_peak_memory_stats() - if self.train_status.next_batch_index == self.total_train_steps_per_epoch: - self.train_status.epoch += 1 - self.train_status.next_batch_index = 0 + if self.train_status.num_train_steps_done >= self.max_train_steps: + logger.info(f"Training is skipped: already done.") + return + + start_epoch = self.train_status.num_train_steps_done // self.total_train_steps_per_epoch - next_batch_index = self.train_status.next_batch_index self.hook.on_train_start(self) - for epoch in range(self.train_status.epoch, self.train_args.max_epochs or sys.maxsize): + for epoch in range(start_epoch, self.train_args.max_epochs or sys.maxsize): self.dataloader['train'].sampler.set_epoch(epoch) torch.distributed.barrier() - self.train_status.epoch = epoch - self.train_status.next_batch_index = next_batch_index self.hook.on_epoch_start(self, epoch) self.train_epoch(epoch) @@ -534,11 +555,10 @@ def train(self): if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'epoch': self.lr_scheduler.step() - if self.train_args.max_train_steps and self.num_train_steps >= self.train_args.max_train_steps: + if self.train_args.max_train_steps and self.train_status.num_train_steps_done >= self.train_args.max_train_steps: logger.info(f"Reached max train steps({self.train_args.max_train_steps}): Training is done.") break - next_batch_index = 0 else: # not break from for loop, which means not finished with max_train_steps # finished with max_epochs logger.info(f"Reached max_epochs({self.train_args.max_epochs}): Training is done.") @@ -566,6 +586,7 @@ def _validate(self, step_stat: _StepStat): step_stat.val_loss = step_stat.train_loss return step_stat.val_loss + logger.info(f"Validating...") data_iter = enumerate(self._global_batch_iterator(stage='val')) if self.rank == 0: total_val_steps_per_epoch = len(self.dataloader['val']) // self.train_args.update_freq @@ -576,7 +597,7 @@ def _validate(self, step_stat: _StepStat): total=total_val_steps_per_epoch, initial=0, desc=f'Validating', - disable=not self.train_args.enable_progress_bar + disable=not self.train_args.enable_progress_bar, ) loss_sum = 0.0 @@ -611,7 +632,10 @@ def _validate(self, step_stat: _StepStat): self.hook.on_val_end(self, loss) step_stat.val_loss = loss - self.log_metrics(asdict(step_stat), self.num_train_steps, tag='val') + val_metrics = asdict(step_stat) + self.log_metrics(val_metrics, tag='val') + if self.rank == 0 and self.train_args.enable_log_progress: + logger.info(self._format_metrics(f'Validation', None, val_metrics)) return step_stat.val_loss def train_epoch(self, epoch): @@ -619,25 +643,42 @@ def train_epoch(self, epoch): VAL_STATUS_VAL = 1 # validated but not saved VAL_STATUS_SAVE = 2 # validated and saved has_validated = VAL_STATUS_NO # 3 states - resume_from_idx = self.train_status.next_batch_index + + resume_from_idx = self.train_status.num_train_steps_done % self.total_train_steps_per_epoch data_iter = enumerate(self._global_batch_iterator(num_skip_first=resume_from_idx)) + + max_epoch = self.max_train_steps // self.total_train_steps_per_epoch + if self.max_train_steps % self.total_train_steps_per_epoch != 0: + max_epoch += 1 + ndigits = len(str(max_epoch)) + epoch_format = f"0{ndigits}d" + epoch_desc = f'Epoch {format(epoch, epoch_format)}' + if self.rank == 0: - data_iter = tqdm( - data_iter, + progress = tqdm( + None, total=self.total_train_steps_per_epoch, initial=resume_from_idx, - desc=f'Epoch {epoch:04d}', - disable=not self.train_args.enable_progress_bar + desc=epoch_desc, + disable=not self.train_args.enable_progress_bar, ) + else: + progress = None step_stat: Optional[_StepStat] = None - for idx, batches in data_iter: + for i, batches in data_iter: + idx = i + resume_from_idx + + if self.rank == 0: + # looks manually update progress bar is easier + # than using tqdm directly + # the difference is we update progress bar at the beginning of the loop + # instead of the end of the loop + progress.update(1) + step_start_at = time.perf_counter() step_stat = _StepStat() + step_metrics = {} has_validated = VAL_STATUS_NO - # the current batch is idx + resume_from_idx - # `+1` because the next_batch_index is the index of the next batch - # all save_checkpoint will be done at the end of the loop with correct next_batch_index - self.train_status.next_batch_index = idx + resume_from_idx + 1 num_batches = len(batches) batches, is_dummy_batch = self._fix_batches(batches) @@ -659,8 +700,6 @@ def train_epoch(self, epoch): loss = aggregated_outputs.loss_sum step_stat.train_loss = loss self.hook.after_aggregate_train_step_outputs(self, aggregated_outputs, loss, idx) - if self.rank == 0: - data_iter.set_postfix({'loss': loss}) self.hook.before_sync_grad(self) # actually `sync_shard_grad` is no-op here @@ -702,41 +741,56 @@ def train_epoch(self, epoch): if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'step': self.lr_scheduler.step() + self.train_status.num_train_steps_done += 1 self._log_mem_stats(tag='train') - self.log_metrics( - {k:v for k, v in asdict(step_stat).items() if v is not None}, - self.num_train_steps, - tag='train' - ) + step_metrics = {k:v for k, v in asdict(step_stat).items() if v is not None} + step_metrics['train_wall'] = time.perf_counter() - step_start_at + self.log_metrics(step_metrics, tag='train') + if self.rank == 0: + progress.set_postfix(step_metrics) + if self.train_args.enable_log_progress \ + and self.train_status.num_train_steps_done % self.train_args.log_progress_every_n_train_steps == 0: + logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) + step_metrics = {} # validate and save checkpoint if self.train_args.checkpoint.every_n_train_steps and \ - self.num_train_steps % self.train_args.checkpoint.every_n_train_steps == 0: + self.train_status.num_train_steps_done % self.train_args.checkpoint.every_n_train_steps == 0: self._validate_and_save(step_stat) has_validated = VAL_STATUS_SAVE - if self.train_args.max_train_steps and self.num_train_steps >= self.train_args.max_train_steps: + # max_train_steps is reached + if self.train_status.num_train_steps_done >= self.max_train_steps: + if step_metrics and self.train_args.enable_log_progress: + logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) + step_metrics = {} if not has_validated: self._validate_and_save(step_stat) has_validated = VAL_STATUS_SAVE + if self.rank == 0: + # disable refresh the progress bar to avoid redundant progress bar + progress.leave = False + progress.close() break if not has_validated and self.train_args.val_every_n_train_steps and \ - self.num_train_steps % self.train_args.val_every_n_train_steps == 0: + self.train_status.num_train_steps_done % self.train_args.val_every_n_train_steps == 0: self._validate(step_stat) has_validated = VAL_STATUS_VAL - # import time - # time.sleep(0.2) - else: # not finished with max_train_steps + + # time.sleep(1) + else: + # Do per-epoch operations here. + # if the loop exits with `break` (max_train_steps is reached) + # those operations have done in the loop if step_stat is None: - return # no train step runs. No need to save checkpoint - if has_validated < VAL_STATUS_SAVE and \ - self.train_args.max_epochs == self.train_status.epoch + 1 \ - or (self.train_args.checkpoint.every_n_epochs and \ - (self.train_status.epoch + 1) % self.train_args.checkpoint.every_n_epochs == 0): + return # no train step runs. Nothing to do. + if has_validated < VAL_STATUS_SAVE \ + and self.train_args.checkpoint.every_n_epochs \ + and (epoch + 1) % self.train_args.checkpoint.every_n_epochs == 0: self._validate_and_save(step_stat) has_validated = VAL_STATUS_SAVE - if not has_validated and self.train_args.val_every_n_epochs and \ - (self.train_status.epoch + 1) % self.train_args.val_every_n_epochs == 0: + if not has_validated and self.train_args.val_every_n_epochs \ + and (epoch + 1) % self.train_args.val_every_n_epochs == 0: self._validate(step_stat) has_validated = VAL_STATUS_VAL diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index eb4de2f6..f17b13c5 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -6,6 +6,7 @@ import logging import copy import os +import builtins import torch import torch.utils @@ -33,14 +34,28 @@ def load_type(type_name: str): if callable(type_name): # a function or class return type_name - parts = type_name.rsplit('.', 1) - if len(parts) == 1: - nm = __builtins__ - type_name = parts[0] - else: - namespace, type_name = parts - nm = importlib.import_module(namespace) - return getattr(nm, type_name) + parts = type_name.split('.') + + # s: the number of parts to be the namespace + # s == 0: use builtins + # so the range() part includes 0 (with stop=-1) + for s in range(len(parts) - 1, -1, -1): + if s == 0: + nm = builtins + else: + namespace = '.'.join(parts[:s]) + try: + nm = importlib.import_module(namespace) + break + except (ImportError, ModuleNotFoundError): + pass + + try: + for i in range(s, len(parts)): + nm = getattr(nm, parts[i]) + return nm + except AttributeError as e: + raise RuntimeError(f"Failed to load type {type_name}") from e @dataclass @@ -318,6 +333,10 @@ class TrainerArgs: val_every_n_epochs: Optional[int] = 1 enable_progress_bar: bool = True + # if progress_bar is disabled (enable_progress_bar is False), + # the frequency to print the training progress + # validation metrics will also be printed if it is not None. + log_progress_every_n_train_steps: Optional[int] = 100 seed: Optional[int] = None # environment initialization function @@ -463,6 +482,10 @@ def scaling_factor(self): def update_freq(self): return self.global_batch_size // self.micro_batch_size // self.scaling_factor + @property + def enable_log_progress(self): + return not self.enable_progress_bar and self.log_progress_every_n_train_steps + @property def param_dtype(self) -> torch.dtype: return _PRECISION_MAP[self.precision['param']] diff --git a/tests/cli/test_train_args.py b/tests/cli/test_train_args.py new file mode 100644 index 00000000..b6c9b2ad --- /dev/null +++ b/tests/cli/test_train_args.py @@ -0,0 +1,25 @@ +import pytest + +import nnscaler +from nnscaler.cli.trainer_args import load_type + + +def test_load_type(): + assert load_type(int) == int + assert load_type('int') == int + assert load_type(int.to_bytes) == int.to_bytes + assert load_type('int.to_bytes') == int.to_bytes + assert load_type('nnscaler.cli.trainer_args.TrainerArgs') == nnscaler.cli.trainer_args.TrainerArgs + assert load_type('nnscaler.cli.trainer_args.TrainerArgs.from_cli') == nnscaler.cli.trainer_args.TrainerArgs.from_cli + + with pytest.raises(RuntimeError): + load_type('not_exist_name') + + with pytest.raises(RuntimeError): + load_type('not_exist_namespace.not_exist_name') + + with pytest.raises(RuntimeError): + load_type('nnscaler.not_exist_name') + + with pytest.raises(RuntimeError): + load_type('nnscaler.cli.trainer_args.TrainerArgs.not_exist_name') diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index bd8036d5..a92cbea8 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -134,6 +134,27 @@ def trainer_resume_worker(save_dir, save_type, bf16): ckpt0_files0 = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} assert len(ckpt0_files0)/4 == min(30, trainer.total_train_steps_per_epoch * 2) + 2 # 2 for best/last + # resume from last without update max_epochs + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16' if bf16 else 'none', + '--optimizer.type', optimizer_type, + '--max_epochs', '2', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.train() + ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} + # nothing should be updated in this case. + assert ckpt0_files0 == ckpt0_files0_x + # create merged checkpoint ckpt1_savedir = save_dir / 'ckpt1' ckpt1_savedir.mkdir(parents=True, exist_ok=True) From 1472c0a2f2b415564a448691e19ee3ed72e453f0 Mon Sep 17 00:00:00 2001 From: "Xin Ji (CSI Interfusion Co Ltd)" Date: Thu, 29 Aug 2024 22:56:52 +0000 Subject: [PATCH 1719/1892] Merged PR 2208: auto op partition testing auto op partition testing --- nnscaler/graph/function/function.py | 1 - utility/verify_ops/verify_dimops.py | 462 ++++++++++++++++++ utility/verify_ops/verify_graph_operations.py | 161 ++++++ utility/verify_ops/verify_op.md | 347 +++++++++++++ 4 files changed, 970 insertions(+), 1 deletion(-) create mode 100644 utility/verify_ops/verify_dimops.py create mode 100644 utility/verify_ops/verify_graph_operations.py create mode 100644 utility/verify_ops/verify_op.md diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index f6790e4d..1a605aea 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -1895,7 +1895,6 @@ def Repeat(tensor, repeats: _VariadicInt, *arg_repeats, signature = None): """ torch.Tensor.repeat(*sizes) """ - signature = 'torch.ops.aten.repeat' if isinstance(repeats, (list, tuple)) or ( isinstance(repeats, IRObject) and isinstance(repeats.value, (list, tuple)) ): diff --git a/utility/verify_ops/verify_dimops.py b/utility/verify_ops/verify_dimops.py new file mode 100644 index 00000000..f9c56eca --- /dev/null +++ b/utility/verify_ops/verify_dimops.py @@ -0,0 +1,462 @@ +""" +This test verifies the correctness of an operator's annotation by running its distributed versions. +The processing pipeline is: +1. generate the input and calculate the output for the operator on a single device +2. construct the partition search space based on its annotation +3. for each partition choice, nnscaler will generate runnable code with communication adapters automatically +4. compare each distributed result with single device version, the difference should be less than a threshold +NOTE: only consider partitioning along one dimension currently +""" + +import os +from typing import Dict, List, Tuple, Any, Union +from dataclasses import dataclass, field +import logging +import subprocess +import torch + +from nnscaler.graph.function.dimops import IRDimops, OpAnno, DimAnno +from nnscaler.ir.cten import IRTensor, IRObject + + +logger = logging.getLogger(__name__) + + +_SINGLE_GPU_TEST_FILE = "single_gpu_test.py" +_TWO_GPUS_TEST_FILE = "two_gpus_test.py" + +module_template_common = """ +import os +import numpy +import sys +import torch +import nnscaler + +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType +{import_cumsomized_func} + +import nnscaler.graph +import nnscaler.graph.function +import nnscaler.graph.function.wrapnn + +import torch +import numpy as np +import random + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, {args}): + # Add clone to resolve the issue: + # a leaf Variable that requires grad is being used in an in-place operation. + {clone_args} + + {func_sig_call} + + out = 0 + for one_out in [{outputs}]: + if not isinstance(one_out, torch.Tensor): + continue + out += torch.sum(one_out) + return out + +model = TestModule() #.to(torch.float16) +""" + +module_template_single_main = """ +# Load inputs from file, ensuring inputs.pt is always a tuple, even when there's only one input +{args}, = torch.load('{func_sig}_inputs.pt', map_location=torch.device('cuda:0')) + +model = model.cuda() + +single_loss = model({args}) +single_loss.backward() + +grad_tensors = {grad_tensors} +torch.save([grad_tensors, single_loss], '{func_sig}_loss_single.pt') +print('single gpu loss: ', single_loss) +""" + +module_template_single = module_template_common + module_template_single_main + +module_template_parallel_main = """ +nnscaler.init() +rank_id = torch.distributed.get_rank() + +{args}, = torch.load('{func_sig}_inputs.pt', map_location=torch.device(f'cuda:{{rank_id}}')) + +def policy(graph: IRGraph, resource) -> IRGraph: + ngpus = 2 + partitioned = False + + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == '{func_sig}': + print('Partitioned node: ', node) + sub_nodes = graph.partition( + node, node.algorithms('dim'), idx={idx}, dim={dim}, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + + assert partitioned, f'No node is partitioned for {func_sig}.' + return graph + +parallel_model = parallelize( + model, + dummy_forward_args={dummy_input_str}, + pas_policy=policy, + compute_config=ComputeConfig(2, 2), + reuse=ReuseType.OVERRIDE +) + +parallel_model.train() + +parallel_loss = parallel_model({args}) +parallel_loss.backward() + +grad_tensors = {grad_tensors} +torch.save([grad_tensors, parallel_loss], '{func_sig}_loss_para_'+str(rank_id)+'.pt') +print('two gpus loss: ', parallel_loss) +""" + +module_template_parallel = module_template_common + module_template_parallel_main + + +@dataclass +class TensorInfo: + value_form: str # 'shape' or 'value' + value: Union[Tuple[int], Any] + dtype: torch.dtype = torch.float32 + requires_grad: bool = True + + # make TensorInfo hashable + def __hash__(self): + value = self.value + if isinstance(value, slice): + value = (value.start, value.stop, value.step) + return hash((self.value_form, value)) + + +@dataclass +class VerifyConfig: + fsig: str + args: List[TensorInfo] + kwargs: Dict[str, Any] + noutputs: int + parti_options: List[Dict[str, int]] + import_customized_func: str = "" + non_grad_indices: List[int] = field(default_factory=list) + + +def _complex(val: Any): + """ + Convert IRObject to concrete value + NOTE: only used for handling kwargs + """ + if isinstance(val, tuple): + return tuple(_complex(t) for t in val) + if isinstance(val, list): + return list(_complex(t) for t in val) + if isinstance(val, dict): + return {_complex(key): _complex(val) for key, val in val.items()} + if isinstance(val, slice): + return slice(_complex(val.start), _complex(val.stop), _complex(val.step)) + if isinstance(val, IRObject): + assert not isinstance(val, IRTensor), "IRTensor should not be in kwargs" + return _complex(val.value) + return val + + +def get_candidate_options( + anno: OpAnno, ins_outs_shape: List[TensorInfo], npartitions: int = 2 +) -> List[Dict[str, int]]: + """ + Get all the feasible partitions specified by the annotation of an operator. + Checks whether the dimension can be divided, and also checks whether the size of the dimension can be evenly divided by the number of partitions + Args: + anno (OpAnno): operator annotation + ins_outs_shape (List[TensorInfo]): input and output shapes + npartitions (int, optional): number of partitions. Defaults to 2. + Returns: + List[Dict[str, int]]: a list of feasible partitions + + """ + all_configs = anno.transform_space() + + candidate_partitions = [] + for idx, dim in all_configs: + if ( + ins_outs_shape[idx].value_form == "shape" + and ins_outs_shape[idx].value[dim] % npartitions == 0 + ): + candidate_partitions.append({"idx": idx, "dim": dim}) + + return candidate_partitions + + +def handle_buffer_parameters(inputs, non_grad_indices): + """ + Detach specified buffer parameters from the computational graph and disable their gradient computation. + This is necessary for parameters that should not participate in the backward pass, + such as statistical parameters in certain layers (e.g., running_mean in normalization layers). + + Args: + inputs (List[torch.Tensor]): The list of input tensors. + non_grad_indices (List[int]): The indices of buffer parameters in the input list. + """ + for idx in non_grad_indices: + if inputs[idx] is not None: + inputs[idx] = inputs[idx].detach() + inputs[idx].requires_grad = False + + +def _create_op_inputs(verify_config: VerifyConfig) -> List[Any]: + """ + Create input tensors/non-tensors for the operator. + The input tensors/non-tensors are only for args, not for kwargs. + Args: + verify_config (VerifyConfig): configuration for verifying the partitions + Returns: + List[Any]: input tensors + """ + torch.manual_seed(0) + inputs = [] + + def process_slice(slice_obj): + start = ( + slice_obj.start.value + if isinstance(slice_obj.start, IRObject) + else slice_obj.start + ) + stop = ( + slice_obj.stop.value + if isinstance(slice_obj.stop, IRObject) + else slice_obj.stop + ) + step = slice_obj.step + return slice(start, stop, step) + + for i, tensor_info in enumerate(verify_config.args): + if tensor_info.value_form == "shape": + # Special handling: For torch. rsqrt, generate random integers between 1 and 10 to avoid invalid values + if verify_config.fsig == "torch.rsqrt": + inputs.append( + torch.randint( + 1, + 10, + tensor_info.value, + dtype=tensor_info.dtype, + requires_grad=tensor_info.requires_grad, + ) + ) + # Special handling: for the first parameter of torch.where which is a boolean mask + elif verify_config.fsig == "torch.where" and i == 0: + inputs.append( + torch.rand( + *tensor_info.value, dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad + ) + > 0.5 + ) + elif verify_config.fsig == "torch.add" and tensor_info.value == (1,): + # Special handling:add in the model generates values that cannot be partitioned + inputs.append(torch.randn(4, dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad)) + else: + if tensor_info.value == (): + inputs.append( + torch.randn( + (), dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad + ).squeeze() + ) + else: + inputs.append( + torch.randn( + *tensor_info.value, + dtype=tensor_info.dtype, + requires_grad=tensor_info.requires_grad, + ) + ) + elif tensor_info.value_form == "value" and isinstance(tensor_info.value, slice): + inputs.append(process_slice(tensor_info.value)) + else: + inputs.append(tensor_info.value) + if verify_config.non_grad_indices: + handle_buffer_parameters(inputs, verify_config.non_grad_indices) + return inputs + + +def verify_partition_options(verify_config: VerifyConfig) -> bool: + errors = [] + try: + logger.info(f"Verifying partitions of {verify_config.fsig}...") + inputs = _create_op_inputs(verify_config) + torch.save(inputs, f"{verify_config.fsig}_inputs.pt") + logger.info(f"Input tensors saved to {verify_config.fsig}_inputs.pt") + + outputs_str = ", ".join([f"_out{i}" for i in range(verify_config.noutputs)]) + + kwargs_str = ", ".join( + [ + f'{k}="{v}"' if isinstance(v, str) else f"{k}={_complex(v)}" + for k, v in verify_config.kwargs.items() + ] + ) + args_str = ", ".join([f"_in{i}" for i, tinfo in enumerate(verify_config.args)]) + func_sig_call = verify_config.fsig + + if args_str: + func_call = f"{outputs_str} = {func_sig_call}({args_str}, {kwargs_str})" + else: + func_call = f"{outputs_str} = {func_sig_call}({kwargs_str})" + + clone_args_right = ", ".join( + [ + f"_in{i}.clone()" + for i, tinfo in enumerate(verify_config.args) + if tinfo.value_form == "shape" + ] + ) + if clone_args_right: + clone_args_left = ", ".join( + [ + f"_in{i}" + for i, tinfo in enumerate(verify_config.args) + if tinfo.value_form == "shape" + ] + ) + clone_args = f"{clone_args_left} = {clone_args_right}" + else: + clone_args = "" + + dummy_input_str = ( + "{" + + ", ".join([f'"_in{i}": _in{i}' for i in range(len(verify_config.args))]) + + "}" + ) + + grad_tensors = ( + "[" + + ", ".join( + [ + f"_in{i}.grad" + for i in range(len(verify_config.args)) + if i not in verify_config.non_grad_indices + and verify_config.args[i].value_form == "shape" + ] + ) + + "]" + ) + module_single_str = module_template_single.format( + import_cumsomized_func=verify_config.import_customized_func, + clone_args=clone_args, + args=args_str, + kwargs=kwargs_str, + func_sig=verify_config.fsig, + func_sig_call=func_call, + outputs=outputs_str, + grad_tensors=grad_tensors, + ) + with open(_SINGLE_GPU_TEST_FILE, "w") as f: + f.write(module_single_str) + logger.info("Generated test code for single gpu and running...") + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_single.pt"]) + subprocess.run(["python", _SINGLE_GPU_TEST_FILE]) + logger.info( + f"Single GPU test completed. Output saved to {verify_config.fsig}_loss_single.pt" + ) + logger.info(f"verify_config: {verify_config}") + logger.info(f"verify_config.parti_options: {verify_config.parti_options}") + + for poption in verify_config.parti_options: + try: + logger.info(f"Verifying the partition {poption}...") + module_para_str = module_template_parallel.format( + import_cumsomized_func=verify_config.import_customized_func, + clone_args=clone_args, + args=args_str, + kwargs=kwargs_str, + func_sig=verify_config.fsig, + func_sig_call=func_call, + outputs=outputs_str, + dummy_input_str=dummy_input_str, + grad_tensors=grad_tensors, + idx=poption["idx"], + dim=poption["dim"], + ) + with open(_TWO_GPUS_TEST_FILE, "w") as f: + f.write(module_para_str) + logger.info("Generated test code for two gpus.") + + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_para_0.pt"]) + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_para_1.pt"]) + subprocess.run( + [ + "torchrun", + "--nproc_per_node=2", + "--nnodes=1", + "--rdzv-endpoint=localhost:23457", + _TWO_GPUS_TEST_FILE, + ] + ) + logger.info( + f"Two GPU test completed. Outputs saved to {verify_config.fsig}_loss_para_0.pt and {verify_config.fsig}_loss_para_1.pt" + ) + single = torch.load(f"{verify_config.fsig}_loss_single.pt") + logger.info( + f"Loading single loss from: {verify_config.fsig}_loss_single.pt" + ) + para0 = torch.load(f"{verify_config.fsig}_loss_para_0.pt") + para1 = torch.load(f"{verify_config.fsig}_loss_para_1.pt") + + logger.info(f"Single loss: {single[1]}") + logger.info(f"Multi-GPU loss (para0): {para0[1]}") + logger.info(f"Multi-GPU loss (para1): {para1[1]}") + + assert torch.allclose( + single[1], para0[1], rtol=1e-3, atol=1e-5 + ), f"Loss mismatch between single and multi-GPU (para0)" + assert torch.equal( + para0[1], para1[1].to(para0[1]) + ), f"Loss mismatch between multi-GPU (para0 and para1)" + + for i in range(len(single[0])): + if single[0][i] is None or para0[0][i] is None: + logger.debug( + f"Skipping comparison for index {i} because it is None" + ) + continue + logger.debug(f"Absolute error: {single[0][i] - para0[0][i]}") + logger.debug( + f"Relative error: {(single[0][i] - para0[0][i]) / single[0][i]}" + ) + assert torch.allclose( + single[0][i], para0[0][i], rtol=1e-3, atol=1e-5 + ), f"Gradient mismatch between single and multi-GPU (para0)" + assert torch.equal( + para0[0][i], para1[0][i].to(para0[0][i]) + ), f"Gradient mismatch between multi-GPU (para0 and para1)" + + logger.info( + f"{verify_config.fsig} of partition {poption} passed the allclose comparison." + ) + except Exception as e: + error_message = f"Partition {poption} failed with error: {str(e)}" + logger.error(error_message) + errors.append(error_message) + if errors: + logger.error("Some partitions failed:") + for error in errors: + logger.error(error) + return False + else: + logger.info( + f"Verified all the partitions of {verify_config.fsig} successfully." + ) + return True + except Exception as e: + logger.exception("Exception occurred during verification process") + raise e diff --git a/utility/verify_ops/verify_graph_operations.py b/utility/verify_ops/verify_graph_operations.py new file mode 100644 index 00000000..680d3fc1 --- /dev/null +++ b/utility/verify_ops/verify_graph_operations.py @@ -0,0 +1,161 @@ +import argparse +import os +import sys +import torch +from nnscaler.graph.function.dimops import DimAnno, IRDimops, OpAnno +from nnscaler.graph.graph import IRGraph +from nnscaler.ir.cten import IRObject, IRTensor +from pathlib import Path +import logging + +from verify_dimops import TensorInfo, get_candidate_options + +_VERIFIED_OPS_FILE_NAME = "verified_ops.pt" +_DEFAULT_CACHE_DIR = Path(os.path.expanduser("~/.cache/nnscaler")) + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +def load_verified_ops(outdir: Path): + verified_ops_file = outdir / _VERIFIED_OPS_FILE_NAME + if verified_ops_file.exists(): + logger.info(f"{verified_ops_file} exists, load it.") + return torch.load(verified_ops_file) + else: + logger.info(f"{verified_ops_file} does not exist, start from scratch.") + return set() + + +def save_verified_ops(outdir: Path, verified_ops: set): + verified_ops_file = outdir / _VERIFIED_OPS_FILE_NAME + torch.save(verified_ops, verified_ops_file) + logger.info(f"Verification results saved to {verified_ops_file}") + + +def verify_op_partitions(graph: IRGraph, outdir: Path): + """ + Test if the partitioned ops in the graph are computationally correct. + + Args: + graph (IRGraph): the graph to be verified + outdir (Path): the directory to save the verified ops + + Returns: + None + """ + from verify_dimops import ( + VerifyConfig, + TensorInfo, + verify_partition_options, + ) + + verified_ops = load_verified_ops(outdir) + skipped_nodes = [] + + gnodes = graph.nodes(flatten=True) + for idx, node in enumerate(gnodes): + logger.info(f"node: {node}") + logger.info(f"Verification progress: {idx} / {len(gnodes)}") + if node.isfw() and isinstance(node, IRDimops): + ins_info = [ + ( + TensorInfo("shape", _input.shape) + if isinstance(_input, IRTensor) + else TensorInfo( + "value", + _input.value if isinstance(_input, IRObject) else _input, + ) + ) + for _input in node.inputs() + ] + if not ins_info: + skipped_nodes.append(f"{node.signature} (type: {type(node)})") + logger.info(f"ins_info is empty for node: {node.signature}, skipping.") + continue + + outs_info = [ + ( + TensorInfo("shape", output.shape) + if isinstance(output, IRTensor) + else TensorInfo( + "value", + output.value if isinstance(output, IRObject) else output, + ) + ) + for output in node.outputs() + ] + if (node.signature, tuple(ins_info + outs_info)) in verified_ops: + logger.info(f"{node.signature} has been verified before, skip.") + continue + + logger.info(f"Node annos: {node.signature}, {node.anno}") + + parti_options = get_candidate_options(node.anno, ins_info + outs_info) + + logger.info(f"Candidate partition options: {parti_options}") + + verify_config = VerifyConfig( + fsig=node.signature, + args=ins_info, + kwargs=node.kwargs, + noutputs=len(node.outputs()), + parti_options=parti_options, + ) + try: + iscorrect = verify_partition_options(verify_config) + except Exception as e: + logger.warning( + f"Verification failed for {node.signature}, {e}, please manually verify." + ) + iscorrect = True # fake true to skip this node + if not iscorrect: + logger.warning(f"Verification failed for {node.signature}, continuing execution.") + continue + + verified_ops.add((node.signature, tuple(ins_info + outs_info))) + save_verified_ops(outdir, verified_ops) + + if skipped_nodes: + logger.info("Skipped the following nodes due to empty ins_info:") + for node_info in skipped_nodes: + logger.info(f" - {node_info}") + +def main(): + parser = argparse.ArgumentParser( + description="Verify partitions of operations in an IRGraph." + ) + parser.add_argument( + "--graph", type=str, required=True, help="Path to the graph file." + ) + parser.add_argument( + "--outdir", + type=str, + help="Optional directory to save the verified operations. If not provided, results will be saved to the default cache directory.", + ) + + args = parser.parse_args() + + graph_path = Path(args.graph) + if not graph_path.exists(): + raise FileNotFoundError(f"Graph file {graph_path} does not exist.") + + graph = IRGraph.load(graph_path) + + if args.outdir: + outdir = Path(args.outdir) + else: + outdir = _DEFAULT_CACHE_DIR + + outdir.mkdir(parents=True, exist_ok=True) + verify_op_partitions(graph, outdir) + + +if __name__ == "__main__": + main() diff --git a/utility/verify_ops/verify_op.md b/utility/verify_ops/verify_op.md new file mode 100644 index 00000000..46cfc5d2 --- /dev/null +++ b/utility/verify_ops/verify_op.md @@ -0,0 +1,347 @@ +## verify-graph support +""" +Used to verify operations in IRGraph to ensure their functionality and consistency across single and multiple Gpus. +""" +## example: +Command-line interface for verifying operations in an IRGraph. + +Usage: +python verify_graph_operations.py --graph --outdir + +Parameters: +--graph (str): Path to the graph checkpoint file (.ckp) to be loaded. This is the same graph used as the input for the pas policy. +--outdir (str): Directory where verification results will be saved. + +This script performs the following steps: +1. Load the IRGraph: Reads the graph checkpoint file specified by the `--graph` argument. +2. Verify Operations: Performs verification on the operations defined in the graph. This includes: + - Registering the operations for further testing. + - Verifying single-GPU and multi-GPU functionality. + - Checking the consistency of partitioned operations across different GPUs. +3. Generate and Save Results: Outputs verification results, including loss values for single and multiple GPUs, and details of partition validations. + +To test a module: you should first use parallelize to generate the required graph.ckp file, then test graph against the current script. + +## verify-dimops support +""" +Define a configuration for verifying partition options of a tensor operation. +This configuration helps ensure that the operation's partitioning logic is valid +by specifying the function signature, arguments, expected outputs, and partitioning options. +""" +## example 1: +This is used to verify that Conv2D's partition configuration is correct. This configuration defines a basic Conv2D operation with input Tensor, convolution kernel, and bias. +```python +@dataclass +class VerifyConfig: + fsig: str + args: List[TensorInfo] + kwargs: Dict[str, Any] + noutputs: int + parti_options: List[Dict[str, int]] + import_customized_func: str = "" + non_grad_indices: List[int] = field(default_factory=list) + +Parameters: + fsig (str): Function signature of the operator to be tested. + args (List[TensorInfo]): List of TensorInfo objects representing the input arguments for the operator. + kwargs (Dict[str, Any]): Keyword arguments for the operator. + noutputs (int): Number of outputs expected from the operator. + parti_options (List[Dict[str, int]]): List of partition options specifying how to partition the operator. + import_customized_func (str): A string containing import statements for any custom functions or modules required by the operator. This ensures that all necessary functions are available in the generated test code. + non_grad_indices (List[int]): List of indices specifying which input tensors are buffer parameters + (e.g., running_mean, running_var) that should not participate in the + backward pass. These parameters will be detached and have their gradients + disabled during the test. + +conv2d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv2d', + fsig = 'torch.conv2d', + args = [ + TensorInfo('shape', (8, 32, 32)), + TensorInfo('shape', (16, 4, 3, 3)), + TensorInfo('shape', (16,)) + ], + kwargs = {'stride': 1, 'padding': 0, 'dilation': 1, 'groups': 2}, + parti_options = [{'idx': 0, 'dim': 0}], + noutputs = 1, +) +verify_partition_options(conv2d_config) +``` + +## Examples for configuring different op + +``` +dropout_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import dropout', + fsig = 'torch.nn.functional.dropout', + args = [ + TensorInfo('shape', (1, 512, 4, 4)) + ], + kwargs = {'p': 0.5, 'training':False, 'inplace':False}, + parti_options = [{'idx': 0, 'dim': 1}, + {'idx': 0, 'dim': 2}, + {'idx': 0, 'dim': 3}], + noutputs = 1, +) +verify_partition_options(dropout_config) + + +where_config = VerifyConfig( + fsig='torch.where', + args=[ + TensorInfo('shape', value=(1, 1, 9, 9)), + TensorInfo('shape', value=(1, 12, 9, 9)), + TensorInfo('shape', value=(1,)) + ], + kwargs={}, + noutputs=1, + parti_options=[{'idx': 1, 'dim': 1}], +) +verify_partition_options(where_config) + + +view_config = VerifyConfig( + fsig='torch.Tensor.view', + args=[ + TensorInfo('shape', value=(1, 9, 768)) + ], + kwargs={'size': (-1, 768)}, + noutputs=1, + parti_options=[{'idx': 0, 'dim': 2}], +) +verify_partition_options(view_config) + + +embedding_config = VerifyConfig( + import_customized_func='from torch.nn.functional import embedding', + fsig='nnscaler.runtime.function.embedding', + args=[ + TensorInfo('shape', value=(1, 9)), + TensorInfo('shape', value=(50257, 768)) + ], + kwargs={'padding_idx': None, 'start': 0, 'stop': 50257}, + parti_options=[{'idx': 1, 'dim': 1}], + noutputs=1 +) +verify_partition_options(embedding_config) + + +dropout_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import dropout', + fsig = 'torch.nn.functional.dropout', + args = [ + TensorInfo('shape', (1, 9, 768)) + ], + kwargs = {'p': 0.1, 'training':False, 'inplace':False}, + parti_options = [{'idx': 0, 'dim': 2}], + noutputs = 1, +) +verify_partition_options(dropout_config) + + +fullslice_config = VerifyConfig( + fsig='nnscaler.runtime.function.fullslice', + args=[ + TensorInfo('shape', value=(1, 1, 1024, 1024)), + TensorInfo('value', value=slice(None, None, None)), + TensorInfo('value', value=slice(None, None, None)), + TensorInfo('value', value=slice(IRObject('sub391', value=0, is_constant=True), IRObject('size_9388', value=9, is_constant=True), None)), + TensorInfo('value', value=slice(None, IRObject('size_9388', value=9, is_constant=True), None)) + ], + kwargs={}, + noutputs=1, + parti_options=[], +) +verify_partition_options(fullslice_config) + + +conv2d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv2d', + fsig = 'torch.conv2d', + args = [TensorInfo('shape',(8192, 4, 4)), TensorInfo('shape', (8192, 512, 3, 3)),TensorInfo('shape', (8192,))], + kwargs = {'stride': 1, 'padding': 0, 'dilation': 1, 'groups': 16}, + parti_options = [{'idx': 0, 'dim': 0}], + noutputs = 1, +) +verify_partition_options(conv2d_config) +conv2d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv2d', + fsig = 'torch.conv2d', + args = [ + TensorInfo('shape', (4, 8, 32, 32)), + TensorInfo('shape', (16, 4, 3, 3)), + TensorInfo('shape', (16,)) + ], + kwargs = {'stride': 1, 'padding': 0, 'dilation': 1, 'groups': 2}, + parti_options = [{'idx': 0, 'dim': 1}], + noutputs = 1, +) +verify_partition_options(conv2d_config) +conv2d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv2d', + fsig = 'torch.conv2d', + args = [TensorInfo('shape',(1, 8192, 4, 4)), TensorInfo('shape', (8192, 512, 3, 3)),TensorInfo('shape', (8192,))], + kwargs = {'stride': 1, 'padding': 0, 'dilation': 1, 'groups': 16}, + parti_options = [{'idx': 0, 'dim': 1}], + noutputs = 1, +) +verify_partition_options(conv2d_config) + + +conv_transpose2d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv_transpose2d', + fsig = 'torch.conv_transpose2d', + args = [TensorInfo('shape',(8192, 4, 4)), TensorInfo('shape', (8192, 512, 3, 3))], + kwargs = {'stride': 1, 'padding': 0, 'output_padding': 0, 'dilation': 1, 'groups': 16}, + parti_options = [{'idx': 0, 'dim': 0}], + noutputs = 1, +) +verify_partition_options(conv_transpose2d_config) +conv_transpose2d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv_transpose2d', + fsig = 'torch.conv_transpose2d', + args = [ + TensorInfo('shape', (512, 4, 4)), + TensorInfo('shape', (512, 512, 3, 3)) + ], + kwargs = {'stride': 1, 'padding': 0, 'output_padding': 0, 'dilation': 1, 'groups': 1}, + parti_options = [{'idx': 0, 'dim': 0}], + noutputs = 1, +) + +verify_partition_options(conv_transpose2d_config) +conv_transpose2d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv_transpose2d', + fsig = 'torch.conv_transpose2d', + args = [TensorInfo('shape',(1, 8192, 4, 4)), TensorInfo('shape', (8192, 512, 3, 3))], + kwargs = {'stride': 1, 'padding': 0, 'output_padding': 0, 'dilation': 1, 'groups': 16}, + parti_options = [{'idx': 0, 'dim': 1}], + noutputs = 1, +) +verify_partition_options(conv_transpose2d_config) +conv_transpose2d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv_transpose2d', + fsig = 'torch.conv_transpose2d', + args = [ + TensorInfo('shape', (1, 512, 4, 4)), + TensorInfo('shape', (512, 512, 3, 3)) + ], + kwargs = {'stride': 1, 'padding': 0, 'output_padding': 0, 'dilation': 1, 'groups': 1}, + parti_options = [{'idx': 0, 'dim': 1}], + noutputs = 1, +) +verify_partition_options(conv_transpose2d_config) + + +conv1d_config = VerifyConfig( + # import_customized_func = 'from torch.nn.functional import conv1d', + fsig = 'torch.conv1d', + args = [ + TensorInfo('shape',(1, 512, 400)), + TensorInfo('shape', (128, 512, 3)) + ], + kwargs = {'stride': 1, 'padding': 0, 'dilation': 1, 'groups': 1}, + parti_options = [{'idx': 0, 'dim': 1}], + noutputs = 1, +) +verify_partition_options(conv1d_config) +conv1d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv1d', + fsig = 'torch.conv1d', + args = [ + TensorInfo('shape', (1, 8192, 400)), + TensorInfo('shape', (128, 512, 3)) + ], + kwargs = {'stride': 1, 'padding': 0, 'dilation': 1, 'groups': 16}, + parti_options = [{'idx': 0, 'dim': 1}], + noutputs = 1, +) +verify_partition_options(conv1d_config) + + +pose1d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv_transpose1d', + fsig = 'torch.conv_transpose1d', + args = [ + TensorInfo('shape', (1, 512, 100)), + TensorInfo('shape', (512, 256, 3)) + ], + kwargs = {'stride': 1, 'padding': 0, 'output_padding': 0, 'dilation': 1, 'groups': 8}, + parti_options = [{'idx': 0, 'dim': 1}], + noutputs = 1, +) +verify_partition_options(pose1d_config) +pose1d_config = VerifyConfig( + import_customized_func = 'from torch.nn.functional import conv_transpose1d', + fsig = 'torch.conv_transpose1d', + args = [ + TensorInfo('shape', (1, 512, 100)), + TensorInfo('shape', (512, 256, 3)) + ], + kwargs = {'stride': 1, 'padding': 0, 'output_padding': 0, 'dilation': 1, 'groups': 1}, + parti_options = [{'idx': 0, 'dim': 1}], + noutputs = 1, +) +verify_partition_options(pose1d_config) + + +verify_config = VerifyConfig( + fsig='nnscaler.graph.function.wrapnn.wrap_batchnorm2d_func', + args=[ + TensorInfo('shape', value=(32, 64, 8, 8)), + TensorInfo('shape', value=(64,)), + TensorInfo('shape', value=(64,)), + TensorInfo('shape', value=(64,)), + TensorInfo('shape', value=(64,)), + TensorInfo('shape', value=(1,)), + ], + kwargs={ + 'momentum': 0.1, + 'training': True, + 'track_running_stats': True, + 'eps': 1e-05 + }, + noutputs=1, + parti_options=[{'idx': 0, 'dim': 0}, + {'idx': 0, 'dim': 1}], + non_grad_indices=[3, 4, 5] +) +verify_partition_options(verify_config) + + +addmm_config = VerifyConfig( + import_customized_func='from torch import addmm', + fsig='torch.addmm', + args=[ + TensorInfo('shape', (2, 3)), + TensorInfo('shape', (2, 3)), + TensorInfo('shape', (3, 3)) + ], + kwargs={}, + parti_options=[{'idx': 0, 'dim': 0}, {'idx': 1, 'dim': 0}], + noutputs=1 +) + +verify_partition_options(addmm_config) + + +verify_config = VerifyConfig( + fsig='nnscaler.graph.function.wrapnn.wrap_instancenorm2d_func', + args=[ + TensorInfo('shape', value=(32, 64, 8, 8)), + TensorInfo('shape', value=(64,)), + TensorInfo('shape', value=(64,)), + TensorInfo('shape', value=(64,)), + TensorInfo('shape', value=(64,)), + ], + kwargs={ + 'training': True, + 'momentum':0.1, + 'eps': 1e-05 + }, + noutputs=1, + parti_options=[{'idx': 0, 'dim': 0}], + non_grad_indices=[3, 4] +) +verify_partition_options(verify_config) +``` \ No newline at end of file From 93e35da3758b7665222b623bfb8e79a4b0e1d61c Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Mon, 2 Sep 2024 06:30:40 +0000 Subject: [PATCH 1720/1892] Merged PR 2249: Llama3 128K finetuning for v0.3 release Check the `README.md` for guidance --- docs/source/parallel_module.md | 1 + examples/llama3_8B_128K/.gitignore | 2 + examples/llama3_8B_128K/README.md | 121 ++++++++ examples/llama3_8B_128K/bookcorpus.py | 60 ++++ .../chunk_linear_cross_entropy.py | 65 ++++ examples/llama3_8B_128K/ckpt_merger.py | 33 ++ examples/llama3_8B_128K/create_mini_model.py | 38 +++ examples/llama3_8B_128K/modeling_modifier.py | 267 ++++++++++++++++ examples/llama3_8B_128K/requirements.txt | 2 + examples/llama3_8B_128K/train.py | 286 ++++++++++++++++++ nnscaler/policies.py | 2 + 11 files changed, 877 insertions(+) create mode 100644 examples/llama3_8B_128K/.gitignore create mode 100644 examples/llama3_8B_128K/README.md create mode 100644 examples/llama3_8B_128K/bookcorpus.py create mode 100644 examples/llama3_8B_128K/chunk_linear_cross_entropy.py create mode 100644 examples/llama3_8B_128K/ckpt_merger.py create mode 100644 examples/llama3_8B_128K/create_mini_model.py create mode 100644 examples/llama3_8B_128K/modeling_modifier.py create mode 100644 examples/llama3_8B_128K/requirements.txt create mode 100644 examples/llama3_8B_128K/train.py diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index fb0e5ebf..43ca9137 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -591,6 +591,7 @@ It requires the `use_end2end` to be true. It has the following configurations. - `pipeline_scheduler`: The scheduler name for the pipeline. Please note currently `1f1b` is the only supported scheduler in `autodist`. Default is `1f1b`. Optional. - `parallel_profile`: If set to `True`, autodist will profile operators in parallel by using available gpus. Default is `True`. Optional. - `max_partition_degree`: Max degree when partitioning an operator / node. When pipeline parallelism is enabled to explore (`explore_pipeline` is True), user can change the value to constrain the plan to be composed of stages that span on less or equal to `max_partition_degree` devices (recommend to set `max_partition_degree` to the number of devices in a node to avoid inter-node communication, but should be be no more than `plan_ngpus`). Default is `plan_ngpus`. Optional. + - `transient_mem_coef`: In autodist, a heuristic is used to estimate the transient memory size: `transient_mem_size = opt_transient_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)`. This formula is useful in many cases, but it may be too strict when some operators consume or generate a large tensor (>= 4GB). In this case, you can set `transient_mem_coef` to a smaller value to relax the constraint. Default is `2`. Optional. You can also put any other settings that can affect code generation here. but please prefix the keys with `_` to avoid conflicts with predefined keys. diff --git a/examples/llama3_8B_128K/.gitignore b/examples/llama3_8B_128K/.gitignore new file mode 100644 index 00000000..8c2014ac --- /dev/null +++ b/examples/llama3_8B_128K/.gitignore @@ -0,0 +1,2 @@ +runs/ +*.log diff --git a/examples/llama3_8B_128K/README.md b/examples/llama3_8B_128K/README.md new file mode 100644 index 00000000..6dd30eda --- /dev/null +++ b/examples/llama3_8B_128K/README.md @@ -0,0 +1,121 @@ +# Introduction + +This example demonstrates how to train llama3-8B-128k model with 8xH100s or 8xA100s. + +# Requirements + +To run this example, you need to install the following packages: + +```text +nnscaler +transformers==4.40.0 +datasets==2.20.0 +apex +flash-attn +``` + +*nnScaler* is a framework for distributed training by automatically partitioning the model. Apart from the core nnScaler library, it also includes a mini-trainer for modern model training. You can find related documents and examples at [nnScaler](TODO). + +*transformers* and *datasets* are required to prepare the data and loading the Llama model. + +To speed up the training process, [*apex*](https://github.com/NVIDIA/apex) and [*flash-attn*](https://github.com/Dao-AILab/flash-attention) are required. You can install them by following instructions in their official repositories. We also recommend to launch the script under a Nvidia docker directly, like nvidia/pytorch:24.02-py3. + +# Data Preparation + +We use the [bookcorpus](https://huggingface.co/datasets/bookcorpus) dataset for training. The dataset is tokenized with the [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) tokenizer. Tokenized data is saved in the `bookcorpus_llama3_128K` directory. + +```bash +python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_128K --sequence_length 131072 +``` + +# Training + +nnScaler adopts a compiler approach to launch the distributed training. The processing pipeline is divided into two stages: + +1. Compile stage: trace the original PyTorch model and get the dataflow graph. Analyze the graph and generate an efficient plan for distributed training. Generate python code for the runtime stage. +2. Runtime stage: run the generated python code to train the model. + +For better user experience, we recommend to use separate commands for the compile and runtime stages. You can also use the `Run` command directly to combine the two stages. + +**Note**: currently we only tested `"_attn_implementation": "flash_attention_2"` and `"use_cache": false` in the config file. Other configurations may not work properly. + +## Register Customized Function + +Llama3's vocabulary size is about 128K, which is much larger then the 32K in Llama2. At the same time the sequence length in this example is 128K, the output tensor size of the last projection layer is quite large: 128K x 128K x 2 bytes = 32GB. +Although this tensor can be partitioned evenly to 8 GPUs, 4GB memory is still quite large for limited GPU memory. What makes it worse is that we need to store additional 8GB for `log_softmax` and `cross_entropy_loss` computation. +In order to reduce the memory consumption: +- we split the input sequence on each device to chunks of 1K tokens +- for each chunk, we recompute a function which is composed of last projection layer, log_softmax and loss +- as a result, we only need to store the input tensor to the last projection layer, whose initial size is 128K x 4K x 2 bytes = 1GB, which is much smaller than 32GB + +You can find the detailed implementation in `chunk_linear_cross_entropy.py`. +The interface of the `chunk_linear_cross_entropy` function is `(hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor`, where +- `hidden_states` is the output of the last transformer layer, with shape `[batch_size, sequence_length, hidden_size]` +- `weight` is the weight matrix of the last projection layer, with shape `[vocab_size, hidden_size]` +- `labels` is the target labels, with shape `[batch_size, sequence_length]` +- `padding_idx` is the padding index +- `chunk_size` is the size of the chunk, default is 1024 + +We want to register this function to nnScaler and tell it to partition this function along batch size or sequence dimension. A possible annotation is `b l d^, n^ d^, b l -> b l`. Here `b` stands for batch size, `l` stands for sequence length, `d` stands for hidden size, and `n` stands for vocab size. The `^` means the dimension cannot be partitioned. More details about the annotation can be found in related documents. + +## Compile + +```bash +python train.py --run_mode compile --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 2>&1 | tee compile.log +``` + +## Run + +```bash +torchrun --nproc_per_node=8 train.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --dataset_path ./bookcorpus_llama3_128K --plan_ngpus=8 --runtime_ngpus=8 2>&1 | tee run.log +``` + +## Checkpoint + +This script will save the model checkpoint in the `./checkpoints` directory. You can change the checkpoint directory by updating the `CheckpointConfig` in the `train.py` script. + +nnScalar saves checkpoints in shards: each rank may save parameters and optimizer states in a file. These checkpoints can be directly loaded by nnScaler if the partitioning strategy is the same. If you want to evaluate the checkpoints on downstream tasks, you need to merge the shards into a single file. You can use the following command to merge the shards: + +```bash +python ckpt_merger.py --ckpt_dir ./checkpoints --output_fname ./merged.ckpt +``` + +The merged checkpoint can be loaded by nnScaler by setting the `--resume_path` option to the merged file. + +If the script is modified for different hardware configurations. +- All sharded checkpoint files should be collected and placed in a same directory before `ckpt_merger.py` is called. +- If the config is changed (plan_ngus/runtime_ngus/etc), the sharded checkpoint can not be used anymore. You need to merge them so the trainer can load from merged checkpoint. + +# Performance + +The flops of the forward computation for Llama3 is + +$2 \cdot ( param\_num \cdot seqlen + 2 \cdot layer\_num \cdot hidden\_dim \cdot seqlen ^ 2)$ + +For the 8B model, the forward flops is about 11104.35 TFLOPs. The detailed config is as following: +- $param\_num = 8 \times 10^9$ +- $seqlen = 128 \times 1024$ +- $layer\_num = 32$ +- $hidden\_dim = 4096$ + +Generally, the computational cost of backpropagation is twice that of the forward pass. In addition, the gradient accumulation number is set to 4. As a result, the flops for a step of the training script is 133252.22 TFLOPs. + +We execute the training script on a node with 8xH100 80GB HBM3. The time cost is about 41.12s for a step. The theoretical BF16 computational speed of the H100 is 989 TFLOPS. Combine them together, this script can achieve 40.96% MFU. You can optimize the performance furtherly by +- add more devices to avoid recomputation: in order to fit the model into the memory, we recompute by layer. +- do more kernel optimizations. For example, the swiglu activation can be fused into the matmul ahead of it. + +# Debugging + +Since the 128K config is challenging, it is recommended to use a smaller model for debugging. For example, you can use the following command to prepare data and train a smaller llama3 (same architecture, but with 4 decoder layers) model on two GPUs. + +```bash +# prepare data +python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 + +# build the mini model +python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini + +# compile and run using data parallelism + zero1 +torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K + +``` diff --git a/examples/llama3_8B_128K/bookcorpus.py b/examples/llama3_8B_128K/bookcorpus.py new file mode 100644 index 00000000..b69b90b0 --- /dev/null +++ b/examples/llama3_8B_128K/bookcorpus.py @@ -0,0 +1,60 @@ +import argparse +from typing import List, Dict + +from datasets import load_dataset, Dataset +from transformers import AutoTokenizer, PreTrainedTokenizer +import numpy +import torch + + +def get_tokenizer(model_path): + return AutoTokenizer.from_pretrained(model_path) + + +def tokenize(sample: Dict[str, str], tokenizer: PreTrainedTokenizer, text_key: str): + input_ids = tokenizer.encode(tokenizer.bos_token + sample[text_key] + tokenizer.eos_token, add_special_tokens=False) + return {"input_ids": input_ids} + + +def concate_split(samples: Dict[str, List[List[int]]], sample_len: int, text_key: str): + buffer = samples[text_key][0] + resized_ids = [] + length = [] + for in_ids in samples[text_key]: + buffer.extend(in_ids) + while len(buffer) >= sample_len: + resized_ids.append(buffer[:sample_len]) + length.append(sample_len) + buffer = buffer[sample_len:] + return {"input_ids": resized_ids, "length": length} + + +def create_dataset(tokenizer: PreTrainedTokenizer, raw_dataset: Dataset, text_key: str, sample_len: int = 8 * 1024, batch_size=10000): + tokenized_dataset = raw_dataset.map(tokenize, remove_columns=raw_dataset.column_names, num_proc=32, + fn_kwargs={'tokenizer': tokenizer, 'text_key': text_key}) + return tokenized_dataset.map(concate_split, remove_columns=tokenized_dataset.column_names, num_proc=32, batched=True, + batch_size=batch_size, fn_kwargs={'sample_len': sample_len, 'text_key': 'input_ids'}) + + +if __name__ == '__main__': + # python bookcorpus.py --data_path_or_name "bookcorpus/bookcorpus" --tokenizer_path_or_name "meta-llama/Llama-2-7b-hf" --save_path "bookcorpus-llama2-2k-hf" --sequence_length 2048 + parser = argparse.ArgumentParser() + parser.add_argument('--data_path_or_name', help='the path or name of the raw dataset, for exmaple, "bookcorpus/bookcorpus"', type=str, required=True) + parser.add_argument('--tokenizer_path_or_name', help='the tokenizer path or name, for example, "meta-llama/Llama-2-7b-hf"', type=str, required=True) + parser.add_argument('--save_path', help='the path to save the tokenized dataset', type=str, required=True) + parser.add_argument('--sequence_length', help='the length of each sample in the tokenized dataset, usually set to the max sequence length', type=int, required=True) + args = parser.parse_args() + data_path_or_name = args.data_path_or_name + tokenizer_path_or_name = args.tokenizer_path_or_name + save_path = args.save_path + sequence_length = args.sequence_length + + raw_dataset = load_dataset(data_path_or_name)["train"] + tokenizer = get_tokenizer(tokenizer_path_or_name) + dataset = create_dataset(tokenizer, raw_dataset, "text", sequence_length) + dataset.save_to_disk(save_path) + # used by fairseq dataset + sizes = numpy.array(dataset["length"]) + torch.save(sizes, f"{save_path}/fairseq_dataset_sizes.np") + + print(f"If you want to use this dataset, please set `DATASET_TYPE=hf_disk` and `DATASET_PATH={save_path}` in run script") diff --git a/examples/llama3_8B_128K/chunk_linear_cross_entropy.py b/examples/llama3_8B_128K/chunk_linear_cross_entropy.py new file mode 100644 index 00000000..db902b92 --- /dev/null +++ b/examples/llama3_8B_128K/chunk_linear_cross_entropy.py @@ -0,0 +1,65 @@ +import torch +import torch.utils.checkpoint as ckpt + +from nnscaler.graph.parser.register import register_op + + +def linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int = 0) -> torch.Tensor: + """ + Compute the cross entropy loss of a linear layer. + + Args: + + x: [token_num, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [token_num], the target token index + padding_idx: int, the index of padding token + + Returns: + + losses: [token_num], the cross entropy loss of each token + """ + logits = torch.nn.functional.linear(x, w) + normalized_logits = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32) + losses = torch.nn.functional.nll_loss(normalized_logits, y, reduction='none', ignore_index=padding_idx) + return losses + + +def chunk_linear_cross_entropy(x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, padding_idx: int, chunk_size: int) -> torch.Tensor: + """ + In order to reduce the memory usage when the sequence length and dictionary size are large, we can split the input + tensor into chunks and compute the cross entropy loss of each chunk separately. + You can register this function with annotation 'b l d^, n^ d^, b l -> b l'. + + Args: + + x: [bsz, seq_len, hidden_size], the last hidden state of the model + w: [dict_size, hidden_size], the weight matrix of the last linear layer + y: [bsz, seq_len], the target token index + padding_idx: int, the index of padding token + chunk_size: int, the size of each chunk + + Returns: + + losses: [bsz, seq_len], the cross entropy loss of each token + """ + bsz, seq_len, hidden_size = x.size() + token_num = bsz * seq_len + x = x.view(token_num, hidden_size) + y = y.view(token_num) + + if token_num % chunk_size != 0: + raise ValueError(f"token_num {token_num} is not divisible by chunk_size {chunk_size}") + + chunk_num = token_num // chunk_size + xs = x.view(chunk_num, chunk_size, hidden_size) + ys = y.view(chunk_num, chunk_size) + losses = [] + for i in range(chunk_num): + loss = ckpt.checkpoint(linear_cross_entropy, xs[i], w, ys[i], padding_idx, use_reentrant=False) + losses.append(loss) + losses = torch.stack(losses).view(bsz, seq_len) + return losses + + +register_op('b l d^, n^ d^, b l -> b l')(chunk_linear_cross_entropy) diff --git a/examples/llama3_8B_128K/ckpt_merger.py b/examples/llama3_8B_128K/ckpt_merger.py new file mode 100644 index 00000000..ce80f5a7 --- /dev/null +++ b/examples/llama3_8B_128K/ckpt_merger.py @@ -0,0 +1,33 @@ +import argparse +import os + +import torch + +from nnscaler.cli.trainer import Trainer +from nnscaler.utils import set_default_logger_level + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--ckpt_dir', + default=None, + type=str, + help='path to the sharded checkpoint directory', + ) + parser.add_argument( + '--output_fname', + default=None, + type=str, + help='output filename', + ) + args = parser.parse_args() + + if args.ckpt_dir is None: + raise ValueError('ckpt_dir is required') + if args.output_fname is None: + raise ValueError('output_fname is required') + + set_default_logger_level('INFO') + ckpt_files = [os.path.join(args.ckpt_dir, f) for f in os.listdir(args.ckpt_dir) if f.endswith('.ckpt')] + Trainer.merge_checkpoint(ckpt_files, args.output_fname) diff --git a/examples/llama3_8B_128K/create_mini_model.py b/examples/llama3_8B_128K/create_mini_model.py new file mode 100644 index 00000000..e0718480 --- /dev/null +++ b/examples/llama3_8B_128K/create_mini_model.py @@ -0,0 +1,38 @@ +import argparse +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + +def main(args): + config = AutoConfig.from_pretrained(args.model_id) + config.num_hidden_layers = 4 + config.use_cache = False + config._attn_implementation = 'flash_attention_2' + model = AutoModelForCausalLM.from_config(config) + model.save_pretrained(args.output_id) + + tokenizer = AutoTokenizer.from_pretrained(args.model_id) + tokenizer.save_pretrained(args.output_id) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_id', + default=None, + type=str, + help='huggingface model id / path', + ) + parser.add_argument( + '--output_id', + default=None, + type=str, + help='output model id / path', + ) + args = parser.parse_args() + + if args.model_id is None: + raise ValueError('model_id is required') + if args.output_id is None: + raise ValueError('output_id is required') + + main(args) diff --git a/examples/llama3_8B_128K/modeling_modifier.py b/examples/llama3_8B_128K/modeling_modifier.py new file mode 100644 index 00000000..57b65b90 --- /dev/null +++ b/examples/llama3_8B_128K/modeling_modifier.py @@ -0,0 +1,267 @@ +# This file modifies the official modeling_llama.py file at runtime to +# 1. register the flash attention function to nnscaler and update related code +# 2. replace the un-fused RMSNorm with apex's fused version + +import types +from typing import List, Optional, Tuple, Union + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir import IRTensor + +import torch + +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.models.llama.modeling_llama import LlamaAttention, LLAMA_ATTENTION_CLASSES, apply_rotary_pos_emb, LlamaRMSNorm +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +try: + from apex.normalization.fused_layer_norm import fused_rms_norm_affine + has_apex = True +except ImportError: + has_apex = False + + +def rmsnorm_fwd(self, hidden_states): + if has_apex: + return fused_rms_norm_affine(hidden_states, self.weight, self.weight.shape, self.variance_epsilon) + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class NNScalerLlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and q_len != 1 + + attn_output = nnscaler_flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, causal=causal + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def nnscaler_flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, causal=True +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = nnscaler_upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + +def nnscaler_upad_input(query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + _, _, num_heads, _ = query_layer.shape + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def llama_flash_attention_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str: + if query_states.shape[2] != key_states.shape[2]: + assert query_states.shape[2] % key_states.shape[2] == 0 + group_size = query_states.shape[2] // key_states.shape[2] + assert query_states.shape[2] == value_states.shape[2] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + if isinstance(attention_mask, IRTensor): + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, b l^ -> b l^ {q_anno} vd^' + else: + return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^' + + +register_op(llama_flash_attention_anno)(nnscaler_flash_attention_forward) + + +def nnscaler_llama_init(): + LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NNScalerLlamaFlashAttention2 + LlamaRMSNorm.forward = rmsnorm_fwd diff --git a/examples/llama3_8B_128K/requirements.txt b/examples/llama3_8B_128K/requirements.txt new file mode 100644 index 00000000..8001637d --- /dev/null +++ b/examples/llama3_8B_128K/requirements.txt @@ -0,0 +1,2 @@ +transformers==4.40.0 +datasets==2.20.0 diff --git a/examples/llama3_8B_128K/train.py b/examples/llama3_8B_128K/train.py new file mode 100644 index 00000000..8b116cb0 --- /dev/null +++ b/examples/llama3_8B_128K/train.py @@ -0,0 +1,286 @@ +import argparse +import os + +import datasets +from datasets import load_from_disk +import huggingface_hub +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling +from modeling_modifier import nnscaler_llama_init +from chunk_linear_cross_entropy import chunk_linear_cross_entropy + +from nnscaler.utils import set_default_logger_level +from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer_args import ( + CheckpointConfig, + DatasetConfig, + HookMapConfig, + ModelConfig, + OptimizerConfig, + TrainerArgs, + DataloaderConfig, + AggregatedOutputs, + LogConfig, + DatasetSamplerConfig, +) +from nnscaler.parallel import ComputeConfig, BroadcastGenFilesStrategy +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdamW +from nnscaler.cli.loggers.tensorboard import TensorBoardLogger + + +IGNORE_IDX = -100 + + +def get_tokenizer(tokenizer_name_or_path, + model_max_length=None, + default_bos_token="", + default_eos_token="", + default_pad_token="[PAD]", + default_unk_token=""): + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + special_tokens_dict = dict() + if tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = default_pad_token + if tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = default_eos_token + if tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = default_bos_token + if tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = default_unk_token + + tokenizer.add_special_tokens(special_tokens_dict) + if model_max_length: + tokenizer.model_max_length = model_max_length + return tokenizer + + +class WrapperModel(torch.nn.Module): + def __init__(self, model_id): + super().__init__() + self.model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation='flash_attention_2') + + def forward(self, samples): + outputs = self.model.model( + input_ids=samples['net_input']['src_tokens'], + use_cache=False, + return_dict=False, + ) + hidden_states = outputs[0] + losses = chunk_linear_cross_entropy(hidden_states, self.model.lm_head.weight, samples['target'], IGNORE_IDX, 1024) + loss = torch.sum(losses) + return loss, loss.data, samples['ntokens'], samples['nsentences'] + + +def aggregate_outputs_fn(loss_outputs, sync_group) -> AggregatedOutputs: + losses, ntokens_info = [], [] + for _, loss, ntokens, _ in loss_outputs: + losses.append(loss) + ntokens_info.append(ntokens) + + loss_sum = torch.sum(torch.stack(losses), dtype=torch.float64) + torch.distributed.all_reduce(loss_sum, group=sync_group) + ntokens_sum = torch.sum(torch.tensor(ntokens_info, dtype=torch.float64, device=torch.cuda.current_device())) + torch.distributed.all_reduce(ntokens_sum, group=sync_group) + num_batches = torch.tensor(len(losses), device=torch.cuda.current_device()) + torch.distributed.all_reduce(num_batches, group=sync_group) + + return AggregatedOutputs( + loss_sum=loss_sum.item() / ntokens_sum.item(), + num_batches=num_batches.item(), + num_tokens=ntokens_sum.item(), + ) + + +def main(args): + + if args.run_mode == 'run': + broadcast_strategy = 'all' + else: + broadcast_strategy = 'none' + + set_default_logger_level('INFO') + + nnscaler_llama_init() + + ## Setup Dataset ## + + dataset = load_from_disk(args.dataset_path) + tokenizer = get_tokenizer(args.model_id) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + def collate_fn(samples): + if len(samples) == 0: + return {} + + mini_batch = data_collator(samples) + _mini_batch = {} + + src_tokens = mini_batch.pop('input_ids') + seq_len = src_tokens.size(-1) + _mini_batch['src_tokens'] = src_tokens + + shift_labels = mini_batch['labels'][..., 1:] + _mini_batch['labels'] = torch.nn.functional.pad(shift_labels, (0, 1), 'constant', IGNORE_IDX).contiguous() + + return { + "nsentences": len(samples), + "ntokens": len(samples) * seq_len, + "net_input": _mini_batch, + "target": _mini_batch.pop('labels'), + } + + ## Config Trainer ## + + if args.run_mode == 'compile': + if args.runtime_ngpus is None: + raise ValueError('runtime_ngpus must be specified in compile mode') + runtime_ngpus = args.runtime_ngpus + elif args.run_mode == 'run': + world_size = int(os.getenv('WORLD_SIZE')) + if args.runtime_ngpus is None: + runtime_ngpus = world_size + else: + if args.runtime_ngpus != world_size: + raise ValueError('runtime_ngpus must match the number of GPUs in run mode') + runtime_ngpus = args.runtime_ngpus + if runtime_ngpus % args.plan_ngpus != 0: + raise ValueError('runtime_ngpus must be a multiple of plan_ngpus') + + compute_config = ComputeConfig( + plan_ngpus=args.plan_ngpus, + runtime_ngpus=runtime_ngpus, + constant_folding=True, + use_zero=True, + use_end2end=True, + # autodist config: + # - memory constraint is set to 64GB + # - recompute by the transformer layer in Llama + pas_config={ + 'mem_constraint': 64, + 'recompute_modules': 'LlamaDecoderLayer', + }, + ) + + model_config = ModelConfig( + type=WrapperModel, + args={ + 'model_id': args.model_id, + }, + ) + + # optimizer hyperparameters are from YaRN + optimizer_config = OptimizerConfig( + type=MixedPrecisionAdamW, + args={'lr': 2e-5, 'betas': (0.9, 0.95), 'weight_decay': 0.0, 'fused': True}, + clip_gnorm=1.0, + loss_reduction='sum', + grad_reduction='per-token-mean', + aggregate_outputs_fn=aggregate_outputs_fn, + ) + + dataset_config = DatasetConfig( + type=(lambda split: dataset), + train_args={'split': 'train'}, + ) + + dataloader_config = DataloaderConfig( + train_args={ + 'collate_fn': collate_fn, + 'drop_last': True, + }, + ) + + sampler_config = DatasetSamplerConfig( + train_args={ + 'shuffle': False, + }, + ) + + checkpoint_config = CheckpointConfig( + every_n_train_steps=1000, + save_type='deduped', + resume_from=(args.resume_path or 'last'), + ) + + log_config = LogConfig( + type=TensorBoardLogger, + args={ + 'name': args.name, + 'root_dir': './runs', + }, + ) + + trainer_args = TrainerArgs( + instance_name=args.name, + run_mode=args.run_mode, + compute_config=compute_config, + pas_policy='autodist', + model=model_config, + optimizer=optimizer_config, + dataset=dataset_config, + dataloader=dataloader_config, + checkpoint=checkpoint_config, + precision='bf16', + max_epochs=2, + grad_accumulation_steps=4, + log=[log_config], + seed=0, + broadcast_strategy=broadcast_strategy, + dataset_sampler=sampler_config, + ) + + trainer = Trainer(train_args=trainer_args) + if args.run_mode == 'run': + trainer.train() + + +if __name__ == '__main__': + ## Parse Args ## + + parser = argparse.ArgumentParser() + parser.add_argument( + '--name', + default='llama3-8b', + type=str, + help='name of the experiment', + ) + parser.add_argument( + '--run_mode', + default='run', + choices=['run', 'compile'], + help='run or compile', + ) + parser.add_argument( + '--plan_ngpus', + type=int, + required=True, + help='specify the scale unit size', + ) + parser.add_argument( + '--runtime_ngpus', + type=int, + required=True, + help='specify the number of GPUs to use', + ) + parser.add_argument( + '--resume_path', + default=None, + type=str, + help='path to dir of ckpts or the ckpt file to resume from', + ) + parser.add_argument( + '--dataset_path', + default=None, + type=str, + help='path to the dataset', + ) + parser.add_argument( + '--model_id', + default=None, + type=str, + help='transformers model id', + ) + args = parser.parse_args() + + main(args) diff --git a/nnscaler/policies.py b/nnscaler/policies.py index cfb65264..567c5e3f 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -233,6 +233,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: pipeline_pivots = pas_cfg.get('pipeline_pivots', '') use_apex_fused_adam_v2 = pas_cfg.get('use_apex_fused_adam_v2', False) parallel_profile = pas_cfg.get('parallel_profile', True) + transient_mem_coef = pas_cfg.get('transient_mem_coef', 2) task_name = f'{task_name}_{cfg.plan_ngpus}gpus_{update_freq}update_freq' if memory_constraint == -1: @@ -296,6 +297,7 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: pipeline=explore_pipeline, pipeline_pivots=pipeline_pivots, parallel_profile=parallel_profile, + transient_mem_coef=transient_mem_coef, ) return parallelize_graph(graph, autodist_cfg) From a0b5775cabb564caf5696384d7643762acc77ca8 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 3 Sep 2024 03:58:49 +0000 Subject: [PATCH 1721/1892] Merged PR 2254: Code Refine: Input to IRObject conversion Refactor to unify input-> IRObject conversion. unit test pass parity checkpass --- nnscaler/cli/trainer.py | 33 +++---- nnscaler/graph/parser/fx/parser.py | 29 ++----- nnscaler/ir/cten.py | 126 +++++++++++++++++++++++++-- nnscaler/parallel.py | 42 +++------ nnscaler/program.py | 15 +--- nnscaler/utils.py | 1 + tests/ir/test_cten.py | 133 +++++++++++++++++++++++++++++ 7 files changed, 289 insertions(+), 90 deletions(-) create mode 100644 tests/ir/test_cten.py diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index 77977f42..f3faa6e8 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -41,7 +41,8 @@ @dataclass class TrainStatus: best_loss = float('inf') - num_train_steps_done: int = 0 + # the train steps done so far + finished_train_steps: int = 0 @dataclass @@ -287,7 +288,7 @@ def _log_finalize(self): logger.finalize() def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None, *, tag: Optional[str] = None): - step = step or self.train_status.num_train_steps_done + step = step or self.train_status.finished_train_steps for logger in self.loggers: logger.log_metrics(metrics, step, tag=tag) @@ -374,12 +375,12 @@ def _save_checkpoint(self, loss): return torch.distributed.barrier() - logger.info(f"Saving checkpoint after {self.train_status.num_train_steps_done} steps with loss={loss:.3f}.") + logger.info(f"Saving checkpoint after {self.train_status.finished_train_steps} steps with loss={loss:.3f}.") save_dir = Path(checkpoint_config.save_dir) save_dir.mkdir(parents=True, exist_ok=True) - current_epoch = self.train_status.num_train_steps_done // self.total_train_steps_per_epoch + current_epoch = self.train_status.finished_train_steps // self.total_train_steps_per_epoch # the last step of the epoch - if self.train_status.num_train_steps_done % self.total_train_steps_per_epoch == 0: + if self.train_status.finished_train_steps % self.total_train_steps_per_epoch == 0: current_epoch -= 1 if checkpoint_config.save_type == 'sharded': @@ -404,7 +405,7 @@ def _save_checkpoint(self, loss): self.hook.on_save_checkpoint(self, state_dict) ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( epoch=current_epoch, - step=self.train_status.num_train_steps_done, + step=self.train_status.finished_train_steps, rank=self.rank, ) logger.info(f"Saving checkpoint to {str(ckpt_file.parent)}") @@ -434,7 +435,7 @@ def _save_checkpoint(self, loss): logger.info(f"Saving checkpoint as the best checkpoint.") best_file = save_dir / CHECKPOINT_BEST_FILE_FORMAT.format( epoch=current_epoch, - step=self.train_status.num_train_steps_done, + step=self.train_status.finished_train_steps, rank=self.rank, ) best_file.parent.mkdir(parents=True, exist_ok=True) @@ -535,11 +536,11 @@ def train(self): # So that we can get accurate peak memory usage for each step torch.cuda.reset_peak_memory_stats() - if self.train_status.num_train_steps_done >= self.max_train_steps: + if self.train_status.finished_train_steps >= self.max_train_steps: logger.info(f"Training is skipped: already done.") return - start_epoch = self.train_status.num_train_steps_done // self.total_train_steps_per_epoch + start_epoch = self.train_status.finished_train_steps // self.total_train_steps_per_epoch self.hook.on_train_start(self) @@ -555,7 +556,7 @@ def train(self): if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'epoch': self.lr_scheduler.step() - if self.train_args.max_train_steps and self.train_status.num_train_steps_done >= self.train_args.max_train_steps: + if self.train_args.max_train_steps and self.train_status.finished_train_steps >= self.train_args.max_train_steps: logger.info(f"Reached max train steps({self.train_args.max_train_steps}): Training is done.") break @@ -644,7 +645,7 @@ def train_epoch(self, epoch): VAL_STATUS_SAVE = 2 # validated and saved has_validated = VAL_STATUS_NO # 3 states - resume_from_idx = self.train_status.num_train_steps_done % self.total_train_steps_per_epoch + resume_from_idx = self.train_status.finished_train_steps % self.total_train_steps_per_epoch data_iter = enumerate(self._global_batch_iterator(num_skip_first=resume_from_idx)) max_epoch = self.max_train_steps // self.total_train_steps_per_epoch @@ -741,7 +742,7 @@ def train_epoch(self, epoch): if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'step': self.lr_scheduler.step() - self.train_status.num_train_steps_done += 1 + self.train_status.finished_train_steps += 1 self._log_mem_stats(tag='train') step_metrics = {k:v for k, v in asdict(step_stat).items() if v is not None} step_metrics['train_wall'] = time.perf_counter() - step_start_at @@ -749,18 +750,18 @@ def train_epoch(self, epoch): if self.rank == 0: progress.set_postfix(step_metrics) if self.train_args.enable_log_progress \ - and self.train_status.num_train_steps_done % self.train_args.log_progress_every_n_train_steps == 0: + and self.train_status.finished_train_steps % self.train_args.log_progress_every_n_train_steps == 0: logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) step_metrics = {} # validate and save checkpoint if self.train_args.checkpoint.every_n_train_steps and \ - self.train_status.num_train_steps_done % self.train_args.checkpoint.every_n_train_steps == 0: + self.train_status.finished_train_steps % self.train_args.checkpoint.every_n_train_steps == 0: self._validate_and_save(step_stat) has_validated = VAL_STATUS_SAVE # max_train_steps is reached - if self.train_status.num_train_steps_done >= self.max_train_steps: + if self.train_status.finished_train_steps >= self.max_train_steps: if step_metrics and self.train_args.enable_log_progress: logger.info(self._format_metrics(epoch_desc, idx + 1, step_metrics)) step_metrics = {} @@ -774,7 +775,7 @@ def train_epoch(self, epoch): break if not has_validated and self.train_args.val_every_n_train_steps and \ - self.train_status.num_train_steps_done % self.train_args.val_every_n_train_steps == 0: + self.train_status.finished_train_steps % self.train_args.val_every_n_train_steps == 0: self._validate(step_stat) has_validated = VAL_STATUS_VAL diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index 608cee42..73c58d57 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -125,32 +125,13 @@ def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame, is_constant: bool = True): assert isinstance(node, torch.fx.Node) - def meta2var(meta: Any) -> Any: - """Support complex data type of List, Tuple, Dict, Tensor/Object""" - if isinstance(meta, TensorMetadata): - shape = meta.shape - dtype = meta.dtype - requires_grad = meta.requires_grad - return IRFullTensor(shape=shape, name=node.name, - requires_grad=requires_grad, dtype=dtype) - if isinstance(meta, list): - return list(meta2var(item) for item in meta) - if isinstance(meta, tuple): - return tuple(meta2var(item) for item in meta) - if isinstance(meta, dict): - if not all(isinstance(key, str) for key in meta.keys()): - raise TypeError(f"only support dict type with str key, but got {meta.keys()}.\n{node}") - return {key : meta2var(value) for key, value in meta.items()} - if isinstance(meta, DICT_VALUES_TYPE): - return {key : meta2var(value) for key, value in enumerate(meta)}.values() - if isinstance(meta, DICT_ITEMS_TYPE): - return {key : meta2var(value) for key, value in meta}.items() - # TODO: data type check, with cases like {'a': 1.2, 'b': torch.Tensor} - return IRObject(name=node.name, value=meta, is_constant=is_constant) - assert hasattr(node, 'meta') and 'tensor_meta' in node.meta, f"Node {node} should have tensor_meta" meta = node.meta['tensor_meta'] - val = meta2var(meta) + val = IRObject.from_complex(node.name, meta, + collection_types=(list, tuple, dict, DICT_VALUES_TYPE, DICT_ITEMS_TYPE), + tensor_types=(TensorMetadata,), + is_constant=is_constant + ) frame.add_var(node.name, val) @staticmethod diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index 7a215362..f91f21e8 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -23,6 +23,7 @@ from nnscaler.ir.unique import IDGenerator from nnscaler.ir.dtype import DTypeInfo +from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE NestedVarOrStatic = Any @@ -470,11 +471,6 @@ def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: self._value: Optional[Any] = value self._is_constant: bool = is_constant - def __eq__(self, obj): - if not isinstance(obj, IRObject): - return False - return self._id == obj.tid - def __hash__(self) -> int: return self._id @@ -563,6 +559,126 @@ def overlap(self, other: Any) -> bool: else: return False + @classmethod + def from_complex(cls, + name: str, + data: Any, + *, + collection_types: Tuple = (tuple, list, dict), + tensor_types: Tuple = (torch.Tensor,), + is_constant: bool = False, + requires_grad: Optional[bool] = None, + tosub: bool = False, + ) -> Any: + """ + Convert complex data type of + collection_types (tuple, list, dict) + tensor_types (has shape/dtype/requires_grad) + into intermediate representation object. + + Rule: + 1. All tensor-like objects will be converted into IRFullTensor + 2. For any complex types, + a. if there is no tensor-like object, the whole object will be converted into IRObject + b. if there is tensor-like object, all items will be converted into IRObject. + Examples: + [1, 2, torch.tensor(3)] + -> [IRObject(1), IRObject(2), IRFullTensor(3)] + {'a': [1, 2, 3], 'b': torch.tensor(2)} + -> {'a': IRObject([1, 2, 3]), 'b': IRFullTensor(2)} + {'a': [1, 2, torch.tensor(3)], 'b': 2} + -> {'a': [IRObject(1), IRObject(2), IRFullTensor(3)], 'b': IRObject(2)} + Args: + name (str): the object name + data (Any): the complex data structure to be converted + collection_types (Tuple): the complex data types to be converted + tensor_types (Tuple): the tensor data types to be converted + tosub(bool): whether convert full tensor to sub-tensor + requires_grad (Optional[bool]): the requires_grad flag for the tensor-like object + None: will respect the original requires_grad flag + True: will set requires_grad to True + False: will set requires_grad to False + """ + from nnscaler.ir.tensor import IRFullTensor + + collection_types = tuple(collection_types) + tensor_types = tuple(tensor_types) + supported_collection_types = (tuple, list, dict, _DICT_VALUES_TYPE, _DICT_ITEMS_TYPE) + if any(t not in supported_collection_types for t in collection_types): + raise ValueError(f"Only support converting complex data type of {supported_collection_types}") + + def _inner(obj) -> Tuple[Any, bool]: + # second return is to know if there is any tensor-like object + + if isinstance(obj, tensor_types): + if requires_grad is None: + rg = obj.requires_grad + else: + # PyTorch only supports floating point and complex tensors for autograd. + # To align with PyTorch, we set requires_grad to False for other types. + rg = requires_grad and (obj.dtype.is_floating_point or obj.dtype.is_complex) + + tensor = IRFullTensor( + list(obj.shape), + name, + dtype=obj.dtype, + requires_grad=rg, + ) + if tosub: + tensor = tensor.tosub() + tensor._value = obj # is required in SemanticModel.forward + return tensor, True + + if isinstance(obj, collection_types): + if isinstance(obj, tuple): + result = [_inner(item) for item in obj] + if not any(r[1] for r in result): + return IRObject(name, value=obj, is_constant=is_constant), False + else: + return tuple(r[0] for r in result), True + if isinstance(obj, list): + result = [_inner(item) for item in obj] + if not any(r[1] for r in result): + return IRObject(name, value=obj, is_constant=is_constant), False + else: + return [r[0] for r in result], True + if isinstance(obj, dict): + if not all(isinstance(key, str) for key in obj.keys()): + raise TypeError(f"only support dict type with str key, but got {obj.keys()}.") + result = {k: _inner(v) for k, v in obj.items()} + if not any(r[1] for r in result.values()): + return IRObject(name, value=obj, is_constant=is_constant), False + else: + return {k: r[0] for k, r in result.items()}, True + if isinstance(obj, _DICT_VALUES_TYPE): + result = [_inner(item) for item in obj] + if not any(r[1] for r in result): + return IRObject(name, value=obj, is_constant=is_constant), False + else: + return {k: r[0] for k, r in enumerate(result)}.values(), True + if isinstance(obj, _DICT_ITEMS_TYPE): + result = {k: _inner(v) for k, v in obj} + if not any(r[1] for r in result.values()): + return IRObject(name, value=obj, is_constant=is_constant), False + else: + return {k: r[0] for k, r in result.items()}.items(), True + + return IRObject(name, value=obj, is_constant=is_constant), False + + return _inner(data)[0] + + @classmethod + def tosub_complex(cls, obj: Any) -> Any: + """ + Convert complex data type of tensor-like object into sub-tensor + + Args: + obj (Any): the complex data structure to be converted + """ + from nnscaler.ir.tensor import IRFullTensor + modifier = lambda t: t.tosub() if isinstance(t, IRFullTensor) else t + return IRCell.modify_objects_of_complex(obj, modifier) + def __repr__(self): return f'Object({self.name}{self.tid}, val={self.value}, is_constant={self.is_constant})' diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 9481130a..75ba5010 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -313,35 +313,6 @@ def _to_cpu(val: Any): return val -def to_ir_input(sample, name): - """Support complex of types: Tuple, List, Dict, torch.Tensor""" - if isinstance(sample, tuple): - return tuple(to_ir_input(t, name) for t in sample) - if isinstance(sample, list): - return list(to_ir_input(t, name) for t in sample) - if isinstance(sample, dict): - return {k: to_ir_input(v, str(k)) for k, v in sample.items()} - if isinstance(sample, torch.Tensor): - # note: we will always set tensor to require gradient, which may - # generate backward communications in adapter. However, as long as - # the data doesn't require gradient in real runtime, the backward - # communication will not be triggered. - # PyTorch only supports floating point and complex tensors for autograd. - # To align with PyTorch, we set requires_grad to False for other types. - requires_grad = sample.is_floating_point() or sample.is_complex() - tensor = IRFullTensor( - shape=sample.size(), - name=name, - requires_grad=requires_grad, - dtype=sample.dtype - ).tosub() - tensor._value = sample - if requires_grad: - tensor.grad = tensor.parent.grad.tosub() - return tensor - return IRObject(name, value=sample, is_constant=False) - - def _contains_uncommutable_data(ir_outputs: Any): """ only IRObject (but not IRTensor) is not commutable between gpus. @@ -654,8 +625,17 @@ def _gen_graph( else: raise ValueError(f"Input {node.target} not in dummy forward args, nor has default value.") for i in range(len(ir_dummy_inputs)): - ir_dummy_inputs[i] = to_ir_input(ir_dummy_inputs[i], fx_input_nodes[i].target) - # if the input is not a tensor, we should wrap it with IRObject + # note: we will always set tensor to require gradient, which may + # generate backward communications in adapter. However, as long as + # the data doesn't require gradient in real runtime, the backward + # communication will not be triggered. + ir_dummy_inputs[i] = IRObject.from_complex( + fx_input_nodes[i].target, ir_dummy_inputs[i], + requires_grad=True, + tosub=True, + is_constant=False, + ) + # if the input is a complex type, we should wrap it with IRObject if not isinstance(ir_dummy_inputs[i], IRObject): ir_dummy_inputs[i] = IRObject(fx_input_nodes[i].target, value=ir_dummy_inputs[i], is_constant=False) diff --git a/nnscaler/program.py b/nnscaler/program.py index 2ffddfb7..421f6d04 100644 --- a/nnscaler/program.py +++ b/nnscaler/program.py @@ -110,25 +110,12 @@ def __iter__(self): return self def __next__(self): - def generate_output(sample, name='data'): - """Support complex of types: Tuple, List, Dict, torch.Tensor""" - if isinstance(sample, tuple): - return tuple(generate_output(t, name) for t in sample) - if isinstance(sample, list): - return list(generate_output(t, name) for t in sample) - if isinstance(sample, dict): - return {k: generate_output(v, str(k)) for k, v in sample.items()} - if isinstance(sample, torch.Tensor): - tensor = IRFullTensor(list(sample.shape), name, dtype=sample.dtype).tosub() - tensor._value = sample - return tensor - return IRObject(name, value=sample, is_constant=False) # get dataloader sample sample = next(iter(self.dataloader)) if not isinstance(sample, tuple): sample = (sample,) # turn sample into IRObjects - outputs = tuple(generate_output(s) for s in sample) + outputs = tuple(IRObject.from_complex('data', s, tosub=True, requires_grad=False, is_constant=False) for s in sample) outputs = tuple(IRObject('data', value=out) if not isinstance(out, IRObject) else out for out in outputs) # create dataloader operation # the `self.irobj` is the IRObject standing for the non-tensor value of real dataloader. diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 046641a0..e10b8f9f 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -173,6 +173,7 @@ def enforce_zero_num_worker(cls) -> Generator[None, None, None]: _old__init__ = cls.__init__ def _new__init__(self, *args, **kwargs) -> None: kwargs['num_workers'] = 0 + kwargs['prefetch_factor'] = None _old__init__(self, *args, **kwargs) cls.__init__ = _new__init__ yield diff --git a/tests/ir/test_cten.py b/tests/ir/test_cten.py new file mode 100644 index 00000000..6c65b158 --- /dev/null +++ b/tests/ir/test_cten.py @@ -0,0 +1,133 @@ +import torch + +import pytest + +from nnscaler.ir.cten import IRObject +from nnscaler.ir.tensor import IRFullTensor, IRSubTensor +from nnscaler.graph.parser.fx.parser import TensorMetadata, DICT_VALUES_TYPE, DICT_ITEMS_TYPE + + +@pytest.mark.parametrize('tosub', [True, False]) +@pytest.mark.parametrize('requires_grad', [True, False, None]) +def test_from_complex(tosub, requires_grad): + tensor_type = IRSubTensor if tosub else IRFullTensor + rg = requires_grad + if rg is None: + rg = False + rgt = requires_grad + if rgt is None: + rgt = True + obj = IRObject.from_complex('n', 1, tosub=tosub, requires_grad=requires_grad) + assert type(obj) == IRObject and obj.value == 1 and not obj.is_constant and obj.name == 'n' + + obj = IRObject.from_complex('n', [1, 2], tosub=tosub, requires_grad=requires_grad) + assert type(obj) == IRObject and obj.value == [1, 2] and not obj.is_constant and obj.name == 'n' + + obj = IRObject.from_complex('n', {'a': 1, 'b': 2}, tosub=tosub, requires_grad=requires_grad) + assert type(obj) == IRObject and obj.value == {'a': 1, 'b': 2} and not obj.is_constant and obj.name == 'n' + + obj = IRObject.from_complex('n', {'a': {'c': [3, 4], 'd': [4, 5]}, 'b': [1,2]}, tosub=tosub, requires_grad=requires_grad) + assert type(obj) == IRObject and obj.value == {'a': {'c': [3, 4], 'd': [4, 5]}, 'b': [1,2]} and not obj.is_constant and obj.name == 'n' + + t1 = torch.tensor(1.0) + t2 = torch.tensor([2.0, 3.0], requires_grad=True) + + obj = IRObject.from_complex('n', t1, tosub=tosub, requires_grad=requires_grad) + assert type(obj) == tensor_type and id(obj.value) == id(t1) \ + and obj.shape == (1,) and obj.origin_shape == () and obj.dtype == torch.float \ + and obj.requires_grad == rg and not obj.is_constant \ + and obj.name == 'n' + + obj = IRObject.from_complex('n', [t1, t2, 1], tosub=tosub, requires_grad=requires_grad) + assert type(obj) == list and len(obj) == 3 + assert type(obj[0]) == tensor_type and id(obj[0].value) == id(t1) \ + and obj[0].shape == (1,) and obj[0].origin_shape == () and obj[0].dtype == torch.float \ + and obj[0].requires_grad == rg and not obj[0].is_constant \ + and obj[0].name == 'n' + assert type(obj[1]) == tensor_type and id(obj[1].value) == id(t2) \ + and obj[1].shape == (2,) and obj[1].origin_shape == (2,) and obj[1].dtype == torch.float \ + and obj[1].requires_grad == rgt and not obj[1].is_constant \ + and obj[1].name == 'n' + assert type(obj[2]) == IRObject and obj[2].value == 1 and not obj[2].is_constant and obj[2].name == 'n' + + obj = IRObject.from_complex('n', {'a': [1, 2, t1], 'b': 2}, tosub=tosub, requires_grad=requires_grad) + assert type(obj) == dict and len(obj) == 2 + x = obj['a'] + assert type(x) == list and len(x) == 3 + assert type(x[0]) == IRObject and x[0].value == 1 and not x[0].is_constant and x[0].name == 'n' + assert type(x[1]) == IRObject and x[1].value == 2 and not x[1].is_constant and x[1].name == 'n' + assert type(x[2]) == tensor_type and id(x[2].value) == id(t1) \ + and x[2].shape == (1,) and x[2].origin_shape == () and x[2].dtype == torch.float \ + and x[2].requires_grad == rg and not x[2].is_constant \ + and x[2].name == 'n' + y = obj['b'] + assert type(y) == IRObject and y.value == 2 and not y.is_constant and y.name == 'n' + + x = [t1, t2, 1] + obj = IRObject.from_complex('n', x, tosub=tosub, tensor_types=(), requires_grad=requires_grad) + assert type(obj) == IRObject and id(obj.value) == id(x) and not obj.is_constant and obj.name == 'n' + + obj = IRObject.from_complex('n', x, tosub=tosub, collection_types=(tuple,), requires_grad=requires_grad) + assert type(obj) == IRObject and id(obj.value) == id(x) and not obj.is_constant and obj.name == 'n' + + obj = IRObject.from_complex('n', [t1, [1, 2, {'a': 3}], (4, 5, {'b': 6, 'c': t2})], tosub=tosub, requires_grad=requires_grad) + assert type(obj) == list and len(obj) == 3 + assert type(obj[0]) == tensor_type and id(obj[0].value) == id(t1) \ + and obj[0].shape == (1,) and obj[0].origin_shape == () and obj[0].dtype == torch.float \ + and obj[0].requires_grad == rg and not obj[0].is_constant \ + and obj[0].name == 'n' + assert type(obj[1]) == IRObject and obj[1].value == [1, 2, {'a': 3}] and not obj[1].is_constant and obj[1].name == 'n' + x = obj[2] + assert type(x) == tuple and len(x) == 3 + assert type(x[0]) == IRObject and x[0].value == 4 and not x[0].is_constant and x[0].name == 'n' + assert type(x[1]) == IRObject and x[1].value == 5 and not x[1].is_constant and x[1].name == 'n' + y = x[2] + assert type(y) == dict and len(y) == 2 + assert type(y['b']) == IRObject and y['b'].value == 6 and not y['b'].is_constant and y['b'].name == 'n' + assert type(y['c']) == tensor_type and id(y['c'].value) == id(t2) \ + and y['c'].shape == (2,) and y['c'].origin_shape == (2,) and y['c'].dtype == torch.float \ + and y['c'].requires_grad == rgt and not y['c'].is_constant \ + and y['c'].name == 'n' + + t1 = TensorMetadata(shape=(), dtype=torch.float, requires_grad=False, + stride=None, memory_format=None, is_quantized=None, qparams=None) + t2 = TensorMetadata(shape=(2,), dtype=torch.float, requires_grad=True, + stride=None, memory_format=None, is_quantized=None, qparams=None) + + obj = IRObject.from_complex('n', {'a': t1, 'b': t2}.values(), + collection_types=(DICT_VALUES_TYPE,), + tensor_types=(TensorMetadata,), + tosub=tosub, requires_grad=requires_grad + ) + assert type(obj) == DICT_VALUES_TYPE and len(obj) == 2 + x = list(obj)[0] + assert type(x) == tensor_type and id(x.value) == id(t1) \ + and x.shape == (1,) and x.origin_shape == () and x.dtype == torch.float \ + and x.requires_grad == rg and not x.is_constant \ + and x.name == 'n' + y = list(obj)[1] + assert type(y) == tensor_type and id(y.value) == id(t2) \ + and y.shape == (2,) and y.origin_shape == (2,) and y.dtype == torch.float \ + and y.requires_grad == rgt and not y.is_constant \ + and y.name == 'n' + + obj = IRObject.from_complex('n', {'a': t1, 'b': t2}.items(), + collection_types=(DICT_ITEMS_TYPE,), + tensor_types=(TensorMetadata,), + tosub=tosub, requires_grad=requires_grad + ) + assert type(obj) == DICT_ITEMS_TYPE and len(obj) == 2 + x = list(obj)[0] + assert x[0] == 'a' + x = x[1] + assert type(x) == tensor_type and id(x.value) == id(t1) \ + and x.shape == (1,) and x.origin_shape == () and x.dtype == torch.float \ + and x.requires_grad == rg and not x.is_constant \ + and x.name == 'n' + y = list(obj)[1] + assert y[0] == 'b' + y = y[1] + assert type(y) == tensor_type and id(y.value) == id(t2) \ + and y.shape == (2,) and y.origin_shape == (2,) and y.dtype == torch.float \ + and y.requires_grad == rgt and not y.is_constant \ + and y.name == 'n' From adc72f6f2404ce4165a6547a65c81372d0e13a33 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 3 Sep 2024 08:32:01 +0000 Subject: [PATCH 1722/1892] Merged PR 2257: minitrainer: add run method for both train/compile minitrainer: add run method --- examples/llama3_8B_128K/train.py | 3 +-- nnscaler/cli/train.py | 3 +-- nnscaler/cli/trainer.py | 13 +++++++--- nnscaler/cli/trainer_args.py | 4 +++ tests/cli/test_trainer.py | 25 ++++++++++--------- .../lightning/pytorch/test_strategy.py | 2 +- tests/runtime/test_f16_optimizer.py | 4 +-- 7 files changed, 31 insertions(+), 23 deletions(-) diff --git a/examples/llama3_8B_128K/train.py b/examples/llama3_8B_128K/train.py index 8b116cb0..f74c3318 100644 --- a/examples/llama3_8B_128K/train.py +++ b/examples/llama3_8B_128K/train.py @@ -231,8 +231,7 @@ def collate_fn(samples): ) trainer = Trainer(train_args=trainer_args) - if args.run_mode == 'run': - trainer.train() + trainer.run() if __name__ == '__main__': diff --git a/nnscaler/cli/train.py b/nnscaler/cli/train.py index 3070dbf2..c6981d76 100644 --- a/nnscaler/cli/train.py +++ b/nnscaler/cli/train.py @@ -8,5 +8,4 @@ if __name__ == '__main__': nnscaler.utils.set_default_logger_level(level=logging.INFO) trainer = Trainer() - if trainer.train_args.run_mode == 'run': - trainer.train() + trainer.run() diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index f3faa6e8..188801e0 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -90,7 +90,12 @@ def __init__(self, self.max_train_steps = None self.loggers = [] self.hook = None + + def run(self): self._setup() + if self.train_args.compile_mode: + return + self._train() def _fix_input(self, input): if isinstance(input, dict): @@ -126,7 +131,7 @@ def _load_dummy_input(self): def _setup(self): self.train_args.init_env() - compile_only = self.train_args.run_mode == 'compile' + compile_only = self.train_args.compile_mode if is_running_distributed(): nnscaler.init() @@ -530,7 +535,7 @@ def _fix_batches(self, batches): batches += [self.dummy_input] * gap return batches, is_dummy_batch - def train(self): + def _train(self): logger.info('Training...') # reset peak memory stats before training # So that we can get accurate peak memory usage for each step @@ -550,7 +555,7 @@ def train(self): torch.distributed.barrier() self.hook.on_epoch_start(self, epoch) - self.train_epoch(epoch) + self._train_epoch(epoch) self.hook.on_epoch_end(self, epoch) if self.lr_scheduler and self.train_args.lr_scheduler.interval == 'epoch': @@ -639,7 +644,7 @@ def _validate(self, step_stat: _StepStat): logger.info(self._format_metrics(f'Validation', None, val_metrics)) return step_stat.val_loss - def train_epoch(self, epoch): + def _train_epoch(self, epoch): VAL_STATUS_NO = 0 # not validated or saved VAL_STATUS_VAL = 1 # validated but not saved VAL_STATUS_SAVE = 2 # validated and saved diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index f17b13c5..dfa91552 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -486,6 +486,10 @@ def update_freq(self): def enable_log_progress(self): return not self.enable_progress_bar and self.log_progress_every_n_train_steps + @property + def compile_mode(self) -> bool: + return self.run_mode == 'compile' + @property def param_dtype(self) -> torch.dtype: return _PRECISION_MAP[self.precision['param']] diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index a92cbea8..f3d7e80b 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -36,7 +36,7 @@ def trainer_logging_worker(save_dir): '--log.1.args.project', 'nnscaler', '--log.1.args.mode', 'offline', ]) - trainer.train() + trainer.run() torch.distributed.barrier() @@ -66,7 +66,7 @@ def test_trainer_compile_worker(tmp_path): config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) gen_savedir = save_dir / 'gen' # compile only - Trainer([ + trainer = Trainer([ '-f', config_path, '--max_epochs', '2', '--gen_savedir', str(gen_savedir), @@ -76,6 +76,7 @@ def test_trainer_compile_worker(tmp_path): '--run_mode', 'compile', '--broadcast_strategy', 'none', ]) + trainer.run() assert set([f.name for f in gen_savedir.glob('**/*.py')]) == set(['gencode0.py', 'gencode1.py', 'gencode2.py', 'gencode3.py']) shutil.rmtree(gen_savedir) @@ -107,7 +108,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', ]) - trainer.train() + trainer.run() ckpt_files = set(ckpt_savedir.glob('**/*.ckpt')) assert len(ckpt_files)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last @@ -130,7 +131,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', ]) - trainer.train() + trainer.run() ckpt0_files0 = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} assert len(ckpt0_files0)/4 == min(30, trainer.total_train_steps_per_epoch * 2) + 2 # 2 for best/last @@ -150,7 +151,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', ]) - trainer.train() + trainer.run() ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} # nothing should be updated in this case. assert ckpt0_files0 == ckpt0_files0_x @@ -178,7 +179,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', ]) - trainer.train() + trainer.run() left_files = { f: f.stat().st_mtime_ns for f in ckpt0_files0.keys() if f.exists() and f.parent.name not in ['last', 'best'] @@ -207,7 +208,7 @@ def trainer_resume_worker(save_dir, save_type, bf16): '--checkpoint.resume_from', str(ckpt1_savedir / 'merged.pt'), '--checkpoint.keep_last_n_checkpoints', '30', ]) - trainer.train() + trainer.run() left_files = { f: f.stat().st_mtime_ns for f in ckpt0_files0.keys() if f.exists() and f.parent.name not in ['last', 'best'] @@ -266,7 +267,7 @@ def trainer_last_checkpoint_worker(save_dir): '--checkpoint.every_n_train_steps', '15', '--checkpoint.save_dir', str(ckpt_savedir), ]) - trainer.train() + trainer.run() torch.distributed.barrier() # make sure the last checkpoint is saved. @@ -313,7 +314,7 @@ def trainer_loss_reduction_worker(save_dir): '--hook.after_aggregate_val_step_outputs', 'tests.cli.test_trainer.after_aggregate_val_step_outputs', ]) - trainer.train() + trainer.run() # get a copy train_loss_mean = _train_losses[:] @@ -340,7 +341,7 @@ def trainer_loss_reduction_worker(save_dir): '--hook.after_aggregate_val_step_outputs', 'tests.cli.test_trainer.after_aggregate_val_step_outputs', ]) - trainer.train() + trainer.run() torch.distributed.barrier() assert len(train_loss_mean) == len(_train_losses) @@ -409,7 +410,7 @@ def trainer_per_token_worker(save_dir): '--hook.before_gnorm_clip', 'tests.cli.test_trainer.before_gnorm_clip', ]) - trainer.train() + trainer.run() # get a copy grads = _before_step_grads @@ -433,7 +434,7 @@ def trainer_per_token_worker(save_dir): '--hook.before_gnorm_clip', 'tests.cli.test_trainer.before_gnorm_clip', ]) - trainer.train() + trainer.run() torch.distributed.barrier() diff --git a/tests/integration/lightning/pytorch/test_strategy.py b/tests/integration/lightning/pytorch/test_strategy.py index 350a5821..39229f1b 100644 --- a/tests/integration/lightning/pytorch/test_strategy.py +++ b/tests/integration/lightning/pytorch/test_strategy.py @@ -320,7 +320,7 @@ def on_val_step_end(trainer: Trainer, outputs, batches, idx) -> None: _correctnes_worker_train_loss_history.clear() _correctnes_worker_single_loss_history.clear() _correctnes_worker_val_loss_history.clear() - trainer.train() + trainer.run() return _correctnes_worker_update_history, trainer.model.fullmap, \ _correctnes_worker_val_loss_history, \ _correctnes_worker_train_loss_history, \ diff --git a/tests/runtime/test_f16_optimizer.py b/tests/runtime/test_f16_optimizer.py index f077f2ea..95609c90 100644 --- a/tests/runtime/test_f16_optimizer.py +++ b/tests/runtime/test_f16_optimizer.py @@ -23,7 +23,7 @@ def trainer_worker(save_dir): '--gen_savedir', str(gen_savedir), '--checkpoint.save_dir', str(ckpt0_savedir), ]) - trainer.train() + trainer.run() torch.distributed.barrier() # train with normal optimizer @@ -34,7 +34,7 @@ def trainer_worker(save_dir): '--gen_savedir', str(gen_savedir), '--checkpoint.save_dir', str(ckpt1_savedir), ]) - trainer.train() + trainer.run() torch.distributed.barrier() if torch.distributed.get_rank() == 0: From a375445a9ef251f0112079ef56d511470edeed15 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Fri, 6 Sep 2024 06:21:37 +0000 Subject: [PATCH 1723/1892] Merged PR 2258: remove unpolished examples Remove unpolished examples. Those examples will be added back when polished. --- examples/llama/__init__.py | 0 examples/llama/chat.py | 67 ----- examples/llama/generation.py | 344 ------------------------- examples/llama/model.py | 307 ---------------------- examples/llama/test_chat_completion.py | 108 -------- examples/llama/tokenizer.py | 41 --- examples/mlp/__init__.py | 0 examples/mlp/policy/__init__.py | 0 examples/mlp/policy/gallery.py | 89 ------- examples/mlp/train.py | 101 -------- examples/nlp/__init__.py | 0 examples/nlp/blocks/__init__.py | 0 examples/nlp/blocks/attention.py | 160 ------------ examples/nlp/blocks/mlp.py | 32 --- examples/nlp/blocks/transformer.py | 53 ---- examples/nlp/gpt/__init__.py | 0 examples/nlp/gpt/model.py | 103 -------- examples/nlp/gpt/policy/__init__.py | 0 examples/nlp/gpt/policy/mpmd.py | 126 --------- examples/nlp/gpt/policy/spmd.py | 93 ------- examples/nlp/gpt/train.py | 129 ---------- examples/nlp/mbart/__init__.py | 0 examples/nlp/mbart/model.py | 183 ------------- examples/nlp/mbart/policy/__init__.py | 0 examples/nlp/mbart/policy/gallery.py | 186 ------------- examples/nlp/mbart/train.py | 149 ----------- 26 files changed, 2271 deletions(-) delete mode 100644 examples/llama/__init__.py delete mode 100644 examples/llama/chat.py delete mode 100644 examples/llama/generation.py delete mode 100644 examples/llama/model.py delete mode 100644 examples/llama/test_chat_completion.py delete mode 100644 examples/llama/tokenizer.py delete mode 100644 examples/mlp/__init__.py delete mode 100644 examples/mlp/policy/__init__.py delete mode 100644 examples/mlp/policy/gallery.py delete mode 100644 examples/mlp/train.py delete mode 100644 examples/nlp/__init__.py delete mode 100644 examples/nlp/blocks/__init__.py delete mode 100644 examples/nlp/blocks/attention.py delete mode 100644 examples/nlp/blocks/mlp.py delete mode 100644 examples/nlp/blocks/transformer.py delete mode 100644 examples/nlp/gpt/__init__.py delete mode 100644 examples/nlp/gpt/model.py delete mode 100644 examples/nlp/gpt/policy/__init__.py delete mode 100644 examples/nlp/gpt/policy/mpmd.py delete mode 100644 examples/nlp/gpt/policy/spmd.py delete mode 100644 examples/nlp/gpt/train.py delete mode 100644 examples/nlp/mbart/__init__.py delete mode 100644 examples/nlp/mbart/model.py delete mode 100644 examples/nlp/mbart/policy/__init__.py delete mode 100644 examples/nlp/mbart/policy/gallery.py delete mode 100644 examples/nlp/mbart/train.py diff --git a/examples/llama/__init__.py b/examples/llama/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/llama/chat.py b/examples/llama/chat.py deleted file mode 100644 index 9e5925b9..00000000 --- a/examples/llama/chat.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -pip install fire sentencepiece - -PYTHONPATH=.:$PYTHONPATH torchrun \ - --nproc_per_node=1 \ - examples/llama/chat.py \ - --ckpt_dir=/home/t-zhiqilin/llama/llama-2-7b-chat \ - --tokenizer_path=/home/t-zhiqilin/llama/tokenizer.model \ - --max_seq_len 512 --max_batch_size 8 --temperature 0 \ - --use-cube -""" - -from typing import Optional - -import fire -import logging - -from examples.llama.generation import Llama - -import nnscaler -from nnscaler.utils import set_default_logger_level - -nnscaler.init() -set_default_logger_level(level=logging.WARNING) -logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) - - -def main( - ckpt_dir: str, - tokenizer_path: str, - temperature: float = 0.6, - top_p: float = 0.9, - max_seq_len: int = 512, - max_batch_size: int = 8, - max_gen_len: Optional[int] = None, - use_cube: bool = False, -): - generator = Llama.build( - ckpt_dir=ckpt_dir, - tokenizer_path=tokenizer_path, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - use_cube=use_cube, - ) - - dialog = [ - {"role": "system", "content": - "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature."}, - ] - - print('Assistant: Hello, this is Llama 2') - while True: - user_content = input("Prompt >> ") - dialog.append({"role": "user", "content": user_content}) - result = generator.chat_completion( - [dialog], - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - )[0] - assit_content = result['generation']['content'] - print(f"{result['generation']['role'].capitalize()}: {assit_content}") - dialog.append({"role": "assistant", "content": assit_content}) - - -if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file diff --git a/examples/llama/generation.py b/examples/llama/generation.py deleted file mode 100644 index d603da48..00000000 --- a/examples/llama/generation.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import json -import os -import sys -import time -from pathlib import Path -from typing import List, Literal, Optional, Tuple, TypedDict - -import torch -import torch.nn.functional as F - -from examples.llama.model import ModelArgs, Transformer -from examples.llama.tokenizer import Tokenizer - -Role = Literal["system", "user", "assistant"] - -import nnscaler -from nnscaler.compiler import compile -from nnscaler.utils import load_model -from nnscaler.flags import CompileFlag - - -class Message(TypedDict): - role: Role - content: str - - -class CompletionPrediction(TypedDict, total=False): - generation: str - tokens: List[str] # not required - logprobs: List[float] # not required - - -class ChatPrediction(TypedDict, total=False): - generation: Message - tokens: List[str] # not required - logprobs: List[float] # not required - - -Dialog = List[Message] - -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>\n", "\n<>\n\n" - -SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] -UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." - - -class Llama: - @staticmethod - def build( - ckpt_dir: str, - tokenizer_path: str, - max_seq_len: int, - max_batch_size: int, - use_cube: bool, - ) -> "Llama": - - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) - - # seed must be the same in all processes - torch.manual_seed(1) - - if local_rank > 0: - sys.stdout = open(os.devnull, "w") - - start_time = time.time() - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" - # assert model_parallel_size == len( - # checkpoints - # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" - ckpt_path = checkpoints[0] - # ckpt_path = checkpoints[get_model_parallel_rank()] - checkpoint = torch.load(ckpt_path, map_location="cpu") - with open(Path(ckpt_dir) / "params.json", "r") as f: - params = json.loads(f.read()) - - model_args: ModelArgs = ModelArgs( - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - **params, - ) - tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words - torch.set_default_tensor_type(torch.cuda.HalfTensor) - model = Transformer(model_args) - model.load_state_dict(checkpoint, strict=False) - print(f"Loaded in {time.time() - start_time:.2f} seconds") - - return Llama(model, tokenizer, use_cube) - - def __init__(self, model: Transformer, tokenizer: Tokenizer, use_cube: bool): - self.model = model - self.tokenizer = tokenizer - - # ======================= cube initilizer ================= - self.use_cube = use_cube - if use_cube: - print(f"Build using cube engine") - CompileFlag.disable_code_line_info = False - self.build_inference() - - def build_inference(self): - - sample_tokens = torch.randint( - 1, 1000, size=(4, 38), dtype=torch.int64) - - def policy(graph, resource): - from nnscaler.ir.operator import IRFwOperation - for fwop in graph.select(ntype=IRFwOperation): - graph.assign(fwop, 0) - return graph - - @compile(self.model, sample_tokens, 0, - PAS=policy, model_constant_folding=False) - def infer(model: torch.nn.Module, tokens: torch.Tensor, prev_pos: int): - logits = model(tokens, prev_pos) - return logits - - params = self.model.params - vocab_size, n_layers = params.vocab_size, params.n_layers - - del self.model - self.model = load_model() - - # TODO: support auto reset non-parameter attributes for llama model - self.model.params = params - self.model.vocab_size = vocab_size - self.model.n_layers = n_layers - self.infer_fn = (infer,) - - @torch.inference_mode() - def generate( - self, - prompt_tokens: List[List[int]], - max_gen_len: int, - temperature: float = 0.6, - top_p: float = 0.9, - logprobs: bool = False, - echo: bool = False, - ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: - params = self.model.params - bsz = len(prompt_tokens) - assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) - - min_prompt_len = min(len(t) for t in prompt_tokens) - max_prompt_len = max(len(t) for t in prompt_tokens) - assert max_prompt_len <= params.max_seq_len - total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) - - pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") - for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") - if logprobs: - token_logprobs = torch.zeros_like(tokens, dtype=torch.float) - - prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cuda") - input_text_mask = tokens != pad_id - for cur_pos in range(min_prompt_len, total_len): - if self.use_cube: - logits = self.infer_fn[0](self.model, tokens[:, prev_pos:cur_pos], prev_pos) - else: - logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - if logprobs: - token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( - input=logits.transpose(1, 2), - target=tokens[:, prev_pos + 1 : cur_pos + 1], - reduction="none", - ignore_index=pad_id, - ) - if temperature > 0: - probs = torch.softmax(logits[:, -1] / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) - else: - next_token = torch.argmax(logits[:, -1], dim=-1) - - next_token = next_token.reshape(-1) - # only replace token if prompt has already been generated - next_token = torch.where( - input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token - ) - tokens[:, cur_pos] = next_token - eos_reached |= (~input_text_mask[:, cur_pos]) & ( - next_token == self.tokenizer.eos_id - ) - prev_pos = cur_pos - if all(eos_reached): - break - - if logprobs: - token_logprobs = token_logprobs.tolist() - out_tokens, out_logprobs = [], [] - for i, toks in enumerate(tokens.tolist()): - # cut to max gen len - start = 0 if echo else len(prompt_tokens[i]) - toks = toks[start : len(prompt_tokens[i]) + max_gen_len] - probs = None - if logprobs: - probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] - # cut to eos tok if any - if self.tokenizer.eos_id in toks: - eos_idx = toks.index(self.tokenizer.eos_id) - toks = toks[:eos_idx] - probs = probs[:eos_idx] if logprobs else None - out_tokens.append(toks) - out_logprobs.append(probs) - return (out_tokens, out_logprobs if logprobs else None) - - def text_completion( - self, - prompts: List[str], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - echo: bool = False, - ) -> List[CompletionPrediction]: - if max_gen_len is None: - max_gen_len = self.model.params.max_seq_len - 1 - prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - generation_tokens, generation_logprobs = self.generate( - prompt_tokens=prompt_tokens, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - echo=echo, - ) - if logprobs: - return [ - { - "generation": self.tokenizer.decode(t), - "tokens": [self.tokenizer.decode(x) for x in t], - "logprobs": logprobs_i, - } - for t, logprobs_i in zip(generation_tokens, generation_logprobs) - ] - return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] - - def chat_completion( - self, - dialogs: List[Dialog], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - ) -> List[ChatPrediction]: - if max_gen_len is None: - max_gen_len = self.model.params.max_seq_len - 1 - prompt_tokens = [] - unsafe_requests = [] - for dialog in dialogs: - unsafe_requests.append( - any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) - ) - if dialog[0]["role"] == "system": - dialog = [ - { - "role": dialog[1]["role"], - "content": B_SYS - + dialog[0]["content"] - + E_SYS - + dialog[1]["content"], - } - ] + dialog[2:] - assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( - [msg["role"] == "assistant" for msg in dialog[1::2]] - ), ( - "model only supports 'system', 'user' and 'assistant' roles, " - "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" - ) - dialog_tokens: List[int] = sum( - [ - self.tokenizer.encode( - f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", - bos=True, - eos=True, - ) - for prompt, answer in zip( - dialog[::2], - dialog[1::2], - ) - ], - [], - ) - assert ( - dialog[-1]["role"] == "user" - ), f"Last message must be from user, got {dialog[-1]['role']}" - dialog_tokens += self.tokenizer.encode( - f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", - bos=True, - eos=False, - ) - prompt_tokens.append(dialog_tokens) - - generation_tokens, generation_logprobs = self.generate( - prompt_tokens=prompt_tokens, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - ) - if logprobs: - return [ - { - "generation": { - "role": "assistant", - "content": self.tokenizer.decode(t) - if not unsafe - else UNSAFE_ERROR, - }, - "tokens": [self.tokenizer.decode(x) for x in t], - "logprobs": logprobs_i, - } - for t, logprobs_i, unsafe in zip( - generation_tokens, generation_logprobs, unsafe_requests - ) - ] - return [ - { - "generation": { - "role": "assistant", - "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, - } - } - for t, unsafe in zip(generation_tokens, unsafe_requests) - ] - - -def sample_top_p(probs, p): - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token \ No newline at end of file diff --git a/examples/llama/model.py b/examples/llama/model.py deleted file mode 100644 index 64ae7a3b..00000000 --- a/examples/llama/model.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import math -from dataclasses import dataclass -from typing import Any, Optional, Tuple - - -import torch -import torch.nn.functional as F -from torch import nn - -import nnscaler - -@dataclass -class ModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - - max_batch_size: int = 32 - max_seq_len: int = 2048 - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -# TODO: fix annotation -@nnscaler.register_op('*, *, 38^ 64^ -> *, *') -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -@nnscaler.register_op('N seqlen^, N seqlen^ H^ -> 1 1 seqlen^ seqlen^') -def create_mask(tokens: torch.Tensor, h: torch.Tensor, start_pos: int): - seqlen = tokens.shape[1] - mask = None - if seqlen > 1: - mask = torch.full( - (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device - ) - mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) - return mask - - -@nnscaler.register_op('N seqlen *, 1 1 * -> N seqlen *') -def apply_mask(x: torch.Tensor, mask: torch.Tensor): - return x if mask is None else x + mask - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - self.n_local_heads = args.n_heads - self.n_local_kv_heads = self.n_kv_heads - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads - - self.wq = torch.nn.Linear( - args.dim, - args.n_heads * self.head_dim, - bias=False, - ) - self.wk = torch.nn.Linear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=False, - ) - self.wv = torch.nn.Linear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=False, - ) - self.wo = torch.nn.Linear( - args.n_heads * self.head_dim, - args.dim, - bias=False, - ) - - self.cache_k = torch.zeros( - ( - args.max_batch_size, - args.max_seq_len, - self.n_local_kv_heads, - self.head_dim, - ) - ).cuda() - self.cache_v = torch.zeros( - ( - args.max_batch_size, - args.max_seq_len, - self.n_local_kv_heads, - self.head_dim, - ) - ).cuda() - - def forward( - self, - x: torch.Tensor, - start_pos: int, - freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], - ): - bsz, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - # TODO: support register function with kwargs on tensor - # xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis) - - # modification: move `.to(xq)` to the belowing - # self.cache_k = self.cache_k.to(xq) - # self.cache_v = self.cache_v.to(xq) - - self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk - self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv - - keys = self.cache_k.to(xq)[:bsz, : start_pos + seqlen] - values = self.cache_v.to(xq)[:bsz, : start_pos + seqlen] - - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) - - # NOTE: cube doesn't support dynamic graph - # if mask is not None: - # scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) - scores = apply_mask(scores, mask) - - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - return self.wo(output) - - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = torch.nn.Linear( - dim, hidden_dim, bias=False - ) - self.w2 = torch.nn.Linear( - hidden_dim, dim, bias=False - ) - self.w3 = torch.nn.Linear( - dim, hidden_dim, bias=False - ) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs): - super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.head_dim = args.dim // args.n_heads - self.attention = Attention(args) - self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - ) - self.layer_id = layer_id - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - - def forward( - self, - x: torch.Tensor, - start_pos: int, - freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], - ): - h = x + self.attention.forward( - self.attention_norm(x), start_pos, freqs_cis, mask - ) - out = h + self.feed_forward.forward(self.ffn_norm(h)) - return out - - -class Transformer(nn.Module): - def __init__(self, params: ModelArgs): - super().__init__() - self.params = params - self.vocab_size = params.vocab_size - self.n_layers = params.n_layers - - self.tok_embeddings = torch.nn.Embedding( - params.vocab_size, params.dim - ) - - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) - - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = torch.nn.Linear( - params.dim, params.vocab_size, bias=False - ) - - self.freqs_cis = precompute_freqs_cis( - self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 - ) - - @torch.inference_mode() - def forward(self, tokens: torch.Tensor, start_pos: int): - - # TODO: support tracking dependency on kwarg IRObject - start_pos = start_pos + 0 - - _bsz, seqlen = tokens.shape - h = self.tok_embeddings(tokens) - # self.freqs_cis = self.freqs_cis.to(h.device) - freqs_cis = self.freqs_cis.to(h.device)[start_pos : start_pos + seqlen] - - # NOTE: cube doesn't support dynamic graph - # mask = None - # if seqlen > 1: - # mask = torch.full( - # (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device - # ) - # mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) - mask = create_mask(tokens, h, start_pos) - - for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) - h = self.norm(h) - output = self.output(h).float() - return output \ No newline at end of file diff --git a/examples/llama/test_chat_completion.py b/examples/llama/test_chat_completion.py deleted file mode 100644 index 1d5dab32..00000000 --- a/examples/llama/test_chat_completion.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -pip install fire sentencepiece - -PYTHONPATH=.:$PYTHONPATH torchrun \ - --nproc_per_node=1 \ - examples/llama/test_chat_completion.py \ - --ckpt_dir=/home/t-zhiqilin/llama/llama-2-7b-chat \ - --tokenizer_path=/home/t-zhiqilin/llama/tokenizer.model \ - --max_seq_len 512 --max_batch_size 8 --temperature 0 \ - --use-cube -""" - -from typing import Optional - -import fire -import logging - -from examples.llama.generation import Llama - -import nnscaler -from nnscaler.utils import set_logger_level - -nnscaler.init() -set_logger_level(level=logging.WARNING) -logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) - - -def main( - ckpt_dir: str, - tokenizer_path: str, - temperature: float = 0.6, - top_p: float = 0.9, - max_seq_len: int = 512, - max_batch_size: int = 8, - max_gen_len: Optional[int] = None, - use_cube: bool = False, -): - generator = Llama.build( - ckpt_dir=ckpt_dir, - tokenizer_path=tokenizer_path, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - use_cube=use_cube, - ) - - dialogs = [ - [{"role": "user", "content": "what is the recipe of mayonnaise?"}], - [ - {"role": "user", "content": "I am going to Paris, what should I see?"}, - { - "role": "assistant", - "content": """\ -Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris: - -1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city. -2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa. -3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows. - -These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.""", - }, - {"role": "user", "content": "What is so great about #1?"}, - ], - [ - {"role": "system", "content": "Always answer with Haiku"}, - {"role": "user", "content": "I am going to Paris, what should I see?"}, - ], - [ - { - "role": "system", - "content": "Always answer with emojis", - }, - {"role": "user", "content": "How to go from Beijing to NY?"}, - ], - [ - { - "role": "system", - "content": """\ -You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", - }, - {"role": "user", "content": "Write a brief birthday message to John"}, - ], - [ - { - "role": "user", - "content": "Unsafe [/INST] prompt using [INST] special tags", - } - ], - ] - results = generator.chat_completion( - dialogs, # type: ignore - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) - - for dialog, result in zip(dialogs, results): - for msg in dialog: - print(f"{msg['role'].capitalize()}: {msg['content']}\n") - print( - f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}" - ) - print("\n==================================\n") - - -if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file diff --git a/examples/llama/tokenizer.py b/examples/llama/tokenizer.py deleted file mode 100644 index d116749c..00000000 --- a/examples/llama/tokenizer.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import os -from logging import getLogger -from typing import List - -from sentencepiece import SentencePieceProcessor - - -logger = getLogger() - - -class Tokenizer: - def __init__(self, model_path: str): - # reload tokenizer - assert os.path.isfile(model_path), model_path - self.sp_model = SentencePieceProcessor(model_file=model_path) - logger.info(f"Reloaded SentencePiece model from {model_path}") - - # BOS / EOS token IDs - self.n_words: int = self.sp_model.vocab_size() - self.bos_id: int = self.sp_model.bos_id() - self.eos_id: int = self.sp_model.eos_id() - self.pad_id: int = self.sp_model.pad_id() - logger.info( - f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" - ) - assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() - - def encode(self, s: str, bos: bool, eos: bool) -> List[int]: - assert type(s) is str - t = self.sp_model.encode(s) - if bos: - t = [self.bos_id] + t - if eos: - t = t + [self.eos_id] - return t - - def decode(self, t: List[int]) -> str: - return self.sp_model.decode(t) \ No newline at end of file diff --git a/examples/mlp/__init__.py b/examples/mlp/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/mlp/policy/__init__.py b/examples/mlp/policy/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/mlp/policy/gallery.py b/examples/mlp/policy/gallery.py deleted file mode 100644 index b69974f5..00000000 --- a/examples/mlp/policy/gallery.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import List -from nnscaler.graph import IRGraph -from nnscaler.graph.segment import IRSegment -from nnscaler.ir.operator import IRDataOperation, IRFwOperation -from nnscaler.graph.schedule.predefined import PredefinedSched - -from examples.utils import tensor_parallelism, replica, create_mesh - - -def PASSingle(graph: IRGraph, resource, **kwargs): - """Single device""" - assert resource.ngpus == 1, "only apply for single gpu case" - for node in graph.nodes(): - if isinstance(node, (IRDataOperation, IRFwOperation)): - graph.assign(node, 0) - return graph - - -def PASData(graph: IRGraph, resource, **kwargs): - """Data Parallellism""" - devs = list(range(resource.ngpus)) - for node in graph.select(ntype=IRFwOperation): - tensor_parallelism(graph, node, idx=0, dim=0, devs=devs) - for node in graph.select(ntype=IRDataOperation): - replica(graph, node, devs=devs) - return graph - - -def PASCol(graph: IRGraph, resource, **kwargs): - """Linear Column Parallel""" - devs = list(range(resource.ngpus)) - for node in graph.select(name='linear'): - tensor_parallelism(graph, node, idx=1, dim=0, devs=devs) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - replica(graph, node, devs=devs) - return graph - - -def PASRow(graph: IRGraph, resource, **kwargs): - """Linear Row Parallel""" - devs = list(range(resource.ngpus)) - for node in graph.select(name='linear'): - tensor_parallelism(graph, node, idx=0, dim=1, devs=devs) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - replica(graph, node, devs=devs) - return graph - - -def PASMegatronTP(graph: IRGraph, resource, **kwargs): - """Linear Hybrid Parallelism (Megatron)""" - devs = list(range(resource.ngpus)) - for idx, node in enumerate(graph.select(name='linear')): - tensor_parallelism(graph, node, idx=1, dim=idx%2, devs=devs) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - replica(graph, node, devs=devs) - return graph - - -def PASMegatron(graph: IRGraph, resource, nmicros: int, tp_size: int, **kwargs): - - num_stages = resource.ngpus // tp_size - _, tp_mesh = create_mesh(resource.ngpus, (num_stages, tp_size)) - - # group to sub-graphs - linears = graph.select(name='linear') - stage_start_nodes = linears[::len(linears) // num_stages][:num_stages] - graph.staging(stage_start_nodes) - - segments = graph.select(ntype=IRSegment, flatten=False) - fsegs = [seg for seg in segments if seg.isfw()] - - for sid, segment in enumerate(fsegs): - # get tensor parallel group - tp_group = tp_mesh[sid] - for idx, node in enumerate(segment.nodes()): - if node.name == 'linear': - tensor_parallelism(graph, node, idx=1, dim=idx%2, devs=tp_group) - else: - replica(graph, node, devs=tp_group) - - for dl in graph.select(ntype=IRDataOperation): - replica(graph, dl, devs=list(range(resource.ngpus))) - - PredefinedSched.sched_1f1b(graph, nmicros, num_stages) - return graph - diff --git a/examples/mlp/train.py b/examples/mlp/train.py deleted file mode 100644 index dfefff34..00000000 --- a/examples/mlp/train.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -PYTHONPATH=.:$PYTHONPATH torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASMegatronTP -""" - -import torch -from torch import nn -from functools import partial - -import nnscaler -from nnscaler.compiler import compile -from nnscaler.profiler import CudaTimer -from nnscaler.profiler.timer import print_each_rank -from nnscaler.runtime.utils import microbatches - - -import examples.mlp.policy.gallery as gallery -from examples.utils import get_policy - -import argparse - -parser = argparse.ArgumentParser(description='MLP example') -parser.add_argument('--policy', type=str, help='policy choice, starting with "PAS"') -parser.add_argument('--dim', type=int, default=1024, help='model hidden size') -parser.add_argument('--layers', type=int, default=16, help='number of linear layers') -parser.add_argument('--gbs', type=int, default=64, help='global batch size') -parser.add_argument('--mbs', type=int, default=64, help='micro batch size') -parser.add_argument('--tp-size', type=int, default=2, help='tensor parallelism size only for Megatron policy') -args = parser.parse_args() - -nnscaler.init() - -# get policy -policy = get_policy([gallery], args.policy) -policy = partial(policy, nmicros=args.gbs//args.mbs, tp_size=args.tp_size) - -# =================== Semantic Model Description ==================== - -class MLP(nn.Module): - def __init__(self, dim: int, nlayers: int): - super().__init__() - self.layers = torch.nn.ModuleList([]) - for _ in range(nlayers): - self.layers.append(nn.Linear(dim, dim, bias=False)) - - def forward(self, data): - x = data - for layer in self.layers: - x = layer(x) - loss = torch.sum(x) - return loss - -def dummy_data(): - return torch.randn( - args.mbs, args.dim, device=torch.cuda.current_device()) - - -def train(): - - model = MLP(dim=args.dim, nlayers=args.layers) - # create dummy data - dataloader = microbatches((dummy_data(),)) - - # compile a training iteration - @compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - data = next(dataloader) - loss = model(data) - loss.backward() - # load generated model - model = nnscaler.utils.load_model() - - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - - CudaTimer(enable=False).warmup() - iter_num, warmup = 5, 2 - for step in range(iter_num): - if step == warmup: - CudaTimer(enable=True).start('e2e') - - # get data samples - samples = [dummy_data() for _ in range(args.gbs // args.mbs)] - dataloader = microbatches(samples) - # run training iteration - train_iter(model, dataloader) - - optimizer.step() - optimizer.zero_grad() - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - CudaTimer().stop('e2e') - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) - - -if __name__ == '__main__': - train() \ No newline at end of file diff --git a/examples/nlp/__init__.py b/examples/nlp/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/nlp/blocks/__init__.py b/examples/nlp/blocks/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/nlp/blocks/attention.py b/examples/nlp/blocks/attention.py deleted file mode 100644 index 913363ec..00000000 --- a/examples/nlp/blocks/attention.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch -import nnscaler - - -@nnscaler.register_op('L^ N E^, (h+ d^ 3) E^, (h+ d^ 3), E^ (h+ d^) -> L^ N E^', name='self_attention') -def self_attention(query: torch.Tensor, - qkv_proj: torch.Tensor, qkv_bias: torch.Tensor, - out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = False): - num_head = h - L, N = query.size(0), query.size(1) - dim_head = qkv_proj.size(0) // num_head // 3 - - qkv = torch.nn.functional.linear(query, qkv_proj, qkv_bias) # L N E, (h d 3) E -> L N (h d 3) - qkv = qkv.view(L, N, num_head * dim_head, 3) # L N (h d 3) -> L N (h d) 3 - q, k, v = qkv.chunk(3, dim=-1) # L N (3 h d) -> L N (h d), L N (h d), L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - - # ======== replace the semantic into more efficient implementation ============ - # q = q.transpose(0, 1) # L (N h) d -> (N h) L d - # k = k.transpose(0, 1) # L (N h) d -> (N h) L d - # q = q * scale # (N h) L d, 1 -> (N h) L d - # k = k.transpose(1, 2) # (N h) L d -> (N h) d L - # attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # preallocating input tensor: (N h) L L - matmul_input_buffer = torch.empty([N * h, L, L], dtype=query.dtype, device=query.device) - # L (N h) d, L (N h) d -> (N h) L L - attn = torch.baddbmm( - matmul_input_buffer, - q.transpose(0, 1), # (N h) L d - k.transpose(0, 1).transpose(1, 2), # (N h) d L - beta=0.0, alpha=scale - ) - # ======== replace the semantic into more efficient implementation ============ - - # attention mask - if mask: # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - amask = torch.tril(ones) - amask = amask.view(N, 1, L, L) - amask = (amask < 0.5) - attn = attn.masked_fill_(amask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj) # L N (h d), E E -> L N E - return output - - -@nnscaler.register_op('L^ N E^, L^ N E^, (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), (h+ d) E^, (h+ d), E^ (h+ d) -> L^ N E^', name='cross_attention') -def cross_attention(query: torch.Tensor, key: torch.Tensor, - q_proj: torch.Tensor, q_bias: torch.Tensor, - k_proj: torch.Tensor, k_bias: torch.Tensor, - v_proj: torch.Tensor, v_bias: torch.Tensor, - out_proj: torch.Tensor, - h: int, scale: float, dropout_p: float, mask: bool = False): - num_head = h - L, N = query.size(0), query.size(1) - dim_head = q_proj.size(0) // num_head - - q = torch.nn.functional.linear(query, q_proj, q_bias) # L N E, (h d) E -> L N (h d) - k = torch.nn.functional.linear(key, k_proj, k_bias) # L N E, (h d) E -> L N (h d) - v = torch.nn.functional.linear(key, v_proj, v_bias) # L N E, (h d) E -> L N (h d) - q = q.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - k = k.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - v = v.contiguous().view(L, (N * num_head), dim_head) # L N (h d) -> L (N h) d - q = q.transpose(0, 1) # L (N h) d -> (N h) L d - k = k.transpose(0, 1) # L (N h) d -> (N h) L d - v = v.transpose(0, 1) # L (N h) d -> (N h) L d - q = q * scale # (N h) L d, 1 -> (N h) L d - k = k.transpose(1, 2) # (N h) L d -> (N h) d L - attn = torch.bmm(q, k) # (N h) L d, (N h) d L -> (N h) L L - - # attention mask - if mask: # (N h) L L -> (N h) L L - attn = attn.view(N, num_head, L, L) - ones = torch.ones((N, L, L), device=attn.device) - amask = torch.tril(ones) - amask = amask.view(N, 1, L, L) - amask = (amask < 0.5) - attn = attn.masked_fill_(amask, -10000.0) - attn = attn.view((N * num_head), L, L) - - attn = torch.nn.functional.softmax(attn, dim=-1) # (N h) L L -> (N h) L L - attn = torch.nn.functional.dropout(attn, dropout_p, True, False) # (N h) L L -> (N h) L L - output = torch.bmm(attn, v) # (N h) L L, (N h) L d -> (N h) L d - output = output.transpose(0, 1).contiguous() # (N h) L d -> L (N h) d - output = output.view(L, N, num_head * dim_head) # (N h) L d -> L N (h d) - output = torch.nn.functional.linear(output, out_proj, None) # L N (h d), E E -> L N E - return output - - -class MultiHeadSelfAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): - super().__init__() - self.inner_dim = inner_dim - self.num_heads = num_heads - self.head_dim = inner_dim // num_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # QKV [(h d 3), E] - self.qkv_proj = torch.nn.Parameter(torch.empty(3 * inner_dim, embed_dim)) - self.qkv_bias = torch.nn.Parameter(torch.empty(3 * inner_dim)) - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) - - def forward(self, query): - attn = self_attention( - query, self.qkv_proj, self.qkv_bias, - self.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=False - ) - attn = attn + self.out_bias - return attn - - -class MultiHeadCrossAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, inner_dim: int, dropout: float = 0.0): - super().__init__() - self.inner_dim = inner_dim - self.num_heads = num_heads - self.head_dim = inner_dim // num_heads - self.scaling = self.head_dim ** -0.5 - self.dropout_p = dropout - # Q - self.q_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.q_bias = torch.nn.Parameter(torch.empty(inner_dim)) - # K - self.k_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.k_bias = torch.nn.Parameter(torch.empty(inner_dim)) - # V - self.v_proj = torch.nn.Parameter(torch.empty(inner_dim, embed_dim)) - self.v_bias = torch.nn.Parameter(torch.empty(inner_dim)) - # Out - self.out_proj = torch.nn.Parameter(torch.empty(embed_dim, inner_dim)) - self.out_bias = torch.nn.Parameter(torch.empty(embed_dim)) - - def forward(self, query: torch.Tensor, key: torch.Tensor): - attn = cross_attention( - query, key, - self.q_proj, self.q_bias, - self.k_proj, self.k_bias, - self.v_proj, self.v_bias, - self.out_proj, - self.num_heads, self.scaling, self.dropout_p, mask=True - ) - attn = attn + self.out_bias - return attn diff --git a/examples/nlp/blocks/mlp.py b/examples/nlp/blocks/mlp.py deleted file mode 100644 index 4a5f3e7c..00000000 --- a/examples/nlp/blocks/mlp.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -import nnscaler - - -@nnscaler.register_op('L^ N E^, H+ E^, H+, E^ H+ -> L^ N E^', name='feedforward') -def feedforward(x: torch.Tensor, - proj1: torch.Tensor, proj1_bias: torch.Tensor, - proj2: torch.Tensor, - dropout: float, - is_training: bool = True) -> torch.Tensor: - x = torch.nn.functional.linear(x, proj1, proj1_bias) - x = torch.nn.functional.gelu(x) - x = torch.nn.functional.dropout(x, dropout, is_training, False) - x = torch.nn.functional.linear(x, proj2, None) - return x - - -class MLP(torch.nn.Module): - - def __init__(self, embed_dim: int, hidden_dim: int, dropout: float): - super().__init__() - self.proj1 = torch.nn.Parameter(torch.empty((hidden_dim, embed_dim))) - self.proj1_bias = torch.nn.Parameter(torch.empty((hidden_dim,))) - self.proj2 = torch.nn.Parameter(torch.empty((embed_dim, hidden_dim))) - self.proj2_bias = torch.nn.Parameter(torch.empty((embed_dim,))) - self.dropout = dropout - - def forward(self, x: torch.Tensor): - x = feedforward(x, self.proj1, self.proj1_bias, - self.proj2, self.dropout, self.training) - x = x + self.proj2_bias - return x diff --git a/examples/nlp/blocks/transformer.py b/examples/nlp/blocks/transformer.py deleted file mode 100644 index 069a370e..00000000 --- a/examples/nlp/blocks/transformer.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -from examples.nlp.blocks.attention import MultiHeadSelfAttention -from examples.nlp.blocks.attention import MultiHeadCrossAttention -from examples.nlp.blocks.mlp import MLP - - -class TransformerLayer(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, - attn_hidden_dim: int, ffn_hidden_dim: int, - dropout: float = 0.2, atten_dropout: float = 0.2, activation_dropout: float = 0.2, - use_cross_attention: bool = False): - super().__init__() - self.self_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.self_attn = MultiHeadSelfAttention( - embed_dim, num_heads, attn_hidden_dim, atten_dropout - ) - - self.use_cross_attention = use_cross_attention - if use_cross_attention: - self.cross_attn_layer_norm = torch.nn.LayerNorm(embed_dim) - self.cross_attn = MultiHeadCrossAttention( - embed_dim, num_heads, attn_hidden_dim, atten_dropout - ) - - self.dropout = torch.nn.Dropout(p=dropout) - self.mlp = MLP(embed_dim, ffn_hidden_dim, activation_dropout) - self.final_layer_norm = torch.nn.LayerNorm(embed_dim) - - def forward(self, x: torch.Tensor, encoder_output = None) -> torch.Tensor: - # self attention - residual = x - x = self.self_attn_layer_norm(x) - x = self.self_attn(x) - x = self.dropout(x) - x = x + residual - - # cross attention - if self.use_cross_attention: - residual = x - x = self.cross_attn_layer_norm(x) - x = self.cross_attn(x, encoder_output) - x = self.dropout(x) - x = x + residual - - # mlp - residual = x - x = self.final_layer_norm(x) - x = self.mlp(x) - x = self.dropout(x) - x = x + residual - - return x diff --git a/examples/nlp/gpt/__init__.py b/examples/nlp/gpt/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/nlp/gpt/model.py b/examples/nlp/gpt/model.py deleted file mode 100644 index 6fca2959..00000000 --- a/examples/nlp/gpt/model.py +++ /dev/null @@ -1,103 +0,0 @@ -import torch -from dataclasses import dataclass - -import nnscaler - -from examples.nlp.blocks.transformer import TransformerLayer - - -@dataclass -class Config: - hidden: int = 1024 - layers: int = 8 - heads: int = 16 - ffn_hidden_dim: int = 4096 - num_embeddings: int = 51200 - seqlen: int = 1024 - dropout: float = 0.2 - attn_dropout: float = 0.2 - activation_dropout: float = 0.2 - - -def build_gpt_config(name: str) -> Config: - if name == 'toy': - hidden, layers, heads = 1024, 4, 16 - elif name == '350M': - hidden, layers, heads = 1024, 24, 16 - elif name == '760M': - hidden, layers, heads = 1536, 24, 16 - elif name == '1.3B': - hidden, layers, heads = 2048, 24, 32 - elif name == '2.6B': - hidden, layers, heads = 2560, 32, 32 - elif name == '6.7B': - hidden, layers, heads = 4096, 32, 32 - elif name == '15B': - hidden, layers, heads = 5120, 48, 40 - elif name == '39B': - hidden, layers, heads = 8192, 48, 64 - elif name == '175B': - hidden, layers, heads = 12288, 96, 96 - else: - assert False, f'unrecognized name: {name}' - return Config(hidden, layers, heads, hidden, 4 * hidden) - - -class GPT(torch.nn.Module): - - def __init__(self, cfg: Config): - super().__init__() - - # self.embed = torch.nn.Embedding(cfg.num_embeddings, cfg.hidden) - self.embedw = torch.nn.Parameter(torch.empty(cfg.num_embeddings, cfg.hidden)) - self.position = torch.nn.Embedding(cfg.seqlen, cfg.hidden) - self.embed_dropout = torch.nn.Dropout() - - self.layers = torch.nn.ModuleList( - [TransformerLayer( - cfg.hidden, cfg.heads, - cfg.hidden, cfg.ffn_hidden_dim, - cfg.dropout, cfg.attn_dropout, cfg.activation_dropout, - use_cross_attention=False, - ) for _ in range(cfg.layers)] - ) - self.final_layernorm = torch.nn.LayerNorm(cfg.hidden) - - def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): - - # embed = self.embed(input_ids) - embed = torch.nn.functional.embedding( - input_ids, self.embedw, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False - ) - pos_embed = self.position(position_ids) - embed = embed + pos_embed - embed = self.embed_dropout(embed) - enc = embed.transpose(0, 1) - - for layer in self.layers: - nnscaler.runtime.function.anchor('transformer start') - enc = layer(enc) - enc = self.final_layernorm(enc) - - # logits = torch.nn.functional.linear(enc, self.embed.weight) - logits = torch.nn.functional.linear(enc, self.embedw) - # simplified - loss = torch.sum(logits) - return loss - - -def dummy_data(batch_size: int, cfg: Config): - - input_ids = torch.randint( - 0, cfg.num_embeddings, - size=(batch_size, cfg.seqlen,), - dtype=torch.int64, - device=torch.cuda.current_device() - ) - position_ids = torch.arange( - 0, cfg.seqlen, dtype=torch.int64, - device=torch.cuda.current_device() - ).repeat(batch_size, 1).view(batch_size, cfg.seqlen,) - - return input_ids, position_ids diff --git a/examples/nlp/gpt/policy/__init__.py b/examples/nlp/gpt/policy/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/nlp/gpt/policy/mpmd.py b/examples/nlp/gpt/policy/mpmd.py deleted file mode 100644 index ec6b1041..00000000 --- a/examples/nlp/gpt/policy/mpmd.py +++ /dev/null @@ -1,126 +0,0 @@ -"""GPT policy gallery for MPMD Parallelism""" -from nnscaler.graph import IRGraph -from nnscaler.graph.segment import IRSegment -from nnscaler.ir.operator import IRDataOperation, IRFwOperation -from nnscaler.graph.schedule.predefined import PredefinedSched - -from examples.utils import create_mesh, tensor_parallelism, replica, group_to_layers - - -def PASRoundRobin(graph: IRGraph, resource, **kwargs): - """ - roundrobin scheduling - """ - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - - # group to transformer layers - transformers = group_to_layers(fnodes) - - for lid, transformer in enumerate(transformers): - stage_id = lid % resource.ngpus - print(f'assigning {lid} transformer to stage {stage_id}') - for node in transformer: - graph.assign(node, stage_id) - - for node in graph.nodes(): - if len(node.device) == 0: - graph.assign(node, 0) - - return graph - - -def PAS1F1B(graph: IRGraph, resource, nmicros: int = 16, **kwargs): - """1F1B schedule""" - num_stages = resource.ngpus - num_microbatch = nmicros - - # group to transformer layers - fnodes = [node for node in graph.nodes() if isinstance(node, IRFwOperation)] - transformers = group_to_layers(fnodes) - assert len(transformers) >= num_stages - - # staging - fstages = [[] for _ in range(num_stages)] - nlayer_per_stage = (len(transformers) // resource.ngpus) - for lid, fnodes in enumerate(transformers): - stage_id = min(lid // nlayer_per_stage, num_stages - 1) - fstages[stage_id] += fnodes - graph.staging(tuple(stages[0] for stages in fstages)) - - # stage to device - fsegments = [seg for seg in graph.nodes() if isinstance(seg, IRSegment) and seg.isfw()] - assert len(fsegments) == num_stages - for devid, segment in enumerate(fsegments): - graph.assign(segment, devid) - - for node in graph.nodes(): - if isinstance(node, IRDataOperation): - graph.assign(node, 0) - - if graph.train: - PredefinedSched.sched_1f1b(graph, num_microbatch, num_stages) - else: - PredefinedSched.sched_infer_pipe(graph, num_microbatch, num_stages) - return graph - - -def PASMegatron(graph: IRGraph, resource, - tp_size: int = 2, dp_size: int = 1, - nmicros: int = 16, **kwargs ): - """Megatron policy for hybrid data-tensor-pipeline parallelism""" - pp_size = resource.ngpus // (dp_size * tp_size) - num_microbatch = nmicros - - # device mesh - dp_groups, pp_groups, tp_groups = \ - create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) - print(f'dp groups: {dp_groups}') - print(f'pp groups: {pp_groups}') - print(f'tp groups: {tp_groups}') - - def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: - return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] - - # group to transformer layers - transformers = group_to_layers(graph.select(ntype=IRFwOperation)) - - # group to stage: set each stage operators - fstages = [[] for _ in range(pp_size)] - nlayer_per_stage = (len(transformers) // pp_size) - for lid, fnodes in enumerate(transformers): - stage_id = min(lid // nlayer_per_stage, pp_size - 1) - fstages[stage_id] += fnodes - graph.staging(tuple(stages[0] for stages in fstages)) - - dataloader = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[0] - - # partition dataloader - dls = replica(graph, dataloader, [0]*dp_size) # graph.partition(dataloader, dataloader.algorithms('data'), num=dp_size) - for dp_idx, dl in enumerate(dls): - # only stage 0 needs dataloader - devices = [get_device(dp_idx, 0, tp_idx) for tp_idx in range(tp_size)] - replica(graph, dl, devices) - - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - assert len(fstages) > 0 - for pp_idx, fstage in enumerate(fstages): - for fnode in fstage.nodes(): - if len(fnode.inputs()) == 0: continue # anchor - if fnode.name == 'self_attention' or fnode.name == 'feedforward': - fnodes = tensor_parallelism(graph, fnode, idx=1, dim=0, devs=[0]*tp_size) - elif fnode.name == 'embedding': - fnodes = tensor_parallelism(graph, fnode, idx=1, dim=0, devs=[0]*tp_size) - elif fnode.name == 'linear': # the last embeding linear - fnodes = tensor_parallelism(graph, fnode, idx=1, dim=0, devs=[0]*tp_size) - elif fnode.name == 'sum': - fnodes = tensor_parallelism(graph, fnode, idx=0, dim=2, devs=[0]*tp_size) - else: - fnodes = replica(graph, fnode, [0]*tp_size) - # data parallel - for tp_idx, fnode in enumerate(fnodes): - dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] - batch_dim = fnode.input(0).shape.index(bs) - tensor_parallelism(graph, fnode, idx=0, dim=batch_dim, devs=dp_devices) - PredefinedSched.sched_1f1b(graph, num_microbatch, pp_size) - return graph diff --git a/examples/nlp/gpt/policy/spmd.py b/examples/nlp/gpt/policy/spmd.py deleted file mode 100644 index 1c6da6db..00000000 --- a/examples/nlp/gpt/policy/spmd.py +++ /dev/null @@ -1,93 +0,0 @@ -"""GPT policy gallery for MPMD Parallelism""" - -from typing import List - -from nnscaler.graph import IRGraph -from nnscaler.graph.function.pyfunc import IRPyFunc -from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation - -from examples.utils import tensor_parallelism, replica - - -# coshard -def coshard(graph: IRGraph, node: IRFwOperation, devs: List[int], colocate: int, - idx: int, dim: int): - algo = node.algorithms('dim') - sub_nodes = graph.partition(node, algo, idx=idx, dim=dim, num=colocate*len(devs)) - assert sub_nodes is not None - graph.recompute(sub_nodes) - for devid in devs: - for coid in range(colocate): - sub_node = sub_nodes[devid * colocate + coid] - graph.assign(sub_node, devid) - return sub_nodes - - -def PASSingle(graph: IRGraph, resource, **kwargs): - """Single-device execution""" - assert resource.ngpus == 1 - for node in graph.nodes(): - if not isinstance(node, IRBpOperation): - graph.assign(node, 0) - return graph - - -def PASDP(graph: IRGraph, resource, **kwargs): - """Data parallelism""" - devs = list(range(resource.ngpus)) - dataloader = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[0] - # replicate dataloader - replica(graph, dataloader, devs) - # partition forward operators - for node in graph.select(ntype=IRFwOperation): - if isinstance(node, IRPyFunc): - graph.assign(node, 0) - continue - if len(node.inputs()) == 0: continue - batch_dim = node.input(0).shape.index(bs) - tensor_parallelism(graph, node, idx=0, dim=batch_dim, devs=devs) - return graph - - -def PASMegatronTP(graph: IRGraph, resource, **kwargs): - """Megatron-way tensor parallelism""" - devs = list(range(resource.ngpus)) - # attention - for attn in graph.select(name='self_attention'): - tensor_parallelism(graph, attn, idx=1, dim=0, devs=devs) - # feedforward - for ffn in graph.select(name='feedforward'): - tensor_parallelism(graph, ffn, idx=1, dim=0, devs=devs) - # partition embed - for embed in graph.select(name='embedding'): - tensor_parallelism(graph, embed, idx=1, dim=0, devs=devs) - # partition last linear - linears = graph.select(name='linear') - tensor_parallelism(graph, linears[-1], idx=1, dim=0, devs=devs) - # partition loss - sums = graph.select(name='sum') - tensor_parallelism(graph, sums[0], idx=0, dim=2, devs=devs) - # replica other nodes - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - replica(graph, node, devs) - return graph - - -def PASMeshShard(graph: IRGraph, resource, **kwargs): - """Coshard policy for long sequence""" - devs = list(range(resource.ngpus)) - # attention - for attn in graph.select(name='self_attention'): - # tensor_parallelism(graph, attn, idx=1, dim=0, devs) - coshard(graph, attn, devs, colocate=2, idx=1, dim=0) - # feedforward - for ffn in graph.select(name='feedforward'): - # tensor_parallelism(graph, ffn, idx=1, dim=0, devs) - coshard(graph, ffn, devs, colocate=4, idx=1, dim=0) - # replica other nodes - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - if len(node.device) == 0: - replica(graph, node, devs) - return graph diff --git a/examples/nlp/gpt/train.py b/examples/nlp/gpt/train.py deleted file mode 100644 index 1982e135..00000000 --- a/examples/nlp/gpt/train.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -example: - -PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - examples/nlp/gpt/train.py --policy PASMegatronTP --fp16 -""" - - -import torch -import logging -from functools import partial - -from model import GPT, Config, dummy_data - -import nnscaler -from nnscaler.compiler import compile -from nnscaler.utils import set_default_logger_level -from nnscaler.profiler.timer import CudaTimer, print_each_rank -from nnscaler.profiler.memory import memory_summary -from nnscaler.runtime.utils import microbatches - -import examples.nlp.gpt.policy.spmd as spmd -import examples.nlp.gpt.policy.mpmd as mpmd - -from examples.utils import get_policy - -import argparse - -parser = argparse.ArgumentParser(description='GPT Train') - -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') -parser.add_argument('--fp16', action='store_true', default=False, - help='use fp16 for the training') -parser.add_argument('--mbs', type=int, default=8, - help='micro-batch size') -parser.add_argument('--gbs', type=int, default=8, - help='global batch size') -parser.add_argument('--dp', type=int, default=1, - help='data parallel size, only for megatron') -parser.add_argument('--tp', type=int, default=1, - help='tensor parallel size, only for megatron') - -# arch -parser.add_argument('--layers', type=int, default=4, - help='number of transformer layers') -parser.add_argument('--hidden', type=int, default=1024, - help='hidden size') -parser.add_argument('--heads', type=int, default=16, - help='number of attention heads') -parser.add_argument('--seqlen', type=int, default=1024, - help='sequence length') -args = parser.parse_args() - - -nnscaler.init() -set_default_logger_level(logging.WARN) -logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) - -# get policy -policy = get_policy([spmd, mpmd], args.policy) -policy = partial(policy, - nmicros=args.gbs//args.mbs, - dp_size=args.dp, - tp_size=args.tp -) - - -def train(): - - config = Config( - hidden=args.hidden, - layers=args.layers, - heads=args.heads, - ffn_hidden_dim=4*args.hidden, - num_embeddings=51200, - seqlen=args.seqlen, - ) - model = GPT(config) - model = model if not args.fp16 else model.half() - - gen_data = partial(dummy_data, args.mbs, config) - dataloader = microbatches((gen_data(),), cycle=True) - - @compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - input_ids, position_ids = next(dataloader) - loss = model(input_ids, position_ids) - loss.backward() - model = nnscaler.utils.load_model() - - optimizer = torch.optim.Adam(model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - torch.distributed.barrier() - print_each_rank('model weight consumpition:', rank_only=0) - memory_summary() - - CudaTimer().warmup() - iter_num, warmup = 5, 2 - for step in range(iter_num): - if step == warmup: - CudaTimer(enable=True).start('e2e') - - # collect dummy data - samples = [gen_data() for _ in range(args.gbs // args.mbs)] - dataloader = microbatches(samples) - - # train - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - - if step == 0: - print_each_rank('passed first iteration') - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - CudaTimer().stop('e2e') - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) - - memory_summary() - - -if __name__ == '__main__': - - nnscaler.init() - train() \ No newline at end of file diff --git a/examples/nlp/mbart/__init__.py b/examples/nlp/mbart/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/nlp/mbart/model.py b/examples/nlp/mbart/model.py deleted file mode 100644 index 5a9e6675..00000000 --- a/examples/nlp/mbart/model.py +++ /dev/null @@ -1,183 +0,0 @@ -import torch -import math -from dataclasses import dataclass - -from examples.nlp.blocks.transformer import TransformerLayer - -import nnscaler - - -@dataclass -class Config: - - hidden: int = 1024 - heads: int = 16 - layers: int = 4 # for encoder and decoder layers separately - seqlen: int = 2048 - ffn_hidden_dim: int = 4096 - vocab: int = 2500 - - attention_dropout: float = 0.2 - dropout: float = 0.2 - activation_dropout: float = 0.2 - - pad_token_id: int = 1 - eos_token_id: int = 1 - num_classes: int = 3 - - -class PositionalEmbedding(torch.nn.Embedding): - - def __init__(self, vocab: int, embedding_dim: int): - self.offset = 2 - super().__init__(vocab + self.offset, embedding_dim) - - def forward(self, seq_len: int): - positions = torch.arange( - 0, seq_len, dtype=torch.long, device=torch.cuda.current_device() - ) - return super().forward(positions + self.offset) - - -class MBartClassificationHead(torch.nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - - self.num_classes = num_classes - self.dense = torch.nn.Linear(input_dim, inner_dim) - self.dropout = torch.nn.Dropout(p=pooler_dropout) - self.out_proj = torch.nn.Linear(inner_dim, num_classes) - self.loss_fct = torch.nn.CrossEntropyLoss() - - # def forward(self, dec: torch.Tensor, labels): - def forward(self, dec: torch.Tensor): - # sentence_represent = dec[eos_mask,:].view(dec.size(0), -1, hidden_states.size(-1))[:,-1,:] - dec = torch.select(dec, dim=1, index=-1) - # dec = dec[:,-1,:] - sentence_represent = dec - hidden_states = self.dropout(sentence_represent) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - logits = self.out_proj(hidden_states) - # loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) - loss = logits.sum() - return loss - - -class MBartForSentenceClassification(torch.nn.Module): - - def __init__(self, batch_size: int, cfg: Config): - super().__init__() - self.vocab_size = cfg.vocab - # embedding - self.vocab = torch.nn.Parameter(torch.empty( - cfg.vocab, cfg.hidden)) - # encoder embedding - self.embed_offset = 2 - self.encoder_position = torch.nn.Parameter(torch.empty( - cfg.seqlen, cfg.hidden)) - self.embed_scale_encoder = math.sqrt(cfg.hidden) - self.layernorm_embedding_encoder = torch.nn.LayerNorm(cfg.hidden) - - # encoder layers - self.encoders = torch.nn.ModuleList( - [TransformerLayer( - cfg.hidden, cfg.heads, - cfg.hidden, cfg.ffn_hidden_dim, - cfg.dropout, cfg.attention_dropout, cfg.activation_dropout, - use_cross_attention=False, - ) for _ in range(cfg.layers)] - ) - self.layer_norm_encoder = torch.nn.LayerNorm(cfg.hidden) - - # decoder embedding - self.decoder_position = torch.nn.Parameter(torch.empty( - cfg.seqlen, cfg.hidden)) - self.embed_scale_decoder = math.sqrt(cfg.hidden) - self.layernorm_embedding_decoder = torch.nn.LayerNorm(cfg.hidden) - - # decoder layers - self.decoders = torch.nn.ModuleList( - [TransformerLayer( - cfg.hidden, cfg.heads, - cfg.hidden, cfg.ffn_hidden_dim, - cfg.dropout, cfg.attention_dropout, cfg.activation_dropout, - use_cross_attention=True, - ) for _ in range(cfg.layers)] - ) - self.layer_norm_decoder = torch.nn.LayerNorm(cfg.hidden) - self.head = MBartClassificationHead(cfg.hidden, 1024, cfg.num_classes, 0.0) - - def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor): - """ - The forward is only for benchmark performance, - the original input of input_ids, decoder_input_ids and labels are - simplied by using only ine input_ids. - - The loss computation is also simplified by using sum. - """ - # encoder embedding - nnscaler.runtime.function.anchor('encoder embedding') - enc_emb = torch.nn.functional.embedding(input_ids, self.vocab) - enc_emb = enc_emb * self.embed_scale_encoder - enc_emb = enc_emb + self.encoder_position - enc_emb = self.layernorm_embedding_encoder(enc_emb) - enc_emb = torch.nn.functional.dropout(enc_emb, p=0.1) - enc = enc_emb.transpose(0, 1) - - # encoder layers - for layer in self.encoders: - nnscaler.runtime.function.anchor('encoder layer') - enc = layer(enc) - enc = self.layer_norm_encoder(enc) - - # decoder embedding - nnscaler.runtime.function.anchor('decoder embedding') - dec_emb = torch.nn.functional.embedding(decoder_input_ids, self.vocab) - dec_emb = dec_emb * self.embed_scale_decoder - dec_emb = dec_emb + self.decoder_position - dec_emb = self.layernorm_embedding_decoder(dec_emb) - dec_emb = torch.nn.functional.dropout(dec_emb, p=0.1) - dec = dec_emb.transpose(0, 1) - - # decoder layers - for layer in self.decoders: - nnscaler.runtime.function.anchor('decoder layer') - dec = layer(dec, enc) - - dec = self.layer_norm_decoder(dec) - dec = dec.transpose(0, 1) - - # head - loss = self.head(dec) - return loss - - -def dummy_data(batch_size: int, config: Config): - - input_ids = torch.randint( - 0, config.vocab, - size=(batch_size, config.seqlen,), - dtype=torch.int64, device=torch.cuda.current_device() - ) - decoder_input_ids = torch.randint( - 0, config.vocab, - size=(batch_size, config.seqlen,), - dtype=torch.int64, device=torch.cuda.current_device() - ) - labels = torch.randint( - 0, config.num_classes, - size=(batch_size, ), # scalar - dtype=torch.int64, - device=torch.cuda.current_device() - ) - return input_ids, decoder_input_ids diff --git a/examples/nlp/mbart/policy/__init__.py b/examples/nlp/mbart/policy/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/nlp/mbart/policy/gallery.py b/examples/nlp/mbart/policy/gallery.py deleted file mode 100644 index 0e675958..00000000 --- a/examples/nlp/mbart/policy/gallery.py +++ /dev/null @@ -1,186 +0,0 @@ -from typing import List - -from nnscaler.graph import IRGraph -from nnscaler.ir.operator import IRFwOperation, IRDataOperation -from nnscaler.graph.function.anchor import IRGraphAnchor -from nnscaler.graph.schedule.predefined import PredefinedSched -from nnscaler.graph.segment import IRSegment -from nnscaler.ir.cten import IRCell - -from examples.utils import create_mesh, tensor_parallelism, replica - - -def _group_to_blocks(fnodes) -> List[List[IRCell]]: - """ - Grouping to [ - [Encoder Embed], - [Encoder Layer], [Encoder Layer], ..., - [Decoder Embed], - [Decoder Layer], [Decoder Layer], ... - ] - """ - blocks = [] - anchors = [node for node in fnodes if isinstance(node, IRGraphAnchor)] - indices = [fnodes.index(anchor) for anchor in anchors] - # encoder embedding - fnodes[indices[0] + 1].comment = f'==> start of encoder embedding' - assert anchors[0].name == 'encoder embedding' - blocks.append(fnodes[0:indices[1]]) - indices.pop(0) - anchors.pop(0) - # encoder layers - lid = 0 - while anchors[0].name == 'encoder layer': - start, end = indices[0], indices[1] - fnodes[start + 1].comment = f'==> start of encoder layer {lid}' - blocks.append(fnodes[start:end]) - indices.pop(0) - anchors.pop(0) - lid += 1 - # decoder embedding - assert anchors[0].name == 'decoder embedding' - blocks.append(fnodes[indices[0]:indices[1]]) - indices.pop(0) - anchors.pop(0) - # decoder layers - lid = 0 - while len(indices) != 0: - assert anchors[0].name == 'decoder layer' - start, end = indices[0], indices[1] if len(indices) > 1 else len(fnodes) - fnodes[start + 1].comment = f'==> start of decoder layer {lid}' - blocks.append(fnodes[indices[0]:end]) - indices.pop(0) - anchors.pop(0) - lid += 1 - return blocks - - - -def PASSingle(graph: IRGraph, resource, **kwargs): - assert resource.ngpus == 1 - _ = _group_to_blocks(graph.select(ntype=IRFwOperation)) - for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): - graph.assign(node, 0) - return graph - - -def PAS1F1B(graph: IRGraph, resource, nmicros: int = 16, **kwargs): - - num_stages = resource.ngpus - recompute: bool = True - - blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) - enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] - dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] - if recompute: - for block in blocks: - graph.recompute(block) - - # staging - fstages = [[] for _ in range(num_stages)] - nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // num_stages - for lid, fnodes in enumerate(enc_layers + dec_layers): - if lid == 0: - fstages[0] += enc_emb - elif lid == len(enc_layers): - fstages[num_stages // 2] += dec_emb - stage_id = min(lid // nlayer_per_stage, num_stages - 1) - fstages[stage_id] += fnodes - graph.staging(tuple(stage[0] for stage in fstages)) - - dataloader = graph.select(ntype=IRDataOperation)[0] - replica(graph, dataloader, [0, num_stages // 2]) - - fsegments = [seg for seg in graph.select(ntype=IRSegment, flatten=False) if seg.isfw()] - assert len(fsegments) == num_stages, f"Not match: {len(fsegments)} != {num_stages}" - for devid, segment in enumerate(fsegments): - graph.assign(segment, devid) - - strategy = PredefinedSched(graph, nmicros, num_stages) - graph.predef_sched(strategy) - - return graph - - -def PASMegatronTP(graph: IRGraph, resource, **kwargs): - """Megatron-way tensor parallelism""" - devs = list(range(resource.ngpus)) - for node in graph.select(ntype=(IRDataOperation, IRFwOperation)): - if node.name == 'embedding': - tensor_parallelism(graph, node, idx=1, dim=0, devs=devs) - elif node.name == 'self_attention' or node.name == 'feedforward': - tensor_parallelism(graph, node, idx=1, dim=0, devs=devs) - elif node.name == 'cross_attention': - tensor_parallelism(graph, node, idx=2, dim=0, devs=devs) - else: - replica(graph, node, devs) - return graph - - -def PASMegatron(graph: IRGraph, resource, - tp_size: int = 2, dp_size: int = 1, - nmicros: int = 16, **kwargs): - """Megatron policy for hybrid data-tensor-pipeline parallelism""" - dp_size = 2 - tp_size = 2 - pp_size = resource.ngpus // (dp_size * tp_size) - recompute: bool = True - num_microbatch = nmicros - - # device mesh - dp_groups, pp_groups, tp_groups = \ - create_mesh(resource.ngpus, (dp_size, pp_size, tp_size)) - print(f'dp groups: {dp_groups}') - print(f'pp groups: {pp_groups}') - print(f'tp groups: {tp_groups}') - - def get_device(dp_idx: int, pp_idx: int, tp_idx: int, ) -> int: - return tp_groups[dp_idx * pp_size + pp_idx][tp_idx] - - blocks = _group_to_blocks(graph.select(ntype=IRFwOperation)) - enc_emb, enc_layers = blocks[0], blocks[1:len(blocks)//2] - dec_emb, dec_layers = blocks[len(blocks)//2], blocks[len(blocks)//2+1:] - if recompute: - for block in blocks: - graph.recompute(block) - - # pipelien stage - fstages = [[] for _ in range(pp_size)] - nlayer_per_stage = (len(enc_layers) + len(dec_layers)) // pp_size - for lid, fnodes in enumerate(enc_layers + dec_layers): - if lid == 0: - fstages[0] += enc_emb - elif lid == len(enc_layers): - fstages[pp_size // 2] += dec_emb - stage_id = min(lid // nlayer_per_stage, pp_size - 1) - fstages[stage_id] += fnodes - graph.staging(tuple(stage[0] for stage in fstages)) - - # partition dataloader - dataloader = graph.select(ntype=IRDataOperation)[0] - bs = dataloader.output(0).shape[0] - replica(graph, dataloader, list(range(resource.ngpus))) - - # tp-dp partition - fstages = [stage for stage in graph.select(ntype=IRSegment, flatten=False) if stage.isfw()] - assert len(fstages) == pp_size - for pp_idx, fstage in enumerate(fstages): - for node in fstage.nodes(): - if len(node.inputs()) == 0: continue # anchor - if node.name == 'embedding': - nodes = tensor_parallelism(graph, node, idx=1, dim=0, devs=[0]*tp_size) - elif node.name == 'self_attention' or node.name == 'feedforward': - nodes = tensor_parallelism(graph, node, idx=1, dim=0, devs=[0]*tp_size) - elif node.name == 'cross_attention': - nodes = tensor_parallelism(graph, node, idx=2, dim=0, devs=[0]*tp_size) - else: - nodes = replica(graph, node, [0]*tp_size) - # data parallel - for tp_idx, node in enumerate(nodes): - dp_devices = [get_device(dp_idx, pp_idx, tp_idx) for dp_idx in range(dp_size)] - batch_dim = node.input(0).shape.index(bs) - tensor_parallelism(graph, node, dp_devices, idx=0, dim=batch_dim) - - strategy = PredefinedSched.sched_1f1b(graph, num_microbatch) - graph.predef_sched(strategy) - return graph diff --git a/examples/nlp/mbart/train.py b/examples/nlp/mbart/train.py deleted file mode 100644 index 28ae9da9..00000000 --- a/examples/nlp/mbart/train.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -example: - -PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=4 torchrun \ - --nproc_per_node=4 \ - examples/nlp/mbart/train.py --policy PASMegatronTP --fp16 -""" - - -import torch -import logging -import argparse -import math -from functools import partial - -from examples.nlp.mbart.model import MBartForSentenceClassification, Config -from examples.nlp.mbart.model import dummy_data - -import nnscaler -from nnscaler.compiler import compile -from nnscaler.utils import set_default_logger_level, load_model -from nnscaler.profiler.timer import CudaTimer, print_each_rank -from nnscaler.profiler.memory import memory_summary -from nnscaler.runtime.utils import microbatches - -import examples.nlp.mbart.policy.gallery as gallery - -from examples.utils import get_policy - -parser = argparse.ArgumentParser(description='GPT Train') -parser.add_argument('--policy', type=str, help='PAS policy choice, starting with PAS') -parser.add_argument('--fp16', action='store_true', default=False, - help='use fp16 for the training') -parser.add_argument('--dp', type=int, default=1, - help='data parallel size, only for megatron') -parser.add_argument('--tp', type=int, default=1, - help='tensor parallel size, only for megatron') -# training -parser.add_argument('--gbs', type=int, default=4, help='global batch size') -parser.add_argument('--mbs', type=int, default=4, help='micro batch size') -# arch -parser.add_argument('--vocab', type=int, default=2500, - help='used vocabulary size') -parser.add_argument('--layers', type=int, default=8, - help='layer number of each encoder and decoder') -parser.add_argument('--heads', type=int, default=16, - help='head number') -parser.add_argument('--hidden', type=int, default=2048, - help='head number') -parser.add_argument('--seqlen', type=int, default=1024, - help='sequence length') - -args = parser.parse_args() - -nnscaler.init() -print(args) - - -nnscaler.init() -set_default_logger_level(logging.WARN) -logging.getLogger('nnscaler.compiler').setLevel(logging.INFO) - -# get policy -policy = get_policy([gallery], args.policy) -policy = partial(policy, - nmicros=args.gbs//args.mbs, - dp_size=args.dp, - tp_size=args.tp -) - - -def trunc_normal_(tensor: torch.Tensor, mean=0., std=1., a=-2., b=2.): - def norm_cdf(x): - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - with torch.no_grad(): - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - tensor.uniform_(2 * l - 1, 2 * u - 1) - # tensor.erfinv_() - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - tensor.clamp_(min=a, max=b) - return tensor - - -def train(): - - batch_size = args.mbs - - config = Config( - hidden=args.hidden, - heads=args.heads, - layers=args.layers, - seqlen=args.seqlen, - ffn_hidden_dim=args.hidden * 4, - vocab=args.vocab, - ) - print_each_rank(config) - - model = MBartForSentenceClassification(batch_size, config) - torch.manual_seed(0) - for param in model.parameters(): - trunc_normal_(param) - model = model.half() if args.fp16 else model - - gen_data = partial(dummy_data, batch_size, config) - dataloader = microbatches((gen_data(),), cycle=True) - - @compile(model, dataloader, PAS=policy) - def train_iter(model, dataloader): - input_ids, decoder_input_ids = next(dataloader) - loss = model(input_ids, decoder_input_ids) - loss.backward() - model = load_model() - - optimizer = torch.optim.Adam( - model.parameters(), lr=3e-05, betas=(0.9, 0.98)) - - CudaTimer().warmup() - iter_num, warmup = 5, 2 - for step in range(iter_num): - if step == warmup: - CudaTimer(enable=True).start('e2e') - # prepare input data - samples = [gen_data() for _ in range(args.gbs // args.mbs)] - dataloader = microbatches(samples) - - # training - train_iter(model, dataloader) - optimizer.step() - optimizer.zero_grad() - - if step == 0: - print_each_rank('passed first iteration') - if (step + 1) % 2 == 0: - print_each_rank(f'iter [{step + 1}/{iter_num}]', rank_only=0) - - CudaTimer().stop('e2e') - print_each_rank('e2e time (ms) per iteration: {} ms'.format( - CudaTimer().duration(iter_num-warmup, field_name='e2e'))) - CudaTimer().print_all(times=iter_num-warmup) - memory_summary() - - -if __name__ == '__main__': - - nnscaler.init() - train() \ No newline at end of file From d1e1c24846910b72bc4b0599c9259995c51a59d2 Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Fri, 6 Sep 2024 06:28:36 +0000 Subject: [PATCH 1724/1892] Merged PR 2253: refine tracer utils MagicMethodPatcher + node_is_impure_wrapper -> TorchFXPatcher split utils.py to different files --- nnscaler/graph/parser/converter.py | 2 +- .../fx/concrete_trace_utils/__init__.py | 2 +- .../fx/concrete_trace_utils/concrete_proxy.py | 99 +++---- .../concrete_trace_utils/concrete_tracer.py | 235 +---------------- .../fx/concrete_trace_utils/frame_utils.py | 75 ++++++ .../{utils.py => metadata.py} | 89 +------ .../fx/concrete_trace_utils/pytree_utils.py | 36 +++ .../concrete_trace_utils/torch_fx_patcher.py | 243 ++++++++++++++++++ nnscaler/graph/parser/fx/parser.py | 3 +- tests/graph/tracer/test_inplace.py | 2 +- 10 files changed, 418 insertions(+), 368 deletions(-) create mode 100644 nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py rename nnscaler/graph/parser/fx/concrete_trace_utils/{utils.py => metadata.py} (51%) create mode 100644 nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index f09ea014..69342f1a 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -12,7 +12,7 @@ from nnscaler.graph.parser.fx.parser import FxModuleParser from nnscaler.graph.parser.fx.concrete_trace_utils import concrete_trace from nnscaler.graph.parser.fx.concrete_trace_utils.wrap_utils import Location, is_autograd_apply, LeafWrapInfo -from nnscaler.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops +from nnscaler.graph.parser.fx.concrete_trace_utils.torch_fx_patcher import side_effectful_inplace_ops import nnscaler.runtime.function as cube_rt_function diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py b/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py index 279b44c1..282a505c 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/__init__.py @@ -13,4 +13,4 @@ """ from .concrete_tracer import ConcreteTracer, concrete_trace from .concrete_proxy import ConcreteProxy -from .utils import ExtraSEFPatcher, TensorMetadata +from .metadata import TensorMetadata, DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index 95b416d5..d7d9a61e 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -7,7 +7,7 @@ import logging import inspect -from typing import List, Optional, Iterable, Any, Set, Union +from typing import List, Optional, Iterable, Any, Union import torch from torch.fx._compatibility import compatibility @@ -17,10 +17,8 @@ from torch.overrides import is_tensor_method_or_property from . import concrete_tracer as et -from . import pytree_utils, orig_func -from .utils import ( - get_frame_record, -) +from . import pytree_utils, orig_func, wrap_utils +from .frame_utils import get_frame_record, get_instruction _logger = logging.getLogger(__name__) @@ -77,16 +75,8 @@ def __call__(self, *args, **kwargs) -> ConcreteProxy: return self.value.__call__(*args, **kwargs) return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) - def __iter__(self) -> Union[Iterable, ConcreteProxy]: - # to detect if in executing `*proxy`, or `a, b, c = atuple` - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - cur = calling_frame.f_lasti // 2 - insts: List[dis.Instruction] = orig_func.list(dis.get_instructions(calling_frame.f_code)) - while insts[cur].opcode == self.op_extended_arg: - cur += 1 + def __iter__(self) -> Union[Iterable, ConcreteProxy]: + insts, cur = get_instruction(1) if insts[cur].opcode == self.op_call_ex: # in executing func(..., *proxy) @@ -117,15 +107,7 @@ def __next__(self) -> ConcreteProxy: return self.tracer.create_proxy('call_function', next, (self,), {}) def __len__(self) -> Union[int, ConcreteProxy]: - # to detect if in executing `*proxy` - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - cur = calling_frame.f_lasti // 2 - insts: List[dis.Instruction] = orig_func.list(dis.get_instructions(calling_frame.f_code)) - while insts[cur].opcode == self.op_extended_arg: - cur += 1 + insts, cur = get_instruction(1) if insts[cur].opcode == self.op_call_ex: # in executing func(..., *proxy) @@ -148,15 +130,7 @@ def __setitem__(self, *args, **kwargs) -> ConcreteProxy: return self.tracer.create_proxy('call_function', orig_func.setitem, (self,) + args, kwargs) def __bool__(self) -> Union[bool, ConcreteProxy]: - # to detect if in executing branch condition - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - cur = calling_frame.f_lasti // 2 - insts: List[dis.Instruction] = orig_func.list(dis.get_instructions(calling_frame.f_code)) - while insts[cur].opcode == self.op_extended_arg: - cur += 1 + insts, cur = get_instruction(1) if insts[cur].opcode in self.jump_opcodes or ( insts[cur].opcode in self.jump_before_opcodes and insts[cur + 1].opcode in self.jump_opcodes): @@ -205,15 +179,7 @@ def __exit__(self, exc_type, exc_value, traceback): @compatibility(is_backward_compatible=True) def keys(self): - # to detect if in executing `**proxy` - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - cur = calling_frame.f_lasti // 2 - insts: List[dis.Instruction] = orig_func.list(dis.get_instructions(calling_frame.f_code)) - while insts[cur].opcode == self.op_extended_arg: - cur += 1 + insts, cur = get_instruction(1) if insts[cur].opcode == self.op_call_ex or insts[cur].opcode == self.op_dict_merge: # in executing `**proxy` @@ -239,33 +205,38 @@ def items(self): def __torch_function__(cls, orig_method, types, args=None, kwargs=None): # to wrap all the functions/methods with tensor inputs in the namespace 'torch.*'. # actually a simple way to do wrap, but may get wrong in functions with no tensor inputs. - # TODO: now for most functions in torch namespace, we do wrap directly and not use __torch_function__ + # NOTE: now for most functions in torch namespace, we do wrap directly and not use __torch_function__ + _logger.warning(f"{orig_method} is not wrapped by tracer, which is not expected, please consider to register this function.") args = args if args else () kwargs = kwargs if kwargs else {} - tracers: Set[Any] = orig_func.set() - - def find_tracer(a): - if orig_func.isinstance(a, cls): - tracers.add(a.tracer) - - pytree_utils.tree_map(find_tracer, args) - pytree_utils.tree_map(find_tracer, kwargs) - - if orig_func.len(tracers) > 1: - raise RuntimeError(f'Found multiple different tracers {orig_func.list(tracers)} while ' - f'trying to trace operations {orig_method}') - tracer, = tracers - - if isinstance(orig_method, torch._C.ScriptMethod): - args = (orig_method.owner,) + args - return tracer.create_proxy('call_method', orig_method.name, args, kwargs) - if is_tensor_method_or_property(orig_method): - return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) + with wrap_utils.do_temp_call_origin(): + tracers = orig_func.set() + + def detect_tracer(obj): + if isinstance(obj, ConcreteProxy): + tracers.add(obj.tracer) + + pytree_utils.tree_map(detect_tracer, args) + pytree_utils.tree_map(detect_tracer, kwargs) + + if len(tracers) > 1: + raise Exception('more than 1 tracer detected. please report the issue') + + tracer = None if len(tracers) == 0 else tracers.pop() + + if tracer is None: + raise RuntimeError(f"no proxy detected in the inputs of {orig_method}, please wrap this function for trace completeness.") else: - return tracer.create_proxy('call_function', orig_method, args, kwargs, - name=tracer.graph._target_to_str(orig_method.__name__)) + if isinstance(orig_method, torch._C.ScriptMethod): + args = (orig_method.owner,) + args + return tracer.create_proxy('call_method', orig_method.name, args, kwargs) + if is_tensor_method_or_property(orig_method): + return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) + else: + return tracer.create_proxy('call_function', orig_method, args, kwargs, + name=tracer.graph._target_to_str(orig_method.__name__)) @compatibility(is_backward_compatible=True) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index fb5f707c..d30e61f1 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -8,8 +8,6 @@ import sys import inspect import logging -import operator -import functools import builtins from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType @@ -18,79 +16,27 @@ import torch from torch._C import ScriptObject -from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict import torch.fx from torch.fx import GraphModule from torch.fx._compatibility import compatibility from torch.fx._symbolic_trace import _proxyable_classes from torch.fx.graph import Graph -from torch.fx.node import Target, Node, Argument, _side_effectful_functions, base_types -from torch.fx.proxy import TracerBase +from torch.fx.node import Target, Node, Argument, base_types +from torch.fx.proxy import TracerBase, Scope from torch.fx.operator_schemas import check_for_mutable_operation dict_keys_type = type(dict().keys()) dict_values_type = type(dict().values()) dict_items_type = type(dict().items()) -try: - # Scope is a new class to record module path in pytorch 2.0 - from torch.fx.proxy import Scope -except ImportError: - # copy from pytorch 2.0 - @compatibility(is_backward_compatible=False) - class Scope: - def __init__(self, module_path: str, module_type: Any): - super().__init__() - self.module_path = module_path - self.module_type = module_type - -try: - # comes with Scope - from torch.fx.proxy import ScopeContextManager -except ImportError: - # copy from pytorch 2.0 - @compatibility(is_backward_compatible=False) - class ScopeContextManager: - """ A context manager to track the Scope of Node during symbolic tracing. - When entering a forward function of a Module, we'll update the scope information of - the current module, and when we exit, we'll restore the previous scope information. - """ - - def __init__( - self, - scope: Scope, - current_scope: Scope, - ): - super().__init__() - # Keep a copy of prev scope to restore on exit - self._prev_scope = copy.copy(scope) - # Update scope to current scope - scope.module_path = current_scope.module_path - scope.module_type = current_scope.module_type - # Save a reference so we can restore it - self._scope = scope - - def __enter__(self): - return self._scope - - def __exit__(self, *args): - self._scope.module_path = self._prev_scope.module_path - self._scope.module_type = self._prev_scope.module_type - return - from . import concrete_proxy as ep from . import pytree_utils, orig_func, wrap_utils +from .frame_utils import get_frame_record from .function_patcher import FunctionPatcher +from .metadata import EmptyResult, extract_results_metadata from .operator_patcher import OperatorPatcherContext -from .utils import ( - _orig_node_is_impure, - side_effectful_inplace_ops, - ExtraSEFPatcher, - EmptyResult, - extract_results_metadata, - get_frame_record, -) +from .torch_fx_patcher import TorchFXPatcher, ExtraSEFPatcher, side_effectful_inplace_ops # pyright: reportGeneralTypeIssues=false _logger = logging.getLogger(__name__) @@ -732,12 +678,13 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): with OperatorPatcherContext(self, use_operator_patch, operator_patch_backlist): results = OperatorPatcherContext.patch_run(fn, *args, *more_args, **kwargs) # we should unwrap proxy to the original value in the results when we record it to node.meta['tensor_meta'] - def unwrap(obj: Any): - while orig_func.isinstance(obj, ep.ConcreteProxy): - obj = obj.value - return obj + with wrap_utils.do_temp_call_origin(): + def unwrap_nested_proxy(proxy: ep.ConcreteProxy): + return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) + + node_result = pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, results) self.create_node('output', 'output', (self.create_arg(results),), - {}, type_expr=fn.__annotations__.get('return', None), node_result=ep.map_aggregate_not_proxy(results, unwrap)) + {}, type_expr=fn.__annotations__.get('return', None), node_result=node_result) finally: _retain_weight_consistency(self.root) @@ -774,109 +721,6 @@ def __init__(self, graph: Graph): super().__init__() self.graph = graph -class MagicMethodPatcher: - from torch.fx import graph as fx_graph - from torch.fx import graph_module as fx_graph_module - from torch.fx import node as fx_node - magic_methods_ori = fx_graph.magic_methods - magic_methods_new = { - **fx_graph.magic_methods, - 'not_': 'not {}', - 'is_': '{} is {}', - 'is_not': '{} is not {}', - 'contains': '{1} in {0}', - } - copy_attr_ori: Any = fx_graph_module._copy_attr - find_module_of_method_ori: Any = fx_node._find_module_of_method - format_import_statement_ori: Any = fx_graph_module._format_import_statement - - @staticmethod - def copy_attr_new(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): - *prefix, field = target.split('.') - for item in prefix: - f = getattr(from_module, item) - t = getattr(to_module, item, None) - if f is t: - return - - if t is None: - if isinstance(f, Sequential): - t = Sequential() - elif isinstance(f, ModuleList): - t = ModuleList() - elif isinstance(f, ModuleDict): - t = ModuleDict() - else: - t = torch.nn.Module() - if hasattr(f, '_get_name'): - t._get_name = f._get_name - to_module.add_module(item, t) - from_module, to_module = f, t - - orig = getattr(from_module, field) - - # If it is a buffer, register the tensor as the same type of buffer, otherwise, just set the attribute. - if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): - persistent = field in from_module._buffers and field not in from_module._non_persistent_buffers_set - to_module.register_buffer(field, orig, persistent=persistent) - else: - setattr(to_module, field, orig) - - @staticmethod - def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: - if wrap_utils.is_autograd_apply(orig_method): - # for torch.autograd.Function - return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' - - name = orig_method.__name__ - module = orig_method.__module__ - # if hasattr(orig_method, '__qualname__') and isinstance(orig_method.__qualname__, str): - # # if there has '.' in '__qualname__', it means this function is in a nested structure, - # # - # # for example, it is a method / function in a class: - # # torch.nn.Linear.forward.__module__ = torch.nn - # # torch.nn.Linear.forward.__name__ = forward - # # torch.nn.Linear.forward.__qualname__ = Linear.forward - # # - # # And in fx.node qualified name creating rule, the module also should include the class name, - # # in this example, the returned module should be `torch.nn.Linear`. - # # It is not the original meaning of a obj's module, but we need this workaround to reuse fx node. - # splited_names = orig_method.__qualname__.split('.') - # class_name, name = splited_names[:-1], splited_names[-1] - # module = '.'.join([module] + class_name) - if module == 'torch.autograd.grad_mode' and name in ['__enter__', '__exit__']: - return 'torch.autograd.grad_mode.no_grad' - if module is not None: - return module - if hasattr(orig_method, '__qualname__')\ - and isinstance(orig_method.__qualname__, str) and orig_method.__qualname__.startswith('_VariableFunctionsClass.'): - return 'torch._C._VariableFunctions' - for guess in [torch, getattr(torch.nn, 'functional')]: - if getattr(guess, name, None) is orig_method: - return guess.__name__ - raise RuntimeError(f'cannot find module for {orig_method}') - - @staticmethod - def format_import_statement_new(name: str, obj: Any, importer) -> str: - if wrap_utils.is_autograd_apply(obj): # type: ignore - # torch.autograd.function - return MagicMethodPatcher.format_import_statement_ori(name, obj.__self__, importer) + f'\n{name} = {name}.apply' - return MagicMethodPatcher.format_import_statement_ori(name, obj, importer) - - def __enter__(self): - MagicMethodPatcher.fx_graph.magic_methods = self.magic_methods_new - MagicMethodPatcher.fx_graph_module._copy_attr = self.copy_attr_new - MagicMethodPatcher.fx_node._find_module_of_method = self.find_module_of_method_new - MagicMethodPatcher.fx_graph_module._format_import_statement = self.format_import_statement_new - MagicMethodPatcher.available = True - - def __exit__(self, exc_type, exc_value, tb): - MagicMethodPatcher.fx_graph.magic_methods = MagicMethodPatcher.magic_methods_ori - MagicMethodPatcher.fx_graph_module._copy_attr = MagicMethodPatcher.copy_attr_ori - MagicMethodPatcher.fx_node._find_module_of_method = MagicMethodPatcher.find_module_of_method_ori - MagicMethodPatcher.fx_graph_module._format_import_statement = MagicMethodPatcher.format_import_statement_ori - MagicMethodPatcher.available = False - return exc_type is None def _retain_weight_consistency(root: torch.nn.Module): _flag = 0 @@ -898,56 +742,6 @@ def _retain_weight_consistency(root: torch.nn.Module): ' ``concrete_trace`` may not guarantee the consistency of the traced graph.') return root -@functools.wraps(_orig_node_is_impure) -def node_is_impure_wrapper(node): - if is_useless_iter(node): - return False - - if node.op in {"placeholder", "output"}: - return True - - if node.op == "call_function": - return node.target in _side_effectful_functions - - if node.op == "call_method": - return node.target.endswith("_") - - if node.op == "call_module": - assert ( - node.graph.owning_module is not None - ), "self.graph.owning_module not set for purity check" - target_mod = node.graph.owning_module.get_submodule(node.target) - assert ( - target_mod is not None - ), f"Did not find expected submodule target {node.target}" - return getattr(target_mod, "_is_impure", False) - - return False - -def is_useless_iter(node: Node): - if node.op == 'call_function' and node.target is iter: - node_is_impure = False - for iter_user in node.users: - if not is_useless_next(iter_user): - node_is_impure = True - break - if not node_is_impure: - for iter_user in list(node.users.keys()): - setattr(iter_user, '_is_impure', False) - iter_user.graph.erase_node(iter_user) - if len(node.users) > 0: - raise RuntimeError('The user node of iter is not empty, something goning wrong.') - setattr(node, '_is_impure', False) - return True - else: - return False - -def is_useless_next(node: Node): - if node.op == "call_function" and node.target is next: - if len(node.users) == 0: - return True - else: - return False def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Union[Dict[str, Any], Tuple], @@ -1128,7 +922,7 @@ def f(x, y): else: assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}' - with MagicMethodPatcher(): + with TorchFXPatcher(): name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ traced = GraphModule(tracer.root, graph, name) @@ -1140,10 +934,9 @@ def f(x, y): *side_effectful_inplace_ops } extra_side_effectful_functions = default_extra_side_effectful_functions | dce_ignored_function - with FunctionPatcher() as patcher, ExtraSEFPatcher(extra_side_effectful_functions): - patcher.patch_method(Node, 'is_impure', node_is_impure_wrapper, deduplicate=False) + with ExtraSEFPatcher(extra_side_effectful_functions): traced.graph.eliminate_dead_code() - traced.recompile() # this need to be done in MagicMethodPatcher context + traced.recompile() # this need to be done in TorchFXPatcher context # TODO: better infomation if check_args is not None: diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py new file mode 100644 index 00000000..34b741ed --- /dev/null +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/frame_utils.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass +import dis +import importlib +import inspect +from pathlib import Path +import sys +import traceback + +from typing import List, Optional + + +def get_instruction(back_times=1) -> dis.Instruction: + """ + Get the instruction of the (back_times)-th frame from the bottom. + By default (back_times=1), the instruction of the frame who call this function will be returned. + """ + frame = inspect.currentframe() + assert frame is not None + # the frame who call get_instruction + calling_frame = frame.f_back + for _ in range(back_times): + calling_frame = calling_frame.f_back + assert calling_frame is not None + insts: List[dis.Instruction] = list(dis.get_instructions(calling_frame.f_code)) + + if sys.version_info >= (3, 11): + from bisect import bisect_left + # bisect_left find the position where an element should be inserted in a sorted list to maintain the list’s order. + # If the element already exists in the list, + # bisect_left returns the position to the left of the first occurrence of that element. + # here use bisect_left to find the position of calling_frame.f_lasti in the insts. + cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) + else: + # based on the assumption that most bytecodes in Python are two bytes, + # dividing by 2 results in the sequence number of the instructions. + cur = calling_frame.f_lasti // 2 + + # From python doc: + # EXTENDED_ARG(ext): Prefixes any opcode which has an argument too big to fit into the default one byte. + # ext holds an additional byte which act as higher bits in the argument. + # For each opcode, at most three prefixal EXTENDED_ARG are allowed, forming an argument from two-byte to four-byte. + while insts[cur].opname == 'EXTENDED_ARG': + cur += 1 + return insts, cur + + +@dataclass +class FrameRecord: + filename: str + lineno: str + line: str + # the name of the frame is the function name + name: str + + def __repr__(self) -> str: + if self.filename: + return f'File "{self.filename}", line {self.lineno}, in {self.name}, {self.line}' + else: + return '' + + +def get_frame_record() -> Optional[FrameRecord]: + # record code frame, include filename, line number, and function name + frame_record = None + cube_path = str(Path(importlib.util.find_spec('nnscaler').origin).parent) + '/' # the cube path + torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path + ignore_dirs = [cube_path, torch_path] + # the last frame is the current frame [get_frame_record], so we need to skip it + for frame in traceback.extract_stack()[-2::-1]: + if any(p in frame.filename for p in ignore_dirs): + continue + frame_record = FrameRecord(frame.filename, frame.lineno, frame.line, frame.name) + break + return frame_record + diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py similarity index 51% rename from nnscaler/graph/parser/fx/concrete_trace_utils/utils.py rename to nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py index 6c0da1eb..a463fc23 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/metadata.py @@ -1,43 +1,20 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from dataclasses import dataclass -import importlib -import operator -import traceback -from pathlib import Path -from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple +from typing import Any, Dict, NamedTuple, Optional, Tuple import torch -from torch.fx.node import Node, map_aggregate, _side_effectful_functions +from torch.fx.node import Node + +from . import pytree_utils DICT_KEYS_TYPE = type({}.keys()) DICT_VALUES_TYPE= type({}.values()) DICT_ITEMS_TYPE= type({}.items()) -_orig_node_is_impure: Callable = Node.is_impure - -side_effectful_inplace_ops = { - operator.iadd, operator.isub, operator.imul, operator.itruediv, operator.ifloordiv, - operator.iand, operator.ior, operator.ixor, operator.ilshift, operator.irshift, - operator.imod, operator.ipow, - # operator.imatmul is not implemented in torch - # so let's ignore it now - operator.setitem, -} - - -class ExtraSEFPatcher: - def __init__(self, extra_side_effectful_functions: Set[Callable]): - self.extra_side_effectful_functions = extra_side_effectful_functions - self.incontext_funcs = set() - def __enter__(self): - self.incontext_funcs = self.extra_side_effectful_functions - _side_effectful_functions - _side_effectful_functions.update(self.incontext_funcs) - - def __exit__(self, exc_type, exc_val, exc_tb): - _side_effectful_functions.difference_update(self.incontext_funcs) +class EmptyResult: + """ + Used for identification no results. + """ + pass class TensorMetadata(NamedTuple): @@ -94,21 +71,13 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] - return TensorMetadata( - shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) - - -def extract_tensor_metadata(obj: Any): - if isinstance(obj, torch.Tensor): - return _extract_tensor_metadata(obj) - else: - return obj + return TensorMetadata(shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) def extract_results_metadata(results: Any, node: Node): if results is not EmptyResult: res = tuple(results) if isinstance(results, (DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE)) else results - meta = map_aggregate(res, extract_tensor_metadata) + meta = pytree_utils.tree_map_only(torch.Tensor, _extract_tensor_metadata, res) # we should get the meta info of the inner element of these type obj if isinstance(results, DICT_KEYS_TYPE): meta = {i: m for i, m in enumerate(meta)}.keys() @@ -118,39 +87,3 @@ def extract_results_metadata(results: Any, node: Node): meta = {i: m for i, m in meta}.items() node.meta['tensor_meta'] = meta node.meta['type'] = type(results) - - -class EmptyResult: - """Used for identification no results. - """ - pass - - -@dataclass -class FrameRecord: - filename: str - lineno: str - line: str - # the name of the frame is the function name - name: str - - def __repr__(self) -> str: - if self.filename: - return f'File "{self.filename}", line {self.lineno}, in {self.name}, {self.line}' - else: - return '' - - -def get_frame_record() -> Optional[FrameRecord]: - # record code frame, include filename, line number, and function name - frame_record = None - cube_path = str(Path(importlib.util.find_spec('nnscaler').origin).parent) + '/' # the cube path - torch_path = str(Path(importlib.util.find_spec('torch').origin).parent) + '/' # the torch path - ignore_dirs = [cube_path, torch_path] - # the last frame is the current frame [get_frame_record], so we need to skip it - for frame in traceback.extract_stack()[-2::-1]: - if any(p in frame.filename for p in ignore_dirs): - continue - frame_record = FrameRecord(frame.filename, frame.lineno, frame.line, frame.name) - break - return frame_record diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py index b428c0e0..f23a034a 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/pytree_utils.py @@ -4,6 +4,9 @@ from collections import namedtuple from typing import Any, List, Tuple, Iterable +import torch +import inspect + from . import orig_func, _pytree from ._pytree import * @@ -57,6 +60,39 @@ def _slice_unflatten(values: Iterable[Any], context: Context) -> slice: ) +# register return_types to pytree, copy from torch.return_types +return_types = torch._C._return_types + +def pytree_register_structseq(cls): + def structseq_flatten(structseq): + return list(structseq), None + + def structseq_flatten_with_keys(structseq): + values, context = structseq_flatten(structseq) + return [(_pytree.SequenceKey(i), v) for i, v in enumerate(values)], context + + def structseq_unflatten(values, context): + return cls(values) + + _pytree.register_pytree_node( + cls, + structseq_flatten, + structseq_unflatten, + flatten_with_keys_fn=structseq_flatten_with_keys, + ) + +for name in dir(return_types): + if name.startswith('__'): + continue + + _attr = getattr(return_types, name) + globals()[name] = _attr + + # torch.return_types is a structseq, aka a "namedtuple"-like thing defined by the Python C-API. + if inspect.isclass(_attr) and issubclass(_attr, tuple): + pytree_register_structseq(_attr) + + def tree_leaves_with_spec(pytree: _pytree.PyTree, spec: TreeSpec) -> List: """ Flat a pytree with a given spec. diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py b/nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py new file mode 100644 index 00000000..0cf6c50b --- /dev/null +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/torch_fx_patcher.py @@ -0,0 +1,243 @@ +import operator +from typing import Any, Callable, Set + +import torch +from torch.nn import Sequential, ModuleList, ModuleDict +from torch.fx.node import _side_effectful_functions, Node + +from . import wrap_utils + + +side_effectful_inplace_ops = { + operator.iadd, operator.isub, operator.imul, operator.itruediv, operator.ifloordiv, + operator.iand, operator.ior, operator.ixor, operator.ilshift, operator.irshift, + operator.imod, operator.ipow, + # operator.imatmul is not implemented in torch + # so let's ignore it now + operator.setitem, +} + + +class ExtraSEFPatcher: + def __init__(self, extra_side_effectful_functions: Set[Callable]): + self.extra_side_effectful_functions = extra_side_effectful_functions + self.incontext_funcs = set() + + def __enter__(self): + self.incontext_funcs = self.extra_side_effectful_functions - _side_effectful_functions + _side_effectful_functions.update(self.incontext_funcs) + + def __exit__(self, exc_type, exc_val, exc_tb): + _side_effectful_functions.difference_update(self.incontext_funcs) + + +def is_useless_iter(node: Node): + if node.op == 'call_function' and node.target is iter: + node_is_impure = False + for iter_user in node.users: + if not is_useless_next(iter_user): + node_is_impure = True + break + if not node_is_impure: + for iter_user in list(node.users.keys()): + setattr(iter_user, '_is_impure', False) + iter_user.graph.erase_node(iter_user) + if len(node.users) > 0: + raise RuntimeError('The user node of iter is not empty, something goning wrong.') + setattr(node, '_is_impure', False) + return True + else: + return False + + +def is_useless_next(node: Node): + if node.op == "call_function" and node.target is next: + if len(node.users) == 0: + return True + else: + return False + + +class TorchFXPatcher: + """ + this patcher is a context mananger, when enter the context, several torch.fx functions will be patched, + and revert these functions when exit. + + The following function will be patched: + + torch.fx.graph.magic_methods: + additional add not_/is_/is_not/contains, because these functions are transformed by nnscaler operator patcher. + + torch.fx.graph_module._copy_attr: + additional track persistent attribute for buffer. + + torch.fx.graph_module._format_import_statement: + additional support autograd functions code generation. + + torch.fx.node._find_module_of_method: + additional support autograd functions and _VariableFunctionsClass functions for find the correct module. + + torch.fx.node.is_impure: + additional add inplace functions as impure nodes and useless iter nodes as non-impure node. + """ + from torch.fx import graph as fx_graph + from torch.fx import graph_module as fx_graph_module + from torch.fx import node as fx_node + + magic_methods_ori = fx_graph.magic_methods + copy_attr_ori = fx_graph_module._copy_attr + find_module_of_method_ori = fx_node._find_module_of_method + is_impure_ori = fx_node.Node.is_impure + format_import_statement_ori = fx_graph_module._format_import_statement + + magic_methods_new = { + **fx_graph.magic_methods, + # NOTE by nnscaler: add these method because we use operator patcher to transform the origin code to `_operator.xxx`, + # torch.fx.graph.magic_methods is used to emit node to generate code, so here add these mapping to make the gencode more readable. + # for example: + # in original code: if mask is not None: + # will be transformed to: if _operator.is_not(mask, None): + 'not_': 'not {}', + 'is_': '{} is {}', + 'is_not': '{} is not {}', + 'contains': '{1} in {0}', + } + + @staticmethod + def copy_attr_new(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): + """ + copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' + This installs empty Modules where none exist yet if they are subpaths of target + """ + *prefix, field = target.split('.') + for item in prefix: + f = getattr(from_module, item) + t = getattr(to_module, item, None) + if f is t: + # we have already installed one of its parents + # (e.g. target = root.linear.weight, but we have already installed root.linear) + # once we have installed a parent, we no longer need to copy the children + # since all needed attributes have been copied + return + + if t is None: + # NOTE by nnscaler: in the original copy_attr, only create torch.nn.Module for all cases, + # here we add more kinds of official subclasses of torch.nn.Module + if isinstance(f, Sequential): + t = Sequential() + elif isinstance(f, ModuleList): + t = ModuleList() + elif isinstance(f, ModuleDict): + t = ModuleDict() + else: + t = torch.nn.Module() + # NOTE by nnscaler: for readable reason, we want the to_module has the same repr with the from_module, + # so here we bind the from_module._get_name to to_module._get_name + if hasattr(f, '_get_name'): + t._get_name = f._get_name + to_module.add_module(item, t) + from_module, to_module = f, t + + orig = getattr(from_module, field) + + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + # NOTE by nnscaler: persistent state is not considered by the original copy_attr, so add it here + persistent = field in from_module._buffers and field not in from_module._non_persistent_buffers_set + to_module.register_buffer(field, orig, persistent=persistent) + else: + setattr(to_module, field, orig) + + @staticmethod + def find_module_of_method_new(orig_method: Callable[..., Any]) -> str: + # NOTE by nnscaler: if the method is torch.autograd.Function.apply, we should return its name with bound module + # for example, cus_module.CusAutogradFunction is a class inherit the torch.autograd.Function, then: + # cus_module.CusAutogradFunction.apply.__name__ is "apply" + # cus_module.CusAutogradFunction.apply.__module__ is "torch.autograd.function" + # cus_module.CusAutogradFunction.apply.__self__.__name__ is "CusAutogradFunction" + # cus_module.CusAutogradFunction.apply.__self__.__module__ is "cus_module" + # so the correct module path of the autograd apply method is f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' + if wrap_utils.is_autograd_apply(orig_method): + return f'{orig_method.__self__.__module__}.{orig_method.__self__.__name__}' + + name = orig_method.__name__ + module = orig_method.__module__ + + if module is not None: + return module + # NOTE by nnscaler: add a special support for torch._C._VariableFunctions + if hasattr(orig_method, '__qualname__') \ + and isinstance(orig_method.__qualname__, str) and orig_method.__qualname__.startswith('_VariableFunctionsClass.'): + return 'torch._C._VariableFunctions' + for guess in [torch, getattr(torch.nn, 'functional')]: + if getattr(guess, name, None) is orig_method: + return guess.__name__ + raise RuntimeError(f'cannot find module for {orig_method}') + + @staticmethod + def format_import_statement_new(name: str, obj: Any, importer) -> str: + # NOTE by nnscaler: to support code generation of autograd function in nnscaler + # for example: + # => input: name=model_layer_CustomizedAutogradFunc_apply, obj=CustomizedAutogradFunc.apply + # => obj.__self__ is CustomizedAutogradFunc + # => return: from xxx import CustomizedAutogradFunc as model_layer_CustomizedAutogradFunc_apply + # model_layer_CustomizedAutogradFunc_apply = model_layer_CustomizedAutogradFunc_apply.apply + if wrap_utils.is_autograd_apply(obj): + return TorchFXPatcher.format_import_statement_ori(name, obj.__self__, importer) + f'\n{name} = {name}.apply' + return TorchFXPatcher.format_import_statement_ori(name, obj, importer) + + @staticmethod + def is_impure_new(node: fx_node.Node): + """ + Returns whether this op is impure, i.e. if its op is a placeholder or + output, or if a call_function or call_module which is impure. + + Returns: + + bool: If the op is impure or not. + """ + if is_useless_iter(node): + return False + + if node.op in {"placeholder", "output"}: + return True + + # Check if an impure function. + if node.op == "call_function": + return node.target in _side_effectful_functions + + # NOTE by nnscaler: we assume all method end with "_" is inplace operation, + # and we take all inplace operations impure. + if node.op == "call_method": + return node.target.endswith("_") + + # Check if an impure module. + if node.op == "call_module": + assert ( + node.graph.owning_module is not None + ), "self.graph.owning_module not set for purity check" + target_mod = node.graph.owning_module.get_submodule(node.target) + assert ( + target_mod is not None + ), f"Did not find expected submodule target {node.target}" + return getattr(target_mod, "_is_impure", False) + + return False + + def __enter__(self): + TorchFXPatcher.fx_graph.magic_methods = self.magic_methods_new + TorchFXPatcher.fx_graph_module._copy_attr = self.copy_attr_new + TorchFXPatcher.fx_node._find_module_of_method = self.find_module_of_method_new + TorchFXPatcher.fx_node.Node.is_impure = self.is_impure_new + TorchFXPatcher.fx_graph_module._format_import_statement = self.format_import_statement_new + TorchFXPatcher.available = True + + def __exit__(self, exc_type, exc_value, tb): + TorchFXPatcher.fx_graph.magic_methods = TorchFXPatcher.magic_methods_ori + TorchFXPatcher.fx_graph_module._copy_attr = TorchFXPatcher.copy_attr_ori + TorchFXPatcher.fx_node._find_module_of_method = TorchFXPatcher.find_module_of_method_ori + TorchFXPatcher.fx_node.Node.is_impure = TorchFXPatcher.is_impure_ori + TorchFXPatcher.fx_graph_module._format_import_statement = TorchFXPatcher.format_import_statement_ori + TorchFXPatcher.available = False + return exc_type is None diff --git a/nnscaler/graph/parser/fx/parser.py b/nnscaler/graph/parser/fx/parser.py index 73c58d57..6d98a2ff 100644 --- a/nnscaler/graph/parser/fx/parser.py +++ b/nnscaler/graph/parser/fx/parser.py @@ -14,8 +14,7 @@ from nnscaler.graph.function.function import any_ir_object_satisfy import torch.fx -from .concrete_trace_utils import TensorMetadata -from .concrete_trace_utils.utils import DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE +from .concrete_trace_utils import TensorMetadata, DICT_KEYS_TYPE, DICT_VALUES_TYPE, DICT_ITEMS_TYPE _logger = logging.getLogger(__name__) diff --git a/tests/graph/tracer/test_inplace.py b/tests/graph/tracer/test_inplace.py index 90552e46..6323c5ed 100644 --- a/tests/graph/tracer/test_inplace.py +++ b/tests/graph/tracer/test_inplace.py @@ -3,7 +3,7 @@ import torch from nnscaler.graph.parser.converter import to_fx_graph -from nnscaler.graph.parser.fx.concrete_trace_utils.utils import side_effectful_inplace_ops +from nnscaler.graph.parser.fx.concrete_trace_utils.torch_fx_patcher import side_effectful_inplace_ops import nnscaler.runtime.function as cube_rt_function from ...utils import replace_all_device_with From a25f753241e0467318f2dfec1ce1dab6fe11d880 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Tue, 10 Sep 2024 03:43:16 +0000 Subject: [PATCH 1725/1892] Merged PR 2263: minitrainer: empty dataset train_args support minitrainer: empty dataset train_args support --- nnscaler/cli/__init__.py | 17 +++++++++++++++++ nnscaler/cli/trainer_args.py | 6 +++++- tests/cli/test_trainer.py | 20 +++++++++++++++++++- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/nnscaler/cli/__init__.py b/nnscaler/cli/__init__.py index e69de29b..134577af 100644 --- a/nnscaler/cli/__init__.py +++ b/nnscaler/cli/__init__.py @@ -0,0 +1,17 @@ +from nnscaler.cli.trainer import Trainer +from nnscaler.cli.trainer_args import ( + TrainerArgs, + CheckpointConfig, + DataloaderConfig, + DatasetConfig, + DatasetSamplerConfig, + ModelConfig, + OptimizerConfig, + LRSchedulerConfig, + LogConfig, + HookConfig, + HookMapConfig, + AggregatedOutputs, +) + +from nnscaler.parallel import ComputeConfig diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index dfa91552..575357cb 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -526,7 +526,11 @@ def create_parallel_optimizer(self, parallel_model: ParallelModule): def create_dataset(self, stage='train'): dataset_args = getattr(self.dataset, f'{stage}_args') - if not dataset_args: + # Sometimes a user uses a parameterless dataset class/factory function. + # To support this case, we will create train dataset even without any arguments. + # but val/test dataset must have arguments. + if not dataset_args and stage != 'train': + logger.info(f"{stage} dataset will not be created because empty arguments are provided.") return None kwargs = self.create_kwarg(dataset_args) dataset_class = load_type(self.dataset.type) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index f3d7e80b..20b681c0 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -6,7 +6,7 @@ import torch.distributed from nnscaler.cli.trainer import Trainer -from nnscaler.cli.trainer_args import AggregatedOutputs +from nnscaler.cli.trainer_args import AggregatedOutputs, TrainerArgs from tests.parallel_module.common import assert_equal from tests.utils import replace_all_device_with from ..launch_torchrun import launch_torchrun @@ -448,3 +448,21 @@ def trainer_per_token_worker(save_dir): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_trainer_per_token(tmp_path): launch_torchrun(2, trainer_per_token_worker, tmp_path) + + +def test_dataset_empty_train_args(): + def _empty_train_args(): + from .common import SimpleDataset + return SimpleDataset(10) + + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + train_args = TrainerArgs.from_cli([ + '-f', config_path, + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '2', + ]) + train_args.dataset.type = _empty_train_args + train_args.dataset.train_args = {} + train_args.dataset.val_args = {} + assert train_args.create_dataset() is not None + assert train_args.create_dataset('val') is None From e0143fb3981c937f33e79e2b7da27d510fa0e733 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Wed, 18 Sep 2024 05:28:25 +0000 Subject: [PATCH 1726/1892] Merged PR 2265: Add help message for cpp module when install from source Add help message for cpp module when install from source --- nnscaler/autodist/spmd_solver.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 75503fdd..a9ff87df 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -131,11 +131,11 @@ class SPMDSolver: 1. calculate the `out_degs` for each node, which is the number of consumers of the node 2. traverse the nodes in the topological order - decrease each producer's `out_degs` by 1, if the `out_degs` is 0, remove the producer - from the `unclosed_idx` and set current node's idx (time) as the producer's 'close_time' + from the `unclosed_idx` and set current node's idx (time) as the producer's 'close_time' - set current node's `cut_ops` as the union of `unclosed_idx` and the node itself - - add the node to `unclosed_idx` if its #consumers > 0 + - add the node to `unclosed_idx` if its #consumers > 0 - However, in real-world scenarios, certain positions might have a large number of `cut_ops`, + However, in real-world scenarios, certain positions might have a large number of `cut_ops`, and each op may have more than one partitioning strategy (for example, when the input data flow graph is a complete graph). In such cases, the search space becomes very large, making it impossible to solve within limited time and space. To help users reduce the search space, we calculate a metric called @@ -1180,7 +1180,14 @@ def do_ilp(self, intervals: List[Tuple[int, int]], def do_dp(self, intervals: List[Tuple[int, int]], topk: int) -> List[List[SPMDSearchOutput]]: import cppimport.import_hook - import nnscaler.autodist.dp_solver as dp_solver + try: + import nnscaler.autodist.dp_solver as dp_solver + except ImportError: + raise RuntimeError( + 'Failed to import solver. ' + 'If you installed nnscaler from source (`pip install -e .`), ' + 'please also make sure to put parent directory of `nnscaler` in `PYTHONPATH`.' + ) if self.autodist_config.memory_granularity < 1024: raise RuntimeError('dp solver assumes the memory granularity is at least 1024 bytes') From 3577eb88d5d736005021f5626bedca3392c40759 Mon Sep 17 00:00:00 2001 From: Weijiang Xu Date: Sun, 22 Sep 2024 07:22:04 +0000 Subject: [PATCH 1727/1892] Merged PR 2266: Fix np repr break change from latest np 1. Fix np repr break change from latest np ``` old version (1.26.4): >>> repr(np.int64(1)) '1' new version (2.1.1): >>> repr(np.int64(1)) 'np.int64(1)' ```` 2. Fix a typo in inter rvd 3. Fix libstdc++.so.6 version error for tox --- nnscaler/graph/gener/rvd/inter.py | 40 ++++++++++------- nnscaler/graph/gener/rvd/intra.py | 74 +++++++++++++++++-------------- tox.ini | 1 + 3 files changed, 64 insertions(+), 51 deletions(-) diff --git a/nnscaler/graph/gener/rvd/inter.py b/nnscaler/graph/gener/rvd/inter.py index f05dbd30..48df84d0 100644 --- a/nnscaler/graph/gener/rvd/inter.py +++ b/nnscaler/graph/gener/rvd/inter.py @@ -57,7 +57,7 @@ def decr(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: rvd = list(rvd) rvd[0] = rvd[0] // chunks return rvd, MovePrim - + @staticmethod def incd(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: """ @@ -73,7 +73,7 @@ def incd(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: rvd = list(rvd) rvd[2+dim] = rvd[2+dim] * chunks return rvd, partial(RDScatterPrim, dim=dim) - + @staticmethod def decd(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: """ @@ -90,7 +90,7 @@ def decd(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: rvd = list(rvd) rvd[2+dim] = rvd[2+dim] // chunks return rvd, partial(RDGatherPrim, dim=dim) - + @staticmethod def incv(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: """ @@ -103,9 +103,9 @@ def incv(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: @return prim Callable: primitive class """ rvd = list(rvd) - rvd[1] *= 2 + rvd[1] *= chunks return rvd, RVScatterPrim - + @staticmethod def decv(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: """ @@ -121,7 +121,7 @@ def decv(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: rvd = list(rvd) rvd[1] = rvd[1] // chunks return rvd, RVGatherPrim - + @staticmethod def transitionable(src_rvd: TRVD, dst_rvd: TRVD) -> Optional[Callable]: """ @@ -154,7 +154,7 @@ def transitionable(src_rvd: TRVD, dst_rvd: TRVD) -> Optional[Callable]: return InterTransition.decv else: return partial(InterTransition.decd, dim=decd-2) - + @staticmethod def transition(src_layout: RVDLayout, dst_rvd: TRVD, placement: Optional[Tuple[int]] = None) -> Tuple[RVDLayout, List[IRAdapterPrim]]: """ @@ -165,11 +165,11 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD, placement: Optional[Tuple[i @param src_layout RVDLayout: source ilayout @param dst_rvd Tuple[int]: destination RVD @param placement Tuple[int]: output layout device placement - + @return rets Tuple[GridLayout, List[IRAdapterPrim]]: pairs of of output """ - + src_rvd = src_layout.vec ftensor = src_layout.ftensor dst_layout: RVDLayout = RVDLayout.grid(ftensor, r=dst_rvd[0], v=dst_rvd[1], dims=dst_rvd[2:], devices=placement) @@ -180,7 +180,7 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD, placement: Optional[Tuple[i decd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 > d2] if len(incd) == 0 and len(decd) == 0: decd = [0] - + if len(incd) == 1: change_dim = incd[0] chunks = dst_rvd[change_dim] // src_rvd[change_dim] @@ -191,7 +191,7 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD, placement: Optional[Tuple[i imat = RVDLayout.dim2last(src_layout.mat, change_dim, src_rvd[change_dim]) omat = RVDLayout.dim2last(dst_layout.mat, change_dim, dst_rvd[change_dim]) - + prims = [] if len(incd) == 1: for src, dsts in zip(imat.flatten(), omat.reshape(-1, chunks)): @@ -233,7 +233,7 @@ def path(ilayout: RVDLayout, olayout: RVDLayout, cost_fn: Optional[Callable] = N """ ftensor: IRFullTensor = ilayout.ftensor cost_fn = InterPathFinder.default_cost_fn if cost_fn is None else cost_fn - + inter_rvds: List[InterRVD] = InterPathFinder.get_optimal_path( ftensor, ilayout.vec, olayout.vec, cost_fn) @@ -277,7 +277,7 @@ def device_align(ilayout: RVDLayout, olayout: RVDLayout, assert align, "Internal Error: inter-rvd producer side device fails to align" break # we only take the first one assert producer_out_devs is not None, f"Can't find inter-rvd producer out device placement" - + # setup consumer primitives and entry device placement consumer_entry_devs = None for cdevs in cdev_space: @@ -292,7 +292,7 @@ def device_align(ilayout: RVDLayout, olayout: RVDLayout, # setup inter-primitive _, iprims = InterTransition.transition(playout, crvds[0], consumer_entry_devs) - + # merge together return pprims + iprims + cprims @@ -300,14 +300,16 @@ def device_align(ilayout: RVDLayout, olayout: RVDLayout, def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, cost_fn: Optional[Callable] = None) -> List[InterRVD]: """ Get optimal RVD path from source RVD to destination RVD - + @param src_rvd Tuple[int]: source RVD @param dst_rvd Tuple[int]: destination RVD @return path Tuple[InterRVD]: The first one is src_rvd. The last one is dst_rvd. - Otherwise they are intermediate RVD status + Otherwise they are intermediate RVD status """ + # Please note the following int can be either python int or np.int* + src_ndevs = np.prod(src_rvd) src = ('p',) + src_rvd dst_ndevs = np.prod(dst_rvd) @@ -359,6 +361,10 @@ def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, cost_fn: Optional[Ca path = paths[nodes.index(dst)] assert len(path) > 0, f"Un-reachable src RVD {src} -> dst RVD {dst}" inter_rvds = tuple(nodes[idx] for idx in path) + inter_rvds = tuple( + (rvd[0],) + tuple(int(x) for x in rvd[1:]) # make sure all int (not np.int*) for rvd[1:] + for rvd in inter_rvds + ) return inter_rvds @staticmethod @@ -411,7 +417,7 @@ def init_graph(ftensor: IRFullTensor, src_ndevs: int, dst_ndevs: int, cost_fn: C # set for [len(src_nodes) + j, i] edges[len(src_nodes) + j, i] = cost return nodes, edges - + @staticmethod def decode(inter_rvds: Tuple[InterRVD]) -> Tuple[Tuple[TRVD], Tuple[TRVD]]: """ diff --git a/nnscaler/graph/gener/rvd/intra.py b/nnscaler/graph/gener/rvd/intra.py index 08ae8a00..dcc6f82e 100644 --- a/nnscaler/graph/gener/rvd/intra.py +++ b/nnscaler/graph/gener/rvd/intra.py @@ -36,11 +36,11 @@ class IntraTransition: def d2r(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: """ intra-RVD primitive D->R: allgather - + @param rvd Tuple[int]: input RVD @param dim int: tensor dimension @param chunks int: the number of chunks to transfer - + @return rvd Tuple[int]: output RVD @return prim Callable: IRAdapter primitive """ @@ -53,12 +53,12 @@ def d2r(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: def d2d(rvd: TRVD, from_dim: int, to_dim: int, chunks: int) -> Tuple[TRVD, Callable]: """ intra-RVD primitive D(...,i,..)->D(..,j,...): alltoall - + @param rvd Tuple[int]: input RVD @param from_dim int: source tensor axis @param to_dim int: destination tensor axis @param chunks int: the number of chunks to transfer - + @return rvd Tuple[int]: output RVD @return prim Callable: IRAdapter primitive """ @@ -71,10 +71,10 @@ def d2d(rvd: TRVD, from_dim: int, to_dim: int, chunks: int) -> Tuple[TRVD, Calla def v2r(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: """ intra-RVD primitive V->R: allreduce - + @param dim int: tensor dimension @param chunks int: the number of chunks to transfer - + @return rvd Tuple[int]: output RVD @return prim Callable: IRAdapter primitive """ @@ -82,15 +82,15 @@ def v2r(rvd: TRVD, chunks: int) -> Tuple[TRVD, Callable]: rvd = list(rvd) rvd[1], rvd[0] = rvd[1] // chunks, rvd[0] * chunks return rvd, AllReducePrim - + @staticmethod def v2d(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: """ intra-RVD primitive V->D: reduce-scatter - + @param dim int: tensor dimension @param chunks int: the number of chunks to transfer - + @return rvd Tuple[int]: output RVD """ assert rvd[1] % chunks == 0, f"not dividable value chunks: {rvd[1]} // {chunks}" @@ -102,10 +102,10 @@ def v2d(rvd: TRVD, dim: int, chunks: int) -> Tuple[TRVD, Callable]: def r2d(rvd: TRVD, dim: int, chunks: int) -> Tuple: """ intra-RVD primitive V->D: schunk - + @param dim int: tensor axis @param chunks int: the number of chunks to transfer - + @return rvd Tuple[int]: output RVD @return prim Callable: IRAdapter primitive """ @@ -118,10 +118,10 @@ def r2d(rvd: TRVD, dim: int, chunks: int) -> Tuple: def r2v(rvd: TRVD, chunks: int) -> Tuple: """ intra-RVD primitive V->D: schunk - + @param dim int: tensor axis @param chunks int: the number of chunks to transfer - + @return rvd Tuple[int]: output RVD @return prim Callable: IRAdapter primitive """ @@ -174,7 +174,7 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD) -> List[Tuple[RVDLayout, Li @param src_layout RVDLayout: source ilayout @param dst_rvd Tuple[int]: destination RVD - + @return rets List[Tuple[GridLayout, List[IRAdapterPrim]], ...]: tuple of pairs of with each has a different device mapping. """ @@ -187,7 +187,7 @@ def transition(src_layout: RVDLayout, dst_rvd: TRVD) -> List[Tuple[RVDLayout, Li decd = [dim for dim, (d1, d2) in enumerate(zip(src_rvd, dst_rvd)) if d1 > d2][0] chunks = src_rvd[decd] // dst_rvd[decd] _, primitive = trans_fn(src_rvd, chunks=chunks) - + # get device spaces optional_dims = {0, 1} devices = tuple(t.device[0] for t in src_layout.mat.flatten()) @@ -281,7 +281,7 @@ def backup_path(ilayout: RVDLayout, olayout: RVDLayout, @param olayout RVDLayout: output tensor layout @param cost_fn Optional[Callable]: cost function of each primitive. Default (None) will use transmission volume as metrics - + @return all_primitives List[IRAdapterPrims]: all primitives for communication path """ assert ilayout.ftensor == olayout.ftensor, f"ilayout and olayout should have a same full tensor" @@ -336,14 +336,16 @@ def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, cost_fn: Optional[Callable] = None) -> Tuple[TRVD]: """ Get optimal RVD path from source RVD to destination RVD - + @param src_rvd Tuple[int]: source RVD @param dst_rvd Tuple[int]: destination RVD @return path Tuple[Tuple[int]]: The first one is src_rvd. The last one is dst_rvd. - Otherwise they are intermediate RVD status + Otherwise they are intermediate RVD status """ + # Please note the following int can be either python int or np.int* + src_rvd, dst_rvd = tuple(src_rvd), tuple(dst_rvd) if src_rvd == dst_rvd: return [src_rvd, dst_rvd] @@ -388,14 +390,18 @@ def get_optimal_path(ftensor, src_rvd: TRVD, dst_rvd: TRVD, unvisited.remove(visit) visited.add(visit) IntraPathFinder._cached_intra_paths[key][src_rvd] = paths - + # for idx, path in enumerate(paths): # print(f"{src} -> {nodes[idx]}: {' -> '.join([str(nodes[i]) for i in path])} | cost: {cost[idx]}") - + # get layout nodes = IntraPathFinder._cached_intra_nodes[key] path: List[int] = paths[nodes.index(dst_rvd)] rvds: List[Tuple[int]] = [nodes[idx] for idx in path] + rvds = tuple( + tuple(int(x) for x in rvd) # make sure all int (not np.int*) for rvds + for rvd in rvds + ) assert len(path) > 0, f"Un-reachable src RVD ({src_rvd}) -> dst RVD ({dst_rvd})" # print(f'get optimal path from {src_rvd} -> {dst_rvd}: {rvds}') return rvds @@ -422,7 +428,7 @@ def get_device_space(ftensor: IRFullTensor, rvd_paths: List[TRVD], placement: Tu @param ftensor IRFullTensor @param rvd_paths Tuple[TRVDS]: transition RVD paths from source to destination @param placement Tuple[int]: device placement of the first RVD in rvd_paths - + @return placements Set[Tuple[int]]: all possible device placement """ init, hops = rvd_paths[0], rvd_paths[1:] @@ -475,10 +481,10 @@ def get_rvd_space(ftensor: IRFullTensor, ndevs: int) -> List[Tuple[int, ...]]: @param ilayout GridLayout: input layout @param olayout GridLayout: output layout - @return layouts List[GridLayout]: + @return layouts List[GridLayout]: """ all_layouts: List[int] = [] - + def factors(ndevs: int, length: int): if length == 1: yield [ndevs] else: @@ -486,7 +492,7 @@ def factors(ndevs: int, length: int): if ndevs % i == 0: for res in factors(ndevs // i, length - 1): yield [i] + res - + for rvd in factors(ndevs, 2+len(ftensor.shape)): skip = False for dimlen, pnum in zip(ftensor.shape, rvd[2:]): @@ -537,7 +543,7 @@ def auto_place(graph: IRSegment, ftensor: IRFullTensor, """ Automatically find good device placement for consumers given the producer placement The backward will also be considered. - + @param graph IRSegment @param ftensor IRFullTensor @param producers List[IRCell]: producers that must be assigned to devices @@ -548,10 +554,10 @@ def auto_place(graph: IRSegment, ftensor: IRFullTensor, """ assert not ftensor.is_param(), f"Cannot automatically assign device given weight tensor" assert all(len(p.device) > 0 for p in producers), f"Expect all producers have been assigned to a device" - + devices = [p.device[0] for p in producers] assert len(set(devices)) == len(producers),f"Expect each producer is on a different device" - + assert len(producers) == len(consumers), \ f"Expect same number of producer and consumer, but got {len(producers)} producers and {len(consumers)} consumers" @@ -586,14 +592,14 @@ def auto_place(graph: IRSegment, ftensor: IRFullTensor, if ftensor.grad is not None: bw_src_rvd = RVDLayout.togrid(ftensor.grad, bptensors).vec bw_dst_rvd = RVDLayout.togrid(ftensor.grad, bctensors).vec - + # get placement advice devices = [t.device[0] for t in fw_src.mat.flatten()] placement, _ = IntraAutoPlacer.advice( - ftensor.shape, + ftensor.shape, fw_src_rvd, fw_dst_rvd, bw_src_rvd, bw_dst_rvd, devices, cost_fn) - + # assign to device ordered_placement = [None] * len(consumers) for devid, t in zip(placement, fw_dst.mat.flatten()): @@ -609,9 +615,9 @@ def advice(shape: TShape, src_placement: List[int], cost_fn: Optional[Callable] = None) -> Tuple[Tuple[int], float]: """ - Search for a good device placement for + Search for a good device placement for source and destination RVD partition - + @param shape Tuple[int]: full tensor shape @param fw_src_rvd Tuple[int]: forward producer RVD layout vector @param fw_dst_rvd Tuple[int]: forward consumer RVD layout vector @@ -619,7 +625,7 @@ def advice(shape: TShape, @param bw_dst_rvd Optional[Tuple[int]]: backward consumer RVD layout vector @param cost_fn Optional[Callable]: cost function of each primitive. Default (None) will use communication volume as metrics - + @return devices Tuple[int]: device sequence for RVD tensors @return cost float: Cost of communication plan """ @@ -642,7 +648,7 @@ def advice(shape: TShape, bw_consumer_devices = IntraPathFinder.get_device_space( ftensor, bw_rvd_hops, bw_producer_devs ) - # FIXME: this comparison on tuples some misses possible placement + # FIXME: this comparison on tuples some misses possible placement # that can be actually aligned by using layout.align (false possitive). if src_placement in bw_consumer_devices: devices.add(bw_producer_devs) diff --git a/tox.ini b/tox.ini index ae7e80bd..fd753382 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,7 @@ deps = -rrequirements.txt -rrequirements-dev.txt commands = coverage erase + rm -f {envdir}/lib/libstdc++.so.6 # force using system libstdc++ pytest --cov={toxinidir}/nnscaler -x tests coverage html rm -rf {envdir} From 5006d539bccccb2b944ad42e776dbc808580d1cf Mon Sep 17 00:00:00 2001 From: Youshan Miao Date: Mon, 23 Sep 2024 03:37:50 +0000 Subject: [PATCH 1728/1892] Merged PR 2256: update readme update readme Related work items: #2014 --- README.md | 278 +++++++++++++++++++-------- dev.md | 95 +++++++++ docs/source/images/nnScaler-c-1.png | Bin 0 -> 11005 bytes docs/source/images/nnScaler_flow.png | Bin 0 -> 33230 bytes 4 files changed, 288 insertions(+), 85 deletions(-) create mode 100644 dev.md create mode 100644 docs/source/images/nnScaler-c-1.png create mode 100644 docs/source/images/nnScaler_flow.png diff --git a/README.md b/README.md index 6f1f60df..d5c6cddb 100644 --- a/README.md +++ b/README.md @@ -1,123 +1,231 @@ -# MagicCube +drawing -AI System Compiler to map a semantic (single-device) model into distributed execution using policies specified by developers. +nnScaler: Compiling DNN models for Parallel Training over Multiple Devices +============== -## Prerequisite + +# What is nnScaler? + +--------- +nnScaler is a parallelization engine that compiles a Deep neural network (DNN) model that designed for single-GPU execution into a program that capable of running in parallel across multiple GPUs. + +drawing + + +### System Highlights: + +* Ease of Use: Only a few lines of code need to be changed to enable automated parallelization. +* Pythonic: The parallelization output is in PyTorch code, making it easy for users to understand and convenient for further development or customization. +* Extensibility: nnScaler exposes an API to support new operators for emerging models. +* Reliability: Verified through various end-to-end training sessions, nnScaler is a dependable system. +* Performance: By exploring a large parallelization space, nnScaler can significantly enhance parallel training performance. + +For **_DNN scientists_**, they can concentrate on model design with PyTorch on single GPU, while leaving parallelization complexities to nnScaler. It introduces innovative parallelism techniques that surpass existing methods in performance. Additionally, nnScaler supports the extension of DNN modules with new structures or execution patterns, enabling users to parallelize their custom DNN models. + +For **_DNN system experts_**, they can leverage nnScaler to explore new DNN parallelization mechanisms and policies for emerging models. By providing user-defined functions for new operators not recognized by nnScaler, it ensures seamless parallelization of novel DNN models. For example, to facilitate long sequence support in LLMs. + + +# Quick start + +--------- + +## Installation + +### Prerequisite Install the following packages before the installation of cube: -* Python >= 3.8 + Python >= 3.8, < 3.11 (3.10 is recommanded) -* PyTorch >= 2.0 + PyTorch >= 2.0, < 2.4 (2.2.0 is recommanded) -## Install +### (Option 1) Install nnScaler from source +Execute below commands in nnScaler directory: -```bash -pip install -e . -``` + pip install -r requirements.txt + pip install -e . -## Run Example +Besides, to avoid *cppimport* error, it also needs to include nnScaler directory in environment variable **PYTHONPATH**: -Run an MLP Model on 4 GPUs: + export NNSCALER_HOME=$(pwd) + export PYTHONPATH=${NNSCALER_HOME}:$PYTHONPATH -```sh -PYTHONPATH=:.$PYTHONPATH torchrun \ - --nproc_per_node=4 \ - --nnodes=1 \ - examples/mlp/train.py --policy PASCol -``` +[//]: # (Reference output: Successfully installed MarkupSafe-2.1.5 contourpy-1.3.0 cppimport-22.8.2 cycler-0.12.1 dill-0.3.8 filelock-3.15.4 fonttools-4.53.1 fsspec-2024.6.1 importlib-resources-6.4.4 jinja2-3.1.4 kiwisolver-1.4.5 mako-1.3.5 matplotlib-3.9.2 more-itertools-10.4.0 mpmath-1.3.0 networkx-3.3 numpy-2.1.0 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.6.68 nvidia-nvtx-cu12-12.1.105 packaging-24.1 pillow-10.4.0 psutil-6.0.0 pulp-2.9.0 pybind11-2.13.5 pyparsing-3.1.4 python-dateutil-2.9.0.post0 pyyaml-6.0.2 six-1.16.0 sympy-1.13.2 torch-2.4.0 tqdm-4.66.5 triton-3.0.0 typing-extensions-4.12.2) +### (Option 2) Install nnScaler from whl package -## Development Docstring +To get started, install the latest wheel by visiting [DevOps Artifacts](https://msrasrg.visualstudio.com/SuperScaler/_artifacts/feed/nightly/PyPI/nnscaler/overview/). You may follow DevOps guide to set up the repository, or alternatively download the **.whl** file from the “Files” section of the website, then install locally: -We follow [Google Style Python Docstring](https://google.github.io/styleguide/pyguide.html) for development. + python -m pip install nnscaler-*.whl -Following is an typical example: +## Example Llama-3 -```python -class SampleClass: - """Summary of class here. +### Prerequisite for Llama-3 + +Install packages required to run Llama-3. Besides, a certain version of CUDA library is needed during flash-attn installation. For example, [CUDA V11.8](https://developer.nvidia.com/cuda-11-8-0-download-archive) is needed if using PyTorch 2.20. - Longer class information... - Longer class information... + python -m pip install transformers==4.40.0 flash-attn==2.5.5 tensorboard - """ +### Model Access - def __init__(self, likes_spam: bool = False): - """Initializes the instance based on spam preference. +Obtain access of Llama-3 model from [HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), where you will receive an access token which should be set as an environment variable: - Args: - likes_spam: Defines if instance exhibits this preference. - """ - self.likes_spam = likes_spam - self.eggs = 0 + export HF_TOKEN= - def public_method(self, a, b): - """Performs operation blah. +### Code Changes for Parallelization - Long description here. +You can find all the example code at `examples/llama3_8B_128K`. As shown below, a user needs to: +* Wrap the Model: Include loss computation and other necessary components. +* Configure Components: Set up the model, optimizer, and dataloader. +* Initialize and Start: In the main function, create an nnScaler trainer with the above configurations and start the training process. - Args: - a (int): xxx - b (int/str): xxx +```python +# import the nnScaler build-in parallelization-capable trainer +from nnscaler.cli.trainer import Trainer + +# wrap model to include loss computing, etc. +class WrapperModel(torch.nn.Module): + def __init__(self, model_id): + super().__init__() + self.model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation='flash_attention_2') + + def forward(self, samples): + outputs = self.model.model( + input_ids=samples['net_input']['src_tokens'], + use_cache=False, + return_dict=False, + ) + loss = torch.sum(chunk_linear_cross_entropy(outputs[0], self.model.lm_head.weight, samples['target'], ...)) + return loss, samples['ntokens'], samples['nsentences'] + +def main(args): + # data config + dataloader_config = ... + + # model config + model_config = ModelConfig( + type=WrapperModel, + args={ + 'model_id': args.model_id, + }, + ) + # optimizer hyperparameters + optimizer_config = OptimizerConfig( + type=MixedPrecisionAdamW, + args={'lr': 2e-5, 'betas': (0.9, 0.95), 'weight_decay': 0.0, 'fused': True}, + #... + ) + #... + + # setup trainer with configs of dataloader/model/optimizer, etc. + trainer = Trainer(train_args=TrainerArgs( + #... + model=model_config, + optimizer=optimizer_config, + dataloader=dataloader_config, + #... + )) + trainer.run() - Returns: - t (bool): xxx - k (int): xxx - """ - # function implementation goes here ``` -## Run unit tests +### Run the example Llama-3 training -We use `tox` to run unit tests. You should install `tox` in your development environemnt -``` -pip install tox -``` -Currently we only use python3.10 to run tests. If you don't have python3.10 in your system, you can use conda. After conda is installed, you should install tox conda plugin by running -``` -pip install tox-conda -``` -After tox is ready, you can run all the unit test by running -``` -tox -``` -Please note tox will reuse the same virtual environment which is initialized by installing all packages listed in `requirements.txt` and `requirements-dev.txt`. If any of above files are modified, you should re-create virtual environment by running -``` -tox -r -``` +Then we can start the example, and all the parallelization tasks will be finished by nnScaler automatically. -To run a single unit test task during development, you can run +```shell +cd examples/llama3_8B_128K + +# prepare training data: +python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 + +# build the mini model +python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini + +#compile and run using data parallelism + zero1 +torchrun --nproc_per_node=2 train.py --plan_ngpus 1 --runtime_ngpus 2 --name llama3_debug --model_id ./llama3_mini --dataset_path ./bookcorpus_llama3_4K -``` -pytest tests/your_test_file.py ``` -### Unit tests in AzureDevops pipeline +## Example nanoGPT -We use AzureDevops to run unit tests before you can merge your PR to main branch. You can find the pipeline definition in `azure-pipelines.yml`. +We also provide an example to demonstrate how to parallelize a model through a [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)-compatible interface in nnScaler. -Please note that in AzureDevops pipeline agent, no gpu is available. So you must make sure your unit tests can run on cpu to pass the CI. Two options are available: -1. Use `@replace_all_device_with('cpu')` decorator to replace all devices with cpu. Please refer to other tests for example. -2. Mark your test case only work on gpu by using `@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices')` decorator. Please refer to existing tests for example. +* Find the [nanoGPT](https://github.com/karpathy/nanoGPT) example in nnScaler repo: -Before you push your code, please run tests at least on GPU machines to make sure all tests can pass. GPU test cases can't be run in AzureDevops pipeline. Of course, it would be better if you can run all tests on both GPU and CPU machines. -### Run unit tests in vscode + cd MagicCube/examples/nanogpt -VS Code has a great support to unit tests. You can run/debug every tests easily in VS Code. Please refer to this document to set up your environment https://code.visualstudio.com/docs/python/testing +* Install nanoGPT's dependencies: -Another trick is, if you want to step into pakcage source code, you can add the following config to your .vscode/launch.json: -``` -{ - "name": "Debug Unit Test", - "type": "python", - "request": "test", - "justMyCode": false, -}, -``` -### Write Unit Tests -1. If you need to use torchrun, please refer to `unit_test/launch_torchrun.py`, and you can find examples in `unit_tests/runtime/test_runtime_collectives.py`. Please note that `torchrun` is very slow, you should reduce its usage as possible. -2. If you want to mock up any functions/methods, please use pytest-mock. -3. **NOTE**: The name of test files and test functions must start with `test_` + pip install -r requirements.txt + +* Prepare dataset: + + + python nanoGPT/data/shakespeare_char/prepare.py + +* Test with Single GPU + +Now you can run ``train_nnscaler.py`` with `torchrun `_: + + torchrun --nproc_per_node=1 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +This will train a baby GPT model on a single GPU. +It will take several minutes and the best validation loss will be around 1.47. + +* Test with Multi-GPU + +By default, nnScaler parallelizes a model over GPUs with _data parallelism_. +If you have 4 GPUs on one node: + + torchrun --nproc_per_node=4 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +Or if you have multiple nodes, for example 2 nodes with 4 GPUs each: + + # on each node + torchrun --nnodes=2 --nproc_per_node=4 --rdzv-id=NNSCALER_NANOGPT --rdzv-backend=c10d --rdzv-endpoint= \ + train_nnscaler.py nanoGPT/config/train_shakespeare_char.py + +NOTE: The local batch size is fixed by default, so using more workers will result in a larger global batch size. + +💡 _For advanced usages, please refer to: **TODO:link to rst docs**_ + + +# Success Stories + +nnScaler has been adopted by multiple projects, including both product and research explorations: +* [(YOCO)You only cache once: Decoder-decoder architectures for language models](https://arxiv.org/abs/2405.05254) +* [LongRoPE: Extending LLM context window beyond 2 million tokens](https://arxiv.org/abs/2402.13753) +* Post training for the long context version of [Phi-3 series](https://arxiv.org/abs/2404.14219) + +# Reference + +--------- +Please cite nnScaler in your publications if it helps your research: + + @inproceedings{lin2024nnscaler, + title = {nnScaler: Constraint-Guided Parallelization Plan Generation for Deep Learning Training}, + author={Lin, Zhiqi and Miao, Youshan and Zhang, Quanlu and Yang, Fan and Zhu, Yi and Li, Cheng and Maleki, Saeed and Cao, Xu and Shang, Ning and Yang, Yilei and Xu, Weijiang and Yang, Mao and Zhang, Lintao and Zhou, Lidong}, + booktitle={18th USENIX Symposium on Operating Systems Design and Implementation (OSDI 24)}, + pages={347--363}, + year={2024} + } + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. + +When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + +## Trademarks + +This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos is subject to those third-party's policies. + +## Contact + +You may find our public repo from or microsoft internal repo . +For any questions or inquiries, please contact us at [nnscaler@service.microsoft.com](mailto:nnscaler@service.microsoft.com). \ No newline at end of file diff --git a/dev.md b/dev.md new file mode 100644 index 00000000..0960bde9 --- /dev/null +++ b/dev.md @@ -0,0 +1,95 @@ +# Development Guide + +## Code style + +We follow [Google Style Python Docstring](https://google.github.io/styleguide/pyguide.html) for development. + +Following is an typical example: + +```python +class SampleClass: + """Summary of class here. + + Longer class information... + Longer class information... + + """ + + def __init__(self, likes_spam: bool = False): + """Initializes the instance based on spam preference. + + Args: + likes_spam: Defines if instance exhibits this preference. + """ + self.likes_spam = likes_spam + self.eggs = 0 + + def public_method(self, a, b): + """Performs operation blah. + + Long description here. + + Args: + a (int): xxx + b (int/str): xxx + + Returns: + t (bool): xxx + k (int): xxx + """ + # function implementation goes here +``` + +## Run unit tests + +We use `tox` to run unit tests. You should install `tox` in your development environemnt +``` +pip install tox +``` +Currently we only use python3.10 to run tests. If you don't have python3.10 in your system, you can use conda. After conda is installed, you should install tox conda plugin by running +``` +pip install tox-conda +``` +After tox is ready, you can run all the unit test by running +``` +tox +``` +Please note tox will reuse the same virtual environment which is initialized by installing all packages listed in `requirements.txt` and `requirements-dev.txt`. If any of above files are modified, you should re-create virtual environment by running +``` +tox -r +``` + +To run a single unit test task during development, you can run + +``` +pytest tests/your_test_file.py +``` + +### Unit tests in AzureDevops pipeline + +We use AzureDevops to run unit tests before you can merge your PR to main branch. You can find the pipeline definition in `azure-pipelines.yml`. + +Please note that in AzureDevops pipeline agent, no gpu is available. So you must make sure your unit tests can run on cpu to pass the CI. Two options are available: +1. Use `@replace_all_device_with('cpu')` decorator to replace all devices with cpu. Please refer to other tests for example. +2. Mark your test case only work on gpu by using `@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices')` decorator. Please refer to existing tests for example. + +Before you push your code, please run tests at least on GPU machines to make sure all tests can pass. GPU test cases can't be run in AzureDevops pipeline. Of course, it would be better if you can run all tests on both GPU and CPU machines. + +### Run unit tests in vscode + +VS Code has a great support to unit tests. You can run/debug every tests easily in VS Code. Please refer to this document to set up your environment https://code.visualstudio.com/docs/python/testing + +Another trick is, if you want to step into pakcage source code, you can add the following config to your .vscode/launch.json: +``` +{ + "name": "Debug Unit Test", + "type": "python", + "request": "test", + "justMyCode": false, +}, +``` + +### Write Unit Tests +1. If you need to use torchrun, please refer to `unit_test/launch_torchrun.py`, and you can find examples in `unit_tests/runtime/test_runtime_collectives.py`. Please note that `torchrun` is very slow, you should reduce its usage as possible. +2. If you want to mock up any functions/methods, please use pytest-mock. +3. **NOTE**: The name of test files and test functions must start with `test_` \ No newline at end of file diff --git a/docs/source/images/nnScaler-c-1.png b/docs/source/images/nnScaler-c-1.png new file mode 100644 index 0000000000000000000000000000000000000000..8deb166bd0cddf670e9c17cdc2dfbdfa6eb51d8d GIT binary patch literal 11005 zcmeHtRb13h`0f%aEgd4Uq{LDRO4rie-5}k)fGE=4NJ}>hNXWv1g~UoX3jzXC(kv-( zzWzSv+@1ULzc?5B+1=@RXXbsLdESZB)>I-UctijKfrwR~M#Oqa8AAqvs;W9v>R$!DBKsBwr?2KV zDCy1QJh9JYi^B$e>Z4X72Cc-CVu4<;LP4MrJTlPhr*gQUmBjy_^nbO)hvgT3pTgS% zMvJ&AlCyM`RBI$JAM#X8ZNQ5?`CsUNveQXtdrZEv3=-8*g=FfY+56WNx%}ncf$qT2 z*Woy9hN{SR32M(*H=*>tiy#*knSckj^vnyJ4ND*Fk{`>&TK@vtKC{QGtJ#=JVR(;6 z#tbRVnU<-QSYD9K_q`DCkHo7SaWvj>+IjGh87kl(!U$j9bbO;@9YD>Y;sQoA9u1vf ziczi0;!-N_SmbasSQ79tZuNm!R_92*k-MwQY+G@T7}C%7%HYs#W95z#8rG5Uo6wNK zDgr&(SnE@u_hkD=2R2G@jOrKr;VMEqrC94$A6JDO3R26R;YlnS(hW> zC?aN7?lXvx+LP^5b&hot9vM-x`h1e)4xiJ5 zdo9Hb9aFQIqAPXXFniaV)0GZXT;6#u;{?=U|CCs@K=o_9jR47f)D@}XY<=&YiXtJ^ zDLFX8a?5s4!LK1S=G{)YyrVECU30b-(Nv8Ph(T@_t9nJx42tTjaH1@GYUHiG_Lr9p0uwF!%6$Oxk59Dk$8N`hm zTvuACywlgfKkMv*z2j0u;<|&Oq~`AxBP$>a75mMD$VvHYwq3Q_9ro4&C4Mp%yvjXl zrcQ(s)0yVOJirw zs*z@JQk^q^qjB-a0(UF!&0prM z2|+e>5=cK&SQLaxn6?^;2=PipA5e9|hr;FvN2mQiHYH@Av}(%&R(*(|y^sG)a<)QX zZ?dHRj#ASrisopxu@>lw<@`X?`ekH&&LS$VOJ%8SxG%jOqVFX7&@*xX>ffyFj+21Zun=9{r1ALq(Ytnp zljs4`iSnyrY+~x7#TL)7$^Nv3s?(zXROwQcz(Ls=x5Sw5(&>^GMPXZk5jXKr@al)E z-Dln=b{a(VU?G+kEPeB*P;cop>OJ>q&b%3$P)Efm`$?aloBOT%ANp!9$XK9?a@gxF z6+BBr)O)p7yF&4D!+j`d6d91YR0hD$jrcA?M-3aH80|Bg**E zyd-dgMzcXk74`d#P`aDWJVAq?$3VHc1<3@Q;#L>4*7FP1u6Pu0Z5LK!x@YOJCPW5x zDRkg0FHI6T@M)0g^6p+Jl}LkXv{!{ig$`4lnB@MfLtXoP1yN3Px6{Wv%*oz|6*z7o zEA+nROWQE~LvA;-)(*k$X090qe~tT7qH~7s#crWxJ=)Yj69I{N1jaEFyVfa*Ha?2_{woQufb>q6otrpFm8cCC%yXCR(ya za4AHJF_48-CU!-$r#A62C5-*ULLMTM=Ui~&^M#gQgDqW8sN4?SNm|dVXEL!w#Rp!m zd#?Vrw1ot4t^L$ej+rxgUpLyLeDEn9$A4Xye!QPeRJ){c;97OLrh#XFD%8Joq$@qVx-!7l6PLcmJAwQwve1sI0Q z6}Mc#JeInJBY#31VxB;D5ic|1+QXH*FOFdv3d=UxjQpMuzpq)p;IOL(Z+F#7r@uX( zz9Pp}BDQp#Ig13*;qOBgp#*C*5KAnyo5XaEKgA4W;arR9dGg;gS%u02bezFifBNk{ zI+!2b8>8K|;gRNWGP=O%M$0H?A8DTw%i1YFP*2dX64w5yVR7E$E6>LUv(~zBoOP8t z@&ri;RVWz0fr5?C%zu;9SvY^tx&z zi4`Kxc4G0xKEv!+v^5eIe)kO0v^z@{e*Nj$QVlDVLv!CWH2#wavVFm6CCAwh(;aTQ z{G!Ds*1kve-9>xBfoYkP?R8ZOACcc{#_;Qv8Ik*IGO^6xMgq9aLdLykt#+yRd@ycp z*+WhnS#5L_&hDAkw+CtFWeYVRz9KC?5c+%X|!WWWhaFnK(5Dv-_Y-4nt zVy+ToOyxvqG}|oV-klj74zmEd+^hs5y-1$q`Bm%8#dHR5(P^M7wKN% zmn0Eyjt~5G#gRQI`)DylDCz~9MORRbII;2l0rQvon39)oi&&wRyPX9qd_&Y@H&>T4 z=jGY$rT__gYMJ|cA*#FHiN=R-Rj2E>yp`lZ@mY&P$b#xi&cL}G2`<6jl zHA(12us!5%93Vol`s)~}&$*4+{w05BM?0>{&$`xBG>Q{{y6+V=I1qoPxWcXK+TNkC zY(&BW0IUCI4a6+Y@CsOMd#8%ata%SiVi-j^!o}JPWsqkUa)pp8ZR_d_uSj4_4A5XdgOzi>Hyr?0q+ z*JgV<;(@%8mtyQEHY-xTx}0ZRiTQsG06r#&v>$n7n+)D3y8>Y}_|j)RoV zc386G$RLqQsYVHi#IDxkH`9&?b>^&=PTcdG2pt7Bj5{iLL(rk4COP_p`o!8;Q`C^c zR$q1z?MNQ&K2+P$ELy_Q;hv4f@&)Fk0hnO>Mz`@`M0@kl%O55k1sz1GQR-r}CnE`S zq;k0J6Uey@T*t)(3q~?LIm!=AigRt{59M~)zZW@9{OTD^fC%bF;>lG*@YEcw ziji|x4iJ6r9rhEMJ-MqN#woho45sD#Zvd`@$Hjw3Pcyc$!``c=G+)X1GTvZH%}|%$ zyO8Vtdn`Fz_iA$)PE5L8%9jh@njCY(DK)-w%U@!8M6d4wnS-VyInwideL=@dOv8=?s{y~Yup#`wIu6DB#H6&X;*ID61+%Y0>#Lqct zNFjMBCq&5D(j4B>(EbPbdfy81qOHY^_-kBBgihQ=`9r)zGsS^p z-QQwj>7yiqP54jV6fSvK6_Ze^;o?U%e*IIgM<*6m{W;WrmoPC7HCO1%+sUoD=XiBYx znp=wV@@sAO3590v?J!?PlIS<6QE z<4wW#0_mK?u>_d|-X11ag;@2)J@bZ&hnc%TX|E?3ay`NBd4MtGvz5{ER6evEW{DrUT}7o=5SgzUJ zR0a9XpQ2i0aP(0N0+wA!u^WAC9#cG#YwG~n3^C5R;tTZX2b##m>Zr^(dPc@FnJFR! zAX7!FAe>FlAWY=V&VOR~*dC7Rm?N)Q?*D7% zN#Uk5u0(h;qhA{P=XX)v#hbhDSnrHOeU{{?kL%rL@#BlL0BFMoZPNUzE;fXw(ncEYs8y_Z7hq3kf$|mN+PZUJ39cMR2uh3hyV~8 zp4_+S0F3bMvOE>3X$)%hY1$?&s-ezzLZXu3f$B1-$6wu3G%?V?Dq~^0O7k z4XSQCjIhAL`u)#iRE(r7Phen7dcu87dZSs0v6ZCK?2d`wjTh!`?K*nGU{lQgR3L{- zn=Nt(fqMZK*7PRw=sCwv)vZ)<3Op zN;@m>aOIsK&4VE?zL#%8nB_m8X-#<9m)(}c;V{%ll9DndX`MhSOA({XALl^e!ah2~ znABKQu|JNPJ_!h@%GRLc8)_>Ev{1e3C8{T5EnUY?p9TKH83W-u{1ErDB#5YYe5a1-kC1o>>r|EBOTcT^N@#C zWk1d=)~=XoJ`B?8QF~|H@iNE6TYp7Xz#ECn^^k7o10&w50D!KlP3rskQfRc5J zFY;e^VOb)%^``IModY+@1c3TnQGTBkmITBxH9aKsoEuW(g2E9@i2UOuFDq#Y!xx>~h7VUWKJ?H^Jy+NU$=poWDV`M{Q}% zc-Cj9$HcT1?x2yBVNOU@92E(0M%@`X2Z&xojAc#)Cjjzknp$n{CFzf=32SoPYTNpe zQ<5)j?<{&P~9PtLbqNCc%*bu)KxMm zhRK-^Snju#1E9d6?eCXl?IG&6fiT$5B4VWqUrxi-xq3X*-aqU9jn*8vnF;_uNI4UoQgfwCBz0pHg<;KP9VRKt3i2R@EL}n@yz`VdgHH;tYmz zsJFK;z*4KtQI~`$m`e2w<5lfnPm2&|?emYB2qpA&Sves-dsESq)`Gx7^1q$6Q3RE5 zR_oK!R^#n1^9Gr&13QnW_k{T3;l?h6yCgB~a>R~VfG_SYoP}fO?02fggsJ%@ZG3xK z(YZ**>oC*;N38N@`0)o=;PQvKnY8sBUJwe5>mL9!6VPu)&Gak74Yq%`KP6mY=NldxPvyAuD!uKom@% zK5v?R<5N$g*nWE5j78*_6q%?M-u%LZVJPUV3>=FN7L#{UPv~#2h#IPuMV7lzJPE`8 ztGC10#wm+z3B5bDK`L3DMTN6t(iIc{wCKB+e>D=h3Qq$=6=iOpNIwr^N3-h*HoO!i z{sHWZZ`qjiUmV)bb>uZ65_AN-Lv7rb6Ki3s%PNFfEwI5MfcR#)M~a&A;8(7LoXG)jb-;0Lmyob%l1UYpB!ayb-8Xgp#ep>?jQ1Z?*fqWunlm5COlL`fDjnk zqcFrN4X^q>Wc#fV9p%@cgSu2NF;KDEW#pR>7iLJL0(g$3*@&Gz9@*9nYobwchi_X? zAy4uBSMoN%4RWj$xs6B%o1hj*SUW?O=uKXKA^X>=@BU=nXl>!%X<~M>5h4dxqM%ja$7#+VH^;?fgB^A$;=fHMu;^(f;O6JPR!qdRO^8>M z)qsupJ9IrFmUhGpEdIV?vMFAA>ojb7Bmj`1>YQjIUa6ovl$XA~C}5METW~Q5=SQNe zV-j(}LgYd{5+B}uJ@A=L-E#}n(c$N6CzS17RD-yl?Y->`OYkTzbG2oJnC(cJ!7hAL zOb0&x6~W*La7^T$E5n=_P)l*?CIVOMUf?^lzoWQh17><=e?DY#Bwzw|1NF4vvHL0S zbSHQZvEN+1{d0G+R%*%S7*HhCItK=%LQFI?!q0x$ZaO96G^}u{-p*N^jSrMJ%+#I5 zrp~aaeKl~WX;>*-8ZofHn!lN}l<4JXSaCliJ90O;Q^Ag0t%>`kX5du+@D+`d5Y++pbOQ0grj)`tIb#JD|3S!~X6r*$J{H=%ycYmt;B@B=r6y-oMY1(k z=d_N0die4H;5C9W_Cxy>N~r8s0k~N)DYJGYo{$jlh8}G=B;O(nfIn25wvC&FM3{Q` z?5-qt<)5q8S}np8gmAZWf`dAZFqs56kWwn@h@Ni8$*-k^!?w?uqK4e*?(|-i_d9Kh z&lH^|mW^w_o$D|SOHaQKKOZRjnbpZ%7l=hChg-3}-n{9}p8ucfBc3B*b6@KL_K)ta zw8Qtp_%1pWvJlC;jLSMuZ~hJ=XW8^Q zw~^rc7r1HXuKDu2nDKC{w zBJm&#SpR}~*lh)H*OFray?PGlj%n*;EpCP!K%vB3M0*3ppToE302nuI#Yq9~`UN08 z`?n-i{~KeM@b9<+j!iwrF+7ijDxHn1IpX`lcsiC4QOo zMH64Ga*nT4oX1K20)ftCV)WfjMSd$ zs~ht27!(LZQ1nJagI&klpQwlA_r!$n z*@xv~l3t7hZR7gM)YS><9Au~L!|VshL|51B@?triaEIGV?uy?&wrZ+9`}Wi4Ptmhn zoXMy!v%Kfa#e)akbwTDg1tKVJBKynwPayA?KNFcYm*!|y{Zhk)z$ z@+CYE>~cb^c8khTCp@o;en<5uX{7XZ*4$ZTNQ1)egbIG}W!&8!x@r2&&x$eO8d&-G zoBzFJz4YuCUe);<4`K54xR@75Z$AE|_A*8|*u9O@O*C8MCubG&Au(D1EF;>oAX8Er ze|Dn3@M`&vvezneR3o4=@9z0N$>oEd3V{mGOwd%LjtH0%0tjMVQfQsFck#1r13>mn;J zPP_R^t`BG*OwgahT1TO&v_Q*hldXebhl}}V%C+A=`iIG0IUsrda^8r|ihW9F>wI$Q z?MUrz*w_Ml-M8C{k*qV>=n040zgZU-4QQB(W|L7*3CY+cNzJ7f!|qPC5zDEWYhHZ{ zlb?O_Cikkrc)8alT5iU{Zqi`whrc|O{j13(I3}URCei#2^5UcpVOPjwZj{t_5EVMg zovdSyF^)oLPkR~OUO-%4ZCVcQ!Wqc@=tk1Pypz|Vx<^kPzUF2Lr>CDEkK6s0*7M=b z_J_ZL4~HLdoc^t2ff{DtvzQz;y7kVERL-DSs9$$<#redd?4otpZsjBQ?B+K3?ebCm z4z4K;BKlzYX+ik`@2f_g*nzco*yZi_4UTUuMM7UCF=oJTU+GZLHFj=L>P-sBPk+}m zXmm47MYaCYioVkqPvn4aD@UgdZLQ`o!Jd5oU}u!Ymd?*+pvEL)r&aR{lre{pu(rt*5u>u*@)f>zVuJ)in`|5P}!t7hfw z^y#nlf;Q%_O0h(GlL5?Hnj|Dfhs*htf$Zz?SYunN!Yb(%xTC>+9qwv^uc#DsGVpOK z4}^tZFXB#mlLT|x!uFD+^w?=QHRRbYp!j|%q;y5}F}Y%3O;x_|!=GO%eT+&E9=(F& zQ-m2EDZC7te=^c|wQ$nlC$H(t7d?=)#vpri6C1IvFDN(sEi7QS;W`krH|xGsyh8(eD4U`V>jhp+j(T;p!9{jC$75Udu z<;dcJib>Mi;Fz}AJd?~LDF zs4W1io|J|59)7xUbmDQcn;f+G6lZkAG_UkUaN@D@qkEyiA*^O#i6{Ap!I3&yrbX41WmvTBx`1wbWh zpYr=U-H!sD9V~wjs&uY?L9}b_^ffJcd|Y4(0wffpE!3uU%fB+`hqMh%Ic`gWxXoi3S(_WWV~3 z%sXZvNKzVjG-(m6mA9C?-rfP>Tj$vuWs^6xBE0K{ppm8z;v{B&o#*era>3En{Ke$X zgwEK$4E1=#Mh?p%zJ6%~1(?mSuV~tiK<5)tm8x}uAg(B%$$&P0&TqkkJ^ObVXU7$s ziWc&zSzAoRJS!T3(lHL##maa)8CRQmoAimT)%*dg2IOqs!>-{R2Pi~E*tQ^ZJ%y~5 zyP81x>iZBIe6-}p(cSp@7I>F)T50;(@zuyy*P@C2Q$%w4p#|($=5mOK6M4}Q30i)es63#`L z1zt>1Wp7A&6=-;7FnL+8;Yb={v2n@h4`Vb5c@S^+t>SQJdS*JVXS43F41Rxa-Dm3< zhC2^b|KQBui;*8fp4o;C7zX@s-jmWr+dw`SEb zf!sc)Dc)?JD$-n^Gg>;b7!09%{^Iw@=))Uoc9kz?k(L>@rx|qGKF6+EHZF~`Yt*ln zEv(=xYRQ*~%`RzMoc~@zy8Pom|07t>|M)?aWnP8r{aEn(o-`8(yk`ZeC}=_(<*XzA E7kN*!=Kufz literal 0 HcmV?d00001 diff --git a/docs/source/images/nnScaler_flow.png b/docs/source/images/nnScaler_flow.png new file mode 100644 index 0000000000000000000000000000000000000000..1e22c25df39777e39bec701695801ab16871d701 GIT binary patch literal 33230 zcmd3OWm8;T)FtkO27)e{g)b7w2)GDfU|`5nl46QrV36!!VBmsqFu-48X+yGrU*JxP5+Y!g z4|$ISN}Fj)n|J{|ffRD8yKGO#1j+;Mt|XFqAvlsCkO&+JIoOc- zIvQ{Ryprq|78d>Q@9^r>=Ku2!a71Y2|9P)|UhaQaK~VG`+5de$DAq4HAmsf{K_Ikh zg~|V}L?AE`aR2vBMgE`O)Tfuo_nE}~ayQO$u2f1{IdSQTyESLNLQBTlx_n9cUz}b* zoK4LRJ0TyLbhTNH2V?ciblMudo-YQ_2ncM>Y*3pbslo!eZ`b)7SRo-EZ#tWn;%{Wo5jKpZiHncQ`eM1s1M^4aB$dm||$6zzkWZcE^mN`e=QPL@a-ZAkd`t-|mr(V;IcP3Fi$A=oebma%!i8agMu4SQd$DR`JiDOJ7JbLy+-)+;)9CjBiQdf8@ja zWL=#_#mA?;?(<|*qV)4icx^uF`-nl&eq`@0OJvg1T{|KN8!uKVOCB3$b=;M#oAU-D zgApmv?eK7XxHO@#RISh~sx1-5M#5$mfT?vr+JuP-JGl{!O>>i-4D3M_y*;ZbBXMB;9uWRRLgs!m?aL>6xo&pIt8Ummax zSUqp8DL#We+AP%;wQiKE&{3GPR|_EHvP)73ak-qNF4128MiDILVbb$?C{iv~=&_E1 z#M#vKJmsK>RjtysX$+^=ZdM?;0vs`_R>(_jE}5>7F9v_O+ESN(^4ob>$^IFUsXaIPe5&C4(h-h5&5Yq^ zrw#8rnlc3KDf?rj-b=9F9(-;`v7%Z=joLy>Cn!w1Ya{;szgxeLQmL}tL>ef91z-$K zhTMgtX}tG&`s3#T7f>iP3MVG+IK#Z-)@)v+f@=nc+&oWG z2@65TQk}zDPGz5v!DaKK#G6EyE&T1#%jMQbhzA5R!i#j`fvbmsH6=}HnG{7Yzu@CdlusHw6Am@=o!4Raxp+QEoes*HLdw&R7W6kHC=RL{2W zUw&t5D|&cIyIoI7*rbgrh{9l!K^+xTA|cq2cA8EXO1V_(&)3g1`z#VYk=6*sk%}uL z<4M?O`@B4sT#qi*nkYSOHZ$t9Qaj8Rec>f>Rgqq6jS$fN$$NW!pU7%Dimf66)5qyNDNs3fH6d6fhKS&- z+2hvr_UdUe7hnSaM5e=r5N!Acq}&Y8Wys^vT9c7P3|Mhq2mEra=2s98b)(e#RK>3k zf)HVS-`4<;P~nBl^Zs1`yZ-;%8uxFu$p7zZ+e<#Bh$09-3* zCDUsd&zEcLXJ4qt^t=E`yZEI{rHt}mB8T%14X9C3>^94@^F)o0H^(EL-cLa0cHryU znQwKr3X6(TIT%f#iaC8o#`~l^vGaT|o~7iS#qIPN<^z6le@VG|$Z zf{RhHa!I_Hm{{@4-%J6l#xGes?y7fyk?ojw0dl2CAUUmwmCeJ%K0YuJY{r+0U@mXxdBWSa$KDB;eB zeY}R8{r)IC#nx{YX43f6ytin-l15g%7DY~QU8t#;3@b&g?0edL~-_C z?9IPY@zJ)GfE|vRL8r!89J>f)`aawQr>nHZ#hc@WvfICZ|1JU*CavOjK)FJ5?hIR> zpOKXGmvxlU!{uJ%!FV>mRbzkumyAzx4Jd?si`HRL9nQz9))ANthRDL`ZSGg5bsT5U zElvlGS%R#r(_WMxGT$B!a^&r8Gk=4`j-1z9oCd5L&QT@~U#}mg&bS(ZO*w!ILDjSP zZ{HIANP|9#K(D?ojd z1_GP&My98DdAw2b{zV}N$;qy6nC-R=Unfqk^h0*K)JSD2@ZLafn|%p{MJd&yi{ZabbyTVg;7a}d9fi(8vDR3sZjO+FJ@e)ep3Aq>>tR#^ z-rsXEoA!U&08F8;0-V<)LbRxAuBW2VVMuld30}BB4 z0l8P)hRt-8`HuMCHY)*#MwW8YHX(O{F6V*k{Z2QZ(ZkrwN9HG1kfa0m)OevRMSex|^AH>#^=)mC? zKKGGh{(^3jM2&3R@OhF8!z#br5uohJ@wyB{<27G+0-gIEBTiuh__r}s8p;p*qdL!K z>mK7D*l)-44F16b8aW^mdqFy_#&R^~P$z%!z5TI_;pJ`Yg?N1XXnDaG_9^WD5dXGL zfaau7Z1*LEKg0-hc0_&rVu7gV-S_hCjb# z#X~?NP+LW=TP{>qSZ*~RKmK4#=-ET&Ftk4d3QX}Xw9!B`#~tZ^(I(SJqHYaeG3l(^P>=+A9g^u|4j%v5ge|%6fw`gpfKoE zl~pct-1l*unC@`cOe5*_I<+x!GkIMrkDPY;qb{GXCPsi%hThw|*%5pLbs)AS-O(c* z2$5Xo{jrB&S4#oiaf3b;%EE11EBa%QwAgpE(tLnOK zK~w14Z|D#Q=&$wEfNc+2ZuwJ_o zA7FBDdKG*2PEz!IYE&-%V@0p>lj`REF+_GwY`+VaRt#af9Rs=W!DOObM9E%J;vuY; z>!sfDHHN4WLy-I1%iRc&Q%ryi^pWjf1mK?))Xii4k&@BXEr?A)tLHMrnE-UzhKCry z+7A1-fk5%>!UlQ5d9tB6y&d`kSqN|*It@#Ir^9YX_0a$7K1qpux0bm+g)owwrn;1VmhhZsD4y}q ze~TxmkP+M_&%X_)7pXR;D0l=b2x*#Abdh-n=Rc@WA}@r98;TQ)8wREgE)1!#@n0yQ z_*SU;Tn(UNwN_}i)V9hYUblSxmheB%&ZBAn zK01VuGVu;}>HX8NVZc>)pUrBs!x0NQ#2Kz&UrH7u~-`ZW!JDfgWy z&QVN=_gOOB_&~@X$TUIp}?_LETv`y?(?QUm^?O?vQT^b(^*@s=@@REcp;Fp_W3s)PM0~jy&vlz zsLwJ1tl#AQd72Ky!e%y}NaS`d~S3lKckY9BH ztU;ae$%a)64*mlICO!4<5aa}4ul;<2MihaQvI>x%EKIJp=R<*xdYdYgLN}q2l3~dc zM-G|=_}%ncAuOhnb+j*5-*93IERhXu(Xd+5D2Ya){>5}WnX3-~|DOZ@3|cu}{uS8K zEyrK$l8HXupy{g6%ig>~;;gUgde!{(>(?kIu)0c+q2*X}b1!^pRZ0ujV*s`;0|4C7 z`B`XNy9TQT)g8awp=|!P2U+{T=2{>Jk=rwsy@>HzOts|fGANG%gH!D{_;9#V(tled z_Tn^W9pxk%=xKGf+A{X<_V9I?ak-Ep5Tp?c!i6-Xpi*ov}~~ z+ZzmKx&)!Sn1)2>%dJ)&hs_^=d04NtssM|J@@~~BW(ha5y91?c_VFCdaR0%FXbVM* zpOPF~qscsZ6!)s}q9q2MHOKolObASkvg0Rq~x zefh6l;DbJxsHpD>5lapQxy5EBHUC6I+6bMORIFA3LFnY#)vOF>D;vp7;K1229S{U*)?d za8a227zfTgHw$Z z=1%bWuc(Ub%0k2|uqo?RLI6^fFqA+w^7eWsc!lcF9|VtXt*2_wF*3wHN=^Wdy3x6y zKNQPyMuo9~vrwg{b_XvNUe)<1?T!He;+V>C@?y8kA*#JOHN7MN+N3dblI222VKS?) zfvHb{0pL@ysZ)Qg)%i#@MKt!z$ln^(pMu+Z_Q)kmExIlQn{)-BWE?9{iotUm;)wr<@Id-<(=fxq+6i z6a06X(v+Kn=Us1=wj+r&1-2yiu|4JhX*psEYtjs$Z6NyT5K^c`=ymOo;p}*|KL$Wk zsV|vaakk-%3D=0V*;{pAgm=J!g09r`j>KATFhNqH-#@6_y`8_k%CE(D!!8ZPx}f23 z$7cPG15j((3IIsUw2Wt;fLK*J;Po^iAmOuWQ9sg0AB%~LPE$cgB(Vy-IOO;iI?CCZ ztMNXo{)!q-Iiu)-@G)XuDu4x(3ZGGn9%7Qf8YW;9Z{L5ckX&*xY**qg_BeyR@K}n}kmXh)4-CM4@mXjH%VIzmNzfmAdKU z?Wt8`ybuf{$veQ)B*0_V_uC!MTJfTc9vqasF-l<46QBXGoTcb3ju&RE7C~Ah<(JGq zv$YA?)s$O-k)H&v?ixaS`^bJ4vs)+W2BGi_lU|953(&nHC4#7^Ck`7js}?{U_F zhrC+Y2KzZSQ@~r3$MuZfX1SgvCk*yA+l}zdwAcaato4}DBnOvwD zMX}h%hB63_)K*dD-6ucLPZ95Eco>Z&Gh9xJq6`5=rBKKp;<6vlZV-T9g~mhi^?c+# zfCiC6FM~;M0DPaCMfTS!7z(xBsDVzutps2xQsL>Mbn%w#rPD{!N2` zg1HA|cj1d=2-N1}O7$8McWZXbSxE|zX00*&u|g5JFpW7F#9c$Nf6))IFf>96gl?Am zUy}p+5Rp>lWFUU9fbaK$lD{MG&>Wa06X02PXjOiEfwKoB!V)csqxp)mZ94@Lq8tu& zo7jpFkgLCnso2DR?Ny3NKiL5~n)+wcNVh0-%{WscD<&u1_Z?31yCS`~_p*zDoK?Cn<<}@Q2S^9W*2y!JG^Nwa&ww1OHB|c6;9;kj^!pr3I26I)TuiDO?TkY@inrejZ{%DunW1y z>1LYSnMhXWE~UcWNE6K9M)m1C#e62xK7+clKzb96+`w*3yLweIi$Sy#NkA9Zs&Zfw zz=mZSEazrjk8e*`=C;2W`o7)U`T#Pt2_P^Z@j}*+iHCobO2S5ql=qe2^6GlNlqvoA zbcJbe=ku_aVsmDJ?Fr}yNfAV=+03w8u69R|BSe_eB*yw#_5HJrRtsNDC$di)q@Ht$ zcn1YJnPjpQ8H!Yj&f-}_S6vYUo={PosBZP#--cQoU6b8L{`dtm?A~T!ySHn@S~&Vj ztguAfeNl1X&|nQNF=92j_FlFppH1g+E}amQYei*;QwZuRX#mqV>~28Hml)V=Zj}Y_ zI~O8RvR=s z|5l^J3zgU<@Tx`QWv3l(59V6Gr{_xv>uQ_Gg6N97kQ~szAtfJR3Y^Itan@=z*|J(p z7n1AI(X**5AfTY65b}FaH`UL|YDA_pRhRx2nQ3=-+M`UIKS#kF--M4}7Z;kD}qo`zi{!9PD!^FC7av`>t#PVb?TtcMzLEY z$opo$Y!M}t^_qylea!kIQ?$#5bpV^FN_?7-eqMiDwZ+k-iE|4LJ#H8tN@7; zjzsKCBo@eoKK|_yh@vW8PL&Gv8zFFKDu%3n@}8)Yo?U++n!u6+DF@GZuP5BUZekqJ z758T31yw9yE_Sw%I^=^9Fs$Jr@biNE-@1DrD_l4zj&jP5E^;mCitfEJjC&GMFOR1m z9B&N0UO>&h&OV;yyL15d4*QMgJ&HXg@1yf+0`o*K&(!;n*D`uwfyr=0bEp0H-u*eM zsSJK~F5ldaP03O*h3s4@@{TS5WvzAUB)@Q-py*RrKAwyycBxRo520B$_HTs+!z?Xz z=3K_r(rXw6PQH=J*C%0SJyHG9)0w$#AV1Pvg_z4Np(7_}FO&6X%E)|WW6Z8Fekf(N zpEm4jL@1&{MD7%gLwYB_N4%UR6O2AF4szALn*CXZ;0Tq!?0IF5#fuB+*{PZS?k32W z9>Z|5JrnLtOjyTn0%JnU2+2YR$P(8X8-(hX`eDs(7x^BMTx{yB{4b0`$Z4lkKJi_Z z_(^|KeRrj=pNvnVG~RZh;C8?#HQK?(Re4Lfm)7^0o$m)f@x7_2&8e z30E2W*x!QOFqwqC)o==M^|iS&qmF>m9X#><`ZFc%T1 zZrZ;-DFadP&f5LFQFQ{J-4`XaQTyG9Wn)V<5b<68sBt+rTc8YKg#2W29&1n50(Plu zh9RzeQYbIiuzqKT^b~hk%A7EEo4^c7UDD^PJ%@bRdAMs6$R>6f zyhzXHYn7*voiv2eZVOc|=+zTKLKdJpQ#Y@k4D$XuLsxNJsgHzx4p6xG7w$Vu`yLtV z{D{OwkI3fn>dAY$RM(ygP_W9)j|#^_cy&ucpL>Ih3;@Hf@O8qJeHd49Hu?0{B7!m6 z{?S`C3^({BO_kIBX{ko?98J{{tnSKjJEJ2GZ$ z76$VjrZAPcK8rtE$Q)D|p*pk(y{zNm%+HwGPJMAAJ-8BKAwIK{kAO>3~P9ZIYKO*JJ_u1VQ?o{KdS6j<` zbfRliPJh?vI)U(HM!WS0SP1BC%kzDX{(@dA&FcDQGCyL{YD8P6On!^T{;!S0zlQ^$ ztRj+hx+-IJj#H?Xn#j05Sp2*>7Q;PXmRlJzI@L`SNnRvw?j6fu?p>U?iFs}h6Sc-< z-rmrMkcgL1bT$XXVWTI{OEXS3br}V7+%S@L&vSm2I+P~h`-pDn@qGzP4m7@ zrZ;V@c6$?_f>vLJCaEgFHeH$VvN&;6Nw^fVf{G(t(6nhS`ZMFme78+!^7zDTwdQdj>} zVWQM#h->7UAV0pWX&h!7Cw+em5pp;oJ`5-%?Sa@RDB|4n#Tuhwo;dtmP-8%?=<`m{LPV^7zh$0qJgzU4p$R~0hABR* zBP_G}uK*>#>@V0_vTpScCJ9AYd8o1I1S@3~-_W!aOOyn)qsTilK*~X+epS2iXn-G3zx_4g?`ls zu^Ba_3KHUbrGFWQlSh6rghgZZ6u04!HaL7h%S6biV9X;LD1l#;*-2=ADaA0Ik zYVY@;pwx=Wo>zct*4-S8av&dwLBPH|e<$*xruvd;FN59f2{BLVP5-H^2ywF^k8-L= zCdK;cm(p)dY90!qRf@X~9`C(W%|A3rO#vT=;1_mhAI1ayN)_%6j8<{>X6KChUF)7m zM+v-izEcuIBMXGuzx=9zQB4Iz-x47~8#(_!lV@fizWeSYKYRSL;16PmpX#@;trxKj zp}_;^ok*^50^S@Lt>VG>Q2Z0nF2tr`K5B_xg1xj%1^kf3c^bm@`;+#`CY`k?VG}Ij zZQL%V?&(4pPb?tG3Twye7gzoxHrITVz=@I2W;`_N?0%6e`STgRX&FBc4AuU=bDj8Z z71`66r-KkR*V9lSv3<&@M(v_SCRbU)lc2e84wvP6M1gjaCcMvNL~@C~Yc!&nb_69{ zHP*({XftEIPr@ce2{x{tUobcQuEBKF2CH?Ebdq8=LUjJ;^9YtsK%4w>CoWh#2Xhge zpA*)SB5dt-yJY5gbEM*aHO84_wak0oP6mrYsJw?p%;WQT2+%v>HXi|rX#C*_(h|d@ z)RCS9O1(}m%W^J416b<0YQUz$`!&#EyNM_a09>n+FGJp}>`u)W-iId{Yp%rZepTax z*tm#10x;W9MCyLI-d5=*>R&+{x#Hiyj3J>y-jRlWz(=Wp$O|OW4#l)K4{U(NyPYG2 zF@%qC_`qOXUoD)G4JS&UCy$WP4Wn1yO>m5-D7YiDIQL`6G4w4Utd6%=?r#?Qg2c2S zJu+P_n^}OY>{qS4#(Z`gDI4aJe>^Vsq32I zN48p5JohFe9rrpLIW3sFUB71RVW$o#X*4U(h{=~pA3!Q20o3fV(|q;W+2iMh(5KH!gcO-`Bj1>gSHL4DYf@G@ht z)Cdxs-{|su26g*FpgeQ)seJJ{H(XBJa_zbL4BYGiZIl;oyc`VjC;$7N*GD&c&0jl) zzz%Tf8frdUYy#XejowO-P+KutD-!D()2c9x776;iMDEM?s*y>lD+@ZUyN}@+{oF=T zIcVfqLTh+tJ5Sds;QiRh7l+$zycKY8!jOH)B!TchFj+y^AXaL_-SXlem;{Tks(2^g zjs8k}KPaEa03!+dWW|q*U(ko9+*x;(E69+W36F>7_AXeJ7XM&FjKXN201W`-+_{AW zI~L^X*RWsHxTcJcME@yzJ}N4-hsAd_k)!ugH05WkCr!^*|E+49W=N>JCP0~O;4+Yh zA$Gs4r5D~m6~3gdlF;)xAIHXSZUUAhTBnbSpQ%2zP*m3RyMgs4{o=_}Le}c(oIfW! z?tFX{{dRpcC!--s`~iq(bsJkWR`oy^+WqlI2Sh&4o>P4F^R9#N5Yw;JErAv zw2q;)QZ<+;_1u1ugFqk2R6D6sLcPf!6848XGC6XuNSX9nRE0zUgI02;w6d{SKg;*?-BT)68D&{* zn89D%GNpy!+14$;5sEuWB3-b(TWV{Kk%uT5XkDCmIw?bP;ZYn2uhOs&Swmn*!Ok?7 zX(6rPdn4QC(R{&-=8_5h7Pi00!oYvnTtQaiJ%VY!o|ZACBe}3v7uaA_N+Y9+A^j`p z1L5ZuU){Dt&*qMZ(U8EiqRQmB%?C0|nmRN7?h_ax^sGCU?b-LmuP~7h`ebb#z(YMJhWjA~%h>}s^aH82IvM+~ieP^U}`C8il<^Zvq2Xi}*H~US4-8<68 zxA`YSHjAptvP9gMda}<0!!<9wk#9RsdaOp711!v{!wW(kRrf4;f2j<=l^;wJ5A|s7 zJfvoh+=q;(J=Sg0)3y=ykP z>E+O*Jf<+eU@-C8Q$QuujqnR>X6G?3GQXb?)okQ3b?gay_q$bCXlaM7DWba{(IW4& za(kdacmmJ%YXM~%xsV1hi0HcLLAZuiq31m)dX{X!QC&u=u!otKJNh<)G5_!0*N%hZ zB@@V}Z~A5}h9!pqk$UnS*U6^?1c#|#iD!d}x|g-m)3 zi%xEHSn|8Pj6x~l^73?uQuI=M67u1R`>)m8Y#Uj&8={V8zJzqliu^$Fl>=nURmK9` zpv(qaWO%;ztZhtz+xs;po!0uxX(9alogmB$RC}qwRAJ&)7}fLzJzt&WMw>ELohZOr z$q=eIdr#;#2S|ZD!WCG2A@{&B<&W^6;BGnk;u?+CX}!I@#h+}Fgf>H=^^Q02*0H zuK{8Q%)MWKSpY|ME*vJ+x#dRg)40u^HoP^OeYNFW%)C9A7)~l02Gaq)=e-)^wyWWa zVX&D|b|V)_L$s}Fm4@cAN7(Z?eoPHQjw~_=>gIp<5K`y4orR*k1~MNzcLDWS?R65w zMh)?qMHH<-f#psS?RV|vaCJ#MTTI9Cn9|xtvKMqjdW?2~A`c$1DJ40%4+H|;P^sO| z;cXnF-QIK5QidXn_^+8WHL3eG-n{R!cI~&@pSmN(pQj0BSy>FORb}aXh%VY<7Db;w z&zC4s`KN+^M;YRqf} z@QI`zqz$=6B5}Tm4Yi(CGD4lrI^Bqaxz2I>gVvt->z?w4FiZq!Y&>19b&l=l$Mh1c{6snx3*8Y1k7FpuSVHi3}LE}UqqmIpYR{z72WBi z&;ADVGl^~=AP?>8AQtl)qj2ocOGVOI?OQ97drBEGb;X0?d+y2P@!FEniD5S5gA-fs z!FV9@{Je=ho(a5!d7V;zTOzCG{2d5d#NEc*d<;%v@*|4mTz3;)yQAb${qZ>pci4`% zBOK>h@=sMEmg2aG4fzD63)9G>8KJ1}d{|&m$`vPAVPd7nDAIw#+4n{RM z!tBfNls7MhM7Edo;4Y1_@*2?dv8N_~}eUuUFZg7S; zhQE*YemZTM3nD<|nIFn$(`WcT2YqI2Pc*%-2YARXSDhwGR@OT_WdPf?&MkD$Or;L5 ziV30jM_T1UFG>7?b zp{~!*0tUL9M@|7Av>qyWUEKb*V^Z=$$~(MZ-RXz6zXDL{+Da0K=fXJ>%}Q?)+UO7C z-rAg-PXuntJ`OunNh8>+^}996xIP%$=y5w7#wIc;E9n+nc>SRF9^K&3ZG3JKaWo#I zgCxqKOpCP`lj#;=b3RheX~}v>$`xjjZZ_<3oz{rLVQ{MsfK*~tg?~maG@EA;Dz@2y z>%}gYsm!KYhwPkfJreYr>(2mx^nH72lGFZaHX=*>ZYB`Vg~NT}jG#TtM!@sBBf*4l z2p^;mG3=EhIF_Cyb|Vsj>F%ShaUiBF0KD_G2HlA?O4_`=V{ zZNPX_x&=XEfWZ^yFZGtjW~L1GY4!J%onr}L;CDp%Wr$8stLJqQj7B;xIlP9{-GKxF zzHtW9`JRJF*S`AnYv;%w!yYFDTxMc1Qr|;pi=2lfQ zmR-JhMk5#M1|%J_*!Lz*Gw_%mh+60iHfhO6u&xgKKoM~lD19rJ^L2&{uyjCvQ11ks((g{Tlf|0fUuMaZw~#P#Leqc@Ed|&D z60JNOdyy)^4R!6a6K$6v%%DZ~0iDs31j&J){KbL$F%wB|N#&8jO+;5~lD)@;@scUY zf`0$@MtQ=~WroYF{DLZq<%9a->fwDg0Z*HfcmE4~7b?5Q_U^P86*WJ2bgDpN4vG)xVd?wX9gKNb19mqp!$ec#cRvrf z{(x46v)(Z`X~3qg!jB$`s~#CEh3Etssf9Hgl#$Gah+He1-w2;AY9MEZaWsP_L9%md zB?0kWNiAlm=ue$=J^ANvvW4zUH=&=0&3DcjXujz1nntd(jTp9;hz6EZ5c8{>zjvD0 z)?A8*c0n2hCaELfi3Gq(kQ7v~QYSBUc)C*PZ@$jenaMX)=TU7+!uh?#IgDa;A0azH zfP*b?l7swmfy|pH4u@$aEC=@W9yq_CLA_cs%XvuJ`GJiy8&!56p@TOg7VYON9y+1S791cr{4N zz+E6q)ij`Q_U=j zkBVl#JB*N0-JJX4VD&Wkhoj1^Jr4%CSD~^RBJ7pH9Zzpy@?P9PPTfqw8jbuDzV{YC z4r?#c<*{h$_34V5+%No|aJE(RQnf0^x2KIaa^KhIMu@b%;Uu7Gk0s0r%e6mjAz~qF zJN+Y9FOl%yin&uHUr5P7U6)_{?Ih;PZm)tVbq(2Glt^Y`CGpe^1tobZ*hLCjBN0bnWgZ8RJ z*12EfNDz-9`maEQ;dWI|HY){Ea(5RYf=2-`^ffr*5b1s`fGWpJF$`z2Kmx^mKg}$W zIAE`T<;2@R5;E6Gtgt- zwv0z)p71DpxhM5*I-%8HAQ;VLqExcu=;zIEfvMnTZ`^5s(=ll9Nq#)Z6y;mIbg-37 z->(?A_Bn<%=|_&tRH0G7cUX8l7PT@}`}&9U=Z=<-eq%X@?RoFx!ofTfR20*;8L`kN zRisuZj`EH(Wvaf1ZtJ%aG*%EhuDYo`l2}v&lU?4xFCL;^J^xsL@dS|&&~|9FNl(6B zE;C@5Pue90`o1U#B!^2Ab^S05wtK@YvlA9SdEMKaSdh^`CZ~8hCM8#oy3+Y?H=8j3 z++V8W%0V|Gf;w@nyvhDn!ZEm0;arqLZs(QD99Gzm$19Rc(5!I1 zKzG@52U=lK%2EdfE5}vZ4hoq`qBM?Tf@;(3n(pXfE?QLjmlA`d(NY$EV<=u)` zdntOQos(u&x*gU;lZoA$?=7JaX}iX-auFClp*XX%7o04R>fb}S&~VWzZS7XWr8gZh z2Ziuqy$Nvb=5vRV88TQY$o>%al~+seAoPyEdxKTPTqw-Df9Ckm`o7$xNIN}jM=*I% zgepXE`j&~myLj)zXEE8ASs}{;yTWil<}B=~Ro*}^lBKZQ`vHCWmux-_pdYIwThRTo zFg!+>_mf*s?$}XH{#{MKUY=fa{Z@-gnQ95H6xVuEMg|)v_kbJYvHDFXwv}b_q2!06 z%ica$f({T#O8!J?YVs@?JS^!KyJX%P*FwqY_%Axs&p5y|iiWVHhHlxyG_k4uVIpx2 zDHm+VV$xN{QQxH9hub#Ne+>C^xKNd|WQp>MJ&B^~G6_r1F10nzWpFBd%h6m#5`jzN zkDxQ|H2_dI0O&@=bCX9XyI(t$EYv2f-N+q1B%&?6Q5L*`N)%WAbS1eyp?scB7{RH< z*y3p2afAUZygw+*)xXV3LAOWE?VSJ^6>v*Zch!%YLyYYqk%DJ4yf?rC>bN|E*zw&3 zFq(=^jEO@+cmjCUk;1Y1XN6V>TWo#{t3Zc`mLtq{oAD5uvu!h_<$_Nnq(i3Ues0Ww z2y?B;n$raTGyry;8NJ($aw_}Namg3<=(k5rA1p_{U!G?*8=*P{S{rfN>7@Y2%;aY3 z0=zXB=UG~480a*Ba|1WSyZpETPc8DmD3j2M;xWJ&am)Z8-83|fT2$y%?)gG!aCa<2ZEbsGbeO5@IbYyU@#gU`PeC8jYv|;6Uc-JJ zpuz2@ZCY(w?5@Msd4l6$>Zn!|;Njh`^GDGof`Zmn>rKuvS<009Z1qekr8x2Pv zAnKFNg@{m$#^*+ZgVNps&dpYqi5*b~B455jbZ6aZMc`cfpTF~=l*fCE{s|0e1wP6k+``HyyX*LFWtR9+_79yzI?zq{DPfAwBo9!H^kZr1x% z_w!sFSJPZOFv+9Y=6s_|@8#66Mj-;<&VRJ1_#9?ZuvfO=bClwq=MdI?Sujj8g}P*X zYvnSoQk#oShJw_nh>;B>@Tx+SgF>45le+ii-iTShp%c>eL;gzQ{rSxY=7%Oc01sLCa5o#256@?wi;| zTugF?dD_O7?~^s$L_x2Ti@-7FEJ!8kb^64RlT(}uK5%;)t6{XByv4+kGUH45Nh_== zG&+D|^p(FEV`roDkldz+934lrrZ#$r0_p?ijY3SX7G5us7#OD1mg3pDA?2i@`7|!M zPXP4(WS*!?4bS2y>&4jJa}i>28L0^g!FbmUABuN*GRSPGa{24u%Ph(n6}-kDcdqsD zu|CL@%gZ%uCZ;rHN~OVTEGQ}FKj?+>(tay{Vykzq;7~tGaUZ5Kq~!REC=AKmvTzs? zH{J|5@pREYfykNF1Rnns9I=eXyt|S2r{3LY4cr0V4w0O_k(jJ*a&9(YRVoo&BB5rh zwap*N`_gtKjKAeHqB$ib*w`uNefFl&Dr$6KvaP-Eb*hq#=7XN~2IN{*bSQf!pkGD< zqP;_yd@PfjO-^-hhPOIgJs4)#pu2S2=DqJ#+LpL+sTCQYO99?2{`D^&5MIx~QTP|6 z+~6qA>i*Kr*HBb4W1xf7G{UJJW5W=GRXy!j3PT?f$+}VZy`9PO0l|!cf|%%~J-?+vcPbePVX@*>)_m ztiNi+n);}LcZmD;lH*|hz)>t#RevRMLdDNYtzN;u&PY*u@VH>8{?=yI(d5Y&^ple! z`#|}&*TDJ?igyf8qG77~kU@!I#lfz%&DH9g;y#hB#loqg-?9FetA%|2JLg0K+86xw zpW{T-@E_Y**56%LRooTddb=b`KOkGq9|i{i%=zIo9|>vKWu{<$3<#;qtOMFzS+B zSvgTR-RIs84>a3!Z!~;C-Qf`YF}`se=6LkH;mty?apvY`BPpF*_4}utYh#A>*weSr zZ)IA&x9xSvMqDok=liG^ryLt8kF2!|GmXcU3HYFG6mkqvr8-l& zl{@Tdg(tf4AQ;M<>Sec?J){G`Ru`suC=prL@B?XeZGR%?936{OzQ?nu3x!0w!=i|E z{NjnOT!{yk^&@Lgp+InW66<+i95?y6Gh}ky9Wu#P2MddU1D8kxHw0FLT_pU&ZXOer zI%?E^)}sUoQe!dsr(Vbf#_T%57!r?6|114%L3i^BsDlg45hP)bb~8loW~ac5!B>1P z!EE0RkoLhuCbQNW-|I2k_u z-4(`XjfZ^y=M#-?S?_(3&1SwJlzGw|RXV&20v{^#*tFg`Omi+zWDl~S`(*S~eJyIu zAMLGuh1fJ<6^Rv>q06!9i^G*SBXj#)HI1xx+n}(=gVG0T!in)Z(xEHz@9!3#bA}pG zgQZ%T(;uh1H7-ojdD|w+czzT{>p?xb7%IWxDYc0XPBG0F(e?uCo~Hv(V?iYSK9y>Q z2XgpF*WM)Lu|lPWaQ~~e_m0Q%|Ns7lY$dX@vUg`F*&~IBtZcHg_uf&I5LsoXWbcu! z%*qHv!~~3$=?&W!{OU04KY21|a27>f96UyVeG+b2Z~D$FZZ~F zTgFe-fg*j?l)JB1t9%VbexG#wu9;hH98}!8PW0lZ{MSn7L-z>p$)iyJf#uGhZCN~DI zZU7gKoJ(@<$<5V%mL0KtR(3ib;|mkRouy-GA*M1iWo^$#zHfdGZD~7vYHL7e z6kW3UUPrCu9r~Dhq}xf1CSvWuyx;zauTkpiapgMAA?9!f(=-nCgz-v$;luYUm>Kh3 zolFY!BmyKUU%q@6(SDr*a6%<{PZ`lIDRk2`qVu(zgnqyslIy)0sRu z$~OJEy0wDeLSzL3$wY>U>aAU)CGio)3(cUNd29YpJ3|q}Iw;(CugmL4z?DHS9=pV7 zpV$H$iq4@s=#lh$uM6F^sQX8)$|BD+|2ikaJJwW_$*VT1`rUQD)WOaKyL?mV=2hAe zeA`c+yqigRXx$k^qOi}FnJF4aS&uW6m+#&N$%sl_XzRY0zMgT6w?Al8O9xGesrU9Tb82SP)D#72W-Xh( zwbgn~L#!vi%aoRfdVj2LoZ)2ld$h`Y{bN!kIz3W(=1ZA4{bpC2XGNE?n`xc+Vh%Yz2&KmYgM(>8dsT`&9z$bj77wi!b_&_rW?Gzx29zGsoieZlO zvwPgALNn6x32h;S&CL{~1X{&Q)ZmbTN~o;)j!b&pT`H-8crgDfrb+l|Jh-Pl`E#kq zMnYDwp-^_-1Jdy5rjHD@VXX6wXD&PA zKQshc-BQycetq|^TX%ju*nXS2GVd_mX~Lc~p7+Uq#|_6DU&mTvoH_5Tdi6jq7?=39 zh^$1=k-bdU5`*U|{OhEsNm_pW-txS(+Ih5)>ka1A*}^Vrrl7kuE;qpOYE*hq`|xo` zSke7pTQd?C2c^Bch>+lwypL>5&M)oNOkvTessbi0?D_pxtq8<7dBcmERwA8;P+?{smrgLeY~V8LnqExq>&ye=QF(^U(Y(b z!uJ>mQN9YJQ@@=kAvhM?Y%MWs6E0FZni7dRpZ()~nP~+^fu10DP~z7n148LT)OCkM zUQXN%*Y@A7wRO!VFB9LZ&sT_1UY{mz6Q`UPO5I)Z#8)5`d!E6>qKbASA_hVgi^08O zs;fq=D#`a|f4&ZXq)AblVL%@XU1ZbpcluVsl6{Q7#G=|$+MlqI;f1x2yOrVNVJ}Cm ze%1P)n#%Lz`6!3+p&$j$u7*@I7a=T1$+fX+jo@x`2!52-v)c$gBx})?E>sQ^d+5mD z8=eZr#Q`V&nqKrn^Y;!v9Z(bAoXnAx#Z#bxj){B?W(g0IPHAG)oYvv}o)fjRg{A|< zbuTtX_PyS{S`N=k-&`yI$T7X*c`i$Jy4W@~cYm%c!UiH&flb!lp=aovg!wj`%@_;G zY2ut1rpP^ZP|hYGmP{~fahQU)NyKec9snErJW%@t@BtpaaOrXr7hqE5%_As5r3ocG zg$tnlj|U&{`@@d8xW22M+pNX!^rPyyv4$hDp8FOKbYdd!{Wxg;r`ic6HHuxYADV;o$ohU21mA2G3v2Bv%Omki{dQ*I7Ey9h4d%yg#kk0OFCdPjci)wTzw zIcD9C31YgrZx_z>@ly{N-)kq&hbN1+*{I(&{6#;@u(?Y5%+!JYlQx zo@w8|Td`*0F*)6wuDeV&vXfvG$rn>`o1!DTaXW#wDznG(EAHQmxmRhc^SY~DV(?tX zmxL^J|3kM>^nO(B`3#l&9GjR-NQ{gW$aEL&0LL|1=k?V$o@B1C_#81)OVL9BIeDp= zjU1SNUVj|DE#GqOBGz+tL0;W*aKT>fuyO~l+5S;!^WYPC9yT z1A0z$mYDA**-xb*f6`Z+-M+(;@c(7PjD7)LV|=nVlews~CRE_LPx(Ofc5C>j!$psa zd_wFKJF7!);KMaQ}gT8?s;TEwu(W! zxsqSyqDDjpEuMlU&;F+N>qQI`#P_u-fu?`(-~k|cD8eHmG`Q=55B7pbsZTlVdE&(N zg&ENmbw+i7$*B=>c`SD)sZZUOZZ$8H+<8duD{~a?r#FJcyLQ1iatses&~zl%x^r*I zi;-MS%U|bq;I-#1;`dc+nA}SJvSdW~ECPT&d*brTpuJSZtNG6SU^qXb)MLng_&_BR^ZwLIQA{w8(4|9IQHb)kxn zsf9b_@#D{kX{iLO!~OS|sDb%V9jB`#vo!8x_pMz!=4BOz`%&MV@t;Yjig^v1f*qb$ z=4zcFX7<0AA(Qj9?V<8zh@6J2Oj@*{oIkdP30 zsQqSXmzbRK>v@o|znWL{-Tjt40X_#|57eZyAjDej>T>!ITrYPkUcfQGkWe*8jn%xy zy$$F-LlUx&iBNiT@+;_&EXrB;GMDAM5Bt-D#xO^c{jTP!=uwX^qI5(K)rB}&SX5P4 zIw*V8DB7#WhhpGPSAj3qo-LIg(cUA9Z@)*y|#w(6r z9+3&MmL3NNzDY%55;3Ui42phEFg>@vMy z{&W%4ikAd_D*?4fToT@&GSy^0ugpxc{K?Is9$+>jG2fd>UApS9&S+|7yhJbQUq?XHQt7+z{Pv*O+Ekwh z#rMh=M9xdozY$Oj@B!cY{s9=1b!CkTAPQN;gSsujN)>yKPDI#1K`xnCs}U#=TVpo~ zoGl0Qb<}mtmwHpUq0Dei4;TJdbbGT8$Vs=z)a2s3c}zes;#r??N6MeMifPPi?Tj3pSD zM=y@QXwq2R&)WVAb@aqBNknYSaGum4^}-L>q%ubgZT)*B8iJcZ!MiCZrvMQ61-i28V14vLH{nb&TN{*5-`Fzo25){oj9;!YkkXpPW^3h+(lP3=Bqh zJUy%bh7M$|fKuQCA3PH|mxk(V(eE04#ijZ+uAR1SY4E;76c?R^oH+>1G55VYI7xH} z!$@Kqi|7&<+ynTY{Z@Jb)^yh4y=9dvkK2^NP0I{V>Lp+pxYoLt(*g>g8rNXu^SCT^ z-_wdegN2h=un{{j53e1`96d*zBeEFo8GoXJQ|hvV;fwIKgD&vHcaAd?E_^jQ=z6Yq ztf>})61*!#f}jfCy04F0cY)ONBA>qY2Vd##e4|D~fJHpmSf>{uc5(jCjU^6I>Ukx^ z3R)cNx0H<^fP*5KMyzyt%d`oJ|B>br_>JO$oY{F^@+f`7hsz>Rh**FFFc39v38&wj zWW*Tt3+k|y8a?lYmb=J}WH3W*b(|HSJTXkTttiHMe&#SlOuZs4#KO7n4ySg9u`lXM zxG3_bh+%4xR{Kk97~ks~X_EPxE{z4AvDc(RC;NQ?8l>qw5E%Mk(--&q?hM^`6=h}# zJ;kTCn5yaZTQS--``{+xkfyFqnRIjyL=qo*Vvo@R5D7i+GdTE{M)(iSv}zijPOZtINV)za5h+_X`9G zV{;Gj6#80@d}Mo1ekJf~{yII_nqT1e8=TT#=XR4N0}MYDq~r8}mM7~@0w32s45pQo z=vo=pnW$lK-w8Z9RR!Oc$6;)q3I2|{5@}fIPKJl`qg3` zJRC32DPXV#f<}Vh;qN4$mN#TX#yr(*P}r*!m29Zf|F+>-u-xS5M?}gG@51!tGvu0& ziA=5V(4-}XnLrvS`Ob$NqR{j&K_s$KJ6fPL(w3b}^l3-gjSEhoK)a#Z4%!Ir=3_XI z)Yp3tsv-Ngh2ozE11wII-qa@T_hY>ZYxByMH;2vd4pA);1$H2`T}pNO-cZbMzq1cw z7F$t^t=Zl}gjP1wbzYcXFtPxTRlD4Zk?s+P;B7_Ch4(dSp9tH{d;0{tTINi>PPo z;t%Kts8%7nKk|Z4;mSRuhyz>cQ`G)7y}5nhzAQo;>#A`R#$QCATm-+o7p(pv6k-pn zpq`T=GG~WnrfEmxw}E){^kx-i!ZQw2!TA0cKRY7$aNs6UNrwLFUoLB#DIPKC+9Id+ zZH97LA78oSsN~#kgU5^tiJL+phZ;T+yvPHD{tU8@gb`fH9iSr$+6;HwDqr#=(ysF? zfrjTNdjzaEhFXa@mtgv~qEtb=aNVlmJS4)ST$-HO9F0S|%a?C>eec|hB`8pRZ=ZGwD6JT-Ba zYaQ|dkOwgP8$G#IdylDM5qzE&Hghb?R9gV`;JjZ5ksfCk(PR%v9s;v*iCWlQmsk4( zU{=R(GH#6ka9F*@nZIFvK;`f}0r{*TIcHNV=tHvu?;`<$Dbs`EzB3W^reFc9YH$KL zRqsE069Hp4Zdxh~!zmIX&*FnL6W8B-?Gbuy*i+`$`0lq`!DCEeETihzX^bAGs_Io^ z^ck9`ID$Yv1t)-F#Xc^YU(s{U4XO zucrPtO0WEuSy;ht-y<-}w)_1QI`GpK5Q39eM;W_wXkmYy0RccR-dpCdC(`j(E4-~EYR{0${oJ#)iwX@e?uHW9olw0_Dj zP&VGI!fdLY7@DTjDi<8HgFEe7!Y-_$*bd?SPXO&`hq{V5D}oP%U5`ESs0bh>(_YJT zhywTz127YpLX`l| zT*J}NuF43j>gmKg7dcd=K8R2GY}Q+!d6~tRIAK`rD2<;K7KR5%xPWfUn?0JA2uY** z1tRbYus4hsd8?}k<{ehk$JKq^;Cp^4{DJJSsP1Gz|AXiq-<+3c2=zMFMNvOJKb#v_ zIF76|rPxf#y?HlY>A`BE615-YM2xPda(}ci$2E8eS-5IkR#*0G-Pd33ckDt)CCteF z=gB#`7~sWCuyuL;TGE{60^~ww7@E0wHqwj1^wn|%SFdp1(>epezACCmEL@lO#j45yJHVp0(Cy2t#QAsRT^F9AWc$-A@q@Q6y!oNB z;Sm;!2+u?i z^)Jay)1aft+i6ZS4P@klg5UcXp?V;hUWGK&%?GgbacH4>9ggU4_Z#m6UH&&Wp_=Zu zH4(wTUy7;4RIH+31Gr`DQ9rWaUJ2@{r$IqG7c&e%F3-3MqT^Z2Go(@I#^ZxK%3fQ4xmhU5fE>RD~wI05dftM+8Y1iCAh2uE+sUN(zcsj zIgb0R_hpnzVQ20+t0~UrglOW$Kx9|-f5IM&gBEa<| zR0GbJlV7Gw*F_i(`Ge0(-FM3WaQ*rsLFg}fF;CTNcH>zsq|;qsFC7cN;IZ3ki6OLq zL)g;dA%yX^GH!iOZ>)qLZnk4C1z0#iHdeLwhC6z=p@n(FJiyUjLw4Za@` zmj8##CPrC}ZF{n^vO^M_=8$?>^E9bTm#xm~>wxsonfVuI^`&!M>34G|26ENTzxH`m zbs7_f%}7hfIuFS7&$`KSG^w4v5cDuzfbf}{5wCDynO^`{kjJW>-Y8Z_Yj;NopdVcu z9Fx{3)31PYfESUF;gj@7N#LY07zk$E@no^!_t@I;w?wo0R#Q{48#@4RSj z`Qc1)i}54!Hgw{%7nUqG^4HwDcl7S9)|FVGdrtCDs$7&_gVkC=WJ&*KHbiBlqs1FY z1$8~Yj}-D=j*RuMFLL|1X=}tQQm)zZWiB<GVJyMJtVE%*Ec8XJyy6qZX7%o{kS> z^2*#VGSpm_tM#`t|NOE9pTQwkU;5Wf!6)FsH!NVMBY?QZE zypcE3!esS8Jg6f`SStA<>6Py{NieE}5l?rFWPpEFYXrH*qDRHK^JvG}rZYkkHr6%~ zxJ74{hv!J7E=m0CVD6-O!hZuC_Eh|8ii-t|a?bIJ(KG$s!ie7T482pM!>jS*KTLm4!`D8qLB}= zI7caKkQk+g%W{5;-)mXoUoyrqd+j#>yC<|x!UKYM+z|wGs5)RXx!`2_DX-eZ+Rdd1 zQztemS@UeiN-GDy4BO2%1y%07F2_!4Z!1`5-sE}6jg?TYP@EDaR*=t(sw(2XqipSr zXzw$uqVExJSoXHJvR@Q91+2w z2=^W51qdwM3yriGRWF$C%tuLn2EG*|WD#m6NcN{-m<`wB^-_VSAW(FKGs~o5C+ex5 z)1(RO-7KR9f9WZaU_ULoBV7yIW2tuz{xkD-`$0O+Osf&<(wPI2v;pc=kLIhlAG;3t z^fAt-SU$rK%W&Fh@WDDw#`ee^>0}(M1e!~2M)-yzYp#hLpD7^^lGknR&%*gl; zDG+~hN7?lIlIyF+c?)2RQ_GUS5H+WVco%0~$RPQtJYc}9cF;k~<=j8uG_^@FCKOm} zU4Z(LeSwr}g!XgB6ehF{*g%P;?}hejq55#0|jnVK#oz`DstQI{npq}|}G{Bx0@INMl% zSeL7rK#dh1r!E3PVOQIt+SG(%gr*p=CkPE+lAD8!2?l(;886qg0P^_iAJ>bl-x>PY zS&ta4UI{+T4EYMLl%@rWSbx-&CrT2dq(4{9Lvyp$sVkl}o0J@E8+S=$%l@=jhd=%B zZiBK{M~#(u%S`sba+~-;`KRw)adBr?lJzuxEV=qB6G@ZN$9$Un8t`v}G741S){nSy zrqx>FDV<2L!d153usDoF7D%P^fT$pfkXUE5Z6X;Y`Jmbo3-s;JAW{$38f`ilXz=sV zDYtr2cXE7KSxq_UwK7pvgjABliR<~tje}%a{eC>rrZ$t{HfpGI=Jq*Q$&Nc5g-Cno z7Xo_av-FILEgjW9yd4ooPe|~n^5NIEYmrim|3~j_ik_5TCvb~Gr(ky-5a=H4_SxUl-9O(yNR=lblW(s*O+l&;(K^Vgw;FsJthAmvN z+W$S>-A;$Yu&$aw3#rCUkl8qlWH9C2c3$g`&>mC%@JN$;ST!|gSY@NM+I@vaDAai} z4_n!JwZ*rzvM9OQRsJ(&Zc_Si_mRQ8c~=@Y%kAO1t6S)YG5A`S`I|-lj>N(DE)Et2 zcpWPgdwdN=nwboL5Lcc_`&~AjpEP*_!zTIaIB=U2%nevk@!x7XJB#TQTnfN zGWO4~C;quBh>!}AgYJg&g5$off}tT|qCLR?W}(k_f} zUe=yIPGr)&t=l-_jV`Z~NW-_D#@N!cx!yc2I;U@*G(I*#Y0Lw)D13jDSQc%|=r-zM z>a5>xt&q{jnX~!){(&GNIYtW0D}#jzsKx1_CO1cSS__I{8_d|qIMdaiN936L=WW&xtp$0iV2me3XzBHVJq&PA= zs9tTiFMfBq>T7;~rD_)6%TD?$_QZ#dMvZqRK<7{w?;2B{PcC`w@h4^|XDl+v^ zc21&kOx7iLVi0?*C;9=e&5|QR9?YKid8aFxqj}`inb_RO^XznOp=oD>Lvv36c76Z9 zoKq*iCma?*ANvyQXo&?Od;D-{ykvVU>{uE4WI;9aGP4>-_PrmP@0qOVxp(uwqK5=q zl(gmDT~u}-5S5{{q$YY?hoxPxa|8G&_;W;fmhna=~HqJ%$He{4qWMccD-Ob;COT5` zKbTK6(S1r58OjbA(tmwPIuHL9qER`VR+9*~p%khy<~<5TN__Xq*m9f>X#z+DBb%%- zDH~wQMXtj8S~4e-C~j#araba?ExE~rF@rx5SrlVsn)c+WkRqd%;|~N#xh~3C=4#fH ztZbMouJm-Nq3`a+o+*@WWtuiKlSioH&Ri6=346xqP2V2VMG?s5idVBJwaAxD#^R-!pTX- zik+J`gpmjIqcO_l5o8~c&qx0bwQ%1UPvRD9EPub1IhoMsCH9&d;`j$@7;s=fX8?krp%5H0 z)qr*`3z$x}dfkMLWUb$-84JAybT|cwmnaU#B~fHjVsRgo^>jR!5(9^9z`i`(oq$_+pH{k z_c&33sM@DqE1-_Z=D2{!;n*OU*QoO2XXhC5mDlR~0^XD+77<+u0z7}KDB1J0{f;p& zdyHE7kNxxi)0Bx0C5n9o)n&Go%*>LzO6r7R)cH#`sQGU4=hf#!a6=U9y?1_K9yl3H zR5>`DTp;HeI_Duq+(JN*XCk?tiF-vWEK|z6@6KYl%WF|@ zCK*1tG`H3GCBj((N+<8K&*@q$w|`!J&-%D%tsm(`z(ebt%XmT`lOpQz8OCy$K`AY~ zF!;L7uu`i^x)qy#fp3Y4{qu|)KELzfSmH|O6*#7nePe2n?}RGsM_(Pj*n^-iK97y# z%o;wfeg>oxPVO!!h!4)cf6uW~$2^=!D?SWUb#gt62K_fdPK-i@`?19kPSUv@MYz!lI~qS-TgVb3Cm4#Ud8<37#^wNi#{Yk& zp6y$&6N3VuiZ-5@U2(zH{d-2o*dFX|76q51oYn_(Vs*pqV7J3ObH<~L(#ZE7(mrb; zVoTLs@O%b!$((LIj6bs7KvwSHZzT#$SKzU8fpf&p-W$3&nrOxbwi^^Qe%(;M8nl#> zT(%=5?tSFR!GdiC0E0}e2B`|@FPcG&$x6w5qGfMc{C@I90IQ&SVhGcNz)tX3UPA>N zA4ia$=+s>DyG!?M93_7%z4=&}?@`J1$e2tG1OtmTpqWIf=-8!t z&^C~TNmuO9@GT}-Xn=}Z%H7Kkgpf&Oxfg;vbE>K!#?oQsl7FV@r6R33-!f<7YHb^s z$QW=|h82vyT(n7%4$WEgGL)E|db46k&nxc;@@UQY^TgQQP+_e1ZdeH3RV4fkEiZO! z1x0?TOL}D2JZTLdAX{KoB6KkGjC}_$OAnXXj?Ee9C*Okv78(%ABiE}Ro+VId?PrWl zP9^%J*_S}}@ns-N0@M(z;tZnh`A^f51Rut@f)oSP9*dd}p!IWi(f`SIH;lY2;WfnF zi5@S1I8QluCU}a=>c~JF7YhWVNQ|I>72|y4lcJ=4F(N^t|FlCeHdj5j+4y6+UXOy` z(6;O%Y;1Rd;{D}fqh-*B^AdDOkL8fnAF{_3)+nVV%2~JQyvs3iTh7 z3CITri3ndzZzz)e#I1y!k1;sl?!bx`%7+qZr2;Aj+t@h_Ggjmct%4{?`=j&|P+z}y z{Z`Y#qko0zmVo)^Ec$E@UpiUY1Oen%MGG==CI}Gg_mMDU3Zz~*!7!1_wy4DHqA9KP z65!nri;y-v_ZL3tG% zd)6fX9f3qL6O8{8p1CH8*wawblcuD^Ao=HEVbPEq`Up`soeWWN(B+6wh|4?Do>64 z>pIVRj$TzEQg)>eIC6sKDyON%0LE!~gG45*Fa7j2lqZRS(ymOH&*|9$Pw2nqrH1t( zYCPm_Ft(&XMv09DjrD)wM-0yYp*sv>G~|Iv%fp0>MQG)Hj0t2(l>Pyuf&5ORauQNo zseWV``QDUwbYAMF{s7HlR_F&?Jlj2uso_0ANbX2|8tI#22jk35B?Pm{@~kKS{p^Py zVbow2h)C3c2c33z4^M%R(E^6Ua4|74ky21tRC+UK7U>i4MqD?dTIvgzaLL4o|||I?|p zbqFq$1>`U>(6{%3Z?FJmt%?d@g|b-fUUb-rrQZXfvfaok%%Y5gc;a;*_W9FXgvr0# zDBuq(7))c}2+LK-aRYHsas(Ah<}x#LPyL@q(Cbdx{UPlG({A@U_F+=Oc!58XEOKfW z8wamhOa&W|tWVcZxEK81m|BZl+jRK%4(_4aqZks%*zXFbgwewwSY*B*+j=831XS;8 zDa=`Y!D*fVx9|t=om?>2Aye}{5^LmZe{q*g{3nLMCqel5Z22U7>?)}*VKQGHzBOo* zk?jhm8!@=3%;6PI!4D2rl-V=*ZUs%cjTkH$20a(}1`4n~MVJ9-eFC%TSYh(ql0Y-o z>4TQXaE75C9q_m~fi|xmwYxsS4jp-Xw(<~c(SNt(41xahjfO+}Z2Oyzw|yWKWQESt zs+X-!so4wU%_?579W9nV8#%uZw1Lg$5!3)%B- zlME(lo-8I2qhc#z2xs`$O@HP-7_{JuN*b(gZ^tm(;~0Xf9~;BWN2KDYDO@R zL5&7X7&5DT>MZ`z1sZf}K|8&k>lGt|W)!gxe?)~~DeT;bnRZoJJ1Id5ziA!nm{i2j zAuzsYXgF$v!HJP`Wrhw^Q%s1{Gxiq*6`dfR91=ptB0CR#^s6p_0eLd`G-3D=9?)=i zpKr}Nfqb-_mfvs!Fwq^_7?xP0LcBUzGWoo8%I;Y7OhqlJ=o`dB!PzOW5owC(hY zZpuRmG46FLGYAva&?27VP-HCIi5fwUJXUx8l&_TkUO#LOQTHFJ3RqXXLe74JOhA)+ z5d-tD_#p3xCw*)`lo9gqMNRA>v^BjRel{^0%6DH^%m!1AV?@JK#C<9wXs_Q@UY+q? z;|0X45iAOgBEmn=>*~VKCrrh75d%l%YjR9lwna}m42m3-Y1^p&-j!4sMyaQKZ~_V) zA^I?J?H1@qGsDz4H=nxE0A$9jx9r6SJk?Rc9C!r!pBA=YBjL6o$!GBtCs708n3kym zmmv-X!=)(HTzy64N=|mFnrUz=?#Etq&01e4#LF+`+73jv#zAnqjGe}7^lwuk4KyaN zgN0fd=ciPODmm-u^AVd|>r6H7872D%2K}*?`5$XZ8OMe2WB(wbaR?UY3O~v*6?*)6 z_>Q32t2Tj|T0(ovlick}!6v^?-$)7am#*Zy^;61zw(i^wuXkBDtozPvC zwmc(-gTUHhm>!e`Q>_&P8P5@1#w*m#NlL^p4fg!R+qMm*you)ugrZ`{0GoeV>{7Q; zJS!z=-j`}XXX;1X`Gur?)g2 z)a&vopA=zQH8SpJRl-m}|HqY;CwDFt;Wn>)Y+7riS&x(l@Q8?+^Do3#QyfCq~sa@%4Y(WE|Zf~4uWs@)cqnO7RSYoRZfHO3`mFYy@yX77b z{;m6I=5-c!siqG_>|@2Z7>gN-?qPK$x(k#AqgNBjJ=%~F*4%mLQdREcQy|kY^GbBv zU%ousKJ%Z2e>tB5+YA}z9G5bSzUZ*K9G0qY`Z0O9!u84vFxNLCO1!ny9!)$Bq-0vORbrFpkiBZi(Y2Clis9BWEv1@qICSFZ@!HSH%R;_YQ-cKW#m)HwySK*fH`4WqySk0Hqg9@>=TQy_jZ zyYba$_#b+Q&qhmO#G;{n!j^bZ&ucaamsD&szY!ce)j!UcF!wr=V?9#1@zfm|xniy6 zFonA;J%<&B$;+eKAy6_BeF;kDmM-ljjk_L{x>4DYT7N#TBgUV8B&_GT!%* zCmQX;Hrn7Mp}uCD=nQAhD8qc9xqPD*O#F;AZ1k@u73`H_C_d7TlAfVT$g_FdR{;X@ z8~|RD3efzaLB`q}=4HIQ9Q=goKB$ z{cnwdmh)dN96ApF^UUMF7YY`Pe~TUJ-2b;LRlRx4(COjIg+VNfi2;8U Date: Mon, 23 Sep 2024 07:23:16 +0000 Subject: [PATCH 1729/1892] Merged PR 2260: modularize run target logic to strategy parity check pass --- docs/source/parallel_module.md | 8 + nnscaler/flags.py | 2 + nnscaler/graph/parser/converter.py | 2 +- .../fx/concrete_trace_utils/concrete_proxy.py | 6 +- .../concrete_trace_utils/concrete_tracer.py | 131 +++---- .../fx/concrete_trace_utils/trace_strategy.py | 334 ++++++++++++++++++ nnscaler/parallel.py | 7 +- tests/utils.py | 3 +- 8 files changed, 406 insertions(+), 87 deletions(-) create mode 100644 nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py diff --git a/docs/source/parallel_module.md b/docs/source/parallel_module.md index 43ca9137..616bb71e 100644 --- a/docs/source/parallel_module.md +++ b/docs/source/parallel_module.md @@ -183,6 +183,7 @@ class ComputeConfig: runtime_ngpus: int constant_folding: bool = False + trace_strategy: Literal['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'] = 'cuda_run_cpu_offload' use_zero: bool = False zero_ngroups: int = 1 @@ -214,6 +215,13 @@ We can categorize the fields into 4 categories: and can make the compiling process faster and reduce the communication cost at runtime. However, user should make sure that inputs at runtime share a same schema (including shape) with tracing and correspond to a same computation graph. Errors may be raised at runtime when this assumption is broken. + - `trace_strategy`: how to execute the functions during trace. + Five strategies are supported: + 1. `cpu`: Execute all functions on cpu device, model weights and intermediate results are on cpu device. + 2. `cuda`: Execute all functions on cuda device, model weights and intermediate results are on cuda device. This strategy is recommended if the model can inference on single gpu. + 3. `meta`: Execute all functions on meta device, model weights are on cpu and intermediate results are on meta device. For more information about meta device type, please view https://pytorch.org/docs/stable/meta.html. + 4. `cuda_run_cpu_offload`: Try to execute all functions on cuda, and retry to execute the function on cpu as backup if OOM is catched, model weights and intermediate results are on cpu. This strategy is recommanded for most case if the model is too large to inference on single gpu. + 5. `reuse_cache`: Compared to `cuda_run_cpu_offload` strategy, maintains a map from function signatures to output values. The cached output is returned when the signature of the function that generates it has been executed. Same signature means the funtions are the same and have almost the same inputs (for tensor type input, just check if they have same tensor meta data[shape, dtyep, requires_grad, stride, memory_format, ...], and don't check the value). This strategy is an experimental strategy to speedup the large-model-large-input case, and have risk to trace an incorrect graph if the signature defined here can not distinguish the differnet functions used in the model, for example, torch.nonzero will always return the same result if the input have same meta data but different value. We have plan to continue improve this strategy to handle most these kind of data dependence cases, but please note that the risk is still inevitable. 2. Compute environment configuration - `plan_ngpus`: the number of gpus to be used as a unit. The model is partitioned (TP or PP) within a unit, and then data parallelism is applied across multiple units. So every `plan_ngpus` devices holds the whole model. Furthermore, assume we have two workers, and their ranks are `rank1` and `rank2`: 1. if `rank1 // plan_gpus == rank2 // plan_ngpus`, then they are in the same unit. diff --git a/nnscaler/flags.py b/nnscaler/flags.py index 2e843e6a..31fd5c91 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -29,6 +29,8 @@ class CompileFlag: use_nnfusion = _to_bool('USE_NNFUSION') use_jit = _to_bool('USE_JIT') disable_code_line_info = _to_bool('DISABLE_CODE_LINE_INFO') # will add original code information in generated code, note that this will make trace slow + # how to execute the functions during trace, available choices ['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'] + trace_strategy = os.environ.get('TRACE_STRATEGY', default='cuda_run_cpu_offload') # ============== runtime ==================== dev_mode = _to_bool('SINGLE_DEV_MODE') # allow to use python xx.py diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index 69342f1a..41a9debd 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -98,7 +98,7 @@ def to_fx_graph(model: torch.nn.Module, dummy_input) -> torch.fx.GraphModule: use_operator_patch=True, autowrap_leaf_function=leaf_functions, dce_ignored_function=dce_ignored_funcs, - cpu_offload=True, + strategy=CompileFlag.trace_strategy, record_frames=not CompileFlag.disable_code_line_info, ) _rewrite_inplace_ops(traced_model) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py index d7d9a61e..82db105a 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_proxy.py @@ -17,7 +17,7 @@ from torch.overrides import is_tensor_method_or_property from . import concrete_tracer as et -from . import pytree_utils, orig_func, wrap_utils +from . import pytree_utils, orig_func, wrap_utils, trace_strategy from .frame_utils import get_frame_record, get_instruction _logger = logging.getLogger(__name__) @@ -251,9 +251,9 @@ def __init__(self, root: ConcreteProxy, attr: str): self.attr = attr self.tracer = root.tracer self._node: Optional[Node] = None - if orig_func.isinstance(root.value, torch.Tensor) and attr == 'is_cuda' and self.tracer.cpu_offload: + if orig_func.isinstance(root.value, torch.Tensor) and attr == 'is_cuda': self.value = True - elif orig_func.isinstance(root.value, torch.Tensor) and attr == 'device' and self.tracer.cpu_offload: + elif orig_func.isinstance(root.value, torch.Tensor) and attr == 'device': self.value = torch.device('cuda') warning_msg = "operation .device is detected, it will always return torch.device('cuda') during trace, " + \ "please make sure don't manually change the tensor device in the code.\n" + \ diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index d30e61f1..beecd331 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -11,7 +11,7 @@ import builtins from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType -from typing import Any, Dict, Optional, Set, Tuple, Type, List, Callable, Union +from typing import Any, Dict, Optional, Set, Tuple, Type, List, Callable, Union, Literal from contextlib import contextmanager import torch @@ -22,7 +22,7 @@ from torch.fx._compatibility import compatibility from torch.fx._symbolic_trace import _proxyable_classes from torch.fx.graph import Graph -from torch.fx.node import Target, Node, Argument, base_types +from torch.fx.node import Target, Node, Argument from torch.fx.proxy import TracerBase, Scope from torch.fx.operator_schemas import check_for_mutable_operation @@ -37,6 +37,7 @@ from .metadata import EmptyResult, extract_results_metadata from .operator_patcher import OperatorPatcherContext from .torch_fx_patcher import TorchFXPatcher, ExtraSEFPatcher, side_effectful_inplace_ops +from .trace_strategy import TRACE_STRATEGY # pyright: reportGeneralTypeIssues=false _logger = logging.getLogger(__name__) @@ -55,16 +56,37 @@ class ConcreteTracer(TracerBase): ) @compatibility(is_backward_compatible=True) - def __init__(self, cpu_offload = False, record_frames = False): + def __init__(self, strategy, record_frames = False): """ similar to _symbolic_trace.Tracer.__init__. remove the 'param_shapes_constant' because we can get real shape when executing. + + Args: + strategy (Literal['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache']): + The device placement strategy for intermediate results and module parameters/buffer, and run target. + The following strategies are supported: + 'cpu': Execute all functions on cpu, model weights and intermediate results are on cpu. + `cuda': Execute all functions on cuda, model weights and intermediate results are on cuda. + This strategy is recommended if the model can inference on single gpu. + 'meta': Execute all functions on meta, model weights are on cpu and intermediate results are on meta. + 'cuda_run_cpu_offload': Try to execute all functions on cuda, and retry to execute the function on cpu as backup if meet OOM error, + model weights and intermediate results are on cpu. This strategy is recommanded for most case if the model is too large to inference on single gpu. + 'reuse_cache': Similar to `cuda_run_cpu_offload` strategy, additional add a buffer to cache all the intermediate results with different function signatures on cpu, + function with same signature exist in cache directly take the cached result as this time function execution to save time. + Same signature means the funtions are the same and have almost the same inputs + (for tensor type input, just check if they have same tensor meta data[shape, dtyep, requires_grad, stride, memory_format, ...], and don't check the value). + This strategy is an experimental strategy to speedup the large-model-large-input case, + and have risk to trace an incorrect graph if the signature defined here can not distinguish the differnet functions used in the model, + for example, torch.nonzero will always return the same result if the input have same meta data but different value. + We have plan to continue improve this strategy to handle most these kind of data dependence cases, but please note that the risk is still inevitable. + + record_frames (bool): If set to True, will add frame information to node.meta['frame_record']. Note this will cost additional trace time. """ super().__init__() self.scope = Scope("", None) self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} - self.cpu_offload = cpu_offload + self.strategy = TRACE_STRATEGY[strategy](self) self.record_frames = record_frames self.patcher = FunctionPatcher() @@ -109,75 +131,6 @@ def fetch_attr(self, target: str) -> Any: attr_itr = orig_func.getattr(attr_itr, atom) return attr_itr - def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]): - """ - actually execute the code. - """ - if kind == 'output': - return args[0], args, kwargs - elif kind == 'placeholder': - return self.placeholder_dict[target], args, kwargs - - def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]): - if kind == 'call_function': - assert isinstance(target, Callable) - fn = target - result = fn(*args, **kwargs) - elif kind == 'call_method': - self_obj, *args_tail = args - fn = orig_func.getattr(self_obj, target) - result = fn(*args_tail, **kwargs) - elif kind == 'call_module': - assert isinstance(target, str) - mod = self.fetch_attr(target) - if self.cpu_offload: - try: - mod.cuda() - result = mod(*args, **kwargs) - except: - mod.cpu() - raise - else: - result = mod(*args, **kwargs) - elif kind == 'get_attr': - assert isinstance(target, str) - return self.fetch_attr(target) - else: - raise RuntimeError() - return result - - try: - if self.cpu_offload: - # Concrete tracer use `.cuda()` to execute operators in device and `.cpu()` to move the result back to host. - # In most cases, `.cuda()` and `.cpu()` keeps the source tensor's `requires_grad` attributes. - # The context `torch.no_grad` enforces requires_grad=False for all tensors that generated in its scope. - # As a result, behavior of `.cuda()` and `.cpu()' is unexpected. - # To handle this case, we manually set the `requires_grad` field after `.cuda()` and `.cpu()`. - args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cuda().requires_grad_(x.requires_grad), args) - kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cuda().requires_grad_(x.requires_grad), kwargs) - result = run(kind, target, args, kwargs) - except torch.cuda.OutOfMemoryError: - if self.cpu_offload: - _logger.warning(f"cuda out of memory, try to trace {target} on cpu.") - args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), args) - kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), kwargs) - result = run(kind, target, args, kwargs) - else: - raise - - if self.cpu_offload: - args = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), args) - kwargs = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), kwargs) - result = pytree_utils.tree_map_only(torch.Tensor, lambda x: x.cpu().requires_grad_(x.requires_grad), result) - - if not pytree_utils.tree_any(lambda x: isinstance(x, torch.Tensor), result) and \ - pytree_utils.tree_any(lambda x: not isinstance(x, (*base_types, type(None), torch.Tensor)), result): - unexpected_types = set([type(elem) for elem in pytree_utils.tree_flatten(result)[0] if not isinstance(elem, (*base_types, type(None), torch.Tensor))]) - _logger.warning(f"result of target {target} contains unexpected types {unexpected_types}, which is not a common behavior.") - torch.cuda.empty_cache() - - return result, args, kwargs - @compatibility(is_backward_compatible=True) def create_node(self, kind : str, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, @@ -230,9 +183,9 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): if self.need_revert(target): with self.patcher.revert(): - value_unwrapped, args_run, kwargs_run = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) + value_unwrapped, args_run, kwargs_run = self.strategy.run_target(kind, target, args_unwrapped, kwargs_unwrapped) else: - value_unwrapped, args_run, kwargs_run = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped) + value_unwrapped, args_run, kwargs_run = self.strategy.run_target(kind, target, args_unwrapped, kwargs_unwrapped) # because setitem is an inplace operation and will not return the obj, so here is a workaound to record node result node_result = args_run[0] if kind == "call_function" and target == orig_func.setitem else value_unwrapped @@ -572,13 +525,17 @@ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], *, # TODO: support trace any callable function by add the fill default values logic. raise RuntimeError('Only support trace a torch.nn.Module instance now.') + self.root = self.strategy.place_model(root) + # fill default values args = inspect.getfullargspec(getattr(root, forward_function_name)).args[1:] defaults = inspect.getfullargspec(getattr(root, forward_function_name)).defaults defaults = tuple() if defaults is None else defaults if isinstance(concrete_args, (tuple, list)): + concrete_args, _ = self.strategy.place_inputs(concrete_args, {}) concrete_args = (*concrete_args, *defaults[len(concrete_args) + len(defaults) - len(args):]) else: + _, concrete_args = self.strategy.place_inputs((), concrete_args) kv_default = {k: v for k, v in zip(args[-len(defaults):], defaults)} concrete_args = { **concrete_args, @@ -754,7 +711,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]], autowrap_leaf_class: Optional[Dict[Type, wrap_utils.LeafWrapInfo]] = None, dce: bool = True, dce_ignored_function: Set[Callable] | None = None, - cpu_offload: bool = False, + strategy: Literal['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache'] = 'cuda_run_cpu_offload', trace_twice: bool = False, record_frames: bool = False, ) -> GraphModule: @@ -873,9 +830,23 @@ def f(x, y): dce_ignored_function (Set[Callable]): The node that its target in this set will not be removed from the graph during dce. - cpu_offload (bool): Whether to offload the module to CPU during tracing. If set to True, the traced code will be executed on GPU, - but is offloaded to CPU afterward. This is useful for reducing memory usage during tracing, but may cause performance issues. - If set to False, there will be no offloading during tracing, but the traced code will be executed on default device. + strategy (Literal['cpu', 'cuda', 'meta', 'cuda_run_cpu_offload', 'reuse_cache']): + The device placement strategy for intermediate results and module parameters/buffer, and run target. + The following strategies are supported: + 'cpu': Execute all functions on cpu, model weights and intermediate results are on cpu. + `cuda': Execute all functions on cuda, model weights and intermediate results are on cuda. + This strategy is recommended if the model can inference on single gpu. + 'meta': Execute all functions on meta, model weights are on cpu and intermediate results are on meta. + 'cuda_run_cpu_offload': Try to execute all functions on cuda, and retry to execute the function on cpu as backup if meet OOM error, + model weights and intermediate results are on cpu. This strategy is recommanded for most case if the model is too large to inference on single gpu. + 'reuse_cache': Similar to `cuda_run_cpu_offload` strategy, additional add a buffer to cache all the intermediate results with different function signatures on cpu, + function with same signature exist in cache directly take the cached result as this time function execution to save time. + Same signature means the funtions are the same and have almost the same inputs + (for tensor type input, just check if they have same tensor meta data[shape, dtyep, requires_grad, stride, memory_format, ...], and don't check the value). + This strategy is an experimental strategy to speedup the large-model-large-input case, + and have risk to trace an incorrect graph if the signature defined here can not distinguish the differnet functions used in the model, + for example, torch.nonzero will always return the same result if the input have same meta data but different value. + We have plan to continue improve this strategy to handle most these kind of data dependence cases, but please note that the risk is still inevitable. trace_twice (bool): If set to True, a second trace will be performed, and the two obtained graphs will be checked for consistency. @@ -887,7 +858,7 @@ def f(x, y): dce_ignored_function = dce_ignored_function if isinstance(dce_ignored_function, set) else set() assert all(callable(ignore_func) for ignore_func in dce_ignored_function) - tracer = ConcreteTracer(cpu_offload = cpu_offload, record_frames = record_frames) + tracer = ConcreteTracer(strategy = strategy, record_frames = record_frames) graph = tracer.trace(root, autowrap_leaf_function = autowrap_leaf_function, diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py b/nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py new file mode 100644 index 00000000..2abc82a7 --- /dev/null +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/trace_strategy.py @@ -0,0 +1,334 @@ +import time +from typing import TYPE_CHECKING, Any, Tuple, Dict, Callable, Type + +import torch +from torch.fx.node import Target + +from . import pytree_utils, metadata + +if TYPE_CHECKING: + from .concrete_tracer import ConcreteTracer + +import logging +_logger = logging.getLogger(__name__) + + +class BaseTraceStrategy: + """ + The base class for trace strategy, which executes the function target with concrete arguments and return the result. + + There are six kinds of node in a fx.Graph: + + - placeholder: + `placeholder` means this node has no parent. A placeholder node usually means it is the input of the traced function. + The target of placeholder node is the name of the object, for example, the input argument name. + + - get_attr: + `get_attr` specifically refers to obtaining attributes from root module. The target is the name path of the attribute. + For example, 'layer1.weight', it means get root.layer1.weight. + + - call_function: + The target of `call_function` is a callable function, executed by target(*args, **kwargs). + + - call_method: + The target of `call_method` is a string of the bound method name, the method is bound to the first element of `args`. + Executed by getattr(args[0], target)(*args[1:], **kwargs). + + - call_module: + The target of `call_module` is a string which means the sub-module path of the root module. For example, 'layer1.linear'. + Executed by fetch_attr(root, target)(*args, **kwargs). + + - output: the output node of the graph. + The target of `output` node is not matter, usually is string 'output'. Only used to identify the output of the graph. + So here we assume the `args` is an one element tuple and `kwargs` is empty, we directly take the first argument as output result. + """ + + # identify the name of the strategy + _name: str + + def __init__(self, tracer: 'ConcreteTracer', main_device: str) -> None: + self.tracer = tracer + self.main_device = main_device + + @property + def name(self): + return self._name + + @staticmethod + def _place_module_to(module: torch.nn.Module, device: str) -> torch.nn.Module: + if device == 'cpu': + module.cpu() + elif device == 'cuda': + module.cuda() + elif device == 'meta': + # NOTE: this device move is not recoverable, the data will lose + module.to_empty(device='meta') + else: + raise ValueError(f'unsupported device type: {device}') + return module + + @staticmethod + def _place_tensors_to(*args, device: str) -> Tuple[Any]: + # In most cases, device placement operation keeps the source tensor's `requires_grad` attributes. + # The context `torch.no_grad` enforces requires_grad=False for all tensors that generated in its scope. + # As a result, behavior of device placement operation is unexpected. + # To handle this case, we need manually set the `requires_grad` field after device placement operation. + if device not in ['cpu', 'cuda', 'meta']: + raise ValueError(f'unsupported device type: {device}') + return pytree_utils.tree_map_only(torch.Tensor, lambda x: x.to(device).requires_grad_(x.requires_grad), args) + + def run_placeholder(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + return self.tracer.placeholder_dict[target], args, kwargs + + def run_get_attr(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + assert isinstance(target, str) + return self.tracer.fetch_attr(target), args, kwargs + + def run_output(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + return args[0], args, kwargs + + def _run_call_function_on(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], device: str) -> Tuple[Any, Tuple, Dict]: + if not isinstance(target, Callable): + raise ValueError(f'the target of "call_function" should be a callable function, but get target {target}') + args, kwargs = self._place_tensors_to(args, kwargs, device=device) + return target(*args, **kwargs), args, kwargs + + def _run_call_method_on(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], device: str) -> Tuple[Any, Tuple, Dict]: + if not isinstance(target, str): + raise ValueError(f'the target of "call_method" should be a string, a bound method name of the first argument, but get target {target}') + args, kwargs = self._place_tensors_to(args, kwargs, device=device) + self_obj, *args_tail = args + func = getattr(self_obj, target) + return func(*args_tail, **kwargs), args, kwargs + + def _run_call_module_on(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], device: str) -> Tuple[Any, Tuple, Dict]: + if not isinstance(target, str): + raise ValueError(f'the target of "call_module" should be a string, a name of the nn module, but get target {target}') + args, kwargs = self._place_tensors_to(args, kwargs, device=device) + mod = self.tracer.fetch_attr(target) + return mod(*args, **kwargs), args, kwargs + + def run_call_function(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + return self._run_call_function_on(target, args, kwargs, device=self.main_device) + + def run_call_method(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + return self._run_call_method_on(target, args, kwargs, device=self.main_device) + + def run_call_module(self, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + return self._run_call_module_on(target, args, kwargs, device=self.main_device) + + def _run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + start_time = time.time() + + if kind == 'placeholder': + result = self.run_placeholder(target, args, kwargs) + elif kind == 'get_attr': + result = self.run_get_attr(target, args, kwargs) + elif kind == 'call_function': + result = self.run_call_function(target, args, kwargs) + elif kind == 'call_method': + result = self.run_call_method(target, args, kwargs) + elif kind == 'call_module': + result = self.run_call_module(target, args, kwargs) + elif kind == 'output': + result = self.run_output(target, args, kwargs) + else: + raise RuntimeError(f'unexpected kind {kind}') + + cost = time.time() - start_time + if cost > 0.05: + cost_msg = f'Run time cost -- [{kind}][{target.__module__ if callable(target) else ""}][{str(target) if not callable(target) else getattr(target, "__qualname__", getattr(target, "__name__"))}]: {cost}s' + _logger.debug(cost_msg) + return result + + def place_model(self, model: torch.nn.Module) -> torch.nn.Module: + """ + Place the model to the preference device. + """ + return self._place_module_to(model, device=self.main_device) + + def place_inputs(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple, Dict]: + """ + Place the tensor in the inputs to the preference device. + """ + return self._place_tensors_to(args, kwargs, device=self.main_device) + + def run_target(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + """ + Concrete execute the target and return the result. + + Args: + kind (str) : one of "placeholder", "call_function", "call_method", "call_module", "get_attr", "output". + """ + return self._run_target(kind, target, args, kwargs) + + +class CpuStrategy(BaseTraceStrategy): + """ + Pure cpu strategy, model is placed on cpu, run target is on cpu, intermediate results are on cpu. + """ + _name = 'cpu' + + def __init__(self, tracer: 'ConcreteTracer'): + super().__init__(tracer, 'cpu') + + +class CudaStrategy(BaseTraceStrategy): + """ + Pure cuda strategy, model is placed on cuda, run target is on cuda, intermediate results are on cuda. + """ + _name = 'cuda' + + def __init__(self, tracer: 'ConcreteTracer'): + super().__init__(tracer, 'cuda') + + +class MetaStrategy(BaseTraceStrategy): + """ + Meta strategy, run target is on meta, intermediate results are on meta, but note model is placed on cpu for current version. + """ + _name = 'meta' + + def __init__(self, tracer: 'ConcreteTracer'): + super().__init__(tracer, 'meta') + + def place_model(self, model: torch.nn.Module) -> torch.nn.Module: + # TODO: save the original model paramenters/buffers data to somewhere, the concrate value will lose after place the model to meta + # return self._place_module_to_meta(model) + return self._place_module_to(model, device='cpu') + + +class CudaRunCpuOffloadStrategy(BaseTraceStrategy): + """ + This is the previous tracer run target logic (nnscaler <= v0.2). + + Model is placed on cpu, run target is on cuda, intermediate results are on cpu. + If detect OOM during run target, will retry run target on cpu. + """ + _name = 'cuda_run_cpu_offload' + + def __init__(self, tracer: 'ConcreteTracer'): + super().__init__(tracer, 'cpu') + + def run_call_function(self, target: Target, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + try: + result, args_cuda, kwargs_cuda = self._run_call_function_on(target, args, kwargs, device='cuda') + result, args_cpu, kwargs_cpu = self._place_tensors_to(result, args_cuda, kwargs_cuda, device='cpu') + return result, args_cpu, kwargs_cpu + except torch.cuda.OutOfMemoryError as e: + _logger.info(f'tracing {target} on cuda failed, try to trace on cpu, error message is: {str(e)}') + return self._run_call_function_on(target, args, kwargs, 'cpu') + + def run_call_method(self, target: Target, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + try: + result, args_cuda, kwargs_cuda = self._run_call_method_on(target, args, kwargs, device='cuda') + result, args_cpu, kwargs_cpu = self._place_tensors_to(result, args_cuda, kwargs_cuda, device='cpu') + return result, args_cpu, kwargs_cpu + except torch.cuda.OutOfMemoryError as e: + _logger.info(f'tracing {target} on cuda failed, try to trace on cpu, error message is: {str(e)}') + return self._run_call_method_on(target, args, kwargs, device='cpu') + + def run_call_module(self, target: Target, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + if not isinstance(target, str): + raise ValueError(f'the target of "call_module" should be a string, a name of the nn module, but get target {target}') + mod: torch.nn.Module = self.tracer.fetch_attr(target) + try: + mod.cuda() + result, args_cuda, kwargs_cuda = self._run_call_module_on(target, args, kwargs, device='cuda') + mod.cpu() + result, args_cpu, kwargs_cpu = self._place_tensors_to(result, args_cuda, kwargs_cuda, device='cpu') + return result, args_cpu, kwargs_cpu + except torch.cuda.OutOfMemoryError as e: + _logger.info(f'tracing {target} on cuda failed, try to trace on cpu, error message is: {str(e)}') + mod.cpu() + return self._run_call_module_on(target, args, kwargs, device='cpu') + + +class ReuseCacheStrategy(CudaRunCpuOffloadStrategy): + """ + In this strategy, the result of a node will be cached, and next time the same op with same input + (for tensor, only check if the tensor meta is the same, please view class TensorMetadata in metadata.py for more information), + will directly return the previous cached result. + + Please note that this strategy break the data dependence and might give tensor with wrong shape as result, for example: + + x = torch.nonzero(torch.tensor([0, 1, 2])) + y = torch.nonzero(torch.tensor([0, 0, 2])) + # in this case, during tracing, because the function is the same, and the input tensor has same meta data, + # then y is not calculted and directly use x as result for the second torch.nonzero call. + """ + _name = 'reuse_cache' + + def __init__(self, tracer: 'ConcreteTracer') -> None: + super().__init__(tracer) + self.cache: Dict[str, Any] = {} + self.cache_size = 0 + + # some ops don't use gpu, so directly run them on cpu and don't cache them + # TODO: add functions to optimize the cache memory cost. + self.force_cpu_ops = [] + + def force_cpu_run(self, kind, target): + if kind == 'call_function': + if target.__module__ == 'builtins': + return True + if target in self.force_cpu_ops: + return True + return False + + @staticmethod + def hash_input(kind, target, args, kwargs): + assert kind != 'call_module', 'call_module is not supported hash input' + args, kwargs = pytree_utils.tree_map_only(torch.Tensor, metadata._extract_tensor_metadata, (args, kwargs)) + # NOTE: here torch.is_grad_enabled is used to detect if under the torch.no_grad context, + # the tensor in the result might have different requires_grad although the operation and its inputs are the same. + # TODO: we don't know if args and kwargs are all hashable, so here simply use str as their hash value, + # for example, list is widly used, but it is not hashable, there is a risk if the str is not good enough to identity the different inputs can reuse the cached output, + # should improve the implementation of the hash_input if we can find a better way to category the inputs that can reuse the output. + return str((kind, target, args, kwargs, torch.is_grad_enabled())) + + @staticmethod + def count_tensor_memory_cost(t: torch.Tensor): + return t.dtype.itemsize * t.numel() + + def run_call_function(self, target: Target, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + if self.force_cpu_run('call_function', target): + return self._run_call_function_on(target, args, kwargs, device='cpu') + + input_hash = self.hash_input('call_function', target, args, kwargs) + if input_hash in self.cache: + return self.cache[input_hash] + else: + result = super().run_call_function(target, args, kwargs) + self.cache[input_hash] = result + self.cache_size += sum([self.count_tensor_memory_cost(t) for t in pytree_utils.tree_flatten(result)[0] if isinstance(t, torch.Tensor)]) + _logger.debug(f'cache [{input_hash}], current total cache size is: {self.cache_size / 1024 / 1024 / 1024} GB') + return result + + def run_call_method(self, target: Target, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + if self.force_cpu_run('call_method', target): + return self._run_call_method_on(target, args, kwargs, device='cpu') + + input_hash = self.hash_input('call_method', target, args, kwargs) + if input_hash in self.cache: + return self.cache[input_hash] + else: + result = super().run_call_method(target, args, kwargs) + self.cache[input_hash] = result + self.cache_size += sum([self.count_tensor_memory_cost(t) for t in pytree_utils.tree_flatten(result) if isinstance(t, torch.Tensor)]) + _logger.debug(f'cache [{input_hash}], current total cache size is: {self.cache_size / 1024 / 1024 / 1024} GB') + return result + + + def run_call_module(self, target: Target, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[Any, Tuple, Dict]: + # TODO: also add cache for the call_module + return super().run_call_module(target, args, kwargs) + + +TRACE_STRATEGY: Dict[str, Type[BaseTraceStrategy]] = { + CpuStrategy._name: CpuStrategy, + CudaStrategy._name: CudaStrategy, + MetaStrategy._name: MetaStrategy, + CudaRunCpuOffloadStrategy._name: CudaRunCpuOffloadStrategy, + ReuseCacheStrategy._name: ReuseCacheStrategy, +} diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index 75ba5010..f7c67670 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -1,7 +1,7 @@ from enum import Enum from functools import partial import types -from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar, List, Set +from typing import Callable, Any, Dict, Optional, Tuple, Type, Union, TypeVar, List, Set, Literal from pathlib import Path import inspect import sys @@ -70,6 +70,9 @@ class ComputeConfig: # whether to fold constant when generating code constant_folding: bool = False + # how to execute the functions during trace + trace_strategy: str = 'cuda_run_cpu_offload' + use_zero: bool = False zero_ngroups: int = 1 @@ -178,6 +181,7 @@ def graph_config(self) -> Dict[str, Any]: 'user_config': self.user_config, 'inference_only': self.inference_only, # there will be no backward nodes in the graph in inference mode 'end2end_mode': self.use_end2end, # end2end_mode can affect the graph generation. + 'trace_strategy': self.trace_strategy, # different strategy might lead to different graph } @property @@ -290,6 +294,7 @@ def _compile_flags(compute_config: ComputeConfig): async_reducer=False, reducer_op='sum', async_comm=False, use_zero=compute_config.use_zero, zero_ngroups=compute_config.zero_ngroups, + trace_strategy=compute_config.trace_strategy, ) diff --git a/tests/utils.py b/tests/utils.py index 717c033b..bd1eca0a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -155,8 +155,7 @@ def wrapper(*args, **kwargs): def patched_to(self, *args, **kwargs): if len(args) > 0 and isinstance(args[0], (torch.device, str)): - args[0] = device - return orig_to(self, *args, **kwargs) + return orig_to(self, device, *args[1:], **kwargs) if 'device' in kwargs: kwargs['device'] = device return orig_to(self, *args, **kwargs) From c3ee6cfd0cf12bb87a759cd5167a6503f4900fba Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 23 Sep 2024 07:55:41 +0000 Subject: [PATCH 1730/1892] Merged PR 2262: fix module creation during trace if a module is dynamic created during forward, we cannot get its name for track module stack, in this pr, put all these modules in a ModuleList `root._module_constants` parity check passed --- .../concrete_trace_utils/concrete_tracer.py | 24 ++++++++++++++++++ .../fx/concrete_trace_utils/wrap_utils.py | 18 ++++++------- tests/graph/tracer/test_module_jit_init.py | 25 +++++++++++++++++++ 3 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 tests/graph/tracer/test_module_jit_init.py diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index beecd331..b754bc1b 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -303,6 +303,30 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool """ return False + def get_path_of_module(self, mod: torch.nn.Module): + if id(mod) in self.path_of_module: + return self.path_of_module[id(mod)] + else: + # if the module id does not exsit in the self.path_of_module, that means this module is not in the orginal root model, + # may be created somewhere outside of the root model, e.g., created on the fly in the forward computation, + # in the following example, a new CrossEntropyLoss module will be created during forward: + # + # def forward(self, x, y): + # loss = torch.nn.CrossEntropyLoss() + # return loss(x, y) + # + # in this case, we create a `_module_constants` field on root model to save these module for the completeness. + if not hasattr(self.root, '_module_constants'): + self.root._module_constants = torch.nn.ModuleList() + module_constants = self.root._module_constants + assert isinstance(module_constants, torch.nn.ModuleList) + sub_path = str(orig_func.len(module_constants)) + if not hasattr(module_constants, sub_path): + module_constants.add_module(sub_path, mod) + path = '_module_constants.%s' % sub_path + self.path_of_module[id(mod)] = path + return path + # This method will be refactored @compatibility(is_backward_compatible=False) def create_args_for_root(self, root_fn, is_module, concrete_args: Union[Dict[str, Any], Tuple]) -> Tuple[Any, list, Any, Any]: diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py b/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py index 0ebd27ff..1151d52d 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/wrap_utils.py @@ -296,7 +296,7 @@ def module_getattribute_wrapper(mod, attr): except AttributeError: attr_val = orig_func.torch_module_getattr(mod, attr) if orig_func.isinstance(attr_val, cct.ConcreteProxy): - warn_msg = f'Detected {tracer.path_of_module[id(mod)]}.{attr} is a ConcreteProxy, ' + \ + warn_msg = f'Detected {tracer.get_path_of_module(mod)}.{attr} is a ConcreteProxy, ' + \ 'this is usually caused by directly assigning the return value of some leaf function to the attribute of the module. ' + \ 'Please note that this writing method may cause some trace errors.' _logger.warning(warn_msg) @@ -304,15 +304,15 @@ def module_getattribute_wrapper(mod, attr): # using isinstance instead of _orig_isinstance to judge whether # the ConcreteProxy.value is the following three types if the attr_val is a ConcreteProxy elif isinstance(attr_val, (orig_func.tuple, orig_func.list)): - if tracer.path_of_module[id(mod)] == '': + if tracer.get_path_of_module(mod) == '': return tracer.create_proxy('get_attr', f'{attr}', (), {}) else: - return tracer.create_proxy('get_attr', f'{tracer.path_of_module[id(mod)]}.{attr}', (), {}) + return tracer.create_proxy('get_attr', f'{tracer.get_path_of_module(mod)}.{attr}', (), {}) elif attr in tracer.default_module_getattr: - if tracer.path_of_module[id(mod)] == '': + if tracer.get_path_of_module(mod) == '': return tracer.create_proxy('get_attr', f'{attr}', (), {}) else: - return tracer.create_proxy('get_attr', f'{tracer.path_of_module[id(mod)]}.{attr}', (), {}) + return tracer.create_proxy('get_attr', f'{tracer.get_path_of_module(mod)}.{attr}', (), {}) elif id(attr_val) in tracer.path_of_parameter: return tracer.create_proxy('get_attr', tracer.path_of_parameter[id(attr_val)], (), {}) elif id(attr_val) in tracer.path_of_buffer: @@ -329,7 +329,7 @@ def module_call_wrapper(mod, *args, **kwargs): return orig_func.torch_module_call(mod, *args, **kwargs) else: # codes below corresponds to symbolic tracer's call_module - module_qualified_name = tracer.path_of_module[id(mod)] + module_qualified_name = tracer.get_path_of_module(mod) with ScopeContextManager(tracer.scope, Scope(module_qualified_name, type(mod))) as _scope: tracer.module_stack[_scope.module_path] = _scope.module_type if not tracer.is_leaf_module(mod, module_qualified_name): @@ -346,12 +346,12 @@ def module_call_wrapper(mod, *args, **kwargs): def create_wrapped_nn_module_func(tracer: 'ConcreteTracer', mod: torch.nn.Module, name: str): orig_fn = orig_func.getattr(mod, name) if not orig_func.isinstance(orig_fn, MethodType): - raise RuntimeError(f'{tracer.path_of_module[id(mod)]}.{name} is not a bound method, only support wrap bound method.') + raise RuntimeError(f'{tracer.get_path_of_module(mod)}.{name} is not a bound method, only support wrap bound method.') @functools.wraps(orig_fn) def wrapped(*args, **kwargs): - module_qualified_name = tracer.path_of_module[id(mod)] - with ScopeContextManager(tracer.scope, Scope(module_qualified_name, orig_func.type(mod))) as _scope: + module_qualified_name = tracer.get_path_of_module(mod) + with ScopeContextManager(tracer.scope, Scope(module_qualified_name, type(mod))) as _scope: need_pop = False if _scope.module_path not in tracer.module_stack: need_pop = True diff --git a/tests/graph/tracer/test_module_jit_init.py b/tests/graph/tracer/test_module_jit_init.py new file mode 100644 index 00000000..9165725d --- /dev/null +++ b/tests/graph/tracer/test_module_jit_init.py @@ -0,0 +1,25 @@ +import torch +from nnscaler.graph.parser.converter import to_fx_graph + +from ...utils import replace_all_device_with + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + + def forward(self, x, y): + loss = torch.nn.CrossEntropyLoss() + return loss(self.fc1(x), y) + + +@replace_all_device_with('cpu') +def test_module_jit_init(): + model = SimpleModel() + dummy_input = {'x': torch.rand(2, 10), 'y': torch.tensor([0,1])} + traced_graph = to_fx_graph(model, dummy_input) + + cross_entropy_node = list(traced_graph.graph.nodes)[5] + assert cross_entropy_node.name == 'cross_entropy', f'{cross_entropy_node.name}' + assert '_module_constants.0' in cross_entropy_node.meta['nn_module_stack'], f"{cross_entropy_node.meta['nn_module_stack']}" From 095645f3edf5c07a53c8d2515408689e444a448f Mon Sep 17 00:00:00 2001 From: Ning Shang Date: Mon, 23 Sep 2024 08:09:10 +0000 Subject: [PATCH 1731/1892] Merged PR 2264: fix proxy in tensor metadata unwrap all proxy before create metadata --- .../fx/concrete_trace_utils/concrete_tracer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py index b754bc1b..f8fa1186 100644 --- a/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py +++ b/nnscaler/graph/parser/fx/concrete_trace_utils/concrete_tracer.py @@ -155,6 +155,12 @@ def create_node(self, kind : str, target : Target, node.meta['nn_module_stack'] = copy.copy(self.module_stack) else: node.meta['nn_module_stack'] = collections.OrderedDict() + + def unwrap_nested_proxy(proxy: ep.ConcreteProxy): + return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) + + # unwrap all proxy in the node result here, because no proxy should be record in the tensor metadata + node_result = pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, node_result) extract_results_metadata(node_result, node) return node @@ -193,10 +199,10 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): args = update_tree_proxy_value(args, args_run) kwargs = update_tree_proxy_value(kwargs, kwargs_run) - args_ = self.create_arg(args) - kwargs_ = self.create_arg(kwargs) - assert isinstance(args_, tuple) - assert isinstance(kwargs_, dict) + args_ = self.create_arg(args) + kwargs_ = self.create_arg(kwargs) + assert isinstance(args_, tuple) + assert isinstance(kwargs_, dict) node = self.create_node(kind, target, args_, kwargs_, name, type_expr, node_result) From 1527c7ebf59bd0f4482471e3d5cf58497506b587 Mon Sep 17 00:00:00 2001 From: Zhe Liu Date: Mon, 23 Sep 2024 08:48:17 +0000 Subject: [PATCH 1732/1892] Merged PR 2244: Llama 3 8B-8K example (recreated PR) Waiting for PR 2262 to merge --- docs/source/images/llama3-curves-8b.png | Bin 0 -> 62345 bytes docs/source/images/llama3-curves-mini.png | Bin 0 -> 53371 bytes docs/source/llama3_demo_example.rst | 87 +++++++ examples/llama3_demo/README.rst | 1 + examples/llama3_demo/requirements.txt | 3 + examples/llama3_demo/train.py | 271 ++++++++++++++++++++++ 6 files changed, 362 insertions(+) create mode 100644 docs/source/images/llama3-curves-8b.png create mode 100644 docs/source/images/llama3-curves-mini.png create mode 100644 docs/source/llama3_demo_example.rst create mode 120000 examples/llama3_demo/README.rst create mode 100644 examples/llama3_demo/requirements.txt create mode 100644 examples/llama3_demo/train.py diff --git a/docs/source/images/llama3-curves-8b.png b/docs/source/images/llama3-curves-8b.png new file mode 100644 index 0000000000000000000000000000000000000000..fadc5dc8f114bbaa56f652a14031eb0b03d008f9 GIT binary patch literal 62345 zcmc$_byQVR_cprePU$WMQ2_yw?of~pm2MD_ZlqgUL_|bNT1rZ~LlFUK$wN1XZn*0R zzQ6bT#&_>OcieFq42QruXRkfiT+e)-XU-k2q9lihO@$3X5Z>cQ(rOTdnhZh6&oI%! zZ`yU1L?CEa<*~GchWp3$a|^?XC68NY`fLL@*}Z9p@_vR}G8y^z6yDM5gwI);co|sC zxtAqRj@r$NjIm6aGP7=fav4s}v?-6XHkPZr`TQ%@O*N8&XW{|Sqdoj5TlGl%^yn1E zVe;7Xgl+q$OoodFo0KG2iJMD4j*_r0T*_19h;X0Rkn{Ns7^ zTEc%{q)?5Gnpv{Au{O|XcJ^$aoVcwz^y1sVJp$rpI-Qkqi z`wt&#vUWjG;9bly;&;iW(|;`M|7X3Hcpk!ZZN{ZSjxXaLp;l;n%N5nX8OTwdOL(eZ zUFc-SO&Lui@U?@qzy7t+jjJH`Q(=a-@_X*F3Dfn`*2icSv}eYMdX5MM|M>AEHaR(B zWo1Q2kdu>>z|73dEtkE9e1V@PnleQxaYOJr7Q+=~7QGo&Rfk0r)&tp!qZQU8detMw z4SQLfNYE$G)!V}7djpic^(U(Z?i&>&5X9MVG&{-hvQ@}o7In&f4iCyJ?rjdl*jhbL zQDPPP7vkZ7=`9S%l9y_Is{S;3V(#pC*S>ZVe7nt`Q1=*I?i-*$2V2uIC>Z#r!fq#f zFFp!5p+j&v-#t0E`pWVC@$m@y3wY>CZ{uP8R{H19s0j%PZWZ_;iq~miZD?+e4l$5m z8*wLTxhNSJ(1ZNy8y)?YKNl=%)KR~jnY1QdP+V*P|2VKxXPshYZRYOo-fVJye(ok> z^P|4Nt(H8=pfmMyb8EKO;Gsn=@zq}a0lR|pD@7<2hn%CyxYV?kriS1y8M~&Y=H$NM zu<9u9Yu__d5v#*Xv_Hc;U#B?geD~OLpvh zCS~hmXLz1maD7N;@FYZ?8PCnlxO>Sm75&dwrD$b~{S?c7*w zPC}>8^~y0$W{~~dE`@yD%E^?^R#NA_XDA$jzcgA-?_AC_ZdrA@U2GH6=v>gBwQg-f zgc4hvje(cD!)z^3+O)s#>CkkCQoP|ozIoln(j}hV3>NqdvG9!J)Js)b4xgpVnZkZ9 zNCf{p=Q3dW%VpTPGS(fAiW zVZH0T67~_Dgn#4FfBt=GH$5$FrYO}b=O?|So~~})j0Z^T3wLK{=YH7u*jL}n3z@YS zmX^c9uJZdSl5oxFV~1JkUPOgL!Ou(|GTn9$h^3i8k=A zM+?xTle*f;!&Fo5@Efd3;yD9G-S6&IxkSNWM*JyC!tf>(N?NHgZXTY%)XTFaNKI(& z9^b2)yp<06InKs2?uP9syNQYuEytNAe@*)FzP=LE8Ruf`(AJ~HhmBaF%tQRU9n^L+ zy%+1f48-A^RE|&Muv>H*4$sVvcDgtC5C0Ir14lU0&X)`%ljdB`U?T=01du)M<%Y5{ zL*I#kedU~9-5voaY_gI7ZcwS5c}ahX&3UZlg?n{mCJ7uJ9e-cqq9P-3i%V4WBCc3S zR5W@--wV%Qe1_Rz%@yRt-oe4F&h&{^P1#CmCtcu^TRQe%z)6F(h9W@#D~^&IkE!<# z51Hv=riVT}%peduhKi*y1@}wOLmpQx;I_tMJt0|B>nZIlc$AE$*yyhr2 zIBf19|3UZaCoueeD=VXSR!qdnW3|s+=5*UXjBI5wT~dK3wGPzpCC%hZeZV z2_|d3O-(I*^(2iq{n=dW>GVbK(MArwVP`b$Xz5%Pd~Rr7#>1mar~8j!?S0%dm(EYg zp9p8VU@xH>hXrCtIC}arMH{o|g7n~}4{Gn(NMq}oMJQ3+dlg-q0JV}oCRfE}*+*SP z#V_ZRr&WKFY*qfZZxqTR`-376Yr{81YTdR%w7fFoIxTsPM~d{^>Z8{=k5HkJ?xWvv z#V`#Y5xR??!HC|Xz^wD=@Dnj~~w1Zlju(IjK zI2DB-bnWy$MU?M%S{D8jKxyg2IH8fT&+LB4O~f%Z zrriOJ?&!d`%@uM#8`*Yf_i{h?y`VnJJL%ota-|@o^TzQ#H6~9x;&&ZF&edzb-pz8@ z(JRrqgFn`fZpA}>XJ@Z|nbqLf(y3fzaeBYdT&h2sjd08OC^Xv^e#3n(>LCPyjMLDn z`M9Y&s3U3q@Nq&z-*HSfo zY|wQ%K@VfSTp5ASjX9yU$G~)oUr`faL52_hl`dJtnFz!nm_0w7#(-#q-D4$gc$?PO z*B`+MeBWQ3c}<_v$;->DJb&)-bnROpPck-qe)pyOs0fyuE>a*CBpO8?N~Z9!+IcvdBH6Y_iftozB!1aOdgZk&^NP!P zR}(eYbJ63Y=Pjhhp=gMN>;2nr0@TFsMg1?gFIeGXOnEiCB9ifPetx=Prz2FjSN81ZXN`IMmrWw;k!#FUM2WI%O=hu;N^E z(`}~q={+)bGZVv8ky7eOTJ^rjm`how$^wQf{KHF2OM}QF&}^%y?dtc3Y3mc^vuZL& zrrB&i!FZ?rQ7|@F^p?of4Gdx|QoYIS>bB7xot(Dkc*IuGEOo6{}V0$FJ;k-sIzqovDRLn1D?ppNYh#p?cjq0SK;I2OC$pC z)oIxA);TWh5%2iTer`b_A(&(JS3UH?fH50c0|SG~?ljG?&N)kKYbhXRTyK|SdwY9x zCHIVR1Zp)eIHs4C@r;}Gu~$@70QG5Fi>nQn(6;eBx>{YTau)!upwT(@L&4PSNL?j4Dw4)Prv(A za;)WFI)n{JN~VkZh|YyXMRqS=C3u)+PaOaua2wsOly7DJ?=|NO%J}$r@Uqs^t;W<1 zwd;2eqvZ@>m5qJe;AV^9FSW+n*w_HMmRnRLDe0(7)Y^X_3LRME- zZ4zayf99$i|6G?;zQlaT_saJ1>x7_l*4Wl@X!gDFHY0)Bc9_pAs1_O0!I z^W0}gM@NT`n#2^k+w#*a{uy*J1Y!q>d2qVJ+Ko%J3o+4I#yBG0#Xh+O+jb|6X#c8* zUC2O)$v_vgdpIM=_25BpQWE8;4d492f>|w%w!~es)}KFdeJh9k?fD~t;x8T&ywAgf zk~0u4&JeqyW{C!~r+|*iD)8ac&)2#^MW(Ih0Jq*Y=a8$B@BB?eT|Lm?Y?l1}$B)tu z4%=UsZpA-F4Q#Gez2zEUHbzFWw*yq%Jg0ye47s$z^ISj&(vnFtqM_36_@#Qko&(}2I`e)$rl!uHcupgf?V zL9}bZk%W}g44f ziV|8?2zDo2%h$>AhW$c2hle$I(MAKrz{Ut*mWwttHbw?3|4+h(S?bMj*N}2qUPC66 zw#a|Y=c{!{mzy^n_t+s{9fW^Yk2X(XVWA3p!d2p5J0(Pg zNn*$GVlZ%IXe{`ymqpb7!Dv*P`kiWlq>sLSW%o4(*A#QvWZ>Ol#Xuy5@+H@s@23Bb zUJd6*F<-i^x?PYaX?JIT-ueRv#skRI_zmu>P;TN(UmR2+**~m_iu`BOfQz4>9!h2Q z1KdKKrvI$|Vus)X5p(;GW7ofvmbMQ0&)bL!DgW!m(ez@`iTa+m5$NP#{y+D%lfAVU1>TpHbm+Dsh z-&NnY(GD0`X`{b@(Pw03-8>&fq|h^>!Urj`{&jVO^1mC>@2{);VC-6oLA?_Ca@*Ct z-qDB%oppOsBJ^xJ!bC)$WOS0a0kIH2HoO(leq{;}O|nw(-J6~N4v3+8j#7F-0S?G6 z_k%HuGF{T^4ezgyF&f}**iG3IPgj}g6uY?pcg3TnCU-u~9iH}&jJT2KXT}%}Kq^CipwbC%x7|4EkMSut4psN@5uJBDo@9#`4KU<-e?d+L z0n7#15TbaP3$Pu)OrFLE1cvyFLjc;&M84r80F4k+oIPPD=8FphS%d@wAp+grd{F8?8bc zuTU!)u`2U2%4WmZMvM`EKm%(h@iY4|(^dXLRI?$WWA#Tdoi)Gi1xXfS*oak1$+NqC zE6&&GMi|qHdF^QUPo)s@27-?F|*8 zlNcv~6!Q&W_5W}cxZbrjy4G8N`r4Qn0YVUlcW{aHq+izn`XS(8Y|S-e(EmgPnTGIR zu2^?iRhZYJ{^0&EJpbf-;Za^*{uhfmGMTT+WX6NcffW@EjlavJ;47Q~0GOMf-{RCO z0(itSKHF_>2G9V2xP`VWX!(%rNBZAD2rP*n>Ii! zp4dk$UEMv+6M}-py)ohpex^*H7FL(d!C%5;Gtb7fF+h6>)ZL>BZHzbHSrB~yMzRPq zfKlH8qXU0K@JLJE*MMRIb{1#P|8axZk`$b02myMaWr9G`pbsCYeCqkQ0J(MFtX@h} zVS^79`|Q8#-3L_n^xQ3tUE4Z1I2iOS`tnHO$H>bo^pbXVoS%suDitzi05N)2I1G}k zU~pt45qx|?NJz+6J>)_)ewvW(g&(ez1$`bI`V{|W)1=6?P9CFH0FmM-A3N@8Vg zy}enxH9P4xjTXC_3$_oCt~()$;gsuK|r;X>0+7TxMgPI71xy$dX&M` z*l3z@YyWkcY^jBkwO~U}4SHO$oH1Tn<%Tebk@7`FCzo35@Sq}JSYEM%Nd$EY=Sk+{ zcjfC3VU3rM1vn9sWqxsp zcx=pCy;K29V9` zFKgL!6$Y_XhgR7RKxUyI@;GsciGSj3O5H{|WufE=HLi@!wtRuh~r~euajCPt$c>0?dLrQ>+SCrJ?qX+ z?o;4tYb21*`JW3h5M!0;v0q(JWUY4koTJii`V#_l!433qQ@U_by0G2motbFI2MDdZ zMcX!1B#Mc8JUXZnC?dqi^N*BjqbK$gO}3d1GIDb6XAyd%osA8_xO{ zzCh4qgkQN6>O=G0l-PnpH+X1Knv0^N*XNi+v~Dn5%ByO&t67hR?5YGf$J=?Y=}bh~2AY612c3Ly zgqrpn7~dwm=i(Peis-?L=n2Ov!&6rui2cq2Kt8)R^{TXm@d<*Oac8vA8i#M_X=Z^J zB$XM@@3G1KQw!r_e<-u7yIV%xo3vr4y|-lzC#~%9&E=lwwp zu3zx^ZFkz+Zun@Nw+g$d<<@dL#u5iJAug+>T-QbJA#OcOJr<~?yi25KQtE;!S z_0vY)%t~ve5-xsB=A5>e&n^$W6bLTr@$vhTQKLa?_{lhZT4@nBaX2QTCh|mHmr)l# z-+E$i)Y{4iC=iX(Lr-*Bhs9Syzp7jqDP9}eo6>Mp5ey6fn-N8Yv`banrUjguuy!Wh zZxg{~8l|Y5nHBK=npleYC6| zZDZtw5HkdPgKgjD2ED>q`L^ettiwvu(-ei|55k<;Bj8i0!ihq?xmNN)N8FJ(ff7ua((Ss!mQA z)dAKB-#lIF8+aHQ{+K#DCud;9eXGp_lpqE|+|0h*)v|k@Yt2@021=}jcPu$)b>Nf2-!~3GNHElqg;n`$yu)@Yj zzZh?byRbh~Je_1ZBvi>OaVufsw|62j2+}SFW1<4jRL;0jr6AZ924d{t&wPJ(Tl-H{ zHgQwD!1>g&>_@Q4&4YqWqgN^mertU_;kf{sRBNRs6HA>W2?T`_N zLi+w^>|q~v`(S#b4Ov(^+1Ut8+bE1WY83N98}SNt*gGds)wOKOI#^R7XX~5Lh)}oS zY+bZRj2>8U>dMHxU4 zW)$Vq;3r~u77EhJ>RB~LtJCXC@KOaza=~WVs5Lhl&^{UCoDNoD6L*opI#{># zS7YpTef)Rp4h#QRgTvaUs5|e=C@KVMW{~bJ8R1V&(S;n5%Fm2=E#p!XLnp^|;KN=P z_UrR#(kAzX&)Wi3x<0!f_1QjZV*>kz3~9)tRq~FQ*t;Y98)7!PpHtZ2lYb_@FIY^_ zWaIR?*4Ohyu5W4xBkoby>g(TV{sWxTf=C7o4}R9=*r9Id)rYmuF4I%BHj9j>#VTy~ ziuvBqs>J~jxOmw-@NJ4G=%7gdZ6Xy3G(P~UEivySsll$GK`l` zYhgI3{m*p574e3yeca77v7hfZ^ReMrd)P;|7Z?A@pN*yg$GjyGTM>(RB%V!3-GvDnp+ zT@}N^djQ_uzUwZ3=1?&gZU8K+($+Rej8OV<93$<=yE(ZkF`d!tt=(k`Gq17tk49Th z9p?Kuob|UllXf@0O$h}h`goAkpEH;I(K~{xVH|9Z;)V{O6m!Nl3`gHn{5JJ4(w;vc zVq?U0HNuiu}#bCuWuwl+Ps?W z4@5xzC{0-Vjl2=V-%@PP1j4YsVz zr&+)(aATBO$T5wq6-3ZOI{SsVJ~S*%8deB7z;+&l$w-3la=wvF@$X9Pe^2xuen3sk zof@Rh<|y5B&_;a2ZGs*2nq-Fb=kwYFS0w0G{=aiv_6gzs}1Go)nY0B|^iY z++mY<%;DyYsO0fl7}AoAAPO}#7=SQF1bS6Uf(zz=A)Ln%w86J%FM5i_u`j@_EnZUths__x4=peSZEhi%%Z6#`BI}cxU^j2H(_H^f}bi$CFX(+R}q+BYe>(O0oY3trIQ4=p+9Cdg6FQVKTLf zjvhW6c{Ky#nDq~v`qLMmgW~&tUIHp&4-@fK_ z`S;M;XvNDLv)iZvFic^6v3J@ZwJZnI>pl}AONat+m^sRe_%_jxQq@VoKuTPH*z5r? zDkHCfidN6tQbFlnwXmHMDEUnNmK_sWQ~pFyUXoUvCGYvscc(vyGY~_XoX+IpwN*r|2UF&T7e2?_c%^Wz$11?Ik5Vfl;>p_fl)a$2SD~3XK`C#08 ziWS|$snyQbw$-1s)S)DeH~1f*wDfG6Q`ZVf7Lu8i!(It?A`^9~2l;mt(lyD!&~3o0 zb<3B?P@V4yF?vV`Ab{i(Re(A$6RbD?Dokb|iU7RVVPBiRgSC(n>>H!A3jv_d3?$Ln zweR2Ezzi|uUS%MeZ?$q_{Pc~;4KVy%P3+Nij%PL;{JQ!e!w`8r&y7{&x{m}YeX5T< zI}=7V$0*R*;AkWBpamrhFeAe7ppr2?kWQA98lesTL?4B{DAVmx7j6c-?`*fOF!U|j zH?Qn43D{B23U&5M-X$|~Qopxq&CcanNS`BG7XXgR%zlpK?Sh{0uDYTx<7RH((lxq( z-TOOcjK(%n@mX4IH;`kPrH{&IgvZxCh__poHnV}A1%=CBnkhuTK*{YdzwSTTS>W_d zLYRD@^6|Y-mddT}6g>}TCXA4_wEf}_m_vm;R`8P#qUIMidC_aQukU&5mzlA?{2_U{$4$?{wVy64rU+Jy(uv{rI2Xei!Mbd&JdQL6Stg zwl4yD>3?D-TNV6BQ~Tk-Zhun*nzZaQ?u{mkRGh%9vHc|QfQvW!=@fr72Ek^KB!^UI zh3FBAk#57EsFkD_m^-^fm?3c=KS+Ws?{{z#{OEal^n4F53iiDdKfSsXJ*T&-bz$3F zciPwq8OrFVpcDE^0OHQ?vntsU>k5?!Oaj(G*t4cmCNvo)tyfNBSdJqpD@ovLtru@Rl^0dvkn9#IGtk zMTUjA{oT8hy;p^`J+PgbC^u!3-^Nl7=}3(S;h`#;s*WRTe|t_DG0{NABYh*fwNkF5 zC0yC}S)n)VPh~X{#(_=n6LutMeXNwyF^R?7MqQ8e!)pEv0V;q-JYNRXVooS#1ei!L zPD7_}V09`$jk+_%;Ro7o>V5!e2`++LEaqoD8MtC%RUjXim>h9-rqTJXKds%)dvPwe zkiB(qkXJf`Ig3Ntv2l zi2uG2e>BCGgwlnACdsdtJV>?pai}bx9J}<_(T_(t^o_`&-GMXfaHmXt9B_H?{$`ki z&Gf{3(cVqxIaZ)19+%Og*P}OX=XI*6nMyN|v>aD*0Em%bOf>3zC#(DI6`prs=4f3W z8GupZa0)s6O=(+O4#C>f{)b$z^X5sI;w)v)U2`yiB!fxnuukZU35$vWvb*_g)YPrU zp*j--XsFP`m))3LocawR1$y-*qaQI*TQR3K?))V(h4dh`m*0VsLqeXtD7`URO-VpM zm4cuAP6z`UWWNcdJrbnv3rYGz+KNcv7YAEb?Ssx%<^zolJNSTn-KFBcJ(KUMFe*=}n>lqUNly)x4Ok zJ*btK?q*1GE@iSzO`#}OWp_vh1UN#zJn2oA7;>!;nxh&by}3+FqZq7HHmUA}+{{%3 zc!Z%P%}_#Cn6yb=QFdm|08d&#x_!r4etQ4oNaa5ebWia?Y&T0-FTO(+NM_l&5+RY5 zF|Fc8Z(DyD;T)55!0Z8yTlVDnWE*Ot1P zQ(_-yMR0uIJw^evr3JQk*Zr4iUrU<$TpQpWC&+9nT`FppP2$eN-ai)=U1EE6(D^SX ze(V@8S6Bg#CVaYt>zLB}~ zz}lI842Si=U4Dl}HNkau?v+xvPIDSD52A2WM+rQrZ{k^CC;e3DL>(9{c6`Z)TXVlE zugnW-7)w&(3&rWT#2xWOcQ}|Ga*lG%p|+zFLMafb(BPz~D@AqqZ5;ZH<%t0gA?{o! zpaD$jP`EU4GF85Pbtd(qM=d6LfYDiZ@3iFOQlbh?v-j`I5d;b}=rVz=0M5INw_^M2 zvJ@qQ<+fT@PG+izx!u%b3P79G_JdpCM`v*NbNkb?yh+@fc{iQ$bA;+^jkric##Woe zF3%L7yx*|Q&1>=((E%FtT|Flv(_uDHYiQv*Dh6Vd(LS?E^D(>Pjv+W4B& z_qh9;JxsDvfw_j2jg;GFs(jCWXUM<;c9tt}Q${u{DDS4(gM{{y(Fx38=9OLSDD6i# zjmvHGrBu`D1PihAkMOJZRRu+EpY5`E(eUlG^nqgS4-1AVZ5{LoaUgC6V znN#d>0sP-`YqfnLTHn3yV$X}0pBpp=zF648Ql|Gs<_X79m}TqQiPvLne( zI@gbJkhYgCpn;$=Z}n0y_HvB4w;TKUK-)**=+8(4dV)L}bOh%8UF!CSb;j<6Sq9#x zEI-;QOKs^}MxE{FicWbU2Mn10v|yF@1X+^NoIV%vDnHBslTf)$3jSr|;Jy3pof=~& ztJnc_usMMnFOfVYVhgk=koJMHg-_f@AU_5J2q=o}jqbL8)SxU_SwWSsx#JO*22UU_jdmDx|DxI@0%flDm9}mJF^Z z9xzs-8!j0Ex1_RS=Q65nT_Z4c=AN{<_}b8p8>Iw3T&h?-Gd}}W)R$1-0P;+sGLFMU z8aL|BG1x(=x)_8&49*q9+1kMe9{N`Ct&mB%dR(nG}nC81QvHoQgY^hV(K+@wm z(ZSZAua>mH!64^wYliKJ7i%(;L^IHY&(2qO+gE!-e*F$SclqQxJepNdB_%|`viKej z8GBjZDE}~YZv>(Jlg^A9L?-K z2QK(8Q4Xn#^|LIIaUCu1`eXk-HIkj5gsZUD+TVIoualAUhjcRm?{G{a0bUaT@i6HO zL`fYX4{~=UbO$UzM|;&_uT20KUm4JC9N;0?*VJ(>KX6}+kC`bv{AFnG%I{(AL@Nn< zB%4-7QLhvw&h@dB*RG*DO`b6^J2W4q33xMXckUaR4&oXx@SlpV%PLQE{nA@)OTyUz zo4-PP{H3gR1Ih0R({nC4D|Ufka7K6^4yM)Gq1r0v^K`H5F%U<=9OWLOhq8j6o2tH6 z7;w;9;@Hqb?Ql&9LYSJ?RrZg8NKi_t#Vwc&R^bl-Hnl_EU%S6w#19NpAX?$NjXtzs zg4G=sOb#B#VTu!HC^@6|f9rkK$pce;XP%RDj?glw9OBYB9O6t4ZVk2q5bAMW>>aA` z87(<2`q81+DTR0tl+HjvniYJsu1WzxKKp-OswA%7enWig%Y^bjCSG7u?zpT+TpY9x>?BIlXFam(}Hltcnf z>wv;BWY~LPZu6o@PEYf82kXnx-)b0vN_d${cBtANrKVuebHjI zmEj_hm~O(vCgC1@nkd26rE|=J46#@&g$GMAwi!o+zH#0HjpEtHvf=RmYFE0iC>hnb zv6w%Mg7ps9+ihCBHYkXSG8(`J(_DxOTX(#eTtG#*d3XvtJbfx0{k6EHgb+-^+ExjH z5lHNy0Ab;BUmQso8?GSfx)zQwN9PaLDX9)wb*cLHV`#^T5ME6LpFEl>DBv=-@ACW9 z?7$MQp3C+iqwWz-&gARG8e)xvRrSE?9Ah>S*R7Hi~Wg2si?HkUW^Mf|puiwsf6t#$if9637jY&5~+8_i(6gDHI> zCM!53#P%ht?2k|O{J4qT6+8aU%FAnO-L6>W@;7eWxI-^}hzn*%MbGA0#gJmPmE;a9 zCnkcE9~7d6D0UxpJZwOMMjL&_5T$YkfeB#P<(-IJ!BCMqu9<+G67ZtT{QUmp4C97ceG-k&7J6TnvhxHxmN&Pvd<_ zqwhWE&96(xJE#SN``72@mZW{npfQdk1T+^=>Gl-`5fEx5h-!_!)M<%tP%EMEnU0<} zg+H7mM2;2a?8PPeLW+f|^;3!t{5vb1Jf}hTjj(%spL#@td&5%hN^-Nci$>7eFq*^D<}+GiOB2EJ`dCq>top>!f-kgWeB_IFPvA znEs2)@8?MK!#kF42klveKZvV9i(vg?acxKm!XASP>(>P3V&-H zKT8rlQvzqhX69z10wc-inskyfdV1jV0p^4mj106ZFtxH%QT$;2kHOEVHbyM7;hK$P z2zRr_3Gvt`Vn7g#>$Tx<*7+gmVQ#+sc`A6sF4Xrp(f7NfrnIy|LDAjj?2uK`R_0(` zbO|5Pj4%mccrp=B?$p0^q&(Qr_6V6YX75%2tq?Gvq2=;Q&&5ObSp`kbc&I%~&HQhv zzg}vM5Gb}86&0l99J+oZYr!LH3%>=JYL*Hssq=w@6ma&6(GP?fJ9j6AyU4?W!`_uVE3 zW3^v{fI9C-cNE%9+a$}91YGumj-VW4q9lv_#%D6z5TY2?E$J7C>pSB_(}`e1xXFmQ z5^3r3q8CH*fGU3z7Nrq$L2|4n0iskboZ_YoI;cLOvU&*fA9~d5!hV~3>i3)4i}d^8 z!O_g}(YFqty61rS4tm1={Gc@I*Bnk-K`lLSElIOv5SgJ{dG__k5c~uvNfFGl9)z7* z@M#;7cW%F)HtdkTP28RS4u-=)%U^qB(RB(Zqy;D)Vy1qIfMzaKNu+~KO9_89|G6B_ zSxFsoh=EL5oP|WZ;PJijf&vbmZ90w^f~`7V{ANSURF55_L5&1hp@~J`xEQuS13`@F zH*z3c?7hzU9Zr~v_VIy@F3wXsFjZ8TtfHmnE~nqvxhBU&<1d?4m(yDKxD0RW6MxL7 z3#0Xk8Z5y=E5(*MgELgjPaO#wvYTJPSTLBJ1yjOcPN~pZ+Xab* z_&)GV0Ru7K=J(UcC5d4r!VPvfXwlT&SkmeOMt`v2&L{>kSCPy)2V4bbBOzZh*)U&Bi3WfkX$PXn}?qw1lzVq#|~xXBeb79`7h7pwwQ z&8E41&fK|k^+Y(VT}k6@w&Odw4^RNvIV=W_`KXB_;|tekiEZsyr>; zhXJ8RLCJ#x0f!cN`k-j6THk?#jM@IRT?Ri6Tp1CRTOXyl1%gc-u|aYI)8Jmz1JL)a zH+55IqfvC@IX*)|m4eB{oU#GRol#n`@*W{W#t&1rUSUv0q5B@X`GOffHU-daBI>Wn ze%lS8l-!$x5VzSz5sjD11<8tCfhfE~#qxLS5xJgf5D?BYSGUa1frKSOMKta9qZNNTY zNRYSBI~TmWTT>B`j#=Z0uBu__*S^@)erE)j1C8Zf-?m#JJYyYQ1HwGu(TP(2*99qxBk;iH72SE=5LQ;jY+e%kAd~6@s50M1+;E=6KlC!vJ;$j|28*E{;b=7dj-Yf2Nvz)8g~xg0(J<}BQ(V= z8nFliS`oB@9xP3NWbQoFbbFRbJW^~qLTow1Qi1!nmCc*ZnV|IodN-hp{GNB5M^d=h z%_T)KbB-)bS~9F%a(-#S4E4Fm4Kgrg2fT7=lTx#yH^0Y3{1B5>0KsR&QRO2P-A_yi z3@Xy>;i zXuNa$@qg=~OJsxpwB~fhT<61DX+;*2QWbfpvR+exU5Yq&obHZ$g*;N8G;&ODjJ^A| z$ro4FFRn{7Jk@CXoo%M)_}sn-W}oK zKJCgvjJ1awX>4hV48~w=V;*{+jeaIw8W&7gy$x)0?xLZov52ZAZ#({vbT)P`bxT7| z6T~Q5K9Ax1GK9@aq5wN;j023tZZd->mlv-F06=X$Bypq!-vE3Fdp4mek}Ta2Fz8Z0 zFroHC{t4SXzS+HF*KZ;XXuWn;xgA9tJuAJwE_*Dxn}ZI(4*=zRr56csnV{CwO5j$c zZCE}0Xv--)Mcu5fDTC!qMGBnE8E3C_t=rqBp@*4WpF@Wh+k)Do5pMo-Rj%~<2uv`s6&9q#uY>LlV%lZjO(k4N!A9+*AfE=Mc=#0(!-mi$k&^F}si@qe9B|4lkJmo6_~=Bx>1E zDTB@L+yt&=-o6(akY6alR>%HSy7M<@!(J`tLp+Nh#Gn^w5`q_4zW5T%NeBkLN~qfzK{46Ij2%tbCh5lF^qN2!5z#{gs*HO97SCS72R;6M|D^UOmVBTnI?;4wx1=zLx#?W*~oR(ezq4M^EF~0 z1lQFGgRUc9__H7Y4MoleX<#-3kRLR8f&3p4_l+uh8lu_my0U^M(aeU4#eFVe{7i{J z$xM9EG%%{D1tUpdrcV;=Bv=F(+e-uG7|6HW;@hFXlkKGs$`?nBn1Q|9N^J-N&2=Du zOhzgfo!Q)Za=HUI8;pQv3(|UjNE$6| zTHgG@bDXz3d$INmq|!1O_n&PP?Rk46$b)65*VK+cylQCOgf5h)lnX{}0RPSDj)*Su zS$44kdU}+G!(I5o`DwgFjw~CEgxfS1YCxULbQuV|XCak&0SZ9DDFGjRrZ_^5DstfW zfz6Gk`4ileYy~4g1^0B{s1#OFCHiQD7F&ieS{g@>rB6wytoOb-6qeOmpAA; z+CHQCtDs(pQ_y?4Uy9K3DZYQ{e%KcJh$X(y3Y!cs8O*H#{Y4DM9gVn`5sYZkqDa5~ z!v-;1_g{YGEtRP?dn1OXtyx6pKEE)+;b9_v6UR!U0vP}35@y_)#6C+re3&xt@8W$ORg zWY1yg2{=*O=dLiOcqGuBZOx2d0S~_ht22%QJ4h*!yw)O+uu<%-{?~_qg{>U4z;>93 z!!v3!Q2k|GInLNl&^zHjilfsDEC_)ogLod4g}4>WgS->Y0o}SA=0aeY6Aal*4`q2& zG0}woUsQc%R8{#KE~1noqEeEAh%|x%(%m85-O}BNw3JB45u{V3I}{L*luqexP#VPh zo|(D-b?>Y-^TBb!+57zV8&B=rFBmWn0rQbK1mA}QMCe>$ySTXgUjhH!8P2$yT-B8$ z*CqQbex0Mfa4WdK@|Xr%g9w#8+H;UY`+N|<1;ei>FL|d-Ni!nNBU_OUtSx|ysvluA zC^27OB~u6W)Oku6yN^#pO;vi~C!#WUqscnq-*n4L@Ppb+1}KTcbFLS-^ES`d^6_XP z*>&D%5TRj#;2 zfz%e$OwY!zR$d+n03CMk^!4@qC$opSA}6ez%KO*Soy`0Jlq|h)xY1&6g@_AdEyjD9 zWzF}1rw4&yL9(QbGplL1Fmuxi*{}*)*r8){eE$rns4rYsD5I171>a-9Er1aZH}nQA zOI?|@0&{$BKm*L!fGy$?S8kZT?s-*M(rV{(yhKEt`VH3p-PwhRNg19Rn^rsTbB7Py zZV&kWJi-iag~6uO#pyao;>0_YW?nAEjzySy$wW&%28~d7xyKP9s0A(;e2`xd2qs=Q zyTGbScSrmo zQR{*a50v@&4?Ewx`}*3)`fleJG3y`jNq<2{JZpM3hhaYbl=W4cUz0W=qeBCzKR;eB zlBuZUiRh+($3(!~44x2V%Tu8je~Qr5Bvf^hcz2i7}zoFfc~CQg$Y zT82Vx;e0;yX#gs&1Tj@t9=m090}gy9FGF?<#VKTOq8!W1)mj{u?#s(wO(`_#2k?X< z^@@GB+241m5&9=%Yb9Z8A)zP$e+=DfhYgkC1gCTl9a`se!TJF!I*bf3MGzrM7eTO7 zOvC}EB5|PM3K^uIahI7WVL|Cs)KrTA<^$y9@vm(`)jwjlowhcwY6C`p^1*LMiqV*uj^0ZV_Pc3C=5S|p%fP0& zbO*Xg#z4M=W+)cGZfG~p{5@o4qM>-d6;tTPAt_;G5r;0)Efb!=YtEdnwS5k{yk8!E ziAt0p%hq;q|9c$KX%_;ELcyJqhvY9GmW#f*e^Wxudj{qnKe2>s=%^Xaf=*6PhJUiv z3hs;N-v>y|K9@Ub)xZ9CA`dGshle(qFo;%xHU<0)_Xo53htdh<5{7&2U-w_+=9xo% zN>}+lTs1*zt6}}_&;@QC@W1uJfUl)p-90>vdiaDB0}L1}#2j^e{1c=MZUum{oF8=h zUUt6kGC|_M6}mS~v5;U@=7p$-6s4>6EC1=GwyZn5?N>9!8(cD01#tZt4huL|pev+; zF99c!@U<_l$C(Eaak56aNFJu=&3A(DpKYFt!IYA zeIPi1o^96n!%u<8@9WT7@0zpM!iVuLmx|B3LZCKas-iBdV{JsnlspaR&8S|3a8(Wf+zMa2S~w2#uM7sZuR%ar)c7e&om-tetKNE)wBXM*G3Z{ zA6d+F8yCxw^{G+Ee*ObL1x%DrDza<&?4O};pz2^#~9&o?M+3QnW&MGk^h@Q09l8{4Jzo_na4z(%l$(7u(IMDR%6W>CAeG= zsu>%&qqhBaO}~y(pyCfy;Nk6e85LdzEp_Nk#K42Smwp!F(_><^Ju`VTRtQFFejUp% z%|>g=o4i;6I8Yni4U;V}4TBLEc)5)CH`7&U=j}ZVD;lB$NN=>=1%o6QL_zZnI5BFf z+HK#b5m4{~TMYw+qs7{CM{w*J#+{uwNBchVd>~IMmZ!x>a*^;DGt%I_o9_*B{m?Lh z^>7qTU0L-esOayoxG+uy;fIRe#biTz0kxu#<1gT)O6hS~&DFeQ^KfRts3TZJ_jT#c z4n}F|#Vb0vesFLhkEvm-Ftb>hR8(QuJr&hB!Pf=XTLxM^L<>ICclmd&C}AC2+aUoC zx4-*xYu}NZo5@fc1!hjdh@q`!2Kg^~x{(T=U#f8+`B84$QUfpt1#XF7#=)Vy_wSI} zp6q)>M=24HgE?>AP!?%?!X)GN(6m{gAY&fXf-s4lV@R;Rsw*2|B7OIemnVhVAkg8COaS_n88%1?OS`4MDLK8BHSx@g!37-YZ*SOP|0IRIR(_( z&9W*0k4RB50gZ6*cAf0a@KHZOT7LyvY!DVU3zE2tDzC4 zU6yW!tws{Ma$K+uL16|8*U|C)VBZSNBq-ud1Yk*RyPDB9V6#wf6aXgZC6Y6}6D=8d zGUqGspOpZrvP-AQNQ*N`TRWk0x(cI4p{llKwZ*)$x^D2FhMH`iO!%=l+mm57#gRxO z;Dcu(%!%o_g?NI6+Fcg4tUQG-%4BJtm-gZ(g8)PU1h181Nw^bJk6%7%>LOT(SO;a3 z1gsX@I45B|?K8b51B>&sHb`W1!yvz^4hPijL#c{@kpY8{0gPl?8+CjbxSMJWHqwE_ zhd$9_EOGqV@~JmeC;+@17MS9p#h#!Pjs9;<19rQ@PMra;KEST%r3r&MPf+0ijjeKn zz&zU=r_bftwKu{E#DcbNwa_+#C2;oqR6LDC(-c46jDGZ$WrT^=8u2@YxbWcnbuWTIQ~2%fi_zDK zNK-{$>aSWeKymrv%uN;+w%hl&OVkbHpXiGEHChuk$@Ut8O>aEFB;df^grb(L64Cu@ z@z9AT;P_D=Vc}<5p(85F;o_LLV&JX3hxB%y`4f>o0M%oVL-Z8L<;aKjw0>!Ur4kk= zyGA!+Xbq$Bk0;72@+_RX#gPHBt4cHe6de@QER#p2y zb!(TaT@0P~U|k42)?h}Qi}#xCmv@~6b`vbKWgf$uW_(OJg^kFMWC)a%N-vHGC&=Qq zTbhur=0onIu)R>Bybc&6qpv^YVh{S$P@OJFr|&XR`JwLGf)yp{yT!I@oDbj^1}*73 z(x2^L0q(*%A4+TvgviLss;d9%E`<^22OXH5O@_=HBD|&~c23#Wq4cpoi0aCYiFh2y zpeK^o0#4AZ@3k&KX4ZXC>HQc0kQcRXB(R>iKB39)$y6PnPo)55`9DuIKg?A z2>>&F4%M{j-9bQM2t;U%F*d{@;RN{+#v)Xx(#TGk8gl&R;KO*iR({PN&ul6pB>za^ zC+IF`c~im&^Tzn9*2yC<6abqIym|!sPWeJ~=Tc|rvqbyzHU@AN1-=Se3Vmx~Z1k_* z!m2xe94zD_1_nolfTXpEgt!?S4WT32g5bGo!M&c@&;ka#M7xhjsV(`7()&=^b$@Aq z$Dr9y5ojMEl~2_sSqUOLK#xs%4=bHnVOGfpAjJg>Y^5?x`0plaur}08q^BKunsc-7 zVd2%(Ho&9MEHe7j7ZeTuA}-GhQNbI}zj%}hWv*@4UL=cIgTp$keQ6&}>6$V)S%1{j zQ+@?j1gUG_LIDUIS!q(tY*-BV(zoIyN!>Ej4U zhJOG$`E3ack5-*?t9$2Q;SU)7MxVV+)#k|fw0yu`z)Bq6MxoJ6!gMdRWvk)nOkFMH z7SM*bP(a05gU12xoyOxa3|a=dK%qR%vl~=Y-T%8GFmYKHc0Wv%x!c&TDWw$pf(#|7 z^{%t)o%AIw3K}NYZCHt6<7>>>fHrAe^Pd|oX}H=tKxG2@PFT<{v}(U zCRQ**l8oY}!~mFW+}?d`&u+g~C+pc`I1DIP-Vc4L5IIQQEf7qDoB-SyK>30LWW!4On-MKOXTmwa56GYg z8XB8|3yCa90S8V3ti-3+g{&>}^{QPfM*WcPsWL39^D=wtdJ_z#<=h8y(Hf-9_z5%D z1fP@ZA3Z!+0n?}RekC!$D+Gw^AngwnrbO-$Gh#@72WaKK2ZzK&hZHu=?}k^zz>{<7 zr#% zf5!Qp2N&CT$oiw(WveGuw8r@XL|OpXjHCx;jBk_sg1wO&KJ}e`-Hy}k%RI^Dj$f53 zlK9>3muh`3EKCF(9zK2%29{N1GK1?oE}}jvK#^beu*?Q`rHeTk_c6J(o67J;e}#2- zY6>a!DXS|WfUZ8tY$bMJf=XT@aO6ORRWqKr5=M*29Ha$Y+N8YU8DZ2e-@g~To`xha zlKw220CEEcWYiXwH;eY317OVu-vexkG2=*_2$mA0sKLV7anQbVLteKBG6U%Ew$y$r z-B0to>-{&cM1-CoqrMyusLlV#9JJ2*pZv}zkzeywGaIQe|7C&}J%1>4P|P}|%KUDl zj2i!l^B#mc5Dwy3hQ;9cCZd!M%EK4Yw|JqV!>JO#OPbU&%O?B>zHj|r0s#wH*g*>p z7Eb#X4}M>itfd6@V$;Wi8l=Ye)u0!O`YGWclPj}Aeitkz6NjfD{<2p88T1GpG4-Rpaw*!YU|^4I zp}1U^8Cez0k_piqv{-el1*ym^G!!u8r`iI+<^SXLuV^$~#pv}Z#;ow!*3x~X>H3Gj zUgO9gvWkYgH-r4AiPOlVo*)%L-~j{-2|0|(XLTq?QPl60(xpincn<}`A_hx29?S10 zp9%xo2>1NwxZ?tk6i&F?=C8-yVPlqYI%mi{HrT!Q0Q^F6c z=-uL5fSn`lY*jYH<(*p+jg8K3Dji2U=Qel%r`${6>@h$h>9m@ZKLi0Np#tfN8YegZ$}XIWIk=nTAF&nv7+MDBd>#Y!a9-}^hCrCDzfUlRRhQ# zraI=vtA44*^VN%WG9q1j*{680SouQ|$nFjkOeY%0Da=uvkCw3{GA$q?fa34AoKC%? zPs>5H_kUrC5RoGq zSFHL7Us4-Z<lUc>KMaeCzWO&s zI@2pmmda0J;@C5J!S`F!_?-|2*Z_;cw^~Eee~BzwX_~*_jRow{ETb z*;LHQt&l7v0fQj|ej)q1rWa1lWP53+_Jd1=h!E<#;RDe6$h3)Zqr?o|Fb;l?&uahj zZfHwAI0t+O@@@7nL==i_w^IlrO>zD-pEWO-qUN79bc$uMPJGyP?i~Lqcz_W3SZDhf zNOEulD`=}G@Y{J0RnL=ie&k^@YncdSpTWluZ#S&sd4`GyI1Xqo!0OH9nQFoK?(^<7 zNQX&CX{v1XS%X9rkX3_(c>ajV`z7);)=H1A2=C|Sq%b7=MxI4!TFbk8zs|p9>!Yc~ z8~y!L&uXA=w0tANqwzJc^i7)=5D73rw(Cr&#+z2Zr^S>L{^Z`a8CB-02L)k<{!`ah*?I8mrL?@x(bAnD#fquB!zu8#f z_ZFmi8q_+KJY7pi#wNWVa1fA)0*N^0W2D-zxpbbVXqlaMG~^NTn!9=KuX^Pz{6>Qe zCnWi6$Gc{~G|KtX0R7p(PrUSRMz`F){XPpQoti0FTXMk0j>^hnV1M&^ty(TgT7IDw z8i{CTej#lOjK`0yaBj2#mi_H`Y#I}>7#e}XK=3l#oPx=%AL1jxz=r{5buf-HxQezf zIwaZP)c(|iyx{6^_0J1~7$=|W$3&6s^-jh^fk)ZW)G@b&YwKn<#vGaPiYest?t>y; zOw8)pi;rUz|NAN=zg0AZFxg(s?{D#mAkC40M}b~jIaRKLGG2u$7!LXK%+>q1zwTn6K8y{^pnnh0KFr+^XJM}Q?4E8ViFvby^Mb8? z0he5cSPARjI?3n38G}+f8|!{*x^Fhal-4T?AMfEK+&*_MwZ*?nJ9I)48`5W)hOsQT z-zJtLb0BHLpvAFDWBV8*c$$>?a`DI*wvDB)X~n>qQxml9Pf&xgjNX=bj0^$<7@z&d z@Gs@&-ADn~|H8-Dm-OIIR?yW9vtSI?Z+wjnx&jK{Wfag+m@{JaddrnfYwgy=h*Blv zoUOtUE!9XZ?Zqj#LeHC!tt=^3KPrmbiS;HrIO=?x!wobocEzq!* zPqLARIT6%vvRb5|-+M~zLaRWp){Lpt@IATMRWu=)4r)Sr(oS4T<;1}*1_da37JVh9 zz8QTt4ad-pmUJF%C@iKPJ*ymHt0w_zkz2Bphd1brbe~3e^TS42X1dD9Tg>=Y;x*LpAF8muAt@^B#SU3U@VX;r zDweqpG*J2V4UX#`f?CrPLt;U5*@`f#G#3OGe}kXa zJw&HI6%fe*$~xE(ZX3s{okj(*)qh%zl~kG$sG(x-=|YP<7k?u8oLJ)F`^~fzfh6zd zoP-}@S-CXe@GdMH(gLd>OIdIv^gIt4Mhbvy8(%{J z)VF)eoWkYqs4yi$!`_|b=F=1Szzjq7w#K1PT`J`QwB|if!a@!}4$nL&?~%kAj~yLQ z4FR0D1@*i7WOhJt3gQUdci$A$b>A`d9f2?msA}A(C&=u?m7*B{ivdUhv$72HUig#y zx9WFcl>ETF$22N=61DgGq~U*@J8buO$twPl!(0)hh!-@jPm@C@Uxi-3n&oIpQ~{rHj+KMW zwgbc$x1w0^k>+Gbrb++p`M0%BxlF_IYxVe@5|*daWtHv|aazPlXXM}{0JRt^e~A?x zaU>;iWRAc`l0_dUjeGAz9dZC&v3p9^+((8;MATtYr!;r8%4caXHm}=n8_>q!lr6JK zd3xx>f^sxVdAXgq_t}!L4$*ELil5lmmncbX=u}HkJP6Z03|ZjWB&w^cL*kK~O~GLD z-2VD?;526ZPDs`|1*!I*TMsiYAZ8X&HZadT=Trmr4>FekxlFaM2C7hr+uGvzX9UfN zf&|8WX+{~IaS*mcoflGED0Y#UOeL?Z(*L1h$^KE4N*r>-K$f$Plg(`cdogGrk*hgy zCdyr>+e^}HgqHH_Jd`O9*daiaqjwh z1?>gB27BFuqi`{*e-*4w zEkinlF(8{uDO3twuZIjG z2~F>V(!qYCyLJBw!OO2kE`rUi_Xae0(`Hp|<$2TIfE~ZyU_o$K5V@B{o0ec9OPM@W zPHQdeJRaYtZ->{Wkx6rq4gvEr)FDXYu*feKeEAm>lF>4M7h6E9ZQHlx7vf=9bzU&{+&P{?(>t^gg-{*4t6m?hx6K@J?NQUQ_WXyh7Fm~OGh z0pkf$UJUpFXaSIUmq|?T9%H=AGvmxbB?1K_Q=bc@yTQm#{ALm-e{#)*7$3HC_+B-v zaUihb8bqm~C<#R)-Tb;Jk@YBCTwg(JuZ|sQ`}&TYI2oSI4Jr^~@pguAITmFyAB;yXq&3-cAic;UxKCD}T6bO|B{0yX!+>8Dp9##-; zco19>*v(2|<%Qc3G98<8?dS?pqJ}0ByYITY`XE$H{O>$1u@wXKs$d8g0sYviT#FiG znxHj@Wc?YJ0rRD%^;1tvx;dSRI7w<&Hruq@PnG?$Ax$V*rQAeag_^kid;WuI&N)oX z#n5m5lQOpO`oyS@K~g0Ua3HofPYbfXFj-G)2x%BzY0nB~GZ44c4z<`h>%b&icGb&>+^w5D=Rp{Kf%F=MW&+@Hkf!`J<2Fp^cnzJ5-#=#}*t8hr+oZ0E< z(~3gg2#7TyVasA`-A|UjE{Ax{Kf@iS$l==+qYg_~Wzgmc8=D`(O=hmr2mg#91SuIj zZCqA3hgCH_;8oe@+;;cI*T}K|#c_)VEgstoT|GeXRn+4c>#~_~b#yUCydhp;#zzEM z(;yq9YINmvetSi`9(~5Qu&9VKw)k}{di~#jUo({Hm7+rfLh((S!XoBQkbj?cvCQhR zS%UzSR)P8QQ772a_M3Mg^6l}`0csBEW3k}FHMj+)k7N5%s{E)Qrhz$P1{Le&AP zOyJ9JxcJ>uDsZ2i%2ARMudN_0qaAf|`7Lh(=*vEaza(|uL;P@1H$yI;QCx+SO$-<`%zZfl@QA>Vxd;M~rNu==dA`KwLP0|=_zS0bc^=%Ci!LsUsLi@r$q za90H;FTQA8I1FHyA)MK<;RluHeRJk(D2K%Ot~NH!E2VUxVX!wb8eD@JRJ@akh^*$T z!|**0&}f391uUtPTOmN1oZMmIdjI;WNxY*kDNWwz4?}m@fPy*}WJ18S36RX%wwKoz zX_{EDgS7+u>Zhjs#`oeZZ`i5p-^^vCVoG0%$xmB1J%(Q*OBjj?$QxkigKwCatzbDJ zQHd$Hy6~UaS-=nf5mGDJ#;4jCzP*AC{#!4z%ZNw)ENMI8%1$vsjgN?rb5kU<3>K35=7 zx;>wk!5<28`|1t*M+Wtp)b}8BO$@XQVY26wLi3!$X<)mE;w{DSacrSUW<)4V|4r0XXJP4QCa$3>5*(^v2!rC-ZlhfBv~6d7FE#zU~@4iErC8(HaJQ zRDcSpwl3@#>qDfg5vJ!4ya8Vg`iXo)l@LE0ZS3rw~L!7NkgbBP6#{GD+kC9F&tXd*l!f2683mC*m@#F z*9mf^00mmmv+)DK*g5}Q)YH5(=i7ZO!0oSPH>-!=x{Z{o*4jF&^zBII=@}KQ{-I7z zg2Xx@=(*5}z5O1_q9b*y<+@sG&Mq$CY&PDTZTi>IK1|i9$AV0_NYBc;j*>pSfqc(T zdR06fJA4;|P4ZL(Gma85AgB=fa-QSh2I>mUeTh2xEMZek2X5On9sD=8&r-EN|auhY0{Ok zSkt|0pJ5z|TL7uCbEk*K`Ls;JYe!d)OMsVvV0gP%=p&?%dJi5(T0G9#BDvd0aV`qM zohaa0k9Vg0p)K&#*!pYbKxC7km%8Ox>?DW@B|xnG>7OJ^QhB7q4T#}VFYRYKMHx(G%}w1&3gV8DMM#mBy`;>nUjknp=d4az0vLq>o9*0{|bsH&e( zbk5aX%x;sRUPHY>{f1WC0p629pP8TUGOdMAIy$6gQ#^IeCiO+>)hD$gHPTx9%zQyH z6$KMIo(RG3c;$$W$pZzJHn3;a^O;^Xmm zIoKCAi%Qg-ohzIF9Ba3?x4&BLO+KIe@jCZtYdoak6}&}WH*Lnuri?~eB`#|+b6lcf zBFQ2rhnkM%ivmZu7~hbxs;24=!zVKQf|2DLz6Q&s)2N;jn#BY;G*9UHnLFkx@)yc; zQK>>^a_@>@(DQR>6^+}DFxWBuD)F`G%j->>ap|g*l71U?bfQ+S)!T>uQZ2l$ozl5V zK99p@reM#Xw0w+o+D~&gxwR7(zJ_P(v>`goHTUi>`rkx!GteKkj~B(yGKtM7#UhHI ztT@-HYUoT^)J<9|>Bv^b#Ak&2`iY>=_tkA8_835TUjCOUMMZ{olDzHNTEiI7M|8x z{8gYYIKZv7Es*LaA6T0&zfSp{B~aR8q^Lm8NcaO>(WPI+sFB)_-B;x)eOf0=V{eYc zZNYfg5jFyYx&kv868Dog4Mi?m%m zGqB@-YbmcKo!KvNAv-!ULeJ!p>UgqzAwnA7{xhF5x!rr8Bd_Ry(d%gXTZbr`E=!L! z2OS@i<|{g>DTI2uPf)v71q^4m*O&H2O*jyE8>JWF+3IR)8y80i747>k!u52lD``?a zV$MyEtq+o-qg{`B4v$(3T)0&g@P1l*hp;6_UZu%44{T1PMdQd_v{77rwRYAeHB<)H zVH183O~Z^lI1^Cp@ye$^NlQm4XSaRMmkg=TdmPdr71C$PULw;-c3f)tYr6mv1nANimSrB z{&V9g>7QDUhFSNMiZ`2W1*ID@dqdn(cqo?oNARiN# zgsTzIZONzM-cUW3)0+8Jb3TQkcjjKvXyy{e(Rf{~{BW46)lqM_-K*($I>qN#7`}4Z zaJg(=F5@+Gf9H9shgVp7XBA{svgOI1$^Kz`R{5oT<6PrtNA{(RqP47ZEPIB(DJvpg zMF4wngCu4gzn0H7zH-Uo`aAYOxnf!-rqM@HyqWNZF5dj3@USoy6_t^eVqx0DjbRB7 z50CDio?C9pxPf5YSXo{UYp(ejIeBRb?4(X{9SI6QMG)zhqLSRcW(Du(Gzo(GcZ-$(!@Iv#5TIZq%dFXNv%T zd8IISte{+JF44PWyeMX|Rc%^j%XuSv;g4yBhRW8bn9ilR8CzACsuDqW2M33Wi>K1k za&Jc_q9>=CpS>ih*^Qar8B62yLam-uSIERHkr-vK>Yj;Z`R%9KYpv##Ra%gp{q2iQ z@o-M|sia5Kta#^ft*)a6)pJXW1YX;`u|E^`0xmJ*PTvV!3Tb(2_v8>-%)I2U4lD)D zQe=aQ%r$t?ys9U<(&Tm!1^^XqnJsnXL;wT^n8IAKmNg$Dxo5tsp){-=cw(DK>nu}`tW(L z#``mRdSq^@D5&JdeM{m9kvqD`r#%=Ym@b=k(@&Mp1I-TPwF{3=mp08+LE-y#Q`+P? z5#IeB1(oH6n!XL^vyzvo%rzXBrBZxKnW~za1<)Yz@$u1oZZ0j*ETKfF69king|&(QgxH zXJ^lcQ&923TUb7S`zG4qY z$F_!nm6i3U9G7LBdY!65xjtV*-sTAF6ZGE2F$G1bWVOGgA~ARrgn3t*e?|m~+&Zg@ z8sYr}V%@#H)-E$we=fG3n8SN2mL5~*&{9%LnL-nYPAkj^D8KC4jxJLQlnTWvBX+%X z=-!@mrzu}m86W8;c9B&e z47IvUVpsVP^qQ%0%Qqq)u!?$IgjM>O&)<0yLjJ=jYH0IldijXpbdwA}@s+P>{I+Yu z`N>}1=>(DLSZ&pA=ZAo4z#L^G1QhlgyM485Y-dKk`1;yRRZ8)B9=>lI^tPdsB4*rp z!o*nOLV`!OqTo`j9Q4@lsN9+Z;d!)~Gb_h?B_1xKR-_tZS3CXnD*5~}Ts15*a`3xh zCw%Oy&4gJIcIJ(xA77LV9g=lyYl3lhOZ51LE>8-$ToJ0*&yU9Dg|5fP;pTN0TJB5{ zzOWv&tSU7Bk+wPA5OOa}X4Gkh!)A8Y`|D)f_wSOoAH?mHj4rdE?#50l=PmxSj92%F zOZhOcz4e^vNC3Cyq(SYqP#J-Dsg7c_YpUx|DZ$6#;(}rq`Y$7d1Y@Hfs;Wy$VpS3x z2sxMo!$L!`JpYQdn|l6fVu2^ydGY-fv8&Nq>W9L@Lg`xsj0&Ef^&L?Z6uYbG_+&=` z{OXBJNnB33uP+ZIm^4ZsuG{pQIeoPzc)<XUirb5B=Maq(`WvcNiai^xHXU=#I_C`X@&}Q@z|c=+kIoD8IvJRbNTtSD})KBZMg;|P8fH3mF9ZW zt|;wqi;3@^VEj}I*F@gcq`M+VyAm>w&w_i%#!G7I^YZfb_aa2hLVJcx2~%^}w>XRJ z%z~>|w$)5ctL`lt$0wvvRdg-ivE6)tjovOQ!9d5VhksGz9rWqqL=pK5e0WITxqd7O z40M0|q}@te-ZpD-UOY=NIyzd8yluh$)C->o8$G$tY;MOD!RG0#3{zWXRu;QQYXKau z@nDLYt_{bHxWOldZsMGLf4du)su@@*^h zRY_^u+_X}YX><6(KrSEh(*D{v{w>xhp7A=~7>Nz0yQqC_GwarLPA<2@r+y}8dhXg6 z1^rM3dM5YMh+;?+OGTX;$Fv`LyloAK-E_{cFzLD(TllUjHyR8z#p0WFe`kGC3#TG; z)qR?if2l13c$pCd{O9yqP`9$?5pEbTRD>&|b(oGV{E(&a;pW|Ge+Jx|xS z%gG&h_KAtq6>)D7r=Oc}w6)hU#CEN8T1%JZF_%zQ@NFE-oO?^gaQzyQBs`ge=WyEF zLd>>mKgo)5Zxi!7r^7E+Ra0f4qx?LnImLL}tT#L=d~AIofIVHmBz;$gHmNVy%t|zw z1dlK(!g>AMQdDf0iIQ_1p8+CiT{4yJT$C(vWsB9kDrv=`hCi_1l1oFN;A&6AJgh>C z@oJiijKlQk6?dcct?~?M~+s zc12qK1dEylJ@Sz|!QQk?Rn_M=56(8H6XS@+t*1BJ(kIqw<+zuqBY8Y(yHC_PZU$`z z410Q9+!XP8vV1G8K8J>mgQjqDpP7gzwrujjBOZpD>8#8crTF-r#7QG1vmd()wnK_L z>R-$>R3uwPlau^ob_6@{D}r$NrQ)&iBz^G=nFo zrrC(7jxUV^&62&U>fCaa+o#I07-nYYK2v2=y9HhkUKe^M9#6BpqN1YVW`84_xnzcb zmh`x9?ti!A1tlI@TcR#W!`VbBFpqkEge?@6E5c4{38w*nrb957$;db}E+Hkqn5%)^ z#6aNh`+9})$xoZZMGcOwJ6Me~ma0a=I^UxR84;sH`5`J+Pyn2-xXd}~eR`@>g+k;) z92?BnM^nr(`PpL z&Ed&ijtr5z9LOudL?k8A-sZV_fsdb9__BmtXmrsstbOPZm#bl>@ILw-1#jyJ&+I&B zjZ@CtjE~zWSkqUBg8*em)AI6qdqrfHUIQz&F)oK4EiC~QC+e2oGWsT9mA@m_!v?mPcaED=ku;MRRxvI z%eJJ0;@n)dY1h@l@^US@F09SbR{N^eICT}vv++4-9BKt!p?x?Qk#FC=?_3(}-s7O< z=2j|c`ct-Ol!bqvY}9%JpohwdZYWv}PIE0Gz#hv=h6^m&TDgL7A;*oJP^ z(%K&5ZKxe}sN!Omsiz|A94CHMHh+`n`mgYQHMNgvFq7>Z zAVC|S`a`SaUYhoHaRz=4aX>9pbd=zUcdEMTEH0eT&$(=ksotC0xm6C&itEa*6&_7h zZf><2qIwszwr>fGk-3@Wi)0zABH570_oV|`H@oI${TeBPa7|bn<6=If+cj&KuzbkXC$sfthE=M+D8UoC zIV_5zs;-{bnyFO$*CG6Pr{|IN<{~Dt^rMg}Z@U)7+*@j$iNG=CNFHIxLLjTYw0s1z z1hk8mD0qnF3@P1nS!rK;t}Vz*Ka?>nPn-+?(uC*ORAo}rvGR+<^rB|%tg^*Gp`ff> z1dpnf`Yu1*G3XoXqcs8v?bd4EI+gTAmM}4}?94hnCFTB^5f3HDV+Yd@c|&&lhx=!| zMt^_B<lsN0C@wDzA=q_1;LnVZi8G*sI8QD~?`=M{Zk(o7&+ftRya7L2JSDgF@zFKN@(*-yWmjnHMd^>%Y<3`r?ZdC(|-0=nDYekOQZz^#nRI9Xk(~^WT$h* z)oSLBHf*Vx5r?@SdB~$A8XQmPpwL=&AFmH~l8-ZS&W%`-I9O9b1{O>TvXXN0Bg+TI z_!Jabm2uhjoKF~KX=05sQ&^3a36Zlb&vLG2fbgp^HSFy1N+; zdwPDTYDwEy?A86WUD;u{S#M)oW2aUh+&)!QSa@_)vsLLM@flh$VzNK0@p3-ODIb=h zrlS^fdV1#J&p&y6wE^ua`n0_-^2y-2YE412-f&u@zxb9Hugg@d`ojvYT1-Q`eoww% z`=0+R&0kqF^gm+v)o*m$b3XcHJ=Rgt78C<$oV1et1?PjvD}(N(&&nR~x`YBK^s6?- zu&4tRYX#l9e42@#nJKo*-Q6!SPe}1s0#PJ7%zP?Z&H60O$IYJ3FXI8gn>6g5RbM+e zO0f*<*-`B8=Q}qXK7KQ8mC6hdl~?R^WiDVL@I(}SGV%@gi=W=vI&faE)_+~GTKdi_ zKYRC&a|hick#ZL2{~F@}5^^Lb;4qc9}27gjoU?Px69~<>!+OFx{>rpGK4&<0`F@3W*fYV@YL8 z44kZ6Xn8%Ot6CDwiQ2HtI&$j@9myB26%<*qjfA~(TnC_ngNyWsZH=Yi7d;zGeQh{| zJCoziKOVhIRW~(dBE%9xpmA{dV6ZzFll3 zT|8U>(9fuqo$sgkBF@T!)OGHEODRexb>8r($zE1m_6=R%qoB=%gr644Flj+Ifysj` zl!gOvm4KT*#Dy;Pl}4k)P5YZ(kqPmM>YH_mgy7L>(Yg%zn3?X{eCW_u+Xyj)J(3(PZ{oY zH~ldMrvEWTZQs_%h5nsmi=_lIt$OeH*FQ(jv0cq#r2t@Juj^ephIZ)JYZhCMVN!iS zPKKF*gD3K`>xI1eM=v47iJvhB8ge^@0v$ZOGSq)vthAUwpVc`kPu}A!2p&h*__zeC zrLhlcC&nCz5vIpd47wz7iTc@}`1wR;)uK9wP5SaTQ?!aTs;WnKd_*>HF1vbEu6-Tg zPMqG^%K^+*`j&uJrc2+UC3oCE9iXD>S&hTp15cVP#hGcYEJI-}id5H*Z>~GnQ~#Zu zP#!KrRm|iuz|=zCj}`n2)wQF|rl`Mp>a&;m@h%Le&x!Hwlj-o$WDOFPer`8XgCoL~ z%DgjMQF3%#$2T=COS!MG@6J1wjL7@hbS6}6_f+(Gp_uA;cdt8ofRIDkU|5*5#oBJu zjwMSyPVnGAmoMCJ$;!aQRVEU(;CyaY&A>8O^0-I6y7qI=%Gm6sux77WP7eAwbND3Q z(+Fnh;}tX!qx%`xwP(gw95L`EtKz3u2CF~JOvO(h`)D(o-QJ#iV(&{H<*2=mZ_0rH zSm+wD%VIhQ#T{M3f96A!`FumRs$!zL=^aHW%^-RPs}!BxgAp{|_8 zzY_Dpa>51{V-lS|)Gu{Y!=$CFO42nH6_07g9qM(2MxR-R*?reF*f*`vWT)mS8b!%1 ztQTic`^>j(T=Q+m(aMd)ivwsj)91>hmkJ@lO8;P8So|oAOcMq8J&BbCJ};_q|Jk@u zv`i{e*q<^e=cPTD1=MJ>bSX(z{4vcp*y?HxZJG0A08LzWtgk%zDr=Ep!jVkeCk}@@ zZsyDXA?!WCn#{VjVH|q_bS!kRP(%cjDj*;-f`EWXlNyl@Qj{KqU`3@VRX}=gL8U~x ziijY+g&q+pNvJ|70g`_`$h_yA>pRzf=9+n5CFCi4uX3;ZUVHB_@HUo3{{v6IrgV>F z6SOr8bW#2bRZv|zd8kyNnF2LCTVxjAhB}I|br-1`U#iL~-zENw>x{75vzA2Jjklps zEVus|%;YQcci)v;^=nh&#I3Jqp!S>ckF5Ki z#QIU4I$G|_($0~$J0TgGLxuUOwjZ_d?T%Qkp^7x+s@a=LE--6{xGGl+j7 z>{1u9N1?{q@!0a|vpsx#UmcsCouCa|Oq8AOEx|SP>$LA3is}^>&$b-t{pVlrZ+6&Q zj??jF3sK3)Lhjy*K(`a z9%u+gdsb+kdTN97RSXBs$#LEOl)f&xOPa9L5SKIZ_JZZXs{B7CE0r!47u?u5m7`{x znE&_#tF`;Ga2^)hPM*Dg|Ml{Yda*wM{-aK<1rFiRNYP}G?x>h+dnbujGU*c$Sy+)t za}CHFaF_$Q*wM~^X0v{8L{+nEGDgQX^Jiq06kPC>e>8GS6yXp)e3+Cvs^Ah@ce^Kf zlqx79B|Dko1AFS9r|hf=3EA(D_aW(-xBqs%7}1cO-*h0sYl#gMO|{0!#9ZSn4f>ix z2?FHr-sCFh{_|Xy0_U2~;;Btej79}El*Y+SbGD(b0wtF-k4~ISj)~yaheH-Rd$?#d zAu5kjPi@D~cXj6!->-S0(FlBg${mBC_td3*lO=jEIGMt7_a)<)*fR0gb(Qp$!RPG- zpUZkQJ!h9%yb!`AuUF-pndhZa{$9-2;n&+AXzA|V`iZiIjo-K17Kz7RBl-Q9S5xGb zaKK7M)=lM=b#p@KAM#z;z({kdObCU~?1z(pD?}>h&L;SfPMck@e#nbzeOG8u5Cv7_ zjC}P851)Y={Cn-;y`<`&y;l3SQf@teUR`Q&B(MGyfz7p|-R@j@kSco} zlwjsXcLz=tgl&b|>4bucGH7cahSzjKz(rDYgAm*$V#v8O`|xgZX9l4djVgdTd#SYx zPry3fladRnOHFx1%>&;r^Zl*=j=*vioZdrQTW9s_^fo6T6P-PlDxZRZ(f!#&R47-O z2YVdOlOxXykV?E1bd)UK0dkhgB|PVIho@5-zca<@GwZP6)EaVAC~XrQ#0~I ztN*4X?u!P>a@o0g`7MG(ay2ru^Y7QFe?PO>5MCH(hQEE@aqh0)fT?ez-Qv+xE?C<* zc-<$uU&*hsXn7gm4F%71%%GzCeeL3Xi${95Qu5?M{)qF(A5ZSl@0RQ@_Q(nbn@Z&r zV?Xdh{(90kT|)ueg2H88C7tK+0m5r1M(4-%_rEMT(c#=lv{aq^;9nJQR&RH1$I%ra ztgoo<-X-bbQH|E!rF=_MFupmpy>>T_RVp@k7p3v4p~|C&gipmOCJy0<_!=9mQpe4z zf`8e4yC+A0vD0z(e(LAczsI;)pRyb}l+ra*CFy>ho*Wiv$-EWYl%XNs!ii8Y)N4`i zT7F;o$COkadifeYAb%L-he(JO(M}aeNmMN4l?2tO6|N6T#X_59YY8c)`fDndSNHkG zH}iWir}lp6v(gJwep-pt5Hrn&?t1Cv(ySexc87jvXMg>a=#o%M5fftzA;rX|_{Xvd zV{LH*?(JPkldmw<=}yh5^Q1m`O`TVtC>`_NHv3HR z&bZPt+lXd|Ze;wcYATe(F$(goLrmEkzsM1*#zTv z_X>DXg_+9xAL2ZNBUG3eKTJ>W6eZ^7VI5SzxA*;_9~opHx3{5RTD`j&_hzJEgA8Y9 zP5+qkP)SpQ1SQE#6=? zdE%O8gK$<1$;!R?#~9`OlM!|tD3#Wl%Ijr5JE2>ni~2S2EnL1T4m<+rYAEx#Rt8wF z?tYFT1?dfYIYy-LS>M3nb-mLEh;o)M8Q%N~hK0B410E?YI z(|7{bpV%4w^Us$FgpVnC6>X>eT8~++-!i_3H8vKvH?%bo-4k6S?)9 zlbl<|%_X8BqQauZRw(bY^PcPw?bc_zJ8MnupTy!d)wMMYto2J=yp8w6mYc5LwQI5u z%RaqSv@Ou(kaCLo^h0Qh%j28nQ5fyllczxtJ0she0<4?bZnQmD%9<{rVqdmu^Cqu$ zDX_>$SXM+x*EUgB{gl?XCBn?xzQ}HzZ6K%k2@wOu3fK3Wf&!*L+aTx>57Aun2Kht` zRLwJqeqyXAflVP~X8{}D{UIXI@%Q6*>^`4AC7Qq+u?;C3?y2Mhm5q(@MW`ZmZbNpq zav}8FySwhGwwtql{0=x{d4`mH?VEPH#pPc)v)iEKiEFVM2*|jL{&%kS_vw%VKS|b4 zdpgcmUj`rrdBskzYtXsK;U`-)^=ciTHFd6NOo>PCHo1*+)W8xe;+23%&^d}ph?dg1D-TUS41D?xQQC^JMuKugbK3^w!_J4%l_>?TMi$38U#n0 zpRb@;2ln-3wu*8s36LJZ-yH2icW5J1hlp&6a7U04E)MmFoUo5*x0`xL`$7cLx6!%CZrwi3147TI*;=c8*=c&MRr1>p2dyxma%@IVUARpD8 zl?6zx*uFE%Jm|I zt!Hr6gAO4?1VX!O01>o);``ZYrohyb)9mth<}92X&#{K)vTFyvIQLJnaZv*l$%5y# zELD>Zc32YbFEq=J28&-V*LD|K{DTpaL-Kb!s1_J+Jc4+I>8W%TV?Cg*t!d!nlL6xf z48`}lFPi*m&=q8eNDBk6V!T%BzTUlq>?#lcJdSF=6?4}KYeV4ntZe6{}$Fb6Ax1Zm$8|Kn4zB1J|cvH*7n8Y@)-G%X0?bH<)H{Xlw;`}O3dVup5WY*IK|bae*EIr+}WK1 z=d7+8P<&+KH@AaRW9!EWh@pT0@gnHCL720uCo5N+v>(-7tpl{S+QGjc@83dRy1?Y) zr=`ugXKWZQ?ARhoE`6({lJXWE{uKcj7!Hd4nJflQC8!Gfrc2>XC(r%oXP`qMqwXax z3x^nloDmYL9Xjy%(0{IGdbjRqc)HU->60?PJuH869z*EywDx3ZiUjeiF#Pc0&qHrs z20=$d!#{&~+=ILbr{qC*-kwtq{~BW1Z1vjzyo1eTTME*m^FQw} z7pQpd|9XceLzG`em;&|mVDlwboj;)u9AG>1#pBOG4cjWF;k7L5L-j9TzWhH&DG#;r z0@2L29*yC+)t2S`m#rMPQrKA3tpZaTao+X$S=^?w3%{6n|F@c22ec{vH$!CD6Z=PC zltG?mEAa%pESo7OzYf6$ZK1dSHUjQD172>Gd0-U3zl)L|?B?F{KZbeO zdz7&`URrc>0FQGbt_uSQlZFir4*qYkfZ?(YsVfJ$$@$M#@ z2^*)!n4<;%jEiJiaj~AZ!fHCd)|+hW@6k^xWHxW>zE67CI}Nd{Cy2p(&g8FL|NUKd zEK;}Fop@uquSj**cIHhi#PI3DLX$Y6wxRv;44jGhA-f_5MQk$8YdJ>urJ#GaZJOdZhF*i^A{rmUl zp(8s^2neXk$jHFNKP-0W&UKaEres{<>GUB!;JCh?o*gVfKtB4ic5C7tq<8q(>7!Rc z%g8HL?pM4Ua3J+QM>vculw_*h+x4RKsU|cQs%vPFKOQ@Ar1VZ-DxA;6$S!eJO-*gY znaTK_-HSfIyb?{x_$Heo=-`rjdu7iP7&j17zu$i6S}pqk{?d0x z3i{fv>*#ZrP-fe`eoj+_a1T@{LR1wn$R&*3nQo_hFxjO%wc2%<3!mfCw9Kjx1)O>+fG3 zeby-T;lsbwZrr=~KJ!}a$@u0TyP>!^c45;G=dwLEI7k zsHiCWu56uf-n{PIL#I^i2-SX?*rfRQJOD#R26^2z&d$y=ukMI#Uf@y|@~!ajzm3iw z1GA#K6!qN^V~wj`gIT*}1Q)k{cQZ0320!AobMyBno67%b&V;8hIvZ`^k_Uu>zP|n# zy+R=${K$@Sb8Kv^zt~=PjXjvj5Y*?zMfamDEP3zWiw!m_2|FUAl7R36hK$*~A`0OR9U<dqAvm8S*Mo0nin zxDvBB{^Y4s9_#b<--+&xx6Z97XHczcpEB;hiTN0?O||^_<}bRV@#ek-AUYWo7$2R< zJbtbGTCZN%$?QKLzJY;)5Ny%I_pj5Q8bx{=`%4@*f!r4U*~1p!tTHqz68it(t$>%$ zPoF+L$Ir+2q=GA9*s-vRU#)yHHR$u;-2JiWjK_Y@Fd;OEbmKWARgVczU8 zVPKFAy6zSCZ>qiY?Z~lXH@46SL(0F1F5u<)Cl#B)f#LW?uqJS^3> z0gLlTs6opzQy{ciOAcyYez_IVCT~O7h3uhkX2vy>DFK3k7I+i<3y%4IU>aHEQ+Io> zS7+P4Eo55$vawP9=&@q}y_^gnlaNRbe0+Qmw-JxszXdQZv9udK%6z(S|9)t`2k@S} z^RecMaSD3-S2&{H99G|e0bPdH`ZL19pz1G)dX?+< z>J@1zCzPLNYbPnNOW3_Mha=QorE|W+K{M@uYjSkbR7yS!zmaj1uxm@^Z*7OjSh&z4 zr;ViBBn3Yt+~yP)ch%e<`TT@k(jgMF3g7C}@tSz8Z*0sy#~Q{isZPMq{UXSvOWrtl z5e*FuNod+l4t+F^oN-I5|Bx|p-!aU~w}r3Yx#f)vsWU4qTEwh>>UXW*bn%kH^B)rT zgUR&fSn;fObb-Qk#*Q4b-kcEVj& z&aX`m4ko{@l_fsdeBn%2Pmja$%n(>r%+gd}8bl%+yhoX7E`%>|$bso-JbC)H%F!c7 zybl>W;|5F&4JRltmRX&|A&5!z*^RYjUs&g|P9C3_@FllCfMJ~(EbD~R_=ZmrCPi-d zy|*6D&p$g$>dvczc7ZiWfc5dQvBaTjKVNT-o}Epv&(|VZ3p|`q>Q2 z50*GJ6J)!6SI1&};hQ*EKjJ$cYm2sRdObpylam`*TNw>QKJlWEx`)99XEC&eB&JH+ z`b28|2l#C=h1z1~>FJq9G4`!Ak&7v^ZCwiGlo9)J0Zy}>>Xq{f&cOu7z|AOV6aN<6 zqBm~YG}Q^>`%H673)2NJ#**cd&W<(ShteJ-z1d(k@k)*8v-5OEk_@3^w0Z5vD;;{G zKBhXZGZD{(XXQ$x-IZS~>F!598X3U3BuNVh2>c)bm0D@}HF(~}#s)1eE?xympcHvx z+8B@D;;ZGyhqdNJYkq}mRKiM+c~LqUMJG+;3zJ|dI^(0F#{A|!?cTnG*`VOpa9kH$ zlb0!YluOrBDh9|jt&i1YdI73tcx}F(N4ne(66lEk)v$eXo>OxZ`?RuaZsC5FHD;0%B`8Z(uEQG$keq6Z1U- zu!&lDyPe+U%YO}RlWPm)@y~u|S?>VVO?bKx_QU1a*jV&z)$-@T_m+e@7F51V&4vMl zOy3|f^4!gdX^Y?{M)VGSe^+M1&j=<9B39WXAutNtO;=_$&ABtSRasaJwz zC(XIyd-9Fap!ZJ;RvCrgpTNS}1hDLFNV zQ)WTkf2SokEumKJ_A)9e0~WtX8{;zSE2<{`c3ObKNIeGS?^E=BV4>fB+7EuznMBje zO7M)MqX0c@kH#%|FDQ6hAvmiXKsx2iVeJ(Hb9~32$sL8>Q9T0#gEYK>W0>^pqrj}v z($a?8KQPqJq^=4t7e5H~$a+LrV}w=Z_W4q?j<2DW`2UVbnuM)6JM#*wb3MpC#3)1a zSG;~OnbhSPdO?}n1PS@n2{)7}kaYhBY$Ft&zPd0$hW7O~ zk|8}F6+^&#im~3;xxRss`^!RiH+01=?A8olr=xfz)b?ek^Bg*KS5H&YCm%^`fC5-u zO<)lOELW@TO2v*AcWN!54i^6c1Go+z_0fqkVb>fd8UCxrE+)OV_a;4X94h0;sHmfF zovvLA27a)=kBYS(D0SM#LIwkbh#n%c!vx$=HYU6nDg#LFUuyDctL-;%s2$}$UkLs6 zlpBgb2LKQN^;Oe6Owg4Bz9NSz1roh4Lop-FRAb-jm{qc@w!;N(++XIF1G~2ekQBe%-a%^=BbPR42f|YqNuD!pQU;PP~iGZj0UBEws zS}u&jal8Cb7ih+cU5S@+3g;~_uR`!?IgLu@Q)_5wh`Z=>@6tivpU(vx+LD#R(~DYU zHJ(3zP8&nrpMS!yotXdYmz7!v_+}wwDP)2}`Yq0}F2bUW(}03t)oR#9?l8Ea`+yTp zP*5GJX28)k-(1ic ztq;Q=H8edm!=JJjiS^PhiQMSP6sVV!(6rSFO1B}3xH^%9?|o+(&v^czf}Gb(BrlAW z1_E`fd~q!o7n*FyDI=5bTt_YgTq*74U`*-O(`Hd4^%TSd-d)GI^uASsrwdQ_TU!$3 z&wVf&Bq(zC?x|q%kk}zK0x}%n>!_q8(ZJdh4J;qIa=5QGr+{ja ztR`;T(iNsi%Z4=>B1m0}M4Ta?i3Bu|jgg6o*tOXj#X-%_Cyt4jW&;bc387&UYw&fD zVTrNzg|*0>2>>y6>Khnz9A4hzh&BD*k)dHN7As-# zc}EDR%nk4WAZ!XlUCB(j`T1QC4)X4fLJ#68<>etSRM3OJt?kkyBags)Xaz^1fL<9&1KM8K9W?q@<+&S^n9(3>z3okUXTI{Nt}VxRq!5d>f`7{ zpK&(%#WrPHin~kY{L`9)&Q(j5p2ogO|G2^!vg(Dk^l-Zj@0K&)u^%HRO;nEOfKG`^$=x6N>5@dW;zDT+&0MG7ZZ;y z4-Ahn!a$Qs2b^^s9CB&E%1wUS&7-HJ^q}UcuI>tLiY$MOgCno2OAjiR2Egw&MUuv? zx6X0Ca}6!(P$DJGfv5%9l}8Mr6`+UGN^Es|7WL#SNTT4p2(xHX5$2(|LDZ%|Qe08I zj~K)B0FFFW)XL=cy}UF-BZh=Q~S+)7lEB}vz2K;#&M7Mc9kI*5e*yhBBn37FYeyGd+R%MwxVlc%*O@TP`T#P z`Wk_bh4ShtYYFn@QB1Jgy?cm zHO%V5N#A>h1rEId96Sk$Yre)GZnQWa@8YmH*&hB&^zsmbb#B#j$hv0(RYQUMnTOAE zUW90ecx(C(!tS!K7w!QB&6RZQnf%J(a^ORiFrF2_YGNk3ybHr3Kl5R99hZ@wj;uh0 zsZ-a~)c5!Iztoy4Y*Kpq{CS=5!dojT^6ogmWTJZ#U%y61aV-Ywi|L^XL#Qo7`T@Mb z#m{DBW>Os`WH_1R*XL@_oIc$KX3e>{LCS@$X{h7V5k%7qjk4=prciG|ELL=N=snN> zYu~Ag7eXlk=|w^F_^MW;^ItSQGJDFJ<}}gd|w5i zYN>o^X0uG#n8nP(BNc9urOfIMbRVxy$GO@^ zkaIetKt8M-O zm{b)fNWSycP+Mt=m;e>A41VH#FCG{PSH)u0CtPx#-fnJgRRArw#r^{oJq{rVyqug0 zaOLIwCMG6iC}oDr%gbNlGLn-&a_C&Y{=+FZUhwjFG6hynWBY#EL2EU@UMG>nr?tx) z%Y=pRBRW4}6?WImP{l?K3GYbc?4|>?c~G>)rT*pz*Ay^WCcxYZKbCs9XQQKIqN8(* ziz6TQOm}rzUi4r0){K|Zge0*@=z(~MOoZ2k*+&Hle_TnqRS5#vQ$~)S>;P|Hg&1$1 z10k6$pGGQ~?}Z%V4s;9ow03I%DzaTovW6~WjZsvDAOu9=Q)OQF<%=B@R*m1jy@cka z4>~o}?n2&GbjtxA1yju00AXGJMWR2|zUp_QkXU^4y)x7C`L7%vpLVjWEYqQy)%hMR z9MNnO0mLhUjc$K=_HwAOxLCrdpy%cE>`h%ow_CM`^Z+cP*XIS?*XNaG(*tcREK;Cc ze>?=RBgLr5=I%((UP!pAu(~=q*;e4Xv2iY-RAWfAoGg6_46~0C>LG1W8ifka37H9e z1RD9XI3{?gvXra z48T@^At#0=-T{|=*|TDZhp7k3)Zt2jD4+*ToXd=W_Cfe?;`zOQ=!(gcKqmm7#4%&@ z_!=w(hejA!9HUiBbSJ&5=Qaci+_`do7X>C&d0##Ya2bDufdK5s+wUY~Oe}39BP0K5 z>)Hhb#KO3I--wc~{J*@zvv1$NY4KbDvf+G?FGtsIO@QgDymsOL5*asu>4A06(tXoQ zpoVF}B6WbmDsU}|QYY4|VqGbsli2l$Ss}eHh{db-t|nKh?I!&Kf;OyPS5wmgYXxMe zr>lDkU<%l(t8_DD8^|gRHbo61uII?69A%~=m^SwfN{t@WDjPTmIod`>M!(TjCqxed zIkri`v{@S!7+ij1`SSZ%hpx8#7;_wzLKc7`4pQ{xwWaA!n7Xbs)i5SGh1$u@CtwIY z1kVUOd^}?u@F-U2SY^` zT|0JYg2wR(zK*bhax))vdG60Ws5Ey8EPk@ZNF^V3s~DmJ6zjg-hSW*`W{B;oJi1m^ z=~^;lkxY(T4QC9O1oU|%z_uHBB;figO1f%lIpZyUdLZgb6YfU<(fO~A6z~WvYF()l z!mS&oO~Pe3f|iU>c2H`>$$4Rc{C6%V@&jk+d9#goYkU~oWvi2kZ1C%L*Gm=c;8aw~ zZcTU=^kREGFM{i!YEB37*LmtD*$$glm#sGo~;j z2-sA=b`1dz2r`8T8T{ACpMiZzTO-7j_-~-iJi4RUB~BwD1`Xszu_Io@46`BH4a1<; z0kJFvdh3VWJFkItwZ}=X64j@)Df21^)yvwdWa%GIB^=tLKvb z$ivFWdhSvLzvc~u_ddt^i-Pum%48Dw5JNyTu+uDT6eH<5)c~sO2L!C;`yB#@8-^Vu znB@LEAOjKsrAV>vlTMV{U$XC6S5km9BW}rbZG93)k8`01ULA;K|Nm2T{};XCf0G;; z_ONX^6T{+tA@Y*7Mi8OpE`(HNmw8& z8l7>l+tQ_=%d11<{f>p*u#OrF;ELF1n*;}ziBzxL2IAk5w!R8ec?Vz3up{uTU=spL z{h+AQ6p;gIh605>ci%2G1cFPT7bj2on=)L&kGNh~1gL2eFYRLB=~+7Lw>%Uu#;mZC zjA(X%M9l6r@L4$Ux&A6Jt1tjnw?tiCeH~Y|*r5UO84V?(q}8{BY8{aHwUKHQ;e|kX z`5Qa_*S60aLpD$e1GhlqZCGZ1l@65zIo_rE6|b<-`+q7102F*+=u z)%E51VaZz`8Pq!9w{mQ0T6i^$>L_ene3jf&u-`4{ z!Pn1}p-fOcgN_Mz0CTgB1J@%$sO|K?o2j1L&#AO{M^+tCD3hu+I^fWt+`?+A`fLro z+NSAsN`jo1!|LxbP{`*|Z}0?uwHVWTY9D1`Ryz-CUrYfZ3Ylu%EG?%JMs3B7;_;RQ zA+ie0)G16U8)~lYe~8k>ErsSu}Ohs&&#* z+;9hy*axT*W1Mq>0&T^Tj$W8(U1sjYW@ldq{Q;3nK*RyQUgU?ber;a|pNrgT$c!O+)5f@T zjs24M<`4s$GT>aH$*j2k))^}9zXTXA1*&Un8-i&=MVRbG@Wb;+b>r{72n)l%K5xmZ zhLl(opv6{jX2o`|#+&1_AV?QAP_IFL+?MnzaM%+#F(D}_1E`|n#H?Q+48uS*rN_jA zqcIdKEF`1>q5voc2$7)@%z6UkUm$V8*_&&ITm6^B>2y(((rW-)I{9KuPeLBsFIsus z5s#&aqTqf{53r>I(`XVd`URn`AhFvjGsd6HpiHjol2B=O~KFT`b#U?(ekWBSJwK%?sw#$Rkk&q@!_-e-&g>aJ$7W zUF{O`3PLY}s+k2q4R&q0giKDbO4{)K1x@%;uG!iyrg|=~!4nOJ$_t2nmQ5rMOv*aL zV;~s>&|&ad7{c(`IA;j9fVX;(U{kg7GYI)J3>&3GLv-iVDM3LsD5Tqv521z7iBP@q zEXL&$zOBnEzVxCY8R-@2ukt+Sa2(4ug3J4XcW}zve0nuQA!_L%B!a~(*2On)iK~LVEsP7^b z@T!5~0J9wHL%C?N!KAsy$?A;!G7@**9?~CkEN%IoUeL(hqN>5+GLwffuQZ5kCj633*Po1E_{0cn>BRm2| z7FtJN{}*BnrzXb#`>{szE`2!cM-8>J?w> zhi!qn0SbjTHP>k)vcM2XVrciFdZu@R2kG6oaT0j)+a{ug!5B(eAaDf`ng96X) z^W(i|L`CsPFVXc&g=xjdcn?WzW036t)-4NOg$S(f9Q}+RXl|=9g%Fy(2JZ%fmvH;9 zScI^izq~Ns+zGQeh#i$$-+BMnw^uZaqm*6D+=JMS?M@}Yg#w%AtFEQh1Im;wA207D zQvI>nzDY^in6gEvgcnDJ)=Qe0kRgT)$;f%(I#eIqzH40=d2A5oA~xU6kkZ}{FX@A= zhUAV0i$BW6Re+(d5unbG@rT-s1oyw=wnhrDC{3ABWRS0HHbVN}ow6{|v1I4B03WHZC<{TRYBxMQN(`{3#H76kF5Cb&rmreMSON|{ zz1xYMiEo-vQ-dT9gg#m7wfs{)Qhx@CN;nUG2$&}MI$y0mq5U^1^$@!GkLjBG`DK9F zkipiYdFf3Vnu0=n=i^oN=(4_Sb*e7HHQ1^s{gl;`8qdt6J;W>@;q(NhpJEkgZ{HeH= zt?l;_*Q?H?R2q>hHEerQ=tD zJU2uneAExw+h6t|W?^W0_X#gyu7m3Md2-TDtpnuoOr#41%mpw;g_Q(*i9`1pXzDRF zH1mjp$=+e+J6*EvNdzAP9lOEw^Tb;&)GAgO>J}bKu8N~NcScV*R8gs$DGL;;jD`pS zGSI;>?&mIs1O+V=Vn{hzq@@8Dmr^szNH{P3U^)7}6vMP?ulo_eo{KEse8F^pnaqUj zgYfpiGCq0oggO$)vrc||17sEFvax8JAmTQLwnOaxncXZxSZVtXUNisIvG8<279_aY zGW~az{|41&c@>~e;~l9rObB4@AdUiYr+x_}4?+7b9$5O~dt>Tq9YFb!=wqB0!WGu% zloK$tRharSA|eKN@4f?yOeK_)C~ACG4UfW9{Tv+~gYFVj)3cC>5ke5g05-V_Uz#Np z0!;%fDQSJ7b$Pb|1&pxd?x@N%D01)uAOtvgdULmNYi`9nkWnQ(r)~lYfqLm0ga}%{ zKP5#Mb}NI3E~19?xjkBHz3@+rY9Y#i{w(c6*!uzcq|2Zo3lr1_<|t0I366 zn0GV$umjZ#`c!Vyn@m1IWe1&+52rjqh6HR69T?OgzU?Chrl$E&H6BY^0DuCK(|J+Z z0?2IOBiNmgkq2n82?N{#L!6P8jv=c*)(2TKx~C9owE3#?g0(w)lNc`zfaG|iD5vOr zeDevP&K(TNMehct11W5*Q&M~P?sXcLZhQ66M@wpXJrw#sbF*aeQG0N=lJcC4YW`f{ zE%l2wA@ zh!v$cutJ|}D#Wqdw#UG1dFJn-SU=W*n1aqVo++IqHv!TG)zci~9i@gmZ9?y2h%IXJ z%4L=edx{*6VsKi%)Znfky+3gGq9*Kz@JWto8&+mwrK`(zsV*#}QO&)wtKsi{7nJCt znowA?W~Z0!UudRdjV8+_3Iscsw>-xeO??QHwaW_aw&V;YTg~<>SzuJs&=RE0` z#ivS>T?>P5g0u0=lwpA+`>I{U;JvowtW%v`iyCRhb@4TwEP3itGOPJge%PxW9QdX7 zvs`9`hJco5iJEo>?Hn~-uSX`xJs7{7RMM;@Jv2|kxv2!5c6v})u6){i247AjD1MuH znJ#baJxiAE?l6$cZ3#pv;E`=WZFx z&E5`SzyDCH`c|fm;@eBYoQ!rcsnY&A=d;Z7jkMtzQ$;CVO>UCwizvstM&cY(-a0YC zvH2YI^=K)Fv|;K(NZiHozpFl+y2V%J-N33@J>QCUOe>x*P$YkCuyj}ZhzwtGrmcUBhRd0M_!FL;mCXl$n`OwW1t+|-Z7HL=B+ZM$@{AZc1& zSiExp^Xtg$H2#`H?V5Aw^7EoKAMr3nyxdw*)o?W7@BLMUrsu!O*|i;au^E>fHc+%M z9G>Qil4)PNs`V&M{%$D_H9kM=W{Vjg`o1PA+oe7`lqOcSA+nCcy`^P-#`ua(;Rih) zcdm^DD-Ad4c%f_ZG0w6zi^DQ!I_2!X2Ga0P7i!stm15H>y|_&q*2WUjDq|nncq_W^ zs$pu2+ljYAi?3Px)|>eSEq}K7kd+!#wcNF*I}5>e7TBu-h57ojO?b#Jvsj49BMUwIt)>Xz5X#AxEiO zU!~r*Es?60SH%%rA31;U3mqQ5DU34g#1LmVoZA?s2MaKM5%|-qE_VsOxkoUK|CqVW z=lu4%Et!-X@l;)(T-|L`lK$zPOHx`42WQ2P0UV&X%YFlIy}DXNB>>O~Q{UxUj25q> zyvR=j*7ORmk%yBCm+?m>eR%Sus|#dsgk%oq4`yz|!`T7C_^QIy*+N~NVd>p7zM~#- zJwMJ@`(hRPdYg{rjgt%H(>#V(-&I>S)2}Md#4OQQUX8b6P~=ATQpMSi3bXot^;yxa zUa>uXOVz`4^u&oUWur%HgNtRm0@AdWgD77n$R+b+uf-oSL?su4rLTT*xD#gh>Xq^t zPW<4(eMzZ?E}co$T$zeA2fcH?JsQlF`9IqYNhEXHBVx4=N2LFB5PGBKuDT=f)}@N+ zwwb>K{d!OH^$;R2sMmI6n8jo6+}_AQW7MKLvP^T&Q(83_Sj}Yl%=9&e@g3BXVN`1F zqG0*r4|R4o;*ktP;r2C+N2uL)&Cdu`Y26$>17m zHucJ>lL_g~)RihKjmEL~G3i*NTO+Duj_%*{<$A~X4u2P)ide_>@xu0HQXYQ7k;Wpo zy~oRGt@2ZeVw>b4DXzK~!ygPRznZ1d*Q$+gl=yr5Mf2}IU}=(LJ5#P9p7|!#KJ0$h znWK)Qs;#6_HxYmU&;wNeip*F4 z!=lllx|=%L)9FEKP3&rPP^%!|t5}k9lUj^~V$y$A@A8wKFfenQxV&myq`Y8bT2PTV znCk4uMG5IFvHVJmqpz8Fhv#Cv&5HQL+Dm9(#~elovmGv?a=N3o)oBU2J&7)U%kxwe zD)0P+Lh@{GI*ChTams08^39s`ib+hx+-x}+7ej2C7pJsn<)B;Xi-S%PlH>i-T~m{n z@W##YvMLdnwQ5g|T=P|)A=Wl-`U~o0pR}ILyn|*~)EOM6pcNWY%}Q=eW230+3xy?B z?=98@h;Inh`58iZJ6r2>69c#*v)23rD{C_*S;R>Du*r(t%vMPT&ZygKcCPV34CTcu zL)N+M{vrN{)^lATl=ZAZI}Eo zQr-MW@|<$+X(7r+)rgbd2O;wb^JU%XjW5goRj#Wo2AJ!^bKhiSH-6TTLqh6qbbFz# z`UT5f^nb_-qtXO&7K>^ou_m;fql`27Eau_#3U2e1eyN_Ize4$r>``7PmHj z+rF__YI^O~p>99wc%LmL)b}BcRzKfd+F!DwO`X0RCPJOv*Q_wN|4O?Xi+vNymfX!I zN`BufSWXw+@6lg?Y2#s`>tO~ZHJDwl;!9lUTFOlK1I}WmbPCj}R9b7imLDv!5K$9h zT~oHG_C$MeJwkau4QF5HXt?IT;U?Va^3^nB=d9jjm{2s)fe=Eq8-HFtKZGk8k?#AA zr(&lI<6|}qYp#lG+_Rs4ljP9fiNO!*cP?g)gv@q^X3ciha!|=(LgovzUUgkCHXbO{;Mo|pcWNouqDvEh;9{FOaE+s`L^^+n zVeEMWPyM|7MQ6Swy0+ik5AS)dzfvSkaWwlE3T4D&2|mr;<duGaTJ|+jd)iHzlW@5wz0wez6#=94#7xZ;1QKcum zCdY7?9R7)QL0>hD|LvlmBC9TsW^cGm_KVvR-RSBC7#G=v!A>}9=u}nF?b5HouSmXJ z{Mzxk{N`nsVw74l(|Q`pmlvXp-IrNo^$UIunvV4>>UyFd?VtxJo%+hjsq)hMy^8he zw;CmA!L*9>0*;cDkTr#y|4`VfDl*g?Y6T`Avl7A+P%HkDa~lgDhd1V;mi1rVO%ze| zcIDH$sJ#6m%6|Sff$&au29M&ux}WZzg7x!`>L~Oxt>L)mx)CwEVJgJ3`Y@C9({yaS zCp)R|u0wBmfb;n7rz^ztAosU{1x9T0e?z=K+O~Sp@3oeeT(7>@w|ufk_`b!uEKNqH z!F)rOJ1@1WByQDMehqt-cI){ut|^I(_l8NGso(6&->;q(x%5?ARN`YcVQKD!WL*%sDcsgav|QL6@3YD~bI=C~M&cA5zsnQX4$*4qec$JSxG zcoZ}`MCs?1H5Yf==TGU&DB2`a1s|`y11&b^ZO&H&{p{SIY7&bLB_AY;oSt=__?-@J!Clr6TiolB7_{y)pIz zf7)MtJ;Q~HHC@-1D>F2TJ0emRL|j(P9O7E`UVUav-V&L2gi!F?dBM-b+1kx=13qHD5W} zT4IGRk0y>4c1_q~Yct_^dQ1f3DccI23dwiB>6pUmMKuBGBs|ob=lL)ar>nVZ4}FKAoy6 z=|K}P&Jm1TZ&TM~x|w#ZmG7$hdVzwUw|ZJz{cy64O8ZY{O4Y#W(L~(hdWWK2@5&>` zLh3?Bxup-!5M^eKTK9nD&oTZ6>(-B{kJrJz*k~5kFt`yMDKxy;)^ zqZ4Z7USIY!D@?ZzoEs9q~FNOSNZRS<+mFM@Tqhz_)e}(lJ#u8;+C5}Iiu1>2v zN-Y0XYR|auv392ui+p;HOLdxlQN7)@T1pW$HzxhXZ?T+D$u06etFkMHQPEVY!Hti` zzQ5gF3${kRsydd~k-INy_nituAI8S4uXQRK_r6buGx_h8c}ur>t(nPRDx>F$pR^RY zVA8Na=#&#`iBdSD(-WoWh@RwbdvY###;(NTRrHaowW8E3kNS=n@4Irja`vc~y?4?` zx#N`|nNj-P;gUCAY0g|QsJE`u;OKF%RQK|j^`ApQKjhz}F#W}+{%(Yc7ZoI<4ZnxQ zb4y0XQ?q)-Nf zkE_}9GuM`u=fr4x&LJ=KB7woRi8Vz>Ik>-QN_X1M)v*sxaEM7g${~<(jaFjQN-LV7 zO`mO!^Y}sdbxkcc?xJI~joG==x5DJc<6?!A!`vyIwq;Q9^c0Z+bv`_`QDBpyS0Zt`BC8 zcp=qT_6xi#9s!{To7_YlFQM2p42XG8XY1s8i1laEb{$x{B1=pQ6f`0@G-NTVp2=@( zoaIw>7eBfL+i>6Jh#8#Oh1umBqHsT#pD$)It%jx_r19Rp$0u+tg8|vzo2%gcQ~S5d zRO%RpQ~~M2*&JJD*wp37r)o#2ZS79~^>@|MT1By4mdp5)grJ#@S(K2@?>?3twD&z9 zUz0J|seb0g0U`QWJl8jUtOtD~LyPA=mU}=%?FxESdYR~^+0KLU%FS$EtE+=3_NB^9#iotRb=~1Yvu;Bk)1&>I@=At@gs$B> z`63k-FH)ryv5vFGyaeTUxs)>)f-(KM0u)}O>(ZoaNw~f62v>EkcD<3cm z8|0^fC5`8`@aA7HSwHTbo0sJ)k990mo1i`=y7*u2RJqo9IG#_?M6u$6Nl6eb#rN2> zUP*3=Tf029KGoP7b#6$}&2_0&#K3d<ShOsViM#H)+B_feXU=yB29%%c2m#m~Ngbwc^GG>;4#wPqK2hhK<-sUqDdTD_;x zP^ri0q~l+ph%Aa}cO-FNL`@9+W;SyrDe}g4Z>+P)idtfi-CrwdjPzS6`mv+PWr1~Q zc_}ToYSAlB#<{vs;AK13d_o}?H|uP{aIcUJJ$*C=TRq~R;QM)Mdp#Koz{XtKO8tt< zy2`Sji011x0fOt3FGRF=k0bj+mWfAe8NGAryp1Nc^N}pQ(qu7gzg9Qgx+1BNZNl1a z=AjG4BQ|nqn26~J%w6sj$%NJO$lXYN`zcT4Q+Aeri|h(|v9m#}fpC%+c2ik6IBk$@LQQ z4^wGHbEzn@G2xR{u%Zr*HnCT8RFf^mhP!b0@Ou0k#TwDMe#if>y)%t!;#}i+uJ^E1 zJs?$4SwvI_vfLn>7>JMpMGCM@5N|h`4}oiT4>MYI{yU-7ojc%_kU!mv`P7CJfK-|9ftXyq?O`<6ID? ztX#RO_TSe|QQKlX8#=_C>WIYE*b(RMdi7H;A(iejQ#BBu%VB2JbxOu`3(qn~zaF~R z`>4UqhHC+TO-jaMU)n8s)1&#GM%e#TQ5ihBly|6(t)eT7>)xy1AKeQx5JhHiuTMVp zrOltGlGEQYX6GIg*%Je&JDIOt;Xlcq5jg4hn&j^YFN~$XN9Du>I?-vu)V!k?bu|0+ zn1#ctR?#o6!2g8d9y0#9y#S5BneY5;pP7-uw$q{X^A-!U-1MC33EO*NM~|-U^7lXRf+$ojTB;lZ0_EHS3?gQ#2E_h%Fy%E>ji8m&{Azh9px5 zXvU_E$EK$09>=~6mmV??8`~ZdF0#`R8OV%a9duqrK1Zfha)C5LL7|goZt%n~zKR@o}vnSTm**we9?tUGW zBOH;ayZUheV_|I^YPL2{i_!X8aa5N-O)SFXE{ z39ao3h!uXA) zZ8UO6GlyyGb2l*7mE94Dy=H=p3S>_vJU}4iAnw@E&}$a}bX>|2Pkucg3iKW;dHQYh z*Xb!&2L1VH_!1Lay}{V6O%huRwAgP?p6_HZ1YXhjUwTy(_y{(X_9yg6PL{5%wO&xZ zV7L`UyZV}3>EeQdKuVZALq@&_5Gjb>K?VPSEt=DBUWOjc1f%^%703p4{H!!4Bmp7} zghgl$rs|28qI>XC+85_|B$e(D7tOIVFTnC zQ2^;2US+#GY)U_P2OVRXt-pN&C~-g)Qr5L>(6UfB8;tD-n$LJ#K7j6_UjhgCW@5}5s=BBr7_Sx;T)XL$-uMkyMe9+d5@yC#nV?2-xMA<6!r)WOYF&0NOX@gs~E881}rS z=v}_TalxO0@%Aw9+DLyrdA+z2IbqnPYq-BY3mMb<6Wp_LUtR{s<~J# z!FTOabz|e?Eyc%APyJ0&_EXSRo16b{iYA|8vHu@Vn<8$>hdeN0X4EIVDK;d}Jo^I@ zx}-)RP_TAr!j4R<6#&L#0bIB2*vz;UiT-LI>^_LEKu~0aep%z3NCv6_(OO$}Xysv; zYn$&X(-x3m*uylbl?YAKO0Ub3Pa`QRwZx#h=)i4KtO1`pN!&T0zn6%a5rBQraBKHK9^%Sw zWR6WqBip3sw)y}UJ>K|8u8m2Lt_}g3b>O$Qel8N!(ak>Auf=X!H z9By6gv+Y_94W{DUz=VN5yT~Ka+gp|fY8)H*R;ZG;fW+Yqq*T&&Nv*XkWmfRriFAmI z&Qebj!+;b4gGZ&mMP?AfkAh(YUr+crC^=Ksl>Eq=DSMnQG1VbiSzTm%Oo#MOxBA>DkYbsOD{8&W? z7?(@z2MTw~mR<`Nn9DdIzO0SNvOtg%ipn8?dgV12ez*XZ*BY|H)uRVejx|JX?GFO4N6U@uthkSwP(SJ4MHR5byk%js3eq?EmPVqRYAs zUi3BUbo96Y;l@`ul?OSUks8B!JT4)i+sE5sp`+a*qVpYB12F;N_OA@F=$D#KfI$Vq z4nX3p2URI~C*_FM?FcFW>6^p@=4vbhpwV-Jt@~Ee#^wBGRoO-O?c`4bt6$bc5vG2k?Er z?>FxK=Z4bIm!|J_aeuOQAj_dI*6)P+vS3SB5|k5+M+H4dnaa zN}Kko00c51`a)bpe2eatzntPETH-|KjHch_5ZKFHR}%6|$nFp?vsu8#jf zl#*XSKv#})!4Y2o*}=)Y^Gn+D_m z#uXRyH^lxo1}dSy+`nHyUlDNryFM6c`2YC__G}#C?(PlU|Kq^BIZ=kBtE=nF+dnfy zI;QJPT3276tU?$DMM=DyhrgnX%%jPvsbo9Lr%#_EAR_XtOy1*hKYhk&Ik`Mg-O$jW zQTnPALtUI2w}qDK=9NL2HuAQ_e2k^Myqx>UMO0L@*!}Om;SNr4L_|PhB5_S^Es|qO zZ7tv8(o(VtZkYHI5izlwsF;|pF4FB{41RMuTvK{Z^U=_yu%}11Cyv!vBc(e5iQF9n zQv9ZuC||dBbhD&4k;}OIGc&~7dDg|GN~ZQVDYp$|s?r+C;PfC~$m^|Hl4gMORN}6EZ@r6SVH5)%2SAUwf(BmY@N%S2vR|J0;zBk zI0*mI@XHpxL~0tW2b@;ZrZtz_wL4UtoOmx@ylBxvYRO6{c(;u{@3}-jtBS@Dglh0fCvtlB}#O!wnT; z#;;$$;+^2);-b2ZFLUPiJE?Z<%_aVd;ypo3BJApB+%NVTW7+I~iL#;^bh|fas@mEi z#vS1c%bn5R3JM$@d7$?-9>ug8?G=wL)D`Ef%~idggii18WTe+-RSLn2T`aIWsMFj% zp_BAitn=e7F=uBkupuSZbAsPS-m1q+Wa_xO@{ShiN6gO~eDL>Q^g$vMmX_{<4f)!^ z^>M-RWckxt+g@NJiwq4yZ@><9KYCpD-yBylbjY0YhftT9F ze`v8Sq~zV6QS9jar{X_CLRYMPd0$NLd(U_^M|8imUD`=$SzjAE?T@$)IN%X#O?{Kp zVSX(T)}BPZC?u3}Ul^&At7xH@bY%cucpQ2$KZqJf#;{iZ7WKQ}36e!e*y)UY9_Gx+ z@k=xAqshc^r9wmuf5gU9)NT33t}8kZQu+F8>pG7bOn*C1Zz(R>fSe>f-cGJ}TfL<8 zWAGTX1`E3A^WXnOVW8|2+{N%6zzlO4HYqSHL zAm1843`vvHzJ;X0(>ZT7`xtlFY!)$^dR)v;)RPRjEm-GSwGc}Pn~pNcK6v|MqW?Nn zH`VifK*Ma13GG7&ePpEmApu1tnr0{!lHokh*|3I`m z<#N2@uRz_&&)D(fydaEm7nIwq3eznpC_qL*Avv~B{>y) zu*GP<+E<~aASvmudjj`+ZtnT;>}>c@w!%F~mO`o(5u;F25*BB@>lsvWmkO`Y^-$*afw>+a1g$U~hT+axQdtL3M?){Ar{2}0TOtK^#bQ_S+MMo#6 z#U`J7$f&3i^77$gWAx0-5&#hJn6+MwvMw`%JUrd+ehNq6)LzEv?%MpJo3Lq2-z&@{ z#pTWTI`PjsmxE0~PHO6wH}gx%TI+8~@(R{S z-EFcSiBw*9;y~fN`#)vxZmfu!qPxGBD}bc=6f)%#OF~Hrlb4s*zv8#~;NDV)oCr2#cw}UuH3+*!FTmVL2_3ym{;`NK zjS&Kt`@z2c-cd>P&myyX1E{)Byw~6gSGLhk((|_#AAf?^fNpga(!+gbN`l|y4zbVn zIgc%DH5QU?JF&hXw5Rp(-%M09ygV{5wsFNWf)8CCxBoX4ju}vc|NJ2=-8ectOslPR z`&r1PQB;4N&}Ohzpj~OSF`|zXiN4rh!SYyh-o(P;3l6d* zM}d@!eXI{<hhLV(ptHLMIZj+h8%4GcceS&DS(`%o;RkYi%{?P2ixma>5SdXE*5+fBtht)fnUu zC`0W42hR{wA;(t>&fIBT?bh;^&FFGPFW{u-PU+B4omqXZFyIv#wBh2eC0H04kT3#N zhCa%{!IxS!vs+JpFoCQwIxYAik+o$Dd0_mR^YSH^nEKDwI<=!=EqNir!+n0LA3i#& z7El!b>C^IVU3U5_^I}fO@Yq=5=eNun#Sp5fpX;{NhQ-GVEdhhW!>?=`EB?y*`T5Z^ zF_m+3LZJ?dWvTwriuA3Ip4NH4%$EPr+iPioicFrb7K3MlE+(EDB{#q&R(OA&tEY{r z3ht4gc8U2pEp*OHNcQocr;BBlUxml(H3Z?BM#(9U_5^#}s|4_~{v`s>Uy~<-Fa-JnK#~uNO>} zIQq8!B6*hYinxy~maki2=QkC3vB7a4(B*FzWY?_GU)j#pzsvqb9+Kip@)FKqCAsPc zIrqWI`EIRsUwTaqk%ZKBfXQs^l=ZX3>?{$&J zVJXVaw9)0jht#?VubBL{e$Gq6lkLCk5I7^Nt0$gHCkR)brO1A+2Jf5Bs{0~$qvcLn z4o${)^GnE@X6WFa=(!SKJIh!XMH?{aU2R{_+m5|FUB<6mY?-x3WPP&xD=W4{1p+z# z7j!9=0Pu!)#~h!KMHjzF^~PAE^3n6q+HTpTtc$IDD zpX1t{hLG`b4VuS~QO!x^Qzw?ON$%B}?3^PQBi}#L$15XJxa4^toru(GsMUy%K|#TG z|DIK_iNcK7W&J<)_c<;u5P(u*18t4gKUx%4U)=c>zRaejhDlvr6cLNsA@^6NZtp;(sH|rOIh<#j1FtDheg0Y_T57B>={Ny zMMY%y%6N%!z{^w_8Z4fpv$5bOCfr9H*0Xh(^&3eXfByW@X!I2LZas&_#l@AcS5L&i zz|f+H^k23cnVcL}Q^QM-^DtXb;rVmq(z3GTGZLsPvtAuR2>w$iGt~dGX>X8}lNIKJ zzkV5MZF3IS)zx9(;4FVK=VJt%^6KhpcW)2O?L%I7I;Me+v9WPPMa5k`6Bj@JR#X%a z8;c)%f|-zzFg|g+2IAuC0i<%nv59$)C~k9+Dx<{r{~vPW|AnA|iZ`4tK62k{!>51p z#NLs?I6<81hxPR?TVdltwub9Li(+}Mj_$8Gn&?hNC8dRRwsvT;d;mb{w7k6Tk*}d< zZEbDea&zq*35^r93fkk@O$LUBme!Tqp@4%YBoL{`vSu;FNGA!njgGtyb%|xs74`6_ z&+8@_!g+|;mnyUaMy$KIaEEeo*3^Wu%CW>!^Y9RMc6JuByq2wfx0ms znB+ZueSP&o@p)DUBbd`56w+85BSw7r?TM{h=1!Q}ByO70~r0IV=YZx(?(XTHrZZ5xe_!m9US0IuywXVKm zsmBW6UGyj}3zAI!`3!Sx3+pB)!)(LC!oKC?*g6s}L;U5SZ*ZQ!0_xexk?{5tPy{m9 zHOihuu721`RE7Ri#*#_jzR?3sT~mW%2l>gF(RfCN*{+|3nif&RrJJ_Hgk`+HH^w3vW%FfK}8 z1EszJdsc6{!24!JOv`0nyzo(^ces7}yGcS9>n`@Yr^Bp1E%wk7#eBpzUf5pPUT3?j z5Ni+rleK}SKmXM^stn9kAZ56nHsfH=Ve{sntPkVv7d&TzfJqUt-ql3#UOGBDB#_m2 zJJB9`Axa2r0qma9Sl=W`cF3Bcm-k7oeu_gZC0mZZ=c5W%R;u}U3L@M zhLa10?=+nE!Pc6)?7zVntLORv;+#nw_McE33}pAZat{=zQjZlu1AauNvYio1Y1uEE zeObZ7X`kC|GVO6$45F(do}q7xmxI1t!S7;vp9R9zU*T@0SX%!RMS=b(Ewn@xZ0xNV z?`$5^-%MjrquivQ%>C-`>fJmg3wlCvs@$q-Yfx^V+3vm9<%a&dvp?B4QEN&0pDoDO zsm7wDr2IcamHhg&1GY0p;E_j@&10#bP!pc@$@8DYUJHL6%@C$)IqZzb*iF8J@@1@x z=kl3aZzS0uO}@U(cPk=^fQC$W7f6^U2fusYZ4^}MfV7!UKuZ)noxhi_-$?ef;WT67 z68Jq4f@k!tMy%lHO;4Q9^t2AX>WkaQbl7s)sXt*7EE-y~kt-DF>g{ctah^j3$qH%u z@Bq7q%xfu(%VvF;0e}h~RTMI8Bv?8m@w@!OUAavlOkI1c+AYox0Rdo9?3y?xCU6J{ z0waX32sFWK_FgdL#YG9!>!x@EEPjL@Cb{Ybg8fy_BJmq zV_;=1nd^{|kzuu66d$rQUK^gATyM?23*JDB4m1agMA%H-jzBR-qf2*w}hA8Ujgo4^yU6oKGlRYE? ze)2|YYCf^%<>ftsD5|I|FG{y_7N~`Uhojxt25fH}9UarNvjeVo>`TGDpd1`~u~qWJ z6B9w(+g4!I{LIXIsHmu>W@fek)IaML!A!cjZu!xPi9eq<*er;QfXd!IIIsu6t>)&& zSFXhjfA3zeo(Vd7i}nO?r+|vt#dWRCZDkcH8f%%EN#1!)_iN)wF>hq|=*TZ|ilaKi zMZ>|tL6`z#w|-L`X7-sv-NTaefujZ9%6>)cTU25X%KeX_p&=+=jK7|?HZefsQ#~g( zFn3^XA^VA{$FjWvwpiWCtQ#1$EyDAVT0NE>b}3sY)^G>|7=vo7!h-`qYAMp=JeQVc zb4#wTuV=NI<_0K7#m0tp<$00r#R0~hUs(8YQi=^e3f6*>k@4XTBAfL5dRwu3eEi#O zB53$6M?zvECN6FpivXQ6n8(4v61eNyT_PRrT4iABMv=O2Ji zCzh@#oUTa0%4$2lk@2HbR-yfmJ>t)&pLi@21n&v}YH0V`86{Q_0rc)db(7cC0XhHq zX4(IV4#|AZiHmB(u>dqfpjvM(t6}^Huv0wG%xGw6s$f-+FVVE;XAee8OUu2#er*FA z)x@_YaL|;~e3U6)qon!&hD;bszHw#apPq@-rx!D!`;VomnMG;K`bZue z1PPtY!S?Zg0*QExL~~(j$#;L>ZfFaDa&5n*m6fEueHnXKhMrLg0dTaq?SF^18M8+H z0GI|CatH8m=T{hPMEs1UjEyt%WC#v!XdPC29Z*|9L)QXYw*!ptLKefGmpLjo&MBdB zN`;sHzyL_jJMu_2*WIVCq~sGE{18ya-+T3&I#yPFe?*MNU$xbr%!Ytk+j%JT1MojM z2mmT1el9?|qN1YY3wE^BG&LW@%ld!%gwxvEs=>d#>Wl+S8!?Y)c;FiE9^lrbbqy9} zW!(pehm4Ny4Db+!cwODyQ6E1pmXq%ysHm&AfG)&chX4BG0DK&Hb#*gZKR_h&4U1AsR0#{o27iqCjWOtO|PBp4osbS#MKcTUgDfHB&} zCW2~eYEJdY?ps+|fp^4%l3{Dh3fXle^CRG#DE7Q`h1GA!?qbMOqwVQxmcl!5TbtZF zX}nG*&CbMW$?{zO{d;lBy0I;Q-1Fymfjc%}`03N94?tkTYB2=KZ@d`5+qB2Y3_U%4 z6Icz4F)UuY`kv5N$O!%@M+eqU3x|K**8p$R!M%;$!3y7GIZg`za3hZgRL>R@XCU5y z@*wk{s@)L|A~YEOJZ7h!fW9MyJfwNN_R8d6&K>-ujxg9;Y`X)G_M{$^63zL$kk3@G zv}6XMa6^;~vBP-!g^rLtZTjVQ_l`w9(@|VKV)JC~rXk`A2#w9@uHhZKyBmU9UPFT* zBQx{g5H5c|8f`z9y8SIw7KUdID`__W(1{Es(Nwmp=ed7V3n}ClU{>#%Qem_wGtT5$ z^2vZg{DiZ(yc3uim{K_^H8lu!82^LP3;tcke)alwHEc;^2M6Og(w(hd`g}o1Cc%)u zxivHw#B9^p*x3IazNx*bA%I7BUdHcA8?zVYzFTIS|F+Pm`+B+NLYJ3c{qMQnvSrP7 zxCO}WU#{9Tb8}}~iQAdJD*9jIOfy}Pes?W1LX-(lj0w-UJ(TcgZ?B`HCiJ-`Vdu|2 zX?Dch#}pS|b_ekJ?vB%KPF13_#iIZhyh99KV0pdtCtN$_g1&nhadzKIW?9+L$m2Ua z3o`>B=!^=W|~j0R9~&PsiMmZYandz))_2;7CFZ zAjLdTby-s@D=WYRur-ue{+16SkbnBWYT=iz^9&*+vhx*7;fMPnAW>m7(dPQ{#Q4vz zuOKU>fTs&nR<(U2fEeYd$QSNT0%SlWCntAsa+W5&MWx|__~X*((&u7 z3?^t83Ah~_fUS53^uA@vK^wk41Oob3HZd=pCQnzabQp7mEmY{qs_YRz2+kI=+!6bq zE7Q}{&s_GGDeK?dbolaJ7X0vc$BiQlFVBu3-cn>P2!NUVsBzo?-ptHSHBcB?nhMZE z(8OcWsq*fOrlWrRxUJCs`zDW{?God{>T1B|rUk<3IZt)hpaO054}ep^OM=lW0KNbf zgyrSYre~^$hctsjLKe2R!hpdAA{Sfo@{acSieU;thY z3?$x_8l750SnC1{3rn_Af(jFoR#7n`fjcdiw-KUq`V2ICLnLcE(dRb4+ehxWsQ4C(XWU= z@jqR%CX?`PEXwPqyv3X+85tS#z*oDa4lvRudyru2kQf<-;?s1m=DQOxL2mEi1WQA1 z3inToY`itJEBK#HEl0huM z6uLLq*l5~UG-5f*nZi_Jo`yja#(zIz^~9(2{*%P`uK!F6g&}qBce@nZjn|%e5duJ^ zy!#z*;$KN&7rg3u`3y22V(nzSWHOaDQV~z~B=YAo4?EibOdM%<=nxl=NyDIoOeI~> z0Pu*ajlx7cacln~Ah7TOOuVnA zdI9?gM!xT-fm*W=K)x8;c=0t)hLSEmGUe6xSB1hxr$G9MlFk9nd_MM2?RbU+U+EGa3WgF=T#978qO z-}E#YM+o>IZH%3`O@r&W452FT&zh`)vRb18?^hKHB;^G2>lynP5mA@(nEiM`XEN_F^$S`Br z#^xD#W*yzGxUY)zifU^9Nl7FC+Ot+xSE>2=N&YLqfo5p_@**1cPN%&^3BaPwxz6@R z)b$r;-F6)RnWW<4nfF_0iP6aT10UhD%tuL)0V`w~zW;hljzl27TjUPJSxNzj)*X#% z-}1vwe|3EiT+o;n5T}aDlSwyZYv?b8aS6e{u5{3PU+deVX!;JQ5;s^Ky`g>Vi5BB^ zE@gQ#6!;+IZ3GqImAr)@6zl>MIFffjX+PsMWqG4N06Kcw(XPLKBI~+fZ7q_nvhbaw zyV!8%EwV?QuHqdYRIGrxta$G#1sh^33F(~@k0IT8pQ!gu{`VoK;f@autgCFd_X&fL zd!AD2dH#KU$K0S0V~8qH*_mQH{{XA-fxhS8$9i6ujyLalgrr1XS)l2!V#^P#@9Bo$ zyr0z6mZFQCYigKof+)kh@A70W0(i@JV-8CJeato#-D~CMiVT6=t&a1TJH3A2e5`QA z${`acb`TTQz}@wuS@;kPjg}|)*N&1|3WINf%k&lS8~?pvN45qjbU%Jx{EVtF-xM{( zol0jIdpqSQ>IxlrXi` z#fP_>)SmJ`tP2QmCrewl%nMt_v`<9y%*A}ffB##M_PW%o!|cGYFc3Vj(k1Rhl*%G0 zV_rB>S`8nBLQ3k2uXamwp`6Ux-$P**vX0~(#TUOJ#Zk$$UxFxii!JUmMRsx1+*v>0 zV8w?oOhMs&c6QT0lS3E#fB6~l;C5ZI3X)E0L>bi06d3R3yP4epdM@M@xb4MWyY)kB z|BQ~vlz;#^edwX1!+NvPgKLg!1C)>^p{qS~e05RZw&EaD49}D02wDB}haNWT?t(1K z`0=O@uHSCxxmP$owbHG}Ihm<`=-D8u{z>{~)uFipmxn{opFh8GW>_+*?FKB(KQBCO zbR)?S_8Rh78bXoTYR`U5$@c^`1)in$c6aSs^EJzgjoN@tzwx*oVytv4hRa`BiPjKo zSn2r$=ki``w{cOZ)9dQo#xmyMF4$ft+=Cd2k+ht>Z)<3n+jjI^?`>FV z`VdB!De7nif$k0Oq3d*wE8&HEwejukqc~RkQ=J_jEK`8{$veH!3Z82`BzHgjomygc zE=VHiNk-}_pvIKad*{0U^vBnG2WM0=ek+`gmVD6#n*b;Pgk%w>jjpaCA(FV7ixv`> z3Lw+Bo(gzraxj=awgh%7P1#(oV_D|)H**WlnUqMYwZJpib3rD3?nj=UVAbz*8I^1j zH|9UK;4jC?f3}v%?YNFZ=C;XNcRXoc$PyqK|G>=uP=M|UslMe*NLZ?6%h=ShCdo;* z`ka5=SKOtc=-q&9G7%faw1un2#<1gXw2NbWA>rr6ra#wfR8)eGkI^7ar^7VwTwKs5 zD(+}Jh-)=c{&yHiXe3Qh)Kb9s*3m@^v#TuB3)WVD+9rU0ecur^mM@;uat<`dkRVkw z3=(buWM1w^UfiyW{F8i#5QDQF#VDK@Lr~raa#fbbgr>YD?6X+!v~31u8=%9WFM|?- z$J{(xpsl2;>IaxR@C{~xrA1)>Liw%JIanu9QX4dz9KcRQ8erDBI#&|s_s(rg)fFmPa} z3{L?{1Qj)$bro_|spUz?$%P$+)JZx5lbQxAy`aEyxV|sQHp=vKrRdy*+0ou3M6B}4 zL}#{Sp4foXHH={nsPWWxrlNWIbNSG=^%eVtC!1bXRrz&9Sg3z8e*Ex7Dlz!+ z-m`yWbMu2%0a-X!(F^c)jE8<;d?LppdmRUE795-@ltgLiJm&4_pNzXB-!HlC_NmHi zI3dC+@18-4P+e;*uc-_2OYxL7W6T7j^Xu+22djPhnL98z#iI(+4bKmL$>wi>O!^m` zsS^@k=}d{TI$W-Tq@7{X12ocp34=89U(rjX+p|OkU@9=%-V{m{O3c+avJ$EZ4#L7t zBRL8^m5hd+)+FW%FDi)!eGSlXIB;~LGCut=HeLirzpd!A!oIB-XsD zDuUVB*_+}8(3Od0q2%T! z0JQ`ryOrs2K>r`KLpt{l4lv2c!i$Ta!BlnUiX8}9TP|L@f4%t5hhAP@^fA(9=PHc2 zLiq!#@sTW?wajZ<%Y>;Z2ns$Bzh7+S&JOdRiDT1BUW6-fdLUJ3nLCBE22PfH2tnq{ zrRFSa|$B(ourW}z+Dp22?ya?{bz$Qf_-}!;4 zAyN8JMP-DNOFuaE`jo7-LmX}0{gx3EX#wM?YAKv=_i#wTc&5hjXJ23YEhqNs}BBNvcH}f45y9XO4 zY$1@ysDwbiQI?Lx9go*F{hyin8awwd1!d z{Y81W*M2+`U~~A2C1J%MjkREOzvh-jLF)h{mbf0hU0zvP_?>8d9|A1E`9ImIzxTO` zlp8PG&?BkB#a@`Jvw$BtdMA^4o0NAKPm?leL01A@!{PzQc=P_JT(_dI=Ms z6wXkVCioTZ!1`ow_ybftN~kz=ZA0euN?=&1x8j;*P0@2TpV+HC7Lb>_<6;{&GZ~uS za?D%xuXY4L8c_1FqgEE^`R^7R1T+@;eEOtHNyS9q+0g8Rp2fQLvr}hN$N{UNK{Q{l zp(*`!Hf*snJE1WM!pVEuv(NZ#G9o2VDIPm1&xnaCMLC@(EEmVuY*|7 z^tj+^JcXO&<8_U|Omt+LabGgBm~232tY~+bdu+xaczMP*Jj#{O{MfuboAyU{H?|qW zgWcnO(I-7ES)rfAne?Iew4F&s43%yY9VbT9fkU@*91)l1q4$?6lA7w#U`9#Qz$qEA^# zHDvmN>h$lK!FH$xoi< z7pa~Xa6PF4h$eQAWFzJrE~yX>`wWOu6F!)D9bpL0)jgu?|D9P-KuVYs1p2g4T`^^y z@B0c;9h=S}L5reLA{1?3C%2J|;C@e!<>8OC8(!Lh}!Z7;(wBRPTLJYRgv+XH9A( z-(q=Q=LPJ|_os-ttgOrt(X>nb=za;o^=`a8@W|^=O#y&oRfC2Uk=zzEgijfx8h1Z{ zIJQss&w~Ed;9^7hCj=F>5s|UcaCQA_8^t1@@~UQwtLR|@f=a=i<;U*7u$@S&N&EE?8yw&;;z;`V>66&nlBI1U2n1W! z(Ms9bh7FA0J;n6`>q1xOtO#UeKRczRfYx&N6Qqw)70DcoC#|&K!2}ZWMBA+-|Cfc8 zd{jRCVyi}oVqR4_R@4Gz)>oUN#si*1R{;ktR4A#icO$v-$3@wQ?ylUl?UxEK%pOCD z3T6zW^YbT-=)_DIn2Y>B{)qJ3&($iY|MWaCtkf-55>|RLO5OvFOAsA3dA-kPZdGF# zHL>-jivHcvA8%=e7fodfSIA13>GVvTsdEmGJ0TJJ400q9_V|zB{C+K#WPE&>w`wux zfqaghDk|Yf3Yp&v*lEw&I#*PBUB%3t^DGhlwFXsoOQp%a&E_&9l(SM8Xk1V+X0J@y zoyfVjcjcYhN~OEcrJtmannp2(OIdngH%!(bNh#%xUY`*5E$cPcX-30_3dA-_(zubYgFl;r+PeRk`tjhVla?L4 zdtH5+sp`VYn2@l1tAp9t$j_fLzvQn6FXaLZ69HWQja*_NkpBVU$r7ff z7*^G`L2XI->1tX^mXk-%7zE#_^yr`M9&B+tE4}|)mTmAODF;nIJ0llQ+lfRtk{)vK zMow}cmyZ)>w?_^YWur~@aE7VZ2skMRzwjHwe2VzURjDdbMnTYnIj z|JSMCaA-_|P0Zjhf=fu!QP-EjfbpX$fq&}ii}aE23umwGf%bwUO4Md5biVvmjr<6; zIlNwKVEonJs``d0XB>#)+`#mX?c>*~YJ5@BT_~hVs|PfS<_`1>dCfbi7noVEi2Z51 zx^m@6U#hzvToyrew)wkzpT3ta_>s?m&D>1ea$p{rMDWNcaM-!TJ(iL^5kQ=>)YhweC9oNP+Ax-{#>Xe;*+g|{DmY3(s>J#&AXHba} za`LpD$$$7Z46p60*0*8* z6|>G4wx*^Qk^njyW^3^AS#UE88#*ssU&q+kJGX4#>!4&{V0U8z&x+d1tazlkP5Y#1 zolBSz?80yBlk}SJn2R3>fKaBUdxCo9*${cuXTr!_B>CFWBh_5Z7#J~dSDz`B=L)k* zyNJqjRRVpFT(rX(^S=Gy;zLiXE5&tN9J{E)sw>cKtbRd9T3?B(njNc zOh{9Wwwl5AA&JGrUj5c&>x12ciiX4)VL*tL<|rzO*e3F-^1Iea=rlDsK-UErxxDj= zHnv76M2R?&F{)GiZ+3oQb31JEtGsQQKLJ8OEQugwvq$Ch{IAi@kLYiOmNuv7YkB1v zL_acwlX~kv0CE6oHrS_`&A8HFqO#(*2bD;QfK0Vknda9UvMw4^E2(Mv0tMLh#2jg) z-eGr5lZ4xj$j8T1qs&gJ(LSCr7X?lwPdW^!2`OiFF51^wRKOP!$$+>Yu2hZ#O?7ek z0fcy3TNg@cFk?%21n!7q)%oZKE3?V_KTcLD2w-EXBI!V@14GDV*QqDY=$u?FM)hdC z=D@ABOR9&}%L&BM&3~ zPRR+AlBTXj(Ny8=Zu}A?NZ`GI%VrsGsw!)}Ub>S0HlhWEcFOTBG3!Gg)iw&D#b`a! zj6XU670=2!Hnbu#=S$r~cm?#AfizF@I33k=1^aiWuRKLVIv_2OeVCFf&$v%1;)A?< z()J6HeIR}8giqZ4I(=yT`$igrtaNCEWS+9+sH%#ZlJD5;1wpVH?&xeW!E4-+Dh}rK=|2IOb2@rb1f8}<$wVU)2ssg&2?ogs0p#JS zxQ`45&RWF2VE24z*Avv*?&|8=bdZ&+EI1hw>EcE@1!p2QtdRJr9> z*M_0ki7Q=q)rZe*x5z)5uM@cZ1$_T2Jf*8`@MU!zy8!VXs}g5zYtFP2Rb$vSs z(7HifQ+6~G!LXc(C0OCU{RXu>>|Ant{_yfUWIul^nx%lPq0QU;AKNGOc#N8r(W5RW>B{%z8niI`o2|BgRL9PjplgN3uv{7t_S-t5i?*? zn1p1()BI+QGjc_Z=ZF{X9zEeNj;D|O2}U#B zi_6b{2yBC_sl`?TkbN2D(+#?-$E3V|xkN3$Y91eCi>c{_y{U$;{;NE638>4vi}=O` z^nU*LYd@a>^2qdtJw|m{*la}*@@^-Z8L$6*Hd}kpRpJ?Jqwh({FS)0*I?VkrHjy(N z5`A*7vHQseASm&@v@(GqrO1tFq2zkA|7;FP?kKqRP%P0-egdaM4rX>SPM=Hd`>XzPv3lQ#;fAD}2K7Q8N;5dw1EsG+_)+)MF5CsLeX3_UO+d_^k}9m%Tdn8HWix zI4DS(YX;4;wve>kQJxMg#t@hIkU(Z``Y|&vj~Lix!0+R93zbB2k;yXy-v>?;Ztc*i zEEX?Sql!*|_+Z8Ow_kYr+`guH;=k^k#sI*kvC?M;lT76d9OiF6ff z7Fs3mIMtZqeJ9UoFILERsMx3wCN#~mz;|<-a=skey^Bf02>>wx{ER^=Q!Kjn<~aq^ zxP#%F?s^Rp&vwc2z-CdFjMMWwYH2BB zB6|qfOGL24#5WRZ-VEuZr4O`d5h{iG%E^!kkGA`iXU` z-26hZT<+P$y{>iFms~N-svJk}4(f|+E`|~F?b2F0mB7YCC^TN7h4nI@cKSNcgI-K8 zSy7?6uvG(2m>e_Ahxp6?fapGJ-F4%Rm)~TKTZYVlSilD?_jefXUad0_NtqEI5C=*obJ;k7gkwH|v# z%D)H}13ZR^XeF!2d?VqR_2?_ma{^*7pi=KVd4dPlEYQDdR8JdytE4`M5cVAvq#vx^ z!m=!5yWHwBR9Xa_1xgSXi(d$EV%6f{>1>UdJTJ%Ze+gdN%UjiJnV6P=cn9n>qK-Y~ z*!*T9XJ^C&z98X>K|HWm{C1J@?o0Ypo}AR0_YEYFLe8W)U2O49wmaOdl=A+_Z(MkUhuXl1t4#q)q8v{XE-Xblb3S;Y*G zS@EtKL?KNPJB_%E6?i}?z)BCOsDaqV6buYyHJ?~OG{RjSqiW#L5qusDa_i4&Fd9}= z7Da)-=#8CVju#^Bkk=LV5$9YRdreQs7Y8g&Osf3=6*aR_((ij^mUycqyiRx!MOEE^ z?G&<*jwh}r#0-T=D=MmwVzkUlK3L86kPHuVh>jB1v~@^Q0q+kHE8c7(v!3zoIwsDm z>J1XeP@%W&Af}>ufpZTjpCep{JmKf|nUZ7Po+MsOIlH#Bw-DFbq}l6-klix!nb+(V zytcIHeiE?(d#U`1Tx+Yjg_h7qzS#Vb*Z{Di9=Jk4C8cEJ#{t4={N$1qd{4n>b~6W<4qT2L6L~vW zHD{9Ge`Psqe>2RuX1Dc(qZ{{IevTCAwls)Z_atz3GW4^L0vF0?$wYfFSB09Dg{X3d z8>(wB?qE4yZiY=Ox7a^57zjFPVH$a#`=xZs)9sC))f!I_`CB-M?d)SHr3rlywLU%L z{@y)3^3L?K_vh2954hcUl*N@#sDUTVpyn?!^)bnm`?!mAy6`hM>wUN@Pfqgdg|kE$ zbPv#2x6RF+Vt~VB;8-6Zs2++UFOA&G+w3Y^Pfe#CTBv{nzc>?7NwK3`A~q*xJ3FMLFjXke^<+*KUr>ayLG6$I_X)ARGG3%%jQVQwg#E3+wc>emrmnD;@GK)SqCw}`Q`V(FcqBDGc`94P#X&eJp`DUKM`+s z+D3h%H)78#I-3D?oDsZVleF8oM3T6i*aD$#+Es@GT z#}=bMY~M0JyRN(+g$YCgQ3hA?tqs1!eU%_f$Cx;(Yyqk4kiW5Mz^IN!b7Mgvp_I07I+OPKejo6*v-|PmcoRcG(ZR6Dlrh@|_{0UxtQSKMlQegnm=l z`JqM&*3k#PGe_-HT1JvBQqs#5b=@&<+Gbs{D0xex7ccp&FUeq$qCndwSsO&E5Lj?Pmag&Y;A za#fK}C3MBu8>YNa1gfT0V-x8lK>2e#lB?&lCz?f{+Y<7hp0!_htBnTL9$?$Ns0gML zxENS$@F4(^8Y#9V>(O5c?m3gyA6Gi`^zFbWf`u^Xv4~4)u-MkV%ONnae@d7FEC?mF z@DeHGmyfRYes`R9B$68sPUoEdsOLg|`$G3+Su>M6bmpo0VZ$5_xD6FG#{>2(4I!_c zbZp=^HYb^3W3HK>r?$&6KU6tt7)w}rs4_O3S(1ZTRR5Rt*T)-#M?rz77C+qOfl8|gxCC0K2!zt%vBV$nFji*^W9*3 zMAEeLYOaq$O~UHNrZ@l(-r5>PWi7uAJX!nogD8905mO>DPxoufHw zLlZlhc`w${m#`&@MBwk+A+CE>Txqc&3+yq+JmwSy_coi&#|I{9r8TqV_mcOxMqcds z4=}Q{@1{R_*3`e*=r8}pVQA~iDM8@bCd$}YM)eTvRGGtx15C}l9WS@(Q-N7Nq9M}` z`T&99?_Kn~&qB1%O3NrkdmZx~DY~OgGGWASXN${V{|gH67tn@!@*99ffk{)U=jJX; z^B(@54UFZT{K~H)LAzWpLEXhc-aX?>74u}7bGVL*t_}t2Bv93Q>_ z?d!DXYK8}=dAP)+;4~|7!b-s#pNPRw22I7>;Ws-axe534WadLDUjhoHnf9uxz`Hge zamG<~j6?LRvbF~kBNYxfJpA)1lZZ#=^g3%exAOv4v#7++K3OsG@p%)vkr6IR(1sCo z!JfT#$ZNAb@(|6s0tslmv^r&TmC;&k0lc-vJ%Hx0tSp?paK(QVY*Uhkthw|O<=|5x z^2c@JmGO#ef`%Cn0?qo`vk|T?OiVnBhshu@&>WdKpoCYlg--r2L4@a;qwI9jRyV(8 z22P-RjTs%MCD&$>hLUE8Um2}Y7tR6Xt7c)-Uw#I=^GzED29^%%>%*Pyj&gb>&6g45 z^i#TlzsD@v{M;EB^ES4$j_K6t_eMtR|1ZYg1D@-*eINdmRAyu(n-CdM_DB(BmMyZg z_uea;qR1v8B(k@R%p#lYmF%6p|L0rX-}}CQ&+q>{k5~25XTHaEo#Qx<^E`Gp7gt{; z5E_!8#vTpyX0tCz^SlwQBl`4ltp8r{``zSA9smx>kO{KH+(!RGdtfkBE^luxgoKnF zayl2W*C@L4B%p2>eNv1M@ng2?tqWm-R+VjfeWVn@Qml?&>F+`X2pl0NSF!e1iysh- z&O;}`ODL-Ugl?Pqj{XblT{>2a^T25F&dqgCr8fG7bb@f!X_PtU3v?BKe1&%f9b%6gH$FLv(>I9GKm-8Xv+O73 zcxR5Qp{npIF)|b~%ARb72LZ`ceAZ)x{98n@LHn4AwcXS%hbDaIT9)TkY}Ey<-e`e@ z6ZhD`?#U=?R@6IHWpB$J@_Z*2Fd4~TXc?7E8tw)i&!Y7j?Z zC^MF(n^5^UkSXHU<8j6ZDvr@&Yj0FkpMM_H4-T1?X}+{l^=*DglO^KTlNS=_Tv)bAXmQ+bzuYy&*18}&#U0cR86IGH*cjbyvbj#fiDhr@-u9?%adY>*SVW^QZ;s;%Q@(mjzonA}XCsl(zOodJs>Wk&=1~W)(Z3WbDqm=PysCv{l3se}DTN zU4`XJP}s|c@wW$jNKI*A;|K#)J6fs{26@A4JTkwsm|N$9&UU6eroF>``AZ6l?mB4{ zkH=vjF7OFDX0p@EZ`-mv8sL(~{eomGXNNShZ)<38%S>#aFe)Mu&{8r)n!#- zHBN98#Jt*$7YpUDIQE}N{m#7PNV9PCP+D3HgD;&iPFce?vI>DgNlm%r=^+N$)3Ql zpayrxoVqtl>#vyAUK3L3Ssy=^%j7bM7h+WM*EX7e9($B9J!fl7L#C~VQ7mry@xiAq z-C>~h-00h#bYTU+?)EP(4T;pQdpDS=dUL8%p)p{^wp*H7l9jE&&>Qr6mULha)vCN`XS@#kJ%s8j0MpAjP|_!dw(R?T+` zJQlMQ8<)*^%?MF*3QAZy;nmQpwS;vS@jV647l~3)I}e7R))_h^L|)Rt$d_}#G2>y& zSR~_aY`!1ov0*|;M&Ye3=Vk2CQxUIncYH7iXa_UxU+)DEu+wP*CxrHievQMKKje9&q=(@vayO3A1uw9pa2Un7a_jRmk(K%BJQc=QzIjtwvo4!UM!Hs za|{^Q9@%ql^>xxU12(8eD5>Do+sPOZpgnhjF*q=C_K|n(n!N&+&BlA6JXtL~y%QIw z4mHt0UjReS26 zZs0dx!l7znVJzlga04j3 zv6z?=>r8vzD29@}vS0u41_w-IK{hn>-J31Y7Nm(KHz3(@SF`{g9VWZIIRib6XlwDC zTUjU>OnWBfEXInvJ?NI4rW7=^ZvZ1cMY}%UYO{gG+r}M|mku@+;plCY9w7BYB>6Yu zUhDYsU@0rYR>*OIRp&y?uzm6K7o3;8wGnA=#3s>(GBjy-*rGic!YJ`#e8CeopranLlK|w1OWx-&gjo;;A2MW@Hb_VFM3YN-( zCTI*{VJ)|feHF%-eudsFCZvYKNz3|{2S`dIZj`vjYWngadn(`;?K)(7rdqo4G>Rl+~v%F!mdv)Xn`FMcNRh+f7>f` z>BB%geAiPNv(UgI>%Y*sV2)19`Q3zuS18Et5jR9{hX1#`O)z; zCzEZbX9jfF1UX3{u0weZF1?K#Vptk+JqNvE_aFFAdRO+}C7cLDYFogcg{%>ced+Gb zw|;@;{bSU-^lF6ED0;CsD`(7&Df^RtK6sXUPDboVFjUfji<;?ZY094GgU~ys@pl>X z?5#n8Fq5CYgid9y14~vWMsM_Xyr-StH#AqswvcA!A<{Bp3h6*YF|g57<@re7@& z2rYnt$^kx?#7!-CjE$_Qp=yBx=m^lZ#O``gbvLXl=W0QMfRNxNw?`b-nOjY&wQZKV`n;wg|&iH@|x8!iz8R0h(bS7hJ$r zGyVFleCk&SFS6em?bMlO!b8IOLpQ%N`z`tUatar6K4)k>nD$&F+MfWvFw!E$@>Vo= zwTRhiVQER^Ub5L^9mE3x&49#;YJE-oUU(3ydnmyfD(uQ07sPs+-&ilUbPKE0cRkU>?uM%So27*?C^O!hQeIuo2 zs3AL^y(x)6rmpP@_f_*U<4qFcB5O&GHS7nP0AD{g zPFw64gqpq6c1P~8+tZxKx+jx8YXtk$7I*^PS7&bzx_wsC-&iO1yIj84eb13Q=>GA; z7h&93zKW!*4xXNu55f|bXQgccq-@`5x4R{OqmP@%(Mai|oU*1*UZ*ysD!70dTTT7m zQx5wCV36?m{<#C(#5(ASv#6QphsX*6r=JfhhOP6W^w>~(8W2VP)}>%2YSk@r4z z(Uv;$m93X&<6P)3w#WX~xwg*$BUre;PxlzvBRCfq6D8&IWUU|Xdx}$4XFlJX1#BF$ z%~xXrTE^7o%l?x;ueU$-18?BR0+M-=OtpGkt(URcQ6Pl%kmA@cozlnh%cL| z)IJMqm#uGi*z>!!>*P8aTHGqrVVrQ(A6;G8QxwnsQWjzH%*8p3Q)F4e-^ajUslCTi zUr7p-w8|66ObaU4EYhJDbjy{z_Za zuzP7fg59-ZW2hrFwGP;0FH3NRJ_`!H9u>$irtZ0J_%lacIi;!!4PpCLjVBfHlaFk| zwSbZen)X}~CKg6r9iIY+`R{W;A8FVOn>6Ml$TM1*D3J`szSOuSHA@TA%!;>W?)491 zMZYq<_B6kdgclfuY;~eOGExTTv#-$-zQ#veeHIOS>*D+0X-Hzkjg}!nAfK=MZE1|& zjV3VoVx!+7Et0_Y0@_5&at>&rz=lO9xc+`Z*LyO52>SgmQxao4^6*p|-h8628x#Bb z1u)+9z=if$*E&!HdEbH@+eP%KLawNo<-So)h4wS=+oT zV@aa#CdjG!4D+#;+cpQnGC$cLH#9P__%h~m-TJo~ZySKs(ME}dLHwn8xQDa!+Xso9yZZJNKtTA!rXE>Zq^5D6jGXYI zuol$^F8zldwS=4OE$C*@$%)+kGx#Zx-@q+)R_a{UJ0%IjzMk;G1$9Db%!h12_88 zXQB35e_Wp-V-LF#znd(NS&0;Z7!Gi+8kE7~7Q@9Eckd-gqD z6O2E5SCNC^*QigdVE~)8N=JzjTw<-A=Tqr(V+Z|IE=gHveqw2lKXMKnE_>K*o|i7+ zuS>UB5@f>{I@gt`w(8#7_Q zXJ8-@#&F($aOsFG=GRkX-x~Ew28M}fS5rN|f9B%nmT60$tLI1hVb>x!H?Sxceuo^Z z#*aSRb~1ZbGaSFeW!OXqEmwdc6IS0DyNF9#VP0=2lNbo?A3vc}6`2DAG`y5+BH#&P zohHWAC#R-^I$u#|^|QB|-@?TR3&Q5He>|G^J@5T|cu_AQiF9i$3h4Og-K{51FM}{2 zNGtMMEES%U%w6Sl4P_RH%?=927E?8Pl^Mane*-#8*>2EcPgV&uyH_5t14cVwH}*#2 z>O~BEKz&Uxl~ZU%WjdayjGL+PU%86^9Kb?=@VYwbR!i%Rr@Q%LIBAH0fZ?yP=;8%# z^uBMv1XRIL=o3tHkXI^Pe)HCe~G!pfd(M4 zRlcJkzj|H$^U28-3ogzAdKS8ij^i#uz+M7oYT@Xtzu&og6Z75>YDgy~AO))LiLpM9 zYQ}=v#wk(Sb1a3$^Y28|-PW9v zik#3(U71zP>M=FK2^l;REcm_PQLkxA`00k5=P&FVE@)|x#scl{z@surEz$(&8L8!s zdDypT*f4{vyLOM~t`_Dpe_UXx$cD-QtUUu_@~wb#IC)`E{me?MnVF+c!z|l6#YS#MN<%{Y>@QdI3vZdeyoHGj0bOAB< z@y-{9_%jL$wpedHCX*OH|6B*Ymd8e&@Ta#uV_iD^Gh`GL{-GJ6xQn)g0-K49uVNg@ z=eH8SM!y*-zAu|)S}#E5jt1$U!K02Q1>s7nJwKntaF`-d2UA?mc|ilY_$d=MM?}Qz zO`pzW(+a(7W23mH(p^qN7w#G94j}VV_=-@4luL`Yu{TScD<0}J$0H*OwVpEbd=>f< z_kmF$jahIIE)SNhZc5#H=e{}uBfyGb26i_d0Z}MGwbWd8j|MScztE8Nf%O);)~Ow$ zuZ|W3&aA3*Jn$3H$eXp6Q7GxRG2UB_BY}X5nH`$NGnvH!pAkG8N!_*D-cdXK3Bs%R z!LWB&w4POC+;yng<64kclF9R*+ zzh`Ia2AId~Pcb>#!uyZHQ-Et>QG1D4YyTCGv_fOKw{FpwLvLT&WKq;@0Sy4iX(Dv?nw$@}I ztaET_ucN}fATYSbLPsL7P2L0*-n*ZEh>VhX@QZxvI)fwgP&T-5&Z66>2;ANMXHH+e zq!yt21~S0MELjMtJ*S5D!#2@A^^IG-3%7VgA9PPYmIPo!QQ5m_ENvLJ5$wVvw&@=W zL*MJj6LWBes|5m~VF234$fQgRW7=wb4M zT+rS^w8Y8T0<#1h>~O@yN%-FZr3&faDlz-+)aEbF2aMA^-F0$>PCMvvfH1KcX;D;_ zzXC-)ev7;pt@DV7T=lHzd#vYlk+x10yLF~cl(jwRBG#ASMq5C&ZvbC|r7k^)K%UjX zgVVwN>MSo5)~CjXExp1@uZ-=@At*Hsgq63`pT(X{&0hazaIdEiT=R;?+pE6zrP+5? zAHM3#q1TJ<)~8e$3-i_n4K)=XVbPalYdM3QDwiR0J5s%7zY_R*zA@rfoa#7DAwl2MK8X_4atyEUe4uwd0Ql zR^2j=>6*l@-C^sdr7j>wp`YIeNQnYk1E8WPnk9)tUhJU(r8yv+*O`x$c?iTqA^3Gf zqy#%r?yrEz*FhBAsH~i#b5^Bd5s%afLVg%={#B{_Nrypflj(^*ftw36f&-RkwqTnm&B=NrQKYp}_1Hqw7!W}IyXhHHm{mt zM&J!?%24u%bAMOx0DuR|idABy{SX=wLuh_A>3N(i2EYIVs0N_)r@5f0k1Z;BA>a4= zTQx|MtbpK^@3wAM`8>TEd-4sq9qgLVFIK!Oe`rsK^8pwx?iAzR{{letXJvXi5nQcT zI+}p6a3EwLZ*T*a&$jCeE(oL6`e!|H9eEt>>>clorQBD1Mp98xi4T+#*gndo)}{2> zr9GLJSMI4A;R&LjRqT(uAg)7m&u2j!3#%CWy@-zw9)UC83rrk*8ipY&X%|V_lsX?d zn48?<&+)6V9!uP zy$-r=c6Xi7uD7cT{&0NTqdg;V=B~U5r!uRS7p$^)sReMwR;?e?kQPHASP-|zV*1JJ zbN4qi>iiyx%_J+|6yd$jBdm>jKsKZA3 z9{LHET)MfxGe5NWq-)nX3Gi|JpjDL5dj4XXZNAK@1a4`a@#!oDoCZ<8lHl<^Oht^X}BrKhGr24k8s1`8ZqA+4lcm4DsV zJUlFrH#>2aMV-^Jlg#1O;-B$ul_Z8de6-yBXY;%8zkVyImf=CQ>OZ2n7juWcaC}i5 z!0muV%ci)+D?>ts7%1gjRilmk4`a9trf9im>X*`?fN%EJb@~F&)Dd&k_a>P0g-3KX zm?*?hV5G|~4PqtW|J?l-r3ZDZ)1 zhidEOKqgi?%|sJ@A<&wAG=Ikd|9(AF*X&^B)qJ(K911D+TBa6qFtVdw<0q%D83P_$CcyVQz#ajkG*-GUm- z50t-xUs2;j4IcUvXki}aJGfaTiHXuGt>cd_t`q4*sC_~RW*Qo!Rn5D))7^t>TTh-t zc4C4Vf9Q^LukY>)(J!CuJNNf6@gYqcamyTRI(M35F`)_UQ22?b-bUA|7NL3{IOSCp zoeM(_#7%5IGoc_fM2rb^FGAQJc5<+KU3I6kCot&qVPTa-PTru-%+%D&qF+uu)T_*p zCLv9GFh&cXfWm@U5T~Je7)xrVqZaPZT1ySisIPgI)Ea8|L>MWRxUAyZG~qH{gC-04 zoLyUZXkpa4Jx{uW1(oU-+5!Gt8b=jHh( ziaEIa>B5B&$SXV?fHI<=JbD7vVR3T?{Z^&mMvK+AvZWOhzjGI89bWZgU|`-id$ZNf z=nB+(WKIQ>NE&D^L0loJy;6*>b zOAAk2!7^LLy*g(ZgGMy58q(lvG?BM81u5I`W!HO)e6E*|18be8o>#=60MV}$l`_tpwKkLdzkk6!3GsS-#X zb+(Hnpry^b(ZZi!Lx4^>>hsmBA{p;nwY2aBucGfC=(cS|O)A|tVn1!$KO7$`swm0< zR?8U}4t8*2u)ns56KCV>6voI>CNw2meEm6#I(y6 zQfCzww$8L?f>sf+PyLab``S{jhwoH|jrH~tUs?S=o-OI%AKi&9dvTOzv5A_A~fXd-EaDZ}+f8fU!W{tFkb zI&{IzI%?TK4s6m=hXlsjyldp`s_vhKfrzCrtNSEvSsN~wVdvyi&km;0Qfy5?2lA}8 z747UDVjX;$vp#UxbF#(=HPRZL^}?9f{k%1_3Lkb4>T)o2pmTxU+EZAIuCY0;jzOx}3$nx68^iW^8{ag62iRSRx?>zW)PHs|(=-%O|&2T2sxKIQ7p7V(Q(bX#q6R7*lNNXO?t33U z%*%IU%X5q!{-Wo<72;8S$49f85pxs58|hU~b4OwHbGujh<4}+~YC~VjhwL1lHp}3b z+#}J?=--{iPL{j7YGLDj?16W^A*L=J(uEPLQNip=XLLYc5E5b41y`R8U!!{L^gCyi zr4>n#Ci}L#-^L#W{<`1Fe|YGNjEcsJb948@yWV*Af>)qb{i~B>q$eo>52p-6TJ+1A z`rnzq-nVYu_r0LE^4-j>bTwxFGc?`RruCI10}0T0kiZJkz=8aoi@ta6sFqkUP&!XQ zSZMy?WARb-iU9BDlL{Px7Oqxx4yM~UcL3qc|Lw@l%ZP`JM1=0+AtLUy?ExG@Lkh#bgW!H_Od z8&*k+v3jax|WqAlZ(y+ z0j5FxK@83YE(S3TchhIA8+59^ODqRbH~r6{SD0mdrIzz*{0%9Q==-UWt_1@f*W3aU9PzHJ7M4_yeQ^j5!X$cwr|Q~W3RmgIikN8 zlZDmo_$8=BgaRcwhBZVgIE({lo6?BsjS#T=*ITe^&Y=K6Uz>5&svB1^QfoK-wR*h) znZF2s3l}Ryk}m-p^l=Y6r>C;JK>Q1z+5?O1&eE)JUsdLOU=$^c16tc!=lJfhaNX<{ zo!m2D*}4Oz zYB^x%ZuZ`P7~C?kJPoaX=~F7gRhG4jS4GmNk-j_^HCmMDyCS1@&Ys_0wZlMrqD{Kr97?uv-t7mbD9_0|G99nvZgH&zaiRn>8Snt$U!#SCftnbjkp*0 zURh&|-GzoE$iV#|r_X!!&Lg@HXSZtC=^7W`nE`ha5h~FE`)LCLm{=7foryl?iZLFb4qxb_w~mT!_Av#VAva~waXY6 zoUiu&8)>9VrR!rI408hZK8!h`wH9 z)m!CfB`|(t>BFfkE&uGMtyvSc+Oc|IaNhAP85Rlm*3?_q6>miE`c$;5|2+b~YG4Kc zzD%pkO69Jx{N2(fWAiIMTrQ_w8U_4L?ke04)TXe%7X4Pw(#_U{Z zAR zlEX?HIhs)OyYnIG!y&VenhAuTSZ>7u$2H5RF5;)&+P+qjb=i-}ndw0?!9i{FCy9m5 znCHHJe<%(O5@!$H*N3%^@t~DvAff9V34hipp?j0Si9MYJYd6EF@r#c?6@LIOg@~pm zC8Cz$QME;-2zvy2Ud4bNdugnY82*-wPrXQv$bwKzXR8IK57>A!DcW#D1QL>F(R2Vu zA?;Hi5RsINXKOfvn!9&j6;F+bl{n1MFzpF;)DCb>qLiepA%l+&1B{ivt8Jw`#acdMZXQV5?rkk z7wow=n6Q6QL!t;&OhjJrsnv@u+j{O0b2HwfMy>4;xQX9 zgyEZ@a-e%n6!=H+XGrc$;P!2p7OTM|>9g#IX5B`d{S8UiQ5Q#=b==fwb!qEgF@%b{ zx*oi8U$vOhCh;FNEz?w(>Zpdvm_6fd)>AG}&L?Dv6F`TNZaOC$TU+Z5zr}7RD`&~g zPM`nuBM&VxxAQ{CjO>?7s~T9^6a61DGO%n7W$S)P7(dP|4#dW)SXdRy>Z@yR#`+eN zkGJo1?UvxVJwJNJd7K%5n(q5u(CXVLDt3Ga@-{oJbNP$5u`t^cq!%o4C<9+&7FPO0 z8-2id>xxCKyW?1(@4!Qt6*~M`99k0Mvp#gJxWlx5xG&TK8^7fXt7msM?IzPBvARI4 zq2RY#S^1fu_19J7sUXaw+aMN@oWEy74GLT44lf1z*$T@_!0)?^qj?2X3dmp9>-v#i z-+;}b=`_;}WQ_0r$z&ZoZ)YbbybX?XfRb$R3wB&@?L(v z_Ck-693vs3&BM&b=3tMZ_ZH4vN~*Q?LbnprjT<(|BXBoC0Vu20ia$AnEPeKUYPndD zzrVlE$%(syp3r9(!!KFt)2~W!M=g9y$9XRye^Kw`n5dPQ{FjEVjsAS{Bt4gmjI3*5 zV6M(2N7KHThHicPw^yll!o0hQ+v?nKFWX6}Noh63rCcra+Rm-pF)p7+&fFA!bZ!sE z*Y2zz5|DkjcqAQCzv5$P_1ldG?2Z2~XTecd6iJ2Qq5C;BO@&WgG>X3hbfG4rCMW0p zR#u~Cc2h}9D{SAT*&*h;B$r+D-gV#M>_Vk94s;{0$%A0*D_aNd#`ep+Zc8lB()wWH ztGJcd%{Wprd|=@;%_s*Rx>H9=c6FOvbP^OCf$EuA|Ke)j!Two)uaEipU58cUqdL}! zKrik6@#8GJuNQ0-qG|n@@K_^WYJs%zS#A#o$F1><{O?a{b`1ue`K-&}ppeO=GBRRH zN=m|pB)-1BpwHD66sa-%pE4PKK2r%||5C(I>HjTV0z={n2??n|*>GiLWU}SoiWtWM zMd4E*Y6fplmyWG#XZUB{XVJ3xcyn;3g10&h85>Aa$y7{%A)z4AadmsPI_hbK9mbb0 zUjhTI#=JHsC&=)g_gs;JilKia)+~*)m(E1JMDpWzpk1^v(+(5U(bNi@j0{uIw{f|y zTs#*T7S`OGSu*Q+<}Q%$Y|*DtZ1?Ce8;1IrmpyVo#}gFyuCaV~CX@KkOG-%C4}r-U zh)4x@76X>!zHr*k*b%ngXp3qwFFQ~!)B~pU8Mm`MJUr*n(Q$Uw&af}~+c-KtFf+^6 zn3ajWh0ZrfFgZD?Q{rnuZBtI(A$%gpvtU#3N1_AKjDhLvd>MbH8)vsXlxYCgA1IX0 zVKV*XEj;rZjicRSAFQ*ru2TJKORV&dbLi(MBCmL0JR++c!?}N&jU5=sI6FD<3Ld@{ zzln>PPs`0sd;+t^<&>4r*Rr!8*7;$@2<$BQPa(M%{FKK5Kn@{?4`$rCvtegMS!l00 zz@7tbB)CP|-@@rpC=h=^l*>T{ZqwXd>eN5oPYuh+&?)7A;>-n6q=rVQ6uGbmd2w;E zYQC-?jF?A6uIfW>RwKIG_~qs01{?p9KuBn7Rq)y?Oo!&=*BnyqJ@t9NK-9nooW^&% z18C%5K$IV}W0<5;vFg*2DU5%9Zn31b0!pIoHJ&Fp{zv7(f4bbr=-lgIB0Ec{R5Ntj#y=ZEv>;iA^bS=5MiGsPdN zoWGOm@n7;XtK{5&_%QH~=GfoEOTnLQ7Cz>HtTyf6a+;?~mw#VyDv$AJMtq=A!{BG< z+hZ<^tjN6aiSK_E8WB-b*Hs}@Eus=MBx2P0A4h3BY+gzOyY|0?kGcYFA%n8vsgMi4 z{jbaWpF*RwDKzykJK@idc)}&&Lp;zKDfpiU4l#SfVs6c;@Wj8AmSBcTJ19xw5fM%Q zo%HzUQ;ZtQ_h#bD5uwlj)PMcwNb$Q~{adm3TKoSG`8{NY{hwnC=`kv8^p~|f3QUd&Yk}9w0(2Fhs$l($}rkn zUb~&)^uY<6?bY$srOS^cePvD`sGUZlB9no<^&(5|h;wrKnhX3